Add context to storage backends and wire it through a lot of places (#3817)

This commit is contained in:
Brian Kassouf
2018-01-18 22:44:44 -08:00
committed by Jeff Mitchell
parent 2864fbd697
commit 8142b42d95
341 changed files with 3417 additions and 3083 deletions

View File

@@ -15,7 +15,7 @@ import (
// BackendPlugin is the plugin.Plugin implementation
type BackendPlugin struct {
Factory func(*logical.BackendConfig) (logical.Backend, error)
Factory logical.Factory
metadataMode bool
Logger hclog.Logger
}

View File

@@ -173,11 +173,11 @@ func (b *backendPluginClient) HandleExistenceCheck(ctx context.Context, req *log
return reply.CheckFound, reply.Exists, nil
}
func (b *backendPluginClient) Cleanup() {
func (b *backendPluginClient) Cleanup(ctx context.Context) {
b.client.Call("Plugin.Cleanup", new(interface{}), &struct{}{})
}
func (b *backendPluginClient) Initialize() error {
func (b *backendPluginClient) Initialize(ctx context.Context) error {
if b.metadataMode {
return ErrClientInMetadataMode
}
@@ -185,14 +185,14 @@ func (b *backendPluginClient) Initialize() error {
return err
}
func (b *backendPluginClient) InvalidateKey(key string) {
func (b *backendPluginClient) InvalidateKey(ctx context.Context, key string) {
if b.metadataMode {
return
}
b.client.Call("Plugin.InvalidateKey", key, &struct{}{})
}
func (b *backendPluginClient) Setup(config *logical.BackendConfig) error {
func (b *backendPluginClient) Setup(ctx context.Context, config *logical.BackendConfig) error {
// Shim logical.Storage
storageImpl := config.StorageView
if b.metadataMode {

View File

@@ -20,7 +20,7 @@ var (
type backendPluginServer struct {
broker *plugin.MuxBroker
backend logical.Backend
factory func(*logical.BackendConfig) (logical.Backend, error)
factory logical.Factory
loggerClient *rpc.Client
sysViewClient *rpc.Client
@@ -39,7 +39,7 @@ func (b *backendPluginServer) HandleRequest(args *HandleRequestArgs, reply *Hand
storage := &StorageClient{client: b.storageClient}
args.Request.Storage = storage
resp, err := b.backend.HandleRequest(context.TODO(), args.Request)
resp, err := b.backend.HandleRequest(context.Background(), args.Request)
*reply = HandleRequestReply{
Response: resp,
Error: wrapError(err),
@@ -74,7 +74,7 @@ func (b *backendPluginServer) HandleExistenceCheck(args *HandleExistenceCheckArg
}
func (b *backendPluginServer) Cleanup(_ interface{}, _ *struct{}) error {
b.backend.Cleanup()
b.backend.Cleanup(context.Background())
// Close rpc clients
b.loggerClient.Close()
@@ -88,7 +88,7 @@ func (b *backendPluginServer) Initialize(_ interface{}, _ *struct{}) error {
return ErrServerInMetadataMode
}
err := b.backend.Initialize()
err := b.backend.Initialize(context.Background())
return err
}
@@ -97,7 +97,7 @@ func (b *backendPluginServer) InvalidateKey(args string, _ *struct{}) error {
return ErrServerInMetadataMode
}
b.backend.InvalidateKey(args)
b.backend.InvalidateKey(context.Background(), args)
return nil
}
@@ -153,7 +153,7 @@ func (b *backendPluginServer) Setup(args *SetupArgs, reply *SetupReply) error {
// Call the underlying backend factory after shims have been created
// to set b.backend
backend, err := b.factory(config)
backend, err := b.factory(context.Background(), config)
if err != nil {
*reply = SetupReply{
Error: wrapError(err),

View File

@@ -97,14 +97,14 @@ func TestBackendPlugin_Cleanup(t *testing.T) {
b, cleanup := testBackend(t)
defer cleanup()
b.Cleanup()
b.Cleanup(context.Background())
}
func TestBackendPlugin_Initialize(t *testing.T) {
b, cleanup := testBackend(t)
defer cleanup()
err := b.Initialize()
err := b.Initialize(context.Background())
if err != nil {
t.Fatal(err)
}
@@ -114,7 +114,9 @@ func TestBackendPlugin_InvalidateKey(t *testing.T) {
b, cleanup := testBackend(t)
defer cleanup()
resp, err := b.HandleRequest(context.Background(), &logical.Request{
ctx := context.Background()
resp, err := b.HandleRequest(ctx, &logical.Request{
Operation: logical.ReadOperation,
Path: "internal",
})
@@ -125,9 +127,9 @@ func TestBackendPlugin_InvalidateKey(t *testing.T) {
t.Fatalf("bad: %#v, expected non-empty value", resp)
}
b.InvalidateKey("internal")
b.InvalidateKey(ctx, "internal")
resp, err = b.HandleRequest(context.Background(), &logical.Request{
resp, err = b.HandleRequest(ctx, &logical.Request{
Operation: logical.ReadOperation,
Path: "internal",
})
@@ -163,7 +165,7 @@ func testBackend(t *testing.T) (logical.Backend, func()) {
}
b := raw.(logical.Backend)
err = b.Setup(&logical.BackendConfig{
err = b.Setup(context.Background(), &logical.BackendConfig{
Logger: logformat.NewVaultLogger(log.LevelTrace),
System: &logical.StaticSystemView{
DefaultLeaseTTLVal: 300 * time.Second,

View File

@@ -15,6 +15,9 @@ import (
var ErrPluginShutdown = errors.New("plugin is shut down")
// Validate backendGRPCPluginClient satisfies the logical.Backend interface
var _ logical.Backend = &backendGRPCPluginClient{}
// backendPluginClient implements logical.Backend and is the
// go-plugin client.
type backendGRPCPluginClient struct {
@@ -126,33 +129,49 @@ func (b *backendGRPCPluginClient) HandleExistenceCheck(ctx context.Context, req
return reply.CheckFound, reply.Exists, nil
}
func (b *backendGRPCPluginClient) Cleanup() {
b.client.Cleanup(b.doneCtx, &pb.Empty{})
func (b *backendGRPCPluginClient) Cleanup(ctx context.Context) {
ctx, cancel := context.WithCancel(ctx)
quitCh := pluginutil.CtxCancelIfCanceled(cancel, b.doneCtx)
defer close(quitCh)
defer cancel()
b.client.Cleanup(ctx, &pb.Empty{})
if b.server != nil {
b.server.GracefulStop()
}
b.clientConn.Close()
}
func (b *backendGRPCPluginClient) Initialize() error {
func (b *backendGRPCPluginClient) Initialize(ctx context.Context) error {
if b.metadataMode {
return ErrClientInMetadataMode
}
_, err := b.client.Initialize(b.doneCtx, &pb.Empty{})
ctx, cancel := context.WithCancel(ctx)
quitCh := pluginutil.CtxCancelIfCanceled(cancel, b.doneCtx)
defer close(quitCh)
defer cancel()
_, err := b.client.Initialize(ctx, &pb.Empty{})
return err
}
func (b *backendGRPCPluginClient) InvalidateKey(key string) {
func (b *backendGRPCPluginClient) InvalidateKey(ctx context.Context, key string) {
if b.metadataMode {
return
}
b.client.InvalidateKey(b.doneCtx, &pb.InvalidateKeyArgs{
ctx, cancel := context.WithCancel(ctx)
quitCh := pluginutil.CtxCancelIfCanceled(cancel, b.doneCtx)
defer close(quitCh)
defer cancel()
b.client.InvalidateKey(ctx, &pb.InvalidateKeyArgs{
Key: key,
})
}
func (b *backendGRPCPluginClient) Setup(config *logical.BackendConfig) error {
func (b *backendGRPCPluginClient) Setup(ctx context.Context, config *logical.BackendConfig) error {
// Shim logical.Storage
storageImpl := config.StorageView
if b.metadataMode {
@@ -187,7 +206,12 @@ func (b *backendGRPCPluginClient) Setup(config *logical.BackendConfig) error {
Config: config.Config,
}
reply, err := b.client.Setup(b.doneCtx, args)
ctx, cancel := context.WithCancel(ctx)
quitCh := pluginutil.CtxCancelIfCanceled(cancel, b.doneCtx)
defer close(quitCh)
defer cancel()
reply, err := b.client.Setup(ctx, args)
if err != nil {
return err
}

View File

@@ -14,7 +14,7 @@ type backendGRPCPluginServer struct {
broker *plugin.GRPCBroker
backend logical.Backend
factory func(*logical.BackendConfig) (logical.Backend, error)
factory logical.Factory
brokeredClient *grpc.ClientConn
@@ -43,7 +43,7 @@ func (b *backendGRPCPluginServer) Setup(ctx context.Context, args *pb.SetupArgs)
// Call the underlying backend factory after shims have been created
// to set b.backend
backend, err := b.factory(config)
backend, err := b.factory(ctx, config)
if err != nil {
return &pb.SetupReply{
Err: pb.ErrToString(err),
@@ -112,7 +112,7 @@ func (b *backendGRPCPluginServer) HandleExistenceCheck(ctx context.Context, args
}
func (b *backendGRPCPluginServer) Cleanup(ctx context.Context, _ *pb.Empty) (*pb.Empty, error) {
b.backend.Cleanup()
b.backend.Cleanup(ctx)
// Close rpc clients
b.brokeredClient.Close()
@@ -124,7 +124,7 @@ func (b *backendGRPCPluginServer) Initialize(ctx context.Context, _ *pb.Empty) (
return &pb.Empty{}, ErrServerInMetadataMode
}
err := b.backend.Initialize()
err := b.backend.Initialize(ctx)
return &pb.Empty{}, err
}
@@ -133,7 +133,7 @@ func (b *backendGRPCPluginServer) InvalidateKey(ctx context.Context, args *pb.In
return &pb.Empty{}, ErrServerInMetadataMode
}
b.backend.InvalidateKey(args.Key)
b.backend.InvalidateKey(ctx, args.Key)
return &pb.Empty{}, nil
}

View File

@@ -99,14 +99,14 @@ func TestGRPCBackendPlugin_Cleanup(t *testing.T) {
b, cleanup := testGRPCBackend(t)
defer cleanup()
b.Cleanup()
b.Cleanup(context.Background())
}
func TestGRPCBackendPlugin_Initialize(t *testing.T) {
b, cleanup := testGRPCBackend(t)
defer cleanup()
err := b.Initialize()
err := b.Initialize(context.Background())
if err != nil {
t.Fatal(err)
}
@@ -116,7 +116,9 @@ func TestGRPCBackendPlugin_InvalidateKey(t *testing.T) {
b, cleanup := testGRPCBackend(t)
defer cleanup()
resp, err := b.HandleRequest(context.Background(), &logical.Request{
ctx := context.Background()
resp, err := b.HandleRequest(ctx, &logical.Request{
Operation: logical.ReadOperation,
Path: "internal",
})
@@ -127,9 +129,9 @@ func TestGRPCBackendPlugin_InvalidateKey(t *testing.T) {
t.Fatalf("bad: %#v, expected non-empty value", resp)
}
b.InvalidateKey("internal")
b.InvalidateKey(ctx, "internal")
resp, err = b.HandleRequest(context.Background(), &logical.Request{
resp, err = b.HandleRequest(ctx, &logical.Request{
Operation: logical.ReadOperation,
Path: "internal",
})
@@ -170,7 +172,7 @@ func testGRPCBackend(t *testing.T) (logical.Backend, func()) {
}
b := raw.(logical.Backend)
err = b.Setup(&logical.BackendConfig{
err = b.Setup(context.Background(), &logical.BackendConfig{
Logger: logformat.NewVaultLogger(log.LevelTrace),
System: &logical.StaticSystemView{
DefaultLeaseTTLVal: 300 * time.Second,

View File

@@ -22,8 +22,8 @@ type GRPCStorageClient struct {
client pb.StorageClient
}
func (s *GRPCStorageClient) List(prefix string) ([]string, error) {
reply, err := s.client.List(context.Background(), &pb.StorageListArgs{
func (s *GRPCStorageClient) List(ctx context.Context, prefix string) ([]string, error) {
reply, err := s.client.List(ctx, &pb.StorageListArgs{
Prefix: prefix,
})
if err != nil {
@@ -35,8 +35,8 @@ func (s *GRPCStorageClient) List(prefix string) ([]string, error) {
return reply.Keys, nil
}
func (s *GRPCStorageClient) Get(key string) (*logical.StorageEntry, error) {
reply, err := s.client.Get(context.Background(), &pb.StorageGetArgs{
func (s *GRPCStorageClient) Get(ctx context.Context, key string) (*logical.StorageEntry, error) {
reply, err := s.client.Get(ctx, &pb.StorageGetArgs{
Key: key,
})
if err != nil {
@@ -48,8 +48,8 @@ func (s *GRPCStorageClient) Get(key string) (*logical.StorageEntry, error) {
return pb.ProtoStorageEntryToLogicalStorageEntry(reply.Entry), nil
}
func (s *GRPCStorageClient) Put(entry *logical.StorageEntry) error {
reply, err := s.client.Put(context.Background(), &pb.StoragePutArgs{
func (s *GRPCStorageClient) Put(ctx context.Context, entry *logical.StorageEntry) error {
reply, err := s.client.Put(ctx, &pb.StoragePutArgs{
Entry: pb.LogicalStorageEntryToProtoStorageEntry(entry),
})
if err != nil {
@@ -61,8 +61,8 @@ func (s *GRPCStorageClient) Put(entry *logical.StorageEntry) error {
return nil
}
func (s *GRPCStorageClient) Delete(key string) error {
reply, err := s.client.Delete(context.Background(), &pb.StorageDeleteArgs{
func (s *GRPCStorageClient) Delete(ctx context.Context, key string) error {
reply, err := s.client.Delete(ctx, &pb.StorageDeleteArgs{
Key: key,
})
if err != nil {
@@ -80,7 +80,7 @@ type GRPCStorageServer struct {
}
func (s *GRPCStorageServer) List(ctx context.Context, args *pb.StorageListArgs) (*pb.StorageListReply, error) {
keys, err := s.impl.List(args.Prefix)
keys, err := s.impl.List(ctx, args.Prefix)
return &pb.StorageListReply{
Keys: keys,
Err: pb.ErrToString(err),
@@ -88,7 +88,7 @@ func (s *GRPCStorageServer) List(ctx context.Context, args *pb.StorageListArgs)
}
func (s *GRPCStorageServer) Get(ctx context.Context, args *pb.StorageGetArgs) (*pb.StorageGetReply, error) {
storageEntry, err := s.impl.Get(args.Key)
storageEntry, err := s.impl.Get(ctx, args.Key)
return &pb.StorageGetReply{
Entry: pb.LogicalStorageEntryToProtoStorageEntry(storageEntry),
Err: pb.ErrToString(err),
@@ -96,14 +96,14 @@ func (s *GRPCStorageServer) Get(ctx context.Context, args *pb.StorageGetArgs) (*
}
func (s *GRPCStorageServer) Put(ctx context.Context, args *pb.StoragePutArgs) (*pb.StoragePutReply, error) {
err := s.impl.Put(pb.ProtoStorageEntryToLogicalStorageEntry(args.Entry))
err := s.impl.Put(ctx, pb.ProtoStorageEntryToLogicalStorageEntry(args.Entry))
return &pb.StoragePutReply{
Err: pb.ErrToString(err),
}, nil
}
func (s *GRPCStorageServer) Delete(ctx context.Context, args *pb.StorageDeleteArgs) (*pb.StorageDeleteReply, error) {
err := s.impl.Delete(args.Key)
err := s.impl.Delete(ctx, args.Key)
return &pb.StorageDeleteReply{
Err: pb.ErrToString(err),
}, nil

View File

@@ -45,8 +45,8 @@ func (s *gRPCSystemViewClient) MaxLeaseTTL() time.Duration {
return time.Duration(reply.TTL)
}
func (s *gRPCSystemViewClient) SudoPrivilege(path string, token string) bool {
reply, err := s.client.SudoPrivilege(context.Background(), &pb.SudoPrivilegeArgs{
func (s *gRPCSystemViewClient) SudoPrivilege(ctx context.Context, path string, token string) bool {
reply, err := s.client.SudoPrivilege(ctx, &pb.SudoPrivilegeArgs{
Path: path,
Token: token,
})
@@ -84,13 +84,13 @@ func (s *gRPCSystemViewClient) ReplicationState() consts.ReplicationState {
return consts.ReplicationState(reply.State)
}
func (s *gRPCSystemViewClient) ResponseWrapData(data map[string]interface{}, ttl time.Duration, jwt bool) (*wrapping.ResponseWrapInfo, error) {
func (s *gRPCSystemViewClient) ResponseWrapData(ctx context.Context, data map[string]interface{}, ttl time.Duration, jwt bool) (*wrapping.ResponseWrapInfo, error) {
buf, err := json.Marshal(data)
if err != nil {
return nil, err
}
reply, err := s.client.ResponseWrapData(context.Background(), &pb.ResponseWrapDataArgs{
reply, err := s.client.ResponseWrapData(ctx, &pb.ResponseWrapDataArgs{
Data: buf,
TTL: int64(ttl),
JWT: false,
@@ -110,7 +110,7 @@ func (s *gRPCSystemViewClient) ResponseWrapData(data map[string]interface{}, ttl
return info, nil
}
func (s *gRPCSystemViewClient) LookupPlugin(name string) (*pluginutil.PluginRunner, error) {
func (s *gRPCSystemViewClient) LookupPlugin(ctx context.Context, name string) (*pluginutil.PluginRunner, error) {
return nil, fmt.Errorf("cannot call LookupPlugin from a plugin backend")
}
@@ -142,7 +142,7 @@ func (s *gRPCSystemViewServer) MaxLeaseTTL(ctx context.Context, _ *pb.Empty) (*p
}
func (s *gRPCSystemViewServer) SudoPrivilege(ctx context.Context, args *pb.SudoPrivilegeArgs) (*pb.SudoPrivilegeReply, error) {
sudo := s.impl.SudoPrivilege(args.Path, args.Token)
sudo := s.impl.SudoPrivilege(ctx, args.Path, args.Token)
return &pb.SudoPrivilegeReply{
Sudo: sudo,
}, nil
@@ -177,7 +177,7 @@ func (s *gRPCSystemViewServer) ResponseWrapData(ctx context.Context, args *pb.Re
}
// Do not allow JWTs to be returned
info, err := s.impl.ResponseWrapData(data, time.Duration(args.TTL), false)
info, err := s.impl.ResponseWrapData(ctx, data, time.Duration(args.TTL), false)
if err != nil {
return &pb.ResponseWrapDataReply{
Err: pb.ErrToString(err),

View File

@@ -1,6 +1,7 @@
package plugin
import (
"context"
"testing"
"google.golang.org/grpc"
@@ -62,8 +63,10 @@ func TestSystem_GRPC_sudoPrivilege(t *testing.T) {
defer client.Close()
testSystemView := newGRPCSystemView(client)
expected := sys.SudoPrivilege("foo", "bar")
actual := testSystemView.SudoPrivilege("foo", "bar")
ctx := context.Background()
expected := sys.SudoPrivilege(ctx, "foo", "bar")
actual := testSystemView.SudoPrivilege(ctx, "foo", "bar")
if !reflect.DeepEqual(expected, actual) {
t.Fatalf("expected: %v, got: %v", expected, actual)
}
@@ -138,7 +141,7 @@ func TestSystem_GRPC_lookupPlugin(t *testing.T) {
testSystemView := newGRPCSystemView(client)
if _, err := testSystemView.LookupPlugin("foo"); err == nil {
if _, err := testSystemView.LookupPlugin(context.Background(), "foo"); err == nil {
t.Fatal("LookPlugin(): expected error on due to unsupported call from plugin")
}
}

View File

@@ -18,6 +18,9 @@ type backendTracingMiddleware struct {
next logical.Backend
}
// Validate the backendTracingMiddle object satisfies the backend interface
var _ logical.Backend = &backendTracingMiddleware{}
func (b *backendTracingMiddleware) HandleRequest(ctx context.Context, req *logical.Request) (resp *logical.Response, err error) {
defer func(then time.Time) {
b.logger.Trace("plugin.HandleRequest", "path", req.Path, "status", "finished", "type", b.typeStr, "transport", b.transport, "err", err, "took", time.Since(then))
@@ -53,40 +56,40 @@ func (b *backendTracingMiddleware) HandleExistenceCheck(ctx context.Context, req
return b.next.HandleExistenceCheck(ctx, req)
}
func (b *backendTracingMiddleware) Cleanup() {
func (b *backendTracingMiddleware) Cleanup(ctx context.Context) {
defer func(then time.Time) {
b.logger.Trace("plugin.Cleanup", "status", "finished", "type", b.typeStr, "transport", b.transport, "took", time.Since(then))
}(time.Now())
b.logger.Trace("plugin.Cleanup", "status", "started", "type", b.typeStr, "transport", b.transport)
b.next.Cleanup()
b.next.Cleanup(ctx)
}
func (b *backendTracingMiddleware) Initialize() (err error) {
func (b *backendTracingMiddleware) Initialize(ctx context.Context) (err error) {
defer func(then time.Time) {
b.logger.Trace("plugin.Initialize", "status", "finished", "type", b.typeStr, "transport", b.transport, "err", err, "took", time.Since(then))
}(time.Now())
b.logger.Trace("plugin.Initialize", "status", "started", "type", b.typeStr, "transport", b.transport)
return b.next.Initialize()
return b.next.Initialize(ctx)
}
func (b *backendTracingMiddleware) InvalidateKey(key string) {
func (b *backendTracingMiddleware) InvalidateKey(ctx context.Context, key string) {
defer func(then time.Time) {
b.logger.Trace("plugin.InvalidateKey", "key", key, "status", "finished", "type", b.typeStr, "transport", b.transport, "took", time.Since(then))
}(time.Now())
b.logger.Trace("plugin.InvalidateKey", "key", key, "status", "started", "type", b.typeStr, "transport", b.transport)
b.next.InvalidateKey(key)
b.next.InvalidateKey(ctx, key)
}
func (b *backendTracingMiddleware) Setup(config *logical.BackendConfig) (err error) {
func (b *backendTracingMiddleware) Setup(ctx context.Context, config *logical.BackendConfig) (err error) {
defer func(then time.Time) {
b.logger.Trace("plugin.Setup", "status", "finished", "type", b.typeStr, "transport", b.transport, "err", err, "took", time.Since(then))
}(time.Now())
b.logger.Trace("plugin.Setup", "status", "started", "type", b.typeStr, "transport", b.transport)
return b.next.Setup(config)
return b.next.Setup(ctx, config)
}
func (b *backendTracingMiddleware) Type() logical.BackendType {

View File

@@ -1,6 +1,8 @@
package mock
import (
"context"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
)
@@ -12,9 +14,9 @@ func New() (interface{}, error) {
}
// Factory returns a new backend as logical.Backend.
func Factory(conf *logical.BackendConfig) (logical.Backend, error) {
func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) {
b := Backend()
if err := b.Setup(conf); err != nil {
if err := b.Setup(ctx, conf); err != nil {
return nil, err
}
return b, nil
@@ -22,11 +24,11 @@ func Factory(conf *logical.BackendConfig) (logical.Backend, error) {
// FactoryType is a wrapper func that allows the Factory func to specify
// the backend type for the mock backend plugin instance.
func FactoryType(backendType logical.BackendType) func(*logical.BackendConfig) (logical.Backend, error) {
return func(conf *logical.BackendConfig) (logical.Backend, error) {
func FactoryType(backendType logical.BackendType) logical.Factory {
return func(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) {
b := Backend()
b.BackendType = backendType
if err := b.Setup(conf); err != nil {
if err := b.Setup(ctx, conf); err != nil {
return nil, err
}
return b, nil
@@ -66,7 +68,7 @@ type backend struct {
internal string
}
func (b *backend) invalidate(key string) {
func (b *backend) invalidate(ctx context.Context, key string) {
switch key {
case "internal":
b.internal = ""

View File

@@ -36,7 +36,7 @@ func kvPaths(b *backend) []*framework.Path {
}
func (b *backend) pathExistenceCheck(ctx context.Context, req *logical.Request, data *framework.FieldData) (bool, error) {
out, err := req.Storage.Get(req.Path)
out, err := req.Storage.Get(ctx, req.Path)
if err != nil {
return false, fmt.Errorf("existence check failed: %v", err)
}
@@ -45,7 +45,7 @@ func (b *backend) pathExistenceCheck(ctx context.Context, req *logical.Request,
}
func (b *backend) pathKVRead(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
entry, err := req.Storage.Get(req.Path)
entry, err := req.Storage.Get(ctx, req.Path)
if err != nil {
return nil, err
}
@@ -75,7 +75,7 @@ func (b *backend) pathKVCreateUpdate(ctx context.Context, req *logical.Request,
}
s := req.Storage
err := s.Put(entry)
err := s.Put(ctx, entry)
if err != nil {
return nil, err
}
@@ -88,7 +88,7 @@ func (b *backend) pathKVCreateUpdate(ctx context.Context, req *logical.Request,
}
func (b *backend) pathKVDelete(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
if err := req.Storage.Delete(req.Path); err != nil {
if err := req.Storage.Delete(ctx, req.Path); err != nil {
return nil, err
}
@@ -96,7 +96,7 @@ func (b *backend) pathKVDelete(ctx context.Context, req *logical.Request, data *
}
func (b *backend) pathKVList(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
vals, err := req.Storage.List("kv/")
vals, err := req.Storage.List(ctx, "kv/")
if err != nil {
return nil, err
}

View File

@@ -1,6 +1,7 @@
package plugin
import (
"context"
"crypto/ecdsa"
"crypto/rsa"
"encoding/gob"
@@ -50,8 +51,8 @@ type BackendPluginClient struct {
// Cleanup calls the RPC client's Cleanup() func and also calls
// the go-plugin's client Kill() func
func (b *BackendPluginClient) Cleanup() {
b.Backend.Cleanup()
func (b *BackendPluginClient) Cleanup(ctx context.Context) {
b.Backend.Cleanup(ctx)
b.client.Kill()
}
@@ -59,9 +60,9 @@ func (b *BackendPluginClient) Cleanup() {
// external plugins, or a concrete implementation of the backend if it is a builtin backend.
// The backend is returned as a logical.Backend interface. The isMetadataMode param determines whether
// the plugin should run in metadata mode.
func NewBackend(pluginName string, sys pluginutil.LookRunnerUtil, logger log.Logger, isMetadataMode bool) (logical.Backend, error) {
func NewBackend(ctx context.Context, pluginName string, sys pluginutil.LookRunnerUtil, logger log.Logger, isMetadataMode bool) (logical.Backend, error) {
// Look for plugin in the plugin catalog
pluginRunner, err := sys.LookupPlugin(pluginName)
pluginRunner, err := sys.LookupPlugin(ctx, pluginName)
if err != nil {
return nil, err
}
@@ -83,7 +84,7 @@ func NewBackend(pluginName string, sys pluginutil.LookRunnerUtil, logger log.Log
} else {
// create a backendPluginClient instance
backend, err = newPluginClient(sys, pluginRunner, logger, isMetadataMode)
backend, err = newPluginClient(ctx, sys, pluginRunner, logger, isMetadataMode)
if err != nil {
return nil, err
}
@@ -92,7 +93,7 @@ func NewBackend(pluginName string, sys pluginutil.LookRunnerUtil, logger log.Log
return backend, nil
}
func newPluginClient(sys pluginutil.RunnerUtil, pluginRunner *pluginutil.PluginRunner, logger log.Logger, isMetadataMode bool) (logical.Backend, error) {
func newPluginClient(ctx context.Context, sys pluginutil.RunnerUtil, pluginRunner *pluginutil.PluginRunner, logger log.Logger, isMetadataMode bool) (logical.Backend, error) {
// pluginMap is the map of plugins we can dispense.
pluginMap := map[string]plugin.Plugin{
"backend": &BackendPlugin{
@@ -103,9 +104,9 @@ func newPluginClient(sys pluginutil.RunnerUtil, pluginRunner *pluginutil.PluginR
var client *plugin.Client
var err error
if isMetadataMode {
client, err = pluginRunner.RunMetadataMode(sys, pluginMap, handshakeConfig, []string{}, logger)
client, err = pluginRunner.RunMetadataMode(ctx, sys, pluginMap, handshakeConfig, []string{}, logger)
} else {
client, err = pluginRunner.Run(sys, pluginMap, handshakeConfig, []string{}, logger)
client, err = pluginRunner.Run(ctx, sys, pluginMap, handshakeConfig, []string{}, logger)
}
if err != nil {
return nil, err

View File

@@ -14,11 +14,10 @@ import (
// dispensed rom the plugin server.
const BackendPluginName = "backend"
type BackendFactoryFunc func(*logical.BackendConfig) (logical.Backend, error)
type TLSProdiverFunc func() (*tls.Config, error)
type ServeOpts struct {
BackendFactoryFunc BackendFactoryFunc
BackendFactoryFunc logical.Factory
TLSProviderFunc TLSProdiverFunc
Logger hclog.Logger
}

View File

@@ -1,6 +1,7 @@
package plugin
import (
"context"
"net/rpc"
"github.com/hashicorp/vault/logical"
@@ -12,7 +13,7 @@ type StorageClient struct {
client *rpc.Client
}
func (s *StorageClient) List(prefix string) ([]string, error) {
func (s *StorageClient) List(_ context.Context, prefix string) ([]string, error) {
var reply StorageListReply
err := s.client.Call("Plugin.List", prefix, &reply)
if err != nil {
@@ -24,7 +25,7 @@ func (s *StorageClient) List(prefix string) ([]string, error) {
return reply.Keys, nil
}
func (s *StorageClient) Get(key string) (*logical.StorageEntry, error) {
func (s *StorageClient) Get(_ context.Context, key string) (*logical.StorageEntry, error) {
var reply StorageGetReply
err := s.client.Call("Plugin.Get", key, &reply)
if err != nil {
@@ -36,7 +37,7 @@ func (s *StorageClient) Get(key string) (*logical.StorageEntry, error) {
return reply.StorageEntry, nil
}
func (s *StorageClient) Put(entry *logical.StorageEntry) error {
func (s *StorageClient) Put(_ context.Context, entry *logical.StorageEntry) error {
var reply StoragePutReply
err := s.client.Call("Plugin.Put", entry, &reply)
if err != nil {
@@ -48,7 +49,7 @@ func (s *StorageClient) Put(entry *logical.StorageEntry) error {
return nil
}
func (s *StorageClient) Delete(key string) error {
func (s *StorageClient) Delete(_ context.Context, key string) error {
var reply StorageDeleteReply
err := s.client.Call("Plugin.Delete", key, &reply)
if err != nil {
@@ -66,7 +67,7 @@ type StorageServer struct {
}
func (s *StorageServer) List(prefix string, reply *StorageListReply) error {
keys, err := s.impl.List(prefix)
keys, err := s.impl.List(context.Background(), prefix)
*reply = StorageListReply{
Keys: keys,
Error: wrapError(err),
@@ -75,7 +76,7 @@ func (s *StorageServer) List(prefix string, reply *StorageListReply) error {
}
func (s *StorageServer) Get(key string, reply *StorageGetReply) error {
storageEntry, err := s.impl.Get(key)
storageEntry, err := s.impl.Get(context.Background(), key)
*reply = StorageGetReply{
StorageEntry: storageEntry,
Error: wrapError(err),
@@ -84,7 +85,7 @@ func (s *StorageServer) Get(key string, reply *StorageGetReply) error {
}
func (s *StorageServer) Put(entry *logical.StorageEntry, reply *StoragePutReply) error {
err := s.impl.Put(entry)
err := s.impl.Put(context.Background(), entry)
*reply = StoragePutReply{
Error: wrapError(err),
}
@@ -92,7 +93,7 @@ func (s *StorageServer) Put(entry *logical.StorageEntry, reply *StoragePutReply)
}
func (s *StorageServer) Delete(key string, reply *StorageDeleteReply) error {
err := s.impl.Delete(key)
err := s.impl.Delete(context.Background(), key)
*reply = StorageDeleteReply{
Error: wrapError(err),
}
@@ -121,18 +122,18 @@ type StorageDeleteReply struct {
// backend plugin in metadata mode.
type NOOPStorage struct{}
func (s *NOOPStorage) List(prefix string) ([]string, error) {
func (s *NOOPStorage) List(_ context.Context, prefix string) ([]string, error) {
return []string{}, nil
}
func (s *NOOPStorage) Get(key string) (*logical.StorageEntry, error) {
func (s *NOOPStorage) Get(_ context.Context, key string) (*logical.StorageEntry, error) {
return nil, nil
}
func (s *NOOPStorage) Put(entry *logical.StorageEntry) error {
func (s *NOOPStorage) Put(_ context.Context, entry *logical.StorageEntry) error {
return nil
}
func (s *NOOPStorage) Delete(key string) error {
func (s *NOOPStorage) Delete(_ context.Context, key string) error {
return nil
}

View File

@@ -1,6 +1,7 @@
package plugin
import (
"context"
"net/rpc"
"time"
@@ -36,7 +37,7 @@ func (s *SystemViewClient) MaxLeaseTTL() time.Duration {
return reply.MaxLeaseTTL
}
func (s *SystemViewClient) SudoPrivilege(path string, token string) bool {
func (s *SystemViewClient) SudoPrivilege(ctx context.Context, path string, token string) bool {
var reply SudoPrivilegeReply
args := &SudoPrivilegeArgs{
Path: path,
@@ -84,7 +85,7 @@ func (s *SystemViewClient) ReplicationState() consts.ReplicationState {
return reply.ReplicationState
}
func (s *SystemViewClient) ResponseWrapData(data map[string]interface{}, ttl time.Duration, jwt bool) (*wrapping.ResponseWrapInfo, error) {
func (s *SystemViewClient) ResponseWrapData(ctx context.Context, data map[string]interface{}, ttl time.Duration, jwt bool) (*wrapping.ResponseWrapInfo, error) {
var reply ResponseWrapDataReply
// Do not allow JWTs to be returned
args := &ResponseWrapDataArgs{
@@ -104,7 +105,7 @@ func (s *SystemViewClient) ResponseWrapData(data map[string]interface{}, ttl tim
return reply.ResponseWrapInfo, nil
}
func (s *SystemViewClient) LookupPlugin(name string) (*pluginutil.PluginRunner, error) {
func (s *SystemViewClient) LookupPlugin(ctx context.Context, name string) (*pluginutil.PluginRunner, error) {
return nil, fmt.Errorf("cannot call LookupPlugin from a plugin backend")
}
@@ -141,7 +142,7 @@ func (s *SystemViewServer) MaxLeaseTTL(_ interface{}, reply *MaxLeaseTTLReply) e
}
func (s *SystemViewServer) SudoPrivilege(args *SudoPrivilegeArgs, reply *SudoPrivilegeReply) error {
sudo := s.impl.SudoPrivilege(args.Path, args.Token)
sudo := s.impl.SudoPrivilege(context.Background(), args.Path, args.Token)
*reply = SudoPrivilegeReply{
Sudo: sudo,
}
@@ -178,7 +179,7 @@ func (s *SystemViewServer) ReplicationState(_ interface{}, reply *ReplicationSta
func (s *SystemViewServer) ResponseWrapData(args *ResponseWrapDataArgs, reply *ResponseWrapDataReply) error {
// Do not allow JWTs to be returned
info, err := s.impl.ResponseWrapData(args.Data, args.TTL, false)
info, err := s.impl.ResponseWrapData(context.Background(), args.Data, args.TTL, false)
if err != nil {
*reply = ResponseWrapDataReply{
Error: wrapError(err),

View File

@@ -1,6 +1,7 @@
package plugin
import (
"context"
"testing"
"reflect"
@@ -64,9 +65,10 @@ func TestSystem_sudoPrivilege(t *testing.T) {
})
testSystemView := &SystemViewClient{client: client}
ctx := context.Background()
expected := sys.SudoPrivilege("foo", "bar")
actual := testSystemView.SudoPrivilege("foo", "bar")
expected := sys.SudoPrivilege(ctx, "foo", "bar")
actual := testSystemView.SudoPrivilege(ctx, "foo", "bar")
if !reflect.DeepEqual(expected, actual) {
t.Fatalf("expected: %v, got: %v", expected, actual)
}
@@ -148,7 +150,7 @@ func TestSystem_lookupPlugin(t *testing.T) {
testSystemView := &SystemViewClient{client: client}
if _, err := testSystemView.LookupPlugin("foo"); err == nil {
if _, err := testSystemView.LookupPlugin(context.Background(), "foo"); err == nil {
t.Fatal("LookPlugin(): expected error on due to unsupported call from plugin")
}
}