Use CancellationToken for streamer shutdown

This commit is contained in:
Marcin Kulik
2024-04-17 20:45:56 +02:00
parent 4158b61eca
commit 9984f097b1
5 changed files with 15 additions and 18 deletions

1
Cargo.lock generated
View File

@@ -106,6 +106,7 @@ dependencies = [
"tokio",
"tokio-stream",
"tokio-tungstenite",
"tokio-util",
"tower-http",
"tracing",
"tracing-subscriber",

View File

@@ -39,6 +39,7 @@ rgb = "0.8.37"
url = "2.5.0"
tokio-tungstenite = { version = "0.21.0", features = ["rustls-tls-webpki-roots"] }
sha2 = "0.10.8"
tokio-util = "0.7.10"
[profile.release]
strip = true

View File

@@ -5,7 +5,7 @@ use futures_util::{future, stream, SinkExt, Stream, StreamExt};
use std::borrow::Cow;
use std::time::Duration;
use tokio::net::TcpStream;
use tokio::sync::{broadcast, mpsc};
use tokio::sync::mpsc;
use tokio_stream::wrappers::errors::BroadcastStreamRecvError;
use tokio_stream::wrappers::IntervalStream;
use tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode;
@@ -21,7 +21,7 @@ pub async fn forward(
url: url::Url,
clients_tx: mpsc::Sender<session::Client>,
notifier_tx: std::sync::mpsc::Sender<String>,
mut shutdown_rx: broadcast::Receiver<()>,
shutdown_token: tokio_util::sync::CancellationToken,
) {
info!("forwarding to {url}");
let mut reconnect_attempt = 0;
@@ -79,11 +79,9 @@ pub async fn forward(
tokio::select! {
_ = tokio::time::sleep(Duration::from_millis(delay)) => (),
_ = shutdown_rx.recv() => break
_ = shutdown_token.cancelled() => break
}
}
info!("shutting down");
}
async fn connect_and_forward(
@@ -124,11 +122,7 @@ where
event = events.next() => {
match event {
Some(event) => sink.send(event?).await?,
None => {
info!("session ended");
return Ok(true);
}
None => return Ok(true)
}
},

View File

@@ -12,7 +12,7 @@ use std::net;
use std::thread;
use std::time::Duration;
use std::time::Instant;
use tokio::sync::{broadcast, mpsc};
use tokio::sync::mpsc;
use tracing::info;
pub struct Streamer {
@@ -89,14 +89,14 @@ impl pty::Recorder for Streamer {
fn start(&mut self, tty_size: tty::TtySize) -> io::Result<()> {
let pty_rx = self.pty_rx.take().unwrap();
let (clients_tx, mut clients_rx) = mpsc::channel(1);
let (shutdown_tx, _shutdown_rx) = broadcast::channel::<()>(1);
let shutdown_token = tokio_util::sync::CancellationToken::new();
let runtime = build_tokio_runtime();
let server = self.listener.take().map(|listener| {
runtime.spawn(server::serve(
listener,
clients_tx.clone(),
shutdown_tx.subscribe(),
shutdown_token.clone(),
))
});
@@ -105,7 +105,7 @@ impl pty::Recorder for Streamer {
url,
clients_tx,
self.notifier_tx.clone(),
shutdown_tx.subscribe(),
shutdown_token.clone(),
))
});
@@ -114,7 +114,8 @@ impl pty::Recorder for Streamer {
self.event_loop_handle = wrap_thread_handle(thread::spawn(move || {
runtime.block_on(async move {
event_loop(pty_rx, &mut clients_rx, tty_size, theme).await;
let _ = shutdown_tx.send(());
info!("shutting down");
shutdown_token.cancel();
if let Some(task) = server {
let _ = tokio::time::timeout(Duration::from_secs(5), task).await;

View File

@@ -16,7 +16,7 @@ use std::borrow::Cow;
use std::future;
use std::io;
use std::net::SocketAddr;
use tokio::sync::{broadcast, mpsc};
use tokio::sync::mpsc;
use tokio_stream::wrappers::errors::BroadcastStreamRecvError;
use tower_http::trace;
use tracing::info;
@@ -28,7 +28,7 @@ struct Assets;
pub async fn serve(
listener: std::net::TcpListener,
clients_tx: mpsc::Sender<session::Client>,
mut shutdown_rx: broadcast::Receiver<()>,
shutdown_token: tokio_util::sync::CancellationToken,
) -> io::Result<()> {
listener.set_nonblocking(true)?;
let listener = tokio::net::TcpListener::from_std(listener)?;
@@ -43,7 +43,7 @@ pub async fn serve(
.layer(trace);
let signal = async move {
let _ = shutdown_rx.recv().await;
let _ = shutdown_token.cancelled().await;
};
info!(