postgres: replace the package lib/pq with pgx (#15343)

* WIP replacing lib/pq

* change timezome param to be URI format

* add changelog

* add changelog for redshift

* update changelog

* add test for DSN style connection string

* more parseurl and quoteidentify to sdk; include copyright and license

* call dbutil.ParseURL instead, fix import ordering

Co-authored-by: Calvin Leung Huang <1883212+calvn@users.noreply.github.com>
This commit is contained in:
Jim Kalafut
2022-05-23 12:49:18 -07:00
committed by GitHub
parent 4f21baa69a
commit c5a88aa1a6
23 changed files with 350 additions and 110 deletions

View File

@@ -8,7 +8,6 @@ import (
"sync"
log "github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/logical"
)
@@ -108,10 +107,10 @@ func (b *backend) DB(ctx context.Context, s logical.Storage) (*sql.DB, error) {
conn += "?timezone=utc"
}
} else {
conn += " timezone=utc"
conn += "&timezone=utc"
}
b.db, err = sql.Open("postgres", conn)
b.db, err = sql.Open("pgx", conn)
if err != nil {
return nil, err
}

View File

@@ -13,7 +13,6 @@ import (
logicaltest "github.com/hashicorp/vault/helper/testhelpers/logical"
postgreshelper "github.com/hashicorp/vault/helper/testhelpers/postgresql"
"github.com/hashicorp/vault/sdk/logical"
"github.com/lib/pq"
"github.com/mitchellh/mapstructure"
)
@@ -272,14 +271,8 @@ func testAccStepReadCreds(t *testing.T, b logical.Backend, s logical.Storage, na
return err
}
log.Printf("[TRACE] Generated credentials: %v", d)
conn, err := pq.ParseURL(connURL)
if err != nil {
t.Fatal(err)
}
conn += " timezone=utc"
db, err := sql.Open("postgres", conn)
db, err := sql.Open("pgx", connURL+"&timezone=utc")
if err != nil {
t.Fatal(err)
}
@@ -356,14 +349,8 @@ func testAccStepCreateTable(t *testing.T, b logical.Backend, s logical.Storage,
return err
}
log.Printf("[TRACE] Generated credentials: %v", d)
conn, err := pq.ParseURL(connURL)
if err != nil {
t.Fatal(err)
}
conn += " timezone=utc"
db, err := sql.Open("postgres", conn)
db, err := sql.Open("pgx", connURL+"&timezone=utc")
if err != nil {
t.Fatal(err)
}
@@ -410,14 +397,8 @@ func testAccStepDropTable(t *testing.T, b logical.Backend, s logical.Storage, na
return err
}
log.Printf("[TRACE] Generated credentials: %v", d)
conn, err := pq.ParseURL(connURL)
if err != nil {
t.Fatal(err)
}
conn += " timezone=utc"
db, err := sql.Open("postgres", conn)
db, err := sql.Open("pgx", connURL+"&timezone=utc")
if err != nil {
t.Fatal(err)
}

View File

@@ -7,7 +7,7 @@ import (
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/logical"
_ "github.com/lib/pq"
_ "github.com/jackc/pgx/v4/stdlib"
)
func pathConfigConnection(b *backend) *framework.Path {
@@ -109,7 +109,7 @@ func (b *backend) pathConnectionWrite(ctx context.Context, req *logical.Request,
verifyConnection := data.Get("verify_connection").(bool)
if verifyConnection {
// Verify the string
db, err := sql.Open("postgres", connURL)
db, err := sql.Open("pgx", connURL)
if err != nil {
return logical.ErrorResponse(fmt.Sprintf(
"Error validating connection info: %s", err)), nil

View File

@@ -11,7 +11,7 @@ import (
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/dbtxn"
"github.com/hashicorp/vault/sdk/logical"
_ "github.com/lib/pq"
_ "github.com/jackc/pgx/v4/stdlib"
)
func pathRoleCreate(b *backend) *framework.Path {

View File

@@ -7,11 +7,12 @@ import (
"strings"
"time"
"github.com/hashicorp/vault/sdk/database/helper/dbutil"
"github.com/hashicorp/go-secure-stdlib/strutil"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/dbtxn"
"github.com/hashicorp/vault/sdk/logical"
"github.com/lib/pq"
)
const SecretCredsType = "creds"
@@ -75,7 +76,7 @@ func (b *backend) secretCredsRenew(ctx context.Context, req *logical.Request, d
query := fmt.Sprintf(
"ALTER ROLE %s VALID UNTIL '%s';",
pq.QuoteIdentifier(username),
dbutil.QuoteIdentifier(username),
expiration)
stmt, err := db.Prepare(query)
if err != nil {
@@ -171,27 +172,27 @@ func (b *backend) secretCredsRevoke(ctx context.Context, req *logical.Request, d
}
revocationStmts = append(revocationStmts, fmt.Sprintf(
`REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA %s FROM %s;`,
pq.QuoteIdentifier(schema),
pq.QuoteIdentifier(username)))
dbutil.QuoteIdentifier(schema),
dbutil.QuoteIdentifier(username)))
revocationStmts = append(revocationStmts, fmt.Sprintf(
`REVOKE USAGE ON SCHEMA %s FROM %s;`,
pq.QuoteIdentifier(schema),
pq.QuoteIdentifier(username)))
dbutil.QuoteIdentifier(schema),
dbutil.QuoteIdentifier(username)))
}
// for good measure, revoke all privileges and usage on schema public
revocationStmts = append(revocationStmts, fmt.Sprintf(
`REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA public FROM %s;`,
pq.QuoteIdentifier(username)))
dbutil.QuoteIdentifier(username)))
revocationStmts = append(revocationStmts, fmt.Sprintf(
"REVOKE ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public FROM %s;",
pq.QuoteIdentifier(username)))
dbutil.QuoteIdentifier(username)))
revocationStmts = append(revocationStmts, fmt.Sprintf(
"REVOKE USAGE ON SCHEMA public FROM %s;",
pq.QuoteIdentifier(username)))
dbutil.QuoteIdentifier(username)))
// get the current database name so we can issue a REVOKE CONNECT for
// this username
@@ -203,8 +204,8 @@ func (b *backend) secretCredsRevoke(ctx context.Context, req *logical.Request, d
if dbname.Valid {
revocationStmts = append(revocationStmts, fmt.Sprintf(
`REVOKE CONNECT ON DATABASE %s FROM %s;`,
pq.QuoteIdentifier(dbname.String),
pq.QuoteIdentifier(username)))
dbutil.QuoteIdentifier(dbname.String),
dbutil.QuoteIdentifier(username)))
}
// again, here, we do not stop on error, as we want to remove as
@@ -226,7 +227,7 @@ func (b *backend) secretCredsRevoke(ctx context.Context, req *logical.Request, d
// Drop this user
stmt, err = db.Prepare(fmt.Sprintf(
`DROP ROLE IF EXISTS %s;`, pq.QuoteIdentifier(username)))
`DROP ROLE IF EXISTS %s;`, dbutil.QuoteIdentifier(username)))
if err != nil {
return nil, err
}