mirror of
				https://github.com/optim-enterprises-bv/vault.git
				synced 2025-11-04 04:28:08 +00:00 
			
		
		
		
	* feat: DB plugin multiplexing (#13734)
* WIP: start from main and get a plugin runner from core
* move MultiplexedClient map to plugin catalog
- call sys.NewPluginClient from PluginFactory
- updates to getPluginClient
- thread through isMetadataMode
* use go-plugin ClientProtocol interface
- call sys.NewPluginClient from dbplugin.NewPluginClient
* move PluginSets to dbplugin package
- export dbplugin HandshakeConfig
- small refactor of PluginCatalog.getPluginClient
* add removeMultiplexedClient; clean up on Close()
- call client.Kill from plugin catalog
- set rpcClient when muxed client exists
* add ID to dbplugin.DatabasePluginClient struct
* only create one plugin process per plugin type
* update NewPluginClient to return connection ID to sdk
- wrap grpc.ClientConn so we can inject the ID into context
- get ID from context on grpc server
* add v6 multiplexing  protocol version
* WIP: backwards compat for db plugins
* Ensure locking on plugin catalog access
- Create public GetPluginClient method for plugin catalog
- rename postgres db plugin
* use the New constructor for db plugins
* grpc server: use write lock for Close and rlock for CRUD
* cleanup MultiplexedClients on Close
* remove TODO
* fix multiplexing regression with grpc server connection
* cleanup grpc server instances on close
* embed ClientProtocol in Multiplexer interface
* use PluginClientConfig arg to make NewPluginClient plugin type agnostic
* create a new plugin process for non-muxed plugins
* feat: plugin multiplexing: handle plugin client cleanup (#13896)
* use closure for plugin client cleanup
* log and return errors; add comments
* move rpcClient wrapping to core for ID injection
* refactor core plugin client and sdk
* remove unused ID method
* refactor and only wrap clientConn on multiplexed plugins
* rename structs and do not export types
* Slight refactor of system view interface
* Revert "Slight refactor of system view interface"
This reverts commit 73d420e5cd.
* Revert "Revert "Slight refactor of system view interface""
This reverts commit f75527008a1db06d04a23e04c3059674be8adb5f.
* only provide pluginRunner arg to the internal newPluginClient method
* embed ClientProtocol in pluginClient and name logger
* Add back MLock support
* remove enableMlock arg from setupPluginCatalog
* rename plugin util interface to PluginClient
Co-authored-by: Brian Kassouf <bkassouf@hashicorp.com>
* feature: multiplexing: fix unit tests (#14007)
* fix grpc_server tests and add coverage
* update run_config tests
* add happy path test case for grpc_server ID from context
* update test helpers
* feat: multiplexing: handle v5 plugin compiled with new sdk
* add mux supported flag and increase test coverage
* set multiplexingSupport field in plugin server
* remove multiplexingSupport field in sdk
* revert postgres to non-multiplexed
* add comments on grpc server fields
* use pointer receiver on grpc server methods
* add changelog
* use pointer for grpcserver instance
* Use a gRPC server to determine if a plugin should be multiplexed
* Apply suggestions from code review
Co-authored-by: Brian Kassouf <briankassouf@users.noreply.github.com>
* add lock to removePluginClient
* add multiplexingSupport field to externalPlugin struct
* do not send nil to grpc MultiplexingSupport
* check err before logging
* handle locking scenario for cleanupFunc
* allow ServeConfigMultiplex to dispense v5 plugin
* reposition structs, add err check and comments
* add comment on locking for cleanupExternalPlugin
Co-authored-by: Brian Kassouf <bkassouf@hashicorp.com>
Co-authored-by: Brian Kassouf <briankassouf@users.noreply.github.com>
		
	
		
			
				
	
	
		
			801 lines
		
	
	
		
			20 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			801 lines
		
	
	
		
			20 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
package dbplugin
 | 
						|
 | 
						|
import (
 | 
						|
	"context"
 | 
						|
	"errors"
 | 
						|
	"fmt"
 | 
						|
	"reflect"
 | 
						|
	"testing"
 | 
						|
	"time"
 | 
						|
 | 
						|
	"google.golang.org/protobuf/types/known/structpb"
 | 
						|
 | 
						|
	"github.com/golang/protobuf/ptypes"
 | 
						|
	"github.com/golang/protobuf/ptypes/timestamp"
 | 
						|
	"github.com/hashicorp/vault/sdk/database/dbplugin/v5/proto"
 | 
						|
	"github.com/hashicorp/vault/sdk/helper/pluginutil"
 | 
						|
	"google.golang.org/grpc/codes"
 | 
						|
	"google.golang.org/grpc/metadata"
 | 
						|
	"google.golang.org/grpc/status"
 | 
						|
)
 | 
						|
 | 
						|
// Before minValidSeconds in ptypes package
 | 
						|
var invalidExpiration = time.Date(0, 1, 1, 0, 0, 0, 0, time.UTC)
 | 
						|
 | 
						|
func TestGRPCServer_Initialize(t *testing.T) {
 | 
						|
	type testCase struct {
 | 
						|
		db            Database
 | 
						|
		req           *proto.InitializeRequest
 | 
						|
		expectedResp  *proto.InitializeResponse
 | 
						|
		expectErr     bool
 | 
						|
		expectCode    codes.Code
 | 
						|
		grpcSetupFunc func(*testing.T, Database) (context.Context, gRPCServer)
 | 
						|
	}
 | 
						|
 | 
						|
	tests := map[string]testCase{
 | 
						|
		"database errored": {
 | 
						|
			db: fakeDatabase{
 | 
						|
				initErr: errors.New("initialization error"),
 | 
						|
			},
 | 
						|
			req:           &proto.InitializeRequest{},
 | 
						|
			expectedResp:  &proto.InitializeResponse{},
 | 
						|
			expectErr:     true,
 | 
						|
			expectCode:    codes.Internal,
 | 
						|
			grpcSetupFunc: testGrpcServer,
 | 
						|
		},
 | 
						|
		"newConfig can't marshal to JSON": {
 | 
						|
			db: fakeDatabase{
 | 
						|
				initResp: InitializeResponse{
 | 
						|
					Config: map[string]interface{}{
 | 
						|
						"bad-data": badJSONValue{},
 | 
						|
					},
 | 
						|
				},
 | 
						|
			},
 | 
						|
			req:           &proto.InitializeRequest{},
 | 
						|
			expectedResp:  &proto.InitializeResponse{},
 | 
						|
			expectErr:     true,
 | 
						|
			expectCode:    codes.Internal,
 | 
						|
			grpcSetupFunc: testGrpcServer,
 | 
						|
		},
 | 
						|
		"happy path with config data for multiplexed plugin": {
 | 
						|
			db: fakeDatabase{
 | 
						|
				initResp: InitializeResponse{
 | 
						|
					Config: map[string]interface{}{
 | 
						|
						"foo": "bar",
 | 
						|
					},
 | 
						|
				},
 | 
						|
			},
 | 
						|
			req: &proto.InitializeRequest{
 | 
						|
				ConfigData: marshal(t, map[string]interface{}{
 | 
						|
					"foo": "bar",
 | 
						|
				}),
 | 
						|
			},
 | 
						|
			expectedResp: &proto.InitializeResponse{
 | 
						|
				ConfigData: marshal(t, map[string]interface{}{
 | 
						|
					"foo": "bar",
 | 
						|
				}),
 | 
						|
			},
 | 
						|
			expectErr:     false,
 | 
						|
			expectCode:    codes.OK,
 | 
						|
			grpcSetupFunc: testGrpcServer,
 | 
						|
		},
 | 
						|
		"happy path with config data for non-multiplexed plugin": {
 | 
						|
			db: fakeDatabase{
 | 
						|
				initResp: InitializeResponse{
 | 
						|
					Config: map[string]interface{}{
 | 
						|
						"foo": "bar",
 | 
						|
					},
 | 
						|
				},
 | 
						|
			},
 | 
						|
			req: &proto.InitializeRequest{
 | 
						|
				ConfigData: marshal(t, map[string]interface{}{
 | 
						|
					"foo": "bar",
 | 
						|
				}),
 | 
						|
			},
 | 
						|
			expectedResp: &proto.InitializeResponse{
 | 
						|
				ConfigData: marshal(t, map[string]interface{}{
 | 
						|
					"foo": "bar",
 | 
						|
				}),
 | 
						|
			},
 | 
						|
			expectErr:     false,
 | 
						|
			expectCode:    codes.OK,
 | 
						|
			grpcSetupFunc: testGrpcServerSingleImpl,
 | 
						|
		},
 | 
						|
	}
 | 
						|
 | 
						|
	for name, test := range tests {
 | 
						|
		t.Run(name, func(t *testing.T) {
 | 
						|
			idCtx, g := test.grpcSetupFunc(t, test.db)
 | 
						|
			resp, err := g.Initialize(idCtx, test.req)
 | 
						|
 | 
						|
			if test.expectErr && err == nil {
 | 
						|
				t.Fatalf("err expected, got nil")
 | 
						|
			}
 | 
						|
			if !test.expectErr && err != nil {
 | 
						|
				t.Fatalf("no error expected, got: %s", err)
 | 
						|
			}
 | 
						|
 | 
						|
			actualCode := status.Code(err)
 | 
						|
			if actualCode != test.expectCode {
 | 
						|
				t.Fatalf("Actual code: %s Expected code: %s", actualCode, test.expectCode)
 | 
						|
			}
 | 
						|
 | 
						|
			if !reflect.DeepEqual(resp, test.expectedResp) {
 | 
						|
				t.Fatalf("Actual response: %#v\nExpected response: %#v", resp, test.expectedResp)
 | 
						|
			}
 | 
						|
		})
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestCoerceFloatsToInt(t *testing.T) {
 | 
						|
	type testCase struct {
 | 
						|
		input    map[string]interface{}
 | 
						|
		expected map[string]interface{}
 | 
						|
	}
 | 
						|
 | 
						|
	tests := map[string]testCase{
 | 
						|
		"no numbers": {
 | 
						|
			input: map[string]interface{}{
 | 
						|
				"foo": "bar",
 | 
						|
			},
 | 
						|
			expected: map[string]interface{}{
 | 
						|
				"foo": "bar",
 | 
						|
			},
 | 
						|
		},
 | 
						|
		"raw integers": {
 | 
						|
			input: map[string]interface{}{
 | 
						|
				"foo": 42,
 | 
						|
			},
 | 
						|
			expected: map[string]interface{}{
 | 
						|
				"foo": 42,
 | 
						|
			},
 | 
						|
		},
 | 
						|
		"floats ": {
 | 
						|
			input: map[string]interface{}{
 | 
						|
				"foo": 42.2,
 | 
						|
			},
 | 
						|
			expected: map[string]interface{}{
 | 
						|
				"foo": 42.2,
 | 
						|
			},
 | 
						|
		},
 | 
						|
		"floats coerced to ints": {
 | 
						|
			input: map[string]interface{}{
 | 
						|
				"foo": float64(42),
 | 
						|
			},
 | 
						|
			expected: map[string]interface{}{
 | 
						|
				"foo": int64(42),
 | 
						|
			},
 | 
						|
		},
 | 
						|
	}
 | 
						|
 | 
						|
	for name, test := range tests {
 | 
						|
		t.Run(name, func(t *testing.T) {
 | 
						|
			actual := copyMap(test.input)
 | 
						|
			coerceFloatsToInt(actual)
 | 
						|
			if !reflect.DeepEqual(actual, test.expected) {
 | 
						|
				t.Fatalf("Actual: %#v\nExpected: %#v", actual, test.expected)
 | 
						|
			}
 | 
						|
		})
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func copyMap(m map[string]interface{}) map[string]interface{} {
 | 
						|
	newMap := map[string]interface{}{}
 | 
						|
	for k, v := range m {
 | 
						|
		newMap[k] = v
 | 
						|
	}
 | 
						|
	return newMap
 | 
						|
}
 | 
						|
 | 
						|
func TestGRPCServer_NewUser(t *testing.T) {
 | 
						|
	type testCase struct {
 | 
						|
		db           Database
 | 
						|
		req          *proto.NewUserRequest
 | 
						|
		expectedResp *proto.NewUserResponse
 | 
						|
		expectErr    bool
 | 
						|
		expectCode   codes.Code
 | 
						|
	}
 | 
						|
 | 
						|
	tests := map[string]testCase{
 | 
						|
		"missing username config": {
 | 
						|
			db:           fakeDatabase{},
 | 
						|
			req:          &proto.NewUserRequest{},
 | 
						|
			expectedResp: &proto.NewUserResponse{},
 | 
						|
			expectErr:    true,
 | 
						|
			expectCode:   codes.InvalidArgument,
 | 
						|
		},
 | 
						|
		"bad expiration": {
 | 
						|
			db: fakeDatabase{},
 | 
						|
			req: &proto.NewUserRequest{
 | 
						|
				UsernameConfig: &proto.UsernameConfig{
 | 
						|
					DisplayName: "dispname",
 | 
						|
					RoleName:    "rolename",
 | 
						|
				},
 | 
						|
				Expiration: ×tamp.Timestamp{
 | 
						|
					Seconds: invalidExpiration.Unix(),
 | 
						|
				},
 | 
						|
			},
 | 
						|
			expectedResp: &proto.NewUserResponse{},
 | 
						|
			expectErr:    true,
 | 
						|
			expectCode:   codes.InvalidArgument,
 | 
						|
		},
 | 
						|
		"database error": {
 | 
						|
			db: fakeDatabase{
 | 
						|
				newUserErr: errors.New("new user error"),
 | 
						|
			},
 | 
						|
			req: &proto.NewUserRequest{
 | 
						|
				UsernameConfig: &proto.UsernameConfig{
 | 
						|
					DisplayName: "dispname",
 | 
						|
					RoleName:    "rolename",
 | 
						|
				},
 | 
						|
				Expiration: ptypes.TimestampNow(),
 | 
						|
			},
 | 
						|
			expectedResp: &proto.NewUserResponse{},
 | 
						|
			expectErr:    true,
 | 
						|
			expectCode:   codes.Internal,
 | 
						|
		},
 | 
						|
		"happy path with expiration": {
 | 
						|
			db: fakeDatabase{
 | 
						|
				newUserResp: NewUserResponse{
 | 
						|
					Username: "someuser_foo",
 | 
						|
				},
 | 
						|
			},
 | 
						|
			req: &proto.NewUserRequest{
 | 
						|
				UsernameConfig: &proto.UsernameConfig{
 | 
						|
					DisplayName: "dispname",
 | 
						|
					RoleName:    "rolename",
 | 
						|
				},
 | 
						|
				Expiration: ptypes.TimestampNow(),
 | 
						|
			},
 | 
						|
			expectedResp: &proto.NewUserResponse{
 | 
						|
				Username: "someuser_foo",
 | 
						|
			},
 | 
						|
			expectErr:  false,
 | 
						|
			expectCode: codes.OK,
 | 
						|
		},
 | 
						|
		"happy path without expiration": {
 | 
						|
			db: fakeDatabase{
 | 
						|
				newUserResp: NewUserResponse{
 | 
						|
					Username: "someuser_foo",
 | 
						|
				},
 | 
						|
			},
 | 
						|
			req: &proto.NewUserRequest{
 | 
						|
				UsernameConfig: &proto.UsernameConfig{
 | 
						|
					DisplayName: "dispname",
 | 
						|
					RoleName:    "rolename",
 | 
						|
				},
 | 
						|
			},
 | 
						|
			expectedResp: &proto.NewUserResponse{
 | 
						|
				Username: "someuser_foo",
 | 
						|
			},
 | 
						|
			expectErr:  false,
 | 
						|
			expectCode: codes.OK,
 | 
						|
		},
 | 
						|
	}
 | 
						|
 | 
						|
	for name, test := range tests {
 | 
						|
		t.Run(name, func(t *testing.T) {
 | 
						|
			idCtx, g := testGrpcServer(t, test.db)
 | 
						|
			resp, err := g.NewUser(idCtx, test.req)
 | 
						|
 | 
						|
			if test.expectErr && err == nil {
 | 
						|
				t.Fatalf("err expected, got nil")
 | 
						|
			}
 | 
						|
			if !test.expectErr && err != nil {
 | 
						|
				t.Fatalf("no error expected, got: %s", err)
 | 
						|
			}
 | 
						|
 | 
						|
			actualCode := status.Code(err)
 | 
						|
			if actualCode != test.expectCode {
 | 
						|
				t.Fatalf("Actual code: %s Expected code: %s", actualCode, test.expectCode)
 | 
						|
			}
 | 
						|
 | 
						|
			if !reflect.DeepEqual(resp, test.expectedResp) {
 | 
						|
				t.Fatalf("Actual response: %#v\nExpected response: %#v", resp, test.expectedResp)
 | 
						|
			}
 | 
						|
		})
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestGRPCServer_UpdateUser(t *testing.T) {
 | 
						|
	type testCase struct {
 | 
						|
		db           Database
 | 
						|
		req          *proto.UpdateUserRequest
 | 
						|
		expectedResp *proto.UpdateUserResponse
 | 
						|
		expectErr    bool
 | 
						|
		expectCode   codes.Code
 | 
						|
	}
 | 
						|
 | 
						|
	tests := map[string]testCase{
 | 
						|
		"missing username": {
 | 
						|
			db:           fakeDatabase{},
 | 
						|
			req:          &proto.UpdateUserRequest{},
 | 
						|
			expectedResp: &proto.UpdateUserResponse{},
 | 
						|
			expectErr:    true,
 | 
						|
			expectCode:   codes.InvalidArgument,
 | 
						|
		},
 | 
						|
		"missing changes": {
 | 
						|
			db: fakeDatabase{},
 | 
						|
			req: &proto.UpdateUserRequest{
 | 
						|
				Username: "someuser",
 | 
						|
			},
 | 
						|
			expectedResp: &proto.UpdateUserResponse{},
 | 
						|
			expectErr:    true,
 | 
						|
			expectCode:   codes.InvalidArgument,
 | 
						|
		},
 | 
						|
		"database error": {
 | 
						|
			db: fakeDatabase{
 | 
						|
				updateUserErr: errors.New("update user error"),
 | 
						|
			},
 | 
						|
			req: &proto.UpdateUserRequest{
 | 
						|
				Username: "someuser",
 | 
						|
				Password: &proto.ChangePassword{
 | 
						|
					NewPassword: "90ughaino",
 | 
						|
				},
 | 
						|
			},
 | 
						|
			expectedResp: &proto.UpdateUserResponse{},
 | 
						|
			expectErr:    true,
 | 
						|
			expectCode:   codes.Internal,
 | 
						|
		},
 | 
						|
		"bad expiration date": {
 | 
						|
			db: fakeDatabase{},
 | 
						|
			req: &proto.UpdateUserRequest{
 | 
						|
				Username: "someuser",
 | 
						|
				Expiration: &proto.ChangeExpiration{
 | 
						|
					NewExpiration: ×tamp.Timestamp{
 | 
						|
						// Before minValidSeconds in ptypes package
 | 
						|
						Seconds: invalidExpiration.Unix(),
 | 
						|
					},
 | 
						|
				},
 | 
						|
			},
 | 
						|
			expectedResp: &proto.UpdateUserResponse{},
 | 
						|
			expectErr:    true,
 | 
						|
			expectCode:   codes.InvalidArgument,
 | 
						|
		},
 | 
						|
		"change password happy path": {
 | 
						|
			db: fakeDatabase{},
 | 
						|
			req: &proto.UpdateUserRequest{
 | 
						|
				Username: "someuser",
 | 
						|
				Password: &proto.ChangePassword{
 | 
						|
					NewPassword: "90ughaino",
 | 
						|
				},
 | 
						|
			},
 | 
						|
			expectedResp: &proto.UpdateUserResponse{},
 | 
						|
			expectErr:    false,
 | 
						|
			expectCode:   codes.OK,
 | 
						|
		},
 | 
						|
		"change expiration happy path": {
 | 
						|
			db: fakeDatabase{},
 | 
						|
			req: &proto.UpdateUserRequest{
 | 
						|
				Username: "someuser",
 | 
						|
				Expiration: &proto.ChangeExpiration{
 | 
						|
					NewExpiration: ptypes.TimestampNow(),
 | 
						|
				},
 | 
						|
			},
 | 
						|
			expectedResp: &proto.UpdateUserResponse{},
 | 
						|
			expectErr:    false,
 | 
						|
			expectCode:   codes.OK,
 | 
						|
		},
 | 
						|
	}
 | 
						|
 | 
						|
	for name, test := range tests {
 | 
						|
		t.Run(name, func(t *testing.T) {
 | 
						|
			idCtx, g := testGrpcServer(t, test.db)
 | 
						|
			resp, err := g.UpdateUser(idCtx, test.req)
 | 
						|
 | 
						|
			if test.expectErr && err == nil {
 | 
						|
				t.Fatalf("err expected, got nil")
 | 
						|
			}
 | 
						|
			if !test.expectErr && err != nil {
 | 
						|
				t.Fatalf("no error expected, got: %s", err)
 | 
						|
			}
 | 
						|
 | 
						|
			actualCode := status.Code(err)
 | 
						|
			if actualCode != test.expectCode {
 | 
						|
				t.Fatalf("Actual code: %s Expected code: %s", actualCode, test.expectCode)
 | 
						|
			}
 | 
						|
 | 
						|
			if !reflect.DeepEqual(resp, test.expectedResp) {
 | 
						|
				t.Fatalf("Actual response: %#v\nExpected response: %#v", resp, test.expectedResp)
 | 
						|
			}
 | 
						|
		})
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestGRPCServer_DeleteUser(t *testing.T) {
 | 
						|
	type testCase struct {
 | 
						|
		db           Database
 | 
						|
		req          *proto.DeleteUserRequest
 | 
						|
		expectedResp *proto.DeleteUserResponse
 | 
						|
		expectErr    bool
 | 
						|
		expectCode   codes.Code
 | 
						|
	}
 | 
						|
 | 
						|
	tests := map[string]testCase{
 | 
						|
		"missing username": {
 | 
						|
			db:           fakeDatabase{},
 | 
						|
			req:          &proto.DeleteUserRequest{},
 | 
						|
			expectedResp: &proto.DeleteUserResponse{},
 | 
						|
			expectErr:    true,
 | 
						|
			expectCode:   codes.InvalidArgument,
 | 
						|
		},
 | 
						|
		"database error": {
 | 
						|
			db: fakeDatabase{
 | 
						|
				deleteUserErr: errors.New("delete user error"),
 | 
						|
			},
 | 
						|
			req: &proto.DeleteUserRequest{
 | 
						|
				Username: "someuser",
 | 
						|
			},
 | 
						|
			expectedResp: &proto.DeleteUserResponse{},
 | 
						|
			expectErr:    true,
 | 
						|
			expectCode:   codes.Internal,
 | 
						|
		},
 | 
						|
		"happy path": {
 | 
						|
			db: fakeDatabase{},
 | 
						|
			req: &proto.DeleteUserRequest{
 | 
						|
				Username: "someuser",
 | 
						|
			},
 | 
						|
			expectedResp: &proto.DeleteUserResponse{},
 | 
						|
			expectErr:    false,
 | 
						|
			expectCode:   codes.OK,
 | 
						|
		},
 | 
						|
	}
 | 
						|
 | 
						|
	for name, test := range tests {
 | 
						|
		t.Run(name, func(t *testing.T) {
 | 
						|
			idCtx, g := testGrpcServer(t, test.db)
 | 
						|
			resp, err := g.DeleteUser(idCtx, test.req)
 | 
						|
 | 
						|
			if test.expectErr && err == nil {
 | 
						|
				t.Fatalf("err expected, got nil")
 | 
						|
			}
 | 
						|
			if !test.expectErr && err != nil {
 | 
						|
				t.Fatalf("no error expected, got: %s", err)
 | 
						|
			}
 | 
						|
 | 
						|
			actualCode := status.Code(err)
 | 
						|
			if actualCode != test.expectCode {
 | 
						|
				t.Fatalf("Actual code: %s Expected code: %s", actualCode, test.expectCode)
 | 
						|
			}
 | 
						|
 | 
						|
			if !reflect.DeepEqual(resp, test.expectedResp) {
 | 
						|
				t.Fatalf("Actual response: %#v\nExpected response: %#v", resp, test.expectedResp)
 | 
						|
			}
 | 
						|
		})
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestGRPCServer_Type(t *testing.T) {
 | 
						|
	type testCase struct {
 | 
						|
		db           Database
 | 
						|
		expectedResp *proto.TypeResponse
 | 
						|
		expectErr    bool
 | 
						|
		expectCode   codes.Code
 | 
						|
	}
 | 
						|
 | 
						|
	tests := map[string]testCase{
 | 
						|
		"database error": {
 | 
						|
			db: fakeDatabase{
 | 
						|
				typeErr: errors.New("type error"),
 | 
						|
			},
 | 
						|
			expectedResp: &proto.TypeResponse{},
 | 
						|
			expectErr:    true,
 | 
						|
			expectCode:   codes.Internal,
 | 
						|
		},
 | 
						|
		"happy path": {
 | 
						|
			db: fakeDatabase{
 | 
						|
				typeResp: "fake database",
 | 
						|
			},
 | 
						|
			expectedResp: &proto.TypeResponse{
 | 
						|
				Type: "fake database",
 | 
						|
			},
 | 
						|
			expectErr:  false,
 | 
						|
			expectCode: codes.OK,
 | 
						|
		},
 | 
						|
	}
 | 
						|
 | 
						|
	for name, test := range tests {
 | 
						|
		t.Run(name, func(t *testing.T) {
 | 
						|
			idCtx, g := testGrpcServer(t, test.db)
 | 
						|
			resp, err := g.Type(idCtx, &proto.Empty{})
 | 
						|
 | 
						|
			if test.expectErr && err == nil {
 | 
						|
				t.Fatalf("err expected, got nil")
 | 
						|
			}
 | 
						|
			if !test.expectErr && err != nil {
 | 
						|
				t.Fatalf("no error expected, got: %s", err)
 | 
						|
			}
 | 
						|
 | 
						|
			actualCode := status.Code(err)
 | 
						|
			if actualCode != test.expectCode {
 | 
						|
				t.Fatalf("Actual code: %s Expected code: %s", actualCode, test.expectCode)
 | 
						|
			}
 | 
						|
 | 
						|
			if !reflect.DeepEqual(resp, test.expectedResp) {
 | 
						|
				t.Fatalf("Actual response: %#v\nExpected response: %#v", resp, test.expectedResp)
 | 
						|
			}
 | 
						|
		})
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestGRPCServer_Close(t *testing.T) {
 | 
						|
	type testCase struct {
 | 
						|
		db            Database
 | 
						|
		expectErr     bool
 | 
						|
		expectCode    codes.Code
 | 
						|
		grpcSetupFunc func(*testing.T, Database) (context.Context, gRPCServer)
 | 
						|
		assertFunc    func(t *testing.T, g gRPCServer)
 | 
						|
	}
 | 
						|
 | 
						|
	tests := map[string]testCase{
 | 
						|
		"database error": {
 | 
						|
			db: fakeDatabase{
 | 
						|
				closeErr: errors.New("close error"),
 | 
						|
			},
 | 
						|
			expectErr:     true,
 | 
						|
			expectCode:    codes.Internal,
 | 
						|
			grpcSetupFunc: testGrpcServer,
 | 
						|
			assertFunc:    nil,
 | 
						|
		},
 | 
						|
		"happy path for multiplexed plugin": {
 | 
						|
			db:            fakeDatabase{},
 | 
						|
			expectErr:     false,
 | 
						|
			expectCode:    codes.OK,
 | 
						|
			grpcSetupFunc: testGrpcServer,
 | 
						|
			assertFunc: func(t *testing.T, g gRPCServer) {
 | 
						|
				if len(g.instances) != 0 {
 | 
						|
					t.Fatalf("err expected instances map to be empty")
 | 
						|
				}
 | 
						|
			},
 | 
						|
		},
 | 
						|
		"happy path for non-multiplexed plugin": {
 | 
						|
			db:            fakeDatabase{},
 | 
						|
			expectErr:     false,
 | 
						|
			expectCode:    codes.OK,
 | 
						|
			grpcSetupFunc: testGrpcServerSingleImpl,
 | 
						|
			assertFunc:    nil,
 | 
						|
		},
 | 
						|
	}
 | 
						|
 | 
						|
	for name, test := range tests {
 | 
						|
		t.Run(name, func(t *testing.T) {
 | 
						|
			idCtx, g := test.grpcSetupFunc(t, test.db)
 | 
						|
			_, err := g.Close(idCtx, &proto.Empty{})
 | 
						|
 | 
						|
			if test.expectErr && err == nil {
 | 
						|
				t.Fatalf("err expected, got nil")
 | 
						|
			}
 | 
						|
			if !test.expectErr && err != nil {
 | 
						|
				t.Fatalf("no error expected, got: %s", err)
 | 
						|
			}
 | 
						|
 | 
						|
			actualCode := status.Code(err)
 | 
						|
			if actualCode != test.expectCode {
 | 
						|
				t.Fatalf("Actual code: %s Expected code: %s", actualCode, test.expectCode)
 | 
						|
			}
 | 
						|
 | 
						|
			if test.assertFunc != nil {
 | 
						|
				test.assertFunc(t, g)
 | 
						|
			}
 | 
						|
		})
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestGetMultiplexIDFromContext(t *testing.T) {
 | 
						|
	type testCase struct {
 | 
						|
		ctx          context.Context
 | 
						|
		expectedResp string
 | 
						|
		expectedErr  error
 | 
						|
	}
 | 
						|
 | 
						|
	tests := map[string]testCase{
 | 
						|
		"missing plugin multiplexing metadata": {
 | 
						|
			ctx:          context.Background(),
 | 
						|
			expectedResp: "",
 | 
						|
			expectedErr:  fmt.Errorf("missing plugin multiplexing metadata"),
 | 
						|
		},
 | 
						|
		"unexpected number of IDs in metadata": {
 | 
						|
			ctx:          idCtx(t, "12345", "67891"),
 | 
						|
			expectedResp: "",
 | 
						|
			expectedErr:  fmt.Errorf("unexpected number of IDs in metadata: (2)"),
 | 
						|
		},
 | 
						|
		"empty multiplex ID in metadata": {
 | 
						|
			ctx:          idCtx(t, ""),
 | 
						|
			expectedResp: "",
 | 
						|
			expectedErr:  fmt.Errorf("empty multiplex ID in metadata"),
 | 
						|
		},
 | 
						|
		"happy path, id is returned from metadata": {
 | 
						|
			ctx:          idCtx(t, "12345"),
 | 
						|
			expectedResp: "12345",
 | 
						|
			expectedErr:  nil,
 | 
						|
		},
 | 
						|
	}
 | 
						|
 | 
						|
	for name, test := range tests {
 | 
						|
		t.Run(name, func(t *testing.T) {
 | 
						|
			resp, err := getMultiplexIDFromContext(test.ctx)
 | 
						|
 | 
						|
			if test.expectedErr != nil && test.expectedErr.Error() != "" && err == nil {
 | 
						|
				t.Fatalf("err expected, got nil")
 | 
						|
			} else if !reflect.DeepEqual(err, test.expectedErr) {
 | 
						|
				t.Fatalf("Actual error: %#v\nExpected error: %#v", err, test.expectedErr)
 | 
						|
			}
 | 
						|
 | 
						|
			if test.expectedErr != nil && test.expectedErr.Error() == "" && err != nil {
 | 
						|
				t.Fatalf("no error expected, got: %s", err)
 | 
						|
			}
 | 
						|
 | 
						|
			if !reflect.DeepEqual(resp, test.expectedResp) {
 | 
						|
				t.Fatalf("Actual response: %#v\nExpected response: %#v", resp, test.expectedResp)
 | 
						|
			}
 | 
						|
		})
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
// testGrpcServer is a test helper that returns a context with an ID set in its
 | 
						|
// metadata and a gRPCServer instance for a multiplexed plugin
 | 
						|
func testGrpcServer(t *testing.T, db Database) (context.Context, gRPCServer) {
 | 
						|
	t.Helper()
 | 
						|
	g := gRPCServer{
 | 
						|
		factoryFunc: func() (interface{}, error) {
 | 
						|
			return db, nil
 | 
						|
		},
 | 
						|
		instances: make(map[string]Database),
 | 
						|
	}
 | 
						|
 | 
						|
	id := "12345"
 | 
						|
	idCtx := idCtx(t, id)
 | 
						|
	g.instances[id] = db
 | 
						|
 | 
						|
	return idCtx, g
 | 
						|
}
 | 
						|
 | 
						|
// testGrpcServerSingleImpl is a test helper that returns a context and a
 | 
						|
// gRPCServer instance for a non-multiplexed plugin
 | 
						|
func testGrpcServerSingleImpl(t *testing.T, db Database) (context.Context, gRPCServer) {
 | 
						|
	t.Helper()
 | 
						|
	return context.Background(), gRPCServer{
 | 
						|
		singleImpl: db,
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
// idCtx is a test helper that will return a context with the IDs set in its
 | 
						|
// metadata
 | 
						|
func idCtx(t *testing.T, ids ...string) context.Context {
 | 
						|
	t.Helper()
 | 
						|
	// Context doesn't need to timeout since this is just passed through
 | 
						|
	ctx := context.Background()
 | 
						|
	md := metadata.MD{}
 | 
						|
	for _, id := range ids {
 | 
						|
		md.Append(pluginutil.MultiplexingCtxKey, id)
 | 
						|
	}
 | 
						|
	return metadata.NewIncomingContext(ctx, md)
 | 
						|
}
 | 
						|
 | 
						|
func marshal(t *testing.T, m map[string]interface{}) *structpb.Struct {
 | 
						|
	t.Helper()
 | 
						|
 | 
						|
	strct, err := mapToStruct(m)
 | 
						|
	if err != nil {
 | 
						|
		t.Fatalf("unable to marshal to protobuf: %s", err)
 | 
						|
	}
 | 
						|
	return strct
 | 
						|
}
 | 
						|
 | 
						|
type badJSONValue struct{}
 | 
						|
 | 
						|
func (badJSONValue) MarshalJSON() ([]byte, error) {
 | 
						|
	return nil, fmt.Errorf("this cannot be marshalled to JSON")
 | 
						|
}
 | 
						|
 | 
						|
func (badJSONValue) UnmarshalJSON([]byte) error {
 | 
						|
	return fmt.Errorf("this cannot be unmarshalled from JSON")
 | 
						|
}
 | 
						|
 | 
						|
var _ Database = fakeDatabase{}
 | 
						|
 | 
						|
type fakeDatabase struct {
 | 
						|
	initResp InitializeResponse
 | 
						|
	initErr  error
 | 
						|
 | 
						|
	newUserResp NewUserResponse
 | 
						|
	newUserErr  error
 | 
						|
 | 
						|
	updateUserResp UpdateUserResponse
 | 
						|
	updateUserErr  error
 | 
						|
 | 
						|
	deleteUserResp DeleteUserResponse
 | 
						|
	deleteUserErr  error
 | 
						|
 | 
						|
	typeResp string
 | 
						|
	typeErr  error
 | 
						|
 | 
						|
	closeErr error
 | 
						|
}
 | 
						|
 | 
						|
func (e fakeDatabase) Initialize(ctx context.Context, req InitializeRequest) (InitializeResponse, error) {
 | 
						|
	return e.initResp, e.initErr
 | 
						|
}
 | 
						|
 | 
						|
func (e fakeDatabase) NewUser(ctx context.Context, req NewUserRequest) (NewUserResponse, error) {
 | 
						|
	return e.newUserResp, e.newUserErr
 | 
						|
}
 | 
						|
 | 
						|
func (e fakeDatabase) UpdateUser(ctx context.Context, req UpdateUserRequest) (UpdateUserResponse, error) {
 | 
						|
	return e.updateUserResp, e.updateUserErr
 | 
						|
}
 | 
						|
 | 
						|
func (e fakeDatabase) DeleteUser(ctx context.Context, req DeleteUserRequest) (DeleteUserResponse, error) {
 | 
						|
	return e.deleteUserResp, e.deleteUserErr
 | 
						|
}
 | 
						|
 | 
						|
func (e fakeDatabase) Type() (string, error) {
 | 
						|
	return e.typeResp, e.typeErr
 | 
						|
}
 | 
						|
 | 
						|
func (e fakeDatabase) Close() error {
 | 
						|
	return e.closeErr
 | 
						|
}
 | 
						|
 | 
						|
var _ Database = &recordingDatabase{}
 | 
						|
 | 
						|
type recordingDatabase struct {
 | 
						|
	initializeCalls int
 | 
						|
	newUserCalls    int
 | 
						|
	updateUserCalls int
 | 
						|
	deleteUserCalls int
 | 
						|
	typeCalls       int
 | 
						|
	closeCalls      int
 | 
						|
 | 
						|
	// recordingDatabase can act as middleware so we can record the calls to other test Database implementations
 | 
						|
	next Database
 | 
						|
}
 | 
						|
 | 
						|
func (f *recordingDatabase) Initialize(ctx context.Context, req InitializeRequest) (InitializeResponse, error) {
 | 
						|
	f.initializeCalls++
 | 
						|
	if f.next == nil {
 | 
						|
		return InitializeResponse{}, nil
 | 
						|
	}
 | 
						|
	return f.next.Initialize(ctx, req)
 | 
						|
}
 | 
						|
 | 
						|
func (f *recordingDatabase) NewUser(ctx context.Context, req NewUserRequest) (NewUserResponse, error) {
 | 
						|
	f.newUserCalls++
 | 
						|
	if f.next == nil {
 | 
						|
		return NewUserResponse{}, nil
 | 
						|
	}
 | 
						|
	return f.next.NewUser(ctx, req)
 | 
						|
}
 | 
						|
 | 
						|
func (f *recordingDatabase) UpdateUser(ctx context.Context, req UpdateUserRequest) (UpdateUserResponse, error) {
 | 
						|
	f.updateUserCalls++
 | 
						|
	if f.next == nil {
 | 
						|
		return UpdateUserResponse{}, nil
 | 
						|
	}
 | 
						|
	return f.next.UpdateUser(ctx, req)
 | 
						|
}
 | 
						|
 | 
						|
func (f *recordingDatabase) DeleteUser(ctx context.Context, req DeleteUserRequest) (DeleteUserResponse, error) {
 | 
						|
	f.deleteUserCalls++
 | 
						|
	if f.next == nil {
 | 
						|
		return DeleteUserResponse{}, nil
 | 
						|
	}
 | 
						|
	return f.next.DeleteUser(ctx, req)
 | 
						|
}
 | 
						|
 | 
						|
func (f *recordingDatabase) Type() (string, error) {
 | 
						|
	f.typeCalls++
 | 
						|
	if f.next == nil {
 | 
						|
		return "recordingDatabase", nil
 | 
						|
	}
 | 
						|
	return f.next.Type()
 | 
						|
}
 | 
						|
 | 
						|
func (f *recordingDatabase) Close() error {
 | 
						|
	f.closeCalls++
 | 
						|
	if f.next == nil {
 | 
						|
		return nil
 | 
						|
	}
 | 
						|
	return f.next.Close()
 | 
						|
}
 |