diff --git a/Cargo.toml b/Cargo.toml index 2cf3556..d94d737 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,7 +17,7 @@ termion = "2.0.1" serde = { version = "1.0.189", features = ["derive"] } serde_json = "1.0.107" clap = { version = "4.4.7", features = ["derive"] } -signal-hook = "0.3.17" +signal-hook = { version = "0.3.17", default-features = false } uuid = { version = "1.6.1", features = ["v4"] } reqwest = { version = "0.11.23", default-features = false, features = ["blocking", "rustls-tls", "multipart", "gzip", "json"] } rustyline = "13.0.0" diff --git a/src/pty.rs b/src/pty.rs index d1f5f1b..4191296 100644 --- a/src/pty.rs +++ b/src/pty.rs @@ -1,12 +1,16 @@ use crate::io::set_non_blocking; use crate::tty::Tty; -use anyhow::bail; +use anyhow::{bail, Result}; use nix::errno::Errno; -use nix::sys::select::{pselect, FdSet}; +use nix::sys::select::{select, FdSet}; +use nix::unistd::pipe; use nix::{libc, pty, sys::signal, sys::wait, unistd, unistd::ForkResult}; +use signal_hook::consts::{SIGHUP, SIGINT, SIGQUIT, SIGTERM, SIGWINCH}; +use signal_hook::SigId; use std::collections::HashMap; use std::ffi::{CString, NulError}; use std::io::{self, Read, Write}; +use std::os::fd::BorrowedFd; use std::os::fd::{AsFd, RawFd}; use std::os::unix::io::{AsRawFd, FromRawFd}; use std::{env, fs}; @@ -80,6 +84,11 @@ fn copy( let mut input: Vec = Vec::with_capacity(BUF_SIZE); let mut output: Vec = Vec::with_capacity(BUF_SIZE); let mut flush = false; + let sigwinch_fd = SignalFd::open(SIGWINCH)?; + let sigint_fd = SignalFd::open(SIGINT)?; + let sigterm_fd = SignalFd::open(SIGTERM)?; + let sigquit_fd = SignalFd::open(SIGQUIT)?; + let sighup_fd = SignalFd::open(SIGHUP)?; set_non_blocking(&master_raw_fd)?; @@ -90,6 +99,11 @@ fn copy( let mut wfds = FdSet::new(); rfds.insert(&tty_fd); + rfds.insert(&sigwinch_fd); + rfds.insert(&sigint_fd); + rfds.insert(&sigterm_fd); + rfds.insert(&sigquit_fd); + rfds.insert(&sighup_fd); if !flush { rfds.insert(&master_fd); @@ -103,7 +117,7 @@ fn copy( wfds.insert(&tty_fd); } - if let Err(e) = pselect(None, &mut rfds, &mut wfds, None, None, None) { + if let Err(e) = select(None, &mut rfds, &mut wfds, None, None) { if e == Errno::EINTR { continue; } else { @@ -115,6 +129,11 @@ fn copy( let master_write = wfds.contains(&master_fd); let tty_read = rfds.contains(&tty_fd); let tty_write = wfds.contains(&tty_fd); + let sigwinch_read = rfds.contains(&sigwinch_fd); + let sigint_read = rfds.contains(&sigint_fd); + let sigterm_read = rfds.contains(&sigterm_fd); + let sigquit_read = rfds.contains(&sigquit_fd); + let sighup_read = rfds.contains(&sighup_fd); if master_read { let offset = output.len(); @@ -149,6 +168,35 @@ fn copy( recorder.input(&input[offset..]); } } + + if sigwinch_read { + sigwinch_fd.flush(); + let winsize = get_tty_size(&*tty, winsize_override); + set_pty_size(master_raw_fd, &winsize); + recorder.resize((winsize.ws_col, winsize.ws_row)); + } + + if sigint_read { + sigint_fd.flush(); + } + + if sigterm_read || sigquit_read || sighup_read { + if sigterm_read { + sigterm_fd.flush(); + } + + if sigquit_read { + sigquit_fd.flush(); + } + + if sighup_read { + sighup_fd.flush(); + } + + unsafe { libc::kill(child.as_raw(), SIGTERM) }; + + return Ok(()); + } } } @@ -241,6 +289,49 @@ fn write_all(sink: &mut W, data: &mut Vec) -> io::Result { Ok(left) } +struct SignalFd { + sigid: SigId, + rx: i32, +} + +impl SignalFd { + fn open(signal: libc::c_int) -> Result { + let (rx, tx) = pipe()?; + set_non_blocking(&rx)?; + set_non_blocking(&tx)?; + + let sigid = unsafe { + signal_hook::low_level::register(signal, move || { + let _ = unistd::write(tx, &[0]); + }) + }?; + + Ok(Self { sigid, rx }) + } + + fn flush(&self) { + let mut buf = [0; 256]; + + while let Ok(n) = unistd::read(self.rx, &mut buf) { + if n == 0 { + break; + }; + } + } +} + +impl AsFd for SignalFd { + fn as_fd(&self) -> BorrowedFd<'_> { + unsafe { BorrowedFd::borrow_raw(self.rx) } + } +} + +impl Drop for SignalFd { + fn drop(&mut self) { + signal_hook::low_level::unregister(self.sigid); + } +} + #[cfg(test)] mod tests { use crate::pty::ExtraEnv;