mirror of
https://github.com/optim-enterprises-bv/vault.git
synced 2025-10-30 02:02:43 +00:00
Add support for IAM Auth for Google CloudSQL DBs (#22445)
This commit is contained in:
@@ -13,6 +13,7 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
cloudmysql "cloud.google.com/go/cloudsqlconn/mysql/mysql"
|
||||
"github.com/go-sql-driver/mysql"
|
||||
"github.com/hashicorp/go-secure-stdlib/parseutil"
|
||||
"github.com/hashicorp/go-uuid"
|
||||
@@ -21,6 +22,11 @@ import (
|
||||
"github.com/mitchellh/mapstructure"
|
||||
)
|
||||
|
||||
const (
|
||||
cloudSQLMySQL = "cloudsql-mysql"
|
||||
driverMySQL = "mysql"
|
||||
)
|
||||
|
||||
// mySQLConnectionProducer implements ConnectionProducer and provides a generic producer for most sql databases
|
||||
type mySQLConnectionProducer struct {
|
||||
ConnectionURL string `json:"connection_url" mapstructure:"connection_url" structs:"connection_url"`
|
||||
@@ -29,6 +35,8 @@ type mySQLConnectionProducer struct {
|
||||
MaxConnectionLifetimeRaw interface{} `json:"max_connection_lifetime" mapstructure:"max_connection_lifetime" structs:"max_connection_lifetime"`
|
||||
Username string `json:"username" mapstructure:"username" structs:"username"`
|
||||
Password string `json:"password" mapstructure:"password" structs:"password"`
|
||||
AuthType string `json:"auth_type" mapstructure:"auth_type" structs:"auth_type"`
|
||||
ServiceAccountJSON string `json:"service_account_json" mapstructure:"service_account_json" structs:"service_account_json"`
|
||||
|
||||
TLSCertificateKeyData []byte `json:"tls_certificate_key" mapstructure:"tls_certificate_key" structs:"-"`
|
||||
TLSCAData []byte `json:"tls_ca" mapstructure:"tls_ca" structs:"-"`
|
||||
@@ -38,6 +46,10 @@ type mySQLConnectionProducer struct {
|
||||
// tlsConfigName is a globally unique name that references the TLS config for this instance in the mysql driver
|
||||
tlsConfigName string
|
||||
|
||||
// cloudDriverName is a globally unique name that references the cloud dialer config for this instance of the driver
|
||||
cloudDriverName string
|
||||
cloudDialerCleanup func() error
|
||||
|
||||
RawConfig map[string]interface{}
|
||||
maxConnectionLifetime time.Duration
|
||||
Initialized bool
|
||||
@@ -110,6 +122,32 @@ func (c *mySQLConnectionProducer) Init(ctx context.Context, conf map[string]inte
|
||||
mysql.RegisterTLSConfig(c.tlsConfigName, tlsConfig)
|
||||
}
|
||||
|
||||
// validate auth_type if provided
|
||||
authType := c.AuthType
|
||||
if authType != "" {
|
||||
if ok := connutil.ValidateAuthType(authType); !ok {
|
||||
return nil, fmt.Errorf("invalid auth_type %s provided", authType)
|
||||
}
|
||||
}
|
||||
|
||||
if c.AuthType == connutil.AuthTypeGCPIAM {
|
||||
c.cloudDriverName, err = uuid.GenerateUUID()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to generate UUID for IAM configuration: %w", err)
|
||||
}
|
||||
|
||||
// for _most_ sql databases, the driver itself contains no state. In the case of google's cloudsql drivers,
|
||||
// however, the driver might store a credentials file, in which case the state stored by the driver is in
|
||||
// fact critical to the proper function of the connection. So it needs to be registered here inside the
|
||||
// ConnectionProducer init.
|
||||
dialerCleanup, err := registerDriverMySQL(c.cloudDriverName, c.ServiceAccountJSON)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.cloudDialerCleanup = dialerCleanup
|
||||
}
|
||||
|
||||
// Set initialized to true at this point since all fields are set,
|
||||
// and the connection can be established at a later time.
|
||||
c.Initialized = true
|
||||
@@ -140,6 +178,20 @@ func (c *mySQLConnectionProducer) Connection(ctx context.Context) (interface{},
|
||||
// If the ping was unsuccessful, close it and ignore errors as we'll be
|
||||
// reestablishing anyways
|
||||
c.db.Close()
|
||||
|
||||
// if IAM authentication was enabled
|
||||
// ensure open dialer is also closed
|
||||
if c.AuthType == connutil.AuthTypeGCPIAM {
|
||||
if c.cloudDialerCleanup != nil {
|
||||
c.cloudDialerCleanup()
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
driverName := driverMySQL
|
||||
if c.cloudDriverName != "" {
|
||||
driverName = c.cloudDriverName
|
||||
}
|
||||
|
||||
connURL, err := c.addTLStoDSN()
|
||||
@@ -147,7 +199,12 @@ func (c *mySQLConnectionProducer) Connection(ctx context.Context) (interface{},
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.db, err = sql.Open("mysql", connURL)
|
||||
cloudURL, err := c.rewriteProtocolForGCP(connURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.db, err = sql.Open(driverName, cloudURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -174,6 +231,13 @@ func (c *mySQLConnectionProducer) Close() error {
|
||||
defer c.Unlock()
|
||||
|
||||
if c.db != nil {
|
||||
// if auth_type is IAM, ensure cleanup
|
||||
// of cloudSQL resources
|
||||
if c.AuthType == connutil.AuthTypeGCPIAM {
|
||||
if c.cloudDialerCleanup != nil {
|
||||
c.cloudDialerCleanup()
|
||||
}
|
||||
}
|
||||
c.db.Close()
|
||||
}
|
||||
|
||||
@@ -230,3 +294,38 @@ func (c *mySQLConnectionProducer) addTLStoDSN() (connURL string, err error) {
|
||||
connURL = config.FormatDSN()
|
||||
return connURL, nil
|
||||
}
|
||||
|
||||
// rewriteProtocolForGCP rewrites the protocol in the DSN to contain the protocol name associated
|
||||
// with the dialer and therefore driver associated with the provided cloudsqlconn.DialerOpts.
|
||||
// As a safety/sanity check, it will only do this for protocol "cloudsql-mysql", the name GCP uses in its documentation.
|
||||
//
|
||||
// For example, it will rewrite the dsn "user@cloudsql-mysql(zone:region:instance)/ to
|
||||
// "user@the-uuid-generated(zone:region:instance)/
|
||||
func (c *mySQLConnectionProducer) rewriteProtocolForGCP(inDSN string) (string, error) {
|
||||
if c.cloudDriverName == "" {
|
||||
// unchanged if not cloud
|
||||
return inDSN, nil
|
||||
}
|
||||
|
||||
config, err := mysql.ParseDSN(inDSN)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("unable to parse connectionURL: %s", err)
|
||||
}
|
||||
|
||||
if config.Net != cloudSQLMySQL {
|
||||
return "", fmt.Errorf("didn't update net name because it wasn't what we expected as a placeholder: %s", config.Net)
|
||||
}
|
||||
|
||||
config.Net = c.cloudDriverName
|
||||
|
||||
return config.FormatDSN(), nil
|
||||
}
|
||||
|
||||
func registerDriverMySQL(driverName, credentials string) (cleanup func() error, err error) {
|
||||
opts, err := connutil.GetCloudSQLAuthOptions(credentials)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return cloudmysql.RegisterDriver(driverName, opts...)
|
||||
}
|
||||
|
||||
@@ -7,18 +7,21 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
stdmysql "github.com/go-sql-driver/mysql"
|
||||
"github.com/hashicorp/go-secure-stdlib/strutil"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
mysqlhelper "github.com/hashicorp/vault/helper/testhelpers/mysql"
|
||||
dbplugin "github.com/hashicorp/vault/sdk/database/dbplugin/v5"
|
||||
"github.com/hashicorp/vault/sdk/database/dbplugin/v5"
|
||||
dbtesting "github.com/hashicorp/vault/sdk/database/dbplugin/v5/testing"
|
||||
"github.com/hashicorp/vault/sdk/database/helper/connutil"
|
||||
"github.com/hashicorp/vault/sdk/database/helper/credsutil"
|
||||
"github.com/hashicorp/vault/sdk/database/helper/dbutil"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var _ dbplugin.Database = (*MySQL)(nil)
|
||||
@@ -44,6 +47,79 @@ func TestMySQL_Initialize(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestMySQL_Initialize_CloudGCP validates the proper initialization of a MySQL backend pointing
|
||||
// to a GCP CloudSQL MySQL instance. This expects some external setup (exact TBD)
|
||||
func TestMySQL_Initialize_CloudGCP(t *testing.T) {
|
||||
envConnURL := "CONNECTION_URL"
|
||||
connURL := os.Getenv(envConnURL)
|
||||
if connURL == "" {
|
||||
t.Skipf("env var %s not set, skipping test", envConnURL)
|
||||
}
|
||||
|
||||
credStr := dbtesting.GetGCPTestCredentials(t)
|
||||
|
||||
tests := map[string]struct {
|
||||
req dbplugin.InitializeRequest
|
||||
wantErr bool
|
||||
expectedError string
|
||||
}{
|
||||
"empty auth type": {
|
||||
req: dbplugin.InitializeRequest{
|
||||
Config: map[string]interface{}{
|
||||
"connection_url": connURL,
|
||||
"auth_type": "",
|
||||
},
|
||||
},
|
||||
},
|
||||
"invalid auth type": {
|
||||
req: dbplugin.InitializeRequest{
|
||||
Config: map[string]interface{}{
|
||||
"connection_url": connURL,
|
||||
"auth_type": "invalid",
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
expectedError: "invalid auth_type",
|
||||
},
|
||||
"JSON credentials": {
|
||||
req: dbplugin.InitializeRequest{
|
||||
Config: map[string]interface{}{
|
||||
"connection_url": connURL,
|
||||
"auth_type": connutil.AuthTypeGCPIAM,
|
||||
"service_account_json": credStr,
|
||||
},
|
||||
VerifyConnection: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for n, tc := range tests {
|
||||
t.Run(n, func(t *testing.T) {
|
||||
db := newMySQL(DefaultUserNameTemplate)
|
||||
defer dbtesting.AssertClose(t, db)
|
||||
_, err := db.Initialize(context.Background(), tc.req)
|
||||
|
||||
if tc.wantErr {
|
||||
if err == nil {
|
||||
t.Fatalf("expected error but received nil")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), tc.expectedError) {
|
||||
t.Fatalf("expected error %s, got %s", tc.expectedError, err.Error())
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, received %s", err)
|
||||
}
|
||||
|
||||
if !db.Initialized {
|
||||
t.Fatal("Database should be initialized")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func testInitialize(t *testing.T, rootPassword string) {
|
||||
cleanup, connURL := mysqlhelper.PrepareTestContainer(t, false, rootPassword)
|
||||
defer cleanup()
|
||||
|
||||
@@ -12,14 +12,17 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/vault/sdk/database/helper/connutil"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/hashicorp/vault/helper/testhelpers/postgresql"
|
||||
"github.com/hashicorp/vault/sdk/database/dbplugin/v5"
|
||||
dbtesting "github.com/hashicorp/vault/sdk/database/dbplugin/v5/testing"
|
||||
"github.com/hashicorp/vault/sdk/database/helper/dbutil"
|
||||
"github.com/hashicorp/vault/sdk/helper/docker"
|
||||
"github.com/hashicorp/vault/sdk/helper/template"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func getPostgreSQL(t *testing.T, options map[string]interface{}) (*PostgreSQL, func()) {
|
||||
@@ -94,6 +97,93 @@ func TestPostgreSQL_Initialize_ConnURLWithDSNFormat(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// Ensures we can successfully initialize and connect to a CloudSQL database
|
||||
// Requires the following:
|
||||
// - GOOGLE_APPLICATION_CREDENTIALS either JSON or path to file
|
||||
// - CONNECTION_URL to a valid Postgres instance on Google CloudSQL
|
||||
func TestPostgreSQL_Initialize_CloudGCP(t *testing.T) {
|
||||
envConnURL := "CONNECTION_URL"
|
||||
connURL := os.Getenv(envConnURL)
|
||||
if connURL == "" {
|
||||
t.Skipf("env var %s not set, skipping test", envConnURL)
|
||||
}
|
||||
|
||||
credStr := dbtesting.GetGCPTestCredentials(t)
|
||||
|
||||
type testCase struct {
|
||||
req dbplugin.InitializeRequest
|
||||
wantErr bool
|
||||
expectedError string
|
||||
}
|
||||
|
||||
tests := map[string]testCase{
|
||||
"empty auth type": {
|
||||
req: dbplugin.InitializeRequest{
|
||||
Config: map[string]interface{}{
|
||||
"connection_url": connURL,
|
||||
"auth_type": "",
|
||||
},
|
||||
},
|
||||
},
|
||||
"invalid auth type": {
|
||||
req: dbplugin.InitializeRequest{
|
||||
Config: map[string]interface{}{
|
||||
"connection_url": connURL,
|
||||
"auth_type": "invalid",
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
expectedError: "invalid auth_type",
|
||||
},
|
||||
"default credentials": {
|
||||
req: dbplugin.InitializeRequest{
|
||||
Config: map[string]interface{}{
|
||||
"connection_url": connURL,
|
||||
"auth_type": connutil.AuthTypeGCPIAM,
|
||||
},
|
||||
VerifyConnection: true,
|
||||
},
|
||||
},
|
||||
"JSON credentials": {
|
||||
req: dbplugin.InitializeRequest{
|
||||
Config: map[string]interface{}{
|
||||
"connection_url": connURL,
|
||||
"auth_type": connutil.AuthTypeGCPIAM,
|
||||
"service_account_json": credStr,
|
||||
},
|
||||
VerifyConnection: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for name, test := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
db := new()
|
||||
defer dbtesting.AssertClose(t, db)
|
||||
|
||||
_, err := dbtesting.VerifyInitialize(t, db, test.req)
|
||||
|
||||
if test.wantErr {
|
||||
if err == nil {
|
||||
t.Fatalf("expected error but received nil")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), test.expectedError) {
|
||||
t.Fatalf("expected error %s, got %s", test.expectedError, err.Error())
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, received %s", err)
|
||||
}
|
||||
|
||||
if !db.Initialized {
|
||||
t.Fatal("Database should be initialized")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestPostgreSQL_PasswordAuthentication tests that the default "password_authentication" is "none", and that
|
||||
// an error is returned if an invalid "password_authentication" is provided.
|
||||
func TestPostgreSQL_PasswordAuthentication(t *testing.T) {
|
||||
@@ -1100,6 +1190,86 @@ func TestNewUser_CustomUsername(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewUser_CloudGCP(t *testing.T) {
|
||||
envConnURL := "CONNECTION_URL"
|
||||
connURL := os.Getenv(envConnURL)
|
||||
if connURL == "" {
|
||||
t.Skipf("env var %s not set, skipping test", envConnURL)
|
||||
}
|
||||
|
||||
credStr := dbtesting.GetGCPTestCredentials(t)
|
||||
|
||||
type testCase struct {
|
||||
usernameTemplate string
|
||||
newUserData dbplugin.UsernameMetadata
|
||||
expectedRegex string
|
||||
}
|
||||
|
||||
tests := map[string]testCase{
|
||||
"default template": {
|
||||
usernameTemplate: "",
|
||||
newUserData: dbplugin.UsernameMetadata{
|
||||
DisplayName: "displayname",
|
||||
RoleName: "longrolename",
|
||||
},
|
||||
expectedRegex: "^v-displayn-longrole-[a-zA-Z0-9]{20}-[0-9]{10}$",
|
||||
},
|
||||
"unique template": {
|
||||
usernameTemplate: "foo-bar",
|
||||
newUserData: dbplugin.UsernameMetadata{
|
||||
DisplayName: "displayname",
|
||||
RoleName: "longrolename",
|
||||
},
|
||||
expectedRegex: "^foo-bar$",
|
||||
},
|
||||
}
|
||||
|
||||
for name, test := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
initReq := dbplugin.InitializeRequest{
|
||||
Config: map[string]interface{}{
|
||||
"connection_url": connURL,
|
||||
"username_template": test.usernameTemplate,
|
||||
"auth_type": connutil.AuthTypeGCPIAM,
|
||||
"service_account_json": credStr,
|
||||
},
|
||||
VerifyConnection: true,
|
||||
}
|
||||
|
||||
db := new()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
_, err := db.Initialize(ctx, initReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
newUserReq := dbplugin.NewUserRequest{
|
||||
UsernameConfig: test.newUserData,
|
||||
Statements: dbplugin.Statements{
|
||||
Commands: []string{
|
||||
`
|
||||
CREATE ROLE "{{name}}" WITH
|
||||
LOGIN
|
||||
PASSWORD '{{password}}'
|
||||
VALID UNTIL '{{expiration}}';
|
||||
GRANT SELECT ON ALL TABLES IN SCHEMA public TO "{{name}}";`,
|
||||
},
|
||||
},
|
||||
Password: "myReally-S3curePassword",
|
||||
Expiration: time.Now().Add(1 * time.Hour),
|
||||
}
|
||||
ctx, cancel = context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
newUserResp, err := db.NewUser(ctx, newUserReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Regexp(t, test.expectedRegex, newUserResp.Username)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// This is a long-running integration test which tests the functionality of Postgres's multi-host
|
||||
// connection strings. It uses two Postgres containers preconfigured with Replication Manager
|
||||
// provided by Bitnami. This test currently does not run in CI and must be run manually. This is
|
||||
|
||||
Reference in New Issue
Block a user