Update policy-list command

This commit is contained in:
Seth Vargo
2017-09-05 00:03:21 -04:00
parent eece6eea4a
commit cfd378187a
2 changed files with 212 additions and 105 deletions

View File

@@ -4,89 +4,102 @@ import (
"fmt"
"strings"
"github.com/hashicorp/vault/meta"
"github.com/mitchellh/cli"
"github.com/posener/complete"
)
// Ensure we are implementing the right interfaces.
var _ cli.Command = (*PolicyListCommand)(nil)
var _ cli.CommandAutocomplete = (*PolicyListCommand)(nil)
// PolicyListCommand is a Command that enables a new endpoint.
type PolicyListCommand struct {
meta.Meta
}
func (c *PolicyListCommand) Run(args []string) int {
flags := c.Meta.FlagSet("policy-list", meta.FlagSetDefault)
flags.Usage = func() { c.Ui.Error(c.Help()) }
if err := flags.Parse(args); err != nil {
return 1
}
args = flags.Args()
if len(args) == 1 {
return c.read(args[0])
} else if len(args) == 0 {
return c.list()
} else {
flags.Usage()
c.Ui.Error(fmt.Sprintf(
"\npolicies expects zero or one arguments"))
return 1
}
}
func (c *PolicyListCommand) list() int {
client, err := c.Client()
if err != nil {
c.Ui.Error(fmt.Sprintf(
"Error initializing client: %s", err))
return 2
}
policies, err := client.Sys().ListPolicies()
if err != nil {
c.Ui.Error(fmt.Sprintf(
"Error: %s", err))
return 1
}
for _, p := range policies {
c.Ui.Output(p)
}
return 0
}
func (c *PolicyListCommand) read(n string) int {
client, err := c.Client()
if err != nil {
c.Ui.Error(fmt.Sprintf(
"Error initializing client: %s", err))
return 2
}
rules, err := client.Sys().GetPolicy(n)
if err != nil {
c.Ui.Error(fmt.Sprintf(
"Error: %s", err))
return 1
}
c.Ui.Output(rules)
return 0
*BaseCommand
}
func (c *PolicyListCommand) Synopsis() string {
return "List the policies on the server"
return "Lists the installed policies"
}
func (c *PolicyListCommand) Help() string {
helpText := `
Usage: vault policies [options] [name]
Usage: vault policies [options] [NAME]
List the policies that are available or read a single policy.
Lists the policies that are installed on the Vault server. If the optional
argument is given, this command returns the policy's contents.
This command lists the policies that are written to the Vault server.
If a name of a policy is specified, that policy is outputted.
List all policies stored in Vault:
$ vault policies
Read the contents of the policy named "my-policy":
$ vault policies my-policy
For a full list of examples, please see the documentation.
` + c.Flags().Help()
General Options:
` + meta.GeneralOptionsUsage()
return strings.TrimSpace(helpText)
}
func (c *PolicyListCommand) Flags() *FlagSets {
return c.flagSet(FlagSetHTTP)
}
func (c *PolicyListCommand) AutocompleteArgs() complete.Predictor {
return c.PredictVaultPolicies()
}
func (c *PolicyListCommand) AutocompleteFlags() complete.Flags {
return c.Flags().Completions()
}
func (c *PolicyListCommand) Run(args []string) int {
f := c.Flags()
if err := f.Parse(args); err != nil {
c.UI.Error(err.Error())
return 1
}
args = f.Args()
switch len(args) {
case 0, 1:
default:
c.UI.Error(fmt.Sprintf("Too many arguments (expected 0-2, got %d)", len(args)))
return 1
}
client, err := c.Client()
if err != nil {
c.UI.Error(err.Error())
return 2
}
switch len(args) {
case 0:
policies, err := client.Sys().ListPolicies()
if err != nil {
c.UI.Error(fmt.Sprintf("Error listing policies: %s", err))
return 2
}
for _, p := range policies {
c.UI.Output(p)
}
case 1:
name := strings.ToLower(strings.TrimSpace(args[0]))
rules, err := client.Sys().GetPolicy(name)
if err != nil {
c.UI.Error(fmt.Sprintf("Error reading policy %s: %s", name, err))
return 2
}
if rules == "" {
c.UI.Error(fmt.Sprintf("Error reading policy: no policy named: %s", name))
return 2
}
c.UI.Output(strings.TrimSpace(rules))
}
return 0
}

View File

@@ -1,53 +1,147 @@
package command
import (
"strings"
"testing"
"github.com/hashicorp/vault/http"
"github.com/hashicorp/vault/meta"
"github.com/hashicorp/vault/vault"
"github.com/mitchellh/cli"
)
func TestPolicyList(t *testing.T) {
core, _, token := vault.TestCoreUnsealed(t)
ln, addr := http.TestServer(t, core)
defer ln.Close()
func testPolicyListCommand(tb testing.TB) (*cli.MockUi, *PolicyListCommand) {
tb.Helper()
ui := new(cli.MockUi)
c := &PolicyListCommand{
Meta: meta.Meta{
ClientToken: token,
Ui: ui,
ui := cli.NewMockUi()
return ui, &PolicyListCommand{
BaseCommand: &BaseCommand{
UI: ui,
},
}
args := []string{
"-address", addr,
}
if code := c.Run(args); code != 0 {
t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String())
}
}
func TestPolicyRead(t *testing.T) {
core, _, token := vault.TestCoreUnsealed(t)
ln, addr := http.TestServer(t, core)
defer ln.Close()
func TestPolicyListCommand_Run(t *testing.T) {
t.Parallel()
ui := new(cli.MockUi)
c := &PolicyListCommand{
Meta: meta.Meta{
ClientToken: token,
Ui: ui,
cases := []struct {
name string
args []string
out string
code int
}{
{
"too_many_args",
[]string{"foo", "bar"},
"Too many arguments",
1,
},
{
"no_policy_exists",
[]string{"not-a-real-policy"},
"no policy named",
2,
},
}
args := []string{
"-address", addr,
"root",
}
if code := c.Run(args); code != 0 {
t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String())
}
t.Run("validations", func(t *testing.T) {
t.Parallel()
for _, tc := range cases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
client, closer := testVaultServer(t)
defer closer()
ui, cmd := testPolicyListCommand(t)
cmd.client = client
code := cmd.Run(tc.args)
if code != tc.code {
t.Errorf("expected %d to be %d", code, tc.code)
}
combined := ui.OutputWriter.String() + ui.ErrorWriter.String()
if !strings.Contains(combined, tc.out) {
t.Errorf("expected %q to contain %q", combined, tc.out)
}
})
}
})
t.Run("list", func(t *testing.T) {
t.Parallel()
client, closer := testVaultServer(t)
defer closer()
ui, cmd := testPolicyListCommand(t)
cmd.client = client
code := cmd.Run([]string{})
if exp := 0; code != exp {
t.Errorf("expected %d to be %d", code, exp)
}
expected := "default\nroot"
combined := ui.OutputWriter.String() + ui.ErrorWriter.String()
if !strings.Contains(combined, expected) {
t.Errorf("expected %q to contain %q", combined, expected)
}
})
t.Run("read", func(t *testing.T) {
t.Parallel()
client, closer := testVaultServer(t)
defer closer()
policy := `path "secret/" {}`
if err := client.Sys().PutPolicy("my-policy", policy); err != nil {
t.Fatal(err)
}
ui, cmd := testPolicyListCommand(t)
cmd.client = client
code := cmd.Run([]string{
"my-policy",
})
if exp := 0; code != exp {
t.Errorf("expected %d to be %d", code, exp)
}
combined := ui.OutputWriter.String() + ui.ErrorWriter.String()
if !strings.Contains(combined, policy) {
t.Errorf("expected %q to contain %q", combined, policy)
}
})
t.Run("communication_failure", func(t *testing.T) {
t.Parallel()
client, closer := testVaultServerBad(t)
defer closer()
ui, cmd := testPolicyListCommand(t)
cmd.client = client
code := cmd.Run([]string{})
if exp := 2; code != exp {
t.Errorf("expected %d to be %d", code, exp)
}
expected := "Error listing policies: "
combined := ui.OutputWriter.String() + ui.ErrorWriter.String()
if !strings.Contains(combined, expected) {
t.Errorf("expected %q to contain %q", combined, expected)
}
})
t.Run("no_tabs", func(t *testing.T) {
t.Parallel()
_, cmd := testPolicyListCommand(t)
assertNoTabs(t, cmd)
})
}