use framed codecs to avoid unbounded buffer (#33)

* using stream

* fix tests

* 💄

* 💄 fix review comments

* clean up buffered data

* 💄 fix review comments

* Refactor Delimited to be its own struct

* Add very_long_frame test to ensure behavior

Co-authored-by: Eric Zhang <ekzhang1@gmail.com>
This commit is contained in:
Prasanth
2022-04-22 09:18:38 +05:30
committed by GitHub
parent e61362915d
commit 9cd43f458a
8 changed files with 220 additions and 111 deletions

View File

@@ -3,20 +3,21 @@
use std::sync::Arc;
use anyhow::{bail, Context, Result};
use tokio::{io::BufReader, net::TcpStream, time::timeout};
use tokio::io::AsyncWriteExt;
use tokio::{net::TcpStream, time::timeout};
use tracing::{error, info, info_span, warn, Instrument};
use uuid::Uuid;
use crate::auth::Authenticator;
use crate::shared::{
proxy, recv_json, recv_json_timeout, send_json, ClientMessage, ServerMessage, CONTROL_PORT,
NETWORK_TIMEOUT,
proxy, ClientMessage, Delimited, ServerMessage, CONTROL_PORT, NETWORK_TIMEOUT,
};
/// State structure for the client.
pub struct Client {
/// Control connection to the server.
conn: Option<BufReader<TcpStream>>,
conn: Option<Delimited<TcpStream>>,
/// Destination address of the server.
to: String,
@@ -43,15 +44,14 @@ impl Client {
port: u16,
secret: Option<&str>,
) -> Result<Self> {
let mut stream = BufReader::new(connect_with_timeout(to, CONTROL_PORT).await?);
let mut stream = Delimited::new(connect_with_timeout(to, CONTROL_PORT).await?);
let auth = secret.map(Authenticator::new);
if let Some(auth) = &auth {
auth.client_handshake(&mut stream).await?;
}
send_json(&mut stream, ClientMessage::Hello(port)).await?;
let remote_port = match recv_json_timeout(&mut stream).await? {
stream.send(ClientMessage::Hello(port)).await?;
let remote_port = match stream.recv_timeout().await? {
Some(ServerMessage::Hello(remote_port)) => remote_port,
Some(ServerMessage::Error(message)) => bail!("server error: {message}"),
Some(ServerMessage::Challenge(_)) => {
@@ -82,10 +82,8 @@ impl Client {
pub async fn listen(mut self) -> Result<()> {
let mut conn = self.conn.take().unwrap();
let this = Arc::new(self);
let mut buf = Vec::new();
loop {
let msg = recv_json(&mut conn, &mut buf).await?;
match msg {
match conn.recv().await? {
Some(ServerMessage::Hello(_)) => warn!("unexpected hello"),
Some(ServerMessage::Challenge(_)) => warn!("unexpected challenge"),
Some(ServerMessage::Heartbeat) => (),
@@ -110,14 +108,16 @@ impl Client {
async fn handle_connection(&self, id: Uuid) -> Result<()> {
let mut remote_conn =
BufReader::new(connect_with_timeout(&self.to[..], CONTROL_PORT).await?);
Delimited::new(connect_with_timeout(&self.to[..], CONTROL_PORT).await?);
if let Some(auth) = &self.auth {
auth.client_handshake(&mut remote_conn).await?;
}
send_json(&mut remote_conn, ClientMessage::Accept(id)).await?;
let local_conn = connect_with_timeout(&self.local_host, self.local_port).await?;
proxy(local_conn, remote_conn).await?;
remote_conn.send(ClientMessage::Accept(id)).await?;
let mut local_conn = connect_with_timeout(&self.local_host, self.local_port).await?;
let parts = remote_conn.into_parts();
debug_assert!(parts.write_buf.is_empty(), "framed write buffer not empty");
local_conn.write_all(&parts.read_buf).await?; // mostly of the cases, this will be empty
proxy(local_conn, parts.io).await?;
Ok(())
}
}