Implement MustParse on Parse

This moves most of the body of the MustParse function into a MustParse
method on a Parser. The MustParse function is now implemented by calling
the MustParse function on the Parser it implicitly creates.

Closes: #194
This commit is contained in:
Daniele Sluijters 2022-10-05 17:59:23 +02:00
parent 11f9b624a9
commit 4fc9666f79
2 changed files with 63 additions and 12 deletions

View File

@ -82,18 +82,7 @@ func MustParse(dest ...interface{}) *Parser {
return nil // just in case osExit was monkey-patched
}
err = p.Parse(flags())
switch {
case err == ErrHelp:
p.writeHelpForSubcommand(stdout, p.lastCmd)
osExit(0)
case err == ErrVersion:
fmt.Fprintln(stdout, p.version)
osExit(0)
case err != nil:
p.failWithSubcommand(err.Error(), p.lastCmd)
}
p.MustParse(flags())
return p
}
@ -449,6 +438,20 @@ func (p *Parser) Parse(args []string) error {
return err
}
func (p *Parser) MustParse(args []string) {
err := p.Parse(args)
switch {
case err == ErrHelp:
p.writeHelpForSubcommand(stdout, p.lastCmd)
osExit(0)
case err == ErrVersion:
fmt.Fprintln(stdout, p.version)
osExit(0)
case err != nil:
p.failWithSubcommand(err.Error(), p.lastCmd)
}
}
// process environment vars for the given arguments
func (p *Parser) captureEnvVars(specs []*spec, wasPresent map[*spec]bool) error {
for _, spec := range specs {

View File

@ -860,6 +860,54 @@ func TestEnvironmentVariableInSubcommandIgnored(t *testing.T) {
assert.Equal(t, "", args.Sub.Foo)
}
func TestParserMustParseEmptyArgs(t *testing.T) {
// this mirrors TestEmptyArgs
p, err := NewParser(Config{}, &struct{}{})
require.NoError(t, err)
assert.NotNil(t, p)
p.MustParse(nil)
}
func TestParserMustParse(t *testing.T) {
tests := []struct {
name string
args versioned
cmdLine []string
code int
output string
}{
{name: "help", args: struct{}{}, cmdLine: []string{"--help"}, code: 0, output: "display this help and exit"},
{name: "version", args: versioned{}, cmdLine: []string{"--version"}, code: 0, output: "example 3.2.1"},
{name: "invalid", args: struct{}{}, cmdLine: []string{"invalid"}, code: -1, output: ""},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
originalExit := osExit
originalStdout := stdout
defer func() {
osExit = originalExit
stdout = originalStdout
}()
var exitCode *int
osExit = func(code int) { exitCode = &code }
var b bytes.Buffer
stdout = &b
p, err := NewParser(Config{}, &tt.args)
require.NoError(t, err)
assert.NotNil(t, p)
p.MustParse(tt.cmdLine)
assert.NotNil(t, exitCode)
assert.Equal(t, tt.code, *exitCode)
assert.Contains(t, b.String(), tt.output)
})
}
}
type textUnmarshaler struct {
val int
}