Refactor plugin catalog set functions (#22666)

Use a struct arg instead of a long list of args. Plugins running in containers
will require even more args and it's getting difficult to maintain.
This commit is contained in:
Tom Proctor
2023-08-31 10:32:24 +01:00
committed by GitHub
parent 1acd0c6d24
commit 3e55447036
6 changed files with 155 additions and 62 deletions

View File

@@ -61,6 +61,20 @@ type PluginRunner struct {
BuiltinFactory func() (interface{}, error) `json:"-" structs:"-"` BuiltinFactory func() (interface{}, error) `json:"-" structs:"-"`
} }
// SetPluginInput is only used as input for the plugin catalog's set methods.
// We don't use the very similar PluginRunner struct to avoid confusion about
// what's settable, which does not include the builtin fields.
type SetPluginInput struct {
Name string
Type consts.PluginType
Version string
Command string
OCIImage string
Args []string
Env []string
Sha256 []byte
}
// Run takes a wrapper RunnerUtil instance along with the go-plugin parameters and // Run takes a wrapper RunnerUtil instance along with the go-plugin parameters and
// returns a configured plugin.Client with TLS Configured and a wrapping token set // returns a configured plugin.Client with TLS Configured and a wrapping token set
// on PluginUnwrapTokenEnv for plugin process consumption. // on PluginUnwrapTokenEnv for plugin process consumption.

View File

@@ -556,7 +556,15 @@ func (b *SystemBackend) handlePluginCatalogUpdate(ctx context.Context, _ *logica
return logical.ErrorResponse("Could not decode SHA-256 value from Hex"), err return logical.ErrorResponse("Could not decode SHA-256 value from Hex"), err
} }
err = b.Core.pluginCatalog.Set(ctx, pluginName, pluginType, pluginVersion, parts[0], args, env, sha256Bytes) err = b.Core.pluginCatalog.Set(ctx, pluginutil.SetPluginInput{
Name: pluginName,
Type: pluginType,
Version: pluginVersion,
Command: parts[0],
Args: args,
Env: env,
Sha256: sha256Bytes,
})
if err != nil { if err != nil {
if errors.Is(err, ErrPluginNotFound) || strings.HasPrefix(err.Error(), "plugin version mismatch") { if errors.Is(err, ErrPluginNotFound) || strings.HasPrefix(err.Error(), "plugin version mismatch") {
return logical.ErrorResponse(err.Error()), nil return logical.ErrorResponse(err.Error()), nil

View File

@@ -2189,7 +2189,15 @@ func TestSystemBackend_tuneAuth(t *testing.T) {
if err := file.Close(); err != nil { if err := file.Close(); err != nil {
t.Fatal(err) t.Fatal(err)
} }
err = c.pluginCatalog.Set(context.Background(), "token", consts.PluginTypeCredential, "v1.0.0", "foo", []string{}, []string{}, []byte{}) err = c.pluginCatalog.Set(context.Background(), pluginutil.SetPluginInput{
Name: "token",
Type: consts.PluginTypeCredential,
Version: "v1.0.0",
Command: "foo",
Args: []string{},
Env: []string{},
Sha256: []byte{},
})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -5742,7 +5750,15 @@ func TestValidateVersion_HelpfulErrorWhenBuiltinOverridden(t *testing.T) {
defer file.Close() defer file.Close()
command := filepath.Base(file.Name()) command := filepath.Base(file.Name())
err = core.pluginCatalog.Set(context.Background(), "kubernetes", consts.PluginTypeCredential, "", command, nil, nil, nil) err = core.pluginCatalog.Set(context.Background(), pluginutil.SetPluginInput{
Name: "kubernetes",
Type: consts.PluginTypeCredential,
Version: "",
Command: command,
Args: nil,
Env: nil,
Sha256: nil,
})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@@ -773,7 +773,15 @@ func (c *PluginCatalog) UpgradePlugins(ctx context.Context, logger log.Logger) e
plugin.Command = filepath.Join(c.directory, plugin.Command) plugin.Command = filepath.Join(c.directory, plugin.Command)
// Upgrade the storage. At this point we don't know what type of plugin this is so pass in the unknown type. // Upgrade the storage. At this point we don't know what type of plugin this is so pass in the unknown type.
runner, err := c.setInternal(ctx, pluginName, consts.PluginTypeUnknown, plugin.Version, cmdOld, plugin.Args, plugin.Env, plugin.Sha256) runner, err := c.setInternal(ctx, pluginutil.SetPluginInput{
Name: pluginName,
Type: consts.PluginTypeUnknown,
Version: plugin.Version,
Command: cmdOld,
Args: plugin.Args,
Env: plugin.Env,
Sha256: plugin.Sha256,
})
if err != nil { if err != nil {
if errors.Is(err, ErrPluginBadType) { if errors.Is(err, ErrPluginBadType) {
retErr = multierror.Append(retErr, fmt.Errorf("could not upgrade plugin %s: plugin of unknown type", pluginName)) retErr = multierror.Append(retErr, fmt.Errorf("could not upgrade plugin %s: plugin of unknown type", pluginName))
@@ -868,29 +876,29 @@ func (c *PluginCatalog) get(ctx context.Context, name string, pluginType consts.
// Set registers a new external plugin with the catalog, or updates an existing // Set registers a new external plugin with the catalog, or updates an existing
// external plugin. It takes the name, command and SHA256 of the plugin. // external plugin. It takes the name, command and SHA256 of the plugin.
func (c *PluginCatalog) Set(ctx context.Context, name string, pluginType consts.PluginType, version string, command string, args []string, env []string, sha256 []byte) error { func (c *PluginCatalog) Set(ctx context.Context, plugin pluginutil.SetPluginInput) error {
if c.directory == "" { if c.directory == "" {
return ErrDirectoryNotConfigured return ErrDirectoryNotConfigured
} }
switch { switch {
case strings.Contains(name, ".."): case strings.Contains(plugin.Name, ".."):
fallthrough fallthrough
case strings.Contains(command, ".."): case strings.Contains(plugin.Command, ".."):
return consts.ErrPathContainsParentReferences return consts.ErrPathContainsParentReferences
} }
c.lock.Lock() c.lock.Lock()
defer c.lock.Unlock() defer c.lock.Unlock()
_, err := c.setInternal(ctx, name, pluginType, version, command, args, env, sha256) _, err := c.setInternal(ctx, plugin)
return err return err
} }
func (c *PluginCatalog) setInternal(ctx context.Context, name string, pluginType consts.PluginType, version string, command string, args []string, env []string, sha256 []byte) (*pluginutil.PluginRunner, error) { func (c *PluginCatalog) setInternal(ctx context.Context, plugin pluginutil.SetPluginInput) (*pluginutil.PluginRunner, error) {
// Best effort check to make sure the command isn't breaking out of the // Best effort check to make sure the command isn't breaking out of the
// configured plugin directory. // configured plugin directory.
commandFull := filepath.Join(c.directory, command) commandFull := filepath.Join(c.directory, plugin.Command)
sym, err := filepath.EvalSymlinks(commandFull) sym, err := filepath.EvalSymlinks(commandFull)
if err != nil { if err != nil {
return nil, fmt.Errorf("error while validating the command path: %w", err) return nil, fmt.Errorf("error while validating the command path: %w", err)
@@ -907,20 +915,21 @@ func (c *PluginCatalog) setInternal(ctx context.Context, name string, pluginType
// entryTmp should only be used for the below type and version checks, it uses the // entryTmp should only be used for the below type and version checks, it uses the
// full command instead of the relative command. // full command instead of the relative command.
entryTmp := &pluginutil.PluginRunner{ entryTmp := &pluginutil.PluginRunner{
Name: name, Name: plugin.Name,
Command: commandFull, Command: commandFull,
Args: args, Args: plugin.Args,
Env: env, Env: plugin.Env,
Sha256: sha256, Sha256: plugin.Sha256,
Builtin: false, Builtin: false,
} }
// If the plugin type is unknown, we want to attempt to determine the type // If the plugin type is unknown, we want to attempt to determine the type
if pluginType == consts.PluginTypeUnknown { if plugin.Type == consts.PluginTypeUnknown {
pluginType, err = c.getPluginTypeFromUnknown(ctx, entryTmp) var err error
plugin.Type, err = c.getPluginTypeFromUnknown(ctx, entryTmp)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if pluginType == consts.PluginTypeUnknown { if plugin.Type == consts.PluginTypeUnknown {
return nil, ErrPluginBadType return nil, ErrPluginBadType
} }
} }
@@ -928,36 +937,36 @@ func (c *PluginCatalog) setInternal(ctx context.Context, name string, pluginType
// getting the plugin version is best-effort, so errors are not fatal // getting the plugin version is best-effort, so errors are not fatal
runningVersion := logical.EmptyPluginVersion runningVersion := logical.EmptyPluginVersion
var versionErr error var versionErr error
switch pluginType { switch plugin.Type {
case consts.PluginTypeSecrets, consts.PluginTypeCredential: case consts.PluginTypeSecrets, consts.PluginTypeCredential:
runningVersion, versionErr = c.getBackendRunningVersion(ctx, entryTmp) runningVersion, versionErr = c.getBackendRunningVersion(ctx, entryTmp)
case consts.PluginTypeDatabase: case consts.PluginTypeDatabase:
runningVersion, versionErr = c.getDatabaseRunningVersion(ctx, entryTmp) runningVersion, versionErr = c.getDatabaseRunningVersion(ctx, entryTmp)
default: default:
return nil, fmt.Errorf("unknown plugin type: %v", pluginType) return nil, fmt.Errorf("unknown plugin type: %v", plugin.Type)
} }
if versionErr != nil { if versionErr != nil {
c.logger.Warn("Error determining plugin version", "error", versionErr) c.logger.Warn("Error determining plugin version", "error", versionErr)
} else if version != "" && runningVersion.Version != "" && version != runningVersion.Version { } else if plugin.Version != "" && runningVersion.Version != "" && plugin.Version != runningVersion.Version {
c.logger.Warn("Plugin self-reported version did not match requested version", "plugin", name, "requestedVersion", version, "reportedVersion", runningVersion.Version) c.logger.Warn("Plugin self-reported version did not match requested version", "plugin", plugin.Name, "requestedVersion", plugin.Version, "reportedVersion", runningVersion.Version)
return nil, fmt.Errorf("plugin version mismatch: %s reported version (%s) did not match requested version (%s)", name, runningVersion.Version, version) return nil, fmt.Errorf("plugin version mismatch: %s reported version (%s) did not match requested version (%s)", plugin.Name, runningVersion.Version, plugin.Version)
} else if version == "" && runningVersion.Version != "" { } else if plugin.Version == "" && runningVersion.Version != "" {
version = runningVersion.Version plugin.Version = runningVersion.Version
_, err := semver.NewVersion(version) _, err := semver.NewVersion(plugin.Version)
if err != nil { if err != nil {
return nil, fmt.Errorf("plugin self-reported version %q is not a valid semantic version: %w", version, err) return nil, fmt.Errorf("plugin self-reported version %q is not a valid semantic version: %w", plugin.Version, err)
} }
} }
entry := &pluginutil.PluginRunner{ entry := &pluginutil.PluginRunner{
Name: name, Name: plugin.Name,
Type: pluginType, Type: plugin.Type,
Version: version, Version: plugin.Version,
Command: command, Command: plugin.Command,
Args: args, Args: plugin.Args,
Env: env, Env: plugin.Env,
Sha256: sha256, Sha256: plugin.Sha256,
Builtin: false, Builtin: false,
} }
@@ -966,9 +975,9 @@ func (c *PluginCatalog) setInternal(ctx context.Context, name string, pluginType
return nil, fmt.Errorf("failed to encode plugin entry: %w", err) return nil, fmt.Errorf("failed to encode plugin entry: %w", err)
} }
storageKey := path.Join(pluginType.String(), name) storageKey := path.Join(plugin.Type.String(), plugin.Name)
if version != "" { if plugin.Version != "" {
storageKey = path.Join(storageKey, version) storageKey = path.Join(storageKey, plugin.Version)
} }
logicalEntry := logical.StorageEntry{ logicalEntry := logical.StorageEntry{
Key: storageKey, Key: storageKey,

View File

@@ -79,7 +79,15 @@ func TestPluginCatalog_CRUD(t *testing.T) {
defer file.Close() defer file.Close()
command := filepath.Base(file.Name()) command := filepath.Base(file.Name())
err = core.pluginCatalog.Set(context.Background(), pluginName, consts.PluginTypeDatabase, "", command, []string{"--test"}, []string{"FOO=BAR"}, []byte{'1'}) err = core.pluginCatalog.Set(context.Background(), pluginutil.SetPluginInput{
Name: pluginName,
Type: consts.PluginTypeDatabase,
Version: "",
Command: command,
Args: []string{"--test"},
Env: []string{"FOO=BAR"},
Sha256: []byte{'1'},
})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -163,7 +171,15 @@ func TestPluginCatalog_VersionedCRUD(t *testing.T) {
const name = "mysql-database-plugin" const name = "mysql-database-plugin"
const version = "1.0.0" const version = "1.0.0"
command := fmt.Sprintf("%s", filepath.Base(file.Name())) command := fmt.Sprintf("%s", filepath.Base(file.Name()))
err = core.pluginCatalog.Set(context.Background(), name, consts.PluginTypeDatabase, version, command, []string{"--test"}, []string{"FOO=BAR"}, []byte{'1'}) err = core.pluginCatalog.Set(context.Background(), pluginutil.SetPluginInput{
Name: name,
Type: consts.PluginTypeDatabase,
Version: version,
Command: command,
Args: []string{"--test"},
Env: []string{"FOO=BAR"},
Sha256: []byte{'1'},
})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -270,13 +286,29 @@ func TestPluginCatalog_List(t *testing.T) {
defer file.Close() defer file.Close()
command := filepath.Base(file.Name()) command := filepath.Base(file.Name())
err = core.pluginCatalog.Set(context.Background(), "mysql-database-plugin", consts.PluginTypeDatabase, "", command, []string{"--test"}, []string{}, []byte{'1'}) err = core.pluginCatalog.Set(context.Background(), pluginutil.SetPluginInput{
Name: "mysql-database-plugin",
Type: consts.PluginTypeDatabase,
Version: "",
Command: command,
Args: []string{"--test"},
Env: []string{},
Sha256: []byte{'1'},
})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
// Set another plugin // Set another plugin
err = core.pluginCatalog.Set(context.Background(), "aaaaaaa", consts.PluginTypeDatabase, "", command, []string{"--test"}, []string{}, []byte{'1'}) err = core.pluginCatalog.Set(context.Background(), pluginutil.SetPluginInput{
Name: "aaaaaaa",
Type: consts.PluginTypeDatabase,
Version: "",
Command: command,
Args: []string{"--test"},
Env: []string{},
Sha256: []byte{'1'},
})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -341,31 +373,29 @@ func TestPluginCatalog_ListVersionedPlugins(t *testing.T) {
defer file.Close() defer file.Close()
command := filepath.Base(file.Name()) command := filepath.Base(file.Name())
err = core.pluginCatalog.Set( err = core.pluginCatalog.Set(context.Background(), pluginutil.SetPluginInput{
context.Background(), Name: "mysql-database-plugin",
"mysql-database-plugin", Type: consts.PluginTypeDatabase,
consts.PluginTypeDatabase, Version: "",
"", Command: command,
command, Args: []string{"--test"},
[]string{"--test"}, Env: []string{},
[]string{}, Sha256: []byte{'1'},
[]byte{'1'}, })
)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
// Set another plugin, with version information // Set another plugin, with version information
err = core.pluginCatalog.Set( err = core.pluginCatalog.Set(context.Background(), pluginutil.SetPluginInput{
context.Background(), Name: "aaaaaaa",
"aaaaaaa", Type: consts.PluginTypeDatabase,
consts.PluginTypeDatabase, Version: "1.1.0",
"1.1.0", Command: command,
command, Args: []string{"--test"},
[]string{"--test"}, Env: []string{},
[]string{}, Sha256: []byte{'1'},
[]byte{'1'}, })
)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -458,7 +488,15 @@ func TestPluginCatalog_ListHandlesPluginNamesWithSlashes(t *testing.T) {
}, },
} }
for _, entry := range pluginsToRegister { for _, entry := range pluginsToRegister {
err = core.pluginCatalog.Set(ctx, entry.Name, consts.PluginTypeCredential, entry.Version, command, nil, nil, nil) err = core.pluginCatalog.Set(ctx, pluginutil.SetPluginInput{
Name: entry.Name,
Type: consts.PluginTypeCredential,
Version: entry.Version,
Command: command,
Args: nil,
Env: nil,
Sha256: nil,
})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@@ -593,7 +593,15 @@ func TestAddTestPlugin(t testing.T, c *Core, name string, pluginType consts.Plug
c.pluginCatalog.directory = fullPath c.pluginCatalog.directory = fullPath
args := []string{fmt.Sprintf("--test.run=%s", testFunc)} args := []string{fmt.Sprintf("--test.run=%s", testFunc)}
err = c.pluginCatalog.Set(context.Background(), name, pluginType, version, fileName, args, env, sum) err = c.pluginCatalog.Set(context.Background(), pluginutil.SetPluginInput{
Name: name,
Type: pluginType,
Version: version,
Command: fileName,
Args: args,
Env: env,
Sha256: sum,
})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }