Allow mTLS for mysql secrets engine (#9181)

* Extract certificate helpers for use in non-mongodb packages
* Created mTLS/X509 test for MySQL secrets engine.
* Ensure mysql username and passwords aren't url encoded
* Skip mTLS test for circleCI
This commit is contained in:
Lauren Voswinkel
2020-06-17 11:46:01 -07:00
committed by GitHub
parent cf8eaacd4e
commit 601d0eb6ea
7 changed files with 826 additions and 46 deletions

View File

@@ -0,0 +1,226 @@
package mysql
import (
"context"
"crypto/tls"
"crypto/x509"
"database/sql"
"fmt"
"net/url"
"sync"
"time"
"github.com/go-sql-driver/mysql"
"github.com/hashicorp/errwrap"
"github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/sdk/database/helper/connutil"
"github.com/hashicorp/vault/sdk/database/helper/dbutil"
"github.com/hashicorp/vault/sdk/helper/parseutil"
"github.com/mitchellh/mapstructure"
)
// mySQLConnectionProducer implements ConnectionProducer and provides a generic producer for most sql databases
type mySQLConnectionProducer struct {
ConnectionURL string `json:"connection_url" mapstructure:"connection_url" structs:"connection_url"`
MaxOpenConnections int `json:"max_open_connections" mapstructure:"max_open_connections" structs:"max_open_connections"`
MaxIdleConnections int `json:"max_idle_connections" mapstructure:"max_idle_connections" structs:"max_idle_connections"`
MaxConnectionLifetimeRaw interface{} `json:"max_connection_lifetime" mapstructure:"max_connection_lifetime" structs:"max_connection_lifetime"`
Username string `json:"username" mapstructure:"username" structs:"username"`
Password string `json:"password" mapstructure:"password" structs:"password"`
TLSCertificateKeyData []byte `json:"tls_certificate_key" mapstructure:"tls_certificate_key" structs:"-"`
TLSCAData []byte `json:"tls_ca" mapstructure:"tls_ca" structs:"-"`
// tlsConfigName is a globally unique name that references the TLS config for this instance in the mysql driver
tlsConfigName string
RawConfig map[string]interface{}
maxConnectionLifetime time.Duration
Initialized bool
db *sql.DB
sync.Mutex
}
func (c *mySQLConnectionProducer) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) error {
_, err := c.Init(ctx, conf, verifyConnection)
return err
}
func (c *mySQLConnectionProducer) Init(ctx context.Context, conf map[string]interface{}, verifyConnection bool) (map[string]interface{}, error) {
c.Lock()
defer c.Unlock()
c.RawConfig = conf
err := mapstructure.WeakDecode(conf, &c)
if err != nil {
return nil, err
}
if len(c.ConnectionURL) == 0 {
return nil, fmt.Errorf("connection_url cannot be empty")
}
// Don't escape special characters for MySQL password
password := c.Password
// QueryHelper doesn't do any SQL escaping, but if it starts to do so
// then maybe we won't be able to use it to do URL substitution any more.
c.ConnectionURL = dbutil.QueryHelper(c.ConnectionURL, map[string]string{
"username": url.PathEscape(c.Username),
"password": password,
})
if c.MaxOpenConnections == 0 {
c.MaxOpenConnections = 4
}
if c.MaxIdleConnections == 0 {
c.MaxIdleConnections = c.MaxOpenConnections
}
if c.MaxIdleConnections > c.MaxOpenConnections {
c.MaxIdleConnections = c.MaxOpenConnections
}
if c.MaxConnectionLifetimeRaw == nil {
c.MaxConnectionLifetimeRaw = "0s"
}
c.maxConnectionLifetime, err = parseutil.ParseDurationSecond(c.MaxConnectionLifetimeRaw)
if err != nil {
return nil, errwrap.Wrapf("invalid max_connection_lifetime: {{err}}", err)
}
tlsConfig, err := c.getTLSAuth()
if err != nil {
return nil, err
}
if tlsConfig != nil {
if c.tlsConfigName == "" {
c.tlsConfigName, err = uuid.GenerateUUID()
if err != nil {
return nil, fmt.Errorf("unable to generate UUID for TLS configuration: %w", err)
}
}
mysql.RegisterTLSConfig(c.tlsConfigName, tlsConfig)
}
// Set initialized to true at this point since all fields are set,
// and the connection can be established at a later time.
c.Initialized = true
if verifyConnection {
if _, err := c.Connection(ctx); err != nil {
return nil, errwrap.Wrapf("error verifying connection: {{err}}", err)
}
if err := c.db.PingContext(ctx); err != nil {
return nil, errwrap.Wrapf("error verifying connection: {{err}}", err)
}
}
return c.RawConfig, nil
}
func (c *mySQLConnectionProducer) Connection(ctx context.Context) (interface{}, error) {
if !c.Initialized {
return nil, connutil.ErrNotInitialized
}
// If we already have a DB, test it and return
if c.db != nil {
if err := c.db.PingContext(ctx); err == nil {
return c.db, nil
}
// If the ping was unsuccessful, close it and ignore errors as we'll be
// reestablishing anyways
c.db.Close()
}
connURL, err := c.addTLStoDSN()
if err != nil {
return nil, err
}
c.db, err = sql.Open("mysql", connURL)
if err != nil {
return nil, err
}
// Set some connection pool settings. We don't need much of this,
// since the request rate shouldn't be high.
c.db.SetMaxOpenConns(c.MaxOpenConnections)
c.db.SetMaxIdleConns(c.MaxIdleConnections)
c.db.SetConnMaxLifetime(c.maxConnectionLifetime)
return c.db, nil
}
func (c *mySQLConnectionProducer) SecretValues() map[string]interface{} {
return map[string]interface{}{
c.Password: "[password]",
}
}
// Close attempts to close the connection
func (c *mySQLConnectionProducer) Close() error {
// Grab the write lock
c.Lock()
defer c.Unlock()
if c.db != nil {
c.db.Close()
}
c.db = nil
return nil
}
func (c *mySQLConnectionProducer) getTLSAuth() (tlsConfig *tls.Config, err error) {
if len(c.TLSCAData) == 0 &&
len(c.TLSCertificateKeyData) == 0 {
return nil, nil
}
rootCertPool := x509.NewCertPool()
if len(c.TLSCAData) > 0 {
ok := rootCertPool.AppendCertsFromPEM(c.TLSCAData)
if !ok {
return nil, fmt.Errorf("failed to append CA to client options")
}
}
clientCert := make([]tls.Certificate, 0, 1)
if len(c.TLSCertificateKeyData) > 0 {
certificate, err := tls.X509KeyPair(c.TLSCertificateKeyData, c.TLSCertificateKeyData)
if err != nil {
return nil, fmt.Errorf("unable to load tls_certificate_key_data: %w", err)
}
clientCert = append(clientCert, certificate)
}
tlsConfig = &tls.Config{
RootCAs: rootCertPool,
Certificates: clientCert,
}
return tlsConfig, nil
}
func (c *mySQLConnectionProducer) addTLStoDSN() (connURL string, err error) {
config, err := mysql.ParseDSN(c.ConnectionURL)
if err != nil {
return "", fmt.Errorf("unable to parse connectionURL: %s", err)
}
config.TLSConfig = c.tlsConfigName
connURL = config.FormatDSN()
return connURL, nil
}

View File

@@ -0,0 +1,311 @@
package mysql
import (
"context"
"database/sql"
"fmt"
"io/ioutil"
"os"
paths "path"
"path/filepath"
"reflect"
"testing"
"time"
"github.com/hashicorp/vault/helper/testhelpers/certhelpers"
"github.com/hashicorp/vault/sdk/database/helper/dbutil"
"github.com/ory/dockertest"
)
func Test_addTLStoDSN(t *testing.T) {
type testCase struct {
rootUrl string
tlsConfigName string
expectedResult string
}
tests := map[string]testCase{
"no tls, no query string": {
rootUrl: "user:password@tcp(localhost:3306)/test",
tlsConfigName: "",
expectedResult: "user:password@tcp(localhost:3306)/test",
},
"tls, no query string": {
rootUrl: "user:password@tcp(localhost:3306)/test",
tlsConfigName: "tlsTest101",
expectedResult: "user:password@tcp(localhost:3306)/test?tls=tlsTest101",
},
"tls, query string": {
rootUrl: "user:password@tcp(localhost:3306)/test?foo=bar",
tlsConfigName: "tlsTest101",
expectedResult: "user:password@tcp(localhost:3306)/test?tls=tlsTest101&foo=bar",
},
"tls, query string, ? in password": {
rootUrl: "user:pa?ssword?@tcp(localhost:3306)/test?foo=bar",
tlsConfigName: "tlsTest101",
expectedResult: "user:pa?ssword?@tcp(localhost:3306)/test?tls=tlsTest101&foo=bar",
},
}
for name, test := range tests {
t.Run(name, func(t *testing.T) {
tCase := mySQLConnectionProducer{
ConnectionURL: test.rootUrl,
tlsConfigName: test.tlsConfigName,
}
actual, err := tCase.addTLStoDSN()
if err != nil {
t.Fatalf("error occurred in test: %s", err)
}
if actual != test.expectedResult {
t.Fatalf("generated: %s, expected: %s", actual, test.expectedResult)
}
})
}
}
func TestInit_clientTLS(t *testing.T) {
t.Skip("Skipping this test because CircleCI can't mount the files we need without further investigation: " +
"https://support.circleci.com/hc/en-us/articles/360007324514-How-can-I-mount-volumes-to-docker-containers-")
// Set up temp directory so we can mount it to the docker container
confDir := makeTempDir(t)
defer os.RemoveAll(confDir)
// Create certificates for MySQL authentication
caCert := certhelpers.NewCert(t,
certhelpers.CommonName("test certificate authority"),
certhelpers.IsCA(true),
certhelpers.SelfSign(),
)
serverCert := certhelpers.NewCert(t,
certhelpers.CommonName("server"),
certhelpers.DNS("localhost"),
certhelpers.Parent(caCert),
)
clientCert := certhelpers.NewCert(t,
certhelpers.CommonName("client"),
certhelpers.DNS("client"),
certhelpers.Parent(caCert),
)
writeFile(t, paths.Join(confDir, "ca.pem"), caCert.CombinedPEM(), 0644)
writeFile(t, paths.Join(confDir, "server-cert.pem"), serverCert.Pem, 0644)
writeFile(t, paths.Join(confDir, "server-key.pem"), serverCert.PrivateKeyPEM(), 0644)
writeFile(t, paths.Join(confDir, "client.pem"), clientCert.CombinedPEM(), 0644)
// //////////////////////////////////////////////////////
// Set up MySQL config file
rawConf := `
[mysqld]
ssl
ssl-ca=/etc/mysql/ca.pem
ssl-cert=/etc/mysql/server-cert.pem
ssl-key=/etc/mysql/server-key.pem`
writeFile(t, paths.Join(confDir, "my.cnf"), []byte(rawConf), 0644)
// //////////////////////////////////////////////////////
// Start MySQL container
retURL, cleanup := startMySQLWithTLS(t, "5.7", confDir)
defer cleanup()
// //////////////////////////////////////////////////////
// Set up x509 user
mClient := connect(t, retURL)
username := setUpX509User(t, mClient, clientCert)
// //////////////////////////////////////////////////////
// Test
mysql := new(25, 25, 25)
conf := map[string]interface{}{
"connection_url": retURL,
"username": username,
"tls_certificate_key": clientCert.CombinedPEM(),
"tls_ca": caCert.Pem,
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_, err := mysql.Init(ctx, conf, true)
if err != nil {
t.Fatalf("Unable to initialize mysql engine: %s", err)
}
// Initialization complete. The connection was established, but we need to ensure
// that we're connected as the right user
whoamiCmd := "SELECT CURRENT_USER()"
client, err := mysql.getConnection(ctx)
if err != nil {
t.Fatalf("Unable to make connection to MySQL: %s", err)
}
stmt, err := client.Prepare(whoamiCmd)
if err != nil {
t.Fatalf("Unable to prepare MySQL statementL %s", err)
}
results := stmt.QueryRow()
expected := fmt.Sprintf("%s@%%", username)
var result string
if err := results.Scan(&result); err != nil {
t.Fatalf("result could not be scanned from result set: %s", err)
}
if !reflect.DeepEqual(result, expected) {
t.Fatalf("Actual:%#v\nExpected:\n%#v", result, expected)
}
}
func makeTempDir(t *testing.T) (confDir string) {
confDir, err := ioutil.TempDir(".", "mysql-test-data")
if err != nil {
t.Fatalf("Unable to make temp directory: %s", err)
}
// Convert the directory to an absolute path because docker needs it when mounting
confDir, err = filepath.Abs(filepath.Clean(confDir))
if err != nil {
t.Fatalf("Unable to determine where temp directory is on absolute path: %s", err)
}
return confDir
}
func startMySQLWithTLS(t *testing.T, version string, confDir string) (retURL string, cleanup func()) {
if os.Getenv("MYSQL_URL") != "" {
return os.Getenv("MYSQL_URL"), func() {}
}
pool, err := dockertest.NewPool("")
if err != nil {
t.Fatalf("Failed to connect to docker: %s", err)
}
pool.MaxWait = 30 * time.Second
containerName := "mysql-unit-test"
// Remove previously running container if it is still running because cleanup failed
err = pool.RemoveContainerByName(containerName)
if err != nil {
t.Fatalf("Unable to remove old running containers: %s", err)
}
username := "root"
password := "x509test"
runOpts := &dockertest.RunOptions{
Name: containerName,
Repository: "mysql",
Tag: version,
Cmd: []string{"--defaults-extra-file=/etc/mysql/my.cnf", "--auto-generate-certs=OFF"},
Env: []string{fmt.Sprintf("MYSQL_ROOT_PASSWORD=%s", password)},
// Mount the directory from local filesystem into the container
Mounts: []string{
fmt.Sprintf("%s:/etc/mysql", confDir),
},
}
resource, err := pool.RunWithOptions(runOpts)
if err != nil {
t.Fatalf("Could not start local mysql docker container: %s", err)
}
resource.Expire(30)
cleanup = func() {
err := pool.Purge(resource)
if err != nil {
t.Fatalf("Failed to cleanup local container: %s", err)
}
}
dsn := fmt.Sprintf("{{username}}:{{password}}@tcp(localhost:%s)/mysql", resource.GetPort("3306/tcp"))
url := dbutil.QueryHelper(dsn, map[string]string{
"username": username,
"password": password,
})
// exponential backoff-retry
err = pool.Retry(func() error {
var err error
db, err := sql.Open("mysql", url)
if err != nil {
t.Logf("err: %s", err)
return err
}
defer db.Close()
return db.Ping()
})
if err != nil {
cleanup()
t.Fatalf("Could not connect to mysql docker container: %s", err)
}
return dsn, cleanup
}
func connect(t *testing.T, dsn string) (db *sql.DB) {
url := dbutil.QueryHelper(dsn, map[string]string{
"username": "root",
"password": "x509test",
})
db, err := sql.Open("mysql", url)
if err != nil {
t.Fatalf("Unable to make connection to MySQL: %s", err)
}
err = db.Ping()
if err != nil {
t.Fatalf("Failed to ping MySQL server: %s", err)
}
return db
}
func setUpX509User(t *testing.T, db *sql.DB, cert certhelpers.Certificate) (username string) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
username = cert.Template.Subject.CommonName
cmds := []string{
fmt.Sprintf("CREATE USER %s IDENTIFIED BY '' REQUIRE X509", username),
fmt.Sprintf("GRANT ALL ON mysql.* TO '%s'@'%s' REQUIRE X509", username, "%"),
}
for _, cmd := range cmds {
stmt, err := db.PrepareContext(ctx, cmd)
if err != nil {
t.Fatalf("Failed to prepare query: %s", err)
}
_, err = stmt.ExecContext(ctx)
if err != nil {
t.Fatalf("Failed to create x509 user in database: %s", err)
}
err = stmt.Close()
if err != nil {
t.Fatalf("Failed to close prepared statement: %s", err)
}
}
return username
}
// ////////////////////////////////////////////////////////////////////////////
// Writing to file
// ////////////////////////////////////////////////////////////////////////////
func writeFile(t *testing.T, filename string, data []byte, perms os.FileMode) {
t.Helper()
err := ioutil.WriteFile(filename, data, perms)
if err != nil {
t.Fatalf("Unable to write to file [%s]: %s", filename, err)
}
}

View File

@@ -10,7 +10,6 @@ import (
stdmysql "github.com/go-sql-driver/mysql"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/sdk/database/dbplugin"
"github.com/hashicorp/vault/sdk/database/helper/connutil"
"github.com/hashicorp/vault/sdk/database/helper/credsutil"
"github.com/hashicorp/vault/sdk/database/helper/dbutil"
"github.com/hashicorp/vault/sdk/helper/strutil"
@@ -39,7 +38,7 @@ var (
var _ dbplugin.Database = (*MySQL)(nil)
type MySQL struct {
*connutil.SQLConnectionProducer
*mySQLConnectionProducer
credsutil.CredentialsProducer
}
@@ -55,8 +54,7 @@ func New(displayNameLen, roleNameLen, usernameLen int) func() (interface{}, erro
}
func new(displayNameLen, roleNameLen, usernameLen int) *MySQL {
connProducer := &connutil.SQLConnectionProducer{}
connProducer.Type = mySQLTypeName
connProducer := &mySQLConnectionProducer{}
credsProducer := &credsutil.SQLCredentialsProducer{
DisplayNameLen: displayNameLen,
@@ -66,8 +64,8 @@ func new(displayNameLen, roleNameLen, usernameLen int) *MySQL {
}
return &MySQL{
SQLConnectionProducer: connProducer,
CredentialsProducer: credsProducer,
mySQLConnectionProducer: connProducer,
CredentialsProducer: credsProducer,
}
}