Add version pinning to plugin catalog (#24960)

Adds the ability to pin a version for a specific plugin type + name to enable an easier plugin upgrade UX. After pinning and reloading, that version should be the only version in use.

No HTTP API implementation yet for managing pins, so no user-facing effects yet.
This commit is contained in:
Tom Proctor
2024-01-26 17:21:43 +00:00
committed by GitHub
parent 55d5880857
commit af27ab3524
17 changed files with 693 additions and 186 deletions

View File

@@ -298,7 +298,17 @@ func (b *databaseBackend) GetConnectionWithConfig(ctx context.Context, name stri
return nil, err return nil, err
} }
dbw, err := newDatabaseWrapper(ctx, config.PluginName, config.PluginVersion, b.System(), b.logger) // Override the configured version if there is a pinned version.
pinnedVersion, err := b.getPinnedVersion(ctx, config.PluginName)
if err != nil {
return nil, err
}
pluginVersion := config.PluginVersion
if pinnedVersion != "" {
pluginVersion = pinnedVersion
}
dbw, err := newDatabaseWrapper(ctx, config.PluginName, pluginVersion, b.System(), b.logger)
if err != nil { if err != nil {
return nil, fmt.Errorf("unable to create database instance: %w", err) return nil, fmt.Errorf("unable to create database instance: %w", err)
} }

View File

@@ -436,58 +436,9 @@ func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc {
return logical.ErrorResponse(respErrEmptyPluginName), nil return logical.ErrorResponse(respErrEmptyPluginName), nil
} }
if pluginVersionRaw, ok := data.GetOk("plugin_version"); ok { pluginVersion, respErr, err := b.selectPluginVersion(ctx, config, data, req.Operation)
config.PluginVersion = pluginVersionRaw.(string) if respErr != nil || err != nil {
} return respErr, err
var builtinShadowed bool
if unversionedPlugin, err := b.System().LookupPlugin(ctx, config.PluginName, consts.PluginTypeDatabase); err == nil && !unversionedPlugin.Builtin {
builtinShadowed = true
}
switch {
case config.PluginVersion != "":
semanticVersion, err := version.NewVersion(config.PluginVersion)
if err != nil {
return logical.ErrorResponse("version %q is not a valid semantic version: %s", config.PluginVersion, err), nil
}
// Canonicalize the version.
config.PluginVersion = "v" + semanticVersion.String()
if config.PluginVersion == versions.GetBuiltinVersion(consts.PluginTypeDatabase, config.PluginName) {
if builtinShadowed {
return logical.ErrorResponse("database plugin %q, version %s not found, as it is"+
" overridden by an unversioned plugin of the same name. Omit `plugin_version` to use the unversioned plugin", config.PluginName, config.PluginVersion), nil
}
config.PluginVersion = ""
}
case builtinShadowed:
// We'll select the unversioned plugin that's been registered.
case req.Operation == logical.CreateOperation:
// No version provided and no unversioned plugin of that name available.
// Pin to the current latest version if any versioned plugins are registered.
plugins, err := b.System().ListVersionedPlugins(ctx, consts.PluginTypeDatabase)
if err != nil {
return nil, err
}
var versionedCandidates []pluginutil.VersionedPlugin
for _, plugin := range plugins {
if !plugin.Builtin && plugin.Name == config.PluginName && plugin.Version != "" {
versionedCandidates = append(versionedCandidates, plugin)
}
}
if len(versionedCandidates) != 0 {
// Sort in reverse order.
sort.SliceStable(versionedCandidates, func(i, j int) bool {
return versionedCandidates[i].SemanticVersion.GreaterThan(versionedCandidates[j].SemanticVersion)
})
config.PluginVersion = "v" + versionedCandidates[0].SemanticVersion.String()
b.logger.Debug(fmt.Sprintf("pinning %q database plugin version %q from candidates %v", config.PluginName, config.PluginVersion, versionedCandidates))
}
} }
if allowedRolesRaw, ok := data.GetOk("allowed_roles"); ok { if allowedRolesRaw, ok := data.GetOk("allowed_roles"); ok {
@@ -536,7 +487,7 @@ func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc {
} }
// Create a database plugin and initialize it. // Create a database plugin and initialize it.
dbw, err := newDatabaseWrapper(ctx, config.PluginName, config.PluginVersion, b.System(), b.logger) dbw, err := newDatabaseWrapper(ctx, config.PluginName, pluginVersion, b.System(), b.logger)
if err != nil { if err != nil {
return logical.ErrorResponse("error creating database object: %s", err), nil return logical.ErrorResponse("error creating database object: %s", err), nil
} }
@@ -613,6 +564,92 @@ func storeConfig(ctx context.Context, storage logical.Storage, name string, conf
return nil return nil
} }
func (b *databaseBackend) getPinnedVersion(ctx context.Context, pluginName string) (string, error) {
extendedSys, ok := b.System().(logical.ExtendedSystemView)
if !ok {
return "", fmt.Errorf("database backend does not support running as an external plugin")
}
pin, err := extendedSys.GetPinnedPluginVersion(ctx, consts.PluginTypeDatabase, pluginName)
if errors.Is(err, pluginutil.ErrPinnedVersionNotFound) {
return "", nil
}
if err != nil {
return "", err
}
return pin.Version, nil
}
func (b *databaseBackend) selectPluginVersion(ctx context.Context, config *DatabaseConfig, data *framework.FieldData, op logical.Operation) (string, *logical.Response, error) {
pinnedVersion, err := b.getPinnedVersion(ctx, config.PluginName)
if err != nil {
return "", nil, err
}
pluginVersionRaw, ok := data.GetOk("plugin_version")
switch {
case ok && pinnedVersion != "":
return "", logical.ErrorResponse("cannot specify plugin_version for plugin %q as it is pinned (v%s)", config.PluginName, pinnedVersion), nil
case pinnedVersion != "":
return pinnedVersion, nil, nil
case ok:
config.PluginVersion = pluginVersionRaw.(string)
}
var builtinShadowed bool
if unversionedPlugin, err := b.System().LookupPlugin(ctx, config.PluginName, consts.PluginTypeDatabase); err == nil && !unversionedPlugin.Builtin {
builtinShadowed = true
}
switch {
case config.PluginVersion != "":
semanticVersion, err := version.NewVersion(config.PluginVersion)
if err != nil {
return "", logical.ErrorResponse("version %q is not a valid semantic version: %s", config.PluginVersion, err), nil
}
// Canonicalize the version.
config.PluginVersion = "v" + semanticVersion.String()
if config.PluginVersion == versions.GetBuiltinVersion(consts.PluginTypeDatabase, config.PluginName) {
if builtinShadowed {
return "", logical.ErrorResponse("database plugin %q, version %s not found, as it is"+
" overridden by an unversioned plugin of the same name. Omit `plugin_version` to use the unversioned plugin", config.PluginName, config.PluginVersion), nil
}
config.PluginVersion = ""
}
case builtinShadowed:
// We'll select the unversioned plugin that's been registered.
case op == logical.CreateOperation:
// No version provided and no unversioned plugin of that name available.
// Pin to the current latest version if any versioned plugins are registered.
plugins, err := b.System().ListVersionedPlugins(ctx, consts.PluginTypeDatabase)
if err != nil {
return "", nil, err
}
var versionedCandidates []pluginutil.VersionedPlugin
for _, plugin := range plugins {
if !plugin.Builtin && plugin.Name == config.PluginName && plugin.Version != "" {
versionedCandidates = append(versionedCandidates, plugin)
}
}
if len(versionedCandidates) != 0 {
// Sort in reverse order.
sort.SliceStable(versionedCandidates, func(i, j int) bool {
return versionedCandidates[i].SemanticVersion.GreaterThan(versionedCandidates[j].SemanticVersion)
})
config.PluginVersion = "v" + versionedCandidates[0].SemanticVersion.String()
b.logger.Debug(fmt.Sprintf("pinning %q database plugin version %q from candidates %v", config.PluginName, config.PluginVersion, versionedCandidates))
}
}
return config.PluginVersion, nil, nil
}
const pathConfigConnectionHelpSyn = ` const pathConfigConnectionHelpSyn = `
Configure connection details to a database plugin. Configure connection details to a database plugin.
` `

View File

@@ -5,6 +5,7 @@ package pluginutil
import ( import (
"context" "context"
"errors"
"strings" "strings"
"time" "time"
@@ -17,6 +18,9 @@ import (
"google.golang.org/grpc" "google.golang.org/grpc"
) )
// ErrPluginNotFound is returned when a plugin does not have a pinned version.
var ErrPinnedVersionNotFound = errors.New("pinned version not found")
// Looker defines the plugin Lookup function that looks into the plugin catalog // Looker defines the plugin Lookup function that looks into the plugin catalog
// for available plugins and returns a PluginRunner // for available plugins and returns a PluginRunner
type Looker interface { type Looker interface {
@@ -144,6 +148,12 @@ type VersionedPlugin struct {
SemanticVersion *version.Version `json:"-"` SemanticVersion *version.Version `json:"-"`
} }
type PinnedVersion struct {
Name string `json:"name"`
Type consts.PluginType `json:"type"`
Version string `json:"version"`
}
// CtxCancelIfCanceled takes a context cancel func and a context. If the context is // CtxCancelIfCanceled takes a context cancel func and a context. If the context is
// shutdown the cancelfunc is called. This is useful for merging two cancel // shutdown the cancelfunc is called. This is useful for merging two cancel
// functions. // functions.

View File

@@ -127,6 +127,9 @@ type ExtendedSystemView interface {
// APILockShouldBlockRequest returns whether a namespace for the requested // APILockShouldBlockRequest returns whether a namespace for the requested
// mount is locked and should be blocked // mount is locked and should be blocked
APILockShouldBlockRequest() (bool, error) APILockShouldBlockRequest() (bool, error)
// GetPinnedPluginVersion returns the pinned version for the given plugin, if any.
GetPinnedPluginVersion(ctx context.Context, pluginType consts.PluginType, pluginName string) (*pluginutil.PinnedVersion, error)
} }
type PasswordGenerator func() (password string, err error) type PasswordGenerator func() (password string, err error)

View File

@@ -175,7 +175,7 @@ func (c *Core) enableCredentialInternal(ctx context.Context, entry *MountEntry,
var backend logical.Backend var backend logical.Backend
// Create the new backend // Create the new backend
sysView := c.mountEntrySysView(entry) sysView := c.mountEntrySysView(entry)
backend, entry.RunningSha256, err = c.newCredentialBackend(ctx, entry, sysView, view) backend, err = c.newCredentialBackend(ctx, entry, sysView, view)
if err != nil { if err != nil {
return err return err
} }
@@ -188,14 +188,6 @@ func (c *Core) enableCredentialInternal(ctx context.Context, entry *MountEntry,
if backendType != logical.TypeCredential { if backendType != logical.TypeCredential {
return fmt.Errorf("cannot mount %q of type %q as an auth backend", entry.Type, backendType) return fmt.Errorf("cannot mount %q of type %q as an auth backend", entry.Type, backendType)
} }
// update the entry running version with the configured version, which was verified during registration.
entry.RunningVersion = entry.Version
if entry.RunningVersion == "" {
// don't set the running version to a builtin if it is running as an external plugin
if entry.RunningSha256 == "" {
entry.RunningVersion = versions.GetBuiltinVersion(consts.PluginTypeCredential, entry.Type)
}
}
addPathCheckers(c, entry, backend, viewPath) addPathCheckers(c, entry, backend, viewPath)
// If the mount is filtered or we are on a DR secondary we don't want to // If the mount is filtered or we are on a DR secondary we don't want to
@@ -249,7 +241,7 @@ func (c *Core) enableCredentialInternal(ctx context.Context, entry *MountEntry,
} }
if c.logger.IsInfo() { if c.logger.IsInfo() {
c.logger.Info("enabled credential backend", "path", entry.Path, "type", entry.Type, "version", entry.Version) c.logger.Info("enabled credential backend", "path", entry.Path, "type", entry.Type, "version", entry.RunningVersion)
} }
return nil return nil
} }
@@ -805,29 +797,24 @@ func (c *Core) setupCredentials(ctx context.Context) error {
// Initialize the backend // Initialize the backend
sysView := c.mountEntrySysView(entry) sysView := c.mountEntrySysView(entry)
backend, entry.RunningSha256, err = c.newCredentialBackend(ctx, entry, sysView, view) backend, err = c.newCredentialBackend(ctx, entry, sysView, view)
if err != nil { if err != nil {
c.logger.Error("failed to create credential entry", "path", entry.Path, "error", err) c.logger.Error("failed to create credential entry", "path", entry.Path, "error", err)
if c.isMountable(ctx, entry, consts.PluginTypeCredential) { mountable, checkErr := c.isMountable(ctx, entry, consts.PluginTypeSecrets)
if checkErr != nil {
return errors.Join(errLoadMountsFailed, checkErr, err)
}
if mountable {
c.logger.Warn("skipping plugin-based auth entry", "path", entry.Path) c.logger.Warn("skipping plugin-based auth entry", "path", entry.Path)
goto ROUTER_MOUNT goto ROUTER_MOUNT
} }
return errLoadAuthFailed return errors.Join(errLoadAuthFailed, err)
} }
if backend == nil { if backend == nil {
return fmt.Errorf("nil backend returned from %q factory", entry.Type) return fmt.Errorf("nil backend returned from %q factory", entry.Type)
} }
// update the entry running version with the configured version, which was verified during registration.
entry.RunningVersion = entry.Version
if entry.RunningVersion == "" {
// don't set the running version to a builtin if it is running as an external plugin
if entry.RunningSha256 == "" {
entry.RunningVersion = versions.GetBuiltinVersion(consts.PluginTypeCredential, entry.Type)
}
}
// Do not start up deprecated builtin plugins. If this is a major // Do not start up deprecated builtin plugins. If this is a major
// upgrade, stop unsealing and shutdown. If we've already mounted this // upgrade, stop unsealing and shutdown. If we've already mounted this
// plugin, skip backend initialization and mount the data for posterity. // plugin, skip backend initialization and mount the data for posterity.
@@ -952,34 +939,37 @@ func (c *Core) teardownCredentials(ctx context.Context) error {
} }
// newCredentialBackend is used to create and configure a new credential backend by name. // newCredentialBackend is used to create and configure a new credential backend by name.
// It also returns the SHA256 of the plugin, if available. func (c *Core) newCredentialBackend(ctx context.Context, entry *MountEntry, sysView logical.SystemView, view logical.Storage) (logical.Backend, error) {
func (c *Core) newCredentialBackend(ctx context.Context, entry *MountEntry, sysView logical.SystemView, view logical.Storage) (logical.Backend, string, error) {
t := entry.Type t := entry.Type
if alias, ok := credentialAliases[t]; ok { if alias, ok := credentialAliases[t]; ok {
t = alias t = alias
} }
var runningSha string pluginVersion, err := c.resolveMountEntryVersion(ctx, consts.PluginTypeCredential, entry)
f, ok := c.credentialBackends[t]
if !ok {
plug, err := c.pluginCatalog.Get(ctx, t, consts.PluginTypeCredential, entry.Version)
if err != nil { if err != nil {
return nil, "", err return nil, err
}
var runningSha string
factory, ok := c.credentialBackends[t]
if !ok {
plug, err := c.pluginCatalog.Get(ctx, t, consts.PluginTypeCredential, pluginVersion)
if err != nil {
return nil, err
} }
if plug == nil { if plug == nil {
errContext := t errContext := t
if entry.Version != "" { if pluginVersion != "" {
errContext += fmt.Sprintf(", version=%s", entry.Version) errContext += fmt.Sprintf(", version=%s", pluginVersion)
} }
return nil, "", fmt.Errorf("%w: %s", plugincatalog.ErrPluginNotFound, errContext) return nil, fmt.Errorf("%w: %s", plugincatalog.ErrPluginNotFound, errContext)
} }
if len(plug.Sha256) > 0 { if len(plug.Sha256) > 0 {
runningSha = hex.EncodeToString(plug.Sha256) runningSha = hex.EncodeToString(plug.Sha256)
} }
f = plugin.Factory factory = plugin.Factory
if !plug.Builtin { if !plug.Builtin {
f = wrapFactoryCheckPerms(c, plugin.Factory) factory = wrapFactoryCheckPerms(c, plugin.Factory)
} }
} }
// Set up conf to pass in plugin_name // Set up conf to pass in plugin_name
@@ -996,7 +986,7 @@ func (c *Core) newCredentialBackend(ctx context.Context, entry *MountEntry, sysV
} }
conf["plugin_type"] = consts.PluginTypeCredential.String() conf["plugin_type"] = consts.PluginTypeCredential.String()
conf["plugin_version"] = entry.Version conf["plugin_version"] = pluginVersion
authLogger := c.baseLogger.Named(fmt.Sprintf("auth.%s.%s", t, entry.Accessor)) authLogger := c.baseLogger.Named(fmt.Sprintf("auth.%s.%s", t, entry.Accessor))
c.AddLogger(authLogger) c.AddLogger(authLogger)
@@ -1005,11 +995,11 @@ func (c *Core) newCredentialBackend(ctx context.Context, entry *MountEntry, sysV
MountAccessor: entry.Accessor, MountAccessor: entry.Accessor,
MountPath: entry.Path, MountPath: entry.Path,
Plugin: entry.Type, Plugin: entry.Type,
PluginVersion: entry.RunningVersion, PluginVersion: pluginVersion,
Version: entry.Version, Version: entry.Options["version"],
}) })
if err != nil { if err != nil {
return nil, "", err return nil, err
} }
config := &logical.BackendConfig{ config := &logical.BackendConfig{
@@ -1021,12 +1011,19 @@ func (c *Core) newCredentialBackend(ctx context.Context, entry *MountEntry, sysV
EventsSender: pluginEventSender, EventsSender: pluginEventSender,
} }
b, err := f(ctx, config) backend, err := factory(ctx, config)
if err != nil { if err != nil {
return nil, "", err return nil, err
}
if backend != nil {
entry.RunningVersion = pluginVersion
entry.RunningSha256 = runningSha
if entry.RunningVersion == "" && entry.RunningSha256 == "" {
entry.RunningVersion = versions.GetBuiltinVersion(consts.PluginTypeCredential, entry.Type)
}
} }
return b, runningSha, nil return backend, nil
} }
func wrapFactoryCheckPerms(core *Core, f logical.Factory) logical.Factory { func wrapFactoryCheckPerms(core *Core, f logical.Factory) logical.Factory {

View File

@@ -3546,16 +3546,17 @@ func (c *Core) readFeatureFlags(ctx context.Context) (*FeatureFlags, error) {
// misconfigured. This allows users to recover from errors when starting Vault // misconfigured. This allows users to recover from errors when starting Vault
// with misconfigured plugins. It should not be possible for existing builtins // with misconfigured plugins. It should not be possible for existing builtins
// to be misconfigured, so that is a fatal error. // to be misconfigured, so that is a fatal error.
func (c *Core) isMountable(ctx context.Context, entry *MountEntry, pluginType consts.PluginType) bool { func (c *Core) isMountable(ctx context.Context, entry *MountEntry, pluginType consts.PluginType) (bool, error) {
return !c.isMountEntryBuiltin(ctx, entry, pluginType) builtin, err := c.isMountEntryBuiltin(ctx, entry, pluginType)
return !builtin, err
} }
// isMountEntryBuiltin determines whether a mount entry is associated with a // isMountEntryBuiltin determines whether a mount entry is associated with a
// builtin of the specified plugin type. // builtin of the specified plugin type.
func (c *Core) isMountEntryBuiltin(ctx context.Context, entry *MountEntry, pluginType consts.PluginType) bool { func (c *Core) isMountEntryBuiltin(ctx context.Context, entry *MountEntry, pluginType consts.PluginType) (bool, error) {
// Prevent a panic early on // Prevent a panic early on
if entry == nil || c.pluginCatalog == nil { if entry == nil || c.pluginCatalog == nil {
return false return false, nil
} }
// Allow type to be determined from mount entry when not otherwise specified // Allow type to be determined from mount entry when not otherwise specified
@@ -3569,12 +3570,16 @@ func (c *Core) isMountEntryBuiltin(ctx context.Context, entry *MountEntry, plugi
pluginName = alias pluginName = alias
} }
plug, err := c.pluginCatalog.Get(ctx, pluginName, pluginType, entry.Version) pluginVersion, err := c.resolveMountEntryVersion(ctx, pluginType, entry)
if err != nil {
return false, err
}
plug, err := c.pluginCatalog.Get(ctx, pluginName, pluginType, pluginVersion)
if err != nil || plug == nil { if err != nil || plug == nil {
return false return false, nil
} }
return plug.Builtin return plug.Builtin, nil
} }
// MatchingMount returns the path of the mount that will be responsible for // MatchingMount returns the path of the mount that will be responsible for

View File

@@ -161,6 +161,11 @@ func (e extendedSystemViewImpl) DeregisterWellKnownRedirect(ctx context.Context,
return e.core.WellKnownRedirects.DeregisterSource(e.mountEntry.UUID, src) return e.core.WellKnownRedirects.DeregisterSource(e.mountEntry.UUID, src)
} }
// GetPinnedPluginVersion implements logical.ExtendedSystemView.
func (e extendedSystemViewImpl) GetPinnedPluginVersion(ctx context.Context, pluginType consts.PluginType, pluginName string) (*pluginutil.PinnedVersion, error) {
return e.core.pluginCatalog.GetPinnedVersion(ctx, pluginType, pluginName)
}
func (d dynamicSystemView) DefaultLeaseTTL() time.Duration { func (d dynamicSystemView) DefaultLeaseTTL() time.Duration {
def, _ := d.fetchTTLs() def, _ := d.fetchTTLs()
return def return def

View File

@@ -17,6 +17,7 @@ import (
"github.com/hashicorp/vault/helper/testhelpers/pluginhelpers" "github.com/hashicorp/vault/helper/testhelpers/pluginhelpers"
"github.com/hashicorp/vault/sdk/framework" "github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/consts" "github.com/hashicorp/vault/sdk/helper/consts"
"github.com/hashicorp/vault/sdk/helper/pluginutil"
"github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/sdk/logical"
"github.com/hashicorp/vault/sdk/plugin" "github.com/hashicorp/vault/sdk/plugin"
"github.com/hashicorp/vault/sdk/plugin/mock" "github.com/hashicorp/vault/sdk/plugin/mock"
@@ -95,6 +96,98 @@ func TestCore_EnableExternalPlugin(t *testing.T) {
} }
} }
// TestCore_UpgradePluginUsingPinnedVersion tests a full workflow of upgrading
// an external plugin gated by pinned versions.
func TestCore_UpgradePluginUsingPinnedVersion(t *testing.T) {
cluster := NewTestCluster(t, &CoreConfig{}, &TestClusterOptions{
Plugins: []*TestPluginConfig{
{
Typ: consts.PluginTypeCredential,
Versions: []string{""},
},
{
Typ: consts.PluginTypeSecrets,
Versions: []string{""},
},
},
})
cluster.Start()
t.Cleanup(cluster.Cleanup)
c := cluster.Cores[0].Core
TestWaitActive(t, c)
for name, tc := range map[string]struct {
idx int
}{
"credential plugin": {
idx: 0,
},
"secrets plugin": {
idx: 1,
},
} {
t.Run(name, func(t *testing.T) {
plugin := cluster.Plugins[tc.idx]
for _, version := range []string{"v1.0.0", "v1.0.1"} {
registerPlugin(t, c.systemBackend, plugin.Name, plugin.Typ.String(), version, plugin.Sha256, plugin.FileName)
}
// Mount 1.0.0 then pin to 1.0.1
mountPlugin(t, c.systemBackend, plugin.Name, plugin.Typ, "v1.0.0", "")
err := c.pluginCatalog.SetPinnedVersion(context.Background(), &pluginutil.PinnedVersion{
Name: plugin.Name,
Type: plugin.Typ,
Version: "v1.0.1",
})
if err != nil {
t.Fatal(err)
}
mountedPath := "foo/"
if plugin.Typ == consts.PluginTypeCredential {
mountedPath = "auth/" + mountedPath
}
expectRunningVersion(t, c, mountedPath, "v1.0.0")
reloaded, err := c.reloadMatchingPlugin(context.Background(), nil, plugin.Typ, plugin.Name)
if reloaded != 1 || err != nil {
t.Fatal(reloaded, err)
}
// Pinned version should be in effect after reloading.
expectRunningVersion(t, c, mountedPath, "v1.0.1")
err = c.pluginCatalog.DeletePinnedVersion(context.Background(), plugin.Typ, plugin.Name)
if err != nil {
t.Fatal(err)
}
reloaded, err = c.reloadMatchingPlugin(context.Background(), nil, plugin.Typ, plugin.Name)
if reloaded != 1 || err != nil {
t.Fatal(reloaded, err)
}
// After pin is deleted, the previously configured version should stand.
expectRunningVersion(t, c, mountedPath, "v1.0.0")
})
}
}
func expectRunningVersion(t *testing.T, c *Core, path, expectedVersion string) {
t.Helper()
match := c.router.MatchingMount(namespace.RootContext(context.Background()), path)
if match != path {
t.Fatalf("missing mount for %s, match: %q", path, match)
}
raw, _ := c.router.root.Get(match)
if actual := raw.(*routeEntry).mountEntry.RunningVersion; expectedVersion != actual {
t.Fatalf("expected running_plugin_version to be %s but got %s", expectedVersion, actual)
}
}
func TestCore_EnableExternalPlugin_MultipleVersions(t *testing.T) { func TestCore_EnableExternalPlugin_MultipleVersions(t *testing.T) {
for name, tc := range map[string]struct { for name, tc := range map[string]struct {
pluginType consts.PluginType pluginType consts.PluginType

View File

@@ -691,7 +691,7 @@ func TestIdentityStore_LoadingEntities(t *testing.T) {
ghSysview := c.mountEntrySysView(meGH) ghSysview := c.mountEntrySysView(meGH)
// Create new github auth credential backend // Create new github auth credential backend
ghAuth, _, err := c.newCredentialBackend(context.Background(), meGH, ghSysview, ghView) ghAuth, err := c.newCredentialBackend(context.Background(), meGH, ghSysview, ghView)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@@ -1535,7 +1535,11 @@ func (b *SystemBackend) handleMount(ctx context.Context, req *logical.Request, d
Version: pluginVersion, Version: pluginVersion,
} }
if b.Core.isMountEntryBuiltin(ctx, me, consts.PluginTypeSecrets) { builtin, err := b.Core.isMountEntryBuiltin(ctx, me, consts.PluginTypeSecrets)
if err != nil {
return nil, err
}
if builtin {
resp, err = b.Core.handleDeprecatedMountEntry(ctx, me, consts.PluginTypeSecrets) resp, err = b.Core.handleDeprecatedMountEntry(ctx, me, consts.PluginTypeSecrets)
if err != nil { if err != nil {
b.Core.logger.Error("could not mount builtin", "name", me.Type, "path", me.Path, "error", err) b.Core.logger.Error("could not mount builtin", "name", me.Type, "path", me.Path, "error", err)
@@ -1949,7 +1953,8 @@ func (b *SystemBackend) handleTuneReadCommon(ctx context.Context, path string) (
resp.Data["external_entropy_access"] = true resp.Data["external_entropy_access"] = true
} }
if mountEntry.Table == credentialTableType { isAuth := mountEntry.Table == credentialTableType
if isAuth {
resp.Data["token_type"] = mountEntry.Config.TokenType.String() resp.Data["token_type"] = mountEntry.Config.TokenType.String()
} }
@@ -1995,6 +2000,19 @@ func (b *SystemBackend) handleTuneReadCommon(ctx context.Context, path string) (
if mountEntry.Version != "" { if mountEntry.Version != "" {
resp.Data["plugin_version"] = mountEntry.Version resp.Data["plugin_version"] = mountEntry.Version
} }
var pinnedVersion *pluginutil.PinnedVersion
var err error
if isAuth {
pinnedVersion, err = b.Core.pluginCatalog.GetPinnedVersion(ctx, consts.PluginTypeCredential, mountEntry.Type)
} else {
pinnedVersion, err = b.Core.pluginCatalog.GetPinnedVersion(ctx, consts.PluginTypeSecrets, mountEntry.Type)
}
if err != nil && !errors.Is(err, pluginutil.ErrPinnedVersionNotFound) {
return nil, err
}
if pinnedVersion != nil && mountEntry.Version != pinnedVersion.Version {
resp.AddWarning(fmt.Sprintf("plugin_version is configured as %s but a version pin for %s is in effect", mountEntry.Version, pinnedVersion.Version))
}
return resp, nil return resp, nil
} }
@@ -2236,6 +2254,19 @@ func (b *SystemBackend) handleTuneWriteCommon(ctx context.Context, path string,
} }
if rawVal, ok := data.GetOk("plugin_version"); ok { if rawVal, ok := data.GetOk("plugin_version"); ok {
pluginType := consts.PluginTypeSecrets
if strings.HasPrefix(path, "auth/") {
pluginType = consts.PluginTypeCredential
}
pinnedVersion, err := b.Core.pluginCatalog.GetPinnedVersion(ctx, pluginType, mountEntry.Type)
if err != nil && !errors.Is(err, pluginutil.ErrPinnedVersionNotFound) {
return nil, err
}
if pinnedVersion != nil {
return logical.ErrorResponse(fmt.Sprintf("plugin_version cannot be set for %s plugin %q as a pinned version %s is in effect", pluginType, mountEntry.Type, pinnedVersion.Version)), nil
}
version := rawVal.(string) version := rawVal.(string)
semanticVersion, err := semver.NewVersion(version) semanticVersion, err := semver.NewVersion(version)
if err != nil { if err != nil {
@@ -2244,10 +2275,6 @@ func (b *SystemBackend) handleTuneWriteCommon(ctx context.Context, path string,
version = "v" + semanticVersion.String() version = "v" + semanticVersion.String()
// Lookup the version to ensure it exists in the catalog before committing. // Lookup the version to ensure it exists in the catalog before committing.
pluginType := consts.PluginTypeSecrets
if strings.HasPrefix(path, "auth/") {
pluginType = consts.PluginTypeCredential
}
_, err = b.System().LookupPluginVersion(ctx, mountEntry.Type, pluginType, version) _, err = b.System().LookupPluginVersion(ctx, mountEntry.Type, pluginType, version)
if err != nil { if err != nil {
return handleError(err) return handleError(err)
@@ -3106,7 +3133,11 @@ func (b *SystemBackend) handleEnableAuth(ctx context.Context, req *logical.Reque
} }
var resp *logical.Response var resp *logical.Response
if b.Core.isMountEntryBuiltin(ctx, me, consts.PluginTypeCredential) { builtin, err := b.Core.isMountEntryBuiltin(ctx, me, consts.PluginTypeCredential)
if err != nil {
return nil, err
}
if builtin {
resp, err = b.Core.handleDeprecatedMountEntry(ctx, me, consts.PluginTypeCredential) resp, err = b.Core.handleDeprecatedMountEntry(ctx, me, consts.PluginTypeCredential)
if err != nil { if err != nil {
b.Core.logger.Error("could not mount builtin", "name", me.Type, "path", me.Path, "error", err) b.Core.logger.Error("could not mount builtin", "name", me.Type, "path", me.Path, "error", err)
@@ -3123,6 +3154,18 @@ func (b *SystemBackend) handleEnableAuth(ctx context.Context, req *logical.Reque
} }
func (b *SystemBackend) validateVersion(ctx context.Context, version string, pluginName string, pluginType consts.PluginType) (string, *logical.Response, error) { func (b *SystemBackend) validateVersion(ctx context.Context, version string, pluginName string, pluginType consts.PluginType) (string, *logical.Response, error) {
pinnedVersion, err := b.Core.pluginCatalog.GetPinnedVersion(ctx, pluginType, pluginName)
if err != nil && !errors.Is(err, pluginutil.ErrPinnedVersionNotFound) {
return "", nil, err
}
if pinnedVersion != nil {
if version != "" {
return "", logical.ErrorResponse("cannot specify plugin_version for %s plugin %q, as it is pinned to version %s", pluginType.String(), pluginName, pinnedVersion.Version), nil
}
return pinnedVersion.Version, nil, nil
}
switch version { switch version {
case "": case "":
var err error var err error

View File

@@ -23,6 +23,7 @@ import (
"github.com/hashicorp/vault/helper/versions" "github.com/hashicorp/vault/helper/versions"
"github.com/hashicorp/vault/sdk/helper/consts" "github.com/hashicorp/vault/sdk/helper/consts"
"github.com/hashicorp/vault/sdk/helper/jsonutil" "github.com/hashicorp/vault/sdk/helper/jsonutil"
"github.com/hashicorp/vault/sdk/helper/pluginutil"
"github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/sdk/logical"
"github.com/hashicorp/vault/vault/plugincatalog" "github.com/hashicorp/vault/vault/plugincatalog"
"github.com/mitchellh/copystructure" "github.com/mitchellh/copystructure"
@@ -347,8 +348,8 @@ type MountEntry struct {
synthesizedConfigCache sync.Map synthesizedConfigCache sync.Map
// version info // version info
Version string `json:"plugin_version,omitempty"` // The semantic version of the mounted plugin, e.g. v1.2.3. Version string `json:"plugin_version,omitempty"` // The configured semantic version of the mounted plugin, e.g. v1.2.3. May be overridden by a pinned version.
RunningVersion string `json:"running_plugin_version,omitempty"` // The semantic version of the mounted plugin as reported by the plugin. RunningVersion string `json:"running_plugin_version,omitempty"` // The semantic version of the currently running mounted plugin.
RunningSha256 string `json:"running_sha256,omitempty"` RunningSha256 string `json:"running_sha256,omitempty"`
} }
@@ -703,13 +704,10 @@ func (c *Core) mountInternal(ctx context.Context, entry *MountEntry, updateStora
var backend logical.Backend var backend logical.Backend
sysView := c.mountEntrySysView(entry) sysView := c.mountEntrySysView(entry)
backend, entry.RunningSha256, err = c.newLogicalBackend(ctx, entry, sysView, view) backend, err = c.newLogicalBackend(ctx, entry, sysView, view)
if err != nil { if err != nil {
return err return err
} }
if backend == nil {
return fmt.Errorf("nil backend of type %q returned from creation function", entry.Type)
}
// Check for the correct backend type // Check for the correct backend type
backendType := backend.Type() backendType := backend.Type()
@@ -719,15 +717,6 @@ func (c *Core) mountInternal(ctx context.Context, entry *MountEntry, updateStora
} }
} }
// update the entry running version with the configured version, which was verified during registration.
entry.RunningVersion = entry.Version
if entry.RunningVersion == "" {
// don't set the running version to a builtin if it is running as an external plugin
if entry.RunningSha256 == "" {
entry.RunningVersion = versions.GetBuiltinVersion(consts.PluginTypeSecrets, entry.Type)
}
}
addPathCheckers(c, entry, backend, viewPath) addPathCheckers(c, entry, backend, viewPath)
c.setCoreBackend(entry, backend, view) c.setCoreBackend(entry, backend, view)
@@ -788,7 +777,7 @@ func (c *Core) mountInternal(ctx context.Context, entry *MountEntry, updateStora
} }
if c.logger.IsInfo() { if c.logger.IsInfo() {
c.logger.Info("successful mount", "namespace", entry.Namespace().Path, "path", entry.Path, "type", entry.Type, "version", entry.Version) c.logger.Info("successful mount", "namespace", entry.Namespace().Path, "path", entry.Path, "type", entry.Type, "version", entry.RunningVersion)
} }
return nil return nil
} }
@@ -1543,27 +1532,19 @@ func (c *Core) setupMounts(ctx context.Context) error {
var backend logical.Backend var backend logical.Backend
// Create the new backend // Create the new backend
sysView := c.mountEntrySysView(entry) sysView := c.mountEntrySysView(entry)
backend, entry.RunningSha256, err = c.newLogicalBackend(ctx, entry, sysView, view) backend, err = c.newLogicalBackend(ctx, entry, sysView, view)
if err != nil { if err != nil {
c.logger.Error("failed to create mount entry", "path", entry.Path, "error", err) c.logger.Error("failed to create mount entry", "path", entry.Path, "error", err)
if c.isMountable(ctx, entry, consts.PluginTypeSecrets) { mountable, checkErr := c.isMountable(ctx, entry, consts.PluginTypeSecrets)
if checkErr != nil {
return errors.Join(errLoadMountsFailed, checkErr, err)
}
if mountable {
c.logger.Warn("skipping plugin-based mount entry", "path", entry.Path) c.logger.Warn("skipping plugin-based mount entry", "path", entry.Path)
goto ROUTER_MOUNT goto ROUTER_MOUNT
} }
return errLoadMountsFailed return errors.Join(errLoadMountsFailed, err)
}
if backend == nil {
return fmt.Errorf("created mount entry of type %q is nil", entry.Type)
}
// update the entry running version with the configured version, which was verified during registration.
entry.RunningVersion = entry.Version
if entry.RunningVersion == "" {
// don't set the running version to a builtin if it is running as an external plugin
if entry.RunningSha256 == "" {
entry.RunningVersion = versions.GetBuiltinVersion(consts.PluginTypeSecrets, entry.Type)
}
} }
// Do not start up deprecated builtin plugins. If this is a major // Do not start up deprecated builtin plugins. If this is a major
@@ -1680,34 +1661,37 @@ func (c *Core) unloadMounts(ctx context.Context) error {
} }
// newLogicalBackend is used to create and configure a new logical backend by name. // newLogicalBackend is used to create and configure a new logical backend by name.
// It also returns the SHA256 of the plugin, if available. func (c *Core) newLogicalBackend(ctx context.Context, entry *MountEntry, sysView logical.SystemView, view logical.Storage) (logical.Backend, error) {
func (c *Core) newLogicalBackend(ctx context.Context, entry *MountEntry, sysView logical.SystemView, view logical.Storage) (logical.Backend, string, error) {
t := entry.Type t := entry.Type
if alias, ok := mountAliases[t]; ok { if alias, ok := mountAliases[t]; ok {
t = alias t = alias
} }
var runningSha string pluginVersion, err := c.resolveMountEntryVersion(ctx, consts.PluginTypeSecrets, entry)
f, ok := c.logicalBackends[t]
if !ok {
plug, err := c.pluginCatalog.Get(ctx, t, consts.PluginTypeSecrets, entry.Version)
if err != nil { if err != nil {
return nil, "", err return nil, err
}
var runningSha string
factory, ok := c.logicalBackends[t]
if !ok {
plug, err := c.pluginCatalog.Get(ctx, t, consts.PluginTypeSecrets, pluginVersion)
if err != nil {
return nil, err
} }
if plug == nil { if plug == nil {
errContext := t errContext := t
if entry.Version != "" { if pluginVersion != "" {
errContext += fmt.Sprintf(", version=%s", entry.Version) errContext += fmt.Sprintf(", version=%s", pluginVersion)
} }
return nil, "", fmt.Errorf("%w: %s", plugincatalog.ErrPluginNotFound, errContext) return nil, fmt.Errorf("%w: %s", plugincatalog.ErrPluginNotFound, errContext)
} }
if len(plug.Sha256) > 0 { if len(plug.Sha256) > 0 {
runningSha = hex.EncodeToString(plug.Sha256) runningSha = hex.EncodeToString(plug.Sha256)
} }
f = plugin.Factory factory = plugin.Factory
if !plug.Builtin { if !plug.Builtin {
f = wrapFactoryCheckPerms(c, plugin.Factory) factory = wrapFactoryCheckPerms(c, factory)
} }
} }
// Set up conf to pass in plugin_name // Set up conf to pass in plugin_name
@@ -1724,7 +1708,7 @@ func (c *Core) newLogicalBackend(ctx context.Context, entry *MountEntry, sysView
} }
conf["plugin_type"] = consts.PluginTypeSecrets.String() conf["plugin_type"] = consts.PluginTypeSecrets.String()
conf["plugin_version"] = entry.Version conf["plugin_version"] = pluginVersion
backendLogger := c.baseLogger.Named(fmt.Sprintf("secrets.%s.%s", t, entry.Accessor)) backendLogger := c.baseLogger.Named(fmt.Sprintf("secrets.%s.%s", t, entry.Accessor))
c.AddLogger(backendLogger) c.AddLogger(backendLogger)
@@ -1733,11 +1717,11 @@ func (c *Core) newLogicalBackend(ctx context.Context, entry *MountEntry, sysView
MountAccessor: entry.Accessor, MountAccessor: entry.Accessor,
MountPath: entry.Path, MountPath: entry.Path,
Plugin: entry.Type, Plugin: entry.Type,
PluginVersion: entry.RunningVersion, PluginVersion: pluginVersion,
Version: entry.Version, Version: entry.Options["version"],
}) })
if err != nil { if err != nil {
return nil, "", err return nil, err
} }
config := &logical.BackendConfig{ config := &logical.BackendConfig{
StorageView: view, StorageView: view,
@@ -1750,16 +1734,39 @@ func (c *Core) newLogicalBackend(ctx context.Context, entry *MountEntry, sysView
ctx = namespace.ContextWithNamespace(ctx, entry.namespace) ctx = namespace.ContextWithNamespace(ctx, entry.namespace)
ctx = context.WithValue(ctx, "core_number", c.coreNumber) ctx = context.WithValue(ctx, "core_number", c.coreNumber)
b, err := f(ctx, config) backend, err := factory(ctx, config)
if err != nil { if err != nil {
return nil, "", err return nil, err
} }
if b == nil { if backend == nil {
return nil, "", fmt.Errorf("nil backend of type %q returned from factory", t) return nil, fmt.Errorf("nil backend of type %q returned from factory", t)
} }
addLicenseCallback(c, b)
return b, runningSha, nil entry.RunningVersion = pluginVersion
entry.RunningSha256 = runningSha
if entry.RunningVersion == "" && entry.RunningSha256 == "" {
entry.RunningVersion = versions.GetBuiltinVersion(consts.PluginTypeSecrets, entry.Type)
}
addLicenseCallback(c, backend)
return backend, nil
}
// resolveMountEntryVersion allows entry.Version to be overridden if there is a
// corresponding pinned version.
func (c *Core) resolveMountEntryVersion(ctx context.Context, pluginType consts.PluginType, entry *MountEntry) (string, error) {
pluginName := entry.Type
if alias, ok := mountAliases[pluginName]; ok {
pluginName = alias
}
pinnedVersion, err := c.pluginCatalog.GetPinnedVersion(ctx, pluginType, pluginName)
if err != nil && !errors.Is(err, pluginutil.ErrPinnedVersionNotFound) {
return "", err
}
if pinnedVersion != nil {
return pinnedVersion.Version, nil
}
return entry.Version, nil
} }
// defaultMountTable creates a default mount table // defaultMountTable creates a default mount table

View File

@@ -9,7 +9,6 @@ import (
"strings" "strings"
"github.com/hashicorp/vault/helper/namespace" "github.com/hashicorp/vault/helper/namespace"
"github.com/hashicorp/vault/helper/versions"
"github.com/hashicorp/go-multierror" "github.com/hashicorp/go-multierror"
"github.com/hashicorp/go-secure-stdlib/strutil" "github.com/hashicorp/go-secure-stdlib/strutil"
@@ -65,7 +64,7 @@ func (c *Core) reloadMatchingPluginMounts(ctx context.Context, ns *namespace.Nam
errors = multierror.Append(errors, fmt.Errorf("cannot reload plugin on %q: %w", mount, err)) errors = multierror.Append(errors, fmt.Errorf("cannot reload plugin on %q: %w", mount, err))
continue continue
} }
c.logger.Info("successfully reloaded plugin", "plugin", entry.Accessor, "path", entry.Path, "version", entry.Version) c.logger.Info("successfully reloaded plugin", "plugin", entry.Accessor, "path", entry.Path, "version", entry.RunningVersion)
} }
return errors return errors
} }
@@ -106,7 +105,7 @@ func (c *Core) reloadMatchingPlugin(ctx context.Context, ns *namespace.Namespace
return reloaded, err return reloaded, err
} }
reloaded++ reloaded++
c.logger.Info("successfully reloaded plugin", "plugin", pluginName, "namespace", entry.Namespace(), "path", entry.Path, "version", entry.Version) c.logger.Info("successfully reloaded plugin", "plugin", pluginName, "namespace", entry.Namespace(), "path", entry.Path, "version", entry.RunningVersion)
} else if database && entry.Type == "database" { } else if database && entry.Type == "database" {
// The combined database plugin is itself a secrets engine, but // The combined database plugin is itself a secrets engine, but
// knowledge of whether a database plugin is in use within a particular // knowledge of whether a database plugin is in use within a particular
@@ -152,7 +151,7 @@ func (c *Core) reloadMatchingPlugin(ctx context.Context, ns *namespace.Namespace
return reloaded, err return reloaded, err
} }
reloaded++ reloaded++
c.logger.Info("successfully reloaded plugin", "plugin", entry.Accessor, "path", entry.Path, "version", entry.Version) c.logger.Info("successfully reloaded plugin", "plugin", entry.Accessor, "path", entry.Path, "version", entry.RunningVersion)
} }
} }
} }
@@ -224,9 +223,9 @@ func (c *Core) reloadBackendCommon(ctx context.Context, entry *MountEntry, isAut
oldSha := entry.RunningSha256 oldSha := entry.RunningSha256
if !isAuth { if !isAuth {
// Dispense a new backend // Dispense a new backend
backend, entry.RunningSha256, err = c.newLogicalBackend(ctx, entry, sysView, view) backend, err = c.newLogicalBackend(ctx, entry, sysView, view)
} else { } else {
backend, entry.RunningSha256, err = c.newCredentialBackend(ctx, entry, sysView, view) backend, err = c.newCredentialBackend(ctx, entry, sysView, view)
} }
if err != nil { if err != nil {
return err return err
@@ -235,19 +234,6 @@ func (c *Core) reloadBackendCommon(ctx context.Context, entry *MountEntry, isAut
return fmt.Errorf("nil backend of type %q returned from creation function", entry.Type) return fmt.Errorf("nil backend of type %q returned from creation function", entry.Type)
} }
// update the entry running version with the configured version, which was verified during registration.
entry.RunningVersion = entry.Version
if entry.RunningVersion == "" {
// don't set the running version to a builtin if it is running as an external plugin
if entry.RunningSha256 == "" {
if isAuth {
entry.RunningVersion = versions.GetBuiltinVersion(consts.PluginTypeCredential, entry.Type)
} else {
entry.RunningVersion = versions.GetBuiltinVersion(consts.PluginTypeSecrets, entry.Type)
}
}
}
// update the mount table since we changed the runningSha // update the mount table since we changed the runningSha
if oldSha != entry.RunningSha256 && MountTableUpdateStorage { if oldSha != entry.RunningSha256 && MountTableUpdateStorage {
if isAuth { if isAuth {

120
vault/plugincatalog/pin.go Normal file
View File

@@ -0,0 +1,120 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package plugincatalog
import (
"context"
"encoding/json"
"fmt"
"path"
"strings"
"github.com/hashicorp/vault/sdk/helper/consts"
"github.com/hashicorp/vault/sdk/helper/pluginutil"
"github.com/hashicorp/vault/sdk/logical"
)
const (
pinnedVersionStoragePrefix = "pinned"
)
func pinnedVersionStorageKey(pluginType consts.PluginType, pluginName string) string {
return path.Join(pinnedVersionStoragePrefix, pluginType.String(), pluginName)
}
// SetPinnedVersion creates a pinned version for the given plugin name and type.
func (c *PluginCatalog) SetPinnedVersion(ctx context.Context, pin *pluginutil.PinnedVersion) error {
c.lock.Lock()
defer c.lock.Unlock()
plugin, err := c.get(ctx, pin.Name, pin.Type, pin.Version)
if err != nil {
return err
}
if plugin == nil {
return fmt.Errorf("%s plugin %q version %s does not exist", pin.Type.String(), pin.Name, pin.Version)
}
bytes, err := json.Marshal(pin)
if err != nil {
return fmt.Errorf("failed to encode pinned version entry: %w", err)
}
logicalEntry := logical.StorageEntry{
Key: path.Join(pinnedVersionStoragePrefix, pin.Type.String(), pin.Name),
Value: bytes,
}
if err := c.catalogView.Put(ctx, &logicalEntry); err != nil {
return fmt.Errorf("failed to persist pinned version entry: %w", err)
}
return nil
}
// GetPinnedVersion returns the pinned version for the given plugin name and type.
func (c *PluginCatalog) GetPinnedVersion(ctx context.Context, pluginType consts.PluginType, pluginName string) (*pluginutil.PinnedVersion, error) {
c.lock.RLock()
defer c.lock.RUnlock()
return c.getPinnedVersionInternal(ctx, pinnedVersionStorageKey(pluginType, pluginName))
}
func (c *PluginCatalog) getPinnedVersionInternal(ctx context.Context, key string) (*pluginutil.PinnedVersion, error) {
logicalEntry, err := c.catalogView.Get(ctx, key)
if err != nil {
return nil, fmt.Errorf("failed to retrieve pinned version entry: %w", err)
}
if logicalEntry == nil {
return nil, pluginutil.ErrPinnedVersionNotFound
}
var pin pluginutil.PinnedVersion
if err := json.Unmarshal(logicalEntry.Value, &pin); err != nil {
return nil, fmt.Errorf("failed to decode pinned version entry: %w", err)
}
return &pin, nil
}
// DeletePinnedVersion deletes the pinned version for the given plugin name and type.
func (c *PluginCatalog) DeletePinnedVersion(ctx context.Context, pluginType consts.PluginType, pluginName string) error {
c.lock.Lock()
defer c.lock.Unlock()
if err := c.catalogView.Delete(ctx, path.Join(pinnedVersionStoragePrefix, pluginType.String(), pluginName)); err != nil {
return fmt.Errorf("failed to delete pinned version entry: %w", err)
}
return nil
}
// ListPinnedVersions returns a list of pinned versions for the given plugin type.
func (c *PluginCatalog) ListPinnedVersions(ctx context.Context) ([]*pluginutil.PinnedVersion, error) {
c.lock.RLock()
defer c.lock.RUnlock()
keys, err := logical.CollectKeys(ctx, c.catalogView)
if err != nil {
return nil, err
}
var pinnedVersions []*pluginutil.PinnedVersion
for _, key := range keys {
// Skip: plugin entry.
if !strings.HasPrefix(key, pinnedVersionStoragePrefix) {
continue
}
pin, err := c.getPinnedVersionInternal(ctx, key)
if err != nil {
return nil, err
}
pinnedVersions = append(pinnedVersions, pin)
}
return pinnedVersions, nil
}

View File

@@ -0,0 +1,97 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package plugincatalog
import (
"context"
"os"
"path/filepath"
"testing"
"github.com/hashicorp/vault/sdk/helper/consts"
"github.com/hashicorp/vault/sdk/helper/pluginutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestPluginCatalog_PinnedVersionCRUD tests the CRUD operations for pinned
// versions.
func TestPluginCatalog_PinnedVersionCRUD(t *testing.T) {
catalog := testPluginCatalog(t)
// Register a plugin in the catalog.
file, err := os.CreateTemp(catalog.directory, "temp")
if err != nil {
t.Fatal(err)
}
defer file.Close()
for _, version := range []string{"1.0.0", "2.0.0"} {
err = catalog.Set(context.Background(), pluginutil.SetPluginInput{
Name: "my-plugin",
Type: consts.PluginTypeSecrets,
Version: version,
Command: filepath.Base(file.Name()),
})
require.NoError(t, err)
}
// List pinned versions before creating a pin.
pinnedVersionsBefore, err := catalog.ListPinnedVersions(context.Background())
require.NoError(t, err)
assert.Empty(t, pinnedVersionsBefore)
// Create a pinned version.
pin := pluginutil.PinnedVersion{
Name: "my-plugin",
Type: consts.PluginTypeSecrets,
Version: "1.0.0",
}
err = catalog.SetPinnedVersion(context.Background(), &pin)
require.NoError(t, err)
// List pinned versions after creating a pin.
pinnedVersionsAfter, err := catalog.ListPinnedVersions(context.Background())
require.NoError(t, err)
require.Len(t, pinnedVersionsAfter, 1)
assert.Equal(t, pin, *pinnedVersionsAfter[0])
// Get the pinned version.
pinnedVersion, err := catalog.GetPinnedVersion(context.Background(), pin.Type, pin.Name)
require.NoError(t, err)
assert.Equal(t, pin, *pinnedVersion)
// Update the pinned version.
pin.Version = "2.0.0"
err = catalog.SetPinnedVersion(context.Background(), &pin)
require.NoError(t, err)
// Get the updated pinned version.
pinnedVersion, err = catalog.GetPinnedVersion(context.Background(), pin.Type, pin.Name)
require.NoError(t, err)
assert.Equal(t, pin, *pinnedVersion)
// Update to a version that isn't in the catalog.
pin.Version = "3.0.0"
err = catalog.SetPinnedVersion(context.Background(), &pin)
assert.Error(t, err)
// Delete the pinned version.
err = catalog.DeletePinnedVersion(context.Background(), pin.Type, pin.Name)
require.NoError(t, err)
// Delete it again, should not error (idempotent).
err = catalog.DeletePinnedVersion(context.Background(), pin.Type, pin.Name)
require.NoError(t, err)
// Verify that the pinned version is deleted.
pinnedVersion, err = catalog.GetPinnedVersion(context.Background(), pin.Type, pin.Name)
assert.Equal(t, pluginutil.ErrPinnedVersionNotFound, err)
assert.Nil(t, pinnedVersion)
// List should be empty again.
pinnedVersionsAfterDelete, err := catalog.ListPinnedVersions(context.Background())
require.NoError(t, err)
assert.Empty(t, pinnedVersionsAfterDelete)
}

View File

@@ -38,6 +38,7 @@ var (
ErrPluginNotFound = errors.New("plugin not found in the catalog") ErrPluginNotFound = errors.New("plugin not found in the catalog")
ErrPluginConnectionNotFound = errors.New("plugin connection not found for client") ErrPluginConnectionNotFound = errors.New("plugin connection not found for client")
ErrPluginBadType = errors.New("unable to determine plugin type") ErrPluginBadType = errors.New("unable to determine plugin type")
ErrPinnedVersion = errors.New("cannot delete a pinned version")
) )
// PluginCatalog keeps a record of plugins known to vault. External plugins need // PluginCatalog keeps a record of plugins known to vault. External plugins need
@@ -1013,6 +1014,14 @@ func (c *PluginCatalog) Delete(ctx context.Context, name string, pluginType cons
c.lock.Lock() c.lock.Lock()
defer c.lock.Unlock() defer c.lock.Unlock()
pin, err := c.getPinnedVersionInternal(ctx, pinnedVersionStorageKey(pluginType, name))
if err != nil && !errors.Is(err, pluginutil.ErrPinnedVersionNotFound) {
return err
}
if pin != nil && pin.Version == pluginVersion {
return ErrPinnedVersion
}
// Check the name under which the plugin exists, but if it's unfound, don't return any error. // Check the name under which the plugin exists, but if it's unfound, don't return any error.
pluginKey := path.Join(pluginType.String(), name) pluginKey := path.Join(pluginType.String(), name)
if pluginVersion != "" { if pluginVersion != "" {
@@ -1059,6 +1068,10 @@ func (c *PluginCatalog) ListPluginsWithRuntime(ctx context.Context, runtime stri
var ret []string var ret []string
for _, key := range keys { for _, key := range keys {
// Skip: pinned version entry.
if strings.HasPrefix(key, pinnedVersionStoragePrefix) {
continue
}
entry, err := c.catalogView.Get(ctx, key) entry, err := c.catalogView.Get(ctx, key)
if err != nil || entry == nil { if err != nil || entry == nil {
continue continue
@@ -1094,6 +1107,11 @@ func (c *PluginCatalog) listInternal(ctx context.Context, pluginType consts.Plug
unversionedPlugins := make(map[string]struct{}) unversionedPlugins := make(map[string]struct{})
for _, key := range keys { for _, key := range keys {
// Skip: pinned version entry.
if strings.HasPrefix(key, pinnedVersionStoragePrefix) {
continue
}
var semanticVersion *semver.Version var semanticVersion *semver.Version
entry, err := c.catalogView.Get(ctx, key) entry, err := c.catalogView.Get(ctx, key)

View File

@@ -8,6 +8,7 @@ import (
"crypto/sha256" "crypto/sha256"
"encoding/hex" "encoding/hex"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"os" "os"
@@ -71,6 +72,33 @@ func TestPluginCatalog_CRUD(t *testing.T) {
pluginCatalog := testPluginCatalog(t) pluginCatalog := testPluginCatalog(t)
// Register a fake plugin in the catalog.
file, err := os.CreateTemp(pluginCatalog.directory, "temp")
if err != nil {
t.Fatal(err)
}
defer file.Close()
err = pluginCatalog.Set(context.Background(), pluginutil.SetPluginInput{
Name: pluginName,
Type: consts.PluginTypeDatabase,
Version: "1.0.0",
Command: filepath.Base(file.Name()),
})
if err != nil {
t.Fatal(err)
}
// Register a pinned version, should not affect anything below.
err = pluginCatalog.SetPinnedVersion(context.Background(), &pluginutil.PinnedVersion{
Name: pluginName,
Type: consts.PluginTypeDatabase,
Version: "1.0.0",
})
if err != nil {
t.Fatal(err)
}
// Get builtin plugin // Get builtin plugin
p, err := pluginCatalog.Get(context.Background(), pluginName, consts.PluginTypeDatabase, "") p, err := pluginCatalog.Get(context.Background(), pluginName, consts.PluginTypeDatabase, "")
if err != nil { if err != nil {
@@ -106,12 +134,6 @@ func TestPluginCatalog_CRUD(t *testing.T) {
} }
// Set a plugin, test overwriting a builtin plugin // Set a plugin, test overwriting a builtin plugin
file, err := os.CreateTemp(pluginCatalog.directory, "temp")
if err != nil {
t.Fatal(err)
}
defer file.Close()
command := filepath.Base(file.Name()) command := filepath.Base(file.Name())
err = pluginCatalog.Set(context.Background(), pluginutil.SetPluginInput{ err = pluginCatalog.Set(context.Background(), pluginutil.SetPluginInput{
Name: pluginName, Name: pluginName,
@@ -1060,6 +1082,58 @@ func TestExternalPluginInContainer_GetBackendTypeVersion(t *testing.T) {
} }
} }
// TestPluginCatalog_CannotDeletePinnedVersion ensures we cannot delete a
// plugin which is referred to in an active pinned version.
func TestPluginCatalog_CannotDeletePinnedVersion(t *testing.T) {
pluginCatalog := testPluginCatalog(t)
// Register a fake plugin in the catalog.
file, err := os.CreateTemp(pluginCatalog.directory, "temp")
if err != nil {
t.Fatal(err)
}
defer file.Close()
err = pluginCatalog.Set(context.Background(), pluginutil.SetPluginInput{
Name: "my-plugin",
Type: consts.PluginTypeSecrets,
Version: "1.0.0",
Command: filepath.Base(file.Name()),
})
if err != nil {
t.Fatal(err)
}
// Pin a version and check we can't delete it.
err = pluginCatalog.SetPinnedVersion(context.Background(), &pluginutil.PinnedVersion{
Name: "my-plugin",
Type: consts.PluginTypeSecrets,
Version: "1.0.0",
})
if err != nil {
t.Fatal(err)
}
err = pluginCatalog.Delete(context.Background(), "my-plugin", consts.PluginTypeSecrets, "1.0.0")
if err == nil {
t.Fatal("expected error")
}
if !errors.Is(err, ErrPinnedVersion) {
t.Fatal(err)
}
// Now delete the pinned version and we should be able to delete the plugin.
err = pluginCatalog.DeletePinnedVersion(context.Background(), consts.PluginTypeSecrets, "my-plugin")
if err != nil {
t.Fatalf("unexpected error %v", err)
}
err = pluginCatalog.Delete(context.Background(), "my-plugin", consts.PluginTypeSecrets, "1.0.0")
if err != nil {
t.Fatal(err)
}
}
// testRunTestPlugin runs the testFunc which has already been registered to the // testRunTestPlugin runs the testFunc which has already been registered to the
// plugin catalog and returns a pluginClient. This can be called after calling // plugin catalog and returns a pluginClient. This can be called after calling
// TestAddTestPlugin. // TestAddTestPlugin.

View File

@@ -519,7 +519,7 @@ func TestKeyCopy(key []byte) []byte {
return result return result
} }
func TestDynamicSystemView(c *Core, ns *namespace.Namespace) *dynamicSystemView { func TestDynamicSystemView(c *Core, ns *namespace.Namespace) logical.SystemView {
me := &MountEntry{ me := &MountEntry{
Config: MountConfig{ Config: MountConfig{
DefaultLeaseTTL: 24 * time.Hour, DefaultLeaseTTL: 24 * time.Hour,
@@ -534,7 +534,9 @@ func TestDynamicSystemView(c *Core, ns *namespace.Namespace) *dynamicSystemView
me.namespace = ns me.namespace = ns
} }
return &dynamicSystemView{c, me, c.perfStandby} return &extendedSystemViewImpl{
dynamicSystemView{c, me, c.perfStandby},
}
} }
func TestAddTestPlugin(t testing.T, core *Core, name string, pluginType consts.PluginType, version string, testFunc string, env []string) { func TestAddTestPlugin(t testing.T, core *Core, name string, pluginType consts.PluginType, version string, testFunc string, env []string) {