summaryrefslogtreecommitdiff
path: root/vendor/hyper-rustls/src/stream.rs
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/hyper-rustls/src/stream.rs')
-rw-r--r--vendor/hyper-rustls/src/stream.rs121
1 files changed, 121 insertions, 0 deletions
diff --git a/vendor/hyper-rustls/src/stream.rs b/vendor/hyper-rustls/src/stream.rs
new file mode 100644
index 00000000..f08e7b1b
--- /dev/null
+++ b/vendor/hyper-rustls/src/stream.rs
@@ -0,0 +1,121 @@
+// Copied from hyperium/hyper-tls#62e3376/src/stream.rs
+use std::fmt;
+use std::io;
+use std::pin::Pin;
+use std::task::{Context, Poll};
+
+use hyper::rt;
+use hyper_util::client::legacy::connect::{Connected, Connection};
+
+use hyper_util::rt::TokioIo;
+use tokio_rustls::client::TlsStream;
+
+/// A stream that might be protected with TLS.
+#[allow(clippy::large_enum_variant)]
+pub enum MaybeHttpsStream<T> {
+ /// A stream over plain text.
+ Http(T),
+ /// A stream protected with TLS.
+ Https(TokioIo<TlsStream<TokioIo<T>>>),
+}
+
+impl<T: rt::Read + rt::Write + Connection + Unpin> Connection for MaybeHttpsStream<T> {
+ fn connected(&self) -> Connected {
+ match self {
+ Self::Http(s) => s.connected(),
+ Self::Https(s) => {
+ let (tcp, tls) = s.inner().get_ref();
+ if tls.alpn_protocol() == Some(b"h2") {
+ tcp.inner().connected().negotiated_h2()
+ } else {
+ tcp.inner().connected()
+ }
+ }
+ }
+ }
+}
+
+impl<T: fmt::Debug> fmt::Debug for MaybeHttpsStream<T> {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ match *self {
+ Self::Http(..) => f.pad("Http(..)"),
+ Self::Https(..) => f.pad("Https(..)"),
+ }
+ }
+}
+
+impl<T> From<T> for MaybeHttpsStream<T> {
+ fn from(inner: T) -> Self {
+ Self::Http(inner)
+ }
+}
+
+impl<T> From<TlsStream<TokioIo<T>>> for MaybeHttpsStream<T> {
+ fn from(inner: TlsStream<TokioIo<T>>) -> Self {
+ Self::Https(TokioIo::new(inner))
+ }
+}
+
+impl<T: rt::Read + rt::Write + Unpin> rt::Read for MaybeHttpsStream<T> {
+ #[inline]
+ fn poll_read(
+ self: Pin<&mut Self>,
+ cx: &mut Context,
+ buf: rt::ReadBufCursor<'_>,
+ ) -> Poll<Result<(), io::Error>> {
+ match Pin::get_mut(self) {
+ Self::Http(s) => Pin::new(s).poll_read(cx, buf),
+ Self::Https(s) => Pin::new(s).poll_read(cx, buf),
+ }
+ }
+}
+
+impl<T: rt::Write + rt::Read + Unpin> rt::Write for MaybeHttpsStream<T> {
+ #[inline]
+ fn poll_write(
+ self: Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ buf: &[u8],
+ ) -> Poll<Result<usize, io::Error>> {
+ match Pin::get_mut(self) {
+ Self::Http(s) => Pin::new(s).poll_write(cx, buf),
+ Self::Https(s) => Pin::new(s).poll_write(cx, buf),
+ }
+ }
+
+ #[inline]
+ fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
+ match Pin::get_mut(self) {
+ Self::Http(s) => Pin::new(s).poll_flush(cx),
+ Self::Https(s) => Pin::new(s).poll_flush(cx),
+ }
+ }
+
+ #[inline]
+ fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
+ match Pin::get_mut(self) {
+ Self::Http(s) => Pin::new(s).poll_shutdown(cx),
+ Self::Https(s) => Pin::new(s).poll_shutdown(cx),
+ }
+ }
+
+ #[inline]
+ fn is_write_vectored(&self) -> bool {
+ match self {
+ Self::Http(s) => s.is_write_vectored(),
+ Self::Https(s) => s.is_write_vectored(),
+ }
+ }
+
+ #[inline]
+ fn poll_write_vectored(
+ self: Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ bufs: &[io::IoSlice<'_>],
+ ) -> Poll<Result<usize, io::Error>> {
+ match Pin::get_mut(self) {
+ Self::Http(s) => Pin::new(s).poll_write_vectored(cx, bufs),
+ Self::Https(s) => Pin::new(s).poll_write_vectored(cx, bufs),
+ }
+ }
+}