diff --git a/parse.go b/parse.go index 4f62c60..ecbb7e4 100644 --- a/parse.go +++ b/parse.go @@ -356,6 +356,9 @@ func process(specs []*spec, args []string) error { for _, spec := range specs { if spec.positional { if spec.multiple { + if spec.required && len(positionals) == 0 { + return fmt.Errorf("%s is required", spec.long) + } err := setSlice(spec.dest, positionals, true) if err != nil { return fmt.Errorf("error processing %s: %v", spec.long, err) diff --git a/parse_test.go b/parse_test.go index 267e57c..a646f2b 100644 --- a/parse_test.go +++ b/parse_test.go @@ -250,6 +250,15 @@ func TestRequiredPositional(t *testing.T) { assert.Error(t, err) } +func TestRequiredPositionalMultiple(t *testing.T) { + var args struct { + Input string `arg:"positional"` + Multiple []string `arg:"positional,required"` + } + err := parse("foo", &args) + assert.Error(t, err) +} + func TestTooManyPositional(t *testing.T) { var args struct { Input string `arg:"positional"` @@ -270,6 +279,17 @@ func TestMultiple(t *testing.T) { assert.Equal(t, []string{"x", "y", "z"}, args.Bar) } +func TestMultiplePositionals(t *testing.T) { + var args struct { + Input string `arg:"positional"` + Multiple []string `arg:"positional,required"` + } + err := parse("foo a b c", &args) + assert.NoError(t, err) + assert.Equal(t, "foo", args.Input) + assert.Equal(t, []string{"a", "b", "c"}, args.Multiple) +} + func TestMultipleWithEq(t *testing.T) { var args struct { Foo []int @@ -321,6 +341,14 @@ func TestMissingRequired(t *testing.T) { assert.Error(t, err) } +func TestMissingRequiredMultiplePositional(t *testing.T) { + var args struct { + X []string `arg:"positional, required"` + } + err := parse("x", &args) + assert.Error(t, err) +} + func TestMissingValue(t *testing.T) { var args struct { Foo string diff --git a/usage.go b/usage.go index bf7fb83..4652b36 100644 --- a/usage.go +++ b/usage.go @@ -54,7 +54,13 @@ func (p *Parser) WriteUsage(w io.Writer) { fmt.Fprint(w, " ") up := strings.ToUpper(spec.long) if spec.multiple { - fmt.Fprintf(w, "[%s [%s ...]]", up, up) + if !spec.required { + fmt.Fprint(w, "[") + } + fmt.Fprintf(w, "%s [%s ...]", up, up) + if !spec.required { + fmt.Fprint(w, "]") + } } else { fmt.Fprint(w, up) } diff --git a/usage_test.go b/usage_test.go index bf78a80..1bb1071 100644 --- a/usage_test.go +++ b/usage_test.go @@ -157,3 +157,25 @@ Options: t.Fail() } } + +func TestRequiredMultiplePositionals(t *testing.T) { + expectedHelp := `Usage: example REQUIREDMULTIPLE [REQUIREDMULTIPLE ...] + +Positional arguments: + REQUIREDMULTIPLE required multiple positional + +Options: + --help, -h display this help and exit +` + var args struct { + RequiredMultiple []string `arg:"positional,required,help:required multiple positional"` + } + + p, err := NewParser(Config{}, &args) + require.NoError(t, err) + + os.Args[0] = "example" + var help bytes.Buffer + p.WriteHelp(&help) + assert.Equal(t, expectedHelp, help.String()) +}