diff --git a/parse.go b/parse.go index a79db77..78cfa45 100644 --- a/parse.go +++ b/parse.go @@ -141,6 +141,9 @@ type Config struct { // Out is where help text, usage text, and failure messages are printed (defaults to os.Stdout) Out io.Writer + + // Environment is a map of environment variables to override those in the process environment, or provide values to those not in the process environment. + Environment map[string]string } // Parser represents a set of command line options with destination values @@ -531,7 +534,17 @@ func (p *Parser) captureEnvVars(specs []*spec, wasPresent map[*spec]bool) error continue } - value, found := os.LookupEnv(spec.env) + var value string + var found bool + + if !p.config.IgnoreEnv { + value, found = os.LookupEnv(spec.env) + } + + if p.config.Environment != nil { + value, found = p.config.Environment[spec.env] + } + if !found { continue } @@ -584,7 +597,7 @@ func (p *Parser) process(args []string) error { copy(specs, curCmd.specs) // deal with environment vars - if !p.config.IgnoreEnv { + if !p.config.IgnoreEnv || p.config.Environment != nil { err := p.captureEnvVars(specs, wasPresent) if err != nil { return err @@ -640,7 +653,7 @@ func (p *Parser) process(args []string) error { } // capture environment vars for these new options - if !p.config.IgnoreEnv { + if !p.config.IgnoreEnv || p.config.Environment != nil { err := p.captureEnvVars(subcmd.specs, wasPresent) if err != nil { return err diff --git a/parse_test.go b/parse_test.go index c929b28..e437cdf 100644 --- a/parse_test.go +++ b/parse_test.go @@ -39,7 +39,13 @@ func parseWithEnv(tb testing.TB, cmdline string, env []string, dest interface{}) } func parseWithEnvErr(tb testing.TB, cmdline string, env []string, dest interface{}) (*Parser, error) { - p, err := NewParser(Config{}, dest) + tb.Helper() + return parseWithConfigEnvErr(tb, Config{}, cmdline, env, dest) +} + +func parseWithConfigEnvErr(tb testing.TB, config Config, cmdline string, env []string, dest interface{}) (*Parser, error) { + tb.Helper() + p, err := NewParser(config, dest) if err != nil { return nil, err } @@ -669,6 +675,24 @@ func TestEnvironmentVariable(t *testing.T) { assert.Equal(t, "bar", args.Foo) } +func TestEnvironmentVariableViaCustomEnvironment(t *testing.T) { + var args struct { + Foo string `arg:"env"` + } + _, err := parseWithConfigEnvErr(t, Config{Environment: map[string]string{"FOO": "bar"}}, "", nil, &args) + require.NoError(t, err) + assert.Equal(t, "bar", args.Foo) +} + +func TestEnvironmentVariableOverriddenByCustomEnvironment(t *testing.T) { + var args struct { + Foo string `arg:"env"` + } + _, err := parseWithConfigEnvErr(t, Config{Environment: map[string]string{"FOO": "bar"}}, "", []string{"FOO=foo"}, &args) + require.NoError(t, err) + assert.Equal(t, "bar", args.Foo) +} + func TestEnvironmentVariableNotPresent(t *testing.T) { var args struct { NotPresent string `arg:"env"`