mirror of
				https://github.com/optim-enterprises-bv/vault.git
				synced 2025-10-31 18:48:08 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			317 lines
		
	
	
		
			8.7 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			317 lines
		
	
	
		
			8.7 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| package mysql
 | |
| 
 | |
| import (
 | |
| 	"context"
 | |
| 	"database/sql"
 | |
| 	"fmt"
 | |
| 	"io/ioutil"
 | |
| 	"os"
 | |
| 	paths "path"
 | |
| 	"path/filepath"
 | |
| 	"reflect"
 | |
| 	"testing"
 | |
| 	"time"
 | |
| 
 | |
| 	"github.com/hashicorp/vault/helper/testhelpers/certhelpers"
 | |
| 	"github.com/hashicorp/vault/sdk/database/helper/dbutil"
 | |
| 	dockertest "github.com/ory/dockertest/v3"
 | |
| )
 | |
| 
 | |
| func Test_addTLStoDSN(t *testing.T) {
 | |
| 	type testCase struct {
 | |
| 		rootUrl        string
 | |
| 		tlsConfigName  string
 | |
| 		expectedResult string
 | |
| 	}
 | |
| 
 | |
| 	tests := map[string]testCase{
 | |
| 		"no tls, no query string": {
 | |
| 			rootUrl:        "user:password@tcp(localhost:3306)/test",
 | |
| 			tlsConfigName:  "",
 | |
| 			expectedResult: "user:password@tcp(localhost:3306)/test",
 | |
| 		},
 | |
| 		"tls, no query string": {
 | |
| 			rootUrl:        "user:password@tcp(localhost:3306)/test",
 | |
| 			tlsConfigName:  "tlsTest101",
 | |
| 			expectedResult: "user:password@tcp(localhost:3306)/test?tls=tlsTest101",
 | |
| 		},
 | |
| 		"tls, query string": {
 | |
| 			rootUrl:        "user:password@tcp(localhost:3306)/test?foo=bar",
 | |
| 			tlsConfigName:  "tlsTest101",
 | |
| 			expectedResult: "user:password@tcp(localhost:3306)/test?tls=tlsTest101&foo=bar",
 | |
| 		},
 | |
| 		"tls, query string, ? in password": {
 | |
| 			rootUrl:        "user:pa?ssword?@tcp(localhost:3306)/test?foo=bar",
 | |
| 			tlsConfigName:  "tlsTest101",
 | |
| 			expectedResult: "user:pa?ssword?@tcp(localhost:3306)/test?tls=tlsTest101&foo=bar",
 | |
| 		},
 | |
| 		"tls, valid tls parameter in query string": {
 | |
| 			rootUrl:        "user:password@tcp(localhost:3306)/test?tls=true",
 | |
| 			tlsConfigName:  "",
 | |
| 			expectedResult: "user:password@tcp(localhost:3306)/test?tls=true",
 | |
| 		},
 | |
| 	}
 | |
| 
 | |
| 	for name, test := range tests {
 | |
| 		t.Run(name, func(t *testing.T) {
 | |
| 			tCase := mySQLConnectionProducer{
 | |
| 				ConnectionURL: test.rootUrl,
 | |
| 				tlsConfigName: test.tlsConfigName,
 | |
| 			}
 | |
| 
 | |
| 			actual, err := tCase.addTLStoDSN()
 | |
| 			if err != nil {
 | |
| 				t.Fatalf("error occurred in test: %s", err)
 | |
| 			}
 | |
| 			if actual != test.expectedResult {
 | |
| 				t.Fatalf("generated: %s, expected: %s", actual, test.expectedResult)
 | |
| 			}
 | |
| 		})
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func TestInit_clientTLS(t *testing.T) {
 | |
| 	t.Skip("Skipping this test because CircleCI can't mount the files we need without further investigation: " +
 | |
| 		"https://support.circleci.com/hc/en-us/articles/360007324514-How-can-I-mount-volumes-to-docker-containers-")
 | |
| 
 | |
| 	// Set up temp directory so we can mount it to the docker container
 | |
| 	confDir := makeTempDir(t)
 | |
| 	defer os.RemoveAll(confDir)
 | |
| 
 | |
| 	// Create certificates for MySQL authentication
 | |
| 	caCert := certhelpers.NewCert(t,
 | |
| 		certhelpers.CommonName("test certificate authority"),
 | |
| 		certhelpers.IsCA(true),
 | |
| 		certhelpers.SelfSign(),
 | |
| 	)
 | |
| 	serverCert := certhelpers.NewCert(t,
 | |
| 		certhelpers.CommonName("server"),
 | |
| 		certhelpers.DNS("localhost"),
 | |
| 		certhelpers.Parent(caCert),
 | |
| 	)
 | |
| 	clientCert := certhelpers.NewCert(t,
 | |
| 		certhelpers.CommonName("client"),
 | |
| 		certhelpers.DNS("client"),
 | |
| 		certhelpers.Parent(caCert),
 | |
| 	)
 | |
| 
 | |
| 	writeFile(t, paths.Join(confDir, "ca.pem"), caCert.CombinedPEM(), 0o644)
 | |
| 	writeFile(t, paths.Join(confDir, "server-cert.pem"), serverCert.Pem, 0o644)
 | |
| 	writeFile(t, paths.Join(confDir, "server-key.pem"), serverCert.PrivateKeyPEM(), 0o644)
 | |
| 	writeFile(t, paths.Join(confDir, "client.pem"), clientCert.CombinedPEM(), 0o644)
 | |
| 
 | |
| 	// //////////////////////////////////////////////////////
 | |
| 	// Set up MySQL config file
 | |
| 	rawConf := `
 | |
| [mysqld]
 | |
| ssl
 | |
| ssl-ca=/etc/mysql/ca.pem
 | |
| ssl-cert=/etc/mysql/server-cert.pem
 | |
| ssl-key=/etc/mysql/server-key.pem`
 | |
| 
 | |
| 	writeFile(t, paths.Join(confDir, "my.cnf"), []byte(rawConf), 0o644)
 | |
| 
 | |
| 	// //////////////////////////////////////////////////////
 | |
| 	// Start MySQL container
 | |
| 	retURL, cleanup := startMySQLWithTLS(t, "5.7", confDir)
 | |
| 	defer cleanup()
 | |
| 
 | |
| 	// //////////////////////////////////////////////////////
 | |
| 	// Set up x509 user
 | |
| 	mClient := connect(t, retURL)
 | |
| 
 | |
| 	username := setUpX509User(t, mClient, clientCert)
 | |
| 
 | |
| 	// //////////////////////////////////////////////////////
 | |
| 	// Test
 | |
| 	mysql := newMySQL(DefaultUserNameTemplate)
 | |
| 
 | |
| 	conf := map[string]interface{}{
 | |
| 		"connection_url":      retURL,
 | |
| 		"username":            username,
 | |
| 		"tls_certificate_key": clientCert.CombinedPEM(),
 | |
| 		"tls_ca":              caCert.Pem,
 | |
| 	}
 | |
| 
 | |
| 	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
 | |
| 	defer cancel()
 | |
| 
 | |
| 	_, err := mysql.Init(ctx, conf, true)
 | |
| 	if err != nil {
 | |
| 		t.Fatalf("Unable to initialize mysql engine: %s", err)
 | |
| 	}
 | |
| 
 | |
| 	// Initialization complete. The connection was established, but we need to ensure
 | |
| 	// that we're connected as the right user
 | |
| 	whoamiCmd := "SELECT CURRENT_USER()"
 | |
| 
 | |
| 	client, err := mysql.getConnection(ctx)
 | |
| 	if err != nil {
 | |
| 		t.Fatalf("Unable to make connection to MySQL: %s", err)
 | |
| 	}
 | |
| 	stmt, err := client.Prepare(whoamiCmd)
 | |
| 	if err != nil {
 | |
| 		t.Fatalf("Unable to prepare MySQL statementL %s", err)
 | |
| 	}
 | |
| 
 | |
| 	results := stmt.QueryRow()
 | |
| 
 | |
| 	expected := fmt.Sprintf("%s@%%", username)
 | |
| 
 | |
| 	var result string
 | |
| 	if err := results.Scan(&result); err != nil {
 | |
| 		t.Fatalf("result could not be scanned from result set: %s", err)
 | |
| 	}
 | |
| 
 | |
| 	if !reflect.DeepEqual(result, expected) {
 | |
| 		t.Fatalf("Actual:%#v\nExpected:\n%#v", result, expected)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func makeTempDir(t *testing.T) (confDir string) {
 | |
| 	confDir, err := ioutil.TempDir(".", "mysql-test-data")
 | |
| 	if err != nil {
 | |
| 		t.Fatalf("Unable to make temp directory: %s", err)
 | |
| 	}
 | |
| 	// Convert the directory to an absolute path because docker needs it when mounting
 | |
| 	confDir, err = filepath.Abs(filepath.Clean(confDir))
 | |
| 	if err != nil {
 | |
| 		t.Fatalf("Unable to determine where temp directory is on absolute path: %s", err)
 | |
| 	}
 | |
| 	return confDir
 | |
| }
 | |
| 
 | |
| func startMySQLWithTLS(t *testing.T, version string, confDir string) (retURL string, cleanup func()) {
 | |
| 	if os.Getenv("MYSQL_URL") != "" {
 | |
| 		return os.Getenv("MYSQL_URL"), func() {}
 | |
| 	}
 | |
| 
 | |
| 	pool, err := dockertest.NewPool("")
 | |
| 	if err != nil {
 | |
| 		t.Fatalf("Failed to connect to docker: %s", err)
 | |
| 	}
 | |
| 	pool.MaxWait = 30 * time.Second
 | |
| 
 | |
| 	containerName := "mysql-unit-test"
 | |
| 
 | |
| 	// Remove previously running container if it is still running because cleanup failed
 | |
| 	err = pool.RemoveContainerByName(containerName)
 | |
| 	if err != nil {
 | |
| 		t.Fatalf("Unable to remove old running containers: %s", err)
 | |
| 	}
 | |
| 
 | |
| 	username := "root"
 | |
| 	password := "x509test"
 | |
| 
 | |
| 	runOpts := &dockertest.RunOptions{
 | |
| 		Name:       containerName,
 | |
| 		Repository: "mysql",
 | |
| 		Tag:        version,
 | |
| 		Cmd:        []string{"--defaults-extra-file=/etc/mysql/my.cnf", "--auto-generate-certs=OFF"},
 | |
| 		Env:        []string{fmt.Sprintf("MYSQL_ROOT_PASSWORD=%s", password)},
 | |
| 		// Mount the directory from local filesystem into the container
 | |
| 		Mounts: []string{
 | |
| 			fmt.Sprintf("%s:/etc/mysql", confDir),
 | |
| 		},
 | |
| 	}
 | |
| 
 | |
| 	resource, err := pool.RunWithOptions(runOpts)
 | |
| 	if err != nil {
 | |
| 		t.Fatalf("Could not start local mysql docker container: %s", err)
 | |
| 	}
 | |
| 	resource.Expire(30)
 | |
| 
 | |
| 	cleanup = func() {
 | |
| 		err := pool.Purge(resource)
 | |
| 		if err != nil {
 | |
| 			t.Fatalf("Failed to cleanup local container: %s", err)
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	dsn := fmt.Sprintf("{{username}}:{{password}}@tcp(localhost:%s)/mysql", resource.GetPort("3306/tcp"))
 | |
| 
 | |
| 	url := dbutil.QueryHelper(dsn, map[string]string{
 | |
| 		"username": username,
 | |
| 		"password": password,
 | |
| 	})
 | |
| 	// exponential backoff-retry
 | |
| 	err = pool.Retry(func() error {
 | |
| 		var err error
 | |
| 
 | |
| 		db, err := sql.Open("mysql", url)
 | |
| 		if err != nil {
 | |
| 			t.Logf("err: %s", err)
 | |
| 			return err
 | |
| 		}
 | |
| 		defer db.Close()
 | |
| 		return db.Ping()
 | |
| 	})
 | |
| 	if err != nil {
 | |
| 		cleanup()
 | |
| 		t.Fatalf("Could not connect to mysql docker container: %s", err)
 | |
| 	}
 | |
| 
 | |
| 	return dsn, cleanup
 | |
| }
 | |
| 
 | |
| func connect(t *testing.T, dsn string) (db *sql.DB) {
 | |
| 	url := dbutil.QueryHelper(dsn, map[string]string{
 | |
| 		"username": "root",
 | |
| 		"password": "x509test",
 | |
| 	})
 | |
| 
 | |
| 	db, err := sql.Open("mysql", url)
 | |
| 	if err != nil {
 | |
| 		t.Fatalf("Unable to make connection to MySQL: %s", err)
 | |
| 	}
 | |
| 
 | |
| 	err = db.Ping()
 | |
| 	if err != nil {
 | |
| 		t.Fatalf("Failed to ping MySQL server: %s", err)
 | |
| 	}
 | |
| 
 | |
| 	return db
 | |
| }
 | |
| 
 | |
| func setUpX509User(t *testing.T, db *sql.DB, cert certhelpers.Certificate) (username string) {
 | |
| 	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
 | |
| 	defer cancel()
 | |
| 
 | |
| 	username = cert.Template.Subject.CommonName
 | |
| 
 | |
| 	cmds := []string{
 | |
| 		fmt.Sprintf("CREATE USER %s IDENTIFIED BY '' REQUIRE X509", username),
 | |
| 		fmt.Sprintf("GRANT ALL ON mysql.* TO '%s'@'%s' REQUIRE X509", username, "%"),
 | |
| 	}
 | |
| 
 | |
| 	for _, cmd := range cmds {
 | |
| 		stmt, err := db.PrepareContext(ctx, cmd)
 | |
| 		if err != nil {
 | |
| 			t.Fatalf("Failed to prepare query: %s", err)
 | |
| 		}
 | |
| 
 | |
| 		_, err = stmt.ExecContext(ctx)
 | |
| 		if err != nil {
 | |
| 			t.Fatalf("Failed to create x509 user in database: %s", err)
 | |
| 		}
 | |
| 		err = stmt.Close()
 | |
| 		if err != nil {
 | |
| 			t.Fatalf("Failed to close prepared statement: %s", err)
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	return username
 | |
| }
 | |
| 
 | |
| // ////////////////////////////////////////////////////////////////////////////
 | |
| // Writing to file
 | |
| // ////////////////////////////////////////////////////////////////////////////
 | |
| func writeFile(t *testing.T, filename string, data []byte, perms os.FileMode) {
 | |
| 	t.Helper()
 | |
| 
 | |
| 	err := ioutil.WriteFile(filename, data, perms)
 | |
| 	if err != nil {
 | |
| 		t.Fatalf("Unable to write to file [%s]: %s", filename, err)
 | |
| 	}
 | |
| }
 | 
