From 08a888dc8aa9e35c1c10e2496bf55e92cd37b0f3 Mon Sep 17 00:00:00 2001 From: Pete Davison Date: Sun, 10 Mar 2024 17:11:07 +0000 Subject: [PATCH] feat: parse templates in collection-type variables (#1526) * refactor: replacer * feat: move traverser to deepcopy package * feat: nested map variable templating * refactor: ReplaceVar function * feat: test cases * fix: TraverseStringsFunc copy value instead of pointer --- internal/compiler/compiler.go | 23 ++---- internal/deepcopy/deepcopy.go | 106 ++++++++++++++++++++++++++ internal/output/group.go | 8 +- internal/output/interleaved.go | 4 +- internal/output/output.go | 10 +-- internal/output/output_test.go | 2 +- internal/output/prefixed.go | 4 +- internal/templater/templater.go | 128 +++++++++++++++----------------- task.go | 2 +- taskfile/dotenv.go | 5 +- taskfile/reader.go | 8 +- testdata/vars/any2/Taskfile.yml | 20 +++++ variables.go | 58 +++++++-------- 13 files changed, 243 insertions(+), 135 deletions(-) diff --git a/internal/compiler/compiler.go b/internal/compiler/compiler.go index ec10c245..db8aea5f 100644 --- a/internal/compiler/compiler.go +++ b/internal/compiler/compiler.go @@ -59,20 +59,9 @@ func (c *Compiler) getVariables(t *ast.Task, call *ast.Call, evaluateShVars bool getRangeFunc := func(dir string) func(k string, v ast.Var) error { return func(k string, v ast.Var) error { - tr := templater.Templater{Vars: result} + cache := &templater.Cache{Vars: result} // Replace values - newVar := ast.Var{} - switch value := v.Value.(type) { - case string: - newVar.Value = tr.Replace(value) - default: - newVar.Value = value - } - newVar.Sh = tr.Replace(v.Sh) - newVar.Ref = v.Ref - newVar.Json = tr.Replace(v.Json) - newVar.Yaml = tr.Replace(v.Yaml) - newVar.Dir = v.Dir + newVar := templater.ReplaceVar(v, cache) // If the variable is a reference, we can resolve it if newVar.Ref != "" { newVar.Value = result.Get(newVar.Ref).Value @@ -89,7 +78,7 @@ func (c *Compiler) getVariables(t *ast.Task, call *ast.Call, evaluateShVars bool return nil } // Now we can check for errors since we've handled all the cases when we don't want to evaluate - if err := tr.Err(); err != nil { + if err := cache.Err(); err != nil { return err } // Evaluate JSON @@ -124,9 +113,9 @@ func (c *Compiler) getVariables(t *ast.Task, call *ast.Call, evaluateShVars bool if t != nil { // NOTE(@andreynering): We're manually joining these paths here because // this is the raw task, not the compiled one. - tr := templater.Templater{Vars: result} - dir := tr.Replace(t.Dir) - if err := tr.Err(); err != nil { + cache := &templater.Cache{Vars: result} + dir := templater.Replace(t.Dir, cache) + if err := cache.Err(); err != nil { return nil, err } dir = filepathext.SmartJoin(c.Dir, dir) diff --git a/internal/deepcopy/deepcopy.go b/internal/deepcopy/deepcopy.go index e378f9fd..261d443e 100644 --- a/internal/deepcopy/deepcopy.go +++ b/internal/deepcopy/deepcopy.go @@ -1,5 +1,9 @@ package deepcopy +import ( + "reflect" +) + type Copier[T any] interface { DeepCopy() T } @@ -33,3 +37,105 @@ func Map[K comparable, V any](orig map[K]V) map[K]V { } return c } + +// TraverseStringsFunc runs the given function on every string in the given +// value by traversing it recursively. If the given value is a string, the +// function will run on a copy of the string and return it. If the value is a +// struct, map or a slice, the function will recursively call itself for each +// field or element of the struct, map or slice until all strings inside the +// struct or slice are replaced. +func TraverseStringsFunc[T any](v T, fn func(v string) (string, error)) (T, error) { + original := reflect.ValueOf(v) + if original.Kind() == reflect.Invalid || !original.IsValid() { + return v, nil + } + copy := reflect.New(original.Type()).Elem() + + var traverseFunc func(copy, v reflect.Value) error + traverseFunc = func(copy, v reflect.Value) error { + switch v.Kind() { + + case reflect.Ptr: + // Unwrap the pointer + originalValue := v.Elem() + // If the pointer is nil, do nothing + if !originalValue.IsValid() { + return nil + } + // Create an empty copy from the original value's type + copy.Set(reflect.New(originalValue.Type())) + // Unwrap the newly created pointer and call traverseFunc recursively + if err := traverseFunc(copy.Elem(), originalValue); err != nil { + return err + } + + case reflect.Interface: + // Unwrap the interface + originalValue := v.Elem() + if !originalValue.IsValid() { + return nil + } + // Create an empty copy from the original value's type + copyValue := reflect.New(originalValue.Type()).Elem() + // Unwrap the newly created pointer and call traverseFunc recursively + if err := traverseFunc(copyValue, originalValue); err != nil { + return err + } + copy.Set(copyValue) + + case reflect.Struct: + // Loop over each field and call traverseFunc recursively + for i := 0; i < v.NumField(); i += 1 { + if err := traverseFunc(copy.Field(i), v.Field(i)); err != nil { + return err + } + } + + case reflect.Slice: + // Create an empty copy from the original value's type + copy.Set(reflect.MakeSlice(v.Type(), v.Len(), v.Cap())) + // Loop over each element and call traverseFunc recursively + for i := 0; i < v.Len(); i += 1 { + if err := traverseFunc(copy.Index(i), v.Index(i)); err != nil { + return err + } + } + + case reflect.Map: + // Create an empty copy from the original value's type + copy.Set(reflect.MakeMap(v.Type())) + // Loop over each key + for _, key := range v.MapKeys() { + // Create a copy of each map index + originalValue := v.MapIndex(key) + if originalValue.IsNil() { + continue + } + copyValue := reflect.New(originalValue.Type()).Elem() + // Call traverseFunc recursively + if err := traverseFunc(copyValue, originalValue); err != nil { + return err + } + copy.SetMapIndex(key, copyValue) + } + + case reflect.String: + rv, err := fn(v.String()) + if err != nil { + return err + } + copy.Set(reflect.ValueOf(rv)) + + default: + copy.Set(v) + } + + return nil + } + + if err := traverseFunc(copy, original); err != nil { + return v, err + } + + return copy.Interface().(T), nil +} diff --git a/internal/output/group.go b/internal/output/group.go index c602cd17..46973a27 100644 --- a/internal/output/group.go +++ b/internal/output/group.go @@ -3,6 +3,8 @@ package output import ( "bytes" "io" + + "github.com/go-task/task/v3/internal/templater" ) type Group struct { @@ -10,13 +12,13 @@ type Group struct { ErrorOnly bool } -func (g Group) WrapWriter(stdOut, _ io.Writer, _ string, tmpl Templater) (io.Writer, io.Writer, CloseFunc) { +func (g Group) WrapWriter(stdOut, _ io.Writer, _ string, cache *templater.Cache) (io.Writer, io.Writer, CloseFunc) { gw := &groupWriter{writer: stdOut} if g.Begin != "" { - gw.begin = tmpl.Replace(g.Begin) + "\n" + gw.begin = templater.Replace(g.Begin, cache) + "\n" } if g.End != "" { - gw.end = tmpl.Replace(g.End) + "\n" + gw.end = templater.Replace(g.End, cache) + "\n" } return gw, gw, func(err error) error { if g.ErrorOnly && err == nil { diff --git a/internal/output/interleaved.go b/internal/output/interleaved.go index 0bdd1640..32c7fc8d 100644 --- a/internal/output/interleaved.go +++ b/internal/output/interleaved.go @@ -2,10 +2,12 @@ package output import ( "io" + + "github.com/go-task/task/v3/internal/templater" ) type Interleaved struct{} -func (Interleaved) WrapWriter(stdOut, stdErr io.Writer, _ string, _ Templater) (io.Writer, io.Writer, CloseFunc) { +func (Interleaved) WrapWriter(stdOut, stdErr io.Writer, _ string, _ *templater.Cache) (io.Writer, io.Writer, CloseFunc) { return stdOut, stdErr, func(error) error { return nil } } diff --git a/internal/output/output.go b/internal/output/output.go index 8dc7b6e3..c3c1346c 100644 --- a/internal/output/output.go +++ b/internal/output/output.go @@ -4,18 +4,12 @@ import ( "fmt" "io" + "github.com/go-task/task/v3/internal/templater" "github.com/go-task/task/v3/taskfile/ast" ) -// Templater executes a template engine. -// It is provided by the templater.Templater package. -type Templater interface { - // Replace replaces the provided template string with a rendered string. - Replace(tmpl string) string -} - type Output interface { - WrapWriter(stdOut, stdErr io.Writer, prefix string, tmpl Templater) (io.Writer, io.Writer, CloseFunc) + WrapWriter(stdOut, stdErr io.Writer, prefix string, cache *templater.Cache) (io.Writer, io.Writer, CloseFunc) } type CloseFunc func(err error) error diff --git a/internal/output/output_test.go b/internal/output/output_test.go index 9236736e..41b35552 100644 --- a/internal/output/output_test.go +++ b/internal/output/output_test.go @@ -46,7 +46,7 @@ func TestGroup(t *testing.T) { } func TestGroupWithBeginEnd(t *testing.T) { - tmpl := templater.Templater{ + tmpl := templater.Cache{ Vars: &ast.Vars{ OrderedMap: omap.FromMap(map[string]ast.Var{ "VAR1": {Value: "example-value"}, diff --git a/internal/output/prefixed.go b/internal/output/prefixed.go index da6d5e6b..cea2c7d3 100644 --- a/internal/output/prefixed.go +++ b/internal/output/prefixed.go @@ -5,11 +5,13 @@ import ( "fmt" "io" "strings" + + "github.com/go-task/task/v3/internal/templater" ) type Prefixed struct{} -func (Prefixed) WrapWriter(stdOut, _ io.Writer, prefix string, _ Templater) (io.Writer, io.Writer, CloseFunc) { +func (Prefixed) WrapWriter(stdOut, _ io.Writer, prefix string, _ *templater.Cache) (io.Writer, io.Writer, CloseFunc) { pw := &prefixWriter{writer: stdOut, prefix: prefix} return pw, pw, func(error) error { return pw.close() } } diff --git a/internal/templater/templater.go b/internal/templater/templater.go index b7cf3cd8..58b7d71a 100644 --- a/internal/templater/templater.go +++ b/internal/templater/templater.go @@ -6,122 +6,116 @@ import ( "strings" "text/template" + "github.com/go-task/task/v3/internal/deepcopy" "github.com/go-task/task/v3/taskfile/ast" ) -// Templater is a help struct that allow us to call "replaceX" funcs multiple +// Cache is a help struct that allow us to call "replaceX" funcs multiple // times, without having to check for error each time. The first error that // happen will be assigned to r.err, and consecutive calls to funcs will just // return the zero value. -type Templater struct { +type Cache struct { Vars *ast.Vars cacheMap map[string]any err error } -func (r *Templater) ResetCache() { +func (r *Cache) ResetCache() { r.cacheMap = r.Vars.ToCacheMap() } -func (r *Templater) Replace(str string) string { - return r.replace(str, nil) +func (r *Cache) Err() error { + return r.err } -func (r *Templater) ReplaceWithExtra(str string, extra map[string]any) string { - return r.replace(str, extra) +func Replace[T any](v T, cache *Cache) T { + return ReplaceWithExtra(v, cache, nil) } -func (r *Templater) replace(str string, extra map[string]any) string { - if r.err != nil || str == "" { - return "" +func ReplaceWithExtra[T any](v T, cache *Cache, extra map[string]any) T { + // If there is already an error, do nothing + if cache.err != nil { + return v } - templ, err := template.New("").Funcs(templateFuncs).Parse(str) + // Initialize the cache map if it's not already initialized + if cache.cacheMap == nil { + cache.cacheMap = cache.Vars.ToCacheMap() + } + + // Create a copy of the cache map to avoid editing the original + // If there is extra data, merge it with the cache map + data := maps.Clone(cache.cacheMap) + if extra != nil { + maps.Copy(data, extra) + } + + // Traverse the value and parse any template variables + copy, err := deepcopy.TraverseStringsFunc(v, func(v string) (string, error) { + tpl, err := template.New("").Funcs(templateFuncs).Parse(v) + if err != nil { + return v, err + } + var b bytes.Buffer + if err := tpl.Execute(&b, data); err != nil { + return v, err + } + return strings.ReplaceAll(b.String(), "", ""), nil + }) if err != nil { - r.err = err - return "" + cache.err = err + return v } - if r.cacheMap == nil { - r.cacheMap = r.Vars.ToCacheMap() - } - - var b bytes.Buffer - if extra == nil { - err = templ.Execute(&b, r.cacheMap) - } else { - // Copy the map to avoid modifying the cached map - m := maps.Clone(r.cacheMap) - maps.Copy(m, extra) - err = templ.Execute(&b, m) - } - if err != nil { - r.err = err - return "" - } - return strings.ReplaceAll(b.String(), "", "") + return copy } -func (r *Templater) ReplaceSlice(strs []string) []string { - if r.err != nil || len(strs) == 0 { - return nil - } - - new := make([]string, len(strs)) - for i, str := range strs { - new[i] = r.Replace(str) - } - return new -} - -func (r *Templater) ReplaceGlobs(globs []*ast.Glob) []*ast.Glob { - if r.err != nil || len(globs) == 0 { +func ReplaceGlobs(globs []*ast.Glob, cache *Cache) []*ast.Glob { + if cache.err != nil || len(globs) == 0 { return nil } new := make([]*ast.Glob, len(globs)) for i, g := range globs { new[i] = &ast.Glob{ - Glob: r.Replace(g.Glob), + Glob: Replace(g.Glob, cache), Negate: g.Negate, } } return new } -func (r *Templater) ReplaceVars(vars *ast.Vars) *ast.Vars { - return r.replaceVars(vars, nil) +func ReplaceVar(v ast.Var, cache *Cache) ast.Var { + return ReplaceVarWithExtra(v, cache, nil) } -func (r *Templater) ReplaceVarsWithExtra(vars *ast.Vars, extra map[string]any) *ast.Vars { - return r.replaceVars(vars, extra) +func ReplaceVarWithExtra(v ast.Var, cache *Cache, extra map[string]any) ast.Var { + return ast.Var{ + Value: ReplaceWithExtra(v.Value, cache, extra), + Sh: ReplaceWithExtra(v.Sh, cache, extra), + Live: v.Live, + Ref: v.Ref, + Dir: v.Dir, + Json: ReplaceWithExtra(v.Json, cache, extra), + Yaml: ReplaceWithExtra(v.Yaml, cache, extra), + } } -func (r *Templater) replaceVars(vars *ast.Vars, extra map[string]any) *ast.Vars { - if r.err != nil || vars.Len() == 0 { +func ReplaceVars(vars *ast.Vars, cache *Cache) *ast.Vars { + return ReplaceVarsWithExtra(vars, cache, nil) +} + +func ReplaceVarsWithExtra(vars *ast.Vars, cache *Cache, extra map[string]any) *ast.Vars { + if cache.err != nil || vars.Len() == 0 { return nil } var newVars ast.Vars _ = vars.Range(func(k string, v ast.Var) error { - var newVar ast.Var - switch value := v.Value.(type) { - case string: - newVar.Value = r.ReplaceWithExtra(value, extra) - } - newVar.Live = v.Live - newVar.Sh = r.ReplaceWithExtra(v.Sh, extra) - newVar.Ref = v.Ref - newVar.Json = r.ReplaceWithExtra(v.Json, extra) - newVar.Yaml = r.ReplaceWithExtra(v.Yaml, extra) - newVars.Set(k, newVar) + newVars.Set(k, ReplaceVarWithExtra(v, cache, extra)) return nil }) return &newVars } - -func (r *Templater) Err() error { - return r.err -} diff --git a/task.go b/task.go index eaafeaa4..a19d7ba2 100644 --- a/task.go +++ b/task.go @@ -348,7 +348,7 @@ func (e *Executor) runCommand(ctx context.Context, t *ast.Task, call *ast.Call, outputWrapper = output.Interleaved{} } vars, err := e.Compiler.FastGetVariables(t, call) - outputTemplater := &templater.Templater{Vars: vars} + outputTemplater := &templater.Cache{Vars: vars} if err != nil { return fmt.Errorf("task: failed to get variables: %w", err) } diff --git a/taskfile/dotenv.go b/taskfile/dotenv.go index 143d3156..3971a9d2 100644 --- a/taskfile/dotenv.go +++ b/taskfile/dotenv.go @@ -22,11 +22,10 @@ func Dotenv(c *compiler.Compiler, tf *ast.Taskfile, dir string) (*ast.Vars, erro } env := &ast.Vars{} - - tr := templater.Templater{Vars: vars} + cache := &templater.Cache{Vars: vars} for _, dotEnvPath := range tf.Dotenv { - dotEnvPath = tr.Replace(dotEnvPath) + dotEnvPath = templater.Replace(dotEnvPath, cache) if dotEnvPath == "" { continue } diff --git a/taskfile/reader.go b/taskfile/reader.go index 92127a35..399a8e75 100644 --- a/taskfile/reader.go +++ b/taskfile/reader.go @@ -60,11 +60,11 @@ func Read( } err = tf.Includes.Range(func(namespace string, include ast.Include) error { - tr := templater.Templater{Vars: tf.Vars} + cache := &templater.Cache{Vars: tf.Vars} include = ast.Include{ Namespace: include.Namespace, - Taskfile: tr.Replace(include.Taskfile), - Dir: tr.Replace(include.Dir), + Taskfile: templater.Replace(include.Taskfile, cache), + Dir: templater.Replace(include.Dir, cache), Optional: include.Optional, Internal: include.Internal, Aliases: include.Aliases, @@ -72,7 +72,7 @@ func Read( Vars: include.Vars, BaseDir: include.BaseDir, } - if err := tr.Err(); err != nil { + if err := cache.Err(); err != nil { return err } diff --git a/testdata/vars/any2/Taskfile.yml b/testdata/vars/any2/Taskfile.yml index 0f20932c..4a11ff86 100644 --- a/testdata/vars/any2/Taskfile.yml +++ b/testdata/vars/any2/Taskfile.yml @@ -3,6 +3,8 @@ version: '3' tasks: default: - task: map + - task: nested-map + - task: slice - task: ref - task: ref-sh - task: ref-dep @@ -19,6 +21,24 @@ tasks: VAR: ref: MAP + nested-map: + vars: + FOO: "foo" + nested: + map: + variables: + work: "{{.FOO}}" + cmds: + - echo {{.nested.variables.work}} + + slice: + vars: + FOO: "foo" + BAR: "bar" + slice_variables_work: ["{{.FOO}}","{{.BAR}}"] + cmds: + - echo {{index .slice_variables_work 0}} {{index .slice_variables_work 1}} + ref: vars: MAP: diff --git a/variables.go b/variables.go index 1322d5c9..39ec98dd 100644 --- a/variables.go +++ b/variables.go @@ -42,30 +42,30 @@ func (e *Executor) compiledTask(call *ast.Call, evaluateShVars bool) (*ast.Task, return nil, err } - r := templater.Templater{Vars: vars} + cache := &templater.Cache{Vars: vars} new := ast.Task{ Task: origTask.Task, - Label: r.Replace(origTask.Label), - Desc: r.Replace(origTask.Desc), - Prompt: r.Replace(origTask.Prompt), - Summary: r.Replace(origTask.Summary), + Label: templater.Replace(origTask.Label, cache), + Desc: templater.Replace(origTask.Desc, cache), + Prompt: templater.Replace(origTask.Prompt, cache), + Summary: templater.Replace(origTask.Summary, cache), Aliases: origTask.Aliases, - Sources: r.ReplaceGlobs(origTask.Sources), - Generates: r.ReplaceGlobs(origTask.Generates), - Dir: r.Replace(origTask.Dir), + Sources: templater.ReplaceGlobs(origTask.Sources, cache), + Generates: templater.ReplaceGlobs(origTask.Generates, cache), + Dir: templater.Replace(origTask.Dir, cache), Set: origTask.Set, Shopt: origTask.Shopt, Vars: nil, Env: nil, - Dotenv: r.ReplaceSlice(origTask.Dotenv), + Dotenv: templater.Replace(origTask.Dotenv, cache), Silent: origTask.Silent, Interactive: origTask.Interactive, Internal: origTask.Internal, - Method: r.Replace(origTask.Method), - Prefix: r.Replace(origTask.Prefix), + Method: templater.Replace(origTask.Method, cache), + Prefix: templater.Replace(origTask.Prefix, cache), IgnoreError: origTask.IgnoreError, - Run: r.Replace(origTask.Run), + Run: templater.Replace(origTask.Run, cache), IncludeVars: origTask.IncludeVars, IncludedTaskfileVars: origTask.IncludedTaskfileVars, Platforms: origTask.Platforms, @@ -104,9 +104,9 @@ func (e *Executor) compiledTask(call *ast.Call, evaluateShVars bool) (*ast.Task, } new.Env = &ast.Vars{} - new.Env.Merge(r.ReplaceVars(e.Taskfile.Env)) - new.Env.Merge(r.ReplaceVars(dotenvEnvs)) - new.Env.Merge(r.ReplaceVars(origTask.Env)) + new.Env.Merge(templater.ReplaceVars(e.Taskfile.Env, cache)) + new.Env.Merge(templater.ReplaceVars(dotenvEnvs, cache)) + new.Env.Merge(templater.ReplaceVars(origTask.Env, cache)) if evaluateShVars { err = new.Env.Range(func(k string, v ast.Var) error { // If the variable is not dynamic, we can set it and return @@ -200,17 +200,17 @@ func (e *Executor) compiledTask(call *ast.Call, evaluateShVars bool) (*ast.Task, extra["KEY"] = keys[i] } newCmd := cmd.DeepCopy() - newCmd.Cmd = r.ReplaceWithExtra(cmd.Cmd, extra) - newCmd.Task = r.ReplaceWithExtra(cmd.Task, extra) - newCmd.Vars = r.ReplaceVarsWithExtra(cmd.Vars, extra) + newCmd.Cmd = templater.ReplaceWithExtra(cmd.Cmd, cache, extra) + newCmd.Task = templater.ReplaceWithExtra(cmd.Task, cache, extra) + newCmd.Vars = templater.ReplaceVarsWithExtra(cmd.Vars, cache, extra) new.Cmds = append(new.Cmds, newCmd) } continue } newCmd := cmd.DeepCopy() - newCmd.Cmd = r.Replace(cmd.Cmd) - newCmd.Task = r.Replace(cmd.Task) - newCmd.Vars = r.ReplaceVars(cmd.Vars) + newCmd.Cmd = templater.Replace(cmd.Cmd, cache) + newCmd.Task = templater.Replace(cmd.Task, cache) + newCmd.Vars = templater.ReplaceVars(cmd.Vars, cache) // Loop over the command's variables and resolve any references to other variables err := cmd.Vars.Range(func(k string, v ast.Var) error { if v.Ref != "" { @@ -232,8 +232,8 @@ func (e *Executor) compiledTask(call *ast.Call, evaluateShVars bool) (*ast.Task, continue } newDep := dep.DeepCopy() - newDep.Task = r.Replace(dep.Task) - newDep.Vars = r.ReplaceVars(dep.Vars) + newDep.Task = templater.Replace(dep.Task, cache) + newDep.Vars = templater.ReplaceVars(dep.Vars, cache) // Loop over the dep's variables and resolve any references to other variables err := dep.Vars.Range(func(k string, v ast.Var) error { if v.Ref != "" { @@ -256,8 +256,8 @@ func (e *Executor) compiledTask(call *ast.Call, evaluateShVars bool) (*ast.Task, continue } newPrecondition := precondition.DeepCopy() - newPrecondition.Sh = r.Replace(precondition.Sh) - newPrecondition.Msg = r.Replace(precondition.Msg) + newPrecondition.Sh = templater.Replace(precondition.Sh, cache) + newPrecondition.Msg = templater.Replace(precondition.Msg, cache) new.Preconditions = append(new.Preconditions, newPrecondition) } } @@ -276,14 +276,14 @@ func (e *Executor) compiledTask(call *ast.Call, evaluateShVars bool) (*ast.Task, // Adding new variables, requires us to refresh the templaters // cache of the the values manually - r.ResetCache() + cache.ResetCache() - new.Status = r.ReplaceSlice(origTask.Status) + new.Status = templater.Replace(origTask.Status, cache) } // We only care about templater errors if we are evaluating shell variables - if evaluateShVars && r.Err() != nil { - return &new, r.Err() + if evaluateShVars && cache.Err() != nil { + return &new, cache.Err() } return &new, nil