mirror of
https://github.com/optim-enterprises-bv/vault.git
synced 2025-11-01 19:17:58 +00:00
Add start of base command, flags, prediction
This commit is contained in:
434
command/base.go
Normal file
434
command/base.go
Normal file
@@ -0,0 +1,434 @@
|
||||
package command
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/vault/api"
|
||||
"github.com/hashicorp/vault/command/token"
|
||||
"github.com/kr/text"
|
||||
"github.com/mitchellh/cli"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/posener/complete"
|
||||
)
|
||||
|
||||
// maxLineLength is the maximum width of any line.
|
||||
const maxLineLength int = 78
|
||||
|
||||
// reRemoveWhitespace is a regular expression for stripping whitespace from
|
||||
// a string.
|
||||
var reRemoveWhitespace = regexp.MustCompile(`[\s]+`)
|
||||
|
||||
type TokenHelperFunc func() (token.TokenHelper, error)
|
||||
|
||||
type BaseCommand struct {
|
||||
UI cli.Ui
|
||||
|
||||
flags *FlagSets
|
||||
flagsOnce sync.Once
|
||||
|
||||
flagAddress string
|
||||
flagCACert string
|
||||
flagCAPath string
|
||||
flagClientCert string
|
||||
flagClientKey string
|
||||
flagTLSServerName string
|
||||
flagTLSSkipVerify bool
|
||||
flagWrapTTL time.Duration
|
||||
|
||||
flagFormat string
|
||||
flagField string
|
||||
|
||||
tokenHelper TokenHelperFunc
|
||||
|
||||
client *api.Client
|
||||
clientErr error
|
||||
clientOnce sync.Once
|
||||
}
|
||||
|
||||
// Client returns the HTTP API client. The client is cached on the command to
|
||||
// save performance on future calls.
|
||||
func (c *BaseCommand) Client() (*api.Client, error) {
|
||||
c.clientOnce.Do(func() {
|
||||
// This should never happen in reality and is just for testing. Nothing
|
||||
// should be setting the underlying client.
|
||||
if c.client != nil {
|
||||
return
|
||||
}
|
||||
|
||||
config := api.DefaultConfig()
|
||||
|
||||
if err := config.ReadEnvironment(); err != nil {
|
||||
c.clientErr = errors.Wrap(err, "failed to read environment")
|
||||
return
|
||||
}
|
||||
|
||||
if c.flagAddress != "" {
|
||||
config.Address = c.flagAddress
|
||||
}
|
||||
|
||||
// If we need custom TLS configuration, then set it
|
||||
if c.flagCACert != "" || c.flagCAPath != "" || c.flagClientCert != "" ||
|
||||
c.flagClientKey != "" || c.flagTLSServerName != "" || c.flagTLSSkipVerify {
|
||||
t := &api.TLSConfig{
|
||||
CACert: c.flagCACert,
|
||||
CAPath: c.flagCAPath,
|
||||
ClientCert: c.flagClientCert,
|
||||
ClientKey: c.flagClientKey,
|
||||
TLSServerName: c.flagTLSServerName,
|
||||
Insecure: c.flagTLSSkipVerify,
|
||||
}
|
||||
config.ConfigureTLS(t)
|
||||
}
|
||||
|
||||
// Build the client
|
||||
client, err := api.NewClient(config)
|
||||
if err != nil {
|
||||
c.clientErr = errors.Wrap(err, "failed to create client")
|
||||
return
|
||||
}
|
||||
|
||||
// Set the wrapping function
|
||||
client.SetWrappingLookupFunc(c.DefaultWrappingLookupFunc)
|
||||
|
||||
// Get the token if it came in from the environment
|
||||
token := client.Token()
|
||||
|
||||
// If we don't have a token, check the token helper
|
||||
if token == "" {
|
||||
if c.tokenHelper != nil {
|
||||
// If we have a token, then set that
|
||||
tokenHelper, err := c.tokenHelper()
|
||||
if err != nil {
|
||||
c.clientErr = errors.Wrap(err, "failed to get token helper")
|
||||
return
|
||||
}
|
||||
token, err = tokenHelper.Get()
|
||||
if err != nil {
|
||||
c.clientErr = errors.Wrap(err, "failed to retrieve from token helper")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Set the token
|
||||
if token != "" {
|
||||
client.SetToken(token)
|
||||
}
|
||||
|
||||
c.client = client
|
||||
})
|
||||
|
||||
return c.client, c.clientErr
|
||||
}
|
||||
|
||||
// DefaultWrappingLookupFunc is the default wrapping function based on the
|
||||
// CLI flag.
|
||||
func (c *BaseCommand) DefaultWrappingLookupFunc(operation, path string) string {
|
||||
if c.flagWrapTTL != 0 {
|
||||
return c.flagWrapTTL.String()
|
||||
}
|
||||
|
||||
return api.DefaultWrappingLookupFunc(operation, path)
|
||||
}
|
||||
|
||||
type FlagSetBit uint
|
||||
|
||||
const (
|
||||
FlagSetNone FlagSetBit = 1 << iota
|
||||
FlagSetHTTP
|
||||
FlagSetOutputField
|
||||
FlagSetOutputFormat
|
||||
)
|
||||
|
||||
// flagSet creates the flags for this command. The result is cached on the
|
||||
// command to save performance on future calls.
|
||||
func (c *BaseCommand) flagSet(bit FlagSetBit) *FlagSets {
|
||||
c.flagsOnce.Do(func() {
|
||||
set := NewFlagSets(c.UI)
|
||||
|
||||
if bit&FlagSetHTTP != 0 {
|
||||
f := set.NewFlagSet("HTTP Options")
|
||||
|
||||
f.StringVar(&StringVar{
|
||||
Name: "address",
|
||||
Target: &c.flagAddress,
|
||||
Default: "https://127.0.0.1:8200",
|
||||
EnvVar: "VAULT_ADDR",
|
||||
Completion: complete.PredictAnything,
|
||||
Usage: "Address of the Vault server.",
|
||||
})
|
||||
|
||||
f.StringVar(&StringVar{
|
||||
Name: "ca-cert",
|
||||
Target: &c.flagCACert,
|
||||
Default: "",
|
||||
EnvVar: "VAULT_CACERT",
|
||||
Completion: complete.PredictFiles("*"),
|
||||
Usage: "Path on the local disk to a single PEM-encoded CA " +
|
||||
"certificate to verify the Vault server's SSL certificate. This " +
|
||||
"takes precendence over -ca-path.",
|
||||
})
|
||||
|
||||
f.StringVar(&StringVar{
|
||||
Name: "ca-path",
|
||||
Target: &c.flagCAPath,
|
||||
Default: "",
|
||||
EnvVar: "VAULT_CAPATH",
|
||||
Completion: complete.PredictDirs("*"),
|
||||
Usage: "Path on the local disk to a directory of PEM-encoded CA " +
|
||||
"certificates to verify the Vault server's SSL certificate.",
|
||||
})
|
||||
|
||||
f.StringVar(&StringVar{
|
||||
Name: "client-cert",
|
||||
Target: &c.flagClientCert,
|
||||
Default: "",
|
||||
EnvVar: "VAULT_CLIENT_CERT",
|
||||
Completion: complete.PredictFiles("*"),
|
||||
Usage: "Path on the local disk to a single PEM-encoded CA " +
|
||||
"certificate to use for TLS authentication to the Vault server. If " +
|
||||
"this flag is specified, -client-key is also required.",
|
||||
})
|
||||
|
||||
f.StringVar(&StringVar{
|
||||
Name: "client-key",
|
||||
Target: &c.flagClientKey,
|
||||
Default: "",
|
||||
EnvVar: "VAULT_CLIENT_KEY",
|
||||
Completion: complete.PredictFiles("*"),
|
||||
Usage: "Path on the local disk to a single PEM-encoded private key " +
|
||||
"matching the client certificate from -client-cert.",
|
||||
})
|
||||
|
||||
f.StringVar(&StringVar{
|
||||
Name: "tls-server-name",
|
||||
Target: &c.flagTLSServerName,
|
||||
Default: "",
|
||||
EnvVar: "VAULT_TLS_SERVER_NAME",
|
||||
Completion: complete.PredictAnything,
|
||||
Usage: "Name to use as the SNI host when connecting to the Vault " +
|
||||
"server via TLS.",
|
||||
})
|
||||
|
||||
f.BoolVar(&BoolVar{
|
||||
Name: "tls-skip-verify",
|
||||
Target: &c.flagTLSSkipVerify,
|
||||
Default: false,
|
||||
EnvVar: "VAULT_SKIP_VERIFY",
|
||||
Completion: complete.PredictNothing,
|
||||
Usage: "Disable verification of TLS certificates. Using this option " +
|
||||
"is highly discouraged and decreases the security of data " +
|
||||
"transmissions to and from the Vault server.",
|
||||
})
|
||||
|
||||
f.DurationVar(&DurationVar{
|
||||
Name: "wrap-ttl",
|
||||
Target: &c.flagWrapTTL,
|
||||
Default: 0,
|
||||
EnvVar: "VAULT_WRAP_TTL",
|
||||
Completion: complete.PredictAnything,
|
||||
Usage: "Wraps the response in a cubbyhole token with the requested " +
|
||||
"TTL. The response is available via the \"vault unwrap\" command. " +
|
||||
"The TTL is specified as a numeric string with suffix like \"30s\" " +
|
||||
"or \"5m\"",
|
||||
})
|
||||
}
|
||||
|
||||
if bit&(FlagSetOutputField|FlagSetOutputFormat) != 0 {
|
||||
f := set.NewFlagSet("Output Options")
|
||||
|
||||
if bit&FlagSetOutputField != 0 {
|
||||
f.StringVar(&StringVar{
|
||||
Name: "field",
|
||||
Target: &c.flagField,
|
||||
Default: "",
|
||||
EnvVar: "",
|
||||
Completion: complete.PredictAnything,
|
||||
Usage: "Print only the field with the given name. Specifying " +
|
||||
"this option will take precedence over other formatting " +
|
||||
"directives. The result will not have a trailing newline " +
|
||||
"making it idea for piping to other processes.",
|
||||
})
|
||||
}
|
||||
|
||||
if bit&FlagSetOutputFormat != 0 {
|
||||
f.StringVar(&StringVar{
|
||||
Name: "format",
|
||||
Target: &c.flagFormat,
|
||||
Default: "table",
|
||||
EnvVar: "VAULT_FORMAT",
|
||||
Completion: complete.PredictSet("table", "json", "yaml"),
|
||||
Usage: "Print the output in the given format. Valid formats " +
|
||||
"are \"table\", \"json\", or \"yaml\".",
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
c.flags = set
|
||||
})
|
||||
|
||||
return c.flags
|
||||
}
|
||||
|
||||
// printFlagTitle prints a consistently-formatted title to the given writer.
|
||||
func printFlagTitle(w io.Writer, s string) {
|
||||
fmt.Fprintf(w, "%s\n\n", s)
|
||||
}
|
||||
|
||||
// printFlagDetail prints a single flag to the given writer.
|
||||
func printFlagDetail(w io.Writer, f *flag.Flag) {
|
||||
example := ""
|
||||
if t, ok := f.Value.(FlagExample); ok {
|
||||
example = t.Example()
|
||||
}
|
||||
|
||||
if example != "" {
|
||||
fmt.Fprintf(w, " -%s=<%s>\n", f.Name, example)
|
||||
} else {
|
||||
fmt.Fprintf(w, " -%s\n", f.Name)
|
||||
}
|
||||
|
||||
usage := reRemoveWhitespace.ReplaceAllString(f.Usage, " ")
|
||||
indented := wrapAtLength(usage, 6)
|
||||
fmt.Fprintf(w, "%s\n\n", indented)
|
||||
}
|
||||
|
||||
// wrapAtLength wraps the given text at the maxLineLength, taking into account
|
||||
// any provided left padding.
|
||||
func wrapAtLength(s string, pad int) string {
|
||||
wrapped := text.Wrap(s, maxLineLength-pad)
|
||||
lines := strings.Split(wrapped, "\n")
|
||||
for i, line := range lines {
|
||||
lines[i] = strings.Repeat(" ", pad) + line
|
||||
}
|
||||
return strings.Join(lines, "\n")
|
||||
}
|
||||
|
||||
// FlagSets is a group of flag sets.
|
||||
type FlagSets struct {
|
||||
flagSets []*FlagSet
|
||||
mainSet *flag.FlagSet
|
||||
hiddens map[string]struct{}
|
||||
completions complete.Flags
|
||||
}
|
||||
|
||||
// NewFlagSets creates a new flag sets.
|
||||
func NewFlagSets(ui cli.Ui) *FlagSets {
|
||||
mainSet := flag.NewFlagSet("", flag.ContinueOnError)
|
||||
mainSet.Usage = func() {}
|
||||
|
||||
// Pull errors from the flagset into the ui's error
|
||||
errR, errW := io.Pipe()
|
||||
errScanner := bufio.NewScanner(errR)
|
||||
go func() {
|
||||
for errScanner.Scan() {
|
||||
ui.Error(errScanner.Text())
|
||||
}
|
||||
}()
|
||||
mainSet.SetOutput(errW)
|
||||
|
||||
return &FlagSets{
|
||||
flagSets: make([]*FlagSet, 0, 6),
|
||||
mainSet: mainSet,
|
||||
hiddens: make(map[string]struct{}),
|
||||
completions: complete.Flags{},
|
||||
}
|
||||
}
|
||||
|
||||
// NewFlagSet creates a new flag set from the given flag sets.
|
||||
func (f *FlagSets) NewFlagSet(name string) *FlagSet {
|
||||
flagSet := NewFlagSet(name)
|
||||
f.AddFlagSet(flagSet)
|
||||
return flagSet
|
||||
}
|
||||
|
||||
// AddFlagSet adds a new flag set to this flag set.
|
||||
func (f *FlagSets) AddFlagSet(set *FlagSet) {
|
||||
set.mainSet = f.mainSet
|
||||
set.completions = f.completions
|
||||
f.flagSets = append(f.flagSets, set)
|
||||
}
|
||||
|
||||
func (f *FlagSets) Completions() complete.Flags {
|
||||
return f.completions
|
||||
}
|
||||
|
||||
// Parse parses the given flags, returning any errors.
|
||||
func (f *FlagSets) Parse(args []string) error {
|
||||
return f.mainSet.Parse(args)
|
||||
}
|
||||
|
||||
// Args returns the remaining args after parsing.
|
||||
func (f *FlagSets) Args() []string {
|
||||
return f.mainSet.Args()
|
||||
}
|
||||
|
||||
// HideFlag excludes the flag from the list of flags to print in help. This is
|
||||
// useful when you want to include a flag in parsing for deprecations/bc, but
|
||||
// you don't want to include it in help output.
|
||||
func (f *FlagSets) HideFlag(n string) {
|
||||
if _, ok := f.hiddens[n]; !ok {
|
||||
f.hiddens[n] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
// HiddenFlag returns true if the flag with the given name is hidden.
|
||||
func (f *FlagSets) HiddenFlag(n string) bool {
|
||||
_, ok := f.hiddens[n]
|
||||
return ok
|
||||
}
|
||||
|
||||
// Help builds custom help for this command, grouping by flag set.
|
||||
func (fs *FlagSets) Help() string {
|
||||
var out bytes.Buffer
|
||||
|
||||
for _, set := range fs.flagSets {
|
||||
printFlagTitle(&out, set.name+":")
|
||||
set.VisitAll(func(f *flag.Flag) {
|
||||
// Skip any hidden flags
|
||||
if fs.HiddenFlag(f.Name) {
|
||||
return
|
||||
}
|
||||
printFlagDetail(&out, f)
|
||||
})
|
||||
}
|
||||
|
||||
return strings.TrimRight(out.String(), "\n")
|
||||
}
|
||||
|
||||
// FlagSet is a grouped wrapper around a real flag set and a grouped flag set.
|
||||
type FlagSet struct {
|
||||
name string
|
||||
flagSet *flag.FlagSet
|
||||
mainSet *flag.FlagSet
|
||||
completions complete.Flags
|
||||
}
|
||||
|
||||
// NewFlagSet creates a new flag set.
|
||||
func NewFlagSet(name string) *FlagSet {
|
||||
return &FlagSet{
|
||||
name: name,
|
||||
flagSet: flag.NewFlagSet(name, flag.ContinueOnError),
|
||||
}
|
||||
}
|
||||
|
||||
// Name returns the name of this flag set.
|
||||
func (f *FlagSet) Name() string {
|
||||
return f.name
|
||||
}
|
||||
|
||||
func (f *FlagSet) Visit(fn func(*flag.Flag)) {
|
||||
f.flagSet.Visit(fn)
|
||||
}
|
||||
|
||||
func (f *FlagSet) VisitAll(fn func(*flag.Flag)) {
|
||||
f.flagSet.VisitAll(fn)
|
||||
}
|
||||
506
command/base_flags.go
Normal file
506
command/base_flags.go
Normal file
@@ -0,0 +1,506 @@
|
||||
package command
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/posener/complete"
|
||||
)
|
||||
|
||||
// FlagExample is an interface which declares an example value.
|
||||
type FlagExample interface {
|
||||
Example() string
|
||||
}
|
||||
|
||||
type BoolVar struct {
|
||||
Name string
|
||||
Aliases []string
|
||||
Usage string
|
||||
Default bool
|
||||
EnvVar string
|
||||
Target *bool
|
||||
Completion complete.Predictor
|
||||
}
|
||||
|
||||
func (f *FlagSet) BoolVar(i *BoolVar) {
|
||||
def := i.Default
|
||||
if v := os.Getenv(i.EnvVar); v != "" {
|
||||
if b, err := strconv.ParseBool(v); err != nil {
|
||||
def = b
|
||||
}
|
||||
}
|
||||
|
||||
f.VarFlag(&VarFlag{
|
||||
Name: i.Name,
|
||||
Aliases: i.Aliases,
|
||||
Usage: i.Usage,
|
||||
Default: strconv.FormatBool(i.Default),
|
||||
EnvVar: i.EnvVar,
|
||||
Value: newBoolValue(def, i.Target),
|
||||
Completion: i.Completion,
|
||||
})
|
||||
}
|
||||
|
||||
type IntVar struct {
|
||||
Name string
|
||||
Aliases []string
|
||||
Usage string
|
||||
Default int
|
||||
EnvVar string
|
||||
Target *int
|
||||
Completion complete.Predictor
|
||||
}
|
||||
|
||||
func (f *FlagSet) IntVar(i *IntVar) {
|
||||
def := i.Default
|
||||
if v := os.Getenv(i.EnvVar); v != "" {
|
||||
if i, err := strconv.ParseInt(v, 0, 64); err != nil {
|
||||
def = int(i)
|
||||
}
|
||||
}
|
||||
|
||||
f.VarFlag(&VarFlag{
|
||||
Name: i.Name,
|
||||
Aliases: i.Aliases,
|
||||
Usage: i.Usage,
|
||||
Default: strconv.FormatInt(int64(i.Default), 10),
|
||||
EnvVar: i.EnvVar,
|
||||
Value: newIntValue(def, i.Target),
|
||||
Completion: i.Completion,
|
||||
})
|
||||
}
|
||||
|
||||
type Int64Var struct {
|
||||
Name string
|
||||
Aliases []string
|
||||
Usage string
|
||||
Default int64
|
||||
EnvVar string
|
||||
Target *int64
|
||||
Completion complete.Predictor
|
||||
}
|
||||
|
||||
func (f *FlagSet) Int64Var(i *Int64Var) {
|
||||
def := i.Default
|
||||
if v := os.Getenv(i.EnvVar); v != "" {
|
||||
if i, err := strconv.ParseInt(v, 0, 64); err != nil {
|
||||
def = i
|
||||
}
|
||||
}
|
||||
|
||||
f.VarFlag(&VarFlag{
|
||||
Name: i.Name,
|
||||
Aliases: i.Aliases,
|
||||
Usage: i.Usage,
|
||||
Default: strconv.FormatInt(i.Default, 10),
|
||||
EnvVar: i.EnvVar,
|
||||
Value: newInt64Value(def, i.Target),
|
||||
Completion: i.Completion,
|
||||
})
|
||||
}
|
||||
|
||||
type UintVar struct {
|
||||
Name string
|
||||
Aliases []string
|
||||
Usage string
|
||||
Default uint
|
||||
EnvVar string
|
||||
Target *uint
|
||||
Completion complete.Predictor
|
||||
}
|
||||
|
||||
func (f *FlagSet) UintVar(i *UintVar) {
|
||||
def := i.Default
|
||||
if v := os.Getenv(i.EnvVar); v != "" {
|
||||
if i, err := strconv.ParseUint(v, 0, 64); err != nil {
|
||||
def = uint(i)
|
||||
}
|
||||
}
|
||||
|
||||
f.VarFlag(&VarFlag{
|
||||
Name: i.Name,
|
||||
Aliases: i.Aliases,
|
||||
Usage: i.Usage,
|
||||
Default: strconv.FormatUint(uint64(i.Default), 10),
|
||||
EnvVar: i.EnvVar,
|
||||
Value: newUintValue(def, i.Target),
|
||||
Completion: i.Completion,
|
||||
})
|
||||
}
|
||||
|
||||
type Uint64Var struct {
|
||||
Name string
|
||||
Aliases []string
|
||||
Usage string
|
||||
Default uint64
|
||||
EnvVar string
|
||||
Target *uint64
|
||||
Completion complete.Predictor
|
||||
}
|
||||
|
||||
func (f *FlagSet) Uint64Var(i *Uint64Var) {
|
||||
def := i.Default
|
||||
if v := os.Getenv(i.EnvVar); v != "" {
|
||||
if i, err := strconv.ParseUint(v, 0, 64); err != nil {
|
||||
def = i
|
||||
}
|
||||
}
|
||||
|
||||
f.VarFlag(&VarFlag{
|
||||
Name: i.Name,
|
||||
Aliases: i.Aliases,
|
||||
Usage: i.Usage,
|
||||
Default: strconv.FormatUint(i.Default, 10),
|
||||
EnvVar: i.EnvVar,
|
||||
Value: newUint64Value(def, i.Target),
|
||||
Completion: i.Completion,
|
||||
})
|
||||
}
|
||||
|
||||
type StringVar struct {
|
||||
Name string
|
||||
Aliases []string
|
||||
Usage string
|
||||
Default string
|
||||
EnvVar string
|
||||
Target *string
|
||||
Completion complete.Predictor
|
||||
}
|
||||
|
||||
func (f *FlagSet) StringVar(i *StringVar) {
|
||||
def := i.Default
|
||||
if v := os.Getenv(i.EnvVar); v != "" {
|
||||
def = v
|
||||
}
|
||||
|
||||
f.VarFlag(&VarFlag{
|
||||
Name: i.Name,
|
||||
Aliases: i.Aliases,
|
||||
Usage: i.Usage,
|
||||
Default: i.Default,
|
||||
EnvVar: i.EnvVar,
|
||||
Value: newStringValue(def, i.Target),
|
||||
Completion: i.Completion,
|
||||
})
|
||||
}
|
||||
|
||||
type Float64Var struct {
|
||||
Name string
|
||||
Aliases []string
|
||||
Usage string
|
||||
Default float64
|
||||
EnvVar string
|
||||
Target *float64
|
||||
Completion complete.Predictor
|
||||
}
|
||||
|
||||
func (f *FlagSet) Float64Var(i *Float64Var) {
|
||||
def := i.Default
|
||||
if v := os.Getenv(i.EnvVar); v != "" {
|
||||
if i, err := strconv.ParseFloat(v, 64); err != nil {
|
||||
def = i
|
||||
}
|
||||
}
|
||||
|
||||
f.VarFlag(&VarFlag{
|
||||
Name: i.Name,
|
||||
Aliases: i.Aliases,
|
||||
Usage: i.Usage,
|
||||
Default: strconv.FormatFloat(i.Default, 'e', -1, 64),
|
||||
EnvVar: i.EnvVar,
|
||||
Value: newFloat64Value(def, i.Target),
|
||||
Completion: i.Completion,
|
||||
})
|
||||
}
|
||||
|
||||
type DurationVar struct {
|
||||
Name string
|
||||
Aliases []string
|
||||
Usage string
|
||||
Default time.Duration
|
||||
EnvVar string
|
||||
Target *time.Duration
|
||||
Completion complete.Predictor
|
||||
}
|
||||
|
||||
func (f *FlagSet) DurationVar(i *DurationVar) {
|
||||
def := i.Default
|
||||
if v := os.Getenv(i.EnvVar); v != "" {
|
||||
if d, err := time.ParseDuration(v); err != nil {
|
||||
def = d
|
||||
}
|
||||
}
|
||||
|
||||
f.VarFlag(&VarFlag{
|
||||
Name: i.Name,
|
||||
Aliases: i.Aliases,
|
||||
Usage: i.Usage,
|
||||
Default: i.Default.String(),
|
||||
EnvVar: i.EnvVar,
|
||||
Value: newDurationValue(def, i.Target),
|
||||
Completion: i.Completion,
|
||||
})
|
||||
}
|
||||
|
||||
type VarFlag struct {
|
||||
Name string
|
||||
Aliases []string
|
||||
Usage string
|
||||
Default string
|
||||
EnvVar string
|
||||
Value flag.Value
|
||||
Completion complete.Predictor
|
||||
}
|
||||
|
||||
func (f *FlagSet) VarFlag(i *VarFlag) {
|
||||
// Calculate the full usage
|
||||
usage := i.Usage
|
||||
|
||||
if len(i.Aliases) > 0 {
|
||||
sentence := make([]string, len(i.Aliases))
|
||||
for i, a := range i.Aliases {
|
||||
sentence[i] = fmt.Sprintf(`"-%s"`, a)
|
||||
}
|
||||
|
||||
aliases := ""
|
||||
switch len(sentence) {
|
||||
case 0:
|
||||
// impossible...
|
||||
case 1:
|
||||
aliases = sentence[0]
|
||||
case 2:
|
||||
aliases = sentence[0] + " and " + sentence[1]
|
||||
default:
|
||||
sentence[len(sentence)-1] = "and " + sentence[len(sentence)-1]
|
||||
aliases = strings.Join(sentence, ", ")
|
||||
}
|
||||
|
||||
usage += fmt.Sprintf(" This is aliased as %s.", aliases)
|
||||
}
|
||||
|
||||
if i.Default != "" {
|
||||
usage += fmt.Sprintf(" The default is %s.", i.Default)
|
||||
}
|
||||
|
||||
if i.EnvVar != "" {
|
||||
usage += fmt.Sprintf(" This can also be specified via the %s "+
|
||||
"environment variable.", i.EnvVar)
|
||||
}
|
||||
|
||||
f.mainSet.Var(i.Value, i.Name, "") // No point in passing along usage here
|
||||
|
||||
// Add aliases to the main set
|
||||
for _, a := range i.Aliases {
|
||||
f.mainSet.Var(i.Value, a, "")
|
||||
}
|
||||
|
||||
f.flagSet.Var(i.Value, i.Name, usage)
|
||||
f.completions["-"+i.Name] = i.Completion
|
||||
}
|
||||
|
||||
func (f *FlagSet) Var(value flag.Value, name, usage string) {
|
||||
f.mainSet.Var(value, name, usage)
|
||||
f.flagSet.Var(value, name, usage)
|
||||
}
|
||||
|
||||
// -- bool Value
|
||||
type boolValue bool
|
||||
|
||||
func newBoolValue(val bool, p *bool) *boolValue {
|
||||
*p = val
|
||||
return (*boolValue)(p)
|
||||
}
|
||||
|
||||
func (b *boolValue) Set(s string) error {
|
||||
v, err := strconv.ParseBool(s)
|
||||
*b = boolValue(v)
|
||||
return err
|
||||
}
|
||||
|
||||
func (b *boolValue) Get() interface{} { return bool(*b) }
|
||||
|
||||
func (b *boolValue) String() string { return strconv.FormatBool(bool(*b)) }
|
||||
|
||||
func (b *boolValue) Example() string { return "" }
|
||||
|
||||
func (b *boolValue) IsBoolFlag() bool { return true }
|
||||
|
||||
// optional interface to indicate boolean flags that can be
|
||||
// supplied without "=value" text
|
||||
type boolFlag interface {
|
||||
flag.Value
|
||||
IsBoolFlag() bool
|
||||
}
|
||||
|
||||
// -- int Value
|
||||
type intValue int
|
||||
|
||||
func newIntValue(val int, p *int) *intValue {
|
||||
*p = val
|
||||
return (*intValue)(p)
|
||||
}
|
||||
|
||||
func (i *intValue) Set(s string) error {
|
||||
v, err := strconv.ParseInt(s, 0, 64)
|
||||
*i = intValue(v)
|
||||
return err
|
||||
}
|
||||
|
||||
func (i *intValue) Get() interface{} { return int(*i) }
|
||||
|
||||
func (i *intValue) String() string { return strconv.Itoa(int(*i)) }
|
||||
|
||||
func (i *intValue) Example() string { return "int" }
|
||||
|
||||
// -- int64 Value
|
||||
type int64Value int64
|
||||
|
||||
func newInt64Value(val int64, p *int64) *int64Value {
|
||||
*p = val
|
||||
return (*int64Value)(p)
|
||||
}
|
||||
|
||||
func (i *int64Value) Set(s string) error {
|
||||
v, err := strconv.ParseInt(s, 0, 64)
|
||||
*i = int64Value(v)
|
||||
return err
|
||||
}
|
||||
|
||||
func (i *int64Value) Get() interface{} { return int64(*i) }
|
||||
|
||||
func (i *int64Value) String() string { return strconv.FormatInt(int64(*i), 10) }
|
||||
|
||||
func (i *int64Value) Example() string { return "int" }
|
||||
|
||||
// -- uint Value
|
||||
type uintValue uint
|
||||
|
||||
func newUintValue(val uint, p *uint) *uintValue {
|
||||
*p = val
|
||||
return (*uintValue)(p)
|
||||
}
|
||||
|
||||
func (i *uintValue) Set(s string) error {
|
||||
v, err := strconv.ParseUint(s, 0, 64)
|
||||
*i = uintValue(v)
|
||||
return err
|
||||
}
|
||||
|
||||
func (i *uintValue) Get() interface{} { return uint(*i) }
|
||||
|
||||
func (i *uintValue) String() string { return strconv.FormatUint(uint64(*i), 10) }
|
||||
|
||||
func (i *uintValue) Example() string { return "uint" }
|
||||
|
||||
// -- uint64 Value
|
||||
type uint64Value uint64
|
||||
|
||||
func newUint64Value(val uint64, p *uint64) *uint64Value {
|
||||
*p = val
|
||||
return (*uint64Value)(p)
|
||||
}
|
||||
|
||||
func (i *uint64Value) Set(s string) error {
|
||||
v, err := strconv.ParseUint(s, 0, 64)
|
||||
*i = uint64Value(v)
|
||||
return err
|
||||
}
|
||||
|
||||
func (i *uint64Value) Get() interface{} { return uint64(*i) }
|
||||
|
||||
func (i *uint64Value) String() string { return strconv.FormatUint(uint64(*i), 10) }
|
||||
|
||||
func (i *uint64Value) Example() string { return "uint" }
|
||||
|
||||
// -- string Value
|
||||
type stringValue string
|
||||
|
||||
func newStringValue(val string, p *string) *stringValue {
|
||||
*p = val
|
||||
return (*stringValue)(p)
|
||||
}
|
||||
|
||||
func (s *stringValue) Set(val string) error {
|
||||
*s = stringValue(val)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *stringValue) Get() interface{} { return string(*s) }
|
||||
|
||||
func (s *stringValue) String() string { return string(*s) }
|
||||
|
||||
func (s *stringValue) Example() string { return "string" }
|
||||
|
||||
// -- float64 Value
|
||||
type float64Value float64
|
||||
|
||||
func newFloat64Value(val float64, p *float64) *float64Value {
|
||||
*p = val
|
||||
return (*float64Value)(p)
|
||||
}
|
||||
|
||||
func (f *float64Value) Set(s string) error {
|
||||
v, err := strconv.ParseFloat(s, 64)
|
||||
*f = float64Value(v)
|
||||
return err
|
||||
}
|
||||
|
||||
func (f *float64Value) Get() interface{} { return float64(*f) }
|
||||
|
||||
func (f *float64Value) String() string { return strconv.FormatFloat(float64(*f), 'g', -1, 64) }
|
||||
|
||||
func (f *float64Value) Example() string { return "float" }
|
||||
|
||||
// -- time.Duration Value
|
||||
type durationValue time.Duration
|
||||
|
||||
func newDurationValue(val time.Duration, p *time.Duration) *durationValue {
|
||||
*p = val
|
||||
return (*durationValue)(p)
|
||||
}
|
||||
|
||||
func (d *durationValue) Set(s string) error {
|
||||
v, err := time.ParseDuration(s)
|
||||
*d = durationValue(v)
|
||||
return err
|
||||
}
|
||||
|
||||
func (d *durationValue) Get() interface{} { return time.Duration(*d) }
|
||||
|
||||
func (d *durationValue) String() string { return (*time.Duration)(d).String() }
|
||||
|
||||
func (d *durationValue) Example() string { return "duration" }
|
||||
|
||||
// -- helpers
|
||||
func envDefault(key, def string) string {
|
||||
if v := os.Getenv(key); v != "" {
|
||||
return v
|
||||
}
|
||||
return def
|
||||
}
|
||||
|
||||
func envBoolDefault(key string, def bool) bool {
|
||||
if v := os.Getenv(key); v != "" {
|
||||
b, err := strconv.ParseBool(v)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return b
|
||||
}
|
||||
return def
|
||||
}
|
||||
|
||||
func envDurationDefault(key string, def time.Duration) time.Duration {
|
||||
if v := os.Getenv(key); v != "" {
|
||||
d, err := time.ParseDuration(v)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return d
|
||||
}
|
||||
return def
|
||||
}
|
||||
43
command/base_helpers.go
Normal file
43
command/base_helpers.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package command
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var ErrMissingPath = fmt.Errorf("Missing PATH!")
|
||||
|
||||
// extractPath extracts the path and list of arguments from the args. If there
|
||||
// are no extra arguments, the remaining args will be nil.
|
||||
func extractPath(args []string) (string, []string, error) {
|
||||
if len(args) < 1 {
|
||||
return "", nil, ErrMissingPath
|
||||
}
|
||||
|
||||
// Path is always the first argument after all flags
|
||||
path := args[0]
|
||||
|
||||
// Strip leading and trailing slashes
|
||||
for len(path) > 0 && path[0] == '/' {
|
||||
path = path[1:]
|
||||
}
|
||||
for len(path) > 0 && path[len(path)-1] == '/' {
|
||||
path = path[:len(path)-1]
|
||||
}
|
||||
|
||||
// Trim any leading/trailing whitespace
|
||||
path = strings.TrimSpace(path)
|
||||
|
||||
// Verify we have a path
|
||||
if path == "" {
|
||||
return "", nil, ErrMissingPath
|
||||
}
|
||||
|
||||
// Splice remaining args
|
||||
var remaining []string
|
||||
if len(args) > 1 {
|
||||
remaining = args[1:]
|
||||
}
|
||||
|
||||
return path, remaining, nil
|
||||
}
|
||||
173
command/base_predict.go
Normal file
173
command/base_predict.go
Normal file
@@ -0,0 +1,173 @@
|
||||
package command
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/vault/api"
|
||||
"github.com/posener/complete"
|
||||
)
|
||||
|
||||
// defaultPredictVaultMounts is the default list of mounts to return to the
|
||||
// user. This is a best-guess, given we haven't communicated with the Vault
|
||||
// server. If the user has no token or if the token does not have the default
|
||||
// policy attached, it won't be able to read cubbyhole/, but it's a better UX
|
||||
// that returning nothing.
|
||||
var defaultPredictVaultMounts = []string{"cubbyhole/"}
|
||||
|
||||
// PredictVaultPaths returns a predictor for Vault mounts and paths based on the
|
||||
// configured client for the base command. Unfortunately this happens pre-flag
|
||||
// parsing, so users must rely on environment variables for autocomplete if they
|
||||
// are not using Vault at the default endpoints.
|
||||
func (b *BaseCommand) PredictVaultPaths() complete.Predictor {
|
||||
client, err := b.Client()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return PredictVaultPaths(client)
|
||||
}
|
||||
|
||||
// PredictVaultPaths returns a predictor for Vault paths. This is a public API
|
||||
// for consumers, but you probably want BaseCommand.PredictVaultPaths instead.
|
||||
func PredictVaultPaths(client *api.Client) complete.Predictor {
|
||||
return predictVaultPaths(client)
|
||||
}
|
||||
|
||||
// predictVaultPaths parses the CLI options and returns the "best" list of
|
||||
// possible paths. If there are any errors, this function returns an empty
|
||||
// result. All errors are suppressed since this is a prediction function.
|
||||
func predictVaultPaths(client *api.Client) complete.PredictFunc {
|
||||
return func(args complete.Args) []string {
|
||||
// Do not predict more than one paths
|
||||
if predictHasPathArg(args.All) {
|
||||
return nil
|
||||
}
|
||||
|
||||
path := args.Last
|
||||
|
||||
var predictions []string
|
||||
if strings.Contains(path, "/") {
|
||||
predictions = predictPaths(client, path)
|
||||
} else {
|
||||
predictions = predictMounts(client, path)
|
||||
}
|
||||
|
||||
// Either no results or many results, so return.
|
||||
if len(predictions) != 1 {
|
||||
return predictions
|
||||
}
|
||||
|
||||
// If this is not a "folder", do not try to recurse.
|
||||
if !strings.HasSuffix(predictions[0], "/") {
|
||||
return predictions
|
||||
}
|
||||
|
||||
// If the prediction is the same as the last guess, return it (we have no
|
||||
// new information and we won't get anymore).
|
||||
if predictions[0] == args.Last {
|
||||
return predictions
|
||||
}
|
||||
|
||||
// Re-predict with the remaining path
|
||||
args.Last = predictions[0]
|
||||
return predictVaultPaths(client).Predict(args)
|
||||
}
|
||||
}
|
||||
|
||||
// predictMounts predicts all mounts which start with the given prefix. These
|
||||
// are predicted on mount path, not "type".
|
||||
func predictMounts(client *api.Client, path string) []string {
|
||||
mounts := predictListMounts(client)
|
||||
|
||||
var predictions []string
|
||||
for _, m := range mounts {
|
||||
if strings.HasPrefix(m, path) {
|
||||
predictions = append(predictions, m)
|
||||
}
|
||||
}
|
||||
|
||||
return predictions
|
||||
}
|
||||
|
||||
// predictPaths predicts all paths which start with the given path.
|
||||
func predictPaths(client *api.Client, path string) []string {
|
||||
// Vault does not support listing based on a sub-key, so we have to back-pedal
|
||||
// to the last "/" and return all paths on that "folder". Then we perform
|
||||
// client-side filtering.
|
||||
root := path
|
||||
idx := strings.LastIndex(root, "/")
|
||||
if idx > 0 && idx < len(root) {
|
||||
root = root[:idx+1]
|
||||
}
|
||||
|
||||
paths := predictListPaths(client, root)
|
||||
|
||||
var predictions []string
|
||||
for _, p := range paths {
|
||||
// Calculate the absolute "path" for matching.
|
||||
p = root + p
|
||||
|
||||
if strings.HasPrefix(p, path) {
|
||||
predictions = append(predictions, p)
|
||||
}
|
||||
}
|
||||
|
||||
// Add root to the path
|
||||
if len(predictions) == 0 {
|
||||
predictions = append(predictions, path)
|
||||
}
|
||||
|
||||
return predictions
|
||||
}
|
||||
|
||||
// predictListMounts returns a sorted list of the mount paths for Vault server
|
||||
// for which the client is configured to communicate with. This function returns
|
||||
// the default list of mounts if an error occurs.
|
||||
func predictListMounts(c *api.Client) []string {
|
||||
mounts, err := c.Sys().ListMounts()
|
||||
if err != nil {
|
||||
return defaultPredictVaultMounts
|
||||
}
|
||||
|
||||
list := make([]string, 0, len(mounts))
|
||||
for m := range mounts {
|
||||
list = append(list, m)
|
||||
}
|
||||
sort.Strings(list)
|
||||
return list
|
||||
}
|
||||
|
||||
// predictListPaths returns a list of paths (HTTP LIST) for the given path. This
|
||||
// function returns an empty list of any errors occur.
|
||||
func predictListPaths(c *api.Client, path string) []string {
|
||||
secret, err := c.Logical().List(path)
|
||||
if err != nil || secret == nil || secret.Data == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
paths, ok := secret.Data["keys"].([]interface{})
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
list := make([]string, 0, len(paths))
|
||||
for _, p := range paths {
|
||||
if str, ok := p.(string); ok {
|
||||
list = append(list, str)
|
||||
}
|
||||
}
|
||||
sort.Strings(list)
|
||||
return list
|
||||
}
|
||||
|
||||
// predictHasPathArg determines if the args have already accepted a path.
|
||||
func predictHasPathArg(args []string) bool {
|
||||
var nonFlags []string
|
||||
for _, a := range args {
|
||||
if !strings.HasPrefix(a, "-") {
|
||||
nonFlags = append(nonFlags, a)
|
||||
}
|
||||
}
|
||||
|
||||
return len(nonFlags) > 2
|
||||
}
|
||||
363
command/base_predict_test.go
Normal file
363
command/base_predict_test.go
Normal file
@@ -0,0 +1,363 @@
|
||||
package command
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/hashicorp/vault/api"
|
||||
"github.com/posener/complete"
|
||||
)
|
||||
|
||||
func TestPredictVaultPaths(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client, closer := testVaultServer(t)
|
||||
defer closer()
|
||||
|
||||
data := map[string]interface{}{"a": "b"}
|
||||
if _, err := client.Logical().Write("secret/bar", data); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := client.Logical().Write("secret/foo", data); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := client.Logical().Write("secret/zip/zap", data); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := client.Logical().Write("secret/zip/zonk", data); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := client.Logical().Write("secret/zip/twoot", data); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
f := predictVaultPaths(client)
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
args complete.Args
|
||||
exp []string
|
||||
}{
|
||||
{
|
||||
"has_args",
|
||||
complete.Args{
|
||||
All: []string{"read", "secret/foo", "a=b"},
|
||||
Last: "a=b",
|
||||
},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"part_mount",
|
||||
complete.Args{
|
||||
All: []string{"read", "s"},
|
||||
Last: "s",
|
||||
},
|
||||
[]string{"secret/", "sys/"},
|
||||
},
|
||||
{
|
||||
"only_mount",
|
||||
complete.Args{
|
||||
All: []string{"read", "sec"},
|
||||
Last: "sec",
|
||||
},
|
||||
[]string{"secret/bar", "secret/foo", "secret/zip/"},
|
||||
},
|
||||
{
|
||||
"full_mount",
|
||||
complete.Args{
|
||||
All: []string{"read", "secret"},
|
||||
Last: "secret",
|
||||
},
|
||||
[]string{"secret/bar", "secret/foo", "secret/zip/"},
|
||||
},
|
||||
{
|
||||
"full_mount_slash",
|
||||
complete.Args{
|
||||
All: []string{"read", "secret/"},
|
||||
Last: "secret/",
|
||||
},
|
||||
[]string{"secret/bar", "secret/foo", "secret/zip/"},
|
||||
},
|
||||
{
|
||||
"path_partial",
|
||||
complete.Args{
|
||||
All: []string{"read", "secret/z"},
|
||||
Last: "secret/z",
|
||||
},
|
||||
[]string{"secret/zip/twoot", "secret/zip/zap", "secret/zip/zonk"},
|
||||
},
|
||||
{
|
||||
"subpath_partial_z",
|
||||
complete.Args{
|
||||
All: []string{"read", "secret/zip/z"},
|
||||
Last: "secret/zip/z",
|
||||
},
|
||||
[]string{"secret/zip/zap", "secret/zip/zonk"},
|
||||
},
|
||||
{
|
||||
"subpath_partial_t",
|
||||
complete.Args{
|
||||
All: []string{"read", "secret/zip/t"},
|
||||
Last: "secret/zip/t",
|
||||
},
|
||||
[]string{"secret/zip/twoot"},
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("group", func(t *testing.T) {
|
||||
for _, tc := range cases {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
act := f(tc.args)
|
||||
if !reflect.DeepEqual(act, tc.exp) {
|
||||
t.Errorf("expected %q to be %q", act, tc.exp)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestPredictMounts(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client, closer := testVaultServer(t)
|
||||
defer closer()
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
client *api.Client
|
||||
path string
|
||||
exp []string
|
||||
}{
|
||||
{
|
||||
"no_match",
|
||||
client,
|
||||
"not-a-real-mount-seriously",
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"s",
|
||||
client,
|
||||
"s",
|
||||
[]string{"secret/", "sys/"},
|
||||
},
|
||||
{
|
||||
"se",
|
||||
client,
|
||||
"se",
|
||||
[]string{"secret/"},
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("group", func(t *testing.T) {
|
||||
for _, tc := range cases {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
act := predictMounts(tc.client, tc.path)
|
||||
if !reflect.DeepEqual(act, tc.exp) {
|
||||
t.Errorf("expected %q to be %q", act, tc.exp)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestPredictPaths(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client, closer := testVaultServer(t)
|
||||
defer closer()
|
||||
|
||||
data := map[string]interface{}{"a": "b"}
|
||||
if _, err := client.Logical().Write("secret/bar", data); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := client.Logical().Write("secret/foo", data); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := client.Logical().Write("secret/zip/zap", data); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
client *api.Client
|
||||
path string
|
||||
exp []string
|
||||
}{
|
||||
{
|
||||
"bad_path",
|
||||
client,
|
||||
"nope/not/a/real/path/ever",
|
||||
[]string{"nope/not/a/real/path/ever"},
|
||||
},
|
||||
{
|
||||
"good_path",
|
||||
client,
|
||||
"secret/",
|
||||
[]string{"secret/bar", "secret/foo", "secret/zip/"},
|
||||
},
|
||||
{
|
||||
"partial_match",
|
||||
client,
|
||||
"secret/z",
|
||||
[]string{"secret/zip/"},
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("group", func(t *testing.T) {
|
||||
for _, tc := range cases {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
act := predictPaths(tc.client, tc.path)
|
||||
if !reflect.DeepEqual(act, tc.exp) {
|
||||
t.Errorf("expected %q to be %q", act, tc.exp)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestPredictListMounts(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client, closer := testVaultServer(t)
|
||||
defer closer()
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
client *api.Client
|
||||
exp []string
|
||||
}{
|
||||
{
|
||||
"not_connected_client",
|
||||
func() *api.Client {
|
||||
// Bad API client
|
||||
client, _ := api.NewClient(nil)
|
||||
return client
|
||||
}(),
|
||||
defaultPredictVaultMounts,
|
||||
},
|
||||
{
|
||||
"good_path",
|
||||
client,
|
||||
[]string{"cubbyhole/", "secret/", "sys/"},
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("group", func(t *testing.T) {
|
||||
for _, tc := range cases {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
act := predictListMounts(tc.client)
|
||||
if !reflect.DeepEqual(act, tc.exp) {
|
||||
t.Errorf("expected %q to be %q", act, tc.exp)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestPredictListPaths(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client, closer := testVaultServer(t)
|
||||
defer closer()
|
||||
|
||||
data := map[string]interface{}{"a": "b"}
|
||||
if _, err := client.Logical().Write("secret/bar", data); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := client.Logical().Write("secret/foo", data); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
client *api.Client
|
||||
path string
|
||||
exp []string
|
||||
}{
|
||||
{
|
||||
"bad_path",
|
||||
client,
|
||||
"nope/not/a/real/path/ever",
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"good_path",
|
||||
client,
|
||||
"secret/",
|
||||
[]string{"bar", "foo"},
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("group", func(t *testing.T) {
|
||||
for _, tc := range cases {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
act := predictListPaths(tc.client, tc.path)
|
||||
if !reflect.DeepEqual(act, tc.exp) {
|
||||
t.Errorf("expected %q to be %q", act, tc.exp)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestPredictHasPathArg(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
args []string
|
||||
exp bool
|
||||
}{
|
||||
{
|
||||
"nil",
|
||||
nil,
|
||||
false,
|
||||
},
|
||||
{
|
||||
"empty",
|
||||
[]string{},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"empty_string",
|
||||
[]string{""},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"single",
|
||||
[]string{"foo"},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"multiple",
|
||||
[]string{"foo", "bar", "baz"},
|
||||
true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if act := predictHasPathArg(tc.args); act != tc.exp {
|
||||
t.Errorf("expected %t to be %t", act, tc.exp)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
14
command/base_test.go
Normal file
14
command/base_test.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package command
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/mitchellh/cli"
|
||||
)
|
||||
|
||||
func assertNoTabs(tb testing.TB, c cli.Command) {
|
||||
if strings.ContainsRune(c.Help(), '\t') {
|
||||
tb.Errorf("%#v help output contains tabs", c)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user