diff --git a/go.mod b/go.mod index c4c4879..14c6119 100644 --- a/go.mod +++ b/go.mod @@ -4,3 +4,5 @@ require ( github.com/alexflint/go-scalar v1.0.0 github.com/stretchr/testify v1.2.2 ) + +go 1.13 diff --git a/parse.go b/parse.go index a29258a..d234ed2 100644 --- a/parse.go +++ b/parse.go @@ -54,6 +54,7 @@ type spec struct { help string env string boolean bool + defaultVal string // default value for this option, only if provided as a struct tag } // command represents a named subcommand, or the top-level command @@ -250,6 +251,11 @@ func cmdFromStruct(name string, dest path, t reflect.Type) (*command, error) { spec.help = help } + defaultVal, hasDefault := field.Tag.Lookup("default") + if hasDefault { + spec.defaultVal = defaultVal + } + // Look at the tag var isSubcommand bool // tracks whether this field is a subcommand if tag != "" { @@ -274,6 +280,11 @@ func cmdFromStruct(name string, dest path, t reflect.Type) (*command, error) { } spec.short = key[1:] case key == "required": + if hasDefault { + errs = append(errs, fmt.Sprintf("%s.%s: 'required' cannot be used when a default value is specified", + t.Name(), field.Name)) + return false + } spec.required = true case key == "positional": spec.positional = true @@ -328,6 +339,11 @@ func cmdFromStruct(name string, dest path, t reflect.Type) (*command, error) { t.Name(), field.Name, field.Type.String())) return false } + if spec.multiple && hasDefault { + errs = append(errs, fmt.Sprintf("%s.%s: default values are not supported for slice fields", + t.Name(), field.Name)) + return false + } } // if this was an embedded field then we already returned true up above @@ -570,15 +586,26 @@ func (p *Parser) process(args []string) error { return fmt.Errorf("too many positional arguments at '%s'", positionals[0]) } - // finally check that all the required args were provided + // fill in defaults and check that all the required args were provided for _, spec := range specs { - if spec.required && !wasPresent[spec] { - name := spec.long - if !spec.positional { - name = "--" + spec.long - } + if wasPresent[spec] { + continue + } + + name := spec.long + if !spec.positional { + name = "--" + spec.long + } + + if spec.required { return fmt.Errorf("%s is required", name) } + if spec.defaultVal != "" { + err := scalar.ParseValue(p.val(spec.dest), spec.defaultVal) + if err != nil { + return fmt.Errorf("error processing default value for %s: %v", name, err) + } + } } return nil diff --git a/parse_test.go b/parse_test.go index 5909472..9cd8bce 100644 --- a/parse_test.go +++ b/parse_test.go @@ -1057,3 +1057,71 @@ func TestMultipleTerminates(t *testing.T) { assert.Equal(t, []string{"a", "b"}, args.X) assert.Equal(t, "c", args.Y) } + +func TestDefaultOptionValues(t *testing.T) { + var args struct { + A int `default:"123"` + B *int `default:"123"` + C string `default:"abc"` + D *string `default:"abc"` + E float64 `default:"1.23"` + F *float64 `default:"1.23"` + G bool `default:"true"` + H *bool `default:"true"` + } + + err := parse("--c=xyz --e=4.56", &args) + require.NoError(t, err) + + assert.Equal(t, 123, args.A) + assert.Equal(t, 123, *args.B) + assert.Equal(t, "xyz", args.C) + assert.Equal(t, "abc", *args.D) + assert.Equal(t, 4.56, args.E) + assert.Equal(t, 1.23, *args.F) + assert.True(t, args.G) + assert.True(t, args.G) +} + +func TestDefaultPositionalValues(t *testing.T) { + var args struct { + A int `arg:"positional" default:"123"` + B *int `arg:"positional" default:"123"` + C string `arg:"positional" default:"abc"` + D *string `arg:"positional" default:"abc"` + E float64 `arg:"positional" default:"1.23"` + F *float64 `arg:"positional" default:"1.23"` + G bool `arg:"positional" default:"true"` + H *bool `arg:"positional" default:"true"` + } + + err := parse("456 789", &args) + require.NoError(t, err) + + assert.Equal(t, 456, args.A) + assert.Equal(t, 789, *args.B) + assert.Equal(t, "abc", args.C) + assert.Equal(t, "abc", *args.D) + assert.Equal(t, 1.23, args.E) + assert.Equal(t, 1.23, *args.F) + assert.True(t, args.G) + assert.True(t, args.G) +} + +func TestDefaultValuesNotAllowedWithRequired(t *testing.T) { + var args struct { + A int `arg:"required" default:"123"` // required not allowed with default! + } + + err := parse("", &args) + assert.EqualError(t, err, ".A: 'required' cannot be used when a default value is specified") +} + +func TestDefaultValuesNotAllowedWithSlice(t *testing.T) { + var args struct { + A []int `default:"123"` // required not allowed with default! + } + + err := parse("", &args) + assert.EqualError(t, err, ".A: default values are not supported for slice fields") +}