mirror of
https://github.com/optim-enterprises-bv/vault.git
synced 2025-11-02 03:27:54 +00:00
Add context to storage backends and wire it through a lot of places (#3817)
This commit is contained in:
committed by
Jeff Mitchell
parent
2864fbd697
commit
8142b42d95
@@ -15,7 +15,7 @@ import (
|
||||
|
||||
// BackendPlugin is the plugin.Plugin implementation
|
||||
type BackendPlugin struct {
|
||||
Factory func(*logical.BackendConfig) (logical.Backend, error)
|
||||
Factory logical.Factory
|
||||
metadataMode bool
|
||||
Logger hclog.Logger
|
||||
}
|
||||
|
||||
@@ -173,11 +173,11 @@ func (b *backendPluginClient) HandleExistenceCheck(ctx context.Context, req *log
|
||||
return reply.CheckFound, reply.Exists, nil
|
||||
}
|
||||
|
||||
func (b *backendPluginClient) Cleanup() {
|
||||
func (b *backendPluginClient) Cleanup(ctx context.Context) {
|
||||
b.client.Call("Plugin.Cleanup", new(interface{}), &struct{}{})
|
||||
}
|
||||
|
||||
func (b *backendPluginClient) Initialize() error {
|
||||
func (b *backendPluginClient) Initialize(ctx context.Context) error {
|
||||
if b.metadataMode {
|
||||
return ErrClientInMetadataMode
|
||||
}
|
||||
@@ -185,14 +185,14 @@ func (b *backendPluginClient) Initialize() error {
|
||||
return err
|
||||
}
|
||||
|
||||
func (b *backendPluginClient) InvalidateKey(key string) {
|
||||
func (b *backendPluginClient) InvalidateKey(ctx context.Context, key string) {
|
||||
if b.metadataMode {
|
||||
return
|
||||
}
|
||||
b.client.Call("Plugin.InvalidateKey", key, &struct{}{})
|
||||
}
|
||||
|
||||
func (b *backendPluginClient) Setup(config *logical.BackendConfig) error {
|
||||
func (b *backendPluginClient) Setup(ctx context.Context, config *logical.BackendConfig) error {
|
||||
// Shim logical.Storage
|
||||
storageImpl := config.StorageView
|
||||
if b.metadataMode {
|
||||
|
||||
@@ -20,7 +20,7 @@ var (
|
||||
type backendPluginServer struct {
|
||||
broker *plugin.MuxBroker
|
||||
backend logical.Backend
|
||||
factory func(*logical.BackendConfig) (logical.Backend, error)
|
||||
factory logical.Factory
|
||||
|
||||
loggerClient *rpc.Client
|
||||
sysViewClient *rpc.Client
|
||||
@@ -39,7 +39,7 @@ func (b *backendPluginServer) HandleRequest(args *HandleRequestArgs, reply *Hand
|
||||
storage := &StorageClient{client: b.storageClient}
|
||||
args.Request.Storage = storage
|
||||
|
||||
resp, err := b.backend.HandleRequest(context.TODO(), args.Request)
|
||||
resp, err := b.backend.HandleRequest(context.Background(), args.Request)
|
||||
*reply = HandleRequestReply{
|
||||
Response: resp,
|
||||
Error: wrapError(err),
|
||||
@@ -74,7 +74,7 @@ func (b *backendPluginServer) HandleExistenceCheck(args *HandleExistenceCheckArg
|
||||
}
|
||||
|
||||
func (b *backendPluginServer) Cleanup(_ interface{}, _ *struct{}) error {
|
||||
b.backend.Cleanup()
|
||||
b.backend.Cleanup(context.Background())
|
||||
|
||||
// Close rpc clients
|
||||
b.loggerClient.Close()
|
||||
@@ -88,7 +88,7 @@ func (b *backendPluginServer) Initialize(_ interface{}, _ *struct{}) error {
|
||||
return ErrServerInMetadataMode
|
||||
}
|
||||
|
||||
err := b.backend.Initialize()
|
||||
err := b.backend.Initialize(context.Background())
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -97,7 +97,7 @@ func (b *backendPluginServer) InvalidateKey(args string, _ *struct{}) error {
|
||||
return ErrServerInMetadataMode
|
||||
}
|
||||
|
||||
b.backend.InvalidateKey(args)
|
||||
b.backend.InvalidateKey(context.Background(), args)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -153,7 +153,7 @@ func (b *backendPluginServer) Setup(args *SetupArgs, reply *SetupReply) error {
|
||||
|
||||
// Call the underlying backend factory after shims have been created
|
||||
// to set b.backend
|
||||
backend, err := b.factory(config)
|
||||
backend, err := b.factory(context.Background(), config)
|
||||
if err != nil {
|
||||
*reply = SetupReply{
|
||||
Error: wrapError(err),
|
||||
|
||||
@@ -97,14 +97,14 @@ func TestBackendPlugin_Cleanup(t *testing.T) {
|
||||
b, cleanup := testBackend(t)
|
||||
defer cleanup()
|
||||
|
||||
b.Cleanup()
|
||||
b.Cleanup(context.Background())
|
||||
}
|
||||
|
||||
func TestBackendPlugin_Initialize(t *testing.T) {
|
||||
b, cleanup := testBackend(t)
|
||||
defer cleanup()
|
||||
|
||||
err := b.Initialize()
|
||||
err := b.Initialize(context.Background())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -114,7 +114,9 @@ func TestBackendPlugin_InvalidateKey(t *testing.T) {
|
||||
b, cleanup := testBackend(t)
|
||||
defer cleanup()
|
||||
|
||||
resp, err := b.HandleRequest(context.Background(), &logical.Request{
|
||||
ctx := context.Background()
|
||||
|
||||
resp, err := b.HandleRequest(ctx, &logical.Request{
|
||||
Operation: logical.ReadOperation,
|
||||
Path: "internal",
|
||||
})
|
||||
@@ -125,9 +127,9 @@ func TestBackendPlugin_InvalidateKey(t *testing.T) {
|
||||
t.Fatalf("bad: %#v, expected non-empty value", resp)
|
||||
}
|
||||
|
||||
b.InvalidateKey("internal")
|
||||
b.InvalidateKey(ctx, "internal")
|
||||
|
||||
resp, err = b.HandleRequest(context.Background(), &logical.Request{
|
||||
resp, err = b.HandleRequest(ctx, &logical.Request{
|
||||
Operation: logical.ReadOperation,
|
||||
Path: "internal",
|
||||
})
|
||||
@@ -163,7 +165,7 @@ func testBackend(t *testing.T) (logical.Backend, func()) {
|
||||
}
|
||||
b := raw.(logical.Backend)
|
||||
|
||||
err = b.Setup(&logical.BackendConfig{
|
||||
err = b.Setup(context.Background(), &logical.BackendConfig{
|
||||
Logger: logformat.NewVaultLogger(log.LevelTrace),
|
||||
System: &logical.StaticSystemView{
|
||||
DefaultLeaseTTLVal: 300 * time.Second,
|
||||
|
||||
@@ -15,6 +15,9 @@ import (
|
||||
|
||||
var ErrPluginShutdown = errors.New("plugin is shut down")
|
||||
|
||||
// Validate backendGRPCPluginClient satisfies the logical.Backend interface
|
||||
var _ logical.Backend = &backendGRPCPluginClient{}
|
||||
|
||||
// backendPluginClient implements logical.Backend and is the
|
||||
// go-plugin client.
|
||||
type backendGRPCPluginClient struct {
|
||||
@@ -126,33 +129,49 @@ func (b *backendGRPCPluginClient) HandleExistenceCheck(ctx context.Context, req
|
||||
return reply.CheckFound, reply.Exists, nil
|
||||
}
|
||||
|
||||
func (b *backendGRPCPluginClient) Cleanup() {
|
||||
b.client.Cleanup(b.doneCtx, &pb.Empty{})
|
||||
func (b *backendGRPCPluginClient) Cleanup(ctx context.Context) {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
quitCh := pluginutil.CtxCancelIfCanceled(cancel, b.doneCtx)
|
||||
defer close(quitCh)
|
||||
defer cancel()
|
||||
|
||||
b.client.Cleanup(ctx, &pb.Empty{})
|
||||
if b.server != nil {
|
||||
b.server.GracefulStop()
|
||||
}
|
||||
b.clientConn.Close()
|
||||
}
|
||||
|
||||
func (b *backendGRPCPluginClient) Initialize() error {
|
||||
func (b *backendGRPCPluginClient) Initialize(ctx context.Context) error {
|
||||
if b.metadataMode {
|
||||
return ErrClientInMetadataMode
|
||||
}
|
||||
|
||||
_, err := b.client.Initialize(b.doneCtx, &pb.Empty{})
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
quitCh := pluginutil.CtxCancelIfCanceled(cancel, b.doneCtx)
|
||||
defer close(quitCh)
|
||||
defer cancel()
|
||||
|
||||
_, err := b.client.Initialize(ctx, &pb.Empty{})
|
||||
return err
|
||||
}
|
||||
|
||||
func (b *backendGRPCPluginClient) InvalidateKey(key string) {
|
||||
func (b *backendGRPCPluginClient) InvalidateKey(ctx context.Context, key string) {
|
||||
if b.metadataMode {
|
||||
return
|
||||
}
|
||||
b.client.InvalidateKey(b.doneCtx, &pb.InvalidateKeyArgs{
|
||||
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
quitCh := pluginutil.CtxCancelIfCanceled(cancel, b.doneCtx)
|
||||
defer close(quitCh)
|
||||
defer cancel()
|
||||
|
||||
b.client.InvalidateKey(ctx, &pb.InvalidateKeyArgs{
|
||||
Key: key,
|
||||
})
|
||||
}
|
||||
|
||||
func (b *backendGRPCPluginClient) Setup(config *logical.BackendConfig) error {
|
||||
func (b *backendGRPCPluginClient) Setup(ctx context.Context, config *logical.BackendConfig) error {
|
||||
// Shim logical.Storage
|
||||
storageImpl := config.StorageView
|
||||
if b.metadataMode {
|
||||
@@ -187,7 +206,12 @@ func (b *backendGRPCPluginClient) Setup(config *logical.BackendConfig) error {
|
||||
Config: config.Config,
|
||||
}
|
||||
|
||||
reply, err := b.client.Setup(b.doneCtx, args)
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
quitCh := pluginutil.CtxCancelIfCanceled(cancel, b.doneCtx)
|
||||
defer close(quitCh)
|
||||
defer cancel()
|
||||
|
||||
reply, err := b.client.Setup(ctx, args)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -14,7 +14,7 @@ type backendGRPCPluginServer struct {
|
||||
broker *plugin.GRPCBroker
|
||||
backend logical.Backend
|
||||
|
||||
factory func(*logical.BackendConfig) (logical.Backend, error)
|
||||
factory logical.Factory
|
||||
|
||||
brokeredClient *grpc.ClientConn
|
||||
|
||||
@@ -43,7 +43,7 @@ func (b *backendGRPCPluginServer) Setup(ctx context.Context, args *pb.SetupArgs)
|
||||
|
||||
// Call the underlying backend factory after shims have been created
|
||||
// to set b.backend
|
||||
backend, err := b.factory(config)
|
||||
backend, err := b.factory(ctx, config)
|
||||
if err != nil {
|
||||
return &pb.SetupReply{
|
||||
Err: pb.ErrToString(err),
|
||||
@@ -112,7 +112,7 @@ func (b *backendGRPCPluginServer) HandleExistenceCheck(ctx context.Context, args
|
||||
}
|
||||
|
||||
func (b *backendGRPCPluginServer) Cleanup(ctx context.Context, _ *pb.Empty) (*pb.Empty, error) {
|
||||
b.backend.Cleanup()
|
||||
b.backend.Cleanup(ctx)
|
||||
|
||||
// Close rpc clients
|
||||
b.brokeredClient.Close()
|
||||
@@ -124,7 +124,7 @@ func (b *backendGRPCPluginServer) Initialize(ctx context.Context, _ *pb.Empty) (
|
||||
return &pb.Empty{}, ErrServerInMetadataMode
|
||||
}
|
||||
|
||||
err := b.backend.Initialize()
|
||||
err := b.backend.Initialize(ctx)
|
||||
return &pb.Empty{}, err
|
||||
}
|
||||
|
||||
@@ -133,7 +133,7 @@ func (b *backendGRPCPluginServer) InvalidateKey(ctx context.Context, args *pb.In
|
||||
return &pb.Empty{}, ErrServerInMetadataMode
|
||||
}
|
||||
|
||||
b.backend.InvalidateKey(args.Key)
|
||||
b.backend.InvalidateKey(ctx, args.Key)
|
||||
return &pb.Empty{}, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -99,14 +99,14 @@ func TestGRPCBackendPlugin_Cleanup(t *testing.T) {
|
||||
b, cleanup := testGRPCBackend(t)
|
||||
defer cleanup()
|
||||
|
||||
b.Cleanup()
|
||||
b.Cleanup(context.Background())
|
||||
}
|
||||
|
||||
func TestGRPCBackendPlugin_Initialize(t *testing.T) {
|
||||
b, cleanup := testGRPCBackend(t)
|
||||
defer cleanup()
|
||||
|
||||
err := b.Initialize()
|
||||
err := b.Initialize(context.Background())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -116,7 +116,9 @@ func TestGRPCBackendPlugin_InvalidateKey(t *testing.T) {
|
||||
b, cleanup := testGRPCBackend(t)
|
||||
defer cleanup()
|
||||
|
||||
resp, err := b.HandleRequest(context.Background(), &logical.Request{
|
||||
ctx := context.Background()
|
||||
|
||||
resp, err := b.HandleRequest(ctx, &logical.Request{
|
||||
Operation: logical.ReadOperation,
|
||||
Path: "internal",
|
||||
})
|
||||
@@ -127,9 +129,9 @@ func TestGRPCBackendPlugin_InvalidateKey(t *testing.T) {
|
||||
t.Fatalf("bad: %#v, expected non-empty value", resp)
|
||||
}
|
||||
|
||||
b.InvalidateKey("internal")
|
||||
b.InvalidateKey(ctx, "internal")
|
||||
|
||||
resp, err = b.HandleRequest(context.Background(), &logical.Request{
|
||||
resp, err = b.HandleRequest(ctx, &logical.Request{
|
||||
Operation: logical.ReadOperation,
|
||||
Path: "internal",
|
||||
})
|
||||
@@ -170,7 +172,7 @@ func testGRPCBackend(t *testing.T) (logical.Backend, func()) {
|
||||
}
|
||||
b := raw.(logical.Backend)
|
||||
|
||||
err = b.Setup(&logical.BackendConfig{
|
||||
err = b.Setup(context.Background(), &logical.BackendConfig{
|
||||
Logger: logformat.NewVaultLogger(log.LevelTrace),
|
||||
System: &logical.StaticSystemView{
|
||||
DefaultLeaseTTLVal: 300 * time.Second,
|
||||
|
||||
@@ -22,8 +22,8 @@ type GRPCStorageClient struct {
|
||||
client pb.StorageClient
|
||||
}
|
||||
|
||||
func (s *GRPCStorageClient) List(prefix string) ([]string, error) {
|
||||
reply, err := s.client.List(context.Background(), &pb.StorageListArgs{
|
||||
func (s *GRPCStorageClient) List(ctx context.Context, prefix string) ([]string, error) {
|
||||
reply, err := s.client.List(ctx, &pb.StorageListArgs{
|
||||
Prefix: prefix,
|
||||
})
|
||||
if err != nil {
|
||||
@@ -35,8 +35,8 @@ func (s *GRPCStorageClient) List(prefix string) ([]string, error) {
|
||||
return reply.Keys, nil
|
||||
}
|
||||
|
||||
func (s *GRPCStorageClient) Get(key string) (*logical.StorageEntry, error) {
|
||||
reply, err := s.client.Get(context.Background(), &pb.StorageGetArgs{
|
||||
func (s *GRPCStorageClient) Get(ctx context.Context, key string) (*logical.StorageEntry, error) {
|
||||
reply, err := s.client.Get(ctx, &pb.StorageGetArgs{
|
||||
Key: key,
|
||||
})
|
||||
if err != nil {
|
||||
@@ -48,8 +48,8 @@ func (s *GRPCStorageClient) Get(key string) (*logical.StorageEntry, error) {
|
||||
return pb.ProtoStorageEntryToLogicalStorageEntry(reply.Entry), nil
|
||||
}
|
||||
|
||||
func (s *GRPCStorageClient) Put(entry *logical.StorageEntry) error {
|
||||
reply, err := s.client.Put(context.Background(), &pb.StoragePutArgs{
|
||||
func (s *GRPCStorageClient) Put(ctx context.Context, entry *logical.StorageEntry) error {
|
||||
reply, err := s.client.Put(ctx, &pb.StoragePutArgs{
|
||||
Entry: pb.LogicalStorageEntryToProtoStorageEntry(entry),
|
||||
})
|
||||
if err != nil {
|
||||
@@ -61,8 +61,8 @@ func (s *GRPCStorageClient) Put(entry *logical.StorageEntry) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *GRPCStorageClient) Delete(key string) error {
|
||||
reply, err := s.client.Delete(context.Background(), &pb.StorageDeleteArgs{
|
||||
func (s *GRPCStorageClient) Delete(ctx context.Context, key string) error {
|
||||
reply, err := s.client.Delete(ctx, &pb.StorageDeleteArgs{
|
||||
Key: key,
|
||||
})
|
||||
if err != nil {
|
||||
@@ -80,7 +80,7 @@ type GRPCStorageServer struct {
|
||||
}
|
||||
|
||||
func (s *GRPCStorageServer) List(ctx context.Context, args *pb.StorageListArgs) (*pb.StorageListReply, error) {
|
||||
keys, err := s.impl.List(args.Prefix)
|
||||
keys, err := s.impl.List(ctx, args.Prefix)
|
||||
return &pb.StorageListReply{
|
||||
Keys: keys,
|
||||
Err: pb.ErrToString(err),
|
||||
@@ -88,7 +88,7 @@ func (s *GRPCStorageServer) List(ctx context.Context, args *pb.StorageListArgs)
|
||||
}
|
||||
|
||||
func (s *GRPCStorageServer) Get(ctx context.Context, args *pb.StorageGetArgs) (*pb.StorageGetReply, error) {
|
||||
storageEntry, err := s.impl.Get(args.Key)
|
||||
storageEntry, err := s.impl.Get(ctx, args.Key)
|
||||
return &pb.StorageGetReply{
|
||||
Entry: pb.LogicalStorageEntryToProtoStorageEntry(storageEntry),
|
||||
Err: pb.ErrToString(err),
|
||||
@@ -96,14 +96,14 @@ func (s *GRPCStorageServer) Get(ctx context.Context, args *pb.StorageGetArgs) (*
|
||||
}
|
||||
|
||||
func (s *GRPCStorageServer) Put(ctx context.Context, args *pb.StoragePutArgs) (*pb.StoragePutReply, error) {
|
||||
err := s.impl.Put(pb.ProtoStorageEntryToLogicalStorageEntry(args.Entry))
|
||||
err := s.impl.Put(ctx, pb.ProtoStorageEntryToLogicalStorageEntry(args.Entry))
|
||||
return &pb.StoragePutReply{
|
||||
Err: pb.ErrToString(err),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *GRPCStorageServer) Delete(ctx context.Context, args *pb.StorageDeleteArgs) (*pb.StorageDeleteReply, error) {
|
||||
err := s.impl.Delete(args.Key)
|
||||
err := s.impl.Delete(ctx, args.Key)
|
||||
return &pb.StorageDeleteReply{
|
||||
Err: pb.ErrToString(err),
|
||||
}, nil
|
||||
|
||||
@@ -45,8 +45,8 @@ func (s *gRPCSystemViewClient) MaxLeaseTTL() time.Duration {
|
||||
return time.Duration(reply.TTL)
|
||||
}
|
||||
|
||||
func (s *gRPCSystemViewClient) SudoPrivilege(path string, token string) bool {
|
||||
reply, err := s.client.SudoPrivilege(context.Background(), &pb.SudoPrivilegeArgs{
|
||||
func (s *gRPCSystemViewClient) SudoPrivilege(ctx context.Context, path string, token string) bool {
|
||||
reply, err := s.client.SudoPrivilege(ctx, &pb.SudoPrivilegeArgs{
|
||||
Path: path,
|
||||
Token: token,
|
||||
})
|
||||
@@ -84,13 +84,13 @@ func (s *gRPCSystemViewClient) ReplicationState() consts.ReplicationState {
|
||||
return consts.ReplicationState(reply.State)
|
||||
}
|
||||
|
||||
func (s *gRPCSystemViewClient) ResponseWrapData(data map[string]interface{}, ttl time.Duration, jwt bool) (*wrapping.ResponseWrapInfo, error) {
|
||||
func (s *gRPCSystemViewClient) ResponseWrapData(ctx context.Context, data map[string]interface{}, ttl time.Duration, jwt bool) (*wrapping.ResponseWrapInfo, error) {
|
||||
buf, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
reply, err := s.client.ResponseWrapData(context.Background(), &pb.ResponseWrapDataArgs{
|
||||
reply, err := s.client.ResponseWrapData(ctx, &pb.ResponseWrapDataArgs{
|
||||
Data: buf,
|
||||
TTL: int64(ttl),
|
||||
JWT: false,
|
||||
@@ -110,7 +110,7 @@ func (s *gRPCSystemViewClient) ResponseWrapData(data map[string]interface{}, ttl
|
||||
return info, nil
|
||||
}
|
||||
|
||||
func (s *gRPCSystemViewClient) LookupPlugin(name string) (*pluginutil.PluginRunner, error) {
|
||||
func (s *gRPCSystemViewClient) LookupPlugin(ctx context.Context, name string) (*pluginutil.PluginRunner, error) {
|
||||
return nil, fmt.Errorf("cannot call LookupPlugin from a plugin backend")
|
||||
}
|
||||
|
||||
@@ -142,7 +142,7 @@ func (s *gRPCSystemViewServer) MaxLeaseTTL(ctx context.Context, _ *pb.Empty) (*p
|
||||
}
|
||||
|
||||
func (s *gRPCSystemViewServer) SudoPrivilege(ctx context.Context, args *pb.SudoPrivilegeArgs) (*pb.SudoPrivilegeReply, error) {
|
||||
sudo := s.impl.SudoPrivilege(args.Path, args.Token)
|
||||
sudo := s.impl.SudoPrivilege(ctx, args.Path, args.Token)
|
||||
return &pb.SudoPrivilegeReply{
|
||||
Sudo: sudo,
|
||||
}, nil
|
||||
@@ -177,7 +177,7 @@ func (s *gRPCSystemViewServer) ResponseWrapData(ctx context.Context, args *pb.Re
|
||||
}
|
||||
|
||||
// Do not allow JWTs to be returned
|
||||
info, err := s.impl.ResponseWrapData(data, time.Duration(args.TTL), false)
|
||||
info, err := s.impl.ResponseWrapData(ctx, data, time.Duration(args.TTL), false)
|
||||
if err != nil {
|
||||
return &pb.ResponseWrapDataReply{
|
||||
Err: pb.ErrToString(err),
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package plugin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
@@ -62,8 +63,10 @@ func TestSystem_GRPC_sudoPrivilege(t *testing.T) {
|
||||
defer client.Close()
|
||||
testSystemView := newGRPCSystemView(client)
|
||||
|
||||
expected := sys.SudoPrivilege("foo", "bar")
|
||||
actual := testSystemView.SudoPrivilege("foo", "bar")
|
||||
ctx := context.Background()
|
||||
|
||||
expected := sys.SudoPrivilege(ctx, "foo", "bar")
|
||||
actual := testSystemView.SudoPrivilege(ctx, "foo", "bar")
|
||||
if !reflect.DeepEqual(expected, actual) {
|
||||
t.Fatalf("expected: %v, got: %v", expected, actual)
|
||||
}
|
||||
@@ -138,7 +141,7 @@ func TestSystem_GRPC_lookupPlugin(t *testing.T) {
|
||||
|
||||
testSystemView := newGRPCSystemView(client)
|
||||
|
||||
if _, err := testSystemView.LookupPlugin("foo"); err == nil {
|
||||
if _, err := testSystemView.LookupPlugin(context.Background(), "foo"); err == nil {
|
||||
t.Fatal("LookPlugin(): expected error on due to unsupported call from plugin")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -18,6 +18,9 @@ type backendTracingMiddleware struct {
|
||||
next logical.Backend
|
||||
}
|
||||
|
||||
// Validate the backendTracingMiddle object satisfies the backend interface
|
||||
var _ logical.Backend = &backendTracingMiddleware{}
|
||||
|
||||
func (b *backendTracingMiddleware) HandleRequest(ctx context.Context, req *logical.Request) (resp *logical.Response, err error) {
|
||||
defer func(then time.Time) {
|
||||
b.logger.Trace("plugin.HandleRequest", "path", req.Path, "status", "finished", "type", b.typeStr, "transport", b.transport, "err", err, "took", time.Since(then))
|
||||
@@ -53,40 +56,40 @@ func (b *backendTracingMiddleware) HandleExistenceCheck(ctx context.Context, req
|
||||
return b.next.HandleExistenceCheck(ctx, req)
|
||||
}
|
||||
|
||||
func (b *backendTracingMiddleware) Cleanup() {
|
||||
func (b *backendTracingMiddleware) Cleanup(ctx context.Context) {
|
||||
defer func(then time.Time) {
|
||||
b.logger.Trace("plugin.Cleanup", "status", "finished", "type", b.typeStr, "transport", b.transport, "took", time.Since(then))
|
||||
}(time.Now())
|
||||
|
||||
b.logger.Trace("plugin.Cleanup", "status", "started", "type", b.typeStr, "transport", b.transport)
|
||||
b.next.Cleanup()
|
||||
b.next.Cleanup(ctx)
|
||||
}
|
||||
|
||||
func (b *backendTracingMiddleware) Initialize() (err error) {
|
||||
func (b *backendTracingMiddleware) Initialize(ctx context.Context) (err error) {
|
||||
defer func(then time.Time) {
|
||||
b.logger.Trace("plugin.Initialize", "status", "finished", "type", b.typeStr, "transport", b.transport, "err", err, "took", time.Since(then))
|
||||
}(time.Now())
|
||||
|
||||
b.logger.Trace("plugin.Initialize", "status", "started", "type", b.typeStr, "transport", b.transport)
|
||||
return b.next.Initialize()
|
||||
return b.next.Initialize(ctx)
|
||||
}
|
||||
|
||||
func (b *backendTracingMiddleware) InvalidateKey(key string) {
|
||||
func (b *backendTracingMiddleware) InvalidateKey(ctx context.Context, key string) {
|
||||
defer func(then time.Time) {
|
||||
b.logger.Trace("plugin.InvalidateKey", "key", key, "status", "finished", "type", b.typeStr, "transport", b.transport, "took", time.Since(then))
|
||||
}(time.Now())
|
||||
|
||||
b.logger.Trace("plugin.InvalidateKey", "key", key, "status", "started", "type", b.typeStr, "transport", b.transport)
|
||||
b.next.InvalidateKey(key)
|
||||
b.next.InvalidateKey(ctx, key)
|
||||
}
|
||||
|
||||
func (b *backendTracingMiddleware) Setup(config *logical.BackendConfig) (err error) {
|
||||
func (b *backendTracingMiddleware) Setup(ctx context.Context, config *logical.BackendConfig) (err error) {
|
||||
defer func(then time.Time) {
|
||||
b.logger.Trace("plugin.Setup", "status", "finished", "type", b.typeStr, "transport", b.transport, "err", err, "took", time.Since(then))
|
||||
}(time.Now())
|
||||
|
||||
b.logger.Trace("plugin.Setup", "status", "started", "type", b.typeStr, "transport", b.transport)
|
||||
return b.next.Setup(config)
|
||||
return b.next.Setup(ctx, config)
|
||||
}
|
||||
|
||||
func (b *backendTracingMiddleware) Type() logical.BackendType {
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package mock
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/logical/framework"
|
||||
)
|
||||
@@ -12,9 +14,9 @@ func New() (interface{}, error) {
|
||||
}
|
||||
|
||||
// Factory returns a new backend as logical.Backend.
|
||||
func Factory(conf *logical.BackendConfig) (logical.Backend, error) {
|
||||
func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) {
|
||||
b := Backend()
|
||||
if err := b.Setup(conf); err != nil {
|
||||
if err := b.Setup(ctx, conf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return b, nil
|
||||
@@ -22,11 +24,11 @@ func Factory(conf *logical.BackendConfig) (logical.Backend, error) {
|
||||
|
||||
// FactoryType is a wrapper func that allows the Factory func to specify
|
||||
// the backend type for the mock backend plugin instance.
|
||||
func FactoryType(backendType logical.BackendType) func(*logical.BackendConfig) (logical.Backend, error) {
|
||||
return func(conf *logical.BackendConfig) (logical.Backend, error) {
|
||||
func FactoryType(backendType logical.BackendType) logical.Factory {
|
||||
return func(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) {
|
||||
b := Backend()
|
||||
b.BackendType = backendType
|
||||
if err := b.Setup(conf); err != nil {
|
||||
if err := b.Setup(ctx, conf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return b, nil
|
||||
@@ -66,7 +68,7 @@ type backend struct {
|
||||
internal string
|
||||
}
|
||||
|
||||
func (b *backend) invalidate(key string) {
|
||||
func (b *backend) invalidate(ctx context.Context, key string) {
|
||||
switch key {
|
||||
case "internal":
|
||||
b.internal = ""
|
||||
|
||||
@@ -36,7 +36,7 @@ func kvPaths(b *backend) []*framework.Path {
|
||||
}
|
||||
|
||||
func (b *backend) pathExistenceCheck(ctx context.Context, req *logical.Request, data *framework.FieldData) (bool, error) {
|
||||
out, err := req.Storage.Get(req.Path)
|
||||
out, err := req.Storage.Get(ctx, req.Path)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("existence check failed: %v", err)
|
||||
}
|
||||
@@ -45,7 +45,7 @@ func (b *backend) pathExistenceCheck(ctx context.Context, req *logical.Request,
|
||||
}
|
||||
|
||||
func (b *backend) pathKVRead(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
entry, err := req.Storage.Get(req.Path)
|
||||
entry, err := req.Storage.Get(ctx, req.Path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -75,7 +75,7 @@ func (b *backend) pathKVCreateUpdate(ctx context.Context, req *logical.Request,
|
||||
}
|
||||
|
||||
s := req.Storage
|
||||
err := s.Put(entry)
|
||||
err := s.Put(ctx, entry)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -88,7 +88,7 @@ func (b *backend) pathKVCreateUpdate(ctx context.Context, req *logical.Request,
|
||||
}
|
||||
|
||||
func (b *backend) pathKVDelete(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
if err := req.Storage.Delete(req.Path); err != nil {
|
||||
if err := req.Storage.Delete(ctx, req.Path); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -96,7 +96,7 @@ func (b *backend) pathKVDelete(ctx context.Context, req *logical.Request, data *
|
||||
}
|
||||
|
||||
func (b *backend) pathKVList(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
vals, err := req.Storage.List("kv/")
|
||||
vals, err := req.Storage.List(ctx, "kv/")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package plugin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ecdsa"
|
||||
"crypto/rsa"
|
||||
"encoding/gob"
|
||||
@@ -50,8 +51,8 @@ type BackendPluginClient struct {
|
||||
|
||||
// Cleanup calls the RPC client's Cleanup() func and also calls
|
||||
// the go-plugin's client Kill() func
|
||||
func (b *BackendPluginClient) Cleanup() {
|
||||
b.Backend.Cleanup()
|
||||
func (b *BackendPluginClient) Cleanup(ctx context.Context) {
|
||||
b.Backend.Cleanup(ctx)
|
||||
b.client.Kill()
|
||||
}
|
||||
|
||||
@@ -59,9 +60,9 @@ func (b *BackendPluginClient) Cleanup() {
|
||||
// external plugins, or a concrete implementation of the backend if it is a builtin backend.
|
||||
// The backend is returned as a logical.Backend interface. The isMetadataMode param determines whether
|
||||
// the plugin should run in metadata mode.
|
||||
func NewBackend(pluginName string, sys pluginutil.LookRunnerUtil, logger log.Logger, isMetadataMode bool) (logical.Backend, error) {
|
||||
func NewBackend(ctx context.Context, pluginName string, sys pluginutil.LookRunnerUtil, logger log.Logger, isMetadataMode bool) (logical.Backend, error) {
|
||||
// Look for plugin in the plugin catalog
|
||||
pluginRunner, err := sys.LookupPlugin(pluginName)
|
||||
pluginRunner, err := sys.LookupPlugin(ctx, pluginName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -83,7 +84,7 @@ func NewBackend(pluginName string, sys pluginutil.LookRunnerUtil, logger log.Log
|
||||
|
||||
} else {
|
||||
// create a backendPluginClient instance
|
||||
backend, err = newPluginClient(sys, pluginRunner, logger, isMetadataMode)
|
||||
backend, err = newPluginClient(ctx, sys, pluginRunner, logger, isMetadataMode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -92,7 +93,7 @@ func NewBackend(pluginName string, sys pluginutil.LookRunnerUtil, logger log.Log
|
||||
return backend, nil
|
||||
}
|
||||
|
||||
func newPluginClient(sys pluginutil.RunnerUtil, pluginRunner *pluginutil.PluginRunner, logger log.Logger, isMetadataMode bool) (logical.Backend, error) {
|
||||
func newPluginClient(ctx context.Context, sys pluginutil.RunnerUtil, pluginRunner *pluginutil.PluginRunner, logger log.Logger, isMetadataMode bool) (logical.Backend, error) {
|
||||
// pluginMap is the map of plugins we can dispense.
|
||||
pluginMap := map[string]plugin.Plugin{
|
||||
"backend": &BackendPlugin{
|
||||
@@ -103,9 +104,9 @@ func newPluginClient(sys pluginutil.RunnerUtil, pluginRunner *pluginutil.PluginR
|
||||
var client *plugin.Client
|
||||
var err error
|
||||
if isMetadataMode {
|
||||
client, err = pluginRunner.RunMetadataMode(sys, pluginMap, handshakeConfig, []string{}, logger)
|
||||
client, err = pluginRunner.RunMetadataMode(ctx, sys, pluginMap, handshakeConfig, []string{}, logger)
|
||||
} else {
|
||||
client, err = pluginRunner.Run(sys, pluginMap, handshakeConfig, []string{}, logger)
|
||||
client, err = pluginRunner.Run(ctx, sys, pluginMap, handshakeConfig, []string{}, logger)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -14,11 +14,10 @@ import (
|
||||
// dispensed rom the plugin server.
|
||||
const BackendPluginName = "backend"
|
||||
|
||||
type BackendFactoryFunc func(*logical.BackendConfig) (logical.Backend, error)
|
||||
type TLSProdiverFunc func() (*tls.Config, error)
|
||||
|
||||
type ServeOpts struct {
|
||||
BackendFactoryFunc BackendFactoryFunc
|
||||
BackendFactoryFunc logical.Factory
|
||||
TLSProviderFunc TLSProdiverFunc
|
||||
Logger hclog.Logger
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package plugin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/rpc"
|
||||
|
||||
"github.com/hashicorp/vault/logical"
|
||||
@@ -12,7 +13,7 @@ type StorageClient struct {
|
||||
client *rpc.Client
|
||||
}
|
||||
|
||||
func (s *StorageClient) List(prefix string) ([]string, error) {
|
||||
func (s *StorageClient) List(_ context.Context, prefix string) ([]string, error) {
|
||||
var reply StorageListReply
|
||||
err := s.client.Call("Plugin.List", prefix, &reply)
|
||||
if err != nil {
|
||||
@@ -24,7 +25,7 @@ func (s *StorageClient) List(prefix string) ([]string, error) {
|
||||
return reply.Keys, nil
|
||||
}
|
||||
|
||||
func (s *StorageClient) Get(key string) (*logical.StorageEntry, error) {
|
||||
func (s *StorageClient) Get(_ context.Context, key string) (*logical.StorageEntry, error) {
|
||||
var reply StorageGetReply
|
||||
err := s.client.Call("Plugin.Get", key, &reply)
|
||||
if err != nil {
|
||||
@@ -36,7 +37,7 @@ func (s *StorageClient) Get(key string) (*logical.StorageEntry, error) {
|
||||
return reply.StorageEntry, nil
|
||||
}
|
||||
|
||||
func (s *StorageClient) Put(entry *logical.StorageEntry) error {
|
||||
func (s *StorageClient) Put(_ context.Context, entry *logical.StorageEntry) error {
|
||||
var reply StoragePutReply
|
||||
err := s.client.Call("Plugin.Put", entry, &reply)
|
||||
if err != nil {
|
||||
@@ -48,7 +49,7 @@ func (s *StorageClient) Put(entry *logical.StorageEntry) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *StorageClient) Delete(key string) error {
|
||||
func (s *StorageClient) Delete(_ context.Context, key string) error {
|
||||
var reply StorageDeleteReply
|
||||
err := s.client.Call("Plugin.Delete", key, &reply)
|
||||
if err != nil {
|
||||
@@ -66,7 +67,7 @@ type StorageServer struct {
|
||||
}
|
||||
|
||||
func (s *StorageServer) List(prefix string, reply *StorageListReply) error {
|
||||
keys, err := s.impl.List(prefix)
|
||||
keys, err := s.impl.List(context.Background(), prefix)
|
||||
*reply = StorageListReply{
|
||||
Keys: keys,
|
||||
Error: wrapError(err),
|
||||
@@ -75,7 +76,7 @@ func (s *StorageServer) List(prefix string, reply *StorageListReply) error {
|
||||
}
|
||||
|
||||
func (s *StorageServer) Get(key string, reply *StorageGetReply) error {
|
||||
storageEntry, err := s.impl.Get(key)
|
||||
storageEntry, err := s.impl.Get(context.Background(), key)
|
||||
*reply = StorageGetReply{
|
||||
StorageEntry: storageEntry,
|
||||
Error: wrapError(err),
|
||||
@@ -84,7 +85,7 @@ func (s *StorageServer) Get(key string, reply *StorageGetReply) error {
|
||||
}
|
||||
|
||||
func (s *StorageServer) Put(entry *logical.StorageEntry, reply *StoragePutReply) error {
|
||||
err := s.impl.Put(entry)
|
||||
err := s.impl.Put(context.Background(), entry)
|
||||
*reply = StoragePutReply{
|
||||
Error: wrapError(err),
|
||||
}
|
||||
@@ -92,7 +93,7 @@ func (s *StorageServer) Put(entry *logical.StorageEntry, reply *StoragePutReply)
|
||||
}
|
||||
|
||||
func (s *StorageServer) Delete(key string, reply *StorageDeleteReply) error {
|
||||
err := s.impl.Delete(key)
|
||||
err := s.impl.Delete(context.Background(), key)
|
||||
*reply = StorageDeleteReply{
|
||||
Error: wrapError(err),
|
||||
}
|
||||
@@ -121,18 +122,18 @@ type StorageDeleteReply struct {
|
||||
// backend plugin in metadata mode.
|
||||
type NOOPStorage struct{}
|
||||
|
||||
func (s *NOOPStorage) List(prefix string) ([]string, error) {
|
||||
func (s *NOOPStorage) List(_ context.Context, prefix string) ([]string, error) {
|
||||
return []string{}, nil
|
||||
}
|
||||
|
||||
func (s *NOOPStorage) Get(key string) (*logical.StorageEntry, error) {
|
||||
func (s *NOOPStorage) Get(_ context.Context, key string) (*logical.StorageEntry, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *NOOPStorage) Put(entry *logical.StorageEntry) error {
|
||||
func (s *NOOPStorage) Put(_ context.Context, entry *logical.StorageEntry) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *NOOPStorage) Delete(key string) error {
|
||||
func (s *NOOPStorage) Delete(_ context.Context, key string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package plugin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/rpc"
|
||||
"time"
|
||||
|
||||
@@ -36,7 +37,7 @@ func (s *SystemViewClient) MaxLeaseTTL() time.Duration {
|
||||
return reply.MaxLeaseTTL
|
||||
}
|
||||
|
||||
func (s *SystemViewClient) SudoPrivilege(path string, token string) bool {
|
||||
func (s *SystemViewClient) SudoPrivilege(ctx context.Context, path string, token string) bool {
|
||||
var reply SudoPrivilegeReply
|
||||
args := &SudoPrivilegeArgs{
|
||||
Path: path,
|
||||
@@ -84,7 +85,7 @@ func (s *SystemViewClient) ReplicationState() consts.ReplicationState {
|
||||
return reply.ReplicationState
|
||||
}
|
||||
|
||||
func (s *SystemViewClient) ResponseWrapData(data map[string]interface{}, ttl time.Duration, jwt bool) (*wrapping.ResponseWrapInfo, error) {
|
||||
func (s *SystemViewClient) ResponseWrapData(ctx context.Context, data map[string]interface{}, ttl time.Duration, jwt bool) (*wrapping.ResponseWrapInfo, error) {
|
||||
var reply ResponseWrapDataReply
|
||||
// Do not allow JWTs to be returned
|
||||
args := &ResponseWrapDataArgs{
|
||||
@@ -104,7 +105,7 @@ func (s *SystemViewClient) ResponseWrapData(data map[string]interface{}, ttl tim
|
||||
return reply.ResponseWrapInfo, nil
|
||||
}
|
||||
|
||||
func (s *SystemViewClient) LookupPlugin(name string) (*pluginutil.PluginRunner, error) {
|
||||
func (s *SystemViewClient) LookupPlugin(ctx context.Context, name string) (*pluginutil.PluginRunner, error) {
|
||||
return nil, fmt.Errorf("cannot call LookupPlugin from a plugin backend")
|
||||
}
|
||||
|
||||
@@ -141,7 +142,7 @@ func (s *SystemViewServer) MaxLeaseTTL(_ interface{}, reply *MaxLeaseTTLReply) e
|
||||
}
|
||||
|
||||
func (s *SystemViewServer) SudoPrivilege(args *SudoPrivilegeArgs, reply *SudoPrivilegeReply) error {
|
||||
sudo := s.impl.SudoPrivilege(args.Path, args.Token)
|
||||
sudo := s.impl.SudoPrivilege(context.Background(), args.Path, args.Token)
|
||||
*reply = SudoPrivilegeReply{
|
||||
Sudo: sudo,
|
||||
}
|
||||
@@ -178,7 +179,7 @@ func (s *SystemViewServer) ReplicationState(_ interface{}, reply *ReplicationSta
|
||||
|
||||
func (s *SystemViewServer) ResponseWrapData(args *ResponseWrapDataArgs, reply *ResponseWrapDataReply) error {
|
||||
// Do not allow JWTs to be returned
|
||||
info, err := s.impl.ResponseWrapData(args.Data, args.TTL, false)
|
||||
info, err := s.impl.ResponseWrapData(context.Background(), args.Data, args.TTL, false)
|
||||
if err != nil {
|
||||
*reply = ResponseWrapDataReply{
|
||||
Error: wrapError(err),
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package plugin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"reflect"
|
||||
@@ -64,9 +65,10 @@ func TestSystem_sudoPrivilege(t *testing.T) {
|
||||
})
|
||||
|
||||
testSystemView := &SystemViewClient{client: client}
|
||||
ctx := context.Background()
|
||||
|
||||
expected := sys.SudoPrivilege("foo", "bar")
|
||||
actual := testSystemView.SudoPrivilege("foo", "bar")
|
||||
expected := sys.SudoPrivilege(ctx, "foo", "bar")
|
||||
actual := testSystemView.SudoPrivilege(ctx, "foo", "bar")
|
||||
if !reflect.DeepEqual(expected, actual) {
|
||||
t.Fatalf("expected: %v, got: %v", expected, actual)
|
||||
}
|
||||
@@ -148,7 +150,7 @@ func TestSystem_lookupPlugin(t *testing.T) {
|
||||
|
||||
testSystemView := &SystemViewClient{client: client}
|
||||
|
||||
if _, err := testSystemView.LookupPlugin("foo"); err == nil {
|
||||
if _, err := testSystemView.LookupPlugin(context.Background(), "foo"); err == nil {
|
||||
t.Fatal("LookPlugin(): expected error on due to unsupported call from plugin")
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user