diff --git a/src/pty.rs b/src/pty.rs index 24d11ab..8c90dd8 100644 --- a/src/pty.rs +++ b/src/pty.rs @@ -3,8 +3,6 @@ use std::env; use std::ffi::{CString, NulError}; use std::os::fd::OwnedFd; use std::os::unix::io::AsRawFd; -use std::pin::Pin; -use std::task::{ready, Context, Poll}; use nix::errno::Errno; use nix::pty::{ForkptyResult, Winsize}; @@ -13,7 +11,7 @@ use nix::sys::wait::{self, WaitPidFlag, WaitStatus}; use nix::unistd::{self, Pid}; use nix::{libc, pty}; use tokio::io::unix::AsyncFd; -use tokio::io::{self, AsyncRead, AsyncWrite, ReadBuf}; +use tokio::io::{self, Interest}; use tokio::task; use crate::fd::FdExt; @@ -23,17 +21,25 @@ pub struct Pty { master: AsyncFd, } -pub struct PtyReadHalf<'a> { - pty: &'a Pty, -} - -pub struct PtyWriteHalf<'a> { - pty: &'a Pty, -} - impl Pty { - pub fn split(&self) -> (PtyReadHalf<'_>, PtyWriteHalf<'_>) { - (PtyReadHalf { pty: self }, PtyWriteHalf { pty: self }) + pub async fn read(&self, buffer: &mut [u8]) -> io::Result { + self.master + .async_io(Interest::READABLE, |fd| match unistd::read(fd, buffer) { + Ok(n) => Ok(n), + Err(Errno::EIO) => Ok(0), + Err(e) => Err(e.into()), + }) + .await + } + + pub async fn write(&self, buffer: &[u8]) -> io::Result { + self.master + .async_io(Interest::WRITABLE, |fd| match unistd::write(fd, buffer) { + Ok(n) => Ok(n), + Err(Errno::EIO) => Ok(0), + Err(e) => Err(e.into()), + }) + .await } pub fn resize(&self, winsize: Winsize) { @@ -51,62 +57,6 @@ impl Pty { } } -impl AsyncRead for Pty { - fn poll_read( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut ReadBuf<'_>, - ) -> Poll> { - loop { - let mut guard = ready!(self.master.poll_read_ready(cx))?; - let unfilled = buf.initialize_unfilled(); - - match guard.try_io(|fd| match unistd::read(fd, unfilled) { - Ok(n) => Ok(n), - Err(Errno::EIO) => Ok(0), - Err(e) => Err(io::Error::from(e)), - }) { - Ok(Ok(n)) => { - buf.advance(n); - return Poll::Ready(Ok(())); - } - - Ok(Err(e)) => return Poll::Ready(Err(e)), - Err(_would_block) => continue, - } - } - } -} - -impl AsyncWrite for Pty { - fn poll_write( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - loop { - let mut guard = ready!(self.master.poll_write_ready(cx))?; - - match guard.try_io(|fd| match unistd::write(fd, buf) { - Ok(n) => Ok(n), - Err(Errno::EIO) => Ok(0), - Err(e) => Err(io::Error::from(e)), - }) { - Ok(result) => return Poll::Ready(result), - Err(_would_block) => continue, - } - } - } - - fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } -} - impl Drop for Pty { fn drop(&mut self) { self.kill(); @@ -114,62 +64,6 @@ impl Drop for Pty { } } -impl AsyncRead for PtyReadHalf<'_> { - fn poll_read( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut ReadBuf<'_>, - ) -> Poll> { - loop { - let mut guard = ready!(self.pty.master.poll_read_ready(cx))?; - let unfilled = buf.initialize_unfilled(); - - match guard.try_io(|fd| match unistd::read(fd, unfilled) { - Ok(n) => Ok(n), - Err(Errno::EIO) => Ok(0), - Err(e) => Err(io::Error::from(e)), - }) { - Ok(Ok(n)) => { - buf.advance(n); - return Poll::Ready(Ok(())); - } - - Ok(Err(e)) => return Poll::Ready(Err(e)), - Err(_would_block) => continue, - } - } - } -} - -impl AsyncWrite for PtyWriteHalf<'_> { - fn poll_write( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - loop { - let mut guard = ready!(self.pty.master.poll_write_ready(cx))?; - - match guard.try_io(|fd| match unistd::write(fd, buf) { - Ok(n) => Ok(n), - Err(Errno::EIO) => Ok(0), - Err(e) => Err(io::Error::from(e)), - }) { - Ok(result) => return Poll::Ready(result), - Err(_would_block) => continue, - } - } - } - - fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } -} - pub fn spawn>( command: &[S], winsize: Winsize, @@ -214,8 +108,6 @@ fn handle_child>( mod tests { use std::collections::HashMap; - use tokio::io::{AsyncReadExt, AsyncWriteExt}; - use super::Pty; use crate::tty::TtySize; @@ -223,7 +115,7 @@ mod tests { super::spawn(command, TtySize::default().into(), extra_env).unwrap() } - async fn read_output(mut pty: Pty) -> Vec { + async fn read_output(pty: Pty) -> Vec { let mut buf = [0u8; 1024]; let mut output = Vec::new(); @@ -284,9 +176,9 @@ sys.stdout.write('bar'); #[tokio::test] async fn spawn_echo_input() { - let mut pty = spawn(&["cat"], &HashMap::new()).await; - pty.write_all(b"foo").await.unwrap(); - pty.write_all(b"bar").await.unwrap(); + let pty = spawn(&["cat"], &HashMap::new()).await; + pty.write(b"foo").await.unwrap(); + pty.write(b"bar").await.unwrap(); pty.kill(); let output = read_output(pty).await.join(""); diff --git a/src/session.rs b/src/session.rs index e07f35a..0b470e2 100644 --- a/src/session.rs +++ b/src/session.rs @@ -8,7 +8,6 @@ use nix::sys::wait::{WaitPidFlag, WaitStatus}; use signal_hook::consts::signal::*; use signal_hook_tokio::Signals; use tokio::io; -use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::sync::mpsc; use tokio::time::Instant; use tracing::error; @@ -139,11 +138,10 @@ impl Session { let mut input: Vec = Vec::with_capacity(BUF_SIZE); let mut output: Vec = Vec::with_capacity(BUF_SIZE); let mut wait_status = None; - let (mut pty_reader, mut pty_writer) = pty.split(); loop { tokio::select! { - result = pty_reader.read(&mut output_buf) => { + result = pty.read(&mut output_buf) => { let n = result?; if n > 0 { @@ -154,7 +152,7 @@ impl Session { } } - result = pty_writer.write(&input), if !input.is_empty() => { + result = pty.write(&input), if !input.is_empty() => { let n = result?; input.drain(..n); } diff --git a/src/tty/default.rs b/src/tty/default.rs index 20a1f66..d180cae 100644 --- a/src/tty/default.rs +++ b/src/tty/default.rs @@ -4,13 +4,13 @@ use std::os::fd::{AsFd, AsRawFd}; use std::os::unix::fs::OpenOptionsExt; use async_trait::async_trait; +use nix::libc; use nix::pty::Winsize; use nix::sys::termios::{self, SetArg, Termios}; -use nix::libc; use tokio::io::unix::AsyncFd; use tokio::io::{self, Interest}; -use super::{TtySize, Tty, TtyTheme}; +use super::{Tty, TtySize, TtyTheme}; pub struct DevTty { file: AsyncFd,