mirror of
				https://github.com/optim-enterprises-bv/vault.git
				synced 2025-11-04 04:28:08 +00:00 
			
		
		
		
	Also fixes a function where we may call go-plugin's client.Client() without ever calling client.Kill(), which could leak plugin processes
		
			
				
	
	
		
			175 lines
		
	
	
		
			5.0 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			175 lines
		
	
	
		
			5.0 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
// Copyright (c) HashiCorp, Inc.
 | 
						|
// SPDX-License-Identifier: BUSL-1.1
 | 
						|
 | 
						|
// Package pluginhelpers contains testhelpers that don't depend on package
 | 
						|
// vault, and thus can be used within vault (as well as elsewhere.)
 | 
						|
package pluginhelpers
 | 
						|
 | 
						|
import (
 | 
						|
	"crypto/sha256"
 | 
						|
	"fmt"
 | 
						|
	"os"
 | 
						|
	"os/exec"
 | 
						|
	"path"
 | 
						|
	"path/filepath"
 | 
						|
	"strings"
 | 
						|
	"sync"
 | 
						|
 | 
						|
	"github.com/hashicorp/vault/sdk/helper/consts"
 | 
						|
	"github.com/mitchellh/go-testing-interface"
 | 
						|
)
 | 
						|
 | 
						|
var (
 | 
						|
	testPluginCacheLock sync.Mutex
 | 
						|
	testPluginCache     = map[string][]byte{}
 | 
						|
)
 | 
						|
 | 
						|
type TestPlugin struct {
 | 
						|
	Name        string
 | 
						|
	Typ         consts.PluginType
 | 
						|
	Version     string
 | 
						|
	FileName    string
 | 
						|
	Sha256      string
 | 
						|
	Image       string
 | 
						|
	ImageSha256 string
 | 
						|
}
 | 
						|
 | 
						|
func GetPlugin(t testing.T, typ consts.PluginType) (string, string, string, string) {
 | 
						|
	t.Helper()
 | 
						|
	var pluginName string
 | 
						|
	var pluginType string
 | 
						|
	var pluginMain string
 | 
						|
	var pluginVersionLocation string
 | 
						|
 | 
						|
	switch typ {
 | 
						|
	case consts.PluginTypeCredential:
 | 
						|
		pluginType = "approle"
 | 
						|
		pluginName = "vault-plugin-auth-" + pluginType
 | 
						|
		pluginMain = filepath.Join("builtin", "credential", pluginType, "cmd", pluginType, "main.go")
 | 
						|
		pluginVersionLocation = fmt.Sprintf("github.com/hashicorp/vault/builtin/credential/%s.ReportedVersion", pluginType)
 | 
						|
	case consts.PluginTypeSecrets:
 | 
						|
		pluginType = "consul"
 | 
						|
		pluginName = "vault-plugin-secrets-" + pluginType
 | 
						|
		pluginMain = filepath.Join("builtin", "logical", pluginType, "cmd", pluginType, "main.go")
 | 
						|
		pluginVersionLocation = fmt.Sprintf("github.com/hashicorp/vault/builtin/logical/%s.ReportedVersion", pluginType)
 | 
						|
	case consts.PluginTypeDatabase:
 | 
						|
		pluginType = "postgresql"
 | 
						|
		pluginName = "vault-plugin-database-" + pluginType
 | 
						|
		pluginMain = filepath.Join("plugins", "database", pluginType, fmt.Sprintf("%s-database-plugin", pluginType), "main.go")
 | 
						|
		pluginVersionLocation = fmt.Sprintf("github.com/hashicorp/vault/plugins/database/%s.ReportedVersion", pluginType)
 | 
						|
	default:
 | 
						|
		t.Fatal(typ.String())
 | 
						|
	}
 | 
						|
	return pluginName, pluginType, pluginMain, pluginVersionLocation
 | 
						|
}
 | 
						|
 | 
						|
// to mount a plugin, we need a working binary plugin, so we compile one here.
 | 
						|
// pluginVersion is used to override the plugin's self-reported version
 | 
						|
func CompilePlugin(t testing.T, typ consts.PluginType, pluginVersion string, pluginDir string) TestPlugin {
 | 
						|
	t.Helper()
 | 
						|
 | 
						|
	pluginName, pluginType, pluginMain, pluginVersionLocation := GetPlugin(t, typ)
 | 
						|
 | 
						|
	testPluginCacheLock.Lock()
 | 
						|
	defer testPluginCacheLock.Unlock()
 | 
						|
 | 
						|
	var pluginBytes []byte
 | 
						|
 | 
						|
	dir := ""
 | 
						|
	var err error
 | 
						|
	pluginRootDir := "builtin"
 | 
						|
	if typ == consts.PluginTypeDatabase {
 | 
						|
		pluginRootDir = "plugins"
 | 
						|
	}
 | 
						|
	for {
 | 
						|
		dir, err = os.Getwd()
 | 
						|
		if err != nil {
 | 
						|
			t.Fatal(err)
 | 
						|
		}
 | 
						|
		// detect if we are in a subdirectory or the root directory and compensate
 | 
						|
		if _, err := os.Stat(pluginRootDir); os.IsNotExist(err) {
 | 
						|
			err := os.Chdir("..")
 | 
						|
			if err != nil {
 | 
						|
				t.Fatal(err)
 | 
						|
			}
 | 
						|
		} else {
 | 
						|
			break
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	pluginPath := path.Join(pluginDir, pluginName)
 | 
						|
	if pluginVersion != "" {
 | 
						|
		pluginPath += "-" + pluginVersion
 | 
						|
	}
 | 
						|
 | 
						|
	key := fmt.Sprintf("%s %s %s", pluginName, pluginType, pluginVersion)
 | 
						|
	// cache the compilation to only run once
 | 
						|
	var ok bool
 | 
						|
	pluginBytes, ok = testPluginCache[key]
 | 
						|
	if !ok {
 | 
						|
		// we need to compile
 | 
						|
		line := []string{"build"}
 | 
						|
		if pluginVersion != "" {
 | 
						|
			line = append(line, "-ldflags", fmt.Sprintf("-X %s=%s", pluginVersionLocation, pluginVersion))
 | 
						|
		}
 | 
						|
		line = append(line, "-o", pluginPath, pluginMain)
 | 
						|
		cmd := exec.Command("go", line...)
 | 
						|
		cmd.Env = append(os.Environ(), "CGO_ENABLED=0")
 | 
						|
		cmd.Dir = dir
 | 
						|
		output, err := cmd.CombinedOutput()
 | 
						|
		if err != nil {
 | 
						|
			t.Fatal(fmt.Errorf("error running go build %v output: %s", err, output))
 | 
						|
		}
 | 
						|
		testPluginCache[key], err = os.ReadFile(pluginPath)
 | 
						|
		if err != nil {
 | 
						|
			t.Fatal(err)
 | 
						|
		}
 | 
						|
		pluginBytes = testPluginCache[key]
 | 
						|
	}
 | 
						|
 | 
						|
	// write the cached plugin if necessary
 | 
						|
	if _, err := os.Stat(pluginPath); os.IsNotExist(err) {
 | 
						|
		err = os.WriteFile(pluginPath, pluginBytes, 0o755)
 | 
						|
	}
 | 
						|
	if err != nil {
 | 
						|
		t.Fatal(err)
 | 
						|
	}
 | 
						|
 | 
						|
	sha := sha256.New()
 | 
						|
	_, err = sha.Write(pluginBytes)
 | 
						|
	if err != nil {
 | 
						|
		t.Fatal(err)
 | 
						|
	}
 | 
						|
	return TestPlugin{
 | 
						|
		Name:     pluginName,
 | 
						|
		Typ:      typ,
 | 
						|
		Version:  pluginVersion,
 | 
						|
		FileName: path.Base(pluginPath),
 | 
						|
		Sha256:   fmt.Sprintf("%x", sha.Sum(nil)),
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func BuildPluginContainerImage(t testing.T, plugin TestPlugin, pluginDir string) (image string, sha256 string) {
 | 
						|
	t.Helper()
 | 
						|
	ref := plugin.Name
 | 
						|
	if plugin.Version != "" {
 | 
						|
		ref += ":" + strings.TrimPrefix(plugin.Version, "v")
 | 
						|
	} else {
 | 
						|
		ref += ":latest"
 | 
						|
	}
 | 
						|
	args := []string{"build", "--tag=" + ref, "--build-arg=plugin=" + plugin.FileName, "--file=vault/testdata/Dockerfile", pluginDir}
 | 
						|
	cmd := exec.Command("docker", args...)
 | 
						|
	output, err := cmd.CombinedOutput()
 | 
						|
	if err != nil {
 | 
						|
		t.Fatal(fmt.Errorf("error running docker build %v output: %s", err, output))
 | 
						|
	}
 | 
						|
 | 
						|
	cmd = exec.Command("docker", "images", ref, "--format={{ .ID }}", "--no-trunc")
 | 
						|
	id, err := cmd.CombinedOutput()
 | 
						|
	if err != nil {
 | 
						|
		t.Fatal(fmt.Errorf("error running docker build %v output: %s", err, output))
 | 
						|
	}
 | 
						|
 | 
						|
	return plugin.Name, strings.TrimSpace(strings.TrimPrefix(string(id), "sha256:"))
 | 
						|
}
 |