mirror of
https://github.com/asciinema/asciinema.git
synced 2026-05-18 05:04:50 +02:00
324 lines
10 KiB
Rust
324 lines
10 KiB
Rust
use core::future::{self, Future};
|
|
use std::pin::Pin;
|
|
use std::time::Duration;
|
|
|
|
use anyhow::{anyhow, bail};
|
|
use axum::http::Uri;
|
|
use futures_util::stream::SplitSink;
|
|
use futures_util::{SinkExt, Stream, StreamExt};
|
|
use rand::Rng;
|
|
use tokio::net::TcpStream;
|
|
use tokio::time;
|
|
use tokio_stream::wrappers::errors::BroadcastStreamRecvError;
|
|
use tokio_stream::wrappers::IntervalStream;
|
|
use tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode;
|
|
use tokio_tungstenite::tungstenite::protocol::CloseFrame;
|
|
use tokio_tungstenite::tungstenite::{self, ClientRequestBuilder, Message};
|
|
use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
|
|
use tracing::{debug, error, info};
|
|
|
|
use crate::alis;
|
|
use crate::api;
|
|
use crate::notifier::Notifier;
|
|
use crate::stream::{Event, Subscriber};
|
|
|
|
const PING_INTERVAL: u64 = 15;
|
|
const PING_TIMEOUT: u64 = 10;
|
|
const SEND_TIMEOUT: u64 = 10;
|
|
const RECONNECT_DELAY_BASE: u64 = 500;
|
|
const RECONNECT_DELAY_CAP: u64 = 10_000;
|
|
|
|
pub async fn forward<N: Notifier>(
|
|
url: url::Url,
|
|
subscriber: Subscriber,
|
|
mut notifier: N,
|
|
shutdown_token: tokio_util::sync::CancellationToken,
|
|
) -> anyhow::Result<()> {
|
|
info!("forwarding to {url}");
|
|
let mut reconnect_attempt = 0;
|
|
let mut connection_count: u64 = 0;
|
|
|
|
loop {
|
|
let session_stream = subscriber.subscribe().await?;
|
|
let conn = connect_and_forward(&url, session_stream);
|
|
tokio::pin!(conn);
|
|
|
|
let result = tokio::select! {
|
|
result = &mut conn => result,
|
|
|
|
_ = time::sleep(Duration::from_secs(3)) => {
|
|
if reconnect_attempt > 0 {
|
|
if connection_count == 0 {
|
|
let _ = notifier.notify("Connected to the server".to_string()).await;
|
|
} else {
|
|
let _ = notifier.notify("Reconnected to the server".to_string()).await;
|
|
}
|
|
}
|
|
|
|
connection_count += 1;
|
|
reconnect_attempt = 0;
|
|
|
|
conn.await
|
|
}
|
|
};
|
|
|
|
match result {
|
|
Ok(true) => {
|
|
break;
|
|
}
|
|
|
|
Ok(false) => {
|
|
let _ = notifier
|
|
.notify("Stream halted by the server".to_string())
|
|
.await;
|
|
|
|
break;
|
|
}
|
|
|
|
Err(e) => {
|
|
if let Some(tungstenite::error::Error::Protocol(
|
|
tungstenite::error::ProtocolError::SecWebSocketSubProtocolError(_),
|
|
)) = e.downcast_ref::<tungstenite::error::Error>()
|
|
{
|
|
// This happens when the server accepts the websocket connection
|
|
// but doesn't properly perform the protocol negotiation.
|
|
// This applies to asciinema-server v20241103 and earlier.
|
|
|
|
let _ = notifier
|
|
.notify("The server version is too old, forwarding failed".to_string())
|
|
.await;
|
|
|
|
break;
|
|
}
|
|
|
|
if let Some(tungstenite::error::Error::Http(response)) =
|
|
e.downcast_ref::<tungstenite::error::Error>()
|
|
{
|
|
if response.status().as_u16() == 400 {
|
|
// This happens when the server doesn't support our protocol (version).
|
|
// This applies to asciinema-server versions newer than v20241103.
|
|
|
|
let _ = notifier
|
|
.notify(
|
|
"CLI not compatible with the server, forwarding failed".to_string(),
|
|
)
|
|
.await;
|
|
|
|
break;
|
|
}
|
|
}
|
|
|
|
error!("connection error: {e}");
|
|
|
|
if reconnect_attempt == 0 {
|
|
if connection_count == 0 {
|
|
let _ = notifier
|
|
.notify("Cannot connect to the server, retrying...".to_string())
|
|
.await;
|
|
} else {
|
|
let _ = notifier
|
|
.notify("Disconnected from the server, reconnecting...".to_string())
|
|
.await;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
let delay = exponential_delay(reconnect_attempt);
|
|
reconnect_attempt = reconnect_attempt.saturating_add(1);
|
|
info!("reconnecting in {delay} ms");
|
|
|
|
tokio::select! {
|
|
_ = time::sleep(Duration::from_millis(delay)) => (),
|
|
_ = shutdown_token.cancelled() => break
|
|
}
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
async fn connect_and_forward(
|
|
url: &url::Url,
|
|
session_stream: impl Stream<Item = Result<Event, BroadcastStreamRecvError>> + Unpin,
|
|
) -> anyhow::Result<bool> {
|
|
let request = build_request(url)?;
|
|
let (ws, _) = tokio_tungstenite::connect_async_with_config(request, None, true).await?;
|
|
info!("connected to the endpoint");
|
|
|
|
handle_socket(ws, get_alis_stream(session_stream)).await
|
|
}
|
|
|
|
fn build_request(url: &url::Url) -> anyhow::Result<ClientRequestBuilder> {
|
|
let uri: Uri = url.to_string().parse()?;
|
|
|
|
Ok(ClientRequestBuilder::new(uri)
|
|
.with_sub_protocol("v1.alis")
|
|
.with_header("user-agent", api::build_user_agent()))
|
|
}
|
|
|
|
fn get_alis_stream(
|
|
stream: impl Stream<Item = Result<Event, BroadcastStreamRecvError>>,
|
|
) -> impl Stream<Item = anyhow::Result<Message>> {
|
|
alis::stream(stream)
|
|
.map(ws_result)
|
|
.chain(futures_util::stream::once(future::ready(Ok(
|
|
close_message(),
|
|
))))
|
|
}
|
|
|
|
async fn handle_socket<T>(
|
|
ws: WebSocketStream<MaybeTlsStream<TcpStream>>,
|
|
alis_messages: T,
|
|
) -> anyhow::Result<bool>
|
|
where
|
|
T: Stream<Item = anyhow::Result<Message>> + Unpin,
|
|
{
|
|
let (mut sink, mut stream) = ws.split();
|
|
let mut alis_messages = alis_messages;
|
|
let mut pings = ping_stream();
|
|
let mut ping_timeout: Pin<Box<dyn Future<Output = ()> + Send>> = Box::pin(future::pending());
|
|
|
|
loop {
|
|
tokio::select! {
|
|
message = alis_messages.next() => {
|
|
match message {
|
|
Some(message) => {
|
|
send_with_timeout(&mut sink, message?).await??;
|
|
},
|
|
|
|
None => {
|
|
return Ok(true);
|
|
}
|
|
}
|
|
},
|
|
|
|
ping = pings.next() => {
|
|
send_with_timeout(&mut sink, ping.unwrap()).await??;
|
|
ping_timeout = Box::pin(time::sleep(Duration::from_secs(PING_TIMEOUT)));
|
|
}
|
|
|
|
_ = &mut ping_timeout => bail!("ping timeout"),
|
|
|
|
message = stream.next() => {
|
|
match message {
|
|
Some(Ok(Message::Close(close_frame))) => {
|
|
info!("server closed the connection");
|
|
handle_close_frame(close_frame)?;
|
|
return Ok(false);
|
|
},
|
|
|
|
Some(Ok(Message::Ping(_))) => (),
|
|
|
|
Some(Ok(Message::Pong(_))) => {
|
|
ping_timeout = Box::pin(future::pending());
|
|
},
|
|
|
|
Some(Ok(msg)) => debug!("unexpected message from the server: {msg:?}"),
|
|
Some(Err(e)) => bail!(e),
|
|
None => bail!("SplitStream closed")
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
async fn send_with_timeout(
|
|
sink: &mut SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>,
|
|
message: Message,
|
|
) -> anyhow::Result<Result<(), tungstenite::Error>> {
|
|
time::timeout(Duration::from_secs(SEND_TIMEOUT), sink.send(message))
|
|
.await
|
|
.map_err(|_| anyhow!("send timeout"))
|
|
}
|
|
|
|
fn handle_close_frame(frame: Option<CloseFrame>) -> anyhow::Result<()> {
|
|
match frame {
|
|
Some(CloseFrame { code, reason }) => {
|
|
info!("close reason: {code} ({reason})");
|
|
|
|
match code {
|
|
CloseCode::Normal => Ok(()),
|
|
CloseCode::Library(code) if code < 4100 => Ok(()),
|
|
c => bail!("unclean close: {c}"),
|
|
}
|
|
}
|
|
|
|
None => {
|
|
info!("close reason: none");
|
|
Ok(())
|
|
}
|
|
}
|
|
}
|
|
|
|
fn exponential_delay(attempt: usize) -> u64 {
|
|
let mut rng = rand::rng();
|
|
let attempt = attempt.min(10);
|
|
let exp = (RECONNECT_DELAY_BASE * 2_u64.pow(attempt as u32)).min(RECONNECT_DELAY_CAP);
|
|
|
|
rng.random_range((exp / 2)..exp)
|
|
}
|
|
|
|
fn ws_result(m: Result<Vec<u8>, BroadcastStreamRecvError>) -> anyhow::Result<Message> {
|
|
match m {
|
|
Ok(bytes) => Ok(Message::binary(bytes)),
|
|
Err(e) => Err(anyhow::anyhow!(e)),
|
|
}
|
|
}
|
|
|
|
fn close_message() -> Message {
|
|
Message::Close(Some(CloseFrame {
|
|
code: CloseCode::Normal,
|
|
reason: "ended".into(),
|
|
}))
|
|
}
|
|
|
|
fn ping_stream() -> impl Stream<Item = Message> {
|
|
IntervalStream::new(time::interval(Duration::from_secs(PING_INTERVAL)))
|
|
.skip(1)
|
|
.map(|_| Message::Ping(vec![].into()))
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn exponential_delay_is_within_equal_jitter_bounds() {
|
|
for _ in 0..1000 {
|
|
let delay = exponential_delay(0);
|
|
assert!(delay >= 250, "delay {delay} below lower bound 250");
|
|
assert!(delay < 500, "delay {delay} above upper bound 499");
|
|
}
|
|
|
|
for _ in 0..1000 {
|
|
let delay = exponential_delay(1);
|
|
assert!(delay >= 500, "delay {delay} below lower bound 500");
|
|
assert!(delay < 1000, "delay {delay} above upper bound 999");
|
|
}
|
|
|
|
for _ in 0..1000 {
|
|
let delay = exponential_delay(4);
|
|
assert!(delay >= 4000, "delay {delay} below lower bound 4000");
|
|
assert!(delay < 8000, "delay {delay} above upper bound 7999");
|
|
}
|
|
|
|
for _ in 0..1000 {
|
|
let delay = exponential_delay(5);
|
|
assert!(delay >= 5000, "delay {delay} below lower bound 5000");
|
|
assert!(delay < 10000, "delay {delay} above upper bound 9999");
|
|
}
|
|
|
|
for _ in 0..1000 {
|
|
let delay = exponential_delay(10);
|
|
assert!(delay >= 5000, "delay {delay} below lower bound 5000");
|
|
assert!(delay < 10000, "delay {delay} above upper bound 9999");
|
|
}
|
|
|
|
for _ in 0..1000 {
|
|
let delay = exponential_delay(100);
|
|
assert!(delay >= 5000, "delay {delay} below lower bound 5000");
|
|
assert!(delay < 10000, "delay {delay} above upper bound 9999");
|
|
}
|
|
}
|
|
}
|