//! Convert panics into responses. //! //! Note that using panics for error handling is _not_ recommended. Prefer instead to use `Result` //! whenever possible. //! //! # Example //! //! ```rust //! use http::{Request, Response, header::HeaderName}; //! use std::convert::Infallible; //! use tower::{Service, ServiceExt, ServiceBuilder, service_fn}; //! use tower_http::catch_panic::CatchPanicLayer; //! use http_body_util::Full; //! use bytes::Bytes; //! //! # #[tokio::main] //! # async fn main() -> Result<(), Box> { //! async fn handle(req: Request>) -> Result>, Infallible> { //! panic!("something went wrong...") //! } //! //! let mut svc = ServiceBuilder::new() //! // Catch panics and convert them into responses. //! .layer(CatchPanicLayer::new()) //! .service_fn(handle); //! //! // Call the service. //! let request = Request::new(Full::default()); //! //! let response = svc.ready().await?.call(request).await?; //! //! assert_eq!(response.status(), 500); //! # //! # Ok(()) //! # } //! ``` //! //! Using a custom panic handler: //! //! ```rust //! use http::{Request, StatusCode, Response, header::{self, HeaderName}}; //! use std::{any::Any, convert::Infallible}; //! use tower::{Service, ServiceExt, ServiceBuilder, service_fn}; //! use tower_http::catch_panic::CatchPanicLayer; //! use bytes::Bytes; //! use http_body_util::Full; //! //! # #[tokio::main] //! # async fn main() -> Result<(), Box> { //! async fn handle(req: Request>) -> Result>, Infallible> { //! panic!("something went wrong...") //! } //! //! fn handle_panic(err: Box) -> Response> { //! let details = if let Some(s) = err.downcast_ref::() { //! s.clone() //! } else if let Some(s) = err.downcast_ref::<&str>() { //! s.to_string() //! } else { //! "Unknown panic message".to_string() //! }; //! //! let body = serde_json::json!({ //! "error": { //! "kind": "panic", //! "details": details, //! } //! }); //! let body = serde_json::to_string(&body).unwrap(); //! //! Response::builder() //! .status(StatusCode::INTERNAL_SERVER_ERROR) //! .header(header::CONTENT_TYPE, "application/json") //! .body(Full::from(body)) //! .unwrap() //! } //! //! let svc = ServiceBuilder::new() //! // Use `handle_panic` to create the response. //! .layer(CatchPanicLayer::custom(handle_panic)) //! .service_fn(handle); //! # //! # Ok(()) //! # } //! ``` use bytes::Bytes; use futures_util::future::{CatchUnwind, FutureExt}; use http::{HeaderValue, Request, Response, StatusCode}; use http_body::Body; use http_body_util::BodyExt; use pin_project_lite::pin_project; use std::{ any::Any, future::Future, panic::AssertUnwindSafe, pin::Pin, task::{ready, Context, Poll}, }; use tower_layer::Layer; use tower_service::Service; use crate::{ body::{Full, UnsyncBoxBody}, BoxError, }; /// Layer that applies the [`CatchPanic`] middleware that catches panics and converts them into /// `500 Internal Server` responses. /// /// See the [module docs](self) for an example. #[derive(Debug, Clone, Copy, Default)] pub struct CatchPanicLayer { panic_handler: T, } impl CatchPanicLayer { /// Create a new `CatchPanicLayer` with the default panic handler. pub fn new() -> Self { CatchPanicLayer { panic_handler: DefaultResponseForPanic, } } } impl CatchPanicLayer { /// Create a new `CatchPanicLayer` with a custom panic handler. pub fn custom(panic_handler: T) -> Self where T: ResponseForPanic, { Self { panic_handler } } } impl Layer for CatchPanicLayer where T: Clone, { type Service = CatchPanic; fn layer(&self, inner: S) -> Self::Service { CatchPanic { inner, panic_handler: self.panic_handler.clone(), } } } /// Middleware that catches panics and converts them into `500 Internal Server` responses. /// /// See the [module docs](self) for an example. #[derive(Debug, Clone, Copy)] pub struct CatchPanic { inner: S, panic_handler: T, } impl CatchPanic { /// Create a new `CatchPanic` with the default panic handler. pub fn new(inner: S) -> Self { Self { inner, panic_handler: DefaultResponseForPanic, } } } impl CatchPanic { define_inner_service_accessors!(); /// Create a new `CatchPanic` with a custom panic handler. pub fn custom(inner: S, panic_handler: T) -> Self where T: ResponseForPanic, { Self { inner, panic_handler, } } } impl Service> for CatchPanic where S: Service, Response = Response>, ResBody: Body + Send + 'static, ResBody::Error: Into, T: ResponseForPanic + Clone, T::ResponseBody: Body + Send + 'static, ::Error: Into, { type Response = Response>; type Error = S::Error; type Future = ResponseFuture; #[inline] fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { self.inner.poll_ready(cx) } fn call(&mut self, req: Request) -> Self::Future { match std::panic::catch_unwind(AssertUnwindSafe(|| self.inner.call(req))) { Ok(future) => ResponseFuture { kind: Kind::Future { future: AssertUnwindSafe(future).catch_unwind(), panic_handler: Some(self.panic_handler.clone()), }, }, Err(panic_err) => ResponseFuture { kind: Kind::Panicked { panic_err: Some(panic_err), panic_handler: Some(self.panic_handler.clone()), }, }, } } } pin_project! { /// Response future for [`CatchPanic`]. pub struct ResponseFuture { #[pin] kind: Kind, } } pin_project! { #[project = KindProj] enum Kind { Panicked { panic_err: Option>, panic_handler: Option, }, Future { #[pin] future: CatchUnwind>, panic_handler: Option, } } } impl Future for ResponseFuture where F: Future, E>>, ResBody: Body + Send + 'static, ResBody::Error: Into, T: ResponseForPanic, T::ResponseBody: Body + Send + 'static, ::Error: Into, { type Output = Result>, E>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { match self.project().kind.project() { KindProj::Panicked { panic_err, panic_handler, } => { let panic_handler = panic_handler .take() .expect("future polled after completion"); let panic_err = panic_err.take().expect("future polled after completion"); Poll::Ready(Ok(response_for_panic(panic_handler, panic_err))) } KindProj::Future { future, panic_handler, } => match ready!(future.poll(cx)) { Ok(Ok(res)) => { Poll::Ready(Ok(res.map(|body| { UnsyncBoxBody::new(body.map_err(Into::into).boxed_unsync()) }))) } Ok(Err(svc_err)) => Poll::Ready(Err(svc_err)), Err(panic_err) => Poll::Ready(Ok(response_for_panic( panic_handler .take() .expect("future polled after completion"), panic_err, ))), }, } } } fn response_for_panic( mut panic_handler: T, err: Box, ) -> Response> where T: ResponseForPanic, T::ResponseBody: Body + Send + 'static, ::Error: Into, { panic_handler .response_for_panic(err) .map(|body| UnsyncBoxBody::new(body.map_err(Into::into).boxed_unsync())) } /// Trait for creating responses from panics. pub trait ResponseForPanic: Clone { /// The body type used for responses to panics. type ResponseBody; /// Create a response from the panic error. fn response_for_panic( &mut self, err: Box, ) -> Response; } impl ResponseForPanic for F where F: FnMut(Box) -> Response + Clone, { type ResponseBody = B; fn response_for_panic( &mut self, err: Box, ) -> Response { self(err) } } /// The default `ResponseForPanic` used by `CatchPanic`. /// /// It will log the panic message and return a `500 Internal Server` error response with an empty /// body. #[derive(Debug, Default, Clone, Copy)] #[non_exhaustive] pub struct DefaultResponseForPanic; impl ResponseForPanic for DefaultResponseForPanic { type ResponseBody = Full; fn response_for_panic( &mut self, err: Box, ) -> Response { if let Some(s) = err.downcast_ref::() { tracing::error!("Service panicked: {}", s); } else if let Some(s) = err.downcast_ref::<&str>() { tracing::error!("Service panicked: {}", s); } else { tracing::error!( "Service panicked but `CatchPanic` was unable to downcast the panic info" ); }; let mut res = Response::new(Full::new(http_body_util::Full::from("Service panicked"))); *res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR; #[allow(clippy::declare_interior_mutable_const)] const TEXT_PLAIN: HeaderValue = HeaderValue::from_static("text/plain; charset=utf-8"); res.headers_mut() .insert(http::header::CONTENT_TYPE, TEXT_PLAIN); res } } #[cfg(test)] mod tests { #![allow(unreachable_code)] use super::*; use crate::test_helpers::Body; use http::Response; use std::convert::Infallible; use tower::{ServiceBuilder, ServiceExt}; #[tokio::test] async fn panic_before_returning_future() { let svc = ServiceBuilder::new() .layer(CatchPanicLayer::new()) .service_fn(|_: Request| { panic!("service panic"); async { Ok::<_, Infallible>(Response::new(Body::empty())) } }); let req = Request::new(Body::empty()); let res = svc.oneshot(req).await.unwrap(); assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR); let body = crate::test_helpers::to_bytes(res).await.unwrap(); assert_eq!(&body[..], b"Service panicked"); } #[tokio::test] async fn panic_in_future() { let svc = ServiceBuilder::new() .layer(CatchPanicLayer::new()) .service_fn(|_: Request| async { panic!("future panic"); Ok::<_, Infallible>(Response::new(Body::empty())) }); let req = Request::new(Body::empty()); let res = svc.oneshot(req).await.unwrap(); assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR); let body = crate::test_helpers::to_bytes(res).await.unwrap(); assert_eq!(&body[..], b"Service panicked"); } }