diff --git a/changelog/22249.txt b/changelog/22249.txt new file mode 100644 index 0000000000..d470b9743f --- /dev/null +++ b/changelog/22249.txt @@ -0,0 +1,4 @@ +```release-note:bug +sdk/ldaputil: Properly escape user filters when using UPN domains +sdk/ldaputil: use EscapeLDAPValue implementation from cap/ldap +``` \ No newline at end of file diff --git a/sdk/helper/ldaputil/client.go b/sdk/helper/ldaputil/client.go index c3562a5d40..61320bccfb 100644 --- a/sdk/helper/ldaputil/client.go +++ b/sdk/helper/ldaputil/client.go @@ -8,6 +8,7 @@ import ( "crypto/tls" "crypto/x509" "encoding/binary" + "encoding/hex" "fmt" "math" "net" @@ -227,7 +228,11 @@ func (c *Client) RenderUserSearchFilter(cfg *ConfigEntry, username string) (stri } if cfg.UPNDomain != "" { context.UserAttr = "userPrincipalName" - context.Username = fmt.Sprintf("%s@%s", EscapeLDAPValue(username), cfg.UPNDomain) + // Intentionally, calling EscapeFilter(...) (vs EscapeValue) since the + // username is being injected into a search filter. + // As an untrusted string, the username must be escaped according to RFC + // 4515, in order to prevent attackers from injecting characters that could modify the filter + context.Username = fmt.Sprintf("%s@%s", ldap.EscapeFilter(username), cfg.UPNDomain) } // Execute the template. Note that the template context contains escaped input and does @@ -595,42 +600,59 @@ func (c *Client) GetLdapGroups(cfg *ConfigEntry, conn Connection, userDN string, } // EscapeLDAPValue is exported because a plugin uses it outside this package. +// EscapeLDAPValue will properly escape the input string as an ldap value +// rfc4514 states the following must be escaped: +// - leading space or hash +// - trailing space +// - special characters '"', '+', ',', ';', '<', '>', '\\' +// - hex func EscapeLDAPValue(input string) string { if input == "" { return "" } - // RFC4514 forbids un-escaped: - // - leading space or hash - // - trailing space - // - special characters '"', '+', ',', ';', '<', '>', '\\' - // - null - for i := 0; i < len(input); i++ { - escaped := false - if input[i] == '\\' && i+1 < len(input)-1 { - i++ - escaped = true - } - switch input[i] { - case '"', '+', ',', ';', '<', '>', '\\': - if !escaped { - input = input[0:i] + "\\" + input[i:] - i++ - } + buf := bytes.Buffer{} + + escFn := func(c byte) { + buf.WriteByte('\\') + buf.WriteByte(c) + } + + inputLen := len(input) + for i := 0; i < inputLen; i++ { + char := input[i] + switch { + case i == 0 && char == ' ' || char == '#': + // leading space or hash. + escFn(char) continue - } - if escaped { - input = input[0:i] + "\\" + input[i:] - i++ + case i == inputLen-1 && char == ' ': + // trailing space. + escFn(char) + continue + case specialChar(char): + escFn(char) + continue + case char < ' ' || char > '~': + // anything that's not between the ascii space and tilde must be hex + buf.WriteByte('\\') + buf.WriteString(hex.EncodeToString([]byte{char})) + continue + default: + // everything remaining, doesn't need to be escaped + buf.WriteByte(char) } } - if input[0] == ' ' || input[0] == '#' { - input = "\\" + input + return buf.String() +} + +func specialChar(char byte) bool { + switch char { + case '"', '+', ',', ';', '<', '>', '\\': + return true + default: + return false } - if input[len(input)-1] == ' ' { - input = input[0:len(input)-1] + "\\ " - } - return input } /* diff --git a/sdk/helper/ldaputil/client_test.go b/sdk/helper/ldaputil/client_test.go index 167d50f22d..dcce9c6e0e 100644 --- a/sdk/helper/ldaputil/client_test.go +++ b/sdk/helper/ldaputil/client_test.go @@ -7,6 +7,8 @@ import ( "testing" "github.com/hashicorp/go-hclog" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) // TestDialLDAP duplicates a potential panic that was @@ -29,15 +31,20 @@ func TestDialLDAP(t *testing.T) { func TestLDAPEscape(t *testing.T) { testcases := map[string]string{ - "#test": "\\#test", - "test,hello": "test\\,hello", - "test,hel+lo": "test\\,hel\\+lo", - "test\\hello": "test\\\\hello", - " test ": "\\ test \\ ", - "": "", - "\\test": "\\\\test", - "test\\": "test\\\\", - "test\\ ": "test\\\\\\ ", + "#test": "\\#test", + "test,hello": "test\\,hello", + "test,hel+lo": "test\\,hel\\+lo", + "test\\hello": "test\\\\hello", + " test ": "\\ test \\ ", + "": "", + `\`: `\\`, + "trailing\000": `trailing\00`, + "mid\000dle": `mid\00dle`, + "\000": `\00`, + "multiple\000\000": `multiple\00\00`, + "backlash-before-null\\\000": `backlash-before-null\\\00`, + "trailing\\": `trailing\\`, + "double-escaping\\>": `double-escaping\\\>`, } for test, answer := range testcases { @@ -88,3 +95,58 @@ func TestSIDBytesToString(t *testing.T) { } } } + +func TestClient_renderUserSearchFilter(t *testing.T) { + t.Parallel() + tests := []struct { + name string + conf *ConfigEntry + username string + want string + errContains string + }{ + { + name: "valid-default", + username: "alice", + conf: &ConfigEntry{ + UserAttr: "cn", + }, + want: "(cn=alice)", + }, + { + name: "escaped-malicious-filter", + username: "foo@example.com)((((((((((((((((((((((((((((((((((((((userPrincipalName=foo", + conf: &ConfigEntry{ + UPNDomain: "example.com", + UserFilter: "(&({{.UserAttr}}={{.Username}})({{.UserAttr}}=admin@example.com))", + }, + want: "(&(userPrincipalName=foo@example.com\\29\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28userPrincipalName=foo@example.com)(userPrincipalName=admin@example.com))", + }, + { + name: "bad-filter-unclosed-action", + username: "alice", + conf: &ConfigEntry{ + UserFilter: "hello{{range", + }, + errContains: "search failed due to template compilation error", + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + c := Client{ + Logger: hclog.NewNullLogger(), + LDAP: NewLDAP(), + } + + f, err := c.RenderUserSearchFilter(tc.conf, tc.username) + if tc.errContains != "" { + require.Error(t, err) + assert.ErrorContains(t, err, tc.errContains) + return + } + require.NoError(t, err) + assert.NotEmpty(t, f) + assert.Equal(t, tc.want, f) + }) + } +}