mirror of
				https://github.com/optim-enterprises-bv/vault.git
				synced 2025-10-30 18:17:55 +00:00 
			
		
		
		
	Add support for IAM Auth for Google CloudSQL DBs (#22445)
This commit is contained in:
		| @@ -13,6 +13,7 @@ import ( | ||||
| 	"sync" | ||||
| 	"time" | ||||
|  | ||||
| 	cloudmysql "cloud.google.com/go/cloudsqlconn/mysql/mysql" | ||||
| 	"github.com/go-sql-driver/mysql" | ||||
| 	"github.com/hashicorp/go-secure-stdlib/parseutil" | ||||
| 	"github.com/hashicorp/go-uuid" | ||||
| @@ -21,6 +22,11 @@ import ( | ||||
| 	"github.com/mitchellh/mapstructure" | ||||
| ) | ||||
|  | ||||
| const ( | ||||
| 	cloudSQLMySQL = "cloudsql-mysql" | ||||
| 	driverMySQL   = "mysql" | ||||
| ) | ||||
|  | ||||
| // mySQLConnectionProducer implements ConnectionProducer and provides a generic producer for most sql databases | ||||
| type mySQLConnectionProducer struct { | ||||
| 	ConnectionURL            string      `json:"connection_url"          mapstructure:"connection_url"          structs:"connection_url"` | ||||
| @@ -29,6 +35,8 @@ type mySQLConnectionProducer struct { | ||||
| 	MaxConnectionLifetimeRaw interface{} `json:"max_connection_lifetime" mapstructure:"max_connection_lifetime" structs:"max_connection_lifetime"` | ||||
| 	Username                 string      `json:"username" mapstructure:"username" structs:"username"` | ||||
| 	Password                 string      `json:"password" mapstructure:"password" structs:"password"` | ||||
| 	AuthType                 string      `json:"auth_type" mapstructure:"auth_type" structs:"auth_type"` | ||||
| 	ServiceAccountJSON       string      `json:"service_account_json" mapstructure:"service_account_json" structs:"service_account_json"` | ||||
|  | ||||
| 	TLSCertificateKeyData []byte `json:"tls_certificate_key" mapstructure:"tls_certificate_key" structs:"-"` | ||||
| 	TLSCAData             []byte `json:"tls_ca"              mapstructure:"tls_ca"              structs:"-"` | ||||
| @@ -38,6 +46,10 @@ type mySQLConnectionProducer struct { | ||||
| 	// tlsConfigName is a globally unique name that references the TLS config for this instance in the mysql driver | ||||
| 	tlsConfigName string | ||||
|  | ||||
| 	// cloudDriverName is a globally unique name that references the cloud dialer config for this instance of the driver | ||||
| 	cloudDriverName    string | ||||
| 	cloudDialerCleanup func() error | ||||
|  | ||||
| 	RawConfig             map[string]interface{} | ||||
| 	maxConnectionLifetime time.Duration | ||||
| 	Initialized           bool | ||||
| @@ -110,6 +122,32 @@ func (c *mySQLConnectionProducer) Init(ctx context.Context, conf map[string]inte | ||||
| 		mysql.RegisterTLSConfig(c.tlsConfigName, tlsConfig) | ||||
| 	} | ||||
|  | ||||
| 	// validate auth_type if provided | ||||
| 	authType := c.AuthType | ||||
| 	if authType != "" { | ||||
| 		if ok := connutil.ValidateAuthType(authType); !ok { | ||||
| 			return nil, fmt.Errorf("invalid auth_type %s provided", authType) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	if c.AuthType == connutil.AuthTypeGCPIAM { | ||||
| 		c.cloudDriverName, err = uuid.GenerateUUID() | ||||
| 		if err != nil { | ||||
| 			return nil, fmt.Errorf("unable to generate UUID for IAM configuration: %w", err) | ||||
| 		} | ||||
|  | ||||
| 		// for _most_ sql databases, the driver itself contains no state. In the case of google's cloudsql drivers, | ||||
| 		// however, the driver might store a credentials file, in which case the state stored by the driver is in | ||||
| 		// fact critical to the proper function of the connection. So it needs to be registered here inside the | ||||
| 		// ConnectionProducer init. | ||||
| 		dialerCleanup, err := registerDriverMySQL(c.cloudDriverName, c.ServiceAccountJSON) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
|  | ||||
| 		c.cloudDialerCleanup = dialerCleanup | ||||
| 	} | ||||
|  | ||||
| 	// Set initialized to true at this point since all fields are set, | ||||
| 	// and the connection can be established at a later time. | ||||
| 	c.Initialized = true | ||||
| @@ -140,6 +178,20 @@ func (c *mySQLConnectionProducer) Connection(ctx context.Context) (interface{}, | ||||
| 		// If the ping was unsuccessful, close it and ignore errors as we'll be | ||||
| 		// reestablishing anyways | ||||
| 		c.db.Close() | ||||
|  | ||||
| 		// if IAM authentication was enabled | ||||
| 		// ensure open dialer is also closed | ||||
| 		if c.AuthType == connutil.AuthTypeGCPIAM { | ||||
| 			if c.cloudDialerCleanup != nil { | ||||
| 				c.cloudDialerCleanup() | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
| 	} | ||||
|  | ||||
| 	driverName := driverMySQL | ||||
| 	if c.cloudDriverName != "" { | ||||
| 		driverName = c.cloudDriverName | ||||
| 	} | ||||
|  | ||||
| 	connURL, err := c.addTLStoDSN() | ||||
| @@ -147,7 +199,12 @@ func (c *mySQLConnectionProducer) Connection(ctx context.Context) (interface{}, | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	c.db, err = sql.Open("mysql", connURL) | ||||
| 	cloudURL, err := c.rewriteProtocolForGCP(connURL) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	c.db, err = sql.Open(driverName, cloudURL) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| @@ -174,6 +231,13 @@ func (c *mySQLConnectionProducer) Close() error { | ||||
| 	defer c.Unlock() | ||||
|  | ||||
| 	if c.db != nil { | ||||
| 		// if auth_type is IAM, ensure cleanup | ||||
| 		// of cloudSQL resources | ||||
| 		if c.AuthType == connutil.AuthTypeGCPIAM { | ||||
| 			if c.cloudDialerCleanup != nil { | ||||
| 				c.cloudDialerCleanup() | ||||
| 			} | ||||
| 		} | ||||
| 		c.db.Close() | ||||
| 	} | ||||
|  | ||||
| @@ -230,3 +294,38 @@ func (c *mySQLConnectionProducer) addTLStoDSN() (connURL string, err error) { | ||||
| 	connURL = config.FormatDSN() | ||||
| 	return connURL, nil | ||||
| } | ||||
|  | ||||
| // rewriteProtocolForGCP rewrites the protocol in the DSN to contain the protocol name associated | ||||
| // with the dialer and therefore driver associated with the provided cloudsqlconn.DialerOpts. | ||||
| // As a safety/sanity check, it will only do this for protocol "cloudsql-mysql", the name GCP uses in its documentation. | ||||
| // | ||||
| // For example, it will rewrite the dsn "user@cloudsql-mysql(zone:region:instance)/ to | ||||
| // "user@the-uuid-generated(zone:region:instance)/ | ||||
| func (c *mySQLConnectionProducer) rewriteProtocolForGCP(inDSN string) (string, error) { | ||||
| 	if c.cloudDriverName == "" { | ||||
| 		// unchanged if not cloud | ||||
| 		return inDSN, nil | ||||
| 	} | ||||
|  | ||||
| 	config, err := mysql.ParseDSN(inDSN) | ||||
| 	if err != nil { | ||||
| 		return "", fmt.Errorf("unable to parse connectionURL: %s", err) | ||||
| 	} | ||||
|  | ||||
| 	if config.Net != cloudSQLMySQL { | ||||
| 		return "", fmt.Errorf("didn't update net name because it wasn't what we expected as a placeholder: %s", config.Net) | ||||
| 	} | ||||
|  | ||||
| 	config.Net = c.cloudDriverName | ||||
|  | ||||
| 	return config.FormatDSN(), nil | ||||
| } | ||||
|  | ||||
| func registerDriverMySQL(driverName, credentials string) (cleanup func() error, err error) { | ||||
| 	opts, err := connutil.GetCloudSQLAuthOptions(credentials) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	return cloudmysql.RegisterDriver(driverName, opts...) | ||||
| } | ||||
|   | ||||
| @@ -7,18 +7,21 @@ import ( | ||||
| 	"context" | ||||
| 	"database/sql" | ||||
| 	"fmt" | ||||
| 	"os" | ||||
| 	"strings" | ||||
| 	"testing" | ||||
| 	"time" | ||||
|  | ||||
| 	stdmysql "github.com/go-sql-driver/mysql" | ||||
| 	"github.com/hashicorp/go-secure-stdlib/strutil" | ||||
| 	"github.com/stretchr/testify/require" | ||||
|  | ||||
| 	mysqlhelper "github.com/hashicorp/vault/helper/testhelpers/mysql" | ||||
| 	dbplugin "github.com/hashicorp/vault/sdk/database/dbplugin/v5" | ||||
| 	"github.com/hashicorp/vault/sdk/database/dbplugin/v5" | ||||
| 	dbtesting "github.com/hashicorp/vault/sdk/database/dbplugin/v5/testing" | ||||
| 	"github.com/hashicorp/vault/sdk/database/helper/connutil" | ||||
| 	"github.com/hashicorp/vault/sdk/database/helper/credsutil" | ||||
| 	"github.com/hashicorp/vault/sdk/database/helper/dbutil" | ||||
| 	"github.com/stretchr/testify/require" | ||||
| ) | ||||
|  | ||||
| var _ dbplugin.Database = (*MySQL)(nil) | ||||
| @@ -44,6 +47,79 @@ func TestMySQL_Initialize(t *testing.T) { | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // TestMySQL_Initialize_CloudGCP validates the proper initialization of a MySQL backend pointing | ||||
| // to a GCP CloudSQL MySQL instance. This expects some external setup (exact TBD) | ||||
| func TestMySQL_Initialize_CloudGCP(t *testing.T) { | ||||
| 	envConnURL := "CONNECTION_URL" | ||||
| 	connURL := os.Getenv(envConnURL) | ||||
| 	if connURL == "" { | ||||
| 		t.Skipf("env var %s not set, skipping test", envConnURL) | ||||
| 	} | ||||
|  | ||||
| 	credStr := dbtesting.GetGCPTestCredentials(t) | ||||
|  | ||||
| 	tests := map[string]struct { | ||||
| 		req           dbplugin.InitializeRequest | ||||
| 		wantErr       bool | ||||
| 		expectedError string | ||||
| 	}{ | ||||
| 		"empty auth type": { | ||||
| 			req: dbplugin.InitializeRequest{ | ||||
| 				Config: map[string]interface{}{ | ||||
| 					"connection_url": connURL, | ||||
| 					"auth_type":      "", | ||||
| 				}, | ||||
| 			}, | ||||
| 		}, | ||||
| 		"invalid auth type": { | ||||
| 			req: dbplugin.InitializeRequest{ | ||||
| 				Config: map[string]interface{}{ | ||||
| 					"connection_url": connURL, | ||||
| 					"auth_type":      "invalid", | ||||
| 				}, | ||||
| 			}, | ||||
| 			wantErr:       true, | ||||
| 			expectedError: "invalid auth_type", | ||||
| 		}, | ||||
| 		"JSON credentials": { | ||||
| 			req: dbplugin.InitializeRequest{ | ||||
| 				Config: map[string]interface{}{ | ||||
| 					"connection_url":       connURL, | ||||
| 					"auth_type":            connutil.AuthTypeGCPIAM, | ||||
| 					"service_account_json": credStr, | ||||
| 				}, | ||||
| 				VerifyConnection: true, | ||||
| 			}, | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	for n, tc := range tests { | ||||
| 		t.Run(n, func(t *testing.T) { | ||||
| 			db := newMySQL(DefaultUserNameTemplate) | ||||
| 			defer dbtesting.AssertClose(t, db) | ||||
| 			_, err := db.Initialize(context.Background(), tc.req) | ||||
|  | ||||
| 			if tc.wantErr { | ||||
| 				if err == nil { | ||||
| 					t.Fatalf("expected error but received nil") | ||||
| 				} | ||||
|  | ||||
| 				if !strings.Contains(err.Error(), tc.expectedError) { | ||||
| 					t.Fatalf("expected error %s, got %s", tc.expectedError, err.Error()) | ||||
| 				} | ||||
| 			} else { | ||||
| 				if err != nil { | ||||
| 					t.Fatalf("expected no error, received %s", err) | ||||
| 				} | ||||
|  | ||||
| 				if !db.Initialized { | ||||
| 					t.Fatal("Database should be initialized") | ||||
| 				} | ||||
| 			} | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func testInitialize(t *testing.T, rootPassword string) { | ||||
| 	cleanup, connURL := mysqlhelper.PrepareTestContainer(t, false, rootPassword) | ||||
| 	defer cleanup() | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 kpcraig
					kpcraig