Workaround axum's graceful shutdown limitation wrt existing websocket connections

This commit is contained in:
Marcin Kulik
2025-06-06 14:56:14 +02:00
parent 27305aa0c6
commit d8b6cfeafc
3 changed files with 44 additions and 6 deletions

14
Cargo.lock generated
View File

@@ -606,6 +606,17 @@ version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6"
[[package]]
name = "futures-macro"
version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "futures-sink"
version = "0.3.31"
@@ -626,6 +637,7 @@ checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81"
dependencies = [
"futures-core",
"futures-io",
"futures-macro",
"futures-sink",
"futures-task",
"memchr",
@@ -2019,6 +2031,8 @@ dependencies = [
"bytes",
"futures-core",
"futures-sink",
"futures-util",
"hashbrown",
"pin-project-lite",
"tokio",
]

View File

@@ -40,7 +40,7 @@ rgb = { version = "0.8.37", default-features = false }
url = "2.5.0"
tokio-tungstenite = { version = "0.26.2", default-features = false, features = ["connect", "rustls-tls-native-roots"] }
rustls = { version = "0.23.26", default-features = false, features = ["aws_lc_rs"] }
tokio-util = "0.7.10"
tokio-util = { version = "0.7.10", features = ["rt"] }
rand = "0.9.1"
[build-dependencies]

View File

@@ -13,8 +13,10 @@ use axum::serve::ListenerExt;
use axum::Router;
use futures_util::{sink, StreamExt};
use rust_embed::RustEmbed;
use tokio::time::{self, Duration};
use tokio_stream::wrappers::errors::BroadcastStreamRecvError;
use tokio_util::sync::CancellationToken;
use tokio_util::task::TaskTracker;
use tower_http::trace::{DefaultMakeSpan, TraceLayer};
use tracing::info;
@@ -25,6 +27,12 @@ use crate::stream::Subscriber;
#[folder = "assets/"]
struct Assets;
#[derive(Clone)]
struct AppState {
subscriber: Subscriber,
tracker: TaskTracker,
}
pub async fn serve(
listener: std::net::TcpListener,
subscriber: Subscriber,
@@ -36,9 +44,16 @@ pub async fn serve(
let trace =
TraceLayer::new_for_http().make_span_with(DefaultMakeSpan::default().include_headers(true));
let tracker = TaskTracker::new();
let state = AppState {
subscriber,
tracker: tracker.clone(),
};
let app = Router::new()
.route("/ws", get(ws_handler))
.with_state(subscriber)
.with_state(state)
.fallback(static_handler)
.layer(trace);
@@ -55,12 +70,17 @@ pub async fn serve(
let _ = tcp_stream.set_nodelay(true);
});
axum::serve(
let result = axum::serve(
listener,
app.into_make_service_with_connect_info::<SocketAddr>(),
)
.with_graceful_shutdown(signal)
.await
.await;
tracker.close();
let _ = time::timeout(Duration::from_secs(3), tracker.wait()).await;
result
}
async fn static_handler(uri: Uri) -> impl IntoResponse {
@@ -99,14 +119,18 @@ fn mime_from_path(path: &str) -> &str {
async fn ws_handler(
ws: WebSocketUpgrade,
ConnectInfo(addr): ConnectInfo<SocketAddr>,
State(subscriber): State<Subscriber>,
State(state): State<AppState>,
) -> impl IntoResponse {
ws.protocols(["v1.alis"])
.on_upgrade(move |socket| async move {
info!("websocket client {addr} connected");
if socket.protocol().is_some() {
let _ = handle_socket(socket, subscriber).await;
let _ = state
.tracker
.track_future(handle_socket(socket, state.subscriber))
.await;
info!("websocket client {addr} disconnected");
} else {
info!("subprotocol negotiation failed, closing connection");