From 2f4ae7f2158cd415bbe120fa5340a7deb3995ebc Mon Sep 17 00:00:00 2001 From: Vishal Nayak Date: Tue, 15 Oct 2019 00:55:31 -0400 Subject: [PATCH] Recovery Mode (#7559) * Initial work * rework * s/dr/recovery * Add sys/raw support to recovery mode (#7577) * Factor the raw paths out so they can be run with a SystemBackend. # Conflicts: # vault/logical_system.go * Add handleLogicalRecovery which is like handleLogical but is only sufficient for use with the sys-raw endpoint in recovery mode. No authentication is done yet. * Integrate with recovery-mode. We now handle unauthenticated sys/raw requests, albeit on path v1/raw instead v1/sys/raw. * Use sys/raw instead raw during recovery. * Don't bother persisting the recovery token. Authenticate sys/raw requests with it. * RecoveryMode: Support generate-root for autounseals (#7591) * Recovery: Abstract config creation and log settings * Recovery mode integration test. (#7600) * Recovery: Touch up (#7607) * Recovery: Touch up * revert the raw backend creation changes * Added recovery operation token prefix * Move RawBackend to its own file * Update API path and hit it using CLI flag on generate-root * Fix a panic triggered when handling a request that yields a nil response. (#7618) * Improve integ test to actually make changes while in recovery mode and verify they're still there after coming back in regular mode. * Refuse to allow a second recovery token to be generated. * Resize raft cluster to size 1 and start as leader (#7626) * RecoveryMode: Setup raft cluster post unseal (#7635) * Setup raft cluster post unseal in recovery mode * Remove marking as unsealed as its not needed * Address review comments * Accept only one seal config in recovery mode as there is no scope for migration --- api/sys_generate_root.go | 16 + command/operator_generate_root.go | 112 +++- command/server.go | 494 +++++++++++++++--- helper/testhelpers/testhelpers.go | 44 +- http/handler.go | 96 ++-- http/logical.go | 60 ++- http/sys_generate_root.go | 4 +- http/sys_init.go | 2 +- http/sys_raft.go | 2 +- http/sys_rekey.go | 6 +- http/sys_seal.go | 6 +- physical/raft/raft.go | 33 +- vault/core.go | 41 +- .../external_tests/api/sys_rekey_ext_test.go | 6 +- .../external_tests/misc/kvv2_upgrade_test.go | 2 +- .../misc/recover_from_panic_test.go | 2 +- vault/external_tests/misc/recovery_test.go | 140 +++++ vault/generate_root.go | 125 ++++- vault/generate_root_recovery.go | 31 ++ vault/init.go | 25 + vault/logical_raw.go | 216 ++++++++ vault/logical_system.go | 178 +------ vault/request_handling.go | 3 + vault/testing.go | 33 +- 24 files changed, 1264 insertions(+), 413 deletions(-) create mode 100644 vault/external_tests/misc/recovery_test.go create mode 100644 vault/generate_root_recovery.go create mode 100644 vault/logical_raw.go diff --git a/api/sys_generate_root.go b/api/sys_generate_root.go index 66f72dff69..870dacb09e 100644 --- a/api/sys_generate_root.go +++ b/api/sys_generate_root.go @@ -10,6 +10,10 @@ func (c *Sys) GenerateDROperationTokenStatus() (*GenerateRootStatusResponse, err return c.generateRootStatusCommon("/v1/sys/replication/dr/secondary/generate-operation-token/attempt") } +func (c *Sys) GenerateRecoveryOperationTokenStatus() (*GenerateRootStatusResponse, error) { + return c.generateRootStatusCommon("/v1/sys/generate-recovery-token/attempt") +} + func (c *Sys) generateRootStatusCommon(path string) (*GenerateRootStatusResponse, error) { r := c.c.NewRequest("GET", path) @@ -34,6 +38,10 @@ func (c *Sys) GenerateDROperationTokenInit(otp, pgpKey string) (*GenerateRootSta return c.generateRootInitCommon("/v1/sys/replication/dr/secondary/generate-operation-token/attempt", otp, pgpKey) } +func (c *Sys) GenerateRecoveryOperationTokenInit(otp, pgpKey string) (*GenerateRootStatusResponse, error) { + return c.generateRootInitCommon("/v1/sys/generate-recovery-token/attempt", otp, pgpKey) +} + func (c *Sys) generateRootInitCommon(path, otp, pgpKey string) (*GenerateRootStatusResponse, error) { body := map[string]interface{}{ "otp": otp, @@ -66,6 +74,10 @@ func (c *Sys) GenerateDROperationTokenCancel() error { return c.generateRootCancelCommon("/v1/sys/replication/dr/secondary/generate-operation-token/attempt") } +func (c *Sys) GenerateRecoveryOperationTokenCancel() error { + return c.generateRootCancelCommon("/v1/sys/generate-recovery-token/attempt") +} + func (c *Sys) generateRootCancelCommon(path string) error { r := c.c.NewRequest("DELETE", path) @@ -86,6 +98,10 @@ func (c *Sys) GenerateDROperationTokenUpdate(shard, nonce string) (*GenerateRoot return c.generateRootUpdateCommon("/v1/sys/replication/dr/secondary/generate-operation-token/update", shard, nonce) } +func (c *Sys) GenerateRecoveryOperationTokenUpdate(shard, nonce string) (*GenerateRootStatusResponse, error) { + return c.generateRootUpdateCommon("/v1/sys/generate-recovery-token/update", shard, nonce) +} + func (c *Sys) generateRootUpdateCommon(path, shard, nonce string) (*GenerateRootStatusResponse, error) { body := map[string]interface{}{ "key": shard, diff --git a/command/operator_generate_root.go b/command/operator_generate_root.go index fcecaaf1a2..78e6793a1e 100644 --- a/command/operator_generate_root.go +++ b/command/operator_generate_root.go @@ -23,18 +23,27 @@ import ( var _ cli.Command = (*OperatorGenerateRootCommand)(nil) var _ cli.CommandAutocomplete = (*OperatorGenerateRootCommand)(nil) +type generateRootKind int + +const ( + generateRootRegular generateRootKind = iota + generateRootDR + generateRootRecovery +) + type OperatorGenerateRootCommand struct { *BaseCommand - flagInit bool - flagCancel bool - flagStatus bool - flagDecode string - flagOTP string - flagPGPKey string - flagNonce string - flagGenerateOTP bool - flagDRToken bool + flagInit bool + flagCancel bool + flagStatus bool + flagDecode string + flagOTP string + flagPGPKey string + flagNonce string + flagGenerateOTP bool + flagDRToken bool + flagRecoveryToken bool testStdin io.Reader // for tests } @@ -143,6 +152,16 @@ func (c *OperatorGenerateRootCommand) Flags() *FlagSets { "tokens.", }) + f.BoolVar(&BoolVar{ + Name: "recovery-token", + Target: &c.flagRecoveryToken, + Default: false, + EnvVar: "", + Completion: complete.PredictNothing, + Usage: "Set this flag to do generate root operations on Recovery Operational " + + "tokens.", + }) + f.StringVar(&StringVar{ Name: "otp", Target: &c.flagOTP, @@ -200,43 +219,60 @@ func (c *OperatorGenerateRootCommand) Run(args []string) int { return 1 } + if c.flagDRToken && c.flagRecoveryToken { + c.UI.Error("Both -recovery-token and -dr-token flags are set") + return 1 + } + client, err := c.Client() if err != nil { c.UI.Error(err.Error()) return 2 } + kind := generateRootRegular + switch { + case c.flagDRToken: + kind = generateRootDR + case c.flagRecoveryToken: + kind = generateRootRecovery + } + switch { case c.flagGenerateOTP: - otp, code := c.generateOTP(client, c.flagDRToken) + otp, code := c.generateOTP(client, kind) if code == 0 { return PrintRaw(c.UI, otp) } return code case c.flagDecode != "": - return c.decode(client, c.flagDecode, c.flagOTP, c.flagDRToken) + return c.decode(client, c.flagDecode, c.flagOTP, kind) case c.flagCancel: - return c.cancel(client, c.flagDRToken) + return c.cancel(client, kind) case c.flagInit: - return c.init(client, c.flagOTP, c.flagPGPKey, c.flagDRToken) + return c.init(client, c.flagOTP, c.flagPGPKey, kind) case c.flagStatus: - return c.status(client, c.flagDRToken) + return c.status(client, kind) default: // If there are no other flags, prompt for an unseal key. key := "" if len(args) > 0 { key = strings.TrimSpace(args[0]) } - return c.provide(client, key, c.flagDRToken) + return c.provide(client, key, kind) } } // generateOTP generates a suitable OTP code for generating a root token. -func (c *OperatorGenerateRootCommand) generateOTP(client *api.Client, drToken bool) (string, int) { +func (c *OperatorGenerateRootCommand) generateOTP(client *api.Client, kind generateRootKind) (string, int) { f := client.Sys().GenerateRootStatus - if drToken { + switch kind { + case generateRootDR: f = client.Sys().GenerateDROperationTokenStatus + case generateRootRecovery: + f = client.Sys().GenerateRecoveryOperationTokenStatus } + status, err := f() if err != nil { c.UI.Error(fmt.Sprintf("Error getting root generation status: %s", err)) @@ -272,7 +308,7 @@ func (c *OperatorGenerateRootCommand) generateOTP(client *api.Client, drToken bo } // decode decodes the given value using the otp. -func (c *OperatorGenerateRootCommand) decode(client *api.Client, encoded, otp string, drToken bool) int { +func (c *OperatorGenerateRootCommand) decode(client *api.Client, encoded, otp string, kind generateRootKind) int { if encoded == "" { c.UI.Error("Missing encoded value: use -decode= to supply it") return 1 @@ -283,9 +319,13 @@ func (c *OperatorGenerateRootCommand) decode(client *api.Client, encoded, otp st } f := client.Sys().GenerateRootStatus - if drToken { + switch kind { + case generateRootDR: f = client.Sys().GenerateDROperationTokenStatus + case generateRootRecovery: + f = client.Sys().GenerateRecoveryOperationTokenStatus } + status, err := f() if err != nil { c.UI.Error(fmt.Sprintf("Error getting root generation status: %s", err)) @@ -327,7 +367,7 @@ func (c *OperatorGenerateRootCommand) decode(client *api.Client, encoded, otp st } // init is used to start the generation process -func (c *OperatorGenerateRootCommand) init(client *api.Client, otp, pgpKey string, drToken bool) int { +func (c *OperatorGenerateRootCommand) init(client *api.Client, otp, pgpKey string, kind generateRootKind) int { // Validate incoming fields. Either OTP OR PGP keys must be supplied. if otp != "" && pgpKey != "" { c.UI.Error("Error initializing: cannot specify both -otp and -pgp-key") @@ -336,8 +376,11 @@ func (c *OperatorGenerateRootCommand) init(client *api.Client, otp, pgpKey strin // Start the root generation f := client.Sys().GenerateRootInit - if drToken { + switch kind { + case generateRootDR: f = client.Sys().GenerateDROperationTokenInit + case generateRootRecovery: + f = client.Sys().GenerateRecoveryOperationTokenInit } status, err := f(otp, pgpKey) if err != nil { @@ -355,10 +398,13 @@ func (c *OperatorGenerateRootCommand) init(client *api.Client, otp, pgpKey strin // provide prompts the user for the seal key and posts it to the update root // endpoint. If this is the last unseal, this function outputs it. -func (c *OperatorGenerateRootCommand) provide(client *api.Client, key string, drToken bool) int { +func (c *OperatorGenerateRootCommand) provide(client *api.Client, key string, kind generateRootKind) int { f := client.Sys().GenerateRootStatus - if drToken { + switch kind { + case generateRootDR: f = client.Sys().GenerateDROperationTokenStatus + case generateRootRecovery: + f = client.Sys().GenerateRecoveryOperationTokenStatus } status, err := f() if err != nil { @@ -437,8 +483,11 @@ func (c *OperatorGenerateRootCommand) provide(client *api.Client, key string, dr // Provide the key, this may potentially complete the update fUpd := client.Sys().GenerateRootUpdate - if drToken { + switch kind { + case generateRootDR: fUpd = client.Sys().GenerateDROperationTokenUpdate + case generateRootRecovery: + fUpd = client.Sys().GenerateRecoveryOperationTokenUpdate } status, err = fUpd(key, nonce) if err != nil { @@ -454,10 +503,13 @@ func (c *OperatorGenerateRootCommand) provide(client *api.Client, key string, dr } // cancel cancels the root token generation -func (c *OperatorGenerateRootCommand) cancel(client *api.Client, drToken bool) int { +func (c *OperatorGenerateRootCommand) cancel(client *api.Client, kind generateRootKind) int { f := client.Sys().GenerateRootCancel - if drToken { + switch kind { + case generateRootDR: f = client.Sys().GenerateDROperationTokenCancel + case generateRootRecovery: + f = client.Sys().GenerateRecoveryOperationTokenCancel } if err := f(); err != nil { c.UI.Error(fmt.Sprintf("Error canceling root token generation: %s", err)) @@ -468,11 +520,15 @@ func (c *OperatorGenerateRootCommand) cancel(client *api.Client, drToken bool) i } // status is used just to fetch and dump the status -func (c *OperatorGenerateRootCommand) status(client *api.Client, drToken bool) int { +func (c *OperatorGenerateRootCommand) status(client *api.Client, kind generateRootKind) int { f := client.Sys().GenerateRootStatus - if drToken { + switch kind { + case generateRootDR: f = client.Sys().GenerateDROperationTokenStatus + case generateRootRecovery: + f = client.Sys().GenerateRecoveryOperationTokenStatus } + status, err := f() if err != nil { c.UI.Error(fmt.Sprintf("Error getting root generation status: %s", err)) diff --git a/command/server.go b/command/server.go index e113cbc39f..bd0ee9978e 100644 --- a/command/server.go +++ b/command/server.go @@ -6,6 +6,7 @@ import ( "encoding/base64" "encoding/hex" "fmt" + "go.uber.org/atomic" "io" "io/ioutil" "net" @@ -99,6 +100,7 @@ type ServerCommand struct { flagConfigs []string flagLogLevel string flagLogFormat string + flagRecovery bool flagDev bool flagDevRootTokenID string flagDevListenAddr string @@ -197,6 +199,13 @@ func (c *ServerCommand) Flags() *FlagSets { Usage: `Log format. Supported values are "standard" and "json".`, }) + f.BoolVar(&BoolVar{ + Name: "recovery", + Target: &c.flagRecovery, + Usage: "Enable recovery mode. In this mode, Vault is used to perform recovery actions." + + "Using a recovery operation token, \"sys/raw\" API can be used to manipulate the storage.", + }) + f = set.NewFlagSet("Dev Options") f.BoolVar(&BoolVar{ @@ -365,6 +374,384 @@ func (c *ServerCommand) AutocompleteFlags() complete.Flags { return c.Flags().Completions() } +func (c *ServerCommand) parseConfig() (*server.Config, error) { + // Load the configuration + var config *server.Config + for _, path := range c.flagConfigs { + current, err := server.LoadConfig(path) + if err != nil { + return nil, errwrap.Wrapf(fmt.Sprintf("error loading configuration from %s: {{err}}", path), err) + } + + if config == nil { + config = current + } else { + config = config.Merge(current) + } + } + return config, nil +} + +func (c *ServerCommand) runRecoveryMode() int { + config, err := c.parseConfig() + if err != nil { + c.UI.Error(err.Error()) + return 1 + } + + // Ensure at least one config was found. + if config == nil { + c.UI.Output(wrapAtLength( + "No configuration files found. Please provide configurations with the " + + "-config flag. If you are supplying the path to a directory, please " + + "ensure the directory contains files with the .hcl or .json " + + "extension.")) + return 1 + } + + level, logLevelString, logLevelWasNotSet, logFormat, err := c.processLogLevelAndFormat(config) + if err != nil { + c.UI.Error(err.Error()) + return 1 + } + + c.logger = log.New(&log.LoggerOptions{ + Output: c.logWriter, + Level: level, + // Note that if logFormat is either unspecified or standard, then + // the resulting logger's format will be standard. + JSONFormat: logFormat == logging.JSONFormat, + }) + + logLevelStr, err := c.adjustLogLevel(config, logLevelWasNotSet) + if err != nil { + c.UI.Error(err.Error()) + return 1 + } + if logLevelStr != "" { + logLevelString = logLevelStr + } + + // create GRPC logger + namedGRPCLogFaker := c.logger.Named("grpclogfaker") + grpclog.SetLogger(&grpclogFaker{ + logger: namedGRPCLogFaker, + log: os.Getenv("VAULT_GRPC_LOGGING") != "", + }) + + if config.Storage == nil { + c.UI.Output("A storage backend must be specified") + return 1 + } + + if config.DefaultMaxRequestDuration != 0 { + vault.DefaultMaxRequestDuration = config.DefaultMaxRequestDuration + } + + proxyCfg := httpproxy.FromEnvironment() + c.logger.Info("proxy environment", "http_proxy", proxyCfg.HTTPProxy, + "https_proxy", proxyCfg.HTTPSProxy, "no_proxy", proxyCfg.NoProxy) + + // Initialize the storage backend + factory, exists := c.PhysicalBackends[config.Storage.Type] + if !exists { + c.UI.Error(fmt.Sprintf("Unknown storage type %s", config.Storage.Type)) + return 1 + } + if config.Storage.Type == "raft" { + if envCA := os.Getenv("VAULT_CLUSTER_ADDR"); envCA != "" { + config.ClusterAddr = envCA + } + + if len(config.ClusterAddr) == 0 { + c.UI.Error("Cluster address must be set when using raft storage") + return 1 + } + } + + namedStorageLogger := c.logger.Named("storage." + config.Storage.Type) + backend, err := factory(config.Storage.Config, namedStorageLogger) + if err != nil { + c.UI.Error(fmt.Sprintf("Error initializing storage of type %s: %s", config.Storage.Type, err)) + return 1 + } + + infoKeys := make([]string, 0, 10) + info := make(map[string]string) + info["log level"] = logLevelString + infoKeys = append(infoKeys, "log level") + + var barrierSeal vault.Seal + var sealConfigError error + + if len(config.Seals) == 0 { + config.Seals = append(config.Seals, &server.Seal{Type: vaultseal.Shamir}) + } + + if len(config.Seals) > 1 { + c.UI.Error("Only one seal block is accepted in recovery mode") + return 1 + } + + configSeal := config.Seals[0] + sealType := vaultseal.Shamir + if !configSeal.Disabled && os.Getenv("VAULT_SEAL_TYPE") != "" { + sealType = os.Getenv("VAULT_SEAL_TYPE") + configSeal.Type = sealType + } else { + sealType = configSeal.Type + } + + var seal vault.Seal + sealLogger := c.logger.Named(sealType) + seal, sealConfigError = serverseal.ConfigureSeal(configSeal, &infoKeys, &info, sealLogger, vault.NewDefaultSeal(shamirseal.NewSeal(c.logger.Named("shamir")))) + if sealConfigError != nil { + if !errwrap.ContainsType(sealConfigError, new(logical.KeyNotFoundError)) { + c.UI.Error(fmt.Sprintf( + "Error parsing Seal configuration: %s", sealConfigError)) + return 1 + } + } + if seal == nil { + c.UI.Error(fmt.Sprintf( + "After configuring seal nil returned, seal type was %s", sealType)) + return 1 + } + + barrierSeal = seal + + // Ensure that the seal finalizer is called, even if using verify-only + defer func() { + err = seal.Finalize(context.Background()) + if err != nil { + c.UI.Error(fmt.Sprintf("Error finalizing seals: %v", err)) + } + }() + + coreConfig := &vault.CoreConfig{ + Physical: backend, + StorageType: config.Storage.Type, + Seal: barrierSeal, + Logger: c.logger, + DisableMlock: config.DisableMlock, + RecoveryMode: c.flagRecovery, + ClusterAddr: config.ClusterAddr, + } + + core, newCoreError := vault.NewCore(coreConfig) + if newCoreError != nil { + if vault.IsFatalError(newCoreError) { + c.UI.Error(fmt.Sprintf("Error initializing core: %s", newCoreError)) + return 1 + } + } + + if err := core.InitializeRecovery(context.Background()); err != nil { + c.UI.Error(fmt.Sprintf("Error initializing core in recovery mode: %s", err)) + return 1 + } + + // Compile server information for output later + infoKeys = append(infoKeys, "storage") + info["storage"] = config.Storage.Type + + if coreConfig.ClusterAddr != "" { + info["cluster address"] = coreConfig.ClusterAddr + infoKeys = append(infoKeys, "cluster address") + } + + // Initialize the listeners + lns := make([]ServerListener, 0, len(config.Listeners)) + for _, lnConfig := range config.Listeners { + ln, _, _, err := server.NewListener(lnConfig.Type, lnConfig.Config, c.logWriter, c.UI) + if err != nil { + c.UI.Error(fmt.Sprintf("Error initializing listener of type %s: %s", lnConfig.Type, err)) + return 1 + } + + lns = append(lns, ServerListener{ + Listener: ln, + config: lnConfig.Config, + }) + } + + listenerCloseFunc := func() { + for _, ln := range lns { + ln.Listener.Close() + } + } + + defer c.cleanupGuard.Do(listenerCloseFunc) + + infoKeys = append(infoKeys, "version") + verInfo := version.GetVersion() + info["version"] = verInfo.FullVersionNumber(false) + if verInfo.Revision != "" { + info["version sha"] = strings.Trim(verInfo.Revision, "'") + infoKeys = append(infoKeys, "version sha") + } + + infoKeys = append(infoKeys, "recovery mode") + info["recovery mode"] = "true" + + // Server configuration output + padding := 24 + sort.Strings(infoKeys) + c.UI.Output("==> Vault server configuration:\n") + for _, k := range infoKeys { + c.UI.Output(fmt.Sprintf( + "%s%s: %s", + strings.Repeat(" ", padding-len(k)), + strings.Title(k), + info[k])) + } + c.UI.Output("") + + for _, ln := range lns { + handler := vaulthttp.Handler(&vault.HandlerProperties{ + Core: core, + MaxRequestSize: ln.maxRequestSize, + MaxRequestDuration: ln.maxRequestDuration, + DisablePrintableCheck: config.DisablePrintableCheck, + RecoveryMode: c.flagRecovery, + RecoveryToken: atomic.NewString(""), + }) + + server := &http.Server{ + Handler: handler, + ReadHeaderTimeout: 10 * time.Second, + ReadTimeout: 30 * time.Second, + IdleTimeout: 5 * time.Minute, + ErrorLog: c.logger.StandardLogger(nil), + } + + go server.Serve(ln.Listener) + } + + if sealConfigError != nil { + init, err := core.Initialized(context.Background()) + if err != nil { + c.UI.Error(fmt.Sprintf("Error checking if core is initialized: %v", err)) + return 1 + } + if init { + c.UI.Error("Vault is initialized but no Seal key could be loaded") + return 1 + } + } + + if newCoreError != nil { + c.UI.Warn(wrapAtLength( + "WARNING! A non-fatal error occurred during initialization. Please " + + "check the logs for more information.")) + c.UI.Warn("") + } + + if !c.flagCombineLogs { + c.UI.Output("==> Vault server started! Log data will stream in below:\n") + } + + c.logGate.Flush() + + for { + select { + case <-c.ShutdownCh: + c.UI.Output("==> Vault shutdown triggered") + + c.cleanupGuard.Do(listenerCloseFunc) + + if err := core.Shutdown(); err != nil { + c.UI.Error(fmt.Sprintf("Error with core shutdown: %s", err)) + } + + return 0 + + case <-c.SigUSR2Ch: + buf := make([]byte, 32*1024*1024) + n := runtime.Stack(buf[:], true) + c.logger.Info("goroutine trace", "stack", string(buf[:n])) + } + } + + return 0 +} + +func (c *ServerCommand) adjustLogLevel(config *server.Config, logLevelWasNotSet bool) (string, error) { + var logLevelString string + if config.LogLevel != "" && logLevelWasNotSet { + configLogLevel := strings.ToLower(strings.TrimSpace(config.LogLevel)) + logLevelString = configLogLevel + switch configLogLevel { + case "trace": + c.logger.SetLevel(log.Trace) + case "debug": + c.logger.SetLevel(log.Debug) + case "notice", "info", "": + c.logger.SetLevel(log.Info) + case "warn", "warning": + c.logger.SetLevel(log.Warn) + case "err", "error": + c.logger.SetLevel(log.Error) + default: + return "", fmt.Errorf("unknown log level: %s", config.LogLevel) + } + } + return logLevelString, nil +} + +func (c *ServerCommand) processLogLevelAndFormat(config *server.Config) (log.Level, string, bool, logging.LogFormat, error) { + // Create a logger. We wrap it in a gated writer so that it doesn't + // start logging too early. + c.logGate = &gatedwriter.Writer{Writer: os.Stderr} + c.logWriter = c.logGate + if c.flagCombineLogs { + c.logWriter = os.Stdout + } + var level log.Level + var logLevelWasNotSet bool + logFormat := logging.UnspecifiedFormat + logLevelString := c.flagLogLevel + c.flagLogLevel = strings.ToLower(strings.TrimSpace(c.flagLogLevel)) + switch c.flagLogLevel { + case notSetValue, "": + logLevelWasNotSet = true + logLevelString = "info" + level = log.Info + case "trace": + level = log.Trace + case "debug": + level = log.Debug + case "notice", "info": + level = log.Info + case "warn", "warning": + level = log.Warn + case "err", "error": + level = log.Error + default: + return level, logLevelString, logLevelWasNotSet, logFormat, fmt.Errorf("unknown log level: %s", c.flagLogLevel) + } + + if c.flagLogFormat != notSetValue { + var err error + logFormat, err = logging.ParseLogFormat(c.flagLogFormat) + if err != nil { + return level, logLevelString, logLevelWasNotSet, logFormat, err + } + } + if logFormat == logging.UnspecifiedFormat { + logFormat = logging.ParseEnvLogFormat() + } + if logFormat == logging.UnspecifiedFormat { + var err error + logFormat, err = logging.ParseLogFormat(config.LogFormat) + if err != nil { + return level, logLevelString, logLevelWasNotSet, logFormat, err + } + } + + return level, logLevelString, logLevelWasNotSet, logFormat, nil +} + func (c *ServerCommand) Run(args []string) int { f := c.Flags() @@ -373,6 +760,10 @@ func (c *ServerCommand) Run(args []string) int { return 1 } + if c.flagRecovery { + return c.runRecoveryMode() + } + // Automatically enable dev mode if other dev flags are provided. if c.flagDevHA || c.flagDevTransactional || c.flagDevLeasedKV || c.flagDevThreeNode || c.flagDevFourCluster || c.flagDevAutoSeal || c.flagDevKVV1 { c.flagDev = true @@ -413,18 +804,16 @@ func (c *ServerCommand) Run(args []string) int { config.Listeners[0].Config["address"] = c.flagDevListenAddr } } - for _, path := range c.flagConfigs { - current, err := server.LoadConfig(path) - if err != nil { - c.UI.Error(fmt.Sprintf("Error loading configuration from %s: %s", path, err)) - return 1 - } - if config == nil { - config = current - } else { - config = config.Merge(current) - } + parsedConfig, err := c.parseConfig() + if err != nil { + c.UI.Error(err.Error()) + return 1 + } + if config == nil { + config = parsedConfig + } else { + config = config.Merge(parsedConfig) } // Ensure at least one config was found. @@ -437,58 +826,12 @@ func (c *ServerCommand) Run(args []string) int { return 1 } - // Create a logger. We wrap it in a gated writer so that it doesn't - // start logging too early. - c.logGate = &gatedwriter.Writer{Writer: os.Stderr} - c.logWriter = c.logGate - if c.flagCombineLogs { - c.logWriter = os.Stdout - } - var level log.Level - var logLevelWasNotSet bool - logLevelString := c.flagLogLevel - c.flagLogLevel = strings.ToLower(strings.TrimSpace(c.flagLogLevel)) - switch c.flagLogLevel { - case notSetValue, "": - logLevelWasNotSet = true - logLevelString = "info" - level = log.Info - case "trace": - level = log.Trace - case "debug": - level = log.Debug - case "notice", "info": - level = log.Info - case "warn", "warning": - level = log.Warn - case "err", "error": - level = log.Error - default: - c.UI.Error(fmt.Sprintf("Unknown log level: %s", c.flagLogLevel)) + level, logLevelString, logLevelWasNotSet, logFormat, err := c.processLogLevelAndFormat(config) + if err != nil { + c.UI.Error(err.Error()) return 1 } - logFormat := logging.UnspecifiedFormat - if c.flagLogFormat != notSetValue { - var err error - logFormat, err = logging.ParseLogFormat(c.flagLogFormat) - if err != nil { - c.UI.Error(err.Error()) - return 1 - } - } - if logFormat == logging.UnspecifiedFormat { - logFormat = logging.ParseEnvLogFormat() - } - if logFormat == logging.UnspecifiedFormat { - var err error - logFormat, err = logging.ParseLogFormat(config.LogFormat) - if err != nil { - c.UI.Error(err.Error()) - return 1 - } - } - if c.flagDevThreeNode || c.flagDevFourCluster { c.logger = log.New(&log.LoggerOptions{ Mutex: &sync.Mutex{}, @@ -507,25 +850,13 @@ func (c *ServerCommand) Run(args []string) int { allLoggers := []log.Logger{c.logger} - // adjust log level based on config setting - if config.LogLevel != "" && logLevelWasNotSet { - configLogLevel := strings.ToLower(strings.TrimSpace(config.LogLevel)) - logLevelString = configLogLevel - switch configLogLevel { - case "trace": - c.logger.SetLevel(log.Trace) - case "debug": - c.logger.SetLevel(log.Debug) - case "notice", "info", "": - c.logger.SetLevel(log.Info) - case "warn", "warning": - c.logger.SetLevel(log.Warn) - case "err", "error": - c.logger.SetLevel(log.Error) - default: - c.UI.Error(fmt.Sprintf("Unknown log level: %s", config.LogLevel)) - return 1 - } + logLevelStr, err := c.adjustLogLevel(config, logLevelWasNotSet) + if err != nil { + c.UI.Error(err.Error()) + return 1 + } + if logLevelStr != "" { + logLevelString = logLevelStr } // create GRPC logger @@ -580,7 +911,6 @@ func (c *ServerCommand) Run(args []string) int { return 1 } - if config.Storage.Type == "raft" { if envCA := os.Getenv("VAULT_CLUSTER_ADDR"); envCA != "" { config.ClusterAddr = envCA @@ -1066,6 +1396,9 @@ CLUSTER_SYNTHESIS_COMPLETE: info["cgo"] = "enabled" } + infoKeys = append(infoKeys, "recovery mode") + info["recovery mode"] = "false" + // Server configuration output padding := 24 sort.Strings(infoKeys) @@ -1263,6 +1596,7 @@ CLUSTER_SYNTHESIS_COMPLETE: MaxRequestDuration: ln.maxRequestDuration, DisablePrintableCheck: config.DisablePrintableCheck, UnauthenticatedMetricsAccess: ln.unauthenticatedMetricsAccess, + RecoveryMode: c.flagRecovery, }) // We perform validation on the config earlier, we can just cast here diff --git a/helper/testhelpers/testhelpers.go b/helper/testhelpers/testhelpers.go index 79bb2f1292..945671d13a 100644 --- a/helper/testhelpers/testhelpers.go +++ b/helper/testhelpers/testhelpers.go @@ -19,16 +19,26 @@ import ( "github.com/mitchellh/go-testing-interface" ) +type GenerateRootKind int + +const ( + GenerateRootRegular GenerateRootKind = iota + GenerateRootDR + GenerateRecovery +) + // Generates a root token on the target cluster. -func GenerateRoot(t testing.T, cluster *vault.TestCluster, drToken bool) string { - token, err := GenerateRootWithError(t, cluster, drToken) +func GenerateRoot(t testing.T, cluster *vault.TestCluster, kind GenerateRootKind) string { + t.Helper() + token, err := GenerateRootWithError(t, cluster, kind) if err != nil { t.Fatal(err) } return token } -func GenerateRootWithError(t testing.T, cluster *vault.TestCluster, drToken bool) (string, error) { +func GenerateRootWithError(t testing.T, cluster *vault.TestCluster, kind GenerateRootKind) (string, error) { + t.Helper() // If recovery keys supported, use those to perform root token generation instead var keys [][]byte if cluster.Cores[0].SealAccess().RecoveryKeySupported() { @@ -36,13 +46,18 @@ func GenerateRootWithError(t testing.T, cluster *vault.TestCluster, drToken bool } else { keys = cluster.BarrierKeys } - client := cluster.Cores[0].Client - f := client.Sys().GenerateRootInit - if drToken { - f = client.Sys().GenerateDROperationTokenInit + + var err error + var status *api.GenerateRootStatusResponse + switch kind { + case GenerateRootRegular: + status, err = client.Sys().GenerateRootInit("", "") + case GenerateRootDR: + status, err = client.Sys().GenerateDROperationTokenInit("", "") + case GenerateRecovery: + status, err = client.Sys().GenerateRecoveryOperationTokenInit("", "") } - status, err := f("", "") if err != nil { return "", err } @@ -57,11 +72,16 @@ func GenerateRootWithError(t testing.T, cluster *vault.TestCluster, drToken bool if i >= status.Required { break } - f := client.Sys().GenerateRootUpdate - if drToken { - f = client.Sys().GenerateDROperationTokenUpdate + + strKey := base64.StdEncoding.EncodeToString(key) + switch kind { + case GenerateRootRegular: + status, err = client.Sys().GenerateRootUpdate(strKey, status.Nonce) + case GenerateRootDR: + status, err = client.Sys().GenerateDROperationTokenUpdate(strKey, status.Nonce) + case GenerateRecovery: + status, err = client.Sys().GenerateRecoveryOperationTokenUpdate(strKey, status.Nonce) } - status, err = f(base64.StdEncoding.EncodeToString(key), status.Nonce) if err != nil { return "", err } diff --git a/http/handler.go b/http/handler.go index 8ccc436b1f..079a4950e2 100644 --- a/http/handler.go +++ b/http/handler.go @@ -6,6 +6,7 @@ import ( "encoding/json" "errors" "fmt" + "github.com/NYTimes/gziphandler" "io" "io/ioutil" "net" @@ -16,11 +17,10 @@ import ( "strings" "time" - "github.com/NYTimes/gziphandler" assetfs "github.com/elazarl/go-bindata-assetfs" "github.com/hashicorp/errwrap" - cleanhttp "github.com/hashicorp/go-cleanhttp" - sockaddr "github.com/hashicorp/go-sockaddr" + "github.com/hashicorp/go-cleanhttp" + "github.com/hashicorp/go-sockaddr" "github.com/hashicorp/vault/helper/namespace" "github.com/hashicorp/vault/sdk/helper/consts" "github.com/hashicorp/vault/sdk/helper/jsonutil" @@ -111,49 +111,57 @@ func Handler(props *vault.HandlerProperties) http.Handler { // Create the muxer to handle the actual endpoints mux := http.NewServeMux() - // Handle non-forwarded paths - mux.Handle("/v1/sys/config/state/", handleLogicalNoForward(core)) - mux.Handle("/v1/sys/host-info", handleLogicalNoForward(core)) - mux.Handle("/v1/sys/pprof/", handleLogicalNoForward(core)) + switch { + case props.RecoveryMode: + raw := vault.NewRawBackend(core) + strategy := vault.GenerateRecoveryTokenStrategy(props.RecoveryToken) + mux.Handle("/v1/sys/raw/", handleLogicalRecovery(raw, props.RecoveryToken)) + mux.Handle("/v1/sys/generate-recovery-token/attempt", handleSysGenerateRootAttempt(core, strategy)) + mux.Handle("/v1/sys/generate-recovery-token/update", handleSysGenerateRootUpdate(core, strategy)) + default: + // Handle pprof paths + mux.Handle("/v1/sys/pprof/", handleLogicalNoForward(core)) - mux.Handle("/v1/sys/init", handleSysInit(core)) - mux.Handle("/v1/sys/seal-status", handleSysSealStatus(core)) - mux.Handle("/v1/sys/seal", handleSysSeal(core)) - mux.Handle("/v1/sys/step-down", handleRequestForwarding(core, handleSysStepDown(core))) - mux.Handle("/v1/sys/unseal", handleSysUnseal(core)) - mux.Handle("/v1/sys/leader", handleSysLeader(core)) - mux.Handle("/v1/sys/health", handleSysHealth(core)) - mux.Handle("/v1/sys/generate-root/attempt", handleRequestForwarding(core, handleSysGenerateRootAttempt(core, vault.GenerateStandardRootTokenStrategy))) - mux.Handle("/v1/sys/generate-root/update", handleRequestForwarding(core, handleSysGenerateRootUpdate(core, vault.GenerateStandardRootTokenStrategy))) - mux.Handle("/v1/sys/rekey/init", handleRequestForwarding(core, handleSysRekeyInit(core, false))) - mux.Handle("/v1/sys/rekey/update", handleRequestForwarding(core, handleSysRekeyUpdate(core, false))) - mux.Handle("/v1/sys/rekey/verify", handleRequestForwarding(core, handleSysRekeyVerify(core, false))) - mux.Handle("/v1/sys/rekey-recovery-key/init", handleRequestForwarding(core, handleSysRekeyInit(core, true))) - mux.Handle("/v1/sys/rekey-recovery-key/update", handleRequestForwarding(core, handleSysRekeyUpdate(core, true))) - mux.Handle("/v1/sys/rekey-recovery-key/verify", handleRequestForwarding(core, handleSysRekeyVerify(core, true))) - mux.Handle("/v1/sys/storage/raft/join", handleSysRaftJoin(core)) - for _, path := range injectDataIntoTopRoutes { - mux.Handle(path, handleRequestForwarding(core, handleLogicalWithInjector(core))) - } - mux.Handle("/v1/sys/", handleRequestForwarding(core, handleLogical(core))) - mux.Handle("/v1/", handleRequestForwarding(core, handleLogical(core))) - if core.UIEnabled() == true { - if uiBuiltIn { - mux.Handle("/ui/", http.StripPrefix("/ui/", gziphandler.GzipHandler(handleUIHeaders(core, handleUI(http.FileServer(&UIAssetWrapper{FileSystem: assetFS()})))))) - mux.Handle("/robots.txt", gziphandler.GzipHandler(handleUIHeaders(core, handleUI(http.FileServer(&UIAssetWrapper{FileSystem: assetFS()}))))) - } else { - mux.Handle("/ui/", handleUIHeaders(core, handleUIStub())) + mux.Handle("/v1/sys/init", handleSysInit(core)) + mux.Handle("/v1/sys/seal-status", handleSysSealStatus(core)) + mux.Handle("/v1/sys/seal", handleSysSeal(core)) + mux.Handle("/v1/sys/step-down", handleRequestForwarding(core, handleSysStepDown(core))) + mux.Handle("/v1/sys/unseal", handleSysUnseal(core)) + mux.Handle("/v1/sys/leader", handleSysLeader(core)) + mux.Handle("/v1/sys/health", handleSysHealth(core)) + mux.Handle("/v1/sys/generate-root/attempt", handleRequestForwarding(core, handleSysGenerateRootAttempt(core, vault.GenerateStandardRootTokenStrategy))) + mux.Handle("/v1/sys/generate-root/update", handleRequestForwarding(core, handleSysGenerateRootUpdate(core, vault.GenerateStandardRootTokenStrategy))) + mux.Handle("/v1/sys/rekey/init", handleRequestForwarding(core, handleSysRekeyInit(core, false))) + mux.Handle("/v1/sys/rekey/update", handleRequestForwarding(core, handleSysRekeyUpdate(core, false))) + mux.Handle("/v1/sys/rekey/verify", handleRequestForwarding(core, handleSysRekeyVerify(core, false))) + mux.Handle("/v1/sys/rekey-recovery-key/init", handleRequestForwarding(core, handleSysRekeyInit(core, true))) + mux.Handle("/v1/sys/rekey-recovery-key/update", handleRequestForwarding(core, handleSysRekeyUpdate(core, true))) + mux.Handle("/v1/sys/rekey-recovery-key/verify", handleRequestForwarding(core, handleSysRekeyVerify(core, true))) + mux.Handle("/v1/sys/storage/raft/join", handleSysRaftJoin(core)) + for _, path := range injectDataIntoTopRoutes { + mux.Handle(path, handleRequestForwarding(core, handleLogicalWithInjector(core))) } - mux.Handle("/ui", handleUIRedirect()) - mux.Handle("/", handleUIRedirect()) - } + mux.Handle("/v1/sys/", handleRequestForwarding(core, handleLogical(core))) + mux.Handle("/v1/", handleRequestForwarding(core, handleLogical(core))) + if core.UIEnabled() == true { + if uiBuiltIn { + mux.Handle("/ui/", http.StripPrefix("/ui/", gziphandler.GzipHandler(handleUIHeaders(core, handleUI(http.FileServer(&UIAssetWrapper{FileSystem: assetFS()})))))) + mux.Handle("/robots.txt", gziphandler.GzipHandler(handleUIHeaders(core, handleUI(http.FileServer(&UIAssetWrapper{FileSystem: assetFS()}))))) + } else { + mux.Handle("/ui/", handleUIHeaders(core, handleUIStub())) + } + mux.Handle("/ui", handleUIRedirect()) + mux.Handle("/", handleUIRedirect()) - // Register metrics path without authentication if enabled - if props.UnauthenticatedMetricsAccess { - mux.Handle("/v1/sys/metrics", handleMetricsUnauthenticated(core)) - } + } - additionalRoutes(mux, core) + // Register metrics path without authentication if enabled + if props.UnauthenticatedMetricsAccess { + mux.Handle("/v1/sys/metrics", handleMetricsUnauthenticated(core)) + } + + additionalRoutes(mux, core) + } // Wrap the handler in another handler to trigger all help paths. helpWrappedHandler := wrapHelpHandler(mux, core) @@ -489,7 +497,7 @@ func parseQuery(values url.Values) map[string]interface{} { return nil } -func parseRequest(core *vault.Core, r *http.Request, w http.ResponseWriter, out interface{}) (io.ReadCloser, error) { +func parseRequest(perfStandby bool, r *http.Request, w http.ResponseWriter, out interface{}) (io.ReadCloser, error) { // Limit the maximum number of bytes to MaxRequestSize to protect // against an indefinite amount of data being read. reader := r.Body @@ -505,7 +513,7 @@ func parseRequest(core *vault.Core, r *http.Request, w http.ResponseWriter, out } } var origBody io.ReadWriter - if core.PerfStandby() { + if perfStandby { // Since we're checking PerfStandby here we key on origBody being nil // or not later, so we need to always allocate so it's non-nil origBody = new(bytes.Buffer) diff --git a/http/logical.go b/http/logical.go index 8b04dbb59f..80b4f66659 100644 --- a/http/logical.go +++ b/http/logical.go @@ -4,6 +4,8 @@ import ( "encoding/base64" "encoding/json" "fmt" + "github.com/hashicorp/vault/sdk/helper/consts" + "go.uber.org/atomic" "io" "net" "net/http" @@ -18,7 +20,7 @@ import ( "github.com/hashicorp/vault/vault" ) -func buildLogicalRequest(core *vault.Core, w http.ResponseWriter, r *http.Request) (*logical.Request, io.ReadCloser, int, error) { +func buildLogicalRequestNoAuth(perfStandby bool, w http.ResponseWriter, r *http.Request) (*logical.Request, io.ReadCloser, int, error) { ns, err := namespace.FromContext(r.Context()) if err != nil { return nil, nil, http.StatusBadRequest, nil @@ -78,7 +80,7 @@ func buildLogicalRequest(core *vault.Core, w http.ResponseWriter, r *http.Reques passHTTPReq = true origBody = r.Body } else { - origBody, err = parseRequest(core, r, w, &data) + origBody, err = parseRequest(perfStandby, r, w, &data) if err == io.EOF { data = nil err = nil @@ -105,14 +107,32 @@ func buildLogicalRequest(core *vault.Core, w http.ResponseWriter, r *http.Reques return nil, nil, http.StatusBadRequest, errwrap.Wrapf("failed to generate identifier for the request: {{err}}", err) } - req, err := requestAuth(core, r, &logical.Request{ + req := &logical.Request{ ID: request_id, Operation: op, Path: path, Data: data, Connection: getConnection(r), Headers: r.Header, - }) + } + + if passHTTPReq { + req.HTTPRequest = r + } + if responseWriter != nil { + req.ResponseWriter = logical.NewHTTPResponseWriter(responseWriter) + } + + return req, origBody, 0, nil +} + +func buildLogicalRequest(core *vault.Core, w http.ResponseWriter, r *http.Request) (*logical.Request, io.ReadCloser, int, error) { + req, origBody, status, err := buildLogicalRequestNoAuth(core.PerfStandby(), w, r) + if err != nil { + return nil, nil, status, err + } + + req, err = requestAuth(core, r, req) if err != nil { if errwrap.Contains(err, logical.ErrPermissionDenied.Error()) { return nil, nil, http.StatusForbidden, nil @@ -135,12 +155,6 @@ func buildLogicalRequest(core *vault.Core, w http.ResponseWriter, r *http.Reques return nil, nil, http.StatusBadRequest, errwrap.Wrapf(fmt.Sprintf(`failed to parse %s header: {{err}}`, PolicyOverrideHeaderName), err) } - if passHTTPReq { - req.HTTPRequest = r - } - if responseWriter != nil { - req.ResponseWriter = logical.NewHTTPResponseWriter(responseWriter) - } return req, origBody, 0, nil } @@ -168,6 +182,32 @@ func handleLogicalNoForward(core *vault.Core) http.Handler { return handleLogicalInternal(core, false, true) } +func handleLogicalRecovery(raw *vault.RawBackend, token *atomic.String) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + req, _, statusCode, err := buildLogicalRequestNoAuth(false, w, r) + if err != nil || statusCode != 0 { + respondError(w, statusCode, err) + return + } + reqToken := r.Header.Get(consts.AuthHeaderName) + if reqToken == "" || token.Load() == "" || reqToken != token.Load() { + respondError(w, http.StatusForbidden, nil) + } + + resp, err := raw.HandleRequest(r.Context(), req) + if respondErrorCommon(w, req, resp, err) { + return + } + + var httpResp *logical.HTTPResponse + if resp != nil { + httpResp = logical.LogicalResponseToHTTPResponse(resp) + httpResp.RequestID = req.ID + } + respondOk(w, httpResp) + }) +} + // handleLogicalInternal is a common helper that returns a handler for // processing logical requests. The behavior depends on the various boolean // toggles. Refer to usage on functions for possible behaviors. diff --git a/http/sys_generate_root.go b/http/sys_generate_root.go index 56b8694a29..dae751a469 100644 --- a/http/sys_generate_root.go +++ b/http/sys_generate_root.go @@ -86,7 +86,7 @@ func handleSysGenerateRootAttemptGet(core *vault.Core, w http.ResponseWriter, r func handleSysGenerateRootAttemptPut(core *vault.Core, w http.ResponseWriter, r *http.Request, generateStrategy vault.GenerateRootStrategy) { // Parse the request var req GenerateRootInitRequest - if _, err := parseRequest(core, r, w, &req); err != nil && err != io.EOF { + if _, err := parseRequest(core.PerfStandby(), r, w, &req); err != nil && err != io.EOF { respondError(w, http.StatusBadRequest, err) return } @@ -132,7 +132,7 @@ func handleSysGenerateRootUpdate(core *vault.Core, generateStrategy vault.Genera return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Parse the request var req GenerateRootUpdateRequest - if _, err := parseRequest(core, r, w, &req); err != nil { + if _, err := parseRequest(core.PerfStandby(), r, w, &req); err != nil { respondError(w, http.StatusBadRequest, err) return } diff --git a/http/sys_init.go b/http/sys_init.go index 6615ea79bd..0e43d4248a 100644 --- a/http/sys_init.go +++ b/http/sys_init.go @@ -40,7 +40,7 @@ func handleSysInitPut(core *vault.Core, w http.ResponseWriter, r *http.Request) // Parse the request var req InitRequest - if _, err := parseRequest(core, r, w, &req); err != nil { + if _, err := parseRequest(core.PerfStandby(), r, w, &req); err != nil { respondError(w, http.StatusBadRequest, err) return } diff --git a/http/sys_raft.go b/http/sys_raft.go index 78b411f934..75d6ccf96b 100644 --- a/http/sys_raft.go +++ b/http/sys_raft.go @@ -25,7 +25,7 @@ func handleSysRaftJoin(core *vault.Core) http.Handler { func handleSysRaftJoinPost(core *vault.Core, w http.ResponseWriter, r *http.Request) { // Parse the request var req JoinRequest - if _, err := parseRequest(core, r, w, &req); err != nil && err != io.EOF { + if _, err := parseRequest(core.PerfStandby(), r, w, &req); err != nil && err != io.EOF { respondError(w, http.StatusBadRequest, err) return } diff --git a/http/sys_rekey.go b/http/sys_rekey.go index 58c2edf1bf..eb8760f927 100644 --- a/http/sys_rekey.go +++ b/http/sys_rekey.go @@ -108,7 +108,7 @@ func handleSysRekeyInitGet(ctx context.Context, core *vault.Core, recovery bool, func handleSysRekeyInitPut(ctx context.Context, core *vault.Core, recovery bool, w http.ResponseWriter, r *http.Request) { // Parse the request var req RekeyRequest - if _, err := parseRequest(core, r, w, &req); err != nil { + if _, err := parseRequest(core.PerfStandby(), r, w, &req); err != nil { respondError(w, http.StatusBadRequest, err) return } @@ -158,7 +158,7 @@ func handleSysRekeyUpdate(core *vault.Core, recovery bool) http.Handler { // Parse the request var req RekeyUpdateRequest - if _, err := parseRequest(core, r, w, &req); err != nil { + if _, err := parseRequest(core.PerfStandby(), r, w, &req); err != nil { respondError(w, http.StatusBadRequest, err) return } @@ -306,7 +306,7 @@ func handleSysRekeyVerifyDelete(ctx context.Context, core *vault.Core, recovery func handleSysRekeyVerifyPut(ctx context.Context, core *vault.Core, recovery bool, w http.ResponseWriter, r *http.Request) { // Parse the request var req RekeyVerificationUpdateRequest - if _, err := parseRequest(core, r, w, &req); err != nil { + if _, err := parseRequest(core.PerfStandby(), r, w, &req); err != nil { respondError(w, http.StatusBadRequest, err) return } diff --git a/http/sys_seal.go b/http/sys_seal.go index 384d6c5ab7..1cf520c098 100644 --- a/http/sys_seal.go +++ b/http/sys_seal.go @@ -86,7 +86,7 @@ func handleSysUnseal(core *vault.Core) http.Handler { // Parse the request var req UnsealRequest - if _, err := parseRequest(core, r, w, &req); err != nil { + if _, err := parseRequest(core.PerfStandby(), r, w, &req); err != nil { respondError(w, http.StatusBadRequest, err) return } @@ -198,7 +198,7 @@ func handleSysSealStatusRaw(core *vault.Core, w http.ResponseWriter, r *http.Req Initialized: false, Sealed: true, RecoverySeal: core.SealAccess().RecoveryKeySupported(), - StorageType: core.StorageType(), + StorageType: core.StorageType(), }) return } @@ -234,7 +234,7 @@ func handleSysSealStatusRaw(core *vault.Core, w http.ResponseWriter, r *http.Req ClusterName: clusterName, ClusterID: clusterID, RecoverySeal: core.SealAccess().RecoveryKeySupported(), - StorageType: core.StorageType(), + StorageType: core.StorageType(), }) } diff --git a/physical/raft/raft.go b/physical/raft/raft.go index 00e8789d53..015dcb480d 100644 --- a/physical/raft/raft.go +++ b/physical/raft/raft.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "github.com/armon/go-metrics" "io" "io/ioutil" "os" @@ -12,12 +13,11 @@ import ( "sync" "time" - metrics "github.com/armon/go-metrics" - proto "github.com/golang/protobuf/proto" + "github.com/golang/protobuf/proto" "github.com/hashicorp/errwrap" log "github.com/hashicorp/go-hclog" "github.com/hashicorp/go-raftchunking" - uuid "github.com/hashicorp/go-uuid" + "github.com/hashicorp/go-uuid" "github.com/hashicorp/raft" snapshot "github.com/hashicorp/raft-snapshot" raftboltdb "github.com/hashicorp/vault/physical/raft/logstore" @@ -367,6 +367,26 @@ type SetupOpts struct { // StartAsLeader is used to specify this node should start as leader and // bypass the leader election. This should be used with caution. StartAsLeader bool + + // RecoveryModeConfig is the configuration for the raft cluster in recovery + // mode. + RecoveryModeConfig *raft.Configuration +} + +func (b *RaftBackend) StartRecoveryCluster(ctx context.Context, peer Peer) error { + recoveryModeConfig := &raft.Configuration{ + Servers: []raft.Server{ + { + ID: raft.ServerID(peer.ID), + Address: raft.ServerAddress(peer.Address), + }, + }, + } + + return b.SetupCluster(context.Background(), SetupOpts{ + StartAsLeader: true, + RecoveryModeConfig: recoveryModeConfig, + }) } // SetupCluster starts the raft cluster and enables the networking needed for @@ -477,6 +497,13 @@ func (b *RaftBackend) SetupCluster(ctx context.Context, opts SetupOpts) error { b.logger.Info("raft recovery deleted peers.json") } + if opts.RecoveryModeConfig != nil { + err = raft.RecoverCluster(raftConfig, b.fsm, b.logStore, b.stableStore, b.snapStore, b.raftTransport, *opts.RecoveryModeConfig) + if err != nil { + return errwrap.Wrapf("recovering raft cluster failed: {{err}}", err) + } + } + raftObj, err := raft.NewRaft(raftConfig, b.fsm.chunker, b.logStore, b.stableStore, b.snapStore, b.raftTransport) b.fsm.SetNoopRestore(false) if err != nil { diff --git a/vault/core.go b/vault/core.go index 3d2c54481c..9d2f37dc60 100644 --- a/vault/core.go +++ b/vault/core.go @@ -423,6 +423,10 @@ type Core struct { // Stores any funcs that should be run on successful postUnseal postUnsealFuncs []func() + // Stores any funcs that should be run on successful barrier unseal in + // recovery mode + postRecoveryUnsealFuncs []func() error + // replicationFailure is used to mark when replication has entered an // unrecoverable failure. replicationFailure *uint32 @@ -465,6 +469,8 @@ type Core struct { rawConfig *server.Config coreNumber int + + recoveryMode bool } // CoreConfig is used to parameterize a core @@ -542,6 +548,8 @@ type CoreConfig struct { MetricsHelper *metricsutil.MetricsHelper CounterSyncInterval time.Duration + + RecoveryMode bool } func (c *CoreConfig) Clone() *CoreConfig { @@ -668,6 +676,7 @@ func NewCore(conf *CoreConfig) (*Core, error) { requests: new(uint64), syncInterval: syncInterval, }, + recoveryMode: conf.RecoveryMode, } atomic.StoreUint32(c.sealed, 1) @@ -726,25 +735,12 @@ func NewCore(conf *CoreConfig) (*Core, error) { var err error - if conf.PluginDirectory != "" { - c.pluginDirectory, err = filepath.Abs(conf.PluginDirectory) - if err != nil { - return nil, errwrap.Wrapf("core setup failed, could not verify plugin directory: {{err}}", err) - } - } - // Construct a new AES-GCM barrier c.barrier, err = NewAESGCMBarrier(c.physical) if err != nil { return nil, errwrap.Wrapf("barrier setup failed: {{err}}", err) } - createSecondaries(c, conf) - - if conf.HAPhysical != nil && conf.HAPhysical.HAEnabled() { - c.ha = conf.HAPhysical - } - // We create the funcs here, then populate the given config with it so that // the caller can share state conf.ReloadFuncsLock = &c.reloadFuncsLock @@ -753,6 +749,25 @@ func NewCore(conf *CoreConfig) (*Core, error) { c.reloadFuncsLock.Unlock() conf.ReloadFuncs = &c.reloadFuncs + // All the things happening below this are not required in + // recovery mode + if c.recoveryMode { + return c, nil + } + + if conf.PluginDirectory != "" { + c.pluginDirectory, err = filepath.Abs(conf.PluginDirectory) + if err != nil { + return nil, errwrap.Wrapf("core setup failed, could not verify plugin directory: {{err}}", err) + } + } + + createSecondaries(c, conf) + + if conf.HAPhysical != nil && conf.HAPhysical.HAEnabled() { + c.ha = conf.HAPhysical + } + logicalBackends := make(map[string]logical.Factory) for k, f := range conf.LogicalBackends { logicalBackends[k] = f diff --git a/vault/external_tests/api/sys_rekey_ext_test.go b/vault/external_tests/api/sys_rekey_ext_test.go index b0ca8cf88b..b3d7ca1829 100644 --- a/vault/external_tests/api/sys_rekey_ext_test.go +++ b/vault/external_tests/api/sys_rekey_ext_test.go @@ -218,7 +218,7 @@ func testSysRekey_Verification(t *testing.T, recovery bool) { } else { // We haven't finished, so generating a root token should still be the // old keys (which are still currently set) - testhelpers.GenerateRoot(t, cluster, false) + testhelpers.GenerateRoot(t, cluster, testhelpers.GenerateRootRegular) } // Provide the final new key @@ -256,7 +256,7 @@ func testSysRekey_Verification(t *testing.T, recovery bool) { } } else { // The old keys should no longer work - _, err := testhelpers.GenerateRootWithError(t, cluster, false) + _, err := testhelpers.GenerateRootWithError(t, cluster, testhelpers.GenerateRootRegular) if err == nil { t.Fatal("expected error") } @@ -273,6 +273,6 @@ func testSysRekey_Verification(t *testing.T, recovery bool) { if err := client.Sys().GenerateRootCancel(); err != nil { t.Fatal(err) } - testhelpers.GenerateRoot(t, cluster, false) + testhelpers.GenerateRoot(t, cluster, testhelpers.GenerateRootRegular) } } diff --git a/vault/external_tests/misc/kvv2_upgrade_test.go b/vault/external_tests/misc/kvv2_upgrade_test.go index 4318128ea3..330a615354 100644 --- a/vault/external_tests/misc/kvv2_upgrade_test.go +++ b/vault/external_tests/misc/kvv2_upgrade_test.go @@ -1,4 +1,4 @@ -package token +package misc import ( "bytes" diff --git a/vault/external_tests/misc/recover_from_panic_test.go b/vault/external_tests/misc/recover_from_panic_test.go index 157a4e9694..6c2ec1cf42 100644 --- a/vault/external_tests/misc/recover_from_panic_test.go +++ b/vault/external_tests/misc/recover_from_panic_test.go @@ -1,4 +1,4 @@ -package token +package misc import ( "testing" diff --git a/vault/external_tests/misc/recovery_test.go b/vault/external_tests/misc/recovery_test.go new file mode 100644 index 0000000000..1121f3c4e4 --- /dev/null +++ b/vault/external_tests/misc/recovery_test.go @@ -0,0 +1,140 @@ +package misc + +import ( + "github.com/go-test/deep" + "go.uber.org/atomic" + "path" + "testing" + + "github.com/hashicorp/go-hclog" + "github.com/hashicorp/vault/helper/testhelpers" + "github.com/hashicorp/vault/http" + "github.com/hashicorp/vault/sdk/helper/logging" + "github.com/hashicorp/vault/sdk/physical/inmem" + "github.com/hashicorp/vault/vault" +) + +func TestRecovery(t *testing.T) { + logger := logging.NewVaultLogger(hclog.Debug).Named(t.Name()) + inm, err := inmem.NewInmemHA(nil, logger) + if err != nil { + t.Fatal(err) + } + + var keys [][]byte + var secretUUID string + var rootToken string + { + conf := vault.CoreConfig{ + Physical: inm, + Logger: logger, + } + opts := vault.TestClusterOptions{ + HandlerFunc: http.Handler, + NumCores: 1, + } + + cluster := vault.NewTestCluster(t, &conf, &opts) + cluster.Start() + defer cluster.Cleanup() + + client := cluster.Cores[0].Client + rootToken = client.Token() + var fooVal = map[string]interface{}{"bar": 1.0} + _, err = client.Logical().Write("secret/foo", fooVal) + if err != nil { + t.Fatal(err) + } + secret, err := client.Logical().List("secret/") + if err != nil { + t.Fatal(err) + } + if diff := deep.Equal(secret.Data["keys"], []interface{}{"foo"}); len(diff) > 0 { + t.Fatalf("got=%v, want=%v, diff: %v", secret.Data["keys"], []string{"foo"}, diff) + } + mounts, err := cluster.Cores[0].Client.Sys().ListMounts() + if err != nil { + t.Fatal(err) + } + secretMount := mounts["secret/"] + if secretMount == nil { + t.Fatalf("secret mount not found, mounts: %v", mounts) + } + secretUUID = secretMount.UUID + cluster.EnsureCoresSealed(t) + keys = cluster.BarrierKeys + } + + { + // Now bring it up in recovery mode. + var tokenRef atomic.String + conf := vault.CoreConfig{ + Physical: inm, + Logger: logger, + RecoveryMode: true, + } + opts := vault.TestClusterOptions{ + HandlerFunc: http.Handler, + NumCores: 1, + SkipInit: true, + DefaultHandlerProperties: vault.HandlerProperties{ + RecoveryMode: true, + RecoveryToken: &tokenRef, + }, + } + cluster := vault.NewTestCluster(t, &conf, &opts) + cluster.BarrierKeys = keys + cluster.Start() + defer cluster.Cleanup() + + client := cluster.Cores[0].Client + recoveryToken := testhelpers.GenerateRoot(t, cluster, testhelpers.GenerateRecovery) + _, err = testhelpers.GenerateRootWithError(t, cluster, testhelpers.GenerateRecovery) + if err == nil { + t.Fatal("expected second generate-root to fail") + } + client.SetToken(recoveryToken) + + secret, err := client.Logical().List(path.Join("sys/raw/logical", secretUUID)) + if err != nil { + t.Fatal(err) + } + if diff := deep.Equal(secret.Data["keys"], []interface{}{"foo"}); len(diff) > 0 { + t.Fatalf("got=%v, want=%v, diff: %v", secret.Data, []string{"foo"}, diff) + } + + _, err = client.Logical().Delete(path.Join("sys/raw/logical", secretUUID, "foo")) + if err != nil { + t.Fatal(err) + } + cluster.EnsureCoresSealed(t) + } + + { + // Now go back to regular mode and verify that our changes are present + conf := vault.CoreConfig{ + Physical: inm, + Logger: logger, + } + opts := vault.TestClusterOptions{ + HandlerFunc: http.Handler, + NumCores: 1, + SkipInit: true, + } + cluster := vault.NewTestCluster(t, &conf, &opts) + cluster.BarrierKeys = keys + cluster.Start() + testhelpers.EnsureCoresUnsealed(t, cluster) + defer cluster.Cleanup() + + client := cluster.Cores[0].Client + client.SetToken(rootToken) + secret, err := client.Logical().List("secret/") + if err != nil { + t.Fatal(err) + } + if secret != nil { + t.Fatal("expected no data in secret mount") + } + } +} diff --git a/vault/generate_root.go b/vault/generate_root.go index c7b554bbb2..9708877709 100644 --- a/vault/generate_root.go +++ b/vault/generate_root.go @@ -4,10 +4,10 @@ import ( "bytes" "context" "encoding/base64" + "errors" "fmt" - "github.com/hashicorp/errwrap" - uuid "github.com/hashicorp/go-uuid" + "github.com/hashicorp/go-uuid" "github.com/hashicorp/vault/helper/pgpkeys" "github.com/hashicorp/vault/helper/xor" "github.com/hashicorp/vault/sdk/helper/consts" @@ -77,10 +77,10 @@ type GenerateRootResult struct { func (c *Core) GenerateRootProgress() (int, error) { c.stateLock.RLock() defer c.stateLock.RUnlock() - if c.Sealed() { + if c.Sealed() && !c.recoveryMode { return 0, consts.ErrSealed } - if c.standby { + if c.standby && !c.recoveryMode { return 0, consts.ErrStandby } @@ -95,10 +95,10 @@ func (c *Core) GenerateRootProgress() (int, error) { func (c *Core) GenerateRootConfiguration() (*GenerateRootConfig, error) { c.stateLock.RLock() defer c.stateLock.RUnlock() - if c.Sealed() { + if c.Sealed() && !c.recoveryMode { return nil, consts.ErrSealed } - if c.standby { + if c.standby && !c.recoveryMode { return nil, consts.ErrStandby } @@ -141,10 +141,17 @@ func (c *Core) GenerateRootInit(otp, pgpKey string, strategy GenerateRootStrateg c.stateLock.RLock() defer c.stateLock.RUnlock() - if c.Sealed() { + if c.Sealed() && !c.recoveryMode { return consts.ErrSealed } - if c.standby { + barrierSealed, err := c.barrier.Sealed() + if err != nil { + return errors.New("unable to check barrier seal status") + } + if !barrierSealed && c.recoveryMode { + return errors.New("attempt to generate recovery operation token when already unsealed") + } + if c.standby && !c.recoveryMode { return consts.ErrStandby } @@ -174,6 +181,8 @@ func (c *Core) GenerateRootInit(otp, pgpKey string, strategy GenerateRootStrateg switch strategy.(type) { case generateStandardRootToken: c.logger.Info("root generation initialized", "nonce", c.generateRootConfig.Nonce) + case *generateRecoveryToken: + c.logger.Info("recovery operation token generation initialized", "nonce", c.generateRootConfig.Nonce) default: c.logger.Info("dr operation token generation initialized", "nonce", c.generateRootConfig.Nonce) } @@ -217,10 +226,19 @@ func (c *Core) GenerateRootUpdate(ctx context.Context, key []byte, nonce string, // Ensure we are already unsealed c.stateLock.RLock() defer c.stateLock.RUnlock() - if c.Sealed() { + if c.Sealed() && !c.recoveryMode { return nil, consts.ErrSealed } - if c.standby { + + barrierSealed, err := c.barrier.Sealed() + if err != nil { + return nil, errors.New("unable to check barrier seal status") + } + if !barrierSealed && c.recoveryMode { + return nil, errors.New("attempt to generate recovery operation token when already unsealed") + } + + if c.standby && !c.recoveryMode { return nil, consts.ErrStandby } @@ -263,29 +281,80 @@ func (c *Core) GenerateRootUpdate(ctx context.Context, key []byte, nonce string, }, nil } - // Recover the master key - var masterKey []byte + // Combine the key parts + var combinedKey []byte if config.SecretThreshold == 1 { - masterKey = c.generateRootProgress[0] + combinedKey = c.generateRootProgress[0] c.generateRootProgress = nil } else { - masterKey, err = shamir.Combine(c.generateRootProgress) + combinedKey, err = shamir.Combine(c.generateRootProgress) c.generateRootProgress = nil if err != nil { return nil, errwrap.Wrapf("failed to compute master key: {{err}}", err) } } - // Verify the master key - if c.seal.RecoveryKeySupported() { - if err := c.seal.VerifyRecoveryKey(ctx, masterKey); err != nil { + switch { + case c.seal.RecoveryKeySupported(): + // Ensure that the combined recovery key is valid + if err := c.seal.VerifyRecoveryKey(ctx, combinedKey); err != nil { c.logger.Error("root generation aborted, recovery key verification failed", "error", err) return nil, err } - } else { - if err := c.barrier.VerifyMaster(masterKey); err != nil { - c.logger.Error("root generation aborted, master key verification failed", "error", err) - return nil, err + // If we are in recovery mode, then retrieve + // the stored keys and unseal the barrier + if c.recoveryMode { + if !c.seal.StoredKeysSupported() { + c.logger.Error("root generation aborted, recovery key verified but stored keys unsupported") + return nil, errors.New("recovery key verified but stored keys unsupported") + } + masterKeyShares, err := c.seal.GetStoredKeys(ctx) + if err != nil { + return nil, errwrap.Wrapf("unable to retrieve stored keys in recovery mode: {{err}}", err) + } + + switch len(masterKeyShares) { + case 0: + return nil, errors.New("seal returned no master key shares in recovery mode") + case 1: + combinedKey = masterKeyShares[0] + default: + combinedKey, err = shamir.Combine(masterKeyShares) + if err != nil { + return nil, errwrap.Wrapf("failed to compute master key in recovery mode: {{err}}", err) + } + } + + // Use the retrieved master key to unseal the barrier + if err := c.barrier.Unseal(ctx, combinedKey); err != nil { + c.logger.Error("root generation aborted, recovery operation token verification failed", "error", err) + return nil, err + } + } + default: + switch { + case c.recoveryMode: + // If we are in recovery mode, being able to unseal + // the barrier is how we establish authentication + if err := c.barrier.Unseal(ctx, combinedKey); err != nil { + c.logger.Error("root generation aborted, recovery operation token verification failed", "error", err) + return nil, err + } + default: + if err := c.barrier.VerifyMaster(combinedKey); err != nil { + c.logger.Error("root generation aborted, master key verification failed", "error", err) + return nil, err + } + } + } + + // Authentication in recovery mode is successful + if c.recoveryMode { + // Run any post unseal functions that are set + for _, v := range c.postRecoveryUnsealFuncs { + if err := v(); err != nil { + return nil, errwrap.Wrapf("failed to run post unseal func: {{err}}", err) + } } } @@ -334,13 +403,11 @@ func (c *Core) GenerateRootUpdate(ctx context.Context, key []byte, nonce string, switch strategy.(type) { case generateStandardRootToken: - if c.logger.IsInfo() { - c.logger.Info("root generation finished", "nonce", c.generateRootConfig.Nonce) - } + c.logger.Info("root generation finished", "nonce", c.generateRootConfig.Nonce) + case *generateRecoveryToken: + c.logger.Info("recovery operation token generation finished", "nonce", c.generateRootConfig.Nonce) default: - if c.logger.IsInfo() { - c.logger.Info("dr operation token generation finished", "nonce", c.generateRootConfig.Nonce) - } + c.logger.Info("dr operation token generation finished", "nonce", c.generateRootConfig.Nonce) } c.generateRootProgress = nil @@ -352,10 +419,10 @@ func (c *Core) GenerateRootUpdate(ctx context.Context, key []byte, nonce string, func (c *Core) GenerateRootCancel() error { c.stateLock.RLock() defer c.stateLock.RUnlock() - if c.Sealed() { + if c.Sealed() && !c.recoveryMode { return consts.ErrSealed } - if c.standby { + if c.standby && !c.recoveryMode { return consts.ErrStandby } diff --git a/vault/generate_root_recovery.go b/vault/generate_root_recovery.go new file mode 100644 index 0000000000..73b9c2c336 --- /dev/null +++ b/vault/generate_root_recovery.go @@ -0,0 +1,31 @@ +package vault + +import ( + "context" + + "github.com/hashicorp/vault/sdk/helper/base62" + "go.uber.org/atomic" +) + +// GenerateRecoveryTokenStrategy is the strategy used to generate a +// recovery token +func GenerateRecoveryTokenStrategy(token *atomic.String) GenerateRootStrategy { + return &generateRecoveryToken{token: token} +} + +// generateRecoveryToken implements the GenerateRootStrategy and is in +// charge of creating recovery tokens. +type generateRecoveryToken struct { + token *atomic.String +} + +func (g *generateRecoveryToken) generate(ctx context.Context, c *Core) (string, func(), error) { + id, err := base62.Random(TokenLength) + if err != nil { + return "", nil, err + } + token := "r." + id + g.token.Store(token) + + return token, func() { g.token.Store("") }, nil +} diff --git a/vault/init.go b/vault/init.go index 72408a7850..b112e18175 100644 --- a/vault/init.go +++ b/vault/init.go @@ -38,6 +38,31 @@ var ( initInProgress uint32 ) +func (c *Core) InitializeRecovery(ctx context.Context) error { + if !c.recoveryMode { + return nil + } + + raftStorage, ok := c.underlyingPhysical.(*raft.RaftBackend) + if !ok { + return nil + } + + parsedClusterAddr, err := url.Parse(c.ClusterAddr()) + if err != nil { + return err + } + + c.postRecoveryUnsealFuncs = append(c.postRecoveryUnsealFuncs, func() error { + return raftStorage.StartRecoveryCluster(context.Background(), raft.Peer{ + ID: raftStorage.NodeID(), + Address: parsedClusterAddr.Host, + }) + }) + + return nil +} + // Initialized checks if the Vault is already initialized func (c *Core) Initialized(ctx context.Context) (bool, error) { // Check the barrier first diff --git a/vault/logical_raw.go b/vault/logical_raw.go new file mode 100644 index 0000000000..9474e55a76 --- /dev/null +++ b/vault/logical_raw.go @@ -0,0 +1,216 @@ +package vault + +import ( + "context" + "fmt" + log "github.com/hashicorp/go-hclog" + "github.com/hashicorp/vault/sdk/framework" + "github.com/hashicorp/vault/sdk/helper/compressutil" + "github.com/hashicorp/vault/sdk/logical" + "strings" +) + +var ( + // protectedPaths cannot be accessed via the raw APIs. + // This is both for security and to prevent disrupting Vault. + protectedPaths = []string{ + keyringPath, + // Changing the cluster info path can change the cluster ID which can be disruptive + coreLocalClusterInfoPath, + } +) + +type RawBackend struct { + *framework.Backend + barrier SecurityBarrier + logger log.Logger + checkRaw func(path string) error + recoveryMode bool +} + +func NewRawBackend(core *Core) *RawBackend { + r := &RawBackend{ + barrier: core.barrier, + logger: core.logger.Named("raw"), + checkRaw: func(path string) error { + return nil + }, + recoveryMode: core.recoveryMode, + } + r.Backend = &framework.Backend{ + Paths: rawPaths("sys/", r), + } + return r +} + +// handleRawRead is used to read directly from the barrier +func (b *RawBackend) handleRawRead(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) { + path := data.Get("path").(string) + + if b.recoveryMode { + b.logger.Info("reading", "path", path) + } + + // Prevent access of protected paths + for _, p := range protectedPaths { + if strings.HasPrefix(path, p) { + err := fmt.Sprintf("cannot read '%s'", path) + return logical.ErrorResponse(err), logical.ErrInvalidRequest + } + } + + // Run additional checks if needed + if err := b.checkRaw(path); err != nil { + b.logger.Warn(err.Error(), "path", path) + return logical.ErrorResponse("cannot read '%s'", path), logical.ErrInvalidRequest + } + + entry, err := b.barrier.Get(ctx, path) + if err != nil { + return handleErrorNoReadOnlyForward(err) + } + if entry == nil { + return nil, nil + } + + // Run this through the decompression helper to see if it's been compressed. + // If the input contained the compression canary, `outputBytes` will hold + // the decompressed data. If the input was not compressed, then `outputBytes` + // will be nil. + outputBytes, _, err := compressutil.Decompress(entry.Value) + if err != nil { + return handleErrorNoReadOnlyForward(err) + } + + // `outputBytes` is nil if the input is uncompressed. In that case set it to the original input. + if outputBytes == nil { + outputBytes = entry.Value + } + + resp := &logical.Response{ + Data: map[string]interface{}{ + "value": string(outputBytes), + }, + } + return resp, nil +} + +// handleRawWrite is used to write directly to the barrier +func (b *RawBackend) handleRawWrite(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) { + path := data.Get("path").(string) + + if b.recoveryMode { + b.logger.Info("writing", "path", path) + } + + // Prevent access of protected paths + for _, p := range protectedPaths { + if strings.HasPrefix(path, p) { + err := fmt.Sprintf("cannot write '%s'", path) + return logical.ErrorResponse(err), logical.ErrInvalidRequest + } + } + + value := data.Get("value").(string) + entry := &logical.StorageEntry{ + Key: path, + Value: []byte(value), + } + if err := b.barrier.Put(ctx, entry); err != nil { + return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest + } + return nil, nil +} + +// handleRawDelete is used to delete directly from the barrier +func (b *RawBackend) handleRawDelete(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) { + path := data.Get("path").(string) + + if b.recoveryMode { + b.logger.Info("deleting", "path", path) + } + + // Prevent access of protected paths + for _, p := range protectedPaths { + if strings.HasPrefix(path, p) { + err := fmt.Sprintf("cannot delete '%s'", path) + return logical.ErrorResponse(err), logical.ErrInvalidRequest + } + } + + if err := b.barrier.Delete(ctx, path); err != nil { + return handleErrorNoReadOnlyForward(err) + } + return nil, nil +} + +// handleRawList is used to list directly from the barrier +func (b *RawBackend) handleRawList(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) { + path := data.Get("path").(string) + if path != "" && !strings.HasSuffix(path, "/") { + path = path + "/" + } + + if b.recoveryMode { + b.logger.Info("listing", "path", path) + } + + // Prevent access of protected paths + for _, p := range protectedPaths { + if strings.HasPrefix(path, p) { + err := fmt.Sprintf("cannot list '%s'", path) + return logical.ErrorResponse(err), logical.ErrInvalidRequest + } + } + + // Run additional checks if needed + if err := b.checkRaw(path); err != nil { + b.logger.Warn(err.Error(), "path", path) + return logical.ErrorResponse("cannot list '%s'", path), logical.ErrInvalidRequest + } + + keys, err := b.barrier.List(ctx, path) + if err != nil { + return handleErrorNoReadOnlyForward(err) + } + return logical.ListResponse(keys), nil +} + +func rawPaths(prefix string, r *RawBackend) []*framework.Path { + return []*framework.Path{ + &framework.Path{ + Pattern: prefix + "(raw/?$|raw/(?P.+))", + + Fields: map[string]*framework.FieldSchema{ + "path": &framework.FieldSchema{ + Type: framework.TypeString, + }, + "value": &framework.FieldSchema{ + Type: framework.TypeString, + }, + }, + + Operations: map[logical.Operation]framework.OperationHandler{ + logical.ReadOperation: &framework.PathOperation{ + Callback: r.handleRawRead, + Summary: "Read the value of the key at the given path.", + }, + logical.UpdateOperation: &framework.PathOperation{ + Callback: r.handleRawWrite, + Summary: "Update the value of the key at the given path.", + }, + logical.DeleteOperation: &framework.PathOperation{ + Callback: r.handleRawDelete, + Summary: "Delete the key with given path.", + }, + logical.ListOperation: &framework.PathOperation{ + Callback: r.handleRawList, + Summary: "Return a list keys for a given path prefix.", + }, + }, + + HelpSynopsis: strings.TrimSpace(sysHelp["raw"][0]), + HelpDescription: strings.TrimSpace(sysHelp["raw"][1]), + }, + } +} diff --git a/vault/logical_system.go b/vault/logical_system.go index bf2528e0fc..20cc9188bd 100644 --- a/vault/logical_system.go +++ b/vault/logical_system.go @@ -23,14 +23,13 @@ import ( "github.com/hashicorp/errwrap" log "github.com/hashicorp/go-hclog" - memdb "github.com/hashicorp/go-memdb" - uuid "github.com/hashicorp/go-uuid" + "github.com/hashicorp/go-memdb" + "github.com/hashicorp/go-uuid" "github.com/hashicorp/vault/helper/hostutil" "github.com/hashicorp/vault/helper/identity" "github.com/hashicorp/vault/helper/metricsutil" "github.com/hashicorp/vault/helper/namespace" "github.com/hashicorp/vault/sdk/framework" - "github.com/hashicorp/vault/sdk/helper/compressutil" "github.com/hashicorp/vault/sdk/helper/consts" "github.com/hashicorp/vault/sdk/helper/jsonutil" "github.com/hashicorp/vault/sdk/helper/parseutil" @@ -40,16 +39,6 @@ import ( "github.com/mitchellh/mapstructure" ) -var ( - // protectedPaths cannot be accessed via the raw APIs. - // This is both for security and to prevent disrupting Vault. - protectedPaths = []string{ - keyringPath, - // Changing the cluster info path can change the cluster ID which can be disruptive - coreLocalClusterInfoPath, - } -) - const maxBytes = 128 * 1024 func systemBackendMemDBSchema() *memdb.DBSchema { @@ -172,40 +161,7 @@ func NewSystemBackend(core *Core, logger log.Logger) *SystemBackend { b.Backend.Paths = append(b.Backend.Paths, b.hostInfoPath()) if core.rawEnabled { - b.Backend.Paths = append(b.Backend.Paths, &framework.Path{ - Pattern: "(raw/?$|raw/(?P.+))", - - Fields: map[string]*framework.FieldSchema{ - "path": &framework.FieldSchema{ - Type: framework.TypeString, - }, - "value": &framework.FieldSchema{ - Type: framework.TypeString, - }, - }, - - Operations: map[logical.Operation]framework.OperationHandler{ - logical.ReadOperation: &framework.PathOperation{ - Callback: b.handleRawRead, - Summary: "Read the value of the key at the given path.", - }, - logical.UpdateOperation: &framework.PathOperation{ - Callback: b.handleRawWrite, - Summary: "Update the value of the key at the given path.", - }, - logical.DeleteOperation: &framework.PathOperation{ - Callback: b.handleRawDelete, - Summary: "Delete the key with given path.", - }, - logical.ListOperation: &framework.PathOperation{ - Callback: b.handleRawList, - Summary: "Return a list keys for a given path prefix.", - }, - }, - - HelpSynopsis: strings.TrimSpace(sysHelp["raw"][0]), - HelpDescription: strings.TrimSpace(sysHelp["raw"][1]), - }) + b.Backend.Paths = append(b.Backend.Paths, b.rawPaths()...) } if _, ok := core.underlyingPhysical.(*raft.RaftBackend); ok { @@ -216,6 +172,17 @@ func NewSystemBackend(core *Core, logger log.Logger) *SystemBackend { return b } +func (b *SystemBackend) rawPaths() []*framework.Path { + r := &RawBackend{ + barrier: b.Core.barrier, + logger: b.logger, + checkRaw: func(path string) error { + return checkRaw(b, path) + }, + } + return rawPaths("", r) +} + // SystemBackend implements logical.Backend and is used to interact with // the core of the system. This backend is hardcoded to exist at the "sys" // prefix. Conceptually it is similar to procfs on Linux. @@ -2248,123 +2215,6 @@ func (b *SystemBackend) handleConfigUIHeadersDelete(ctx context.Context, req *lo return nil, nil } -// handleRawRead is used to read directly from the barrier -func (b *SystemBackend) handleRawRead(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) { - path := data.Get("path").(string) - - // Prevent access of protected paths - for _, p := range protectedPaths { - if strings.HasPrefix(path, p) { - err := fmt.Sprintf("cannot read '%s'", path) - return logical.ErrorResponse(err), logical.ErrInvalidRequest - } - } - - // Run additional checks if needed - if err := checkRaw(b, path); err != nil { - b.Core.logger.Warn(err.Error(), "path", path) - return logical.ErrorResponse("cannot read '%s'", path), logical.ErrInvalidRequest - } - - entry, err := b.Core.barrier.Get(ctx, path) - if err != nil { - return handleErrorNoReadOnlyForward(err) - } - if entry == nil { - return nil, nil - } - - // Run this through the decompression helper to see if it's been compressed. - // If the input contained the compression canary, `outputBytes` will hold - // the decompressed data. If the input was not compressed, then `outputBytes` - // will be nil. - outputBytes, _, err := compressutil.Decompress(entry.Value) - if err != nil { - return handleErrorNoReadOnlyForward(err) - } - - // `outputBytes` is nil if the input is uncompressed. In that case set it to the original input. - if outputBytes == nil { - outputBytes = entry.Value - } - - resp := &logical.Response{ - Data: map[string]interface{}{ - "value": string(outputBytes), - }, - } - return resp, nil -} - -// handleRawWrite is used to write directly to the barrier -func (b *SystemBackend) handleRawWrite(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) { - path := data.Get("path").(string) - - // Prevent access of protected paths - for _, p := range protectedPaths { - if strings.HasPrefix(path, p) { - err := fmt.Sprintf("cannot write '%s'", path) - return logical.ErrorResponse(err), logical.ErrInvalidRequest - } - } - - value := data.Get("value").(string) - entry := &logical.StorageEntry{ - Key: path, - Value: []byte(value), - } - if err := b.Core.barrier.Put(ctx, entry); err != nil { - return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest - } - return nil, nil -} - -// handleRawDelete is used to delete directly from the barrier -func (b *SystemBackend) handleRawDelete(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) { - path := data.Get("path").(string) - - // Prevent access of protected paths - for _, p := range protectedPaths { - if strings.HasPrefix(path, p) { - err := fmt.Sprintf("cannot delete '%s'", path) - return logical.ErrorResponse(err), logical.ErrInvalidRequest - } - } - - if err := b.Core.barrier.Delete(ctx, path); err != nil { - return handleErrorNoReadOnlyForward(err) - } - return nil, nil -} - -// handleRawList is used to list directly from the barrier -func (b *SystemBackend) handleRawList(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) { - path := data.Get("path").(string) - if path != "" && !strings.HasSuffix(path, "/") { - path = path + "/" - } - - // Prevent access of protected paths - for _, p := range protectedPaths { - if strings.HasPrefix(path, p) { - err := fmt.Sprintf("cannot list '%s'", path) - return logical.ErrorResponse(err), logical.ErrInvalidRequest - } - } - - // Run additional checks if needed - if err := checkRaw(b, path); err != nil { - b.Core.logger.Warn(err.Error(), "path", path) - return logical.ErrorResponse("cannot list '%s'", path), logical.ErrInvalidRequest - } - - keys, err := b.Core.barrier.List(ctx, path) - if err != nil { - return handleErrorNoReadOnlyForward(err) - } - return logical.ListResponse(keys), nil -} - // handleKeyStatus returns status information about the backend key func (b *SystemBackend) handleKeyStatus(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) { // Get the key info diff --git a/vault/request_handling.go b/vault/request_handling.go index cd67765ac3..4f2e44f92f 100644 --- a/vault/request_handling.go +++ b/vault/request_handling.go @@ -22,6 +22,7 @@ import ( "github.com/hashicorp/vault/sdk/helper/strutil" "github.com/hashicorp/vault/sdk/helper/wrapping" "github.com/hashicorp/vault/sdk/logical" + uberAtomic "go.uber.org/atomic" ) const ( @@ -43,6 +44,8 @@ type HandlerProperties struct { MaxRequestSize int64 MaxRequestDuration time.Duration DisablePrintableCheck bool + RecoveryMode bool + RecoveryToken *uberAtomic.String UnauthenticatedMetricsAccess bool } diff --git a/vault/testing.go b/vault/testing.go index 55ed15f3b6..0ce4889a34 100644 --- a/vault/testing.go +++ b/vault/testing.go @@ -1039,16 +1039,17 @@ type PhysicalBackendBundle struct { } type TestClusterOptions struct { - KeepStandbysSealed bool - SkipInit bool - HandlerFunc func(*HandlerProperties) http.Handler - BaseListenAddress string - NumCores int - SealFunc func() Seal - Logger log.Logger - TempDir string - CACert []byte - CAKey *ecdsa.PrivateKey + KeepStandbysSealed bool + SkipInit bool + HandlerFunc func(*HandlerProperties) http.Handler + DefaultHandlerProperties HandlerProperties + BaseListenAddress string + NumCores int + SealFunc func() Seal + Logger log.Logger + TempDir string + CACert []byte + CAKey *ecdsa.PrivateKey // PhysicalFactory is used to create backends. // The int argument is the index of the core within the cluster, i.e. first // core in cluster will have 0, second 1, etc. @@ -1417,7 +1418,7 @@ func NewTestCluster(t testing.T, base *CoreConfig, opts *TestClusterOptions) *Te coreConfig.DevToken = base.DevToken coreConfig.CounterSyncInterval = base.CounterSyncInterval - + coreConfig.RecoveryMode = base.RecoveryMode } if coreConfig.RawConfig == nil { @@ -1511,10 +1512,12 @@ func NewTestCluster(t testing.T, base *CoreConfig, opts *TestClusterOptions) *Te cores = append(cores, c) coreConfigs = append(coreConfigs, &localConfig) if opts != nil && opts.HandlerFunc != nil { - handlers[i] = opts.HandlerFunc(&HandlerProperties{ - Core: c, - MaxRequestDuration: DefaultMaxRequestDuration, - }) + props := opts.DefaultHandlerProperties + props.Core = c + if props.MaxRequestDuration == 0 { + props.MaxRequestDuration = DefaultMaxRequestDuration + } + handlers[i] = opts.HandlerFunc(&props) servers[i].Handler = handlers[i] }