diff --git a/command/auth_disable.go b/command/auth_disable.go index 621ce5907c..14cf092269 100644 --- a/command/auth_disable.go +++ b/command/auth_disable.go @@ -4,66 +4,88 @@ import ( "fmt" "strings" - "github.com/hashicorp/vault/meta" + "github.com/mitchellh/cli" + "github.com/posener/complete" ) +// Ensure we are implementing the right interfaces. +var _ cli.Command = (*AuthDisableCommand)(nil) +var _ cli.CommandAutocomplete = (*AuthDisableCommand)(nil) + // AuthDisableCommand is a Command that enables a new endpoint. type AuthDisableCommand struct { - meta.Meta -} - -func (c *AuthDisableCommand) Run(args []string) int { - flags := c.Meta.FlagSet("auth-disable", meta.FlagSetDefault) - flags.Usage = func() { c.Ui.Error(c.Help()) } - if err := flags.Parse(args); err != nil { - return 1 - } - - args = flags.Args() - if len(args) != 1 { - flags.Usage() - c.Ui.Error(fmt.Sprintf( - "\nauth-disable expects one argument: the path to disable.")) - return 1 - } - - path := args[0] - - client, err := c.Client() - if err != nil { - c.Ui.Error(fmt.Sprintf( - "Error initializing client: %s", err)) - return 2 - } - - if err := client.Sys().DisableAuth(path); err != nil { - c.Ui.Error(fmt.Sprintf( - "Error: %s", err)) - return 2 - } - - c.Ui.Output(fmt.Sprintf( - "Disabled auth provider at path '%s' if it was enabled", path)) - - return 0 + *BaseCommand } func (c *AuthDisableCommand) Synopsis() string { - return "Disable an auth provider" + return "Disables an auth provider" } func (c *AuthDisableCommand) Help() string { helpText := ` -Usage: vault auth-disable [options] path +Usage: vault auth-disable [options] PATH - Disable an already-enabled auth provider. + Disables an existing authentication provider at the given PATH. The argument + corresponds to the PATH of the mount, not the TYPE!. Once the auth provider + is disabled its path can no longer be used to authenticate. All access tokens + generated via the disabled auth provider are revoked. - Once the auth provider is disabled its path can no longer be used - to authenticate. All access tokens generated via the disabled auth provider - will be revoked. This command will block until all tokens are revoked. - If the command is exited early the tokens will still be revoked. + This command will block until all tokens are revoked. + + Disable the authentication provider at userpass/: + + $ vault auth-disable userpass + + For a full list of examples, please see the documentation. + +` + c.Flags().Help() -General Options: -` + meta.GeneralOptionsUsage() return strings.TrimSpace(helpText) } + +func (c *AuthDisableCommand) Flags() *FlagSets { + return c.flagSet(FlagSetHTTP) +} + +func (c *AuthDisableCommand) AutocompleteArgs() complete.Predictor { + return c.PredictVaultAuths() +} + +func (c *AuthDisableCommand) AutocompleteFlags() complete.Flags { + return c.Flags().Completions() +} + +func (c *AuthDisableCommand) Run(args []string) int { + f := c.Flags() + + if err := f.Parse(args); err != nil { + c.UI.Error(err.Error()) + return 1 + } + + args = f.Args() + switch { + case len(args) < 1: + c.UI.Error(fmt.Sprintf("Not enough arguments (expected 1, got %d)", len(args))) + return 1 + case len(args) > 1: + c.UI.Error(fmt.Sprintf("Too many arguments (expected 1, got %d)", len(args))) + return 1 + } + + path := ensureTrailingSlash(sanitizePath(args[0])) + + client, err := c.Client() + if err != nil { + c.UI.Error(err.Error()) + return 2 + } + + if err := client.Sys().DisableAuth(path); err != nil { + c.UI.Error(fmt.Sprintf("Error disabling auth at %s: %s", path, err)) + return 2 + } + + c.UI.Output(fmt.Sprintf("Success! Disabled the auth provider (if it existed) at: %s", path)) + return 0 +} diff --git a/command/auth_disable_test.go b/command/auth_disable_test.go index fb2b91fb23..86b1f90560 100644 --- a/command/auth_disable_test.go +++ b/command/auth_disable_test.go @@ -1,102 +1,133 @@ package command import ( + "strings" "testing" - "github.com/hashicorp/vault/api" - "github.com/hashicorp/vault/http" - "github.com/hashicorp/vault/meta" - "github.com/hashicorp/vault/vault" "github.com/mitchellh/cli" ) -func TestAuthDisable(t *testing.T) { - core, _, token := vault.TestCoreUnsealed(t) - ln, addr := http.TestServer(t, core) - defer ln.Close() +func testAuthDisableCommand(tb testing.TB) (*cli.MockUi, *AuthDisableCommand) { + tb.Helper() - ui := new(cli.MockUi) - c := &AuthDisableCommand{ - Meta: meta.Meta{ - ClientToken: token, - Ui: ui, + ui := cli.NewMockUi() + return ui, &AuthDisableCommand{ + BaseCommand: &BaseCommand{ + UI: ui, }, } - - args := []string{ - "-address", addr, - "noop", - } - - // Run the command once to setup the client, it will fail - c.Run(args) - - client, err := c.Client() - if err != nil { - t.Fatalf("err: %s", err) - } - - if err := client.Sys().EnableAuth("noop", "noop", ""); err != nil { - t.Fatalf("err: %s", err) - } - - if code := c.Run(args); code != 0 { - t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) - } - - mounts, err := client.Sys().ListAuth() - if err != nil { - t.Fatalf("err: %s", err) - } - - if _, ok := mounts["noop"]; ok { - t.Fatal("should not have noop mount") - } } -func TestAuthDisableWithOptions(t *testing.T) { - core, _, token := vault.TestCoreUnsealed(t) - ln, addr := http.TestServer(t, core) - defer ln.Close() +func TestAuthDisableCommand_Run(t *testing.T) { + t.Parallel() - ui := new(cli.MockUi) - c := &AuthDisableCommand{ - Meta: meta.Meta{ - ClientToken: token, - Ui: ui, + cases := []struct { + name string + args []string + out string + code int + }{ + { + "not_enough_args", + nil, + "Not enough arguments", + 1, + }, + { + "too_many_args", + []string{"foo", "bar"}, + "Too many arguments", + 1, }, } - args := []string{ - "-address", addr, - "noop", - } + t.Run("validations", func(t *testing.T) { + t.Parallel() - // Run the command once to setup the client, it will fail - c.Run(args) + for _, tc := range cases { + tc := tc - client, err := c.Client() - if err != nil { - t.Fatalf("err: %s", err) - } + t.Run(tc.name, func(t *testing.T) { + t.Parallel() - if err := client.Sys().EnableAuthWithOptions("noop", &api.EnableAuthOptions{ - Type: "noop", - Description: "", - }); err != nil { - t.Fatalf("err: %#v", err) - } + ui, cmd := testAuthDisableCommand(t) - if code := c.Run(args); code != 0 { - t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) - } + code := cmd.Run(tc.args) + if code != tc.code { + t.Errorf("expected %d to be %d", code, tc.code) + } - mounts, err := client.Sys().ListAuth() - if err != nil { - t.Fatalf("err: %s", err) - } + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, tc.out) { + t.Errorf("expected %q to contain %q", combined, tc.out) + } + }) + } + }) - if _, ok := mounts["noop"]; ok { - t.Fatal("should not have noop mount") - } + t.Run("integration", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + if err := client.Sys().EnableAuth("my-auth", "userpass", ""); err != nil { + t.Fatal(err) + } + + ui, cmd := testAuthDisableCommand(t) + cmd.client = client + + code := cmd.Run([]string{ + "my-auth", + }) + if exp := 0; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + expected := "Success! Disabled the auth provider" + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, expected) { + t.Errorf("expected %q to contain %q", combined, expected) + } + + auths, err := client.Sys().ListAuth() + if err != nil { + t.Fatal(err) + } + + if auth, ok := auths["my-auth/"]; ok { + t.Errorf("expected auth to be disabled: %#v", auth) + } + }) + + t.Run("communication_failure", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServerBad(t) + defer closer() + + ui, cmd := testAuthDisableCommand(t) + cmd.client = client + + code := cmd.Run([]string{ + "my-auth", + }) + if exp := 2; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + expected := "Error disabling auth at my-auth/: " + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, expected) { + t.Errorf("expected %q to contain %q", combined, expected) + } + }) + + t.Run("no_tabs", func(t *testing.T) { + t.Parallel() + + _, cmd := testAuthDisableCommand(t) + assertNoTabs(t, cmd) + }) }