use crate::timeout::body::TimeoutBody; use http::{Request, Response, StatusCode}; use pin_project_lite::pin_project; use std::{ future::Future, pin::Pin, task::{ready, Context, Poll}, time::Duration, }; use tokio::time::Sleep; use tower_layer::Layer; use tower_service::Service; /// Layer that applies the [`Timeout`] middleware which apply a timeout to requests. /// /// See the [module docs](super) for an example. #[derive(Debug, Clone, Copy)] pub struct TimeoutLayer { timeout: Duration, } impl TimeoutLayer { /// Creates a new [`TimeoutLayer`]. pub fn new(timeout: Duration) -> Self { TimeoutLayer { timeout } } } impl Layer for TimeoutLayer { type Service = Timeout; fn layer(&self, inner: S) -> Self::Service { Timeout::new(inner, self.timeout) } } /// Middleware which apply a timeout to requests. /// /// If the request does not complete within the specified timeout it will be aborted and a `408 /// Request Timeout` response will be sent. /// /// See the [module docs](super) for an example. #[derive(Debug, Clone, Copy)] pub struct Timeout { inner: S, timeout: Duration, } impl Timeout { /// Creates a new [`Timeout`]. pub fn new(inner: S, timeout: Duration) -> Self { Self { inner, timeout } } define_inner_service_accessors!(); /// Returns a new [`Layer`] that wraps services with a `Timeout` middleware. /// /// [`Layer`]: tower_layer::Layer pub fn layer(timeout: Duration) -> TimeoutLayer { TimeoutLayer::new(timeout) } } impl Service> for Timeout where S: Service, Response = Response>, ResBody: Default, { type Response = S::Response; type Error = S::Error; type Future = ResponseFuture; #[inline] fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { self.inner.poll_ready(cx) } fn call(&mut self, req: Request) -> Self::Future { let sleep = tokio::time::sleep(self.timeout); ResponseFuture { inner: self.inner.call(req), sleep, } } } pin_project! { /// Response future for [`Timeout`]. pub struct ResponseFuture { #[pin] inner: F, #[pin] sleep: Sleep, } } impl Future for ResponseFuture where F: Future, E>>, B: Default, { type Output = Result, E>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.project(); if this.sleep.poll(cx).is_ready() { let mut res = Response::new(B::default()); *res.status_mut() = StatusCode::REQUEST_TIMEOUT; return Poll::Ready(Ok(res)); } this.inner.poll(cx) } } /// Applies a [`TimeoutBody`] to the request body. #[derive(Clone, Debug)] pub struct RequestBodyTimeoutLayer { timeout: Duration, } impl RequestBodyTimeoutLayer { /// Creates a new [`RequestBodyTimeoutLayer`]. pub fn new(timeout: Duration) -> Self { Self { timeout } } } impl Layer for RequestBodyTimeoutLayer { type Service = RequestBodyTimeout; fn layer(&self, inner: S) -> Self::Service { RequestBodyTimeout::new(inner, self.timeout) } } /// Applies a [`TimeoutBody`] to the request body. #[derive(Clone, Debug)] pub struct RequestBodyTimeout { inner: S, timeout: Duration, } impl RequestBodyTimeout { /// Creates a new [`RequestBodyTimeout`]. pub fn new(service: S, timeout: Duration) -> Self { Self { inner: service, timeout, } } /// Returns a new [`Layer`] that wraps services with a [`RequestBodyTimeoutLayer`] middleware. /// /// [`Layer`]: tower_layer::Layer pub fn layer(timeout: Duration) -> RequestBodyTimeoutLayer { RequestBodyTimeoutLayer::new(timeout) } define_inner_service_accessors!(); } impl Service> for RequestBodyTimeout where S: Service>>, { type Response = S::Response; type Error = S::Error; type Future = S::Future; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { self.inner.poll_ready(cx) } fn call(&mut self, req: Request) -> Self::Future { let req = req.map(|body| TimeoutBody::new(self.timeout, body)); self.inner.call(req) } } /// Applies a [`TimeoutBody`] to the response body. #[derive(Clone)] pub struct ResponseBodyTimeoutLayer { timeout: Duration, } impl ResponseBodyTimeoutLayer { /// Creates a new [`ResponseBodyTimeoutLayer`]. pub fn new(timeout: Duration) -> Self { Self { timeout } } } impl Layer for ResponseBodyTimeoutLayer { type Service = ResponseBodyTimeout; fn layer(&self, inner: S) -> Self::Service { ResponseBodyTimeout::new(inner, self.timeout) } } /// Applies a [`TimeoutBody`] to the response body. #[derive(Clone)] pub struct ResponseBodyTimeout { inner: S, timeout: Duration, } impl Service> for ResponseBodyTimeout where S: Service, Response = Response>, { type Response = Response>; type Error = S::Error; type Future = ResponseBodyTimeoutFuture; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { self.inner.poll_ready(cx) } fn call(&mut self, req: Request) -> Self::Future { ResponseBodyTimeoutFuture { inner: self.inner.call(req), timeout: self.timeout, } } } impl ResponseBodyTimeout { /// Creates a new [`ResponseBodyTimeout`]. pub fn new(service: S, timeout: Duration) -> Self { Self { inner: service, timeout, } } /// Returns a new [`Layer`] that wraps services with a [`ResponseBodyTimeoutLayer`] middleware. /// /// [`Layer`]: tower_layer::Layer pub fn layer(timeout: Duration) -> ResponseBodyTimeoutLayer { ResponseBodyTimeoutLayer::new(timeout) } define_inner_service_accessors!(); } pin_project! { /// Response future for [`ResponseBodyTimeout`]. pub struct ResponseBodyTimeoutFuture { #[pin] inner: Fut, timeout: Duration, } } impl Future for ResponseBodyTimeoutFuture where Fut: Future, E>>, { type Output = Result>, E>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let timeout = self.timeout; let this = self.project(); let res = ready!(this.inner.poll(cx))?; Poll::Ready(Ok(res.map(|body| TimeoutBody::new(timeout, body)))) } }