diff --git a/cmd/task/task.go b/cmd/task/task.go index 74c54697..40950ffe 100644 --- a/cmd/task/task.go +++ b/cmd/task/task.go @@ -87,6 +87,11 @@ func main() { return } + ctx := context.Background() + if !watch { + ctx = getSignalContext() + } + e := task.Executor{ Force: force, Watch: watch, @@ -95,7 +100,7 @@ func main() { Dir: dir, Dry: dry, - Context: getSignalContext(), + Context: ctx, Stdin: os.Stdin, Stdout: os.Stdout, diff --git a/watch.go b/watch.go index 43d346b2..9660c057 100644 --- a/watch.go +++ b/watch.go @@ -2,7 +2,10 @@ package task import ( "context" + "os" + "os/signal" "strings" + "syscall" "time" "github.com/go-task/task/internal/taskfile" @@ -40,6 +43,8 @@ func (e *Executor) watchTasks(calls ...taskfile.Call) error { return err } + closeOnInterrupt(w) + go func() { for { select { @@ -66,6 +71,7 @@ func (e *Executor) watchTasks(calls ...taskfile.Call) error { e.Logger.Errf("%v", err) } case <-w.Closed: + cancel() return } } @@ -84,6 +90,19 @@ func (e *Executor) watchTasks(calls ...taskfile.Call) error { return w.Start(time.Second) } +func isContextError(err error) bool { + return err == context.Canceled || err == context.DeadlineExceeded +} + +func closeOnInterrupt(w *watcher.Watcher) { + ch := make(chan os.Signal, 1) + signal.Notify(ch, os.Interrupt, os.Kill, syscall.SIGTERM) + go func() { + <-ch + w.Close() + }() +} + func (e *Executor) registerWatchedFiles(w *watcher.Watcher, calls ...taskfile.Call) error { oldWatchedFiles := make(map[string]struct{}) for f := range w.WatchedFiles() { @@ -140,12 +159,3 @@ func (e *Executor) registerWatchedFiles(w *watcher.Watcher, calls ...taskfile.Ca } return nil } - -func isContextError(err error) bool { - switch err { - case context.Canceled, context.DeadlineExceeded: - return true - default: - return false - } -}