Add test for multihost connection strings with Postgres (#16912)

Co-authored-by: Austin Gebauer <34121980+austingebauer@users.noreply.github.com>
This commit is contained in:
Robert
2022-09-22 14:00:56 -05:00
committed by GitHub
parent 3c6807d574
commit 4a4fa72ff3
3 changed files with 260 additions and 40 deletions

View File

@@ -131,10 +131,26 @@ func (s ServiceURL) URL() *url.URL {
// connection string (typically a URL) and nil, or empty string and an error.
type ServiceAdapter func(ctx context.Context, host string, port int) (ServiceConfig, error)
// StartService will start the runner's configured docker container with a
// random UUID suffix appended to the name to make it unique and will return
// either a hostname or local address depending on if a Docker network was given.
//
// Most tests can default to using this.
func (d *Runner) StartService(ctx context.Context, connect ServiceAdapter) (*Service, error) {
container, hostIPs, err := d.Start(context.Background())
serv, _, err := d.StartNewService(ctx, true, false, connect)
return serv, err
}
// StartNewService will start the runner's configured docker container but with the
// ability to control adding a name suffix or forcing a local address to be returned.
// 'addSuffix' will add a random UUID to the end of the container name.
// 'forceLocalAddr' will force the container address returned to be in the
// form of '127.0.0.1:1234' where 1234 is the mapped container port.
func (d *Runner) StartNewService(ctx context.Context, addSuffix, forceLocalAddr bool, connect ServiceAdapter) (*Service, string, error) {
container, hostIPs, containerID, err := d.Start(context.Background(), addSuffix, forceLocalAddr)
if err != nil {
return nil, err
return nil, "", err
}
cleanup := func() {
@@ -171,7 +187,7 @@ func (d *Runner) StartService(ctx context.Context, connect ServiceAdapter) (*Ser
pieces := strings.Split(hostIPs[0], ":")
portInt, err := strconv.Atoi(pieces[1])
if err != nil {
return nil, err
return nil, "", err
}
var config ServiceConfig
@@ -191,14 +207,14 @@ func (d *Runner) StartService(ctx context.Context, connect ServiceAdapter) (*Ser
if !d.RunOptions.DoNotAutoRemove {
cleanup()
}
return nil, err
return nil, "", err
}
return &Service{
Config: config,
Cleanup: cleanup,
Container: container,
}, nil
}, containerID, nil
}
type Service struct {
@@ -207,12 +223,15 @@ type Service struct {
Container *types.ContainerJSON
}
func (d *Runner) Start(ctx context.Context) (*types.ContainerJSON, []string, error) {
suffix, err := uuid.GenerateUUID()
if err != nil {
return nil, nil, err
func (d *Runner) Start(ctx context.Context, addSuffix, forceLocalAddr bool) (*types.ContainerJSON, []string, string, error) {
name := d.RunOptions.ContainerName
if addSuffix {
suffix, err := uuid.GenerateUUID()
if err != nil {
return nil, nil, "", err
}
name += "-" + suffix
}
name := d.RunOptions.ContainerName + "-" + suffix
cfg := &container.Config{
Hostname: name,
@@ -251,7 +270,7 @@ func (d *Runner) Start(ctx context.Context) (*types.ContainerJSON, []string, err
"password": d.RunOptions.AuthPassword,
}
if err := json.NewEncoder(&buf).Encode(auth); err != nil {
return nil, nil, err
return nil, nil, "", err
}
opts.RegistryAuth = base64.URLEncoding.EncodeToString(buf.Bytes())
}
@@ -262,47 +281,67 @@ func (d *Runner) Start(ctx context.Context) (*types.ContainerJSON, []string, err
c, err := d.DockerAPI.ContainerCreate(ctx, cfg, hostConfig, netConfig, nil, cfg.Hostname)
if err != nil {
return nil, nil, fmt.Errorf("container create failed: %v", err)
return nil, nil, "", fmt.Errorf("container create failed: %v", err)
}
for from, to := range d.RunOptions.CopyFromTo {
if err := copyToContainer(ctx, d.DockerAPI, c.ID, from, to); err != nil {
_ = d.DockerAPI.ContainerRemove(ctx, c.ID, types.ContainerRemoveOptions{})
return nil, nil, err
return nil, nil, "", err
}
}
err = d.DockerAPI.ContainerStart(ctx, c.ID, types.ContainerStartOptions{})
if err != nil {
_ = d.DockerAPI.ContainerRemove(ctx, c.ID, types.ContainerRemoveOptions{})
return nil, nil, fmt.Errorf("container start failed: %v", err)
return nil, nil, "", fmt.Errorf("container start failed: %v", err)
}
inspect, err := d.DockerAPI.ContainerInspect(ctx, c.ID)
if err != nil {
_ = d.DockerAPI.ContainerRemove(ctx, c.ID, types.ContainerRemoveOptions{})
return nil, nil, err
return nil, nil, "", err
}
var addrs []string
for _, port := range d.RunOptions.Ports {
pieces := strings.Split(port, "/")
if len(pieces) < 2 {
return nil, nil, fmt.Errorf("expected port of the form 1234/tcp, got: %s", port)
return nil, nil, "", fmt.Errorf("expected port of the form 1234/tcp, got: %s", port)
}
if d.RunOptions.NetworkID != "" {
if d.RunOptions.NetworkID != "" && !forceLocalAddr {
addrs = append(addrs, fmt.Sprintf("%s:%s", cfg.Hostname, pieces[0]))
} else {
mapped, ok := inspect.NetworkSettings.Ports[nat.Port(port)]
if !ok || len(mapped) == 0 {
return nil, nil, fmt.Errorf("no port mapping found for %s", port)
return nil, nil, "", fmt.Errorf("no port mapping found for %s", port)
}
addrs = append(addrs, fmt.Sprintf("127.0.0.1:%s", mapped[0].HostPort))
}
}
return &inspect, addrs, nil
return &inspect, addrs, c.ID, nil
}
func (d *Runner) Stop(ctx context.Context, containerID string) error {
timeout := 5 * time.Second
if err := d.DockerAPI.ContainerStop(ctx, containerID, &timeout); err != nil {
return err
}
return d.DockerAPI.NetworkDisconnect(ctx, d.RunOptions.NetworkID, containerID, true)
}
func (d *Runner) Restart(ctx context.Context, containerID string) error {
if err := d.DockerAPI.ContainerStart(ctx, containerID, types.ContainerStartOptions{}); err != nil {
return err
}
ends := &network.EndpointSettings{
NetworkID: d.RunOptions.NetworkID,
}
return d.DockerAPI.NetworkConnect(ctx, d.RunOptions.NetworkID, containerID, ends)
}
func copyToContainer(ctx context.Context, dapi *client.Client, containerID, from, to string) error {

View File

@@ -12,44 +12,74 @@ import (
)
func PrepareTestContainer(t *testing.T, version string) (func(), string) {
return prepareTestContainer(t, version, "secret", "database")
env := []string{
"POSTGRES_PASSWORD=secret",
"POSTGRES_DB=database",
}
_, cleanup, url, _ := prepareTestContainer(t, "postgres", "postgres", version, "secret", true, false, false, env)
return cleanup, url
}
func PrepareTestContainerWithPassword(t *testing.T, version, password string) (func(), string) {
return prepareTestContainer(t, version, password, "database")
env := []string{
"POSTGRES_PASSWORD=" + password,
"POSTGRES_DB=database",
}
_, cleanup, url, _ := prepareTestContainer(t, "postgres", "postgres", version, password, true, false, false, env)
return cleanup, url
}
func prepareTestContainer(t *testing.T, version, password, db string) (func(), string) {
func PrepareTestContainerRepmgr(t *testing.T, name, version string, envVars []string) (*docker.Runner, func(), string, string) {
env := append(envVars,
"REPMGR_PARTNER_NODES=psql-repl-node-0,psql-repl-node-1",
"REPMGR_PRIMARY_HOST=psql-repl-node-0",
"REPMGR_PASSWORD=repmgrpass",
"POSTGRESQL_PASSWORD=secret")
return prepareTestContainer(t, name, "bitnami/postgresql-repmgr", version, "secret", false, true, true, env)
}
func prepareTestContainer(t *testing.T, name, repo, version, password string,
addSuffix, forceLocalAddr, doNotAutoRemove bool, envVars []string,
) (*docker.Runner, func(), string, string) {
if os.Getenv("PG_URL") != "" {
return func() {}, os.Getenv("PG_URL")
return nil, func() {}, "", os.Getenv("PG_URL")
}
if version == "" {
version = "11"
}
runner, err := docker.NewServiceRunner(docker.RunOptions{
ImageRepo: "postgres",
ImageTag: version,
Env: []string{
"POSTGRES_PASSWORD=" + password,
"POSTGRES_DB=" + db,
},
Ports: []string{"5432/tcp"},
})
runOpts := docker.RunOptions{
ContainerName: name,
ImageRepo: repo,
ImageTag: version,
Env: envVars,
Ports: []string{"5432/tcp"},
DoNotAutoRemove: doNotAutoRemove,
}
if repo == "bitnami/postgresql-repmgr" {
runOpts.NetworkID = os.Getenv("POSTGRES_MULTIHOST_NET")
}
runner, err := docker.NewServiceRunner(runOpts)
if err != nil {
t.Fatalf("Could not start docker Postgres: %s", err)
}
svc, err := runner.StartService(context.Background(), connectPostgres(password))
svc, containerID, err := runner.StartNewService(context.Background(), addSuffix, forceLocalAddr, connectPostgres(password, repo))
if err != nil {
t.Fatalf("Could not start docker Postgres: %s", err)
}
return svc.Cleanup, svc.Config.URL().String()
return runner, svc.Cleanup, svc.Config.URL().String(), containerID
}
func connectPostgres(password string) docker.ServiceAdapter {
func connectPostgres(password, repo string) docker.ServiceAdapter {
return func(ctx context.Context, host string, port int) (docker.ServiceConfig, error) {
u := url.URL{
Scheme: "postgres",
@@ -65,10 +95,21 @@ func connectPostgres(password string) docker.ServiceAdapter {
}
defer db.Close()
err = db.Ping()
if err != nil {
if err = db.Ping(); err != nil {
return nil, err
}
return docker.NewServiceURL(u), nil
}
}
func StopContainer(t *testing.T, ctx context.Context, runner *docker.Runner, containerID string) {
if err := runner.Stop(ctx, containerID); err != nil {
t.Fatalf("Could not stop docker Postgres: %s", err)
}
}
func RestartContainer(t *testing.T, ctx context.Context, runner *docker.Runner, containerID string) {
if err := runner.Restart(ctx, containerID); err != nil {
t.Fatalf("Could not restart docker Postgres: %s", err)
}
}

View File

@@ -4,15 +4,16 @@ import (
"context"
"database/sql"
"fmt"
"os"
"strings"
"testing"
"time"
"github.com/hashicorp/vault/sdk/database/helper/dbutil"
"github.com/hashicorp/vault/helper/testhelpers/docker"
"github.com/hashicorp/vault/helper/testhelpers/postgresql"
"github.com/hashicorp/vault/sdk/database/dbplugin/v5"
dbtesting "github.com/hashicorp/vault/sdk/database/dbplugin/v5/testing"
"github.com/hashicorp/vault/sdk/database/helper/dbutil"
"github.com/hashicorp/vault/sdk/helper/template"
"github.com/stretchr/testify/require"
)
@@ -990,3 +991,142 @@ func TestNewUser_CustomUsername(t *testing.T) {
})
}
}
// This is a long-running integration test which tests the functionality of Postgres's multi-host
// connection strings. It uses two Postgres containers preconfigured with Replication Manager
// provided by Bitnami. This test currently does not run in CI and must be run manually. This is
// due to the test length, as it requires multiple sleep calls to ensure cluster setup and
// primary node failover occurs before the test steps continue.
//
// To run the test, set the environment variable POSTGRES_MULTIHOST_NET to the value of
// a docker network you've preconfigured, e.g.
// 'docker network create -d bridge postgres-repmgr'
// 'export POSTGRES_MULTIHOST_NET=postgres-repmgr'
func TestPostgreSQL_Repmgr(t *testing.T) {
_, exists := os.LookupEnv("POSTGRES_MULTIHOST_NET")
if !exists {
t.Skipf("POSTGRES_MULTIHOST_NET not set, skipping test")
}
// Run two postgres-repmgr containers in a replication cluster
db0, runner0, url0, container0 := testPostgreSQL_Repmgr_Container(t, "psql-repl-node-0")
_, _, url1, _ := testPostgreSQL_Repmgr_Container(t, "psql-repl-node-1")
ctx, cancel := context.WithTimeout(context.Background(), 300*time.Second)
defer cancel()
time.Sleep(10 * time.Second)
// Write a read role to the cluster
_, err := db0.NewUser(ctx, dbplugin.NewUserRequest{
Statements: dbplugin.Statements{
Commands: []string{
`CREATE ROLE "ro" NOINHERIT;
GRANT SELECT ON ALL TABLES IN SCHEMA public TO "ro";`,
},
},
})
if err != nil {
t.Fatalf("no error expected, got: %s", err)
}
// Open a connection to both databases using the multihost connection string
connectionDetails := map[string]interface{}{
"connection_url": fmt.Sprintf("postgresql://{{username}}:{{password}}@%s,%s/postgres?target_session_attrs=read-write", getHost(url0), getHost(url1)),
"username": "postgres",
"password": "secret",
}
req := dbplugin.InitializeRequest{
Config: connectionDetails,
VerifyConnection: true,
}
db := new()
dbtesting.AssertInitialize(t, db, req)
if !db.Initialized {
t.Fatal("Database should be initialized")
}
defer db.Close()
// Add a user to the cluster, then stop the primary container
if err = testPostgreSQL_Repmgr_AddUser(ctx, db); err != nil {
t.Fatalf("no error expected, got: %s", err)
}
postgresql.StopContainer(t, ctx, runner0, container0)
// Try adding a new user immediately - expect failure as the database
// cluster is still switching primaries
err = testPostgreSQL_Repmgr_AddUser(ctx, db)
if !strings.HasSuffix(err.Error(), "ValidateConnect failed (read only connection)") {
t.Fatalf("expected error was not received, got: %s", err)
}
time.Sleep(20 * time.Second)
// Try adding a new user again which should succeed after the sleep
// as the primary failover should have finished. Then, restart
// the first container which should become a secondary DB.
if err = testPostgreSQL_Repmgr_AddUser(ctx, db); err != nil {
t.Fatalf("no error expected, got: %s", err)
}
postgresql.RestartContainer(t, ctx, runner0, container0)
time.Sleep(10 * time.Second)
// A final new user to add, which should succeed after the secondary joins.
if err = testPostgreSQL_Repmgr_AddUser(ctx, db); err != nil {
t.Fatalf("no error expected, got: %s", err)
}
if err := db.Close(); err != nil {
t.Fatalf("err: %s", err)
}
}
func testPostgreSQL_Repmgr_Container(t *testing.T, name string) (*PostgreSQL, *docker.Runner, string, string) {
envVars := []string{
"REPMGR_NODE_NAME=" + name,
"REPMGR_NODE_NETWORK_NAME=" + name,
}
runner, cleanup, connURL, containerID := postgresql.PrepareTestContainerRepmgr(t, name, "13.4.0", envVars)
t.Cleanup(cleanup)
connectionDetails := map[string]interface{}{
"connection_url": connURL,
}
req := dbplugin.InitializeRequest{
Config: connectionDetails,
VerifyConnection: true,
}
db := new()
dbtesting.AssertInitialize(t, db, req)
if !db.Initialized {
t.Fatal("Database should be initialized")
}
if err := db.Close(); err != nil {
t.Fatalf("err: %s", err)
}
return db, runner, connURL, containerID
}
func testPostgreSQL_Repmgr_AddUser(ctx context.Context, db *PostgreSQL) error {
_, err := db.NewUser(ctx, dbplugin.NewUserRequest{
Statements: dbplugin.Statements{
Commands: []string{
`CREATE ROLE "{{name}}" WITH LOGIN PASSWORD '{{password}}' VALID UNTIL '{{expiration}}' INHERIT;
GRANT ro TO "{{name}}";`,
},
},
})
return err
}
func getHost(url string) string {
splitCreds := strings.Split(url, "@")[1]
return strings.Split(splitCreds, "/")[0]
}