mirror of
https://github.com/dokku/dokku.git
synced 2025-12-16 20:17:44 +01:00
406 lines
9.9 KiB
Go
406 lines
9.9 KiB
Go
package common
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net/url"
|
|
"os"
|
|
"os/exec"
|
|
"os/signal"
|
|
"path/filepath"
|
|
"strconv"
|
|
"strings"
|
|
"syscall"
|
|
|
|
"github.com/fatih/color"
|
|
"github.com/melbahja/goph"
|
|
"golang.org/x/crypto/ssh"
|
|
)
|
|
|
|
// SshCommandInput is the input for CallSshCommand
|
|
type SshCommandInput struct {
|
|
// Command is the command to execute. This can be the path to an executable
|
|
// or the executable with arguments.
|
|
//
|
|
// Any arguments must be given via Args
|
|
Command string
|
|
|
|
// Args are the arguments to pass to the command.
|
|
Args []string
|
|
|
|
// DisableStdioBuffer disables the stdio buffer
|
|
DisableStdioBuffer bool
|
|
|
|
// Env is a list of environment variables to add to the current environment
|
|
Env map[string]string
|
|
|
|
// AllowUknownHosts allows connecting to hosts with unknown host keys
|
|
AllowUknownHosts bool
|
|
|
|
// RemoteHost is the remote host to connect to
|
|
RemoteHost string
|
|
|
|
// Stdin is the stdin of the command
|
|
Stdin io.Reader
|
|
|
|
// StreamStdio prints stdout and stderr directly to os.Stdout/err as
|
|
// the command runs.
|
|
StreamStdio bool
|
|
|
|
// StreamStdout prints stdout directly to os.Stdout as the command runs.
|
|
StreamStdout bool
|
|
|
|
// StreamStderr prints stderr directly to os.Stderr as the command runs.
|
|
StreamStderr bool
|
|
|
|
// Sudo runs the command with sudo -n -u root
|
|
Sudo bool
|
|
}
|
|
|
|
// CallSshCommand executes a command on a remote host via ssh
|
|
func CallSshCommand(input SshCommandInput) (SshResult, error) {
|
|
return CallSshCommandWithContext(context.Background(), input)
|
|
}
|
|
|
|
// CallSshCommandWithContext executes a command on a remote host via ssh with the given context
|
|
func CallSshCommandWithContext(ctx context.Context, input SshCommandInput) (SshResult, error) {
|
|
signals := make(chan os.Signal, 1)
|
|
signal.Notify(signals, os.Interrupt, syscall.SIGHUP,
|
|
syscall.SIGINT,
|
|
syscall.SIGQUIT,
|
|
syscall.SIGTERM)
|
|
ctx, cancel := context.WithCancel(ctx)
|
|
go func() {
|
|
<-signals
|
|
cancel()
|
|
}()
|
|
|
|
// hack: colors do not work natively with io.MultiWriter
|
|
// as it isn't detected as a tty. If the output isn't
|
|
// being captured, then color output can be forced.
|
|
isatty := !color.NoColor
|
|
env := []string{}
|
|
if input.Env != nil {
|
|
for k, v := range input.Env {
|
|
env = append(env, fmt.Sprintf("%s=%s", k, v))
|
|
}
|
|
}
|
|
|
|
u, err := url.Parse(input.RemoteHost)
|
|
if err != nil {
|
|
return SshResult{}, fmt.Errorf("failed to parse remote host: %w", err)
|
|
}
|
|
if u.Scheme == "" {
|
|
return SshResult{}, fmt.Errorf("missing remote host ssh scheme in remote host: %s", input.RemoteHost)
|
|
}
|
|
if u.Scheme != "ssh" {
|
|
return SshResult{}, fmt.Errorf("invalid remote host scheme: %s", u.Scheme)
|
|
}
|
|
|
|
username := ""
|
|
password := ""
|
|
if u.User != nil {
|
|
username = u.User.Username()
|
|
if pass, ok := u.User.Password(); ok {
|
|
password = pass
|
|
}
|
|
}
|
|
if username == "" {
|
|
username = os.Getenv("USER")
|
|
}
|
|
|
|
portStr := u.Port()
|
|
port := 0
|
|
if portStr != "" {
|
|
portVal, err := strconv.Atoi(portStr)
|
|
if err != nil {
|
|
return SshResult{}, fmt.Errorf("failed to parse port: %w", err)
|
|
}
|
|
port = portVal
|
|
}
|
|
if port == 0 {
|
|
port = 22
|
|
}
|
|
|
|
sshKeyPath := filepath.Join(os.Getenv("DOKKU_ROOT"), ".ssh/id_ed25519")
|
|
if !FileExists(sshKeyPath) {
|
|
sshKeyPath = filepath.Join(os.Getenv("DOKKU_ROOT"), ".ssh/id_rsa")
|
|
}
|
|
if !FileExists(sshKeyPath) {
|
|
return SshResult{}, errors.New("ssh key not found at ~/.ssh/id_ed25519 or ~/.ssh/id_rsa")
|
|
}
|
|
|
|
cmd := SshTask{
|
|
Command: input.Command,
|
|
Args: input.Args,
|
|
Env: env,
|
|
DisableStdioBuffer: input.DisableStdioBuffer,
|
|
AllowUknownHosts: input.AllowUknownHosts,
|
|
Hostname: u.Hostname(),
|
|
Port: uint(port),
|
|
Username: username,
|
|
Password: password,
|
|
SshKeyPath: sshKeyPath,
|
|
Sudo: input.Sudo,
|
|
}
|
|
|
|
if os.Getenv("DOKKU_TRACE") == "1" {
|
|
cmd.PrintCommand = true
|
|
}
|
|
|
|
if input.Stdin != nil {
|
|
cmd.Stdin = input.Stdin
|
|
} else if isatty {
|
|
cmd.Stdin = os.Stdin
|
|
}
|
|
|
|
if input.StreamStdio {
|
|
cmd.StreamStdio = true
|
|
}
|
|
if input.StreamStdout {
|
|
cmd.StdOutWriter = os.Stdout
|
|
}
|
|
if input.StreamStderr {
|
|
cmd.StdErrWriter = os.Stderr
|
|
}
|
|
|
|
res, err := cmd.Execute(ctx)
|
|
if err != nil {
|
|
return res, err
|
|
}
|
|
|
|
if res.ExitCode != 0 {
|
|
return res, errors.New(res.Stderr)
|
|
}
|
|
|
|
return res, nil
|
|
}
|
|
|
|
// SshTask is a task for executing a command on a remote host via ssh
|
|
type SshTask struct {
|
|
// Command is the command to execute. This can be the path to an executable
|
|
// or the executable with arguments.
|
|
//
|
|
// Any arguments must be given via Args
|
|
Command string
|
|
|
|
// Args are the arguments to pass to the command.
|
|
Args []string
|
|
|
|
// Shell run the command in a bash shell.
|
|
// Note that the system must have `bash` installed in the PATH or in /bin/bash
|
|
Shell bool
|
|
|
|
// Env is a list of environment variables to add to the current environment
|
|
Env []string
|
|
|
|
// Stdin connect a reader to stdin for the command
|
|
// being executed.
|
|
Stdin io.Reader
|
|
|
|
// PrintCommand prints the command before executing
|
|
PrintCommand bool
|
|
|
|
// StreamStdio prints stdout and stderr directly to os.Stdout/err as
|
|
// the command runs.
|
|
StreamStdio bool
|
|
|
|
// DisableStdioBuffer prevents any output from being saved in the
|
|
// TaskResult, which is useful for when the result is very large, or
|
|
// when you want to stream the output to another writer exclusively.
|
|
DisableStdioBuffer bool
|
|
|
|
// StdoutWriter when set will receive a copy of stdout from the command
|
|
StdOutWriter io.Writer
|
|
|
|
// StderrWriter when set will receive a copy of stderr from the command
|
|
StdErrWriter io.Writer
|
|
|
|
// AllowUknownHosts allows connecting to hosts with unknown host keys
|
|
AllowUknownHosts bool
|
|
|
|
// Hostname is the hostname to connect to
|
|
Hostname string
|
|
|
|
// Port is the port to connect to
|
|
Port uint
|
|
|
|
// Username is the username to connect with
|
|
Username string
|
|
|
|
// Password is the password to connect with
|
|
Password string
|
|
|
|
// SshKeyPath is the path to the ssh key to use
|
|
SshKeyPath string
|
|
|
|
// Sudo runs the command with sudo -n -u root
|
|
Sudo bool
|
|
}
|
|
|
|
// SshResult is the result of executing a command on a remote host via ssh
|
|
type SshResult struct {
|
|
Stdout string
|
|
Stderr string
|
|
ExitCode int
|
|
Cancelled bool
|
|
}
|
|
|
|
// Execute runs the task
|
|
func (task SshTask) Execute(ctx context.Context) (SshResult, error) {
|
|
if task.Command == "" {
|
|
return SshResult{}, errors.New("command is required")
|
|
}
|
|
if task.Hostname == "" {
|
|
return SshResult{}, errors.New("hostname is required")
|
|
}
|
|
if task.SshKeyPath == "" {
|
|
return SshResult{}, errors.New("ssh key path is required")
|
|
}
|
|
if task.Username == "" {
|
|
return SshResult{}, errors.New("username is required")
|
|
}
|
|
if task.Port == 0 {
|
|
task.Port = 22
|
|
}
|
|
|
|
if err := TouchFile(filepath.Join(os.Getenv("DOKKU_ROOT"), ".ssh", "known_hosts")); err != nil {
|
|
return SshResult{}, fmt.Errorf("failed to touch known_hosts file: %w", err)
|
|
}
|
|
|
|
auth, err := goph.Key(task.SshKeyPath, "")
|
|
if err != nil {
|
|
return SshResult{}, fmt.Errorf("failed to load ssh key: %w", err)
|
|
}
|
|
|
|
callback, err := goph.DefaultKnownHosts()
|
|
if err != nil {
|
|
return SshResult{}, fmt.Errorf("failed to load known hosts: %w", err)
|
|
}
|
|
|
|
if task.AllowUknownHosts {
|
|
callback = ssh.InsecureIgnoreHostKey()
|
|
}
|
|
|
|
connectionConf := goph.Config{
|
|
User: task.Username,
|
|
Addr: task.Hostname,
|
|
Port: task.Port,
|
|
Timeout: goph.DefaultTimeout,
|
|
Callback: callback,
|
|
Auth: auth,
|
|
}
|
|
if task.Password != "" {
|
|
connectionConf.Auth = goph.Password(task.Password)
|
|
}
|
|
|
|
client, err := goph.NewConn(&connectionConf)
|
|
if err != nil {
|
|
return SshResult{}, fmt.Errorf("failed to create ssh client: %w", err)
|
|
}
|
|
defer client.Close()
|
|
|
|
// don't try to run if the context is already cancelled
|
|
if ctx.Err() != nil {
|
|
return SshResult{
|
|
// the exec package returns -1 for cancelled commands
|
|
ExitCode: -1,
|
|
Cancelled: ctx.Err() == context.Canceled,
|
|
}, ctx.Err()
|
|
}
|
|
|
|
var command string
|
|
var commandArgs []string
|
|
if task.Shell {
|
|
command = "bash"
|
|
if len(task.Args) == 0 {
|
|
// use Split and Join to remove any extra whitespace?
|
|
startArgs := strings.Split(task.Command, " ")
|
|
script := strings.Join(startArgs, " ")
|
|
commandArgs = append([]string{"-c"}, script)
|
|
} else {
|
|
script := strings.Join(task.Args, " ")
|
|
commandArgs = append([]string{"-c"}, fmt.Sprintf("%s %s", task.Command, script))
|
|
}
|
|
} else {
|
|
command = task.Command
|
|
commandArgs = task.Args
|
|
}
|
|
|
|
if task.Sudo {
|
|
commandArgs = append([]string{"-n", "-u", "root", command}, commandArgs...)
|
|
command = "sudo"
|
|
}
|
|
|
|
if task.PrintCommand || os.Getenv("DOKKU_TRACE") == "1" {
|
|
LogDebug(fmt.Sprintf("ssh %s@%s %s %v", task.Username, task.Hostname, command, commandArgs))
|
|
}
|
|
|
|
cmd, err := client.CommandContext(ctx, command, commandArgs...)
|
|
if err != nil {
|
|
return SshResult{}, fmt.Errorf("failed to create ssh command: %w", err)
|
|
}
|
|
|
|
if len(task.Env) > 0 {
|
|
overrides := map[string]bool{}
|
|
for _, env := range task.Env {
|
|
key := strings.Split(env, "=")[0]
|
|
overrides[key] = true
|
|
cmd.Env = append(cmd.Env, env)
|
|
}
|
|
}
|
|
|
|
if task.Stdin != nil {
|
|
cmd.Stdin = task.Stdin
|
|
}
|
|
|
|
stdoutBuff := bytes.Buffer{}
|
|
stderrBuff := bytes.Buffer{}
|
|
|
|
var stdoutWriters []io.Writer
|
|
var stderrWriters []io.Writer
|
|
|
|
if !task.DisableStdioBuffer {
|
|
stdoutWriters = append(stdoutWriters, &stdoutBuff)
|
|
stderrWriters = append(stderrWriters, &stderrBuff)
|
|
}
|
|
|
|
if task.StreamStdio {
|
|
stdoutWriters = append(stdoutWriters, os.Stdout)
|
|
stderrWriters = append(stderrWriters, os.Stderr)
|
|
}
|
|
|
|
if task.StdOutWriter != nil {
|
|
stdoutWriters = append(stdoutWriters, task.StdOutWriter)
|
|
}
|
|
if task.StdErrWriter != nil {
|
|
stderrWriters = append(stderrWriters, task.StdErrWriter)
|
|
}
|
|
|
|
cmd.Stdout = io.MultiWriter(stdoutWriters...)
|
|
cmd.Stderr = io.MultiWriter(stderrWriters...)
|
|
|
|
startErr := cmd.Start()
|
|
if startErr != nil {
|
|
return SshResult{}, startErr
|
|
}
|
|
|
|
exitCode := 0
|
|
execErr := cmd.Wait()
|
|
if execErr != nil {
|
|
if exitError, ok := execErr.(*exec.ExitError); ok {
|
|
exitCode = exitError.ExitCode()
|
|
}
|
|
}
|
|
|
|
return SshResult{
|
|
Stdout: stdoutBuff.String(),
|
|
Stderr: stderrBuff.String(),
|
|
ExitCode: exitCode,
|
|
Cancelled: ctx.Err() == context.Canceled,
|
|
}, ctx.Err()
|
|
}
|