mirror of
				https://github.com/optim-enterprises-bv/vault.git
				synced 2025-11-03 20:17:59 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			546 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			546 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
// Copyright (c) HashiCorp, Inc.
 | 
						|
// SPDX-License-Identifier: MPL-2.0
 | 
						|
 | 
						|
package dbplugin
 | 
						|
 | 
						|
import (
 | 
						|
	"context"
 | 
						|
	"encoding/json"
 | 
						|
	"errors"
 | 
						|
	"reflect"
 | 
						|
	"testing"
 | 
						|
	"time"
 | 
						|
 | 
						|
	"github.com/hashicorp/vault/sdk/database/dbplugin/v5/proto"
 | 
						|
	"google.golang.org/grpc"
 | 
						|
)
 | 
						|
 | 
						|
func TestGRPCClient_Initialize(t *testing.T) {
 | 
						|
	type testCase struct {
 | 
						|
		client       proto.DatabaseClient
 | 
						|
		req          InitializeRequest
 | 
						|
		expectedResp InitializeResponse
 | 
						|
		assertErr    errorAssertion
 | 
						|
	}
 | 
						|
 | 
						|
	tests := map[string]testCase{
 | 
						|
		"bad config": {
 | 
						|
			client: fakeClient{},
 | 
						|
			req: InitializeRequest{
 | 
						|
				Config: map[string]interface{}{
 | 
						|
					"foo": badJSONValue{},
 | 
						|
				},
 | 
						|
			},
 | 
						|
			assertErr: assertErrNotNil,
 | 
						|
		},
 | 
						|
		"database error": {
 | 
						|
			client: fakeClient{
 | 
						|
				initErr: errors.New("initialize error"),
 | 
						|
			},
 | 
						|
			req: InitializeRequest{
 | 
						|
				Config: map[string]interface{}{
 | 
						|
					"foo": "bar",
 | 
						|
				},
 | 
						|
			},
 | 
						|
			assertErr: assertErrNotNil,
 | 
						|
		},
 | 
						|
		"happy path": {
 | 
						|
			client: fakeClient{
 | 
						|
				initResp: &proto.InitializeResponse{
 | 
						|
					ConfigData: marshal(t, map[string]interface{}{
 | 
						|
						"foo": "bar",
 | 
						|
						"baz": "biz",
 | 
						|
					}),
 | 
						|
				},
 | 
						|
			},
 | 
						|
			req: InitializeRequest{
 | 
						|
				Config: map[string]interface{}{
 | 
						|
					"foo": "bar",
 | 
						|
				},
 | 
						|
			},
 | 
						|
			expectedResp: InitializeResponse{
 | 
						|
				Config: map[string]interface{}{
 | 
						|
					"foo": "bar",
 | 
						|
					"baz": "biz",
 | 
						|
				},
 | 
						|
			},
 | 
						|
			assertErr: assertErrNil,
 | 
						|
		},
 | 
						|
		"JSON number type in initialize request": {
 | 
						|
			client: fakeClient{
 | 
						|
				initResp: &proto.InitializeResponse{
 | 
						|
					ConfigData: marshal(t, map[string]interface{}{
 | 
						|
						"foo": "bar",
 | 
						|
						"max": "10",
 | 
						|
					}),
 | 
						|
				},
 | 
						|
			},
 | 
						|
			req: InitializeRequest{
 | 
						|
				Config: map[string]interface{}{
 | 
						|
					"foo": "bar",
 | 
						|
					"max": json.Number("10"),
 | 
						|
				},
 | 
						|
			},
 | 
						|
			expectedResp: InitializeResponse{
 | 
						|
				Config: map[string]interface{}{
 | 
						|
					"foo": "bar",
 | 
						|
					"max": "10",
 | 
						|
				},
 | 
						|
			},
 | 
						|
			assertErr: assertErrNil,
 | 
						|
		},
 | 
						|
	}
 | 
						|
 | 
						|
	for name, test := range tests {
 | 
						|
		t.Run(name, func(t *testing.T) {
 | 
						|
			c := gRPCClient{
 | 
						|
				client:  test.client,
 | 
						|
				doneCtx: nil,
 | 
						|
			}
 | 
						|
 | 
						|
			// Context doesn't need to timeout since this is just passed through
 | 
						|
			ctx := context.Background()
 | 
						|
 | 
						|
			resp, err := c.Initialize(ctx, test.req)
 | 
						|
			test.assertErr(t, err)
 | 
						|
 | 
						|
			if !reflect.DeepEqual(resp, test.expectedResp) {
 | 
						|
				t.Fatalf("Actual response: %#v\nExpected response: %#v", resp, test.expectedResp)
 | 
						|
			}
 | 
						|
		})
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestGRPCClient_NewUser(t *testing.T) {
 | 
						|
	runningCtx := context.Background()
 | 
						|
	cancelledCtx, cancel := context.WithCancel(context.Background())
 | 
						|
	cancel()
 | 
						|
 | 
						|
	type testCase struct {
 | 
						|
		client       proto.DatabaseClient
 | 
						|
		req          NewUserRequest
 | 
						|
		doneCtx      context.Context
 | 
						|
		expectedResp NewUserResponse
 | 
						|
		assertErr    errorAssertion
 | 
						|
	}
 | 
						|
 | 
						|
	tests := map[string]testCase{
 | 
						|
		"missing password": {
 | 
						|
			client: fakeClient{},
 | 
						|
			req: NewUserRequest{
 | 
						|
				Password:   "",
 | 
						|
				Expiration: time.Now(),
 | 
						|
			},
 | 
						|
			doneCtx:   runningCtx,
 | 
						|
			assertErr: assertErrNotNil,
 | 
						|
		},
 | 
						|
		"bad expiration": {
 | 
						|
			client: fakeClient{},
 | 
						|
			req: NewUserRequest{
 | 
						|
				Password:   "njkvcb8y934u90grsnkjl",
 | 
						|
				Expiration: invalidExpiration,
 | 
						|
			},
 | 
						|
			doneCtx:   runningCtx,
 | 
						|
			assertErr: assertErrNotNil,
 | 
						|
		},
 | 
						|
		"database error": {
 | 
						|
			client: fakeClient{
 | 
						|
				newUserErr: errors.New("new user error"),
 | 
						|
			},
 | 
						|
			req: NewUserRequest{
 | 
						|
				Password:   "njkvcb8y934u90grsnkjl",
 | 
						|
				Expiration: time.Now(),
 | 
						|
			},
 | 
						|
			doneCtx:   runningCtx,
 | 
						|
			assertErr: assertErrNotNil,
 | 
						|
		},
 | 
						|
		"plugin shut down": {
 | 
						|
			client: fakeClient{
 | 
						|
				newUserErr: errors.New("new user error"),
 | 
						|
			},
 | 
						|
			req: NewUserRequest{
 | 
						|
				Password:   "njkvcb8y934u90grsnkjl",
 | 
						|
				Expiration: time.Now(),
 | 
						|
			},
 | 
						|
			doneCtx:   cancelledCtx,
 | 
						|
			assertErr: assertErrEquals(ErrPluginShutdown),
 | 
						|
		},
 | 
						|
		"happy path": {
 | 
						|
			client: fakeClient{
 | 
						|
				newUserResp: &proto.NewUserResponse{
 | 
						|
					Username: "new_user",
 | 
						|
				},
 | 
						|
			},
 | 
						|
			req: NewUserRequest{
 | 
						|
				Password:   "njkvcb8y934u90grsnkjl",
 | 
						|
				Expiration: time.Now(),
 | 
						|
			},
 | 
						|
			doneCtx: runningCtx,
 | 
						|
			expectedResp: NewUserResponse{
 | 
						|
				Username: "new_user",
 | 
						|
			},
 | 
						|
			assertErr: assertErrNil,
 | 
						|
		},
 | 
						|
	}
 | 
						|
 | 
						|
	for name, test := range tests {
 | 
						|
		t.Run(name, func(t *testing.T) {
 | 
						|
			c := gRPCClient{
 | 
						|
				client:  test.client,
 | 
						|
				doneCtx: test.doneCtx,
 | 
						|
			}
 | 
						|
 | 
						|
			ctx := context.Background()
 | 
						|
 | 
						|
			resp, err := c.NewUser(ctx, test.req)
 | 
						|
			test.assertErr(t, err)
 | 
						|
 | 
						|
			if !reflect.DeepEqual(resp, test.expectedResp) {
 | 
						|
				t.Fatalf("Actual response: %#v\nExpected response: %#v", resp, test.expectedResp)
 | 
						|
			}
 | 
						|
		})
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestGRPCClient_UpdateUser(t *testing.T) {
 | 
						|
	runningCtx := context.Background()
 | 
						|
	cancelledCtx, cancel := context.WithCancel(context.Background())
 | 
						|
	cancel()
 | 
						|
 | 
						|
	type testCase struct {
 | 
						|
		client    proto.DatabaseClient
 | 
						|
		req       UpdateUserRequest
 | 
						|
		doneCtx   context.Context
 | 
						|
		assertErr errorAssertion
 | 
						|
	}
 | 
						|
 | 
						|
	tests := map[string]testCase{
 | 
						|
		"missing username": {
 | 
						|
			client:    fakeClient{},
 | 
						|
			req:       UpdateUserRequest{},
 | 
						|
			doneCtx:   runningCtx,
 | 
						|
			assertErr: assertErrNotNil,
 | 
						|
		},
 | 
						|
		"missing changes": {
 | 
						|
			client: fakeClient{},
 | 
						|
			req: UpdateUserRequest{
 | 
						|
				Username: "user",
 | 
						|
			},
 | 
						|
			doneCtx:   runningCtx,
 | 
						|
			assertErr: assertErrNotNil,
 | 
						|
		},
 | 
						|
		"empty password": {
 | 
						|
			client: fakeClient{},
 | 
						|
			req: UpdateUserRequest{
 | 
						|
				Username: "user",
 | 
						|
				Password: &ChangePassword{
 | 
						|
					NewPassword: "",
 | 
						|
				},
 | 
						|
			},
 | 
						|
			doneCtx:   runningCtx,
 | 
						|
			assertErr: assertErrNotNil,
 | 
						|
		},
 | 
						|
		"zero expiration": {
 | 
						|
			client: fakeClient{},
 | 
						|
			req: UpdateUserRequest{
 | 
						|
				Username: "user",
 | 
						|
				Expiration: &ChangeExpiration{
 | 
						|
					NewExpiration: time.Time{},
 | 
						|
				},
 | 
						|
			},
 | 
						|
			doneCtx:   runningCtx,
 | 
						|
			assertErr: assertErrNotNil,
 | 
						|
		},
 | 
						|
		"bad expiration": {
 | 
						|
			client: fakeClient{},
 | 
						|
			req: UpdateUserRequest{
 | 
						|
				Username: "user",
 | 
						|
				Expiration: &ChangeExpiration{
 | 
						|
					NewExpiration: invalidExpiration,
 | 
						|
				},
 | 
						|
			},
 | 
						|
			doneCtx:   runningCtx,
 | 
						|
			assertErr: assertErrNotNil,
 | 
						|
		},
 | 
						|
		"database error": {
 | 
						|
			client: fakeClient{
 | 
						|
				updateUserErr: errors.New("update user error"),
 | 
						|
			},
 | 
						|
			req: UpdateUserRequest{
 | 
						|
				Username: "user",
 | 
						|
				Password: &ChangePassword{
 | 
						|
					NewPassword: "asdf",
 | 
						|
				},
 | 
						|
			},
 | 
						|
			doneCtx:   runningCtx,
 | 
						|
			assertErr: assertErrNotNil,
 | 
						|
		},
 | 
						|
		"plugin shut down": {
 | 
						|
			client: fakeClient{
 | 
						|
				updateUserErr: errors.New("update user error"),
 | 
						|
			},
 | 
						|
			req: UpdateUserRequest{
 | 
						|
				Username: "user",
 | 
						|
				Password: &ChangePassword{
 | 
						|
					NewPassword: "asdf",
 | 
						|
				},
 | 
						|
			},
 | 
						|
			doneCtx:   cancelledCtx,
 | 
						|
			assertErr: assertErrEquals(ErrPluginShutdown),
 | 
						|
		},
 | 
						|
		"happy path - change password": {
 | 
						|
			client: fakeClient{},
 | 
						|
			req: UpdateUserRequest{
 | 
						|
				Username: "user",
 | 
						|
				Password: &ChangePassword{
 | 
						|
					NewPassword: "asdf",
 | 
						|
				},
 | 
						|
			},
 | 
						|
			doneCtx:   runningCtx,
 | 
						|
			assertErr: assertErrNil,
 | 
						|
		},
 | 
						|
		"happy path - change expiration": {
 | 
						|
			client: fakeClient{},
 | 
						|
			req: UpdateUserRequest{
 | 
						|
				Username: "user",
 | 
						|
				Expiration: &ChangeExpiration{
 | 
						|
					NewExpiration: time.Now(),
 | 
						|
				},
 | 
						|
			},
 | 
						|
			doneCtx:   runningCtx,
 | 
						|
			assertErr: assertErrNil,
 | 
						|
		},
 | 
						|
	}
 | 
						|
 | 
						|
	for name, test := range tests {
 | 
						|
		t.Run(name, func(t *testing.T) {
 | 
						|
			c := gRPCClient{
 | 
						|
				client:  test.client,
 | 
						|
				doneCtx: test.doneCtx,
 | 
						|
			}
 | 
						|
 | 
						|
			ctx := context.Background()
 | 
						|
 | 
						|
			_, err := c.UpdateUser(ctx, test.req)
 | 
						|
			test.assertErr(t, err)
 | 
						|
		})
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestGRPCClient_DeleteUser(t *testing.T) {
 | 
						|
	runningCtx := context.Background()
 | 
						|
	cancelledCtx, cancel := context.WithCancel(context.Background())
 | 
						|
	cancel()
 | 
						|
 | 
						|
	type testCase struct {
 | 
						|
		client    proto.DatabaseClient
 | 
						|
		req       DeleteUserRequest
 | 
						|
		doneCtx   context.Context
 | 
						|
		assertErr errorAssertion
 | 
						|
	}
 | 
						|
 | 
						|
	tests := map[string]testCase{
 | 
						|
		"missing username": {
 | 
						|
			client:    fakeClient{},
 | 
						|
			req:       DeleteUserRequest{},
 | 
						|
			doneCtx:   runningCtx,
 | 
						|
			assertErr: assertErrNotNil,
 | 
						|
		},
 | 
						|
		"database error": {
 | 
						|
			client: fakeClient{
 | 
						|
				deleteUserErr: errors.New("delete user error'"),
 | 
						|
			},
 | 
						|
			req: DeleteUserRequest{
 | 
						|
				Username: "user",
 | 
						|
			},
 | 
						|
			doneCtx:   runningCtx,
 | 
						|
			assertErr: assertErrNotNil,
 | 
						|
		},
 | 
						|
		"plugin shut down": {
 | 
						|
			client: fakeClient{
 | 
						|
				deleteUserErr: errors.New("delete user error'"),
 | 
						|
			},
 | 
						|
			req: DeleteUserRequest{
 | 
						|
				Username: "user",
 | 
						|
			},
 | 
						|
			doneCtx:   cancelledCtx,
 | 
						|
			assertErr: assertErrEquals(ErrPluginShutdown),
 | 
						|
		},
 | 
						|
		"happy path": {
 | 
						|
			client: fakeClient{},
 | 
						|
			req: DeleteUserRequest{
 | 
						|
				Username: "user",
 | 
						|
			},
 | 
						|
			doneCtx:   runningCtx,
 | 
						|
			assertErr: assertErrNil,
 | 
						|
		},
 | 
						|
	}
 | 
						|
 | 
						|
	for name, test := range tests {
 | 
						|
		t.Run(name, func(t *testing.T) {
 | 
						|
			c := gRPCClient{
 | 
						|
				client:  test.client,
 | 
						|
				doneCtx: test.doneCtx,
 | 
						|
			}
 | 
						|
 | 
						|
			ctx := context.Background()
 | 
						|
 | 
						|
			_, err := c.DeleteUser(ctx, test.req)
 | 
						|
			test.assertErr(t, err)
 | 
						|
		})
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestGRPCClient_Type(t *testing.T) {
 | 
						|
	runningCtx := context.Background()
 | 
						|
	cancelledCtx, cancel := context.WithCancel(context.Background())
 | 
						|
	cancel()
 | 
						|
 | 
						|
	type testCase struct {
 | 
						|
		client       proto.DatabaseClient
 | 
						|
		doneCtx      context.Context
 | 
						|
		expectedType string
 | 
						|
		assertErr    errorAssertion
 | 
						|
	}
 | 
						|
 | 
						|
	tests := map[string]testCase{
 | 
						|
		"database error": {
 | 
						|
			client: fakeClient{
 | 
						|
				typeErr: errors.New("type error"),
 | 
						|
			},
 | 
						|
			doneCtx:   runningCtx,
 | 
						|
			assertErr: assertErrNotNil,
 | 
						|
		},
 | 
						|
		"plugin shut down": {
 | 
						|
			client: fakeClient{
 | 
						|
				typeErr: errors.New("type error"),
 | 
						|
			},
 | 
						|
			doneCtx:   cancelledCtx,
 | 
						|
			assertErr: assertErrEquals(ErrPluginShutdown),
 | 
						|
		},
 | 
						|
		"happy path": {
 | 
						|
			client: fakeClient{
 | 
						|
				typeResp: &proto.TypeResponse{
 | 
						|
					Type: "test type",
 | 
						|
				},
 | 
						|
			},
 | 
						|
			doneCtx:      runningCtx,
 | 
						|
			expectedType: "test type",
 | 
						|
			assertErr:    assertErrNil,
 | 
						|
		},
 | 
						|
	}
 | 
						|
 | 
						|
	for name, test := range tests {
 | 
						|
		t.Run(name, func(t *testing.T) {
 | 
						|
			c := gRPCClient{
 | 
						|
				client:  test.client,
 | 
						|
				doneCtx: test.doneCtx,
 | 
						|
			}
 | 
						|
 | 
						|
			dbType, err := c.Type()
 | 
						|
			test.assertErr(t, err)
 | 
						|
 | 
						|
			if dbType != test.expectedType {
 | 
						|
				t.Fatalf("Actual type: %s Expected type: %s", dbType, test.expectedType)
 | 
						|
			}
 | 
						|
		})
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestGRPCClient_Close(t *testing.T) {
 | 
						|
	runningCtx := context.Background()
 | 
						|
	cancelledCtx, cancel := context.WithCancel(context.Background())
 | 
						|
	cancel()
 | 
						|
 | 
						|
	type testCase struct {
 | 
						|
		client    proto.DatabaseClient
 | 
						|
		doneCtx   context.Context
 | 
						|
		assertErr errorAssertion
 | 
						|
	}
 | 
						|
 | 
						|
	tests := map[string]testCase{
 | 
						|
		"database error": {
 | 
						|
			client: fakeClient{
 | 
						|
				typeErr: errors.New("type error"),
 | 
						|
			},
 | 
						|
			doneCtx:   runningCtx,
 | 
						|
			assertErr: assertErrNotNil,
 | 
						|
		},
 | 
						|
		"plugin shut down": {
 | 
						|
			client: fakeClient{
 | 
						|
				typeErr: errors.New("type error"),
 | 
						|
			},
 | 
						|
			doneCtx:   cancelledCtx,
 | 
						|
			assertErr: assertErrEquals(ErrPluginShutdown),
 | 
						|
		},
 | 
						|
		"happy path": {
 | 
						|
			client:    fakeClient{},
 | 
						|
			doneCtx:   runningCtx,
 | 
						|
			assertErr: assertErrNil,
 | 
						|
		},
 | 
						|
	}
 | 
						|
 | 
						|
	for name, test := range tests {
 | 
						|
		t.Run(name, func(t *testing.T) {
 | 
						|
			c := gRPCClient{
 | 
						|
				client:  test.client,
 | 
						|
				doneCtx: test.doneCtx,
 | 
						|
			}
 | 
						|
 | 
						|
			err := c.Close()
 | 
						|
			test.assertErr(t, err)
 | 
						|
		})
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
type errorAssertion func(*testing.T, error)
 | 
						|
 | 
						|
func assertErrNotNil(t *testing.T, err error) {
 | 
						|
	t.Helper()
 | 
						|
	if err == nil {
 | 
						|
		t.Fatalf("err expected, got nil")
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func assertErrNil(t *testing.T, err error) {
 | 
						|
	t.Helper()
 | 
						|
	if err != nil {
 | 
						|
		t.Fatalf("no error expected, got: %s", err)
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func assertErrEquals(expectedErr error) errorAssertion {
 | 
						|
	return func(t *testing.T, err error) {
 | 
						|
		t.Helper()
 | 
						|
		if err != expectedErr {
 | 
						|
			t.Fatalf("Actual err: %#v Expected err: %#v", err, expectedErr)
 | 
						|
		}
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
// fakeClient methods
 | 
						|
 | 
						|
func (f fakeClient) Initialize(context.Context, *proto.InitializeRequest, ...grpc.CallOption) (*proto.InitializeResponse, error) {
 | 
						|
	return f.initResp, f.initErr
 | 
						|
}
 | 
						|
 | 
						|
func (f fakeClient) NewUser(context.Context, *proto.NewUserRequest, ...grpc.CallOption) (*proto.NewUserResponse, error) {
 | 
						|
	return f.newUserResp, f.newUserErr
 | 
						|
}
 | 
						|
 | 
						|
func (f fakeClient) UpdateUser(context.Context, *proto.UpdateUserRequest, ...grpc.CallOption) (*proto.UpdateUserResponse, error) {
 | 
						|
	return f.updateUserResp, f.updateUserErr
 | 
						|
}
 | 
						|
 | 
						|
func (f fakeClient) DeleteUser(context.Context, *proto.DeleteUserRequest, ...grpc.CallOption) (*proto.DeleteUserResponse, error) {
 | 
						|
	return f.deleteUserResp, f.deleteUserErr
 | 
						|
}
 | 
						|
 | 
						|
func (f fakeClient) Type(context.Context, *proto.Empty, ...grpc.CallOption) (*proto.TypeResponse, error) {
 | 
						|
	return f.typeResp, f.typeErr
 | 
						|
}
 | 
						|
 | 
						|
func (f fakeClient) Close(context.Context, *proto.Empty, ...grpc.CallOption) (*proto.Empty, error) {
 | 
						|
	return &proto.Empty{}, f.typeErr
 | 
						|
}
 |