added parser struct

This commit is contained in:
Alex Flint 2015-10-31 23:57:26 -07:00
parent 026a824666
commit f427e9f317
3 changed files with 88 additions and 53 deletions

View File

@ -1,8 +1,11 @@
package arg package arg
import ( import (
"errors"
"fmt" "fmt"
"io"
"os" "os"
"path/filepath"
"reflect" "reflect"
"strconv" "strconv"
"strings" "strings"
@ -20,48 +23,74 @@ type spec struct {
wasPresent bool wasPresent bool
} }
// Parse returns this value to indicate that -h or --help were provided
var ErrHelp = errors.New("help requested by user")
// MustParse processes command line arguments and exits upon failure. // MustParse processes command line arguments and exits upon failure.
func MustParse(dest ...interface{}) { func MustParse(dest ...interface{}) {
err := Parse(dest...) p, err := NewParser(dest...)
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
os.Exit(1) os.Exit(1)
} }
} err = p.Parse(os.Args[1:])
if err != nil {
// Parse processes command line arguments and stores the result in args. fmt.Println(err)
func Parse(dest ...interface{}) error { writeUsage(os.Stdout, filepath.Base(os.Args[0]), p.spec)
return ParseFrom(os.Args[1:], dest...) os.Exit(1)
}
// ParseFrom processes command line arguments and stores the result in args.
func ParseFrom(args []string, dest ...interface{}) error {
// Add the help option if one is not already defined
var internal struct {
Help bool `arg:"-h,help:print this help message"`
} }
}
// Parse the spec // Parse processes command line arguments and stores them in dest.
dest = append(dest, &internal) func Parse(dest ...interface{}) error {
p, err := NewParser(dest...)
if err != nil {
fmt.Println(err)
os.Exit(1)
}
return p.Parse(os.Args[1:])
}
// Parser represents a set of command line options with destination values
type Parser struct {
spec []*spec
}
// NewParser constructs a parser from a list of destination structs
func NewParser(dest ...interface{}) (*Parser, error) {
spec, err := extractSpec(dest...) spec, err := extractSpec(dest...)
if err != nil { if err != nil {
return err return nil, err
}
return &Parser{spec: spec}, nil
}
// Parse processes the given command line option, storing the results in the field
// of the structs from which NewParser was constructed
func (p *Parser) Parse(args []string) error {
// If -h or --help were specified then print usage
for _, arg := range args {
if arg == "-h" || arg == "--help" {
return ErrHelp
}
if arg == "--" {
break
}
} }
// Process args // Process all command line arguments
err = processArgs(spec, args) err := process(p.spec, args)
if err != nil { if err != nil {
return err return err
} }
// If -h or --help were specified then print help
if internal.Help {
writeUsage(os.Stdout, spec)
os.Exit(0)
}
// Validate // Validate
return validate(spec) return validate(p.spec)
}
// WriteUsage writes usage information to the given writer
func (p *Parser) WriteUsage(w io.Writer) {
writeUsage(w, filepath.Base(os.Args[0]), p.spec)
} }
// extractSpec gets specifications for each argument from the tags in a struct // extractSpec gets specifications for each argument from the tags in a struct
@ -143,8 +172,9 @@ func extractSpec(dests ...interface{}) ([]*spec, error) {
return specs, nil return specs, nil
} }
// processArgs processes arguments using a pre-constructed spec // process goes through arguments the arguments one-by-one, parses them, and assigns the result to
func processArgs(specs []*spec, args []string) error { // the underlying struct field
func process(specs []*spec, args []string) error {
// construct a map from --option to spec // construct a map from --option to spec
optionMap := make(map[string]*spec) optionMap := make(map[string]*spec)
for _, spec := range specs { for _, spec := range specs {

View File

@ -8,15 +8,19 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func split(s string) []string { func parse(cmdline string, dest interface{}) error {
return strings.Split(s, " ") p, err := NewParser(dest)
if err != nil {
return err
}
return p.Parse(strings.Split(cmdline, " "))
} }
func TestStringSingle(t *testing.T) { func TestStringSingle(t *testing.T) {
var args struct { var args struct {
Foo string Foo string
} }
err := ParseFrom(split("--foo bar"), &args) err := parse("--foo bar", &args)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "bar", args.Foo) assert.Equal(t, "bar", args.Foo)
} }
@ -30,7 +34,7 @@ func TestMixed(t *testing.T) {
Spam float32 Spam float32
} }
args.Bar = 3 args.Bar = 3
err := ParseFrom(split("123 -spam=1.2 -ham -f xyz"), &args) err := parse("123 -spam=1.2 -ham -f xyz", &args)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "xyz", args.Foo) assert.Equal(t, "xyz", args.Foo)
assert.Equal(t, 3, args.Bar) assert.Equal(t, 3, args.Bar)
@ -43,7 +47,7 @@ func TestRequired(t *testing.T) {
var args struct { var args struct {
Foo string `arg:"required"` Foo string `arg:"required"`
} }
err := ParseFrom(nil, &args) err := parse("", &args)
require.Error(t, err, "--foo is required") require.Error(t, err, "--foo is required")
} }
@ -52,15 +56,15 @@ func TestShortFlag(t *testing.T) {
Foo string `arg:"-f"` Foo string `arg:"-f"`
} }
err := ParseFrom(split("-f xyz"), &args) err := parse("-f xyz", &args)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "xyz", args.Foo) assert.Equal(t, "xyz", args.Foo)
err = ParseFrom(split("-foo xyz"), &args) err = parse("-foo xyz", &args)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "xyz", args.Foo) assert.Equal(t, "xyz", args.Foo)
err = ParseFrom(split("--foo xyz"), &args) err = parse("--foo xyz", &args)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "xyz", args.Foo) assert.Equal(t, "xyz", args.Foo)
} }
@ -71,7 +75,7 @@ func TestCaseSensitive(t *testing.T) {
Upper bool `arg:"-V"` Upper bool `arg:"-V"`
} }
err := ParseFrom(split("-v"), &args) err := parse("-v", &args)
require.NoError(t, err) require.NoError(t, err)
assert.True(t, args.Lower) assert.True(t, args.Lower)
assert.False(t, args.Upper) assert.False(t, args.Upper)
@ -83,7 +87,7 @@ func TestCaseSensitive2(t *testing.T) {
Upper bool `arg:"-V"` Upper bool `arg:"-V"`
} }
err := ParseFrom(split("-V"), &args) err := parse("-V", &args)
require.NoError(t, err) require.NoError(t, err)
assert.False(t, args.Lower) assert.False(t, args.Lower)
assert.True(t, args.Upper) assert.True(t, args.Upper)
@ -94,7 +98,7 @@ func TestPositional(t *testing.T) {
Input string `arg:"positional"` Input string `arg:"positional"`
Output string `arg:"positional"` Output string `arg:"positional"`
} }
err := ParseFrom(split("foo"), &args) err := parse("foo", &args)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "foo", args.Input) assert.Equal(t, "foo", args.Input)
assert.Equal(t, "", args.Output) assert.Equal(t, "", args.Output)
@ -105,7 +109,7 @@ func TestRequiredPositional(t *testing.T) {
Input string `arg:"positional"` Input string `arg:"positional"`
Output string `arg:"positional,required"` Output string `arg:"positional,required"`
} }
err := ParseFrom(split("foo"), &args) err := parse("foo", &args)
assert.Error(t, err) assert.Error(t, err)
} }
@ -114,7 +118,7 @@ func TestTooManyPositional(t *testing.T) {
Input string `arg:"positional"` Input string `arg:"positional"`
Output string `arg:"positional"` Output string `arg:"positional"`
} }
err := ParseFrom(split("foo bar baz"), &args) err := parse("foo bar baz", &args)
assert.Error(t, err) assert.Error(t, err)
} }
@ -123,7 +127,7 @@ func TestMultiple(t *testing.T) {
Foo []int Foo []int
Bar []string Bar []string
} }
err := ParseFrom(split("--foo 1 2 3 --bar x y z"), &args) err := parse("--foo 1 2 3 --bar x y z", &args)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, []int{1, 2, 3}, args.Foo) assert.Equal(t, []int{1, 2, 3}, args.Foo)
assert.Equal(t, []string{"x", "y", "z"}, args.Bar) assert.Equal(t, []string{"x", "y", "z"}, args.Bar)

View File

@ -4,11 +4,12 @@ import (
"fmt" "fmt"
"io" "io"
"os" "os"
"path/filepath"
"reflect" "reflect"
"strings" "strings"
) )
// Usage prints usage information to stdout information and exits with status zero // Usage prints usage information to stdout and exits with status zero
func Usage(dest ...interface{}) { func Usage(dest ...interface{}) {
if err := WriteUsage(os.Stdout, dest...); err != nil { if err := WriteUsage(os.Stdout, dest...); err != nil {
fmt.Println(err) fmt.Println(err)
@ -31,20 +32,12 @@ func WriteUsage(w io.Writer, dest ...interface{}) error {
if err != nil { if err != nil {
return err return err
} }
writeUsage(w, spec) writeUsage(w, filepath.Base(os.Args[0]), spec)
return nil return nil
} }
func synopsis(spec *spec, form string) string {
if spec.dest.Kind() == reflect.Bool {
return form
} else {
return form + " " + strings.ToUpper(spec.long)
}
}
// writeUsage writes usage information to the given writer // writeUsage writes usage information to the given writer
func writeUsage(w io.Writer, specs []*spec) { func writeUsage(w io.Writer, cmd string, specs []*spec) {
var positionals, options []*spec var positionals, options []*spec
for _, spec := range specs { for _, spec := range specs {
if spec.positional { if spec.positional {
@ -54,7 +47,7 @@ func writeUsage(w io.Writer, specs []*spec) {
} }
} }
fmt.Fprint(w, "usage: ") fmt.Fprint(w, "usage: %s ", cmd)
// write the option component of the one-line usage message // write the option component of the one-line usage message
for _, spec := range options { for _, spec := range options {
@ -110,3 +103,11 @@ func writeUsage(w io.Writer, specs []*spec) {
} }
} }
} }
func synopsis(spec *spec, form string) string {
if spec.dest.Kind() == reflect.Bool {
return form
} else {
return form + " " + strings.ToUpper(spec.long)
}
}