diff --git a/go.mod b/go.mod index 28e97f6a44..4ac339a24b 100644 --- a/go.mod +++ b/go.mod @@ -96,7 +96,7 @@ require ( github.com/hashicorp/go-memdb v1.3.4 github.com/hashicorp/go-msgpack v1.1.5 github.com/hashicorp/go-multierror v1.1.1 - github.com/hashicorp/go-plugin v1.5.0 + github.com/hashicorp/go-plugin v1.5.1 github.com/hashicorp/go-raftchunking v0.6.3-0.20191002164813-7e9e8525653a github.com/hashicorp/go-retryablehttp v0.7.4 github.com/hashicorp/go-rootcerts v1.0.2 @@ -385,7 +385,7 @@ require ( github.com/hashicorp/go-metrics v0.5.1 // indirect github.com/hashicorp/go-msgpack/v2 v2.0.0 // indirect github.com/hashicorp/go-secure-stdlib/fileutil v0.1.0 // indirect - github.com/hashicorp/go-secure-stdlib/plugincontainer v0.1.1 // indirect + github.com/hashicorp/go-secure-stdlib/plugincontainer v0.2.0 // indirect github.com/hashicorp/go-slug v0.11.1 // indirect github.com/hashicorp/go-tfe v1.25.1 // indirect github.com/hashicorp/jsonapi v0.0.0-20210826224640-ee7dae0fb22d // indirect diff --git a/go.sum b/go.sum index 04affd0440..967906026d 100644 --- a/go.sum +++ b/go.sum @@ -2012,8 +2012,9 @@ github.com/hashicorp/go-multierror v1.1.0/go.mod h1:spPvp8C1qA32ftKqdAHm4hHTbPw+ github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= github.com/hashicorp/go-plugin v1.4.3/go.mod h1:5fGEH17QVwTTcR0zV7yhDPLLmFX9YSZ38b18Udy6vYQ= -github.com/hashicorp/go-plugin v1.5.0 h1:g6Lj3USwF5LaB8HlvCxPjN2X4nFE08ko2BJNVpl7TIE= github.com/hashicorp/go-plugin v1.5.0/go.mod h1:w1sAEES3g3PuV/RzUrgow20W2uErMly84hhD3um1WL4= +github.com/hashicorp/go-plugin v1.5.1 h1:oGm7cWBaYIp3lJpx1RUEfLWophprE2EV/KUeqBYo+6k= +github.com/hashicorp/go-plugin v1.5.1/go.mod h1:w1sAEES3g3PuV/RzUrgow20W2uErMly84hhD3um1WL4= github.com/hashicorp/go-raftchunking v0.6.3-0.20191002164813-7e9e8525653a h1:FmnBDwGwlTgugDGbVxwV8UavqSMACbGrUpfc98yFLR4= github.com/hashicorp/go-raftchunking v0.6.3-0.20191002164813-7e9e8525653a/go.mod h1:xbXnmKqX9/+RhPkJ4zrEx4738HacP72aaUPlT2RZ4sU= github.com/hashicorp/go-retryablehttp v0.5.3/go.mod h1:9B5zBasrRhHXnJnui7y6sL7es7NDiJgTc6Er0maI1Xs= @@ -2047,8 +2048,9 @@ github.com/hashicorp/go-secure-stdlib/parseutil v0.1.7 h1:UpiO20jno/eV1eVZcxqWnU github.com/hashicorp/go-secure-stdlib/parseutil v0.1.7/go.mod h1:QmrqtbKuxxSWTN3ETMPuB+VtEiBJ/A9XhoYGv8E1uD8= github.com/hashicorp/go-secure-stdlib/password v0.1.1 h1:6JzmBqXprakgFEHwBgdchsjaA9x3GyjdI568bXKxa60= github.com/hashicorp/go-secure-stdlib/password v0.1.1/go.mod h1:9hH302QllNwu1o2TGYtSk8I8kTAN0ca1EHpwhm5Mmzo= -github.com/hashicorp/go-secure-stdlib/plugincontainer v0.1.1 h1:1F0n5stk5uz4yIw2elN3k6bGbIv95OQaJVR2sVQ1kk0= github.com/hashicorp/go-secure-stdlib/plugincontainer v0.1.1/go.mod h1:kRpzC4wHYXc2+sjXA9vuKawXYs0x0d0HuqqbaW1fj1w= +github.com/hashicorp/go-secure-stdlib/plugincontainer v0.2.0 h1:1jd8y6HKfDED6vdsXFRM9SpFQNfhBEIHOC41GyILGyY= +github.com/hashicorp/go-secure-stdlib/plugincontainer v0.2.0/go.mod h1:Cv387jRKKbetAp5AWK4zL7UxdeBeDTgUJOnmS4T/4I8= github.com/hashicorp/go-secure-stdlib/reloadutil v0.1.1 h1:SMGUnbpAcat8rIKHkBPjfv81yC46a8eCNZ2hsR2l1EI= github.com/hashicorp/go-secure-stdlib/reloadutil v0.1.1/go.mod h1:Ch/bf00Qnx77MZd49JRgHYqHQjtEmTgGU2faufpVZb0= github.com/hashicorp/go-secure-stdlib/strutil v0.1.1/go.mod h1:gKOamz3EwoIoJq7mlMIRBpVTAUn8qPCrEclOKKWhD3U= diff --git a/sdk/go.mod b/sdk/go.mod index ca98f77359..aa52f0a613 100644 --- a/sdk/go.mod +++ b/sdk/go.mod @@ -23,13 +23,13 @@ require ( github.com/hashicorp/go-kms-wrapping/entropy/v2 v2.0.0 github.com/hashicorp/go-kms-wrapping/v2 v2.0.8 github.com/hashicorp/go-multierror v1.1.1 - github.com/hashicorp/go-plugin v1.5.0 + github.com/hashicorp/go-plugin v1.5.1 github.com/hashicorp/go-retryablehttp v0.7.1 github.com/hashicorp/go-secure-stdlib/base62 v0.1.2 github.com/hashicorp/go-secure-stdlib/mlock v0.1.2 github.com/hashicorp/go-secure-stdlib/parseutil v0.1.7 github.com/hashicorp/go-secure-stdlib/password v0.1.1 - github.com/hashicorp/go-secure-stdlib/plugincontainer v0.1.1 + github.com/hashicorp/go-secure-stdlib/plugincontainer v0.2.0 github.com/hashicorp/go-secure-stdlib/strutil v0.1.2 github.com/hashicorp/go-secure-stdlib/tlsutil v0.1.2 github.com/hashicorp/go-sockaddr v1.0.2 diff --git a/sdk/go.sum b/sdk/go.sum index 11967d5c5c..4c6f67598e 100644 --- a/sdk/go.sum +++ b/sdk/go.sum @@ -171,8 +171,8 @@ github.com/hashicorp/go-kms-wrapping/v2 v2.0.8/go.mod h1:qTCjxGig/kjuj3hk1z8pOUr github.com/hashicorp/go-multierror v1.0.0/go.mod h1:dHtQlpGsu+cZNNAkkCN/P3hoUDHhCYQXV3UM06sGGrk= github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= -github.com/hashicorp/go-plugin v1.5.0 h1:g6Lj3USwF5LaB8HlvCxPjN2X4nFE08ko2BJNVpl7TIE= -github.com/hashicorp/go-plugin v1.5.0/go.mod h1:w1sAEES3g3PuV/RzUrgow20W2uErMly84hhD3um1WL4= +github.com/hashicorp/go-plugin v1.5.1 h1:oGm7cWBaYIp3lJpx1RUEfLWophprE2EV/KUeqBYo+6k= +github.com/hashicorp/go-plugin v1.5.1/go.mod h1:w1sAEES3g3PuV/RzUrgow20W2uErMly84hhD3um1WL4= github.com/hashicorp/go-retryablehttp v0.5.3/go.mod h1:9B5zBasrRhHXnJnui7y6sL7es7NDiJgTc6Er0maI1Xs= github.com/hashicorp/go-retryablehttp v0.7.1 h1:sUiuQAnLlbvmExtFQs72iFW/HXeUn8Z1aJLQ4LJJbTQ= github.com/hashicorp/go-retryablehttp v0.7.1/go.mod h1:vAew36LZh98gCBJNLH42IQ1ER/9wtLZZ8meHqQvEYWY= @@ -187,8 +187,8 @@ github.com/hashicorp/go-secure-stdlib/parseutil v0.1.7 h1:UpiO20jno/eV1eVZcxqWnU github.com/hashicorp/go-secure-stdlib/parseutil v0.1.7/go.mod h1:QmrqtbKuxxSWTN3ETMPuB+VtEiBJ/A9XhoYGv8E1uD8= github.com/hashicorp/go-secure-stdlib/password v0.1.1 h1:6JzmBqXprakgFEHwBgdchsjaA9x3GyjdI568bXKxa60= github.com/hashicorp/go-secure-stdlib/password v0.1.1/go.mod h1:9hH302QllNwu1o2TGYtSk8I8kTAN0ca1EHpwhm5Mmzo= -github.com/hashicorp/go-secure-stdlib/plugincontainer v0.1.1 h1:1F0n5stk5uz4yIw2elN3k6bGbIv95OQaJVR2sVQ1kk0= -github.com/hashicorp/go-secure-stdlib/plugincontainer v0.1.1/go.mod h1:kRpzC4wHYXc2+sjXA9vuKawXYs0x0d0HuqqbaW1fj1w= +github.com/hashicorp/go-secure-stdlib/plugincontainer v0.2.0 h1:1jd8y6HKfDED6vdsXFRM9SpFQNfhBEIHOC41GyILGyY= +github.com/hashicorp/go-secure-stdlib/plugincontainer v0.2.0/go.mod h1:Cv387jRKKbetAp5AWK4zL7UxdeBeDTgUJOnmS4T/4I8= github.com/hashicorp/go-secure-stdlib/strutil v0.1.1/go.mod h1:gKOamz3EwoIoJq7mlMIRBpVTAUn8qPCrEclOKKWhD3U= github.com/hashicorp/go-secure-stdlib/strutil v0.1.2 h1:kes8mmyCpxJsI7FTwtzRqEy9CdjCtrXrXGuOpxEA7Ts= github.com/hashicorp/go-secure-stdlib/strutil v0.1.2/go.mod h1:Gou2R9+il93BqX25LAKCLuM+y9U2T4hlwvT1yprcna4= diff --git a/sdk/helper/consts/plugin_runtime_types.go b/sdk/helper/consts/plugin_runtime_types.go index 63b8127ec2..bf0722dd31 100644 --- a/sdk/helper/consts/plugin_runtime_types.go +++ b/sdk/helper/consts/plugin_runtime_types.go @@ -18,6 +18,8 @@ type PluginRuntimeType uint32 // This is a list of PluginRuntimeTypes used by Vault. const ( + DefaultContainerPluginOCIRuntime = "runsc" + PluginRuntimeTypeUnsupported PluginRuntimeType = iota PluginRuntimeTypeContainer ) diff --git a/sdk/helper/pluginutil/run_config.go b/sdk/helper/pluginutil/run_config.go index 480bdc1d9a..fddca60508 100644 --- a/sdk/helper/pluginutil/run_config.go +++ b/sdk/helper/pluginutil/run_config.go @@ -10,15 +10,14 @@ import ( "fmt" "os" "os/exec" + "strconv" "strings" - "github.com/hashicorp/go-hclog" log "github.com/hashicorp/go-hclog" "github.com/hashicorp/go-plugin" - "github.com/hashicorp/go-plugin/runner" "github.com/hashicorp/go-secure-stdlib/plugincontainer" - "github.com/hashicorp/go-secure-stdlib/plugincontainer/config" "github.com/hashicorp/vault/sdk/helper/consts" + "github.com/hashicorp/vault/sdk/helper/pluginruntimeutil" ) type PluginClientConfig struct { @@ -45,23 +44,13 @@ type runConfig struct { // Initialized with what's in PluginRunner.Env, but can be added to env []string + runtimeConfig *pluginruntimeutil.PluginRuntimeConfig + PluginClientConfig } -func overlayCmdSpec(base, cmd *exec.Cmd) { - if cmd.Path != "" { - base.Path = cmd.Path - } - if len(cmd.Args) > 0 { - base.Args = cmd.Args - } - if len(cmd.Env) > 0 { - base.Env = append(base.Env, cmd.Env...) - } -} - -func (rc runConfig) makeConfig(ctx context.Context) (*plugin.ClientConfig, error) { - cmd := exec.Command(rc.command, rc.args...) +func (rc runConfig) generateCmd(ctx context.Context) (cmd *exec.Cmd, clientTLSConfig *tls.Config, err error) { + cmd = exec.Command(rc.command, rc.args...) cmd.Env = append(cmd.Env, rc.env...) // Add the mlock setting to the ENV of the plugin @@ -70,7 +59,7 @@ func (rc runConfig) makeConfig(ctx context.Context) (*plugin.ClientConfig, error } version, err := rc.Wrapper.VaultVersion(ctx) if err != nil { - return nil, err + return nil, nil, err } cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", PluginVaultVersionEnv, version)) @@ -83,31 +72,39 @@ func (rc runConfig) makeConfig(ctx context.Context) (*plugin.ClientConfig, error automtlsEnv := fmt.Sprintf("%s=%t", PluginAutoMTLSEnv, rc.AutoMTLS) cmd.Env = append(cmd.Env, automtlsEnv) - var clientTLSConfig *tls.Config if !rc.AutoMTLS && !rc.IsMetadataMode { // Get a CA TLS Certificate certBytes, key, err := generateCert() if err != nil { - return nil, err + return nil, nil, err } // Use CA to sign a client cert and return a configured TLS config clientTLSConfig, err = createClientTLSConfig(certBytes, key) if err != nil { - return nil, err + return nil, nil, err } // Use CA to sign a server cert and wrap the values in a response wrapped // token. wrapToken, err := wrapServerConfig(ctx, rc.Wrapper, certBytes, key) if err != nil { - return nil, err + return nil, nil, err } // Add the response wrap token to the ENV of the plugin cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", PluginUnwrapTokenEnv, wrapToken)) } + return cmd, clientTLSConfig, nil +} + +func (rc runConfig) makeConfig(ctx context.Context) (*plugin.ClientConfig, error) { + cmd, clientTLSConfig, err := rc.generateCmd(ctx) + if err != nil { + return nil, err + } + clientConfig := &plugin.ClientConfig{ HandshakeConfig: rc.HandshakeConfig, VersionedPlugins: rc.PluginSets, @@ -126,32 +123,49 @@ func (rc runConfig) makeConfig(ctx context.Context) (*plugin.ClientConfig, error Hash: sha256.New(), } } else { + containerCfg := rc.containerConfig(cmd.Env) clientConfig.SkipHostEnv = true - clientConfig.RunnerFunc = func(logger hclog.Logger, goPluginCmd *exec.Cmd, tmpDir string) (runner.Runner, error) { - overlayCmdSpec(goPluginCmd, cmd) - cfg := &config.ContainerConfig{ - UnixSocketGroup: fmt.Sprintf("%d", os.Getgid()), - Image: rc.image, - Tag: rc.imageTag, - SHA256: fmt.Sprintf("%x", rc.sha256), - Labels: map[string]string{ - "managed-by": "hashicorp.com/vault", - }, - // TODO: More configurables. - // Defaulting to runsc will require installing gVisor in the GitHub runner. - // Runtime: "runsc", - // CgroupParent: "", - // NanoCpus: 100000000, - // Memory: 64 * 1024 * 1024, - // TODO: network - - } - return plugincontainer.NewContainerRunner(logger, goPluginCmd, cfg, tmpDir) + clientConfig.RunnerFunc = containerCfg.NewContainerRunner + clientConfig.UnixSocketConfig = &plugin.UnixSocketConfig{ + Group: strconv.Itoa(containerCfg.GroupAdd), } } return clientConfig, nil } +func (rc runConfig) containerConfig(env []string) *plugincontainer.Config { + cfg := &plugincontainer.Config{ + Image: rc.image, + Tag: rc.imageTag, + SHA256: fmt.Sprintf("%x", rc.sha256), + + Env: env, + GroupAdd: os.Getgid(), + Runtime: consts.DefaultContainerPluginOCIRuntime, + Labels: map[string]string{ + "managed-by": "hashicorp.com/vault", + }, + } + // Use rc.command and rc.args directly instead of cmd.Path and cmd.Args, as + // exec.Command may mutate the provided command. + if rc.command != "" { + cfg.Entrypoint = []string{rc.command} + } + if len(rc.args) > 0 { + cfg.Args = rc.args + } + if rc.runtimeConfig != nil { + cfg.CgroupParent = rc.runtimeConfig.CgroupParent + cfg.NanoCpus = rc.runtimeConfig.CPU + cfg.Memory = rc.runtimeConfig.Memory + if rc.runtimeConfig.OCIRuntime != "" { + cfg.Runtime = rc.runtimeConfig.OCIRuntime + } + } + + return cfg +} + func (rc runConfig) run(ctx context.Context) (*plugin.Client, error) { clientConfig, err := rc.makeConfig(ctx) if err != nil { @@ -219,12 +233,13 @@ func (r *PluginRunner) RunConfig(ctx context.Context, opts ...RunOpt) (*plugin.C imageTag = strings.TrimPrefix(r.Version, "v") } rc := runConfig{ - command: r.Command, - image: image, - imageTag: imageTag, - args: r.Args, - sha256: r.Sha256, - env: r.Env, + command: r.Command, + image: image, + imageTag: imageTag, + args: r.Args, + sha256: r.Sha256, + env: r.Env, + runtimeConfig: r.RuntimeConfig, } for _, opt := range opts { diff --git a/sdk/helper/pluginutil/run_config_test.go b/sdk/helper/pluginutil/run_config_test.go index e64057783a..4469401203 100644 --- a/sdk/helper/pluginutil/run_config_test.go +++ b/sdk/helper/pluginutil/run_config_test.go @@ -5,13 +5,19 @@ package pluginutil import ( "context" + "encoding/hex" "fmt" + "os" "os/exec" + "strconv" "testing" "time" "github.com/hashicorp/go-hclog" "github.com/hashicorp/go-plugin" + "github.com/hashicorp/go-secure-stdlib/plugincontainer" + "github.com/hashicorp/vault/sdk/helper/consts" + "github.com/hashicorp/vault/sdk/helper/pluginruntimeutil" "github.com/hashicorp/vault/sdk/helper/wrapping" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" @@ -28,8 +34,10 @@ func TestMakeConfig(t *testing.T) { mlockEnabled bool mlockEnabledTimes int - expectedConfig *plugin.ClientConfig - expectTLSConfig bool + expectedConfig *plugin.ClientConfig + expectTLSConfig bool + expectRunnerFunc bool + skipSecureConfig bool } tests := map[string]testCase{ @@ -286,6 +294,64 @@ func TestMakeConfig(t *testing.T) { }, expectTLSConfig: false, }, + "image set": { + rc: runConfig{ + command: "echo", + args: []string{"foo", "bar"}, + sha256: []byte("some_sha256"), + env: []string{"initial=true"}, + image: "some-image", + imageTag: "0.1.0", + PluginClientConfig: PluginClientConfig{ + PluginSets: map[int]plugin.PluginSet{ + 1: { + "bogus": nil, + }, + }, + HandshakeConfig: plugin.HandshakeConfig{ + ProtocolVersion: 1, + MagicCookieKey: "magic_cookie_key", + MagicCookieValue: "magic_cookie_value", + }, + Logger: hclog.NewNullLogger(), + IsMetadataMode: false, + AutoMTLS: true, + }, + }, + + responseWrapInfoTimes: 0, + + mlockEnabled: false, + mlockEnabledTimes: 1, + + expectedConfig: &plugin.ClientConfig{ + HandshakeConfig: plugin.HandshakeConfig{ + ProtocolVersion: 1, + MagicCookieKey: "magic_cookie_key", + MagicCookieValue: "magic_cookie_value", + }, + VersionedPlugins: map[int]plugin.PluginSet{ + 1: { + "bogus": nil, + }, + }, + Cmd: nil, + SecureConfig: nil, + AllowedProtocols: []plugin.Protocol{ + plugin.ProtocolNetRPC, + plugin.ProtocolGRPC, + }, + Logger: hclog.NewNullLogger(), + AutoMTLS: true, + SkipHostEnv: true, + UnixSocketConfig: &plugin.UnixSocketConfig{ + Group: strconv.Itoa(os.Getgid()), + }, + }, + expectTLSConfig: false, + expectRunnerFunc: true, + skipSecureConfig: true, + }, } for name, test := range tests { @@ -309,11 +375,13 @@ func TestMakeConfig(t *testing.T) { // The following fields are generated, so we just need to check for existence, not specific value // The value must be nilled out before performing a DeepEqual check - hsh := config.SecureConfig.Hash - if hsh == nil { - t.Fatalf("Missing SecureConfig.Hash") + if !test.skipSecureConfig { + hsh := config.SecureConfig.Hash + if hsh == nil { + t.Fatalf("Missing SecureConfig.Hash") + } + config.SecureConfig.Hash = nil } - config.SecureConfig.Hash = nil if test.expectTLSConfig && config.TLSConfig == nil { t.Fatalf("TLS config expected, got nil") @@ -323,6 +391,11 @@ func TestMakeConfig(t *testing.T) { } config.TLSConfig = nil + if test.expectRunnerFunc != (config.RunnerFunc != nil) { + t.Fatalf("expected RunnerFunc: %v, actual: %v", test.expectRunnerFunc, config.RunnerFunc != nil) + } + config.RunnerFunc = nil + require.Equal(t, test.expectedConfig, config) }) } @@ -358,3 +431,117 @@ func (m *mockRunnerUtil) MlockEnabled() bool { args := m.Called() return args.Bool(0) } + +func TestContainerConfig(t *testing.T) { + dummySHA, err := hex.DecodeString("abc123") + if err != nil { + t.Fatal(err) + } + for name, tc := range map[string]struct { + rc runConfig + expected plugincontainer.Config + }{ + "image set, no runtime": { + rc: runConfig{ + command: "echo", + args: []string{"foo", "bar"}, + sha256: dummySHA, + env: []string{"initial=true"}, + image: "some-image", + imageTag: "0.1.0", + PluginClientConfig: PluginClientConfig{ + PluginSets: map[int]plugin.PluginSet{ + 1: { + "bogus": nil, + }, + }, + HandshakeConfig: plugin.HandshakeConfig{ + ProtocolVersion: 1, + MagicCookieKey: "magic_cookie_key", + MagicCookieValue: "magic_cookie_value", + }, + Logger: hclog.NewNullLogger(), + AutoMTLS: true, + }, + }, + expected: plugincontainer.Config{ + Image: "some-image", + Tag: "0.1.0", + SHA256: "abc123", + Entrypoint: []string{"echo"}, + Args: []string{"foo", "bar"}, + Env: []string{ + "initial=true", + fmt.Sprintf("%s=%s", PluginVaultVersionEnv, "dummyversion"), + fmt.Sprintf("%s=%t", PluginMetadataModeEnv, false), + fmt.Sprintf("%s=%t", PluginAutoMTLSEnv, true), + }, + Labels: map[string]string{ + "managed-by": "hashicorp.com/vault", + }, + Runtime: consts.DefaultContainerPluginOCIRuntime, + GroupAdd: os.Getgid(), + }, + }, + "image set, with runtime": { + rc: runConfig{ + sha256: dummySHA, + image: "some-image", + imageTag: "0.1.0", + runtimeConfig: &pluginruntimeutil.PluginRuntimeConfig{ + OCIRuntime: "some-oci-runtime", + CgroupParent: "/cgroup/parent", + CPU: 1000, + Memory: 2000, + }, + PluginClientConfig: PluginClientConfig{ + PluginSets: map[int]plugin.PluginSet{ + 1: { + "bogus": nil, + }, + }, + HandshakeConfig: plugin.HandshakeConfig{ + ProtocolVersion: 1, + MagicCookieKey: "magic_cookie_key", + MagicCookieValue: "magic_cookie_value", + }, + Logger: hclog.NewNullLogger(), + AutoMTLS: true, + }, + }, + expected: plugincontainer.Config{ + Image: "some-image", + Tag: "0.1.0", + SHA256: "abc123", + Env: []string{ + fmt.Sprintf("%s=%s", PluginVaultVersionEnv, "dummyversion"), + fmt.Sprintf("%s=%t", PluginMetadataModeEnv, false), + fmt.Sprintf("%s=%t", PluginAutoMTLSEnv, true), + }, + Labels: map[string]string{ + "managed-by": "hashicorp.com/vault", + }, + Runtime: "some-oci-runtime", + GroupAdd: os.Getgid(), + CgroupParent: "/cgroup/parent", + NanoCpus: 1000, + Memory: 2000, + }, + }, + } { + t.Run(name, func(t *testing.T) { + mockWrapper := new(mockRunnerUtil) + mockWrapper.On("ResponseWrapData", mock.Anything, mock.Anything, mock.Anything, mock.Anything). + Return(nil, nil) + mockWrapper.On("MlockEnabled"). + Return(false) + tc.rc.Wrapper = mockWrapper + cmd, _, err := tc.rc.generateCmd(context.Background()) + if err != nil { + t.Fatal(err) + } + cfg := tc.rc.containerConfig(cmd.Env) + require.Equal(t, tc.expected, *cfg) + }) + } +} diff --git a/sdk/helper/pluginutil/runner.go b/sdk/helper/pluginutil/runner.go index 6102995ef5..316a16fe3f 100644 --- a/sdk/helper/pluginutil/runner.go +++ b/sdk/helper/pluginutil/runner.go @@ -12,6 +12,7 @@ import ( "github.com/hashicorp/go-plugin" "github.com/hashicorp/go-version" "github.com/hashicorp/vault/sdk/helper/consts" + prutil "github.com/hashicorp/vault/sdk/helper/pluginruntimeutil" "github.com/hashicorp/vault/sdk/helper/wrapping" "google.golang.org/grpc" ) @@ -62,6 +63,7 @@ type PluginRunner struct { Sha256 []byte `json:"sha256" structs:"sha256"` Builtin bool `json:"builtin" structs:"builtin"` BuiltinFactory func() (interface{}, error) `json:"-" structs:"-"` + RuntimeConfig *prutil.PluginRuntimeConfig `json:"-" structs:"-"` } // BinaryReference returns either the OCI image reference if it's a container diff --git a/vault/external_plugin_container_test.go b/vault/external_plugin_container_test.go index 0f2401bf31..8133c13563 100644 --- a/vault/external_plugin_container_test.go +++ b/vault/external_plugin_container_test.go @@ -7,11 +7,14 @@ import ( "context" "encoding/hex" "fmt" + "os/exec" + "strings" "testing" "github.com/hashicorp/vault/helper/namespace" "github.com/hashicorp/vault/helper/testhelpers/pluginhelpers" "github.com/hashicorp/vault/sdk/helper/consts" + "github.com/hashicorp/vault/sdk/helper/pluginruntimeutil" "github.com/hashicorp/vault/sdk/helper/pluginutil" "github.com/hashicorp/vault/sdk/logical" ) @@ -41,39 +44,66 @@ func testClusterWithContainerPlugin(t *testing.T, pluginType consts.PluginType, func TestExternalPluginInContainer_MountAndUnmount(t *testing.T) { for name, tc := range map[string]struct { - pluginType consts.PluginType - routerPath string - expectedMatch string - listRolesPath string + pluginType consts.PluginType }{ - "enable external credential plugin": { - pluginType: consts.PluginTypeCredential, - routerPath: "auth/foo/bar", - expectedMatch: "auth/foo/", + "auth": { + pluginType: consts.PluginTypeCredential, }, - "enable external secrets plugin": { - pluginType: consts.PluginTypeSecrets, - routerPath: "foo/bar", - expectedMatch: "foo/", + "secrets": { + pluginType: consts.PluginTypeSecrets, }, } { t.Run(name, func(t *testing.T) { c, plugin := testClusterWithContainerPlugin(t, tc.pluginType, "v1.0.0") - registerContainerPlugin(t, c.systemBackend, plugin.Name, tc.pluginType.String(), "1.0.0", plugin.ImageSha256, plugin.Image) + t.Run("default", func(t *testing.T) { + if _, err := exec.LookPath("runsc"); err != nil { + t.Skip("Skipping test as runsc not found on path") + } + mountAndUnmountContainerPlugin_WithRuntime(t, c, plugin, "") + }) - mountPlugin(t, c.systemBackend, plugin.Name, tc.pluginType, "v1.0.0", "") + t.Run("runc", func(t *testing.T) { + mountAndUnmountContainerPlugin_WithRuntime(t, c, plugin, "runc") + }) - match := c.router.MatchingMount(namespace.RootContext(nil), tc.routerPath) - if match != tc.expectedMatch { - t.Fatalf("missing mount, match: %q", match) - } - - unmountPlugin(t, c.systemBackend, plugin.Name, tc.pluginType, "v1.0.0", "foo") + t.Run("runsc", func(t *testing.T) { + if _, err := exec.LookPath("runsc"); err != nil { + t.Skip("Skipping test as runsc not found on path") + } + mountAndUnmountContainerPlugin_WithRuntime(t, c, plugin, "runsc") + }) }) } } +func mountAndUnmountContainerPlugin_WithRuntime(t *testing.T, c *Core, plugin pluginhelpers.TestPlugin, ociRuntime string) { + if ociRuntime != "" { + registerPluginRuntime(t, c.systemBackend, ociRuntime, ociRuntime) + } + registerContainerPlugin(t, c.systemBackend, plugin.Name, plugin.Typ.String(), "1.0.0", plugin.ImageSha256, plugin.Image, ociRuntime) + + mountPlugin(t, c.systemBackend, plugin.Name, plugin.Typ, "v1.0.0", "") + + routeRequest := func(expectMatch bool) { + pluginPath := "foo/bar" + if plugin.Typ == consts.PluginTypeCredential { + pluginPath = "auth/foo/bar" + } + match := c.router.MatchingMount(namespace.RootContext(nil), pluginPath) + if expectMatch && match != strings.TrimSuffix(pluginPath, "bar") { + t.Fatalf("missing mount, match: %q", match) + } + if !expectMatch && match != "" { + t.Fatalf("expected no match for path, but got %q", match) + } + } + + routeRequest(true) + unmountPlugin(t, c.systemBackend, plugin.Typ, "foo") + routeRequest(false) +} + func TestExternalPluginInContainer_GetBackendTypeVersion(t *testing.T) { for name, tc := range map[string]struct { pluginType consts.PluginType @@ -94,41 +124,63 @@ func TestExternalPluginInContainer_GetBackendTypeVersion(t *testing.T) { } { t.Run(name, func(t *testing.T) { c, plugin := testClusterWithContainerPlugin(t, tc.pluginType, tc.setRunningVersion) - registerContainerPlugin(t, c.systemBackend, plugin.Name, tc.pluginType.String(), tc.setRunningVersion, plugin.ImageSha256, plugin.Image) + for _, ociRuntime := range []string{"runc", "runsc"} { + t.Run(ociRuntime, func(t *testing.T) { + if _, err := exec.LookPath(ociRuntime); err != nil { + t.Skipf("Skipping test as %s not found on path", ociRuntime) + } + shaBytes, _ := hex.DecodeString(plugin.ImageSha256) + entry := &pluginutil.PluginRunner{ + Name: plugin.Name, + OCIImage: plugin.Image, + Args: nil, + Sha256: shaBytes, + Builtin: false, + Runtime: ociRuntime, + RuntimeConfig: &pluginruntimeutil.PluginRuntimeConfig{ + OCIRuntime: ociRuntime, + }, + } - shaBytes, _ := hex.DecodeString(plugin.ImageSha256) - entry := &pluginutil.PluginRunner{ - Name: plugin.Name, - OCIImage: plugin.Image, - Args: nil, - Sha256: shaBytes, - Builtin: false, - } - - var version logical.PluginVersion - var err error - if tc.pluginType == consts.PluginTypeDatabase { - version, err = c.pluginCatalog.getDatabaseRunningVersion(context.Background(), entry) - } else { - version, err = c.pluginCatalog.getBackendRunningVersion(context.Background(), entry) - } - if err != nil { - t.Fatal(err) - } - if version.Version != tc.setRunningVersion { - t.Errorf("Expected to get version %v but got %v", tc.setRunningVersion, version.Version) + var version logical.PluginVersion + var err error + if tc.pluginType == consts.PluginTypeDatabase { + version, err = c.pluginCatalog.getDatabaseRunningVersion(context.Background(), entry) + } else { + version, err = c.pluginCatalog.getBackendRunningVersion(context.Background(), entry) + } + if err != nil { + t.Fatal(err) + } + if version.Version != tc.setRunningVersion { + t.Errorf("Expected to get version %v but got %v", tc.setRunningVersion, version.Version) + } + }) } }) } } -func registerContainerPlugin(t *testing.T, sys *SystemBackend, pluginName, pluginType, version, sha, image string) { +func registerContainerPlugin(t *testing.T, sys *SystemBackend, pluginName, pluginType, version, sha, image, runtime string) { t.Helper() req := logical.TestRequest(t, logical.UpdateOperation, fmt.Sprintf("plugins/catalog/%s/%s", pluginType, pluginName)) req.Data = map[string]interface{}{ "oci_image": image, "sha256": sha, "version": version, + "runtime": runtime, + } + resp, err := sys.HandleRequest(namespace.RootContext(nil), req) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%v resp:%#v", err, resp) + } +} + +func registerPluginRuntime(t *testing.T, sys *SystemBackend, name, ociRuntime string) { + t.Helper() + req := logical.TestRequest(t, logical.UpdateOperation, fmt.Sprintf("plugins/runtimes/catalog/%s/%s", consts.PluginRuntimeTypeContainer, name)) + req.Data = map[string]interface{}{ + "oci_runtime": ociRuntime, } resp, err := sys.HandleRequest(namespace.RootContext(nil), req) if err != nil || (resp != nil && resp.IsError()) { diff --git a/vault/external_plugin_test.go b/vault/external_plugin_test.go index 35ce4a4399..96f71b30fb 100644 --- a/vault/external_plugin_test.go +++ b/vault/external_plugin_test.go @@ -422,7 +422,7 @@ func TestCore_EnableExternalPlugin_ShadowBuiltin(t *testing.T) { } // Remount auth method using registered shadow plugin - unmountPlugin(t, c.systemBackend, pluginName, consts.PluginTypeCredential, "", "") + unmountPlugin(t, c.systemBackend, consts.PluginTypeCredential, "") mountPlugin(t, c.systemBackend, pluginName, consts.PluginTypeCredential, "", "") // Verify auth table has changed @@ -439,7 +439,7 @@ func TestCore_EnableExternalPlugin_ShadowBuiltin(t *testing.T) { } // Remount auth method - unmountPlugin(t, c.systemBackend, pluginName, consts.PluginTypeCredential, "", "") + unmountPlugin(t, c.systemBackend, consts.PluginTypeCredential, "") mountPlugin(t, c.systemBackend, pluginName, consts.PluginTypeCredential, "", "") // Verify auth table has changed @@ -935,23 +935,15 @@ func mountPlugin(t *testing.T, sys *SystemBackend, pluginName string, pluginType } } -func unmountPlugin(t *testing.T, sys *SystemBackend, pluginName string, pluginType consts.PluginType, version, path string) { +func unmountPlugin(t *testing.T, sys *SystemBackend, pluginType consts.PluginType, path string) { t.Helper() var mountPath string if path == "" { mountPath = mountTable(pluginType) } else { - mountPath = mountTableWithPath(consts.PluginTypeSecrets, path) + mountPath = mountTableWithPath(pluginType, path) } req := logical.TestRequest(t, logical.DeleteOperation, mountPath) - req.Data = map[string]interface{}{ - "type": pluginName, - } - if version != "" { - req.Data["config"] = map[string]interface{}{ - "plugin_version": version, - } - } resp, err := sys.HandleRequest(namespace.RootContext(nil), req) if err != nil || (resp != nil && resp.IsError()) { t.Fatalf("err:%v resp:%#v", err, resp) diff --git a/vault/logical_system.go b/vault/logical_system.go index a7f8203e71..a2fea6c95d 100644 --- a/vault/logical_system.go +++ b/vault/logical_system.go @@ -550,11 +550,19 @@ func (b *SystemBackend) handlePluginCatalogUpdate(ctx context.Context, _ *logica return nil, err } } - if ociImage != "" && runtime.GOOS != "linux" { - return logical.ErrorResponse("specifying oci_image is currently only supported on Linux"), nil - } pluginRuntime := d.Get("runtime").(string) + if ociImage != "" { + if runtime.GOOS != "linux" { + return logical.ErrorResponse("specifying oci_image is currently only supported on Linux"), nil + } + if pluginRuntime != "" { + _, err := b.Core.pluginRuntimeCatalog.Get(ctx, pluginRuntime, consts.PluginRuntimeTypeContainer) + if err != nil { + return logical.ErrorResponse("specified plugin runtime %q, but failed to retrieve config: %w", pluginRuntime, err), nil + } + } + } // For backwards compatibility, also accept args as part of command. Don't // accepts args in both command and args. diff --git a/vault/plugin_catalog.go b/vault/plugin_catalog.go index b6c9d499cc..34a0f4b2ba 100644 --- a/vault/plugin_catalog.go +++ b/vault/plugin_catalog.go @@ -61,6 +61,8 @@ type PluginCatalog struct { lock sync.RWMutex wrapper pluginutil.RunnerUtil + + runtimeCatalog *PluginRuntimeCatalog } // Only plugins running with identical PluginRunner config can be multiplexed, @@ -181,6 +183,7 @@ func (c *Core) setupPluginCatalog(ctx context.Context) error { logger: c.logger, mlockPlugins: c.enableMlock, wrapper: logical.StaticSystemView{VersionString: version.GetVersion().Version}, + runtimeCatalog: c.pluginRuntimeCatalog, } // Run upgrade if untyped plugins exist @@ -814,39 +817,46 @@ func (c *PluginCatalog) Get(ctx context.Context, name string, pluginType consts. } func (c *PluginCatalog) get(ctx context.Context, name string, pluginType consts.PluginType, version string) (*pluginutil.PluginRunner, error) { - // If the directory isn't set only look for builtin plugins. - if c.directory != "" { - // Look for external plugins in the barrier - storageKey := path.Join(pluginType.String(), name) - if version != "" { - storageKey = path.Join(storageKey, version) - } - out, err := c.catalogView.Get(ctx, storageKey) + // Look for external plugins in the barrier + storageKey := path.Join(pluginType.String(), name) + if version != "" { + storageKey = path.Join(storageKey, version) + } + out, err := c.catalogView.Get(ctx, storageKey) + if err != nil { + return nil, fmt.Errorf("failed to retrieve plugin %q: %w", name, err) + } + if out == nil && version == "" { + // Also look for external plugins under what their name would have been if they + // were registered before plugin types existed. + out, err = c.catalogView.Get(ctx, name) if err != nil { return nil, fmt.Errorf("failed to retrieve plugin %q: %w", name, err) } - if out == nil && version == "" { - // Also look for external plugins under what their name would have been if they - // were registered before plugin types existed. - out, err = c.catalogView.Get(ctx, name) - if err != nil { - return nil, fmt.Errorf("failed to retrieve plugin %q: %w", name, err) - } + } + entry := new(pluginutil.PluginRunner) + if out != nil { + if err := jsonutil.DecodeJSON(out.Value, entry); err != nil { + return nil, fmt.Errorf("failed to decode plugin entry: %w", err) + } + if entry.Type != pluginType && entry.Type != consts.PluginTypeUnknown { + return nil, nil } - if out != nil { - entry := new(pluginutil.PluginRunner) - if err := jsonutil.DecodeJSON(out.Value, entry); err != nil { - return nil, fmt.Errorf("failed to decode plugin entry: %w", err) - } - if entry.Type != pluginType && entry.Type != consts.PluginTypeUnknown { - return nil, nil - } - // Make the command path fully rooted if it's not a container plugin. - if entry.OCIImage == "" { - entry.Command = filepath.Join(c.directory, entry.Command) + // If none of the cases are satisfied, we'll search for a builtin plugin below. + switch { + case entry.OCIImage != "": + if entry.Runtime != "" { + entry.RuntimeConfig, err = c.runtimeCatalog.Get(ctx, entry.Runtime, consts.PluginRuntimeTypeContainer) + if err != nil { + return nil, fmt.Errorf("failed to get configured runtime for plugin %q: %w", name, err) + } } - + return entry, nil + case c.directory != "": + // Only allow returning non-container external plugins if we have a plugin directory. + // Make the command path fully rooted. + entry.Command = filepath.Join(c.directory, entry.Command) return entry, nil } } @@ -879,7 +889,7 @@ func (c *PluginCatalog) get(ctx context.Context, name string, pluginType consts. // Set registers a new external plugin with the catalog, or updates an existing // external plugin. It takes the name, command and SHA256 of the plugin. func (c *PluginCatalog) Set(ctx context.Context, plugin pluginutil.SetPluginInput) error { - if c.directory == "" { + if c.directory == "" && plugin.OCIImage == "" { return ErrDirectoryNotConfigured } @@ -930,6 +940,13 @@ func (c *PluginCatalog) setInternal(ctx context.Context, plugin pluginutil.SetPl Sha256: plugin.Sha256, Builtin: false, } + if entryTmp.OCIImage != "" && entryTmp.Runtime != "" { + var err error + entryTmp.RuntimeConfig, err = c.runtimeCatalog.Get(ctx, entryTmp.Runtime, consts.PluginRuntimeTypeContainer) + if err != nil { + return nil, fmt.Errorf("failed to get configured runtime for plugin %q: %w", plugin.Name, err) + } + } // If the plugin type is unknown, we want to attempt to determine the type if plugin.Type == consts.PluginTypeUnknown { var err error diff --git a/vault/plugin_catalog_test.go b/vault/plugin_catalog_test.go index 2f3ce61143..5dd62886ed 100644 --- a/vault/plugin_catalog_test.go +++ b/vault/plugin_catalog_test.go @@ -21,6 +21,7 @@ import ( "github.com/hashicorp/vault/plugins/database/postgresql" v5 "github.com/hashicorp/vault/sdk/database/dbplugin/v5" "github.com/hashicorp/vault/sdk/helper/consts" + "github.com/hashicorp/vault/sdk/helper/pluginruntimeutil" "github.com/hashicorp/vault/sdk/helper/pluginutil" backendplugin "github.com/hashicorp/vault/sdk/plugin" @@ -72,7 +73,7 @@ func TestPluginCatalog_CRUD(t *testing.T) { } // Set a plugin, test overwriting a builtin plugin - file, err := ioutil.TempFile(tempDir, "temp") + file, err := os.CreateTemp(tempDir, "temp") if err != nil { t.Fatal(err) } @@ -648,6 +649,151 @@ func TestPluginCatalog_MakeExternalPluginsKey_Comparable(t *testing.T) { } } +// TestPluginCatalog_ErrDirectoryNotConfigured ensures we correctly report an +// error when registering a binary plugin without a directory configured, and +// always allow registration of container plugins (rejecting on non-Linux happens +// in the logical system API handler). +func TestPluginCatalog_ErrDirectoryNotConfigured(t *testing.T) { + core, _, _ := TestCoreUnsealed(t) + tempDir, err := filepath.EvalSymlinks(t.TempDir()) + if err != nil { + t.Fatal(err) + } + + catalog := core.pluginCatalog + tests := map[string]func(t *testing.T){ + "set binary plugin": func(t *testing.T) { + file, err := os.CreateTemp(tempDir, "temp") + if err != nil { + t.Fatal(err) + } + defer file.Close() + + command := filepath.Base(file.Name()) + // Should error if directory not set. + err = catalog.Set(context.Background(), pluginutil.SetPluginInput{ + Name: "binary", + Type: consts.PluginTypeDatabase, + Command: command, + }) + dirSet := catalog.directory != "" + if dirSet { + if err != nil { + t.Fatal(err) + } + p, err := catalog.Get(context.Background(), "binary", consts.PluginTypeDatabase, "") + if err != nil { + t.Fatal(err) + } + expectedCommand := filepath.Join(tempDir, command) + if p.Command != expectedCommand { + t.Fatalf("Expected %s, got %s", expectedCommand, p.Command) + } + } + if !dirSet && err == nil { + t.Fatal("expected error without directory set") + } + // Make sure we can still get builtins too + _, err = core.pluginCatalog.Get(context.Background(), "mysql-database-plugin", consts.PluginTypeDatabase, "") + if err != nil { + t.Fatalf("unexpected error %v", err) + } + }, + "set container plugin": func(t *testing.T) { + // Should never error. + const image = "does-not-exist" + err = catalog.Set(context.Background(), pluginutil.SetPluginInput{ + Name: "container", + Type: consts.PluginTypeDatabase, + OCIImage: image, + }) + if err != nil { + t.Fatal(err) + } + // Check we can get it back ok. + p, err := catalog.Get(context.Background(), "container", consts.PluginTypeDatabase, "") + if err != nil { + t.Fatal(err) + } + if p.OCIImage != image { + t.Fatalf("Expected %s, got %s", image, p.OCIImage) + } + // Make sure we can still get builtins too + _, err = core.pluginCatalog.Get(context.Background(), "mysql-database-plugin", consts.PluginTypeDatabase, "") + if err != nil { + t.Fatalf("unexpected error %v", err) + } + }, + } + + t.Run("directory not set", func(t *testing.T) { + for name, test := range tests { + t.Run(name, test) + } + }) + + core.pluginCatalog.directory = tempDir + + t.Run("directory set", func(t *testing.T) { + for name, test := range tests { + t.Run(name, test) + } + }) +} + +// TestRuntimeConfigPopulatedIfSpecified ensures plugins read from the catalog +// are returned with their container runtime config populated if it was +// specified. +func TestRuntimeConfigPopulatedIfSpecified(t *testing.T) { + core, _, _ := TestCoreUnsealed(t) + const image = "does-not-exist" + const runtime = "custom-runtime" + err := core.pluginCatalog.Set(context.Background(), pluginutil.SetPluginInput{ + Name: "container", + Type: consts.PluginTypeDatabase, + OCIImage: image, + Runtime: runtime, + }) + if err == nil { + t.Fatal("specified runtime doesn't exist yet, should have failed") + } + + const ociRuntime = "some-other-oci-runtime" + err = core.pluginRuntimeCatalog.Set(context.Background(), &pluginruntimeutil.PluginRuntimeConfig{ + Name: runtime, + Type: consts.PluginRuntimeTypeContainer, + OCIRuntime: ociRuntime, + }) + if err != nil { + t.Fatal(err) + } + + // Now setting the plugin with a runtime should succeed. + err = core.pluginCatalog.Set(context.Background(), pluginutil.SetPluginInput{ + Name: "container", + Type: consts.PluginTypeDatabase, + OCIImage: image, + Runtime: runtime, + }) + if err != nil { + t.Fatal(err) + } + + p, err := core.pluginCatalog.Get(context.Background(), "container", consts.PluginTypeDatabase, "") + if err != nil { + t.Fatal(err) + } + if p.Runtime != runtime { + t.Errorf("expected %s, got %s", runtime, p.Runtime) + } + if p.RuntimeConfig == nil { + t.Fatal() + } + if p.RuntimeConfig.OCIRuntime != ociRuntime { + t.Errorf("expected %s, got %s", ociRuntime, p.RuntimeConfig.OCIRuntime) + } +} + func TestPluginCatalog_PluginMain_Userpass(t *testing.T) { if os.Getenv(pluginutil.PluginVaultVersionEnv) == "" { return