From cef66fd2f6e0e061ade3778d3d3868032e4f0a32 Mon Sep 17 00:00:00 2001 From: Alexey Trofimov Date: Wed, 18 Jan 2023 11:50:50 +0300 Subject: [PATCH 1/4] add strict subcommand parsing --- go.sum | 2 -- parse.go | 11 ++++++- parse_test.go | 82 +++++++++++++++++++++++++++++++++++++++++++++++++-- 3 files changed, 89 insertions(+), 6 deletions(-) diff --git a/go.sum b/go.sum index 5b536f9..385ca8f 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,3 @@ -github.com/alexflint/go-scalar v1.1.0 h1:aaAouLLzI9TChcPXotr6gUhq+Scr8rl0P9P4PnltbhM= -github.com/alexflint/go-scalar v1.1.0/go.mod h1:LoFvNMqS1CPrMVltza4LvnGKhaSpc3oyLEBUZVhhS2o= github.com/alexflint/go-scalar v1.2.0 h1:WR7JPKkeNpnYIOfHRa7ivM21aWAdHD0gEWHCx+WQBRw= github.com/alexflint/go-scalar v1.2.0/go.mod h1:LoFvNMqS1CPrMVltza4LvnGKhaSpc3oyLEBUZVhhS2o= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= diff --git a/parse.go b/parse.go index 6ac5e99..3efe858 100644 --- a/parse.go +++ b/parse.go @@ -115,6 +115,10 @@ type Config struct { // IgnoreDefault instructs the library not to reset the variables to the // default values, including pointers to sub commands IgnoreDefault bool + + // IgnoreDefault intructs the library not to allow global commands after + // subcommand + StrictSubcommands bool } // Parser represents a set of command line options with destination values @@ -588,7 +592,12 @@ func (p *Parser) process(args []string) error { } // add the new options to the set of allowed options - specs = append(specs, subcmd.specs...) + if p.config.StrictSubcommands { + specs = make([]*spec, len(subcmd.specs)) + copy(specs, subcmd.specs) + } else { + specs = append(specs, subcmd.specs...) + } // capture environment vars for these new options if !p.config.IgnoreEnv { diff --git a/parse_test.go b/parse_test.go index 5d38306..64119a8 100644 --- a/parse_test.go +++ b/parse_test.go @@ -98,9 +98,9 @@ func TestInt(t *testing.T) { func TestHexOctBin(t *testing.T) { var args struct { - Hex int - Oct int - Bin int + Hex int + Oct int + Bin int Underscored int } err := parse("--hex 0xA --oct 0o10 --bin 0b101 --underscored 123_456", &args) @@ -1614,3 +1614,79 @@ func TestTextMarshalerUnmarshalerEmptyPointer(t *testing.T) { require.NoError(t, err) assert.Nil(t, args.Config) } + +func TestSubcommandGlobalFlag_Before(t *testing.T) { + var args struct { + Global bool `arg:"-g"` + Sub *struct { + } `arg:"subcommand"` + } + + p, err := NewParser(Config{StrictSubcommands: false}, &args) + require.NoError(t, err) + + err = p.Parse([]string{"-g", "sub"}) + assert.NoError(t, err) + assert.True(t, args.Global) +} + +func TestSubcommandGlobalFlag_InCommand(t *testing.T) { + var args struct { + Global bool `arg:"-g"` + Sub *struct { + } `arg:"subcommand"` + } + + p, err := NewParser(Config{StrictSubcommands: false}, &args) + require.NoError(t, err) + + err = p.Parse([]string{"sub", "-g"}) + assert.NoError(t, err) + assert.True(t, args.Global) +} + +func TestSubcommandGlobalFlag_Before_Strict(t *testing.T) { + var args struct { + Global bool `arg:"-g"` + Sub *struct { + } `arg:"subcommand"` + } + + p, err := NewParser(Config{StrictSubcommands: true}, &args) + require.NoError(t, err) + + err = p.Parse([]string{"-g", "sub"}) + assert.NoError(t, err) + assert.True(t, args.Global) +} + +func TestSubcommandGlobalFlag_InCommand_Strict(t *testing.T) { + var args struct { + Global bool `arg:"-g"` + Sub *struct { + } `arg:"subcommand"` + } + + p, err := NewParser(Config{StrictSubcommands: true}, &args) + require.NoError(t, err) + + err = p.Parse([]string{"sub", "-g"}) + assert.Error(t, err) +} + +func TestSubcommandGlobalFlag_InCommand_Strict_Inner(t *testing.T) { + var args struct { + Global bool `arg:"-g"` + Sub *struct { + Guard bool `arg:"-g"` + } `arg:"subcommand"` + } + + p, err := NewParser(Config{StrictSubcommands: true}, &args) + require.NoError(t, err) + + err = p.Parse([]string{"sub", "-g"}) + assert.NoError(t, err) + assert.False(t, args.Global) + assert.True(t, args.Sub.Guard) +} From 5036dce2d6a64b2dd1b6e270947bde1e8110708c Mon Sep 17 00:00:00 2001 From: Alexey Trofimov Date: Wed, 18 Jan 2023 11:52:13 +0300 Subject: [PATCH 2/4] fix typo --- parse.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/parse.go b/parse.go index 3efe858..b935fad 100644 --- a/parse.go +++ b/parse.go @@ -116,7 +116,7 @@ type Config struct { // default values, including pointers to sub commands IgnoreDefault bool - // IgnoreDefault intructs the library not to allow global commands after + // StrictSubcommands intructs the library not to allow global commands after // subcommand StrictSubcommands bool } From efae1938fd6c8434532ca7527cd90752e558d377 Mon Sep 17 00:00:00 2001 From: duxinlong Date: Wed, 8 Feb 2023 12:01:48 +0000 Subject: [PATCH 3/4] feat: support more env than terminal Change-Id: I7f35e90b8f19f4ea781832885d35e2f1e275207a --- parse.go | 31 +++++++++++++++++++++++++++---- usage.go | 6 +++--- 2 files changed, 30 insertions(+), 7 deletions(-) diff --git a/parse.go b/parse.go index b935fad..6d8b509 100644 --- a/parse.go +++ b/parse.go @@ -5,6 +5,7 @@ import ( "encoding/csv" "errors" "fmt" + "io" "os" "path/filepath" "reflect" @@ -119,6 +120,10 @@ type Config struct { // StrictSubcommands intructs the library not to allow global commands after // subcommand StrictSubcommands bool + + OsExit func(int) + Stdout io.Writer + Stderr io.Writer } // Parser represents a set of command line options with destination values @@ -132,6 +137,10 @@ type Parser struct { // the following field changes during processing of command line arguments lastCmd *command + + osExit func(int) + stdout io.Writer + stderr io.Writer } // Versioned is the interface that the destination struct should implement to @@ -196,6 +205,20 @@ func NewParser(config Config, dests ...interface{}) (*Parser, error) { p := Parser{ cmd: &command{name: name}, config: config, + + osExit: osExit, + stdout: stdout, + stderr: stderr, + } + + if config.OsExit != nil { + p.osExit = config.OsExit + } + if config.Stdout != nil { + p.stdout = config.Stdout + } + if config.Stderr != nil { + p.stderr = config.Stderr } // make a list of roots @@ -483,11 +506,11 @@ func (p *Parser) MustParse(args []string) { err := p.Parse(args) switch { case err == ErrHelp: - p.writeHelpForSubcommand(stdout, p.lastCmd) - osExit(0) + p.writeHelpForSubcommand(p.stdout, p.lastCmd) + p.osExit(0) case err == ErrVersion: - fmt.Fprintln(stdout, p.version) - osExit(0) + fmt.Fprintln(p.stdout, p.version) + p.osExit(0) case err != nil: p.failWithSubcommand(err.Error(), p.lastCmd) } diff --git a/usage.go b/usage.go index 7a480c3..80eba45 100644 --- a/usage.go +++ b/usage.go @@ -39,9 +39,9 @@ func (p *Parser) FailSubcommand(msg string, subcommand ...string) error { // failWithSubcommand prints usage information for the given subcommand to stderr and exits with non-zero status func (p *Parser) failWithSubcommand(msg string, cmd *command) { - p.writeUsageForSubcommand(stderr, cmd) - fmt.Fprintln(stderr, "error:", msg) - osExit(-1) + p.writeUsageForSubcommand(p.stderr, cmd) + fmt.Fprintln(p.stderr, "error:", msg) + p.osExit(-1) } // WriteUsage writes usage information to the given writer From df28e7154bbab76436bc59e5dc67fb6d6824fc62 Mon Sep 17 00:00:00 2001 From: Alex Flint Date: Wed, 8 Feb 2023 09:49:03 -0500 Subject: [PATCH 4/4] clean up customizable stdout, stderr, and exit in parser config --- example_test.go | 31 ++++++++-------------- parse.go | 65 +++++++++++++++++++++++++--------------------- parse_test.go | 69 ++++++++++++++++--------------------------------- usage.go | 14 +++------- usage_test.go | 34 ++++++------------------ 5 files changed, 80 insertions(+), 133 deletions(-) diff --git a/example_test.go b/example_test.go index fd64777..5272393 100644 --- a/example_test.go +++ b/example_test.go @@ -162,8 +162,7 @@ func Example_helpText() { } // This is only necessary when running inside golang's runnable example harness - osExit = func(int) {} - stdout = os.Stdout + mustParseExit = func(int) {} MustParse(&args) @@ -195,8 +194,7 @@ func Example_helpPlaceholder() { } // This is only necessary when running inside golang's runnable example harness - osExit = func(int) {} - stdout = os.Stdout + mustParseExit = func(int) {} MustParse(&args) @@ -236,8 +234,7 @@ func Example_helpTextWithSubcommand() { } // This is only necessary when running inside golang's runnable example harness - osExit = func(int) {} - stdout = os.Stdout + mustParseExit = func(int) {} MustParse(&args) @@ -274,8 +271,7 @@ func Example_helpTextWhenUsingSubcommand() { } // This is only necessary when running inside golang's runnable example harness - osExit = func(int) {} - stdout = os.Stdout + mustParseExit = func(int) {} MustParse(&args) @@ -311,10 +307,9 @@ func Example_writeHelpForSubcommand() { } // This is only necessary when running inside golang's runnable example harness - osExit = func(int) {} - stdout = os.Stdout + exit := func(int) {} - p, err := NewParser(Config{}, &args) + p, err := NewParser(Config{Exit: exit}, &args) if err != nil { fmt.Println(err) os.Exit(1) @@ -360,10 +355,9 @@ func Example_writeHelpForSubcommandNested() { } // This is only necessary when running inside golang's runnable example harness - osExit = func(int) {} - stdout = os.Stdout + exit := func(int) {} - p, err := NewParser(Config{}, &args) + p, err := NewParser(Config{Exit: exit}, &args) if err != nil { fmt.Println(err) os.Exit(1) @@ -397,8 +391,7 @@ func Example_errorText() { } // This is only necessary when running inside golang's runnable example harness - osExit = func(int) {} - stderr = os.Stdout + mustParseExit = func(int) {} MustParse(&args) @@ -421,8 +414,7 @@ func Example_errorTextForSubcommand() { } // This is only necessary when running inside golang's runnable example harness - osExit = func(int) {} - stderr = os.Stdout + mustParseExit = func(int) {} MustParse(&args) @@ -457,8 +449,7 @@ func Example_subcommand() { } // This is only necessary when running inside golang's runnable example harness - osExit = func(int) {} - stderr = os.Stdout + mustParseExit = func(int) {} MustParse(&args) diff --git a/parse.go b/parse.go index 6d8b509..be77924 100644 --- a/parse.go +++ b/parse.go @@ -75,13 +75,28 @@ var ErrHelp = errors.New("help requested by user") // ErrVersion indicates that --version was provided var ErrVersion = errors.New("version requested by user") +// for monkey patching in example code +var mustParseExit = os.Exit + // MustParse processes command line arguments and exits upon failure func MustParse(dest ...interface{}) *Parser { - p, err := NewParser(Config{}, dest...) + return mustParse(Config{Exit: mustParseExit}, dest...) +} + +// mustParse is a helper that facilitates testing +func mustParse(config Config, dest ...interface{}) *Parser { + if config.Exit == nil { + config.Exit = os.Exit + } + if config.Out == nil { + config.Out = os.Stdout + } + + p, err := NewParser(config, dest...) if err != nil { - fmt.Fprintln(stdout, err) - osExit(-1) - return nil // just in case osExit was monkey-patched + fmt.Fprintln(config.Out, err) + config.Exit(-1) + return nil } p.MustParse(flags()) @@ -121,9 +136,11 @@ type Config struct { // subcommand StrictSubcommands bool - OsExit func(int) - Stdout io.Writer - Stderr io.Writer + // Exit is called to terminate the process with an error code (defaults to os.Exit) + Exit func(int) + + // Out is where help text, usage text, and failure messages are printed (defaults to os.Stdout) + Out io.Writer } // Parser represents a set of command line options with destination values @@ -137,10 +154,6 @@ type Parser struct { // the following field changes during processing of command line arguments lastCmd *command - - osExit func(int) - stdout io.Writer - stderr io.Writer } // Versioned is the interface that the destination struct should implement to @@ -190,6 +203,14 @@ func walkFieldsImpl(t reflect.Type, visit func(field reflect.StructField, owner // NewParser constructs a parser from a list of destination structs func NewParser(config Config, dests ...interface{}) (*Parser, error) { + // fill in defaults + if config.Exit == nil { + config.Exit = os.Exit + } + if config.Out == nil { + config.Out = os.Stdout + } + // first pick a name for the command for use in the usage text var name string switch { @@ -205,20 +226,6 @@ func NewParser(config Config, dests ...interface{}) (*Parser, error) { p := Parser{ cmd: &command{name: name}, config: config, - - osExit: osExit, - stdout: stdout, - stderr: stderr, - } - - if config.OsExit != nil { - p.osExit = config.OsExit - } - if config.Stdout != nil { - p.stdout = config.Stdout - } - if config.Stderr != nil { - p.stderr = config.Stderr } // make a list of roots @@ -506,11 +513,11 @@ func (p *Parser) MustParse(args []string) { err := p.Parse(args) switch { case err == ErrHelp: - p.writeHelpForSubcommand(p.stdout, p.lastCmd) - p.osExit(0) + p.writeHelpForSubcommand(p.config.Out, p.lastCmd) + p.config.Exit(0) case err == ErrVersion: - fmt.Fprintln(p.stdout, p.version) - p.osExit(0) + fmt.Fprintln(p.config.Out, p.version) + p.config.Exit(0) case err != nil: p.failWithSubcommand(err.Error(), p.lastCmd) } diff --git a/parse_test.go b/parse_test.go index 64119a8..d368b17 100644 --- a/parse_test.go +++ b/parse_test.go @@ -885,26 +885,18 @@ func TestParserMustParse(t *testing.T) { for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { - originalExit := osExit - originalStdout := stdout - defer func() { - osExit = originalExit - stdout = originalStdout - }() + var exitCode int + var stdout bytes.Buffer + exit := func(code int) { exitCode = code } - var exitCode *int - osExit = func(code int) { exitCode = &code } - var b bytes.Buffer - stdout = &b - - p, err := NewParser(Config{}, &tt.args) + p, err := NewParser(Config{Exit: exit, Out: &stdout}, &tt.args) require.NoError(t, err) assert.NotNil(t, p) p.MustParse(tt.cmdLine) assert.NotNil(t, exitCode) - assert.Equal(t, tt.code, *exitCode) - assert.Contains(t, b.String(), tt.output) + assert.Equal(t, tt.code, exitCode) + assert.Contains(t, stdout.String(), tt.output) }) } } @@ -1484,70 +1476,53 @@ func TestUnexportedFieldsSkipped(t *testing.T) { } func TestMustParseInvalidParser(t *testing.T) { - originalExit := osExit - originalStdout := stdout - defer func() { - osExit = originalExit - stdout = originalStdout - }() - var exitCode int - osExit = func(code int) { exitCode = code } - stdout = &bytes.Buffer{} + var stdout bytes.Buffer + exit := func(code int) { exitCode = code } var args struct { CannotParse struct{} } - parser := MustParse(&args) + parser := mustParse(Config{Out: &stdout, Exit: exit}, &args) assert.Nil(t, parser) assert.Equal(t, -1, exitCode) } func TestMustParsePrintsHelp(t *testing.T) { - originalExit := osExit - originalStdout := stdout originalArgs := os.Args defer func() { - osExit = originalExit - stdout = originalStdout os.Args = originalArgs }() - var exitCode *int - osExit = func(code int) { exitCode = &code } os.Args = []string{"someprogram", "--help"} - stdout = &bytes.Buffer{} + + var exitCode int + var stdout bytes.Buffer + exit := func(code int) { exitCode = code } var args struct{} - parser := MustParse(&args) + parser := mustParse(Config{Out: &stdout, Exit: exit}, &args) assert.NotNil(t, parser) - require.NotNil(t, exitCode) - assert.Equal(t, 0, *exitCode) + assert.Equal(t, 0, exitCode) } func TestMustParsePrintsVersion(t *testing.T) { - originalExit := osExit - originalStdout := stdout originalArgs := os.Args defer func() { - osExit = originalExit - stdout = originalStdout os.Args = originalArgs }() - var exitCode *int - osExit = func(code int) { exitCode = &code } + var exitCode int + var stdout bytes.Buffer + exit := func(code int) { exitCode = code } + os.Args = []string{"someprogram", "--version"} - var b bytes.Buffer - stdout = &b - var args versioned - parser := MustParse(&args) + parser := mustParse(Config{Out: &stdout, Exit: exit}, &args) require.NotNil(t, parser) - require.NotNil(t, exitCode) - assert.Equal(t, 0, *exitCode) - assert.Equal(t, "example 3.2.1\n", b.String()) + assert.Equal(t, 0, exitCode) + assert.Equal(t, "example 3.2.1\n", stdout.String()) } type mapWithUnmarshalText struct { diff --git a/usage.go b/usage.go index 80eba45..43d6231 100644 --- a/usage.go +++ b/usage.go @@ -3,20 +3,12 @@ package arg import ( "fmt" "io" - "os" "strings" ) // the width of the left column const colWidth = 25 -// to allow monkey patching in tests -var ( - stdout io.Writer = os.Stdout - stderr io.Writer = os.Stderr - osExit = os.Exit -) - // Fail prints usage information to stderr and exits with non-zero status func (p *Parser) Fail(msg string) { p.failWithSubcommand(msg, p.cmd) @@ -39,9 +31,9 @@ func (p *Parser) FailSubcommand(msg string, subcommand ...string) error { // failWithSubcommand prints usage information for the given subcommand to stderr and exits with non-zero status func (p *Parser) failWithSubcommand(msg string, cmd *command) { - p.writeUsageForSubcommand(p.stderr, cmd) - fmt.Fprintln(p.stderr, "error:", msg) - p.osExit(-1) + p.writeUsageForSubcommand(p.config.Out, cmd) + fmt.Fprintln(p.config.Out, "error:", msg) + p.config.Exit(-1) } // WriteUsage writes usage information to the given writer diff --git a/usage_test.go b/usage_test.go index be5894a..69feac2 100644 --- a/usage_test.go +++ b/usage_test.go @@ -572,18 +572,9 @@ Options: } func TestFail(t *testing.T) { - originalStderr := stderr - originalExit := osExit - defer func() { - stderr = originalStderr - osExit = originalExit - }() - - var b bytes.Buffer - stderr = &b - + var stdout bytes.Buffer var exitCode int - osExit = func(code int) { exitCode = code } + exit := func(code int) { exitCode = code } expectedStdout := ` Usage: example [--foo FOO] @@ -593,27 +584,18 @@ error: something went wrong var args struct { Foo int } - p, err := NewParser(Config{Program: "example"}, &args) + p, err := NewParser(Config{Program: "example", Exit: exit, Out: &stdout}, &args) require.NoError(t, err) p.Fail("something went wrong") - assert.Equal(t, expectedStdout[1:], b.String()) + assert.Equal(t, expectedStdout[1:], stdout.String()) assert.Equal(t, -1, exitCode) } func TestFailSubcommand(t *testing.T) { - originalStderr := stderr - originalExit := osExit - defer func() { - stderr = originalStderr - osExit = originalExit - }() - - var b bytes.Buffer - stderr = &b - + var stdout bytes.Buffer var exitCode int - osExit = func(code int) { exitCode = code } + exit := func(code int) { exitCode = code } expectedStdout := ` Usage: example sub @@ -623,13 +605,13 @@ error: something went wrong var args struct { Sub *struct{} `arg:"subcommand"` } - p, err := NewParser(Config{Program: "example"}, &args) + p, err := NewParser(Config{Program: "example", Exit: exit, Out: &stdout}, &args) require.NoError(t, err) err = p.FailSubcommand("something went wrong", "sub") require.NoError(t, err) - assert.Equal(t, expectedStdout[1:], b.String()) + assert.Equal(t, expectedStdout[1:], stdout.String()) assert.Equal(t, -1, exitCode) }