From caf0cf3707fff9a840b19449914fb48f05a5b9b1 Mon Sep 17 00:00:00 2001 From: Marcin Kulik Date: Fri, 6 Jun 2025 14:56:14 +0200 Subject: [PATCH] Refactor session/pty/tty --- Cargo.lock | 99 ++++----- Cargo.toml | 5 +- src/alis.rs | 1 - src/cmd/play.rs | 19 +- src/cmd/session.rs | 226 ++++++++++--------- src/fd.rs | 17 ++ src/file_writer.rs | 41 ++-- src/forwarder.rs | 26 ++- src/io.rs | 15 -- src/main.rs | 2 +- src/notifier.rs | 61 +++--- src/player.rs | 176 +++++++-------- src/pty.rs | 532 +++++++++------------------------------------ src/server.rs | 5 +- src/session.rs | 320 ++++++++++++++++++--------- src/stream.rs | 37 ++-- src/tty.rs | 346 ++++++++++++++--------------- src/util.rs | 26 +-- 18 files changed, 840 insertions(+), 1114 deletions(-) create mode 100644 src/fd.rs delete mode 100644 src/io.rs diff --git a/Cargo.lock b/Cargo.lock index 607532c..084f3fb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -87,6 +87,7 @@ name = "asciinema" version = "3.0.0-rc.4" dependencies = [ "anyhow", + "async-trait", "avt", "axum", "clap", @@ -104,8 +105,8 @@ dependencies = [ "serde", "serde_json", "signal-hook", + "signal-hook-tokio", "tempfile", - "termion", "tokio", "tokio-stream", "tokio-tungstenite", @@ -131,6 +132,17 @@ dependencies = [ "tokio", ] +[[package]] +name = "async-trait" +version = "0.1.88" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e539d3fca749fcee5236ab05e93a52867dd549cc157c8cb7f99595f3cedffdb5" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "autocfg" version = "1.4.0" @@ -248,7 +260,7 @@ version = "0.69.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "271383c67ccabffb7381723dea0672a673f292304fcb45c01cc648c7a8d58088" dependencies = [ - "bitflags 2.9.0", + "bitflags", "cexpr", "clang-sys", "itertools", @@ -265,12 +277,6 @@ dependencies = [ "which 4.4.2", ] -[[package]] -name = "bitflags" -version = "1.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" - [[package]] name = "bitflags" version = "2.9.0" @@ -1044,17 +1050,6 @@ dependencies = [ "windows-targets 0.52.6", ] -[[package]] -name = "libredox" -version = "0.0.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3af92c55d7d839293953fcd0fda5ecfe93297cfde6ffbdec13b41d99c0ba6607" -dependencies = [ - "bitflags 2.9.0", - "libc", - "redox_syscall", -] - [[package]] name = "linux-raw-sys" version = "0.4.15" @@ -1148,7 +1143,7 @@ version = "0.27.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2eb04e9c688eff1c89d72b407f168cf79bb9e867a9d3323ed6c01519eb9cc053" dependencies = [ - "bitflags 2.9.0", + "bitflags", "cfg-if", "libc", ] @@ -1159,7 +1154,7 @@ version = "0.30.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "74523f3a35e05aba87a1d978330aef40f67b0304ac79c1c00b294c9830543db6" dependencies = [ - "bitflags 2.9.0", + "bitflags", "cfg-if", "cfg_aliases", "libc", @@ -1175,12 +1170,6 @@ dependencies = [ "minimal-lexical", ] -[[package]] -name = "numtoa" -version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b8f8bdf33df195859076e54ab11ee78a1b208382d3a26ec40d142ffc1ecc49ef" - [[package]] name = "object" version = "0.36.7" @@ -1352,21 +1341,6 @@ dependencies = [ "getrandom 0.3.2", ] -[[package]] -name = "redox_syscall" -version = "0.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4722d768eff46b75989dd134e5c353f0d6296e5aaa3132e776cbdb56be7731aa" -dependencies = [ - "bitflags 1.3.2", -] - -[[package]] -name = "redox_termios" -version = "0.1.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "20145670ba436b55d91fc92d25e71160fbfbdd57831631c8d7d36377a476f1cb" - [[package]] name = "regex" version = "1.11.1" @@ -1545,7 +1519,7 @@ version = "0.38.44" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fdb5bc1ae2baa591800df16c9ca78619bf65c0488b41b96ccec5d11220d8c154" dependencies = [ - "bitflags 2.9.0", + "bitflags", "errno", "libc", "linux-raw-sys 0.4.15", @@ -1558,7 +1532,7 @@ version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c71e83d6afe7ff64890ec6b71d6a69bb8a610ab78ce364b3352876bb4c801266" dependencies = [ - "bitflags 2.9.0", + "bitflags", "errno", "libc", "linux-raw-sys 0.9.4", @@ -1634,7 +1608,7 @@ version = "13.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "02a2d683a4ac90aeef5b1013933f6d977bd37d51ff3f4dad829d4931a7e6be86" dependencies = [ - "bitflags 2.9.0", + "bitflags", "cfg-if", "clipboard-win", "libc", @@ -1677,7 +1651,7 @@ version = "3.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "271720403f46ca04f7ba6f55d438f8bd878d6b8ca0a1046e8228c4145bcbb316" dependencies = [ - "bitflags 2.9.0", + "bitflags", "core-foundation", "core-foundation-sys", "libc", @@ -1803,6 +1777,18 @@ dependencies = [ "libc", ] +[[package]] +name = "signal-hook-tokio" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "213241f76fb1e37e27de3b6aa1b068a2c333233b59cca6634f634b80a27ecf1e" +dependencies = [ + "futures-core", + "libc", + "signal-hook", + "tokio", +] + [[package]] name = "slab" version = "0.4.9" @@ -1890,18 +1876,6 @@ dependencies = [ "windows-sys 0.59.0", ] -[[package]] -name = "termion" -version = "3.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "417813675a504dfbbf21bfde32c03e5bf9f2413999962b479023c02848c1c7a5" -dependencies = [ - "libc", - "libredox", - "numtoa", - "redox_termios", -] - [[package]] name = "thiserror" version = "2.0.12" @@ -1959,15 +1933,16 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.44.2" +version = "1.45.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e6b88822cbe49de4185e3a4cbf8321dd487cf5fe0c5c65695fef6346371e9c48" +checksum = "75ef51a33ef1da925cea3e4eb122833cb377c61439ca401b770f54902b806779" dependencies = [ "backtrace", "bytes", "libc", "mio", "pin-project-lite", + "signal-hook-registry", "socket2", "tokio-macros", "windows-sys 0.52.0", @@ -2092,7 +2067,7 @@ version = "0.6.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "403fa3b783d4b626a8ad51d766ab03cb6d2dbfc46b1c5d4448395e6628dc9697" dependencies = [ - "bitflags 2.9.0", + "bitflags", "bytes", "http", "http-body", @@ -2641,7 +2616,7 @@ version = "0.39.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6f42320e61fe2cfd34354ecb597f86f413484a798ba44a8ca1165c58d42da6c1" dependencies = [ - "bitflags 2.9.0", + "bitflags", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index c5c2aa6..54129b8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,7 +16,6 @@ rust-version = "1.75.0" [dependencies] anyhow = "1.0.98" nix = { version = "0.30", features = ["fs", "term", "process", "signal", "poll"] } -termion = "3.0.0" serde = { version = "1.0.219", features = ["derive"] } serde_json = "1.0.140" clap = { version = "4.5.37", features = ["derive"] } @@ -29,7 +28,7 @@ which = "6.0.3" tempfile = "3.9.0" avt = "0.16.0" axum = { version = "0.8.4", default-features = false, features = ["http1", "ws"] } -tokio = { version = "1.44.2", features = ["rt-multi-thread", "net", "sync", "time"] } +tokio = { version = "1.45.1", features = ["rt-multi-thread", "net", "sync", "time", "fs", "process"] } futures-util = { version = "0.3.31", default-features = false, features = ["sink"] } tokio-stream = { version = "0.1.17", default-features = false, features = ["sync", "time"] } rust-embed = "8.2.0" @@ -42,6 +41,8 @@ tokio-tungstenite = { version = "0.26.2", default-features = false, features = [ rustls = { version = "0.23.26", default-features = false, features = ["aws_lc_rs"] } tokio-util = { version = "0.7.10", features = ["rt"] } rand = "0.9.1" +async-trait = "0.1.88" +signal-hook-tokio = { version = "0.3.1", features = ["futures-v0_3"] } [build-dependencies] clap = { version = "4.5.37", features = ["derive"] } diff --git a/src/alis.rs b/src/alis.rs index 90deb51..bf16abb 100644 --- a/src/alis.rs +++ b/src/alis.rs @@ -6,7 +6,6 @@ use std::future; -use anyhow::Result; use futures_util::{stream, Stream, StreamExt}; use tokio_stream::wrappers::errors::BroadcastStreamRecvError; diff --git a/src/cmd/play.rs b/src/cmd/play.rs index 1fc5dfb..42a1d4c 100644 --- a/src/cmd/play.rs +++ b/src/cmd/play.rs @@ -1,37 +1,34 @@ -use anyhow::Result; +use tokio::runtime::Runtime; use crate::asciicast; use crate::cli; use crate::config::{self, Config}; use crate::player::{self, KeyBindings}; use crate::status; -use crate::tty; use crate::util; impl cli::Play { - pub fn run(self) -> Result<()> { + pub fn run(self) -> anyhow::Result<()> { let config = Config::new(None)?; let speed = self.speed.or(config.playback.speed).unwrap_or(1.0); let idle_time_limit = self.idle_time_limit.or(config.playback.idle_time_limit); + let path = util::get_local_path(&self.file)?; + let keys = get_key_bindings(&config.playback)?; + let runtime = Runtime::new()?; status::info!("Replaying session from {}", self.file); - let path = util::get_local_path(&self.file)?; - let keys = get_key_bindings(&config.playback)?; - let ended = loop { let recording = asciicast::open_from_path(&*path)?; - let tty = tty::DevTty::open()?; - let ended = player::play( + let ended = runtime.block_on(player::play( recording, - tty, speed, idle_time_limit, self.pause_on_markers, &keys, self.resize, - )?; + ))?; if !self.loop_ || !ended { break ended; @@ -48,7 +45,7 @@ impl cli::Play { } } -fn get_key_bindings(config: &config::Playback) -> Result { +fn get_key_bindings(config: &config::Playback) -> anyhow::Result { let mut keys = KeyBindings::default(); if let Some(key) = config.pause_key()? { diff --git a/src/cmd/session.rs b/src/cmd/session.rs index 960b2ac..0f4e556 100644 --- a/src/cmd/session.rs +++ b/src/cmd/session.rs @@ -1,13 +1,11 @@ use std::collections::{HashMap, HashSet}; use std::env; -use std::fs::{self, File, OpenOptions}; -use std::io::LineWriter; -use std::net::TcpListener; use std::path::{Path, PathBuf}; use std::process::ExitCode; use std::time::{Duration, SystemTime}; use anyhow::{anyhow, bail, Context, Result}; +use tokio::net::TcpListener; use tokio::runtime::Runtime; use tokio::time; use tokio_util::sync::CancellationToken; @@ -20,15 +18,14 @@ use crate::api; use crate::asciicast::{self, Version}; use crate::cli::{self, Format, RelayTarget}; use crate::config::{self, Config}; -use crate::encoder::{AsciicastV2Encoder, AsciicastV3Encoder, RawEncoder, TextEncoder}; +use crate::encoder::{AsciicastV2Encoder, AsciicastV3Encoder, Encoder, RawEncoder, TextEncoder}; use crate::file_writer::FileWriter; use crate::forwarder; use crate::hash; use crate::locale; -use crate::notifier::{self, Notifier, NullNotifier}; -use crate::pty; +use crate::notifier::{self, BackgroundNotifier, Notifier, NullNotifier}; use crate::server; -use crate::session::{self, KeyBindings, Metadata, Session, TermInfo}; +use crate::session::{self, KeyBindings, Metadata, TermInfo}; use crate::status; use crate::stream::Stream; use crate::tty::{DevTty, FixedSizeTty, NullTty, Tty}; @@ -37,15 +34,25 @@ impl cli::Session { pub fn run(mut self) -> Result { locale::check_utf8_locale()?; + let exit_status = Runtime::new()?.block_on(self.do_run())?; + + if !self.return_ || exit_status == 0 { + Ok(ExitCode::from(0)) + } else if exit_status > 0 { + Ok(ExitCode::from(exit_status as u8)) + } else { + Ok(ExitCode::from(1)) + } + } + + async fn do_run(&mut self) -> Result { let config = Config::new(self.server_url.clone())?; - let runtime = Runtime::new()?; let command = self.get_command(&config.recording); let keys = get_key_bindings(&config.recording)?; - let notifier = notifier::threaded(get_notifier(&config)); - let signal_fd = pty::open_signal_fd()?; - let metadata = self.get_session_metadata(&config.recording)?; - let file_writer = self.get_file_writer(&metadata, notifier.clone())?; - let listener = self.get_listener()?; + let notifier = get_notifier(&config); + let metadata = self.get_session_metadata(&config.recording).await?; + let file_writer = self.get_file_writer(&metadata, notifier.clone()).await?; + let listener = self.get_listener().await?; let relay = self.get_relay(&metadata, &config)?; let relay_id = relay.as_ref().map(|r| r.id()); let parent_session_relay_id = get_parent_session_relay_id(); @@ -101,12 +108,12 @@ impl cli::Session { let mut outputs: Vec> = Vec::new(); if let Some(writer) = file_writer { - let output = writer.start()?; + let output = writer.start().await?; outputs.push(Box::new(output)); } let server = listener.map(|listener| { - runtime.spawn(server::serve( + tokio::spawn(server::serve( listener, stream.subscriber(), shutdown_token.clone(), @@ -114,7 +121,7 @@ impl cli::Session { }); let forwarder = relay.map(|relay| { - runtime.spawn(forwarder::forward( + tokio::spawn(forwarder::forward( relay.ws_producer_url, stream.subscriber(), notifier.clone(), @@ -123,67 +130,52 @@ impl cli::Session { }); if server.is_some() || forwarder.is_some() { - let output = stream.start(runtime.handle().clone(), &metadata); + let output = stream.start(&metadata).await; outputs.push(Box::new(output)); } - let exit_status = { - let mut tty = self.get_tty(true)?; + let command = &build_exec_command(command.as_ref().cloned()); + let extra_env = &build_exec_extra_env(relay_id.as_ref()); - let mut session = Session::new( - outputs, - metadata.term.size, + let exit_status = { + let mut tty = self.get_tty(true).await?; + + session::run( + command, + extra_env, + tty.as_mut(), self.rec_input || config.recording.rec_input, + outputs, keys, notifier, - ); - - pty::exec( - &build_exec_command(command.as_ref().cloned()), - &build_exec_extra_env(relay_id.as_ref()), - metadata.term.size, - &mut tty, - &mut session, - signal_fd, - )? + ) + .await? }; - runtime.block_on(async { - debug!("session shutting down..."); - shutdown_token.cancel(); - - if let Some(task) = server { - debug!("waiting for server shutdown..."); - let _ = time::timeout(Duration::from_secs(5), task).await; - } - - if let Some(task) = forwarder { - debug!("waiting for forwarder shutdown..."); - let _ = time::timeout(Duration::from_secs(5), task).await; - } - - debug!("shutdown complete"); - }); - status::info!("asciinema session ended"); + shutdown_token.cancel(); - if !self.return_ || exit_status == 0 { - Ok(ExitCode::from(0)) - } else if exit_status > 0 { - Ok(ExitCode::from(exit_status as u8)) - } else { - Ok(ExitCode::from(1)) + if let Some(task) = server { + debug!("waiting for server shutdown..."); + let _ = time::timeout(Duration::from_secs(5), task).await; } + + if let Some(task) = forwarder { + debug!("waiting for forwarder shutdown..."); + let _ = time::timeout(Duration::from_secs(5), task).await; + } + + Ok(exit_status) } fn get_command(&self, config: &config::Recording) -> Option { self.command.as_ref().cloned().or(config.command.clone()) } - fn get_session_metadata(&self, config: &config::Recording) -> Result { + async fn get_session_metadata(&self, config: &config::Recording) -> Result { Ok(Metadata { time: SystemTime::now(), - term: self.get_term_info()?, + term: self.get_term_info().await?, idle_time_limit: self.idle_time_limit.or(config.idle_time_limit), command: self.get_command(config), title: self.title.clone(), @@ -191,18 +183,18 @@ impl cli::Session { }) } - fn get_term_info(&self) -> Result { - let tty = self.get_tty(false)?; + async fn get_term_info(&self) -> Result { + let tty = self.get_tty(false).await?; Ok(TermInfo { type_: env::var("TERM").ok(), - version: tty.get_version(), + version: tty.get_version().await, size: tty.get_size().into(), - theme: tty.get_theme(), + theme: tty.get_theme().await, }) } - fn get_file_writer( + async fn get_file_writer( &self, metadata: &Metadata, notifier: N, @@ -213,47 +205,18 @@ impl cli::Session { let path = Path::new(path); let (overwrite, append) = self.get_file_mode(path)?; - let file = self.open_output_file(path, overwrite, append)?; + let file = self.open_output_file(path, overwrite, append).await?; let format = self.get_file_format(path, append)?; + let writer = Box::new(file); let notifier = Box::new(notifier); + let encoder = self.get_encoder(format, path, append)?; - let file_writer = match format { - Format::AsciicastV3 => { - let writer = Box::new(LineWriter::new(file)); - let encoder = Box::new(AsciicastV3Encoder::new(append)); - - FileWriter::new(writer, encoder, notifier, metadata.clone()) - } - - Format::AsciicastV2 => { - let time_offset = if append { - asciicast::get_duration(path)? - } else { - 0 - }; - - let writer = Box::new(LineWriter::new(file)); - let encoder = Box::new(AsciicastV2Encoder::new(append, time_offset)); - - FileWriter::new(writer, encoder, notifier, metadata.clone()) - } - - Format::Raw => { - let writer = Box::new(file); - let encoder = Box::new(RawEncoder::new()); - - FileWriter::new(writer, encoder, notifier, metadata.clone()) - } - - Format::Txt => { - let writer = Box::new(file); - let encoder = Box::new(TextEncoder::new()); - - FileWriter::new(writer, encoder, notifier, metadata.clone()) - } - }; - - Ok(Some(file_writer)) + Ok(Some(FileWriter::new( + writer, + encoder, + notifier, + metadata.clone(), + ))) } fn get_file_mode(&self, path: &Path) -> Result<(bool, bool)> { @@ -261,7 +224,7 @@ impl cli::Session { let mut append = self.append; if path.exists() { - let metadata = fs::metadata(path)?; + let metadata = std::fs::metadata(path)?; if metadata.len() == 0 { overwrite = true; @@ -298,27 +261,58 @@ impl cli::Session { }) } - fn open_output_file(&self, path: &Path, overwrite: bool, append: bool) -> Result { + fn get_encoder( + &self, + format: Format, + path: &Path, + append: bool, + ) -> Result> { + match format { + Format::AsciicastV3 => Ok(Box::new(AsciicastV3Encoder::new(append))), + + Format::AsciicastV2 => { + let time_offset = if append { + asciicast::get_duration(path)? + } else { + 0 + }; + + Ok(Box::new(AsciicastV2Encoder::new(append, time_offset))) + } + + Format::Raw => Ok(Box::new(RawEncoder::new())), + Format::Txt => Ok(Box::new(TextEncoder::new())), + } + } + + async fn open_output_file( + &self, + path: &Path, + overwrite: bool, + append: bool, + ) -> Result { if let Some(dir) = path.parent() { - let _ = fs::create_dir_all(dir); + let _ = std::fs::create_dir_all(dir); } - OpenOptions::new() + tokio::fs::File::options() .write(true) .append(append) .create(overwrite) .create_new(!overwrite && !append) .truncate(overwrite) .open(path) + .await .map_err(|e| e.into()) } - fn get_listener(&self) -> Result> { + async fn get_listener(&self) -> Result> { let Some(addr) = self.stream_local else { return Ok(None); }; TcpListener::bind(addr) + .await .map(Some) .context("cannot start listener") } @@ -348,19 +342,19 @@ impl cli::Session { Ok(Some(relay)) } - fn get_tty(&self, quiet: bool) -> Result { + async fn get_tty(&self, quiet: bool) -> Result> { let (cols, rows) = self.window_size.unwrap_or((None, None)); if self.headless { - Ok(FixedSizeTty::new(NullTty::open()?, cols, rows)) - } else if let Ok(dev_tty) = DevTty::open() { - Ok(FixedSizeTty::new(dev_tty, cols, rows)) + Ok(Box::new(FixedSizeTty::new(NullTty, cols, rows))) + } else if let Ok(dev_tty) = DevTty::open().await { + Ok(Box::new(FixedSizeTty::new(dev_tty, cols, rows))) } else { if !quiet { status::info!("TTY not available, recording in headless mode"); } - Ok(FixedSizeTty::new(NullTty::open()?, cols, rows)) + Ok(Box::new(FixedSizeTty::new(NullTty, cols, rows))) } } @@ -384,8 +378,8 @@ impl cli::Session { Ok(()) } - fn open_log_file(&self, path: &PathBuf) -> Result { - OpenOptions::new() + fn open_log_file(&self, path: &PathBuf) -> Result { + std::fs::File::options() .create(true) .append(true) .open(path) @@ -470,12 +464,14 @@ fn capture_env(var_names: Option, config: &config::Recording) -> HashMap .collect::>() } -fn get_notifier(config: &Config) -> Box { - if config.notifications.enabled { +fn get_notifier(config: &Config) -> BackgroundNotifier { + let inner = if config.notifications.enabled { notifier::get_notifier(config.notifications.command.clone()) } else { Box::new(NullNotifier) - } + }; + + notifier::background(inner) } fn build_exec_command(command: Option) -> Vec { diff --git a/src/fd.rs b/src/fd.rs new file mode 100644 index 0000000..0477211 --- /dev/null +++ b/src/fd.rs @@ -0,0 +1,17 @@ +use std::io; +use std::os::fd::AsFd; + +use nix::fcntl::{self, FcntlArg::*, OFlag}; + +pub trait FdExt: AsFd { + fn set_nonblocking(&self) -> io::Result<()> { + let flags = fcntl::fcntl(self.as_fd(), F_GETFL)?; + let mut oflags = OFlag::from_bits_truncate(flags); + oflags |= OFlag::O_NONBLOCK; + fcntl::fcntl(self.as_fd(), F_SETFL(oflags))?; + + Ok(()) + } +} + +impl FdExt for T {} diff --git a/src/file_writer.rs b/src/file_writer.rs index a26a1fd..3b38836 100644 --- a/src/file_writer.rs +++ b/src/file_writer.rs @@ -1,28 +1,30 @@ -use std::io::{self, Write}; use std::time::UNIX_EPOCH; +use async_trait::async_trait; +use tokio::io::{self, AsyncWrite, AsyncWriteExt}; + use crate::asciicast; -use crate::encoder; +use crate::encoder::Encoder; use crate::notifier::Notifier; use crate::session::{self, Metadata}; pub struct FileWriter { - writer: Box, - encoder: Box, + writer: Box, + encoder: Box, notifier: Box, metadata: Metadata, } pub struct LiveFileWriter { - writer: Box, - encoder: Box, + writer: Box, + encoder: Box, notifier: Box, } impl FileWriter { pub fn new( - writer: Box, - encoder: Box, + writer: Box, + encoder: Box, notifier: Box, metadata: Metadata, ) -> Self { @@ -34,7 +36,7 @@ impl FileWriter { } } - pub fn start(mut self) -> io::Result { + pub async fn start(mut self) -> io::Result { let timestamp = self .metadata .time @@ -55,10 +57,11 @@ impl FileWriter { env: Some(self.metadata.env.clone()), }; - if let Err(e) = self.writer.write_all(&self.encoder.header(&header)) { + if let Err(e) = self.writer.write_all(&self.encoder.header(&header)).await { let _ = self .notifier - .notify("Write error, session won't be recorded".to_owned()); + .notify("Write error, session won't be recorded".to_owned()) + .await; return Err(e); } @@ -71,23 +74,29 @@ impl FileWriter { } } +#[async_trait] impl session::Output for LiveFileWriter { - fn event(&mut self, event: session::Event) -> io::Result<()> { - match self.writer.write_all(&self.encoder.event(event.into())) { + async fn event(&mut self, event: session::Event) -> io::Result<()> { + match self + .writer + .write_all(&self.encoder.event(event.into())) + .await + { Ok(_) => Ok(()), Err(e) => { let _ = self .notifier - .notify("Write error, recording suspended".to_owned()); + .notify("Write error, recording suspended".to_owned()) + .await; Err(e) } } } - fn flush(&mut self) -> io::Result<()> { - self.writer.write_all(&self.encoder.flush()) + async fn flush(&mut self) -> io::Result<()> { + self.writer.write_all(&self.encoder.flush()).await } } diff --git a/src/forwarder.rs b/src/forwarder.rs index e569d1f..c4614c1 100644 --- a/src/forwarder.rs +++ b/src/forwarder.rs @@ -49,9 +49,9 @@ pub async fn forward( _ = time::sleep(Duration::from_secs(3)) => { if reconnect_attempt > 0 { if connection_count == 0 { - let _ = notifier.notify("Connected to the server".to_string()); + let _ = notifier.notify("Connected to the server".to_string()).await; } else { - let _ = notifier.notify("Reconnected to the server".to_string()); + let _ = notifier.notify("Reconnected to the server".to_string()).await; } } @@ -68,7 +68,10 @@ pub async fn forward( } Ok(false) => { - let _ = notifier.notify("Stream halted by the server".to_string()); + let _ = notifier + .notify("Stream halted by the server".to_string()) + .await; + break; } @@ -82,7 +85,8 @@ pub async fn forward( // This applies to asciinema-server v20241103 and earlier. let _ = notifier - .notify("The server version is too old, forwarding failed".to_string()); + .notify("The server version is too old, forwarding failed".to_string()) + .await; break; } @@ -94,9 +98,11 @@ pub async fn forward( // This happens when the server doesn't support our protocol (version). // This applies to asciinema-server versions newer than v20241103. - let _ = notifier.notify( - "CLI not compatible with the server, forwarding failed".to_string(), - ); + let _ = notifier + .notify( + "CLI not compatible with the server, forwarding failed".to_string(), + ) + .await; break; } @@ -107,10 +113,12 @@ pub async fn forward( if reconnect_attempt == 0 { if connection_count == 0 { let _ = notifier - .notify("Cannot connect to the server, retrying...".to_string()); + .notify("Cannot connect to the server, retrying...".to_string()) + .await; } else { let _ = notifier - .notify("Disconnected from the server, reconnecting...".to_string()); + .notify("Disconnected from the server, reconnecting...".to_string()) + .await; } } } diff --git a/src/io.rs b/src/io.rs deleted file mode 100644 index d01d01a..0000000 --- a/src/io.rs +++ /dev/null @@ -1,15 +0,0 @@ -use std::io; -use std::os::fd::AsFd; - -use anyhow::Result; - -pub fn set_non_blocking(fd: &T) -> Result<(), io::Error> { - use nix::fcntl::{fcntl, FcntlArg::*, OFlag}; - - let flags = fcntl(fd, F_GETFL)?; - let mut oflags = OFlag::from_bits_truncate(flags); - oflags |= OFlag::O_NONBLOCK; - fcntl(fd, F_SETFL(oflags))?; - - Ok(()) -} diff --git a/src/main.rs b/src/main.rs index dbfbd78..0b14a3c 100644 --- a/src/main.rs +++ b/src/main.rs @@ -5,11 +5,11 @@ mod cli; mod cmd; mod config; mod encoder; +mod fd; mod file_writer; mod forwarder; mod hash; mod html; -mod io; mod leb128; mod locale; mod notifier; diff --git a/src/notifier.rs b/src/notifier.rs index 71d83ec..22e6329 100644 --- a/src/notifier.rs +++ b/src/notifier.rs @@ -1,15 +1,16 @@ use std::env; use std::ffi::OsStr; use std::path::PathBuf; -use std::process::{Command, Stdio}; -use std::sync::mpsc; -use std::thread; +use std::process::Stdio; -use anyhow::Result; +use async_trait::async_trait; +use tokio::process::Command; +use tokio::sync::mpsc; use which::which; +#[async_trait] pub trait Notifier: Send { - fn notify(&mut self, message: String) -> Result<()>; + async fn notify(&mut self, message: String) -> anyhow::Result<()>; } pub fn get_notifier(custom_command: Option) -> Box { @@ -34,11 +35,12 @@ impl TmuxNotifier { } } +#[async_trait] impl Notifier for TmuxNotifier { - fn notify(&mut self, message: String) -> Result<()> { + async fn notify(&mut self, message: String) -> anyhow::Result<()> { let args = ["display-message", &format!("asciinema: {message}")]; - exec(&mut Command::new(&self.0), &args) + exec(&mut Command::new(&self.0), &args).await } } @@ -50,9 +52,10 @@ impl LibNotifyNotifier { } } +#[async_trait] impl Notifier for LibNotifyNotifier { - fn notify(&mut self, message: String) -> Result<()> { - exec(&mut Command::new(&self.0), &["asciinema", &message]) + async fn notify(&mut self, message: String) -> anyhow::Result<()> { + exec(&mut Command::new(&self.0), &["asciinema", &message]).await } } @@ -64,43 +67,48 @@ impl AppleScriptNotifier { } } +#[async_trait] impl Notifier for AppleScriptNotifier { - fn notify(&mut self, message: String) -> Result<()> { + async fn notify(&mut self, message: String) -> anyhow::Result<()> { let text = message.replace('\"', "\\\""); let script = format!("display notification \"{text}\" with title \"asciinema\""); - exec(&mut Command::new(&self.0), &["-e", &script]) + exec(&mut Command::new(&self.0), &["-e", &script]).await } } pub struct CustomNotifier(String); +#[async_trait] impl Notifier for CustomNotifier { - fn notify(&mut self, text: String) -> Result<()> { + async fn notify(&mut self, text: String) -> anyhow::Result<()> { exec::<&str>( Command::new("/bin/sh") .args(["-c", &self.0]) .env("TEXT", text), &[], ) + .await } } pub struct NullNotifier; +#[async_trait] impl Notifier for NullNotifier { - fn notify(&mut self, _text: String) -> Result<()> { + async fn notify(&mut self, _text: String) -> anyhow::Result<()> { Ok(()) } } -fn exec>(command: &mut Command, args: &[S]) -> Result<()> { +async fn exec>(command: &mut Command, args: &[S]) -> anyhow::Result<()> { let status = command .stdin(Stdio::null()) .stdout(Stdio::null()) .stderr(Stdio::null()) .args(args) - .status()?; + .status() + .await?; if status.success() { Ok(()) @@ -113,27 +121,28 @@ fn exec>(command: &mut Command, args: &[S]) -> Result<()> { } #[derive(Clone)] -pub struct ThreadedNotifier(mpsc::Sender); +pub struct BackgroundNotifier(mpsc::Sender); -pub fn threaded(mut notifier: Box) -> ThreadedNotifier { - let (tx, rx) = mpsc::channel(); +pub fn background(mut notifier: Box) -> BackgroundNotifier { + let (tx, mut rx) = mpsc::channel(16); - thread::spawn(move || { - for message in &rx { - if notifier.notify(message).is_err() { + tokio::spawn(async move { + while let Some(message) = rx.recv().await { + if notifier.notify(message).await.is_err() { break; } } - for _ in rx {} + while rx.recv().await.is_some() {} }); - ThreadedNotifier(tx) + BackgroundNotifier(tx) } -impl Notifier for ThreadedNotifier { - fn notify(&mut self, message: String) -> Result<()> { - self.0.send(message)?; +#[async_trait] +impl Notifier for BackgroundNotifier { + async fn notify(&mut self, message: String) -> anyhow::Result<()> { + self.0.send(message).await?; Ok(()) } diff --git a/src/player.rs b/src/player.rs index e083148..3efe61a 100644 --- a/src/player.rs +++ b/src/player.rs @@ -1,14 +1,10 @@ -use std::io::{self, Write}; -use std::os::unix::io::AsRawFd; -use std::time::{Duration, Instant}; - use anyhow::Result; -use nix::sys::select::{pselect, FdSet}; -use nix::sys::time::{TimeSpec, TimeValLike}; +use tokio::sync::mpsc; +use tokio::time::{self, Duration, Instant}; use crate::asciicast::{self, Event, EventData}; use crate::config::Key; -use crate::tty::Tty; +use crate::tty::{DevTty, Tty}; pub struct KeyBindings { pub quit: Key, @@ -28,78 +24,77 @@ impl Default for KeyBindings { } } -pub fn play( - recording: asciicast::Asciicast, - mut tty: impl Tty, +pub async fn play( + recording: asciicast::Asciicast<'static>, speed: f64, - idle_time_limit: Option, + idle_time_limit_override: Option, pause_on_markers: bool, keys: &KeyBindings, auto_resize: bool, ) -> Result { let initial_cols = recording.header.term_cols; let initial_rows = recording.header.term_rows; - let mut events = open_recording(recording, speed, idle_time_limit)?; - let mut stdout = io::stdout(); + let mut events = emit_session_events(recording, speed, idle_time_limit_override)?; let mut epoch = Instant::now(); let mut pause_elapsed_time: Option = None; - let mut next_event = events.next().transpose()?; + let mut next_event = events.recv().await.transpose()?; + let mut input = [0u8; 1024]; + let mut tty = DevTty::open().await?; if auto_resize { - resize_terminal(&mut stdout, initial_cols, initial_rows)?; + tty.resize((initial_cols as usize, initial_rows as usize).into()) + .await?; } while let Some(Event { time, data }) = &next_event { if let Some(pet) = pause_elapsed_time { - if let Some(input) = read_input(&mut tty, 1_000_000)? { - if keys.quit.as_ref().is_some_and(|k| k == &input) { - stdout.write_all("\r\n".as_bytes())?; - return Ok(false); + let n = tty.read(&mut input).await?; + let key = &input[..n]; + + if keys.quit.as_ref().is_some_and(|k| k == key) { + tty.write_all("\r\n".as_bytes()).await?; + return Ok(false); + } + + if keys.pause.as_ref().is_some_and(|k| k == key) { + epoch = Instant::now() - Duration::from_micros(pet); + pause_elapsed_time = None; + } else if keys.step.as_ref().is_some_and(|k| k == key) { + pause_elapsed_time = Some(*time); + + match data { + EventData::Output(data) => { + tty.write_all(data.as_bytes()).await?; + } + + EventData::Resize(cols, rows) if auto_resize => { + tty.resize((*cols as usize, *rows as usize).into()).await?; + } + + _ => {} } - if keys.pause.as_ref().is_some_and(|k| k == &input) { - epoch = Instant::now() - Duration::from_micros(pet); - pause_elapsed_time = None; - } else if keys.step.as_ref().is_some_and(|k| k == &input) { - pause_elapsed_time = Some(*time); + next_event = events.recv().await.transpose()?; + } else if keys.next_marker.as_ref().is_some_and(|k| k == key) { + while let Some(Event { time, data }) = next_event { + next_event = events.recv().await.transpose()?; match data { EventData::Output(data) => { - stdout.write_all(data.as_bytes())?; - stdout.flush()?; + tty.write_all(data.as_bytes()).await?; + } + + EventData::Marker(_) => { + pause_elapsed_time = Some(time); + break; } EventData::Resize(cols, rows) if auto_resize => { - resize_terminal(&mut stdout, *cols, *rows)?; + tty.resize((cols as usize, rows as usize).into()).await?; } _ => {} } - - next_event = events.next().transpose()?; - } else if keys.next_marker.as_ref().is_some_and(|k| k == &input) { - while let Some(Event { time, data }) = next_event { - next_event = events.next().transpose()?; - - match data { - EventData::Output(data) => { - stdout.write_all(data.as_bytes())?; - } - - EventData::Marker(_) => { - pause_elapsed_time = Some(time); - break; - } - - EventData::Resize(cols, rows) if auto_resize => { - resize_terminal(&mut stdout, cols, rows)?; - } - - _ => {} - } - } - - stdout.flush()?; } } } else { @@ -107,15 +102,20 @@ pub fn play( let delay = *time as i64 - epoch.elapsed().as_micros() as i64; if delay > 0 { - stdout.flush()?; + let delay = (*time as i64 - epoch.elapsed().as_micros() as i64).max(0) as u64; - if let Some(key) = read_input(&mut tty, delay)? { - if keys.quit.as_ref().is_some_and(|k| k == &key) { - stdout.write_all("\r\n".as_bytes())?; + if let Ok(result) = + time::timeout(Duration::from_micros(delay), tty.read(&mut input)).await + { + let n = result?; + let key = &input[..n]; + + if keys.quit.as_ref().is_some_and(|k| k == key) { + tty.write_all("\r\n".as_bytes()).await?; return Ok(false); } - if keys.pause.as_ref().is_some_and(|k| k == &key) { + if keys.pause.as_ref().is_some_and(|k| k == key) { pause_elapsed_time = Some(epoch.elapsed().as_micros() as u64); break; } @@ -126,17 +126,17 @@ pub fn play( match data { EventData::Output(data) => { - stdout.write_all(data.as_bytes())?; + tty.write_all(data.as_bytes()).await?; } EventData::Resize(cols, rows) if auto_resize => { - resize_terminal(&mut stdout, *cols, *rows)?; + tty.resize((*cols as usize, *rows as usize).into()).await?; } EventData::Marker(_) => { if pause_on_markers { pause_elapsed_time = Some(*time); - next_event = events.next().transpose()?; + next_event = events.recv().await.transpose()?; break; } } @@ -144,7 +144,7 @@ pub fn play( _ => (), } - next_event = events.next().transpose()?; + next_event = events.recv().await.transpose()?; } } } @@ -152,56 +152,28 @@ pub fn play( Ok(true) } -fn resize_terminal(stdout: &mut impl Write, cols: u16, rows: u16) -> io::Result<()> { - let resize_sequence = format!("\x1b[8;{};{}t", rows, cols); - stdout.write_all(resize_sequence.as_bytes())?; - stdout.flush()?; - - Ok(()) -} - -fn open_recording( - recording: asciicast::Asciicast<'_>, +fn emit_session_events( + recording: asciicast::Asciicast<'static>, speed: f64, - idle_time_limit: Option, -) -> Result> + '_> { - let idle_time_limit = idle_time_limit + idle_time_limit_override: Option, +) -> Result>> { + let idle_time_limit = idle_time_limit_override .or(recording.header.idle_time_limit) .unwrap_or(f64::MAX); let events = asciicast::limit_idle_time(recording.events, idle_time_limit); let events = asciicast::accelerate(events, speed); + // TODO avoid collect, support playback from stdin + let events: Vec<_> = events.collect(); + let (tx, rx) = mpsc::channel::>(1024); - Ok(events) -} - -fn read_input(tty: &mut T, timeout: i64) -> Result>> { - let tty_fd = tty.as_fd(); - let nfds = Some(tty_fd.as_raw_fd() + 1); - let mut rfds = FdSet::new(); - rfds.insert(tty_fd); - let timeout = TimeSpec::microseconds(timeout); - let mut input: Vec = Vec::new(); - - pselect(nfds, &mut rfds, None, None, &timeout, None)?; - - if rfds.contains(tty_fd) { - let mut buf = [0u8; 1024]; - - while let Ok(n) = tty.read(&mut buf) { - if n == 0 { + tokio::spawn(async move { + for event in events { + if tx.send(event).await.is_err() { break; } - - input.extend_from_slice(&buf[0..n]); } + }); - if !input.is_empty() { - Ok(Some(input)) - } else { - Ok(None) - } - } else { - Ok(None) - } + Ok(rx) } diff --git a/src/pty.rs b/src/pty.rs index ab955f6..83d497d 100644 --- a/src/pty.rs +++ b/src/pty.rs @@ -1,260 +1,95 @@ use std::collections::HashMap; use std::env; use std::ffi::{CString, NulError}; -use std::fs::File; -use std::io::{self, ErrorKind, Read, Write}; -use std::os::fd::AsFd; -use std::os::fd::{BorrowedFd, OwnedFd}; +use std::os::fd::OwnedFd; use std::os::unix::io::AsRawFd; -use std::sync::Arc; -use std::time::{Duration, Instant}; -use anyhow::bail; use nix::errno::Errno; -use nix::libc::EIO; -use nix::sys::select::{select, FdSet}; -use nix::sys::signal::{self, kill, Signal}; +use nix::pty::{ForkptyResult, Winsize}; +use nix::sys::signal::{self, SigHandler, Signal}; use nix::sys::wait::{self, WaitPidFlag, WaitStatus}; -use nix::unistd; +use nix::unistd::{self, Pid}; use nix::{libc, pty}; -use signal_hook::consts::{SIGALRM, SIGCHLD, SIGHUP, SIGINT, SIGQUIT, SIGTERM, SIGWINCH}; -use signal_hook::SigId; +use tokio::io::unix::AsyncFd; +use tokio::io::{self, Interest}; +use tokio::task; -use crate::io::set_non_blocking; -use crate::tty::{Tty, TtySize}; +use crate::fd::FdExt; -type ExtraEnv = HashMap; - -pub trait Handler { - fn output(&mut self, time: Duration, data: &[u8]) -> bool; - fn input(&mut self, time: Duration, data: &[u8]) -> bool; - fn resize(&mut self, time: Duration, tty_size: TtySize) -> bool; - fn stop(&mut self, time: Duration, exit_status: i32); +pub struct Pty { + child: Pid, + master: AsyncFd, } -pub fn open_signal_fd() -> anyhow::Result { - SignalFd::open(&[SIGWINCH, SIGINT, SIGTERM, SIGQUIT, SIGHUP, SIGALRM, SIGCHLD]) -} - -pub fn exec, T: Tty, H: Handler>( +pub async fn spawn>( command: &[S], - extra_env: &ExtraEnv, - initial_tty_size: TtySize, - tty: &mut T, - handler: &mut H, - signal_fd: SignalFd, -) -> anyhow::Result { - let winsize = initial_tty_size.into(); - let epoch = Instant::now(); + winsize: Winsize, + extra_env: &HashMap, +) -> anyhow::Result { let result = unsafe { pty::forkpty(Some(&winsize), None) }?; match result { - pty::ForkptyResult::Parent { child, master } => { - let code = handle_parent(master, child, tty, handler, epoch, signal_fd)?; - handler.stop(epoch.elapsed(), code); + ForkptyResult::Parent { child, master } => { + master.set_nonblocking()?; + let master = AsyncFd::new(master)?; - Ok(code) + Ok(Pty { child, master }) } - pty::ForkptyResult::Child => { + ForkptyResult::Child => { handle_child(command, extra_env)?; unreachable!(); } } } -fn handle_parent( - master_fd: OwnedFd, - child: unistd::Pid, - tty: &mut T, - handler: &mut H, - epoch: Instant, - signal_fd: SignalFd, -) -> anyhow::Result { - let wait_result = match copy(master_fd, child, tty, handler, epoch, signal_fd) { - Ok(Some(status)) => Ok(status), - Ok(None) => wait::waitpid(child, None), +impl Pty { + pub async fn read(&self, buffer: &mut [u8]) -> io::Result { + self.master + .async_io(Interest::READABLE, |fd| match unistd::read(fd, buffer) { + Ok(n) => Ok(n), + Err(Errno::EIO) => Ok(0), + Err(e) => Err(e.into()), + }) + .await + } - Err(e) => { - let _ = wait::waitpid(child, None); - return Err(e); - } - }; + pub async fn write(&self, buffer: &[u8]) -> io::Result { + self.master + .async_io(Interest::WRITABLE, |fd| match unistd::write(fd, buffer) { + Ok(n) => Ok(n), + Err(Errno::EIO) => Ok(0), + Err(e) => Err(e.into()), + }) + .await + } - match wait_result { - Ok(WaitStatus::Exited(_pid, status)) => Ok(status), - Ok(WaitStatus::Signaled(_pid, signal, ..)) => Ok(128 + signal as i32), - Ok(_) => Ok(1), - Err(e) => Err(anyhow::anyhow!(e)), + pub fn resize(&self, winsize: Winsize) { + unsafe { libc::ioctl(self.master.as_raw_fd(), libc::TIOCSWINSZ, &winsize) }; + } + + pub fn kill(&self) { + // Any errors occurred when killing the child are ignored. + let _ = signal::kill(self.child, Signal::SIGTERM); + } + + pub async fn wait(&self, options: Option) -> io::Result { + let pid = self.child; + task::spawn_blocking(move || Ok(wait::waitpid(pid, options)?)).await? } } -const BUF_SIZE: usize = 128 * 1024; - -fn copy( - master_fd: OwnedFd, - child: unistd::Pid, - tty: &mut T, - handler: &mut H, - epoch: Instant, - mut signal_fd: SignalFd, -) -> anyhow::Result> { - let mut master = File::from(master_fd); - let master_raw_fd = master.as_raw_fd(); - let mut buf = [0u8; BUF_SIZE]; - let mut input: Vec = Vec::with_capacity(BUF_SIZE); - let mut output: Vec = Vec::with_capacity(BUF_SIZE); - let mut master_closed = false; - - set_non_blocking(&master)?; - - loop { - let master_fd = master.as_fd(); - let tty_fd = tty.as_fd(); - let mut rfds = FdSet::new(); - let mut wfds = FdSet::new(); - - rfds.insert(tty_fd); - rfds.insert(signal_fd.as_fd()); - - if !master_closed { - rfds.insert(master_fd); - - if !input.is_empty() { - wfds.insert(master_fd); - } - } - - if !output.is_empty() { - wfds.insert(tty_fd); - } - - if let Err(e) = select(None, &mut rfds, &mut wfds, None, None) { - if e == Errno::EINTR { - continue; - } - - bail!(e); - } - - let master_read = rfds.contains(master_fd); - let master_write = wfds.contains(master_fd); - let tty_read = rfds.contains(tty_fd); - let tty_write = wfds.contains(tty_fd); - let signal_read = rfds.contains(signal_fd.as_fd()); - - if master_read { - while let Some(n) = read_non_blocking(&mut master, &mut buf)? { - if n > 0 { - if handler.output(epoch.elapsed(), &buf[0..n]) { - output.extend_from_slice(&buf[0..n]); - } - } else if output.is_empty() { - return Ok(None); - } else { - master_closed = true; - break; - } - } - } - - if master_write { - let mut buf: &[u8] = input.as_ref(); - - while let Some(n) = write_non_blocking(&mut master, buf)? { - buf = &buf[n..]; - - if buf.is_empty() { - break; - } - } - - let left = buf.len(); - - if left == 0 { - input.clear(); - } else { - input.drain(..input.len() - left); - } - } - - if tty_write { - let mut buf: &[u8] = output.as_ref(); - - while let Some(n) = write_non_blocking(tty, buf)? { - buf = &buf[n..]; - - if buf.is_empty() { - break; - } - } - - let left = buf.len(); - - if left == 0 { - if master_closed { - return Ok(None); - } - - output.clear(); - } else { - output.drain(..output.len() - left); - } - } - - if tty_read { - while let Some(n) = read_non_blocking(tty, &mut buf)? { - if n > 0 { - if handler.input(epoch.elapsed(), &buf[0..n]) { - input.extend_from_slice(&buf[0..n]); - } - } else { - return Ok(None); - } - } - } - - let mut kill_the_child = false; - - if signal_read { - for signal in signal_fd.flush() { - match signal { - SIGWINCH => { - let winsize = tty.get_size(); - - if handler.resize(epoch.elapsed(), winsize.into()) { - set_pty_size(master_raw_fd, &winsize); - } - } - - SIGINT | SIGTERM | SIGQUIT | SIGHUP => { - kill_the_child = true; - } - - SIGCHLD => { - if let Ok(status) = wait::waitpid(child, Some(WaitPidFlag::WNOHANG)) { - if status != WaitStatus::StillAlive { - return Ok(Some(status)); - } - } - } - - _ => {} - } - } - } - - if kill_the_child { - // Any errors occurred when killing the child are ignored. - let _ = kill(child, Signal::SIGTERM); - return Ok(None); - } +impl Drop for Pty { + fn drop(&mut self) { + self.kill(); + let _ = wait::waitpid(self.child, None); } } -fn handle_child>(command: &[S], extra_env: &ExtraEnv) -> anyhow::Result<()> { - use signal::SigHandler; - +fn handle_child>( + command: &[S], + extra_env: &HashMap, +) -> anyhow::Result<()> { let command = command .iter() .map(|s| CString::new(s.as_ref())) @@ -269,166 +104,42 @@ fn handle_child>(command: &[S], extra_env: &ExtraEnv) -> anyhow::R unsafe { libc::_exit(1) } } -fn set_pty_size(pty_fd: i32, winsize: &pty::Winsize) { - unsafe { libc::ioctl(pty_fd, libc::TIOCSWINSZ, winsize) }; -} - -fn read_non_blocking( - source: &mut R, - buf: &mut [u8], -) -> io::Result> { - match source.read(buf) { - Ok(n) => Ok(Some(n)), - - Err(e) => { - if e.kind() == ErrorKind::WouldBlock { - Ok(None) - } else if e.raw_os_error().is_some_and(|code| code == EIO) { - Ok(Some(0)) - } else { - return Err(e); - } - } - } -} - -fn write_non_blocking(sink: &mut W, buf: &[u8]) -> io::Result> { - match sink.write(buf) { - Ok(n) => Ok(Some(n)), - - Err(e) => { - if e.kind() == ErrorKind::WouldBlock { - Ok(None) - } else if e.raw_os_error().is_some_and(|code| code == EIO) { - Ok(Some(0)) - } else { - return Err(e); - } - } - } -} - -pub struct SignalFd { - sigids: Vec, - rx: OwnedFd, -} - -impl SignalFd { - fn open(signals: &[libc::c_int]) -> anyhow::Result { - let (rx, tx) = unistd::pipe()?; - set_non_blocking(&rx)?; - set_non_blocking(&tx)?; - - let tx = Arc::new(tx); - - let mut sigids = Vec::new(); - - for signal in signals { - let tx_ = Arc::clone(&tx); - let num = *signal as u8; - - let sigid = unsafe { - signal_hook::low_level::register(*signal, move || { - let _ = unistd::write(&tx_, &[num]); - }) - }?; - - sigids.push(sigid); - } - - Ok(Self { sigids, rx }) - } - - fn flush(&mut self) -> Vec { - let mut buf = [0; 256]; - let mut signals = Vec::new(); - - while let Ok(n) = unistd::read(&self.rx, &mut buf) { - for num in &buf[..n] { - signals.push(*num as i32); - } - - if n == 0 { - break; - }; - } - - signals - } -} - -impl AsFd for SignalFd { - fn as_fd(&self) -> BorrowedFd<'_> { - self.rx.as_fd() - } -} - -impl Drop for SignalFd { - fn drop(&mut self) { - for sigid in &self.sigids { - signal_hook::low_level::unregister(*sigid); - } - } -} - #[cfg(test)] mod tests { - use super::{Handler, SignalFd}; - use crate::pty::ExtraEnv; - use crate::tty::{NullTty, TtySize}; - use std::time::Duration; + use std::{collections::HashMap, sync::Arc}; - #[derive(Default)] - struct TestHandler { - tty_size: TtySize, - output: Vec>, + use super::Pty; + use crate::tty::TtySize; + + async fn spawn>(command: &[S], extra_env: &HashMap) -> Arc { + Arc::new( + super::spawn(command, TtySize::default().into(), extra_env) + .await + .unwrap(), + ) } - impl Handler for TestHandler { - fn output(&mut self, _time: Duration, data: &[u8]) -> bool { - self.output.push(data.into()); + async fn read_output(pty: Arc) -> Vec { + tokio::spawn(async move { + let mut buf = [0u8; 1024]; + let mut output = Vec::new(); - true - } + while let Ok(n) = pty.read(&mut buf).await { + if n == 0 { + break; + } - fn input(&mut self, _time: Duration, _data: &[u8]) -> bool { - true - } - - fn resize(&mut self, _time: Duration, _size: TtySize) -> bool { - true - } - - fn stop(&mut self, _time: Duration, _exit_status: i32) {} - } - - impl TestHandler { - fn new() -> Self { - Self { - tty_size: Default::default(), - output: Vec::new(), + output.push(String::from_utf8_lossy(&buf[..n]).to_string()); } - } - fn output(&self) -> Vec { - self.output - .iter() - .map(|x| String::from_utf8_lossy(x).to_string()) - .collect::>() - } + output + }) + .await + .unwrap() } - fn setup() -> (TestHandler, SignalFd) { - let handler = TestHandler::new(); - let signal_fd = super::open_signal_fd().unwrap(); - - (handler, signal_fd) - } - - #[test] - fn exec_basic() { - let (mut handler, signal_fd) = setup(); - + #[tokio::test] + async fn spawn_basic() { let code = r#" import sys; import time; @@ -438,71 +149,36 @@ time.sleep(0.1); sys.stdout.write('bar'); "#; - let _code = super::exec( - &["python3", "-c", code], - &ExtraEnv::new(), - TtySize::default(), - &mut NullTty::open().unwrap(), - &mut handler, - signal_fd, - ) - .unwrap(); + let pty = spawn(&["python3", "-c", code], &HashMap::new()).await; + let output = read_output(pty).await; - assert_eq!(handler.output(), vec!["foo", "bar"]); - assert_eq!(handler.tty_size, TtySize(80, 24)); + assert_eq!(output, vec!["foo", "bar"]); } - #[test] - fn exec_no_output() { - let (mut handler, signal_fd) = setup(); + #[tokio::test] + async fn spawn_no_output() { + let pty = spawn(&["true"], &HashMap::new()).await; + let output = read_output(pty).await; - let _code = super::exec( - &["true"], - &ExtraEnv::new(), - TtySize::default(), - &mut NullTty::open().unwrap(), - &mut handler, - signal_fd, - ) - .unwrap(); - - assert!(handler.output().is_empty()); + assert!(output.is_empty()); } - #[test] - fn exec_quick() { - let (mut handler, signal_fd) = setup(); + #[tokio::test] + async fn spawn_quick() { + let pty = spawn(&["printf", "hello world\n"], &HashMap::new()).await; + let output = read_output(pty).await; - let _code = super::exec( - &["printf", "hello world\n"], - &ExtraEnv::new(), - TtySize::default(), - &mut NullTty::open().unwrap(), - &mut handler, - signal_fd, - ) - .unwrap(); - - assert!(!handler.output().is_empty()); + assert!(!output.is_empty()); } - #[test] - fn exec_extra_env() { - let (mut handler, signal_fd) = setup(); + #[tokio::test] + async fn spawn_extra_env() { + let mut extra_env = HashMap::new(); + extra_env.insert("ASCIINEMA_TEST_FOO".to_owned(), "bar".to_owned()); - let mut env = ExtraEnv::new(); - env.insert("ASCIINEMA_TEST_FOO".to_owned(), "bar".to_owned()); + let pty = spawn(&["sh", "-c", "echo -n $ASCIINEMA_TEST_FOO"], &extra_env).await; + let output = read_output(pty).await; - let _code = super::exec( - &["sh", "-c", "echo -n $ASCIINEMA_TEST_FOO"], - &env, - TtySize::default(), - &mut NullTty::open().unwrap(), - &mut handler, - signal_fd, - ) - .unwrap(); - - assert_eq!(handler.output(), vec!["bar"]); + assert_eq!(output, vec!["bar"]); } } diff --git a/src/server.rs b/src/server.rs index f00dd10..eb338d0 100644 --- a/src/server.rs +++ b/src/server.rs @@ -34,13 +34,10 @@ struct AppState { } pub async fn serve( - listener: std::net::TcpListener, + listener: tokio::net::TcpListener, subscriber: Subscriber, shutdown_token: CancellationToken, ) -> io::Result<()> { - listener.set_nonblocking(true)?; - let listener = tokio::net::TcpListener::from_std(listener)?; - let trace = TraceLayer::new_for_http().make_span_with(DefaultMakeSpan::default().include_headers(true)); diff --git a/src/session.rs b/src/session.rs index 1d5da3c..95be153 100644 --- a/src/session.rs +++ b/src/session.rs @@ -1,21 +1,24 @@ use std::collections::HashMap; -use std::io; -use std::sync::mpsc; -use std::thread; -use std::time::{Duration, SystemTime}; +use std::time::SystemTime; +use async_trait::async_trait; +use futures_util::future; +use futures_util::stream::StreamExt; +use nix::sys::wait::{WaitPidFlag, WaitStatus}; +use signal_hook::consts::signal::*; +use signal_hook_tokio::Signals; +use tokio::io; +use tokio::sync::mpsc; +use tokio::time::Instant; use tracing::error; use crate::config::Key; use crate::notifier::Notifier; -use crate::pty; -use crate::tty::{TtySize, TtyTheme}; -use crate::util::{JoinHandle, Utf8Decoder}; +use crate::pty::{self, Pty}; +use crate::tty::{Tty, TtySize, TtyTheme}; +use crate::util::Utf8Decoder; -pub trait Output: Send { - fn event(&mut self, event: Event) -> io::Result<()>; - fn flush(&mut self) -> io::Result<()>; -} +const BUF_SIZE: usize = 128 * 1024; #[derive(Clone)] pub enum Event { @@ -26,20 +29,6 @@ pub enum Event { Exit(u64, i32), } -pub struct Session { - notifier: N, - input_decoder: Utf8Decoder, - output_decoder: Utf8Decoder, - tty_size: TtySize, - record_input: bool, - keys: KeyBindings, - sender: mpsc::Sender, - time_offset: u64, - pause_time: Option, - prefix_mode: bool, - _handle: JoinHandle, -} - #[derive(Clone)] pub struct Metadata { pub time: SystemTime, @@ -58,85 +47,193 @@ pub struct TermInfo { pub theme: Option, } -impl Session { - pub fn new( - mut outputs: Vec>, - tty_size: TtySize, - record_input: bool, - keys: KeyBindings, - notifier: N, - ) -> Self { - let (sender, receiver) = mpsc::channel::(); +struct Session { + epoch: Instant, + events_tx: mpsc::Sender, + input_decoder: Utf8Decoder, + keys: KeyBindings, + notifier: N, + output_decoder: Utf8Decoder, + pause_time: Option, + prefix_mode: bool, + record_input: bool, + time_offset: u64, + tty_size: TtySize, +} - let handle = thread::spawn(move || { - for event in receiver { - outputs.retain_mut(|output| match output.event(event.clone()) { - Ok(_) => true, +#[async_trait] +pub trait Output: Send { + async fn event(&mut self, event: Event) -> io::Result<()>; + async fn flush(&mut self) -> io::Result<()>; +} - Err(e) => { - error!("output event handler failed: {e:?}"); +pub async fn run, T: Tty + ?Sized, N: Notifier>( + command: &[S], + extra_env: &HashMap, + tty: &mut T, + record_input: bool, + outputs: Vec>, + keys: KeyBindings, + notifier: N, +) -> anyhow::Result { + let epoch = Instant::now(); + let (events_tx, events_rx) = mpsc::channel::(1024); + let winsize = tty.get_size(); + let pty = pty::spawn(command, winsize, extra_env).await?; + tokio::spawn(forward_events(events_rx, outputs)); - false - } - }); - } + let mut session = Session { + epoch, + events_tx, + input_decoder: Utf8Decoder::new(), + keys, + notifier, + output_decoder: Utf8Decoder::new(), + pause_time: None, + prefix_mode: false, + record_input, + time_offset: 0, + tty_size: winsize.into(), + }; - for mut output in outputs { - match output.flush() { - Ok(_) => {} + session.run(pty, tty).await +} - Err(e) => { - error!("output flush failed: {e:?}"); - } - } - } - }); +async fn forward_events(mut events_rx: mpsc::Receiver, outputs: Vec>) { + let mut outputs = outputs; - Session { - notifier, - input_decoder: Utf8Decoder::new(), - output_decoder: Utf8Decoder::new(), - record_input, - keys, - tty_size, - sender, - time_offset: 0, - pause_time: None, - prefix_mode: false, - _handle: JoinHandle::new(handle), - } + while let Some(event) = events_rx.recv().await { + let futs: Vec<_> = outputs + .into_iter() + .map(|output| forward_event(output, event.clone())) + .collect(); + + outputs = future::join_all(futs).await.into_iter().flatten().collect(); } - fn elapsed_time(&self, time: Duration) -> u64 { - if let Some(pause_time) = self.pause_time { - pause_time - } else { - time.as_micros() as u64 - self.time_offset + for mut output in outputs { + if let Err(e) = output.flush().await { + error!("output flush failed: {e:?}"); } } - - fn notify(&mut self, text: S) { - self.notifier - .notify(text.to_string()) - .expect("notification should succeed"); - } } -impl pty::Handler for Session { - fn output(&mut self, time: Duration, data: &[u8]) -> bool { +async fn forward_event(mut output: Box, event: Event) -> Option> { + match output.event(event).await { + Ok(()) => Some(output), + + Err(e) => { + error!("output event handler failed: {e:?}"); + None + } + } +} + +impl Session { + async fn run(&mut self, pty: Pty, tty: &mut T) -> anyhow::Result { + let mut signals = + Signals::new([SIGWINCH, SIGINT, SIGTERM, SIGQUIT, SIGHUP, SIGALRM, SIGCHLD])?; + let mut output_buf = [0u8; BUF_SIZE]; + let mut input_buf = [0u8; BUF_SIZE]; + let mut input: Vec = Vec::with_capacity(BUF_SIZE); + let mut output: Vec = Vec::with_capacity(BUF_SIZE); + let mut wait_status = None; + + loop { + tokio::select! { + result = pty.read(&mut output_buf) => { + let n = result?; + + if n > 0 { + self.handle_output(&output_buf[..n]).await; + output.extend_from_slice(&output_buf[0..n]); + } else { + break; + } + } + + result = pty.write(&input), if !input.is_empty() => { + let n = result?; + input.drain(..n); + } + + result = tty.read(&mut input_buf) => { + let n = result?; + + if n > 0 { + if self.handle_input(&input_buf[..n]).await { + input.extend_from_slice(&input_buf[..n]); + } + } else { + break; + } + } + + result = tty.write(&output), if !output.is_empty() => { + let n = result?; + output.drain(..n); + } + + Some(signal) = signals.next() => { + match signal { + SIGWINCH => { + let winsize = tty.get_size(); + pty.resize(winsize); + self.handle_resize(winsize.into()).await; + } + + SIGINT | SIGTERM | SIGQUIT | SIGHUP => { + pty.kill(); + } + + SIGCHLD => { + if let Ok(status) = pty.wait(Some(WaitPidFlag::WNOHANG)).await { + if status != WaitStatus::StillAlive { + wait_status = Some(status); + break; + } + } + } + + _ => {} + } + } + } + } + + if !output.is_empty() { + self.handle_output(&output).await; + let _ = tty.write_all(&output).await; + } + + let wait_status = match wait_status { + Some(ws) => ws, + None => pty.wait(None).await?, + }; + + let status = match wait_status { + WaitStatus::Exited(_pid, status) => status, + WaitStatus::Signaled(_pid, signal, ..) => 128 + signal as i32, + _ => 1, + }; + + self.handle_exit(status).await; + + Ok(status) + } + + async fn handle_output(&mut self, data: &[u8]) { if self.pause_time.is_none() { let text = self.output_decoder.feed(data); if !text.is_empty() { - let msg = Event::Output(self.elapsed_time(time), text); - self.sender.send(msg).expect("output send should succeed"); + let event = Event::Output(self.elapsed_time(), text); + self.send_session_event(event).await; } } - - true } - fn input(&mut self, time: Duration, data: &[u8]) -> bool { + async fn handle_input(&mut self, data: &[u8]) -> bool { let prefix_key = self.keys.prefix.as_ref(); let pause_key = self.keys.pause.as_ref(); let add_marker_key = self.keys.add_marker.as_ref(); @@ -152,18 +249,18 @@ impl pty::Handler for Session { if pause_key.is_some_and(|key| data == key) { if let Some(pt) = self.pause_time { self.pause_time = None; - self.time_offset += self.elapsed_time(time) - pt; - self.notify("Resumed recording"); + self.time_offset += self.elapsed_time() - pt; + self.notify("Resumed recording").await; } else { - self.pause_time = Some(self.elapsed_time(time)); - self.notify("Paused recording"); + self.pause_time = Some(self.elapsed_time()); + self.notify("Paused recording").await; } return false; } else if add_marker_key.is_some_and(|key| data == key) { - let msg = Event::Marker(self.elapsed_time(time), "".to_owned()); - self.sender.send(msg).expect("marker send should succeed"); - self.notify("Marker added"); + let event = Event::Marker(self.elapsed_time(), "".to_owned()); + self.send_session_event(event).await; + self.notify("Marker added").await; return false; } } @@ -172,28 +269,47 @@ impl pty::Handler for Session { let text = self.input_decoder.feed(data); if !text.is_empty() { - let msg = Event::Input(self.elapsed_time(time), text); - self.sender.send(msg).expect("input send should succeed"); + let event = Event::Input(self.elapsed_time(), text); + self.send_session_event(event).await; } } true } - fn resize(&mut self, time: Duration, tty_size: TtySize) -> bool { + async fn handle_resize(&mut self, tty_size: TtySize) { if tty_size != self.tty_size { - let msg = Event::Resize(self.elapsed_time(time), tty_size); - self.sender.send(msg).expect("resize send should succeed"); - + let event = Event::Resize(self.elapsed_time(), tty_size); + self.send_session_event(event).await; self.tty_size = tty_size; } - - true } - fn stop(&mut self, time: Duration, exit_status: i32) { - let msg = Event::Exit(self.elapsed_time(time), exit_status); - self.sender.send(msg).expect("exit send should succeed"); + async fn handle_exit(&mut self, status: i32) { + let event = Event::Exit(self.elapsed_time(), status); + self.send_session_event(event).await; + } + + fn elapsed_time(&self) -> u64 { + if let Some(pause_time) = self.pause_time { + pause_time + } else { + self.epoch.elapsed().as_micros() as u64 - self.time_offset + } + } + + async fn send_session_event(&mut self, event: Event) { + self.events_tx + .send(event) + .await + .expect("session event send should succeed"); + } + + async fn notify(&mut self, text: S) { + self.notifier + .notify(text.to_string()) + .await + .expect("notification should succeed"); } } diff --git a/src/stream.rs b/src/stream.rs index 523d24c..1f9c910 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -1,28 +1,23 @@ use std::future; -use std::io; use std::time::{Duration, Instant}; -use anyhow::Result; +use async_trait::async_trait; use avt::Vt; use futures_util::{stream, StreamExt}; -use tokio::runtime::Handle; use tokio::sync::{broadcast, mpsc, oneshot}; -use tokio::time; +use tokio::{io, time}; use tokio_stream::wrappers::errors::BroadcastStreamRecvError; use tokio_stream::wrappers::BroadcastStream; use tracing::info; use crate::session::{self, Metadata}; -use crate::tty::TtySize; -use crate::tty::TtyTheme; +use crate::tty::{TtySize, TtyTheme}; pub struct Stream { request_tx: mpsc::Sender, request_rx: mpsc::Receiver, } -pub struct LiveStream(mpsc::UnboundedSender); - type Request = oneshot::Sender; struct Subscription { @@ -30,9 +25,6 @@ struct Subscription { events_rx: broadcast::Receiver, } -#[derive(Clone)] -pub struct Subscriber(mpsc::Sender); - #[derive(Clone)] pub enum Event { Init(u64, u64, TtySize, Option, String), @@ -43,9 +35,14 @@ pub enum Event { Exit(u64, u64, i32), } +#[derive(Clone)] +pub struct Subscriber(mpsc::Sender); + +pub struct LiveStream(mpsc::UnboundedSender); + impl Stream { pub fn new() -> Self { - let (request_tx, request_rx) = mpsc::channel(1); + let (request_tx, request_rx) = mpsc::channel(16); Stream { request_tx, @@ -57,18 +54,16 @@ impl Stream { Subscriber(self.request_tx.clone()) } - pub fn start(self, handle: Handle, metadata: &Metadata) -> LiveStream { + pub async fn start(self, metadata: &Metadata) -> LiveStream { let (stream_tx, stream_rx) = mpsc::unbounded_channel(); let request_rx = self.request_rx; - let fut = run( + tokio::spawn(run( metadata.term.size, metadata.term.theme.clone(), stream_rx, request_rx, - ); - - handle.spawn(fut); + )); LiveStream(stream_tx) } @@ -162,7 +157,8 @@ async fn run( impl Subscriber { pub async fn subscribe( &self, - ) -> Result>> { + ) -> anyhow::Result>> + { let (tx, rx) = oneshot::channel(); self.0.send(tx).await?; let subscription = time::timeout(Duration::from_secs(5), rx).await??; @@ -179,12 +175,13 @@ fn build_vt(tty_size: TtySize) -> Vt { .build() } +#[async_trait] impl session::Output for LiveStream { - fn event(&mut self, event: session::Event) -> io::Result<()> { + async fn event(&mut self, event: session::Event) -> io::Result<()> { self.0.send(event).map_err(io::Error::other) } - fn flush(&mut self) -> io::Result<()> { + async fn flush(&mut self) -> io::Result<()> { Ok(()) } } diff --git a/src/tty.rs b/src/tty.rs index 806ee9c..a14b06b 100644 --- a/src/tty.rs +++ b/src/tty.rs @@ -1,33 +1,70 @@ -use std::fs; -use std::io; -use std::os::fd::{AsFd, AsRawFd, BorrowedFd, OwnedFd}; +use std::fs::File; +use std::future::pending; +use std::io::{Read, Write}; +use std::os::fd::{AsFd, AsRawFd}; +use std::os::unix::fs::OpenOptionsExt; -use anyhow::Result; -use nix::errno::Errno; -use nix::sys::select::{select, FdSet}; -use nix::sys::time::TimeVal; -use nix::{libc, pty, unistd}; +use async_trait::async_trait; +use nix::libc; +use nix::pty::Winsize; +use nix::sys::termios::{self, SetArg, Termios}; use rgb::RGB8; -use termion::raw::{IntoRawMode, RawTerminal}; +use tokio::io::unix::AsyncFd; +use tokio::io::{self, Interest}; +use tokio::time::{self, Duration}; + +const QUERY_READ_TIMEOUT: u64 = 500; +const COLORS_QUERY: &str = "\x1b]10;?\x07\x1b]11;?\x07\x1b]4;0;?\x07\x1b]4;1;?\x07\x1b]4;2;?\x07\x1b]4;3;?\x07\x1b]4;4;?\x07\x1b]4;5;?\x07\x1b]4;6;?\x07\x1b]4;7;?\x07\x1b]4;8;?\x07\x1b]4;9;?\x07\x1b]4;10;?\x07\x1b]4;11;?\x07\x1b]4;12;?\x07\x1b]4;13;?\x07\x1b]4;14;?\x07\x1b]4;15;?\x07"; +const XTVERSION_QUERY: &str = "\x1b[>0q"; + +pub struct DevTty { + file: AsyncFd, + settings: libc::termios, +} + +pub struct NullTty; + +pub struct FixedSizeTty { + inner: T, + cols: Option, + rows: Option, +} #[derive(Clone, Copy, Debug, PartialEq)] pub struct TtySize(pub u16, pub u16); +#[derive(Clone)] +pub struct TtyTheme { + pub fg: RGB8, + pub bg: RGB8, + pub palette: Vec, +} + +#[async_trait] +pub trait Tty { + fn get_size(&self) -> Winsize; + async fn get_theme(&self) -> Option; + async fn get_version(&self) -> Option; + async fn read<'e>(&self, buffer: &'e mut [u8]) -> io::Result; + async fn write<'e>(&self, buffer: &'e [u8]) -> io::Result; + async fn write_all<'e>(&self, buffer: &'e [u8]) -> io::Result<()>; +} + impl Default for TtySize { fn default() -> Self { TtySize(80, 24) } } -impl From for TtySize { - fn from(winsize: pty::Winsize) -> Self { +impl From for TtySize { + fn from(winsize: Winsize) -> Self { TtySize(winsize.ws_col, winsize.ws_row) } } -impl From for pty::Winsize { +impl From for Winsize { fn from(tty_size: TtySize) -> Self { - pty::Winsize { + Winsize { ws_col: tty_size.0, ws_row: tty_size.1, ws_xpixel: 0, @@ -48,82 +85,46 @@ impl From for (u16, u16) { } } -pub trait Tty: io::Write + io::Read + AsFd { - fn get_size(&self) -> pty::Winsize; - fn get_theme(&self) -> Option; - fn get_version(&self) -> Option; -} - -#[derive(Clone)] -pub struct TtyTheme { - pub fg: RGB8, - pub bg: RGB8, - pub palette: Vec, -} - -pub struct DevTty { - file: RawTerminal, -} - -const QUERY_READ_TIMEOUT: i64 = 500_000; - impl DevTty { - pub fn open() -> Result { - let file = fs::OpenOptions::new() + pub async fn open() -> anyhow::Result { + let file = File::options() .read(true) .write(true) - .open("/dev/tty")? - .into_raw_mode()?; + .custom_flags(libc::O_NONBLOCK) + .open("/dev/tty")?; - crate::io::set_non_blocking(&file)?; + let file = AsyncFd::new(file)?; + let settings = make_raw(&file)?; - Ok(Self { file }) + Ok(Self { file, settings }) } - fn query(&self, query: &str) -> Result> { + async fn query(&self, query: &str) -> anyhow::Result> { let mut query = query.to_string().into_bytes(); query.extend_from_slice(b"\x1b[c"); let mut query = &query[..]; let mut response = Vec::new(); let mut buf = [0u8; 1024]; - let fd = self.as_fd(); loop { - let mut timeout = TimeVal::new(0, QUERY_READ_TIMEOUT); - let mut rfds = FdSet::new(); - let mut wfds = FdSet::new(); - rfds.insert(fd); + tokio::select! { + result = self.read(&mut buf) => { + let n = result?; + response.extend_from_slice(&buf[..n]); - if !query.is_empty() { - wfds.insert(fd); - } - - match select(None, &mut rfds, &mut wfds, None, &mut timeout) { - Ok(0) => break, - - Ok(_) => { - if rfds.contains(fd) { - let n = unistd::read(fd, &mut buf)?; - response.extend_from_slice(&buf[..n]); - - if let Some(len) = self.complete_response_len(&response) { - response.truncate(len); - break; - } - } - - if wfds.contains(fd) { - let n = unistd::write(fd, query)?; - query = &query[n..]; + if let Some(len) = self.complete_response_len(&response) { + response.truncate(len); + break; } } - Err(e) => { - if e == Errno::EINTR { - continue; - } else { - return Err(e.into()); - } + result = self.write(query), if !query.is_empty() => { + let n = result?; + query = &query[n..]; + } + + _ = time::sleep(Duration::from_millis(QUERY_READ_TIMEOUT)) => { + break; } } } @@ -159,6 +160,29 @@ impl DevTty { None } } + + pub async fn resize(&mut self, size: TtySize) -> io::Result<()> { + let xtwinops_seq = format!("\x1b[8;{};{}t", size.1, size.0); + self.write_all(xtwinops_seq.as_bytes()).await?; + + Ok(()) + } +} + +fn make_raw(fd: F) -> anyhow::Result { + let termios = termios::tcgetattr(fd.as_fd())?; + let mut raw_termios = termios.clone(); + termios::cfmakeraw(&mut raw_termios); + termios::tcsetattr(fd.as_fd(), SetArg::TCSANOW, &raw_termios)?; + + Ok(termios.into()) +} + +impl Drop for DevTty { + fn drop(&mut self) { + let termios = Termios::from(self.settings); + let _ = termios::tcsetattr(self.file.as_fd(), SetArg::TCSANOW, &termios); + } } fn parse_color(rgb: &str) -> Option { @@ -178,13 +202,10 @@ fn parse_color(rgb: &str) -> Option { Some(RGB8::new(r, g, b)) } -static COLORS_QUERY: &str = "\x1b]10;?\x07\x1b]11;?\x07\x1b]4;0;?\x07\x1b]4;1;?\x07\x1b]4;2;?\x07\x1b]4;3;?\x07\x1b]4;4;?\x07\x1b]4;5;?\x07\x1b]4;6;?\x07\x1b]4;7;?\x07\x1b]4;8;?\x07\x1b]4;9;?\x07\x1b]4;10;?\x07\x1b]4;11;?\x07\x1b]4;12;?\x07\x1b]4;13;?\x07\x1b]4;14;?\x07\x1b]4;15;?\x07"; - -static XTVERSION_QUERY: &str = "\x1b[>0q"; - +#[async_trait] impl Tty for DevTty { - fn get_size(&self) -> pty::Winsize { - let mut winsize = pty::Winsize { + fn get_size(&self) -> Winsize { + let mut winsize = Winsize { ws_row: 24, ws_col: 80, ws_xpixel: 0, @@ -196,8 +217,8 @@ impl Tty for DevTty { winsize } - fn get_theme(&self) -> Option { - let response = self.query(COLORS_QUERY).ok()?; + async fn get_theme(&self) -> Option { + let response = self.query(COLORS_QUERY).await.ok()?; let response = String::from_utf8_lossy(response.as_slice()); let mut colors = response.match_indices("rgb:"); let (idx, _) = colors.next()?; @@ -215,8 +236,8 @@ impl Tty for DevTty { Some(TtyTheme { fg, bg, palette }) } - fn get_version(&self) -> Option { - let response = self.query(XTVERSION_QUERY).ok()?; + async fn get_version(&self) -> Option { + let response = self.query(XTVERSION_QUERY).await.ok()?; if let [b'\x1b', b'P', b'>', b'|', version @ .., b'\x1b', b'\\'] = &response[..] { Some(String::from_utf8_lossy(version).to_string()) @@ -224,46 +245,35 @@ impl Tty for DevTty { None } } -} -impl io::Read for DevTty { - fn read(&mut self, buf: &mut [u8]) -> io::Result { - self.file.read(buf) - } -} - -impl io::Write for DevTty { - fn write(&mut self, buf: &[u8]) -> io::Result { - self.file.write(buf) - } - - fn flush(&mut self) -> io::Result<()> { - self.file.flush() - } -} - -impl AsFd for DevTty { - fn as_fd(&self) -> BorrowedFd<'_> { - self.file.as_fd() - } -} - -pub struct NullTty { - tx: OwnedFd, - _rx: OwnedFd, -} - -impl NullTty { - pub fn open() -> Result { - let (rx, tx) = unistd::pipe()?; - - Ok(Self { tx, _rx: rx }) + async fn read<'e>(&self, buffer: &'e mut [u8]) -> io::Result { + self.file + .async_io(Interest::READABLE, |mut file| file.read(buffer)) + .await + } + + async fn write<'e>(&self, buffer: &'e [u8]) -> io::Result { + self.file + .async_io(Interest::WRITABLE, |mut file| file.write(buffer)) + .await + } + + async fn write_all<'e>(&self, buffer: &'e [u8]) -> io::Result<()> { + let mut buffer = buffer; + + while !buffer.is_empty() { + let n = self.write(buffer).await?; + buffer = &buffer[n..]; + } + + Ok(()) } } +#[async_trait] impl Tty for NullTty { - fn get_size(&self) -> pty::Winsize { - pty::Winsize { + fn get_size(&self) -> Winsize { + Winsize { ws_row: 24, ws_col: 80, ws_xpixel: 0, @@ -271,55 +281,37 @@ impl Tty for NullTty { } } - fn get_theme(&self) -> Option { + async fn get_theme(&self) -> Option { None } - fn get_version(&self) -> Option { + async fn get_version(&self) -> Option { None } -} -impl io::Read for NullTty { - fn read(&mut self, _buf: &mut [u8]) -> io::Result { - panic!("read attempt from NullTty"); - } -} - -impl io::Write for NullTty { - fn write(&mut self, buf: &[u8]) -> io::Result { - Ok(buf.len()) + async fn read<'e>(&self, _buffer: &'e mut [u8]) -> io::Result { + pending::<()>().await; + unreachable!() } - fn flush(&mut self) -> io::Result<()> { + async fn write<'e>(&self, buffer: &'e [u8]) -> io::Result { + Ok(buffer.len()) + } + + async fn write_all<'e>(&self, _buffer: &'e [u8]) -> io::Result<()> { Ok(()) } } -impl AsFd for NullTty { - fn as_fd(&self) -> BorrowedFd<'_> { - self.tx.as_fd() +impl FixedSizeTty { + pub fn new(inner: T, cols: Option, rows: Option) -> Self { + Self { inner, cols, rows } } } -pub struct FixedSizeTty { - inner: Box, - cols: Option, - rows: Option, -} - -impl FixedSizeTty { - pub fn new(inner: T, cols: Option, rows: Option) -> Self { - Self { - inner: Box::new(inner), - cols, - rows, - } - } -} - -impl Tty for FixedSizeTty { - fn get_size(&self) -> pty::Winsize { +#[async_trait] +impl Tty for FixedSizeTty { + fn get_size(&self) -> Winsize { let mut winsize = self.inner.get_size(); if let Some(cols) = self.cols { @@ -333,34 +325,24 @@ impl Tty for FixedSizeTty { winsize } - fn get_theme(&self) -> Option { - self.inner.get_theme() + async fn get_theme(&self) -> Option { + self.inner.get_theme().await } - fn get_version(&self) -> Option { - self.inner.get_version() - } -} - -impl AsFd for FixedSizeTty { - fn as_fd(&self) -> BorrowedFd<'_> { - return self.inner.as_fd(); - } -} - -impl io::Read for FixedSizeTty { - fn read(&mut self, buf: &mut [u8]) -> io::Result { - self.inner.read(buf) - } -} - -impl io::Write for FixedSizeTty { - fn write(&mut self, buf: &[u8]) -> io::Result { - self.inner.write(buf) + async fn get_version(&self) -> Option { + self.inner.get_version().await } - fn flush(&mut self) -> io::Result<()> { - self.inner.flush() + async fn read<'e>(&self, buffer: &'e mut [u8]) -> io::Result { + self.inner.read(buffer).await + } + + async fn write<'e>(&self, buffer: &'e [u8]) -> io::Result { + self.inner.write(buffer).await + } + + async fn write_all<'e>(&self, buffer: &'e [u8]) -> io::Result<()> { + self.inner.write_all(buffer).await } } @@ -396,12 +378,20 @@ mod tests { } #[test] - fn fixed_size_tty() { - let tty = FixedSizeTty::new(NullTty::open().unwrap(), Some(100), Some(50)); - + fn fixed_size_tty_get_size() { + let tty = FixedSizeTty::new(NullTty, Some(100), Some(50)); let winsize = tty.get_size(); - assert!(winsize.ws_col == 100); assert!(winsize.ws_row == 50); + + let tty = FixedSizeTty::new(NullTty, Some(100), None); + let winsize = tty.get_size(); + assert!(winsize.ws_col == 100); + assert!(winsize.ws_row == 24); + + let tty = FixedSizeTty::new(NullTty, None, None); + let winsize = tty.get_size(); + assert!(winsize.ws_col == 80); + assert!(winsize.ws_row == 24); } } diff --git a/src/util.rs b/src/util.rs index f54d6c7..d5a8b3a 100644 --- a/src/util.rs +++ b/src/util.rs @@ -1,13 +1,13 @@ +use std::io; use std::path::{Path, PathBuf}; -use std::{io, thread}; -use anyhow::{anyhow, bail, Result}; +use anyhow::{anyhow, bail}; use reqwest::Url; use tempfile::NamedTempFile; use crate::html; -pub fn get_local_path(filename: &str) -> Result>> { +pub fn get_local_path(filename: &str) -> anyhow::Result>> { if filename.starts_with("https://") || filename.starts_with("http://") { match download_asciicast(filename) { Ok(path) => Ok(Box::new(path)), @@ -18,7 +18,7 @@ pub fn get_local_path(filename: &str) -> Result>> { } } -fn download_asciicast(url: &str) -> Result { +fn download_asciicast(url: &str) -> anyhow::Result { use reqwest::blocking::get; let mut response = get(Url::parse(url)?)?; @@ -50,24 +50,6 @@ fn download_asciicast(url: &str) -> Result { } } -pub struct JoinHandle(Option>); - -impl JoinHandle { - pub fn new(handle: thread::JoinHandle<()>) -> Self { - Self(Some(handle)) - } -} - -impl Drop for JoinHandle { - fn drop(&mut self) { - self.0 - .take() - .unwrap() - .join() - .expect("worker thread should finish cleanly"); - } -} - pub struct Utf8Decoder(Vec); impl Utf8Decoder {