Implement the rest of the bore client

This commit is contained in:
Eric Zhang
2022-04-06 02:08:01 -04:00
parent fe1c8ad0e9
commit 599926d19c
6 changed files with 245 additions and 76 deletions

View File

@@ -4,22 +4,20 @@ use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use anyhow::{Context, Result};
use anyhow::Result;
use dashmap::DashMap;
use serde::de::DeserializeOwned;
use serde::Serialize;
use tokio::io::{AsyncBufRead, AsyncBufReadExt, AsyncWrite, AsyncWriteExt, BufReader};
use tokio::io::BufReader;
use tokio::net::{TcpListener, TcpStream};
use tokio::time::{sleep, timeout};
use tracing::{info, info_span, warn, Instrument};
use uuid::Uuid;
use crate::shared::{proxy, ClientMessage, ServerMessage, CONTROL_PORT};
use crate::shared::{proxy, recv_json, send_json, ClientMessage, ServerMessage, CONTROL_PORT};
/// State structure for the server.
pub struct Server {
/// The minimum TCP port that can be forwarded.
pub min_port: u16,
min_port: u16,
/// Concurrent map of IDs to incoming connections.
conns: Arc<DashMap<Uuid, TcpStream>>,
@@ -27,7 +25,7 @@ pub struct Server {
impl Server {
/// Create a new server with a specified minimum port number.
pub fn new(min_port: u16) -> Server {
pub fn new(min_port: u16) -> Self {
Server {
min_port,
conns: Arc::new(DashMap::new()),
@@ -62,18 +60,28 @@ impl Server {
let mut stream = BufReader::new(stream);
let mut buf = Vec::new();
let msg = next_mp(&mut stream, &mut buf).await?;
let msg = recv_json(&mut stream, &mut buf).await?;
match msg {
Some(ClientMessage::Hello(port)) => {
if port < self.min_port {
if port != 0 && port < self.min_port {
warn!(?port, "client port number too low");
return Ok(());
}
info!(?port, "new client");
let listener = TcpListener::bind(("::", port)).await?;
let listener = match TcpListener::bind(("::", port)).await {
Ok(listener) => listener,
Err(_) => {
warn!(?port, "could not bind to local port");
send_json(&mut stream, "port already in use").await?;
return Ok(());
}
};
let port = listener.local_addr()?.port();
send_json(&mut stream, ServerMessage::Hello(port)).await?;
loop {
if send_mp(&mut stream, ServerMessage::Heartbeat)
if send_json(&mut stream, ServerMessage::Heartbeat)
.await
.is_err()
{
@@ -92,18 +100,18 @@ impl Server {
// Remove stale entries to avoid memory leaks.
sleep(Duration::from_secs(10)).await;
if conns.remove(&id).is_some() {
warn!(?id, "removed stale connection");
warn!(%id, "removed stale connection");
}
});
send_mp(&mut stream, ServerMessage::Connection(id)).await?;
send_json(&mut stream, ServerMessage::Connection(id)).await?;
}
}
}
Some(ClientMessage::Accept(id)) => {
info!(?id, "forwarding connection");
info!(%id, "forwarding connection");
match self.conns.remove(&id) {
Some((_, stream2)) => proxy(stream, stream2).await?,
None => warn!(?id, "missing connection ID"),
None => warn!(%id, "missing connection"),
}
Ok(())
}
@@ -120,27 +128,3 @@ impl Default for Server {
Server::new(1024)
}
}
/// Read the next null-delimited MessagePack instruction from a stream.
async fn next_mp<T: DeserializeOwned>(
reader: &mut (impl AsyncBufRead + Unpin),
buf: &mut Vec<u8>,
) -> Result<Option<T>> {
buf.clear();
reader.read_until(0, buf).await?;
if buf.is_empty() {
return Ok(None);
}
if buf.last() == Some(&0) {
buf.pop();
}
Ok(rmp_serde::from_slice(buf).context("failed to parse MessagePack")?)
}
/// Send a null-terminated MessagePack instruction on a stream.
async fn send_mp<T: Serialize>(writer: &mut (impl AsyncWrite + Unpin), msg: T) -> Result<()> {
let msg = rmp_serde::to_vec(&msg)?;
writer.write_all(&msg).await?;
writer.write_all(&[0]).await?;
Ok(())
}