Add duration flag

This commit is contained in:
Eyal Posener 2019-11-27 21:33:55 +02:00
parent 8a431c416e
commit 85f1fe4e1d
2 changed files with 97 additions and 6 deletions

View File

@ -39,6 +39,7 @@ import (
"fmt" "fmt"
"os" "os"
"strconv" "strconv"
"time"
"github.com/posener/complete/v2" "github.com/posener/complete/v2"
"github.com/posener/complete/v2/predict" "github.com/posener/complete/v2/predict"
@ -75,6 +76,12 @@ func (fs *FlagSet) Int(name string, value int, usage string, options ...predict.
return p return p
} }
func (fs *FlagSet) Duration(name string, value time.Duration, usage string, options ...predict.Option) *time.Duration {
p := new(time.Duration)
(*flag.FlagSet)(fs).Var(newDurationValue(value, p, predict.Options(options...)), name, usage)
return p
}
var CommandLine = (*FlagSet)(flag.CommandLine) var CommandLine = (*FlagSet)(flag.CommandLine)
// Parse parses command line arguments. It also performs bash completion when needed. // Parse parses command line arguments. It also performs bash completion when needed.
@ -95,6 +102,10 @@ func Int(name string, value int, usage string, options ...predict.Option) *int {
return CommandLine.Int(name, value, usage, options...) return CommandLine.Int(name, value, usage, options...)
} }
func Duration(name string, value time.Duration, usage string, options ...predict.Option) *time.Duration {
return CommandLine.Duration(name, value, usage, options...)
}
type boolValue struct { type boolValue struct {
v *bool v *bool
predict.Config predict.Config
@ -114,13 +125,13 @@ func (b *boolValue) Set(val string) error {
return b.Check(val) return b.Check(val)
} }
func (b *boolValue) Get() interface{} { return bool(*b.v) } func (b *boolValue) Get() interface{} { return *b.v }
func (b *boolValue) String() string { func (b *boolValue) String() string {
if b == nil || b.v == nil { if b == nil || b.v == nil {
return strconv.FormatBool(false) return strconv.FormatBool(false)
} }
return strconv.FormatBool(bool(*b.v)) return strconv.FormatBool(*b.v)
} }
func (b *boolValue) IsBoolFlag() bool { return true } func (b *boolValue) IsBoolFlag() bool { return true }
@ -154,14 +165,14 @@ func (s *stringValue) Set(val string) error {
} }
func (s *stringValue) Get() interface{} { func (s *stringValue) Get() interface{} {
return string(*s.v) return *s.v
} }
func (s *stringValue) String() string { func (s *stringValue) String() string {
if s == nil || s.v == nil { if s == nil || s.v == nil {
return "" return ""
} }
return string(*s.v) return *s.v
} }
func (s *stringValue) Predict(prefix string) []string { func (s *stringValue) Predict(prefix string) []string {
@ -190,13 +201,13 @@ func (i *intValue) Set(val string) error {
return i.Check(val) return i.Check(val)
} }
func (i *intValue) Get() interface{} { return int(*i.v) } func (i *intValue) Get() interface{} { return *i.v }
func (i *intValue) String() string { func (i *intValue) String() string {
if i == nil || i.v == nil { if i == nil || i.v == nil {
return strconv.Itoa(0) return strconv.Itoa(0)
} }
return strconv.Itoa(int(*i.v)) return strconv.Itoa(*i.v)
} }
func (s *intValue) Predict(prefix string) []string { func (s *intValue) Predict(prefix string) []string {
@ -205,3 +216,38 @@ func (s *intValue) Predict(prefix string) []string {
} }
return []string{""} return []string{""}
} }
type durationValue struct {
v *time.Duration
predict.Config
}
func newDurationValue(val time.Duration, p *time.Duration, c predict.Config) *durationValue {
*p = val
return &durationValue{v: p, Config: c}
}
func (i *durationValue) Set(val string) error {
v, err := time.ParseDuration(val)
*i.v = v
if err != nil {
return fmt.Errorf("bad value for duration flag")
}
return i.Check(val)
}
func (i *durationValue) Get() interface{} { return *i.v }
func (i *durationValue) String() string {
if i == nil || i.v == nil {
return time.Duration(0).String()
}
return i.v.String()
}
func (s *durationValue) Predict(prefix string) []string {
if s.Predictor != nil {
return s.Predictor.Predict(prefix)
}
return []string{""}
}

View File

@ -3,6 +3,7 @@ package compflag
import ( import (
"flag" "flag"
"testing" "testing"
"time"
"github.com/posener/complete/v2" "github.com/posener/complete/v2"
"github.com/posener/complete/v2/predict" "github.com/posener/complete/v2/predict"
@ -104,3 +105,47 @@ func TestInt(t *testing.T) {
complete.Test(t, complete.FlagSet((*flag.FlagSet)(&cmd)), "-a=1", []string{"1"}) complete.Test(t, complete.FlagSet((*flag.FlagSet)(&cmd)), "-a=1", []string{"1"})
}) })
} }
func TestDuration(t *testing.T) {
t.Parallel()
t.Run("options invalid not checked", func(t *testing.T) {
var cmd FlagSet
value := cmd.Duration("a", 0, "", predict.OptValues("1s", "1m"))
err := cmd.Parse([]string{"-a", "1h"})
assert.NoError(t, err)
assert.Equal(t, time.Hour, *value)
})
t.Run("options valid checked", func(t *testing.T) {
var cmd FlagSet
value := cmd.Duration("a", 0, "", predict.OptValues("1s", "1m"), predict.OptCheck())
err := cmd.Parse([]string{"-a", "1m"})
assert.NoError(t, err)
assert.Equal(t, time.Minute, *value)
})
t.Run("options invalid checked", func(t *testing.T) {
var cmd FlagSet
_ = cmd.Duration("a", 0, "", predict.OptValues("1s", "1m"), predict.OptCheck())
err := cmd.Parse([]string{"-a", "1h"})
assert.Error(t, err)
})
t.Run("options invalid duration value", func(t *testing.T) {
var cmd FlagSet
_ = cmd.Duration("a", 0, "", predict.OptValues("1h", "1m", "1"), predict.OptCheck())
err := cmd.Parse([]string{"-a", "1"})
assert.Error(t, err)
})
t.Run("complete", func(t *testing.T) {
var cmd FlagSet
_ = cmd.Duration("a", 0, "", predict.OptValues("1s", "1m"))
complete.Test(t, complete.FlagSet((*flag.FlagSet)(&cmd)), "-a ", []string{"1s", "1m"})
complete.Test(t, complete.FlagSet((*flag.FlagSet)(&cmd)), "-a=", []string{"1s", "1m"})
complete.Test(t, complete.FlagSet((*flag.FlagSet)(&cmd)), "-a 1", []string{"1s", "1m"})
complete.Test(t, complete.FlagSet((*flag.FlagSet)(&cmd)), "-a=1", []string{"1s", "1m"})
complete.Test(t, complete.FlagSet((*flag.FlagSet)(&cmd)), "-a=1m", []string{"1m"})
})
}