go-arg/parse.go

505 lines
12 KiB
Go
Raw Normal View History

2015-10-31 20:26:58 -05:00
package arg
2015-10-31 18:15:24 -05:00
import (
2017-02-15 20:37:19 -06:00
"encoding"
2015-11-01 01:57:26 -05:00
"errors"
2015-10-31 18:15:24 -05:00
"fmt"
"os"
"path/filepath"
2015-10-31 18:15:24 -05:00
"reflect"
"strings"
2017-02-15 20:19:41 -06:00
scalar "github.com/alexflint/go-scalar"
2015-10-31 18:15:24 -05:00
)
2015-10-31 20:26:58 -05:00
// spec represents a command line option
type spec struct {
dest reflect.Value
long string
short string
multiple bool
required bool
positional bool
separate bool
2015-10-31 20:26:58 -05:00
help string
2016-01-18 12:42:04 -06:00
env string
2015-10-31 20:26:58 -05:00
wasPresent bool
boolean bool
2015-10-31 20:26:58 -05:00
}
// ErrHelp indicates that -h or --help were provided
2015-11-01 01:57:26 -05:00
var ErrHelp = errors.New("help requested by user")
2016-09-08 23:18:19 -05:00
// ErrVersion indicates that --version was provided
var ErrVersion = errors.New("version requested by user")
// MustParse processes command line arguments and exits upon failure
2016-01-05 15:52:33 -06:00
func MustParse(dest ...interface{}) *Parser {
p, err := NewParser(Config{}, dest...)
2015-11-01 01:57:26 -05:00
if err != nil {
fmt.Println(err)
os.Exit(-1)
2015-11-01 01:57:26 -05:00
}
2017-02-09 17:12:33 -06:00
err = p.Parse(flags())
if err == ErrHelp {
p.WriteHelp(os.Stdout)
os.Exit(0)
}
2016-09-08 23:18:19 -05:00
if err == ErrVersion {
fmt.Println(p.version)
os.Exit(0)
}
2015-10-31 18:15:24 -05:00
if err != nil {
p.Fail(err.Error())
2015-10-31 18:15:24 -05:00
}
2016-01-05 15:52:33 -06:00
return p
2015-10-31 18:15:24 -05:00
}
// Parse processes command line arguments and stores them in dest
2015-10-31 20:26:58 -05:00
func Parse(dest ...interface{}) error {
p, err := NewParser(Config{}, dest...)
2015-11-01 01:57:26 -05:00
if err != nil {
return err
2015-11-01 01:57:26 -05:00
}
2017-02-09 17:12:33 -06:00
return p.Parse(flags())
}
// flags gets all command line arguments other than the first (program name)
func flags() []string {
2017-02-15 20:24:32 -06:00
if len(os.Args) == 0 { // os.Args could be empty
2017-02-09 17:12:33 -06:00
return nil
}
return os.Args[1:]
2015-10-31 18:15:24 -05:00
}
// Config represents configuration options for an argument parser
type Config struct {
Program string // Program is the name of the program used in the help text
}
2015-11-01 01:57:26 -05:00
// Parser represents a set of command line options with destination values
type Parser struct {
2017-01-23 19:41:12 -06:00
spec []*spec
config Config
version string
description string
2016-09-08 23:18:19 -05:00
}
// Versioned is the interface that the destination struct should implement to
// make a version string appear at the top of the help message.
type Versioned interface {
// Version returns the version string that will be printed on a line by itself
// at the top of the help message.
Version() string
2015-11-01 01:57:26 -05:00
}
2015-10-31 18:15:24 -05:00
2017-01-23 19:41:12 -06:00
// Described is the interface that the destination struct should implement to
// make a description string appear at the top of the help message.
type Described interface {
// Description returns the string that will be printed on a line by itself
// at the top of the help message.
Description() string
}
2016-10-09 19:18:28 -05:00
// walkFields calls a function for each field of a struct, recursively expanding struct fields.
func walkFields(v reflect.Value, visit func(field reflect.StructField, val reflect.Value, owner reflect.Type) bool) {
t := v.Type()
for i := 0; i < t.NumField(); i++ {
field := t.Field(i)
val := v.Field(i)
expand := visit(field, val, t)
if expand && field.Type.Kind() == reflect.Struct {
walkFields(val, visit)
}
}
}
2015-11-01 01:57:26 -05:00
// NewParser constructs a parser from a list of destination structs
func NewParser(config Config, dests ...interface{}) (*Parser, error) {
2016-09-08 23:18:19 -05:00
p := Parser{
config: config,
}
2015-10-31 20:26:58 -05:00
for _, dest := range dests {
2016-09-08 23:18:19 -05:00
if dest, ok := dest.(Versioned); ok {
p.version = dest.Version()
}
2017-01-23 19:41:12 -06:00
if dest, ok := dest.(Described); ok {
p.description = dest.Description()
}
2015-10-31 20:26:58 -05:00
v := reflect.ValueOf(dest)
if v.Kind() != reflect.Ptr {
panic(fmt.Sprintf("%s is not a pointer (did you forget an ampersand?)", v.Type()))
2015-10-31 18:15:24 -05:00
}
2015-10-31 20:26:58 -05:00
v = v.Elem()
if v.Kind() != reflect.Struct {
panic(fmt.Sprintf("%T is not a struct pointer", dest))
2015-10-31 18:15:24 -05:00
}
2016-10-09 19:18:28 -05:00
var errs []string
walkFields(v, func(field reflect.StructField, val reflect.Value, t reflect.Type) bool {
2015-10-31 20:26:58 -05:00
// Check for the ignore switch in the tag
tag := field.Tag.Get("arg")
if tag == "-" {
2016-10-09 19:18:28 -05:00
return false
}
// If this is an embedded struct then recurse into its fields
if field.Anonymous && field.Type.Kind() == reflect.Struct {
return true
2015-10-31 18:15:24 -05:00
}
2015-10-31 20:26:58 -05:00
spec := spec{
2016-10-09 19:18:28 -05:00
long: strings.ToLower(field.Name),
dest: val,
2015-10-31 20:26:58 -05:00
}
2015-10-31 18:15:24 -05:00
help, exists := field.Tag.Lookup("help")
if exists {
spec.help = help
}
// Check whether this field is supported. It's good to do this here rather than
2018-04-18 23:23:08 -05:00
// wait until ParseValue because it means that a program with invalid argument
2016-07-31 11:14:44 -05:00
// fields will always fail regardless of whether the arguments it received
// exercised those fields.
var parseable bool
parseable, spec.boolean, spec.multiple = canParse(field.Type)
if !parseable {
2016-10-09 19:18:28 -05:00
errs = append(errs, fmt.Sprintf("%s.%s: %s fields are not supported",
t.Name(), field.Name, field.Type.String()))
return false
}
2015-10-31 20:26:58 -05:00
// Look at the tag
if tag != "" {
for _, key := range strings.Split(tag, ",") {
2017-09-16 06:05:53 -05:00
key = strings.TrimLeft(key, " ")
2015-10-31 20:26:58 -05:00
var value string
if pos := strings.Index(key, ":"); pos != -1 {
value = key[pos+1:]
key = key[:pos]
}
2015-10-31 18:15:24 -05:00
2015-10-31 20:26:58 -05:00
switch {
2017-02-21 11:08:08 -06:00
case strings.HasPrefix(key, "---"):
errs = append(errs, fmt.Sprintf("%s.%s: too many hyphens", t.Name(), field.Name))
2015-10-31 20:26:58 -05:00
case strings.HasPrefix(key, "--"):
spec.long = key[2:]
case strings.HasPrefix(key, "-"):
if len(key) != 2 {
2016-10-09 19:18:28 -05:00
errs = append(errs, fmt.Sprintf("%s.%s: short arguments must be one character only",
t.Name(), field.Name))
return false
2015-10-31 20:26:58 -05:00
}
spec.short = key[1:]
case key == "required":
spec.required = true
case key == "positional":
spec.positional = true
case key == "separate":
spec.separate = true
case key == "help": // deprecated
2015-10-31 20:26:58 -05:00
spec.help = value
2016-01-18 12:42:04 -06:00
case key == "env":
// Use override name if provided
if value != "" {
spec.env = value
} else {
spec.env = strings.ToUpper(field.Name)
}
2015-10-31 20:26:58 -05:00
default:
2016-10-09 19:18:28 -05:00
errs = append(errs, fmt.Sprintf("unrecognized tag '%s' on field %s", key, tag))
return false
2015-10-31 18:15:24 -05:00
}
}
}
2016-09-08 23:18:19 -05:00
p.spec = append(p.spec, &spec)
2016-10-09 19:18:28 -05:00
// if this was an embedded field then we already returned true up above
return false
})
if len(errs) > 0 {
return nil, errors.New(strings.Join(errs, "\n"))
2015-10-31 18:15:24 -05:00
}
}
2016-09-08 23:18:19 -05:00
if p.config.Program == "" {
p.config.Program = "program"
if len(os.Args) > 0 {
2016-09-08 23:18:19 -05:00
p.config.Program = filepath.Base(os.Args[0])
}
}
2016-09-08 23:18:19 -05:00
return &p, nil
}
// Parse processes the given command line option, storing the results in the field
// of the structs from which NewParser was constructed
func (p *Parser) Parse(args []string) error {
// If -h or --help were specified then print usage
for _, arg := range args {
if arg == "-h" || arg == "--help" {
return ErrHelp
}
2016-09-08 23:18:19 -05:00
if arg == "--version" {
return ErrVersion
}
if arg == "--" {
break
}
}
// Process all command line arguments
err := process(p.spec, args)
if err != nil {
return err
}
// Validate
return validate(p.spec)
2015-10-31 18:15:24 -05:00
}
// process goes through arguments one-by-one, parses them, and assigns the result to
2015-11-01 01:57:26 -05:00
// the underlying struct field
func process(specs []*spec, args []string) error {
2015-10-31 19:05:14 -05:00
// construct a map from --option to spec
optionMap := make(map[string]*spec)
2015-10-31 18:15:24 -05:00
for _, spec := range specs {
2015-10-31 19:05:14 -05:00
if spec.positional {
continue
}
2015-10-31 18:15:24 -05:00
if spec.long != "" {
2015-10-31 19:05:14 -05:00
optionMap[spec.long] = spec
2015-10-31 18:15:24 -05:00
}
if spec.short != "" {
2015-10-31 19:05:14 -05:00
optionMap[spec.short] = spec
2015-10-31 18:15:24 -05:00
}
2016-01-18 12:42:04 -06:00
if spec.env != "" {
if value, found := os.LookupEnv(spec.env); found {
2018-04-18 23:23:08 -05:00
err := scalar.ParseValue(spec.dest, value)
2016-01-18 12:42:04 -06:00
if err != nil {
return fmt.Errorf("error processing environment variable %s: %v", spec.env, err)
}
spec.wasPresent = true
}
}
2015-10-31 18:15:24 -05:00
}
// process each string from the command line
var allpositional bool
var positionals []string
// must use explicit for loop, not range, because we manipulate i inside the loop
for i := 0; i < len(args); i++ {
arg := args[i]
if arg == "--" {
allpositional = true
continue
}
2017-02-21 11:08:08 -06:00
if !isFlag(arg) || allpositional {
2015-10-31 18:15:24 -05:00
positionals = append(positionals, arg)
continue
}
// check for an equals sign, as in "--foo=bar"
var value string
opt := strings.TrimLeft(arg, "-")
if pos := strings.Index(opt, "="); pos != -1 {
value = opt[pos+1:]
opt = opt[:pos]
}
// lookup the spec for this option
2015-10-31 19:05:14 -05:00
spec, ok := optionMap[opt]
2015-10-31 18:15:24 -05:00
if !ok {
return fmt.Errorf("unknown argument %s", arg)
}
spec.wasPresent = true
// deal with the case of multiple values
if spec.multiple {
var values []string
if value == "" {
2017-02-21 11:08:08 -06:00
for i+1 < len(args) && !isFlag(args[i+1]) {
2015-10-31 19:05:14 -05:00
values = append(values, args[i+1])
i++
if spec.separate {
break
}
2015-10-31 18:15:24 -05:00
}
} else {
values = append(values, value)
}
err := setSlice(spec.dest, values, !spec.separate)
2015-10-31 19:05:14 -05:00
if err != nil {
return fmt.Errorf("error processing %s: %v", arg, err)
}
2015-10-31 18:15:24 -05:00
continue
}
// if it's a flag and it has no value then set the value to true
// use boolean because this takes account of TextUnmarshaler
if spec.boolean && value == "" {
2015-10-31 18:15:24 -05:00
value = "true"
}
// if we have something like "--foo" then the value is the next argument
if value == "" {
2018-01-13 16:20:00 -06:00
if i+1 == len(args) {
return fmt.Errorf("missing value for %s", arg)
}
if !nextIsNumeric(spec.dest.Type(), args[i+1]) && isFlag(args[i+1]) {
2015-10-31 18:15:24 -05:00
return fmt.Errorf("missing value for %s", arg)
}
value = args[i+1]
i++
}
2018-04-18 23:23:08 -05:00
err := scalar.ParseValue(spec.dest, value)
2015-10-31 18:15:24 -05:00
if err != nil {
return fmt.Errorf("error processing %s: %v", arg, err)
}
}
2015-10-31 19:05:14 -05:00
// process positionals
for _, spec := range specs {
if spec.positional {
if spec.multiple {
2017-03-30 13:32:39 -05:00
if spec.required && len(positionals) == 0 {
return fmt.Errorf("%s is required", spec.long)
}
err := setSlice(spec.dest, positionals, true)
2015-10-31 19:05:14 -05:00
if err != nil {
2015-10-31 20:26:58 -05:00
return fmt.Errorf("error processing %s: %v", spec.long, err)
2015-10-31 19:05:14 -05:00
}
positionals = nil
} else if len(positionals) > 0 {
2018-04-18 23:23:08 -05:00
err := scalar.ParseValue(spec.dest, positionals[0])
2015-10-31 19:05:14 -05:00
if err != nil {
2015-10-31 20:26:58 -05:00
return fmt.Errorf("error processing %s: %v", spec.long, err)
2015-10-31 19:05:14 -05:00
}
positionals = positionals[1:]
} else if spec.required {
2015-10-31 20:26:58 -05:00
return fmt.Errorf("%s is required", spec.long)
2015-10-31 19:05:14 -05:00
}
}
}
if len(positionals) > 0 {
return fmt.Errorf("too many positional arguments at '%s'", positionals[0])
}
2015-10-31 18:15:24 -05:00
return nil
}
2018-01-13 16:20:00 -06:00
func nextIsNumeric(t reflect.Type, s string) bool {
switch t.Kind() {
case reflect.Ptr:
return nextIsNumeric(t.Elem(), s)
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
v := reflect.New(t)
err := scalar.ParseValue(v, s)
return err == nil
default:
return false
}
}
2017-02-21 11:08:08 -06:00
// isFlag returns true if a token is a flag such as "-v" or "--user" but not "-" or "--"
func isFlag(s string) bool {
return strings.HasPrefix(s, "-") && strings.TrimLeft(s, "-") != ""
}
2015-10-31 18:15:24 -05:00
// validate an argument spec after arguments have been parse
func validate(spec []*spec) error {
for _, arg := range spec {
2015-10-31 19:05:14 -05:00
if !arg.positional && arg.required && !arg.wasPresent {
2015-10-31 20:26:58 -05:00
return fmt.Errorf("--%s is required", arg.long)
2015-10-31 18:15:24 -05:00
}
}
return nil
}
2016-07-31 11:14:44 -05:00
// parse a value as the appropriate type and store it in the struct
func setSlice(dest reflect.Value, values []string, trunc bool) error {
2015-10-31 19:05:14 -05:00
if !dest.CanSet() {
return fmt.Errorf("field is not writable")
}
var ptr bool
elem := dest.Type().Elem()
2018-04-18 23:51:16 -05:00
if elem.Kind() == reflect.Ptr && !elem.Implements(textUnmarshalerType) {
2015-10-31 19:05:14 -05:00
ptr = true
elem = elem.Elem()
}
// Truncate the dest slice in case default values exist
if trunc && !dest.IsNil() {
dest.SetLen(0)
}
2015-10-31 19:05:14 -05:00
for _, s := range values {
v := reflect.New(elem)
2018-04-18 23:23:08 -05:00
if err := scalar.ParseValue(v.Elem(), s); err != nil {
2015-10-31 19:05:14 -05:00
return err
}
2015-11-04 12:27:17 -06:00
if !ptr {
v = v.Elem()
2015-10-31 19:05:14 -05:00
}
2015-11-04 12:27:17 -06:00
dest.Set(reflect.Append(dest, v))
2015-10-31 19:05:14 -05:00
}
2015-10-31 18:15:24 -05:00
return nil
}
// canParse returns true if the type can be parsed from a string
func canParse(t reflect.Type) (parseable, boolean, multiple bool) {
parseable = scalar.CanParse(t)
boolean = isBoolean(t)
if parseable {
return
}
// Look inside pointer types
if t.Kind() == reflect.Ptr {
t = t.Elem()
}
// Look inside slice types
if t.Kind() == reflect.Slice {
multiple = true
t = t.Elem()
}
parseable = scalar.CanParse(t)
boolean = isBoolean(t)
if parseable {
return
}
// Look inside pointer types (again, in case of []*Type)
if t.Kind() == reflect.Ptr {
t = t.Elem()
}
parseable = scalar.CanParse(t)
boolean = isBoolean(t)
if parseable {
return
}
return false, false, false
}
2017-02-15 20:19:41 -06:00
2017-02-15 20:37:19 -06:00
var textUnmarshalerType = reflect.TypeOf([]encoding.TextUnmarshaler{}).Elem()
// isBoolean returns true if the type can be parsed from a single string
func isBoolean(t reflect.Type) bool {
2017-02-15 20:37:19 -06:00
switch {
case t.Implements(textUnmarshalerType):
return false
2017-02-15 20:37:19 -06:00
case t.Kind() == reflect.Bool:
return true
2017-02-15 20:37:19 -06:00
case t.Kind() == reflect.Ptr && t.Elem().Kind() == reflect.Bool:
return true
2017-02-15 20:37:19 -06:00
default:
return false
2017-02-15 20:37:19 -06:00
}
2017-02-15 20:19:41 -06:00
}