1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
|
#![cfg(feature = "early-data")]
use std::io::{self, Read, Write};
use std::net::{SocketAddr, TcpListener};
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::thread;
use futures_util::{future::Future, ready};
use rustls::pki_types::ServerName;
use rustls::{self, ClientConfig, ServerConnection, Stream};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt, ReadBuf};
use tokio::net::TcpStream;
use tokio_rustls::client::TlsStream;
use tokio_rustls::TlsConnector;
struct Read1<T>(T);
impl<T: AsyncRead + Unpin> Future for Read1<T> {
type Output = io::Result<()>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut buf = [0];
let mut buf = ReadBuf::new(&mut buf);
ready!(Pin::new(&mut self.0).poll_read(cx, &mut buf))?;
if buf.filled().is_empty() {
Poll::Ready(Ok(()))
} else {
cx.waker().wake_by_ref();
Poll::Pending
}
}
}
async fn send(
config: Arc<ClientConfig>,
addr: SocketAddr,
data: &[u8],
vectored: bool,
) -> io::Result<(TlsStream<TcpStream>, Vec<u8>)> {
let connector = TlsConnector::from(config).early_data(true);
let stream = TcpStream::connect(&addr).await?;
let domain = ServerName::try_from("foobar.com").unwrap();
let mut stream = connector.connect(domain, stream).await?;
utils::write(&mut stream, data, vectored).await?;
stream.flush().await?;
stream.shutdown().await?;
let mut buf = Vec::new();
stream.read_to_end(&mut buf).await?;
Ok((stream, buf))
}
#[tokio::test]
async fn test_0rtt() -> io::Result<()> {
test_0rtt_impl(false).await
}
#[tokio::test]
async fn test_0rtt_vectored() -> io::Result<()> {
test_0rtt_impl(true).await
}
async fn test_0rtt_impl(vectored: bool) -> io::Result<()> {
let (mut server, mut client) = utils::make_configs();
server.max_early_data_size = 8192;
let server = Arc::new(server);
let listener = TcpListener::bind("127.0.0.1:0")?;
let server_port = listener.local_addr().unwrap().port();
thread::spawn(move || loop {
let (mut sock, _addr) = listener.accept().unwrap();
let server = Arc::clone(&server);
thread::spawn(move || {
let mut conn = ServerConnection::new(server).unwrap();
conn.complete_io(&mut sock).unwrap();
if let Some(mut early_data) = conn.early_data() {
let mut buf = Vec::new();
early_data.read_to_end(&mut buf).unwrap();
let mut stream = Stream::new(&mut conn, &mut sock);
stream.write_all(b"EARLY:").unwrap();
stream.write_all(&buf).unwrap();
}
let mut stream = Stream::new(&mut conn, &mut sock);
stream.write_all(b"LATE:").unwrap();
loop {
let mut buf = [0; 1024];
let n = stream.read(&mut buf).unwrap();
if n == 0 {
conn.send_close_notify();
conn.complete_io(&mut sock).unwrap();
break;
}
stream.write_all(&buf[..n]).unwrap();
}
});
});
client.enable_early_data = true;
let client = Arc::new(client);
let addr = SocketAddr::from(([127, 0, 0, 1], server_port));
let (io, buf) = send(client.clone(), addr, b"hello", vectored).await?;
assert!(!io.get_ref().1.is_early_data_accepted());
assert_eq!("LATE:hello", String::from_utf8_lossy(&buf));
let (io, buf) = send(client, addr, b"world!", vectored).await?;
assert!(io.get_ref().1.is_early_data_accepted());
assert_eq!("EARLY:world!LATE:", String::from_utf8_lossy(&buf));
Ok(())
}
// Include `utils` module
include!("utils.rs");
|