Postgres: Correct parsing of multiline statements (#8512)

* add test reproducing issue

* add code fixing issue

* check for END in unquoted string frags

* move delimiters inside parens

* begin checking with stmt

* PR feedback

* fix comment

* add tests with templates

* update test name

* remove unnecessary backslashes from test
This commit is contained in:
Becca Petrin
2020-03-17 12:45:25 -07:00
committed by GitHub
parent 000dc498e0
commit 87d7180204
2 changed files with 173 additions and 2 deletions

View File

@@ -5,6 +5,7 @@ import (
"database/sql"
"errors"
"fmt"
"regexp"
"strings"
"time"
@@ -29,7 +30,22 @@ ALTER ROLE "{{username}}" WITH PASSWORD '{{password}}';
`
)
var _ dbplugin.Database = &PostgreSQL{}
var (
_ dbplugin.Database = &PostgreSQL{}
// postgresEndStatement is basically the word "END" but
// surrounded by a word boundary to differentiate it from
// other words like "APPEND".
postgresEndStatement = regexp.MustCompile(`\bEND\b`)
// doubleQuotedPhrases finds substrings like "hello"
// and pulls them out with the quotes included.
doubleQuotedPhrases = regexp.MustCompile(`(".*?")`)
// singleQuotedPhrases finds substrings like 'hello'
// and pulls them out with the quotes included.
singleQuotedPhrases = regexp.MustCompile(`('.*?')`)
)
// New implements builtinplugins.BuiltinFactory
func New() (interface{}, error) {
@@ -206,6 +222,20 @@ func (p *PostgreSQL) CreateUser(ctx context.Context, statements dbplugin.Stateme
// Execute each query
for _, stmt := range statements.Creation {
if containsMultilineStatement(stmt) {
// Execute it as-is.
m := map[string]string{
"name": username,
"username": username,
"password": password,
"expiration": expirationStr,
}
if err := dbtxn.ExecuteTxQuery(ctx, tx, m, stmt); err != nil {
return "", "", err
}
continue
}
// Otherwise, it's fine to split the statements on the semicolon.
for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
query = strings.TrimSpace(query)
if len(query) == 0 {
@@ -501,3 +531,40 @@ func (p *PostgreSQL) RotateRootCredentials(ctx context.Context, statements []str
p.RawConfig["password"] = password
return p.RawConfig, nil
}
// containsMultilineStatement is a best effort to determine whether
// a particular statement is multiline, and therefore should not be
// split upon semicolons. If it's unsure, it defaults to false.
func containsMultilineStatement(stmt string) bool {
// We're going to look for the word "END", but first let's ignore
// anything the user provided within single or double quotes since
// we're looking for an "END" within the Postgres syntax.
literals, err := extractQuotedStrings(stmt)
if err != nil {
return false
}
stmtWithoutLiterals := stmt
for _, literal := range literals {
stmtWithoutLiterals = strings.Replace(stmt, literal, "", -1)
}
// Now look for the word "END" specifically. This will miss any
// representations of END that aren't surrounded by spaces, but
// it should be easy to change on the user's side.
return postgresEndStatement.MatchString(stmtWithoutLiterals)
}
// extractQuotedStrings extracts 0 or many substrings
// that have been single- or double-quoted. Ex:
// `"Hello", silly 'elephant' from the "zoo".`
// returns [ `Hello`, `'elephant'`, `"zoo"` ]
func extractQuotedStrings(s string) ([]string, error) {
var found []string
toFind := []*regexp.Regexp{
doubleQuotedPhrases,
singleQuotedPhrases,
}
for _, typeOfPhrase := range toFind {
found = append(found, typeOfPhrase.FindAllString(s, -1)...)
}
return found, nil
}

View File

@@ -114,7 +114,8 @@ func TestPostgreSQL_CreateUser_missingArgs(t *testing.T) {
func TestPostgreSQL_CreateUser(t *testing.T) {
type testCase struct {
createStmts []string
createStmts []string
shouldTestCredsExist bool
}
tests := map[string]testCase{
@@ -126,6 +127,7 @@ func TestPostgreSQL_CreateUser(t *testing.T) {
VALID UNTIL '{{expiration}}';
GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO "{{name}}";`,
},
shouldTestCredsExist: true,
},
"admin username": {
createStmts: []string{`
@@ -135,6 +137,7 @@ func TestPostgreSQL_CreateUser(t *testing.T) {
VALID UNTIL '{{expiration}}';
GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO "{{username}}";`,
},
shouldTestCredsExist: true,
},
"read only name": {
createStmts: []string{`
@@ -145,6 +148,7 @@ func TestPostgreSQL_CreateUser(t *testing.T) {
GRANT SELECT ON ALL TABLES IN SCHEMA public TO "{{name}}";
GRANT SELECT ON ALL SEQUENCES IN SCHEMA public TO "{{name}}";`,
},
shouldTestCredsExist: true,
},
"read only username": {
createStmts: []string{`
@@ -155,6 +159,23 @@ func TestPostgreSQL_CreateUser(t *testing.T) {
GRANT SELECT ON ALL TABLES IN SCHEMA public TO "{{username}}";
GRANT SELECT ON ALL SEQUENCES IN SCHEMA public TO "{{username}}";`,
},
shouldTestCredsExist: true,
},
// https://github.com/hashicorp/vault/issues/6098
"reproduce GH-6098": {
createStmts: []string{
// NOTE: "rolname" in the following line is not a typo.
"DO $$ BEGIN IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname='my_role') THEN CREATE ROLE my_role; END IF; END $$",
},
// This test statement doesn't generate creds.
shouldTestCredsExist: false,
},
"reproduce issue with template": {
createStmts: []string{
`DO $$ BEGIN IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname='my_role') THEN CREATE ROLE "{{username}}"; END IF; END $$`,
},
// This test statement doesn't generate creds.
shouldTestCredsExist: false,
},
}
@@ -192,6 +213,11 @@ func TestPostgreSQL_CreateUser(t *testing.T) {
t.Fatalf("err: %s", err)
}
if !test.shouldTestCredsExist {
// We're done here.
return
}
if err = testCredsExist(t, connURL, username, password); err != nil {
t.Fatalf("Could not connect with new credentials: %s", err)
}
@@ -657,3 +683,81 @@ func createTestPGUser(t *testing.T, connURL string, username, password, query st
t.Fatal(err)
}
}
func TestContainsMultilineStatement(t *testing.T) {
type testCase struct {
Input string
Expected bool
}
testCases := map[string]*testCase{
"issue 6098 repro": {
Input: `DO $$ BEGIN IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname='my_role') THEN CREATE ROLE my_role; END IF; END $$`,
Expected: true,
},
"multiline with template fields": {
Input: `DO $$ BEGIN IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname="{{name}}") THEN CREATE ROLE {{name}}; END IF; END $$`,
Expected: true,
},
"docs example": {
Input: `CREATE ROLE "{{name}}" WITH LOGIN PASSWORD '{{password}}' VALID UNTIL '{{expiration}}'; \
GRANT SELECT ON ALL TABLES IN SCHEMA public TO "{{name}}";`,
Expected: false,
},
}
for tName, tCase := range testCases {
t.Run(tName, func(t *testing.T) {
if containsMultilineStatement(tCase.Input) != tCase.Expected {
t.Fatalf("%q should be %t for multiline input", tCase.Input, tCase.Expected)
}
})
}
}
func TestExtractQuotedStrings(t *testing.T) {
type testCase struct {
Input string
Expected []string
}
testCases := map[string]*testCase{
"no quotes": {
Input: `Five little monkeys jumping on the bed`,
Expected: []string{},
},
"two of both quote types": {
Input: `"Five" little 'monkeys' "jumping on" the' 'bed`,
Expected: []string{`"Five"`, `"jumping on"`, `'monkeys'`, `' '`},
},
"one single quote": {
Input: `Five little monkeys 'jumping on the bed`,
Expected: []string{},
},
"empty string": {
Input: ``,
Expected: []string{},
},
"templated field": {
Input: `DO $$ BEGIN IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname="{{name}}") THEN CREATE ROLE {{name}}; END IF; END $$`,
Expected: []string{`"{{name}}"`},
},
}
for tName, tCase := range testCases {
t.Run(tName, func(t *testing.T) {
results, err := extractQuotedStrings(tCase.Input)
if err != nil {
t.Fatal(err)
}
if len(results) != len(tCase.Expected) {
t.Fatalf("%s isn't equal to %s", results, tCase.Expected)
}
for i := range results {
if results[i] != tCase.Expected[i] {
t.Fatalf(`expected %q but received %q`, tCase.Expected, results[i])
}
}
})
}
}