use framed codecs to avoid unbounded buffer (#33)

* using stream

* fix tests

* 💄

* 💄 fix review comments

* clean up buffered data

* 💄 fix review comments

* Refactor Delimited to be its own struct

* Add very_long_frame test to ensure behavior

Co-authored-by: Eric Zhang <ekzhang1@gmail.com>
This commit is contained in:
Prasanth
2022-04-22 09:18:38 +05:30
committed by GitHub
parent e61362915d
commit 9cd43f458a
8 changed files with 220 additions and 111 deletions

View File

@@ -3,10 +3,10 @@
use anyhow::{bail, ensure, Result};
use hmac::{Hmac, Mac};
use sha2::{Digest, Sha256};
use tokio::io::{AsyncBufRead, AsyncWrite};
use tokio::io::{AsyncRead, AsyncWrite};
use uuid::Uuid;
use crate::shared::{recv_json_timeout, send_json, ClientMessage, ServerMessage};
use crate::shared::{ClientMessage, Delimited, ServerMessage};
/// Wrapper around a MAC used for authenticating clients that have a secret.
pub struct Authenticator(Hmac<Sha256>);
@@ -48,13 +48,13 @@ impl Authenticator {
}
/// As the server, send a challenge to the client and validate their response.
pub async fn server_handshake(
pub async fn server_handshake<T: AsyncRead + AsyncWrite + Unpin>(
&self,
stream: &mut (impl AsyncBufRead + AsyncWrite + Unpin),
stream: &mut Delimited<T>,
) -> Result<()> {
let challenge = Uuid::new_v4();
send_json(stream, ServerMessage::Challenge(challenge)).await?;
match recv_json_timeout(stream).await? {
stream.send(ServerMessage::Challenge(challenge)).await?;
match stream.recv_timeout().await? {
Some(ClientMessage::Authenticate(tag)) => {
ensure!(self.validate(&challenge, &tag), "invalid secret");
Ok(())
@@ -64,16 +64,16 @@ impl Authenticator {
}
/// As the client, answer a challenge to attempt to authenticate with the server.
pub async fn client_handshake(
pub async fn client_handshake<T: AsyncRead + AsyncWrite + Unpin>(
&self,
stream: &mut (impl AsyncBufRead + AsyncWrite + Unpin),
stream: &mut Delimited<T>,
) -> Result<()> {
let challenge = match recv_json_timeout(stream).await? {
let challenge = match stream.recv_timeout().await? {
Some(ServerMessage::Challenge(challenge)) => challenge,
_ => bail!("expected authentication challenge, but no secret was required"),
};
let tag = self.answer(&challenge);
send_json(stream, ClientMessage::Authenticate(tag)).await?;
stream.send(ClientMessage::Authenticate(tag)).await?;
Ok(())
}
}

View File

@@ -3,20 +3,21 @@
use std::sync::Arc;
use anyhow::{bail, Context, Result};
use tokio::{io::BufReader, net::TcpStream, time::timeout};
use tokio::io::AsyncWriteExt;
use tokio::{net::TcpStream, time::timeout};
use tracing::{error, info, info_span, warn, Instrument};
use uuid::Uuid;
use crate::auth::Authenticator;
use crate::shared::{
proxy, recv_json, recv_json_timeout, send_json, ClientMessage, ServerMessage, CONTROL_PORT,
NETWORK_TIMEOUT,
proxy, ClientMessage, Delimited, ServerMessage, CONTROL_PORT, NETWORK_TIMEOUT,
};
/// State structure for the client.
pub struct Client {
/// Control connection to the server.
conn: Option<BufReader<TcpStream>>,
conn: Option<Delimited<TcpStream>>,
/// Destination address of the server.
to: String,
@@ -43,15 +44,14 @@ impl Client {
port: u16,
secret: Option<&str>,
) -> Result<Self> {
let mut stream = BufReader::new(connect_with_timeout(to, CONTROL_PORT).await?);
let mut stream = Delimited::new(connect_with_timeout(to, CONTROL_PORT).await?);
let auth = secret.map(Authenticator::new);
if let Some(auth) = &auth {
auth.client_handshake(&mut stream).await?;
}
send_json(&mut stream, ClientMessage::Hello(port)).await?;
let remote_port = match recv_json_timeout(&mut stream).await? {
stream.send(ClientMessage::Hello(port)).await?;
let remote_port = match stream.recv_timeout().await? {
Some(ServerMessage::Hello(remote_port)) => remote_port,
Some(ServerMessage::Error(message)) => bail!("server error: {message}"),
Some(ServerMessage::Challenge(_)) => {
@@ -82,10 +82,8 @@ impl Client {
pub async fn listen(mut self) -> Result<()> {
let mut conn = self.conn.take().unwrap();
let this = Arc::new(self);
let mut buf = Vec::new();
loop {
let msg = recv_json(&mut conn, &mut buf).await?;
match msg {
match conn.recv().await? {
Some(ServerMessage::Hello(_)) => warn!("unexpected hello"),
Some(ServerMessage::Challenge(_)) => warn!("unexpected challenge"),
Some(ServerMessage::Heartbeat) => (),
@@ -110,14 +108,16 @@ impl Client {
async fn handle_connection(&self, id: Uuid) -> Result<()> {
let mut remote_conn =
BufReader::new(connect_with_timeout(&self.to[..], CONTROL_PORT).await?);
Delimited::new(connect_with_timeout(&self.to[..], CONTROL_PORT).await?);
if let Some(auth) = &self.auth {
auth.client_handshake(&mut remote_conn).await?;
}
send_json(&mut remote_conn, ClientMessage::Accept(id)).await?;
let local_conn = connect_with_timeout(&self.local_host, self.local_port).await?;
proxy(local_conn, remote_conn).await?;
remote_conn.send(ClientMessage::Accept(id)).await?;
let mut local_conn = connect_with_timeout(&self.local_host, self.local_port).await?;
let parts = remote_conn.into_parts();
debug_assert!(parts.write_buf.is_empty(), "framed write buffer not empty");
local_conn.write_all(&parts.read_buf).await?; // mostly of the cases, this will be empty
proxy(local_conn, parts.io).await?;
Ok(())
}
}

View File

@@ -6,16 +6,14 @@ use std::time::Duration;
use anyhow::Result;
use dashmap::DashMap;
use tokio::io::BufReader;
use tokio::io::AsyncWriteExt;
use tokio::net::{TcpListener, TcpStream};
use tokio::time::{sleep, timeout};
use tracing::{info, info_span, warn, Instrument};
use uuid::Uuid;
use crate::auth::Authenticator;
use crate::shared::{
proxy, recv_json_timeout, send_json, ClientMessage, ServerMessage, CONTROL_PORT,
};
use crate::shared::{proxy, ClientMessage, Delimited, ServerMessage, CONTROL_PORT};
/// State structure for the server.
pub struct Server {
@@ -64,16 +62,16 @@ impl Server {
}
async fn handle_connection(&self, stream: TcpStream) -> Result<()> {
let mut stream = BufReader::new(stream);
let mut stream = Delimited::new(stream);
if let Some(auth) = &self.auth {
if let Err(err) = auth.server_handshake(&mut stream).await {
warn!(%err, "server handshake failed");
send_json(&mut stream, ServerMessage::Error(err.to_string())).await?;
stream.send(ServerMessage::Error(err.to_string())).await?;
return Ok(());
}
}
match recv_json_timeout(&mut stream).await? {
match stream.recv_timeout().await? {
Some(ClientMessage::Authenticate(_)) => {
warn!("unexpected authenticate");
Ok(())
@@ -88,22 +86,17 @@ impl Server {
Ok(listener) => listener,
Err(_) => {
warn!(?port, "could not bind to local port");
send_json(
&mut stream,
ServerMessage::Error("port already in use".into()),
)
.await?;
stream
.send(ServerMessage::Error("port already in use".into()))
.await?;
return Ok(());
}
};
let port = listener.local_addr()?.port();
send_json(&mut stream, ServerMessage::Hello(port)).await?;
stream.send(ServerMessage::Hello(port)).await?;
loop {
if send_json(&mut stream, ServerMessage::Heartbeat)
.await
.is_err()
{
if stream.send(ServerMessage::Heartbeat).await.is_err() {
// Assume that the TCP connection has been dropped.
return Ok(());
}
@@ -123,14 +116,19 @@ impl Server {
warn!(%id, "removed stale connection");
}
});
send_json(&mut stream, ServerMessage::Connection(id)).await?;
stream.send(ServerMessage::Connection(id)).await?;
}
}
}
Some(ClientMessage::Accept(id)) => {
info!(%id, "forwarding connection");
match self.conns.remove(&id) {
Some((_, stream2)) => proxy(stream, stream2).await?,
Some((_, mut stream2)) => {
let parts = stream.into_parts();
debug_assert!(parts.write_buf.is_empty(), "framed write buffer not empty");
stream2.write_all(&parts.read_buf).await?;
proxy(parts.io, stream2).await?
}
None => warn!(%id, "missing connection"),
}
Ok(())

View File

@@ -3,16 +3,22 @@
use std::time::Duration;
use anyhow::{Context, Result};
use futures_util::{SinkExt, StreamExt};
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use tokio::io::{self, AsyncBufRead, AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt};
use tokio::io::{self, AsyncRead, AsyncWrite};
use tokio::time::timeout;
use tokio_util::codec::{AnyDelimiterCodec, Framed, FramedParts};
use tracing::trace;
use uuid::Uuid;
/// TCP port used for control connections with the server.
pub const CONTROL_PORT: u16 = 7835;
/// Maxmium byte length for a JSON frame in the stream.
pub const MAX_FRAME_LENGTH: usize = 256;
/// Timeout for network connections and initial protocol messages.
pub const NETWORK_TIMEOUT: Duration = Duration::from_secs(3);
@@ -48,6 +54,52 @@ pub enum ServerMessage {
Error(String),
}
/// Transport stream with JSON frames delimited by null characters.
pub struct Delimited<U>(Framed<U, AnyDelimiterCodec>);
impl<U: AsyncRead + AsyncWrite + Unpin> Delimited<U> {
/// Construct a new delimited stream.
pub fn new(stream: U) -> Self {
let codec = AnyDelimiterCodec::new_with_max_length(vec![0], vec![0], MAX_FRAME_LENGTH);
Self(Framed::new(stream, codec))
}
/// Read the next null-delimited JSON instruction from a stream.
pub async fn recv<T: DeserializeOwned>(&mut self) -> Result<Option<T>> {
trace!("waiting to receive json message");
if let Some(next_message) = self.0.next().await {
let byte_message = next_message.context("frame error, invalid byte length")?;
let serialized_obj = serde_json::from_slice(&byte_message.to_vec())
.context("unable to parse message")?;
Ok(serialized_obj)
} else {
Ok(None)
}
}
/// Read the next null-delimited JSON instruction, with a default timeout.
///
/// This is useful for parsing the initial message of a stream for handshake or
/// other protocol purposes, where we do not want to wait indefinitely.
pub async fn recv_timeout<T: DeserializeOwned>(&mut self) -> Result<Option<T>> {
timeout(NETWORK_TIMEOUT, self.recv())
.await
.context("timed out waiting for initial message")?
}
/// Send a null-terminated JSON instruction on a stream.
pub async fn send<T: Serialize>(&mut self, msg: T) -> Result<()> {
trace!("sending json message");
self.0.send(serde_json::to_string(&msg)?).await?;
Ok(())
}
/// Consume this object, returning current buffers and the inner transport.
pub fn into_parts(self) -> FramedParts<U, AnyDelimiterCodec> {
self.0.into_parts()
}
}
/// Copy data mutually between two read/write streams.
pub async fn proxy<S1, S2>(stream1: S1, stream2: S2) -> io::Result<()>
where
@@ -62,41 +114,3 @@ where
}?;
Ok(())
}
/// Read the next null-delimited JSON instruction from a stream.
pub async fn recv_json<T: DeserializeOwned>(
reader: &mut (impl AsyncBufRead + Unpin),
buf: &mut Vec<u8>,
) -> Result<Option<T>> {
trace!("waiting to receive json message");
buf.clear();
reader.read_until(0, buf).await?;
if buf.is_empty() {
return Ok(None);
}
if buf.last() == Some(&0) {
buf.pop();
}
Ok(serde_json::from_slice(buf).context("failed to parse JSON")?)
}
/// Read the next null-delimited JSON instruction, with a default timeout.
///
/// This is useful for parsing the initial message of a stream for handshake or
/// other protocol purposes, where we do not want to wait indefinitely.
pub async fn recv_json_timeout<T: DeserializeOwned>(
reader: &mut (impl AsyncBufRead + Unpin),
) -> Result<Option<T>> {
timeout(NETWORK_TIMEOUT, recv_json(reader, &mut Vec::new()))
.await
.context("timed out waiting for initial message")?
}
/// Send a null-terminated JSON instruction on a stream.
pub async fn send_json<T: Serialize>(writer: &mut (impl AsyncWrite + Unpin), msg: T) -> Result<()> {
trace!("sending json message");
let msg = serde_json::to_vec(&msg)?;
writer.write_all(&msg).await?;
writer.write_all(&[0]).await?;
Ok(())
}