//! Serve services. use std::{ convert::Infallible, fmt::Debug, future::{Future, IntoFuture}, io, marker::PhantomData, pin::pin, }; use axum_core::{body::Body, extract::Request, response::Response}; use futures_util::FutureExt; use hyper::body::Incoming; use hyper_util::rt::{TokioExecutor, TokioIo}; #[cfg(any(feature = "http1", feature = "http2"))] use hyper_util::{server::conn::auto::Builder, service::TowerToHyperService}; use tokio::sync::watch; use tower::ServiceExt as _; use tower_service::Service; mod listener; pub use self::listener::{Listener, ListenerExt, TapIo}; /// Serve the service with the supplied listener. /// /// This method of running a service is intentionally simple and doesn't support any configuration. /// Use hyper or hyper-util if you need configuration. /// /// It supports both HTTP/1 as well as HTTP/2. /// /// # Examples /// /// Serving a [`Router`]: /// /// ``` /// use axum::{Router, routing::get}; /// /// # async { /// let router = Router::new().route("/", get(|| async { "Hello, World!" })); /// /// let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap(); /// axum::serve(listener, router).await.unwrap(); /// # }; /// ``` /// /// See also [`Router::into_make_service_with_connect_info`]. /// /// Serving a [`MethodRouter`]: /// /// ``` /// use axum::routing::get; /// /// # async { /// let router = get(|| async { "Hello, World!" }); /// /// let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap(); /// axum::serve(listener, router).await.unwrap(); /// # }; /// ``` /// /// See also [`MethodRouter::into_make_service_with_connect_info`]. /// /// Serving a [`Handler`]: /// /// ``` /// use axum::handler::HandlerWithoutStateExt; /// /// # async { /// async fn handler() -> &'static str { /// "Hello, World!" /// } /// /// let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap(); /// axum::serve(listener, handler.into_make_service()).await.unwrap(); /// # }; /// ``` /// /// See also [`HandlerWithoutStateExt::into_make_service_with_connect_info`] and /// [`HandlerService::into_make_service_with_connect_info`]. /// /// # Return Value /// /// Although this future resolves to `io::Result<()>`, it will never actually complete or return an /// error. Errors on the TCP socket will be handled by sleeping for a short while (currently, one /// second). /// /// [`Router`]: crate::Router /// [`Router::into_make_service_with_connect_info`]: crate::Router::into_make_service_with_connect_info /// [`MethodRouter`]: crate::routing::MethodRouter /// [`MethodRouter::into_make_service_with_connect_info`]: crate::routing::MethodRouter::into_make_service_with_connect_info /// [`Handler`]: crate::handler::Handler /// [`HandlerWithoutStateExt::into_make_service_with_connect_info`]: crate::handler::HandlerWithoutStateExt::into_make_service_with_connect_info /// [`HandlerService::into_make_service_with_connect_info`]: crate::handler::HandlerService::into_make_service_with_connect_info #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] pub fn serve(listener: L, make_service: M) -> Serve where L: Listener, M: for<'a> Service, Error = Infallible, Response = S>, S: Service + Clone + Send + 'static, S::Future: Send, { Serve { listener, make_service, _marker: PhantomData, } } /// Future returned by [`serve`]. #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] #[must_use = "futures must be awaited or polled"] pub struct Serve { listener: L, make_service: M, _marker: PhantomData, } #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] impl Serve where L: Listener, { /// Prepares a server to handle graceful shutdown when the provided future completes. /// /// # Example /// /// ``` /// use axum::{Router, routing::get}; /// /// # async { /// let router = Router::new().route("/", get(|| async { "Hello, World!" })); /// /// let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap(); /// axum::serve(listener, router) /// .with_graceful_shutdown(shutdown_signal()) /// .await /// .unwrap(); /// # }; /// /// async fn shutdown_signal() { /// // ... /// } /// ``` /// /// # Return Value /// /// Similarly to [`serve`], although this future resolves to `io::Result<()>`, it will never /// error. It returns `Ok(())` only after the `signal` future completes. pub fn with_graceful_shutdown(self, signal: F) -> WithGracefulShutdown where F: Future + Send + 'static, { WithGracefulShutdown { listener: self.listener, make_service: self.make_service, signal, _marker: PhantomData, } } /// Returns the local address this server is bound to. pub fn local_addr(&self) -> io::Result { self.listener.local_addr() } } #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] impl Serve where L: Listener, L::Addr: Debug, M: for<'a> Service, Error = Infallible, Response = S> + Send + 'static, for<'a> >>::Future: Send, S: Service + Clone + Send + 'static, S::Future: Send, { async fn run(self) -> ! { let Self { mut listener, mut make_service, _marker, } = self; let (signal_tx, _signal_rx) = watch::channel(()); let (_close_tx, close_rx) = watch::channel(()); loop { let (io, remote_addr) = listener.accept().await; handle_connection(&mut make_service, &signal_tx, &close_rx, io, remote_addr).await; } } } #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] impl Debug for Serve where L: Debug + 'static, M: Debug, { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let Self { listener, make_service, _marker: _, } = self; let mut s = f.debug_struct("Serve"); s.field("listener", listener) .field("make_service", make_service); s.finish() } } #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] impl IntoFuture for Serve where L: Listener, L::Addr: Debug, M: for<'a> Service, Error = Infallible, Response = S> + Send + 'static, for<'a> >>::Future: Send, S: Service + Clone + Send + 'static, S::Future: Send, { type Output = io::Result<()>; type IntoFuture = private::ServeFuture; fn into_future(self) -> Self::IntoFuture { private::ServeFuture(Box::pin(async move { self.run().await })) } } /// Serve future with graceful shutdown enabled. #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] #[must_use = "futures must be awaited or polled"] pub struct WithGracefulShutdown { listener: L, make_service: M, signal: F, _marker: PhantomData, } #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] impl WithGracefulShutdown where L: Listener, { /// Returns the local address this server is bound to. pub fn local_addr(&self) -> io::Result { self.listener.local_addr() } } #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] impl WithGracefulShutdown where L: Listener, L::Addr: Debug, M: for<'a> Service, Error = Infallible, Response = S> + Send + 'static, for<'a> >>::Future: Send, S: Service + Clone + Send + 'static, S::Future: Send, F: Future + Send + 'static, { async fn run(self) { let Self { mut listener, mut make_service, signal, _marker, } = self; let (signal_tx, signal_rx) = watch::channel(()); tokio::spawn(async move { signal.await; trace!("received graceful shutdown signal. Telling tasks to shutdown"); drop(signal_rx); }); let (close_tx, close_rx) = watch::channel(()); loop { let (io, remote_addr) = tokio::select! { conn = listener.accept() => conn, _ = signal_tx.closed() => { trace!("signal received, not accepting new connections"); break; } }; handle_connection(&mut make_service, &signal_tx, &close_rx, io, remote_addr).await; } drop(close_rx); drop(listener); trace!( "waiting for {} task(s) to finish", close_tx.receiver_count() ); close_tx.closed().await; } } #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] impl Debug for WithGracefulShutdown where L: Debug + 'static, M: Debug, S: Debug, F: Debug, { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let Self { listener, make_service, signal, _marker: _, } = self; f.debug_struct("WithGracefulShutdown") .field("listener", listener) .field("make_service", make_service) .field("signal", signal) .finish() } } #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] impl IntoFuture for WithGracefulShutdown where L: Listener, L::Addr: Debug, M: for<'a> Service, Error = Infallible, Response = S> + Send + 'static, for<'a> >>::Future: Send, S: Service + Clone + Send + 'static, S::Future: Send, F: Future + Send + 'static, { type Output = io::Result<()>; type IntoFuture = private::ServeFuture; fn into_future(self) -> Self::IntoFuture { private::ServeFuture(Box::pin(async move { self.run().await; Ok(()) })) } } async fn handle_connection( make_service: &mut M, signal_tx: &watch::Sender<()>, close_rx: &watch::Receiver<()>, io: ::Io, remote_addr: ::Addr, ) where L: Listener, L::Addr: Debug, M: for<'a> Service, Error = Infallible, Response = S> + Send + 'static, for<'a> >>::Future: Send, S: Service + Clone + Send + 'static, S::Future: Send, { let io = TokioIo::new(io); trace!("connection {remote_addr:?} accepted"); make_service .ready() .await .unwrap_or_else(|err| match err {}); let tower_service = make_service .call(IncomingStream { io: &io, remote_addr, }) .await .unwrap_or_else(|err| match err {}) .map_request(|req: Request| req.map(Body::new)); let hyper_service = TowerToHyperService::new(tower_service); let signal_tx = signal_tx.clone(); let close_rx = close_rx.clone(); tokio::spawn(async move { #[allow(unused_mut)] let mut builder = Builder::new(TokioExecutor::new()); // CONNECT protocol needed for HTTP/2 websockets #[cfg(feature = "http2")] builder.http2().enable_connect_protocol(); let mut conn = pin!(builder.serve_connection_with_upgrades(io, hyper_service)); let mut signal_closed = pin!(signal_tx.closed().fuse()); loop { tokio::select! { result = conn.as_mut() => { if let Err(_err) = result { trace!("failed to serve connection: {_err:#}"); } break; } _ = &mut signal_closed => { trace!("signal received in task, starting graceful shutdown"); conn.as_mut().graceful_shutdown(); } } } drop(close_rx); }); } /// An incoming stream. /// /// Used with [`serve`] and [`IntoMakeServiceWithConnectInfo`]. /// /// [`IntoMakeServiceWithConnectInfo`]: crate::extract::connect_info::IntoMakeServiceWithConnectInfo #[derive(Debug)] pub struct IncomingStream<'a, L> where L: Listener, { io: &'a TokioIo, remote_addr: L::Addr, } impl IncomingStream<'_, L> where L: Listener, { /// Get a reference to the inner IO type. pub fn io(&self) -> &L::Io { self.io.inner() } /// Returns the remote address that this stream is bound to. pub fn remote_addr(&self) -> &L::Addr { &self.remote_addr } } mod private { use std::{ future::Future, io, pin::Pin, task::{Context, Poll}, }; pub struct ServeFuture(pub(super) futures_util::future::BoxFuture<'static, io::Result<()>>); impl Future for ServeFuture { type Output = io::Result<()>; #[inline] fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { self.0.as_mut().poll(cx) } } impl std::fmt::Debug for ServeFuture { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("ServeFuture").finish_non_exhaustive() } } } #[cfg(test)] mod tests { use std::{ future::{pending, IntoFuture as _}, net::{IpAddr, Ipv4Addr}, }; use axum_core::{body::Body, extract::Request}; use http::StatusCode; use hyper_util::rt::TokioIo; #[cfg(unix)] use tokio::net::UnixListener; use tokio::{ io::{self, AsyncRead, AsyncWrite}, net::TcpListener, }; #[cfg(unix)] use super::IncomingStream; use super::{serve, Listener}; #[cfg(unix)] use crate::extract::connect_info::Connected; use crate::{ body::to_bytes, handler::{Handler, HandlerWithoutStateExt}, routing::get, serve::ListenerExt, Router, }; #[allow(dead_code, unused_must_use)] async fn if_it_compiles_it_works() { #[derive(Clone, Debug)] struct UdsConnectInfo; #[cfg(unix)] impl Connected> for UdsConnectInfo { fn connect_info(_stream: IncomingStream<'_, UnixListener>) -> Self { Self } } let router: Router = Router::new(); let addr = "0.0.0.0:0"; let tcp_nodelay_listener = || async { TcpListener::bind(addr).await.unwrap().tap_io(|tcp_stream| { if let Err(err) = tcp_stream.set_nodelay(true) { eprintln!("failed to set TCP_NODELAY on incoming connection: {err:#}"); } }) }; // router serve(TcpListener::bind(addr).await.unwrap(), router.clone()); serve(tcp_nodelay_listener().await, router.clone()) .await .unwrap(); #[cfg(unix)] serve(UnixListener::bind("").unwrap(), router.clone()); serve( TcpListener::bind(addr).await.unwrap(), router.clone().into_make_service(), ); serve( tcp_nodelay_listener().await, router.clone().into_make_service(), ); #[cfg(unix)] serve( UnixListener::bind("").unwrap(), router.clone().into_make_service(), ); serve( TcpListener::bind(addr).await.unwrap(), router .clone() .into_make_service_with_connect_info::(), ); serve( tcp_nodelay_listener().await, router .clone() .into_make_service_with_connect_info::(), ); #[cfg(unix)] serve( UnixListener::bind("").unwrap(), router.into_make_service_with_connect_info::(), ); // method router serve(TcpListener::bind(addr).await.unwrap(), get(handler)); serve(tcp_nodelay_listener().await, get(handler)); #[cfg(unix)] serve(UnixListener::bind("").unwrap(), get(handler)); serve( TcpListener::bind(addr).await.unwrap(), get(handler).into_make_service(), ); serve( tcp_nodelay_listener().await, get(handler).into_make_service(), ); #[cfg(unix)] serve( UnixListener::bind("").unwrap(), get(handler).into_make_service(), ); serve( TcpListener::bind(addr).await.unwrap(), get(handler).into_make_service_with_connect_info::(), ); serve( tcp_nodelay_listener().await, get(handler).into_make_service_with_connect_info::(), ); #[cfg(unix)] serve( UnixListener::bind("").unwrap(), get(handler).into_make_service_with_connect_info::(), ); // handler serve( TcpListener::bind(addr).await.unwrap(), handler.into_service(), ); serve(tcp_nodelay_listener().await, handler.into_service()); #[cfg(unix)] serve(UnixListener::bind("").unwrap(), handler.into_service()); serve( TcpListener::bind(addr).await.unwrap(), handler.with_state(()), ); serve(tcp_nodelay_listener().await, handler.with_state(())); #[cfg(unix)] serve(UnixListener::bind("").unwrap(), handler.with_state(())); serve( TcpListener::bind(addr).await.unwrap(), handler.into_make_service(), ); serve(tcp_nodelay_listener().await, handler.into_make_service()); #[cfg(unix)] serve(UnixListener::bind("").unwrap(), handler.into_make_service()); serve( TcpListener::bind(addr).await.unwrap(), handler.into_make_service_with_connect_info::(), ); serve( tcp_nodelay_listener().await, handler.into_make_service_with_connect_info::(), ); #[cfg(unix)] serve( UnixListener::bind("").unwrap(), handler.into_make_service_with_connect_info::(), ); } async fn handler() {} #[crate::test] async fn test_serve_local_addr() { let router: Router = Router::new(); let addr = "0.0.0.0:0"; let server = serve(TcpListener::bind(addr).await.unwrap(), router.clone()); let address = server.local_addr().unwrap(); assert_eq!(address.ip(), IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0))); assert_ne!(address.port(), 0); } #[crate::test] async fn test_with_graceful_shutdown_local_addr() { let router: Router = Router::new(); let addr = "0.0.0.0:0"; let server = serve(TcpListener::bind(addr).await.unwrap(), router.clone()) .with_graceful_shutdown(pending()); let address = server.local_addr().unwrap(); assert_eq!(address.ip(), IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0))); assert_ne!(address.port(), 0); } #[test] fn into_future_outside_tokio() { let router: Router = Router::new(); let addr = "0.0.0.0:0"; let rt = tokio::runtime::Builder::new_multi_thread() .enable_all() .build() .unwrap(); let listener = rt.block_on(tokio::net::TcpListener::bind(addr)).unwrap(); // Call Serve::into_future outside of a tokio context. This used to panic. _ = serve(listener, router).into_future(); } #[crate::test] async fn serving_on_custom_io_type() { struct ReadyListener(Option); impl Listener for ReadyListener where T: AsyncRead + AsyncWrite + Unpin + Send + 'static, { type Io = T; type Addr = (); async fn accept(&mut self) -> (Self::Io, Self::Addr) { match self.0.take() { Some(server) => (server, ()), None => std::future::pending().await, } } fn local_addr(&self) -> io::Result { Ok(()) } } let (client, server) = io::duplex(1024); let listener = ReadyListener(Some(server)); let app = Router::new().route("/", get(|| async { "Hello, World!" })); tokio::spawn(serve(listener, app).into_future()); let stream = TokioIo::new(client); let (mut sender, conn) = hyper::client::conn::http1::handshake(stream).await.unwrap(); tokio::spawn(conn); let request = Request::builder().body(Body::empty()).unwrap(); let response = sender.send_request(request).await.unwrap(); assert_eq!(response.status(), StatusCode::OK); let body = Body::new(response.into_body()); let body = to_bytes(body, usize::MAX).await.unwrap(); let body = String::from_utf8(body.to_vec()).unwrap(); assert_eq!(body, "Hello, World!"); } }