From d06786fb3d1193b428de42a812c8e56fb5f07c88 Mon Sep 17 00:00:00 2001 From: Marcin Kulik Date: Fri, 6 Jun 2025 20:01:02 +0200 Subject: [PATCH] Make forwarder code more readable, split methods, etc --- src/forwarder.rs | 69 ++++++++++++++++++++++++++++++------------------ 1 file changed, 43 insertions(+), 26 deletions(-) diff --git a/src/forwarder.rs b/src/forwarder.rs index fe15db5..ef87ab4 100644 --- a/src/forwarder.rs +++ b/src/forwarder.rs @@ -4,10 +4,11 @@ use std::time::Duration; use anyhow::{anyhow, bail}; use axum::http::Uri; +use futures_util::stream::SplitSink; use futures_util::{SinkExt, Stream, StreamExt}; use rand::Rng; use tokio::net::TcpStream; -use tokio::time::{interval, sleep, timeout}; +use tokio::time; use tokio_stream::wrappers::errors::BroadcastStreamRecvError; use tokio_stream::wrappers::IntervalStream; use tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode; @@ -38,18 +39,18 @@ pub async fn forward( let mut connection_count: u64 = 0; loop { - let stream = subscriber + let session_stream = subscriber .subscribe() .await .expect("stream should be alive"); - let conn = connect_and_forward(&url, stream); + let conn = connect_and_forward(&url, session_stream); tokio::pin!(conn); let result = tokio::select! { result = &mut conn => result, - _ = sleep(Duration::from_secs(3)) => { + _ = time::sleep(Duration::from_secs(3)) => { if reconnect_attempt > 0 { if connection_count == 0 { let _ = notifier.notify("Connected to the server".to_string()); @@ -124,7 +125,7 @@ pub async fn forward( info!("reconnecting in {delay}"); tokio::select! { - _ = sleep(Duration::from_millis(delay)) => (), + _ = time::sleep(Duration::from_millis(delay)) => (), _ = shutdown_token.cancelled() => break } } @@ -132,44 +133,51 @@ pub async fn forward( async fn connect_and_forward( url: &url::Url, - stream: impl Stream> + Unpin, + session_stream: impl Stream> + Unpin, ) -> anyhow::Result { - let uri: Uri = url.to_string().parse()?; - - let builder = ClientRequestBuilder::new(uri) - .with_sub_protocol("v1.alis") - .with_header("user-agent", api::build_user_agent()); - - let (ws, _) = tokio_tungstenite::connect_async_with_config(builder, None, true).await?; + let request = build_request(url)?; + let (ws, _) = tokio_tungstenite::connect_async_with_config(request, None, true).await?; info!("connected to the endpoint"); - let events = alis::stream(stream) + handle_socket(ws, get_alis_stream(session_stream)).await +} + +fn build_request(url: &url::Url) -> anyhow::Result { + let uri: Uri = url.to_string().parse()?; + + Ok(ClientRequestBuilder::new(uri) + .with_sub_protocol("v1.alis") + .with_header("user-agent", api::build_user_agent())) +} + +fn get_alis_stream( + stream: impl Stream>, +) -> impl Stream> { + alis::stream(stream) .map(ws_result) .chain(futures_util::stream::once(future::ready(Ok( close_message(), - )))); - - handle_socket(ws, events).await + )))) } async fn handle_socket( ws: WebSocketStream>, - events: T, + alis_messages: T, ) -> anyhow::Result where T: Stream> + Unpin, { let (mut sink, mut stream) = ws.split(); - let mut events = events; + let mut alis_messages = alis_messages; let mut pings = ping_stream(); let mut ping_timeout: Pin + Send>> = Box::pin(future::pending()); loop { tokio::select! { - event = events.next() => { - match event { - Some(event) => { - timeout(Duration::from_secs(SEND_TIMEOUT), sink.send(event?)).await.map_err(|_| anyhow!("send timeout"))??; + message = alis_messages.next() => { + match message { + Some(message) => { + send_with_timeout(&mut sink, message?).await??; }, None => { @@ -179,8 +187,8 @@ where }, ping = pings.next() => { - timeout(Duration::from_secs(SEND_TIMEOUT), sink.send(ping.unwrap())).await.map_err(|_| anyhow!("send timeout"))??; - ping_timeout = Box::pin(sleep(Duration::from_secs(PING_TIMEOUT))); + send_with_timeout(&mut sink, ping.unwrap()).await??; + ping_timeout = Box::pin(time::sleep(Duration::from_secs(PING_TIMEOUT))); } _ = &mut ping_timeout => bail!("ping timeout"), @@ -208,6 +216,15 @@ where } } +async fn send_with_timeout( + sink: &mut SplitSink>, Message>, + message: Message, +) -> anyhow::Result> { + time::timeout(Duration::from_secs(SEND_TIMEOUT), sink.send(message)) + .await + .map_err(|_| anyhow!("send timeout")) +} + fn handle_close_frame(frame: Option) -> anyhow::Result<()> { match frame { Some(CloseFrame { code, reason }) => { @@ -249,7 +266,7 @@ fn close_message() -> Message { } fn ping_stream() -> impl Stream { - IntervalStream::new(interval(Duration::from_secs(PING_INTERVAL))) + IntervalStream::new(time::interval(Duration::from_secs(PING_INTERVAL))) .skip(1) .map(|_| Message::Ping(vec![].into())) }