database/postgres: add inline certificate authentication fields (#28024)

* add inline cert auth to postres db plugin

* handle both sslinline and new TLS plugin fields

* refactor PrepareTestContainerWithSSL

* add tests for postgres inline TLS fields

* changelog

* revert back to errwrap since the middleware sanitizing depends on it

* enable only setting sslrootcert
This commit is contained in:
John-Michael Faircloth
2024-08-09 14:20:19 -05:00
committed by GitHub
parent a19195c901
commit 3fcb1a67c5
10 changed files with 374 additions and 97 deletions

View File

@@ -123,11 +123,8 @@ func (c *mySQLConnectionProducer) Init(ctx context.Context, conf map[string]inte
}
// validate auth_type if provided
authType := c.AuthType
if authType != "" {
if ok := connutil.ValidateAuthType(authType); !ok {
return nil, fmt.Errorf("invalid auth_type %s provided", authType)
}
if ok := connutil.ValidateAuthType(c.AuthType); !ok {
return nil, fmt.Errorf("invalid auth_type: %s", c.AuthType)
}
if c.AuthType == connutil.AuthTypeGCPIAM {

View File

@@ -5,7 +5,11 @@ package postgresql
import (
"context"
"crypto/tls"
"crypto/x509"
"database/sql"
"encoding/pem"
"errors"
"fmt"
"regexp"
"strings"
@@ -79,11 +83,65 @@ func new() *PostgreSQL {
type PostgreSQL struct {
*connutil.SQLConnectionProducer
TLSCertificateData []byte `json:"tls_certificate" structs:"-" mapstructure:"tls_certificate"`
TLSPrivateKey []byte `json:"tls_private_key" structs:"-" mapstructure:"tls_private_key"`
TLSCAData []byte `json:"tls_ca" structs:"-" mapstructure:"tls_ca"`
usernameProducer template.StringTemplate
passwordAuthentication passwordAuthentication
}
func (p *PostgreSQL) Initialize(ctx context.Context, req dbplugin.InitializeRequest) (dbplugin.InitializeResponse, error) {
sslcert, err := strutil.GetString(req.Config, "tls_certificate")
if err != nil {
return dbplugin.InitializeResponse{}, fmt.Errorf("failed to retrieve tls_certificate: %w", err)
}
sslkey, err := strutil.GetString(req.Config, "tls_private_key")
if err != nil {
return dbplugin.InitializeResponse{}, fmt.Errorf("failed to retrieve tls_private_key: %w", err)
}
sslrootcert, err := strutil.GetString(req.Config, "tls_ca")
if err != nil {
return dbplugin.InitializeResponse{}, fmt.Errorf("failed to retrieve tls_ca: %w", err)
}
useTLS := false
tlsConfig := &tls.Config{}
if sslrootcert != "" {
caCertPool := x509.NewCertPool()
if !caCertPool.AppendCertsFromPEM([]byte(sslrootcert)) {
return dbplugin.InitializeResponse{}, errors.New("unable to add CA to cert pool")
}
tlsConfig.RootCAs = caCertPool
tlsConfig.ClientCAs = caCertPool
p.TLSConfig = tlsConfig
useTLS = true
}
if (sslcert != "" && sslkey == "") || (sslcert == "" && sslkey != "") {
return dbplugin.InitializeResponse{}, errors.New(`both "sslcert" and "sslkey" are required`)
}
if sslcert != "" && sslkey != "" {
block, _ := pem.Decode([]byte(sslkey))
cert, err := tls.X509KeyPair([]byte(sslcert), pem.EncodeToMemory(block))
if err != nil {
return dbplugin.InitializeResponse{}, fmt.Errorf("unable to load cert: %w", err)
}
tlsConfig.Certificates = []tls.Certificate{cert}
p.TLSConfig = tlsConfig
useTLS = true
}
if !useTLS {
// set to nil to flag that this connection does not use a custom TLS config
p.TLSConfig = nil
}
newConf, err := p.SQLConnectionProducer.Init(ctx, req.Config, req.VerifyConnection)
if err != nil {
return dbplugin.InitializeResponse{}, err

View File

@@ -12,6 +12,7 @@ import (
"testing"
"time"
"github.com/hashicorp/vault/helper/testhelpers/certhelpers"
"github.com/hashicorp/vault/helper/testhelpers/postgresql"
"github.com/hashicorp/vault/sdk/database/dbplugin/v5"
dbtesting "github.com/hashicorp/vault/sdk/database/dbplugin/v5/testing"
@@ -86,15 +87,18 @@ func TestPostgreSQL_InitializeMultiHost(t *testing.T) {
}
}
// TestPostgreSQL_InitializeSSLFeatureFlag tests that the VAULT_PLUGIN_USE_POSTGRES_SSLINLINE
// TestPostgreSQL_InitializeSSLInlineFeatureFlag tests that the VAULT_PLUGIN_USE_POSTGRES_SSLINLINE
// flag guards against unwanted usage of the deprecated SSL client authentication path.
// TODO: remove this when we remove the underlying feature in a future SDK version
func TestPostgreSQL_InitializeSSLFeatureFlag(t *testing.T) {
func TestPostgreSQL_InitializeSSLInlineFeatureFlag(t *testing.T) {
// set the flag to true so we can call PrepareTestContainerWithSSL
// which does a validation check on the connection
t.Setenv(pluginutil.PluginUsePostgresSSLInline, "true")
cleanup, connURL := postgresql.PrepareTestContainerWithSSL(t, context.Background(), "verify-ca", false)
// Create certificates for postgres authentication
caCert := certhelpers.NewCert(t, certhelpers.CommonName("ca"), certhelpers.IsCA(true), certhelpers.SelfSign())
clientCert := certhelpers.NewCert(t, certhelpers.CommonName("postgres"), certhelpers.DNS("localhost"), certhelpers.Parent(caCert))
cleanup, connURL := postgresql.PrepareTestContainerWithSSL(t, "verify-ca", caCert, clientCert, false)
t.Cleanup(cleanup)
type testCase struct {
@@ -166,11 +170,11 @@ func TestPostgreSQL_InitializeSSLFeatureFlag(t *testing.T) {
}
}
// TestPostgreSQL_InitializeSSL tests that we can successfully authenticate
// TestPostgreSQL_InitializeSSLInline tests that we can successfully authenticate
// with a postgres server via ssl with a URL connection string or DSN (key/value)
// for each ssl mode.
// TODO: remove this when we remove the underlying feature in a future SDK version
func TestPostgreSQL_InitializeSSL(t *testing.T) {
func TestPostgreSQL_InitializeSSLInline(t *testing.T) {
// required to enable the sslinline custom parsing
t.Setenv(pluginutil.PluginUsePostgresSSLInline, "true")
@@ -287,7 +291,11 @@ func TestPostgreSQL_InitializeSSL(t *testing.T) {
for name, test := range tests {
t.Run(name, func(t *testing.T) {
t.Parallel()
cleanup, connURL := postgresql.PrepareTestContainerWithSSL(t, context.Background(), test.sslMode, test.useFallback)
// Create certificates for postgres authentication
caCert := certhelpers.NewCert(t, certhelpers.CommonName("ca"), certhelpers.IsCA(true), certhelpers.SelfSign())
clientCert := certhelpers.NewCert(t, certhelpers.CommonName("postgres"), certhelpers.DNS("localhost"), certhelpers.Parent(caCert))
cleanup, connURL := postgresql.PrepareTestContainerWithSSL(t, test.sslMode, caCert, clientCert, test.useFallback)
t.Cleanup(cleanup)
if test.useDSN {
@@ -326,6 +334,188 @@ func TestPostgreSQL_InitializeSSL(t *testing.T) {
}
}
// TestPostgreSQL_InitializeSSL tests that we can successfully authenticate
// with a postgres server via ssl with a URL connection string or DSN (key/value)
// for each ssl mode.
func TestPostgreSQL_InitializeSSL(t *testing.T) {
type testCase struct {
sslMode string
useDSN bool
useFallback bool
wantErr bool
expectedError string
}
tests := map[string]testCase{
"disable sslmode": {
sslMode: "disable",
wantErr: true,
expectedError: "error verifying connection",
},
"allow sslmode": {
sslMode: "allow",
wantErr: false,
},
"prefer sslmode": {
sslMode: "prefer",
wantErr: false,
},
"require sslmode": {
sslMode: "require",
wantErr: false,
},
"verify-ca sslmode": {
sslMode: "verify-ca",
wantErr: false,
},
"verify-full sslmode": {
sslMode: "verify-full",
wantErr: false,
},
"disable sslmode with DSN": {
sslMode: "disable",
useDSN: true,
wantErr: true,
expectedError: "error verifying connection",
},
"allow sslmode with DSN": {
sslMode: "allow",
useDSN: true,
wantErr: false,
},
"prefer sslmode with DSN": {
sslMode: "prefer",
useDSN: true,
wantErr: false,
},
"require sslmode with DSN": {
sslMode: "require",
useDSN: true,
wantErr: false,
},
"verify-ca sslmode with DSN": {
sslMode: "verify-ca",
useDSN: true,
wantErr: false,
},
"verify-full sslmode with DSN": {
sslMode: "verify-full",
useDSN: true,
wantErr: false,
},
"disable sslmode with fallback": {
sslMode: "disable",
useFallback: true,
wantErr: true,
expectedError: "error verifying connection",
},
"allow sslmode with fallback": {
sslMode: "allow",
useFallback: true,
},
"prefer sslmode with fallback": {
sslMode: "prefer",
useFallback: true,
},
"require sslmode with fallback": {
sslMode: "require",
useFallback: true,
},
"verify-ca sslmode with fallback": {
sslMode: "verify-ca",
useFallback: true,
},
"verify-full sslmode with fallback": {
sslMode: "verify-full",
useFallback: true,
},
"disable sslmode with DSN with fallback": {
sslMode: "disable",
useDSN: true,
useFallback: true,
wantErr: true,
expectedError: "error verifying connection",
},
"allow sslmode with DSN with fallback": {
sslMode: "allow",
useDSN: true,
useFallback: true,
wantErr: false,
},
"prefer sslmode with DSN with fallback": {
sslMode: "prefer",
useDSN: true,
useFallback: true,
wantErr: false,
},
"require sslmode with DSN with fallback": {
sslMode: "require",
useDSN: true,
useFallback: true,
wantErr: false,
},
"verify-ca sslmode with DSN with fallback": {
sslMode: "verify-ca",
useDSN: true,
useFallback: true,
wantErr: false,
},
"verify-full sslmode with DSN with fallback": {
sslMode: "verify-full",
useDSN: true,
useFallback: true,
wantErr: false,
},
}
for name, test := range tests {
t.Run(name, func(t *testing.T) {
t.Parallel()
// Create certificates for postgres authentication
caCert := certhelpers.NewCert(t, certhelpers.CommonName("ca"), certhelpers.IsCA(true), certhelpers.SelfSign())
clientCert := certhelpers.NewCert(t, certhelpers.CommonName("postgres"), certhelpers.DNS("localhost"), certhelpers.Parent(caCert))
cleanup, connURL := postgresql.PrepareTestContainerWithSSL(t, test.sslMode, caCert, clientCert, test.useFallback)
t.Cleanup(cleanup)
if test.useDSN {
var err error
connURL, err = dbutil.ParseURL(connURL)
if err != nil {
t.Fatal(err)
}
}
connectionDetails := map[string]interface{}{
"connection_url": connURL,
"max_open_connections": 5,
"tls_certificate": string(clientCert.CombinedPEM()),
"tls_private_key": string(clientCert.PrivateKeyPEM()),
"tls_ca": string(caCert.CombinedPEM()),
}
req := dbplugin.InitializeRequest{
Config: connectionDetails,
VerifyConnection: true,
}
db := new()
_, err := dbtesting.VerifyInitialize(t, db, req)
if test.wantErr && err == nil {
t.Fatal("expected error, got nil")
} else if test.wantErr && !strings.Contains(err.Error(), test.expectedError) {
t.Fatalf("got: %s, want: %s", err.Error(), test.expectedError)
}
if !test.wantErr && !db.Initialized {
t.Fatal("Database should be initialized")
}
if err := db.Close(); err != nil {
t.Fatalf("err: %s", err)
}
})
}
}
func TestPostgreSQL_InitializeWithStringVals(t *testing.T) {
db, cleanup := getPostgreSQL(t, map[string]interface{}{
"max_open_connections": "5",