Add start of base command, flags, prediction

This commit is contained in:
Seth Vargo
2017-08-28 16:44:35 -04:00
parent 150e81f3f0
commit 7f6aa892a4
6 changed files with 1533 additions and 0 deletions

434
command/base.go Normal file
View File

@@ -0,0 +1,434 @@
package command
import (
"bufio"
"bytes"
"flag"
"fmt"
"io"
"regexp"
"strings"
"sync"
"time"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/command/token"
"github.com/kr/text"
"github.com/mitchellh/cli"
"github.com/pkg/errors"
"github.com/posener/complete"
)
// maxLineLength is the maximum width of any line.
const maxLineLength int = 78
// reRemoveWhitespace is a regular expression for stripping whitespace from
// a string.
var reRemoveWhitespace = regexp.MustCompile(`[\s]+`)
type TokenHelperFunc func() (token.TokenHelper, error)
type BaseCommand struct {
UI cli.Ui
flags *FlagSets
flagsOnce sync.Once
flagAddress string
flagCACert string
flagCAPath string
flagClientCert string
flagClientKey string
flagTLSServerName string
flagTLSSkipVerify bool
flagWrapTTL time.Duration
flagFormat string
flagField string
tokenHelper TokenHelperFunc
client *api.Client
clientErr error
clientOnce sync.Once
}
// Client returns the HTTP API client. The client is cached on the command to
// save performance on future calls.
func (c *BaseCommand) Client() (*api.Client, error) {
c.clientOnce.Do(func() {
// This should never happen in reality and is just for testing. Nothing
// should be setting the underlying client.
if c.client != nil {
return
}
config := api.DefaultConfig()
if err := config.ReadEnvironment(); err != nil {
c.clientErr = errors.Wrap(err, "failed to read environment")
return
}
if c.flagAddress != "" {
config.Address = c.flagAddress
}
// If we need custom TLS configuration, then set it
if c.flagCACert != "" || c.flagCAPath != "" || c.flagClientCert != "" ||
c.flagClientKey != "" || c.flagTLSServerName != "" || c.flagTLSSkipVerify {
t := &api.TLSConfig{
CACert: c.flagCACert,
CAPath: c.flagCAPath,
ClientCert: c.flagClientCert,
ClientKey: c.flagClientKey,
TLSServerName: c.flagTLSServerName,
Insecure: c.flagTLSSkipVerify,
}
config.ConfigureTLS(t)
}
// Build the client
client, err := api.NewClient(config)
if err != nil {
c.clientErr = errors.Wrap(err, "failed to create client")
return
}
// Set the wrapping function
client.SetWrappingLookupFunc(c.DefaultWrappingLookupFunc)
// Get the token if it came in from the environment
token := client.Token()
// If we don't have a token, check the token helper
if token == "" {
if c.tokenHelper != nil {
// If we have a token, then set that
tokenHelper, err := c.tokenHelper()
if err != nil {
c.clientErr = errors.Wrap(err, "failed to get token helper")
return
}
token, err = tokenHelper.Get()
if err != nil {
c.clientErr = errors.Wrap(err, "failed to retrieve from token helper")
return
}
}
}
// Set the token
if token != "" {
client.SetToken(token)
}
c.client = client
})
return c.client, c.clientErr
}
// DefaultWrappingLookupFunc is the default wrapping function based on the
// CLI flag.
func (c *BaseCommand) DefaultWrappingLookupFunc(operation, path string) string {
if c.flagWrapTTL != 0 {
return c.flagWrapTTL.String()
}
return api.DefaultWrappingLookupFunc(operation, path)
}
type FlagSetBit uint
const (
FlagSetNone FlagSetBit = 1 << iota
FlagSetHTTP
FlagSetOutputField
FlagSetOutputFormat
)
// flagSet creates the flags for this command. The result is cached on the
// command to save performance on future calls.
func (c *BaseCommand) flagSet(bit FlagSetBit) *FlagSets {
c.flagsOnce.Do(func() {
set := NewFlagSets(c.UI)
if bit&FlagSetHTTP != 0 {
f := set.NewFlagSet("HTTP Options")
f.StringVar(&StringVar{
Name: "address",
Target: &c.flagAddress,
Default: "https://127.0.0.1:8200",
EnvVar: "VAULT_ADDR",
Completion: complete.PredictAnything,
Usage: "Address of the Vault server.",
})
f.StringVar(&StringVar{
Name: "ca-cert",
Target: &c.flagCACert,
Default: "",
EnvVar: "VAULT_CACERT",
Completion: complete.PredictFiles("*"),
Usage: "Path on the local disk to a single PEM-encoded CA " +
"certificate to verify the Vault server's SSL certificate. This " +
"takes precendence over -ca-path.",
})
f.StringVar(&StringVar{
Name: "ca-path",
Target: &c.flagCAPath,
Default: "",
EnvVar: "VAULT_CAPATH",
Completion: complete.PredictDirs("*"),
Usage: "Path on the local disk to a directory of PEM-encoded CA " +
"certificates to verify the Vault server's SSL certificate.",
})
f.StringVar(&StringVar{
Name: "client-cert",
Target: &c.flagClientCert,
Default: "",
EnvVar: "VAULT_CLIENT_CERT",
Completion: complete.PredictFiles("*"),
Usage: "Path on the local disk to a single PEM-encoded CA " +
"certificate to use for TLS authentication to the Vault server. If " +
"this flag is specified, -client-key is also required.",
})
f.StringVar(&StringVar{
Name: "client-key",
Target: &c.flagClientKey,
Default: "",
EnvVar: "VAULT_CLIENT_KEY",
Completion: complete.PredictFiles("*"),
Usage: "Path on the local disk to a single PEM-encoded private key " +
"matching the client certificate from -client-cert.",
})
f.StringVar(&StringVar{
Name: "tls-server-name",
Target: &c.flagTLSServerName,
Default: "",
EnvVar: "VAULT_TLS_SERVER_NAME",
Completion: complete.PredictAnything,
Usage: "Name to use as the SNI host when connecting to the Vault " +
"server via TLS.",
})
f.BoolVar(&BoolVar{
Name: "tls-skip-verify",
Target: &c.flagTLSSkipVerify,
Default: false,
EnvVar: "VAULT_SKIP_VERIFY",
Completion: complete.PredictNothing,
Usage: "Disable verification of TLS certificates. Using this option " +
"is highly discouraged and decreases the security of data " +
"transmissions to and from the Vault server.",
})
f.DurationVar(&DurationVar{
Name: "wrap-ttl",
Target: &c.flagWrapTTL,
Default: 0,
EnvVar: "VAULT_WRAP_TTL",
Completion: complete.PredictAnything,
Usage: "Wraps the response in a cubbyhole token with the requested " +
"TTL. The response is available via the \"vault unwrap\" command. " +
"The TTL is specified as a numeric string with suffix like \"30s\" " +
"or \"5m\"",
})
}
if bit&(FlagSetOutputField|FlagSetOutputFormat) != 0 {
f := set.NewFlagSet("Output Options")
if bit&FlagSetOutputField != 0 {
f.StringVar(&StringVar{
Name: "field",
Target: &c.flagField,
Default: "",
EnvVar: "",
Completion: complete.PredictAnything,
Usage: "Print only the field with the given name. Specifying " +
"this option will take precedence over other formatting " +
"directives. The result will not have a trailing newline " +
"making it idea for piping to other processes.",
})
}
if bit&FlagSetOutputFormat != 0 {
f.StringVar(&StringVar{
Name: "format",
Target: &c.flagFormat,
Default: "table",
EnvVar: "VAULT_FORMAT",
Completion: complete.PredictSet("table", "json", "yaml"),
Usage: "Print the output in the given format. Valid formats " +
"are \"table\", \"json\", or \"yaml\".",
})
}
}
c.flags = set
})
return c.flags
}
// printFlagTitle prints a consistently-formatted title to the given writer.
func printFlagTitle(w io.Writer, s string) {
fmt.Fprintf(w, "%s\n\n", s)
}
// printFlagDetail prints a single flag to the given writer.
func printFlagDetail(w io.Writer, f *flag.Flag) {
example := ""
if t, ok := f.Value.(FlagExample); ok {
example = t.Example()
}
if example != "" {
fmt.Fprintf(w, " -%s=<%s>\n", f.Name, example)
} else {
fmt.Fprintf(w, " -%s\n", f.Name)
}
usage := reRemoveWhitespace.ReplaceAllString(f.Usage, " ")
indented := wrapAtLength(usage, 6)
fmt.Fprintf(w, "%s\n\n", indented)
}
// wrapAtLength wraps the given text at the maxLineLength, taking into account
// any provided left padding.
func wrapAtLength(s string, pad int) string {
wrapped := text.Wrap(s, maxLineLength-pad)
lines := strings.Split(wrapped, "\n")
for i, line := range lines {
lines[i] = strings.Repeat(" ", pad) + line
}
return strings.Join(lines, "\n")
}
// FlagSets is a group of flag sets.
type FlagSets struct {
flagSets []*FlagSet
mainSet *flag.FlagSet
hiddens map[string]struct{}
completions complete.Flags
}
// NewFlagSets creates a new flag sets.
func NewFlagSets(ui cli.Ui) *FlagSets {
mainSet := flag.NewFlagSet("", flag.ContinueOnError)
mainSet.Usage = func() {}
// Pull errors from the flagset into the ui's error
errR, errW := io.Pipe()
errScanner := bufio.NewScanner(errR)
go func() {
for errScanner.Scan() {
ui.Error(errScanner.Text())
}
}()
mainSet.SetOutput(errW)
return &FlagSets{
flagSets: make([]*FlagSet, 0, 6),
mainSet: mainSet,
hiddens: make(map[string]struct{}),
completions: complete.Flags{},
}
}
// NewFlagSet creates a new flag set from the given flag sets.
func (f *FlagSets) NewFlagSet(name string) *FlagSet {
flagSet := NewFlagSet(name)
f.AddFlagSet(flagSet)
return flagSet
}
// AddFlagSet adds a new flag set to this flag set.
func (f *FlagSets) AddFlagSet(set *FlagSet) {
set.mainSet = f.mainSet
set.completions = f.completions
f.flagSets = append(f.flagSets, set)
}
func (f *FlagSets) Completions() complete.Flags {
return f.completions
}
// Parse parses the given flags, returning any errors.
func (f *FlagSets) Parse(args []string) error {
return f.mainSet.Parse(args)
}
// Args returns the remaining args after parsing.
func (f *FlagSets) Args() []string {
return f.mainSet.Args()
}
// HideFlag excludes the flag from the list of flags to print in help. This is
// useful when you want to include a flag in parsing for deprecations/bc, but
// you don't want to include it in help output.
func (f *FlagSets) HideFlag(n string) {
if _, ok := f.hiddens[n]; !ok {
f.hiddens[n] = struct{}{}
}
}
// HiddenFlag returns true if the flag with the given name is hidden.
func (f *FlagSets) HiddenFlag(n string) bool {
_, ok := f.hiddens[n]
return ok
}
// Help builds custom help for this command, grouping by flag set.
func (fs *FlagSets) Help() string {
var out bytes.Buffer
for _, set := range fs.flagSets {
printFlagTitle(&out, set.name+":")
set.VisitAll(func(f *flag.Flag) {
// Skip any hidden flags
if fs.HiddenFlag(f.Name) {
return
}
printFlagDetail(&out, f)
})
}
return strings.TrimRight(out.String(), "\n")
}
// FlagSet is a grouped wrapper around a real flag set and a grouped flag set.
type FlagSet struct {
name string
flagSet *flag.FlagSet
mainSet *flag.FlagSet
completions complete.Flags
}
// NewFlagSet creates a new flag set.
func NewFlagSet(name string) *FlagSet {
return &FlagSet{
name: name,
flagSet: flag.NewFlagSet(name, flag.ContinueOnError),
}
}
// Name returns the name of this flag set.
func (f *FlagSet) Name() string {
return f.name
}
func (f *FlagSet) Visit(fn func(*flag.Flag)) {
f.flagSet.Visit(fn)
}
func (f *FlagSet) VisitAll(fn func(*flag.Flag)) {
f.flagSet.VisitAll(fn)
}

506
command/base_flags.go Normal file
View File

@@ -0,0 +1,506 @@
package command
import (
"flag"
"fmt"
"os"
"strconv"
"strings"
"time"
"github.com/posener/complete"
)
// FlagExample is an interface which declares an example value.
type FlagExample interface {
Example() string
}
type BoolVar struct {
Name string
Aliases []string
Usage string
Default bool
EnvVar string
Target *bool
Completion complete.Predictor
}
func (f *FlagSet) BoolVar(i *BoolVar) {
def := i.Default
if v := os.Getenv(i.EnvVar); v != "" {
if b, err := strconv.ParseBool(v); err != nil {
def = b
}
}
f.VarFlag(&VarFlag{
Name: i.Name,
Aliases: i.Aliases,
Usage: i.Usage,
Default: strconv.FormatBool(i.Default),
EnvVar: i.EnvVar,
Value: newBoolValue(def, i.Target),
Completion: i.Completion,
})
}
type IntVar struct {
Name string
Aliases []string
Usage string
Default int
EnvVar string
Target *int
Completion complete.Predictor
}
func (f *FlagSet) IntVar(i *IntVar) {
def := i.Default
if v := os.Getenv(i.EnvVar); v != "" {
if i, err := strconv.ParseInt(v, 0, 64); err != nil {
def = int(i)
}
}
f.VarFlag(&VarFlag{
Name: i.Name,
Aliases: i.Aliases,
Usage: i.Usage,
Default: strconv.FormatInt(int64(i.Default), 10),
EnvVar: i.EnvVar,
Value: newIntValue(def, i.Target),
Completion: i.Completion,
})
}
type Int64Var struct {
Name string
Aliases []string
Usage string
Default int64
EnvVar string
Target *int64
Completion complete.Predictor
}
func (f *FlagSet) Int64Var(i *Int64Var) {
def := i.Default
if v := os.Getenv(i.EnvVar); v != "" {
if i, err := strconv.ParseInt(v, 0, 64); err != nil {
def = i
}
}
f.VarFlag(&VarFlag{
Name: i.Name,
Aliases: i.Aliases,
Usage: i.Usage,
Default: strconv.FormatInt(i.Default, 10),
EnvVar: i.EnvVar,
Value: newInt64Value(def, i.Target),
Completion: i.Completion,
})
}
type UintVar struct {
Name string
Aliases []string
Usage string
Default uint
EnvVar string
Target *uint
Completion complete.Predictor
}
func (f *FlagSet) UintVar(i *UintVar) {
def := i.Default
if v := os.Getenv(i.EnvVar); v != "" {
if i, err := strconv.ParseUint(v, 0, 64); err != nil {
def = uint(i)
}
}
f.VarFlag(&VarFlag{
Name: i.Name,
Aliases: i.Aliases,
Usage: i.Usage,
Default: strconv.FormatUint(uint64(i.Default), 10),
EnvVar: i.EnvVar,
Value: newUintValue(def, i.Target),
Completion: i.Completion,
})
}
type Uint64Var struct {
Name string
Aliases []string
Usage string
Default uint64
EnvVar string
Target *uint64
Completion complete.Predictor
}
func (f *FlagSet) Uint64Var(i *Uint64Var) {
def := i.Default
if v := os.Getenv(i.EnvVar); v != "" {
if i, err := strconv.ParseUint(v, 0, 64); err != nil {
def = i
}
}
f.VarFlag(&VarFlag{
Name: i.Name,
Aliases: i.Aliases,
Usage: i.Usage,
Default: strconv.FormatUint(i.Default, 10),
EnvVar: i.EnvVar,
Value: newUint64Value(def, i.Target),
Completion: i.Completion,
})
}
type StringVar struct {
Name string
Aliases []string
Usage string
Default string
EnvVar string
Target *string
Completion complete.Predictor
}
func (f *FlagSet) StringVar(i *StringVar) {
def := i.Default
if v := os.Getenv(i.EnvVar); v != "" {
def = v
}
f.VarFlag(&VarFlag{
Name: i.Name,
Aliases: i.Aliases,
Usage: i.Usage,
Default: i.Default,
EnvVar: i.EnvVar,
Value: newStringValue(def, i.Target),
Completion: i.Completion,
})
}
type Float64Var struct {
Name string
Aliases []string
Usage string
Default float64
EnvVar string
Target *float64
Completion complete.Predictor
}
func (f *FlagSet) Float64Var(i *Float64Var) {
def := i.Default
if v := os.Getenv(i.EnvVar); v != "" {
if i, err := strconv.ParseFloat(v, 64); err != nil {
def = i
}
}
f.VarFlag(&VarFlag{
Name: i.Name,
Aliases: i.Aliases,
Usage: i.Usage,
Default: strconv.FormatFloat(i.Default, 'e', -1, 64),
EnvVar: i.EnvVar,
Value: newFloat64Value(def, i.Target),
Completion: i.Completion,
})
}
type DurationVar struct {
Name string
Aliases []string
Usage string
Default time.Duration
EnvVar string
Target *time.Duration
Completion complete.Predictor
}
func (f *FlagSet) DurationVar(i *DurationVar) {
def := i.Default
if v := os.Getenv(i.EnvVar); v != "" {
if d, err := time.ParseDuration(v); err != nil {
def = d
}
}
f.VarFlag(&VarFlag{
Name: i.Name,
Aliases: i.Aliases,
Usage: i.Usage,
Default: i.Default.String(),
EnvVar: i.EnvVar,
Value: newDurationValue(def, i.Target),
Completion: i.Completion,
})
}
type VarFlag struct {
Name string
Aliases []string
Usage string
Default string
EnvVar string
Value flag.Value
Completion complete.Predictor
}
func (f *FlagSet) VarFlag(i *VarFlag) {
// Calculate the full usage
usage := i.Usage
if len(i.Aliases) > 0 {
sentence := make([]string, len(i.Aliases))
for i, a := range i.Aliases {
sentence[i] = fmt.Sprintf(`"-%s"`, a)
}
aliases := ""
switch len(sentence) {
case 0:
// impossible...
case 1:
aliases = sentence[0]
case 2:
aliases = sentence[0] + " and " + sentence[1]
default:
sentence[len(sentence)-1] = "and " + sentence[len(sentence)-1]
aliases = strings.Join(sentence, ", ")
}
usage += fmt.Sprintf(" This is aliased as %s.", aliases)
}
if i.Default != "" {
usage += fmt.Sprintf(" The default is %s.", i.Default)
}
if i.EnvVar != "" {
usage += fmt.Sprintf(" This can also be specified via the %s "+
"environment variable.", i.EnvVar)
}
f.mainSet.Var(i.Value, i.Name, "") // No point in passing along usage here
// Add aliases to the main set
for _, a := range i.Aliases {
f.mainSet.Var(i.Value, a, "")
}
f.flagSet.Var(i.Value, i.Name, usage)
f.completions["-"+i.Name] = i.Completion
}
func (f *FlagSet) Var(value flag.Value, name, usage string) {
f.mainSet.Var(value, name, usage)
f.flagSet.Var(value, name, usage)
}
// -- bool Value
type boolValue bool
func newBoolValue(val bool, p *bool) *boolValue {
*p = val
return (*boolValue)(p)
}
func (b *boolValue) Set(s string) error {
v, err := strconv.ParseBool(s)
*b = boolValue(v)
return err
}
func (b *boolValue) Get() interface{} { return bool(*b) }
func (b *boolValue) String() string { return strconv.FormatBool(bool(*b)) }
func (b *boolValue) Example() string { return "" }
func (b *boolValue) IsBoolFlag() bool { return true }
// optional interface to indicate boolean flags that can be
// supplied without "=value" text
type boolFlag interface {
flag.Value
IsBoolFlag() bool
}
// -- int Value
type intValue int
func newIntValue(val int, p *int) *intValue {
*p = val
return (*intValue)(p)
}
func (i *intValue) Set(s string) error {
v, err := strconv.ParseInt(s, 0, 64)
*i = intValue(v)
return err
}
func (i *intValue) Get() interface{} { return int(*i) }
func (i *intValue) String() string { return strconv.Itoa(int(*i)) }
func (i *intValue) Example() string { return "int" }
// -- int64 Value
type int64Value int64
func newInt64Value(val int64, p *int64) *int64Value {
*p = val
return (*int64Value)(p)
}
func (i *int64Value) Set(s string) error {
v, err := strconv.ParseInt(s, 0, 64)
*i = int64Value(v)
return err
}
func (i *int64Value) Get() interface{} { return int64(*i) }
func (i *int64Value) String() string { return strconv.FormatInt(int64(*i), 10) }
func (i *int64Value) Example() string { return "int" }
// -- uint Value
type uintValue uint
func newUintValue(val uint, p *uint) *uintValue {
*p = val
return (*uintValue)(p)
}
func (i *uintValue) Set(s string) error {
v, err := strconv.ParseUint(s, 0, 64)
*i = uintValue(v)
return err
}
func (i *uintValue) Get() interface{} { return uint(*i) }
func (i *uintValue) String() string { return strconv.FormatUint(uint64(*i), 10) }
func (i *uintValue) Example() string { return "uint" }
// -- uint64 Value
type uint64Value uint64
func newUint64Value(val uint64, p *uint64) *uint64Value {
*p = val
return (*uint64Value)(p)
}
func (i *uint64Value) Set(s string) error {
v, err := strconv.ParseUint(s, 0, 64)
*i = uint64Value(v)
return err
}
func (i *uint64Value) Get() interface{} { return uint64(*i) }
func (i *uint64Value) String() string { return strconv.FormatUint(uint64(*i), 10) }
func (i *uint64Value) Example() string { return "uint" }
// -- string Value
type stringValue string
func newStringValue(val string, p *string) *stringValue {
*p = val
return (*stringValue)(p)
}
func (s *stringValue) Set(val string) error {
*s = stringValue(val)
return nil
}
func (s *stringValue) Get() interface{} { return string(*s) }
func (s *stringValue) String() string { return string(*s) }
func (s *stringValue) Example() string { return "string" }
// -- float64 Value
type float64Value float64
func newFloat64Value(val float64, p *float64) *float64Value {
*p = val
return (*float64Value)(p)
}
func (f *float64Value) Set(s string) error {
v, err := strconv.ParseFloat(s, 64)
*f = float64Value(v)
return err
}
func (f *float64Value) Get() interface{} { return float64(*f) }
func (f *float64Value) String() string { return strconv.FormatFloat(float64(*f), 'g', -1, 64) }
func (f *float64Value) Example() string { return "float" }
// -- time.Duration Value
type durationValue time.Duration
func newDurationValue(val time.Duration, p *time.Duration) *durationValue {
*p = val
return (*durationValue)(p)
}
func (d *durationValue) Set(s string) error {
v, err := time.ParseDuration(s)
*d = durationValue(v)
return err
}
func (d *durationValue) Get() interface{} { return time.Duration(*d) }
func (d *durationValue) String() string { return (*time.Duration)(d).String() }
func (d *durationValue) Example() string { return "duration" }
// -- helpers
func envDefault(key, def string) string {
if v := os.Getenv(key); v != "" {
return v
}
return def
}
func envBoolDefault(key string, def bool) bool {
if v := os.Getenv(key); v != "" {
b, err := strconv.ParseBool(v)
if err != nil {
panic(err)
}
return b
}
return def
}
func envDurationDefault(key string, def time.Duration) time.Duration {
if v := os.Getenv(key); v != "" {
d, err := time.ParseDuration(v)
if err != nil {
panic(err)
}
return d
}
return def
}

43
command/base_helpers.go Normal file
View File

@@ -0,0 +1,43 @@
package command
import (
"fmt"
"strings"
)
var ErrMissingPath = fmt.Errorf("Missing PATH!")
// extractPath extracts the path and list of arguments from the args. If there
// are no extra arguments, the remaining args will be nil.
func extractPath(args []string) (string, []string, error) {
if len(args) < 1 {
return "", nil, ErrMissingPath
}
// Path is always the first argument after all flags
path := args[0]
// Strip leading and trailing slashes
for len(path) > 0 && path[0] == '/' {
path = path[1:]
}
for len(path) > 0 && path[len(path)-1] == '/' {
path = path[:len(path)-1]
}
// Trim any leading/trailing whitespace
path = strings.TrimSpace(path)
// Verify we have a path
if path == "" {
return "", nil, ErrMissingPath
}
// Splice remaining args
var remaining []string
if len(args) > 1 {
remaining = args[1:]
}
return path, remaining, nil
}

173
command/base_predict.go Normal file
View File

@@ -0,0 +1,173 @@
package command
import (
"sort"
"strings"
"github.com/hashicorp/vault/api"
"github.com/posener/complete"
)
// defaultPredictVaultMounts is the default list of mounts to return to the
// user. This is a best-guess, given we haven't communicated with the Vault
// server. If the user has no token or if the token does not have the default
// policy attached, it won't be able to read cubbyhole/, but it's a better UX
// that returning nothing.
var defaultPredictVaultMounts = []string{"cubbyhole/"}
// PredictVaultPaths returns a predictor for Vault mounts and paths based on the
// configured client for the base command. Unfortunately this happens pre-flag
// parsing, so users must rely on environment variables for autocomplete if they
// are not using Vault at the default endpoints.
func (b *BaseCommand) PredictVaultPaths() complete.Predictor {
client, err := b.Client()
if err != nil {
return nil
}
return PredictVaultPaths(client)
}
// PredictVaultPaths returns a predictor for Vault paths. This is a public API
// for consumers, but you probably want BaseCommand.PredictVaultPaths instead.
func PredictVaultPaths(client *api.Client) complete.Predictor {
return predictVaultPaths(client)
}
// predictVaultPaths parses the CLI options and returns the "best" list of
// possible paths. If there are any errors, this function returns an empty
// result. All errors are suppressed since this is a prediction function.
func predictVaultPaths(client *api.Client) complete.PredictFunc {
return func(args complete.Args) []string {
// Do not predict more than one paths
if predictHasPathArg(args.All) {
return nil
}
path := args.Last
var predictions []string
if strings.Contains(path, "/") {
predictions = predictPaths(client, path)
} else {
predictions = predictMounts(client, path)
}
// Either no results or many results, so return.
if len(predictions) != 1 {
return predictions
}
// If this is not a "folder", do not try to recurse.
if !strings.HasSuffix(predictions[0], "/") {
return predictions
}
// If the prediction is the same as the last guess, return it (we have no
// new information and we won't get anymore).
if predictions[0] == args.Last {
return predictions
}
// Re-predict with the remaining path
args.Last = predictions[0]
return predictVaultPaths(client).Predict(args)
}
}
// predictMounts predicts all mounts which start with the given prefix. These
// are predicted on mount path, not "type".
func predictMounts(client *api.Client, path string) []string {
mounts := predictListMounts(client)
var predictions []string
for _, m := range mounts {
if strings.HasPrefix(m, path) {
predictions = append(predictions, m)
}
}
return predictions
}
// predictPaths predicts all paths which start with the given path.
func predictPaths(client *api.Client, path string) []string {
// Vault does not support listing based on a sub-key, so we have to back-pedal
// to the last "/" and return all paths on that "folder". Then we perform
// client-side filtering.
root := path
idx := strings.LastIndex(root, "/")
if idx > 0 && idx < len(root) {
root = root[:idx+1]
}
paths := predictListPaths(client, root)
var predictions []string
for _, p := range paths {
// Calculate the absolute "path" for matching.
p = root + p
if strings.HasPrefix(p, path) {
predictions = append(predictions, p)
}
}
// Add root to the path
if len(predictions) == 0 {
predictions = append(predictions, path)
}
return predictions
}
// predictListMounts returns a sorted list of the mount paths for Vault server
// for which the client is configured to communicate with. This function returns
// the default list of mounts if an error occurs.
func predictListMounts(c *api.Client) []string {
mounts, err := c.Sys().ListMounts()
if err != nil {
return defaultPredictVaultMounts
}
list := make([]string, 0, len(mounts))
for m := range mounts {
list = append(list, m)
}
sort.Strings(list)
return list
}
// predictListPaths returns a list of paths (HTTP LIST) for the given path. This
// function returns an empty list of any errors occur.
func predictListPaths(c *api.Client, path string) []string {
secret, err := c.Logical().List(path)
if err != nil || secret == nil || secret.Data == nil {
return nil
}
paths, ok := secret.Data["keys"].([]interface{})
if !ok {
return nil
}
list := make([]string, 0, len(paths))
for _, p := range paths {
if str, ok := p.(string); ok {
list = append(list, str)
}
}
sort.Strings(list)
return list
}
// predictHasPathArg determines if the args have already accepted a path.
func predictHasPathArg(args []string) bool {
var nonFlags []string
for _, a := range args {
if !strings.HasPrefix(a, "-") {
nonFlags = append(nonFlags, a)
}
}
return len(nonFlags) > 2
}

View File

@@ -0,0 +1,363 @@
package command
import (
"reflect"
"testing"
"github.com/hashicorp/vault/api"
"github.com/posener/complete"
)
func TestPredictVaultPaths(t *testing.T) {
t.Parallel()
client, closer := testVaultServer(t)
defer closer()
data := map[string]interface{}{"a": "b"}
if _, err := client.Logical().Write("secret/bar", data); err != nil {
t.Fatal(err)
}
if _, err := client.Logical().Write("secret/foo", data); err != nil {
t.Fatal(err)
}
if _, err := client.Logical().Write("secret/zip/zap", data); err != nil {
t.Fatal(err)
}
if _, err := client.Logical().Write("secret/zip/zonk", data); err != nil {
t.Fatal(err)
}
if _, err := client.Logical().Write("secret/zip/twoot", data); err != nil {
t.Fatal(err)
}
f := predictVaultPaths(client)
cases := []struct {
name string
args complete.Args
exp []string
}{
{
"has_args",
complete.Args{
All: []string{"read", "secret/foo", "a=b"},
Last: "a=b",
},
nil,
},
{
"part_mount",
complete.Args{
All: []string{"read", "s"},
Last: "s",
},
[]string{"secret/", "sys/"},
},
{
"only_mount",
complete.Args{
All: []string{"read", "sec"},
Last: "sec",
},
[]string{"secret/bar", "secret/foo", "secret/zip/"},
},
{
"full_mount",
complete.Args{
All: []string{"read", "secret"},
Last: "secret",
},
[]string{"secret/bar", "secret/foo", "secret/zip/"},
},
{
"full_mount_slash",
complete.Args{
All: []string{"read", "secret/"},
Last: "secret/",
},
[]string{"secret/bar", "secret/foo", "secret/zip/"},
},
{
"path_partial",
complete.Args{
All: []string{"read", "secret/z"},
Last: "secret/z",
},
[]string{"secret/zip/twoot", "secret/zip/zap", "secret/zip/zonk"},
},
{
"subpath_partial_z",
complete.Args{
All: []string{"read", "secret/zip/z"},
Last: "secret/zip/z",
},
[]string{"secret/zip/zap", "secret/zip/zonk"},
},
{
"subpath_partial_t",
complete.Args{
All: []string{"read", "secret/zip/t"},
Last: "secret/zip/t",
},
[]string{"secret/zip/twoot"},
},
}
t.Run("group", func(t *testing.T) {
for _, tc := range cases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
act := f(tc.args)
if !reflect.DeepEqual(act, tc.exp) {
t.Errorf("expected %q to be %q", act, tc.exp)
}
})
}
})
}
func TestPredictMounts(t *testing.T) {
t.Parallel()
client, closer := testVaultServer(t)
defer closer()
cases := []struct {
name string
client *api.Client
path string
exp []string
}{
{
"no_match",
client,
"not-a-real-mount-seriously",
nil,
},
{
"s",
client,
"s",
[]string{"secret/", "sys/"},
},
{
"se",
client,
"se",
[]string{"secret/"},
},
}
t.Run("group", func(t *testing.T) {
for _, tc := range cases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
act := predictMounts(tc.client, tc.path)
if !reflect.DeepEqual(act, tc.exp) {
t.Errorf("expected %q to be %q", act, tc.exp)
}
})
}
})
}
func TestPredictPaths(t *testing.T) {
t.Parallel()
client, closer := testVaultServer(t)
defer closer()
data := map[string]interface{}{"a": "b"}
if _, err := client.Logical().Write("secret/bar", data); err != nil {
t.Fatal(err)
}
if _, err := client.Logical().Write("secret/foo", data); err != nil {
t.Fatal(err)
}
if _, err := client.Logical().Write("secret/zip/zap", data); err != nil {
t.Fatal(err)
}
cases := []struct {
name string
client *api.Client
path string
exp []string
}{
{
"bad_path",
client,
"nope/not/a/real/path/ever",
[]string{"nope/not/a/real/path/ever"},
},
{
"good_path",
client,
"secret/",
[]string{"secret/bar", "secret/foo", "secret/zip/"},
},
{
"partial_match",
client,
"secret/z",
[]string{"secret/zip/"},
},
}
t.Run("group", func(t *testing.T) {
for _, tc := range cases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
act := predictPaths(tc.client, tc.path)
if !reflect.DeepEqual(act, tc.exp) {
t.Errorf("expected %q to be %q", act, tc.exp)
}
})
}
})
}
func TestPredictListMounts(t *testing.T) {
t.Parallel()
client, closer := testVaultServer(t)
defer closer()
cases := []struct {
name string
client *api.Client
exp []string
}{
{
"not_connected_client",
func() *api.Client {
// Bad API client
client, _ := api.NewClient(nil)
return client
}(),
defaultPredictVaultMounts,
},
{
"good_path",
client,
[]string{"cubbyhole/", "secret/", "sys/"},
},
}
t.Run("group", func(t *testing.T) {
for _, tc := range cases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
act := predictListMounts(tc.client)
if !reflect.DeepEqual(act, tc.exp) {
t.Errorf("expected %q to be %q", act, tc.exp)
}
})
}
})
}
func TestPredictListPaths(t *testing.T) {
t.Parallel()
client, closer := testVaultServer(t)
defer closer()
data := map[string]interface{}{"a": "b"}
if _, err := client.Logical().Write("secret/bar", data); err != nil {
t.Fatal(err)
}
if _, err := client.Logical().Write("secret/foo", data); err != nil {
t.Fatal(err)
}
cases := []struct {
name string
client *api.Client
path string
exp []string
}{
{
"bad_path",
client,
"nope/not/a/real/path/ever",
nil,
},
{
"good_path",
client,
"secret/",
[]string{"bar", "foo"},
},
}
t.Run("group", func(t *testing.T) {
for _, tc := range cases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
act := predictListPaths(tc.client, tc.path)
if !reflect.DeepEqual(act, tc.exp) {
t.Errorf("expected %q to be %q", act, tc.exp)
}
})
}
})
}
func TestPredictHasPathArg(t *testing.T) {
t.Parallel()
cases := []struct {
name string
args []string
exp bool
}{
{
"nil",
nil,
false,
},
{
"empty",
[]string{},
false,
},
{
"empty_string",
[]string{""},
false,
},
{
"single",
[]string{"foo"},
false,
},
{
"multiple",
[]string{"foo", "bar", "baz"},
true,
},
}
for _, tc := range cases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
if act := predictHasPathArg(tc.args); act != tc.exp {
t.Errorf("expected %t to be %t", act, tc.exp)
}
})
}
}

14
command/base_test.go Normal file
View File

@@ -0,0 +1,14 @@
package command
import (
"strings"
"testing"
"github.com/mitchellh/cli"
)
func assertNoTabs(tb testing.TB, c cli.Command) {
if strings.ContainsRune(c.Help(), '\t') {
tb.Errorf("%#v help output contains tabs", c)
}
}