mirror of
https://github.com/ekzhang/bore.git
synced 2025-12-15 19:37:47 +01:00
183 lines
7.0 KiB
Rust
183 lines
7.0 KiB
Rust
//! Server implementation for the `bore` service.
|
|
|
|
use std::{io, ops::RangeInclusive, sync::Arc, time::Duration};
|
|
use std::net::IpAddr;
|
|
|
|
use anyhow::Result;
|
|
use dashmap::DashMap;
|
|
use tokio::io::AsyncWriteExt;
|
|
use tokio::net::{TcpListener, TcpStream};
|
|
use tokio::time::{sleep, timeout};
|
|
use tracing::{info, info_span, warn, Instrument};
|
|
use uuid::Uuid;
|
|
|
|
use crate::auth::Authenticator;
|
|
use crate::shared::{proxy, ClientMessage, Delimited, ServerMessage, CONTROL_PORT};
|
|
|
|
/// State structure for the server.
|
|
pub struct Server {
|
|
/// Range of TCP ports that can be forwarded.
|
|
port_range: RangeInclusive<u16>,
|
|
|
|
/// Optional secret used to authenticate clients.
|
|
auth: Option<Authenticator>,
|
|
|
|
/// Concurrent map of IDs to incoming connections.
|
|
conns: Arc<DashMap<Uuid, TcpStream>>,
|
|
|
|
/// IP address where the control server will bind to. Bore clients must reach this.
|
|
bind_addr: IpAddr,
|
|
|
|
/// IP address where tunnels will listen on.
|
|
bind_tunnels: IpAddr,
|
|
}
|
|
|
|
impl Server {
|
|
/// Create a new server with a specified minimum port number.
|
|
pub fn new(
|
|
port_range: RangeInclusive<u16>,
|
|
secret: Option<&str>,
|
|
bind_addr: IpAddr,
|
|
bind_tunnels: IpAddr,
|
|
) -> Self {
|
|
assert!(!port_range.is_empty(), "must provide at least one port");
|
|
Server {
|
|
port_range,
|
|
conns: Arc::new(DashMap::new()),
|
|
auth: secret.map(Authenticator::new),
|
|
bind_addr,
|
|
bind_tunnels,
|
|
}
|
|
}
|
|
|
|
/// Start the server, listening for new connections.
|
|
pub async fn listen(self) -> Result<()> {
|
|
let this = Arc::new(self);
|
|
let listener = TcpListener::bind((this.bind_addr, CONTROL_PORT)).await?;
|
|
info!(addr = ?this.bind_addr, port = CONTROL_PORT, "server listening");
|
|
|
|
loop {
|
|
let (stream, addr) = listener.accept().await?;
|
|
let this = Arc::clone(&this);
|
|
tokio::spawn(
|
|
async move {
|
|
info!("incoming connection");
|
|
if let Err(err) = this.handle_connection(stream).await {
|
|
warn!(%err, "connection exited with error");
|
|
} else {
|
|
info!("connection exited");
|
|
}
|
|
}
|
|
.instrument(info_span!("control", ?addr)),
|
|
);
|
|
}
|
|
}
|
|
|
|
async fn create_listener(&self, port: u16) -> Result<TcpListener, &'static str> {
|
|
let try_bind = |port: u16| async move {
|
|
TcpListener::bind((self.bind_tunnels, port))
|
|
.await
|
|
.map_err(|err| match err.kind() {
|
|
io::ErrorKind::AddrInUse => "port already in use",
|
|
io::ErrorKind::PermissionDenied => "permission denied",
|
|
_ => "failed to bind to port",
|
|
})
|
|
};
|
|
if port > 0 {
|
|
// Client requests a specific port number.
|
|
if !self.port_range.contains(&port) {
|
|
return Err("client port number not in allowed range");
|
|
}
|
|
try_bind(port).await
|
|
} else {
|
|
// Client requests any available port in range.
|
|
//
|
|
// In this case, we bind to 150 random port numbers. We choose this value because in
|
|
// order to find a free port with probability at least 1-δ, when ε proportion of the
|
|
// ports are currently available, it suffices to check approximately -2 ln(δ) / ε
|
|
// independently and uniformly chosen ports (up to a second-order term in ε).
|
|
//
|
|
// Checking 150 times gives us 99.999% success at utilizing 85% of ports under these
|
|
// conditions, when ε=0.15 and δ=0.00001.
|
|
for _ in 0..150 {
|
|
let port = fastrand::u16(self.port_range.clone());
|
|
match try_bind(port).await {
|
|
Ok(listener) => return Ok(listener),
|
|
Err(_) => continue,
|
|
}
|
|
}
|
|
Err("failed to find an available port")
|
|
}
|
|
}
|
|
|
|
async fn handle_connection(&self, stream: TcpStream) -> Result<()> {
|
|
let mut stream = Delimited::new(stream);
|
|
if let Some(auth) = &self.auth {
|
|
if let Err(err) = auth.server_handshake(&mut stream).await {
|
|
warn!(%err, "server handshake failed");
|
|
stream.send(ServerMessage::Error(err.to_string())).await?;
|
|
return Ok(());
|
|
}
|
|
}
|
|
|
|
match stream.recv_timeout().await? {
|
|
Some(ClientMessage::Authenticate(_)) => {
|
|
warn!("unexpected authenticate");
|
|
Ok(())
|
|
}
|
|
Some(ClientMessage::Hello(port)) => {
|
|
let listener = match self.create_listener(port).await {
|
|
Ok(listener) => listener,
|
|
Err(err) => {
|
|
stream.send(ServerMessage::Error(err.into())).await?;
|
|
return Ok(());
|
|
}
|
|
};
|
|
let host = listener.local_addr()?.ip();
|
|
let port = listener.local_addr()?.port();
|
|
info!(?host, ?port, "new client");
|
|
stream.send(ServerMessage::Hello(port)).await?;
|
|
|
|
loop {
|
|
if stream.send(ServerMessage::Heartbeat).await.is_err() {
|
|
// Assume that the TCP connection has been dropped.
|
|
return Ok(());
|
|
}
|
|
const TIMEOUT: Duration = Duration::from_millis(500);
|
|
if let Ok(result) = timeout(TIMEOUT, listener.accept()).await {
|
|
let (stream2, addr) = result?;
|
|
info!(?addr, ?port, "new connection");
|
|
|
|
let id = Uuid::new_v4();
|
|
let conns = Arc::clone(&self.conns);
|
|
|
|
conns.insert(id, stream2);
|
|
tokio::spawn(async move {
|
|
// Remove stale entries to avoid memory leaks.
|
|
sleep(Duration::from_secs(10)).await;
|
|
if conns.remove(&id).is_some() {
|
|
warn!(%id, "removed stale connection");
|
|
}
|
|
});
|
|
stream.send(ServerMessage::Connection(id)).await?;
|
|
}
|
|
}
|
|
}
|
|
Some(ClientMessage::Accept(id)) => {
|
|
info!(%id, "forwarding connection");
|
|
match self.conns.remove(&id) {
|
|
Some((_, mut stream2)) => {
|
|
let parts = stream.into_parts();
|
|
debug_assert!(parts.write_buf.is_empty(), "framed write buffer not empty");
|
|
stream2.write_all(&parts.read_buf).await?;
|
|
proxy(parts.io, stream2).await?
|
|
}
|
|
None => warn!(%id, "missing connection"),
|
|
}
|
|
Ok(())
|
|
}
|
|
None => Ok(()),
|
|
}
|
|
}
|
|
}
|