diff options
Diffstat (limited to 'vendor/hyper-util/src/server/graceful.rs')
| -rw-r--r-- | vendor/hyper-util/src/server/graceful.rs | 488 |
1 files changed, 488 insertions, 0 deletions
diff --git a/vendor/hyper-util/src/server/graceful.rs b/vendor/hyper-util/src/server/graceful.rs new file mode 100644 index 00000000..b367fc8a --- /dev/null +++ b/vendor/hyper-util/src/server/graceful.rs @@ -0,0 +1,488 @@ +//! Utility to gracefully shutdown a server. +//! +//! This module provides a [`GracefulShutdown`] type, +//! which can be used to gracefully shutdown a server. +//! +//! See <https://github.com/hyperium/hyper-util/blob/master/examples/server_graceful.rs> +//! for an example of how to use this. + +use std::{ + fmt::{self, Debug}, + future::Future, + pin::Pin, + task::{self, Poll}, +}; + +use pin_project_lite::pin_project; +use tokio::sync::watch; + +/// A graceful shutdown utility +// Purposefully not `Clone`, see `watcher()` method for why. +pub struct GracefulShutdown { + tx: watch::Sender<()>, +} + +/// A watcher side of the graceful shutdown. +/// +/// This type can only watch a connection, it cannot trigger a shutdown. +/// +/// Call [`GracefulShutdown::watcher()`] to construct one of these. +pub struct Watcher { + rx: watch::Receiver<()>, +} + +impl GracefulShutdown { + /// Create a new graceful shutdown helper. + pub fn new() -> Self { + let (tx, _) = watch::channel(()); + Self { tx } + } + + /// Wrap a future for graceful shutdown watching. + pub fn watch<C: GracefulConnection>(&self, conn: C) -> impl Future<Output = C::Output> { + self.watcher().watch(conn) + } + + /// Create an owned type that can watch a connection. + /// + /// This method allows created an owned type that can be sent onto another + /// task before calling [`Watcher::watch()`]. + // Internal: this function exists because `Clone` allows footguns. + // If the `tx` were cloned (or the `rx`), race conditions can happens where + // one task starting a shutdown is scheduled and interwined with a task + // starting to watch a connection, and the "watch version" is one behind. + pub fn watcher(&self) -> Watcher { + let rx = self.tx.subscribe(); + Watcher { rx } + } + + /// Signal shutdown for all watched connections. + /// + /// This returns a `Future` which will complete once all watched + /// connections have shutdown. + pub async fn shutdown(self) { + let Self { tx } = self; + + // signal all the watched futures about the change + let _ = tx.send(()); + // and then wait for all of them to complete + tx.closed().await; + } + + /// Returns the number of the watching connections. + pub fn count(&self) -> usize { + self.tx.receiver_count() + } +} + +impl Debug for GracefulShutdown { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("GracefulShutdown").finish() + } +} + +impl Default for GracefulShutdown { + fn default() -> Self { + Self::new() + } +} + +impl Watcher { + /// Wrap a future for graceful shutdown watching. + pub fn watch<C: GracefulConnection>(self, conn: C) -> impl Future<Output = C::Output> { + let Watcher { mut rx } = self; + GracefulConnectionFuture::new(conn, async move { + let _ = rx.changed().await; + // hold onto the rx until the watched future is completed + rx + }) + } +} + +impl Debug for Watcher { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("GracefulWatcher").finish() + } +} + +pin_project! { + struct GracefulConnectionFuture<C, F: Future> { + #[pin] + conn: C, + #[pin] + cancel: F, + #[pin] + // If cancelled, this is held until the inner conn is done. + cancelled_guard: Option<F::Output>, + } +} + +impl<C, F: Future> GracefulConnectionFuture<C, F> { + fn new(conn: C, cancel: F) -> Self { + Self { + conn, + cancel, + cancelled_guard: None, + } + } +} + +impl<C, F: Future> Debug for GracefulConnectionFuture<C, F> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("GracefulConnectionFuture").finish() + } +} + +impl<C, F> Future for GracefulConnectionFuture<C, F> +where + C: GracefulConnection, + F: Future, +{ + type Output = C::Output; + + fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> { + let mut this = self.project(); + if this.cancelled_guard.is_none() { + if let Poll::Ready(guard) = this.cancel.poll(cx) { + this.cancelled_guard.set(Some(guard)); + this.conn.as_mut().graceful_shutdown(); + } + } + this.conn.poll(cx) + } +} + +/// An internal utility trait as an umbrella target for all (hyper) connection +/// types that the [`GracefulShutdown`] can watch. +pub trait GracefulConnection: Future<Output = Result<(), Self::Error>> + private::Sealed { + /// The error type returned by the connection when used as a future. + type Error; + + /// Start a graceful shutdown process for this connection. + fn graceful_shutdown(self: Pin<&mut Self>); +} + +#[cfg(feature = "http1")] +impl<I, B, S> GracefulConnection for hyper::server::conn::http1::Connection<I, S> +where + S: hyper::service::HttpService<hyper::body::Incoming, ResBody = B>, + S::Error: Into<Box<dyn std::error::Error + Send + Sync>>, + I: hyper::rt::Read + hyper::rt::Write + Unpin + 'static, + B: hyper::body::Body + 'static, + B::Error: Into<Box<dyn std::error::Error + Send + Sync>>, +{ + type Error = hyper::Error; + + fn graceful_shutdown(self: Pin<&mut Self>) { + hyper::server::conn::http1::Connection::graceful_shutdown(self); + } +} + +#[cfg(feature = "http2")] +impl<I, B, S, E> GracefulConnection for hyper::server::conn::http2::Connection<I, S, E> +where + S: hyper::service::HttpService<hyper::body::Incoming, ResBody = B>, + S::Error: Into<Box<dyn std::error::Error + Send + Sync>>, + I: hyper::rt::Read + hyper::rt::Write + Unpin + 'static, + B: hyper::body::Body + 'static, + B::Error: Into<Box<dyn std::error::Error + Send + Sync>>, + E: hyper::rt::bounds::Http2ServerConnExec<S::Future, B>, +{ + type Error = hyper::Error; + + fn graceful_shutdown(self: Pin<&mut Self>) { + hyper::server::conn::http2::Connection::graceful_shutdown(self); + } +} + +#[cfg(feature = "server-auto")] +impl<I, B, S, E> GracefulConnection for crate::server::conn::auto::Connection<'_, I, S, E> +where + S: hyper::service::Service<http::Request<hyper::body::Incoming>, Response = http::Response<B>>, + S::Error: Into<Box<dyn std::error::Error + Send + Sync>>, + S::Future: 'static, + I: hyper::rt::Read + hyper::rt::Write + Unpin + 'static, + B: hyper::body::Body + 'static, + B::Error: Into<Box<dyn std::error::Error + Send + Sync>>, + E: hyper::rt::bounds::Http2ServerConnExec<S::Future, B>, +{ + type Error = Box<dyn std::error::Error + Send + Sync>; + + fn graceful_shutdown(self: Pin<&mut Self>) { + crate::server::conn::auto::Connection::graceful_shutdown(self); + } +} + +#[cfg(feature = "server-auto")] +impl<I, B, S, E> GracefulConnection + for crate::server::conn::auto::UpgradeableConnection<'_, I, S, E> +where + S: hyper::service::Service<http::Request<hyper::body::Incoming>, Response = http::Response<B>>, + S::Error: Into<Box<dyn std::error::Error + Send + Sync>>, + S::Future: 'static, + I: hyper::rt::Read + hyper::rt::Write + Unpin + Send + 'static, + B: hyper::body::Body + 'static, + B::Error: Into<Box<dyn std::error::Error + Send + Sync>>, + E: hyper::rt::bounds::Http2ServerConnExec<S::Future, B>, +{ + type Error = Box<dyn std::error::Error + Send + Sync>; + + fn graceful_shutdown(self: Pin<&mut Self>) { + crate::server::conn::auto::UpgradeableConnection::graceful_shutdown(self); + } +} + +mod private { + pub trait Sealed {} + + #[cfg(feature = "http1")] + impl<I, B, S> Sealed for hyper::server::conn::http1::Connection<I, S> + where + S: hyper::service::HttpService<hyper::body::Incoming, ResBody = B>, + S::Error: Into<Box<dyn std::error::Error + Send + Sync>>, + I: hyper::rt::Read + hyper::rt::Write + Unpin + 'static, + B: hyper::body::Body + 'static, + B::Error: Into<Box<dyn std::error::Error + Send + Sync>>, + { + } + + #[cfg(feature = "http1")] + impl<I, B, S> Sealed for hyper::server::conn::http1::UpgradeableConnection<I, S> + where + S: hyper::service::HttpService<hyper::body::Incoming, ResBody = B>, + S::Error: Into<Box<dyn std::error::Error + Send + Sync>>, + I: hyper::rt::Read + hyper::rt::Write + Unpin + 'static, + B: hyper::body::Body + 'static, + B::Error: Into<Box<dyn std::error::Error + Send + Sync>>, + { + } + + #[cfg(feature = "http2")] + impl<I, B, S, E> Sealed for hyper::server::conn::http2::Connection<I, S, E> + where + S: hyper::service::HttpService<hyper::body::Incoming, ResBody = B>, + S::Error: Into<Box<dyn std::error::Error + Send + Sync>>, + I: hyper::rt::Read + hyper::rt::Write + Unpin + 'static, + B: hyper::body::Body + 'static, + B::Error: Into<Box<dyn std::error::Error + Send + Sync>>, + E: hyper::rt::bounds::Http2ServerConnExec<S::Future, B>, + { + } + + #[cfg(feature = "server-auto")] + impl<I, B, S, E> Sealed for crate::server::conn::auto::Connection<'_, I, S, E> + where + S: hyper::service::Service< + http::Request<hyper::body::Incoming>, + Response = http::Response<B>, + >, + S::Error: Into<Box<dyn std::error::Error + Send + Sync>>, + S::Future: 'static, + I: hyper::rt::Read + hyper::rt::Write + Unpin + 'static, + B: hyper::body::Body + 'static, + B::Error: Into<Box<dyn std::error::Error + Send + Sync>>, + E: hyper::rt::bounds::Http2ServerConnExec<S::Future, B>, + { + } + + #[cfg(feature = "server-auto")] + impl<I, B, S, E> Sealed for crate::server::conn::auto::UpgradeableConnection<'_, I, S, E> + where + S: hyper::service::Service< + http::Request<hyper::body::Incoming>, + Response = http::Response<B>, + >, + S::Error: Into<Box<dyn std::error::Error + Send + Sync>>, + S::Future: 'static, + I: hyper::rt::Read + hyper::rt::Write + Unpin + Send + 'static, + B: hyper::body::Body + 'static, + B::Error: Into<Box<dyn std::error::Error + Send + Sync>>, + E: hyper::rt::bounds::Http2ServerConnExec<S::Future, B>, + { + } +} + +#[cfg(test)] +mod test { + use super::*; + use pin_project_lite::pin_project; + use std::sync::atomic::{AtomicUsize, Ordering}; + use std::sync::Arc; + + pin_project! { + #[derive(Debug)] + struct DummyConnection<F> { + #[pin] + future: F, + shutdown_counter: Arc<AtomicUsize>, + } + } + + impl<F> private::Sealed for DummyConnection<F> {} + + impl<F: Future> GracefulConnection for DummyConnection<F> { + type Error = (); + + fn graceful_shutdown(self: Pin<&mut Self>) { + self.shutdown_counter.fetch_add(1, Ordering::SeqCst); + } + } + + impl<F: Future> Future for DummyConnection<F> { + type Output = Result<(), ()>; + + fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> { + match self.project().future.poll(cx) { + Poll::Ready(_) => Poll::Ready(Ok(())), + Poll::Pending => Poll::Pending, + } + } + } + + #[cfg(not(miri))] + #[tokio::test] + async fn test_graceful_shutdown_ok() { + let graceful = GracefulShutdown::new(); + let shutdown_counter = Arc::new(AtomicUsize::new(0)); + let (dummy_tx, _) = tokio::sync::broadcast::channel(1); + + for i in 1..=3 { + let mut dummy_rx = dummy_tx.subscribe(); + let shutdown_counter = shutdown_counter.clone(); + + let future = async move { + tokio::time::sleep(std::time::Duration::from_millis(i * 10)).await; + let _ = dummy_rx.recv().await; + }; + let dummy_conn = DummyConnection { + future, + shutdown_counter, + }; + let conn = graceful.watch(dummy_conn); + tokio::spawn(async move { + conn.await.unwrap(); + }); + } + + assert_eq!(shutdown_counter.load(Ordering::SeqCst), 0); + let _ = dummy_tx.send(()); + + tokio::select! { + _ = tokio::time::sleep(std::time::Duration::from_millis(100)) => { + panic!("timeout") + }, + _ = graceful.shutdown() => { + assert_eq!(shutdown_counter.load(Ordering::SeqCst), 3); + } + } + } + + #[cfg(not(miri))] + #[tokio::test] + async fn test_graceful_shutdown_delayed_ok() { + let graceful = GracefulShutdown::new(); + let shutdown_counter = Arc::new(AtomicUsize::new(0)); + + for i in 1..=3 { + let shutdown_counter = shutdown_counter.clone(); + + //tokio::time::sleep(std::time::Duration::from_millis(i * 5)).await; + let future = async move { + tokio::time::sleep(std::time::Duration::from_millis(i * 50)).await; + }; + let dummy_conn = DummyConnection { + future, + shutdown_counter, + }; + let conn = graceful.watch(dummy_conn); + tokio::spawn(async move { + conn.await.unwrap(); + }); + } + + assert_eq!(shutdown_counter.load(Ordering::SeqCst), 0); + + tokio::select! { + _ = tokio::time::sleep(std::time::Duration::from_millis(200)) => { + panic!("timeout") + }, + _ = graceful.shutdown() => { + assert_eq!(shutdown_counter.load(Ordering::SeqCst), 3); + } + } + } + + #[cfg(not(miri))] + #[tokio::test] + async fn test_graceful_shutdown_multi_per_watcher_ok() { + let graceful = GracefulShutdown::new(); + let shutdown_counter = Arc::new(AtomicUsize::new(0)); + + for i in 1..=3 { + let shutdown_counter = shutdown_counter.clone(); + + let mut futures = Vec::new(); + for u in 1..=i { + let future = tokio::time::sleep(std::time::Duration::from_millis(u * 50)); + let dummy_conn = DummyConnection { + future, + shutdown_counter: shutdown_counter.clone(), + }; + let conn = graceful.watch(dummy_conn); + futures.push(conn); + } + tokio::spawn(async move { + futures_util::future::join_all(futures).await; + }); + } + + assert_eq!(shutdown_counter.load(Ordering::SeqCst), 0); + + tokio::select! { + _ = tokio::time::sleep(std::time::Duration::from_millis(200)) => { + panic!("timeout") + }, + _ = graceful.shutdown() => { + assert_eq!(shutdown_counter.load(Ordering::SeqCst), 6); + } + } + } + + #[cfg(not(miri))] + #[tokio::test] + async fn test_graceful_shutdown_timeout() { + let graceful = GracefulShutdown::new(); + let shutdown_counter = Arc::new(AtomicUsize::new(0)); + + for i in 1..=3 { + let shutdown_counter = shutdown_counter.clone(); + + let future = async move { + if i == 1 { + std::future::pending::<()>().await + } else { + std::future::ready(()).await + } + }; + let dummy_conn = DummyConnection { + future, + shutdown_counter, + }; + let conn = graceful.watch(dummy_conn); + tokio::spawn(async move { + conn.await.unwrap(); + }); + } + + assert_eq!(shutdown_counter.load(Ordering::SeqCst), 0); + + tokio::select! { + _ = tokio::time::sleep(std::time::Duration::from_millis(100)) => { + assert_eq!(shutdown_counter.load(Ordering::SeqCst), 3); + }, + _ = graceful.shutdown() => { + panic!("shutdown should not be completed: as not all our conns finish") + } + } + } +} |
