diff options
| author | mo khan <mo@mokhan.ca> | 2025-07-02 18:36:06 -0600 |
|---|---|---|
| committer | mo khan <mo@mokhan.ca> | 2025-07-02 18:36:06 -0600 |
| commit | 8cdfa445d6629ffef4cb84967ff7017654045bc2 (patch) | |
| tree | 22f0b0907c024c78d26a731e2e1f5219407d8102 /vendor/hyper-util/src | |
| parent | 4351c74c7c5f97156bc94d3a8549b9940ac80e3f (diff) | |
chore: add vendor directory
Diffstat (limited to 'vendor/hyper-util/src')
43 files changed, 10903 insertions, 0 deletions
diff --git a/vendor/hyper-util/src/client/client.rs b/vendor/hyper-util/src/client/client.rs new file mode 100644 index 00000000..a9fb244a --- /dev/null +++ b/vendor/hyper-util/src/client/client.rs @@ -0,0 +1,132 @@ +use hyper::{Request, Response}; +use tower::{Service, MakeService}; + +use super::connect::Connect; +use super::pool; + +pub struct Client<M> { + // Hi there. So, let's take a 0.14.x hyper::Client, and build up its layers + // here. We don't need to fully expose the layers to start with, but that + // is the end goal. + // + // Client = MakeSvcAsService< + // SetHost< + // Http1RequestTarget< + // DelayedRelease< + // ConnectingPool<C, P> + // > + // > + // > + // > + make_svc: M, +} + +// We might change this... :shrug: +type PoolKey = hyper::Uri; + +struct ConnectingPool<C, P> { + connector: C, + pool: P, +} + +struct PoolableSvc<S>(S); + +/// A marker to identify what version a pooled connection is. +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +#[allow(dead_code)] +pub enum Ver { + Auto, + Http2, +} + +// ===== impl Client ===== + +impl<M, /*ReqBody, ResBody,*/ E> Client<M> +where + M: MakeService< + hyper::Uri, + Request<()>, + Response = Response<()>, + Error = E, + MakeError = E, + >, + //M: Service<hyper::Uri, Error = E>, + //M::Response: Service<Request<ReqBody>, Response = Response<ResBody>>, +{ + pub async fn request(&mut self, req: Request<()>) -> Result<Response<()>, E> { + let mut svc = self.make_svc.make_service(req.uri().clone()).await?; + svc.call(req).await + } +} + +impl<M, /*ReqBody, ResBody,*/ E> Client<M> +where + M: MakeService< + hyper::Uri, + Request<()>, + Response = Response<()>, + Error = E, + MakeError = E, + >, + //M: Service<hyper::Uri, Error = E>, + //M::Response: Service<Request<ReqBody>, Response = Response<ResBody>>, +{ + +} + +// ===== impl ConnectingPool ===== + +impl<C, P> ConnectingPool<C, P> +where + C: Connect, + C::_Svc: Unpin + Send + 'static, +{ + async fn connection_for(&self, target: PoolKey) -> Result<pool::Pooled<PoolableSvc<C::_Svc>, PoolKey>, ()> { + todo!() + } +} + +impl<S> pool::Poolable for PoolableSvc<S> +where + S: Unpin + Send + 'static, +{ + fn is_open(&self) -> bool { + /* + match self.tx { + PoolTx::Http1(ref tx) => tx.is_ready(), + #[cfg(feature = "http2")] + PoolTx::Http2(ref tx) => tx.is_ready(), + } + */ + true + } + + fn reserve(self) -> pool::Reservation<Self> { + /* + match self.tx { + PoolTx::Http1(tx) => Reservation::Unique(PoolClient { + conn_info: self.conn_info, + tx: PoolTx::Http1(tx), + }), + #[cfg(feature = "http2")] + PoolTx::Http2(tx) => { + let b = PoolClient { + conn_info: self.conn_info.clone(), + tx: PoolTx::Http2(tx.clone()), + }; + let a = PoolClient { + conn_info: self.conn_info, + tx: PoolTx::Http2(tx), + }; + Reservation::Shared(a, b) + } + } + */ + pool::Reservation::Unique(self) + } + + fn can_share(&self) -> bool { + false + //self.is_http2() + } +} diff --git a/vendor/hyper-util/src/client/legacy/client.rs b/vendor/hyper-util/src/client/legacy/client.rs new file mode 100644 index 00000000..9899d346 --- /dev/null +++ b/vendor/hyper-util/src/client/legacy/client.rs @@ -0,0 +1,1690 @@ +//! The legacy HTTP Client from 0.14.x +//! +//! This `Client` will eventually be deconstructed into more composable parts. +//! For now, to enable people to use hyper 1.0 quicker, this `Client` exists +//! in much the same way it did in hyper 0.14. + +use std::error::Error as StdError; +use std::fmt; +use std::future::Future; +use std::pin::Pin; +use std::task::{self, Poll}; +use std::time::Duration; + +use futures_util::future::{self, Either, FutureExt, TryFutureExt}; +use http::uri::Scheme; +use hyper::client::conn::TrySendError as ConnTrySendError; +use hyper::header::{HeaderValue, HOST}; +use hyper::rt::Timer; +use hyper::{body::Body, Method, Request, Response, Uri, Version}; +use tracing::{debug, trace, warn}; + +use super::connect::capture::CaptureConnectionExtension; +#[cfg(feature = "tokio")] +use super::connect::HttpConnector; +use super::connect::{Alpn, Connect, Connected, Connection}; +use super::pool::{self, Ver}; + +use crate::common::future::poll_fn; +use crate::common::{lazy as hyper_lazy, timer, Exec, Lazy, SyncWrapper}; + +type BoxSendFuture = Pin<Box<dyn Future<Output = ()> + Send>>; + +/// A Client to make outgoing HTTP requests. +/// +/// `Client` is cheap to clone and cloning is the recommended way to share a `Client`. The +/// underlying connection pool will be reused. +#[cfg_attr(docsrs, doc(cfg(any(feature = "http1", feature = "http2"))))] +pub struct Client<C, B> { + config: Config, + connector: C, + exec: Exec, + #[cfg(feature = "http1")] + h1_builder: hyper::client::conn::http1::Builder, + #[cfg(feature = "http2")] + h2_builder: hyper::client::conn::http2::Builder<Exec>, + pool: pool::Pool<PoolClient<B>, PoolKey>, +} + +#[derive(Clone, Copy, Debug)] +struct Config { + retry_canceled_requests: bool, + set_host: bool, + ver: Ver, +} + +/// Client errors +pub struct Error { + kind: ErrorKind, + source: Option<Box<dyn StdError + Send + Sync>>, + #[cfg(any(feature = "http1", feature = "http2"))] + connect_info: Option<Connected>, +} + +#[derive(Debug)] +enum ErrorKind { + Canceled, + ChannelClosed, + Connect, + UserUnsupportedRequestMethod, + UserUnsupportedVersion, + UserAbsoluteUriRequired, + SendRequest, +} + +macro_rules! e { + ($kind:ident) => { + Error { + kind: ErrorKind::$kind, + source: None, + connect_info: None, + } + }; + ($kind:ident, $src:expr) => { + Error { + kind: ErrorKind::$kind, + source: Some($src.into()), + connect_info: None, + } + }; +} + +// We might change this... :shrug: +type PoolKey = (http::uri::Scheme, http::uri::Authority); + +enum TrySendError<B> { + Retryable { + error: Error, + req: Request<B>, + connection_reused: bool, + }, + Nope(Error), +} + +/// A `Future` that will resolve to an HTTP Response. +/// +/// This is returned by `Client::request` (and `Client::get`). +#[must_use = "futures do nothing unless polled"] +pub struct ResponseFuture { + inner: SyncWrapper< + Pin<Box<dyn Future<Output = Result<Response<hyper::body::Incoming>, Error>> + Send>>, + >, +} + +// ===== impl Client ===== + +impl Client<(), ()> { + /// Create a builder to configure a new `Client`. + /// + /// # Example + /// + /// ``` + /// # #[cfg(feature = "tokio")] + /// # fn run () { + /// use std::time::Duration; + /// use hyper_util::client::legacy::Client; + /// use hyper_util::rt::{TokioExecutor, TokioTimer}; + /// + /// let client = Client::builder(TokioExecutor::new()) + /// .pool_timer(TokioTimer::new()) + /// .pool_idle_timeout(Duration::from_secs(30)) + /// .http2_only(true) + /// .build_http(); + /// # let infer: Client<_, http_body_util::Full<bytes::Bytes>> = client; + /// # drop(infer); + /// # } + /// # fn main() {} + /// ``` + pub fn builder<E>(executor: E) -> Builder + where + E: hyper::rt::Executor<BoxSendFuture> + Send + Sync + Clone + 'static, + { + Builder::new(executor) + } +} + +impl<C, B> Client<C, B> +where + C: Connect + Clone + Send + Sync + 'static, + B: Body + Send + 'static + Unpin, + B::Data: Send, + B::Error: Into<Box<dyn StdError + Send + Sync>>, +{ + /// Send a `GET` request to the supplied `Uri`. + /// + /// # Note + /// + /// This requires that the `Body` type have a `Default` implementation. + /// It *should* return an "empty" version of itself, such that + /// `Body::is_end_stream` is `true`. + /// + /// # Example + /// + /// ``` + /// # #[cfg(feature = "tokio")] + /// # fn run () { + /// use hyper::Uri; + /// use hyper_util::client::legacy::Client; + /// use hyper_util::rt::TokioExecutor; + /// use bytes::Bytes; + /// use http_body_util::Full; + /// + /// let client: Client<_, Full<Bytes>> = Client::builder(TokioExecutor::new()).build_http(); + /// + /// let future = client.get(Uri::from_static("http://httpbin.org/ip")); + /// # } + /// # fn main() {} + /// ``` + pub fn get(&self, uri: Uri) -> ResponseFuture + where + B: Default, + { + let body = B::default(); + if !body.is_end_stream() { + warn!("default Body used for get() does not return true for is_end_stream"); + } + + let mut req = Request::new(body); + *req.uri_mut() = uri; + self.request(req) + } + + /// Send a constructed `Request` using this `Client`. + /// + /// # Example + /// + /// ``` + /// # #[cfg(feature = "tokio")] + /// # fn run () { + /// use hyper::{Method, Request}; + /// use hyper_util::client::legacy::Client; + /// use http_body_util::Full; + /// use hyper_util::rt::TokioExecutor; + /// use bytes::Bytes; + /// + /// let client: Client<_, Full<Bytes>> = Client::builder(TokioExecutor::new()).build_http(); + /// + /// let req: Request<Full<Bytes>> = Request::builder() + /// .method(Method::POST) + /// .uri("http://httpbin.org/post") + /// .body(Full::from("Hallo!")) + /// .expect("request builder"); + /// + /// let future = client.request(req); + /// # } + /// # fn main() {} + /// ``` + pub fn request(&self, mut req: Request<B>) -> ResponseFuture { + let is_http_connect = req.method() == Method::CONNECT; + match req.version() { + Version::HTTP_11 => (), + Version::HTTP_10 => { + if is_http_connect { + warn!("CONNECT is not allowed for HTTP/1.0"); + return ResponseFuture::new(future::err(e!(UserUnsupportedRequestMethod))); + } + } + Version::HTTP_2 => (), + // completely unsupported HTTP version (like HTTP/0.9)! + other => return ResponseFuture::error_version(other), + }; + + let pool_key = match extract_domain(req.uri_mut(), is_http_connect) { + Ok(s) => s, + Err(err) => { + return ResponseFuture::new(future::err(err)); + } + }; + + ResponseFuture::new(self.clone().send_request(req, pool_key)) + } + + async fn send_request( + self, + mut req: Request<B>, + pool_key: PoolKey, + ) -> Result<Response<hyper::body::Incoming>, Error> { + let uri = req.uri().clone(); + + loop { + req = match self.try_send_request(req, pool_key.clone()).await { + Ok(resp) => return Ok(resp), + Err(TrySendError::Nope(err)) => return Err(err), + Err(TrySendError::Retryable { + mut req, + error, + connection_reused, + }) => { + if !self.config.retry_canceled_requests || !connection_reused { + // if client disabled, don't retry + // a fresh connection means we definitely can't retry + return Err(error); + } + + trace!( + "unstarted request canceled, trying again (reason={:?})", + error + ); + *req.uri_mut() = uri.clone(); + req + } + } + } + } + + async fn try_send_request( + &self, + mut req: Request<B>, + pool_key: PoolKey, + ) -> Result<Response<hyper::body::Incoming>, TrySendError<B>> { + let mut pooled = self + .connection_for(pool_key) + .await + // `connection_for` already retries checkout errors, so if + // it returns an error, there's not much else to retry + .map_err(TrySendError::Nope)?; + + if let Some(conn) = req.extensions_mut().get_mut::<CaptureConnectionExtension>() { + conn.set(&pooled.conn_info); + } + + if pooled.is_http1() { + if req.version() == Version::HTTP_2 { + warn!("Connection is HTTP/1, but request requires HTTP/2"); + return Err(TrySendError::Nope( + e!(UserUnsupportedVersion).with_connect_info(pooled.conn_info.clone()), + )); + } + + if self.config.set_host { + let uri = req.uri().clone(); + req.headers_mut().entry(HOST).or_insert_with(|| { + let hostname = uri.host().expect("authority implies host"); + if let Some(port) = get_non_default_port(&uri) { + let s = format!("{hostname}:{port}"); + HeaderValue::from_str(&s) + } else { + HeaderValue::from_str(hostname) + } + .expect("uri host is valid header value") + }); + } + + // CONNECT always sends authority-form, so check it first... + if req.method() == Method::CONNECT { + authority_form(req.uri_mut()); + } else if pooled.conn_info.is_proxied { + absolute_form(req.uri_mut()); + } else { + origin_form(req.uri_mut()); + } + } else if req.method() == Method::CONNECT && !pooled.is_http2() { + authority_form(req.uri_mut()); + } + + let mut res = match pooled.try_send_request(req).await { + Ok(res) => res, + Err(mut err) => { + return if let Some(req) = err.take_message() { + Err(TrySendError::Retryable { + connection_reused: pooled.is_reused(), + error: e!(Canceled, err.into_error()) + .with_connect_info(pooled.conn_info.clone()), + req, + }) + } else { + Err(TrySendError::Nope( + e!(SendRequest, err.into_error()) + .with_connect_info(pooled.conn_info.clone()), + )) + } + } + }; + + // If the Connector included 'extra' info, add to Response... + if let Some(extra) = &pooled.conn_info.extra { + extra.set(res.extensions_mut()); + } + + // If pooled is HTTP/2, we can toss this reference immediately. + // + // when pooled is dropped, it will try to insert back into the + // pool. To delay that, spawn a future that completes once the + // sender is ready again. + // + // This *should* only be once the related `Connection` has polled + // for a new request to start. + // + // It won't be ready if there is a body to stream. + if pooled.is_http2() || !pooled.is_pool_enabled() || pooled.is_ready() { + drop(pooled); + } else if !res.body().is_end_stream() { + //let (delayed_tx, delayed_rx) = oneshot::channel::<()>(); + //res.body_mut().delayed_eof(delayed_rx); + let on_idle = poll_fn(move |cx| pooled.poll_ready(cx)).map(move |_| { + // At this point, `pooled` is dropped, and had a chance + // to insert into the pool (if conn was idle) + //drop(delayed_tx); + }); + + self.exec.execute(on_idle); + } else { + // There's no body to delay, but the connection isn't + // ready yet. Only re-insert when it's ready + let on_idle = poll_fn(move |cx| pooled.poll_ready(cx)).map(|_| ()); + + self.exec.execute(on_idle); + } + + Ok(res) + } + + async fn connection_for( + &self, + pool_key: PoolKey, + ) -> Result<pool::Pooled<PoolClient<B>, PoolKey>, Error> { + loop { + match self.one_connection_for(pool_key.clone()).await { + Ok(pooled) => return Ok(pooled), + Err(ClientConnectError::Normal(err)) => return Err(err), + Err(ClientConnectError::CheckoutIsClosed(reason)) => { + if !self.config.retry_canceled_requests { + return Err(e!(Connect, reason)); + } + + trace!( + "unstarted request canceled, trying again (reason={:?})", + reason, + ); + continue; + } + }; + } + } + + async fn one_connection_for( + &self, + pool_key: PoolKey, + ) -> Result<pool::Pooled<PoolClient<B>, PoolKey>, ClientConnectError> { + // Return a single connection if pooling is not enabled + if !self.pool.is_enabled() { + return self + .connect_to(pool_key) + .await + .map_err(ClientConnectError::Normal); + } + + // This actually races 2 different futures to try to get a ready + // connection the fastest, and to reduce connection churn. + // + // - If the pool has an idle connection waiting, that's used + // immediately. + // - Otherwise, the Connector is asked to start connecting to + // the destination Uri. + // - Meanwhile, the pool Checkout is watching to see if any other + // request finishes and tries to insert an idle connection. + // - If a new connection is started, but the Checkout wins after + // (an idle connection became available first), the started + // connection future is spawned into the runtime to complete, + // and then be inserted into the pool as an idle connection. + let checkout = self.pool.checkout(pool_key.clone()); + let connect = self.connect_to(pool_key); + let is_ver_h2 = self.config.ver == Ver::Http2; + + // The order of the `select` is depended on below... + + match future::select(checkout, connect).await { + // Checkout won, connect future may have been started or not. + // + // If it has, let it finish and insert back into the pool, + // so as to not waste the socket... + Either::Left((Ok(checked_out), connecting)) => { + // This depends on the `select` above having the correct + // order, such that if the checkout future were ready + // immediately, the connect future will never have been + // started. + // + // If it *wasn't* ready yet, then the connect future will + // have been started... + if connecting.started() { + let bg = connecting + .map_err(|err| { + trace!("background connect error: {}", err); + }) + .map(|_pooled| { + // dropping here should just place it in + // the Pool for us... + }); + // An execute error here isn't important, we're just trying + // to prevent a waste of a socket... + self.exec.execute(bg); + } + Ok(checked_out) + } + // Connect won, checkout can just be dropped. + Either::Right((Ok(connected), _checkout)) => Ok(connected), + // Either checkout or connect could get canceled: + // + // 1. Connect is canceled if this is HTTP/2 and there is + // an outstanding HTTP/2 connecting task. + // 2. Checkout is canceled if the pool cannot deliver an + // idle connection reliably. + // + // In both cases, we should just wait for the other future. + Either::Left((Err(err), connecting)) => { + if err.is_canceled() { + connecting.await.map_err(ClientConnectError::Normal) + } else { + Err(ClientConnectError::Normal(e!(Connect, err))) + } + } + Either::Right((Err(err), checkout)) => { + if err.is_canceled() { + checkout.await.map_err(move |err| { + if is_ver_h2 && err.is_canceled() { + ClientConnectError::CheckoutIsClosed(err) + } else { + ClientConnectError::Normal(e!(Connect, err)) + } + }) + } else { + Err(ClientConnectError::Normal(err)) + } + } + } + } + + #[cfg(any(feature = "http1", feature = "http2"))] + fn connect_to( + &self, + pool_key: PoolKey, + ) -> impl Lazy<Output = Result<pool::Pooled<PoolClient<B>, PoolKey>, Error>> + Send + Unpin + { + let executor = self.exec.clone(); + let pool = self.pool.clone(); + #[cfg(feature = "http1")] + let h1_builder = self.h1_builder.clone(); + #[cfg(feature = "http2")] + let h2_builder = self.h2_builder.clone(); + let ver = self.config.ver; + let is_ver_h2 = ver == Ver::Http2; + let connector = self.connector.clone(); + let dst = domain_as_uri(pool_key.clone()); + hyper_lazy(move || { + // Try to take a "connecting lock". + // + // If the pool_key is for HTTP/2, and there is already a + // connection being established, then this can't take a + // second lock. The "connect_to" future is Canceled. + let connecting = match pool.connecting(&pool_key, ver) { + Some(lock) => lock, + None => { + let canceled = e!(Canceled); + // TODO + //crate::Error::new_canceled().with("HTTP/2 connection in progress"); + return Either::Right(future::err(canceled)); + } + }; + Either::Left( + connector + .connect(super::connect::sealed::Internal, dst) + .map_err(|src| e!(Connect, src)) + .and_then(move |io| { + let connected = io.connected(); + // If ALPN is h2 and we aren't http2_only already, + // then we need to convert our pool checkout into + // a single HTTP2 one. + let connecting = if connected.alpn == Alpn::H2 && !is_ver_h2 { + match connecting.alpn_h2(&pool) { + Some(lock) => { + trace!("ALPN negotiated h2, updating pool"); + lock + } + None => { + // Another connection has already upgraded, + // the pool checkout should finish up for us. + let canceled = e!(Canceled, "ALPN upgraded to HTTP/2"); + return Either::Right(future::err(canceled)); + } + } + } else { + connecting + }; + + #[cfg_attr(not(feature = "http2"), allow(unused))] + let is_h2 = is_ver_h2 || connected.alpn == Alpn::H2; + + Either::Left(Box::pin(async move { + let tx = if is_h2 { + #[cfg(feature = "http2")] { + let (mut tx, conn) = + h2_builder.handshake(io).await.map_err(Error::tx)?; + + trace!( + "http2 handshake complete, spawning background dispatcher task" + ); + executor.execute( + conn.map_err(|e| debug!("client connection error: {}", e)) + .map(|_| ()), + ); + + // Wait for 'conn' to ready up before we + // declare this tx as usable + tx.ready().await.map_err(Error::tx)?; + PoolTx::Http2(tx) + } + #[cfg(not(feature = "http2"))] + panic!("http2 feature is not enabled"); + } else { + #[cfg(feature = "http1")] { + // Perform the HTTP/1.1 handshake on the provided I/O stream. + // Uses the h1_builder to establish a connection, returning a sender (tx) for requests + // and a connection task (conn) that manages the connection lifecycle. + let (mut tx, conn) = + h1_builder.handshake(io).await.map_err(crate::client::legacy::client::Error::tx)?; + + // Log that the HTTP/1.1 handshake has completed successfully. + // This indicates the connection is established and ready for request processing. + trace!( + "http1 handshake complete, spawning background dispatcher task" + ); + // Create a oneshot channel to communicate errors from the connection task. + // err_tx sends errors from the connection task, and err_rx receives them + // to correlate connection failures with request readiness errors. + let (err_tx, err_rx) = tokio::sync::oneshot::channel(); + // Spawn the connection task in the background using the executor. + // The task manages the HTTP/1.1 connection, including upgrades (e.g., WebSocket). + // Errors are sent via err_tx to ensure they can be checked if the sender (tx) fails. + executor.execute( + conn.with_upgrades() + .map_err(|e| { + // Log the connection error at debug level for diagnostic purposes. + debug!("client connection error: {:?}", e); + // Log that the error is being sent to the error channel. + trace!("sending connection error to error channel"); + // Send the error via the oneshot channel, ignoring send failures + // (e.g., if the receiver is dropped, which is handled later). + let _ =err_tx.send(e); + }) + .map(|_| ()), + ); + // Log that the client is waiting for the connection to be ready. + // Readiness indicates the sender (tx) can accept a request without blocking. + trace!("waiting for connection to be ready"); + // Check if the sender is ready to accept a request. + // This ensures the connection is fully established before proceeding. + // aka: + // Wait for 'conn' to ready up before we + // declare this tx as usable + match tx.ready().await { + // If ready, the connection is usable for sending requests. + Ok(_) => { + // Log that the connection is ready for use. + trace!("connection is ready"); + // Drop the error receiver, as it’s no longer needed since the sender is ready. + // This prevents waiting for errors that won’t occur in a successful case. + drop(err_rx); + // Wrap the sender in PoolTx::Http1 for use in the connection pool. + PoolTx::Http1(tx) + } + // If the sender fails with a closed channel error, check for a specific connection error. + // This distinguishes between a vague ChannelClosed error and an actual connection failure. + Err(e) if e.is_closed() => { + // Log that the channel is closed, indicating a potential connection issue. + trace!("connection channel closed, checking for connection error"); + // Check the oneshot channel for a specific error from the connection task. + match err_rx.await { + // If an error was received, it’s a specific connection failure. + Ok(err) => { + // Log the specific connection error for diagnostics. + trace!("received connection error: {:?}", err); + // Return the error wrapped in Error::tx to propagate it. + return Err(crate::client::legacy::client::Error::tx(err)); + } + // If the error channel is closed, no specific error was sent. + // Fall back to the vague ChannelClosed error. + Err(_) => { + // Log that the error channel is closed, indicating no specific error. + trace!("error channel closed, returning the vague ChannelClosed error"); + // Return the original error wrapped in Error::tx. + return Err(crate::client::legacy::client::Error::tx(e)); + } + } + } + // For other errors (e.g., timeout, I/O issues), propagate them directly. + // These are not ChannelClosed errors and don’t require error channel checks. + Err(e) => { + // Log the specific readiness failure for diagnostics. + trace!("connection readiness failed: {:?}", e); + // Return the error wrapped in Error::tx to propagate it. + return Err(crate::client::legacy::client::Error::tx(e)); + } + } + } + #[cfg(not(feature = "http1"))] { + panic!("http1 feature is not enabled"); + } + }; + + Ok(pool.pooled( + connecting, + PoolClient { + conn_info: connected, + tx, + }, + )) + })) + }), + ) + }) + } +} + +impl<C, B> tower_service::Service<Request<B>> for Client<C, B> +where + C: Connect + Clone + Send + Sync + 'static, + B: Body + Send + 'static + Unpin, + B::Data: Send, + B::Error: Into<Box<dyn StdError + Send + Sync>>, +{ + type Response = Response<hyper::body::Incoming>; + type Error = Error; + type Future = ResponseFuture; + + fn poll_ready(&mut self, _: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: Request<B>) -> Self::Future { + self.request(req) + } +} + +impl<C, B> tower_service::Service<Request<B>> for &'_ Client<C, B> +where + C: Connect + Clone + Send + Sync + 'static, + B: Body + Send + 'static + Unpin, + B::Data: Send, + B::Error: Into<Box<dyn StdError + Send + Sync>>, +{ + type Response = Response<hyper::body::Incoming>; + type Error = Error; + type Future = ResponseFuture; + + fn poll_ready(&mut self, _: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: Request<B>) -> Self::Future { + self.request(req) + } +} + +impl<C: Clone, B> Clone for Client<C, B> { + fn clone(&self) -> Client<C, B> { + Client { + config: self.config, + exec: self.exec.clone(), + #[cfg(feature = "http1")] + h1_builder: self.h1_builder.clone(), + #[cfg(feature = "http2")] + h2_builder: self.h2_builder.clone(), + connector: self.connector.clone(), + pool: self.pool.clone(), + } + } +} + +impl<C, B> fmt::Debug for Client<C, B> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Client").finish() + } +} + +// ===== impl ResponseFuture ===== + +impl ResponseFuture { + fn new<F>(value: F) -> Self + where + F: Future<Output = Result<Response<hyper::body::Incoming>, Error>> + Send + 'static, + { + Self { + inner: SyncWrapper::new(Box::pin(value)), + } + } + + fn error_version(ver: Version) -> Self { + warn!("Request has unsupported version \"{:?}\"", ver); + ResponseFuture::new(Box::pin(future::err(e!(UserUnsupportedVersion)))) + } +} + +impl fmt::Debug for ResponseFuture { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.pad("Future<Response>") + } +} + +impl Future for ResponseFuture { + type Output = Result<Response<hyper::body::Incoming>, Error>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> { + self.inner.get_mut().as_mut().poll(cx) + } +} + +// ===== impl PoolClient ===== + +// FIXME: allow() required due to `impl Trait` leaking types to this lint +#[allow(missing_debug_implementations)] +struct PoolClient<B> { + conn_info: Connected, + tx: PoolTx<B>, +} + +enum PoolTx<B> { + #[cfg(feature = "http1")] + Http1(hyper::client::conn::http1::SendRequest<B>), + #[cfg(feature = "http2")] + Http2(hyper::client::conn::http2::SendRequest<B>), +} + +impl<B> PoolClient<B> { + fn poll_ready( + &mut self, + #[allow(unused_variables)] cx: &mut task::Context<'_>, + ) -> Poll<Result<(), Error>> { + match self.tx { + #[cfg(feature = "http1")] + PoolTx::Http1(ref mut tx) => tx.poll_ready(cx).map_err(Error::closed), + #[cfg(feature = "http2")] + PoolTx::Http2(_) => Poll::Ready(Ok(())), + } + } + + fn is_http1(&self) -> bool { + !self.is_http2() + } + + fn is_http2(&self) -> bool { + match self.tx { + #[cfg(feature = "http1")] + PoolTx::Http1(_) => false, + #[cfg(feature = "http2")] + PoolTx::Http2(_) => true, + } + } + + fn is_poisoned(&self) -> bool { + self.conn_info.poisoned.poisoned() + } + + fn is_ready(&self) -> bool { + match self.tx { + #[cfg(feature = "http1")] + PoolTx::Http1(ref tx) => tx.is_ready(), + #[cfg(feature = "http2")] + PoolTx::Http2(ref tx) => tx.is_ready(), + } + } +} + +impl<B: Body + 'static> PoolClient<B> { + fn try_send_request( + &mut self, + req: Request<B>, + ) -> impl Future<Output = Result<Response<hyper::body::Incoming>, ConnTrySendError<Request<B>>>> + where + B: Send, + { + #[cfg(all(feature = "http1", feature = "http2"))] + return match self.tx { + #[cfg(feature = "http1")] + PoolTx::Http1(ref mut tx) => Either::Left(tx.try_send_request(req)), + #[cfg(feature = "http2")] + PoolTx::Http2(ref mut tx) => Either::Right(tx.try_send_request(req)), + }; + + #[cfg(feature = "http1")] + #[cfg(not(feature = "http2"))] + return match self.tx { + #[cfg(feature = "http1")] + PoolTx::Http1(ref mut tx) => tx.try_send_request(req), + }; + + #[cfg(not(feature = "http1"))] + #[cfg(feature = "http2")] + return match self.tx { + #[cfg(feature = "http2")] + PoolTx::Http2(ref mut tx) => tx.try_send_request(req), + }; + } +} + +impl<B> pool::Poolable for PoolClient<B> +where + B: Send + 'static, +{ + fn is_open(&self) -> bool { + !self.is_poisoned() && self.is_ready() + } + + fn reserve(self) -> pool::Reservation<Self> { + match self.tx { + #[cfg(feature = "http1")] + PoolTx::Http1(tx) => pool::Reservation::Unique(PoolClient { + conn_info: self.conn_info, + tx: PoolTx::Http1(tx), + }), + #[cfg(feature = "http2")] + PoolTx::Http2(tx) => { + let b = PoolClient { + conn_info: self.conn_info.clone(), + tx: PoolTx::Http2(tx.clone()), + }; + let a = PoolClient { + conn_info: self.conn_info, + tx: PoolTx::Http2(tx), + }; + pool::Reservation::Shared(a, b) + } + } + } + + fn can_share(&self) -> bool { + self.is_http2() + } +} + +enum ClientConnectError { + Normal(Error), + CheckoutIsClosed(pool::Error), +} + +fn origin_form(uri: &mut Uri) { + let path = match uri.path_and_query() { + Some(path) if path.as_str() != "/" => { + let mut parts = ::http::uri::Parts::default(); + parts.path_and_query = Some(path.clone()); + Uri::from_parts(parts).expect("path is valid uri") + } + _none_or_just_slash => { + debug_assert!(Uri::default() == "/"); + Uri::default() + } + }; + *uri = path +} + +fn absolute_form(uri: &mut Uri) { + debug_assert!(uri.scheme().is_some(), "absolute_form needs a scheme"); + debug_assert!( + uri.authority().is_some(), + "absolute_form needs an authority" + ); + // If the URI is to HTTPS, and the connector claimed to be a proxy, + // then it *should* have tunneled, and so we don't want to send + // absolute-form in that case. + if uri.scheme() == Some(&Scheme::HTTPS) { + origin_form(uri); + } +} + +fn authority_form(uri: &mut Uri) { + if let Some(path) = uri.path_and_query() { + // `https://hyper.rs` would parse with `/` path, don't + // annoy people about that... + if path != "/" { + warn!("HTTP/1.1 CONNECT request stripping path: {:?}", path); + } + } + *uri = match uri.authority() { + Some(auth) => { + let mut parts = ::http::uri::Parts::default(); + parts.authority = Some(auth.clone()); + Uri::from_parts(parts).expect("authority is valid") + } + None => { + unreachable!("authority_form with relative uri"); + } + }; +} + +fn extract_domain(uri: &mut Uri, is_http_connect: bool) -> Result<PoolKey, Error> { + let uri_clone = uri.clone(); + match (uri_clone.scheme(), uri_clone.authority()) { + (Some(scheme), Some(auth)) => Ok((scheme.clone(), auth.clone())), + (None, Some(auth)) if is_http_connect => { + let scheme = match auth.port_u16() { + Some(443) => { + set_scheme(uri, Scheme::HTTPS); + Scheme::HTTPS + } + _ => { + set_scheme(uri, Scheme::HTTP); + Scheme::HTTP + } + }; + Ok((scheme, auth.clone())) + } + _ => { + debug!("Client requires absolute-form URIs, received: {:?}", uri); + Err(e!(UserAbsoluteUriRequired)) + } + } +} + +fn domain_as_uri((scheme, auth): PoolKey) -> Uri { + http::uri::Builder::new() + .scheme(scheme) + .authority(auth) + .path_and_query("/") + .build() + .expect("domain is valid Uri") +} + +fn set_scheme(uri: &mut Uri, scheme: Scheme) { + debug_assert!( + uri.scheme().is_none(), + "set_scheme expects no existing scheme" + ); + let old = std::mem::take(uri); + let mut parts: ::http::uri::Parts = old.into(); + parts.scheme = Some(scheme); + parts.path_and_query = Some("/".parse().expect("slash is a valid path")); + *uri = Uri::from_parts(parts).expect("scheme is valid"); +} + +fn get_non_default_port(uri: &Uri) -> Option<http::uri::Port<&str>> { + match (uri.port().map(|p| p.as_u16()), is_schema_secure(uri)) { + (Some(443), true) => None, + (Some(80), false) => None, + _ => uri.port(), + } +} + +fn is_schema_secure(uri: &Uri) -> bool { + uri.scheme_str() + .map(|scheme_str| matches!(scheme_str, "wss" | "https")) + .unwrap_or_default() +} + +/// A builder to configure a new [`Client`](Client). +/// +/// # Example +/// +/// ``` +/// # #[cfg(feature = "tokio")] +/// # fn run () { +/// use std::time::Duration; +/// use hyper_util::client::legacy::Client; +/// use hyper_util::rt::TokioExecutor; +/// +/// let client = Client::builder(TokioExecutor::new()) +/// .pool_idle_timeout(Duration::from_secs(30)) +/// .http2_only(true) +/// .build_http(); +/// # let infer: Client<_, http_body_util::Full<bytes::Bytes>> = client; +/// # drop(infer); +/// # } +/// # fn main() {} +/// ``` +#[cfg_attr(docsrs, doc(cfg(any(feature = "http1", feature = "http2"))))] +#[derive(Clone)] +pub struct Builder { + client_config: Config, + exec: Exec, + #[cfg(feature = "http1")] + h1_builder: hyper::client::conn::http1::Builder, + #[cfg(feature = "http2")] + h2_builder: hyper::client::conn::http2::Builder<Exec>, + pool_config: pool::Config, + pool_timer: Option<timer::Timer>, +} + +impl Builder { + /// Construct a new Builder. + pub fn new<E>(executor: E) -> Self + where + E: hyper::rt::Executor<BoxSendFuture> + Send + Sync + Clone + 'static, + { + let exec = Exec::new(executor); + Self { + client_config: Config { + retry_canceled_requests: true, + set_host: true, + ver: Ver::Auto, + }, + exec: exec.clone(), + #[cfg(feature = "http1")] + h1_builder: hyper::client::conn::http1::Builder::new(), + #[cfg(feature = "http2")] + h2_builder: hyper::client::conn::http2::Builder::new(exec), + pool_config: pool::Config { + idle_timeout: Some(Duration::from_secs(90)), + max_idle_per_host: usize::MAX, + }, + pool_timer: None, + } + } + /// Set an optional timeout for idle sockets being kept-alive. + /// A `Timer` is required for this to take effect. See `Builder::pool_timer` + /// + /// Pass `None` to disable timeout. + /// + /// Default is 90 seconds. + /// + /// # Example + /// + /// ``` + /// # #[cfg(feature = "tokio")] + /// # fn run () { + /// use std::time::Duration; + /// use hyper_util::client::legacy::Client; + /// use hyper_util::rt::{TokioExecutor, TokioTimer}; + /// + /// let client = Client::builder(TokioExecutor::new()) + /// .pool_idle_timeout(Duration::from_secs(30)) + /// .pool_timer(TokioTimer::new()) + /// .build_http(); + /// + /// # let infer: Client<_, http_body_util::Full<bytes::Bytes>> = client; + /// # } + /// # fn main() {} + /// ``` + pub fn pool_idle_timeout<D>(&mut self, val: D) -> &mut Self + where + D: Into<Option<Duration>>, + { + self.pool_config.idle_timeout = val.into(); + self + } + + #[doc(hidden)] + #[deprecated(note = "renamed to `pool_max_idle_per_host`")] + pub fn max_idle_per_host(&mut self, max_idle: usize) -> &mut Self { + self.pool_config.max_idle_per_host = max_idle; + self + } + + /// Sets the maximum idle connection per host allowed in the pool. + /// + /// Default is `usize::MAX` (no limit). + pub fn pool_max_idle_per_host(&mut self, max_idle: usize) -> &mut Self { + self.pool_config.max_idle_per_host = max_idle; + self + } + + // HTTP/1 options + + /// Sets the exact size of the read buffer to *always* use. + /// + /// Note that setting this option unsets the `http1_max_buf_size` option. + /// + /// Default is an adaptive read buffer. + #[cfg(feature = "http1")] + #[cfg_attr(docsrs, doc(cfg(feature = "http1")))] + pub fn http1_read_buf_exact_size(&mut self, sz: usize) -> &mut Self { + self.h1_builder.read_buf_exact_size(Some(sz)); + self + } + + /// Set the maximum buffer size for the connection. + /// + /// Default is ~400kb. + /// + /// Note that setting this option unsets the `http1_read_exact_buf_size` option. + /// + /// # Panics + /// + /// The minimum value allowed is 8192. This method panics if the passed `max` is less than the minimum. + #[cfg(feature = "http1")] + #[cfg_attr(docsrs, doc(cfg(feature = "http1")))] + pub fn http1_max_buf_size(&mut self, max: usize) -> &mut Self { + self.h1_builder.max_buf_size(max); + self + } + + /// Set whether HTTP/1 connections will accept spaces between header names + /// and the colon that follow them in responses. + /// + /// Newline codepoints (`\r` and `\n`) will be transformed to spaces when + /// parsing. + /// + /// You probably don't need this, here is what [RFC 7230 Section 3.2.4.] has + /// to say about it: + /// + /// > No whitespace is allowed between the header field-name and colon. In + /// > the past, differences in the handling of such whitespace have led to + /// > security vulnerabilities in request routing and response handling. A + /// > server MUST reject any received request message that contains + /// > whitespace between a header field-name and colon with a response code + /// > of 400 (Bad Request). A proxy MUST remove any such whitespace from a + /// > response message before forwarding the message downstream. + /// + /// Note that this setting does not affect HTTP/2. + /// + /// Default is false. + /// + /// [RFC 7230 Section 3.2.4.]: https://tools.ietf.org/html/rfc7230#section-3.2.4 + #[cfg(feature = "http1")] + #[cfg_attr(docsrs, doc(cfg(feature = "http1")))] + pub fn http1_allow_spaces_after_header_name_in_responses(&mut self, val: bool) -> &mut Self { + self.h1_builder + .allow_spaces_after_header_name_in_responses(val); + self + } + + /// Set whether HTTP/1 connections will accept obsolete line folding for + /// header values. + /// + /// You probably don't need this, here is what [RFC 7230 Section 3.2.4.] has + /// to say about it: + /// + /// > A server that receives an obs-fold in a request message that is not + /// > within a message/http container MUST either reject the message by + /// > sending a 400 (Bad Request), preferably with a representation + /// > explaining that obsolete line folding is unacceptable, or replace + /// > each received obs-fold with one or more SP octets prior to + /// > interpreting the field value or forwarding the message downstream. + /// + /// > A proxy or gateway that receives an obs-fold in a response message + /// > that is not within a message/http container MUST either discard the + /// > message and replace it with a 502 (Bad Gateway) response, preferably + /// > with a representation explaining that unacceptable line folding was + /// > received, or replace each received obs-fold with one or more SP + /// > octets prior to interpreting the field value or forwarding the + /// > message downstream. + /// + /// > A user agent that receives an obs-fold in a response message that is + /// > not within a message/http container MUST replace each received + /// > obs-fold with one or more SP octets prior to interpreting the field + /// > value. + /// + /// Note that this setting does not affect HTTP/2. + /// + /// Default is false. + /// + /// [RFC 7230 Section 3.2.4.]: https://tools.ietf.org/html/rfc7230#section-3.2.4 + #[cfg(feature = "http1")] + #[cfg_attr(docsrs, doc(cfg(feature = "http1")))] + pub fn http1_allow_obsolete_multiline_headers_in_responses(&mut self, val: bool) -> &mut Self { + self.h1_builder + .allow_obsolete_multiline_headers_in_responses(val); + self + } + + /// Sets whether invalid header lines should be silently ignored in HTTP/1 responses. + /// + /// This mimics the behaviour of major browsers. You probably don't want this. + /// You should only want this if you are implementing a proxy whose main + /// purpose is to sit in front of browsers whose users access arbitrary content + /// which may be malformed, and they expect everything that works without + /// the proxy to keep working with the proxy. + /// + /// This option will prevent Hyper's client from returning an error encountered + /// when parsing a header, except if the error was caused by the character NUL + /// (ASCII code 0), as Chrome specifically always reject those. + /// + /// The ignorable errors are: + /// * empty header names; + /// * characters that are not allowed in header names, except for `\0` and `\r`; + /// * when `allow_spaces_after_header_name_in_responses` is not enabled, + /// spaces and tabs between the header name and the colon; + /// * missing colon between header name and colon; + /// * characters that are not allowed in header values except for `\0` and `\r`. + /// + /// If an ignorable error is encountered, the parser tries to find the next + /// line in the input to resume parsing the rest of the headers. An error + /// will be emitted nonetheless if it finds `\0` or a lone `\r` while + /// looking for the next line. + #[cfg(feature = "http1")] + #[cfg_attr(docsrs, doc(cfg(feature = "http1")))] + pub fn http1_ignore_invalid_headers_in_responses(&mut self, val: bool) -> &mut Builder { + self.h1_builder.ignore_invalid_headers_in_responses(val); + self + } + + /// Set whether HTTP/1 connections should try to use vectored writes, + /// or always flatten into a single buffer. + /// + /// Note that setting this to false may mean more copies of body data, + /// but may also improve performance when an IO transport doesn't + /// support vectored writes well, such as most TLS implementations. + /// + /// Setting this to true will force hyper to use queued strategy + /// which may eliminate unnecessary cloning on some TLS backends + /// + /// Default is `auto`. In this mode hyper will try to guess which + /// mode to use + #[cfg(feature = "http1")] + #[cfg_attr(docsrs, doc(cfg(feature = "http1")))] + pub fn http1_writev(&mut self, enabled: bool) -> &mut Builder { + self.h1_builder.writev(enabled); + self + } + + /// Set whether HTTP/1 connections will write header names as title case at + /// the socket level. + /// + /// Note that this setting does not affect HTTP/2. + /// + /// Default is false. + #[cfg(feature = "http1")] + #[cfg_attr(docsrs, doc(cfg(feature = "http1")))] + pub fn http1_title_case_headers(&mut self, val: bool) -> &mut Self { + self.h1_builder.title_case_headers(val); + self + } + + /// Set whether to support preserving original header cases. + /// + /// Currently, this will record the original cases received, and store them + /// in a private extension on the `Response`. It will also look for and use + /// such an extension in any provided `Request`. + /// + /// Since the relevant extension is still private, there is no way to + /// interact with the original cases. The only effect this can have now is + /// to forward the cases in a proxy-like fashion. + /// + /// Note that this setting does not affect HTTP/2. + /// + /// Default is false. + #[cfg(feature = "http1")] + #[cfg_attr(docsrs, doc(cfg(feature = "http1")))] + pub fn http1_preserve_header_case(&mut self, val: bool) -> &mut Self { + self.h1_builder.preserve_header_case(val); + self + } + + /// Set the maximum number of headers. + /// + /// When a response is received, the parser will reserve a buffer to store headers for optimal + /// performance. + /// + /// If client receives more headers than the buffer size, the error "message header too large" + /// is returned. + /// + /// The headers is allocated on the stack by default, which has higher performance. After + /// setting this value, headers will be allocated in heap memory, that is, heap memory + /// allocation will occur for each response, and there will be a performance drop of about 5%. + /// + /// Note that this setting does not affect HTTP/2. + /// + /// Default is 100. + #[cfg(feature = "http1")] + #[cfg_attr(docsrs, doc(cfg(feature = "http1")))] + pub fn http1_max_headers(&mut self, val: usize) -> &mut Self { + self.h1_builder.max_headers(val); + self + } + + /// Set whether HTTP/0.9 responses should be tolerated. + /// + /// Default is false. + #[cfg(feature = "http1")] + #[cfg_attr(docsrs, doc(cfg(feature = "http1")))] + pub fn http09_responses(&mut self, val: bool) -> &mut Self { + self.h1_builder.http09_responses(val); + self + } + + /// Set whether the connection **must** use HTTP/2. + /// + /// The destination must either allow HTTP2 Prior Knowledge, or the + /// `Connect` should be configured to do use ALPN to upgrade to `h2` + /// as part of the connection process. This will not make the `Client` + /// utilize ALPN by itself. + /// + /// Note that setting this to true prevents HTTP/1 from being allowed. + /// + /// Default is false. + #[cfg(feature = "http2")] + #[cfg_attr(docsrs, doc(cfg(feature = "http2")))] + pub fn http2_only(&mut self, val: bool) -> &mut Self { + self.client_config.ver = if val { Ver::Http2 } else { Ver::Auto }; + self + } + + /// Configures the maximum number of pending reset streams allowed before a GOAWAY will be sent. + /// + /// This will default to the default value set by the [`h2` crate](https://crates.io/crates/h2). + /// As of v0.4.0, it is 20. + /// + /// See <https://github.com/hyperium/hyper/issues/2877> for more information. + #[cfg(feature = "http2")] + #[cfg_attr(docsrs, doc(cfg(feature = "http2")))] + pub fn http2_max_pending_accept_reset_streams( + &mut self, + max: impl Into<Option<usize>>, + ) -> &mut Self { + self.h2_builder.max_pending_accept_reset_streams(max.into()); + self + } + + /// Sets the [`SETTINGS_INITIAL_WINDOW_SIZE`][spec] option for HTTP2 + /// stream-level flow control. + /// + /// Passing `None` will do nothing. + /// + /// If not set, hyper will use a default. + /// + /// [spec]: https://http2.github.io/http2-spec/#SETTINGS_INITIAL_WINDOW_SIZE + #[cfg(feature = "http2")] + #[cfg_attr(docsrs, doc(cfg(feature = "http2")))] + pub fn http2_initial_stream_window_size(&mut self, sz: impl Into<Option<u32>>) -> &mut Self { + self.h2_builder.initial_stream_window_size(sz.into()); + self + } + + /// Sets the max connection-level flow control for HTTP2 + /// + /// Passing `None` will do nothing. + /// + /// If not set, hyper will use a default. + #[cfg(feature = "http2")] + #[cfg_attr(docsrs, doc(cfg(feature = "http2")))] + pub fn http2_initial_connection_window_size( + &mut self, + sz: impl Into<Option<u32>>, + ) -> &mut Self { + self.h2_builder.initial_connection_window_size(sz.into()); + self + } + + /// Sets the initial maximum of locally initiated (send) streams. + /// + /// This value will be overwritten by the value included in the initial + /// SETTINGS frame received from the peer as part of a [connection preface]. + /// + /// Passing `None` will do nothing. + /// + /// If not set, hyper will use a default. + /// + /// [connection preface]: https://httpwg.org/specs/rfc9113.html#preface + #[cfg(feature = "http2")] + #[cfg_attr(docsrs, doc(cfg(feature = "http2")))] + pub fn http2_initial_max_send_streams( + &mut self, + initial: impl Into<Option<usize>>, + ) -> &mut Self { + self.h2_builder.initial_max_send_streams(initial); + self + } + + /// Sets whether to use an adaptive flow control. + /// + /// Enabling this will override the limits set in + /// `http2_initial_stream_window_size` and + /// `http2_initial_connection_window_size`. + #[cfg(feature = "http2")] + #[cfg_attr(docsrs, doc(cfg(feature = "http2")))] + pub fn http2_adaptive_window(&mut self, enabled: bool) -> &mut Self { + self.h2_builder.adaptive_window(enabled); + self + } + + /// Sets the maximum frame size to use for HTTP2. + /// + /// Passing `None` will do nothing. + /// + /// If not set, hyper will use a default. + #[cfg(feature = "http2")] + #[cfg_attr(docsrs, doc(cfg(feature = "http2")))] + pub fn http2_max_frame_size(&mut self, sz: impl Into<Option<u32>>) -> &mut Self { + self.h2_builder.max_frame_size(sz); + self + } + + /// Sets the max size of received header frames for HTTP2. + /// + /// Default is currently 16KB, but can change. + #[cfg(feature = "http2")] + #[cfg_attr(docsrs, doc(cfg(feature = "http2")))] + pub fn http2_max_header_list_size(&mut self, max: u32) -> &mut Self { + self.h2_builder.max_header_list_size(max); + self + } + + /// Sets an interval for HTTP2 Ping frames should be sent to keep a + /// connection alive. + /// + /// Pass `None` to disable HTTP2 keep-alive. + /// + /// Default is currently disabled. + /// + /// # Cargo Feature + /// + /// Requires the `tokio` cargo feature to be enabled. + #[cfg(feature = "tokio")] + #[cfg(feature = "http2")] + #[cfg_attr(docsrs, doc(cfg(feature = "http2")))] + pub fn http2_keep_alive_interval( + &mut self, + interval: impl Into<Option<Duration>>, + ) -> &mut Self { + self.h2_builder.keep_alive_interval(interval); + self + } + + /// Sets a timeout for receiving an acknowledgement of the keep-alive ping. + /// + /// If the ping is not acknowledged within the timeout, the connection will + /// be closed. Does nothing if `http2_keep_alive_interval` is disabled. + /// + /// Default is 20 seconds. + /// + /// # Cargo Feature + /// + /// Requires the `tokio` cargo feature to be enabled. + #[cfg(feature = "tokio")] + #[cfg(feature = "http2")] + #[cfg_attr(docsrs, doc(cfg(feature = "http2")))] + pub fn http2_keep_alive_timeout(&mut self, timeout: Duration) -> &mut Self { + self.h2_builder.keep_alive_timeout(timeout); + self + } + + /// Sets whether HTTP2 keep-alive should apply while the connection is idle. + /// + /// If disabled, keep-alive pings are only sent while there are open + /// request/responses streams. If enabled, pings are also sent when no + /// streams are active. Does nothing if `http2_keep_alive_interval` is + /// disabled. + /// + /// Default is `false`. + /// + /// # Cargo Feature + /// + /// Requires the `tokio` cargo feature to be enabled. + #[cfg(feature = "tokio")] + #[cfg(feature = "http2")] + #[cfg_attr(docsrs, doc(cfg(feature = "http2")))] + pub fn http2_keep_alive_while_idle(&mut self, enabled: bool) -> &mut Self { + self.h2_builder.keep_alive_while_idle(enabled); + self + } + + /// Sets the maximum number of HTTP2 concurrent locally reset streams. + /// + /// See the documentation of [`h2::client::Builder::max_concurrent_reset_streams`] for more + /// details. + /// + /// The default value is determined by the `h2` crate. + /// + /// [`h2::client::Builder::max_concurrent_reset_streams`]: https://docs.rs/h2/client/struct.Builder.html#method.max_concurrent_reset_streams + #[cfg(feature = "http2")] + #[cfg_attr(docsrs, doc(cfg(feature = "http2")))] + pub fn http2_max_concurrent_reset_streams(&mut self, max: usize) -> &mut Self { + self.h2_builder.max_concurrent_reset_streams(max); + self + } + + /// Provide a timer to be used for h2 + /// + /// See the documentation of [`h2::client::Builder::timer`] for more + /// details. + /// + /// [`h2::client::Builder::timer`]: https://docs.rs/h2/client/struct.Builder.html#method.timer + pub fn timer<M>(&mut self, timer: M) -> &mut Self + where + M: Timer + Send + Sync + 'static, + { + #[cfg(feature = "http2")] + self.h2_builder.timer(timer); + self + } + + /// Provide a timer to be used for timeouts and intervals in connection pools. + pub fn pool_timer<M>(&mut self, timer: M) -> &mut Self + where + M: Timer + Clone + Send + Sync + 'static, + { + self.pool_timer = Some(timer::Timer::new(timer.clone())); + self + } + + /// Set the maximum write buffer size for each HTTP/2 stream. + /// + /// Default is currently 1MB, but may change. + /// + /// # Panics + /// + /// The value must be no larger than `u32::MAX`. + #[cfg(feature = "http2")] + #[cfg_attr(docsrs, doc(cfg(feature = "http2")))] + pub fn http2_max_send_buf_size(&mut self, max: usize) -> &mut Self { + self.h2_builder.max_send_buf_size(max); + self + } + + /// Set whether to retry requests that get disrupted before ever starting + /// to write. + /// + /// This means a request that is queued, and gets given an idle, reused + /// connection, and then encounters an error immediately as the idle + /// connection was found to be unusable. + /// + /// When this is set to `false`, the related `ResponseFuture` would instead + /// resolve to an `Error::Cancel`. + /// + /// Default is `true`. + #[inline] + pub fn retry_canceled_requests(&mut self, val: bool) -> &mut Self { + self.client_config.retry_canceled_requests = val; + self + } + + /// Set whether to automatically add the `Host` header to requests. + /// + /// If true, and a request does not include a `Host` header, one will be + /// added automatically, derived from the authority of the `Uri`. + /// + /// Default is `true`. + #[inline] + pub fn set_host(&mut self, val: bool) -> &mut Self { + self.client_config.set_host = val; + self + } + + /// Build a client with this configuration and the default `HttpConnector`. + #[cfg(feature = "tokio")] + pub fn build_http<B>(&self) -> Client<HttpConnector, B> + where + B: Body + Send, + B::Data: Send, + { + let mut connector = HttpConnector::new(); + if self.pool_config.is_enabled() { + connector.set_keepalive(self.pool_config.idle_timeout); + } + self.build(connector) + } + + /// Combine the configuration of this builder with a connector to create a `Client`. + pub fn build<C, B>(&self, connector: C) -> Client<C, B> + where + C: Connect + Clone, + B: Body + Send, + B::Data: Send, + { + let exec = self.exec.clone(); + let timer = self.pool_timer.clone(); + Client { + config: self.client_config, + exec: exec.clone(), + #[cfg(feature = "http1")] + h1_builder: self.h1_builder.clone(), + #[cfg(feature = "http2")] + h2_builder: self.h2_builder.clone(), + connector, + pool: pool::Pool::new(self.pool_config, exec, timer), + } + } +} + +impl fmt::Debug for Builder { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Builder") + .field("client_config", &self.client_config) + .field("pool_config", &self.pool_config) + .finish() + } +} + +// ==== impl Error ==== + +impl fmt::Debug for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut f = f.debug_tuple("hyper_util::client::legacy::Error"); + f.field(&self.kind); + if let Some(ref cause) = self.source { + f.field(cause); + } + f.finish() + } +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "client error ({:?})", self.kind) + } +} + +impl StdError for Error { + fn source(&self) -> Option<&(dyn StdError + 'static)> { + self.source.as_ref().map(|e| &**e as _) + } +} + +impl Error { + /// Returns true if this was an error from `Connect`. + pub fn is_connect(&self) -> bool { + matches!(self.kind, ErrorKind::Connect) + } + + /// Returns the info of the client connection on which this error occurred. + #[cfg(any(feature = "http1", feature = "http2"))] + pub fn connect_info(&self) -> Option<&Connected> { + self.connect_info.as_ref() + } + + #[cfg(any(feature = "http1", feature = "http2"))] + fn with_connect_info(self, connect_info: Connected) -> Self { + Self { + connect_info: Some(connect_info), + ..self + } + } + fn is_canceled(&self) -> bool { + matches!(self.kind, ErrorKind::Canceled) + } + + fn tx(src: hyper::Error) -> Self { + e!(SendRequest, src) + } + + fn closed(src: hyper::Error) -> Self { + e!(ChannelClosed, src) + } +} diff --git a/vendor/hyper-util/src/client/legacy/connect/capture.rs b/vendor/hyper-util/src/client/legacy/connect/capture.rs new file mode 100644 index 00000000..b31b6433 --- /dev/null +++ b/vendor/hyper-util/src/client/legacy/connect/capture.rs @@ -0,0 +1,187 @@ +use std::{ops::Deref, sync::Arc}; + +use http::Request; +use tokio::sync::watch; + +use super::Connected; + +/// [`CaptureConnection`] allows callers to capture [`Connected`] information +/// +/// To capture a connection for a request, use [`capture_connection`]. +#[derive(Debug, Clone)] +pub struct CaptureConnection { + rx: watch::Receiver<Option<Connected>>, +} + +/// Capture the connection for a given request +/// +/// When making a request with Hyper, the underlying connection must implement the [`Connection`] trait. +/// [`capture_connection`] allows a caller to capture the returned [`Connected`] structure as soon +/// as the connection is established. +/// +/// [`Connection`]: crate::client::legacy::connect::Connection +/// +/// *Note*: If establishing a connection fails, [`CaptureConnection::connection_metadata`] will always return none. +/// +/// # Examples +/// +/// **Synchronous access**: +/// The [`CaptureConnection::connection_metadata`] method allows callers to check if a connection has been +/// established. This is ideal for situations where you are certain the connection has already +/// been established (e.g. after the response future has already completed). +/// ```rust +/// use hyper_util::client::legacy::connect::capture_connection; +/// let mut request = http::Request::builder() +/// .uri("http://foo.com") +/// .body(()) +/// .unwrap(); +/// +/// let captured_connection = capture_connection(&mut request); +/// // some time later after the request has been sent... +/// let connection_info = captured_connection.connection_metadata(); +/// println!("we are connected! {:?}", connection_info.as_ref()); +/// ``` +/// +/// **Asynchronous access**: +/// The [`CaptureConnection::wait_for_connection_metadata`] method returns a future resolves as soon as the +/// connection is available. +/// +/// ```rust +/// # #[cfg(feature = "tokio")] +/// # async fn example() { +/// use hyper_util::client::legacy::connect::capture_connection; +/// use hyper_util::client::legacy::Client; +/// use hyper_util::rt::TokioExecutor; +/// use bytes::Bytes; +/// use http_body_util::Empty; +/// let mut request = http::Request::builder() +/// .uri("http://foo.com") +/// .body(Empty::<Bytes>::new()) +/// .unwrap(); +/// +/// let mut captured = capture_connection(&mut request); +/// tokio::task::spawn(async move { +/// let connection_info = captured.wait_for_connection_metadata().await; +/// println!("we are connected! {:?}", connection_info.as_ref()); +/// }); +/// +/// let client = Client::builder(TokioExecutor::new()).build_http(); +/// client.request(request).await.expect("request failed"); +/// # } +/// ``` +pub fn capture_connection<B>(request: &mut Request<B>) -> CaptureConnection { + let (tx, rx) = CaptureConnection::new(); + request.extensions_mut().insert(tx); + rx +} + +/// TxSide for [`CaptureConnection`] +/// +/// This is inserted into `Extensions` to allow Hyper to back channel connection info +#[derive(Clone)] +pub(crate) struct CaptureConnectionExtension { + tx: Arc<watch::Sender<Option<Connected>>>, +} + +impl CaptureConnectionExtension { + pub(crate) fn set(&self, connected: &Connected) { + self.tx.send_replace(Some(connected.clone())); + } +} + +impl CaptureConnection { + /// Internal API to create the tx and rx half of [`CaptureConnection`] + pub(crate) fn new() -> (CaptureConnectionExtension, Self) { + let (tx, rx) = watch::channel(None); + ( + CaptureConnectionExtension { tx: Arc::new(tx) }, + CaptureConnection { rx }, + ) + } + + /// Retrieve the connection metadata, if available + pub fn connection_metadata(&self) -> impl Deref<Target = Option<Connected>> + '_ { + self.rx.borrow() + } + + /// Wait for the connection to be established + /// + /// If a connection was established, this will always return `Some(...)`. If the request never + /// successfully connected (e.g. DNS resolution failure), this method will never return. + pub async fn wait_for_connection_metadata( + &mut self, + ) -> impl Deref<Target = Option<Connected>> + '_ { + if self.rx.borrow().is_some() { + return self.rx.borrow(); + } + let _ = self.rx.changed().await; + self.rx.borrow() + } +} + +#[cfg(all(test, not(miri)))] +mod test { + use super::*; + + #[test] + fn test_sync_capture_connection() { + let (tx, rx) = CaptureConnection::new(); + assert!( + rx.connection_metadata().is_none(), + "connection has not been set" + ); + tx.set(&Connected::new().proxy(true)); + assert!(rx + .connection_metadata() + .as_ref() + .expect("connected should be set") + .is_proxied()); + + // ensure it can be called multiple times + assert!(rx + .connection_metadata() + .as_ref() + .expect("connected should be set") + .is_proxied()); + } + + #[tokio::test] + async fn async_capture_connection() { + let (tx, mut rx) = CaptureConnection::new(); + assert!( + rx.connection_metadata().is_none(), + "connection has not been set" + ); + let test_task = tokio::spawn(async move { + assert!(rx + .wait_for_connection_metadata() + .await + .as_ref() + .expect("connection should be set") + .is_proxied()); + // can be awaited multiple times + assert!( + rx.wait_for_connection_metadata().await.is_some(), + "should be awaitable multiple times" + ); + + assert!(rx.connection_metadata().is_some()); + }); + // can't be finished, we haven't set the connection yet + assert!(!test_task.is_finished()); + tx.set(&Connected::new().proxy(true)); + + assert!(test_task.await.is_ok()); + } + + #[tokio::test] + async fn capture_connection_sender_side_dropped() { + let (tx, mut rx) = CaptureConnection::new(); + assert!( + rx.connection_metadata().is_none(), + "connection has not been set" + ); + drop(tx); + assert!(rx.wait_for_connection_metadata().await.is_none()); + } +} diff --git a/vendor/hyper-util/src/client/legacy/connect/dns.rs b/vendor/hyper-util/src/client/legacy/connect/dns.rs new file mode 100644 index 00000000..abeb2cca --- /dev/null +++ b/vendor/hyper-util/src/client/legacy/connect/dns.rs @@ -0,0 +1,360 @@ +//! DNS Resolution used by the `HttpConnector`. +//! +//! This module contains: +//! +//! - A [`GaiResolver`] that is the default resolver for the `HttpConnector`. +//! - The `Name` type used as an argument to custom resolvers. +//! +//! # Resolvers are `Service`s +//! +//! A resolver is just a +//! `Service<Name, Response = impl Iterator<Item = SocketAddr>>`. +//! +//! A simple resolver that ignores the name and always returns a specific +//! address: +//! +//! ```rust,ignore +//! use std::{convert::Infallible, iter, net::SocketAddr}; +//! +//! let resolver = tower::service_fn(|_name| async { +//! Ok::<_, Infallible>(iter::once(SocketAddr::from(([127, 0, 0, 1], 8080)))) +//! }); +//! ``` +use std::error::Error; +use std::future::Future; +use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6, ToSocketAddrs}; +use std::pin::Pin; +use std::str::FromStr; +use std::task::{self, Poll}; +use std::{fmt, io, vec}; + +use tokio::task::JoinHandle; +use tower_service::Service; + +pub(super) use self::sealed::Resolve; + +/// A domain name to resolve into IP addresses. +#[derive(Clone, Hash, Eq, PartialEq)] +pub struct Name { + host: Box<str>, +} + +/// A resolver using blocking `getaddrinfo` calls in a threadpool. +#[derive(Clone)] +pub struct GaiResolver { + _priv: (), +} + +/// An iterator of IP addresses returned from `getaddrinfo`. +pub struct GaiAddrs { + inner: SocketAddrs, +} + +/// A future to resolve a name returned by `GaiResolver`. +pub struct GaiFuture { + inner: JoinHandle<Result<SocketAddrs, io::Error>>, +} + +impl Name { + pub(super) fn new(host: Box<str>) -> Name { + Name { host } + } + + /// View the hostname as a string slice. + pub fn as_str(&self) -> &str { + &self.host + } +} + +impl fmt::Debug for Name { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Debug::fmt(&self.host, f) + } +} + +impl fmt::Display for Name { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Display::fmt(&self.host, f) + } +} + +impl FromStr for Name { + type Err = InvalidNameError; + + fn from_str(host: &str) -> Result<Self, Self::Err> { + // Possibly add validation later + Ok(Name::new(host.into())) + } +} + +/// Error indicating a given string was not a valid domain name. +#[derive(Debug)] +pub struct InvalidNameError(()); + +impl fmt::Display for InvalidNameError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("Not a valid domain name") + } +} + +impl Error for InvalidNameError {} + +impl GaiResolver { + /// Construct a new `GaiResolver`. + pub fn new() -> Self { + GaiResolver { _priv: () } + } +} + +impl Service<Name> for GaiResolver { + type Response = GaiAddrs; + type Error = io::Error; + type Future = GaiFuture; + + fn poll_ready(&mut self, _cx: &mut task::Context<'_>) -> Poll<Result<(), io::Error>> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, name: Name) -> Self::Future { + let blocking = tokio::task::spawn_blocking(move || { + (&*name.host, 0) + .to_socket_addrs() + .map(|i| SocketAddrs { iter: i }) + }); + + GaiFuture { inner: blocking } + } +} + +impl fmt::Debug for GaiResolver { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.pad("GaiResolver") + } +} + +impl Future for GaiFuture { + type Output = Result<GaiAddrs, io::Error>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> { + Pin::new(&mut self.inner).poll(cx).map(|res| match res { + Ok(Ok(addrs)) => Ok(GaiAddrs { inner: addrs }), + Ok(Err(err)) => Err(err), + Err(join_err) => { + if join_err.is_cancelled() { + Err(io::Error::new(io::ErrorKind::Interrupted, join_err)) + } else { + panic!("gai background task failed: {join_err:?}") + } + } + }) + } +} + +impl fmt::Debug for GaiFuture { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.pad("GaiFuture") + } +} + +impl Drop for GaiFuture { + fn drop(&mut self) { + self.inner.abort(); + } +} + +impl Iterator for GaiAddrs { + type Item = SocketAddr; + + fn next(&mut self) -> Option<Self::Item> { + self.inner.next() + } +} + +impl fmt::Debug for GaiAddrs { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.pad("GaiAddrs") + } +} + +pub(super) struct SocketAddrs { + iter: vec::IntoIter<SocketAddr>, +} + +impl SocketAddrs { + pub(super) fn new(addrs: Vec<SocketAddr>) -> Self { + SocketAddrs { + iter: addrs.into_iter(), + } + } + + pub(super) fn try_parse(host: &str, port: u16) -> Option<SocketAddrs> { + if let Ok(addr) = host.parse::<Ipv4Addr>() { + let addr = SocketAddrV4::new(addr, port); + return Some(SocketAddrs { + iter: vec![SocketAddr::V4(addr)].into_iter(), + }); + } + if let Ok(addr) = host.parse::<Ipv6Addr>() { + let addr = SocketAddrV6::new(addr, port, 0, 0); + return Some(SocketAddrs { + iter: vec![SocketAddr::V6(addr)].into_iter(), + }); + } + None + } + + #[inline] + fn filter(self, predicate: impl FnMut(&SocketAddr) -> bool) -> SocketAddrs { + SocketAddrs::new(self.iter.filter(predicate).collect()) + } + + pub(super) fn split_by_preference( + self, + local_addr_ipv4: Option<Ipv4Addr>, + local_addr_ipv6: Option<Ipv6Addr>, + ) -> (SocketAddrs, SocketAddrs) { + match (local_addr_ipv4, local_addr_ipv6) { + (Some(_), None) => (self.filter(SocketAddr::is_ipv4), SocketAddrs::new(vec![])), + (None, Some(_)) => (self.filter(SocketAddr::is_ipv6), SocketAddrs::new(vec![])), + _ => { + let preferring_v6 = self + .iter + .as_slice() + .first() + .map(SocketAddr::is_ipv6) + .unwrap_or(false); + + let (preferred, fallback) = self + .iter + .partition::<Vec<_>, _>(|addr| addr.is_ipv6() == preferring_v6); + + (SocketAddrs::new(preferred), SocketAddrs::new(fallback)) + } + } + } + + pub(super) fn is_empty(&self) -> bool { + self.iter.as_slice().is_empty() + } + + pub(super) fn len(&self) -> usize { + self.iter.as_slice().len() + } +} + +impl Iterator for SocketAddrs { + type Item = SocketAddr; + #[inline] + fn next(&mut self) -> Option<SocketAddr> { + self.iter.next() + } +} + +mod sealed { + use std::future::Future; + use std::task::{self, Poll}; + + use super::{Name, SocketAddr}; + use tower_service::Service; + + // "Trait alias" for `Service<Name, Response = Addrs>` + pub trait Resolve { + type Addrs: Iterator<Item = SocketAddr>; + type Error: Into<Box<dyn std::error::Error + Send + Sync>>; + type Future: Future<Output = Result<Self::Addrs, Self::Error>>; + + fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>>; + fn resolve(&mut self, name: Name) -> Self::Future; + } + + impl<S> Resolve for S + where + S: Service<Name>, + S::Response: Iterator<Item = SocketAddr>, + S::Error: Into<Box<dyn std::error::Error + Send + Sync>>, + { + type Addrs = S::Response; + type Error = S::Error; + type Future = S::Future; + + fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> { + Service::poll_ready(self, cx) + } + + fn resolve(&mut self, name: Name) -> Self::Future { + Service::call(self, name) + } + } +} + +pub(super) async fn resolve<R>(resolver: &mut R, name: Name) -> Result<R::Addrs, R::Error> +where + R: Resolve, +{ + crate::common::future::poll_fn(|cx| resolver.poll_ready(cx)).await?; + resolver.resolve(name).await +} + +#[cfg(test)] +mod tests { + use super::*; + use std::net::{Ipv4Addr, Ipv6Addr}; + + #[test] + fn test_ip_addrs_split_by_preference() { + let ip_v4 = Ipv4Addr::new(127, 0, 0, 1); + let ip_v6 = Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1); + let v4_addr = (ip_v4, 80).into(); + let v6_addr = (ip_v6, 80).into(); + + let (mut preferred, mut fallback) = SocketAddrs { + iter: vec![v4_addr, v6_addr].into_iter(), + } + .split_by_preference(None, None); + assert!(preferred.next().unwrap().is_ipv4()); + assert!(fallback.next().unwrap().is_ipv6()); + + let (mut preferred, mut fallback) = SocketAddrs { + iter: vec![v6_addr, v4_addr].into_iter(), + } + .split_by_preference(None, None); + assert!(preferred.next().unwrap().is_ipv6()); + assert!(fallback.next().unwrap().is_ipv4()); + + let (mut preferred, mut fallback) = SocketAddrs { + iter: vec![v4_addr, v6_addr].into_iter(), + } + .split_by_preference(Some(ip_v4), Some(ip_v6)); + assert!(preferred.next().unwrap().is_ipv4()); + assert!(fallback.next().unwrap().is_ipv6()); + + let (mut preferred, mut fallback) = SocketAddrs { + iter: vec![v6_addr, v4_addr].into_iter(), + } + .split_by_preference(Some(ip_v4), Some(ip_v6)); + assert!(preferred.next().unwrap().is_ipv6()); + assert!(fallback.next().unwrap().is_ipv4()); + + let (mut preferred, fallback) = SocketAddrs { + iter: vec![v4_addr, v6_addr].into_iter(), + } + .split_by_preference(Some(ip_v4), None); + assert!(preferred.next().unwrap().is_ipv4()); + assert!(fallback.is_empty()); + + let (mut preferred, fallback) = SocketAddrs { + iter: vec![v4_addr, v6_addr].into_iter(), + } + .split_by_preference(None, Some(ip_v6)); + assert!(preferred.next().unwrap().is_ipv6()); + assert!(fallback.is_empty()); + } + + #[test] + fn test_name_from_str() { + const DOMAIN: &str = "test.example.com"; + let name = Name::from_str(DOMAIN).expect("Should be a valid domain"); + assert_eq!(name.as_str(), DOMAIN); + assert_eq!(name.to_string(), DOMAIN); + } +} diff --git a/vendor/hyper-util/src/client/legacy/connect/http.rs b/vendor/hyper-util/src/client/legacy/connect/http.rs new file mode 100644 index 00000000..f19a78eb --- /dev/null +++ b/vendor/hyper-util/src/client/legacy/connect/http.rs @@ -0,0 +1,1468 @@ +use std::error::Error as StdError; +use std::fmt; +use std::future::Future; +use std::io; +use std::marker::PhantomData; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{self, Poll}; +use std::time::Duration; + +use futures_core::ready; +use futures_util::future::Either; +use http::uri::{Scheme, Uri}; +use pin_project_lite::pin_project; +use socket2::TcpKeepalive; +use tokio::net::{TcpSocket, TcpStream}; +use tokio::time::Sleep; +use tracing::{debug, trace, warn}; + +use super::dns::{self, resolve, GaiResolver, Resolve}; +use super::{Connected, Connection}; +use crate::rt::TokioIo; + +/// A connector for the `http` scheme. +/// +/// Performs DNS resolution in a thread pool, and then connects over TCP. +/// +/// # Note +/// +/// Sets the [`HttpInfo`](HttpInfo) value on responses, which includes +/// transport information such as the remote socket address used. +#[derive(Clone)] +pub struct HttpConnector<R = GaiResolver> { + config: Arc<Config>, + resolver: R, +} + +/// Extra information about the transport when an HttpConnector is used. +/// +/// # Example +/// +/// ``` +/// # fn doc(res: http::Response<()>) { +/// use hyper_util::client::legacy::connect::HttpInfo; +/// +/// // res = http::Response +/// res +/// .extensions() +/// .get::<HttpInfo>() +/// .map(|info| { +/// println!("remote addr = {}", info.remote_addr()); +/// }); +/// # } +/// ``` +/// +/// # Note +/// +/// If a different connector is used besides [`HttpConnector`](HttpConnector), +/// this value will not exist in the extensions. Consult that specific +/// connector to see what "extra" information it might provide to responses. +#[derive(Clone, Debug)] +pub struct HttpInfo { + remote_addr: SocketAddr, + local_addr: SocketAddr, +} + +#[derive(Clone)] +struct Config { + connect_timeout: Option<Duration>, + enforce_http: bool, + happy_eyeballs_timeout: Option<Duration>, + tcp_keepalive_config: TcpKeepaliveConfig, + local_address_ipv4: Option<Ipv4Addr>, + local_address_ipv6: Option<Ipv6Addr>, + nodelay: bool, + reuse_address: bool, + send_buffer_size: Option<usize>, + recv_buffer_size: Option<usize>, + #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))] + interface: Option<String>, + #[cfg(any( + target_os = "illumos", + target_os = "ios", + target_os = "macos", + target_os = "solaris", + target_os = "tvos", + target_os = "visionos", + target_os = "watchos", + ))] + interface: Option<std::ffi::CString>, + #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))] + tcp_user_timeout: Option<Duration>, +} + +#[derive(Default, Debug, Clone, Copy)] +struct TcpKeepaliveConfig { + time: Option<Duration>, + interval: Option<Duration>, + retries: Option<u32>, +} + +impl TcpKeepaliveConfig { + /// Converts into a `socket2::TcpKeealive` if there is any keep alive configuration. + fn into_tcpkeepalive(self) -> Option<TcpKeepalive> { + let mut dirty = false; + let mut ka = TcpKeepalive::new(); + if let Some(time) = self.time { + ka = ka.with_time(time); + dirty = true + } + if let Some(interval) = self.interval { + ka = Self::ka_with_interval(ka, interval, &mut dirty) + }; + if let Some(retries) = self.retries { + ka = Self::ka_with_retries(ka, retries, &mut dirty) + }; + if dirty { + Some(ka) + } else { + None + } + } + + #[cfg( + // See https://docs.rs/socket2/0.5.8/src/socket2/lib.rs.html#511-525 + any( + target_os = "android", + target_os = "dragonfly", + target_os = "freebsd", + target_os = "fuchsia", + target_os = "illumos", + target_os = "ios", + target_os = "visionos", + target_os = "linux", + target_os = "macos", + target_os = "netbsd", + target_os = "tvos", + target_os = "watchos", + target_os = "windows", + ) + )] + fn ka_with_interval(ka: TcpKeepalive, interval: Duration, dirty: &mut bool) -> TcpKeepalive { + *dirty = true; + ka.with_interval(interval) + } + + #[cfg(not( + // See https://docs.rs/socket2/0.5.8/src/socket2/lib.rs.html#511-525 + any( + target_os = "android", + target_os = "dragonfly", + target_os = "freebsd", + target_os = "fuchsia", + target_os = "illumos", + target_os = "ios", + target_os = "visionos", + target_os = "linux", + target_os = "macos", + target_os = "netbsd", + target_os = "tvos", + target_os = "watchos", + target_os = "windows", + ) + ))] + fn ka_with_interval(ka: TcpKeepalive, _: Duration, _: &mut bool) -> TcpKeepalive { + ka // no-op as keepalive interval is not supported on this platform + } + + #[cfg( + // See https://docs.rs/socket2/0.5.8/src/socket2/lib.rs.html#557-570 + any( + target_os = "android", + target_os = "dragonfly", + target_os = "freebsd", + target_os = "fuchsia", + target_os = "illumos", + target_os = "ios", + target_os = "visionos", + target_os = "linux", + target_os = "macos", + target_os = "netbsd", + target_os = "tvos", + target_os = "watchos", + ) + )] + fn ka_with_retries(ka: TcpKeepalive, retries: u32, dirty: &mut bool) -> TcpKeepalive { + *dirty = true; + ka.with_retries(retries) + } + + #[cfg(not( + // See https://docs.rs/socket2/0.5.8/src/socket2/lib.rs.html#557-570 + any( + target_os = "android", + target_os = "dragonfly", + target_os = "freebsd", + target_os = "fuchsia", + target_os = "illumos", + target_os = "ios", + target_os = "visionos", + target_os = "linux", + target_os = "macos", + target_os = "netbsd", + target_os = "tvos", + target_os = "watchos", + ) + ))] + fn ka_with_retries(ka: TcpKeepalive, _: u32, _: &mut bool) -> TcpKeepalive { + ka // no-op as keepalive retries is not supported on this platform + } +} + +// ===== impl HttpConnector ===== + +impl HttpConnector { + /// Construct a new HttpConnector. + pub fn new() -> HttpConnector { + HttpConnector::new_with_resolver(GaiResolver::new()) + } +} + +impl<R> HttpConnector<R> { + /// Construct a new HttpConnector. + /// + /// Takes a [`Resolver`](crate::client::legacy::connect::dns#resolvers-are-services) to handle DNS lookups. + pub fn new_with_resolver(resolver: R) -> HttpConnector<R> { + HttpConnector { + config: Arc::new(Config { + connect_timeout: None, + enforce_http: true, + happy_eyeballs_timeout: Some(Duration::from_millis(300)), + tcp_keepalive_config: TcpKeepaliveConfig::default(), + local_address_ipv4: None, + local_address_ipv6: None, + nodelay: false, + reuse_address: false, + send_buffer_size: None, + recv_buffer_size: None, + #[cfg(any( + target_os = "android", + target_os = "fuchsia", + target_os = "illumos", + target_os = "ios", + target_os = "linux", + target_os = "macos", + target_os = "solaris", + target_os = "tvos", + target_os = "visionos", + target_os = "watchos", + ))] + interface: None, + #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))] + tcp_user_timeout: None, + }), + resolver, + } + } + + /// Option to enforce all `Uri`s have the `http` scheme. + /// + /// Enabled by default. + #[inline] + pub fn enforce_http(&mut self, is_enforced: bool) { + self.config_mut().enforce_http = is_enforced; + } + + /// Set that all sockets have `SO_KEEPALIVE` set with the supplied duration + /// to remain idle before sending TCP keepalive probes. + /// + /// If `None`, keepalive is disabled. + /// + /// Default is `None`. + #[inline] + pub fn set_keepalive(&mut self, time: Option<Duration>) { + self.config_mut().tcp_keepalive_config.time = time; + } + + /// Set the duration between two successive TCP keepalive retransmissions, + /// if acknowledgement to the previous keepalive transmission is not received. + #[inline] + pub fn set_keepalive_interval(&mut self, interval: Option<Duration>) { + self.config_mut().tcp_keepalive_config.interval = interval; + } + + /// Set the number of retransmissions to be carried out before declaring that remote end is not available. + #[inline] + pub fn set_keepalive_retries(&mut self, retries: Option<u32>) { + self.config_mut().tcp_keepalive_config.retries = retries; + } + + /// Set that all sockets have `SO_NODELAY` set to the supplied value `nodelay`. + /// + /// Default is `false`. + #[inline] + pub fn set_nodelay(&mut self, nodelay: bool) { + self.config_mut().nodelay = nodelay; + } + + /// Sets the value of the SO_SNDBUF option on the socket. + #[inline] + pub fn set_send_buffer_size(&mut self, size: Option<usize>) { + self.config_mut().send_buffer_size = size; + } + + /// Sets the value of the SO_RCVBUF option on the socket. + #[inline] + pub fn set_recv_buffer_size(&mut self, size: Option<usize>) { + self.config_mut().recv_buffer_size = size; + } + + /// Set that all sockets are bound to the configured address before connection. + /// + /// If `None`, the sockets will not be bound. + /// + /// Default is `None`. + #[inline] + pub fn set_local_address(&mut self, addr: Option<IpAddr>) { + let (v4, v6) = match addr { + Some(IpAddr::V4(a)) => (Some(a), None), + Some(IpAddr::V6(a)) => (None, Some(a)), + _ => (None, None), + }; + + let cfg = self.config_mut(); + + cfg.local_address_ipv4 = v4; + cfg.local_address_ipv6 = v6; + } + + /// Set that all sockets are bound to the configured IPv4 or IPv6 address (depending on host's + /// preferences) before connection. + #[inline] + pub fn set_local_addresses(&mut self, addr_ipv4: Ipv4Addr, addr_ipv6: Ipv6Addr) { + let cfg = self.config_mut(); + + cfg.local_address_ipv4 = Some(addr_ipv4); + cfg.local_address_ipv6 = Some(addr_ipv6); + } + + /// Set the connect timeout. + /// + /// If a domain resolves to multiple IP addresses, the timeout will be + /// evenly divided across them. + /// + /// Default is `None`. + #[inline] + pub fn set_connect_timeout(&mut self, dur: Option<Duration>) { + self.config_mut().connect_timeout = dur; + } + + /// Set timeout for [RFC 6555 (Happy Eyeballs)][RFC 6555] algorithm. + /// + /// If hostname resolves to both IPv4 and IPv6 addresses and connection + /// cannot be established using preferred address family before timeout + /// elapses, then connector will in parallel attempt connection using other + /// address family. + /// + /// If `None`, parallel connection attempts are disabled. + /// + /// Default is 300 milliseconds. + /// + /// [RFC 6555]: https://tools.ietf.org/html/rfc6555 + #[inline] + pub fn set_happy_eyeballs_timeout(&mut self, dur: Option<Duration>) { + self.config_mut().happy_eyeballs_timeout = dur; + } + + /// Set that all socket have `SO_REUSEADDR` set to the supplied value `reuse_address`. + /// + /// Default is `false`. + #[inline] + pub fn set_reuse_address(&mut self, reuse_address: bool) -> &mut Self { + self.config_mut().reuse_address = reuse_address; + self + } + + /// Sets the name of the interface to bind sockets produced by this + /// connector. + /// + /// On Linux, this sets the `SO_BINDTODEVICE` option on this socket (see + /// [`man 7 socket`] for details). On macOS (and macOS-derived systems like + /// iOS), illumos, and Solaris, this will instead use the `IP_BOUND_IF` + /// socket option (see [`man 7p ip`]). + /// + /// If a socket is bound to an interface, only packets received from that particular + /// interface are processed by the socket. Note that this only works for some socket + /// types, particularly `AF_INET`` sockets. + /// + /// On Linux it can be used to specify a [VRF], but the binary needs + /// to either have `CAP_NET_RAW` or to be run as root. + /// + /// This function is only available on the following operating systems: + /// - Linux, including Android + /// - Fuchsia + /// - illumos and Solaris + /// - macOS, iOS, visionOS, watchOS, and tvOS + /// + /// [VRF]: https://www.kernel.org/doc/Documentation/networking/vrf.txt + /// [`man 7 socket`] https://man7.org/linux/man-pages/man7/socket.7.html + /// [`man 7p ip`]: https://docs.oracle.com/cd/E86824_01/html/E54777/ip-7p.html + #[cfg(any( + target_os = "android", + target_os = "fuchsia", + target_os = "illumos", + target_os = "ios", + target_os = "linux", + target_os = "macos", + target_os = "solaris", + target_os = "tvos", + target_os = "visionos", + target_os = "watchos", + ))] + #[inline] + pub fn set_interface<S: Into<String>>(&mut self, interface: S) -> &mut Self { + let interface = interface.into(); + #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))] + { + self.config_mut().interface = Some(interface); + } + #[cfg(not(any(target_os = "android", target_os = "fuchsia", target_os = "linux")))] + { + let interface = std::ffi::CString::new(interface) + .expect("interface name should not have nulls in it"); + self.config_mut().interface = Some(interface); + } + self + } + + /// Sets the value of the TCP_USER_TIMEOUT option on the socket. + #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))] + #[inline] + pub fn set_tcp_user_timeout(&mut self, time: Option<Duration>) { + self.config_mut().tcp_user_timeout = time; + } + + // private + + fn config_mut(&mut self) -> &mut Config { + // If the are HttpConnector clones, this will clone the inner + // config. So mutating the config won't ever affect previous + // clones. + Arc::make_mut(&mut self.config) + } +} + +static INVALID_NOT_HTTP: &str = "invalid URL, scheme is not http"; +static INVALID_MISSING_SCHEME: &str = "invalid URL, scheme is missing"; +static INVALID_MISSING_HOST: &str = "invalid URL, host is missing"; + +// R: Debug required for now to allow adding it to debug output later... +impl<R: fmt::Debug> fmt::Debug for HttpConnector<R> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("HttpConnector").finish() + } +} + +impl<R> tower_service::Service<Uri> for HttpConnector<R> +where + R: Resolve + Clone + Send + Sync + 'static, + R::Future: Send, +{ + type Response = TokioIo<TcpStream>; + type Error = ConnectError; + type Future = HttpConnecting<R>; + + fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> { + ready!(self.resolver.poll_ready(cx)).map_err(ConnectError::dns)?; + Poll::Ready(Ok(())) + } + + fn call(&mut self, dst: Uri) -> Self::Future { + let mut self_ = self.clone(); + HttpConnecting { + fut: Box::pin(async move { self_.call_async(dst).await }), + _marker: PhantomData, + } + } +} + +fn get_host_port<'u>(config: &Config, dst: &'u Uri) -> Result<(&'u str, u16), ConnectError> { + trace!( + "Http::connect; scheme={:?}, host={:?}, port={:?}", + dst.scheme(), + dst.host(), + dst.port(), + ); + + if config.enforce_http { + if dst.scheme() != Some(&Scheme::HTTP) { + return Err(ConnectError { + msg: INVALID_NOT_HTTP, + addr: None, + cause: None, + }); + } + } else if dst.scheme().is_none() { + return Err(ConnectError { + msg: INVALID_MISSING_SCHEME, + addr: None, + cause: None, + }); + } + + let host = match dst.host() { + Some(s) => s, + None => { + return Err(ConnectError { + msg: INVALID_MISSING_HOST, + addr: None, + cause: None, + }) + } + }; + let port = match dst.port() { + Some(port) => port.as_u16(), + None => { + if dst.scheme() == Some(&Scheme::HTTPS) { + 443 + } else { + 80 + } + } + }; + + Ok((host, port)) +} + +impl<R> HttpConnector<R> +where + R: Resolve, +{ + async fn call_async(&mut self, dst: Uri) -> Result<TokioIo<TcpStream>, ConnectError> { + let config = &self.config; + + let (host, port) = get_host_port(config, &dst)?; + let host = host.trim_start_matches('[').trim_end_matches(']'); + + // If the host is already an IP addr (v4 or v6), + // skip resolving the dns and start connecting right away. + let addrs = if let Some(addrs) = dns::SocketAddrs::try_parse(host, port) { + addrs + } else { + let addrs = resolve(&mut self.resolver, dns::Name::new(host.into())) + .await + .map_err(ConnectError::dns)?; + let addrs = addrs + .map(|mut addr| { + set_port(&mut addr, port, dst.port().is_some()); + + addr + }) + .collect(); + dns::SocketAddrs::new(addrs) + }; + + let c = ConnectingTcp::new(addrs, config); + + let sock = c.connect().await?; + + if let Err(e) = sock.set_nodelay(config.nodelay) { + warn!("tcp set_nodelay error: {}", e); + } + + Ok(TokioIo::new(sock)) + } +} + +impl Connection for TcpStream { + fn connected(&self) -> Connected { + let connected = Connected::new(); + if let (Ok(remote_addr), Ok(local_addr)) = (self.peer_addr(), self.local_addr()) { + connected.extra(HttpInfo { + remote_addr, + local_addr, + }) + } else { + connected + } + } +} + +#[cfg(unix)] +impl Connection for tokio::net::UnixStream { + fn connected(&self) -> Connected { + Connected::new() + } +} + +#[cfg(windows)] +impl Connection for tokio::net::windows::named_pipe::NamedPipeClient { + fn connected(&self) -> Connected { + Connected::new() + } +} + +// Implement `Connection` for generic `TokioIo<T>` so that external crates can +// implement their own `HttpConnector` with `TokioIo<CustomTcpStream>`. +impl<T> Connection for TokioIo<T> +where + T: Connection, +{ + fn connected(&self) -> Connected { + self.inner().connected() + } +} + +impl HttpInfo { + /// Get the remote address of the transport used. + pub fn remote_addr(&self) -> SocketAddr { + self.remote_addr + } + + /// Get the local address of the transport used. + pub fn local_addr(&self) -> SocketAddr { + self.local_addr + } +} + +pin_project! { + // Not publicly exported (so missing_docs doesn't trigger). + // + // We return this `Future` instead of the `Pin<Box<dyn Future>>` directly + // so that users don't rely on it fitting in a `Pin<Box<dyn Future>>` slot + // (and thus we can change the type in the future). + #[must_use = "futures do nothing unless polled"] + #[allow(missing_debug_implementations)] + pub struct HttpConnecting<R> { + #[pin] + fut: BoxConnecting, + _marker: PhantomData<R>, + } +} + +type ConnectResult = Result<TokioIo<TcpStream>, ConnectError>; +type BoxConnecting = Pin<Box<dyn Future<Output = ConnectResult> + Send>>; + +impl<R: Resolve> Future for HttpConnecting<R> { + type Output = ConnectResult; + + fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> { + self.project().fut.poll(cx) + } +} + +// Not publicly exported (so missing_docs doesn't trigger). +pub struct ConnectError { + msg: &'static str, + addr: Option<SocketAddr>, + cause: Option<Box<dyn StdError + Send + Sync>>, +} + +impl ConnectError { + fn new<E>(msg: &'static str, cause: E) -> ConnectError + where + E: Into<Box<dyn StdError + Send + Sync>>, + { + ConnectError { + msg, + addr: None, + cause: Some(cause.into()), + } + } + + fn dns<E>(cause: E) -> ConnectError + where + E: Into<Box<dyn StdError + Send + Sync>>, + { + ConnectError::new("dns error", cause) + } + + fn m<E>(msg: &'static str) -> impl FnOnce(E) -> ConnectError + where + E: Into<Box<dyn StdError + Send + Sync>>, + { + move |cause| ConnectError::new(msg, cause) + } +} + +impl fmt::Debug for ConnectError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut b = f.debug_tuple("ConnectError"); + b.field(&self.msg); + if let Some(ref addr) = self.addr { + b.field(addr); + } + if let Some(ref cause) = self.cause { + b.field(cause); + } + b.finish() + } +} + +impl fmt::Display for ConnectError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(self.msg) + } +} + +impl StdError for ConnectError { + fn source(&self) -> Option<&(dyn StdError + 'static)> { + self.cause.as_ref().map(|e| &**e as _) + } +} + +struct ConnectingTcp<'a> { + preferred: ConnectingTcpRemote, + fallback: Option<ConnectingTcpFallback>, + config: &'a Config, +} + +impl<'a> ConnectingTcp<'a> { + fn new(remote_addrs: dns::SocketAddrs, config: &'a Config) -> Self { + if let Some(fallback_timeout) = config.happy_eyeballs_timeout { + let (preferred_addrs, fallback_addrs) = remote_addrs + .split_by_preference(config.local_address_ipv4, config.local_address_ipv6); + if fallback_addrs.is_empty() { + return ConnectingTcp { + preferred: ConnectingTcpRemote::new(preferred_addrs, config.connect_timeout), + fallback: None, + config, + }; + } + + ConnectingTcp { + preferred: ConnectingTcpRemote::new(preferred_addrs, config.connect_timeout), + fallback: Some(ConnectingTcpFallback { + delay: tokio::time::sleep(fallback_timeout), + remote: ConnectingTcpRemote::new(fallback_addrs, config.connect_timeout), + }), + config, + } + } else { + ConnectingTcp { + preferred: ConnectingTcpRemote::new(remote_addrs, config.connect_timeout), + fallback: None, + config, + } + } + } +} + +struct ConnectingTcpFallback { + delay: Sleep, + remote: ConnectingTcpRemote, +} + +struct ConnectingTcpRemote { + addrs: dns::SocketAddrs, + connect_timeout: Option<Duration>, +} + +impl ConnectingTcpRemote { + fn new(addrs: dns::SocketAddrs, connect_timeout: Option<Duration>) -> Self { + let connect_timeout = connect_timeout.and_then(|t| t.checked_div(addrs.len() as u32)); + + Self { + addrs, + connect_timeout, + } + } +} + +impl ConnectingTcpRemote { + async fn connect(&mut self, config: &Config) -> Result<TcpStream, ConnectError> { + let mut err = None; + for addr in &mut self.addrs { + debug!("connecting to {}", addr); + match connect(&addr, config, self.connect_timeout)?.await { + Ok(tcp) => { + debug!("connected to {}", addr); + return Ok(tcp); + } + Err(mut e) => { + trace!("connect error for {}: {:?}", addr, e); + e.addr = Some(addr); + // only return the first error, we assume it's the most relevant + if err.is_none() { + err = Some(e); + } + } + } + } + + match err { + Some(e) => Err(e), + None => Err(ConnectError::new( + "tcp connect error", + std::io::Error::new(std::io::ErrorKind::NotConnected, "Network unreachable"), + )), + } + } +} + +fn bind_local_address( + socket: &socket2::Socket, + dst_addr: &SocketAddr, + local_addr_ipv4: &Option<Ipv4Addr>, + local_addr_ipv6: &Option<Ipv6Addr>, +) -> io::Result<()> { + match (*dst_addr, local_addr_ipv4, local_addr_ipv6) { + (SocketAddr::V4(_), Some(addr), _) => { + socket.bind(&SocketAddr::new((*addr).into(), 0).into())?; + } + (SocketAddr::V6(_), _, Some(addr)) => { + socket.bind(&SocketAddr::new((*addr).into(), 0).into())?; + } + _ => { + if cfg!(windows) { + // Windows requires a socket be bound before calling connect + let any: SocketAddr = match *dst_addr { + SocketAddr::V4(_) => ([0, 0, 0, 0], 0).into(), + SocketAddr::V6(_) => ([0, 0, 0, 0, 0, 0, 0, 0], 0).into(), + }; + socket.bind(&any.into())?; + } + } + } + + Ok(()) +} + +fn connect( + addr: &SocketAddr, + config: &Config, + connect_timeout: Option<Duration>, +) -> Result<impl Future<Output = Result<TcpStream, ConnectError>>, ConnectError> { + // TODO(eliza): if Tokio's `TcpSocket` gains support for setting the + // keepalive timeout, it would be nice to use that instead of socket2, + // and avoid the unsafe `into_raw_fd`/`from_raw_fd` dance... + use socket2::{Domain, Protocol, Socket, Type}; + + let domain = Domain::for_address(*addr); + let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP)) + .map_err(ConnectError::m("tcp open error"))?; + + // When constructing a Tokio `TcpSocket` from a raw fd/socket, the user is + // responsible for ensuring O_NONBLOCK is set. + socket + .set_nonblocking(true) + .map_err(ConnectError::m("tcp set_nonblocking error"))?; + + if let Some(tcp_keepalive) = &config.tcp_keepalive_config.into_tcpkeepalive() { + if let Err(e) = socket.set_tcp_keepalive(tcp_keepalive) { + warn!("tcp set_keepalive error: {}", e); + } + } + + // That this only works for some socket types, particularly AF_INET sockets. + #[cfg(any( + target_os = "android", + target_os = "fuchsia", + target_os = "illumos", + target_os = "ios", + target_os = "linux", + target_os = "macos", + target_os = "solaris", + target_os = "tvos", + target_os = "visionos", + target_os = "watchos", + ))] + if let Some(interface) = &config.interface { + // On Linux-like systems, set the interface to bind using + // `SO_BINDTODEVICE`. + #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))] + socket + .bind_device(Some(interface.as_bytes())) + .map_err(ConnectError::m("tcp bind interface error"))?; + + // On macOS-like and Solaris-like systems, we instead use `IP_BOUND_IF`. + // This socket option desires an integer index for the interface, so we + // must first determine the index of the requested interface name using + // `if_nametoindex`. + #[cfg(any( + target_os = "illumos", + target_os = "ios", + target_os = "macos", + target_os = "solaris", + target_os = "tvos", + target_os = "visionos", + target_os = "watchos", + ))] + { + let idx = unsafe { libc::if_nametoindex(interface.as_ptr()) }; + let idx = std::num::NonZeroU32::new(idx).ok_or_else(|| { + // If the index is 0, check errno and return an I/O error. + ConnectError::new( + "error converting interface name to index", + io::Error::last_os_error(), + ) + })?; + // Different setsockopt calls are necessary depending on whether the + // address is IPv4 or IPv6. + match addr { + SocketAddr::V4(_) => socket.bind_device_by_index_v4(Some(idx)), + SocketAddr::V6(_) => socket.bind_device_by_index_v6(Some(idx)), + } + .map_err(ConnectError::m("tcp bind interface error"))?; + } + } + + #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))] + if let Some(tcp_user_timeout) = &config.tcp_user_timeout { + if let Err(e) = socket.set_tcp_user_timeout(Some(*tcp_user_timeout)) { + warn!("tcp set_tcp_user_timeout error: {}", e); + } + } + + bind_local_address( + &socket, + addr, + &config.local_address_ipv4, + &config.local_address_ipv6, + ) + .map_err(ConnectError::m("tcp bind local error"))?; + + #[cfg(unix)] + let socket = unsafe { + // Safety: `from_raw_fd` is only safe to call if ownership of the raw + // file descriptor is transferred. Since we call `into_raw_fd` on the + // socket2 socket, it gives up ownership of the fd and will not close + // it, so this is safe. + use std::os::unix::io::{FromRawFd, IntoRawFd}; + TcpSocket::from_raw_fd(socket.into_raw_fd()) + }; + #[cfg(windows)] + let socket = unsafe { + // Safety: `from_raw_socket` is only safe to call if ownership of the raw + // Windows SOCKET is transferred. Since we call `into_raw_socket` on the + // socket2 socket, it gives up ownership of the SOCKET and will not close + // it, so this is safe. + use std::os::windows::io::{FromRawSocket, IntoRawSocket}; + TcpSocket::from_raw_socket(socket.into_raw_socket()) + }; + + if config.reuse_address { + if let Err(e) = socket.set_reuseaddr(true) { + warn!("tcp set_reuse_address error: {}", e); + } + } + + if let Some(size) = config.send_buffer_size { + if let Err(e) = socket.set_send_buffer_size(size.try_into().unwrap_or(u32::MAX)) { + warn!("tcp set_buffer_size error: {}", e); + } + } + + if let Some(size) = config.recv_buffer_size { + if let Err(e) = socket.set_recv_buffer_size(size.try_into().unwrap_or(u32::MAX)) { + warn!("tcp set_recv_buffer_size error: {}", e); + } + } + + let connect = socket.connect(*addr); + Ok(async move { + match connect_timeout { + Some(dur) => match tokio::time::timeout(dur, connect).await { + Ok(Ok(s)) => Ok(s), + Ok(Err(e)) => Err(e), + Err(e) => Err(io::Error::new(io::ErrorKind::TimedOut, e)), + }, + None => connect.await, + } + .map_err(ConnectError::m("tcp connect error")) + }) +} + +impl ConnectingTcp<'_> { + async fn connect(mut self) -> Result<TcpStream, ConnectError> { + match self.fallback { + None => self.preferred.connect(self.config).await, + Some(mut fallback) => { + let preferred_fut = self.preferred.connect(self.config); + futures_util::pin_mut!(preferred_fut); + + let fallback_fut = fallback.remote.connect(self.config); + futures_util::pin_mut!(fallback_fut); + + let fallback_delay = fallback.delay; + futures_util::pin_mut!(fallback_delay); + + let (result, future) = + match futures_util::future::select(preferred_fut, fallback_delay).await { + Either::Left((result, _fallback_delay)) => { + (result, Either::Right(fallback_fut)) + } + Either::Right(((), preferred_fut)) => { + // Delay is done, start polling both the preferred and the fallback + futures_util::future::select(preferred_fut, fallback_fut) + .await + .factor_first() + } + }; + + if result.is_err() { + // Fallback to the remaining future (could be preferred or fallback) + // if we get an error + future.await + } else { + result + } + } + } + } +} + +/// Respect explicit ports in the URI, if none, either +/// keep non `0` ports resolved from a custom dns resolver, +/// or use the default port for the scheme. +fn set_port(addr: &mut SocketAddr, host_port: u16, explicit: bool) { + if explicit || addr.port() == 0 { + addr.set_port(host_port) + }; +} + +#[cfg(test)] +mod tests { + use std::io; + use std::net::SocketAddr; + + use ::http::Uri; + + use crate::client::legacy::connect::http::TcpKeepaliveConfig; + + use super::super::sealed::{Connect, ConnectSvc}; + use super::{Config, ConnectError, HttpConnector}; + + use super::set_port; + + async fn connect<C>( + connector: C, + dst: Uri, + ) -> Result<<C::_Svc as ConnectSvc>::Connection, <C::_Svc as ConnectSvc>::Error> + where + C: Connect, + { + connector.connect(super::super::sealed::Internal, dst).await + } + + #[tokio::test] + #[cfg_attr(miri, ignore)] + async fn test_errors_enforce_http() { + let dst = "https://example.domain/foo/bar?baz".parse().unwrap(); + let connector = HttpConnector::new(); + + let err = connect(connector, dst).await.unwrap_err(); + assert_eq!(&*err.msg, super::INVALID_NOT_HTTP); + } + + #[cfg(any(target_os = "linux", target_os = "macos"))] + fn get_local_ips() -> (Option<std::net::Ipv4Addr>, Option<std::net::Ipv6Addr>) { + use std::net::{IpAddr, TcpListener}; + + let mut ip_v4 = None; + let mut ip_v6 = None; + + let ips = pnet_datalink::interfaces() + .into_iter() + .flat_map(|i| i.ips.into_iter().map(|n| n.ip())); + + for ip in ips { + match ip { + IpAddr::V4(ip) if TcpListener::bind((ip, 0)).is_ok() => ip_v4 = Some(ip), + IpAddr::V6(ip) if TcpListener::bind((ip, 0)).is_ok() => ip_v6 = Some(ip), + _ => (), + } + + if ip_v4.is_some() && ip_v6.is_some() { + break; + } + } + + (ip_v4, ip_v6) + } + + #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))] + fn default_interface() -> Option<String> { + pnet_datalink::interfaces() + .iter() + .find(|e| e.is_up() && !e.is_loopback() && !e.ips.is_empty()) + .map(|e| e.name.clone()) + } + + #[tokio::test] + #[cfg_attr(miri, ignore)] + async fn test_errors_missing_scheme() { + let dst = "example.domain".parse().unwrap(); + let mut connector = HttpConnector::new(); + connector.enforce_http(false); + + let err = connect(connector, dst).await.unwrap_err(); + assert_eq!(&*err.msg, super::INVALID_MISSING_SCHEME); + } + + // NOTE: pnet crate that we use in this test doesn't compile on Windows + #[cfg(any(target_os = "linux", target_os = "macos"))] + #[cfg_attr(miri, ignore)] + #[tokio::test] + async fn local_address() { + use std::net::{IpAddr, TcpListener}; + + let (bind_ip_v4, bind_ip_v6) = get_local_ips(); + let server4 = TcpListener::bind("127.0.0.1:0").unwrap(); + let port = server4.local_addr().unwrap().port(); + let server6 = TcpListener::bind(format!("[::1]:{port}")).unwrap(); + + let assert_client_ip = |dst: String, server: TcpListener, expected_ip: IpAddr| async move { + let mut connector = HttpConnector::new(); + + match (bind_ip_v4, bind_ip_v6) { + (Some(v4), Some(v6)) => connector.set_local_addresses(v4, v6), + (Some(v4), None) => connector.set_local_address(Some(v4.into())), + (None, Some(v6)) => connector.set_local_address(Some(v6.into())), + _ => unreachable!(), + } + + connect(connector, dst.parse().unwrap()).await.unwrap(); + + let (_, client_addr) = server.accept().unwrap(); + + assert_eq!(client_addr.ip(), expected_ip); + }; + + if let Some(ip) = bind_ip_v4 { + assert_client_ip(format!("http://127.0.0.1:{port}"), server4, ip.into()).await; + } + + if let Some(ip) = bind_ip_v6 { + assert_client_ip(format!("http://[::1]:{port}"), server6, ip.into()).await; + } + } + + // NOTE: pnet crate that we use in this test doesn't compile on Windows + #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))] + #[tokio::test] + #[ignore = "setting `SO_BINDTODEVICE` requires the `CAP_NET_RAW` capability (works when running as root)"] + async fn interface() { + use socket2::{Domain, Protocol, Socket, Type}; + use std::net::TcpListener; + + let interface: Option<String> = default_interface(); + + let server4 = TcpListener::bind("127.0.0.1:0").unwrap(); + let port = server4.local_addr().unwrap().port(); + + let server6 = TcpListener::bind(format!("[::1]:{port}")).unwrap(); + + let assert_interface_name = + |dst: String, + server: TcpListener, + bind_iface: Option<String>, + expected_interface: Option<String>| async move { + let mut connector = HttpConnector::new(); + if let Some(iface) = bind_iface { + connector.set_interface(iface); + } + + connect(connector, dst.parse().unwrap()).await.unwrap(); + let domain = Domain::for_address(server.local_addr().unwrap()); + let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP)).unwrap(); + + assert_eq!( + socket.device().unwrap().as_deref(), + expected_interface.as_deref().map(|val| val.as_bytes()) + ); + }; + + assert_interface_name( + format!("http://127.0.0.1:{port}"), + server4, + interface.clone(), + interface.clone(), + ) + .await; + assert_interface_name( + format!("http://[::1]:{port}"), + server6, + interface.clone(), + interface.clone(), + ) + .await; + } + + #[test] + #[ignore] // TODO + #[cfg_attr(not(feature = "__internal_happy_eyeballs_tests"), ignore)] + fn client_happy_eyeballs() { + use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, TcpListener}; + use std::time::{Duration, Instant}; + + use super::dns; + use super::ConnectingTcp; + + let server4 = TcpListener::bind("127.0.0.1:0").unwrap(); + let addr = server4.local_addr().unwrap(); + let _server6 = TcpListener::bind(format!("[::1]:{}", addr.port())).unwrap(); + let rt = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap(); + + let local_timeout = Duration::default(); + let unreachable_v4_timeout = measure_connect(unreachable_ipv4_addr()).1; + let unreachable_v6_timeout = measure_connect(unreachable_ipv6_addr()).1; + let fallback_timeout = std::cmp::max(unreachable_v4_timeout, unreachable_v6_timeout) + + Duration::from_millis(250); + + let scenarios = &[ + // Fast primary, without fallback. + (&[local_ipv4_addr()][..], 4, local_timeout, false), + (&[local_ipv6_addr()][..], 6, local_timeout, false), + // Fast primary, with (unused) fallback. + ( + &[local_ipv4_addr(), local_ipv6_addr()][..], + 4, + local_timeout, + false, + ), + ( + &[local_ipv6_addr(), local_ipv4_addr()][..], + 6, + local_timeout, + false, + ), + // Unreachable + fast primary, without fallback. + ( + &[unreachable_ipv4_addr(), local_ipv4_addr()][..], + 4, + unreachable_v4_timeout, + false, + ), + ( + &[unreachable_ipv6_addr(), local_ipv6_addr()][..], + 6, + unreachable_v6_timeout, + false, + ), + // Unreachable + fast primary, with (unused) fallback. + ( + &[ + unreachable_ipv4_addr(), + local_ipv4_addr(), + local_ipv6_addr(), + ][..], + 4, + unreachable_v4_timeout, + false, + ), + ( + &[ + unreachable_ipv6_addr(), + local_ipv6_addr(), + local_ipv4_addr(), + ][..], + 6, + unreachable_v6_timeout, + true, + ), + // Slow primary, with (used) fallback. + ( + &[slow_ipv4_addr(), local_ipv4_addr(), local_ipv6_addr()][..], + 6, + fallback_timeout, + false, + ), + ( + &[slow_ipv6_addr(), local_ipv6_addr(), local_ipv4_addr()][..], + 4, + fallback_timeout, + true, + ), + // Slow primary, with (used) unreachable + fast fallback. + ( + &[slow_ipv4_addr(), unreachable_ipv6_addr(), local_ipv6_addr()][..], + 6, + fallback_timeout + unreachable_v6_timeout, + false, + ), + ( + &[slow_ipv6_addr(), unreachable_ipv4_addr(), local_ipv4_addr()][..], + 4, + fallback_timeout + unreachable_v4_timeout, + true, + ), + ]; + + // Scenarios for IPv6 -> IPv4 fallback require that host can access IPv6 network. + // Otherwise, connection to "slow" IPv6 address will error-out immediately. + let ipv6_accessible = measure_connect(slow_ipv6_addr()).0; + + for &(hosts, family, timeout, needs_ipv6_access) in scenarios { + if needs_ipv6_access && !ipv6_accessible { + continue; + } + + let (start, stream) = rt + .block_on(async move { + let addrs = hosts + .iter() + .map(|host| (*host, addr.port()).into()) + .collect(); + let cfg = Config { + local_address_ipv4: None, + local_address_ipv6: None, + connect_timeout: None, + tcp_keepalive_config: TcpKeepaliveConfig::default(), + happy_eyeballs_timeout: Some(fallback_timeout), + nodelay: false, + reuse_address: false, + enforce_http: false, + send_buffer_size: None, + recv_buffer_size: None, + #[cfg(any( + target_os = "android", + target_os = "fuchsia", + target_os = "linux" + ))] + interface: None, + #[cfg(any( + target_os = "illumos", + target_os = "ios", + target_os = "macos", + target_os = "solaris", + target_os = "tvos", + target_os = "visionos", + target_os = "watchos", + ))] + interface: None, + #[cfg(any( + target_os = "android", + target_os = "fuchsia", + target_os = "linux" + ))] + tcp_user_timeout: None, + }; + let connecting_tcp = ConnectingTcp::new(dns::SocketAddrs::new(addrs), &cfg); + let start = Instant::now(); + Ok::<_, ConnectError>((start, ConnectingTcp::connect(connecting_tcp).await?)) + }) + .unwrap(); + let res = if stream.peer_addr().unwrap().is_ipv4() { + 4 + } else { + 6 + }; + let duration = start.elapsed(); + + // Allow actual duration to be +/- 150ms off. + let min_duration = if timeout >= Duration::from_millis(150) { + timeout - Duration::from_millis(150) + } else { + Duration::default() + }; + let max_duration = timeout + Duration::from_millis(150); + + assert_eq!(res, family); + assert!(duration >= min_duration); + assert!(duration <= max_duration); + } + + fn local_ipv4_addr() -> IpAddr { + Ipv4Addr::new(127, 0, 0, 1).into() + } + + fn local_ipv6_addr() -> IpAddr { + Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1).into() + } + + fn unreachable_ipv4_addr() -> IpAddr { + Ipv4Addr::new(127, 0, 0, 2).into() + } + + fn unreachable_ipv6_addr() -> IpAddr { + Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 2).into() + } + + fn slow_ipv4_addr() -> IpAddr { + // RFC 6890 reserved IPv4 address. + Ipv4Addr::new(198, 18, 0, 25).into() + } + + fn slow_ipv6_addr() -> IpAddr { + // RFC 6890 reserved IPv6 address. + Ipv6Addr::new(2001, 2, 0, 0, 0, 0, 0, 254).into() + } + + fn measure_connect(addr: IpAddr) -> (bool, Duration) { + let start = Instant::now(); + let result = + std::net::TcpStream::connect_timeout(&(addr, 80).into(), Duration::from_secs(1)); + + let reachable = result.is_ok() || result.unwrap_err().kind() == io::ErrorKind::TimedOut; + let duration = start.elapsed(); + (reachable, duration) + } + } + + use std::time::Duration; + + #[test] + fn no_tcp_keepalive_config() { + assert!(TcpKeepaliveConfig::default().into_tcpkeepalive().is_none()); + } + + #[test] + fn tcp_keepalive_time_config() { + let kac = TcpKeepaliveConfig { + time: Some(Duration::from_secs(60)), + ..Default::default() + }; + if let Some(tcp_keepalive) = kac.into_tcpkeepalive() { + assert!(format!("{tcp_keepalive:?}").contains("time: Some(60s)")); + } else { + panic!("test failed"); + } + } + + #[cfg(not(any(target_os = "openbsd", target_os = "redox", target_os = "solaris")))] + #[test] + fn tcp_keepalive_interval_config() { + let kac = TcpKeepaliveConfig { + interval: Some(Duration::from_secs(1)), + ..Default::default() + }; + if let Some(tcp_keepalive) = kac.into_tcpkeepalive() { + assert!(format!("{tcp_keepalive:?}").contains("interval: Some(1s)")); + } else { + panic!("test failed"); + } + } + + #[cfg(not(any( + target_os = "openbsd", + target_os = "redox", + target_os = "solaris", + target_os = "windows" + )))] + #[test] + fn tcp_keepalive_retries_config() { + let kac = TcpKeepaliveConfig { + retries: Some(3), + ..Default::default() + }; + if let Some(tcp_keepalive) = kac.into_tcpkeepalive() { + assert!(format!("{tcp_keepalive:?}").contains("retries: Some(3)")); + } else { + panic!("test failed"); + } + } + + #[test] + fn test_set_port() { + // Respect explicit ports no matter what the resolved port is. + let mut addr = SocketAddr::from(([0, 0, 0, 0], 6881)); + set_port(&mut addr, 42, true); + assert_eq!(addr.port(), 42); + + // Ignore default host port, and use the socket port instead. + let mut addr = SocketAddr::from(([0, 0, 0, 0], 6881)); + set_port(&mut addr, 443, false); + assert_eq!(addr.port(), 6881); + + // Use the default port if the resolved port is `0`. + let mut addr = SocketAddr::from(([0, 0, 0, 0], 0)); + set_port(&mut addr, 443, false); + assert_eq!(addr.port(), 443); + } +} diff --git a/vendor/hyper-util/src/client/legacy/connect/mod.rs b/vendor/hyper-util/src/client/legacy/connect/mod.rs new file mode 100644 index 00000000..90a97679 --- /dev/null +++ b/vendor/hyper-util/src/client/legacy/connect/mod.rs @@ -0,0 +1,444 @@ +//! Connectors used by the `Client`. +//! +//! This module contains: +//! +//! - A default [`HttpConnector`][] that does DNS resolution and establishes +//! connections over TCP. +//! - Types to build custom connectors. +//! +//! # Connectors +//! +//! A "connector" is a [`Service`][] that takes a [`Uri`][] destination, and +//! its `Response` is some type implementing [`Read`][], [`Write`][], +//! and [`Connection`][]. +//! +//! ## Custom Connectors +//! +//! A simple connector that ignores the `Uri` destination and always returns +//! a TCP connection to the same address could be written like this: +//! +//! ```rust,ignore +//! let connector = tower::service_fn(|_dst| async { +//! tokio::net::TcpStream::connect("127.0.0.1:1337") +//! }) +//! ``` +//! +//! Or, fully written out: +//! +//! ``` +//! use std::{future::Future, net::SocketAddr, pin::Pin, task::{self, Poll}}; +//! use http::Uri; +//! use tokio::net::TcpStream; +//! use tower_service::Service; +//! +//! #[derive(Clone)] +//! struct LocalConnector; +//! +//! impl Service<Uri> for LocalConnector { +//! type Response = TcpStream; +//! type Error = std::io::Error; +//! // We can't "name" an `async` generated future. +//! type Future = Pin<Box< +//! dyn Future<Output = Result<Self::Response, Self::Error>> + Send +//! >>; +//! +//! fn poll_ready(&mut self, _: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> { +//! // This connector is always ready, but others might not be. +//! Poll::Ready(Ok(())) +//! } +//! +//! fn call(&mut self, _: Uri) -> Self::Future { +//! Box::pin(TcpStream::connect(SocketAddr::from(([127, 0, 0, 1], 1337)))) +//! } +//! } +//! ``` +//! +//! It's worth noting that for `TcpStream`s, the [`HttpConnector`][] is a +//! better starting place to extend from. +//! +//! [`HttpConnector`]: HttpConnector +//! [`Service`]: tower_service::Service +//! [`Uri`]: ::http::Uri +//! [`Read`]: hyper::rt::Read +//! [`Write`]: hyper::rt::Write +//! [`Connection`]: Connection +use std::{ + fmt::{self, Formatter}, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }, +}; + +use ::http::Extensions; + +#[cfg(feature = "tokio")] +pub use self::http::{HttpConnector, HttpInfo}; + +#[cfg(feature = "tokio")] +pub mod dns; +#[cfg(feature = "tokio")] +mod http; + +pub mod proxy; + +pub(crate) mod capture; +pub use capture::{capture_connection, CaptureConnection}; + +pub use self::sealed::Connect; + +/// Describes a type returned by a connector. +pub trait Connection { + /// Return metadata describing the connection. + fn connected(&self) -> Connected; +} + +/// Extra information about the connected transport. +/// +/// This can be used to inform recipients about things like if ALPN +/// was used, or if connected to an HTTP proxy. +#[derive(Debug)] +pub struct Connected { + pub(super) alpn: Alpn, + pub(super) is_proxied: bool, + pub(super) extra: Option<Extra>, + pub(super) poisoned: PoisonPill, +} + +#[derive(Clone)] +pub(crate) struct PoisonPill { + poisoned: Arc<AtomicBool>, +} + +impl fmt::Debug for PoisonPill { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + // print the address of the pill—this makes debugging issues much easier + write!( + f, + "PoisonPill@{:p} {{ poisoned: {} }}", + self.poisoned, + self.poisoned.load(Ordering::Relaxed) + ) + } +} + +impl PoisonPill { + pub(crate) fn healthy() -> Self { + Self { + poisoned: Arc::new(AtomicBool::new(false)), + } + } + pub(crate) fn poison(&self) { + self.poisoned.store(true, Ordering::Relaxed) + } + + pub(crate) fn poisoned(&self) -> bool { + self.poisoned.load(Ordering::Relaxed) + } +} + +pub(super) struct Extra(Box<dyn ExtraInner>); + +#[derive(Clone, Copy, Debug, PartialEq)] +pub(super) enum Alpn { + H2, + None, +} + +impl Connected { + /// Create new `Connected` type with empty metadata. + pub fn new() -> Connected { + Connected { + alpn: Alpn::None, + is_proxied: false, + extra: None, + poisoned: PoisonPill::healthy(), + } + } + + /// Set whether the connected transport is to an HTTP proxy. + /// + /// This setting will affect if HTTP/1 requests written on the transport + /// will have the request-target in absolute-form or origin-form: + /// + /// - When `proxy(false)`: + /// + /// ```http + /// GET /guide HTTP/1.1 + /// ``` + /// + /// - When `proxy(true)`: + /// + /// ```http + /// GET http://hyper.rs/guide HTTP/1.1 + /// ``` + /// + /// Default is `false`. + pub fn proxy(mut self, is_proxied: bool) -> Connected { + self.is_proxied = is_proxied; + self + } + + /// Determines if the connected transport is to an HTTP proxy. + pub fn is_proxied(&self) -> bool { + self.is_proxied + } + + /// Set extra connection information to be set in the extensions of every `Response`. + pub fn extra<T: Clone + Send + Sync + 'static>(mut self, extra: T) -> Connected { + if let Some(prev) = self.extra { + self.extra = Some(Extra(Box::new(ExtraChain(prev.0, extra)))); + } else { + self.extra = Some(Extra(Box::new(ExtraEnvelope(extra)))); + } + self + } + + /// Copies the extra connection information into an `Extensions` map. + pub fn get_extras(&self, extensions: &mut Extensions) { + if let Some(extra) = &self.extra { + extra.set(extensions); + } + } + + /// Set that the connected transport negotiated HTTP/2 as its next protocol. + pub fn negotiated_h2(mut self) -> Connected { + self.alpn = Alpn::H2; + self + } + + /// Determines if the connected transport negotiated HTTP/2 as its next protocol. + pub fn is_negotiated_h2(&self) -> bool { + self.alpn == Alpn::H2 + } + + /// Poison this connection + /// + /// A poisoned connection will not be reused for subsequent requests by the pool + pub fn poison(&self) { + self.poisoned.poison(); + tracing::debug!( + poison_pill = ?self.poisoned, "connection was poisoned. this connection will not be reused for subsequent requests" + ); + } + + // Don't public expose that `Connected` is `Clone`, unsure if we want to + // keep that contract... + pub(super) fn clone(&self) -> Connected { + Connected { + alpn: self.alpn, + is_proxied: self.is_proxied, + extra: self.extra.clone(), + poisoned: self.poisoned.clone(), + } + } +} + +// ===== impl Extra ===== + +impl Extra { + pub(super) fn set(&self, res: &mut Extensions) { + self.0.set(res); + } +} + +impl Clone for Extra { + fn clone(&self) -> Extra { + Extra(self.0.clone_box()) + } +} + +impl fmt::Debug for Extra { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Extra").finish() + } +} + +trait ExtraInner: Send + Sync { + fn clone_box(&self) -> Box<dyn ExtraInner>; + fn set(&self, res: &mut Extensions); +} + +// This indirection allows the `Connected` to have a type-erased "extra" value, +// while that type still knows its inner extra type. This allows the correct +// TypeId to be used when inserting into `res.extensions_mut()`. +#[derive(Clone)] +struct ExtraEnvelope<T>(T); + +impl<T> ExtraInner for ExtraEnvelope<T> +where + T: Clone + Send + Sync + 'static, +{ + fn clone_box(&self) -> Box<dyn ExtraInner> { + Box::new(self.clone()) + } + + fn set(&self, res: &mut Extensions) { + res.insert(self.0.clone()); + } +} + +struct ExtraChain<T>(Box<dyn ExtraInner>, T); + +impl<T: Clone> Clone for ExtraChain<T> { + fn clone(&self) -> Self { + ExtraChain(self.0.clone_box(), self.1.clone()) + } +} + +impl<T> ExtraInner for ExtraChain<T> +where + T: Clone + Send + Sync + 'static, +{ + fn clone_box(&self) -> Box<dyn ExtraInner> { + Box::new(self.clone()) + } + + fn set(&self, res: &mut Extensions) { + self.0.set(res); + res.insert(self.1.clone()); + } +} + +pub(super) mod sealed { + use std::error::Error as StdError; + use std::future::Future; + + use ::http::Uri; + use hyper::rt::{Read, Write}; + + use super::Connection; + + /// Connect to a destination, returning an IO transport. + /// + /// A connector receives a [`Uri`](::http::Uri) and returns a `Future` of the + /// ready connection. + /// + /// # Trait Alias + /// + /// This is really just an *alias* for the `tower::Service` trait, with + /// additional bounds set for convenience *inside* hyper. You don't actually + /// implement this trait, but `tower::Service<Uri>` instead. + // The `Sized` bound is to prevent creating `dyn Connect`, since they cannot + // fit the `Connect` bounds because of the blanket impl for `Service`. + pub trait Connect: Sealed + Sized { + #[doc(hidden)] + type _Svc: ConnectSvc; + #[doc(hidden)] + fn connect(self, internal_only: Internal, dst: Uri) -> <Self::_Svc as ConnectSvc>::Future; + } + + pub trait ConnectSvc { + type Connection: Read + Write + Connection + Unpin + Send + 'static; + type Error: Into<Box<dyn StdError + Send + Sync>>; + type Future: Future<Output = Result<Self::Connection, Self::Error>> + Unpin + Send + 'static; + + fn connect(self, internal_only: Internal, dst: Uri) -> Self::Future; + } + + impl<S, T> Connect for S + where + S: tower_service::Service<Uri, Response = T> + Send + 'static, + S::Error: Into<Box<dyn StdError + Send + Sync>>, + S::Future: Unpin + Send, + T: Read + Write + Connection + Unpin + Send + 'static, + { + type _Svc = S; + + fn connect(self, _: Internal, dst: Uri) -> crate::service::Oneshot<S, Uri> { + crate::service::Oneshot::new(self, dst) + } + } + + impl<S, T> ConnectSvc for S + where + S: tower_service::Service<Uri, Response = T> + Send + 'static, + S::Error: Into<Box<dyn StdError + Send + Sync>>, + S::Future: Unpin + Send, + T: Read + Write + Connection + Unpin + Send + 'static, + { + type Connection = T; + type Error = S::Error; + type Future = crate::service::Oneshot<S, Uri>; + + fn connect(self, _: Internal, dst: Uri) -> Self::Future { + crate::service::Oneshot::new(self, dst) + } + } + + impl<S, T> Sealed for S + where + S: tower_service::Service<Uri, Response = T> + Send, + S::Error: Into<Box<dyn StdError + Send + Sync>>, + S::Future: Unpin + Send, + T: Read + Write + Connection + Unpin + Send + 'static, + { + } + + pub trait Sealed {} + #[allow(missing_debug_implementations)] + pub struct Internal; +} + +#[cfg(test)] +mod tests { + use super::Connected; + + #[derive(Clone, Debug, PartialEq)] + struct Ex1(usize); + + #[derive(Clone, Debug, PartialEq)] + struct Ex2(&'static str); + + #[derive(Clone, Debug, PartialEq)] + struct Ex3(&'static str); + + #[test] + fn test_connected_extra() { + let c1 = Connected::new().extra(Ex1(41)); + + let mut ex = ::http::Extensions::new(); + + assert_eq!(ex.get::<Ex1>(), None); + + c1.extra.as_ref().expect("c1 extra").set(&mut ex); + + assert_eq!(ex.get::<Ex1>(), Some(&Ex1(41))); + } + + #[test] + fn test_connected_extra_chain() { + // If a user composes connectors and at each stage, there's "extra" + // info to attach, it shouldn't override the previous extras. + + let c1 = Connected::new() + .extra(Ex1(45)) + .extra(Ex2("zoom")) + .extra(Ex3("pew pew")); + + let mut ex1 = ::http::Extensions::new(); + + assert_eq!(ex1.get::<Ex1>(), None); + assert_eq!(ex1.get::<Ex2>(), None); + assert_eq!(ex1.get::<Ex3>(), None); + + c1.extra.as_ref().expect("c1 extra").set(&mut ex1); + + assert_eq!(ex1.get::<Ex1>(), Some(&Ex1(45))); + assert_eq!(ex1.get::<Ex2>(), Some(&Ex2("zoom"))); + assert_eq!(ex1.get::<Ex3>(), Some(&Ex3("pew pew"))); + + // Just like extensions, inserting the same type overrides previous type. + let c2 = Connected::new() + .extra(Ex1(33)) + .extra(Ex2("hiccup")) + .extra(Ex1(99)); + + let mut ex2 = ::http::Extensions::new(); + + c2.extra.as_ref().expect("c2 extra").set(&mut ex2); + + assert_eq!(ex2.get::<Ex1>(), Some(&Ex1(99))); + assert_eq!(ex2.get::<Ex2>(), Some(&Ex2("hiccup"))); + } +} diff --git a/vendor/hyper-util/src/client/legacy/connect/proxy/mod.rs b/vendor/hyper-util/src/client/legacy/connect/proxy/mod.rs new file mode 100644 index 00000000..56ca3291 --- /dev/null +++ b/vendor/hyper-util/src/client/legacy/connect/proxy/mod.rs @@ -0,0 +1,6 @@ +//! Proxy helpers +mod socks; +mod tunnel; + +pub use self::socks::{SocksV4, SocksV5}; +pub use self::tunnel::Tunnel; diff --git a/vendor/hyper-util/src/client/legacy/connect/proxy/socks/mod.rs b/vendor/hyper-util/src/client/legacy/connect/proxy/socks/mod.rs new file mode 100644 index 00000000..d6077b94 --- /dev/null +++ b/vendor/hyper-util/src/client/legacy/connect/proxy/socks/mod.rs @@ -0,0 +1,121 @@ +mod v5; +pub use v5::{SocksV5, SocksV5Error}; + +mod v4; +pub use v4::{SocksV4, SocksV4Error}; + +use bytes::BytesMut; + +use hyper::rt::Read; + +#[derive(Debug)] +pub enum SocksError<C> { + Inner(C), + Io(std::io::Error), + + DnsFailure, + MissingHost, + MissingPort, + + V4(SocksV4Error), + V5(SocksV5Error), + + Parsing(ParsingError), + Serialize(SerializeError), +} + +#[derive(Debug)] +pub enum ParsingError { + Incomplete, + WouldOverflow, + Other, +} + +#[derive(Debug)] +pub enum SerializeError { + WouldOverflow, +} + +async fn read_message<T, M, C>(mut conn: &mut T, buf: &mut BytesMut) -> Result<M, SocksError<C>> +where + T: Read + Unpin, + M: for<'a> TryFrom<&'a mut BytesMut, Error = ParsingError>, +{ + let mut tmp = [0; 513]; + + loop { + let n = crate::rt::read(&mut conn, &mut tmp).await?; + buf.extend_from_slice(&tmp[..n]); + + match M::try_from(buf) { + Err(ParsingError::Incomplete) => { + if n == 0 { + if buf.spare_capacity_mut().is_empty() { + return Err(SocksError::Parsing(ParsingError::WouldOverflow)); + } else { + return Err(std::io::Error::new( + std::io::ErrorKind::UnexpectedEof, + "unexpected eof", + ) + .into()); + } + } + } + Err(err) => return Err(err.into()), + Ok(res) => return Ok(res), + } + } +} + +impl<C> std::fmt::Display for SocksError<C> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str("SOCKS error: ")?; + + match self { + Self::Inner(_) => f.write_str("failed to create underlying connection"), + Self::Io(_) => f.write_str("io error during SOCKS handshake"), + + Self::DnsFailure => f.write_str("could not resolve to acceptable address type"), + Self::MissingHost => f.write_str("missing destination host"), + Self::MissingPort => f.write_str("missing destination port"), + + Self::Parsing(_) => f.write_str("failed parsing server response"), + Self::Serialize(_) => f.write_str("failed serialize request"), + + Self::V4(e) => e.fmt(f), + Self::V5(e) => e.fmt(f), + } + } +} + +impl<C: std::fmt::Debug + std::fmt::Display> std::error::Error for SocksError<C> {} + +impl<C> From<std::io::Error> for SocksError<C> { + fn from(err: std::io::Error) -> Self { + Self::Io(err) + } +} + +impl<C> From<ParsingError> for SocksError<C> { + fn from(err: ParsingError) -> Self { + Self::Parsing(err) + } +} + +impl<C> From<SerializeError> for SocksError<C> { + fn from(err: SerializeError) -> Self { + Self::Serialize(err) + } +} + +impl<C> From<SocksV4Error> for SocksError<C> { + fn from(err: SocksV4Error) -> Self { + Self::V4(err) + } +} + +impl<C> From<SocksV5Error> for SocksError<C> { + fn from(err: SocksV5Error) -> Self { + Self::V5(err) + } +} diff --git a/vendor/hyper-util/src/client/legacy/connect/proxy/socks/v4/errors.rs b/vendor/hyper-util/src/client/legacy/connect/proxy/socks/v4/errors.rs new file mode 100644 index 00000000..5fdbd05c --- /dev/null +++ b/vendor/hyper-util/src/client/legacy/connect/proxy/socks/v4/errors.rs @@ -0,0 +1,22 @@ +use super::Status; + +#[derive(Debug)] +pub enum SocksV4Error { + IpV6, + Command(Status), +} + +impl From<Status> for SocksV4Error { + fn from(err: Status) -> Self { + Self::Command(err) + } +} + +impl std::fmt::Display for SocksV4Error { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::IpV6 => f.write_str("IPV6 is not supported"), + Self::Command(status) => status.fmt(f), + } + } +} diff --git a/vendor/hyper-util/src/client/legacy/connect/proxy/socks/v4/messages.rs b/vendor/hyper-util/src/client/legacy/connect/proxy/socks/v4/messages.rs new file mode 100644 index 00000000..bec8d081 --- /dev/null +++ b/vendor/hyper-util/src/client/legacy/connect/proxy/socks/v4/messages.rs @@ -0,0 +1,131 @@ +use super::super::{ParsingError, SerializeError}; + +use bytes::{Buf, BufMut, BytesMut}; +use std::net::SocketAddrV4; + +/// +-----+-----+----+----+----+----+----+----+-------------+------+------------+------+ +/// | VN | CD | DSTPORT | DSTIP | USERID | NULL | DOMAIN | NULL | +/// +-----+-----+----+----+----+----+----+----+-------------+------+------------+------+ +/// | 1 | 1 | 2 | 4 | Variable | 1 | Variable | 1 | +/// +-----+-----+----+----+----+----+----+----+-------------+------+------------+------+ +/// ^^^^^^^^^^^^^^^^^^^^^ +/// optional: only do IP is 0.0.0.X +#[derive(Debug)] +pub struct Request<'a>(pub &'a Address); + +/// +-----+-----+----+----+----+----+----+----+ +/// | VN | CD | DSTPORT | DSTIP | +/// +-----+-----+----+----+----+----+----+----+ +/// | 1 | 1 | 2 | 4 | +/// +-----+-----+----+----+----+----+----+----+ +/// ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +/// ignore: only for SOCKSv4 BIND +#[derive(Debug)] +pub struct Response(pub Status); + +#[derive(Debug)] +pub enum Address { + Socket(SocketAddrV4), + Domain(String, u16), +} + +#[derive(Debug, PartialEq)] +pub enum Status { + Success = 90, + Failed = 91, + IdentFailure = 92, + IdentMismatch = 93, +} + +impl Request<'_> { + pub fn write_to_buf<B: BufMut>(&self, mut buf: B) -> Result<usize, SerializeError> { + match self.0 { + Address::Socket(socket) => { + if buf.remaining_mut() < 10 { + return Err(SerializeError::WouldOverflow); + } + + buf.put_u8(0x04); // Version + buf.put_u8(0x01); // CONNECT + + buf.put_u16(socket.port()); // Port + buf.put_slice(&socket.ip().octets()); // IP + + buf.put_u8(0x00); // USERID + buf.put_u8(0x00); // NULL + + Ok(10) + } + + Address::Domain(domain, port) => { + if buf.remaining_mut() < 10 + domain.len() + 1 { + return Err(SerializeError::WouldOverflow); + } + + buf.put_u8(0x04); // Version + buf.put_u8(0x01); // CONNECT + + buf.put_u16(*port); // IP + buf.put_slice(&[0x00, 0x00, 0x00, 0xFF]); // Invalid IP + + buf.put_u8(0x00); // USERID + buf.put_u8(0x00); // NULL + + buf.put_slice(domain.as_bytes()); // Domain + buf.put_u8(0x00); // NULL + + Ok(10 + domain.len() + 1) + } + } + } +} + +impl TryFrom<&mut BytesMut> for Response { + type Error = ParsingError; + + fn try_from(buf: &mut BytesMut) -> Result<Self, Self::Error> { + if buf.remaining() < 8 { + return Err(ParsingError::Incomplete); + } + + if buf.get_u8() != 0x00 { + return Err(ParsingError::Other); + } + + let status = buf.get_u8().try_into()?; + let _addr = { + let port = buf.get_u16(); + let mut ip = [0; 4]; + buf.copy_to_slice(&mut ip); + + SocketAddrV4::new(ip.into(), port) + }; + + Ok(Self(status)) + } +} + +impl TryFrom<u8> for Status { + type Error = ParsingError; + + fn try_from(byte: u8) -> Result<Self, Self::Error> { + Ok(match byte { + 90 => Self::Success, + 91 => Self::Failed, + 92 => Self::IdentFailure, + 93 => Self::IdentMismatch, + _ => return Err(ParsingError::Other), + }) + } +} + +impl std::fmt::Display for Status { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(match self { + Self::Success => "success", + Self::Failed => "server failed to execute command", + Self::IdentFailure => "server ident service failed", + Self::IdentMismatch => "server ident service did not recognise client identifier", + }) + } +} diff --git a/vendor/hyper-util/src/client/legacy/connect/proxy/socks/v4/mod.rs b/vendor/hyper-util/src/client/legacy/connect/proxy/socks/v4/mod.rs new file mode 100644 index 00000000..bee7e6dc --- /dev/null +++ b/vendor/hyper-util/src/client/legacy/connect/proxy/socks/v4/mod.rs @@ -0,0 +1,183 @@ +mod errors; +pub use errors::*; + +mod messages; +use messages::*; + +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use std::net::{IpAddr, SocketAddr, SocketAddrV4, ToSocketAddrs}; + +use http::Uri; +use hyper::rt::{Read, Write}; +use tower_service::Service; + +use bytes::BytesMut; + +use pin_project_lite::pin_project; + +/// Tunnel Proxy via SOCKSv4 +/// +/// This is a connector that can be used by the `legacy::Client`. It wraps +/// another connector, and after getting an underlying connection, it established +/// a TCP tunnel over it using SOCKSv4. +#[derive(Debug, Clone)] +pub struct SocksV4<C> { + inner: C, + config: SocksConfig, +} + +#[derive(Debug, Clone)] +struct SocksConfig { + proxy: Uri, + local_dns: bool, +} + +pin_project! { + // Not publicly exported (so missing_docs doesn't trigger). + // + // We return this `Future` instead of the `Pin<Box<dyn Future>>` directly + // so that users don't rely on it fitting in a `Pin<Box<dyn Future>>` slot + // (and thus we can change the type in the future). + #[must_use = "futures do nothing unless polled"] + #[allow(missing_debug_implementations)] + pub struct Handshaking<F, T, E> { + #[pin] + fut: BoxHandshaking<T, E>, + _marker: std::marker::PhantomData<F> + } +} + +type BoxHandshaking<T, E> = Pin<Box<dyn Future<Output = Result<T, super::SocksError<E>>> + Send>>; + +impl<C> SocksV4<C> { + /// Create a new SOCKSv4 handshake service + /// + /// Wraps an underlying connector and stores the address of a tunneling + /// proxying server. + /// + /// A `SocksV4` can then be called with any destination. The `dst` passed to + /// `call` will not be used to create the underlying connection, but will + /// be used in a SOCKS handshake with the proxy destination. + pub fn new(proxy_dst: Uri, connector: C) -> Self { + Self { + inner: connector, + config: SocksConfig::new(proxy_dst), + } + } + + /// Resolve domain names locally on the client, rather than on the proxy server. + /// + /// Disabled by default as local resolution of domain names can be detected as a + /// DNS leak. + pub fn local_dns(mut self, local_dns: bool) -> Self { + self.config.local_dns = local_dns; + self + } +} + +impl SocksConfig { + pub fn new(proxy: Uri) -> Self { + Self { + proxy, + local_dns: false, + } + } + + async fn execute<T, E>( + self, + mut conn: T, + host: String, + port: u16, + ) -> Result<T, super::SocksError<E>> + where + T: Read + Write + Unpin, + { + let address = match host.parse::<IpAddr>() { + Ok(IpAddr::V6(_)) => return Err(SocksV4Error::IpV6.into()), + Ok(IpAddr::V4(ip)) => Address::Socket(SocketAddrV4::new(ip, port)), + Err(_) => { + if self.local_dns { + (host, port) + .to_socket_addrs()? + .find_map(|s| { + if let SocketAddr::V4(v4) = s { + Some(Address::Socket(v4)) + } else { + None + } + }) + .ok_or(super::SocksError::DnsFailure)? + } else { + Address::Domain(host, port) + } + } + }; + + let mut send_buf = BytesMut::with_capacity(1024); + let mut recv_buf = BytesMut::with_capacity(1024); + + // Send Request + let req = Request(&address); + let n = req.write_to_buf(&mut send_buf)?; + crate::rt::write_all(&mut conn, &send_buf[..n]).await?; + + // Read Response + let res: Response = super::read_message(&mut conn, &mut recv_buf).await?; + if res.0 == Status::Success { + Ok(conn) + } else { + Err(SocksV4Error::Command(res.0).into()) + } + } +} + +impl<C> Service<Uri> for SocksV4<C> +where + C: Service<Uri>, + C::Future: Send + 'static, + C::Response: Read + Write + Unpin + Send + 'static, + C::Error: Send + 'static, +{ + type Response = C::Response; + type Error = super::SocksError<C::Error>; + type Future = Handshaking<C::Future, C::Response, C::Error>; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + self.inner.poll_ready(cx).map_err(super::SocksError::Inner) + } + + fn call(&mut self, dst: Uri) -> Self::Future { + let config = self.config.clone(); + let connecting = self.inner.call(config.proxy.clone()); + + let fut = async move { + let port = dst.port().map(|p| p.as_u16()).unwrap_or(443); + let host = dst + .host() + .ok_or(super::SocksError::MissingHost)? + .to_string(); + + let conn = connecting.await.map_err(super::SocksError::Inner)?; + config.execute(conn, host, port).await + }; + + Handshaking { + fut: Box::pin(fut), + _marker: Default::default(), + } + } +} + +impl<F, T, E> Future for Handshaking<F, T, E> +where + F: Future<Output = Result<T, E>>, +{ + type Output = Result<T, super::SocksError<E>>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + self.project().fut.poll(cx) + } +} diff --git a/vendor/hyper-util/src/client/legacy/connect/proxy/socks/v5/errors.rs b/vendor/hyper-util/src/client/legacy/connect/proxy/socks/v5/errors.rs new file mode 100644 index 00000000..06b1a9a8 --- /dev/null +++ b/vendor/hyper-util/src/client/legacy/connect/proxy/socks/v5/errors.rs @@ -0,0 +1,47 @@ +use super::Status; + +#[derive(Debug)] +pub enum SocksV5Error { + HostTooLong, + Auth(AuthError), + Command(Status), +} + +#[derive(Debug)] +pub enum AuthError { + Unsupported, + MethodMismatch, + Failed, +} + +impl From<Status> for SocksV5Error { + fn from(err: Status) -> Self { + Self::Command(err) + } +} + +impl From<AuthError> for SocksV5Error { + fn from(err: AuthError) -> Self { + Self::Auth(err) + } +} + +impl std::fmt::Display for SocksV5Error { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::HostTooLong => f.write_str("host address is more than 255 characters"), + Self::Command(e) => e.fmt(f), + Self::Auth(e) => e.fmt(f), + } + } +} + +impl std::fmt::Display for AuthError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(match self { + Self::Unsupported => "server does not support user/pass authentication", + Self::MethodMismatch => "server implements authentication incorrectly", + Self::Failed => "credentials not accepted", + }) + } +} diff --git a/vendor/hyper-util/src/client/legacy/connect/proxy/socks/v5/messages.rs b/vendor/hyper-util/src/client/legacy/connect/proxy/socks/v5/messages.rs new file mode 100644 index 00000000..ddf93538 --- /dev/null +++ b/vendor/hyper-util/src/client/legacy/connect/proxy/socks/v5/messages.rs @@ -0,0 +1,347 @@ +use super::super::{ParsingError, SerializeError}; + +use bytes::{Buf, BufMut, BytesMut}; +use std::net::SocketAddr; + +/// +----+----------+----------+ +/// |VER | NMETHODS | METHODS | +/// +----+----------+----------+ +/// | 1 | 1 | 1 to 255 | +/// +----+----------+----------+ +#[derive(Debug)] +pub struct NegotiationReq<'a>(pub &'a AuthMethod); + +/// +----+--------+ +/// |VER | METHOD | +/// +----+--------+ +/// | 1 | 1 | +/// +----+--------+ +#[derive(Debug)] +pub struct NegotiationRes(pub AuthMethod); + +/// +----+------+----------+------+----------+ +/// |VER | ULEN | UNAME | PLEN | PASSWD | +/// +----+------+----------+------+----------+ +/// | 1 | 1 | 1 to 255 | 1 | 1 to 255 | +/// +----+------+----------+------+----------+ +#[derive(Debug)] +pub struct AuthenticationReq<'a>(pub &'a str, pub &'a str); + +/// +----+--------+ +/// |VER | STATUS | +/// +----+--------+ +/// | 1 | 1 | +/// +----+--------+ +#[derive(Debug)] +pub struct AuthenticationRes(pub bool); + +/// +----+-----+-------+------+----------+----------+ +/// |VER | CMD | RSV | ATYP | DST.ADDR | DST.PORT | +/// +----+-----+-------+------+----------+----------+ +/// | 1 | 1 | X'00' | 1 | Variable | 2 | +/// +----+-----+-------+------+----------+----------+ +#[derive(Debug)] +pub struct ProxyReq<'a>(pub &'a Address); + +/// +----+-----+-------+------+----------+----------+ +/// |VER | REP | RSV | ATYP | BND.ADDR | BND.PORT | +/// +----+-----+-------+------+----------+----------+ +/// | 1 | 1 | X'00' | 1 | Variable | 2 | +/// +----+-----+-------+------+----------+----------+ +#[derive(Debug)] +pub struct ProxyRes(pub Status); + +#[repr(u8)] +#[derive(Debug, Copy, Clone, PartialEq)] +pub enum AuthMethod { + NoAuth = 0x00, + UserPass = 0x02, + NoneAcceptable = 0xFF, +} + +#[derive(Debug)] +pub enum Address { + Socket(SocketAddr), + Domain(String, u16), +} + +#[derive(Debug, Copy, Clone, PartialEq)] +pub enum Status { + Success, + GeneralServerFailure, + ConnectionNotAllowed, + NetworkUnreachable, + HostUnreachable, + ConnectionRefused, + TtlExpired, + CommandNotSupported, + AddressTypeNotSupported, +} + +impl NegotiationReq<'_> { + pub fn write_to_buf(&self, buf: &mut BytesMut) -> Result<usize, SerializeError> { + if buf.capacity() - buf.len() < 3 { + return Err(SerializeError::WouldOverflow); + } + + buf.put_u8(0x05); // Version + buf.put_u8(0x01); // Number of authentication methods + buf.put_u8(*self.0 as u8); // Authentication method + + Ok(3) + } +} + +impl TryFrom<&mut BytesMut> for NegotiationRes { + type Error = ParsingError; + + fn try_from(buf: &mut BytesMut) -> Result<Self, ParsingError> { + if buf.remaining() < 2 { + return Err(ParsingError::Incomplete); + } + + if buf.get_u8() != 0x05 { + return Err(ParsingError::Other); + } + + let method = buf.get_u8().try_into()?; + Ok(Self(method)) + } +} + +impl AuthenticationReq<'_> { + pub fn write_to_buf(&self, buf: &mut BytesMut) -> Result<usize, SerializeError> { + if buf.capacity() - buf.len() < 3 + self.0.len() + self.1.len() { + return Err(SerializeError::WouldOverflow); + } + + buf.put_u8(0x01); // Version + + buf.put_u8(self.0.len() as u8); // Username length (guarenteed to be 255 or less) + buf.put_slice(self.0.as_bytes()); // Username + + buf.put_u8(self.1.len() as u8); // Password length (guarenteed to be 255 or less) + buf.put_slice(self.1.as_bytes()); // Password + + Ok(3 + self.0.len() + self.1.len()) + } +} + +impl TryFrom<&mut BytesMut> for AuthenticationRes { + type Error = ParsingError; + + fn try_from(buf: &mut BytesMut) -> Result<Self, ParsingError> { + if buf.remaining() < 2 { + return Err(ParsingError::Incomplete); + } + + if buf.get_u8() != 0x01 { + return Err(ParsingError::Other); + } + + if buf.get_u8() == 0 { + Ok(Self(true)) + } else { + Ok(Self(false)) + } + } +} + +impl ProxyReq<'_> { + pub fn write_to_buf(&self, buf: &mut BytesMut) -> Result<usize, SerializeError> { + let addr_len = match self.0 { + Address::Socket(SocketAddr::V4(_)) => 1 + 4 + 2, + Address::Socket(SocketAddr::V6(_)) => 1 + 16 + 2, + Address::Domain(ref domain, _) => 1 + 1 + domain.len() + 2, + }; + + if buf.capacity() - buf.len() < 3 + addr_len { + return Err(SerializeError::WouldOverflow); + } + + buf.put_u8(0x05); // Version + buf.put_u8(0x01); // TCP tunneling command + buf.put_u8(0x00); // Reserved + let _ = self.0.write_to_buf(buf); // Address + + Ok(3 + addr_len) + } +} + +impl TryFrom<&mut BytesMut> for ProxyRes { + type Error = ParsingError; + + fn try_from(buf: &mut BytesMut) -> Result<Self, ParsingError> { + if buf.remaining() < 2 { + return Err(ParsingError::Incomplete); + } + + // VER + if buf.get_u8() != 0x05 { + return Err(ParsingError::Other); + } + + // REP + let status = buf.get_u8().try_into()?; + + // RSV + if buf.get_u8() != 0x00 { + return Err(ParsingError::Other); + } + + // ATYP + ADDR + Address::try_from(buf)?; + + Ok(Self(status)) + } +} + +impl Address { + pub fn write_to_buf(&self, buf: &mut BytesMut) -> Result<usize, SerializeError> { + match self { + Self::Socket(SocketAddr::V4(v4)) => { + if buf.capacity() - buf.len() < 1 + 4 + 2 { + return Err(SerializeError::WouldOverflow); + } + + buf.put_u8(0x01); + buf.put_slice(&v4.ip().octets()); + buf.put_u16(v4.port()); // Network Order/BigEndian for port + + Ok(7) + } + + Self::Socket(SocketAddr::V6(v6)) => { + if buf.capacity() - buf.len() < 1 + 16 + 2 { + return Err(SerializeError::WouldOverflow); + } + + buf.put_u8(0x04); + buf.put_slice(&v6.ip().octets()); + buf.put_u16(v6.port()); // Network Order/BigEndian for port + + Ok(19) + } + + Self::Domain(domain, port) => { + if buf.capacity() - buf.len() < 1 + 1 + domain.len() + 2 { + return Err(SerializeError::WouldOverflow); + } + + buf.put_u8(0x03); + buf.put_u8(domain.len() as u8); // Guarenteed to be less than 255 + buf.put_slice(domain.as_bytes()); + buf.put_u16(*port); + + Ok(4 + domain.len()) + } + } + } +} + +impl TryFrom<&mut BytesMut> for Address { + type Error = ParsingError; + + fn try_from(buf: &mut BytesMut) -> Result<Self, Self::Error> { + if buf.remaining() < 2 { + return Err(ParsingError::Incomplete); + } + + Ok(match buf.get_u8() { + 0x01 => { + let mut ip = [0; 4]; + + if buf.remaining() < 6 { + return Err(ParsingError::Incomplete); + } + + buf.copy_to_slice(&mut ip); + let port = buf.get_u16(); + + Self::Socket(SocketAddr::new(ip.into(), port)) + } + + 0x03 => { + let len = buf.get_u8(); + + if len == 0 { + return Err(ParsingError::Other); + } else if buf.remaining() < (len as usize) + 2 { + return Err(ParsingError::Incomplete); + } + + let domain = std::str::from_utf8(&buf[..len as usize]) + .map_err(|_| ParsingError::Other)? + .to_string(); + + let port = buf.get_u16(); + + Self::Domain(domain, port) + } + + 0x04 => { + let mut ip = [0; 16]; + + if buf.remaining() < 6 { + return Err(ParsingError::Incomplete); + } + buf.copy_to_slice(&mut ip); + let port = buf.get_u16(); + + Self::Socket(SocketAddr::new(ip.into(), port)) + } + + _ => return Err(ParsingError::Other), + }) + } +} + +impl TryFrom<u8> for Status { + type Error = ParsingError; + + fn try_from(byte: u8) -> Result<Self, Self::Error> { + Ok(match byte { + 0x00 => Self::Success, + + 0x01 => Self::GeneralServerFailure, + 0x02 => Self::ConnectionNotAllowed, + 0x03 => Self::NetworkUnreachable, + 0x04 => Self::HostUnreachable, + 0x05 => Self::ConnectionRefused, + 0x06 => Self::TtlExpired, + 0x07 => Self::CommandNotSupported, + 0x08 => Self::AddressTypeNotSupported, + _ => return Err(ParsingError::Other), + }) + } +} + +impl TryFrom<u8> for AuthMethod { + type Error = ParsingError; + + fn try_from(byte: u8) -> Result<Self, Self::Error> { + Ok(match byte { + 0x00 => Self::NoAuth, + 0x02 => Self::UserPass, + 0xFF => Self::NoneAcceptable, + + _ => return Err(ParsingError::Other), + }) + } +} + +impl std::fmt::Display for Status { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(match self { + Self::Success => "success", + Self::GeneralServerFailure => "general server failure", + Self::ConnectionNotAllowed => "connection not allowed", + Self::NetworkUnreachable => "network unreachable", + Self::HostUnreachable => "host unreachable", + Self::ConnectionRefused => "connection refused", + Self::TtlExpired => "ttl expired", + Self::CommandNotSupported => "command not supported", + Self::AddressTypeNotSupported => "address type not supported", + }) + } +} diff --git a/vendor/hyper-util/src/client/legacy/connect/proxy/socks/v5/mod.rs b/vendor/hyper-util/src/client/legacy/connect/proxy/socks/v5/mod.rs new file mode 100644 index 00000000..caf2446b --- /dev/null +++ b/vendor/hyper-util/src/client/legacy/connect/proxy/socks/v5/mod.rs @@ -0,0 +1,313 @@ +mod errors; +pub use errors::*; + +mod messages; +use messages::*; + +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use std::net::{IpAddr, SocketAddr, ToSocketAddrs}; + +use http::Uri; +use hyper::rt::{Read, Write}; +use tower_service::Service; + +use bytes::BytesMut; + +use pin_project_lite::pin_project; + +/// Tunnel Proxy via SOCKSv5 +/// +/// This is a connector that can be used by the `legacy::Client`. It wraps +/// another connector, and after getting an underlying connection, it established +/// a TCP tunnel over it using SOCKSv5. +#[derive(Debug, Clone)] +pub struct SocksV5<C> { + inner: C, + config: SocksConfig, +} + +#[derive(Debug, Clone)] +pub struct SocksConfig { + proxy: Uri, + proxy_auth: Option<(String, String)>, + + local_dns: bool, + optimistic: bool, +} + +#[derive(Debug)] +enum State { + SendingNegReq, + ReadingNegRes, + SendingAuthReq, + ReadingAuthRes, + SendingProxyReq, + ReadingProxyRes, +} + +pin_project! { + // Not publicly exported (so missing_docs doesn't trigger). + // + // We return this `Future` instead of the `Pin<Box<dyn Future>>` directly + // so that users don't rely on it fitting in a `Pin<Box<dyn Future>>` slot + // (and thus we can change the type in the future). + #[must_use = "futures do nothing unless polled"] + #[allow(missing_debug_implementations)] + pub struct Handshaking<F, T, E> { + #[pin] + fut: BoxHandshaking<T, E>, + _marker: std::marker::PhantomData<F> + } +} + +type BoxHandshaking<T, E> = Pin<Box<dyn Future<Output = Result<T, super::SocksError<E>>> + Send>>; + +impl<C> SocksV5<C> { + /// Create a new SOCKSv5 handshake service. + /// + /// Wraps an underlying connector and stores the address of a tunneling + /// proxying server. + /// + /// A `SocksV5` can then be called with any destination. The `dst` passed to + /// `call` will not be used to create the underlying connection, but will + /// be used in a SOCKS handshake with the proxy destination. + pub fn new(proxy_dst: Uri, connector: C) -> Self { + Self { + inner: connector, + config: SocksConfig::new(proxy_dst), + } + } + + /// Use User/Pass authentication method during handshake. + /// + /// Username and Password must be maximum of 255 characters each. + /// 0 length strings are allowed despite RFC prohibiting it. This is done so that + /// for compatablity with server implementations that require it for IP authentication. + pub fn with_auth(mut self, user: String, pass: String) -> Self { + self.config.proxy_auth = Some((user, pass)); + self + } + + /// Resolve domain names locally on the client, rather than on the proxy server. + /// + /// Disabled by default as local resolution of domain names can be detected as a + /// DNS leak. + pub fn local_dns(mut self, local_dns: bool) -> Self { + self.config.local_dns = local_dns; + self + } + + /// Send all messages of the handshake optmistically (without waiting for server response). + /// + /// Typical SOCKS handshake with auithentication takes 3 round trips. Optimistic sending + /// can reduce round trip times and dramatically increase speed of handshake at the cost of + /// reduced portability; many server implementations do not support optimistic sending as it + /// is not defined in the RFC (RFC 1928). + /// + /// Recommended to ensure connector works correctly without optimistic sending before trying + /// with optimistic sending. + pub fn send_optimistically(mut self, optimistic: bool) -> Self { + self.config.optimistic = optimistic; + self + } +} + +impl SocksConfig { + fn new(proxy: Uri) -> Self { + Self { + proxy, + proxy_auth: None, + + local_dns: false, + optimistic: false, + } + } + + async fn execute<T, E>( + self, + mut conn: T, + host: String, + port: u16, + ) -> Result<T, super::SocksError<E>> + where + T: Read + Write + Unpin, + { + let address = match host.parse::<IpAddr>() { + Ok(ip) => Address::Socket(SocketAddr::new(ip, port)), + Err(_) if host.len() <= 255 => { + if self.local_dns { + let socket = (host, port) + .to_socket_addrs()? + .next() + .ok_or(super::SocksError::DnsFailure)?; + + Address::Socket(socket) + } else { + Address::Domain(host, port) + } + } + Err(_) => return Err(SocksV5Error::HostTooLong.into()), + }; + + let method = if self.proxy_auth.is_some() { + AuthMethod::UserPass + } else { + AuthMethod::NoAuth + }; + + let mut recv_buf = BytesMut::with_capacity(513); // Max length of valid recievable message is 513 from Auth Request + let mut send_buf = BytesMut::with_capacity(262); // Max length of valid sendable message is 262 from Auth Response + let mut state = State::SendingNegReq; + + loop { + match state { + State::SendingNegReq => { + let req = NegotiationReq(&method); + + let start = send_buf.len(); + req.write_to_buf(&mut send_buf)?; + crate::rt::write_all(&mut conn, &send_buf[start..]).await?; + + if self.optimistic { + if method == AuthMethod::UserPass { + state = State::SendingAuthReq; + } else { + state = State::SendingProxyReq; + } + } else { + state = State::ReadingNegRes; + } + } + + State::ReadingNegRes => { + let res: NegotiationRes = super::read_message(&mut conn, &mut recv_buf).await?; + + if res.0 == AuthMethod::NoneAcceptable { + return Err(SocksV5Error::Auth(AuthError::Unsupported).into()); + } + + if res.0 != method { + return Err(SocksV5Error::Auth(AuthError::MethodMismatch).into()); + } + + if self.optimistic { + if res.0 == AuthMethod::UserPass { + state = State::ReadingAuthRes; + } else { + state = State::ReadingProxyRes; + } + } else if res.0 == AuthMethod::UserPass { + state = State::SendingAuthReq; + } else { + state = State::SendingProxyReq; + } + } + + State::SendingAuthReq => { + let (user, pass) = self.proxy_auth.as_ref().unwrap(); + let req = AuthenticationReq(user, pass); + + let start = send_buf.len(); + req.write_to_buf(&mut send_buf)?; + crate::rt::write_all(&mut conn, &send_buf[start..]).await?; + + if self.optimistic { + state = State::SendingProxyReq; + } else { + state = State::ReadingAuthRes; + } + } + + State::ReadingAuthRes => { + let res: AuthenticationRes = + super::read_message(&mut conn, &mut recv_buf).await?; + + if !res.0 { + return Err(SocksV5Error::Auth(AuthError::Failed).into()); + } + + if self.optimistic { + state = State::ReadingProxyRes; + } else { + state = State::SendingProxyReq; + } + } + + State::SendingProxyReq => { + let req = ProxyReq(&address); + + let start = send_buf.len(); + req.write_to_buf(&mut send_buf)?; + crate::rt::write_all(&mut conn, &send_buf[start..]).await?; + + if self.optimistic { + state = State::ReadingNegRes; + } else { + state = State::ReadingProxyRes; + } + } + + State::ReadingProxyRes => { + let res: ProxyRes = super::read_message(&mut conn, &mut recv_buf).await?; + + if res.0 == Status::Success { + return Ok(conn); + } else { + return Err(SocksV5Error::Command(res.0).into()); + } + } + } + } + } +} + +impl<C> Service<Uri> for SocksV5<C> +where + C: Service<Uri>, + C::Future: Send + 'static, + C::Response: Read + Write + Unpin + Send + 'static, + C::Error: Send + 'static, +{ + type Response = C::Response; + type Error = super::SocksError<C::Error>; + type Future = Handshaking<C::Future, C::Response, C::Error>; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + self.inner.poll_ready(cx).map_err(super::SocksError::Inner) + } + + fn call(&mut self, dst: Uri) -> Self::Future { + let config = self.config.clone(); + let connecting = self.inner.call(config.proxy.clone()); + + let fut = async move { + let port = dst.port().map(|p| p.as_u16()).unwrap_or(443); + let host = dst + .host() + .ok_or(super::SocksError::MissingHost)? + .to_string(); + + let conn = connecting.await.map_err(super::SocksError::Inner)?; + config.execute(conn, host, port).await + }; + + Handshaking { + fut: Box::pin(fut), + _marker: Default::default(), + } + } +} + +impl<F, T, E> Future for Handshaking<F, T, E> +where + F: Future<Output = Result<T, E>>, +{ + type Output = Result<T, super::SocksError<E>>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + self.project().fut.poll(cx) + } +} diff --git a/vendor/hyper-util/src/client/legacy/connect/proxy/tunnel.rs b/vendor/hyper-util/src/client/legacy/connect/proxy/tunnel.rs new file mode 100644 index 00000000..ad948596 --- /dev/null +++ b/vendor/hyper-util/src/client/legacy/connect/proxy/tunnel.rs @@ -0,0 +1,258 @@ +use std::error::Error as StdError; +use std::future::Future; +use std::marker::{PhantomData, Unpin}; +use std::pin::Pin; +use std::task::{self, Poll}; + +use futures_core::ready; +use http::{HeaderMap, HeaderValue, Uri}; +use hyper::rt::{Read, Write}; +use pin_project_lite::pin_project; +use tower_service::Service; + +/// Tunnel Proxy via HTTP CONNECT +/// +/// This is a connector that can be used by the `legacy::Client`. It wraps +/// another connector, and after getting an underlying connection, it creates +/// an HTTP CONNECT tunnel over it. +#[derive(Debug)] +pub struct Tunnel<C> { + headers: Headers, + inner: C, + proxy_dst: Uri, +} + +#[derive(Clone, Debug)] +enum Headers { + Empty, + Auth(HeaderValue), + Extra(HeaderMap), +} + +#[derive(Debug)] +pub enum TunnelError { + ConnectFailed(Box<dyn StdError + Send + Sync>), + Io(std::io::Error), + MissingHost, + ProxyAuthRequired, + ProxyHeadersTooLong, + TunnelUnexpectedEof, + TunnelUnsuccessful, +} + +pin_project! { + // Not publicly exported (so missing_docs doesn't trigger). + // + // We return this `Future` instead of the `Pin<Box<dyn Future>>` directly + // so that users don't rely on it fitting in a `Pin<Box<dyn Future>>` slot + // (and thus we can change the type in the future). + #[must_use = "futures do nothing unless polled"] + #[allow(missing_debug_implementations)] + pub struct Tunneling<F, T> { + #[pin] + fut: BoxTunneling<T>, + _marker: PhantomData<F>, + } +} + +type BoxTunneling<T> = Pin<Box<dyn Future<Output = Result<T, TunnelError>> + Send>>; + +impl<C> Tunnel<C> { + /// Create a new Tunnel service. + /// + /// This wraps an underlying connector, and stores the address of a + /// tunneling proxy server. + /// + /// A `Tunnel` can then be called with any destination. The `dst` passed to + /// `call` will not be used to create the underlying connection, but will + /// be used in an HTTP CONNECT request sent to the proxy destination. + pub fn new(proxy_dst: Uri, connector: C) -> Self { + Self { + headers: Headers::Empty, + inner: connector, + proxy_dst, + } + } + + /// Add `proxy-authorization` header value to the CONNECT request. + pub fn with_auth(mut self, mut auth: HeaderValue) -> Self { + // just in case the user forgot + auth.set_sensitive(true); + match self.headers { + Headers::Empty => { + self.headers = Headers::Auth(auth); + } + Headers::Auth(ref mut existing) => { + *existing = auth; + } + Headers::Extra(ref mut extra) => { + extra.insert(http::header::PROXY_AUTHORIZATION, auth); + } + } + + self + } + + /// Add extra headers to be sent with the CONNECT request. + /// + /// If existing headers have been set, these will be merged. + pub fn with_headers(mut self, mut headers: HeaderMap) -> Self { + match self.headers { + Headers::Empty => { + self.headers = Headers::Extra(headers); + } + Headers::Auth(auth) => { + headers + .entry(http::header::PROXY_AUTHORIZATION) + .or_insert(auth); + self.headers = Headers::Extra(headers); + } + Headers::Extra(ref mut extra) => { + extra.extend(headers); + } + } + + self + } +} + +impl<C> Service<Uri> for Tunnel<C> +where + C: Service<Uri>, + C::Future: Send + 'static, + C::Response: Read + Write + Unpin + Send + 'static, + C::Error: Into<Box<dyn StdError + Send + Sync>>, +{ + type Response = C::Response; + type Error = TunnelError; + type Future = Tunneling<C::Future, C::Response>; + + fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> { + ready!(self.inner.poll_ready(cx)).map_err(|e| TunnelError::ConnectFailed(e.into()))?; + Poll::Ready(Ok(())) + } + + fn call(&mut self, dst: Uri) -> Self::Future { + let connecting = self.inner.call(self.proxy_dst.clone()); + let headers = self.headers.clone(); + + Tunneling { + fut: Box::pin(async move { + let conn = connecting + .await + .map_err(|e| TunnelError::ConnectFailed(e.into()))?; + tunnel( + conn, + dst.host().ok_or(TunnelError::MissingHost)?, + dst.port().map(|p| p.as_u16()).unwrap_or(443), + &headers, + ) + .await + }), + _marker: PhantomData, + } + } +} + +impl<F, T, E> Future for Tunneling<F, T> +where + F: Future<Output = Result<T, E>>, +{ + type Output = Result<T, TunnelError>; + + fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> { + self.project().fut.poll(cx) + } +} + +async fn tunnel<T>(mut conn: T, host: &str, port: u16, headers: &Headers) -> Result<T, TunnelError> +where + T: Read + Write + Unpin, +{ + let mut buf = format!( + "\ + CONNECT {host}:{port} HTTP/1.1\r\n\ + Host: {host}:{port}\r\n\ + " + ) + .into_bytes(); + + match headers { + Headers::Auth(auth) => { + buf.extend_from_slice(b"Proxy-Authorization: "); + buf.extend_from_slice(auth.as_bytes()); + buf.extend_from_slice(b"\r\n"); + } + Headers::Extra(extra) => { + for (name, value) in extra { + buf.extend_from_slice(name.as_str().as_bytes()); + buf.extend_from_slice(b": "); + buf.extend_from_slice(value.as_bytes()); + buf.extend_from_slice(b"\r\n"); + } + } + Headers::Empty => (), + } + + // headers end + buf.extend_from_slice(b"\r\n"); + + crate::rt::write_all(&mut conn, &buf) + .await + .map_err(TunnelError::Io)?; + + let mut buf = [0; 8192]; + let mut pos = 0; + + loop { + let n = crate::rt::read(&mut conn, &mut buf[pos..]) + .await + .map_err(TunnelError::Io)?; + + if n == 0 { + return Err(TunnelError::TunnelUnexpectedEof); + } + pos += n; + + let recvd = &buf[..pos]; + if recvd.starts_with(b"HTTP/1.1 200") || recvd.starts_with(b"HTTP/1.0 200") { + if recvd.ends_with(b"\r\n\r\n") { + return Ok(conn); + } + if pos == buf.len() { + return Err(TunnelError::ProxyHeadersTooLong); + } + // else read more + } else if recvd.starts_with(b"HTTP/1.1 407") { + return Err(TunnelError::ProxyAuthRequired); + } else { + return Err(TunnelError::TunnelUnsuccessful); + } + } +} + +impl std::fmt::Display for TunnelError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str("tunnel error: ")?; + + f.write_str(match self { + TunnelError::MissingHost => "missing destination host", + TunnelError::ProxyAuthRequired => "proxy authorization required", + TunnelError::ProxyHeadersTooLong => "proxy response headers too long", + TunnelError::TunnelUnexpectedEof => "unexpected end of file", + TunnelError::TunnelUnsuccessful => "unsuccessful", + TunnelError::ConnectFailed(_) => "failed to create underlying connection", + TunnelError::Io(_) => "io error establishing tunnel", + }) + } +} + +impl std::error::Error for TunnelError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + TunnelError::Io(ref e) => Some(e), + TunnelError::ConnectFailed(ref e) => Some(&**e), + _ => None, + } + } +} diff --git a/vendor/hyper-util/src/client/legacy/mod.rs b/vendor/hyper-util/src/client/legacy/mod.rs new file mode 100644 index 00000000..1649ae7e --- /dev/null +++ b/vendor/hyper-util/src/client/legacy/mod.rs @@ -0,0 +1,10 @@ +#[cfg(any(feature = "http1", feature = "http2"))] +mod client; +#[cfg(any(feature = "http1", feature = "http2"))] +pub use client::{Builder, Client, Error, ResponseFuture}; + +pub mod connect; +#[doc(hidden)] +// Publicly available, but just for legacy purposes. A better pool will be +// designed. +pub mod pool; diff --git a/vendor/hyper-util/src/client/legacy/pool.rs b/vendor/hyper-util/src/client/legacy/pool.rs new file mode 100644 index 00000000..727f54b2 --- /dev/null +++ b/vendor/hyper-util/src/client/legacy/pool.rs @@ -0,0 +1,1093 @@ +#![allow(dead_code)] + +use std::collections::{HashMap, HashSet, VecDeque}; +use std::convert::Infallible; +use std::error::Error as StdError; +use std::fmt::{self, Debug}; +use std::future::Future; +use std::hash::Hash; +use std::ops::{Deref, DerefMut}; +use std::pin::Pin; +use std::sync::{Arc, Mutex, Weak}; +use std::task::{self, Poll}; + +use std::time::{Duration, Instant}; + +use futures_channel::oneshot; +use futures_core::ready; +use tracing::{debug, trace}; + +use hyper::rt::Sleep; +use hyper::rt::Timer as _; + +use crate::common::{exec, exec::Exec, timer::Timer}; + +// FIXME: allow() required due to `impl Trait` leaking types to this lint +#[allow(missing_debug_implementations)] +pub struct Pool<T, K: Key> { + // If the pool is disabled, this is None. + inner: Option<Arc<Mutex<PoolInner<T, K>>>>, +} + +// Before using a pooled connection, make sure the sender is not dead. +// +// This is a trait to allow the `client::pool::tests` to work for `i32`. +// +// See https://github.com/hyperium/hyper/issues/1429 +pub trait Poolable: Unpin + Send + Sized + 'static { + fn is_open(&self) -> bool; + /// Reserve this connection. + /// + /// Allows for HTTP/2 to return a shared reservation. + fn reserve(self) -> Reservation<Self>; + fn can_share(&self) -> bool; +} + +pub trait Key: Eq + Hash + Clone + Debug + Unpin + Send + 'static {} + +impl<T> Key for T where T: Eq + Hash + Clone + Debug + Unpin + Send + 'static {} + +/// A marker to identify what version a pooled connection is. +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +#[allow(dead_code)] +pub enum Ver { + Auto, + Http2, +} + +/// When checking out a pooled connection, it might be that the connection +/// only supports a single reservation, or it might be usable for many. +/// +/// Specifically, HTTP/1 requires a unique reservation, but HTTP/2 can be +/// used for multiple requests. +// FIXME: allow() required due to `impl Trait` leaking types to this lint +#[allow(missing_debug_implementations)] +pub enum Reservation<T> { + /// This connection could be used multiple times, the first one will be + /// reinserted into the `idle` pool, and the second will be given to + /// the `Checkout`. + #[cfg(feature = "http2")] + Shared(T, T), + /// This connection requires unique access. It will be returned after + /// use is complete. + Unique(T), +} + +/// Simple type alias in case the key type needs to be adjusted. +// pub type Key = (http::uri::Scheme, http::uri::Authority); //Arc<String>; + +struct PoolInner<T, K: Eq + Hash> { + // A flag that a connection is being established, and the connection + // should be shared. This prevents making multiple HTTP/2 connections + // to the same host. + connecting: HashSet<K>, + // These are internal Conns sitting in the event loop in the KeepAlive + // state, waiting to receive a new Request to send on the socket. + idle: HashMap<K, Vec<Idle<T>>>, + max_idle_per_host: usize, + // These are outstanding Checkouts that are waiting for a socket to be + // able to send a Request one. This is used when "racing" for a new + // connection. + // + // The Client starts 2 tasks, 1 to connect a new socket, and 1 to wait + // for the Pool to receive an idle Conn. When a Conn becomes idle, + // this list is checked for any parked Checkouts, and tries to notify + // them that the Conn could be used instead of waiting for a brand new + // connection. + waiters: HashMap<K, VecDeque<oneshot::Sender<T>>>, + // A oneshot channel is used to allow the interval to be notified when + // the Pool completely drops. That way, the interval can cancel immediately. + idle_interval_ref: Option<oneshot::Sender<Infallible>>, + exec: Exec, + timer: Option<Timer>, + timeout: Option<Duration>, +} + +// This is because `Weak::new()` *allocates* space for `T`, even if it +// doesn't need it! +struct WeakOpt<T>(Option<Weak<T>>); + +#[derive(Clone, Copy, Debug)] +pub struct Config { + pub idle_timeout: Option<Duration>, + pub max_idle_per_host: usize, +} + +impl Config { + pub fn is_enabled(&self) -> bool { + self.max_idle_per_host > 0 + } +} + +impl<T, K: Key> Pool<T, K> { + pub fn new<E, M>(config: Config, executor: E, timer: Option<M>) -> Pool<T, K> + where + E: hyper::rt::Executor<exec::BoxSendFuture> + Send + Sync + Clone + 'static, + M: hyper::rt::Timer + Send + Sync + Clone + 'static, + { + let exec = Exec::new(executor); + let timer = timer.map(|t| Timer::new(t)); + let inner = if config.is_enabled() { + Some(Arc::new(Mutex::new(PoolInner { + connecting: HashSet::new(), + idle: HashMap::new(), + idle_interval_ref: None, + max_idle_per_host: config.max_idle_per_host, + waiters: HashMap::new(), + exec, + timer, + timeout: config.idle_timeout, + }))) + } else { + None + }; + + Pool { inner } + } + + pub(crate) fn is_enabled(&self) -> bool { + self.inner.is_some() + } + + #[cfg(test)] + pub(super) fn no_timer(&self) { + // Prevent an actual interval from being created for this pool... + { + let mut inner = self.inner.as_ref().unwrap().lock().unwrap(); + assert!(inner.idle_interval_ref.is_none(), "timer already spawned"); + let (tx, _) = oneshot::channel(); + inner.idle_interval_ref = Some(tx); + } + } +} + +impl<T: Poolable, K: Key> Pool<T, K> { + /// Returns a `Checkout` which is a future that resolves if an idle + /// connection becomes available. + pub fn checkout(&self, key: K) -> Checkout<T, K> { + Checkout { + key, + pool: self.clone(), + waiter: None, + } + } + + /// Ensure that there is only ever 1 connecting task for HTTP/2 + /// connections. This does nothing for HTTP/1. + pub fn connecting(&self, key: &K, ver: Ver) -> Option<Connecting<T, K>> { + if ver == Ver::Http2 { + if let Some(ref enabled) = self.inner { + let mut inner = enabled.lock().unwrap(); + return if inner.connecting.insert(key.clone()) { + let connecting = Connecting { + key: key.clone(), + pool: WeakOpt::downgrade(enabled), + }; + Some(connecting) + } else { + trace!("HTTP/2 connecting already in progress for {:?}", key); + None + }; + } + } + + // else + Some(Connecting { + key: key.clone(), + // in HTTP/1's case, there is never a lock, so we don't + // need to do anything in Drop. + pool: WeakOpt::none(), + }) + } + + #[cfg(test)] + fn locked(&self) -> std::sync::MutexGuard<'_, PoolInner<T, K>> { + self.inner.as_ref().expect("enabled").lock().expect("lock") + } + + /* Used in client/tests.rs... + #[cfg(test)] + pub(super) fn h1_key(&self, s: &str) -> Key { + Arc::new(s.to_string()) + } + + #[cfg(test)] + pub(super) fn idle_count(&self, key: &Key) -> usize { + self + .locked() + .idle + .get(key) + .map(|list| list.len()) + .unwrap_or(0) + } + */ + + pub fn pooled( + &self, + #[cfg_attr(not(feature = "http2"), allow(unused_mut))] mut connecting: Connecting<T, K>, + value: T, + ) -> Pooled<T, K> { + let (value, pool_ref) = if let Some(ref enabled) = self.inner { + match value.reserve() { + #[cfg(feature = "http2")] + Reservation::Shared(to_insert, to_return) => { + let mut inner = enabled.lock().unwrap(); + inner.put(connecting.key.clone(), to_insert, enabled); + // Do this here instead of Drop for Connecting because we + // already have a lock, no need to lock the mutex twice. + inner.connected(&connecting.key); + // prevent the Drop of Connecting from repeating inner.connected() + connecting.pool = WeakOpt::none(); + + // Shared reservations don't need a reference to the pool, + // since the pool always keeps a copy. + (to_return, WeakOpt::none()) + } + Reservation::Unique(value) => { + // Unique reservations must take a reference to the pool + // since they hope to reinsert once the reservation is + // completed + (value, WeakOpt::downgrade(enabled)) + } + } + } else { + // If pool is not enabled, skip all the things... + + // The Connecting should have had no pool ref + debug_assert!(connecting.pool.upgrade().is_none()); + + (value, WeakOpt::none()) + }; + Pooled { + key: connecting.key.clone(), + is_reused: false, + pool: pool_ref, + value: Some(value), + } + } + + fn reuse(&self, key: &K, value: T) -> Pooled<T, K> { + debug!("reuse idle connection for {:?}", key); + // TODO: unhack this + // In Pool::pooled(), which is used for inserting brand new connections, + // there's some code that adjusts the pool reference taken depending + // on if the Reservation can be shared or is unique. By the time + // reuse() is called, the reservation has already been made, and + // we just have the final value, without knowledge of if this is + // unique or shared. So, the hack is to just assume Ver::Http2 means + // shared... :( + let mut pool_ref = WeakOpt::none(); + if !value.can_share() { + if let Some(ref enabled) = self.inner { + pool_ref = WeakOpt::downgrade(enabled); + } + } + + Pooled { + is_reused: true, + key: key.clone(), + pool: pool_ref, + value: Some(value), + } + } +} + +/// Pop off this list, looking for a usable connection that hasn't expired. +struct IdlePopper<'a, T, K> { + key: &'a K, + list: &'a mut Vec<Idle<T>>, +} + +impl<'a, T: Poolable + 'a, K: Debug> IdlePopper<'a, T, K> { + fn pop(self, expiration: &Expiration) -> Option<Idle<T>> { + while let Some(entry) = self.list.pop() { + // If the connection has been closed, or is older than our idle + // timeout, simply drop it and keep looking... + if !entry.value.is_open() { + trace!("removing closed connection for {:?}", self.key); + continue; + } + // TODO: Actually, since the `idle` list is pushed to the end always, + // that would imply that if *this* entry is expired, then anything + // "earlier" in the list would *have* to be expired also... Right? + // + // In that case, we could just break out of the loop and drop the + // whole list... + if expiration.expires(entry.idle_at) { + trace!("removing expired connection for {:?}", self.key); + continue; + } + + let value = match entry.value.reserve() { + #[cfg(feature = "http2")] + Reservation::Shared(to_reinsert, to_checkout) => { + self.list.push(Idle { + idle_at: Instant::now(), + value: to_reinsert, + }); + to_checkout + } + Reservation::Unique(unique) => unique, + }; + + return Some(Idle { + idle_at: entry.idle_at, + value, + }); + } + + None + } +} + +impl<T: Poolable, K: Key> PoolInner<T, K> { + fn put(&mut self, key: K, value: T, __pool_ref: &Arc<Mutex<PoolInner<T, K>>>) { + if value.can_share() && self.idle.contains_key(&key) { + trace!("put; existing idle HTTP/2 connection for {:?}", key); + return; + } + trace!("put; add idle connection for {:?}", key); + let mut remove_waiters = false; + let mut value = Some(value); + if let Some(waiters) = self.waiters.get_mut(&key) { + while let Some(tx) = waiters.pop_front() { + if !tx.is_canceled() { + let reserved = value.take().expect("value already sent"); + let reserved = match reserved.reserve() { + #[cfg(feature = "http2")] + Reservation::Shared(to_keep, to_send) => { + value = Some(to_keep); + to_send + } + Reservation::Unique(uniq) => uniq, + }; + match tx.send(reserved) { + Ok(()) => { + if value.is_none() { + break; + } else { + continue; + } + } + Err(e) => { + value = Some(e); + } + } + } + + trace!("put; removing canceled waiter for {:?}", key); + } + remove_waiters = waiters.is_empty(); + } + if remove_waiters { + self.waiters.remove(&key); + } + + match value { + Some(value) => { + // borrow-check scope... + { + let idle_list = self.idle.entry(key.clone()).or_default(); + if self.max_idle_per_host <= idle_list.len() { + trace!("max idle per host for {:?}, dropping connection", key); + return; + } + + debug!("pooling idle connection for {:?}", key); + idle_list.push(Idle { + value, + idle_at: Instant::now(), + }); + } + + self.spawn_idle_interval(__pool_ref); + } + None => trace!("put; found waiter for {:?}", key), + } + } + + /// A `Connecting` task is complete. Not necessarily successfully, + /// but the lock is going away, so clean up. + fn connected(&mut self, key: &K) { + let existed = self.connecting.remove(key); + debug_assert!(existed, "Connecting dropped, key not in pool.connecting"); + // cancel any waiters. if there are any, it's because + // this Connecting task didn't complete successfully. + // those waiters would never receive a connection. + self.waiters.remove(key); + } + + fn spawn_idle_interval(&mut self, pool_ref: &Arc<Mutex<PoolInner<T, K>>>) { + if self.idle_interval_ref.is_some() { + return; + } + let dur = if let Some(dur) = self.timeout { + dur + } else { + return; + }; + let timer = if let Some(timer) = self.timer.clone() { + timer + } else { + return; + }; + let (tx, rx) = oneshot::channel(); + self.idle_interval_ref = Some(tx); + + let interval = IdleTask { + timer: timer.clone(), + duration: dur, + deadline: Instant::now(), + fut: timer.sleep_until(Instant::now()), // ready at first tick + pool: WeakOpt::downgrade(pool_ref), + pool_drop_notifier: rx, + }; + + self.exec.execute(interval); + } +} + +impl<T, K: Eq + Hash> PoolInner<T, K> { + /// Any `FutureResponse`s that were created will have made a `Checkout`, + /// and possibly inserted into the pool that it is waiting for an idle + /// connection. If a user ever dropped that future, we need to clean out + /// those parked senders. + fn clean_waiters(&mut self, key: &K) { + let mut remove_waiters = false; + if let Some(waiters) = self.waiters.get_mut(key) { + waiters.retain(|tx| !tx.is_canceled()); + remove_waiters = waiters.is_empty(); + } + if remove_waiters { + self.waiters.remove(key); + } + } +} + +impl<T: Poolable, K: Key> PoolInner<T, K> { + /// This should *only* be called by the IdleTask + fn clear_expired(&mut self) { + let dur = self.timeout.expect("interval assumes timeout"); + + let now = Instant::now(); + //self.last_idle_check_at = now; + + self.idle.retain(|key, values| { + values.retain(|entry| { + if !entry.value.is_open() { + trace!("idle interval evicting closed for {:?}", key); + return false; + } + + // Avoid `Instant::sub` to avoid issues like rust-lang/rust#86470. + if now.saturating_duration_since(entry.idle_at) > dur { + trace!("idle interval evicting expired for {:?}", key); + return false; + } + + // Otherwise, keep this value... + true + }); + + // returning false evicts this key/val + !values.is_empty() + }); + } +} + +impl<T, K: Key> Clone for Pool<T, K> { + fn clone(&self) -> Pool<T, K> { + Pool { + inner: self.inner.clone(), + } + } +} + +/// A wrapped poolable value that tries to reinsert to the Pool on Drop. +// Note: The bounds `T: Poolable` is needed for the Drop impl. +pub struct Pooled<T: Poolable, K: Key> { + value: Option<T>, + is_reused: bool, + key: K, + pool: WeakOpt<Mutex<PoolInner<T, K>>>, +} + +impl<T: Poolable, K: Key> Pooled<T, K> { + pub fn is_reused(&self) -> bool { + self.is_reused + } + + pub fn is_pool_enabled(&self) -> bool { + self.pool.0.is_some() + } + + fn as_ref(&self) -> &T { + self.value.as_ref().expect("not dropped") + } + + fn as_mut(&mut self) -> &mut T { + self.value.as_mut().expect("not dropped") + } +} + +impl<T: Poolable, K: Key> Deref for Pooled<T, K> { + type Target = T; + fn deref(&self) -> &T { + self.as_ref() + } +} + +impl<T: Poolable, K: Key> DerefMut for Pooled<T, K> { + fn deref_mut(&mut self) -> &mut T { + self.as_mut() + } +} + +impl<T: Poolable, K: Key> Drop for Pooled<T, K> { + fn drop(&mut self) { + if let Some(value) = self.value.take() { + if !value.is_open() { + // If we *already* know the connection is done here, + // it shouldn't be re-inserted back into the pool. + return; + } + + if let Some(pool) = self.pool.upgrade() { + if let Ok(mut inner) = pool.lock() { + inner.put(self.key.clone(), value, &pool); + } + } else if !value.can_share() { + trace!("pool dropped, dropping pooled ({:?})", self.key); + } + // Ver::Http2 is already in the Pool (or dead), so we wouldn't + // have an actual reference to the Pool. + } + } +} + +impl<T: Poolable, K: Key> fmt::Debug for Pooled<T, K> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Pooled").field("key", &self.key).finish() + } +} + +struct Idle<T> { + idle_at: Instant, + value: T, +} + +// FIXME: allow() required due to `impl Trait` leaking types to this lint +#[allow(missing_debug_implementations)] +pub struct Checkout<T, K: Key> { + key: K, + pool: Pool<T, K>, + waiter: Option<oneshot::Receiver<T>>, +} + +#[derive(Debug)] +#[non_exhaustive] +pub enum Error { + PoolDisabled, + CheckoutNoLongerWanted, + CheckedOutClosedValue, +} + +impl Error { + pub(super) fn is_canceled(&self) -> bool { + matches!(self, Error::CheckedOutClosedValue) + } +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(match self { + Error::PoolDisabled => "pool is disabled", + Error::CheckedOutClosedValue => "checked out connection was closed", + Error::CheckoutNoLongerWanted => "request was canceled", + }) + } +} + +impl StdError for Error {} + +impl<T: Poolable, K: Key> Checkout<T, K> { + fn poll_waiter( + &mut self, + cx: &mut task::Context<'_>, + ) -> Poll<Option<Result<Pooled<T, K>, Error>>> { + if let Some(mut rx) = self.waiter.take() { + match Pin::new(&mut rx).poll(cx) { + Poll::Ready(Ok(value)) => { + if value.is_open() { + Poll::Ready(Some(Ok(self.pool.reuse(&self.key, value)))) + } else { + Poll::Ready(Some(Err(Error::CheckedOutClosedValue))) + } + } + Poll::Pending => { + self.waiter = Some(rx); + Poll::Pending + } + Poll::Ready(Err(_canceled)) => { + Poll::Ready(Some(Err(Error::CheckoutNoLongerWanted))) + } + } + } else { + Poll::Ready(None) + } + } + + fn checkout(&mut self, cx: &mut task::Context<'_>) -> Option<Pooled<T, K>> { + let entry = { + let mut inner = self.pool.inner.as_ref()?.lock().unwrap(); + let expiration = Expiration::new(inner.timeout); + let maybe_entry = inner.idle.get_mut(&self.key).and_then(|list| { + trace!("take? {:?}: expiration = {:?}", self.key, expiration.0); + // A block to end the mutable borrow on list, + // so the map below can check is_empty() + { + let popper = IdlePopper { + key: &self.key, + list, + }; + popper.pop(&expiration) + } + .map(|e| (e, list.is_empty())) + }); + + let (entry, empty) = if let Some((e, empty)) = maybe_entry { + (Some(e), empty) + } else { + // No entry found means nuke the list for sure. + (None, true) + }; + if empty { + //TODO: This could be done with the HashMap::entry API instead. + inner.idle.remove(&self.key); + } + + if entry.is_none() && self.waiter.is_none() { + let (tx, mut rx) = oneshot::channel(); + trace!("checkout waiting for idle connection: {:?}", self.key); + inner + .waiters + .entry(self.key.clone()) + .or_insert_with(VecDeque::new) + .push_back(tx); + + // register the waker with this oneshot + assert!(Pin::new(&mut rx).poll(cx).is_pending()); + self.waiter = Some(rx); + } + + entry + }; + + entry.map(|e| self.pool.reuse(&self.key, e.value)) + } +} + +impl<T: Poolable, K: Key> Future for Checkout<T, K> { + type Output = Result<Pooled<T, K>, Error>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> { + if let Some(pooled) = ready!(self.poll_waiter(cx)?) { + return Poll::Ready(Ok(pooled)); + } + + if let Some(pooled) = self.checkout(cx) { + Poll::Ready(Ok(pooled)) + } else if !self.pool.is_enabled() { + Poll::Ready(Err(Error::PoolDisabled)) + } else { + // There's a new waiter, already registered in self.checkout() + debug_assert!(self.waiter.is_some()); + Poll::Pending + } + } +} + +impl<T, K: Key> Drop for Checkout<T, K> { + fn drop(&mut self) { + if self.waiter.take().is_some() { + trace!("checkout dropped for {:?}", self.key); + if let Some(Ok(mut inner)) = self.pool.inner.as_ref().map(|i| i.lock()) { + inner.clean_waiters(&self.key); + } + } + } +} + +// FIXME: allow() required due to `impl Trait` leaking types to this lint +#[allow(missing_debug_implementations)] +pub struct Connecting<T: Poolable, K: Key> { + key: K, + pool: WeakOpt<Mutex<PoolInner<T, K>>>, +} + +impl<T: Poolable, K: Key> Connecting<T, K> { + pub fn alpn_h2(self, pool: &Pool<T, K>) -> Option<Self> { + debug_assert!( + self.pool.0.is_none(), + "Connecting::alpn_h2 but already Http2" + ); + + pool.connecting(&self.key, Ver::Http2) + } +} + +impl<T: Poolable, K: Key> Drop for Connecting<T, K> { + fn drop(&mut self) { + if let Some(pool) = self.pool.upgrade() { + // No need to panic on drop, that could abort! + if let Ok(mut inner) = pool.lock() { + inner.connected(&self.key); + } + } + } +} + +struct Expiration(Option<Duration>); + +impl Expiration { + fn new(dur: Option<Duration>) -> Expiration { + Expiration(dur) + } + + fn expires(&self, instant: Instant) -> bool { + match self.0 { + // Avoid `Instant::elapsed` to avoid issues like rust-lang/rust#86470. + Some(timeout) => Instant::now().saturating_duration_since(instant) > timeout, + None => false, + } + } +} + +pin_project_lite::pin_project! { + struct IdleTask<T, K: Key> { + timer: Timer, + duration: Duration, + deadline: Instant, + fut: Pin<Box<dyn Sleep>>, + pool: WeakOpt<Mutex<PoolInner<T, K>>>, + // This allows the IdleTask to be notified as soon as the entire + // Pool is fully dropped, and shutdown. This channel is never sent on, + // but Err(Canceled) will be received when the Pool is dropped. + #[pin] + pool_drop_notifier: oneshot::Receiver<Infallible>, + } +} + +impl<T: Poolable + 'static, K: Key> Future for IdleTask<T, K> { + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> { + let mut this = self.project(); + loop { + match this.pool_drop_notifier.as_mut().poll(cx) { + Poll::Ready(Ok(n)) => match n {}, + Poll::Pending => (), + Poll::Ready(Err(_canceled)) => { + trace!("pool closed, canceling idle interval"); + return Poll::Ready(()); + } + } + + ready!(Pin::new(&mut this.fut).poll(cx)); + // Set this task to run after the next deadline + // If the poll missed the deadline by a lot, set the deadline + // from the current time instead + *this.deadline += *this.duration; + if *this.deadline < Instant::now() - Duration::from_millis(5) { + *this.deadline = Instant::now() + *this.duration; + } + *this.fut = this.timer.sleep_until(*this.deadline); + + if let Some(inner) = this.pool.upgrade() { + if let Ok(mut inner) = inner.lock() { + trace!("idle interval checking for expired"); + inner.clear_expired(); + continue; + } + } + return Poll::Ready(()); + } + } +} + +impl<T> WeakOpt<T> { + fn none() -> Self { + WeakOpt(None) + } + + fn downgrade(arc: &Arc<T>) -> Self { + WeakOpt(Some(Arc::downgrade(arc))) + } + + fn upgrade(&self) -> Option<Arc<T>> { + self.0.as_ref().and_then(Weak::upgrade) + } +} + +#[cfg(all(test, not(miri)))] +mod tests { + use std::fmt::Debug; + use std::future::Future; + use std::hash::Hash; + use std::pin::Pin; + use std::task::{self, Poll}; + use std::time::Duration; + + use super::{Connecting, Key, Pool, Poolable, Reservation, WeakOpt}; + use crate::rt::{TokioExecutor, TokioTimer}; + + use crate::common::timer; + + #[derive(Clone, Debug, PartialEq, Eq, Hash)] + struct KeyImpl(http::uri::Scheme, http::uri::Authority); + + type KeyTuple = (http::uri::Scheme, http::uri::Authority); + + /// Test unique reservations. + #[derive(Debug, PartialEq, Eq)] + struct Uniq<T>(T); + + impl<T: Send + 'static + Unpin> Poolable for Uniq<T> { + fn is_open(&self) -> bool { + true + } + + fn reserve(self) -> Reservation<Self> { + Reservation::Unique(self) + } + + fn can_share(&self) -> bool { + false + } + } + + fn c<T: Poolable, K: Key>(key: K) -> Connecting<T, K> { + Connecting { + key, + pool: WeakOpt::none(), + } + } + + fn host_key(s: &str) -> KeyImpl { + KeyImpl(http::uri::Scheme::HTTP, s.parse().expect("host key")) + } + + fn pool_no_timer<T, K: Key>() -> Pool<T, K> { + pool_max_idle_no_timer(usize::MAX) + } + + fn pool_max_idle_no_timer<T, K: Key>(max_idle: usize) -> Pool<T, K> { + let pool = Pool::new( + super::Config { + idle_timeout: Some(Duration::from_millis(100)), + max_idle_per_host: max_idle, + }, + TokioExecutor::new(), + Option::<timer::Timer>::None, + ); + pool.no_timer(); + pool + } + + #[tokio::test] + async fn test_pool_checkout_smoke() { + let pool = pool_no_timer(); + let key = host_key("foo"); + let pooled = pool.pooled(c(key.clone()), Uniq(41)); + + drop(pooled); + + match pool.checkout(key).await { + Ok(pooled) => assert_eq!(*pooled, Uniq(41)), + Err(_) => panic!("not ready"), + }; + } + + /// Helper to check if the future is ready after polling once. + struct PollOnce<'a, F>(&'a mut F); + + impl<F, T, U> Future for PollOnce<'_, F> + where + F: Future<Output = Result<T, U>> + Unpin, + { + type Output = Option<()>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> { + match Pin::new(&mut self.0).poll(cx) { + Poll::Ready(Ok(_)) => Poll::Ready(Some(())), + Poll::Ready(Err(_)) => Poll::Ready(Some(())), + Poll::Pending => Poll::Ready(None), + } + } + } + + #[tokio::test] + async fn test_pool_checkout_returns_none_if_expired() { + let pool = pool_no_timer(); + let key = host_key("foo"); + let pooled = pool.pooled(c(key.clone()), Uniq(41)); + + drop(pooled); + tokio::time::sleep(pool.locked().timeout.unwrap()).await; + let mut checkout = pool.checkout(key); + let poll_once = PollOnce(&mut checkout); + let is_not_ready = poll_once.await.is_none(); + assert!(is_not_ready); + } + + #[tokio::test] + async fn test_pool_checkout_removes_expired() { + let pool = pool_no_timer(); + let key = host_key("foo"); + + pool.pooled(c(key.clone()), Uniq(41)); + pool.pooled(c(key.clone()), Uniq(5)); + pool.pooled(c(key.clone()), Uniq(99)); + + assert_eq!( + pool.locked().idle.get(&key).map(|entries| entries.len()), + Some(3) + ); + tokio::time::sleep(pool.locked().timeout.unwrap()).await; + + let mut checkout = pool.checkout(key.clone()); + let poll_once = PollOnce(&mut checkout); + // checkout.await should clean out the expired + poll_once.await; + assert!(!pool.locked().idle.contains_key(&key)); + } + + #[test] + fn test_pool_max_idle_per_host() { + let pool = pool_max_idle_no_timer(2); + let key = host_key("foo"); + + pool.pooled(c(key.clone()), Uniq(41)); + pool.pooled(c(key.clone()), Uniq(5)); + pool.pooled(c(key.clone()), Uniq(99)); + + // pooled and dropped 3, max_idle should only allow 2 + assert_eq!( + pool.locked().idle.get(&key).map(|entries| entries.len()), + Some(2) + ); + } + + #[tokio::test] + async fn test_pool_timer_removes_expired() { + let pool = Pool::new( + super::Config { + idle_timeout: Some(Duration::from_millis(10)), + max_idle_per_host: usize::MAX, + }, + TokioExecutor::new(), + Some(TokioTimer::new()), + ); + + let key = host_key("foo"); + + pool.pooled(c(key.clone()), Uniq(41)); + pool.pooled(c(key.clone()), Uniq(5)); + pool.pooled(c(key.clone()), Uniq(99)); + + assert_eq!( + pool.locked().idle.get(&key).map(|entries| entries.len()), + Some(3) + ); + + // Let the timer tick passed the expiration... + tokio::time::sleep(Duration::from_millis(30)).await; + // Yield so the Interval can reap... + tokio::task::yield_now().await; + + assert!(!pool.locked().idle.contains_key(&key)); + } + + #[tokio::test] + async fn test_pool_checkout_task_unparked() { + use futures_util::future::join; + use futures_util::FutureExt; + + let pool = pool_no_timer(); + let key = host_key("foo"); + let pooled = pool.pooled(c(key.clone()), Uniq(41)); + + let checkout = join(pool.checkout(key), async { + // the checkout future will park first, + // and then this lazy future will be polled, which will insert + // the pooled back into the pool + // + // this test makes sure that doing so will unpark the checkout + drop(pooled); + }) + .map(|(entry, _)| entry); + + assert_eq!(*checkout.await.unwrap(), Uniq(41)); + } + + #[tokio::test] + async fn test_pool_checkout_drop_cleans_up_waiters() { + let pool = pool_no_timer::<Uniq<i32>, KeyImpl>(); + let key = host_key("foo"); + + let mut checkout1 = pool.checkout(key.clone()); + let mut checkout2 = pool.checkout(key.clone()); + + let poll_once1 = PollOnce(&mut checkout1); + let poll_once2 = PollOnce(&mut checkout2); + + // first poll needed to get into Pool's parked + poll_once1.await; + assert_eq!(pool.locked().waiters.get(&key).unwrap().len(), 1); + poll_once2.await; + assert_eq!(pool.locked().waiters.get(&key).unwrap().len(), 2); + + // on drop, clean up Pool + drop(checkout1); + assert_eq!(pool.locked().waiters.get(&key).unwrap().len(), 1); + + drop(checkout2); + assert!(!pool.locked().waiters.contains_key(&key)); + } + + #[derive(Debug)] + struct CanClose { + #[allow(unused)] + val: i32, + closed: bool, + } + + impl Poolable for CanClose { + fn is_open(&self) -> bool { + !self.closed + } + + fn reserve(self) -> Reservation<Self> { + Reservation::Unique(self) + } + + fn can_share(&self) -> bool { + false + } + } + + #[test] + fn pooled_drop_if_closed_doesnt_reinsert() { + let pool = pool_no_timer(); + let key = host_key("foo"); + pool.pooled( + c(key.clone()), + CanClose { + val: 57, + closed: true, + }, + ); + + assert!(!pool.locked().idle.contains_key(&key)); + } +} diff --git a/vendor/hyper-util/src/client/mod.rs b/vendor/hyper-util/src/client/mod.rs new file mode 100644 index 00000000..0d896030 --- /dev/null +++ b/vendor/hyper-util/src/client/mod.rs @@ -0,0 +1,8 @@ +//! HTTP client utilities + +/// Legacy implementations of `connect` module and `Client` +#[cfg(feature = "client-legacy")] +pub mod legacy; + +#[cfg(feature = "client-proxy")] +pub mod proxy; diff --git a/vendor/hyper-util/src/client/proxy/matcher.rs b/vendor/hyper-util/src/client/proxy/matcher.rs new file mode 100644 index 00000000..fd563bca --- /dev/null +++ b/vendor/hyper-util/src/client/proxy/matcher.rs @@ -0,0 +1,848 @@ +//! Proxy matchers +//! +//! This module contains different matchers to configure rules for when a proxy +//! should be used, and if so, with what arguments. +//! +//! A [`Matcher`] can be constructed either using environment variables, or +//! a [`Matcher::builder()`]. +//! +//! Once constructed, the `Matcher` can be asked if it intercepts a `Uri` by +//! calling [`Matcher::intercept()`]. +//! +//! An [`Intercept`] includes the destination for the proxy, and any parsed +//! authentication to be used. + +use std::fmt; +use std::net::IpAddr; + +use http::header::HeaderValue; +use ipnet::IpNet; +use percent_encoding::percent_decode_str; + +#[cfg(docsrs)] +pub use self::builder::IntoValue; +#[cfg(not(docsrs))] +use self::builder::IntoValue; + +/// A proxy matcher, usually built from environment variables. +pub struct Matcher { + http: Option<Intercept>, + https: Option<Intercept>, + no: NoProxy, +} + +/// A matched proxy, +/// +/// This is returned by a matcher if a proxy should be used. +#[derive(Clone)] +pub struct Intercept { + uri: http::Uri, + auth: Auth, +} + +/// A builder to create a [`Matcher`]. +/// +/// Construct with [`Matcher::builder()`]. +#[derive(Default)] +pub struct Builder { + is_cgi: bool, + all: String, + http: String, + https: String, + no: String, +} + +#[derive(Clone)] +enum Auth { + Empty, + Basic(http::header::HeaderValue), + Raw(String, String), +} + +/// A filter for proxy matchers. +/// +/// This type is based off the `NO_PROXY` rules used by curl. +#[derive(Clone, Debug, Default)] +struct NoProxy { + ips: IpMatcher, + domains: DomainMatcher, +} + +#[derive(Clone, Debug, Default)] +struct DomainMatcher(Vec<String>); + +#[derive(Clone, Debug, Default)] +struct IpMatcher(Vec<Ip>); + +#[derive(Clone, Debug)] +enum Ip { + Address(IpAddr), + Network(IpNet), +} + +// ===== impl Matcher ===== + +impl Matcher { + /// Create a matcher reading the current environment variables. + /// + /// This checks for values in the following variables, treating them the + /// same as curl does: + /// + /// - `ALL_PROXY`/`all_proxy` + /// - `HTTPS_PROXY`/`https_proxy` + /// - `HTTP_PROXY`/`http_proxy` + /// - `NO_PROXY`/`no_proxy` + pub fn from_env() -> Self { + Builder::from_env().build() + } + + /// Create a matcher from the environment or system. + /// + /// This checks the same environment variables as `from_env()`, and if not + /// set, checks the system configuration for values for the OS. + /// + /// This constructor is always available, but if the `client-proxy-system` + /// feature is enabled, it will check more configuration. Use this + /// constructor if you want to allow users to optionally enable more, or + /// use `from_env` if you do not want the values to change based on an + /// enabled feature. + pub fn from_system() -> Self { + Builder::from_system().build() + } + + /// Start a builder to configure a matcher. + pub fn builder() -> Builder { + Builder::default() + } + + /// Check if the destination should be intercepted by a proxy. + /// + /// If the proxy rules match the destination, a new `Uri` will be returned + /// to connect to. + pub fn intercept(&self, dst: &http::Uri) -> Option<Intercept> { + // TODO(perf): don't need to check `no` if below doesn't match... + if self.no.contains(dst.host()?) { + return None; + } + + match dst.scheme_str() { + Some("http") => self.http.clone(), + Some("https") => self.https.clone(), + _ => None, + } + } +} + +impl fmt::Debug for Matcher { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut b = f.debug_struct("Matcher"); + + if let Some(ref http) = self.http { + b.field("http", http); + } + + if let Some(ref https) = self.https { + b.field("https", https); + } + + if !self.no.is_empty() { + b.field("no", &self.no); + } + b.finish() + } +} + +// ===== impl Intercept ===== + +impl Intercept { + /// Get the `http::Uri` for the target proxy. + pub fn uri(&self) -> &http::Uri { + &self.uri + } + + /// Get any configured basic authorization. + /// + /// This should usually be used with a `Proxy-Authorization` header, to + /// send in Basic format. + /// + /// # Example + /// + /// ```rust + /// # use hyper_util::client::proxy::matcher::Matcher; + /// # let uri = http::Uri::from_static("https://hyper.rs"); + /// let m = Matcher::builder() + /// .all("https://Aladdin:opensesame@localhost:8887") + /// .build(); + /// + /// let proxy = m.intercept(&uri).expect("example"); + /// let auth = proxy.basic_auth().expect("example"); + /// assert_eq!(auth, "Basic QWxhZGRpbjpvcGVuc2VzYW1l"); + /// ``` + pub fn basic_auth(&self) -> Option<&HeaderValue> { + if let Auth::Basic(ref val) = self.auth { + Some(val) + } else { + None + } + } + + /// Get any configured raw authorization. + /// + /// If not detected as another scheme, this is the username and password + /// that should be sent with whatever protocol the proxy handshake uses. + /// + /// # Example + /// + /// ```rust + /// # use hyper_util::client::proxy::matcher::Matcher; + /// # let uri = http::Uri::from_static("https://hyper.rs"); + /// let m = Matcher::builder() + /// .all("socks5h://Aladdin:opensesame@localhost:8887") + /// .build(); + /// + /// let proxy = m.intercept(&uri).expect("example"); + /// let auth = proxy.raw_auth().expect("example"); + /// assert_eq!(auth, ("Aladdin", "opensesame")); + /// ``` + pub fn raw_auth(&self) -> Option<(&str, &str)> { + if let Auth::Raw(ref u, ref p) = self.auth { + Some((u.as_str(), p.as_str())) + } else { + None + } + } +} + +impl fmt::Debug for Intercept { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Intercept") + .field("uri", &self.uri) + // dont output auth, its sensitive + .finish() + } +} + +// ===== impl Builder ===== + +impl Builder { + fn from_env() -> Self { + Builder { + is_cgi: std::env::var_os("REQUEST_METHOD").is_some(), + all: get_first_env(&["ALL_PROXY", "all_proxy"]), + http: get_first_env(&["HTTP_PROXY", "http_proxy"]), + https: get_first_env(&["HTTPS_PROXY", "https_proxy"]), + no: get_first_env(&["NO_PROXY", "no_proxy"]), + } + } + + fn from_system() -> Self { + #[allow(unused_mut)] + let mut builder = Self::from_env(); + + #[cfg(all(feature = "client-proxy-system", target_os = "macos"))] + mac::with_system(&mut builder); + + #[cfg(all(feature = "client-proxy-system", windows))] + win::with_system(&mut builder); + + builder + } + + /// Set the target proxy for all destinations. + pub fn all<S>(mut self, val: S) -> Self + where + S: IntoValue, + { + self.all = val.into_value(); + self + } + + /// Set the target proxy for HTTP destinations. + pub fn http<S>(mut self, val: S) -> Self + where + S: IntoValue, + { + self.http = val.into_value(); + self + } + + /// Set the target proxy for HTTPS destinations. + pub fn https<S>(mut self, val: S) -> Self + where + S: IntoValue, + { + self.https = val.into_value(); + self + } + + /// Set the "no" proxy filter. + /// + /// The rules are as follows: + /// * Entries are expected to be comma-separated (whitespace between entries is ignored) + /// * IP addresses (both IPv4 and IPv6) are allowed, as are optional subnet masks (by adding /size, + /// for example "`192.168.1.0/24`"). + /// * An entry "`*`" matches all hostnames (this is the only wildcard allowed) + /// * Any other entry is considered a domain name (and may contain a leading dot, for example `google.com` + /// and `.google.com` are equivalent) and would match both that domain AND all subdomains. + /// + /// For example, if `"NO_PROXY=google.com, 192.168.1.0/24"` was set, all of the following would match + /// (and therefore would bypass the proxy): + /// * `http://google.com/` + /// * `http://www.google.com/` + /// * `http://192.168.1.42/` + /// + /// The URL `http://notgoogle.com/` would not match. + pub fn no<S>(mut self, val: S) -> Self + where + S: IntoValue, + { + self.no = val.into_value(); + self + } + + /// Construct a [`Matcher`] using the configured values. + pub fn build(self) -> Matcher { + if self.is_cgi { + return Matcher { + http: None, + https: None, + no: NoProxy::empty(), + }; + } + + let all = parse_env_uri(&self.all); + + Matcher { + http: parse_env_uri(&self.http).or_else(|| all.clone()), + https: parse_env_uri(&self.https).or(all), + no: NoProxy::from_string(&self.no), + } + } +} + +fn get_first_env(names: &[&str]) -> String { + for name in names { + if let Ok(val) = std::env::var(name) { + return val; + } + } + + String::new() +} + +fn parse_env_uri(val: &str) -> Option<Intercept> { + let uri = val.parse::<http::Uri>().ok()?; + let mut builder = http::Uri::builder(); + let mut is_httpish = false; + let mut auth = Auth::Empty; + + builder = builder.scheme(match uri.scheme() { + Some(s) => { + if s == &http::uri::Scheme::HTTP || s == &http::uri::Scheme::HTTPS { + is_httpish = true; + s.clone() + } else if s.as_str() == "socks5" || s.as_str() == "socks5h" { + s.clone() + } else { + // can't use this proxy scheme + return None; + } + } + // if no scheme provided, assume they meant 'http' + None => { + is_httpish = true; + http::uri::Scheme::HTTP + } + }); + + let authority = uri.authority()?; + + if let Some((userinfo, host_port)) = authority.as_str().split_once('@') { + let (user, pass) = userinfo.split_once(':')?; + let user = percent_decode_str(user).decode_utf8_lossy(); + let pass = percent_decode_str(pass).decode_utf8_lossy(); + if is_httpish { + auth = Auth::Basic(encode_basic_auth(&user, Some(&pass))); + } else { + auth = Auth::Raw(user.into(), pass.into()); + } + builder = builder.authority(host_port); + } else { + builder = builder.authority(authority.clone()); + } + + // removing any path, but we MUST specify one or the builder errors + builder = builder.path_and_query("/"); + + let dst = builder.build().ok()?; + + Some(Intercept { uri: dst, auth }) +} + +fn encode_basic_auth(user: &str, pass: Option<&str>) -> HeaderValue { + use base64::prelude::BASE64_STANDARD; + use base64::write::EncoderWriter; + use std::io::Write; + + let mut buf = b"Basic ".to_vec(); + { + let mut encoder = EncoderWriter::new(&mut buf, &BASE64_STANDARD); + let _ = write!(encoder, "{user}:"); + if let Some(password) = pass { + let _ = write!(encoder, "{password}"); + } + } + let mut header = HeaderValue::from_bytes(&buf).expect("base64 is always valid HeaderValue"); + header.set_sensitive(true); + header +} + +impl NoProxy { + /* + fn from_env() -> NoProxy { + let raw = std::env::var("NO_PROXY") + .or_else(|_| std::env::var("no_proxy")) + .unwrap_or_default(); + + Self::from_string(&raw) + } + */ + + fn empty() -> NoProxy { + NoProxy { + ips: IpMatcher(Vec::new()), + domains: DomainMatcher(Vec::new()), + } + } + + /// Returns a new no-proxy configuration based on a `no_proxy` string (or `None` if no variables + /// are set) + /// The rules are as follows: + /// * The environment variable `NO_PROXY` is checked, if it is not set, `no_proxy` is checked + /// * If neither environment variable is set, `None` is returned + /// * Entries are expected to be comma-separated (whitespace between entries is ignored) + /// * IP addresses (both IPv4 and IPv6) are allowed, as are optional subnet masks (by adding /size, + /// for example "`192.168.1.0/24`"). + /// * An entry "`*`" matches all hostnames (this is the only wildcard allowed) + /// * Any other entry is considered a domain name (and may contain a leading dot, for example `google.com` + /// and `.google.com` are equivalent) and would match both that domain AND all subdomains. + /// + /// For example, if `"NO_PROXY=google.com, 192.168.1.0/24"` was set, all of the following would match + /// (and therefore would bypass the proxy): + /// * `http://google.com/` + /// * `http://www.google.com/` + /// * `http://192.168.1.42/` + /// + /// The URL `http://notgoogle.com/` would not match. + pub fn from_string(no_proxy_list: &str) -> Self { + let mut ips = Vec::new(); + let mut domains = Vec::new(); + let parts = no_proxy_list.split(',').map(str::trim); + for part in parts { + match part.parse::<IpNet>() { + // If we can parse an IP net or address, then use it, otherwise, assume it is a domain + Ok(ip) => ips.push(Ip::Network(ip)), + Err(_) => match part.parse::<IpAddr>() { + Ok(addr) => ips.push(Ip::Address(addr)), + Err(_) => { + if !part.trim().is_empty() { + domains.push(part.to_owned()) + } + } + }, + } + } + NoProxy { + ips: IpMatcher(ips), + domains: DomainMatcher(domains), + } + } + + /// Return true if this matches the host (domain or IP). + pub fn contains(&self, host: &str) -> bool { + // According to RFC3986, raw IPv6 hosts will be wrapped in []. So we need to strip those off + // the end in order to parse correctly + let host = if host.starts_with('[') { + let x: &[_] = &['[', ']']; + host.trim_matches(x) + } else { + host + }; + match host.parse::<IpAddr>() { + // If we can parse an IP addr, then use it, otherwise, assume it is a domain + Ok(ip) => self.ips.contains(ip), + Err(_) => self.domains.contains(host), + } + } + + fn is_empty(&self) -> bool { + self.ips.0.is_empty() && self.domains.0.is_empty() + } +} + +impl IpMatcher { + fn contains(&self, addr: IpAddr) -> bool { + for ip in &self.0 { + match ip { + Ip::Address(address) => { + if &addr == address { + return true; + } + } + Ip::Network(net) => { + if net.contains(&addr) { + return true; + } + } + } + } + false + } +} + +impl DomainMatcher { + // The following links may be useful to understand the origin of these rules: + // * https://curl.se/libcurl/c/CURLOPT_NOPROXY.html + // * https://github.com/curl/curl/issues/1208 + fn contains(&self, domain: &str) -> bool { + let domain_len = domain.len(); + for d in &self.0 { + if d == domain || d.strip_prefix('.') == Some(domain) { + return true; + } else if domain.ends_with(d) { + if d.starts_with('.') { + // If the first character of d is a dot, that means the first character of domain + // must also be a dot, so we are looking at a subdomain of d and that matches + return true; + } else if domain.as_bytes().get(domain_len - d.len() - 1) == Some(&b'.') { + // Given that d is a prefix of domain, if the prior character in domain is a dot + // then that means we must be matching a subdomain of d, and that matches + return true; + } + } else if d == "*" { + return true; + } + } + false + } +} + +mod builder { + /// A type that can used as a `Builder` value. + /// + /// Private and sealed, only visible in docs. + pub trait IntoValue { + #[doc(hidden)] + fn into_value(self) -> String; + } + + impl IntoValue for String { + #[doc(hidden)] + fn into_value(self) -> String { + self + } + } + + impl IntoValue for &String { + #[doc(hidden)] + fn into_value(self) -> String { + self.into() + } + } + + impl IntoValue for &str { + #[doc(hidden)] + fn into_value(self) -> String { + self.into() + } + } +} + +#[cfg(feature = "client-proxy-system")] +#[cfg(target_os = "macos")] +mod mac { + use system_configuration::core_foundation::base::{CFType, TCFType, TCFTypeRef}; + use system_configuration::core_foundation::dictionary::CFDictionary; + use system_configuration::core_foundation::number::CFNumber; + use system_configuration::core_foundation::string::{CFString, CFStringRef}; + use system_configuration::dynamic_store::SCDynamicStoreBuilder; + use system_configuration::sys::schema_definitions::{ + kSCPropNetProxiesHTTPEnable, kSCPropNetProxiesHTTPPort, kSCPropNetProxiesHTTPProxy, + kSCPropNetProxiesHTTPSEnable, kSCPropNetProxiesHTTPSPort, kSCPropNetProxiesHTTPSProxy, + }; + + pub(super) fn with_system(builder: &mut super::Builder) { + let store = SCDynamicStoreBuilder::new("").build(); + + let proxies_map = if let Some(proxies_map) = store.get_proxies() { + proxies_map + } else { + return; + }; + + if builder.http.is_empty() { + let http_proxy_config = parse_setting_from_dynamic_store( + &proxies_map, + unsafe { kSCPropNetProxiesHTTPEnable }, + unsafe { kSCPropNetProxiesHTTPProxy }, + unsafe { kSCPropNetProxiesHTTPPort }, + ); + if let Some(http) = http_proxy_config { + builder.http = http; + } + } + + if builder.https.is_empty() { + let https_proxy_config = parse_setting_from_dynamic_store( + &proxies_map, + unsafe { kSCPropNetProxiesHTTPSEnable }, + unsafe { kSCPropNetProxiesHTTPSProxy }, + unsafe { kSCPropNetProxiesHTTPSPort }, + ); + + if let Some(https) = https_proxy_config { + builder.https = https; + } + } + } + + fn parse_setting_from_dynamic_store( + proxies_map: &CFDictionary<CFString, CFType>, + enabled_key: CFStringRef, + host_key: CFStringRef, + port_key: CFStringRef, + ) -> Option<String> { + let proxy_enabled = proxies_map + .find(enabled_key) + .and_then(|flag| flag.downcast::<CFNumber>()) + .and_then(|flag| flag.to_i32()) + .unwrap_or(0) + == 1; + + if proxy_enabled { + let proxy_host = proxies_map + .find(host_key) + .and_then(|host| host.downcast::<CFString>()) + .map(|host| host.to_string()); + let proxy_port = proxies_map + .find(port_key) + .and_then(|port| port.downcast::<CFNumber>()) + .and_then(|port| port.to_i32()); + + return match (proxy_host, proxy_port) { + (Some(proxy_host), Some(proxy_port)) => Some(format!("{proxy_host}:{proxy_port}")), + (Some(proxy_host), None) => Some(proxy_host), + (None, Some(_)) => None, + (None, None) => None, + }; + } + + None + } +} + +#[cfg(feature = "client-proxy-system")] +#[cfg(windows)] +mod win { + pub(super) fn with_system(builder: &mut super::Builder) { + let settings = if let Ok(settings) = windows_registry::CURRENT_USER + .open("Software\\Microsoft\\Windows\\CurrentVersion\\Internet Settings") + { + settings + } else { + return; + }; + + if settings.get_u32("ProxyEnable").unwrap_or(0) == 0 { + return; + } + + if let Ok(val) = settings.get_string("ProxyServer") { + if builder.http.is_empty() { + builder.http = val.clone(); + } + if builder.https.is_empty() { + builder.https = val; + } + } + + if builder.no.is_empty() { + if let Ok(val) = settings.get_string("ProxyOverride") { + builder.no = val + .split(';') + .map(|s| s.trim()) + .collect::<Vec<&str>>() + .join(",") + .replace("*.", ""); + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_domain_matcher() { + let domains = vec![".foo.bar".into(), "bar.foo".into()]; + let matcher = DomainMatcher(domains); + + // domains match with leading `.` + assert!(matcher.contains("foo.bar")); + // subdomains match with leading `.` + assert!(matcher.contains("www.foo.bar")); + + // domains match with no leading `.` + assert!(matcher.contains("bar.foo")); + // subdomains match with no leading `.` + assert!(matcher.contains("www.bar.foo")); + + // non-subdomain string prefixes don't match + assert!(!matcher.contains("notfoo.bar")); + assert!(!matcher.contains("notbar.foo")); + } + + #[test] + fn test_no_proxy_wildcard() { + let no_proxy = NoProxy::from_string("*"); + assert!(no_proxy.contains("any.where")); + } + + #[test] + fn test_no_proxy_ip_ranges() { + let no_proxy = + NoProxy::from_string(".foo.bar, bar.baz,10.42.1.1/24,::1,10.124.7.8,2001::/17"); + + let should_not_match = [ + // random url, not in no_proxy + "hyper.rs", + // make sure that random non-subdomain string prefixes don't match + "notfoo.bar", + // make sure that random non-subdomain string prefixes don't match + "notbar.baz", + // ipv4 address out of range + "10.43.1.1", + // ipv4 address out of range + "10.124.7.7", + // ipv6 address out of range + "[ffff:db8:a0b:12f0::1]", + // ipv6 address out of range + "[2005:db8:a0b:12f0::1]", + ]; + + for host in &should_not_match { + assert!(!no_proxy.contains(host), "should not contain {host:?}"); + } + + let should_match = [ + // make sure subdomains (with leading .) match + "hello.foo.bar", + // make sure exact matches (without leading .) match (also makes sure spaces between entries work) + "bar.baz", + // make sure subdomains (without leading . in no_proxy) match + "foo.bar.baz", + // make sure subdomains (without leading . in no_proxy) match - this differs from cURL + "foo.bar", + // ipv4 address match within range + "10.42.1.100", + // ipv6 address exact match + "[::1]", + // ipv6 address match within range + "[2001:db8:a0b:12f0::1]", + // ipv4 address exact match + "10.124.7.8", + ]; + + for host in &should_match { + assert!(no_proxy.contains(host), "should contain {host:?}"); + } + } + + macro_rules! p { + ($($n:ident = $v:expr,)*) => ({Builder { + $($n: $v.into(),)* + ..Builder::default() + }.build()}); + } + + fn intercept(p: &Matcher, u: &str) -> Intercept { + p.intercept(&u.parse().unwrap()).unwrap() + } + + #[test] + fn test_all_proxy() { + let p = p! { + all = "http://om.nom", + }; + + assert_eq!("http://om.nom", intercept(&p, "http://example.com").uri()); + + assert_eq!("http://om.nom", intercept(&p, "https://example.com").uri()); + } + + #[test] + fn test_specific_overrides_all() { + let p = p! { + all = "http://no.pe", + http = "http://y.ep", + }; + + assert_eq!("http://no.pe", intercept(&p, "https://example.com").uri()); + + // the http rule is "more specific" than the all rule + assert_eq!("http://y.ep", intercept(&p, "http://example.com").uri()); + } + + #[test] + fn test_parse_no_scheme_defaults_to_http() { + let p = p! { + https = "y.ep", + http = "127.0.0.1:8887", + }; + + assert_eq!(intercept(&p, "https://example.local").uri(), "http://y.ep"); + assert_eq!( + intercept(&p, "http://example.local").uri(), + "http://127.0.0.1:8887" + ); + } + + #[test] + fn test_parse_http_auth() { + let p = p! { + all = "http://Aladdin:opensesame@y.ep", + }; + + let proxy = intercept(&p, "https://example.local"); + assert_eq!(proxy.uri(), "http://y.ep"); + assert_eq!( + proxy.basic_auth().expect("basic_auth"), + "Basic QWxhZGRpbjpvcGVuc2VzYW1l" + ); + } + + #[test] + fn test_parse_http_auth_without_scheme() { + let p = p! { + all = "Aladdin:opensesame@y.ep", + }; + + let proxy = intercept(&p, "https://example.local"); + assert_eq!(proxy.uri(), "http://y.ep"); + assert_eq!( + proxy.basic_auth().expect("basic_auth"), + "Basic QWxhZGRpbjpvcGVuc2VzYW1l" + ); + } + + #[test] + fn test_dont_parse_http_when_is_cgi() { + let mut builder = Matcher::builder(); + builder.is_cgi = true; + builder.http = "http://never.gonna.let.you.go".into(); + let m = builder.build(); + + assert!(m.intercept(&"http://rick.roll".parse().unwrap()).is_none()); + } +} diff --git a/vendor/hyper-util/src/client/proxy/mod.rs b/vendor/hyper-util/src/client/proxy/mod.rs new file mode 100644 index 00000000..59c8e46d --- /dev/null +++ b/vendor/hyper-util/src/client/proxy/mod.rs @@ -0,0 +1,3 @@ +//! Proxy utilities + +pub mod matcher; diff --git a/vendor/hyper-util/src/client/service.rs b/vendor/hyper-util/src/client/service.rs new file mode 100644 index 00000000..580fb105 --- /dev/null +++ b/vendor/hyper-util/src/client/service.rs @@ -0,0 +1,8 @@ +struct ConnectingPool<C, P> { + connector: C, + pool: P, +} + +struct PoolableSvc<S>(S); + + diff --git a/vendor/hyper-util/src/common/exec.rs b/vendor/hyper-util/src/common/exec.rs new file mode 100644 index 00000000..40860ee1 --- /dev/null +++ b/vendor/hyper-util/src/common/exec.rs @@ -0,0 +1,53 @@ +#![allow(dead_code)] + +use hyper::rt::Executor; +use std::fmt; +use std::future::Future; +use std::pin::Pin; +use std::sync::Arc; + +pub(crate) type BoxSendFuture = Pin<Box<dyn Future<Output = ()> + Send>>; + +// Either the user provides an executor for background tasks, or we use +// `tokio::spawn`. +#[derive(Clone)] +pub(crate) enum Exec { + Executor(Arc<dyn Executor<BoxSendFuture> + Send + Sync>), +} + +// ===== impl Exec ===== + +impl Exec { + pub(crate) fn new<E>(inner: E) -> Self + where + E: Executor<BoxSendFuture> + Send + Sync + 'static, + { + Exec::Executor(Arc::new(inner)) + } + + pub(crate) fn execute<F>(&self, fut: F) + where + F: Future<Output = ()> + Send + 'static, + { + match *self { + Exec::Executor(ref e) => { + e.execute(Box::pin(fut)); + } + } + } +} + +impl fmt::Debug for Exec { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Exec").finish() + } +} + +impl<F> hyper::rt::Executor<F> for Exec +where + F: Future<Output = ()> + Send + 'static, +{ + fn execute(&self, fut: F) { + Exec::execute(self, fut); + } +} diff --git a/vendor/hyper-util/src/common/future.rs b/vendor/hyper-util/src/common/future.rs new file mode 100644 index 00000000..47897f24 --- /dev/null +++ b/vendor/hyper-util/src/common/future.rs @@ -0,0 +1,30 @@ +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, +}; + +// TODO: replace with `std::future::poll_fn` once MSRV >= 1.64 +pub(crate) fn poll_fn<T, F>(f: F) -> PollFn<F> +where + F: FnMut(&mut Context<'_>) -> Poll<T>, +{ + PollFn { f } +} + +pub(crate) struct PollFn<F> { + f: F, +} + +impl<F> Unpin for PollFn<F> {} + +impl<T, F> Future for PollFn<F> +where + F: FnMut(&mut Context<'_>) -> Poll<T>, +{ + type Output = T; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + (self.f)(cx) + } +} diff --git a/vendor/hyper-util/src/common/lazy.rs b/vendor/hyper-util/src/common/lazy.rs new file mode 100644 index 00000000..7ec09bbe --- /dev/null +++ b/vendor/hyper-util/src/common/lazy.rs @@ -0,0 +1,78 @@ +use pin_project_lite::pin_project; + +use std::future::Future; +use std::pin::Pin; +use std::task::{self, Poll}; + +pub(crate) trait Started: Future { + fn started(&self) -> bool; +} + +pub(crate) fn lazy<F, R>(func: F) -> Lazy<F, R> +where + F: FnOnce() -> R, + R: Future + Unpin, +{ + Lazy { + inner: Inner::Init { func }, + } +} + +// FIXME: allow() required due to `impl Trait` leaking types to this lint +pin_project! { + #[allow(missing_debug_implementations)] + pub(crate) struct Lazy<F, R> { + #[pin] + inner: Inner<F, R>, + } +} + +pin_project! { + #[project = InnerProj] + #[project_replace = InnerProjReplace] + enum Inner<F, R> { + Init { func: F }, + Fut { #[pin] fut: R }, + Empty, + } +} + +impl<F, R> Started for Lazy<F, R> +where + F: FnOnce() -> R, + R: Future, +{ + fn started(&self) -> bool { + match self.inner { + Inner::Init { .. } => false, + Inner::Fut { .. } | Inner::Empty => true, + } + } +} + +impl<F, R> Future for Lazy<F, R> +where + F: FnOnce() -> R, + R: Future, +{ + type Output = R::Output; + + fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> { + let mut this = self.project(); + + if let InnerProj::Fut { fut } = this.inner.as_mut().project() { + return fut.poll(cx); + } + + match this.inner.as_mut().project_replace(Inner::Empty) { + InnerProjReplace::Init { func } => { + this.inner.set(Inner::Fut { fut: func() }); + if let InnerProj::Fut { fut } = this.inner.project() { + return fut.poll(cx); + } + unreachable!() + } + _ => unreachable!("lazy state wrong"), + } + } +} diff --git a/vendor/hyper-util/src/common/mod.rs b/vendor/hyper-util/src/common/mod.rs new file mode 100644 index 00000000..b45cd0b2 --- /dev/null +++ b/vendor/hyper-util/src/common/mod.rs @@ -0,0 +1,19 @@ +#![allow(missing_docs)] + +pub(crate) mod exec; +#[cfg(feature = "client")] +mod lazy; +pub(crate) mod rewind; +#[cfg(feature = "client")] +mod sync; +pub(crate) mod timer; + +#[cfg(feature = "client")] +pub(crate) use exec::Exec; + +#[cfg(feature = "client")] +pub(crate) use lazy::{lazy, Started as Lazy}; +#[cfg(feature = "client")] +pub(crate) use sync::SyncWrapper; + +pub(crate) mod future; diff --git a/vendor/hyper-util/src/common/rewind.rs b/vendor/hyper-util/src/common/rewind.rs new file mode 100644 index 00000000..760d7966 --- /dev/null +++ b/vendor/hyper-util/src/common/rewind.rs @@ -0,0 +1,137 @@ +use std::{cmp, io}; + +use bytes::{Buf, Bytes}; +use hyper::rt::{Read, ReadBufCursor, Write}; + +use std::{ + pin::Pin, + task::{self, Poll}, +}; + +/// Combine a buffer with an IO, rewinding reads to use the buffer. +#[derive(Debug)] +pub(crate) struct Rewind<T> { + pub(crate) pre: Option<Bytes>, + pub(crate) inner: T, +} + +impl<T> Rewind<T> { + #[cfg(all(feature = "server", any(feature = "http1", feature = "http2")))] + pub(crate) fn new_buffered(io: T, buf: Bytes) -> Self { + Rewind { + pre: Some(buf), + inner: io, + } + } +} + +impl<T> Read for Rewind<T> +where + T: Read + Unpin, +{ + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + mut buf: ReadBufCursor<'_>, + ) -> Poll<io::Result<()>> { + if let Some(mut prefix) = self.pre.take() { + // If there are no remaining bytes, let the bytes get dropped. + if !prefix.is_empty() { + let copy_len = cmp::min(prefix.len(), buf.remaining()); + buf.put_slice(&prefix[..copy_len]); + prefix.advance(copy_len); + // Put back what's left + if !prefix.is_empty() { + self.pre = Some(prefix); + } + + return Poll::Ready(Ok(())); + } + } + Pin::new(&mut self.inner).poll_read(cx, buf) + } +} + +impl<T> Write for Rewind<T> +where + T: Write + Unpin, +{ + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + buf: &[u8], + ) -> Poll<io::Result<usize>> { + Pin::new(&mut self.inner).poll_write(cx, buf) + } + + fn poll_write_vectored( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + bufs: &[io::IoSlice<'_>], + ) -> Poll<io::Result<usize>> { + Pin::new(&mut self.inner).poll_write_vectored(cx, bufs) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> { + Pin::new(&mut self.inner).poll_flush(cx) + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> { + Pin::new(&mut self.inner).poll_shutdown(cx) + } + + fn is_write_vectored(&self) -> bool { + self.inner.is_write_vectored() + } +} + +/* +#[cfg(test)] +mod tests { + use super::Rewind; + use bytes::Bytes; + use tokio::io::AsyncReadExt; + + #[cfg(not(miri))] + #[tokio::test] + async fn partial_rewind() { + let underlying = [104, 101, 108, 108, 111]; + + let mock = tokio_test::io::Builder::new().read(&underlying).build(); + + let mut stream = Rewind::new(mock); + + // Read off some bytes, ensure we filled o1 + let mut buf = [0; 2]; + stream.read_exact(&mut buf).await.expect("read1"); + + // Rewind the stream so that it is as if we never read in the first place. + stream.rewind(Bytes::copy_from_slice(&buf[..])); + + let mut buf = [0; 5]; + stream.read_exact(&mut buf).await.expect("read1"); + + // At this point we should have read everything that was in the MockStream + assert_eq!(&buf, &underlying); + } + + #[cfg(not(miri))] + #[tokio::test] + async fn full_rewind() { + let underlying = [104, 101, 108, 108, 111]; + + let mock = tokio_test::io::Builder::new().read(&underlying).build(); + + let mut stream = Rewind::new(mock); + + let mut buf = [0; 5]; + stream.read_exact(&mut buf).await.expect("read1"); + + // Rewind the stream so that it is as if we never read in the first place. + stream.rewind(Bytes::copy_from_slice(&buf[..])); + + let mut buf = [0; 5]; + stream.read_exact(&mut buf).await.expect("read1"); + } +} +*/ diff --git a/vendor/hyper-util/src/common/sync.rs b/vendor/hyper-util/src/common/sync.rs new file mode 100644 index 00000000..2755fd05 --- /dev/null +++ b/vendor/hyper-util/src/common/sync.rs @@ -0,0 +1,67 @@ +pub(crate) struct SyncWrapper<T>(T); + +impl<T> SyncWrapper<T> { + /// Creates a new SyncWrapper containing the given value. + /// + /// # Examples + /// + /// ```ignore + /// use hyper::common::sync_wrapper::SyncWrapper; + /// + /// let wrapped = SyncWrapper::new(42); + /// ``` + pub(crate) fn new(value: T) -> Self { + Self(value) + } + + /// Acquires a reference to the protected value. + /// + /// This is safe because it requires an exclusive reference to the wrapper. Therefore this method + /// neither panics nor does it return an error. This is in contrast to [`Mutex::get_mut`] which + /// returns an error if another thread panicked while holding the lock. It is not recommended + /// to send an exclusive reference to a potentially damaged value to another thread for further + /// processing. + /// + /// [`Mutex::get_mut`]: https://doc.rust-lang.org/std/sync/struct.Mutex.html#method.get_mut + /// + /// # Examples + /// + /// ```ignore + /// use hyper::common::sync_wrapper::SyncWrapper; + /// + /// let mut wrapped = SyncWrapper::new(42); + /// let value = wrapped.get_mut(); + /// *value = 0; + /// assert_eq!(*wrapped.get_mut(), 0); + /// ``` + pub(crate) fn get_mut(&mut self) -> &mut T { + &mut self.0 + } + + /// Consumes this wrapper, returning the underlying data. + /// + /// This is safe because it requires ownership of the wrapper, aherefore this method will neither + /// panic nor does it return an error. This is in contrast to [`Mutex::into_inner`] which + /// returns an error if another thread panicked while holding the lock. It is not recommended + /// to send an exclusive reference to a potentially damaged value to another thread for further + /// processing. + /// + /// [`Mutex::into_inner`]: https://doc.rust-lang.org/std/sync/struct.Mutex.html#method.into_inner + /// + /// # Examples + /// + /// ```ignore + /// use hyper::common::sync_wrapper::SyncWrapper; + /// + /// let mut wrapped = SyncWrapper::new(42); + /// assert_eq!(wrapped.into_inner(), 42); + /// ``` + #[allow(dead_code)] + pub(crate) fn into_inner(self) -> T { + self.0 + } +} + +// this is safe because the only operations permitted on this data structure require exclusive +// access or ownership +unsafe impl<T: Send> Sync for SyncWrapper<T> {} diff --git a/vendor/hyper-util/src/common/timer.rs b/vendor/hyper-util/src/common/timer.rs new file mode 100644 index 00000000..390be3b0 --- /dev/null +++ b/vendor/hyper-util/src/common/timer.rs @@ -0,0 +1,38 @@ +#![allow(dead_code)] + +use std::fmt; +use std::pin::Pin; +use std::sync::Arc; +use std::time::Duration; +use std::time::Instant; + +use hyper::rt::Sleep; + +#[derive(Clone)] +pub(crate) struct Timer(Arc<dyn hyper::rt::Timer + Send + Sync>); + +// =====impl Timer===== +impl Timer { + pub(crate) fn new<T>(inner: T) -> Self + where + T: hyper::rt::Timer + Send + Sync + 'static, + { + Self(Arc::new(inner)) + } +} + +impl fmt::Debug for Timer { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Timer").finish() + } +} + +impl hyper::rt::Timer for Timer { + fn sleep(&self, duration: Duration) -> Pin<Box<dyn Sleep>> { + self.0.sleep(duration) + } + + fn sleep_until(&self, deadline: Instant) -> Pin<Box<dyn Sleep>> { + self.0.sleep_until(deadline) + } +} diff --git a/vendor/hyper-util/src/error.rs b/vendor/hyper-util/src/error.rs new file mode 100644 index 00000000..d1894495 --- /dev/null +++ b/vendor/hyper-util/src/error.rs @@ -0,0 +1,14 @@ +/* +use std::error::Error; + +pub(crate) fn find<'a, E: Error + 'static>(top: &'a (dyn Error + 'static)) -> Option<&'a E> { + let mut err = Some(top); + while let Some(src) = err { + if src.is::<E>() { + return src.downcast_ref(); + } + err = src.source(); + } + None +} +*/ diff --git a/vendor/hyper-util/src/lib.rs b/vendor/hyper-util/src/lib.rs new file mode 100644 index 00000000..ac8f89b1 --- /dev/null +++ b/vendor/hyper-util/src/lib.rs @@ -0,0 +1,18 @@ +#![deny(missing_docs)] +#![cfg_attr(docsrs, feature(doc_auto_cfg, doc_cfg))] + +//! Utilities for working with hyper. +//! +//! This crate is less-stable than [`hyper`](https://docs.rs/hyper). However, +//! does respect Rust's semantic version regarding breaking changes. + +#[cfg(feature = "client")] +pub mod client; +mod common; +pub mod rt; +#[cfg(feature = "server")] +pub mod server; +#[cfg(any(feature = "service", feature = "client-legacy"))] +pub mod service; + +mod error; diff --git a/vendor/hyper-util/src/rt/io.rs b/vendor/hyper-util/src/rt/io.rs new file mode 100644 index 00000000..888756f6 --- /dev/null +++ b/vendor/hyper-util/src/rt/io.rs @@ -0,0 +1,34 @@ +use std::marker::Unpin; +use std::pin::Pin; +use std::task::Poll; + +use futures_core::ready; +use hyper::rt::{Read, ReadBuf, Write}; + +use crate::common::future::poll_fn; + +pub(crate) async fn read<T>(io: &mut T, buf: &mut [u8]) -> Result<usize, std::io::Error> +where + T: Read + Unpin, +{ + poll_fn(move |cx| { + let mut buf = ReadBuf::new(buf); + ready!(Pin::new(&mut *io).poll_read(cx, buf.unfilled()))?; + Poll::Ready(Ok(buf.filled().len())) + }) + .await +} + +pub(crate) async fn write_all<T>(io: &mut T, buf: &[u8]) -> Result<(), std::io::Error> +where + T: Write + Unpin, +{ + let mut n = 0; + poll_fn(move |cx| { + while n < buf.len() { + n += ready!(Pin::new(&mut *io).poll_write(cx, &buf[n..])?); + } + Poll::Ready(Ok(())) + }) + .await +} diff --git a/vendor/hyper-util/src/rt/mod.rs b/vendor/hyper-util/src/rt/mod.rs new file mode 100644 index 00000000..71363ccd --- /dev/null +++ b/vendor/hyper-util/src/rt/mod.rs @@ -0,0 +1,12 @@ +//! Runtime utilities + +#[cfg(feature = "client-legacy")] +mod io; +#[cfg(feature = "client-legacy")] +pub(crate) use self::io::{read, write_all}; + +#[cfg(feature = "tokio")] +pub mod tokio; + +#[cfg(feature = "tokio")] +pub use self::tokio::{TokioExecutor, TokioIo, TokioTimer}; diff --git a/vendor/hyper-util/src/rt/tokio.rs b/vendor/hyper-util/src/rt/tokio.rs new file mode 100644 index 00000000..46ffeba8 --- /dev/null +++ b/vendor/hyper-util/src/rt/tokio.rs @@ -0,0 +1,339 @@ +//! [`tokio`] runtime components integration for [`hyper`]. +//! +//! [`hyper::rt`] exposes a set of traits to allow hyper to be agnostic to +//! its underlying asynchronous runtime. This submodule provides glue for +//! [`tokio`] users to bridge those types to [`hyper`]'s interfaces. +//! +//! # IO +//! +//! [`hyper`] abstracts over asynchronous readers and writers using [`Read`] +//! and [`Write`], while [`tokio`] abstracts over this using [`AsyncRead`] +//! and [`AsyncWrite`]. This submodule provides a collection of IO adaptors +//! to bridge these two IO ecosystems together: [`TokioIo<I>`], +//! [`WithHyperIo<I>`], and [`WithTokioIo<I>`]. +//! +//! To compare and constrast these IO adaptors and to help explain which +//! is the proper choice for your needs, here is a table showing which IO +//! traits these implement, given two types `T` and `H` which implement +//! Tokio's and Hyper's corresponding IO traits: +//! +//! | | [`AsyncRead`] | [`AsyncWrite`] | [`Read`] | [`Write`] | +//! |--------------------|------------------|-------------------|--------------|--------------| +//! | `T` | ✅ **true** | ✅ **true** | ❌ **false** | ❌ **false** | +//! | `H` | ❌ **false** | ❌ **false** | ✅ **true** | ✅ **true** | +//! | [`TokioIo<T>`] | ❌ **false** | ❌ **false** | ✅ **true** | ✅ **true** | +//! | [`TokioIo<H>`] | ✅ **true** | ✅ **true** | ❌ **false** | ❌ **false** | +//! | [`WithHyperIo<T>`] | ✅ **true** | ✅ **true** | ✅ **true** | ✅ **true** | +//! | [`WithHyperIo<H>`] | ❌ **false** | ❌ **false** | ❌ **false** | ❌ **false** | +//! | [`WithTokioIo<T>`] | ❌ **false** | ❌ **false** | ❌ **false** | ❌ **false** | +//! | [`WithTokioIo<H>`] | ✅ **true** | ✅ **true** | ✅ **true** | ✅ **true** | +//! +//! For most situations, [`TokioIo<I>`] is the proper choice. This should be +//! constructed, wrapping some underlying [`hyper`] or [`tokio`] IO, at the +//! call-site of a function like [`hyper::client::conn::http1::handshake`]. +//! +//! [`TokioIo<I>`] switches across these ecosystems, but notably does not +//! preserve the existing IO trait implementations of its underlying IO. If +//! one wishes to _extend_ IO with additional implementations, +//! [`WithHyperIo<I>`] and [`WithTokioIo<I>`] are the correct choice. +//! +//! For example, a Tokio reader/writer can be wrapped in [`WithHyperIo<I>`]. +//! That will implement _both_ sets of IO traits. Conversely, +//! [`WithTokioIo<I>`] will implement both sets of IO traits given a +//! reader/writer that implements Hyper's [`Read`] and [`Write`]. +//! +//! See [`tokio::io`] and ["_Asynchronous IO_"][tokio-async-docs] for more +//! information. +//! +//! [`AsyncRead`]: tokio::io::AsyncRead +//! [`AsyncWrite`]: tokio::io::AsyncWrite +//! [`Read`]: hyper::rt::Read +//! [`Write`]: hyper::rt::Write +//! [tokio-async-docs]: https://docs.rs/tokio/latest/tokio/#asynchronous-io + +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, + time::{Duration, Instant}, +}; + +use hyper::rt::{Executor, Sleep, Timer}; +use pin_project_lite::pin_project; + +#[cfg(feature = "tracing")] +use tracing::instrument::Instrument; + +pub use self::{with_hyper_io::WithHyperIo, with_tokio_io::WithTokioIo}; + +mod with_hyper_io; +mod with_tokio_io; + +/// Future executor that utilises `tokio` threads. +#[non_exhaustive] +#[derive(Default, Debug, Clone)] +pub struct TokioExecutor {} + +pin_project! { + /// A wrapper that implements Tokio's IO traits for an inner type that + /// implements hyper's IO traits, or vice versa (implements hyper's IO + /// traits for a type that implements Tokio's IO traits). + #[derive(Debug)] + pub struct TokioIo<T> { + #[pin] + inner: T, + } +} + +/// A Timer that uses the tokio runtime. +#[non_exhaustive] +#[derive(Default, Clone, Debug)] +pub struct TokioTimer; + +// Use TokioSleep to get tokio::time::Sleep to implement Unpin. +// see https://docs.rs/tokio/latest/tokio/time/struct.Sleep.html +pin_project! { + #[derive(Debug)] + struct TokioSleep { + #[pin] + inner: tokio::time::Sleep, + } +} + +// ===== impl TokioExecutor ===== + +impl<Fut> Executor<Fut> for TokioExecutor +where + Fut: Future + Send + 'static, + Fut::Output: Send + 'static, +{ + fn execute(&self, fut: Fut) { + #[cfg(feature = "tracing")] + tokio::spawn(fut.in_current_span()); + + #[cfg(not(feature = "tracing"))] + tokio::spawn(fut); + } +} + +impl TokioExecutor { + /// Create new executor that relies on [`tokio::spawn`] to execute futures. + pub fn new() -> Self { + Self {} + } +} + +// ==== impl TokioIo ===== + +impl<T> TokioIo<T> { + /// Wrap a type implementing Tokio's or hyper's IO traits. + pub fn new(inner: T) -> Self { + Self { inner } + } + + /// Borrow the inner type. + pub fn inner(&self) -> &T { + &self.inner + } + + /// Mut borrow the inner type. + pub fn inner_mut(&mut self) -> &mut T { + &mut self.inner + } + + /// Consume this wrapper and get the inner type. + pub fn into_inner(self) -> T { + self.inner + } +} + +impl<T> hyper::rt::Read for TokioIo<T> +where + T: tokio::io::AsyncRead, +{ + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + mut buf: hyper::rt::ReadBufCursor<'_>, + ) -> Poll<Result<(), std::io::Error>> { + let n = unsafe { + let mut tbuf = tokio::io::ReadBuf::uninit(buf.as_mut()); + match tokio::io::AsyncRead::poll_read(self.project().inner, cx, &mut tbuf) { + Poll::Ready(Ok(())) => tbuf.filled().len(), + other => return other, + } + }; + + unsafe { + buf.advance(n); + } + Poll::Ready(Ok(())) + } +} + +impl<T> hyper::rt::Write for TokioIo<T> +where + T: tokio::io::AsyncWrite, +{ + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll<Result<usize, std::io::Error>> { + tokio::io::AsyncWrite::poll_write(self.project().inner, cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> { + tokio::io::AsyncWrite::poll_flush(self.project().inner, cx) + } + + fn poll_shutdown( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll<Result<(), std::io::Error>> { + tokio::io::AsyncWrite::poll_shutdown(self.project().inner, cx) + } + + fn is_write_vectored(&self) -> bool { + tokio::io::AsyncWrite::is_write_vectored(&self.inner) + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[std::io::IoSlice<'_>], + ) -> Poll<Result<usize, std::io::Error>> { + tokio::io::AsyncWrite::poll_write_vectored(self.project().inner, cx, bufs) + } +} + +impl<T> tokio::io::AsyncRead for TokioIo<T> +where + T: hyper::rt::Read, +{ + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + tbuf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll<Result<(), std::io::Error>> { + //let init = tbuf.initialized().len(); + let filled = tbuf.filled().len(); + let sub_filled = unsafe { + let mut buf = hyper::rt::ReadBuf::uninit(tbuf.unfilled_mut()); + + match hyper::rt::Read::poll_read(self.project().inner, cx, buf.unfilled()) { + Poll::Ready(Ok(())) => buf.filled().len(), + other => return other, + } + }; + + let n_filled = filled + sub_filled; + // At least sub_filled bytes had to have been initialized. + let n_init = sub_filled; + unsafe { + tbuf.assume_init(n_init); + tbuf.set_filled(n_filled); + } + + Poll::Ready(Ok(())) + } +} + +impl<T> tokio::io::AsyncWrite for TokioIo<T> +where + T: hyper::rt::Write, +{ + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll<Result<usize, std::io::Error>> { + hyper::rt::Write::poll_write(self.project().inner, cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> { + hyper::rt::Write::poll_flush(self.project().inner, cx) + } + + fn poll_shutdown( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll<Result<(), std::io::Error>> { + hyper::rt::Write::poll_shutdown(self.project().inner, cx) + } + + fn is_write_vectored(&self) -> bool { + hyper::rt::Write::is_write_vectored(&self.inner) + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[std::io::IoSlice<'_>], + ) -> Poll<Result<usize, std::io::Error>> { + hyper::rt::Write::poll_write_vectored(self.project().inner, cx, bufs) + } +} + +// ==== impl TokioTimer ===== + +impl Timer for TokioTimer { + fn sleep(&self, duration: Duration) -> Pin<Box<dyn Sleep>> { + Box::pin(TokioSleep { + inner: tokio::time::sleep(duration), + }) + } + + fn sleep_until(&self, deadline: Instant) -> Pin<Box<dyn Sleep>> { + Box::pin(TokioSleep { + inner: tokio::time::sleep_until(deadline.into()), + }) + } + + fn reset(&self, sleep: &mut Pin<Box<dyn Sleep>>, new_deadline: Instant) { + if let Some(sleep) = sleep.as_mut().downcast_mut_pin::<TokioSleep>() { + sleep.reset(new_deadline) + } + } +} + +impl TokioTimer { + /// Create a new TokioTimer + pub fn new() -> Self { + Self {} + } +} + +impl Future for TokioSleep { + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + self.project().inner.poll(cx) + } +} + +impl Sleep for TokioSleep {} + +impl TokioSleep { + fn reset(self: Pin<&mut Self>, deadline: Instant) { + self.project().inner.as_mut().reset(deadline.into()); + } +} + +#[cfg(test)] +mod tests { + use crate::rt::TokioExecutor; + use hyper::rt::Executor; + use tokio::sync::oneshot; + + #[cfg(not(miri))] + #[tokio::test] + async fn simple_execute() -> Result<(), Box<dyn std::error::Error>> { + let (tx, rx) = oneshot::channel(); + let executor = TokioExecutor::new(); + executor.execute(async move { + tx.send(()).unwrap(); + }); + rx.await.map_err(Into::into) + } +} diff --git a/vendor/hyper-util/src/rt/tokio/with_hyper_io.rs b/vendor/hyper-util/src/rt/tokio/with_hyper_io.rs new file mode 100644 index 00000000..9c5072d4 --- /dev/null +++ b/vendor/hyper-util/src/rt/tokio/with_hyper_io.rs @@ -0,0 +1,170 @@ +use pin_project_lite::pin_project; +use std::{ + pin::Pin, + task::{Context, Poll}, +}; + +pin_project! { + /// Extends an underlying [`tokio`] I/O with [`hyper`] I/O implementations. + /// + /// This implements [`Read`] and [`Write`] given an inner type that implements [`AsyncRead`] + /// and [`AsyncWrite`], respectively. + #[derive(Debug)] + pub struct WithHyperIo<I> { + #[pin] + inner: I, + } +} + +// ==== impl WithHyperIo ===== + +impl<I> WithHyperIo<I> { + /// Wraps the inner I/O in an [`WithHyperIo<I>`] + pub fn new(inner: I) -> Self { + Self { inner } + } + + /// Returns a reference to the inner type. + pub fn inner(&self) -> &I { + &self.inner + } + + /// Returns a mutable reference to the inner type. + pub fn inner_mut(&mut self) -> &mut I { + &mut self.inner + } + + /// Consumes this wrapper and returns the inner type. + pub fn into_inner(self) -> I { + self.inner + } +} + +/// [`WithHyperIo<I>`] is [`Read`] if `I` is [`AsyncRead`]. +/// +/// [`AsyncRead`]: tokio::io::AsyncRead +/// [`Read`]: hyper::rt::Read +impl<I> hyper::rt::Read for WithHyperIo<I> +where + I: tokio::io::AsyncRead, +{ + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + mut buf: hyper::rt::ReadBufCursor<'_>, + ) -> Poll<Result<(), std::io::Error>> { + let n = unsafe { + let mut tbuf = tokio::io::ReadBuf::uninit(buf.as_mut()); + match tokio::io::AsyncRead::poll_read(self.project().inner, cx, &mut tbuf) { + Poll::Ready(Ok(())) => tbuf.filled().len(), + other => return other, + } + }; + + unsafe { + buf.advance(n); + } + Poll::Ready(Ok(())) + } +} + +/// [`WithHyperIo<I>`] is [`Write`] if `I` is [`AsyncWrite`]. +/// +/// [`AsyncWrite`]: tokio::io::AsyncWrite +/// [`Write`]: hyper::rt::Write +impl<I> hyper::rt::Write for WithHyperIo<I> +where + I: tokio::io::AsyncWrite, +{ + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll<Result<usize, std::io::Error>> { + tokio::io::AsyncWrite::poll_write(self.project().inner, cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> { + tokio::io::AsyncWrite::poll_flush(self.project().inner, cx) + } + + fn poll_shutdown( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll<Result<(), std::io::Error>> { + tokio::io::AsyncWrite::poll_shutdown(self.project().inner, cx) + } + + fn is_write_vectored(&self) -> bool { + tokio::io::AsyncWrite::is_write_vectored(&self.inner) + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[std::io::IoSlice<'_>], + ) -> Poll<Result<usize, std::io::Error>> { + tokio::io::AsyncWrite::poll_write_vectored(self.project().inner, cx, bufs) + } +} + +/// [`WithHyperIo<I>`] exposes its inner `I`'s [`AsyncRead`] implementation. +/// +/// [`AsyncRead`]: tokio::io::AsyncRead +impl<I> tokio::io::AsyncRead for WithHyperIo<I> +where + I: tokio::io::AsyncRead, +{ + #[inline] + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll<Result<(), std::io::Error>> { + self.project().inner.poll_read(cx, buf) + } +} + +/// [`WithHyperIo<I>`] exposes its inner `I`'s [`AsyncWrite`] implementation. +/// +/// [`AsyncWrite`]: tokio::io::AsyncWrite +impl<I> tokio::io::AsyncWrite for WithHyperIo<I> +where + I: tokio::io::AsyncWrite, +{ + #[inline] + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll<Result<usize, std::io::Error>> { + self.project().inner.poll_write(cx, buf) + } + + #[inline] + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> { + self.project().inner.poll_flush(cx) + } + + #[inline] + fn poll_shutdown( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll<Result<(), std::io::Error>> { + self.project().inner.poll_shutdown(cx) + } + + #[inline] + fn is_write_vectored(&self) -> bool { + self.inner.is_write_vectored() + } + + #[inline] + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[std::io::IoSlice<'_>], + ) -> Poll<Result<usize, std::io::Error>> { + self.project().inner.poll_write_vectored(cx, bufs) + } +} diff --git a/vendor/hyper-util/src/rt/tokio/with_tokio_io.rs b/vendor/hyper-util/src/rt/tokio/with_tokio_io.rs new file mode 100644 index 00000000..223e0ed3 --- /dev/null +++ b/vendor/hyper-util/src/rt/tokio/with_tokio_io.rs @@ -0,0 +1,178 @@ +use pin_project_lite::pin_project; +use std::{ + pin::Pin, + task::{Context, Poll}, +}; + +pin_project! { + /// Extends an underlying [`hyper`] I/O with [`tokio`] I/O implementations. + /// + /// This implements [`AsyncRead`] and [`AsyncWrite`] given an inner type that implements + /// [`Read`] and [`Write`], respectively. + #[derive(Debug)] + pub struct WithTokioIo<I> { + #[pin] + inner: I, + } +} + +// ==== impl WithTokioIo ===== + +/// [`WithTokioIo<I>`] is [`AsyncRead`] if `I` is [`Read`]. +/// +/// [`AsyncRead`]: tokio::io::AsyncRead +/// [`Read`]: hyper::rt::Read +impl<I> tokio::io::AsyncRead for WithTokioIo<I> +where + I: hyper::rt::Read, +{ + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + tbuf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll<Result<(), std::io::Error>> { + //let init = tbuf.initialized().len(); + let filled = tbuf.filled().len(); + let sub_filled = unsafe { + let mut buf = hyper::rt::ReadBuf::uninit(tbuf.unfilled_mut()); + + match hyper::rt::Read::poll_read(self.project().inner, cx, buf.unfilled()) { + Poll::Ready(Ok(())) => buf.filled().len(), + other => return other, + } + }; + + let n_filled = filled + sub_filled; + // At least sub_filled bytes had to have been initialized. + let n_init = sub_filled; + unsafe { + tbuf.assume_init(n_init); + tbuf.set_filled(n_filled); + } + + Poll::Ready(Ok(())) + } +} + +/// [`WithTokioIo<I>`] is [`AsyncWrite`] if `I` is [`Write`]. +/// +/// [`AsyncWrite`]: tokio::io::AsyncWrite +/// [`Write`]: hyper::rt::Write +impl<I> tokio::io::AsyncWrite for WithTokioIo<I> +where + I: hyper::rt::Write, +{ + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll<Result<usize, std::io::Error>> { + hyper::rt::Write::poll_write(self.project().inner, cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> { + hyper::rt::Write::poll_flush(self.project().inner, cx) + } + + fn poll_shutdown( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll<Result<(), std::io::Error>> { + hyper::rt::Write::poll_shutdown(self.project().inner, cx) + } + + fn is_write_vectored(&self) -> bool { + hyper::rt::Write::is_write_vectored(&self.inner) + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[std::io::IoSlice<'_>], + ) -> Poll<Result<usize, std::io::Error>> { + hyper::rt::Write::poll_write_vectored(self.project().inner, cx, bufs) + } +} + +/// [`WithTokioIo<I>`] exposes its inner `I`'s [`Write`] implementation. +/// +/// [`Write`]: hyper::rt::Write +impl<I> hyper::rt::Write for WithTokioIo<I> +where + I: hyper::rt::Write, +{ + #[inline] + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll<Result<usize, std::io::Error>> { + self.project().inner.poll_write(cx, buf) + } + + #[inline] + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> { + self.project().inner.poll_flush(cx) + } + + #[inline] + fn poll_shutdown( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll<Result<(), std::io::Error>> { + self.project().inner.poll_shutdown(cx) + } + + #[inline] + fn is_write_vectored(&self) -> bool { + self.inner.is_write_vectored() + } + + #[inline] + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[std::io::IoSlice<'_>], + ) -> Poll<Result<usize, std::io::Error>> { + self.project().inner.poll_write_vectored(cx, bufs) + } +} + +impl<I> WithTokioIo<I> { + /// Wraps the inner I/O in an [`WithTokioIo<I>`] + pub fn new(inner: I) -> Self { + Self { inner } + } + + /// Returns a reference to the inner type. + pub fn inner(&self) -> &I { + &self.inner + } + + /// Returns a mutable reference to the inner type. + pub fn inner_mut(&mut self) -> &mut I { + &mut self.inner + } + + /// Consumes this wrapper and returns the inner type. + pub fn into_inner(self) -> I { + self.inner + } +} + +/// [`WithTokioIo<I>`] exposes its inner `I`'s [`Read`] implementation. +/// +/// [`Read`]: hyper::rt::Read +impl<I> hyper::rt::Read for WithTokioIo<I> +where + I: hyper::rt::Read, +{ + #[inline] + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: hyper::rt::ReadBufCursor<'_>, + ) -> Poll<Result<(), std::io::Error>> { + self.project().inner.poll_read(cx, buf) + } +} diff --git a/vendor/hyper-util/src/server/conn/auto/mod.rs b/vendor/hyper-util/src/server/conn/auto/mod.rs new file mode 100644 index 00000000..b2fc6556 --- /dev/null +++ b/vendor/hyper-util/src/server/conn/auto/mod.rs @@ -0,0 +1,1304 @@ +//! Http1 or Http2 connection. + +pub mod upgrade; + +use hyper::service::HttpService; +use std::future::Future; +use std::marker::PhantomPinned; +use std::mem::MaybeUninit; +use std::pin::Pin; +use std::task::{Context, Poll}; +use std::{error::Error as StdError, io, time::Duration}; + +use bytes::Bytes; +use futures_core::ready; +use http::{Request, Response}; +use http_body::Body; +use hyper::{ + body::Incoming, + rt::{Read, ReadBuf, Timer, Write}, + service::Service, +}; + +#[cfg(feature = "http1")] +use hyper::server::conn::http1; + +#[cfg(feature = "http2")] +use hyper::{rt::bounds::Http2ServerConnExec, server::conn::http2}; + +#[cfg(any(not(feature = "http2"), not(feature = "http1")))] +use std::marker::PhantomData; + +use pin_project_lite::pin_project; + +use crate::common::rewind::Rewind; + +type Error = Box<dyn std::error::Error + Send + Sync>; + +type Result<T> = std::result::Result<T, Error>; + +const H2_PREFACE: &[u8] = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n"; + +/// Exactly equivalent to [`Http2ServerConnExec`]. +#[cfg(feature = "http2")] +pub trait HttpServerConnExec<A, B: Body>: Http2ServerConnExec<A, B> {} + +#[cfg(feature = "http2")] +impl<A, B: Body, T: Http2ServerConnExec<A, B>> HttpServerConnExec<A, B> for T {} + +/// Exactly equivalent to [`Http2ServerConnExec`]. +#[cfg(not(feature = "http2"))] +pub trait HttpServerConnExec<A, B: Body> {} + +#[cfg(not(feature = "http2"))] +impl<A, B: Body, T> HttpServerConnExec<A, B> for T {} + +/// Http1 or Http2 connection builder. +#[derive(Clone, Debug)] +pub struct Builder<E> { + #[cfg(feature = "http1")] + http1: http1::Builder, + #[cfg(feature = "http2")] + http2: http2::Builder<E>, + #[cfg(any(feature = "http1", feature = "http2"))] + version: Option<Version>, + #[cfg(not(feature = "http2"))] + _executor: E, +} + +impl<E: Default> Default for Builder<E> { + fn default() -> Self { + Self::new(E::default()) + } +} + +impl<E> Builder<E> { + /// Create a new auto connection builder. + /// + /// `executor` parameter should be a type that implements + /// [`Executor`](hyper::rt::Executor) trait. + /// + /// # Example + /// + /// ``` + /// use hyper_util::{ + /// rt::TokioExecutor, + /// server::conn::auto, + /// }; + /// + /// auto::Builder::new(TokioExecutor::new()); + /// ``` + pub fn new(executor: E) -> Self { + Self { + #[cfg(feature = "http1")] + http1: http1::Builder::new(), + #[cfg(feature = "http2")] + http2: http2::Builder::new(executor), + #[cfg(any(feature = "http1", feature = "http2"))] + version: None, + #[cfg(not(feature = "http2"))] + _executor: executor, + } + } + + /// Http1 configuration. + #[cfg(feature = "http1")] + pub fn http1(&mut self) -> Http1Builder<'_, E> { + Http1Builder { inner: self } + } + + /// Http2 configuration. + #[cfg(feature = "http2")] + pub fn http2(&mut self) -> Http2Builder<'_, E> { + Http2Builder { inner: self } + } + + /// Only accepts HTTP/2 + /// + /// Does not do anything if used with [`serve_connection_with_upgrades`] + /// + /// [`serve_connection_with_upgrades`]: Builder::serve_connection_with_upgrades + #[cfg(feature = "http2")] + pub fn http2_only(mut self) -> Self { + assert!(self.version.is_none()); + self.version = Some(Version::H2); + self + } + + /// Only accepts HTTP/1 + /// + /// Does not do anything if used with [`serve_connection_with_upgrades`] + /// + /// [`serve_connection_with_upgrades`]: Builder::serve_connection_with_upgrades + #[cfg(feature = "http1")] + pub fn http1_only(mut self) -> Self { + assert!(self.version.is_none()); + self.version = Some(Version::H1); + self + } + + /// Returns `true` if this builder can serve an HTTP/1.1-based connection. + pub fn is_http1_available(&self) -> bool { + match self.version { + #[cfg(feature = "http1")] + Some(Version::H1) => true, + #[cfg(feature = "http2")] + Some(Version::H2) => false, + #[cfg(any(feature = "http1", feature = "http2"))] + _ => true, + } + } + + /// Returns `true` if this builder can serve an HTTP/2-based connection. + pub fn is_http2_available(&self) -> bool { + match self.version { + #[cfg(feature = "http1")] + Some(Version::H1) => false, + #[cfg(feature = "http2")] + Some(Version::H2) => true, + #[cfg(any(feature = "http1", feature = "http2"))] + _ => true, + } + } + + /// Bind a connection together with a [`Service`]. + pub fn serve_connection<I, S, B>(&self, io: I, service: S) -> Connection<'_, I, S, E> + where + S: Service<Request<Incoming>, Response = Response<B>>, + S::Future: 'static, + S::Error: Into<Box<dyn StdError + Send + Sync>>, + B: Body + 'static, + B::Error: Into<Box<dyn StdError + Send + Sync>>, + I: Read + Write + Unpin + 'static, + E: HttpServerConnExec<S::Future, B>, + { + let state = match self.version { + #[cfg(feature = "http1")] + Some(Version::H1) => { + let io = Rewind::new_buffered(io, Bytes::new()); + let conn = self.http1.serve_connection(io, service); + ConnState::H1 { conn } + } + #[cfg(feature = "http2")] + Some(Version::H2) => { + let io = Rewind::new_buffered(io, Bytes::new()); + let conn = self.http2.serve_connection(io, service); + ConnState::H2 { conn } + } + #[cfg(any(feature = "http1", feature = "http2"))] + _ => ConnState::ReadVersion { + read_version: read_version(io), + builder: Cow::Borrowed(self), + service: Some(service), + }, + }; + + Connection { state } + } + + /// Bind a connection together with a [`Service`], with the ability to + /// handle HTTP upgrades. This requires that the IO object implements + /// `Send`. + /// + /// Note that if you ever want to use [`hyper::upgrade::Upgraded::downcast`] + /// with this crate, you'll need to use [`hyper_util::server::conn::auto::upgrade::downcast`] + /// instead. See the documentation of the latter to understand why. + /// + /// [`hyper_util::server::conn::auto::upgrade::downcast`]: crate::server::conn::auto::upgrade::downcast + pub fn serve_connection_with_upgrades<I, S, B>( + &self, + io: I, + service: S, + ) -> UpgradeableConnection<'_, I, S, E> + where + S: Service<Request<Incoming>, Response = Response<B>>, + S::Future: 'static, + S::Error: Into<Box<dyn StdError + Send + Sync>>, + B: Body + 'static, + B::Error: Into<Box<dyn StdError + Send + Sync>>, + I: Read + Write + Unpin + Send + 'static, + E: HttpServerConnExec<S::Future, B>, + { + UpgradeableConnection { + state: UpgradeableConnState::ReadVersion { + read_version: read_version(io), + builder: Cow::Borrowed(self), + service: Some(service), + }, + } + } +} + +#[derive(Copy, Clone, Debug)] +enum Version { + H1, + H2, +} + +impl Version { + #[must_use] + #[cfg(any(not(feature = "http2"), not(feature = "http1")))] + pub fn unsupported(self) -> Error { + match self { + Version::H1 => Error::from("HTTP/1 is not supported"), + Version::H2 => Error::from("HTTP/2 is not supported"), + } + } +} + +fn read_version<I>(io: I) -> ReadVersion<I> +where + I: Read + Unpin, +{ + ReadVersion { + io: Some(io), + buf: [MaybeUninit::uninit(); 24], + filled: 0, + version: Version::H2, + cancelled: false, + _pin: PhantomPinned, + } +} + +pin_project! { + struct ReadVersion<I> { + io: Option<I>, + buf: [MaybeUninit<u8>; 24], + // the amount of `buf` thats been filled + filled: usize, + version: Version, + cancelled: bool, + // Make this future `!Unpin` for compatibility with async trait methods. + #[pin] + _pin: PhantomPinned, + } +} + +impl<I> ReadVersion<I> { + pub fn cancel(self: Pin<&mut Self>) { + *self.project().cancelled = true; + } +} + +impl<I> Future for ReadVersion<I> +where + I: Read + Unpin, +{ + type Output = io::Result<(Version, Rewind<I>)>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + let this = self.project(); + if *this.cancelled { + return Poll::Ready(Err(io::Error::new(io::ErrorKind::Interrupted, "Cancelled"))); + } + + let mut buf = ReadBuf::uninit(&mut *this.buf); + // SAFETY: `this.filled` tracks how many bytes have been read (and thus initialized) and + // we're only advancing by that many. + unsafe { + buf.unfilled().advance(*this.filled); + }; + + // We start as H2 and switch to H1 as soon as we don't have the preface. + while buf.filled().len() < H2_PREFACE.len() { + let len = buf.filled().len(); + ready!(Pin::new(this.io.as_mut().unwrap()).poll_read(cx, buf.unfilled()))?; + *this.filled = buf.filled().len(); + + // We starts as H2 and switch to H1 when we don't get the preface. + if buf.filled().len() == len + || buf.filled()[len..] != H2_PREFACE[len..buf.filled().len()] + { + *this.version = Version::H1; + break; + } + } + + let io = this.io.take().unwrap(); + let buf = buf.filled().to_vec(); + Poll::Ready(Ok(( + *this.version, + Rewind::new_buffered(io, Bytes::from(buf)), + ))) + } +} + +pin_project! { + /// A [`Future`](core::future::Future) representing an HTTP/1 connection, returned from + /// [`Builder::serve_connection`](struct.Builder.html#method.serve_connection). + /// + /// To drive HTTP on this connection this future **must be polled**, typically with + /// `.await`. If it isn't polled, no progress will be made on this connection. + #[must_use = "futures do nothing unless polled"] + pub struct Connection<'a, I, S, E> + where + S: HttpService<Incoming>, + { + #[pin] + state: ConnState<'a, I, S, E>, + } +} + +// A custom COW, since the libstd is has ToOwned bounds that are too eager. +enum Cow<'a, T> { + Borrowed(&'a T), + Owned(T), +} + +impl<T> std::ops::Deref for Cow<'_, T> { + type Target = T; + fn deref(&self) -> &T { + match self { + Cow::Borrowed(t) => &*t, + Cow::Owned(ref t) => t, + } + } +} + +#[cfg(feature = "http1")] +type Http1Connection<I, S> = hyper::server::conn::http1::Connection<Rewind<I>, S>; + +#[cfg(not(feature = "http1"))] +type Http1Connection<I, S> = (PhantomData<I>, PhantomData<S>); + +#[cfg(feature = "http2")] +type Http2Connection<I, S, E> = hyper::server::conn::http2::Connection<Rewind<I>, S, E>; + +#[cfg(not(feature = "http2"))] +type Http2Connection<I, S, E> = (PhantomData<I>, PhantomData<S>, PhantomData<E>); + +pin_project! { + #[project = ConnStateProj] + enum ConnState<'a, I, S, E> + where + S: HttpService<Incoming>, + { + ReadVersion { + #[pin] + read_version: ReadVersion<I>, + builder: Cow<'a, Builder<E>>, + service: Option<S>, + }, + H1 { + #[pin] + conn: Http1Connection<I, S>, + }, + H2 { + #[pin] + conn: Http2Connection<I, S, E>, + }, + } +} + +impl<I, S, E, B> Connection<'_, I, S, E> +where + S: HttpService<Incoming, ResBody = B>, + S::Error: Into<Box<dyn StdError + Send + Sync>>, + I: Read + Write + Unpin, + B: Body + 'static, + B::Error: Into<Box<dyn StdError + Send + Sync>>, + E: HttpServerConnExec<S::Future, B>, +{ + /// Start a graceful shutdown process for this connection. + /// + /// This `Connection` should continue to be polled until shutdown can finish. + /// + /// # Note + /// + /// This should only be called while the `Connection` future is still pending. If called after + /// `Connection::poll` has resolved, this does nothing. + pub fn graceful_shutdown(self: Pin<&mut Self>) { + match self.project().state.project() { + ConnStateProj::ReadVersion { read_version, .. } => read_version.cancel(), + #[cfg(feature = "http1")] + ConnStateProj::H1 { conn } => conn.graceful_shutdown(), + #[cfg(feature = "http2")] + ConnStateProj::H2 { conn } => conn.graceful_shutdown(), + #[cfg(any(not(feature = "http1"), not(feature = "http2")))] + _ => unreachable!(), + } + } + + /// Make this Connection static, instead of borrowing from Builder. + pub fn into_owned(self) -> Connection<'static, I, S, E> + where + Builder<E>: Clone, + { + Connection { + state: match self.state { + ConnState::ReadVersion { + read_version, + builder, + service, + } => ConnState::ReadVersion { + read_version, + service, + builder: Cow::Owned(builder.clone()), + }, + #[cfg(feature = "http1")] + ConnState::H1 { conn } => ConnState::H1 { conn }, + #[cfg(feature = "http2")] + ConnState::H2 { conn } => ConnState::H2 { conn }, + #[cfg(any(not(feature = "http1"), not(feature = "http2")))] + _ => unreachable!(), + }, + } + } +} + +impl<I, S, E, B> Future for Connection<'_, I, S, E> +where + S: Service<Request<Incoming>, Response = Response<B>>, + S::Future: 'static, + S::Error: Into<Box<dyn StdError + Send + Sync>>, + B: Body + 'static, + B::Error: Into<Box<dyn StdError + Send + Sync>>, + I: Read + Write + Unpin + 'static, + E: HttpServerConnExec<S::Future, B>, +{ + type Output = Result<()>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + loop { + let mut this = self.as_mut().project(); + + match this.state.as_mut().project() { + ConnStateProj::ReadVersion { + read_version, + builder, + service, + } => { + let (version, io) = ready!(read_version.poll(cx))?; + let service = service.take().unwrap(); + match version { + #[cfg(feature = "http1")] + Version::H1 => { + let conn = builder.http1.serve_connection(io, service); + this.state.set(ConnState::H1 { conn }); + } + #[cfg(feature = "http2")] + Version::H2 => { + let conn = builder.http2.serve_connection(io, service); + this.state.set(ConnState::H2 { conn }); + } + #[cfg(any(not(feature = "http1"), not(feature = "http2")))] + _ => return Poll::Ready(Err(version.unsupported())), + } + } + #[cfg(feature = "http1")] + ConnStateProj::H1 { conn } => { + return conn.poll(cx).map_err(Into::into); + } + #[cfg(feature = "http2")] + ConnStateProj::H2 { conn } => { + return conn.poll(cx).map_err(Into::into); + } + #[cfg(any(not(feature = "http1"), not(feature = "http2")))] + _ => unreachable!(), + } + } + } +} + +pin_project! { + /// An upgradable [`Connection`], returned by + /// [`Builder::serve_upgradable_connection`](struct.Builder.html#method.serve_connection_with_upgrades). + /// + /// To drive HTTP on this connection this future **must be polled**, typically with + /// `.await`. If it isn't polled, no progress will be made on this connection. + #[must_use = "futures do nothing unless polled"] + pub struct UpgradeableConnection<'a, I, S, E> + where + S: HttpService<Incoming>, + { + #[pin] + state: UpgradeableConnState<'a, I, S, E>, + } +} + +#[cfg(feature = "http1")] +type Http1UpgradeableConnection<I, S> = hyper::server::conn::http1::UpgradeableConnection<I, S>; + +#[cfg(not(feature = "http1"))] +type Http1UpgradeableConnection<I, S> = (PhantomData<I>, PhantomData<S>); + +pin_project! { + #[project = UpgradeableConnStateProj] + enum UpgradeableConnState<'a, I, S, E> + where + S: HttpService<Incoming>, + { + ReadVersion { + #[pin] + read_version: ReadVersion<I>, + builder: Cow<'a, Builder<E>>, + service: Option<S>, + }, + H1 { + #[pin] + conn: Http1UpgradeableConnection<Rewind<I>, S>, + }, + H2 { + #[pin] + conn: Http2Connection<I, S, E>, + }, + } +} + +impl<I, S, E, B> UpgradeableConnection<'_, I, S, E> +where + S: HttpService<Incoming, ResBody = B>, + S::Error: Into<Box<dyn StdError + Send + Sync>>, + I: Read + Write + Unpin, + B: Body + 'static, + B::Error: Into<Box<dyn StdError + Send + Sync>>, + E: HttpServerConnExec<S::Future, B>, +{ + /// Start a graceful shutdown process for this connection. + /// + /// This `UpgradeableConnection` should continue to be polled until shutdown can finish. + /// + /// # Note + /// + /// This should only be called while the `Connection` future is still nothing. pending. If + /// called after `UpgradeableConnection::poll` has resolved, this does nothing. + pub fn graceful_shutdown(self: Pin<&mut Self>) { + match self.project().state.project() { + UpgradeableConnStateProj::ReadVersion { read_version, .. } => read_version.cancel(), + #[cfg(feature = "http1")] + UpgradeableConnStateProj::H1 { conn } => conn.graceful_shutdown(), + #[cfg(feature = "http2")] + UpgradeableConnStateProj::H2 { conn } => conn.graceful_shutdown(), + #[cfg(any(not(feature = "http1"), not(feature = "http2")))] + _ => unreachable!(), + } + } + + /// Make this Connection static, instead of borrowing from Builder. + pub fn into_owned(self) -> UpgradeableConnection<'static, I, S, E> + where + Builder<E>: Clone, + { + UpgradeableConnection { + state: match self.state { + UpgradeableConnState::ReadVersion { + read_version, + builder, + service, + } => UpgradeableConnState::ReadVersion { + read_version, + service, + builder: Cow::Owned(builder.clone()), + }, + #[cfg(feature = "http1")] + UpgradeableConnState::H1 { conn } => UpgradeableConnState::H1 { conn }, + #[cfg(feature = "http2")] + UpgradeableConnState::H2 { conn } => UpgradeableConnState::H2 { conn }, + #[cfg(any(not(feature = "http1"), not(feature = "http2")))] + _ => unreachable!(), + }, + } + } +} + +impl<I, S, E, B> Future for UpgradeableConnection<'_, I, S, E> +where + S: Service<Request<Incoming>, Response = Response<B>>, + S::Future: 'static, + S::Error: Into<Box<dyn StdError + Send + Sync>>, + B: Body + 'static, + B::Error: Into<Box<dyn StdError + Send + Sync>>, + I: Read + Write + Unpin + Send + 'static, + E: HttpServerConnExec<S::Future, B>, +{ + type Output = Result<()>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + loop { + let mut this = self.as_mut().project(); + + match this.state.as_mut().project() { + UpgradeableConnStateProj::ReadVersion { + read_version, + builder, + service, + } => { + let (version, io) = ready!(read_version.poll(cx))?; + let service = service.take().unwrap(); + match version { + #[cfg(feature = "http1")] + Version::H1 => { + let conn = builder.http1.serve_connection(io, service).with_upgrades(); + this.state.set(UpgradeableConnState::H1 { conn }); + } + #[cfg(feature = "http2")] + Version::H2 => { + let conn = builder.http2.serve_connection(io, service); + this.state.set(UpgradeableConnState::H2 { conn }); + } + #[cfg(any(not(feature = "http1"), not(feature = "http2")))] + _ => return Poll::Ready(Err(version.unsupported())), + } + } + #[cfg(feature = "http1")] + UpgradeableConnStateProj::H1 { conn } => { + return conn.poll(cx).map_err(Into::into); + } + #[cfg(feature = "http2")] + UpgradeableConnStateProj::H2 { conn } => { + return conn.poll(cx).map_err(Into::into); + } + #[cfg(any(not(feature = "http1"), not(feature = "http2")))] + _ => unreachable!(), + } + } + } +} + +/// Http1 part of builder. +#[cfg(feature = "http1")] +pub struct Http1Builder<'a, E> { + inner: &'a mut Builder<E>, +} + +#[cfg(feature = "http1")] +impl<E> Http1Builder<'_, E> { + /// Http2 configuration. + #[cfg(feature = "http2")] + pub fn http2(&mut self) -> Http2Builder<'_, E> { + Http2Builder { inner: self.inner } + } + + /// Set whether the `date` header should be included in HTTP responses. + /// + /// Note that including the `date` header is recommended by RFC 7231. + /// + /// Default is true. + pub fn auto_date_header(&mut self, enabled: bool) -> &mut Self { + self.inner.http1.auto_date_header(enabled); + self + } + + /// Set whether HTTP/1 connections should support half-closures. + /// + /// Clients can chose to shutdown their write-side while waiting + /// for the server to respond. Setting this to `true` will + /// prevent closing the connection immediately if `read` + /// detects an EOF in the middle of a request. + /// + /// Default is `false`. + pub fn half_close(&mut self, val: bool) -> &mut Self { + self.inner.http1.half_close(val); + self + } + + /// Enables or disables HTTP/1 keep-alive. + /// + /// Default is true. + pub fn keep_alive(&mut self, val: bool) -> &mut Self { + self.inner.http1.keep_alive(val); + self + } + + /// Set whether HTTP/1 connections will write header names as title case at + /// the socket level. + /// + /// Note that this setting does not affect HTTP/2. + /// + /// Default is false. + pub fn title_case_headers(&mut self, enabled: bool) -> &mut Self { + self.inner.http1.title_case_headers(enabled); + self + } + + /// Set whether HTTP/1 connections will silently ignored malformed header lines. + /// + /// If this is enabled and a header line does not start with a valid header + /// name, or does not include a colon at all, the line will be silently ignored + /// and no error will be reported. + /// + /// Default is false. + pub fn ignore_invalid_headers(&mut self, enabled: bool) -> &mut Self { + self.inner.http1.ignore_invalid_headers(enabled); + self + } + + /// Set whether to support preserving original header cases. + /// + /// Currently, this will record the original cases received, and store them + /// in a private extension on the `Request`. It will also look for and use + /// such an extension in any provided `Response`. + /// + /// Since the relevant extension is still private, there is no way to + /// interact with the original cases. The only effect this can have now is + /// to forward the cases in a proxy-like fashion. + /// + /// Note that this setting does not affect HTTP/2. + /// + /// Default is false. + pub fn preserve_header_case(&mut self, enabled: bool) -> &mut Self { + self.inner.http1.preserve_header_case(enabled); + self + } + + /// Set the maximum number of headers. + /// + /// When a request is received, the parser will reserve a buffer to store headers for optimal + /// performance. + /// + /// If server receives more headers than the buffer size, it responds to the client with + /// "431 Request Header Fields Too Large". + /// + /// The headers is allocated on the stack by default, which has higher performance. After + /// setting this value, headers will be allocated in heap memory, that is, heap memory + /// allocation will occur for each request, and there will be a performance drop of about 5%. + /// + /// Note that this setting does not affect HTTP/2. + /// + /// Default is 100. + pub fn max_headers(&mut self, val: usize) -> &mut Self { + self.inner.http1.max_headers(val); + self + } + + /// Set a timeout for reading client request headers. If a client does not + /// transmit the entire header within this time, the connection is closed. + /// + /// Requires a [`Timer`] set by [`Http1Builder::timer`] to take effect. Panics if `header_read_timeout` is configured + /// without a [`Timer`]. + /// + /// Pass `None` to disable. + /// + /// Default is currently 30 seconds, but do not depend on that. + pub fn header_read_timeout(&mut self, read_timeout: impl Into<Option<Duration>>) -> &mut Self { + self.inner.http1.header_read_timeout(read_timeout); + self + } + + /// Set whether HTTP/1 connections should try to use vectored writes, + /// or always flatten into a single buffer. + /// + /// Note that setting this to false may mean more copies of body data, + /// but may also improve performance when an IO transport doesn't + /// support vectored writes well, such as most TLS implementations. + /// + /// Setting this to true will force hyper to use queued strategy + /// which may eliminate unnecessary cloning on some TLS backends + /// + /// Default is `auto`. In this mode hyper will try to guess which + /// mode to use + pub fn writev(&mut self, val: bool) -> &mut Self { + self.inner.http1.writev(val); + self + } + + /// Set the maximum buffer size for the connection. + /// + /// Default is ~400kb. + /// + /// # Panics + /// + /// The minimum value allowed is 8192. This method panics if the passed `max` is less than the minimum. + pub fn max_buf_size(&mut self, max: usize) -> &mut Self { + self.inner.http1.max_buf_size(max); + self + } + + /// Aggregates flushes to better support pipelined responses. + /// + /// Experimental, may have bugs. + /// + /// Default is false. + pub fn pipeline_flush(&mut self, enabled: bool) -> &mut Self { + self.inner.http1.pipeline_flush(enabled); + self + } + + /// Set the timer used in background tasks. + pub fn timer<M>(&mut self, timer: M) -> &mut Self + where + M: Timer + Send + Sync + 'static, + { + self.inner.http1.timer(timer); + self + } + + /// Bind a connection together with a [`Service`]. + #[cfg(feature = "http2")] + pub async fn serve_connection<I, S, B>(&self, io: I, service: S) -> Result<()> + where + S: Service<Request<Incoming>, Response = Response<B>>, + S::Future: 'static, + S::Error: Into<Box<dyn StdError + Send + Sync>>, + B: Body + 'static, + B::Error: Into<Box<dyn StdError + Send + Sync>>, + I: Read + Write + Unpin + 'static, + E: HttpServerConnExec<S::Future, B>, + { + self.inner.serve_connection(io, service).await + } + + /// Bind a connection together with a [`Service`]. + #[cfg(not(feature = "http2"))] + pub async fn serve_connection<I, S, B>(&self, io: I, service: S) -> Result<()> + where + S: Service<Request<Incoming>, Response = Response<B>>, + S::Future: 'static, + S::Error: Into<Box<dyn StdError + Send + Sync>>, + B: Body + 'static, + B::Error: Into<Box<dyn StdError + Send + Sync>>, + I: Read + Write + Unpin + 'static, + { + self.inner.serve_connection(io, service).await + } + + /// Bind a connection together with a [`Service`], with the ability to + /// handle HTTP upgrades. This requires that the IO object implements + /// `Send`. + #[cfg(feature = "http2")] + pub fn serve_connection_with_upgrades<I, S, B>( + &self, + io: I, + service: S, + ) -> UpgradeableConnection<'_, I, S, E> + where + S: Service<Request<Incoming>, Response = Response<B>>, + S::Future: 'static, + S::Error: Into<Box<dyn StdError + Send + Sync>>, + B: Body + 'static, + B::Error: Into<Box<dyn StdError + Send + Sync>>, + I: Read + Write + Unpin + Send + 'static, + E: HttpServerConnExec<S::Future, B>, + { + self.inner.serve_connection_with_upgrades(io, service) + } +} + +/// Http2 part of builder. +#[cfg(feature = "http2")] +pub struct Http2Builder<'a, E> { + inner: &'a mut Builder<E>, +} + +#[cfg(feature = "http2")] +impl<E> Http2Builder<'_, E> { + #[cfg(feature = "http1")] + /// Http1 configuration. + pub fn http1(&mut self) -> Http1Builder<'_, E> { + Http1Builder { inner: self.inner } + } + + /// Configures the maximum number of pending reset streams allowed before a GOAWAY will be sent. + /// + /// This will default to the default value set by the [`h2` crate](https://crates.io/crates/h2). + /// As of v0.4.0, it is 20. + /// + /// See <https://github.com/hyperium/hyper/issues/2877> for more information. + pub fn max_pending_accept_reset_streams(&mut self, max: impl Into<Option<usize>>) -> &mut Self { + self.inner.http2.max_pending_accept_reset_streams(max); + self + } + + /// Configures the maximum number of local reset streams allowed before a GOAWAY will be sent. + /// + /// If not set, hyper will use a default, currently of 1024. + /// + /// If `None` is supplied, hyper will not apply any limit. + /// This is not advised, as it can potentially expose servers to DOS vulnerabilities. + /// + /// See <https://rustsec.org/advisories/RUSTSEC-2024-0003.html> for more information. + pub fn max_local_error_reset_streams(&mut self, max: impl Into<Option<usize>>) -> &mut Self { + self.inner.http2.max_local_error_reset_streams(max); + self + } + + /// Sets the [`SETTINGS_INITIAL_WINDOW_SIZE`][spec] option for HTTP2 + /// stream-level flow control. + /// + /// Passing `None` will do nothing. + /// + /// If not set, hyper will use a default. + /// + /// [spec]: https://http2.github.io/http2-spec/#SETTINGS_INITIAL_WINDOW_SIZE + pub fn initial_stream_window_size(&mut self, sz: impl Into<Option<u32>>) -> &mut Self { + self.inner.http2.initial_stream_window_size(sz); + self + } + + /// Sets the max connection-level flow control for HTTP2. + /// + /// Passing `None` will do nothing. + /// + /// If not set, hyper will use a default. + pub fn initial_connection_window_size(&mut self, sz: impl Into<Option<u32>>) -> &mut Self { + self.inner.http2.initial_connection_window_size(sz); + self + } + + /// Sets whether to use an adaptive flow control. + /// + /// Enabling this will override the limits set in + /// `http2_initial_stream_window_size` and + /// `http2_initial_connection_window_size`. + pub fn adaptive_window(&mut self, enabled: bool) -> &mut Self { + self.inner.http2.adaptive_window(enabled); + self + } + + /// Sets the maximum frame size to use for HTTP2. + /// + /// Passing `None` will do nothing. + /// + /// If not set, hyper will use a default. + pub fn max_frame_size(&mut self, sz: impl Into<Option<u32>>) -> &mut Self { + self.inner.http2.max_frame_size(sz); + self + } + + /// Sets the [`SETTINGS_MAX_CONCURRENT_STREAMS`][spec] option for HTTP2 + /// connections. + /// + /// Default is 200. Passing `None` will remove any limit. + /// + /// [spec]: https://http2.github.io/http2-spec/#SETTINGS_MAX_CONCURRENT_STREAMS + pub fn max_concurrent_streams(&mut self, max: impl Into<Option<u32>>) -> &mut Self { + self.inner.http2.max_concurrent_streams(max); + self + } + + /// Sets an interval for HTTP2 Ping frames should be sent to keep a + /// connection alive. + /// + /// Pass `None` to disable HTTP2 keep-alive. + /// + /// Default is currently disabled. + /// + /// # Cargo Feature + /// + pub fn keep_alive_interval(&mut self, interval: impl Into<Option<Duration>>) -> &mut Self { + self.inner.http2.keep_alive_interval(interval); + self + } + + /// Sets a timeout for receiving an acknowledgement of the keep-alive ping. + /// + /// If the ping is not acknowledged within the timeout, the connection will + /// be closed. Does nothing if `http2_keep_alive_interval` is disabled. + /// + /// Default is 20 seconds. + /// + /// # Cargo Feature + /// + pub fn keep_alive_timeout(&mut self, timeout: Duration) -> &mut Self { + self.inner.http2.keep_alive_timeout(timeout); + self + } + + /// Set the maximum write buffer size for each HTTP/2 stream. + /// + /// Default is currently ~400KB, but may change. + /// + /// # Panics + /// + /// The value must be no larger than `u32::MAX`. + pub fn max_send_buf_size(&mut self, max: usize) -> &mut Self { + self.inner.http2.max_send_buf_size(max); + self + } + + /// Enables the [extended CONNECT protocol]. + /// + /// [extended CONNECT protocol]: https://datatracker.ietf.org/doc/html/rfc8441#section-4 + pub fn enable_connect_protocol(&mut self) -> &mut Self { + self.inner.http2.enable_connect_protocol(); + self + } + + /// Sets the max size of received header frames. + /// + /// Default is currently ~16MB, but may change. + pub fn max_header_list_size(&mut self, max: u32) -> &mut Self { + self.inner.http2.max_header_list_size(max); + self + } + + /// Set the timer used in background tasks. + pub fn timer<M>(&mut self, timer: M) -> &mut Self + where + M: Timer + Send + Sync + 'static, + { + self.inner.http2.timer(timer); + self + } + + /// Set whether the `date` header should be included in HTTP responses. + /// + /// Note that including the `date` header is recommended by RFC 7231. + /// + /// Default is true. + pub fn auto_date_header(&mut self, enabled: bool) -> &mut Self { + self.inner.http2.auto_date_header(enabled); + self + } + + /// Bind a connection together with a [`Service`]. + pub async fn serve_connection<I, S, B>(&self, io: I, service: S) -> Result<()> + where + S: Service<Request<Incoming>, Response = Response<B>>, + S::Future: 'static, + S::Error: Into<Box<dyn StdError + Send + Sync>>, + B: Body + 'static, + B::Error: Into<Box<dyn StdError + Send + Sync>>, + I: Read + Write + Unpin + 'static, + E: HttpServerConnExec<S::Future, B>, + { + self.inner.serve_connection(io, service).await + } + + /// Bind a connection together with a [`Service`], with the ability to + /// handle HTTP upgrades. This requires that the IO object implements + /// `Send`. + pub fn serve_connection_with_upgrades<I, S, B>( + &self, + io: I, + service: S, + ) -> UpgradeableConnection<'_, I, S, E> + where + S: Service<Request<Incoming>, Response = Response<B>>, + S::Future: 'static, + S::Error: Into<Box<dyn StdError + Send + Sync>>, + B: Body + 'static, + B::Error: Into<Box<dyn StdError + Send + Sync>>, + I: Read + Write + Unpin + Send + 'static, + E: HttpServerConnExec<S::Future, B>, + { + self.inner.serve_connection_with_upgrades(io, service) + } +} + +#[cfg(test)] +mod tests { + use crate::{ + rt::{TokioExecutor, TokioIo}, + server::conn::auto, + }; + use http::{Request, Response}; + use http_body::Body; + use http_body_util::{BodyExt, Empty, Full}; + use hyper::{body, body::Bytes, client, service::service_fn}; + use std::{convert::Infallible, error::Error as StdError, net::SocketAddr, time::Duration}; + use tokio::{ + net::{TcpListener, TcpStream}, + pin, + }; + + const BODY: &[u8] = b"Hello, world!"; + + #[test] + fn configuration() { + // One liner. + auto::Builder::new(TokioExecutor::new()) + .http1() + .keep_alive(true) + .http2() + .keep_alive_interval(None); + // .serve_connection(io, service); + + // Using variable. + let mut builder = auto::Builder::new(TokioExecutor::new()); + + builder.http1().keep_alive(true); + builder.http2().keep_alive_interval(None); + // builder.serve_connection(io, service); + } + + #[cfg(not(miri))] + #[tokio::test] + async fn http1() { + let addr = start_server(false, false).await; + let mut sender = connect_h1(addr).await; + + let response = sender + .send_request(Request::new(Empty::<Bytes>::new())) + .await + .unwrap(); + + let body = response.into_body().collect().await.unwrap().to_bytes(); + + assert_eq!(body, BODY); + } + + #[cfg(not(miri))] + #[tokio::test] + async fn http2() { + let addr = start_server(false, false).await; + let mut sender = connect_h2(addr).await; + + let response = sender + .send_request(Request::new(Empty::<Bytes>::new())) + .await + .unwrap(); + + let body = response.into_body().collect().await.unwrap().to_bytes(); + + assert_eq!(body, BODY); + } + + #[cfg(not(miri))] + #[tokio::test] + async fn http2_only() { + let addr = start_server(false, true).await; + let mut sender = connect_h2(addr).await; + + let response = sender + .send_request(Request::new(Empty::<Bytes>::new())) + .await + .unwrap(); + + let body = response.into_body().collect().await.unwrap().to_bytes(); + + assert_eq!(body, BODY); + } + + #[cfg(not(miri))] + #[tokio::test] + async fn http2_only_fail_if_client_is_http1() { + let addr = start_server(false, true).await; + let mut sender = connect_h1(addr).await; + + let _ = sender + .send_request(Request::new(Empty::<Bytes>::new())) + .await + .expect_err("should fail"); + } + + #[cfg(not(miri))] + #[tokio::test] + async fn http1_only() { + let addr = start_server(true, false).await; + let mut sender = connect_h1(addr).await; + + let response = sender + .send_request(Request::new(Empty::<Bytes>::new())) + .await + .unwrap(); + + let body = response.into_body().collect().await.unwrap().to_bytes(); + + assert_eq!(body, BODY); + } + + #[cfg(not(miri))] + #[tokio::test] + async fn http1_only_fail_if_client_is_http2() { + let addr = start_server(true, false).await; + let mut sender = connect_h2(addr).await; + + let _ = sender + .send_request(Request::new(Empty::<Bytes>::new())) + .await + .expect_err("should fail"); + } + + #[cfg(not(miri))] + #[tokio::test] + async fn graceful_shutdown() { + let listener = TcpListener::bind(SocketAddr::from(([127, 0, 0, 1], 0))) + .await + .unwrap(); + + let listener_addr = listener.local_addr().unwrap(); + + // Spawn the task in background so that we can connect there + let listen_task = tokio::spawn(async move { listener.accept().await.unwrap() }); + // Only connect a stream, do not send headers or anything + let _stream = TcpStream::connect(listener_addr).await.unwrap(); + + let (stream, _) = listen_task.await.unwrap(); + let stream = TokioIo::new(stream); + let builder = auto::Builder::new(TokioExecutor::new()); + let connection = builder.serve_connection(stream, service_fn(hello)); + + pin!(connection); + + connection.as_mut().graceful_shutdown(); + + let connection_error = tokio::time::timeout(Duration::from_millis(200), connection) + .await + .expect("Connection should have finished in a timely manner after graceful shutdown.") + .expect_err("Connection should have been interrupted."); + + let connection_error = connection_error + .downcast_ref::<std::io::Error>() + .expect("The error should have been `std::io::Error`."); + assert_eq!(connection_error.kind(), std::io::ErrorKind::Interrupted); + } + + async fn connect_h1<B>(addr: SocketAddr) -> client::conn::http1::SendRequest<B> + where + B: Body + Send + 'static, + B::Data: Send, + B::Error: Into<Box<dyn StdError + Send + Sync>>, + { + let stream = TokioIo::new(TcpStream::connect(addr).await.unwrap()); + let (sender, connection) = client::conn::http1::handshake(stream).await.unwrap(); + + tokio::spawn(connection); + + sender + } + + async fn connect_h2<B>(addr: SocketAddr) -> client::conn::http2::SendRequest<B> + where + B: Body + Unpin + Send + 'static, + B::Data: Send, + B::Error: Into<Box<dyn StdError + Send + Sync>>, + { + let stream = TokioIo::new(TcpStream::connect(addr).await.unwrap()); + let (sender, connection) = client::conn::http2::Builder::new(TokioExecutor::new()) + .handshake(stream) + .await + .unwrap(); + + tokio::spawn(connection); + + sender + } + + async fn start_server(h1_only: bool, h2_only: bool) -> SocketAddr { + let addr: SocketAddr = ([127, 0, 0, 1], 0).into(); + let listener = TcpListener::bind(addr).await.unwrap(); + + let local_addr = listener.local_addr().unwrap(); + + tokio::spawn(async move { + loop { + let (stream, _) = listener.accept().await.unwrap(); + let stream = TokioIo::new(stream); + tokio::task::spawn(async move { + let mut builder = auto::Builder::new(TokioExecutor::new()); + if h1_only { + builder = builder.http1_only(); + builder.serve_connection(stream, service_fn(hello)).await + } else if h2_only { + builder = builder.http2_only(); + builder.serve_connection(stream, service_fn(hello)).await + } else { + builder + .http2() + .max_header_list_size(4096) + .serve_connection_with_upgrades(stream, service_fn(hello)) + .await + } + .unwrap(); + }); + } + }); + + local_addr + } + + async fn hello(_req: Request<body::Incoming>) -> Result<Response<Full<Bytes>>, Infallible> { + Ok(Response::new(Full::new(Bytes::from(BODY)))) + } +} diff --git a/vendor/hyper-util/src/server/conn/auto/upgrade.rs b/vendor/hyper-util/src/server/conn/auto/upgrade.rs new file mode 100644 index 00000000..8d94c409 --- /dev/null +++ b/vendor/hyper-util/src/server/conn/auto/upgrade.rs @@ -0,0 +1,68 @@ +//! Upgrade utilities. + +use bytes::{Bytes, BytesMut}; +use hyper::{ + rt::{Read, Write}, + upgrade::Upgraded, +}; + +use crate::common::rewind::Rewind; + +/// Tries to downcast the internal trait object to the type passed. +/// +/// On success, returns the downcasted parts. On error, returns the Upgraded back. +/// This is a kludge to work around the fact that the machinery provided by +/// [`hyper_util::server::conn::auto`] wraps the inner `T` with a private type +/// that is not reachable from outside the crate. +/// +/// [`hyper_util::server::conn::auto`]: crate::server::conn::auto +/// +/// This kludge will be removed when this machinery is added back to the main +/// `hyper` code. +pub fn downcast<T>(upgraded: Upgraded) -> Result<Parts<T>, Upgraded> +where + T: Read + Write + Unpin + 'static, +{ + let hyper::upgrade::Parts { + io: rewind, + mut read_buf, + .. + } = upgraded.downcast::<Rewind<T>>()?; + + if let Some(pre) = rewind.pre { + read_buf = if read_buf.is_empty() { + pre + } else { + let mut buf = BytesMut::from(read_buf); + + buf.extend_from_slice(&pre); + + buf.freeze() + }; + } + + Ok(Parts { + io: rewind.inner, + read_buf, + }) +} + +/// The deconstructed parts of an [`Upgraded`] type. +/// +/// Includes the original IO type, and a read buffer of bytes that the +/// HTTP state machine may have already read before completing an upgrade. +#[derive(Debug)] +#[non_exhaustive] +pub struct Parts<T> { + /// The original IO object used before the upgrade. + pub io: T, + /// A buffer of bytes that have been read but not processed as HTTP. + /// + /// For instance, if the `Connection` is used for an HTTP upgrade request, + /// it is possible the server sent back the first bytes of the new protocol + /// along with the response upgrade. + /// + /// You will want to check for any existing bytes if you plan to continue + /// communicating on the IO object. + pub read_buf: Bytes, +} diff --git a/vendor/hyper-util/src/server/conn/mod.rs b/vendor/hyper-util/src/server/conn/mod.rs new file mode 100644 index 00000000..b23503a1 --- /dev/null +++ b/vendor/hyper-util/src/server/conn/mod.rs @@ -0,0 +1,4 @@ +//! Connection utilities. + +#[cfg(any(feature = "http1", feature = "http2"))] +pub mod auto; diff --git a/vendor/hyper-util/src/server/graceful.rs b/vendor/hyper-util/src/server/graceful.rs new file mode 100644 index 00000000..b367fc8a --- /dev/null +++ b/vendor/hyper-util/src/server/graceful.rs @@ -0,0 +1,488 @@ +//! Utility to gracefully shutdown a server. +//! +//! This module provides a [`GracefulShutdown`] type, +//! which can be used to gracefully shutdown a server. +//! +//! See <https://github.com/hyperium/hyper-util/blob/master/examples/server_graceful.rs> +//! for an example of how to use this. + +use std::{ + fmt::{self, Debug}, + future::Future, + pin::Pin, + task::{self, Poll}, +}; + +use pin_project_lite::pin_project; +use tokio::sync::watch; + +/// A graceful shutdown utility +// Purposefully not `Clone`, see `watcher()` method for why. +pub struct GracefulShutdown { + tx: watch::Sender<()>, +} + +/// A watcher side of the graceful shutdown. +/// +/// This type can only watch a connection, it cannot trigger a shutdown. +/// +/// Call [`GracefulShutdown::watcher()`] to construct one of these. +pub struct Watcher { + rx: watch::Receiver<()>, +} + +impl GracefulShutdown { + /// Create a new graceful shutdown helper. + pub fn new() -> Self { + let (tx, _) = watch::channel(()); + Self { tx } + } + + /// Wrap a future for graceful shutdown watching. + pub fn watch<C: GracefulConnection>(&self, conn: C) -> impl Future<Output = C::Output> { + self.watcher().watch(conn) + } + + /// Create an owned type that can watch a connection. + /// + /// This method allows created an owned type that can be sent onto another + /// task before calling [`Watcher::watch()`]. + // Internal: this function exists because `Clone` allows footguns. + // If the `tx` were cloned (or the `rx`), race conditions can happens where + // one task starting a shutdown is scheduled and interwined with a task + // starting to watch a connection, and the "watch version" is one behind. + pub fn watcher(&self) -> Watcher { + let rx = self.tx.subscribe(); + Watcher { rx } + } + + /// Signal shutdown for all watched connections. + /// + /// This returns a `Future` which will complete once all watched + /// connections have shutdown. + pub async fn shutdown(self) { + let Self { tx } = self; + + // signal all the watched futures about the change + let _ = tx.send(()); + // and then wait for all of them to complete + tx.closed().await; + } + + /// Returns the number of the watching connections. + pub fn count(&self) -> usize { + self.tx.receiver_count() + } +} + +impl Debug for GracefulShutdown { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("GracefulShutdown").finish() + } +} + +impl Default for GracefulShutdown { + fn default() -> Self { + Self::new() + } +} + +impl Watcher { + /// Wrap a future for graceful shutdown watching. + pub fn watch<C: GracefulConnection>(self, conn: C) -> impl Future<Output = C::Output> { + let Watcher { mut rx } = self; + GracefulConnectionFuture::new(conn, async move { + let _ = rx.changed().await; + // hold onto the rx until the watched future is completed + rx + }) + } +} + +impl Debug for Watcher { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("GracefulWatcher").finish() + } +} + +pin_project! { + struct GracefulConnectionFuture<C, F: Future> { + #[pin] + conn: C, + #[pin] + cancel: F, + #[pin] + // If cancelled, this is held until the inner conn is done. + cancelled_guard: Option<F::Output>, + } +} + +impl<C, F: Future> GracefulConnectionFuture<C, F> { + fn new(conn: C, cancel: F) -> Self { + Self { + conn, + cancel, + cancelled_guard: None, + } + } +} + +impl<C, F: Future> Debug for GracefulConnectionFuture<C, F> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("GracefulConnectionFuture").finish() + } +} + +impl<C, F> Future for GracefulConnectionFuture<C, F> +where + C: GracefulConnection, + F: Future, +{ + type Output = C::Output; + + fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> { + let mut this = self.project(); + if this.cancelled_guard.is_none() { + if let Poll::Ready(guard) = this.cancel.poll(cx) { + this.cancelled_guard.set(Some(guard)); + this.conn.as_mut().graceful_shutdown(); + } + } + this.conn.poll(cx) + } +} + +/// An internal utility trait as an umbrella target for all (hyper) connection +/// types that the [`GracefulShutdown`] can watch. +pub trait GracefulConnection: Future<Output = Result<(), Self::Error>> + private::Sealed { + /// The error type returned by the connection when used as a future. + type Error; + + /// Start a graceful shutdown process for this connection. + fn graceful_shutdown(self: Pin<&mut Self>); +} + +#[cfg(feature = "http1")] +impl<I, B, S> GracefulConnection for hyper::server::conn::http1::Connection<I, S> +where + S: hyper::service::HttpService<hyper::body::Incoming, ResBody = B>, + S::Error: Into<Box<dyn std::error::Error + Send + Sync>>, + I: hyper::rt::Read + hyper::rt::Write + Unpin + 'static, + B: hyper::body::Body + 'static, + B::Error: Into<Box<dyn std::error::Error + Send + Sync>>, +{ + type Error = hyper::Error; + + fn graceful_shutdown(self: Pin<&mut Self>) { + hyper::server::conn::http1::Connection::graceful_shutdown(self); + } +} + +#[cfg(feature = "http2")] +impl<I, B, S, E> GracefulConnection for hyper::server::conn::http2::Connection<I, S, E> +where + S: hyper::service::HttpService<hyper::body::Incoming, ResBody = B>, + S::Error: Into<Box<dyn std::error::Error + Send + Sync>>, + I: hyper::rt::Read + hyper::rt::Write + Unpin + 'static, + B: hyper::body::Body + 'static, + B::Error: Into<Box<dyn std::error::Error + Send + Sync>>, + E: hyper::rt::bounds::Http2ServerConnExec<S::Future, B>, +{ + type Error = hyper::Error; + + fn graceful_shutdown(self: Pin<&mut Self>) { + hyper::server::conn::http2::Connection::graceful_shutdown(self); + } +} + +#[cfg(feature = "server-auto")] +impl<I, B, S, E> GracefulConnection for crate::server::conn::auto::Connection<'_, I, S, E> +where + S: hyper::service::Service<http::Request<hyper::body::Incoming>, Response = http::Response<B>>, + S::Error: Into<Box<dyn std::error::Error + Send + Sync>>, + S::Future: 'static, + I: hyper::rt::Read + hyper::rt::Write + Unpin + 'static, + B: hyper::body::Body + 'static, + B::Error: Into<Box<dyn std::error::Error + Send + Sync>>, + E: hyper::rt::bounds::Http2ServerConnExec<S::Future, B>, +{ + type Error = Box<dyn std::error::Error + Send + Sync>; + + fn graceful_shutdown(self: Pin<&mut Self>) { + crate::server::conn::auto::Connection::graceful_shutdown(self); + } +} + +#[cfg(feature = "server-auto")] +impl<I, B, S, E> GracefulConnection + for crate::server::conn::auto::UpgradeableConnection<'_, I, S, E> +where + S: hyper::service::Service<http::Request<hyper::body::Incoming>, Response = http::Response<B>>, + S::Error: Into<Box<dyn std::error::Error + Send + Sync>>, + S::Future: 'static, + I: hyper::rt::Read + hyper::rt::Write + Unpin + Send + 'static, + B: hyper::body::Body + 'static, + B::Error: Into<Box<dyn std::error::Error + Send + Sync>>, + E: hyper::rt::bounds::Http2ServerConnExec<S::Future, B>, +{ + type Error = Box<dyn std::error::Error + Send + Sync>; + + fn graceful_shutdown(self: Pin<&mut Self>) { + crate::server::conn::auto::UpgradeableConnection::graceful_shutdown(self); + } +} + +mod private { + pub trait Sealed {} + + #[cfg(feature = "http1")] + impl<I, B, S> Sealed for hyper::server::conn::http1::Connection<I, S> + where + S: hyper::service::HttpService<hyper::body::Incoming, ResBody = B>, + S::Error: Into<Box<dyn std::error::Error + Send + Sync>>, + I: hyper::rt::Read + hyper::rt::Write + Unpin + 'static, + B: hyper::body::Body + 'static, + B::Error: Into<Box<dyn std::error::Error + Send + Sync>>, + { + } + + #[cfg(feature = "http1")] + impl<I, B, S> Sealed for hyper::server::conn::http1::UpgradeableConnection<I, S> + where + S: hyper::service::HttpService<hyper::body::Incoming, ResBody = B>, + S::Error: Into<Box<dyn std::error::Error + Send + Sync>>, + I: hyper::rt::Read + hyper::rt::Write + Unpin + 'static, + B: hyper::body::Body + 'static, + B::Error: Into<Box<dyn std::error::Error + Send + Sync>>, + { + } + + #[cfg(feature = "http2")] + impl<I, B, S, E> Sealed for hyper::server::conn::http2::Connection<I, S, E> + where + S: hyper::service::HttpService<hyper::body::Incoming, ResBody = B>, + S::Error: Into<Box<dyn std::error::Error + Send + Sync>>, + I: hyper::rt::Read + hyper::rt::Write + Unpin + 'static, + B: hyper::body::Body + 'static, + B::Error: Into<Box<dyn std::error::Error + Send + Sync>>, + E: hyper::rt::bounds::Http2ServerConnExec<S::Future, B>, + { + } + + #[cfg(feature = "server-auto")] + impl<I, B, S, E> Sealed for crate::server::conn::auto::Connection<'_, I, S, E> + where + S: hyper::service::Service< + http::Request<hyper::body::Incoming>, + Response = http::Response<B>, + >, + S::Error: Into<Box<dyn std::error::Error + Send + Sync>>, + S::Future: 'static, + I: hyper::rt::Read + hyper::rt::Write + Unpin + 'static, + B: hyper::body::Body + 'static, + B::Error: Into<Box<dyn std::error::Error + Send + Sync>>, + E: hyper::rt::bounds::Http2ServerConnExec<S::Future, B>, + { + } + + #[cfg(feature = "server-auto")] + impl<I, B, S, E> Sealed for crate::server::conn::auto::UpgradeableConnection<'_, I, S, E> + where + S: hyper::service::Service< + http::Request<hyper::body::Incoming>, + Response = http::Response<B>, + >, + S::Error: Into<Box<dyn std::error::Error + Send + Sync>>, + S::Future: 'static, + I: hyper::rt::Read + hyper::rt::Write + Unpin + Send + 'static, + B: hyper::body::Body + 'static, + B::Error: Into<Box<dyn std::error::Error + Send + Sync>>, + E: hyper::rt::bounds::Http2ServerConnExec<S::Future, B>, + { + } +} + +#[cfg(test)] +mod test { + use super::*; + use pin_project_lite::pin_project; + use std::sync::atomic::{AtomicUsize, Ordering}; + use std::sync::Arc; + + pin_project! { + #[derive(Debug)] + struct DummyConnection<F> { + #[pin] + future: F, + shutdown_counter: Arc<AtomicUsize>, + } + } + + impl<F> private::Sealed for DummyConnection<F> {} + + impl<F: Future> GracefulConnection for DummyConnection<F> { + type Error = (); + + fn graceful_shutdown(self: Pin<&mut Self>) { + self.shutdown_counter.fetch_add(1, Ordering::SeqCst); + } + } + + impl<F: Future> Future for DummyConnection<F> { + type Output = Result<(), ()>; + + fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> { + match self.project().future.poll(cx) { + Poll::Ready(_) => Poll::Ready(Ok(())), + Poll::Pending => Poll::Pending, + } + } + } + + #[cfg(not(miri))] + #[tokio::test] + async fn test_graceful_shutdown_ok() { + let graceful = GracefulShutdown::new(); + let shutdown_counter = Arc::new(AtomicUsize::new(0)); + let (dummy_tx, _) = tokio::sync::broadcast::channel(1); + + for i in 1..=3 { + let mut dummy_rx = dummy_tx.subscribe(); + let shutdown_counter = shutdown_counter.clone(); + + let future = async move { + tokio::time::sleep(std::time::Duration::from_millis(i * 10)).await; + let _ = dummy_rx.recv().await; + }; + let dummy_conn = DummyConnection { + future, + shutdown_counter, + }; + let conn = graceful.watch(dummy_conn); + tokio::spawn(async move { + conn.await.unwrap(); + }); + } + + assert_eq!(shutdown_counter.load(Ordering::SeqCst), 0); + let _ = dummy_tx.send(()); + + tokio::select! { + _ = tokio::time::sleep(std::time::Duration::from_millis(100)) => { + panic!("timeout") + }, + _ = graceful.shutdown() => { + assert_eq!(shutdown_counter.load(Ordering::SeqCst), 3); + } + } + } + + #[cfg(not(miri))] + #[tokio::test] + async fn test_graceful_shutdown_delayed_ok() { + let graceful = GracefulShutdown::new(); + let shutdown_counter = Arc::new(AtomicUsize::new(0)); + + for i in 1..=3 { + let shutdown_counter = shutdown_counter.clone(); + + //tokio::time::sleep(std::time::Duration::from_millis(i * 5)).await; + let future = async move { + tokio::time::sleep(std::time::Duration::from_millis(i * 50)).await; + }; + let dummy_conn = DummyConnection { + future, + shutdown_counter, + }; + let conn = graceful.watch(dummy_conn); + tokio::spawn(async move { + conn.await.unwrap(); + }); + } + + assert_eq!(shutdown_counter.load(Ordering::SeqCst), 0); + + tokio::select! { + _ = tokio::time::sleep(std::time::Duration::from_millis(200)) => { + panic!("timeout") + }, + _ = graceful.shutdown() => { + assert_eq!(shutdown_counter.load(Ordering::SeqCst), 3); + } + } + } + + #[cfg(not(miri))] + #[tokio::test] + async fn test_graceful_shutdown_multi_per_watcher_ok() { + let graceful = GracefulShutdown::new(); + let shutdown_counter = Arc::new(AtomicUsize::new(0)); + + for i in 1..=3 { + let shutdown_counter = shutdown_counter.clone(); + + let mut futures = Vec::new(); + for u in 1..=i { + let future = tokio::time::sleep(std::time::Duration::from_millis(u * 50)); + let dummy_conn = DummyConnection { + future, + shutdown_counter: shutdown_counter.clone(), + }; + let conn = graceful.watch(dummy_conn); + futures.push(conn); + } + tokio::spawn(async move { + futures_util::future::join_all(futures).await; + }); + } + + assert_eq!(shutdown_counter.load(Ordering::SeqCst), 0); + + tokio::select! { + _ = tokio::time::sleep(std::time::Duration::from_millis(200)) => { + panic!("timeout") + }, + _ = graceful.shutdown() => { + assert_eq!(shutdown_counter.load(Ordering::SeqCst), 6); + } + } + } + + #[cfg(not(miri))] + #[tokio::test] + async fn test_graceful_shutdown_timeout() { + let graceful = GracefulShutdown::new(); + let shutdown_counter = Arc::new(AtomicUsize::new(0)); + + for i in 1..=3 { + let shutdown_counter = shutdown_counter.clone(); + + let future = async move { + if i == 1 { + std::future::pending::<()>().await + } else { + std::future::ready(()).await + } + }; + let dummy_conn = DummyConnection { + future, + shutdown_counter, + }; + let conn = graceful.watch(dummy_conn); + tokio::spawn(async move { + conn.await.unwrap(); + }); + } + + assert_eq!(shutdown_counter.load(Ordering::SeqCst), 0); + + tokio::select! { + _ = tokio::time::sleep(std::time::Duration::from_millis(100)) => { + assert_eq!(shutdown_counter.load(Ordering::SeqCst), 3); + }, + _ = graceful.shutdown() => { + panic!("shutdown should not be completed: as not all our conns finish") + } + } + } +} diff --git a/vendor/hyper-util/src/server/mod.rs b/vendor/hyper-util/src/server/mod.rs new file mode 100644 index 00000000..a4838ac5 --- /dev/null +++ b/vendor/hyper-util/src/server/mod.rs @@ -0,0 +1,6 @@ +//! Server utilities. + +pub mod conn; + +#[cfg(feature = "server-graceful")] +pub mod graceful; diff --git a/vendor/hyper-util/src/service/glue.rs b/vendor/hyper-util/src/service/glue.rs new file mode 100644 index 00000000..ceff86f5 --- /dev/null +++ b/vendor/hyper-util/src/service/glue.rs @@ -0,0 +1,72 @@ +use pin_project_lite::pin_project; +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, +}; + +use super::Oneshot; + +/// A tower [`Service`][tower-svc] converted into a hyper [`Service`][hyper-svc]. +/// +/// This wraps an inner tower service `S` in a [`hyper::service::Service`] implementation. See +/// the module-level documentation of [`service`][crate::service] for more information about using +/// [`tower`][tower] services and middleware with [`hyper`]. +/// +/// [hyper-svc]: hyper::service::Service +/// [tower]: https://docs.rs/tower/latest/tower/ +/// [tower-svc]: https://docs.rs/tower/latest/tower/trait.Service.html +#[derive(Debug, Copy, Clone)] +pub struct TowerToHyperService<S> { + service: S, +} + +impl<S> TowerToHyperService<S> { + /// Create a new [`TowerToHyperService`] from a tower service. + pub fn new(tower_service: S) -> Self { + Self { + service: tower_service, + } + } +} + +impl<S, R> hyper::service::Service<R> for TowerToHyperService<S> +where + S: tower_service::Service<R> + Clone, +{ + type Response = S::Response; + type Error = S::Error; + type Future = TowerToHyperServiceFuture<S, R>; + + fn call(&self, req: R) -> Self::Future { + TowerToHyperServiceFuture { + future: Oneshot::new(self.service.clone(), req), + } + } +} + +pin_project! { + /// Response future for [`TowerToHyperService`]. + /// + /// This future is acquired by [`call`][hyper::service::Service::call]ing a + /// [`TowerToHyperService`]. + pub struct TowerToHyperServiceFuture<S, R> + where + S: tower_service::Service<R>, + { + #[pin] + future: Oneshot<S, R>, + } +} + +impl<S, R> Future for TowerToHyperServiceFuture<S, R> +where + S: tower_service::Service<R>, +{ + type Output = Result<S::Response, S::Error>; + + #[inline] + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + self.project().future.poll(cx) + } +} diff --git a/vendor/hyper-util/src/service/mod.rs b/vendor/hyper-util/src/service/mod.rs new file mode 100644 index 00000000..34796431 --- /dev/null +++ b/vendor/hyper-util/src/service/mod.rs @@ -0,0 +1,32 @@ +//! Service utilities. +//! +//! [`hyper::service`] provides a [`Service`][hyper-svc] trait, representing an asynchronous +//! function from a `Request` to a `Response`. This provides an interface allowing middleware for +//! network application to be written in a modular and reusable way. +//! +//! This submodule provides an assortment of utilities for working with [`Service`][hyper-svc]s. +//! See the module-level documentation of [`hyper::service`] for more information. +//! +//! # Tower +//! +//! While [`hyper`] uses its own notion of a [`Service`][hyper-svc] internally, many other +//! libraries use a library such as [`tower`][tower] to provide the fundamental model of an +//! asynchronous function. +//! +//! The [`TowerToHyperService`] type provided by this submodule can be used to bridge these +//! ecosystems together. By wrapping a [`tower::Service`][tower-svc] in [`TowerToHyperService`], +//! it can be passed into [`hyper`] interfaces that expect a [`hyper::service::Service`]. +//! +//! [hyper-svc]: hyper::service::Service +//! [tower]: https://docs.rs/tower/latest/tower/ +//! [tower-svc]: https://docs.rs/tower/latest/tower/trait.Service.html + +#[cfg(feature = "service")] +mod glue; +#[cfg(any(feature = "client-legacy", feature = "service"))] +mod oneshot; + +#[cfg(feature = "service")] +pub use self::glue::{TowerToHyperService, TowerToHyperServiceFuture}; +#[cfg(any(feature = "client-legacy", feature = "service"))] +pub(crate) use self::oneshot::Oneshot; diff --git a/vendor/hyper-util/src/service/oneshot.rs b/vendor/hyper-util/src/service/oneshot.rs new file mode 100644 index 00000000..2cc3e6e9 --- /dev/null +++ b/vendor/hyper-util/src/service/oneshot.rs @@ -0,0 +1,63 @@ +use futures_core::ready; +use pin_project_lite::pin_project; +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tower_service::Service; + +// Vendored from tower::util to reduce dependencies, the code is small enough. + +// Not really pub, but used in a trait for bounds +pin_project! { + #[project = OneshotProj] + #[derive(Debug)] + pub enum Oneshot<S: Service<Req>, Req> { + NotReady { + svc: S, + req: Option<Req>, + }, + Called { + #[pin] + fut: S::Future, + }, + Done, + } +} + +impl<S, Req> Oneshot<S, Req> +where + S: Service<Req>, +{ + pub(crate) const fn new(svc: S, req: Req) -> Self { + Oneshot::NotReady { + svc, + req: Some(req), + } + } +} + +impl<S, Req> Future for Oneshot<S, Req> +where + S: Service<Req>, +{ + type Output = Result<S::Response, S::Error>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + loop { + let this = self.as_mut().project(); + match this { + OneshotProj::NotReady { svc, req } => { + ready!(svc.poll_ready(cx))?; + let fut = svc.call(req.take().expect("already called")); + self.set(Oneshot::Called { fut }); + } + OneshotProj::Called { fut } => { + let res = ready!(fut.poll(cx))?; + self.set(Oneshot::Done); + return Poll::Ready(Ok(res)); + } + OneshotProj::Done => panic!("polled after complete"), + } + } + } +} |
