From df28e7154bbab76436bc59e5dc67fb6d6824fc62 Mon Sep 17 00:00:00 2001 From: Alex Flint Date: Wed, 8 Feb 2023 09:49:03 -0500 Subject: [PATCH] clean up customizable stdout, stderr, and exit in parser config --- example_test.go | 31 ++++++++-------------- parse.go | 65 +++++++++++++++++++++++++--------------------- parse_test.go | 69 ++++++++++++++++--------------------------------- usage.go | 14 +++------- usage_test.go | 34 ++++++------------------ 5 files changed, 80 insertions(+), 133 deletions(-) diff --git a/example_test.go b/example_test.go index fd64777..5272393 100644 --- a/example_test.go +++ b/example_test.go @@ -162,8 +162,7 @@ func Example_helpText() { } // This is only necessary when running inside golang's runnable example harness - osExit = func(int) {} - stdout = os.Stdout + mustParseExit = func(int) {} MustParse(&args) @@ -195,8 +194,7 @@ func Example_helpPlaceholder() { } // This is only necessary when running inside golang's runnable example harness - osExit = func(int) {} - stdout = os.Stdout + mustParseExit = func(int) {} MustParse(&args) @@ -236,8 +234,7 @@ func Example_helpTextWithSubcommand() { } // This is only necessary when running inside golang's runnable example harness - osExit = func(int) {} - stdout = os.Stdout + mustParseExit = func(int) {} MustParse(&args) @@ -274,8 +271,7 @@ func Example_helpTextWhenUsingSubcommand() { } // This is only necessary when running inside golang's runnable example harness - osExit = func(int) {} - stdout = os.Stdout + mustParseExit = func(int) {} MustParse(&args) @@ -311,10 +307,9 @@ func Example_writeHelpForSubcommand() { } // This is only necessary when running inside golang's runnable example harness - osExit = func(int) {} - stdout = os.Stdout + exit := func(int) {} - p, err := NewParser(Config{}, &args) + p, err := NewParser(Config{Exit: exit}, &args) if err != nil { fmt.Println(err) os.Exit(1) @@ -360,10 +355,9 @@ func Example_writeHelpForSubcommandNested() { } // This is only necessary when running inside golang's runnable example harness - osExit = func(int) {} - stdout = os.Stdout + exit := func(int) {} - p, err := NewParser(Config{}, &args) + p, err := NewParser(Config{Exit: exit}, &args) if err != nil { fmt.Println(err) os.Exit(1) @@ -397,8 +391,7 @@ func Example_errorText() { } // This is only necessary when running inside golang's runnable example harness - osExit = func(int) {} - stderr = os.Stdout + mustParseExit = func(int) {} MustParse(&args) @@ -421,8 +414,7 @@ func Example_errorTextForSubcommand() { } // This is only necessary when running inside golang's runnable example harness - osExit = func(int) {} - stderr = os.Stdout + mustParseExit = func(int) {} MustParse(&args) @@ -457,8 +449,7 @@ func Example_subcommand() { } // This is only necessary when running inside golang's runnable example harness - osExit = func(int) {} - stderr = os.Stdout + mustParseExit = func(int) {} MustParse(&args) diff --git a/parse.go b/parse.go index 6d8b509..be77924 100644 --- a/parse.go +++ b/parse.go @@ -75,13 +75,28 @@ var ErrHelp = errors.New("help requested by user") // ErrVersion indicates that --version was provided var ErrVersion = errors.New("version requested by user") +// for monkey patching in example code +var mustParseExit = os.Exit + // MustParse processes command line arguments and exits upon failure func MustParse(dest ...interface{}) *Parser { - p, err := NewParser(Config{}, dest...) + return mustParse(Config{Exit: mustParseExit}, dest...) +} + +// mustParse is a helper that facilitates testing +func mustParse(config Config, dest ...interface{}) *Parser { + if config.Exit == nil { + config.Exit = os.Exit + } + if config.Out == nil { + config.Out = os.Stdout + } + + p, err := NewParser(config, dest...) if err != nil { - fmt.Fprintln(stdout, err) - osExit(-1) - return nil // just in case osExit was monkey-patched + fmt.Fprintln(config.Out, err) + config.Exit(-1) + return nil } p.MustParse(flags()) @@ -121,9 +136,11 @@ type Config struct { // subcommand StrictSubcommands bool - OsExit func(int) - Stdout io.Writer - Stderr io.Writer + // Exit is called to terminate the process with an error code (defaults to os.Exit) + Exit func(int) + + // Out is where help text, usage text, and failure messages are printed (defaults to os.Stdout) + Out io.Writer } // Parser represents a set of command line options with destination values @@ -137,10 +154,6 @@ type Parser struct { // the following field changes during processing of command line arguments lastCmd *command - - osExit func(int) - stdout io.Writer - stderr io.Writer } // Versioned is the interface that the destination struct should implement to @@ -190,6 +203,14 @@ func walkFieldsImpl(t reflect.Type, visit func(field reflect.StructField, owner // NewParser constructs a parser from a list of destination structs func NewParser(config Config, dests ...interface{}) (*Parser, error) { + // fill in defaults + if config.Exit == nil { + config.Exit = os.Exit + } + if config.Out == nil { + config.Out = os.Stdout + } + // first pick a name for the command for use in the usage text var name string switch { @@ -205,20 +226,6 @@ func NewParser(config Config, dests ...interface{}) (*Parser, error) { p := Parser{ cmd: &command{name: name}, config: config, - - osExit: osExit, - stdout: stdout, - stderr: stderr, - } - - if config.OsExit != nil { - p.osExit = config.OsExit - } - if config.Stdout != nil { - p.stdout = config.Stdout - } - if config.Stderr != nil { - p.stderr = config.Stderr } // make a list of roots @@ -506,11 +513,11 @@ func (p *Parser) MustParse(args []string) { err := p.Parse(args) switch { case err == ErrHelp: - p.writeHelpForSubcommand(p.stdout, p.lastCmd) - p.osExit(0) + p.writeHelpForSubcommand(p.config.Out, p.lastCmd) + p.config.Exit(0) case err == ErrVersion: - fmt.Fprintln(p.stdout, p.version) - p.osExit(0) + fmt.Fprintln(p.config.Out, p.version) + p.config.Exit(0) case err != nil: p.failWithSubcommand(err.Error(), p.lastCmd) } diff --git a/parse_test.go b/parse_test.go index 64119a8..d368b17 100644 --- a/parse_test.go +++ b/parse_test.go @@ -885,26 +885,18 @@ func TestParserMustParse(t *testing.T) { for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { - originalExit := osExit - originalStdout := stdout - defer func() { - osExit = originalExit - stdout = originalStdout - }() + var exitCode int + var stdout bytes.Buffer + exit := func(code int) { exitCode = code } - var exitCode *int - osExit = func(code int) { exitCode = &code } - var b bytes.Buffer - stdout = &b - - p, err := NewParser(Config{}, &tt.args) + p, err := NewParser(Config{Exit: exit, Out: &stdout}, &tt.args) require.NoError(t, err) assert.NotNil(t, p) p.MustParse(tt.cmdLine) assert.NotNil(t, exitCode) - assert.Equal(t, tt.code, *exitCode) - assert.Contains(t, b.String(), tt.output) + assert.Equal(t, tt.code, exitCode) + assert.Contains(t, stdout.String(), tt.output) }) } } @@ -1484,70 +1476,53 @@ func TestUnexportedFieldsSkipped(t *testing.T) { } func TestMustParseInvalidParser(t *testing.T) { - originalExit := osExit - originalStdout := stdout - defer func() { - osExit = originalExit - stdout = originalStdout - }() - var exitCode int - osExit = func(code int) { exitCode = code } - stdout = &bytes.Buffer{} + var stdout bytes.Buffer + exit := func(code int) { exitCode = code } var args struct { CannotParse struct{} } - parser := MustParse(&args) + parser := mustParse(Config{Out: &stdout, Exit: exit}, &args) assert.Nil(t, parser) assert.Equal(t, -1, exitCode) } func TestMustParsePrintsHelp(t *testing.T) { - originalExit := osExit - originalStdout := stdout originalArgs := os.Args defer func() { - osExit = originalExit - stdout = originalStdout os.Args = originalArgs }() - var exitCode *int - osExit = func(code int) { exitCode = &code } os.Args = []string{"someprogram", "--help"} - stdout = &bytes.Buffer{} + + var exitCode int + var stdout bytes.Buffer + exit := func(code int) { exitCode = code } var args struct{} - parser := MustParse(&args) + parser := mustParse(Config{Out: &stdout, Exit: exit}, &args) assert.NotNil(t, parser) - require.NotNil(t, exitCode) - assert.Equal(t, 0, *exitCode) + assert.Equal(t, 0, exitCode) } func TestMustParsePrintsVersion(t *testing.T) { - originalExit := osExit - originalStdout := stdout originalArgs := os.Args defer func() { - osExit = originalExit - stdout = originalStdout os.Args = originalArgs }() - var exitCode *int - osExit = func(code int) { exitCode = &code } + var exitCode int + var stdout bytes.Buffer + exit := func(code int) { exitCode = code } + os.Args = []string{"someprogram", "--version"} - var b bytes.Buffer - stdout = &b - var args versioned - parser := MustParse(&args) + parser := mustParse(Config{Out: &stdout, Exit: exit}, &args) require.NotNil(t, parser) - require.NotNil(t, exitCode) - assert.Equal(t, 0, *exitCode) - assert.Equal(t, "example 3.2.1\n", b.String()) + assert.Equal(t, 0, exitCode) + assert.Equal(t, "example 3.2.1\n", stdout.String()) } type mapWithUnmarshalText struct { diff --git a/usage.go b/usage.go index 80eba45..43d6231 100644 --- a/usage.go +++ b/usage.go @@ -3,20 +3,12 @@ package arg import ( "fmt" "io" - "os" "strings" ) // the width of the left column const colWidth = 25 -// to allow monkey patching in tests -var ( - stdout io.Writer = os.Stdout - stderr io.Writer = os.Stderr - osExit = os.Exit -) - // Fail prints usage information to stderr and exits with non-zero status func (p *Parser) Fail(msg string) { p.failWithSubcommand(msg, p.cmd) @@ -39,9 +31,9 @@ func (p *Parser) FailSubcommand(msg string, subcommand ...string) error { // failWithSubcommand prints usage information for the given subcommand to stderr and exits with non-zero status func (p *Parser) failWithSubcommand(msg string, cmd *command) { - p.writeUsageForSubcommand(p.stderr, cmd) - fmt.Fprintln(p.stderr, "error:", msg) - p.osExit(-1) + p.writeUsageForSubcommand(p.config.Out, cmd) + fmt.Fprintln(p.config.Out, "error:", msg) + p.config.Exit(-1) } // WriteUsage writes usage information to the given writer diff --git a/usage_test.go b/usage_test.go index be5894a..69feac2 100644 --- a/usage_test.go +++ b/usage_test.go @@ -572,18 +572,9 @@ Options: } func TestFail(t *testing.T) { - originalStderr := stderr - originalExit := osExit - defer func() { - stderr = originalStderr - osExit = originalExit - }() - - var b bytes.Buffer - stderr = &b - + var stdout bytes.Buffer var exitCode int - osExit = func(code int) { exitCode = code } + exit := func(code int) { exitCode = code } expectedStdout := ` Usage: example [--foo FOO] @@ -593,27 +584,18 @@ error: something went wrong var args struct { Foo int } - p, err := NewParser(Config{Program: "example"}, &args) + p, err := NewParser(Config{Program: "example", Exit: exit, Out: &stdout}, &args) require.NoError(t, err) p.Fail("something went wrong") - assert.Equal(t, expectedStdout[1:], b.String()) + assert.Equal(t, expectedStdout[1:], stdout.String()) assert.Equal(t, -1, exitCode) } func TestFailSubcommand(t *testing.T) { - originalStderr := stderr - originalExit := osExit - defer func() { - stderr = originalStderr - osExit = originalExit - }() - - var b bytes.Buffer - stderr = &b - + var stdout bytes.Buffer var exitCode int - osExit = func(code int) { exitCode = code } + exit := func(code int) { exitCode = code } expectedStdout := ` Usage: example sub @@ -623,13 +605,13 @@ error: something went wrong var args struct { Sub *struct{} `arg:"subcommand"` } - p, err := NewParser(Config{Program: "example"}, &args) + p, err := NewParser(Config{Program: "example", Exit: exit, Out: &stdout}, &args) require.NoError(t, err) err = p.FailSubcommand("something went wrong", "sub") require.NoError(t, err) - assert.Equal(t, expectedStdout[1:], b.String()) + assert.Equal(t, expectedStdout[1:], stdout.String()) assert.Equal(t, -1, exitCode) }