mirror of
				https://github.com/optim-enterprises-bv/vault.git
				synced 2025-10-31 10:37:56 +00:00 
			
		
		
		
	Postgres: Correct parsing of multiline statements (#8512)
* add test reproducing issue * add code fixing issue * check for END in unquoted string frags * move delimiters inside parens * begin checking with stmt * PR feedback * fix comment * add tests with templates * update test name * remove unnecessary backslashes from test
This commit is contained in:
		| @@ -5,6 +5,7 @@ import ( | ||||
| 	"database/sql" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"regexp" | ||||
| 	"strings" | ||||
| 	"time" | ||||
|  | ||||
| @@ -29,7 +30,22 @@ ALTER ROLE "{{username}}" WITH PASSWORD '{{password}}'; | ||||
| ` | ||||
| ) | ||||
|  | ||||
| var _ dbplugin.Database = &PostgreSQL{} | ||||
| var ( | ||||
| 	_ dbplugin.Database = &PostgreSQL{} | ||||
|  | ||||
| 	// postgresEndStatement is basically the word "END" but | ||||
| 	// surrounded by a word boundary to differentiate it from | ||||
| 	// other words like "APPEND". | ||||
| 	postgresEndStatement = regexp.MustCompile(`\bEND\b`) | ||||
|  | ||||
| 	// doubleQuotedPhrases finds substrings like "hello" | ||||
| 	// and pulls them out with the quotes included. | ||||
| 	doubleQuotedPhrases = regexp.MustCompile(`(".*?")`) | ||||
|  | ||||
| 	// singleQuotedPhrases finds substrings like 'hello' | ||||
| 	// and pulls them out with the quotes included. | ||||
| 	singleQuotedPhrases = regexp.MustCompile(`('.*?')`) | ||||
| ) | ||||
|  | ||||
| // New implements builtinplugins.BuiltinFactory | ||||
| func New() (interface{}, error) { | ||||
| @@ -206,6 +222,20 @@ func (p *PostgreSQL) CreateUser(ctx context.Context, statements dbplugin.Stateme | ||||
|  | ||||
| 	// Execute each query | ||||
| 	for _, stmt := range statements.Creation { | ||||
| 		if containsMultilineStatement(stmt) { | ||||
| 			// Execute it as-is. | ||||
| 			m := map[string]string{ | ||||
| 				"name":       username, | ||||
| 				"username":   username, | ||||
| 				"password":   password, | ||||
| 				"expiration": expirationStr, | ||||
| 			} | ||||
| 			if err := dbtxn.ExecuteTxQuery(ctx, tx, m, stmt); err != nil { | ||||
| 				return "", "", err | ||||
| 			} | ||||
| 			continue | ||||
| 		} | ||||
| 		// Otherwise, it's fine to split the statements on the semicolon. | ||||
| 		for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") { | ||||
| 			query = strings.TrimSpace(query) | ||||
| 			if len(query) == 0 { | ||||
| @@ -501,3 +531,40 @@ func (p *PostgreSQL) RotateRootCredentials(ctx context.Context, statements []str | ||||
| 	p.RawConfig["password"] = password | ||||
| 	return p.RawConfig, nil | ||||
| } | ||||
|  | ||||
| // containsMultilineStatement is a best effort to determine whether | ||||
| // a particular statement is multiline, and therefore should not be | ||||
| // split upon semicolons. If it's unsure, it defaults to false. | ||||
| func containsMultilineStatement(stmt string) bool { | ||||
| 	// We're going to look for the word "END", but first let's ignore | ||||
| 	// anything the user provided within single or double quotes since | ||||
| 	// we're looking for an "END" within the Postgres syntax. | ||||
| 	literals, err := extractQuotedStrings(stmt) | ||||
| 	if err != nil { | ||||
| 		return false | ||||
| 	} | ||||
| 	stmtWithoutLiterals := stmt | ||||
| 	for _, literal := range literals { | ||||
| 		stmtWithoutLiterals = strings.Replace(stmt, literal, "", -1) | ||||
| 	} | ||||
| 	// Now look for the word "END" specifically. This will miss any | ||||
| 	// representations of END that aren't surrounded by spaces, but | ||||
| 	// it should be easy to change on the user's side. | ||||
| 	return postgresEndStatement.MatchString(stmtWithoutLiterals) | ||||
| } | ||||
|  | ||||
| // extractQuotedStrings extracts 0 or many substrings | ||||
| // that have been single- or double-quoted. Ex: | ||||
| // `"Hello", silly 'elephant' from the "zoo".` | ||||
| // returns [ `Hello`, `'elephant'`, `"zoo"` ] | ||||
| func extractQuotedStrings(s string) ([]string, error) { | ||||
| 	var found []string | ||||
| 	toFind := []*regexp.Regexp{ | ||||
| 		doubleQuotedPhrases, | ||||
| 		singleQuotedPhrases, | ||||
| 	} | ||||
| 	for _, typeOfPhrase := range toFind { | ||||
| 		found = append(found, typeOfPhrase.FindAllString(s, -1)...) | ||||
| 	} | ||||
| 	return found, nil | ||||
| } | ||||
|   | ||||
| @@ -114,7 +114,8 @@ func TestPostgreSQL_CreateUser_missingArgs(t *testing.T) { | ||||
|  | ||||
| func TestPostgreSQL_CreateUser(t *testing.T) { | ||||
| 	type testCase struct { | ||||
| 		createStmts []string | ||||
| 		createStmts          []string | ||||
| 		shouldTestCredsExist bool | ||||
| 	} | ||||
|  | ||||
| 	tests := map[string]testCase{ | ||||
| @@ -126,6 +127,7 @@ func TestPostgreSQL_CreateUser(t *testing.T) { | ||||
| 				  VALID UNTIL '{{expiration}}'; | ||||
| 				GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO "{{name}}";`, | ||||
| 			}, | ||||
| 			shouldTestCredsExist: true, | ||||
| 		}, | ||||
| 		"admin username": { | ||||
| 			createStmts: []string{` | ||||
| @@ -135,6 +137,7 @@ func TestPostgreSQL_CreateUser(t *testing.T) { | ||||
| 				  VALID UNTIL '{{expiration}}'; | ||||
| 				GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO "{{username}}";`, | ||||
| 			}, | ||||
| 			shouldTestCredsExist: true, | ||||
| 		}, | ||||
| 		"read only name": { | ||||
| 			createStmts: []string{` | ||||
| @@ -145,6 +148,7 @@ func TestPostgreSQL_CreateUser(t *testing.T) { | ||||
| 				GRANT SELECT ON ALL TABLES IN SCHEMA public TO "{{name}}"; | ||||
| 				GRANT SELECT ON ALL SEQUENCES IN SCHEMA public TO "{{name}}";`, | ||||
| 			}, | ||||
| 			shouldTestCredsExist: true, | ||||
| 		}, | ||||
| 		"read only username": { | ||||
| 			createStmts: []string{` | ||||
| @@ -155,6 +159,23 @@ func TestPostgreSQL_CreateUser(t *testing.T) { | ||||
| 				GRANT SELECT ON ALL TABLES IN SCHEMA public TO "{{username}}"; | ||||
| 				GRANT SELECT ON ALL SEQUENCES IN SCHEMA public TO "{{username}}";`, | ||||
| 			}, | ||||
| 			shouldTestCredsExist: true, | ||||
| 		}, | ||||
| 		// https://github.com/hashicorp/vault/issues/6098 | ||||
| 		"reproduce GH-6098": { | ||||
| 			createStmts: []string{ | ||||
| 				// NOTE: "rolname" in the following line is not a typo. | ||||
| 				"DO $$ BEGIN IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname='my_role') THEN CREATE ROLE my_role; END IF; END $$", | ||||
| 			}, | ||||
| 			// This test statement doesn't generate creds. | ||||
| 			shouldTestCredsExist: false, | ||||
| 		}, | ||||
| 		"reproduce issue with template": { | ||||
| 			createStmts: []string{ | ||||
| 				`DO $$ BEGIN IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname='my_role') THEN CREATE ROLE "{{username}}"; END IF; END $$`, | ||||
| 			}, | ||||
| 			// This test statement doesn't generate creds. | ||||
| 			shouldTestCredsExist: false, | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| @@ -192,6 +213,11 @@ func TestPostgreSQL_CreateUser(t *testing.T) { | ||||
| 				t.Fatalf("err: %s", err) | ||||
| 			} | ||||
|  | ||||
| 			if !test.shouldTestCredsExist { | ||||
| 				// We're done here. | ||||
| 				return | ||||
| 			} | ||||
|  | ||||
| 			if err = testCredsExist(t, connURL, username, password); err != nil { | ||||
| 				t.Fatalf("Could not connect with new credentials: %s", err) | ||||
| 			} | ||||
| @@ -657,3 +683,81 @@ func createTestPGUser(t *testing.T, connURL string, username, password, query st | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestContainsMultilineStatement(t *testing.T) { | ||||
| 	type testCase struct { | ||||
| 		Input    string | ||||
| 		Expected bool | ||||
| 	} | ||||
|  | ||||
| 	testCases := map[string]*testCase{ | ||||
| 		"issue 6098 repro": { | ||||
| 			Input:    `DO $$ BEGIN IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname='my_role') THEN CREATE ROLE my_role; END IF; END $$`, | ||||
| 			Expected: true, | ||||
| 		}, | ||||
| 		"multiline with template fields": { | ||||
| 			Input:    `DO $$ BEGIN IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname="{{name}}") THEN CREATE ROLE {{name}}; END IF; END $$`, | ||||
| 			Expected: true, | ||||
| 		}, | ||||
| 		"docs example": { | ||||
| 			Input: `CREATE ROLE "{{name}}" WITH LOGIN PASSWORD '{{password}}' VALID UNTIL '{{expiration}}'; \ | ||||
|         GRANT SELECT ON ALL TABLES IN SCHEMA public TO "{{name}}";`, | ||||
| 			Expected: false, | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	for tName, tCase := range testCases { | ||||
| 		t.Run(tName, func(t *testing.T) { | ||||
| 			if containsMultilineStatement(tCase.Input) != tCase.Expected { | ||||
| 				t.Fatalf("%q should be %t for multiline input", tCase.Input, tCase.Expected) | ||||
| 			} | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestExtractQuotedStrings(t *testing.T) { | ||||
| 	type testCase struct { | ||||
| 		Input    string | ||||
| 		Expected []string | ||||
| 	} | ||||
|  | ||||
| 	testCases := map[string]*testCase{ | ||||
| 		"no quotes": { | ||||
| 			Input:    `Five little monkeys jumping on the bed`, | ||||
| 			Expected: []string{}, | ||||
| 		}, | ||||
| 		"two of both quote types": { | ||||
| 			Input:    `"Five" little 'monkeys' "jumping on" the' 'bed`, | ||||
| 			Expected: []string{`"Five"`, `"jumping on"`, `'monkeys'`, `' '`}, | ||||
| 		}, | ||||
| 		"one single quote": { | ||||
| 			Input:    `Five little monkeys 'jumping on the bed`, | ||||
| 			Expected: []string{}, | ||||
| 		}, | ||||
| 		"empty string": { | ||||
| 			Input:    ``, | ||||
| 			Expected: []string{}, | ||||
| 		}, | ||||
| 		"templated field": { | ||||
| 			Input:    `DO $$ BEGIN IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname="{{name}}") THEN CREATE ROLE {{name}}; END IF; END $$`, | ||||
| 			Expected: []string{`"{{name}}"`}, | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	for tName, tCase := range testCases { | ||||
| 		t.Run(tName, func(t *testing.T) { | ||||
| 			results, err := extractQuotedStrings(tCase.Input) | ||||
| 			if err != nil { | ||||
| 				t.Fatal(err) | ||||
| 			} | ||||
| 			if len(results) != len(tCase.Expected) { | ||||
| 				t.Fatalf("%s isn't equal to %s", results, tCase.Expected) | ||||
| 			} | ||||
| 			for i := range results { | ||||
| 				if results[i] != tCase.Expected[i] { | ||||
| 					t.Fatalf(`expected %q but received %q`, tCase.Expected, results[i]) | ||||
| 				} | ||||
| 			} | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Becca Petrin
					Becca Petrin