Workaround axum's graceful shutdown limitation wrt existing websocket connections

This commit is contained in:
Marcin Kulik
2025-06-06 14:56:14 +02:00
parent 27305aa0c6
commit d8b6cfeafc
3 changed files with 44 additions and 6 deletions

14
Cargo.lock generated
View File

@@ -606,6 +606,17 @@ version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6" checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6"
[[package]]
name = "futures-macro"
version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]] [[package]]
name = "futures-sink" name = "futures-sink"
version = "0.3.31" version = "0.3.31"
@@ -626,6 +637,7 @@ checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81"
dependencies = [ dependencies = [
"futures-core", "futures-core",
"futures-io", "futures-io",
"futures-macro",
"futures-sink", "futures-sink",
"futures-task", "futures-task",
"memchr", "memchr",
@@ -2019,6 +2031,8 @@ dependencies = [
"bytes", "bytes",
"futures-core", "futures-core",
"futures-sink", "futures-sink",
"futures-util",
"hashbrown",
"pin-project-lite", "pin-project-lite",
"tokio", "tokio",
] ]

View File

@@ -40,7 +40,7 @@ rgb = { version = "0.8.37", default-features = false }
url = "2.5.0" url = "2.5.0"
tokio-tungstenite = { version = "0.26.2", default-features = false, features = ["connect", "rustls-tls-native-roots"] } tokio-tungstenite = { version = "0.26.2", default-features = false, features = ["connect", "rustls-tls-native-roots"] }
rustls = { version = "0.23.26", default-features = false, features = ["aws_lc_rs"] } rustls = { version = "0.23.26", default-features = false, features = ["aws_lc_rs"] }
tokio-util = "0.7.10" tokio-util = { version = "0.7.10", features = ["rt"] }
rand = "0.9.1" rand = "0.9.1"
[build-dependencies] [build-dependencies]

View File

@@ -13,8 +13,10 @@ use axum::serve::ListenerExt;
use axum::Router; use axum::Router;
use futures_util::{sink, StreamExt}; use futures_util::{sink, StreamExt};
use rust_embed::RustEmbed; use rust_embed::RustEmbed;
use tokio::time::{self, Duration};
use tokio_stream::wrappers::errors::BroadcastStreamRecvError; use tokio_stream::wrappers::errors::BroadcastStreamRecvError;
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
use tokio_util::task::TaskTracker;
use tower_http::trace::{DefaultMakeSpan, TraceLayer}; use tower_http::trace::{DefaultMakeSpan, TraceLayer};
use tracing::info; use tracing::info;
@@ -25,6 +27,12 @@ use crate::stream::Subscriber;
#[folder = "assets/"] #[folder = "assets/"]
struct Assets; struct Assets;
#[derive(Clone)]
struct AppState {
subscriber: Subscriber,
tracker: TaskTracker,
}
pub async fn serve( pub async fn serve(
listener: std::net::TcpListener, listener: std::net::TcpListener,
subscriber: Subscriber, subscriber: Subscriber,
@@ -36,9 +44,16 @@ pub async fn serve(
let trace = let trace =
TraceLayer::new_for_http().make_span_with(DefaultMakeSpan::default().include_headers(true)); TraceLayer::new_for_http().make_span_with(DefaultMakeSpan::default().include_headers(true));
let tracker = TaskTracker::new();
let state = AppState {
subscriber,
tracker: tracker.clone(),
};
let app = Router::new() let app = Router::new()
.route("/ws", get(ws_handler)) .route("/ws", get(ws_handler))
.with_state(subscriber) .with_state(state)
.fallback(static_handler) .fallback(static_handler)
.layer(trace); .layer(trace);
@@ -55,12 +70,17 @@ pub async fn serve(
let _ = tcp_stream.set_nodelay(true); let _ = tcp_stream.set_nodelay(true);
}); });
axum::serve( let result = axum::serve(
listener, listener,
app.into_make_service_with_connect_info::<SocketAddr>(), app.into_make_service_with_connect_info::<SocketAddr>(),
) )
.with_graceful_shutdown(signal) .with_graceful_shutdown(signal)
.await .await;
tracker.close();
let _ = time::timeout(Duration::from_secs(3), tracker.wait()).await;
result
} }
async fn static_handler(uri: Uri) -> impl IntoResponse { async fn static_handler(uri: Uri) -> impl IntoResponse {
@@ -99,14 +119,18 @@ fn mime_from_path(path: &str) -> &str {
async fn ws_handler( async fn ws_handler(
ws: WebSocketUpgrade, ws: WebSocketUpgrade,
ConnectInfo(addr): ConnectInfo<SocketAddr>, ConnectInfo(addr): ConnectInfo<SocketAddr>,
State(subscriber): State<Subscriber>, State(state): State<AppState>,
) -> impl IntoResponse { ) -> impl IntoResponse {
ws.protocols(["v1.alis"]) ws.protocols(["v1.alis"])
.on_upgrade(move |socket| async move { .on_upgrade(move |socket| async move {
info!("websocket client {addr} connected"); info!("websocket client {addr} connected");
if socket.protocol().is_some() { if socket.protocol().is_some() {
let _ = handle_socket(socket, subscriber).await; let _ = state
.tracker
.track_future(handle_socket(socket, state.subscriber))
.await;
info!("websocket client {addr} disconnected"); info!("websocket client {addr} disconnected");
} else { } else {
info!("subprotocol negotiation failed, closing connection"); info!("subprotocol negotiation failed, closing connection");