use std::{ future::Future, pin::Pin, task::{Context, Poll}, time::Duration, }; use pin_project_lite::pin_project; use tokio::time::Sleep; use tower::{BoxError, Layer, Service}; /// This tower layer injects an arbitrary delay before calling downstream layers. #[derive(Clone)] pub struct DelayLayer { delay: Duration, } impl DelayLayer { pub const fn new(delay: Duration) -> Self { DelayLayer { delay } } } impl Layer for DelayLayer { type Service = Delay; fn layer(&self, service: S) -> Self::Service { Delay::new(service, self.delay) } } impl std::fmt::Debug for DelayLayer { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { f.debug_struct("DelayLayer") .field("delay", &self.delay) .finish() } } /// This tower service injects an arbitrary delay before calling downstream layers. #[derive(Debug, Clone)] pub struct Delay { inner: S, delay: Duration, } impl Delay { pub fn new(inner: S, delay: Duration) -> Self { Delay { inner, delay } } } impl Service for Delay where S: Service, S::Error: Into, { type Response = S::Response; type Error = BoxError; type Future = ResponseFuture; fn poll_ready( &mut self, cx: &mut std::task::Context<'_>, ) -> std::task::Poll> { match self.inner.poll_ready(cx) { Poll::Pending => Poll::Pending, Poll::Ready(r) => Poll::Ready(r.map_err(Into::into)), } } fn call(&mut self, req: Request) -> Self::Future { let response = self.inner.call(req); let sleep = tokio::time::sleep(self.delay); ResponseFuture::new(response, sleep) } } // `Delay` response future pin_project! { #[derive(Debug)] pub struct ResponseFuture { #[pin] response: S, #[pin] sleep: Sleep, } } impl ResponseFuture { pub(crate) fn new(response: S, sleep: Sleep) -> Self { ResponseFuture { response, sleep } } } impl Future for ResponseFuture where F: Future>, E: Into, { type Output = Result; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.project(); // First poll the sleep until complete match this.sleep.poll(cx) { Poll::Pending => return Poll::Pending, Poll::Ready(_) => {} } // Then poll the inner future match this.response.poll(cx) { Poll::Ready(v) => Poll::Ready(v.map_err(Into::into)), Poll::Pending => Poll::Pending, } } }