Make SignalFd handle multiple signals with single pipe

This commit is contained in:
Marcin Kulik
2025-06-05 11:17:54 +02:00
parent 5051cc78bc
commit 0676b54033

View File

@@ -6,6 +6,7 @@ use std::io::{self, ErrorKind, Read, Write};
use std::os::fd::AsFd;
use std::os::fd::{BorrowedFd, OwnedFd};
use std::os::unix::io::AsRawFd;
use std::sync::Arc;
use std::time::{Duration, Instant};
use anyhow::bail;
@@ -100,13 +101,8 @@ fn copy<T: Tty, H: Handler>(
let mut output: Vec<u8> = Vec::with_capacity(BUF_SIZE);
let mut master_closed = false;
let sigwinch_fd = SignalFd::open(SIGWINCH)?;
let sigint_fd = SignalFd::open(SIGINT)?;
let sigterm_fd = SignalFd::open(SIGTERM)?;
let sigquit_fd = SignalFd::open(SIGQUIT)?;
let sighup_fd = SignalFd::open(SIGHUP)?;
let sigalrm_fd = SignalFd::open(SIGALRM)?;
let sigchld_fd = SignalFd::open(SIGCHLD)?;
let mut signal_fd =
SignalFd::open(&[SIGWINCH, SIGINT, SIGTERM, SIGQUIT, SIGHUP, SIGALRM, SIGCHLD])?;
set_non_blocking(&master)?;
@@ -117,13 +113,7 @@ fn copy<T: Tty, H: Handler>(
let mut wfds = FdSet::new();
rfds.insert(tty_fd);
rfds.insert(sigwinch_fd.as_fd());
rfds.insert(sigint_fd.as_fd());
rfds.insert(sigterm_fd.as_fd());
rfds.insert(sigquit_fd.as_fd());
rfds.insert(sighup_fd.as_fd());
rfds.insert(sigalrm_fd.as_fd());
rfds.insert(sigchld_fd.as_fd());
rfds.insert(signal_fd.as_fd());
if !master_closed {
rfds.insert(master_fd);
@@ -149,13 +139,7 @@ fn copy<T: Tty, H: Handler>(
let master_write = wfds.contains(master_fd);
let tty_read = rfds.contains(tty_fd);
let tty_write = wfds.contains(tty_fd);
let sigwinch_read = rfds.contains(sigwinch_fd.as_fd());
let sigint_read = rfds.contains(sigint_fd.as_fd());
let sigterm_read = rfds.contains(sigterm_fd.as_fd());
let sigquit_read = rfds.contains(sigquit_fd.as_fd());
let sighup_read = rfds.contains(sighup_fd.as_fd());
let sigalrm_read = rfds.contains(sigalrm_fd.as_fd());
let sigchld_read = rfds.contains(sigchld_fd.as_fd());
let signal_read = rfds.contains(signal_fd.as_fd());
if master_read {
while let Some(n) = read_non_blocking(&mut master, &mut buf)? {
@@ -228,47 +212,32 @@ fn copy<T: Tty, H: Handler>(
}
}
if sigwinch_read {
sigwinch_fd.flush();
let winsize = tty.get_size();
if handler.resize(epoch.elapsed(), winsize.into()) {
set_pty_size(master_raw_fd, &winsize);
}
}
let mut kill_the_child = false;
if sigint_read {
sigint_fd.flush();
kill_the_child = true;
}
if signal_read {
for signal in signal_fd.flush() {
match signal {
SIGWINCH => {
let winsize = tty.get_size();
if sigterm_read {
sigterm_fd.flush();
kill_the_child = true;
}
if handler.resize(epoch.elapsed(), winsize.into()) {
set_pty_size(master_raw_fd, &winsize);
}
}
if sigquit_read {
sigquit_fd.flush();
kill_the_child = true;
}
SIGINT | SIGTERM | SIGQUIT | SIGHUP => {
kill_the_child = true;
}
if sighup_read {
sighup_fd.flush();
kill_the_child = true;
}
SIGCHLD => {
if let Ok(status) = wait::waitpid(child, Some(WaitPidFlag::WNOHANG)) {
if status != WaitStatus::StillAlive {
return Ok(Some(status));
}
}
}
if sigalrm_read {
sigalrm_fd.flush();
}
if sigchld_read {
sigchld_fd.flush();
if let Ok(status) = wait::waitpid(child, Some(WaitPidFlag::WNOHANG)) {
if status != WaitStatus::StillAlive {
return Ok(Some(status));
_ => {}
}
}
}
@@ -338,33 +307,51 @@ fn write_non_blocking<W: Write + ?Sized>(sink: &mut W, buf: &[u8]) -> io::Result
}
struct SignalFd {
sigid: SigId,
sigids: Vec<SigId>,
rx: OwnedFd,
}
impl SignalFd {
fn open(signal: libc::c_int) -> anyhow::Result<Self> {
fn open(signals: &[libc::c_int]) -> anyhow::Result<Self> {
let (rx, tx) = unistd::pipe()?;
set_non_blocking(&rx)?;
set_non_blocking(&tx)?;
let sigid = unsafe {
signal_hook::low_level::register(signal, move || {
let _ = unistd::write(&tx, &[0]);
})
}?;
let tx = Arc::new(tx);
Ok(Self { sigid, rx })
let mut sigids = Vec::new();
for signal in signals {
let tx_ = Arc::clone(&tx);
let num = *signal as u8;
let sigid = unsafe {
signal_hook::low_level::register(*signal, move || {
let _ = unistd::write(&tx_, &[num]);
})
}?;
sigids.push(sigid);
}
Ok(Self { sigids, rx })
}
fn flush(&self) {
fn flush(&mut self) -> Vec<i32> {
let mut buf = [0; 256];
let mut signals = Vec::new();
while let Ok(n) = unistd::read(&self.rx, &mut buf) {
for num in &buf[..n] {
signals.push(*num as i32);
}
if n == 0 {
break;
};
}
signals
}
}
@@ -376,7 +363,9 @@ impl AsFd for SignalFd {
impl Drop for SignalFd {
fn drop(&mut self) {
signal_hook::low_level::unregister(self.sigid);
for sigid in &self.sigids {
signal_hook::low_level::unregister(*sigid);
}
}
}