mirror of
https://github.com/ekzhang/bore.git
synced 2025-12-16 03:47:50 +01:00
Implement the rest of the bore client
This commit is contained in:
@@ -4,22 +4,20 @@ use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use anyhow::Result;
|
||||
use dashmap::DashMap;
|
||||
use serde::de::DeserializeOwned;
|
||||
use serde::Serialize;
|
||||
use tokio::io::{AsyncBufRead, AsyncBufReadExt, AsyncWrite, AsyncWriteExt, BufReader};
|
||||
use tokio::io::BufReader;
|
||||
use tokio::net::{TcpListener, TcpStream};
|
||||
use tokio::time::{sleep, timeout};
|
||||
use tracing::{info, info_span, warn, Instrument};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::shared::{proxy, ClientMessage, ServerMessage, CONTROL_PORT};
|
||||
use crate::shared::{proxy, recv_json, send_json, ClientMessage, ServerMessage, CONTROL_PORT};
|
||||
|
||||
/// State structure for the server.
|
||||
pub struct Server {
|
||||
/// The minimum TCP port that can be forwarded.
|
||||
pub min_port: u16,
|
||||
min_port: u16,
|
||||
|
||||
/// Concurrent map of IDs to incoming connections.
|
||||
conns: Arc<DashMap<Uuid, TcpStream>>,
|
||||
@@ -27,7 +25,7 @@ pub struct Server {
|
||||
|
||||
impl Server {
|
||||
/// Create a new server with a specified minimum port number.
|
||||
pub fn new(min_port: u16) -> Server {
|
||||
pub fn new(min_port: u16) -> Self {
|
||||
Server {
|
||||
min_port,
|
||||
conns: Arc::new(DashMap::new()),
|
||||
@@ -62,18 +60,28 @@ impl Server {
|
||||
let mut stream = BufReader::new(stream);
|
||||
|
||||
let mut buf = Vec::new();
|
||||
let msg = next_mp(&mut stream, &mut buf).await?;
|
||||
let msg = recv_json(&mut stream, &mut buf).await?;
|
||||
|
||||
match msg {
|
||||
Some(ClientMessage::Hello(port)) => {
|
||||
if port < self.min_port {
|
||||
if port != 0 && port < self.min_port {
|
||||
warn!(?port, "client port number too low");
|
||||
return Ok(());
|
||||
}
|
||||
info!(?port, "new client");
|
||||
let listener = TcpListener::bind(("::", port)).await?;
|
||||
let listener = match TcpListener::bind(("::", port)).await {
|
||||
Ok(listener) => listener,
|
||||
Err(_) => {
|
||||
warn!(?port, "could not bind to local port");
|
||||
send_json(&mut stream, "port already in use").await?;
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
let port = listener.local_addr()?.port();
|
||||
send_json(&mut stream, ServerMessage::Hello(port)).await?;
|
||||
|
||||
loop {
|
||||
if send_mp(&mut stream, ServerMessage::Heartbeat)
|
||||
if send_json(&mut stream, ServerMessage::Heartbeat)
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
@@ -92,18 +100,18 @@ impl Server {
|
||||
// Remove stale entries to avoid memory leaks.
|
||||
sleep(Duration::from_secs(10)).await;
|
||||
if conns.remove(&id).is_some() {
|
||||
warn!(?id, "removed stale connection");
|
||||
warn!(%id, "removed stale connection");
|
||||
}
|
||||
});
|
||||
send_mp(&mut stream, ServerMessage::Connection(id)).await?;
|
||||
send_json(&mut stream, ServerMessage::Connection(id)).await?;
|
||||
}
|
||||
}
|
||||
}
|
||||
Some(ClientMessage::Accept(id)) => {
|
||||
info!(?id, "forwarding connection");
|
||||
info!(%id, "forwarding connection");
|
||||
match self.conns.remove(&id) {
|
||||
Some((_, stream2)) => proxy(stream, stream2).await?,
|
||||
None => warn!(?id, "missing connection ID"),
|
||||
None => warn!(%id, "missing connection"),
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
@@ -120,27 +128,3 @@ impl Default for Server {
|
||||
Server::new(1024)
|
||||
}
|
||||
}
|
||||
|
||||
/// Read the next null-delimited MessagePack instruction from a stream.
|
||||
async fn next_mp<T: DeserializeOwned>(
|
||||
reader: &mut (impl AsyncBufRead + Unpin),
|
||||
buf: &mut Vec<u8>,
|
||||
) -> Result<Option<T>> {
|
||||
buf.clear();
|
||||
reader.read_until(0, buf).await?;
|
||||
if buf.is_empty() {
|
||||
return Ok(None);
|
||||
}
|
||||
if buf.last() == Some(&0) {
|
||||
buf.pop();
|
||||
}
|
||||
Ok(rmp_serde::from_slice(buf).context("failed to parse MessagePack")?)
|
||||
}
|
||||
|
||||
/// Send a null-terminated MessagePack instruction on a stream.
|
||||
async fn send_mp<T: Serialize>(writer: &mut (impl AsyncWrite + Unpin), msg: T) -> Result<()> {
|
||||
let msg = rmp_serde::to_vec(&msg)?;
|
||||
writer.write_all(&msg).await?;
|
||||
writer.write_all(&[0]).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user