diff --git a/parse.go b/parse.go index c8cd79e..dc87947 100644 --- a/parse.go +++ b/parse.go @@ -82,18 +82,7 @@ func MustParse(dest ...interface{}) *Parser { return nil // just in case osExit was monkey-patched } - err = p.Parse(flags()) - switch { - case err == ErrHelp: - p.writeHelpForSubcommand(stdout, p.lastCmd) - osExit(0) - case err == ErrVersion: - fmt.Fprintln(stdout, p.version) - osExit(0) - case err != nil: - p.failWithSubcommand(err.Error(), p.lastCmd) - } - + p.MustParse(flags()) return p } @@ -449,6 +438,20 @@ func (p *Parser) Parse(args []string) error { return err } +func (p *Parser) MustParse(args []string) { + err := p.Parse(args) + switch { + case err == ErrHelp: + p.writeHelpForSubcommand(stdout, p.lastCmd) + osExit(0) + case err == ErrVersion: + fmt.Fprintln(stdout, p.version) + osExit(0) + case err != nil: + p.failWithSubcommand(err.Error(), p.lastCmd) + } +} + // process environment vars for the given arguments func (p *Parser) captureEnvVars(specs []*spec, wasPresent map[*spec]bool) error { for _, spec := range specs { diff --git a/parse_test.go b/parse_test.go index 4ea6bc4..7e84def 100644 --- a/parse_test.go +++ b/parse_test.go @@ -860,6 +860,54 @@ func TestEnvironmentVariableInSubcommandIgnored(t *testing.T) { assert.Equal(t, "", args.Sub.Foo) } +func TestParserMustParseEmptyArgs(t *testing.T) { + // this mirrors TestEmptyArgs + p, err := NewParser(Config{}, &struct{}{}) + require.NoError(t, err) + assert.NotNil(t, p) + p.MustParse(nil) +} + +func TestParserMustParse(t *testing.T) { + tests := []struct { + name string + args versioned + cmdLine []string + code int + output string + }{ + {name: "help", args: struct{}{}, cmdLine: []string{"--help"}, code: 0, output: "display this help and exit"}, + {name: "version", args: versioned{}, cmdLine: []string{"--version"}, code: 0, output: "example 3.2.1"}, + {name: "invalid", args: struct{}{}, cmdLine: []string{"invalid"}, code: -1, output: ""}, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + originalExit := osExit + originalStdout := stdout + defer func() { + osExit = originalExit + stdout = originalStdout + }() + + var exitCode *int + osExit = func(code int) { exitCode = &code } + var b bytes.Buffer + stdout = &b + + p, err := NewParser(Config{}, &tt.args) + require.NoError(t, err) + assert.NotNil(t, p) + + p.MustParse(tt.cmdLine) + assert.NotNil(t, exitCode) + assert.Equal(t, tt.code, *exitCode) + assert.Contains(t, b.String(), tt.output) + }) + } +} + type textUnmarshaler struct { val int }