Fix bad rebase

Apparently I can't git...
This commit is contained in:
Seth Vargo
2017-09-21 20:51:12 -04:00
parent 50caac0bb6
commit be7c31f695
7 changed files with 328 additions and 403 deletions

View File

@@ -250,6 +250,6 @@ func OutputSealStatus(ui cli.Ui, client *api.Client, status *api.SealStatusRespo
} }
} }
ui.Output(columnOutput(out, nil)) ui.Output(tableOutput(out, nil))
return 0 return 0
} }

View File

@@ -1,7 +1,6 @@
package command package command
import ( import (
"fmt"
"io/ioutil" "io/ioutil"
"strings" "strings"
"testing" "testing"
@@ -68,7 +67,7 @@ func TestOperatorUnsealCommand_Run(t *testing.T) {
if exp := 0; code != exp { if exp := 0; code != exp {
t.Errorf("expected %d to be %d", code, exp) t.Errorf("expected %d to be %d", code, exp)
} }
expected := "Unseal Progress: 0" expected := "0/3"
combined := ui.OutputWriter.String() + ui.ErrorWriter.String() combined := ui.OutputWriter.String() + ui.ErrorWriter.String()
if !strings.Contains(combined, expected) { if !strings.Contains(combined, expected) {
t.Errorf("expected %q to contain %q", combined, expected) t.Errorf("expected %q to contain %q", combined, expected)
@@ -86,7 +85,7 @@ func TestOperatorUnsealCommand_Run(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
for i, key := range keys { for _, key := range keys {
ui, cmd := testOperatorUnsealCommand(t) ui, cmd := testOperatorUnsealCommand(t)
cmd.client = client cmd.client = client
cmd.testOutput = ioutil.Discard cmd.testOutput = ioutil.Discard
@@ -96,14 +95,17 @@ func TestOperatorUnsealCommand_Run(t *testing.T) {
key, key,
}) })
if exp := 0; code != exp { if exp := 0; code != exp {
t.Errorf("expected %d to be %d", code, exp) t.Errorf("expected %d to be %d: %s", code, exp, ui.ErrorWriter.String())
}
expected := fmt.Sprintf("Unseal Progress: %d", (i+1)%3) // 1, 2, 0
combined := ui.OutputWriter.String() + ui.ErrorWriter.String()
if !strings.Contains(combined, expected) {
t.Errorf("expected %q to contain %q", combined, expected)
} }
} }
status, err := client.Sys().SealStatus()
if err != nil {
t.Fatal(err)
}
if status.Sealed {
t.Error("expected unsealed")
}
}) })
t.Run("communication_failure", func(t *testing.T) { t.Run("communication_failure", func(t *testing.T) {

View File

@@ -46,9 +46,9 @@ func TestPathHelpCommand_Run(t *testing.T) {
2, 2,
}, },
{ {
"generic", "kv",
[]string{"secret/"}, []string{"secret/"},
"The generic backend", "The kv backend",
0, 0,
}, },
{ {

View File

@@ -33,7 +33,6 @@ import (
"github.com/hashicorp/go-multierror" "github.com/hashicorp/go-multierror"
"github.com/hashicorp/vault/audit" "github.com/hashicorp/vault/audit"
"github.com/hashicorp/vault/command/server" "github.com/hashicorp/vault/command/server"
"github.com/hashicorp/vault/helper/flag-slice"
"github.com/hashicorp/vault/helper/gated-writer" "github.com/hashicorp/vault/helper/gated-writer"
"github.com/hashicorp/vault/helper/logformat" "github.com/hashicorp/vault/helper/logformat"
"github.com/hashicorp/vault/helper/mlock" "github.com/hashicorp/vault/helper/mlock"
@@ -41,7 +40,6 @@ import (
"github.com/hashicorp/vault/helper/reload" "github.com/hashicorp/vault/helper/reload"
vaulthttp "github.com/hashicorp/vault/http" vaulthttp "github.com/hashicorp/vault/http"
"github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/meta"
"github.com/hashicorp/vault/physical" "github.com/hashicorp/vault/physical"
"github.com/hashicorp/vault/vault" "github.com/hashicorp/vault/vault"
"github.com/hashicorp/vault/version" "github.com/hashicorp/vault/version"
@@ -51,6 +49,8 @@ var _ cli.Command = (*ServerCommand)(nil)
var _ cli.CommandAutocomplete = (*ServerCommand)(nil) var _ cli.CommandAutocomplete = (*ServerCommand)(nil)
type ServerCommand struct { type ServerCommand struct {
*BaseCommand
AuditBackends map[string]audit.Factory AuditBackends map[string]audit.Factory
CredentialBackends map[string]logical.Factory CredentialBackends map[string]logical.Factory
LogicalBackends map[string]logical.Factory LogicalBackends map[string]logical.Factory
@@ -61,8 +61,6 @@ type ServerCommand struct {
WaitGroup *sync.WaitGroup WaitGroup *sync.WaitGroup
meta.Meta
logGate *gatedwriter.Writer logGate *gatedwriter.Writer
logger log.Logger logger log.Logger
@@ -84,9 +82,10 @@ type ServerCommand struct {
flagDevHA bool flagDevHA bool
flagDevLatency int flagDevLatency int
flagDevLatencyJitter int flagDevLatencyJitter int
flagDevTransactional bool
flagDevLeasedKV bool flagDevLeasedKV bool
flagDevSkipInit bool
flagDevThreeNode bool flagDevThreeNode bool
flagDevTransactional bool
flagTestVerifyOnly bool flagTestVerifyOnly bool
} }
@@ -223,6 +222,13 @@ func (c *ServerCommand) Flags() *FlagSets {
Hidden: true, Hidden: true,
}) })
f.BoolVar(&BoolVar{
Name: "dev-skip-init",
Target: &c.flagDevSkipInit,
Default: false,
Hidden: true,
})
f.BoolVar(&BoolVar{ f.BoolVar(&BoolVar{
Name: "dev-three-node", Name: "dev-three-node",
Target: &c.flagDevThreeNode, Target: &c.flagDevThreeNode,
@@ -252,27 +258,10 @@ func (c *ServerCommand) AutocompleteFlags() complete.Flags {
} }
func (c *ServerCommand) Run(args []string) int { func (c *ServerCommand) Run(args []string) int {
var dev, verifyOnly, devHA, devTransactional, devLeasedKV, devThreeNode, devSkipInit bool f := c.Flags()
var configPath []string
var logLevel, devRootTokenID, devListenAddress, devPluginDir string if err := f.Parse(args); err != nil {
var devLatency, devLatencyJitter int c.UI.Error(err.Error())
flags := c.Meta.FlagSet("server", meta.FlagSetDefault)
flags.BoolVar(&dev, "dev", false, "")
flags.StringVar(&devRootTokenID, "dev-root-token-id", "", "")
flags.StringVar(&devListenAddress, "dev-listen-address", "", "")
flags.StringVar(&devPluginDir, "dev-plugin-dir", "", "")
flags.StringVar(&logLevel, "log-level", "info", "")
flags.IntVar(&devLatency, "dev-latency", 0, "")
flags.IntVar(&devLatencyJitter, "dev-latency-jitter", 20, "")
flags.BoolVar(&verifyOnly, "verify-only", false, "")
flags.BoolVar(&devHA, "dev-ha", false, "")
flags.BoolVar(&devTransactional, "dev-transactional", false, "")
flags.BoolVar(&devLeasedKV, "dev-leased-kv", false, "")
flags.BoolVar(&devThreeNode, "dev-three-node", false, "")
flags.BoolVar(&devSkipInit, "dev-skip-init", false, "")
flags.Usage = func() { c.Ui.Output(c.Help()) }
flags.Var((*sliceflag.StringFlag)(&configPath), "config", "config")
if err := flags.Parse(args); err != nil {
return 1 return 1
} }
@@ -280,8 +269,8 @@ func (c *ServerCommand) Run(args []string) int {
// start logging too early. // start logging too early.
c.logGate = &gatedwriter.Writer{Writer: colorable.NewColorable(os.Stderr)} c.logGate = &gatedwriter.Writer{Writer: colorable.NewColorable(os.Stderr)}
var level int var level int
logLevel = strings.ToLower(strings.TrimSpace(logLevel)) c.flagLogLevel = strings.ToLower(strings.TrimSpace(c.flagLogLevel))
switch logLevel { switch c.flagLogLevel {
case "trace": case "trace":
level = log.LevelTrace level = log.LevelTrace
case "debug": case "debug":
@@ -295,7 +284,7 @@ func (c *ServerCommand) Run(args []string) int {
case "err": case "err":
level = log.LevelError level = log.LevelError
default: default:
c.Ui.Output(fmt.Sprintf("Unknown log level %s", logLevel)) c.UI.Error(fmt.Sprintf("Unknown log level: %s", c.flagLogLevel))
return 1 return 1
} }
@@ -315,24 +304,16 @@ func (c *ServerCommand) Run(args []string) int {
log: os.Getenv("VAULT_GRPC_LOGGING") != "", log: os.Getenv("VAULT_GRPC_LOGGING") != "",
}) })
if os.Getenv("VAULT_DEV_ROOT_TOKEN_ID") != "" && devRootTokenID == "" { // Automatically enable dev mode if other dev flags are provided.
devRootTokenID = os.Getenv("VAULT_DEV_ROOT_TOKEN_ID") if c.flagDevHA || c.flagDevTransactional || c.flagDevLeasedKV || c.flagDevThreeNode {
} c.flagDev = true
if os.Getenv("VAULT_DEV_LISTEN_ADDRESS") != "" && devListenAddress == "" {
devListenAddress = os.Getenv("VAULT_DEV_LISTEN_ADDRESS")
}
if devHA || devTransactional || devLeasedKV || devThreeNode {
dev = true
} }
// Validation // Validation
if !dev { if !c.flagDev {
switch { switch {
case len(configPath) == 0: case len(c.flagConfigs) == 0:
c.Ui.Output("At least one config path must be specified with -config") c.UI.Error("Must specify at least one config path using -config")
flags.Usage()
return 1 return 1
case c.flagDevRootTokenID != "": case c.flagDevRootTokenID != "":
c.UI.Warn(wrapAtLength( c.UI.Warn(wrapAtLength(
@@ -344,17 +325,16 @@ func (c *ServerCommand) Run(args []string) int {
// Load the configuration // Load the configuration
var config *server.Config var config *server.Config
if dev { if c.flagDev {
config = server.DevConfig(devHA, devTransactional) config = server.DevConfig(c.flagDevHA, c.flagDevTransactional)
if devListenAddress != "" { if c.flagDevListenAddr != "" {
config.Listeners[0].Config["address"] = devListenAddress config.Listeners[0].Config["address"] = c.flagDevListenAddr
} }
} }
for _, path := range configPath { for _, path := range c.flagConfigs {
current, err := server.LoadConfig(path, c.logger) current, err := server.LoadConfig(path, c.logger)
if err != nil { if err != nil {
c.Ui.Output(fmt.Sprintf( c.UI.Error(fmt.Sprintf("Error loading configuration from %s: %s", path, err))
"Error loading configuration from %s: %s", path, err))
return 1 return 1
} }
@@ -367,43 +347,45 @@ func (c *ServerCommand) Run(args []string) int {
// Ensure at least one config was found. // Ensure at least one config was found.
if config == nil { if config == nil {
c.Ui.Output("No configuration files found.") c.UI.Output(wrapAtLength(
"No configuration files found. Please provide configurations with the " +
"-config flag. If you are supply the path to a directory, please " +
"ensure the directory contains files with the .hcl or .json " +
"extension."))
return 1 return 1
} }
// Ensure that a backend is provided // Ensure that a backend is provided
if config.Storage == nil { if config.Storage == nil {
c.Ui.Output("A storage backend must be specified") c.UI.Output("A storage backend must be specified")
return 1 return 1
} }
// If mlockall(2) isn't supported, show a warning. We disable this // If mlockall(2) isn't supported, show a warning. We disable this
// in dev because it is quite scary to see when first using Vault. // in dev because it is quite scary to see when first using Vault.
if !dev && !mlock.Supported() { if !c.flagDev && !mlock.Supported() {
c.Ui.Output("==> WARNING: mlock not supported on this system!\n") c.UI.Warn(wrapAtLength(
c.Ui.Output(" An `mlockall(2)`-like syscall to prevent memory from being") "WARNING! mlock is not supported on this system! An mlockall(2)-like " +
c.Ui.Output(" swapped to disk is not supported on this system. Running") "syscall to prevent memory from being swapped to disk is not " +
c.Ui.Output(" Vault on an mlockall(2) enabled system is much more secure.\n") "supported on this system. For better security, only run Vault on " +
"systems where this call is supported. If you are running Vault " +
"in a Docker container, provide the IPC_LOCK cap to the container."))
} }
if err := c.setupTelemetry(config); err != nil { if err := c.setupTelemetry(config); err != nil {
c.Ui.Output(fmt.Sprintf("Error initializing telemetry: %s", err)) c.UI.Error(fmt.Sprintf("Error initializing telemetry: %s", err))
return 1 return 1
} }
// Initialize the backend // Initialize the backend
factory, exists := c.PhysicalBackends[config.Storage.Type] factory, exists := c.PhysicalBackends[config.Storage.Type]
if !exists { if !exists {
c.Ui.Output(fmt.Sprintf( c.UI.Error(fmt.Sprintf("Unknown storage type %s", config.Storage.Type))
"Unknown storage type %s",
config.Storage.Type))
return 1 return 1
} }
backend, err := factory(config.Storage.Config, c.logger) backend, err := factory(config.Storage.Config, c.logger)
if err != nil { if err != nil {
c.Ui.Output(fmt.Sprintf( c.UI.Error(fmt.Sprintf("Error initializing storage of type %s: %s", config.Storage.Type, err))
"Error initializing storage of type %s: %s",
config.Storage.Type, err))
return 1 return 1
} }
@@ -417,13 +399,13 @@ func (c *ServerCommand) Run(args []string) int {
if seal != nil { if seal != nil {
err = seal.Finalize() err = seal.Finalize()
if err != nil { if err != nil {
c.Ui.Error(fmt.Sprintf("Error finalizing seals: %v", err)) c.UI.Error(fmt.Sprintf("Error finalizing seals: %v", err))
} }
} }
}() }()
if seal == nil { if seal == nil {
c.Ui.Error(fmt.Sprintf("Could not create seal; most likely proper Seal configuration information was not set, but no error was generated.")) c.UI.Error(fmt.Sprintf("Could not create seal! Most likely proper Seal configuration information was not set, but no error was generated."))
return 1 return 1
} }
@@ -445,14 +427,13 @@ func (c *ServerCommand) Run(args []string) int {
PluginDirectory: config.PluginDirectory, PluginDirectory: config.PluginDirectory,
EnableRaw: config.EnableRawEndpoint, EnableRaw: config.EnableRawEndpoint,
} }
if c.flagDev {
if dev { coreConfig.DevToken = c.flagDevRootTokenID
coreConfig.DevToken = devRootTokenID if c.flagDevLeasedKV {
if devLeasedKV {
coreConfig.LogicalBackends["kv"] = vault.LeasedPassthroughBackendFactory coreConfig.LogicalBackends["kv"] = vault.LeasedPassthroughBackendFactory
} }
if devPluginDir != "" { if c.flagDevPluginDir != "" {
coreConfig.PluginDirectory = devPluginDir coreConfig.PluginDirectory = c.flagDevPluginDir
} }
if c.flagDevLatency > 0 { if c.flagDevLatency > 0 {
injectLatency := time.Duration(c.flagDevLatency) * time.Millisecond injectLatency := time.Duration(c.flagDevLatency) * time.Millisecond
@@ -464,8 +445,8 @@ func (c *ServerCommand) Run(args []string) int {
} }
} }
if devThreeNode { if c.flagDevThreeNode {
return c.enableThreeNodeDevCluster(coreConfig, info, infoKeys, devListenAddress) return c.enableThreeNodeDevCluster(coreConfig, info, infoKeys)
} }
var disableClustering bool var disableClustering bool
@@ -475,26 +456,25 @@ func (c *ServerCommand) Run(args []string) int {
if config.HAStorage != nil { if config.HAStorage != nil {
factory, exists := c.PhysicalBackends[config.HAStorage.Type] factory, exists := c.PhysicalBackends[config.HAStorage.Type]
if !exists { if !exists {
c.Ui.Output(fmt.Sprintf( c.UI.Error(fmt.Sprintf("Unknown HA storage type %s", config.HAStorage.Type))
"Unknown HA storage type %s",
config.HAStorage.Type))
return 1 return 1
} }
habackend, err := factory(config.HAStorage.Config, c.logger) habackend, err := factory(config.HAStorage.Config, c.logger)
if err != nil { if err != nil {
c.Ui.Output(fmt.Sprintf( c.UI.Error(fmt.Sprintf(
"Error initializing HA storage of type %s: %s", "Error initializing HA storage of type %s: %s", config.HAStorage.Type, err))
config.HAStorage.Type, err))
return 1 return 1
} }
if coreConfig.HAPhysical, ok = habackend.(physical.HABackend); !ok { if coreConfig.HAPhysical, ok = habackend.(physical.HABackend); !ok {
c.Ui.Output("Specified HA storage does not support HA") c.UI.Error("Specified HA storage does not support HA")
return 1 return 1
} }
if !coreConfig.HAPhysical.HAEnabled() { if !coreConfig.HAPhysical.HAEnabled() {
c.Ui.Output("Specified HA storage has HA support disabled; please consult documentation") c.UI.Error("Specified HA storage has HA support disabled; please consult documentation")
return 1 return 1
} }
@@ -529,14 +509,14 @@ func (c *ServerCommand) Run(args []string) int {
if ok && coreConfig.RedirectAddr == "" { if ok && coreConfig.RedirectAddr == "" {
redirect, err := c.detectRedirect(detect, config) redirect, err := c.detectRedirect(detect, config)
if err != nil { if err != nil {
c.Ui.Output(fmt.Sprintf("Error detecting redirect address: %s", err)) c.UI.Error(fmt.Sprintf("Error detecting redirect address: %s", err))
} else if redirect == "" { } else if redirect == "" {
c.Ui.Output("Failed to detect redirect address.") c.UI.Error("Failed to detect redirect address.")
} else { } else {
coreConfig.RedirectAddr = redirect coreConfig.RedirectAddr = redirect
} }
} }
if coreConfig.RedirectAddr == "" && dev { if coreConfig.RedirectAddr == "" && c.flagDev {
coreConfig.RedirectAddr = fmt.Sprintf("http://%s", config.Listeners[0].Config["address"]) coreConfig.RedirectAddr = fmt.Sprintf("http://%s", config.Listeners[0].Config["address"])
} }
@@ -551,14 +531,15 @@ func (c *ServerCommand) Run(args []string) int {
switch { switch {
case coreConfig.ClusterAddr == "" && coreConfig.RedirectAddr != "": case coreConfig.ClusterAddr == "" && coreConfig.RedirectAddr != "":
addrToUse = coreConfig.RedirectAddr addrToUse = coreConfig.RedirectAddr
case dev: case c.flagDev:
addrToUse = fmt.Sprintf("http://%s", config.Listeners[0].Config["address"]) addrToUse = fmt.Sprintf("http://%s", config.Listeners[0].Config["address"])
default: default:
goto CLUSTER_SYNTHESIS_COMPLETE goto CLUSTER_SYNTHESIS_COMPLETE
} }
u, err := url.ParseRequestURI(addrToUse) u, err := url.ParseRequestURI(addrToUse)
if err != nil { if err != nil {
c.Ui.Output(fmt.Sprintf("Error parsing synthesized cluster address %s: %v", addrToUse, err)) c.UI.Error(fmt.Sprintf(
"Error parsing synthesized cluster address %s: %v", addrToUse, err))
return 1 return 1
} }
host, port, err := net.SplitHostPort(u.Host) host, port, err := net.SplitHostPort(u.Host)
@@ -568,13 +549,14 @@ func (c *ServerCommand) Run(args []string) int {
host = u.Host host = u.Host
port = "443" port = "443"
} else { } else {
c.Ui.Output(fmt.Sprintf("Error parsing redirect address: %v", err)) c.UI.Error(fmt.Sprintf("Error parsing redirect address: %v", err))
return 1 return 1
} }
} }
nPort, err := strconv.Atoi(port) nPort, err := strconv.Atoi(port)
if err != nil { if err != nil {
c.Ui.Output(fmt.Sprintf("Error parsing synthesized address; failed to convert %q to a numeric: %v", port, err)) c.UI.Error(fmt.Sprintf(
"Error parsing synthesized address; failed to convert %q to a numeric: %v", port, err))
return 1 return 1
} }
u.Host = net.JoinHostPort(host, strconv.Itoa(nPort+1)) u.Host = net.JoinHostPort(host, strconv.Itoa(nPort+1))
@@ -589,8 +571,8 @@ CLUSTER_SYNTHESIS_COMPLETE:
// Force https as we'll always be TLS-secured // Force https as we'll always be TLS-secured
u, err := url.ParseRequestURI(coreConfig.ClusterAddr) u, err := url.ParseRequestURI(coreConfig.ClusterAddr)
if err != nil { if err != nil {
c.Ui.Output(fmt.Sprintf("Error parsing cluster address %s: %v", coreConfig.RedirectAddr, err)) c.UI.Error(fmt.Sprintf("Error parsing cluster address %s: %v", coreConfig.RedirectAddr, err))
return 1 return 11
} }
u.Scheme = "https" u.Scheme = "https"
coreConfig.ClusterAddr = u.String() coreConfig.ClusterAddr = u.String()
@@ -600,7 +582,7 @@ CLUSTER_SYNTHESIS_COMPLETE:
core, newCoreError := vault.NewCore(coreConfig) core, newCoreError := vault.NewCore(coreConfig)
if newCoreError != nil { if newCoreError != nil {
if !errwrap.ContainsType(newCoreError, new(vault.NonFatalError)) { if !errwrap.ContainsType(newCoreError, new(vault.NonFatalError)) {
c.Ui.Output(fmt.Sprintf("Error initializing core: %s", newCoreError)) c.UI.Error(fmt.Sprintf("Error initializing core: %s", newCoreError))
return 1 return 1
} }
} }
@@ -611,7 +593,7 @@ CLUSTER_SYNTHESIS_COMPLETE:
// Compile server information for output later // Compile server information for output later
info["storage"] = config.Storage.Type info["storage"] = config.Storage.Type
info["log level"] = logLevel info["log level"] = c.flagLogLevel
info["mlock"] = fmt.Sprintf( info["mlock"] = fmt.Sprintf(
"supported: %v, enabled: %v", "supported: %v, enabled: %v",
mlock.Supported(), !config.DisableMlock && mlock.Supported()) mlock.Supported(), !config.DisableMlock && mlock.Supported())
@@ -648,9 +630,7 @@ CLUSTER_SYNTHESIS_COMPLETE:
for i, lnConfig := range config.Listeners { for i, lnConfig := range config.Listeners {
ln, props, reloadFunc, err := server.NewListener(lnConfig.Type, lnConfig.Config, c.logGate) ln, props, reloadFunc, err := server.NewListener(lnConfig.Type, lnConfig.Config, c.logGate)
if err != nil { if err != nil {
c.Ui.Output(fmt.Sprintf( c.UI.Error(fmt.Sprintf("Error initializing listener of type %s: %s", lnConfig.Type, err))
"Error initializing listener of type %s: %s",
lnConfig.Type, err))
return 1 return 1
} }
@@ -670,16 +650,14 @@ CLUSTER_SYNTHESIS_COMPLETE:
addr = addrRaw.(string) addr = addrRaw.(string)
tcpAddr, err := net.ResolveTCPAddr("tcp", addr) tcpAddr, err := net.ResolveTCPAddr("tcp", addr)
if err != nil { if err != nil {
c.Ui.Output(fmt.Sprintf( c.UI.Error(fmt.Sprintf("Error resolving cluster_address: %s", err))
"Error resolving cluster_address: %s",
err))
return 1 return 1
} }
clusterAddrs = append(clusterAddrs, tcpAddr) clusterAddrs = append(clusterAddrs, tcpAddr)
} else { } else {
tcpAddr, ok := ln.Addr().(*net.TCPAddr) tcpAddr, ok := ln.Addr().(*net.TCPAddr)
if !ok { if !ok {
c.Ui.Output("Failed to parse tcp listener") c.UI.Error("Failed to parse tcp listener")
return 1 return 1
} }
clusterAddr := &net.TCPAddr{ clusterAddr := &net.TCPAddr{
@@ -737,17 +715,19 @@ CLUSTER_SYNTHESIS_COMPLETE:
// Server configuration output // Server configuration output
padding := 24 padding := 24
sort.Strings(infoKeys) sort.Strings(infoKeys)
c.Ui.Output("==> Vault server configuration:\n") c.UI.Output("==> Vault server configuration:\n")
for _, k := range infoKeys { for _, k := range infoKeys {
c.Ui.Output(fmt.Sprintf( c.UI.Output(fmt.Sprintf(
"%s%s: %s", "%s%s: %s",
strings.Repeat(" ", padding-len(k)), strings.Repeat(" ", padding-len(k)),
strings.Title(k), strings.Title(k),
info[k])) info[k]))
} }
c.Ui.Output("") c.UI.Output("")
if verifyOnly { // Tests might not want to start a vault server and just want to verify
// the configuration.
if c.flagTestVerifyOnly {
return 0 return 0
} }
@@ -761,7 +741,7 @@ CLUSTER_SYNTHESIS_COMPLETE:
err = core.UnsealWithStoredKeys() err = core.UnsealWithStoredKeys()
if err != nil { if err != nil {
if !errwrap.ContainsType(err, new(vault.NonFatalError)) { if !errwrap.ContainsType(err, new(vault.NonFatalError)) {
c.Ui.Output(fmt.Sprintf("Error initializing core: %s", err)) c.UI.Error(fmt.Sprintf("Error initializing core: %s", err))
return 1 return 1
} }
} }
@@ -791,18 +771,17 @@ CLUSTER_SYNTHESIS_COMPLETE:
} }
if err := sd.RunServiceDiscovery(c.WaitGroup, c.ShutdownCh, coreConfig.RedirectAddr, activeFunc, sealedFunc); err != nil { if err := sd.RunServiceDiscovery(c.WaitGroup, c.ShutdownCh, coreConfig.RedirectAddr, activeFunc, sealedFunc); err != nil {
c.Ui.Output(fmt.Sprintf("Error initializing service discovery: %v", err)) c.UI.Error(fmt.Sprintf("Error initializing service discovery: %v", err))
return 1 return 1
} }
} }
} }
// If we're in Dev mode, then initialize the core // If we're in Dev mode, then initialize the core
if dev && !devSkipInit { if c.flagDev && !c.flagDevSkipInit {
init, err := c.enableDev(core, coreConfig) init, err := c.enableDev(core, coreConfig)
if err != nil { if err != nil {
c.Ui.Output(fmt.Sprintf( c.UI.Error(fmt.Sprintf("Error initializing Dev mode: %s", err))
"Error initializing Dev mode: %s", err))
return 1 return 1
} }
@@ -813,44 +792,49 @@ CLUSTER_SYNTHESIS_COMPLETE:
quote = "" quote = ""
} }
c.Ui.Output(fmt.Sprint( // Print the big dev mode warning!
"==> WARNING: Dev mode is enabled!\n\n" + c.UI.Warn(wrapAtLength(
"In this mode, Vault is completely in-memory and unsealed.\n" + "WARNING! dev mode is enabled! In this mode, Vault runs entirely " +
"Vault is configured to only have a single unseal key. The root\n" + "in-memory and starts unsealed with a single unseal key. The root " +
"token has already been authenticated with the CLI, so you can\n" + "token is already authenticated to the CLI, so you can immediately " +
"immediately begin using the Vault CLI.\n\n" + "begin using Vault."))
"The only step you need to take is to set the following\n" + c.UI.Warn("")
"environment variables:\n\n" + c.UI.Warn("You may need to set the following environment variable:")
" " + export + " VAULT_ADDR=" + quote + "http://" + config.Listeners[0].Config["address"].(string) + quote + "\n\n" + c.UI.Warn("")
"The unseal key and root token are reproduced below in case you\n" + c.UI.Warn(fmt.Sprintf(" $ %s VAULT_ADDR=%s%s%s",
"want to seal/unseal the Vault or play with authentication.\n", export, quote, "http://"+config.Listeners[0].Config["address"].(string), quote))
))
// Unseal key is not returned if stored shares is supported // Unseal key is not returned if stored shares is supported
if len(init.SecretShares) > 0 { if len(init.SecretShares) > 0 {
c.Ui.Output(fmt.Sprintf( c.UI.Warn("")
"Unseal Key: %s", c.UI.Warn(wrapAtLength(
base64.StdEncoding.EncodeToString(init.SecretShares[0]), "The unseal key and root token are displayed below in case you want " +
)) "to seal/unseal the Vault or re-authenticate."))
c.UI.Warn("")
c.UI.Warn(fmt.Sprintf("Unseal Key: %s", base64.StdEncoding.EncodeToString(init.SecretShares[0])))
} }
if len(init.RecoveryShares) > 0 { if len(init.RecoveryShares) > 0 {
c.Ui.Output(fmt.Sprintf( c.UI.Warn("")
"Recovery Key: %s", c.UI.Warn(wrapAtLength(
base64.StdEncoding.EncodeToString(init.RecoveryShares[0]), "The recovery key and root token are displayed below in case you want " +
)) "to seal/unseal the Vault or re-authenticate."))
c.UI.Warn("")
c.UI.Warn(fmt.Sprintf("Unseal Key: %s", base64.StdEncoding.EncodeToString(init.RecoveryShares[0])))
} }
c.Ui.Output(fmt.Sprintf( c.UI.Warn(fmt.Sprintf("Root Token: %s", init.RootToken))
"Root Token: %s\n",
init.RootToken, c.UI.Warn("")
)) c.UI.Warn(wrapAtLength(
"Development mode should NOT be used in production installations!"))
c.UI.Warn("")
} }
// Initialize the HTTP server // Initialize the HTTP server
server := &http.Server{} server := &http.Server{}
if err := http2.ConfigureServer(server, nil); err != nil { if err := http2.ConfigureServer(server, nil); err != nil {
c.Ui.Output(fmt.Sprintf("Error configuring server for HTTP/2: %s", err)) c.UI.Error(fmt.Sprintf("Error configuring server for HTTP/2: %s", err))
return 1 return 1
} }
server.Handler = handler server.Handler = handler
@@ -859,12 +843,20 @@ CLUSTER_SYNTHESIS_COMPLETE:
} }
if newCoreError != nil { if newCoreError != nil {
c.Ui.Output("==> Warning:\n\nNon-fatal error during initialization; check the logs for more information.") c.UI.Warn(wrapAtLength(
c.Ui.Output("") "WARNING! A non-fatal error occurred during initialization. Please " +
"check the logs for more information."))
c.UI.Warn("")
} }
// Output the header that the server has started // Output the header that the server has started
c.Ui.Output("==> Vault server started! Log data will stream in below:\n") c.UI.Output("==> Vault server started! Log data will stream in below:\n")
// Inform any tests that the server is ready
select {
case c.startedCh <- struct{}{}:
default:
}
// Release the log gate. // Release the log gate.
c.logGate.Flush() c.logGate.Flush()
@@ -887,7 +879,7 @@ CLUSTER_SYNTHESIS_COMPLETE:
for !shutdownTriggered { for !shutdownTriggered {
select { select {
case <-c.ShutdownCh: case <-c.ShutdownCh:
c.Ui.Output("==> Vault shutdown triggered") c.UI.Output("==> Vault shutdown triggered")
// Stop the listners so that we don't process further client requests. // Stop the listners so that we don't process further client requests.
c.cleanupGuard.Do(listenerCloseFunc) c.cleanupGuard.Do(listenerCloseFunc)
@@ -896,15 +888,15 @@ CLUSTER_SYNTHESIS_COMPLETE:
// request forwarding listeners will also be closed (and also // request forwarding listeners will also be closed (and also
// waited for). // waited for).
if err := core.Shutdown(); err != nil { if err := core.Shutdown(); err != nil {
c.Ui.Output(fmt.Sprintf("Error with core shutdown: %s", err)) c.UI.Error(fmt.Sprintf("Error with core shutdown: %s", err))
} }
shutdownTriggered = true shutdownTriggered = true
case <-c.SighupCh: case <-c.SighupCh:
c.Ui.Output("==> Vault reload triggered") c.UI.Output("==> Vault reload triggered")
if err := c.Reload(c.reloadFuncsLock, c.reloadFuncs, configPath); err != nil { if err := c.Reload(c.reloadFuncsLock, c.reloadFuncs, c.flagConfigs); err != nil {
c.Ui.Output(fmt.Sprintf("Error(s) were encountered during reload: %s", err)) c.UI.Error(fmt.Sprintf("Error(s) were encountered during reload: %s", err))
} }
} }
} }
@@ -1031,10 +1023,10 @@ func (c *ServerCommand) enableDev(core *vault.Core, coreConfig *vault.CoreConfig
return init, nil return init, nil
} }
func (c *ServerCommand) enableThreeNodeDevCluster(base *vault.CoreConfig, info map[string]string, infoKeys []string, devListenAddress string) int { func (c *ServerCommand) enableThreeNodeDevCluster(base *vault.CoreConfig, info map[string]string, infoKeys []string) int {
testCluster := vault.NewTestCluster(&testing.RuntimeT{}, base, &vault.TestClusterOptions{ testCluster := vault.NewTestCluster(&testing.RuntimeT{}, base, &vault.TestClusterOptions{
HandlerFunc: vaulthttp.Handler, HandlerFunc: vaulthttp.Handler,
BaseListenAddress: devListenAddress, BaseListenAddress: c.flagDevListenAddr,
}) })
defer c.cleanupGuard.Do(testCluster.Cleanup) defer c.cleanupGuard.Do(testCluster.Cleanup)
@@ -1063,15 +1055,15 @@ func (c *ServerCommand) enableThreeNodeDevCluster(base *vault.CoreConfig, info m
// Server configuration output // Server configuration output
padding := 24 padding := 24
sort.Strings(infoKeys) sort.Strings(infoKeys)
c.Ui.Output("==> Vault server configuration:\n") c.UI.Output("==> Vault server configuration:\n")
for _, k := range infoKeys { for _, k := range infoKeys {
c.Ui.Output(fmt.Sprintf( c.UI.Output(fmt.Sprintf(
"%s%s: %s", "%s%s: %s",
strings.Repeat(" ", padding-len(k)), strings.Repeat(" ", padding-len(k)),
strings.Title(k), strings.Title(k),
info[k])) info[k]))
} }
c.Ui.Output("") c.UI.Output("")
for _, core := range testCluster.Cores { for _, core := range testCluster.Cores {
core.Server.Handler = vaulthttp.Handler(core.Core) core.Server.Handler = vaulthttp.Handler(core.Core)
@@ -1095,15 +1087,15 @@ func (c *ServerCommand) enableThreeNodeDevCluster(base *vault.CoreConfig, info m
} }
resp, err := testCluster.Cores[0].HandleRequest(req) resp, err := testCluster.Cores[0].HandleRequest(req)
if err != nil { if err != nil {
c.Ui.Output(fmt.Sprintf("failed to create root token with ID %s: %s", base.DevToken, err)) c.UI.Error(fmt.Sprintf("failed to create root token with ID %s: %s", base.DevToken, err))
return 1 return 1
} }
if resp == nil { if resp == nil {
c.Ui.Output(fmt.Sprintf("nil response when creating root token with ID %s", base.DevToken)) c.UI.Error(fmt.Sprintf("nil response when creating root token with ID %s", base.DevToken))
return 1 return 1
} }
if resp.Auth == nil { if resp.Auth == nil {
c.Ui.Output(fmt.Sprintf("nil auth when creating root token with ID %s", base.DevToken)) c.UI.Error(fmt.Sprintf("nil auth when creating root token with ID %s", base.DevToken))
return 1 return 1
} }
@@ -1114,7 +1106,7 @@ func (c *ServerCommand) enableThreeNodeDevCluster(base *vault.CoreConfig, info m
req.Data = nil req.Data = nil
resp, err = testCluster.Cores[0].HandleRequest(req) resp, err = testCluster.Cores[0].HandleRequest(req)
if err != nil { if err != nil {
c.Ui.Output(fmt.Sprintf("failed to revoke initial root token: %s", err)) c.UI.Output(fmt.Sprintf("failed to revoke initial root token: %s", err))
return 1 return 1
} }
} }
@@ -1122,37 +1114,37 @@ func (c *ServerCommand) enableThreeNodeDevCluster(base *vault.CoreConfig, info m
// Set the token // Set the token
tokenHelper, err := c.TokenHelper() tokenHelper, err := c.TokenHelper()
if err != nil { if err != nil {
c.Ui.Output(fmt.Sprintf("%v", err)) c.UI.Error(fmt.Sprintf("Error getting token helper: %s", err))
return 1 return 1
} }
if err := tokenHelper.Store(testCluster.RootToken); err != nil { if err := tokenHelper.Store(testCluster.RootToken); err != nil {
c.Ui.Output(fmt.Sprintf("%v", err)) c.UI.Error(fmt.Sprintf("Error storing in token helper: %s", err))
return 1 return 1
} }
if err := ioutil.WriteFile(filepath.Join(testCluster.TempDir, "root_token"), []byte(testCluster.RootToken), 0755); err != nil { if err := ioutil.WriteFile(filepath.Join(testCluster.TempDir, "root_token"), []byte(testCluster.RootToken), 0755); err != nil {
c.Ui.Output(fmt.Sprintf("%v", err)) c.UI.Error(fmt.Sprintf("Error writing token to tempfile: %s", err))
return 1 return 1
} }
c.Ui.Output(fmt.Sprintf( c.UI.Output(fmt.Sprintf(
"==> Three node dev mode is enabled\n\n" + "==> Three node dev mode is enabled\n\n" +
"The unseal key and root token are reproduced below in case you\n" + "The unseal key and root token are reproduced below in case you\n" +
"want to seal/unseal the Vault or play with authentication.\n", "want to seal/unseal the Vault or play with authentication.\n",
)) ))
for i, key := range testCluster.BarrierKeys { for i, key := range testCluster.BarrierKeys {
c.Ui.Output(fmt.Sprintf( c.UI.Output(fmt.Sprintf(
"Unseal Key %d: %s", "Unseal Key %d: %s",
i+1, base64.StdEncoding.EncodeToString(key), i+1, base64.StdEncoding.EncodeToString(key),
)) ))
} }
c.Ui.Output(fmt.Sprintf( c.UI.Output(fmt.Sprintf(
"\nRoot Token: %s\n", testCluster.RootToken, "\nRoot Token: %s\n", testCluster.RootToken,
)) ))
c.Ui.Output(fmt.Sprintf( c.UI.Output(fmt.Sprintf(
"\nUseful env vars:\n"+ "\nUseful env vars:\n"+
"VAULT_TOKEN=%s\n"+ "VAULT_TOKEN=%s\n"+
"VAULT_ADDR=%s\n"+ "VAULT_ADDR=%s\n"+
@@ -1163,7 +1155,13 @@ func (c *ServerCommand) enableThreeNodeDevCluster(base *vault.CoreConfig, info m
)) ))
// Output the header that the server has started // Output the header that the server has started
c.Ui.Output("==> Vault server started! Log data will stream in below:\n") c.UI.Output("==> Vault server started! Log data will stream in below:\n")
// Inform any tests that the server is ready
select {
case c.startedCh <- struct{}{}:
default:
}
// Release the log gate. // Release the log gate.
c.logGate.Flush() c.logGate.Flush()
@@ -1174,7 +1172,7 @@ func (c *ServerCommand) enableThreeNodeDevCluster(base *vault.CoreConfig, info m
for !shutdownTriggered { for !shutdownTriggered {
select { select {
case <-c.ShutdownCh: case <-c.ShutdownCh:
c.Ui.Output("==> Vault shutdown triggered") c.UI.Output("==> Vault shutdown triggered")
// Stop the listners so that we don't process further client requests. // Stop the listners so that we don't process further client requests.
c.cleanupGuard.Do(testCluster.Cleanup) c.cleanupGuard.Do(testCluster.Cleanup)
@@ -1184,17 +1182,17 @@ func (c *ServerCommand) enableThreeNodeDevCluster(base *vault.CoreConfig, info m
// waited for). // waited for).
for _, core := range testCluster.Cores { for _, core := range testCluster.Cores {
if err := core.Shutdown(); err != nil { if err := core.Shutdown(); err != nil {
c.Ui.Output(fmt.Sprintf("Error with core shutdown: %s", err)) c.UI.Error(fmt.Sprintf("Error with core shutdown: %s", err))
} }
} }
shutdownTriggered = true shutdownTriggered = true
case <-c.SighupCh: case <-c.SighupCh:
c.Ui.Output("==> Vault reload triggered") c.UI.Output("==> Vault reload triggered")
for _, core := range testCluster.Cores { for _, core := range testCluster.Cores {
if err := c.Reload(core.ReloadFuncsLock, core.ReloadFuncs, nil); err != nil { if err := c.Reload(core.ReloadFuncsLock, core.ReloadFuncs, nil); err != nil {
c.Ui.Output(fmt.Sprintf("Error(s) were encountered during reload: %s", err)) c.UI.Error(fmt.Sprintf("Error(s) were encountered during reload: %s", err))
} }
} }
} }
@@ -1405,68 +1403,11 @@ func (c *ServerCommand) Reload(lock *sync.RWMutex, reloadFuncs *map[string][]rel
} }
} }
return reloadErrors.ErrorOrNil() // Send a message that we reloaded. This prevents "guessing" sleep times
} // in tests.
select {
func (c *ServerCommand) Synopsis() string { case c.reloadedCh <- struct{}{}:
return "Start a Vault server" default:
}
func (c *ServerCommand) Help() string {
helpText := `
Usage: vault server [options]
Start a Vault server.
This command starts a Vault server that responds to API requests.
Vault will start in a "sealed" state. The Vault must be unsealed
with "vault unseal" or the API before this server can respond to requests.
This must be done for every server.
If the server is being started against a storage backend that is
brand new (no existing Vault data in it), it must be initialized with
"vault init" or the API first.
General Options:
-config=<path> Path to the configuration file or directory. This can
be specified multiple times. If it is a directory,
all files with a ".hcl" or ".json" suffix will be
loaded.
-dev Enables Dev mode. In this mode, Vault is completely
in-memory and unsealed. Do not run the Dev server in
production!
-dev-root-token-id="" If set, the root token returned in Dev mode will have
the given ID. This *only* has an effect when running
in Dev mode. Can also be specified with the
VAULT_DEV_ROOT_TOKEN_ID environment variable.
-dev-listen-address="" If set, this overrides the normal Dev mode listen
address of "127.0.0.1:8200". Can also be specified
with the VAULT_DEV_LISTEN_ADDRESS environment
variable.
-log-level=info Log verbosity. Defaults to "info", will be output to
stderr. Supported values: "trace", "debug", "info",
"warn", "err"
`
return strings.TrimSpace(helpText)
}
func (c *ServerCommand) AutocompleteArgs() complete.Predictor {
return complete.PredictNothing
}
func (c *ServerCommand) AutocompleteFlags() complete.Flags {
return complete.Flags{
"-config": complete.PredictOr(complete.PredictFiles("*.hcl"), complete.PredictFiles("*.json")),
"-dev": complete.PredictNothing,
"-dev-root-token-id": complete.PredictNothing,
"-dev-listen-address": complete.PredictNothing,
"-log-level": complete.PredictSet("trace", "debug", "info", "warn", "err"),
} }
return reloadErrors.ErrorOrNil() return reloadErrors.ErrorOrNil()

View File

@@ -1,106 +0,0 @@
// +build !race
package command
import (
"io/ioutil"
"os"
"strings"
"testing"
"github.com/hashicorp/vault/meta"
"github.com/hashicorp/vault/physical"
"github.com/mitchellh/cli"
physConsul "github.com/hashicorp/vault/physical/consul"
)
// The following tests have a go-metrics/exp manager race condition
func TestServer_CommonHA(t *testing.T) {
ui := new(cli.MockUi)
c := &ServerCommand{
Meta: meta.Meta{
Ui: ui,
},
PhysicalBackends: map[string]physical.Factory{
"consul": physConsul.NewConsulBackend,
},
}
tmpfile, err := ioutil.TempFile("", "")
if err != nil {
t.Fatalf("error creating temp dir: %v", err)
}
tmpfile.WriteString(basehcl + consulhcl)
tmpfile.Close()
defer os.Remove(tmpfile.Name())
args := []string{"-config", tmpfile.Name(), "-verify-only", "true"}
if code := c.Run(args); code != 0 {
t.Fatalf("bad: %d\n\n%s\n\n%s", code, ui.ErrorWriter.String(), ui.OutputWriter.String())
}
if !strings.Contains(ui.OutputWriter.String(), "(HA available)") {
t.Fatalf("did not find HA available: %s", ui.OutputWriter.String())
}
}
func TestServer_GoodSeparateHA(t *testing.T) {
ui := new(cli.MockUi)
c := &ServerCommand{
Meta: meta.Meta{
Ui: ui,
},
PhysicalBackends: map[string]physical.Factory{
"consul": physConsul.NewConsulBackend,
},
}
tmpfile, err := ioutil.TempFile("", "")
if err != nil {
t.Fatalf("error creating temp dir: %v", err)
}
tmpfile.WriteString(basehcl + consulhcl + haconsulhcl)
tmpfile.Close()
defer os.Remove(tmpfile.Name())
args := []string{"-config", tmpfile.Name(), "-verify-only", "true"}
if code := c.Run(args); code != 0 {
t.Fatalf("bad: %d\n\n%s\n\n%s", code, ui.ErrorWriter.String(), ui.OutputWriter.String())
}
if !strings.Contains(ui.OutputWriter.String(), "HA Storage:") {
t.Fatalf("did not find HA Storage: %s", ui.OutputWriter.String())
}
}
func TestServer_BadSeparateHA(t *testing.T) {
ui := new(cli.MockUi)
c := &ServerCommand{
Meta: meta.Meta{
Ui: ui,
},
PhysicalBackends: map[string]physical.Factory{
"consul": physConsul.NewConsulBackend,
},
}
tmpfile, err := ioutil.TempFile("", "")
if err != nil {
t.Fatalf("error creating temp dir: %v", err)
}
tmpfile.WriteString(basehcl + consulhcl + badhaconsulhcl)
tmpfile.Close()
defer os.Remove(tmpfile.Name())
args := []string{"-config", tmpfile.Name()}
if code := c.Run(args); code == 0 {
t.Fatalf("bad: should have gotten an error on a bad HA config")
}
}

View File

@@ -1,4 +1,5 @@
// +build !race // +build !race
// The server tests have a go-metrics/exp manager race condition :(.
package command package command
@@ -7,72 +8,112 @@ import (
"crypto/x509" "crypto/x509"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"math/rand" "net"
"os" "os"
"strings" "strings"
"sync" "sync"
"testing" "testing"
"time" "time"
"github.com/hashicorp/vault/meta"
"github.com/hashicorp/vault/physical" "github.com/hashicorp/vault/physical"
"github.com/mitchellh/cli" "github.com/mitchellh/cli"
physConsul "github.com/hashicorp/vault/physical/consul"
physFile "github.com/hashicorp/vault/physical/file" physFile "github.com/hashicorp/vault/physical/file"
) )
var ( func testRandomPort(tb testing.TB) int {
basehcl = ` tb.Helper()
disable_mlock = true
listener "tcp" { addr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:0")
address = "127.0.0.1:8200" if err != nil {
tls_disable = "true" tb.Fatal(err)
}
l, err := net.ListenTCP("tcp", addr)
if err != nil {
tb.Fatal(err)
}
defer l.Close()
return l.Addr().(*net.TCPAddr).Port
} }
`
consulhcl = ` func testBaseHCL(tb testing.TB) string {
tb.Helper()
return strings.TrimSpace(fmt.Sprintf(`
disable_mlock = true
listener "tcp" {
address = "127.0.0.1:%d"
tls_disable = "true"
}
`, testRandomPort(tb)))
}
const (
consulHCL = `
backend "consul" { backend "consul" {
prefix = "foo/" prefix = "foo/"
advertise_addr = "http://127.0.0.1:8200" advertise_addr = "http://127.0.0.1:8200"
disable_registration = "true" disable_registration = "true"
} }
` `
haconsulhcl = ` haConsulHCL = `
ha_backend "consul" { ha_backend "consul" {
prefix = "bar/" prefix = "bar/"
redirect_addr = "http://127.0.0.1:8200" redirect_addr = "http://127.0.0.1:8200"
disable_registration = "true" disable_registration = "true"
} }
` `
badhaconsulhcl = ` badHAConsulHCL = `
ha_backend "file" { ha_backend "file" {
path = "/dev/null" path = "/dev/null"
} }
` `
reloadhcl = ` reloadHCL = `
backend "file" { backend "file" {
path = "/dev/null" path = "/dev/null"
} }
disable_mlock = true disable_mlock = true
listener "tcp" { listener "tcp" {
address = "127.0.0.1:8203" address = "127.0.0.1:8203"
tls_cert_file = "TMPDIR/reload_cert.pem" tls_cert_file = "TMPDIR/reload_cert.pem"
tls_key_file = "TMPDIR/reload_key.pem" tls_key_file = "TMPDIR/reload_key.pem"
} }
` `
) )
// The following tests have a go-metrics/exp manager race condition func testServerCommand(tb testing.TB) (*cli.MockUi, *ServerCommand) {
tb.Helper()
ui := cli.NewMockUi()
return ui, &ServerCommand{
BaseCommand: &BaseCommand{
UI: ui,
},
ShutdownCh: MakeShutdownCh(),
SighupCh: MakeSighupCh(),
PhysicalBackends: map[string]physical.Factory{
"file": physFile.NewFileBackend,
"consul": physConsul.NewConsulBackend,
},
// These prevent us from random sleep guessing...
startedCh: make(chan struct{}, 5),
reloadedCh: make(chan struct{}, 5),
}
}
func TestServer_ReloadListener(t *testing.T) { func TestServer_ReloadListener(t *testing.T) {
t.Parallel()
wd, _ := os.Getwd() wd, _ := os.Getwd()
wd += "/server/test-fixtures/reload/" wd += "/server/test-fixtures/reload/"
td, err := ioutil.TempDir("", fmt.Sprintf("vault-test-%d", rand.New(rand.NewSource(time.Now().Unix())).Int63)) td, err := ioutil.TempDir("", "vault-test-")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -86,7 +127,7 @@ func TestServer_ReloadListener(t *testing.T) {
inBytes, _ = ioutil.ReadFile(wd + "reload_foo.key") inBytes, _ = ioutil.ReadFile(wd + "reload_foo.key")
ioutil.WriteFile(td+"/reload_key.pem", inBytes, 0777) ioutil.WriteFile(td+"/reload_key.pem", inBytes, 0777)
relhcl := strings.Replace(reloadhcl, "TMPDIR", td, -1) relhcl := strings.Replace(reloadHCL, "TMPDIR", td, -1)
ioutil.WriteFile(td+"/reload.hcl", []byte(relhcl), 0777) ioutil.WriteFile(td+"/reload.hcl", []byte(relhcl), 0777)
inBytes, _ = ioutil.ReadFile(wd + "reload_ca.pem") inBytes, _ = ioutil.ReadFile(wd + "reload_ca.pem")
@@ -96,17 +137,8 @@ func TestServer_ReloadListener(t *testing.T) {
t.Fatal("not ok when appending CA cert") t.Fatal("not ok when appending CA cert")
} }
ui := new(cli.MockUi) ui, cmd := testServerCommand(t)
c := &ServerCommand{ _ = ui
Meta: meta.Meta{
Ui: ui,
},
ShutdownCh: MakeShutdownCh(),
SighupCh: MakeSighupCh(),
PhysicalBackends: map[string]physical.Factory{
"file": physFile.NewFileBackend,
},
}
finished := false finished := false
finishedMutex := sync.Mutex{} finishedMutex := sync.Mutex{}
@@ -114,7 +146,7 @@ func TestServer_ReloadListener(t *testing.T) {
wg.Add(1) wg.Add(1)
args := []string{"-config", td + "/reload.hcl"} args := []string{"-config", td + "/reload.hcl"}
go func() { go func() {
if code := c.Run(args); code != 0 { if code := cmd.Run(args); code != 0 {
t.Error("got a non-zero exit status") t.Error("got a non-zero exit status")
} }
finishedMutex.Lock() finishedMutex.Lock()
@@ -123,14 +155,6 @@ func TestServer_ReloadListener(t *testing.T) {
wg.Done() wg.Done()
}() }()
checkFinished := func() {
finishedMutex.Lock()
if finished {
t.Fatalf(fmt.Sprintf("finished early; relhcl was\n%s\nstdout was\n%s\nstderr was\n%s\n", relhcl, ui.OutputWriter.String(), ui.ErrorWriter.String()))
}
finishedMutex.Unlock()
}
testCertificateName := func(cn string) error { testCertificateName := func(cn string) error {
conn, err := tls.Dial("tcp", "127.0.0.1:8203", &tls.Config{ conn, err := tls.Dial("tcp", "127.0.0.1:8203", &tls.Config{
RootCAs: certPool, RootCAs: certPool,
@@ -149,31 +173,95 @@ func TestServer_ReloadListener(t *testing.T) {
return nil return nil
} }
checkFinished() select {
time.Sleep(5 * time.Second) case <-cmd.startedCh:
checkFinished() case <-time.After(5 * time.Second):
t.Fatalf("timeout")
}
if err := testCertificateName("foo.example.com"); err != nil { if err := testCertificateName("foo.example.com"); err != nil {
t.Fatalf("certificate name didn't check out: %s", err) t.Fatalf("certificate name didn't check out: %s", err)
} }
relhcl = strings.Replace(reloadhcl, "TMPDIR", td, -1) relhcl = strings.Replace(reloadHCL, "TMPDIR", td, -1)
inBytes, _ = ioutil.ReadFile(wd + "reload_bar.pem") inBytes, _ = ioutil.ReadFile(wd + "reload_bar.pem")
ioutil.WriteFile(td+"/reload_cert.pem", inBytes, 0777) ioutil.WriteFile(td+"/reload_cert.pem", inBytes, 0777)
inBytes, _ = ioutil.ReadFile(wd + "reload_bar.key") inBytes, _ = ioutil.ReadFile(wd + "reload_bar.key")
ioutil.WriteFile(td+"/reload_key.pem", inBytes, 0777) ioutil.WriteFile(td+"/reload_key.pem", inBytes, 0777)
ioutil.WriteFile(td+"/reload.hcl", []byte(relhcl), 0777) ioutil.WriteFile(td+"/reload.hcl", []byte(relhcl), 0777)
c.SighupCh <- struct{}{} cmd.SighupCh <- struct{}{}
checkFinished() select {
time.Sleep(2 * time.Second) case <-cmd.reloadedCh:
checkFinished() case <-time.After(5 * time.Second):
t.Fatalf("timeout")
}
if err := testCertificateName("bar.example.com"); err != nil { if err := testCertificateName("bar.example.com"); err != nil {
t.Fatalf("certificate name didn't check out: %s", err) t.Fatalf("certificate name didn't check out: %s", err)
} }
c.ShutdownCh <- struct{}{} cmd.ShutdownCh <- struct{}{}
wg.Wait() wg.Wait()
} }
func TestServer(t *testing.T) {
t.Parallel()
cases := []struct {
name string
contents string
exp string
code int
}{
{
"common_ha",
testBaseHCL(t) + consulHCL,
"(HA available)",
0,
},
{
"separate_ha",
testBaseHCL(t) + consulHCL + haConsulHCL,
"HA Storage:",
0,
},
{
"bad_separate_ha",
testBaseHCL(t) + consulHCL + badHAConsulHCL,
"Specified HA storage does not support HA",
1,
},
}
for _, tc := range cases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
ui, cmd := testServerCommand(t)
f, err := ioutil.TempFile("", "")
if err != nil {
t.Fatalf("error creating temp dir: %v", err)
}
f.WriteString(tc.contents)
f.Close()
defer os.Remove(f.Name())
code := cmd.Run([]string{
"-config", f.Name(),
"-test-verify-only",
})
output := ui.ErrorWriter.String() + ui.OutputWriter.String()
if code != tc.code {
t.Errorf("expected %d to be %d: %s", code, tc.code, output)
}
if !strings.Contains(output, tc.exp) {
t.Fatalf("expected %q to contain %q", output, tc.exp)
}
})
}
}

View File

@@ -32,14 +32,14 @@ func TestStatusCommand_Run(t *testing.T) {
"unsealed", "unsealed",
nil, nil,
false, false,
"Sealed: false", "Sealed false",
0, 0,
}, },
{ {
"sealed", "sealed",
nil, nil,
true, true,
"Sealed: true", "Sealed true",
2, 2,
}, },
{ {