mirror of
https://github.com/asciinema/asciinema.git
synced 2025-12-16 03:38:03 +01:00
Workaround axum's graceful shutdown limitation wrt existing websocket connections
This commit is contained in:
14
Cargo.lock
generated
14
Cargo.lock
generated
@@ -606,6 +606,17 @@ version = "0.3.31"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6"
|
checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "futures-macro"
|
||||||
|
version = "0.3.31"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"syn",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "futures-sink"
|
name = "futures-sink"
|
||||||
version = "0.3.31"
|
version = "0.3.31"
|
||||||
@@ -626,6 +637,7 @@ checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"futures-core",
|
"futures-core",
|
||||||
"futures-io",
|
"futures-io",
|
||||||
|
"futures-macro",
|
||||||
"futures-sink",
|
"futures-sink",
|
||||||
"futures-task",
|
"futures-task",
|
||||||
"memchr",
|
"memchr",
|
||||||
@@ -2019,6 +2031,8 @@ dependencies = [
|
|||||||
"bytes",
|
"bytes",
|
||||||
"futures-core",
|
"futures-core",
|
||||||
"futures-sink",
|
"futures-sink",
|
||||||
|
"futures-util",
|
||||||
|
"hashbrown",
|
||||||
"pin-project-lite",
|
"pin-project-lite",
|
||||||
"tokio",
|
"tokio",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ rgb = { version = "0.8.37", default-features = false }
|
|||||||
url = "2.5.0"
|
url = "2.5.0"
|
||||||
tokio-tungstenite = { version = "0.26.2", default-features = false, features = ["connect", "rustls-tls-native-roots"] }
|
tokio-tungstenite = { version = "0.26.2", default-features = false, features = ["connect", "rustls-tls-native-roots"] }
|
||||||
rustls = { version = "0.23.26", default-features = false, features = ["aws_lc_rs"] }
|
rustls = { version = "0.23.26", default-features = false, features = ["aws_lc_rs"] }
|
||||||
tokio-util = "0.7.10"
|
tokio-util = { version = "0.7.10", features = ["rt"] }
|
||||||
rand = "0.9.1"
|
rand = "0.9.1"
|
||||||
|
|
||||||
[build-dependencies]
|
[build-dependencies]
|
||||||
|
|||||||
@@ -13,8 +13,10 @@ use axum::serve::ListenerExt;
|
|||||||
use axum::Router;
|
use axum::Router;
|
||||||
use futures_util::{sink, StreamExt};
|
use futures_util::{sink, StreamExt};
|
||||||
use rust_embed::RustEmbed;
|
use rust_embed::RustEmbed;
|
||||||
|
use tokio::time::{self, Duration};
|
||||||
use tokio_stream::wrappers::errors::BroadcastStreamRecvError;
|
use tokio_stream::wrappers::errors::BroadcastStreamRecvError;
|
||||||
use tokio_util::sync::CancellationToken;
|
use tokio_util::sync::CancellationToken;
|
||||||
|
use tokio_util::task::TaskTracker;
|
||||||
use tower_http::trace::{DefaultMakeSpan, TraceLayer};
|
use tower_http::trace::{DefaultMakeSpan, TraceLayer};
|
||||||
use tracing::info;
|
use tracing::info;
|
||||||
|
|
||||||
@@ -25,6 +27,12 @@ use crate::stream::Subscriber;
|
|||||||
#[folder = "assets/"]
|
#[folder = "assets/"]
|
||||||
struct Assets;
|
struct Assets;
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
struct AppState {
|
||||||
|
subscriber: Subscriber,
|
||||||
|
tracker: TaskTracker,
|
||||||
|
}
|
||||||
|
|
||||||
pub async fn serve(
|
pub async fn serve(
|
||||||
listener: std::net::TcpListener,
|
listener: std::net::TcpListener,
|
||||||
subscriber: Subscriber,
|
subscriber: Subscriber,
|
||||||
@@ -36,9 +44,16 @@ pub async fn serve(
|
|||||||
let trace =
|
let trace =
|
||||||
TraceLayer::new_for_http().make_span_with(DefaultMakeSpan::default().include_headers(true));
|
TraceLayer::new_for_http().make_span_with(DefaultMakeSpan::default().include_headers(true));
|
||||||
|
|
||||||
|
let tracker = TaskTracker::new();
|
||||||
|
|
||||||
|
let state = AppState {
|
||||||
|
subscriber,
|
||||||
|
tracker: tracker.clone(),
|
||||||
|
};
|
||||||
|
|
||||||
let app = Router::new()
|
let app = Router::new()
|
||||||
.route("/ws", get(ws_handler))
|
.route("/ws", get(ws_handler))
|
||||||
.with_state(subscriber)
|
.with_state(state)
|
||||||
.fallback(static_handler)
|
.fallback(static_handler)
|
||||||
.layer(trace);
|
.layer(trace);
|
||||||
|
|
||||||
@@ -55,12 +70,17 @@ pub async fn serve(
|
|||||||
let _ = tcp_stream.set_nodelay(true);
|
let _ = tcp_stream.set_nodelay(true);
|
||||||
});
|
});
|
||||||
|
|
||||||
axum::serve(
|
let result = axum::serve(
|
||||||
listener,
|
listener,
|
||||||
app.into_make_service_with_connect_info::<SocketAddr>(),
|
app.into_make_service_with_connect_info::<SocketAddr>(),
|
||||||
)
|
)
|
||||||
.with_graceful_shutdown(signal)
|
.with_graceful_shutdown(signal)
|
||||||
.await
|
.await;
|
||||||
|
|
||||||
|
tracker.close();
|
||||||
|
let _ = time::timeout(Duration::from_secs(3), tracker.wait()).await;
|
||||||
|
|
||||||
|
result
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn static_handler(uri: Uri) -> impl IntoResponse {
|
async fn static_handler(uri: Uri) -> impl IntoResponse {
|
||||||
@@ -99,14 +119,18 @@ fn mime_from_path(path: &str) -> &str {
|
|||||||
async fn ws_handler(
|
async fn ws_handler(
|
||||||
ws: WebSocketUpgrade,
|
ws: WebSocketUpgrade,
|
||||||
ConnectInfo(addr): ConnectInfo<SocketAddr>,
|
ConnectInfo(addr): ConnectInfo<SocketAddr>,
|
||||||
State(subscriber): State<Subscriber>,
|
State(state): State<AppState>,
|
||||||
) -> impl IntoResponse {
|
) -> impl IntoResponse {
|
||||||
ws.protocols(["v1.alis"])
|
ws.protocols(["v1.alis"])
|
||||||
.on_upgrade(move |socket| async move {
|
.on_upgrade(move |socket| async move {
|
||||||
info!("websocket client {addr} connected");
|
info!("websocket client {addr} connected");
|
||||||
|
|
||||||
if socket.protocol().is_some() {
|
if socket.protocol().is_some() {
|
||||||
let _ = handle_socket(socket, subscriber).await;
|
let _ = state
|
||||||
|
.tracker
|
||||||
|
.track_future(handle_socket(socket, state.subscriber))
|
||||||
|
.await;
|
||||||
|
|
||||||
info!("websocket client {addr} disconnected");
|
info!("websocket client {addr} disconnected");
|
||||||
} else {
|
} else {
|
||||||
info!("subprotocol negotiation failed, closing connection");
|
info!("subprotocol negotiation failed, closing connection");
|
||||||
|
|||||||
Reference in New Issue
Block a user