From 7969486d32333653a84b34a8da621521e9e2be60 Mon Sep 17 00:00:00 2001 From: Eric Zhang Date: Mon, 9 Jun 2025 16:10:40 -0400 Subject: [PATCH] Use copy_bidirectional, handle half-closed TCP streams (#165) --- src/client.rs | 8 +++----- src/server.rs | 6 +++--- src/shared.rs | 17 +---------------- tests/e2e_test.rs | 37 +++++++++++++++++++++++++++++++++++++ 4 files changed, 44 insertions(+), 24 deletions(-) diff --git a/src/client.rs b/src/client.rs index 2c21c16..cb8fa7b 100644 --- a/src/client.rs +++ b/src/client.rs @@ -8,9 +8,7 @@ use tracing::{error, info, info_span, warn, Instrument}; use uuid::Uuid; use crate::auth::Authenticator; -use crate::shared::{ - proxy, ClientMessage, Delimited, ServerMessage, CONTROL_PORT, NETWORK_TIMEOUT, -}; +use crate::shared::{ClientMessage, Delimited, ServerMessage, CONTROL_PORT, NETWORK_TIMEOUT}; /// State structure for the client. pub struct Client { @@ -112,10 +110,10 @@ impl Client { } remote_conn.send(ClientMessage::Accept(id)).await?; let mut local_conn = connect_with_timeout(&self.local_host, self.local_port).await?; - let parts = remote_conn.into_parts(); + let mut parts = remote_conn.into_parts(); debug_assert!(parts.write_buf.is_empty(), "framed write buffer not empty"); local_conn.write_all(&parts.read_buf).await?; // mostly of the cases, this will be empty - proxy(local_conn, parts.io).await?; + tokio::io::copy_bidirectional(&mut local_conn, &mut parts.io).await?; Ok(()) } } diff --git a/src/server.rs b/src/server.rs index 3c38988..f47d714 100644 --- a/src/server.rs +++ b/src/server.rs @@ -12,7 +12,7 @@ use tracing::{info, info_span, warn, Instrument}; use uuid::Uuid; use crate::auth::Authenticator; -use crate::shared::{proxy, ClientMessage, Delimited, ServerMessage, CONTROL_PORT}; +use crate::shared::{ClientMessage, Delimited, ServerMessage, CONTROL_PORT}; /// State structure for the server. pub struct Server { @@ -172,10 +172,10 @@ impl Server { info!(%id, "forwarding connection"); match self.conns.remove(&id) { Some((_, mut stream2)) => { - let parts = stream.into_parts(); + let mut parts = stream.into_parts(); debug_assert!(parts.write_buf.is_empty(), "framed write buffer not empty"); stream2.write_all(&parts.read_buf).await?; - proxy(parts.io, stream2).await? + tokio::io::copy_bidirectional(&mut parts.io, &mut stream2).await?; } None => warn!(%id, "missing connection"), } diff --git a/src/shared.rs b/src/shared.rs index 10b1bc8..d9c5d3b 100644 --- a/src/shared.rs +++ b/src/shared.rs @@ -5,7 +5,7 @@ use std::time::Duration; use anyhow::{Context, Result}; use futures_util::{SinkExt, StreamExt}; use serde::{de::DeserializeOwned, Deserialize, Serialize}; -use tokio::io::{self, AsyncRead, AsyncWrite}; +use tokio::io::{AsyncRead, AsyncWrite}; use tokio::time::timeout; use tokio_util::codec::{AnyDelimiterCodec, Framed, FramedParts}; use tracing::trace; @@ -97,18 +97,3 @@ impl Delimited { self.0.into_parts() } } - -/// Copy data mutually between two read/write streams. -pub async fn proxy(stream1: S1, stream2: S2) -> io::Result<()> -where - S1: AsyncRead + AsyncWrite + Unpin, - S2: AsyncRead + AsyncWrite + Unpin, -{ - let (mut s1_read, mut s1_write) = io::split(stream1); - let (mut s2_read, mut s2_write) = io::split(stream2); - tokio::select! { - res = io::copy(&mut s1_read, &mut s2_write) => res, - res = io::copy(&mut s2_read, &mut s1_write) => res, - }?; - Ok(()) -} diff --git a/tests/e2e_test.rs b/tests/e2e_test.rs index 50a6739..e8a0e78 100644 --- a/tests/e2e_test.rs +++ b/tests/e2e_test.rs @@ -125,3 +125,40 @@ fn empty_port_range() { let max_port = 3000; let _ = Server::new(min_port..=max_port, None); } + +#[tokio::test] +async fn half_closed_tcp_stream() -> Result<()> { + // Check that "half-closed" TCP streams will not result in spontaneous hangups. + let _guard = SERIAL_GUARD.lock().await; + + spawn_server(None).await; + let (listener, addr) = spawn_client(None).await?; + + let (mut cli, (mut srv, _)) = tokio::try_join!(TcpStream::connect(addr), listener.accept())?; + + // Send data before half-closing one of the streams. + let mut buf = b"message before shutdown".to_vec(); + cli.write_all(&buf).await?; + + // Only close the write half of the stream. This is a half-closed stream. In the + // TCP protocol, it is represented as a FIN packet on one end. The entire stream + // is only closed after two FINs are exchanged and ACKed by the other end. + cli.shutdown().await?; + + srv.read_exact(&mut buf).await?; + assert_eq!(buf, b"message before shutdown"); + assert_eq!(srv.read(&mut buf).await?, 0); // EOF + + // Now make sure that the other stream can still send data, despite + // half-shutdown on client->server side. + let mut buf = b"hello from the other side!".to_vec(); + srv.write_all(&buf).await?; + cli.read_exact(&mut buf).await?; + assert_eq!(buf, b"hello from the other side!"); + + // We don't have to think about CLOSE_RD handling because that's not really + // part of the TCP protocol, just the POSIX streams API. It is implemented by + // the OS ignoring future packets received on that stream. + + Ok(()) +}