diff --git a/builtin/logical/database/backend_test.go b/builtin/logical/database/backend_test.go index eb14d78443..f5ba7246fc 100644 --- a/builtin/logical/database/backend_test.go +++ b/builtin/logical/database/backend_test.go @@ -23,6 +23,7 @@ import ( "github.com/hashicorp/go-hclog" "github.com/hashicorp/vault/helper/builtinplugins" "github.com/hashicorp/vault/helper/namespace" + "github.com/hashicorp/vault/helper/testhelpers/certhelpers" "github.com/hashicorp/vault/helper/testhelpers/corehelpers" postgreshelper "github.com/hashicorp/vault/helper/testhelpers/postgresql" vaulthttp "github.com/hashicorp/vault/http" @@ -660,6 +661,7 @@ func (s *singletonDBFactory) factory(context.Context, *logical.BackendConfig) (l } func TestBackend_connectionCrud(t *testing.T) { + t.Parallel() dbFactory := &singletonDBFactory{} cluster, sys := getClusterPostgresDBWithFactory(t, dbFactory.factory) defer cluster.Cleanup() @@ -717,7 +719,6 @@ func TestBackend_connectionCrud(t *testing.T) { "allowed_roles": []string{"plugin-role-test"}, "username": "postgres", "password": "secret", - "private_key": "PRIVATE_KEY", }) if err != nil { t.Fatalf("err:%s resp:%#v\n", err, resp) @@ -738,9 +739,6 @@ func TestBackend_connectionCrud(t *testing.T) { if _, exists := returnedConnectionDetails["password"]; exists { t.Fatal("password should NOT be found in the returned config") } - if _, exists := returnedConnectionDetails["private_key"]; exists { - t.Fatal("private_key should NOT be found in the returned config") - } // Replace connection url with templated version templatedConnURL := strings.ReplaceAll(connURL, "postgres:secret", "{{username}}:{{password}}") @@ -750,7 +748,6 @@ func TestBackend_connectionCrud(t *testing.T) { "allowed_roles": []string{"plugin-role-test"}, "username": "postgres", "password": "secret", - "private_key": "PRIVATE_KEY", }) if err != nil { t.Fatalf("err:%s resp:%#v\n", err, resp) @@ -856,6 +853,57 @@ func TestBackend_connectionCrud(t *testing.T) { } } +func TestBackend_connectionSanitizePrivateKey(t *testing.T) { + t.Parallel() + dbFactory := &singletonDBFactory{} + cluster, sys := getClusterPostgresDBWithFactory(t, dbFactory.factory) + defer cluster.Cleanup() + + dbFactory.sys = sys + client := cluster.Cores[0].Client.Logical() + + cleanup, connURL := postgreshelper.PrepareTestContainer(t) + defer cleanup() + + // Mount the database plugin. + resp, err := client.Write("sys/mounts/database", map[string]interface{}{ + "type": "database", + }) + if err != nil { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + + caCert := certhelpers.NewCert(t, certhelpers.CommonName("ca"), certhelpers.IsCA(true), certhelpers.SelfSign()) + clientCert := certhelpers.NewCert(t, certhelpers.CommonName("postgres"), certhelpers.DNS("localhost"), certhelpers.Parent(caCert)) + + // Create a connection + resp, err = client.Write("database/config/plugin-test", map[string]interface{}{ + "connection_url": connURL, + "plugin_name": "postgresql-database-plugin", + "allowed_roles": []string{"plugin-role-test"}, + "username": "postgres", + "tls_certificate": string(clientCert.CombinedPEM()), + "private_key": string(clientCert.PrivateKeyPEM()), + "tls_ca": string(caCert.CombinedPEM()), + "verify_connection": false, + }) + if err != nil { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + + resp, err = client.Read("database/config/plugin-test") + if err != nil { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + returnedConnectionDetails := resp.Data["connection_details"].(map[string]interface{}) + if strings.Contains(returnedConnectionDetails["connection_url"].(string), "secret") { + t.Fatal("password should not be found in the connection url") + } + if _, exists := returnedConnectionDetails["private_key"]; exists { + t.Fatal("private_key should NOT be found in the returned config") + } +} + func TestBackend_roleCrud(t *testing.T) { cluster, sys := getClusterPostgresDB(t) defer cluster.Cleanup() diff --git a/plugins/database/postgresql/postgresql.go b/plugins/database/postgresql/postgresql.go index a9279a2867..3a9bf7ae7c 100644 --- a/plugins/database/postgresql/postgresql.go +++ b/plugins/database/postgresql/postgresql.go @@ -84,7 +84,7 @@ type PostgreSQL struct { *connutil.SQLConnectionProducer TLSCertificateData []byte `json:"tls_certificate" structs:"-" mapstructure:"tls_certificate"` - TLSPrivateKey []byte `json:"tls_private_key" structs:"-" mapstructure:"tls_private_key"` + TLSPrivateKey []byte `json:"private_key" structs:"-" mapstructure:"private_key"` TLSCAData []byte `json:"tls_ca" structs:"-" mapstructure:"tls_ca"` usernameProducer template.StringTemplate @@ -97,9 +97,9 @@ func (p *PostgreSQL) Initialize(ctx context.Context, req dbplugin.InitializeRequ return dbplugin.InitializeResponse{}, fmt.Errorf("failed to retrieve tls_certificate: %w", err) } - sslkey, err := strutil.GetString(req.Config, "tls_private_key") + sslkey, err := strutil.GetString(req.Config, "private_key") if err != nil { - return dbplugin.InitializeResponse{}, fmt.Errorf("failed to retrieve tls_private_key: %w", err) + return dbplugin.InitializeResponse{}, fmt.Errorf("failed to retrieve private_key: %w", err) } sslrootcert, err := strutil.GetString(req.Config, "tls_ca") diff --git a/plugins/database/postgresql/postgresql_test.go b/plugins/database/postgresql/postgresql_test.go index e9d4efd20e..ba150870c7 100644 --- a/plugins/database/postgresql/postgresql_test.go +++ b/plugins/database/postgresql/postgresql_test.go @@ -488,7 +488,7 @@ func TestPostgreSQL_InitializeSSL(t *testing.T) { "connection_url": connURL, "max_open_connections": 5, "tls_certificate": string(clientCert.CombinedPEM()), - "tls_private_key": string(clientCert.PrivateKeyPEM()), + "private_key": string(clientCert.PrivateKeyPEM()), "tls_ca": string(caCert.CombinedPEM()), }