diff --git a/builtin/logical/database/backend.go b/builtin/logical/database/backend.go index 693a22ff97..86b14a537a 100644 --- a/builtin/logical/database/backend.go +++ b/builtin/logical/database/backend.go @@ -298,7 +298,17 @@ func (b *databaseBackend) GetConnectionWithConfig(ctx context.Context, name stri return nil, err } - dbw, err := newDatabaseWrapper(ctx, config.PluginName, config.PluginVersion, b.System(), b.logger) + // Override the configured version if there is a pinned version. + pinnedVersion, err := b.getPinnedVersion(ctx, config.PluginName) + if err != nil { + return nil, err + } + pluginVersion := config.PluginVersion + if pinnedVersion != "" { + pluginVersion = pinnedVersion + } + + dbw, err := newDatabaseWrapper(ctx, config.PluginName, pluginVersion, b.System(), b.logger) if err != nil { return nil, fmt.Errorf("unable to create database instance: %w", err) } diff --git a/builtin/logical/database/path_config_connection.go b/builtin/logical/database/path_config_connection.go index b017009da0..383a2bd458 100644 --- a/builtin/logical/database/path_config_connection.go +++ b/builtin/logical/database/path_config_connection.go @@ -436,58 +436,9 @@ func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc { return logical.ErrorResponse(respErrEmptyPluginName), nil } - if pluginVersionRaw, ok := data.GetOk("plugin_version"); ok { - config.PluginVersion = pluginVersionRaw.(string) - } - - var builtinShadowed bool - if unversionedPlugin, err := b.System().LookupPlugin(ctx, config.PluginName, consts.PluginTypeDatabase); err == nil && !unversionedPlugin.Builtin { - builtinShadowed = true - } - switch { - case config.PluginVersion != "": - semanticVersion, err := version.NewVersion(config.PluginVersion) - if err != nil { - return logical.ErrorResponse("version %q is not a valid semantic version: %s", config.PluginVersion, err), nil - } - - // Canonicalize the version. - config.PluginVersion = "v" + semanticVersion.String() - - if config.PluginVersion == versions.GetBuiltinVersion(consts.PluginTypeDatabase, config.PluginName) { - if builtinShadowed { - return logical.ErrorResponse("database plugin %q, version %s not found, as it is"+ - " overridden by an unversioned plugin of the same name. Omit `plugin_version` to use the unversioned plugin", config.PluginName, config.PluginVersion), nil - } - - config.PluginVersion = "" - } - case builtinShadowed: - // We'll select the unversioned plugin that's been registered. - case req.Operation == logical.CreateOperation: - // No version provided and no unversioned plugin of that name available. - // Pin to the current latest version if any versioned plugins are registered. - plugins, err := b.System().ListVersionedPlugins(ctx, consts.PluginTypeDatabase) - if err != nil { - return nil, err - } - - var versionedCandidates []pluginutil.VersionedPlugin - for _, plugin := range plugins { - if !plugin.Builtin && plugin.Name == config.PluginName && plugin.Version != "" { - versionedCandidates = append(versionedCandidates, plugin) - } - } - - if len(versionedCandidates) != 0 { - // Sort in reverse order. - sort.SliceStable(versionedCandidates, func(i, j int) bool { - return versionedCandidates[i].SemanticVersion.GreaterThan(versionedCandidates[j].SemanticVersion) - }) - - config.PluginVersion = "v" + versionedCandidates[0].SemanticVersion.String() - b.logger.Debug(fmt.Sprintf("pinning %q database plugin version %q from candidates %v", config.PluginName, config.PluginVersion, versionedCandidates)) - } + pluginVersion, respErr, err := b.selectPluginVersion(ctx, config, data, req.Operation) + if respErr != nil || err != nil { + return respErr, err } if allowedRolesRaw, ok := data.GetOk("allowed_roles"); ok { @@ -536,7 +487,7 @@ func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc { } // Create a database plugin and initialize it. - dbw, err := newDatabaseWrapper(ctx, config.PluginName, config.PluginVersion, b.System(), b.logger) + dbw, err := newDatabaseWrapper(ctx, config.PluginName, pluginVersion, b.System(), b.logger) if err != nil { return logical.ErrorResponse("error creating database object: %s", err), nil } @@ -613,6 +564,92 @@ func storeConfig(ctx context.Context, storage logical.Storage, name string, conf return nil } +func (b *databaseBackend) getPinnedVersion(ctx context.Context, pluginName string) (string, error) { + extendedSys, ok := b.System().(logical.ExtendedSystemView) + if !ok { + return "", fmt.Errorf("database backend does not support running as an external plugin") + } + + pin, err := extendedSys.GetPinnedPluginVersion(ctx, consts.PluginTypeDatabase, pluginName) + if errors.Is(err, pluginutil.ErrPinnedVersionNotFound) { + return "", nil + } + if err != nil { + return "", err + } + + return pin.Version, nil +} + +func (b *databaseBackend) selectPluginVersion(ctx context.Context, config *DatabaseConfig, data *framework.FieldData, op logical.Operation) (string, *logical.Response, error) { + pinnedVersion, err := b.getPinnedVersion(ctx, config.PluginName) + if err != nil { + return "", nil, err + } + pluginVersionRaw, ok := data.GetOk("plugin_version") + + switch { + case ok && pinnedVersion != "": + return "", logical.ErrorResponse("cannot specify plugin_version for plugin %q as it is pinned (v%s)", config.PluginName, pinnedVersion), nil + case pinnedVersion != "": + return pinnedVersion, nil, nil + case ok: + config.PluginVersion = pluginVersionRaw.(string) + } + + var builtinShadowed bool + if unversionedPlugin, err := b.System().LookupPlugin(ctx, config.PluginName, consts.PluginTypeDatabase); err == nil && !unversionedPlugin.Builtin { + builtinShadowed = true + } + switch { + case config.PluginVersion != "": + semanticVersion, err := version.NewVersion(config.PluginVersion) + if err != nil { + return "", logical.ErrorResponse("version %q is not a valid semantic version: %s", config.PluginVersion, err), nil + } + + // Canonicalize the version. + config.PluginVersion = "v" + semanticVersion.String() + + if config.PluginVersion == versions.GetBuiltinVersion(consts.PluginTypeDatabase, config.PluginName) { + if builtinShadowed { + return "", logical.ErrorResponse("database plugin %q, version %s not found, as it is"+ + " overridden by an unversioned plugin of the same name. Omit `plugin_version` to use the unversioned plugin", config.PluginName, config.PluginVersion), nil + } + + config.PluginVersion = "" + } + case builtinShadowed: + // We'll select the unversioned plugin that's been registered. + case op == logical.CreateOperation: + // No version provided and no unversioned plugin of that name available. + // Pin to the current latest version if any versioned plugins are registered. + plugins, err := b.System().ListVersionedPlugins(ctx, consts.PluginTypeDatabase) + if err != nil { + return "", nil, err + } + + var versionedCandidates []pluginutil.VersionedPlugin + for _, plugin := range plugins { + if !plugin.Builtin && plugin.Name == config.PluginName && plugin.Version != "" { + versionedCandidates = append(versionedCandidates, plugin) + } + } + + if len(versionedCandidates) != 0 { + // Sort in reverse order. + sort.SliceStable(versionedCandidates, func(i, j int) bool { + return versionedCandidates[i].SemanticVersion.GreaterThan(versionedCandidates[j].SemanticVersion) + }) + + config.PluginVersion = "v" + versionedCandidates[0].SemanticVersion.String() + b.logger.Debug(fmt.Sprintf("pinning %q database plugin version %q from candidates %v", config.PluginName, config.PluginVersion, versionedCandidates)) + } + } + + return config.PluginVersion, nil, nil +} + const pathConfigConnectionHelpSyn = ` Configure connection details to a database plugin. ` diff --git a/sdk/helper/pluginutil/runner.go b/sdk/helper/pluginutil/runner.go index c627204b1a..dc0432ef98 100644 --- a/sdk/helper/pluginutil/runner.go +++ b/sdk/helper/pluginutil/runner.go @@ -5,6 +5,7 @@ package pluginutil import ( "context" + "errors" "strings" "time" @@ -17,6 +18,9 @@ import ( "google.golang.org/grpc" ) +// ErrPluginNotFound is returned when a plugin does not have a pinned version. +var ErrPinnedVersionNotFound = errors.New("pinned version not found") + // Looker defines the plugin Lookup function that looks into the plugin catalog // for available plugins and returns a PluginRunner type Looker interface { @@ -144,6 +148,12 @@ type VersionedPlugin struct { SemanticVersion *version.Version `json:"-"` } +type PinnedVersion struct { + Name string `json:"name"` + Type consts.PluginType `json:"type"` + Version string `json:"version"` +} + // CtxCancelIfCanceled takes a context cancel func and a context. If the context is // shutdown the cancelfunc is called. This is useful for merging two cancel // functions. diff --git a/sdk/logical/system_view.go b/sdk/logical/system_view.go index 50a6080996..cecbc261e1 100644 --- a/sdk/logical/system_view.go +++ b/sdk/logical/system_view.go @@ -127,6 +127,9 @@ type ExtendedSystemView interface { // APILockShouldBlockRequest returns whether a namespace for the requested // mount is locked and should be blocked APILockShouldBlockRequest() (bool, error) + + // GetPinnedPluginVersion returns the pinned version for the given plugin, if any. + GetPinnedPluginVersion(ctx context.Context, pluginType consts.PluginType, pluginName string) (*pluginutil.PinnedVersion, error) } type PasswordGenerator func() (password string, err error) diff --git a/vault/auth.go b/vault/auth.go index 7c62e77706..ab74c56301 100644 --- a/vault/auth.go +++ b/vault/auth.go @@ -175,7 +175,7 @@ func (c *Core) enableCredentialInternal(ctx context.Context, entry *MountEntry, var backend logical.Backend // Create the new backend sysView := c.mountEntrySysView(entry) - backend, entry.RunningSha256, err = c.newCredentialBackend(ctx, entry, sysView, view) + backend, err = c.newCredentialBackend(ctx, entry, sysView, view) if err != nil { return err } @@ -188,14 +188,6 @@ func (c *Core) enableCredentialInternal(ctx context.Context, entry *MountEntry, if backendType != logical.TypeCredential { return fmt.Errorf("cannot mount %q of type %q as an auth backend", entry.Type, backendType) } - // update the entry running version with the configured version, which was verified during registration. - entry.RunningVersion = entry.Version - if entry.RunningVersion == "" { - // don't set the running version to a builtin if it is running as an external plugin - if entry.RunningSha256 == "" { - entry.RunningVersion = versions.GetBuiltinVersion(consts.PluginTypeCredential, entry.Type) - } - } addPathCheckers(c, entry, backend, viewPath) // If the mount is filtered or we are on a DR secondary we don't want to @@ -249,7 +241,7 @@ func (c *Core) enableCredentialInternal(ctx context.Context, entry *MountEntry, } if c.logger.IsInfo() { - c.logger.Info("enabled credential backend", "path", entry.Path, "type", entry.Type, "version", entry.Version) + c.logger.Info("enabled credential backend", "path", entry.Path, "type", entry.Type, "version", entry.RunningVersion) } return nil } @@ -805,29 +797,24 @@ func (c *Core) setupCredentials(ctx context.Context) error { // Initialize the backend sysView := c.mountEntrySysView(entry) - backend, entry.RunningSha256, err = c.newCredentialBackend(ctx, entry, sysView, view) + backend, err = c.newCredentialBackend(ctx, entry, sysView, view) if err != nil { c.logger.Error("failed to create credential entry", "path", entry.Path, "error", err) - if c.isMountable(ctx, entry, consts.PluginTypeCredential) { + mountable, checkErr := c.isMountable(ctx, entry, consts.PluginTypeSecrets) + if checkErr != nil { + return errors.Join(errLoadMountsFailed, checkErr, err) + } + if mountable { c.logger.Warn("skipping plugin-based auth entry", "path", entry.Path) goto ROUTER_MOUNT } - return errLoadAuthFailed + return errors.Join(errLoadAuthFailed, err) } if backend == nil { return fmt.Errorf("nil backend returned from %q factory", entry.Type) } - // update the entry running version with the configured version, which was verified during registration. - entry.RunningVersion = entry.Version - if entry.RunningVersion == "" { - // don't set the running version to a builtin if it is running as an external plugin - if entry.RunningSha256 == "" { - entry.RunningVersion = versions.GetBuiltinVersion(consts.PluginTypeCredential, entry.Type) - } - } - // Do not start up deprecated builtin plugins. If this is a major // upgrade, stop unsealing and shutdown. If we've already mounted this // plugin, skip backend initialization and mount the data for posterity. @@ -952,34 +939,37 @@ func (c *Core) teardownCredentials(ctx context.Context) error { } // newCredentialBackend is used to create and configure a new credential backend by name. -// It also returns the SHA256 of the plugin, if available. -func (c *Core) newCredentialBackend(ctx context.Context, entry *MountEntry, sysView logical.SystemView, view logical.Storage) (logical.Backend, string, error) { +func (c *Core) newCredentialBackend(ctx context.Context, entry *MountEntry, sysView logical.SystemView, view logical.Storage) (logical.Backend, error) { t := entry.Type if alias, ok := credentialAliases[t]; ok { t = alias } + pluginVersion, err := c.resolveMountEntryVersion(ctx, consts.PluginTypeCredential, entry) + if err != nil { + return nil, err + } var runningSha string - f, ok := c.credentialBackends[t] + factory, ok := c.credentialBackends[t] if !ok { - plug, err := c.pluginCatalog.Get(ctx, t, consts.PluginTypeCredential, entry.Version) + plug, err := c.pluginCatalog.Get(ctx, t, consts.PluginTypeCredential, pluginVersion) if err != nil { - return nil, "", err + return nil, err } if plug == nil { errContext := t - if entry.Version != "" { - errContext += fmt.Sprintf(", version=%s", entry.Version) + if pluginVersion != "" { + errContext += fmt.Sprintf(", version=%s", pluginVersion) } - return nil, "", fmt.Errorf("%w: %s", plugincatalog.ErrPluginNotFound, errContext) + return nil, fmt.Errorf("%w: %s", plugincatalog.ErrPluginNotFound, errContext) } if len(plug.Sha256) > 0 { runningSha = hex.EncodeToString(plug.Sha256) } - f = plugin.Factory + factory = plugin.Factory if !plug.Builtin { - f = wrapFactoryCheckPerms(c, plugin.Factory) + factory = wrapFactoryCheckPerms(c, plugin.Factory) } } // Set up conf to pass in plugin_name @@ -996,7 +986,7 @@ func (c *Core) newCredentialBackend(ctx context.Context, entry *MountEntry, sysV } conf["plugin_type"] = consts.PluginTypeCredential.String() - conf["plugin_version"] = entry.Version + conf["plugin_version"] = pluginVersion authLogger := c.baseLogger.Named(fmt.Sprintf("auth.%s.%s", t, entry.Accessor)) c.AddLogger(authLogger) @@ -1005,11 +995,11 @@ func (c *Core) newCredentialBackend(ctx context.Context, entry *MountEntry, sysV MountAccessor: entry.Accessor, MountPath: entry.Path, Plugin: entry.Type, - PluginVersion: entry.RunningVersion, - Version: entry.Version, + PluginVersion: pluginVersion, + Version: entry.Options["version"], }) if err != nil { - return nil, "", err + return nil, err } config := &logical.BackendConfig{ @@ -1021,12 +1011,19 @@ func (c *Core) newCredentialBackend(ctx context.Context, entry *MountEntry, sysV EventsSender: pluginEventSender, } - b, err := f(ctx, config) + backend, err := factory(ctx, config) if err != nil { - return nil, "", err + return nil, err + } + if backend != nil { + entry.RunningVersion = pluginVersion + entry.RunningSha256 = runningSha + if entry.RunningVersion == "" && entry.RunningSha256 == "" { + entry.RunningVersion = versions.GetBuiltinVersion(consts.PluginTypeCredential, entry.Type) + } } - return b, runningSha, nil + return backend, nil } func wrapFactoryCheckPerms(core *Core, f logical.Factory) logical.Factory { diff --git a/vault/core.go b/vault/core.go index 880acf0ac3..2b0cdbde43 100644 --- a/vault/core.go +++ b/vault/core.go @@ -3546,16 +3546,17 @@ func (c *Core) readFeatureFlags(ctx context.Context) (*FeatureFlags, error) { // misconfigured. This allows users to recover from errors when starting Vault // with misconfigured plugins. It should not be possible for existing builtins // to be misconfigured, so that is a fatal error. -func (c *Core) isMountable(ctx context.Context, entry *MountEntry, pluginType consts.PluginType) bool { - return !c.isMountEntryBuiltin(ctx, entry, pluginType) +func (c *Core) isMountable(ctx context.Context, entry *MountEntry, pluginType consts.PluginType) (bool, error) { + builtin, err := c.isMountEntryBuiltin(ctx, entry, pluginType) + return !builtin, err } // isMountEntryBuiltin determines whether a mount entry is associated with a // builtin of the specified plugin type. -func (c *Core) isMountEntryBuiltin(ctx context.Context, entry *MountEntry, pluginType consts.PluginType) bool { +func (c *Core) isMountEntryBuiltin(ctx context.Context, entry *MountEntry, pluginType consts.PluginType) (bool, error) { // Prevent a panic early on if entry == nil || c.pluginCatalog == nil { - return false + return false, nil } // Allow type to be determined from mount entry when not otherwise specified @@ -3569,12 +3570,16 @@ func (c *Core) isMountEntryBuiltin(ctx context.Context, entry *MountEntry, plugi pluginName = alias } - plug, err := c.pluginCatalog.Get(ctx, pluginName, pluginType, entry.Version) + pluginVersion, err := c.resolveMountEntryVersion(ctx, pluginType, entry) + if err != nil { + return false, err + } + plug, err := c.pluginCatalog.Get(ctx, pluginName, pluginType, pluginVersion) if err != nil || plug == nil { - return false + return false, nil } - return plug.Builtin + return plug.Builtin, nil } // MatchingMount returns the path of the mount that will be responsible for diff --git a/vault/dynamic_system_view.go b/vault/dynamic_system_view.go index 4ae8c83db6..2feae0a9a4 100644 --- a/vault/dynamic_system_view.go +++ b/vault/dynamic_system_view.go @@ -161,6 +161,11 @@ func (e extendedSystemViewImpl) DeregisterWellKnownRedirect(ctx context.Context, return e.core.WellKnownRedirects.DeregisterSource(e.mountEntry.UUID, src) } +// GetPinnedPluginVersion implements logical.ExtendedSystemView. +func (e extendedSystemViewImpl) GetPinnedPluginVersion(ctx context.Context, pluginType consts.PluginType, pluginName string) (*pluginutil.PinnedVersion, error) { + return e.core.pluginCatalog.GetPinnedVersion(ctx, pluginType, pluginName) +} + func (d dynamicSystemView) DefaultLeaseTTL() time.Duration { def, _ := d.fetchTTLs() return def diff --git a/vault/external_plugin_test.go b/vault/external_plugin_test.go index 03151d985e..a332df23b6 100644 --- a/vault/external_plugin_test.go +++ b/vault/external_plugin_test.go @@ -17,6 +17,7 @@ import ( "github.com/hashicorp/vault/helper/testhelpers/pluginhelpers" "github.com/hashicorp/vault/sdk/framework" "github.com/hashicorp/vault/sdk/helper/consts" + "github.com/hashicorp/vault/sdk/helper/pluginutil" "github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/sdk/plugin" "github.com/hashicorp/vault/sdk/plugin/mock" @@ -95,6 +96,98 @@ func TestCore_EnableExternalPlugin(t *testing.T) { } } +// TestCore_UpgradePluginUsingPinnedVersion tests a full workflow of upgrading +// an external plugin gated by pinned versions. +func TestCore_UpgradePluginUsingPinnedVersion(t *testing.T) { + cluster := NewTestCluster(t, &CoreConfig{}, &TestClusterOptions{ + Plugins: []*TestPluginConfig{ + { + Typ: consts.PluginTypeCredential, + Versions: []string{""}, + }, + { + Typ: consts.PluginTypeSecrets, + Versions: []string{""}, + }, + }, + }) + + cluster.Start() + t.Cleanup(cluster.Cleanup) + + c := cluster.Cores[0].Core + TestWaitActive(t, c) + + for name, tc := range map[string]struct { + idx int + }{ + "credential plugin": { + idx: 0, + }, + "secrets plugin": { + idx: 1, + }, + } { + t.Run(name, func(t *testing.T) { + plugin := cluster.Plugins[tc.idx] + for _, version := range []string{"v1.0.0", "v1.0.1"} { + registerPlugin(t, c.systemBackend, plugin.Name, plugin.Typ.String(), version, plugin.Sha256, plugin.FileName) + } + + // Mount 1.0.0 then pin to 1.0.1 + mountPlugin(t, c.systemBackend, plugin.Name, plugin.Typ, "v1.0.0", "") + err := c.pluginCatalog.SetPinnedVersion(context.Background(), &pluginutil.PinnedVersion{ + Name: plugin.Name, + Type: plugin.Typ, + Version: "v1.0.1", + }) + if err != nil { + t.Fatal(err) + } + + mountedPath := "foo/" + if plugin.Typ == consts.PluginTypeCredential { + mountedPath = "auth/" + mountedPath + } + expectRunningVersion(t, c, mountedPath, "v1.0.0") + + reloaded, err := c.reloadMatchingPlugin(context.Background(), nil, plugin.Typ, plugin.Name) + if reloaded != 1 || err != nil { + t.Fatal(reloaded, err) + } + + // Pinned version should be in effect after reloading. + expectRunningVersion(t, c, mountedPath, "v1.0.1") + + err = c.pluginCatalog.DeletePinnedVersion(context.Background(), plugin.Typ, plugin.Name) + if err != nil { + t.Fatal(err) + } + + reloaded, err = c.reloadMatchingPlugin(context.Background(), nil, plugin.Typ, plugin.Name) + if reloaded != 1 || err != nil { + t.Fatal(reloaded, err) + } + + // After pin is deleted, the previously configured version should stand. + expectRunningVersion(t, c, mountedPath, "v1.0.0") + }) + } +} + +func expectRunningVersion(t *testing.T, c *Core, path, expectedVersion string) { + t.Helper() + match := c.router.MatchingMount(namespace.RootContext(context.Background()), path) + if match != path { + t.Fatalf("missing mount for %s, match: %q", path, match) + } + + raw, _ := c.router.root.Get(match) + if actual := raw.(*routeEntry).mountEntry.RunningVersion; expectedVersion != actual { + t.Fatalf("expected running_plugin_version to be %s but got %s", expectedVersion, actual) + } +} + func TestCore_EnableExternalPlugin_MultipleVersions(t *testing.T) { for name, tc := range map[string]struct { pluginType consts.PluginType diff --git a/vault/identity_store_entities_test.go b/vault/identity_store_entities_test.go index 11b8024f46..260b6a9cea 100644 --- a/vault/identity_store_entities_test.go +++ b/vault/identity_store_entities_test.go @@ -691,7 +691,7 @@ func TestIdentityStore_LoadingEntities(t *testing.T) { ghSysview := c.mountEntrySysView(meGH) // Create new github auth credential backend - ghAuth, _, err := c.newCredentialBackend(context.Background(), meGH, ghSysview, ghView) + ghAuth, err := c.newCredentialBackend(context.Background(), meGH, ghSysview, ghView) if err != nil { t.Fatal(err) } diff --git a/vault/logical_system.go b/vault/logical_system.go index 3441911f8c..857aef3b62 100644 --- a/vault/logical_system.go +++ b/vault/logical_system.go @@ -1535,7 +1535,11 @@ func (b *SystemBackend) handleMount(ctx context.Context, req *logical.Request, d Version: pluginVersion, } - if b.Core.isMountEntryBuiltin(ctx, me, consts.PluginTypeSecrets) { + builtin, err := b.Core.isMountEntryBuiltin(ctx, me, consts.PluginTypeSecrets) + if err != nil { + return nil, err + } + if builtin { resp, err = b.Core.handleDeprecatedMountEntry(ctx, me, consts.PluginTypeSecrets) if err != nil { b.Core.logger.Error("could not mount builtin", "name", me.Type, "path", me.Path, "error", err) @@ -1949,7 +1953,8 @@ func (b *SystemBackend) handleTuneReadCommon(ctx context.Context, path string) ( resp.Data["external_entropy_access"] = true } - if mountEntry.Table == credentialTableType { + isAuth := mountEntry.Table == credentialTableType + if isAuth { resp.Data["token_type"] = mountEntry.Config.TokenType.String() } @@ -1995,6 +2000,19 @@ func (b *SystemBackend) handleTuneReadCommon(ctx context.Context, path string) ( if mountEntry.Version != "" { resp.Data["plugin_version"] = mountEntry.Version } + var pinnedVersion *pluginutil.PinnedVersion + var err error + if isAuth { + pinnedVersion, err = b.Core.pluginCatalog.GetPinnedVersion(ctx, consts.PluginTypeCredential, mountEntry.Type) + } else { + pinnedVersion, err = b.Core.pluginCatalog.GetPinnedVersion(ctx, consts.PluginTypeSecrets, mountEntry.Type) + } + if err != nil && !errors.Is(err, pluginutil.ErrPinnedVersionNotFound) { + return nil, err + } + if pinnedVersion != nil && mountEntry.Version != pinnedVersion.Version { + resp.AddWarning(fmt.Sprintf("plugin_version is configured as %s but a version pin for %s is in effect", mountEntry.Version, pinnedVersion.Version)) + } return resp, nil } @@ -2236,6 +2254,19 @@ func (b *SystemBackend) handleTuneWriteCommon(ctx context.Context, path string, } if rawVal, ok := data.GetOk("plugin_version"); ok { + pluginType := consts.PluginTypeSecrets + if strings.HasPrefix(path, "auth/") { + pluginType = consts.PluginTypeCredential + } + + pinnedVersion, err := b.Core.pluginCatalog.GetPinnedVersion(ctx, pluginType, mountEntry.Type) + if err != nil && !errors.Is(err, pluginutil.ErrPinnedVersionNotFound) { + return nil, err + } + if pinnedVersion != nil { + return logical.ErrorResponse(fmt.Sprintf("plugin_version cannot be set for %s plugin %q as a pinned version %s is in effect", pluginType, mountEntry.Type, pinnedVersion.Version)), nil + } + version := rawVal.(string) semanticVersion, err := semver.NewVersion(version) if err != nil { @@ -2244,10 +2275,6 @@ func (b *SystemBackend) handleTuneWriteCommon(ctx context.Context, path string, version = "v" + semanticVersion.String() // Lookup the version to ensure it exists in the catalog before committing. - pluginType := consts.PluginTypeSecrets - if strings.HasPrefix(path, "auth/") { - pluginType = consts.PluginTypeCredential - } _, err = b.System().LookupPluginVersion(ctx, mountEntry.Type, pluginType, version) if err != nil { return handleError(err) @@ -3106,7 +3133,11 @@ func (b *SystemBackend) handleEnableAuth(ctx context.Context, req *logical.Reque } var resp *logical.Response - if b.Core.isMountEntryBuiltin(ctx, me, consts.PluginTypeCredential) { + builtin, err := b.Core.isMountEntryBuiltin(ctx, me, consts.PluginTypeCredential) + if err != nil { + return nil, err + } + if builtin { resp, err = b.Core.handleDeprecatedMountEntry(ctx, me, consts.PluginTypeCredential) if err != nil { b.Core.logger.Error("could not mount builtin", "name", me.Type, "path", me.Path, "error", err) @@ -3123,6 +3154,18 @@ func (b *SystemBackend) handleEnableAuth(ctx context.Context, req *logical.Reque } func (b *SystemBackend) validateVersion(ctx context.Context, version string, pluginName string, pluginType consts.PluginType) (string, *logical.Response, error) { + pinnedVersion, err := b.Core.pluginCatalog.GetPinnedVersion(ctx, pluginType, pluginName) + if err != nil && !errors.Is(err, pluginutil.ErrPinnedVersionNotFound) { + return "", nil, err + } + if pinnedVersion != nil { + if version != "" { + return "", logical.ErrorResponse("cannot specify plugin_version for %s plugin %q, as it is pinned to version %s", pluginType.String(), pluginName, pinnedVersion.Version), nil + } + + return pinnedVersion.Version, nil, nil + } + switch version { case "": var err error diff --git a/vault/mount.go b/vault/mount.go index f899564357..f9caf4f461 100644 --- a/vault/mount.go +++ b/vault/mount.go @@ -23,6 +23,7 @@ import ( "github.com/hashicorp/vault/helper/versions" "github.com/hashicorp/vault/sdk/helper/consts" "github.com/hashicorp/vault/sdk/helper/jsonutil" + "github.com/hashicorp/vault/sdk/helper/pluginutil" "github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/vault/plugincatalog" "github.com/mitchellh/copystructure" @@ -347,8 +348,8 @@ type MountEntry struct { synthesizedConfigCache sync.Map // version info - Version string `json:"plugin_version,omitempty"` // The semantic version of the mounted plugin, e.g. v1.2.3. - RunningVersion string `json:"running_plugin_version,omitempty"` // The semantic version of the mounted plugin as reported by the plugin. + Version string `json:"plugin_version,omitempty"` // The configured semantic version of the mounted plugin, e.g. v1.2.3. May be overridden by a pinned version. + RunningVersion string `json:"running_plugin_version,omitempty"` // The semantic version of the currently running mounted plugin. RunningSha256 string `json:"running_sha256,omitempty"` } @@ -703,13 +704,10 @@ func (c *Core) mountInternal(ctx context.Context, entry *MountEntry, updateStora var backend logical.Backend sysView := c.mountEntrySysView(entry) - backend, entry.RunningSha256, err = c.newLogicalBackend(ctx, entry, sysView, view) + backend, err = c.newLogicalBackend(ctx, entry, sysView, view) if err != nil { return err } - if backend == nil { - return fmt.Errorf("nil backend of type %q returned from creation function", entry.Type) - } // Check for the correct backend type backendType := backend.Type() @@ -719,15 +717,6 @@ func (c *Core) mountInternal(ctx context.Context, entry *MountEntry, updateStora } } - // update the entry running version with the configured version, which was verified during registration. - entry.RunningVersion = entry.Version - if entry.RunningVersion == "" { - // don't set the running version to a builtin if it is running as an external plugin - if entry.RunningSha256 == "" { - entry.RunningVersion = versions.GetBuiltinVersion(consts.PluginTypeSecrets, entry.Type) - } - } - addPathCheckers(c, entry, backend, viewPath) c.setCoreBackend(entry, backend, view) @@ -788,7 +777,7 @@ func (c *Core) mountInternal(ctx context.Context, entry *MountEntry, updateStora } if c.logger.IsInfo() { - c.logger.Info("successful mount", "namespace", entry.Namespace().Path, "path", entry.Path, "type", entry.Type, "version", entry.Version) + c.logger.Info("successful mount", "namespace", entry.Namespace().Path, "path", entry.Path, "type", entry.Type, "version", entry.RunningVersion) } return nil } @@ -1543,27 +1532,19 @@ func (c *Core) setupMounts(ctx context.Context) error { var backend logical.Backend // Create the new backend sysView := c.mountEntrySysView(entry) - backend, entry.RunningSha256, err = c.newLogicalBackend(ctx, entry, sysView, view) + backend, err = c.newLogicalBackend(ctx, entry, sysView, view) if err != nil { c.logger.Error("failed to create mount entry", "path", entry.Path, "error", err) - if c.isMountable(ctx, entry, consts.PluginTypeSecrets) { + mountable, checkErr := c.isMountable(ctx, entry, consts.PluginTypeSecrets) + if checkErr != nil { + return errors.Join(errLoadMountsFailed, checkErr, err) + } + if mountable { c.logger.Warn("skipping plugin-based mount entry", "path", entry.Path) goto ROUTER_MOUNT } - return errLoadMountsFailed - } - if backend == nil { - return fmt.Errorf("created mount entry of type %q is nil", entry.Type) - } - - // update the entry running version with the configured version, which was verified during registration. - entry.RunningVersion = entry.Version - if entry.RunningVersion == "" { - // don't set the running version to a builtin if it is running as an external plugin - if entry.RunningSha256 == "" { - entry.RunningVersion = versions.GetBuiltinVersion(consts.PluginTypeSecrets, entry.Type) - } + return errors.Join(errLoadMountsFailed, err) } // Do not start up deprecated builtin plugins. If this is a major @@ -1680,34 +1661,37 @@ func (c *Core) unloadMounts(ctx context.Context) error { } // newLogicalBackend is used to create and configure a new logical backend by name. -// It also returns the SHA256 of the plugin, if available. -func (c *Core) newLogicalBackend(ctx context.Context, entry *MountEntry, sysView logical.SystemView, view logical.Storage) (logical.Backend, string, error) { +func (c *Core) newLogicalBackend(ctx context.Context, entry *MountEntry, sysView logical.SystemView, view logical.Storage) (logical.Backend, error) { t := entry.Type if alias, ok := mountAliases[t]; ok { t = alias } + pluginVersion, err := c.resolveMountEntryVersion(ctx, consts.PluginTypeSecrets, entry) + if err != nil { + return nil, err + } var runningSha string - f, ok := c.logicalBackends[t] + factory, ok := c.logicalBackends[t] if !ok { - plug, err := c.pluginCatalog.Get(ctx, t, consts.PluginTypeSecrets, entry.Version) + plug, err := c.pluginCatalog.Get(ctx, t, consts.PluginTypeSecrets, pluginVersion) if err != nil { - return nil, "", err + return nil, err } if plug == nil { errContext := t - if entry.Version != "" { - errContext += fmt.Sprintf(", version=%s", entry.Version) + if pluginVersion != "" { + errContext += fmt.Sprintf(", version=%s", pluginVersion) } - return nil, "", fmt.Errorf("%w: %s", plugincatalog.ErrPluginNotFound, errContext) + return nil, fmt.Errorf("%w: %s", plugincatalog.ErrPluginNotFound, errContext) } if len(plug.Sha256) > 0 { runningSha = hex.EncodeToString(plug.Sha256) } - f = plugin.Factory + factory = plugin.Factory if !plug.Builtin { - f = wrapFactoryCheckPerms(c, plugin.Factory) + factory = wrapFactoryCheckPerms(c, factory) } } // Set up conf to pass in plugin_name @@ -1724,7 +1708,7 @@ func (c *Core) newLogicalBackend(ctx context.Context, entry *MountEntry, sysView } conf["plugin_type"] = consts.PluginTypeSecrets.String() - conf["plugin_version"] = entry.Version + conf["plugin_version"] = pluginVersion backendLogger := c.baseLogger.Named(fmt.Sprintf("secrets.%s.%s", t, entry.Accessor)) c.AddLogger(backendLogger) @@ -1733,11 +1717,11 @@ func (c *Core) newLogicalBackend(ctx context.Context, entry *MountEntry, sysView MountAccessor: entry.Accessor, MountPath: entry.Path, Plugin: entry.Type, - PluginVersion: entry.RunningVersion, - Version: entry.Version, + PluginVersion: pluginVersion, + Version: entry.Options["version"], }) if err != nil { - return nil, "", err + return nil, err } config := &logical.BackendConfig{ StorageView: view, @@ -1750,16 +1734,39 @@ func (c *Core) newLogicalBackend(ctx context.Context, entry *MountEntry, sysView ctx = namespace.ContextWithNamespace(ctx, entry.namespace) ctx = context.WithValue(ctx, "core_number", c.coreNumber) - b, err := f(ctx, config) + backend, err := factory(ctx, config) if err != nil { - return nil, "", err + return nil, err } - if b == nil { - return nil, "", fmt.Errorf("nil backend of type %q returned from factory", t) + if backend == nil { + return nil, fmt.Errorf("nil backend of type %q returned from factory", t) } - addLicenseCallback(c, b) - return b, runningSha, nil + entry.RunningVersion = pluginVersion + entry.RunningSha256 = runningSha + if entry.RunningVersion == "" && entry.RunningSha256 == "" { + entry.RunningVersion = versions.GetBuiltinVersion(consts.PluginTypeSecrets, entry.Type) + } + addLicenseCallback(c, backend) + + return backend, nil +} + +// resolveMountEntryVersion allows entry.Version to be overridden if there is a +// corresponding pinned version. +func (c *Core) resolveMountEntryVersion(ctx context.Context, pluginType consts.PluginType, entry *MountEntry) (string, error) { + pluginName := entry.Type + if alias, ok := mountAliases[pluginName]; ok { + pluginName = alias + } + pinnedVersion, err := c.pluginCatalog.GetPinnedVersion(ctx, pluginType, pluginName) + if err != nil && !errors.Is(err, pluginutil.ErrPinnedVersionNotFound) { + return "", err + } + if pinnedVersion != nil { + return pinnedVersion.Version, nil + } + return entry.Version, nil } // defaultMountTable creates a default mount table diff --git a/vault/plugin_reload.go b/vault/plugin_reload.go index 74f395c403..5991469263 100644 --- a/vault/plugin_reload.go +++ b/vault/plugin_reload.go @@ -9,7 +9,6 @@ import ( "strings" "github.com/hashicorp/vault/helper/namespace" - "github.com/hashicorp/vault/helper/versions" "github.com/hashicorp/go-multierror" "github.com/hashicorp/go-secure-stdlib/strutil" @@ -65,7 +64,7 @@ func (c *Core) reloadMatchingPluginMounts(ctx context.Context, ns *namespace.Nam errors = multierror.Append(errors, fmt.Errorf("cannot reload plugin on %q: %w", mount, err)) continue } - c.logger.Info("successfully reloaded plugin", "plugin", entry.Accessor, "path", entry.Path, "version", entry.Version) + c.logger.Info("successfully reloaded plugin", "plugin", entry.Accessor, "path", entry.Path, "version", entry.RunningVersion) } return errors } @@ -106,7 +105,7 @@ func (c *Core) reloadMatchingPlugin(ctx context.Context, ns *namespace.Namespace return reloaded, err } reloaded++ - c.logger.Info("successfully reloaded plugin", "plugin", pluginName, "namespace", entry.Namespace(), "path", entry.Path, "version", entry.Version) + c.logger.Info("successfully reloaded plugin", "plugin", pluginName, "namespace", entry.Namespace(), "path", entry.Path, "version", entry.RunningVersion) } else if database && entry.Type == "database" { // The combined database plugin is itself a secrets engine, but // knowledge of whether a database plugin is in use within a particular @@ -152,7 +151,7 @@ func (c *Core) reloadMatchingPlugin(ctx context.Context, ns *namespace.Namespace return reloaded, err } reloaded++ - c.logger.Info("successfully reloaded plugin", "plugin", entry.Accessor, "path", entry.Path, "version", entry.Version) + c.logger.Info("successfully reloaded plugin", "plugin", entry.Accessor, "path", entry.Path, "version", entry.RunningVersion) } } } @@ -224,9 +223,9 @@ func (c *Core) reloadBackendCommon(ctx context.Context, entry *MountEntry, isAut oldSha := entry.RunningSha256 if !isAuth { // Dispense a new backend - backend, entry.RunningSha256, err = c.newLogicalBackend(ctx, entry, sysView, view) + backend, err = c.newLogicalBackend(ctx, entry, sysView, view) } else { - backend, entry.RunningSha256, err = c.newCredentialBackend(ctx, entry, sysView, view) + backend, err = c.newCredentialBackend(ctx, entry, sysView, view) } if err != nil { return err @@ -235,19 +234,6 @@ func (c *Core) reloadBackendCommon(ctx context.Context, entry *MountEntry, isAut return fmt.Errorf("nil backend of type %q returned from creation function", entry.Type) } - // update the entry running version with the configured version, which was verified during registration. - entry.RunningVersion = entry.Version - if entry.RunningVersion == "" { - // don't set the running version to a builtin if it is running as an external plugin - if entry.RunningSha256 == "" { - if isAuth { - entry.RunningVersion = versions.GetBuiltinVersion(consts.PluginTypeCredential, entry.Type) - } else { - entry.RunningVersion = versions.GetBuiltinVersion(consts.PluginTypeSecrets, entry.Type) - } - } - } - // update the mount table since we changed the runningSha if oldSha != entry.RunningSha256 && MountTableUpdateStorage { if isAuth { diff --git a/vault/plugincatalog/pin.go b/vault/plugincatalog/pin.go new file mode 100644 index 0000000000..981efa0bcd --- /dev/null +++ b/vault/plugincatalog/pin.go @@ -0,0 +1,120 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package plugincatalog + +import ( + "context" + "encoding/json" + "fmt" + "path" + "strings" + + "github.com/hashicorp/vault/sdk/helper/consts" + "github.com/hashicorp/vault/sdk/helper/pluginutil" + "github.com/hashicorp/vault/sdk/logical" +) + +const ( + pinnedVersionStoragePrefix = "pinned" +) + +func pinnedVersionStorageKey(pluginType consts.PluginType, pluginName string) string { + return path.Join(pinnedVersionStoragePrefix, pluginType.String(), pluginName) +} + +// SetPinnedVersion creates a pinned version for the given plugin name and type. +func (c *PluginCatalog) SetPinnedVersion(ctx context.Context, pin *pluginutil.PinnedVersion) error { + c.lock.Lock() + defer c.lock.Unlock() + + plugin, err := c.get(ctx, pin.Name, pin.Type, pin.Version) + if err != nil { + return err + } + if plugin == nil { + return fmt.Errorf("%s plugin %q version %s does not exist", pin.Type.String(), pin.Name, pin.Version) + } + + bytes, err := json.Marshal(pin) + if err != nil { + return fmt.Errorf("failed to encode pinned version entry: %w", err) + } + + logicalEntry := logical.StorageEntry{ + Key: path.Join(pinnedVersionStoragePrefix, pin.Type.String(), pin.Name), + Value: bytes, + } + + if err := c.catalogView.Put(ctx, &logicalEntry); err != nil { + return fmt.Errorf("failed to persist pinned version entry: %w", err) + } + + return nil +} + +// GetPinnedVersion returns the pinned version for the given plugin name and type. +func (c *PluginCatalog) GetPinnedVersion(ctx context.Context, pluginType consts.PluginType, pluginName string) (*pluginutil.PinnedVersion, error) { + c.lock.RLock() + defer c.lock.RUnlock() + + return c.getPinnedVersionInternal(ctx, pinnedVersionStorageKey(pluginType, pluginName)) +} + +func (c *PluginCatalog) getPinnedVersionInternal(ctx context.Context, key string) (*pluginutil.PinnedVersion, error) { + logicalEntry, err := c.catalogView.Get(ctx, key) + if err != nil { + return nil, fmt.Errorf("failed to retrieve pinned version entry: %w", err) + } + + if logicalEntry == nil { + return nil, pluginutil.ErrPinnedVersionNotFound + } + + var pin pluginutil.PinnedVersion + if err := json.Unmarshal(logicalEntry.Value, &pin); err != nil { + return nil, fmt.Errorf("failed to decode pinned version entry: %w", err) + } + + return &pin, nil +} + +// DeletePinnedVersion deletes the pinned version for the given plugin name and type. +func (c *PluginCatalog) DeletePinnedVersion(ctx context.Context, pluginType consts.PluginType, pluginName string) error { + c.lock.Lock() + defer c.lock.Unlock() + + if err := c.catalogView.Delete(ctx, path.Join(pinnedVersionStoragePrefix, pluginType.String(), pluginName)); err != nil { + return fmt.Errorf("failed to delete pinned version entry: %w", err) + } + + return nil +} + +// ListPinnedVersions returns a list of pinned versions for the given plugin type. +func (c *PluginCatalog) ListPinnedVersions(ctx context.Context) ([]*pluginutil.PinnedVersion, error) { + c.lock.RLock() + defer c.lock.RUnlock() + + keys, err := logical.CollectKeys(ctx, c.catalogView) + if err != nil { + return nil, err + } + + var pinnedVersions []*pluginutil.PinnedVersion + for _, key := range keys { + // Skip: plugin entry. + if !strings.HasPrefix(key, pinnedVersionStoragePrefix) { + continue + } + + pin, err := c.getPinnedVersionInternal(ctx, key) + if err != nil { + return nil, err + } + + pinnedVersions = append(pinnedVersions, pin) + } + + return pinnedVersions, nil +} diff --git a/vault/plugincatalog/pin_test.go b/vault/plugincatalog/pin_test.go new file mode 100644 index 0000000000..ff93806a44 --- /dev/null +++ b/vault/plugincatalog/pin_test.go @@ -0,0 +1,97 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package plugincatalog + +import ( + "context" + "os" + "path/filepath" + "testing" + + "github.com/hashicorp/vault/sdk/helper/consts" + "github.com/hashicorp/vault/sdk/helper/pluginutil" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestPluginCatalog_PinnedVersionCRUD tests the CRUD operations for pinned +// versions. +func TestPluginCatalog_PinnedVersionCRUD(t *testing.T) { + catalog := testPluginCatalog(t) + + // Register a plugin in the catalog. + file, err := os.CreateTemp(catalog.directory, "temp") + if err != nil { + t.Fatal(err) + } + defer file.Close() + + for _, version := range []string{"1.0.0", "2.0.0"} { + err = catalog.Set(context.Background(), pluginutil.SetPluginInput{ + Name: "my-plugin", + Type: consts.PluginTypeSecrets, + Version: version, + Command: filepath.Base(file.Name()), + }) + require.NoError(t, err) + } + + // List pinned versions before creating a pin. + pinnedVersionsBefore, err := catalog.ListPinnedVersions(context.Background()) + require.NoError(t, err) + assert.Empty(t, pinnedVersionsBefore) + + // Create a pinned version. + pin := pluginutil.PinnedVersion{ + Name: "my-plugin", + Type: consts.PluginTypeSecrets, + Version: "1.0.0", + } + err = catalog.SetPinnedVersion(context.Background(), &pin) + require.NoError(t, err) + + // List pinned versions after creating a pin. + pinnedVersionsAfter, err := catalog.ListPinnedVersions(context.Background()) + require.NoError(t, err) + require.Len(t, pinnedVersionsAfter, 1) + assert.Equal(t, pin, *pinnedVersionsAfter[0]) + + // Get the pinned version. + pinnedVersion, err := catalog.GetPinnedVersion(context.Background(), pin.Type, pin.Name) + require.NoError(t, err) + assert.Equal(t, pin, *pinnedVersion) + + // Update the pinned version. + pin.Version = "2.0.0" + err = catalog.SetPinnedVersion(context.Background(), &pin) + require.NoError(t, err) + + // Get the updated pinned version. + pinnedVersion, err = catalog.GetPinnedVersion(context.Background(), pin.Type, pin.Name) + require.NoError(t, err) + assert.Equal(t, pin, *pinnedVersion) + + // Update to a version that isn't in the catalog. + pin.Version = "3.0.0" + err = catalog.SetPinnedVersion(context.Background(), &pin) + assert.Error(t, err) + + // Delete the pinned version. + err = catalog.DeletePinnedVersion(context.Background(), pin.Type, pin.Name) + require.NoError(t, err) + + // Delete it again, should not error (idempotent). + err = catalog.DeletePinnedVersion(context.Background(), pin.Type, pin.Name) + require.NoError(t, err) + + // Verify that the pinned version is deleted. + pinnedVersion, err = catalog.GetPinnedVersion(context.Background(), pin.Type, pin.Name) + assert.Equal(t, pluginutil.ErrPinnedVersionNotFound, err) + assert.Nil(t, pinnedVersion) + + // List should be empty again. + pinnedVersionsAfterDelete, err := catalog.ListPinnedVersions(context.Background()) + require.NoError(t, err) + assert.Empty(t, pinnedVersionsAfterDelete) +} diff --git a/vault/plugincatalog/plugin_catalog.go b/vault/plugincatalog/plugin_catalog.go index 2c39009df7..b136f47510 100644 --- a/vault/plugincatalog/plugin_catalog.go +++ b/vault/plugincatalog/plugin_catalog.go @@ -38,6 +38,7 @@ var ( ErrPluginNotFound = errors.New("plugin not found in the catalog") ErrPluginConnectionNotFound = errors.New("plugin connection not found for client") ErrPluginBadType = errors.New("unable to determine plugin type") + ErrPinnedVersion = errors.New("cannot delete a pinned version") ) // PluginCatalog keeps a record of plugins known to vault. External plugins need @@ -1013,6 +1014,14 @@ func (c *PluginCatalog) Delete(ctx context.Context, name string, pluginType cons c.lock.Lock() defer c.lock.Unlock() + pin, err := c.getPinnedVersionInternal(ctx, pinnedVersionStorageKey(pluginType, name)) + if err != nil && !errors.Is(err, pluginutil.ErrPinnedVersionNotFound) { + return err + } + if pin != nil && pin.Version == pluginVersion { + return ErrPinnedVersion + } + // Check the name under which the plugin exists, but if it's unfound, don't return any error. pluginKey := path.Join(pluginType.String(), name) if pluginVersion != "" { @@ -1059,6 +1068,10 @@ func (c *PluginCatalog) ListPluginsWithRuntime(ctx context.Context, runtime stri var ret []string for _, key := range keys { + // Skip: pinned version entry. + if strings.HasPrefix(key, pinnedVersionStoragePrefix) { + continue + } entry, err := c.catalogView.Get(ctx, key) if err != nil || entry == nil { continue @@ -1094,6 +1107,11 @@ func (c *PluginCatalog) listInternal(ctx context.Context, pluginType consts.Plug unversionedPlugins := make(map[string]struct{}) for _, key := range keys { + // Skip: pinned version entry. + if strings.HasPrefix(key, pinnedVersionStoragePrefix) { + continue + } + var semanticVersion *semver.Version entry, err := c.catalogView.Get(ctx, key) diff --git a/vault/plugincatalog/plugin_catalog_test.go b/vault/plugincatalog/plugin_catalog_test.go index 9704a37426..6c44872c46 100644 --- a/vault/plugincatalog/plugin_catalog_test.go +++ b/vault/plugincatalog/plugin_catalog_test.go @@ -8,6 +8,7 @@ import ( "crypto/sha256" "encoding/hex" "encoding/json" + "errors" "fmt" "io/ioutil" "os" @@ -71,6 +72,33 @@ func TestPluginCatalog_CRUD(t *testing.T) { pluginCatalog := testPluginCatalog(t) + // Register a fake plugin in the catalog. + file, err := os.CreateTemp(pluginCatalog.directory, "temp") + if err != nil { + t.Fatal(err) + } + defer file.Close() + + err = pluginCatalog.Set(context.Background(), pluginutil.SetPluginInput{ + Name: pluginName, + Type: consts.PluginTypeDatabase, + Version: "1.0.0", + Command: filepath.Base(file.Name()), + }) + if err != nil { + t.Fatal(err) + } + + // Register a pinned version, should not affect anything below. + err = pluginCatalog.SetPinnedVersion(context.Background(), &pluginutil.PinnedVersion{ + Name: pluginName, + Type: consts.PluginTypeDatabase, + Version: "1.0.0", + }) + if err != nil { + t.Fatal(err) + } + // Get builtin plugin p, err := pluginCatalog.Get(context.Background(), pluginName, consts.PluginTypeDatabase, "") if err != nil { @@ -106,12 +134,6 @@ func TestPluginCatalog_CRUD(t *testing.T) { } // Set a plugin, test overwriting a builtin plugin - file, err := os.CreateTemp(pluginCatalog.directory, "temp") - if err != nil { - t.Fatal(err) - } - defer file.Close() - command := filepath.Base(file.Name()) err = pluginCatalog.Set(context.Background(), pluginutil.SetPluginInput{ Name: pluginName, @@ -1060,6 +1082,58 @@ func TestExternalPluginInContainer_GetBackendTypeVersion(t *testing.T) { } } +// TestPluginCatalog_CannotDeletePinnedVersion ensures we cannot delete a +// plugin which is referred to in an active pinned version. +func TestPluginCatalog_CannotDeletePinnedVersion(t *testing.T) { + pluginCatalog := testPluginCatalog(t) + + // Register a fake plugin in the catalog. + file, err := os.CreateTemp(pluginCatalog.directory, "temp") + if err != nil { + t.Fatal(err) + } + defer file.Close() + + err = pluginCatalog.Set(context.Background(), pluginutil.SetPluginInput{ + Name: "my-plugin", + Type: consts.PluginTypeSecrets, + Version: "1.0.0", + Command: filepath.Base(file.Name()), + }) + if err != nil { + t.Fatal(err) + } + + // Pin a version and check we can't delete it. + err = pluginCatalog.SetPinnedVersion(context.Background(), &pluginutil.PinnedVersion{ + Name: "my-plugin", + Type: consts.PluginTypeSecrets, + Version: "1.0.0", + }) + if err != nil { + t.Fatal(err) + } + + err = pluginCatalog.Delete(context.Background(), "my-plugin", consts.PluginTypeSecrets, "1.0.0") + if err == nil { + t.Fatal("expected error") + } + if !errors.Is(err, ErrPinnedVersion) { + t.Fatal(err) + } + + // Now delete the pinned version and we should be able to delete the plugin. + err = pluginCatalog.DeletePinnedVersion(context.Background(), consts.PluginTypeSecrets, "my-plugin") + if err != nil { + t.Fatalf("unexpected error %v", err) + } + + err = pluginCatalog.Delete(context.Background(), "my-plugin", consts.PluginTypeSecrets, "1.0.0") + if err != nil { + t.Fatal(err) + } +} + // testRunTestPlugin runs the testFunc which has already been registered to the // plugin catalog and returns a pluginClient. This can be called after calling // TestAddTestPlugin. diff --git a/vault/testing.go b/vault/testing.go index 3e3dbe2088..76de21af21 100644 --- a/vault/testing.go +++ b/vault/testing.go @@ -519,7 +519,7 @@ func TestKeyCopy(key []byte) []byte { return result } -func TestDynamicSystemView(c *Core, ns *namespace.Namespace) *dynamicSystemView { +func TestDynamicSystemView(c *Core, ns *namespace.Namespace) logical.SystemView { me := &MountEntry{ Config: MountConfig{ DefaultLeaseTTL: 24 * time.Hour, @@ -534,7 +534,9 @@ func TestDynamicSystemView(c *Core, ns *namespace.Namespace) *dynamicSystemView me.namespace = ns } - return &dynamicSystemView{c, me, c.perfStandby} + return &extendedSystemViewImpl{ + dynamicSystemView{c, me, c.perfStandby}, + } } func TestAddTestPlugin(t testing.T, core *Core, name string, pluginType consts.PluginType, version string, testFunc string, env []string) {