feat: add WithContext functions for all subprocess-related code

Also deprecate the existing subprocess calls.
This commit is contained in:
Jose Diaz-Gonzalez
2024-02-11 23:04:20 -05:00
parent 03df28af87
commit bcf25ba5cd
5 changed files with 177 additions and 22 deletions

View File

@@ -3,8 +3,10 @@ package common
import (
"errors"
"fmt"
"io"
"os"
"os/signal"
"strings"
"syscall"
"context"
@@ -13,21 +15,76 @@ import (
"github.com/fatih/color"
)
// ExecCommandInput is the input for the ExecCommand function
type ExecCommandInput struct {
Command string
Args []string
// Command is the command to execute
Command string
// Args are the arguments to pass to the command
Args []string
// CaptureOutput determines whether to capture the output of the command
CaptureOutput bool
Env map[string]string
StreamStdio bool
// Env is the environment variables to pass to the command
Env map[string]string
// Stdin is the stdin of the command
Stdin io.Reader
// StreamStdio prints stdout and stderr directly to os.Stdout/err as
// the command runs
StreamStdio bool
// StreamStdout prints stdout directly to os.Stdout as the command runs.
StreamStdout bool
// StreamStderr prints stderr directly to os.Stderr as the command runs.
StreamStderr bool
// Sudo runs the command with sudo -n -u root
Sudo bool
}
func CallExecCommand(input ExecCommandInput) (execute.ExecResult, error) {
// ExecCommandResponse is the response for the ExecCommand function
type ExecCommandResponse struct {
// Stdout is the stdout of the command
Stdout string
// Stderr is the stderr of the command
Stderr string
// ExitCode is the exit code of the command
ExitCode int
// Cancelled is whether the command was cancelled
Cancelled bool
}
// StdoutContents returns the trimmed stdout of the command
func (ecr ExecCommandResponse) StdoutContents() string {
return strings.TrimSpace(ecr.Stdout)
}
// StderrContents returns the trimmed stderr of the command
func (ecr ExecCommandResponse) StderrContents() string {
return strings.TrimSpace(ecr.Stderr)
}
// CallExecCommand executes a command on the local host
func CallExecCommand(input ExecCommandInput) (ExecCommandResponse, error) {
ctx := context.Background()
return CallExecCommandWithContext(ctx, input)
}
// CallExecCommandWithContext executes a command on the local host with the given context
func CallExecCommandWithContext(ctx context.Context, input ExecCommandInput) (ExecCommandResponse, error) {
signals := make(chan os.Signal, 1)
signal.Notify(signals, os.Interrupt, syscall.SIGHUP,
syscall.SIGINT,
syscall.SIGQUIT,
syscall.SIGTERM)
ctx, cancel := context.WithCancel(context.Background())
ctx, cancel := context.WithCancel(ctx)
go func() {
<-signals
cancel()
@@ -47,9 +104,16 @@ func CallExecCommand(input ExecCommandInput) (execute.ExecResult, error) {
}
}
command := input.Command
commandArgs := input.Args
if input.Sudo {
commandArgs = append([]string{"-n", "-u", "root", command}, commandArgs...)
command = "sudo"
}
cmd := execute.ExecTask{
Command: input.Command,
Args: input.Args,
Command: command,
Args: commandArgs,
Env: env,
DisableStdioBuffer: !input.CaptureOutput,
}
@@ -58,22 +122,45 @@ func CallExecCommand(input ExecCommandInput) (execute.ExecResult, error) {
cmd.PrintCommand = true
}
if isatty {
if input.Stdin != nil {
cmd.Stdin = input.Stdin
} else if isatty {
cmd.Stdin = os.Stdin
}
if input.StreamStdio {
cmd.StreamStdio = true
}
if input.StreamStdout {
cmd.StdOutWriter = os.Stdout
}
if input.StreamStderr {
cmd.StdErrWriter = os.Stderr
}
res, err := cmd.Execute(ctx)
if err != nil {
return res, err
return ExecCommandResponse{
Stdout: res.Stdout,
Stderr: res.Stderr,
ExitCode: res.ExitCode,
Cancelled: res.Cancelled,
}, err
}
if res.ExitCode != 0 {
return res, errors.New(res.Stderr)
return ExecCommandResponse{
Stdout: res.Stdout,
Stderr: res.Stderr,
ExitCode: res.ExitCode,
Cancelled: res.Cancelled,
}, errors.New(res.Stderr)
}
return res, nil
return ExecCommandResponse{
Stdout: res.Stdout,
Stderr: res.Stderr,
ExitCode: res.ExitCode,
Cancelled: res.Cancelled,
}, nil
}

View File

@@ -1,25 +1,54 @@
package common
import (
execute "github.com/alexellis/go-execute/v2"
"context"
"io"
)
// PlugnTriggerInput is the input for CallPlugnTrigger
type PlugnTriggerInput struct {
Args []string
// Args are the arguments to pass to the trigger
Args []string
// CaptureOutput determines whether to capture the output of the trigger
CaptureOutput bool
Env map[string]string
StreamStdio bool
Trigger string
// Env is the environment variables to pass to the trigger
Env map[string]string
// Stdin is the stdin of the command
Stdin io.Reader
// StreamStdio determines whether to stream the stdio of the trigger
StreamStdio bool
// StreamStdout prints stdout directly to os.Stdout as the command runs.
StreamStdout bool
// StreamStderr prints stderr directly to os.Stderr as the command runs.
StreamStderr bool
// Trigger is the trigger to execute
Trigger string
}
func CallPlugnTrigger(input PlugnTriggerInput) (execute.ExecResult, error) {
// CallPlugnTrigger executes a trigger via plugn
func CallPlugnTrigger(input PlugnTriggerInput) (ExecCommandResponse, error) {
return CallPlugnTriggerWithContext(context.Background(), input)
}
// CallPlugnTriggerWithContext executes a trigger via plugn with the given context
func CallPlugnTriggerWithContext(ctx context.Context, input PlugnTriggerInput) (ExecCommandResponse, error) {
args := []string{"trigger", input.Trigger}
args = append(args, input.Args...)
return CallExecCommand(ExecCommandInput{
return CallExecCommandWithContext(ctx, ExecCommandInput{
Command: "plugn",
Args: args,
CaptureOutput: input.CaptureOutput,
Env: input.Env,
Stdin: input.Stdin,
StreamStdio: input.StreamStdio,
StreamStdout: input.StreamStdout,
StreamStderr: input.StreamStderr,
})
}

View File

@@ -33,12 +33,17 @@ type SftpCopyInput struct {
// CallSftpCopy copies a file to a remote host via sftp
func CallSftpCopy(input SftpCopyInput) (SftpCopyResult, error) {
return CallSftpCopyWithContext(context.Background(), input)
}
// CallSftpCopyWithContext copies a file to a remote host via sftp with the given context
func CallSftpCopyWithContext(ctx context.Context, input SftpCopyInput) (SftpCopyResult, error) {
signals := make(chan os.Signal, 1)
signal.Notify(signals, os.Interrupt, syscall.SIGHUP,
syscall.SIGINT,
syscall.SIGQUIT,
syscall.SIGTERM)
ctx, cancel := context.WithCancel(context.Background())
ctx, cancel := context.WithCancel(ctx)
go func() {
<-signals
cancel()

View File

@@ -43,22 +43,36 @@ type SshCommandInput struct {
// RemoteHost is the remote host to connect to
RemoteHost string
// Stdin is the stdin of the command
Stdin io.Reader
// StreamStdio prints stdout and stderr directly to os.Stdout/err as
// the command runs.
StreamStdio bool
// StreamStdout prints stdout directly to os.Stdout as the command runs.
StreamStdout bool
// StreamStderr prints stderr directly to os.Stderr as the command runs.
StreamStderr bool
// Sudo runs the command with sudo -n -u root
Sudo bool
}
// CallSshCommand executes a command on a remote host via ssh
func CallSshCommand(input SshCommandInput) (SshResult, error) {
return CallSshCommandWithContext(context.Background(), input)
}
// CallSshCommand executes a command on a remote host via ssh with the given context
func CallSshCommandWithContext(ctx context.Context, input SshCommandInput) (SshResult, error) {
signals := make(chan os.Signal, 1)
signal.Notify(signals, os.Interrupt, syscall.SIGHUP,
syscall.SIGINT,
syscall.SIGQUIT,
syscall.SIGTERM)
ctx, cancel := context.WithCancel(context.Background())
ctx, cancel := context.WithCancel(ctx)
go func() {
<-signals
cancel()
@@ -137,13 +151,21 @@ func CallSshCommand(input SshCommandInput) (SshResult, error) {
cmd.PrintCommand = true
}
if isatty {
if input.Stdin != nil {
cmd.Stdin = input.Stdin
} else if isatty {
cmd.Stdin = os.Stdin
}
if input.StreamStdio {
cmd.StreamStdio = true
}
if input.StreamStdout {
cmd.StdOutWriter = os.Stdout
}
if input.StreamStderr {
cmd.StdErrWriter = os.Stderr
}
res, err := cmd.Execute(ctx)
if err != nil {

View File

@@ -22,6 +22,8 @@ type ShellCmd struct {
}
// NewShellCmd returns a new ShellCmd struct
//
// Deprecated: use CallExecCommand instead
func NewShellCmd(command string) *ShellCmd {
items := strings.Split(command, " ")
cmd := items[0]
@@ -30,6 +32,8 @@ func NewShellCmd(command string) *ShellCmd {
}
// NewShellCmdWithArgs returns a new ShellCmd struct
//
// Deprecated: use CallExecCommand instead
func NewShellCmdWithArgs(cmd string, args ...string) *ShellCmd {
commandString := strings.Join(append([]string{cmd}, args...), " ")
@@ -87,12 +91,16 @@ func (sc *ShellCmd) CombinedOutput() ([]byte, error) {
}
// PlugnTrigger fire the given plugn trigger with the given args
//
// Deprecated: use CallPlugnTrigger instead
func PlugnTrigger(triggerName string, args ...string) error {
LogDebug(fmt.Sprintf("plugn trigger %s %v", triggerName, args))
return PlugnTriggerSetup(triggerName, args...).Run()
}
// PlugnTriggerOutput fire the given plugn trigger with the given args
//
// Deprecated: use CallPlugnTrigger with CaptureOutput=true instead
func PlugnTriggerOutput(triggerName string, args ...string) ([]byte, error) {
LogDebug(fmt.Sprintf("plugn trigger %s %v", triggerName, args))
rE, wE, _ := os.Pipe()
@@ -125,12 +133,16 @@ func PlugnTriggerOutput(triggerName string, args ...string) ([]byte, error) {
}
// PlugnTriggerOutputAsString fires the given plugn trigger with the given args and returns the string contents instead of bytes
//
// Deprecated: use CallPlugnTrigger with CaptureOutput=true instead
func PlugnTriggerOutputAsString(triggerName string, args ...string) (string, error) {
b, err := PlugnTriggerOutput(triggerName, args...)
return strings.TrimSpace(string(b[:])), err
}
// PlugnTriggerSetup sets up a plugn trigger call
//
// Deprecated: use CallPlugnTrigger instead
func PlugnTriggerSetup(triggerName string, args ...string) *sh.Session {
shellArgs := make([]interface{}, len(args)+2)
shellArgs[0] = "trigger"