summaryrefslogtreecommitdiff
path: root/vendor/hyper-util/src/server/graceful.rs
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/hyper-util/src/server/graceful.rs')
-rw-r--r--vendor/hyper-util/src/server/graceful.rs488
1 files changed, 488 insertions, 0 deletions
diff --git a/vendor/hyper-util/src/server/graceful.rs b/vendor/hyper-util/src/server/graceful.rs
new file mode 100644
index 00000000..b367fc8a
--- /dev/null
+++ b/vendor/hyper-util/src/server/graceful.rs
@@ -0,0 +1,488 @@
+//! Utility to gracefully shutdown a server.
+//!
+//! This module provides a [`GracefulShutdown`] type,
+//! which can be used to gracefully shutdown a server.
+//!
+//! See <https://github.com/hyperium/hyper-util/blob/master/examples/server_graceful.rs>
+//! for an example of how to use this.
+
+use std::{
+ fmt::{self, Debug},
+ future::Future,
+ pin::Pin,
+ task::{self, Poll},
+};
+
+use pin_project_lite::pin_project;
+use tokio::sync::watch;
+
+/// A graceful shutdown utility
+// Purposefully not `Clone`, see `watcher()` method for why.
+pub struct GracefulShutdown {
+ tx: watch::Sender<()>,
+}
+
+/// A watcher side of the graceful shutdown.
+///
+/// This type can only watch a connection, it cannot trigger a shutdown.
+///
+/// Call [`GracefulShutdown::watcher()`] to construct one of these.
+pub struct Watcher {
+ rx: watch::Receiver<()>,
+}
+
+impl GracefulShutdown {
+ /// Create a new graceful shutdown helper.
+ pub fn new() -> Self {
+ let (tx, _) = watch::channel(());
+ Self { tx }
+ }
+
+ /// Wrap a future for graceful shutdown watching.
+ pub fn watch<C: GracefulConnection>(&self, conn: C) -> impl Future<Output = C::Output> {
+ self.watcher().watch(conn)
+ }
+
+ /// Create an owned type that can watch a connection.
+ ///
+ /// This method allows created an owned type that can be sent onto another
+ /// task before calling [`Watcher::watch()`].
+ // Internal: this function exists because `Clone` allows footguns.
+ // If the `tx` were cloned (or the `rx`), race conditions can happens where
+ // one task starting a shutdown is scheduled and interwined with a task
+ // starting to watch a connection, and the "watch version" is one behind.
+ pub fn watcher(&self) -> Watcher {
+ let rx = self.tx.subscribe();
+ Watcher { rx }
+ }
+
+ /// Signal shutdown for all watched connections.
+ ///
+ /// This returns a `Future` which will complete once all watched
+ /// connections have shutdown.
+ pub async fn shutdown(self) {
+ let Self { tx } = self;
+
+ // signal all the watched futures about the change
+ let _ = tx.send(());
+ // and then wait for all of them to complete
+ tx.closed().await;
+ }
+
+ /// Returns the number of the watching connections.
+ pub fn count(&self) -> usize {
+ self.tx.receiver_count()
+ }
+}
+
+impl Debug for GracefulShutdown {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("GracefulShutdown").finish()
+ }
+}
+
+impl Default for GracefulShutdown {
+ fn default() -> Self {
+ Self::new()
+ }
+}
+
+impl Watcher {
+ /// Wrap a future for graceful shutdown watching.
+ pub fn watch<C: GracefulConnection>(self, conn: C) -> impl Future<Output = C::Output> {
+ let Watcher { mut rx } = self;
+ GracefulConnectionFuture::new(conn, async move {
+ let _ = rx.changed().await;
+ // hold onto the rx until the watched future is completed
+ rx
+ })
+ }
+}
+
+impl Debug for Watcher {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("GracefulWatcher").finish()
+ }
+}
+
+pin_project! {
+ struct GracefulConnectionFuture<C, F: Future> {
+ #[pin]
+ conn: C,
+ #[pin]
+ cancel: F,
+ #[pin]
+ // If cancelled, this is held until the inner conn is done.
+ cancelled_guard: Option<F::Output>,
+ }
+}
+
+impl<C, F: Future> GracefulConnectionFuture<C, F> {
+ fn new(conn: C, cancel: F) -> Self {
+ Self {
+ conn,
+ cancel,
+ cancelled_guard: None,
+ }
+ }
+}
+
+impl<C, F: Future> Debug for GracefulConnectionFuture<C, F> {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("GracefulConnectionFuture").finish()
+ }
+}
+
+impl<C, F> Future for GracefulConnectionFuture<C, F>
+where
+ C: GracefulConnection,
+ F: Future,
+{
+ type Output = C::Output;
+
+ fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
+ let mut this = self.project();
+ if this.cancelled_guard.is_none() {
+ if let Poll::Ready(guard) = this.cancel.poll(cx) {
+ this.cancelled_guard.set(Some(guard));
+ this.conn.as_mut().graceful_shutdown();
+ }
+ }
+ this.conn.poll(cx)
+ }
+}
+
+/// An internal utility trait as an umbrella target for all (hyper) connection
+/// types that the [`GracefulShutdown`] can watch.
+pub trait GracefulConnection: Future<Output = Result<(), Self::Error>> + private::Sealed {
+ /// The error type returned by the connection when used as a future.
+ type Error;
+
+ /// Start a graceful shutdown process for this connection.
+ fn graceful_shutdown(self: Pin<&mut Self>);
+}
+
+#[cfg(feature = "http1")]
+impl<I, B, S> GracefulConnection for hyper::server::conn::http1::Connection<I, S>
+where
+ S: hyper::service::HttpService<hyper::body::Incoming, ResBody = B>,
+ S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
+ I: hyper::rt::Read + hyper::rt::Write + Unpin + 'static,
+ B: hyper::body::Body + 'static,
+ B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
+{
+ type Error = hyper::Error;
+
+ fn graceful_shutdown(self: Pin<&mut Self>) {
+ hyper::server::conn::http1::Connection::graceful_shutdown(self);
+ }
+}
+
+#[cfg(feature = "http2")]
+impl<I, B, S, E> GracefulConnection for hyper::server::conn::http2::Connection<I, S, E>
+where
+ S: hyper::service::HttpService<hyper::body::Incoming, ResBody = B>,
+ S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
+ I: hyper::rt::Read + hyper::rt::Write + Unpin + 'static,
+ B: hyper::body::Body + 'static,
+ B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
+ E: hyper::rt::bounds::Http2ServerConnExec<S::Future, B>,
+{
+ type Error = hyper::Error;
+
+ fn graceful_shutdown(self: Pin<&mut Self>) {
+ hyper::server::conn::http2::Connection::graceful_shutdown(self);
+ }
+}
+
+#[cfg(feature = "server-auto")]
+impl<I, B, S, E> GracefulConnection for crate::server::conn::auto::Connection<'_, I, S, E>
+where
+ S: hyper::service::Service<http::Request<hyper::body::Incoming>, Response = http::Response<B>>,
+ S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
+ S::Future: 'static,
+ I: hyper::rt::Read + hyper::rt::Write + Unpin + 'static,
+ B: hyper::body::Body + 'static,
+ B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
+ E: hyper::rt::bounds::Http2ServerConnExec<S::Future, B>,
+{
+ type Error = Box<dyn std::error::Error + Send + Sync>;
+
+ fn graceful_shutdown(self: Pin<&mut Self>) {
+ crate::server::conn::auto::Connection::graceful_shutdown(self);
+ }
+}
+
+#[cfg(feature = "server-auto")]
+impl<I, B, S, E> GracefulConnection
+ for crate::server::conn::auto::UpgradeableConnection<'_, I, S, E>
+where
+ S: hyper::service::Service<http::Request<hyper::body::Incoming>, Response = http::Response<B>>,
+ S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
+ S::Future: 'static,
+ I: hyper::rt::Read + hyper::rt::Write + Unpin + Send + 'static,
+ B: hyper::body::Body + 'static,
+ B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
+ E: hyper::rt::bounds::Http2ServerConnExec<S::Future, B>,
+{
+ type Error = Box<dyn std::error::Error + Send + Sync>;
+
+ fn graceful_shutdown(self: Pin<&mut Self>) {
+ crate::server::conn::auto::UpgradeableConnection::graceful_shutdown(self);
+ }
+}
+
+mod private {
+ pub trait Sealed {}
+
+ #[cfg(feature = "http1")]
+ impl<I, B, S> Sealed for hyper::server::conn::http1::Connection<I, S>
+ where
+ S: hyper::service::HttpService<hyper::body::Incoming, ResBody = B>,
+ S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
+ I: hyper::rt::Read + hyper::rt::Write + Unpin + 'static,
+ B: hyper::body::Body + 'static,
+ B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
+ {
+ }
+
+ #[cfg(feature = "http1")]
+ impl<I, B, S> Sealed for hyper::server::conn::http1::UpgradeableConnection<I, S>
+ where
+ S: hyper::service::HttpService<hyper::body::Incoming, ResBody = B>,
+ S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
+ I: hyper::rt::Read + hyper::rt::Write + Unpin + 'static,
+ B: hyper::body::Body + 'static,
+ B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
+ {
+ }
+
+ #[cfg(feature = "http2")]
+ impl<I, B, S, E> Sealed for hyper::server::conn::http2::Connection<I, S, E>
+ where
+ S: hyper::service::HttpService<hyper::body::Incoming, ResBody = B>,
+ S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
+ I: hyper::rt::Read + hyper::rt::Write + Unpin + 'static,
+ B: hyper::body::Body + 'static,
+ B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
+ E: hyper::rt::bounds::Http2ServerConnExec<S::Future, B>,
+ {
+ }
+
+ #[cfg(feature = "server-auto")]
+ impl<I, B, S, E> Sealed for crate::server::conn::auto::Connection<'_, I, S, E>
+ where
+ S: hyper::service::Service<
+ http::Request<hyper::body::Incoming>,
+ Response = http::Response<B>,
+ >,
+ S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
+ S::Future: 'static,
+ I: hyper::rt::Read + hyper::rt::Write + Unpin + 'static,
+ B: hyper::body::Body + 'static,
+ B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
+ E: hyper::rt::bounds::Http2ServerConnExec<S::Future, B>,
+ {
+ }
+
+ #[cfg(feature = "server-auto")]
+ impl<I, B, S, E> Sealed for crate::server::conn::auto::UpgradeableConnection<'_, I, S, E>
+ where
+ S: hyper::service::Service<
+ http::Request<hyper::body::Incoming>,
+ Response = http::Response<B>,
+ >,
+ S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
+ S::Future: 'static,
+ I: hyper::rt::Read + hyper::rt::Write + Unpin + Send + 'static,
+ B: hyper::body::Body + 'static,
+ B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
+ E: hyper::rt::bounds::Http2ServerConnExec<S::Future, B>,
+ {
+ }
+}
+
+#[cfg(test)]
+mod test {
+ use super::*;
+ use pin_project_lite::pin_project;
+ use std::sync::atomic::{AtomicUsize, Ordering};
+ use std::sync::Arc;
+
+ pin_project! {
+ #[derive(Debug)]
+ struct DummyConnection<F> {
+ #[pin]
+ future: F,
+ shutdown_counter: Arc<AtomicUsize>,
+ }
+ }
+
+ impl<F> private::Sealed for DummyConnection<F> {}
+
+ impl<F: Future> GracefulConnection for DummyConnection<F> {
+ type Error = ();
+
+ fn graceful_shutdown(self: Pin<&mut Self>) {
+ self.shutdown_counter.fetch_add(1, Ordering::SeqCst);
+ }
+ }
+
+ impl<F: Future> Future for DummyConnection<F> {
+ type Output = Result<(), ()>;
+
+ fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
+ match self.project().future.poll(cx) {
+ Poll::Ready(_) => Poll::Ready(Ok(())),
+ Poll::Pending => Poll::Pending,
+ }
+ }
+ }
+
+ #[cfg(not(miri))]
+ #[tokio::test]
+ async fn test_graceful_shutdown_ok() {
+ let graceful = GracefulShutdown::new();
+ let shutdown_counter = Arc::new(AtomicUsize::new(0));
+ let (dummy_tx, _) = tokio::sync::broadcast::channel(1);
+
+ for i in 1..=3 {
+ let mut dummy_rx = dummy_tx.subscribe();
+ let shutdown_counter = shutdown_counter.clone();
+
+ let future = async move {
+ tokio::time::sleep(std::time::Duration::from_millis(i * 10)).await;
+ let _ = dummy_rx.recv().await;
+ };
+ let dummy_conn = DummyConnection {
+ future,
+ shutdown_counter,
+ };
+ let conn = graceful.watch(dummy_conn);
+ tokio::spawn(async move {
+ conn.await.unwrap();
+ });
+ }
+
+ assert_eq!(shutdown_counter.load(Ordering::SeqCst), 0);
+ let _ = dummy_tx.send(());
+
+ tokio::select! {
+ _ = tokio::time::sleep(std::time::Duration::from_millis(100)) => {
+ panic!("timeout")
+ },
+ _ = graceful.shutdown() => {
+ assert_eq!(shutdown_counter.load(Ordering::SeqCst), 3);
+ }
+ }
+ }
+
+ #[cfg(not(miri))]
+ #[tokio::test]
+ async fn test_graceful_shutdown_delayed_ok() {
+ let graceful = GracefulShutdown::new();
+ let shutdown_counter = Arc::new(AtomicUsize::new(0));
+
+ for i in 1..=3 {
+ let shutdown_counter = shutdown_counter.clone();
+
+ //tokio::time::sleep(std::time::Duration::from_millis(i * 5)).await;
+ let future = async move {
+ tokio::time::sleep(std::time::Duration::from_millis(i * 50)).await;
+ };
+ let dummy_conn = DummyConnection {
+ future,
+ shutdown_counter,
+ };
+ let conn = graceful.watch(dummy_conn);
+ tokio::spawn(async move {
+ conn.await.unwrap();
+ });
+ }
+
+ assert_eq!(shutdown_counter.load(Ordering::SeqCst), 0);
+
+ tokio::select! {
+ _ = tokio::time::sleep(std::time::Duration::from_millis(200)) => {
+ panic!("timeout")
+ },
+ _ = graceful.shutdown() => {
+ assert_eq!(shutdown_counter.load(Ordering::SeqCst), 3);
+ }
+ }
+ }
+
+ #[cfg(not(miri))]
+ #[tokio::test]
+ async fn test_graceful_shutdown_multi_per_watcher_ok() {
+ let graceful = GracefulShutdown::new();
+ let shutdown_counter = Arc::new(AtomicUsize::new(0));
+
+ for i in 1..=3 {
+ let shutdown_counter = shutdown_counter.clone();
+
+ let mut futures = Vec::new();
+ for u in 1..=i {
+ let future = tokio::time::sleep(std::time::Duration::from_millis(u * 50));
+ let dummy_conn = DummyConnection {
+ future,
+ shutdown_counter: shutdown_counter.clone(),
+ };
+ let conn = graceful.watch(dummy_conn);
+ futures.push(conn);
+ }
+ tokio::spawn(async move {
+ futures_util::future::join_all(futures).await;
+ });
+ }
+
+ assert_eq!(shutdown_counter.load(Ordering::SeqCst), 0);
+
+ tokio::select! {
+ _ = tokio::time::sleep(std::time::Duration::from_millis(200)) => {
+ panic!("timeout")
+ },
+ _ = graceful.shutdown() => {
+ assert_eq!(shutdown_counter.load(Ordering::SeqCst), 6);
+ }
+ }
+ }
+
+ #[cfg(not(miri))]
+ #[tokio::test]
+ async fn test_graceful_shutdown_timeout() {
+ let graceful = GracefulShutdown::new();
+ let shutdown_counter = Arc::new(AtomicUsize::new(0));
+
+ for i in 1..=3 {
+ let shutdown_counter = shutdown_counter.clone();
+
+ let future = async move {
+ if i == 1 {
+ std::future::pending::<()>().await
+ } else {
+ std::future::ready(()).await
+ }
+ };
+ let dummy_conn = DummyConnection {
+ future,
+ shutdown_counter,
+ };
+ let conn = graceful.watch(dummy_conn);
+ tokio::spawn(async move {
+ conn.await.unwrap();
+ });
+ }
+
+ assert_eq!(shutdown_counter.load(Ordering::SeqCst), 0);
+
+ tokio::select! {
+ _ = tokio::time::sleep(std::time::Duration::from_millis(100)) => {
+ assert_eq!(shutdown_counter.load(Ordering::SeqCst), 3);
+ },
+ _ = graceful.shutdown() => {
+ panic!("shutdown should not be completed: as not all our conns finish")
+ }
+ }
+ }
+}