mirror of
https://github.com/optim-enterprises-bv/vault.git
synced 2025-11-02 03:27:54 +00:00
db/postgres: add feature flag protected sslinline configuration (#27871)
* adds sslinline option to postgres conn string * for database secrets type postgres, inspects the connection string for sslinline and generates a tlsconfig from the connection string. * support fallback hosts * remove broken multihost test * bootstrap container with cert material * overwrite pg config and set key file perms * add feature flag check * add tests * add license and comments * test all ssl modes * add test cases for dsn (key/value) connection strings * add fallback test cases * fix error formatting * add test for multi-host when using pgx native conn url parsing --------- Co-authored-by: Branden Horiuchi <Branden.Horiuchi@blackline.com>
This commit is contained in:
committed by
GitHub
parent
10068ffb0a
commit
899ebd4aff
@@ -1451,7 +1451,7 @@ func TestBackend_ConnectionURL_redacted(t *testing.T) {
|
|||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
cleanup, u := postgreshelper.PrepareTestContainerWithPassword(t, "13.4-buster", tt.password)
|
cleanup, u := postgreshelper.PrepareTestContainerWithPassword(t, tt.password)
|
||||||
t.Cleanup(cleanup)
|
t.Cleanup(cleanup)
|
||||||
|
|
||||||
p, err := url.Parse(u)
|
p, err := url.Parse(u)
|
||||||
|
|||||||
@@ -11,22 +11,29 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/hashicorp/vault/helper/testhelpers/certhelpers"
|
||||||
|
"github.com/hashicorp/vault/sdk/database/helper/connutil"
|
||||||
"github.com/hashicorp/vault/sdk/helper/docker"
|
"github.com/hashicorp/vault/sdk/helper/docker"
|
||||||
)
|
)
|
||||||
|
|
||||||
const postgresVersion = "13.4-buster"
|
const (
|
||||||
|
defaultPGImage = "docker.mirror.hashicorp.services/postgres"
|
||||||
|
defaultPGVersion = "13.4-buster"
|
||||||
|
defaultPGPass = "secret"
|
||||||
|
)
|
||||||
|
|
||||||
func defaultRunOpts(t *testing.T) docker.RunOptions {
|
func defaultRunOpts(t *testing.T) docker.RunOptions {
|
||||||
return docker.RunOptions{
|
return docker.RunOptions{
|
||||||
ContainerName: "postgres",
|
ContainerName: "postgres",
|
||||||
ImageRepo: "docker.mirror.hashicorp.services/postgres",
|
ImageRepo: defaultPGImage,
|
||||||
ImageTag: postgresVersion,
|
ImageTag: defaultPGVersion,
|
||||||
Env: []string{
|
Env: []string{
|
||||||
"POSTGRES_PASSWORD=secret",
|
"POSTGRES_PASSWORD=" + defaultPGPass,
|
||||||
"POSTGRES_DB=database",
|
"POSTGRES_DB=database",
|
||||||
},
|
},
|
||||||
Ports: []string{"5432/tcp"},
|
Ports: []string{"5432/tcp"},
|
||||||
DoNotAutoRemove: false,
|
DoNotAutoRemove: false,
|
||||||
|
OmitLogTimestamps: true,
|
||||||
LogConsumer: func(s string) {
|
LogConsumer: func(s string) {
|
||||||
if t.Failed() {
|
if t.Failed() {
|
||||||
t.Logf("container logs: %s", s)
|
t.Logf("container logs: %s", s)
|
||||||
@@ -36,7 +43,13 @@ func defaultRunOpts(t *testing.T) docker.RunOptions {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func PrepareTestContainer(t *testing.T) (func(), string) {
|
func PrepareTestContainer(t *testing.T) (func(), string) {
|
||||||
_, cleanup, url, _ := prepareTestContainer(t, defaultRunOpts(t), "secret", true, false)
|
_, cleanup, url, _ := prepareTestContainer(t, defaultRunOpts(t), defaultPGPass, true, false, false)
|
||||||
|
|
||||||
|
return cleanup, url
|
||||||
|
}
|
||||||
|
|
||||||
|
func PrepareTestContainerMultiHost(t *testing.T) (func(), string) {
|
||||||
|
_, cleanup, url, _ := prepareTestContainer(t, defaultRunOpts(t), defaultPGPass, true, false, true)
|
||||||
|
|
||||||
return cleanup, url
|
return cleanup, url
|
||||||
}
|
}
|
||||||
@@ -45,90 +58,138 @@ func PrepareTestContainer(t *testing.T) (func(), string) {
|
|||||||
// admin user configured so that we can safely call rotate-root without
|
// admin user configured so that we can safely call rotate-root without
|
||||||
// rotating the root DB credentials
|
// rotating the root DB credentials
|
||||||
func PrepareTestContainerWithVaultUser(t *testing.T, ctx context.Context) (func(), string) {
|
func PrepareTestContainerWithVaultUser(t *testing.T, ctx context.Context) (func(), string) {
|
||||||
runner, cleanup, url, id := prepareTestContainer(t, defaultRunOpts(t), "secret", true, false)
|
runner, cleanup, url, id := prepareTestContainer(t, defaultRunOpts(t), defaultPGPass, true, false, false)
|
||||||
|
|
||||||
cmd := []string{"psql", "-U", "postgres", "-c", "CREATE USER vaultadmin WITH LOGIN PASSWORD 'vaultpass' SUPERUSER"}
|
cmd := []string{"psql", "-U", "postgres", "-c", "CREATE USER vaultadmin WITH LOGIN PASSWORD 'vaultpass' SUPERUSER"}
|
||||||
_, err := runner.RunCmdInBackground(ctx, id, cmd)
|
mustRunCommand(t, ctx, runner, id, cmd)
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Could not run command (%v) in container: %v", cmd, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return cleanup, url
|
return cleanup, url
|
||||||
}
|
}
|
||||||
|
|
||||||
func PrepareTestContainerWithSSL(t *testing.T, ctx context.Context, version string) (func(), string) {
|
// PrepareTestContainerWithSSL will setup a test container with SSL enabled so
|
||||||
|
// that we can test client certificate authentication.
|
||||||
|
func PrepareTestContainerWithSSL(t *testing.T, ctx context.Context, sslMode string, useFallback bool) (func(), string) {
|
||||||
runOpts := defaultRunOpts(t)
|
runOpts := defaultRunOpts(t)
|
||||||
runOpts.Cmd = []string{"-c", "log_statement=all"}
|
runner, err := docker.NewServiceRunner(runOpts)
|
||||||
runner, cleanup, url, id := prepareTestContainer(t, runOpts, "secret", true, false)
|
|
||||||
|
|
||||||
content := "echo 'hostssl all all all cert clientcert=verify-ca' > /var/lib/postgresql/data/pg_hba.conf"
|
|
||||||
// Copy the ssl init script into the newly running container.
|
|
||||||
buildCtx := docker.NewBuildContext()
|
|
||||||
buildCtx["ssl-conf.sh"] = docker.PathContentsFromBytes([]byte(content))
|
|
||||||
if err := runner.CopyTo(id, "/usr/local/bin", buildCtx); err != nil {
|
|
||||||
t.Fatalf("Could not copy ssl init script into container: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// run the ssl init script to overwrite the pg_hba.conf file and set it to
|
|
||||||
// require SSL for each connection
|
|
||||||
cmd := []string{"bash", "/usr/local/bin/ssl-conf.sh"}
|
|
||||||
_, err := runner.RunCmdInBackground(ctx, id, cmd)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Could not run command (%v) in container: %v", cmd, err)
|
t.Fatalf("Could not provision docker service runner: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// reload so the config changes take effect
|
// first we connect with username/password because ssl is not enabled yet
|
||||||
cmd = []string{"psql", "-U", "postgres", "-c", "SELECT pg_reload_conf()"}
|
svc, id, err := runner.StartNewService(context.Background(), true, false, connectPostgres(defaultPGPass, runOpts.ImageRepo, false))
|
||||||
_, err = runner.RunCmdInBackground(ctx, id, cmd)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Could not run command (%v) in container: %v", cmd, err)
|
t.Fatalf("Could not start docker Postgres: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return cleanup, url
|
// Create certificates for postgres authentication
|
||||||
|
caCert := certhelpers.NewCert(t,
|
||||||
|
certhelpers.CommonName("ca"),
|
||||||
|
certhelpers.IsCA(true),
|
||||||
|
certhelpers.SelfSign(),
|
||||||
|
)
|
||||||
|
serverCert := certhelpers.NewCert(t,
|
||||||
|
certhelpers.CommonName("server"),
|
||||||
|
certhelpers.DNS("localhost"),
|
||||||
|
certhelpers.Parent(caCert),
|
||||||
|
)
|
||||||
|
clientCert := certhelpers.NewCert(t,
|
||||||
|
certhelpers.CommonName("postgres"),
|
||||||
|
certhelpers.DNS("localhost"),
|
||||||
|
certhelpers.Parent(caCert),
|
||||||
|
)
|
||||||
|
|
||||||
|
bCtx := docker.NewBuildContext()
|
||||||
|
bCtx["ca.crt"] = docker.PathContentsFromBytes(caCert.CombinedPEM())
|
||||||
|
bCtx["server.crt"] = docker.PathContentsFromBytes(serverCert.CombinedPEM())
|
||||||
|
bCtx["server.key"] = &docker.FileContents{
|
||||||
|
Data: serverCert.PrivateKeyPEM(),
|
||||||
|
Mode: 0o600,
|
||||||
|
// postgres uid
|
||||||
|
UID: 999,
|
||||||
}
|
}
|
||||||
|
|
||||||
func PrepareTestContainerWithPassword(t *testing.T, version, password string) (func(), string) {
|
// https://www.postgresql.org/docs/current/auth-pg-hba-conf.html
|
||||||
|
clientAuthConfig := "echo 'hostssl all all all cert clientcert=verify-ca' > /var/lib/postgresql/data/pg_hba.conf"
|
||||||
|
bCtx["ssl-conf.sh"] = docker.PathContentsFromString(clientAuthConfig)
|
||||||
|
pgConfig := `
|
||||||
|
cat << EOF > /var/lib/postgresql/data/postgresql.conf
|
||||||
|
# PostgreSQL configuration file
|
||||||
|
listen_addresses = '*'
|
||||||
|
max_connections = 100
|
||||||
|
shared_buffers = 128MB
|
||||||
|
dynamic_shared_memory_type = posix
|
||||||
|
max_wal_size = 1GB
|
||||||
|
min_wal_size = 80MB
|
||||||
|
ssl = on
|
||||||
|
ssl_ca_file = '/var/lib/postgresql/ca.crt'
|
||||||
|
ssl_cert_file = '/var/lib/postgresql/server.crt'
|
||||||
|
ssl_key_file= '/var/lib/postgresql/server.key'
|
||||||
|
EOF
|
||||||
|
`
|
||||||
|
bCtx["pg-conf.sh"] = docker.PathContentsFromString(pgConfig)
|
||||||
|
|
||||||
|
err = runner.CopyTo(id, "/var/lib/postgresql/", bCtx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to copy to container: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// overwrite the postgresql.conf config file with our ssl settings
|
||||||
|
mustRunCommand(t, ctx, runner, id,
|
||||||
|
[]string{"bash", "/var/lib/postgresql/pg-conf.sh"})
|
||||||
|
|
||||||
|
// overwrite the pg_hba.conf file and set it to require SSL for each connection
|
||||||
|
mustRunCommand(t, ctx, runner, id,
|
||||||
|
[]string{"bash", "/var/lib/postgresql/ssl-conf.sh"})
|
||||||
|
|
||||||
|
// reload so the config changes take effect and ssl is enabled
|
||||||
|
mustRunCommand(t, ctx, runner, id,
|
||||||
|
[]string{"psql", "-U", "postgres", "-c", "SELECT pg_reload_conf()"})
|
||||||
|
|
||||||
|
if sslMode == "disable" {
|
||||||
|
// return non-tls connection url
|
||||||
|
return svc.Cleanup, svc.Config.URL().String()
|
||||||
|
}
|
||||||
|
|
||||||
|
sslConfig, err := connectPostgresSSL(
|
||||||
|
t,
|
||||||
|
svc.Config.URL().Host,
|
||||||
|
sslMode,
|
||||||
|
string(caCert.CombinedPEM()),
|
||||||
|
string(clientCert.CombinedPEM()),
|
||||||
|
string(clientCert.PrivateKeyPEM()),
|
||||||
|
useFallback,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
svc.Cleanup()
|
||||||
|
t.Fatalf("failed to connect to postgres container via SSL: %v", err)
|
||||||
|
}
|
||||||
|
return svc.Cleanup, sslConfig.URL().String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func PrepareTestContainerWithPassword(t *testing.T, password string) (func(), string) {
|
||||||
runOpts := defaultRunOpts(t)
|
runOpts := defaultRunOpts(t)
|
||||||
runOpts.Env = []string{
|
runOpts.Env = []string{
|
||||||
"POSTGRES_PASSWORD=" + password,
|
"POSTGRES_PASSWORD=" + password,
|
||||||
"POSTGRES_DB=database",
|
"POSTGRES_DB=database",
|
||||||
}
|
}
|
||||||
|
|
||||||
_, cleanup, url, _ := prepareTestContainer(t, runOpts, password, true, false)
|
_, cleanup, url, _ := prepareTestContainer(t, runOpts, password, true, false, false)
|
||||||
|
|
||||||
return cleanup, url
|
return cleanup, url
|
||||||
}
|
}
|
||||||
|
|
||||||
func PrepareTestContainerRepmgr(t *testing.T, name, version string, envVars []string) (*docker.Runner, func(), string, string) {
|
func prepareTestContainer(t *testing.T, runOpts docker.RunOptions, password string, addSuffix, forceLocalAddr, useFallback bool,
|
||||||
runOpts := defaultRunOpts(t)
|
|
||||||
runOpts.ImageRepo = "docker.mirror.hashicorp.services/bitnami/postgresql-repmgr"
|
|
||||||
runOpts.ImageTag = version
|
|
||||||
runOpts.Env = append(envVars,
|
|
||||||
"REPMGR_PARTNER_NODES=psql-repl-node-0,psql-repl-node-1",
|
|
||||||
"REPMGR_PRIMARY_HOST=psql-repl-node-0",
|
|
||||||
"REPMGR_PASSWORD=repmgrpass",
|
|
||||||
"POSTGRESQL_PASSWORD=secret")
|
|
||||||
runOpts.DoNotAutoRemove = true
|
|
||||||
|
|
||||||
return prepareTestContainer(t, runOpts, "secret", false, true)
|
|
||||||
}
|
|
||||||
|
|
||||||
func prepareTestContainer(t *testing.T, runOpts docker.RunOptions, password string, addSuffix, forceLocalAddr bool,
|
|
||||||
) (*docker.Runner, func(), string, string) {
|
) (*docker.Runner, func(), string, string) {
|
||||||
if os.Getenv("PG_URL") != "" {
|
if os.Getenv("PG_URL") != "" {
|
||||||
return nil, func() {}, "", os.Getenv("PG_URL")
|
return nil, func() {}, "", os.Getenv("PG_URL")
|
||||||
}
|
}
|
||||||
|
|
||||||
if runOpts.ImageRepo == "bitnami/postgresql-repmgr" {
|
|
||||||
runOpts.NetworkID = os.Getenv("POSTGRES_MULTIHOST_NET")
|
|
||||||
}
|
|
||||||
|
|
||||||
runner, err := docker.NewServiceRunner(runOpts)
|
runner, err := docker.NewServiceRunner(runOpts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Could not start docker Postgres: %s", err)
|
t.Fatalf("Could not start docker Postgres: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
svc, containerID, err := runner.StartNewService(context.Background(), addSuffix, forceLocalAddr, connectPostgres(password, runOpts.ImageRepo))
|
svc, containerID, err := runner.StartNewService(context.Background(), addSuffix, forceLocalAddr, connectPostgres(password, runOpts.ImageRepo, useFallback))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Could not start docker Postgres: %s", err)
|
t.Fatalf("Could not start docker Postgres: %s", err)
|
||||||
}
|
}
|
||||||
@@ -136,12 +197,55 @@ func prepareTestContainer(t *testing.T, runOpts docker.RunOptions, password stri
|
|||||||
return runner, svc.Cleanup, svc.Config.URL().String(), containerID
|
return runner, svc.Cleanup, svc.Config.URL().String(), containerID
|
||||||
}
|
}
|
||||||
|
|
||||||
func connectPostgres(password, repo string) docker.ServiceAdapter {
|
// connectPostgresSSL is used to verify the connection of our test container
|
||||||
|
// and construct the connection string that is used in tests.
|
||||||
|
//
|
||||||
|
// NOTE: The RawQuery component of the url sets the custom sslinline field and
|
||||||
|
// inlines the certificate material in the sslrootcert, sslcert, and sslkey
|
||||||
|
// fields. This feature will be removed in a future version of the SDK.
|
||||||
|
func connectPostgresSSL(t *testing.T, host, sslMode, caCert, clientCert, clientKey string, useFallback bool) (docker.ServiceConfig, error) {
|
||||||
|
if useFallback {
|
||||||
|
// set the first host to a bad address so we can test the fallback logic
|
||||||
|
host = "localhost:55," + host
|
||||||
|
}
|
||||||
|
u := url.URL{
|
||||||
|
Scheme: "postgres",
|
||||||
|
User: url.User("postgres"),
|
||||||
|
Host: host,
|
||||||
|
Path: "postgres",
|
||||||
|
RawQuery: url.Values{
|
||||||
|
"sslmode": {sslMode},
|
||||||
|
"sslinline": {"true"},
|
||||||
|
"sslrootcert": {caCert},
|
||||||
|
"sslcert": {clientCert},
|
||||||
|
"sslkey": {clientKey},
|
||||||
|
}.Encode(),
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: remove this deprecated function call in a future SDK version
|
||||||
|
db, err := connutil.OpenPostgres("pgx", u.String())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
if err = db.Ping(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return docker.NewServiceURL(u), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func connectPostgres(password, repo string, useFallback bool) docker.ServiceAdapter {
|
||||||
return func(ctx context.Context, host string, port int) (docker.ServiceConfig, error) {
|
return func(ctx context.Context, host string, port int) (docker.ServiceConfig, error) {
|
||||||
|
hostAddr := fmt.Sprintf("%s:%d", host, port)
|
||||||
|
if useFallback {
|
||||||
|
// set the first host to a bad address so we can test the fallback logic
|
||||||
|
hostAddr = "localhost:55," + hostAddr
|
||||||
|
}
|
||||||
u := url.URL{
|
u := url.URL{
|
||||||
Scheme: "postgres",
|
Scheme: "postgres",
|
||||||
User: url.UserPassword("postgres", password),
|
User: url.UserPassword("postgres", password),
|
||||||
Host: fmt.Sprintf("%s:%d", host, port),
|
Host: hostAddr,
|
||||||
Path: "postgres",
|
Path: "postgres",
|
||||||
RawQuery: "sslmode=disable",
|
RawQuery: "sslmode=disable",
|
||||||
}
|
}
|
||||||
@@ -170,3 +274,14 @@ func RestartContainer(t *testing.T, ctx context.Context, runner *docker.Runner,
|
|||||||
t.Fatalf("Could not restart docker Postgres: %s", err)
|
t.Fatalf("Could not restart docker Postgres: %s", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func mustRunCommand(t *testing.T, ctx context.Context, runner *docker.Runner, containerID string, cmd []string) {
|
||||||
|
t.Helper()
|
||||||
|
_, stderr, retcode, err := runner.RunCmdWithOutput(ctx, containerID, cmd)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Could not run command (%v) in container: %v", cmd, err)
|
||||||
|
}
|
||||||
|
if retcode != 0 || len(stderr) != 0 {
|
||||||
|
t.Fatalf("exit code: %v, stderr: %v", retcode, string(stderr))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ import (
|
|||||||
dbtesting "github.com/hashicorp/vault/sdk/database/dbplugin/v5/testing"
|
dbtesting "github.com/hashicorp/vault/sdk/database/dbplugin/v5/testing"
|
||||||
"github.com/hashicorp/vault/sdk/database/helper/connutil"
|
"github.com/hashicorp/vault/sdk/database/helper/connutil"
|
||||||
"github.com/hashicorp/vault/sdk/database/helper/dbutil"
|
"github.com/hashicorp/vault/sdk/database/helper/dbutil"
|
||||||
"github.com/hashicorp/vault/sdk/helper/docker"
|
"github.com/hashicorp/vault/sdk/helper/pluginutil"
|
||||||
"github.com/hashicorp/vault/sdk/helper/template"
|
"github.com/hashicorp/vault/sdk/helper/template"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
@@ -58,6 +58,274 @@ func TestPostgreSQL_Initialize(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestPostgreSQL_InitializeMultiHost tests the functionality of Postgres's
|
||||||
|
// multi-host connection strings.
|
||||||
|
func TestPostgreSQL_InitializeMultiHost(t *testing.T) {
|
||||||
|
cleanup, connURL := postgresql.PrepareTestContainerMultiHost(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
connectionDetails := map[string]interface{}{
|
||||||
|
"connection_url": connURL,
|
||||||
|
"max_open_connections": 5,
|
||||||
|
}
|
||||||
|
|
||||||
|
req := dbplugin.InitializeRequest{
|
||||||
|
Config: connectionDetails,
|
||||||
|
VerifyConnection: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
db := new()
|
||||||
|
dbtesting.AssertInitialize(t, db, req)
|
||||||
|
|
||||||
|
if !db.Initialized {
|
||||||
|
t.Fatal("Database should be initialized")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := db.Close(); err != nil {
|
||||||
|
t.Fatalf("err: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestPostgreSQL_InitializeSSLFeatureFlag tests that the VAULT_PLUGIN_USE_POSTGRES_SSLINLINE
|
||||||
|
// flag guards against unwanted usage of the deprecated SSL client authentication path.
|
||||||
|
// TODO: remove this when we remove the underlying feature in a future SDK version
|
||||||
|
func TestPostgreSQL_InitializeSSLFeatureFlag(t *testing.T) {
|
||||||
|
// set the flag to true so we can call PrepareTestContainerWithSSL
|
||||||
|
// which does a validation check on the connection
|
||||||
|
t.Setenv(pluginutil.PluginUsePostgresSSLInline, "true")
|
||||||
|
|
||||||
|
cleanup, connURL := postgresql.PrepareTestContainerWithSSL(t, context.Background(), "verify-ca", false)
|
||||||
|
t.Cleanup(cleanup)
|
||||||
|
|
||||||
|
type testCase struct {
|
||||||
|
env string
|
||||||
|
wantErr bool
|
||||||
|
expectedError string
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := map[string]testCase{
|
||||||
|
"feature flag is true": {
|
||||||
|
env: "true",
|
||||||
|
wantErr: false,
|
||||||
|
expectedError: "",
|
||||||
|
},
|
||||||
|
"feature flag is unset or empty": {
|
||||||
|
env: "",
|
||||||
|
wantErr: true,
|
||||||
|
// this error is expected because the env var unset means we are
|
||||||
|
// using pgx's native connection string parsing which does not
|
||||||
|
// support inlining of the certificate material in the sslrootcert,
|
||||||
|
// sslcert, and sslkey fields
|
||||||
|
expectedError: "error verifying connection",
|
||||||
|
},
|
||||||
|
"feature flag is false": {
|
||||||
|
env: "false",
|
||||||
|
wantErr: true,
|
||||||
|
expectedError: "failed to open postgres connection with deprecated funtion",
|
||||||
|
},
|
||||||
|
"feature flag is invalid": {
|
||||||
|
env: "foo",
|
||||||
|
wantErr: true,
|
||||||
|
expectedError: "failed to open postgres connection with deprecated funtion",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, test := range tests {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
// update the env var with the value we are testing
|
||||||
|
t.Setenv(pluginutil.PluginUsePostgresSSLInline, test.env)
|
||||||
|
|
||||||
|
connectionDetails := map[string]interface{}{
|
||||||
|
"connection_url": connURL,
|
||||||
|
"max_open_connections": 5,
|
||||||
|
}
|
||||||
|
|
||||||
|
req := dbplugin.InitializeRequest{
|
||||||
|
Config: connectionDetails,
|
||||||
|
VerifyConnection: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
db := new()
|
||||||
|
_, err := dbtesting.VerifyInitialize(t, db, req)
|
||||||
|
if test.wantErr && err == nil {
|
||||||
|
t.Fatal("expected error, got nil")
|
||||||
|
} else if test.wantErr && !strings.Contains(err.Error(), test.expectedError) {
|
||||||
|
t.Fatalf("got: %s, want: %s", err.Error(), test.expectedError)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !test.wantErr && !db.Initialized {
|
||||||
|
t.Fatal("Database should be initialized")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := db.Close(); err != nil {
|
||||||
|
t.Fatalf("err: %s", err)
|
||||||
|
}
|
||||||
|
// unset for the next test case
|
||||||
|
os.Unsetenv(pluginutil.PluginUsePostgresSSLInline)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestPostgreSQL_InitializeSSL tests that we can successfully authenticate
|
||||||
|
// with a postgres server via ssl with a URL connection string or DSN (key/value)
|
||||||
|
// for each ssl mode.
|
||||||
|
// TODO: remove this when we remove the underlying feature in a future SDK version
|
||||||
|
func TestPostgreSQL_InitializeSSL(t *testing.T) {
|
||||||
|
// required to enable the sslinline custom parsing
|
||||||
|
t.Setenv(pluginutil.PluginUsePostgresSSLInline, "true")
|
||||||
|
|
||||||
|
type testCase struct {
|
||||||
|
sslMode string
|
||||||
|
useDSN bool
|
||||||
|
useFallback bool
|
||||||
|
wantErr bool
|
||||||
|
expectedError string
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := map[string]testCase{
|
||||||
|
"disable sslmode": {
|
||||||
|
sslMode: "disable",
|
||||||
|
wantErr: true,
|
||||||
|
expectedError: "error verifying connection",
|
||||||
|
},
|
||||||
|
"allow sslmode": {
|
||||||
|
sslMode: "allow",
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
"prefer sslmode": {
|
||||||
|
sslMode: "prefer",
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
"require sslmode": {
|
||||||
|
sslMode: "require",
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
"verify-ca sslmode": {
|
||||||
|
sslMode: "verify-ca",
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
"disable sslmode with DSN": {
|
||||||
|
sslMode: "disable",
|
||||||
|
useDSN: true,
|
||||||
|
wantErr: true,
|
||||||
|
expectedError: "error verifying connection",
|
||||||
|
},
|
||||||
|
"allow sslmode with DSN": {
|
||||||
|
sslMode: "allow",
|
||||||
|
useDSN: true,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
"prefer sslmode with DSN": {
|
||||||
|
sslMode: "prefer",
|
||||||
|
useDSN: true,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
"require sslmode with DSN": {
|
||||||
|
sslMode: "require",
|
||||||
|
useDSN: true,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
"verify-ca sslmode with DSN": {
|
||||||
|
sslMode: "verify-ca",
|
||||||
|
useDSN: true,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
"disable sslmode with fallback": {
|
||||||
|
sslMode: "disable",
|
||||||
|
useFallback: true,
|
||||||
|
wantErr: true,
|
||||||
|
expectedError: "error verifying connection",
|
||||||
|
},
|
||||||
|
"allow sslmode with fallback": {
|
||||||
|
sslMode: "allow",
|
||||||
|
useFallback: true,
|
||||||
|
},
|
||||||
|
"prefer sslmode with fallback": {
|
||||||
|
sslMode: "prefer",
|
||||||
|
useFallback: true,
|
||||||
|
},
|
||||||
|
"require sslmode with fallback": {
|
||||||
|
sslMode: "require",
|
||||||
|
useFallback: true,
|
||||||
|
},
|
||||||
|
"verify-ca sslmode with fallback": {
|
||||||
|
sslMode: "verify-ca",
|
||||||
|
useFallback: true,
|
||||||
|
},
|
||||||
|
"disable sslmode with DSN with fallback": {
|
||||||
|
sslMode: "disable",
|
||||||
|
useDSN: true,
|
||||||
|
useFallback: true,
|
||||||
|
wantErr: true,
|
||||||
|
expectedError: "error verifying connection",
|
||||||
|
},
|
||||||
|
"allow sslmode with DSN with fallback": {
|
||||||
|
sslMode: "allow",
|
||||||
|
useDSN: true,
|
||||||
|
useFallback: true,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
"prefer sslmode with DSN with fallback": {
|
||||||
|
sslMode: "prefer",
|
||||||
|
useDSN: true,
|
||||||
|
useFallback: true,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
"require sslmode with DSN with fallback": {
|
||||||
|
sslMode: "require",
|
||||||
|
useDSN: true,
|
||||||
|
useFallback: true,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
"verify-ca sslmode with DSN with fallback": {
|
||||||
|
sslMode: "verify-ca",
|
||||||
|
useDSN: true,
|
||||||
|
useFallback: true,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for name, test := range tests {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
cleanup, connURL := postgresql.PrepareTestContainerWithSSL(t, context.Background(), test.sslMode, test.useFallback)
|
||||||
|
t.Cleanup(cleanup)
|
||||||
|
|
||||||
|
if test.useDSN {
|
||||||
|
var err error
|
||||||
|
connURL, err = dbutil.ParseURL(connURL)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
connectionDetails := map[string]interface{}{
|
||||||
|
"connection_url": connURL,
|
||||||
|
"max_open_connections": 5,
|
||||||
|
}
|
||||||
|
|
||||||
|
req := dbplugin.InitializeRequest{
|
||||||
|
Config: connectionDetails,
|
||||||
|
VerifyConnection: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
db := new()
|
||||||
|
_, err := dbtesting.VerifyInitialize(t, db, req)
|
||||||
|
if test.wantErr && err == nil {
|
||||||
|
t.Fatal("expected error, got nil")
|
||||||
|
} else if test.wantErr && !strings.Contains(err.Error(), test.expectedError) {
|
||||||
|
t.Fatalf("got: %s, want: %s", err.Error(), test.expectedError)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !test.wantErr && !db.Initialized {
|
||||||
|
t.Fatal("Database should be initialized")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := db.Close(); err != nil {
|
||||||
|
t.Fatalf("err: %s", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestPostgreSQL_InitializeWithStringVals(t *testing.T) {
|
func TestPostgreSQL_InitializeWithStringVals(t *testing.T) {
|
||||||
db, cleanup := getPostgreSQL(t, map[string]interface{}{
|
db, cleanup := getPostgreSQL(t, map[string]interface{}{
|
||||||
"max_open_connections": "5",
|
"max_open_connections": "5",
|
||||||
@@ -1268,139 +1536,6 @@ func TestNewUser_CloudGCP(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// This is a long-running integration test which tests the functionality of Postgres's multi-host
|
|
||||||
// connection strings. It uses two Postgres containers preconfigured with Replication Manager
|
|
||||||
// provided by Bitnami. This test currently does not run in CI and must be run manually. This is
|
|
||||||
// due to the test length, as it requires multiple sleep calls to ensure cluster setup and
|
|
||||||
// primary node failover occurs before the test steps continue.
|
|
||||||
//
|
|
||||||
// To run the test, set the environment variable POSTGRES_MULTIHOST_NET to the value of
|
|
||||||
// a docker network you've preconfigured, e.g.
|
|
||||||
// 'docker network create -d bridge postgres-repmgr'
|
|
||||||
// 'export POSTGRES_MULTIHOST_NET=postgres-repmgr'
|
|
||||||
func TestPostgreSQL_Repmgr(t *testing.T) {
|
|
||||||
_, exists := os.LookupEnv("POSTGRES_MULTIHOST_NET")
|
|
||||||
if !exists {
|
|
||||||
t.Skipf("POSTGRES_MULTIHOST_NET not set, skipping test")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Run two postgres-repmgr containers in a replication cluster
|
|
||||||
db0, runner0, url0, container0 := testPostgreSQL_Repmgr_Container(t, "psql-repl-node-0")
|
|
||||||
_, _, url1, _ := testPostgreSQL_Repmgr_Container(t, "psql-repl-node-1")
|
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 300*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
time.Sleep(10 * time.Second)
|
|
||||||
|
|
||||||
// Write a read role to the cluster
|
|
||||||
_, err := db0.NewUser(ctx, dbplugin.NewUserRequest{
|
|
||||||
Statements: dbplugin.Statements{
|
|
||||||
Commands: []string{
|
|
||||||
`CREATE ROLE "ro" NOINHERIT;
|
|
||||||
GRANT SELECT ON ALL TABLES IN SCHEMA public TO "ro";`,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("no error expected, got: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Open a connection to both databases using the multihost connection string
|
|
||||||
connectionDetails := map[string]interface{}{
|
|
||||||
"connection_url": fmt.Sprintf("postgresql://{{username}}:{{password}}@%s,%s/postgres?target_session_attrs=read-write", getHost(url0), getHost(url1)),
|
|
||||||
"username": "postgres",
|
|
||||||
"password": "secret",
|
|
||||||
}
|
|
||||||
req := dbplugin.InitializeRequest{
|
|
||||||
Config: connectionDetails,
|
|
||||||
VerifyConnection: true,
|
|
||||||
}
|
|
||||||
|
|
||||||
db := new()
|
|
||||||
dbtesting.AssertInitialize(t, db, req)
|
|
||||||
if !db.Initialized {
|
|
||||||
t.Fatal("Database should be initialized")
|
|
||||||
}
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
// Add a user to the cluster, then stop the primary container
|
|
||||||
if err = testPostgreSQL_Repmgr_AddUser(ctx, db); err != nil {
|
|
||||||
t.Fatalf("no error expected, got: %s", err)
|
|
||||||
}
|
|
||||||
postgresql.StopContainer(t, ctx, runner0, container0)
|
|
||||||
|
|
||||||
// Try adding a new user immediately - expect failure as the database
|
|
||||||
// cluster is still switching primaries
|
|
||||||
err = testPostgreSQL_Repmgr_AddUser(ctx, db)
|
|
||||||
if !strings.HasSuffix(err.Error(), "ValidateConnect failed (read only connection)") {
|
|
||||||
t.Fatalf("expected error was not received, got: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
time.Sleep(20 * time.Second)
|
|
||||||
|
|
||||||
// Try adding a new user again which should succeed after the sleep
|
|
||||||
// as the primary failover should have finished. Then, restart
|
|
||||||
// the first container which should become a secondary DB.
|
|
||||||
if err = testPostgreSQL_Repmgr_AddUser(ctx, db); err != nil {
|
|
||||||
t.Fatalf("no error expected, got: %s", err)
|
|
||||||
}
|
|
||||||
postgresql.RestartContainer(t, ctx, runner0, container0)
|
|
||||||
|
|
||||||
time.Sleep(10 * time.Second)
|
|
||||||
|
|
||||||
// A final new user to add, which should succeed after the secondary joins.
|
|
||||||
if err = testPostgreSQL_Repmgr_AddUser(ctx, db); err != nil {
|
|
||||||
t.Fatalf("no error expected, got: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := db.Close(); err != nil {
|
|
||||||
t.Fatalf("err: %s", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func testPostgreSQL_Repmgr_Container(t *testing.T, name string) (*PostgreSQL, *docker.Runner, string, string) {
|
|
||||||
envVars := []string{
|
|
||||||
"REPMGR_NODE_NAME=" + name,
|
|
||||||
"REPMGR_NODE_NETWORK_NAME=" + name,
|
|
||||||
}
|
|
||||||
|
|
||||||
runner, cleanup, connURL, containerID := postgresql.PrepareTestContainerRepmgr(t, name, "13.4.0", envVars)
|
|
||||||
t.Cleanup(cleanup)
|
|
||||||
|
|
||||||
connectionDetails := map[string]interface{}{
|
|
||||||
"connection_url": connURL,
|
|
||||||
}
|
|
||||||
req := dbplugin.InitializeRequest{
|
|
||||||
Config: connectionDetails,
|
|
||||||
VerifyConnection: true,
|
|
||||||
}
|
|
||||||
db := new()
|
|
||||||
dbtesting.AssertInitialize(t, db, req)
|
|
||||||
if !db.Initialized {
|
|
||||||
t.Fatal("Database should be initialized")
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := db.Close(); err != nil {
|
|
||||||
t.Fatalf("err: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return db, runner, connURL, containerID
|
|
||||||
}
|
|
||||||
|
|
||||||
func testPostgreSQL_Repmgr_AddUser(ctx context.Context, db *PostgreSQL) error {
|
|
||||||
_, err := db.NewUser(ctx, dbplugin.NewUserRequest{
|
|
||||||
Statements: dbplugin.Statements{
|
|
||||||
Commands: []string{
|
|
||||||
`CREATE ROLE "{{name}}" WITH LOGIN PASSWORD '{{password}}' VALID UNTIL '{{expiration}}' INHERIT;
|
|
||||||
GRANT ro TO "{{name}}";`,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func getHost(url string) string {
|
func getHost(url string) string {
|
||||||
splitCreds := strings.Split(url, "@")[1]
|
splitCreds := strings.Split(url, "@")[1]
|
||||||
|
|
||||||
|
|||||||
466
sdk/database/helper/connutil/postgres.go
Normal file
466
sdk/database/helper/connutil/postgres.go
Normal file
@@ -0,0 +1,466 @@
|
|||||||
|
// Copyright (c) 2019-2021 Jack Christensen
|
||||||
|
|
||||||
|
// MIT License
|
||||||
|
|
||||||
|
// Permission is hereby granted, free of charge, to any person obtaining
|
||||||
|
// a copy of this software and associated documentation files (the
|
||||||
|
// "Software"), to deal in the Software without restriction, including
|
||||||
|
// without limitation the rights to use, copy, modify, merge, publish,
|
||||||
|
// distribute, sublicense, and/or sell copies of the Software, and to
|
||||||
|
// permit persons to whom the Software is furnished to do so, subject to
|
||||||
|
// the following conditions:
|
||||||
|
|
||||||
|
// The above copyright notice and this permission notice shall be
|
||||||
|
// included in all copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||||
|
// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||||
|
// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
|
||||||
|
// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
|
||||||
|
// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
|
||||||
|
// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
|
||||||
|
// WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||||
|
|
||||||
|
// Copied from https://github.com/jackc/pgconn/blob/1860f4e57204614f40d05a5c76a43e8d80fde9da/config.go
|
||||||
|
|
||||||
|
package connutil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
|
"database/sql"
|
||||||
|
"encoding/pem"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"math"
|
||||||
|
"net"
|
||||||
|
"net/url"
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/hashicorp/vault/sdk/helper/pluginutil"
|
||||||
|
"github.com/jackc/pgconn"
|
||||||
|
"github.com/jackc/pgx/v4"
|
||||||
|
"github.com/jackc/pgx/v4/stdlib"
|
||||||
|
)
|
||||||
|
|
||||||
|
// OpenPostgres parses the connection string and opens a connection to the database.
|
||||||
|
//
|
||||||
|
// If sslinline is set, strips the connection string of all ssl settings and
|
||||||
|
// creates a TLS config based on the settings provided, then uses the
|
||||||
|
// RegisterConnConfig function to create a new connection. This is necessary
|
||||||
|
// because the pgx driver does not support the sslinline parameter and instead
|
||||||
|
// expects to source ssl material from the file system.
|
||||||
|
//
|
||||||
|
// Deprecated: OpenPostgres will be removed in a future version of the Vault SDK.
|
||||||
|
func OpenPostgres(driverName, connString string) (*sql.DB, error) {
|
||||||
|
if ok, _ := strconv.ParseBool(os.Getenv(pluginutil.PluginUsePostgresSSLInline)); !ok {
|
||||||
|
return nil, fmt.Errorf("failed to open postgres connection with deprecated funtion, set feature flag to enable")
|
||||||
|
}
|
||||||
|
|
||||||
|
var options pgconn.ParseConfigOptions
|
||||||
|
|
||||||
|
settings := make(map[string]string)
|
||||||
|
if connString != "" {
|
||||||
|
var err error
|
||||||
|
// connString may be a database URL or a DSN
|
||||||
|
if strings.HasPrefix(connString, "postgres://") || strings.HasPrefix(connString, "postgresql://") {
|
||||||
|
settings, err = parsePostgresURLSettings(connString)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse as URL: %w", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
settings, err = parsePostgresDSNSettings(connString)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse as DSN: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// get the inline flag
|
||||||
|
sslInline := settings["sslinline"] == "true"
|
||||||
|
|
||||||
|
// if sslinline is not set, open a regular connection
|
||||||
|
if !sslInline {
|
||||||
|
return sql.Open(driverName, connString)
|
||||||
|
}
|
||||||
|
|
||||||
|
// generate a new DSN without the ssl settings
|
||||||
|
newConnStr := []string{"sslmode=disable"}
|
||||||
|
for k, v := range settings {
|
||||||
|
switch k {
|
||||||
|
case "sslinline", "sslcert", "sslkey", "sslrootcert", "sslmode":
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
newConnStr = append(newConnStr, fmt.Sprintf("%s='%s'", k, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// parse the updated config
|
||||||
|
config, err := pgx.ParseConfig(strings.Join(newConnStr, " "))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// create a TLS config
|
||||||
|
fallbacks := []*pgconn.FallbackConfig{}
|
||||||
|
|
||||||
|
hosts := strings.Split(settings["host"], ",")
|
||||||
|
ports := strings.Split(settings["port"], ",")
|
||||||
|
|
||||||
|
for i, host := range hosts {
|
||||||
|
var portStr string
|
||||||
|
if i < len(ports) {
|
||||||
|
portStr = ports[i]
|
||||||
|
} else {
|
||||||
|
portStr = ports[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
port, err := parsePort(portStr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid port: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var tlsConfigs []*tls.Config
|
||||||
|
|
||||||
|
// Ignore TLS settings if Unix domain socket like libpq
|
||||||
|
if network, _ := pgconn.NetworkAddress(host, port); network == "unix" {
|
||||||
|
tlsConfigs = append(tlsConfigs, nil)
|
||||||
|
} else {
|
||||||
|
var err error
|
||||||
|
tlsConfigs, err = configPostgresTLS(settings, host, options)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to configure TLS: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tlsConfig := range tlsConfigs {
|
||||||
|
fallbacks = append(fallbacks, &pgconn.FallbackConfig{
|
||||||
|
Host: host,
|
||||||
|
Port: port,
|
||||||
|
TLSConfig: tlsConfig,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
config.Host = fallbacks[0].Host
|
||||||
|
config.Port = fallbacks[0].Port
|
||||||
|
config.TLSConfig = fallbacks[0].TLSConfig
|
||||||
|
config.Fallbacks = fallbacks[1:]
|
||||||
|
|
||||||
|
return sql.Open(driverName, stdlib.RegisterConnConfig(config))
|
||||||
|
}
|
||||||
|
|
||||||
|
// configPostgresTLS uses libpq's TLS parameters to construct []*tls.Config. It is
|
||||||
|
// necessary to allow returning multiple TLS configs as sslmode "allow" and
|
||||||
|
// "prefer" allow fallback.
|
||||||
|
//
|
||||||
|
// Copied from https://github.com/jackc/pgconn/blob/1860f4e57204614f40d05a5c76a43e8d80fde9da/config.go
|
||||||
|
// and modified to read ssl material by value instead of file location.
|
||||||
|
func configPostgresTLS(settings map[string]string, thisHost string, parseConfigOptions pgconn.ParseConfigOptions) ([]*tls.Config, error) {
|
||||||
|
host := thisHost
|
||||||
|
sslmode := settings["sslmode"]
|
||||||
|
sslrootcert := settings["sslrootcert"]
|
||||||
|
sslcert := settings["sslcert"]
|
||||||
|
sslkey := settings["sslkey"]
|
||||||
|
sslpassword := settings["sslpassword"]
|
||||||
|
sslsni := settings["sslsni"]
|
||||||
|
|
||||||
|
// Match libpq default behavior
|
||||||
|
if sslmode == "" {
|
||||||
|
sslmode = "prefer"
|
||||||
|
}
|
||||||
|
if sslsni == "" {
|
||||||
|
sslsni = "1"
|
||||||
|
}
|
||||||
|
|
||||||
|
tlsConfig := &tls.Config{}
|
||||||
|
|
||||||
|
switch sslmode {
|
||||||
|
case "disable":
|
||||||
|
return []*tls.Config{nil}, nil
|
||||||
|
case "allow", "prefer":
|
||||||
|
tlsConfig.InsecureSkipVerify = true
|
||||||
|
case "require":
|
||||||
|
// According to PostgreSQL documentation, if a root CA file exists,
|
||||||
|
// the behavior of sslmode=require should be the same as that of verify-ca
|
||||||
|
//
|
||||||
|
// See https://www.postgresql.org/docs/12/libpq-ssl.html
|
||||||
|
if sslrootcert != "" {
|
||||||
|
goto nextCase
|
||||||
|
}
|
||||||
|
tlsConfig.InsecureSkipVerify = true
|
||||||
|
break
|
||||||
|
nextCase:
|
||||||
|
fallthrough
|
||||||
|
case "verify-ca":
|
||||||
|
// Don't perform the default certificate verification because it
|
||||||
|
// will verify the hostname. Instead, verify the server's
|
||||||
|
// certificate chain ourselves in VerifyPeerCertificate and
|
||||||
|
// ignore the server name. This emulates libpq's verify-ca
|
||||||
|
// behavior.
|
||||||
|
//
|
||||||
|
// See https://github.com/golang/go/issues/21971#issuecomment-332693931
|
||||||
|
// and https://pkg.go.dev/crypto/tls?tab=doc#example-Config-VerifyPeerCertificate
|
||||||
|
// for more info.
|
||||||
|
tlsConfig.InsecureSkipVerify = true
|
||||||
|
tlsConfig.VerifyPeerCertificate = func(certificates [][]byte, _ [][]*x509.Certificate) error {
|
||||||
|
certs := make([]*x509.Certificate, len(certificates))
|
||||||
|
for i, asn1Data := range certificates {
|
||||||
|
cert, err := x509.ParseCertificate(asn1Data)
|
||||||
|
if err != nil {
|
||||||
|
return errors.New("failed to parse certificate from server: " + err.Error())
|
||||||
|
}
|
||||||
|
certs[i] = cert
|
||||||
|
}
|
||||||
|
|
||||||
|
// Leave DNSName empty to skip hostname verification.
|
||||||
|
opts := x509.VerifyOptions{
|
||||||
|
Roots: tlsConfig.RootCAs,
|
||||||
|
Intermediates: x509.NewCertPool(),
|
||||||
|
}
|
||||||
|
// Skip the first cert because it's the leaf. All others
|
||||||
|
// are intermediates.
|
||||||
|
for _, cert := range certs[1:] {
|
||||||
|
opts.Intermediates.AddCert(cert)
|
||||||
|
}
|
||||||
|
_, err := certs[0].Verify(opts)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
case "verify-full":
|
||||||
|
tlsConfig.ServerName = host
|
||||||
|
default:
|
||||||
|
return nil, errors.New("sslmode is invalid")
|
||||||
|
}
|
||||||
|
|
||||||
|
if sslrootcert != "" {
|
||||||
|
caCertPool := x509.NewCertPool()
|
||||||
|
if !caCertPool.AppendCertsFromPEM([]byte(sslrootcert)) {
|
||||||
|
return nil, errors.New("unable to add CA to cert pool")
|
||||||
|
}
|
||||||
|
|
||||||
|
tlsConfig.RootCAs = caCertPool
|
||||||
|
tlsConfig.ClientCAs = caCertPool
|
||||||
|
}
|
||||||
|
|
||||||
|
if (sslcert != "" && sslkey == "") || (sslcert == "" && sslkey != "") {
|
||||||
|
return nil, errors.New(`both "sslcert" and "sslkey" are required`)
|
||||||
|
}
|
||||||
|
|
||||||
|
if sslcert != "" && sslkey != "" {
|
||||||
|
block, _ := pem.Decode([]byte(sslkey))
|
||||||
|
var pemKey []byte
|
||||||
|
var decryptedKey []byte
|
||||||
|
var decryptedError error
|
||||||
|
// If PEM is encrypted, attempt to decrypt using pass phrase
|
||||||
|
if x509.IsEncryptedPEMBlock(block) {
|
||||||
|
// Attempt decryption with pass phrase
|
||||||
|
// NOTE: only supports RSA (PKCS#1)
|
||||||
|
if sslpassword != "" {
|
||||||
|
decryptedKey, decryptedError = x509.DecryptPEMBlock(block, []byte(sslpassword))
|
||||||
|
}
|
||||||
|
// if sslpassword not provided or has decryption error when use it
|
||||||
|
// try to find sslpassword with callback function
|
||||||
|
if sslpassword == "" || decryptedError != nil {
|
||||||
|
if parseConfigOptions.GetSSLPassword != nil {
|
||||||
|
sslpassword = parseConfigOptions.GetSSLPassword(context.Background())
|
||||||
|
}
|
||||||
|
if sslpassword == "" {
|
||||||
|
return nil, fmt.Errorf("unable to find sslpassword")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
decryptedKey, decryptedError = x509.DecryptPEMBlock(block, []byte(sslpassword))
|
||||||
|
// Should we also provide warning for PKCS#1 needed?
|
||||||
|
if decryptedError != nil {
|
||||||
|
return nil, fmt.Errorf("unable to decrypt key: %w", decryptedError)
|
||||||
|
}
|
||||||
|
|
||||||
|
pemBytes := pem.Block{
|
||||||
|
Type: "RSA PRIVATE KEY",
|
||||||
|
Bytes: decryptedKey,
|
||||||
|
}
|
||||||
|
pemKey = pem.EncodeToMemory(&pemBytes)
|
||||||
|
} else {
|
||||||
|
pemKey = pem.EncodeToMemory(block)
|
||||||
|
}
|
||||||
|
|
||||||
|
cert, err := tls.X509KeyPair([]byte(sslcert), pemKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("unable to load cert: %w", err)
|
||||||
|
}
|
||||||
|
tlsConfig.Certificates = []tls.Certificate{cert}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set Server Name Indication (SNI), if enabled by connection parameters.
|
||||||
|
// Per RFC 6066, do not set it if the host is a literal IP address (IPv4
|
||||||
|
// or IPv6).
|
||||||
|
if sslsni == "1" && net.ParseIP(host) == nil {
|
||||||
|
tlsConfig.ServerName = host
|
||||||
|
}
|
||||||
|
|
||||||
|
switch sslmode {
|
||||||
|
case "allow":
|
||||||
|
return []*tls.Config{nil, tlsConfig}, nil
|
||||||
|
case "prefer":
|
||||||
|
return []*tls.Config{tlsConfig, nil}, nil
|
||||||
|
case "require", "verify-ca", "verify-full":
|
||||||
|
return []*tls.Config{tlsConfig}, nil
|
||||||
|
default:
|
||||||
|
panic("BUG: bad sslmode should already have been caught")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func parsePort(s string) (uint16, error) {
|
||||||
|
port, err := strconv.ParseUint(s, 10, 16)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
if port < 1 || port > math.MaxUint16 {
|
||||||
|
return 0, errors.New("outside range")
|
||||||
|
}
|
||||||
|
return uint16(port), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var asciiSpace = [256]uint8{'\t': 1, '\n': 1, '\v': 1, '\f': 1, '\r': 1, ' ': 1}
|
||||||
|
|
||||||
|
func parsePostgresURLSettings(connString string) (map[string]string, error) {
|
||||||
|
settings := make(map[string]string)
|
||||||
|
|
||||||
|
url, err := url.Parse(connString)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if url.User != nil {
|
||||||
|
settings["user"] = url.User.Username()
|
||||||
|
if password, present := url.User.Password(); present {
|
||||||
|
settings["password"] = password
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle multiple host:port's in url.Host by splitting them into host,host,host and port,port,port.
|
||||||
|
var hosts []string
|
||||||
|
var ports []string
|
||||||
|
for _, host := range strings.Split(url.Host, ",") {
|
||||||
|
if host == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if isIPOnly(host) {
|
||||||
|
hosts = append(hosts, strings.Trim(host, "[]"))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
h, p, err := net.SplitHostPort(host)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to split host:port in '%s', err: %w", host, err)
|
||||||
|
}
|
||||||
|
if h != "" {
|
||||||
|
hosts = append(hosts, h)
|
||||||
|
}
|
||||||
|
if p != "" {
|
||||||
|
ports = append(ports, p)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(hosts) > 0 {
|
||||||
|
settings["host"] = strings.Join(hosts, ",")
|
||||||
|
}
|
||||||
|
if len(ports) > 0 {
|
||||||
|
settings["port"] = strings.Join(ports, ",")
|
||||||
|
}
|
||||||
|
|
||||||
|
database := strings.TrimLeft(url.Path, "/")
|
||||||
|
if database != "" {
|
||||||
|
settings["database"] = database
|
||||||
|
}
|
||||||
|
|
||||||
|
nameMap := map[string]string{
|
||||||
|
"dbname": "database",
|
||||||
|
}
|
||||||
|
|
||||||
|
for k, v := range url.Query() {
|
||||||
|
if k2, present := nameMap[k]; present {
|
||||||
|
k = k2
|
||||||
|
}
|
||||||
|
|
||||||
|
settings[k] = v[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
return settings, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parsePostgresDSNSettings(s string) (map[string]string, error) {
|
||||||
|
settings := make(map[string]string)
|
||||||
|
|
||||||
|
nameMap := map[string]string{
|
||||||
|
"dbname": "database",
|
||||||
|
}
|
||||||
|
|
||||||
|
for len(s) > 0 {
|
||||||
|
var key, val string
|
||||||
|
eqIdx := strings.IndexRune(s, '=')
|
||||||
|
if eqIdx < 0 {
|
||||||
|
return nil, errors.New("invalid dsn")
|
||||||
|
}
|
||||||
|
|
||||||
|
key = strings.Trim(s[:eqIdx], " \t\n\r\v\f")
|
||||||
|
s = strings.TrimLeft(s[eqIdx+1:], " \t\n\r\v\f")
|
||||||
|
if len(s) == 0 {
|
||||||
|
} else if s[0] != '\'' {
|
||||||
|
end := 0
|
||||||
|
for ; end < len(s); end++ {
|
||||||
|
if asciiSpace[s[end]] == 1 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if s[end] == '\\' {
|
||||||
|
end++
|
||||||
|
if end == len(s) {
|
||||||
|
return nil, errors.New("invalid backslash")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
val = strings.Replace(strings.Replace(s[:end], "\\\\", "\\", -1), "\\'", "'", -1)
|
||||||
|
if end == len(s) {
|
||||||
|
s = ""
|
||||||
|
} else {
|
||||||
|
s = s[end+1:]
|
||||||
|
}
|
||||||
|
} else { // quoted string
|
||||||
|
s = s[1:]
|
||||||
|
end := 0
|
||||||
|
for ; end < len(s); end++ {
|
||||||
|
if s[end] == '\'' {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if s[end] == '\\' {
|
||||||
|
end++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if end == len(s) {
|
||||||
|
return nil, errors.New("unterminated quoted string in connection info string")
|
||||||
|
}
|
||||||
|
val = strings.Replace(strings.Replace(s[:end], "\\\\", "\\", -1), "\\'", "'", -1)
|
||||||
|
if end == len(s) {
|
||||||
|
s = ""
|
||||||
|
} else {
|
||||||
|
s = s[end+1:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if k, ok := nameMap[key]; ok {
|
||||||
|
key = k
|
||||||
|
}
|
||||||
|
|
||||||
|
if key == "" {
|
||||||
|
return nil, errors.New("invalid dsn")
|
||||||
|
}
|
||||||
|
|
||||||
|
settings[key] = val
|
||||||
|
}
|
||||||
|
|
||||||
|
return settings, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func isIPOnly(host string) bool {
|
||||||
|
return net.ParseIP(strings.Trim(host, "[]")) != nil || !strings.Contains(host, ":")
|
||||||
|
}
|
||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
@@ -17,6 +18,7 @@ import (
|
|||||||
"github.com/hashicorp/go-uuid"
|
"github.com/hashicorp/go-uuid"
|
||||||
"github.com/hashicorp/vault/sdk/database/dbplugin"
|
"github.com/hashicorp/vault/sdk/database/dbplugin"
|
||||||
"github.com/hashicorp/vault/sdk/database/helper/dbutil"
|
"github.com/hashicorp/vault/sdk/database/helper/dbutil"
|
||||||
|
"github.com/hashicorp/vault/sdk/helper/pluginutil"
|
||||||
"github.com/mitchellh/mapstructure"
|
"github.com/mitchellh/mapstructure"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -218,7 +220,13 @@ func (c *SQLConnectionProducer) Connection(ctx context.Context) (interface{}, er
|
|||||||
}
|
}
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
|
if driverName == "pgx" && os.Getenv(pluginutil.PluginUsePostgresSSLInline) != "" {
|
||||||
|
// TODO: remove this deprecated function call in a future SDK version
|
||||||
|
c.db, err = OpenPostgres(driverName, conn)
|
||||||
|
} else {
|
||||||
c.db, err = sql.Open(driverName, conn)
|
c.db, err = sql.Open(driverName, conn)
|
||||||
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -44,6 +44,11 @@ const (
|
|||||||
// colliding plugin-specific environment variables. Otherwise, plugin-specific
|
// colliding plugin-specific environment variables. Otherwise, plugin-specific
|
||||||
// environment variables take precedence over Vault process environment variables.
|
// environment variables take precedence over Vault process environment variables.
|
||||||
PluginUseLegacyEnvLayering = "VAULT_PLUGIN_USE_LEGACY_ENV_LAYERING"
|
PluginUseLegacyEnvLayering = "VAULT_PLUGIN_USE_LEGACY_ENV_LAYERING"
|
||||||
|
|
||||||
|
// PluginUsePostgresSSLInline enables the usage of a custom sslinline
|
||||||
|
// configuration as a shim to the pgx posgtres library.
|
||||||
|
// Deprecated: VAULT_PLUGIN_USE_POSTGRES_SSLINLINE will be removed in a future version of the Vault SDK.
|
||||||
|
PluginUsePostgresSSLInline = "VAULT_PLUGIN_USE_POSTGRES_SSLINLINE"
|
||||||
)
|
)
|
||||||
|
|
||||||
// OptionallyEnableMlock determines if mlock should be called, and if so enables
|
// OptionallyEnableMlock determines if mlock should be called, and if so enables
|
||||||
|
|||||||
Reference in New Issue
Block a user