mirror of
				https://github.com/optim-enterprises-bv/vault.git
				synced 2025-11-03 20:17:59 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			317 lines
		
	
	
		
			8.7 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			317 lines
		
	
	
		
			8.7 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
package mysql
 | 
						|
 | 
						|
import (
 | 
						|
	"context"
 | 
						|
	"database/sql"
 | 
						|
	"fmt"
 | 
						|
	"io/ioutil"
 | 
						|
	"os"
 | 
						|
	paths "path"
 | 
						|
	"path/filepath"
 | 
						|
	"reflect"
 | 
						|
	"testing"
 | 
						|
	"time"
 | 
						|
 | 
						|
	"github.com/hashicorp/vault/helper/testhelpers/certhelpers"
 | 
						|
	"github.com/hashicorp/vault/sdk/database/helper/dbutil"
 | 
						|
	dockertest "github.com/ory/dockertest/v3"
 | 
						|
)
 | 
						|
 | 
						|
func Test_addTLStoDSN(t *testing.T) {
 | 
						|
	type testCase struct {
 | 
						|
		rootUrl        string
 | 
						|
		tlsConfigName  string
 | 
						|
		expectedResult string
 | 
						|
	}
 | 
						|
 | 
						|
	tests := map[string]testCase{
 | 
						|
		"no tls, no query string": {
 | 
						|
			rootUrl:        "user:password@tcp(localhost:3306)/test",
 | 
						|
			tlsConfigName:  "",
 | 
						|
			expectedResult: "user:password@tcp(localhost:3306)/test",
 | 
						|
		},
 | 
						|
		"tls, no query string": {
 | 
						|
			rootUrl:        "user:password@tcp(localhost:3306)/test",
 | 
						|
			tlsConfigName:  "tlsTest101",
 | 
						|
			expectedResult: "user:password@tcp(localhost:3306)/test?tls=tlsTest101",
 | 
						|
		},
 | 
						|
		"tls, query string": {
 | 
						|
			rootUrl:        "user:password@tcp(localhost:3306)/test?foo=bar",
 | 
						|
			tlsConfigName:  "tlsTest101",
 | 
						|
			expectedResult: "user:password@tcp(localhost:3306)/test?tls=tlsTest101&foo=bar",
 | 
						|
		},
 | 
						|
		"tls, query string, ? in password": {
 | 
						|
			rootUrl:        "user:pa?ssword?@tcp(localhost:3306)/test?foo=bar",
 | 
						|
			tlsConfigName:  "tlsTest101",
 | 
						|
			expectedResult: "user:pa?ssword?@tcp(localhost:3306)/test?tls=tlsTest101&foo=bar",
 | 
						|
		},
 | 
						|
		"tls, valid tls parameter in query string": {
 | 
						|
			rootUrl:        "user:password@tcp(localhost:3306)/test?tls=true",
 | 
						|
			tlsConfigName:  "",
 | 
						|
			expectedResult: "user:password@tcp(localhost:3306)/test?tls=true",
 | 
						|
		},
 | 
						|
	}
 | 
						|
 | 
						|
	for name, test := range tests {
 | 
						|
		t.Run(name, func(t *testing.T) {
 | 
						|
			tCase := mySQLConnectionProducer{
 | 
						|
				ConnectionURL: test.rootUrl,
 | 
						|
				tlsConfigName: test.tlsConfigName,
 | 
						|
			}
 | 
						|
 | 
						|
			actual, err := tCase.addTLStoDSN()
 | 
						|
			if err != nil {
 | 
						|
				t.Fatalf("error occurred in test: %s", err)
 | 
						|
			}
 | 
						|
			if actual != test.expectedResult {
 | 
						|
				t.Fatalf("generated: %s, expected: %s", actual, test.expectedResult)
 | 
						|
			}
 | 
						|
		})
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestInit_clientTLS(t *testing.T) {
 | 
						|
	t.Skip("Skipping this test because CircleCI can't mount the files we need without further investigation: " +
 | 
						|
		"https://support.circleci.com/hc/en-us/articles/360007324514-How-can-I-mount-volumes-to-docker-containers-")
 | 
						|
 | 
						|
	// Set up temp directory so we can mount it to the docker container
 | 
						|
	confDir := makeTempDir(t)
 | 
						|
	defer os.RemoveAll(confDir)
 | 
						|
 | 
						|
	// Create certificates for MySQL authentication
 | 
						|
	caCert := certhelpers.NewCert(t,
 | 
						|
		certhelpers.CommonName("test certificate authority"),
 | 
						|
		certhelpers.IsCA(true),
 | 
						|
		certhelpers.SelfSign(),
 | 
						|
	)
 | 
						|
	serverCert := certhelpers.NewCert(t,
 | 
						|
		certhelpers.CommonName("server"),
 | 
						|
		certhelpers.DNS("localhost"),
 | 
						|
		certhelpers.Parent(caCert),
 | 
						|
	)
 | 
						|
	clientCert := certhelpers.NewCert(t,
 | 
						|
		certhelpers.CommonName("client"),
 | 
						|
		certhelpers.DNS("client"),
 | 
						|
		certhelpers.Parent(caCert),
 | 
						|
	)
 | 
						|
 | 
						|
	writeFile(t, paths.Join(confDir, "ca.pem"), caCert.CombinedPEM(), 0o644)
 | 
						|
	writeFile(t, paths.Join(confDir, "server-cert.pem"), serverCert.Pem, 0o644)
 | 
						|
	writeFile(t, paths.Join(confDir, "server-key.pem"), serverCert.PrivateKeyPEM(), 0o644)
 | 
						|
	writeFile(t, paths.Join(confDir, "client.pem"), clientCert.CombinedPEM(), 0o644)
 | 
						|
 | 
						|
	// //////////////////////////////////////////////////////
 | 
						|
	// Set up MySQL config file
 | 
						|
	rawConf := `
 | 
						|
[mysqld]
 | 
						|
ssl
 | 
						|
ssl-ca=/etc/mysql/ca.pem
 | 
						|
ssl-cert=/etc/mysql/server-cert.pem
 | 
						|
ssl-key=/etc/mysql/server-key.pem`
 | 
						|
 | 
						|
	writeFile(t, paths.Join(confDir, "my.cnf"), []byte(rawConf), 0o644)
 | 
						|
 | 
						|
	// //////////////////////////////////////////////////////
 | 
						|
	// Start MySQL container
 | 
						|
	retURL, cleanup := startMySQLWithTLS(t, "5.7", confDir)
 | 
						|
	defer cleanup()
 | 
						|
 | 
						|
	// //////////////////////////////////////////////////////
 | 
						|
	// Set up x509 user
 | 
						|
	mClient := connect(t, retURL)
 | 
						|
 | 
						|
	username := setUpX509User(t, mClient, clientCert)
 | 
						|
 | 
						|
	// //////////////////////////////////////////////////////
 | 
						|
	// Test
 | 
						|
	mysql := newMySQL(DefaultUserNameTemplate)
 | 
						|
 | 
						|
	conf := map[string]interface{}{
 | 
						|
		"connection_url":      retURL,
 | 
						|
		"username":            username,
 | 
						|
		"tls_certificate_key": clientCert.CombinedPEM(),
 | 
						|
		"tls_ca":              caCert.Pem,
 | 
						|
	}
 | 
						|
 | 
						|
	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
 | 
						|
	defer cancel()
 | 
						|
 | 
						|
	_, err := mysql.Init(ctx, conf, true)
 | 
						|
	if err != nil {
 | 
						|
		t.Fatalf("Unable to initialize mysql engine: %s", err)
 | 
						|
	}
 | 
						|
 | 
						|
	// Initialization complete. The connection was established, but we need to ensure
 | 
						|
	// that we're connected as the right user
 | 
						|
	whoamiCmd := "SELECT CURRENT_USER()"
 | 
						|
 | 
						|
	client, err := mysql.getConnection(ctx)
 | 
						|
	if err != nil {
 | 
						|
		t.Fatalf("Unable to make connection to MySQL: %s", err)
 | 
						|
	}
 | 
						|
	stmt, err := client.Prepare(whoamiCmd)
 | 
						|
	if err != nil {
 | 
						|
		t.Fatalf("Unable to prepare MySQL statementL %s", err)
 | 
						|
	}
 | 
						|
 | 
						|
	results := stmt.QueryRow()
 | 
						|
 | 
						|
	expected := fmt.Sprintf("%s@%%", username)
 | 
						|
 | 
						|
	var result string
 | 
						|
	if err := results.Scan(&result); err != nil {
 | 
						|
		t.Fatalf("result could not be scanned from result set: %s", err)
 | 
						|
	}
 | 
						|
 | 
						|
	if !reflect.DeepEqual(result, expected) {
 | 
						|
		t.Fatalf("Actual:%#v\nExpected:\n%#v", result, expected)
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func makeTempDir(t *testing.T) (confDir string) {
 | 
						|
	confDir, err := ioutil.TempDir(".", "mysql-test-data")
 | 
						|
	if err != nil {
 | 
						|
		t.Fatalf("Unable to make temp directory: %s", err)
 | 
						|
	}
 | 
						|
	// Convert the directory to an absolute path because docker needs it when mounting
 | 
						|
	confDir, err = filepath.Abs(filepath.Clean(confDir))
 | 
						|
	if err != nil {
 | 
						|
		t.Fatalf("Unable to determine where temp directory is on absolute path: %s", err)
 | 
						|
	}
 | 
						|
	return confDir
 | 
						|
}
 | 
						|
 | 
						|
func startMySQLWithTLS(t *testing.T, version string, confDir string) (retURL string, cleanup func()) {
 | 
						|
	if os.Getenv("MYSQL_URL") != "" {
 | 
						|
		return os.Getenv("MYSQL_URL"), func() {}
 | 
						|
	}
 | 
						|
 | 
						|
	pool, err := dockertest.NewPool("")
 | 
						|
	if err != nil {
 | 
						|
		t.Fatalf("Failed to connect to docker: %s", err)
 | 
						|
	}
 | 
						|
	pool.MaxWait = 30 * time.Second
 | 
						|
 | 
						|
	containerName := "mysql-unit-test"
 | 
						|
 | 
						|
	// Remove previously running container if it is still running because cleanup failed
 | 
						|
	err = pool.RemoveContainerByName(containerName)
 | 
						|
	if err != nil {
 | 
						|
		t.Fatalf("Unable to remove old running containers: %s", err)
 | 
						|
	}
 | 
						|
 | 
						|
	username := "root"
 | 
						|
	password := "x509test"
 | 
						|
 | 
						|
	runOpts := &dockertest.RunOptions{
 | 
						|
		Name:       containerName,
 | 
						|
		Repository: "mysql",
 | 
						|
		Tag:        version,
 | 
						|
		Cmd:        []string{"--defaults-extra-file=/etc/mysql/my.cnf", "--auto-generate-certs=OFF"},
 | 
						|
		Env:        []string{fmt.Sprintf("MYSQL_ROOT_PASSWORD=%s", password)},
 | 
						|
		// Mount the directory from local filesystem into the container
 | 
						|
		Mounts: []string{
 | 
						|
			fmt.Sprintf("%s:/etc/mysql", confDir),
 | 
						|
		},
 | 
						|
	}
 | 
						|
 | 
						|
	resource, err := pool.RunWithOptions(runOpts)
 | 
						|
	if err != nil {
 | 
						|
		t.Fatalf("Could not start local mysql docker container: %s", err)
 | 
						|
	}
 | 
						|
	resource.Expire(30)
 | 
						|
 | 
						|
	cleanup = func() {
 | 
						|
		err := pool.Purge(resource)
 | 
						|
		if err != nil {
 | 
						|
			t.Fatalf("Failed to cleanup local container: %s", err)
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	dsn := fmt.Sprintf("{{username}}:{{password}}@tcp(localhost:%s)/mysql", resource.GetPort("3306/tcp"))
 | 
						|
 | 
						|
	url := dbutil.QueryHelper(dsn, map[string]string{
 | 
						|
		"username": username,
 | 
						|
		"password": password,
 | 
						|
	})
 | 
						|
	// exponential backoff-retry
 | 
						|
	err = pool.Retry(func() error {
 | 
						|
		var err error
 | 
						|
 | 
						|
		db, err := sql.Open("mysql", url)
 | 
						|
		if err != nil {
 | 
						|
			t.Logf("err: %s", err)
 | 
						|
			return err
 | 
						|
		}
 | 
						|
		defer db.Close()
 | 
						|
		return db.Ping()
 | 
						|
	})
 | 
						|
	if err != nil {
 | 
						|
		cleanup()
 | 
						|
		t.Fatalf("Could not connect to mysql docker container: %s", err)
 | 
						|
	}
 | 
						|
 | 
						|
	return dsn, cleanup
 | 
						|
}
 | 
						|
 | 
						|
func connect(t *testing.T, dsn string) (db *sql.DB) {
 | 
						|
	url := dbutil.QueryHelper(dsn, map[string]string{
 | 
						|
		"username": "root",
 | 
						|
		"password": "x509test",
 | 
						|
	})
 | 
						|
 | 
						|
	db, err := sql.Open("mysql", url)
 | 
						|
	if err != nil {
 | 
						|
		t.Fatalf("Unable to make connection to MySQL: %s", err)
 | 
						|
	}
 | 
						|
 | 
						|
	err = db.Ping()
 | 
						|
	if err != nil {
 | 
						|
		t.Fatalf("Failed to ping MySQL server: %s", err)
 | 
						|
	}
 | 
						|
 | 
						|
	return db
 | 
						|
}
 | 
						|
 | 
						|
func setUpX509User(t *testing.T, db *sql.DB, cert certhelpers.Certificate) (username string) {
 | 
						|
	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
 | 
						|
	defer cancel()
 | 
						|
 | 
						|
	username = cert.Template.Subject.CommonName
 | 
						|
 | 
						|
	cmds := []string{
 | 
						|
		fmt.Sprintf("CREATE USER %s IDENTIFIED BY '' REQUIRE X509", username),
 | 
						|
		fmt.Sprintf("GRANT ALL ON mysql.* TO '%s'@'%s' REQUIRE X509", username, "%"),
 | 
						|
	}
 | 
						|
 | 
						|
	for _, cmd := range cmds {
 | 
						|
		stmt, err := db.PrepareContext(ctx, cmd)
 | 
						|
		if err != nil {
 | 
						|
			t.Fatalf("Failed to prepare query: %s", err)
 | 
						|
		}
 | 
						|
 | 
						|
		_, err = stmt.ExecContext(ctx)
 | 
						|
		if err != nil {
 | 
						|
			t.Fatalf("Failed to create x509 user in database: %s", err)
 | 
						|
		}
 | 
						|
		err = stmt.Close()
 | 
						|
		if err != nil {
 | 
						|
			t.Fatalf("Failed to close prepared statement: %s", err)
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	return username
 | 
						|
}
 | 
						|
 | 
						|
// ////////////////////////////////////////////////////////////////////////////
 | 
						|
// Writing to file
 | 
						|
// ////////////////////////////////////////////////////////////////////////////
 | 
						|
func writeFile(t *testing.T, filename string, data []byte, perms os.FileMode) {
 | 
						|
	t.Helper()
 | 
						|
 | 
						|
	err := ioutil.WriteFile(filename, data, perms)
 | 
						|
	if err != nil {
 | 
						|
		t.Fatalf("Unable to write to file [%s]: %s", filename, err)
 | 
						|
	}
 | 
						|
}
 |