diff --git a/errors.go b/errors.go index 4b94e194..0c6a1dcf 100644 --- a/errors.go +++ b/errors.go @@ -4,8 +4,13 @@ import ( "fmt" ) -// ErrNoTaskFile is returned when the program can not find a proper TaskFile -var ErrNoTaskFile = fmt.Errorf(`task: No task file found (is it named "%s"?)`, TaskFilePath) +type taskFileNotFound struct { + taskFile string +} + +func (err taskFileNotFound) Error() string { + return fmt.Sprintf(`task: No task file found (is it named "%s"?)`, err.taskFile) +} type taskNotFoundError struct { taskName string diff --git a/read_taskfile.go b/read_taskfile.go new file mode 100644 index 00000000..e76df86f --- /dev/null +++ b/read_taskfile.go @@ -0,0 +1,58 @@ +package task + +import ( + "encoding/json" + "fmt" + "io/ioutil" + "os" + + "runtime" + + "github.com/BurntSushi/toml" + "github.com/imdario/mergo" + "gopkg.in/yaml.v2" +) + +func readTaskfile() (map[string]*Task, error) { + initialTasks, err := readTaskfileData(TaskFilePath) + if err != nil { + return nil, err + } + mergeTasks, err := readTaskfileData(fmt.Sprintf("%s_%s", TaskFilePath, runtime.GOOS)) + if err != nil { + switch err.(type) { + default: + return nil, err + case taskFileNotFound: + return initialTasks, nil + } + } + if err := mergo.MapWithOverwrite(&initialTasks, mergeTasks); err != nil { + return nil, err + } + return initialTasks, nil +} + +func readTaskfileData(path string) (tasks map[string]*Task, err error) { + if b, err := ioutil.ReadFile(path + ".yml"); err == nil { + return tasks, yaml.Unmarshal(b, &tasks) + } + if b, err := ioutil.ReadFile(path + ".json"); err == nil { + return tasks, json.Unmarshal(b, &tasks) + } + if b, err := ioutil.ReadFile(path + ".toml"); err == nil { + return tasks, toml.Unmarshal(b, &tasks) + } + return nil, taskFileNotFound{path} +} + +func exists(path string) (bool, error) { + _, err := os.Stat(path) + if err == nil { + return true, nil + } + if os.IsNotExist(err) { + return false, nil + } + return true, err +} diff --git a/task.go b/task.go index 26aa9809..ec07617f 100644 --- a/task.go +++ b/task.go @@ -1,16 +1,12 @@ package task import ( - "encoding/json" - "io/ioutil" "log" "os" "os/exec" "strings" - "github.com/BurntSushi/toml" "github.com/spf13/pflag" - "gopkg.in/yaml.v2" ) var ( @@ -168,16 +164,3 @@ func (t *Task) runCommand(i int) error { } return nil } - -func readTaskfile() (tasks map[string]*Task, err error) { - if b, err := ioutil.ReadFile(TaskFilePath + ".yml"); err == nil { - return tasks, yaml.Unmarshal(b, &tasks) - } - if b, err := ioutil.ReadFile(TaskFilePath + ".json"); err == nil { - return tasks, json.Unmarshal(b, &tasks) - } - if b, err := ioutil.ReadFile(TaskFilePath + ".toml"); err == nil { - return tasks, toml.Unmarshal(b, &tasks) - } - return nil, ErrNoTaskFile -}