Add support for IAM Auth for Google CloudSQL DBs (#22445)

This commit is contained in:
kpcraig
2023-09-06 17:40:39 -04:00
committed by GitHub
parent 2ca784ad11
commit 2172786316
11 changed files with 1024 additions and 41 deletions

View File

@@ -5,6 +5,7 @@ package dbtesting
import (
"context"
"io/ioutil"
"os"
"testing"
"time"
@@ -39,7 +40,7 @@ func AssertInitializeCircleCiTest(t *testing.T, db dbplugin.Database, req dbplug
var err error
for i := 1; i <= maxAttempts; i++ {
resp, err = verifyInitialize(t, db, req)
resp, err = VerifyInitialize(t, db, req)
if err != nil {
t.Errorf("Failed AssertInitialize attempt: %d with error:\n%+v\n", i, err)
time.Sleep(1 * time.Second)
@@ -57,14 +58,14 @@ func AssertInitializeCircleCiTest(t *testing.T, db dbplugin.Database, req dbplug
func AssertInitialize(t *testing.T, db dbplugin.Database, req dbplugin.InitializeRequest) dbplugin.InitializeResponse {
t.Helper()
resp, err := verifyInitialize(t, db, req)
resp, err := VerifyInitialize(t, db, req)
if err != nil {
t.Fatalf("Failed to initialize: %s", err)
}
return resp
}
func verifyInitialize(t *testing.T, db dbplugin.Database, req dbplugin.InitializeRequest) (dbplugin.InitializeResponse, error) {
func VerifyInitialize(t *testing.T, db dbplugin.Database, req dbplugin.InitializeRequest) (dbplugin.InitializeResponse, error) {
ctx, cancel := context.WithTimeout(context.Background(), getRequestTimeout(t))
defer cancel()
@@ -119,3 +120,31 @@ func AssertClose(t *testing.T, db dbplugin.Database) {
t.Fatalf("Failed to close database: %s", err)
}
}
// GetGCPTestCredentials reads the credentials from the
// GOOGLE_APPLICATIONS_CREDENTIALS environment variable
// The credentials are read from a file if a file exists
// otherwise they are returned as JSON
func GetGCPTestCredentials(t *testing.T) string {
t.Helper()
envCredentials := "GOOGLE_APPLICATIONS_CREDENTIALS"
var credsStr string
credsEnv := os.Getenv(envCredentials)
if credsEnv == "" {
t.Skipf("env var %s not set, skipping test", envCredentials)
}
// Attempt to read as file path; if invalid, assume given JSON value directly
if _, err := os.Stat(credsEnv); err == nil {
credsBytes, err := ioutil.ReadFile(credsEnv)
if err != nil {
t.Fatalf("unable to read credentials file %s: %v", credsStr, err)
}
credsStr = string(credsBytes)
} else {
credsStr = credsEnv
}
return credsStr
}

View File

@@ -0,0 +1,72 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package connutil
import (
"fmt"
"cloud.google.com/go/cloudsqlconn"
"cloud.google.com/go/cloudsqlconn/postgres/pgxv4"
)
var configurableAuthTypes = []string{
AuthTypeGCPIAM,
}
func (c *SQLConnectionProducer) getCloudSQLDriverType() (string, error) {
var driverType string
// using switch case for future extensibility
switch c.Type {
case dbTypePostgres:
driverType = cloudSQLPostgres
default:
return "", fmt.Errorf("unsupported DB type for cloud IAM: %s", c.Type)
}
return driverType, nil
}
func (c *SQLConnectionProducer) registerDrivers(driverName string, credentials string) (func() error, error) {
typ, err := c.getCloudSQLDriverType()
if err != nil {
return nil, err
}
opts, err := GetCloudSQLAuthOptions(credentials)
if err != nil {
return nil, err
}
// using switch case for future extensibility
switch typ {
case cloudSQLPostgres:
return pgxv4.RegisterDriver(driverName, opts...)
}
return nil, fmt.Errorf("unrecognized cloudsql type encountered: %s", typ)
}
// GetCloudSQLAuthOptions takes a credentials JSON and returns
// a set of GCP CloudSQL options - always WithIAMAUthN, and then the appropriate file/JSON option.
func GetCloudSQLAuthOptions(credentials string) ([]cloudsqlconn.Option, error) {
opts := []cloudsqlconn.Option{cloudsqlconn.WithIAMAuthN()}
if credentials != "" {
opts = append(opts, cloudsqlconn.WithCredentialsJSON([]byte(credentials)))
}
return opts, nil
}
func ValidateAuthType(authType string) bool {
var valid bool
for _, typ := range configurableAuthTypes {
if authType == typ {
valid = true
break
}
}
return valid
}

View File

@@ -14,11 +14,19 @@ import (
"github.com/hashicorp/errwrap"
"github.com/hashicorp/go-secure-stdlib/parseutil"
"github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/sdk/database/dbplugin"
"github.com/hashicorp/vault/sdk/database/helper/dbutil"
"github.com/mitchellh/mapstructure"
)
const (
AuthTypeGCPIAM = "gcp_iam"
dbTypePostgres = "pgx"
cloudSQLPostgres = "cloudsql-postgres"
)
var _ ConnectionProducer = &SQLConnectionProducer{}
// SQLConnectionProducer implements ConnectionProducer and provides a generic producer for most sql databases
@@ -29,8 +37,15 @@ type SQLConnectionProducer struct {
MaxConnectionLifetimeRaw interface{} `json:"max_connection_lifetime" mapstructure:"max_connection_lifetime" structs:"max_connection_lifetime"`
Username string `json:"username" mapstructure:"username" structs:"username"`
Password string `json:"password" mapstructure:"password" structs:"password"`
AuthType string `json:"auth_type" mapstructure:"auth_type" structs:"auth_type"`
ServiceAccountJSON string `json:"service_account_json" mapstructure:"service_account_json" structs:"service_account_json"`
DisableEscaping bool `json:"disable_escaping" mapstructure:"disable_escaping" structs:"disable_escaping"`
// cloud options here - cloudDriverName is globally unique, but only needs to be retained for the lifetime
// of driver registration, not across plugin restarts.
cloudDriverName string
cloudDialerCleanup func() error
Type string
RawConfig map[string]interface{}
maxConnectionLifetime time.Duration
@@ -107,6 +122,32 @@ func (c *SQLConnectionProducer) Init(ctx context.Context, conf map[string]interf
return nil, errwrap.Wrapf("invalid max_connection_lifetime: {{err}}", err)
}
// validate auth_type if provided
authType := c.AuthType
if authType != "" {
if ok := ValidateAuthType(authType); !ok {
return nil, fmt.Errorf("invalid auth_type %s provided", authType)
}
}
if authType == AuthTypeGCPIAM {
c.cloudDriverName, err = uuid.GenerateUUID()
if err != nil {
return nil, fmt.Errorf("unable to generate UUID for IAM configuration: %w", err)
}
// for _most_ sql databases, the driver itself contains no state. In the case of google's cloudsql drivers,
// however, the driver might store a credentials file, in which case the state stored by the driver is in
// fact critical to the proper function of the connection. So it needs to be registered here inside the
// ConnectionProducer init.
dialerCleanup, err := c.registerDrivers(c.cloudDriverName, c.ServiceAccountJSON)
if err != nil {
return nil, err
}
c.cloudDialerCleanup = dialerCleanup
}
// Set initialized to true at this point since all fields are set,
// and the connection can be established at a later time.
c.Initialized = true
@@ -137,12 +178,24 @@ func (c *SQLConnectionProducer) Connection(ctx context.Context) (interface{}, er
// If the ping was unsuccessful, close it and ignore errors as we'll be
// reestablishing anyways
c.db.Close()
// if IAM authentication is enabled
// ensure open dialer is also closed
if c.AuthType == AuthTypeGCPIAM {
if c.cloudDialerCleanup != nil {
c.cloudDialerCleanup()
}
}
}
// For mssql backend, switch to sqlserver instead
dbType := c.Type
if c.Type == "mssql" {
dbType = "sqlserver"
// default non-IAM behavior
driverName := c.Type
if c.AuthType == AuthTypeGCPIAM {
driverName = c.cloudDriverName
} else if c.Type == "mssql" {
// For mssql backend, switch to sqlserver instead
driverName = "sqlserver"
}
// Otherwise, attempt to make connection
@@ -164,7 +217,7 @@ func (c *SQLConnectionProducer) Connection(ctx context.Context) (interface{}, er
}
var err error
c.db, err = sql.Open(dbType, conn)
c.db, err = sql.Open(driverName, conn)
if err != nil {
return nil, err
}
@@ -192,6 +245,13 @@ func (c *SQLConnectionProducer) Close() error {
if c.db != nil {
c.db.Close()
// cleanup IAM dialer if it exists
if c.AuthType == AuthTypeGCPIAM {
if c.cloudDialerCleanup != nil {
c.cloudDialerCleanup()
}
}
}
c.db = nil