From 0c7905db0cf6795a9b11a5b2b4a03388d675663f Mon Sep 17 00:00:00 2001 From: Marcin Kulik Date: Thu, 28 Mar 2024 21:26:53 +0100 Subject: [PATCH] Initial version of stream forwarder --- Cargo.lock | 85 ++++++++++++++++++++++++++++--- Cargo.toml | 2 + src/cmd/stream.rs | 18 ++++++- src/streamer/forwarder.rs | 103 ++++++++++++++++++++++++++++++++++++++ src/streamer/mod.rs | 18 ++++++- 5 files changed, 217 insertions(+), 9 deletions(-) create mode 100644 src/streamer/forwarder.rs diff --git a/Cargo.lock b/Cargo.lock index 7c6da2b..a2459e2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -104,9 +104,11 @@ dependencies = [ "termion", "tokio", "tokio-stream", + "tokio-tungstenite", "tower-http", "tracing", "tracing-subscriber", + "url", "uuid", "which", ] @@ -930,9 +932,9 @@ dependencies = [ "futures-util", "http 0.2.11", "hyper 0.14.28", - "rustls", + "rustls 0.21.10", "tokio", - "tokio-rustls", + "tokio-rustls 0.24.1", ] [[package]] @@ -1595,7 +1597,7 @@ dependencies = [ "once_cell", "percent-encoding", "pin-project-lite", - "rustls", + "rustls 0.21.10", "rustls-pemfile", "serde", "serde_json", @@ -1603,14 +1605,14 @@ dependencies = [ "sync_wrapper", "system-configuration", "tokio", - "tokio-rustls", + "tokio-rustls 0.24.1", "tokio-util", "tower-service", "url", "wasm-bindgen", "wasm-bindgen-futures", "web-sys", - "webpki-roots", + "webpki-roots 0.25.4", "winreg", ] @@ -1717,10 +1719,24 @@ checksum = "f9d5a6813c0759e4609cd494e8e725babae6a2ca7b62a5536a13daaec6fcb7ba" dependencies = [ "log", "ring", - "rustls-webpki", + "rustls-webpki 0.101.7", "sct", ] +[[package]] +name = "rustls" +version = "0.22.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "99008d7ad0bbbea527ec27bddbc0e432c5b87d8175178cee68d2eec9c4a1813c" +dependencies = [ + "log", + "ring", + "rustls-pki-types", + "rustls-webpki 0.102.2", + "subtle", + "zeroize", +] + [[package]] name = "rustls-pemfile" version = "1.0.4" @@ -1730,6 +1746,12 @@ dependencies = [ "base64", ] +[[package]] +name = "rustls-pki-types" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "868e20fada228fefaf6b652e00cc73623d54f8171e7352c18bb281571f2d92da" + [[package]] name = "rustls-webpki" version = "0.101.7" @@ -1740,6 +1762,17 @@ dependencies = [ "untrusted", ] +[[package]] +name = "rustls-webpki" +version = "0.102.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "faaa0a62740bedb9b2ef5afa303da42764c012f743917351dc9a237ea1663610" +dependencies = [ + "ring", + "rustls-pki-types", + "untrusted", +] + [[package]] name = "rustversion" version = "1.0.14" @@ -2035,6 +2068,12 @@ version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623" +[[package]] +name = "subtle" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "81cdd64d312baedb58e21336b31bc043b77e01cc99033ce76ef539f78e965ebc" + [[package]] name = "syn" version = "1.0.109" @@ -2209,7 +2248,18 @@ version = "0.24.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c28327cf380ac148141087fbfb9de9d7bd4e84ab5d2c28fbc911d753de8a7081" dependencies = [ - "rustls", + "rustls 0.21.10", + "tokio", +] + +[[package]] +name = "tokio-rustls" +version = "0.25.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "775e0c0f0adb3a2f22a00c4745d728b479985fc15ee7ca6a2608388c5569860f" +dependencies = [ + "rustls 0.22.3", + "rustls-pki-types", "tokio", ] @@ -2233,8 +2283,12 @@ checksum = "c83b561d025642014097b66e6c1bb422783339e0909e4429cde4749d1990bc38" dependencies = [ "futures-util", "log", + "rustls 0.22.3", + "rustls-pki-types", "tokio", + "tokio-rustls 0.25.0", "tungstenite", + "webpki-roots 0.26.1", ] [[package]] @@ -2411,6 +2465,8 @@ dependencies = [ "httparse", "log", "rand 0.8.5", + "rustls 0.22.3", + "rustls-pki-types", "sha1", "thiserror", "url", @@ -2628,6 +2684,15 @@ version = "0.25.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5f20c57d8d7db6d3b86154206ae5d8fba62dd39573114de97c2cb0578251f8e1" +[[package]] +name = "webpki-roots" +version = "0.26.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b3de34ae270483955a94f4b21bdaaeb83d508bb84a01435f393818edb0012009" +dependencies = [ + "rustls-pki-types", +] + [[package]] name = "which" version = "6.0.0" @@ -2822,3 +2887,9 @@ dependencies = [ "cfg-if", "windows-sys 0.48.0", ] + +[[package]] +name = "zeroize" +version = "1.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "525b4ec142c6b68a2d10f01f7bbf6755599ca3f81ea53b8431b7dd348f5fdb2d" diff --git a/Cargo.toml b/Cargo.toml index 099f76a..2149782 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -36,6 +36,8 @@ tower-http = { version = "0.5.1", features = ["trace"] } tracing = "0.1.40" tracing-subscriber = { version = "0.3.18", features = ["env-filter"] } rgb = "0.8.37" +url = "2.5.0" +tokio-tungstenite = { version = "0.21.0", features = ["rustls-tls-webpki-roots"] } [profile.release] strip = true diff --git a/src/cmd/stream.rs b/src/cmd/stream.rs index 48f0e2f..630c2b7 100644 --- a/src/cmd/stream.rs +++ b/src/cmd/stream.rs @@ -26,6 +26,10 @@ pub struct Cli { #[clap(short, long, default_value = "127.0.0.1:8080")] listen_addr: SocketAddr, + /// WebSocket forwarding address + #[clap(short, long, value_parser = validate_forward_url)] + forward_url: Option, + /// Override terminal size for the session #[arg(long, value_name = "COLSxROWS")] tty_size: Option, @@ -35,8 +39,19 @@ pub struct Cli { log_file: Option, } +fn validate_forward_url(s: &str) -> Result { + let url = url::Url::parse(s).map_err(|e| e.to_string())?; + let scheme = url.scheme(); + + if scheme == "ws" || scheme == "wss" { + Ok(url) + } else { + Err("must be WebSocket URL (ws:// or wss://)".to_owned()) + } +} + impl Cli { - pub fn run(self, config: &Config) -> Result<()> { + pub fn run(mut self, config: &Config) -> Result<()> { locale::check_utf8_locale()?; let command = self.get_command(config); @@ -65,6 +80,7 @@ impl Cli { let mut streamer = streamer::Streamer::new( self.listen_addr, + self.forward_url.take(), record_input, keys, notifier, diff --git a/src/streamer/forwarder.rs b/src/streamer/forwarder.rs new file mode 100644 index 0000000..14f8c1d --- /dev/null +++ b/src/streamer/forwarder.rs @@ -0,0 +1,103 @@ +use super::alis; +use super::session; +use futures_util::Sink; +use futures_util::{sink, SinkExt, Stream, StreamExt}; +use std::time::{Duration, Instant}; +use tokio::sync::mpsc; +use tokio_stream::wrappers::errors::BroadcastStreamRecvError; +use tokio_stream::wrappers::IntervalStream; +use tokio_tungstenite::tungstenite; +use tokio_tungstenite::tungstenite::Message; +use tracing::{debug, info}; + +const WS_PING_INTERVAL: u64 = 15; +const MAX_RECONNECT_DELAY: u64 = 5000; + +pub async fn forward( + clients_tx: mpsc::Sender, + url: url::Url, +) -> anyhow::Result<()> { + let mut reconnect_attempt = 0; + + info!("forwarding to {url}"); + + loop { + let time = Instant::now(); + + match forward_once(&clients_tx, &url).await { + Ok(_) => return Ok(()), + Err(e) => debug!("{e:?}"), + } + + if time.elapsed().as_secs_f32() > 1.0 { + reconnect_attempt = 0; + } + + let delay = exponential_delay(reconnect_attempt); + reconnect_attempt = (reconnect_attempt + 1).min(10); + info!("connection closed, reconnecting in {delay}"); + tokio::time::sleep(Duration::from_millis(delay)).await; + } +} + +async fn forward_once( + clients_tx: &mpsc::Sender, + url: &url::Url, +) -> anyhow::Result<()> { + let (ws, _) = tokio_tungstenite::connect_async(url).await?; + info!("connected to the endpoint"); + let (sink, stream) = ws.split(); + let drainer = tokio::spawn(stream.map(Ok).forward(sink::drain())); + let events = alis::stream(clients_tx).await?.map(ws_result); + let result = forward_with_pings(events, sink).await; + drainer.abort(); + + result +} + +async fn forward_with_pings(events: T, mut sink: U) -> anyhow::Result<()> +where + T: Stream> + Unpin, + U: Sink + Unpin, + ::Error: Into, +{ + let mut events = events.fuse(); + let mut pings = ping_stream().fuse(); + + loop { + futures_util::select! { + event = events.next() => { + match event { + Some(event) => { + sink.send(event?).await.map_err(|e| e.into())?; + } + + None => return Ok(()) + } + }, + + ping = pings.next() => { + sink.send(ping.unwrap()).await.map_err(|e| e.into())?; + } + } + } +} + +fn exponential_delay(attempt: usize) -> u64 { + (2_u64.pow(attempt as u32) * 500).min(MAX_RECONNECT_DELAY) +} + +fn ws_result(m: Result, BroadcastStreamRecvError>) -> anyhow::Result { + match m { + Ok(bytes) => Ok(tungstenite::Message::binary(bytes)), + Err(e) => Err(anyhow::anyhow!(e)), + } +} + +fn ping_stream() -> impl Stream { + let interval = tokio::time::interval(Duration::from_secs(WS_PING_INTERVAL)); + + IntervalStream::new(interval) + .skip(1) + .map(|_| tungstenite::Message::Ping(vec![])) +} diff --git a/src/streamer/mod.rs b/src/streamer/mod.rs index fc297b8..9976a90 100644 --- a/src/streamer/mod.rs +++ b/src/streamer/mod.rs @@ -1,4 +1,5 @@ mod alis; +mod forwarder; mod server; mod session; use crate::config::Key; @@ -27,6 +28,7 @@ pub struct Streamer { paused: bool, prefix_mode: bool, listen_addr: net::SocketAddr, + forward_url: Option, theme: Option, } @@ -39,6 +41,7 @@ enum Event { impl Streamer { pub fn new( listen_addr: net::SocketAddr, + forward_url: Option, record_input: bool, keys: KeyBindings, notifier: Box, @@ -61,6 +64,7 @@ impl Streamer { paused: false, prefix_mode: false, listen_addr, + forward_url, theme, } } @@ -86,7 +90,18 @@ impl pty::Recorder for Streamer { let (server_shutdown_tx, server_shutdown_rx) = oneshot::channel::<()>(); let listener = TcpListener::bind(self.listen_addr)?; let runtime = build_tokio_runtime(); - let server = runtime.spawn(server::serve(listener, clients_tx, server_shutdown_rx)); + + let server = runtime.spawn(server::serve( + listener, + clients_tx.clone(), + server_shutdown_rx, + )); + + let forwarder = self + .forward_url + .take() + .map(|url| runtime.spawn(forwarder::forward(clients_tx, url))); + let theme = self.theme.take(); self.event_loop_handle = wrap_thread_handle(thread::spawn(move || { @@ -95,6 +110,7 @@ impl pty::Recorder for Streamer { let _ = server_shutdown_tx.send(()); let _ = server.await; let _ = clients_rx.recv().await; + let _ = forwarder.map(|task| task.abort()); }); }));