refactor database plugin SDK (#29479)

* prepare for enterprise database plugin SDK development
This commit is contained in:
Thy Ton
2025-02-03 08:50:33 -08:00
committed by GitHub
parent cda9ad3491
commit 193796bfc9
16 changed files with 270 additions and 166 deletions

View File

@@ -8,12 +8,13 @@ import (
"fmt"
"time"
"google.golang.org/grpc"
"github.com/hashicorp/errwrap"
log "github.com/hashicorp/go-hclog"
plugin "github.com/hashicorp/go-plugin"
"github.com/hashicorp/go-plugin"
"github.com/hashicorp/vault/sdk/helper/consts"
"github.com/hashicorp/vault/sdk/helper/pluginutil"
"google.golang.org/grpc"
)
// Database is the interface that all database objects must implement.

View File

@@ -12,6 +12,7 @@ import (
"time"
"github.com/golang/protobuf/ptypes"
"github.com/hashicorp/vault/sdk/database/dbplugin/v5/proto"
"github.com/hashicorp/vault/sdk/helper/pluginutil"
"github.com/hashicorp/vault/sdk/logical"
@@ -25,6 +26,7 @@ var (
)
type gRPCClient struct {
entGRPCClient
client proto.DatabaseClient
versionClient logical.PluginVersionClient
doneCtx context.Context
@@ -285,20 +287,6 @@ func (c gRPCClient) Type() (string, error) {
return typeResp.GetType(), nil
}
func (c gRPCClient) Close() error {
ctx, cancel := getContextWithTimeout(pluginutil.PluginGRPCTimeoutClose)
defer cancel()
_, err := c.client.Close(ctx, &proto.Empty{})
if err != nil {
if c.doneCtx.Err() != nil {
return ErrPluginShutdown
}
return err
}
return nil
}
func getContextWithTimeout(env string) (context.Context, context.CancelFunc) {
timeout := 1 // default timeout
if envTimeout, err := strconv.Atoi(os.Getenv(env)); err == nil && envTimeout > 0 {

View File

@@ -0,0 +1,27 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
//go:build !enterprise
package dbplugin
import (
"github.com/hashicorp/vault/sdk/database/dbplugin/v5/proto"
"github.com/hashicorp/vault/sdk/helper/pluginutil"
)
type entGRPCClient struct{}
func (c gRPCClient) Close() error {
ctx, cancel := getContextWithTimeout(pluginutil.PluginGRPCTimeoutClose)
defer cancel()
_, err := c.client.Close(ctx, &proto.Empty{})
if err != nil {
if c.doneCtx.Err() != nil {
return ErrPluginShutdown
}
return err
}
return nil
}

View File

@@ -0,0 +1,31 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
//go:build !enterprise
package dbplugin
import (
"github.com/hashicorp/vault/sdk/database/dbplugin/v5/proto"
)
var _ proto.DatabaseClient = fakeClient{}
type fakeClient struct {
initResp *proto.InitializeResponse
initErr error
newUserResp *proto.NewUserResponse
newUserErr error
updateUserResp *proto.UpdateUserResponse
updateUserErr error
deleteUserResp *proto.DeleteUserResponse
deleteUserErr error
typeResp *proto.TypeResponse
typeErr error
closeErr error
}

View File

@@ -518,26 +518,7 @@ func assertErrEquals(expectedErr error) errorAssertion {
}
}
var _ proto.DatabaseClient = fakeClient{}
type fakeClient struct {
initResp *proto.InitializeResponse
initErr error
newUserResp *proto.NewUserResponse
newUserErr error
updateUserResp *proto.UpdateUserResponse
updateUserErr error
deleteUserResp *proto.DeleteUserResponse
deleteUserErr error
typeResp *proto.TypeResponse
typeErr error
closeErr error
}
// fakeClient methods
func (f fakeClient) Initialize(context.Context, *proto.InitializeRequest, ...grpc.CallOption) (*proto.InitializeResponse, error) {
return f.initResp, f.initErr

View File

@@ -4,13 +4,7 @@
package dbplugin
import (
"context"
"github.com/hashicorp/go-plugin"
"github.com/hashicorp/vault/sdk/database/dbplugin/v5/proto"
"github.com/hashicorp/vault/sdk/helper/pluginutil"
"github.com/hashicorp/vault/sdk/logical"
"google.golang.org/grpc"
)
// handshakeConfigs are used to just do a basic handshake between
@@ -37,36 +31,3 @@ var (
_ plugin.Plugin = &GRPCDatabasePlugin{}
_ plugin.GRPCPlugin = &GRPCDatabasePlugin{}
)
func (d GRPCDatabasePlugin) GRPCServer(_ *plugin.GRPCBroker, s *grpc.Server) error {
var server gRPCServer
if d.Impl != nil {
server = gRPCServer{singleImpl: d.Impl}
} else {
// multiplexing is supported
server = gRPCServer{
factoryFunc: d.FactoryFunc,
instances: make(map[string]Database),
}
// Multiplexing is enabled for this plugin, register the server so we
// can tell the client in Vault.
pluginutil.RegisterPluginMultiplexingServer(s, pluginutil.PluginMultiplexingServerImpl{
Supported: true,
})
}
proto.RegisterDatabaseServer(s, &server)
logical.RegisterPluginVersionServer(s, &server)
return nil
}
func (GRPCDatabasePlugin) GRPCClient(doneCtx context.Context, _ *plugin.GRPCBroker, c *grpc.ClientConn) (interface{}, error) {
client := gRPCClient{
client: proto.NewDatabaseClient(c),
versionClient: logical.NewPluginVersionClient(c),
doneCtx: doneCtx,
}
return client, nil
}

View File

@@ -0,0 +1,56 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
//go:build !enterprise
package dbplugin
import (
"context"
"google.golang.org/grpc"
"github.com/hashicorp/go-plugin"
"github.com/hashicorp/vault/sdk/database/dbplugin/v5/proto"
"github.com/hashicorp/vault/sdk/helper/pluginutil"
"github.com/hashicorp/vault/sdk/logical"
)
// GRPCClient (Vault CE edition) initializes and returns a gRPCClient with Database and
// PluginVersion gRPC clients. It implements GRPCClient() defined
// by GRPCPlugin interface in go-plugin/plugin.go
func (GRPCDatabasePlugin) GRPCClient(doneCtx context.Context, _ *plugin.GRPCBroker, c *grpc.ClientConn) (interface{}, error) {
client := gRPCClient{
client: proto.NewDatabaseClient(c),
versionClient: logical.NewPluginVersionClient(c),
doneCtx: doneCtx,
}
return client, nil
}
// GRPCServer (Vault CE edition) registers multiplexing server if the plugin supports it, and
// registers the Database and PluginVersion gRPC servers. It implements GRPCServer() defined
// by GRPCPlugin interface in go-plugin/plugin.go
func (d GRPCDatabasePlugin) GRPCServer(_ *plugin.GRPCBroker, s *grpc.Server) error {
var server gRPCServer
if d.Impl != nil {
server = gRPCServer{singleImpl: d.Impl}
} else {
// multiplexing is supported
server = gRPCServer{
factoryFunc: d.FactoryFunc,
instances: make(map[string]Database),
}
// Multiplexing is enabled for this plugin, register the server so we
// can tell the client in Vault.
pluginutil.RegisterPluginMultiplexingServer(s, pluginutil.PluginMultiplexingServerImpl{
Supported: true,
})
}
proto.RegisterDatabaseServer(s, &server)
logical.RegisterPluginVersionServer(s, &server)
return nil
}

View File

@@ -7,7 +7,6 @@ import (
"context"
"errors"
"fmt"
"sync"
"time"
"github.com/golang/protobuf/ptypes"
@@ -21,21 +20,6 @@ import (
var _ proto.DatabaseServer = &gRPCServer{}
type gRPCServer struct {
proto.UnimplementedDatabaseServer
logical.UnimplementedPluginVersionServer
// holds the non-multiplexed Database
// when this is set the plugin does not support multiplexing
singleImpl Database
// instances holds the multiplexed Databases
instances map[string]Database
factoryFunc func() (interface{}, error)
sync.RWMutex
}
func (g *gRPCServer) getOrCreateDatabase(ctx context.Context) (Database, error) {
g.Lock()
defer g.Unlock()

View File

@@ -0,0 +1,28 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
//go:build !enterprise
package dbplugin
import (
"sync"
"github.com/hashicorp/vault/sdk/database/dbplugin/v5/proto"
"github.com/hashicorp/vault/sdk/logical"
)
type gRPCServer struct {
proto.UnimplementedDatabaseServer
logical.UnimplementedPluginVersionServer
// holds the non-multiplexed Database
// when this is set the plugin does not support multiplexing
singleImpl Database
// instances holds the multiplexed Databases
instances map[string]Database
factoryFunc func() (interface{}, error)
sync.RWMutex
}

View File

@@ -0,0 +1,54 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
//go:build !enterprise
package dbplugin
import (
"github.com/hashicorp/vault/sdk/logical"
)
var _ Database = fakeDatabase{}
type fakeDatabase struct {
initResp InitializeResponse
initErr error
newUserResp NewUserResponse
newUserErr error
updateUserResp UpdateUserResponse
updateUserErr error
deleteUserResp DeleteUserResponse
deleteUserErr error
typeResp string
typeErr error
closeErr error
}
var _ Database = &recordingDatabase{}
type recordingDatabase struct {
initializeCalls int
newUserCalls int
updateUserCalls int
deleteUserCalls int
typeCalls int
closeCalls int
// recordingDatabase can act as middleware so we can record the calls to other test Database implementations
next Database
}
type fakeDatabaseWithVersion struct {
version string
}
var (
_ Database = (*fakeDatabaseWithVersion)(nil)
_ logical.PluginVersioner = (*fakeDatabaseWithVersion)(nil)
)

View File

@@ -706,26 +706,7 @@ func (badJSONValue) UnmarshalJSON([]byte) error {
return fmt.Errorf("this cannot be unmarshalled from JSON")
}
var _ Database = fakeDatabase{}
type fakeDatabase struct {
initResp InitializeResponse
initErr error
newUserResp NewUserResponse
newUserErr error
updateUserResp UpdateUserResponse
updateUserErr error
deleteUserResp DeleteUserResponse
deleteUserErr error
typeResp string
typeErr error
closeErr error
}
// fakeDatabase methods
func (e fakeDatabase) Initialize(ctx context.Context, req InitializeRequest) (InitializeResponse, error) {
return e.initResp, e.initErr
@@ -751,19 +732,7 @@ func (e fakeDatabase) Close() error {
return e.closeErr
}
var _ Database = &recordingDatabase{}
type recordingDatabase struct {
initializeCalls int
newUserCalls int
updateUserCalls int
deleteUserCalls int
typeCalls int
closeCalls int
// recordingDatabase can act as middleware so we can record the calls to other test Database implementations
next Database
}
// recordingDatabase methods
func (f *recordingDatabase) Initialize(ctx context.Context, req InitializeRequest) (InitializeResponse, error) {
f.initializeCalls++
@@ -813,9 +782,7 @@ func (f *recordingDatabase) Close() error {
return f.next.Close()
}
type fakeDatabaseWithVersion struct {
version string
}
// fakeDatabaseWithVersion methods
func (e fakeDatabaseWithVersion) PluginVersion() logical.PluginVersion {
return logical.PluginVersion{Version: e.version}

View File

@@ -4,11 +4,7 @@
package dbplugin
import (
"context"
"errors"
"github.com/hashicorp/go-plugin"
"github.com/hashicorp/vault/sdk/database/dbplugin/v5/proto"
"github.com/hashicorp/vault/sdk/helper/pluginutil"
"github.com/hashicorp/vault/sdk/logical"
)
@@ -45,38 +41,3 @@ var PluginSets = map[int]plugin.PluginSet{
"database": &GRPCDatabasePlugin{},
},
}
// NewPluginClient returns a databaseRPCClient with a connection to a running
// plugin.
func NewPluginClient(ctx context.Context, sys pluginutil.RunnerUtil, config pluginutil.PluginClientConfig) (Database, error) {
pluginClient, err := sys.NewPluginClient(ctx, config)
if err != nil {
return nil, err
}
// Request the plugin
raw, err := pluginClient.Dispense("database")
if err != nil {
return nil, err
}
// We should have a database type now. This feels like a normal interface
// implementation but is in fact over an RPC connection.
var db Database
switch c := raw.(type) {
case gRPCClient:
// This is an abstraction leak from go-plugin but it is necessary in
// order to enable multiplexing on multiplexed plugins
c.client = proto.NewDatabaseClient(pluginClient.Conn())
c.versionClient = logical.NewPluginVersionClient(pluginClient.Conn())
db = c
default:
return nil, errors.New("unsupported client type")
}
return &DatabasePluginClient{
client: pluginClient,
Database: db,
}, nil
}

View File

@@ -0,0 +1,50 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
//go:build !enterprise
package dbplugin
import (
"context"
"errors"
"github.com/hashicorp/vault/sdk/database/dbplugin/v5/proto"
"github.com/hashicorp/vault/sdk/helper/pluginutil"
"github.com/hashicorp/vault/sdk/logical"
)
// NewPluginClient returns a databaseRPCClient with a connection to a running
// plugin.
func NewPluginClient(ctx context.Context, sys pluginutil.RunnerUtil, config pluginutil.PluginClientConfig) (Database, error) {
pluginClient, err := sys.NewPluginClient(ctx, config)
if err != nil {
return nil, err
}
// Request the plugin
raw, err := pluginClient.Dispense("database")
if err != nil {
return nil, err
}
// We should have a database type now. This feels like a normal interface
// implementation but is in fact over an RPC connection.
var db Database
switch c := raw.(type) {
case gRPCClient:
// This is an abstraction leak from go-plugin but it is necessary in
// order to enable multiplexing on multiplexed plugins
c.client = proto.NewDatabaseClient(pluginClient.Conn())
c.versionClient = logical.NewPluginVersionClient(pluginClient.Conn())
db = c
default:
return nil, errors.New("unsupported client type")
}
return &DatabasePluginClient{
client: pluginClient,
Database: db,
}, nil
}

View File

@@ -60,6 +60,8 @@ func PluginFactoryVersion(ctx context.Context, pluginName string, pluginVersion
AutoMTLS: true,
Wrapper: sys,
}
config.EntUpdate(pluginRunner)
// create a DatabasePluginClient instance
db, err = NewPluginClient(ctx, sys, config)
if err != nil {

View File

@@ -30,6 +30,7 @@ const (
)
type PluginClientConfig struct {
EntPluginClientConfig
Name string
PluginType consts.PluginType
Version string

View File

@@ -0,0 +1,12 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
//go:build !enterprise
package pluginutil
type EntPluginClientConfig struct{}
func (p *PluginClientConfig) EntUpdate(_ *PluginRunner) {
// no-op
}