Merge pull request #91 from alexflint/defaults
Allow default values in struct tags
This commit is contained in:
commit
c0c7a3ba8a
|
@ -1,10 +1,7 @@
|
|||
language: go
|
||||
go:
|
||||
- "1.10"
|
||||
- "1.12"
|
||||
- tip
|
||||
env:
|
||||
- GO111MODULE=on # will only be used in go 1.11
|
||||
- "1.13"
|
||||
before_install:
|
||||
- go get github.com/axw/gocov/gocov
|
||||
- go get github.com/mattn/goveralls
|
||||
|
|
15
README.md
15
README.md
|
@ -140,12 +140,22 @@ Options:
|
|||
|
||||
### Default values
|
||||
|
||||
```go
|
||||
var args struct {
|
||||
Foo string `default:"abc"`
|
||||
Bar bool
|
||||
}
|
||||
arg.MustParse(&args)
|
||||
```
|
||||
|
||||
### Default values (before v1.2)
|
||||
|
||||
```go
|
||||
var args struct {
|
||||
Foo string
|
||||
Bar bool
|
||||
}
|
||||
args.Foo = "default value"
|
||||
arg.Foo = "abc"
|
||||
arg.MustParse(&args)
|
||||
```
|
||||
|
||||
|
@ -307,9 +317,8 @@ func (n *NameDotName) MarshalText() ([]byte, error) {
|
|||
|
||||
func main() {
|
||||
var args struct {
|
||||
Name NameDotName
|
||||
Name NameDotName `default:"file.txt"`
|
||||
}
|
||||
args.Name = NameDotName{"file", "txt"} // set default value
|
||||
arg.MustParse(&args)
|
||||
fmt.Printf("%#v\n", args.Name)
|
||||
}
|
||||
|
|
|
@ -30,12 +30,11 @@ func Example_defaultValues() {
|
|||
os.Args = split("./example")
|
||||
|
||||
var args struct {
|
||||
Foo string
|
||||
Foo string `default:"abc"`
|
||||
}
|
||||
args.Foo = "default value"
|
||||
MustParse(&args)
|
||||
fmt.Println(args.Foo)
|
||||
// output: default value
|
||||
// output: abc
|
||||
}
|
||||
|
||||
// This example demonstrates arguments that are required
|
||||
|
|
68
parse.go
68
parse.go
|
@ -1,6 +1,7 @@
|
|||
package arg
|
||||
|
||||
import (
|
||||
"encoding"
|
||||
"encoding/csv"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
@ -54,6 +55,7 @@ type spec struct {
|
|||
help string
|
||||
env string
|
||||
boolean bool
|
||||
defaultVal string // default value for this option
|
||||
}
|
||||
|
||||
// command represents a named subcommand, or the top-level command
|
||||
|
@ -192,6 +194,22 @@ func NewParser(config Config, dests ...interface{}) (*Parser, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// add nonzero field values as defaults
|
||||
for _, spec := range cmd.specs {
|
||||
if v := p.val(spec.dest); v.IsValid() && !isZero(v) {
|
||||
if defaultVal, ok := v.Interface().(encoding.TextMarshaler); ok {
|
||||
str, err := defaultVal.MarshalText()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%v: error marshaling default value to string: %v", spec.dest, err)
|
||||
}
|
||||
spec.defaultVal = string(str)
|
||||
} else {
|
||||
spec.defaultVal = fmt.Sprintf("%v", v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
p.cmd.specs = append(p.cmd.specs, cmd.specs...)
|
||||
p.cmd.subcommands = append(p.cmd.subcommands, cmd.subcommands...)
|
||||
|
||||
|
@ -250,6 +268,11 @@ func cmdFromStruct(name string, dest path, t reflect.Type) (*command, error) {
|
|||
spec.help = help
|
||||
}
|
||||
|
||||
defaultVal, hasDefault := field.Tag.Lookup("default")
|
||||
if hasDefault {
|
||||
spec.defaultVal = defaultVal
|
||||
}
|
||||
|
||||
// Look at the tag
|
||||
var isSubcommand bool // tracks whether this field is a subcommand
|
||||
if tag != "" {
|
||||
|
@ -274,6 +297,11 @@ func cmdFromStruct(name string, dest path, t reflect.Type) (*command, error) {
|
|||
}
|
||||
spec.short = key[1:]
|
||||
case key == "required":
|
||||
if hasDefault {
|
||||
errs = append(errs, fmt.Sprintf("%s.%s: 'required' cannot be used when a default value is specified",
|
||||
t.Name(), field.Name))
|
||||
return false
|
||||
}
|
||||
spec.required = true
|
||||
case key == "positional":
|
||||
spec.positional = true
|
||||
|
@ -328,6 +356,11 @@ func cmdFromStruct(name string, dest path, t reflect.Type) (*command, error) {
|
|||
t.Name(), field.Name, field.Type.String()))
|
||||
return false
|
||||
}
|
||||
if spec.multiple && hasDefault {
|
||||
errs = append(errs, fmt.Sprintf("%s.%s: default values are not supported for slice fields",
|
||||
t.Name(), field.Name))
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// if this was an embedded field then we already returned true up above
|
||||
|
@ -570,15 +603,26 @@ func (p *Parser) process(args []string) error {
|
|||
return fmt.Errorf("too many positional arguments at '%s'", positionals[0])
|
||||
}
|
||||
|
||||
// finally check that all the required args were provided
|
||||
// fill in defaults and check that all the required args were provided
|
||||
for _, spec := range specs {
|
||||
if spec.required && !wasPresent[spec] {
|
||||
name := spec.long
|
||||
if !spec.positional {
|
||||
name = "--" + spec.long
|
||||
}
|
||||
if wasPresent[spec] {
|
||||
continue
|
||||
}
|
||||
|
||||
name := spec.long
|
||||
if !spec.positional {
|
||||
name = "--" + spec.long
|
||||
}
|
||||
|
||||
if spec.required {
|
||||
return fmt.Errorf("%s is required", name)
|
||||
}
|
||||
if spec.defaultVal != "" {
|
||||
err := scalar.ParseValue(p.val(spec.dest), spec.defaultVal)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error processing default value for %s: %v", name, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
|
@ -679,3 +723,15 @@ func findSubcommand(cmds []*command, name string) *command {
|
|||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// isZero returns true if v contains the zero value for its type
|
||||
func isZero(v reflect.Value) bool {
|
||||
t := v.Type()
|
||||
if t.Kind() == reflect.Slice {
|
||||
return v.IsNil()
|
||||
}
|
||||
if !t.Comparable() {
|
||||
return false
|
||||
}
|
||||
return v.Interface() == reflect.Zero(t).Interface()
|
||||
}
|
||||
|
|
|
@ -1057,3 +1057,80 @@ func TestMultipleTerminates(t *testing.T) {
|
|||
assert.Equal(t, []string{"a", "b"}, args.X)
|
||||
assert.Equal(t, "c", args.Y)
|
||||
}
|
||||
|
||||
func TestDefaultOptionValues(t *testing.T) {
|
||||
var args struct {
|
||||
A int `default:"123"`
|
||||
B *int `default:"123"`
|
||||
C string `default:"abc"`
|
||||
D *string `default:"abc"`
|
||||
E float64 `default:"1.23"`
|
||||
F *float64 `default:"1.23"`
|
||||
G bool `default:"true"`
|
||||
H *bool `default:"true"`
|
||||
}
|
||||
|
||||
err := parse("--c=xyz --e=4.56", &args)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, 123, args.A)
|
||||
assert.Equal(t, 123, *args.B)
|
||||
assert.Equal(t, "xyz", args.C)
|
||||
assert.Equal(t, "abc", *args.D)
|
||||
assert.Equal(t, 4.56, args.E)
|
||||
assert.Equal(t, 1.23, *args.F)
|
||||
assert.True(t, args.G)
|
||||
assert.True(t, args.G)
|
||||
}
|
||||
|
||||
func TestDefaultUnparseable(t *testing.T) {
|
||||
var args struct {
|
||||
A int `default:"x"`
|
||||
}
|
||||
|
||||
err := parse("", &args)
|
||||
assert.EqualError(t, err, `error processing default value for --a: strconv.ParseInt: parsing "x": invalid syntax`)
|
||||
}
|
||||
|
||||
func TestDefaultPositionalValues(t *testing.T) {
|
||||
var args struct {
|
||||
A int `arg:"positional" default:"123"`
|
||||
B *int `arg:"positional" default:"123"`
|
||||
C string `arg:"positional" default:"abc"`
|
||||
D *string `arg:"positional" default:"abc"`
|
||||
E float64 `arg:"positional" default:"1.23"`
|
||||
F *float64 `arg:"positional" default:"1.23"`
|
||||
G bool `arg:"positional" default:"true"`
|
||||
H *bool `arg:"positional" default:"true"`
|
||||
}
|
||||
|
||||
err := parse("456 789", &args)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, 456, args.A)
|
||||
assert.Equal(t, 789, *args.B)
|
||||
assert.Equal(t, "abc", args.C)
|
||||
assert.Equal(t, "abc", *args.D)
|
||||
assert.Equal(t, 1.23, args.E)
|
||||
assert.Equal(t, 1.23, *args.F)
|
||||
assert.True(t, args.G)
|
||||
assert.True(t, args.G)
|
||||
}
|
||||
|
||||
func TestDefaultValuesNotAllowedWithRequired(t *testing.T) {
|
||||
var args struct {
|
||||
A int `arg:"required" default:"123"` // required not allowed with default!
|
||||
}
|
||||
|
||||
err := parse("", &args)
|
||||
assert.EqualError(t, err, ".A: 'required' cannot be used when a default value is specified")
|
||||
}
|
||||
|
||||
func TestDefaultValuesNotAllowedWithSlice(t *testing.T) {
|
||||
var args struct {
|
||||
A []int `default:"123"` // required not allowed with default!
|
||||
}
|
||||
|
||||
err := parse("", &args)
|
||||
assert.EqualError(t, err, ".A: default values are not supported for slice fields")
|
||||
}
|
||||
|
|
36
usage.go
36
usage.go
|
@ -1,11 +1,9 @@
|
|||
package arg
|
||||
|
||||
import (
|
||||
"encoding"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"reflect"
|
||||
"strings"
|
||||
)
|
||||
|
||||
|
@ -94,7 +92,7 @@ func (p *Parser) writeUsageForCommand(w io.Writer, cmd *command) {
|
|||
fmt.Fprint(w, "\n")
|
||||
}
|
||||
|
||||
func printTwoCols(w io.Writer, left, help string, defaultVal *string) {
|
||||
func printTwoCols(w io.Writer, left, help string, defaultVal string) {
|
||||
lhs := " " + left
|
||||
fmt.Fprint(w, lhs)
|
||||
if help != "" {
|
||||
|
@ -105,8 +103,8 @@ func printTwoCols(w io.Writer, left, help string, defaultVal *string) {
|
|||
}
|
||||
fmt.Fprint(w, help)
|
||||
}
|
||||
if defaultVal != nil {
|
||||
fmt.Fprintf(w, " [default: %s]", *defaultVal)
|
||||
if defaultVal != "" {
|
||||
fmt.Fprintf(w, " [default: %s]", defaultVal)
|
||||
}
|
||||
fmt.Fprint(w, "\n")
|
||||
}
|
||||
|
@ -136,7 +134,7 @@ func (p *Parser) writeHelpForCommand(w io.Writer, cmd *command) {
|
|||
if len(positionals) > 0 {
|
||||
fmt.Fprint(w, "\nPositional arguments:\n")
|
||||
for _, spec := range positionals {
|
||||
printTwoCols(w, strings.ToUpper(spec.long), spec.help, nil)
|
||||
printTwoCols(w, strings.ToUpper(spec.long), spec.help, "")
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -165,7 +163,7 @@ func (p *Parser) writeHelpForCommand(w io.Writer, cmd *command) {
|
|||
if len(cmd.subcommands) > 0 {
|
||||
fmt.Fprint(w, "\nCommands:\n")
|
||||
for _, subcmd := range cmd.subcommands {
|
||||
printTwoCols(w, subcmd.name, subcmd.help, nil)
|
||||
printTwoCols(w, subcmd.name, subcmd.help, "")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -175,29 +173,7 @@ func (p *Parser) printOption(w io.Writer, spec *spec) {
|
|||
if spec.short != "" {
|
||||
left += ", " + synopsis(spec, "-"+spec.short)
|
||||
}
|
||||
|
||||
// If spec.dest is not the zero value then a default value has been added.
|
||||
var v reflect.Value
|
||||
if len(spec.dest.fields) > 0 {
|
||||
v = p.val(spec.dest)
|
||||
}
|
||||
|
||||
var defaultVal *string
|
||||
if v.IsValid() {
|
||||
z := reflect.Zero(v.Type())
|
||||
if (v.Type().Comparable() && z.Type().Comparable() && v.Interface() != z.Interface()) || v.Kind() == reflect.Slice && !v.IsNil() {
|
||||
if scalar, ok := v.Interface().(encoding.TextMarshaler); ok {
|
||||
if value, err := scalar.MarshalText(); err != nil {
|
||||
defaultVal = ptrTo(fmt.Sprintf("error: %v", err))
|
||||
} else {
|
||||
defaultVal = ptrTo(fmt.Sprintf("%v", string(value)))
|
||||
}
|
||||
} else {
|
||||
defaultVal = ptrTo(fmt.Sprintf("%v", v))
|
||||
}
|
||||
}
|
||||
}
|
||||
printTwoCols(w, left, spec.help, defaultVal)
|
||||
printTwoCols(w, left, spec.help, spec.defaultVal)
|
||||
}
|
||||
|
||||
func synopsis(spec *spec, form string) string {
|
||||
|
|
|
@ -96,26 +96,37 @@ func (n *MyEnum) MarshalText() ([]byte, error) {
|
|||
return nil, errors.New("There was a problem")
|
||||
}
|
||||
|
||||
func TestUsageError(t *testing.T) {
|
||||
expectedHelp := `Usage: example [--name NAME]
|
||||
func TestUsageWithDefaults(t *testing.T) {
|
||||
expectedHelp := `Usage: example [--label LABEL] [--content CONTENT]
|
||||
|
||||
Options:
|
||||
--name NAME [default: error: There was a problem]
|
||||
--label LABEL [default: cat]
|
||||
--content CONTENT [default: dog]
|
||||
--help, -h display this help and exit
|
||||
`
|
||||
var args struct {
|
||||
Label string
|
||||
Content string `default:"dog"`
|
||||
}
|
||||
args.Label = "cat"
|
||||
p, err := NewParser(Config{"example"}, &args)
|
||||
require.NoError(t, err)
|
||||
|
||||
args.Label = "should_ignore_this"
|
||||
|
||||
var help bytes.Buffer
|
||||
p.WriteHelp(&help)
|
||||
assert.Equal(t, expectedHelp, help.String())
|
||||
}
|
||||
|
||||
func TestUsageCannotMarshalToString(t *testing.T) {
|
||||
var args struct {
|
||||
Name *MyEnum
|
||||
}
|
||||
v := MyEnum(42)
|
||||
args.Name = &v
|
||||
p, err := NewParser(Config{"example"}, &args)
|
||||
|
||||
// NB: some might might expect there to be an error here
|
||||
require.NoError(t, err)
|
||||
|
||||
var help bytes.Buffer
|
||||
p.WriteHelp(&help)
|
||||
assert.Equal(t, expectedHelp, help.String())
|
||||
_, err := NewParser(Config{"example"}, &args)
|
||||
assert.EqualError(t, err, `args.Name: error marshaling default value to string: There was a problem`)
|
||||
}
|
||||
|
||||
func TestUsageLongPositionalWithHelp_legacyForm(t *testing.T) {
|
||||
|
|
Loading…
Reference in New Issue