Files
dokku/plugins/common/ssh.go
2025-11-09 20:21:10 -05:00

406 lines
9.9 KiB
Go

package common
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"net/url"
"os"
"os/exec"
"os/signal"
"path/filepath"
"strconv"
"strings"
"syscall"
"github.com/fatih/color"
"github.com/melbahja/goph"
"golang.org/x/crypto/ssh"
)
// SshCommandInput is the input for CallSshCommand
type SshCommandInput struct {
// Command is the command to execute. This can be the path to an executable
// or the executable with arguments.
//
// Any arguments must be given via Args
Command string
// Args are the arguments to pass to the command.
Args []string
// DisableStdioBuffer disables the stdio buffer
DisableStdioBuffer bool
// Env is a list of environment variables to add to the current environment
Env map[string]string
// AllowUknownHosts allows connecting to hosts with unknown host keys
AllowUknownHosts bool
// RemoteHost is the remote host to connect to
RemoteHost string
// Stdin is the stdin of the command
Stdin io.Reader
// StreamStdio prints stdout and stderr directly to os.Stdout/err as
// the command runs.
StreamStdio bool
// StreamStdout prints stdout directly to os.Stdout as the command runs.
StreamStdout bool
// StreamStderr prints stderr directly to os.Stderr as the command runs.
StreamStderr bool
// Sudo runs the command with sudo -n -u root
Sudo bool
}
// CallSshCommand executes a command on a remote host via ssh
func CallSshCommand(input SshCommandInput) (SshResult, error) {
return CallSshCommandWithContext(context.Background(), input)
}
// CallSshCommandWithContext executes a command on a remote host via ssh with the given context
func CallSshCommandWithContext(ctx context.Context, input SshCommandInput) (SshResult, error) {
signals := make(chan os.Signal, 1)
signal.Notify(signals, os.Interrupt, syscall.SIGHUP,
syscall.SIGINT,
syscall.SIGQUIT,
syscall.SIGTERM)
ctx, cancel := context.WithCancel(ctx)
go func() {
<-signals
cancel()
}()
// hack: colors do not work natively with io.MultiWriter
// as it isn't detected as a tty. If the output isn't
// being captured, then color output can be forced.
isatty := !color.NoColor
env := []string{}
if input.Env != nil {
for k, v := range input.Env {
env = append(env, fmt.Sprintf("%s=%s", k, v))
}
}
u, err := url.Parse(input.RemoteHost)
if err != nil {
return SshResult{}, fmt.Errorf("failed to parse remote host: %w", err)
}
if u.Scheme == "" {
return SshResult{}, fmt.Errorf("missing remote host ssh scheme in remote host: %s", input.RemoteHost)
}
if u.Scheme != "ssh" {
return SshResult{}, fmt.Errorf("invalid remote host scheme: %s", u.Scheme)
}
username := ""
password := ""
if u.User != nil {
username = u.User.Username()
if pass, ok := u.User.Password(); ok {
password = pass
}
}
if username == "" {
username = os.Getenv("USER")
}
portStr := u.Port()
port := 0
if portStr != "" {
portVal, err := strconv.Atoi(portStr)
if err != nil {
return SshResult{}, fmt.Errorf("failed to parse port: %w", err)
}
port = portVal
}
if port == 0 {
port = 22
}
sshKeyPath := filepath.Join(os.Getenv("DOKKU_ROOT"), ".ssh/id_ed25519")
if !FileExists(sshKeyPath) {
sshKeyPath = filepath.Join(os.Getenv("DOKKU_ROOT"), ".ssh/id_rsa")
}
if !FileExists(sshKeyPath) {
return SshResult{}, errors.New("ssh key not found at ~/.ssh/id_ed25519 or ~/.ssh/id_rsa")
}
cmd := SshTask{
Command: input.Command,
Args: input.Args,
Env: env,
DisableStdioBuffer: input.DisableStdioBuffer,
AllowUknownHosts: input.AllowUknownHosts,
Hostname: u.Hostname(),
Port: uint(port),
Username: username,
Password: password,
SshKeyPath: sshKeyPath,
Sudo: input.Sudo,
}
if os.Getenv("DOKKU_TRACE") == "1" {
cmd.PrintCommand = true
}
if input.Stdin != nil {
cmd.Stdin = input.Stdin
} else if isatty {
cmd.Stdin = os.Stdin
}
if input.StreamStdio {
cmd.StreamStdio = true
}
if input.StreamStdout {
cmd.StdOutWriter = os.Stdout
}
if input.StreamStderr {
cmd.StdErrWriter = os.Stderr
}
res, err := cmd.Execute(ctx)
if err != nil {
return res, err
}
if res.ExitCode != 0 {
return res, errors.New(res.Stderr)
}
return res, nil
}
// SshTask is a task for executing a command on a remote host via ssh
type SshTask struct {
// Command is the command to execute. This can be the path to an executable
// or the executable with arguments.
//
// Any arguments must be given via Args
Command string
// Args are the arguments to pass to the command.
Args []string
// Shell run the command in a bash shell.
// Note that the system must have `bash` installed in the PATH or in /bin/bash
Shell bool
// Env is a list of environment variables to add to the current environment
Env []string
// Stdin connect a reader to stdin for the command
// being executed.
Stdin io.Reader
// PrintCommand prints the command before executing
PrintCommand bool
// StreamStdio prints stdout and stderr directly to os.Stdout/err as
// the command runs.
StreamStdio bool
// DisableStdioBuffer prevents any output from being saved in the
// TaskResult, which is useful for when the result is very large, or
// when you want to stream the output to another writer exclusively.
DisableStdioBuffer bool
// StdoutWriter when set will receive a copy of stdout from the command
StdOutWriter io.Writer
// StderrWriter when set will receive a copy of stderr from the command
StdErrWriter io.Writer
// AllowUknownHosts allows connecting to hosts with unknown host keys
AllowUknownHosts bool
// Hostname is the hostname to connect to
Hostname string
// Port is the port to connect to
Port uint
// Username is the username to connect with
Username string
// Password is the password to connect with
Password string
// SshKeyPath is the path to the ssh key to use
SshKeyPath string
// Sudo runs the command with sudo -n -u root
Sudo bool
}
// SshResult is the result of executing a command on a remote host via ssh
type SshResult struct {
Stdout string
Stderr string
ExitCode int
Cancelled bool
}
// Execute runs the task
func (task SshTask) Execute(ctx context.Context) (SshResult, error) {
if task.Command == "" {
return SshResult{}, errors.New("command is required")
}
if task.Hostname == "" {
return SshResult{}, errors.New("hostname is required")
}
if task.SshKeyPath == "" {
return SshResult{}, errors.New("ssh key path is required")
}
if task.Username == "" {
return SshResult{}, errors.New("username is required")
}
if task.Port == 0 {
task.Port = 22
}
if err := TouchFile(filepath.Join(os.Getenv("DOKKU_ROOT"), ".ssh", "known_hosts")); err != nil {
return SshResult{}, fmt.Errorf("failed to touch known_hosts file: %w", err)
}
auth, err := goph.Key(task.SshKeyPath, "")
if err != nil {
return SshResult{}, fmt.Errorf("failed to load ssh key: %w", err)
}
callback, err := goph.DefaultKnownHosts()
if err != nil {
return SshResult{}, fmt.Errorf("failed to load known hosts: %w", err)
}
if task.AllowUknownHosts {
callback = ssh.InsecureIgnoreHostKey()
}
connectionConf := goph.Config{
User: task.Username,
Addr: task.Hostname,
Port: task.Port,
Timeout: goph.DefaultTimeout,
Callback: callback,
Auth: auth,
}
if task.Password != "" {
connectionConf.Auth = goph.Password(task.Password)
}
client, err := goph.NewConn(&connectionConf)
if err != nil {
return SshResult{}, fmt.Errorf("failed to create ssh client: %w", err)
}
defer client.Close()
// don't try to run if the context is already cancelled
if ctx.Err() != nil {
return SshResult{
// the exec package returns -1 for cancelled commands
ExitCode: -1,
Cancelled: ctx.Err() == context.Canceled,
}, ctx.Err()
}
var command string
var commandArgs []string
if task.Shell {
command = "bash"
if len(task.Args) == 0 {
// use Split and Join to remove any extra whitespace?
startArgs := strings.Split(task.Command, " ")
script := strings.Join(startArgs, " ")
commandArgs = append([]string{"-c"}, script)
} else {
script := strings.Join(task.Args, " ")
commandArgs = append([]string{"-c"}, fmt.Sprintf("%s %s", task.Command, script))
}
} else {
command = task.Command
commandArgs = task.Args
}
if task.Sudo {
commandArgs = append([]string{"-n", "-u", "root", command}, commandArgs...)
command = "sudo"
}
if task.PrintCommand || os.Getenv("DOKKU_TRACE") == "1" {
LogDebug(fmt.Sprintf("ssh %s@%s %s %v", task.Username, task.Hostname, command, commandArgs))
}
cmd, err := client.CommandContext(ctx, command, commandArgs...)
if err != nil {
return SshResult{}, fmt.Errorf("failed to create ssh command: %w", err)
}
if len(task.Env) > 0 {
overrides := map[string]bool{}
for _, env := range task.Env {
key := strings.Split(env, "=")[0]
overrides[key] = true
cmd.Env = append(cmd.Env, env)
}
}
if task.Stdin != nil {
cmd.Stdin = task.Stdin
}
stdoutBuff := bytes.Buffer{}
stderrBuff := bytes.Buffer{}
var stdoutWriters []io.Writer
var stderrWriters []io.Writer
if !task.DisableStdioBuffer {
stdoutWriters = append(stdoutWriters, &stdoutBuff)
stderrWriters = append(stderrWriters, &stderrBuff)
}
if task.StreamStdio {
stdoutWriters = append(stdoutWriters, os.Stdout)
stderrWriters = append(stderrWriters, os.Stderr)
}
if task.StdOutWriter != nil {
stdoutWriters = append(stdoutWriters, task.StdOutWriter)
}
if task.StdErrWriter != nil {
stderrWriters = append(stderrWriters, task.StdErrWriter)
}
cmd.Stdout = io.MultiWriter(stdoutWriters...)
cmd.Stderr = io.MultiWriter(stderrWriters...)
startErr := cmd.Start()
if startErr != nil {
return SshResult{}, startErr
}
exitCode := 0
execErr := cmd.Wait()
if execErr != nil {
if exitError, ok := execErr.(*exec.ExitError); ok {
exitCode = exitError.ExitCode()
}
}
return SshResult{
Stdout: stdoutBuff.String(),
Stderr: stderrBuff.String(),
ExitCode: exitCode,
Cancelled: ctx.Err() == context.Canceled,
}, ctx.Err()
}