use crate::{extract::rejection::*, response::IntoResponseParts}; use axum_core::extract::OptionalFromRequestParts; use axum_core::{ extract::FromRequestParts, response::{IntoResponse, Response, ResponseParts}, }; use http::{request::Parts, Extensions, Request}; use std::{ convert::Infallible, task::{Context, Poll}, }; use tower_service::Service; /// Extractor and response for extensions. /// /// # As extractor /// /// This is commonly used to share state across handlers. /// /// ```rust,no_run /// use axum::{ /// Router, /// Extension, /// routing::get, /// }; /// use std::sync::Arc; /// /// // Some shared state used throughout our application /// struct State { /// // ... /// } /// /// async fn handler(state: Extension>) { /// // ... /// } /// /// let state = Arc::new(State { /* ... */ }); /// /// let app = Router::new().route("/", get(handler)) /// // Add middleware that inserts the state into all incoming request's /// // extensions. /// .layer(Extension(state)); /// # let _: Router = app; /// ``` /// /// If the extension is missing it will reject the request with a `500 Internal /// Server Error` response. Alternatively, you can use `Option>` to /// make the extension extractor optional. /// /// # As response /// /// Response extensions can be used to share state with middleware. /// /// ```rust /// use axum::{ /// Extension, /// response::IntoResponse, /// }; /// /// async fn handler() -> (Extension, &'static str) { /// ( /// Extension(Foo("foo")), /// "Hello, World!" /// ) /// } /// /// #[derive(Clone)] /// struct Foo(&'static str); /// ``` #[derive(Debug, Clone, Copy, Default)] #[must_use] pub struct Extension(pub T); impl Extension where T: Clone + Send + Sync + 'static, { fn from_extensions(extensions: &Extensions) -> Option { extensions.get::().cloned().map(Extension) } } impl FromRequestParts for Extension where T: Clone + Send + Sync + 'static, S: Send + Sync, { type Rejection = ExtensionRejection; async fn from_request_parts(req: &mut Parts, _state: &S) -> Result { Ok(Self::from_extensions(&req.extensions).ok_or_else(|| { MissingExtension::from_err(format!( "Extension of type `{}` was not found. Perhaps you forgot to add it? See `axum::Extension`.", std::any::type_name::() )) })?) } } impl OptionalFromRequestParts for Extension where T: Clone + Send + Sync + 'static, S: Send + Sync, { type Rejection = Infallible; async fn from_request_parts( req: &mut Parts, _state: &S, ) -> Result, Self::Rejection> { Ok(Self::from_extensions(&req.extensions)) } } axum_core::__impl_deref!(Extension); impl IntoResponseParts for Extension where T: Clone + Send + Sync + 'static, { type Error = Infallible; fn into_response_parts(self, mut res: ResponseParts) -> Result { res.extensions_mut().insert(self.0); Ok(res) } } impl IntoResponse for Extension where T: Clone + Send + Sync + 'static, { fn into_response(self) -> Response { let mut res = ().into_response(); res.extensions_mut().insert(self.0); res } } impl tower_layer::Layer for Extension where T: Clone + Send + Sync + 'static, { type Service = AddExtension; fn layer(&self, inner: S) -> Self::Service { AddExtension { inner, value: self.0.clone(), } } } /// Middleware for adding some shareable value to [request extensions]. /// /// See [Passing state from middleware to handlers](index.html#passing-state-from-middleware-to-handlers) /// for more details. /// /// [request extensions]: https://docs.rs/http/latest/http/struct.Extensions.html #[derive(Clone, Copy, Debug)] pub struct AddExtension { pub(crate) inner: S, pub(crate) value: T, } impl Service> for AddExtension where S: Service>, T: Clone + Send + Sync + 'static, { type Response = S::Response; type Error = S::Error; type Future = S::Future; #[inline] fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { self.inner.poll_ready(cx) } fn call(&mut self, mut req: Request) -> Self::Future { req.extensions_mut().insert(self.value.clone()); self.inner.call(req) } } #[cfg(test)] mod tests { use super::*; use crate::routing::get; use crate::test_helpers::TestClient; use crate::Router; use http::StatusCode; #[derive(Clone)] struct Foo(String); #[derive(Clone)] struct Bar(String); #[crate::test] async fn extension_extractor() { async fn requires_foo(Extension(foo): Extension) -> String { foo.0 } async fn optional_foo(extension: Option>) -> String { extension.map(|foo| foo.0 .0).unwrap_or("none".to_owned()) } async fn requires_bar(Extension(bar): Extension) -> String { bar.0 } async fn optional_bar(extension: Option>) -> String { extension.map(|bar| bar.0 .0).unwrap_or("none".to_owned()) } let app = Router::new() .route("/requires_foo", get(requires_foo)) .route("/optional_foo", get(optional_foo)) .route("/requires_bar", get(requires_bar)) .route("/optional_bar", get(optional_bar)) .layer(Extension(Foo("foo".to_owned()))); let client = TestClient::new(app); let response = client.get("/requires_foo").await; assert_eq!(response.status(), StatusCode::OK); assert_eq!(response.text().await, "foo"); let response = client.get("/optional_foo").await; assert_eq!(response.status(), StatusCode::OK); assert_eq!(response.text().await, "foo"); let response = client.get("/requires_bar").await; assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR); assert_eq!(response.text().await, "Missing request extension: Extension of type `axum::extension::tests::Bar` was not found. Perhaps you forgot to add it? See `axum::Extension`."); let response = client.get("/optional_bar").await; assert_eq!(response.status(), StatusCode::OK); assert_eq!(response.text().await, "none"); } }