Add support for IAM Auth for Google CloudSQL DBs (#22445)

This commit is contained in:
kpcraig
2023-09-06 17:40:39 -04:00
committed by GitHub
parent 2ca784ad11
commit 2172786316
11 changed files with 1024 additions and 41 deletions

View File

@@ -13,6 +13,7 @@ import (
"sync"
"time"
cloudmysql "cloud.google.com/go/cloudsqlconn/mysql/mysql"
"github.com/go-sql-driver/mysql"
"github.com/hashicorp/go-secure-stdlib/parseutil"
"github.com/hashicorp/go-uuid"
@@ -21,6 +22,11 @@ import (
"github.com/mitchellh/mapstructure"
)
const (
cloudSQLMySQL = "cloudsql-mysql"
driverMySQL = "mysql"
)
// mySQLConnectionProducer implements ConnectionProducer and provides a generic producer for most sql databases
type mySQLConnectionProducer struct {
ConnectionURL string `json:"connection_url" mapstructure:"connection_url" structs:"connection_url"`
@@ -29,6 +35,8 @@ type mySQLConnectionProducer struct {
MaxConnectionLifetimeRaw interface{} `json:"max_connection_lifetime" mapstructure:"max_connection_lifetime" structs:"max_connection_lifetime"`
Username string `json:"username" mapstructure:"username" structs:"username"`
Password string `json:"password" mapstructure:"password" structs:"password"`
AuthType string `json:"auth_type" mapstructure:"auth_type" structs:"auth_type"`
ServiceAccountJSON string `json:"service_account_json" mapstructure:"service_account_json" structs:"service_account_json"`
TLSCertificateKeyData []byte `json:"tls_certificate_key" mapstructure:"tls_certificate_key" structs:"-"`
TLSCAData []byte `json:"tls_ca" mapstructure:"tls_ca" structs:"-"`
@@ -38,6 +46,10 @@ type mySQLConnectionProducer struct {
// tlsConfigName is a globally unique name that references the TLS config for this instance in the mysql driver
tlsConfigName string
// cloudDriverName is a globally unique name that references the cloud dialer config for this instance of the driver
cloudDriverName string
cloudDialerCleanup func() error
RawConfig map[string]interface{}
maxConnectionLifetime time.Duration
Initialized bool
@@ -110,6 +122,32 @@ func (c *mySQLConnectionProducer) Init(ctx context.Context, conf map[string]inte
mysql.RegisterTLSConfig(c.tlsConfigName, tlsConfig)
}
// validate auth_type if provided
authType := c.AuthType
if authType != "" {
if ok := connutil.ValidateAuthType(authType); !ok {
return nil, fmt.Errorf("invalid auth_type %s provided", authType)
}
}
if c.AuthType == connutil.AuthTypeGCPIAM {
c.cloudDriverName, err = uuid.GenerateUUID()
if err != nil {
return nil, fmt.Errorf("unable to generate UUID for IAM configuration: %w", err)
}
// for _most_ sql databases, the driver itself contains no state. In the case of google's cloudsql drivers,
// however, the driver might store a credentials file, in which case the state stored by the driver is in
// fact critical to the proper function of the connection. So it needs to be registered here inside the
// ConnectionProducer init.
dialerCleanup, err := registerDriverMySQL(c.cloudDriverName, c.ServiceAccountJSON)
if err != nil {
return nil, err
}
c.cloudDialerCleanup = dialerCleanup
}
// Set initialized to true at this point since all fields are set,
// and the connection can be established at a later time.
c.Initialized = true
@@ -140,6 +178,20 @@ func (c *mySQLConnectionProducer) Connection(ctx context.Context) (interface{},
// If the ping was unsuccessful, close it and ignore errors as we'll be
// reestablishing anyways
c.db.Close()
// if IAM authentication was enabled
// ensure open dialer is also closed
if c.AuthType == connutil.AuthTypeGCPIAM {
if c.cloudDialerCleanup != nil {
c.cloudDialerCleanup()
}
}
}
driverName := driverMySQL
if c.cloudDriverName != "" {
driverName = c.cloudDriverName
}
connURL, err := c.addTLStoDSN()
@@ -147,7 +199,12 @@ func (c *mySQLConnectionProducer) Connection(ctx context.Context) (interface{},
return nil, err
}
c.db, err = sql.Open("mysql", connURL)
cloudURL, err := c.rewriteProtocolForGCP(connURL)
if err != nil {
return nil, err
}
c.db, err = sql.Open(driverName, cloudURL)
if err != nil {
return nil, err
}
@@ -174,6 +231,13 @@ func (c *mySQLConnectionProducer) Close() error {
defer c.Unlock()
if c.db != nil {
// if auth_type is IAM, ensure cleanup
// of cloudSQL resources
if c.AuthType == connutil.AuthTypeGCPIAM {
if c.cloudDialerCleanup != nil {
c.cloudDialerCleanup()
}
}
c.db.Close()
}
@@ -230,3 +294,38 @@ func (c *mySQLConnectionProducer) addTLStoDSN() (connURL string, err error) {
connURL = config.FormatDSN()
return connURL, nil
}
// rewriteProtocolForGCP rewrites the protocol in the DSN to contain the protocol name associated
// with the dialer and therefore driver associated with the provided cloudsqlconn.DialerOpts.
// As a safety/sanity check, it will only do this for protocol "cloudsql-mysql", the name GCP uses in its documentation.
//
// For example, it will rewrite the dsn "user@cloudsql-mysql(zone:region:instance)/ to
// "user@the-uuid-generated(zone:region:instance)/
func (c *mySQLConnectionProducer) rewriteProtocolForGCP(inDSN string) (string, error) {
if c.cloudDriverName == "" {
// unchanged if not cloud
return inDSN, nil
}
config, err := mysql.ParseDSN(inDSN)
if err != nil {
return "", fmt.Errorf("unable to parse connectionURL: %s", err)
}
if config.Net != cloudSQLMySQL {
return "", fmt.Errorf("didn't update net name because it wasn't what we expected as a placeholder: %s", config.Net)
}
config.Net = c.cloudDriverName
return config.FormatDSN(), nil
}
func registerDriverMySQL(driverName, credentials string) (cleanup func() error, err error) {
opts, err := connutil.GetCloudSQLAuthOptions(credentials)
if err != nil {
return nil, err
}
return cloudmysql.RegisterDriver(driverName, opts...)
}

View File

@@ -7,18 +7,21 @@ import (
"context"
"database/sql"
"fmt"
"os"
"strings"
"testing"
"time"
stdmysql "github.com/go-sql-driver/mysql"
"github.com/hashicorp/go-secure-stdlib/strutil"
"github.com/stretchr/testify/require"
mysqlhelper "github.com/hashicorp/vault/helper/testhelpers/mysql"
dbplugin "github.com/hashicorp/vault/sdk/database/dbplugin/v5"
"github.com/hashicorp/vault/sdk/database/dbplugin/v5"
dbtesting "github.com/hashicorp/vault/sdk/database/dbplugin/v5/testing"
"github.com/hashicorp/vault/sdk/database/helper/connutil"
"github.com/hashicorp/vault/sdk/database/helper/credsutil"
"github.com/hashicorp/vault/sdk/database/helper/dbutil"
"github.com/stretchr/testify/require"
)
var _ dbplugin.Database = (*MySQL)(nil)
@@ -44,6 +47,79 @@ func TestMySQL_Initialize(t *testing.T) {
}
}
// TestMySQL_Initialize_CloudGCP validates the proper initialization of a MySQL backend pointing
// to a GCP CloudSQL MySQL instance. This expects some external setup (exact TBD)
func TestMySQL_Initialize_CloudGCP(t *testing.T) {
envConnURL := "CONNECTION_URL"
connURL := os.Getenv(envConnURL)
if connURL == "" {
t.Skipf("env var %s not set, skipping test", envConnURL)
}
credStr := dbtesting.GetGCPTestCredentials(t)
tests := map[string]struct {
req dbplugin.InitializeRequest
wantErr bool
expectedError string
}{
"empty auth type": {
req: dbplugin.InitializeRequest{
Config: map[string]interface{}{
"connection_url": connURL,
"auth_type": "",
},
},
},
"invalid auth type": {
req: dbplugin.InitializeRequest{
Config: map[string]interface{}{
"connection_url": connURL,
"auth_type": "invalid",
},
},
wantErr: true,
expectedError: "invalid auth_type",
},
"JSON credentials": {
req: dbplugin.InitializeRequest{
Config: map[string]interface{}{
"connection_url": connURL,
"auth_type": connutil.AuthTypeGCPIAM,
"service_account_json": credStr,
},
VerifyConnection: true,
},
},
}
for n, tc := range tests {
t.Run(n, func(t *testing.T) {
db := newMySQL(DefaultUserNameTemplate)
defer dbtesting.AssertClose(t, db)
_, err := db.Initialize(context.Background(), tc.req)
if tc.wantErr {
if err == nil {
t.Fatalf("expected error but received nil")
}
if !strings.Contains(err.Error(), tc.expectedError) {
t.Fatalf("expected error %s, got %s", tc.expectedError, err.Error())
}
} else {
if err != nil {
t.Fatalf("expected no error, received %s", err)
}
if !db.Initialized {
t.Fatal("Database should be initialized")
}
}
})
}
}
func testInitialize(t *testing.T, rootPassword string) {
cleanup, connURL := mysqlhelper.PrepareTestContainer(t, false, rootPassword)
defer cleanup()

View File

@@ -12,14 +12,17 @@ import (
"testing"
"time"
"github.com/hashicorp/vault/sdk/database/helper/connutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/hashicorp/vault/helper/testhelpers/postgresql"
"github.com/hashicorp/vault/sdk/database/dbplugin/v5"
dbtesting "github.com/hashicorp/vault/sdk/database/dbplugin/v5/testing"
"github.com/hashicorp/vault/sdk/database/helper/dbutil"
"github.com/hashicorp/vault/sdk/helper/docker"
"github.com/hashicorp/vault/sdk/helper/template"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func getPostgreSQL(t *testing.T, options map[string]interface{}) (*PostgreSQL, func()) {
@@ -94,6 +97,93 @@ func TestPostgreSQL_Initialize_ConnURLWithDSNFormat(t *testing.T) {
}
}
// Ensures we can successfully initialize and connect to a CloudSQL database
// Requires the following:
// - GOOGLE_APPLICATION_CREDENTIALS either JSON or path to file
// - CONNECTION_URL to a valid Postgres instance on Google CloudSQL
func TestPostgreSQL_Initialize_CloudGCP(t *testing.T) {
envConnURL := "CONNECTION_URL"
connURL := os.Getenv(envConnURL)
if connURL == "" {
t.Skipf("env var %s not set, skipping test", envConnURL)
}
credStr := dbtesting.GetGCPTestCredentials(t)
type testCase struct {
req dbplugin.InitializeRequest
wantErr bool
expectedError string
}
tests := map[string]testCase{
"empty auth type": {
req: dbplugin.InitializeRequest{
Config: map[string]interface{}{
"connection_url": connURL,
"auth_type": "",
},
},
},
"invalid auth type": {
req: dbplugin.InitializeRequest{
Config: map[string]interface{}{
"connection_url": connURL,
"auth_type": "invalid",
},
},
wantErr: true,
expectedError: "invalid auth_type",
},
"default credentials": {
req: dbplugin.InitializeRequest{
Config: map[string]interface{}{
"connection_url": connURL,
"auth_type": connutil.AuthTypeGCPIAM,
},
VerifyConnection: true,
},
},
"JSON credentials": {
req: dbplugin.InitializeRequest{
Config: map[string]interface{}{
"connection_url": connURL,
"auth_type": connutil.AuthTypeGCPIAM,
"service_account_json": credStr,
},
VerifyConnection: true,
},
},
}
for name, test := range tests {
t.Run(name, func(t *testing.T) {
db := new()
defer dbtesting.AssertClose(t, db)
_, err := dbtesting.VerifyInitialize(t, db, test.req)
if test.wantErr {
if err == nil {
t.Fatalf("expected error but received nil")
}
if !strings.Contains(err.Error(), test.expectedError) {
t.Fatalf("expected error %s, got %s", test.expectedError, err.Error())
}
} else {
if err != nil {
t.Fatalf("expected no error, received %s", err)
}
if !db.Initialized {
t.Fatal("Database should be initialized")
}
}
})
}
}
// TestPostgreSQL_PasswordAuthentication tests that the default "password_authentication" is "none", and that
// an error is returned if an invalid "password_authentication" is provided.
func TestPostgreSQL_PasswordAuthentication(t *testing.T) {
@@ -1100,6 +1190,86 @@ func TestNewUser_CustomUsername(t *testing.T) {
}
}
func TestNewUser_CloudGCP(t *testing.T) {
envConnURL := "CONNECTION_URL"
connURL := os.Getenv(envConnURL)
if connURL == "" {
t.Skipf("env var %s not set, skipping test", envConnURL)
}
credStr := dbtesting.GetGCPTestCredentials(t)
type testCase struct {
usernameTemplate string
newUserData dbplugin.UsernameMetadata
expectedRegex string
}
tests := map[string]testCase{
"default template": {
usernameTemplate: "",
newUserData: dbplugin.UsernameMetadata{
DisplayName: "displayname",
RoleName: "longrolename",
},
expectedRegex: "^v-displayn-longrole-[a-zA-Z0-9]{20}-[0-9]{10}$",
},
"unique template": {
usernameTemplate: "foo-bar",
newUserData: dbplugin.UsernameMetadata{
DisplayName: "displayname",
RoleName: "longrolename",
},
expectedRegex: "^foo-bar$",
},
}
for name, test := range tests {
t.Run(name, func(t *testing.T) {
initReq := dbplugin.InitializeRequest{
Config: map[string]interface{}{
"connection_url": connURL,
"username_template": test.usernameTemplate,
"auth_type": connutil.AuthTypeGCPIAM,
"service_account_json": credStr,
},
VerifyConnection: true,
}
db := new()
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
_, err := db.Initialize(ctx, initReq)
require.NoError(t, err)
newUserReq := dbplugin.NewUserRequest{
UsernameConfig: test.newUserData,
Statements: dbplugin.Statements{
Commands: []string{
`
CREATE ROLE "{{name}}" WITH
LOGIN
PASSWORD '{{password}}'
VALID UNTIL '{{expiration}}';
GRANT SELECT ON ALL TABLES IN SCHEMA public TO "{{name}}";`,
},
},
Password: "myReally-S3curePassword",
Expiration: time.Now().Add(1 * time.Hour),
}
ctx, cancel = context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
newUserResp, err := db.NewUser(ctx, newUserReq)
require.NoError(t, err)
require.Regexp(t, test.expectedRegex, newUserResp.Username)
})
}
}
// This is a long-running integration test which tests the functionality of Postgres's multi-host
// connection strings. It uses two Postgres containers preconfigured with Replication Manager
// provided by Bitnami. This test currently does not run in CI and must be run manually. This is