use framed codecs to avoid unbounded buffer (#33)

* using stream

* fix tests

* 💄

* 💄 fix review comments

* clean up buffered data

* 💄 fix review comments

* Refactor Delimited to be its own struct

* Add very_long_frame test to ensure behavior

Co-authored-by: Eric Zhang <ekzhang1@gmail.com>
This commit is contained in:
Prasanth
2022-04-22 09:18:38 +05:30
committed by GitHub
parent e61362915d
commit 9cd43f458a
8 changed files with 220 additions and 111 deletions

View File

@@ -6,16 +6,14 @@ use std::time::Duration;
use anyhow::Result;
use dashmap::DashMap;
use tokio::io::BufReader;
use tokio::io::AsyncWriteExt;
use tokio::net::{TcpListener, TcpStream};
use tokio::time::{sleep, timeout};
use tracing::{info, info_span, warn, Instrument};
use uuid::Uuid;
use crate::auth::Authenticator;
use crate::shared::{
proxy, recv_json_timeout, send_json, ClientMessage, ServerMessage, CONTROL_PORT,
};
use crate::shared::{proxy, ClientMessage, Delimited, ServerMessage, CONTROL_PORT};
/// State structure for the server.
pub struct Server {
@@ -64,16 +62,16 @@ impl Server {
}
async fn handle_connection(&self, stream: TcpStream) -> Result<()> {
let mut stream = BufReader::new(stream);
let mut stream = Delimited::new(stream);
if let Some(auth) = &self.auth {
if let Err(err) = auth.server_handshake(&mut stream).await {
warn!(%err, "server handshake failed");
send_json(&mut stream, ServerMessage::Error(err.to_string())).await?;
stream.send(ServerMessage::Error(err.to_string())).await?;
return Ok(());
}
}
match recv_json_timeout(&mut stream).await? {
match stream.recv_timeout().await? {
Some(ClientMessage::Authenticate(_)) => {
warn!("unexpected authenticate");
Ok(())
@@ -88,22 +86,17 @@ impl Server {
Ok(listener) => listener,
Err(_) => {
warn!(?port, "could not bind to local port");
send_json(
&mut stream,
ServerMessage::Error("port already in use".into()),
)
.await?;
stream
.send(ServerMessage::Error("port already in use".into()))
.await?;
return Ok(());
}
};
let port = listener.local_addr()?.port();
send_json(&mut stream, ServerMessage::Hello(port)).await?;
stream.send(ServerMessage::Hello(port)).await?;
loop {
if send_json(&mut stream, ServerMessage::Heartbeat)
.await
.is_err()
{
if stream.send(ServerMessage::Heartbeat).await.is_err() {
// Assume that the TCP connection has been dropped.
return Ok(());
}
@@ -123,14 +116,19 @@ impl Server {
warn!(%id, "removed stale connection");
}
});
send_json(&mut stream, ServerMessage::Connection(id)).await?;
stream.send(ServerMessage::Connection(id)).await?;
}
}
}
Some(ClientMessage::Accept(id)) => {
info!(%id, "forwarding connection");
match self.conns.remove(&id) {
Some((_, stream2)) => proxy(stream, stream2).await?,
Some((_, mut stream2)) => {
let parts = stream.into_parts();
debug_assert!(parts.write_buf.is_empty(), "framed write buffer not empty");
stream2.write_all(&parts.read_buf).await?;
proxy(parts.io, stream2).await?
}
None => warn!(%id, "missing connection"),
}
Ok(())