diff --git a/sdk/helper/pluginutil/runner.go b/sdk/helper/pluginutil/runner.go index 977f95d722..a0590914e1 100644 --- a/sdk/helper/pluginutil/runner.go +++ b/sdk/helper/pluginutil/runner.go @@ -61,6 +61,20 @@ type PluginRunner struct { BuiltinFactory func() (interface{}, error) `json:"-" structs:"-"` } +// SetPluginInput is only used as input for the plugin catalog's set methods. +// We don't use the very similar PluginRunner struct to avoid confusion about +// what's settable, which does not include the builtin fields. +type SetPluginInput struct { + Name string + Type consts.PluginType + Version string + Command string + OCIImage string + Args []string + Env []string + Sha256 []byte +} + // Run takes a wrapper RunnerUtil instance along with the go-plugin parameters and // returns a configured plugin.Client with TLS Configured and a wrapping token set // on PluginUnwrapTokenEnv for plugin process consumption. diff --git a/vault/logical_system.go b/vault/logical_system.go index c6380104d3..0f033a6489 100644 --- a/vault/logical_system.go +++ b/vault/logical_system.go @@ -556,7 +556,15 @@ func (b *SystemBackend) handlePluginCatalogUpdate(ctx context.Context, _ *logica return logical.ErrorResponse("Could not decode SHA-256 value from Hex"), err } - err = b.Core.pluginCatalog.Set(ctx, pluginName, pluginType, pluginVersion, parts[0], args, env, sha256Bytes) + err = b.Core.pluginCatalog.Set(ctx, pluginutil.SetPluginInput{ + Name: pluginName, + Type: pluginType, + Version: pluginVersion, + Command: parts[0], + Args: args, + Env: env, + Sha256: sha256Bytes, + }) if err != nil { if errors.Is(err, ErrPluginNotFound) || strings.HasPrefix(err.Error(), "plugin version mismatch") { return logical.ErrorResponse(err.Error()), nil diff --git a/vault/logical_system_test.go b/vault/logical_system_test.go index 7d5ce33a47..94bd150bc4 100644 --- a/vault/logical_system_test.go +++ b/vault/logical_system_test.go @@ -2189,7 +2189,15 @@ func TestSystemBackend_tuneAuth(t *testing.T) { if err := file.Close(); err != nil { t.Fatal(err) } - err = c.pluginCatalog.Set(context.Background(), "token", consts.PluginTypeCredential, "v1.0.0", "foo", []string{}, []string{}, []byte{}) + err = c.pluginCatalog.Set(context.Background(), pluginutil.SetPluginInput{ + Name: "token", + Type: consts.PluginTypeCredential, + Version: "v1.0.0", + Command: "foo", + Args: []string{}, + Env: []string{}, + Sha256: []byte{}, + }) if err != nil { t.Fatal(err) } @@ -5742,7 +5750,15 @@ func TestValidateVersion_HelpfulErrorWhenBuiltinOverridden(t *testing.T) { defer file.Close() command := filepath.Base(file.Name()) - err = core.pluginCatalog.Set(context.Background(), "kubernetes", consts.PluginTypeCredential, "", command, nil, nil, nil) + err = core.pluginCatalog.Set(context.Background(), pluginutil.SetPluginInput{ + Name: "kubernetes", + Type: consts.PluginTypeCredential, + Version: "", + Command: command, + Args: nil, + Env: nil, + Sha256: nil, + }) if err != nil { t.Fatal(err) } diff --git a/vault/plugin_catalog.go b/vault/plugin_catalog.go index 2454415742..e3e69377bc 100644 --- a/vault/plugin_catalog.go +++ b/vault/plugin_catalog.go @@ -773,7 +773,15 @@ func (c *PluginCatalog) UpgradePlugins(ctx context.Context, logger log.Logger) e plugin.Command = filepath.Join(c.directory, plugin.Command) // Upgrade the storage. At this point we don't know what type of plugin this is so pass in the unknown type. - runner, err := c.setInternal(ctx, pluginName, consts.PluginTypeUnknown, plugin.Version, cmdOld, plugin.Args, plugin.Env, plugin.Sha256) + runner, err := c.setInternal(ctx, pluginutil.SetPluginInput{ + Name: pluginName, + Type: consts.PluginTypeUnknown, + Version: plugin.Version, + Command: cmdOld, + Args: plugin.Args, + Env: plugin.Env, + Sha256: plugin.Sha256, + }) if err != nil { if errors.Is(err, ErrPluginBadType) { retErr = multierror.Append(retErr, fmt.Errorf("could not upgrade plugin %s: plugin of unknown type", pluginName)) @@ -868,29 +876,29 @@ func (c *PluginCatalog) get(ctx context.Context, name string, pluginType consts. // Set registers a new external plugin with the catalog, or updates an existing // external plugin. It takes the name, command and SHA256 of the plugin. -func (c *PluginCatalog) Set(ctx context.Context, name string, pluginType consts.PluginType, version string, command string, args []string, env []string, sha256 []byte) error { +func (c *PluginCatalog) Set(ctx context.Context, plugin pluginutil.SetPluginInput) error { if c.directory == "" { return ErrDirectoryNotConfigured } switch { - case strings.Contains(name, ".."): + case strings.Contains(plugin.Name, ".."): fallthrough - case strings.Contains(command, ".."): + case strings.Contains(plugin.Command, ".."): return consts.ErrPathContainsParentReferences } c.lock.Lock() defer c.lock.Unlock() - _, err := c.setInternal(ctx, name, pluginType, version, command, args, env, sha256) + _, err := c.setInternal(ctx, plugin) return err } -func (c *PluginCatalog) setInternal(ctx context.Context, name string, pluginType consts.PluginType, version string, command string, args []string, env []string, sha256 []byte) (*pluginutil.PluginRunner, error) { +func (c *PluginCatalog) setInternal(ctx context.Context, plugin pluginutil.SetPluginInput) (*pluginutil.PluginRunner, error) { // Best effort check to make sure the command isn't breaking out of the // configured plugin directory. - commandFull := filepath.Join(c.directory, command) + commandFull := filepath.Join(c.directory, plugin.Command) sym, err := filepath.EvalSymlinks(commandFull) if err != nil { return nil, fmt.Errorf("error while validating the command path: %w", err) @@ -907,20 +915,21 @@ func (c *PluginCatalog) setInternal(ctx context.Context, name string, pluginType // entryTmp should only be used for the below type and version checks, it uses the // full command instead of the relative command. entryTmp := &pluginutil.PluginRunner{ - Name: name, + Name: plugin.Name, Command: commandFull, - Args: args, - Env: env, - Sha256: sha256, + Args: plugin.Args, + Env: plugin.Env, + Sha256: plugin.Sha256, Builtin: false, } // If the plugin type is unknown, we want to attempt to determine the type - if pluginType == consts.PluginTypeUnknown { - pluginType, err = c.getPluginTypeFromUnknown(ctx, entryTmp) + if plugin.Type == consts.PluginTypeUnknown { + var err error + plugin.Type, err = c.getPluginTypeFromUnknown(ctx, entryTmp) if err != nil { return nil, err } - if pluginType == consts.PluginTypeUnknown { + if plugin.Type == consts.PluginTypeUnknown { return nil, ErrPluginBadType } } @@ -928,36 +937,36 @@ func (c *PluginCatalog) setInternal(ctx context.Context, name string, pluginType // getting the plugin version is best-effort, so errors are not fatal runningVersion := logical.EmptyPluginVersion var versionErr error - switch pluginType { + switch plugin.Type { case consts.PluginTypeSecrets, consts.PluginTypeCredential: runningVersion, versionErr = c.getBackendRunningVersion(ctx, entryTmp) case consts.PluginTypeDatabase: runningVersion, versionErr = c.getDatabaseRunningVersion(ctx, entryTmp) default: - return nil, fmt.Errorf("unknown plugin type: %v", pluginType) + return nil, fmt.Errorf("unknown plugin type: %v", plugin.Type) } if versionErr != nil { c.logger.Warn("Error determining plugin version", "error", versionErr) - } else if version != "" && runningVersion.Version != "" && version != runningVersion.Version { - c.logger.Warn("Plugin self-reported version did not match requested version", "plugin", name, "requestedVersion", version, "reportedVersion", runningVersion.Version) - return nil, fmt.Errorf("plugin version mismatch: %s reported version (%s) did not match requested version (%s)", name, runningVersion.Version, version) - } else if version == "" && runningVersion.Version != "" { - version = runningVersion.Version - _, err := semver.NewVersion(version) + } else if plugin.Version != "" && runningVersion.Version != "" && plugin.Version != runningVersion.Version { + c.logger.Warn("Plugin self-reported version did not match requested version", "plugin", plugin.Name, "requestedVersion", plugin.Version, "reportedVersion", runningVersion.Version) + return nil, fmt.Errorf("plugin version mismatch: %s reported version (%s) did not match requested version (%s)", plugin.Name, runningVersion.Version, plugin.Version) + } else if plugin.Version == "" && runningVersion.Version != "" { + plugin.Version = runningVersion.Version + _, err := semver.NewVersion(plugin.Version) if err != nil { - return nil, fmt.Errorf("plugin self-reported version %q is not a valid semantic version: %w", version, err) + return nil, fmt.Errorf("plugin self-reported version %q is not a valid semantic version: %w", plugin.Version, err) } } entry := &pluginutil.PluginRunner{ - Name: name, - Type: pluginType, - Version: version, - Command: command, - Args: args, - Env: env, - Sha256: sha256, + Name: plugin.Name, + Type: plugin.Type, + Version: plugin.Version, + Command: plugin.Command, + Args: plugin.Args, + Env: plugin.Env, + Sha256: plugin.Sha256, Builtin: false, } @@ -966,9 +975,9 @@ func (c *PluginCatalog) setInternal(ctx context.Context, name string, pluginType return nil, fmt.Errorf("failed to encode plugin entry: %w", err) } - storageKey := path.Join(pluginType.String(), name) - if version != "" { - storageKey = path.Join(storageKey, version) + storageKey := path.Join(plugin.Type.String(), plugin.Name) + if plugin.Version != "" { + storageKey = path.Join(storageKey, plugin.Version) } logicalEntry := logical.StorageEntry{ Key: storageKey, diff --git a/vault/plugin_catalog_test.go b/vault/plugin_catalog_test.go index cf4cf0d382..2f3ce61143 100644 --- a/vault/plugin_catalog_test.go +++ b/vault/plugin_catalog_test.go @@ -79,7 +79,15 @@ func TestPluginCatalog_CRUD(t *testing.T) { defer file.Close() command := filepath.Base(file.Name()) - err = core.pluginCatalog.Set(context.Background(), pluginName, consts.PluginTypeDatabase, "", command, []string{"--test"}, []string{"FOO=BAR"}, []byte{'1'}) + err = core.pluginCatalog.Set(context.Background(), pluginutil.SetPluginInput{ + Name: pluginName, + Type: consts.PluginTypeDatabase, + Version: "", + Command: command, + Args: []string{"--test"}, + Env: []string{"FOO=BAR"}, + Sha256: []byte{'1'}, + }) if err != nil { t.Fatal(err) } @@ -163,7 +171,15 @@ func TestPluginCatalog_VersionedCRUD(t *testing.T) { const name = "mysql-database-plugin" const version = "1.0.0" command := fmt.Sprintf("%s", filepath.Base(file.Name())) - err = core.pluginCatalog.Set(context.Background(), name, consts.PluginTypeDatabase, version, command, []string{"--test"}, []string{"FOO=BAR"}, []byte{'1'}) + err = core.pluginCatalog.Set(context.Background(), pluginutil.SetPluginInput{ + Name: name, + Type: consts.PluginTypeDatabase, + Version: version, + Command: command, + Args: []string{"--test"}, + Env: []string{"FOO=BAR"}, + Sha256: []byte{'1'}, + }) if err != nil { t.Fatal(err) } @@ -270,13 +286,29 @@ func TestPluginCatalog_List(t *testing.T) { defer file.Close() command := filepath.Base(file.Name()) - err = core.pluginCatalog.Set(context.Background(), "mysql-database-plugin", consts.PluginTypeDatabase, "", command, []string{"--test"}, []string{}, []byte{'1'}) + err = core.pluginCatalog.Set(context.Background(), pluginutil.SetPluginInput{ + Name: "mysql-database-plugin", + Type: consts.PluginTypeDatabase, + Version: "", + Command: command, + Args: []string{"--test"}, + Env: []string{}, + Sha256: []byte{'1'}, + }) if err != nil { t.Fatal(err) } // Set another plugin - err = core.pluginCatalog.Set(context.Background(), "aaaaaaa", consts.PluginTypeDatabase, "", command, []string{"--test"}, []string{}, []byte{'1'}) + err = core.pluginCatalog.Set(context.Background(), pluginutil.SetPluginInput{ + Name: "aaaaaaa", + Type: consts.PluginTypeDatabase, + Version: "", + Command: command, + Args: []string{"--test"}, + Env: []string{}, + Sha256: []byte{'1'}, + }) if err != nil { t.Fatal(err) } @@ -341,31 +373,29 @@ func TestPluginCatalog_ListVersionedPlugins(t *testing.T) { defer file.Close() command := filepath.Base(file.Name()) - err = core.pluginCatalog.Set( - context.Background(), - "mysql-database-plugin", - consts.PluginTypeDatabase, - "", - command, - []string{"--test"}, - []string{}, - []byte{'1'}, - ) + err = core.pluginCatalog.Set(context.Background(), pluginutil.SetPluginInput{ + Name: "mysql-database-plugin", + Type: consts.PluginTypeDatabase, + Version: "", + Command: command, + Args: []string{"--test"}, + Env: []string{}, + Sha256: []byte{'1'}, + }) if err != nil { t.Fatal(err) } // Set another plugin, with version information - err = core.pluginCatalog.Set( - context.Background(), - "aaaaaaa", - consts.PluginTypeDatabase, - "1.1.0", - command, - []string{"--test"}, - []string{}, - []byte{'1'}, - ) + err = core.pluginCatalog.Set(context.Background(), pluginutil.SetPluginInput{ + Name: "aaaaaaa", + Type: consts.PluginTypeDatabase, + Version: "1.1.0", + Command: command, + Args: []string{"--test"}, + Env: []string{}, + Sha256: []byte{'1'}, + }) if err != nil { t.Fatal(err) } @@ -458,7 +488,15 @@ func TestPluginCatalog_ListHandlesPluginNamesWithSlashes(t *testing.T) { }, } for _, entry := range pluginsToRegister { - err = core.pluginCatalog.Set(ctx, entry.Name, consts.PluginTypeCredential, entry.Version, command, nil, nil, nil) + err = core.pluginCatalog.Set(ctx, pluginutil.SetPluginInput{ + Name: entry.Name, + Type: consts.PluginTypeCredential, + Version: entry.Version, + Command: command, + Args: nil, + Env: nil, + Sha256: nil, + }) if err != nil { t.Fatal(err) } diff --git a/vault/testing.go b/vault/testing.go index f1d159061f..464a348932 100644 --- a/vault/testing.go +++ b/vault/testing.go @@ -593,7 +593,15 @@ func TestAddTestPlugin(t testing.T, c *Core, name string, pluginType consts.Plug c.pluginCatalog.directory = fullPath args := []string{fmt.Sprintf("--test.run=%s", testFunc)} - err = c.pluginCatalog.Set(context.Background(), name, pluginType, version, fileName, args, env, sum) + err = c.pluginCatalog.Set(context.Background(), pluginutil.SetPluginInput{ + Name: name, + Type: pluginType, + Version: version, + Command: fileName, + Args: args, + Env: env, + Sha256: sum, + }) if err != nil { t.Fatal(err) }