diff --git a/Cargo.lock b/Cargo.lock index 5022106..5c91e8c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -106,6 +106,7 @@ dependencies = [ "tokio", "tokio-stream", "tokio-tungstenite", + "tokio-util", "tower-http", "tracing", "tracing-subscriber", diff --git a/Cargo.toml b/Cargo.toml index 6ed48ce..000c9ce 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -39,6 +39,7 @@ rgb = "0.8.37" url = "2.5.0" tokio-tungstenite = { version = "0.21.0", features = ["rustls-tls-webpki-roots"] } sha2 = "0.10.8" +tokio-util = "0.7.10" [profile.release] strip = true diff --git a/src/streamer/forwarder.rs b/src/streamer/forwarder.rs index a0d4215..0d0ef31 100644 --- a/src/streamer/forwarder.rs +++ b/src/streamer/forwarder.rs @@ -5,7 +5,7 @@ use futures_util::{future, stream, SinkExt, Stream, StreamExt}; use std::borrow::Cow; use std::time::Duration; use tokio::net::TcpStream; -use tokio::sync::{broadcast, mpsc}; +use tokio::sync::mpsc; use tokio_stream::wrappers::errors::BroadcastStreamRecvError; use tokio_stream::wrappers::IntervalStream; use tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode; @@ -21,7 +21,7 @@ pub async fn forward( url: url::Url, clients_tx: mpsc::Sender, notifier_tx: std::sync::mpsc::Sender, - mut shutdown_rx: broadcast::Receiver<()>, + shutdown_token: tokio_util::sync::CancellationToken, ) { info!("forwarding to {url}"); let mut reconnect_attempt = 0; @@ -79,11 +79,9 @@ pub async fn forward( tokio::select! { _ = tokio::time::sleep(Duration::from_millis(delay)) => (), - _ = shutdown_rx.recv() => break + _ = shutdown_token.cancelled() => break } } - - info!("shutting down"); } async fn connect_and_forward( @@ -124,11 +122,7 @@ where event = events.next() => { match event { Some(event) => sink.send(event?).await?, - - None => { - info!("session ended"); - return Ok(true); - } + None => return Ok(true) } }, diff --git a/src/streamer/mod.rs b/src/streamer/mod.rs index d27ad94..ff3b861 100644 --- a/src/streamer/mod.rs +++ b/src/streamer/mod.rs @@ -12,7 +12,7 @@ use std::net; use std::thread; use std::time::Duration; use std::time::Instant; -use tokio::sync::{broadcast, mpsc}; +use tokio::sync::mpsc; use tracing::info; pub struct Streamer { @@ -89,14 +89,14 @@ impl pty::Recorder for Streamer { fn start(&mut self, tty_size: tty::TtySize) -> io::Result<()> { let pty_rx = self.pty_rx.take().unwrap(); let (clients_tx, mut clients_rx) = mpsc::channel(1); - let (shutdown_tx, _shutdown_rx) = broadcast::channel::<()>(1); + let shutdown_token = tokio_util::sync::CancellationToken::new(); let runtime = build_tokio_runtime(); let server = self.listener.take().map(|listener| { runtime.spawn(server::serve( listener, clients_tx.clone(), - shutdown_tx.subscribe(), + shutdown_token.clone(), )) }); @@ -105,7 +105,7 @@ impl pty::Recorder for Streamer { url, clients_tx, self.notifier_tx.clone(), - shutdown_tx.subscribe(), + shutdown_token.clone(), )) }); @@ -114,7 +114,8 @@ impl pty::Recorder for Streamer { self.event_loop_handle = wrap_thread_handle(thread::spawn(move || { runtime.block_on(async move { event_loop(pty_rx, &mut clients_rx, tty_size, theme).await; - let _ = shutdown_tx.send(()); + info!("shutting down"); + shutdown_token.cancel(); if let Some(task) = server { let _ = tokio::time::timeout(Duration::from_secs(5), task).await; diff --git a/src/streamer/server.rs b/src/streamer/server.rs index e76b44b..c43e942 100644 --- a/src/streamer/server.rs +++ b/src/streamer/server.rs @@ -16,7 +16,7 @@ use std::borrow::Cow; use std::future; use std::io; use std::net::SocketAddr; -use tokio::sync::{broadcast, mpsc}; +use tokio::sync::mpsc; use tokio_stream::wrappers::errors::BroadcastStreamRecvError; use tower_http::trace; use tracing::info; @@ -28,7 +28,7 @@ struct Assets; pub async fn serve( listener: std::net::TcpListener, clients_tx: mpsc::Sender, - mut shutdown_rx: broadcast::Receiver<()>, + shutdown_token: tokio_util::sync::CancellationToken, ) -> io::Result<()> { listener.set_nonblocking(true)?; let listener = tokio::net::TcpListener::from_std(listener)?; @@ -43,7 +43,7 @@ pub async fn serve( .layer(trace); let signal = async move { - let _ = shutdown_rx.recv().await; + let _ = shutdown_token.cancelled().await; }; info!(