diff --git a/example_test.go b/example_test.go index 0e21589..faca6b1 100644 --- a/example_test.go +++ b/example_test.go @@ -162,8 +162,7 @@ func Example_helpText() { } // This is only necessary when running inside golang's runnable example harness - osExit = func(int) {} - stdout = os.Stdout + mustParseExit = func(int) {} MustParse(&args) @@ -195,8 +194,7 @@ func Example_helpPlaceholder() { } // This is only necessary when running inside golang's runnable example harness - osExit = func(int) {} - stdout = os.Stdout + mustParseExit = func(int) {} MustParse(&args) @@ -236,8 +234,7 @@ func Example_helpTextWithSubcommand() { } // This is only necessary when running inside golang's runnable example harness - osExit = func(int) {} - stdout = os.Stdout + mustParseExit = func(int) {} MustParse(&args) @@ -276,11 +273,8 @@ func Example_helpTextWithGroups() { Quiet bool `arg:"-q" help:"Quiet"` // this flag is global to all subcommands } - MustParse(&args) - // This is only necessary when running inside golang's runnable example harness - osExit = func(int) {} - stdout = os.Stdout + mustParseExit = func(int) {} MustParse(&args) @@ -324,8 +318,7 @@ func Example_helpTextWhenUsingSubcommand() { } // This is only necessary when running inside golang's runnable example harness - osExit = func(int) {} - stdout = os.Stdout + mustParseExit = func(int) {} MustParse(&args) @@ -361,10 +354,9 @@ func Example_writeHelpForSubcommand() { } // This is only necessary when running inside golang's runnable example harness - osExit = func(int) {} - stdout = os.Stdout + exit := func(int) {} - p, err := NewParser(Config{}, &args) + p, err := NewParser(Config{Exit: exit}, &args) if err != nil { fmt.Println(err) os.Exit(1) @@ -410,10 +402,9 @@ func Example_writeHelpForSubcommandNested() { } // This is only necessary when running inside golang's runnable example harness - osExit = func(int) {} - stdout = os.Stdout + exit := func(int) {} - p, err := NewParser(Config{}, &args) + p, err := NewParser(Config{Exit: exit}, &args) if err != nil { fmt.Println(err) os.Exit(1) @@ -447,8 +438,7 @@ func Example_errorText() { } // This is only necessary when running inside golang's runnable example harness - osExit = func(int) {} - stderr = os.Stdout + mustParseExit = func(int) {} MustParse(&args) @@ -471,8 +461,7 @@ func Example_errorTextForSubcommand() { } // This is only necessary when running inside golang's runnable example harness - osExit = func(int) {} - stderr = os.Stdout + mustParseExit = func(int) {} MustParse(&args) @@ -507,8 +496,7 @@ func Example_subcommand() { } // This is only necessary when running inside golang's runnable example harness - osExit = func(int) {} - stderr = os.Stdout + mustParseExit = func(int) {} MustParse(&args) diff --git a/go.sum b/go.sum index 5b536f9..385ca8f 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,3 @@ -github.com/alexflint/go-scalar v1.1.0 h1:aaAouLLzI9TChcPXotr6gUhq+Scr8rl0P9P4PnltbhM= -github.com/alexflint/go-scalar v1.1.0/go.mod h1:LoFvNMqS1CPrMVltza4LvnGKhaSpc3oyLEBUZVhhS2o= github.com/alexflint/go-scalar v1.2.0 h1:WR7JPKkeNpnYIOfHRa7ivM21aWAdHD0gEWHCx+WQBRw= github.com/alexflint/go-scalar v1.2.0/go.mod h1:LoFvNMqS1CPrMVltza4LvnGKhaSpc3oyLEBUZVhhS2o= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= diff --git a/parse.go b/parse.go index f2fc7bc..b54fbeb 100644 --- a/parse.go +++ b/parse.go @@ -5,6 +5,7 @@ import ( "encoding/csv" "errors" "fmt" + "io" "os" "path/filepath" "reflect" @@ -86,13 +87,28 @@ var ErrHelp = errors.New("help requested by user") // ErrVersion indicates that --version was provided var ErrVersion = errors.New("version requested by user") +// for monkey patching in example code +var mustParseExit = os.Exit + // MustParse processes command line arguments and exits upon failure func MustParse(dest ...interface{}) *Parser { - p, err := NewParser(Config{}, dest...) + return mustParse(Config{Exit: mustParseExit}, dest...) +} + +// mustParse is a helper that facilitates testing +func mustParse(config Config, dest ...interface{}) *Parser { + if config.Exit == nil { + config.Exit = os.Exit + } + if config.Out == nil { + config.Out = os.Stdout + } + + p, err := NewParser(config, dest...) if err != nil { - fmt.Fprintln(stdout, err) - osExit(-1) - return nil // just in case osExit was monkey-patched + fmt.Fprintln(config.Out, err) + config.Exit(-1) + return nil } p.MustParse(flags()) @@ -127,6 +143,16 @@ type Config struct { // IgnoreDefault instructs the library not to reset the variables to the // default values, including pointers to sub commands IgnoreDefault bool + + // StrictSubcommands intructs the library not to allow global commands after + // subcommand + StrictSubcommands bool + + // Exit is called to terminate the process with an error code (defaults to os.Exit) + Exit func(int) + + // Out is where help text, usage text, and failure messages are printed (defaults to os.Stdout) + Out io.Writer } // Parser represents a set of command line options with destination values @@ -189,6 +215,14 @@ func walkFieldsImpl(t reflect.Type, visit func(field reflect.StructField, owner // NewParser constructs a parser from a list of destination structs func NewParser(config Config, dests ...interface{}) (*Parser, error) { + // fill in defaults + if config.Exit == nil { + config.Exit = os.Exit + } + if config.Out == nil { + config.Out = os.Stdout + } + // first pick a name for the command for use in the usage text var name string switch { @@ -531,11 +565,11 @@ func (p *Parser) MustParse(args []string) { err := p.Parse(args) switch { case err == ErrHelp: - p.writeHelpForSubcommand(stdout, p.lastCmd) - osExit(0) + p.writeHelpForSubcommand(p.config.Out, p.lastCmd) + p.config.Exit(0) case err == ErrVersion: - fmt.Fprintln(stdout, p.version) - osExit(0) + fmt.Fprintln(p.config.Out, p.version) + p.config.Exit(0) case err != nil: p.failWithSubcommand(err.Error(), p.lastCmd) } @@ -636,7 +670,12 @@ func (p *Parser) process(args []string) error { p.val(subcmd.dest) // add the new options to the set of allowed options - specs = append(specs, subcmd.specs()...) + if p.config.StrictSubcommands { + specs = make([]*spec, len(subcmd.specs())) + copy(specs, subcmd.specs()) + } else { + specs = append(specs, subcmd.specs()...) + } // capture environment vars for these new options if !p.config.IgnoreEnv { diff --git a/parse_test.go b/parse_test.go index fe3197a..ab63cb5 100644 --- a/parse_test.go +++ b/parse_test.go @@ -1198,26 +1198,18 @@ func TestParserMustParse(t *testing.T) { for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { - originalExit := osExit - originalStdout := stdout - defer func() { - osExit = originalExit - stdout = originalStdout - }() + var exitCode int + var stdout bytes.Buffer + exit := func(code int) { exitCode = code } - var exitCode *int - osExit = func(code int) { exitCode = &code } - var b bytes.Buffer - stdout = &b - - p, err := NewParser(Config{}, &tt.args) + p, err := NewParser(Config{Exit: exit, Out: &stdout}, &tt.args) require.NoError(t, err) assert.NotNil(t, p) p.MustParse(tt.cmdLine) assert.NotNil(t, exitCode) - assert.Equal(t, tt.code, *exitCode) - assert.Contains(t, b.String(), tt.output) + assert.Equal(t, tt.code, exitCode) + assert.Contains(t, stdout.String(), tt.output) }) } } @@ -1808,70 +1800,53 @@ func TestUnexportedFieldsSkipped(t *testing.T) { } func TestMustParseInvalidParser(t *testing.T) { - originalExit := osExit - originalStdout := stdout - defer func() { - osExit = originalExit - stdout = originalStdout - }() - var exitCode int - osExit = func(code int) { exitCode = code } - stdout = &bytes.Buffer{} + var stdout bytes.Buffer + exit := func(code int) { exitCode = code } var args struct { CannotParse struct{} } - parser := MustParse(&args) + parser := mustParse(Config{Out: &stdout, Exit: exit}, &args) assert.Nil(t, parser) assert.Equal(t, -1, exitCode) } func TestMustParsePrintsHelp(t *testing.T) { - originalExit := osExit - originalStdout := stdout originalArgs := os.Args defer func() { - osExit = originalExit - stdout = originalStdout os.Args = originalArgs }() - var exitCode *int - osExit = func(code int) { exitCode = &code } os.Args = []string{"someprogram", "--help"} - stdout = &bytes.Buffer{} + + var exitCode int + var stdout bytes.Buffer + exit := func(code int) { exitCode = code } var args struct{} - parser := MustParse(&args) + parser := mustParse(Config{Out: &stdout, Exit: exit}, &args) assert.NotNil(t, parser) - require.NotNil(t, exitCode) - assert.Equal(t, 0, *exitCode) + assert.Equal(t, 0, exitCode) } func TestMustParsePrintsVersion(t *testing.T) { - originalExit := osExit - originalStdout := stdout originalArgs := os.Args defer func() { - osExit = originalExit - stdout = originalStdout os.Args = originalArgs }() - var exitCode *int - osExit = func(code int) { exitCode = &code } + var exitCode int + var stdout bytes.Buffer + exit := func(code int) { exitCode = code } + os.Args = []string{"someprogram", "--version"} - var b bytes.Buffer - stdout = &b - var args versioned - parser := MustParse(&args) + parser := mustParse(Config{Out: &stdout, Exit: exit}, &args) require.NotNil(t, parser) - require.NotNil(t, exitCode) - assert.Equal(t, 0, *exitCode) - assert.Equal(t, "example 3.2.1\n", b.String()) + assert.Equal(t, 0, exitCode) + assert.Equal(t, "example 3.2.1\n", stdout.String()) } type mapWithUnmarshalText struct { @@ -1938,3 +1913,79 @@ func TestTextMarshalerUnmarshalerEmptyPointer(t *testing.T) { require.NoError(t, err) assert.Nil(t, args.Config) } + +func TestSubcommandGlobalFlag_Before(t *testing.T) { + var args struct { + Global bool `arg:"-g"` + Sub *struct { + } `arg:"subcommand"` + } + + p, err := NewParser(Config{StrictSubcommands: false}, &args) + require.NoError(t, err) + + err = p.Parse([]string{"-g", "sub"}) + assert.NoError(t, err) + assert.True(t, args.Global) +} + +func TestSubcommandGlobalFlag_InCommand(t *testing.T) { + var args struct { + Global bool `arg:"-g"` + Sub *struct { + } `arg:"subcommand"` + } + + p, err := NewParser(Config{StrictSubcommands: false}, &args) + require.NoError(t, err) + + err = p.Parse([]string{"sub", "-g"}) + assert.NoError(t, err) + assert.True(t, args.Global) +} + +func TestSubcommandGlobalFlag_Before_Strict(t *testing.T) { + var args struct { + Global bool `arg:"-g"` + Sub *struct { + } `arg:"subcommand"` + } + + p, err := NewParser(Config{StrictSubcommands: true}, &args) + require.NoError(t, err) + + err = p.Parse([]string{"-g", "sub"}) + assert.NoError(t, err) + assert.True(t, args.Global) +} + +func TestSubcommandGlobalFlag_InCommand_Strict(t *testing.T) { + var args struct { + Global bool `arg:"-g"` + Sub *struct { + } `arg:"subcommand"` + } + + p, err := NewParser(Config{StrictSubcommands: true}, &args) + require.NoError(t, err) + + err = p.Parse([]string{"sub", "-g"}) + assert.Error(t, err) +} + +func TestSubcommandGlobalFlag_InCommand_Strict_Inner(t *testing.T) { + var args struct { + Global bool `arg:"-g"` + Sub *struct { + Guard bool `arg:"-g"` + } `arg:"subcommand"` + } + + p, err := NewParser(Config{StrictSubcommands: true}, &args) + require.NoError(t, err) + + err = p.Parse([]string{"sub", "-g"}) + assert.NoError(t, err) + assert.False(t, args.Global) + assert.True(t, args.Sub.Guard) +} diff --git a/usage.go b/usage.go index b18e6f2..a595346 100644 --- a/usage.go +++ b/usage.go @@ -3,20 +3,12 @@ package arg import ( "fmt" "io" - "os" "strings" ) // the width of the left column const colWidth = 25 -// to allow monkey patching in tests -var ( - stdout io.Writer = os.Stdout - stderr io.Writer = os.Stderr - osExit = os.Exit -) - // Fail prints usage information to stderr and exits with non-zero status func (p *Parser) Fail(msg string) { p.failWithSubcommand(msg, p.cmd) @@ -39,9 +31,9 @@ func (p *Parser) FailSubcommand(msg string, subcommand ...string) error { // failWithSubcommand prints usage information for the given subcommand to stderr and exits with non-zero status func (p *Parser) failWithSubcommand(msg string, cmd *command) { - p.writeUsageForSubcommand(stderr, cmd) - fmt.Fprintln(stderr, "error:", msg) - osExit(-1) + p.writeUsageForSubcommand(p.config.Out, cmd) + fmt.Fprintln(p.config.Out, "error:", msg) + p.config.Exit(-1) } // WriteUsage writes usage information to the given writer diff --git a/usage_test.go b/usage_test.go index 7add765..3ba5277 100644 --- a/usage_test.go +++ b/usage_test.go @@ -770,18 +770,9 @@ Options: } func TestFail(t *testing.T) { - originalStderr := stderr - originalExit := osExit - defer func() { - stderr = originalStderr - osExit = originalExit - }() - - var b bytes.Buffer - stderr = &b - + var stdout bytes.Buffer var exitCode int - osExit = func(code int) { exitCode = code } + exit := func(code int) { exitCode = code } expectedStdout := ` Usage: example [--foo FOO] @@ -791,27 +782,18 @@ error: something went wrong var args struct { Foo int } - p, err := NewParser(Config{Program: "example"}, &args) + p, err := NewParser(Config{Program: "example", Exit: exit, Out: &stdout}, &args) require.NoError(t, err) p.Fail("something went wrong") - assert.Equal(t, expectedStdout[1:], b.String()) + assert.Equal(t, expectedStdout[1:], stdout.String()) assert.Equal(t, -1, exitCode) } func TestFailSubcommand(t *testing.T) { - originalStderr := stderr - originalExit := osExit - defer func() { - stderr = originalStderr - osExit = originalExit - }() - - var b bytes.Buffer - stderr = &b - + var stdout bytes.Buffer var exitCode int - osExit = func(code int) { exitCode = code } + exit := func(code int) { exitCode = code } expectedStdout := ` Usage: example sub @@ -821,13 +803,13 @@ error: something went wrong var args struct { Sub *struct{} `arg:"subcommand"` } - p, err := NewParser(Config{Program: "example"}, &args) + p, err := NewParser(Config{Program: "example", Exit: exit, Out: &stdout}, &args) require.NoError(t, err) err = p.FailSubcommand("something went wrong", "sub") require.NoError(t, err) - assert.Equal(t, expectedStdout[1:], b.String()) + assert.Equal(t, expectedStdout[1:], stdout.String()) assert.Equal(t, -1, exitCode) }