Merge pull request #223 from hhromic/fix-version-flag

Improve handling of version flag
This commit is contained in:
Alex Flint 2023-07-14 15:52:33 -04:00 committed by GitHub
commit 660b9045e1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 69 additions and 7 deletions

View File

@ -69,10 +69,10 @@ type command struct {
parent *command parent *command
} }
// ErrHelp indicates that -h or --help were provided // ErrHelp indicates that the builtin -h or --help were provided
var ErrHelp = errors.New("help requested by user") var ErrHelp = errors.New("help requested by user")
// ErrVersion indicates that --version was provided // ErrVersion indicates that the builtin --version was provided
var ErrVersion = errors.New("version requested by user") var ErrVersion = errors.New("version requested by user")
// for monkey patching in example code // for monkey patching in example code
@ -591,6 +591,15 @@ func (p *Parser) process(args []string) error {
} }
} }
// determine if the current command has a version option spec
var hasVersionOption bool
for _, spec := range curCmd.specs {
if spec.long == "version" {
hasVersionOption = true
break
}
}
// process each string from the command line // process each string from the command line
var allpositional bool var allpositional bool
var positionals []string var positionals []string
@ -648,8 +657,10 @@ func (p *Parser) process(args []string) error {
case "-h", "--help": case "-h", "--help":
return ErrHelp return ErrHelp
case "--version": case "--version":
if !hasVersionOption && p.version != "" {
return ErrVersion return ErrVersion
} }
}
// check for an equals sign, as in "--foo=bar" // check for an equals sign, as in "--foo=bar"
var value string var value string

View File

@ -1380,11 +1380,55 @@ func TestReuseParser(t *testing.T) {
assert.Error(t, err) assert.Error(t, err)
} }
func TestVersion(t *testing.T) { func TestNoVersion(t *testing.T) {
var args struct{} var args struct{}
err := parse("--version", &args)
assert.Equal(t, ErrVersion, err)
p, err := NewParser(Config{}, &args)
require.NoError(t, err)
err = p.Parse([]string{"--version"})
assert.Error(t, err)
assert.NotEqual(t, ErrVersion, err)
}
func TestBuiltinVersion(t *testing.T) {
var args struct{}
p, err := NewParser(Config{}, &args)
require.NoError(t, err)
p.version = "example 3.2.1"
err = p.Parse([]string{"--version"})
assert.Equal(t, ErrVersion, err)
}
func TestArgsVersion(t *testing.T) {
var args struct {
Version bool `arg:"--version"`
}
p, err := NewParser(Config{}, &args)
require.NoError(t, err)
err = p.Parse([]string{"--version"})
require.NoError(t, err)
require.Equal(t, args.Version, true)
}
func TestArgsAndBuiltinVersion(t *testing.T) {
var args struct {
Version bool `arg:"--version"`
}
p, err := NewParser(Config{}, &args)
require.NoError(t, err)
p.version = "example 3.2.1"
err = p.Parse([]string{"--version"})
require.NoError(t, err)
require.Equal(t, args.Version, true)
} }
func TestMultipleTerminates(t *testing.T) { func TestMultipleTerminates(t *testing.T) {

View File

@ -209,6 +209,7 @@ func (p *Parser) WriteHelpForSubcommand(w io.Writer, subcommand ...string) error
// writeHelp writes the usage string for the given subcommand // writeHelp writes the usage string for the given subcommand
func (p *Parser) writeHelpForSubcommand(w io.Writer, cmd *command) { func (p *Parser) writeHelpForSubcommand(w io.Writer, cmd *command) {
var positionals, longOptions, shortOptions, envOnlyOptions []*spec var positionals, longOptions, shortOptions, envOnlyOptions []*spec
var hasVersionOption bool
for _, spec := range cmd.specs { for _, spec := range cmd.specs {
switch { switch {
case spec.positional: case spec.positional:
@ -243,6 +244,9 @@ func (p *Parser) writeHelpForSubcommand(w io.Writer, cmd *command) {
} }
for _, spec := range longOptions { for _, spec := range longOptions {
p.printOption(w, spec) p.printOption(w, spec)
if spec.long == "version" {
hasVersionOption = true
}
} }
} }
@ -259,6 +263,9 @@ func (p *Parser) writeHelpForSubcommand(w io.Writer, cmd *command) {
fmt.Fprint(w, "\nGlobal options:\n") fmt.Fprint(w, "\nGlobal options:\n")
for _, spec := range globals { for _, spec := range globals {
p.printOption(w, spec) p.printOption(w, spec)
if spec.long == "version" {
hasVersionOption = true
}
} }
} }
@ -269,7 +276,7 @@ func (p *Parser) writeHelpForSubcommand(w io.Writer, cmd *command) {
short: "h", short: "h",
help: "display this help and exit", help: "display this help and exit",
}) })
if p.version != "" { if !hasVersionOption && p.version != "" {
p.printOption(w, &spec{ p.printOption(w, &spec{
cardinality: zero, cardinality: zero,
long: "version", long: "version",