Make forwarder code more readable, split methods, etc

This commit is contained in:
Marcin Kulik
2025-06-06 20:01:02 +02:00
parent 0a5dbc2bf5
commit d06786fb3d

View File

@@ -4,10 +4,11 @@ use std::time::Duration;
use anyhow::{anyhow, bail}; use anyhow::{anyhow, bail};
use axum::http::Uri; use axum::http::Uri;
use futures_util::stream::SplitSink;
use futures_util::{SinkExt, Stream, StreamExt}; use futures_util::{SinkExt, Stream, StreamExt};
use rand::Rng; use rand::Rng;
use tokio::net::TcpStream; use tokio::net::TcpStream;
use tokio::time::{interval, sleep, timeout}; use tokio::time;
use tokio_stream::wrappers::errors::BroadcastStreamRecvError; use tokio_stream::wrappers::errors::BroadcastStreamRecvError;
use tokio_stream::wrappers::IntervalStream; use tokio_stream::wrappers::IntervalStream;
use tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode; use tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode;
@@ -38,18 +39,18 @@ pub async fn forward<N: Notifier>(
let mut connection_count: u64 = 0; let mut connection_count: u64 = 0;
loop { loop {
let stream = subscriber let session_stream = subscriber
.subscribe() .subscribe()
.await .await
.expect("stream should be alive"); .expect("stream should be alive");
let conn = connect_and_forward(&url, stream); let conn = connect_and_forward(&url, session_stream);
tokio::pin!(conn); tokio::pin!(conn);
let result = tokio::select! { let result = tokio::select! {
result = &mut conn => result, result = &mut conn => result,
_ = sleep(Duration::from_secs(3)) => { _ = time::sleep(Duration::from_secs(3)) => {
if reconnect_attempt > 0 { if reconnect_attempt > 0 {
if connection_count == 0 { if connection_count == 0 {
let _ = notifier.notify("Connected to the server".to_string()); let _ = notifier.notify("Connected to the server".to_string());
@@ -124,7 +125,7 @@ pub async fn forward<N: Notifier>(
info!("reconnecting in {delay}"); info!("reconnecting in {delay}");
tokio::select! { tokio::select! {
_ = sleep(Duration::from_millis(delay)) => (), _ = time::sleep(Duration::from_millis(delay)) => (),
_ = shutdown_token.cancelled() => break _ = shutdown_token.cancelled() => break
} }
} }
@@ -132,44 +133,51 @@ pub async fn forward<N: Notifier>(
async fn connect_and_forward( async fn connect_and_forward(
url: &url::Url, url: &url::Url,
stream: impl Stream<Item = Result<Event, BroadcastStreamRecvError>> + Unpin, session_stream: impl Stream<Item = Result<Event, BroadcastStreamRecvError>> + Unpin,
) -> anyhow::Result<bool> { ) -> anyhow::Result<bool> {
let uri: Uri = url.to_string().parse()?; let request = build_request(url)?;
let (ws, _) = tokio_tungstenite::connect_async_with_config(request, None, true).await?;
let builder = ClientRequestBuilder::new(uri)
.with_sub_protocol("v1.alis")
.with_header("user-agent", api::build_user_agent());
let (ws, _) = tokio_tungstenite::connect_async_with_config(builder, None, true).await?;
info!("connected to the endpoint"); info!("connected to the endpoint");
let events = alis::stream(stream) handle_socket(ws, get_alis_stream(session_stream)).await
}
fn build_request(url: &url::Url) -> anyhow::Result<ClientRequestBuilder> {
let uri: Uri = url.to_string().parse()?;
Ok(ClientRequestBuilder::new(uri)
.with_sub_protocol("v1.alis")
.with_header("user-agent", api::build_user_agent()))
}
fn get_alis_stream(
stream: impl Stream<Item = Result<Event, BroadcastStreamRecvError>>,
) -> impl Stream<Item = anyhow::Result<Message>> {
alis::stream(stream)
.map(ws_result) .map(ws_result)
.chain(futures_util::stream::once(future::ready(Ok( .chain(futures_util::stream::once(future::ready(Ok(
close_message(), close_message(),
)))); ))))
handle_socket(ws, events).await
} }
async fn handle_socket<T>( async fn handle_socket<T>(
ws: WebSocketStream<MaybeTlsStream<TcpStream>>, ws: WebSocketStream<MaybeTlsStream<TcpStream>>,
events: T, alis_messages: T,
) -> anyhow::Result<bool> ) -> anyhow::Result<bool>
where where
T: Stream<Item = anyhow::Result<Message>> + Unpin, T: Stream<Item = anyhow::Result<Message>> + Unpin,
{ {
let (mut sink, mut stream) = ws.split(); let (mut sink, mut stream) = ws.split();
let mut events = events; let mut alis_messages = alis_messages;
let mut pings = ping_stream(); let mut pings = ping_stream();
let mut ping_timeout: Pin<Box<dyn Future<Output = ()> + Send>> = Box::pin(future::pending()); let mut ping_timeout: Pin<Box<dyn Future<Output = ()> + Send>> = Box::pin(future::pending());
loop { loop {
tokio::select! { tokio::select! {
event = events.next() => { message = alis_messages.next() => {
match event { match message {
Some(event) => { Some(message) => {
timeout(Duration::from_secs(SEND_TIMEOUT), sink.send(event?)).await.map_err(|_| anyhow!("send timeout"))??; send_with_timeout(&mut sink, message?).await??;
}, },
None => { None => {
@@ -179,8 +187,8 @@ where
}, },
ping = pings.next() => { ping = pings.next() => {
timeout(Duration::from_secs(SEND_TIMEOUT), sink.send(ping.unwrap())).await.map_err(|_| anyhow!("send timeout"))??; send_with_timeout(&mut sink, ping.unwrap()).await??;
ping_timeout = Box::pin(sleep(Duration::from_secs(PING_TIMEOUT))); ping_timeout = Box::pin(time::sleep(Duration::from_secs(PING_TIMEOUT)));
} }
_ = &mut ping_timeout => bail!("ping timeout"), _ = &mut ping_timeout => bail!("ping timeout"),
@@ -208,6 +216,15 @@ where
} }
} }
async fn send_with_timeout(
sink: &mut SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>,
message: Message,
) -> anyhow::Result<Result<(), tungstenite::Error>> {
time::timeout(Duration::from_secs(SEND_TIMEOUT), sink.send(message))
.await
.map_err(|_| anyhow!("send timeout"))
}
fn handle_close_frame(frame: Option<CloseFrame>) -> anyhow::Result<()> { fn handle_close_frame(frame: Option<CloseFrame>) -> anyhow::Result<()> {
match frame { match frame {
Some(CloseFrame { code, reason }) => { Some(CloseFrame { code, reason }) => {
@@ -249,7 +266,7 @@ fn close_message() -> Message {
} }
fn ping_stream() -> impl Stream<Item = Message> { fn ping_stream() -> impl Stream<Item = Message> {
IntervalStream::new(interval(Duration::from_secs(PING_INTERVAL))) IntervalStream::new(time::interval(Duration::from_secs(PING_INTERVAL)))
.skip(1) .skip(1)
.map(|_| Message::Ping(vec![].into())) .map(|_| Message::Ping(vec![].into()))
} }