diff --git a/command/server.go b/command/server.go index b29b208e84..501da04883 100644 --- a/command/server.go +++ b/command/server.go @@ -97,8 +97,9 @@ type ServerCommand struct { type ServerListener struct { net.Listener - config map[string]interface{} - maxRequestSize int64 + config map[string]interface{} + maxRequestSize int64 + maxRequestDuration time.Duration } func (c *ServerCommand) Synopsis() string { @@ -395,6 +396,10 @@ func (c *ServerCommand) Run(args []string) int { return 1 } + if config.DefaultMaxRequestDuration != 0 { + vault.DefaultMaxRequestDuration = config.DefaultMaxRequestDuration + } + // If mlockall(2) isn't supported, show a warning. We disable this in dev // because it is quite scary to see when first using Vault. We also disable // this if the user has explicitly disabled mlock in configuration. @@ -738,10 +743,25 @@ CLUSTER_SYNTHESIS_COMPLETE: } props["max_request_size"] = fmt.Sprintf("%d", maxRequestSize) + var maxRequestDuration time.Duration = vault.DefaultMaxRequestDuration + if valRaw, ok := lnConfig.Config["max_request_duration"]; ok { + val, err := parseutil.ParseDurationSecond(valRaw) + if err != nil { + c.UI.Error(fmt.Sprintf("Could not parse max_request_duration value %v", valRaw)) + return 1 + } + + if val >= 0 { + maxRequestDuration = val + } + } + props["max_request_duration"] = fmt.Sprintf("%s", maxRequestDuration.String()) + lns = append(lns, ServerListener{ - Listener: ln, - config: lnConfig.Config, - maxRequestSize: maxRequestSize, + Listener: ln, + config: lnConfig.Config, + maxRequestSize: maxRequestSize, + maxRequestDuration: maxRequestDuration, }) // Store the listener props for output later @@ -939,6 +959,7 @@ CLUSTER_SYNTHESIS_COMPLETE: handler := vaulthttp.Handler(&vault.HandlerProperties{ Core: core, MaxRequestSize: ln.maxRequestSize, + MaxRequestDuration: ln.maxRequestDuration, DisablePrintableCheck: config.DisablePrintableCheck, }) @@ -1113,7 +1134,7 @@ func (c *ServerCommand) enableDev(core *vault.Core, coreConfig *vault.CoreConfig "no_default_policy": true, }, } - resp, err := core.HandleRequest(req) + resp, err := core.HandleRequest(context.Background(), req) if err != nil { return nil, errwrap.Wrapf(fmt.Sprintf("failed to create root token with ID %q: {{err}}", coreConfig.DevToken), err) } @@ -1129,7 +1150,7 @@ func (c *ServerCommand) enableDev(core *vault.Core, coreConfig *vault.CoreConfig req.ID = "dev-revoke-init-root" req.Path = "auth/token/revoke-self" req.Data = nil - resp, err = core.HandleRequest(req) + resp, err = core.HandleRequest(context.Background(), req) if err != nil { return nil, errwrap.Wrapf("failed to revoke initial root token: {{err}}", err) } @@ -1156,7 +1177,7 @@ func (c *ServerCommand) enableDev(core *vault.Core, coreConfig *vault.CoreConfig }, }, } - resp, err := core.HandleRequest(req) + resp, err := core.HandleRequest(context.Background(), req) if err != nil { return nil, errwrap.Wrapf("error upgrading default K/V store: {{err}}", err) } @@ -1233,7 +1254,7 @@ func (c *ServerCommand) enableThreeNodeDevCluster(base *vault.CoreConfig, info m "no_default_policy": true, }, } - resp, err := testCluster.Cores[0].HandleRequest(req) + resp, err := testCluster.Cores[0].HandleRequest(context.Background(), req) if err != nil { c.UI.Error(fmt.Sprintf("failed to create root token with ID %s: %s", base.DevToken, err)) return 1 @@ -1252,7 +1273,7 @@ func (c *ServerCommand) enableThreeNodeDevCluster(base *vault.CoreConfig, info m req.ID = "dev-revoke-init-root" req.Path = "auth/token/revoke-self" req.Data = nil - resp, err = testCluster.Cores[0].HandleRequest(req) + resp, err = testCluster.Cores[0].HandleRequest(context.Background(), req) if err != nil { c.UI.Output(fmt.Sprintf("failed to revoke initial root token: %s", err)) return 1 @@ -1385,7 +1406,7 @@ func (c *ServerCommand) addPlugin(path, token string, core *vault.Core) error { "command": name, }, } - if _, err := core.HandleRequest(req); err != nil { + if _, err := core.HandleRequest(context.Background(), req); err != nil { return err } diff --git a/command/server/config.go b/command/server/config.go index d33cec7090..f2644a34d1 100644 --- a/command/server/config.go +++ b/command/server/config.go @@ -16,7 +16,6 @@ import ( "github.com/hashicorp/go-multierror" "github.com/hashicorp/hcl" "github.com/hashicorp/hcl/hcl/ast" - "github.com/hashicorp/vault/helper/hclutil" "github.com/hashicorp/vault/helper/parseutil" ) @@ -46,6 +45,9 @@ type Config struct { DefaultLeaseTTL time.Duration `hcl:"-"` DefaultLeaseTTLRaw interface{} `hcl:"default_lease_ttl"` + DefaultMaxRequestDuration time.Duration `hcl:"-"` + DefaultMaxRequestDurationRaw interface{} `hcl:"default_max_request_time"` + ClusterName string `hcl:"cluster_name"` ClusterCipherSuites string `hcl:"cluster_cipher_suites"` @@ -289,6 +291,11 @@ func (c *Config) Merge(c2 *Config) *Config { result.DefaultLeaseTTL = c2.DefaultLeaseTTL } + result.DefaultMaxRequestDuration = c.DefaultMaxRequestDuration + if c2.DefaultMaxRequestDuration > result.DefaultMaxRequestDuration { + result.DefaultMaxRequestDuration = c2.DefaultMaxRequestDuration + } + result.ClusterName = c.ClusterName if c2.ClusterName != "" { result.ClusterName = c2.ClusterName @@ -375,6 +382,12 @@ func ParseConfig(d string, logger log.Logger) (*Config, error) { } } + if result.DefaultMaxRequestDurationRaw != nil { + if result.DefaultMaxRequestDuration, err = parseutil.ParseDurationSecond(result.DefaultMaxRequestDurationRaw); err != nil { + return nil, err + } + } + if result.EnableUIRaw != nil { if result.EnableUI, err = parseutil.ParseBool(result.EnableUIRaw); err != nil { return nil, err @@ -422,36 +435,6 @@ func ParseConfig(d string, logger log.Logger) (*Config, error) { return nil, fmt.Errorf("error parsing: file doesn't contain a root object") } - valid := []string{ - "storage", - "ha_storage", - "backend", - "ha_backend", - "hsm", - "seal", - "listener", - "cache_size", - "disable_cache", - "disable_mlock", - "disable_printable_check", - "ui", - "telemetry", - "default_lease_ttl", - "max_lease_ttl", - "cluster_name", - "cluster_cipher_suites", - "plugin_directory", - "pid_file", - "raw_storage_endpoint", - "api_addr", - "cluster_addr", - "disable_clustering", - "disable_sealwrap", - } - if err := hclutil.CheckHCLKeys(list, valid); err != nil { - return nil, err - } - // Look for storage but still support old backend if o := list.Filter("storage"); len(o.Items) > 0 { if err := parseStorage(&result, o, "storage"); err != nil { @@ -728,61 +711,16 @@ func parseSeal(result *Config, list *ast.ObjectList, blockName string) error { key = item.Keys[0].Token.Value().(string) } - var valid []string // Valid parameter for the Seal types switch key { case "pkcs11": - valid = []string{ - "lib", - "slot", - "token_label", - "pin", - "mechanism", - "hmac_mechanism", - "key_label", - "default_key_label", - "hmac_key_label", - "hmac_default_key_label", - "generate_key", - "regenerate_key", - "max_parallel", - "disable_auto_reinit_on_error", - "rsa_encrypt_local", - "rsa_oaep_hash", - } case "awskms": - valid = []string{ - "region", - "access_key", - "secret_key", - "kms_key_id", - "max_parallel", - } case "gcpckms": - valid = []string{ - "credentials", - "project", - "region", - "key_ring", - "crypto_key", - } case "azurekeyvault": - valid = []string{ - "tenant_id", - "client_id", - "client_secret", - "environment", - "vault_name", - "key_name", - } default: return fmt.Errorf("invalid seal type %q", key) } - if err := hclutil.CheckHCLKeys(item.Val, valid); err != nil { - return multierror.Prefix(err, fmt.Sprintf("%s.%s:", blockName, key)) - } - var m map[string]string if err := hcl.DecodeObject(&m, item.Val); err != nil { return multierror.Prefix(err, fmt.Sprintf("%s.%s:", blockName, key)) @@ -804,34 +742,6 @@ func parseListeners(result *Config, list *ast.ObjectList) error { key = item.Keys[0].Token.Value().(string) } - valid := []string{ - "address", - "cluster_address", - "endpoint", - "x_forwarded_for_authorized_addrs", - "x_forwarded_for_hop_skips", - "x_forwarded_for_reject_not_authorized", - "x_forwarded_for_reject_not_present", - "infrastructure", - "max_request_size", - "node_id", - "proxy_protocol_behavior", - "proxy_protocol_authorized_addrs", - "tls_disable", - "tls_cert_file", - "tls_key_file", - "tls_min_version", - "tls_cipher_suites", - "tls_prefer_server_cipher_suites", - "tls_require_and_verify_client_cert", - "tls_disable_client_certs", - "tls_client_ca_file", - "token", - } - if err := hclutil.CheckHCLKeys(item.Val, valid); err != nil { - return multierror.Prefix(err, fmt.Sprintf("listeners.%s:", key)) - } - var m map[string]interface{} if err := hcl.DecodeObject(&m, item.Val); err != nil { return multierror.Prefix(err, fmt.Sprintf("listeners.%s:", key)) @@ -857,31 +767,6 @@ func parseTelemetry(result *Config, list *ast.ObjectList) error { // Get our one item item := list.Items[0] - // Check for invalid keys - valid := []string{ - "circonus_api_token", - "circonus_api_app", - "circonus_api_url", - "circonus_submission_interval", - "circonus_submission_url", - "circonus_check_id", - "circonus_check_force_metric_activation", - "circonus_check_instance_id", - "circonus_check_search_tag", - "circonus_check_display_name", - "circonus_check_tags", - "circonus_broker_id", - "circonus_broker_select_tag", - "disable_hostname", - "dogstatsd_addr", - "dogstatsd_tags", - "statsd_address", - "statsite_address", - } - if err := hclutil.CheckHCLKeys(item.Val, valid); err != nil { - return multierror.Prefix(err, "telemetry:") - } - var t Telemetry if err := hcl.DecodeObject(&t, item.Val); err != nil { return multierror.Prefix(err, "telemetry:") diff --git a/command/server/config_test.go b/command/server/config_test.go index 74b04b1e5b..9feba4038a 100644 --- a/command/server/config_test.go +++ b/command/server/config_test.go @@ -383,73 +383,3 @@ listener "tcp" { } } - -func TestParseConfig_badTopLevel(t *testing.T) { - logger := logging.NewVaultLogger(log.Debug) - - _, err := ParseConfig(strings.TrimSpace(` -backend {} -bad = "one" -nope = "yes" -`), logger) - - if err == nil { - t.Fatal("expected error") - } - - if !strings.Contains(err.Error(), `invalid key "bad" on line 2`) { - t.Errorf("bad error: %q", err) - } - - if !strings.Contains(err.Error(), `invalid key "nope" on line 3`) { - t.Errorf("bad error: %q", err) - } -} - -func TestParseConfig_badListener(t *testing.T) { - logger := logging.NewVaultLogger(log.Debug) - - _, err := ParseConfig(strings.TrimSpace(` -listener "tcp" { - address = "1.2.3.3" - bad = "one" - nope = "yes" -} -`), logger) - - if err == nil { - t.Fatal("expected error") - } - - if !strings.Contains(err.Error(), `listeners.tcp: invalid key "bad" on line 3`) { - t.Errorf("bad error: %q", err) - } - - if !strings.Contains(err.Error(), `listeners.tcp: invalid key "nope" on line 4`) { - t.Errorf("bad error: %q", err) - } -} - -func TestParseConfig_badTelemetry(t *testing.T) { - logger := logging.NewVaultLogger(log.Debug) - - _, err := ParseConfig(strings.TrimSpace(` -telemetry { - statsd_address = "1.2.3.3" - bad = "one" - nope = "yes" -} -`), logger) - - if err == nil { - t.Fatal("expected error") - } - - if !strings.Contains(err.Error(), `telemetry: invalid key "bad" on line 3`) { - t.Errorf("bad error: %q", err) - } - - if !strings.Contains(err.Error(), `telemetry: invalid key "nope" on line 4`) { - t.Errorf("bad error: %q", err) - } -} diff --git a/helper/parseutil/parseutil.go b/helper/parseutil/parseutil.go index ae8c58ba78..9b32bf7df4 100644 --- a/helper/parseutil/parseutil.go +++ b/helper/parseutil/parseutil.go @@ -28,7 +28,7 @@ func ParseDurationSecond(in interface{}) (time.Duration, error) { } var err error // Look for a suffix otherwise its a plain second value - if strings.HasSuffix(inp, "s") || strings.HasSuffix(inp, "m") || strings.HasSuffix(inp, "h") { + if strings.HasSuffix(inp, "s") || strings.HasSuffix(inp, "m") || strings.HasSuffix(inp, "h") || strings.HasSuffix(inp, "ms") { dur, err = time.ParseDuration(inp) if err != nil { return dur, err diff --git a/http/handler.go b/http/handler.go index f5586ab8fb..a7184b10fa 100644 --- a/http/handler.go +++ b/http/handler.go @@ -113,7 +113,7 @@ func Handler(props *vault.HandlerProperties) http.Handler { // Wrap the help wrapped handler with another layer with a generic // handler - genericWrappedHandler := wrapGenericHandler(corsWrappedHandler, props.MaxRequestSize) + genericWrappedHandler := wrapGenericHandler(corsWrappedHandler, props.MaxRequestSize, props.MaxRequestDuration) // Wrap the handler with PrintablePathCheckHandler to check for non-printable // characters in the request path. @@ -128,20 +128,27 @@ func Handler(props *vault.HandlerProperties) http.Handler { // wrapGenericHandler wraps the handler with an extra layer of handler where // tasks that should be commonly handled for all the requests and/or responses // are performed. -func wrapGenericHandler(h http.Handler, maxRequestSize int64) http.Handler { +func wrapGenericHandler(h http.Handler, maxRequestSize int64, maxRequestDuration time.Duration) http.Handler { + if maxRequestDuration == 0 { + maxRequestDuration = vault.DefaultMaxRequestDuration + } return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Set the Cache-Control header for all the responses returned // by Vault w.Header().Set("Cache-Control", "no-store") - // Add a context and put the request limit for this handler in it + // Start with the request context + ctx := r.Context() + var cancelFunc context.CancelFunc + // Add our timeout + ctx, cancelFunc = context.WithTimeout(ctx, maxRequestDuration) + // Add a size limiter if desired if maxRequestSize > 0 { - ctx := context.WithValue(r.Context(), "max_request_size", maxRequestSize) - h.ServeHTTP(w, r.WithContext(ctx)) - } else { - h.ServeHTTP(w, r) + ctx = context.WithValue(ctx, "max_request_size", maxRequestSize) } - + r = r.WithContext(ctx) + h.ServeHTTP(w, r) + cancelFunc() return }) } @@ -432,7 +439,7 @@ func handleRequestForwarding(core *vault.Core, handler http.Handler) http.Handle // request is a helper to perform a request and properly exit in the // case of an error. func request(core *vault.Core, w http.ResponseWriter, rawReq *http.Request, r *logical.Request) (*logical.Response, bool) { - resp, err := core.HandleRequest(r) + resp, err := core.HandleRequest(rawReq.Context(), r) if errwrap.Contains(err, consts.ErrStandby.Error()) { respondStandby(core, w, rawReq.URL) return resp, false diff --git a/http/help.go b/http/help.go index 1c3a9560f0..597fbb0975 100644 --- a/http/help.go +++ b/http/help.go @@ -37,7 +37,7 @@ func handleHelp(core *vault.Core, w http.ResponseWriter, req *http.Request) { Connection: getConnection(req), }) - resp, err := core.HandleRequest(lreq) + resp, err := core.HandleRequest(req.Context(), lreq) if err != nil { respondErrorCommon(w, lreq, resp, err) return diff --git a/http/sys_seal.go b/http/sys_seal.go index 2eebb95cd1..ad6aaf9bff 100644 --- a/http/sys_seal.go +++ b/http/sys_seal.go @@ -32,7 +32,7 @@ func handleSysSeal(core *vault.Core) http.Handler { // Seal with the token above // We use context.Background since there won't be a request context if the node isn't active - if err := core.SealWithRequest(req); err != nil { + if err := core.SealWithRequest(r.Context(), req); err != nil { if errwrap.Contains(err, logical.ErrPermissionDenied.Error()) { respondError(w, http.StatusForbidden, err) return @@ -62,7 +62,7 @@ func handleSysStepDown(core *vault.Core) http.Handler { } // Seal with the token above - if err := core.StepDown(req); err != nil { + if err := core.StepDown(r.Context(), req); err != nil { respondError(w, http.StatusInternalServerError, err) return } diff --git a/http/sys_seal_test.go b/http/sys_seal_test.go index 68c9b53075..1b42f1b867 100644 --- a/http/sys_seal_test.go +++ b/http/sys_seal_test.go @@ -1,6 +1,7 @@ package http import ( + "context" "encoding/hex" "encoding/json" "fmt" @@ -273,7 +274,7 @@ func TestSysSeal_Permissions(t *testing.T) { }, ClientToken: root, } - resp, err := core.HandleRequest(req) + resp, err := core.HandleRequest(context.Background(), req) if err != nil { t.Fatalf("err: %v", err) } @@ -288,7 +289,7 @@ func TestSysSeal_Permissions(t *testing.T) { "policies": []string{"test"}, } - resp, err = core.HandleRequest(req) + resp, err = core.HandleRequest(context.Background(), req) if err != nil { t.Fatalf("err: %v %v", err, resp) } @@ -311,7 +312,7 @@ func TestSysSeal_Permissions(t *testing.T) { }, ClientToken: root, } - resp, err = core.HandleRequest(req) + resp, err = core.HandleRequest(context.Background(), req) if err != nil { t.Fatalf("err: %v", err) } @@ -332,7 +333,7 @@ func TestSysSeal_Permissions(t *testing.T) { }, ClientToken: root, } - resp, err = core.HandleRequest(req) + resp, err = core.HandleRequest(context.Background(), req) if err != nil { t.Fatalf("err: %v", err) } @@ -353,7 +354,7 @@ func TestSysSeal_Permissions(t *testing.T) { }, ClientToken: root, } - resp, err = core.HandleRequest(req) + resp, err = core.HandleRequest(context.Background(), req) if err != nil { t.Fatalf("err: %v", err) } diff --git a/logical/testing/testing.go b/logical/testing/testing.go index 0f8a9b28d3..03d3b220e8 100644 --- a/logical/testing/testing.go +++ b/logical/testing/testing.go @@ -249,7 +249,7 @@ func Test(tt TestT, c TestCase) { req.Path = fmt.Sprintf("%s/%s", prefix, req.Path) // Make the request - resp, err := core.HandleRequest(req) + resp, err := core.HandleRequest(context.Background(), req) if resp != nil && resp.Secret != nil { // Revoke this secret later revoke = append(revoke, &logical.Request{ @@ -303,7 +303,7 @@ func Test(tt TestT, c TestCase) { logger.Warn("Revoking secret", "secret", fmt.Sprintf("%#v", req)) } req.ClientToken = client.Token() - resp, err := core.HandleRequest(req) + resp, err := core.HandleRequest(context.Background(), req) if err == nil && resp.IsError() { err = fmt.Errorf("erroneous response:\n\n%#v", resp) } @@ -320,7 +320,7 @@ func Test(tt TestT, c TestCase) { req := logical.RollbackRequest(prefix + "/") req.Data["immediate"] = true req.ClientToken = client.Token() - resp, err := core.HandleRequest(req) + resp, err := core.HandleRequest(context.Background(), req) if err == nil && resp.IsError() { err = fmt.Errorf("erroneous response:\n\n%#v", resp) } diff --git a/physical/file/file.go b/physical/file/file.go index 1edf7c749f..6d56f61b09 100644 --- a/physical/file/file.go +++ b/physical/file/file.go @@ -98,6 +98,12 @@ func (b *FileBackend) DeleteInternal(ctx context.Context, path string) error { basePath, key := b.expandPath(path) fullPath := filepath.Join(basePath, key) + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + err := os.Remove(fullPath) if err != nil && !os.IsNotExist(err) { return errwrap.Wrapf(fmt.Sprintf("failed to remove %q: {{err}}", fullPath), err) @@ -192,6 +198,12 @@ func (b *FileBackend) GetInternal(ctx context.Context, k string) (*physical.Entr return nil, err } + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + return &physical.Entry{ Key: k, Value: entry.Value, @@ -236,6 +248,12 @@ func (b *FileBackend) PutInternal(ctx context.Context, entry *physical.Entry) er return errors.New("could not successfully get a file handle") } + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + enc := json.NewEncoder(f) encErr := enc.Encode(&fileEntry{ Value: entry.Value, @@ -270,10 +288,10 @@ func (b *FileBackend) List(ctx context.Context, prefix string) ([]string, error) b.RLock() defer b.RUnlock() - return b.ListInternal(prefix) + return b.ListInternal(ctx, prefix) } -func (b *FileBackend) ListInternal(prefix string) ([]string, error) { +func (b *FileBackend) ListInternal(ctx context.Context, prefix string) ([]string, error) { if err := b.validatePath(prefix); err != nil { return nil, err } @@ -315,6 +333,12 @@ func (b *FileBackend) ListInternal(prefix string) ([]string, error) { } } + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + return names, nil } diff --git a/physical/inmem/inmem.go b/physical/inmem/inmem.go index 0274305685..0616738c13 100644 --- a/physical/inmem/inmem.go +++ b/physical/inmem/inmem.go @@ -93,6 +93,12 @@ func (i *InmemBackend) PutInternal(ctx context.Context, entry *physical.Entry) e return PutDisabledError } + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + i.root.Insert(entry.Key, entry.Value) return nil } @@ -121,6 +127,12 @@ func (i *InmemBackend) GetInternal(ctx context.Context, key string) (*physical.E return nil, GetDisabledError } + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + if raw, ok := i.root.Get(key); ok { return &physical.Entry{ Key: key, @@ -153,6 +165,11 @@ func (i *InmemBackend) DeleteInternal(ctx context.Context, key string) error { if atomic.LoadUint32(i.failDelete) != 0 { return DeleteDisabledError } + select { + case <-ctx.Done(): + return ctx.Err() + default: + } i.root.Delete(key) return nil @@ -175,10 +192,10 @@ func (i *InmemBackend) List(ctx context.Context, prefix string) ([]string, error i.RLock() defer i.RUnlock() - return i.ListInternal(prefix) + return i.ListInternal(ctx, prefix) } -func (i *InmemBackend) ListInternal(prefix string) ([]string, error) { +func (i *InmemBackend) ListInternal(ctx context.Context, prefix string) ([]string, error) { if atomic.LoadUint32(i.failList) != 0 { return nil, ListDisabledError } @@ -201,6 +218,12 @@ func (i *InmemBackend) ListInternal(prefix string) ([]string, error) { } i.root.WalkPrefix(prefix, walkFn) + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + return out, nil } diff --git a/vault/auth_test.go b/vault/auth_test.go index 8b32275997..2a24ab716f 100644 --- a/vault/auth_test.go +++ b/vault/auth_test.go @@ -342,7 +342,7 @@ func TestCore_DisableCredential_Cleanup(t *testing.T) { Operation: logical.ReadOperation, Path: "auth/foo/login", } - resp, err := c.HandleRequest(r) + resp, err := c.HandleRequest(context.Background(), r) if err != nil { t.Fatalf("err: %v", err) } diff --git a/vault/cluster_test.go b/vault/cluster_test.go index 2b5b661efb..f1440dbf0b 100644 --- a/vault/cluster_test.go +++ b/vault/cluster_test.go @@ -150,7 +150,7 @@ func TestCluster_ListenForRequests(t *testing.T) { time.Sleep(clusterTestPausePeriod) checkListenersFunc(false) - err := cores[0].StepDown(&logical.Request{ + err := cores[0].StepDown(context.Background(), &logical.Request{ Operation: logical.UpdateOperation, Path: "sys/step-down", ClientToken: cluster.RootToken, @@ -222,7 +222,7 @@ func testCluster_ForwardRequestsCommon(t *testing.T) { // // Ensure active core is cores[1] and test - err := cores[0].StepDown(&logical.Request{ + err := cores[0].StepDown(context.Background(), &logical.Request{ Operation: logical.UpdateOperation, Path: "sys/step-down", ClientToken: root, @@ -231,7 +231,7 @@ func testCluster_ForwardRequestsCommon(t *testing.T) { t.Fatal(err) } time.Sleep(clusterTestPausePeriod) - _ = cores[2].StepDown(&logical.Request{ + _ = cores[2].StepDown(context.Background(), &logical.Request{ Operation: logical.UpdateOperation, Path: "sys/step-down", ClientToken: root, @@ -242,7 +242,7 @@ func testCluster_ForwardRequestsCommon(t *testing.T) { testCluster_ForwardRequests(t, cores[2], root, "core2") // Ensure active core is cores[2] and test - err = cores[1].StepDown(&logical.Request{ + err = cores[1].StepDown(context.Background(), &logical.Request{ Operation: logical.UpdateOperation, Path: "sys/step-down", ClientToken: root, @@ -251,7 +251,7 @@ func testCluster_ForwardRequestsCommon(t *testing.T) { t.Fatal(err) } time.Sleep(clusterTestPausePeriod) - _ = cores[0].StepDown(&logical.Request{ + _ = cores[0].StepDown(context.Background(), &logical.Request{ Operation: logical.UpdateOperation, Path: "sys/step-down", ClientToken: root, @@ -262,7 +262,7 @@ func testCluster_ForwardRequestsCommon(t *testing.T) { testCluster_ForwardRequests(t, cores[1], root, "core3") // Ensure active core is cores[0] and test - err = cores[2].StepDown(&logical.Request{ + err = cores[2].StepDown(context.Background(), &logical.Request{ Operation: logical.UpdateOperation, Path: "sys/step-down", ClientToken: root, @@ -271,7 +271,7 @@ func testCluster_ForwardRequestsCommon(t *testing.T) { t.Fatal(err) } time.Sleep(clusterTestPausePeriod) - _ = cores[1].StepDown(&logical.Request{ + _ = cores[1].StepDown(context.Background(), &logical.Request{ Operation: logical.UpdateOperation, Path: "sys/step-down", ClientToken: root, @@ -282,7 +282,7 @@ func testCluster_ForwardRequestsCommon(t *testing.T) { testCluster_ForwardRequests(t, cores[2], root, "core1") // Ensure active core is cores[1] and test - err = cores[0].StepDown(&logical.Request{ + err = cores[0].StepDown(context.Background(), &logical.Request{ Operation: logical.UpdateOperation, Path: "sys/step-down", ClientToken: root, @@ -291,7 +291,7 @@ func testCluster_ForwardRequestsCommon(t *testing.T) { t.Fatal(err) } time.Sleep(clusterTestPausePeriod) - _ = cores[2].StepDown(&logical.Request{ + _ = cores[2].StepDown(context.Background(), &logical.Request{ Operation: logical.UpdateOperation, Path: "sys/step-down", ClientToken: root, @@ -302,7 +302,7 @@ func testCluster_ForwardRequestsCommon(t *testing.T) { testCluster_ForwardRequests(t, cores[2], root, "core2") // Ensure active core is cores[2] and test - err = cores[1].StepDown(&logical.Request{ + err = cores[1].StepDown(context.Background(), &logical.Request{ Operation: logical.UpdateOperation, Path: "sys/step-down", ClientToken: root, @@ -311,7 +311,7 @@ func testCluster_ForwardRequestsCommon(t *testing.T) { t.Fatal(err) } time.Sleep(clusterTestPausePeriod) - _ = cores[0].StepDown(&logical.Request{ + _ = cores[0].StepDown(context.Background(), &logical.Request{ Operation: logical.UpdateOperation, Path: "sys/step-down", ClientToken: root, diff --git a/vault/core.go b/vault/core.go index 33515d5942..bddece2050 100644 --- a/vault/core.go +++ b/vault/core.go @@ -925,7 +925,7 @@ func (c *Core) unsealInternal(ctx context.Context, masterKey []byte) (bool, erro // SealWithRequest takes in a logical.Request, acquires the lock, and passes // through to sealInternal -func (c *Core) SealWithRequest(req *logical.Request) error { +func (c *Core) SealWithRequest(httpCtx context.Context, req *logical.Request) error { defer metrics.MeasureSince([]string{"core", "seal-with-request"}, time.Now()) if c.Sealed() { @@ -936,7 +936,19 @@ func (c *Core) SealWithRequest(req *logical.Request) error { // This will unlock the read lock // We use background context since we may not be active - return c.sealInitCommon(context.Background(), req) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { + select { + case <-ctx.Done(): + case <-httpCtx.Done(): + cancel() + } + }() + + // This will unlock the read lock + return c.sealInitCommon(ctx, req) } // Seal takes in a token and creates a logical.Request, acquires the lock, and @@ -1068,7 +1080,7 @@ func (c *Core) sealInitCommon(ctx context.Context, req *logical.Request) (retErr // we won't have a token store after sealing. leaseID, err := c.expiration.CreateOrFetchRevocationLeaseByToken(te) if err == nil { - err = c.expiration.Revoke(leaseID) + err = c.expiration.Revoke(ctx, leaseID) } if err != nil { c.logger.Error("token needed revocation before seal but failed to revoke", "error", err) diff --git a/vault/core_test.go b/vault/core_test.go index 3a31f02b66..81ed890c69 100644 --- a/vault/core_test.go +++ b/vault/core_test.go @@ -185,7 +185,7 @@ func TestCore_Route_Sealed(t *testing.T) { Operation: logical.ReadOperation, Path: "sys/mounts", } - _, err := c.HandleRequest(req) + _, err := c.HandleRequest(context.Background(), req) if err != consts.ErrSealed { t.Fatalf("err: %v", err) } @@ -208,7 +208,7 @@ func TestCore_Route_Sealed(t *testing.T) { // Should not error after unseal req.ClientToken = res.RootToken - _, err = c.HandleRequest(req) + _, err = c.HandleRequest(context.Background(), req) if err != nil { t.Fatalf("err: %v", err) } @@ -301,7 +301,7 @@ func TestCore_HandleRequest_Lease(t *testing.T) { }, ClientToken: root, } - resp, err := c.HandleRequest(req) + resp, err := c.HandleRequest(context.Background(), req) if err != nil { t.Fatalf("err: %v", err) } @@ -312,7 +312,7 @@ func TestCore_HandleRequest_Lease(t *testing.T) { // Read the key req.Operation = logical.ReadOperation req.Data = nil - resp, err = c.HandleRequest(req) + resp, err = c.HandleRequest(context.Background(), req) if err != nil { t.Fatalf("err: %v", err) } @@ -342,7 +342,7 @@ func TestCore_HandleRequest_Lease_MaxLength(t *testing.T) { }, ClientToken: root, } - resp, err := c.HandleRequest(req) + resp, err := c.HandleRequest(context.Background(), req) if err != nil { t.Fatalf("err: %v", err) } @@ -353,7 +353,7 @@ func TestCore_HandleRequest_Lease_MaxLength(t *testing.T) { // Read the key req.Operation = logical.ReadOperation req.Data = nil - resp, err = c.HandleRequest(req) + resp, err = c.HandleRequest(context.Background(), req) if err != nil { t.Fatalf("err: %v", err) } @@ -383,7 +383,7 @@ func TestCore_HandleRequest_Lease_DefaultLength(t *testing.T) { }, ClientToken: root, } - resp, err := c.HandleRequest(req) + resp, err := c.HandleRequest(context.Background(), req) if err != nil { t.Fatalf("err: %v", err) } @@ -394,7 +394,7 @@ func TestCore_HandleRequest_Lease_DefaultLength(t *testing.T) { // Read the key req.Operation = logical.ReadOperation req.Data = nil - resp, err = c.HandleRequest(req) + resp, err = c.HandleRequest(context.Background(), req) if err != nil { t.Fatalf("err: %v", err) } @@ -423,7 +423,7 @@ func TestCore_HandleRequest_MissingToken(t *testing.T) { "lease": "1h", }, } - resp, err := c.HandleRequest(req) + resp, err := c.HandleRequest(context.Background(), req) if err == nil || !errwrap.Contains(err, logical.ErrInvalidRequest.Error()) { t.Fatalf("err: %v", err) } @@ -444,7 +444,7 @@ func TestCore_HandleRequest_InvalidToken(t *testing.T) { }, ClientToken: "foobarbaz", } - resp, err := c.HandleRequest(req) + resp, err := c.HandleRequest(context.Background(), req) if err == nil || !errwrap.Contains(err, logical.ErrPermissionDenied.Error()) { t.Fatalf("err: %v", err) } @@ -462,7 +462,7 @@ func TestCore_HandleRequest_NoSlash(t *testing.T) { Path: "secret", ClientToken: root, } - resp, err := c.HandleRequest(req) + resp, err := c.HandleRequest(context.Background(), req) if err != nil { t.Fatalf("err: %v, resp: %v", err, resp) } @@ -481,7 +481,7 @@ func TestCore_HandleRequest_RootPath(t *testing.T) { Path: "sys/policy", // root protected! ClientToken: "child", } - resp, err := c.HandleRequest(req) + resp, err := c.HandleRequest(context.Background(), req) if err == nil || !errwrap.Contains(err, logical.ErrPermissionDenied.Error()) { t.Fatalf("err: %v, resp: %v", err, resp) } @@ -500,7 +500,7 @@ func TestCore_HandleRequest_RootPath_WithSudo(t *testing.T) { }, ClientToken: root, } - resp, err := c.HandleRequest(req) + resp, err := c.HandleRequest(context.Background(), req) if err != nil { t.Fatalf("err: %v", err) } @@ -515,7 +515,7 @@ func TestCore_HandleRequest_RootPath_WithSudo(t *testing.T) { Path: "sys/policy", // root protected! ClientToken: "child", } - resp, err = c.HandleRequest(req) + resp, err = c.HandleRequest(context.Background(), req) if err != nil { t.Fatalf("err: %v", err) } @@ -538,7 +538,7 @@ func TestCore_HandleRequest_PermissionDenied(t *testing.T) { }, ClientToken: "child", } - resp, err := c.HandleRequest(req) + resp, err := c.HandleRequest(context.Background(), req) if err == nil || !errwrap.Contains(err, logical.ErrPermissionDenied.Error()) { t.Fatalf("err: %v, resp: %v", err, resp) } @@ -558,7 +558,7 @@ func TestCore_HandleRequest_PermissionAllowed(t *testing.T) { }, ClientToken: root, } - resp, err := c.HandleRequest(req) + resp, err := c.HandleRequest(context.Background(), req) if err != nil { t.Fatalf("err: %v", err) } @@ -576,7 +576,7 @@ func TestCore_HandleRequest_PermissionAllowed(t *testing.T) { }, ClientToken: "child", } - resp, err = c.HandleRequest(req) + resp, err = c.HandleRequest(context.Background(), req) if err != nil { t.Fatalf("err: %v", err) } @@ -599,7 +599,7 @@ func TestCore_HandleRequest_NoClientToken(t *testing.T) { req.Data["type"] = "noop" req.Data["description"] = "foo" req.ClientToken = root - _, err := c.HandleRequest(req) + _, err := c.HandleRequest(context.Background(), req) if err != nil { t.Fatalf("err: %v", err) } @@ -609,7 +609,7 @@ func TestCore_HandleRequest_NoClientToken(t *testing.T) { Path: "foo/login", } req.ClientToken = root - if _, err := c.HandleRequest(req); err != nil { + if _, err := c.HandleRequest(context.Background(), req); err != nil { t.Fatalf("err: %v", err) } @@ -633,7 +633,7 @@ func TestCore_HandleRequest_ConnOnLogin(t *testing.T) { req := logical.TestRequest(t, logical.UpdateOperation, "sys/auth/foo") req.Data["type"] = "noop" req.ClientToken = root - _, err := c.HandleRequest(req) + _, err := c.HandleRequest(context.Background(), req) if err != nil { t.Fatalf("err: %v", err) } @@ -643,7 +643,7 @@ func TestCore_HandleRequest_ConnOnLogin(t *testing.T) { Path: "auth/foo/login", Connection: &logical.Connection{}, } - if _, err := c.HandleRequest(req); err != nil { + if _, err := c.HandleRequest(context.Background(), req); err != nil { t.Fatalf("err: %v", err) } if noop.Requests[0].Connection == nil { @@ -674,7 +674,7 @@ func TestCore_HandleLogin_Token(t *testing.T) { req := logical.TestRequest(t, logical.UpdateOperation, "sys/auth/foo") req.Data["type"] = "noop" req.ClientToken = root - _, err := c.HandleRequest(req) + _, err := c.HandleRequest(context.Background(), req) if err != nil { t.Fatalf("err: %v", err) } @@ -683,7 +683,7 @@ func TestCore_HandleLogin_Token(t *testing.T) { lreq := &logical.Request{ Path: "auth/foo/login", } - lresp, err := c.HandleRequest(lreq) + lresp, err := c.HandleRequest(context.Background(), lreq) if err != nil { t.Fatalf("err: %v", err) } @@ -738,7 +738,7 @@ func TestCore_HandleRequest_AuditTrail(t *testing.T) { req := logical.TestRequest(t, logical.UpdateOperation, "sys/audit/noop") req.Data["type"] = "noop" req.ClientToken = root - resp, err := c.HandleRequest(req) + resp, err := c.HandleRequest(context.Background(), req) if err != nil { t.Fatalf("err: %v", err) } @@ -754,7 +754,7 @@ func TestCore_HandleRequest_AuditTrail(t *testing.T) { ClientToken: root, } req.ClientToken = root - if _, err := c.HandleRequest(req); err != nil { + if _, err := c.HandleRequest(context.Background(), req); err != nil { t.Fatalf("err: %v", err) } @@ -802,7 +802,7 @@ func TestCore_HandleRequest_AuditTrail_noHMACKeys(t *testing.T) { req := logical.TestRequest(t, logical.UpdateOperation, "sys/mounts/secret/tune") req.Data["audit_non_hmac_request_keys"] = "foo" req.ClientToken = root - resp, err := c.HandleRequest(req) + resp, err := c.HandleRequest(context.Background(), req) if err != nil { t.Fatalf("err: %v", err) } @@ -810,7 +810,7 @@ func TestCore_HandleRequest_AuditTrail_noHMACKeys(t *testing.T) { req = logical.TestRequest(t, logical.UpdateOperation, "sys/mounts/secret/tune") req.Data["audit_non_hmac_response_keys"] = "baz" req.ClientToken = root - resp, err = c.HandleRequest(req) + resp, err = c.HandleRequest(context.Background(), req) if err != nil { t.Fatalf("err: %v", err) } @@ -819,7 +819,7 @@ func TestCore_HandleRequest_AuditTrail_noHMACKeys(t *testing.T) { req = logical.TestRequest(t, logical.UpdateOperation, "sys/audit/noop") req.Data["type"] = "noop" req.ClientToken = root - resp, err = c.HandleRequest(req) + resp, err = c.HandleRequest(context.Background(), req) if err != nil { t.Fatalf("err: %v", err) } @@ -834,7 +834,7 @@ func TestCore_HandleRequest_AuditTrail_noHMACKeys(t *testing.T) { ClientToken: root, } req.ClientToken = root - if _, err := c.HandleRequest(req); err != nil { + if _, err := c.HandleRequest(context.Background(), req); err != nil { t.Fatalf("err: %v", err) } @@ -876,7 +876,7 @@ func TestCore_HandleRequest_AuditTrail_noHMACKeys(t *testing.T) { ClientToken: root, } req.ClientToken = root - if _, err := c.HandleRequest(req); err != nil { + if _, err := c.HandleRequest(context.Background(), req); err != nil { t.Fatalf("err: %v", err) } if len(noop.RespNonHMACKeys) != 1 || noop.RespNonHMACKeys[0] != "baz" { @@ -920,7 +920,7 @@ func TestCore_HandleLogin_AuditTrail(t *testing.T) { req := logical.TestRequest(t, logical.UpdateOperation, "sys/auth/foo") req.Data["type"] = "noop" req.ClientToken = root - _, err := c.HandleRequest(req) + _, err := c.HandleRequest(context.Background(), req) if err != nil { t.Fatalf("err: %v", err) } @@ -929,7 +929,7 @@ func TestCore_HandleLogin_AuditTrail(t *testing.T) { req = logical.TestRequest(t, logical.UpdateOperation, "sys/audit/noop") req.Data["type"] = "noop" req.ClientToken = root - _, err = c.HandleRequest(req) + _, err = c.HandleRequest(context.Background(), req) if err != nil { t.Fatalf("err: %v", err) } @@ -938,7 +938,7 @@ func TestCore_HandleLogin_AuditTrail(t *testing.T) { lreq := &logical.Request{ Path: "auth/foo/login", } - lresp, err := c.HandleRequest(lreq) + lresp, err := c.HandleRequest(context.Background(), lreq) if err != nil { t.Fatalf("err: %v", err) } @@ -983,7 +983,7 @@ func TestCore_HandleRequest_CreateToken_Lease(t *testing.T) { req := logical.TestRequest(t, logical.UpdateOperation, "auth/token/create") req.ClientToken = root req.Data["policies"] = []string{"foo"} - resp, err := c.HandleRequest(req) + resp, err := c.HandleRequest(context.Background(), req) if err != nil { t.Fatalf("err: %v", err) } @@ -1031,7 +1031,7 @@ func TestCore_HandleRequest_CreateToken_NoDefaultPolicy(t *testing.T) { req.ClientToken = root req.Data["policies"] = []string{"foo"} req.Data["no_default_policy"] = true - resp, err := c.HandleRequest(req) + resp, err := c.HandleRequest(context.Background(), req) if err != nil { t.Fatalf("err: %v", err) } @@ -1069,7 +1069,7 @@ func TestCore_LimitedUseToken(t *testing.T) { req := logical.TestRequest(t, logical.UpdateOperation, "auth/token/create") req.ClientToken = root req.Data["num_uses"] = "1" - resp, err := c.HandleRequest(req) + resp, err := c.HandleRequest(context.Background(), req) if err != nil { t.Fatalf("err: %v", err) } @@ -1083,13 +1083,13 @@ func TestCore_LimitedUseToken(t *testing.T) { }, ClientToken: resp.Auth.ClientToken, } - _, err = c.HandleRequest(req) + _, err = c.HandleRequest(context.Background(), req) if err != nil { t.Fatalf("err: %v", err) } // Second operation should fail - _, err = c.HandleRequest(req) + _, err = c.HandleRequest(context.Background(), req) if err == nil || !errwrap.Contains(err, logical.ErrPermissionDenied.Error()) { t.Fatalf("err: %v", err) } @@ -1310,7 +1310,7 @@ func TestCore_StepDown(t *testing.T) { } // Step down core - err = core.StepDown(req) + err = core.StepDown(context.Background(), req) if err != nil { t.Fatal("error stepping down core 1") } @@ -1352,7 +1352,7 @@ func TestCore_StepDown(t *testing.T) { } // Step down core2 - err = core2.StepDown(req) + err = core2.StepDown(context.Background(), req) if err != nil { t.Fatal("error stepping down core 1") } @@ -1619,7 +1619,7 @@ func testCore_Standby_Common(t *testing.T, inm physical.Backend, inmha physical. }, ClientToken: root, } - _, err = core.HandleRequest(req) + _, err = core.HandleRequest(context.Background(), req) if err != nil { t.Fatalf("err: %v", err) } @@ -1668,7 +1668,7 @@ func testCore_Standby_Common(t *testing.T, inm physical.Backend, inmha physical. } // Request should fail in standby mode - _, err = core2.HandleRequest(req) + _, err = core2.HandleRequest(context.Background(), req) if err != consts.ErrStandby { t.Fatalf("err: %v", err) } @@ -1709,7 +1709,7 @@ func testCore_Standby_Common(t *testing.T, inm physical.Backend, inmha physical. Path: "secret/foo", ClientToken: root, } - resp, err := core2.HandleRequest(req) + resp, err := core2.HandleRequest(context.Background(), req) if err != nil { t.Fatalf("err: %v", err) } @@ -1772,7 +1772,7 @@ func TestCore_HandleRequest_Login_InternalData(t *testing.T) { req := logical.TestRequest(t, logical.UpdateOperation, "sys/auth/foo") req.Data["type"] = "noop" req.ClientToken = root - _, err := c.HandleRequest(req) + _, err := c.HandleRequest(context.Background(), req) if err != nil { t.Fatalf("err: %v", err) } @@ -1781,7 +1781,7 @@ func TestCore_HandleRequest_Login_InternalData(t *testing.T) { lreq := &logical.Request{ Path: "auth/foo/login", } - lresp, err := c.HandleRequest(lreq) + lresp, err := c.HandleRequest(context.Background(), lreq) if err != nil { t.Fatalf("err: %v", err) } @@ -1816,7 +1816,7 @@ func TestCore_HandleRequest_InternalData(t *testing.T) { req := logical.TestRequest(t, logical.UpdateOperation, "sys/mounts/foo") req.Data["type"] = "noop" req.ClientToken = root - _, err := c.HandleRequest(req) + _, err := c.HandleRequest(context.Background(), req) if err != nil { t.Fatalf("err: %v", err) } @@ -1827,7 +1827,7 @@ func TestCore_HandleRequest_InternalData(t *testing.T) { Path: "foo/test", ClientToken: root, } - lresp, err := c.HandleRequest(lreq) + lresp, err := c.HandleRequest(context.Background(), lreq) if err != nil { t.Fatalf("err: %v", err) } @@ -1859,7 +1859,7 @@ func TestCore_HandleLogin_ReturnSecret(t *testing.T) { req := logical.TestRequest(t, logical.UpdateOperation, "sys/auth/foo") req.Data["type"] = "noop" req.ClientToken = root - _, err := c.HandleRequest(req) + _, err := c.HandleRequest(context.Background(), req) if err != nil { t.Fatalf("err: %v", err) } @@ -1868,7 +1868,7 @@ func TestCore_HandleLogin_ReturnSecret(t *testing.T) { lreq := &logical.Request{ Path: "auth/foo/login", } - _, err = c.HandleRequest(lreq) + _, err = c.HandleRequest(context.Background(), lreq) if err != ErrInternalError { t.Fatalf("err: %v", err) } @@ -1888,7 +1888,7 @@ func TestCore_RenewSameLease(t *testing.T) { }, ClientToken: root, } - resp, err := c.HandleRequest(req) + resp, err := c.HandleRequest(context.Background(), req) if err != nil { t.Fatalf("err: %v", err) } @@ -1899,7 +1899,7 @@ func TestCore_RenewSameLease(t *testing.T) { // Read the key req.Operation = logical.ReadOperation req.Data = nil - resp, err = c.HandleRequest(req) + resp, err = c.HandleRequest(context.Background(), req) if err != nil { t.Fatalf("err: %v", err) } @@ -1911,7 +1911,7 @@ func TestCore_RenewSameLease(t *testing.T) { // Renew the lease req = logical.TestRequest(t, logical.UpdateOperation, "sys/renew/"+resp.Secret.LeaseID) req.ClientToken = root - resp, err = c.HandleRequest(req) + resp, err = c.HandleRequest(context.Background(), req) if err != nil { t.Fatalf("err: %v", err) } @@ -1924,7 +1924,7 @@ func TestCore_RenewSameLease(t *testing.T) { // Renew the lease (alternate path) req = logical.TestRequest(t, logical.UpdateOperation, "sys/leases/renew/"+resp.Secret.LeaseID) req.ClientToken = root - resp, err = c.HandleRequest(req) + resp, err = c.HandleRequest(context.Background(), req) if err != nil { t.Fatalf("err: %v", err) } @@ -1948,7 +1948,7 @@ func TestCore_RenewToken_SingleRegister(t *testing.T) { }, ClientToken: root, } - resp, err := c.HandleRequest(req) + resp, err := c.HandleRequest(context.Background(), req) if err != nil { t.Fatalf("err: %v", err) } @@ -1960,7 +1960,7 @@ func TestCore_RenewToken_SingleRegister(t *testing.T) { req.Data = map[string]interface{}{ "token": newClient, } - resp, err = c.HandleRequest(req) + resp, err = c.HandleRequest(context.Background(), req) if err != nil { t.Fatalf("err: %v", err) } @@ -1968,7 +1968,7 @@ func TestCore_RenewToken_SingleRegister(t *testing.T) { // Revoke using the renew prefix req = logical.TestRequest(t, logical.UpdateOperation, "sys/revoke-prefix/auth/token/renew/") req.ClientToken = root - resp, err = c.HandleRequest(req) + resp, err = c.HandleRequest(context.Background(), req) if err != nil { t.Fatalf("err: %v", err) } @@ -1979,7 +1979,7 @@ func TestCore_RenewToken_SingleRegister(t *testing.T) { "token": newClient, } req.ClientToken = newClient - resp, err = c.HandleRequest(req) + resp, err = c.HandleRequest(context.Background(), req) if err != nil { t.Fatalf("err: %v", err) } @@ -2023,7 +2023,7 @@ path "secret/*" { req := logical.TestRequest(t, logical.UpdateOperation, "sys/auth/foo") req.Data["type"] = "noop" req.ClientToken = root - _, err := c.HandleRequest(req) + _, err := c.HandleRequest(context.Background(), req) if err != nil { t.Fatalf("err: %v", err) } @@ -2032,7 +2032,7 @@ path "secret/*" { lreq := &logical.Request{ Path: "auth/foo/login", } - lresp, err := c.HandleRequest(lreq) + lresp, err := c.HandleRequest(context.Background(), lreq) if err == nil || lresp == nil || !lresp.IsError() { t.Fatalf("expected error trying to auth and receive root policy") } @@ -2042,7 +2042,7 @@ path "secret/*" { lreq = &logical.Request{ Path: "auth/foo/login", } - lresp, err = c.HandleRequest(lreq) + lresp, err = c.HandleRequest(context.Background(), lreq) if err != nil { t.Fatalf("err: %v", err) } @@ -2057,7 +2057,7 @@ path "secret/*" { }, ClientToken: lresp.Auth.ClientToken, } - resp, err := c.HandleRequest(req) + resp, err := c.HandleRequest(context.Background(), req) if err != nil { t.Fatalf("err: %v", err) } @@ -2068,7 +2068,7 @@ path "secret/*" { // Read the key req.Operation = logical.ReadOperation req.Data = nil - resp, err = c.HandleRequest(req) + resp, err = c.HandleRequest(context.Background(), req) if err != nil { t.Fatalf("err: %v", err) } @@ -2082,7 +2082,7 @@ path "secret/*" { "lease_id": resp.Secret.LeaseID, } req.ClientToken = lresp.Auth.ClientToken - _, err = c.HandleRequest(req) + _, err = c.HandleRequest(context.Background(), req) if err != nil { t.Fatalf("err: %v", err) } @@ -2090,7 +2090,7 @@ path "secret/*" { // Disable the credential backend req = logical.TestRequest(t, logical.DeleteOperation, "sys/auth/foo") req.ClientToken = root - resp, err = c.HandleRequest(req) + resp, err = c.HandleRequest(context.Background(), req) if err != nil { t.Fatalf("err: %v %#v", err, resp) } @@ -2110,7 +2110,7 @@ func TestCore_HandleRequest_MountPointType(t *testing.T) { req.Data["type"] = "noop" req.Data["description"] = "foo" req.ClientToken = root - _, err := c.HandleRequest(req) + _, err := c.HandleRequest(context.Background(), req) if err != nil { t.Fatalf("err: %v", err) } @@ -2122,7 +2122,7 @@ func TestCore_HandleRequest_MountPointType(t *testing.T) { Connection: &logical.Connection{}, } req.ClientToken = root - if _, err := c.HandleRequest(req); err != nil { + if _, err := c.HandleRequest(context.Background(), req); err != nil { t.Fatalf("err: %v", err) } @@ -2194,7 +2194,7 @@ func TestCore_Standby_Rotate(t *testing.T) { Path: "sys/rotate", ClientToken: root, } - _, err = core.HandleRequest(req) + _, err = core.HandleRequest(context.Background(), req) if err != nil { t.Fatalf("err: %v", err) } @@ -2214,7 +2214,7 @@ func TestCore_Standby_Rotate(t *testing.T) { Path: "sys/key-status", ClientToken: root, } - resp, err := core2.HandleRequest(req) + resp, err := core2.HandleRequest(context.Background(), req) if err != nil { t.Fatalf("err: %v", err) } @@ -2242,7 +2242,7 @@ func TestCore_HandleRequest_Headers(t *testing.T) { req := logical.TestRequest(t, logical.UpdateOperation, "sys/mounts/foo") req.Data["type"] = "noop" req.ClientToken = root - _, err := c.HandleRequest(req) + _, err := c.HandleRequest(context.Background(), req) if err != nil { t.Fatalf("err: %v", err) } @@ -2251,7 +2251,7 @@ func TestCore_HandleRequest_Headers(t *testing.T) { req = logical.TestRequest(t, logical.UpdateOperation, "sys/mounts/foo/tune") req.Data["passthrough_request_headers"] = []string{"Should-Passthrough", "should-passthrough-case-insensitive"} req.ClientToken = root - _, err = c.HandleRequest(req) + _, err = c.HandleRequest(context.Background(), req) if err != nil { t.Fatalf("err: %v", err) } @@ -2267,7 +2267,7 @@ func TestCore_HandleRequest_Headers(t *testing.T) { "Should-Not-Passthrough": []string{"bar"}, }, } - _, err = c.HandleRequest(lreq) + _, err = c.HandleRequest(context.Background(), lreq) if err != nil { t.Fatalf("err: %v", err) } diff --git a/vault/expiration.go b/vault/expiration.go index 246254d5b6..23c5eff23f 100644 --- a/vault/expiration.go +++ b/vault/expiration.go @@ -274,7 +274,7 @@ func (m *ExpirationManager) Tidy() error { if revokeLease { // Force the revocation and skip going through the token store // again - err = m.revokeCommon(leaseID, true, true) + err = m.revokeCommon(m.quitContext, leaseID, true, true) if err != nil { tidyErrors = multierror.Append(tidyErrors, errwrap.Wrapf(fmt.Sprintf("failed to revoke an invalid lease with ID %q: {{err}}", leaseID), err)) return @@ -487,10 +487,10 @@ func (m *ExpirationManager) Stop() error { } // Revoke is used to revoke a secret named by the given LeaseID -func (m *ExpirationManager) Revoke(leaseID string) error { +func (m *ExpirationManager) Revoke(ctx context.Context, leaseID string) error { defer metrics.MeasureSince([]string{"expire", "revoke"}, time.Now()) - return m.revokeCommon(leaseID, false, false) + return m.revokeCommon(ctx, leaseID, false, false) } // LazyRevoke is used to queue revocation for a secret named by the given @@ -527,7 +527,7 @@ func (m *ExpirationManager) LazyRevoke(leaseID string) error { // revokeCommon does the heavy lifting. If force is true, we ignore a problem // during revocation and still remove entries/index/lease timers -func (m *ExpirationManager) revokeCommon(leaseID string, force, skipToken bool) error { +func (m *ExpirationManager) revokeCommon(ctx context.Context, leaseID string, force, skipToken bool) error { defer metrics.MeasureSince([]string{"expire", "revoke-common"}, time.Now()) // Load the entry @@ -653,7 +653,7 @@ func (m *ExpirationManager) RevokeByToken(te *logical.TokenEntry) error { // we're already revoking the token, so we just want to clean up the lease. // This avoids spurious revocations later in the log when the timer runs // out, and eases up resource usage. - return m.revokeCommon(tokenLeaseID, false, true) + return m.revokeCommon(m.quitContext, tokenLeaseID, false, true) } return nil @@ -671,7 +671,7 @@ func (m *ExpirationManager) revokePrefixCommon(prefix string, force, sync bool) le, err := m.loadEntry(prefix) if err == nil && le != nil { if sync { - if err := m.revokeCommon(prefix, force, false); err != nil { + if err := m.revokeCommon(m.quitContext, prefix, force, false); err != nil { return errwrap.Wrapf(fmt.Sprintf("failed to revoke %q: {{err}}", prefix), err) } return nil @@ -693,7 +693,7 @@ func (m *ExpirationManager) revokePrefixCommon(prefix string, force, sync bool) leaseID := prefix + suffix switch { case sync: - if err := m.revokeCommon(leaseID, force, false); err != nil { + if err := m.revokeCommon(m.quitContext, leaseID, force, false); err != nil { return errwrap.Wrapf(fmt.Sprintf("failed to revoke %q (%d / %d): {{err}}", leaseID, idx+1, len(existing)), err) } default: @@ -1096,6 +1096,16 @@ func (m *ExpirationManager) expireID(leaseID string) { m.pendingLock.Unlock() for attempt := uint(0); attempt < maxRevokeAttempts; attempt++ { + ctx, cancel := context.WithTimeout(m.quitContext, DefaultMaxRequestDuration) + + go func() { + select { + case <-ctx.Done(): + case <-m.quitCh: + cancel() + } + }() + select { case <-m.quitCh: m.logger.Error("shutting down, not attempting further revocation of lease", "lease_id", leaseID) @@ -1107,8 +1117,9 @@ func (m *ExpirationManager) expireID(leaseID string) { } m.coreStateLock.RLock() - err := m.Revoke(leaseID) + err := m.Revoke(ctx, leaseID) m.coreStateLock.RUnlock() + cancel() if err == nil { return } diff --git a/vault/expiration_test.go b/vault/expiration_test.go index fe0f56e1c0..b3e6c49113 100644 --- a/vault/expiration_test.go +++ b/vault/expiration_test.go @@ -574,7 +574,7 @@ func TestExpiration_Revoke(t *testing.T) { t.Fatalf("err: %v", err) } - if err := exp.Revoke(id); err != nil { + if err := exp.Revoke(context.Background(), id); err != nil { t.Fatalf("err: %v", err) } @@ -1732,7 +1732,7 @@ func TestExpiration_RevokeForce(t *testing.T) { ClientToken: root, } - resp, err := core.HandleRequest(req) + resp, err := core.HandleRequest(context.Background(), req) if err != nil { t.Fatal(err) } @@ -1746,13 +1746,13 @@ func TestExpiration_RevokeForce(t *testing.T) { req.Operation = logical.UpdateOperation req.Path = "sys/revoke-prefix/badrenew/creds" - resp, err = core.HandleRequest(req) + resp, err = core.HandleRequest(context.Background(), req) if err == nil { t.Fatal("expected error") } req.Path = "sys/revoke-force/badrenew/creds" - resp, err = core.HandleRequest(req) + resp, err = core.HandleRequest(context.Background(), req) if err != nil { t.Fatalf("got error: %s", err) } @@ -1780,7 +1780,7 @@ func TestExpiration_RevokeForceSingle(t *testing.T) { ClientToken: root, } - resp, err := core.HandleRequest(req) + resp, err := core.HandleRequest(context.Background(), req) if err != nil { t.Fatal(err) } @@ -1795,7 +1795,7 @@ func TestExpiration_RevokeForceSingle(t *testing.T) { req.Operation = logical.UpdateOperation req.Path = "sys/leases/lookup" req.Data = map[string]interface{}{"lease_id": leaseID} - resp, err = core.HandleRequest(req) + resp, err = core.HandleRequest(context.Background(), req) if err != nil { t.Fatal(err) } @@ -1808,20 +1808,20 @@ func TestExpiration_RevokeForceSingle(t *testing.T) { req.Path = "sys/revoke-prefix/" + leaseID - resp, err = core.HandleRequest(req) + resp, err = core.HandleRequest(context.Background(), req) if err == nil { t.Fatal("expected error") } req.Path = "sys/revoke-force/" + leaseID - resp, err = core.HandleRequest(req) + resp, err = core.HandleRequest(context.Background(), req) if err != nil { t.Fatalf("got error: %s", err) } req.Path = "sys/leases/lookup" req.Data = map[string]interface{}{"lease_id": leaseID} - resp, err = core.HandleRequest(req) + resp, err = core.HandleRequest(context.Background(), req) if err == nil { t.Fatal("expected error") } diff --git a/vault/ha.go b/vault/ha.go index 51fc9f3682..fec79e1759 100644 --- a/vault/ha.go +++ b/vault/ha.go @@ -142,7 +142,7 @@ func (c *Core) Leader() (isLeader bool, leaderAddr, clusterAddr string, err erro } // StepDown is used to step down from leadership -func (c *Core) StepDown(req *logical.Request) (retErr error) { +func (c *Core) StepDown(httpCtx context.Context, req *logical.Request) (retErr error) { defer metrics.MeasureSince([]string{"core", "step_down"}, time.Now()) if req == nil { @@ -159,7 +159,16 @@ func (c *Core) StepDown(req *logical.Request) (retErr error) { return nil } - ctx := c.activeContext + ctx, cancel := context.WithCancel(c.activeContext) + defer cancel() + + go func() { + select { + case <-ctx.Done(): + case <-httpCtx.Done(): + cancel() + } + }() acl, te, entity, identityPolicies, err := c.fetchACLTokenEntryAndEntity(req) if err != nil { @@ -238,7 +247,7 @@ func (c *Core) StepDown(req *logical.Request) (retErr error) { // we won't have a token store after sealing. leaseID, err := c.expiration.CreateOrFetchRevocationLeaseByToken(te) if err == nil { - err = c.expiration.Revoke(leaseID) + err = c.expiration.Revoke(ctx, leaseID) } if err != nil { c.logger.Error("token needed revocation before step-down but failed to revoke", "error", err) diff --git a/vault/identity_store_test.go b/vault/identity_store_test.go index 0fcdad99bf..5996ecd5c6 100644 --- a/vault/identity_store_test.go +++ b/vault/identity_store_test.go @@ -67,7 +67,7 @@ func TestIdentityStore_EntityIDPassthrough(t *testing.T) { } // Make the request with the above created token - resp, err := core.HandleRequest(&logical.Request{ + resp, err := core.HandleRequest(context.Background(), &logical.Request{ ClientToken: "testtokenid", Operation: logical.ReadOperation, Path: "test/backend/foo", @@ -241,7 +241,7 @@ func TestIdentityStore_WrapInfoInheritance(t *testing.T) { }, } - resp, err = core.HandleRequest(wrapReq) + resp, err = core.HandleRequest(context.Background(), wrapReq) if err != nil || (resp != nil && resp.IsError()) { t.Fatalf("bad: resp: %#v, err: %v", resp, err) } diff --git a/vault/logical_system.go b/vault/logical_system.go index f696424251..8942fe3df2 100644 --- a/vault/logical_system.go +++ b/vault/logical_system.go @@ -2319,7 +2319,7 @@ func (b *SystemBackend) handleRevoke(ctx context.Context, req *logical.Request, if data.Get("sync").(bool) { // Invoke the expiration manager directly - if err := b.Core.expiration.Revoke(leaseID); err != nil { + if err := b.Core.expiration.Revoke(ctx, leaseID); err != nil { b.Backend.Logger().Error("lease revocation failed", "lease_id", leaseID, "error", err) return handleErrorNoReadOnlyForward(err) } diff --git a/vault/logical_system_integ_test.go b/vault/logical_system_integ_test.go index d633e83314..b929ef724f 100644 --- a/vault/logical_system_integ_test.go +++ b/vault/logical_system_integ_test.go @@ -1,6 +1,7 @@ package vault_test import ( + "context" "fmt" "io/ioutil" "os" @@ -28,7 +29,7 @@ func TestSystemBackend_Plugin_secret(t *testing.T) { // Make a request to lazy load the plugin req := logical.TestRequest(t, logical.ReadOperation, "mock-0/internal") req.ClientToken = core.Client.Token() - resp, err := core.HandleRequest(req) + resp, err := core.HandleRequest(context.Background(), req) if err != nil { t.Fatalf("err: %v", err) } @@ -66,7 +67,7 @@ func TestSystemBackend_Plugin_auth(t *testing.T) { // Make a request to lazy load the plugin req := logical.TestRequest(t, logical.ReadOperation, "auth/mock-0/internal") req.ClientToken = core.Client.Token() - resp, err := core.HandleRequest(req) + resp, err := core.HandleRequest(context.Background(), req) if err != nil { t.Fatalf("err: %v", err) } @@ -108,7 +109,7 @@ func TestSystemBackend_Plugin_MismatchType(t *testing.T) { // and expect an error req := logical.TestRequest(t, logical.ReadOperation, "mock-0/internal") req.ClientToken = core.Client.Token() - _, err := core.HandleRequest(req) + _, err := core.HandleRequest(context.Background(), req) if err == nil { t.Fatalf("expected error due to mismatch on error type: %s", err) } @@ -144,7 +145,7 @@ func testPlugin_CatalogRemoved(t *testing.T, btype logical.BackendType, testMoun // Remove the plugin from the catalog req := logical.TestRequest(t, logical.DeleteOperation, "sys/plugins/catalog/mock-plugin") req.ClientToken = core.Client.Token() - resp, err := core.HandleRequest(req) + resp, err := core.HandleRequest(context.Background(), req) if err != nil || (resp != nil && resp.IsError()) { t.Fatalf("err:%v resp:%#v", err, resp) } @@ -229,7 +230,7 @@ func testPlugin_continueOnError(t *testing.T, btype logical.BackendType, mismatc // Get the registered plugin req := logical.TestRequest(t, logical.ReadOperation, "sys/plugins/catalog/mock-plugin") req.ClientToken = core.Client.Token() - resp, err := core.HandleRequest(req) + resp, err := core.HandleRequest(context.Background(), req) if err != nil || resp == nil || (resp != nil && resp.IsError()) { t.Fatalf("err:%v resp:%#v", err, resp) } @@ -247,7 +248,7 @@ func testPlugin_continueOnError(t *testing.T, btype logical.BackendType, mismatc "command": filepath.Base(command), } req.ClientToken = core.Client.Token() - resp, err = core.HandleRequest(req) + resp, err = core.HandleRequest(context.Background(), req) if err != nil || (resp != nil && resp.IsError()) { t.Fatalf("err:%v resp:%#v", err, resp) } @@ -293,7 +294,7 @@ func testPlugin_continueOnError(t *testing.T, btype logical.BackendType, mismatc "plugin": "mock-plugin", } req.ClientToken = core.Client.Token() - resp, err = core.HandleRequest(req) + resp, err = core.HandleRequest(context.Background(), req) if err != nil || (resp != nil && resp.IsError()) { t.Fatalf("err:%v resp:%#v", err, resp) } @@ -309,7 +310,7 @@ func testPlugin_continueOnError(t *testing.T, btype logical.BackendType, mismatc req = logical.TestRequest(t, logical.ReadOperation, reqPath) req.ClientToken = core.Client.Token() - resp, err = core.HandleRequest(req) + resp, err = core.HandleRequest(context.Background(), req) if err != nil { t.Fatalf("err: %v", err) } @@ -328,7 +329,7 @@ func TestSystemBackend_Plugin_autoReload(t *testing.T) { req := logical.TestRequest(t, logical.UpdateOperation, "mock-0/internal") req.ClientToken = core.Client.Token() req.Data["value"] = "baz" - resp, err := core.HandleRequest(req) + resp, err := core.HandleRequest(context.Background(), req) if err != nil { t.Fatalf("err: %v", err) } @@ -339,7 +340,7 @@ func TestSystemBackend_Plugin_autoReload(t *testing.T) { // Call errors/rpc endpoint to trigger reload req = logical.TestRequest(t, logical.ReadOperation, "mock-0/errors/rpc") req.ClientToken = core.Client.Token() - resp, err = core.HandleRequest(req) + resp, err = core.HandleRequest(context.Background(), req) if err == nil { t.Fatalf("expected error from error/rpc request") } @@ -347,7 +348,7 @@ func TestSystemBackend_Plugin_autoReload(t *testing.T) { // Check internal value to make sure it's reset req = logical.TestRequest(t, logical.ReadOperation, "mock-0/internal") req.ClientToken = core.Client.Token() - resp, err = core.HandleRequest(req) + resp, err = core.HandleRequest(context.Background(), req) if err != nil { t.Fatalf("err: %v", err) } diff --git a/vault/logical_system_test.go b/vault/logical_system_test.go index cdf61e2d27..fb18d51529 100644 --- a/vault/logical_system_test.go +++ b/vault/logical_system_test.go @@ -680,7 +680,7 @@ func TestSystemBackend_leases(t *testing.T) { req := logical.TestRequest(t, logical.UpdateOperation, "secret/foo") req.Data["foo"] = "bar" req.ClientToken = root - resp, err := core.HandleRequest(req) + resp, err := core.HandleRequest(context.Background(), req) if err != nil { t.Fatalf("err: %v", err) } @@ -691,7 +691,7 @@ func TestSystemBackend_leases(t *testing.T) { // Read a key with a LeaseID req = logical.TestRequest(t, logical.ReadOperation, "secret/foo") req.ClientToken = root - resp, err = core.HandleRequest(req) + resp, err = core.HandleRequest(context.Background(), req) if err != nil { t.Fatalf("err: %v", err) } @@ -726,7 +726,7 @@ func TestSystemBackend_leases_list(t *testing.T) { req := logical.TestRequest(t, logical.UpdateOperation, "secret/foo") req.Data["foo"] = "bar" req.ClientToken = root - resp, err := core.HandleRequest(req) + resp, err := core.HandleRequest(context.Background(), req) if err != nil { t.Fatalf("err: %v", err) } @@ -737,7 +737,7 @@ func TestSystemBackend_leases_list(t *testing.T) { // Read a key with a LeaseID req = logical.TestRequest(t, logical.ReadOperation, "secret/foo") req.ClientToken = root - resp, err = core.HandleRequest(req) + resp, err = core.HandleRequest(context.Background(), req) if err != nil { t.Fatalf("err: %v", err) } @@ -785,7 +785,7 @@ func TestSystemBackend_leases_list(t *testing.T) { // Generate multiple leases req = logical.TestRequest(t, logical.ReadOperation, "secret/foo") req.ClientToken = root - resp, err = core.HandleRequest(req) + resp, err = core.HandleRequest(context.Background(), req) if err != nil { t.Fatalf("err: %v", err) } @@ -795,7 +795,7 @@ func TestSystemBackend_leases_list(t *testing.T) { req = logical.TestRequest(t, logical.ReadOperation, "secret/foo") req.ClientToken = root - resp, err = core.HandleRequest(req) + resp, err = core.HandleRequest(context.Background(), req) if err != nil { t.Fatalf("err: %v", err) } @@ -823,7 +823,7 @@ func TestSystemBackend_leases_list(t *testing.T) { req = logical.TestRequest(t, logical.UpdateOperation, "secret/bar") req.Data["foo"] = "bar" req.ClientToken = root - resp, err = core.HandleRequest(req) + resp, err = core.HandleRequest(context.Background(), req) if err != nil { t.Fatalf("err: %v", err) } @@ -834,7 +834,7 @@ func TestSystemBackend_leases_list(t *testing.T) { // Read a key with a LeaseID req = logical.TestRequest(t, logical.ReadOperation, "secret/bar") req.ClientToken = root - resp, err = core.HandleRequest(req) + resp, err = core.HandleRequest(context.Background(), req) if err != nil { t.Fatalf("err: %v", err) } @@ -870,7 +870,7 @@ func TestSystemBackend_renew(t *testing.T) { req := logical.TestRequest(t, logical.UpdateOperation, "secret/foo") req.Data["foo"] = "bar" req.ClientToken = root - resp, err := core.HandleRequest(req) + resp, err := core.HandleRequest(context.Background(), req) if err != nil { t.Fatalf("err: %v", err) } @@ -881,7 +881,7 @@ func TestSystemBackend_renew(t *testing.T) { // Read a key with a LeaseID req = logical.TestRequest(t, logical.ReadOperation, "secret/foo") req.ClientToken = root - resp, err = core.HandleRequest(req) + resp, err = core.HandleRequest(context.Background(), req) if err != nil { t.Fatalf("err: %v", err) } @@ -906,7 +906,7 @@ func TestSystemBackend_renew(t *testing.T) { req.Data["foo"] = "bar" req.Data["ttl"] = "180s" req.ClientToken = root - resp, err = core.HandleRequest(req) + resp, err = core.HandleRequest(context.Background(), req) if err != nil { t.Fatalf("err: %v", err) } @@ -917,7 +917,7 @@ func TestSystemBackend_renew(t *testing.T) { // Read a key with a LeaseID req = logical.TestRequest(t, logical.ReadOperation, "secret/foo") req.ClientToken = root - resp, err = core.HandleRequest(req) + resp, err = core.HandleRequest(context.Background(), req) if err != nil { t.Fatalf("err: %v", err) } @@ -1034,7 +1034,7 @@ func TestSystemBackend_revoke(t *testing.T) { req.Data["foo"] = "bar" req.Data["lease"] = "1h" req.ClientToken = root - resp, err := core.HandleRequest(req) + resp, err := core.HandleRequest(context.Background(), req) if err != nil { t.Fatalf("err: %v", err) } @@ -1045,7 +1045,7 @@ func TestSystemBackend_revoke(t *testing.T) { // Read a key with a LeaseID req = logical.TestRequest(t, logical.ReadOperation, "secret/foo") req.ClientToken = root - resp, err = core.HandleRequest(req) + resp, err = core.HandleRequest(context.Background(), req) if err != nil { t.Fatalf("err: %v", err) } @@ -1076,7 +1076,7 @@ func TestSystemBackend_revoke(t *testing.T) { // Read a key with a LeaseID req = logical.TestRequest(t, logical.ReadOperation, "secret/foo") req.ClientToken = root - resp, err = core.HandleRequest(req) + resp, err = core.HandleRequest(context.Background(), req) if err != nil { t.Fatalf("err: %v", err) } @@ -1098,7 +1098,7 @@ func TestSystemBackend_revoke(t *testing.T) { // Read a key with a LeaseID req = logical.TestRequest(t, logical.ReadOperation, "secret/foo") req.ClientToken = root - resp, err = core.HandleRequest(req) + resp, err = core.HandleRequest(context.Background(), req) if err != nil { t.Fatalf("err: %v", err) } @@ -1176,7 +1176,7 @@ func TestSystemBackend_revokePrefix(t *testing.T) { req.Data["foo"] = "bar" req.Data["lease"] = "1h" req.ClientToken = root - resp, err := core.HandleRequest(req) + resp, err := core.HandleRequest(context.Background(), req) if err != nil { t.Fatalf("err: %v", err) } @@ -1187,7 +1187,7 @@ func TestSystemBackend_revokePrefix(t *testing.T) { // Read a key with a LeaseID req = logical.TestRequest(t, logical.ReadOperation, "secret/foo") req.ClientToken = root - resp, err = core.HandleRequest(req) + resp, err = core.HandleRequest(context.Background(), req) if err != nil { t.Fatalf("err: %v", err) } @@ -1224,7 +1224,7 @@ func TestSystemBackend_revokePrefix_origUrl(t *testing.T) { req.Data["foo"] = "bar" req.Data["lease"] = "1h" req.ClientToken = root - resp, err := core.HandleRequest(req) + resp, err := core.HandleRequest(context.Background(), req) if err != nil { t.Fatalf("err: %v", err) } @@ -1235,7 +1235,7 @@ func TestSystemBackend_revokePrefix_origUrl(t *testing.T) { // Read a key with a LeaseID req = logical.TestRequest(t, logical.ReadOperation, "secret/foo") req.ClientToken = root - resp, err = core.HandleRequest(req) + resp, err = core.HandleRequest(context.Background(), req) if err != nil { t.Fatalf("err: %v", err) } diff --git a/vault/mount_test.go b/vault/mount_test.go index e773003571..880460fafd 100644 --- a/vault/mount_test.go +++ b/vault/mount_test.go @@ -290,7 +290,7 @@ func TestCore_Unmount_Cleanup(t *testing.T) { Path: "test/foo", ClientToken: root, } - resp, err := c.HandleRequest(r) + resp, err := c.HandleRequest(context.Background(), r) if err != nil { t.Fatalf("err: %v", err) } @@ -410,7 +410,7 @@ func TestCore_Remount_Cleanup(t *testing.T) { Path: "test/foo", ClientToken: root, } - resp, err := c.HandleRequest(r) + resp, err := c.HandleRequest(context.Background(), r) if err != nil { t.Fatalf("err: %v", err) } diff --git a/vault/request_handling.go b/vault/request_handling.go index a8cd8af76e..cdc8c6af08 100644 --- a/vault/request_handling.go +++ b/vault/request_handling.go @@ -26,11 +26,20 @@ const ( replTimeout = 10 * time.Second ) +var ( + // DefaultMaxRequestDuration is the amount of time we'll wait for a request + // to complete, unless overridden on a per-handler basis + // FIXME: In 0.11 make this 90 seconds; for now keep it at essentially infinity if not set explicitly + //DefaultMaxRequestDuration = 90 * time.Second + DefaultMaxRequestDuration = 999999 * time.Hour +) + // HanlderProperties is used to seed configuration into a vaulthttp.Handler. // It's in this package to avoid a circular dependency type HandlerProperties struct { Core *Core MaxRequestSize int64 + MaxRequestDuration time.Duration DisablePrintableCheck bool } @@ -265,7 +274,7 @@ func (c *Core) checkToken(ctx context.Context, req *logical.Request, unauth bool } // HandleRequest is used to handle a new incoming request -func (c *Core) HandleRequest(req *logical.Request) (resp *logical.Response, err error) { +func (c *Core) HandleRequest(httpCtx context.Context, req *logical.Request) (resp *logical.Response, err error) { c.stateLock.RLock() defer c.stateLock.RUnlock() if c.Sealed() { @@ -278,6 +287,14 @@ func (c *Core) HandleRequest(req *logical.Request) (resp *logical.Response, err ctx, cancel := context.WithCancel(c.activeContext) defer cancel() + go func() { + select { + case <-ctx.Done(): + case <-httpCtx.Done(): + cancel() + } + }() + // Allowing writing to a path ending in / makes it extremely difficult to // understand user intent for the filesystem-like backends (kv, // cubbyhole) -- did they want a key named foo/ or did they want to write @@ -430,7 +447,7 @@ func (c *Core) handleRequest(ctx context.Context, req *logical.Request) (retResp defer func(id string) { leaseID, err := c.expiration.CreateOrFetchRevocationLeaseByToken(te) if err == nil { - err = c.expiration.Revoke(leaseID) + err = c.expiration.Revoke(ctx, leaseID) } if err != nil { c.logger.Error("failed to revoke token", "error", err) diff --git a/vault/request_handling_test.go b/vault/request_handling_test.go index 2a47c3f9d8..7be38c9155 100644 --- a/vault/request_handling_test.go +++ b/vault/request_handling_test.go @@ -35,7 +35,7 @@ func TestRequestHandling_Wrapping(t *testing.T) { "zip": "zap", }, } - resp, err := core.HandleRequest(req) + resp, err := core.HandleRequest(context.Background(), req) if err != nil { t.Fatalf("err: %v", err) } @@ -51,7 +51,7 @@ func TestRequestHandling_Wrapping(t *testing.T) { TTL: time.Duration(15 * time.Second), }, } - resp, err = core.HandleRequest(req) + resp, err = core.HandleRequest(context.Background(), req) if err != nil { t.Fatalf("err: %v", err) } @@ -82,7 +82,7 @@ func TestRequestHandling_LoginWrapping(t *testing.T) { }, Connection: &logical.Connection{}, } - resp, err := core.HandleRequest(req) + resp, err := core.HandleRequest(context.Background(), req) if err != nil { t.Fatalf("err: %v", err) } @@ -95,7 +95,7 @@ func TestRequestHandling_LoginWrapping(t *testing.T) { "password": "foo", "policies": "default", } - resp, err = core.HandleRequest(req) + resp, err = core.HandleRequest(context.Background(), req) if err != nil { t.Fatalf("err: %v", err) } @@ -111,7 +111,7 @@ func TestRequestHandling_LoginWrapping(t *testing.T) { }, Connection: &logical.Connection{}, } - resp, err = core.HandleRequest(req) + resp, err = core.HandleRequest(context.Background(), req) if err != nil { t.Fatalf("err: %v", err) } @@ -133,7 +133,7 @@ func TestRequestHandling_LoginWrapping(t *testing.T) { }, Connection: &logical.Connection{}, } - resp, err = core.HandleRequest(req) + resp, err = core.HandleRequest(context.Background(), req) if err != nil { t.Fatalf("err: %v", err) } diff --git a/vault/rollback.go b/vault/rollback.go index 9954246319..5ba3ad0707 100644 --- a/vault/rollback.go +++ b/vault/rollback.go @@ -168,7 +168,10 @@ func (m *RollbackManager) attemptRollback(ctx context.Context, path string, rs * Operation: logical.RollbackOperation, Path: path, } + var cancelFunc context.CancelFunc + ctx, cancelFunc = context.WithTimeout(ctx, DefaultMaxRequestDuration) _, err = m.router.Route(ctx, req) + cancelFunc() // If the error is an unsupported operation, then it doesn't // matter, the backend doesn't support it. diff --git a/vault/testing.go b/vault/testing.go index 9bb3e6f587..1c0c7484a6 100644 --- a/vault/testing.go +++ b/vault/testing.go @@ -1230,7 +1230,8 @@ func NewTestCluster(t testing.T, base *CoreConfig, opts *TestClusterOptions) *Te cores = append(cores, c) if opts != nil && opts.HandlerFunc != nil { handlers[i] = opts.HandlerFunc(&HandlerProperties{ - Core: c, + Core: c, + MaxRequestDuration: DefaultMaxRequestDuration, }) servers[i].Handler = handlers[i] } diff --git a/vault/token_store.go b/vault/token_store.go index c62ab4fa52..a8e57d4109 100644 --- a/vault/token_store.go +++ b/vault/token_store.go @@ -991,7 +991,7 @@ func (ts *TokenStore) lookupSalted(ctx context.Context, saltedID string, tainted return nil, err } - err = ts.expiration.Revoke(leaseID) + err = ts.expiration.Revoke(ctx, leaseID) if err != nil { return nil, err } @@ -1599,7 +1599,7 @@ func (ts *TokenStore) handleUpdateRevokeAccessor(ctx context.Context, req *logic return nil, err } - err = ts.expiration.Revoke(leaseID) + err = ts.expiration.Revoke(ctx, leaseID) if err != nil { return nil, err } @@ -2054,7 +2054,7 @@ func (ts *TokenStore) handleRevokeSelf(ctx context.Context, req *logical.Request return nil, err } - err = ts.expiration.Revoke(leaseID) + err = ts.expiration.Revoke(ctx, leaseID) if err != nil { return nil, err } @@ -2090,7 +2090,7 @@ func (ts *TokenStore) handleRevokeTree(ctx context.Context, req *logical.Request return nil, err } - err = ts.expiration.Revoke(leaseID) + err = ts.expiration.Revoke(ctx, leaseID) if err != nil { return nil, err } diff --git a/vault/token_store_test.go b/vault/token_store_test.go index 80dc81dce1..810b4385fa 100644 --- a/vault/token_store_test.go +++ b/vault/token_store_test.go @@ -333,7 +333,7 @@ func testMakeTokenViaCore(t testing.TB, c *Core, root, client, ttl string, polic req.Data["policies"] = policy req.Data["ttl"] = ttl - resp, err := c.HandleRequest(req) + resp, err := c.HandleRequest(context.Background(), req) if err != nil || (resp != nil && resp.IsError()) { t.Fatalf("err: %v\nresp: %#v", err, resp) } @@ -2085,7 +2085,7 @@ func TestTokenStore_RoleCRUD(t *testing.T) { req := logical.TestRequest(t, logical.ReadOperation, "auth/token/roles/test") req.ClientToken = root - resp, err := core.HandleRequest(req) + resp, err := core.HandleRequest(context.Background(), req) if err != nil || (resp != nil && resp.IsError()) { t.Fatalf("err: %v\nresp: %#v", err, resp) } @@ -2102,7 +2102,7 @@ func TestTokenStore_RoleCRUD(t *testing.T) { "path_suffix": "happenin", } - resp, err = core.HandleRequest(req) + resp, err = core.HandleRequest(context.Background(), req) if err != nil || (resp != nil && resp.IsError()) { t.Fatalf("err: %v\nresp: %#v", err, resp) } @@ -2114,7 +2114,7 @@ func TestTokenStore_RoleCRUD(t *testing.T) { req.Operation = logical.ReadOperation req.Data = map[string]interface{}{} - resp, err = core.HandleRequest(req) + resp, err = core.HandleRequest(context.Background(), req) if err != nil || (resp != nil && resp.IsError()) { t.Fatalf("err: %v\nresp: %#v", err, resp) } @@ -2148,7 +2148,7 @@ func TestTokenStore_RoleCRUD(t *testing.T) { "renewable": false, } - resp, err = core.HandleRequest(req) + resp, err = core.HandleRequest(context.Background(), req) if err != nil || (resp != nil && resp.IsError()) { t.Fatalf("err: %v\nresp: %#v", err, resp) } @@ -2160,7 +2160,7 @@ func TestTokenStore_RoleCRUD(t *testing.T) { req.Operation = logical.ReadOperation req.Data = map[string]interface{}{} - resp, err = core.HandleRequest(req) + resp, err = core.HandleRequest(context.Background(), req) if err != nil || (resp != nil && resp.IsError()) { t.Fatalf("err: %v\nresp: %#v", err, resp) } @@ -2189,7 +2189,7 @@ func TestTokenStore_RoleCRUD(t *testing.T) { "explicit_max_ttl": "5", "period": "0s", } - resp, err = core.HandleRequest(req) + resp, err = core.HandleRequest(context.Background(), req) if err != nil || (resp != nil && resp.IsError()) { t.Fatalf("err: %v\nresp: %#v", err, resp) } @@ -2197,7 +2197,7 @@ func TestTokenStore_RoleCRUD(t *testing.T) { req.Operation = logical.ReadOperation req.Data = map[string]interface{}{} - resp, err = core.HandleRequest(req) + resp, err = core.HandleRequest(context.Background(), req) if err != nil || (resp != nil && resp.IsError()) { t.Fatalf("err: %v\nresp: %#v", err, resp) } @@ -2223,7 +2223,7 @@ func TestTokenStore_RoleCRUD(t *testing.T) { req.Operation = logical.ListOperation req.Path = "auth/token/roles" req.Data = map[string]interface{}{} - resp, err = core.HandleRequest(req) + resp, err = core.HandleRequest(context.Background(), req) if err != nil || (resp != nil && resp.IsError()) { t.Fatalf("err: %v\nresp: %#v", err, resp) } @@ -2247,7 +2247,7 @@ func TestTokenStore_RoleCRUD(t *testing.T) { req.Operation = logical.DeleteOperation req.Path = "auth/token/roles/test" - resp, err = core.HandleRequest(req) + resp, err = core.HandleRequest(context.Background(), req) if err != nil || (resp != nil && resp.IsError()) { t.Fatalf("err: %v\nresp: %#v", err, resp) } @@ -2256,7 +2256,7 @@ func TestTokenStore_RoleCRUD(t *testing.T) { } req.Operation = logical.ReadOperation - resp, err = core.HandleRequest(req) + resp, err = core.HandleRequest(context.Background(), req) if err != nil || (resp != nil && resp.IsError()) { t.Fatalf("err: %v\nresp: %#v", err, resp) } @@ -2598,7 +2598,7 @@ func TestTokenStore_RolePeriod(t *testing.T) { "period": 5, } - resp, err := core.HandleRequest(req) + resp, err := core.HandleRequest(context.Background(), req) if err != nil || (resp != nil && resp.IsError()) { t.Fatalf("err: %v\nresp: %#v", err, resp) } @@ -2615,7 +2615,7 @@ func TestTokenStore_RolePeriod(t *testing.T) { req.Data = map[string]interface{}{ "policies": []string{"default"}, } - resp, err = core.HandleRequest(req) + resp, err = core.HandleRequest(context.Background(), req) if err != nil || (resp != nil && resp.IsError()) { t.Fatalf("err: %v\nresp: %#v", err, resp) } @@ -2626,7 +2626,7 @@ func TestTokenStore_RolePeriod(t *testing.T) { req.ClientToken = resp.Auth.ClientToken req.Operation = logical.ReadOperation req.Path = "auth/token/lookup-self" - resp, err = core.HandleRequest(req) + resp, err = core.HandleRequest(context.Background(), req) if err != nil { t.Fatalf("err: %v", err) } @@ -2640,14 +2640,14 @@ func TestTokenStore_RolePeriod(t *testing.T) { req.Operation = logical.UpdateOperation req.Path = "auth/token/renew-self" - resp, err = core.HandleRequest(req) + resp, err = core.HandleRequest(context.Background(), req) if err != nil || (resp != nil && resp.IsError()) { t.Fatalf("err: %v\nresp: %#v", err, resp) } req.Operation = logical.ReadOperation req.Path = "auth/token/lookup-self" - resp, err = core.HandleRequest(req) + resp, err = core.HandleRequest(context.Background(), req) if err != nil { t.Fatalf("err: %v", err) } @@ -2663,14 +2663,14 @@ func TestTokenStore_RolePeriod(t *testing.T) { req.Data = map[string]interface{}{ "increment": 1, } - resp, err = core.HandleRequest(req) + resp, err = core.HandleRequest(context.Background(), req) if err != nil || (resp != nil && resp.IsError()) { t.Fatalf("err: %v\nresp: %#v", err, resp) } req.Operation = logical.ReadOperation req.Path = "auth/token/lookup-self" - resp, err = core.HandleRequest(req) + resp, err = core.HandleRequest(context.Background(), req) if err != nil { t.Fatalf("err: %v", err) } @@ -2686,7 +2686,7 @@ func TestTokenStore_RolePeriod(t *testing.T) { req.ClientToken = root req.Operation = logical.UpdateOperation req.Path = "auth/token/create/test" - resp, err = core.HandleRequest(req) + resp, err = core.HandleRequest(context.Background(), req) if err != nil || (resp != nil && resp.IsError()) { t.Fatalf("err: %v\nresp: %#v", err, resp) } @@ -2703,7 +2703,7 @@ func TestTokenStore_RolePeriod(t *testing.T) { req.ClientToken = resp.Auth.ClientToken req.Operation = logical.ReadOperation req.Path = "auth/token/lookup-self" - resp, err = core.HandleRequest(req) + resp, err = core.HandleRequest(context.Background(), req) if err != nil { t.Fatalf("err: %v", err) } @@ -2720,14 +2720,14 @@ func TestTokenStore_RolePeriod(t *testing.T) { req.Data = map[string]interface{}{ "increment": 1, } - resp, err = core.HandleRequest(req) + resp, err = core.HandleRequest(context.Background(), req) if err != nil || (resp != nil && resp.IsError()) { t.Fatalf("err: %v\nresp: %#v", err, resp) } req.Operation = logical.ReadOperation req.Path = "auth/token/lookup-self" - resp, err = core.HandleRequest(req) + resp, err = core.HandleRequest(context.Background(), req) if err != nil { t.Fatalf("err: %v", err) } @@ -2754,7 +2754,7 @@ func TestTokenStore_RoleExplicitMaxTTL(t *testing.T) { "explicit_max_ttl": "100h", } - resp, err := core.HandleRequest(req) + resp, err := core.HandleRequest(context.Background(), req) if err != nil || (resp != nil && resp.IsError()) { t.Fatalf("err: %v\nresp: %#v", err, resp) } @@ -2764,7 +2764,7 @@ func TestTokenStore_RoleExplicitMaxTTL(t *testing.T) { req.Operation = logical.UpdateOperation req.Path = "auth/token/create/test" - resp, err = core.HandleRequest(req) + resp, err = core.HandleRequest(context.Background(), req) if err != nil { t.Fatalf("expected an error") } @@ -2779,7 +2779,7 @@ func TestTokenStore_RoleExplicitMaxTTL(t *testing.T) { "explicit_max_ttl": "10s", } - resp, err = core.HandleRequest(req) + resp, err = core.HandleRequest(context.Background(), req) if err != nil || (resp != nil && resp.IsError()) { t.Fatalf("err: %v\nresp: %#v", err, resp) } @@ -2795,7 +2795,7 @@ func TestTokenStore_RoleExplicitMaxTTL(t *testing.T) { req.Data = map[string]interface{}{ "policies": []string{"default"}, } - resp, err = core.HandleRequest(req) + resp, err = core.HandleRequest(context.Background(), req) if err != nil || (resp != nil && resp.IsError()) { t.Fatalf("err: %v\nresp: %#v", err, resp) } @@ -2806,7 +2806,7 @@ func TestTokenStore_RoleExplicitMaxTTL(t *testing.T) { req.ClientToken = resp.Auth.ClientToken req.Operation = logical.ReadOperation req.Path = "auth/token/lookup-self" - resp, err = core.HandleRequest(req) + resp, err = core.HandleRequest(context.Background(), req) if err != nil { t.Fatalf("err: %v", err) } @@ -2820,14 +2820,14 @@ func TestTokenStore_RoleExplicitMaxTTL(t *testing.T) { req.Operation = logical.UpdateOperation req.Path = "auth/token/renew-self" - resp, err = core.HandleRequest(req) + resp, err = core.HandleRequest(context.Background(), req) if err != nil || (resp != nil && resp.IsError()) { t.Fatalf("err: %v\nresp: %#v", err, resp) } req.Operation = logical.ReadOperation req.Path = "auth/token/lookup-self" - resp, err = core.HandleRequest(req) + resp, err = core.HandleRequest(context.Background(), req) if err != nil { t.Fatalf("err: %v", err) } @@ -2843,7 +2843,7 @@ func TestTokenStore_RoleExplicitMaxTTL(t *testing.T) { req.ClientToken = root req.Operation = logical.UpdateOperation req.Path = "auth/token/create/test" - resp, err = core.HandleRequest(req) + resp, err = core.HandleRequest(context.Background(), req) if err != nil || (resp != nil && resp.IsError()) { t.Fatalf("err: %v\nresp: %#v", err, resp) } @@ -2860,7 +2860,7 @@ func TestTokenStore_RoleExplicitMaxTTL(t *testing.T) { req.ClientToken = resp.Auth.ClientToken req.Operation = logical.ReadOperation req.Path = "auth/token/lookup-self" - resp, err = core.HandleRequest(req) + resp, err = core.HandleRequest(context.Background(), req) if err != nil { t.Fatalf("err: %v", err) } @@ -2882,14 +2882,14 @@ func TestTokenStore_RoleExplicitMaxTTL(t *testing.T) { req.Data = map[string]interface{}{ "increment": 300, } - resp, err = core.HandleRequest(req) + resp, err = core.HandleRequest(context.Background(), req) if err != nil || (resp != nil && resp.IsError()) { t.Fatalf("err: %v\nresp: %#v", err, resp) } req.Operation = logical.ReadOperation req.Path = "auth/token/lookup-self" - resp, err = core.HandleRequest(req) + resp, err = core.HandleRequest(context.Background(), req) if err != nil { t.Fatalf("err: %v", err) } @@ -2906,14 +2906,14 @@ func TestTokenStore_RoleExplicitMaxTTL(t *testing.T) { req.Data = map[string]interface{}{ "increment": 300, } - resp, err = core.HandleRequest(req) + resp, err = core.HandleRequest(context.Background(), req) if err != nil || (resp != nil && resp.IsError()) { t.Fatalf("err: %v\nresp: %#v", err, resp) } req.Operation = logical.ReadOperation req.Path = "auth/token/lookup-self" - resp, err = core.HandleRequest(req) + resp, err = core.HandleRequest(context.Background(), req) if err != nil { t.Fatalf("err: %v", err) } @@ -2930,7 +2930,7 @@ func TestTokenStore_RoleExplicitMaxTTL(t *testing.T) { req.Data = map[string]interface{}{ "increment": 300, } - resp, err = core.HandleRequest(req) + resp, err = core.HandleRequest(context.Background(), req) if err == nil { t.Fatalf("expected error") } @@ -2939,7 +2939,7 @@ func TestTokenStore_RoleExplicitMaxTTL(t *testing.T) { req.Operation = logical.ReadOperation req.Path = "auth/token/lookup-self" - resp, err = core.HandleRequest(req) + resp, err = core.HandleRequest(context.Background(), req) if resp != nil && err == nil { t.Fatalf("expected error, response is %#v", *resp) } @@ -2964,7 +2964,7 @@ func TestTokenStore_Periodic(t *testing.T) { "period": 5, } - resp, err := core.HandleRequest(req) + resp, err := core.HandleRequest(context.Background(), req) if err != nil || (resp != nil && resp.IsError()) { t.Fatalf("err: %v\nresp: %#v", err, resp) } @@ -2977,7 +2977,7 @@ func TestTokenStore_Periodic(t *testing.T) { req.ClientToken = root req.Operation = logical.UpdateOperation req.Path = "auth/token/create" - resp, err = core.HandleRequest(req) + resp, err = core.HandleRequest(context.Background(), req) if err != nil { t.Fatal(err) } @@ -2994,7 +2994,7 @@ func TestTokenStore_Periodic(t *testing.T) { req.ClientToken = resp.Auth.ClientToken req.Operation = logical.ReadOperation req.Path = "auth/token/lookup-self" - resp, err = core.HandleRequest(req) + resp, err = core.HandleRequest(context.Background(), req) if err != nil { t.Fatalf("err: %v", err) } @@ -3011,14 +3011,14 @@ func TestTokenStore_Periodic(t *testing.T) { req.Data = map[string]interface{}{ "increment": 1, } - resp, err = core.HandleRequest(req) + resp, err = core.HandleRequest(context.Background(), req) if err != nil || (resp != nil && resp.IsError()) { t.Fatalf("err: %v\nresp: %#v", err, resp) } req.Operation = logical.ReadOperation req.Path = "auth/token/lookup-self" - resp, err = core.HandleRequest(req) + resp, err = core.HandleRequest(context.Background(), req) if err != nil { t.Fatalf("err: %v", err) } @@ -3038,7 +3038,7 @@ func TestTokenStore_Periodic(t *testing.T) { req.Data = map[string]interface{}{ "period": 5, } - resp, err = core.HandleRequest(req) + resp, err = core.HandleRequest(context.Background(), req) if err != nil || (resp != nil && resp.IsError()) { t.Fatalf("err: %v\nresp: %#v", err, resp) } @@ -3055,7 +3055,7 @@ func TestTokenStore_Periodic(t *testing.T) { req.ClientToken = resp.Auth.ClientToken req.Operation = logical.ReadOperation req.Path = "auth/token/lookup-self" - resp, err = core.HandleRequest(req) + resp, err = core.HandleRequest(context.Background(), req) if err != nil { t.Fatalf("err: %v", err) } @@ -3072,14 +3072,14 @@ func TestTokenStore_Periodic(t *testing.T) { req.Data = map[string]interface{}{ "increment": 1, } - resp, err = core.HandleRequest(req) + resp, err = core.HandleRequest(context.Background(), req) if err != nil || (resp != nil && resp.IsError()) { t.Fatalf("err: %v\nresp: %#v", err, resp) } req.Operation = logical.ReadOperation req.Path = "auth/token/lookup-self" - resp, err = core.HandleRequest(req) + resp, err = core.HandleRequest(context.Background(), req) if err != nil { t.Fatalf("err: %v", err) } @@ -3105,7 +3105,7 @@ func TestTokenStore_Periodic_ExplicitMax(t *testing.T) { "period": 5, } - resp, err := core.HandleRequest(req) + resp, err := core.HandleRequest(context.Background(), req) if err != nil || (resp != nil && resp.IsError()) { t.Fatalf("err: %v\nresp: %#v", err, resp) } @@ -3122,7 +3122,7 @@ func TestTokenStore_Periodic_ExplicitMax(t *testing.T) { "period": 5, "explicit_max_ttl": 4, } - resp, err = core.HandleRequest(req) + resp, err = core.HandleRequest(context.Background(), req) if err != nil { t.Fatal(err) } @@ -3139,7 +3139,7 @@ func TestTokenStore_Periodic_ExplicitMax(t *testing.T) { req.ClientToken = resp.Auth.ClientToken req.Operation = logical.ReadOperation req.Path = "auth/token/lookup-self" - resp, err = core.HandleRequest(req) + resp, err = core.HandleRequest(context.Background(), req) if err != nil { t.Fatalf("err: %v", err) } @@ -3156,14 +3156,14 @@ func TestTokenStore_Periodic_ExplicitMax(t *testing.T) { req.Data = map[string]interface{}{ "increment": 76, } - resp, err = core.HandleRequest(req) + resp, err = core.HandleRequest(context.Background(), req) if err != nil || (resp != nil && resp.IsError()) { t.Fatalf("err: %v\nresp: %#v", err, resp) } req.Operation = logical.ReadOperation req.Path = "auth/token/lookup-self" - resp, err = core.HandleRequest(req) + resp, err = core.HandleRequest(context.Background(), req) if err != nil { t.Fatalf("err: %v", err) } @@ -3185,7 +3185,7 @@ func TestTokenStore_Periodic_ExplicitMax(t *testing.T) { "explicit_max_ttl": 4, } - resp, err := core.HandleRequest(req) + resp, err := core.HandleRequest(context.Background(), req) if err != nil || (resp != nil && resp.IsError()) { t.Fatalf("err: %v\nresp: %#v", err, resp) } @@ -3196,7 +3196,7 @@ func TestTokenStore_Periodic_ExplicitMax(t *testing.T) { req.ClientToken = root req.Operation = logical.UpdateOperation req.Path = "auth/token/create/test" - resp, err = core.HandleRequest(req) + resp, err = core.HandleRequest(context.Background(), req) if err != nil || (resp != nil && resp.IsError()) { t.Fatalf("err: %v\nresp: %#v", err, resp) } @@ -3213,7 +3213,7 @@ func TestTokenStore_Periodic_ExplicitMax(t *testing.T) { req.ClientToken = resp.Auth.ClientToken req.Operation = logical.ReadOperation req.Path = "auth/token/lookup-self" - resp, err = core.HandleRequest(req) + resp, err = core.HandleRequest(context.Background(), req) if err != nil { t.Fatalf("err: %v", err) } @@ -3230,14 +3230,14 @@ func TestTokenStore_Periodic_ExplicitMax(t *testing.T) { req.Data = map[string]interface{}{ "increment": 1, } - resp, err = core.HandleRequest(req) + resp, err = core.HandleRequest(context.Background(), req) if err != nil || (resp != nil && resp.IsError()) { t.Fatalf("err: %v\nresp: %#v", err, resp) } req.Operation = logical.ReadOperation req.Path = "auth/token/lookup-self" - resp, err = core.HandleRequest(req) + resp, err = core.HandleRequest(context.Background(), req) if err != nil { t.Fatalf("err: %v", err) }