mirror of
https://github.com/optim-enterprises-bv/vault.git
synced 2025-11-01 19:17:58 +00:00
Add support for IAM Auth for Google CloudSQL DBs (#22445)
This commit is contained in:
@@ -5,6 +5,7 @@ package dbtesting
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -39,7 +40,7 @@ func AssertInitializeCircleCiTest(t *testing.T, db dbplugin.Database, req dbplug
|
||||
var err error
|
||||
|
||||
for i := 1; i <= maxAttempts; i++ {
|
||||
resp, err = verifyInitialize(t, db, req)
|
||||
resp, err = VerifyInitialize(t, db, req)
|
||||
if err != nil {
|
||||
t.Errorf("Failed AssertInitialize attempt: %d with error:\n%+v\n", i, err)
|
||||
time.Sleep(1 * time.Second)
|
||||
@@ -57,14 +58,14 @@ func AssertInitializeCircleCiTest(t *testing.T, db dbplugin.Database, req dbplug
|
||||
|
||||
func AssertInitialize(t *testing.T, db dbplugin.Database, req dbplugin.InitializeRequest) dbplugin.InitializeResponse {
|
||||
t.Helper()
|
||||
resp, err := verifyInitialize(t, db, req)
|
||||
resp, err := VerifyInitialize(t, db, req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to initialize: %s", err)
|
||||
}
|
||||
return resp
|
||||
}
|
||||
|
||||
func verifyInitialize(t *testing.T, db dbplugin.Database, req dbplugin.InitializeRequest) (dbplugin.InitializeResponse, error) {
|
||||
func VerifyInitialize(t *testing.T, db dbplugin.Database, req dbplugin.InitializeRequest) (dbplugin.InitializeResponse, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), getRequestTimeout(t))
|
||||
defer cancel()
|
||||
|
||||
@@ -119,3 +120,31 @@ func AssertClose(t *testing.T, db dbplugin.Database) {
|
||||
t.Fatalf("Failed to close database: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
// GetGCPTestCredentials reads the credentials from the
|
||||
// GOOGLE_APPLICATIONS_CREDENTIALS environment variable
|
||||
// The credentials are read from a file if a file exists
|
||||
// otherwise they are returned as JSON
|
||||
func GetGCPTestCredentials(t *testing.T) string {
|
||||
t.Helper()
|
||||
envCredentials := "GOOGLE_APPLICATIONS_CREDENTIALS"
|
||||
|
||||
var credsStr string
|
||||
credsEnv := os.Getenv(envCredentials)
|
||||
if credsEnv == "" {
|
||||
t.Skipf("env var %s not set, skipping test", envCredentials)
|
||||
}
|
||||
|
||||
// Attempt to read as file path; if invalid, assume given JSON value directly
|
||||
if _, err := os.Stat(credsEnv); err == nil {
|
||||
credsBytes, err := ioutil.ReadFile(credsEnv)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to read credentials file %s: %v", credsStr, err)
|
||||
}
|
||||
credsStr = string(credsBytes)
|
||||
} else {
|
||||
credsStr = credsEnv
|
||||
}
|
||||
|
||||
return credsStr
|
||||
}
|
||||
|
||||
72
sdk/database/helper/connutil/cloudsql.go
Normal file
72
sdk/database/helper/connutil/cloudsql.go
Normal file
@@ -0,0 +1,72 @@
|
||||
// Copyright (c) HashiCorp, Inc.
|
||||
// SPDX-License-Identifier: MPL-2.0
|
||||
|
||||
package connutil
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"cloud.google.com/go/cloudsqlconn"
|
||||
"cloud.google.com/go/cloudsqlconn/postgres/pgxv4"
|
||||
)
|
||||
|
||||
var configurableAuthTypes = []string{
|
||||
AuthTypeGCPIAM,
|
||||
}
|
||||
|
||||
func (c *SQLConnectionProducer) getCloudSQLDriverType() (string, error) {
|
||||
var driverType string
|
||||
// using switch case for future extensibility
|
||||
switch c.Type {
|
||||
case dbTypePostgres:
|
||||
driverType = cloudSQLPostgres
|
||||
default:
|
||||
return "", fmt.Errorf("unsupported DB type for cloud IAM: %s", c.Type)
|
||||
}
|
||||
|
||||
return driverType, nil
|
||||
}
|
||||
|
||||
func (c *SQLConnectionProducer) registerDrivers(driverName string, credentials string) (func() error, error) {
|
||||
typ, err := c.getCloudSQLDriverType()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
opts, err := GetCloudSQLAuthOptions(credentials)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// using switch case for future extensibility
|
||||
switch typ {
|
||||
case cloudSQLPostgres:
|
||||
return pgxv4.RegisterDriver(driverName, opts...)
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("unrecognized cloudsql type encountered: %s", typ)
|
||||
}
|
||||
|
||||
// GetCloudSQLAuthOptions takes a credentials JSON and returns
|
||||
// a set of GCP CloudSQL options - always WithIAMAUthN, and then the appropriate file/JSON option.
|
||||
func GetCloudSQLAuthOptions(credentials string) ([]cloudsqlconn.Option, error) {
|
||||
opts := []cloudsqlconn.Option{cloudsqlconn.WithIAMAuthN()}
|
||||
|
||||
if credentials != "" {
|
||||
opts = append(opts, cloudsqlconn.WithCredentialsJSON([]byte(credentials)))
|
||||
}
|
||||
|
||||
return opts, nil
|
||||
}
|
||||
|
||||
func ValidateAuthType(authType string) bool {
|
||||
var valid bool
|
||||
for _, typ := range configurableAuthTypes {
|
||||
if authType == typ {
|
||||
valid = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return valid
|
||||
}
|
||||
@@ -14,11 +14,19 @@ import (
|
||||
|
||||
"github.com/hashicorp/errwrap"
|
||||
"github.com/hashicorp/go-secure-stdlib/parseutil"
|
||||
"github.com/hashicorp/go-uuid"
|
||||
"github.com/hashicorp/vault/sdk/database/dbplugin"
|
||||
"github.com/hashicorp/vault/sdk/database/helper/dbutil"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
)
|
||||
|
||||
const (
|
||||
AuthTypeGCPIAM = "gcp_iam"
|
||||
|
||||
dbTypePostgres = "pgx"
|
||||
cloudSQLPostgres = "cloudsql-postgres"
|
||||
)
|
||||
|
||||
var _ ConnectionProducer = &SQLConnectionProducer{}
|
||||
|
||||
// SQLConnectionProducer implements ConnectionProducer and provides a generic producer for most sql databases
|
||||
@@ -29,8 +37,15 @@ type SQLConnectionProducer struct {
|
||||
MaxConnectionLifetimeRaw interface{} `json:"max_connection_lifetime" mapstructure:"max_connection_lifetime" structs:"max_connection_lifetime"`
|
||||
Username string `json:"username" mapstructure:"username" structs:"username"`
|
||||
Password string `json:"password" mapstructure:"password" structs:"password"`
|
||||
AuthType string `json:"auth_type" mapstructure:"auth_type" structs:"auth_type"`
|
||||
ServiceAccountJSON string `json:"service_account_json" mapstructure:"service_account_json" structs:"service_account_json"`
|
||||
DisableEscaping bool `json:"disable_escaping" mapstructure:"disable_escaping" structs:"disable_escaping"`
|
||||
|
||||
// cloud options here - cloudDriverName is globally unique, but only needs to be retained for the lifetime
|
||||
// of driver registration, not across plugin restarts.
|
||||
cloudDriverName string
|
||||
cloudDialerCleanup func() error
|
||||
|
||||
Type string
|
||||
RawConfig map[string]interface{}
|
||||
maxConnectionLifetime time.Duration
|
||||
@@ -107,6 +122,32 @@ func (c *SQLConnectionProducer) Init(ctx context.Context, conf map[string]interf
|
||||
return nil, errwrap.Wrapf("invalid max_connection_lifetime: {{err}}", err)
|
||||
}
|
||||
|
||||
// validate auth_type if provided
|
||||
authType := c.AuthType
|
||||
if authType != "" {
|
||||
if ok := ValidateAuthType(authType); !ok {
|
||||
return nil, fmt.Errorf("invalid auth_type %s provided", authType)
|
||||
}
|
||||
}
|
||||
|
||||
if authType == AuthTypeGCPIAM {
|
||||
c.cloudDriverName, err = uuid.GenerateUUID()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to generate UUID for IAM configuration: %w", err)
|
||||
}
|
||||
|
||||
// for _most_ sql databases, the driver itself contains no state. In the case of google's cloudsql drivers,
|
||||
// however, the driver might store a credentials file, in which case the state stored by the driver is in
|
||||
// fact critical to the proper function of the connection. So it needs to be registered here inside the
|
||||
// ConnectionProducer init.
|
||||
dialerCleanup, err := c.registerDrivers(c.cloudDriverName, c.ServiceAccountJSON)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.cloudDialerCleanup = dialerCleanup
|
||||
}
|
||||
|
||||
// Set initialized to true at this point since all fields are set,
|
||||
// and the connection can be established at a later time.
|
||||
c.Initialized = true
|
||||
@@ -137,12 +178,24 @@ func (c *SQLConnectionProducer) Connection(ctx context.Context) (interface{}, er
|
||||
// If the ping was unsuccessful, close it and ignore errors as we'll be
|
||||
// reestablishing anyways
|
||||
c.db.Close()
|
||||
|
||||
// if IAM authentication is enabled
|
||||
// ensure open dialer is also closed
|
||||
if c.AuthType == AuthTypeGCPIAM {
|
||||
if c.cloudDialerCleanup != nil {
|
||||
c.cloudDialerCleanup()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// For mssql backend, switch to sqlserver instead
|
||||
dbType := c.Type
|
||||
if c.Type == "mssql" {
|
||||
dbType = "sqlserver"
|
||||
// default non-IAM behavior
|
||||
driverName := c.Type
|
||||
|
||||
if c.AuthType == AuthTypeGCPIAM {
|
||||
driverName = c.cloudDriverName
|
||||
} else if c.Type == "mssql" {
|
||||
// For mssql backend, switch to sqlserver instead
|
||||
driverName = "sqlserver"
|
||||
}
|
||||
|
||||
// Otherwise, attempt to make connection
|
||||
@@ -164,7 +217,7 @@ func (c *SQLConnectionProducer) Connection(ctx context.Context) (interface{}, er
|
||||
}
|
||||
|
||||
var err error
|
||||
c.db, err = sql.Open(dbType, conn)
|
||||
c.db, err = sql.Open(driverName, conn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -192,6 +245,13 @@ func (c *SQLConnectionProducer) Close() error {
|
||||
|
||||
if c.db != nil {
|
||||
c.db.Close()
|
||||
|
||||
// cleanup IAM dialer if it exists
|
||||
if c.AuthType == AuthTypeGCPIAM {
|
||||
if c.cloudDialerCleanup != nil {
|
||||
c.cloudDialerCleanup()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
c.db = nil
|
||||
|
||||
Reference in New Issue
Block a user