mirror of
				https://github.com/optim-enterprises-bv/vault.git
				synced 2025-10-29 17:52:32 +00:00 
			
		
		
		
	database/postgres: add inline certificate authentication fields (#28024)
* add inline cert auth to postres db plugin * handle both sslinline and new TLS plugin fields * refactor PrepareTestContainerWithSSL * add tests for postgres inline TLS fields * changelog * revert back to errwrap since the middleware sanitizing depends on it * enable only setting sslrootcert
This commit is contained in:
		 John-Michael Faircloth
					John-Michael Faircloth
				
			
				
					committed by
					
						 GitHub
						GitHub
					
				
			
			
				
	
			
			
			 GitHub
						GitHub
					
				
			
						parent
						
							a19195c901
						
					
				
				
					commit
					3fcb1a67c5
				
			| @@ -345,6 +345,8 @@ func TestBackend_config_connection(t *testing.T) { | ||||
| 	assert.Equal(t, "plugin-test", eventSender.Events[2].Event.Metadata.AsMap()["name"]) | ||||
| } | ||||
|  | ||||
| // TestBackend_BadConnectionString tests that an error response resulting from | ||||
| // a failed connection does not expose the URL. The middleware should sanitize it. | ||||
| func TestBackend_BadConnectionString(t *testing.T) { | ||||
| 	cluster, sys := getClusterPostgresDB(t) | ||||
| 	defer cluster.Cleanup() | ||||
|   | ||||
							
								
								
									
										3
									
								
								changelog/28024.txt
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										3
									
								
								changelog/28024.txt
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,3 @@ | ||||
| ```release-note:improvement | ||||
| database/postgres: Add new fields to the plugin's config endpoint for client certificate authentication. | ||||
| ``` | ||||
| @@ -9,11 +9,13 @@ import ( | ||||
| 	"fmt" | ||||
| 	"net/url" | ||||
| 	"os" | ||||
| 	"strconv" | ||||
| 	"testing" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/hashicorp/vault/helper/testhelpers/certhelpers" | ||||
| 	"github.com/hashicorp/vault/sdk/database/helper/connutil" | ||||
| 	"github.com/hashicorp/vault/sdk/helper/docker" | ||||
| 	"github.com/hashicorp/vault/sdk/helper/pluginutil" | ||||
| ) | ||||
|  | ||||
| const ( | ||||
| @@ -68,7 +70,13 @@ func PrepareTestContainerWithVaultUser(t *testing.T, ctx context.Context) (func( | ||||
|  | ||||
| // PrepareTestContainerWithSSL will setup a test container with SSL enabled so | ||||
| // that we can test client certificate authentication. | ||||
| func PrepareTestContainerWithSSL(t *testing.T, ctx context.Context, sslMode string, useFallback bool) (func(), string) { | ||||
| func PrepareTestContainerWithSSL( | ||||
| 	t *testing.T, | ||||
| 	sslMode string, | ||||
| 	caCert certhelpers.Certificate, | ||||
| 	clientCert certhelpers.Certificate, | ||||
| 	useFallback bool, | ||||
| ) (func(), string) { | ||||
| 	runOpts := defaultRunOpts(t) | ||||
| 	runner, err := docker.NewServiceRunner(runOpts) | ||||
| 	if err != nil { | ||||
| @@ -82,21 +90,11 @@ func PrepareTestContainerWithSSL(t *testing.T, ctx context.Context, sslMode stri | ||||
| 	} | ||||
|  | ||||
| 	// Create certificates for postgres authentication | ||||
| 	caCert := certhelpers.NewCert(t, | ||||
| 		certhelpers.CommonName("ca"), | ||||
| 		certhelpers.IsCA(true), | ||||
| 		certhelpers.SelfSign(), | ||||
| 	) | ||||
| 	serverCert := certhelpers.NewCert(t, | ||||
| 		certhelpers.CommonName("server"), | ||||
| 		certhelpers.DNS("localhost"), | ||||
| 		certhelpers.Parent(caCert), | ||||
| 	) | ||||
| 	clientCert := certhelpers.NewCert(t, | ||||
| 		certhelpers.CommonName("postgres"), | ||||
| 		certhelpers.DNS("localhost"), | ||||
| 		certhelpers.Parent(caCert), | ||||
| 	) | ||||
|  | ||||
| 	bCtx := docker.NewBuildContext() | ||||
| 	bCtx["ca.crt"] = docker.PathContentsFromBytes(caCert.CombinedPEM()) | ||||
| @@ -133,6 +131,9 @@ EOF | ||||
| 		t.Fatalf("failed to copy to container: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) | ||||
| 	defer cancel() | ||||
|  | ||||
| 	// overwrite the postgresql.conf config file with our ssl settings | ||||
| 	mustRunCommand(t, ctx, runner, id, | ||||
| 		[]string{"bash", "/var/lib/postgresql/pg-conf.sh"}) | ||||
| @@ -150,7 +151,7 @@ EOF | ||||
| 		return svc.Cleanup, svc.Config.URL().String() | ||||
| 	} | ||||
|  | ||||
| 	sslConfig, err := connectPostgresSSL( | ||||
| 	sslConfig := getPostgresSSLConfig( | ||||
| 		t, | ||||
| 		svc.Config.URL().Host, | ||||
| 		sslMode, | ||||
| @@ -197,42 +198,40 @@ func prepareTestContainer(t *testing.T, runOpts docker.RunOptions, password stri | ||||
| 	return runner, svc.Cleanup, svc.Config.URL().String(), containerID | ||||
| } | ||||
|  | ||||
| // connectPostgresSSL is used to verify the connection of our test container | ||||
| // and construct the connection string that is used in tests. | ||||
| // | ||||
| // NOTE: The RawQuery component of the url sets the custom sslinline field and | ||||
| // inlines the certificate material in the sslrootcert, sslcert, and sslkey | ||||
| // fields. This feature will be removed in a future version of the SDK. | ||||
| func connectPostgresSSL(t *testing.T, host, sslMode, caCert, clientCert, clientKey string, useFallback bool) (docker.ServiceConfig, error) { | ||||
| func getPostgresSSLConfig(t *testing.T, host, sslMode, caCert, clientCert, clientKey string, useFallback bool) docker.ServiceConfig { | ||||
| 	if useFallback { | ||||
| 		// set the first host to a bad address so we can test the fallback logic | ||||
| 		host = "localhost:55," + host | ||||
| 	} | ||||
| 	u := url.URL{ | ||||
| 		Scheme: "postgres", | ||||
| 		User:   url.User("postgres"), | ||||
| 		Host:   host, | ||||
| 		Path:   "postgres", | ||||
| 		RawQuery: url.Values{ | ||||
| 			"sslmode":     {sslMode}, | ||||
| 			"sslinline":   {"true"}, | ||||
| 			"sslrootcert": {caCert}, | ||||
| 			"sslcert":     {clientCert}, | ||||
| 			"sslkey":      {clientKey}, | ||||
| 		}.Encode(), | ||||
|  | ||||
| 	u := url.URL{} | ||||
|  | ||||
| 	if ok, _ := strconv.ParseBool(os.Getenv(pluginutil.PluginUsePostgresSSLInline)); ok { | ||||
| 		// TODO: remove this when we remove the underlying feature in a future SDK version | ||||
| 		u = url.URL{ | ||||
| 			Scheme: "postgres", | ||||
| 			User:   url.User("postgres"), | ||||
| 			Host:   host, | ||||
| 			Path:   "postgres", | ||||
| 			RawQuery: url.Values{ | ||||
| 				"sslmode":     {sslMode}, | ||||
| 				"sslinline":   {"true"}, | ||||
| 				"sslrootcert": {caCert}, | ||||
| 				"sslcert":     {clientCert}, | ||||
| 				"sslkey":      {clientKey}, | ||||
| 			}.Encode(), | ||||
| 		} | ||||
| 	} else { | ||||
| 		u = url.URL{ | ||||
| 			Scheme:   "postgres", | ||||
| 			User:     url.User("postgres"), | ||||
| 			Host:     host, | ||||
| 			Path:     "postgres", | ||||
| 			RawQuery: url.Values{"sslmode": {sslMode}}.Encode(), | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	// TODO: remove this deprecated function call in a future SDK version | ||||
| 	db, err := connutil.OpenPostgres("pgx", u.String()) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	defer db.Close() | ||||
|  | ||||
| 	if err = db.Ping(); err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return docker.NewServiceURL(u), nil | ||||
| 	return docker.NewServiceURL(u) | ||||
| } | ||||
|  | ||||
| func connectPostgres(password, repo string, useFallback bool) docker.ServiceAdapter { | ||||
|   | ||||
| @@ -123,11 +123,8 @@ func (c *mySQLConnectionProducer) Init(ctx context.Context, conf map[string]inte | ||||
| 	} | ||||
|  | ||||
| 	// validate auth_type if provided | ||||
| 	authType := c.AuthType | ||||
| 	if authType != "" { | ||||
| 		if ok := connutil.ValidateAuthType(authType); !ok { | ||||
| 			return nil, fmt.Errorf("invalid auth_type %s provided", authType) | ||||
| 		} | ||||
| 	if ok := connutil.ValidateAuthType(c.AuthType); !ok { | ||||
| 		return nil, fmt.Errorf("invalid auth_type: %s", c.AuthType) | ||||
| 	} | ||||
|  | ||||
| 	if c.AuthType == connutil.AuthTypeGCPIAM { | ||||
|   | ||||
| @@ -5,7 +5,11 @@ package postgresql | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"crypto/tls" | ||||
| 	"crypto/x509" | ||||
| 	"database/sql" | ||||
| 	"encoding/pem" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"regexp" | ||||
| 	"strings" | ||||
| @@ -79,11 +83,65 @@ func new() *PostgreSQL { | ||||
| type PostgreSQL struct { | ||||
| 	*connutil.SQLConnectionProducer | ||||
|  | ||||
| 	TLSCertificateData []byte `json:"tls_certificate" structs:"-" mapstructure:"tls_certificate"` | ||||
| 	TLSPrivateKey      []byte `json:"tls_private_key" structs:"-" mapstructure:"tls_private_key"` | ||||
| 	TLSCAData          []byte `json:"tls_ca" structs:"-" mapstructure:"tls_ca"` | ||||
|  | ||||
| 	usernameProducer       template.StringTemplate | ||||
| 	passwordAuthentication passwordAuthentication | ||||
| } | ||||
|  | ||||
| func (p *PostgreSQL) Initialize(ctx context.Context, req dbplugin.InitializeRequest) (dbplugin.InitializeResponse, error) { | ||||
| 	sslcert, err := strutil.GetString(req.Config, "tls_certificate") | ||||
| 	if err != nil { | ||||
| 		return dbplugin.InitializeResponse{}, fmt.Errorf("failed to retrieve tls_certificate: %w", err) | ||||
| 	} | ||||
|  | ||||
| 	sslkey, err := strutil.GetString(req.Config, "tls_private_key") | ||||
| 	if err != nil { | ||||
| 		return dbplugin.InitializeResponse{}, fmt.Errorf("failed to retrieve tls_private_key: %w", err) | ||||
| 	} | ||||
|  | ||||
| 	sslrootcert, err := strutil.GetString(req.Config, "tls_ca") | ||||
| 	if err != nil { | ||||
| 		return dbplugin.InitializeResponse{}, fmt.Errorf("failed to retrieve tls_ca: %w", err) | ||||
| 	} | ||||
|  | ||||
| 	useTLS := false | ||||
| 	tlsConfig := &tls.Config{} | ||||
| 	if sslrootcert != "" { | ||||
| 		caCertPool := x509.NewCertPool() | ||||
| 		if !caCertPool.AppendCertsFromPEM([]byte(sslrootcert)) { | ||||
| 			return dbplugin.InitializeResponse{}, errors.New("unable to add CA to cert pool") | ||||
| 		} | ||||
|  | ||||
| 		tlsConfig.RootCAs = caCertPool | ||||
| 		tlsConfig.ClientCAs = caCertPool | ||||
| 		p.TLSConfig = tlsConfig | ||||
| 		useTLS = true | ||||
| 	} | ||||
|  | ||||
| 	if (sslcert != "" && sslkey == "") || (sslcert == "" && sslkey != "") { | ||||
| 		return dbplugin.InitializeResponse{}, errors.New(`both "sslcert" and "sslkey" are required`) | ||||
| 	} | ||||
|  | ||||
| 	if sslcert != "" && sslkey != "" { | ||||
| 		block, _ := pem.Decode([]byte(sslkey)) | ||||
|  | ||||
| 		cert, err := tls.X509KeyPair([]byte(sslcert), pem.EncodeToMemory(block)) | ||||
| 		if err != nil { | ||||
| 			return dbplugin.InitializeResponse{}, fmt.Errorf("unable to load cert: %w", err) | ||||
| 		} | ||||
| 		tlsConfig.Certificates = []tls.Certificate{cert} | ||||
| 		p.TLSConfig = tlsConfig | ||||
| 		useTLS = true | ||||
| 	} | ||||
|  | ||||
| 	if !useTLS { | ||||
| 		// set to nil to flag that this connection does not use a custom TLS config | ||||
| 		p.TLSConfig = nil | ||||
| 	} | ||||
|  | ||||
| 	newConf, err := p.SQLConnectionProducer.Init(ctx, req.Config, req.VerifyConnection) | ||||
| 	if err != nil { | ||||
| 		return dbplugin.InitializeResponse{}, err | ||||
|   | ||||
| @@ -12,6 +12,7 @@ import ( | ||||
| 	"testing" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/hashicorp/vault/helper/testhelpers/certhelpers" | ||||
| 	"github.com/hashicorp/vault/helper/testhelpers/postgresql" | ||||
| 	"github.com/hashicorp/vault/sdk/database/dbplugin/v5" | ||||
| 	dbtesting "github.com/hashicorp/vault/sdk/database/dbplugin/v5/testing" | ||||
| @@ -86,15 +87,18 @@ func TestPostgreSQL_InitializeMultiHost(t *testing.T) { | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // TestPostgreSQL_InitializeSSLFeatureFlag tests that the VAULT_PLUGIN_USE_POSTGRES_SSLINLINE | ||||
| // TestPostgreSQL_InitializeSSLInlineFeatureFlag tests that the VAULT_PLUGIN_USE_POSTGRES_SSLINLINE | ||||
| // flag guards against unwanted usage of the deprecated SSL client authentication path. | ||||
| // TODO: remove this when we remove the underlying feature in a future SDK version | ||||
| func TestPostgreSQL_InitializeSSLFeatureFlag(t *testing.T) { | ||||
| func TestPostgreSQL_InitializeSSLInlineFeatureFlag(t *testing.T) { | ||||
| 	// set the flag to true so we can call PrepareTestContainerWithSSL | ||||
| 	// which does a validation check on the connection | ||||
| 	t.Setenv(pluginutil.PluginUsePostgresSSLInline, "true") | ||||
|  | ||||
| 	cleanup, connURL := postgresql.PrepareTestContainerWithSSL(t, context.Background(), "verify-ca", false) | ||||
| 	// Create certificates for postgres authentication | ||||
| 	caCert := certhelpers.NewCert(t, certhelpers.CommonName("ca"), certhelpers.IsCA(true), certhelpers.SelfSign()) | ||||
| 	clientCert := certhelpers.NewCert(t, certhelpers.CommonName("postgres"), certhelpers.DNS("localhost"), certhelpers.Parent(caCert)) | ||||
| 	cleanup, connURL := postgresql.PrepareTestContainerWithSSL(t, "verify-ca", caCert, clientCert, false) | ||||
| 	t.Cleanup(cleanup) | ||||
|  | ||||
| 	type testCase struct { | ||||
| @@ -166,11 +170,11 @@ func TestPostgreSQL_InitializeSSLFeatureFlag(t *testing.T) { | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // TestPostgreSQL_InitializeSSL tests that we can successfully authenticate | ||||
| // TestPostgreSQL_InitializeSSLInline tests that we can successfully authenticate | ||||
| // with a postgres server via ssl with a URL connection string or DSN (key/value) | ||||
| // for each ssl mode. | ||||
| // TODO: remove this when we remove the underlying feature in a future SDK version | ||||
| func TestPostgreSQL_InitializeSSL(t *testing.T) { | ||||
| func TestPostgreSQL_InitializeSSLInline(t *testing.T) { | ||||
| 	// required to enable the sslinline custom parsing | ||||
| 	t.Setenv(pluginutil.PluginUsePostgresSSLInline, "true") | ||||
|  | ||||
| @@ -287,7 +291,11 @@ func TestPostgreSQL_InitializeSSL(t *testing.T) { | ||||
| 	for name, test := range tests { | ||||
| 		t.Run(name, func(t *testing.T) { | ||||
| 			t.Parallel() | ||||
| 			cleanup, connURL := postgresql.PrepareTestContainerWithSSL(t, context.Background(), test.sslMode, test.useFallback) | ||||
|  | ||||
| 			// Create certificates for postgres authentication | ||||
| 			caCert := certhelpers.NewCert(t, certhelpers.CommonName("ca"), certhelpers.IsCA(true), certhelpers.SelfSign()) | ||||
| 			clientCert := certhelpers.NewCert(t, certhelpers.CommonName("postgres"), certhelpers.DNS("localhost"), certhelpers.Parent(caCert)) | ||||
| 			cleanup, connURL := postgresql.PrepareTestContainerWithSSL(t, test.sslMode, caCert, clientCert, test.useFallback) | ||||
| 			t.Cleanup(cleanup) | ||||
|  | ||||
| 			if test.useDSN { | ||||
| @@ -326,6 +334,188 @@ func TestPostgreSQL_InitializeSSL(t *testing.T) { | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // TestPostgreSQL_InitializeSSL tests that we can successfully authenticate | ||||
| // with a postgres server via ssl with a URL connection string or DSN (key/value) | ||||
| // for each ssl mode. | ||||
| func TestPostgreSQL_InitializeSSL(t *testing.T) { | ||||
| 	type testCase struct { | ||||
| 		sslMode       string | ||||
| 		useDSN        bool | ||||
| 		useFallback   bool | ||||
| 		wantErr       bool | ||||
| 		expectedError string | ||||
| 	} | ||||
|  | ||||
| 	tests := map[string]testCase{ | ||||
| 		"disable sslmode": { | ||||
| 			sslMode:       "disable", | ||||
| 			wantErr:       true, | ||||
| 			expectedError: "error verifying connection", | ||||
| 		}, | ||||
| 		"allow sslmode": { | ||||
| 			sslMode: "allow", | ||||
| 			wantErr: false, | ||||
| 		}, | ||||
| 		"prefer sslmode": { | ||||
| 			sslMode: "prefer", | ||||
| 			wantErr: false, | ||||
| 		}, | ||||
| 		"require sslmode": { | ||||
| 			sslMode: "require", | ||||
| 			wantErr: false, | ||||
| 		}, | ||||
| 		"verify-ca sslmode": { | ||||
| 			sslMode: "verify-ca", | ||||
| 			wantErr: false, | ||||
| 		}, | ||||
| 		"verify-full sslmode": { | ||||
| 			sslMode: "verify-full", | ||||
| 			wantErr: false, | ||||
| 		}, | ||||
| 		"disable sslmode with DSN": { | ||||
| 			sslMode:       "disable", | ||||
| 			useDSN:        true, | ||||
| 			wantErr:       true, | ||||
| 			expectedError: "error verifying connection", | ||||
| 		}, | ||||
| 		"allow sslmode with DSN": { | ||||
| 			sslMode: "allow", | ||||
| 			useDSN:  true, | ||||
| 			wantErr: false, | ||||
| 		}, | ||||
| 		"prefer sslmode with DSN": { | ||||
| 			sslMode: "prefer", | ||||
| 			useDSN:  true, | ||||
| 			wantErr: false, | ||||
| 		}, | ||||
| 		"require sslmode with DSN": { | ||||
| 			sslMode: "require", | ||||
| 			useDSN:  true, | ||||
| 			wantErr: false, | ||||
| 		}, | ||||
| 		"verify-ca sslmode with DSN": { | ||||
| 			sslMode: "verify-ca", | ||||
| 			useDSN:  true, | ||||
| 			wantErr: false, | ||||
| 		}, | ||||
| 		"verify-full sslmode with DSN": { | ||||
| 			sslMode: "verify-full", | ||||
| 			useDSN:  true, | ||||
| 			wantErr: false, | ||||
| 		}, | ||||
| 		"disable sslmode with fallback": { | ||||
| 			sslMode:       "disable", | ||||
| 			useFallback:   true, | ||||
| 			wantErr:       true, | ||||
| 			expectedError: "error verifying connection", | ||||
| 		}, | ||||
| 		"allow sslmode with fallback": { | ||||
| 			sslMode:     "allow", | ||||
| 			useFallback: true, | ||||
| 		}, | ||||
| 		"prefer sslmode with fallback": { | ||||
| 			sslMode:     "prefer", | ||||
| 			useFallback: true, | ||||
| 		}, | ||||
| 		"require sslmode with fallback": { | ||||
| 			sslMode:     "require", | ||||
| 			useFallback: true, | ||||
| 		}, | ||||
| 		"verify-ca sslmode with fallback": { | ||||
| 			sslMode:     "verify-ca", | ||||
| 			useFallback: true, | ||||
| 		}, | ||||
| 		"verify-full sslmode with fallback": { | ||||
| 			sslMode:     "verify-full", | ||||
| 			useFallback: true, | ||||
| 		}, | ||||
| 		"disable sslmode with DSN with fallback": { | ||||
| 			sslMode:       "disable", | ||||
| 			useDSN:        true, | ||||
| 			useFallback:   true, | ||||
| 			wantErr:       true, | ||||
| 			expectedError: "error verifying connection", | ||||
| 		}, | ||||
| 		"allow sslmode with DSN with fallback": { | ||||
| 			sslMode:     "allow", | ||||
| 			useDSN:      true, | ||||
| 			useFallback: true, | ||||
| 			wantErr:     false, | ||||
| 		}, | ||||
| 		"prefer sslmode with DSN with fallback": { | ||||
| 			sslMode:     "prefer", | ||||
| 			useDSN:      true, | ||||
| 			useFallback: true, | ||||
| 			wantErr:     false, | ||||
| 		}, | ||||
| 		"require sslmode with DSN with fallback": { | ||||
| 			sslMode:     "require", | ||||
| 			useDSN:      true, | ||||
| 			useFallback: true, | ||||
| 			wantErr:     false, | ||||
| 		}, | ||||
| 		"verify-ca sslmode with DSN with fallback": { | ||||
| 			sslMode:     "verify-ca", | ||||
| 			useDSN:      true, | ||||
| 			useFallback: true, | ||||
| 			wantErr:     false, | ||||
| 		}, | ||||
| 		"verify-full sslmode with DSN with fallback": { | ||||
| 			sslMode:     "verify-full", | ||||
| 			useDSN:      true, | ||||
| 			useFallback: true, | ||||
| 			wantErr:     false, | ||||
| 		}, | ||||
| 	} | ||||
| 	for name, test := range tests { | ||||
| 		t.Run(name, func(t *testing.T) { | ||||
| 			t.Parallel() | ||||
|  | ||||
| 			// Create certificates for postgres authentication | ||||
| 			caCert := certhelpers.NewCert(t, certhelpers.CommonName("ca"), certhelpers.IsCA(true), certhelpers.SelfSign()) | ||||
| 			clientCert := certhelpers.NewCert(t, certhelpers.CommonName("postgres"), certhelpers.DNS("localhost"), certhelpers.Parent(caCert)) | ||||
| 			cleanup, connURL := postgresql.PrepareTestContainerWithSSL(t, test.sslMode, caCert, clientCert, test.useFallback) | ||||
| 			t.Cleanup(cleanup) | ||||
|  | ||||
| 			if test.useDSN { | ||||
| 				var err error | ||||
| 				connURL, err = dbutil.ParseURL(connURL) | ||||
| 				if err != nil { | ||||
| 					t.Fatal(err) | ||||
| 				} | ||||
| 			} | ||||
| 			connectionDetails := map[string]interface{}{ | ||||
| 				"connection_url":       connURL, | ||||
| 				"max_open_connections": 5, | ||||
| 				"tls_certificate":      string(clientCert.CombinedPEM()), | ||||
| 				"tls_private_key":      string(clientCert.PrivateKeyPEM()), | ||||
| 				"tls_ca":               string(caCert.CombinedPEM()), | ||||
| 			} | ||||
|  | ||||
| 			req := dbplugin.InitializeRequest{ | ||||
| 				Config:           connectionDetails, | ||||
| 				VerifyConnection: true, | ||||
| 			} | ||||
|  | ||||
| 			db := new() | ||||
| 			_, err := dbtesting.VerifyInitialize(t, db, req) | ||||
| 			if test.wantErr && err == nil { | ||||
| 				t.Fatal("expected error, got nil") | ||||
| 			} else if test.wantErr && !strings.Contains(err.Error(), test.expectedError) { | ||||
| 				t.Fatalf("got: %s, want: %s", err.Error(), test.expectedError) | ||||
| 			} | ||||
|  | ||||
| 			if !test.wantErr && !db.Initialized { | ||||
| 				t.Fatal("Database should be initialized") | ||||
| 			} | ||||
|  | ||||
| 			if err := db.Close(); err != nil { | ||||
| 				t.Fatalf("err: %s", err) | ||||
| 			} | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestPostgreSQL_InitializeWithStringVals(t *testing.T) { | ||||
| 	db, cleanup := getPostgreSQL(t, map[string]interface{}{ | ||||
| 		"max_open_connections": "5", | ||||
|   | ||||
| @@ -10,10 +10,6 @@ import ( | ||||
| 	"cloud.google.com/go/cloudsqlconn/postgres/pgxv4" | ||||
| ) | ||||
|  | ||||
| var configurableAuthTypes = []string{ | ||||
| 	AuthTypeGCPIAM, | ||||
| } | ||||
|  | ||||
| func (c *SQLConnectionProducer) getCloudSQLDriverType() (string, error) { | ||||
| 	var driverType string | ||||
| 	// using switch case for future extensibility | ||||
| @@ -62,15 +58,3 @@ func GetCloudSQLAuthOptions(credentials string, usePrivateIP bool) ([]cloudsqlco | ||||
|  | ||||
| 	return opts, nil | ||||
| } | ||||
|  | ||||
| func ValidateAuthType(authType string) bool { | ||||
| 	var valid bool | ||||
| 	for _, typ := range configurableAuthTypes { | ||||
| 		if authType == typ { | ||||
| 			valid = true | ||||
| 			break | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return valid | ||||
| } | ||||
|   | ||||
| @@ -46,7 +46,7 @@ import ( | ||||
| 	"github.com/jackc/pgx/v4/stdlib" | ||||
| ) | ||||
|  | ||||
| // OpenPostgres parses the connection string and opens a connection to the database. | ||||
| // openPostgres parses the connection string and opens a connection to the database. | ||||
| // | ||||
| // If sslinline is set, strips the connection string of all ssl settings and | ||||
| // creates a TLS config based on the settings provided, then uses the | ||||
| @@ -54,8 +54,8 @@ import ( | ||||
| // because the pgx driver does not support the sslinline parameter and instead | ||||
| // expects to source ssl material from the file system. | ||||
| // | ||||
| // Deprecated: OpenPostgres will be removed in a future version of the Vault SDK. | ||||
| func OpenPostgres(driverName, connString string) (*sql.DB, error) { | ||||
| // Deprecated: openPostgres will be removed in a future version of the Vault SDK. | ||||
| func openPostgres(driverName, connString string) (*sql.DB, error) { | ||||
| 	if ok, _ := strconv.ParseBool(os.Getenv(pluginutil.PluginUsePostgresSSLInline)); !ok { | ||||
| 		return nil, fmt.Errorf("failed to open postgres connection with deprecated funtion, set feature flag to enable") | ||||
| 	} | ||||
|   | ||||
| @@ -5,6 +5,7 @@ package connutil | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"crypto/tls" | ||||
| 	"database/sql" | ||||
| 	"fmt" | ||||
| 	"net/url" | ||||
| @@ -19,12 +20,18 @@ import ( | ||||
| 	"github.com/hashicorp/vault/sdk/database/dbplugin" | ||||
| 	"github.com/hashicorp/vault/sdk/database/helper/dbutil" | ||||
| 	"github.com/hashicorp/vault/sdk/helper/pluginutil" | ||||
| 	"github.com/jackc/pgx/v4" | ||||
| 	"github.com/jackc/pgx/v4/stdlib" | ||||
| 	"github.com/mitchellh/mapstructure" | ||||
| ) | ||||
|  | ||||
| const ( | ||||
| 	AuthTypeGCPIAM = "gcp_iam" | ||||
| 	AuthTypeGCPIAM           = "gcp_iam" | ||||
| 	AuthTypeCert             = "cert" | ||||
| 	AuthTypeUsernamePassword = "" | ||||
| ) | ||||
|  | ||||
| const ( | ||||
| 	dbTypePostgres   = "pgx" | ||||
| 	cloudSQLPostgres = "cloudsql-postgres" | ||||
| ) | ||||
| @@ -37,14 +44,19 @@ type SQLConnectionProducer struct { | ||||
| 	MaxOpenConnections       int         `json:"max_open_connections" mapstructure:"max_open_connections" structs:"max_open_connections"` | ||||
| 	MaxIdleConnections       int         `json:"max_idle_connections" mapstructure:"max_idle_connections" structs:"max_idle_connections"` | ||||
| 	MaxConnectionLifetimeRaw interface{} `json:"max_connection_lifetime" mapstructure:"max_connection_lifetime" structs:"max_connection_lifetime"` | ||||
| 	Username                 string      `json:"username" mapstructure:"username" structs:"username"` | ||||
| 	Password                 string      `json:"password" mapstructure:"password" structs:"password"` | ||||
| 	AuthType                 string      `json:"auth_type" mapstructure:"auth_type" structs:"auth_type"` | ||||
| 	ServiceAccountJSON       string      `json:"service_account_json" mapstructure:"service_account_json" structs:"service_account_json"` | ||||
| 	DisableEscaping          bool        `json:"disable_escaping" mapstructure:"disable_escaping" structs:"disable_escaping"` | ||||
| 	usePrivateIP             bool        `json:"use_private_ip" mapstructure:"use_private_ip" structs:"use_private_ip"` | ||||
|  | ||||
| 	// cloud options here - cloudDriverName is globally unique, but only needs to be retained for the lifetime | ||||
| 	// Username/Password is the default auth type when AuthType is not set | ||||
| 	Username string `json:"username" mapstructure:"username" structs:"username"` | ||||
| 	Password string `json:"password" mapstructure:"password" structs:"password"` | ||||
|  | ||||
| 	// AuthType defines the type of client authenticate used for this connection | ||||
| 	AuthType           string `json:"auth_type" mapstructure:"auth_type" structs:"auth_type"` | ||||
| 	ServiceAccountJSON string `json:"service_account_json" mapstructure:"service_account_json" structs:"service_account_json"` | ||||
| 	TLSConfig          *tls.Config | ||||
|  | ||||
| 	// cloudDriverName is globally unique, but only needs to be retained for the lifetime | ||||
| 	// of driver registration, not across plugin restarts. | ||||
| 	cloudDriverName    string | ||||
| 	cloudDialerCleanup func() error | ||||
| @@ -125,15 +137,11 @@ func (c *SQLConnectionProducer) Init(ctx context.Context, conf map[string]interf | ||||
| 		return nil, errwrap.Wrapf("invalid max_connection_lifetime: {{err}}", err) | ||||
| 	} | ||||
|  | ||||
| 	// validate auth_type if provided | ||||
| 	authType := c.AuthType | ||||
| 	if authType != "" { | ||||
| 		if ok := ValidateAuthType(authType); !ok { | ||||
| 			return nil, fmt.Errorf("invalid auth_type %s provided", authType) | ||||
| 		} | ||||
| 	if ok := ValidateAuthType(c.AuthType); !ok { | ||||
| 		return nil, fmt.Errorf("invalid auth_type: %s", c.AuthType) | ||||
| 	} | ||||
|  | ||||
| 	if authType == AuthTypeGCPIAM { | ||||
| 	if c.AuthType == AuthTypeGCPIAM { | ||||
| 		c.cloudDriverName, err = uuid.GenerateUUID() | ||||
| 		if err != nil { | ||||
| 			return nil, fmt.Errorf("unable to generate UUID for IAM configuration: %w", err) | ||||
| @@ -161,7 +169,7 @@ func (c *SQLConnectionProducer) Init(ctx context.Context, conf map[string]interf | ||||
| 		} | ||||
|  | ||||
| 		if err := c.db.PingContext(ctx); err != nil { | ||||
| 			return nil, errwrap.Wrapf("error verifying connection: {{err}}", err) | ||||
| 			return nil, errwrap.Wrapf("error verifying connection: ping failed: {{err}}", err) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| @@ -219,16 +227,42 @@ func (c *SQLConnectionProducer) Connection(ctx context.Context) (interface{}, er | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	var err error | ||||
| 	if driverName == "pgx" && os.Getenv(pluginutil.PluginUsePostgresSSLInline) != "" { | ||||
| 		// TODO: remove this deprecated function call in a future SDK version | ||||
| 		c.db, err = OpenPostgres(driverName, conn) | ||||
| 	} else { | ||||
| 		c.db, err = sql.Open(driverName, conn) | ||||
| 	} | ||||
| 	if driverName == dbTypePostgres && c.TLSConfig != nil { | ||||
| 		config, err := pgx.ParseConfig(conn) | ||||
| 		if err != nil { | ||||
| 			return nil, fmt.Errorf("failed to parse config: %w", err) | ||||
| 		} | ||||
| 		if config.TLSConfig == nil { | ||||
| 			// handle sslmode=disable | ||||
| 			config.TLSConfig = &tls.Config{} | ||||
| 		} | ||||
|  | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 		config.TLSConfig.RootCAs = c.TLSConfig.RootCAs | ||||
| 		config.TLSConfig.ClientCAs = c.TLSConfig.ClientCAs | ||||
| 		config.TLSConfig.Certificates = c.TLSConfig.Certificates | ||||
|  | ||||
| 		// Ensure there are no stale fallbacks when manually setting TLSConfig | ||||
| 		for _, fallback := range config.Fallbacks { | ||||
| 			fallback.TLSConfig = config.TLSConfig | ||||
| 		} | ||||
|  | ||||
| 		c.db = stdlib.OpenDB(*config) | ||||
| 		if err != nil { | ||||
| 			return nil, fmt.Errorf("failed to open connection: %w", err) | ||||
| 		} | ||||
| 	} else if driverName == dbTypePostgres && os.Getenv(pluginutil.PluginUsePostgresSSLInline) != "" { | ||||
| 		var err error | ||||
| 		// TODO: remove this deprecated function call in a future SDK version | ||||
| 		c.db, err = openPostgres(driverName, conn) | ||||
| 		if err != nil { | ||||
| 			return nil, fmt.Errorf("failed to open connection: %w", err) | ||||
| 		} | ||||
| 	} else { | ||||
| 		var err error | ||||
| 		c.db, err = sql.Open(driverName, conn) | ||||
| 		if err != nil { | ||||
| 			return nil, fmt.Errorf("failed to open connection: %w", err) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	// Set some connection pool settings. We don't need much of this, | ||||
| @@ -277,3 +311,13 @@ func (c *SQLConnectionProducer) Close() error { | ||||
| func (c *SQLConnectionProducer) SetCredentials(ctx context.Context, statements dbplugin.Statements, staticUser dbplugin.StaticUserConfig) (username, password string, err error) { | ||||
| 	return "", "", dbutil.Unimplemented() | ||||
| } | ||||
|  | ||||
| var configurableAuthTypes = map[string]bool{ | ||||
| 	AuthTypeUsernamePassword: true, | ||||
| 	AuthTypeCert:             true, | ||||
| 	AuthTypeGCPIAM:           true, | ||||
| } | ||||
|  | ||||
| func ValidateAuthType(authType string) bool { | ||||
| 	return configurableAuthTypes[authType] | ||||
| } | ||||
|   | ||||
| @@ -84,7 +84,7 @@ require ( | ||||
| 	github.com/jackc/pgproto3/v2 v2.3.3 // indirect | ||||
| 	github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect | ||||
| 	github.com/jackc/pgtype v1.14.0 // indirect | ||||
| 	github.com/jackc/pgx/v4 v4.18.3 // indirect | ||||
| 	github.com/jackc/pgx/v4 v4.18.3 | ||||
| 	github.com/joshlf/go-acl v0.0.0-20200411065538-eae00ae38531 // indirect | ||||
| 	github.com/klauspost/compress v1.16.5 // indirect | ||||
| 	github.com/mattn/go-colorable v0.1.13 // indirect | ||||
|   | ||||
		Reference in New Issue
	
	Block a user