mirror of
https://github.com/optim-enterprises-bv/vault.git
synced 2025-11-01 11:08:10 +00:00
database/postgres: add inline certificate authentication fields (#28024)
* add inline cert auth to postres db plugin * handle both sslinline and new TLS plugin fields * refactor PrepareTestContainerWithSSL * add tests for postgres inline TLS fields * changelog * revert back to errwrap since the middleware sanitizing depends on it * enable only setting sslrootcert
This commit is contained in:
committed by
GitHub
parent
a19195c901
commit
3fcb1a67c5
@@ -123,11 +123,8 @@ func (c *mySQLConnectionProducer) Init(ctx context.Context, conf map[string]inte
|
||||
}
|
||||
|
||||
// validate auth_type if provided
|
||||
authType := c.AuthType
|
||||
if authType != "" {
|
||||
if ok := connutil.ValidateAuthType(authType); !ok {
|
||||
return nil, fmt.Errorf("invalid auth_type %s provided", authType)
|
||||
}
|
||||
if ok := connutil.ValidateAuthType(c.AuthType); !ok {
|
||||
return nil, fmt.Errorf("invalid auth_type: %s", c.AuthType)
|
||||
}
|
||||
|
||||
if c.AuthType == connutil.AuthTypeGCPIAM {
|
||||
|
||||
@@ -5,7 +5,11 @@ package postgresql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"database/sql"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
@@ -79,11 +83,65 @@ func new() *PostgreSQL {
|
||||
type PostgreSQL struct {
|
||||
*connutil.SQLConnectionProducer
|
||||
|
||||
TLSCertificateData []byte `json:"tls_certificate" structs:"-" mapstructure:"tls_certificate"`
|
||||
TLSPrivateKey []byte `json:"tls_private_key" structs:"-" mapstructure:"tls_private_key"`
|
||||
TLSCAData []byte `json:"tls_ca" structs:"-" mapstructure:"tls_ca"`
|
||||
|
||||
usernameProducer template.StringTemplate
|
||||
passwordAuthentication passwordAuthentication
|
||||
}
|
||||
|
||||
func (p *PostgreSQL) Initialize(ctx context.Context, req dbplugin.InitializeRequest) (dbplugin.InitializeResponse, error) {
|
||||
sslcert, err := strutil.GetString(req.Config, "tls_certificate")
|
||||
if err != nil {
|
||||
return dbplugin.InitializeResponse{}, fmt.Errorf("failed to retrieve tls_certificate: %w", err)
|
||||
}
|
||||
|
||||
sslkey, err := strutil.GetString(req.Config, "tls_private_key")
|
||||
if err != nil {
|
||||
return dbplugin.InitializeResponse{}, fmt.Errorf("failed to retrieve tls_private_key: %w", err)
|
||||
}
|
||||
|
||||
sslrootcert, err := strutil.GetString(req.Config, "tls_ca")
|
||||
if err != nil {
|
||||
return dbplugin.InitializeResponse{}, fmt.Errorf("failed to retrieve tls_ca: %w", err)
|
||||
}
|
||||
|
||||
useTLS := false
|
||||
tlsConfig := &tls.Config{}
|
||||
if sslrootcert != "" {
|
||||
caCertPool := x509.NewCertPool()
|
||||
if !caCertPool.AppendCertsFromPEM([]byte(sslrootcert)) {
|
||||
return dbplugin.InitializeResponse{}, errors.New("unable to add CA to cert pool")
|
||||
}
|
||||
|
||||
tlsConfig.RootCAs = caCertPool
|
||||
tlsConfig.ClientCAs = caCertPool
|
||||
p.TLSConfig = tlsConfig
|
||||
useTLS = true
|
||||
}
|
||||
|
||||
if (sslcert != "" && sslkey == "") || (sslcert == "" && sslkey != "") {
|
||||
return dbplugin.InitializeResponse{}, errors.New(`both "sslcert" and "sslkey" are required`)
|
||||
}
|
||||
|
||||
if sslcert != "" && sslkey != "" {
|
||||
block, _ := pem.Decode([]byte(sslkey))
|
||||
|
||||
cert, err := tls.X509KeyPair([]byte(sslcert), pem.EncodeToMemory(block))
|
||||
if err != nil {
|
||||
return dbplugin.InitializeResponse{}, fmt.Errorf("unable to load cert: %w", err)
|
||||
}
|
||||
tlsConfig.Certificates = []tls.Certificate{cert}
|
||||
p.TLSConfig = tlsConfig
|
||||
useTLS = true
|
||||
}
|
||||
|
||||
if !useTLS {
|
||||
// set to nil to flag that this connection does not use a custom TLS config
|
||||
p.TLSConfig = nil
|
||||
}
|
||||
|
||||
newConf, err := p.SQLConnectionProducer.Init(ctx, req.Config, req.VerifyConnection)
|
||||
if err != nil {
|
||||
return dbplugin.InitializeResponse{}, err
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/vault/helper/testhelpers/certhelpers"
|
||||
"github.com/hashicorp/vault/helper/testhelpers/postgresql"
|
||||
"github.com/hashicorp/vault/sdk/database/dbplugin/v5"
|
||||
dbtesting "github.com/hashicorp/vault/sdk/database/dbplugin/v5/testing"
|
||||
@@ -86,15 +87,18 @@ func TestPostgreSQL_InitializeMultiHost(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestPostgreSQL_InitializeSSLFeatureFlag tests that the VAULT_PLUGIN_USE_POSTGRES_SSLINLINE
|
||||
// TestPostgreSQL_InitializeSSLInlineFeatureFlag tests that the VAULT_PLUGIN_USE_POSTGRES_SSLINLINE
|
||||
// flag guards against unwanted usage of the deprecated SSL client authentication path.
|
||||
// TODO: remove this when we remove the underlying feature in a future SDK version
|
||||
func TestPostgreSQL_InitializeSSLFeatureFlag(t *testing.T) {
|
||||
func TestPostgreSQL_InitializeSSLInlineFeatureFlag(t *testing.T) {
|
||||
// set the flag to true so we can call PrepareTestContainerWithSSL
|
||||
// which does a validation check on the connection
|
||||
t.Setenv(pluginutil.PluginUsePostgresSSLInline, "true")
|
||||
|
||||
cleanup, connURL := postgresql.PrepareTestContainerWithSSL(t, context.Background(), "verify-ca", false)
|
||||
// Create certificates for postgres authentication
|
||||
caCert := certhelpers.NewCert(t, certhelpers.CommonName("ca"), certhelpers.IsCA(true), certhelpers.SelfSign())
|
||||
clientCert := certhelpers.NewCert(t, certhelpers.CommonName("postgres"), certhelpers.DNS("localhost"), certhelpers.Parent(caCert))
|
||||
cleanup, connURL := postgresql.PrepareTestContainerWithSSL(t, "verify-ca", caCert, clientCert, false)
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
type testCase struct {
|
||||
@@ -166,11 +170,11 @@ func TestPostgreSQL_InitializeSSLFeatureFlag(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestPostgreSQL_InitializeSSL tests that we can successfully authenticate
|
||||
// TestPostgreSQL_InitializeSSLInline tests that we can successfully authenticate
|
||||
// with a postgres server via ssl with a URL connection string or DSN (key/value)
|
||||
// for each ssl mode.
|
||||
// TODO: remove this when we remove the underlying feature in a future SDK version
|
||||
func TestPostgreSQL_InitializeSSL(t *testing.T) {
|
||||
func TestPostgreSQL_InitializeSSLInline(t *testing.T) {
|
||||
// required to enable the sslinline custom parsing
|
||||
t.Setenv(pluginutil.PluginUsePostgresSSLInline, "true")
|
||||
|
||||
@@ -287,7 +291,11 @@ func TestPostgreSQL_InitializeSSL(t *testing.T) {
|
||||
for name, test := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
cleanup, connURL := postgresql.PrepareTestContainerWithSSL(t, context.Background(), test.sslMode, test.useFallback)
|
||||
|
||||
// Create certificates for postgres authentication
|
||||
caCert := certhelpers.NewCert(t, certhelpers.CommonName("ca"), certhelpers.IsCA(true), certhelpers.SelfSign())
|
||||
clientCert := certhelpers.NewCert(t, certhelpers.CommonName("postgres"), certhelpers.DNS("localhost"), certhelpers.Parent(caCert))
|
||||
cleanup, connURL := postgresql.PrepareTestContainerWithSSL(t, test.sslMode, caCert, clientCert, test.useFallback)
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
if test.useDSN {
|
||||
@@ -326,6 +334,188 @@ func TestPostgreSQL_InitializeSSL(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestPostgreSQL_InitializeSSL tests that we can successfully authenticate
|
||||
// with a postgres server via ssl with a URL connection string or DSN (key/value)
|
||||
// for each ssl mode.
|
||||
func TestPostgreSQL_InitializeSSL(t *testing.T) {
|
||||
type testCase struct {
|
||||
sslMode string
|
||||
useDSN bool
|
||||
useFallback bool
|
||||
wantErr bool
|
||||
expectedError string
|
||||
}
|
||||
|
||||
tests := map[string]testCase{
|
||||
"disable sslmode": {
|
||||
sslMode: "disable",
|
||||
wantErr: true,
|
||||
expectedError: "error verifying connection",
|
||||
},
|
||||
"allow sslmode": {
|
||||
sslMode: "allow",
|
||||
wantErr: false,
|
||||
},
|
||||
"prefer sslmode": {
|
||||
sslMode: "prefer",
|
||||
wantErr: false,
|
||||
},
|
||||
"require sslmode": {
|
||||
sslMode: "require",
|
||||
wantErr: false,
|
||||
},
|
||||
"verify-ca sslmode": {
|
||||
sslMode: "verify-ca",
|
||||
wantErr: false,
|
||||
},
|
||||
"verify-full sslmode": {
|
||||
sslMode: "verify-full",
|
||||
wantErr: false,
|
||||
},
|
||||
"disable sslmode with DSN": {
|
||||
sslMode: "disable",
|
||||
useDSN: true,
|
||||
wantErr: true,
|
||||
expectedError: "error verifying connection",
|
||||
},
|
||||
"allow sslmode with DSN": {
|
||||
sslMode: "allow",
|
||||
useDSN: true,
|
||||
wantErr: false,
|
||||
},
|
||||
"prefer sslmode with DSN": {
|
||||
sslMode: "prefer",
|
||||
useDSN: true,
|
||||
wantErr: false,
|
||||
},
|
||||
"require sslmode with DSN": {
|
||||
sslMode: "require",
|
||||
useDSN: true,
|
||||
wantErr: false,
|
||||
},
|
||||
"verify-ca sslmode with DSN": {
|
||||
sslMode: "verify-ca",
|
||||
useDSN: true,
|
||||
wantErr: false,
|
||||
},
|
||||
"verify-full sslmode with DSN": {
|
||||
sslMode: "verify-full",
|
||||
useDSN: true,
|
||||
wantErr: false,
|
||||
},
|
||||
"disable sslmode with fallback": {
|
||||
sslMode: "disable",
|
||||
useFallback: true,
|
||||
wantErr: true,
|
||||
expectedError: "error verifying connection",
|
||||
},
|
||||
"allow sslmode with fallback": {
|
||||
sslMode: "allow",
|
||||
useFallback: true,
|
||||
},
|
||||
"prefer sslmode with fallback": {
|
||||
sslMode: "prefer",
|
||||
useFallback: true,
|
||||
},
|
||||
"require sslmode with fallback": {
|
||||
sslMode: "require",
|
||||
useFallback: true,
|
||||
},
|
||||
"verify-ca sslmode with fallback": {
|
||||
sslMode: "verify-ca",
|
||||
useFallback: true,
|
||||
},
|
||||
"verify-full sslmode with fallback": {
|
||||
sslMode: "verify-full",
|
||||
useFallback: true,
|
||||
},
|
||||
"disable sslmode with DSN with fallback": {
|
||||
sslMode: "disable",
|
||||
useDSN: true,
|
||||
useFallback: true,
|
||||
wantErr: true,
|
||||
expectedError: "error verifying connection",
|
||||
},
|
||||
"allow sslmode with DSN with fallback": {
|
||||
sslMode: "allow",
|
||||
useDSN: true,
|
||||
useFallback: true,
|
||||
wantErr: false,
|
||||
},
|
||||
"prefer sslmode with DSN with fallback": {
|
||||
sslMode: "prefer",
|
||||
useDSN: true,
|
||||
useFallback: true,
|
||||
wantErr: false,
|
||||
},
|
||||
"require sslmode with DSN with fallback": {
|
||||
sslMode: "require",
|
||||
useDSN: true,
|
||||
useFallback: true,
|
||||
wantErr: false,
|
||||
},
|
||||
"verify-ca sslmode with DSN with fallback": {
|
||||
sslMode: "verify-ca",
|
||||
useDSN: true,
|
||||
useFallback: true,
|
||||
wantErr: false,
|
||||
},
|
||||
"verify-full sslmode with DSN with fallback": {
|
||||
sslMode: "verify-full",
|
||||
useDSN: true,
|
||||
useFallback: true,
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
for name, test := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create certificates for postgres authentication
|
||||
caCert := certhelpers.NewCert(t, certhelpers.CommonName("ca"), certhelpers.IsCA(true), certhelpers.SelfSign())
|
||||
clientCert := certhelpers.NewCert(t, certhelpers.CommonName("postgres"), certhelpers.DNS("localhost"), certhelpers.Parent(caCert))
|
||||
cleanup, connURL := postgresql.PrepareTestContainerWithSSL(t, test.sslMode, caCert, clientCert, test.useFallback)
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
if test.useDSN {
|
||||
var err error
|
||||
connURL, err = dbutil.ParseURL(connURL)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
connectionDetails := map[string]interface{}{
|
||||
"connection_url": connURL,
|
||||
"max_open_connections": 5,
|
||||
"tls_certificate": string(clientCert.CombinedPEM()),
|
||||
"tls_private_key": string(clientCert.PrivateKeyPEM()),
|
||||
"tls_ca": string(caCert.CombinedPEM()),
|
||||
}
|
||||
|
||||
req := dbplugin.InitializeRequest{
|
||||
Config: connectionDetails,
|
||||
VerifyConnection: true,
|
||||
}
|
||||
|
||||
db := new()
|
||||
_, err := dbtesting.VerifyInitialize(t, db, req)
|
||||
if test.wantErr && err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
} else if test.wantErr && !strings.Contains(err.Error(), test.expectedError) {
|
||||
t.Fatalf("got: %s, want: %s", err.Error(), test.expectedError)
|
||||
}
|
||||
|
||||
if !test.wantErr && !db.Initialized {
|
||||
t.Fatal("Database should be initialized")
|
||||
}
|
||||
|
||||
if err := db.Close(); err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostgreSQL_InitializeWithStringVals(t *testing.T) {
|
||||
db, cleanup := getPostgreSQL(t, map[string]interface{}{
|
||||
"max_open_connections": "5",
|
||||
|
||||
Reference in New Issue
Block a user