//! Set a header on the response. //! //! The header value to be set may be provided as a fixed value when the //! middleware is constructed, or determined dynamically based on the response //! by a closure. See the [`MakeHeaderValue`] trait for details. //! //! # Example //! //! Setting a header from a fixed value provided when the middleware is constructed: //! //! ``` //! use http::{Request, Response, header::{self, HeaderValue}}; //! use tower::{Service, ServiceExt, ServiceBuilder}; //! use tower_http::set_header::SetResponseHeaderLayer; //! use http_body_util::Full; //! use bytes::Bytes; //! //! # #[tokio::main] //! # async fn main() -> Result<(), Box> { //! # let render_html = tower::service_fn(|request: Request>| async move { //! # Ok::<_, std::convert::Infallible>(Response::new(request.into_body())) //! # }); //! # //! let mut svc = ServiceBuilder::new() //! .layer( //! // Layer that sets `Content-Type: text/html` on responses. //! // //! // `if_not_present` will only insert the header if it does not already //! // have a value. //! SetResponseHeaderLayer::if_not_present( //! header::CONTENT_TYPE, //! HeaderValue::from_static("text/html"), //! ) //! ) //! .service(render_html); //! //! let request = Request::new(Full::default()); //! //! let response = svc.ready().await?.call(request).await?; //! //! assert_eq!(response.headers()["content-type"], "text/html"); //! # //! # Ok(()) //! # } //! ``` //! //! Setting a header based on a value determined dynamically from the response: //! //! ``` //! use http::{Request, Response, header::{self, HeaderValue}}; //! use tower::{Service, ServiceExt, ServiceBuilder}; //! use tower_http::set_header::SetResponseHeaderLayer; //! use bytes::Bytes; //! use http_body_util::Full; //! use http_body::Body as _; // for `Body::size_hint` //! //! # #[tokio::main] //! # async fn main() -> Result<(), Box> { //! # let render_html = tower::service_fn(|request: Request>| async move { //! # Ok::<_, std::convert::Infallible>(Response::new(Full::from("1234567890"))) //! # }); //! # //! let mut svc = ServiceBuilder::new() //! .layer( //! // Layer that sets `Content-Length` if the body has a known size. //! // Bodies with streaming responses wont have a known size. //! // //! // `overriding` will insert the header and override any previous values it //! // may have. //! SetResponseHeaderLayer::overriding( //! header::CONTENT_LENGTH, //! |response: &Response>| { //! if let Some(size) = response.body().size_hint().exact() { //! // If the response body has a known size, returning `Some` will //! // set the `Content-Length` header to that value. //! Some(HeaderValue::from_str(&size.to_string()).unwrap()) //! } else { //! // If the response body doesn't have a known size, return `None` //! // to skip setting the header on this response. //! None //! } //! } //! ) //! ) //! .service(render_html); //! //! let request = Request::new(Full::default()); //! //! let response = svc.ready().await?.call(request).await?; //! //! assert_eq!(response.headers()["content-length"], "10"); //! # //! # Ok(()) //! # } //! ``` use super::{InsertHeaderMode, MakeHeaderValue}; use http::{header::HeaderName, Request, Response}; use pin_project_lite::pin_project; use std::{ fmt, future::Future, pin::Pin, task::{ready, Context, Poll}, }; use tower_layer::Layer; use tower_service::Service; /// Layer that applies [`SetResponseHeader`] which adds a response header. /// /// See [`SetResponseHeader`] for more details. pub struct SetResponseHeaderLayer { header_name: HeaderName, make: M, mode: InsertHeaderMode, } impl fmt::Debug for SetResponseHeaderLayer { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("SetResponseHeaderLayer") .field("header_name", &self.header_name) .field("mode", &self.mode) .field("make", &std::any::type_name::()) .finish() } } impl SetResponseHeaderLayer { /// Create a new [`SetResponseHeaderLayer`]. /// /// If a previous value exists for the same header, it is removed and replaced with the new /// header value. pub fn overriding(header_name: HeaderName, make: M) -> Self { Self::new(header_name, make, InsertHeaderMode::Override) } /// Create a new [`SetResponseHeaderLayer`]. /// /// The new header is always added, preserving any existing values. If previous values exist, /// the header will have multiple values. pub fn appending(header_name: HeaderName, make: M) -> Self { Self::new(header_name, make, InsertHeaderMode::Append) } /// Create a new [`SetResponseHeaderLayer`]. /// /// If a previous value exists for the header, the new value is not inserted. pub fn if_not_present(header_name: HeaderName, make: M) -> Self { Self::new(header_name, make, InsertHeaderMode::IfNotPresent) } fn new(header_name: HeaderName, make: M, mode: InsertHeaderMode) -> Self { Self { make, header_name, mode, } } } impl Layer for SetResponseHeaderLayer where M: Clone, { type Service = SetResponseHeader; fn layer(&self, inner: S) -> Self::Service { SetResponseHeader { inner, header_name: self.header_name.clone(), make: self.make.clone(), mode: self.mode, } } } impl Clone for SetResponseHeaderLayer where M: Clone, { fn clone(&self) -> Self { Self { make: self.make.clone(), header_name: self.header_name.clone(), mode: self.mode, } } } /// Middleware that sets a header on the response. #[derive(Clone)] pub struct SetResponseHeader { inner: S, header_name: HeaderName, make: M, mode: InsertHeaderMode, } impl SetResponseHeader { /// Create a new [`SetResponseHeader`]. /// /// If a previous value exists for the same header, it is removed and replaced with the new /// header value. pub fn overriding(inner: S, header_name: HeaderName, make: M) -> Self { Self::new(inner, header_name, make, InsertHeaderMode::Override) } /// Create a new [`SetResponseHeader`]. /// /// The new header is always added, preserving any existing values. If previous values exist, /// the header will have multiple values. pub fn appending(inner: S, header_name: HeaderName, make: M) -> Self { Self::new(inner, header_name, make, InsertHeaderMode::Append) } /// Create a new [`SetResponseHeader`]. /// /// If a previous value exists for the header, the new value is not inserted. pub fn if_not_present(inner: S, header_name: HeaderName, make: M) -> Self { Self::new(inner, header_name, make, InsertHeaderMode::IfNotPresent) } fn new(inner: S, header_name: HeaderName, make: M, mode: InsertHeaderMode) -> Self { Self { inner, header_name, make, mode, } } define_inner_service_accessors!(); } impl fmt::Debug for SetResponseHeader where S: fmt::Debug, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("SetResponseHeader") .field("inner", &self.inner) .field("header_name", &self.header_name) .field("mode", &self.mode) .field("make", &std::any::type_name::()) .finish() } } impl Service> for SetResponseHeader where S: Service, Response = Response>, M: MakeHeaderValue> + Clone, { type Response = S::Response; type Error = S::Error; type Future = ResponseFuture; #[inline] fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { self.inner.poll_ready(cx) } fn call(&mut self, req: Request) -> Self::Future { ResponseFuture { future: self.inner.call(req), header_name: self.header_name.clone(), make: self.make.clone(), mode: self.mode, } } } pin_project! { /// Response future for [`SetResponseHeader`]. #[derive(Debug)] pub struct ResponseFuture { #[pin] future: F, header_name: HeaderName, make: M, mode: InsertHeaderMode, } } impl Future for ResponseFuture where F: Future, E>>, M: MakeHeaderValue>, { type Output = F::Output; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.project(); let mut res = ready!(this.future.poll(cx)?); this.mode.apply(this.header_name, &mut res, &mut *this.make); Poll::Ready(Ok(res)) } } #[cfg(test)] mod tests { use super::*; use crate::test_helpers::Body; use http::{header, HeaderValue}; use std::convert::Infallible; use tower::{service_fn, ServiceExt}; #[tokio::test] async fn test_override_mode() { let svc = SetResponseHeader::overriding( service_fn(|_req: Request| async { let res = Response::builder() .header(header::CONTENT_TYPE, "good-content") .body(Body::empty()) .unwrap(); Ok::<_, Infallible>(res) }), header::CONTENT_TYPE, HeaderValue::from_static("text/html"), ); let res = svc.oneshot(Request::new(Body::empty())).await.unwrap(); let mut values = res.headers().get_all(header::CONTENT_TYPE).iter(); assert_eq!(values.next().unwrap(), "text/html"); assert_eq!(values.next(), None); } #[tokio::test] async fn test_append_mode() { let svc = SetResponseHeader::appending( service_fn(|_req: Request| async { let res = Response::builder() .header(header::CONTENT_TYPE, "good-content") .body(Body::empty()) .unwrap(); Ok::<_, Infallible>(res) }), header::CONTENT_TYPE, HeaderValue::from_static("text/html"), ); let res = svc.oneshot(Request::new(Body::empty())).await.unwrap(); let mut values = res.headers().get_all(header::CONTENT_TYPE).iter(); assert_eq!(values.next().unwrap(), "good-content"); assert_eq!(values.next().unwrap(), "text/html"); assert_eq!(values.next(), None); } #[tokio::test] async fn test_skip_if_present_mode() { let svc = SetResponseHeader::if_not_present( service_fn(|_req: Request| async { let res = Response::builder() .header(header::CONTENT_TYPE, "good-content") .body(Body::empty()) .unwrap(); Ok::<_, Infallible>(res) }), header::CONTENT_TYPE, HeaderValue::from_static("text/html"), ); let res = svc.oneshot(Request::new(Body::empty())).await.unwrap(); let mut values = res.headers().get_all(header::CONTENT_TYPE).iter(); assert_eq!(values.next().unwrap(), "good-content"); assert_eq!(values.next(), None); } #[tokio::test] async fn test_skip_if_present_mode_when_not_present() { let svc = SetResponseHeader::if_not_present( service_fn(|_req: Request| async { let res = Response::builder().body(Body::empty()).unwrap(); Ok::<_, Infallible>(res) }), header::CONTENT_TYPE, HeaderValue::from_static("text/html"), ); let res = svc.oneshot(Request::new(Body::empty())).await.unwrap(); let mut values = res.headers().get_all(header::CONTENT_TYPE).iter(); assert_eq!(values.next().unwrap(), "text/html"); assert_eq!(values.next(), None); } }