diff --git a/example_test.go b/example_test.go index 5645156..26f24e7 100644 --- a/example_test.go +++ b/example_test.go @@ -95,6 +95,19 @@ func Example_mappings() { // output: map[john:123 mary:456] } +// This example demonstrates arguments with keys and values separated by commas +func Example_mappingsWithCommas() { + // The args you would pass in on the command line + os.Args = split("./example --userids john=123 mary=456") + + var args struct { + UserIDs map[string]int + } + MustParse(&args) + fmt.Println(args.UserIDs) + // output: map[john:123 mary:456] +} + // This eample demonstrates multiple value arguments that can be mixed with // other arguments. func Example_multipleMixed() { @@ -130,6 +143,7 @@ func Example_helpText() { // This is only necessary when running inside golang's runnable example harness osExit = func(int) {} + stdout = os.Stdout MustParse(&args) @@ -162,6 +176,7 @@ func Example_helpPlaceholder() { // This is only necessary when running inside golang's runnable example harness osExit = func(int) {} + stdout = os.Stdout MustParse(&args) @@ -202,6 +217,7 @@ func Example_helpTextWithSubcommand() { // This is only necessary when running inside golang's runnable example harness osExit = func(int) {} + stdout = os.Stdout MustParse(&args) @@ -239,6 +255,7 @@ func Example_helpTextForSubcommand() { // This is only necessary when running inside golang's runnable example harness osExit = func(int) {} + stdout = os.Stdout MustParse(&args) diff --git a/parse.go b/parse.go index d357d5c..94c0a89 100644 --- a/parse.go +++ b/parse.go @@ -13,9 +13,6 @@ import ( scalar "github.com/alexflint/go-scalar" ) -// to enable monkey-patching during tests -var osExit = os.Exit - // path represents a sequence of steps to find the output location for an // argument or subcommand in the final destination struct type path struct { @@ -80,7 +77,7 @@ var ErrVersion = errors.New("version requested by user") func MustParse(dest ...interface{}) *Parser { p, err := NewParser(Config{}, dest...) if err != nil { - fmt.Println(err) + fmt.Fprintln(stdout, err) osExit(-1) return nil // just in case osExit was monkey-patched } @@ -88,10 +85,10 @@ func MustParse(dest ...interface{}) *Parser { err = p.Parse(flags()) switch { case err == ErrHelp: - p.writeHelpForCommand(os.Stdout, p.lastCmd) + p.writeHelpForCommand(stdout, p.lastCmd) osExit(0) case err == ErrVersion: - fmt.Println(p.version) + fmt.Fprintln(stdout, p.version) osExit(0) case err != nil: p.failWithCommand(err.Error(), p.lastCmd) @@ -688,15 +685,7 @@ func (p *Parser) val(dest path) reflect.Value { v = v.Elem() } - next := v.FieldByIndex(field.Index) - if !next.IsValid() { - // it is appropriate to panic here because this can only happen due to - // an internal bug in this library (since we construct the path ourselves - // by reflecting on the same struct) - panic(fmt.Errorf("error resolving path %v: %v has no field named %v", - dest.fields, v.Type(), field)) - } - v = next + v = v.FieldByIndex(field.Index) } return v } @@ -723,15 +712,3 @@ func findSubcommand(cmds []*command, name string) *command { } return nil } - -// isZero returns true if v contains the zero value for its type -func isZero(v reflect.Value) bool { - t := v.Type() - if t.Kind() == reflect.Slice || t.Kind() == reflect.Map { - return v.IsNil() - } - if !t.Comparable() { - return false - } - return v.Interface() == reflect.Zero(t).Interface() -} diff --git a/parse_test.go b/parse_test.go index d03cbfd..09fb508 100644 --- a/parse_test.go +++ b/parse_test.go @@ -1,6 +1,8 @@ package arg import ( + "bytes" + "fmt" "net" "net/mail" "os" @@ -24,14 +26,34 @@ func parse(cmdline string, dest interface{}) error { } func pparse(cmdline string, dest interface{}) (*Parser, error) { + return parseWithEnv(cmdline, nil, dest) +} + +func parseWithEnv(cmdline string, env []string, dest interface{}) (*Parser, error) { p, err := NewParser(Config{}, dest) if err != nil { return nil, err } + + // split the command line var parts []string if len(cmdline) > 0 { parts = strings.Split(cmdline, " ") } + + // split the environment vars + for _, s := range env { + pos := strings.Index(s, "=") + if pos == -1 { + return nil, fmt.Errorf("missing equals sign in %q", s) + } + err := os.Setenv(s[:pos], s[pos+1:]) + if err != nil { + return nil, err + } + } + + // execute the parser return p, p.Parse(parts) } @@ -461,7 +483,7 @@ func TestMissingValueAtEnd(t *testing.T) { assert.Error(t, err) } -func TestMissingValueInMIddle(t *testing.T) { +func TestMissingValueInMiddle(t *testing.T) { var args struct { Foo string Bar string @@ -546,6 +568,14 @@ func TestNoMoreOptions(t *testing.T) { assert.Equal(t, []string{"abc", "--foo", "xyz"}, args.Bar) } +func TestNoMoreOptionsBeforeHelp(t *testing.T) { + var args struct { + Foo int + } + err := parse("not_an_integer -- --help", &args) + assert.NotEqual(t, ErrHelp, err) +} + func TestHelpFlag(t *testing.T) { var args struct { Foo string @@ -633,9 +663,8 @@ func TestEnvironmentVariable(t *testing.T) { var args struct { Foo string `arg:"env"` } - setenv(t, "FOO", "bar") - os.Args = []string{"example"} - MustParse(&args) + _, err := parseWithEnv("", []string{"FOO=bar"}, &args) + require.NoError(t, err) assert.Equal(t, "bar", args.Foo) } @@ -643,8 +672,8 @@ func TestEnvironmentVariableNotPresent(t *testing.T) { var args struct { NotPresent string `arg:"env"` } - os.Args = []string{"example"} - MustParse(&args) + _, err := parseWithEnv("", nil, &args) + require.NoError(t, err) assert.Equal(t, "", args.NotPresent) } @@ -652,9 +681,8 @@ func TestEnvironmentVariableOverrideName(t *testing.T) { var args struct { Foo string `arg:"env:BAZ"` } - setenv(t, "BAZ", "bar") - os.Args = []string{"example"} - MustParse(&args) + _, err := parseWithEnv("", []string{"BAZ=bar"}, &args) + require.NoError(t, err) assert.Equal(t, "bar", args.Foo) } @@ -662,19 +690,16 @@ func TestEnvironmentVariableOverrideArgument(t *testing.T) { var args struct { Foo string `arg:"env"` } - setenv(t, "FOO", "bar") - os.Args = []string{"example", "--foo", "baz"} - MustParse(&args) - assert.Equal(t, "baz", args.Foo) + _, err := parseWithEnv("--foo zzz", []string{"FOO=bar"}, &args) + require.NoError(t, err) + assert.Equal(t, "zzz", args.Foo) } func TestEnvironmentVariableError(t *testing.T) { var args struct { Foo int `arg:"env"` } - setenv(t, "FOO", "bar") - os.Args = []string{"example"} - err := Parse(&args) + _, err := parseWithEnv("", []string{"FOO=bar"}, &args) assert.Error(t, err) } @@ -682,9 +707,8 @@ func TestEnvironmentVariableRequired(t *testing.T) { var args struct { Foo string `arg:"env,required"` } - setenv(t, "FOO", "bar") - os.Args = []string{"example"} - MustParse(&args) + _, err := parseWithEnv("", []string{"FOO=bar"}, &args) + require.NoError(t, err) assert.Equal(t, "bar", args.Foo) } @@ -692,8 +716,8 @@ func TestEnvironmentVariableSliceArgumentString(t *testing.T) { var args struct { Foo []string `arg:"env"` } - setenv(t, "FOO", `bar,"baz, qux"`) - MustParse(&args) + _, err := parseWithEnv("", []string{`FOO=bar,"baz, qux"`}, &args) + require.NoError(t, err) assert.Equal(t, []string{"bar", "baz, qux"}, args.Foo) } @@ -701,8 +725,8 @@ func TestEnvironmentVariableSliceArgumentInteger(t *testing.T) { var args struct { Foo []int `arg:"env"` } - setenv(t, "FOO", "1,99") - MustParse(&args) + _, err := parseWithEnv("", []string{`FOO=1,99`}, &args) + require.NoError(t, err) assert.Equal(t, []int{1, 99}, args.Foo) } @@ -710,8 +734,8 @@ func TestEnvironmentVariableSliceArgumentFloat(t *testing.T) { var args struct { Foo []float32 `arg:"env"` } - setenv(t, "FOO", "1.1,99.9") - MustParse(&args) + _, err := parseWithEnv("", []string{`FOO=1.1,99.9`}, &args) + require.NoError(t, err) assert.Equal(t, []float32{1.1, 99.9}, args.Foo) } @@ -719,8 +743,8 @@ func TestEnvironmentVariableSliceArgumentBool(t *testing.T) { var args struct { Foo []bool `arg:"env"` } - setenv(t, "FOO", "true,false,0,1") - MustParse(&args) + _, err := parseWithEnv("", []string{`FOO=true,false,0,1`}, &args) + require.NoError(t, err) assert.Equal(t, []bool{true, false, false, true}, args.Foo) } @@ -728,8 +752,7 @@ func TestEnvironmentVariableSliceArgumentWrongCsv(t *testing.T) { var args struct { Foo []int `arg:"env"` } - setenv(t, "FOO", "1,99\"") - err := Parse(&args) + _, err := parseWithEnv("", []string{`FOO=1,99\"`}, &args) assert.Error(t, err) } @@ -737,8 +760,7 @@ func TestEnvironmentVariableSliceArgumentWrongType(t *testing.T) { var args struct { Foo []bool `arg:"env"` } - setenv(t, "FOO", "one,two") - err := Parse(&args) + _, err := parseWithEnv("", []string{`FOO=one,two`}, &args) assert.Error(t, err) } @@ -746,8 +768,8 @@ func TestEnvironmentVariableMap(t *testing.T) { var args struct { Foo map[int]string `arg:"env"` } - setenv(t, "FOO", "1=one,99=ninetynine") - MustParse(&args) + _, err := parseWithEnv("", []string{`FOO=1=one,99=ninetynine`}, &args) + require.NoError(t, err) assert.Len(t, args.Foo, 2) assert.Equal(t, "one", args.Foo[1]) assert.Equal(t, "ninetynine", args.Foo[99]) @@ -1299,3 +1321,70 @@ func TestUnexportedFieldsSkipped(t *testing.T) { _, err := NewParser(Config{}, &args) require.NoError(t, err) } + +func TestMustParseInvalidParser(t *testing.T) { + originalExit := osExit + originalStdout := stdout + defer func() { + osExit = originalExit + stdout = originalStdout + }() + + var exitCode int + osExit = func(code int) { exitCode = code } + stdout = &bytes.Buffer{} + + var args struct { + CannotParse struct{} + } + parser := MustParse(&args) + assert.Nil(t, parser) + assert.Equal(t, -1, exitCode) +} + +func TestMustParsePrintsHelp(t *testing.T) { + originalExit := osExit + originalStdout := stdout + originalArgs := os.Args + defer func() { + osExit = originalExit + stdout = originalStdout + os.Args = originalArgs + }() + + var exitCode *int + osExit = func(code int) { exitCode = &code } + os.Args = []string{"someprogram", "--help"} + stdout = &bytes.Buffer{} + + var args struct{} + parser := MustParse(&args) + assert.NotNil(t, parser) + require.NotNil(t, exitCode) + assert.Equal(t, 0, *exitCode) +} + +func TestMustParsePrintsVersion(t *testing.T) { + originalExit := osExit + originalStdout := stdout + originalArgs := os.Args + defer func() { + osExit = originalExit + stdout = originalStdout + os.Args = originalArgs + }() + + var exitCode *int + osExit = func(code int) { exitCode = &code } + os.Args = []string{"someprogram", "--version"} + + var b bytes.Buffer + stdout = &b + + var args versioned + parser := MustParse(&args) + require.NotNil(t, parser) + require.NotNil(t, exitCode) + assert.Equal(t, 0, *exitCode) + assert.Equal(t, "example 3.2.1\n", b.String()) +} diff --git a/reflect.go b/reflect.go index 1806973..c719b52 100644 --- a/reflect.go +++ b/reflect.go @@ -94,3 +94,15 @@ func isExported(field string) bool { r, _ := utf8.DecodeRuneInString(field) // returns RuneError for empty string or invalid UTF8 return unicode.IsLetter(r) && unicode.IsUpper(r) } + +// isZero returns true if v contains the zero value for its type +func isZero(v reflect.Value) bool { + t := v.Type() + if t.Kind() == reflect.Slice || t.Kind() == reflect.Map { + return v.IsNil() + } + if !t.Comparable() { + return false + } + return v.Interface() == reflect.Zero(t).Interface() +} diff --git a/reflect_test.go b/reflect_test.go index 8d65fd9..10909b3 100644 --- a/reflect_test.go +++ b/reflect_test.go @@ -89,3 +89,24 @@ func TestCardinalityString(t *testing.T) { assert.Equal(t, "unsupported", unsupported.String()) assert.Equal(t, "unknown(42)", cardinality(42).String()) } + +func TestIsZero(t *testing.T) { + var zero int + var notZero = 3 + var nilSlice []int + var nonNilSlice = []int{1, 2, 3} + var nilMap map[string]string + var nonNilMap = map[string]string{"foo": "bar"} + var uncomparable = func() {} + + assert.True(t, isZero(reflect.ValueOf(zero))) + assert.False(t, isZero(reflect.ValueOf(notZero))) + + assert.True(t, isZero(reflect.ValueOf(nilSlice))) + assert.False(t, isZero(reflect.ValueOf(nonNilSlice))) + + assert.True(t, isZero(reflect.ValueOf(nilMap))) + assert.False(t, isZero(reflect.ValueOf(nonNilMap))) + + assert.False(t, isZero(reflect.ValueOf(uncomparable))) +} diff --git a/subcommand_test.go b/subcommand_test.go index c34ab01..2c61dd3 100644 --- a/subcommand_test.go +++ b/subcommand_test.go @@ -1,6 +1,7 @@ package arg import ( + "reflect" "testing" "github.com/stretchr/testify/assert" @@ -48,6 +49,17 @@ func TestMinimalSubcommand(t *testing.T) { assert.Equal(t, []string{"list"}, p.SubcommandNames()) } +func TestSubcommandNamesBeforeParsing(t *testing.T) { + type listCmd struct{} + var args struct { + List *listCmd `arg:"subcommand"` + } + p, err := NewParser(Config{}, &args) + require.NoError(t, err) + assert.Nil(t, p.Subcommand()) + assert.Nil(t, p.SubcommandNames()) +} + func TestNoSuchSubcommand(t *testing.T) { type listCmd struct { } @@ -179,6 +191,36 @@ func TestSubcommandsWithOptions(t *testing.T) { } } +func TestSubcommandsWithEnvVars(t *testing.T) { + type getCmd struct { + Name string `arg:"env"` + } + type listCmd struct { + Limit int `arg:"env"` + } + type cmd struct { + Verbose bool + Get *getCmd `arg:"subcommand"` + List *listCmd `arg:"subcommand"` + } + + { + var args cmd + setenv(t, "LIMIT", "123") + err := parse("list", &args) + require.NoError(t, err) + require.NotNil(t, args.List) + assert.Equal(t, 123, args.List.Limit) + } + + { + var args cmd + setenv(t, "LIMIT", "not_an_integer") + err := parse("list", &args) + assert.Error(t, err) + } +} + func TestNestedSubcommands(t *testing.T) { type child struct{} type parent struct { @@ -353,3 +395,19 @@ func TestSubcommandsWithMultiplePositionals(t *testing.T) { assert.Equal(t, 5, args.Limit) } } + +func TestValForNilStruct(t *testing.T) { + type subcmd struct{} + var cmd struct { + Sub *subcmd `arg:"subcommand"` + } + + p, err := NewParser(Config{}, &cmd) + require.NoError(t, err) + + typ := reflect.TypeOf(cmd) + subField, _ := typ.FieldByName("Sub") + + v := p.val(path{fields: []reflect.StructField{subField, subField}}) + assert.False(t, v.IsValid()) +} diff --git a/usage.go b/usage.go index 231476b..c121c45 100644 --- a/usage.go +++ b/usage.go @@ -11,7 +11,11 @@ import ( const colWidth = 25 // to allow monkey patching in tests -var stderr = os.Stderr +var ( + stdout io.Writer = os.Stdout + stderr io.Writer = os.Stderr + osExit = os.Exit +) // Fail prints usage information to stderr and exits with non-zero status func (p *Parser) Fail(msg string) { diff --git a/usage_test.go b/usage_test.go index 6dee402..1b6c475 100644 --- a/usage_test.go +++ b/usage_test.go @@ -33,9 +33,10 @@ func (n *NameDotName) MarshalText() (text []byte, err error) { } func TestWriteUsage(t *testing.T) { - expectedUsage := "Usage: example [--name NAME] [--value VALUE] [--verbose] [--dataset DATASET] [--optimize OPTIMIZE] [--ids IDS] [--values VALUES] [--workers WORKERS] [--testenv TESTENV] [--file FILE] INPUT [OUTPUT [OUTPUT ...]]\n" + expectedUsage := "Usage: example [--name NAME] [--value VALUE] [--verbose] [--dataset DATASET] [--optimize OPTIMIZE] [--ids IDS] [--values VALUES] [--workers WORKERS] [--testenv TESTENV] [--file FILE] INPUT [OUTPUT [OUTPUT ...]]" - expectedHelp := `Usage: example [--name NAME] [--value VALUE] [--verbose] [--dataset DATASET] [--optimize OPTIMIZE] [--ids IDS] [--values VALUES] [--workers WORKERS] [--testenv TESTENV] [--file FILE] INPUT [OUTPUT [OUTPUT ...]] + expectedHelp := ` +Usage: example [--name NAME] [--value VALUE] [--verbose] [--dataset DATASET] [--optimize OPTIMIZE] [--ids IDS] [--values VALUES] [--workers WORKERS] [--testenv TESTENV] [--file FILE] INPUT [OUTPUT [OUTPUT ...]] Positional arguments: INPUT @@ -56,6 +57,7 @@ Options: --file FILE, -f FILE File with mandatory extension [default: scratch.txt] --help, -h display this help and exit ` + var args struct { Input string `arg:"positional"` Output []string `arg:"positional" help:"list of outputs"` @@ -79,13 +81,13 @@ Options: os.Args[0] = "example" - var usage bytes.Buffer - p.WriteUsage(&usage) - assert.Equal(t, expectedUsage, usage.String()) - var help bytes.Buffer p.WriteHelp(&help) - assert.Equal(t, expectedHelp, help.String()) + assert.Equal(t, expectedHelp[1:], help.String()) + + var usage bytes.Buffer + p.WriteUsage(&usage) + assert.Equal(t, expectedUsage, strings.TrimSpace(usage.String())) } type MyEnum int @@ -99,7 +101,10 @@ func (n *MyEnum) MarshalText() ([]byte, error) { } func TestUsageWithDefaults(t *testing.T) { - expectedHelp := `Usage: example [--label LABEL] [--content CONTENT] + expectedUsage := "Usage: example [--label LABEL] [--content CONTENT]" + + expectedHelp := ` +Usage: example [--label LABEL] [--content CONTENT] Options: --label LABEL [default: cat] @@ -118,7 +123,11 @@ Options: var help bytes.Buffer p.WriteHelp(&help) - assert.Equal(t, expectedHelp, help.String()) + assert.Equal(t, expectedHelp[1:], help.String()) + + var usage bytes.Buffer + p.WriteUsage(&usage) + assert.Equal(t, expectedUsage, strings.TrimSpace(usage.String())) } func TestUsageCannotMarshalToString(t *testing.T) { @@ -132,7 +141,10 @@ func TestUsageCannotMarshalToString(t *testing.T) { } func TestUsageLongPositionalWithHelp_legacyForm(t *testing.T) { - expectedHelp := `Usage: example VERYLONGPOSITIONALWITHHELP + expectedUsage := "Usage: example VERYLONGPOSITIONALWITHHELP" + + expectedHelp := ` +Usage: example VERYLONGPOSITIONALWITHHELP Positional arguments: VERYLONGPOSITIONALWITHHELP @@ -145,17 +157,23 @@ Options: VeryLongPositionalWithHelp string `arg:"positional,help:this positional argument is very long but cannot include commas"` } - p, err := NewParser(Config{}, &args) + p, err := NewParser(Config{Program: "example"}, &args) require.NoError(t, err) - os.Args[0] = "example" var help bytes.Buffer p.WriteHelp(&help) - assert.Equal(t, expectedHelp, help.String()) + assert.Equal(t, expectedHelp[1:], help.String()) + + var usage bytes.Buffer + p.WriteUsage(&usage) + assert.Equal(t, expectedUsage, strings.TrimSpace(usage.String())) } func TestUsageLongPositionalWithHelp_newForm(t *testing.T) { - expectedHelp := `Usage: example VERYLONGPOSITIONALWITHHELP + expectedUsage := "Usage: example VERYLONGPOSITIONALWITHHELP" + + expectedHelp := ` +Usage: example VERYLONGPOSITIONALWITHHELP Positional arguments: VERYLONGPOSITIONALWITHHELP @@ -168,17 +186,23 @@ Options: VeryLongPositionalWithHelp string `arg:"positional" help:"this positional argument is very long, and includes: commas, colons etc"` } - p, err := NewParser(Config{}, &args) + p, err := NewParser(Config{Program: "example"}, &args) require.NoError(t, err) - os.Args[0] = "example" var help bytes.Buffer p.WriteHelp(&help) - assert.Equal(t, expectedHelp, help.String()) + assert.Equal(t, expectedHelp[1:], help.String()) + + var usage bytes.Buffer + p.WriteUsage(&usage) + assert.Equal(t, expectedUsage, strings.TrimSpace(usage.String())) } func TestUsageWithProgramName(t *testing.T) { - expectedHelp := `Usage: myprogram + expectedUsage := "Usage: myprogram" + + expectedHelp := ` +Usage: myprogram Options: --help, -h display this help and exit @@ -190,9 +214,14 @@ Options: require.NoError(t, err) os.Args[0] = "example" + var help bytes.Buffer p.WriteHelp(&help) - assert.Equal(t, expectedHelp, help.String()) + assert.Equal(t, expectedHelp[1:], help.String()) + + var usage bytes.Buffer + p.WriteUsage(&usage) + assert.Equal(t, expectedUsage, strings.TrimSpace(usage.String())) } type versioned struct{} @@ -203,7 +232,10 @@ func (versioned) Version() string { } func TestUsageWithVersion(t *testing.T) { - expectedHelp := `example 3.2.1 + expectedUsage := "example 3.2.1\nUsage: example" + + expectedHelp := ` +example 3.2.1 Usage: example Options: @@ -216,12 +248,11 @@ Options: var help bytes.Buffer p.WriteHelp(&help) - actual := help.String() - if expectedHelp != actual { - t.Logf("Expected:\n%s", expectedHelp) - t.Logf("Actual:\n%s", actual) - t.Fail() - } + assert.Equal(t, expectedHelp[1:], help.String()) + + var usage bytes.Buffer + p.WriteUsage(&usage) + assert.Equal(t, expectedUsage, strings.TrimSpace(usage.String())) } type described struct{} @@ -232,7 +263,10 @@ func (described) Description() string { } func TestUsageWithDescription(t *testing.T) { - expectedHelp := `this program does this and that + expectedUsage := "Usage: example" + + expectedHelp := ` +this program does this and that Usage: example Options: @@ -244,16 +278,18 @@ Options: var help bytes.Buffer p.WriteHelp(&help) - actual := help.String() - if expectedHelp != actual { - t.Logf("Expected:\n%s", expectedHelp) - t.Logf("Actual:\n%s", actual) - t.Fail() - } + assert.Equal(t, expectedHelp[1:], help.String()) + + var usage bytes.Buffer + p.WriteUsage(&usage) + assert.Equal(t, expectedUsage, strings.TrimSpace(usage.String())) } func TestRequiredMultiplePositionals(t *testing.T) { - expectedHelp := `Usage: example REQUIREDMULTIPLE [REQUIREDMULTIPLE ...] + expectedUsage := "Usage: example REQUIREDMULTIPLE [REQUIREDMULTIPLE ...]" + + expectedHelp := ` +Usage: example REQUIREDMULTIPLE [REQUIREDMULTIPLE ...] Positional arguments: REQUIREDMULTIPLE required multiple positional @@ -270,11 +306,18 @@ Options: var help bytes.Buffer p.WriteHelp(&help) - assert.Equal(t, expectedHelp, help.String()) + assert.Equal(t, expectedHelp[1:], help.String()) + + var usage bytes.Buffer + p.WriteUsage(&usage) + assert.Equal(t, expectedUsage, strings.TrimSpace(usage.String())) } func TestUsageWithNestedSubcommands(t *testing.T) { - expectedHelp := `Usage: example child nested [--enable] OUTPUT + expectedUsage := "Usage: example child nested [--enable] OUTPUT" + + expectedHelp := ` +Usage: example child nested [--enable] OUTPUT Positional arguments: OUTPUT @@ -307,11 +350,18 @@ Global options: var help bytes.Buffer p.WriteHelp(&help) - assert.Equal(t, expectedHelp, help.String()) + assert.Equal(t, expectedHelp[1:], help.String()) + + var usage bytes.Buffer + p.WriteUsage(&usage) + assert.Equal(t, expectedUsage, strings.TrimSpace(usage.String())) } func TestUsageWithoutLongNames(t *testing.T) { - expectedHelp := `Usage: example [-a PLACEHOLDER] -b SHORTONLY2 + expectedUsage := "Usage: example [-a PLACEHOLDER] -b SHORTONLY2" + + expectedHelp := ` +Usage: example [-a PLACEHOLDER] -b SHORTONLY2 Options: -a PLACEHOLDER some help [default: some val] @@ -324,13 +374,21 @@ Options: } p, err := NewParser(Config{Program: "example"}, &args) assert.NoError(t, err) + var help bytes.Buffer p.WriteHelp(&help) - assert.Equal(t, expectedHelp, help.String()) + assert.Equal(t, expectedHelp[1:], help.String()) + + var usage bytes.Buffer + p.WriteUsage(&usage) + assert.Equal(t, expectedUsage, strings.TrimSpace(usage.String())) } func TestUsageWithShortFirst(t *testing.T) { - expectedHelp := `Usage: example [-c CAT] [--dog DOG] + expectedUsage := "Usage: example [-c CAT] [--dog DOG]" + + expectedHelp := ` +Usage: example [-c CAT] [--dog DOG] Options: -c CAT @@ -343,13 +401,21 @@ Options: } p, err := NewParser(Config{Program: "example"}, &args) assert.NoError(t, err) + var help bytes.Buffer p.WriteHelp(&help) - assert.Equal(t, expectedHelp, help.String()) + assert.Equal(t, expectedHelp[1:], help.String()) + + var usage bytes.Buffer + p.WriteUsage(&usage) + assert.Equal(t, expectedUsage, strings.TrimSpace(usage.String())) } func TestUsageWithEnvOptions(t *testing.T) { - expectedHelp := `Usage: example [-s SHORT] + expectedUsage := "Usage: example [-s SHORT]" + + expectedHelp := ` +Usage: example [-s SHORT] Options: -s SHORT [env: SHORT] @@ -363,7 +429,42 @@ Options: p, err := NewParser(Config{Program: "example"}, &args) assert.NoError(t, err) + var help bytes.Buffer p.WriteHelp(&help) - assert.Equal(t, expectedHelp, help.String()) + assert.Equal(t, expectedHelp[1:], help.String()) + + var usage bytes.Buffer + p.WriteUsage(&usage) + assert.Equal(t, expectedUsage, strings.TrimSpace(usage.String())) +} + +func TestFail(t *testing.T) { + originalStderr := stderr + originalExit := osExit + defer func() { + stderr = originalStderr + osExit = originalExit + }() + + var b bytes.Buffer + stderr = &b + + var exitCode int + osExit = func(code int) { exitCode = code } + + expectedStdout := ` +Usage: example [--foo FOO] +error: something went wrong +` + + var args struct { + Foo int + } + p, err := NewParser(Config{Program: "example"}, &args) + require.NoError(t, err) + p.Fail("something went wrong") + + assert.Equal(t, expectedStdout[1:], b.String()) + assert.Equal(t, -1, exitCode) }