mirror of
https://github.com/optim-enterprises-bv/vault.git
synced 2025-11-02 11:38:02 +00:00
Allow mTLS for mysql secrets engine (#9181)
* Extract certificate helpers for use in non-mongodb packages * Created mTLS/X509 test for MySQL secrets engine. * Ensure mysql username and passwords aren't url encoded * Skip mTLS test for circleCI
This commit is contained in:
244
helper/testhelpers/certhelpers/cert_helpers.go
Normal file
244
helper/testhelpers/certhelpers/cert_helpers.go
Normal file
@@ -0,0 +1,244 @@
|
|||||||
|
package certhelpers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"crypto"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/rsa"
|
||||||
|
"crypto/sha1"
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
|
"crypto/x509/pkix"
|
||||||
|
"encoding/pem"
|
||||||
|
"math/big"
|
||||||
|
"net"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type CertBuilder struct {
|
||||||
|
tmpl *x509.Certificate
|
||||||
|
parentTmpl *x509.Certificate
|
||||||
|
|
||||||
|
selfSign bool
|
||||||
|
parentKey *rsa.PrivateKey
|
||||||
|
|
||||||
|
isCA bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type CertOpt func(*CertBuilder) error
|
||||||
|
|
||||||
|
func CommonName(cn string) CertOpt {
|
||||||
|
return func(builder *CertBuilder) error {
|
||||||
|
builder.tmpl.Subject.CommonName = cn
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Parent(parent Certificate) CertOpt {
|
||||||
|
return func(builder *CertBuilder) error {
|
||||||
|
builder.parentKey = parent.PrivKey.PrivKey
|
||||||
|
builder.parentTmpl = parent.Template
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func IsCA(isCA bool) CertOpt {
|
||||||
|
return func(builder *CertBuilder) error {
|
||||||
|
builder.isCA = isCA
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func SelfSign() CertOpt {
|
||||||
|
return func(builder *CertBuilder) error {
|
||||||
|
builder.selfSign = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func IP(ip ...string) CertOpt {
|
||||||
|
return func(builder *CertBuilder) error {
|
||||||
|
for _, addr := range ip {
|
||||||
|
if ipAddr := net.ParseIP(addr); ipAddr != nil {
|
||||||
|
builder.tmpl.IPAddresses = append(builder.tmpl.IPAddresses, ipAddr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func DNS(dns ...string) CertOpt {
|
||||||
|
return func(builder *CertBuilder) error {
|
||||||
|
builder.tmpl.DNSNames = dns
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewCert(t *testing.T, opts ...CertOpt) (cert Certificate) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
builder := CertBuilder{
|
||||||
|
tmpl: &x509.Certificate{
|
||||||
|
SerialNumber: makeSerial(t),
|
||||||
|
Subject: pkix.Name{
|
||||||
|
CommonName: makeCommonName(),
|
||||||
|
},
|
||||||
|
NotBefore: time.Now().Add(-1 * time.Hour),
|
||||||
|
NotAfter: time.Now().Add(1 * time.Hour),
|
||||||
|
IsCA: false,
|
||||||
|
KeyUsage: x509.KeyUsageDigitalSignature |
|
||||||
|
x509.KeyUsageKeyEncipherment |
|
||||||
|
x509.KeyUsageKeyAgreement,
|
||||||
|
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
|
||||||
|
BasicConstraintsValid: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, opt := range opts {
|
||||||
|
err := opt(&builder)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to set up certificate builder: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
key := NewPrivateKey(t)
|
||||||
|
|
||||||
|
builder.tmpl.SubjectKeyId = getSubjKeyID(t, key.PrivKey)
|
||||||
|
|
||||||
|
tmpl := builder.tmpl
|
||||||
|
parent := builder.parentTmpl
|
||||||
|
publicKey := key.PrivKey.Public()
|
||||||
|
signingKey := builder.parentKey
|
||||||
|
|
||||||
|
if builder.selfSign {
|
||||||
|
parent = tmpl
|
||||||
|
signingKey = key.PrivKey
|
||||||
|
}
|
||||||
|
|
||||||
|
if builder.isCA {
|
||||||
|
tmpl.IsCA = true
|
||||||
|
tmpl.KeyUsage = x509.KeyUsageCertSign | x509.KeyUsageCRLSign
|
||||||
|
tmpl.ExtKeyUsage = nil
|
||||||
|
} else {
|
||||||
|
tmpl.KeyUsage = x509.KeyUsageDigitalSignature |
|
||||||
|
x509.KeyUsageKeyEncipherment |
|
||||||
|
x509.KeyUsageKeyAgreement
|
||||||
|
tmpl.ExtKeyUsage = []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}
|
||||||
|
}
|
||||||
|
|
||||||
|
certBytes, err := x509.CreateCertificate(rand.Reader, tmpl, parent, publicKey, signingKey)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Unable to generate certificate: %s", err)
|
||||||
|
}
|
||||||
|
certPem := pem.EncodeToMemory(&pem.Block{
|
||||||
|
Type: "CERTIFICATE",
|
||||||
|
Bytes: certBytes,
|
||||||
|
})
|
||||||
|
|
||||||
|
tlsCert, err := tls.X509KeyPair(certPem, key.Pem)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Unable to parse X509 key pair: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return Certificate{
|
||||||
|
Template: tmpl,
|
||||||
|
PrivKey: key,
|
||||||
|
TLSCert: tlsCert,
|
||||||
|
RawCert: certBytes,
|
||||||
|
Pem: certPem,
|
||||||
|
IsCA: builder.isCA,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Private Key
|
||||||
|
// ////////////////////////////////////////////////////////////////////////////
|
||||||
|
type KeyWrapper struct {
|
||||||
|
PrivKey *rsa.PrivateKey
|
||||||
|
Pem []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewPrivateKey(t *testing.T) (key KeyWrapper) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
privKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Unable to generate key for cert: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
privKeyPem := pem.EncodeToMemory(
|
||||||
|
&pem.Block{
|
||||||
|
Type: "RSA PRIVATE KEY",
|
||||||
|
Bytes: x509.MarshalPKCS1PrivateKey(privKey),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
key = KeyWrapper{
|
||||||
|
PrivKey: privKey,
|
||||||
|
Pem: privKeyPem,
|
||||||
|
}
|
||||||
|
|
||||||
|
return key
|
||||||
|
}
|
||||||
|
|
||||||
|
// ////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Certificate
|
||||||
|
// ////////////////////////////////////////////////////////////////////////////
|
||||||
|
type Certificate struct {
|
||||||
|
PrivKey KeyWrapper
|
||||||
|
Template *x509.Certificate
|
||||||
|
TLSCert tls.Certificate
|
||||||
|
RawCert []byte
|
||||||
|
Pem []byte
|
||||||
|
IsCA bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cert Certificate) CombinedPEM() []byte {
|
||||||
|
if cert.IsCA {
|
||||||
|
return cert.Pem
|
||||||
|
}
|
||||||
|
return bytes.Join([][]byte{cert.PrivKey.Pem, cert.Pem}, []byte{'\n'})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cert Certificate) PrivateKeyPEM() []byte {
|
||||||
|
return cert.PrivKey.Pem
|
||||||
|
}
|
||||||
|
|
||||||
|
// ////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Helpers
|
||||||
|
// ////////////////////////////////////////////////////////////////////////////
|
||||||
|
func makeSerial(t *testing.T) *big.Int {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
v := &big.Int{}
|
||||||
|
serialNumberLimit := v.Lsh(big.NewInt(1), 128)
|
||||||
|
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Unable to generate serial number: %s", err)
|
||||||
|
}
|
||||||
|
return serialNumber
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pulled from sdk/helper/certutil & slightly modified for test usage
|
||||||
|
func getSubjKeyID(t *testing.T, privateKey crypto.Signer) []byte {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
if privateKey == nil {
|
||||||
|
t.Fatalf("passed-in private key is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
marshaledKey, err := x509.MarshalPKIXPublicKey(privateKey.Public())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("error marshalling public key: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
subjKeyID := sha1.Sum(marshaledKey)
|
||||||
|
|
||||||
|
return subjKeyID[:]
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeCommonName() (cn string) {
|
||||||
|
return strings.ReplaceAll(time.Now().Format("20060102T150405.000"), ".", "")
|
||||||
|
}
|
||||||
@@ -10,9 +10,7 @@ import (
|
|||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"crypto/x509/pkix"
|
"crypto/x509/pkix"
|
||||||
"encoding/pem"
|
"encoding/pem"
|
||||||
"io/ioutil"
|
|
||||||
"math/big"
|
"math/big"
|
||||||
"os"
|
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
@@ -192,18 +190,6 @@ func (cert certificate) CombinedPEM() []byte {
|
|||||||
return bytes.Join([][]byte{cert.privKey.pem, cert.pem}, []byte{'\n'})
|
return bytes.Join([][]byte{cert.privKey.pem, cert.pem}, []byte{'\n'})
|
||||||
}
|
}
|
||||||
|
|
||||||
// ////////////////////////////////////////////////////////////////////////////
|
|
||||||
// Writing to file
|
|
||||||
// ////////////////////////////////////////////////////////////////////////////
|
|
||||||
func writeFile(t *testing.T, filename string, data []byte, perms os.FileMode) {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
err := ioutil.WriteFile(filename, data, perms)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Unable to write to file [%s]: %s", filename, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ////////////////////////////////////////////////////////////////////////////
|
// ////////////////////////////////////////////////////////////////////////////
|
||||||
// Helpers
|
// Helpers
|
||||||
// ////////////////////////////////////////////////////////////////////////////
|
// ////////////////////////////////////////////////////////////////////////////
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/hashicorp/vault/helper/testhelpers/certhelpers"
|
||||||
"github.com/hashicorp/vault/helper/testhelpers/mongodb"
|
"github.com/hashicorp/vault/helper/testhelpers/mongodb"
|
||||||
"github.com/ory/dockertest"
|
"github.com/ory/dockertest"
|
||||||
"go.mongodb.org/mongo-driver/mongo"
|
"go.mongodb.org/mongo-driver/mongo"
|
||||||
@@ -30,20 +31,20 @@ func TestInit_clientTLS(t *testing.T) {
|
|||||||
defer os.RemoveAll(confDir)
|
defer os.RemoveAll(confDir)
|
||||||
|
|
||||||
// Create certificates for Mongo authentication
|
// Create certificates for Mongo authentication
|
||||||
caCert := newCert(t,
|
caCert := certhelpers.NewCert(t,
|
||||||
commonName("test certificate authority"),
|
certhelpers.CommonName("test certificate authority"),
|
||||||
isCA(true),
|
certhelpers.IsCA(true),
|
||||||
selfSign(),
|
certhelpers.SelfSign(),
|
||||||
)
|
)
|
||||||
serverCert := newCert(t,
|
serverCert := certhelpers.NewCert(t,
|
||||||
commonName("server"),
|
certhelpers.CommonName("server"),
|
||||||
dns("localhost"),
|
certhelpers.DNS("localhost"),
|
||||||
parent(caCert),
|
certhelpers.Parent(caCert),
|
||||||
)
|
)
|
||||||
clientCert := newCert(t,
|
clientCert := certhelpers.NewCert(t,
|
||||||
commonName("client"),
|
certhelpers.CommonName("client"),
|
||||||
dns("client"),
|
certhelpers.DNS("client"),
|
||||||
parent(caCert),
|
certhelpers.Parent(caCert),
|
||||||
)
|
)
|
||||||
|
|
||||||
writeFile(t, paths.Join(confDir, "ca.pem"), caCert.CombinedPEM(), 0644)
|
writeFile(t, paths.Join(confDir, "ca.pem"), caCert.CombinedPEM(), 0644)
|
||||||
@@ -81,7 +82,7 @@ net:
|
|||||||
"connection_url": retURL,
|
"connection_url": retURL,
|
||||||
"allowed_roles": "*",
|
"allowed_roles": "*",
|
||||||
"tls_certificate_key": clientCert.CombinedPEM(),
|
"tls_certificate_key": clientCert.CombinedPEM(),
|
||||||
"tls_ca": caCert.pem,
|
"tls_ca": caCert.Pem,
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
@@ -111,7 +112,7 @@ net:
|
|||||||
AuthInfo: authInfo{
|
AuthInfo: authInfo{
|
||||||
AuthenticatedUsers: []user{
|
AuthenticatedUsers: []user{
|
||||||
{
|
{
|
||||||
User: fmt.Sprintf("CN=%s", clientCert.template.Subject.CommonName),
|
User: fmt.Sprintf("CN=%s", clientCert.Template.Subject.CommonName),
|
||||||
DB: "$external",
|
DB: "$external",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -249,11 +250,11 @@ func connect(t *testing.T, uri string) (client *mongo.Client) {
|
|||||||
return client
|
return client
|
||||||
}
|
}
|
||||||
|
|
||||||
func setUpX509User(t *testing.T, client *mongo.Client, cert certificate) {
|
func setUpX509User(t *testing.T, client *mongo.Client, cert certhelpers.Certificate) {
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
username := fmt.Sprintf("CN=%s", cert.template.Subject.CommonName)
|
username := fmt.Sprintf("CN=%s", cert.Template.Subject.CommonName)
|
||||||
|
|
||||||
cmd := &createUserCommand{
|
cmd := &createUserCommand{
|
||||||
Username: username,
|
Username: username,
|
||||||
@@ -301,3 +302,16 @@ type roles []role
|
|||||||
func (r roles) Len() int { return len(r) }
|
func (r roles) Len() int { return len(r) }
|
||||||
func (r roles) Less(i, j int) bool { return r[i].Role < r[j].Role }
|
func (r roles) Less(i, j int) bool { return r[i].Role < r[j].Role }
|
||||||
func (r roles) Swap(i, j int) { r[i], r[j] = r[j], r[i] }
|
func (r roles) Swap(i, j int) { r[i], r[j] = r[j], r[i] }
|
||||||
|
|
||||||
|
// ////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Writing to file
|
||||||
|
// ////////////////////////////////////////////////////////////////////////////
|
||||||
|
func writeFile(t *testing.T, filename string, data []byte, perms os.FileMode) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
err := ioutil.WriteFile(filename, data, perms)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Unable to write to file [%s]: %s", filename, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/hashicorp/vault/helper/testhelpers/certhelpers"
|
||||||
"github.com/hashicorp/vault/helper/testhelpers/mongodb"
|
"github.com/hashicorp/vault/helper/testhelpers/mongodb"
|
||||||
"github.com/hashicorp/vault/sdk/database/dbplugin"
|
"github.com/hashicorp/vault/sdk/database/dbplugin"
|
||||||
"go.mongodb.org/mongo-driver/mongo"
|
"go.mongodb.org/mongo-driver/mongo"
|
||||||
@@ -239,14 +240,14 @@ func testCreateDBUser(t testing.TB, connURL, db, username, password string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestGetTLSAuth(t *testing.T) {
|
func TestGetTLSAuth(t *testing.T) {
|
||||||
ca := newCert(t,
|
ca := certhelpers.NewCert(t,
|
||||||
commonName("certificate authority"),
|
certhelpers.CommonName("certificate authority"),
|
||||||
isCA(true),
|
certhelpers.IsCA(true),
|
||||||
selfSign(),
|
certhelpers.SelfSign(),
|
||||||
)
|
)
|
||||||
cert := newCert(t,
|
cert := certhelpers.NewCert(t,
|
||||||
commonName("test cert"),
|
certhelpers.CommonName("test cert"),
|
||||||
parent(ca),
|
certhelpers.Parent(ca),
|
||||||
)
|
)
|
||||||
|
|
||||||
type testCase struct {
|
type testCase struct {
|
||||||
@@ -276,12 +277,12 @@ func TestGetTLSAuth(t *testing.T) {
|
|||||||
expectErr: true,
|
expectErr: true,
|
||||||
},
|
},
|
||||||
"good ca": {
|
"good ca": {
|
||||||
tlsCAData: cert.pem,
|
tlsCAData: cert.Pem,
|
||||||
|
|
||||||
expectOpts: options.Client().
|
expectOpts: options.Client().
|
||||||
SetTLSConfig(
|
SetTLSConfig(
|
||||||
&tls.Config{
|
&tls.Config{
|
||||||
RootCAs: appendToCertPool(t, x509.NewCertPool(), cert.pem),
|
RootCAs: appendToCertPool(t, x509.NewCertPool(), cert.Pem),
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
expectErr: false,
|
expectErr: false,
|
||||||
@@ -293,7 +294,7 @@ func TestGetTLSAuth(t *testing.T) {
|
|||||||
expectOpts: options.Client().
|
expectOpts: options.Client().
|
||||||
SetTLSConfig(
|
SetTLSConfig(
|
||||||
&tls.Config{
|
&tls.Config{
|
||||||
Certificates: []tls.Certificate{cert.tlsCert},
|
Certificates: []tls.Certificate{cert.TLSCert},
|
||||||
},
|
},
|
||||||
).
|
).
|
||||||
SetAuth(options.Credential{
|
SetAuth(options.Credential{
|
||||||
|
|||||||
226
plugins/database/mysql/connection_producer.go
Normal file
226
plugins/database/mysql/connection_producer.go
Normal file
@@ -0,0 +1,226 @@
|
|||||||
|
package mysql
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"net/url"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/go-sql-driver/mysql"
|
||||||
|
"github.com/hashicorp/errwrap"
|
||||||
|
"github.com/hashicorp/go-uuid"
|
||||||
|
"github.com/hashicorp/vault/sdk/database/helper/connutil"
|
||||||
|
"github.com/hashicorp/vault/sdk/database/helper/dbutil"
|
||||||
|
"github.com/hashicorp/vault/sdk/helper/parseutil"
|
||||||
|
"github.com/mitchellh/mapstructure"
|
||||||
|
)
|
||||||
|
|
||||||
|
// mySQLConnectionProducer implements ConnectionProducer and provides a generic producer for most sql databases
|
||||||
|
type mySQLConnectionProducer struct {
|
||||||
|
ConnectionURL string `json:"connection_url" mapstructure:"connection_url" structs:"connection_url"`
|
||||||
|
MaxOpenConnections int `json:"max_open_connections" mapstructure:"max_open_connections" structs:"max_open_connections"`
|
||||||
|
MaxIdleConnections int `json:"max_idle_connections" mapstructure:"max_idle_connections" structs:"max_idle_connections"`
|
||||||
|
MaxConnectionLifetimeRaw interface{} `json:"max_connection_lifetime" mapstructure:"max_connection_lifetime" structs:"max_connection_lifetime"`
|
||||||
|
|
||||||
|
Username string `json:"username" mapstructure:"username" structs:"username"`
|
||||||
|
Password string `json:"password" mapstructure:"password" structs:"password"`
|
||||||
|
|
||||||
|
TLSCertificateKeyData []byte `json:"tls_certificate_key" mapstructure:"tls_certificate_key" structs:"-"`
|
||||||
|
TLSCAData []byte `json:"tls_ca" mapstructure:"tls_ca" structs:"-"`
|
||||||
|
|
||||||
|
// tlsConfigName is a globally unique name that references the TLS config for this instance in the mysql driver
|
||||||
|
tlsConfigName string
|
||||||
|
|
||||||
|
RawConfig map[string]interface{}
|
||||||
|
maxConnectionLifetime time.Duration
|
||||||
|
Initialized bool
|
||||||
|
db *sql.DB
|
||||||
|
sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *mySQLConnectionProducer) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) error {
|
||||||
|
_, err := c.Init(ctx, conf, verifyConnection)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *mySQLConnectionProducer) Init(ctx context.Context, conf map[string]interface{}, verifyConnection bool) (map[string]interface{}, error) {
|
||||||
|
c.Lock()
|
||||||
|
defer c.Unlock()
|
||||||
|
|
||||||
|
c.RawConfig = conf
|
||||||
|
|
||||||
|
err := mapstructure.WeakDecode(conf, &c)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(c.ConnectionURL) == 0 {
|
||||||
|
return nil, fmt.Errorf("connection_url cannot be empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Don't escape special characters for MySQL password
|
||||||
|
password := c.Password
|
||||||
|
|
||||||
|
// QueryHelper doesn't do any SQL escaping, but if it starts to do so
|
||||||
|
// then maybe we won't be able to use it to do URL substitution any more.
|
||||||
|
c.ConnectionURL = dbutil.QueryHelper(c.ConnectionURL, map[string]string{
|
||||||
|
"username": url.PathEscape(c.Username),
|
||||||
|
"password": password,
|
||||||
|
})
|
||||||
|
|
||||||
|
if c.MaxOpenConnections == 0 {
|
||||||
|
c.MaxOpenConnections = 4
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.MaxIdleConnections == 0 {
|
||||||
|
c.MaxIdleConnections = c.MaxOpenConnections
|
||||||
|
}
|
||||||
|
if c.MaxIdleConnections > c.MaxOpenConnections {
|
||||||
|
c.MaxIdleConnections = c.MaxOpenConnections
|
||||||
|
}
|
||||||
|
if c.MaxConnectionLifetimeRaw == nil {
|
||||||
|
c.MaxConnectionLifetimeRaw = "0s"
|
||||||
|
}
|
||||||
|
|
||||||
|
c.maxConnectionLifetime, err = parseutil.ParseDurationSecond(c.MaxConnectionLifetimeRaw)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errwrap.Wrapf("invalid max_connection_lifetime: {{err}}", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tlsConfig, err := c.getTLSAuth()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if tlsConfig != nil {
|
||||||
|
if c.tlsConfigName == "" {
|
||||||
|
c.tlsConfigName, err = uuid.GenerateUUID()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("unable to generate UUID for TLS configuration: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
mysql.RegisterTLSConfig(c.tlsConfigName, tlsConfig)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set initialized to true at this point since all fields are set,
|
||||||
|
// and the connection can be established at a later time.
|
||||||
|
c.Initialized = true
|
||||||
|
|
||||||
|
if verifyConnection {
|
||||||
|
if _, err := c.Connection(ctx); err != nil {
|
||||||
|
return nil, errwrap.Wrapf("error verifying connection: {{err}}", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.db.PingContext(ctx); err != nil {
|
||||||
|
return nil, errwrap.Wrapf("error verifying connection: {{err}}", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.RawConfig, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *mySQLConnectionProducer) Connection(ctx context.Context) (interface{}, error) {
|
||||||
|
if !c.Initialized {
|
||||||
|
return nil, connutil.ErrNotInitialized
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we already have a DB, test it and return
|
||||||
|
if c.db != nil {
|
||||||
|
if err := c.db.PingContext(ctx); err == nil {
|
||||||
|
return c.db, nil
|
||||||
|
}
|
||||||
|
// If the ping was unsuccessful, close it and ignore errors as we'll be
|
||||||
|
// reestablishing anyways
|
||||||
|
c.db.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
connURL, err := c.addTLStoDSN()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
c.db, err = sql.Open("mysql", connURL)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set some connection pool settings. We don't need much of this,
|
||||||
|
// since the request rate shouldn't be high.
|
||||||
|
c.db.SetMaxOpenConns(c.MaxOpenConnections)
|
||||||
|
c.db.SetMaxIdleConns(c.MaxIdleConnections)
|
||||||
|
c.db.SetConnMaxLifetime(c.maxConnectionLifetime)
|
||||||
|
|
||||||
|
return c.db, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *mySQLConnectionProducer) SecretValues() map[string]interface{} {
|
||||||
|
return map[string]interface{}{
|
||||||
|
c.Password: "[password]",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close attempts to close the connection
|
||||||
|
func (c *mySQLConnectionProducer) Close() error {
|
||||||
|
// Grab the write lock
|
||||||
|
c.Lock()
|
||||||
|
defer c.Unlock()
|
||||||
|
|
||||||
|
if c.db != nil {
|
||||||
|
c.db.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
c.db = nil
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *mySQLConnectionProducer) getTLSAuth() (tlsConfig *tls.Config, err error) {
|
||||||
|
if len(c.TLSCAData) == 0 &&
|
||||||
|
len(c.TLSCertificateKeyData) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
rootCertPool := x509.NewCertPool()
|
||||||
|
if len(c.TLSCAData) > 0 {
|
||||||
|
ok := rootCertPool.AppendCertsFromPEM(c.TLSCAData)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("failed to append CA to client options")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
clientCert := make([]tls.Certificate, 0, 1)
|
||||||
|
|
||||||
|
if len(c.TLSCertificateKeyData) > 0 {
|
||||||
|
certificate, err := tls.X509KeyPair(c.TLSCertificateKeyData, c.TLSCertificateKeyData)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("unable to load tls_certificate_key_data: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
clientCert = append(clientCert, certificate)
|
||||||
|
}
|
||||||
|
|
||||||
|
tlsConfig = &tls.Config{
|
||||||
|
RootCAs: rootCertPool,
|
||||||
|
Certificates: clientCert,
|
||||||
|
}
|
||||||
|
|
||||||
|
return tlsConfig, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *mySQLConnectionProducer) addTLStoDSN() (connURL string, err error) {
|
||||||
|
config, err := mysql.ParseDSN(c.ConnectionURL)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("unable to parse connectionURL: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
config.TLSConfig = c.tlsConfigName
|
||||||
|
|
||||||
|
connURL = config.FormatDSN()
|
||||||
|
|
||||||
|
return connURL, nil
|
||||||
|
}
|
||||||
311
plugins/database/mysql/connection_producer_test.go
Normal file
311
plugins/database/mysql/connection_producer_test.go
Normal file
@@ -0,0 +1,311 @@
|
|||||||
|
package mysql
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"io/ioutil"
|
||||||
|
"os"
|
||||||
|
paths "path"
|
||||||
|
"path/filepath"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/hashicorp/vault/helper/testhelpers/certhelpers"
|
||||||
|
"github.com/hashicorp/vault/sdk/database/helper/dbutil"
|
||||||
|
"github.com/ory/dockertest"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test_addTLStoDSN(t *testing.T) {
|
||||||
|
type testCase struct {
|
||||||
|
rootUrl string
|
||||||
|
tlsConfigName string
|
||||||
|
expectedResult string
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := map[string]testCase{
|
||||||
|
"no tls, no query string": {
|
||||||
|
rootUrl: "user:password@tcp(localhost:3306)/test",
|
||||||
|
tlsConfigName: "",
|
||||||
|
expectedResult: "user:password@tcp(localhost:3306)/test",
|
||||||
|
},
|
||||||
|
"tls, no query string": {
|
||||||
|
rootUrl: "user:password@tcp(localhost:3306)/test",
|
||||||
|
tlsConfigName: "tlsTest101",
|
||||||
|
expectedResult: "user:password@tcp(localhost:3306)/test?tls=tlsTest101",
|
||||||
|
},
|
||||||
|
"tls, query string": {
|
||||||
|
rootUrl: "user:password@tcp(localhost:3306)/test?foo=bar",
|
||||||
|
tlsConfigName: "tlsTest101",
|
||||||
|
expectedResult: "user:password@tcp(localhost:3306)/test?tls=tlsTest101&foo=bar",
|
||||||
|
},
|
||||||
|
"tls, query string, ? in password": {
|
||||||
|
rootUrl: "user:pa?ssword?@tcp(localhost:3306)/test?foo=bar",
|
||||||
|
tlsConfigName: "tlsTest101",
|
||||||
|
expectedResult: "user:pa?ssword?@tcp(localhost:3306)/test?tls=tlsTest101&foo=bar",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, test := range tests {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
tCase := mySQLConnectionProducer{
|
||||||
|
ConnectionURL: test.rootUrl,
|
||||||
|
tlsConfigName: test.tlsConfigName,
|
||||||
|
}
|
||||||
|
|
||||||
|
actual, err := tCase.addTLStoDSN()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("error occurred in test: %s", err)
|
||||||
|
}
|
||||||
|
if actual != test.expectedResult {
|
||||||
|
t.Fatalf("generated: %s, expected: %s", actual, test.expectedResult)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInit_clientTLS(t *testing.T) {
|
||||||
|
t.Skip("Skipping this test because CircleCI can't mount the files we need without further investigation: " +
|
||||||
|
"https://support.circleci.com/hc/en-us/articles/360007324514-How-can-I-mount-volumes-to-docker-containers-")
|
||||||
|
|
||||||
|
// Set up temp directory so we can mount it to the docker container
|
||||||
|
confDir := makeTempDir(t)
|
||||||
|
defer os.RemoveAll(confDir)
|
||||||
|
|
||||||
|
// Create certificates for MySQL authentication
|
||||||
|
caCert := certhelpers.NewCert(t,
|
||||||
|
certhelpers.CommonName("test certificate authority"),
|
||||||
|
certhelpers.IsCA(true),
|
||||||
|
certhelpers.SelfSign(),
|
||||||
|
)
|
||||||
|
serverCert := certhelpers.NewCert(t,
|
||||||
|
certhelpers.CommonName("server"),
|
||||||
|
certhelpers.DNS("localhost"),
|
||||||
|
certhelpers.Parent(caCert),
|
||||||
|
)
|
||||||
|
clientCert := certhelpers.NewCert(t,
|
||||||
|
certhelpers.CommonName("client"),
|
||||||
|
certhelpers.DNS("client"),
|
||||||
|
certhelpers.Parent(caCert),
|
||||||
|
)
|
||||||
|
|
||||||
|
writeFile(t, paths.Join(confDir, "ca.pem"), caCert.CombinedPEM(), 0644)
|
||||||
|
writeFile(t, paths.Join(confDir, "server-cert.pem"), serverCert.Pem, 0644)
|
||||||
|
writeFile(t, paths.Join(confDir, "server-key.pem"), serverCert.PrivateKeyPEM(), 0644)
|
||||||
|
writeFile(t, paths.Join(confDir, "client.pem"), clientCert.CombinedPEM(), 0644)
|
||||||
|
|
||||||
|
// //////////////////////////////////////////////////////
|
||||||
|
// Set up MySQL config file
|
||||||
|
rawConf := `
|
||||||
|
[mysqld]
|
||||||
|
ssl
|
||||||
|
ssl-ca=/etc/mysql/ca.pem
|
||||||
|
ssl-cert=/etc/mysql/server-cert.pem
|
||||||
|
ssl-key=/etc/mysql/server-key.pem`
|
||||||
|
|
||||||
|
writeFile(t, paths.Join(confDir, "my.cnf"), []byte(rawConf), 0644)
|
||||||
|
|
||||||
|
// //////////////////////////////////////////////////////
|
||||||
|
// Start MySQL container
|
||||||
|
retURL, cleanup := startMySQLWithTLS(t, "5.7", confDir)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
// //////////////////////////////////////////////////////
|
||||||
|
// Set up x509 user
|
||||||
|
mClient := connect(t, retURL)
|
||||||
|
|
||||||
|
username := setUpX509User(t, mClient, clientCert)
|
||||||
|
|
||||||
|
// //////////////////////////////////////////////////////
|
||||||
|
// Test
|
||||||
|
mysql := new(25, 25, 25)
|
||||||
|
|
||||||
|
conf := map[string]interface{}{
|
||||||
|
"connection_url": retURL,
|
||||||
|
"username": username,
|
||||||
|
"tls_certificate_key": clientCert.CombinedPEM(),
|
||||||
|
"tls_ca": caCert.Pem,
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
_, err := mysql.Init(ctx, conf, true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Unable to initialize mysql engine: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialization complete. The connection was established, but we need to ensure
|
||||||
|
// that we're connected as the right user
|
||||||
|
whoamiCmd := "SELECT CURRENT_USER()"
|
||||||
|
|
||||||
|
client, err := mysql.getConnection(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Unable to make connection to MySQL: %s", err)
|
||||||
|
}
|
||||||
|
stmt, err := client.Prepare(whoamiCmd)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Unable to prepare MySQL statementL %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
results := stmt.QueryRow()
|
||||||
|
|
||||||
|
expected := fmt.Sprintf("%s@%%", username)
|
||||||
|
|
||||||
|
var result string
|
||||||
|
if err := results.Scan(&result); err != nil {
|
||||||
|
t.Fatalf("result could not be scanned from result set: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(result, expected) {
|
||||||
|
t.Fatalf("Actual:%#v\nExpected:\n%#v", result, expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeTempDir(t *testing.T) (confDir string) {
|
||||||
|
confDir, err := ioutil.TempDir(".", "mysql-test-data")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Unable to make temp directory: %s", err)
|
||||||
|
}
|
||||||
|
// Convert the directory to an absolute path because docker needs it when mounting
|
||||||
|
confDir, err = filepath.Abs(filepath.Clean(confDir))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Unable to determine where temp directory is on absolute path: %s", err)
|
||||||
|
}
|
||||||
|
return confDir
|
||||||
|
}
|
||||||
|
|
||||||
|
func startMySQLWithTLS(t *testing.T, version string, confDir string) (retURL string, cleanup func()) {
|
||||||
|
if os.Getenv("MYSQL_URL") != "" {
|
||||||
|
return os.Getenv("MYSQL_URL"), func() {}
|
||||||
|
}
|
||||||
|
|
||||||
|
pool, err := dockertest.NewPool("")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to connect to docker: %s", err)
|
||||||
|
}
|
||||||
|
pool.MaxWait = 30 * time.Second
|
||||||
|
|
||||||
|
containerName := "mysql-unit-test"
|
||||||
|
|
||||||
|
// Remove previously running container if it is still running because cleanup failed
|
||||||
|
err = pool.RemoveContainerByName(containerName)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Unable to remove old running containers: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
username := "root"
|
||||||
|
password := "x509test"
|
||||||
|
|
||||||
|
runOpts := &dockertest.RunOptions{
|
||||||
|
Name: containerName,
|
||||||
|
Repository: "mysql",
|
||||||
|
Tag: version,
|
||||||
|
Cmd: []string{"--defaults-extra-file=/etc/mysql/my.cnf", "--auto-generate-certs=OFF"},
|
||||||
|
Env: []string{fmt.Sprintf("MYSQL_ROOT_PASSWORD=%s", password)},
|
||||||
|
// Mount the directory from local filesystem into the container
|
||||||
|
Mounts: []string{
|
||||||
|
fmt.Sprintf("%s:/etc/mysql", confDir),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
resource, err := pool.RunWithOptions(runOpts)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Could not start local mysql docker container: %s", err)
|
||||||
|
}
|
||||||
|
resource.Expire(30)
|
||||||
|
|
||||||
|
cleanup = func() {
|
||||||
|
err := pool.Purge(resource)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to cleanup local container: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
dsn := fmt.Sprintf("{{username}}:{{password}}@tcp(localhost:%s)/mysql", resource.GetPort("3306/tcp"))
|
||||||
|
|
||||||
|
url := dbutil.QueryHelper(dsn, map[string]string{
|
||||||
|
"username": username,
|
||||||
|
"password": password,
|
||||||
|
})
|
||||||
|
// exponential backoff-retry
|
||||||
|
err = pool.Retry(func() error {
|
||||||
|
var err error
|
||||||
|
|
||||||
|
db, err := sql.Open("mysql", url)
|
||||||
|
if err != nil {
|
||||||
|
t.Logf("err: %s", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
return db.Ping()
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
cleanup()
|
||||||
|
t.Fatalf("Could not connect to mysql docker container: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return dsn, cleanup
|
||||||
|
}
|
||||||
|
|
||||||
|
func connect(t *testing.T, dsn string) (db *sql.DB) {
|
||||||
|
url := dbutil.QueryHelper(dsn, map[string]string{
|
||||||
|
"username": "root",
|
||||||
|
"password": "x509test",
|
||||||
|
})
|
||||||
|
|
||||||
|
db, err := sql.Open("mysql", url)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Unable to make connection to MySQL: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = db.Ping()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to ping MySQL server: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return db
|
||||||
|
}
|
||||||
|
|
||||||
|
func setUpX509User(t *testing.T, db *sql.DB, cert certhelpers.Certificate) (username string) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
username = cert.Template.Subject.CommonName
|
||||||
|
|
||||||
|
cmds := []string{
|
||||||
|
fmt.Sprintf("CREATE USER %s IDENTIFIED BY '' REQUIRE X509", username),
|
||||||
|
fmt.Sprintf("GRANT ALL ON mysql.* TO '%s'@'%s' REQUIRE X509", username, "%"),
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, cmd := range cmds {
|
||||||
|
stmt, err := db.PrepareContext(ctx, cmd)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to prepare query: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = stmt.ExecContext(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create x509 user in database: %s", err)
|
||||||
|
}
|
||||||
|
err = stmt.Close()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to close prepared statement: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return username
|
||||||
|
}
|
||||||
|
|
||||||
|
// ////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Writing to file
|
||||||
|
// ////////////////////////////////////////////////////////////////////////////
|
||||||
|
func writeFile(t *testing.T, filename string, data []byte, perms os.FileMode) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
err := ioutil.WriteFile(filename, data, perms)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Unable to write to file [%s]: %s", filename, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -10,7 +10,6 @@ import (
|
|||||||
stdmysql "github.com/go-sql-driver/mysql"
|
stdmysql "github.com/go-sql-driver/mysql"
|
||||||
"github.com/hashicorp/vault/api"
|
"github.com/hashicorp/vault/api"
|
||||||
"github.com/hashicorp/vault/sdk/database/dbplugin"
|
"github.com/hashicorp/vault/sdk/database/dbplugin"
|
||||||
"github.com/hashicorp/vault/sdk/database/helper/connutil"
|
|
||||||
"github.com/hashicorp/vault/sdk/database/helper/credsutil"
|
"github.com/hashicorp/vault/sdk/database/helper/credsutil"
|
||||||
"github.com/hashicorp/vault/sdk/database/helper/dbutil"
|
"github.com/hashicorp/vault/sdk/database/helper/dbutil"
|
||||||
"github.com/hashicorp/vault/sdk/helper/strutil"
|
"github.com/hashicorp/vault/sdk/helper/strutil"
|
||||||
@@ -39,7 +38,7 @@ var (
|
|||||||
var _ dbplugin.Database = (*MySQL)(nil)
|
var _ dbplugin.Database = (*MySQL)(nil)
|
||||||
|
|
||||||
type MySQL struct {
|
type MySQL struct {
|
||||||
*connutil.SQLConnectionProducer
|
*mySQLConnectionProducer
|
||||||
credsutil.CredentialsProducer
|
credsutil.CredentialsProducer
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -55,8 +54,7 @@ func New(displayNameLen, roleNameLen, usernameLen int) func() (interface{}, erro
|
|||||||
}
|
}
|
||||||
|
|
||||||
func new(displayNameLen, roleNameLen, usernameLen int) *MySQL {
|
func new(displayNameLen, roleNameLen, usernameLen int) *MySQL {
|
||||||
connProducer := &connutil.SQLConnectionProducer{}
|
connProducer := &mySQLConnectionProducer{}
|
||||||
connProducer.Type = mySQLTypeName
|
|
||||||
|
|
||||||
credsProducer := &credsutil.SQLCredentialsProducer{
|
credsProducer := &credsutil.SQLCredentialsProducer{
|
||||||
DisplayNameLen: displayNameLen,
|
DisplayNameLen: displayNameLen,
|
||||||
@@ -66,7 +64,7 @@ func new(displayNameLen, roleNameLen, usernameLen int) *MySQL {
|
|||||||
}
|
}
|
||||||
|
|
||||||
return &MySQL{
|
return &MySQL{
|
||||||
SQLConnectionProducer: connProducer,
|
mySQLConnectionProducer: connProducer,
|
||||||
CredentialsProducer: credsProducer,
|
CredentialsProducer: credsProducer,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user