Fix stream session shutdown when forwarder is in a reconnection loop

This commit is contained in:
Marcin Kulik
2024-04-02 22:08:57 +02:00
parent 0b282a5737
commit 0d7951b54c
3 changed files with 28 additions and 13 deletions

View File

@@ -6,7 +6,7 @@ use futures_util::Sink;
use futures_util::{sink, SinkExt, Stream, StreamExt};
use std::borrow::Cow;
use std::time::{Duration, Instant};
use tokio::sync::mpsc;
use tokio::sync::{broadcast, mpsc};
use tokio_stream::wrappers::errors::BroadcastStreamRecvError;
use tokio_stream::wrappers::IntervalStream;
use tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode;
@@ -20,6 +20,7 @@ const MAX_RECONNECT_DELAY: u64 = 5000;
pub async fn forward(
clients_tx: mpsc::Sender<session::Client>,
url: url::Url,
mut shutdown_rx: broadcast::Receiver<()>,
) -> anyhow::Result<()> {
let mut reconnect_attempt = 0;
@@ -39,9 +40,19 @@ pub async fn forward(
let delay = exponential_delay(reconnect_attempt);
reconnect_attempt = (reconnect_attempt + 1).min(10);
info!("connection closed, reconnecting in {delay}");
tokio::time::sleep(Duration::from_millis(delay)).await;
info!("connection error, reconnecting in {delay}");
tokio::select! {
_ = tokio::time::sleep(Duration::from_millis(delay)) => (),
_ = shutdown_rx.recv() => {
info!("shutting down");
break;
}
}
}
Ok(())
}
async fn forward_once(

View File

@@ -11,7 +11,7 @@ use std::io;
use std::net::{self, TcpListener};
use std::thread;
use std::time::Instant;
use tokio::sync::{mpsc, oneshot};
use tokio::sync::{broadcast, mpsc};
use tracing::info;
pub struct Streamer {
@@ -87,30 +87,34 @@ impl pty::Recorder for Streamer {
fn start(&mut self, tty_size: tty::TtySize) -> io::Result<()> {
let pty_rx = self.pty_rx.take().unwrap();
let (clients_tx, mut clients_rx) = mpsc::channel(1);
let (server_shutdown_tx, server_shutdown_rx) = oneshot::channel::<()>();
let (shutdown_tx, _shutdown_rx) = broadcast::channel::<()>(1);
let listener = TcpListener::bind(self.listen_addr)?;
let runtime = build_tokio_runtime();
let server = runtime.spawn(server::serve(
listener,
clients_tx.clone(),
server_shutdown_rx,
shutdown_tx.subscribe(),
));
let forwarder = self
.forward_url
.take()
.map(|url| runtime.spawn(forwarder::forward(clients_tx, url)));
.map(|url| runtime.spawn(forwarder::forward(clients_tx, url, shutdown_tx.subscribe())));
let theme = self.theme.take();
self.event_loop_handle = wrap_thread_handle(thread::spawn(move || {
runtime.block_on(async move {
event_loop(pty_rx, &mut clients_rx, tty_size, theme).await;
let _ = server_shutdown_tx.send(());
let _ = shutdown_tx.send(());
let _ = server.await;
if let Some(task) = forwarder {
let _ = task.await;
}
let _ = clients_rx.recv().await;
let _ = forwarder.map(|task| task.abort());
});
}));

View File

@@ -16,7 +16,7 @@ use std::borrow::Cow;
use std::future;
use std::io;
use std::net::SocketAddr;
use tokio::sync::{mpsc, oneshot};
use tokio::sync::{broadcast, mpsc};
use tokio_stream::wrappers::errors::BroadcastStreamRecvError;
use tower_http::trace;
use tracing::info;
@@ -28,7 +28,7 @@ struct Assets;
pub async fn serve(
listener: std::net::TcpListener,
clients_tx: mpsc::Sender<session::Client>,
shutdown_rx: oneshot::Receiver<()>,
mut shutdown_rx: broadcast::Receiver<()>,
) -> io::Result<()> {
listener.set_nonblocking(true)?;
let listener = tokio::net::TcpListener::from_std(listener)?;
@@ -42,8 +42,8 @@ pub async fn serve(
.fallback(static_handler)
.layer(trace);
let signal = async {
let _ = shutdown_rx.await;
let signal = async move {
let _ = shutdown_rx.recv().await;
};
info!(