Refactor pty Handler trait

This commit is contained in:
Marcin Kulik
2024-09-20 12:05:38 +02:00
parent 6f2dcb434b
commit 2728a10bbd
3 changed files with 69 additions and 64 deletions

View File

@@ -19,26 +19,26 @@ use std::{env, fs};
type ExtraEnv = HashMap<String, String>; type ExtraEnv = HashMap<String, String>;
pub trait Recorder { pub trait Handler {
fn start(&mut self, size: TtySize) -> io::Result<()>; fn start(&mut self, size: TtySize);
fn output(&mut self, data: &[u8]); fn output(&mut self, data: &[u8]) -> bool;
fn input(&mut self, data: &[u8]) -> bool; fn input(&mut self, data: &[u8]) -> bool;
fn resize(&mut self, size: TtySize); fn resize(&mut self, size: TtySize) -> bool;
} }
pub fn exec<S: AsRef<str>, T: Tty + ?Sized, R: Recorder>( pub fn exec<S: AsRef<str>, T: Tty + ?Sized, H: Handler>(
command: &[S], command: &[S],
extra_env: &ExtraEnv, extra_env: &ExtraEnv,
tty: &mut T, tty: &mut T,
recorder: &mut R, handler: &mut H,
) -> Result<i32> { ) -> Result<i32> {
let winsize = tty.get_size(); let winsize = tty.get_size();
recorder.start(winsize.into())?; handler.start(winsize.into());
let result = unsafe { pty::forkpty(Some(&winsize), None) }?; let result = unsafe { pty::forkpty(Some(&winsize), None) }?;
match result.fork_result { match result.fork_result {
ForkResult::Parent { child } => { ForkResult::Parent { child } => {
handle_parent(result.master.as_raw_fd(), child, tty, recorder) handle_parent(result.master.as_raw_fd(), child, tty, handler)
} }
ForkResult::Child => { ForkResult::Child => {
@@ -48,13 +48,13 @@ pub fn exec<S: AsRef<str>, T: Tty + ?Sized, R: Recorder>(
} }
} }
fn handle_parent<T: Tty + ?Sized, R: Recorder>( fn handle_parent<T: Tty + ?Sized, H: Handler>(
master_fd: RawFd, master_fd: RawFd,
child: unistd::Pid, child: unistd::Pid,
tty: &mut T, tty: &mut T,
recorder: &mut R, handler: &mut H,
) -> Result<i32> { ) -> Result<i32> {
let wait_result = match copy(master_fd, child, tty, recorder) { let wait_result = match copy(master_fd, child, tty, handler) {
Ok(Some(status)) => Ok(status), Ok(Some(status)) => Ok(status),
Ok(None) => wait::waitpid(child, None), Ok(None) => wait::waitpid(child, None),
@@ -74,11 +74,11 @@ fn handle_parent<T: Tty + ?Sized, R: Recorder>(
const BUF_SIZE: usize = 128 * 1024; const BUF_SIZE: usize = 128 * 1024;
fn copy<T: Tty + ?Sized, R: Recorder>( fn copy<T: Tty + ?Sized, H: Handler>(
master_raw_fd: RawFd, master_raw_fd: RawFd,
child: unistd::Pid, child: unistd::Pid,
tty: &mut T, tty: &mut T,
recorder: &mut R, handler: &mut H,
) -> Result<Option<WaitStatus>> { ) -> Result<Option<WaitStatus>> {
let mut master = unsafe { fs::File::from_raw_fd(master_raw_fd) }; let mut master = unsafe { fs::File::from_raw_fd(master_raw_fd) };
let mut buf = [0u8; BUF_SIZE]; let mut buf = [0u8; BUF_SIZE];
@@ -146,8 +146,9 @@ fn copy<T: Tty + ?Sized, R: Recorder>(
if master_read { if master_read {
while let Some(n) = read_non_blocking(&mut master, &mut buf)? { while let Some(n) = read_non_blocking(&mut master, &mut buf)? {
if n > 0 { if n > 0 {
recorder.output(&buf[0..n]); if handler.output(&buf[0..n]) {
output.extend_from_slice(&buf[0..n]); output.extend_from_slice(&buf[0..n]);
}
} else if output.is_empty() { } else if output.is_empty() {
return Ok(None); return Ok(None);
} else { } else {
@@ -204,7 +205,7 @@ fn copy<T: Tty + ?Sized, R: Recorder>(
if tty_read { if tty_read {
while let Some(n) = read_non_blocking(tty, &mut buf)? { while let Some(n) = read_non_blocking(tty, &mut buf)? {
if n > 0 { if n > 0 {
if recorder.input(&buf[0..n]) { if handler.input(&buf[0..n]) {
input.extend_from_slice(&buf[0..n]); input.extend_from_slice(&buf[0..n]);
} }
} else { } else {
@@ -216,8 +217,10 @@ fn copy<T: Tty + ?Sized, R: Recorder>(
if sigwinch_read { if sigwinch_read {
sigwinch_fd.flush(); sigwinch_fd.flush();
let winsize = tty.get_size(); let winsize = tty.get_size();
set_pty_size(master_raw_fd, &winsize);
recorder.resize(winsize.into()); if handler.resize(winsize.into()) {
set_pty_size(master_raw_fd, &winsize);
}
} }
let mut kill_the_child = false; let mut kill_the_child = false;
@@ -366,34 +369,37 @@ impl Drop for SignalFd {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::Recorder; use super::Handler;
use crate::pty::ExtraEnv; use crate::pty::ExtraEnv;
use crate::tty::{FixedSizeTty, NullTty, TtySize}; use crate::tty::{FixedSizeTty, NullTty, TtySize};
#[derive(Default)] #[derive(Default)]
struct TestRecorder { struct TestHandler {
tty_size: Option<TtySize>, tty_size: Option<TtySize>,
output: Vec<Vec<u8>>, output: Vec<Vec<u8>>,
} }
impl Recorder for TestRecorder { impl Handler for TestHandler {
fn start(&mut self, tty_size: TtySize) -> std::io::Result<()> { fn start(&mut self, tty_size: TtySize) {
self.tty_size = Some(tty_size); self.tty_size = Some(tty_size);
Ok(())
} }
fn output(&mut self, data: &[u8]) { fn output(&mut self, data: &[u8]) -> bool {
self.output.push(data.into()); self.output.push(data.into());
true
} }
fn input(&mut self, _data: &[u8]) -> bool { fn input(&mut self, _data: &[u8]) -> bool {
true true
} }
fn resize(&mut self, _size: TtySize) {} fn resize(&mut self, _size: TtySize) -> bool {
true
}
} }
impl TestRecorder { impl TestHandler {
fn output(&self) -> Vec<String> { fn output(&self) -> Vec<String> {
self.output self.output
.iter() .iter()
@@ -404,7 +410,7 @@ mod tests {
#[test] #[test]
fn exec_basic() { fn exec_basic() {
let mut recorder = TestRecorder::default(); let mut handler = TestHandler::default();
let code = r#" let code = r#"
import sys; import sys;
@@ -419,47 +425,47 @@ sys.stdout.write('bar');
&["python3", "-c", code], &["python3", "-c", code],
&ExtraEnv::new(), &ExtraEnv::new(),
&mut NullTty::open().unwrap(), &mut NullTty::open().unwrap(),
&mut recorder, &mut handler,
) )
.unwrap(); .unwrap();
assert_eq!(recorder.output(), vec!["foo", "bar"]); assert_eq!(handler.output(), vec!["foo", "bar"]);
assert_eq!(recorder.tty_size, Some(TtySize(80, 24))); assert_eq!(handler.tty_size, Some(TtySize(80, 24)));
} }
#[test] #[test]
fn exec_no_output() { fn exec_no_output() {
let mut recorder = TestRecorder::default(); let mut handler = TestHandler::default();
super::exec( super::exec(
&["true"], &["true"],
&ExtraEnv::new(), &ExtraEnv::new(),
&mut NullTty::open().unwrap(), &mut NullTty::open().unwrap(),
&mut recorder, &mut handler,
) )
.unwrap(); .unwrap();
assert!(recorder.output().is_empty()); assert!(handler.output().is_empty());
} }
#[test] #[test]
fn exec_quick() { fn exec_quick() {
let mut recorder = TestRecorder::default(); let mut handler = TestHandler::default();
super::exec( super::exec(
&["printf", "hello world\n"], &["printf", "hello world\n"],
&ExtraEnv::new(), &ExtraEnv::new(),
&mut NullTty::open().unwrap(), &mut NullTty::open().unwrap(),
&mut recorder, &mut handler,
) )
.unwrap(); .unwrap();
assert!(!recorder.output().is_empty()); assert!(!handler.output().is_empty());
} }
#[test] #[test]
fn exec_extra_env() { fn exec_extra_env() {
let mut recorder = TestRecorder::default(); let mut handler = TestHandler::default();
let mut env = ExtraEnv::new(); let mut env = ExtraEnv::new();
env.insert("ASCIINEMA_TEST_FOO".to_owned(), "bar".to_owned()); env.insert("ASCIINEMA_TEST_FOO".to_owned(), "bar".to_owned());
@@ -468,25 +474,25 @@ sys.stdout.write('bar');
&["sh", "-c", "echo -n $ASCIINEMA_TEST_FOO"], &["sh", "-c", "echo -n $ASCIINEMA_TEST_FOO"],
&env, &env,
&mut NullTty::open().unwrap(), &mut NullTty::open().unwrap(),
&mut recorder, &mut handler,
) )
.unwrap(); .unwrap();
assert_eq!(recorder.output(), vec!["bar"]); assert_eq!(handler.output(), vec!["bar"]);
} }
#[test] #[test]
fn exec_winsize_override() { fn exec_winsize_override() {
let mut recorder = TestRecorder::default(); let mut handler = TestHandler::default();
super::exec( super::exec(
&["true"], &["true"],
&ExtraEnv::new(), &ExtraEnv::new(),
&mut FixedSizeTty::new(NullTty::open().unwrap(), Some(100), Some(50)), &mut FixedSizeTty::new(NullTty::open().unwrap(), Some(100), Some(50)),
&mut recorder, &mut handler,
) )
.unwrap(); .unwrap();
assert_eq!(recorder.tty_size, Some(TtySize(100, 50))); assert_eq!(handler.tty_size, Some(TtySize(100, 50)));
} }
} }

View File

@@ -81,10 +81,10 @@ impl Recorder {
} }
} }
impl pty::Recorder for Recorder { impl pty::Handler for Recorder {
fn start(&mut self, tty_size: tty::TtySize) -> io::Result<()> { fn start(&mut self, tty_size: tty::TtySize) {
let mut output = self.output.take().unwrap(); let mut output = self.output.take().unwrap();
output.start(&tty_size)?; let _ = output.start(&tty_size);
let receiver = self.receiver.take().unwrap(); let receiver = self.receiver.take().unwrap();
let mut notifier = self.notifier.take().unwrap(); let mut notifier = self.notifier.take().unwrap();
@@ -134,17 +134,15 @@ impl pty::Recorder for Recorder {
self.handle = Some(util::JoinHandle::new(handle)); self.handle = Some(util::JoinHandle::new(handle));
self.start_time = Instant::now(); self.start_time = Instant::now();
Ok(())
} }
fn output(&mut self, data: &[u8]) { fn output(&mut self, data: &[u8]) -> bool {
if self.pause_time.is_some() { if self.pause_time.is_none() {
return; let msg = Message::Output(self.elapsed_time(), data.into());
self.sender.send(msg).expect("output send should succeed");
} }
let msg = Message::Output(self.elapsed_time(), data.into()); true
self.sender.send(msg).expect("output send should succeed");
} }
fn input(&mut self, data: &[u8]) -> bool { fn input(&mut self, data: &[u8]) -> bool {
@@ -187,9 +185,11 @@ impl pty::Recorder for Recorder {
true true
} }
fn resize(&mut self, size: tty::TtySize) { fn resize(&mut self, size: tty::TtySize) -> bool {
let msg = Message::Resize(self.elapsed_time(), size); let msg = Message::Resize(self.elapsed_time(), size);
self.sender.send(msg).expect("resize send should succeed"); self.sender.send(msg).expect("resize send should succeed");
true
} }
} }

View File

@@ -7,7 +7,6 @@ use crate::notifier::Notifier;
use crate::pty; use crate::pty;
use crate::tty; use crate::tty;
use crate::util; use crate::util;
use std::io;
use std::net; use std::net;
use std::thread; use std::thread;
use std::time::Duration; use std::time::Duration;
@@ -85,8 +84,8 @@ impl Streamer {
} }
} }
impl pty::Recorder for Streamer { impl pty::Handler for Streamer {
fn start(&mut self, tty_size: tty::TtySize) -> io::Result<()> { fn start(&mut self, tty_size: tty::TtySize) {
let pty_rx = self.pty_rx.take().unwrap(); let pty_rx = self.pty_rx.take().unwrap();
let (clients_tx, mut clients_rx) = mpsc::channel(1); let (clients_tx, mut clients_rx) = mpsc::channel(1);
let shutdown_token = tokio_util::sync::CancellationToken::new(); let shutdown_token = tokio_util::sync::CancellationToken::new();
@@ -139,17 +138,15 @@ impl pty::Recorder for Streamer {
})); }));
self.start_time = Instant::now(); self.start_time = Instant::now();
Ok(())
} }
fn output(&mut self, raw: &[u8]) { fn output(&mut self, raw: &[u8]) -> bool {
if self.paused { if !self.paused {
return; let event = Event::Output(self.elapsed_time(), raw.into());
let _ = self.pty_tx.send(event);
} }
let event = Event::Output(self.elapsed_time(), raw.into()); true
let _ = self.pty_tx.send(event);
} }
fn input(&mut self, raw: &[u8]) -> bool { fn input(&mut self, raw: &[u8]) -> bool {
@@ -185,9 +182,11 @@ impl pty::Recorder for Streamer {
true true
} }
fn resize(&mut self, size: crate::tty::TtySize) { fn resize(&mut self, size: crate::tty::TtySize) -> bool {
let event = Event::Resize(self.elapsed_time(), size); let event = Event::Resize(self.elapsed_time(), size);
let _ = self.pty_tx.send(event); let _ = self.pty_tx.send(event);
true
} }
} }