clean up customizable stdout, stderr, and exit in parser config

This commit is contained in:
Alex Flint 2023-02-08 09:49:03 -05:00
parent 5dbdd5d0c5
commit df28e7154b
5 changed files with 80 additions and 133 deletions

View File

@ -162,8 +162,7 @@ func Example_helpText() {
} }
// This is only necessary when running inside golang's runnable example harness // This is only necessary when running inside golang's runnable example harness
osExit = func(int) {} mustParseExit = func(int) {}
stdout = os.Stdout
MustParse(&args) MustParse(&args)
@ -195,8 +194,7 @@ func Example_helpPlaceholder() {
} }
// This is only necessary when running inside golang's runnable example harness // This is only necessary when running inside golang's runnable example harness
osExit = func(int) {} mustParseExit = func(int) {}
stdout = os.Stdout
MustParse(&args) MustParse(&args)
@ -236,8 +234,7 @@ func Example_helpTextWithSubcommand() {
} }
// This is only necessary when running inside golang's runnable example harness // This is only necessary when running inside golang's runnable example harness
osExit = func(int) {} mustParseExit = func(int) {}
stdout = os.Stdout
MustParse(&args) MustParse(&args)
@ -274,8 +271,7 @@ func Example_helpTextWhenUsingSubcommand() {
} }
// This is only necessary when running inside golang's runnable example harness // This is only necessary when running inside golang's runnable example harness
osExit = func(int) {} mustParseExit = func(int) {}
stdout = os.Stdout
MustParse(&args) MustParse(&args)
@ -311,10 +307,9 @@ func Example_writeHelpForSubcommand() {
} }
// This is only necessary when running inside golang's runnable example harness // This is only necessary when running inside golang's runnable example harness
osExit = func(int) {} exit := func(int) {}
stdout = os.Stdout
p, err := NewParser(Config{}, &args) p, err := NewParser(Config{Exit: exit}, &args)
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
os.Exit(1) os.Exit(1)
@ -360,10 +355,9 @@ func Example_writeHelpForSubcommandNested() {
} }
// This is only necessary when running inside golang's runnable example harness // This is only necessary when running inside golang's runnable example harness
osExit = func(int) {} exit := func(int) {}
stdout = os.Stdout
p, err := NewParser(Config{}, &args) p, err := NewParser(Config{Exit: exit}, &args)
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
os.Exit(1) os.Exit(1)
@ -397,8 +391,7 @@ func Example_errorText() {
} }
// This is only necessary when running inside golang's runnable example harness // This is only necessary when running inside golang's runnable example harness
osExit = func(int) {} mustParseExit = func(int) {}
stderr = os.Stdout
MustParse(&args) MustParse(&args)
@ -421,8 +414,7 @@ func Example_errorTextForSubcommand() {
} }
// This is only necessary when running inside golang's runnable example harness // This is only necessary when running inside golang's runnable example harness
osExit = func(int) {} mustParseExit = func(int) {}
stderr = os.Stdout
MustParse(&args) MustParse(&args)
@ -457,8 +449,7 @@ func Example_subcommand() {
} }
// This is only necessary when running inside golang's runnable example harness // This is only necessary when running inside golang's runnable example harness
osExit = func(int) {} mustParseExit = func(int) {}
stderr = os.Stdout
MustParse(&args) MustParse(&args)

View File

@ -75,13 +75,28 @@ var ErrHelp = errors.New("help requested by user")
// ErrVersion indicates that --version was provided // ErrVersion indicates that --version was provided
var ErrVersion = errors.New("version requested by user") var ErrVersion = errors.New("version requested by user")
// for monkey patching in example code
var mustParseExit = os.Exit
// MustParse processes command line arguments and exits upon failure // MustParse processes command line arguments and exits upon failure
func MustParse(dest ...interface{}) *Parser { func MustParse(dest ...interface{}) *Parser {
p, err := NewParser(Config{}, dest...) return mustParse(Config{Exit: mustParseExit}, dest...)
}
// mustParse is a helper that facilitates testing
func mustParse(config Config, dest ...interface{}) *Parser {
if config.Exit == nil {
config.Exit = os.Exit
}
if config.Out == nil {
config.Out = os.Stdout
}
p, err := NewParser(config, dest...)
if err != nil { if err != nil {
fmt.Fprintln(stdout, err) fmt.Fprintln(config.Out, err)
osExit(-1) config.Exit(-1)
return nil // just in case osExit was monkey-patched return nil
} }
p.MustParse(flags()) p.MustParse(flags())
@ -121,9 +136,11 @@ type Config struct {
// subcommand // subcommand
StrictSubcommands bool StrictSubcommands bool
OsExit func(int) // Exit is called to terminate the process with an error code (defaults to os.Exit)
Stdout io.Writer Exit func(int)
Stderr io.Writer
// Out is where help text, usage text, and failure messages are printed (defaults to os.Stdout)
Out io.Writer
} }
// Parser represents a set of command line options with destination values // Parser represents a set of command line options with destination values
@ -137,10 +154,6 @@ type Parser struct {
// the following field changes during processing of command line arguments // the following field changes during processing of command line arguments
lastCmd *command lastCmd *command
osExit func(int)
stdout io.Writer
stderr io.Writer
} }
// Versioned is the interface that the destination struct should implement to // Versioned is the interface that the destination struct should implement to
@ -190,6 +203,14 @@ func walkFieldsImpl(t reflect.Type, visit func(field reflect.StructField, owner
// NewParser constructs a parser from a list of destination structs // NewParser constructs a parser from a list of destination structs
func NewParser(config Config, dests ...interface{}) (*Parser, error) { func NewParser(config Config, dests ...interface{}) (*Parser, error) {
// fill in defaults
if config.Exit == nil {
config.Exit = os.Exit
}
if config.Out == nil {
config.Out = os.Stdout
}
// first pick a name for the command for use in the usage text // first pick a name for the command for use in the usage text
var name string var name string
switch { switch {
@ -205,20 +226,6 @@ func NewParser(config Config, dests ...interface{}) (*Parser, error) {
p := Parser{ p := Parser{
cmd: &command{name: name}, cmd: &command{name: name},
config: config, config: config,
osExit: osExit,
stdout: stdout,
stderr: stderr,
}
if config.OsExit != nil {
p.osExit = config.OsExit
}
if config.Stdout != nil {
p.stdout = config.Stdout
}
if config.Stderr != nil {
p.stderr = config.Stderr
} }
// make a list of roots // make a list of roots
@ -506,11 +513,11 @@ func (p *Parser) MustParse(args []string) {
err := p.Parse(args) err := p.Parse(args)
switch { switch {
case err == ErrHelp: case err == ErrHelp:
p.writeHelpForSubcommand(p.stdout, p.lastCmd) p.writeHelpForSubcommand(p.config.Out, p.lastCmd)
p.osExit(0) p.config.Exit(0)
case err == ErrVersion: case err == ErrVersion:
fmt.Fprintln(p.stdout, p.version) fmt.Fprintln(p.config.Out, p.version)
p.osExit(0) p.config.Exit(0)
case err != nil: case err != nil:
p.failWithSubcommand(err.Error(), p.lastCmd) p.failWithSubcommand(err.Error(), p.lastCmd)
} }

View File

@ -885,26 +885,18 @@ func TestParserMustParse(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
tt := tt tt := tt
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
originalExit := osExit var exitCode int
originalStdout := stdout var stdout bytes.Buffer
defer func() { exit := func(code int) { exitCode = code }
osExit = originalExit
stdout = originalStdout
}()
var exitCode *int p, err := NewParser(Config{Exit: exit, Out: &stdout}, &tt.args)
osExit = func(code int) { exitCode = &code }
var b bytes.Buffer
stdout = &b
p, err := NewParser(Config{}, &tt.args)
require.NoError(t, err) require.NoError(t, err)
assert.NotNil(t, p) assert.NotNil(t, p)
p.MustParse(tt.cmdLine) p.MustParse(tt.cmdLine)
assert.NotNil(t, exitCode) assert.NotNil(t, exitCode)
assert.Equal(t, tt.code, *exitCode) assert.Equal(t, tt.code, exitCode)
assert.Contains(t, b.String(), tt.output) assert.Contains(t, stdout.String(), tt.output)
}) })
} }
} }
@ -1484,70 +1476,53 @@ func TestUnexportedFieldsSkipped(t *testing.T) {
} }
func TestMustParseInvalidParser(t *testing.T) { func TestMustParseInvalidParser(t *testing.T) {
originalExit := osExit
originalStdout := stdout
defer func() {
osExit = originalExit
stdout = originalStdout
}()
var exitCode int var exitCode int
osExit = func(code int) { exitCode = code } var stdout bytes.Buffer
stdout = &bytes.Buffer{} exit := func(code int) { exitCode = code }
var args struct { var args struct {
CannotParse struct{} CannotParse struct{}
} }
parser := MustParse(&args) parser := mustParse(Config{Out: &stdout, Exit: exit}, &args)
assert.Nil(t, parser) assert.Nil(t, parser)
assert.Equal(t, -1, exitCode) assert.Equal(t, -1, exitCode)
} }
func TestMustParsePrintsHelp(t *testing.T) { func TestMustParsePrintsHelp(t *testing.T) {
originalExit := osExit
originalStdout := stdout
originalArgs := os.Args originalArgs := os.Args
defer func() { defer func() {
osExit = originalExit
stdout = originalStdout
os.Args = originalArgs os.Args = originalArgs
}() }()
var exitCode *int
osExit = func(code int) { exitCode = &code }
os.Args = []string{"someprogram", "--help"} os.Args = []string{"someprogram", "--help"}
stdout = &bytes.Buffer{}
var exitCode int
var stdout bytes.Buffer
exit := func(code int) { exitCode = code }
var args struct{} var args struct{}
parser := MustParse(&args) parser := mustParse(Config{Out: &stdout, Exit: exit}, &args)
assert.NotNil(t, parser) assert.NotNil(t, parser)
require.NotNil(t, exitCode) assert.Equal(t, 0, exitCode)
assert.Equal(t, 0, *exitCode)
} }
func TestMustParsePrintsVersion(t *testing.T) { func TestMustParsePrintsVersion(t *testing.T) {
originalExit := osExit
originalStdout := stdout
originalArgs := os.Args originalArgs := os.Args
defer func() { defer func() {
osExit = originalExit
stdout = originalStdout
os.Args = originalArgs os.Args = originalArgs
}() }()
var exitCode *int var exitCode int
osExit = func(code int) { exitCode = &code } var stdout bytes.Buffer
exit := func(code int) { exitCode = code }
os.Args = []string{"someprogram", "--version"} os.Args = []string{"someprogram", "--version"}
var b bytes.Buffer
stdout = &b
var args versioned var args versioned
parser := MustParse(&args) parser := mustParse(Config{Out: &stdout, Exit: exit}, &args)
require.NotNil(t, parser) require.NotNil(t, parser)
require.NotNil(t, exitCode) assert.Equal(t, 0, exitCode)
assert.Equal(t, 0, *exitCode) assert.Equal(t, "example 3.2.1\n", stdout.String())
assert.Equal(t, "example 3.2.1\n", b.String())
} }
type mapWithUnmarshalText struct { type mapWithUnmarshalText struct {

View File

@ -3,20 +3,12 @@ package arg
import ( import (
"fmt" "fmt"
"io" "io"
"os"
"strings" "strings"
) )
// the width of the left column // the width of the left column
const colWidth = 25 const colWidth = 25
// to allow monkey patching in tests
var (
stdout io.Writer = os.Stdout
stderr io.Writer = os.Stderr
osExit = os.Exit
)
// Fail prints usage information to stderr and exits with non-zero status // Fail prints usage information to stderr and exits with non-zero status
func (p *Parser) Fail(msg string) { func (p *Parser) Fail(msg string) {
p.failWithSubcommand(msg, p.cmd) p.failWithSubcommand(msg, p.cmd)
@ -39,9 +31,9 @@ func (p *Parser) FailSubcommand(msg string, subcommand ...string) error {
// failWithSubcommand prints usage information for the given subcommand to stderr and exits with non-zero status // failWithSubcommand prints usage information for the given subcommand to stderr and exits with non-zero status
func (p *Parser) failWithSubcommand(msg string, cmd *command) { func (p *Parser) failWithSubcommand(msg string, cmd *command) {
p.writeUsageForSubcommand(p.stderr, cmd) p.writeUsageForSubcommand(p.config.Out, cmd)
fmt.Fprintln(p.stderr, "error:", msg) fmt.Fprintln(p.config.Out, "error:", msg)
p.osExit(-1) p.config.Exit(-1)
} }
// WriteUsage writes usage information to the given writer // WriteUsage writes usage information to the given writer

View File

@ -572,18 +572,9 @@ Options:
} }
func TestFail(t *testing.T) { func TestFail(t *testing.T) {
originalStderr := stderr var stdout bytes.Buffer
originalExit := osExit
defer func() {
stderr = originalStderr
osExit = originalExit
}()
var b bytes.Buffer
stderr = &b
var exitCode int var exitCode int
osExit = func(code int) { exitCode = code } exit := func(code int) { exitCode = code }
expectedStdout := ` expectedStdout := `
Usage: example [--foo FOO] Usage: example [--foo FOO]
@ -593,27 +584,18 @@ error: something went wrong
var args struct { var args struct {
Foo int Foo int
} }
p, err := NewParser(Config{Program: "example"}, &args) p, err := NewParser(Config{Program: "example", Exit: exit, Out: &stdout}, &args)
require.NoError(t, err) require.NoError(t, err)
p.Fail("something went wrong") p.Fail("something went wrong")
assert.Equal(t, expectedStdout[1:], b.String()) assert.Equal(t, expectedStdout[1:], stdout.String())
assert.Equal(t, -1, exitCode) assert.Equal(t, -1, exitCode)
} }
func TestFailSubcommand(t *testing.T) { func TestFailSubcommand(t *testing.T) {
originalStderr := stderr var stdout bytes.Buffer
originalExit := osExit
defer func() {
stderr = originalStderr
osExit = originalExit
}()
var b bytes.Buffer
stderr = &b
var exitCode int var exitCode int
osExit = func(code int) { exitCode = code } exit := func(code int) { exitCode = code }
expectedStdout := ` expectedStdout := `
Usage: example sub Usage: example sub
@ -623,13 +605,13 @@ error: something went wrong
var args struct { var args struct {
Sub *struct{} `arg:"subcommand"` Sub *struct{} `arg:"subcommand"`
} }
p, err := NewParser(Config{Program: "example"}, &args) p, err := NewParser(Config{Program: "example", Exit: exit, Out: &stdout}, &args)
require.NoError(t, err) require.NoError(t, err)
err = p.FailSubcommand("something went wrong", "sub") err = p.FailSubcommand("something went wrong", "sub")
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, expectedStdout[1:], b.String()) assert.Equal(t, expectedStdout[1:], stdout.String())
assert.Equal(t, -1, exitCode) assert.Equal(t, -1, exitCode)
} }