Refactor pty module

This commit is contained in:
Marcin Kulik
2025-06-18 20:27:08 +02:00
parent ac4e92dfc4
commit 1ef501def7
2 changed files with 172 additions and 71 deletions

View File

@@ -3,6 +3,8 @@ use std::env;
use std::ffi::{CString, NulError};
use std::os::fd::OwnedFd;
use std::os::unix::io::AsRawFd;
use std::pin::Pin;
use std::task::{ready, Context, Poll};
use nix::errno::Errno;
use nix::pty::{ForkptyResult, Winsize};
@@ -11,7 +13,7 @@ use nix::sys::wait::{self, WaitPidFlag, WaitStatus};
use nix::unistd::{self, Pid};
use nix::{libc, pty};
use tokio::io::unix::AsyncFd;
use tokio::io::{self, Interest};
use tokio::io::{self, AsyncRead, AsyncWrite, ReadBuf};
use tokio::task;
use crate::fd::FdExt;
@@ -21,7 +23,154 @@ pub struct Pty {
master: AsyncFd<OwnedFd>,
}
pub async fn spawn<S: AsRef<str>>(
pub struct PtyReadHalf<'a> {
pty: &'a Pty,
}
pub struct PtyWriteHalf<'a> {
pty: &'a Pty,
}
impl Pty {
pub fn split(&self) -> (PtyReadHalf<'_>, PtyWriteHalf<'_>) {
(PtyReadHalf { pty: self }, PtyWriteHalf { pty: self })
}
pub fn resize(&self, winsize: Winsize) {
unsafe { libc::ioctl(self.master.as_raw_fd(), libc::TIOCSWINSZ, &winsize) };
}
pub fn kill(&self) {
// Any errors occurred when killing the child are ignored.
let _ = signal::kill(self.child, Signal::SIGTERM);
}
pub async fn wait(&self, options: Option<WaitPidFlag>) -> io::Result<WaitStatus> {
let pid = self.child;
task::spawn_blocking(move || Ok(wait::waitpid(pid, options)?)).await?
}
}
impl AsyncRead for Pty {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
loop {
let mut guard = ready!(self.master.poll_read_ready(cx))?;
let unfilled = buf.initialize_unfilled();
match guard.try_io(|fd| match unistd::read(fd, unfilled) {
Ok(n) => Ok(n),
Err(Errno::EIO) => Ok(0),
Err(e) => Err(io::Error::from(e)),
}) {
Ok(Ok(n)) => {
buf.advance(n);
return Poll::Ready(Ok(()));
}
Ok(Err(e)) => return Poll::Ready(Err(e)),
Err(_would_block) => continue,
}
}
}
}
impl AsyncWrite for Pty {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
loop {
let mut guard = ready!(self.master.poll_write_ready(cx))?;
match guard.try_io(|fd| match unistd::write(fd, buf) {
Ok(n) => Ok(n),
Err(Errno::EIO) => Ok(0),
Err(e) => Err(io::Error::from(e)),
}) {
Ok(result) => return Poll::Ready(result),
Err(_would_block) => continue,
}
}
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
}
impl Drop for Pty {
fn drop(&mut self) {
self.kill();
let _ = wait::waitpid(self.child, None);
}
}
impl AsyncRead for PtyReadHalf<'_> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
loop {
let mut guard = ready!(self.pty.master.poll_read_ready(cx))?;
let unfilled = buf.initialize_unfilled();
match guard.try_io(|fd| match unistd::read(fd, unfilled) {
Ok(n) => Ok(n),
Err(Errno::EIO) => Ok(0),
Err(e) => Err(io::Error::from(e)),
}) {
Ok(Ok(n)) => {
buf.advance(n);
return Poll::Ready(Ok(()));
}
Ok(Err(e)) => return Poll::Ready(Err(e)),
Err(_would_block) => continue,
}
}
}
}
impl AsyncWrite for PtyWriteHalf<'_> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
loop {
let mut guard = ready!(self.pty.master.poll_write_ready(cx))?;
match guard.try_io(|fd| match unistd::write(fd, buf) {
Ok(n) => Ok(n),
Err(Errno::EIO) => Ok(0),
Err(e) => Err(io::Error::from(e)),
}) {
Ok(result) => return Poll::Ready(result),
Err(_would_block) => continue,
}
}
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
}
pub fn spawn<S: AsRef<str>>(
command: &[S],
winsize: Winsize,
extra_env: &HashMap<String, String>,
@@ -43,49 +192,6 @@ pub async fn spawn<S: AsRef<str>>(
}
}
impl Pty {
pub async fn read(&self, buffer: &mut [u8]) -> io::Result<usize> {
self.master
.async_io(Interest::READABLE, |fd| match unistd::read(fd, buffer) {
Ok(n) => Ok(n),
Err(Errno::EIO) => Ok(0),
Err(e) => Err(e.into()),
})
.await
}
pub async fn write(&self, buffer: &[u8]) -> io::Result<usize> {
self.master
.async_io(Interest::WRITABLE, |fd| match unistd::write(fd, buffer) {
Ok(n) => Ok(n),
Err(Errno::EIO) => Ok(0),
Err(e) => Err(e.into()),
})
.await
}
pub fn resize(&self, winsize: Winsize) {
unsafe { libc::ioctl(self.master.as_raw_fd(), libc::TIOCSWINSZ, &winsize) };
}
pub fn kill(&self) {
// Any errors occurred when killing the child are ignored.
let _ = signal::kill(self.child, Signal::SIGTERM);
}
pub async fn wait(&self, options: Option<WaitPidFlag>) -> io::Result<WaitStatus> {
let pid = self.child;
task::spawn_blocking(move || Ok(wait::waitpid(pid, options)?)).await?
}
}
impl Drop for Pty {
fn drop(&mut self) {
self.kill();
let _ = wait::waitpid(self.child, None);
}
}
fn handle_child<S: AsRef<str>>(
command: &[S],
extra_env: &HashMap<String, String>,
@@ -106,36 +212,30 @@ fn handle_child<S: AsRef<str>>(
#[cfg(test)]
mod tests {
use std::{collections::HashMap, sync::Arc};
use std::collections::HashMap;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use super::Pty;
use crate::tty::TtySize;
async fn spawn<S: AsRef<str>>(command: &[S], extra_env: &HashMap<String, String>) -> Arc<Pty> {
Arc::new(
super::spawn(command, TtySize::default().into(), extra_env)
.await
.unwrap(),
)
async fn spawn<S: AsRef<str>>(command: &[S], extra_env: &HashMap<String, String>) -> Pty {
super::spawn(command, TtySize::default().into(), extra_env).unwrap()
}
async fn read_output(pty: Arc<Pty>) -> Vec<String> {
tokio::spawn(async move {
let mut buf = [0u8; 1024];
let mut output = Vec::new();
async fn read_output(mut pty: Pty) -> Vec<String> {
let mut buf = [0u8; 1024];
let mut output = Vec::new();
while let Ok(n) = pty.read(&mut buf).await {
if n == 0 {
break;
}
output.push(String::from_utf8_lossy(&buf[..n]).to_string());
while let Ok(n) = pty.read(&mut buf).await {
if n == 0 {
break;
}
output
})
.await
.unwrap()
output.push(String::from_utf8_lossy(&buf[..n]).to_string());
}
output
}
#[tokio::test]
@@ -166,9 +266,9 @@ sys.stdout.write('bar');
#[tokio::test]
async fn spawn_quick() {
let pty = spawn(&["printf", "hello world\n"], &HashMap::new()).await;
let output = read_output(pty).await;
let output = read_output(pty).await.join("");
assert!(!output.is_empty());
assert_eq!(output, "hello world\r\n");
}
#[tokio::test]

View File

@@ -80,7 +80,7 @@ pub async fn run<S: AsRef<str>, T: Tty + ?Sized, N: Notifier>(
let epoch = Instant::now();
let (events_tx, events_rx) = mpsc::channel::<Event>(1024);
let winsize = tty.get_size();
let pty = pty::spawn(command, winsize, extra_env).await?;
let pty = pty::spawn(command, winsize, extra_env)?;
tokio::spawn(forward_events(events_rx, outputs));
let mut session = Session {
@@ -140,10 +140,11 @@ impl<N: Notifier> Session<N> {
let mut output: Vec<u8> = Vec::with_capacity(BUF_SIZE);
let mut wait_status = None;
let (mut tty_reader, mut tty_writer) = tty.split();
let (mut pty_reader, mut pty_writer) = pty.split();
loop {
tokio::select! {
result = pty.read(&mut output_buf) => {
result = pty_reader.read(&mut output_buf) => {
let n = result?;
if n > 0 {
@@ -154,7 +155,7 @@ impl<N: Notifier> Session<N> {
}
}
result = pty.write(&input), if !input.is_empty() => {
result = pty_writer.write(&input), if !input.is_empty() => {
let n = result?;
input.drain(..n);
}