diff --git a/task.go b/task.go index 17e5be10..21abcf79 100644 --- a/task.go +++ b/task.go @@ -106,7 +106,12 @@ func (e *Executor) Run(ctx context.Context, calls ...taskfile.Call) error { // Setup setups Executor's internal state func (e *Executor) Setup() error { var err error - e.Taskfile, err = read.Taskfile(e.Dir, e.Entrypoint) + e.Taskfile, err = read.Taskfile(&read.ReaderNode{ + Dir: e.Dir, + Entrypoint: e.Entrypoint, + Parent: nil, + Optional: false, + }) if err != nil { return err } diff --git a/task_test.go b/task_test.go index cb75bb00..47c2a8bd 100644 --- a/task_test.go +++ b/task_test.go @@ -753,6 +753,31 @@ func TestIncludes(t *testing.T) { tt.Run(t) } +func TestIncludesMultiLevel(t *testing.T) { + tt := fileContentTest{ + Dir: "testdata/includes_multi_level", + Target: "default", + TrimSpace: true, + Files: map[string]string{}, + } + tt.Run(t) +} + +func TestIncludeCycle(t *testing.T) { + const dir = "testdata/includes_cycle" + expectedError := "include cycle detected between testdata/includes_cycle/Taskfile.yml <--> testdata/includes_cycle/one/two/Taskfile.yml" + + var buff bytes.Buffer + e := task.Executor{ + Dir: dir, + Stdout: &buff, + Stderr: &buff, + Silent: true, + } + + assert.EqualError(t, e.Setup(), expectedError) +} + func TestIncorrectVersionIncludes(t *testing.T) { const dir = "testdata/incorrect_includes" expectedError := "task: Import with additional parameters is only available starting on Taskfile version v3" diff --git a/taskfile/read/taskfile.go b/taskfile/read/taskfile.go index ec7a3ab9..b97d21ab 100644 --- a/taskfile/read/taskfile.go +++ b/taskfile/read/taskfile.go @@ -15,29 +15,35 @@ import ( ) var ( - // ErrIncludedTaskfilesCantHaveIncludes is returned when a included Taskfile contains includes - ErrIncludedTaskfilesCantHaveIncludes = errors.New("task: Included Taskfiles can't have includes. Please, move the include to the main Taskfile") // ErrIncludedTaskfilesCantHaveDotenvs is returned when a included Taskfile contains dotenvs ErrIncludedTaskfilesCantHaveDotenvs = errors.New("task: Included Taskfiles can't have dotenv declarations. Please, move the dotenv declaration to the main Taskfile") defaultTaskfiles = []string{"Taskfile.yml", "Taskfile.yaml"} ) +type ReaderNode struct { + Dir string + Entrypoint string + Optional bool + Parent *ReaderNode +} + // Taskfile reads a Taskfile for a given directory // Uses current dir when dir is left empty. Uses Taskfile.yml // or Taskfile.yaml when entrypoint is left empty -func Taskfile(dir string, entrypoint string) (*taskfile.Taskfile, error) { - if dir == "" { +func Taskfile(readerNode *ReaderNode) (*taskfile.Taskfile, error) { + if readerNode.Dir == "" { d, err := os.Getwd() if err != nil { return nil, err } - dir = d + readerNode.Dir = d } - path, err := exists(filepath.Join(dir, entrypoint)) + path, err := exists(filepath.Join(readerNode.Dir, readerNode.Entrypoint)) if err != nil { return nil, err } + readerNode.Entrypoint = filepath.Base(path) t, err := readTaskfile(path) if err != nil { @@ -68,10 +74,32 @@ func Taskfile(dir string, entrypoint string) (*taskfile.Taskfile, error) { return err } if !filepath.IsAbs(path) { - path = filepath.Join(dir, path) + path = filepath.Join(readerNode.Dir, path) } - path, err = exists(path) + // check for cyclic include references by walking up + // node tree of parents and comparing paths + var curNode = readerNode + for curNode.Parent != nil { + curNode = curNode.Parent + curPath := filepath.Join(curNode.Dir, curNode.Entrypoint) + if curPath == path { + return fmt.Errorf("include cycle detected between %s <--> %s", + curPath, + filepath.Join(readerNode.Dir, readerNode.Entrypoint), + ) + } + } + + // if we made it here then there is no cyclic include + readOpts := &ReaderNode{ + Dir: filepath.Dir(path), + Entrypoint: filepath.Base(path), + Parent: readerNode, + Optional: includedTask.Optional, + } + + includedTaskfile, err := Taskfile(readOpts) if err != nil { if includedTask.Optional { return nil @@ -79,14 +107,6 @@ func Taskfile(dir string, entrypoint string) (*taskfile.Taskfile, error) { return err } - includedTaskfile, err := readTaskfile(path) - if err != nil { - return err - } - if includedTaskfile.Includes.Len() > 0 { - return ErrIncludedTaskfilesCantHaveIncludes - } - if v >= 3.0 && len(includedTaskfile.Dotenv) > 0 { return ErrIncludedTaskfilesCantHaveDotenvs } @@ -94,12 +114,12 @@ func Taskfile(dir string, entrypoint string) (*taskfile.Taskfile, error) { if includedTask.AdvancedImport { for k, v := range includedTaskfile.Vars.Mapping { o := v - o.Dir = filepath.Join(dir, includedTask.Dir) + o.Dir = filepath.Join(readerNode.Dir, includedTask.Dir) includedTaskfile.Vars.Mapping[k] = o } for k, v := range includedTaskfile.Env.Mapping { o := v - o.Dir = filepath.Join(dir, includedTask.Dir) + o.Dir = filepath.Join(readerNode.Dir, includedTask.Dir) includedTaskfile.Env.Mapping[k] = o } @@ -120,7 +140,7 @@ func Taskfile(dir string, entrypoint string) (*taskfile.Taskfile, error) { } if v < 3.0 { - path = filepath.Join(dir, fmt.Sprintf("Taskfile_%s.yml", runtime.GOOS)) + path = filepath.Join(readerNode.Dir, fmt.Sprintf("Taskfile_%s.yml", runtime.GOOS)) if _, err = os.Stat(path); err == nil { osTaskfile, err := readTaskfile(path) if err != nil { diff --git a/testdata/includes_cycle/Taskfile.yml b/testdata/includes_cycle/Taskfile.yml new file mode 100644 index 00000000..56748cd7 --- /dev/null +++ b/testdata/includes_cycle/Taskfile.yml @@ -0,0 +1,12 @@ +version: '3' + +includes: + 'one': ./one/Taskfile.yml + +tasks: + default: + cmds: + - echo "called_dep" > called_dep.txt + level1: + cmds: + - echo "hello level 1" \ No newline at end of file diff --git a/testdata/includes_cycle/one/Taskfile.yml b/testdata/includes_cycle/one/Taskfile.yml new file mode 100644 index 00000000..a948df55 --- /dev/null +++ b/testdata/includes_cycle/one/Taskfile.yml @@ -0,0 +1,9 @@ +version: '3' + +includes: + 'two': ./two/Taskfile.yml + +tasks: + level2: + cmds: + - echo "hello level 2" \ No newline at end of file diff --git a/testdata/includes_cycle/one/two/Taskfile.yml b/testdata/includes_cycle/one/two/Taskfile.yml new file mode 100644 index 00000000..b01e4642 --- /dev/null +++ b/testdata/includes_cycle/one/two/Taskfile.yml @@ -0,0 +1,9 @@ +version: '3' + +includes: + bad: "../../Taskfile.yml" + +tasks: + level3: + cmds: + - echo "hello level 3" \ No newline at end of file diff --git a/testdata/includes_multi_level/Taskfile.yml b/testdata/includes_multi_level/Taskfile.yml new file mode 100644 index 00000000..56748cd7 --- /dev/null +++ b/testdata/includes_multi_level/Taskfile.yml @@ -0,0 +1,12 @@ +version: '3' + +includes: + 'one': ./one/Taskfile.yml + +tasks: + default: + cmds: + - echo "called_dep" > called_dep.txt + level1: + cmds: + - echo "hello level 1" \ No newline at end of file diff --git a/testdata/includes_multi_level/one/Taskfile.yml b/testdata/includes_multi_level/one/Taskfile.yml new file mode 100644 index 00000000..a948df55 --- /dev/null +++ b/testdata/includes_multi_level/one/Taskfile.yml @@ -0,0 +1,9 @@ +version: '3' + +includes: + 'two': ./two/Taskfile.yml + +tasks: + level2: + cmds: + - echo "hello level 2" \ No newline at end of file diff --git a/testdata/includes_multi_level/one/two/Taskfile.yml b/testdata/includes_multi_level/one/two/Taskfile.yml new file mode 100644 index 00000000..738fa5ae --- /dev/null +++ b/testdata/includes_multi_level/one/two/Taskfile.yml @@ -0,0 +1,6 @@ +version: '3' + +tasks: + level3: + cmds: + - echo "hello level 3" \ No newline at end of file