diff options
| author | mo khan <mo@mokhan.ca> | 2025-07-10 13:11:11 -0600 |
|---|---|---|
| committer | mo khan <mo@mokhan.ca> | 2025-07-10 13:11:11 -0600 |
| commit | 01959b16a21b22b5df5f16569c2a8e8f92beecef (patch) | |
| tree | 32afa5d747c5466345c59ec52161a7cba3d6d755 /vendor/tower-http | |
| parent | ff30574117a996df332e23d1fb6f65259b316b5b (diff) | |
chore: vendor dependencies
Diffstat (limited to 'vendor/tower-http')
100 files changed, 23781 insertions, 0 deletions
diff --git a/vendor/tower-http/.cargo-checksum.json b/vendor/tower-http/.cargo-checksum.json new file mode 100644 index 00000000..e4a2f2d6 --- /dev/null +++ b/vendor/tower-http/.cargo-checksum.json @@ -0,0 +1 @@ +{"files":{"CHANGELOG.md":"0444e6082103f5ef9dc07ee95f7fd58c671729aca659f59db2a7deef562e2293","Cargo.lock":"c0d3a88b20acd5608170dce9f30a4edfa8696cecd9a93d8caa2ae6601a118aec","Cargo.toml":"14d34d789f3b0539fb6b7a7698bf0f22b525ef01b26a21157b54768968459a04","LICENSE":"5049cf464977eff4b4fcfa7988d84e74116956a3eb9d5f1d451b3f828f945233","README.md":"b222111103d0522442def1ccb0277520bf01c15a8902738dfd16d472e0375063","src/add_extension.rs":"689e8d8c0319911391a533616b1f642cebc7e3dc0838594b1582d7bb0dc0175c","src/auth/add_authorization.rs":"a5d866ead65ff7ca5bc89f7ec479c385dc90715d48911286bbd10d8a1ed644e8","src/auth/async_require_authorization.rs":"92a6bff979155de818f12b7d72841e0917a58789656bd7f6b06f3ad8a325857f","src/auth/mod.rs":"4a16268a7bfa5ca1f110c6eff9b8733ebfe96c0d280607c6b1f1154a92ccf128","src/auth/require_authorization.rs":"b8bbc94cc7d4e52a4e2305291b48c10012be7a8275f286d3d529e258995cb7cc","src/body.rs":"b6cb0269c26cd23838595288c495893b5eae5aa71b580dcae060ac157cb38af1","src/builder.rs":"0d471da399d156e83fb0831496c8d712d963cbb324523615827c631542b39711","src/catch_panic.rs":"4f933eeeb4b7e0ca0ea3ae5e72814efced9bbf629c2cf772bf00c9b7c6369ffc","src/classify/grpc_errors_as_failures.rs":"25bb3df620570525b2efe969dafef7198c458a2f3bbbf6cc0db1243f7bb1c65e","src/classify/map_failure_class.rs":"a9f7d36f3aede2990205a88c06542d68126bb58f0f530bb1f8716fa47587f1f6","src/classify/mod.rs":"b79278bee5fd8b7b9e1cae620ffbd3747f3d025783ac57746f743f1f4112f36b","src/classify/status_in_range_is_error.rs":"e63e5dbf55a37c05bf7c90d873870a65ebf598ddb921285b6ea338195780610a","src/compression/body.rs":"3e6fa2c569770abada1c9c21469baa0e85186fec3a537cf8797bf0607347411b","src/compression/future.rs":"77ca63e0c371e7a56e2401cbd5118dc013a35ad61090d5363f43ef650d86276f","src/compression/layer.rs":"6a57b901b4284fbfd3c3cd784a1b3ce90766510fd1ada8968ce6cea7e2989887","src/compression/mod.rs":"8c8284a429b75e467f7cd07bfd11f6144d199052f8db7ac5e91d811898514981","src/compression/pin_project_cfg.rs":"a98033f4b8b12f8372ba51522f22a8610d005e82e29b3e24e28b4a8cb902b2ef","src/compression/predicate.rs":"70753e77ed770ebb8c2fa40fffa79463badd27a42a0a51bcd7c65f21dc36962f","src/compression/service.rs":"61e5aa8d19cb032875d7024002926d132c9b7e8120ef6e3d2e68f68261e8e327","src/compression_utils.rs":"c76b1a47fa0cb336f92adc87859f146e0e7d9770e5abbc67c0a4230357e10913","src/content_encoding.rs":"dea2320a31bdcb02c4e36fbe1d0b1da614dab9d063f4a8940a0effdd1ed76e15","src/cors/allow_credentials.rs":"9b4e114f78f08e9fe583bcca642df069c4af9b97af82a1d611fd3223a6f254ea","src/cors/allow_headers.rs":"a30892e380530864a17e4fe432d3d8b5f3f3da0afb37bd19ac309833ee34734a","src/cors/allow_methods.rs":"ea35bf01415873be4a013d4851e495750d409854384c3bc32e24aed18b9079fd","src/cors/allow_origin.rs":"c818cc8830d137748e670b519eb43080274a0afd8d189428eff4e04d6d9edbfe","src/cors/allow_private_network.rs":"2ba150f838456be32588c2b95cee668e7f96ab191a769694c880a86d8a3c1814","src/cors/expose_headers.rs":"9ab7d0dbfa921013c9be7679d89cb811649b2606c51becc72a16f627eeea465b","src/cors/max_age.rs":"0d2b67f1c092b9d36b20e05a04bfdf7eb57215c23a04cf2ca5fae07fca066697","src/cors/mod.rs":"129427ee1de43ef69dc0b29265b2e300306699a6655f70d96d93d69df9c50e0e","src/cors/tests.rs":"459135218afdf6aa74629b7d9eb87dfaf60fae9e0333c510e9f6350a3a5a7f93","src/cors/vary.rs":"1f60668bb835da2c71d58711aa7f08f9707de65844055b9a4fed5ba608f1127e","src/decompression/body.rs":"be0cf9b51de29c80eff7b00201f021357b82a79d9fc1cc0a41eb730dfee39993","src/decompression/future.rs":"a9cfc2f175854bb85b76901e8dbcbfdf743db92c822ccf589647ba18ef82c730","src/decompression/layer.rs":"98d13d3a107ad4809b5bfbc6e509cde0c0876ae4596f6ae5d985d007594cbbdf","src/decompression/mod.rs":"605b418aca8c8655b753772d2d5c302e521b475d8d776276b2b53cb597694c7d","src/decompression/request/future.rs":"d7da33415760ef36cd42d519ff44f0157333f5f64e5372deb7b68fde058ed95c","src/decompression/request/layer.rs":"f17a14ab9d8408067767b058a6fb848a2891b9f68fbbf6e192086e8f00bc7d89","src/decompression/request/mod.rs":"57b9e4844d6b9320547b05e00a2256737afd38e86fc17fefb1b62974ac6d8e9e","src/decompression/request/service.rs":"af905c7eee15d72840ec4685bc2e68854ff1103a760504a6db91d00de47b7f93","src/decompression/service.rs":"94accf60490c9e6184b1da72f0a9dc9f4a2428481955f23f920a381348e19860","src/follow_redirect/mod.rs":"9acdbb54abec919498cbaf4d57fbef5b659a69897714e6c9d8d2002f17bebe24","src/follow_redirect/policy/and.rs":"a62623042f4d13029ca0d35a21cab20f26bf98fff0d321dd19ee6eadef96ee02","src/follow_redirect/policy/clone_body_fn.rs":"3a78bf37d4bd000d9c2d60d84a2d02d2d0ae584a0790da2dcdb34fab43fcd557","src/follow_redirect/policy/filter_credentials.rs":"918ce212685ce6501f78e6346c929fec8e01e81b26d681f6d3c86d88fe2eaa97","src/follow_redirect/policy/limited.rs":"b958035fc38736e12ef2a5421366de51c806c8ca849e8e9310b9d14e8b0b1e07","src/follow_redirect/policy/mod.rs":"e4185953e23944928f49eb9debe59da78ffb1fd87e5ac03cbab0079ccef3e316","src/follow_redirect/policy/or.rs":"02de001232c92a9e7e19cdef70b1321df181c6323973e6297642cc234dbf3119","src/follow_redirect/policy/redirect_fn.rs":"f4f7bf9219df8da1021ef9f44b07b50814dfa0728c8dbf52090e0dfab0b8edcc","src/follow_redirect/policy/same_origin.rs":"9c47be5b615c3dd31db8056e324a4bc87b0510c19df09b6a9e5f7ea8de2829fe","src/lib.rs":"ebbdd27b9937f9f944c4ac14842f46f0ae12d2321df149cee6e8ee9058c33667","src/limit/body.rs":"aa59aba00aae4ca98097d746efeabff2f650e1ac60bbea30179644e76b5ea1f9","src/limit/future.rs":"6c6feba8766a38e8fd68df7f73677a13047d71911acc0578f80b7d70ab0142d0","src/limit/layer.rs":"a9ebe5e09b32d7ca71442f855af622b3657dca140d95acf7021c4d57c7b50576","src/limit/mod.rs":"22ecc0e5cf5e2d526da2b04e8ec8af715ae58338fc4031e05050d4a64ce79c8a","src/limit/service.rs":"b73f2016d1feb0b61fc4924597cbb06a43106ce213ce8104feabf953c7eefe2d","src/macros.rs":"d9e425b15811809ef9a002c7f86376737ca401a435630f59d4451c1863aed823","src/map_request_body.rs":"35eb77cc8d2257b849348a68aae050d9dee7a0869b433748b4a038cc8c70ee2f","src/map_response_body.rs":"691263b8c85bd595aed6c55a3d3e2fd6d8e19dca77dd2e5df283fba581b7df56","src/metrics/in_flight_requests.rs":"615652b49586bb809b8f841f15ee3ba70c8e62620e31c81d062e1100e88619e2","src/metrics/mod.rs":"71d79df8dcb242f4925dc9b0d380d3c4fa7ae1f3d6a125f9db4f8a4ee3be9b3d","src/normalize_path.rs":"7859fefac3eb454e1fe52a315b96f018b0241e1a3c1202fd3a4032fda7df03a6","src/propagate_header.rs":"d123d13557a28d9a797486285233cb4bade6fc318d803d07f8e93bca831a7750","src/request_id.rs":"d555b5df675a1497d97768dd0d280d2d8a3d1d3abecc8969085559a483241f25","src/sensitive_headers.rs":"ab78f92d1482a3818f743b412316d6073dae6bf94ee06b22b60fe488d645cbbc","src/service_ext.rs":"74754c8d6e5d4c1a06e609244ffc7680094c6a33ce44e9bd024a7ad1822c183e","src/services/fs/mod.rs":"f89e32e1c49d567259e0828896896855b238e35dd3bfb17cd52b261e11087507","src/services/fs/serve_dir/future.rs":"743d45c645c126d3de208701e718746b13e89ccb9b0c48c142d6848b4662fb33","src/services/fs/serve_dir/headers.rs":"d48fb514ca575e5e380f80eb84521a5fcab0560e57a995f1fef5ca35590400e8","src/services/fs/serve_dir/mod.rs":"3efa5af990c5d92021563aebf33a99ae27782bc553a3fee36906610e109cbfd6","src/services/fs/serve_dir/open_file.rs":"2d494cd9b75d89d6299d24565bb36015621a9923acbe647800bc9a8eeac56f9f","src/services/fs/serve_dir/tests.rs":"a3672e27ea51a42d0284b4bb02c490b5171e0fcf14961a4ce37bea9bb369a3ed","src/services/fs/serve_file.rs":"a2f6feee5b6261d6643c91a2b6c5547f3dea2b48e823ccea4d95ec0f0f4bb561","src/services/mod.rs":"177bf1406c13c0386c82b247e6d1556c55c7a2f6704de7e50dbc987400713b96","src/services/redirect.rs":"480cb9d2fefdcbe1f70c428a78faa3aa283a4f44eb26dff5c2d36cdd543d011a","src/set_header/mod.rs":"642db9ef0fed43886e12311e26ae2522d25275ff9263cb5f0ef500ed6ce5f6bd","src/set_header/request.rs":"6261a0e89413bf8d5bcecb25acbf0c7eee6bbfd57590676ac22d4dfe46c12ad1","src/set_header/response.rs":"b6c659771a61cfba3ab814df32353582c05f695bf6bf21a2265d7d6ccb709440","src/set_status.rs":"9dfc8c6d598a45483b8a489d6362b7bb8debd1feb0c8304a69c003a6ae4882d3","src/test_helpers.rs":"31db04c69cc898baa7dd6c33a42356324585c98d4e120fa45a3c1b00da336fd5","src/timeout/body.rs":"fb2ae892a082f00835084693bdffe2980f5c94cd45ef0be05109ba8cae35b351","src/timeout/mod.rs":"8032cbcc0863d22d9bd3f89dda5e7dd85574e53811ab5c98c99aaa12d21bd646","src/timeout/service.rs":"fa1b0868ab246d9ad01504e89e5c2bf11564b72e4becc97f7e8967434537c2b4","src/trace/body.rs":"c4aabdc0c6799e8425ca750730f6a6c108727f0c48cef57f2160a5cc22e96ecb","src/trace/future.rs":"1b0334a22f07017b589e51f6d7bda472161ac58435202be031a5aab0e741e266","src/trace/layer.rs":"9f9a52c51f356fa0e2f4e83942dac1e04424d52589723ff74927754d68b38a77","src/trace/make_span.rs":"1066f20074c3da019901f4c614b18e8bd574170fe6cdcbc090ab9bf42361f876","src/trace/mod.rs":"57a650c3d191281c502ee6701f7ff0648e2e95aabbea4b037583910d40b2f075","src/trace/on_body_chunk.rs":"824f83e4b44e5656fd48922addf02c010764cd73ec4d290213be2b990751f3ca","src/trace/on_eos.rs":"321f2afd63eef9a1be0bbe5e5bb450555bea984bc28381f92b31a17b6e466237","src/trace/on_failure.rs":"2aa316893e4c2df0ac0bfe8b597a9eaee8db79a243c42480be16fe2ebdf58f41","src/trace/on_request.rs":"9a88d6061c2f638d04dabf79317d014f7d47abb3c6e30730b687294ff135d646","src/trace/on_response.rs":"9b22781e2c2f1003ad5c4d0525ab26c037905660d769dd0717f4dcf359e7319a","src/trace/service.rs":"2b96171af5c11ad7d7e372afd8d86aed824de84ca64ae1cfdf832b8506646a66","src/validate_request.rs":"2310e2b00bd3c6fd3330e8eaed502197c6fcccdc833331555c73e115e3b63400"},"package":"adc82fd73de2a9722ac5da747f12383d2bfdb93591ee6c58486e0097890f05f2"}
\ No newline at end of file diff --git a/vendor/tower-http/CHANGELOG.md b/vendor/tower-http/CHANGELOG.md new file mode 100644 index 00000000..7f8c151f --- /dev/null +++ b/vendor/tower-http/CHANGELOG.md @@ -0,0 +1,501 @@ +# Changelog + +All notable changes to this project will be documented in this file. + +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), +and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). + +# 0.6.6 + +## Fixed + +- compression: fix panic when looking in vary header ([#578]) + +[#578]: https://github.com/tower-rs/tower-http/pull/578 + +# 0.6.5 + +## Added + +- normalize_path: add `append_trailing_slash()` mode ([#547]) + +## Fixed + +- redirect: remove payload headers if redirect changes method to GET ([#575]) +- compression: avoid setting `vary: accept-encoding` if already set ([#572]) + +[#547]: https://github.com/tower-rs/tower-http/pull/547 +[#572]: https://github.com/tower-rs/tower-http/pull/572 +[#575]: https://github.com/tower-rs/tower-http/pull/575 + +# 0.6.4 + +## Added + +- decompression: Support HTTP responses containing multiple ZSTD frames ([#548]) +- The `ServiceExt` trait for chaining layers onto an arbitrary http service just + like `ServiceBuilderExt` allows for `ServiceBuilder` ([#563]) + +## Fixed + +- Remove unnecessary trait bounds on `S::Error` for `Service` impls of + `RequestBodyTimeout<S>` and `ResponseBodyTimeout<S>` ([#533]) +- compression: Respect `is_end_stream` ([#535]) +- Fix a rare panic in `fs::ServeDir` ([#553]) +- Fix invalid `content-lenght` of 1 in response to range requests to empty + files ([#556]) +- In `AsyncRequireAuthorization`, use the original inner service after it is + ready, instead of using a clone ([#561]) + +[#533]: https://github.com/tower-rs/tower-http/pull/533 +[#535]: https://github.com/tower-rs/tower-http/pull/535 +[#548]: https://github.com/tower-rs/tower-http/pull/548 +[#553]: https://github.com/tower-rs/tower-http/pull/556 +[#556]: https://github.com/tower-rs/tower-http/pull/556 +[#561]: https://github.com/tower-rs/tower-http/pull/561 +[#563]: https://github.com/tower-rs/tower-http/pull/563 + +# 0.6.3 + +*This release was yanked because its definition of `ServiceExt` was quite unhelpful, in a way that's very unlikely that anybody would start depending on within the small timeframe before this was yanked, but that was technically breaking to change.* + +# 0.6.2 + +## Changed: + +- `CompressionBody<B>` now propagates `B`'s size hint in its `http_body::Body` + implementation, if compression is disabled ([#531]) + - this allows a `content-length` to be included in an HTTP message with this + body for those cases + +[#531]: https://github.com/tower-rs/tower-http/pull/531 + +# 0.6.1 + +## Fixed + +- **decompression:** reuse scratch buffer to significantly reduce allocations and improve performance ([#521]) + +[#521]: https://github.com/tower-rs/tower-http/pull/521 + +# 0.6.0 + +## Changed: + +- `body` module is disabled except for `catch-panic`, `decompression-*`, `fs`, or `limit` features (BREAKING) ([#477]) +- Update to `tower` 0.5 ([#503]) + +## Fixed + +- **fs:** Precompression of static files now supports files without a file extension ([#507]) + +[#477]: https://github.com/tower-rs/tower-http/pull/477 +[#503]: https://github.com/tower-rs/tower-http/pull/503 +[#507]: https://github.com/tower-rs/tower-http/pull/507 + +# 0.5.2 + +## Added: + +- **compression:** Will now send a `vary: accept-encoding` header on compressed responses ([#399]) +- **compression:** Support `x-gzip` as equivalent to `gzip` in `accept-encoding` request header ([#467]) + +## Fixed + +- **compression:** Skip compression for range requests ([#446]) +- **compression:** Skip compression for SSE responses by default ([#465]) +- **cors:** *Actually* keep Vary headers set by the inner service when setting response headers ([#473]) + - Version 0.5.1 intended to ship this, but the implementation was buggy and didn't actually do anything + +[#399]: https://github.com/tower-rs/tower-http/pull/399 +[#446]: https://github.com/tower-rs/tower-http/pull/446 +[#465]: https://github.com/tower-rs/tower-http/pull/465 +[#467]: https://github.com/tower-rs/tower-http/pull/467 +[#473]: https://github.com/tower-rs/tower-http/pull/473 + +# 0.5.1 (January 14, 2024) + +## Added + +- **fs:** Support files precompressed with `zstd` in `ServeFile` +- **trace:** Add default generic parameters for `ResponseBody` and `ResponseFuture` ([#455]) +- **trace:** Add type aliases `HttpMakeClassifier` and `GrpcMakeClassifier` ([#455]) + +## Fixed + +- **cors:** Keep Vary headers set by the inner service when setting response headers ([#398]) +- **fs:** `ServeDir` now no longer redirects from `/directory` to `/directory/` + if `append_index_html_on_directories` is disabled ([#421]) + +[#398]: https://github.com/tower-rs/tower-http/pull/398 +[#421]: https://github.com/tower-rs/tower-http/pull/421 +[#455]: https://github.com/tower-rs/tower-http/pull/455 + +# 0.5.0 (November 21, 2023) + +## Changed + +- Bump Minimum Supported Rust Version to 1.66 ([#433]) +- Update to http-body 1.0 ([#348]) +- Update to http 1.0 ([#348]) +- Preserve service error type in RequestDecompression ([#368]) + +## Fixed + +- Accepts range headers with ranges where the end of range goes past the end of the document by bumping +http-range-header to `0.4` + +[#418]: https://github.com/tower-rs/tower-http/pull/418 +[#433]: https://github.com/tower-rs/tower-http/pull/433 +[#348]: https://github.com/tower-rs/tower-http/pull/348 +[#368]: https://github.com/tower-rs/tower-http/pull/368 + +# 0.4.2 (July 19, 2023) + +## Added + +- **cors:** Add support for private network preflights ([#373]) +- **compression:** Implement `Default` for `DecompressionBody` ([#370]) + +## Changed + +- **compression:** Update to async-compression 0.4 ([#371]) + +## Fixed + +- **compression:** Override default brotli compression level 11 -> 4 ([#356]) +- **trace:** Simplify dynamic tracing level application ([#380]) +- **normalize_path:** Fix path normalization for preceding slashes ([#359]) + +[#356]: https://github.com/tower-rs/tower-http/pull/356 +[#359]: https://github.com/tower-rs/tower-http/pull/359 +[#370]: https://github.com/tower-rs/tower-http/pull/370 +[#371]: https://github.com/tower-rs/tower-http/pull/371 +[#373]: https://github.com/tower-rs/tower-http/pull/373 +[#380]: https://github.com/tower-rs/tower-http/pull/380 + +# 0.4.1 (June 20, 2023) + +## Added + +- **request_id:** Derive `Default` for `MakeRequestUuid` ([#335]) +- **fs:** Derive `Default` for `ServeFileSystemResponseBody` ([#336]) +- **compression:** Expose compression quality on the CompressionLayer ([#333]) + +## Fixed + +- **compression:** Improve parsing of `Accept-Encoding` request header ([#220]) +- **normalize_path:** Fix path normalization of index route ([#347]) +- **decompression:** Enable `multiple_members` for `GzipDecoder` ([#354]) + +[#347]: https://github.com/tower-rs/tower-http/pull/347 +[#333]: https://github.com/tower-rs/tower-http/pull/333 +[#220]: https://github.com/tower-rs/tower-http/pull/220 +[#335]: https://github.com/tower-rs/tower-http/pull/335 +[#336]: https://github.com/tower-rs/tower-http/pull/336 +[#354]: https://github.com/tower-rs/tower-http/pull/354 + +# 0.4.0 (February 24, 2023) + +## Added + +- **decompression:** Add `RequestDecompression` middleware ([#282]) +- **compression:** Implement `Default` for `CompressionBody` ([#323]) +- **compression, decompression:** Support zstd (de)compression ([#322]) + +## Changed + +- **serve_dir:** `ServeDir` and `ServeFile`'s error types are now `Infallible` and any IO errors + will be converted into responses. Use `try_call` to generate error responses manually (BREAKING) ([#283]) +- **serve_dir:** `ServeDir::fallback` and `ServeDir::not_found_service` now requires + the fallback service to use `Infallible` as its error type (BREAKING) ([#283]) +- **compression, decompression:** Tweak prefered compression encodings ([#325]) + +## Removed + +- Removed `RequireAuthorization` in favor of `ValidateRequest` (BREAKING) ([#290]) + +## Fixed + +- **serve_dir:** Don't include identity in Content-Encoding header ([#317]) +- **compression:** Do compress SVGs ([#321]) +- **serve_dir:** In `ServeDir`, convert `io::ErrorKind::NotADirectory` to `404 Not Found` ([#331]) + +[#282]: https://github.com/tower-rs/tower-http/pull/282 +[#283]: https://github.com/tower-rs/tower-http/pull/283 +[#290]: https://github.com/tower-rs/tower-http/pull/290 +[#317]: https://github.com/tower-rs/tower-http/pull/317 +[#321]: https://github.com/tower-rs/tower-http/pull/321 +[#322]: https://github.com/tower-rs/tower-http/pull/322 +[#323]: https://github.com/tower-rs/tower-http/pull/323 +[#325]: https://github.com/tower-rs/tower-http/pull/325 +[#331]: https://github.com/tower-rs/tower-http/pull/331 + +# 0.3.5 (December 02, 2022) + +## Added + +- Add `NormalizePath` middleware ([#275]) +- Add `ValidateRequest` middleware ([#289]) +- Add `RequestBodyTimeout` middleware ([#303]) + +## Changed + +- Bump Minimum Supported Rust Version to 1.60 ([#299]) + +## Fixed + +- **trace:** Correctly identify gRPC requests in default `on_response` callback ([#278]) +- **cors:** Panic if a wildcard (`*`) is passed to `AllowOrigin::list`. Use + `AllowOrigin::any()` instead ([#285]) +- **serve_dir:** Call the fallback on non-uft8 request paths ([#310]) + +[#275]: https://github.com/tower-rs/tower-http/pull/275 +[#278]: https://github.com/tower-rs/tower-http/pull/278 +[#285]: https://github.com/tower-rs/tower-http/pull/285 +[#289]: https://github.com/tower-rs/tower-http/pull/289 +[#299]: https://github.com/tower-rs/tower-http/pull/299 +[#303]: https://github.com/tower-rs/tower-http/pull/303 +[#310]: https://github.com/tower-rs/tower-http/pull/310 + +# 0.3.4 (June 06, 2022) + +## Added + +- Add `Timeout` middleware ([#270]) +- Add `RequestBodyLimit` middleware ([#271]) + +[#270]: https://github.com/tower-rs/tower-http/pull/270 +[#271]: https://github.com/tower-rs/tower-http/pull/271 + +# 0.3.3 (May 08, 2022) + +## Added + +- **serve_dir:** Add `ServeDir::call_fallback_on_method_not_allowed` to allow calling the fallback + for requests that aren't `GET` or `HEAD` ([#264]) +- **request_id:** Add `MakeRequestUuid` for generating request ids using UUIDs ([#266]) + +[#264]: https://github.com/tower-rs/tower-http/pull/264 +[#266]: https://github.com/tower-rs/tower-http/pull/266 + +## Fixed + +- **serve_dir:** Include `Allow` header for `405 Method Not Allowed` responses ([#263]) + +[#263]: https://github.com/tower-rs/tower-http/pull/263 + +# 0.3.2 (April 29, 2022) + +## Fixed + +- **serve_dir**: Fix empty request parts being passed to `ServeDir`'s fallback instead of the actual ones ([#258]) + +[#258]: https://github.com/tower-rs/tower-http/pull/258 + +# 0.3.1 (April 28, 2022) + +## Fixed + +- **cors**: Only send a single origin in `Access-Control-Allow-Origin` header when a list of + allowed origins is configured (the previous behavior of sending a comma-separated list like for + allowed methods and allowed headers is not allowed by any standard) + +# 0.3.0 (April 25, 2022) + +## Added + +- **fs**: Add `ServeDir::{fallback, not_found_service}` for calling another service if + the file cannot be found ([#243]) +- **fs**: Add `SetStatus` to override status codes ([#248]) +- `ServeDir` and `ServeFile` now respond with `405 Method Not Allowed` to requests where the + method isn't `GET` or `HEAD` ([#249]) +- **cors**: Added `CorsLayer::very_permissive` which is like + `CorsLayer::permissive` except it (truly) allows credentials. This is made + possible by mirroring the request's origin as well as method and headers + back as CORS-whitelisted ones ([#237]) +- **cors**: Allow customizing the value(s) for the `Vary` header ([#237]) + +## Changed + +- **cors**: Removed `allow-credentials: true` from `CorsLayer::permissive`. + It never actually took effect in compliant browsers because it is mutually + exclusive with the `*` wildcard (`Any`) on origins, methods and headers ([#237]) +- **cors**: Rewrote the CORS middleware. Almost all existing usage patterns + will continue to work. (BREAKING) ([#237]) +- **cors**: The CORS middleware will now panic if you try to use `Any` in + combination with `.allow_credentials(true)`. This configuration worked + before, but resulted in browsers ignoring the `allow-credentials` header, + which defeats the purpose of setting it and can be very annoying to debug + ([#237]) + +## Fixed + +- **fs**: Fix content-length calculation on range requests ([#228]) + +[#228]: https://github.com/tower-rs/tower-http/pull/228 +[#237]: https://github.com/tower-rs/tower-http/pull/237 +[#243]: https://github.com/tower-rs/tower-http/pull/243 +[#248]: https://github.com/tower-rs/tower-http/pull/248 +[#249]: https://github.com/tower-rs/tower-http/pull/249 + +# 0.2.4 (March 5, 2022) + +## Added + +- Added `CatchPanic` middleware which catches panics and converts them + into `500 Internal Server` responses ([#214]) + +## Fixed + +- Make parsing of `Accept-Encoding` more robust ([#220]) + +[#214]: https://github.com/tower-rs/tower-http/pull/214 +[#220]: https://github.com/tower-rs/tower-http/pull/220 + +# 0.2.3 (February 18, 2022) + +## Changed + +- Update to tokio-util 0.7 ([#221]) + +## Fixed + +- The CORS layer / service methods `allow_headers`, `allow_methods`, `allow_origin` + and `expose_headers` now do nothing if given an empty `Vec`, instead of sending + the respective header with an empty value ([#218]) + +[#218]: https://github.com/tower-rs/tower-http/pull/218 +[#221]: https://github.com/tower-rs/tower-http/pull/221 + +# 0.2.2 (February 8, 2022) + +## Fixed + +- Add `Vary` headers for CORS preflight responses ([#216]) + +[#216]: https://github.com/tower-rs/tower-http/pull/216 + +# 0.2.1 (January 21, 2022) + +## Added + +- Support `Last-Modified` (and friends) headers in `ServeDir` and `ServeFile` ([#145]) +- Add `AsyncRequireAuthorization::layer` ([#195]) + +## Fixed + +- Fix build error for certain feature sets ([#209]) +- `Cors`: Set `Vary` header ([#199]) +- `ServeDir` and `ServeFile`: Fix potential directory traversal attack due to + improper path validation on Windows ([#204]) + +[#145]: https://github.com/tower-rs/tower-http/pull/145 +[#195]: https://github.com/tower-rs/tower-http/pull/195 +[#199]: https://github.com/tower-rs/tower-http/pull/199 +[#204]: https://github.com/tower-rs/tower-http/pull/204 +[#209]: https://github.com/tower-rs/tower-http/pull/209 + +# 0.2.0 (December 1, 2021) + +## Added + +- **builder**: Add `ServiceBuilderExt` which adds methods to `tower::ServiceBuilder` for + adding middleware from tower-http ([#106]) +- **request_id**: Add `SetRequestId` and `PropagateRequestId` middleware ([#150]) +- **trace**: Add `DefaultMakeSpan::level` to make log level of tracing spans easily configurable ([#124]) +- **trace**: Add `LatencyUnit::Seconds` for formatting latencies as seconds ([#179]) +- **trace**: Support customizing which status codes are considered failures by `GrpcErrorsAsFailures` ([#189]) +- **compression**: Support specifying predicates to choose when responses should + be compressed. This can be used to disable compression of small responses, + responses with a certain `content-type`, or something user defined ([#172]) +- **fs**: Ability to serve precompressed files ([#156]) +- **fs**: Support `Range` requests ([#173]) +- **fs**: Properly support HEAD requests which return no body and have the `Content-Length` header set ([#169]) + +## Changed + +- `AddAuthorization`, `InFlightRequests`, `SetRequestHeader`, + `SetResponseHeader`, `AddExtension`, `MapRequestBody` and `MapResponseBody` + now requires underlying service to use `http::Request<ReqBody>` and + `http::Response<ResBody>` as request and responses ([#182]) (BREAKING) +- **set_header**: Remove unnecessary generic parameter from `SetRequestHeaderLayer` + and `SetResponseHeaderLayer`. This removes the need (and possibility) to specify a + body type for these layers ([#148]) (BREAKING) +- **compression, decompression**: Change the response body error type to + `Box<dyn std::error::Error + Send + Sync>`. This makes them usable if + the body they're wrapping uses `Box<dyn std::error::Error + Send + Sync>` as + its error type which they previously weren't ([#166]) (BREAKING) +- **fs**: Change response body type of `ServeDir` and `ServeFile` to + `ServeFileSystemResponseBody` and `ServeFileSystemResponseFuture` ([#187]) (BREAKING) +- **auth**: Change `AuthorizeRequest` and `AsyncAuthorizeRequest` traits to be simpler ([#192]) (BREAKING) + +## Removed + +- **compression, decompression**: Remove `BodyOrIoError`. Its been replaced with `Box<dyn + std::error::Error + Send + Sync>` ([#166]) (BREAKING) +- **compression, decompression**: Remove the `compression` and `decompression` feature. They were unnecessary + and `compression-full`/`decompression-full` can be used to get full + compression/decompression support. For more granular control, `[compression|decompression]-gzip`, + `[compression|decompression]-br` and `[compression|decompression]-deflate` may + be used instead ([#170]) (BREAKING) + +[#106]: https://github.com/tower-rs/tower-http/pull/106 +[#124]: https://github.com/tower-rs/tower-http/pull/124 +[#148]: https://github.com/tower-rs/tower-http/pull/148 +[#150]: https://github.com/tower-rs/tower-http/pull/150 +[#156]: https://github.com/tower-rs/tower-http/pull/156 +[#166]: https://github.com/tower-rs/tower-http/pull/166 +[#169]: https://github.com/tower-rs/tower-http/pull/169 +[#170]: https://github.com/tower-rs/tower-http/pull/170 +[#172]: https://github.com/tower-rs/tower-http/pull/172 +[#173]: https://github.com/tower-rs/tower-http/pull/173 +[#179]: https://github.com/tower-rs/tower-http/pull/179 +[#182]: https://github.com/tower-rs/tower-http/pull/182 +[#187]: https://github.com/tower-rs/tower-http/pull/187 +[#189]: https://github.com/tower-rs/tower-http/pull/189 +[#192]: https://github.com/tower-rs/tower-http/pull/192 + +# 0.1.2 (November 13, 2021) + +- New middleware: Add `Cors` for setting [CORS] headers ([#112]) +- New middleware: Add `AsyncRequireAuthorization` ([#118]) +- `Compression`: Don't recompress HTTP responses ([#140]) +- `Compression` and `Decompression`: Pass configuration from layer into middleware ([#132]) +- `ServeDir` and `ServeFile`: Improve performance ([#137]) +- `Compression`: Remove needless `ResBody::Error: Into<BoxError>` bounds ([#117]) +- `ServeDir`: Percent decode path segments ([#129]) +- `ServeDir`: Use correct redirection status ([#130]) +- `ServeDir`: Return `404 Not Found` on requests to directories if + `append_index_html_on_directories` is set to `false` ([#122]) + +[#112]: https://github.com/tower-rs/tower-http/pull/112 +[#118]: https://github.com/tower-rs/tower-http/pull/118 +[#140]: https://github.com/tower-rs/tower-http/pull/140 +[#132]: https://github.com/tower-rs/tower-http/pull/132 +[#137]: https://github.com/tower-rs/tower-http/pull/137 +[#117]: https://github.com/tower-rs/tower-http/pull/117 +[#129]: https://github.com/tower-rs/tower-http/pull/129 +[#130]: https://github.com/tower-rs/tower-http/pull/130 +[#122]: https://github.com/tower-rs/tower-http/pull/122 + +# 0.1.1 (July 2, 2021) + +- Add example of using `SharedClassifier`. +- Add `StatusInRangeAsFailures` which is a response classifier that considers + responses with status code in a certain range as failures. Useful for HTTP + clients where both server errors (5xx) and client errors (4xx) are considered + failures. +- Implement `Debug` for `NeverClassifyEos`. +- Update iri-string to 0.4. +- Add `ClassifyResponse::map_failure_class` and `ClassifyEos::map_failure_class` + for transforming the failure classification using a function. +- Clarify exactly when each `Trace` callback is called. +- Add `AddAuthorizationLayer` for setting the `Authorization` header on + requests. + +# 0.1.0 (May 27, 2021) + +- Initial release. + +[CORS]: https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS diff --git a/vendor/tower-http/Cargo.lock b/vendor/tower-http/Cargo.lock new file mode 100644 index 00000000..576fa044 --- /dev/null +++ b/vendor/tower-http/Cargo.lock @@ -0,0 +1,985 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 3 + +[[package]] +name = "addr2line" +version = "0.24.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dfbe277e56a376000877090da837660b4427aad530e3028d44e0bffe4f89a1c1" +dependencies = [ + "gimli", +] + +[[package]] +name = "adler2" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "512761e0bb2578dd7380c6baaa0f4ce03e84f95e960231d1dec8bf4d7d6e2627" + +[[package]] +name = "alloc-no-stdlib" +version = "2.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc7bb162ec39d46ab1ca8c77bf72e890535becd1751bb45f64c597edb4c8c6b3" + +[[package]] +name = "alloc-stdlib" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94fb8275041c72129eb51b7d0322c29b8387a0386127718b096429201a5d6ece" +dependencies = [ + "alloc-no-stdlib", +] + +[[package]] +name = "async-compression" +version = "0.4.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0cb8f1d480b0ea3783ab015936d2a55c87e219676f0c0b7dec61494043f21857" +dependencies = [ + "brotli", + "flate2", + "futures-core", + "memchr", + "pin-project-lite", + "tokio", + "zstd", + "zstd-safe", +] + +[[package]] +name = "async-trait" +version = "0.1.83" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "721cae7de5c34fbb2acd27e21e6d2cf7b886dce0c27388d46c4e6c47ea4318dd" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "autocfg" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" + +[[package]] +name = "backtrace" +version = "0.3.74" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8d82cb332cdfaed17ae235a638438ac4d4839913cc2af585c3c6746e8f8bee1a" +dependencies = [ + "addr2line", + "cfg-if", + "libc", + "miniz_oxide", + "object", + "rustc-demangle", + "windows-targets", +] + +[[package]] +name = "base64" +version = "0.22.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" + +[[package]] +name = "bitflags" +version = "2.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de" + +[[package]] +name = "brotli" +version = "7.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc97b8f16f944bba54f0433f07e30be199b6dc2bd25937444bbad560bcea29bd" +dependencies = [ + "alloc-no-stdlib", + "alloc-stdlib", + "brotli-decompressor", +] + +[[package]] +name = "brotli-decompressor" +version = "4.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a45bd2e4095a8b518033b128020dd4a55aab1c0a381ba4404a472630f4bc362" +dependencies = [ + "alloc-no-stdlib", + "alloc-stdlib", +] + +[[package]] +name = "bytes" +version = "1.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ac0150caa2ae65ca5bd83f25c7de183dea78d4d366469f148435e2acfbad0da" + +[[package]] +name = "cc" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fd9de9f2205d5ef3fd67e685b0df337994ddd4495e2a28d185500d0e1edfea47" +dependencies = [ + "jobserver", + "libc", + "shlex", +] + +[[package]] +name = "cfg-if" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" + +[[package]] +name = "crc32fast" +version = "1.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a97769d94ddab943e4510d138150169a2758b5ef3eb191a9ee688de3e23ef7b3" +dependencies = [ + "cfg-if", +] + +[[package]] +name = "flate2" +version = "1.0.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c936bfdafb507ebbf50b8074c54fa31c5be9a1e7e5f467dd659697041407d07c" +dependencies = [ + "crc32fast", + "miniz_oxide", +] + +[[package]] +name = "fnv" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" + +[[package]] +name = "futures-channel" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2dff15bf788c671c1934e366d07e30c1814a8ef514e1af724a602e8a2fbe1b10" +dependencies = [ + "futures-core", +] + +[[package]] +name = "futures-core" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e" + +[[package]] +name = "futures-macro" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "futures-sink" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e575fab7d1e0dcb8d0c7bcf9a63ee213816ab51902e6d244a95819acacf1d4f7" + +[[package]] +name = "futures-task" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988" + +[[package]] +name = "futures-util" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81" +dependencies = [ + "futures-core", + "futures-macro", + "futures-task", + "pin-project-lite", + "pin-utils", + "slab", +] + +[[package]] +name = "getrandom" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7" +dependencies = [ + "cfg-if", + "libc", + "wasi", +] + +[[package]] +name = "gimli" +version = "0.31.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f" + +[[package]] +name = "hermit-abi" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d231dfb89cfffdbc30e7fc41579ed6066ad03abda9e567ccafae602b97ec5024" + +[[package]] +name = "http" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "21b9ddb458710bc376481b842f5da65cdf31522de232c1ca8146abce2a358258" +dependencies = [ + "bytes", + "fnv", + "itoa", +] + +[[package]] +name = "http-body" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184" +dependencies = [ + "bytes", + "http", +] + +[[package]] +name = "http-body-util" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "793429d76616a256bcb62c2a2ec2bed781c8307e797e2598c50010f2bee2544f" +dependencies = [ + "bytes", + "futures-util", + "http", + "http-body", + "pin-project-lite", +] + +[[package]] +name = "http-range-header" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08a397c49fec283e3d6211adbe480be95aae5f304cfb923e9970e08956d5168a" + +[[package]] +name = "httparse" +version = "1.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7d71d3574edd2771538b901e6549113b4006ece66150fb69c0fb6d9a2adae946" + +[[package]] +name = "httpdate" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" + +[[package]] +name = "hyper" +version = "1.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97818827ef4f364230e16705d4706e2897df2bb60617d6ca15d598025a3c481f" +dependencies = [ + "bytes", + "futures-channel", + "futures-util", + "http", + "http-body", + "httparse", + "itoa", + "pin-project-lite", + "smallvec", + "tokio", + "want", +] + +[[package]] +name = "hyper-util" +version = "0.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df2dcfbe0677734ab2f3ffa7fa7bfd4706bfdc1ef393f2ee30184aed67e631b4" +dependencies = [ + "bytes", + "futures-channel", + "futures-util", + "http", + "http-body", + "hyper", + "pin-project-lite", + "socket2", + "tokio", + "tower-service", + "tracing", +] + +[[package]] +name = "iri-string" +version = "0.7.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc0f0a572e8ffe56e2ff4f769f32ffe919282c3916799f8b68688b6030063bea" +dependencies = [ + "memchr", + "serde", +] + +[[package]] +name = "itoa" +version = "1.0.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b" + +[[package]] +name = "jobserver" +version = "0.1.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48d1dbcbbeb6a7fec7e059840aa538bd62aaccf972c7346c4d9d2059312853d0" +dependencies = [ + "libc", +] + +[[package]] +name = "lazy_static" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" + +[[package]] +name = "libc" +version = "0.2.164" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "433bfe06b8c75da9b2e3fbea6e5329ff87748f0b144ef75306e674c3f6f7c13f" + +[[package]] +name = "lock_api" +version = "0.4.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "07af8b9cdd281b7915f413fa73f29ebd5d55d0d3f0155584dade1ff18cea1b17" +dependencies = [ + "autocfg", + "scopeguard", +] + +[[package]] +name = "log" +version = "0.4.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" + +[[package]] +name = "memchr" +version = "2.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" + +[[package]] +name = "mime" +version = "0.3.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" + +[[package]] +name = "mime_guess" +version = "2.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f7c44f8e672c00fe5308fa235f821cb4198414e1c77935c1ab6948d3fd78550e" +dependencies = [ + "mime", + "unicase", +] + +[[package]] +name = "miniz_oxide" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2d80299ef12ff69b16a84bb182e3b9df68b5a91574d3d4fa6e41b65deec4df1" +dependencies = [ + "adler2", +] + +[[package]] +name = "mio" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "80e04d1dcff3aae0704555fe5fee3bcfaf3d1fdf8a7e521d5b9d2b42acb52cec" +dependencies = [ + "hermit-abi", + "libc", + "wasi", + "windows-sys", +] + +[[package]] +name = "nu-ansi-term" +version = "0.46.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77a8165726e8236064dbb45459242600304b42a5ea24ee2948e18e023bf7ba84" +dependencies = [ + "overload", + "winapi", +] + +[[package]] +name = "object" +version = "0.36.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aedf0a2d09c573ed1d8d85b30c119153926a2b36dce0ab28322c09a117a4683e" +dependencies = [ + "memchr", +] + +[[package]] +name = "once_cell" +version = "1.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1261fe7e33c73b354eab43b1273a57c8f967d0391e80353e51f764ac02cf6775" + +[[package]] +name = "overload" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" + +[[package]] +name = "parking_lot" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1bf18183cf54e8d6059647fc3063646a1801cf30896933ec2311622cc4b9a27" +dependencies = [ + "lock_api", + "parking_lot_core", +] + +[[package]] +name = "parking_lot_core" +version = "0.9.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e401f977ab385c9e4e3ab30627d6f26d00e2c73eef317493c4ec6d468726cf8" +dependencies = [ + "cfg-if", + "libc", + "redox_syscall", + "smallvec", + "windows-targets", +] + +[[package]] +name = "percent-encoding" +version = "2.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" + +[[package]] +name = "pin-project-lite" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "915a1e146535de9163f3987b8944ed8cf49a18bb0056bcebcdcece385cece4ff" + +[[package]] +name = "pin-utils" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" + +[[package]] +name = "pkg-config" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "953ec861398dccce10c670dfeaf3ec4911ca479e9c02154b3a215178c5f566f2" + +[[package]] +name = "proc-macro2" +version = "1.0.89" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f139b0662de085916d1fb67d2b4169d1addddda1919e696f3252b740b629986e" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "quote" +version = "1.0.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b5b9d34b8991d19d98081b46eacdd8eb58c6f2b201139f7c5f643cc155a633af" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "redox_syscall" +version = "0.5.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b6dfecf2c74bce2466cabf93f6664d6998a69eb21e39f4207930065b27b771f" +dependencies = [ + "bitflags", +] + +[[package]] +name = "rustc-demangle" +version = "0.1.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f" + +[[package]] +name = "ryu" +version = "1.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f" + +[[package]] +name = "scopeguard" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" + +[[package]] +name = "serde" +version = "1.0.215" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6513c1ad0b11a9376da888e3e0baa0077f1aed55c17f50e7b2397136129fb88f" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.215" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ad1e866f866923f252f05c889987993144fb74e722403468a4ebd70c3cd756c0" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "serde_json" +version = "1.0.133" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7fceb2473b9166b2294ef05efcb65a3db80803f0b03ef86a5fc88a2b85ee377" +dependencies = [ + "itoa", + "memchr", + "ryu", + "serde", +] + +[[package]] +name = "sharded-slab" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f40ca3c46823713e0d4209592e8d6e826aa57e928f09752619fc696c499637f6" +dependencies = [ + "lazy_static", +] + +[[package]] +name = "shlex" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" + +[[package]] +name = "signal-hook-registry" +version = "1.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a9e9e0b4211b72e7b8b6e85c807d36c212bdb33ea8587f7569562a84df5465b1" +dependencies = [ + "libc", +] + +[[package]] +name = "slab" +version = "0.4.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f92a496fb766b417c996b9c5e57daf2f7ad3b0bebe1ccfca4856390e3d3bb67" +dependencies = [ + "autocfg", +] + +[[package]] +name = "smallvec" +version = "1.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" + +[[package]] +name = "socket2" +version = "0.5.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ce305eb0b4296696835b71df73eb912e0f1ffd2556a501fcede6e0c50349191c" +dependencies = [ + "libc", + "windows-sys", +] + +[[package]] +name = "syn" +version = "2.0.87" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "25aa4ce346d03a6dcd68dd8b4010bcb74e54e62c90c573f394c46eae99aba32d" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "sync_wrapper" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2047c6ded9c721764247e62cd3b03c09ffc529b2ba5b10ec482ae507a4a70160" + +[[package]] +name = "sync_wrapper" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a7065abeca94b6a8a577f9bd45aa0867a2238b74e8eb67cf10d492bc39351394" + +[[package]] +name = "thread_local" +version = "1.1.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b9ef9bad013ada3808854ceac7b46812a6465ba368859a37e2100283d2d719c" +dependencies = [ + "cfg-if", + "once_cell", +] + +[[package]] +name = "tokio" +version = "1.41.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22cfb5bee7a6a52939ca9224d6ac897bb669134078daa8735560897f69de4d33" +dependencies = [ + "backtrace", + "bytes", + "libc", + "mio", + "parking_lot", + "pin-project-lite", + "signal-hook-registry", + "socket2", + "tokio-macros", + "windows-sys", +] + +[[package]] +name = "tokio-macros" +version = "2.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "693d596312e88961bc67d7f1f97af8a70227d9f90c31bba5806eec004978d752" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "tokio-util" +version = "0.7.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "61e7c3654c13bcd040d4a03abee2c75b1d14a37b423cf5a813ceae1cc903ec6a" +dependencies = [ + "bytes", + "futures-core", + "futures-sink", + "pin-project-lite", + "tokio", +] + +[[package]] +name = "tower" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2873938d487c3cfb9aed7546dc9f2711d867c9f90c46b889989a2cb84eba6b4f" +dependencies = [ + "futures-core", + "futures-util", + "pin-project-lite", + "sync_wrapper 0.1.2", + "tokio", + "tokio-util", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "tower-http" +version = "0.6.6" +dependencies = [ + "async-compression", + "async-trait", + "base64", + "bitflags", + "brotli", + "bytes", + "flate2", + "futures-core", + "futures-util", + "http", + "http-body", + "http-body-util", + "http-range-header", + "httpdate", + "hyper-util", + "iri-string", + "mime", + "mime_guess", + "once_cell", + "percent-encoding", + "pin-project-lite", + "serde_json", + "sync_wrapper 1.0.1", + "tokio", + "tokio-util", + "tower", + "tower-layer", + "tower-service", + "tracing", + "tracing-subscriber", + "uuid", + "zstd", +] + +[[package]] +name = "tower-layer" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "121c2a6cda46980bb0fcd1647ffaf6cd3fc79a013de288782836f6df9c48780e" + +[[package]] +name = "tower-service" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3" + +[[package]] +name = "tracing" +version = "0.1.40" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3523ab5a71916ccf420eebdf5521fcef02141234bbc0b8a49f2fdc4544364ef" +dependencies = [ + "pin-project-lite", + "tracing-core", +] + +[[package]] +name = "tracing-core" +version = "0.1.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c06d3da6113f116aaee68e4d601191614c9053067f9ab7f6edbcb161237daa54" +dependencies = [ + "once_cell", + "valuable", +] + +[[package]] +name = "tracing-log" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee855f1f400bd0e5c02d150ae5de3840039a3f54b025156404e34c23c03f47c3" +dependencies = [ + "log", + "once_cell", + "tracing-core", +] + +[[package]] +name = "tracing-subscriber" +version = "0.3.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ad0f048c97dbd9faa9b7df56362b8ebcaa52adb06b498c050d2f4e32f90a7a8b" +dependencies = [ + "nu-ansi-term", + "sharded-slab", + "smallvec", + "thread_local", + "tracing-core", + "tracing-log", +] + +[[package]] +name = "try-lock" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" + +[[package]] +name = "unicase" +version = "2.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7e51b68083f157f853b6379db119d1c1be0e6e4dec98101079dec41f6f5cf6df" + +[[package]] +name = "unicode-ident" +version = "1.0.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e91b56cd4cadaeb79bbf1a5645f6b4f8dc5bde8834ad5894a8db35fda9efa1fe" + +[[package]] +name = "uuid" +version = "1.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8c5f0a0af699448548ad1a2fbf920fb4bee257eae39953ba95cb84891a0446a" +dependencies = [ + "getrandom", +] + +[[package]] +name = "valuable" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "830b7e5d4d90034032940e4ace0d9a9a057e7a45cd94e6c007832e39edb82f6d" + +[[package]] +name = "want" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfa7760aed19e106de2c7c0b581b509f2f25d3dacaf737cb82ac61bc6d760b0e" +dependencies = [ + "try-lock", +] + +[[package]] +name = "wasi" +version = "0.11.0+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" + +[[package]] +name = "winapi" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" +dependencies = [ + "winapi-i686-pc-windows-gnu", + "winapi-x86_64-pc-windows-gnu", +] + +[[package]] +name = "winapi-i686-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" + +[[package]] +name = "winapi-x86_64-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" + +[[package]] +name = "windows-sys" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" +dependencies = [ + "windows-targets", +] + +[[package]] +name = "windows-targets" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" +dependencies = [ + "windows_aarch64_gnullvm", + "windows_aarch64_msvc", + "windows_i686_gnu", + "windows_i686_gnullvm", + "windows_i686_msvc", + "windows_x86_64_gnu", + "windows_x86_64_gnullvm", + "windows_x86_64_msvc", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" + +[[package]] +name = "windows_i686_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" + +[[package]] +name = "windows_i686_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" + +[[package]] +name = "windows_i686_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" + +[[package]] +name = "zstd" +version = "0.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fcf2b778a664581e31e389454a7072dab1647606d44f7feea22cd5abb9c9f3f9" +dependencies = [ + "zstd-safe", +] + +[[package]] +name = "zstd-safe" +version = "7.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "54a3ab4db68cea366acc5c897c7b4d4d1b8994a9cd6e6f841f8964566a419059" +dependencies = [ + "zstd-sys", +] + +[[package]] +name = "zstd-sys" +version = "2.0.13+zstd.1.5.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38ff0f21cfee8f97d94cef41359e0c89aa6113028ab0291aa8ca0038995a95aa" +dependencies = [ + "cc", + "pkg-config", +] diff --git a/vendor/tower-http/Cargo.toml b/vendor/tower-http/Cargo.toml new file mode 100644 index 00000000..bf3ad35b --- /dev/null +++ b/vendor/tower-http/Cargo.toml @@ -0,0 +1,390 @@ +# THIS FILE IS AUTOMATICALLY GENERATED BY CARGO +# +# When uploading crates to the registry Cargo will automatically +# "normalize" Cargo.toml files for maximal compatibility +# with all versions of Cargo and also rewrite `path` dependencies +# to registry (e.g., crates.io) dependencies. +# +# If you are reading this file be aware that the original Cargo.toml +# will likely look very different (and much more reasonable). +# See Cargo.toml.orig for the original contents. + +[package] +edition = "2018" +rust-version = "1.64" +name = "tower-http" +version = "0.6.6" +authors = ["Tower Maintainers <team@tower-rs.com>"] +build = false +autolib = false +autobins = false +autoexamples = false +autotests = false +autobenches = false +description = "Tower middleware and utilities for HTTP clients and servers" +homepage = "https://github.com/tower-rs/tower-http" +readme = "README.md" +keywords = [ + "io", + "async", + "futures", + "service", + "http", +] +categories = [ + "asynchronous", + "network-programming", + "web-programming", +] +license = "MIT" +repository = "https://github.com/tower-rs/tower-http" +resolver = "2" + +[package.metadata.cargo-public-api-crates] +allowed = [ + "bytes", + "http", + "http_body", + "mime", + "pin-project-lite", + "tokio", + "tower", + "tower_layer", + "tower_service", + "tracing", + "tracing_core", +] + +[package.metadata.docs.rs] +all-features = true +rustdoc-args = [ + "--cfg", + "docsrs", +] + +[package.metadata.playground] +features = ["full"] + +[features] +add-extension = [] +auth = [ + "base64", + "validate-request", +] +catch-panic = [ + "tracing", + "futures-util/std", + "dep:http-body", + "dep:http-body-util", +] +compression-br = [ + "async-compression/brotli", + "futures-core", + "dep:http-body", + "tokio-util", + "tokio", +] +compression-deflate = [ + "async-compression/zlib", + "futures-core", + "dep:http-body", + "tokio-util", + "tokio", +] +compression-full = [ + "compression-br", + "compression-deflate", + "compression-gzip", + "compression-zstd", +] +compression-gzip = [ + "async-compression/gzip", + "futures-core", + "dep:http-body", + "tokio-util", + "tokio", +] +compression-zstd = [ + "async-compression/zstd", + "futures-core", + "dep:http-body", + "tokio-util", + "tokio", +] +cors = [] +decompression-br = [ + "async-compression/brotli", + "futures-core", + "dep:http-body", + "dep:http-body-util", + "tokio-util", + "tokio", +] +decompression-deflate = [ + "async-compression/zlib", + "futures-core", + "dep:http-body", + "dep:http-body-util", + "tokio-util", + "tokio", +] +decompression-full = [ + "decompression-br", + "decompression-deflate", + "decompression-gzip", + "decompression-zstd", +] +decompression-gzip = [ + "async-compression/gzip", + "futures-core", + "dep:http-body", + "dep:http-body-util", + "tokio-util", + "tokio", +] +decompression-zstd = [ + "async-compression/zstd", + "futures-core", + "dep:http-body", + "dep:http-body-util", + "tokio-util", + "tokio", +] +default = [] +follow-redirect = [ + "futures-util", + "dep:http-body", + "iri-string", + "tower/util", +] +fs = [ + "futures-core", + "futures-util", + "dep:http-body", + "dep:http-body-util", + "tokio/fs", + "tokio-util/io", + "tokio/io-util", + "dep:http-range-header", + "mime_guess", + "mime", + "percent-encoding", + "httpdate", + "set-status", + "futures-util/alloc", + "tracing", +] +full = [ + "add-extension", + "auth", + "catch-panic", + "compression-full", + "cors", + "decompression-full", + "follow-redirect", + "fs", + "limit", + "map-request-body", + "map-response-body", + "metrics", + "normalize-path", + "propagate-header", + "redirect", + "request-id", + "sensitive-headers", + "set-header", + "set-status", + "timeout", + "trace", + "util", + "validate-request", +] +limit = [ + "dep:http-body", + "dep:http-body-util", +] +map-request-body = [] +map-response-body = [] +metrics = [ + "dep:http-body", + "tokio/time", +] +normalize-path = [] +propagate-header = [] +redirect = [] +request-id = ["uuid"] +sensitive-headers = [] +set-header = [] +set-status = [] +timeout = [ + "dep:http-body", + "tokio/time", +] +trace = [ + "dep:http-body", + "tracing", +] +util = ["tower"] +validate-request = ["mime"] + +[lib] +name = "tower_http" +path = "src/lib.rs" + +[dependencies.async-compression] +version = "0.4" +features = ["tokio"] +optional = true + +[dependencies.base64] +version = "0.22" +optional = true + +[dependencies.bitflags] +version = "2.0.2" + +[dependencies.bytes] +version = "1" + +[dependencies.futures-core] +version = "0.3" +optional = true +default-features = false + +[dependencies.futures-util] +version = "0.3.14" +optional = true +default-features = false + +[dependencies.http] +version = "1.0" + +[dependencies.http-body] +version = "1.0.0" +optional = true + +[dependencies.http-body-util] +version = "0.1.0" +optional = true + +[dependencies.http-range-header] +version = "0.4.0" +optional = true + +[dependencies.httpdate] +version = "1.0" +optional = true + +[dependencies.iri-string] +version = "0.7.0" +optional = true + +[dependencies.mime] +version = "0.3.17" +optional = true +default-features = false + +[dependencies.mime_guess] +version = "2" +optional = true +default-features = false + +[dependencies.percent-encoding] +version = "2.1.0" +optional = true + +[dependencies.pin-project-lite] +version = "0.2.7" + +[dependencies.tokio] +version = "1.6" +optional = true +default-features = false + +[dependencies.tokio-util] +version = "0.7" +features = ["io"] +optional = true +default-features = false + +[dependencies.tower] +version = "0.5" +optional = true + +[dependencies.tower-layer] +version = "0.3.3" + +[dependencies.tower-service] +version = "0.3" + +[dependencies.tracing] +version = "0.1" +optional = true +default-features = false + +[dependencies.uuid] +version = "1.0" +features = ["v4"] +optional = true + +[dev-dependencies.async-trait] +version = "0.1" + +[dev-dependencies.brotli] +version = "7" + +[dev-dependencies.bytes] +version = "1" + +[dev-dependencies.flate2] +version = "1.0" + +[dev-dependencies.futures-util] +version = "0.3.14" + +[dev-dependencies.http-body] +version = "1.0.0" + +[dev-dependencies.http-body-util] +version = "0.1.0" + +[dev-dependencies.hyper-util] +version = "0.1" +features = [ + "client-legacy", + "http1", + "tokio", +] + +[dev-dependencies.once_cell] +version = "1" + +[dev-dependencies.serde_json] +version = "1.0" + +[dev-dependencies.sync_wrapper] +version = "1" + +[dev-dependencies.tokio] +version = "1" +features = ["full"] + +[dev-dependencies.tower] +version = "0.5" +features = [ + "buffer", + "util", + "retry", + "make", + "timeout", +] + +[dev-dependencies.tracing-subscriber] +version = "0.3" + +[dev-dependencies.uuid] +version = "1.0" +features = ["v4"] + +[dev-dependencies.zstd] +version = "0.13" diff --git a/vendor/tower-http/LICENSE b/vendor/tower-http/LICENSE new file mode 100644 index 00000000..352c2cfa --- /dev/null +++ b/vendor/tower-http/LICENSE @@ -0,0 +1,25 @@ +Copyright (c) 2019-2021 Tower Contributors + +Permission is hereby granted, free of charge, to any +person obtaining a copy of this software and associated +documentation files (the "Software"), to deal in the +Software without restriction, including without +limitation the rights to use, copy, modify, merge, +publish, distribute, sublicense, and/or sell copies of +the Software, and to permit persons to whom the Software +is furnished to do so, subject to the following +conditions: + +The above copyright notice and this permission notice +shall be included in all copies or substantial portions +of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF +ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED +TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A +PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT +SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR +IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. diff --git a/vendor/tower-http/README.md b/vendor/tower-http/README.md new file mode 100644 index 00000000..4df24186 --- /dev/null +++ b/vendor/tower-http/README.md @@ -0,0 +1,79 @@ +# Tower HTTP + +Tower middleware and utilities for HTTP clients and servers. + +[](https://github.com/tower-rs/tower-http/actions) +[](https://crates.io/crates/tower-http) +[](https://docs.rs/tower-http) +[](tower-http/LICENSE) + +More information about this crate can be found in the [crate documentation][docs]. + +## Middleware + +Tower HTTP contains lots of middleware that are generally useful when building +HTTP servers and clients. Some of the highlights are: + +- `Trace` adds high level logging of requests and responses. Supports both + regular HTTP requests as well as gRPC. +- `Compression` and `Decompression` to compress/decompress response bodies. +- `FollowRedirect` to automatically follow redirection responses. + +See the [docs] for the complete list of middleware. + +Middleware uses the [http] crate as the HTTP interface so they're compatible +with any library or framework that also uses [http]. For example [hyper]. + +The middleware were originally extracted from one of [@EmbarkStudios] internal +projects. + +## Examples + +The [examples] folder contains various examples of how to use Tower HTTP: + +- [warp-key-value-store]: A key/value store with an HTTP API built with warp. +- [tonic-key-value-store]: A key/value store with a gRPC API and client built with tonic. +- [axum-key-value-store]: A key/value store with an HTTP API built with axum. + +## Minimum supported Rust version + +tower-http's MSRV is 1.66. + +## Getting Help + +If you're new to tower its [guides] might help. In the tower-http repo we also +have a [number of examples][examples] showing how to put everything together. +You're also welcome to ask in the [`#tower` Discord channel][chat] or open an +[issue] with your question. + +## Contributing + +:balloon: Thanks for your help improving the project! We are so happy to have +you! We have a [contributing guide][guide] to help you get involved in the Tower +HTTP project. + +[guide]: CONTRIBUTING.md + +## License + +This project is licensed under the [MIT license](tower-http/LICENSE). + +### Contribution + +Unless you explicitly state otherwise, any contribution intentionally submitted +for inclusion in Tower HTTP by you, shall be licensed as MIT, without any +additional terms or conditions. + +[@EmbarkStudios]: https://github.com/EmbarkStudios +[examples]: https://github.com/tower-rs/tower-http/tree/master/examples +[http]: https://crates.io/crates/http +[tonic-key-value-store]: https://github.com/tower-rs/tower-http/tree/master/examples/tonic-key-value-store +[warp-key-value-store]: https://github.com/tower-rs/tower-http/tree/master/examples/warp-key-value-store +[axum-key-value-store]: https://github.com/tower-rs/tower-http/tree/master/examples/axum-key-value-store +[chat]: https://discord.gg/tokio +[docs]: https://docs.rs/tower-http +[hyper]: https://github.com/hyperium/hyper +[issue]: https://github.com/tower-rs/tower-http/issues/new +[milestone]: https://github.com/tower-rs/tower-http/milestones +[examples]: https://github.com/tower-rs/tower-http/tree/master/examples +[guides]: https://github.com/tower-rs/tower/tree/master/guides diff --git a/vendor/tower-http/src/add_extension.rs b/vendor/tower-http/src/add_extension.rs new file mode 100644 index 00000000..095646df --- /dev/null +++ b/vendor/tower-http/src/add_extension.rs @@ -0,0 +1,167 @@ +//! Middleware that clones a value into each request's [extensions]. +//! +//! [extensions]: https://docs.rs/http/latest/http/struct.Extensions.html +//! +//! # Example +//! +//! ``` +//! use tower_http::add_extension::AddExtensionLayer; +//! use tower::{Service, ServiceExt, ServiceBuilder, service_fn}; +//! use http::{Request, Response}; +//! use bytes::Bytes; +//! use http_body_util::Full; +//! use std::{sync::Arc, convert::Infallible}; +//! +//! # struct DatabaseConnectionPool; +//! # impl DatabaseConnectionPool { +//! # fn new() -> DatabaseConnectionPool { DatabaseConnectionPool } +//! # } +//! # +//! // Shared state across all request handlers --- in this case, a pool of database connections. +//! struct State { +//! pool: DatabaseConnectionPool, +//! } +//! +//! async fn handle(req: Request<Full<Bytes>>) -> Result<Response<Full<Bytes>>, Infallible> { +//! // Grab the state from the request extensions. +//! let state = req.extensions().get::<Arc<State>>().unwrap(); +//! +//! Ok(Response::new(Full::default())) +//! } +//! +//! # #[tokio::main] +//! # async fn main() -> Result<(), Box<dyn std::error::Error>> { +//! // Construct the shared state. +//! let state = State { +//! pool: DatabaseConnectionPool::new(), +//! }; +//! +//! let mut service = ServiceBuilder::new() +//! // Share an `Arc<State>` with all requests. +//! .layer(AddExtensionLayer::new(Arc::new(state))) +//! .service_fn(handle); +//! +//! // Call the service. +//! let response = service +//! .ready() +//! .await? +//! .call(Request::new(Full::default())) +//! .await?; +//! # Ok(()) +//! # } +//! ``` + +use http::{Request, Response}; +use std::task::{Context, Poll}; +use tower_layer::Layer; +use tower_service::Service; + +/// [`Layer`] for adding some shareable value to [request extensions]. +/// +/// See the [module docs](crate::add_extension) for more details. +/// +/// [request extensions]: https://docs.rs/http/latest/http/struct.Extensions.html +#[derive(Clone, Copy, Debug)] +pub struct AddExtensionLayer<T> { + value: T, +} + +impl<T> AddExtensionLayer<T> { + /// Create a new [`AddExtensionLayer`]. + pub fn new(value: T) -> Self { + AddExtensionLayer { value } + } +} + +impl<S, T> Layer<S> for AddExtensionLayer<T> +where + T: Clone, +{ + type Service = AddExtension<S, T>; + + fn layer(&self, inner: S) -> Self::Service { + AddExtension { + inner, + value: self.value.clone(), + } + } +} + +/// Middleware for adding some shareable value to [request extensions]. +/// +/// See the [module docs](crate::add_extension) for more details. +/// +/// [request extensions]: https://docs.rs/http/latest/http/struct.Extensions.html +#[derive(Clone, Copy, Debug)] +pub struct AddExtension<S, T> { + inner: S, + value: T, +} + +impl<S, T> AddExtension<S, T> { + /// Create a new [`AddExtension`]. + pub fn new(inner: S, value: T) -> Self { + Self { inner, value } + } + + define_inner_service_accessors!(); + + /// Returns a new [`Layer`] that wraps services with a `AddExtension` middleware. + /// + /// [`Layer`]: tower_layer::Layer + pub fn layer(value: T) -> AddExtensionLayer<T> { + AddExtensionLayer::new(value) + } +} + +impl<ResBody, ReqBody, S, T> Service<Request<ReqBody>> for AddExtension<S, T> +where + S: Service<Request<ReqBody>, Response = Response<ResBody>>, + T: Clone + Send + Sync + 'static, +{ + type Response = S::Response; + type Error = S::Error; + type Future = S::Future; + + #[inline] + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future { + req.extensions_mut().insert(self.value.clone()); + self.inner.call(req) + } +} + +#[cfg(test)] +mod tests { + #[allow(unused_imports)] + use super::*; + use crate::test_helpers::Body; + use http::Response; + use std::{convert::Infallible, sync::Arc}; + use tower::{service_fn, ServiceBuilder, ServiceExt}; + + struct State(i32); + + #[tokio::test] + async fn basic() { + let state = Arc::new(State(1)); + + let svc = ServiceBuilder::new() + .layer(AddExtensionLayer::new(state)) + .service(service_fn(|req: Request<Body>| async move { + let state = req.extensions().get::<Arc<State>>().unwrap(); + Ok::<_, Infallible>(Response::new(state.0)) + })); + + let res = svc + .oneshot(Request::new(Body::empty())) + .await + .unwrap() + .into_body(); + + assert_eq!(1, res); + } +} diff --git a/vendor/tower-http/src/auth/add_authorization.rs b/vendor/tower-http/src/auth/add_authorization.rs new file mode 100644 index 00000000..246c13b6 --- /dev/null +++ b/vendor/tower-http/src/auth/add_authorization.rs @@ -0,0 +1,267 @@ +//! Add authorization to requests using the [`Authorization`] header. +//! +//! [`Authorization`]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Authorization +//! +//! # Example +//! +//! ``` +//! use tower_http::validate_request::{ValidateRequestHeader, ValidateRequestHeaderLayer}; +//! use tower_http::auth::AddAuthorizationLayer; +//! use http::{Request, Response, StatusCode, header::AUTHORIZATION}; +//! use tower::{Service, ServiceExt, ServiceBuilder, service_fn, BoxError}; +//! use http_body_util::Full; +//! use bytes::Bytes; +//! # async fn handle(request: Request<Full<Bytes>>) -> Result<Response<Full<Bytes>>, BoxError> { +//! # Ok(Response::new(Full::default())) +//! # } +//! +//! # #[tokio::main] +//! # async fn main() -> Result<(), BoxError> { +//! # let service_that_requires_auth = ValidateRequestHeader::basic( +//! # tower::service_fn(handle), +//! # "username", +//! # "password", +//! # ); +//! let mut client = ServiceBuilder::new() +//! // Use basic auth with the given username and password +//! .layer(AddAuthorizationLayer::basic("username", "password")) +//! .service(service_that_requires_auth); +//! +//! // Make a request, we don't have to add the `Authorization` header manually +//! let response = client +//! .ready() +//! .await? +//! .call(Request::new(Full::default())) +//! .await?; +//! +//! assert_eq!(StatusCode::OK, response.status()); +//! # Ok(()) +//! # } +//! ``` + +use base64::Engine as _; +use http::{HeaderValue, Request, Response}; +use std::{ + convert::TryFrom, + task::{Context, Poll}, +}; +use tower_layer::Layer; +use tower_service::Service; + +const BASE64: base64::engine::GeneralPurpose = base64::engine::general_purpose::STANDARD; + +/// Layer that applies [`AddAuthorization`] which adds authorization to all requests using the +/// [`Authorization`] header. +/// +/// See the [module docs](crate::auth::add_authorization) for an example. +/// +/// You can also use [`SetRequestHeader`] if you have a use case that isn't supported by this +/// middleware. +/// +/// [`Authorization`]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Authorization +/// [`SetRequestHeader`]: crate::set_header::SetRequestHeader +#[derive(Debug, Clone)] +pub struct AddAuthorizationLayer { + value: HeaderValue, +} + +impl AddAuthorizationLayer { + /// Authorize requests using a username and password pair. + /// + /// The `Authorization` header will be set to `Basic {credentials}` where `credentials` is + /// `base64_encode("{username}:{password}")`. + /// + /// Since the username and password is sent in clear text it is recommended to use HTTPS/TLS + /// with this method. However use of HTTPS/TLS is not enforced by this middleware. + pub fn basic(username: &str, password: &str) -> Self { + let encoded = BASE64.encode(format!("{}:{}", username, password)); + let value = HeaderValue::try_from(format!("Basic {}", encoded)).unwrap(); + Self { value } + } + + /// Authorize requests using a "bearer token". Commonly used for OAuth 2. + /// + /// The `Authorization` header will be set to `Bearer {token}`. + /// + /// # Panics + /// + /// Panics if the token is not a valid [`HeaderValue`]. + pub fn bearer(token: &str) -> Self { + let value = + HeaderValue::try_from(format!("Bearer {}", token)).expect("token is not valid header"); + Self { value } + } + + /// Mark the header as [sensitive]. + /// + /// This can for example be used to hide the header value from logs. + /// + /// [sensitive]: https://docs.rs/http/latest/http/header/struct.HeaderValue.html#method.set_sensitive + #[allow(clippy::wrong_self_convention)] + pub fn as_sensitive(mut self, sensitive: bool) -> Self { + self.value.set_sensitive(sensitive); + self + } +} + +impl<S> Layer<S> for AddAuthorizationLayer { + type Service = AddAuthorization<S>; + + fn layer(&self, inner: S) -> Self::Service { + AddAuthorization { + inner, + value: self.value.clone(), + } + } +} + +/// Middleware that adds authorization all requests using the [`Authorization`] header. +/// +/// See the [module docs](crate::auth::add_authorization) for an example. +/// +/// You can also use [`SetRequestHeader`] if you have a use case that isn't supported by this +/// middleware. +/// +/// [`Authorization`]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Authorization +/// [`SetRequestHeader`]: crate::set_header::SetRequestHeader +#[derive(Debug, Clone)] +pub struct AddAuthorization<S> { + inner: S, + value: HeaderValue, +} + +impl<S> AddAuthorization<S> { + /// Authorize requests using a username and password pair. + /// + /// The `Authorization` header will be set to `Basic {credentials}` where `credentials` is + /// `base64_encode("{username}:{password}")`. + /// + /// Since the username and password is sent in clear text it is recommended to use HTTPS/TLS + /// with this method. However use of HTTPS/TLS is not enforced by this middleware. + pub fn basic(inner: S, username: &str, password: &str) -> Self { + AddAuthorizationLayer::basic(username, password).layer(inner) + } + + /// Authorize requests using a "bearer token". Commonly used for OAuth 2. + /// + /// The `Authorization` header will be set to `Bearer {token}`. + /// + /// # Panics + /// + /// Panics if the token is not a valid [`HeaderValue`]. + pub fn bearer(inner: S, token: &str) -> Self { + AddAuthorizationLayer::bearer(token).layer(inner) + } + + define_inner_service_accessors!(); + + /// Mark the header as [sensitive]. + /// + /// This can for example be used to hide the header value from logs. + /// + /// [sensitive]: https://docs.rs/http/latest/http/header/struct.HeaderValue.html#method.set_sensitive + #[allow(clippy::wrong_self_convention)] + pub fn as_sensitive(mut self, sensitive: bool) -> Self { + self.value.set_sensitive(sensitive); + self + } +} + +impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for AddAuthorization<S> +where + S: Service<Request<ReqBody>, Response = Response<ResBody>>, +{ + type Response = S::Response; + type Error = S::Error; + type Future = S::Future; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future { + req.headers_mut() + .insert(http::header::AUTHORIZATION, self.value.clone()); + self.inner.call(req) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::test_helpers::Body; + use crate::validate_request::ValidateRequestHeaderLayer; + use http::{Response, StatusCode}; + use std::convert::Infallible; + use tower::{BoxError, Service, ServiceBuilder, ServiceExt}; + + #[tokio::test] + async fn basic() { + // service that requires auth for all requests + let svc = ServiceBuilder::new() + .layer(ValidateRequestHeaderLayer::basic("foo", "bar")) + .service_fn(echo); + + // make a client that adds auth + let mut client = AddAuthorization::basic(svc, "foo", "bar"); + + let res = client + .ready() + .await + .unwrap() + .call(Request::new(Body::empty())) + .await + .unwrap(); + + assert_eq!(res.status(), StatusCode::OK); + } + + #[tokio::test] + async fn token() { + // service that requires auth for all requests + let svc = ServiceBuilder::new() + .layer(ValidateRequestHeaderLayer::bearer("foo")) + .service_fn(echo); + + // make a client that adds auth + let mut client = AddAuthorization::bearer(svc, "foo"); + + let res = client + .ready() + .await + .unwrap() + .call(Request::new(Body::empty())) + .await + .unwrap(); + + assert_eq!(res.status(), StatusCode::OK); + } + + #[tokio::test] + async fn making_header_sensitive() { + let svc = ServiceBuilder::new() + .layer(ValidateRequestHeaderLayer::bearer("foo")) + .service_fn(|request: Request<Body>| async move { + let auth = request.headers().get(http::header::AUTHORIZATION).unwrap(); + assert!(auth.is_sensitive()); + + Ok::<_, Infallible>(Response::new(Body::empty())) + }); + + let mut client = AddAuthorization::bearer(svc, "foo").as_sensitive(true); + + let res = client + .ready() + .await + .unwrap() + .call(Request::new(Body::empty())) + .await + .unwrap(); + + assert_eq!(res.status(), StatusCode::OK); + } + + async fn echo(req: Request<Body>) -> Result<Response<Body>, BoxError> { + Ok(Response::new(req.into_body())) + } +} diff --git a/vendor/tower-http/src/auth/async_require_authorization.rs b/vendor/tower-http/src/auth/async_require_authorization.rs new file mode 100644 index 00000000..fda9abea --- /dev/null +++ b/vendor/tower-http/src/auth/async_require_authorization.rs @@ -0,0 +1,385 @@ +//! Authorize requests using the [`Authorization`] header asynchronously. +//! +//! [`Authorization`]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Authorization +//! +//! # Example +//! +//! ``` +//! use tower_http::auth::{AsyncRequireAuthorizationLayer, AsyncAuthorizeRequest}; +//! use http::{Request, Response, StatusCode, header::AUTHORIZATION}; +//! use tower::{Service, ServiceExt, ServiceBuilder, service_fn, BoxError}; +//! use futures_core::future::BoxFuture; +//! use bytes::Bytes; +//! use http_body_util::Full; +//! +//! #[derive(Clone, Copy)] +//! struct MyAuth; +//! +//! impl<B> AsyncAuthorizeRequest<B> for MyAuth +//! where +//! B: Send + Sync + 'static, +//! { +//! type RequestBody = B; +//! type ResponseBody = Full<Bytes>; +//! type Future = BoxFuture<'static, Result<Request<B>, Response<Self::ResponseBody>>>; +//! +//! fn authorize(&mut self, mut request: Request<B>) -> Self::Future { +//! Box::pin(async { +//! if let Some(user_id) = check_auth(&request).await { +//! // Set `user_id` as a request extension so it can be accessed by other +//! // services down the stack. +//! request.extensions_mut().insert(user_id); +//! +//! Ok(request) +//! } else { +//! let unauthorized_response = Response::builder() +//! .status(StatusCode::UNAUTHORIZED) +//! .body(Full::<Bytes>::default()) +//! .unwrap(); +//! +//! Err(unauthorized_response) +//! } +//! }) +//! } +//! } +//! +//! async fn check_auth<B>(request: &Request<B>) -> Option<UserId> { +//! // ... +//! # None +//! } +//! +//! #[derive(Debug, Clone)] +//! struct UserId(String); +//! +//! async fn handle(request: Request<Full<Bytes>>) -> Result<Response<Full<Bytes>>, BoxError> { +//! // Access the `UserId` that was set in `on_authorized`. If `handle` gets called the +//! // request was authorized and `UserId` will be present. +//! let user_id = request +//! .extensions() +//! .get::<UserId>() +//! .expect("UserId will be there if request was authorized"); +//! +//! println!("request from {:?}", user_id); +//! +//! Ok(Response::new(Full::default())) +//! } +//! +//! # #[tokio::main] +//! # async fn main() -> Result<(), BoxError> { +//! let service = ServiceBuilder::new() +//! // Authorize requests using `MyAuth` +//! .layer(AsyncRequireAuthorizationLayer::new(MyAuth)) +//! .service_fn(handle); +//! # Ok(()) +//! # } +//! ``` +//! +//! Or using a closure: +//! +//! ``` +//! use tower_http::auth::{AsyncRequireAuthorizationLayer, AsyncAuthorizeRequest}; +//! use http::{Request, Response, StatusCode}; +//! use tower::{Service, ServiceExt, ServiceBuilder, BoxError}; +//! use futures_core::future::BoxFuture; +//! use http_body_util::Full; +//! use bytes::Bytes; +//! +//! async fn check_auth<B>(request: &Request<B>) -> Option<UserId> { +//! // ... +//! # None +//! } +//! +//! #[derive(Debug)] +//! struct UserId(String); +//! +//! async fn handle(request: Request<Full<Bytes>>) -> Result<Response<Full<Bytes>>, BoxError> { +//! # todo!(); +//! // ... +//! } +//! +//! # #[tokio::main] +//! # async fn main() -> Result<(), BoxError> { +//! let service = ServiceBuilder::new() +//! .layer(AsyncRequireAuthorizationLayer::new(|request: Request<Full<Bytes>>| async move { +//! if let Some(user_id) = check_auth(&request).await { +//! Ok(request) +//! } else { +//! let unauthorized_response = Response::builder() +//! .status(StatusCode::UNAUTHORIZED) +//! .body(Full::<Bytes>::default()) +//! .unwrap(); +//! +//! Err(unauthorized_response) +//! } +//! })) +//! .service_fn(handle); +//! # Ok(()) +//! # } +//! ``` + +use http::{Request, Response}; +use pin_project_lite::pin_project; +use std::{ + future::Future, + mem, + pin::Pin, + task::{ready, Context, Poll}, +}; +use tower_layer::Layer; +use tower_service::Service; + +/// Layer that applies [`AsyncRequireAuthorization`] which authorizes all requests using the +/// [`Authorization`] header. +/// +/// See the [module docs](crate::auth::async_require_authorization) for an example. +/// +/// [`Authorization`]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Authorization +#[derive(Debug, Clone)] +pub struct AsyncRequireAuthorizationLayer<T> { + auth: T, +} + +impl<T> AsyncRequireAuthorizationLayer<T> { + /// Authorize requests using a custom scheme. + pub fn new(auth: T) -> AsyncRequireAuthorizationLayer<T> { + Self { auth } + } +} + +impl<S, T> Layer<S> for AsyncRequireAuthorizationLayer<T> +where + T: Clone, +{ + type Service = AsyncRequireAuthorization<S, T>; + + fn layer(&self, inner: S) -> Self::Service { + AsyncRequireAuthorization::new(inner, self.auth.clone()) + } +} + +/// Middleware that authorizes all requests using the [`Authorization`] header. +/// +/// See the [module docs](crate::auth::async_require_authorization) for an example. +/// +/// [`Authorization`]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Authorization +#[derive(Clone, Debug)] +pub struct AsyncRequireAuthorization<S, T> { + inner: S, + auth: T, +} + +impl<S, T> AsyncRequireAuthorization<S, T> { + define_inner_service_accessors!(); +} + +impl<S, T> AsyncRequireAuthorization<S, T> { + /// Authorize requests using a custom scheme. + /// + /// The `Authorization` header is required to have the value provided. + pub fn new(inner: S, auth: T) -> AsyncRequireAuthorization<S, T> { + Self { inner, auth } + } + + /// Returns a new [`Layer`] that wraps services with an [`AsyncRequireAuthorizationLayer`] + /// middleware. + /// + /// [`Layer`]: tower_layer::Layer + pub fn layer(auth: T) -> AsyncRequireAuthorizationLayer<T> { + AsyncRequireAuthorizationLayer::new(auth) + } +} + +impl<ReqBody, ResBody, S, Auth> Service<Request<ReqBody>> for AsyncRequireAuthorization<S, Auth> +where + Auth: AsyncAuthorizeRequest<ReqBody, ResponseBody = ResBody>, + S: Service<Request<Auth::RequestBody>, Response = Response<ResBody>> + Clone, +{ + type Response = Response<ResBody>; + type Error = S::Error; + type Future = ResponseFuture<Auth, S, ReqBody>; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, req: Request<ReqBody>) -> Self::Future { + let mut inner = self.inner.clone(); + let authorize = self.auth.authorize(req); + // mem::swap due to https://docs.rs/tower/latest/tower/trait.Service.html#be-careful-when-cloning-inner-services + mem::swap(&mut self.inner, &mut inner); + + ResponseFuture { + state: State::Authorize { authorize }, + service: inner, + } + } +} + +pin_project! { + /// Response future for [`AsyncRequireAuthorization`]. + pub struct ResponseFuture<Auth, S, ReqBody> + where + Auth: AsyncAuthorizeRequest<ReqBody>, + S: Service<Request<Auth::RequestBody>>, + { + #[pin] + state: State<Auth::Future, S::Future>, + service: S, + } +} + +pin_project! { + #[project = StateProj] + enum State<A, SFut> { + Authorize { + #[pin] + authorize: A, + }, + Authorized { + #[pin] + fut: SFut, + }, + } +} + +impl<Auth, S, ReqBody, B> Future for ResponseFuture<Auth, S, ReqBody> +where + Auth: AsyncAuthorizeRequest<ReqBody, ResponseBody = B>, + S: Service<Request<Auth::RequestBody>, Response = Response<B>>, +{ + type Output = Result<Response<B>, S::Error>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + let mut this = self.project(); + + loop { + match this.state.as_mut().project() { + StateProj::Authorize { authorize } => { + let auth = ready!(authorize.poll(cx)); + match auth { + Ok(req) => { + let fut = this.service.call(req); + this.state.set(State::Authorized { fut }) + } + Err(res) => { + return Poll::Ready(Ok(res)); + } + }; + } + StateProj::Authorized { fut } => { + return fut.poll(cx); + } + } + } + } +} + +/// Trait for authorizing requests. +pub trait AsyncAuthorizeRequest<B> { + /// The type of request body returned by `authorize`. + /// + /// Set this to `B` unless you need to change the request body type. + type RequestBody; + + /// The body type used for responses to unauthorized requests. + type ResponseBody; + + /// The Future type returned by `authorize` + type Future: Future<Output = Result<Request<Self::RequestBody>, Response<Self::ResponseBody>>>; + + /// Authorize the request. + /// + /// If the future resolves to `Ok(request)` then the request is allowed through, otherwise not. + fn authorize(&mut self, request: Request<B>) -> Self::Future; +} + +impl<B, F, Fut, ReqBody, ResBody> AsyncAuthorizeRequest<B> for F +where + F: FnMut(Request<B>) -> Fut, + Fut: Future<Output = Result<Request<ReqBody>, Response<ResBody>>>, +{ + type RequestBody = ReqBody; + type ResponseBody = ResBody; + type Future = Fut; + + fn authorize(&mut self, request: Request<B>) -> Self::Future { + self(request) + } +} + +#[cfg(test)] +mod tests { + #[allow(unused_imports)] + use super::*; + use crate::test_helpers::Body; + use futures_core::future::BoxFuture; + use http::{header, StatusCode}; + use tower::{BoxError, ServiceBuilder, ServiceExt}; + + #[derive(Clone, Copy)] + struct MyAuth; + + impl<B> AsyncAuthorizeRequest<B> for MyAuth + where + B: Send + 'static, + { + type RequestBody = B; + type ResponseBody = Body; + type Future = BoxFuture<'static, Result<Request<B>, Response<Self::ResponseBody>>>; + + fn authorize(&mut self, request: Request<B>) -> Self::Future { + Box::pin(async move { + let authorized = request + .headers() + .get(header::AUTHORIZATION) + .and_then(|auth| auth.to_str().ok()?.strip_prefix("Bearer ")) + == Some("69420"); + + if authorized { + Ok(request) + } else { + Err(Response::builder() + .status(StatusCode::UNAUTHORIZED) + .body(Body::empty()) + .unwrap()) + } + }) + } + } + + #[tokio::test] + async fn require_async_auth_works() { + let mut service = ServiceBuilder::new() + .layer(AsyncRequireAuthorizationLayer::new(MyAuth)) + .service_fn(echo); + + let request = Request::get("/") + .header(header::AUTHORIZATION, "Bearer 69420") + .body(Body::empty()) + .unwrap(); + + let res = service.ready().await.unwrap().call(request).await.unwrap(); + + assert_eq!(res.status(), StatusCode::OK); + } + + #[tokio::test] + async fn require_async_auth_401() { + let mut service = ServiceBuilder::new() + .layer(AsyncRequireAuthorizationLayer::new(MyAuth)) + .service_fn(echo); + + let request = Request::get("/") + .header(header::AUTHORIZATION, "Bearer deez") + .body(Body::empty()) + .unwrap(); + + let res = service.ready().await.unwrap().call(request).await.unwrap(); + + assert_eq!(res.status(), StatusCode::UNAUTHORIZED); + } + + async fn echo(req: Request<Body>) -> Result<Response<Body>, BoxError> { + Ok(Response::new(req.into_body())) + } +} diff --git a/vendor/tower-http/src/auth/mod.rs b/vendor/tower-http/src/auth/mod.rs new file mode 100644 index 00000000..fc8c2308 --- /dev/null +++ b/vendor/tower-http/src/auth/mod.rs @@ -0,0 +1,13 @@ +//! Authorization related middleware. + +pub mod add_authorization; +pub mod async_require_authorization; +pub mod require_authorization; + +#[doc(inline)] +pub use self::{ + add_authorization::{AddAuthorization, AddAuthorizationLayer}, + async_require_authorization::{ + AsyncAuthorizeRequest, AsyncRequireAuthorization, AsyncRequireAuthorizationLayer, + }, +}; diff --git a/vendor/tower-http/src/auth/require_authorization.rs b/vendor/tower-http/src/auth/require_authorization.rs new file mode 100644 index 00000000..7aa1a87f --- /dev/null +++ b/vendor/tower-http/src/auth/require_authorization.rs @@ -0,0 +1,404 @@ +//! Authorize requests using [`ValidateRequest`]. +//! +//! [`Authorization`]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Authorization +//! +//! # Example +//! +//! ``` +//! use tower_http::validate_request::{ValidateRequest, ValidateRequestHeader, ValidateRequestHeaderLayer}; +//! use http::{Request, Response, StatusCode, header::AUTHORIZATION}; +//! use tower::{Service, ServiceExt, ServiceBuilder, service_fn, BoxError}; +//! use bytes::Bytes; +//! use http_body_util::Full; +//! +//! async fn handle(request: Request<Full<Bytes>>) -> Result<Response<Full<Bytes>>, BoxError> { +//! Ok(Response::new(Full::default())) +//! } +//! +//! # #[tokio::main] +//! # async fn main() -> Result<(), BoxError> { +//! let mut service = ServiceBuilder::new() +//! // Require the `Authorization` header to be `Bearer passwordlol` +//! .layer(ValidateRequestHeaderLayer::bearer("passwordlol")) +//! .service_fn(handle); +//! +//! // Requests with the correct token are allowed through +//! let request = Request::builder() +//! .header(AUTHORIZATION, "Bearer passwordlol") +//! .body(Full::default()) +//! .unwrap(); +//! +//! let response = service +//! .ready() +//! .await? +//! .call(request) +//! .await?; +//! +//! assert_eq!(StatusCode::OK, response.status()); +//! +//! // Requests with an invalid token get a `401 Unauthorized` response +//! let request = Request::builder() +//! .body(Full::default()) +//! .unwrap(); +//! +//! let response = service +//! .ready() +//! .await? +//! .call(request) +//! .await?; +//! +//! assert_eq!(StatusCode::UNAUTHORIZED, response.status()); +//! # Ok(()) +//! # } +//! ``` +//! +//! Custom validation can be made by implementing [`ValidateRequest`]. + +use crate::validate_request::{ValidateRequest, ValidateRequestHeader, ValidateRequestHeaderLayer}; +use base64::Engine as _; +use http::{ + header::{self, HeaderValue}, + Request, Response, StatusCode, +}; +use std::{fmt, marker::PhantomData}; + +const BASE64: base64::engine::GeneralPurpose = base64::engine::general_purpose::STANDARD; + +impl<S, ResBody> ValidateRequestHeader<S, Basic<ResBody>> { + /// Authorize requests using a username and password pair. + /// + /// The `Authorization` header is required to be `Basic {credentials}` where `credentials` is + /// `base64_encode("{username}:{password}")`. + /// + /// Since the username and password is sent in clear text it is recommended to use HTTPS/TLS + /// with this method. However use of HTTPS/TLS is not enforced by this middleware. + pub fn basic(inner: S, username: &str, value: &str) -> Self + where + ResBody: Default, + { + Self::custom(inner, Basic::new(username, value)) + } +} + +impl<ResBody> ValidateRequestHeaderLayer<Basic<ResBody>> { + /// Authorize requests using a username and password pair. + /// + /// The `Authorization` header is required to be `Basic {credentials}` where `credentials` is + /// `base64_encode("{username}:{password}")`. + /// + /// Since the username and password is sent in clear text it is recommended to use HTTPS/TLS + /// with this method. However use of HTTPS/TLS is not enforced by this middleware. + pub fn basic(username: &str, password: &str) -> Self + where + ResBody: Default, + { + Self::custom(Basic::new(username, password)) + } +} + +impl<S, ResBody> ValidateRequestHeader<S, Bearer<ResBody>> { + /// Authorize requests using a "bearer token". Commonly used for OAuth 2. + /// + /// The `Authorization` header is required to be `Bearer {token}`. + /// + /// # Panics + /// + /// Panics if the token is not a valid [`HeaderValue`]. + pub fn bearer(inner: S, token: &str) -> Self + where + ResBody: Default, + { + Self::custom(inner, Bearer::new(token)) + } +} + +impl<ResBody> ValidateRequestHeaderLayer<Bearer<ResBody>> { + /// Authorize requests using a "bearer token". Commonly used for OAuth 2. + /// + /// The `Authorization` header is required to be `Bearer {token}`. + /// + /// # Panics + /// + /// Panics if the token is not a valid [`HeaderValue`]. + pub fn bearer(token: &str) -> Self + where + ResBody: Default, + { + Self::custom(Bearer::new(token)) + } +} + +/// Type that performs "bearer token" authorization. +/// +/// See [`ValidateRequestHeader::bearer`] for more details. +pub struct Bearer<ResBody> { + header_value: HeaderValue, + _ty: PhantomData<fn() -> ResBody>, +} + +impl<ResBody> Bearer<ResBody> { + fn new(token: &str) -> Self + where + ResBody: Default, + { + Self { + header_value: format!("Bearer {}", token) + .parse() + .expect("token is not a valid header value"), + _ty: PhantomData, + } + } +} + +impl<ResBody> Clone for Bearer<ResBody> { + fn clone(&self) -> Self { + Self { + header_value: self.header_value.clone(), + _ty: PhantomData, + } + } +} + +impl<ResBody> fmt::Debug for Bearer<ResBody> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Bearer") + .field("header_value", &self.header_value) + .finish() + } +} + +impl<B, ResBody> ValidateRequest<B> for Bearer<ResBody> +where + ResBody: Default, +{ + type ResponseBody = ResBody; + + fn validate(&mut self, request: &mut Request<B>) -> Result<(), Response<Self::ResponseBody>> { + match request.headers().get(header::AUTHORIZATION) { + Some(actual) if actual == self.header_value => Ok(()), + _ => { + let mut res = Response::new(ResBody::default()); + *res.status_mut() = StatusCode::UNAUTHORIZED; + Err(res) + } + } + } +} + +/// Type that performs basic authorization. +/// +/// See [`ValidateRequestHeader::basic`] for more details. +pub struct Basic<ResBody> { + header_value: HeaderValue, + _ty: PhantomData<fn() -> ResBody>, +} + +impl<ResBody> Basic<ResBody> { + fn new(username: &str, password: &str) -> Self + where + ResBody: Default, + { + let encoded = BASE64.encode(format!("{}:{}", username, password)); + let header_value = format!("Basic {}", encoded).parse().unwrap(); + Self { + header_value, + _ty: PhantomData, + } + } +} + +impl<ResBody> Clone for Basic<ResBody> { + fn clone(&self) -> Self { + Self { + header_value: self.header_value.clone(), + _ty: PhantomData, + } + } +} + +impl<ResBody> fmt::Debug for Basic<ResBody> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Basic") + .field("header_value", &self.header_value) + .finish() + } +} + +impl<B, ResBody> ValidateRequest<B> for Basic<ResBody> +where + ResBody: Default, +{ + type ResponseBody = ResBody; + + fn validate(&mut self, request: &mut Request<B>) -> Result<(), Response<Self::ResponseBody>> { + match request.headers().get(header::AUTHORIZATION) { + Some(actual) if actual == self.header_value => Ok(()), + _ => { + let mut res = Response::new(ResBody::default()); + *res.status_mut() = StatusCode::UNAUTHORIZED; + res.headers_mut() + .insert(header::WWW_AUTHENTICATE, "Basic".parse().unwrap()); + Err(res) + } + } + } +} + +#[cfg(test)] +mod tests { + use crate::validate_request::ValidateRequestHeaderLayer; + + #[allow(unused_imports)] + use super::*; + use crate::test_helpers::Body; + use http::header; + use tower::{BoxError, ServiceBuilder, ServiceExt}; + use tower_service::Service; + + #[tokio::test] + async fn valid_basic_token() { + let mut service = ServiceBuilder::new() + .layer(ValidateRequestHeaderLayer::basic("foo", "bar")) + .service_fn(echo); + + let request = Request::get("/") + .header( + header::AUTHORIZATION, + format!("Basic {}", BASE64.encode("foo:bar")), + ) + .body(Body::empty()) + .unwrap(); + + let res = service.ready().await.unwrap().call(request).await.unwrap(); + + assert_eq!(res.status(), StatusCode::OK); + } + + #[tokio::test] + async fn invalid_basic_token() { + let mut service = ServiceBuilder::new() + .layer(ValidateRequestHeaderLayer::basic("foo", "bar")) + .service_fn(echo); + + let request = Request::get("/") + .header( + header::AUTHORIZATION, + format!("Basic {}", BASE64.encode("wrong:credentials")), + ) + .body(Body::empty()) + .unwrap(); + + let res = service.ready().await.unwrap().call(request).await.unwrap(); + + assert_eq!(res.status(), StatusCode::UNAUTHORIZED); + + let www_authenticate = res.headers().get(header::WWW_AUTHENTICATE).unwrap(); + assert_eq!(www_authenticate, "Basic"); + } + + #[tokio::test] + async fn valid_bearer_token() { + let mut service = ServiceBuilder::new() + .layer(ValidateRequestHeaderLayer::bearer("foobar")) + .service_fn(echo); + + let request = Request::get("/") + .header(header::AUTHORIZATION, "Bearer foobar") + .body(Body::empty()) + .unwrap(); + + let res = service.ready().await.unwrap().call(request).await.unwrap(); + + assert_eq!(res.status(), StatusCode::OK); + } + + #[tokio::test] + async fn basic_auth_is_case_sensitive_in_prefix() { + let mut service = ServiceBuilder::new() + .layer(ValidateRequestHeaderLayer::basic("foo", "bar")) + .service_fn(echo); + + let request = Request::get("/") + .header( + header::AUTHORIZATION, + format!("basic {}", BASE64.encode("foo:bar")), + ) + .body(Body::empty()) + .unwrap(); + + let res = service.ready().await.unwrap().call(request).await.unwrap(); + + assert_eq!(res.status(), StatusCode::UNAUTHORIZED); + } + + #[tokio::test] + async fn basic_auth_is_case_sensitive_in_value() { + let mut service = ServiceBuilder::new() + .layer(ValidateRequestHeaderLayer::basic("foo", "bar")) + .service_fn(echo); + + let request = Request::get("/") + .header( + header::AUTHORIZATION, + format!("Basic {}", BASE64.encode("Foo:bar")), + ) + .body(Body::empty()) + .unwrap(); + + let res = service.ready().await.unwrap().call(request).await.unwrap(); + + assert_eq!(res.status(), StatusCode::UNAUTHORIZED); + } + + #[tokio::test] + async fn invalid_bearer_token() { + let mut service = ServiceBuilder::new() + .layer(ValidateRequestHeaderLayer::bearer("foobar")) + .service_fn(echo); + + let request = Request::get("/") + .header(header::AUTHORIZATION, "Bearer wat") + .body(Body::empty()) + .unwrap(); + + let res = service.ready().await.unwrap().call(request).await.unwrap(); + + assert_eq!(res.status(), StatusCode::UNAUTHORIZED); + } + + #[tokio::test] + async fn bearer_token_is_case_sensitive_in_prefix() { + let mut service = ServiceBuilder::new() + .layer(ValidateRequestHeaderLayer::bearer("foobar")) + .service_fn(echo); + + let request = Request::get("/") + .header(header::AUTHORIZATION, "bearer foobar") + .body(Body::empty()) + .unwrap(); + + let res = service.ready().await.unwrap().call(request).await.unwrap(); + + assert_eq!(res.status(), StatusCode::UNAUTHORIZED); + } + + #[tokio::test] + async fn bearer_token_is_case_sensitive_in_token() { + let mut service = ServiceBuilder::new() + .layer(ValidateRequestHeaderLayer::bearer("foobar")) + .service_fn(echo); + + let request = Request::get("/") + .header(header::AUTHORIZATION, "Bearer Foobar") + .body(Body::empty()) + .unwrap(); + + let res = service.ready().await.unwrap().call(request).await.unwrap(); + + assert_eq!(res.status(), StatusCode::UNAUTHORIZED); + } + + async fn echo(req: Request<Body>) -> Result<Response<Body>, BoxError> { + Ok(Response::new(req.into_body())) + } +} diff --git a/vendor/tower-http/src/body.rs b/vendor/tower-http/src/body.rs new file mode 100644 index 00000000..815a0d10 --- /dev/null +++ b/vendor/tower-http/src/body.rs @@ -0,0 +1,121 @@ +//! Body types. +//! +//! All these are wrappers around other body types. You shouldn't have to use them in your code. +//! Use `http-body-util` instead. +//! +//! They exist because we don't want to expose types from `http-body-util` in `tower-http`s public +//! API. + +#![allow(missing_docs)] + +use std::convert::Infallible; + +use bytes::{Buf, Bytes}; +use http_body::Body; +use pin_project_lite::pin_project; + +use crate::BoxError; + +macro_rules! body_methods { + () => { + #[inline] + fn poll_frame( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> { + self.project().inner.poll_frame(cx) + } + + #[inline] + fn is_end_stream(&self) -> bool { + Body::is_end_stream(&self.inner) + } + + #[inline] + fn size_hint(&self) -> http_body::SizeHint { + Body::size_hint(&self.inner) + } + }; +} + +pin_project! { + #[derive(Default)] + pub struct Full { + #[pin] + pub(crate) inner: http_body_util::Full<Bytes> + } +} + +impl Full { + #[allow(dead_code)] + pub(crate) fn new(inner: http_body_util::Full<Bytes>) -> Self { + Self { inner } + } +} + +impl Body for Full { + type Data = Bytes; + type Error = Infallible; + + body_methods!(); +} + +pin_project! { + pub struct Limited<B> { + #[pin] + pub(crate) inner: http_body_util::Limited<B> + } +} + +impl<B> Limited<B> { + #[allow(dead_code)] + pub(crate) fn new(inner: http_body_util::Limited<B>) -> Self { + Self { inner } + } +} + +impl<B> Body for Limited<B> +where + B: Body, + B::Error: Into<BoxError>, +{ + type Data = B::Data; + type Error = BoxError; + + body_methods!(); +} + +pin_project! { + pub struct UnsyncBoxBody<D, E> { + #[pin] + pub(crate) inner: http_body_util::combinators::UnsyncBoxBody<D, E> + } +} + +impl<D, E> Default for UnsyncBoxBody<D, E> +where + D: Buf + 'static, +{ + fn default() -> Self { + Self { + inner: Default::default(), + } + } +} + +impl<D, E> UnsyncBoxBody<D, E> { + #[allow(dead_code)] + pub(crate) fn new(inner: http_body_util::combinators::UnsyncBoxBody<D, E>) -> Self { + Self { inner } + } +} + +impl<D, E> Body for UnsyncBoxBody<D, E> +where + D: Buf, +{ + type Data = D; + type Error = E; + + body_methods!(); +} diff --git a/vendor/tower-http/src/builder.rs b/vendor/tower-http/src/builder.rs new file mode 100644 index 00000000..3bdcf64a --- /dev/null +++ b/vendor/tower-http/src/builder.rs @@ -0,0 +1,616 @@ +use tower::ServiceBuilder; + +#[allow(unused_imports)] +use http::header::HeaderName; +#[allow(unused_imports)] +use tower_layer::Stack; + +mod sealed { + #[allow(unreachable_pub, unused)] + pub trait Sealed<T> {} +} + +/// Extension trait that adds methods to [`tower::ServiceBuilder`] for adding middleware from +/// tower-http. +/// +/// [`Service`]: tower::Service +/// +/// # Example +/// +/// ```rust +/// use http::{Request, Response, header::HeaderName}; +/// use bytes::Bytes; +/// use http_body_util::Full; +/// use std::{time::Duration, convert::Infallible}; +/// use tower::{ServiceBuilder, ServiceExt, Service}; +/// use tower_http::ServiceBuilderExt; +/// +/// async fn handle(request: Request<Full<Bytes>>) -> Result<Response<Full<Bytes>>, Infallible> { +/// Ok(Response::new(Full::default())) +/// } +/// +/// # #[tokio::main] +/// # async fn main() { +/// let service = ServiceBuilder::new() +/// // Methods from tower +/// .timeout(Duration::from_secs(30)) +/// // Methods from tower-http +/// .trace_for_http() +/// .propagate_header(HeaderName::from_static("x-request-id")) +/// .service_fn(handle); +/// # let mut service = service; +/// # service.ready().await.unwrap().call(Request::new(Full::default())).await.unwrap(); +/// # } +/// ``` +#[cfg(feature = "util")] +// ^ work around rustdoc not inferring doc(cfg)s for cfg's from surrounding scopes +pub trait ServiceBuilderExt<L>: sealed::Sealed<L> + Sized { + /// Propagate a header from the request to the response. + /// + /// See [`tower_http::propagate_header`] for more details. + /// + /// [`tower_http::propagate_header`]: crate::propagate_header + #[cfg(feature = "propagate-header")] + fn propagate_header( + self, + header: HeaderName, + ) -> ServiceBuilder<Stack<crate::propagate_header::PropagateHeaderLayer, L>>; + + /// Add some shareable value to [request extensions]. + /// + /// See [`tower_http::add_extension`] for more details. + /// + /// [`tower_http::add_extension`]: crate::add_extension + /// [request extensions]: https://docs.rs/http/latest/http/struct.Extensions.html + #[cfg(feature = "add-extension")] + fn add_extension<T>( + self, + value: T, + ) -> ServiceBuilder<Stack<crate::add_extension::AddExtensionLayer<T>, L>>; + + /// Apply a transformation to the request body. + /// + /// See [`tower_http::map_request_body`] for more details. + /// + /// [`tower_http::map_request_body`]: crate::map_request_body + #[cfg(feature = "map-request-body")] + fn map_request_body<F>( + self, + f: F, + ) -> ServiceBuilder<Stack<crate::map_request_body::MapRequestBodyLayer<F>, L>>; + + /// Apply a transformation to the response body. + /// + /// See [`tower_http::map_response_body`] for more details. + /// + /// [`tower_http::map_response_body`]: crate::map_response_body + #[cfg(feature = "map-response-body")] + fn map_response_body<F>( + self, + f: F, + ) -> ServiceBuilder<Stack<crate::map_response_body::MapResponseBodyLayer<F>, L>>; + + /// Compresses response bodies. + /// + /// See [`tower_http::compression`] for more details. + /// + /// [`tower_http::compression`]: crate::compression + #[cfg(any( + feature = "compression-br", + feature = "compression-deflate", + feature = "compression-gzip", + feature = "compression-zstd", + ))] + fn compression(self) -> ServiceBuilder<Stack<crate::compression::CompressionLayer, L>>; + + /// Decompress response bodies. + /// + /// See [`tower_http::decompression`] for more details. + /// + /// [`tower_http::decompression`]: crate::decompression + #[cfg(any( + feature = "decompression-br", + feature = "decompression-deflate", + feature = "decompression-gzip", + feature = "decompression-zstd", + ))] + fn decompression(self) -> ServiceBuilder<Stack<crate::decompression::DecompressionLayer, L>>; + + /// High level tracing that classifies responses using HTTP status codes. + /// + /// This method does not support customizing the output, to do that use [`TraceLayer`] + /// instead. + /// + /// See [`tower_http::trace`] for more details. + /// + /// [`tower_http::trace`]: crate::trace + /// [`TraceLayer`]: crate::trace::TraceLayer + #[cfg(feature = "trace")] + fn trace_for_http( + self, + ) -> ServiceBuilder<Stack<crate::trace::TraceLayer<crate::trace::HttpMakeClassifier>, L>>; + + /// High level tracing that classifies responses using gRPC headers. + /// + /// This method does not support customizing the output, to do that use [`TraceLayer`] + /// instead. + /// + /// See [`tower_http::trace`] for more details. + /// + /// [`tower_http::trace`]: crate::trace + /// [`TraceLayer`]: crate::trace::TraceLayer + #[cfg(feature = "trace")] + fn trace_for_grpc( + self, + ) -> ServiceBuilder<Stack<crate::trace::TraceLayer<crate::trace::GrpcMakeClassifier>, L>>; + + /// Follow redirect resposes using the [`Standard`] policy. + /// + /// See [`tower_http::follow_redirect`] for more details. + /// + /// [`tower_http::follow_redirect`]: crate::follow_redirect + /// [`Standard`]: crate::follow_redirect::policy::Standard + #[cfg(feature = "follow-redirect")] + fn follow_redirects( + self, + ) -> ServiceBuilder< + Stack< + crate::follow_redirect::FollowRedirectLayer<crate::follow_redirect::policy::Standard>, + L, + >, + >; + + /// Mark headers as [sensitive] on both requests and responses. + /// + /// See [`tower_http::sensitive_headers`] for more details. + /// + /// [sensitive]: https://docs.rs/http/latest/http/header/struct.HeaderValue.html#method.set_sensitive + /// [`tower_http::sensitive_headers`]: crate::sensitive_headers + #[cfg(feature = "sensitive-headers")] + fn sensitive_headers<I>( + self, + headers: I, + ) -> ServiceBuilder<Stack<crate::sensitive_headers::SetSensitiveHeadersLayer, L>> + where + I: IntoIterator<Item = HeaderName>; + + /// Mark headers as [sensitive] on requests. + /// + /// See [`tower_http::sensitive_headers`] for more details. + /// + /// [sensitive]: https://docs.rs/http/latest/http/header/struct.HeaderValue.html#method.set_sensitive + /// [`tower_http::sensitive_headers`]: crate::sensitive_headers + #[cfg(feature = "sensitive-headers")] + fn sensitive_request_headers( + self, + headers: std::sync::Arc<[HeaderName]>, + ) -> ServiceBuilder<Stack<crate::sensitive_headers::SetSensitiveRequestHeadersLayer, L>>; + + /// Mark headers as [sensitive] on responses. + /// + /// See [`tower_http::sensitive_headers`] for more details. + /// + /// [sensitive]: https://docs.rs/http/latest/http/header/struct.HeaderValue.html#method.set_sensitive + /// [`tower_http::sensitive_headers`]: crate::sensitive_headers + #[cfg(feature = "sensitive-headers")] + fn sensitive_response_headers( + self, + headers: std::sync::Arc<[HeaderName]>, + ) -> ServiceBuilder<Stack<crate::sensitive_headers::SetSensitiveResponseHeadersLayer, L>>; + + /// Insert a header into the request. + /// + /// If a previous value exists for the same header, it is removed and replaced with the new + /// header value. + /// + /// See [`tower_http::set_header`] for more details. + /// + /// [`tower_http::set_header`]: crate::set_header + #[cfg(feature = "set-header")] + fn override_request_header<M>( + self, + header_name: HeaderName, + make: M, + ) -> ServiceBuilder<Stack<crate::set_header::SetRequestHeaderLayer<M>, L>>; + + /// Append a header into the request. + /// + /// If previous values exist, the header will have multiple values. + /// + /// See [`tower_http::set_header`] for more details. + /// + /// [`tower_http::set_header`]: crate::set_header + #[cfg(feature = "set-header")] + fn append_request_header<M>( + self, + header_name: HeaderName, + make: M, + ) -> ServiceBuilder<Stack<crate::set_header::SetRequestHeaderLayer<M>, L>>; + + /// Insert a header into the request, if the header is not already present. + /// + /// See [`tower_http::set_header`] for more details. + /// + /// [`tower_http::set_header`]: crate::set_header + #[cfg(feature = "set-header")] + fn insert_request_header_if_not_present<M>( + self, + header_name: HeaderName, + make: M, + ) -> ServiceBuilder<Stack<crate::set_header::SetRequestHeaderLayer<M>, L>>; + + /// Insert a header into the response. + /// + /// If a previous value exists for the same header, it is removed and replaced with the new + /// header value. + /// + /// See [`tower_http::set_header`] for more details. + /// + /// [`tower_http::set_header`]: crate::set_header + #[cfg(feature = "set-header")] + fn override_response_header<M>( + self, + header_name: HeaderName, + make: M, + ) -> ServiceBuilder<Stack<crate::set_header::SetResponseHeaderLayer<M>, L>>; + + /// Append a header into the response. + /// + /// If previous values exist, the header will have multiple values. + /// + /// See [`tower_http::set_header`] for more details. + /// + /// [`tower_http::set_header`]: crate::set_header + #[cfg(feature = "set-header")] + fn append_response_header<M>( + self, + header_name: HeaderName, + make: M, + ) -> ServiceBuilder<Stack<crate::set_header::SetResponseHeaderLayer<M>, L>>; + + /// Insert a header into the response, if the header is not already present. + /// + /// See [`tower_http::set_header`] for more details. + /// + /// [`tower_http::set_header`]: crate::set_header + #[cfg(feature = "set-header")] + fn insert_response_header_if_not_present<M>( + self, + header_name: HeaderName, + make: M, + ) -> ServiceBuilder<Stack<crate::set_header::SetResponseHeaderLayer<M>, L>>; + + /// Add request id header and extension. + /// + /// See [`tower_http::request_id`] for more details. + /// + /// [`tower_http::request_id`]: crate::request_id + #[cfg(feature = "request-id")] + fn set_request_id<M>( + self, + header_name: HeaderName, + make_request_id: M, + ) -> ServiceBuilder<Stack<crate::request_id::SetRequestIdLayer<M>, L>> + where + M: crate::request_id::MakeRequestId; + + /// Add request id header and extension, using `x-request-id` as the header name. + /// + /// See [`tower_http::request_id`] for more details. + /// + /// [`tower_http::request_id`]: crate::request_id + #[cfg(feature = "request-id")] + fn set_x_request_id<M>( + self, + make_request_id: M, + ) -> ServiceBuilder<Stack<crate::request_id::SetRequestIdLayer<M>, L>> + where + M: crate::request_id::MakeRequestId, + { + self.set_request_id(crate::request_id::X_REQUEST_ID, make_request_id) + } + + /// Propgate request ids from requests to responses. + /// + /// See [`tower_http::request_id`] for more details. + /// + /// [`tower_http::request_id`]: crate::request_id + #[cfg(feature = "request-id")] + fn propagate_request_id( + self, + header_name: HeaderName, + ) -> ServiceBuilder<Stack<crate::request_id::PropagateRequestIdLayer, L>>; + + /// Propgate request ids from requests to responses, using `x-request-id` as the header name. + /// + /// See [`tower_http::request_id`] for more details. + /// + /// [`tower_http::request_id`]: crate::request_id + #[cfg(feature = "request-id")] + fn propagate_x_request_id( + self, + ) -> ServiceBuilder<Stack<crate::request_id::PropagateRequestIdLayer, L>> { + self.propagate_request_id(crate::request_id::X_REQUEST_ID) + } + + /// Catch panics and convert them into `500 Internal Server` responses. + /// + /// See [`tower_http::catch_panic`] for more details. + /// + /// [`tower_http::catch_panic`]: crate::catch_panic + #[cfg(feature = "catch-panic")] + fn catch_panic( + self, + ) -> ServiceBuilder< + Stack<crate::catch_panic::CatchPanicLayer<crate::catch_panic::DefaultResponseForPanic>, L>, + >; + + /// Intercept requests with over-sized payloads and convert them into + /// `413 Payload Too Large` responses. + /// + /// See [`tower_http::limit`] for more details. + /// + /// [`tower_http::limit`]: crate::limit + #[cfg(feature = "limit")] + fn request_body_limit( + self, + limit: usize, + ) -> ServiceBuilder<Stack<crate::limit::RequestBodyLimitLayer, L>>; + + /// Remove trailing slashes from paths. + /// + /// See [`tower_http::normalize_path`] for more details. + /// + /// [`tower_http::normalize_path`]: crate::normalize_path + #[cfg(feature = "normalize-path")] + fn trim_trailing_slash( + self, + ) -> ServiceBuilder<Stack<crate::normalize_path::NormalizePathLayer, L>>; + + /// Append trailing slash to paths. + /// + /// See [`tower_http::normalize_path`] for more details. + /// + /// [`tower_http::normalize_path`]: crate::normalize_path + #[cfg(feature = "normalize-path")] + fn append_trailing_slash( + self, + ) -> ServiceBuilder<Stack<crate::normalize_path::NormalizePathLayer, L>>; +} + +impl<L> sealed::Sealed<L> for ServiceBuilder<L> {} + +impl<L> ServiceBuilderExt<L> for ServiceBuilder<L> { + #[cfg(feature = "propagate-header")] + fn propagate_header( + self, + header: HeaderName, + ) -> ServiceBuilder<Stack<crate::propagate_header::PropagateHeaderLayer, L>> { + self.layer(crate::propagate_header::PropagateHeaderLayer::new(header)) + } + + #[cfg(feature = "add-extension")] + fn add_extension<T>( + self, + value: T, + ) -> ServiceBuilder<Stack<crate::add_extension::AddExtensionLayer<T>, L>> { + self.layer(crate::add_extension::AddExtensionLayer::new(value)) + } + + #[cfg(feature = "map-request-body")] + fn map_request_body<F>( + self, + f: F, + ) -> ServiceBuilder<Stack<crate::map_request_body::MapRequestBodyLayer<F>, L>> { + self.layer(crate::map_request_body::MapRequestBodyLayer::new(f)) + } + + #[cfg(feature = "map-response-body")] + fn map_response_body<F>( + self, + f: F, + ) -> ServiceBuilder<Stack<crate::map_response_body::MapResponseBodyLayer<F>, L>> { + self.layer(crate::map_response_body::MapResponseBodyLayer::new(f)) + } + + #[cfg(any( + feature = "compression-br", + feature = "compression-deflate", + feature = "compression-gzip", + feature = "compression-zstd", + ))] + fn compression(self) -> ServiceBuilder<Stack<crate::compression::CompressionLayer, L>> { + self.layer(crate::compression::CompressionLayer::new()) + } + + #[cfg(any( + feature = "decompression-br", + feature = "decompression-deflate", + feature = "decompression-gzip", + feature = "decompression-zstd", + ))] + fn decompression(self) -> ServiceBuilder<Stack<crate::decompression::DecompressionLayer, L>> { + self.layer(crate::decompression::DecompressionLayer::new()) + } + + #[cfg(feature = "trace")] + fn trace_for_http( + self, + ) -> ServiceBuilder<Stack<crate::trace::TraceLayer<crate::trace::HttpMakeClassifier>, L>> { + self.layer(crate::trace::TraceLayer::new_for_http()) + } + + #[cfg(feature = "trace")] + fn trace_for_grpc( + self, + ) -> ServiceBuilder<Stack<crate::trace::TraceLayer<crate::trace::GrpcMakeClassifier>, L>> { + self.layer(crate::trace::TraceLayer::new_for_grpc()) + } + + #[cfg(feature = "follow-redirect")] + fn follow_redirects( + self, + ) -> ServiceBuilder< + Stack< + crate::follow_redirect::FollowRedirectLayer<crate::follow_redirect::policy::Standard>, + L, + >, + > { + self.layer(crate::follow_redirect::FollowRedirectLayer::new()) + } + + #[cfg(feature = "sensitive-headers")] + fn sensitive_headers<I>( + self, + headers: I, + ) -> ServiceBuilder<Stack<crate::sensitive_headers::SetSensitiveHeadersLayer, L>> + where + I: IntoIterator<Item = HeaderName>, + { + self.layer(crate::sensitive_headers::SetSensitiveHeadersLayer::new( + headers, + )) + } + + #[cfg(feature = "sensitive-headers")] + fn sensitive_request_headers( + self, + headers: std::sync::Arc<[HeaderName]>, + ) -> ServiceBuilder<Stack<crate::sensitive_headers::SetSensitiveRequestHeadersLayer, L>> { + self.layer(crate::sensitive_headers::SetSensitiveRequestHeadersLayer::from_shared(headers)) + } + + #[cfg(feature = "sensitive-headers")] + fn sensitive_response_headers( + self, + headers: std::sync::Arc<[HeaderName]>, + ) -> ServiceBuilder<Stack<crate::sensitive_headers::SetSensitiveResponseHeadersLayer, L>> { + self.layer(crate::sensitive_headers::SetSensitiveResponseHeadersLayer::from_shared(headers)) + } + + #[cfg(feature = "set-header")] + fn override_request_header<M>( + self, + header_name: HeaderName, + make: M, + ) -> ServiceBuilder<Stack<crate::set_header::SetRequestHeaderLayer<M>, L>> { + self.layer(crate::set_header::SetRequestHeaderLayer::overriding( + header_name, + make, + )) + } + + #[cfg(feature = "set-header")] + fn append_request_header<M>( + self, + header_name: HeaderName, + make: M, + ) -> ServiceBuilder<Stack<crate::set_header::SetRequestHeaderLayer<M>, L>> { + self.layer(crate::set_header::SetRequestHeaderLayer::appending( + header_name, + make, + )) + } + + #[cfg(feature = "set-header")] + fn insert_request_header_if_not_present<M>( + self, + header_name: HeaderName, + make: M, + ) -> ServiceBuilder<Stack<crate::set_header::SetRequestHeaderLayer<M>, L>> { + self.layer(crate::set_header::SetRequestHeaderLayer::if_not_present( + header_name, + make, + )) + } + + #[cfg(feature = "set-header")] + fn override_response_header<M>( + self, + header_name: HeaderName, + make: M, + ) -> ServiceBuilder<Stack<crate::set_header::SetResponseHeaderLayer<M>, L>> { + self.layer(crate::set_header::SetResponseHeaderLayer::overriding( + header_name, + make, + )) + } + + #[cfg(feature = "set-header")] + fn append_response_header<M>( + self, + header_name: HeaderName, + make: M, + ) -> ServiceBuilder<Stack<crate::set_header::SetResponseHeaderLayer<M>, L>> { + self.layer(crate::set_header::SetResponseHeaderLayer::appending( + header_name, + make, + )) + } + + #[cfg(feature = "set-header")] + fn insert_response_header_if_not_present<M>( + self, + header_name: HeaderName, + make: M, + ) -> ServiceBuilder<Stack<crate::set_header::SetResponseHeaderLayer<M>, L>> { + self.layer(crate::set_header::SetResponseHeaderLayer::if_not_present( + header_name, + make, + )) + } + + #[cfg(feature = "request-id")] + fn set_request_id<M>( + self, + header_name: HeaderName, + make_request_id: M, + ) -> ServiceBuilder<Stack<crate::request_id::SetRequestIdLayer<M>, L>> + where + M: crate::request_id::MakeRequestId, + { + self.layer(crate::request_id::SetRequestIdLayer::new( + header_name, + make_request_id, + )) + } + + #[cfg(feature = "request-id")] + fn propagate_request_id( + self, + header_name: HeaderName, + ) -> ServiceBuilder<Stack<crate::request_id::PropagateRequestIdLayer, L>> { + self.layer(crate::request_id::PropagateRequestIdLayer::new(header_name)) + } + + #[cfg(feature = "catch-panic")] + fn catch_panic( + self, + ) -> ServiceBuilder< + Stack<crate::catch_panic::CatchPanicLayer<crate::catch_panic::DefaultResponseForPanic>, L>, + > { + self.layer(crate::catch_panic::CatchPanicLayer::new()) + } + + #[cfg(feature = "limit")] + fn request_body_limit( + self, + limit: usize, + ) -> ServiceBuilder<Stack<crate::limit::RequestBodyLimitLayer, L>> { + self.layer(crate::limit::RequestBodyLimitLayer::new(limit)) + } + + #[cfg(feature = "normalize-path")] + fn trim_trailing_slash( + self, + ) -> ServiceBuilder<Stack<crate::normalize_path::NormalizePathLayer, L>> { + self.layer(crate::normalize_path::NormalizePathLayer::trim_trailing_slash()) + } + + #[cfg(feature = "normalize-path")] + fn append_trailing_slash( + self, + ) -> ServiceBuilder<Stack<crate::normalize_path::NormalizePathLayer, L>> { + self.layer(crate::normalize_path::NormalizePathLayer::append_trailing_slash()) + } +} diff --git a/vendor/tower-http/src/catch_panic.rs b/vendor/tower-http/src/catch_panic.rs new file mode 100644 index 00000000..3f1c2279 --- /dev/null +++ b/vendor/tower-http/src/catch_panic.rs @@ -0,0 +1,409 @@ +//! Convert panics into responses. +//! +//! Note that using panics for error handling is _not_ recommended. Prefer instead to use `Result` +//! whenever possible. +//! +//! # Example +//! +//! ```rust +//! use http::{Request, Response, header::HeaderName}; +//! use std::convert::Infallible; +//! use tower::{Service, ServiceExt, ServiceBuilder, service_fn}; +//! use tower_http::catch_panic::CatchPanicLayer; +//! use http_body_util::Full; +//! use bytes::Bytes; +//! +//! # #[tokio::main] +//! # async fn main() -> Result<(), Box<dyn std::error::Error>> { +//! async fn handle(req: Request<Full<Bytes>>) -> Result<Response<Full<Bytes>>, Infallible> { +//! panic!("something went wrong...") +//! } +//! +//! let mut svc = ServiceBuilder::new() +//! // Catch panics and convert them into responses. +//! .layer(CatchPanicLayer::new()) +//! .service_fn(handle); +//! +//! // Call the service. +//! let request = Request::new(Full::default()); +//! +//! let response = svc.ready().await?.call(request).await?; +//! +//! assert_eq!(response.status(), 500); +//! # +//! # Ok(()) +//! # } +//! ``` +//! +//! Using a custom panic handler: +//! +//! ```rust +//! use http::{Request, StatusCode, Response, header::{self, HeaderName}}; +//! use std::{any::Any, convert::Infallible}; +//! use tower::{Service, ServiceExt, ServiceBuilder, service_fn}; +//! use tower_http::catch_panic::CatchPanicLayer; +//! use bytes::Bytes; +//! use http_body_util::Full; +//! +//! # #[tokio::main] +//! # async fn main() -> Result<(), Box<dyn std::error::Error>> { +//! async fn handle(req: Request<Full<Bytes>>) -> Result<Response<Full<Bytes>>, Infallible> { +//! panic!("something went wrong...") +//! } +//! +//! fn handle_panic(err: Box<dyn Any + Send + 'static>) -> Response<Full<Bytes>> { +//! let details = if let Some(s) = err.downcast_ref::<String>() { +//! s.clone() +//! } else if let Some(s) = err.downcast_ref::<&str>() { +//! s.to_string() +//! } else { +//! "Unknown panic message".to_string() +//! }; +//! +//! let body = serde_json::json!({ +//! "error": { +//! "kind": "panic", +//! "details": details, +//! } +//! }); +//! let body = serde_json::to_string(&body).unwrap(); +//! +//! Response::builder() +//! .status(StatusCode::INTERNAL_SERVER_ERROR) +//! .header(header::CONTENT_TYPE, "application/json") +//! .body(Full::from(body)) +//! .unwrap() +//! } +//! +//! let svc = ServiceBuilder::new() +//! // Use `handle_panic` to create the response. +//! .layer(CatchPanicLayer::custom(handle_panic)) +//! .service_fn(handle); +//! # +//! # Ok(()) +//! # } +//! ``` + +use bytes::Bytes; +use futures_util::future::{CatchUnwind, FutureExt}; +use http::{HeaderValue, Request, Response, StatusCode}; +use http_body::Body; +use http_body_util::BodyExt; +use pin_project_lite::pin_project; +use std::{ + any::Any, + future::Future, + panic::AssertUnwindSafe, + pin::Pin, + task::{ready, Context, Poll}, +}; +use tower_layer::Layer; +use tower_service::Service; + +use crate::{ + body::{Full, UnsyncBoxBody}, + BoxError, +}; + +/// Layer that applies the [`CatchPanic`] middleware that catches panics and converts them into +/// `500 Internal Server` responses. +/// +/// See the [module docs](self) for an example. +#[derive(Debug, Clone, Copy, Default)] +pub struct CatchPanicLayer<T> { + panic_handler: T, +} + +impl CatchPanicLayer<DefaultResponseForPanic> { + /// Create a new `CatchPanicLayer` with the default panic handler. + pub fn new() -> Self { + CatchPanicLayer { + panic_handler: DefaultResponseForPanic, + } + } +} + +impl<T> CatchPanicLayer<T> { + /// Create a new `CatchPanicLayer` with a custom panic handler. + pub fn custom(panic_handler: T) -> Self + where + T: ResponseForPanic, + { + Self { panic_handler } + } +} + +impl<T, S> Layer<S> for CatchPanicLayer<T> +where + T: Clone, +{ + type Service = CatchPanic<S, T>; + + fn layer(&self, inner: S) -> Self::Service { + CatchPanic { + inner, + panic_handler: self.panic_handler.clone(), + } + } +} + +/// Middleware that catches panics and converts them into `500 Internal Server` responses. +/// +/// See the [module docs](self) for an example. +#[derive(Debug, Clone, Copy)] +pub struct CatchPanic<S, T> { + inner: S, + panic_handler: T, +} + +impl<S> CatchPanic<S, DefaultResponseForPanic> { + /// Create a new `CatchPanic` with the default panic handler. + pub fn new(inner: S) -> Self { + Self { + inner, + panic_handler: DefaultResponseForPanic, + } + } +} + +impl<S, T> CatchPanic<S, T> { + define_inner_service_accessors!(); + + /// Create a new `CatchPanic` with a custom panic handler. + pub fn custom(inner: S, panic_handler: T) -> Self + where + T: ResponseForPanic, + { + Self { + inner, + panic_handler, + } + } +} + +impl<S, T, ReqBody, ResBody> Service<Request<ReqBody>> for CatchPanic<S, T> +where + S: Service<Request<ReqBody>, Response = Response<ResBody>>, + ResBody: Body<Data = Bytes> + Send + 'static, + ResBody::Error: Into<BoxError>, + T: ResponseForPanic + Clone, + T::ResponseBody: Body<Data = Bytes> + Send + 'static, + <T::ResponseBody as Body>::Error: Into<BoxError>, +{ + type Response = Response<UnsyncBoxBody<Bytes, BoxError>>; + type Error = S::Error; + type Future = ResponseFuture<S::Future, T>; + + #[inline] + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, req: Request<ReqBody>) -> Self::Future { + match std::panic::catch_unwind(AssertUnwindSafe(|| self.inner.call(req))) { + Ok(future) => ResponseFuture { + kind: Kind::Future { + future: AssertUnwindSafe(future).catch_unwind(), + panic_handler: Some(self.panic_handler.clone()), + }, + }, + Err(panic_err) => ResponseFuture { + kind: Kind::Panicked { + panic_err: Some(panic_err), + panic_handler: Some(self.panic_handler.clone()), + }, + }, + } + } +} + +pin_project! { + /// Response future for [`CatchPanic`]. + pub struct ResponseFuture<F, T> { + #[pin] + kind: Kind<F, T>, + } +} + +pin_project! { + #[project = KindProj] + enum Kind<F, T> { + Panicked { + panic_err: Option<Box<dyn Any + Send + 'static>>, + panic_handler: Option<T>, + }, + Future { + #[pin] + future: CatchUnwind<AssertUnwindSafe<F>>, + panic_handler: Option<T>, + } + } +} + +impl<F, ResBody, E, T> Future for ResponseFuture<F, T> +where + F: Future<Output = Result<Response<ResBody>, E>>, + ResBody: Body<Data = Bytes> + Send + 'static, + ResBody::Error: Into<BoxError>, + T: ResponseForPanic, + T::ResponseBody: Body<Data = Bytes> + Send + 'static, + <T::ResponseBody as Body>::Error: Into<BoxError>, +{ + type Output = Result<Response<UnsyncBoxBody<Bytes, BoxError>>, E>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + match self.project().kind.project() { + KindProj::Panicked { + panic_err, + panic_handler, + } => { + let panic_handler = panic_handler + .take() + .expect("future polled after completion"); + let panic_err = panic_err.take().expect("future polled after completion"); + Poll::Ready(Ok(response_for_panic(panic_handler, panic_err))) + } + KindProj::Future { + future, + panic_handler, + } => match ready!(future.poll(cx)) { + Ok(Ok(res)) => { + Poll::Ready(Ok(res.map(|body| { + UnsyncBoxBody::new(body.map_err(Into::into).boxed_unsync()) + }))) + } + Ok(Err(svc_err)) => Poll::Ready(Err(svc_err)), + Err(panic_err) => Poll::Ready(Ok(response_for_panic( + panic_handler + .take() + .expect("future polled after completion"), + panic_err, + ))), + }, + } + } +} + +fn response_for_panic<T>( + mut panic_handler: T, + err: Box<dyn Any + Send + 'static>, +) -> Response<UnsyncBoxBody<Bytes, BoxError>> +where + T: ResponseForPanic, + T::ResponseBody: Body<Data = Bytes> + Send + 'static, + <T::ResponseBody as Body>::Error: Into<BoxError>, +{ + panic_handler + .response_for_panic(err) + .map(|body| UnsyncBoxBody::new(body.map_err(Into::into).boxed_unsync())) +} + +/// Trait for creating responses from panics. +pub trait ResponseForPanic: Clone { + /// The body type used for responses to panics. + type ResponseBody; + + /// Create a response from the panic error. + fn response_for_panic( + &mut self, + err: Box<dyn Any + Send + 'static>, + ) -> Response<Self::ResponseBody>; +} + +impl<F, B> ResponseForPanic for F +where + F: FnMut(Box<dyn Any + Send + 'static>) -> Response<B> + Clone, +{ + type ResponseBody = B; + + fn response_for_panic( + &mut self, + err: Box<dyn Any + Send + 'static>, + ) -> Response<Self::ResponseBody> { + self(err) + } +} + +/// The default `ResponseForPanic` used by `CatchPanic`. +/// +/// It will log the panic message and return a `500 Internal Server` error response with an empty +/// body. +#[derive(Debug, Default, Clone, Copy)] +#[non_exhaustive] +pub struct DefaultResponseForPanic; + +impl ResponseForPanic for DefaultResponseForPanic { + type ResponseBody = Full; + + fn response_for_panic( + &mut self, + err: Box<dyn Any + Send + 'static>, + ) -> Response<Self::ResponseBody> { + if let Some(s) = err.downcast_ref::<String>() { + tracing::error!("Service panicked: {}", s); + } else if let Some(s) = err.downcast_ref::<&str>() { + tracing::error!("Service panicked: {}", s); + } else { + tracing::error!( + "Service panicked but `CatchPanic` was unable to downcast the panic info" + ); + }; + + let mut res = Response::new(Full::new(http_body_util::Full::from("Service panicked"))); + *res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR; + + #[allow(clippy::declare_interior_mutable_const)] + const TEXT_PLAIN: HeaderValue = HeaderValue::from_static("text/plain; charset=utf-8"); + res.headers_mut() + .insert(http::header::CONTENT_TYPE, TEXT_PLAIN); + + res + } +} + +#[cfg(test)] +mod tests { + #![allow(unreachable_code)] + + use super::*; + use crate::test_helpers::Body; + use http::Response; + use std::convert::Infallible; + use tower::{ServiceBuilder, ServiceExt}; + + #[tokio::test] + async fn panic_before_returning_future() { + let svc = ServiceBuilder::new() + .layer(CatchPanicLayer::new()) + .service_fn(|_: Request<Body>| { + panic!("service panic"); + async { Ok::<_, Infallible>(Response::new(Body::empty())) } + }); + + let req = Request::new(Body::empty()); + + let res = svc.oneshot(req).await.unwrap(); + + assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR); + let body = crate::test_helpers::to_bytes(res).await.unwrap(); + assert_eq!(&body[..], b"Service panicked"); + } + + #[tokio::test] + async fn panic_in_future() { + let svc = ServiceBuilder::new() + .layer(CatchPanicLayer::new()) + .service_fn(|_: Request<Body>| async { + panic!("future panic"); + Ok::<_, Infallible>(Response::new(Body::empty())) + }); + + let req = Request::new(Body::empty()); + + let res = svc.oneshot(req).await.unwrap(); + + assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR); + let body = crate::test_helpers::to_bytes(res).await.unwrap(); + assert_eq!(&body[..], b"Service panicked"); + } +} diff --git a/vendor/tower-http/src/classify/grpc_errors_as_failures.rs b/vendor/tower-http/src/classify/grpc_errors_as_failures.rs new file mode 100644 index 00000000..3fc96c33 --- /dev/null +++ b/vendor/tower-http/src/classify/grpc_errors_as_failures.rs @@ -0,0 +1,357 @@ +use super::{ClassifiedResponse, ClassifyEos, ClassifyResponse, SharedClassifier}; +use bitflags::bitflags; +use http::{HeaderMap, Response}; +use std::{fmt, num::NonZeroI32}; + +/// gRPC status codes. +/// +/// These variants match the [gRPC status codes]. +/// +/// [gRPC status codes]: https://github.com/grpc/grpc/blob/master/doc/statuscodes.md#status-codes-and-their-use-in-grpc +#[derive(Clone, Copy, Debug)] +pub enum GrpcCode { + /// The operation completed successfully. + Ok, + /// The operation was cancelled. + Cancelled, + /// Unknown error. + Unknown, + /// Client specified an invalid argument. + InvalidArgument, + /// Deadline expired before operation could complete. + DeadlineExceeded, + /// Some requested entity was not found. + NotFound, + /// Some entity that we attempted to create already exists. + AlreadyExists, + /// The caller does not have permission to execute the specified operation. + PermissionDenied, + /// Some resource has been exhausted. + ResourceExhausted, + /// The system is not in a state required for the operation's execution. + FailedPrecondition, + /// The operation was aborted. + Aborted, + /// Operation was attempted past the valid range. + OutOfRange, + /// Operation is not implemented or not supported. + Unimplemented, + /// Internal error. + Internal, + /// The service is currently unavailable. + Unavailable, + /// Unrecoverable data loss or corruption. + DataLoss, + /// The request does not have valid authentication credentials + Unauthenticated, +} + +impl GrpcCode { + pub(crate) fn into_bitmask(self) -> GrpcCodeBitmask { + match self { + Self::Ok => GrpcCodeBitmask::OK, + Self::Cancelled => GrpcCodeBitmask::CANCELLED, + Self::Unknown => GrpcCodeBitmask::UNKNOWN, + Self::InvalidArgument => GrpcCodeBitmask::INVALID_ARGUMENT, + Self::DeadlineExceeded => GrpcCodeBitmask::DEADLINE_EXCEEDED, + Self::NotFound => GrpcCodeBitmask::NOT_FOUND, + Self::AlreadyExists => GrpcCodeBitmask::ALREADY_EXISTS, + Self::PermissionDenied => GrpcCodeBitmask::PERMISSION_DENIED, + Self::ResourceExhausted => GrpcCodeBitmask::RESOURCE_EXHAUSTED, + Self::FailedPrecondition => GrpcCodeBitmask::FAILED_PRECONDITION, + Self::Aborted => GrpcCodeBitmask::ABORTED, + Self::OutOfRange => GrpcCodeBitmask::OUT_OF_RANGE, + Self::Unimplemented => GrpcCodeBitmask::UNIMPLEMENTED, + Self::Internal => GrpcCodeBitmask::INTERNAL, + Self::Unavailable => GrpcCodeBitmask::UNAVAILABLE, + Self::DataLoss => GrpcCodeBitmask::DATA_LOSS, + Self::Unauthenticated => GrpcCodeBitmask::UNAUTHENTICATED, + } + } +} + +bitflags! { + #[derive(Debug, Clone, Copy)] + pub(crate) struct GrpcCodeBitmask: u32 { + const OK = 0b00000000000000001; + const CANCELLED = 0b00000000000000010; + const UNKNOWN = 0b00000000000000100; + const INVALID_ARGUMENT = 0b00000000000001000; + const DEADLINE_EXCEEDED = 0b00000000000010000; + const NOT_FOUND = 0b00000000000100000; + const ALREADY_EXISTS = 0b00000000001000000; + const PERMISSION_DENIED = 0b00000000010000000; + const RESOURCE_EXHAUSTED = 0b00000000100000000; + const FAILED_PRECONDITION = 0b00000001000000000; + const ABORTED = 0b00000010000000000; + const OUT_OF_RANGE = 0b00000100000000000; + const UNIMPLEMENTED = 0b00001000000000000; + const INTERNAL = 0b00010000000000000; + const UNAVAILABLE = 0b00100000000000000; + const DATA_LOSS = 0b01000000000000000; + const UNAUTHENTICATED = 0b10000000000000000; + } +} + +impl GrpcCodeBitmask { + fn try_from_u32(code: u32) -> Option<Self> { + match code { + 0 => Some(Self::OK), + 1 => Some(Self::CANCELLED), + 2 => Some(Self::UNKNOWN), + 3 => Some(Self::INVALID_ARGUMENT), + 4 => Some(Self::DEADLINE_EXCEEDED), + 5 => Some(Self::NOT_FOUND), + 6 => Some(Self::ALREADY_EXISTS), + 7 => Some(Self::PERMISSION_DENIED), + 8 => Some(Self::RESOURCE_EXHAUSTED), + 9 => Some(Self::FAILED_PRECONDITION), + 10 => Some(Self::ABORTED), + 11 => Some(Self::OUT_OF_RANGE), + 12 => Some(Self::UNIMPLEMENTED), + 13 => Some(Self::INTERNAL), + 14 => Some(Self::UNAVAILABLE), + 15 => Some(Self::DATA_LOSS), + 16 => Some(Self::UNAUTHENTICATED), + _ => None, + } + } +} + +/// Response classifier for gRPC responses. +/// +/// gRPC doesn't use normal HTTP statuses for indicating success or failure but instead a special +/// header that might appear in a trailer. +/// +/// Responses are considered successful if +/// +/// - `grpc-status` header value contains a success value. +/// - `grpc-status` header is missing. +/// - `grpc-status` header value isn't a valid `String`. +/// - `grpc-status` header value can't parsed into an `i32`. +/// +/// All others are considered failures. +#[derive(Debug, Clone)] +pub struct GrpcErrorsAsFailures { + success_codes: GrpcCodeBitmask, +} + +impl Default for GrpcErrorsAsFailures { + fn default() -> Self { + Self::new() + } +} + +impl GrpcErrorsAsFailures { + /// Create a new [`GrpcErrorsAsFailures`]. + pub fn new() -> Self { + Self { + success_codes: GrpcCodeBitmask::OK, + } + } + + /// Change which gRPC codes are considered success. + /// + /// Defaults to only considering `Ok` as success. + /// + /// `Ok` will always be considered a success. + /// + /// # Example + /// + /// Servers might not want to consider `Invalid Argument` or `Not Found` as failures since + /// thats likely the clients fault: + /// + /// ```rust + /// use tower_http::classify::{GrpcErrorsAsFailures, GrpcCode}; + /// + /// let classifier = GrpcErrorsAsFailures::new() + /// .with_success(GrpcCode::InvalidArgument) + /// .with_success(GrpcCode::NotFound); + /// ``` + pub fn with_success(mut self, code: GrpcCode) -> Self { + self.success_codes |= code.into_bitmask(); + self + } + + /// Returns a [`MakeClassifier`](super::MakeClassifier) that produces `GrpcErrorsAsFailures`. + /// + /// This is a convenience function that simply calls `SharedClassifier::new`. + pub fn make_classifier() -> SharedClassifier<Self> { + SharedClassifier::new(Self::new()) + } +} + +impl ClassifyResponse for GrpcErrorsAsFailures { + type FailureClass = GrpcFailureClass; + type ClassifyEos = GrpcEosErrorsAsFailures; + + fn classify_response<B>( + self, + res: &Response<B>, + ) -> ClassifiedResponse<Self::FailureClass, Self::ClassifyEos> { + match classify_grpc_metadata(res.headers(), self.success_codes) { + ParsedGrpcStatus::Success + | ParsedGrpcStatus::HeaderNotString + | ParsedGrpcStatus::HeaderNotInt => ClassifiedResponse::Ready(Ok(())), + ParsedGrpcStatus::NonSuccess(status) => { + ClassifiedResponse::Ready(Err(GrpcFailureClass::Code(status))) + } + ParsedGrpcStatus::GrpcStatusHeaderMissing => { + ClassifiedResponse::RequiresEos(GrpcEosErrorsAsFailures { + success_codes: self.success_codes, + }) + } + } + } + + fn classify_error<E>(self, error: &E) -> Self::FailureClass + where + E: fmt::Display + 'static, + { + GrpcFailureClass::Error(error.to_string()) + } +} + +/// The [`ClassifyEos`] for [`GrpcErrorsAsFailures`]. +#[derive(Debug, Clone)] +pub struct GrpcEosErrorsAsFailures { + success_codes: GrpcCodeBitmask, +} + +impl ClassifyEos for GrpcEosErrorsAsFailures { + type FailureClass = GrpcFailureClass; + + fn classify_eos(self, trailers: Option<&HeaderMap>) -> Result<(), Self::FailureClass> { + if let Some(trailers) = trailers { + match classify_grpc_metadata(trailers, self.success_codes) { + ParsedGrpcStatus::Success + | ParsedGrpcStatus::GrpcStatusHeaderMissing + | ParsedGrpcStatus::HeaderNotString + | ParsedGrpcStatus::HeaderNotInt => Ok(()), + ParsedGrpcStatus::NonSuccess(status) => Err(GrpcFailureClass::Code(status)), + } + } else { + Ok(()) + } + } + + fn classify_error<E>(self, error: &E) -> Self::FailureClass + where + E: fmt::Display + 'static, + { + GrpcFailureClass::Error(error.to_string()) + } +} + +/// The failure class for [`GrpcErrorsAsFailures`]. +#[derive(Debug)] +pub enum GrpcFailureClass { + /// A gRPC response was classified as a failure with the corresponding status. + Code(std::num::NonZeroI32), + /// A gRPC response was classified as an error with the corresponding error description. + Error(String), +} + +impl fmt::Display for GrpcFailureClass { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Self::Code(code) => write!(f, "Code: {}", code), + Self::Error(error) => write!(f, "Error: {}", error), + } + } +} + +pub(crate) fn classify_grpc_metadata( + headers: &HeaderMap, + success_codes: GrpcCodeBitmask, +) -> ParsedGrpcStatus { + macro_rules! or_else { + ($expr:expr, $other:ident) => { + if let Some(value) = $expr { + value + } else { + return ParsedGrpcStatus::$other; + } + }; + } + + let status = or_else!(headers.get("grpc-status"), GrpcStatusHeaderMissing); + let status = or_else!(status.to_str().ok(), HeaderNotString); + let status = or_else!(status.parse::<i32>().ok(), HeaderNotInt); + + if GrpcCodeBitmask::try_from_u32(status as _) + .filter(|code| success_codes.contains(*code)) + .is_some() + { + ParsedGrpcStatus::Success + } else { + ParsedGrpcStatus::NonSuccess(NonZeroI32::new(status).unwrap()) + } +} + +#[derive(Debug, PartialEq, Eq)] +pub(crate) enum ParsedGrpcStatus { + Success, + NonSuccess(NonZeroI32), + GrpcStatusHeaderMissing, + // these two are treated as `Success` but kept separate for clarity + HeaderNotString, + HeaderNotInt, +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! classify_grpc_metadata_test { + ( + name: $name:ident, + status: $status:expr, + success_flags: $success_flags:expr, + expected: $expected:expr, + ) => { + #[test] + fn $name() { + let mut headers = HeaderMap::new(); + headers.insert("grpc-status", $status.parse().unwrap()); + let status = classify_grpc_metadata(&headers, $success_flags); + assert_eq!(status, $expected); + } + }; + } + + classify_grpc_metadata_test! { + name: basic_ok, + status: "0", + success_flags: GrpcCodeBitmask::OK, + expected: ParsedGrpcStatus::Success, + } + + classify_grpc_metadata_test! { + name: basic_error, + status: "1", + success_flags: GrpcCodeBitmask::OK, + expected: ParsedGrpcStatus::NonSuccess(NonZeroI32::new(1).unwrap()), + } + + classify_grpc_metadata_test! { + name: two_success_codes_first_matches, + status: "0", + success_flags: GrpcCodeBitmask::OK | GrpcCodeBitmask::INVALID_ARGUMENT, + expected: ParsedGrpcStatus::Success, + } + + classify_grpc_metadata_test! { + name: two_success_codes_second_matches, + status: "3", + success_flags: GrpcCodeBitmask::OK | GrpcCodeBitmask::INVALID_ARGUMENT, + expected: ParsedGrpcStatus::Success, + } + + classify_grpc_metadata_test! { + name: two_success_codes_none_matches, + status: "16", + success_flags: GrpcCodeBitmask::OK | GrpcCodeBitmask::INVALID_ARGUMENT, + expected: ParsedGrpcStatus::NonSuccess(NonZeroI32::new(16).unwrap()), + } +} diff --git a/vendor/tower-http/src/classify/map_failure_class.rs b/vendor/tower-http/src/classify/map_failure_class.rs new file mode 100644 index 00000000..680593b5 --- /dev/null +++ b/vendor/tower-http/src/classify/map_failure_class.rs @@ -0,0 +1,80 @@ +use super::{ClassifiedResponse, ClassifyEos, ClassifyResponse}; +use http::{HeaderMap, Response}; +use std::fmt; + +/// Response classifier that transforms the failure class of some other +/// classifier. +/// +/// Created with [`ClassifyResponse::map_failure_class`] or +/// [`ClassifyEos::map_failure_class`]. +#[derive(Clone, Copy)] +pub struct MapFailureClass<C, F> { + inner: C, + f: F, +} + +impl<C, F> MapFailureClass<C, F> { + pub(super) fn new(classify: C, f: F) -> Self { + Self { inner: classify, f } + } +} + +impl<C, F> fmt::Debug for MapFailureClass<C, F> +where + C: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("MapFailureClass") + .field("inner", &self.inner) + .field("f", &format_args!("{}", std::any::type_name::<F>())) + .finish() + } +} + +impl<C, F, NewClass> ClassifyResponse for MapFailureClass<C, F> +where + C: ClassifyResponse, + F: FnOnce(C::FailureClass) -> NewClass, +{ + type FailureClass = NewClass; + type ClassifyEos = MapFailureClass<C::ClassifyEos, F>; + + fn classify_response<B>( + self, + res: &Response<B>, + ) -> ClassifiedResponse<Self::FailureClass, Self::ClassifyEos> { + match self.inner.classify_response(res) { + ClassifiedResponse::Ready(result) => ClassifiedResponse::Ready(result.map_err(self.f)), + ClassifiedResponse::RequiresEos(classify_eos) => { + let mapped_classify_eos = MapFailureClass::new(classify_eos, self.f); + ClassifiedResponse::RequiresEos(mapped_classify_eos) + } + } + } + + fn classify_error<E>(self, error: &E) -> Self::FailureClass + where + E: std::fmt::Display + 'static, + { + (self.f)(self.inner.classify_error(error)) + } +} + +impl<C, F, NewClass> ClassifyEos for MapFailureClass<C, F> +where + C: ClassifyEos, + F: FnOnce(C::FailureClass) -> NewClass, +{ + type FailureClass = NewClass; + + fn classify_eos(self, trailers: Option<&HeaderMap>) -> Result<(), Self::FailureClass> { + self.inner.classify_eos(trailers).map_err(self.f) + } + + fn classify_error<E>(self, error: &E) -> Self::FailureClass + where + E: std::fmt::Display + 'static, + { + (self.f)(self.inner.classify_error(error)) + } +} diff --git a/vendor/tower-http/src/classify/mod.rs b/vendor/tower-http/src/classify/mod.rs new file mode 100644 index 00000000..a3147843 --- /dev/null +++ b/vendor/tower-http/src/classify/mod.rs @@ -0,0 +1,436 @@ +//! Tools for classifying responses as either success or failure. + +use http::{HeaderMap, Request, Response, StatusCode}; +use std::{convert::Infallible, fmt, marker::PhantomData}; + +pub(crate) mod grpc_errors_as_failures; +mod map_failure_class; +mod status_in_range_is_error; + +pub use self::{ + grpc_errors_as_failures::{ + GrpcCode, GrpcEosErrorsAsFailures, GrpcErrorsAsFailures, GrpcFailureClass, + }, + map_failure_class::MapFailureClass, + status_in_range_is_error::{StatusInRangeAsFailures, StatusInRangeFailureClass}, +}; + +/// Trait for producing response classifiers from a request. +/// +/// This is useful when a classifier depends on data from the request. For example, this could +/// include the URI or HTTP method. +/// +/// This trait is generic over the [`Error` type] of the `Service`s used with the classifier. +/// This is necessary for [`ClassifyResponse::classify_error`]. +/// +/// [`Error` type]: https://docs.rs/tower/latest/tower/trait.Service.html#associatedtype.Error +pub trait MakeClassifier { + /// The response classifier produced. + type Classifier: ClassifyResponse< + FailureClass = Self::FailureClass, + ClassifyEos = Self::ClassifyEos, + >; + + /// The type of failure classifications. + /// + /// This might include additional information about the error, such as + /// whether it was a client or server error, or whether or not it should + /// be considered retryable. + type FailureClass; + + /// The type used to classify the response end of stream (EOS). + type ClassifyEos: ClassifyEos<FailureClass = Self::FailureClass>; + + /// Returns a response classifier for this request + fn make_classifier<B>(&self, req: &Request<B>) -> Self::Classifier; +} + +/// A [`MakeClassifier`] that produces new classifiers by cloning an inner classifier. +/// +/// When a type implementing [`ClassifyResponse`] doesn't depend on information +/// from the request, [`SharedClassifier`] can be used to turn an instance of that type +/// into a [`MakeClassifier`]. +/// +/// # Example +/// +/// ``` +/// use std::fmt; +/// use tower_http::classify::{ +/// ClassifyResponse, ClassifiedResponse, NeverClassifyEos, +/// SharedClassifier, MakeClassifier, +/// }; +/// use http::Response; +/// +/// // A response classifier that only considers errors to be failures. +/// #[derive(Clone, Copy)] +/// struct MyClassifier; +/// +/// impl ClassifyResponse for MyClassifier { +/// type FailureClass = String; +/// type ClassifyEos = NeverClassifyEos<Self::FailureClass>; +/// +/// fn classify_response<B>( +/// self, +/// _res: &Response<B>, +/// ) -> ClassifiedResponse<Self::FailureClass, Self::ClassifyEos> { +/// ClassifiedResponse::Ready(Ok(())) +/// } +/// +/// fn classify_error<E>(self, error: &E) -> Self::FailureClass +/// where +/// E: fmt::Display + 'static, +/// { +/// error.to_string() +/// } +/// } +/// +/// // Some function that requires a `MakeClassifier` +/// fn use_make_classifier<M: MakeClassifier>(make: M) { +/// // ... +/// } +/// +/// // `MyClassifier` doesn't implement `MakeClassifier` but since it doesn't +/// // care about the incoming request we can make `MyClassifier`s by cloning. +/// // That is what `SharedClassifier` does. +/// let make_classifier = SharedClassifier::new(MyClassifier); +/// +/// // We now have a `MakeClassifier`! +/// use_make_classifier(make_classifier); +/// ``` +#[derive(Debug, Clone)] +pub struct SharedClassifier<C> { + classifier: C, +} + +impl<C> SharedClassifier<C> { + /// Create a new `SharedClassifier` from the given classifier. + pub fn new(classifier: C) -> Self + where + C: ClassifyResponse + Clone, + { + Self { classifier } + } +} + +impl<C> MakeClassifier for SharedClassifier<C> +where + C: ClassifyResponse + Clone, +{ + type FailureClass = C::FailureClass; + type ClassifyEos = C::ClassifyEos; + type Classifier = C; + + fn make_classifier<B>(&self, _req: &Request<B>) -> Self::Classifier { + self.classifier.clone() + } +} + +/// Trait for classifying responses as either success or failure. Designed to support both unary +/// requests (single request for a single response) as well as streaming responses. +/// +/// Response classifiers are used in cases where middleware needs to determine +/// whether a response completed successfully or failed. For example, they may +/// be used by logging or metrics middleware to record failures differently +/// from successes. +/// +/// Furthermore, when a response fails, a response classifier may provide +/// additional information about the failure. This can, for example, be used to +/// build [retry policies] by indicating whether or not a particular failure is +/// retryable. +/// +/// [retry policies]: https://docs.rs/tower/latest/tower/retry/trait.Policy.html +pub trait ClassifyResponse { + /// The type returned when a response is classified as a failure. + /// + /// Depending on the classifier, this may simply indicate that the + /// request failed, or it may contain additional information about + /// the failure, such as whether or not it is retryable. + type FailureClass; + + /// The type used to classify the response end of stream (EOS). + type ClassifyEos: ClassifyEos<FailureClass = Self::FailureClass>; + + /// Attempt to classify the beginning of a response. + /// + /// In some cases, the response can be classified immediately, without + /// waiting for a body to complete. This may include: + /// + /// - When the response has an error status code. + /// - When a successful response does not have a streaming body. + /// - When the classifier does not care about streaming bodies. + /// + /// When the response can be classified immediately, `classify_response` + /// returns a [`ClassifiedResponse::Ready`] which indicates whether the + /// response succeeded or failed. + /// + /// In other cases, however, the classifier may need to wait until the + /// response body stream completes before it can classify the response. + /// For example, gRPC indicates RPC failures using the `grpc-status` + /// trailer. In this case, `classify_response` returns a + /// [`ClassifiedResponse::RequiresEos`] containing a type which will + /// be used to classify the response when the body stream ends. + fn classify_response<B>( + self, + res: &Response<B>, + ) -> ClassifiedResponse<Self::FailureClass, Self::ClassifyEos>; + + /// Classify an error. + /// + /// Errors are always errors (doh) but sometimes it might be useful to have multiple classes of + /// errors. A retry policy might allow retrying some errors and not others. + fn classify_error<E>(self, error: &E) -> Self::FailureClass + where + E: fmt::Display + 'static; + + /// Transform the failure classification using a function. + /// + /// # Example + /// + /// ``` + /// use tower_http::classify::{ + /// ServerErrorsAsFailures, ServerErrorsFailureClass, + /// ClassifyResponse, ClassifiedResponse + /// }; + /// use http::{Response, StatusCode}; + /// use http_body_util::Empty; + /// use bytes::Bytes; + /// + /// fn transform_failure_class(class: ServerErrorsFailureClass) -> NewFailureClass { + /// match class { + /// // Convert status codes into u16 + /// ServerErrorsFailureClass::StatusCode(status) => { + /// NewFailureClass::Status(status.as_u16()) + /// } + /// // Don't change errors. + /// ServerErrorsFailureClass::Error(error) => { + /// NewFailureClass::Error(error) + /// } + /// } + /// } + /// + /// enum NewFailureClass { + /// Status(u16), + /// Error(String), + /// } + /// + /// // Create a classifier who's failure class will be transformed by `transform_failure_class` + /// let classifier = ServerErrorsAsFailures::new().map_failure_class(transform_failure_class); + /// + /// let response = Response::builder() + /// .status(StatusCode::INTERNAL_SERVER_ERROR) + /// .body(Empty::<Bytes>::new()) + /// .unwrap(); + /// + /// let classification = classifier.classify_response(&response); + /// + /// assert!(matches!( + /// classification, + /// ClassifiedResponse::Ready(Err(NewFailureClass::Status(500))) + /// )); + /// ``` + fn map_failure_class<F, NewClass>(self, f: F) -> MapFailureClass<Self, F> + where + Self: Sized, + F: FnOnce(Self::FailureClass) -> NewClass, + { + MapFailureClass::new(self, f) + } +} + +/// Trait for classifying end of streams (EOS) as either success or failure. +pub trait ClassifyEos { + /// The type of failure classifications. + type FailureClass; + + /// Perform the classification from response trailers. + fn classify_eos(self, trailers: Option<&HeaderMap>) -> Result<(), Self::FailureClass>; + + /// Classify an error. + /// + /// Errors are always errors (doh) but sometimes it might be useful to have multiple classes of + /// errors. A retry policy might allow retrying some errors and not others. + fn classify_error<E>(self, error: &E) -> Self::FailureClass + where + E: fmt::Display + 'static; + + /// Transform the failure classification using a function. + /// + /// See [`ClassifyResponse::map_failure_class`] for more details. + fn map_failure_class<F, NewClass>(self, f: F) -> MapFailureClass<Self, F> + where + Self: Sized, + F: FnOnce(Self::FailureClass) -> NewClass, + { + MapFailureClass::new(self, f) + } +} + +/// Result of doing a classification. +#[derive(Debug)] +pub enum ClassifiedResponse<FailureClass, ClassifyEos> { + /// The response was able to be classified immediately. + Ready(Result<(), FailureClass>), + /// We have to wait until the end of a streaming response to classify it. + RequiresEos(ClassifyEos), +} + +/// A [`ClassifyEos`] type that can be used in [`ClassifyResponse`] implementations that never have +/// to classify streaming responses. +/// +/// `NeverClassifyEos` exists only as type. `NeverClassifyEos` values cannot be constructed. +pub struct NeverClassifyEos<T> { + _output_ty: PhantomData<fn() -> T>, + _never: Infallible, +} + +impl<T> ClassifyEos for NeverClassifyEos<T> { + type FailureClass = T; + + fn classify_eos(self, _trailers: Option<&HeaderMap>) -> Result<(), Self::FailureClass> { + // `NeverClassifyEos` contains an `Infallible` so it can never be constructed + unreachable!() + } + + fn classify_error<E>(self, _error: &E) -> Self::FailureClass + where + E: fmt::Display + 'static, + { + // `NeverClassifyEos` contains an `Infallible` so it can never be constructed + unreachable!() + } +} + +impl<T> fmt::Debug for NeverClassifyEos<T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("NeverClassifyEos").finish() + } +} + +/// The default classifier used for normal HTTP responses. +/// +/// Responses with a `5xx` status code are considered failures, all others are considered +/// successes. +#[derive(Clone, Debug, Default)] +pub struct ServerErrorsAsFailures { + _priv: (), +} + +impl ServerErrorsAsFailures { + /// Create a new [`ServerErrorsAsFailures`]. + pub fn new() -> Self { + Self::default() + } + + /// Returns a [`MakeClassifier`] that produces `ServerErrorsAsFailures`. + /// + /// This is a convenience function that simply calls `SharedClassifier::new`. + pub fn make_classifier() -> SharedClassifier<Self> { + SharedClassifier::new(Self::new()) + } +} + +impl ClassifyResponse for ServerErrorsAsFailures { + type FailureClass = ServerErrorsFailureClass; + type ClassifyEos = NeverClassifyEos<ServerErrorsFailureClass>; + + fn classify_response<B>( + self, + res: &Response<B>, + ) -> ClassifiedResponse<Self::FailureClass, Self::ClassifyEos> { + if res.status().is_server_error() { + ClassifiedResponse::Ready(Err(ServerErrorsFailureClass::StatusCode(res.status()))) + } else { + ClassifiedResponse::Ready(Ok(())) + } + } + + fn classify_error<E>(self, error: &E) -> Self::FailureClass + where + E: fmt::Display + 'static, + { + ServerErrorsFailureClass::Error(error.to_string()) + } +} + +/// The failure class for [`ServerErrorsAsFailures`]. +#[derive(Debug)] +pub enum ServerErrorsFailureClass { + /// A response was classified as a failure with the corresponding status. + StatusCode(StatusCode), + /// A response was classified as an error with the corresponding error description. + Error(String), +} + +impl fmt::Display for ServerErrorsFailureClass { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::StatusCode(code) => write!(f, "Status code: {}", code), + Self::Error(error) => write!(f, "Error: {}", error), + } + } +} + +// Just verify that we can actually use this response classifier to determine retries as well +#[cfg(test)] +mod usable_for_retries { + #![allow(dead_code)] + + use std::fmt; + + use http::{Request, Response}; + use tower::retry::Policy; + + use super::{ClassifiedResponse, ClassifyResponse}; + + trait IsRetryable { + fn is_retryable(&self) -> bool; + } + + #[derive(Clone)] + struct RetryBasedOnClassification<C> { + classifier: C, + // ... + } + + impl<ReqB, ResB, E, C> Policy<Request<ReqB>, Response<ResB>, E> for RetryBasedOnClassification<C> + where + C: ClassifyResponse + Clone, + E: fmt::Display + 'static, + C::FailureClass: IsRetryable, + ResB: http_body::Body, + Request<ReqB>: Clone, + E: std::error::Error + 'static, + { + type Future = std::future::Ready<()>; + + fn retry( + &mut self, + _req: &mut Request<ReqB>, + res: &mut Result<Response<ResB>, E>, + ) -> Option<Self::Future> { + match res { + Ok(res) => { + if let ClassifiedResponse::Ready(class) = + self.classifier.clone().classify_response(res) + { + if class.err()?.is_retryable() { + return Some(std::future::ready(())); + } + } + + None + } + Err(err) => self + .classifier + .clone() + .classify_error(err) + .is_retryable() + .then(|| std::future::ready(())), + } + } + + fn clone_request(&mut self, req: &Request<ReqB>) -> Option<Request<ReqB>> { + Some(req.clone()) + } + } +} diff --git a/vendor/tower-http/src/classify/status_in_range_is_error.rs b/vendor/tower-http/src/classify/status_in_range_is_error.rs new file mode 100644 index 00000000..8ff830b9 --- /dev/null +++ b/vendor/tower-http/src/classify/status_in_range_is_error.rs @@ -0,0 +1,160 @@ +use super::{ClassifiedResponse, ClassifyResponse, NeverClassifyEos, SharedClassifier}; +use http::StatusCode; +use std::{fmt, ops::RangeInclusive}; + +/// Response classifier that considers responses with a status code within some range to be +/// failures. +/// +/// # Example +/// +/// A client with tracing where server errors _and_ client errors are considered failures. +/// +/// ```no_run +/// use tower_http::{trace::TraceLayer, classify::StatusInRangeAsFailures}; +/// use tower::{ServiceBuilder, Service, ServiceExt}; +/// use http::{Request, Method}; +/// use http_body_util::Full; +/// use bytes::Bytes; +/// use hyper_util::{rt::TokioExecutor, client::legacy::Client}; +/// +/// # async fn foo() -> Result<(), tower::BoxError> { +/// let classifier = StatusInRangeAsFailures::new(400..=599); +/// +/// let client = Client::builder(TokioExecutor::new()).build_http(); +/// let mut client = ServiceBuilder::new() +/// .layer(TraceLayer::new(classifier.into_make_classifier())) +/// .service(client); +/// +/// let request = Request::builder() +/// .method(Method::GET) +/// .uri("https://example.com") +/// .body(Full::<Bytes>::default()) +/// .unwrap(); +/// +/// let response = client.ready().await?.call(request).await?; +/// # Ok(()) +/// # } +/// ``` +#[derive(Debug, Clone)] +pub struct StatusInRangeAsFailures { + range: RangeInclusive<u16>, +} + +impl StatusInRangeAsFailures { + /// Creates a new `StatusInRangeAsFailures`. + /// + /// # Panics + /// + /// Panics if the start or end of `range` aren't valid status codes as determined by + /// [`StatusCode::from_u16`]. + /// + /// [`StatusCode::from_u16`]: https://docs.rs/http/latest/http/status/struct.StatusCode.html#method.from_u16 + pub fn new(range: RangeInclusive<u16>) -> Self { + assert!( + StatusCode::from_u16(*range.start()).is_ok(), + "range start isn't a valid status code" + ); + assert!( + StatusCode::from_u16(*range.end()).is_ok(), + "range end isn't a valid status code" + ); + + Self { range } + } + + /// Creates a new `StatusInRangeAsFailures` that classifies client and server responses as + /// failures. + /// + /// This is a convenience for `StatusInRangeAsFailures::new(400..=599)`. + pub fn new_for_client_and_server_errors() -> Self { + Self::new(400..=599) + } + + /// Convert this `StatusInRangeAsFailures` into a [`MakeClassifier`]. + /// + /// [`MakeClassifier`]: super::MakeClassifier + pub fn into_make_classifier(self) -> SharedClassifier<Self> { + SharedClassifier::new(self) + } +} + +impl ClassifyResponse for StatusInRangeAsFailures { + type FailureClass = StatusInRangeFailureClass; + type ClassifyEos = NeverClassifyEos<Self::FailureClass>; + + fn classify_response<B>( + self, + res: &http::Response<B>, + ) -> ClassifiedResponse<Self::FailureClass, Self::ClassifyEos> { + if self.range.contains(&res.status().as_u16()) { + let class = StatusInRangeFailureClass::StatusCode(res.status()); + ClassifiedResponse::Ready(Err(class)) + } else { + ClassifiedResponse::Ready(Ok(())) + } + } + + fn classify_error<E>(self, error: &E) -> Self::FailureClass + where + E: std::fmt::Display + 'static, + { + StatusInRangeFailureClass::Error(error.to_string()) + } +} + +/// The failure class for [`StatusInRangeAsFailures`]. +#[derive(Debug)] +pub enum StatusInRangeFailureClass { + /// A response was classified as a failure with the corresponding status. + StatusCode(StatusCode), + /// A response was classified as an error with the corresponding error description. + Error(String), +} + +impl fmt::Display for StatusInRangeFailureClass { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::StatusCode(code) => write!(f, "Status code: {}", code), + Self::Error(error) => write!(f, "Error: {}", error), + } + } +} + +#[cfg(test)] +mod tests { + #[allow(unused_imports)] + use super::*; + use http::Response; + + #[test] + fn basic() { + let classifier = StatusInRangeAsFailures::new(400..=599); + + assert!(matches!( + classifier + .clone() + .classify_response(&response_with_status(200)), + ClassifiedResponse::Ready(Ok(())), + )); + + assert!(matches!( + classifier + .clone() + .classify_response(&response_with_status(400)), + ClassifiedResponse::Ready(Err(StatusInRangeFailureClass::StatusCode( + StatusCode::BAD_REQUEST + ))), + )); + + assert!(matches!( + classifier.classify_response(&response_with_status(500)), + ClassifiedResponse::Ready(Err(StatusInRangeFailureClass::StatusCode( + StatusCode::INTERNAL_SERVER_ERROR + ))), + )); + } + + fn response_with_status(status: u16) -> Response<()> { + Response::builder().status(status).body(()).unwrap() + } +} diff --git a/vendor/tower-http/src/compression/body.rs b/vendor/tower-http/src/compression/body.rs new file mode 100644 index 00000000..259e4a27 --- /dev/null +++ b/vendor/tower-http/src/compression/body.rs @@ -0,0 +1,387 @@ +#![allow(unused_imports)] + +use crate::compression::CompressionLevel; +use crate::{ + compression_utils::{AsyncReadBody, BodyIntoStream, DecorateAsyncRead, WrapBody}, + BoxError, +}; +#[cfg(feature = "compression-br")] +use async_compression::tokio::bufread::BrotliEncoder; +#[cfg(feature = "compression-gzip")] +use async_compression::tokio::bufread::GzipEncoder; +#[cfg(feature = "compression-deflate")] +use async_compression::tokio::bufread::ZlibEncoder; +#[cfg(feature = "compression-zstd")] +use async_compression::tokio::bufread::ZstdEncoder; + +use bytes::{Buf, Bytes}; +use http::HeaderMap; +use http_body::Body; +use pin_project_lite::pin_project; +use std::{ + io, + marker::PhantomData, + pin::Pin, + task::{ready, Context, Poll}, +}; +use tokio_util::io::StreamReader; + +use super::pin_project_cfg::pin_project_cfg; + +pin_project! { + /// Response body of [`Compression`]. + /// + /// [`Compression`]: super::Compression + pub struct CompressionBody<B> + where + B: Body, + { + #[pin] + pub(crate) inner: BodyInner<B>, + } +} + +impl<B> Default for CompressionBody<B> +where + B: Body + Default, +{ + fn default() -> Self { + Self { + inner: BodyInner::Identity { + inner: B::default(), + }, + } + } +} + +impl<B> CompressionBody<B> +where + B: Body, +{ + pub(crate) fn new(inner: BodyInner<B>) -> Self { + Self { inner } + } + + /// Get a reference to the inner body + pub fn get_ref(&self) -> &B { + match &self.inner { + #[cfg(feature = "compression-gzip")] + BodyInner::Gzip { inner } => inner.read.get_ref().get_ref().get_ref().get_ref(), + #[cfg(feature = "compression-deflate")] + BodyInner::Deflate { inner } => inner.read.get_ref().get_ref().get_ref().get_ref(), + #[cfg(feature = "compression-br")] + BodyInner::Brotli { inner } => inner.read.get_ref().get_ref().get_ref().get_ref(), + #[cfg(feature = "compression-zstd")] + BodyInner::Zstd { inner } => inner.read.get_ref().get_ref().get_ref().get_ref(), + BodyInner::Identity { inner } => inner, + } + } + + /// Get a mutable reference to the inner body + pub fn get_mut(&mut self) -> &mut B { + match &mut self.inner { + #[cfg(feature = "compression-gzip")] + BodyInner::Gzip { inner } => inner.read.get_mut().get_mut().get_mut().get_mut(), + #[cfg(feature = "compression-deflate")] + BodyInner::Deflate { inner } => inner.read.get_mut().get_mut().get_mut().get_mut(), + #[cfg(feature = "compression-br")] + BodyInner::Brotli { inner } => inner.read.get_mut().get_mut().get_mut().get_mut(), + #[cfg(feature = "compression-zstd")] + BodyInner::Zstd { inner } => inner.read.get_mut().get_mut().get_mut().get_mut(), + BodyInner::Identity { inner } => inner, + } + } + + /// Get a pinned mutable reference to the inner body + pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut B> { + match self.project().inner.project() { + #[cfg(feature = "compression-gzip")] + BodyInnerProj::Gzip { inner } => inner + .project() + .read + .get_pin_mut() + .get_pin_mut() + .get_pin_mut() + .get_pin_mut(), + #[cfg(feature = "compression-deflate")] + BodyInnerProj::Deflate { inner } => inner + .project() + .read + .get_pin_mut() + .get_pin_mut() + .get_pin_mut() + .get_pin_mut(), + #[cfg(feature = "compression-br")] + BodyInnerProj::Brotli { inner } => inner + .project() + .read + .get_pin_mut() + .get_pin_mut() + .get_pin_mut() + .get_pin_mut(), + #[cfg(feature = "compression-zstd")] + BodyInnerProj::Zstd { inner } => inner + .project() + .read + .get_pin_mut() + .get_pin_mut() + .get_pin_mut() + .get_pin_mut(), + BodyInnerProj::Identity { inner } => inner, + } + } + + /// Consume `self`, returning the inner body + pub fn into_inner(self) -> B { + match self.inner { + #[cfg(feature = "compression-gzip")] + BodyInner::Gzip { inner } => inner + .read + .into_inner() + .into_inner() + .into_inner() + .into_inner(), + #[cfg(feature = "compression-deflate")] + BodyInner::Deflate { inner } => inner + .read + .into_inner() + .into_inner() + .into_inner() + .into_inner(), + #[cfg(feature = "compression-br")] + BodyInner::Brotli { inner } => inner + .read + .into_inner() + .into_inner() + .into_inner() + .into_inner(), + #[cfg(feature = "compression-zstd")] + BodyInner::Zstd { inner } => inner + .read + .into_inner() + .into_inner() + .into_inner() + .into_inner(), + BodyInner::Identity { inner } => inner, + } + } +} + +#[cfg(feature = "compression-gzip")] +type GzipBody<B> = WrapBody<GzipEncoder<B>>; + +#[cfg(feature = "compression-deflate")] +type DeflateBody<B> = WrapBody<ZlibEncoder<B>>; + +#[cfg(feature = "compression-br")] +type BrotliBody<B> = WrapBody<BrotliEncoder<B>>; + +#[cfg(feature = "compression-zstd")] +type ZstdBody<B> = WrapBody<ZstdEncoder<B>>; + +pin_project_cfg! { + #[project = BodyInnerProj] + pub(crate) enum BodyInner<B> + where + B: Body, + { + #[cfg(feature = "compression-gzip")] + Gzip { + #[pin] + inner: GzipBody<B>, + }, + #[cfg(feature = "compression-deflate")] + Deflate { + #[pin] + inner: DeflateBody<B>, + }, + #[cfg(feature = "compression-br")] + Brotli { + #[pin] + inner: BrotliBody<B>, + }, + #[cfg(feature = "compression-zstd")] + Zstd { + #[pin] + inner: ZstdBody<B>, + }, + Identity { + #[pin] + inner: B, + }, + } +} + +impl<B: Body> BodyInner<B> { + #[cfg(feature = "compression-gzip")] + pub(crate) fn gzip(inner: WrapBody<GzipEncoder<B>>) -> Self { + Self::Gzip { inner } + } + + #[cfg(feature = "compression-deflate")] + pub(crate) fn deflate(inner: WrapBody<ZlibEncoder<B>>) -> Self { + Self::Deflate { inner } + } + + #[cfg(feature = "compression-br")] + pub(crate) fn brotli(inner: WrapBody<BrotliEncoder<B>>) -> Self { + Self::Brotli { inner } + } + + #[cfg(feature = "compression-zstd")] + pub(crate) fn zstd(inner: WrapBody<ZstdEncoder<B>>) -> Self { + Self::Zstd { inner } + } + + pub(crate) fn identity(inner: B) -> Self { + Self::Identity { inner } + } +} + +impl<B> Body for CompressionBody<B> +where + B: Body, + B::Error: Into<BoxError>, +{ + type Data = Bytes; + type Error = BoxError; + + fn poll_frame( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> { + match self.project().inner.project() { + #[cfg(feature = "compression-gzip")] + BodyInnerProj::Gzip { inner } => inner.poll_frame(cx), + #[cfg(feature = "compression-deflate")] + BodyInnerProj::Deflate { inner } => inner.poll_frame(cx), + #[cfg(feature = "compression-br")] + BodyInnerProj::Brotli { inner } => inner.poll_frame(cx), + #[cfg(feature = "compression-zstd")] + BodyInnerProj::Zstd { inner } => inner.poll_frame(cx), + BodyInnerProj::Identity { inner } => match ready!(inner.poll_frame(cx)) { + Some(Ok(frame)) => { + let frame = frame.map_data(|mut buf| buf.copy_to_bytes(buf.remaining())); + Poll::Ready(Some(Ok(frame))) + } + Some(Err(err)) => Poll::Ready(Some(Err(err.into()))), + None => Poll::Ready(None), + }, + } + } + + fn size_hint(&self) -> http_body::SizeHint { + if let BodyInner::Identity { inner } = &self.inner { + inner.size_hint() + } else { + http_body::SizeHint::new() + } + } + + fn is_end_stream(&self) -> bool { + if let BodyInner::Identity { inner } = &self.inner { + inner.is_end_stream() + } else { + false + } + } +} + +#[cfg(feature = "compression-gzip")] +impl<B> DecorateAsyncRead for GzipEncoder<B> +where + B: Body, +{ + type Input = AsyncReadBody<B>; + type Output = GzipEncoder<Self::Input>; + + fn apply(input: Self::Input, quality: CompressionLevel) -> Self::Output { + GzipEncoder::with_quality(input, quality.into_async_compression()) + } + + fn get_pin_mut(pinned: Pin<&mut Self::Output>) -> Pin<&mut Self::Input> { + pinned.get_pin_mut() + } +} + +#[cfg(feature = "compression-deflate")] +impl<B> DecorateAsyncRead for ZlibEncoder<B> +where + B: Body, +{ + type Input = AsyncReadBody<B>; + type Output = ZlibEncoder<Self::Input>; + + fn apply(input: Self::Input, quality: CompressionLevel) -> Self::Output { + ZlibEncoder::with_quality(input, quality.into_async_compression()) + } + + fn get_pin_mut(pinned: Pin<&mut Self::Output>) -> Pin<&mut Self::Input> { + pinned.get_pin_mut() + } +} + +#[cfg(feature = "compression-br")] +impl<B> DecorateAsyncRead for BrotliEncoder<B> +where + B: Body, +{ + type Input = AsyncReadBody<B>; + type Output = BrotliEncoder<Self::Input>; + + fn apply(input: Self::Input, quality: CompressionLevel) -> Self::Output { + // The brotli crate used under the hood here has a default compression level of 11, + // which is the max for brotli. This causes extremely slow compression times, so we + // manually set a default of 4 here. + // + // This is the same default used by NGINX for on-the-fly brotli compression. + let level = match quality { + CompressionLevel::Default => async_compression::Level::Precise(4), + other => other.into_async_compression(), + }; + BrotliEncoder::with_quality(input, level) + } + + fn get_pin_mut(pinned: Pin<&mut Self::Output>) -> Pin<&mut Self::Input> { + pinned.get_pin_mut() + } +} + +#[cfg(feature = "compression-zstd")] +impl<B> DecorateAsyncRead for ZstdEncoder<B> +where + B: Body, +{ + type Input = AsyncReadBody<B>; + type Output = ZstdEncoder<Self::Input>; + + fn apply(input: Self::Input, quality: CompressionLevel) -> Self::Output { + // See https://issues.chromium.org/issues/41493659: + // "For memory usage reasons, Chromium limits the window size to 8MB" + // See https://datatracker.ietf.org/doc/html/rfc8878#name-window-descriptor + // "For improved interoperability, it's recommended for decoders to support values + // of Window_Size up to 8 MB and for encoders not to generate frames requiring a + // Window_Size larger than 8 MB." + // Level 17 in zstd (as of v1.5.6) is the first level with a window size of 8 MB (2^23): + // https://github.com/facebook/zstd/blob/v1.5.6/lib/compress/clevels.h#L25-L51 + // Set the parameter for all levels >= 17. This will either have no effect (but reduce + // the risk of future changes in zstd) or limit the window log to 8MB. + let needs_window_limit = match quality { + CompressionLevel::Best => true, // level 20 + CompressionLevel::Precise(level) => level >= 17, + _ => false, + }; + // The parameter is not set for levels below 17 as it will increase the window size + // for those levels. + if needs_window_limit { + let params = [async_compression::zstd::CParameter::window_log(23)]; + ZstdEncoder::with_quality_and_params(input, quality.into_async_compression(), ¶ms) + } else { + ZstdEncoder::with_quality(input, quality.into_async_compression()) + } + } + + fn get_pin_mut(pinned: Pin<&mut Self::Output>) -> Pin<&mut Self::Input> { + pinned.get_pin_mut() + } +} diff --git a/vendor/tower-http/src/compression/future.rs b/vendor/tower-http/src/compression/future.rs new file mode 100644 index 00000000..3e899a73 --- /dev/null +++ b/vendor/tower-http/src/compression/future.rs @@ -0,0 +1,133 @@ +#![allow(unused_imports)] + +use super::{body::BodyInner, CompressionBody}; +use crate::compression::predicate::Predicate; +use crate::compression::CompressionLevel; +use crate::compression_utils::WrapBody; +use crate::content_encoding::Encoding; +use http::{header, HeaderMap, HeaderValue, Response}; +use http_body::Body; +use pin_project_lite::pin_project; +use std::{ + future::Future, + pin::Pin, + task::{ready, Context, Poll}, +}; + +pin_project! { + /// Response future of [`Compression`]. + /// + /// [`Compression`]: super::Compression + #[derive(Debug)] + pub struct ResponseFuture<F, P> { + #[pin] + pub(crate) inner: F, + pub(crate) encoding: Encoding, + pub(crate) predicate: P, + pub(crate) quality: CompressionLevel, + } +} + +impl<F, B, E, P> Future for ResponseFuture<F, P> +where + F: Future<Output = Result<Response<B>, E>>, + B: Body, + P: Predicate, +{ + type Output = Result<Response<CompressionBody<B>>, E>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + let res = ready!(self.as_mut().project().inner.poll(cx)?); + + // never recompress responses that are already compressed + let should_compress = !res.headers().contains_key(header::CONTENT_ENCODING) + // never compress responses that are ranges + && !res.headers().contains_key(header::CONTENT_RANGE) + && self.predicate.should_compress(&res); + + let (mut parts, body) = res.into_parts(); + + if should_compress + && !parts.headers.get_all(header::VARY).iter().any(|value| { + contains_ignore_ascii_case( + value.as_bytes(), + header::ACCEPT_ENCODING.as_str().as_bytes(), + ) + }) + { + parts + .headers + .append(header::VARY, header::ACCEPT_ENCODING.into()); + } + + let body = match (should_compress, self.encoding) { + // if compression is _not_ supported or the client doesn't accept it + (false, _) | (_, Encoding::Identity) => { + return Poll::Ready(Ok(Response::from_parts( + parts, + CompressionBody::new(BodyInner::identity(body)), + ))) + } + + #[cfg(feature = "compression-gzip")] + (_, Encoding::Gzip) => { + CompressionBody::new(BodyInner::gzip(WrapBody::new(body, self.quality))) + } + #[cfg(feature = "compression-deflate")] + (_, Encoding::Deflate) => { + CompressionBody::new(BodyInner::deflate(WrapBody::new(body, self.quality))) + } + #[cfg(feature = "compression-br")] + (_, Encoding::Brotli) => { + CompressionBody::new(BodyInner::brotli(WrapBody::new(body, self.quality))) + } + #[cfg(feature = "compression-zstd")] + (_, Encoding::Zstd) => { + CompressionBody::new(BodyInner::zstd(WrapBody::new(body, self.quality))) + } + #[cfg(feature = "fs")] + #[allow(unreachable_patterns)] + (true, _) => { + // This should never happen because the `AcceptEncoding` struct which is used to determine + // `self.encoding` will only enable the different compression algorithms if the + // corresponding crate feature has been enabled. This means + // Encoding::[Gzip|Brotli|Deflate] should be impossible at this point without the + // features enabled. + // + // The match arm is still required though because the `fs` feature uses the + // Encoding struct independently and requires no compression logic to be enabled. + // This means a combination of an individual compression feature and `fs` will fail + // to compile without this branch even though it will never be reached. + // + // To safeguard against refactors that changes this relationship or other bugs the + // server will return an uncompressed response instead of panicking since that could + // become a ddos attack vector. + return Poll::Ready(Ok(Response::from_parts( + parts, + CompressionBody::new(BodyInner::identity(body)), + ))); + } + }; + + parts.headers.remove(header::ACCEPT_RANGES); + parts.headers.remove(header::CONTENT_LENGTH); + + parts + .headers + .insert(header::CONTENT_ENCODING, self.encoding.into_header_value()); + + let res = Response::from_parts(parts, body); + Poll::Ready(Ok(res)) + } +} + +fn contains_ignore_ascii_case(mut haystack: &[u8], needle: &[u8]) -> bool { + while needle.len() <= haystack.len() { + if haystack[..needle.len()].eq_ignore_ascii_case(needle) { + return true; + } + haystack = &haystack[1..]; + } + + false +} diff --git a/vendor/tower-http/src/compression/layer.rs b/vendor/tower-http/src/compression/layer.rs new file mode 100644 index 00000000..5eca0c50 --- /dev/null +++ b/vendor/tower-http/src/compression/layer.rs @@ -0,0 +1,240 @@ +use super::{Compression, Predicate}; +use crate::compression::predicate::DefaultPredicate; +use crate::compression::CompressionLevel; +use crate::compression_utils::AcceptEncoding; +use tower_layer::Layer; + +/// Compress response bodies of the underlying service. +/// +/// This uses the `Accept-Encoding` header to pick an appropriate encoding and adds the +/// `Content-Encoding` header to responses. +/// +/// See the [module docs](crate::compression) for more details. +#[derive(Clone, Debug, Default)] +pub struct CompressionLayer<P = DefaultPredicate> { + accept: AcceptEncoding, + predicate: P, + quality: CompressionLevel, +} + +impl<S, P> Layer<S> for CompressionLayer<P> +where + P: Predicate, +{ + type Service = Compression<S, P>; + + fn layer(&self, inner: S) -> Self::Service { + Compression { + inner, + accept: self.accept, + predicate: self.predicate.clone(), + quality: self.quality, + } + } +} + +impl CompressionLayer { + /// Creates a new [`CompressionLayer`]. + pub fn new() -> Self { + Self::default() + } + + /// Sets whether to enable the gzip encoding. + #[cfg(feature = "compression-gzip")] + pub fn gzip(mut self, enable: bool) -> Self { + self.accept.set_gzip(enable); + self + } + + /// Sets whether to enable the Deflate encoding. + #[cfg(feature = "compression-deflate")] + pub fn deflate(mut self, enable: bool) -> Self { + self.accept.set_deflate(enable); + self + } + + /// Sets whether to enable the Brotli encoding. + #[cfg(feature = "compression-br")] + pub fn br(mut self, enable: bool) -> Self { + self.accept.set_br(enable); + self + } + + /// Sets whether to enable the Zstd encoding. + #[cfg(feature = "compression-zstd")] + pub fn zstd(mut self, enable: bool) -> Self { + self.accept.set_zstd(enable); + self + } + + /// Sets the compression quality. + pub fn quality(mut self, quality: CompressionLevel) -> Self { + self.quality = quality; + self + } + + /// Disables the gzip encoding. + /// + /// This method is available even if the `gzip` crate feature is disabled. + pub fn no_gzip(mut self) -> Self { + self.accept.set_gzip(false); + self + } + + /// Disables the Deflate encoding. + /// + /// This method is available even if the `deflate` crate feature is disabled. + pub fn no_deflate(mut self) -> Self { + self.accept.set_deflate(false); + self + } + + /// Disables the Brotli encoding. + /// + /// This method is available even if the `br` crate feature is disabled. + pub fn no_br(mut self) -> Self { + self.accept.set_br(false); + self + } + + /// Disables the Zstd encoding. + /// + /// This method is available even if the `zstd` crate feature is disabled. + pub fn no_zstd(mut self) -> Self { + self.accept.set_zstd(false); + self + } + + /// Replace the current compression predicate. + /// + /// See [`Compression::compress_when`] for more details. + pub fn compress_when<C>(self, predicate: C) -> CompressionLayer<C> + where + C: Predicate, + { + CompressionLayer { + accept: self.accept, + predicate, + quality: self.quality, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::test_helpers::Body; + use http::{header::ACCEPT_ENCODING, Request, Response}; + use http_body_util::BodyExt; + use std::convert::Infallible; + use tokio::fs::File; + use tokio_util::io::ReaderStream; + use tower::{Service, ServiceBuilder, ServiceExt}; + + async fn handle(_req: Request<Body>) -> Result<Response<Body>, Infallible> { + // Open the file. + let file = File::open("Cargo.toml").await.expect("file missing"); + // Convert the file into a `Stream`. + let stream = ReaderStream::new(file); + // Convert the `Stream` into a `Body`. + let body = Body::from_stream(stream); + // Create response. + Ok(Response::new(body)) + } + + #[tokio::test] + async fn accept_encoding_configuration_works() -> Result<(), crate::BoxError> { + let deflate_only_layer = CompressionLayer::new() + .quality(CompressionLevel::Best) + .no_br() + .no_gzip(); + + let mut service = ServiceBuilder::new() + // Compress responses based on the `Accept-Encoding` header. + .layer(deflate_only_layer) + .service_fn(handle); + + // Call the service with the deflate only layer + let request = Request::builder() + .header(ACCEPT_ENCODING, "gzip, deflate, br") + .body(Body::empty())?; + + let response = service.ready().await?.call(request).await?; + + assert_eq!(response.headers()["content-encoding"], "deflate"); + + // Read the body + let body = response.into_body(); + let bytes = body.collect().await.unwrap().to_bytes(); + + let deflate_bytes_len = bytes.len(); + + let br_only_layer = CompressionLayer::new() + .quality(CompressionLevel::Best) + .no_gzip() + .no_deflate(); + + let mut service = ServiceBuilder::new() + // Compress responses based on the `Accept-Encoding` header. + .layer(br_only_layer) + .service_fn(handle); + + // Call the service with the br only layer + let request = Request::builder() + .header(ACCEPT_ENCODING, "gzip, deflate, br") + .body(Body::empty())?; + + let response = service.ready().await?.call(request).await?; + + assert_eq!(response.headers()["content-encoding"], "br"); + + // Read the body + let body = response.into_body(); + let bytes = body.collect().await.unwrap().to_bytes(); + + let br_byte_length = bytes.len(); + + // check the corresponding algorithms are actually used + // br should compresses better than deflate + assert!(br_byte_length < deflate_bytes_len * 9 / 10); + + Ok(()) + } + + /// Test ensuring that zstd compression will not exceed an 8MiB window size; browsers do not + /// accept responses using 16MiB+ window sizes. + #[tokio::test] + async fn zstd_is_web_safe() -> Result<(), crate::BoxError> { + async fn zeroes(_req: Request<Body>) -> Result<Response<Body>, Infallible> { + Ok(Response::new(Body::from(vec![0u8; 18_874_368]))) + } + // zstd will (I believe) lower its window size if a larger one isn't beneficial and + // it knows the size of the input; use an 18MiB body to ensure it would want a + // >=16MiB window (though it might not be able to see the input size here). + + let zstd_layer = CompressionLayer::new() + .quality(CompressionLevel::Best) + .no_br() + .no_deflate() + .no_gzip(); + + let mut service = ServiceBuilder::new().layer(zstd_layer).service_fn(zeroes); + + let request = Request::builder() + .header(ACCEPT_ENCODING, "zstd") + .body(Body::empty())?; + + let response = service.ready().await?.call(request).await?; + + assert_eq!(response.headers()["content-encoding"], "zstd"); + + let body = response.into_body(); + let bytes = body.collect().await?.to_bytes(); + let mut dec = zstd::Decoder::new(&*bytes)?; + dec.window_log_max(23)?; // Limit window size accepted by decoder to 2 ^ 23 bytes (8MiB) + + std::io::copy(&mut dec, &mut std::io::sink())?; + + Ok(()) + } +} diff --git a/vendor/tower-http/src/compression/mod.rs b/vendor/tower-http/src/compression/mod.rs new file mode 100644 index 00000000..5772b9ba --- /dev/null +++ b/vendor/tower-http/src/compression/mod.rs @@ -0,0 +1,511 @@ +//! Middleware that compresses response bodies. +//! +//! # Example +//! +//! Example showing how to respond with the compressed contents of a file. +//! +//! ```rust +//! use bytes::{Bytes, BytesMut}; +//! use http::{Request, Response, header::ACCEPT_ENCODING}; +//! use http_body_util::{Full, BodyExt, StreamBody, combinators::UnsyncBoxBody}; +//! use http_body::Frame; +//! use std::convert::Infallible; +//! use tokio::fs::{self, File}; +//! use tokio_util::io::ReaderStream; +//! use tower::{Service, ServiceExt, ServiceBuilder, service_fn}; +//! use tower_http::{compression::CompressionLayer, BoxError}; +//! use futures_util::TryStreamExt; +//! +//! type BoxBody = UnsyncBoxBody<Bytes, std::io::Error>; +//! +//! # #[tokio::main] +//! # async fn main() -> Result<(), BoxError> { +//! async fn handle(req: Request<Full<Bytes>>) -> Result<Response<BoxBody>, Infallible> { +//! // Open the file. +//! let file = File::open("Cargo.toml").await.expect("file missing"); +//! // Convert the file into a `Stream` of `Bytes`. +//! let stream = ReaderStream::new(file); +//! // Convert the stream into a stream of data `Frame`s. +//! let stream = stream.map_ok(Frame::data); +//! // Convert the `Stream` into a `Body`. +//! let body = StreamBody::new(stream); +//! // Erase the type because its very hard to name in the function signature. +//! let body = body.boxed_unsync(); +//! // Create response. +//! Ok(Response::new(body)) +//! } +//! +//! let mut service = ServiceBuilder::new() +//! // Compress responses based on the `Accept-Encoding` header. +//! .layer(CompressionLayer::new()) +//! .service_fn(handle); +//! +//! // Call the service. +//! let request = Request::builder() +//! .header(ACCEPT_ENCODING, "gzip") +//! .body(Full::<Bytes>::default())?; +//! +//! let response = service +//! .ready() +//! .await? +//! .call(request) +//! .await?; +//! +//! assert_eq!(response.headers()["content-encoding"], "gzip"); +//! +//! // Read the body +//! let bytes = response +//! .into_body() +//! .collect() +//! .await? +//! .to_bytes(); +//! +//! // The compressed body should be smaller 🤞 +//! let uncompressed_len = fs::read_to_string("Cargo.toml").await?.len(); +//! assert!(bytes.len() < uncompressed_len); +//! # +//! # Ok(()) +//! # } +//! ``` +//! + +pub mod predicate; + +mod body; +mod future; +mod layer; +mod pin_project_cfg; +mod service; + +#[doc(inline)] +pub use self::{ + body::CompressionBody, + future::ResponseFuture, + layer::CompressionLayer, + predicate::{DefaultPredicate, Predicate}, + service::Compression, +}; +pub use crate::compression_utils::CompressionLevel; + +#[cfg(test)] +mod tests { + use crate::compression::predicate::SizeAbove; + + use super::*; + use crate::test_helpers::{Body, WithTrailers}; + use async_compression::tokio::write::{BrotliDecoder, BrotliEncoder}; + use flate2::read::GzDecoder; + use http::header::{ + ACCEPT_ENCODING, ACCEPT_RANGES, CONTENT_ENCODING, CONTENT_RANGE, CONTENT_TYPE, RANGE, + }; + use http::{HeaderMap, HeaderName, HeaderValue, Request, Response}; + use http_body::Body as _; + use http_body_util::BodyExt; + use std::convert::Infallible; + use std::io::Read; + use std::sync::{Arc, RwLock}; + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + use tokio_util::io::StreamReader; + use tower::{service_fn, Service, ServiceExt}; + + // Compression filter allows every other request to be compressed + #[derive(Clone)] + struct Always; + + impl Predicate for Always { + fn should_compress<B>(&self, _: &http::Response<B>) -> bool + where + B: http_body::Body, + { + true + } + } + + #[tokio::test] + async fn gzip_works() { + let svc = service_fn(handle); + let mut svc = Compression::new(svc).compress_when(Always); + + // call the service + let req = Request::builder() + .header("accept-encoding", "gzip") + .body(Body::empty()) + .unwrap(); + let res = svc.ready().await.unwrap().call(req).await.unwrap(); + + // read the compressed body + let collected = res.into_body().collect().await.unwrap(); + let trailers = collected.trailers().cloned().unwrap(); + let compressed_data = collected.to_bytes(); + + // decompress the body + // doing this with flate2 as that is much easier than async-compression and blocking during + // tests is fine + let mut decoder = GzDecoder::new(&compressed_data[..]); + let mut decompressed = String::new(); + decoder.read_to_string(&mut decompressed).unwrap(); + + assert_eq!(decompressed, "Hello, World!"); + + // trailers are maintained + assert_eq!(trailers["foo"], "bar"); + } + + #[tokio::test] + async fn x_gzip_works() { + let svc = service_fn(handle); + let mut svc = Compression::new(svc).compress_when(Always); + + // call the service + let req = Request::builder() + .header("accept-encoding", "x-gzip") + .body(Body::empty()) + .unwrap(); + let res = svc.ready().await.unwrap().call(req).await.unwrap(); + + // we treat x-gzip as equivalent to gzip and don't have to return x-gzip + // taking extra caution by checking all headers with this name + assert_eq!( + res.headers() + .get_all("content-encoding") + .iter() + .collect::<Vec<&HeaderValue>>(), + vec!(HeaderValue::from_static("gzip")) + ); + + // read the compressed body + let collected = res.into_body().collect().await.unwrap(); + let trailers = collected.trailers().cloned().unwrap(); + let compressed_data = collected.to_bytes(); + + // decompress the body + // doing this with flate2 as that is much easier than async-compression and blocking during + // tests is fine + let mut decoder = GzDecoder::new(&compressed_data[..]); + let mut decompressed = String::new(); + decoder.read_to_string(&mut decompressed).unwrap(); + + assert_eq!(decompressed, "Hello, World!"); + + // trailers are maintained + assert_eq!(trailers["foo"], "bar"); + } + + #[tokio::test] + async fn zstd_works() { + let svc = service_fn(handle); + let mut svc = Compression::new(svc).compress_when(Always); + + // call the service + let req = Request::builder() + .header("accept-encoding", "zstd") + .body(Body::empty()) + .unwrap(); + let res = svc.ready().await.unwrap().call(req).await.unwrap(); + + // read the compressed body + let body = res.into_body(); + let compressed_data = body.collect().await.unwrap().to_bytes(); + + // decompress the body + let decompressed = zstd::stream::decode_all(std::io::Cursor::new(compressed_data)).unwrap(); + let decompressed = String::from_utf8(decompressed).unwrap(); + + assert_eq!(decompressed, "Hello, World!"); + } + + #[tokio::test] + async fn no_recompress() { + const DATA: &str = "Hello, World! I'm already compressed with br!"; + + let svc = service_fn(|_| async { + let buf = { + let mut buf = Vec::new(); + + let mut enc = BrotliEncoder::new(&mut buf); + enc.write_all(DATA.as_bytes()).await?; + enc.flush().await?; + buf + }; + + let resp = Response::builder() + .header("content-encoding", "br") + .body(Body::from(buf)) + .unwrap(); + Ok::<_, std::io::Error>(resp) + }); + let mut svc = Compression::new(svc); + + // call the service + // + // note: the accept-encoding doesn't match the content-encoding above, so that + // we're able to see if the compression layer triggered or not + let req = Request::builder() + .header("accept-encoding", "gzip") + .body(Body::empty()) + .unwrap(); + let res = svc.ready().await.unwrap().call(req).await.unwrap(); + + // check we didn't recompress + assert_eq!( + res.headers() + .get("content-encoding") + .and_then(|h| h.to_str().ok()) + .unwrap_or_default(), + "br", + ); + + // read the compressed body + let body = res.into_body(); + let data = body.collect().await.unwrap().to_bytes(); + + // decompress the body + let data = { + let mut output_buf = Vec::new(); + let mut decoder = BrotliDecoder::new(&mut output_buf); + decoder + .write_all(&data) + .await + .expect("couldn't brotli-decode"); + decoder.flush().await.expect("couldn't flush"); + output_buf + }; + + assert_eq!(data, DATA.as_bytes()); + } + + async fn handle(_req: Request<Body>) -> Result<Response<WithTrailers<Body>>, Infallible> { + let mut trailers = HeaderMap::new(); + trailers.insert(HeaderName::from_static("foo"), "bar".parse().unwrap()); + let body = Body::from("Hello, World!").with_trailers(trailers); + Ok(Response::builder().body(body).unwrap()) + } + + #[tokio::test] + async fn will_not_compress_if_filtered_out() { + use predicate::Predicate; + + const DATA: &str = "Hello world uncompressed"; + + let svc_fn = service_fn(|_| async { + let resp = Response::builder() + // .header("content-encoding", "br") + .body(Body::from(DATA.as_bytes())) + .unwrap(); + Ok::<_, std::io::Error>(resp) + }); + + // Compression filter allows every other request to be compressed + #[derive(Default, Clone)] + struct EveryOtherResponse(Arc<RwLock<u64>>); + + #[allow(clippy::dbg_macro)] + impl Predicate for EveryOtherResponse { + fn should_compress<B>(&self, _: &http::Response<B>) -> bool + where + B: http_body::Body, + { + let mut guard = self.0.write().unwrap(); + let should_compress = *guard % 2 != 0; + *guard += 1; + dbg!(should_compress) + } + } + + let mut svc = Compression::new(svc_fn).compress_when(EveryOtherResponse::default()); + let req = Request::builder() + .header("accept-encoding", "br") + .body(Body::empty()) + .unwrap(); + let res = svc.ready().await.unwrap().call(req).await.unwrap(); + + // read the uncompressed body + let body = res.into_body(); + let data = body.collect().await.unwrap().to_bytes(); + let still_uncompressed = String::from_utf8(data.to_vec()).unwrap(); + assert_eq!(DATA, &still_uncompressed); + + // Compression filter will compress the next body + let req = Request::builder() + .header("accept-encoding", "br") + .body(Body::empty()) + .unwrap(); + let res = svc.ready().await.unwrap().call(req).await.unwrap(); + + // read the compressed body + let body = res.into_body(); + let data = body.collect().await.unwrap().to_bytes(); + assert!(String::from_utf8(data.to_vec()).is_err()); + } + + #[tokio::test] + async fn doesnt_compress_images() { + async fn handle(_req: Request<Body>) -> Result<Response<Body>, Infallible> { + let mut res = Response::new(Body::from( + "a".repeat((SizeAbove::DEFAULT_MIN_SIZE * 2) as usize), + )); + res.headers_mut() + .insert(CONTENT_TYPE, "image/png".parse().unwrap()); + Ok(res) + } + + let svc = Compression::new(service_fn(handle)); + + let res = svc + .oneshot( + Request::builder() + .header(ACCEPT_ENCODING, "gzip") + .body(Body::empty()) + .unwrap(), + ) + .await + .unwrap(); + assert!(res.headers().get(CONTENT_ENCODING).is_none()); + } + + #[tokio::test] + async fn does_compress_svg() { + async fn handle(_req: Request<Body>) -> Result<Response<Body>, Infallible> { + let mut res = Response::new(Body::from( + "a".repeat((SizeAbove::DEFAULT_MIN_SIZE * 2) as usize), + )); + res.headers_mut() + .insert(CONTENT_TYPE, "image/svg+xml".parse().unwrap()); + Ok(res) + } + + let svc = Compression::new(service_fn(handle)); + + let res = svc + .oneshot( + Request::builder() + .header(ACCEPT_ENCODING, "gzip") + .body(Body::empty()) + .unwrap(), + ) + .await + .unwrap(); + assert_eq!(res.headers()[CONTENT_ENCODING], "gzip"); + } + + #[tokio::test] + async fn compress_with_quality() { + const DATA: &str = "Check compression quality level! Check compression quality level! Check compression quality level!"; + let level = CompressionLevel::Best; + + let svc = service_fn(|_| async { + let resp = Response::builder() + .body(Body::from(DATA.as_bytes())) + .unwrap(); + Ok::<_, std::io::Error>(resp) + }); + + let mut svc = Compression::new(svc).quality(level); + + // call the service + let req = Request::builder() + .header("accept-encoding", "br") + .body(Body::empty()) + .unwrap(); + let res = svc.ready().await.unwrap().call(req).await.unwrap(); + + // read the compressed body + let body = res.into_body(); + let compressed_data = body.collect().await.unwrap().to_bytes(); + + // build the compressed body with the same quality level + let compressed_with_level = { + use async_compression::tokio::bufread::BrotliEncoder; + + let stream = Box::pin(futures_util::stream::once(async move { + Ok::<_, std::io::Error>(DATA.as_bytes()) + })); + let reader = StreamReader::new(stream); + let mut enc = BrotliEncoder::with_quality(reader, level.into_async_compression()); + + let mut buf = Vec::new(); + enc.read_to_end(&mut buf).await.unwrap(); + buf + }; + + assert_eq!( + compressed_data, + compressed_with_level.as_slice(), + "Compression level is not respected" + ); + } + + #[tokio::test] + async fn should_not_compress_ranges() { + let svc = service_fn(|_| async { + let mut res = Response::new(Body::from("Hello")); + let headers = res.headers_mut(); + headers.insert(ACCEPT_RANGES, "bytes".parse().unwrap()); + headers.insert(CONTENT_RANGE, "bytes 0-4/*".parse().unwrap()); + Ok::<_, std::io::Error>(res) + }); + let mut svc = Compression::new(svc).compress_when(Always); + + // call the service + let req = Request::builder() + .header(ACCEPT_ENCODING, "gzip") + .header(RANGE, "bytes=0-4") + .body(Body::empty()) + .unwrap(); + let res = svc.ready().await.unwrap().call(req).await.unwrap(); + let headers = res.headers().clone(); + + // read the uncompressed body + let collected = res.into_body().collect().await.unwrap().to_bytes(); + + assert_eq!(headers[ACCEPT_RANGES], "bytes"); + assert!(!headers.contains_key(CONTENT_ENCODING)); + assert_eq!(collected, "Hello"); + } + + #[tokio::test] + async fn should_strip_accept_ranges_header_when_compressing() { + let svc = service_fn(|_| async { + let mut res = Response::new(Body::from("Hello, World!")); + res.headers_mut() + .insert(ACCEPT_RANGES, "bytes".parse().unwrap()); + Ok::<_, std::io::Error>(res) + }); + let mut svc = Compression::new(svc).compress_when(Always); + + // call the service + let req = Request::builder() + .header(ACCEPT_ENCODING, "gzip") + .body(Body::empty()) + .unwrap(); + let res = svc.ready().await.unwrap().call(req).await.unwrap(); + let headers = res.headers().clone(); + + // read the compressed body + let collected = res.into_body().collect().await.unwrap(); + let compressed_data = collected.to_bytes(); + + // decompress the body + // doing this with flate2 as that is much easier than async-compression and blocking during + // tests is fine + let mut decoder = GzDecoder::new(&compressed_data[..]); + let mut decompressed = String::new(); + decoder.read_to_string(&mut decompressed).unwrap(); + + assert!(!headers.contains_key(ACCEPT_RANGES)); + assert_eq!(headers[CONTENT_ENCODING], "gzip"); + assert_eq!(decompressed, "Hello, World!"); + } + + #[tokio::test] + async fn size_hint_identity() { + let msg = "Hello, world!"; + let svc = service_fn(|_| async { Ok::<_, std::io::Error>(Response::new(Body::from(msg))) }); + let mut svc = Compression::new(svc); + + let req = Request::new(Body::empty()); + let res = svc.ready().await.unwrap().call(req).await.unwrap(); + let body = res.into_body(); + assert_eq!(body.size_hint().exact().unwrap(), msg.len() as u64); + } +} diff --git a/vendor/tower-http/src/compression/pin_project_cfg.rs b/vendor/tower-http/src/compression/pin_project_cfg.rs new file mode 100644 index 00000000..655b8d94 --- /dev/null +++ b/vendor/tower-http/src/compression/pin_project_cfg.rs @@ -0,0 +1,144 @@ +// Full credit to @tesaguri who posted this gist under CC0 1.0 Universal licence +// https://gist.github.com/tesaguri/2a1c0790a48bbda3dd7f71c26d02a793 + +macro_rules! pin_project_cfg { + ($(#[$($attr:tt)*])* $vis:vis enum $($rest:tt)+) => { + pin_project_cfg! { + @outer [$(#[$($attr)*])* $vis enum] $($rest)+ + } + }; + // Accumulate type parameters and `where` clause. + (@outer [$($accum:tt)*] $tt:tt $($rest:tt)+) => { + pin_project_cfg! { + @outer [$($accum)* $tt] $($rest)+ + } + }; + (@outer [$($accum:tt)*] { $($body:tt)* }) => { + pin_project_cfg! { + @body #[cfg(all())] [$($accum)*] {} $($body)* + } + }; + // Process a variant with `cfg`. + ( + @body + #[cfg(all($($pred_accum:tt)*))] + $outer:tt + { $($accum:tt)* } + + #[cfg($($pred:tt)*)] + $(#[$($attr:tt)*])* $variant:ident { $($body:tt)* }, + $($rest:tt)* + ) => { + // Create two versions of the enum with `cfg($pred)` and `cfg(not($pred))`. + pin_project_cfg! { + @variant_body + { $($body)* } + {} + #[cfg(all($($pred_accum)* $($pred)*,))] + $outer + { $($accum)* $(#[$($attr)*])* $variant } + $($rest)* + } + pin_project_cfg! { + @body + #[cfg(all($($pred_accum)* not($($pred)*),))] + $outer + { $($accum)* } + $($rest)* + } + }; + // Process a variant without `cfg`. + ( + @body + #[cfg(all($($pred_accum:tt)*))] + $outer:tt + { $($accum:tt)* } + + $(#[$($attr:tt)*])* $variant:ident { $($body:tt)* }, + $($rest:tt)* + ) => { + pin_project_cfg! { + @variant_body + { $($body)* } + {} + #[cfg(all($($pred_accum)*))] + $outer + { $($accum)* $(#[$($attr)*])* $variant } + $($rest)* + } + }; + // Process a variant field with `cfg`. + ( + @variant_body + { + #[cfg($($pred:tt)*)] + $(#[$($attr:tt)*])* $field:ident: $ty:ty, + $($rest:tt)* + } + { $($accum:tt)* } + #[cfg(all($($pred_accum:tt)*))] + $($outer:tt)* + ) => { + pin_project_cfg! { + @variant_body + {$($rest)*} + { $($accum)* $(#[$($attr)*])* $field: $ty, } + #[cfg(all($($pred_accum)* $($pred)*,))] + $($outer)* + } + pin_project_cfg! { + @variant_body + { $($rest)* } + { $($accum)* } + #[cfg(all($($pred_accum)* not($($pred)*),))] + $($outer)* + } + }; + // Process a variant field without `cfg`. + ( + @variant_body + { + $(#[$($attr:tt)*])* $field:ident: $ty:ty, + $($rest:tt)* + } + { $($accum:tt)* } + $($outer:tt)* + ) => { + pin_project_cfg! { + @variant_body + {$($rest)*} + { $($accum)* $(#[$($attr)*])* $field: $ty, } + $($outer)* + } + }; + ( + @variant_body + {} + $body:tt + #[cfg(all($($pred_accum:tt)*))] + $outer:tt + { $($accum:tt)* } + $($rest:tt)* + ) => { + pin_project_cfg! { + @body + #[cfg(all($($pred_accum)*))] + $outer + { $($accum)* $body, } + $($rest)* + } + }; + ( + @body + #[$cfg:meta] + [$($outer:tt)*] + $body:tt + ) => { + #[$cfg] + pin_project_lite::pin_project! { + $($outer)* $body + } + }; +} + +pub(crate) use pin_project_cfg; diff --git a/vendor/tower-http/src/compression/predicate.rs b/vendor/tower-http/src/compression/predicate.rs new file mode 100644 index 00000000..88c3101c --- /dev/null +++ b/vendor/tower-http/src/compression/predicate.rs @@ -0,0 +1,272 @@ +//! Predicates for disabling compression of responses. +//! +//! Predicates are applied with [`Compression::compress_when`] or +//! [`CompressionLayer::compress_when`]. +//! +//! [`Compression::compress_when`]: super::Compression::compress_when +//! [`CompressionLayer::compress_when`]: super::CompressionLayer::compress_when + +use http::{header, Extensions, HeaderMap, StatusCode, Version}; +use http_body::Body; +use std::{fmt, sync::Arc}; + +/// Predicate used to determine if a response should be compressed or not. +pub trait Predicate: Clone { + /// Should this response be compressed or not? + fn should_compress<B>(&self, response: &http::Response<B>) -> bool + where + B: Body; + + /// Combine two predicates into one. + /// + /// The resulting predicate enables compression if both inner predicates do. + fn and<Other>(self, other: Other) -> And<Self, Other> + where + Self: Sized, + Other: Predicate, + { + And { + lhs: self, + rhs: other, + } + } +} + +impl<F> Predicate for F +where + F: Fn(StatusCode, Version, &HeaderMap, &Extensions) -> bool + Clone, +{ + fn should_compress<B>(&self, response: &http::Response<B>) -> bool + where + B: Body, + { + let status = response.status(); + let version = response.version(); + let headers = response.headers(); + let extensions = response.extensions(); + self(status, version, headers, extensions) + } +} + +impl<T> Predicate for Option<T> +where + T: Predicate, +{ + fn should_compress<B>(&self, response: &http::Response<B>) -> bool + where + B: Body, + { + self.as_ref() + .map(|inner| inner.should_compress(response)) + .unwrap_or(true) + } +} + +/// Two predicates combined into one. +/// +/// Created with [`Predicate::and`] +#[derive(Debug, Clone, Default, Copy)] +pub struct And<Lhs, Rhs> { + lhs: Lhs, + rhs: Rhs, +} + +impl<Lhs, Rhs> Predicate for And<Lhs, Rhs> +where + Lhs: Predicate, + Rhs: Predicate, +{ + fn should_compress<B>(&self, response: &http::Response<B>) -> bool + where + B: Body, + { + self.lhs.should_compress(response) && self.rhs.should_compress(response) + } +} + +/// The default predicate used by [`Compression`] and [`CompressionLayer`]. +/// +/// This will compress responses unless: +/// +/// - They're gRPC, which has its own protocol specific compression scheme. +/// - It's an image as determined by the `content-type` starting with `image/`. +/// - They're Server-Sent Events (SSE) as determined by the `content-type` being `text/event-stream`. +/// - The response is less than 32 bytes. +/// +/// # Configuring the defaults +/// +/// `DefaultPredicate` doesn't support any configuration. Instead you can build your own predicate +/// by combining types in this module: +/// +/// ```rust +/// use tower_http::compression::predicate::{SizeAbove, NotForContentType, Predicate}; +/// +/// // slightly large min size than the default 32 +/// let predicate = SizeAbove::new(256) +/// // still don't compress gRPC +/// .and(NotForContentType::GRPC) +/// // still don't compress images +/// .and(NotForContentType::IMAGES) +/// // also don't compress JSON +/// .and(NotForContentType::const_new("application/json")); +/// ``` +/// +/// [`Compression`]: super::Compression +/// [`CompressionLayer`]: super::CompressionLayer +#[derive(Clone)] +pub struct DefaultPredicate( + And<And<And<SizeAbove, NotForContentType>, NotForContentType>, NotForContentType>, +); + +impl DefaultPredicate { + /// Create a new `DefaultPredicate`. + pub fn new() -> Self { + let inner = SizeAbove::new(SizeAbove::DEFAULT_MIN_SIZE) + .and(NotForContentType::GRPC) + .and(NotForContentType::IMAGES) + .and(NotForContentType::SSE); + Self(inner) + } +} + +impl Default for DefaultPredicate { + fn default() -> Self { + Self::new() + } +} + +impl Predicate for DefaultPredicate { + fn should_compress<B>(&self, response: &http::Response<B>) -> bool + where + B: Body, + { + self.0.should_compress(response) + } +} + +/// [`Predicate`] that will only allow compression of responses above a certain size. +#[derive(Clone, Copy, Debug)] +pub struct SizeAbove(u16); + +impl SizeAbove { + pub(crate) const DEFAULT_MIN_SIZE: u16 = 32; + + /// Create a new `SizeAbove` predicate that will only compress responses larger than + /// `min_size_bytes`. + /// + /// The response will be compressed if the exact size cannot be determined through either the + /// `content-length` header or [`Body::size_hint`]. + pub const fn new(min_size_bytes: u16) -> Self { + Self(min_size_bytes) + } +} + +impl Default for SizeAbove { + fn default() -> Self { + Self(Self::DEFAULT_MIN_SIZE) + } +} + +impl Predicate for SizeAbove { + fn should_compress<B>(&self, response: &http::Response<B>) -> bool + where + B: Body, + { + let content_size = response.body().size_hint().exact().or_else(|| { + response + .headers() + .get(header::CONTENT_LENGTH) + .and_then(|h| h.to_str().ok()) + .and_then(|val| val.parse().ok()) + }); + + match content_size { + Some(size) => size >= (self.0 as u64), + _ => true, + } + } +} + +/// Predicate that wont allow responses with a specific `content-type` to be compressed. +#[derive(Clone, Debug)] +pub struct NotForContentType { + content_type: Str, + exception: Option<Str>, +} + +impl NotForContentType { + /// Predicate that wont compress gRPC responses. + pub const GRPC: Self = Self::const_new("application/grpc"); + + /// Predicate that wont compress images. + pub const IMAGES: Self = Self { + content_type: Str::Static("image/"), + exception: Some(Str::Static("image/svg+xml")), + }; + + /// Predicate that wont compress Server-Sent Events (SSE) responses. + pub const SSE: Self = Self::const_new("text/event-stream"); + + /// Create a new `NotForContentType`. + pub fn new(content_type: &str) -> Self { + Self { + content_type: Str::Shared(content_type.into()), + exception: None, + } + } + + /// Create a new `NotForContentType` from a static string. + pub const fn const_new(content_type: &'static str) -> Self { + Self { + content_type: Str::Static(content_type), + exception: None, + } + } +} + +impl Predicate for NotForContentType { + fn should_compress<B>(&self, response: &http::Response<B>) -> bool + where + B: Body, + { + if let Some(except) = &self.exception { + if content_type(response) == except.as_str() { + return true; + } + } + + !content_type(response).starts_with(self.content_type.as_str()) + } +} + +#[derive(Clone)] +enum Str { + Static(&'static str), + Shared(Arc<str>), +} + +impl Str { + fn as_str(&self) -> &str { + match self { + Str::Static(s) => s, + Str::Shared(s) => s, + } + } +} + +impl fmt::Debug for Str { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Static(inner) => inner.fmt(f), + Self::Shared(inner) => inner.fmt(f), + } + } +} + +fn content_type<B>(response: &http::Response<B>) -> &str { + response + .headers() + .get(header::CONTENT_TYPE) + .and_then(|h| h.to_str().ok()) + .unwrap_or_default() +} diff --git a/vendor/tower-http/src/compression/service.rs b/vendor/tower-http/src/compression/service.rs new file mode 100644 index 00000000..22dcf73a --- /dev/null +++ b/vendor/tower-http/src/compression/service.rs @@ -0,0 +1,185 @@ +use super::{CompressionBody, CompressionLayer, ResponseFuture}; +use crate::compression::predicate::{DefaultPredicate, Predicate}; +use crate::compression::CompressionLevel; +use crate::{compression_utils::AcceptEncoding, content_encoding::Encoding}; +use http::{Request, Response}; +use http_body::Body; +use std::task::{Context, Poll}; +use tower_service::Service; + +/// Compress response bodies of the underlying service. +/// +/// This uses the `Accept-Encoding` header to pick an appropriate encoding and adds the +/// `Content-Encoding` header to responses. +/// +/// See the [module docs](crate::compression) for more details. +#[derive(Clone, Copy)] +pub struct Compression<S, P = DefaultPredicate> { + pub(crate) inner: S, + pub(crate) accept: AcceptEncoding, + pub(crate) predicate: P, + pub(crate) quality: CompressionLevel, +} + +impl<S> Compression<S, DefaultPredicate> { + /// Creates a new `Compression` wrapping the `service`. + pub fn new(service: S) -> Compression<S, DefaultPredicate> { + Self { + inner: service, + accept: AcceptEncoding::default(), + predicate: DefaultPredicate::default(), + quality: CompressionLevel::default(), + } + } +} + +impl<S, P> Compression<S, P> { + define_inner_service_accessors!(); + + /// Returns a new [`Layer`] that wraps services with a `Compression` middleware. + /// + /// [`Layer`]: tower_layer::Layer + pub fn layer() -> CompressionLayer { + CompressionLayer::new() + } + + /// Sets whether to enable the gzip encoding. + #[cfg(feature = "compression-gzip")] + pub fn gzip(mut self, enable: bool) -> Self { + self.accept.set_gzip(enable); + self + } + + /// Sets whether to enable the Deflate encoding. + #[cfg(feature = "compression-deflate")] + pub fn deflate(mut self, enable: bool) -> Self { + self.accept.set_deflate(enable); + self + } + + /// Sets whether to enable the Brotli encoding. + #[cfg(feature = "compression-br")] + pub fn br(mut self, enable: bool) -> Self { + self.accept.set_br(enable); + self + } + + /// Sets whether to enable the Zstd encoding. + #[cfg(feature = "compression-zstd")] + pub fn zstd(mut self, enable: bool) -> Self { + self.accept.set_zstd(enable); + self + } + + /// Sets the compression quality. + pub fn quality(mut self, quality: CompressionLevel) -> Self { + self.quality = quality; + self + } + + /// Disables the gzip encoding. + /// + /// This method is available even if the `gzip` crate feature is disabled. + pub fn no_gzip(mut self) -> Self { + self.accept.set_gzip(false); + self + } + + /// Disables the Deflate encoding. + /// + /// This method is available even if the `deflate` crate feature is disabled. + pub fn no_deflate(mut self) -> Self { + self.accept.set_deflate(false); + self + } + + /// Disables the Brotli encoding. + /// + /// This method is available even if the `br` crate feature is disabled. + pub fn no_br(mut self) -> Self { + self.accept.set_br(false); + self + } + + /// Disables the Zstd encoding. + /// + /// This method is available even if the `zstd` crate feature is disabled. + pub fn no_zstd(mut self) -> Self { + self.accept.set_zstd(false); + self + } + + /// Replace the current compression predicate. + /// + /// Predicates are used to determine whether a response should be compressed or not. + /// + /// The default predicate is [`DefaultPredicate`]. See its documentation for more + /// details on which responses it wont compress. + /// + /// # Changing the compression predicate + /// + /// ``` + /// use tower_http::compression::{ + /// Compression, + /// predicate::{Predicate, NotForContentType, DefaultPredicate}, + /// }; + /// use tower::util::service_fn; + /// + /// // Placeholder service_fn + /// let service = service_fn(|_: ()| async { + /// Ok::<_, std::io::Error>(http::Response::new(())) + /// }); + /// + /// // build our custom compression predicate + /// // its recommended to still include `DefaultPredicate` as part of + /// // custom predicates + /// let predicate = DefaultPredicate::new() + /// // don't compress responses who's `content-type` starts with `application/json` + /// .and(NotForContentType::new("application/json")); + /// + /// let service = Compression::new(service).compress_when(predicate); + /// ``` + /// + /// See [`predicate`](super::predicate) for more utilities for building compression predicates. + /// + /// Responses that are already compressed (ie have a `content-encoding` header) will _never_ be + /// recompressed, regardless what they predicate says. + pub fn compress_when<C>(self, predicate: C) -> Compression<S, C> + where + C: Predicate, + { + Compression { + inner: self.inner, + accept: self.accept, + predicate, + quality: self.quality, + } + } +} + +impl<ReqBody, ResBody, S, P> Service<Request<ReqBody>> for Compression<S, P> +where + S: Service<Request<ReqBody>, Response = Response<ResBody>>, + ResBody: Body, + P: Predicate, +{ + type Response = Response<CompressionBody<ResBody>>; + type Error = S::Error; + type Future = ResponseFuture<S::Future, P>; + + #[inline] + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, req: Request<ReqBody>) -> Self::Future { + let encoding = Encoding::from_headers(req.headers(), self.accept); + + ResponseFuture { + inner: self.inner.call(req), + encoding, + predicate: self.predicate.clone(), + quality: self.quality, + } + } +} diff --git a/vendor/tower-http/src/compression_utils.rs b/vendor/tower-http/src/compression_utils.rs new file mode 100644 index 00000000..58aa6a08 --- /dev/null +++ b/vendor/tower-http/src/compression_utils.rs @@ -0,0 +1,457 @@ +//! Types used by compression and decompression middleware. + +use crate::{content_encoding::SupportedEncodings, BoxError}; +use bytes::{Buf, Bytes, BytesMut}; +use futures_core::Stream; +use http::HeaderValue; +use http_body::{Body, Frame}; +use pin_project_lite::pin_project; +use std::{ + io, + pin::Pin, + task::{ready, Context, Poll}, +}; +use tokio::io::AsyncRead; +use tokio_util::io::StreamReader; + +#[derive(Debug, Clone, Copy)] +pub(crate) struct AcceptEncoding { + pub(crate) gzip: bool, + pub(crate) deflate: bool, + pub(crate) br: bool, + pub(crate) zstd: bool, +} + +impl AcceptEncoding { + #[allow(dead_code)] + pub(crate) fn to_header_value(self) -> Option<HeaderValue> { + let accept = match (self.gzip(), self.deflate(), self.br(), self.zstd()) { + (true, true, true, false) => "gzip,deflate,br", + (true, true, false, false) => "gzip,deflate", + (true, false, true, false) => "gzip,br", + (true, false, false, false) => "gzip", + (false, true, true, false) => "deflate,br", + (false, true, false, false) => "deflate", + (false, false, true, false) => "br", + (true, true, true, true) => "zstd,gzip,deflate,br", + (true, true, false, true) => "zstd,gzip,deflate", + (true, false, true, true) => "zstd,gzip,br", + (true, false, false, true) => "zstd,gzip", + (false, true, true, true) => "zstd,deflate,br", + (false, true, false, true) => "zstd,deflate", + (false, false, true, true) => "zstd,br", + (false, false, false, true) => "zstd", + (false, false, false, false) => return None, + }; + Some(HeaderValue::from_static(accept)) + } + + #[allow(dead_code)] + pub(crate) fn set_gzip(&mut self, enable: bool) { + self.gzip = enable; + } + + #[allow(dead_code)] + pub(crate) fn set_deflate(&mut self, enable: bool) { + self.deflate = enable; + } + + #[allow(dead_code)] + pub(crate) fn set_br(&mut self, enable: bool) { + self.br = enable; + } + + #[allow(dead_code)] + pub(crate) fn set_zstd(&mut self, enable: bool) { + self.zstd = enable; + } +} + +impl SupportedEncodings for AcceptEncoding { + #[allow(dead_code)] + fn gzip(&self) -> bool { + #[cfg(any(feature = "decompression-gzip", feature = "compression-gzip"))] + return self.gzip; + + #[cfg(not(any(feature = "decompression-gzip", feature = "compression-gzip")))] + return false; + } + + #[allow(dead_code)] + fn deflate(&self) -> bool { + #[cfg(any(feature = "decompression-deflate", feature = "compression-deflate"))] + return self.deflate; + + #[cfg(not(any(feature = "decompression-deflate", feature = "compression-deflate")))] + return false; + } + + #[allow(dead_code)] + fn br(&self) -> bool { + #[cfg(any(feature = "decompression-br", feature = "compression-br"))] + return self.br; + + #[cfg(not(any(feature = "decompression-br", feature = "compression-br")))] + return false; + } + + #[allow(dead_code)] + fn zstd(&self) -> bool { + #[cfg(any(feature = "decompression-zstd", feature = "compression-zstd"))] + return self.zstd; + + #[cfg(not(any(feature = "decompression-zstd", feature = "compression-zstd")))] + return false; + } +} + +impl Default for AcceptEncoding { + fn default() -> Self { + AcceptEncoding { + gzip: true, + deflate: true, + br: true, + zstd: true, + } + } +} + +/// A `Body` that has been converted into an `AsyncRead`. +pub(crate) type AsyncReadBody<B> = + StreamReader<StreamErrorIntoIoError<BodyIntoStream<B>, <B as Body>::Error>, <B as Body>::Data>; + +/// Trait for applying some decorator to an `AsyncRead` +pub(crate) trait DecorateAsyncRead { + type Input: AsyncRead; + type Output: AsyncRead; + + /// Apply the decorator + fn apply(input: Self::Input, quality: CompressionLevel) -> Self::Output; + + /// Get a pinned mutable reference to the original input. + /// + /// This is necessary to implement `Body::poll_trailers`. + fn get_pin_mut(pinned: Pin<&mut Self::Output>) -> Pin<&mut Self::Input>; +} + +pin_project! { + /// `Body` that has been decorated by an `AsyncRead` + pub(crate) struct WrapBody<M: DecorateAsyncRead> { + #[pin] + // rust-analyer thinks this field is private if its `pub(crate)` but works fine when its + // `pub` + pub read: M::Output, + // A buffer to temporarily store the data read from the underlying body. + // Reused as much as possible to optimize allocations. + buf: BytesMut, + read_all_data: bool, + } +} + +impl<M: DecorateAsyncRead> WrapBody<M> { + const INTERNAL_BUF_CAPACITY: usize = 4096; +} + +impl<M: DecorateAsyncRead> WrapBody<M> { + #[allow(dead_code)] + pub(crate) fn new<B>(body: B, quality: CompressionLevel) -> Self + where + B: Body, + M: DecorateAsyncRead<Input = AsyncReadBody<B>>, + { + // convert `Body` into a `Stream` + let stream = BodyIntoStream::new(body); + + // an adapter that converts the error type into `io::Error` while storing the actual error + // `StreamReader` requires the error type is `io::Error` + let stream = StreamErrorIntoIoError::<_, B::Error>::new(stream); + + // convert `Stream` into an `AsyncRead` + let read = StreamReader::new(stream); + + // apply decorator to `AsyncRead` yielding another `AsyncRead` + let read = M::apply(read, quality); + + Self { + read, + buf: BytesMut::with_capacity(Self::INTERNAL_BUF_CAPACITY), + read_all_data: false, + } + } +} + +impl<B, M> Body for WrapBody<M> +where + B: Body, + B::Error: Into<BoxError>, + M: DecorateAsyncRead<Input = AsyncReadBody<B>>, +{ + type Data = Bytes; + type Error = BoxError; + + fn poll_frame( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> { + let mut this = self.project(); + + if !*this.read_all_data { + if this.buf.capacity() == 0 { + this.buf.reserve(Self::INTERNAL_BUF_CAPACITY); + } + + let result = tokio_util::io::poll_read_buf(this.read.as_mut(), cx, &mut this.buf); + + match ready!(result) { + Ok(0) => { + *this.read_all_data = true; + } + Ok(_) => { + let chunk = this.buf.split().freeze(); + return Poll::Ready(Some(Ok(Frame::data(chunk)))); + } + Err(err) => { + let body_error: Option<B::Error> = M::get_pin_mut(this.read) + .get_pin_mut() + .project() + .error + .take(); + + if let Some(body_error) = body_error { + return Poll::Ready(Some(Err(body_error.into()))); + } else if err.raw_os_error() == Some(SENTINEL_ERROR_CODE) { + // SENTINEL_ERROR_CODE only gets used when storing + // an underlying body error + unreachable!() + } else { + return Poll::Ready(Some(Err(err.into()))); + } + } + } + } + + // poll any remaining frames, such as trailers + let body = M::get_pin_mut(this.read).get_pin_mut().get_pin_mut(); + body.poll_frame(cx).map(|option| { + option.map(|result| { + result + .map(|frame| frame.map_data(|mut data| data.copy_to_bytes(data.remaining()))) + .map_err(|err| err.into()) + }) + }) + } +} + +pin_project! { + pub(crate) struct BodyIntoStream<B> + where + B: Body, + { + #[pin] + body: B, + yielded_all_data: bool, + non_data_frame: Option<Frame<B::Data>>, + } +} + +#[allow(dead_code)] +impl<B> BodyIntoStream<B> +where + B: Body, +{ + pub(crate) fn new(body: B) -> Self { + Self { + body, + yielded_all_data: false, + non_data_frame: None, + } + } + + /// Get a reference to the inner body + pub(crate) fn get_ref(&self) -> &B { + &self.body + } + + /// Get a mutable reference to the inner body + pub(crate) fn get_mut(&mut self) -> &mut B { + &mut self.body + } + + /// Get a pinned mutable reference to the inner body + pub(crate) fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut B> { + self.project().body + } + + /// Consume `self`, returning the inner body + pub(crate) fn into_inner(self) -> B { + self.body + } +} + +impl<B> Stream for BodyIntoStream<B> +where + B: Body, +{ + type Item = Result<B::Data, B::Error>; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { + loop { + let this = self.as_mut().project(); + + if *this.yielded_all_data { + return Poll::Ready(None); + } + + match std::task::ready!(this.body.poll_frame(cx)) { + Some(Ok(frame)) => match frame.into_data() { + Ok(data) => return Poll::Ready(Some(Ok(data))), + Err(frame) => { + *this.yielded_all_data = true; + *this.non_data_frame = Some(frame); + } + }, + Some(Err(err)) => return Poll::Ready(Some(Err(err))), + None => { + *this.yielded_all_data = true; + } + } + } + } +} + +impl<B> Body for BodyIntoStream<B> +where + B: Body, +{ + type Data = B::Data; + type Error = B::Error; + + fn poll_frame( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> { + // First drive the stream impl. This consumes all data frames and buffer at most one + // trailers frame. + if let Some(frame) = std::task::ready!(self.as_mut().poll_next(cx)) { + return Poll::Ready(Some(frame.map(Frame::data))); + } + + let this = self.project(); + + // Yield the trailers frame `poll_next` hit. + if let Some(frame) = this.non_data_frame.take() { + return Poll::Ready(Some(Ok(frame))); + } + + // Yield any remaining frames in the body. There shouldn't be any after the trailers but + // you never know. + this.body.poll_frame(cx) + } + + #[inline] + fn size_hint(&self) -> http_body::SizeHint { + self.body.size_hint() + } +} + +pin_project! { + pub(crate) struct StreamErrorIntoIoError<S, E> { + #[pin] + inner: S, + error: Option<E>, + } +} + +impl<S, E> StreamErrorIntoIoError<S, E> { + pub(crate) fn new(inner: S) -> Self { + Self { inner, error: None } + } + + /// Get a reference to the inner body + pub(crate) fn get_ref(&self) -> &S { + &self.inner + } + + /// Get a mutable reference to the inner inner + pub(crate) fn get_mut(&mut self) -> &mut S { + &mut self.inner + } + + /// Get a pinned mutable reference to the inner inner + pub(crate) fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut S> { + self.project().inner + } + + /// Consume `self`, returning the inner inner + pub(crate) fn into_inner(self) -> S { + self.inner + } +} + +impl<S, T, E> Stream for StreamErrorIntoIoError<S, E> +where + S: Stream<Item = Result<T, E>>, +{ + type Item = Result<T, io::Error>; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { + let this = self.project(); + match ready!(this.inner.poll_next(cx)) { + None => Poll::Ready(None), + Some(Ok(value)) => Poll::Ready(Some(Ok(value))), + Some(Err(err)) => { + *this.error = Some(err); + Poll::Ready(Some(Err(io::Error::from_raw_os_error(SENTINEL_ERROR_CODE)))) + } + } + } +} + +pub(crate) const SENTINEL_ERROR_CODE: i32 = -837459418; + +/// Level of compression data should be compressed with. +#[non_exhaustive] +#[derive(Clone, Copy, Debug, Eq, PartialEq, Default)] +pub enum CompressionLevel { + /// Fastest quality of compression, usually produces bigger size. + Fastest, + /// Best quality of compression, usually produces the smallest size. + Best, + /// Default quality of compression defined by the selected compression + /// algorithm. + #[default] + Default, + /// Precise quality based on the underlying compression algorithms' + /// qualities. + /// + /// The interpretation of this depends on the algorithm chosen and the + /// specific implementation backing it. + /// + /// Qualities are implicitly clamped to the algorithm's maximum. + Precise(i32), +} + +#[cfg(any( + feature = "compression-br", + feature = "compression-gzip", + feature = "compression-deflate", + feature = "compression-zstd" +))] +use async_compression::Level as AsyncCompressionLevel; + +#[cfg(any( + feature = "compression-br", + feature = "compression-gzip", + feature = "compression-deflate", + feature = "compression-zstd" +))] +impl CompressionLevel { + pub(crate) fn into_async_compression(self) -> AsyncCompressionLevel { + match self { + CompressionLevel::Fastest => AsyncCompressionLevel::Fastest, + CompressionLevel::Best => AsyncCompressionLevel::Best, + CompressionLevel::Default => AsyncCompressionLevel::Default, + CompressionLevel::Precise(quality) => AsyncCompressionLevel::Precise(quality), + } + } +} diff --git a/vendor/tower-http/src/content_encoding.rs b/vendor/tower-http/src/content_encoding.rs new file mode 100644 index 00000000..91c21d45 --- /dev/null +++ b/vendor/tower-http/src/content_encoding.rs @@ -0,0 +1,605 @@ +pub(crate) trait SupportedEncodings: Copy { + fn gzip(&self) -> bool; + fn deflate(&self) -> bool; + fn br(&self) -> bool; + fn zstd(&self) -> bool; +} + +// This enum's variants are ordered from least to most preferred. +#[derive(Copy, Clone, Debug, Ord, PartialOrd, PartialEq, Eq)] +pub(crate) enum Encoding { + #[allow(dead_code)] + Identity, + #[cfg(any(feature = "fs", feature = "compression-deflate"))] + Deflate, + #[cfg(any(feature = "fs", feature = "compression-gzip"))] + Gzip, + #[cfg(any(feature = "fs", feature = "compression-br"))] + Brotli, + #[cfg(any(feature = "fs", feature = "compression-zstd"))] + Zstd, +} + +impl Encoding { + #[allow(dead_code)] + fn to_str(self) -> &'static str { + match self { + #[cfg(any(feature = "fs", feature = "compression-gzip"))] + Encoding::Gzip => "gzip", + #[cfg(any(feature = "fs", feature = "compression-deflate"))] + Encoding::Deflate => "deflate", + #[cfg(any(feature = "fs", feature = "compression-br"))] + Encoding::Brotli => "br", + #[cfg(any(feature = "fs", feature = "compression-zstd"))] + Encoding::Zstd => "zstd", + Encoding::Identity => "identity", + } + } + + #[cfg(feature = "fs")] + pub(crate) fn to_file_extension(self) -> Option<&'static std::ffi::OsStr> { + match self { + Encoding::Gzip => Some(std::ffi::OsStr::new(".gz")), + Encoding::Deflate => Some(std::ffi::OsStr::new(".zz")), + Encoding::Brotli => Some(std::ffi::OsStr::new(".br")), + Encoding::Zstd => Some(std::ffi::OsStr::new(".zst")), + Encoding::Identity => None, + } + } + + #[allow(dead_code)] + pub(crate) fn into_header_value(self) -> http::HeaderValue { + http::HeaderValue::from_static(self.to_str()) + } + + #[cfg(any( + feature = "compression-gzip", + feature = "compression-br", + feature = "compression-deflate", + feature = "compression-zstd", + feature = "fs", + ))] + fn parse(s: &str, _supported_encoding: impl SupportedEncodings) -> Option<Encoding> { + #[cfg(any(feature = "fs", feature = "compression-gzip"))] + if (s.eq_ignore_ascii_case("gzip") || s.eq_ignore_ascii_case("x-gzip")) + && _supported_encoding.gzip() + { + return Some(Encoding::Gzip); + } + + #[cfg(any(feature = "fs", feature = "compression-deflate"))] + if s.eq_ignore_ascii_case("deflate") && _supported_encoding.deflate() { + return Some(Encoding::Deflate); + } + + #[cfg(any(feature = "fs", feature = "compression-br"))] + if s.eq_ignore_ascii_case("br") && _supported_encoding.br() { + return Some(Encoding::Brotli); + } + + #[cfg(any(feature = "fs", feature = "compression-zstd"))] + if s.eq_ignore_ascii_case("zstd") && _supported_encoding.zstd() { + return Some(Encoding::Zstd); + } + + if s.eq_ignore_ascii_case("identity") { + return Some(Encoding::Identity); + } + + None + } + + #[cfg(any( + feature = "compression-gzip", + feature = "compression-br", + feature = "compression-zstd", + feature = "compression-deflate", + ))] + // based on https://github.com/http-rs/accept-encoding + pub(crate) fn from_headers( + headers: &http::HeaderMap, + supported_encoding: impl SupportedEncodings, + ) -> Self { + Encoding::preferred_encoding(encodings(headers, supported_encoding)) + .unwrap_or(Encoding::Identity) + } + + #[cfg(any( + feature = "compression-gzip", + feature = "compression-br", + feature = "compression-zstd", + feature = "compression-deflate", + feature = "fs", + ))] + pub(crate) fn preferred_encoding( + accepted_encodings: impl Iterator<Item = (Encoding, QValue)>, + ) -> Option<Self> { + accepted_encodings + .filter(|(_, qvalue)| qvalue.0 > 0) + .max_by_key(|&(encoding, qvalue)| (qvalue, encoding)) + .map(|(encoding, _)| encoding) + } +} + +// Allowed q-values are numbers between 0 and 1 with at most 3 digits in the fractional part. They +// are presented here as an unsigned integer between 0 and 1000. +#[cfg(any( + feature = "compression-gzip", + feature = "compression-br", + feature = "compression-zstd", + feature = "compression-deflate", + feature = "fs", +))] +#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] +pub(crate) struct QValue(u16); + +#[cfg(any( + feature = "compression-gzip", + feature = "compression-br", + feature = "compression-zstd", + feature = "compression-deflate", + feature = "fs", +))] +impl QValue { + #[inline] + pub(crate) fn one() -> Self { + Self(1000) + } + + // Parse a q-value as specified in RFC 7231 section 5.3.1. + fn parse(s: &str) -> Option<Self> { + let mut c = s.chars(); + // Parse "q=" (case-insensitively). + match c.next() { + Some('q' | 'Q') => (), + _ => return None, + }; + match c.next() { + Some('=') => (), + _ => return None, + }; + + // Parse leading digit. Since valid q-values are between 0.000 and 1.000, only "0" and "1" + // are allowed. + let mut value = match c.next() { + Some('0') => 0, + Some('1') => 1000, + _ => return None, + }; + + // Parse optional decimal point. + match c.next() { + Some('.') => (), + None => return Some(Self(value)), + _ => return None, + }; + + // Parse optional fractional digits. The value of each digit is multiplied by `factor`. + // Since the q-value is represented as an integer between 0 and 1000, `factor` is `100` for + // the first digit, `10` for the next, and `1` for the digit after that. + let mut factor = 100; + loop { + match c.next() { + Some(n @ '0'..='9') => { + // If `factor` is less than `1`, three digits have already been parsed. A + // q-value having more than 3 fractional digits is invalid. + if factor < 1 { + return None; + } + // Add the digit's value multiplied by `factor` to `value`. + value += factor * (n as u16 - '0' as u16); + } + None => { + // No more characters to parse. Check that the value representing the q-value is + // in the valid range. + return if value <= 1000 { + Some(Self(value)) + } else { + None + }; + } + _ => return None, + }; + factor /= 10; + } + } +} + +#[cfg(any( + feature = "compression-gzip", + feature = "compression-br", + feature = "compression-zstd", + feature = "compression-deflate", + feature = "fs", +))] +// based on https://github.com/http-rs/accept-encoding +pub(crate) fn encodings<'a>( + headers: &'a http::HeaderMap, + supported_encoding: impl SupportedEncodings + 'a, +) -> impl Iterator<Item = (Encoding, QValue)> + 'a { + headers + .get_all(http::header::ACCEPT_ENCODING) + .iter() + .filter_map(|hval| hval.to_str().ok()) + .flat_map(|s| s.split(',')) + .filter_map(move |v| { + let mut v = v.splitn(2, ';'); + + let encoding = match Encoding::parse(v.next().unwrap().trim(), supported_encoding) { + Some(encoding) => encoding, + None => return None, // ignore unknown encodings + }; + + let qval = if let Some(qval) = v.next() { + QValue::parse(qval.trim())? + } else { + QValue::one() + }; + + Some((encoding, qval)) + }) +} + +#[cfg(all( + test, + feature = "compression-gzip", + feature = "compression-deflate", + feature = "compression-br", + feature = "compression-zstd", +))] +mod tests { + use super::*; + + #[derive(Copy, Clone, Default)] + struct SupportedEncodingsAll; + + impl SupportedEncodings for SupportedEncodingsAll { + fn gzip(&self) -> bool { + true + } + + fn deflate(&self) -> bool { + true + } + + fn br(&self) -> bool { + true + } + + fn zstd(&self) -> bool { + true + } + } + + #[test] + fn no_accept_encoding_header() { + let encoding = Encoding::from_headers(&http::HeaderMap::new(), SupportedEncodingsAll); + assert_eq!(Encoding::Identity, encoding); + } + + #[test] + fn accept_encoding_header_single_encoding() { + let mut headers = http::HeaderMap::new(); + headers.append( + http::header::ACCEPT_ENCODING, + http::HeaderValue::from_static("gzip"), + ); + let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); + assert_eq!(Encoding::Gzip, encoding); + } + + #[test] + fn accept_encoding_header_two_encodings() { + let mut headers = http::HeaderMap::new(); + headers.append( + http::header::ACCEPT_ENCODING, + http::HeaderValue::from_static("gzip,br"), + ); + let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); + assert_eq!(Encoding::Brotli, encoding); + } + + #[test] + fn accept_encoding_header_gzip_x_gzip() { + let mut headers = http::HeaderMap::new(); + headers.append( + http::header::ACCEPT_ENCODING, + http::HeaderValue::from_static("gzip,x-gzip"), + ); + let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); + assert_eq!(Encoding::Gzip, encoding); + } + + #[test] + fn accept_encoding_header_x_gzip_deflate() { + let mut headers = http::HeaderMap::new(); + headers.append( + http::header::ACCEPT_ENCODING, + http::HeaderValue::from_static("deflate,x-gzip"), + ); + let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); + assert_eq!(Encoding::Gzip, encoding); + } + + #[test] + fn accept_encoding_header_three_encodings() { + let mut headers = http::HeaderMap::new(); + headers.append( + http::header::ACCEPT_ENCODING, + http::HeaderValue::from_static("gzip,deflate,br"), + ); + let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); + assert_eq!(Encoding::Brotli, encoding); + } + + #[test] + fn accept_encoding_header_two_encodings_with_one_qvalue() { + let mut headers = http::HeaderMap::new(); + headers.append( + http::header::ACCEPT_ENCODING, + http::HeaderValue::from_static("gzip;q=0.5,br"), + ); + let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); + assert_eq!(Encoding::Brotli, encoding); + } + + #[test] + fn accept_encoding_header_three_encodings_with_one_qvalue() { + let mut headers = http::HeaderMap::new(); + headers.append( + http::header::ACCEPT_ENCODING, + http::HeaderValue::from_static("gzip;q=0.5,deflate,br"), + ); + let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); + assert_eq!(Encoding::Brotli, encoding); + } + + #[test] + fn two_accept_encoding_headers_with_one_qvalue() { + let mut headers = http::HeaderMap::new(); + headers.append( + http::header::ACCEPT_ENCODING, + http::HeaderValue::from_static("gzip;q=0.5"), + ); + headers.append( + http::header::ACCEPT_ENCODING, + http::HeaderValue::from_static("br"), + ); + let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); + assert_eq!(Encoding::Brotli, encoding); + } + + #[test] + fn two_accept_encoding_headers_three_encodings_with_one_qvalue() { + let mut headers = http::HeaderMap::new(); + headers.append( + http::header::ACCEPT_ENCODING, + http::HeaderValue::from_static("gzip;q=0.5,deflate"), + ); + headers.append( + http::header::ACCEPT_ENCODING, + http::HeaderValue::from_static("br"), + ); + let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); + assert_eq!(Encoding::Brotli, encoding); + } + + #[test] + fn three_accept_encoding_headers_with_one_qvalue() { + let mut headers = http::HeaderMap::new(); + headers.append( + http::header::ACCEPT_ENCODING, + http::HeaderValue::from_static("gzip;q=0.5"), + ); + headers.append( + http::header::ACCEPT_ENCODING, + http::HeaderValue::from_static("deflate"), + ); + headers.append( + http::header::ACCEPT_ENCODING, + http::HeaderValue::from_static("br"), + ); + let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); + assert_eq!(Encoding::Brotli, encoding); + } + + #[test] + fn accept_encoding_header_two_encodings_with_two_qvalues() { + let mut headers = http::HeaderMap::new(); + headers.append( + http::header::ACCEPT_ENCODING, + http::HeaderValue::from_static("gzip;q=0.5,br;q=0.8"), + ); + let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); + assert_eq!(Encoding::Brotli, encoding); + + let mut headers = http::HeaderMap::new(); + headers.append( + http::header::ACCEPT_ENCODING, + http::HeaderValue::from_static("gzip;q=0.8,br;q=0.5"), + ); + let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); + assert_eq!(Encoding::Gzip, encoding); + + let mut headers = http::HeaderMap::new(); + headers.append( + http::header::ACCEPT_ENCODING, + http::HeaderValue::from_static("gzip;q=0.995,br;q=0.999"), + ); + let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); + assert_eq!(Encoding::Brotli, encoding); + } + + #[test] + fn accept_encoding_header_three_encodings_with_three_qvalues() { + let mut headers = http::HeaderMap::new(); + headers.append( + http::header::ACCEPT_ENCODING, + http::HeaderValue::from_static("gzip;q=0.5,deflate;q=0.6,br;q=0.8"), + ); + let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); + assert_eq!(Encoding::Brotli, encoding); + + let mut headers = http::HeaderMap::new(); + headers.append( + http::header::ACCEPT_ENCODING, + http::HeaderValue::from_static("gzip;q=0.8,deflate;q=0.6,br;q=0.5"), + ); + let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); + assert_eq!(Encoding::Gzip, encoding); + + let mut headers = http::HeaderMap::new(); + headers.append( + http::header::ACCEPT_ENCODING, + http::HeaderValue::from_static("gzip;q=0.6,deflate;q=0.8,br;q=0.5"), + ); + let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); + assert_eq!(Encoding::Deflate, encoding); + + let mut headers = http::HeaderMap::new(); + headers.append( + http::header::ACCEPT_ENCODING, + http::HeaderValue::from_static("gzip;q=0.995,deflate;q=0.997,br;q=0.999"), + ); + let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); + assert_eq!(Encoding::Brotli, encoding); + } + + #[test] + fn accept_encoding_header_invalid_encdoing() { + let mut headers = http::HeaderMap::new(); + headers.append( + http::header::ACCEPT_ENCODING, + http::HeaderValue::from_static("invalid,gzip"), + ); + let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); + assert_eq!(Encoding::Gzip, encoding); + } + + #[test] + fn accept_encoding_header_with_qvalue_zero() { + let mut headers = http::HeaderMap::new(); + headers.append( + http::header::ACCEPT_ENCODING, + http::HeaderValue::from_static("gzip;q=0"), + ); + let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); + assert_eq!(Encoding::Identity, encoding); + + let mut headers = http::HeaderMap::new(); + headers.append( + http::header::ACCEPT_ENCODING, + http::HeaderValue::from_static("gzip;q=0."), + ); + let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); + assert_eq!(Encoding::Identity, encoding); + + let mut headers = http::HeaderMap::new(); + headers.append( + http::header::ACCEPT_ENCODING, + http::HeaderValue::from_static("gzip;q=0,br;q=0.5"), + ); + let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); + assert_eq!(Encoding::Brotli, encoding); + } + + #[test] + fn accept_encoding_header_with_uppercase_letters() { + let mut headers = http::HeaderMap::new(); + headers.append( + http::header::ACCEPT_ENCODING, + http::HeaderValue::from_static("gZiP"), + ); + let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); + assert_eq!(Encoding::Gzip, encoding); + + let mut headers = http::HeaderMap::new(); + headers.append( + http::header::ACCEPT_ENCODING, + http::HeaderValue::from_static("gzip;q=0.5,br;Q=0.8"), + ); + let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); + assert_eq!(Encoding::Brotli, encoding); + } + + #[test] + fn accept_encoding_header_with_allowed_spaces() { + let mut headers = http::HeaderMap::new(); + headers.append( + http::header::ACCEPT_ENCODING, + http::HeaderValue::from_static(" gzip\t; q=0.5 ,\tbr ;\tq=0.8\t"), + ); + let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); + assert_eq!(Encoding::Brotli, encoding); + } + + #[test] + fn accept_encoding_header_with_invalid_spaces() { + let mut headers = http::HeaderMap::new(); + headers.append( + http::header::ACCEPT_ENCODING, + http::HeaderValue::from_static("gzip;q =0.5"), + ); + let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); + assert_eq!(Encoding::Identity, encoding); + + let mut headers = http::HeaderMap::new(); + headers.append( + http::header::ACCEPT_ENCODING, + http::HeaderValue::from_static("gzip;q= 0.5"), + ); + let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); + assert_eq!(Encoding::Identity, encoding); + } + + #[test] + fn accept_encoding_header_with_invalid_quvalues() { + let mut headers = http::HeaderMap::new(); + headers.append( + http::header::ACCEPT_ENCODING, + http::HeaderValue::from_static("gzip;q=-0.1"), + ); + let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); + assert_eq!(Encoding::Identity, encoding); + + let mut headers = http::HeaderMap::new(); + headers.append( + http::header::ACCEPT_ENCODING, + http::HeaderValue::from_static("gzip;q=00.5"), + ); + let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); + assert_eq!(Encoding::Identity, encoding); + + let mut headers = http::HeaderMap::new(); + headers.append( + http::header::ACCEPT_ENCODING, + http::HeaderValue::from_static("gzip;q=0.5000"), + ); + let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); + assert_eq!(Encoding::Identity, encoding); + + let mut headers = http::HeaderMap::new(); + headers.append( + http::header::ACCEPT_ENCODING, + http::HeaderValue::from_static("gzip;q=.5"), + ); + let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); + assert_eq!(Encoding::Identity, encoding); + + let mut headers = http::HeaderMap::new(); + headers.append( + http::header::ACCEPT_ENCODING, + http::HeaderValue::from_static("gzip;q=1.01"), + ); + let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); + assert_eq!(Encoding::Identity, encoding); + + let mut headers = http::HeaderMap::new(); + headers.append( + http::header::ACCEPT_ENCODING, + http::HeaderValue::from_static("gzip;q=1.001"), + ); + let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); + assert_eq!(Encoding::Identity, encoding); + } +} diff --git a/vendor/tower-http/src/cors/allow_credentials.rs b/vendor/tower-http/src/cors/allow_credentials.rs new file mode 100644 index 00000000..de53ffed --- /dev/null +++ b/vendor/tower-http/src/cors/allow_credentials.rs @@ -0,0 +1,96 @@ +use std::{fmt, sync::Arc}; + +use http::{ + header::{self, HeaderName, HeaderValue}, + request::Parts as RequestParts, +}; + +/// Holds configuration for how to set the [`Access-Control-Allow-Credentials`][mdn] header. +/// +/// See [`CorsLayer::allow_credentials`] for more details. +/// +/// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Credentials +/// [`CorsLayer::allow_credentials`]: super::CorsLayer::allow_credentials +#[derive(Clone, Default)] +#[must_use] +pub struct AllowCredentials(AllowCredentialsInner); + +impl AllowCredentials { + /// Allow credentials for all requests + /// + /// See [`CorsLayer::allow_credentials`] for more details. + /// + /// [`CorsLayer::allow_credentials`]: super::CorsLayer::allow_credentials + pub fn yes() -> Self { + Self(AllowCredentialsInner::Yes) + } + + /// Allow credentials for some requests, based on a given predicate + /// + /// The first argument to the predicate is the request origin. + /// + /// See [`CorsLayer::allow_credentials`] for more details. + /// + /// [`CorsLayer::allow_credentials`]: super::CorsLayer::allow_credentials + pub fn predicate<F>(f: F) -> Self + where + F: Fn(&HeaderValue, &RequestParts) -> bool + Send + Sync + 'static, + { + Self(AllowCredentialsInner::Predicate(Arc::new(f))) + } + + pub(super) fn is_true(&self) -> bool { + matches!(&self.0, AllowCredentialsInner::Yes) + } + + pub(super) fn to_header( + &self, + origin: Option<&HeaderValue>, + parts: &RequestParts, + ) -> Option<(HeaderName, HeaderValue)> { + #[allow(clippy::declare_interior_mutable_const)] + const TRUE: HeaderValue = HeaderValue::from_static("true"); + + let allow_creds = match &self.0 { + AllowCredentialsInner::Yes => true, + AllowCredentialsInner::No => false, + AllowCredentialsInner::Predicate(c) => c(origin?, parts), + }; + + allow_creds.then_some((header::ACCESS_CONTROL_ALLOW_CREDENTIALS, TRUE)) + } +} + +impl From<bool> for AllowCredentials { + fn from(v: bool) -> Self { + match v { + true => Self(AllowCredentialsInner::Yes), + false => Self(AllowCredentialsInner::No), + } + } +} + +impl fmt::Debug for AllowCredentials { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self.0 { + AllowCredentialsInner::Yes => f.debug_tuple("Yes").finish(), + AllowCredentialsInner::No => f.debug_tuple("No").finish(), + AllowCredentialsInner::Predicate(_) => f.debug_tuple("Predicate").finish(), + } + } +} + +#[derive(Clone)] +enum AllowCredentialsInner { + Yes, + No, + Predicate( + Arc<dyn for<'a> Fn(&'a HeaderValue, &'a RequestParts) -> bool + Send + Sync + 'static>, + ), +} + +impl Default for AllowCredentialsInner { + fn default() -> Self { + Self::No + } +} diff --git a/vendor/tower-http/src/cors/allow_headers.rs b/vendor/tower-http/src/cors/allow_headers.rs new file mode 100644 index 00000000..06c19928 --- /dev/null +++ b/vendor/tower-http/src/cors/allow_headers.rs @@ -0,0 +1,112 @@ +use std::{array, fmt}; + +use http::{ + header::{self, HeaderName, HeaderValue}, + request::Parts as RequestParts, +}; + +use super::{separated_by_commas, Any, WILDCARD}; + +/// Holds configuration for how to set the [`Access-Control-Allow-Headers`][mdn] header. +/// +/// See [`CorsLayer::allow_headers`] for more details. +/// +/// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Headers +/// [`CorsLayer::allow_headers`]: super::CorsLayer::allow_headers +#[derive(Clone, Default)] +#[must_use] +pub struct AllowHeaders(AllowHeadersInner); + +impl AllowHeaders { + /// Allow any headers by sending a wildcard (`*`) + /// + /// See [`CorsLayer::allow_headers`] for more details. + /// + /// [`CorsLayer::allow_headers`]: super::CorsLayer::allow_headers + pub fn any() -> Self { + Self(AllowHeadersInner::Const(Some(WILDCARD))) + } + + /// Set multiple allowed headers + /// + /// See [`CorsLayer::allow_headers`] for more details. + /// + /// [`CorsLayer::allow_headers`]: super::CorsLayer::allow_headers + pub fn list<I>(headers: I) -> Self + where + I: IntoIterator<Item = HeaderName>, + { + Self(AllowHeadersInner::Const(separated_by_commas( + headers.into_iter().map(Into::into), + ))) + } + + /// Allow any headers, by mirroring the preflight [`Access-Control-Request-Headers`][mdn] + /// header. + /// + /// See [`CorsLayer::allow_headers`] for more details. + /// + /// [`CorsLayer::allow_headers`]: super::CorsLayer::allow_headers + /// + /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Request-Headers + pub fn mirror_request() -> Self { + Self(AllowHeadersInner::MirrorRequest) + } + + #[allow(clippy::borrow_interior_mutable_const)] + pub(super) fn is_wildcard(&self) -> bool { + matches!(&self.0, AllowHeadersInner::Const(Some(v)) if v == WILDCARD) + } + + pub(super) fn to_header(&self, parts: &RequestParts) -> Option<(HeaderName, HeaderValue)> { + let allow_headers = match &self.0 { + AllowHeadersInner::Const(v) => v.clone()?, + AllowHeadersInner::MirrorRequest => parts + .headers + .get(header::ACCESS_CONTROL_REQUEST_HEADERS)? + .clone(), + }; + + Some((header::ACCESS_CONTROL_ALLOW_HEADERS, allow_headers)) + } +} + +impl fmt::Debug for AllowHeaders { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match &self.0 { + AllowHeadersInner::Const(inner) => f.debug_tuple("Const").field(inner).finish(), + AllowHeadersInner::MirrorRequest => f.debug_tuple("MirrorRequest").finish(), + } + } +} + +impl From<Any> for AllowHeaders { + fn from(_: Any) -> Self { + Self::any() + } +} + +impl<const N: usize> From<[HeaderName; N]> for AllowHeaders { + fn from(arr: [HeaderName; N]) -> Self { + #[allow(deprecated)] // Can be changed when MSRV >= 1.53 + Self::list(array::IntoIter::new(arr)) + } +} + +impl From<Vec<HeaderName>> for AllowHeaders { + fn from(vec: Vec<HeaderName>) -> Self { + Self::list(vec) + } +} + +#[derive(Clone)] +enum AllowHeadersInner { + Const(Option<HeaderValue>), + MirrorRequest, +} + +impl Default for AllowHeadersInner { + fn default() -> Self { + Self::Const(None) + } +} diff --git a/vendor/tower-http/src/cors/allow_methods.rs b/vendor/tower-http/src/cors/allow_methods.rs new file mode 100644 index 00000000..df1a3cbd --- /dev/null +++ b/vendor/tower-http/src/cors/allow_methods.rs @@ -0,0 +1,132 @@ +use std::{array, fmt}; + +use http::{ + header::{self, HeaderName, HeaderValue}, + request::Parts as RequestParts, + Method, +}; + +use super::{separated_by_commas, Any, WILDCARD}; + +/// Holds configuration for how to set the [`Access-Control-Allow-Methods`][mdn] header. +/// +/// See [`CorsLayer::allow_methods`] for more details. +/// +/// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Methods +/// [`CorsLayer::allow_methods`]: super::CorsLayer::allow_methods +#[derive(Clone, Default)] +#[must_use] +pub struct AllowMethods(AllowMethodsInner); + +impl AllowMethods { + /// Allow any method by sending a wildcard (`*`) + /// + /// See [`CorsLayer::allow_methods`] for more details. + /// + /// [`CorsLayer::allow_methods`]: super::CorsLayer::allow_methods + pub fn any() -> Self { + Self(AllowMethodsInner::Const(Some(WILDCARD))) + } + + /// Set a single allowed method + /// + /// See [`CorsLayer::allow_methods`] for more details. + /// + /// [`CorsLayer::allow_methods`]: super::CorsLayer::allow_methods + pub fn exact(method: Method) -> Self { + Self(AllowMethodsInner::Const(Some( + HeaderValue::from_str(method.as_str()).unwrap(), + ))) + } + + /// Set multiple allowed methods + /// + /// See [`CorsLayer::allow_methods`] for more details. + /// + /// [`CorsLayer::allow_methods`]: super::CorsLayer::allow_methods + pub fn list<I>(methods: I) -> Self + where + I: IntoIterator<Item = Method>, + { + Self(AllowMethodsInner::Const(separated_by_commas( + methods + .into_iter() + .map(|m| HeaderValue::from_str(m.as_str()).unwrap()), + ))) + } + + /// Allow any method, by mirroring the preflight [`Access-Control-Request-Method`][mdn] + /// header. + /// + /// See [`CorsLayer::allow_methods`] for more details. + /// + /// [`CorsLayer::allow_methods`]: super::CorsLayer::allow_methods + /// + /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Request-Method + pub fn mirror_request() -> Self { + Self(AllowMethodsInner::MirrorRequest) + } + + #[allow(clippy::borrow_interior_mutable_const)] + pub(super) fn is_wildcard(&self) -> bool { + matches!(&self.0, AllowMethodsInner::Const(Some(v)) if v == WILDCARD) + } + + pub(super) fn to_header(&self, parts: &RequestParts) -> Option<(HeaderName, HeaderValue)> { + let allow_methods = match &self.0 { + AllowMethodsInner::Const(v) => v.clone()?, + AllowMethodsInner::MirrorRequest => parts + .headers + .get(header::ACCESS_CONTROL_REQUEST_METHOD)? + .clone(), + }; + + Some((header::ACCESS_CONTROL_ALLOW_METHODS, allow_methods)) + } +} + +impl fmt::Debug for AllowMethods { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match &self.0 { + AllowMethodsInner::Const(inner) => f.debug_tuple("Const").field(inner).finish(), + AllowMethodsInner::MirrorRequest => f.debug_tuple("MirrorRequest").finish(), + } + } +} + +impl From<Any> for AllowMethods { + fn from(_: Any) -> Self { + Self::any() + } +} + +impl From<Method> for AllowMethods { + fn from(method: Method) -> Self { + Self::exact(method) + } +} + +impl<const N: usize> From<[Method; N]> for AllowMethods { + fn from(arr: [Method; N]) -> Self { + #[allow(deprecated)] // Can be changed when MSRV >= 1.53 + Self::list(array::IntoIter::new(arr)) + } +} + +impl From<Vec<Method>> for AllowMethods { + fn from(vec: Vec<Method>) -> Self { + Self::list(vec) + } +} + +#[derive(Clone)] +enum AllowMethodsInner { + Const(Option<HeaderValue>), + MirrorRequest, +} + +impl Default for AllowMethodsInner { + fn default() -> Self { + Self::Const(None) + } +} diff --git a/vendor/tower-http/src/cors/allow_origin.rs b/vendor/tower-http/src/cors/allow_origin.rs new file mode 100644 index 00000000..d5fdd7b6 --- /dev/null +++ b/vendor/tower-http/src/cors/allow_origin.rs @@ -0,0 +1,241 @@ +use http::{ + header::{self, HeaderName, HeaderValue}, + request::Parts as RequestParts, +}; +use pin_project_lite::pin_project; +use std::{ + array, fmt, + future::Future, + pin::Pin, + sync::Arc, + task::{Context, Poll}, +}; + +use super::{Any, WILDCARD}; + +/// Holds configuration for how to set the [`Access-Control-Allow-Origin`][mdn] header. +/// +/// See [`CorsLayer::allow_origin`] for more details. +/// +/// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin +/// [`CorsLayer::allow_origin`]: super::CorsLayer::allow_origin +#[derive(Clone, Default)] +#[must_use] +pub struct AllowOrigin(OriginInner); + +impl AllowOrigin { + /// Allow any origin by sending a wildcard (`*`) + /// + /// See [`CorsLayer::allow_origin`] for more details. + /// + /// [`CorsLayer::allow_origin`]: super::CorsLayer::allow_origin + pub fn any() -> Self { + Self(OriginInner::Const(WILDCARD)) + } + + /// Set a single allowed origin + /// + /// See [`CorsLayer::allow_origin`] for more details. + /// + /// [`CorsLayer::allow_origin`]: super::CorsLayer::allow_origin + pub fn exact(origin: HeaderValue) -> Self { + Self(OriginInner::Const(origin)) + } + + /// Set multiple allowed origins + /// + /// See [`CorsLayer::allow_origin`] for more details. + /// + /// # Panics + /// + /// If the iterator contains a wildcard (`*`). + /// + /// [`CorsLayer::allow_origin`]: super::CorsLayer::allow_origin + #[allow(clippy::borrow_interior_mutable_const)] + pub fn list<I>(origins: I) -> Self + where + I: IntoIterator<Item = HeaderValue>, + { + let origins = origins.into_iter().collect::<Vec<_>>(); + if origins.contains(&WILDCARD) { + panic!( + "Wildcard origin (`*`) cannot be passed to `AllowOrigin::list`. \ + Use `AllowOrigin::any()` instead" + ); + } + + Self(OriginInner::List(origins)) + } + + /// Set the allowed origins from a predicate + /// + /// See [`CorsLayer::allow_origin`] for more details. + /// + /// [`CorsLayer::allow_origin`]: super::CorsLayer::allow_origin + pub fn predicate<F>(f: F) -> Self + where + F: Fn(&HeaderValue, &RequestParts) -> bool + Send + Sync + 'static, + { + Self(OriginInner::Predicate(Arc::new(f))) + } + + /// Set the allowed origins from an async predicate + /// + /// See [`CorsLayer::allow_origin`] for more details. + /// + /// [`CorsLayer::allow_origin`]: super::CorsLayer::allow_origin + pub fn async_predicate<F, Fut>(f: F) -> Self + where + F: FnOnce(HeaderValue, &RequestParts) -> Fut + Send + Sync + 'static + Clone, + Fut: Future<Output = bool> + Send + 'static, + { + Self(OriginInner::AsyncPredicate(Arc::new(move |v, p| { + Box::pin((f.clone())(v, p)) + }))) + } + + /// Allow any origin, by mirroring the request origin + /// + /// This is equivalent to + /// [`AllowOrigin::predicate(|_, _| true)`][Self::predicate]. + /// + /// See [`CorsLayer::allow_origin`] for more details. + /// + /// [`CorsLayer::allow_origin`]: super::CorsLayer::allow_origin + pub fn mirror_request() -> Self { + Self::predicate(|_, _| true) + } + + #[allow(clippy::borrow_interior_mutable_const)] + pub(super) fn is_wildcard(&self) -> bool { + matches!(&self.0, OriginInner::Const(v) if v == WILDCARD) + } + + pub(super) fn to_future( + &self, + origin: Option<&HeaderValue>, + parts: &RequestParts, + ) -> AllowOriginFuture { + let name = header::ACCESS_CONTROL_ALLOW_ORIGIN; + + match &self.0 { + OriginInner::Const(v) => AllowOriginFuture::ok(Some((name, v.clone()))), + OriginInner::List(l) => { + AllowOriginFuture::ok(origin.filter(|o| l.contains(o)).map(|o| (name, o.clone()))) + } + OriginInner::Predicate(c) => AllowOriginFuture::ok( + origin + .filter(|origin| c(origin, parts)) + .map(|o| (name, o.clone())), + ), + OriginInner::AsyncPredicate(f) => { + if let Some(origin) = origin.cloned() { + let fut = f(origin.clone(), parts); + AllowOriginFuture::fut(async move { fut.await.then_some((name, origin)) }) + } else { + AllowOriginFuture::ok(None) + } + } + } + } +} + +pin_project! { + #[project = AllowOriginFutureProj] + pub(super) enum AllowOriginFuture { + Ok{ + res: Option<(HeaderName, HeaderValue)> + }, + Future{ + #[pin] + future: Pin<Box<dyn Future<Output = Option<(HeaderName, HeaderValue)>> + Send + 'static>> + }, + } +} + +impl AllowOriginFuture { + fn ok(res: Option<(HeaderName, HeaderValue)>) -> Self { + Self::Ok { res } + } + + fn fut<F: Future<Output = Option<(HeaderName, HeaderValue)>> + Send + 'static>( + future: F, + ) -> Self { + Self::Future { + future: Box::pin(future), + } + } +} + +impl Future for AllowOriginFuture { + type Output = Option<(HeaderName, HeaderValue)>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + match self.project() { + AllowOriginFutureProj::Ok { res } => Poll::Ready(res.take()), + AllowOriginFutureProj::Future { future } => future.poll(cx), + } + } +} + +impl fmt::Debug for AllowOrigin { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match &self.0 { + OriginInner::Const(inner) => f.debug_tuple("Const").field(inner).finish(), + OriginInner::List(inner) => f.debug_tuple("List").field(inner).finish(), + OriginInner::Predicate(_) => f.debug_tuple("Predicate").finish(), + OriginInner::AsyncPredicate(_) => f.debug_tuple("AsyncPredicate").finish(), + } + } +} + +impl From<Any> for AllowOrigin { + fn from(_: Any) -> Self { + Self::any() + } +} + +impl From<HeaderValue> for AllowOrigin { + fn from(val: HeaderValue) -> Self { + Self::exact(val) + } +} + +impl<const N: usize> From<[HeaderValue; N]> for AllowOrigin { + fn from(arr: [HeaderValue; N]) -> Self { + #[allow(deprecated)] // Can be changed when MSRV >= 1.53 + Self::list(array::IntoIter::new(arr)) + } +} + +impl From<Vec<HeaderValue>> for AllowOrigin { + fn from(vec: Vec<HeaderValue>) -> Self { + Self::list(vec) + } +} + +#[derive(Clone)] +enum OriginInner { + Const(HeaderValue), + List(Vec<HeaderValue>), + Predicate( + Arc<dyn for<'a> Fn(&'a HeaderValue, &'a RequestParts) -> bool + Send + Sync + 'static>, + ), + AsyncPredicate( + Arc< + dyn for<'a> Fn( + HeaderValue, + &'a RequestParts, + ) -> Pin<Box<dyn Future<Output = bool> + Send + 'static>> + + Send + + Sync + + 'static, + >, + ), +} + +impl Default for OriginInner { + fn default() -> Self { + Self::List(Vec::new()) + } +} diff --git a/vendor/tower-http/src/cors/allow_private_network.rs b/vendor/tower-http/src/cors/allow_private_network.rs new file mode 100644 index 00000000..9f97dc11 --- /dev/null +++ b/vendor/tower-http/src/cors/allow_private_network.rs @@ -0,0 +1,205 @@ +use std::{fmt, sync::Arc}; + +use http::{ + header::{HeaderName, HeaderValue}, + request::Parts as RequestParts, +}; + +/// Holds configuration for how to set the [`Access-Control-Allow-Private-Network`][wicg] header. +/// +/// See [`CorsLayer::allow_private_network`] for more details. +/// +/// [wicg]: https://wicg.github.io/private-network-access/ +/// [`CorsLayer::allow_private_network`]: super::CorsLayer::allow_private_network +#[derive(Clone, Default)] +#[must_use] +pub struct AllowPrivateNetwork(AllowPrivateNetworkInner); + +impl AllowPrivateNetwork { + /// Allow requests via a more private network than the one used to access the origin + /// + /// See [`CorsLayer::allow_private_network`] for more details. + /// + /// [`CorsLayer::allow_private_network`]: super::CorsLayer::allow_private_network + pub fn yes() -> Self { + Self(AllowPrivateNetworkInner::Yes) + } + + /// Allow requests via private network for some requests, based on a given predicate + /// + /// The first argument to the predicate is the request origin. + /// + /// See [`CorsLayer::allow_private_network`] for more details. + /// + /// [`CorsLayer::allow_private_network`]: super::CorsLayer::allow_private_network + pub fn predicate<F>(f: F) -> Self + where + F: Fn(&HeaderValue, &RequestParts) -> bool + Send + Sync + 'static, + { + Self(AllowPrivateNetworkInner::Predicate(Arc::new(f))) + } + + #[allow( + clippy::declare_interior_mutable_const, + clippy::borrow_interior_mutable_const + )] + pub(super) fn to_header( + &self, + origin: Option<&HeaderValue>, + parts: &RequestParts, + ) -> Option<(HeaderName, HeaderValue)> { + #[allow(clippy::declare_interior_mutable_const)] + const REQUEST_PRIVATE_NETWORK: HeaderName = + HeaderName::from_static("access-control-request-private-network"); + + #[allow(clippy::declare_interior_mutable_const)] + const ALLOW_PRIVATE_NETWORK: HeaderName = + HeaderName::from_static("access-control-allow-private-network"); + + const TRUE: HeaderValue = HeaderValue::from_static("true"); + + // Cheapest fallback: allow_private_network hasn't been set + if let AllowPrivateNetworkInner::No = &self.0 { + return None; + } + + // Access-Control-Allow-Private-Network is only relevant if the request + // has the Access-Control-Request-Private-Network header set, else skip + if parts.headers.get(REQUEST_PRIVATE_NETWORK) != Some(&TRUE) { + return None; + } + + let allow_private_network = match &self.0 { + AllowPrivateNetworkInner::Yes => true, + AllowPrivateNetworkInner::No => false, // unreachable, but not harmful + AllowPrivateNetworkInner::Predicate(c) => c(origin?, parts), + }; + + allow_private_network.then_some((ALLOW_PRIVATE_NETWORK, TRUE)) + } +} + +impl From<bool> for AllowPrivateNetwork { + fn from(v: bool) -> Self { + match v { + true => Self(AllowPrivateNetworkInner::Yes), + false => Self(AllowPrivateNetworkInner::No), + } + } +} + +impl fmt::Debug for AllowPrivateNetwork { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self.0 { + AllowPrivateNetworkInner::Yes => f.debug_tuple("Yes").finish(), + AllowPrivateNetworkInner::No => f.debug_tuple("No").finish(), + AllowPrivateNetworkInner::Predicate(_) => f.debug_tuple("Predicate").finish(), + } + } +} + +#[derive(Clone)] +enum AllowPrivateNetworkInner { + Yes, + No, + Predicate( + Arc<dyn for<'a> Fn(&'a HeaderValue, &'a RequestParts) -> bool + Send + Sync + 'static>, + ), +} + +impl Default for AllowPrivateNetworkInner { + fn default() -> Self { + Self::No + } +} + +#[cfg(test)] +mod tests { + #![allow( + clippy::declare_interior_mutable_const, + clippy::borrow_interior_mutable_const + )] + + use super::AllowPrivateNetwork; + use crate::cors::CorsLayer; + + use crate::test_helpers::Body; + use http::{header::ORIGIN, request::Parts, HeaderName, HeaderValue, Request, Response}; + use tower::{BoxError, ServiceBuilder, ServiceExt}; + use tower_service::Service; + + const REQUEST_PRIVATE_NETWORK: HeaderName = + HeaderName::from_static("access-control-request-private-network"); + + const ALLOW_PRIVATE_NETWORK: HeaderName = + HeaderName::from_static("access-control-allow-private-network"); + + const TRUE: HeaderValue = HeaderValue::from_static("true"); + + #[tokio::test] + async fn cors_private_network_header_is_added_correctly() { + let mut service = ServiceBuilder::new() + .layer(CorsLayer::new().allow_private_network(true)) + .service_fn(echo); + + let req = Request::builder() + .header(REQUEST_PRIVATE_NETWORK, TRUE) + .body(Body::empty()) + .unwrap(); + let res = service.ready().await.unwrap().call(req).await.unwrap(); + + assert_eq!(res.headers().get(ALLOW_PRIVATE_NETWORK).unwrap(), TRUE); + + let req = Request::builder().body(Body::empty()).unwrap(); + let res = service.ready().await.unwrap().call(req).await.unwrap(); + + assert!(res.headers().get(ALLOW_PRIVATE_NETWORK).is_none()); + } + + #[tokio::test] + async fn cors_private_network_header_is_added_correctly_with_predicate() { + let allow_private_network = + AllowPrivateNetwork::predicate(|origin: &HeaderValue, parts: &Parts| { + parts.uri.path() == "/allow-private" && origin == "localhost" + }); + let mut service = ServiceBuilder::new() + .layer(CorsLayer::new().allow_private_network(allow_private_network)) + .service_fn(echo); + + let req = Request::builder() + .header(ORIGIN, "localhost") + .header(REQUEST_PRIVATE_NETWORK, TRUE) + .uri("/allow-private") + .body(Body::empty()) + .unwrap(); + + let res = service.ready().await.unwrap().call(req).await.unwrap(); + assert_eq!(res.headers().get(ALLOW_PRIVATE_NETWORK).unwrap(), TRUE); + + let req = Request::builder() + .header(ORIGIN, "localhost") + .header(REQUEST_PRIVATE_NETWORK, TRUE) + .uri("/other") + .body(Body::empty()) + .unwrap(); + + let res = service.ready().await.unwrap().call(req).await.unwrap(); + + assert!(res.headers().get(ALLOW_PRIVATE_NETWORK).is_none()); + + let req = Request::builder() + .header(ORIGIN, "not-localhost") + .header(REQUEST_PRIVATE_NETWORK, TRUE) + .uri("/allow-private") + .body(Body::empty()) + .unwrap(); + + let res = service.ready().await.unwrap().call(req).await.unwrap(); + + assert!(res.headers().get(ALLOW_PRIVATE_NETWORK).is_none()); + } + + async fn echo(req: Request<Body>) -> Result<Response<Body>, BoxError> { + Ok(Response::new(req.into_body())) + } +} diff --git a/vendor/tower-http/src/cors/expose_headers.rs b/vendor/tower-http/src/cors/expose_headers.rs new file mode 100644 index 00000000..2b1a2267 --- /dev/null +++ b/vendor/tower-http/src/cors/expose_headers.rs @@ -0,0 +1,94 @@ +use std::{array, fmt}; + +use http::{ + header::{self, HeaderName, HeaderValue}, + request::Parts as RequestParts, +}; + +use super::{separated_by_commas, Any, WILDCARD}; + +/// Holds configuration for how to set the [`Access-Control-Expose-Headers`][mdn] header. +/// +/// See [`CorsLayer::expose_headers`] for more details. +/// +/// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Expose-Headers +/// [`CorsLayer::expose_headers`]: super::CorsLayer::expose_headers +#[derive(Clone, Default)] +#[must_use] +pub struct ExposeHeaders(ExposeHeadersInner); + +impl ExposeHeaders { + /// Expose any / all headers by sending a wildcard (`*`) + /// + /// See [`CorsLayer::expose_headers`] for more details. + /// + /// [`CorsLayer::expose_headers`]: super::CorsLayer::expose_headers + pub fn any() -> Self { + Self(ExposeHeadersInner::Const(Some(WILDCARD))) + } + + /// Set multiple exposed header names + /// + /// See [`CorsLayer::expose_headers`] for more details. + /// + /// [`CorsLayer::expose_headers`]: super::CorsLayer::expose_headers + pub fn list<I>(headers: I) -> Self + where + I: IntoIterator<Item = HeaderName>, + { + Self(ExposeHeadersInner::Const(separated_by_commas( + headers.into_iter().map(Into::into), + ))) + } + + #[allow(clippy::borrow_interior_mutable_const)] + pub(super) fn is_wildcard(&self) -> bool { + matches!(&self.0, ExposeHeadersInner::Const(Some(v)) if v == WILDCARD) + } + + pub(super) fn to_header(&self, _parts: &RequestParts) -> Option<(HeaderName, HeaderValue)> { + let expose_headers = match &self.0 { + ExposeHeadersInner::Const(v) => v.clone()?, + }; + + Some((header::ACCESS_CONTROL_EXPOSE_HEADERS, expose_headers)) + } +} + +impl fmt::Debug for ExposeHeaders { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match &self.0 { + ExposeHeadersInner::Const(inner) => f.debug_tuple("Const").field(inner).finish(), + } + } +} + +impl From<Any> for ExposeHeaders { + fn from(_: Any) -> Self { + Self::any() + } +} + +impl<const N: usize> From<[HeaderName; N]> for ExposeHeaders { + fn from(arr: [HeaderName; N]) -> Self { + #[allow(deprecated)] // Can be changed when MSRV >= 1.53 + Self::list(array::IntoIter::new(arr)) + } +} + +impl From<Vec<HeaderName>> for ExposeHeaders { + fn from(vec: Vec<HeaderName>) -> Self { + Self::list(vec) + } +} + +#[derive(Clone)] +enum ExposeHeadersInner { + Const(Option<HeaderValue>), +} + +impl Default for ExposeHeadersInner { + fn default() -> Self { + ExposeHeadersInner::Const(None) + } +} diff --git a/vendor/tower-http/src/cors/max_age.rs b/vendor/tower-http/src/cors/max_age.rs new file mode 100644 index 00000000..98189926 --- /dev/null +++ b/vendor/tower-http/src/cors/max_age.rs @@ -0,0 +1,74 @@ +use std::{fmt, sync::Arc, time::Duration}; + +use http::{ + header::{self, HeaderName, HeaderValue}, + request::Parts as RequestParts, +}; + +/// Holds configuration for how to set the [`Access-Control-Max-Age`][mdn] header. +/// +/// See [`CorsLayer::max_age`][super::CorsLayer::max_age] for more details. +/// +/// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Max-Age +#[derive(Clone, Default)] +#[must_use] +pub struct MaxAge(MaxAgeInner); + +impl MaxAge { + /// Set a static max-age value + /// + /// See [`CorsLayer::max_age`][super::CorsLayer::max_age] for more details. + pub fn exact(max_age: Duration) -> Self { + Self(MaxAgeInner::Exact(Some(max_age.as_secs().into()))) + } + + /// Set the max-age based on the preflight request parts + /// + /// See [`CorsLayer::max_age`][super::CorsLayer::max_age] for more details. + pub fn dynamic<F>(f: F) -> Self + where + F: Fn(&HeaderValue, &RequestParts) -> Duration + Send + Sync + 'static, + { + Self(MaxAgeInner::Fn(Arc::new(f))) + } + + pub(super) fn to_header( + &self, + origin: Option<&HeaderValue>, + parts: &RequestParts, + ) -> Option<(HeaderName, HeaderValue)> { + let max_age = match &self.0 { + MaxAgeInner::Exact(v) => v.clone()?, + MaxAgeInner::Fn(c) => c(origin?, parts).as_secs().into(), + }; + + Some((header::ACCESS_CONTROL_MAX_AGE, max_age)) + } +} + +impl fmt::Debug for MaxAge { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match &self.0 { + MaxAgeInner::Exact(inner) => f.debug_tuple("Exact").field(inner).finish(), + MaxAgeInner::Fn(_) => f.debug_tuple("Fn").finish(), + } + } +} + +impl From<Duration> for MaxAge { + fn from(max_age: Duration) -> Self { + Self::exact(max_age) + } +} + +#[derive(Clone)] +enum MaxAgeInner { + Exact(Option<HeaderValue>), + Fn(Arc<dyn for<'a> Fn(&'a HeaderValue, &'a RequestParts) -> Duration + Send + Sync + 'static>), +} + +impl Default for MaxAgeInner { + fn default() -> Self { + Self::Exact(None) + } +} diff --git a/vendor/tower-http/src/cors/mod.rs b/vendor/tower-http/src/cors/mod.rs new file mode 100644 index 00000000..9da666c2 --- /dev/null +++ b/vendor/tower-http/src/cors/mod.rs @@ -0,0 +1,822 @@ +//! Middleware which adds headers for [CORS][mdn]. +//! +//! # Example +//! +//! ``` +//! use http::{Request, Response, Method, header}; +//! use http_body_util::Full; +//! use bytes::Bytes; +//! use tower::{ServiceBuilder, ServiceExt, Service}; +//! use tower_http::cors::{Any, CorsLayer}; +//! use std::convert::Infallible; +//! +//! async fn handle(request: Request<Full<Bytes>>) -> Result<Response<Full<Bytes>>, Infallible> { +//! Ok(Response::new(Full::default())) +//! } +//! +//! # #[tokio::main] +//! # async fn main() -> Result<(), Box<dyn std::error::Error>> { +//! let cors = CorsLayer::new() +//! // allow `GET` and `POST` when accessing the resource +//! .allow_methods([Method::GET, Method::POST]) +//! // allow requests from any origin +//! .allow_origin(Any); +//! +//! let mut service = ServiceBuilder::new() +//! .layer(cors) +//! .service_fn(handle); +//! +//! let request = Request::builder() +//! .header(header::ORIGIN, "https://example.com") +//! .body(Full::default()) +//! .unwrap(); +//! +//! let response = service +//! .ready() +//! .await? +//! .call(request) +//! .await?; +//! +//! assert_eq!( +//! response.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN).unwrap(), +//! "*", +//! ); +//! # Ok(()) +//! # } +//! ``` +//! +//! [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS + +#![allow(clippy::enum_variant_names)] + +use allow_origin::AllowOriginFuture; +use bytes::{BufMut, BytesMut}; +use http::{ + header::{self, HeaderName}, + HeaderMap, HeaderValue, Method, Request, Response, +}; +use pin_project_lite::pin_project; +use std::{ + array, + future::Future, + mem, + pin::Pin, + task::{ready, Context, Poll}, +}; +use tower_layer::Layer; +use tower_service::Service; + +mod allow_credentials; +mod allow_headers; +mod allow_methods; +mod allow_origin; +mod allow_private_network; +mod expose_headers; +mod max_age; +mod vary; + +#[cfg(test)] +mod tests; + +pub use self::{ + allow_credentials::AllowCredentials, allow_headers::AllowHeaders, allow_methods::AllowMethods, + allow_origin::AllowOrigin, allow_private_network::AllowPrivateNetwork, + expose_headers::ExposeHeaders, max_age::MaxAge, vary::Vary, +}; + +/// Layer that applies the [`Cors`] middleware which adds headers for [CORS][mdn]. +/// +/// See the [module docs](crate::cors) for an example. +/// +/// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS +#[derive(Debug, Clone)] +#[must_use] +pub struct CorsLayer { + allow_credentials: AllowCredentials, + allow_headers: AllowHeaders, + allow_methods: AllowMethods, + allow_origin: AllowOrigin, + allow_private_network: AllowPrivateNetwork, + expose_headers: ExposeHeaders, + max_age: MaxAge, + vary: Vary, +} + +#[allow(clippy::declare_interior_mutable_const)] +const WILDCARD: HeaderValue = HeaderValue::from_static("*"); + +impl CorsLayer { + /// Create a new `CorsLayer`. + /// + /// No headers are sent by default. Use the builder methods to customize + /// the behavior. + /// + /// You need to set at least an allowed origin for browsers to make + /// successful cross-origin requests to your service. + pub fn new() -> Self { + Self { + allow_credentials: Default::default(), + allow_headers: Default::default(), + allow_methods: Default::default(), + allow_origin: Default::default(), + allow_private_network: Default::default(), + expose_headers: Default::default(), + max_age: Default::default(), + vary: Default::default(), + } + } + + /// A permissive configuration: + /// + /// - All request headers allowed. + /// - All methods allowed. + /// - All origins allowed. + /// - All headers exposed. + pub fn permissive() -> Self { + Self::new() + .allow_headers(Any) + .allow_methods(Any) + .allow_origin(Any) + .expose_headers(Any) + } + + /// A very permissive configuration: + /// + /// - **Credentials allowed.** + /// - The method received in `Access-Control-Request-Method` is sent back + /// as an allowed method. + /// - The origin of the preflight request is sent back as an allowed origin. + /// - The header names received in `Access-Control-Request-Headers` are sent + /// back as allowed headers. + /// - No headers are currently exposed, but this may change in the future. + pub fn very_permissive() -> Self { + Self::new() + .allow_credentials(true) + .allow_headers(AllowHeaders::mirror_request()) + .allow_methods(AllowMethods::mirror_request()) + .allow_origin(AllowOrigin::mirror_request()) + } + + /// Set the [`Access-Control-Allow-Credentials`][mdn] header. + /// + /// ``` + /// use tower_http::cors::CorsLayer; + /// + /// let layer = CorsLayer::new().allow_credentials(true); + /// ``` + /// + /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Credentials + pub fn allow_credentials<T>(mut self, allow_credentials: T) -> Self + where + T: Into<AllowCredentials>, + { + self.allow_credentials = allow_credentials.into(); + self + } + + /// Set the value of the [`Access-Control-Allow-Headers`][mdn] header. + /// + /// ``` + /// use tower_http::cors::CorsLayer; + /// use http::header::{AUTHORIZATION, ACCEPT}; + /// + /// let layer = CorsLayer::new().allow_headers([AUTHORIZATION, ACCEPT]); + /// ``` + /// + /// All headers can be allowed with + /// + /// ``` + /// use tower_http::cors::{Any, CorsLayer}; + /// + /// let layer = CorsLayer::new().allow_headers(Any); + /// ``` + /// + /// Note that multiple calls to this method will override any previous + /// calls. + /// + /// Also note that `Access-Control-Allow-Headers` is required for requests that have + /// `Access-Control-Request-Headers`. + /// + /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Headers + pub fn allow_headers<T>(mut self, headers: T) -> Self + where + T: Into<AllowHeaders>, + { + self.allow_headers = headers.into(); + self + } + + /// Set the value of the [`Access-Control-Max-Age`][mdn] header. + /// + /// ``` + /// use std::time::Duration; + /// use tower_http::cors::CorsLayer; + /// + /// let layer = CorsLayer::new().max_age(Duration::from_secs(60) * 10); + /// ``` + /// + /// By default the header will not be set which disables caching and will + /// require a preflight call for all requests. + /// + /// Note that each browser has a maximum internal value that takes + /// precedence when the Access-Control-Max-Age is greater. For more details + /// see [mdn]. + /// + /// If you need more flexibility, you can use supply a function which can + /// dynamically decide the max-age based on the origin and other parts of + /// each preflight request: + /// + /// ``` + /// # struct MyServerConfig { cors_max_age: Duration } + /// use std::time::Duration; + /// + /// use http::{request::Parts as RequestParts, HeaderValue}; + /// use tower_http::cors::{CorsLayer, MaxAge}; + /// + /// let layer = CorsLayer::new().max_age(MaxAge::dynamic( + /// |_origin: &HeaderValue, parts: &RequestParts| -> Duration { + /// // Let's say you want to be able to reload your config at + /// // runtime and have another middleware that always inserts + /// // the current config into the request extensions + /// let config = parts.extensions.get::<MyServerConfig>().unwrap(); + /// config.cors_max_age + /// }, + /// )); + /// ``` + /// + /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Max-Age + pub fn max_age<T>(mut self, max_age: T) -> Self + where + T: Into<MaxAge>, + { + self.max_age = max_age.into(); + self + } + + /// Set the value of the [`Access-Control-Allow-Methods`][mdn] header. + /// + /// ``` + /// use tower_http::cors::CorsLayer; + /// use http::Method; + /// + /// let layer = CorsLayer::new().allow_methods([Method::GET, Method::POST]); + /// ``` + /// + /// All methods can be allowed with + /// + /// ``` + /// use tower_http::cors::{Any, CorsLayer}; + /// + /// let layer = CorsLayer::new().allow_methods(Any); + /// ``` + /// + /// Note that multiple calls to this method will override any previous + /// calls. + /// + /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Methods + pub fn allow_methods<T>(mut self, methods: T) -> Self + where + T: Into<AllowMethods>, + { + self.allow_methods = methods.into(); + self + } + + /// Set the value of the [`Access-Control-Allow-Origin`][mdn] header. + /// + /// ``` + /// use http::HeaderValue; + /// use tower_http::cors::CorsLayer; + /// + /// let layer = CorsLayer::new().allow_origin( + /// "http://example.com".parse::<HeaderValue>().unwrap(), + /// ); + /// ``` + /// + /// Multiple origins can be allowed with + /// + /// ``` + /// use tower_http::cors::CorsLayer; + /// + /// let origins = [ + /// "http://example.com".parse().unwrap(), + /// "http://api.example.com".parse().unwrap(), + /// ]; + /// + /// let layer = CorsLayer::new().allow_origin(origins); + /// ``` + /// + /// All origins can be allowed with + /// + /// ``` + /// use tower_http::cors::{Any, CorsLayer}; + /// + /// let layer = CorsLayer::new().allow_origin(Any); + /// ``` + /// + /// You can also use a closure + /// + /// ``` + /// use tower_http::cors::{CorsLayer, AllowOrigin}; + /// use http::{request::Parts as RequestParts, HeaderValue}; + /// + /// let layer = CorsLayer::new().allow_origin(AllowOrigin::predicate( + /// |origin: &HeaderValue, _request_parts: &RequestParts| { + /// origin.as_bytes().ends_with(b".rust-lang.org") + /// }, + /// )); + /// ``` + /// + /// You can also use an async closure: + /// + /// ``` + /// # #[derive(Clone)] + /// # struct Client; + /// # fn get_api_client() -> Client { + /// # Client + /// # } + /// # impl Client { + /// # async fn fetch_allowed_origins(&self) -> Vec<HeaderValue> { + /// # vec![HeaderValue::from_static("http://example.com")] + /// # } + /// # async fn fetch_allowed_origins_for_path(&self, _path: String) -> Vec<HeaderValue> { + /// # vec![HeaderValue::from_static("http://example.com")] + /// # } + /// # } + /// use tower_http::cors::{CorsLayer, AllowOrigin}; + /// use http::{request::Parts as RequestParts, HeaderValue}; + /// + /// let client = get_api_client(); + /// + /// let layer = CorsLayer::new().allow_origin(AllowOrigin::async_predicate( + /// |origin: HeaderValue, _request_parts: &RequestParts| async move { + /// // fetch list of origins that are allowed + /// let origins = client.fetch_allowed_origins().await; + /// origins.contains(&origin) + /// }, + /// )); + /// + /// let client = get_api_client(); + /// + /// // if using &RequestParts, make sure all the values are owned + /// // before passing into the future + /// let layer = CorsLayer::new().allow_origin(AllowOrigin::async_predicate( + /// |origin: HeaderValue, parts: &RequestParts| { + /// let path = parts.uri.path().to_owned(); + /// + /// async move { + /// // fetch list of origins that are allowed for this path + /// let origins = client.fetch_allowed_origins_for_path(path).await; + /// origins.contains(&origin) + /// } + /// }, + /// )); + /// ``` + /// + /// Note that multiple calls to this method will override any previous + /// calls. + /// + /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin + pub fn allow_origin<T>(mut self, origin: T) -> Self + where + T: Into<AllowOrigin>, + { + self.allow_origin = origin.into(); + self + } + + /// Set the value of the [`Access-Control-Expose-Headers`][mdn] header. + /// + /// ``` + /// use tower_http::cors::CorsLayer; + /// use http::header::CONTENT_ENCODING; + /// + /// let layer = CorsLayer::new().expose_headers([CONTENT_ENCODING]); + /// ``` + /// + /// All headers can be allowed with + /// + /// ``` + /// use tower_http::cors::{Any, CorsLayer}; + /// + /// let layer = CorsLayer::new().expose_headers(Any); + /// ``` + /// + /// Note that multiple calls to this method will override any previous + /// calls. + /// + /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Expose-Headers + pub fn expose_headers<T>(mut self, headers: T) -> Self + where + T: Into<ExposeHeaders>, + { + self.expose_headers = headers.into(); + self + } + + /// Set the value of the [`Access-Control-Allow-Private-Network`][wicg] header. + /// + /// ``` + /// use tower_http::cors::CorsLayer; + /// + /// let layer = CorsLayer::new().allow_private_network(true); + /// ``` + /// + /// [wicg]: https://wicg.github.io/private-network-access/ + pub fn allow_private_network<T>(mut self, allow_private_network: T) -> Self + where + T: Into<AllowPrivateNetwork>, + { + self.allow_private_network = allow_private_network.into(); + self + } + + /// Set the value(s) of the [`Vary`][mdn] header. + /// + /// In contrast to the other headers, this one has a non-empty default of + /// [`preflight_request_headers()`]. + /// + /// You only need to set this is you want to remove some of these defaults, + /// or if you use a closure for one of the other headers and want to add a + /// vary header accordingly. + /// + /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Vary + pub fn vary<T>(mut self, headers: T) -> Self + where + T: Into<Vary>, + { + self.vary = headers.into(); + self + } +} + +/// Represents a wildcard value (`*`) used with some CORS headers such as +/// [`CorsLayer::allow_methods`]. +#[derive(Debug, Clone, Copy)] +#[must_use] +pub struct Any; + +/// Represents a wildcard value (`*`) used with some CORS headers such as +/// [`CorsLayer::allow_methods`]. +#[deprecated = "Use Any as a unit struct literal instead"] +pub fn any() -> Any { + Any +} + +fn separated_by_commas<I>(mut iter: I) -> Option<HeaderValue> +where + I: Iterator<Item = HeaderValue>, +{ + match iter.next() { + Some(fst) => { + let mut result = BytesMut::from(fst.as_bytes()); + for val in iter { + result.reserve(val.len() + 1); + result.put_u8(b','); + result.extend_from_slice(val.as_bytes()); + } + + Some(HeaderValue::from_maybe_shared(result.freeze()).unwrap()) + } + None => None, + } +} + +impl Default for CorsLayer { + fn default() -> Self { + Self::new() + } +} + +impl<S> Layer<S> for CorsLayer { + type Service = Cors<S>; + + fn layer(&self, inner: S) -> Self::Service { + ensure_usable_cors_rules(self); + + Cors { + inner, + layer: self.clone(), + } + } +} + +/// Middleware which adds headers for [CORS][mdn]. +/// +/// See the [module docs](crate::cors) for an example. +/// +/// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS +#[derive(Debug, Clone)] +#[must_use] +pub struct Cors<S> { + inner: S, + layer: CorsLayer, +} + +impl<S> Cors<S> { + /// Create a new `Cors`. + /// + /// See [`CorsLayer::new`] for more details. + pub fn new(inner: S) -> Self { + Self { + inner, + layer: CorsLayer::new(), + } + } + + /// A permissive configuration. + /// + /// See [`CorsLayer::permissive`] for more details. + pub fn permissive(inner: S) -> Self { + Self { + inner, + layer: CorsLayer::permissive(), + } + } + + /// A very permissive configuration. + /// + /// See [`CorsLayer::very_permissive`] for more details. + pub fn very_permissive(inner: S) -> Self { + Self { + inner, + layer: CorsLayer::very_permissive(), + } + } + + define_inner_service_accessors!(); + + /// Returns a new [`Layer`] that wraps services with a [`Cors`] middleware. + /// + /// [`Layer`]: tower_layer::Layer + pub fn layer() -> CorsLayer { + CorsLayer::new() + } + + /// Set the [`Access-Control-Allow-Credentials`][mdn] header. + /// + /// See [`CorsLayer::allow_credentials`] for more details. + /// + /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Credentials + pub fn allow_credentials<T>(self, allow_credentials: T) -> Self + where + T: Into<AllowCredentials>, + { + self.map_layer(|layer| layer.allow_credentials(allow_credentials)) + } + + /// Set the value of the [`Access-Control-Allow-Headers`][mdn] header. + /// + /// See [`CorsLayer::allow_headers`] for more details. + /// + /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Headers + pub fn allow_headers<T>(self, headers: T) -> Self + where + T: Into<AllowHeaders>, + { + self.map_layer(|layer| layer.allow_headers(headers)) + } + + /// Set the value of the [`Access-Control-Max-Age`][mdn] header. + /// + /// See [`CorsLayer::max_age`] for more details. + /// + /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Max-Age + pub fn max_age<T>(self, max_age: T) -> Self + where + T: Into<MaxAge>, + { + self.map_layer(|layer| layer.max_age(max_age)) + } + + /// Set the value of the [`Access-Control-Allow-Methods`][mdn] header. + /// + /// See [`CorsLayer::allow_methods`] for more details. + /// + /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Methods + pub fn allow_methods<T>(self, methods: T) -> Self + where + T: Into<AllowMethods>, + { + self.map_layer(|layer| layer.allow_methods(methods)) + } + + /// Set the value of the [`Access-Control-Allow-Origin`][mdn] header. + /// + /// See [`CorsLayer::allow_origin`] for more details. + /// + /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin + pub fn allow_origin<T>(self, origin: T) -> Self + where + T: Into<AllowOrigin>, + { + self.map_layer(|layer| layer.allow_origin(origin)) + } + + /// Set the value of the [`Access-Control-Expose-Headers`][mdn] header. + /// + /// See [`CorsLayer::expose_headers`] for more details. + /// + /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Expose-Headers + pub fn expose_headers<T>(self, headers: T) -> Self + where + T: Into<ExposeHeaders>, + { + self.map_layer(|layer| layer.expose_headers(headers)) + } + + /// Set the value of the [`Access-Control-Allow-Private-Network`][wicg] header. + /// + /// See [`CorsLayer::allow_private_network`] for more details. + /// + /// [wicg]: https://wicg.github.io/private-network-access/ + pub fn allow_private_network<T>(self, allow_private_network: T) -> Self + where + T: Into<AllowPrivateNetwork>, + { + self.map_layer(|layer| layer.allow_private_network(allow_private_network)) + } + + fn map_layer<F>(mut self, f: F) -> Self + where + F: FnOnce(CorsLayer) -> CorsLayer, + { + self.layer = f(self.layer); + self + } +} + +impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for Cors<S> +where + S: Service<Request<ReqBody>, Response = Response<ResBody>>, + ResBody: Default, +{ + type Response = S::Response; + type Error = S::Error; + type Future = ResponseFuture<S::Future>; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + ensure_usable_cors_rules(&self.layer); + self.inner.poll_ready(cx) + } + + fn call(&mut self, req: Request<ReqBody>) -> Self::Future { + let (parts, body) = req.into_parts(); + let origin = parts.headers.get(&header::ORIGIN); + + let mut headers = HeaderMap::new(); + + // These headers are applied to both preflight and subsequent regular CORS requests: + // https://fetch.spec.whatwg.org/#http-responses + + headers.extend(self.layer.allow_credentials.to_header(origin, &parts)); + headers.extend(self.layer.allow_private_network.to_header(origin, &parts)); + headers.extend(self.layer.vary.to_header()); + + let allow_origin_future = self.layer.allow_origin.to_future(origin, &parts); + + // Return results immediately upon preflight request + if parts.method == Method::OPTIONS { + // These headers are applied only to preflight requests + headers.extend(self.layer.allow_methods.to_header(&parts)); + headers.extend(self.layer.allow_headers.to_header(&parts)); + headers.extend(self.layer.max_age.to_header(origin, &parts)); + + ResponseFuture { + inner: Kind::PreflightCall { + allow_origin_future, + headers, + }, + } + } else { + // This header is applied only to non-preflight requests + headers.extend(self.layer.expose_headers.to_header(&parts)); + + let req = Request::from_parts(parts, body); + ResponseFuture { + inner: Kind::CorsCall { + allow_origin_future, + allow_origin_complete: false, + future: self.inner.call(req), + headers, + }, + } + } + } +} + +pin_project! { + /// Response future for [`Cors`]. + pub struct ResponseFuture<F> { + #[pin] + inner: Kind<F>, + } +} + +pin_project! { + #[project = KindProj] + enum Kind<F> { + CorsCall { + #[pin] + allow_origin_future: AllowOriginFuture, + allow_origin_complete: bool, + #[pin] + future: F, + headers: HeaderMap, + }, + PreflightCall { + #[pin] + allow_origin_future: AllowOriginFuture, + headers: HeaderMap, + }, + } +} + +impl<F, B, E> Future for ResponseFuture<F> +where + F: Future<Output = Result<Response<B>, E>>, + B: Default, +{ + type Output = Result<Response<B>, E>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + match self.project().inner.project() { + KindProj::CorsCall { + allow_origin_future, + allow_origin_complete, + future, + headers, + } => { + if !*allow_origin_complete { + headers.extend(ready!(allow_origin_future.poll(cx))); + *allow_origin_complete = true; + } + + let mut response: Response<B> = ready!(future.poll(cx))?; + + let response_headers = response.headers_mut(); + + // vary header can have multiple values, don't overwrite + // previously-set value(s). + if let Some(vary) = headers.remove(header::VARY) { + response_headers.append(header::VARY, vary); + } + // extend will overwrite previous headers of remaining names + response_headers.extend(headers.drain()); + + Poll::Ready(Ok(response)) + } + KindProj::PreflightCall { + allow_origin_future, + headers, + } => { + headers.extend(ready!(allow_origin_future.poll(cx))); + + let mut response = Response::new(B::default()); + mem::swap(response.headers_mut(), headers); + + Poll::Ready(Ok(response)) + } + } + } +} + +fn ensure_usable_cors_rules(layer: &CorsLayer) { + if layer.allow_credentials.is_true() { + assert!( + !layer.allow_headers.is_wildcard(), + "Invalid CORS configuration: Cannot combine `Access-Control-Allow-Credentials: true` \ + with `Access-Control-Allow-Headers: *`" + ); + + assert!( + !layer.allow_methods.is_wildcard(), + "Invalid CORS configuration: Cannot combine `Access-Control-Allow-Credentials: true` \ + with `Access-Control-Allow-Methods: *`" + ); + + assert!( + !layer.allow_origin.is_wildcard(), + "Invalid CORS configuration: Cannot combine `Access-Control-Allow-Credentials: true` \ + with `Access-Control-Allow-Origin: *`" + ); + + assert!( + !layer.expose_headers.is_wildcard(), + "Invalid CORS configuration: Cannot combine `Access-Control-Allow-Credentials: true` \ + with `Access-Control-Expose-Headers: *`" + ); + } +} + +/// Returns an iterator over the three request headers that may be involved in a CORS preflight request. +/// +/// This is the default set of header names returned in the `vary` header +pub fn preflight_request_headers() -> impl Iterator<Item = HeaderName> { + #[allow(deprecated)] // Can be changed when MSRV >= 1.53 + array::IntoIter::new([ + header::ORIGIN, + header::ACCESS_CONTROL_REQUEST_METHOD, + header::ACCESS_CONTROL_REQUEST_HEADERS, + ]) +} diff --git a/vendor/tower-http/src/cors/tests.rs b/vendor/tower-http/src/cors/tests.rs new file mode 100644 index 00000000..8f3f4acb --- /dev/null +++ b/vendor/tower-http/src/cors/tests.rs @@ -0,0 +1,73 @@ +use std::convert::Infallible; + +use crate::test_helpers::Body; +use http::{header, HeaderValue, Request, Response}; +use tower::{service_fn, util::ServiceExt, Layer}; + +use crate::cors::{AllowOrigin, CorsLayer}; + +#[tokio::test] +#[allow( + clippy::declare_interior_mutable_const, + clippy::borrow_interior_mutable_const +)] +async fn vary_set_by_inner_service() { + const CUSTOM_VARY_HEADERS: HeaderValue = HeaderValue::from_static("accept, accept-encoding"); + const PERMISSIVE_CORS_VARY_HEADERS: HeaderValue = HeaderValue::from_static( + "origin, access-control-request-method, access-control-request-headers", + ); + + async fn inner_svc(_: Request<Body>) -> Result<Response<Body>, Infallible> { + Ok(Response::builder() + .header(header::VARY, CUSTOM_VARY_HEADERS) + .body(Body::empty()) + .unwrap()) + } + + let svc = CorsLayer::permissive().layer(service_fn(inner_svc)); + let res = svc.oneshot(Request::new(Body::empty())).await.unwrap(); + let mut vary_headers = res.headers().get_all(header::VARY).into_iter(); + assert_eq!(vary_headers.next(), Some(&CUSTOM_VARY_HEADERS)); + assert_eq!(vary_headers.next(), Some(&PERMISSIVE_CORS_VARY_HEADERS)); + assert_eq!(vary_headers.next(), None); +} + +#[tokio::test] +async fn test_allow_origin_async_predicate() { + #[derive(Clone)] + struct Client; + + impl Client { + async fn fetch_allowed_origins_for_path(&self, _path: String) -> Vec<HeaderValue> { + vec![HeaderValue::from_static("http://example.com")] + } + } + + let client = Client; + + let allow_origin = AllowOrigin::async_predicate(|origin, parts| { + let path = parts.uri.path().to_owned(); + + async move { + let origins = client.fetch_allowed_origins_for_path(path).await; + + origins.contains(&origin) + } + }); + + let valid_origin = HeaderValue::from_static("http://example.com"); + let parts = http::Request::new("hello world").into_parts().0; + + let header = allow_origin + .to_future(Some(&valid_origin), &parts) + .await + .unwrap(); + assert_eq!(header.0, header::ACCESS_CONTROL_ALLOW_ORIGIN); + assert_eq!(header.1, valid_origin); + + let invalid_origin = HeaderValue::from_static("http://example.org"); + let parts = http::Request::new("hello world").into_parts().0; + + let res = allow_origin.to_future(Some(&invalid_origin), &parts).await; + assert!(res.is_none()); +} diff --git a/vendor/tower-http/src/cors/vary.rs b/vendor/tower-http/src/cors/vary.rs new file mode 100644 index 00000000..1ed7e672 --- /dev/null +++ b/vendor/tower-http/src/cors/vary.rs @@ -0,0 +1,60 @@ +use std::array; + +use http::header::{self, HeaderName, HeaderValue}; + +use super::preflight_request_headers; + +/// Holds configuration for how to set the [`Vary`][mdn] header. +/// +/// See [`CorsLayer::vary`] for more details. +/// +/// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Vary +/// [`CorsLayer::vary`]: super::CorsLayer::vary +#[derive(Clone, Debug)] +pub struct Vary(Vec<HeaderValue>); + +impl Vary { + /// Set the list of header names to return as vary header values + /// + /// See [`CorsLayer::vary`] for more details. + /// + /// [`CorsLayer::vary`]: super::CorsLayer::vary + pub fn list<I>(headers: I) -> Self + where + I: IntoIterator<Item = HeaderName>, + { + Self(headers.into_iter().map(Into::into).collect()) + } + + pub(super) fn to_header(&self) -> Option<(HeaderName, HeaderValue)> { + let values = &self.0; + let mut res = values.first()?.as_bytes().to_owned(); + for val in &values[1..] { + res.extend_from_slice(b", "); + res.extend_from_slice(val.as_bytes()); + } + + let header_val = HeaderValue::from_bytes(&res) + .expect("comma-separated list of HeaderValues is always a valid HeaderValue"); + Some((header::VARY, header_val)) + } +} + +impl Default for Vary { + fn default() -> Self { + Self::list(preflight_request_headers()) + } +} + +impl<const N: usize> From<[HeaderName; N]> for Vary { + fn from(arr: [HeaderName; N]) -> Self { + #[allow(deprecated)] // Can be changed when MSRV >= 1.53 + Self::list(array::IntoIter::new(arr)) + } +} + +impl From<Vec<HeaderName>> for Vary { + fn from(vec: Vec<HeaderName>) -> Self { + Self::list(vec) + } +} diff --git a/vendor/tower-http/src/decompression/body.rs b/vendor/tower-http/src/decompression/body.rs new file mode 100644 index 00000000..a2970d65 --- /dev/null +++ b/vendor/tower-http/src/decompression/body.rs @@ -0,0 +1,408 @@ +#![allow(unused_imports)] + +use crate::compression_utils::CompressionLevel; +use crate::{ + compression_utils::{AsyncReadBody, BodyIntoStream, DecorateAsyncRead, WrapBody}, + BoxError, +}; +#[cfg(feature = "decompression-br")] +use async_compression::tokio::bufread::BrotliDecoder; +#[cfg(feature = "decompression-gzip")] +use async_compression::tokio::bufread::GzipDecoder; +#[cfg(feature = "decompression-deflate")] +use async_compression::tokio::bufread::ZlibDecoder; +#[cfg(feature = "decompression-zstd")] +use async_compression::tokio::bufread::ZstdDecoder; +use bytes::{Buf, Bytes}; +use http::HeaderMap; +use http_body::{Body, SizeHint}; +use pin_project_lite::pin_project; +use std::task::Context; +use std::{ + io, + marker::PhantomData, + pin::Pin, + task::{ready, Poll}, +}; +use tokio_util::io::StreamReader; + +pin_project! { + /// Response body of [`RequestDecompression`] and [`Decompression`]. + /// + /// [`RequestDecompression`]: super::RequestDecompression + /// [`Decompression`]: super::Decompression + pub struct DecompressionBody<B> + where + B: Body + { + #[pin] + pub(crate) inner: BodyInner<B>, + } +} + +impl<B> Default for DecompressionBody<B> +where + B: Body + Default, +{ + fn default() -> Self { + Self { + inner: BodyInner::Identity { + inner: B::default(), + }, + } + } +} + +impl<B> DecompressionBody<B> +where + B: Body, +{ + pub(crate) fn new(inner: BodyInner<B>) -> Self { + Self { inner } + } + + /// Get a reference to the inner body + pub fn get_ref(&self) -> &B { + match &self.inner { + #[cfg(feature = "decompression-gzip")] + BodyInner::Gzip { inner } => inner.read.get_ref().get_ref().get_ref().get_ref(), + #[cfg(feature = "decompression-deflate")] + BodyInner::Deflate { inner } => inner.read.get_ref().get_ref().get_ref().get_ref(), + #[cfg(feature = "decompression-br")] + BodyInner::Brotli { inner } => inner.read.get_ref().get_ref().get_ref().get_ref(), + #[cfg(feature = "decompression-zstd")] + BodyInner::Zstd { inner } => inner.read.get_ref().get_ref().get_ref().get_ref(), + BodyInner::Identity { inner } => inner, + + // FIXME: Remove once possible; see https://github.com/rust-lang/rust/issues/51085 + #[cfg(not(feature = "decompression-gzip"))] + BodyInner::Gzip { inner } => match inner.0 {}, + #[cfg(not(feature = "decompression-deflate"))] + BodyInner::Deflate { inner } => match inner.0 {}, + #[cfg(not(feature = "decompression-br"))] + BodyInner::Brotli { inner } => match inner.0 {}, + #[cfg(not(feature = "decompression-zstd"))] + BodyInner::Zstd { inner } => match inner.0 {}, + } + } + + /// Get a mutable reference to the inner body + pub fn get_mut(&mut self) -> &mut B { + match &mut self.inner { + #[cfg(feature = "decompression-gzip")] + BodyInner::Gzip { inner } => inner.read.get_mut().get_mut().get_mut().get_mut(), + #[cfg(feature = "decompression-deflate")] + BodyInner::Deflate { inner } => inner.read.get_mut().get_mut().get_mut().get_mut(), + #[cfg(feature = "decompression-br")] + BodyInner::Brotli { inner } => inner.read.get_mut().get_mut().get_mut().get_mut(), + #[cfg(feature = "decompression-zstd")] + BodyInner::Zstd { inner } => inner.read.get_mut().get_mut().get_mut().get_mut(), + BodyInner::Identity { inner } => inner, + + #[cfg(not(feature = "decompression-gzip"))] + BodyInner::Gzip { inner } => match inner.0 {}, + #[cfg(not(feature = "decompression-deflate"))] + BodyInner::Deflate { inner } => match inner.0 {}, + #[cfg(not(feature = "decompression-br"))] + BodyInner::Brotli { inner } => match inner.0 {}, + #[cfg(not(feature = "decompression-zstd"))] + BodyInner::Zstd { inner } => match inner.0 {}, + } + } + + /// Get a pinned mutable reference to the inner body + pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut B> { + match self.project().inner.project() { + #[cfg(feature = "decompression-gzip")] + BodyInnerProj::Gzip { inner } => inner + .project() + .read + .get_pin_mut() + .get_pin_mut() + .get_pin_mut() + .get_pin_mut(), + #[cfg(feature = "decompression-deflate")] + BodyInnerProj::Deflate { inner } => inner + .project() + .read + .get_pin_mut() + .get_pin_mut() + .get_pin_mut() + .get_pin_mut(), + #[cfg(feature = "decompression-br")] + BodyInnerProj::Brotli { inner } => inner + .project() + .read + .get_pin_mut() + .get_pin_mut() + .get_pin_mut() + .get_pin_mut(), + #[cfg(feature = "decompression-zstd")] + BodyInnerProj::Zstd { inner } => inner + .project() + .read + .get_pin_mut() + .get_pin_mut() + .get_pin_mut() + .get_pin_mut(), + BodyInnerProj::Identity { inner } => inner, + + #[cfg(not(feature = "decompression-gzip"))] + BodyInnerProj::Gzip { inner } => match inner.0 {}, + #[cfg(not(feature = "decompression-deflate"))] + BodyInnerProj::Deflate { inner } => match inner.0 {}, + #[cfg(not(feature = "decompression-br"))] + BodyInnerProj::Brotli { inner } => match inner.0 {}, + #[cfg(not(feature = "decompression-zstd"))] + BodyInnerProj::Zstd { inner } => match inner.0 {}, + } + } + + /// Consume `self`, returning the inner body + pub fn into_inner(self) -> B { + match self.inner { + #[cfg(feature = "decompression-gzip")] + BodyInner::Gzip { inner } => inner + .read + .into_inner() + .into_inner() + .into_inner() + .into_inner(), + #[cfg(feature = "decompression-deflate")] + BodyInner::Deflate { inner } => inner + .read + .into_inner() + .into_inner() + .into_inner() + .into_inner(), + #[cfg(feature = "decompression-br")] + BodyInner::Brotli { inner } => inner + .read + .into_inner() + .into_inner() + .into_inner() + .into_inner(), + #[cfg(feature = "decompression-zstd")] + BodyInner::Zstd { inner } => inner + .read + .into_inner() + .into_inner() + .into_inner() + .into_inner(), + BodyInner::Identity { inner } => inner, + + #[cfg(not(feature = "decompression-gzip"))] + BodyInner::Gzip { inner } => match inner.0 {}, + #[cfg(not(feature = "decompression-deflate"))] + BodyInner::Deflate { inner } => match inner.0 {}, + #[cfg(not(feature = "decompression-br"))] + BodyInner::Brotli { inner } => match inner.0 {}, + #[cfg(not(feature = "decompression-zstd"))] + BodyInner::Zstd { inner } => match inner.0 {}, + } + } +} + +#[cfg(any( + not(feature = "decompression-gzip"), + not(feature = "decompression-deflate"), + not(feature = "decompression-br"), + not(feature = "decompression-zstd") +))] +pub(crate) enum Never {} + +#[cfg(feature = "decompression-gzip")] +type GzipBody<B> = WrapBody<GzipDecoder<B>>; +#[cfg(not(feature = "decompression-gzip"))] +type GzipBody<B> = (Never, PhantomData<B>); + +#[cfg(feature = "decompression-deflate")] +type DeflateBody<B> = WrapBody<ZlibDecoder<B>>; +#[cfg(not(feature = "decompression-deflate"))] +type DeflateBody<B> = (Never, PhantomData<B>); + +#[cfg(feature = "decompression-br")] +type BrotliBody<B> = WrapBody<BrotliDecoder<B>>; +#[cfg(not(feature = "decompression-br"))] +type BrotliBody<B> = (Never, PhantomData<B>); + +#[cfg(feature = "decompression-zstd")] +type ZstdBody<B> = WrapBody<ZstdDecoder<B>>; +#[cfg(not(feature = "decompression-zstd"))] +type ZstdBody<B> = (Never, PhantomData<B>); + +pin_project! { + #[project = BodyInnerProj] + pub(crate) enum BodyInner<B> + where + B: Body, + { + Gzip { + #[pin] + inner: GzipBody<B>, + }, + Deflate { + #[pin] + inner: DeflateBody<B>, + }, + Brotli { + #[pin] + inner: BrotliBody<B>, + }, + Zstd { + #[pin] + inner: ZstdBody<B>, + }, + Identity { + #[pin] + inner: B, + }, + } +} + +impl<B: Body> BodyInner<B> { + #[cfg(feature = "decompression-gzip")] + pub(crate) fn gzip(inner: WrapBody<GzipDecoder<B>>) -> Self { + Self::Gzip { inner } + } + + #[cfg(feature = "decompression-deflate")] + pub(crate) fn deflate(inner: WrapBody<ZlibDecoder<B>>) -> Self { + Self::Deflate { inner } + } + + #[cfg(feature = "decompression-br")] + pub(crate) fn brotli(inner: WrapBody<BrotliDecoder<B>>) -> Self { + Self::Brotli { inner } + } + + #[cfg(feature = "decompression-zstd")] + pub(crate) fn zstd(inner: WrapBody<ZstdDecoder<B>>) -> Self { + Self::Zstd { inner } + } + + pub(crate) fn identity(inner: B) -> Self { + Self::Identity { inner } + } +} + +impl<B> Body for DecompressionBody<B> +where + B: Body, + B::Error: Into<BoxError>, +{ + type Data = Bytes; + type Error = BoxError; + + fn poll_frame( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> { + match self.project().inner.project() { + #[cfg(feature = "decompression-gzip")] + BodyInnerProj::Gzip { inner } => inner.poll_frame(cx), + #[cfg(feature = "decompression-deflate")] + BodyInnerProj::Deflate { inner } => inner.poll_frame(cx), + #[cfg(feature = "decompression-br")] + BodyInnerProj::Brotli { inner } => inner.poll_frame(cx), + #[cfg(feature = "decompression-zstd")] + BodyInnerProj::Zstd { inner } => inner.poll_frame(cx), + BodyInnerProj::Identity { inner } => match ready!(inner.poll_frame(cx)) { + Some(Ok(frame)) => { + let frame = frame.map_data(|mut buf| buf.copy_to_bytes(buf.remaining())); + Poll::Ready(Some(Ok(frame))) + } + Some(Err(err)) => Poll::Ready(Some(Err(err.into()))), + None => Poll::Ready(None), + }, + + #[cfg(not(feature = "decompression-gzip"))] + BodyInnerProj::Gzip { inner } => match inner.0 {}, + #[cfg(not(feature = "decompression-deflate"))] + BodyInnerProj::Deflate { inner } => match inner.0 {}, + #[cfg(not(feature = "decompression-br"))] + BodyInnerProj::Brotli { inner } => match inner.0 {}, + #[cfg(not(feature = "decompression-zstd"))] + BodyInnerProj::Zstd { inner } => match inner.0 {}, + } + } + + fn size_hint(&self) -> SizeHint { + match self.inner { + BodyInner::Identity { ref inner } => inner.size_hint(), + _ => SizeHint::default(), + } + } +} + +#[cfg(feature = "decompression-gzip")] +impl<B> DecorateAsyncRead for GzipDecoder<B> +where + B: Body, +{ + type Input = AsyncReadBody<B>; + type Output = GzipDecoder<Self::Input>; + + fn apply(input: Self::Input, _quality: CompressionLevel) -> Self::Output { + let mut decoder = GzipDecoder::new(input); + decoder.multiple_members(true); + decoder + } + + fn get_pin_mut(pinned: Pin<&mut Self::Output>) -> Pin<&mut Self::Input> { + pinned.get_pin_mut() + } +} + +#[cfg(feature = "decompression-deflate")] +impl<B> DecorateAsyncRead for ZlibDecoder<B> +where + B: Body, +{ + type Input = AsyncReadBody<B>; + type Output = ZlibDecoder<Self::Input>; + + fn apply(input: Self::Input, _quality: CompressionLevel) -> Self::Output { + ZlibDecoder::new(input) + } + + fn get_pin_mut(pinned: Pin<&mut Self::Output>) -> Pin<&mut Self::Input> { + pinned.get_pin_mut() + } +} + +#[cfg(feature = "decompression-br")] +impl<B> DecorateAsyncRead for BrotliDecoder<B> +where + B: Body, +{ + type Input = AsyncReadBody<B>; + type Output = BrotliDecoder<Self::Input>; + + fn apply(input: Self::Input, _quality: CompressionLevel) -> Self::Output { + BrotliDecoder::new(input) + } + + fn get_pin_mut(pinned: Pin<&mut Self::Output>) -> Pin<&mut Self::Input> { + pinned.get_pin_mut() + } +} + +#[cfg(feature = "decompression-zstd")] +impl<B> DecorateAsyncRead for ZstdDecoder<B> +where + B: Body, +{ + type Input = AsyncReadBody<B>; + type Output = ZstdDecoder<Self::Input>; + + fn apply(input: Self::Input, _quality: CompressionLevel) -> Self::Output { + let mut decoder = ZstdDecoder::new(input); + decoder.multiple_members(true); + decoder + } + + fn get_pin_mut(pinned: Pin<&mut Self::Output>) -> Pin<&mut Self::Input> { + pinned.get_pin_mut() + } +} diff --git a/vendor/tower-http/src/decompression/future.rs b/vendor/tower-http/src/decompression/future.rs new file mode 100644 index 00000000..36867e97 --- /dev/null +++ b/vendor/tower-http/src/decompression/future.rs @@ -0,0 +1,80 @@ +#![allow(unused_imports)] + +use super::{body::BodyInner, DecompressionBody}; +use crate::compression_utils::{AcceptEncoding, CompressionLevel, WrapBody}; +use crate::content_encoding::SupportedEncodings; +use http::{header, Response}; +use http_body::Body; +use pin_project_lite::pin_project; +use std::{ + future::Future, + pin::Pin, + task::{ready, Context, Poll}, +}; + +pin_project! { + /// Response future of [`Decompression`]. + /// + /// [`Decompression`]: super::Decompression + #[derive(Debug)] + pub struct ResponseFuture<F> { + #[pin] + pub(crate) inner: F, + pub(crate) accept: AcceptEncoding, + } +} + +impl<F, B, E> Future for ResponseFuture<F> +where + F: Future<Output = Result<Response<B>, E>>, + B: Body, +{ + type Output = Result<Response<DecompressionBody<B>>, E>; + + #[allow(unreachable_code, unused_mut, unused_variables)] + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + let res = ready!(self.as_mut().project().inner.poll(cx)?); + let (mut parts, body) = res.into_parts(); + + let res = + if let header::Entry::Occupied(entry) = parts.headers.entry(header::CONTENT_ENCODING) { + let body = match entry.get().as_bytes() { + #[cfg(feature = "decompression-gzip")] + b"gzip" if self.accept.gzip() => DecompressionBody::new(BodyInner::gzip( + WrapBody::new(body, CompressionLevel::default()), + )), + + #[cfg(feature = "decompression-deflate")] + b"deflate" if self.accept.deflate() => DecompressionBody::new( + BodyInner::deflate(WrapBody::new(body, CompressionLevel::default())), + ), + + #[cfg(feature = "decompression-br")] + b"br" if self.accept.br() => DecompressionBody::new(BodyInner::brotli( + WrapBody::new(body, CompressionLevel::default()), + )), + + #[cfg(feature = "decompression-zstd")] + b"zstd" if self.accept.zstd() => DecompressionBody::new(BodyInner::zstd( + WrapBody::new(body, CompressionLevel::default()), + )), + + _ => { + return Poll::Ready(Ok(Response::from_parts( + parts, + DecompressionBody::new(BodyInner::identity(body)), + ))) + } + }; + + entry.remove(); + parts.headers.remove(header::CONTENT_LENGTH); + + Response::from_parts(parts, body) + } else { + Response::from_parts(parts, DecompressionBody::new(BodyInner::identity(body))) + }; + + Poll::Ready(Ok(res)) + } +} diff --git a/vendor/tower-http/src/decompression/layer.rs b/vendor/tower-http/src/decompression/layer.rs new file mode 100644 index 00000000..4a184c16 --- /dev/null +++ b/vendor/tower-http/src/decompression/layer.rs @@ -0,0 +1,92 @@ +use super::Decompression; +use crate::compression_utils::AcceptEncoding; +use tower_layer::Layer; + +/// Decompresses response bodies of the underlying service. +/// +/// This adds the `Accept-Encoding` header to requests and transparently decompresses response +/// bodies based on the `Content-Encoding` header. +/// +/// See the [module docs](crate::decompression) for more details. +#[derive(Debug, Default, Clone)] +pub struct DecompressionLayer { + accept: AcceptEncoding, +} + +impl<S> Layer<S> for DecompressionLayer { + type Service = Decompression<S>; + + fn layer(&self, service: S) -> Self::Service { + Decompression { + inner: service, + accept: self.accept, + } + } +} + +impl DecompressionLayer { + /// Creates a new `DecompressionLayer`. + pub fn new() -> Self { + Default::default() + } + + /// Sets whether to request the gzip encoding. + #[cfg(feature = "decompression-gzip")] + pub fn gzip(mut self, enable: bool) -> Self { + self.accept.set_gzip(enable); + self + } + + /// Sets whether to request the Deflate encoding. + #[cfg(feature = "decompression-deflate")] + pub fn deflate(mut self, enable: bool) -> Self { + self.accept.set_deflate(enable); + self + } + + /// Sets whether to request the Brotli encoding. + #[cfg(feature = "decompression-br")] + pub fn br(mut self, enable: bool) -> Self { + self.accept.set_br(enable); + self + } + + /// Sets whether to request the Zstd encoding. + #[cfg(feature = "decompression-zstd")] + pub fn zstd(mut self, enable: bool) -> Self { + self.accept.set_zstd(enable); + self + } + + /// Disables the gzip encoding. + /// + /// This method is available even if the `gzip` crate feature is disabled. + pub fn no_gzip(mut self) -> Self { + self.accept.set_gzip(false); + self + } + + /// Disables the Deflate encoding. + /// + /// This method is available even if the `deflate` crate feature is disabled. + pub fn no_deflate(mut self) -> Self { + self.accept.set_deflate(false); + self + } + + /// Disables the Brotli encoding. + /// + /// This method is available even if the `br` crate feature is disabled. + pub fn no_br(mut self) -> Self { + self.accept.set_br(false); + self + } + + /// Disables the Zstd encoding. + /// + /// This method is available even if the `zstd` crate feature is disabled. + pub fn no_zstd(mut self) -> Self { + self.accept.set_zstd(false); + self + } +} diff --git a/vendor/tower-http/src/decompression/mod.rs b/vendor/tower-http/src/decompression/mod.rs new file mode 100644 index 00000000..50d4d5fa --- /dev/null +++ b/vendor/tower-http/src/decompression/mod.rs @@ -0,0 +1,233 @@ +//! Middleware that decompresses request and response bodies. +//! +//! # Examples +//! +//! #### Request +//! +//! ```rust +//! use bytes::Bytes; +//! use flate2::{write::GzEncoder, Compression}; +//! use http::{header, HeaderValue, Request, Response}; +//! use http_body_util::{Full, BodyExt}; +//! use std::{error::Error, io::Write}; +//! use tower::{Service, ServiceBuilder, service_fn, ServiceExt}; +//! use tower_http::{BoxError, decompression::{DecompressionBody, RequestDecompressionLayer}}; +//! +//! # #[tokio::main] +//! # async fn main() -> Result<(), BoxError> { +//! // A request encoded with gzip coming from some HTTP client. +//! let mut encoder = GzEncoder::new(Vec::new(), Compression::default()); +//! encoder.write_all(b"Hello?")?; +//! let request = Request::builder() +//! .header(header::CONTENT_ENCODING, "gzip") +//! .body(Full::from(encoder.finish()?))?; +//! +//! // Our HTTP server +//! let mut server = ServiceBuilder::new() +//! // Automatically decompress request bodies. +//! .layer(RequestDecompressionLayer::new()) +//! .service(service_fn(handler)); +//! +//! // Send the request, with the gzip encoded body, to our server. +//! let _response = server.ready().await?.call(request).await?; +//! +//! // Handler receives request whose body is decoded when read +//! async fn handler( +//! mut req: Request<DecompressionBody<Full<Bytes>>>, +//! ) -> Result<Response<Full<Bytes>>, BoxError>{ +//! let data = req.into_body().collect().await?.to_bytes(); +//! assert_eq!(&data[..], b"Hello?"); +//! Ok(Response::new(Full::from("Hello, World!"))) +//! } +//! # Ok(()) +//! # } +//! ``` +//! +//! #### Response +//! +//! ```rust +//! use bytes::Bytes; +//! use http::{Request, Response}; +//! use http_body_util::{Full, BodyExt}; +//! use std::convert::Infallible; +//! use tower::{Service, ServiceExt, ServiceBuilder, service_fn}; +//! use tower_http::{compression::Compression, decompression::DecompressionLayer, BoxError}; +//! # +//! # #[tokio::main] +//! # async fn main() -> Result<(), tower_http::BoxError> { +//! # async fn handle(req: Request<Full<Bytes>>) -> Result<Response<Full<Bytes>>, Infallible> { +//! # let body = Full::from("Hello, World!"); +//! # Ok(Response::new(body)) +//! # } +//! +//! // Some opaque service that applies compression. +//! let service = Compression::new(service_fn(handle)); +//! +//! // Our HTTP client. +//! let mut client = ServiceBuilder::new() +//! // Automatically decompress response bodies. +//! .layer(DecompressionLayer::new()) +//! .service(service); +//! +//! // Call the service. +//! // +//! // `DecompressionLayer` takes care of setting `Accept-Encoding`. +//! let request = Request::new(Full::<Bytes>::default()); +//! +//! let response = client +//! .ready() +//! .await? +//! .call(request) +//! .await?; +//! +//! // Read the body +//! let body = response.into_body(); +//! let bytes = body.collect().await?.to_bytes().to_vec(); +//! let body = String::from_utf8(bytes).map_err(Into::<BoxError>::into)?; +//! +//! assert_eq!(body, "Hello, World!"); +//! # +//! # Ok(()) +//! # } +//! ``` + +mod request; + +mod body; +mod future; +mod layer; +mod service; + +pub use self::{ + body::DecompressionBody, future::ResponseFuture, layer::DecompressionLayer, + service::Decompression, +}; + +pub use self::request::future::RequestDecompressionFuture; +pub use self::request::layer::RequestDecompressionLayer; +pub use self::request::service::RequestDecompression; + +#[cfg(test)] +mod tests { + use std::convert::Infallible; + use std::io::Write; + + use super::*; + use crate::test_helpers::Body; + use crate::{compression::Compression, test_helpers::WithTrailers}; + use flate2::write::GzEncoder; + use http::Response; + use http::{HeaderMap, HeaderName, Request}; + use http_body_util::BodyExt; + use tower::{service_fn, Service, ServiceExt}; + + #[tokio::test] + async fn works() { + let mut client = Decompression::new(Compression::new(service_fn(handle))); + + let req = Request::builder() + .header("accept-encoding", "gzip") + .body(Body::empty()) + .unwrap(); + let res = client.ready().await.unwrap().call(req).await.unwrap(); + + // read the body, it will be decompressed automatically + let body = res.into_body(); + let collected = body.collect().await.unwrap(); + let trailers = collected.trailers().cloned().unwrap(); + let decompressed_data = String::from_utf8(collected.to_bytes().to_vec()).unwrap(); + + assert_eq!(decompressed_data, "Hello, World!"); + + // maintains trailers + assert_eq!(trailers["foo"], "bar"); + } + + async fn handle(_req: Request<Body>) -> Result<Response<WithTrailers<Body>>, Infallible> { + let mut trailers = HeaderMap::new(); + trailers.insert(HeaderName::from_static("foo"), "bar".parse().unwrap()); + let body = Body::from("Hello, World!").with_trailers(trailers); + Ok(Response::builder().body(body).unwrap()) + } + + #[tokio::test] + async fn decompress_multi_gz() { + let mut client = Decompression::new(service_fn(handle_multi_gz)); + + let req = Request::builder() + .header("accept-encoding", "gzip") + .body(Body::empty()) + .unwrap(); + let res = client.ready().await.unwrap().call(req).await.unwrap(); + + // read the body, it will be decompressed automatically + let body = res.into_body(); + let decompressed_data = + String::from_utf8(body.collect().await.unwrap().to_bytes().to_vec()).unwrap(); + + assert_eq!(decompressed_data, "Hello, World!"); + } + + #[tokio::test] + async fn decompress_multi_zstd() { + let mut client = Decompression::new(service_fn(handle_multi_zstd)); + + let req = Request::builder() + .header("accept-encoding", "zstd") + .body(Body::empty()) + .unwrap(); + let res = client.ready().await.unwrap().call(req).await.unwrap(); + + // read the body, it will be decompressed automatically + let body = res.into_body(); + let decompressed_data = + String::from_utf8(body.collect().await.unwrap().to_bytes().to_vec()).unwrap(); + + assert_eq!(decompressed_data, "Hello, World!"); + } + + async fn handle_multi_gz(_req: Request<Body>) -> Result<Response<Body>, Infallible> { + let mut buf = Vec::new(); + let mut enc1 = GzEncoder::new(&mut buf, Default::default()); + enc1.write_all(b"Hello, ").unwrap(); + enc1.finish().unwrap(); + + let mut enc2 = GzEncoder::new(&mut buf, Default::default()); + enc2.write_all(b"World!").unwrap(); + enc2.finish().unwrap(); + + let mut res = Response::new(Body::from(buf)); + res.headers_mut() + .insert("content-encoding", "gzip".parse().unwrap()); + Ok(res) + } + + async fn handle_multi_zstd(_req: Request<Body>) -> Result<Response<Body>, Infallible> { + let mut buf = Vec::new(); + let mut enc1 = zstd::Encoder::new(&mut buf, Default::default()).unwrap(); + enc1.write_all(b"Hello, ").unwrap(); + enc1.finish().unwrap(); + + let mut enc2 = zstd::Encoder::new(&mut buf, Default::default()).unwrap(); + enc2.write_all(b"World!").unwrap(); + enc2.finish().unwrap(); + + let mut res = Response::new(Body::from(buf)); + res.headers_mut() + .insert("content-encoding", "zstd".parse().unwrap()); + Ok(res) + } + + #[allow(dead_code)] + async fn is_compatible_with_hyper() { + let client = + hyper_util::client::legacy::Client::builder(hyper_util::rt::TokioExecutor::new()) + .build_http(); + let mut client = Decompression::new(client); + + let req = Request::new(Body::empty()); + + let _: Response<DecompressionBody<_>> = + client.ready().await.unwrap().call(req).await.unwrap(); + } +} diff --git a/vendor/tower-http/src/decompression/request/future.rs b/vendor/tower-http/src/decompression/request/future.rs new file mode 100644 index 00000000..bdb22f8b --- /dev/null +++ b/vendor/tower-http/src/decompression/request/future.rs @@ -0,0 +1,98 @@ +use crate::body::UnsyncBoxBody; +use crate::compression_utils::AcceptEncoding; +use crate::BoxError; +use bytes::Buf; +use http::{header, HeaderValue, Response, StatusCode}; +use http_body::Body; +use http_body_util::BodyExt; +use http_body_util::Empty; +use pin_project_lite::pin_project; +use std::future::Future; +use std::pin::Pin; +use std::task::Context; +use std::task::Poll; + +pin_project! { + #[derive(Debug)] + /// Response future of [`RequestDecompression`] + pub struct RequestDecompressionFuture<F, B, E> + where + F: Future<Output = Result<Response<B>, E>>, + B: Body + { + #[pin] + kind: Kind<F, B, E>, + } +} + +pin_project! { + #[derive(Debug)] + #[project = StateProj] + enum Kind<F, B, E> + where + F: Future<Output = Result<Response<B>, E>>, + B: Body + { + Inner { + #[pin] + fut: F + }, + Unsupported { + #[pin] + accept: AcceptEncoding + }, + } +} + +impl<F, B, E> RequestDecompressionFuture<F, B, E> +where + F: Future<Output = Result<Response<B>, E>>, + B: Body, +{ + #[must_use] + pub(super) fn unsupported_encoding(accept: AcceptEncoding) -> Self { + Self { + kind: Kind::Unsupported { accept }, + } + } + + #[must_use] + pub(super) fn inner(fut: F) -> Self { + Self { + kind: Kind::Inner { fut }, + } + } +} + +impl<F, B, E> Future for RequestDecompressionFuture<F, B, E> +where + F: Future<Output = Result<Response<B>, E>>, + B: Body + Send + 'static, + B::Data: Buf + 'static, + B::Error: Into<BoxError> + 'static, +{ + type Output = Result<Response<UnsyncBoxBody<B::Data, BoxError>>, E>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + match self.project().kind.project() { + StateProj::Inner { fut } => fut.poll(cx).map_ok(|res| { + res.map(|body| UnsyncBoxBody::new(body.map_err(Into::into).boxed_unsync())) + }), + StateProj::Unsupported { accept } => { + let res = Response::builder() + .header( + header::ACCEPT_ENCODING, + accept + .to_header_value() + .unwrap_or(HeaderValue::from_static("identity")), + ) + .status(StatusCode::UNSUPPORTED_MEDIA_TYPE) + .body(UnsyncBoxBody::new( + Empty::new().map_err(Into::into).boxed_unsync(), + )) + .unwrap(); + Poll::Ready(Ok(res)) + } + } + } +} diff --git a/vendor/tower-http/src/decompression/request/layer.rs b/vendor/tower-http/src/decompression/request/layer.rs new file mode 100644 index 00000000..71200960 --- /dev/null +++ b/vendor/tower-http/src/decompression/request/layer.rs @@ -0,0 +1,105 @@ +use super::service::RequestDecompression; +use crate::compression_utils::AcceptEncoding; +use tower_layer::Layer; + +/// Decompresses request bodies and calls its underlying service. +/// +/// Transparently decompresses request bodies based on the `Content-Encoding` header. +/// When the encoding in the `Content-Encoding` header is not accepted an `Unsupported Media Type` +/// status code will be returned with the accepted encodings in the `Accept-Encoding` header. +/// +/// Enabling pass-through of unaccepted encodings will not return an `Unsupported Media Type`. But +/// will call the underlying service with the unmodified request if the encoding is not supported. +/// This is disabled by default. +/// +/// See the [module docs](crate::decompression) for more details. +#[derive(Debug, Default, Clone)] +pub struct RequestDecompressionLayer { + accept: AcceptEncoding, + pass_through_unaccepted: bool, +} + +impl<S> Layer<S> for RequestDecompressionLayer { + type Service = RequestDecompression<S>; + + fn layer(&self, service: S) -> Self::Service { + RequestDecompression { + inner: service, + accept: self.accept, + pass_through_unaccepted: self.pass_through_unaccepted, + } + } +} + +impl RequestDecompressionLayer { + /// Creates a new `RequestDecompressionLayer`. + pub fn new() -> Self { + Default::default() + } + + /// Sets whether to support gzip encoding. + #[cfg(feature = "decompression-gzip")] + pub fn gzip(mut self, enable: bool) -> Self { + self.accept.set_gzip(enable); + self + } + + /// Sets whether to support Deflate encoding. + #[cfg(feature = "decompression-deflate")] + pub fn deflate(mut self, enable: bool) -> Self { + self.accept.set_deflate(enable); + self + } + + /// Sets whether to support Brotli encoding. + #[cfg(feature = "decompression-br")] + pub fn br(mut self, enable: bool) -> Self { + self.accept.set_br(enable); + self + } + + /// Sets whether to support Zstd encoding. + #[cfg(feature = "decompression-zstd")] + pub fn zstd(mut self, enable: bool) -> Self { + self.accept.set_zstd(enable); + self + } + + /// Disables support for gzip encoding. + /// + /// This method is available even if the `gzip` crate feature is disabled. + pub fn no_gzip(mut self) -> Self { + self.accept.set_gzip(false); + self + } + + /// Disables support for Deflate encoding. + /// + /// This method is available even if the `deflate` crate feature is disabled. + pub fn no_deflate(mut self) -> Self { + self.accept.set_deflate(false); + self + } + + /// Disables support for Brotli encoding. + /// + /// This method is available even if the `br` crate feature is disabled. + pub fn no_br(mut self) -> Self { + self.accept.set_br(false); + self + } + + /// Disables support for Zstd encoding. + /// + /// This method is available even if the `zstd` crate feature is disabled. + pub fn no_zstd(mut self) -> Self { + self.accept.set_zstd(false); + self + } + + /// Sets whether to pass through the request even when the encoding is not supported. + pub fn pass_through_unaccepted(mut self, enable: bool) -> Self { + self.pass_through_unaccepted = enable; + self + } +} diff --git a/vendor/tower-http/src/decompression/request/mod.rs b/vendor/tower-http/src/decompression/request/mod.rs new file mode 100644 index 00000000..da3d9409 --- /dev/null +++ b/vendor/tower-http/src/decompression/request/mod.rs @@ -0,0 +1,90 @@ +pub(super) mod future; +pub(super) mod layer; +pub(super) mod service; + +#[cfg(test)] +mod tests { + use super::service::RequestDecompression; + use crate::decompression::DecompressionBody; + use crate::test_helpers::Body; + use flate2::{write::GzEncoder, Compression}; + use http::{header, Request, Response, StatusCode}; + use http_body_util::BodyExt; + use std::{convert::Infallible, io::Write}; + use tower::{service_fn, Service, ServiceExt}; + + #[tokio::test] + async fn decompress_accepted_encoding() { + let req = request_gzip(); + let mut svc = RequestDecompression::new(service_fn(assert_request_is_decompressed)); + let _ = svc.ready().await.unwrap().call(req).await.unwrap(); + } + + #[tokio::test] + async fn support_unencoded_body() { + let req = Request::builder().body(Body::from("Hello?")).unwrap(); + let mut svc = RequestDecompression::new(service_fn(assert_request_is_decompressed)); + let _ = svc.ready().await.unwrap().call(req).await.unwrap(); + } + + #[tokio::test] + async fn unaccepted_content_encoding_returns_unsupported_media_type() { + let req = request_gzip(); + let mut svc = RequestDecompression::new(service_fn(should_not_be_called)).gzip(false); + let res = svc.ready().await.unwrap().call(req).await.unwrap(); + assert_eq!(StatusCode::UNSUPPORTED_MEDIA_TYPE, res.status()); + } + + #[tokio::test] + async fn pass_through_unsupported_encoding_when_enabled() { + let req = request_gzip(); + let mut svc = RequestDecompression::new(service_fn(assert_request_is_passed_through)) + .pass_through_unaccepted(true) + .gzip(false); + let _ = svc.ready().await.unwrap().call(req).await.unwrap(); + } + + async fn assert_request_is_decompressed( + req: Request<DecompressionBody<Body>>, + ) -> Result<Response<Body>, Infallible> { + let (parts, mut body) = req.into_parts(); + let body = read_body(&mut body).await; + + assert_eq!(body, b"Hello?"); + assert!(!parts.headers.contains_key(header::CONTENT_ENCODING)); + + Ok(Response::new(Body::from("Hello, World!"))) + } + + async fn assert_request_is_passed_through( + req: Request<DecompressionBody<Body>>, + ) -> Result<Response<Body>, Infallible> { + let (parts, mut body) = req.into_parts(); + let body = read_body(&mut body).await; + + assert_ne!(body, b"Hello?"); + assert!(parts.headers.contains_key(header::CONTENT_ENCODING)); + + Ok(Response::new(Body::empty())) + } + + async fn should_not_be_called( + _: Request<DecompressionBody<Body>>, + ) -> Result<Response<Body>, Infallible> { + panic!("Inner service should not be called"); + } + + fn request_gzip() -> Request<Body> { + let mut encoder = GzEncoder::new(Vec::new(), Compression::default()); + encoder.write_all(b"Hello?").unwrap(); + let body = encoder.finish().unwrap(); + Request::builder() + .header(header::CONTENT_ENCODING, "gzip") + .body(Body::from(body)) + .unwrap() + } + + async fn read_body(body: &mut DecompressionBody<Body>) -> Vec<u8> { + body.collect().await.unwrap().to_bytes().to_vec() + } +} diff --git a/vendor/tower-http/src/decompression/request/service.rs b/vendor/tower-http/src/decompression/request/service.rs new file mode 100644 index 00000000..663436e5 --- /dev/null +++ b/vendor/tower-http/src/decompression/request/service.rs @@ -0,0 +1,198 @@ +use super::future::RequestDecompressionFuture as ResponseFuture; +use super::layer::RequestDecompressionLayer; +use crate::body::UnsyncBoxBody; +use crate::compression_utils::CompressionLevel; +use crate::{ + compression_utils::AcceptEncoding, decompression::body::BodyInner, + decompression::DecompressionBody, BoxError, +}; +use bytes::Buf; +use http::{header, Request, Response}; +use http_body::Body; +use std::task::{Context, Poll}; +use tower_service::Service; + +#[cfg(any( + feature = "decompression-gzip", + feature = "decompression-deflate", + feature = "decompression-br", + feature = "decompression-zstd", +))] +use crate::content_encoding::SupportedEncodings; + +/// Decompresses request bodies and calls its underlying service. +/// +/// Transparently decompresses request bodies based on the `Content-Encoding` header. +/// When the encoding in the `Content-Encoding` header is not accepted an `Unsupported Media Type` +/// status code will be returned with the accepted encodings in the `Accept-Encoding` header. +/// +/// Enabling pass-through of unaccepted encodings will not return an `Unsupported Media Type` but +/// will call the underlying service with the unmodified request if the encoding is not supported. +/// This is disabled by default. +/// +/// See the [module docs](crate::decompression) for more details. +#[derive(Debug, Clone)] +pub struct RequestDecompression<S> { + pub(super) inner: S, + pub(super) accept: AcceptEncoding, + pub(super) pass_through_unaccepted: bool, +} + +impl<S, ReqBody, ResBody, D> Service<Request<ReqBody>> for RequestDecompression<S> +where + S: Service<Request<DecompressionBody<ReqBody>>, Response = Response<ResBody>>, + ReqBody: Body, + ResBody: Body<Data = D> + Send + 'static, + <ResBody as Body>::Error: Into<BoxError>, + D: Buf + 'static, +{ + type Response = Response<UnsyncBoxBody<D, BoxError>>; + type Error = S::Error; + type Future = ResponseFuture<S::Future, ResBody, S::Error>; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, req: Request<ReqBody>) -> Self::Future { + let (mut parts, body) = req.into_parts(); + + let body = + if let header::Entry::Occupied(entry) = parts.headers.entry(header::CONTENT_ENCODING) { + match entry.get().as_bytes() { + #[cfg(feature = "decompression-gzip")] + b"gzip" if self.accept.gzip() => { + entry.remove(); + parts.headers.remove(header::CONTENT_LENGTH); + BodyInner::gzip(crate::compression_utils::WrapBody::new( + body, + CompressionLevel::default(), + )) + } + #[cfg(feature = "decompression-deflate")] + b"deflate" if self.accept.deflate() => { + entry.remove(); + parts.headers.remove(header::CONTENT_LENGTH); + BodyInner::deflate(crate::compression_utils::WrapBody::new( + body, + CompressionLevel::default(), + )) + } + #[cfg(feature = "decompression-br")] + b"br" if self.accept.br() => { + entry.remove(); + parts.headers.remove(header::CONTENT_LENGTH); + BodyInner::brotli(crate::compression_utils::WrapBody::new( + body, + CompressionLevel::default(), + )) + } + #[cfg(feature = "decompression-zstd")] + b"zstd" if self.accept.zstd() => { + entry.remove(); + parts.headers.remove(header::CONTENT_LENGTH); + BodyInner::zstd(crate::compression_utils::WrapBody::new( + body, + CompressionLevel::default(), + )) + } + b"identity" => BodyInner::identity(body), + _ if self.pass_through_unaccepted => BodyInner::identity(body), + _ => return ResponseFuture::unsupported_encoding(self.accept), + } + } else { + BodyInner::identity(body) + }; + let body = DecompressionBody::new(body); + let req = Request::from_parts(parts, body); + ResponseFuture::inner(self.inner.call(req)) + } +} + +impl<S> RequestDecompression<S> { + /// Creates a new `RequestDecompression` wrapping the `service`. + pub fn new(service: S) -> Self { + Self { + inner: service, + accept: AcceptEncoding::default(), + pass_through_unaccepted: false, + } + } + + define_inner_service_accessors!(); + + /// Returns a new [`Layer`] that wraps services with a `RequestDecompression` middleware. + /// + /// [`Layer`]: tower_layer::Layer + pub fn layer() -> RequestDecompressionLayer { + RequestDecompressionLayer::new() + } + + /// Passes through the request even when the encoding is not supported. + /// + /// By default pass-through is disabled. + pub fn pass_through_unaccepted(mut self, enabled: bool) -> Self { + self.pass_through_unaccepted = enabled; + self + } + + /// Sets whether to support gzip encoding. + #[cfg(feature = "decompression-gzip")] + pub fn gzip(mut self, enable: bool) -> Self { + self.accept.set_gzip(enable); + self + } + + /// Sets whether to support Deflate encoding. + #[cfg(feature = "decompression-deflate")] + pub fn deflate(mut self, enable: bool) -> Self { + self.accept.set_deflate(enable); + self + } + + /// Sets whether to support Brotli encoding. + #[cfg(feature = "decompression-br")] + pub fn br(mut self, enable: bool) -> Self { + self.accept.set_br(enable); + self + } + + /// Sets whether to support Zstd encoding. + #[cfg(feature = "decompression-zstd")] + pub fn zstd(mut self, enable: bool) -> Self { + self.accept.set_zstd(enable); + self + } + + /// Disables support for gzip encoding. + /// + /// This method is available even if the `gzip` crate feature is disabled. + pub fn no_gzip(mut self) -> Self { + self.accept.set_gzip(false); + self + } + + /// Disables support for Deflate encoding. + /// + /// This method is available even if the `deflate` crate feature is disabled. + pub fn no_deflate(mut self) -> Self { + self.accept.set_deflate(false); + self + } + + /// Disables support for Brotli encoding. + /// + /// This method is available even if the `br` crate feature is disabled. + pub fn no_br(mut self) -> Self { + self.accept.set_br(false); + self + } + + /// Disables support for Zstd encoding. + /// + /// This method is available even if the `zstd` crate feature is disabled. + pub fn no_zstd(mut self) -> Self { + self.accept.set_zstd(false); + self + } +} diff --git a/vendor/tower-http/src/decompression/service.rs b/vendor/tower-http/src/decompression/service.rs new file mode 100644 index 00000000..50e8ead5 --- /dev/null +++ b/vendor/tower-http/src/decompression/service.rs @@ -0,0 +1,127 @@ +use super::{DecompressionBody, DecompressionLayer, ResponseFuture}; +use crate::compression_utils::AcceptEncoding; +use http::{ + header::{self, ACCEPT_ENCODING}, + Request, Response, +}; +use http_body::Body; +use std::task::{Context, Poll}; +use tower_service::Service; + +/// Decompresses response bodies of the underlying service. +/// +/// This adds the `Accept-Encoding` header to requests and transparently decompresses response +/// bodies based on the `Content-Encoding` header. +/// +/// See the [module docs](crate::decompression) for more details. +#[derive(Debug, Clone)] +pub struct Decompression<S> { + pub(crate) inner: S, + pub(crate) accept: AcceptEncoding, +} + +impl<S> Decompression<S> { + /// Creates a new `Decompression` wrapping the `service`. + pub fn new(service: S) -> Self { + Self { + inner: service, + accept: AcceptEncoding::default(), + } + } + + define_inner_service_accessors!(); + + /// Returns a new [`Layer`] that wraps services with a `Decompression` middleware. + /// + /// [`Layer`]: tower_layer::Layer + pub fn layer() -> DecompressionLayer { + DecompressionLayer::new() + } + + /// Sets whether to request the gzip encoding. + #[cfg(feature = "decompression-gzip")] + pub fn gzip(mut self, enable: bool) -> Self { + self.accept.set_gzip(enable); + self + } + + /// Sets whether to request the Deflate encoding. + #[cfg(feature = "decompression-deflate")] + pub fn deflate(mut self, enable: bool) -> Self { + self.accept.set_deflate(enable); + self + } + + /// Sets whether to request the Brotli encoding. + #[cfg(feature = "decompression-br")] + pub fn br(mut self, enable: bool) -> Self { + self.accept.set_br(enable); + self + } + + /// Sets whether to request the Zstd encoding. + #[cfg(feature = "decompression-zstd")] + pub fn zstd(mut self, enable: bool) -> Self { + self.accept.set_zstd(enable); + self + } + + /// Disables the gzip encoding. + /// + /// This method is available even if the `gzip` crate feature is disabled. + pub fn no_gzip(mut self) -> Self { + self.accept.set_gzip(false); + self + } + + /// Disables the Deflate encoding. + /// + /// This method is available even if the `deflate` crate feature is disabled. + pub fn no_deflate(mut self) -> Self { + self.accept.set_deflate(false); + self + } + + /// Disables the Brotli encoding. + /// + /// This method is available even if the `br` crate feature is disabled. + pub fn no_br(mut self) -> Self { + self.accept.set_br(false); + self + } + + /// Disables the Zstd encoding. + /// + /// This method is available even if the `zstd` crate feature is disabled. + pub fn no_zstd(mut self) -> Self { + self.accept.set_zstd(false); + self + } +} + +impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for Decompression<S> +where + S: Service<Request<ReqBody>, Response = Response<ResBody>>, + ResBody: Body, +{ + type Response = Response<DecompressionBody<ResBody>>; + type Error = S::Error; + type Future = ResponseFuture<S::Future>; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future { + if let header::Entry::Vacant(entry) = req.headers_mut().entry(ACCEPT_ENCODING) { + if let Some(accept) = self.accept.to_header_value() { + entry.insert(accept); + } + } + + ResponseFuture { + inner: self.inner.call(req), + accept: self.accept, + } + } +} diff --git a/vendor/tower-http/src/follow_redirect/mod.rs b/vendor/tower-http/src/follow_redirect/mod.rs new file mode 100644 index 00000000..a90e0825 --- /dev/null +++ b/vendor/tower-http/src/follow_redirect/mod.rs @@ -0,0 +1,476 @@ +//! Middleware for following redirections. +//! +//! # Overview +//! +//! The [`FollowRedirect`] middleware retries requests with the inner [`Service`] to follow HTTP +//! redirections. +//! +//! The middleware tries to clone the original [`Request`] when making a redirected request. +//! However, since [`Extensions`][http::Extensions] are `!Clone`, any extensions set by outer +//! middleware will be discarded. Also, the request body cannot always be cloned. When the +//! original body is known to be empty by [`Body::size_hint`], the middleware uses `Default` +//! implementation of the body type to create a new request body. If you know that the body can be +//! cloned in some way, you can tell the middleware to clone it by configuring a [`policy`]. +//! +//! # Examples +//! +//! ## Basic usage +//! +//! ``` +//! use http::{Request, Response}; +//! use bytes::Bytes; +//! use http_body_util::Full; +//! use tower::{Service, ServiceBuilder, ServiceExt}; +//! use tower_http::follow_redirect::{FollowRedirectLayer, RequestUri}; +//! +//! # #[tokio::main] +//! # async fn main() -> Result<(), std::convert::Infallible> { +//! # let http_client = tower::service_fn(|req: Request<_>| async move { +//! # let dest = "https://www.rust-lang.org/"; +//! # let mut res = http::Response::builder(); +//! # if req.uri() != dest { +//! # res = res +//! # .status(http::StatusCode::MOVED_PERMANENTLY) +//! # .header(http::header::LOCATION, dest); +//! # } +//! # Ok::<_, std::convert::Infallible>(res.body(Full::<Bytes>::default()).unwrap()) +//! # }); +//! let mut client = ServiceBuilder::new() +//! .layer(FollowRedirectLayer::new()) +//! .service(http_client); +//! +//! let request = Request::builder() +//! .uri("https://rust-lang.org/") +//! .body(Full::<Bytes>::default()) +//! .unwrap(); +//! +//! let response = client.ready().await?.call(request).await?; +//! // Get the final request URI. +//! assert_eq!(response.extensions().get::<RequestUri>().unwrap().0, "https://www.rust-lang.org/"); +//! # Ok(()) +//! # } +//! ``` +//! +//! ## Customizing the `Policy` +//! +//! You can use a [`Policy`] value to customize how the middleware handles redirections. +//! +//! ``` +//! use http::{Request, Response}; +//! use http_body_util::Full; +//! use bytes::Bytes; +//! use tower::{Service, ServiceBuilder, ServiceExt}; +//! use tower_http::follow_redirect::{ +//! policy::{self, PolicyExt}, +//! FollowRedirectLayer, +//! }; +//! +//! #[derive(Debug)] +//! enum MyError { +//! TooManyRedirects, +//! Other(tower::BoxError), +//! } +//! +//! # #[tokio::main] +//! # async fn main() -> Result<(), MyError> { +//! # let http_client = +//! # tower::service_fn(|_: Request<Full<Bytes>>| async { Ok(Response::new(Full::<Bytes>::default())) }); +//! let policy = policy::Limited::new(10) // Set the maximum number of redirections to 10. +//! // Return an error when the limit was reached. +//! .or::<_, (), _>(policy::redirect_fn(|_| Err(MyError::TooManyRedirects))) +//! // Do not follow cross-origin redirections, and return the redirection responses as-is. +//! .and::<_, (), _>(policy::SameOrigin::new()); +//! +//! let mut client = ServiceBuilder::new() +//! .layer(FollowRedirectLayer::with_policy(policy)) +//! .map_err(MyError::Other) +//! .service(http_client); +//! +//! // ... +//! # let _ = client.ready().await?.call(Request::default()).await?; +//! # Ok(()) +//! # } +//! ``` + +pub mod policy; + +use self::policy::{Action, Attempt, Policy, Standard}; +use futures_util::future::Either; +use http::{ + header::CONTENT_ENCODING, header::CONTENT_LENGTH, header::CONTENT_TYPE, header::LOCATION, + header::TRANSFER_ENCODING, HeaderMap, HeaderValue, Method, Request, Response, StatusCode, Uri, + Version, +}; +use http_body::Body; +use iri_string::types::{UriAbsoluteString, UriReferenceStr}; +use pin_project_lite::pin_project; +use std::{ + convert::TryFrom, + future::Future, + mem, + pin::Pin, + str, + task::{ready, Context, Poll}, +}; +use tower::util::Oneshot; +use tower_layer::Layer; +use tower_service::Service; + +/// [`Layer`] for retrying requests with a [`Service`] to follow redirection responses. +/// +/// See the [module docs](self) for more details. +#[derive(Clone, Copy, Debug, Default)] +pub struct FollowRedirectLayer<P = Standard> { + policy: P, +} + +impl FollowRedirectLayer { + /// Create a new [`FollowRedirectLayer`] with a [`Standard`] redirection policy. + pub fn new() -> Self { + Self::default() + } +} + +impl<P> FollowRedirectLayer<P> { + /// Create a new [`FollowRedirectLayer`] with the given redirection [`Policy`]. + pub fn with_policy(policy: P) -> Self { + FollowRedirectLayer { policy } + } +} + +impl<S, P> Layer<S> for FollowRedirectLayer<P> +where + S: Clone, + P: Clone, +{ + type Service = FollowRedirect<S, P>; + + fn layer(&self, inner: S) -> Self::Service { + FollowRedirect::with_policy(inner, self.policy.clone()) + } +} + +/// Middleware that retries requests with a [`Service`] to follow redirection responses. +/// +/// See the [module docs](self) for more details. +#[derive(Clone, Copy, Debug)] +pub struct FollowRedirect<S, P = Standard> { + inner: S, + policy: P, +} + +impl<S> FollowRedirect<S> { + /// Create a new [`FollowRedirect`] with a [`Standard`] redirection policy. + pub fn new(inner: S) -> Self { + Self::with_policy(inner, Standard::default()) + } + + /// Returns a new [`Layer`] that wraps services with a `FollowRedirect` middleware. + /// + /// [`Layer`]: tower_layer::Layer + pub fn layer() -> FollowRedirectLayer { + FollowRedirectLayer::new() + } +} + +impl<S, P> FollowRedirect<S, P> +where + P: Clone, +{ + /// Create a new [`FollowRedirect`] with the given redirection [`Policy`]. + pub fn with_policy(inner: S, policy: P) -> Self { + FollowRedirect { inner, policy } + } + + /// Returns a new [`Layer`] that wraps services with a `FollowRedirect` middleware + /// with the given redirection [`Policy`]. + /// + /// [`Layer`]: tower_layer::Layer + pub fn layer_with_policy(policy: P) -> FollowRedirectLayer<P> { + FollowRedirectLayer::with_policy(policy) + } + + define_inner_service_accessors!(); +} + +impl<ReqBody, ResBody, S, P> Service<Request<ReqBody>> for FollowRedirect<S, P> +where + S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone, + ReqBody: Body + Default, + P: Policy<ReqBody, S::Error> + Clone, +{ + type Response = Response<ResBody>; + type Error = S::Error; + type Future = ResponseFuture<S, ReqBody, P>; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future { + let service = self.inner.clone(); + let mut service = mem::replace(&mut self.inner, service); + let mut policy = self.policy.clone(); + let mut body = BodyRepr::None; + body.try_clone_from(req.body(), &policy); + policy.on_request(&mut req); + ResponseFuture { + method: req.method().clone(), + uri: req.uri().clone(), + version: req.version(), + headers: req.headers().clone(), + body, + future: Either::Left(service.call(req)), + service, + policy, + } + } +} + +pin_project! { + /// Response future for [`FollowRedirect`]. + #[derive(Debug)] + pub struct ResponseFuture<S, B, P> + where + S: Service<Request<B>>, + { + #[pin] + future: Either<S::Future, Oneshot<S, Request<B>>>, + service: S, + policy: P, + method: Method, + uri: Uri, + version: Version, + headers: HeaderMap<HeaderValue>, + body: BodyRepr<B>, + } +} + +impl<S, ReqBody, ResBody, P> Future for ResponseFuture<S, ReqBody, P> +where + S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone, + ReqBody: Body + Default, + P: Policy<ReqBody, S::Error>, +{ + type Output = Result<Response<ResBody>, S::Error>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + let mut this = self.project(); + let mut res = ready!(this.future.as_mut().poll(cx)?); + res.extensions_mut().insert(RequestUri(this.uri.clone())); + + let drop_payload_headers = |headers: &mut HeaderMap| { + for header in &[ + CONTENT_TYPE, + CONTENT_LENGTH, + CONTENT_ENCODING, + TRANSFER_ENCODING, + ] { + headers.remove(header); + } + }; + match res.status() { + StatusCode::MOVED_PERMANENTLY | StatusCode::FOUND => { + // User agents MAY change the request method from POST to GET + // (RFC 7231 section 6.4.2. and 6.4.3.). + if *this.method == Method::POST { + *this.method = Method::GET; + *this.body = BodyRepr::Empty; + drop_payload_headers(this.headers); + } + } + StatusCode::SEE_OTHER => { + // A user agent can perform a GET or HEAD request (RFC 7231 section 6.4.4.). + if *this.method != Method::HEAD { + *this.method = Method::GET; + } + *this.body = BodyRepr::Empty; + drop_payload_headers(this.headers); + } + StatusCode::TEMPORARY_REDIRECT | StatusCode::PERMANENT_REDIRECT => {} + _ => return Poll::Ready(Ok(res)), + }; + + let body = if let Some(body) = this.body.take() { + body + } else { + return Poll::Ready(Ok(res)); + }; + + let location = res + .headers() + .get(&LOCATION) + .and_then(|loc| resolve_uri(str::from_utf8(loc.as_bytes()).ok()?, this.uri)); + let location = if let Some(loc) = location { + loc + } else { + return Poll::Ready(Ok(res)); + }; + + let attempt = Attempt { + status: res.status(), + location: &location, + previous: this.uri, + }; + match this.policy.redirect(&attempt)? { + Action::Follow => { + *this.uri = location; + this.body.try_clone_from(&body, &this.policy); + + let mut req = Request::new(body); + *req.uri_mut() = this.uri.clone(); + *req.method_mut() = this.method.clone(); + *req.version_mut() = *this.version; + *req.headers_mut() = this.headers.clone(); + this.policy.on_request(&mut req); + this.future + .set(Either::Right(Oneshot::new(this.service.clone(), req))); + + cx.waker().wake_by_ref(); + Poll::Pending + } + Action::Stop => Poll::Ready(Ok(res)), + } + } +} + +/// Response [`Extensions`][http::Extensions] value that represents the effective request URI of +/// a response returned by a [`FollowRedirect`] middleware. +/// +/// The value differs from the original request's effective URI if the middleware has followed +/// redirections. +#[derive(Clone)] +pub struct RequestUri(pub Uri); + +#[derive(Debug)] +enum BodyRepr<B> { + Some(B), + Empty, + None, +} + +impl<B> BodyRepr<B> +where + B: Body + Default, +{ + fn take(&mut self) -> Option<B> { + match mem::replace(self, BodyRepr::None) { + BodyRepr::Some(body) => Some(body), + BodyRepr::Empty => { + *self = BodyRepr::Empty; + Some(B::default()) + } + BodyRepr::None => None, + } + } + + fn try_clone_from<P, E>(&mut self, body: &B, policy: &P) + where + P: Policy<B, E>, + { + match self { + BodyRepr::Some(_) | BodyRepr::Empty => {} + BodyRepr::None => { + if let Some(body) = clone_body(policy, body) { + *self = BodyRepr::Some(body); + } + } + } + } +} + +fn clone_body<P, B, E>(policy: &P, body: &B) -> Option<B> +where + P: Policy<B, E>, + B: Body + Default, +{ + if body.size_hint().exact() == Some(0) { + Some(B::default()) + } else { + policy.clone_body(body) + } +} + +/// Try to resolve a URI reference `relative` against a base URI `base`. +fn resolve_uri(relative: &str, base: &Uri) -> Option<Uri> { + let relative = UriReferenceStr::new(relative).ok()?; + let base = UriAbsoluteString::try_from(base.to_string()).ok()?; + let uri = relative.resolve_against(&base).to_string(); + Uri::try_from(uri).ok() +} + +#[cfg(test)] +mod tests { + use super::{policy::*, *}; + use crate::test_helpers::Body; + use http::header::LOCATION; + use std::convert::Infallible; + use tower::{ServiceBuilder, ServiceExt}; + + #[tokio::test] + async fn follows() { + let svc = ServiceBuilder::new() + .layer(FollowRedirectLayer::with_policy(Action::Follow)) + .buffer(1) + .service_fn(handle); + let req = Request::builder() + .uri("http://example.com/42") + .body(Body::empty()) + .unwrap(); + let res = svc.oneshot(req).await.unwrap(); + assert_eq!(*res.body(), 0); + assert_eq!( + res.extensions().get::<RequestUri>().unwrap().0, + "http://example.com/0" + ); + } + + #[tokio::test] + async fn stops() { + let svc = ServiceBuilder::new() + .layer(FollowRedirectLayer::with_policy(Action::Stop)) + .buffer(1) + .service_fn(handle); + let req = Request::builder() + .uri("http://example.com/42") + .body(Body::empty()) + .unwrap(); + let res = svc.oneshot(req).await.unwrap(); + assert_eq!(*res.body(), 42); + assert_eq!( + res.extensions().get::<RequestUri>().unwrap().0, + "http://example.com/42" + ); + } + + #[tokio::test] + async fn limited() { + let svc = ServiceBuilder::new() + .layer(FollowRedirectLayer::with_policy(Limited::new(10))) + .buffer(1) + .service_fn(handle); + let req = Request::builder() + .uri("http://example.com/42") + .body(Body::empty()) + .unwrap(); + let res = svc.oneshot(req).await.unwrap(); + assert_eq!(*res.body(), 42 - 10); + assert_eq!( + res.extensions().get::<RequestUri>().unwrap().0, + "http://example.com/32" + ); + } + + /// A server with an endpoint `GET /{n}` which redirects to `/{n-1}` unless `n` equals zero, + /// returning `n` as the response body. + async fn handle<B>(req: Request<B>) -> Result<Response<u64>, Infallible> { + let n: u64 = req.uri().path()[1..].parse().unwrap(); + let mut res = Response::builder(); + if n > 0 { + res = res + .status(StatusCode::MOVED_PERMANENTLY) + .header(LOCATION, format!("/{}", n - 1)); + } + Ok::<_, Infallible>(res.body(n).unwrap()) + } +} diff --git a/vendor/tower-http/src/follow_redirect/policy/and.rs b/vendor/tower-http/src/follow_redirect/policy/and.rs new file mode 100644 index 00000000..69d2b7da --- /dev/null +++ b/vendor/tower-http/src/follow_redirect/policy/and.rs @@ -0,0 +1,118 @@ +use super::{Action, Attempt, Policy}; +use http::Request; + +/// A redirection [`Policy`] that combines the results of two `Policy`s. +/// +/// See [`PolicyExt::and`][super::PolicyExt::and] for more details. +#[derive(Clone, Copy, Debug, Default)] +pub struct And<A, B> { + a: A, + b: B, +} + +impl<A, B> And<A, B> { + pub(crate) fn new<Bd, E>(a: A, b: B) -> Self + where + A: Policy<Bd, E>, + B: Policy<Bd, E>, + { + And { a, b } + } +} + +impl<Bd, E, A, B> Policy<Bd, E> for And<A, B> +where + A: Policy<Bd, E>, + B: Policy<Bd, E>, +{ + fn redirect(&mut self, attempt: &Attempt<'_>) -> Result<Action, E> { + match self.a.redirect(attempt) { + Ok(Action::Follow) => self.b.redirect(attempt), + a => a, + } + } + + fn on_request(&mut self, request: &mut Request<Bd>) { + self.a.on_request(request); + self.b.on_request(request); + } + + fn clone_body(&self, body: &Bd) -> Option<Bd> { + self.a.clone_body(body).or_else(|| self.b.clone_body(body)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use http::Uri; + + struct Taint<P> { + policy: P, + used: bool, + } + + impl<P> Taint<P> { + fn new(policy: P) -> Self { + Taint { + policy, + used: false, + } + } + } + + impl<B, E, P> Policy<B, E> for Taint<P> + where + P: Policy<B, E>, + { + fn redirect(&mut self, attempt: &Attempt<'_>) -> Result<Action, E> { + self.used = true; + self.policy.redirect(attempt) + } + } + + #[test] + fn redirect() { + let attempt = Attempt { + status: Default::default(), + location: &Uri::from_static("*"), + previous: &Uri::from_static("*"), + }; + + let mut a = Taint::new(Action::Follow); + let mut b = Taint::new(Action::Follow); + let mut policy = And::new::<(), ()>(&mut a, &mut b); + assert!(Policy::<(), ()>::redirect(&mut policy, &attempt) + .unwrap() + .is_follow()); + assert!(a.used); + assert!(b.used); + + let mut a = Taint::new(Action::Stop); + let mut b = Taint::new(Action::Follow); + let mut policy = And::new::<(), ()>(&mut a, &mut b); + assert!(Policy::<(), ()>::redirect(&mut policy, &attempt) + .unwrap() + .is_stop()); + assert!(a.used); + assert!(!b.used); // short-circuiting + + let mut a = Taint::new(Action::Follow); + let mut b = Taint::new(Action::Stop); + let mut policy = And::new::<(), ()>(&mut a, &mut b); + assert!(Policy::<(), ()>::redirect(&mut policy, &attempt) + .unwrap() + .is_stop()); + assert!(a.used); + assert!(b.used); + + let mut a = Taint::new(Action::Stop); + let mut b = Taint::new(Action::Stop); + let mut policy = And::new::<(), ()>(&mut a, &mut b); + assert!(Policy::<(), ()>::redirect(&mut policy, &attempt) + .unwrap() + .is_stop()); + assert!(a.used); + assert!(!b.used); + } +} diff --git a/vendor/tower-http/src/follow_redirect/policy/clone_body_fn.rs b/vendor/tower-http/src/follow_redirect/policy/clone_body_fn.rs new file mode 100644 index 00000000..d7d7cb7c --- /dev/null +++ b/vendor/tower-http/src/follow_redirect/policy/clone_body_fn.rs @@ -0,0 +1,42 @@ +use super::{Action, Attempt, Policy}; +use std::fmt; + +/// A redirection [`Policy`] created from a closure. +/// +/// See [`clone_body_fn`] for more details. +#[derive(Clone, Copy)] +pub struct CloneBodyFn<F> { + f: F, +} + +impl<F> fmt::Debug for CloneBodyFn<F> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("CloneBodyFn") + .field("f", &std::any::type_name::<F>()) + .finish() + } +} + +impl<F, B, E> Policy<B, E> for CloneBodyFn<F> +where + F: Fn(&B) -> Option<B>, +{ + fn redirect(&mut self, _: &Attempt<'_>) -> Result<Action, E> { + Ok(Action::Follow) + } + + fn clone_body(&self, body: &B) -> Option<B> { + (self.f)(body) + } +} + +/// Create a new redirection [`Policy`] from a closure `F: Fn(&B) -> Option<B>`. +/// +/// [`clone_body`][Policy::clone_body] method of the returned `Policy` delegates to the wrapped +/// closure and [`redirect`][Policy::redirect] method always returns [`Action::Follow`]. +pub fn clone_body_fn<F, B>(f: F) -> CloneBodyFn<F> +where + F: Fn(&B) -> Option<B>, +{ + CloneBodyFn { f } +} diff --git a/vendor/tower-http/src/follow_redirect/policy/filter_credentials.rs b/vendor/tower-http/src/follow_redirect/policy/filter_credentials.rs new file mode 100644 index 00000000..fea80f11 --- /dev/null +++ b/vendor/tower-http/src/follow_redirect/policy/filter_credentials.rs @@ -0,0 +1,161 @@ +use super::{eq_origin, Action, Attempt, Policy}; +use http::{ + header::{self, HeaderName}, + Request, +}; + +/// A redirection [`Policy`] that removes credentials from requests in redirections. +#[derive(Clone, Debug)] +pub struct FilterCredentials { + block_cross_origin: bool, + block_any: bool, + remove_blocklisted: bool, + remove_all: bool, + blocked: bool, +} + +const BLOCKLIST: &[HeaderName] = &[ + header::AUTHORIZATION, + header::COOKIE, + header::PROXY_AUTHORIZATION, +]; + +impl FilterCredentials { + /// Create a new [`FilterCredentials`] that removes blocklisted request headers in cross-origin + /// redirections. + pub fn new() -> Self { + FilterCredentials { + block_cross_origin: true, + block_any: false, + remove_blocklisted: true, + remove_all: false, + blocked: false, + } + } + + /// Configure `self` to mark cross-origin redirections as "blocked". + pub fn block_cross_origin(mut self, enable: bool) -> Self { + self.block_cross_origin = enable; + self + } + + /// Configure `self` to mark every redirection as "blocked". + pub fn block_any(mut self) -> Self { + self.block_any = true; + self + } + + /// Configure `self` to mark no redirections as "blocked". + pub fn block_none(mut self) -> Self { + self.block_any = false; + self.block_cross_origin(false) + } + + /// Configure `self` to remove blocklisted headers in "blocked" redirections. + /// + /// The blocklist includes the following headers: + /// + /// - `Authorization` + /// - `Cookie` + /// - `Proxy-Authorization` + pub fn remove_blocklisted(mut self, enable: bool) -> Self { + self.remove_blocklisted = enable; + self + } + + /// Configure `self` to remove all headers in "blocked" redirections. + pub fn remove_all(mut self) -> Self { + self.remove_all = true; + self + } + + /// Configure `self` to remove no headers in "blocked" redirections. + pub fn remove_none(mut self) -> Self { + self.remove_all = false; + self.remove_blocklisted(false) + } +} + +impl Default for FilterCredentials { + fn default() -> Self { + Self::new() + } +} + +impl<B, E> Policy<B, E> for FilterCredentials { + fn redirect(&mut self, attempt: &Attempt<'_>) -> Result<Action, E> { + self.blocked = self.block_any + || (self.block_cross_origin && !eq_origin(attempt.previous(), attempt.location())); + Ok(Action::Follow) + } + + fn on_request(&mut self, request: &mut Request<B>) { + if self.blocked { + let headers = request.headers_mut(); + if self.remove_all { + headers.clear(); + } else if self.remove_blocklisted { + for key in BLOCKLIST { + headers.remove(key); + } + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use http::Uri; + + #[test] + fn works() { + let mut policy = FilterCredentials::default(); + + let initial = Uri::from_static("http://example.com/old"); + let same_origin = Uri::from_static("http://example.com/new"); + let cross_origin = Uri::from_static("https://example.com/new"); + + let mut request = Request::builder() + .uri(initial) + .header(header::COOKIE, "42") + .body(()) + .unwrap(); + Policy::<(), ()>::on_request(&mut policy, &mut request); + assert!(request.headers().contains_key(header::COOKIE)); + + let attempt = Attempt { + status: Default::default(), + location: &same_origin, + previous: request.uri(), + }; + assert!(Policy::<(), ()>::redirect(&mut policy, &attempt) + .unwrap() + .is_follow()); + + let mut request = Request::builder() + .uri(same_origin) + .header(header::COOKIE, "42") + .body(()) + .unwrap(); + Policy::<(), ()>::on_request(&mut policy, &mut request); + assert!(request.headers().contains_key(header::COOKIE)); + + let attempt = Attempt { + status: Default::default(), + location: &cross_origin, + previous: request.uri(), + }; + assert!(Policy::<(), ()>::redirect(&mut policy, &attempt) + .unwrap() + .is_follow()); + + let mut request = Request::builder() + .uri(cross_origin) + .header(header::COOKIE, "42") + .body(()) + .unwrap(); + Policy::<(), ()>::on_request(&mut policy, &mut request); + assert!(!request.headers().contains_key(header::COOKIE)); + } +} diff --git a/vendor/tower-http/src/follow_redirect/policy/limited.rs b/vendor/tower-http/src/follow_redirect/policy/limited.rs new file mode 100644 index 00000000..a81b0d79 --- /dev/null +++ b/vendor/tower-http/src/follow_redirect/policy/limited.rs @@ -0,0 +1,74 @@ +use super::{Action, Attempt, Policy}; + +/// A redirection [`Policy`] that limits the number of successive redirections. +#[derive(Clone, Copy, Debug)] +pub struct Limited { + remaining: usize, +} + +impl Limited { + /// Create a new [`Limited`] with a limit of `max` redirections. + pub fn new(max: usize) -> Self { + Limited { remaining: max } + } +} + +impl Default for Limited { + /// Returns the default [`Limited`] with a limit of `20` redirections. + fn default() -> Self { + // This is the (default) limit of Firefox and the Fetch API. + // https://hg.mozilla.org/mozilla-central/file/6264f13d54a1caa4f5b60303617a819efd91b8ee/modules/libpref/init/all.js#l1371 + // https://fetch.spec.whatwg.org/#http-redirect-fetch + Limited::new(20) + } +} + +impl<B, E> Policy<B, E> for Limited { + fn redirect(&mut self, _: &Attempt<'_>) -> Result<Action, E> { + if self.remaining > 0 { + self.remaining -= 1; + Ok(Action::Follow) + } else { + Ok(Action::Stop) + } + } +} + +#[cfg(test)] +mod tests { + use http::{Request, Uri}; + + use super::*; + + #[test] + fn works() { + let uri = Uri::from_static("https://example.com/"); + let mut policy = Limited::new(2); + + for _ in 0..2 { + let mut request = Request::builder().uri(uri.clone()).body(()).unwrap(); + Policy::<(), ()>::on_request(&mut policy, &mut request); + + let attempt = Attempt { + status: Default::default(), + location: &uri, + previous: &uri, + }; + assert!(Policy::<(), ()>::redirect(&mut policy, &attempt) + .unwrap() + .is_follow()); + } + + let mut request = Request::builder().uri(uri.clone()).body(()).unwrap(); + Policy::<(), ()>::on_request(&mut policy, &mut request); + + let attempt = Attempt { + status: Default::default(), + location: &uri, + previous: &uri, + }; + assert!(Policy::<(), ()>::redirect(&mut policy, &attempt) + .unwrap() + .is_stop()); + } +} diff --git a/vendor/tower-http/src/follow_redirect/policy/mod.rs b/vendor/tower-http/src/follow_redirect/policy/mod.rs new file mode 100644 index 00000000..8e5d39ce --- /dev/null +++ b/vendor/tower-http/src/follow_redirect/policy/mod.rs @@ -0,0 +1,316 @@ +//! Tools for customizing the behavior of a [`FollowRedirect`][super::FollowRedirect] middleware. + +mod and; +mod clone_body_fn; +mod filter_credentials; +mod limited; +mod or; +mod redirect_fn; +mod same_origin; + +pub use self::{ + and::And, + clone_body_fn::{clone_body_fn, CloneBodyFn}, + filter_credentials::FilterCredentials, + limited::Limited, + or::Or, + redirect_fn::{redirect_fn, RedirectFn}, + same_origin::SameOrigin, +}; + +use http::{uri::Scheme, Request, StatusCode, Uri}; + +/// Trait for the policy on handling redirection responses. +/// +/// # Example +/// +/// Detecting a cyclic redirection: +/// +/// ``` +/// use http::{Request, Uri}; +/// use std::collections::HashSet; +/// use tower_http::follow_redirect::policy::{Action, Attempt, Policy}; +/// +/// #[derive(Clone)] +/// pub struct DetectCycle { +/// uris: HashSet<Uri>, +/// } +/// +/// impl<B, E> Policy<B, E> for DetectCycle { +/// fn redirect(&mut self, attempt: &Attempt<'_>) -> Result<Action, E> { +/// if self.uris.contains(attempt.location()) { +/// Ok(Action::Stop) +/// } else { +/// self.uris.insert(attempt.previous().clone()); +/// Ok(Action::Follow) +/// } +/// } +/// } +/// ``` +pub trait Policy<B, E> { + /// Invoked when the service received a response with a redirection status code (`3xx`). + /// + /// This method returns an [`Action`] which indicates whether the service should follow + /// the redirection. + fn redirect(&mut self, attempt: &Attempt<'_>) -> Result<Action, E>; + + /// Invoked right before the service makes a request, regardless of whether it is redirected + /// or not. + /// + /// This can for example be used to remove sensitive headers from the request + /// or prepare the request in other ways. + /// + /// The default implementation does nothing. + fn on_request(&mut self, _request: &mut Request<B>) {} + + /// Try to clone a request body before the service makes a redirected request. + /// + /// If the request body cannot be cloned, return `None`. + /// + /// This is not invoked when [`B::size_hint`][http_body::Body::size_hint] returns zero, + /// in which case `B::default()` will be used to create a new request body. + /// + /// The default implementation returns `None`. + fn clone_body(&self, _body: &B) -> Option<B> { + None + } +} + +impl<B, E, P> Policy<B, E> for &mut P +where + P: Policy<B, E> + ?Sized, +{ + fn redirect(&mut self, attempt: &Attempt<'_>) -> Result<Action, E> { + (**self).redirect(attempt) + } + + fn on_request(&mut self, request: &mut Request<B>) { + (**self).on_request(request) + } + + fn clone_body(&self, body: &B) -> Option<B> { + (**self).clone_body(body) + } +} + +impl<B, E, P> Policy<B, E> for Box<P> +where + P: Policy<B, E> + ?Sized, +{ + fn redirect(&mut self, attempt: &Attempt<'_>) -> Result<Action, E> { + (**self).redirect(attempt) + } + + fn on_request(&mut self, request: &mut Request<B>) { + (**self).on_request(request) + } + + fn clone_body(&self, body: &B) -> Option<B> { + (**self).clone_body(body) + } +} + +/// An extension trait for `Policy` that provides additional adapters. +pub trait PolicyExt { + /// Create a new `Policy` that returns [`Action::Follow`] only if `self` and `other` return + /// `Action::Follow`. + /// + /// [`clone_body`][Policy::clone_body] method of the returned `Policy` tries to clone the body + /// with both policies. + /// + /// # Example + /// + /// ``` + /// use bytes::Bytes; + /// use http_body_util::Full; + /// use tower_http::follow_redirect::policy::{self, clone_body_fn, Limited, PolicyExt}; + /// + /// enum MyBody { + /// Bytes(Bytes), + /// Full(Full<Bytes>), + /// } + /// + /// let policy = Limited::default().and::<_, _, ()>(clone_body_fn(|body| { + /// if let MyBody::Bytes(buf) = body { + /// Some(MyBody::Bytes(buf.clone())) + /// } else { + /// None + /// } + /// })); + /// ``` + fn and<P, B, E>(self, other: P) -> And<Self, P> + where + Self: Policy<B, E> + Sized, + P: Policy<B, E>; + + /// Create a new `Policy` that returns [`Action::Follow`] if either `self` or `other` returns + /// `Action::Follow`. + /// + /// [`clone_body`][Policy::clone_body] method of the returned `Policy` tries to clone the body + /// with both policies. + /// + /// # Example + /// + /// ``` + /// use tower_http::follow_redirect::policy::{self, Action, Limited, PolicyExt}; + /// + /// #[derive(Clone)] + /// enum MyError { + /// TooManyRedirects, + /// // ... + /// } + /// + /// let policy = Limited::default().or::<_, (), _>(Err(MyError::TooManyRedirects)); + /// ``` + fn or<P, B, E>(self, other: P) -> Or<Self, P> + where + Self: Policy<B, E> + Sized, + P: Policy<B, E>; +} + +impl<T> PolicyExt for T +where + T: ?Sized, +{ + fn and<P, B, E>(self, other: P) -> And<Self, P> + where + Self: Policy<B, E> + Sized, + P: Policy<B, E>, + { + And::new(self, other) + } + + fn or<P, B, E>(self, other: P) -> Or<Self, P> + where + Self: Policy<B, E> + Sized, + P: Policy<B, E>, + { + Or::new(self, other) + } +} + +/// A redirection [`Policy`] with a reasonable set of standard behavior. +/// +/// This policy limits the number of successive redirections ([`Limited`]) +/// and removes credentials from requests in cross-origin redirections ([`FilterCredentials`]). +pub type Standard = And<Limited, FilterCredentials>; + +/// A type that holds information on a redirection attempt. +pub struct Attempt<'a> { + pub(crate) status: StatusCode, + pub(crate) location: &'a Uri, + pub(crate) previous: &'a Uri, +} + +impl<'a> Attempt<'a> { + /// Returns the redirection response. + pub fn status(&self) -> StatusCode { + self.status + } + + /// Returns the destination URI of the redirection. + pub fn location(&self) -> &'a Uri { + self.location + } + + /// Returns the URI of the original request. + pub fn previous(&self) -> &'a Uri { + self.previous + } +} + +/// A value returned by [`Policy::redirect`] which indicates the action +/// [`FollowRedirect`][super::FollowRedirect] should take for a redirection response. +#[derive(Clone, Copy, Debug)] +pub enum Action { + /// Follow the redirection. + Follow, + /// Do not follow the redirection, and return the redirection response as-is. + Stop, +} + +impl Action { + /// Returns `true` if the `Action` is a `Follow` value. + pub fn is_follow(&self) -> bool { + if let Action::Follow = self { + true + } else { + false + } + } + + /// Returns `true` if the `Action` is a `Stop` value. + pub fn is_stop(&self) -> bool { + if let Action::Stop = self { + true + } else { + false + } + } +} + +impl<B, E> Policy<B, E> for Action { + fn redirect(&mut self, _: &Attempt<'_>) -> Result<Action, E> { + Ok(*self) + } +} + +impl<B, E> Policy<B, E> for Result<Action, E> +where + E: Clone, +{ + fn redirect(&mut self, _: &Attempt<'_>) -> Result<Action, E> { + self.clone() + } +} + +/// Compares the origins of two URIs as per RFC 6454 sections 4. through 5. +fn eq_origin(lhs: &Uri, rhs: &Uri) -> bool { + let default_port = match (lhs.scheme(), rhs.scheme()) { + (Some(l), Some(r)) if l == r => { + if l == &Scheme::HTTP { + 80 + } else if l == &Scheme::HTTPS { + 443 + } else { + return false; + } + } + _ => return false, + }; + match (lhs.host(), rhs.host()) { + (Some(l), Some(r)) if l == r => {} + _ => return false, + } + lhs.port_u16().unwrap_or(default_port) == rhs.port_u16().unwrap_or(default_port) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn eq_origin_works() { + assert!(eq_origin( + &Uri::from_static("https://example.com/1"), + &Uri::from_static("https://example.com/2") + )); + assert!(eq_origin( + &Uri::from_static("https://example.com:443/"), + &Uri::from_static("https://example.com/") + )); + assert!(eq_origin( + &Uri::from_static("https://example.com/"), + &Uri::from_static("https://user@example.com/") + )); + + assert!(!eq_origin( + &Uri::from_static("https://example.com/"), + &Uri::from_static("https://www.example.com/") + )); + assert!(!eq_origin( + &Uri::from_static("https://example.com/"), + &Uri::from_static("http://example.com/") + )); + } +} diff --git a/vendor/tower-http/src/follow_redirect/policy/or.rs b/vendor/tower-http/src/follow_redirect/policy/or.rs new file mode 100644 index 00000000..858e57bd --- /dev/null +++ b/vendor/tower-http/src/follow_redirect/policy/or.rs @@ -0,0 +1,118 @@ +use super::{Action, Attempt, Policy}; +use http::Request; + +/// A redirection [`Policy`] that combines the results of two `Policy`s. +/// +/// See [`PolicyExt::or`][super::PolicyExt::or] for more details. +#[derive(Clone, Copy, Debug, Default)] +pub struct Or<A, B> { + a: A, + b: B, +} + +impl<A, B> Or<A, B> { + pub(crate) fn new<Bd, E>(a: A, b: B) -> Self + where + A: Policy<Bd, E>, + B: Policy<Bd, E>, + { + Or { a, b } + } +} + +impl<Bd, E, A, B> Policy<Bd, E> for Or<A, B> +where + A: Policy<Bd, E>, + B: Policy<Bd, E>, +{ + fn redirect(&mut self, attempt: &Attempt<'_>) -> Result<Action, E> { + match self.a.redirect(attempt) { + Ok(Action::Stop) | Err(_) => self.b.redirect(attempt), + a => a, + } + } + + fn on_request(&mut self, request: &mut Request<Bd>) { + self.a.on_request(request); + self.b.on_request(request); + } + + fn clone_body(&self, body: &Bd) -> Option<Bd> { + self.a.clone_body(body).or_else(|| self.b.clone_body(body)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use http::Uri; + + struct Taint<P> { + policy: P, + used: bool, + } + + impl<P> Taint<P> { + fn new(policy: P) -> Self { + Taint { + policy, + used: false, + } + } + } + + impl<B, E, P> Policy<B, E> for Taint<P> + where + P: Policy<B, E>, + { + fn redirect(&mut self, attempt: &Attempt<'_>) -> Result<Action, E> { + self.used = true; + self.policy.redirect(attempt) + } + } + + #[test] + fn redirect() { + let attempt = Attempt { + status: Default::default(), + location: &Uri::from_static("*"), + previous: &Uri::from_static("*"), + }; + + let mut a = Taint::new(Action::Follow); + let mut b = Taint::new(Action::Follow); + let mut policy = Or::new::<(), ()>(&mut a, &mut b); + assert!(Policy::<(), ()>::redirect(&mut policy, &attempt) + .unwrap() + .is_follow()); + assert!(a.used); + assert!(!b.used); // short-circuiting + + let mut a = Taint::new(Action::Stop); + let mut b = Taint::new(Action::Follow); + let mut policy = Or::new::<(), ()>(&mut a, &mut b); + assert!(Policy::<(), ()>::redirect(&mut policy, &attempt) + .unwrap() + .is_follow()); + assert!(a.used); + assert!(b.used); + + let mut a = Taint::new(Action::Follow); + let mut b = Taint::new(Action::Stop); + let mut policy = Or::new::<(), ()>(&mut a, &mut b); + assert!(Policy::<(), ()>::redirect(&mut policy, &attempt) + .unwrap() + .is_follow()); + assert!(a.used); + assert!(!b.used); + + let mut a = Taint::new(Action::Stop); + let mut b = Taint::new(Action::Stop); + let mut policy = Or::new::<(), ()>(&mut a, &mut b); + assert!(Policy::<(), ()>::redirect(&mut policy, &attempt) + .unwrap() + .is_stop()); + assert!(a.used); + assert!(b.used); + } +} diff --git a/vendor/tower-http/src/follow_redirect/policy/redirect_fn.rs b/vendor/tower-http/src/follow_redirect/policy/redirect_fn.rs new file mode 100644 index 00000000..a16593ac --- /dev/null +++ b/vendor/tower-http/src/follow_redirect/policy/redirect_fn.rs @@ -0,0 +1,39 @@ +use super::{Action, Attempt, Policy}; +use std::fmt; + +/// A redirection [`Policy`] created from a closure. +/// +/// See [`redirect_fn`] for more details. +#[derive(Clone, Copy)] +pub struct RedirectFn<F> { + f: F, +} + +impl<F> fmt::Debug for RedirectFn<F> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("RedirectFn") + .field("f", &std::any::type_name::<F>()) + .finish() + } +} + +impl<B, E, F> Policy<B, E> for RedirectFn<F> +where + F: FnMut(&Attempt<'_>) -> Result<Action, E>, +{ + fn redirect(&mut self, attempt: &Attempt<'_>) -> Result<Action, E> { + (self.f)(attempt) + } +} + +/// Create a new redirection [`Policy`] from a closure +/// `F: FnMut(&Attempt<'_>) -> Result<Action, E>`. +/// +/// [`redirect`][Policy::redirect] method of the returned `Policy` delegates to +/// the wrapped closure. +pub fn redirect_fn<F, E>(f: F) -> RedirectFn<F> +where + F: FnMut(&Attempt<'_>) -> Result<Action, E>, +{ + RedirectFn { f } +} diff --git a/vendor/tower-http/src/follow_redirect/policy/same_origin.rs b/vendor/tower-http/src/follow_redirect/policy/same_origin.rs new file mode 100644 index 00000000..cf7b7b19 --- /dev/null +++ b/vendor/tower-http/src/follow_redirect/policy/same_origin.rs @@ -0,0 +1,70 @@ +use super::{eq_origin, Action, Attempt, Policy}; +use std::fmt; + +/// A redirection [`Policy`] that stops cross-origin redirections. +#[derive(Clone, Copy, Default)] +pub struct SameOrigin { + _priv: (), +} + +impl SameOrigin { + /// Create a new [`SameOrigin`]. + pub fn new() -> Self { + Self::default() + } +} + +impl fmt::Debug for SameOrigin { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("SameOrigin").finish() + } +} + +impl<B, E> Policy<B, E> for SameOrigin { + fn redirect(&mut self, attempt: &Attempt<'_>) -> Result<Action, E> { + if eq_origin(attempt.previous(), attempt.location()) { + Ok(Action::Follow) + } else { + Ok(Action::Stop) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use http::{Request, Uri}; + + #[test] + fn works() { + let mut policy = SameOrigin::default(); + + let initial = Uri::from_static("http://example.com/old"); + let same_origin = Uri::from_static("http://example.com/new"); + let cross_origin = Uri::from_static("https://example.com/new"); + + let mut request = Request::builder().uri(initial).body(()).unwrap(); + Policy::<(), ()>::on_request(&mut policy, &mut request); + + let attempt = Attempt { + status: Default::default(), + location: &same_origin, + previous: request.uri(), + }; + assert!(Policy::<(), ()>::redirect(&mut policy, &attempt) + .unwrap() + .is_follow()); + + let mut request = Request::builder().uri(same_origin).body(()).unwrap(); + Policy::<(), ()>::on_request(&mut policy, &mut request); + + let attempt = Attempt { + status: Default::default(), + location: &cross_origin, + previous: request.uri(), + }; + assert!(Policy::<(), ()>::redirect(&mut policy, &attempt) + .unwrap() + .is_stop()); + } +} diff --git a/vendor/tower-http/src/lib.rs b/vendor/tower-http/src/lib.rs new file mode 100644 index 00000000..372bef8c --- /dev/null +++ b/vendor/tower-http/src/lib.rs @@ -0,0 +1,373 @@ +//! `async fn(HttpRequest) -> Result<HttpResponse, Error>` +//! +//! # Overview +//! +//! tower-http is a library that provides HTTP-specific middleware and utilities built on top of +//! [tower]. +//! +//! All middleware uses the [http] and [http-body] crates as the HTTP abstractions. That means +//! they're compatible with any library or framework that also uses those crates, such as +//! [hyper], [tonic], and [warp]. +//! +//! # Example server +//! +//! This example shows how to apply middleware from tower-http to a [`Service`] and then run +//! that service using [hyper]. +//! +//! ```rust,no_run +//! use tower_http::{ +//! add_extension::AddExtensionLayer, +//! compression::CompressionLayer, +//! propagate_header::PropagateHeaderLayer, +//! sensitive_headers::SetSensitiveRequestHeadersLayer, +//! set_header::SetResponseHeaderLayer, +//! trace::TraceLayer, +//! validate_request::ValidateRequestHeaderLayer, +//! }; +//! use tower::{ServiceBuilder, service_fn, BoxError}; +//! use http::{Request, Response, header::{HeaderName, CONTENT_TYPE, AUTHORIZATION}}; +//! use std::{sync::Arc, net::SocketAddr, convert::Infallible, iter::once}; +//! use bytes::Bytes; +//! use http_body_util::Full; +//! # struct DatabaseConnectionPool; +//! # impl DatabaseConnectionPool { +//! # fn new() -> DatabaseConnectionPool { DatabaseConnectionPool } +//! # } +//! # fn content_length_from_response<B>(_: &http::Response<B>) -> Option<http::HeaderValue> { None } +//! # async fn update_in_flight_requests_metric(count: usize) {} +//! +//! // Our request handler. This is where we would implement the application logic +//! // for responding to HTTP requests... +//! async fn handler(request: Request<Full<Bytes>>) -> Result<Response<Full<Bytes>>, BoxError> { +//! // ... +//! # todo!() +//! } +//! +//! // Shared state across all request handlers --- in this case, a pool of database connections. +//! struct State { +//! pool: DatabaseConnectionPool, +//! } +//! +//! #[tokio::main] +//! async fn main() { +//! // Construct the shared state. +//! let state = State { +//! pool: DatabaseConnectionPool::new(), +//! }; +//! +//! // Use tower's `ServiceBuilder` API to build a stack of tower middleware +//! // wrapping our request handler. +//! let service = ServiceBuilder::new() +//! // Mark the `Authorization` request header as sensitive so it doesn't show in logs +//! .layer(SetSensitiveRequestHeadersLayer::new(once(AUTHORIZATION))) +//! // High level logging of requests and responses +//! .layer(TraceLayer::new_for_http()) +//! // Share an `Arc<State>` with all requests +//! .layer(AddExtensionLayer::new(Arc::new(state))) +//! // Compress responses +//! .layer(CompressionLayer::new()) +//! // Propagate `X-Request-Id`s from requests to responses +//! .layer(PropagateHeaderLayer::new(HeaderName::from_static("x-request-id"))) +//! // If the response has a known size set the `Content-Length` header +//! .layer(SetResponseHeaderLayer::overriding(CONTENT_TYPE, content_length_from_response)) +//! // Authorize requests using a token +//! .layer(ValidateRequestHeaderLayer::bearer("passwordlol")) +//! // Accept only application/json, application/* and */* in a request's ACCEPT header +//! .layer(ValidateRequestHeaderLayer::accept("application/json")) +//! // Wrap a `Service` in our middleware stack +//! .service_fn(handler); +//! # let mut service = service; +//! # tower::Service::call(&mut service, Request::new(Full::default())); +//! } +//! ``` +//! +//! Keep in mind that while this example uses [hyper], tower-http supports any HTTP +//! client/server implementation that uses the [http] and [http-body] crates. +//! +//! # Example client +//! +//! tower-http middleware can also be applied to HTTP clients: +//! +//! ```rust,no_run +//! use tower_http::{ +//! decompression::DecompressionLayer, +//! set_header::SetRequestHeaderLayer, +//! trace::TraceLayer, +//! classify::StatusInRangeAsFailures, +//! }; +//! use tower::{ServiceBuilder, Service, ServiceExt}; +//! use hyper_util::{rt::TokioExecutor, client::legacy::Client}; +//! use http_body_util::Full; +//! use bytes::Bytes; +//! use http::{Request, HeaderValue, header::USER_AGENT}; +//! +//! #[tokio::main] +//! async fn main() { +//! let client = Client::builder(TokioExecutor::new()).build_http(); +//! let mut client = ServiceBuilder::new() +//! // Add tracing and consider server errors and client +//! // errors as failures. +//! .layer(TraceLayer::new( +//! StatusInRangeAsFailures::new(400..=599).into_make_classifier() +//! )) +//! // Set a `User-Agent` header on all requests. +//! .layer(SetRequestHeaderLayer::overriding( +//! USER_AGENT, +//! HeaderValue::from_static("tower-http demo") +//! )) +//! // Decompress response bodies +//! .layer(DecompressionLayer::new()) +//! // Wrap a `Client` in our middleware stack. +//! // This is possible because `Client` implements +//! // `tower::Service`. +//! .service(client); +//! +//! // Make a request +//! let request = Request::builder() +//! .uri("http://example.com") +//! .body(Full::<Bytes>::default()) +//! .unwrap(); +//! +//! let response = client +//! .ready() +//! .await +//! .unwrap() +//! .call(request) +//! .await +//! .unwrap(); +//! } +//! ``` +//! +//! # Feature Flags +//! +//! All middleware are disabled by default and can be enabled using [cargo features]. +//! +//! For example, to enable the [`Trace`] middleware, add the "trace" feature flag in +//! your `Cargo.toml`: +//! +//! ```toml +//! tower-http = { version = "0.1", features = ["trace"] } +//! ``` +//! +//! You can use `"full"` to enable everything: +//! +//! ```toml +//! tower-http = { version = "0.1", features = ["full"] } +//! ``` +//! +//! # Getting Help +//! +//! If you're new to tower its [guides] might help. In the tower-http repo we also have a [number +//! of examples][examples] showing how to put everything together. You're also welcome to ask in +//! the [`#tower` Discord channel][chat] or open an [issue] with your question. +//! +//! [tower]: https://crates.io/crates/tower +//! [http]: https://crates.io/crates/http +//! [http-body]: https://crates.io/crates/http-body +//! [hyper]: https://crates.io/crates/hyper +//! [guides]: https://github.com/tower-rs/tower/tree/master/guides +//! [tonic]: https://crates.io/crates/tonic +//! [warp]: https://crates.io/crates/warp +//! [cargo features]: https://doc.rust-lang.org/cargo/reference/features.html +//! [`AddExtension`]: crate::add_extension::AddExtension +//! [`Service`]: https://docs.rs/tower/latest/tower/trait.Service.html +//! [chat]: https://discord.gg/tokio +//! [issue]: https://github.com/tower-rs/tower-http/issues/new +//! [`Trace`]: crate::trace::Trace +//! [examples]: https://github.com/tower-rs/tower-http/tree/master/examples + +#![warn( + clippy::all, + clippy::dbg_macro, + clippy::todo, + clippy::empty_enum, + clippy::enum_glob_use, + clippy::mem_forget, + clippy::unused_self, + clippy::filter_map_next, + clippy::needless_continue, + clippy::needless_borrow, + clippy::match_wildcard_for_single_variants, + clippy::if_let_mutex, + clippy::await_holding_lock, + clippy::match_on_vec_items, + clippy::imprecise_flops, + clippy::suboptimal_flops, + clippy::lossy_float_literal, + clippy::rest_pat_in_fully_bound_structs, + clippy::fn_params_excessive_bools, + clippy::exit, + clippy::inefficient_to_string, + clippy::linkedlist, + clippy::macro_use_imports, + clippy::option_option, + clippy::verbose_file_reads, + clippy::unnested_or_patterns, + rust_2018_idioms, + future_incompatible, + nonstandard_style, + missing_docs +)] +#![deny(unreachable_pub)] +#![allow( + elided_lifetimes_in_paths, + // TODO: Remove this once the MSRV bumps to 1.42.0 or above. + clippy::match_like_matches_macro, + clippy::type_complexity +)] +#![forbid(unsafe_code)] +#![cfg_attr(docsrs, feature(doc_auto_cfg))] +#![cfg_attr(test, allow(clippy::float_cmp))] + +#[macro_use] +pub(crate) mod macros; + +#[cfg(test)] +mod test_helpers; + +#[cfg(feature = "auth")] +pub mod auth; + +#[cfg(feature = "set-header")] +pub mod set_header; + +#[cfg(feature = "propagate-header")] +pub mod propagate_header; + +#[cfg(any( + feature = "compression-br", + feature = "compression-deflate", + feature = "compression-gzip", + feature = "compression-zstd", +))] +pub mod compression; + +#[cfg(feature = "add-extension")] +pub mod add_extension; + +#[cfg(feature = "sensitive-headers")] +pub mod sensitive_headers; + +#[cfg(any( + feature = "decompression-br", + feature = "decompression-deflate", + feature = "decompression-gzip", + feature = "decompression-zstd", +))] +pub mod decompression; + +#[cfg(any( + feature = "compression-br", + feature = "compression-deflate", + feature = "compression-gzip", + feature = "compression-zstd", + feature = "decompression-br", + feature = "decompression-deflate", + feature = "decompression-gzip", + feature = "decompression-zstd", + feature = "fs" // Used for serving precompressed static files as well +))] +mod content_encoding; + +#[cfg(any( + feature = "compression-br", + feature = "compression-deflate", + feature = "compression-gzip", + feature = "compression-zstd", + feature = "decompression-br", + feature = "decompression-deflate", + feature = "decompression-gzip", + feature = "decompression-zstd", +))] +mod compression_utils; + +#[cfg(any( + feature = "compression-br", + feature = "compression-deflate", + feature = "compression-gzip", + feature = "compression-zstd", + feature = "decompression-br", + feature = "decompression-deflate", + feature = "decompression-gzip", + feature = "decompression-zstd", +))] +pub use compression_utils::CompressionLevel; + +#[cfg(feature = "map-response-body")] +pub mod map_response_body; + +#[cfg(feature = "map-request-body")] +pub mod map_request_body; + +#[cfg(feature = "trace")] +pub mod trace; + +#[cfg(feature = "follow-redirect")] +pub mod follow_redirect; + +#[cfg(feature = "limit")] +pub mod limit; + +#[cfg(feature = "metrics")] +pub mod metrics; + +#[cfg(feature = "cors")] +pub mod cors; + +#[cfg(feature = "request-id")] +pub mod request_id; + +#[cfg(feature = "catch-panic")] +pub mod catch_panic; + +#[cfg(feature = "set-status")] +pub mod set_status; + +#[cfg(feature = "timeout")] +pub mod timeout; + +#[cfg(feature = "normalize-path")] +pub mod normalize_path; + +pub mod classify; +pub mod services; + +#[cfg(feature = "util")] +mod builder; +#[cfg(feature = "util")] +mod service_ext; + +#[cfg(feature = "util")] +#[doc(inline)] +pub use self::{builder::ServiceBuilderExt, service_ext::ServiceExt}; + +#[cfg(feature = "validate-request")] +pub mod validate_request; + +#[cfg(any( + feature = "catch-panic", + feature = "decompression-br", + feature = "decompression-deflate", + feature = "decompression-gzip", + feature = "decompression-zstd", + feature = "fs", + feature = "limit", +))] +pub mod body; + +/// The latency unit used to report latencies by middleware. +#[non_exhaustive] +#[derive(Copy, Clone, Debug)] +pub enum LatencyUnit { + /// Use seconds. + Seconds, + /// Use milliseconds. + Millis, + /// Use microseconds. + Micros, + /// Use nanoseconds. + Nanos, +} + +/// Alias for a type-erased error type. +pub type BoxError = Box<dyn std::error::Error + Send + Sync>; diff --git a/vendor/tower-http/src/limit/body.rs b/vendor/tower-http/src/limit/body.rs new file mode 100644 index 00000000..4e540f8b --- /dev/null +++ b/vendor/tower-http/src/limit/body.rs @@ -0,0 +1,96 @@ +use bytes::Bytes; +use http::{HeaderValue, Response, StatusCode}; +use http_body::{Body, SizeHint}; +use http_body_util::Full; +use pin_project_lite::pin_project; +use std::pin::Pin; +use std::task::{Context, Poll}; + +pin_project! { + /// Response body for [`RequestBodyLimit`]. + /// + /// [`RequestBodyLimit`]: super::RequestBodyLimit + pub struct ResponseBody<B> { + #[pin] + inner: ResponseBodyInner<B> + } +} + +impl<B> ResponseBody<B> { + fn payload_too_large() -> Self { + Self { + inner: ResponseBodyInner::PayloadTooLarge { + body: Full::from(BODY), + }, + } + } + + pub(crate) fn new(body: B) -> Self { + Self { + inner: ResponseBodyInner::Body { body }, + } + } +} + +pin_project! { + #[project = BodyProj] + enum ResponseBodyInner<B> { + PayloadTooLarge { + #[pin] + body: Full<Bytes>, + }, + Body { + #[pin] + body: B + } + } +} + +impl<B> Body for ResponseBody<B> +where + B: Body<Data = Bytes>, +{ + type Data = Bytes; + type Error = B::Error; + + fn poll_frame( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> { + match self.project().inner.project() { + BodyProj::PayloadTooLarge { body } => body.poll_frame(cx).map_err(|err| match err {}), + BodyProj::Body { body } => body.poll_frame(cx), + } + } + + fn is_end_stream(&self) -> bool { + match &self.inner { + ResponseBodyInner::PayloadTooLarge { body } => body.is_end_stream(), + ResponseBodyInner::Body { body } => body.is_end_stream(), + } + } + + fn size_hint(&self) -> SizeHint { + match &self.inner { + ResponseBodyInner::PayloadTooLarge { body } => body.size_hint(), + ResponseBodyInner::Body { body } => body.size_hint(), + } + } +} + +const BODY: &[u8] = b"length limit exceeded"; + +pub(crate) fn create_error_response<B>() -> Response<ResponseBody<B>> +where + B: Body, +{ + let mut res = Response::new(ResponseBody::payload_too_large()); + *res.status_mut() = StatusCode::PAYLOAD_TOO_LARGE; + + #[allow(clippy::declare_interior_mutable_const)] + const TEXT_PLAIN: HeaderValue = HeaderValue::from_static("text/plain; charset=utf-8"); + res.headers_mut() + .insert(http::header::CONTENT_TYPE, TEXT_PLAIN); + + res +} diff --git a/vendor/tower-http/src/limit/future.rs b/vendor/tower-http/src/limit/future.rs new file mode 100644 index 00000000..fd913c75 --- /dev/null +++ b/vendor/tower-http/src/limit/future.rs @@ -0,0 +1,60 @@ +use super::body::create_error_response; +use super::ResponseBody; +use http::Response; +use http_body::Body; +use pin_project_lite::pin_project; +use std::future::Future; +use std::pin::Pin; +use std::task::{ready, Context, Poll}; + +pin_project! { + /// Response future for [`RequestBodyLimit`]. + /// + /// [`RequestBodyLimit`]: super::RequestBodyLimit + pub struct ResponseFuture<F> { + #[pin] + inner: ResponseFutureInner<F>, + } +} + +impl<F> ResponseFuture<F> { + pub(crate) fn payload_too_large() -> Self { + Self { + inner: ResponseFutureInner::PayloadTooLarge, + } + } + + pub(crate) fn new(future: F) -> Self { + Self { + inner: ResponseFutureInner::Future { future }, + } + } +} + +pin_project! { + #[project = ResFutProj] + enum ResponseFutureInner<F> { + PayloadTooLarge, + Future { + #[pin] + future: F, + } + } +} + +impl<ResBody, F, E> Future for ResponseFuture<F> +where + ResBody: Body, + F: Future<Output = Result<Response<ResBody>, E>>, +{ + type Output = Result<Response<ResponseBody<ResBody>>, E>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + let res = match self.project().inner.project() { + ResFutProj::PayloadTooLarge => create_error_response(), + ResFutProj::Future { future } => ready!(future.poll(cx))?.map(ResponseBody::new), + }; + + Poll::Ready(Ok(res)) + } +} diff --git a/vendor/tower-http/src/limit/layer.rs b/vendor/tower-http/src/limit/layer.rs new file mode 100644 index 00000000..2dcff71a --- /dev/null +++ b/vendor/tower-http/src/limit/layer.rs @@ -0,0 +1,32 @@ +use super::RequestBodyLimit; +use tower_layer::Layer; + +/// Layer that applies the [`RequestBodyLimit`] middleware that intercepts requests +/// with body lengths greater than the configured limit and converts them into +/// `413 Payload Too Large` responses. +/// +/// See the [module docs](crate::limit) for an example. +/// +/// [`RequestBodyLimit`]: super::RequestBodyLimit +#[derive(Clone, Copy, Debug)] +pub struct RequestBodyLimitLayer { + limit: usize, +} + +impl RequestBodyLimitLayer { + /// Create a new `RequestBodyLimitLayer` with the given body length limit. + pub fn new(limit: usize) -> Self { + Self { limit } + } +} + +impl<S> Layer<S> for RequestBodyLimitLayer { + type Service = RequestBodyLimit<S>; + + fn layer(&self, inner: S) -> Self::Service { + RequestBodyLimit { + inner, + limit: self.limit, + } + } +} diff --git a/vendor/tower-http/src/limit/mod.rs b/vendor/tower-http/src/limit/mod.rs new file mode 100644 index 00000000..3f2fede3 --- /dev/null +++ b/vendor/tower-http/src/limit/mod.rs @@ -0,0 +1,142 @@ +//! Middleware for limiting request bodies. +//! +//! This layer will also intercept requests with a `Content-Length` header +//! larger than the allowable limit and return an immediate error response +//! before reading any of the body. +//! +//! Note that payload length errors can be used by adversaries in an attempt +//! to smuggle requests. When an incoming stream is dropped due to an +//! over-sized payload, servers should close the connection or resynchronize +//! by optimistically consuming some data in an attempt to reach the end of +//! the current HTTP frame. If the incoming stream cannot be resynchronized, +//! then the connection should be closed. If you're using [hyper] this is +//! automatically handled for you. +//! +//! # Examples +//! +//! ## Limiting based on `Content-Length` +//! +//! If a `Content-Length` header is present and indicates a payload that is +//! larger than the acceptable limit, then the underlying service will not +//! be called and a `413 Payload Too Large` response will be generated. +//! +//! ```rust +//! use bytes::Bytes; +//! use std::convert::Infallible; +//! use http::{Request, Response, StatusCode, HeaderValue, header::CONTENT_LENGTH}; +//! use http_body_util::{LengthLimitError}; +//! use tower::{Service, ServiceExt, ServiceBuilder}; +//! use tower_http::{body::Limited, limit::RequestBodyLimitLayer}; +//! use http_body_util::Full; +//! +//! # #[tokio::main] +//! # async fn main() -> Result<(), Box<dyn std::error::Error>> { +//! async fn handle(req: Request<Limited<Full<Bytes>>>) -> Result<Response<Full<Bytes>>, Infallible> { +//! panic!("This should not be hit") +//! } +//! +//! let mut svc = ServiceBuilder::new() +//! // Limit incoming requests to 4096 bytes. +//! .layer(RequestBodyLimitLayer::new(4096)) +//! .service_fn(handle); +//! +//! // Call the service with a header that indicates the body is too large. +//! let mut request = Request::builder() +//! .header(CONTENT_LENGTH, HeaderValue::from_static("5000")) +//! .body(Full::<Bytes>::default()) +//! .unwrap(); +//! +//! // let response = svc.ready().await?.call(request).await?; +//! let response = svc.call(request).await?; +//! +//! assert_eq!(response.status(), StatusCode::PAYLOAD_TOO_LARGE); +//! # +//! # Ok(()) +//! # } +//! ``` +//! +//! ## Limiting without known `Content-Length` +//! +//! If a `Content-Length` header is not present, then the body will be read +//! until the configured limit has been reached. If the payload is larger than +//! the limit, the [`http_body_util::Limited`] body will return an error. This +//! error can be inspected to determine if it is a [`http_body_util::LengthLimitError`] +//! and return an appropriate response in such case. +//! +//! Note that no error will be generated if the body is never read. Similarly, +//! if the body _would be_ to large, but is never consumed beyond the length +//! limit, then no error is generated, and handling of the remaining incoming +//! data stream is left to the server implementation as described above. +//! +//! ```rust +//! # use bytes::Bytes; +//! # use std::convert::Infallible; +//! # use http::{Request, Response, StatusCode}; +//! # use http_body_util::LengthLimitError; +//! # use tower::{Service, ServiceExt, ServiceBuilder, BoxError}; +//! # use tower_http::{body::Limited, limit::RequestBodyLimitLayer}; +//! # use http_body_util::Full; +//! # use http_body_util::BodyExt; +//! # +//! # #[tokio::main] +//! # async fn main() -> Result<(), BoxError> { +//! async fn handle(req: Request<Limited<Full<Bytes>>>) -> Result<Response<Full<Bytes>>, BoxError> { +//! let data = match req.into_body().collect().await { +//! Ok(collected) => collected.to_bytes(), +//! Err(err) => { +//! if let Some(_) = err.downcast_ref::<LengthLimitError>() { +//! let mut resp = Response::new(Full::default()); +//! *resp.status_mut() = StatusCode::PAYLOAD_TOO_LARGE; +//! return Ok(resp); +//! } else { +//! return Err(err); +//! } +//! } +//! }; +//! +//! Ok(Response::new(Full::default())) +//! } +//! +//! let mut svc = ServiceBuilder::new() +//! // Limit incoming requests to 4096 bytes. +//! .layer(RequestBodyLimitLayer::new(4096)) +//! .service_fn(handle); +//! +//! // Call the service. +//! let request = Request::new(Full::<Bytes>::default()); +//! +//! let response = svc.ready().await?.call(request).await?; +//! +//! assert_eq!(response.status(), StatusCode::OK); +//! +//! // Call the service with a body that is too large. +//! let request = Request::new(Full::<Bytes>::from(Bytes::from(vec![0u8; 4097]))); +//! +//! let response = svc.ready().await?.call(request).await?; +//! +//! assert_eq!(response.status(), StatusCode::PAYLOAD_TOO_LARGE); +//! # +//! # Ok(()) +//! # } +//! ``` +//! +//! ## Limiting without `Content-Length` +//! +//! If enforcement of body size limits is desired without preemptively +//! handling requests with a `Content-Length` header indicating an over-sized +//! request, consider using [`MapRequestBody`] to wrap the request body with +//! [`http_body_util::Limited`] and checking for [`http_body_util::LengthLimitError`] +//! like in the previous example. +//! +//! [`MapRequestBody`]: crate::map_request_body +//! [hyper]: https://crates.io/crates/hyper + +mod body; +mod future; +mod layer; +mod service; + +pub use body::ResponseBody; +pub use future::ResponseFuture; +pub use layer::RequestBodyLimitLayer; +pub use service::RequestBodyLimit; diff --git a/vendor/tower-http/src/limit/service.rs b/vendor/tower-http/src/limit/service.rs new file mode 100644 index 00000000..fdf65d25 --- /dev/null +++ b/vendor/tower-http/src/limit/service.rs @@ -0,0 +1,64 @@ +use super::{RequestBodyLimitLayer, ResponseBody, ResponseFuture}; +use crate::body::Limited; +use http::{Request, Response}; +use http_body::Body; +use std::task::{Context, Poll}; +use tower_service::Service; + +/// Middleware that intercepts requests with body lengths greater than the +/// configured limit and converts them into `413 Payload Too Large` responses. +/// +/// See the [module docs](crate::limit) for an example. +#[derive(Clone, Copy, Debug)] +pub struct RequestBodyLimit<S> { + pub(crate) inner: S, + pub(crate) limit: usize, +} + +impl<S> RequestBodyLimit<S> { + /// Create a new `RequestBodyLimit` with the given body length limit. + pub fn new(inner: S, limit: usize) -> Self { + Self { inner, limit } + } + + define_inner_service_accessors!(); + + /// Returns a new [`Layer`] that wraps services with a `RequestBodyLimit` middleware. + /// + /// [`Layer`]: tower_layer::Layer + pub fn layer(limit: usize) -> RequestBodyLimitLayer { + RequestBodyLimitLayer::new(limit) + } +} + +impl<ReqBody, ResBody, S> Service<Request<ReqBody>> for RequestBodyLimit<S> +where + ResBody: Body, + S: Service<Request<Limited<ReqBody>>, Response = Response<ResBody>>, +{ + type Response = Response<ResponseBody<ResBody>>; + type Error = S::Error; + type Future = ResponseFuture<S::Future>; + + #[inline] + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, req: Request<ReqBody>) -> Self::Future { + let content_length = req + .headers() + .get(http::header::CONTENT_LENGTH) + .and_then(|value| value.to_str().ok()?.parse::<usize>().ok()); + + let body_limit = match content_length { + Some(len) if len > self.limit => return ResponseFuture::payload_too_large(), + Some(len) => self.limit.min(len), + None => self.limit, + }; + + let req = req.map(|body| Limited::new(http_body_util::Limited::new(body, body_limit))); + + ResponseFuture::new(self.inner.call(req)) + } +} diff --git a/vendor/tower-http/src/macros.rs b/vendor/tower-http/src/macros.rs new file mode 100644 index 00000000..f58d34a6 --- /dev/null +++ b/vendor/tower-http/src/macros.rs @@ -0,0 +1,105 @@ +#[allow(unused_macros)] +macro_rules! define_inner_service_accessors { + () => { + /// Gets a reference to the underlying service. + pub fn get_ref(&self) -> &S { + &self.inner + } + + /// Gets a mutable reference to the underlying service. + pub fn get_mut(&mut self) -> &mut S { + &mut self.inner + } + + /// Consumes `self`, returning the underlying service. + pub fn into_inner(self) -> S { + self.inner + } + }; +} + +#[allow(unused_macros)] +macro_rules! opaque_body { + ($(#[$m:meta])* pub type $name:ident = $actual:ty;) => { + opaque_body! { + $(#[$m])* pub type $name<> = $actual; + } + }; + + ($(#[$m:meta])* pub type $name:ident<$($param:ident),*> = $actual:ty;) => { + pin_project_lite::pin_project! { + $(#[$m])* + pub struct $name<$($param),*> { + #[pin] + pub(crate) inner: $actual + } + } + + impl<$($param),*> $name<$($param),*> { + pub(crate) fn new(inner: $actual) -> Self { + Self { inner } + } + } + + impl<$($param),*> http_body::Body for $name<$($param),*> { + type Data = <$actual as http_body::Body>::Data; + type Error = <$actual as http_body::Body>::Error; + + #[inline] + fn poll_frame( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> { + self.project().inner.poll_frame(cx) + } + + #[inline] + fn is_end_stream(&self) -> bool { + http_body::Body::is_end_stream(&self.inner) + } + + #[inline] + fn size_hint(&self) -> http_body::SizeHint { + http_body::Body::size_hint(&self.inner) + } + } + }; +} + +#[allow(unused_macros)] +macro_rules! opaque_future { + ($(#[$m:meta])* pub type $name:ident<$($param:ident),+> = $actual:ty;) => { + pin_project_lite::pin_project! { + $(#[$m])* + pub struct $name<$($param),+> { + #[pin] + inner: $actual + } + } + + impl<$($param),+> $name<$($param),+> { + pub(crate) fn new(inner: $actual) -> Self { + Self { + inner + } + } + } + + impl<$($param),+> std::fmt::Debug for $name<$($param),+> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_tuple(stringify!($name)).field(&format_args!("...")).finish() + } + } + + impl<$($param),+> std::future::Future for $name<$($param),+> + where + $actual: std::future::Future, + { + type Output = <$actual as std::future::Future>::Output; + #[inline] + fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll<Self::Output> { + self.project().inner.poll(cx) + } + } + } +} diff --git a/vendor/tower-http/src/map_request_body.rs b/vendor/tower-http/src/map_request_body.rs new file mode 100644 index 00000000..dd067e92 --- /dev/null +++ b/vendor/tower-http/src/map_request_body.rs @@ -0,0 +1,157 @@ +//! Apply a transformation to the request body. +//! +//! # Example +//! +//! ``` +//! use http_body_util::Full; +//! use bytes::Bytes; +//! use http::{Request, Response}; +//! use std::convert::Infallible; +//! use std::{pin::Pin, task::{ready, Context, Poll}}; +//! use tower::{ServiceBuilder, service_fn, ServiceExt, Service}; +//! use tower_http::map_request_body::MapRequestBodyLayer; +//! +//! // A wrapper for a `Full<Bytes>` +//! struct BodyWrapper { +//! inner: Full<Bytes>, +//! } +//! +//! impl BodyWrapper { +//! fn new(inner: Full<Bytes>) -> Self { +//! Self { inner } +//! } +//! } +//! +//! impl http_body::Body for BodyWrapper { +//! // ... +//! # type Data = Bytes; +//! # type Error = tower::BoxError; +//! # fn poll_frame( +//! # self: Pin<&mut Self>, +//! # cx: &mut Context<'_> +//! # ) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> { unimplemented!() } +//! # fn is_end_stream(&self) -> bool { unimplemented!() } +//! # fn size_hint(&self) -> http_body::SizeHint { unimplemented!() } +//! } +//! +//! async fn handle<B>(_: Request<B>) -> Result<Response<Full<Bytes>>, Infallible> { +//! // ... +//! # Ok(Response::new(Full::default())) +//! } +//! +//! # #[tokio::main] +//! # async fn main() -> Result<(), Box<dyn std::error::Error>> { +//! let mut svc = ServiceBuilder::new() +//! // Wrap response bodies in `BodyWrapper` +//! .layer(MapRequestBodyLayer::new(BodyWrapper::new)) +//! .service_fn(handle); +//! +//! // Call the service +//! let request = Request::new(Full::default()); +//! +//! svc.ready().await?.call(request).await?; +//! # Ok(()) +//! # } +//! ``` + +use http::{Request, Response}; +use std::{ + fmt, + task::{Context, Poll}, +}; +use tower_layer::Layer; +use tower_service::Service; + +/// Apply a transformation to the request body. +/// +/// See the [module docs](crate::map_request_body) for an example. +#[derive(Clone)] +pub struct MapRequestBodyLayer<F> { + f: F, +} + +impl<F> MapRequestBodyLayer<F> { + /// Create a new [`MapRequestBodyLayer`]. + /// + /// `F` is expected to be a function that takes a body and returns another body. + pub fn new(f: F) -> Self { + Self { f } + } +} + +impl<S, F> Layer<S> for MapRequestBodyLayer<F> +where + F: Clone, +{ + type Service = MapRequestBody<S, F>; + + fn layer(&self, inner: S) -> Self::Service { + MapRequestBody::new(inner, self.f.clone()) + } +} + +impl<F> fmt::Debug for MapRequestBodyLayer<F> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("MapRequestBodyLayer") + .field("f", &std::any::type_name::<F>()) + .finish() + } +} + +/// Apply a transformation to the request body. +/// +/// See the [module docs](crate::map_request_body) for an example. +#[derive(Clone)] +pub struct MapRequestBody<S, F> { + inner: S, + f: F, +} + +impl<S, F> MapRequestBody<S, F> { + /// Create a new [`MapRequestBody`]. + /// + /// `F` is expected to be a function that takes a body and returns another body. + pub fn new(service: S, f: F) -> Self { + Self { inner: service, f } + } + + /// Returns a new [`Layer`] that wraps services with a `MapRequestBodyLayer` middleware. + /// + /// [`Layer`]: tower_layer::Layer + pub fn layer(f: F) -> MapRequestBodyLayer<F> { + MapRequestBodyLayer::new(f) + } + + define_inner_service_accessors!(); +} + +impl<F, S, ReqBody, ResBody, NewReqBody> Service<Request<ReqBody>> for MapRequestBody<S, F> +where + S: Service<Request<NewReqBody>, Response = Response<ResBody>>, + F: FnMut(ReqBody) -> NewReqBody, +{ + type Response = S::Response; + type Error = S::Error; + type Future = S::Future; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, req: Request<ReqBody>) -> Self::Future { + let req = req.map(&mut self.f); + self.inner.call(req) + } +} + +impl<S, F> fmt::Debug for MapRequestBody<S, F> +where + S: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("MapRequestBody") + .field("inner", &self.inner) + .field("f", &std::any::type_name::<F>()) + .finish() + } +} diff --git a/vendor/tower-http/src/map_response_body.rs b/vendor/tower-http/src/map_response_body.rs new file mode 100644 index 00000000..5329e5d5 --- /dev/null +++ b/vendor/tower-http/src/map_response_body.rs @@ -0,0 +1,185 @@ +//! Apply a transformation to the response body. +//! +//! # Example +//! +//! ``` +//! use bytes::Bytes; +//! use http::{Request, Response}; +//! use http_body_util::Full; +//! use std::convert::Infallible; +//! use std::{pin::Pin, task::{ready, Context, Poll}}; +//! use tower::{ServiceBuilder, service_fn, ServiceExt, Service}; +//! use tower_http::map_response_body::MapResponseBodyLayer; +//! +//! // A wrapper for a `Full<Bytes>` +//! struct BodyWrapper { +//! inner: Full<Bytes>, +//! } +//! +//! impl BodyWrapper { +//! fn new(inner: Full<Bytes>) -> Self { +//! Self { inner } +//! } +//! } +//! +//! impl http_body::Body for BodyWrapper { +//! // ... +//! # type Data = Bytes; +//! # type Error = tower::BoxError; +//! # fn poll_frame( +//! # self: Pin<&mut Self>, +//! # cx: &mut Context<'_> +//! # ) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> { unimplemented!() } +//! # fn is_end_stream(&self) -> bool { unimplemented!() } +//! # fn size_hint(&self) -> http_body::SizeHint { unimplemented!() } +//! } +//! +//! async fn handle<B>(_: Request<B>) -> Result<Response<Full<Bytes>>, Infallible> { +//! // ... +//! # Ok(Response::new(Full::default())) +//! } +//! +//! # #[tokio::main] +//! # async fn main() -> Result<(), Box<dyn std::error::Error>> { +//! let mut svc = ServiceBuilder::new() +//! // Wrap response bodies in `BodyWrapper` +//! .layer(MapResponseBodyLayer::new(BodyWrapper::new)) +//! .service_fn(handle); +//! +//! // Call the service +//! let request = Request::new(Full::<Bytes>::from("foobar")); +//! +//! svc.ready().await?.call(request).await?; +//! # Ok(()) +//! # } +//! ``` + +use http::{Request, Response}; +use pin_project_lite::pin_project; +use std::future::Future; +use std::{ + fmt, + pin::Pin, + task::{ready, Context, Poll}, +}; +use tower_layer::Layer; +use tower_service::Service; + +/// Apply a transformation to the response body. +/// +/// See the [module docs](crate::map_response_body) for an example. +#[derive(Clone)] +pub struct MapResponseBodyLayer<F> { + f: F, +} + +impl<F> MapResponseBodyLayer<F> { + /// Create a new [`MapResponseBodyLayer`]. + /// + /// `F` is expected to be a function that takes a body and returns another body. + pub fn new(f: F) -> Self { + Self { f } + } +} + +impl<S, F> Layer<S> for MapResponseBodyLayer<F> +where + F: Clone, +{ + type Service = MapResponseBody<S, F>; + + fn layer(&self, inner: S) -> Self::Service { + MapResponseBody::new(inner, self.f.clone()) + } +} + +impl<F> fmt::Debug for MapResponseBodyLayer<F> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("MapResponseBodyLayer") + .field("f", &std::any::type_name::<F>()) + .finish() + } +} + +/// Apply a transformation to the response body. +/// +/// See the [module docs](crate::map_response_body) for an example. +#[derive(Clone)] +pub struct MapResponseBody<S, F> { + inner: S, + f: F, +} + +impl<S, F> MapResponseBody<S, F> { + /// Create a new [`MapResponseBody`]. + /// + /// `F` is expected to be a function that takes a body and returns another body. + pub fn new(service: S, f: F) -> Self { + Self { inner: service, f } + } + + /// Returns a new [`Layer`] that wraps services with a `MapResponseBodyLayer` middleware. + /// + /// [`Layer`]: tower_layer::Layer + pub fn layer(f: F) -> MapResponseBodyLayer<F> { + MapResponseBodyLayer::new(f) + } + + define_inner_service_accessors!(); +} + +impl<F, S, ReqBody, ResBody, NewResBody> Service<Request<ReqBody>> for MapResponseBody<S, F> +where + S: Service<Request<ReqBody>, Response = Response<ResBody>>, + F: FnMut(ResBody) -> NewResBody + Clone, +{ + type Response = Response<NewResBody>; + type Error = S::Error; + type Future = ResponseFuture<S::Future, F>; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, req: Request<ReqBody>) -> Self::Future { + ResponseFuture { + inner: self.inner.call(req), + f: self.f.clone(), + } + } +} + +impl<S, F> fmt::Debug for MapResponseBody<S, F> +where + S: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("MapResponseBody") + .field("inner", &self.inner) + .field("f", &std::any::type_name::<F>()) + .finish() + } +} + +pin_project! { + /// Response future for [`MapResponseBody`]. + pub struct ResponseFuture<Fut, F> { + #[pin] + inner: Fut, + f: F, + } +} + +impl<Fut, F, ResBody, E, NewResBody> Future for ResponseFuture<Fut, F> +where + Fut: Future<Output = Result<Response<ResBody>, E>>, + F: FnMut(ResBody) -> NewResBody, +{ + type Output = Result<Response<NewResBody>, E>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + let this = self.project(); + let res = ready!(this.inner.poll(cx)?); + Poll::Ready(Ok(res.map(this.f))) + } +} diff --git a/vendor/tower-http/src/metrics/in_flight_requests.rs b/vendor/tower-http/src/metrics/in_flight_requests.rs new file mode 100644 index 00000000..dbb5e2ff --- /dev/null +++ b/vendor/tower-http/src/metrics/in_flight_requests.rs @@ -0,0 +1,327 @@ +//! Measure the number of in-flight requests. +//! +//! In-flight requests is the number of requests a service is currently processing. The processing +//! of a request starts when it is received by the service (`tower::Service::call` is called) and +//! is considered complete when the response body is consumed, dropped, or an error happens. +//! +//! # Example +//! +//! ``` +//! use tower::{Service, ServiceExt, ServiceBuilder}; +//! use tower_http::metrics::InFlightRequestsLayer; +//! use http::{Request, Response}; +//! use bytes::Bytes; +//! use http_body_util::Full; +//! use std::{time::Duration, convert::Infallible}; +//! +//! async fn handle(req: Request<Full<Bytes>>) -> Result<Response<Full<Bytes>>, Infallible> { +//! // ... +//! # Ok(Response::new(Full::default())) +//! } +//! +//! async fn update_in_flight_requests_metric(count: usize) { +//! // ... +//! } +//! +//! # #[tokio::main] +//! # async fn main() -> Result<(), Box<dyn std::error::Error>> { +//! // Create a `Layer` with an associated counter. +//! let (in_flight_requests_layer, counter) = InFlightRequestsLayer::pair(); +//! +//! // Spawn a task that will receive the number of in-flight requests every 10 seconds. +//! tokio::spawn( +//! counter.run_emitter(Duration::from_secs(10), |count| async move { +//! update_in_flight_requests_metric(count).await; +//! }), +//! ); +//! +//! let mut service = ServiceBuilder::new() +//! // Keep track of the number of in-flight requests. This will increment and decrement +//! // `counter` automatically. +//! .layer(in_flight_requests_layer) +//! .service_fn(handle); +//! +//! // Call the service. +//! let response = service +//! .ready() +//! .await? +//! .call(Request::new(Full::default())) +//! .await?; +//! # Ok(()) +//! # } +//! ``` + +use http::{Request, Response}; +use http_body::Body; +use pin_project_lite::pin_project; +use std::{ + future::Future, + pin::Pin, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }, + task::{ready, Context, Poll}, + time::Duration, +}; +use tower_layer::Layer; +use tower_service::Service; + +/// Layer for applying [`InFlightRequests`] which counts the number of in-flight requests. +/// +/// See the [module docs](crate::metrics::in_flight_requests) for more details. +#[derive(Clone, Debug)] +pub struct InFlightRequestsLayer { + counter: InFlightRequestsCounter, +} + +impl InFlightRequestsLayer { + /// Create a new `InFlightRequestsLayer` and its associated counter. + pub fn pair() -> (Self, InFlightRequestsCounter) { + let counter = InFlightRequestsCounter::new(); + let layer = Self::new(counter.clone()); + (layer, counter) + } + + /// Create a new `InFlightRequestsLayer` that will update the given counter. + pub fn new(counter: InFlightRequestsCounter) -> Self { + Self { counter } + } +} + +impl<S> Layer<S> for InFlightRequestsLayer { + type Service = InFlightRequests<S>; + + fn layer(&self, inner: S) -> Self::Service { + InFlightRequests { + inner, + counter: self.counter.clone(), + } + } +} + +/// Middleware that counts the number of in-flight requests. +/// +/// See the [module docs](crate::metrics::in_flight_requests) for more details. +#[derive(Clone, Debug)] +pub struct InFlightRequests<S> { + inner: S, + counter: InFlightRequestsCounter, +} + +impl<S> InFlightRequests<S> { + /// Create a new `InFlightRequests` and its associated counter. + pub fn pair(inner: S) -> (Self, InFlightRequestsCounter) { + let counter = InFlightRequestsCounter::new(); + let service = Self::new(inner, counter.clone()); + (service, counter) + } + + /// Create a new `InFlightRequests` that will update the given counter. + pub fn new(inner: S, counter: InFlightRequestsCounter) -> Self { + Self { inner, counter } + } + + define_inner_service_accessors!(); +} + +/// An atomic counter that keeps track of the number of in-flight requests. +/// +/// This will normally combined with [`InFlightRequestsLayer`] or [`InFlightRequests`] which will +/// update the counter as requests arrive. +#[derive(Debug, Clone, Default)] +pub struct InFlightRequestsCounter { + count: Arc<AtomicUsize>, +} + +impl InFlightRequestsCounter { + /// Create a new `InFlightRequestsCounter`. + pub fn new() -> Self { + Self::default() + } + + /// Get the current number of in-flight requests. + pub fn get(&self) -> usize { + self.count.load(Ordering::Relaxed) + } + + fn increment(&self) -> IncrementGuard { + self.count.fetch_add(1, Ordering::Relaxed); + IncrementGuard { + count: self.count.clone(), + } + } + + /// Run a future every `interval` which receives the current number of in-flight requests. + /// + /// This can be used to send the current count to your metrics system. + /// + /// This function will loop forever so normally it is called with [`tokio::spawn`]: + /// + /// ```rust,no_run + /// use tower_http::metrics::in_flight_requests::InFlightRequestsCounter; + /// use std::time::Duration; + /// + /// let counter = InFlightRequestsCounter::new(); + /// + /// tokio::spawn( + /// counter.run_emitter(Duration::from_secs(10), |count: usize| async move { + /// // Send `count` to metrics system. + /// }), + /// ); + /// ``` + pub async fn run_emitter<F, Fut>(mut self, interval: Duration, mut emit: F) + where + F: FnMut(usize) -> Fut + Send + 'static, + Fut: Future<Output = ()> + Send, + { + let mut interval = tokio::time::interval(interval); + + loop { + // if all producers have gone away we don't need to emit anymore + match Arc::try_unwrap(self.count) { + Ok(_) => return, + Err(shared_count) => { + self = Self { + count: shared_count, + } + } + } + + interval.tick().await; + emit(self.get()).await; + } + } +} + +struct IncrementGuard { + count: Arc<AtomicUsize>, +} + +impl Drop for IncrementGuard { + fn drop(&mut self) { + self.count.fetch_sub(1, Ordering::Relaxed); + } +} + +impl<S, R, ResBody> Service<Request<R>> for InFlightRequests<S> +where + S: Service<Request<R>, Response = Response<ResBody>>, +{ + type Response = Response<ResponseBody<ResBody>>; + type Error = S::Error; + type Future = ResponseFuture<S::Future>; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, req: Request<R>) -> Self::Future { + let guard = self.counter.increment(); + ResponseFuture { + inner: self.inner.call(req), + guard: Some(guard), + } + } +} + +pin_project! { + /// Response future for [`InFlightRequests`]. + pub struct ResponseFuture<F> { + #[pin] + inner: F, + guard: Option<IncrementGuard>, + } +} + +impl<F, B, E> Future for ResponseFuture<F> +where + F: Future<Output = Result<Response<B>, E>>, +{ + type Output = Result<Response<ResponseBody<B>>, E>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + let this = self.project(); + let response = ready!(this.inner.poll(cx))?; + let guard = this.guard.take().unwrap(); + let response = response.map(move |body| ResponseBody { inner: body, guard }); + + Poll::Ready(Ok(response)) + } +} + +pin_project! { + /// Response body for [`InFlightRequests`]. + pub struct ResponseBody<B> { + #[pin] + inner: B, + guard: IncrementGuard, + } +} + +impl<B> Body for ResponseBody<B> +where + B: Body, +{ + type Data = B::Data; + type Error = B::Error; + + #[inline] + fn poll_frame( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> { + self.project().inner.poll_frame(cx) + } + + #[inline] + fn is_end_stream(&self) -> bool { + self.inner.is_end_stream() + } + + #[inline] + fn size_hint(&self) -> http_body::SizeHint { + self.inner.size_hint() + } +} + +#[cfg(test)] +mod tests { + #[allow(unused_imports)] + use super::*; + use crate::test_helpers::Body; + use http::Request; + use tower::{BoxError, ServiceBuilder}; + + #[tokio::test] + async fn basic() { + let (in_flight_requests_layer, counter) = InFlightRequestsLayer::pair(); + + let mut service = ServiceBuilder::new() + .layer(in_flight_requests_layer) + .service_fn(echo); + assert_eq!(counter.get(), 0); + + // driving service to ready shouldn't increment the counter + std::future::poll_fn(|cx| service.poll_ready(cx)) + .await + .unwrap(); + assert_eq!(counter.get(), 0); + + // creating the response future should increment the count + let response_future = service.call(Request::new(Body::empty())); + assert_eq!(counter.get(), 1); + + // count shouldn't decrement until the full body has been comsumed + let response = response_future.await.unwrap(); + assert_eq!(counter.get(), 1); + + let body = response.into_body(); + crate::test_helpers::to_bytes(body).await.unwrap(); + assert_eq!(counter.get(), 0); + } + + async fn echo(req: Request<Body>) -> Result<Response<Body>, BoxError> { + Ok(Response::new(req.into_body())) + } +} diff --git a/vendor/tower-http/src/metrics/mod.rs b/vendor/tower-http/src/metrics/mod.rs new file mode 100644 index 00000000..317d17b8 --- /dev/null +++ b/vendor/tower-http/src/metrics/mod.rs @@ -0,0 +1,12 @@ +//! Middlewares for adding metrics to services. +//! +//! Supported metrics: +//! +//! - [In-flight requests][]: Measure the number of requests a service is currently processing. +//! +//! [In-flight requests]: in_flight_requests + +pub mod in_flight_requests; + +#[doc(inline)] +pub use self::in_flight_requests::{InFlightRequests, InFlightRequestsLayer}; diff --git a/vendor/tower-http/src/normalize_path.rs b/vendor/tower-http/src/normalize_path.rs new file mode 100644 index 00000000..f9b9dd2e --- /dev/null +++ b/vendor/tower-http/src/normalize_path.rs @@ -0,0 +1,384 @@ +//! Middleware that normalizes paths. +//! +//! # Example +//! +//! ``` +//! use tower_http::normalize_path::NormalizePathLayer; +//! use http::{Request, Response, StatusCode}; +//! use http_body_util::Full; +//! use bytes::Bytes; +//! use std::{iter::once, convert::Infallible}; +//! use tower::{ServiceBuilder, Service, ServiceExt}; +//! +//! # #[tokio::main] +//! # async fn main() -> Result<(), Box<dyn std::error::Error>> { +//! async fn handle(req: Request<Full<Bytes>>) -> Result<Response<Full<Bytes>>, Infallible> { +//! // `req.uri().path()` will not have trailing slashes +//! # Ok(Response::new(Full::default())) +//! } +//! +//! let mut service = ServiceBuilder::new() +//! // trim trailing slashes from paths +//! .layer(NormalizePathLayer::trim_trailing_slash()) +//! .service_fn(handle); +//! +//! // call the service +//! let request = Request::builder() +//! // `handle` will see `/foo` +//! .uri("/foo/") +//! .body(Full::default())?; +//! +//! service.ready().await?.call(request).await?; +//! # +//! # Ok(()) +//! # } +//! ``` + +use http::{Request, Response, Uri}; +use std::{ + borrow::Cow, + task::{Context, Poll}, +}; +use tower_layer::Layer; +use tower_service::Service; + +/// Different modes of normalizing paths +#[derive(Debug, Copy, Clone)] +enum NormalizeMode { + /// Normalizes paths by trimming the trailing slashes, e.g. /foo/ -> /foo + Trim, + /// Normalizes paths by appending trailing slash, e.g. /foo -> /foo/ + Append, +} + +/// Layer that applies [`NormalizePath`] which normalizes paths. +/// +/// See the [module docs](self) for more details. +#[derive(Debug, Copy, Clone)] +pub struct NormalizePathLayer { + mode: NormalizeMode, +} + +impl NormalizePathLayer { + /// Create a new [`NormalizePathLayer`]. + /// + /// Any trailing slashes from request paths will be removed. For example, a request with `/foo/` + /// will be changed to `/foo` before reaching the inner service. + pub fn trim_trailing_slash() -> Self { + NormalizePathLayer { + mode: NormalizeMode::Trim, + } + } + + /// Create a new [`NormalizePathLayer`]. + /// + /// Request paths without trailing slash will be appended with a trailing slash. For example, a request with `/foo` + /// will be changed to `/foo/` before reaching the inner service. + pub fn append_trailing_slash() -> Self { + NormalizePathLayer { + mode: NormalizeMode::Append, + } + } +} + +impl<S> Layer<S> for NormalizePathLayer { + type Service = NormalizePath<S>; + + fn layer(&self, inner: S) -> Self::Service { + NormalizePath { + mode: self.mode, + inner, + } + } +} + +/// Middleware that normalizes paths. +/// +/// See the [module docs](self) for more details. +#[derive(Debug, Copy, Clone)] +pub struct NormalizePath<S> { + mode: NormalizeMode, + inner: S, +} + +impl<S> NormalizePath<S> { + /// Construct a new [`NormalizePath`] with trim mode. + pub fn trim_trailing_slash(inner: S) -> Self { + Self { + mode: NormalizeMode::Trim, + inner, + } + } + + /// Construct a new [`NormalizePath`] with append mode. + pub fn append_trailing_slash(inner: S) -> Self { + Self { + mode: NormalizeMode::Append, + inner, + } + } + + define_inner_service_accessors!(); +} + +impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for NormalizePath<S> +where + S: Service<Request<ReqBody>, Response = Response<ResBody>>, +{ + type Response = S::Response; + type Error = S::Error; + type Future = S::Future; + + #[inline] + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future { + match self.mode { + NormalizeMode::Trim => trim_trailing_slash(req.uri_mut()), + NormalizeMode::Append => append_trailing_slash(req.uri_mut()), + } + self.inner.call(req) + } +} + +fn trim_trailing_slash(uri: &mut Uri) { + if !uri.path().ends_with('/') && !uri.path().starts_with("//") { + return; + } + + let new_path = format!("/{}", uri.path().trim_matches('/')); + + let mut parts = uri.clone().into_parts(); + + let new_path_and_query = if let Some(path_and_query) = &parts.path_and_query { + let new_path_and_query = if let Some(query) = path_and_query.query() { + Cow::Owned(format!("{}?{}", new_path, query)) + } else { + new_path.into() + } + .parse() + .unwrap(); + + Some(new_path_and_query) + } else { + None + }; + + parts.path_and_query = new_path_and_query; + if let Ok(new_uri) = Uri::from_parts(parts) { + *uri = new_uri; + } +} + +fn append_trailing_slash(uri: &mut Uri) { + if uri.path().ends_with("/") && !uri.path().ends_with("//") { + return; + } + + let trimmed = uri.path().trim_matches('/'); + let new_path = if trimmed.is_empty() { + "/".to_string() + } else { + format!("/{trimmed}/") + }; + + let mut parts = uri.clone().into_parts(); + + let new_path_and_query = if let Some(path_and_query) = &parts.path_and_query { + let new_path_and_query = if let Some(query) = path_and_query.query() { + Cow::Owned(format!("{new_path}?{query}")) + } else { + new_path.into() + } + .parse() + .unwrap(); + + Some(new_path_and_query) + } else { + Some(new_path.parse().unwrap()) + }; + + parts.path_and_query = new_path_and_query; + if let Ok(new_uri) = Uri::from_parts(parts) { + *uri = new_uri; + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::convert::Infallible; + use tower::{ServiceBuilder, ServiceExt}; + + #[tokio::test] + async fn trim_works() { + async fn handle(request: Request<()>) -> Result<Response<String>, Infallible> { + Ok(Response::new(request.uri().to_string())) + } + + let mut svc = ServiceBuilder::new() + .layer(NormalizePathLayer::trim_trailing_slash()) + .service_fn(handle); + + let body = svc + .ready() + .await + .unwrap() + .call(Request::builder().uri("/foo/").body(()).unwrap()) + .await + .unwrap() + .into_body(); + + assert_eq!(body, "/foo"); + } + + #[test] + fn is_noop_if_no_trailing_slash() { + let mut uri = "/foo".parse::<Uri>().unwrap(); + trim_trailing_slash(&mut uri); + assert_eq!(uri, "/foo"); + } + + #[test] + fn maintains_query() { + let mut uri = "/foo/?a=a".parse::<Uri>().unwrap(); + trim_trailing_slash(&mut uri); + assert_eq!(uri, "/foo?a=a"); + } + + #[test] + fn removes_multiple_trailing_slashes() { + let mut uri = "/foo////".parse::<Uri>().unwrap(); + trim_trailing_slash(&mut uri); + assert_eq!(uri, "/foo"); + } + + #[test] + fn removes_multiple_trailing_slashes_even_with_query() { + let mut uri = "/foo////?a=a".parse::<Uri>().unwrap(); + trim_trailing_slash(&mut uri); + assert_eq!(uri, "/foo?a=a"); + } + + #[test] + fn is_noop_on_index() { + let mut uri = "/".parse::<Uri>().unwrap(); + trim_trailing_slash(&mut uri); + assert_eq!(uri, "/"); + } + + #[test] + fn removes_multiple_trailing_slashes_on_index() { + let mut uri = "////".parse::<Uri>().unwrap(); + trim_trailing_slash(&mut uri); + assert_eq!(uri, "/"); + } + + #[test] + fn removes_multiple_trailing_slashes_on_index_even_with_query() { + let mut uri = "////?a=a".parse::<Uri>().unwrap(); + trim_trailing_slash(&mut uri); + assert_eq!(uri, "/?a=a"); + } + + #[test] + fn removes_multiple_preceding_slashes_even_with_query() { + let mut uri = "///foo//?a=a".parse::<Uri>().unwrap(); + trim_trailing_slash(&mut uri); + assert_eq!(uri, "/foo?a=a"); + } + + #[test] + fn removes_multiple_preceding_slashes() { + let mut uri = "///foo".parse::<Uri>().unwrap(); + trim_trailing_slash(&mut uri); + assert_eq!(uri, "/foo"); + } + + #[tokio::test] + async fn append_works() { + async fn handle(request: Request<()>) -> Result<Response<String>, Infallible> { + Ok(Response::new(request.uri().to_string())) + } + + let mut svc = ServiceBuilder::new() + .layer(NormalizePathLayer::append_trailing_slash()) + .service_fn(handle); + + let body = svc + .ready() + .await + .unwrap() + .call(Request::builder().uri("/foo").body(()).unwrap()) + .await + .unwrap() + .into_body(); + + assert_eq!(body, "/foo/"); + } + + #[test] + fn is_noop_if_trailing_slash() { + let mut uri = "/foo/".parse::<Uri>().unwrap(); + append_trailing_slash(&mut uri); + assert_eq!(uri, "/foo/"); + } + + #[test] + fn append_maintains_query() { + let mut uri = "/foo?a=a".parse::<Uri>().unwrap(); + append_trailing_slash(&mut uri); + assert_eq!(uri, "/foo/?a=a"); + } + + #[test] + fn append_only_keeps_one_slash() { + let mut uri = "/foo////".parse::<Uri>().unwrap(); + append_trailing_slash(&mut uri); + assert_eq!(uri, "/foo/"); + } + + #[test] + fn append_only_keeps_one_slash_even_with_query() { + let mut uri = "/foo////?a=a".parse::<Uri>().unwrap(); + append_trailing_slash(&mut uri); + assert_eq!(uri, "/foo/?a=a"); + } + + #[test] + fn append_is_noop_on_index() { + let mut uri = "/".parse::<Uri>().unwrap(); + append_trailing_slash(&mut uri); + assert_eq!(uri, "/"); + } + + #[test] + fn append_removes_multiple_trailing_slashes_on_index() { + let mut uri = "////".parse::<Uri>().unwrap(); + append_trailing_slash(&mut uri); + assert_eq!(uri, "/"); + } + + #[test] + fn append_removes_multiple_trailing_slashes_on_index_even_with_query() { + let mut uri = "////?a=a".parse::<Uri>().unwrap(); + append_trailing_slash(&mut uri); + assert_eq!(uri, "/?a=a"); + } + + #[test] + fn append_removes_multiple_preceding_slashes_even_with_query() { + let mut uri = "///foo//?a=a".parse::<Uri>().unwrap(); + append_trailing_slash(&mut uri); + assert_eq!(uri, "/foo/?a=a"); + } + + #[test] + fn append_removes_multiple_preceding_slashes() { + let mut uri = "///foo".parse::<Uri>().unwrap(); + append_trailing_slash(&mut uri); + assert_eq!(uri, "/foo/"); + } +} diff --git a/vendor/tower-http/src/propagate_header.rs b/vendor/tower-http/src/propagate_header.rs new file mode 100644 index 00000000..6c77ec32 --- /dev/null +++ b/vendor/tower-http/src/propagate_header.rs @@ -0,0 +1,154 @@ +//! Propagate a header from the request to the response. +//! +//! # Example +//! +//! ```rust +//! use http::{Request, Response, header::HeaderName}; +//! use std::convert::Infallible; +//! use tower::{Service, ServiceExt, ServiceBuilder, service_fn}; +//! use tower_http::propagate_header::PropagateHeaderLayer; +//! use bytes::Bytes; +//! use http_body_util::Full; +//! +//! # #[tokio::main] +//! # async fn main() -> Result<(), Box<dyn std::error::Error>> { +//! async fn handle(req: Request<Full<Bytes>>) -> Result<Response<Full<Bytes>>, Infallible> { +//! // ... +//! # Ok(Response::new(Full::default())) +//! } +//! +//! let mut svc = ServiceBuilder::new() +//! // This will copy `x-request-id` headers from requests onto responses. +//! .layer(PropagateHeaderLayer::new(HeaderName::from_static("x-request-id"))) +//! .service_fn(handle); +//! +//! // Call the service. +//! let request = Request::builder() +//! .header("x-request-id", "1337") +//! .body(Full::default())?; +//! +//! let response = svc.ready().await?.call(request).await?; +//! +//! assert_eq!(response.headers()["x-request-id"], "1337"); +//! # +//! # Ok(()) +//! # } +//! ``` + +use http::{header::HeaderName, HeaderValue, Request, Response}; +use pin_project_lite::pin_project; +use std::future::Future; +use std::{ + pin::Pin, + task::{ready, Context, Poll}, +}; +use tower_layer::Layer; +use tower_service::Service; + +/// Layer that applies [`PropagateHeader`] which propagates headers from requests to responses. +/// +/// If the header is present on the request it'll be applied to the response as well. This could +/// for example be used to propagate headers such as `X-Request-Id`. +/// +/// See the [module docs](crate::propagate_header) for more details. +#[derive(Clone, Debug)] +pub struct PropagateHeaderLayer { + header: HeaderName, +} + +impl PropagateHeaderLayer { + /// Create a new [`PropagateHeaderLayer`]. + pub fn new(header: HeaderName) -> Self { + Self { header } + } +} + +impl<S> Layer<S> for PropagateHeaderLayer { + type Service = PropagateHeader<S>; + + fn layer(&self, inner: S) -> Self::Service { + PropagateHeader { + inner, + header: self.header.clone(), + } + } +} + +/// Middleware that propagates headers from requests to responses. +/// +/// If the header is present on the request it'll be applied to the response as well. This could +/// for example be used to propagate headers such as `X-Request-Id`. +/// +/// See the [module docs](crate::propagate_header) for more details. +#[derive(Clone, Debug)] +pub struct PropagateHeader<S> { + inner: S, + header: HeaderName, +} + +impl<S> PropagateHeader<S> { + /// Create a new [`PropagateHeader`] that propagates the given header. + pub fn new(inner: S, header: HeaderName) -> Self { + Self { inner, header } + } + + define_inner_service_accessors!(); + + /// Returns a new [`Layer`] that wraps services with a `PropagateHeader` middleware. + /// + /// [`Layer`]: tower_layer::Layer + pub fn layer(header: HeaderName) -> PropagateHeaderLayer { + PropagateHeaderLayer::new(header) + } +} + +impl<ReqBody, ResBody, S> Service<Request<ReqBody>> for PropagateHeader<S> +where + S: Service<Request<ReqBody>, Response = Response<ResBody>>, +{ + type Response = S::Response; + type Error = S::Error; + type Future = ResponseFuture<S::Future>; + + #[inline] + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, req: Request<ReqBody>) -> Self::Future { + let value = req.headers().get(&self.header).cloned(); + + ResponseFuture { + future: self.inner.call(req), + header_and_value: Some(self.header.clone()).zip(value), + } + } +} + +pin_project! { + /// Response future for [`PropagateHeader`]. + #[derive(Debug)] + pub struct ResponseFuture<F> { + #[pin] + future: F, + header_and_value: Option<(HeaderName, HeaderValue)>, + } +} + +impl<F, ResBody, E> Future for ResponseFuture<F> +where + F: Future<Output = Result<Response<ResBody>, E>>, +{ + type Output = F::Output; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + let this = self.project(); + let mut res = ready!(this.future.poll(cx)?); + + if let Some((header, value)) = this.header_and_value.take() { + res.headers_mut().insert(header, value); + } + + Poll::Ready(Ok(res)) + } +} diff --git a/vendor/tower-http/src/request_id.rs b/vendor/tower-http/src/request_id.rs new file mode 100644 index 00000000..3c8c43fa --- /dev/null +++ b/vendor/tower-http/src/request_id.rs @@ -0,0 +1,604 @@ +//! Set and propagate request ids. +//! +//! # Example +//! +//! ``` +//! use http::{Request, Response, header::HeaderName}; +//! use tower::{Service, ServiceExt, ServiceBuilder}; +//! use tower_http::request_id::{ +//! SetRequestIdLayer, PropagateRequestIdLayer, MakeRequestId, RequestId, +//! }; +//! use http_body_util::Full; +//! use bytes::Bytes; +//! use std::sync::{Arc, atomic::{AtomicU64, Ordering}}; +//! +//! # #[tokio::main] +//! # async fn main() -> Result<(), Box<dyn std::error::Error>> { +//! # let handler = tower::service_fn(|request: Request<Full<Bytes>>| async move { +//! # Ok::<_, std::convert::Infallible>(Response::new(request.into_body())) +//! # }); +//! # +//! // A `MakeRequestId` that increments an atomic counter +//! #[derive(Clone, Default)] +//! struct MyMakeRequestId { +//! counter: Arc<AtomicU64>, +//! } +//! +//! impl MakeRequestId for MyMakeRequestId { +//! fn make_request_id<B>(&mut self, request: &Request<B>) -> Option<RequestId> { +//! let request_id = self.counter +//! .fetch_add(1, Ordering::SeqCst) +//! .to_string() +//! .parse() +//! .unwrap(); +//! +//! Some(RequestId::new(request_id)) +//! } +//! } +//! +//! let x_request_id = HeaderName::from_static("x-request-id"); +//! +//! let mut svc = ServiceBuilder::new() +//! // set `x-request-id` header on all requests +//! .layer(SetRequestIdLayer::new( +//! x_request_id.clone(), +//! MyMakeRequestId::default(), +//! )) +//! // propagate `x-request-id` headers from request to response +//! .layer(PropagateRequestIdLayer::new(x_request_id)) +//! .service(handler); +//! +//! let request = Request::new(Full::default()); +//! let response = svc.ready().await?.call(request).await?; +//! +//! assert_eq!(response.headers()["x-request-id"], "0"); +//! # +//! # Ok(()) +//! # } +//! ``` +//! +//! Additional convenience methods are available on [`ServiceBuilderExt`]: +//! +//! ``` +//! use tower_http::ServiceBuilderExt; +//! # use http::{Request, Response, header::HeaderName}; +//! # use tower::{Service, ServiceExt, ServiceBuilder}; +//! # use tower_http::request_id::{ +//! # SetRequestIdLayer, PropagateRequestIdLayer, MakeRequestId, RequestId, +//! # }; +//! # use bytes::Bytes; +//! # use http_body_util::Full; +//! # use std::sync::{Arc, atomic::{AtomicU64, Ordering}}; +//! # #[tokio::main] +//! # async fn main() -> Result<(), Box<dyn std::error::Error>> { +//! # let handler = tower::service_fn(|request: Request<Full<Bytes>>| async move { +//! # Ok::<_, std::convert::Infallible>(Response::new(request.into_body())) +//! # }); +//! # #[derive(Clone, Default)] +//! # struct MyMakeRequestId { +//! # counter: Arc<AtomicU64>, +//! # } +//! # impl MakeRequestId for MyMakeRequestId { +//! # fn make_request_id<B>(&mut self, request: &Request<B>) -> Option<RequestId> { +//! # let request_id = self.counter +//! # .fetch_add(1, Ordering::SeqCst) +//! # .to_string() +//! # .parse() +//! # .unwrap(); +//! # Some(RequestId::new(request_id)) +//! # } +//! # } +//! +//! let mut svc = ServiceBuilder::new() +//! .set_x_request_id(MyMakeRequestId::default()) +//! .propagate_x_request_id() +//! .service(handler); +//! +//! let request = Request::new(Full::default()); +//! let response = svc.ready().await?.call(request).await?; +//! +//! assert_eq!(response.headers()["x-request-id"], "0"); +//! # +//! # Ok(()) +//! # } +//! ``` +//! +//! See [`SetRequestId`] and [`PropagateRequestId`] for more details. +//! +//! # Using `Trace` +//! +//! To have request ids show up correctly in logs produced by [`Trace`] you must apply the layers +//! in this order: +//! +//! ``` +//! use tower_http::{ +//! ServiceBuilderExt, +//! trace::{TraceLayer, DefaultMakeSpan, DefaultOnResponse}, +//! }; +//! # use http::{Request, Response, header::HeaderName}; +//! # use tower::{Service, ServiceExt, ServiceBuilder}; +//! # use tower_http::request_id::{ +//! # SetRequestIdLayer, PropagateRequestIdLayer, MakeRequestId, RequestId, +//! # }; +//! # use http_body_util::Full; +//! # use bytes::Bytes; +//! # use std::sync::{Arc, atomic::{AtomicU64, Ordering}}; +//! # #[tokio::main] +//! # async fn main() -> Result<(), Box<dyn std::error::Error>> { +//! # let handler = tower::service_fn(|request: Request<Full<Bytes>>| async move { +//! # Ok::<_, std::convert::Infallible>(Response::new(request.into_body())) +//! # }); +//! # #[derive(Clone, Default)] +//! # struct MyMakeRequestId { +//! # counter: Arc<AtomicU64>, +//! # } +//! # impl MakeRequestId for MyMakeRequestId { +//! # fn make_request_id<B>(&mut self, request: &Request<B>) -> Option<RequestId> { +//! # let request_id = self.counter +//! # .fetch_add(1, Ordering::SeqCst) +//! # .to_string() +//! # .parse() +//! # .unwrap(); +//! # Some(RequestId::new(request_id)) +//! # } +//! # } +//! +//! let svc = ServiceBuilder::new() +//! // make sure to set request ids before the request reaches `TraceLayer` +//! .set_x_request_id(MyMakeRequestId::default()) +//! // log requests and responses +//! .layer( +//! TraceLayer::new_for_http() +//! .make_span_with(DefaultMakeSpan::new().include_headers(true)) +//! .on_response(DefaultOnResponse::new().include_headers(true)) +//! ) +//! // propagate the header to the response before the response reaches `TraceLayer` +//! .propagate_x_request_id() +//! .service(handler); +//! # +//! # Ok(()) +//! # } +//! ``` +//! +//! # Doesn't override existing headers +//! +//! [`SetRequestId`] and [`PropagateRequestId`] wont override request ids if its already present on +//! requests or responses. Among other things, this allows other middleware to conditionally set +//! request ids and use the middleware in this module as a fallback. +//! +//! [`ServiceBuilderExt`]: crate::ServiceBuilderExt +//! [`Uuid`]: https://crates.io/crates/uuid +//! [`Trace`]: crate::trace::Trace + +use http::{ + header::{HeaderName, HeaderValue}, + Request, Response, +}; +use pin_project_lite::pin_project; +use std::task::{ready, Context, Poll}; +use std::{future::Future, pin::Pin}; +use tower_layer::Layer; +use tower_service::Service; +use uuid::Uuid; + +pub(crate) const X_REQUEST_ID: HeaderName = HeaderName::from_static("x-request-id"); + +/// Trait for producing [`RequestId`]s. +/// +/// Used by [`SetRequestId`]. +pub trait MakeRequestId { + /// Try and produce a [`RequestId`] from the request. + fn make_request_id<B>(&mut self, request: &Request<B>) -> Option<RequestId>; +} + +/// An identifier for a request. +#[derive(Debug, Clone)] +pub struct RequestId(HeaderValue); + +impl RequestId { + /// Create a new `RequestId` from a [`HeaderValue`]. + pub fn new(header_value: HeaderValue) -> Self { + Self(header_value) + } + + /// Gets a reference to the underlying [`HeaderValue`]. + pub fn header_value(&self) -> &HeaderValue { + &self.0 + } + + /// Consumes `self`, returning the underlying [`HeaderValue`]. + pub fn into_header_value(self) -> HeaderValue { + self.0 + } +} + +impl From<HeaderValue> for RequestId { + fn from(value: HeaderValue) -> Self { + Self::new(value) + } +} + +/// Set request id headers and extensions on requests. +/// +/// This layer applies the [`SetRequestId`] middleware. +/// +/// See the [module docs](self) and [`SetRequestId`] for more details. +#[derive(Debug, Clone)] +pub struct SetRequestIdLayer<M> { + header_name: HeaderName, + make_request_id: M, +} + +impl<M> SetRequestIdLayer<M> { + /// Create a new `SetRequestIdLayer`. + pub fn new(header_name: HeaderName, make_request_id: M) -> Self + where + M: MakeRequestId, + { + SetRequestIdLayer { + header_name, + make_request_id, + } + } + + /// Create a new `SetRequestIdLayer` that uses `x-request-id` as the header name. + pub fn x_request_id(make_request_id: M) -> Self + where + M: MakeRequestId, + { + SetRequestIdLayer::new(X_REQUEST_ID, make_request_id) + } +} + +impl<S, M> Layer<S> for SetRequestIdLayer<M> +where + M: Clone + MakeRequestId, +{ + type Service = SetRequestId<S, M>; + + fn layer(&self, inner: S) -> Self::Service { + SetRequestId::new( + inner, + self.header_name.clone(), + self.make_request_id.clone(), + ) + } +} + +/// Set request id headers and extensions on requests. +/// +/// See the [module docs](self) for an example. +/// +/// If [`MakeRequestId::make_request_id`] returns `Some(_)` and the request doesn't already have a +/// header with the same name, then the header will be inserted. +/// +/// Additionally [`RequestId`] will be inserted into [`Request::extensions`] so other +/// services can access it. +#[derive(Debug, Clone)] +pub struct SetRequestId<S, M> { + inner: S, + header_name: HeaderName, + make_request_id: M, +} + +impl<S, M> SetRequestId<S, M> { + /// Create a new `SetRequestId`. + pub fn new(inner: S, header_name: HeaderName, make_request_id: M) -> Self + where + M: MakeRequestId, + { + Self { + inner, + header_name, + make_request_id, + } + } + + /// Create a new `SetRequestId` that uses `x-request-id` as the header name. + pub fn x_request_id(inner: S, make_request_id: M) -> Self + where + M: MakeRequestId, + { + Self::new(inner, X_REQUEST_ID, make_request_id) + } + + define_inner_service_accessors!(); + + /// Returns a new [`Layer`] that wraps services with a `SetRequestId` middleware. + pub fn layer(header_name: HeaderName, make_request_id: M) -> SetRequestIdLayer<M> + where + M: MakeRequestId, + { + SetRequestIdLayer::new(header_name, make_request_id) + } +} + +impl<S, M, ReqBody, ResBody> Service<Request<ReqBody>> for SetRequestId<S, M> +where + S: Service<Request<ReqBody>, Response = Response<ResBody>>, + M: MakeRequestId, +{ + type Response = S::Response; + type Error = S::Error; + type Future = S::Future; + + #[inline] + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future { + if let Some(request_id) = req.headers().get(&self.header_name) { + if req.extensions().get::<RequestId>().is_none() { + let request_id = request_id.clone(); + req.extensions_mut().insert(RequestId::new(request_id)); + } + } else if let Some(request_id) = self.make_request_id.make_request_id(&req) { + req.extensions_mut().insert(request_id.clone()); + req.headers_mut() + .insert(self.header_name.clone(), request_id.0); + } + + self.inner.call(req) + } +} + +/// Propagate request ids from requests to responses. +/// +/// This layer applies the [`PropagateRequestId`] middleware. +/// +/// See the [module docs](self) and [`PropagateRequestId`] for more details. +#[derive(Debug, Clone)] +pub struct PropagateRequestIdLayer { + header_name: HeaderName, +} + +impl PropagateRequestIdLayer { + /// Create a new `PropagateRequestIdLayer`. + pub fn new(header_name: HeaderName) -> Self { + PropagateRequestIdLayer { header_name } + } + + /// Create a new `PropagateRequestIdLayer` that uses `x-request-id` as the header name. + pub fn x_request_id() -> Self { + Self::new(X_REQUEST_ID) + } +} + +impl<S> Layer<S> for PropagateRequestIdLayer { + type Service = PropagateRequestId<S>; + + fn layer(&self, inner: S) -> Self::Service { + PropagateRequestId::new(inner, self.header_name.clone()) + } +} + +/// Propagate request ids from requests to responses. +/// +/// See the [module docs](self) for an example. +/// +/// If the request contains a matching header that header will be applied to responses. If a +/// [`RequestId`] extension is also present it will be propagated as well. +#[derive(Debug, Clone)] +pub struct PropagateRequestId<S> { + inner: S, + header_name: HeaderName, +} + +impl<S> PropagateRequestId<S> { + /// Create a new `PropagateRequestId`. + pub fn new(inner: S, header_name: HeaderName) -> Self { + Self { inner, header_name } + } + + /// Create a new `PropagateRequestId` that uses `x-request-id` as the header name. + pub fn x_request_id(inner: S) -> Self { + Self::new(inner, X_REQUEST_ID) + } + + define_inner_service_accessors!(); + + /// Returns a new [`Layer`] that wraps services with a `PropagateRequestId` middleware. + pub fn layer(header_name: HeaderName) -> PropagateRequestIdLayer { + PropagateRequestIdLayer::new(header_name) + } +} + +impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for PropagateRequestId<S> +where + S: Service<Request<ReqBody>, Response = Response<ResBody>>, +{ + type Response = S::Response; + type Error = S::Error; + type Future = PropagateRequestIdResponseFuture<S::Future>; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, req: Request<ReqBody>) -> Self::Future { + let request_id = req + .headers() + .get(&self.header_name) + .cloned() + .map(RequestId::new); + + PropagateRequestIdResponseFuture { + inner: self.inner.call(req), + header_name: self.header_name.clone(), + request_id, + } + } +} + +pin_project! { + /// Response future for [`PropagateRequestId`]. + pub struct PropagateRequestIdResponseFuture<F> { + #[pin] + inner: F, + header_name: HeaderName, + request_id: Option<RequestId>, + } +} + +impl<F, B, E> Future for PropagateRequestIdResponseFuture<F> +where + F: Future<Output = Result<Response<B>, E>>, +{ + type Output = Result<Response<B>, E>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + let this = self.project(); + let mut response = ready!(this.inner.poll(cx))?; + + if let Some(current_id) = response.headers().get(&*this.header_name) { + if response.extensions().get::<RequestId>().is_none() { + let current_id = current_id.clone(); + response.extensions_mut().insert(RequestId::new(current_id)); + } + } else if let Some(request_id) = this.request_id.take() { + response + .headers_mut() + .insert(this.header_name.clone(), request_id.0.clone()); + response.extensions_mut().insert(request_id); + } + + Poll::Ready(Ok(response)) + } +} + +/// A [`MakeRequestId`] that generates `UUID`s. +#[derive(Clone, Copy, Default)] +pub struct MakeRequestUuid; + +impl MakeRequestId for MakeRequestUuid { + fn make_request_id<B>(&mut self, _request: &Request<B>) -> Option<RequestId> { + let request_id = Uuid::new_v4().to_string().parse().unwrap(); + Some(RequestId::new(request_id)) + } +} + +#[cfg(test)] +mod tests { + use crate::test_helpers::Body; + use crate::ServiceBuilderExt as _; + use http::Response; + use std::{ + convert::Infallible, + sync::{ + atomic::{AtomicU64, Ordering}, + Arc, + }, + }; + use tower::{ServiceBuilder, ServiceExt}; + + #[allow(unused_imports)] + use super::*; + + #[tokio::test] + async fn basic() { + let svc = ServiceBuilder::new() + .set_x_request_id(Counter::default()) + .propagate_x_request_id() + .service_fn(handler); + + // header on response + let req = Request::builder().body(Body::empty()).unwrap(); + let res = svc.clone().oneshot(req).await.unwrap(); + assert_eq!(res.headers()["x-request-id"], "0"); + + let req = Request::builder().body(Body::empty()).unwrap(); + let res = svc.clone().oneshot(req).await.unwrap(); + assert_eq!(res.headers()["x-request-id"], "1"); + + // doesn't override if header is already there + let req = Request::builder() + .header("x-request-id", "foo") + .body(Body::empty()) + .unwrap(); + let res = svc.clone().oneshot(req).await.unwrap(); + assert_eq!(res.headers()["x-request-id"], "foo"); + + // extension propagated + let req = Request::builder().body(Body::empty()).unwrap(); + let res = svc.clone().oneshot(req).await.unwrap(); + assert_eq!(res.extensions().get::<RequestId>().unwrap().0, "2"); + } + + #[tokio::test] + async fn other_middleware_setting_request_id() { + let svc = ServiceBuilder::new() + .override_request_header( + HeaderName::from_static("x-request-id"), + HeaderValue::from_str("foo").unwrap(), + ) + .set_x_request_id(Counter::default()) + .map_request(|request: Request<_>| { + // `set_x_request_id` should set the extension if its missing + assert_eq!(request.extensions().get::<RequestId>().unwrap().0, "foo"); + request + }) + .propagate_x_request_id() + .service_fn(handler); + + let req = Request::builder() + .header( + "x-request-id", + "this-will-be-overriden-by-override_request_header-middleware", + ) + .body(Body::empty()) + .unwrap(); + let res = svc.clone().oneshot(req).await.unwrap(); + assert_eq!(res.headers()["x-request-id"], "foo"); + assert_eq!(res.extensions().get::<RequestId>().unwrap().0, "foo"); + } + + #[tokio::test] + async fn other_middleware_setting_request_id_on_response() { + let svc = ServiceBuilder::new() + .set_x_request_id(Counter::default()) + .propagate_x_request_id() + .override_response_header( + HeaderName::from_static("x-request-id"), + HeaderValue::from_str("foo").unwrap(), + ) + .service_fn(handler); + + let req = Request::builder() + .header("x-request-id", "foo") + .body(Body::empty()) + .unwrap(); + let res = svc.clone().oneshot(req).await.unwrap(); + assert_eq!(res.headers()["x-request-id"], "foo"); + assert_eq!(res.extensions().get::<RequestId>().unwrap().0, "foo"); + } + + #[derive(Clone, Default)] + struct Counter(Arc<AtomicU64>); + + impl MakeRequestId for Counter { + fn make_request_id<B>(&mut self, _request: &Request<B>) -> Option<RequestId> { + let id = + HeaderValue::from_str(&self.0.fetch_add(1, Ordering::SeqCst).to_string()).unwrap(); + Some(RequestId::new(id)) + } + } + + async fn handler(_: Request<Body>) -> Result<Response<Body>, Infallible> { + Ok(Response::new(Body::empty())) + } + + #[tokio::test] + async fn uuid() { + let svc = ServiceBuilder::new() + .set_x_request_id(MakeRequestUuid) + .propagate_x_request_id() + .service_fn(handler); + + // header on response + let req = Request::builder().body(Body::empty()).unwrap(); + let mut res = svc.clone().oneshot(req).await.unwrap(); + let id = res.headers_mut().remove("x-request-id").unwrap(); + id.to_str().unwrap().parse::<Uuid>().unwrap(); + } +} diff --git a/vendor/tower-http/src/sensitive_headers.rs b/vendor/tower-http/src/sensitive_headers.rs new file mode 100644 index 00000000..3bd081db --- /dev/null +++ b/vendor/tower-http/src/sensitive_headers.rs @@ -0,0 +1,448 @@ +//! Middlewares that mark headers as [sensitive]. +//! +//! [sensitive]: https://docs.rs/http/latest/http/header/struct.HeaderValue.html#method.set_sensitive +//! +//! # Example +//! +//! ``` +//! use tower_http::sensitive_headers::SetSensitiveHeadersLayer; +//! use tower::{Service, ServiceExt, ServiceBuilder, service_fn}; +//! use http::{Request, Response, header::AUTHORIZATION}; +//! use http_body_util::Full; +//! use bytes::Bytes; +//! use std::{iter::once, convert::Infallible}; +//! +//! async fn handle(req: Request<Full<Bytes>>) -> Result<Response<Full<Bytes>>, Infallible> { +//! // ... +//! # Ok(Response::new(Full::default())) +//! } +//! +//! # #[tokio::main] +//! # async fn main() -> Result<(), Box<dyn std::error::Error>> { +//! let mut service = ServiceBuilder::new() +//! // Mark the `Authorization` header as sensitive so it doesn't show in logs +//! // +//! // `SetSensitiveHeadersLayer` will mark the header as sensitive on both the +//! // request and response. +//! // +//! // The middleware is constructed from an iterator of headers to easily mark +//! // multiple headers at once. +//! .layer(SetSensitiveHeadersLayer::new(once(AUTHORIZATION))) +//! .service(service_fn(handle)); +//! +//! // Call the service. +//! let response = service +//! .ready() +//! .await? +//! .call(Request::new(Full::default())) +//! .await?; +//! # Ok(()) +//! # } +//! ``` +//! +//! Its important to think about the order in which requests and responses arrive at your +//! middleware. For example to hide headers both on requests and responses when using +//! [`TraceLayer`] you have to apply [`SetSensitiveRequestHeadersLayer`] before [`TraceLayer`] +//! and [`SetSensitiveResponseHeadersLayer`] afterwards. +//! +//! ``` +//! use tower_http::{ +//! trace::TraceLayer, +//! sensitive_headers::{ +//! SetSensitiveRequestHeadersLayer, +//! SetSensitiveResponseHeadersLayer, +//! }, +//! }; +//! use tower::{Service, ServiceExt, ServiceBuilder, service_fn}; +//! use http::header; +//! use std::sync::Arc; +//! # use http::{Request, Response}; +//! # use bytes::Bytes; +//! # use http_body_util::Full; +//! # use std::convert::Infallible; +//! # async fn handle(req: Request<Full<Bytes>>) -> Result<Response<Full<Bytes>>, Infallible> { +//! # Ok(Response::new(Full::default())) +//! # } +//! +//! # #[tokio::main] +//! # async fn main() -> Result<(), Box<dyn std::error::Error>> { +//! let headers: Arc<[_]> = Arc::new([ +//! header::AUTHORIZATION, +//! header::PROXY_AUTHORIZATION, +//! header::COOKIE, +//! header::SET_COOKIE, +//! ]); +//! +//! let service = ServiceBuilder::new() +//! .layer(SetSensitiveRequestHeadersLayer::from_shared(Arc::clone(&headers))) +//! .layer(TraceLayer::new_for_http()) +//! .layer(SetSensitiveResponseHeadersLayer::from_shared(headers)) +//! .service_fn(handle); +//! # Ok(()) +//! # } +//! ``` +//! +//! [`TraceLayer`]: crate::trace::TraceLayer + +use http::{header::HeaderName, Request, Response}; +use pin_project_lite::pin_project; +use std::{ + future::Future, + pin::Pin, + sync::Arc, + task::{ready, Context, Poll}, +}; +use tower_layer::Layer; +use tower_service::Service; + +/// Mark headers as [sensitive] on both requests and responses. +/// +/// Produces [`SetSensitiveHeaders`] services. +/// +/// See the [module docs](crate::sensitive_headers) for more details. +/// +/// [sensitive]: https://docs.rs/http/latest/http/header/struct.HeaderValue.html#method.set_sensitive +#[derive(Clone, Debug)] +pub struct SetSensitiveHeadersLayer { + headers: Arc<[HeaderName]>, +} + +impl SetSensitiveHeadersLayer { + /// Create a new [`SetSensitiveHeadersLayer`]. + pub fn new<I>(headers: I) -> Self + where + I: IntoIterator<Item = HeaderName>, + { + let headers = headers.into_iter().collect::<Vec<_>>(); + Self::from_shared(headers.into()) + } + + /// Create a new [`SetSensitiveHeadersLayer`] from a shared slice of headers. + pub fn from_shared(headers: Arc<[HeaderName]>) -> Self { + Self { headers } + } +} + +impl<S> Layer<S> for SetSensitiveHeadersLayer { + type Service = SetSensitiveHeaders<S>; + + fn layer(&self, inner: S) -> Self::Service { + SetSensitiveRequestHeaders::from_shared( + SetSensitiveResponseHeaders::from_shared(inner, self.headers.clone()), + self.headers.clone(), + ) + } +} + +/// Mark headers as [sensitive] on both requests and responses. +/// +/// See the [module docs](crate::sensitive_headers) for more details. +/// +/// [sensitive]: https://docs.rs/http/latest/http/header/struct.HeaderValue.html#method.set_sensitive +pub type SetSensitiveHeaders<S> = SetSensitiveRequestHeaders<SetSensitiveResponseHeaders<S>>; + +/// Mark request headers as [sensitive]. +/// +/// Produces [`SetSensitiveRequestHeaders`] services. +/// +/// See the [module docs](crate::sensitive_headers) for more details. +/// +/// [sensitive]: https://docs.rs/http/latest/http/header/struct.HeaderValue.html#method.set_sensitive +#[derive(Clone, Debug)] +pub struct SetSensitiveRequestHeadersLayer { + headers: Arc<[HeaderName]>, +} + +impl SetSensitiveRequestHeadersLayer { + /// Create a new [`SetSensitiveRequestHeadersLayer`]. + pub fn new<I>(headers: I) -> Self + where + I: IntoIterator<Item = HeaderName>, + { + let headers = headers.into_iter().collect::<Vec<_>>(); + Self::from_shared(headers.into()) + } + + /// Create a new [`SetSensitiveRequestHeadersLayer`] from a shared slice of headers. + pub fn from_shared(headers: Arc<[HeaderName]>) -> Self { + Self { headers } + } +} + +impl<S> Layer<S> for SetSensitiveRequestHeadersLayer { + type Service = SetSensitiveRequestHeaders<S>; + + fn layer(&self, inner: S) -> Self::Service { + SetSensitiveRequestHeaders { + inner, + headers: self.headers.clone(), + } + } +} + +/// Mark request headers as [sensitive]. +/// +/// See the [module docs](crate::sensitive_headers) for more details. +/// +/// [sensitive]: https://docs.rs/http/latest/http/header/struct.HeaderValue.html#method.set_sensitive +#[derive(Clone, Debug)] +pub struct SetSensitiveRequestHeaders<S> { + inner: S, + headers: Arc<[HeaderName]>, +} + +impl<S> SetSensitiveRequestHeaders<S> { + /// Create a new [`SetSensitiveRequestHeaders`]. + pub fn new<I>(inner: S, headers: I) -> Self + where + I: IntoIterator<Item = HeaderName>, + { + let headers = headers.into_iter().collect::<Vec<_>>(); + Self::from_shared(inner, headers.into()) + } + + /// Create a new [`SetSensitiveRequestHeaders`] from a shared slice of headers. + pub fn from_shared(inner: S, headers: Arc<[HeaderName]>) -> Self { + Self { inner, headers } + } + + define_inner_service_accessors!(); + + /// Returns a new [`Layer`] that wraps services with a `SetSensitiveRequestHeaders` middleware. + /// + /// [`Layer`]: tower_layer::Layer + pub fn layer<I>(headers: I) -> SetSensitiveRequestHeadersLayer + where + I: IntoIterator<Item = HeaderName>, + { + SetSensitiveRequestHeadersLayer::new(headers) + } +} + +impl<ReqBody, ResBody, S> Service<Request<ReqBody>> for SetSensitiveRequestHeaders<S> +where + S: Service<Request<ReqBody>, Response = Response<ResBody>>, +{ + type Response = S::Response; + type Error = S::Error; + type Future = S::Future; + + #[inline] + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future { + let headers = req.headers_mut(); + for header in &*self.headers { + if let http::header::Entry::Occupied(mut entry) = headers.entry(header) { + for value in entry.iter_mut() { + value.set_sensitive(true); + } + } + } + + self.inner.call(req) + } +} + +/// Mark response headers as [sensitive]. +/// +/// Produces [`SetSensitiveResponseHeaders`] services. +/// +/// See the [module docs](crate::sensitive_headers) for more details. +/// +/// [sensitive]: https://docs.rs/http/latest/http/header/struct.HeaderValue.html#method.set_sensitive +#[derive(Clone, Debug)] +pub struct SetSensitiveResponseHeadersLayer { + headers: Arc<[HeaderName]>, +} + +impl SetSensitiveResponseHeadersLayer { + /// Create a new [`SetSensitiveResponseHeadersLayer`]. + pub fn new<I>(headers: I) -> Self + where + I: IntoIterator<Item = HeaderName>, + { + let headers = headers.into_iter().collect::<Vec<_>>(); + Self::from_shared(headers.into()) + } + + /// Create a new [`SetSensitiveResponseHeadersLayer`] from a shared slice of headers. + pub fn from_shared(headers: Arc<[HeaderName]>) -> Self { + Self { headers } + } +} + +impl<S> Layer<S> for SetSensitiveResponseHeadersLayer { + type Service = SetSensitiveResponseHeaders<S>; + + fn layer(&self, inner: S) -> Self::Service { + SetSensitiveResponseHeaders { + inner, + headers: self.headers.clone(), + } + } +} + +/// Mark response headers as [sensitive]. +/// +/// See the [module docs](crate::sensitive_headers) for more details. +/// +/// [sensitive]: https://docs.rs/http/latest/http/header/struct.HeaderValue.html#method.set_sensitive +#[derive(Clone, Debug)] +pub struct SetSensitiveResponseHeaders<S> { + inner: S, + headers: Arc<[HeaderName]>, +} + +impl<S> SetSensitiveResponseHeaders<S> { + /// Create a new [`SetSensitiveResponseHeaders`]. + pub fn new<I>(inner: S, headers: I) -> Self + where + I: IntoIterator<Item = HeaderName>, + { + let headers = headers.into_iter().collect::<Vec<_>>(); + Self::from_shared(inner, headers.into()) + } + + /// Create a new [`SetSensitiveResponseHeaders`] from a shared slice of headers. + pub fn from_shared(inner: S, headers: Arc<[HeaderName]>) -> Self { + Self { inner, headers } + } + + define_inner_service_accessors!(); + + /// Returns a new [`Layer`] that wraps services with a `SetSensitiveResponseHeaders` middleware. + /// + /// [`Layer`]: tower_layer::Layer + pub fn layer<I>(headers: I) -> SetSensitiveResponseHeadersLayer + where + I: IntoIterator<Item = HeaderName>, + { + SetSensitiveResponseHeadersLayer::new(headers) + } +} + +impl<ReqBody, ResBody, S> Service<Request<ReqBody>> for SetSensitiveResponseHeaders<S> +where + S: Service<Request<ReqBody>, Response = Response<ResBody>>, +{ + type Response = S::Response; + type Error = S::Error; + type Future = SetSensitiveResponseHeadersResponseFuture<S::Future>; + + #[inline] + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, req: Request<ReqBody>) -> Self::Future { + SetSensitiveResponseHeadersResponseFuture { + future: self.inner.call(req), + headers: self.headers.clone(), + } + } +} + +pin_project! { + /// Response future for [`SetSensitiveResponseHeaders`]. + #[derive(Debug)] + pub struct SetSensitiveResponseHeadersResponseFuture<F> { + #[pin] + future: F, + headers: Arc<[HeaderName]>, + } +} + +impl<F, ResBody, E> Future for SetSensitiveResponseHeadersResponseFuture<F> +where + F: Future<Output = Result<Response<ResBody>, E>>, +{ + type Output = F::Output; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + let this = self.project(); + let mut res = ready!(this.future.poll(cx)?); + + let headers = res.headers_mut(); + for header in &**this.headers { + if let http::header::Entry::Occupied(mut entry) = headers.entry(header) { + for value in entry.iter_mut() { + value.set_sensitive(true); + } + } + } + + Poll::Ready(Ok(res)) + } +} + +#[cfg(test)] +mod tests { + #[allow(unused_imports)] + use super::*; + use http::header; + use tower::{ServiceBuilder, ServiceExt}; + + #[tokio::test] + async fn multiple_value_header() { + async fn response_set_cookie(req: http::Request<()>) -> Result<http::Response<()>, ()> { + let mut iter = req.headers().get_all(header::COOKIE).iter().peekable(); + + assert!(iter.peek().is_some()); + + for value in iter { + assert!(value.is_sensitive()) + } + + let mut resp = http::Response::new(()); + resp.headers_mut().append( + header::CONTENT_TYPE, + http::HeaderValue::from_static("text/html"), + ); + resp.headers_mut().append( + header::SET_COOKIE, + http::HeaderValue::from_static("cookie-1"), + ); + resp.headers_mut().append( + header::SET_COOKIE, + http::HeaderValue::from_static("cookie-2"), + ); + resp.headers_mut().append( + header::SET_COOKIE, + http::HeaderValue::from_static("cookie-3"), + ); + Ok(resp) + } + + let mut service = ServiceBuilder::new() + .layer(SetSensitiveRequestHeadersLayer::new(vec![header::COOKIE])) + .layer(SetSensitiveResponseHeadersLayer::new(vec![ + header::SET_COOKIE, + ])) + .service_fn(response_set_cookie); + + let mut req = http::Request::new(()); + req.headers_mut() + .append(header::COOKIE, http::HeaderValue::from_static("cookie+1")); + req.headers_mut() + .append(header::COOKIE, http::HeaderValue::from_static("cookie+2")); + + let resp = service.ready().await.unwrap().call(req).await.unwrap(); + + assert!(!resp + .headers() + .get(header::CONTENT_TYPE) + .unwrap() + .is_sensitive()); + + let mut iter = resp.headers().get_all(header::SET_COOKIE).iter().peekable(); + + assert!(iter.peek().is_some()); + + for value in iter { + assert!(value.is_sensitive()) + } + } +} diff --git a/vendor/tower-http/src/service_ext.rs b/vendor/tower-http/src/service_ext.rs new file mode 100644 index 00000000..8973d8a4 --- /dev/null +++ b/vendor/tower-http/src/service_ext.rs @@ -0,0 +1,442 @@ +#[allow(unused_imports)] +use http::header::HeaderName; + +/// Extension trait that adds methods to any [`Service`] for adding middleware from +/// tower-http. +/// +/// [`Service`]: tower::Service +#[cfg(feature = "util")] +// ^ work around rustdoc not inferring doc(cfg)s for cfg's from surrounding scopes +pub trait ServiceExt { + /// Propagate a header from the request to the response. + /// + /// See [`tower_http::propagate_header`] for more details. + /// + /// [`tower_http::propagate_header`]: crate::propagate_header + #[cfg(feature = "propagate-header")] + fn propagate_header(self, header: HeaderName) -> crate::propagate_header::PropagateHeader<Self> + where + Self: Sized, + { + crate::propagate_header::PropagateHeader::new(self, header) + } + + /// Add some shareable value to [request extensions]. + /// + /// See [`tower_http::add_extension`] for more details. + /// + /// [`tower_http::add_extension`]: crate::add_extension + /// [request extensions]: https://docs.rs/http/latest/http/struct.Extensions.html + #[cfg(feature = "add-extension")] + fn add_extension<T>(self, value: T) -> crate::add_extension::AddExtension<Self, T> + where + Self: Sized, + { + crate::add_extension::AddExtension::new(self, value) + } + + /// Apply a transformation to the request body. + /// + /// See [`tower_http::map_request_body`] for more details. + /// + /// [`tower_http::map_request_body`]: crate::map_request_body + #[cfg(feature = "map-request-body")] + fn map_request_body<F>(self, f: F) -> crate::map_request_body::MapRequestBody<Self, F> + where + Self: Sized, + { + crate::map_request_body::MapRequestBody::new(self, f) + } + + /// Apply a transformation to the response body. + /// + /// See [`tower_http::map_response_body`] for more details. + /// + /// [`tower_http::map_response_body`]: crate::map_response_body + #[cfg(feature = "map-response-body")] + fn map_response_body<F>(self, f: F) -> crate::map_response_body::MapResponseBody<Self, F> + where + Self: Sized, + { + crate::map_response_body::MapResponseBody::new(self, f) + } + + /// Compresses response bodies. + /// + /// See [`tower_http::compression`] for more details. + /// + /// [`tower_http::compression`]: crate::compression + #[cfg(any( + feature = "compression-br", + feature = "compression-deflate", + feature = "compression-gzip", + feature = "compression-zstd", + ))] + fn compression(self) -> crate::compression::Compression<Self> + where + Self: Sized, + { + crate::compression::Compression::new(self) + } + + /// Decompress response bodies. + /// + /// See [`tower_http::decompression`] for more details. + /// + /// [`tower_http::decompression`]: crate::decompression + #[cfg(any( + feature = "decompression-br", + feature = "decompression-deflate", + feature = "decompression-gzip", + feature = "decompression-zstd", + ))] + fn decompression(self) -> crate::decompression::Decompression<Self> + where + Self: Sized, + { + crate::decompression::Decompression::new(self) + } + + /// High level tracing that classifies responses using HTTP status codes. + /// + /// This method does not support customizing the output, to do that use [`TraceLayer`] + /// instead. + /// + /// See [`tower_http::trace`] for more details. + /// + /// [`tower_http::trace`]: crate::trace + /// [`TraceLayer`]: crate::trace::TraceLayer + #[cfg(feature = "trace")] + fn trace_for_http(self) -> crate::trace::Trace<Self, crate::trace::HttpMakeClassifier> + where + Self: Sized, + { + crate::trace::Trace::new_for_http(self) + } + + /// High level tracing that classifies responses using gRPC headers. + /// + /// This method does not support customizing the output, to do that use [`TraceLayer`] + /// instead. + /// + /// See [`tower_http::trace`] for more details. + /// + /// [`tower_http::trace`]: crate::trace + /// [`TraceLayer`]: crate::trace::TraceLayer + #[cfg(feature = "trace")] + fn trace_for_grpc(self) -> crate::trace::Trace<Self, crate::trace::GrpcMakeClassifier> + where + Self: Sized, + { + crate::trace::Trace::new_for_grpc(self) + } + + /// Follow redirect resposes using the [`Standard`] policy. + /// + /// See [`tower_http::follow_redirect`] for more details. + /// + /// [`tower_http::follow_redirect`]: crate::follow_redirect + /// [`Standard`]: crate::follow_redirect::policy::Standard + #[cfg(feature = "follow-redirect")] + fn follow_redirects( + self, + ) -> crate::follow_redirect::FollowRedirect<Self, crate::follow_redirect::policy::Standard> + where + Self: Sized, + { + crate::follow_redirect::FollowRedirect::new(self) + } + + /// Mark headers as [sensitive] on both requests and responses. + /// + /// See [`tower_http::sensitive_headers`] for more details. + /// + /// [sensitive]: https://docs.rs/http/latest/http/header/struct.HeaderValue.html#method.set_sensitive + /// [`tower_http::sensitive_headers`]: crate::sensitive_headers + #[cfg(feature = "sensitive-headers")] + fn sensitive_headers( + self, + headers: impl IntoIterator<Item = HeaderName>, + ) -> crate::sensitive_headers::SetSensitiveHeaders<Self> + where + Self: Sized, + { + use tower_layer::Layer as _; + crate::sensitive_headers::SetSensitiveHeadersLayer::new(headers).layer(self) + } + + /// Mark headers as [sensitive] on requests. + /// + /// See [`tower_http::sensitive_headers`] for more details. + /// + /// [sensitive]: https://docs.rs/http/latest/http/header/struct.HeaderValue.html#method.set_sensitive + /// [`tower_http::sensitive_headers`]: crate::sensitive_headers + #[cfg(feature = "sensitive-headers")] + fn sensitive_request_headers( + self, + headers: impl IntoIterator<Item = HeaderName>, + ) -> crate::sensitive_headers::SetSensitiveRequestHeaders<Self> + where + Self: Sized, + { + crate::sensitive_headers::SetSensitiveRequestHeaders::new(self, headers) + } + + /// Mark headers as [sensitive] on responses. + /// + /// See [`tower_http::sensitive_headers`] for more details. + /// + /// [sensitive]: https://docs.rs/http/latest/http/header/struct.HeaderValue.html#method.set_sensitive + /// [`tower_http::sensitive_headers`]: crate::sensitive_headers + #[cfg(feature = "sensitive-headers")] + fn sensitive_response_headers( + self, + headers: impl IntoIterator<Item = HeaderName>, + ) -> crate::sensitive_headers::SetSensitiveResponseHeaders<Self> + where + Self: Sized, + { + crate::sensitive_headers::SetSensitiveResponseHeaders::new(self, headers) + } + + /// Insert a header into the request. + /// + /// If a previous value exists for the same header, it is removed and replaced with the new + /// header value. + /// + /// See [`tower_http::set_header`] for more details. + /// + /// [`tower_http::set_header`]: crate::set_header + #[cfg(feature = "set-header")] + fn override_request_header<M>( + self, + header_name: HeaderName, + make: M, + ) -> crate::set_header::SetRequestHeader<Self, M> + where + Self: Sized, + { + crate::set_header::SetRequestHeader::overriding(self, header_name, make) + } + + /// Append a header into the request. + /// + /// If previous values exist, the header will have multiple values. + /// + /// See [`tower_http::set_header`] for more details. + /// + /// [`tower_http::set_header`]: crate::set_header + #[cfg(feature = "set-header")] + fn append_request_header<M>( + self, + header_name: HeaderName, + make: M, + ) -> crate::set_header::SetRequestHeader<Self, M> + where + Self: Sized, + { + crate::set_header::SetRequestHeader::appending(self, header_name, make) + } + + /// Insert a header into the request, if the header is not already present. + /// + /// See [`tower_http::set_header`] for more details. + /// + /// [`tower_http::set_header`]: crate::set_header + #[cfg(feature = "set-header")] + fn insert_request_header_if_not_present<M>( + self, + header_name: HeaderName, + make: M, + ) -> crate::set_header::SetRequestHeader<Self, M> + where + Self: Sized, + { + crate::set_header::SetRequestHeader::if_not_present(self, header_name, make) + } + + /// Insert a header into the response. + /// + /// If a previous value exists for the same header, it is removed and replaced with the new + /// header value. + /// + /// See [`tower_http::set_header`] for more details. + /// + /// [`tower_http::set_header`]: crate::set_header + #[cfg(feature = "set-header")] + fn override_response_header<M>( + self, + header_name: HeaderName, + make: M, + ) -> crate::set_header::SetResponseHeader<Self, M> + where + Self: Sized, + { + crate::set_header::SetResponseHeader::overriding(self, header_name, make) + } + + /// Append a header into the response. + /// + /// If previous values exist, the header will have multiple values. + /// + /// See [`tower_http::set_header`] for more details. + /// + /// [`tower_http::set_header`]: crate::set_header + #[cfg(feature = "set-header")] + fn append_response_header<M>( + self, + header_name: HeaderName, + make: M, + ) -> crate::set_header::SetResponseHeader<Self, M> + where + Self: Sized, + { + crate::set_header::SetResponseHeader::appending(self, header_name, make) + } + + /// Insert a header into the response, if the header is not already present. + /// + /// See [`tower_http::set_header`] for more details. + /// + /// [`tower_http::set_header`]: crate::set_header + #[cfg(feature = "set-header")] + fn insert_response_header_if_not_present<M>( + self, + header_name: HeaderName, + make: M, + ) -> crate::set_header::SetResponseHeader<Self, M> + where + Self: Sized, + { + crate::set_header::SetResponseHeader::if_not_present(self, header_name, make) + } + + /// Add request id header and extension. + /// + /// See [`tower_http::request_id`] for more details. + /// + /// [`tower_http::request_id`]: crate::request_id + #[cfg(feature = "request-id")] + fn set_request_id<M>( + self, + header_name: HeaderName, + make_request_id: M, + ) -> crate::request_id::SetRequestId<Self, M> + where + Self: Sized, + M: crate::request_id::MakeRequestId, + { + crate::request_id::SetRequestId::new(self, header_name, make_request_id) + } + + /// Add request id header and extension, using `x-request-id` as the header name. + /// + /// See [`tower_http::request_id`] for more details. + /// + /// [`tower_http::request_id`]: crate::request_id + #[cfg(feature = "request-id")] + fn set_x_request_id<M>(self, make_request_id: M) -> crate::request_id::SetRequestId<Self, M> + where + Self: Sized, + M: crate::request_id::MakeRequestId, + { + self.set_request_id(crate::request_id::X_REQUEST_ID, make_request_id) + } + + /// Propgate request ids from requests to responses. + /// + /// See [`tower_http::request_id`] for more details. + /// + /// [`tower_http::request_id`]: crate::request_id + #[cfg(feature = "request-id")] + fn propagate_request_id( + self, + header_name: HeaderName, + ) -> crate::request_id::PropagateRequestId<Self> + where + Self: Sized, + { + crate::request_id::PropagateRequestId::new(self, header_name) + } + + /// Propgate request ids from requests to responses, using `x-request-id` as the header name. + /// + /// See [`tower_http::request_id`] for more details. + /// + /// [`tower_http::request_id`]: crate::request_id + #[cfg(feature = "request-id")] + fn propagate_x_request_id(self) -> crate::request_id::PropagateRequestId<Self> + where + Self: Sized, + { + self.propagate_request_id(crate::request_id::X_REQUEST_ID) + } + + /// Catch panics and convert them into `500 Internal Server` responses. + /// + /// See [`tower_http::catch_panic`] for more details. + /// + /// [`tower_http::catch_panic`]: crate::catch_panic + #[cfg(feature = "catch-panic")] + fn catch_panic( + self, + ) -> crate::catch_panic::CatchPanic<Self, crate::catch_panic::DefaultResponseForPanic> + where + Self: Sized, + { + crate::catch_panic::CatchPanic::new(self) + } + + /// Intercept requests with over-sized payloads and convert them into + /// `413 Payload Too Large` responses. + /// + /// See [`tower_http::limit`] for more details. + /// + /// [`tower_http::limit`]: crate::limit + #[cfg(feature = "limit")] + fn request_body_limit(self, limit: usize) -> crate::limit::RequestBodyLimit<Self> + where + Self: Sized, + { + crate::limit::RequestBodyLimit::new(self, limit) + } + + /// Remove trailing slashes from paths. + /// + /// See [`tower_http::normalize_path`] for more details. + /// + /// [`tower_http::normalize_path`]: crate::normalize_path + #[cfg(feature = "normalize-path")] + fn trim_trailing_slash(self) -> crate::normalize_path::NormalizePath<Self> + where + Self: Sized, + { + crate::normalize_path::NormalizePath::trim_trailing_slash(self) + } + + /// Append trailing slash to paths. + /// + /// See [`tower_http::normalize_path`] for more details. + /// + /// [`tower_http::normalize_path`]: crate::normalize_path + #[cfg(feature = "normalize-path")] + fn append_trailing_slash(self) -> crate::normalize_path::NormalizePath<Self> + where + Self: Sized, + { + crate::normalize_path::NormalizePath::append_trailing_slash(self) + } +} + +impl<T> ServiceExt for T {} + +#[cfg(all(test, feature = "fs", feature = "add-extension"))] +mod tests { + use super::ServiceExt; + use crate::services; + + #[allow(dead_code)] + fn test_type_inference() { + let _svc = services::fs::ServeDir::new(".").add_extension("&'static str"); + } +} diff --git a/vendor/tower-http/src/services/fs/mod.rs b/vendor/tower-http/src/services/fs/mod.rs new file mode 100644 index 00000000..c23f9619 --- /dev/null +++ b/vendor/tower-http/src/services/fs/mod.rs @@ -0,0 +1,79 @@ +//! File system related services. + +use bytes::Bytes; +use futures_core::Stream; +use http_body::{Body, Frame}; +use pin_project_lite::pin_project; +use std::{ + io, + pin::Pin, + task::{Context, Poll}, +}; +use tokio::io::{AsyncRead, AsyncReadExt, Take}; +use tokio_util::io::ReaderStream; + +mod serve_dir; +mod serve_file; + +pub use self::{ + serve_dir::{ + future::ResponseFuture as ServeFileSystemResponseFuture, + DefaultServeDirFallback, + // The response body and future are used for both ServeDir and ServeFile + ResponseBody as ServeFileSystemResponseBody, + ServeDir, + }, + serve_file::ServeFile, +}; + +pin_project! { + // NOTE: This could potentially be upstreamed to `http-body`. + /// Adapter that turns an [`impl AsyncRead`][tokio::io::AsyncRead] to an [`impl Body`][http_body::Body]. + #[derive(Debug)] + pub struct AsyncReadBody<T> { + #[pin] + reader: ReaderStream<T>, + } +} + +impl<T> AsyncReadBody<T> +where + T: AsyncRead, +{ + /// Create a new [`AsyncReadBody`] wrapping the given reader, + /// with a specific read buffer capacity + fn with_capacity(read: T, capacity: usize) -> Self { + Self { + reader: ReaderStream::with_capacity(read, capacity), + } + } + + fn with_capacity_limited( + read: T, + capacity: usize, + max_read_bytes: u64, + ) -> AsyncReadBody<Take<T>> { + AsyncReadBody { + reader: ReaderStream::with_capacity(read.take(max_read_bytes), capacity), + } + } +} + +impl<T> Body for AsyncReadBody<T> +where + T: AsyncRead, +{ + type Data = Bytes; + type Error = io::Error; + + fn poll_frame( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> { + match std::task::ready!(self.project().reader.poll_next(cx)) { + Some(Ok(chunk)) => Poll::Ready(Some(Ok(Frame::data(chunk)))), + Some(Err(err)) => Poll::Ready(Some(Err(err))), + None => Poll::Ready(None), + } + } +} diff --git a/vendor/tower-http/src/services/fs/serve_dir/future.rs b/vendor/tower-http/src/services/fs/serve_dir/future.rs new file mode 100644 index 00000000..6d96204d --- /dev/null +++ b/vendor/tower-http/src/services/fs/serve_dir/future.rs @@ -0,0 +1,332 @@ +use super::{ + open_file::{FileOpened, FileRequestExtent, OpenFileOutput}, + DefaultServeDirFallback, ResponseBody, +}; +use crate::{ + body::UnsyncBoxBody, content_encoding::Encoding, services::fs::AsyncReadBody, BoxError, +}; +use bytes::Bytes; +use futures_core::future::BoxFuture; +use futures_util::future::{FutureExt, TryFutureExt}; +use http::{ + header::{self, ALLOW}, + HeaderValue, Request, Response, StatusCode, +}; +use http_body_util::{BodyExt, Empty, Full}; +use pin_project_lite::pin_project; +use std::{ + convert::Infallible, + future::Future, + io, + pin::Pin, + task::{ready, Context, Poll}, +}; +use tower_service::Service; + +pin_project! { + /// Response future of [`ServeDir::try_call()`][`super::ServeDir::try_call()`]. + pub struct ResponseFuture<ReqBody, F = DefaultServeDirFallback> { + #[pin] + pub(super) inner: ResponseFutureInner<ReqBody, F>, + } +} + +impl<ReqBody, F> ResponseFuture<ReqBody, F> { + pub(super) fn open_file_future( + future: BoxFuture<'static, io::Result<OpenFileOutput>>, + fallback_and_request: Option<(F, Request<ReqBody>)>, + ) -> Self { + Self { + inner: ResponseFutureInner::OpenFileFuture { + future, + fallback_and_request, + }, + } + } + + pub(super) fn invalid_path(fallback_and_request: Option<(F, Request<ReqBody>)>) -> Self { + Self { + inner: ResponseFutureInner::InvalidPath { + fallback_and_request, + }, + } + } + + pub(super) fn method_not_allowed() -> Self { + Self { + inner: ResponseFutureInner::MethodNotAllowed, + } + } +} + +pin_project! { + #[project = ResponseFutureInnerProj] + pub(super) enum ResponseFutureInner<ReqBody, F> { + OpenFileFuture { + #[pin] + future: BoxFuture<'static, io::Result<OpenFileOutput>>, + fallback_and_request: Option<(F, Request<ReqBody>)>, + }, + FallbackFuture { + future: BoxFuture<'static, Result<Response<ResponseBody>, Infallible>>, + }, + InvalidPath { + fallback_and_request: Option<(F, Request<ReqBody>)>, + }, + MethodNotAllowed, + } +} + +impl<F, ReqBody, ResBody> Future for ResponseFuture<ReqBody, F> +where + F: Service<Request<ReqBody>, Response = Response<ResBody>, Error = Infallible> + Clone, + F::Future: Send + 'static, + ResBody: http_body::Body<Data = Bytes> + Send + 'static, + ResBody::Error: Into<Box<dyn std::error::Error + Send + Sync>>, +{ + type Output = io::Result<Response<ResponseBody>>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + loop { + let mut this = self.as_mut().project(); + + let new_state = match this.inner.as_mut().project() { + ResponseFutureInnerProj::OpenFileFuture { + future: open_file_future, + fallback_and_request, + } => match ready!(open_file_future.poll(cx)) { + Ok(OpenFileOutput::FileOpened(file_output)) => { + break Poll::Ready(Ok(build_response(*file_output))); + } + + Ok(OpenFileOutput::Redirect { location }) => { + let mut res = response_with_status(StatusCode::TEMPORARY_REDIRECT); + res.headers_mut().insert(http::header::LOCATION, location); + break Poll::Ready(Ok(res)); + } + + Ok(OpenFileOutput::FileNotFound) => { + if let Some((mut fallback, request)) = fallback_and_request.take() { + call_fallback(&mut fallback, request) + } else { + break Poll::Ready(Ok(not_found())); + } + } + + Ok(OpenFileOutput::PreconditionFailed) => { + break Poll::Ready(Ok(response_with_status( + StatusCode::PRECONDITION_FAILED, + ))); + } + + Ok(OpenFileOutput::NotModified) => { + break Poll::Ready(Ok(response_with_status(StatusCode::NOT_MODIFIED))); + } + + Ok(OpenFileOutput::InvalidRedirectUri) => { + break Poll::Ready(Ok(response_with_status( + StatusCode::INTERNAL_SERVER_ERROR, + ))); + } + + Err(err) => { + #[cfg(unix)] + // 20 = libc::ENOTDIR => "not a directory + // when `io_error_more` landed, this can be changed + // to checking for `io::ErrorKind::NotADirectory`. + // https://github.com/rust-lang/rust/issues/86442 + let error_is_not_a_directory = err.raw_os_error() == Some(20); + #[cfg(not(unix))] + let error_is_not_a_directory = false; + + if matches!( + err.kind(), + io::ErrorKind::NotFound | io::ErrorKind::PermissionDenied + ) || error_is_not_a_directory + { + if let Some((mut fallback, request)) = fallback_and_request.take() { + call_fallback(&mut fallback, request) + } else { + break Poll::Ready(Ok(not_found())); + } + } else { + break Poll::Ready(Err(err)); + } + } + }, + + ResponseFutureInnerProj::FallbackFuture { future } => { + break Pin::new(future).poll(cx).map_err(|err| match err {}) + } + + ResponseFutureInnerProj::InvalidPath { + fallback_and_request, + } => { + if let Some((mut fallback, request)) = fallback_and_request.take() { + call_fallback(&mut fallback, request) + } else { + break Poll::Ready(Ok(not_found())); + } + } + + ResponseFutureInnerProj::MethodNotAllowed => { + let mut res = response_with_status(StatusCode::METHOD_NOT_ALLOWED); + res.headers_mut() + .insert(ALLOW, HeaderValue::from_static("GET,HEAD")); + break Poll::Ready(Ok(res)); + } + }; + + this.inner.set(new_state); + } + } +} + +fn response_with_status(status: StatusCode) -> Response<ResponseBody> { + Response::builder() + .status(status) + .body(empty_body()) + .unwrap() +} + +fn not_found() -> Response<ResponseBody> { + response_with_status(StatusCode::NOT_FOUND) +} + +pub(super) fn call_fallback<F, B, FResBody>( + fallback: &mut F, + req: Request<B>, +) -> ResponseFutureInner<B, F> +where + F: Service<Request<B>, Response = Response<FResBody>, Error = Infallible> + Clone, + F::Future: Send + 'static, + FResBody: http_body::Body<Data = Bytes> + Send + 'static, + FResBody::Error: Into<BoxError>, +{ + let future = fallback + .call(req) + .map_ok(|response| { + response + .map(|body| { + UnsyncBoxBody::new( + body.map_err(|err| match err.into().downcast::<io::Error>() { + Ok(err) => *err, + Err(err) => io::Error::new(io::ErrorKind::Other, err), + }) + .boxed_unsync(), + ) + }) + .map(ResponseBody::new) + }) + .boxed(); + + ResponseFutureInner::FallbackFuture { future } +} + +fn build_response(output: FileOpened) -> Response<ResponseBody> { + let (maybe_file, size) = match output.extent { + FileRequestExtent::Full(file, meta) => (Some(file), meta.len()), + FileRequestExtent::Head(meta) => (None, meta.len()), + }; + + let mut builder = Response::builder() + .header(header::CONTENT_TYPE, output.mime_header_value) + .header(header::ACCEPT_RANGES, "bytes"); + + if let Some(encoding) = output + .maybe_encoding + .filter(|encoding| *encoding != Encoding::Identity) + { + builder = builder.header(header::CONTENT_ENCODING, encoding.into_header_value()); + } + + if let Some(last_modified) = output.last_modified { + builder = builder.header(header::LAST_MODIFIED, last_modified.0.to_string()); + } + + match output.maybe_range { + Some(Ok(ranges)) => { + if let Some(range) = ranges.first() { + if ranges.len() > 1 { + builder + .header(header::CONTENT_RANGE, format!("bytes */{}", size)) + .status(StatusCode::RANGE_NOT_SATISFIABLE) + .body(body_from_bytes(Bytes::from( + "Cannot serve multipart range requests", + ))) + .unwrap() + } else { + let body = if let Some(file) = maybe_file { + let range_size = range.end() - range.start() + 1; + ResponseBody::new(UnsyncBoxBody::new( + AsyncReadBody::with_capacity_limited( + file, + output.chunk_size, + range_size, + ) + .boxed_unsync(), + )) + } else { + empty_body() + }; + + let content_length = if size == 0 { + 0 + } else { + range.end() - range.start() + 1 + }; + + builder + .header( + header::CONTENT_RANGE, + format!("bytes {}-{}/{}", range.start(), range.end(), size), + ) + .header(header::CONTENT_LENGTH, content_length) + .status(StatusCode::PARTIAL_CONTENT) + .body(body) + .unwrap() + } + } else { + builder + .header(header::CONTENT_RANGE, format!("bytes */{}", size)) + .status(StatusCode::RANGE_NOT_SATISFIABLE) + .body(body_from_bytes(Bytes::from( + "No range found after parsing range header, please file an issue", + ))) + .unwrap() + } + } + + Some(Err(_)) => builder + .header(header::CONTENT_RANGE, format!("bytes */{}", size)) + .status(StatusCode::RANGE_NOT_SATISFIABLE) + .body(empty_body()) + .unwrap(), + + // Not a range request + None => { + let body = if let Some(file) = maybe_file { + ResponseBody::new(UnsyncBoxBody::new( + AsyncReadBody::with_capacity(file, output.chunk_size).boxed_unsync(), + )) + } else { + empty_body() + }; + + builder + .header(header::CONTENT_LENGTH, size.to_string()) + .body(body) + .unwrap() + } + } +} + +fn body_from_bytes(bytes: Bytes) -> ResponseBody { + let body = Full::from(bytes).map_err(|err| match err {}).boxed_unsync(); + ResponseBody::new(UnsyncBoxBody::new(body)) +} + +fn empty_body() -> ResponseBody { + let body = Empty::new().map_err(|err| match err {}).boxed_unsync(); + ResponseBody::new(UnsyncBoxBody::new(body)) +} diff --git a/vendor/tower-http/src/services/fs/serve_dir/headers.rs b/vendor/tower-http/src/services/fs/serve_dir/headers.rs new file mode 100644 index 00000000..e9e80907 --- /dev/null +++ b/vendor/tower-http/src/services/fs/serve_dir/headers.rs @@ -0,0 +1,45 @@ +use http::header::HeaderValue; +use httpdate::HttpDate; +use std::time::SystemTime; + +pub(super) struct LastModified(pub(super) HttpDate); + +impl From<SystemTime> for LastModified { + fn from(time: SystemTime) -> Self { + LastModified(time.into()) + } +} + +pub(super) struct IfModifiedSince(HttpDate); + +impl IfModifiedSince { + /// Check if the supplied time means the resource has been modified. + pub(super) fn is_modified(&self, last_modified: &LastModified) -> bool { + self.0 < last_modified.0 + } + + /// convert a header value into a IfModifiedSince, invalid values are silentely ignored + pub(super) fn from_header_value(value: &HeaderValue) -> Option<IfModifiedSince> { + std::str::from_utf8(value.as_bytes()) + .ok() + .and_then(|value| httpdate::parse_http_date(value).ok()) + .map(|time| IfModifiedSince(time.into())) + } +} + +pub(super) struct IfUnmodifiedSince(HttpDate); + +impl IfUnmodifiedSince { + /// Check if the supplied time passes the precondtion. + pub(super) fn precondition_passes(&self, last_modified: &LastModified) -> bool { + self.0 >= last_modified.0 + } + + /// Convert a header value into a IfModifiedSince, invalid values are silentely ignored + pub(super) fn from_header_value(value: &HeaderValue) -> Option<IfUnmodifiedSince> { + std::str::from_utf8(value.as_bytes()) + .ok() + .and_then(|value| httpdate::parse_http_date(value).ok()) + .map(|time| IfUnmodifiedSince(time.into())) + } +} diff --git a/vendor/tower-http/src/services/fs/serve_dir/mod.rs b/vendor/tower-http/src/services/fs/serve_dir/mod.rs new file mode 100644 index 00000000..61b956d1 --- /dev/null +++ b/vendor/tower-http/src/services/fs/serve_dir/mod.rs @@ -0,0 +1,541 @@ +use self::future::ResponseFuture; +use crate::{ + body::UnsyncBoxBody, + content_encoding::{encodings, SupportedEncodings}, + set_status::SetStatus, +}; +use bytes::Bytes; +use futures_util::FutureExt; +use http::{header, HeaderValue, Method, Request, Response, StatusCode}; +use http_body_util::{BodyExt, Empty}; +use percent_encoding::percent_decode; +use std::{ + convert::Infallible, + io, + path::{Component, Path, PathBuf}, + task::{Context, Poll}, +}; +use tower_service::Service; + +pub(crate) mod future; +mod headers; +mod open_file; + +#[cfg(test)] +mod tests; + +// default capacity 64KiB +const DEFAULT_CAPACITY: usize = 65536; + +/// Service that serves files from a given directory and all its sub directories. +/// +/// The `Content-Type` will be guessed from the file extension. +/// +/// An empty response with status `404 Not Found` will be returned if: +/// +/// - The file doesn't exist +/// - Any segment of the path contains `..` +/// - Any segment of the path contains a backslash +/// - On unix, any segment of the path referenced as directory is actually an +/// existing file (`/file.html/something`) +/// - We don't have necessary permissions to read the file +/// +/// # Example +/// +/// ``` +/// use tower_http::services::ServeDir; +/// +/// // This will serve files in the "assets" directory and +/// // its subdirectories +/// let service = ServeDir::new("assets"); +/// ``` +#[derive(Clone, Debug)] +pub struct ServeDir<F = DefaultServeDirFallback> { + base: PathBuf, + buf_chunk_size: usize, + precompressed_variants: Option<PrecompressedVariants>, + // This is used to specialise implementation for + // single files + variant: ServeVariant, + fallback: Option<F>, + call_fallback_on_method_not_allowed: bool, +} + +impl ServeDir<DefaultServeDirFallback> { + /// Create a new [`ServeDir`]. + pub fn new<P>(path: P) -> Self + where + P: AsRef<Path>, + { + let mut base = PathBuf::from("."); + base.push(path.as_ref()); + + Self { + base, + buf_chunk_size: DEFAULT_CAPACITY, + precompressed_variants: None, + variant: ServeVariant::Directory { + append_index_html_on_directories: true, + }, + fallback: None, + call_fallback_on_method_not_allowed: false, + } + } + + pub(crate) fn new_single_file<P>(path: P, mime: HeaderValue) -> Self + where + P: AsRef<Path>, + { + Self { + base: path.as_ref().to_owned(), + buf_chunk_size: DEFAULT_CAPACITY, + precompressed_variants: None, + variant: ServeVariant::SingleFile { mime }, + fallback: None, + call_fallback_on_method_not_allowed: false, + } + } +} + +impl<F> ServeDir<F> { + /// If the requested path is a directory append `index.html`. + /// + /// This is useful for static sites. + /// + /// Defaults to `true`. + pub fn append_index_html_on_directories(mut self, append: bool) -> Self { + match &mut self.variant { + ServeVariant::Directory { + append_index_html_on_directories, + } => { + *append_index_html_on_directories = append; + self + } + ServeVariant::SingleFile { mime: _ } => self, + } + } + + /// Set a specific read buffer chunk size. + /// + /// The default capacity is 64kb. + pub fn with_buf_chunk_size(mut self, chunk_size: usize) -> Self { + self.buf_chunk_size = chunk_size; + self + } + + /// Informs the service that it should also look for a precompressed gzip + /// version of _any_ file in the directory. + /// + /// Assuming the `dir` directory is being served and `dir/foo.txt` is requested, + /// a client with an `Accept-Encoding` header that allows the gzip encoding + /// will receive the file `dir/foo.txt.gz` instead of `dir/foo.txt`. + /// If the precompressed file is not available, or the client doesn't support it, + /// the uncompressed version will be served instead. + /// Both the precompressed version and the uncompressed version are expected + /// to be present in the directory. Different precompressed variants can be combined. + pub fn precompressed_gzip(mut self) -> Self { + self.precompressed_variants + .get_or_insert(Default::default()) + .gzip = true; + self + } + + /// Informs the service that it should also look for a precompressed brotli + /// version of _any_ file in the directory. + /// + /// Assuming the `dir` directory is being served and `dir/foo.txt` is requested, + /// a client with an `Accept-Encoding` header that allows the brotli encoding + /// will receive the file `dir/foo.txt.br` instead of `dir/foo.txt`. + /// If the precompressed file is not available, or the client doesn't support it, + /// the uncompressed version will be served instead. + /// Both the precompressed version and the uncompressed version are expected + /// to be present in the directory. Different precompressed variants can be combined. + pub fn precompressed_br(mut self) -> Self { + self.precompressed_variants + .get_or_insert(Default::default()) + .br = true; + self + } + + /// Informs the service that it should also look for a precompressed deflate + /// version of _any_ file in the directory. + /// + /// Assuming the `dir` directory is being served and `dir/foo.txt` is requested, + /// a client with an `Accept-Encoding` header that allows the deflate encoding + /// will receive the file `dir/foo.txt.zz` instead of `dir/foo.txt`. + /// If the precompressed file is not available, or the client doesn't support it, + /// the uncompressed version will be served instead. + /// Both the precompressed version and the uncompressed version are expected + /// to be present in the directory. Different precompressed variants can be combined. + pub fn precompressed_deflate(mut self) -> Self { + self.precompressed_variants + .get_or_insert(Default::default()) + .deflate = true; + self + } + + /// Informs the service that it should also look for a precompressed zstd + /// version of _any_ file in the directory. + /// + /// Assuming the `dir` directory is being served and `dir/foo.txt` is requested, + /// a client with an `Accept-Encoding` header that allows the zstd encoding + /// will receive the file `dir/foo.txt.zst` instead of `dir/foo.txt`. + /// If the precompressed file is not available, or the client doesn't support it, + /// the uncompressed version will be served instead. + /// Both the precompressed version and the uncompressed version are expected + /// to be present in the directory. Different precompressed variants can be combined. + pub fn precompressed_zstd(mut self) -> Self { + self.precompressed_variants + .get_or_insert(Default::default()) + .zstd = true; + self + } + + /// Set the fallback service. + /// + /// This service will be called if there is no file at the path of the request. + /// + /// The status code returned by the fallback will not be altered. Use + /// [`ServeDir::not_found_service`] to set a fallback and always respond with `404 Not Found`. + /// + /// # Example + /// + /// This can be used to respond with a different file: + /// + /// ```rust + /// use tower_http::services::{ServeDir, ServeFile}; + /// + /// let service = ServeDir::new("assets") + /// // respond with `not_found.html` for missing files + /// .fallback(ServeFile::new("assets/not_found.html")); + /// ``` + pub fn fallback<F2>(self, new_fallback: F2) -> ServeDir<F2> { + ServeDir { + base: self.base, + buf_chunk_size: self.buf_chunk_size, + precompressed_variants: self.precompressed_variants, + variant: self.variant, + fallback: Some(new_fallback), + call_fallback_on_method_not_allowed: self.call_fallback_on_method_not_allowed, + } + } + + /// Set the fallback service and override the fallback's status code to `404 Not Found`. + /// + /// This service will be called if there is no file at the path of the request. + /// + /// # Example + /// + /// This can be used to respond with a different file: + /// + /// ```rust + /// use tower_http::services::{ServeDir, ServeFile}; + /// + /// let service = ServeDir::new("assets") + /// // respond with `404 Not Found` and the contents of `not_found.html` for missing files + /// .not_found_service(ServeFile::new("assets/not_found.html")); + /// ``` + /// + /// Setups like this are often found in single page applications. + pub fn not_found_service<F2>(self, new_fallback: F2) -> ServeDir<SetStatus<F2>> { + self.fallback(SetStatus::new(new_fallback, StatusCode::NOT_FOUND)) + } + + /// Customize whether or not to call the fallback for requests that aren't `GET` or `HEAD`. + /// + /// Defaults to not calling the fallback and instead returning `405 Method Not Allowed`. + pub fn call_fallback_on_method_not_allowed(mut self, call_fallback: bool) -> Self { + self.call_fallback_on_method_not_allowed = call_fallback; + self + } + + /// Call the service and get a future that contains any `std::io::Error` that might have + /// happened. + /// + /// By default `<ServeDir as Service<_>>::call` will handle IO errors and convert them into + /// responses. It does that by converting [`std::io::ErrorKind::NotFound`] and + /// [`std::io::ErrorKind::PermissionDenied`] to `404 Not Found` and any other error to `500 + /// Internal Server Error`. The error will also be logged with `tracing`. + /// + /// If you want to manually control how the error response is generated you can make a new + /// service that wraps a `ServeDir` and calls `try_call` instead of `call`. + /// + /// # Example + /// + /// ``` + /// use tower_http::services::ServeDir; + /// use std::{io, convert::Infallible}; + /// use http::{Request, Response, StatusCode}; + /// use http_body::Body as _; + /// use http_body_util::{Full, BodyExt, combinators::UnsyncBoxBody}; + /// use bytes::Bytes; + /// use tower::{service_fn, ServiceExt, BoxError}; + /// + /// async fn serve_dir( + /// request: Request<Full<Bytes>> + /// ) -> Result<Response<UnsyncBoxBody<Bytes, BoxError>>, Infallible> { + /// let mut service = ServeDir::new("assets"); + /// + /// // You only need to worry about backpressure, and thus call `ServiceExt::ready`, if + /// // your adding a fallback to `ServeDir` that cares about backpressure. + /// // + /// // Its shown here for demonstration but you can do `service.try_call(request)` + /// // otherwise + /// let ready_service = match ServiceExt::<Request<Full<Bytes>>>::ready(&mut service).await { + /// Ok(ready_service) => ready_service, + /// Err(infallible) => match infallible {}, + /// }; + /// + /// match ready_service.try_call(request).await { + /// Ok(response) => { + /// Ok(response.map(|body| body.map_err(Into::into).boxed_unsync())) + /// } + /// Err(err) => { + /// let body = Full::from("Something went wrong...") + /// .map_err(Into::into) + /// .boxed_unsync(); + /// let response = Response::builder() + /// .status(StatusCode::INTERNAL_SERVER_ERROR) + /// .body(body) + /// .unwrap(); + /// Ok(response) + /// } + /// } + /// } + /// ``` + pub fn try_call<ReqBody, FResBody>( + &mut self, + req: Request<ReqBody>, + ) -> ResponseFuture<ReqBody, F> + where + F: Service<Request<ReqBody>, Response = Response<FResBody>, Error = Infallible> + Clone, + F::Future: Send + 'static, + FResBody: http_body::Body<Data = Bytes> + Send + 'static, + FResBody::Error: Into<Box<dyn std::error::Error + Send + Sync>>, + { + if req.method() != Method::GET && req.method() != Method::HEAD { + if self.call_fallback_on_method_not_allowed { + if let Some(fallback) = &mut self.fallback { + return ResponseFuture { + inner: future::call_fallback(fallback, req), + }; + } + } else { + return ResponseFuture::method_not_allowed(); + } + } + + // `ServeDir` doesn't care about the request body but the fallback might. So move out the + // body and pass it to the fallback, leaving an empty body in its place + // + // this is necessary because we cannot clone bodies + let (mut parts, body) = req.into_parts(); + // same goes for extensions + let extensions = std::mem::take(&mut parts.extensions); + let req = Request::from_parts(parts, Empty::<Bytes>::new()); + + let fallback_and_request = self.fallback.as_mut().map(|fallback| { + let mut fallback_req = Request::new(body); + *fallback_req.method_mut() = req.method().clone(); + *fallback_req.uri_mut() = req.uri().clone(); + *fallback_req.headers_mut() = req.headers().clone(); + *fallback_req.extensions_mut() = extensions; + + // get the ready fallback and leave a non-ready clone in its place + let clone = fallback.clone(); + let fallback = std::mem::replace(fallback, clone); + + (fallback, fallback_req) + }); + + let path_to_file = match self + .variant + .build_and_validate_path(&self.base, req.uri().path()) + { + Some(path_to_file) => path_to_file, + None => { + return ResponseFuture::invalid_path(fallback_and_request); + } + }; + + let buf_chunk_size = self.buf_chunk_size; + let range_header = req + .headers() + .get(header::RANGE) + .and_then(|value| value.to_str().ok()) + .map(|s| s.to_owned()); + + let negotiated_encodings: Vec<_> = encodings( + req.headers(), + self.precompressed_variants.unwrap_or_default(), + ) + .collect(); + + let variant = self.variant.clone(); + + let open_file_future = Box::pin(open_file::open_file( + variant, + path_to_file, + req, + negotiated_encodings, + range_header, + buf_chunk_size, + )); + + ResponseFuture::open_file_future(open_file_future, fallback_and_request) + } +} + +impl<ReqBody, F, FResBody> Service<Request<ReqBody>> for ServeDir<F> +where + F: Service<Request<ReqBody>, Response = Response<FResBody>, Error = Infallible> + Clone, + F::Future: Send + 'static, + FResBody: http_body::Body<Data = Bytes> + Send + 'static, + FResBody::Error: Into<Box<dyn std::error::Error + Send + Sync>>, +{ + type Response = Response<ResponseBody>; + type Error = Infallible; + type Future = InfallibleResponseFuture<ReqBody, F>; + + #[inline] + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + if let Some(fallback) = &mut self.fallback { + fallback.poll_ready(cx) + } else { + Poll::Ready(Ok(())) + } + } + + fn call(&mut self, req: Request<ReqBody>) -> Self::Future { + let future = self + .try_call(req) + .map(|result: Result<_, _>| -> Result<_, Infallible> { + let response = result.unwrap_or_else(|err| { + tracing::error!(error = %err, "Failed to read file"); + + let body = ResponseBody::new(UnsyncBoxBody::new( + Empty::new().map_err(|err| match err {}).boxed_unsync(), + )); + Response::builder() + .status(StatusCode::INTERNAL_SERVER_ERROR) + .body(body) + .unwrap() + }); + Ok(response) + } as _); + + InfallibleResponseFuture::new(future) + } +} + +opaque_future! { + /// Response future of [`ServeDir`]. + pub type InfallibleResponseFuture<ReqBody, F> = + futures_util::future::Map< + ResponseFuture<ReqBody, F>, + fn(Result<Response<ResponseBody>, io::Error>) -> Result<Response<ResponseBody>, Infallible>, + >; +} + +// Allow the ServeDir service to be used in the ServeFile service +// with almost no overhead +#[derive(Clone, Debug)] +enum ServeVariant { + Directory { + append_index_html_on_directories: bool, + }, + SingleFile { + mime: HeaderValue, + }, +} + +impl ServeVariant { + fn build_and_validate_path(&self, base_path: &Path, requested_path: &str) -> Option<PathBuf> { + match self { + ServeVariant::Directory { + append_index_html_on_directories: _, + } => { + let path = requested_path.trim_start_matches('/'); + + let path_decoded = percent_decode(path.as_ref()).decode_utf8().ok()?; + let path_decoded = Path::new(&*path_decoded); + + let mut path_to_file = base_path.to_path_buf(); + for component in path_decoded.components() { + match component { + Component::Normal(comp) => { + // protect against paths like `/foo/c:/bar/baz` (#204) + if Path::new(&comp) + .components() + .all(|c| matches!(c, Component::Normal(_))) + { + path_to_file.push(comp) + } else { + return None; + } + } + Component::CurDir => {} + Component::Prefix(_) | Component::RootDir | Component::ParentDir => { + return None; + } + } + } + Some(path_to_file) + } + ServeVariant::SingleFile { mime: _ } => Some(base_path.to_path_buf()), + } + } +} + +opaque_body! { + /// Response body for [`ServeDir`] and [`ServeFile`][super::ServeFile]. + #[derive(Default)] + pub type ResponseBody = UnsyncBoxBody<Bytes, io::Error>; +} + +/// The default fallback service used with [`ServeDir`]. +#[derive(Debug, Clone, Copy)] +pub struct DefaultServeDirFallback(Infallible); + +impl<ReqBody> Service<Request<ReqBody>> for DefaultServeDirFallback +where + ReqBody: Send + 'static, +{ + type Response = Response<ResponseBody>; + type Error = Infallible; + type Future = InfallibleResponseFuture<ReqBody, Self>; + + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + match self.0 {} + } + + fn call(&mut self, _req: Request<ReqBody>) -> Self::Future { + match self.0 {} + } +} + +#[derive(Clone, Copy, Debug, Default)] +struct PrecompressedVariants { + gzip: bool, + deflate: bool, + br: bool, + zstd: bool, +} + +impl SupportedEncodings for PrecompressedVariants { + fn gzip(&self) -> bool { + self.gzip + } + + fn deflate(&self) -> bool { + self.deflate + } + + fn br(&self) -> bool { + self.br + } + + fn zstd(&self) -> bool { + self.zstd + } +} diff --git a/vendor/tower-http/src/services/fs/serve_dir/open_file.rs b/vendor/tower-http/src/services/fs/serve_dir/open_file.rs new file mode 100644 index 00000000..9ddedd8a --- /dev/null +++ b/vendor/tower-http/src/services/fs/serve_dir/open_file.rs @@ -0,0 +1,340 @@ +use super::{ + headers::{IfModifiedSince, IfUnmodifiedSince, LastModified}, + ServeVariant, +}; +use crate::content_encoding::{Encoding, QValue}; +use bytes::Bytes; +use http::{header, HeaderValue, Method, Request, Uri}; +use http_body_util::Empty; +use http_range_header::RangeUnsatisfiableError; +use std::{ + ffi::OsStr, + fs::Metadata, + io::{self, SeekFrom}, + ops::RangeInclusive, + path::{Path, PathBuf}, +}; +use tokio::{fs::File, io::AsyncSeekExt}; + +pub(super) enum OpenFileOutput { + FileOpened(Box<FileOpened>), + Redirect { location: HeaderValue }, + FileNotFound, + PreconditionFailed, + NotModified, + InvalidRedirectUri, +} + +pub(super) struct FileOpened { + pub(super) extent: FileRequestExtent, + pub(super) chunk_size: usize, + pub(super) mime_header_value: HeaderValue, + pub(super) maybe_encoding: Option<Encoding>, + pub(super) maybe_range: Option<Result<Vec<RangeInclusive<u64>>, RangeUnsatisfiableError>>, + pub(super) last_modified: Option<LastModified>, +} + +pub(super) enum FileRequestExtent { + Full(File, Metadata), + Head(Metadata), +} + +pub(super) async fn open_file( + variant: ServeVariant, + mut path_to_file: PathBuf, + req: Request<Empty<Bytes>>, + negotiated_encodings: Vec<(Encoding, QValue)>, + range_header: Option<String>, + buf_chunk_size: usize, +) -> io::Result<OpenFileOutput> { + let if_unmodified_since = req + .headers() + .get(header::IF_UNMODIFIED_SINCE) + .and_then(IfUnmodifiedSince::from_header_value); + + let if_modified_since = req + .headers() + .get(header::IF_MODIFIED_SINCE) + .and_then(IfModifiedSince::from_header_value); + + let mime = match variant { + ServeVariant::Directory { + append_index_html_on_directories, + } => { + // Might already at this point know a redirect or not found result should be + // returned which corresponds to a Some(output). Otherwise the path might be + // modified and proceed to the open file/metadata future. + if let Some(output) = maybe_redirect_or_append_path( + &mut path_to_file, + req.uri(), + append_index_html_on_directories, + ) + .await + { + return Ok(output); + } + + mime_guess::from_path(&path_to_file) + .first_raw() + .map(HeaderValue::from_static) + .unwrap_or_else(|| { + HeaderValue::from_str(mime::APPLICATION_OCTET_STREAM.as_ref()).unwrap() + }) + } + + ServeVariant::SingleFile { mime } => mime, + }; + + if req.method() == Method::HEAD { + let (meta, maybe_encoding) = + file_metadata_with_fallback(path_to_file, negotiated_encodings).await?; + + let last_modified = meta.modified().ok().map(LastModified::from); + if let Some(output) = check_modified_headers( + last_modified.as_ref(), + if_unmodified_since, + if_modified_since, + ) { + return Ok(output); + } + + let maybe_range = try_parse_range(range_header.as_deref(), meta.len()); + + Ok(OpenFileOutput::FileOpened(Box::new(FileOpened { + extent: FileRequestExtent::Head(meta), + chunk_size: buf_chunk_size, + mime_header_value: mime, + maybe_encoding, + maybe_range, + last_modified, + }))) + } else { + let (mut file, maybe_encoding) = + open_file_with_fallback(path_to_file, negotiated_encodings).await?; + let meta = file.metadata().await?; + let last_modified = meta.modified().ok().map(LastModified::from); + if let Some(output) = check_modified_headers( + last_modified.as_ref(), + if_unmodified_since, + if_modified_since, + ) { + return Ok(output); + } + + let maybe_range = try_parse_range(range_header.as_deref(), meta.len()); + if let Some(Ok(ranges)) = maybe_range.as_ref() { + // if there is any other amount of ranges than 1 we'll return an + // unsatisfiable later as there isn't yet support for multipart ranges + if ranges.len() == 1 { + file.seek(SeekFrom::Start(*ranges[0].start())).await?; + } + } + + Ok(OpenFileOutput::FileOpened(Box::new(FileOpened { + extent: FileRequestExtent::Full(file, meta), + chunk_size: buf_chunk_size, + mime_header_value: mime, + maybe_encoding, + maybe_range, + last_modified, + }))) + } +} + +fn check_modified_headers( + modified: Option<&LastModified>, + if_unmodified_since: Option<IfUnmodifiedSince>, + if_modified_since: Option<IfModifiedSince>, +) -> Option<OpenFileOutput> { + if let Some(since) = if_unmodified_since { + let precondition = modified + .as_ref() + .map(|time| since.precondition_passes(time)) + .unwrap_or(false); + + if !precondition { + return Some(OpenFileOutput::PreconditionFailed); + } + } + + if let Some(since) = if_modified_since { + let unmodified = modified + .as_ref() + .map(|time| !since.is_modified(time)) + // no last_modified means its always modified + .unwrap_or(false); + if unmodified { + return Some(OpenFileOutput::NotModified); + } + } + + None +} + +// Returns the preferred_encoding encoding and modifies the path extension +// to the corresponding file extension for the encoding. +fn preferred_encoding( + path: &mut PathBuf, + negotiated_encoding: &[(Encoding, QValue)], +) -> Option<Encoding> { + let preferred_encoding = Encoding::preferred_encoding(negotiated_encoding.iter().copied()); + + if let Some(file_extension) = + preferred_encoding.and_then(|encoding| encoding.to_file_extension()) + { + let new_file_name = path + .file_name() + .map(|file_name| { + let mut os_string = file_name.to_os_string(); + os_string.push(file_extension); + os_string + }) + .unwrap_or_else(|| file_extension.to_os_string()); + + path.set_file_name(new_file_name); + } + + preferred_encoding +} + +// Attempts to open the file with any of the possible negotiated_encodings in the +// preferred order. If none of the negotiated_encodings have a corresponding precompressed +// file the uncompressed file is used as a fallback. +async fn open_file_with_fallback( + mut path: PathBuf, + mut negotiated_encoding: Vec<(Encoding, QValue)>, +) -> io::Result<(File, Option<Encoding>)> { + let (file, encoding) = loop { + // Get the preferred encoding among the negotiated ones. + let encoding = preferred_encoding(&mut path, &negotiated_encoding); + match (File::open(&path).await, encoding) { + (Ok(file), maybe_encoding) => break (file, maybe_encoding), + (Err(err), Some(encoding)) if err.kind() == io::ErrorKind::NotFound => { + // Remove the extension corresponding to a precompressed file (.gz, .br, .zz) + // to reset the path before the next iteration. + path.set_extension(OsStr::new("")); + // Remove the encoding from the negotiated_encodings since the file doesn't exist + negotiated_encoding + .retain(|(negotiated_encoding, _)| *negotiated_encoding != encoding); + } + (Err(err), _) => return Err(err), + } + }; + Ok((file, encoding)) +} + +// Attempts to get the file metadata with any of the possible negotiated_encodings in the +// preferred order. If none of the negotiated_encodings have a corresponding precompressed +// file the uncompressed file is used as a fallback. +async fn file_metadata_with_fallback( + mut path: PathBuf, + mut negotiated_encoding: Vec<(Encoding, QValue)>, +) -> io::Result<(Metadata, Option<Encoding>)> { + let (file, encoding) = loop { + // Get the preferred encoding among the negotiated ones. + let encoding = preferred_encoding(&mut path, &negotiated_encoding); + match (tokio::fs::metadata(&path).await, encoding) { + (Ok(file), maybe_encoding) => break (file, maybe_encoding), + (Err(err), Some(encoding)) if err.kind() == io::ErrorKind::NotFound => { + // Remove the extension corresponding to a precompressed file (.gz, .br, .zz) + // to reset the path before the next iteration. + path.set_extension(OsStr::new("")); + // Remove the encoding from the negotiated_encodings since the file doesn't exist + negotiated_encoding + .retain(|(negotiated_encoding, _)| *negotiated_encoding != encoding); + } + (Err(err), _) => return Err(err), + } + }; + Ok((file, encoding)) +} + +async fn maybe_redirect_or_append_path( + path_to_file: &mut PathBuf, + uri: &Uri, + append_index_html_on_directories: bool, +) -> Option<OpenFileOutput> { + if !is_dir(path_to_file).await { + return None; + } + + if !append_index_html_on_directories { + return Some(OpenFileOutput::FileNotFound); + } + + if uri.path().ends_with('/') { + path_to_file.push("index.html"); + None + } else { + let uri = match append_slash_on_path(uri.clone()) { + Ok(uri) => uri, + Err(err) => return Some(err), + }; + let location = HeaderValue::from_str(&uri.to_string()).unwrap(); + Some(OpenFileOutput::Redirect { location }) + } +} + +fn try_parse_range( + maybe_range_ref: Option<&str>, + file_size: u64, +) -> Option<Result<Vec<RangeInclusive<u64>>, RangeUnsatisfiableError>> { + maybe_range_ref.map(|header_value| { + http_range_header::parse_range_header(header_value) + .and_then(|first_pass| first_pass.validate(file_size)) + }) +} + +async fn is_dir(path_to_file: &Path) -> bool { + tokio::fs::metadata(path_to_file) + .await + .map_or(false, |meta_data| meta_data.is_dir()) +} + +fn append_slash_on_path(uri: Uri) -> Result<Uri, OpenFileOutput> { + let http::uri::Parts { + scheme, + authority, + path_and_query, + .. + } = uri.into_parts(); + + let mut uri_builder = Uri::builder(); + + if let Some(scheme) = scheme { + uri_builder = uri_builder.scheme(scheme); + } + + if let Some(authority) = authority { + uri_builder = uri_builder.authority(authority); + } + + let uri_builder = if let Some(path_and_query) = path_and_query { + if let Some(query) = path_and_query.query() { + uri_builder.path_and_query(format!("{}/?{}", path_and_query.path(), query)) + } else { + uri_builder.path_and_query(format!("{}/", path_and_query.path())) + } + } else { + uri_builder.path_and_query("/") + }; + + uri_builder.build().map_err(|err| { + tracing::error!(?err, "redirect uri failed to build"); + OpenFileOutput::InvalidRedirectUri + }) +} + +#[test] +fn preferred_encoding_with_extension() { + let mut path = PathBuf::from("hello.txt"); + preferred_encoding(&mut path, &[(Encoding::Gzip, QValue::one())]); + assert_eq!(path, PathBuf::from("hello.txt.gz")); +} + +#[test] +fn preferred_encoding_without_extension() { + let mut path = PathBuf::from("hello"); + preferred_encoding(&mut path, &[(Encoding::Gzip, QValue::one())]); + assert_eq!(path, PathBuf::from("hello.gz")); +} diff --git a/vendor/tower-http/src/services/fs/serve_dir/tests.rs b/vendor/tower-http/src/services/fs/serve_dir/tests.rs new file mode 100644 index 00000000..ea1c543e --- /dev/null +++ b/vendor/tower-http/src/services/fs/serve_dir/tests.rs @@ -0,0 +1,836 @@ +use crate::services::{ServeDir, ServeFile}; +use crate::test_helpers::{to_bytes, Body}; +use brotli::BrotliDecompress; +use bytes::Bytes; +use flate2::bufread::{DeflateDecoder, GzDecoder}; +use http::header::ALLOW; +use http::{header, Method, Response}; +use http::{Request, StatusCode}; +use http_body::Body as HttpBody; +use http_body_util::BodyExt; +use std::convert::Infallible; +use std::fs; +use std::io::Read; +use tower::{service_fn, ServiceExt}; + +#[tokio::test] +async fn basic() { + let svc = ServeDir::new(".."); + + let req = Request::builder() + .uri("/README.md") + .body(Body::empty()) + .unwrap(); + let res = svc.oneshot(req).await.unwrap(); + + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(res.headers()["content-type"], "text/markdown"); + + let body = body_into_text(res.into_body()).await; + + let contents = std::fs::read_to_string("../README.md").unwrap(); + assert_eq!(body, contents); +} + +#[tokio::test] +async fn basic_with_index() { + let svc = ServeDir::new("../test-files"); + + let req = Request::new(Body::empty()); + let res = svc.oneshot(req).await.unwrap(); + + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(res.headers()[header::CONTENT_TYPE], "text/html"); + + let body = body_into_text(res.into_body()).await; + assert_eq!(body, "<b>HTML!</b>\n"); +} + +#[tokio::test] +async fn head_request() { + let svc = ServeDir::new("../test-files"); + + let req = Request::builder() + .uri("/precompressed.txt") + .method(Method::HEAD) + .body(Body::empty()) + .unwrap(); + + let res = svc.oneshot(req).await.unwrap(); + + assert_eq!(res.headers()["content-type"], "text/plain"); + assert_eq!(res.headers()["content-length"], "23"); + + assert!(res.into_body().frame().await.is_none()); +} + +#[tokio::test] +async fn precompresed_head_request() { + let svc = ServeDir::new("../test-files").precompressed_gzip(); + + let req = Request::builder() + .uri("/precompressed.txt") + .header("Accept-Encoding", "gzip") + .method(Method::HEAD) + .body(Body::empty()) + .unwrap(); + let res = svc.oneshot(req).await.unwrap(); + + assert_eq!(res.headers()["content-type"], "text/plain"); + assert_eq!(res.headers()["content-encoding"], "gzip"); + assert_eq!(res.headers()["content-length"], "59"); + + assert!(res.into_body().frame().await.is_none()); +} + +#[tokio::test] +async fn with_custom_chunk_size() { + let svc = ServeDir::new("..").with_buf_chunk_size(1024 * 32); + + let req = Request::builder() + .uri("/README.md") + .body(Body::empty()) + .unwrap(); + let res = svc.oneshot(req).await.unwrap(); + + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(res.headers()["content-type"], "text/markdown"); + + let body = body_into_text(res.into_body()).await; + + let contents = std::fs::read_to_string("../README.md").unwrap(); + assert_eq!(body, contents); +} + +#[tokio::test] +async fn precompressed_gzip() { + let svc = ServeDir::new("../test-files").precompressed_gzip(); + + let req = Request::builder() + .uri("/precompressed.txt") + .header("Accept-Encoding", "gzip") + .body(Body::empty()) + .unwrap(); + let res = svc.oneshot(req).await.unwrap(); + + assert_eq!(res.headers()["content-type"], "text/plain"); + assert_eq!(res.headers()["content-encoding"], "gzip"); + + let body = res.into_body().collect().await.unwrap().to_bytes(); + let mut decoder = GzDecoder::new(&body[..]); + let mut decompressed = String::new(); + decoder.read_to_string(&mut decompressed).unwrap(); + assert!(decompressed.starts_with("\"This is a test file!\"")); +} + +#[tokio::test] +async fn precompressed_br() { + let svc = ServeDir::new("../test-files").precompressed_br(); + + let req = Request::builder() + .uri("/precompressed.txt") + .header("Accept-Encoding", "br") + .body(Body::empty()) + .unwrap(); + let res = svc.oneshot(req).await.unwrap(); + + assert_eq!(res.headers()["content-type"], "text/plain"); + assert_eq!(res.headers()["content-encoding"], "br"); + + let body = res.into_body().collect().await.unwrap().to_bytes(); + let mut decompressed = Vec::new(); + BrotliDecompress(&mut &body[..], &mut decompressed).unwrap(); + let decompressed = String::from_utf8(decompressed.to_vec()).unwrap(); + assert!(decompressed.starts_with("\"This is a test file!\"")); +} + +#[tokio::test] +async fn precompressed_deflate() { + let svc = ServeDir::new("../test-files").precompressed_deflate(); + let request = Request::builder() + .uri("/precompressed.txt") + .header("Accept-Encoding", "deflate,br") + .body(Body::empty()) + .unwrap(); + let res = svc.oneshot(request).await.unwrap(); + + assert_eq!(res.headers()["content-type"], "text/plain"); + assert_eq!(res.headers()["content-encoding"], "deflate"); + + let body = res.into_body().collect().await.unwrap().to_bytes(); + let mut decoder = DeflateDecoder::new(&body[..]); + let mut decompressed = String::new(); + decoder.read_to_string(&mut decompressed).unwrap(); + assert!(decompressed.starts_with("\"This is a test file!\"")); +} + +#[tokio::test] +async fn unsupported_precompression_alogrithm_fallbacks_to_uncompressed() { + let svc = ServeDir::new("../test-files").precompressed_gzip(); + + let request = Request::builder() + .uri("/precompressed.txt") + .header("Accept-Encoding", "br") + .body(Body::empty()) + .unwrap(); + let res = svc.oneshot(request).await.unwrap(); + + assert_eq!(res.headers()["content-type"], "text/plain"); + assert!(res.headers().get("content-encoding").is_none()); + + let body = res.into_body().collect().await.unwrap().to_bytes(); + let body = String::from_utf8(body.to_vec()).unwrap(); + assert!(body.starts_with("\"This is a test file!\"")); +} + +#[tokio::test] +async fn only_precompressed_variant_existing() { + let svc = ServeDir::new("../test-files").precompressed_gzip(); + + let request = Request::builder() + .uri("/only_gzipped.txt") + .body(Body::empty()) + .unwrap(); + let res = svc.clone().oneshot(request).await.unwrap(); + + assert_eq!(res.status(), StatusCode::NOT_FOUND); + + // Should reply with gzipped file if client supports it + let request = Request::builder() + .uri("/only_gzipped.txt") + .header("Accept-Encoding", "gzip") + .body(Body::empty()) + .unwrap(); + let res = svc.oneshot(request).await.unwrap(); + + assert_eq!(res.headers()["content-type"], "text/plain"); + assert_eq!(res.headers()["content-encoding"], "gzip"); + + let body = res.into_body().collect().await.unwrap().to_bytes(); + let mut decoder = GzDecoder::new(&body[..]); + let mut decompressed = String::new(); + decoder.read_to_string(&mut decompressed).unwrap(); + assert!(decompressed.starts_with("\"This is a test file\"")); +} + +#[tokio::test] +async fn missing_precompressed_variant_fallbacks_to_uncompressed() { + let svc = ServeDir::new("../test-files").precompressed_gzip(); + + let request = Request::builder() + .uri("/missing_precompressed.txt") + .header("Accept-Encoding", "gzip") + .body(Body::empty()) + .unwrap(); + let res = svc.oneshot(request).await.unwrap(); + + assert_eq!(res.headers()["content-type"], "text/plain"); + // Uncompressed file is served because compressed version is missing + assert!(res.headers().get("content-encoding").is_none()); + + let body = res.into_body().collect().await.unwrap().to_bytes(); + let body = String::from_utf8(body.to_vec()).unwrap(); + assert!(body.starts_with("Test file!")); +} + +#[tokio::test] +async fn missing_precompressed_variant_fallbacks_to_uncompressed_for_head_request() { + let svc = ServeDir::new("../test-files").precompressed_gzip(); + + let request = Request::builder() + .uri("/missing_precompressed.txt") + .header("Accept-Encoding", "gzip") + .method(Method::HEAD) + .body(Body::empty()) + .unwrap(); + let res = svc.oneshot(request).await.unwrap(); + + assert_eq!(res.headers()["content-type"], "text/plain"); + assert_eq!(res.headers()["content-length"], "11"); + // Uncompressed file is served because compressed version is missing + assert!(res.headers().get("content-encoding").is_none()); + + assert!(res.into_body().frame().await.is_none()); +} + +#[tokio::test] +async fn precompressed_without_extension() { + let svc = ServeDir::new("../test-files").precompressed_gzip(); + + let request = Request::builder() + .uri("/extensionless_precompressed") + .header("Accept-Encoding", "gzip") + .body(Body::empty()) + .unwrap(); + let res = svc.oneshot(request).await.unwrap(); + + assert_eq!(res.status(), StatusCode::OK); + + assert_eq!(res.headers()["content-type"], "application/octet-stream"); + assert_eq!(res.headers()["content-encoding"], "gzip"); + + let body = res.into_body().collect().await.unwrap().to_bytes(); + let mut decoder = GzDecoder::new(&body[..]); + let mut decompressed = String::new(); + decoder.read_to_string(&mut decompressed).unwrap(); + + let correct = fs::read_to_string("../test-files/extensionless_precompressed").unwrap(); + assert_eq!(decompressed, correct); +} + +#[tokio::test] +async fn missing_precompressed_without_extension_fallbacks_to_uncompressed() { + let svc = ServeDir::new("../test-files").precompressed_gzip(); + + let request = Request::builder() + .uri("/extensionless_precompressed_missing") + .header("Accept-Encoding", "gzip") + .body(Body::empty()) + .unwrap(); + let res = svc.oneshot(request).await.unwrap(); + + assert_eq!(res.status(), StatusCode::OK); + + assert_eq!(res.headers()["content-type"], "application/octet-stream"); + assert!(res.headers().get("content-encoding").is_none()); + + let body = res.into_body().collect().await.unwrap().to_bytes(); + let body = String::from_utf8(body.to_vec()).unwrap(); + + let correct = fs::read_to_string("../test-files/extensionless_precompressed_missing").unwrap(); + assert_eq!(body, correct); +} + +#[tokio::test] +async fn access_to_sub_dirs() { + let svc = ServeDir::new(".."); + + let req = Request::builder() + .uri("/tower-http/Cargo.toml") + .body(Body::empty()) + .unwrap(); + let res = svc.oneshot(req).await.unwrap(); + + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(res.headers()["content-type"], "text/x-toml"); + + let body = body_into_text(res.into_body()).await; + + let contents = std::fs::read_to_string("Cargo.toml").unwrap(); + assert_eq!(body, contents); +} + +#[tokio::test] +async fn not_found() { + let svc = ServeDir::new(".."); + + let req = Request::builder() + .uri("/not-found") + .body(Body::empty()) + .unwrap(); + let res = svc.oneshot(req).await.unwrap(); + + assert_eq!(res.status(), StatusCode::NOT_FOUND); + assert!(res.headers().get(header::CONTENT_TYPE).is_none()); + + let body = body_into_text(res.into_body()).await; + assert!(body.is_empty()); +} + +#[cfg(unix)] +#[tokio::test] +async fn not_found_when_not_a_directory() { + let svc = ServeDir::new("../test-files"); + + // `index.html` is a file, and we are trying to request + // it as a directory. + let req = Request::builder() + .uri("/index.html/some_file") + .body(Body::empty()) + .unwrap(); + let res = svc.oneshot(req).await.unwrap(); + + // This should lead to a 404 + assert_eq!(res.status(), StatusCode::NOT_FOUND); + assert!(res.headers().get(header::CONTENT_TYPE).is_none()); + + let body = body_into_text(res.into_body()).await; + assert!(body.is_empty()); +} + +#[tokio::test] +async fn not_found_precompressed() { + let svc = ServeDir::new("../test-files").precompressed_gzip(); + + let req = Request::builder() + .uri("/not-found") + .header("Accept-Encoding", "gzip") + .body(Body::empty()) + .unwrap(); + let res = svc.oneshot(req).await.unwrap(); + + assert_eq!(res.status(), StatusCode::NOT_FOUND); + assert!(res.headers().get(header::CONTENT_TYPE).is_none()); + + let body = body_into_text(res.into_body()).await; + assert!(body.is_empty()); +} + +#[tokio::test] +async fn fallbacks_to_different_precompressed_variant_if_not_found_for_head_request() { + let svc = ServeDir::new("../test-files") + .precompressed_gzip() + .precompressed_br(); + + let req = Request::builder() + .uri("/precompressed_br.txt") + .header("Accept-Encoding", "gzip,br,deflate") + .method(Method::HEAD) + .body(Body::empty()) + .unwrap(); + let res = svc.oneshot(req).await.unwrap(); + + assert_eq!(res.headers()["content-type"], "text/plain"); + assert_eq!(res.headers()["content-encoding"], "br"); + assert_eq!(res.headers()["content-length"], "15"); + + assert!(res.into_body().frame().await.is_none()); +} + +#[tokio::test] +async fn fallbacks_to_different_precompressed_variant_if_not_found() { + let svc = ServeDir::new("../test-files") + .precompressed_gzip() + .precompressed_br(); + + let req = Request::builder() + .uri("/precompressed_br.txt") + .header("Accept-Encoding", "gzip,br,deflate") + .body(Body::empty()) + .unwrap(); + let res = svc.oneshot(req).await.unwrap(); + + assert_eq!(res.headers()["content-type"], "text/plain"); + assert_eq!(res.headers()["content-encoding"], "br"); + + let body = res.into_body().collect().await.unwrap().to_bytes(); + let mut decompressed = Vec::new(); + BrotliDecompress(&mut &body[..], &mut decompressed).unwrap(); + let decompressed = String::from_utf8(decompressed.to_vec()).unwrap(); + assert!(decompressed.starts_with("Test file")); +} + +#[tokio::test] +async fn redirect_to_trailing_slash_on_dir() { + let svc = ServeDir::new("."); + + let req = Request::builder().uri("/src").body(Body::empty()).unwrap(); + let res = svc.oneshot(req).await.unwrap(); + + assert_eq!(res.status(), StatusCode::TEMPORARY_REDIRECT); + + let location = &res.headers()[http::header::LOCATION]; + assert_eq!(location, "/src/"); +} + +#[tokio::test] +async fn empty_directory_without_index() { + let svc = ServeDir::new(".").append_index_html_on_directories(false); + + let req = Request::new(Body::empty()); + let res = svc.oneshot(req).await.unwrap(); + + assert_eq!(res.status(), StatusCode::NOT_FOUND); + assert!(res.headers().get(header::CONTENT_TYPE).is_none()); + + let body = body_into_text(res.into_body()).await; + assert!(body.is_empty()); +} + +#[tokio::test] +async fn empty_directory_without_index_no_information_leak() { + let svc = ServeDir::new("..").append_index_html_on_directories(false); + + let req = Request::builder() + .uri("/test-files") + .body(Body::empty()) + .unwrap(); + let res = svc.oneshot(req).await.unwrap(); + + assert_eq!(res.status(), StatusCode::NOT_FOUND); + assert!(res.headers().get(header::CONTENT_TYPE).is_none()); + + let body = body_into_text(res.into_body()).await; + assert!(body.is_empty()); +} + +async fn body_into_text<B>(body: B) -> String +where + B: HttpBody<Data = bytes::Bytes> + Unpin, + B::Error: std::fmt::Debug, +{ + let bytes = to_bytes(body).await.unwrap(); + String::from_utf8(bytes.to_vec()).unwrap() +} + +#[tokio::test] +async fn access_cjk_percent_encoded_uri_path() { + // percent encoding present of ä½ å¥½ä¸–ç•Œ.txt + let cjk_filename_encoded = "%E4%BD%A0%E5%A5%BD%E4%B8%96%E7%95%8C.txt"; + + let svc = ServeDir::new("../test-files"); + + let req = Request::builder() + .uri(format!("/{}", cjk_filename_encoded)) + .body(Body::empty()) + .unwrap(); + let res = svc.oneshot(req).await.unwrap(); + + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(res.headers()["content-type"], "text/plain"); +} + +#[tokio::test] +async fn access_space_percent_encoded_uri_path() { + let encoded_filename = "filename%20with%20space.txt"; + + let svc = ServeDir::new("../test-files"); + + let req = Request::builder() + .uri(format!("/{}", encoded_filename)) + .body(Body::empty()) + .unwrap(); + let res = svc.oneshot(req).await.unwrap(); + + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(res.headers()["content-type"], "text/plain"); +} + +#[tokio::test] +async fn read_partial_empty() { + let svc = ServeDir::new("../test-files"); + + let req = Request::builder() + .uri("/empty.txt") + .header("Range", "bytes=0-") + .body(Body::empty()) + .unwrap(); + let res = svc.oneshot(req).await.unwrap(); + + assert_eq!(res.status(), StatusCode::PARTIAL_CONTENT); + assert_eq!(res.headers()["content-length"], "0"); + assert_eq!(res.headers()["content-range"], "bytes 0-0/0"); + + let body = to_bytes(res.into_body()).await.ok().unwrap(); + assert!(body.is_empty()); +} + +#[tokio::test] +async fn read_partial_in_bounds() { + let svc = ServeDir::new(".."); + let bytes_start_incl = 9; + let bytes_end_incl = 1023; + + let req = Request::builder() + .uri("/README.md") + .header( + "Range", + format!("bytes={}-{}", bytes_start_incl, bytes_end_incl), + ) + .body(Body::empty()) + .unwrap(); + let res = svc.oneshot(req).await.unwrap(); + + let file_contents = std::fs::read("../README.md").unwrap(); + assert_eq!(res.status(), StatusCode::PARTIAL_CONTENT); + assert_eq!( + res.headers()["content-length"], + (bytes_end_incl - bytes_start_incl + 1).to_string() + ); + assert!(res.headers()["content-range"] + .to_str() + .unwrap() + .starts_with(&format!( + "bytes {}-{}/{}", + bytes_start_incl, + bytes_end_incl, + file_contents.len() + ))); + assert_eq!(res.headers()["content-type"], "text/markdown"); + + let body = to_bytes(res.into_body()).await.ok().unwrap(); + let source = Bytes::from(file_contents[bytes_start_incl..=bytes_end_incl].to_vec()); + assert_eq!(body, source); +} + +#[tokio::test] +async fn read_partial_accepts_out_of_bounds_range() { + let svc = ServeDir::new(".."); + let bytes_start_incl = 0; + let bytes_end_excl = 9999999; + let requested_len = bytes_end_excl - bytes_start_incl; + + let req = Request::builder() + .uri("/README.md") + .header( + "Range", + format!("bytes={}-{}", bytes_start_incl, requested_len - 1), + ) + .body(Body::empty()) + .unwrap(); + let res = svc.oneshot(req).await.unwrap(); + + assert_eq!(res.status(), StatusCode::PARTIAL_CONTENT); + let file_contents = std::fs::read("../README.md").unwrap(); + // Out of bounds range gives all bytes + assert_eq!( + res.headers()["content-range"], + &format!( + "bytes 0-{}/{}", + file_contents.len() - 1, + file_contents.len() + ) + ) +} + +#[tokio::test] +async fn read_partial_errs_on_garbage_header() { + let svc = ServeDir::new(".."); + let req = Request::builder() + .uri("/README.md") + .header("Range", "bad_format") + .body(Body::empty()) + .unwrap(); + let res = svc.oneshot(req).await.unwrap(); + assert_eq!(res.status(), StatusCode::RANGE_NOT_SATISFIABLE); + let file_contents = std::fs::read("../README.md").unwrap(); + assert_eq!( + res.headers()["content-range"], + &format!("bytes */{}", file_contents.len()) + ) +} + +#[tokio::test] +async fn read_partial_errs_on_bad_range() { + let svc = ServeDir::new(".."); + let req = Request::builder() + .uri("/README.md") + .header("Range", "bytes=-1-15") + .body(Body::empty()) + .unwrap(); + let res = svc.oneshot(req).await.unwrap(); + assert_eq!(res.status(), StatusCode::RANGE_NOT_SATISFIABLE); + let file_contents = std::fs::read("../README.md").unwrap(); + assert_eq!( + res.headers()["content-range"], + &format!("bytes */{}", file_contents.len()) + ) +} + +#[tokio::test] +async fn accept_encoding_identity() { + let svc = ServeDir::new(".."); + let req = Request::builder() + .uri("/README.md") + .header("Accept-Encoding", "identity") + .body(Body::empty()) + .unwrap(); + let res = svc.oneshot(req).await.unwrap(); + assert_eq!(res.status(), StatusCode::OK); + // Identity encoding should not be included in the response headers + assert!(res.headers().get("content-encoding").is_none()); +} + +#[tokio::test] +async fn last_modified() { + let svc = ServeDir::new(".."); + let req = Request::builder() + .uri("/README.md") + .body(Body::empty()) + .unwrap(); + let res = svc.oneshot(req).await.unwrap(); + assert_eq!(res.status(), StatusCode::OK); + + let last_modified = res + .headers() + .get(header::LAST_MODIFIED) + .expect("Missing last modified header!"); + + // -- If-Modified-Since + + let svc = ServeDir::new(".."); + let req = Request::builder() + .uri("/README.md") + .header(header::IF_MODIFIED_SINCE, last_modified) + .body(Body::empty()) + .unwrap(); + + let res = svc.oneshot(req).await.unwrap(); + assert_eq!(res.status(), StatusCode::NOT_MODIFIED); + assert!(res.into_body().frame().await.is_none()); + + let svc = ServeDir::new(".."); + let req = Request::builder() + .uri("/README.md") + .header(header::IF_MODIFIED_SINCE, "Fri, 09 Aug 1996 14:21:40 GMT") + .body(Body::empty()) + .unwrap(); + + let res = svc.oneshot(req).await.unwrap(); + assert_eq!(res.status(), StatusCode::OK); + let readme_bytes = include_bytes!("../../../../../README.md"); + let body = res.into_body().collect().await.unwrap().to_bytes(); + assert_eq!(body.as_ref(), readme_bytes); + + // -- If-Unmodified-Since + + let svc = ServeDir::new(".."); + let req = Request::builder() + .uri("/README.md") + .header(header::IF_UNMODIFIED_SINCE, last_modified) + .body(Body::empty()) + .unwrap(); + + let res = svc.oneshot(req).await.unwrap(); + assert_eq!(res.status(), StatusCode::OK); + let body = res.into_body().collect().await.unwrap().to_bytes(); + assert_eq!(body.as_ref(), readme_bytes); + + let svc = ServeDir::new(".."); + let req = Request::builder() + .uri("/README.md") + .header(header::IF_UNMODIFIED_SINCE, "Fri, 09 Aug 1996 14:21:40 GMT") + .body(Body::empty()) + .unwrap(); + + let res = svc.oneshot(req).await.unwrap(); + assert_eq!(res.status(), StatusCode::PRECONDITION_FAILED); + assert!(res.into_body().frame().await.is_none()); +} + +#[tokio::test] +async fn with_fallback_svc() { + async fn fallback<B>(req: Request<B>) -> Result<Response<Body>, Infallible> { + Ok(Response::new(Body::from(format!( + "from fallback {}", + req.uri().path() + )))) + } + + let svc = ServeDir::new("..").fallback(tower::service_fn(fallback)); + + let req = Request::builder() + .uri("/doesnt-exist") + .body(Body::empty()) + .unwrap(); + let res = svc.oneshot(req).await.unwrap(); + + assert_eq!(res.status(), StatusCode::OK); + + let body = body_into_text(res.into_body()).await; + assert_eq!(body, "from fallback /doesnt-exist"); +} + +#[tokio::test] +async fn with_fallback_serve_file() { + let svc = ServeDir::new("..").fallback(ServeFile::new("../README.md")); + + let req = Request::builder() + .uri("/doesnt-exist") + .body(Body::empty()) + .unwrap(); + let res = svc.oneshot(req).await.unwrap(); + + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(res.headers()["content-type"], "text/markdown"); + + let body = body_into_text(res.into_body()).await; + + let contents = std::fs::read_to_string("../README.md").unwrap(); + assert_eq!(body, contents); +} + +#[tokio::test] +async fn method_not_allowed() { + let svc = ServeDir::new(".."); + + let req = Request::builder() + .method(Method::POST) + .uri("/README.md") + .body(Body::empty()) + .unwrap(); + let res = svc.oneshot(req).await.unwrap(); + + assert_eq!(res.status(), StatusCode::METHOD_NOT_ALLOWED); + assert_eq!(res.headers()[ALLOW], "GET,HEAD"); +} + +#[tokio::test] +async fn calling_fallback_on_not_allowed() { + async fn fallback<B>(req: Request<B>) -> Result<Response<Body>, Infallible> { + Ok(Response::new(Body::from(format!( + "from fallback {}", + req.uri().path() + )))) + } + + let svc = ServeDir::new("..") + .call_fallback_on_method_not_allowed(true) + .fallback(tower::service_fn(fallback)); + + let req = Request::builder() + .method(Method::POST) + .uri("/doesnt-exist") + .body(Body::empty()) + .unwrap(); + let res = svc.oneshot(req).await.unwrap(); + + assert_eq!(res.status(), StatusCode::OK); + + let body = body_into_text(res.into_body()).await; + assert_eq!(body, "from fallback /doesnt-exist"); +} + +#[tokio::test] +async fn with_fallback_svc_and_not_append_index_html_on_directories() { + async fn fallback<B>(req: Request<B>) -> Result<Response<Body>, Infallible> { + Ok(Response::new(Body::from(format!( + "from fallback {}", + req.uri().path() + )))) + } + + let svc = ServeDir::new("..") + .append_index_html_on_directories(false) + .fallback(tower::service_fn(fallback)); + + let req = Request::builder().uri("/").body(Body::empty()).unwrap(); + let res = svc.oneshot(req).await.unwrap(); + + assert_eq!(res.status(), StatusCode::OK); + + let body = body_into_text(res.into_body()).await; + assert_eq!(body, "from fallback /"); +} + +// https://github.com/tower-rs/tower-http/issues/308 +#[tokio::test] +async fn calls_fallback_on_invalid_paths() { + async fn fallback<T>(_: T) -> Result<Response<Body>, Infallible> { + let mut res = Response::new(Body::empty()); + res.headers_mut() + .insert("from-fallback", "1".parse().unwrap()); + Ok(res) + } + + let svc = ServeDir::new("..").fallback(service_fn(fallback)); + + let req = Request::builder() + .uri("/weird_%c3%28_path") + .body(Body::empty()) + .unwrap(); + + let res = svc.oneshot(req).await.unwrap(); + + assert_eq!(res.headers()["from-fallback"], "1"); +} diff --git a/vendor/tower-http/src/services/fs/serve_file.rs b/vendor/tower-http/src/services/fs/serve_file.rs new file mode 100644 index 00000000..ade3cd15 --- /dev/null +++ b/vendor/tower-http/src/services/fs/serve_file.rs @@ -0,0 +1,560 @@ +//! Service that serves a file. + +use super::ServeDir; +use http::{HeaderValue, Request}; +use mime::Mime; +use std::{ + path::Path, + task::{Context, Poll}, +}; +use tower_service::Service; + +/// Service that serves a file. +#[derive(Clone, Debug)] +pub struct ServeFile(ServeDir); + +// Note that this is just a special case of ServeDir +impl ServeFile { + /// Create a new [`ServeFile`]. + /// + /// The `Content-Type` will be guessed from the file extension. + pub fn new<P: AsRef<Path>>(path: P) -> Self { + let guess = mime_guess::from_path(path.as_ref()); + let mime = guess + .first_raw() + .map(HeaderValue::from_static) + .unwrap_or_else(|| { + HeaderValue::from_str(mime::APPLICATION_OCTET_STREAM.as_ref()).unwrap() + }); + + Self(ServeDir::new_single_file(path, mime)) + } + + /// Create a new [`ServeFile`] with a specific mime type. + /// + /// # Panics + /// + /// Will panic if the mime type isn't a valid [header value]. + /// + /// [header value]: https://docs.rs/http/latest/http/header/struct.HeaderValue.html + pub fn new_with_mime<P: AsRef<Path>>(path: P, mime: &Mime) -> Self { + let mime = HeaderValue::from_str(mime.as_ref()).expect("mime isn't a valid header value"); + Self(ServeDir::new_single_file(path, mime)) + } + + /// Informs the service that it should also look for a precompressed gzip + /// version of the file. + /// + /// If the client has an `Accept-Encoding` header that allows the gzip encoding, + /// the file `foo.txt.gz` will be served instead of `foo.txt`. + /// If the precompressed file is not available, or the client doesn't support it, + /// the uncompressed version will be served instead. + /// Both the precompressed version and the uncompressed version are expected + /// to be present in the same directory. Different precompressed + /// variants can be combined. + pub fn precompressed_gzip(self) -> Self { + Self(self.0.precompressed_gzip()) + } + + /// Informs the service that it should also look for a precompressed brotli + /// version of the file. + /// + /// If the client has an `Accept-Encoding` header that allows the brotli encoding, + /// the file `foo.txt.br` will be served instead of `foo.txt`. + /// If the precompressed file is not available, or the client doesn't support it, + /// the uncompressed version will be served instead. + /// Both the precompressed version and the uncompressed version are expected + /// to be present in the same directory. Different precompressed + /// variants can be combined. + pub fn precompressed_br(self) -> Self { + Self(self.0.precompressed_br()) + } + + /// Informs the service that it should also look for a precompressed deflate + /// version of the file. + /// + /// If the client has an `Accept-Encoding` header that allows the deflate encoding, + /// the file `foo.txt.zz` will be served instead of `foo.txt`. + /// If the precompressed file is not available, or the client doesn't support it, + /// the uncompressed version will be served instead. + /// Both the precompressed version and the uncompressed version are expected + /// to be present in the same directory. Different precompressed + /// variants can be combined. + pub fn precompressed_deflate(self) -> Self { + Self(self.0.precompressed_deflate()) + } + + /// Informs the service that it should also look for a precompressed zstd + /// version of the file. + /// + /// If the client has an `Accept-Encoding` header that allows the zstd encoding, + /// the file `foo.txt.zst` will be served instead of `foo.txt`. + /// If the precompressed file is not available, or the client doesn't support it, + /// the uncompressed version will be served instead. + /// Both the precompressed version and the uncompressed version are expected + /// to be present in the same directory. Different precompressed + /// variants can be combined. + pub fn precompressed_zstd(self) -> Self { + Self(self.0.precompressed_zstd()) + } + + /// Set a specific read buffer chunk size. + /// + /// The default capacity is 64kb. + pub fn with_buf_chunk_size(self, chunk_size: usize) -> Self { + Self(self.0.with_buf_chunk_size(chunk_size)) + } + + /// Call the service and get a future that contains any `std::io::Error` that might have + /// happened. + /// + /// See [`ServeDir::try_call`] for more details. + pub fn try_call<ReqBody>( + &mut self, + req: Request<ReqBody>, + ) -> super::serve_dir::future::ResponseFuture<ReqBody> + where + ReqBody: Send + 'static, + { + self.0.try_call(req) + } +} + +impl<ReqBody> Service<Request<ReqBody>> for ServeFile +where + ReqBody: Send + 'static, +{ + type Error = <ServeDir as Service<Request<ReqBody>>>::Error; + type Response = <ServeDir as Service<Request<ReqBody>>>::Response; + type Future = <ServeDir as Service<Request<ReqBody>>>::Future; + + #[inline] + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + Poll::Ready(Ok(())) + } + + #[inline] + fn call(&mut self, req: Request<ReqBody>) -> Self::Future { + self.0.call(req) + } +} + +#[cfg(test)] +mod tests { + use crate::services::ServeFile; + use crate::test_helpers::Body; + use async_compression::tokio::bufread::ZstdDecoder; + use brotli::BrotliDecompress; + use flate2::bufread::DeflateDecoder; + use flate2::bufread::GzDecoder; + use http::header; + use http::Method; + use http::{Request, StatusCode}; + use http_body_util::BodyExt; + use mime::Mime; + use std::io::Read; + use std::str::FromStr; + use tokio::io::AsyncReadExt; + use tower::ServiceExt; + + #[tokio::test] + async fn basic() { + let svc = ServeFile::new("../README.md"); + + let res = svc.oneshot(Request::new(Body::empty())).await.unwrap(); + + assert_eq!(res.headers()["content-type"], "text/markdown"); + + let body = res.into_body().collect().await.unwrap().to_bytes(); + let body = String::from_utf8(body.to_vec()).unwrap(); + + assert!(body.starts_with("# Tower HTTP")); + } + + #[tokio::test] + async fn basic_with_mime() { + let svc = ServeFile::new_with_mime("../README.md", &Mime::from_str("image/jpg").unwrap()); + + let res = svc.oneshot(Request::new(Body::empty())).await.unwrap(); + + assert_eq!(res.headers()["content-type"], "image/jpg"); + + let body = res.into_body().collect().await.unwrap().to_bytes(); + let body = String::from_utf8(body.to_vec()).unwrap(); + + assert!(body.starts_with("# Tower HTTP")); + } + + #[tokio::test] + async fn head_request() { + let svc = ServeFile::new("../test-files/precompressed.txt"); + + let mut request = Request::new(Body::empty()); + *request.method_mut() = Method::HEAD; + let res = svc.oneshot(request).await.unwrap(); + + assert_eq!(res.headers()["content-type"], "text/plain"); + assert_eq!(res.headers()["content-length"], "23"); + + assert!(res.into_body().frame().await.is_none()); + } + + #[tokio::test] + async fn precompresed_head_request() { + let svc = ServeFile::new("../test-files/precompressed.txt").precompressed_gzip(); + + let request = Request::builder() + .header("Accept-Encoding", "gzip") + .method(Method::HEAD) + .body(Body::empty()) + .unwrap(); + let res = svc.oneshot(request).await.unwrap(); + + assert_eq!(res.headers()["content-type"], "text/plain"); + assert_eq!(res.headers()["content-encoding"], "gzip"); + assert_eq!(res.headers()["content-length"], "59"); + + assert!(res.into_body().frame().await.is_none()); + } + + #[tokio::test] + async fn precompressed_gzip() { + let svc = ServeFile::new("../test-files/precompressed.txt").precompressed_gzip(); + + let request = Request::builder() + .header("Accept-Encoding", "gzip") + .body(Body::empty()) + .unwrap(); + let res = svc.oneshot(request).await.unwrap(); + + assert_eq!(res.headers()["content-type"], "text/plain"); + assert_eq!(res.headers()["content-encoding"], "gzip"); + + let body = res.into_body().collect().await.unwrap().to_bytes(); + let mut decoder = GzDecoder::new(&body[..]); + let mut decompressed = String::new(); + decoder.read_to_string(&mut decompressed).unwrap(); + assert!(decompressed.starts_with("\"This is a test file!\"")); + } + + #[tokio::test] + async fn unsupported_precompression_alogrithm_fallbacks_to_uncompressed() { + let svc = ServeFile::new("../test-files/precompressed.txt").precompressed_gzip(); + + let request = Request::builder() + .header("Accept-Encoding", "br") + .body(Body::empty()) + .unwrap(); + let res = svc.oneshot(request).await.unwrap(); + + assert_eq!(res.headers()["content-type"], "text/plain"); + assert!(res.headers().get("content-encoding").is_none()); + + let body = res.into_body().collect().await.unwrap().to_bytes(); + let body = String::from_utf8(body.to_vec()).unwrap(); + assert!(body.starts_with("\"This is a test file!\"")); + } + + #[tokio::test] + async fn missing_precompressed_variant_fallbacks_to_uncompressed() { + let svc = ServeFile::new("../test-files/missing_precompressed.txt").precompressed_gzip(); + + let request = Request::builder() + .header("Accept-Encoding", "gzip") + .body(Body::empty()) + .unwrap(); + let res = svc.oneshot(request).await.unwrap(); + + assert_eq!(res.headers()["content-type"], "text/plain"); + // Uncompressed file is served because compressed version is missing + assert!(res.headers().get("content-encoding").is_none()); + + let body = res.into_body().collect().await.unwrap().to_bytes(); + let body = String::from_utf8(body.to_vec()).unwrap(); + assert!(body.starts_with("Test file!")); + } + + #[tokio::test] + async fn missing_precompressed_variant_fallbacks_to_uncompressed_head_request() { + let svc = ServeFile::new("../test-files/missing_precompressed.txt").precompressed_gzip(); + + let request = Request::builder() + .header("Accept-Encoding", "gzip") + .method(Method::HEAD) + .body(Body::empty()) + .unwrap(); + let res = svc.oneshot(request).await.unwrap(); + + assert_eq!(res.headers()["content-type"], "text/plain"); + assert_eq!(res.headers()["content-length"], "11"); + // Uncompressed file is served because compressed version is missing + assert!(res.headers().get("content-encoding").is_none()); + + assert!(res.into_body().frame().await.is_none()); + } + + #[tokio::test] + async fn only_precompressed_variant_existing() { + let svc = ServeFile::new("../test-files/only_gzipped.txt").precompressed_gzip(); + + let request = Request::builder().body(Body::empty()).unwrap(); + let res = svc.clone().oneshot(request).await.unwrap(); + + assert_eq!(res.status(), StatusCode::NOT_FOUND); + + // Should reply with gzipped file if client supports it + let request = Request::builder() + .header("Accept-Encoding", "gzip") + .body(Body::empty()) + .unwrap(); + let res = svc.oneshot(request).await.unwrap(); + + assert_eq!(res.headers()["content-type"], "text/plain"); + assert_eq!(res.headers()["content-encoding"], "gzip"); + + let body = res.into_body().collect().await.unwrap().to_bytes(); + let mut decoder = GzDecoder::new(&body[..]); + let mut decompressed = String::new(); + decoder.read_to_string(&mut decompressed).unwrap(); + assert!(decompressed.starts_with("\"This is a test file\"")); + } + + #[tokio::test] + async fn precompressed_br() { + let svc = ServeFile::new("../test-files/precompressed.txt").precompressed_br(); + + let request = Request::builder() + .header("Accept-Encoding", "gzip,br") + .body(Body::empty()) + .unwrap(); + let res = svc.oneshot(request).await.unwrap(); + + assert_eq!(res.headers()["content-type"], "text/plain"); + assert_eq!(res.headers()["content-encoding"], "br"); + + let body = res.into_body().collect().await.unwrap().to_bytes(); + let mut decompressed = Vec::new(); + BrotliDecompress(&mut &body[..], &mut decompressed).unwrap(); + let decompressed = String::from_utf8(decompressed.to_vec()).unwrap(); + assert!(decompressed.starts_with("\"This is a test file!\"")); + } + + #[tokio::test] + async fn precompressed_deflate() { + let svc = ServeFile::new("../test-files/precompressed.txt").precompressed_deflate(); + let request = Request::builder() + .header("Accept-Encoding", "deflate,br") + .body(Body::empty()) + .unwrap(); + let res = svc.oneshot(request).await.unwrap(); + + assert_eq!(res.headers()["content-type"], "text/plain"); + assert_eq!(res.headers()["content-encoding"], "deflate"); + + let body = res.into_body().collect().await.unwrap().to_bytes(); + let mut decoder = DeflateDecoder::new(&body[..]); + let mut decompressed = String::new(); + decoder.read_to_string(&mut decompressed).unwrap(); + assert!(decompressed.starts_with("\"This is a test file!\"")); + } + + #[tokio::test] + async fn precompressed_zstd() { + let svc = ServeFile::new("../test-files/precompressed.txt").precompressed_zstd(); + let request = Request::builder() + .header("Accept-Encoding", "zstd,br") + .body(Body::empty()) + .unwrap(); + let res = svc.oneshot(request).await.unwrap(); + + assert_eq!(res.headers()["content-type"], "text/plain"); + assert_eq!(res.headers()["content-encoding"], "zstd"); + + let body = res.into_body().collect().await.unwrap().to_bytes(); + let mut decoder = ZstdDecoder::new(&body[..]); + let mut decompressed = String::new(); + decoder.read_to_string(&mut decompressed).await.unwrap(); + assert!(decompressed.starts_with("\"This is a test file!\"")); + } + + #[tokio::test] + async fn multi_precompressed() { + let svc = ServeFile::new("../test-files/precompressed.txt") + .precompressed_gzip() + .precompressed_br(); + + let request = Request::builder() + .header("Accept-Encoding", "gzip") + .body(Body::empty()) + .unwrap(); + let res = svc.clone().oneshot(request).await.unwrap(); + + assert_eq!(res.headers()["content-type"], "text/plain"); + assert_eq!(res.headers()["content-encoding"], "gzip"); + + let body = res.into_body().collect().await.unwrap().to_bytes(); + let mut decoder = GzDecoder::new(&body[..]); + let mut decompressed = String::new(); + decoder.read_to_string(&mut decompressed).unwrap(); + assert!(decompressed.starts_with("\"This is a test file!\"")); + + let request = Request::builder() + .header("Accept-Encoding", "br") + .body(Body::empty()) + .unwrap(); + let res = svc.clone().oneshot(request).await.unwrap(); + + assert_eq!(res.headers()["content-type"], "text/plain"); + assert_eq!(res.headers()["content-encoding"], "br"); + + let body = res.into_body().collect().await.unwrap().to_bytes(); + let mut decompressed = Vec::new(); + BrotliDecompress(&mut &body[..], &mut decompressed).unwrap(); + let decompressed = String::from_utf8(decompressed.to_vec()).unwrap(); + assert!(decompressed.starts_with("\"This is a test file!\"")); + } + + #[tokio::test] + async fn with_custom_chunk_size() { + let svc = ServeFile::new("../README.md").with_buf_chunk_size(1024 * 32); + + let res = svc.oneshot(Request::new(Body::empty())).await.unwrap(); + + assert_eq!(res.headers()["content-type"], "text/markdown"); + + let body = res.into_body().collect().await.unwrap().to_bytes(); + let body = String::from_utf8(body.to_vec()).unwrap(); + + assert!(body.starts_with("# Tower HTTP")); + } + + #[tokio::test] + async fn fallbacks_to_different_precompressed_variant_if_not_found() { + let svc = ServeFile::new("../test-files/precompressed_br.txt") + .precompressed_gzip() + .precompressed_deflate() + .precompressed_br(); + + let request = Request::builder() + .header("Accept-Encoding", "gzip,deflate,br") + .body(Body::empty()) + .unwrap(); + let res = svc.oneshot(request).await.unwrap(); + + assert_eq!(res.headers()["content-type"], "text/plain"); + assert_eq!(res.headers()["content-encoding"], "br"); + + let body = res.into_body().collect().await.unwrap().to_bytes(); + let mut decompressed = Vec::new(); + BrotliDecompress(&mut &body[..], &mut decompressed).unwrap(); + let decompressed = String::from_utf8(decompressed.to_vec()).unwrap(); + assert!(decompressed.starts_with("Test file")); + } + + #[tokio::test] + async fn fallbacks_to_different_precompressed_variant_if_not_found_head_request() { + let svc = ServeFile::new("../test-files/precompressed_br.txt") + .precompressed_gzip() + .precompressed_deflate() + .precompressed_br(); + + let request = Request::builder() + .header("Accept-Encoding", "gzip,deflate,br") + .method(Method::HEAD) + .body(Body::empty()) + .unwrap(); + let res = svc.oneshot(request).await.unwrap(); + + assert_eq!(res.headers()["content-type"], "text/plain"); + assert_eq!(res.headers()["content-length"], "15"); + assert_eq!(res.headers()["content-encoding"], "br"); + + assert!(res.into_body().frame().await.is_none()); + } + + #[tokio::test] + async fn returns_404_if_file_doesnt_exist() { + let svc = ServeFile::new("../this-doesnt-exist.md"); + + let res = svc.oneshot(Request::new(Body::empty())).await.unwrap(); + + assert_eq!(res.status(), StatusCode::NOT_FOUND); + assert!(res.headers().get(header::CONTENT_TYPE).is_none()); + } + + #[tokio::test] + async fn returns_404_if_file_doesnt_exist_when_precompression_is_used() { + let svc = ServeFile::new("../this-doesnt-exist.md").precompressed_deflate(); + + let request = Request::builder() + .header("Accept-Encoding", "deflate") + .body(Body::empty()) + .unwrap(); + let res = svc.oneshot(request).await.unwrap(); + + assert_eq!(res.status(), StatusCode::NOT_FOUND); + assert!(res.headers().get(header::CONTENT_TYPE).is_none()); + } + + #[tokio::test] + async fn last_modified() { + let svc = ServeFile::new("../README.md"); + + let req = Request::builder().body(Body::empty()).unwrap(); + let res = svc.oneshot(req).await.unwrap(); + + assert_eq!(res.status(), StatusCode::OK); + + let last_modified = res + .headers() + .get(header::LAST_MODIFIED) + .expect("Missing last modified header!"); + + // -- If-Modified-Since + + let svc = ServeFile::new("../README.md"); + let req = Request::builder() + .header(header::IF_MODIFIED_SINCE, last_modified) + .body(Body::empty()) + .unwrap(); + + let res = svc.oneshot(req).await.unwrap(); + assert_eq!(res.status(), StatusCode::NOT_MODIFIED); + assert!(res.into_body().frame().await.is_none()); + + let svc = ServeFile::new("../README.md"); + let req = Request::builder() + .header(header::IF_MODIFIED_SINCE, "Fri, 09 Aug 1996 14:21:40 GMT") + .body(Body::empty()) + .unwrap(); + + let res = svc.oneshot(req).await.unwrap(); + assert_eq!(res.status(), StatusCode::OK); + let readme_bytes = include_bytes!("../../../../README.md"); + let body = res.into_body().collect().await.unwrap().to_bytes(); + assert_eq!(body.as_ref(), readme_bytes); + + // -- If-Unmodified-Since + + let svc = ServeFile::new("../README.md"); + let req = Request::builder() + .header(header::IF_UNMODIFIED_SINCE, last_modified) + .body(Body::empty()) + .unwrap(); + + let res = svc.oneshot(req).await.unwrap(); + assert_eq!(res.status(), StatusCode::OK); + let body = res.into_body().collect().await.unwrap().to_bytes(); + assert_eq!(body.as_ref(), readme_bytes); + + let svc = ServeFile::new("../README.md"); + let req = Request::builder() + .header(header::IF_UNMODIFIED_SINCE, "Fri, 09 Aug 1996 14:21:40 GMT") + .body(Body::empty()) + .unwrap(); + + let res = svc.oneshot(req).await.unwrap(); + assert_eq!(res.status(), StatusCode::PRECONDITION_FAILED); + assert!(res.into_body().frame().await.is_none()); + } +} diff --git a/vendor/tower-http/src/services/mod.rs b/vendor/tower-http/src/services/mod.rs new file mode 100644 index 00000000..737d2fa1 --- /dev/null +++ b/vendor/tower-http/src/services/mod.rs @@ -0,0 +1,21 @@ +//! [`Service`]s that return responses without wrapping other [`Service`]s. +//! +//! These kinds of services are also referred to as "leaf services" since they sit at the leaves of +//! a [tree] of services. +//! +//! [`Service`]: https://docs.rs/tower/latest/tower/trait.Service.html +//! [tree]: https://en.wikipedia.org/wiki/Tree_(data_structure) + +#[cfg(feature = "redirect")] +pub mod redirect; + +#[cfg(feature = "redirect")] +#[doc(inline)] +pub use self::redirect::Redirect; + +#[cfg(feature = "fs")] +pub mod fs; + +#[cfg(feature = "fs")] +#[doc(inline)] +pub use self::fs::{ServeDir, ServeFile}; diff --git a/vendor/tower-http/src/services/redirect.rs b/vendor/tower-http/src/services/redirect.rs new file mode 100644 index 00000000..020927c9 --- /dev/null +++ b/vendor/tower-http/src/services/redirect.rs @@ -0,0 +1,159 @@ +//! Service that redirects all requests. +//! +//! # Example +//! +//! Imagine that we run `example.com` and want to redirect all requests using `HTTP` to `HTTPS`. +//! That can be done like so: +//! +//! ```rust +//! use http::{Request, Uri, StatusCode}; +//! use http_body_util::Full; +//! use bytes::Bytes; +//! use tower::{Service, ServiceExt}; +//! use tower_http::services::Redirect; +//! +//! # #[tokio::main] +//! # async fn main() -> Result<(), Box<dyn std::error::Error>> { +//! let uri: Uri = "https://example.com/".parse().unwrap(); +//! let mut service: Redirect<Full<Bytes>> = Redirect::permanent(uri); +//! +//! let request = Request::builder() +//! .uri("http://example.com") +//! .body(Full::<Bytes>::default()) +//! .unwrap(); +//! +//! let response = service.oneshot(request).await?; +//! +//! assert_eq!(response.status(), StatusCode::PERMANENT_REDIRECT); +//! assert_eq!(response.headers()["location"], "https://example.com/"); +//! # +//! # Ok(()) +//! # } +//! ``` + +use http::{header, HeaderValue, Response, StatusCode, Uri}; +use std::{ + convert::{Infallible, TryFrom}, + fmt, + future::Future, + marker::PhantomData, + pin::Pin, + task::{Context, Poll}, +}; +use tower_service::Service; + +/// Service that redirects all requests. +/// +/// See the [module docs](crate::services::redirect) for more details. +pub struct Redirect<ResBody> { + status_code: StatusCode, + location: HeaderValue, + // Covariant over ResBody, no dropping of ResBody + _marker: PhantomData<fn() -> ResBody>, +} + +impl<ResBody> Redirect<ResBody> { + /// Create a new [`Redirect`] that uses a [`307 Temporary Redirect`][mdn] status code. + /// + /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Status/307 + pub fn temporary(uri: Uri) -> Self { + Self::with_status_code(StatusCode::TEMPORARY_REDIRECT, uri) + } + + /// Create a new [`Redirect`] that uses a [`308 Permanent Redirect`][mdn] status code. + /// + /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Status/308 + pub fn permanent(uri: Uri) -> Self { + Self::with_status_code(StatusCode::PERMANENT_REDIRECT, uri) + } + + /// Create a new [`Redirect`] that uses the given status code. + /// + /// # Panics + /// + /// - If `status_code` isn't a [redirection status code][mdn] (3xx). + /// - If `uri` isn't a valid [`HeaderValue`]. + /// + /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Status#redirection_messages + pub fn with_status_code(status_code: StatusCode, uri: Uri) -> Self { + assert!( + status_code.is_redirection(), + "not a redirection status code" + ); + + Self { + status_code, + location: HeaderValue::try_from(uri.to_string()) + .expect("URI isn't a valid header value"), + _marker: PhantomData, + } + } +} + +impl<R, ResBody> Service<R> for Redirect<ResBody> +where + ResBody: Default, +{ + type Response = Response<ResBody>; + type Error = Infallible; + type Future = ResponseFuture<ResBody>; + + #[inline] + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, _req: R) -> Self::Future { + ResponseFuture { + status_code: self.status_code, + location: Some(self.location.clone()), + _marker: PhantomData, + } + } +} + +impl<ResBody> fmt::Debug for Redirect<ResBody> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Redirect") + .field("status_code", &self.status_code) + .field("location", &self.location) + .finish() + } +} + +impl<ResBody> Clone for Redirect<ResBody> { + fn clone(&self) -> Self { + Self { + status_code: self.status_code, + location: self.location.clone(), + _marker: PhantomData, + } + } +} + +/// Response future of [`Redirect`]. +#[derive(Debug)] +pub struct ResponseFuture<ResBody> { + location: Option<HeaderValue>, + status_code: StatusCode, + // Covariant over ResBody, no dropping of ResBody + _marker: PhantomData<fn() -> ResBody>, +} + +impl<ResBody> Future for ResponseFuture<ResBody> +where + ResBody: Default, +{ + type Output = Result<Response<ResBody>, Infallible>; + + fn poll(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> { + let mut res = Response::default(); + + *res.status_mut() = self.status_code; + + res.headers_mut() + .insert(header::LOCATION, self.location.take().unwrap()); + + Poll::Ready(Ok(res)) + } +} diff --git a/vendor/tower-http/src/set_header/mod.rs b/vendor/tower-http/src/set_header/mod.rs new file mode 100644 index 00000000..396527ef --- /dev/null +++ b/vendor/tower-http/src/set_header/mod.rs @@ -0,0 +1,110 @@ +//! Middleware for setting headers on requests and responses. +//! +//! See [request] and [response] for more details. + +use http::{header::HeaderName, HeaderMap, HeaderValue, Request, Response}; + +pub mod request; +pub mod response; + +#[doc(inline)] +pub use self::{ + request::{SetRequestHeader, SetRequestHeaderLayer}, + response::{SetResponseHeader, SetResponseHeaderLayer}, +}; + +/// Trait for producing header values. +/// +/// Used by [`SetRequestHeader`] and [`SetResponseHeader`]. +/// +/// This trait is implemented for closures with the correct type signature. Typically users will +/// not have to implement this trait for their own types. +/// +/// It is also implemented directly for [`HeaderValue`]. When a fixed header value should be added +/// to all responses, it can be supplied directly to the middleware. +pub trait MakeHeaderValue<T> { + /// Try to create a header value from the request or response. + fn make_header_value(&mut self, message: &T) -> Option<HeaderValue>; +} + +impl<F, T> MakeHeaderValue<T> for F +where + F: FnMut(&T) -> Option<HeaderValue>, +{ + fn make_header_value(&mut self, message: &T) -> Option<HeaderValue> { + self(message) + } +} + +impl<T> MakeHeaderValue<T> for HeaderValue { + fn make_header_value(&mut self, _message: &T) -> Option<HeaderValue> { + Some(self.clone()) + } +} + +impl<T> MakeHeaderValue<T> for Option<HeaderValue> { + fn make_header_value(&mut self, _message: &T) -> Option<HeaderValue> { + self.clone() + } +} + +#[derive(Debug, Clone, Copy)] +enum InsertHeaderMode { + Override, + Append, + IfNotPresent, +} + +impl InsertHeaderMode { + fn apply<T, M>(self, header_name: &HeaderName, target: &mut T, make: &mut M) + where + T: Headers, + M: MakeHeaderValue<T>, + { + match self { + InsertHeaderMode::Override => { + if let Some(value) = make.make_header_value(target) { + target.headers_mut().insert(header_name.clone(), value); + } + } + InsertHeaderMode::IfNotPresent => { + if !target.headers().contains_key(header_name) { + if let Some(value) = make.make_header_value(target) { + target.headers_mut().insert(header_name.clone(), value); + } + } + } + InsertHeaderMode::Append => { + if let Some(value) = make.make_header_value(target) { + target.headers_mut().append(header_name.clone(), value); + } + } + } + } +} + +trait Headers { + fn headers(&self) -> &HeaderMap; + + fn headers_mut(&mut self) -> &mut HeaderMap; +} + +impl<B> Headers for Request<B> { + fn headers(&self) -> &HeaderMap { + Request::headers(self) + } + + fn headers_mut(&mut self) -> &mut HeaderMap { + Request::headers_mut(self) + } +} + +impl<B> Headers for Response<B> { + fn headers(&self) -> &HeaderMap { + Response::headers(self) + } + + fn headers_mut(&mut self) -> &mut HeaderMap { + Response::headers_mut(self) + } +} diff --git a/vendor/tower-http/src/set_header/request.rs b/vendor/tower-http/src/set_header/request.rs new file mode 100644 index 00000000..4032e23a --- /dev/null +++ b/vendor/tower-http/src/set_header/request.rs @@ -0,0 +1,254 @@ +//! Set a header on the request. +//! +//! The header value to be set may be provided as a fixed value when the +//! middleware is constructed, or determined dynamically based on the request +//! by a closure. See the [`MakeHeaderValue`] trait for details. +//! +//! # Example +//! +//! Setting a header from a fixed value provided when the middleware is constructed: +//! +//! ``` +//! use http::{Request, Response, header::{self, HeaderValue}}; +//! use tower::{Service, ServiceExt, ServiceBuilder}; +//! use tower_http::set_header::SetRequestHeaderLayer; +//! use http_body_util::Full; +//! use bytes::Bytes; +//! +//! # #[tokio::main] +//! # async fn main() -> Result<(), Box<dyn std::error::Error>> { +//! # let http_client = tower::service_fn(|_: Request<Full<Bytes>>| async move { +//! # Ok::<_, std::convert::Infallible>(Response::new(Full::<Bytes>::default())) +//! # }); +//! # +//! let mut svc = ServiceBuilder::new() +//! .layer( +//! // Layer that sets `User-Agent: my very cool app` on requests. +//! // +//! // `if_not_present` will only insert the header if it does not already +//! // have a value. +//! SetRequestHeaderLayer::if_not_present( +//! header::USER_AGENT, +//! HeaderValue::from_static("my very cool app"), +//! ) +//! ) +//! .service(http_client); +//! +//! let request = Request::new(Full::default()); +//! +//! let response = svc.ready().await?.call(request).await?; +//! # +//! # Ok(()) +//! # } +//! ``` +//! +//! Setting a header based on a value determined dynamically from the request: +//! +//! ``` +//! use http::{Request, Response, header::{self, HeaderValue}}; +//! use tower::{Service, ServiceExt, ServiceBuilder}; +//! use tower_http::set_header::SetRequestHeaderLayer; +//! use bytes::Bytes; +//! use http_body_util::Full; +//! +//! # #[tokio::main] +//! # async fn main() -> Result<(), Box<dyn std::error::Error>> { +//! # let http_client = tower::service_fn(|_: Request<Full<Bytes>>| async move { +//! # Ok::<_, std::convert::Infallible>(Response::new(Full::<Bytes>::default())) +//! # }); +//! fn date_header_value() -> HeaderValue { +//! // ... +//! # HeaderValue::from_static("now") +//! } +//! +//! let mut svc = ServiceBuilder::new() +//! .layer( +//! // Layer that sets `Date` to the current date and time. +//! // +//! // `overriding` will insert the header and override any previous values it +//! // may have. +//! SetRequestHeaderLayer::overriding( +//! header::DATE, +//! |request: &Request<Full<Bytes>>| { +//! Some(date_header_value()) +//! } +//! ) +//! ) +//! .service(http_client); +//! +//! let request = Request::new(Full::default()); +//! +//! let response = svc.ready().await?.call(request).await?; +//! # +//! # Ok(()) +//! # } +//! ``` + +use super::{InsertHeaderMode, MakeHeaderValue}; +use http::{header::HeaderName, Request, Response}; +use std::{ + fmt, + task::{Context, Poll}, +}; +use tower_layer::Layer; +use tower_service::Service; + +/// Layer that applies [`SetRequestHeader`] which adds a request header. +/// +/// See [`SetRequestHeader`] for more details. +pub struct SetRequestHeaderLayer<M> { + header_name: HeaderName, + make: M, + mode: InsertHeaderMode, +} + +impl<M> fmt::Debug for SetRequestHeaderLayer<M> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("SetRequestHeaderLayer") + .field("header_name", &self.header_name) + .field("mode", &self.mode) + .field("make", &std::any::type_name::<M>()) + .finish() + } +} + +impl<M> SetRequestHeaderLayer<M> { + /// Create a new [`SetRequestHeaderLayer`]. + /// + /// If a previous value exists for the same header, it is removed and replaced with the new + /// header value. + pub fn overriding(header_name: HeaderName, make: M) -> Self { + Self::new(header_name, make, InsertHeaderMode::Override) + } + + /// Create a new [`SetRequestHeaderLayer`]. + /// + /// The new header is always added, preserving any existing values. If previous values exist, + /// the header will have multiple values. + pub fn appending(header_name: HeaderName, make: M) -> Self { + Self::new(header_name, make, InsertHeaderMode::Append) + } + + /// Create a new [`SetRequestHeaderLayer`]. + /// + /// If a previous value exists for the header, the new value is not inserted. + pub fn if_not_present(header_name: HeaderName, make: M) -> Self { + Self::new(header_name, make, InsertHeaderMode::IfNotPresent) + } + + fn new(header_name: HeaderName, make: M, mode: InsertHeaderMode) -> Self { + Self { + make, + header_name, + mode, + } + } +} + +impl<S, M> Layer<S> for SetRequestHeaderLayer<M> +where + M: Clone, +{ + type Service = SetRequestHeader<S, M>; + + fn layer(&self, inner: S) -> Self::Service { + SetRequestHeader { + inner, + header_name: self.header_name.clone(), + make: self.make.clone(), + mode: self.mode, + } + } +} + +impl<M> Clone for SetRequestHeaderLayer<M> +where + M: Clone, +{ + fn clone(&self) -> Self { + Self { + make: self.make.clone(), + header_name: self.header_name.clone(), + mode: self.mode, + } + } +} + +/// Middleware that sets a header on the request. +#[derive(Clone)] +pub struct SetRequestHeader<S, M> { + inner: S, + header_name: HeaderName, + make: M, + mode: InsertHeaderMode, +} + +impl<S, M> SetRequestHeader<S, M> { + /// Create a new [`SetRequestHeader`]. + /// + /// If a previous value exists for the same header, it is removed and replaced with the new + /// header value. + pub fn overriding(inner: S, header_name: HeaderName, make: M) -> Self { + Self::new(inner, header_name, make, InsertHeaderMode::Override) + } + + /// Create a new [`SetRequestHeader`]. + /// + /// The new header is always added, preserving any existing values. If previous values exist, + /// the header will have multiple values. + pub fn appending(inner: S, header_name: HeaderName, make: M) -> Self { + Self::new(inner, header_name, make, InsertHeaderMode::Append) + } + + /// Create a new [`SetRequestHeader`]. + /// + /// If a previous value exists for the header, the new value is not inserted. + pub fn if_not_present(inner: S, header_name: HeaderName, make: M) -> Self { + Self::new(inner, header_name, make, InsertHeaderMode::IfNotPresent) + } + + fn new(inner: S, header_name: HeaderName, make: M, mode: InsertHeaderMode) -> Self { + Self { + inner, + header_name, + make, + mode, + } + } + + define_inner_service_accessors!(); +} + +impl<S, M> fmt::Debug for SetRequestHeader<S, M> +where + S: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("SetRequestHeader") + .field("inner", &self.inner) + .field("header_name", &self.header_name) + .field("mode", &self.mode) + .field("make", &std::any::type_name::<M>()) + .finish() + } +} + +impl<ReqBody, ResBody, S, M> Service<Request<ReqBody>> for SetRequestHeader<S, M> +where + S: Service<Request<ReqBody>, Response = Response<ResBody>>, + M: MakeHeaderValue<Request<ReqBody>>, +{ + type Response = S::Response; + type Error = S::Error; + type Future = S::Future; + + #[inline] + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future { + self.mode.apply(&self.header_name, &mut req, &mut self.make); + self.inner.call(req) + } +} diff --git a/vendor/tower-http/src/set_header/response.rs b/vendor/tower-http/src/set_header/response.rs new file mode 100644 index 00000000..c7b8ea84 --- /dev/null +++ b/vendor/tower-http/src/set_header/response.rs @@ -0,0 +1,391 @@ +//! Set a header on the response. +//! +//! The header value to be set may be provided as a fixed value when the +//! middleware is constructed, or determined dynamically based on the response +//! by a closure. See the [`MakeHeaderValue`] trait for details. +//! +//! # Example +//! +//! Setting a header from a fixed value provided when the middleware is constructed: +//! +//! ``` +//! use http::{Request, Response, header::{self, HeaderValue}}; +//! use tower::{Service, ServiceExt, ServiceBuilder}; +//! use tower_http::set_header::SetResponseHeaderLayer; +//! use http_body_util::Full; +//! use bytes::Bytes; +//! +//! # #[tokio::main] +//! # async fn main() -> Result<(), Box<dyn std::error::Error>> { +//! # let render_html = tower::service_fn(|request: Request<Full<Bytes>>| async move { +//! # Ok::<_, std::convert::Infallible>(Response::new(request.into_body())) +//! # }); +//! # +//! let mut svc = ServiceBuilder::new() +//! .layer( +//! // Layer that sets `Content-Type: text/html` on responses. +//! // +//! // `if_not_present` will only insert the header if it does not already +//! // have a value. +//! SetResponseHeaderLayer::if_not_present( +//! header::CONTENT_TYPE, +//! HeaderValue::from_static("text/html"), +//! ) +//! ) +//! .service(render_html); +//! +//! let request = Request::new(Full::default()); +//! +//! let response = svc.ready().await?.call(request).await?; +//! +//! assert_eq!(response.headers()["content-type"], "text/html"); +//! # +//! # Ok(()) +//! # } +//! ``` +//! +//! Setting a header based on a value determined dynamically from the response: +//! +//! ``` +//! use http::{Request, Response, header::{self, HeaderValue}}; +//! use tower::{Service, ServiceExt, ServiceBuilder}; +//! use tower_http::set_header::SetResponseHeaderLayer; +//! use bytes::Bytes; +//! use http_body_util::Full; +//! use http_body::Body as _; // for `Body::size_hint` +//! +//! # #[tokio::main] +//! # async fn main() -> Result<(), Box<dyn std::error::Error>> { +//! # let render_html = tower::service_fn(|request: Request<Full<Bytes>>| async move { +//! # Ok::<_, std::convert::Infallible>(Response::new(Full::from("1234567890"))) +//! # }); +//! # +//! let mut svc = ServiceBuilder::new() +//! .layer( +//! // Layer that sets `Content-Length` if the body has a known size. +//! // Bodies with streaming responses wont have a known size. +//! // +//! // `overriding` will insert the header and override any previous values it +//! // may have. +//! SetResponseHeaderLayer::overriding( +//! header::CONTENT_LENGTH, +//! |response: &Response<Full<Bytes>>| { +//! if let Some(size) = response.body().size_hint().exact() { +//! // If the response body has a known size, returning `Some` will +//! // set the `Content-Length` header to that value. +//! Some(HeaderValue::from_str(&size.to_string()).unwrap()) +//! } else { +//! // If the response body doesn't have a known size, return `None` +//! // to skip setting the header on this response. +//! None +//! } +//! } +//! ) +//! ) +//! .service(render_html); +//! +//! let request = Request::new(Full::default()); +//! +//! let response = svc.ready().await?.call(request).await?; +//! +//! assert_eq!(response.headers()["content-length"], "10"); +//! # +//! # Ok(()) +//! # } +//! ``` + +use super::{InsertHeaderMode, MakeHeaderValue}; +use http::{header::HeaderName, Request, Response}; +use pin_project_lite::pin_project; +use std::{ + fmt, + future::Future, + pin::Pin, + task::{ready, Context, Poll}, +}; +use tower_layer::Layer; +use tower_service::Service; + +/// Layer that applies [`SetResponseHeader`] which adds a response header. +/// +/// See [`SetResponseHeader`] for more details. +pub struct SetResponseHeaderLayer<M> { + header_name: HeaderName, + make: M, + mode: InsertHeaderMode, +} + +impl<M> fmt::Debug for SetResponseHeaderLayer<M> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("SetResponseHeaderLayer") + .field("header_name", &self.header_name) + .field("mode", &self.mode) + .field("make", &std::any::type_name::<M>()) + .finish() + } +} + +impl<M> SetResponseHeaderLayer<M> { + /// Create a new [`SetResponseHeaderLayer`]. + /// + /// If a previous value exists for the same header, it is removed and replaced with the new + /// header value. + pub fn overriding(header_name: HeaderName, make: M) -> Self { + Self::new(header_name, make, InsertHeaderMode::Override) + } + + /// Create a new [`SetResponseHeaderLayer`]. + /// + /// The new header is always added, preserving any existing values. If previous values exist, + /// the header will have multiple values. + pub fn appending(header_name: HeaderName, make: M) -> Self { + Self::new(header_name, make, InsertHeaderMode::Append) + } + + /// Create a new [`SetResponseHeaderLayer`]. + /// + /// If a previous value exists for the header, the new value is not inserted. + pub fn if_not_present(header_name: HeaderName, make: M) -> Self { + Self::new(header_name, make, InsertHeaderMode::IfNotPresent) + } + + fn new(header_name: HeaderName, make: M, mode: InsertHeaderMode) -> Self { + Self { + make, + header_name, + mode, + } + } +} + +impl<S, M> Layer<S> for SetResponseHeaderLayer<M> +where + M: Clone, +{ + type Service = SetResponseHeader<S, M>; + + fn layer(&self, inner: S) -> Self::Service { + SetResponseHeader { + inner, + header_name: self.header_name.clone(), + make: self.make.clone(), + mode: self.mode, + } + } +} + +impl<M> Clone for SetResponseHeaderLayer<M> +where + M: Clone, +{ + fn clone(&self) -> Self { + Self { + make: self.make.clone(), + header_name: self.header_name.clone(), + mode: self.mode, + } + } +} + +/// Middleware that sets a header on the response. +#[derive(Clone)] +pub struct SetResponseHeader<S, M> { + inner: S, + header_name: HeaderName, + make: M, + mode: InsertHeaderMode, +} + +impl<S, M> SetResponseHeader<S, M> { + /// Create a new [`SetResponseHeader`]. + /// + /// If a previous value exists for the same header, it is removed and replaced with the new + /// header value. + pub fn overriding(inner: S, header_name: HeaderName, make: M) -> Self { + Self::new(inner, header_name, make, InsertHeaderMode::Override) + } + + /// Create a new [`SetResponseHeader`]. + /// + /// The new header is always added, preserving any existing values. If previous values exist, + /// the header will have multiple values. + pub fn appending(inner: S, header_name: HeaderName, make: M) -> Self { + Self::new(inner, header_name, make, InsertHeaderMode::Append) + } + + /// Create a new [`SetResponseHeader`]. + /// + /// If a previous value exists for the header, the new value is not inserted. + pub fn if_not_present(inner: S, header_name: HeaderName, make: M) -> Self { + Self::new(inner, header_name, make, InsertHeaderMode::IfNotPresent) + } + + fn new(inner: S, header_name: HeaderName, make: M, mode: InsertHeaderMode) -> Self { + Self { + inner, + header_name, + make, + mode, + } + } + + define_inner_service_accessors!(); +} + +impl<S, M> fmt::Debug for SetResponseHeader<S, M> +where + S: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("SetResponseHeader") + .field("inner", &self.inner) + .field("header_name", &self.header_name) + .field("mode", &self.mode) + .field("make", &std::any::type_name::<M>()) + .finish() + } +} + +impl<ReqBody, ResBody, S, M> Service<Request<ReqBody>> for SetResponseHeader<S, M> +where + S: Service<Request<ReqBody>, Response = Response<ResBody>>, + M: MakeHeaderValue<Response<ResBody>> + Clone, +{ + type Response = S::Response; + type Error = S::Error; + type Future = ResponseFuture<S::Future, M>; + + #[inline] + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, req: Request<ReqBody>) -> Self::Future { + ResponseFuture { + future: self.inner.call(req), + header_name: self.header_name.clone(), + make: self.make.clone(), + mode: self.mode, + } + } +} + +pin_project! { + /// Response future for [`SetResponseHeader`]. + #[derive(Debug)] + pub struct ResponseFuture<F, M> { + #[pin] + future: F, + header_name: HeaderName, + make: M, + mode: InsertHeaderMode, + } +} + +impl<F, ResBody, E, M> Future for ResponseFuture<F, M> +where + F: Future<Output = Result<Response<ResBody>, E>>, + M: MakeHeaderValue<Response<ResBody>>, +{ + type Output = F::Output; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + let this = self.project(); + let mut res = ready!(this.future.poll(cx)?); + + this.mode.apply(this.header_name, &mut res, &mut *this.make); + + Poll::Ready(Ok(res)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::test_helpers::Body; + use http::{header, HeaderValue}; + use std::convert::Infallible; + use tower::{service_fn, ServiceExt}; + + #[tokio::test] + async fn test_override_mode() { + let svc = SetResponseHeader::overriding( + service_fn(|_req: Request<Body>| async { + let res = Response::builder() + .header(header::CONTENT_TYPE, "good-content") + .body(Body::empty()) + .unwrap(); + Ok::<_, Infallible>(res) + }), + header::CONTENT_TYPE, + HeaderValue::from_static("text/html"), + ); + + let res = svc.oneshot(Request::new(Body::empty())).await.unwrap(); + + let mut values = res.headers().get_all(header::CONTENT_TYPE).iter(); + assert_eq!(values.next().unwrap(), "text/html"); + assert_eq!(values.next(), None); + } + + #[tokio::test] + async fn test_append_mode() { + let svc = SetResponseHeader::appending( + service_fn(|_req: Request<Body>| async { + let res = Response::builder() + .header(header::CONTENT_TYPE, "good-content") + .body(Body::empty()) + .unwrap(); + Ok::<_, Infallible>(res) + }), + header::CONTENT_TYPE, + HeaderValue::from_static("text/html"), + ); + + let res = svc.oneshot(Request::new(Body::empty())).await.unwrap(); + + let mut values = res.headers().get_all(header::CONTENT_TYPE).iter(); + assert_eq!(values.next().unwrap(), "good-content"); + assert_eq!(values.next().unwrap(), "text/html"); + assert_eq!(values.next(), None); + } + + #[tokio::test] + async fn test_skip_if_present_mode() { + let svc = SetResponseHeader::if_not_present( + service_fn(|_req: Request<Body>| async { + let res = Response::builder() + .header(header::CONTENT_TYPE, "good-content") + .body(Body::empty()) + .unwrap(); + Ok::<_, Infallible>(res) + }), + header::CONTENT_TYPE, + HeaderValue::from_static("text/html"), + ); + + let res = svc.oneshot(Request::new(Body::empty())).await.unwrap(); + + let mut values = res.headers().get_all(header::CONTENT_TYPE).iter(); + assert_eq!(values.next().unwrap(), "good-content"); + assert_eq!(values.next(), None); + } + + #[tokio::test] + async fn test_skip_if_present_mode_when_not_present() { + let svc = SetResponseHeader::if_not_present( + service_fn(|_req: Request<Body>| async { + let res = Response::builder().body(Body::empty()).unwrap(); + Ok::<_, Infallible>(res) + }), + header::CONTENT_TYPE, + HeaderValue::from_static("text/html"), + ); + + let res = svc.oneshot(Request::new(Body::empty())).await.unwrap(); + + let mut values = res.headers().get_all(header::CONTENT_TYPE).iter(); + assert_eq!(values.next().unwrap(), "text/html"); + assert_eq!(values.next(), None); + } +} diff --git a/vendor/tower-http/src/set_status.rs b/vendor/tower-http/src/set_status.rs new file mode 100644 index 00000000..65f5405e --- /dev/null +++ b/vendor/tower-http/src/set_status.rs @@ -0,0 +1,137 @@ +//! Middleware to override status codes. +//! +//! # Example +//! +//! ``` +//! use tower_http::set_status::SetStatusLayer; +//! use http::{Request, Response, StatusCode}; +//! use bytes::Bytes; +//! use http_body_util::Full; +//! use std::{iter::once, convert::Infallible}; +//! use tower::{ServiceBuilder, Service, ServiceExt}; +//! +//! async fn handle(req: Request<Full<Bytes>>) -> Result<Response<Full<Bytes>>, Infallible> { +//! // ... +//! # Ok(Response::new(Full::default())) +//! } +//! +//! # #[tokio::main] +//! # async fn main() -> Result<(), Box<dyn std::error::Error>> { +//! let mut service = ServiceBuilder::new() +//! // change the status to `404 Not Found` regardless what the inner service returns +//! .layer(SetStatusLayer::new(StatusCode::NOT_FOUND)) +//! .service_fn(handle); +//! +//! // Call the service. +//! let request = Request::builder().body(Full::default())?; +//! +//! let response = service.ready().await?.call(request).await?; +//! +//! assert_eq!(response.status(), StatusCode::NOT_FOUND); +//! # +//! # Ok(()) +//! # } +//! ``` + +use http::{Request, Response, StatusCode}; +use pin_project_lite::pin_project; +use std::{ + future::Future, + pin::Pin, + task::{ready, Context, Poll}, +}; +use tower_layer::Layer; +use tower_service::Service; + +/// Layer that applies [`SetStatus`] which overrides the status codes. +#[derive(Debug, Clone, Copy)] +pub struct SetStatusLayer { + status: StatusCode, +} + +impl SetStatusLayer { + /// Create a new [`SetStatusLayer`]. + /// + /// The response status code will be `status` regardless of what the inner service returns. + pub fn new(status: StatusCode) -> Self { + SetStatusLayer { status } + } +} + +impl<S> Layer<S> for SetStatusLayer { + type Service = SetStatus<S>; + + fn layer(&self, inner: S) -> Self::Service { + SetStatus::new(inner, self.status) + } +} + +/// Middleware to override status codes. +/// +/// See the [module docs](self) for more details. +#[derive(Debug, Clone, Copy)] +pub struct SetStatus<S> { + inner: S, + status: StatusCode, +} + +impl<S> SetStatus<S> { + /// Create a new [`SetStatus`]. + /// + /// The response status code will be `status` regardless of what the inner service returns. + pub fn new(inner: S, status: StatusCode) -> Self { + Self { status, inner } + } + + define_inner_service_accessors!(); + + /// Returns a new [`Layer`] that wraps services with a `SetStatus` middleware. + /// + /// [`Layer`]: tower_layer::Layer + pub fn layer(status: StatusCode) -> SetStatusLayer { + SetStatusLayer::new(status) + } +} + +impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for SetStatus<S> +where + S: Service<Request<ReqBody>, Response = Response<ResBody>>, +{ + type Response = S::Response; + type Error = S::Error; + type Future = ResponseFuture<S::Future>; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, req: Request<ReqBody>) -> Self::Future { + ResponseFuture { + inner: self.inner.call(req), + status: Some(self.status), + } + } +} + +pin_project! { + /// Response future for [`SetStatus`]. + pub struct ResponseFuture<F> { + #[pin] + inner: F, + status: Option<StatusCode>, + } +} + +impl<F, B, E> Future for ResponseFuture<F> +where + F: Future<Output = Result<Response<B>, E>>, +{ + type Output = F::Output; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + let this = self.project(); + let mut response = ready!(this.inner.poll(cx)?); + *response.status_mut() = this.status.take().expect("future polled after completion"); + Poll::Ready(Ok(response)) + } +} diff --git a/vendor/tower-http/src/test_helpers.rs b/vendor/tower-http/src/test_helpers.rs new file mode 100644 index 00000000..af28463c --- /dev/null +++ b/vendor/tower-http/src/test_helpers.rs @@ -0,0 +1,165 @@ +use std::{ + pin::Pin, + task::{Context, Poll}, +}; + +use bytes::Bytes; +use futures_util::TryStream; +use http::HeaderMap; +use http_body::Frame; +use http_body_util::BodyExt; +use pin_project_lite::pin_project; +use sync_wrapper::SyncWrapper; +use tower::BoxError; + +type BoxBody = http_body_util::combinators::UnsyncBoxBody<Bytes, BoxError>; + +#[derive(Debug)] +pub(crate) struct Body(BoxBody); + +impl Body { + pub(crate) fn new<B>(body: B) -> Self + where + B: http_body::Body<Data = Bytes> + Send + 'static, + B::Error: Into<BoxError>, + { + Self(body.map_err(Into::into).boxed_unsync()) + } + + pub(crate) fn empty() -> Self { + Self::new(http_body_util::Empty::new()) + } + + pub(crate) fn from_stream<S>(stream: S) -> Self + where + S: TryStream + Send + 'static, + S::Ok: Into<Bytes>, + S::Error: Into<BoxError>, + { + Self::new(StreamBody { + stream: SyncWrapper::new(stream), + }) + } + + pub(crate) fn with_trailers(self, trailers: HeaderMap) -> WithTrailers<Self> { + WithTrailers { + inner: self, + trailers: Some(trailers), + } + } +} + +impl Default for Body { + fn default() -> Self { + Self::empty() + } +} + +macro_rules! body_from_impl { + ($ty:ty) => { + impl From<$ty> for Body { + fn from(buf: $ty) -> Self { + Self::new(http_body_util::Full::from(buf)) + } + } + }; +} + +body_from_impl!(&'static [u8]); +body_from_impl!(std::borrow::Cow<'static, [u8]>); +body_from_impl!(Vec<u8>); + +body_from_impl!(&'static str); +body_from_impl!(std::borrow::Cow<'static, str>); +body_from_impl!(String); + +body_from_impl!(Bytes); + +impl http_body::Body for Body { + type Data = Bytes; + type Error = BoxError; + + fn poll_frame( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> { + Pin::new(&mut self.0).poll_frame(cx) + } + + fn size_hint(&self) -> http_body::SizeHint { + self.0.size_hint() + } + + fn is_end_stream(&self) -> bool { + self.0.is_end_stream() + } +} + +pin_project! { + struct StreamBody<S> { + #[pin] + stream: SyncWrapper<S>, + } +} + +impl<S> http_body::Body for StreamBody<S> +where + S: TryStream, + S::Ok: Into<Bytes>, + S::Error: Into<BoxError>, +{ + type Data = Bytes; + type Error = BoxError; + + fn poll_frame( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> { + let stream = self.project().stream.get_pin_mut(); + match std::task::ready!(stream.try_poll_next(cx)) { + Some(Ok(chunk)) => Poll::Ready(Some(Ok(Frame::data(chunk.into())))), + Some(Err(err)) => Poll::Ready(Some(Err(err.into()))), + None => Poll::Ready(None), + } + } +} + +pub(crate) async fn to_bytes<T>(body: T) -> Result<Bytes, T::Error> +where + T: http_body::Body, +{ + Ok(body.collect().await?.to_bytes()) +} + +pin_project! { + pub(crate) struct WithTrailers<B> { + #[pin] + inner: B, + trailers: Option<HeaderMap>, + } +} + +impl<B> http_body::Body for WithTrailers<B> +where + B: http_body::Body, +{ + type Data = B::Data; + type Error = B::Error; + + fn poll_frame( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> { + let this = self.project(); + match std::task::ready!(this.inner.poll_frame(cx)) { + Some(frame) => Poll::Ready(Some(frame)), + None => { + if let Some(trailers) = this.trailers.take() { + Poll::Ready(Some(Ok(Frame::trailers(trailers)))) + } else { + Poll::Ready(None) + } + } + } + } +} diff --git a/vendor/tower-http/src/timeout/body.rs b/vendor/tower-http/src/timeout/body.rs new file mode 100644 index 00000000..d44f35b8 --- /dev/null +++ b/vendor/tower-http/src/timeout/body.rs @@ -0,0 +1,193 @@ +use crate::BoxError; +use http_body::Body; +use pin_project_lite::pin_project; +use std::{ + future::Future, + pin::Pin, + task::{ready, Context, Poll}, + time::Duration, +}; +use tokio::time::{sleep, Sleep}; + +pin_project! { + /// Middleware that applies a timeout to request and response bodies. + /// + /// Wrapper around a [`Body`][`http_body::Body`] to time out if data is not ready within the specified duration. + /// The timeout is enforced between consecutive [`Frame`][`http_body::Frame`] polls, and it + /// resets after each poll. + /// The total time to produce a [`Body`][`http_body::Body`] could exceed the timeout duration without + /// timing out, as long as no single interval between polls exceeds the timeout. + /// + /// If the [`Body`][`http_body::Body`] does not produce a requested data frame within the timeout period, it will return a [`TimeoutError`]. + /// + /// # Differences from [`Timeout`][crate::timeout::Timeout] + /// + /// [`Timeout`][crate::timeout::Timeout] applies a timeout to the request future, not body. + /// That timeout is not reset when bytes are handled, whether the request is active or not. + /// Bodies are handled asynchronously outside of the tower stack's future and thus needs an additional timeout. + /// + /// # Example + /// + /// ``` + /// use http::{Request, Response}; + /// use bytes::Bytes; + /// use http_body_util::Full; + /// use std::time::Duration; + /// use tower::ServiceBuilder; + /// use tower_http::timeout::RequestBodyTimeoutLayer; + /// + /// async fn handle(_: Request<Full<Bytes>>) -> Result<Response<Full<Bytes>>, std::convert::Infallible> { + /// // ... + /// # todo!() + /// } + /// + /// # #[tokio::main] + /// # async fn main() -> Result<(), Box<dyn std::error::Error>> { + /// let svc = ServiceBuilder::new() + /// // Timeout bodies after 30 seconds of inactivity + /// .layer(RequestBodyTimeoutLayer::new(Duration::from_secs(30))) + /// .service_fn(handle); + /// # Ok(()) + /// # } + /// ``` + pub struct TimeoutBody<B> { + timeout: Duration, + #[pin] + sleep: Option<Sleep>, + #[pin] + body: B, + } +} + +impl<B> TimeoutBody<B> { + /// Creates a new [`TimeoutBody`]. + pub fn new(timeout: Duration, body: B) -> Self { + TimeoutBody { + timeout, + sleep: None, + body, + } + } +} + +impl<B> Body for TimeoutBody<B> +where + B: Body, + B::Error: Into<BoxError>, +{ + type Data = B::Data; + type Error = Box<dyn std::error::Error + Send + Sync>; + + fn poll_frame( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> { + let mut this = self.project(); + + // Start the `Sleep` if not active. + let sleep_pinned = if let Some(some) = this.sleep.as_mut().as_pin_mut() { + some + } else { + this.sleep.set(Some(sleep(*this.timeout))); + this.sleep.as_mut().as_pin_mut().unwrap() + }; + + // Error if the timeout has expired. + if let Poll::Ready(()) = sleep_pinned.poll(cx) { + return Poll::Ready(Some(Err(Box::new(TimeoutError(()))))); + } + + // Check for body data. + let frame = ready!(this.body.poll_frame(cx)); + // A frame is ready. Reset the `Sleep`... + this.sleep.set(None); + + Poll::Ready(frame.transpose().map_err(Into::into).transpose()) + } +} + +/// Error for [`TimeoutBody`]. +#[derive(Debug)] +pub struct TimeoutError(()); + +impl std::error::Error for TimeoutError {} + +impl std::fmt::Display for TimeoutError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "data was not received within the designated timeout") + } +} +#[cfg(test)] +mod tests { + use super::*; + + use bytes::Bytes; + use http_body::Frame; + use http_body_util::BodyExt; + use pin_project_lite::pin_project; + use std::{error::Error, fmt::Display}; + + #[derive(Debug)] + struct MockError; + + impl Error for MockError {} + + impl Display for MockError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "mock error") + } + } + + pin_project! { + struct MockBody { + #[pin] + sleep: Sleep + } + } + + impl Body for MockBody { + type Data = Bytes; + type Error = MockError; + + fn poll_frame( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> { + let this = self.project(); + this.sleep + .poll(cx) + .map(|_| Some(Ok(Frame::data(vec![].into())))) + } + } + + #[tokio::test] + async fn test_body_available_within_timeout() { + let mock_sleep = Duration::from_secs(1); + let timeout_sleep = Duration::from_secs(2); + + let mock_body = MockBody { + sleep: sleep(mock_sleep), + }; + let timeout_body = TimeoutBody::new(timeout_sleep, mock_body); + + assert!(timeout_body + .boxed() + .frame() + .await + .expect("no frame") + .is_ok()); + } + + #[tokio::test] + async fn test_body_unavailable_within_timeout_error() { + let mock_sleep = Duration::from_secs(2); + let timeout_sleep = Duration::from_secs(1); + + let mock_body = MockBody { + sleep: sleep(mock_sleep), + }; + let timeout_body = TimeoutBody::new(timeout_sleep, mock_body); + + assert!(timeout_body.boxed().frame().await.unwrap().is_err()); + } +} diff --git a/vendor/tower-http/src/timeout/mod.rs b/vendor/tower-http/src/timeout/mod.rs new file mode 100644 index 00000000..facb6a92 --- /dev/null +++ b/vendor/tower-http/src/timeout/mod.rs @@ -0,0 +1,50 @@ +//! Middleware that applies a timeout to requests. +//! +//! If the request does not complete within the specified timeout it will be aborted and a `408 +//! Request Timeout` response will be sent. +//! +//! # Differences from `tower::timeout` +//! +//! tower's [`Timeout`](tower::timeout::Timeout) middleware uses an error to signal timeout, i.e. +//! it changes the error type to [`BoxError`](tower::BoxError). For HTTP services that is rarely +//! what you want as returning errors will terminate the connection without sending a response. +//! +//! This middleware won't change the error type and instead return a `408 Request Timeout` +//! response. That means if your service's error type is [`Infallible`] it will still be +//! [`Infallible`] after applying this middleware. +//! +//! # Example +//! +//! ``` +//! use http::{Request, Response}; +//! use http_body_util::Full; +//! use bytes::Bytes; +//! use std::{convert::Infallible, time::Duration}; +//! use tower::ServiceBuilder; +//! use tower_http::timeout::TimeoutLayer; +//! +//! async fn handle(_: Request<Full<Bytes>>) -> Result<Response<Full<Bytes>>, Infallible> { +//! // ... +//! # Ok(Response::new(Full::default())) +//! } +//! +//! # #[tokio::main] +//! # async fn main() -> Result<(), Box<dyn std::error::Error>> { +//! let svc = ServiceBuilder::new() +//! // Timeout requests after 30 seconds +//! .layer(TimeoutLayer::new(Duration::from_secs(30))) +//! .service_fn(handle); +//! # Ok(()) +//! # } +//! ``` +//! +//! [`Infallible`]: std::convert::Infallible + +mod body; +mod service; + +pub use body::{TimeoutBody, TimeoutError}; +pub use service::{ + RequestBodyTimeout, RequestBodyTimeoutLayer, ResponseBodyTimeout, ResponseBodyTimeoutLayer, + Timeout, TimeoutLayer, +}; diff --git a/vendor/tower-http/src/timeout/service.rs b/vendor/tower-http/src/timeout/service.rs new file mode 100644 index 00000000..230fe717 --- /dev/null +++ b/vendor/tower-http/src/timeout/service.rs @@ -0,0 +1,271 @@ +use crate::timeout::body::TimeoutBody; +use http::{Request, Response, StatusCode}; +use pin_project_lite::pin_project; +use std::{ + future::Future, + pin::Pin, + task::{ready, Context, Poll}, + time::Duration, +}; +use tokio::time::Sleep; +use tower_layer::Layer; +use tower_service::Service; + +/// Layer that applies the [`Timeout`] middleware which apply a timeout to requests. +/// +/// See the [module docs](super) for an example. +#[derive(Debug, Clone, Copy)] +pub struct TimeoutLayer { + timeout: Duration, +} + +impl TimeoutLayer { + /// Creates a new [`TimeoutLayer`]. + pub fn new(timeout: Duration) -> Self { + TimeoutLayer { timeout } + } +} + +impl<S> Layer<S> for TimeoutLayer { + type Service = Timeout<S>; + + fn layer(&self, inner: S) -> Self::Service { + Timeout::new(inner, self.timeout) + } +} + +/// Middleware which apply a timeout to requests. +/// +/// If the request does not complete within the specified timeout it will be aborted and a `408 +/// Request Timeout` response will be sent. +/// +/// See the [module docs](super) for an example. +#[derive(Debug, Clone, Copy)] +pub struct Timeout<S> { + inner: S, + timeout: Duration, +} + +impl<S> Timeout<S> { + /// Creates a new [`Timeout`]. + pub fn new(inner: S, timeout: Duration) -> Self { + Self { inner, timeout } + } + + define_inner_service_accessors!(); + + /// Returns a new [`Layer`] that wraps services with a `Timeout` middleware. + /// + /// [`Layer`]: tower_layer::Layer + pub fn layer(timeout: Duration) -> TimeoutLayer { + TimeoutLayer::new(timeout) + } +} + +impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for Timeout<S> +where + S: Service<Request<ReqBody>, Response = Response<ResBody>>, + ResBody: Default, +{ + type Response = S::Response; + type Error = S::Error; + type Future = ResponseFuture<S::Future>; + + #[inline] + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, req: Request<ReqBody>) -> Self::Future { + let sleep = tokio::time::sleep(self.timeout); + ResponseFuture { + inner: self.inner.call(req), + sleep, + } + } +} + +pin_project! { + /// Response future for [`Timeout`]. + pub struct ResponseFuture<F> { + #[pin] + inner: F, + #[pin] + sleep: Sleep, + } +} + +impl<F, B, E> Future for ResponseFuture<F> +where + F: Future<Output = Result<Response<B>, E>>, + B: Default, +{ + type Output = Result<Response<B>, E>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + let this = self.project(); + + if this.sleep.poll(cx).is_ready() { + let mut res = Response::new(B::default()); + *res.status_mut() = StatusCode::REQUEST_TIMEOUT; + return Poll::Ready(Ok(res)); + } + + this.inner.poll(cx) + } +} + +/// Applies a [`TimeoutBody`] to the request body. +#[derive(Clone, Debug)] +pub struct RequestBodyTimeoutLayer { + timeout: Duration, +} + +impl RequestBodyTimeoutLayer { + /// Creates a new [`RequestBodyTimeoutLayer`]. + pub fn new(timeout: Duration) -> Self { + Self { timeout } + } +} + +impl<S> Layer<S> for RequestBodyTimeoutLayer { + type Service = RequestBodyTimeout<S>; + + fn layer(&self, inner: S) -> Self::Service { + RequestBodyTimeout::new(inner, self.timeout) + } +} + +/// Applies a [`TimeoutBody`] to the request body. +#[derive(Clone, Debug)] +pub struct RequestBodyTimeout<S> { + inner: S, + timeout: Duration, +} + +impl<S> RequestBodyTimeout<S> { + /// Creates a new [`RequestBodyTimeout`]. + pub fn new(service: S, timeout: Duration) -> Self { + Self { + inner: service, + timeout, + } + } + + /// Returns a new [`Layer`] that wraps services with a [`RequestBodyTimeoutLayer`] middleware. + /// + /// [`Layer`]: tower_layer::Layer + pub fn layer(timeout: Duration) -> RequestBodyTimeoutLayer { + RequestBodyTimeoutLayer::new(timeout) + } + + define_inner_service_accessors!(); +} + +impl<S, ReqBody> Service<Request<ReqBody>> for RequestBodyTimeout<S> +where + S: Service<Request<TimeoutBody<ReqBody>>>, +{ + type Response = S::Response; + type Error = S::Error; + type Future = S::Future; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, req: Request<ReqBody>) -> Self::Future { + let req = req.map(|body| TimeoutBody::new(self.timeout, body)); + self.inner.call(req) + } +} + +/// Applies a [`TimeoutBody`] to the response body. +#[derive(Clone)] +pub struct ResponseBodyTimeoutLayer { + timeout: Duration, +} + +impl ResponseBodyTimeoutLayer { + /// Creates a new [`ResponseBodyTimeoutLayer`]. + pub fn new(timeout: Duration) -> Self { + Self { timeout } + } +} + +impl<S> Layer<S> for ResponseBodyTimeoutLayer { + type Service = ResponseBodyTimeout<S>; + + fn layer(&self, inner: S) -> Self::Service { + ResponseBodyTimeout::new(inner, self.timeout) + } +} + +/// Applies a [`TimeoutBody`] to the response body. +#[derive(Clone)] +pub struct ResponseBodyTimeout<S> { + inner: S, + timeout: Duration, +} + +impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for ResponseBodyTimeout<S> +where + S: Service<Request<ReqBody>, Response = Response<ResBody>>, +{ + type Response = Response<TimeoutBody<ResBody>>; + type Error = S::Error; + type Future = ResponseBodyTimeoutFuture<S::Future>; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, req: Request<ReqBody>) -> Self::Future { + ResponseBodyTimeoutFuture { + inner: self.inner.call(req), + timeout: self.timeout, + } + } +} + +impl<S> ResponseBodyTimeout<S> { + /// Creates a new [`ResponseBodyTimeout`]. + pub fn new(service: S, timeout: Duration) -> Self { + Self { + inner: service, + timeout, + } + } + + /// Returns a new [`Layer`] that wraps services with a [`ResponseBodyTimeoutLayer`] middleware. + /// + /// [`Layer`]: tower_layer::Layer + pub fn layer(timeout: Duration) -> ResponseBodyTimeoutLayer { + ResponseBodyTimeoutLayer::new(timeout) + } + + define_inner_service_accessors!(); +} + +pin_project! { + /// Response future for [`ResponseBodyTimeout`]. + pub struct ResponseBodyTimeoutFuture<Fut> { + #[pin] + inner: Fut, + timeout: Duration, + } +} + +impl<Fut, ResBody, E> Future for ResponseBodyTimeoutFuture<Fut> +where + Fut: Future<Output = Result<Response<ResBody>, E>>, +{ + type Output = Result<Response<TimeoutBody<ResBody>>, E>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + let timeout = self.timeout; + let this = self.project(); + let res = ready!(this.inner.poll(cx))?; + Poll::Ready(Ok(res.map(|body| TimeoutBody::new(timeout, body)))) + } +} diff --git a/vendor/tower-http/src/trace/body.rs b/vendor/tower-http/src/trace/body.rs new file mode 100644 index 00000000..d713f243 --- /dev/null +++ b/vendor/tower-http/src/trace/body.rs @@ -0,0 +1,102 @@ +use super::{DefaultOnBodyChunk, DefaultOnEos, DefaultOnFailure, OnBodyChunk, OnEos, OnFailure}; +use crate::classify::ClassifyEos; +use http_body::{Body, Frame}; +use pin_project_lite::pin_project; +use std::{ + fmt, + pin::Pin, + task::{ready, Context, Poll}, + time::Instant, +}; +use tracing::Span; + +pin_project! { + /// Response body for [`Trace`]. + /// + /// [`Trace`]: super::Trace + pub struct ResponseBody<B, C, OnBodyChunk = DefaultOnBodyChunk, OnEos = DefaultOnEos, OnFailure = DefaultOnFailure> { + #[pin] + pub(crate) inner: B, + pub(crate) classify_eos: Option<C>, + pub(crate) on_eos: Option<(OnEos, Instant)>, + pub(crate) on_body_chunk: OnBodyChunk, + pub(crate) on_failure: Option<OnFailure>, + pub(crate) start: Instant, + pub(crate) span: Span, + } +} + +impl<B, C, OnBodyChunkT, OnEosT, OnFailureT> Body + for ResponseBody<B, C, OnBodyChunkT, OnEosT, OnFailureT> +where + B: Body, + B::Error: fmt::Display + 'static, + C: ClassifyEos, + OnEosT: OnEos, + OnBodyChunkT: OnBodyChunk<B::Data>, + OnFailureT: OnFailure<C::FailureClass>, +{ + type Data = B::Data; + type Error = B::Error; + + fn poll_frame( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> { + let this = self.project(); + let _guard = this.span.enter(); + let result = ready!(this.inner.poll_frame(cx)); + + let latency = this.start.elapsed(); + *this.start = Instant::now(); + + match result { + Some(Ok(frame)) => { + let frame = match frame.into_data() { + Ok(chunk) => { + this.on_body_chunk.on_body_chunk(&chunk, latency, this.span); + Frame::data(chunk) + } + Err(frame) => frame, + }; + + let frame = match frame.into_trailers() { + Ok(trailers) => { + if let Some((on_eos, stream_start)) = this.on_eos.take() { + on_eos.on_eos(Some(&trailers), stream_start.elapsed(), this.span); + } + Frame::trailers(trailers) + } + Err(frame) => frame, + }; + + Poll::Ready(Some(Ok(frame))) + } + Some(Err(err)) => { + if let Some((classify_eos, mut on_failure)) = + this.classify_eos.take().zip(this.on_failure.take()) + { + let failure_class = classify_eos.classify_error(&err); + on_failure.on_failure(failure_class, latency, this.span); + } + + Poll::Ready(Some(Err(err))) + } + None => { + if let Some((on_eos, stream_start)) = this.on_eos.take() { + on_eos.on_eos(None, stream_start.elapsed(), this.span); + } + + Poll::Ready(None) + } + } + } + + fn is_end_stream(&self) -> bool { + self.inner.is_end_stream() + } + + fn size_hint(&self) -> http_body::SizeHint { + self.inner.size_hint() + } +} diff --git a/vendor/tower-http/src/trace/future.rs b/vendor/tower-http/src/trace/future.rs new file mode 100644 index 00000000..e205ea32 --- /dev/null +++ b/vendor/tower-http/src/trace/future.rs @@ -0,0 +1,116 @@ +use super::{ + DefaultOnBodyChunk, DefaultOnEos, DefaultOnFailure, DefaultOnResponse, OnBodyChunk, OnEos, + OnFailure, OnResponse, ResponseBody, +}; +use crate::classify::{ClassifiedResponse, ClassifyResponse}; +use http::Response; +use http_body::Body; +use pin_project_lite::pin_project; +use std::{ + future::Future, + pin::Pin, + task::{ready, Context, Poll}, + time::Instant, +}; +use tracing::Span; + +pin_project! { + /// Response future for [`Trace`]. + /// + /// [`Trace`]: super::Trace + pub struct ResponseFuture<F, C, OnResponse = DefaultOnResponse, OnBodyChunk = DefaultOnBodyChunk, OnEos = DefaultOnEos, OnFailure = DefaultOnFailure> { + #[pin] + pub(crate) inner: F, + pub(crate) span: Span, + pub(crate) classifier: Option<C>, + pub(crate) on_response: Option<OnResponse>, + pub(crate) on_body_chunk: Option<OnBodyChunk>, + pub(crate) on_eos: Option<OnEos>, + pub(crate) on_failure: Option<OnFailure>, + pub(crate) start: Instant, + } +} + +impl<Fut, ResBody, E, C, OnResponseT, OnBodyChunkT, OnEosT, OnFailureT> Future + for ResponseFuture<Fut, C, OnResponseT, OnBodyChunkT, OnEosT, OnFailureT> +where + Fut: Future<Output = Result<Response<ResBody>, E>>, + ResBody: Body, + ResBody::Error: std::fmt::Display + 'static, + E: std::fmt::Display + 'static, + C: ClassifyResponse, + OnResponseT: OnResponse<ResBody>, + OnFailureT: OnFailure<C::FailureClass>, + OnBodyChunkT: OnBodyChunk<ResBody::Data>, + OnEosT: OnEos, +{ + type Output = Result< + Response<ResponseBody<ResBody, C::ClassifyEos, OnBodyChunkT, OnEosT, OnFailureT>>, + E, + >; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + let this = self.project(); + let _guard = this.span.enter(); + let result = ready!(this.inner.poll(cx)); + let latency = this.start.elapsed(); + + let classifier = this.classifier.take().unwrap(); + let on_eos = this.on_eos.take(); + let on_body_chunk = this.on_body_chunk.take().unwrap(); + let mut on_failure = this.on_failure.take().unwrap(); + + match result { + Ok(res) => { + let classification = classifier.classify_response(&res); + let start = *this.start; + + this.on_response + .take() + .unwrap() + .on_response(&res, latency, this.span); + + match classification { + ClassifiedResponse::Ready(classification) => { + if let Err(failure_class) = classification { + on_failure.on_failure(failure_class, latency, this.span); + } + + let span = this.span.clone(); + let res = res.map(|body| ResponseBody { + inner: body, + classify_eos: None, + on_eos: None, + on_body_chunk, + on_failure: Some(on_failure), + start, + span, + }); + + Poll::Ready(Ok(res)) + } + ClassifiedResponse::RequiresEos(classify_eos) => { + let span = this.span.clone(); + let res = res.map(|body| ResponseBody { + inner: body, + classify_eos: Some(classify_eos), + on_eos: on_eos.zip(Some(Instant::now())), + on_body_chunk, + on_failure: Some(on_failure), + start, + span, + }); + + Poll::Ready(Ok(res)) + } + } + } + Err(err) => { + let failure_class = classifier.classify_error(&err); + on_failure.on_failure(failure_class, latency, this.span); + + Poll::Ready(Err(err)) + } + } + } +} diff --git a/vendor/tower-http/src/trace/layer.rs b/vendor/tower-http/src/trace/layer.rs new file mode 100644 index 00000000..21ff321c --- /dev/null +++ b/vendor/tower-http/src/trace/layer.rs @@ -0,0 +1,236 @@ +use super::{ + DefaultMakeSpan, DefaultOnBodyChunk, DefaultOnEos, DefaultOnFailure, DefaultOnRequest, + DefaultOnResponse, GrpcMakeClassifier, HttpMakeClassifier, Trace, +}; +use crate::classify::{ + GrpcErrorsAsFailures, MakeClassifier, ServerErrorsAsFailures, SharedClassifier, +}; +use tower_layer::Layer; + +/// [`Layer`] that adds high level [tracing] to a [`Service`]. +/// +/// See the [module docs](crate::trace) for more details. +/// +/// [`Layer`]: tower_layer::Layer +/// [tracing]: https://crates.io/crates/tracing +/// [`Service`]: tower_service::Service +#[derive(Debug, Copy, Clone)] +pub struct TraceLayer< + M, + MakeSpan = DefaultMakeSpan, + OnRequest = DefaultOnRequest, + OnResponse = DefaultOnResponse, + OnBodyChunk = DefaultOnBodyChunk, + OnEos = DefaultOnEos, + OnFailure = DefaultOnFailure, +> { + pub(crate) make_classifier: M, + pub(crate) make_span: MakeSpan, + pub(crate) on_request: OnRequest, + pub(crate) on_response: OnResponse, + pub(crate) on_body_chunk: OnBodyChunk, + pub(crate) on_eos: OnEos, + pub(crate) on_failure: OnFailure, +} + +impl<M> TraceLayer<M> { + /// Create a new [`TraceLayer`] using the given [`MakeClassifier`]. + pub fn new(make_classifier: M) -> Self + where + M: MakeClassifier, + { + Self { + make_classifier, + make_span: DefaultMakeSpan::new(), + on_failure: DefaultOnFailure::default(), + on_request: DefaultOnRequest::default(), + on_eos: DefaultOnEos::default(), + on_body_chunk: DefaultOnBodyChunk::default(), + on_response: DefaultOnResponse::default(), + } + } +} + +impl<M, MakeSpan, OnRequest, OnResponse, OnBodyChunk, OnEos, OnFailure> + TraceLayer<M, MakeSpan, OnRequest, OnResponse, OnBodyChunk, OnEos, OnFailure> +{ + /// Customize what to do when a request is received. + /// + /// `NewOnRequest` is expected to implement [`OnRequest`]. + /// + /// [`OnRequest`]: super::OnRequest + pub fn on_request<NewOnRequest>( + self, + new_on_request: NewOnRequest, + ) -> TraceLayer<M, MakeSpan, NewOnRequest, OnResponse, OnBodyChunk, OnEos, OnFailure> { + TraceLayer { + on_request: new_on_request, + on_failure: self.on_failure, + on_eos: self.on_eos, + on_body_chunk: self.on_body_chunk, + make_span: self.make_span, + on_response: self.on_response, + make_classifier: self.make_classifier, + } + } + + /// Customize what to do when a response has been produced. + /// + /// `NewOnResponse` is expected to implement [`OnResponse`]. + /// + /// [`OnResponse`]: super::OnResponse + pub fn on_response<NewOnResponse>( + self, + new_on_response: NewOnResponse, + ) -> TraceLayer<M, MakeSpan, OnRequest, NewOnResponse, OnBodyChunk, OnEos, OnFailure> { + TraceLayer { + on_response: new_on_response, + on_request: self.on_request, + on_eos: self.on_eos, + on_body_chunk: self.on_body_chunk, + on_failure: self.on_failure, + make_span: self.make_span, + make_classifier: self.make_classifier, + } + } + + /// Customize what to do when a body chunk has been sent. + /// + /// `NewOnBodyChunk` is expected to implement [`OnBodyChunk`]. + /// + /// [`OnBodyChunk`]: super::OnBodyChunk + pub fn on_body_chunk<NewOnBodyChunk>( + self, + new_on_body_chunk: NewOnBodyChunk, + ) -> TraceLayer<M, MakeSpan, OnRequest, OnResponse, NewOnBodyChunk, OnEos, OnFailure> { + TraceLayer { + on_body_chunk: new_on_body_chunk, + on_eos: self.on_eos, + on_failure: self.on_failure, + on_request: self.on_request, + make_span: self.make_span, + on_response: self.on_response, + make_classifier: self.make_classifier, + } + } + + /// Customize what to do when a streaming response has closed. + /// + /// `NewOnEos` is expected to implement [`OnEos`]. + /// + /// [`OnEos`]: super::OnEos + pub fn on_eos<NewOnEos>( + self, + new_on_eos: NewOnEos, + ) -> TraceLayer<M, MakeSpan, OnRequest, OnResponse, OnBodyChunk, NewOnEos, OnFailure> { + TraceLayer { + on_eos: new_on_eos, + on_body_chunk: self.on_body_chunk, + on_failure: self.on_failure, + on_request: self.on_request, + make_span: self.make_span, + on_response: self.on_response, + make_classifier: self.make_classifier, + } + } + + /// Customize what to do when a response has been classified as a failure. + /// + /// `NewOnFailure` is expected to implement [`OnFailure`]. + /// + /// [`OnFailure`]: super::OnFailure + pub fn on_failure<NewOnFailure>( + self, + new_on_failure: NewOnFailure, + ) -> TraceLayer<M, MakeSpan, OnRequest, OnResponse, OnBodyChunk, OnEos, NewOnFailure> { + TraceLayer { + on_failure: new_on_failure, + on_request: self.on_request, + on_eos: self.on_eos, + on_body_chunk: self.on_body_chunk, + make_span: self.make_span, + on_response: self.on_response, + make_classifier: self.make_classifier, + } + } + + /// Customize how to make [`Span`]s that all request handling will be wrapped in. + /// + /// `NewMakeSpan` is expected to implement [`MakeSpan`]. + /// + /// [`MakeSpan`]: super::MakeSpan + /// [`Span`]: tracing::Span + pub fn make_span_with<NewMakeSpan>( + self, + new_make_span: NewMakeSpan, + ) -> TraceLayer<M, NewMakeSpan, OnRequest, OnResponse, OnBodyChunk, OnEos, OnFailure> { + TraceLayer { + make_span: new_make_span, + on_request: self.on_request, + on_failure: self.on_failure, + on_body_chunk: self.on_body_chunk, + on_eos: self.on_eos, + on_response: self.on_response, + make_classifier: self.make_classifier, + } + } +} + +impl TraceLayer<HttpMakeClassifier> { + /// Create a new [`TraceLayer`] using [`ServerErrorsAsFailures`] which supports classifying + /// regular HTTP responses based on the status code. + pub fn new_for_http() -> Self { + Self { + make_classifier: SharedClassifier::new(ServerErrorsAsFailures::default()), + make_span: DefaultMakeSpan::new(), + on_response: DefaultOnResponse::default(), + on_request: DefaultOnRequest::default(), + on_body_chunk: DefaultOnBodyChunk::default(), + on_eos: DefaultOnEos::default(), + on_failure: DefaultOnFailure::default(), + } + } +} + +impl TraceLayer<GrpcMakeClassifier> { + /// Create a new [`TraceLayer`] using [`GrpcErrorsAsFailures`] which supports classifying + /// gRPC responses and streams based on the `grpc-status` header. + pub fn new_for_grpc() -> Self { + Self { + make_classifier: SharedClassifier::new(GrpcErrorsAsFailures::default()), + make_span: DefaultMakeSpan::new(), + on_response: DefaultOnResponse::default(), + on_request: DefaultOnRequest::default(), + on_body_chunk: DefaultOnBodyChunk::default(), + on_eos: DefaultOnEos::default(), + on_failure: DefaultOnFailure::default(), + } + } +} + +impl<S, M, MakeSpan, OnRequest, OnResponse, OnBodyChunk, OnEos, OnFailure> Layer<S> + for TraceLayer<M, MakeSpan, OnRequest, OnResponse, OnBodyChunk, OnEos, OnFailure> +where + M: Clone, + MakeSpan: Clone, + OnRequest: Clone, + OnResponse: Clone, + OnEos: Clone, + OnBodyChunk: Clone, + OnFailure: Clone, +{ + type Service = Trace<S, M, MakeSpan, OnRequest, OnResponse, OnBodyChunk, OnEos, OnFailure>; + + fn layer(&self, inner: S) -> Self::Service { + Trace { + inner, + make_classifier: self.make_classifier.clone(), + make_span: self.make_span.clone(), + on_request: self.on_request.clone(), + on_eos: self.on_eos.clone(), + on_body_chunk: self.on_body_chunk.clone(), + on_response: self.on_response.clone(), + on_failure: self.on_failure.clone(), + } + } +} diff --git a/vendor/tower-http/src/trace/make_span.rs b/vendor/tower-http/src/trace/make_span.rs new file mode 100644 index 00000000..bf558d3b --- /dev/null +++ b/vendor/tower-http/src/trace/make_span.rs @@ -0,0 +1,113 @@ +use http::Request; +use tracing::{Level, Span}; + +use super::DEFAULT_MESSAGE_LEVEL; + +/// Trait used to generate [`Span`]s from requests. [`Trace`] wraps all request handling in this +/// span. +/// +/// [`Span`]: tracing::Span +/// [`Trace`]: super::Trace +pub trait MakeSpan<B> { + /// Make a span from a request. + fn make_span(&mut self, request: &Request<B>) -> Span; +} + +impl<B> MakeSpan<B> for Span { + fn make_span(&mut self, _request: &Request<B>) -> Span { + self.clone() + } +} + +impl<F, B> MakeSpan<B> for F +where + F: FnMut(&Request<B>) -> Span, +{ + fn make_span(&mut self, request: &Request<B>) -> Span { + self(request) + } +} + +/// The default way [`Span`]s will be created for [`Trace`]. +/// +/// [`Span`]: tracing::Span +/// [`Trace`]: super::Trace +#[derive(Debug, Clone)] +pub struct DefaultMakeSpan { + level: Level, + include_headers: bool, +} + +impl DefaultMakeSpan { + /// Create a new `DefaultMakeSpan`. + pub fn new() -> Self { + Self { + level: DEFAULT_MESSAGE_LEVEL, + include_headers: false, + } + } + + /// Set the [`Level`] used for the [tracing span]. + /// + /// Defaults to [`Level::DEBUG`]. + /// + /// [tracing span]: https://docs.rs/tracing/latest/tracing/#spans + pub fn level(mut self, level: Level) -> Self { + self.level = level; + self + } + + /// Include request headers on the [`Span`]. + /// + /// By default headers are not included. + /// + /// [`Span`]: tracing::Span + pub fn include_headers(mut self, include_headers: bool) -> Self { + self.include_headers = include_headers; + self + } +} + +impl Default for DefaultMakeSpan { + fn default() -> Self { + Self::new() + } +} + +impl<B> MakeSpan<B> for DefaultMakeSpan { + fn make_span(&mut self, request: &Request<B>) -> Span { + // This ugly macro is needed, unfortunately, because `tracing::span!` + // required the level argument to be static. Meaning we can't just pass + // `self.level`. + macro_rules! make_span { + ($level:expr) => { + if self.include_headers { + tracing::span!( + $level, + "request", + method = %request.method(), + uri = %request.uri(), + version = ?request.version(), + headers = ?request.headers(), + ) + } else { + tracing::span!( + $level, + "request", + method = %request.method(), + uri = %request.uri(), + version = ?request.version(), + ) + } + } + } + + match self.level { + Level::ERROR => make_span!(Level::ERROR), + Level::WARN => make_span!(Level::WARN), + Level::INFO => make_span!(Level::INFO), + Level::DEBUG => make_span!(Level::DEBUG), + Level::TRACE => make_span!(Level::TRACE), + } + } +} diff --git a/vendor/tower-http/src/trace/mod.rs b/vendor/tower-http/src/trace/mod.rs new file mode 100644 index 00000000..ec5036aa --- /dev/null +++ b/vendor/tower-http/src/trace/mod.rs @@ -0,0 +1,635 @@ +//! Middleware that adds high level [tracing] to a [`Service`]. +//! +//! # Example +//! +//! Adding tracing to your service can be as simple as: +//! +//! ```rust +//! use http::{Request, Response}; +//! use tower::{ServiceBuilder, ServiceExt, Service}; +//! use tower_http::trace::TraceLayer; +//! use std::convert::Infallible; +//! use http_body_util::Full; +//! use bytes::Bytes; +//! +//! async fn handle(request: Request<Full<Bytes>>) -> Result<Response<Full<Bytes>>, Infallible> { +//! Ok(Response::new(Full::default())) +//! } +//! +//! # #[tokio::main] +//! # async fn main() -> Result<(), Box<dyn std::error::Error>> { +//! // Setup tracing +//! tracing_subscriber::fmt::init(); +//! +//! let mut service = ServiceBuilder::new() +//! .layer(TraceLayer::new_for_http()) +//! .service_fn(handle); +//! +//! let request = Request::new(Full::from("foo")); +//! +//! let response = service +//! .ready() +//! .await? +//! .call(request) +//! .await?; +//! # Ok(()) +//! # } +//! ``` +//! +//! If you run this application with `RUST_LOG=tower_http=trace cargo run` you should see logs like: +//! +//! ```text +//! Mar 05 20:50:28.523 DEBUG request{method=GET path="/foo"}: tower_http::trace::on_request: started processing request +//! Mar 05 20:50:28.524 DEBUG request{method=GET path="/foo"}: tower_http::trace::on_response: finished processing request latency=1 ms status=200 +//! ``` +//! +//! # Customization +//! +//! [`Trace`] comes with good defaults but also supports customizing many aspects of the output. +//! +//! The default behaviour supports some customization: +//! +//! ```rust +//! use http::{Request, Response, HeaderMap, StatusCode}; +//! use http_body_util::Full; +//! use bytes::Bytes; +//! use tower::ServiceBuilder; +//! use tracing::Level; +//! use tower_http::{ +//! LatencyUnit, +//! trace::{TraceLayer, DefaultMakeSpan, DefaultOnRequest, DefaultOnResponse}, +//! }; +//! use std::time::Duration; +//! # use tower::{ServiceExt, Service}; +//! # use std::convert::Infallible; +//! +//! # async fn handle(request: Request<Full<Bytes>>) -> Result<Response<Full<Bytes>>, Infallible> { +//! # Ok(Response::new(Full::from("foo"))) +//! # } +//! # #[tokio::main] +//! # async fn main() -> Result<(), Box<dyn std::error::Error>> { +//! # tracing_subscriber::fmt::init(); +//! # +//! let service = ServiceBuilder::new() +//! .layer( +//! TraceLayer::new_for_http() +//! .make_span_with( +//! DefaultMakeSpan::new().include_headers(true) +//! ) +//! .on_request( +//! DefaultOnRequest::new().level(Level::INFO) +//! ) +//! .on_response( +//! DefaultOnResponse::new() +//! .level(Level::INFO) +//! .latency_unit(LatencyUnit::Micros) +//! ) +//! // on so on for `on_eos`, `on_body_chunk`, and `on_failure` +//! ) +//! .service_fn(handle); +//! # let mut service = service; +//! # let response = service +//! # .ready() +//! # .await? +//! # .call(Request::new(Full::from("foo"))) +//! # .await?; +//! # Ok(()) +//! # } +//! ``` +//! +//! However for maximum control you can provide callbacks: +//! +//! ```rust +//! use http::{Request, Response, HeaderMap, StatusCode}; +//! use http_body_util::Full; +//! use bytes::Bytes; +//! use tower::ServiceBuilder; +//! use tower_http::{classify::ServerErrorsFailureClass, trace::TraceLayer}; +//! use std::time::Duration; +//! use tracing::Span; +//! # use tower::{ServiceExt, Service}; +//! # use std::convert::Infallible; +//! +//! # async fn handle(request: Request<Full<Bytes>>) -> Result<Response<Full<Bytes>>, Infallible> { +//! # Ok(Response::new(Full::from("foo"))) +//! # } +//! # #[tokio::main] +//! # async fn main() -> Result<(), Box<dyn std::error::Error>> { +//! # tracing_subscriber::fmt::init(); +//! # +//! let service = ServiceBuilder::new() +//! .layer( +//! TraceLayer::new_for_http() +//! .make_span_with(|request: &Request<Full<Bytes>>| { +//! tracing::debug_span!("http-request") +//! }) +//! .on_request(|request: &Request<Full<Bytes>>, _span: &Span| { +//! tracing::debug!("started {} {}", request.method(), request.uri().path()) +//! }) +//! .on_response(|response: &Response<Full<Bytes>>, latency: Duration, _span: &Span| { +//! tracing::debug!("response generated in {:?}", latency) +//! }) +//! .on_body_chunk(|chunk: &Bytes, latency: Duration, _span: &Span| { +//! tracing::debug!("sending {} bytes", chunk.len()) +//! }) +//! .on_eos(|trailers: Option<&HeaderMap>, stream_duration: Duration, _span: &Span| { +//! tracing::debug!("stream closed after {:?}", stream_duration) +//! }) +//! .on_failure(|error: ServerErrorsFailureClass, latency: Duration, _span: &Span| { +//! tracing::debug!("something went wrong") +//! }) +//! ) +//! .service_fn(handle); +//! # let mut service = service; +//! # let response = service +//! # .ready() +//! # .await? +//! # .call(Request::new(Full::from("foo"))) +//! # .await?; +//! # Ok(()) +//! # } +//! ``` +//! +//! ## Disabling something +//! +//! Setting the behaviour to `()` will be disable that particular step: +//! +//! ```rust +//! use http::StatusCode; +//! use tower::ServiceBuilder; +//! use tower_http::{classify::ServerErrorsFailureClass, trace::TraceLayer}; +//! use std::time::Duration; +//! use tracing::Span; +//! # use tower::{ServiceExt, Service}; +//! # use http_body_util::Full; +//! # use bytes::Bytes; +//! # use http::{Response, Request}; +//! # use std::convert::Infallible; +//! +//! # async fn handle(request: Request<Full<Bytes>>) -> Result<Response<Full<Bytes>>, Infallible> { +//! # Ok(Response::new(Full::from("foo"))) +//! # } +//! # #[tokio::main] +//! # async fn main() -> Result<(), Box<dyn std::error::Error>> { +//! # tracing_subscriber::fmt::init(); +//! # +//! let service = ServiceBuilder::new() +//! .layer( +//! // This configuration will only emit events on failures +//! TraceLayer::new_for_http() +//! .on_request(()) +//! .on_response(()) +//! .on_body_chunk(()) +//! .on_eos(()) +//! .on_failure(|error: ServerErrorsFailureClass, latency: Duration, _span: &Span| { +//! tracing::debug!("something went wrong") +//! }) +//! ) +//! .service_fn(handle); +//! # let mut service = service; +//! # let response = service +//! # .ready() +//! # .await? +//! # .call(Request::new(Full::from("foo"))) +//! # .await?; +//! # Ok(()) +//! # } +//! ``` +//! +//! # When the callbacks are called +//! +//! ### `on_request` +//! +//! The `on_request` callback is called when the request arrives at the +//! middleware in [`Service::call`] just prior to passing the request to the +//! inner service. +//! +//! ### `on_response` +//! +//! The `on_response` callback is called when the inner service's response +//! future completes with `Ok(response)` regardless if the response is +//! classified as a success or a failure. +//! +//! For example if you're using [`ServerErrorsAsFailures`] as your classifier +//! and the inner service responds with `500 Internal Server Error` then the +//! `on_response` callback is still called. `on_failure` would _also_ be called +//! in this case since the response was classified as a failure. +//! +//! ### `on_body_chunk` +//! +//! The `on_body_chunk` callback is called when the response body produces a new +//! chunk, that is when [`Body::poll_frame`] returns a data frame. +//! +//! `on_body_chunk` is called even if the chunk is empty. +//! +//! ### `on_eos` +//! +//! The `on_eos` callback is called when a streaming response body ends, that is +//! when [`Body::poll_frame`] returns a trailers frame. +//! +//! `on_eos` is called even if the trailers produced are `None`. +//! +//! ### `on_failure` +//! +//! The `on_failure` callback is called when: +//! +//! - The inner [`Service`]'s response future resolves to an error. +//! - A response is classified as a failure. +//! - [`Body::poll_frame`] returns an error. +//! - An end-of-stream is classified as a failure. +//! +//! # Recording fields on the span +//! +//! All callbacks receive a reference to the [tracing] [`Span`], corresponding to this request, +//! produced by the closure passed to [`TraceLayer::make_span_with`]. It can be used to [record +//! field values][record] that weren't known when the span was created. +//! +//! ```rust +//! use http::{Request, Response, HeaderMap, StatusCode}; +//! use http_body_util::Full; +//! use bytes::Bytes; +//! use tower::ServiceBuilder; +//! use tower_http::trace::TraceLayer; +//! use tracing::Span; +//! use std::time::Duration; +//! # use std::convert::Infallible; +//! +//! # async fn handle(request: Request<Full<Bytes>>) -> Result<Response<Full<Bytes>>, Infallible> { +//! # Ok(Response::new(Full::from("foo"))) +//! # } +//! # #[tokio::main] +//! # async fn main() -> Result<(), Box<dyn std::error::Error>> { +//! # tracing_subscriber::fmt::init(); +//! # +//! let service = ServiceBuilder::new() +//! .layer( +//! TraceLayer::new_for_http() +//! .make_span_with(|request: &Request<Full<Bytes>>| { +//! tracing::debug_span!( +//! "http-request", +//! status_code = tracing::field::Empty, +//! ) +//! }) +//! .on_response(|response: &Response<Full<Bytes>>, _latency: Duration, span: &Span| { +//! span.record("status_code", &tracing::field::display(response.status())); +//! +//! tracing::debug!("response generated") +//! }) +//! ) +//! .service_fn(handle); +//! # Ok(()) +//! # } +//! ``` +//! +//! # Providing classifiers +//! +//! Tracing requires determining if a response is a success or failure. [`MakeClassifier`] is used +//! to create a classifier for the incoming request. See the docs for [`MakeClassifier`] and +//! [`ClassifyResponse`] for more details on classification. +//! +//! A [`MakeClassifier`] can be provided when creating a [`TraceLayer`]: +//! +//! ```rust +//! use http::{Request, Response}; +//! use http_body_util::Full; +//! use bytes::Bytes; +//! use tower::ServiceBuilder; +//! use tower_http::{ +//! trace::TraceLayer, +//! classify::{ +//! MakeClassifier, ClassifyResponse, ClassifiedResponse, NeverClassifyEos, +//! SharedClassifier, +//! }, +//! }; +//! use std::convert::Infallible; +//! +//! # async fn handle(request: Request<Full<Bytes>>) -> Result<Response<Full<Bytes>>, Infallible> { +//! # Ok(Response::new(Full::from("foo"))) +//! # } +//! # #[tokio::main] +//! # async fn main() -> Result<(), Box<dyn std::error::Error>> { +//! # tracing_subscriber::fmt::init(); +//! # +//! // Our `MakeClassifier` that always crates `MyClassifier` classifiers. +//! #[derive(Copy, Clone)] +//! struct MyMakeClassify; +//! +//! impl MakeClassifier for MyMakeClassify { +//! type Classifier = MyClassifier; +//! type FailureClass = &'static str; +//! type ClassifyEos = NeverClassifyEos<&'static str>; +//! +//! fn make_classifier<B>(&self, req: &Request<B>) -> Self::Classifier { +//! MyClassifier +//! } +//! } +//! +//! // A classifier that classifies failures as `"something went wrong..."`. +//! #[derive(Copy, Clone)] +//! struct MyClassifier; +//! +//! impl ClassifyResponse for MyClassifier { +//! type FailureClass = &'static str; +//! type ClassifyEos = NeverClassifyEos<&'static str>; +//! +//! fn classify_response<B>( +//! self, +//! res: &Response<B> +//! ) -> ClassifiedResponse<Self::FailureClass, Self::ClassifyEos> { +//! // Classify based on the status code. +//! if res.status().is_server_error() { +//! ClassifiedResponse::Ready(Err("something went wrong...")) +//! } else { +//! ClassifiedResponse::Ready(Ok(())) +//! } +//! } +//! +//! fn classify_error<E>(self, error: &E) -> Self::FailureClass +//! where +//! E: std::fmt::Display + 'static, +//! { +//! "something went wrong..." +//! } +//! } +//! +//! let service = ServiceBuilder::new() +//! // Create a trace layer that uses our classifier. +//! .layer(TraceLayer::new(MyMakeClassify)) +//! .service_fn(handle); +//! +//! // Since `MyClassifier` is `Clone` we can also use `SharedClassifier` +//! // to avoid having to define a separate `MakeClassifier`. +//! let service = ServiceBuilder::new() +//! .layer(TraceLayer::new(SharedClassifier::new(MyClassifier))) +//! .service_fn(handle); +//! # Ok(()) +//! # } +//! ``` +//! +//! [`TraceLayer`] comes with convenience methods for using common classifiers: +//! +//! - [`TraceLayer::new_for_http`] classifies based on the status code. It doesn't consider +//! streaming responses. +//! - [`TraceLayer::new_for_grpc`] classifies based on the gRPC protocol and supports streaming +//! responses. +//! +//! [tracing]: https://crates.io/crates/tracing +//! [`Service`]: tower_service::Service +//! [`Service::call`]: tower_service::Service::call +//! [`MakeClassifier`]: crate::classify::MakeClassifier +//! [`ClassifyResponse`]: crate::classify::ClassifyResponse +//! [record]: https://docs.rs/tracing/latest/tracing/span/struct.Span.html#method.record +//! [`TraceLayer::make_span_with`]: crate::trace::TraceLayer::make_span_with +//! [`Span`]: tracing::Span +//! [`ServerErrorsAsFailures`]: crate::classify::ServerErrorsAsFailures +//! [`Body::poll_frame`]: http_body::Body::poll_frame + +use std::{fmt, time::Duration}; + +use tracing::Level; + +pub use self::{ + body::ResponseBody, + future::ResponseFuture, + layer::TraceLayer, + make_span::{DefaultMakeSpan, MakeSpan}, + on_body_chunk::{DefaultOnBodyChunk, OnBodyChunk}, + on_eos::{DefaultOnEos, OnEos}, + on_failure::{DefaultOnFailure, OnFailure}, + on_request::{DefaultOnRequest, OnRequest}, + on_response::{DefaultOnResponse, OnResponse}, + service::Trace, +}; +use crate::{ + classify::{GrpcErrorsAsFailures, ServerErrorsAsFailures, SharedClassifier}, + LatencyUnit, +}; + +/// MakeClassifier for HTTP requests. +pub type HttpMakeClassifier = SharedClassifier<ServerErrorsAsFailures>; + +/// MakeClassifier for gRPC requests. +pub type GrpcMakeClassifier = SharedClassifier<GrpcErrorsAsFailures>; + +macro_rules! event_dynamic_lvl { + ( $(target: $target:expr,)? $(parent: $parent:expr,)? $lvl:expr, $($tt:tt)* ) => { + match $lvl { + tracing::Level::ERROR => { + tracing::event!( + $(target: $target,)? + $(parent: $parent,)? + tracing::Level::ERROR, + $($tt)* + ); + } + tracing::Level::WARN => { + tracing::event!( + $(target: $target,)? + $(parent: $parent,)? + tracing::Level::WARN, + $($tt)* + ); + } + tracing::Level::INFO => { + tracing::event!( + $(target: $target,)? + $(parent: $parent,)? + tracing::Level::INFO, + $($tt)* + ); + } + tracing::Level::DEBUG => { + tracing::event!( + $(target: $target,)? + $(parent: $parent,)? + tracing::Level::DEBUG, + $($tt)* + ); + } + tracing::Level::TRACE => { + tracing::event!( + $(target: $target,)? + $(parent: $parent,)? + tracing::Level::TRACE, + $($tt)* + ); + } + } + }; +} + +mod body; +mod future; +mod layer; +mod make_span; +mod on_body_chunk; +mod on_eos; +mod on_failure; +mod on_request; +mod on_response; +mod service; + +const DEFAULT_MESSAGE_LEVEL: Level = Level::DEBUG; +const DEFAULT_ERROR_LEVEL: Level = Level::ERROR; + +struct Latency { + unit: LatencyUnit, + duration: Duration, +} + +impl fmt::Display for Latency { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self.unit { + LatencyUnit::Seconds => write!(f, "{} s", self.duration.as_secs_f64()), + LatencyUnit::Millis => write!(f, "{} ms", self.duration.as_millis()), + LatencyUnit::Micros => write!(f, "{} μs", self.duration.as_micros()), + LatencyUnit::Nanos => write!(f, "{} ns", self.duration.as_nanos()), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::classify::ServerErrorsFailureClass; + use crate::test_helpers::Body; + use bytes::Bytes; + use http::{HeaderMap, Request, Response}; + use once_cell::sync::Lazy; + use std::{ + sync::atomic::{AtomicU32, Ordering}, + time::Duration, + }; + use tower::{BoxError, Service, ServiceBuilder, ServiceExt}; + use tracing::Span; + + #[tokio::test] + async fn unary_request() { + static ON_REQUEST_COUNT: Lazy<AtomicU32> = Lazy::new(|| AtomicU32::new(0)); + static ON_RESPONSE_COUNT: Lazy<AtomicU32> = Lazy::new(|| AtomicU32::new(0)); + static ON_BODY_CHUNK_COUNT: Lazy<AtomicU32> = Lazy::new(|| AtomicU32::new(0)); + static ON_EOS: Lazy<AtomicU32> = Lazy::new(|| AtomicU32::new(0)); + static ON_FAILURE: Lazy<AtomicU32> = Lazy::new(|| AtomicU32::new(0)); + + let trace_layer = TraceLayer::new_for_http() + .make_span_with(|_req: &Request<Body>| { + tracing::info_span!("test-span", foo = tracing::field::Empty) + }) + .on_request(|_req: &Request<Body>, span: &Span| { + span.record("foo", 42); + ON_REQUEST_COUNT.fetch_add(1, Ordering::SeqCst); + }) + .on_response(|_res: &Response<Body>, _latency: Duration, _span: &Span| { + ON_RESPONSE_COUNT.fetch_add(1, Ordering::SeqCst); + }) + .on_body_chunk(|_chunk: &Bytes, _latency: Duration, _span: &Span| { + ON_BODY_CHUNK_COUNT.fetch_add(1, Ordering::SeqCst); + }) + .on_eos( + |_trailers: Option<&HeaderMap>, _latency: Duration, _span: &Span| { + ON_EOS.fetch_add(1, Ordering::SeqCst); + }, + ) + .on_failure( + |_class: ServerErrorsFailureClass, _latency: Duration, _span: &Span| { + ON_FAILURE.fetch_add(1, Ordering::SeqCst); + }, + ); + + let mut svc = ServiceBuilder::new().layer(trace_layer).service_fn(echo); + + let res = svc + .ready() + .await + .unwrap() + .call(Request::new(Body::from("foobar"))) + .await + .unwrap(); + + assert_eq!(1, ON_REQUEST_COUNT.load(Ordering::SeqCst), "request"); + assert_eq!(1, ON_RESPONSE_COUNT.load(Ordering::SeqCst), "request"); + assert_eq!(0, ON_BODY_CHUNK_COUNT.load(Ordering::SeqCst), "body chunk"); + assert_eq!(0, ON_EOS.load(Ordering::SeqCst), "eos"); + assert_eq!(0, ON_FAILURE.load(Ordering::SeqCst), "failure"); + + crate::test_helpers::to_bytes(res.into_body()) + .await + .unwrap(); + assert_eq!(1, ON_BODY_CHUNK_COUNT.load(Ordering::SeqCst), "body chunk"); + assert_eq!(0, ON_EOS.load(Ordering::SeqCst), "eos"); + assert_eq!(0, ON_FAILURE.load(Ordering::SeqCst), "failure"); + } + + #[tokio::test] + async fn streaming_response() { + static ON_REQUEST_COUNT: Lazy<AtomicU32> = Lazy::new(|| AtomicU32::new(0)); + static ON_RESPONSE_COUNT: Lazy<AtomicU32> = Lazy::new(|| AtomicU32::new(0)); + static ON_BODY_CHUNK_COUNT: Lazy<AtomicU32> = Lazy::new(|| AtomicU32::new(0)); + static ON_EOS: Lazy<AtomicU32> = Lazy::new(|| AtomicU32::new(0)); + static ON_FAILURE: Lazy<AtomicU32> = Lazy::new(|| AtomicU32::new(0)); + + let trace_layer = TraceLayer::new_for_http() + .on_request(|_req: &Request<Body>, _span: &Span| { + ON_REQUEST_COUNT.fetch_add(1, Ordering::SeqCst); + }) + .on_response(|_res: &Response<Body>, _latency: Duration, _span: &Span| { + ON_RESPONSE_COUNT.fetch_add(1, Ordering::SeqCst); + }) + .on_body_chunk(|_chunk: &Bytes, _latency: Duration, _span: &Span| { + ON_BODY_CHUNK_COUNT.fetch_add(1, Ordering::SeqCst); + }) + .on_eos( + |_trailers: Option<&HeaderMap>, _latency: Duration, _span: &Span| { + ON_EOS.fetch_add(1, Ordering::SeqCst); + }, + ) + .on_failure( + |_class: ServerErrorsFailureClass, _latency: Duration, _span: &Span| { + ON_FAILURE.fetch_add(1, Ordering::SeqCst); + }, + ); + + let mut svc = ServiceBuilder::new() + .layer(trace_layer) + .service_fn(streaming_body); + + let res = svc + .ready() + .await + .unwrap() + .call(Request::new(Body::empty())) + .await + .unwrap(); + + assert_eq!(1, ON_REQUEST_COUNT.load(Ordering::SeqCst), "request"); + assert_eq!(1, ON_RESPONSE_COUNT.load(Ordering::SeqCst), "request"); + assert_eq!(0, ON_BODY_CHUNK_COUNT.load(Ordering::SeqCst), "body chunk"); + assert_eq!(0, ON_EOS.load(Ordering::SeqCst), "eos"); + assert_eq!(0, ON_FAILURE.load(Ordering::SeqCst), "failure"); + + crate::test_helpers::to_bytes(res.into_body()) + .await + .unwrap(); + assert_eq!(3, ON_BODY_CHUNK_COUNT.load(Ordering::SeqCst), "body chunk"); + assert_eq!(0, ON_EOS.load(Ordering::SeqCst), "eos"); + assert_eq!(0, ON_FAILURE.load(Ordering::SeqCst), "failure"); + } + + async fn echo(req: Request<Body>) -> Result<Response<Body>, BoxError> { + Ok(Response::new(req.into_body())) + } + + async fn streaming_body(_req: Request<Body>) -> Result<Response<Body>, BoxError> { + use futures_util::stream::iter; + + let stream = iter(vec![ + Ok::<_, BoxError>(Bytes::from("one")), + Ok::<_, BoxError>(Bytes::from("two")), + Ok::<_, BoxError>(Bytes::from("three")), + ]); + + let body = Body::from_stream(stream); + + Ok(Response::new(body)) + } +} diff --git a/vendor/tower-http/src/trace/on_body_chunk.rs b/vendor/tower-http/src/trace/on_body_chunk.rs new file mode 100644 index 00000000..543f2a63 --- /dev/null +++ b/vendor/tower-http/src/trace/on_body_chunk.rs @@ -0,0 +1,64 @@ +use std::time::Duration; +use tracing::Span; + +/// Trait used to tell [`Trace`] what to do when a body chunk has been sent. +/// +/// See the [module docs](../trace/index.html#on_body_chunk) for details on exactly when the +/// `on_body_chunk` callback is called. +/// +/// [`Trace`]: super::Trace +pub trait OnBodyChunk<B> { + /// Do the thing. + /// + /// `latency` is the duration since the response was sent or since the last body chunk as sent. + /// + /// `span` is the `tracing` [`Span`], corresponding to this request, produced by the closure + /// passed to [`TraceLayer::make_span_with`]. It can be used to [record field values][record] + /// that weren't known when the span was created. + /// + /// [`Span`]: https://docs.rs/tracing/latest/tracing/span/index.html + /// [record]: https://docs.rs/tracing/latest/tracing/span/struct.Span.html#method.record + /// + /// If you're using [hyper] as your server `B` will most likely be [`Bytes`]. + /// + /// [hyper]: https://hyper.rs + /// [`Bytes`]: https://docs.rs/bytes/latest/bytes/struct.Bytes.html + /// [`TraceLayer::make_span_with`]: crate::trace::TraceLayer::make_span_with + fn on_body_chunk(&mut self, chunk: &B, latency: Duration, span: &Span); +} + +impl<B, F> OnBodyChunk<B> for F +where + F: FnMut(&B, Duration, &Span), +{ + fn on_body_chunk(&mut self, chunk: &B, latency: Duration, span: &Span) { + self(chunk, latency, span) + } +} + +impl<B> OnBodyChunk<B> for () { + #[inline] + fn on_body_chunk(&mut self, _: &B, _: Duration, _: &Span) {} +} + +/// The default [`OnBodyChunk`] implementation used by [`Trace`]. +/// +/// Simply does nothing. +/// +/// [`Trace`]: super::Trace +#[derive(Debug, Default, Clone)] +pub struct DefaultOnBodyChunk { + _priv: (), +} + +impl DefaultOnBodyChunk { + /// Create a new `DefaultOnBodyChunk`. + pub fn new() -> Self { + Self { _priv: () } + } +} + +impl<B> OnBodyChunk<B> for DefaultOnBodyChunk { + #[inline] + fn on_body_chunk(&mut self, _: &B, _: Duration, _: &Span) {} +} diff --git a/vendor/tower-http/src/trace/on_eos.rs b/vendor/tower-http/src/trace/on_eos.rs new file mode 100644 index 00000000..ab90fc9c --- /dev/null +++ b/vendor/tower-http/src/trace/on_eos.rs @@ -0,0 +1,107 @@ +use super::{Latency, DEFAULT_MESSAGE_LEVEL}; +use crate::{classify::grpc_errors_as_failures::ParsedGrpcStatus, LatencyUnit}; +use http::header::HeaderMap; +use std::time::Duration; +use tracing::{Level, Span}; + +/// Trait used to tell [`Trace`] what to do when a stream closes. +/// +/// See the [module docs](../trace/index.html#on_eos) for details on exactly when the `on_eos` +/// callback is called. +/// +/// [`Trace`]: super::Trace +pub trait OnEos { + /// Do the thing. + /// + /// `stream_duration` is the duration since the response was sent. + /// + /// `span` is the `tracing` [`Span`], corresponding to this request, produced by the closure + /// passed to [`TraceLayer::make_span_with`]. It can be used to [record field values][record] + /// that weren't known when the span was created. + /// + /// [`Span`]: https://docs.rs/tracing/latest/tracing/span/index.html + /// [record]: https://docs.rs/tracing/latest/tracing/span/struct.Span.html#method.record + /// [`TraceLayer::make_span_with`]: crate::trace::TraceLayer::make_span_with + fn on_eos(self, trailers: Option<&HeaderMap>, stream_duration: Duration, span: &Span); +} + +impl OnEos for () { + #[inline] + fn on_eos(self, _: Option<&HeaderMap>, _: Duration, _: &Span) {} +} + +impl<F> OnEos for F +where + F: FnOnce(Option<&HeaderMap>, Duration, &Span), +{ + fn on_eos(self, trailers: Option<&HeaderMap>, stream_duration: Duration, span: &Span) { + self(trailers, stream_duration, span) + } +} + +/// The default [`OnEos`] implementation used by [`Trace`]. +/// +/// [`Trace`]: super::Trace +#[derive(Clone, Debug)] +pub struct DefaultOnEos { + level: Level, + latency_unit: LatencyUnit, +} + +impl Default for DefaultOnEos { + fn default() -> Self { + Self { + level: DEFAULT_MESSAGE_LEVEL, + latency_unit: LatencyUnit::Millis, + } + } +} + +impl DefaultOnEos { + /// Create a new [`DefaultOnEos`]. + pub fn new() -> Self { + Self::default() + } + + /// Set the [`Level`] used for [tracing events]. + /// + /// Defaults to [`Level::DEBUG`]. + /// + /// [tracing events]: https://docs.rs/tracing/latest/tracing/#events + /// [`Level::DEBUG`]: https://docs.rs/tracing/latest/tracing/struct.Level.html#associatedconstant.DEBUG + pub fn level(mut self, level: Level) -> Self { + self.level = level; + self + } + + /// Set the [`LatencyUnit`] latencies will be reported in. + /// + /// Defaults to [`LatencyUnit::Millis`]. + pub fn latency_unit(mut self, latency_unit: LatencyUnit) -> Self { + self.latency_unit = latency_unit; + self + } +} + +impl OnEos for DefaultOnEos { + fn on_eos(self, trailers: Option<&HeaderMap>, stream_duration: Duration, _span: &Span) { + let stream_duration = Latency { + unit: self.latency_unit, + duration: stream_duration, + }; + let status = trailers.and_then(|trailers| { + match crate::classify::grpc_errors_as_failures::classify_grpc_metadata( + trailers, + crate::classify::GrpcCode::Ok.into_bitmask(), + ) { + ParsedGrpcStatus::Success + | ParsedGrpcStatus::HeaderNotString + | ParsedGrpcStatus::HeaderNotInt => Some(0), + ParsedGrpcStatus::NonSuccess(status) => Some(status.get()), + ParsedGrpcStatus::GrpcStatusHeaderMissing => None, + } + }); + + event_dynamic_lvl!(self.level, %stream_duration, status, "end of stream"); + } +} diff --git a/vendor/tower-http/src/trace/on_failure.rs b/vendor/tower-http/src/trace/on_failure.rs new file mode 100644 index 00000000..7dfa186d --- /dev/null +++ b/vendor/tower-http/src/trace/on_failure.rs @@ -0,0 +1,100 @@ +use super::{Latency, DEFAULT_ERROR_LEVEL}; +use crate::LatencyUnit; +use std::{fmt, time::Duration}; +use tracing::{Level, Span}; + +/// Trait used to tell [`Trace`] what to do when a request fails. +/// +/// See the [module docs](../trace/index.html#on_failure) for details on exactly when the +/// `on_failure` callback is called. +/// +/// [`Trace`]: super::Trace +pub trait OnFailure<FailureClass> { + /// Do the thing. + /// + /// `latency` is the duration since the request was received. + /// + /// `span` is the `tracing` [`Span`], corresponding to this request, produced by the closure + /// passed to [`TraceLayer::make_span_with`]. It can be used to [record field values][record] + /// that weren't known when the span was created. + /// + /// [`Span`]: https://docs.rs/tracing/latest/tracing/span/index.html + /// [record]: https://docs.rs/tracing/latest/tracing/span/struct.Span.html#method.record + /// [`TraceLayer::make_span_with`]: crate::trace::TraceLayer::make_span_with + fn on_failure(&mut self, failure_classification: FailureClass, latency: Duration, span: &Span); +} + +impl<FailureClass> OnFailure<FailureClass> for () { + #[inline] + fn on_failure(&mut self, _: FailureClass, _: Duration, _: &Span) {} +} + +impl<F, FailureClass> OnFailure<FailureClass> for F +where + F: FnMut(FailureClass, Duration, &Span), +{ + fn on_failure(&mut self, failure_classification: FailureClass, latency: Duration, span: &Span) { + self(failure_classification, latency, span) + } +} + +/// The default [`OnFailure`] implementation used by [`Trace`]. +/// +/// [`Trace`]: super::Trace +#[derive(Clone, Debug)] +pub struct DefaultOnFailure { + level: Level, + latency_unit: LatencyUnit, +} + +impl Default for DefaultOnFailure { + fn default() -> Self { + Self { + level: DEFAULT_ERROR_LEVEL, + latency_unit: LatencyUnit::Millis, + } + } +} + +impl DefaultOnFailure { + /// Create a new `DefaultOnFailure`. + pub fn new() -> Self { + Self::default() + } + + /// Set the [`Level`] used for [tracing events]. + /// + /// Defaults to [`Level::ERROR`]. + /// + /// [tracing events]: https://docs.rs/tracing/latest/tracing/#events + pub fn level(mut self, level: Level) -> Self { + self.level = level; + self + } + + /// Set the [`LatencyUnit`] latencies will be reported in. + /// + /// Defaults to [`LatencyUnit::Millis`]. + pub fn latency_unit(mut self, latency_unit: LatencyUnit) -> Self { + self.latency_unit = latency_unit; + self + } +} + +impl<FailureClass> OnFailure<FailureClass> for DefaultOnFailure +where + FailureClass: fmt::Display, +{ + fn on_failure(&mut self, failure_classification: FailureClass, latency: Duration, _: &Span) { + let latency = Latency { + unit: self.latency_unit, + duration: latency, + }; + event_dynamic_lvl!( + self.level, + classification = %failure_classification, + %latency, + "response failed" + ); + } +} diff --git a/vendor/tower-http/src/trace/on_request.rs b/vendor/tower-http/src/trace/on_request.rs new file mode 100644 index 00000000..07de1893 --- /dev/null +++ b/vendor/tower-http/src/trace/on_request.rs @@ -0,0 +1,82 @@ +use super::DEFAULT_MESSAGE_LEVEL; +use http::Request; +use tracing::Level; +use tracing::Span; + +/// Trait used to tell [`Trace`] what to do when a request is received. +/// +/// See the [module docs](../trace/index.html#on_request) for details on exactly when the +/// `on_request` callback is called. +/// +/// [`Trace`]: super::Trace +pub trait OnRequest<B> { + /// Do the thing. + /// + /// `span` is the `tracing` [`Span`], corresponding to this request, produced by the closure + /// passed to [`TraceLayer::make_span_with`]. It can be used to [record field values][record] + /// that weren't known when the span was created. + /// + /// [`Span`]: https://docs.rs/tracing/latest/tracing/span/index.html + /// [record]: https://docs.rs/tracing/latest/tracing/span/struct.Span.html#method.record + /// [`TraceLayer::make_span_with`]: crate::trace::TraceLayer::make_span_with + fn on_request(&mut self, request: &Request<B>, span: &Span); +} + +impl<B> OnRequest<B> for () { + #[inline] + fn on_request(&mut self, _: &Request<B>, _: &Span) {} +} + +impl<B, F> OnRequest<B> for F +where + F: FnMut(&Request<B>, &Span), +{ + fn on_request(&mut self, request: &Request<B>, span: &Span) { + self(request, span) + } +} + +/// The default [`OnRequest`] implementation used by [`Trace`]. +/// +/// [`Trace`]: super::Trace +#[derive(Clone, Debug)] +pub struct DefaultOnRequest { + level: Level, +} + +impl Default for DefaultOnRequest { + fn default() -> Self { + Self { + level: DEFAULT_MESSAGE_LEVEL, + } + } +} + +impl DefaultOnRequest { + /// Create a new `DefaultOnRequest`. + pub fn new() -> Self { + Self::default() + } + + /// Set the [`Level`] used for [tracing events]. + /// + /// Please note that while this will set the level for the tracing events + /// themselves, it might cause them to lack expected information, like + /// request method or path. You can address this using + /// [`DefaultMakeSpan::level`]. + /// + /// Defaults to [`Level::DEBUG`]. + /// + /// [tracing events]: https://docs.rs/tracing/latest/tracing/#events + /// [`DefaultMakeSpan::level`]: crate::trace::DefaultMakeSpan::level + pub fn level(mut self, level: Level) -> Self { + self.level = level; + self + } +} + +impl<B> OnRequest<B> for DefaultOnRequest { + fn on_request(&mut self, _: &Request<B>, _: &Span) { + event_dynamic_lvl!(self.level, "started processing request"); + } +} diff --git a/vendor/tower-http/src/trace/on_response.rs b/vendor/tower-http/src/trace/on_response.rs new file mode 100644 index 00000000..c6ece840 --- /dev/null +++ b/vendor/tower-http/src/trace/on_response.rs @@ -0,0 +1,161 @@ +use super::{Latency, DEFAULT_MESSAGE_LEVEL}; +use crate::LatencyUnit; +use http::Response; +use std::time::Duration; +use tracing::Level; +use tracing::Span; + +/// Trait used to tell [`Trace`] what to do when a response has been produced. +/// +/// See the [module docs](../trace/index.html#on_response) for details on exactly when the +/// `on_response` callback is called. +/// +/// [`Trace`]: super::Trace +pub trait OnResponse<B> { + /// Do the thing. + /// + /// `latency` is the duration since the request was received. + /// + /// `span` is the `tracing` [`Span`], corresponding to this request, produced by the closure + /// passed to [`TraceLayer::make_span_with`]. It can be used to [record field values][record] + /// that weren't known when the span was created. + /// + /// [`Span`]: https://docs.rs/tracing/latest/tracing/span/index.html + /// [record]: https://docs.rs/tracing/latest/tracing/span/struct.Span.html#method.record + /// [`TraceLayer::make_span_with`]: crate::trace::TraceLayer::make_span_with + fn on_response(self, response: &Response<B>, latency: Duration, span: &Span); +} + +impl<B> OnResponse<B> for () { + #[inline] + fn on_response(self, _: &Response<B>, _: Duration, _: &Span) {} +} + +impl<B, F> OnResponse<B> for F +where + F: FnOnce(&Response<B>, Duration, &Span), +{ + fn on_response(self, response: &Response<B>, latency: Duration, span: &Span) { + self(response, latency, span) + } +} + +/// The default [`OnResponse`] implementation used by [`Trace`]. +/// +/// [`Trace`]: super::Trace +#[derive(Clone, Debug)] +pub struct DefaultOnResponse { + level: Level, + latency_unit: LatencyUnit, + include_headers: bool, +} + +impl Default for DefaultOnResponse { + fn default() -> Self { + Self { + level: DEFAULT_MESSAGE_LEVEL, + latency_unit: LatencyUnit::Millis, + include_headers: false, + } + } +} + +impl DefaultOnResponse { + /// Create a new `DefaultOnResponse`. + pub fn new() -> Self { + Self::default() + } + + /// Set the [`Level`] used for [tracing events]. + /// + /// Please note that while this will set the level for the tracing events + /// themselves, it might cause them to lack expected information, like + /// request method or path. You can address this using + /// [`DefaultMakeSpan::level`]. + /// + /// Defaults to [`Level::DEBUG`]. + /// + /// [tracing events]: https://docs.rs/tracing/latest/tracing/#events + /// [`DefaultMakeSpan::level`]: crate::trace::DefaultMakeSpan::level + pub fn level(mut self, level: Level) -> Self { + self.level = level; + self + } + + /// Set the [`LatencyUnit`] latencies will be reported in. + /// + /// Defaults to [`LatencyUnit::Millis`]. + pub fn latency_unit(mut self, latency_unit: LatencyUnit) -> Self { + self.latency_unit = latency_unit; + self + } + + /// Include response headers on the [`Event`]. + /// + /// By default headers are not included. + /// + /// [`Event`]: tracing::Event + pub fn include_headers(mut self, include_headers: bool) -> Self { + self.include_headers = include_headers; + self + } +} + +impl<B> OnResponse<B> for DefaultOnResponse { + fn on_response(self, response: &Response<B>, latency: Duration, _: &Span) { + let latency = Latency { + unit: self.latency_unit, + duration: latency, + }; + let response_headers = self + .include_headers + .then(|| tracing::field::debug(response.headers())); + + event_dynamic_lvl!( + self.level, + %latency, + status = status(response), + response_headers, + "finished processing request" + ); + } +} + +fn status<B>(res: &Response<B>) -> Option<i32> { + use crate::classify::grpc_errors_as_failures::ParsedGrpcStatus; + + // gRPC-over-HTTP2 uses the "application/grpc[+format]" content type, and gRPC-Web uses + // "application/grpc-web[+format]" or "application/grpc-web-text[+format]", where "format" is + // the message format, e.g. +proto, +json. + // + // So, valid grpc content types include (but are not limited to): + // - application/grpc + // - application/grpc+proto + // - application/grpc-web+proto + // - application/grpc-web-text+proto + // + // For simplicity, we simply check that the content type starts with "application/grpc". + let is_grpc = res + .headers() + .get(http::header::CONTENT_TYPE) + .map_or(false, |value| { + value.as_bytes().starts_with("application/grpc".as_bytes()) + }); + + if is_grpc { + match crate::classify::grpc_errors_as_failures::classify_grpc_metadata( + res.headers(), + crate::classify::GrpcCode::Ok.into_bitmask(), + ) { + ParsedGrpcStatus::Success + | ParsedGrpcStatus::HeaderNotString + | ParsedGrpcStatus::HeaderNotInt => Some(0), + ParsedGrpcStatus::NonSuccess(status) => Some(status.get()), + // if `grpc-status` is missing then its a streaming response and there is no status + // _yet_, so its neither success nor error + ParsedGrpcStatus::GrpcStatusHeaderMissing => None, + } + } else { + Some(res.status().as_u16().into()) + } +} diff --git a/vendor/tower-http/src/trace/service.rs b/vendor/tower-http/src/trace/service.rs new file mode 100644 index 00000000..1ab4c1f0 --- /dev/null +++ b/vendor/tower-http/src/trace/service.rs @@ -0,0 +1,325 @@ +use super::{ + DefaultMakeSpan, DefaultOnBodyChunk, DefaultOnEos, DefaultOnFailure, DefaultOnRequest, + DefaultOnResponse, GrpcMakeClassifier, HttpMakeClassifier, MakeSpan, OnBodyChunk, OnEos, + OnFailure, OnRequest, OnResponse, ResponseBody, ResponseFuture, TraceLayer, +}; +use crate::classify::{ + GrpcErrorsAsFailures, MakeClassifier, ServerErrorsAsFailures, SharedClassifier, +}; +use http::{Request, Response}; +use http_body::Body; +use std::{ + fmt, + task::{Context, Poll}, + time::Instant, +}; +use tower_service::Service; + +/// Middleware that adds high level [tracing] to a [`Service`]. +/// +/// See the [module docs](crate::trace) for an example. +/// +/// [tracing]: https://crates.io/crates/tracing +/// [`Service`]: tower_service::Service +#[derive(Debug, Clone, Copy)] +pub struct Trace< + S, + M, + MakeSpan = DefaultMakeSpan, + OnRequest = DefaultOnRequest, + OnResponse = DefaultOnResponse, + OnBodyChunk = DefaultOnBodyChunk, + OnEos = DefaultOnEos, + OnFailure = DefaultOnFailure, +> { + pub(crate) inner: S, + pub(crate) make_classifier: M, + pub(crate) make_span: MakeSpan, + pub(crate) on_request: OnRequest, + pub(crate) on_response: OnResponse, + pub(crate) on_body_chunk: OnBodyChunk, + pub(crate) on_eos: OnEos, + pub(crate) on_failure: OnFailure, +} + +impl<S, M> Trace<S, M> { + /// Create a new [`Trace`] using the given [`MakeClassifier`]. + pub fn new(inner: S, make_classifier: M) -> Self + where + M: MakeClassifier, + { + Self { + inner, + make_classifier, + make_span: DefaultMakeSpan::new(), + on_request: DefaultOnRequest::default(), + on_response: DefaultOnResponse::default(), + on_body_chunk: DefaultOnBodyChunk::default(), + on_eos: DefaultOnEos::default(), + on_failure: DefaultOnFailure::default(), + } + } + + /// Returns a new [`Layer`] that wraps services with a [`TraceLayer`] middleware. + /// + /// [`Layer`]: tower_layer::Layer + pub fn layer(make_classifier: M) -> TraceLayer<M> + where + M: MakeClassifier, + { + TraceLayer::new(make_classifier) + } +} + +impl<S, M, MakeSpan, OnRequest, OnResponse, OnBodyChunk, OnEos, OnFailure> + Trace<S, M, MakeSpan, OnRequest, OnResponse, OnBodyChunk, OnEos, OnFailure> +{ + define_inner_service_accessors!(); + + /// Customize what to do when a request is received. + /// + /// `NewOnRequest` is expected to implement [`OnRequest`]. + /// + /// [`OnRequest`]: super::OnRequest + pub fn on_request<NewOnRequest>( + self, + new_on_request: NewOnRequest, + ) -> Trace<S, M, MakeSpan, NewOnRequest, OnResponse, OnBodyChunk, OnEos, OnFailure> { + Trace { + on_request: new_on_request, + inner: self.inner, + on_failure: self.on_failure, + on_eos: self.on_eos, + on_body_chunk: self.on_body_chunk, + make_span: self.make_span, + on_response: self.on_response, + make_classifier: self.make_classifier, + } + } + + /// Customize what to do when a response has been produced. + /// + /// `NewOnResponse` is expected to implement [`OnResponse`]. + /// + /// [`OnResponse`]: super::OnResponse + pub fn on_response<NewOnResponse>( + self, + new_on_response: NewOnResponse, + ) -> Trace<S, M, MakeSpan, OnRequest, NewOnResponse, OnBodyChunk, OnEos, OnFailure> { + Trace { + on_response: new_on_response, + inner: self.inner, + on_request: self.on_request, + on_failure: self.on_failure, + on_body_chunk: self.on_body_chunk, + on_eos: self.on_eos, + make_span: self.make_span, + make_classifier: self.make_classifier, + } + } + + /// Customize what to do when a body chunk has been sent. + /// + /// `NewOnBodyChunk` is expected to implement [`OnBodyChunk`]. + /// + /// [`OnBodyChunk`]: super::OnBodyChunk + pub fn on_body_chunk<NewOnBodyChunk>( + self, + new_on_body_chunk: NewOnBodyChunk, + ) -> Trace<S, M, MakeSpan, OnRequest, OnResponse, NewOnBodyChunk, OnEos, OnFailure> { + Trace { + on_body_chunk: new_on_body_chunk, + on_eos: self.on_eos, + make_span: self.make_span, + inner: self.inner, + on_failure: self.on_failure, + on_request: self.on_request, + on_response: self.on_response, + make_classifier: self.make_classifier, + } + } + + /// Customize what to do when a streaming response has closed. + /// + /// `NewOnEos` is expected to implement [`OnEos`]. + /// + /// [`OnEos`]: super::OnEos + pub fn on_eos<NewOnEos>( + self, + new_on_eos: NewOnEos, + ) -> Trace<S, M, MakeSpan, OnRequest, OnResponse, OnBodyChunk, NewOnEos, OnFailure> { + Trace { + on_eos: new_on_eos, + make_span: self.make_span, + inner: self.inner, + on_failure: self.on_failure, + on_request: self.on_request, + on_body_chunk: self.on_body_chunk, + on_response: self.on_response, + make_classifier: self.make_classifier, + } + } + + /// Customize what to do when a response has been classified as a failure. + /// + /// `NewOnFailure` is expected to implement [`OnFailure`]. + /// + /// [`OnFailure`]: super::OnFailure + pub fn on_failure<NewOnFailure>( + self, + new_on_failure: NewOnFailure, + ) -> Trace<S, M, MakeSpan, OnRequest, OnResponse, OnBodyChunk, OnEos, NewOnFailure> { + Trace { + on_failure: new_on_failure, + inner: self.inner, + make_span: self.make_span, + on_body_chunk: self.on_body_chunk, + on_request: self.on_request, + on_eos: self.on_eos, + on_response: self.on_response, + make_classifier: self.make_classifier, + } + } + + /// Customize how to make [`Span`]s that all request handling will be wrapped in. + /// + /// `NewMakeSpan` is expected to implement [`MakeSpan`]. + /// + /// [`MakeSpan`]: super::MakeSpan + /// [`Span`]: tracing::Span + pub fn make_span_with<NewMakeSpan>( + self, + new_make_span: NewMakeSpan, + ) -> Trace<S, M, NewMakeSpan, OnRequest, OnResponse, OnBodyChunk, OnEos, OnFailure> { + Trace { + make_span: new_make_span, + inner: self.inner, + on_failure: self.on_failure, + on_request: self.on_request, + on_body_chunk: self.on_body_chunk, + on_response: self.on_response, + on_eos: self.on_eos, + make_classifier: self.make_classifier, + } + } +} + +impl<S> + Trace< + S, + HttpMakeClassifier, + DefaultMakeSpan, + DefaultOnRequest, + DefaultOnResponse, + DefaultOnBodyChunk, + DefaultOnEos, + DefaultOnFailure, + > +{ + /// Create a new [`Trace`] using [`ServerErrorsAsFailures`] which supports classifying + /// regular HTTP responses based on the status code. + pub fn new_for_http(inner: S) -> Self { + Self { + inner, + make_classifier: SharedClassifier::new(ServerErrorsAsFailures::default()), + make_span: DefaultMakeSpan::new(), + on_request: DefaultOnRequest::default(), + on_response: DefaultOnResponse::default(), + on_body_chunk: DefaultOnBodyChunk::default(), + on_eos: DefaultOnEos::default(), + on_failure: DefaultOnFailure::default(), + } + } +} + +impl<S> + Trace< + S, + GrpcMakeClassifier, + DefaultMakeSpan, + DefaultOnRequest, + DefaultOnResponse, + DefaultOnBodyChunk, + DefaultOnEos, + DefaultOnFailure, + > +{ + /// Create a new [`Trace`] using [`GrpcErrorsAsFailures`] which supports classifying + /// gRPC responses and streams based on the `grpc-status` header. + pub fn new_for_grpc(inner: S) -> Self { + Self { + inner, + make_classifier: SharedClassifier::new(GrpcErrorsAsFailures::default()), + make_span: DefaultMakeSpan::new(), + on_request: DefaultOnRequest::default(), + on_response: DefaultOnResponse::default(), + on_body_chunk: DefaultOnBodyChunk::default(), + on_eos: DefaultOnEos::default(), + on_failure: DefaultOnFailure::default(), + } + } +} + +impl< + S, + ReqBody, + ResBody, + M, + OnRequestT, + OnResponseT, + OnFailureT, + OnBodyChunkT, + OnEosT, + MakeSpanT, + > Service<Request<ReqBody>> + for Trace<S, M, MakeSpanT, OnRequestT, OnResponseT, OnBodyChunkT, OnEosT, OnFailureT> +where + S: Service<Request<ReqBody>, Response = Response<ResBody>>, + ReqBody: Body, + ResBody: Body, + ResBody::Error: fmt::Display + 'static, + S::Error: fmt::Display + 'static, + M: MakeClassifier, + M::Classifier: Clone, + MakeSpanT: MakeSpan<ReqBody>, + OnRequestT: OnRequest<ReqBody>, + OnResponseT: OnResponse<ResBody> + Clone, + OnBodyChunkT: OnBodyChunk<ResBody::Data> + Clone, + OnEosT: OnEos + Clone, + OnFailureT: OnFailure<M::FailureClass> + Clone, +{ + type Response = + Response<ResponseBody<ResBody, M::ClassifyEos, OnBodyChunkT, OnEosT, OnFailureT>>; + type Error = S::Error; + type Future = + ResponseFuture<S::Future, M::Classifier, OnResponseT, OnBodyChunkT, OnEosT, OnFailureT>; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, req: Request<ReqBody>) -> Self::Future { + let start = Instant::now(); + + let span = self.make_span.make_span(&req); + + let classifier = self.make_classifier.make_classifier(&req); + + let future = { + let _guard = span.enter(); + self.on_request.on_request(&req, &span); + self.inner.call(req) + }; + + ResponseFuture { + inner: future, + span, + classifier: Some(classifier), + on_response: Some(self.on_response.clone()), + on_body_chunk: Some(self.on_body_chunk.clone()), + on_eos: Some(self.on_eos.clone()), + on_failure: Some(self.on_failure.clone()), + start, + } + } +} diff --git a/vendor/tower-http/src/validate_request.rs b/vendor/tower-http/src/validate_request.rs new file mode 100644 index 00000000..efb301e4 --- /dev/null +++ b/vendor/tower-http/src/validate_request.rs @@ -0,0 +1,587 @@ +//! Middleware that validates requests. +//! +//! # Example +//! +//! ``` +//! use tower_http::validate_request::ValidateRequestHeaderLayer; +//! use http::{Request, Response, StatusCode, header::ACCEPT}; +//! use http_body_util::Full; +//! use bytes::Bytes; +//! use tower::{Service, ServiceExt, ServiceBuilder, service_fn, BoxError}; +//! +//! async fn handle(request: Request<Full<Bytes>>) -> Result<Response<Full<Bytes>>, BoxError> { +//! Ok(Response::new(Full::default())) +//! } +//! +//! # #[tokio::main] +//! # async fn main() -> Result<(), BoxError> { +//! let mut service = ServiceBuilder::new() +//! // Require the `Accept` header to be `application/json`, `*/*` or `application/*` +//! .layer(ValidateRequestHeaderLayer::accept("application/json")) +//! .service_fn(handle); +//! +//! // Requests with the correct value are allowed through +//! let request = Request::builder() +//! .header(ACCEPT, "application/json") +//! .body(Full::default()) +//! .unwrap(); +//! +//! let response = service +//! .ready() +//! .await? +//! .call(request) +//! .await?; +//! +//! assert_eq!(StatusCode::OK, response.status()); +//! +//! // Requests with an invalid value get a `406 Not Acceptable` response +//! let request = Request::builder() +//! .header(ACCEPT, "text/strings") +//! .body(Full::default()) +//! .unwrap(); +//! +//! let response = service +//! .ready() +//! .await? +//! .call(request) +//! .await?; +//! +//! assert_eq!(StatusCode::NOT_ACCEPTABLE, response.status()); +//! # Ok(()) +//! # } +//! ``` +//! +//! Custom validation can be made by implementing [`ValidateRequest`]: +//! +//! ``` +//! use tower_http::validate_request::{ValidateRequestHeaderLayer, ValidateRequest}; +//! use http::{Request, Response, StatusCode, header::ACCEPT}; +//! use http_body_util::Full; +//! use tower::{Service, ServiceExt, ServiceBuilder, service_fn, BoxError}; +//! use bytes::Bytes; +//! +//! #[derive(Clone, Copy)] +//! pub struct MyHeader { /* ... */ } +//! +//! impl<B> ValidateRequest<B> for MyHeader { +//! type ResponseBody = Full<Bytes>; +//! +//! fn validate( +//! &mut self, +//! request: &mut Request<B>, +//! ) -> Result<(), Response<Self::ResponseBody>> { +//! // validate the request... +//! # unimplemented!() +//! } +//! } +//! +//! async fn handle(request: Request<Full<Bytes>>) -> Result<Response<Full<Bytes>>, BoxError> { +//! Ok(Response::new(Full::default())) +//! } +//! +//! +//! # #[tokio::main] +//! # async fn main() -> Result<(), BoxError> { +//! let service = ServiceBuilder::new() +//! // Validate requests using `MyHeader` +//! .layer(ValidateRequestHeaderLayer::custom(MyHeader { /* ... */ })) +//! .service_fn(handle); +//! # Ok(()) +//! # } +//! ``` +//! +//! Or using a closure: +//! +//! ``` +//! use tower_http::validate_request::{ValidateRequestHeaderLayer, ValidateRequest}; +//! use http::{Request, Response, StatusCode, header::ACCEPT}; +//! use bytes::Bytes; +//! use http_body_util::Full; +//! use tower::{Service, ServiceExt, ServiceBuilder, service_fn, BoxError}; +//! +//! async fn handle(request: Request<Full<Bytes>>) -> Result<Response<Full<Bytes>>, BoxError> { +//! # todo!(); +//! // ... +//! } +//! +//! # #[tokio::main] +//! # async fn main() -> Result<(), BoxError> { +//! let service = ServiceBuilder::new() +//! .layer(ValidateRequestHeaderLayer::custom(|request: &mut Request<Full<Bytes>>| { +//! // Validate the request +//! # Ok::<_, Response<Full<Bytes>>>(()) +//! })) +//! .service_fn(handle); +//! # Ok(()) +//! # } +//! ``` + +use http::{header, Request, Response, StatusCode}; +use mime::{Mime, MimeIter}; +use pin_project_lite::pin_project; +use std::{ + fmt, + future::Future, + marker::PhantomData, + pin::Pin, + sync::Arc, + task::{Context, Poll}, +}; +use tower_layer::Layer; +use tower_service::Service; + +/// Layer that applies [`ValidateRequestHeader`] which validates all requests. +/// +/// See the [module docs](crate::validate_request) for an example. +#[derive(Debug, Clone)] +pub struct ValidateRequestHeaderLayer<T> { + validate: T, +} + +impl<ResBody> ValidateRequestHeaderLayer<AcceptHeader<ResBody>> { + /// Validate requests have the required Accept header. + /// + /// The `Accept` header is required to be `*/*`, `type/*` or `type/subtype`, + /// as configured. + /// + /// # Panics + /// + /// Panics if `header_value` is not in the form: `type/subtype`, such as `application/json` + /// See `AcceptHeader::new` for when this method panics. + /// + /// # Example + /// + /// ``` + /// use http_body_util::Full; + /// use bytes::Bytes; + /// use tower_http::validate_request::{AcceptHeader, ValidateRequestHeaderLayer}; + /// + /// let layer = ValidateRequestHeaderLayer::<AcceptHeader<Full<Bytes>>>::accept("application/json"); + /// ``` + /// + /// [`Accept`]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Accept + pub fn accept(value: &str) -> Self + where + ResBody: Default, + { + Self::custom(AcceptHeader::new(value)) + } +} + +impl<T> ValidateRequestHeaderLayer<T> { + /// Validate requests using a custom method. + pub fn custom(validate: T) -> ValidateRequestHeaderLayer<T> { + Self { validate } + } +} + +impl<S, T> Layer<S> for ValidateRequestHeaderLayer<T> +where + T: Clone, +{ + type Service = ValidateRequestHeader<S, T>; + + fn layer(&self, inner: S) -> Self::Service { + ValidateRequestHeader::new(inner, self.validate.clone()) + } +} + +/// Middleware that validates requests. +/// +/// See the [module docs](crate::validate_request) for an example. +#[derive(Clone, Debug)] +pub struct ValidateRequestHeader<S, T> { + inner: S, + validate: T, +} + +impl<S, T> ValidateRequestHeader<S, T> { + fn new(inner: S, validate: T) -> Self { + Self::custom(inner, validate) + } + + define_inner_service_accessors!(); +} + +impl<S, ResBody> ValidateRequestHeader<S, AcceptHeader<ResBody>> { + /// Validate requests have the required Accept header. + /// + /// The `Accept` header is required to be `*/*`, `type/*` or `type/subtype`, + /// as configured. + /// + /// # Panics + /// + /// See `AcceptHeader::new` for when this method panics. + pub fn accept(inner: S, value: &str) -> Self + where + ResBody: Default, + { + Self::custom(inner, AcceptHeader::new(value)) + } +} + +impl<S, T> ValidateRequestHeader<S, T> { + /// Validate requests using a custom method. + pub fn custom(inner: S, validate: T) -> ValidateRequestHeader<S, T> { + Self { inner, validate } + } +} + +impl<ReqBody, ResBody, S, V> Service<Request<ReqBody>> for ValidateRequestHeader<S, V> +where + V: ValidateRequest<ReqBody, ResponseBody = ResBody>, + S: Service<Request<ReqBody>, Response = Response<ResBody>>, +{ + type Response = Response<ResBody>; + type Error = S::Error; + type Future = ResponseFuture<S::Future, ResBody>; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future { + match self.validate.validate(&mut req) { + Ok(_) => ResponseFuture::future(self.inner.call(req)), + Err(res) => ResponseFuture::invalid_header_value(res), + } + } +} + +pin_project! { + /// Response future for [`ValidateRequestHeader`]. + pub struct ResponseFuture<F, B> { + #[pin] + kind: Kind<F, B>, + } +} + +impl<F, B> ResponseFuture<F, B> { + fn future(future: F) -> Self { + Self { + kind: Kind::Future { future }, + } + } + + fn invalid_header_value(res: Response<B>) -> Self { + Self { + kind: Kind::Error { + response: Some(res), + }, + } + } +} + +pin_project! { + #[project = KindProj] + enum Kind<F, B> { + Future { + #[pin] + future: F, + }, + Error { + response: Option<Response<B>>, + }, + } +} + +impl<F, B, E> Future for ResponseFuture<F, B> +where + F: Future<Output = Result<Response<B>, E>>, +{ + type Output = F::Output; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + match self.project().kind.project() { + KindProj::Future { future } => future.poll(cx), + KindProj::Error { response } => { + let response = response.take().expect("future polled after completion"); + Poll::Ready(Ok(response)) + } + } + } +} + +/// Trait for validating requests. +pub trait ValidateRequest<B> { + /// The body type used for responses to unvalidated requests. + type ResponseBody; + + /// Validate the request. + /// + /// If `Ok(())` is returned then the request is allowed through, otherwise not. + fn validate(&mut self, request: &mut Request<B>) -> Result<(), Response<Self::ResponseBody>>; +} + +impl<B, F, ResBody> ValidateRequest<B> for F +where + F: FnMut(&mut Request<B>) -> Result<(), Response<ResBody>>, +{ + type ResponseBody = ResBody; + + fn validate(&mut self, request: &mut Request<B>) -> Result<(), Response<Self::ResponseBody>> { + self(request) + } +} + +/// Type that performs validation of the Accept header. +pub struct AcceptHeader<ResBody> { + header_value: Arc<Mime>, + _ty: PhantomData<fn() -> ResBody>, +} + +impl<ResBody> AcceptHeader<ResBody> { + /// Create a new `AcceptHeader`. + /// + /// # Panics + /// + /// Panics if `header_value` is not in the form: `type/subtype`, such as `application/json` + fn new(header_value: &str) -> Self + where + ResBody: Default, + { + Self { + header_value: Arc::new( + header_value + .parse::<Mime>() + .expect("value is not a valid header value"), + ), + _ty: PhantomData, + } + } +} + +impl<ResBody> Clone for AcceptHeader<ResBody> { + fn clone(&self) -> Self { + Self { + header_value: self.header_value.clone(), + _ty: PhantomData, + } + } +} + +impl<ResBody> fmt::Debug for AcceptHeader<ResBody> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("AcceptHeader") + .field("header_value", &self.header_value) + .finish() + } +} + +impl<B, ResBody> ValidateRequest<B> for AcceptHeader<ResBody> +where + ResBody: Default, +{ + type ResponseBody = ResBody; + + fn validate(&mut self, req: &mut Request<B>) -> Result<(), Response<Self::ResponseBody>> { + if !req.headers().contains_key(header::ACCEPT) { + return Ok(()); + } + if req + .headers() + .get_all(header::ACCEPT) + .into_iter() + .filter_map(|header| header.to_str().ok()) + .any(|h| { + MimeIter::new(h) + .map(|mim| { + if let Ok(mim) = mim { + let typ = self.header_value.type_(); + let subtype = self.header_value.subtype(); + match (mim.type_(), mim.subtype()) { + (t, s) if t == typ && s == subtype => true, + (t, mime::STAR) if t == typ => true, + (mime::STAR, mime::STAR) => true, + _ => false, + } + } else { + false + } + }) + .reduce(|acc, mim| acc || mim) + .unwrap_or(false) + }) + { + return Ok(()); + } + let mut res = Response::new(ResBody::default()); + *res.status_mut() = StatusCode::NOT_ACCEPTABLE; + Err(res) + } +} + +#[cfg(test)] +mod tests { + #[allow(unused_imports)] + use super::*; + use crate::test_helpers::Body; + use http::header; + use tower::{BoxError, ServiceBuilder, ServiceExt}; + + #[tokio::test] + async fn valid_accept_header() { + let mut service = ServiceBuilder::new() + .layer(ValidateRequestHeaderLayer::accept("application/json")) + .service_fn(echo); + + let request = Request::get("/") + .header(header::ACCEPT, "application/json") + .body(Body::empty()) + .unwrap(); + + let res = service.ready().await.unwrap().call(request).await.unwrap(); + + assert_eq!(res.status(), StatusCode::OK); + } + + #[tokio::test] + async fn valid_accept_header_accept_all_json() { + let mut service = ServiceBuilder::new() + .layer(ValidateRequestHeaderLayer::accept("application/json")) + .service_fn(echo); + + let request = Request::get("/") + .header(header::ACCEPT, "application/*") + .body(Body::empty()) + .unwrap(); + + let res = service.ready().await.unwrap().call(request).await.unwrap(); + + assert_eq!(res.status(), StatusCode::OK); + } + + #[tokio::test] + async fn valid_accept_header_accept_all() { + let mut service = ServiceBuilder::new() + .layer(ValidateRequestHeaderLayer::accept("application/json")) + .service_fn(echo); + + let request = Request::get("/") + .header(header::ACCEPT, "*/*") + .body(Body::empty()) + .unwrap(); + + let res = service.ready().await.unwrap().call(request).await.unwrap(); + + assert_eq!(res.status(), StatusCode::OK); + } + + #[tokio::test] + async fn invalid_accept_header() { + let mut service = ServiceBuilder::new() + .layer(ValidateRequestHeaderLayer::accept("application/json")) + .service_fn(echo); + + let request = Request::get("/") + .header(header::ACCEPT, "invalid") + .body(Body::empty()) + .unwrap(); + + let res = service.ready().await.unwrap().call(request).await.unwrap(); + + assert_eq!(res.status(), StatusCode::NOT_ACCEPTABLE); + } + #[tokio::test] + async fn not_accepted_accept_header_subtype() { + let mut service = ServiceBuilder::new() + .layer(ValidateRequestHeaderLayer::accept("application/json")) + .service_fn(echo); + + let request = Request::get("/") + .header(header::ACCEPT, "application/strings") + .body(Body::empty()) + .unwrap(); + + let res = service.ready().await.unwrap().call(request).await.unwrap(); + + assert_eq!(res.status(), StatusCode::NOT_ACCEPTABLE); + } + + #[tokio::test] + async fn not_accepted_accept_header() { + let mut service = ServiceBuilder::new() + .layer(ValidateRequestHeaderLayer::accept("application/json")) + .service_fn(echo); + + let request = Request::get("/") + .header(header::ACCEPT, "text/strings") + .body(Body::empty()) + .unwrap(); + + let res = service.ready().await.unwrap().call(request).await.unwrap(); + + assert_eq!(res.status(), StatusCode::NOT_ACCEPTABLE); + } + + #[tokio::test] + async fn accepted_multiple_header_value() { + let mut service = ServiceBuilder::new() + .layer(ValidateRequestHeaderLayer::accept("application/json")) + .service_fn(echo); + + let request = Request::get("/") + .header(header::ACCEPT, "text/strings") + .header(header::ACCEPT, "invalid, application/json") + .body(Body::empty()) + .unwrap(); + + let res = service.ready().await.unwrap().call(request).await.unwrap(); + + assert_eq!(res.status(), StatusCode::OK); + } + + #[tokio::test] + async fn accepted_inner_header_value() { + let mut service = ServiceBuilder::new() + .layer(ValidateRequestHeaderLayer::accept("application/json")) + .service_fn(echo); + + let request = Request::get("/") + .header(header::ACCEPT, "text/strings, invalid, application/json") + .body(Body::empty()) + .unwrap(); + + let res = service.ready().await.unwrap().call(request).await.unwrap(); + + assert_eq!(res.status(), StatusCode::OK); + } + + #[tokio::test] + async fn accepted_header_with_quotes_valid() { + let value = "foo/bar; parisien=\"baguette, text/html, jambon, fromage\", application/*"; + let mut service = ServiceBuilder::new() + .layer(ValidateRequestHeaderLayer::accept("application/xml")) + .service_fn(echo); + + let request = Request::get("/") + .header(header::ACCEPT, value) + .body(Body::empty()) + .unwrap(); + + let res = service.ready().await.unwrap().call(request).await.unwrap(); + + assert_eq!(res.status(), StatusCode::OK); + } + + #[tokio::test] + async fn accepted_header_with_quotes_invalid() { + let value = "foo/bar; parisien=\"baguette, text/html, jambon, fromage\""; + let mut service = ServiceBuilder::new() + .layer(ValidateRequestHeaderLayer::accept("text/html")) + .service_fn(echo); + + let request = Request::get("/") + .header(header::ACCEPT, value) + .body(Body::empty()) + .unwrap(); + + let res = service.ready().await.unwrap().call(request).await.unwrap(); + + assert_eq!(res.status(), StatusCode::NOT_ACCEPTABLE); + } + + async fn echo(req: Request<Body>) -> Result<Response<Body>, BoxError> { + Ok(Response::new(req.into_body())) + } +} |
