mirror of
				https://github.com/optim-enterprises-bv/vault.git
				synced 2025-11-04 04:28:08 +00:00 
			
		
		
		
	* Add worker pool for LDAP group lookup * changelog * Add lock * derefAliases disappeared
		
			
				
	
	
		
			753 lines
		
	
	
		
			23 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			753 lines
		
	
	
		
			23 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
// Copyright (c) HashiCorp, Inc.
 | 
						|
// SPDX-License-Identifier: MPL-2.0
 | 
						|
 | 
						|
package ldaputil
 | 
						|
 | 
						|
import (
 | 
						|
	"bytes"
 | 
						|
	"crypto/tls"
 | 
						|
	"crypto/x509"
 | 
						|
	"encoding/binary"
 | 
						|
	"encoding/hex"
 | 
						|
	"fmt"
 | 
						|
	"math"
 | 
						|
	"net"
 | 
						|
	"net/url"
 | 
						|
	"strings"
 | 
						|
	"sync"
 | 
						|
	"text/template"
 | 
						|
	"time"
 | 
						|
 | 
						|
	"github.com/go-ldap/ldap/v3"
 | 
						|
	hclog "github.com/hashicorp/go-hclog"
 | 
						|
	multierror "github.com/hashicorp/go-multierror"
 | 
						|
	"github.com/hashicorp/go-secure-stdlib/tlsutil"
 | 
						|
)
 | 
						|
 | 
						|
type Client struct {
 | 
						|
	Logger hclog.Logger
 | 
						|
	LDAP   LDAP
 | 
						|
}
 | 
						|
 | 
						|
func (c *Client) DialLDAP(cfg *ConfigEntry) (Connection, error) {
 | 
						|
	var retErr *multierror.Error
 | 
						|
	var conn Connection
 | 
						|
	urls := strings.Split(cfg.Url, ",")
 | 
						|
 | 
						|
	for _, uut := range urls {
 | 
						|
		u, err := url.Parse(uut)
 | 
						|
		if err != nil {
 | 
						|
			retErr = multierror.Append(retErr, fmt.Errorf(fmt.Sprintf("error parsing url %q: {{err}}", uut), err))
 | 
						|
			continue
 | 
						|
		}
 | 
						|
		host, port, err := net.SplitHostPort(u.Host)
 | 
						|
		if err != nil {
 | 
						|
			host = u.Host
 | 
						|
		}
 | 
						|
 | 
						|
		var tlsConfig *tls.Config
 | 
						|
		dialer := net.Dialer{
 | 
						|
			Timeout: time.Duration(cfg.ConnectionTimeout) * time.Second,
 | 
						|
		}
 | 
						|
 | 
						|
		switch u.Scheme {
 | 
						|
		case "ldap":
 | 
						|
			if port == "" {
 | 
						|
				port = "389"
 | 
						|
			}
 | 
						|
 | 
						|
			fullAddr := fmt.Sprintf("%s://%s", u.Scheme, net.JoinHostPort(host, port))
 | 
						|
			opt := ldap.DialWithDialer(&dialer)
 | 
						|
 | 
						|
			conn, err = c.LDAP.DialURL(fullAddr, opt)
 | 
						|
			if err != nil {
 | 
						|
				break
 | 
						|
			}
 | 
						|
			if conn == nil {
 | 
						|
				err = fmt.Errorf("empty connection after dialing")
 | 
						|
				break
 | 
						|
			}
 | 
						|
			if cfg.StartTLS {
 | 
						|
				tlsConfig, err = getTLSConfig(cfg, host)
 | 
						|
				if err != nil {
 | 
						|
					break
 | 
						|
				}
 | 
						|
				err = conn.StartTLS(tlsConfig)
 | 
						|
			}
 | 
						|
		case "ldaps":
 | 
						|
			if port == "" {
 | 
						|
				port = "636"
 | 
						|
			}
 | 
						|
			tlsConfig, err = getTLSConfig(cfg, host)
 | 
						|
			if err != nil {
 | 
						|
				break
 | 
						|
			}
 | 
						|
 | 
						|
			fullAddr := fmt.Sprintf("%s://%s", u.Scheme, net.JoinHostPort(host, port))
 | 
						|
			opt := ldap.DialWithDialer(&dialer)
 | 
						|
			tls := ldap.DialWithTLSConfig(tlsConfig)
 | 
						|
 | 
						|
			conn, err = c.LDAP.DialURL(fullAddr, opt, tls)
 | 
						|
			if err != nil {
 | 
						|
				break
 | 
						|
			}
 | 
						|
		default:
 | 
						|
			retErr = multierror.Append(retErr, fmt.Errorf("invalid LDAP scheme in url %q", net.JoinHostPort(host, port)))
 | 
						|
			continue
 | 
						|
		}
 | 
						|
		if err == nil {
 | 
						|
			if retErr != nil {
 | 
						|
				if c.Logger.IsDebug() {
 | 
						|
					c.Logger.Debug("errors connecting to some hosts", "error", retErr.Error())
 | 
						|
				}
 | 
						|
			}
 | 
						|
			retErr = nil
 | 
						|
			break
 | 
						|
		}
 | 
						|
		retErr = multierror.Append(retErr, fmt.Errorf(fmt.Sprintf("error connecting to host %q: {{err}}", uut), err))
 | 
						|
	}
 | 
						|
	if retErr != nil {
 | 
						|
		return nil, retErr
 | 
						|
	}
 | 
						|
	if timeout := cfg.RequestTimeout; timeout > 0 {
 | 
						|
		conn.SetTimeout(time.Duration(timeout) * time.Second)
 | 
						|
	}
 | 
						|
	return conn, nil
 | 
						|
}
 | 
						|
 | 
						|
/*
 | 
						|
 * Searches for a username in the ldap server, returning a minimal subset of the
 | 
						|
 * user's attributes (if found)
 | 
						|
 */
 | 
						|
func (c *Client) makeLdapSearchRequest(cfg *ConfigEntry, conn Connection, username string) (*ldap.SearchResult, error) {
 | 
						|
	// Note: The logic below drives the logic in ConfigEntry.Validate().
 | 
						|
	// If updated, please update there as well.
 | 
						|
	var err error
 | 
						|
	if cfg.BindPassword != "" {
 | 
						|
		err = conn.Bind(cfg.BindDN, cfg.BindPassword)
 | 
						|
	} else {
 | 
						|
		err = conn.UnauthenticatedBind(cfg.BindDN)
 | 
						|
	}
 | 
						|
	if err != nil {
 | 
						|
		return nil, fmt.Errorf("LDAP bind (service) failed: %w", err)
 | 
						|
	}
 | 
						|
 | 
						|
	renderedFilter, err := c.RenderUserSearchFilter(cfg, username)
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	if c.Logger.IsDebug() {
 | 
						|
		c.Logger.Debug("discovering user", "userdn", cfg.UserDN, "filter", renderedFilter)
 | 
						|
	}
 | 
						|
	ldapRequest := &ldap.SearchRequest{
 | 
						|
		BaseDN:       cfg.UserDN,
 | 
						|
		DerefAliases: ldapDerefAliasMap[cfg.DerefAliases],
 | 
						|
		Scope:        ldap.ScopeWholeSubtree,
 | 
						|
		Filter:       renderedFilter,
 | 
						|
		SizeLimit:    2, // Should be only 1 result. Any number larger (2 or more) means access denied.
 | 
						|
		Attributes: []string{
 | 
						|
			cfg.UserAttr, // Return only needed attributes
 | 
						|
		},
 | 
						|
	}
 | 
						|
 | 
						|
	result, err := conn.Search(ldapRequest)
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	return result, nil
 | 
						|
}
 | 
						|
 | 
						|
/*
 | 
						|
 * Discover and return the bind string for the user attempting to authenticate, as well as the
 | 
						|
 * value to use for the identity alias.
 | 
						|
 * This is handled in one of several ways:
 | 
						|
 *
 | 
						|
 * 1. If DiscoverDN is set, the user object will be searched for using userdn (base search path)
 | 
						|
 *    and userattr (the attribute that maps to the provided username) or user search filter.
 | 
						|
 *    The bind will either be anonymous or use binddn and bindpassword if they were provided.
 | 
						|
 * 2. If upndomain is set, the user dn and alias attribte are constructed as 'username@upndomain'.
 | 
						|
 *    See https://msdn.microsoft.com/en-us/library/cc223499.aspx
 | 
						|
 *
 | 
						|
 */
 | 
						|
func (c *Client) GetUserBindDN(cfg *ConfigEntry, conn Connection, username string) (string, error) {
 | 
						|
	bindDN := ""
 | 
						|
 | 
						|
	// Note: The logic below drives the logic in ConfigEntry.Validate().
 | 
						|
	// If updated, please update there as well.
 | 
						|
	if cfg.DiscoverDN || (cfg.BindDN != "" && cfg.BindPassword != "") {
 | 
						|
 | 
						|
		result, err := c.makeLdapSearchRequest(cfg, conn, username)
 | 
						|
		if err != nil {
 | 
						|
			return bindDN, fmt.Errorf("LDAP search for binddn failed %w", err)
 | 
						|
		}
 | 
						|
		if len(result.Entries) != 1 {
 | 
						|
			return bindDN, fmt.Errorf("LDAP search for binddn 0 or not unique")
 | 
						|
		}
 | 
						|
 | 
						|
		bindDN = result.Entries[0].DN
 | 
						|
 | 
						|
	} else {
 | 
						|
		if cfg.UPNDomain != "" {
 | 
						|
			bindDN = fmt.Sprintf("%s@%s", EscapeLDAPValue(username), cfg.UPNDomain)
 | 
						|
		} else {
 | 
						|
			bindDN = fmt.Sprintf("%s=%s,%s", cfg.UserAttr, EscapeLDAPValue(username), cfg.UserDN)
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	return bindDN, nil
 | 
						|
}
 | 
						|
 | 
						|
func (c *Client) RenderUserSearchFilter(cfg *ConfigEntry, username string) (string, error) {
 | 
						|
	// The UserFilter can be blank if not set, or running this version of the code
 | 
						|
	// on an existing ldap configuration
 | 
						|
	if cfg.UserFilter == "" {
 | 
						|
		cfg.UserFilter = "({{.UserAttr}}={{.Username}})"
 | 
						|
	}
 | 
						|
 | 
						|
	// If userfilter was defined, resolve it as a Go template and use the query to
 | 
						|
	// find the login user
 | 
						|
	if c.Logger.IsDebug() {
 | 
						|
		c.Logger.Debug("compiling search filter", "search_filter", cfg.UserFilter)
 | 
						|
	}
 | 
						|
 | 
						|
	// Parse the configuration as a template.
 | 
						|
	// Example template "({{.UserAttr}}={{.Username}})"
 | 
						|
	t, err := template.New("queryTemplate").Parse(cfg.UserFilter)
 | 
						|
	if err != nil {
 | 
						|
		return "", fmt.Errorf("LDAP search failed due to template compilation error: %w", err)
 | 
						|
	}
 | 
						|
 | 
						|
	// Build context to pass to template - we will be exposing UserDn and Username.
 | 
						|
	context := struct {
 | 
						|
		UserAttr string
 | 
						|
		Username string
 | 
						|
	}{
 | 
						|
		ldap.EscapeFilter(cfg.UserAttr),
 | 
						|
		ldap.EscapeFilter(username),
 | 
						|
	}
 | 
						|
	if cfg.UPNDomain != "" {
 | 
						|
		context.UserAttr = "userPrincipalName"
 | 
						|
		// Intentionally, calling EscapeFilter(...) (vs EscapeValue) since the
 | 
						|
		// username is being injected into a search filter.
 | 
						|
		// As an untrusted string, the username must be escaped according to RFC
 | 
						|
		// 4515, in order to prevent attackers from injecting characters that could modify the filter
 | 
						|
		context.Username = fmt.Sprintf("%s@%s", ldap.EscapeFilter(username), cfg.UPNDomain)
 | 
						|
	}
 | 
						|
 | 
						|
	// Execute the template. Note that the template context contains escaped input and does
 | 
						|
	// not provide behavior via functions. Additionally, no function map has been provided
 | 
						|
	// during template initialization. The only template functions available during execution
 | 
						|
	// are the predefined global functions: https://pkg.go.dev/text/template#hdr-Functions
 | 
						|
	var renderedFilter bytes.Buffer
 | 
						|
	if err := t.Execute(&renderedFilter, context); err != nil {
 | 
						|
		return "", fmt.Errorf("LDAP search failed due to template parsing error: %w", err)
 | 
						|
	}
 | 
						|
 | 
						|
	return renderedFilter.String(), nil
 | 
						|
}
 | 
						|
 | 
						|
/*
 | 
						|
 * Returns the value to be used for the entity alias of this user
 | 
						|
 * This is handled in one of several ways:
 | 
						|
 *
 | 
						|
 * 1. If DiscoverDN is set, the user will be searched for using userdn (base search path)
 | 
						|
 *    and userattr (the attribute that maps to the provided username) or user search filter.
 | 
						|
 *    The bind will either be anonymous or use binddn and bindpassword if they were provided.
 | 
						|
 * 2. If upndomain is set, the alias attribte is constructed as 'username@upndomain'.
 | 
						|
 *
 | 
						|
 */
 | 
						|
func (c *Client) GetUserAliasAttributeValue(cfg *ConfigEntry, conn Connection, username string) (string, error) {
 | 
						|
	aliasAttributeValue := ""
 | 
						|
 | 
						|
	// Note: The logic below drives the logic in ConfigEntry.Validate().
 | 
						|
	// If updated, please update there as well.
 | 
						|
	if cfg.DiscoverDN || (cfg.BindDN != "" && cfg.BindPassword != "") {
 | 
						|
 | 
						|
		result, err := c.makeLdapSearchRequest(cfg, conn, username)
 | 
						|
		if err != nil {
 | 
						|
			return aliasAttributeValue, fmt.Errorf("LDAP search for entity alias attribute failed: %w", err)
 | 
						|
		}
 | 
						|
		if len(result.Entries) != 1 {
 | 
						|
			return aliasAttributeValue, fmt.Errorf("LDAP search for entity alias attribute 0 or not unique")
 | 
						|
		}
 | 
						|
 | 
						|
		if len(result.Entries[0].Attributes) != 1 {
 | 
						|
			return aliasAttributeValue, fmt.Errorf("LDAP attribute missing for entity alias mapping")
 | 
						|
		}
 | 
						|
 | 
						|
		if len(result.Entries[0].Attributes[0].Values) != 1 {
 | 
						|
			return aliasAttributeValue, fmt.Errorf("LDAP entity alias attribute %s empty or not unique for entity alias mapping", cfg.UserAttr)
 | 
						|
		}
 | 
						|
 | 
						|
		aliasAttributeValue = result.Entries[0].Attributes[0].Values[0]
 | 
						|
	} else {
 | 
						|
		if cfg.UPNDomain != "" {
 | 
						|
			aliasAttributeValue = fmt.Sprintf("%s@%s", EscapeLDAPValue(username), cfg.UPNDomain)
 | 
						|
		} else {
 | 
						|
			aliasAttributeValue = fmt.Sprintf("%s=%s,%s", cfg.UserAttr, EscapeLDAPValue(username), cfg.UserDN)
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	return aliasAttributeValue, nil
 | 
						|
}
 | 
						|
 | 
						|
/*
 | 
						|
 * Returns the DN of the object representing the authenticated user.
 | 
						|
 */
 | 
						|
func (c *Client) GetUserDN(cfg *ConfigEntry, conn Connection, bindDN, username string) (string, error) {
 | 
						|
	userDN := ""
 | 
						|
	if cfg.UPNDomain != "" {
 | 
						|
		// Find the distinguished name for the user if userPrincipalName used for login
 | 
						|
		filter := fmt.Sprintf("(userPrincipalName=%s@%s)", EscapeLDAPValue(username), cfg.UPNDomain)
 | 
						|
		if c.Logger.IsDebug() {
 | 
						|
			c.Logger.Debug("searching upn", "userdn", cfg.UserDN, "filter", filter)
 | 
						|
		}
 | 
						|
		result, err := conn.Search(&ldap.SearchRequest{
 | 
						|
			BaseDN:       cfg.UserDN,
 | 
						|
			Scope:        ldap.ScopeWholeSubtree,
 | 
						|
			DerefAliases: ldapDerefAliasMap[cfg.DerefAliases],
 | 
						|
			Filter:       filter,
 | 
						|
			SizeLimit:    math.MaxInt32,
 | 
						|
		})
 | 
						|
		if err != nil {
 | 
						|
			return userDN, fmt.Errorf("LDAP search failed for detecting user: %w", err)
 | 
						|
		}
 | 
						|
		for _, e := range result.Entries {
 | 
						|
			userDN = e.DN
 | 
						|
		}
 | 
						|
	} else {
 | 
						|
		userDN = bindDN
 | 
						|
	}
 | 
						|
 | 
						|
	return userDN, nil
 | 
						|
}
 | 
						|
 | 
						|
func (c *Client) performLdapFilterGroupsSearch(cfg *ConfigEntry, conn Connection, userDN string, username string) ([]*ldap.Entry, error) {
 | 
						|
	if cfg.GroupFilter == "" {
 | 
						|
		c.Logger.Warn("groupfilter is empty, will not query server")
 | 
						|
		return make([]*ldap.Entry, 0), nil
 | 
						|
	}
 | 
						|
 | 
						|
	if cfg.GroupDN == "" {
 | 
						|
		c.Logger.Warn("groupdn is empty, will not query server")
 | 
						|
		return make([]*ldap.Entry, 0), nil
 | 
						|
	}
 | 
						|
 | 
						|
	// If groupfilter was defined, resolve it as a Go template and use the query for
 | 
						|
	// returning the user's groups
 | 
						|
	if c.Logger.IsDebug() {
 | 
						|
		c.Logger.Debug("compiling group filter", "group_filter", cfg.GroupFilter)
 | 
						|
	}
 | 
						|
 | 
						|
	// Parse the configuration as a template.
 | 
						|
	// Example template "(&(objectClass=group)(member:1.2.840.113556.1.4.1941:={{.UserDN}}))"
 | 
						|
	t, err := template.New("queryTemplate").Parse(cfg.GroupFilter)
 | 
						|
	if err != nil {
 | 
						|
		return nil, fmt.Errorf("LDAP search failed due to template compilation error: %w", err)
 | 
						|
	}
 | 
						|
 | 
						|
	// Build context to pass to template - we will be exposing UserDn and Username.
 | 
						|
	context := struct {
 | 
						|
		UserDN   string
 | 
						|
		Username string
 | 
						|
	}{
 | 
						|
		ldap.EscapeFilter(userDN),
 | 
						|
		ldap.EscapeFilter(username),
 | 
						|
	}
 | 
						|
 | 
						|
	var renderedQuery bytes.Buffer
 | 
						|
	if err := t.Execute(&renderedQuery, context); err != nil {
 | 
						|
		return nil, fmt.Errorf("LDAP search failed due to template parsing error: %w", err)
 | 
						|
	}
 | 
						|
 | 
						|
	if c.Logger.IsDebug() {
 | 
						|
		c.Logger.Debug("searching", "groupdn", cfg.GroupDN, "rendered_query", renderedQuery.String())
 | 
						|
	}
 | 
						|
 | 
						|
	result, err := conn.Search(&ldap.SearchRequest{
 | 
						|
		BaseDN:       cfg.GroupDN,
 | 
						|
		Scope:        ldap.ScopeWholeSubtree,
 | 
						|
		DerefAliases: ldapDerefAliasMap[cfg.DerefAliases],
 | 
						|
		Filter:       renderedQuery.String(),
 | 
						|
		Attributes: []string{
 | 
						|
			cfg.GroupAttr,
 | 
						|
		},
 | 
						|
		SizeLimit: math.MaxInt32,
 | 
						|
	})
 | 
						|
	if err != nil {
 | 
						|
		return nil, fmt.Errorf("LDAP search failed: %w", err)
 | 
						|
	}
 | 
						|
 | 
						|
	return result.Entries, nil
 | 
						|
}
 | 
						|
 | 
						|
func (c *Client) performLdapFilterGroupsSearchPaging(cfg *ConfigEntry, conn PagingConnection, userDN string, username string) ([]*ldap.Entry, error) {
 | 
						|
	if cfg.GroupFilter == "" {
 | 
						|
		c.Logger.Warn("groupfilter is empty, will not query server")
 | 
						|
		return make([]*ldap.Entry, 0), nil
 | 
						|
	}
 | 
						|
 | 
						|
	if cfg.GroupDN == "" {
 | 
						|
		c.Logger.Warn("groupdn is empty, will not query server")
 | 
						|
		return make([]*ldap.Entry, 0), nil
 | 
						|
	}
 | 
						|
 | 
						|
	// If groupfilter was defined, resolve it as a Go template and use the query for
 | 
						|
	// returning the user's groups
 | 
						|
	if c.Logger.IsDebug() {
 | 
						|
		c.Logger.Debug("compiling group filter", "group_filter", cfg.GroupFilter)
 | 
						|
	}
 | 
						|
 | 
						|
	// Parse the configuration as a template.
 | 
						|
	// Example template "(&(objectClass=group)(member:1.2.840.113556.1.4.1941:={{.UserDN}}))"
 | 
						|
	t, err := template.New("queryTemplate").Parse(cfg.GroupFilter)
 | 
						|
	if err != nil {
 | 
						|
		return nil, fmt.Errorf("LDAP search failed due to template compilation error: %w", err)
 | 
						|
	}
 | 
						|
 | 
						|
	// Build context to pass to template - we will be exposing UserDn and Username.
 | 
						|
	context := struct {
 | 
						|
		UserDN   string
 | 
						|
		Username string
 | 
						|
	}{
 | 
						|
		ldap.EscapeFilter(userDN),
 | 
						|
		ldap.EscapeFilter(username),
 | 
						|
	}
 | 
						|
 | 
						|
	// Execute the template. Note that the template context contains escaped input and does
 | 
						|
	// not provide behavior via functions. Additionally, no function map has been provided
 | 
						|
	// during template initialization. The only template functions available during execution
 | 
						|
	// are the predefined global functions: https://pkg.go.dev/text/template#hdr-Functions
 | 
						|
	var renderedQuery bytes.Buffer
 | 
						|
	if err := t.Execute(&renderedQuery, context); err != nil {
 | 
						|
		return nil, fmt.Errorf("LDAP search failed due to template parsing error: %w", err)
 | 
						|
	}
 | 
						|
 | 
						|
	if c.Logger.IsDebug() {
 | 
						|
		c.Logger.Debug("searching", "groupdn", cfg.GroupDN, "rendered_query", renderedQuery.String())
 | 
						|
	}
 | 
						|
 | 
						|
	result, err := conn.SearchWithPaging(&ldap.SearchRequest{
 | 
						|
		BaseDN:       cfg.GroupDN,
 | 
						|
		Scope:        ldap.ScopeWholeSubtree,
 | 
						|
		DerefAliases: ldapDerefAliasMap[cfg.DerefAliases],
 | 
						|
		Filter:       renderedQuery.String(),
 | 
						|
		Attributes: []string{
 | 
						|
			cfg.GroupAttr,
 | 
						|
		},
 | 
						|
		SizeLimit: math.MaxInt32,
 | 
						|
	}, uint32(cfg.MaximumPageSize))
 | 
						|
	if err != nil {
 | 
						|
		return nil, fmt.Errorf("LDAP search failed: %w", err)
 | 
						|
	}
 | 
						|
 | 
						|
	return result.Entries, nil
 | 
						|
}
 | 
						|
 | 
						|
func sidBytesToString(b []byte) (string, error) {
 | 
						|
	reader := bytes.NewReader(b)
 | 
						|
 | 
						|
	var revision, subAuthorityCount uint8
 | 
						|
	var identifierAuthorityParts [3]uint16
 | 
						|
 | 
						|
	if err := binary.Read(reader, binary.LittleEndian, &revision); err != nil {
 | 
						|
		return "", fmt.Errorf(fmt.Sprintf("SID %#v convert failed reading Revision: {{err}}", b), err)
 | 
						|
	}
 | 
						|
 | 
						|
	if err := binary.Read(reader, binary.LittleEndian, &subAuthorityCount); err != nil {
 | 
						|
		return "", fmt.Errorf(fmt.Sprintf("SID %#v convert failed reading SubAuthorityCount: {{err}}", b), err)
 | 
						|
	}
 | 
						|
 | 
						|
	if err := binary.Read(reader, binary.BigEndian, &identifierAuthorityParts); err != nil {
 | 
						|
		return "", fmt.Errorf(fmt.Sprintf("SID %#v convert failed reading IdentifierAuthority: {{err}}", b), err)
 | 
						|
	}
 | 
						|
	identifierAuthority := (uint64(identifierAuthorityParts[0]) << 32) + (uint64(identifierAuthorityParts[1]) << 16) + uint64(identifierAuthorityParts[2])
 | 
						|
 | 
						|
	subAuthority := make([]uint32, subAuthorityCount)
 | 
						|
	if err := binary.Read(reader, binary.LittleEndian, &subAuthority); err != nil {
 | 
						|
		return "", fmt.Errorf(fmt.Sprintf("SID %#v convert failed reading SubAuthority: {{err}}", b), err)
 | 
						|
	}
 | 
						|
 | 
						|
	result := fmt.Sprintf("S-%d-%d", revision, identifierAuthority)
 | 
						|
	for _, subAuthorityPart := range subAuthority {
 | 
						|
		result += fmt.Sprintf("-%d", subAuthorityPart)
 | 
						|
	}
 | 
						|
 | 
						|
	return result, nil
 | 
						|
}
 | 
						|
 | 
						|
func (c *Client) performLdapTokenGroupsSearch(cfg *ConfigEntry, conn Connection, userDN string) ([]*ldap.Entry, error) {
 | 
						|
	var wg sync.WaitGroup
 | 
						|
	var lock sync.Mutex
 | 
						|
	taskChan := make(chan string)
 | 
						|
	maxWorkers := 10
 | 
						|
 | 
						|
	result, err := conn.Search(&ldap.SearchRequest{
 | 
						|
		BaseDN:       userDN,
 | 
						|
		Scope:        ldap.ScopeBaseObject,
 | 
						|
		DerefAliases: ldapDerefAliasMap[cfg.DerefAliases],
 | 
						|
		Filter:       "(objectClass=*)",
 | 
						|
		Attributes: []string{
 | 
						|
			"tokenGroups",
 | 
						|
		},
 | 
						|
		SizeLimit: 1,
 | 
						|
	})
 | 
						|
	if err != nil {
 | 
						|
		return nil, fmt.Errorf("LDAP search failed: %w", err)
 | 
						|
	}
 | 
						|
	if len(result.Entries) == 0 {
 | 
						|
		c.Logger.Warn("unable to read object for group attributes", "userdn", userDN, "groupattr", cfg.GroupAttr)
 | 
						|
		return make([]*ldap.Entry, 0), nil
 | 
						|
	}
 | 
						|
 | 
						|
	userEntry := result.Entries[0]
 | 
						|
	groupAttrValues := userEntry.GetRawAttributeValues("tokenGroups")
 | 
						|
	groupEntries := make([]*ldap.Entry, 0, len(groupAttrValues))
 | 
						|
 | 
						|
	for i := 0; i < maxWorkers; i++ {
 | 
						|
		wg.Add(1)
 | 
						|
		go func() {
 | 
						|
			defer wg.Done()
 | 
						|
 | 
						|
			for sid := range taskChan {
 | 
						|
				groupResult, err := conn.Search(&ldap.SearchRequest{
 | 
						|
					BaseDN:       fmt.Sprintf("<SID=%s>", sid),
 | 
						|
					Scope:        ldap.ScopeBaseObject,
 | 
						|
					DerefAliases: ldapDerefAliasMap[cfg.DerefAliases],
 | 
						|
					Filter:       "(objectClass=*)",
 | 
						|
					Attributes: []string{
 | 
						|
						"1.1", // RFC no attributes
 | 
						|
					},
 | 
						|
					SizeLimit: 1,
 | 
						|
				})
 | 
						|
				if err != nil {
 | 
						|
					c.Logger.Warn("unable to read the group sid", "sid", sid)
 | 
						|
					continue
 | 
						|
				}
 | 
						|
 | 
						|
				if len(groupResult.Entries) == 0 {
 | 
						|
					c.Logger.Warn("unable to find the group", "sid", sid)
 | 
						|
					continue
 | 
						|
				}
 | 
						|
 | 
						|
				lock.Lock()
 | 
						|
				groupEntries = append(groupEntries, groupResult.Entries[0])
 | 
						|
				lock.Unlock()
 | 
						|
			}
 | 
						|
		}()
 | 
						|
	}
 | 
						|
 | 
						|
	for _, sidBytes := range groupAttrValues {
 | 
						|
		sidString, err := sidBytesToString(sidBytes)
 | 
						|
		if err != nil {
 | 
						|
			c.Logger.Warn("unable to read sid", "err", err)
 | 
						|
			continue
 | 
						|
		}
 | 
						|
		taskChan <- sidString
 | 
						|
	}
 | 
						|
 | 
						|
	close(taskChan)
 | 
						|
	wg.Wait()
 | 
						|
 | 
						|
	return groupEntries, nil
 | 
						|
}
 | 
						|
 | 
						|
/*
 | 
						|
 * getLdapGroups queries LDAP and returns a slice describing the set of groups the authenticated user is a member of.
 | 
						|
 *
 | 
						|
 * If cfg.UseTokenGroups is true then the search is performed directly on the userDN.
 | 
						|
 * The values of those attributes are converted to string SIDs, and then looked up to get ldap.Entry objects.
 | 
						|
 * Otherwise, the search query is constructed according to cfg.GroupFilter, and run in context of cfg.GroupDN.
 | 
						|
 * Groups will be resolved from the query results by following the attribute defined in cfg.GroupAttr.
 | 
						|
 *
 | 
						|
 * cfg.GroupFilter is a go template and is compiled with the following context: [UserDN, Username]
 | 
						|
 *    UserDN - The DN of the authenticated user
 | 
						|
 *    Username - The Username of the authenticated user
 | 
						|
 *
 | 
						|
 * Example:
 | 
						|
 *   cfg.GroupFilter = "(&(objectClass=group)(member:1.2.840.113556.1.4.1941:={{.UserDN}}))"
 | 
						|
 *   cfg.GroupDN     = "OU=Groups,DC=myorg,DC=com"
 | 
						|
 *   cfg.GroupAttr   = "cn"
 | 
						|
 *
 | 
						|
 * NOTE - If cfg.GroupFilter is empty, no query is performed and an empty result slice is returned.
 | 
						|
 *
 | 
						|
 */
 | 
						|
func (c *Client) GetLdapGroups(cfg *ConfigEntry, conn Connection, userDN string, username string) ([]string, error) {
 | 
						|
	var entries []*ldap.Entry
 | 
						|
	var err error
 | 
						|
	if cfg.UseTokenGroups {
 | 
						|
		entries, err = c.performLdapTokenGroupsSearch(cfg, conn, userDN)
 | 
						|
	} else {
 | 
						|
		if paging, ok := conn.(PagingConnection); ok && cfg.MaximumPageSize > 0 {
 | 
						|
			entries, err = c.performLdapFilterGroupsSearchPaging(cfg, paging, userDN, username)
 | 
						|
		} else {
 | 
						|
			entries, err = c.performLdapFilterGroupsSearch(cfg, conn, userDN, username)
 | 
						|
		}
 | 
						|
	}
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	// retrieve the groups in a string/bool map as a structure to avoid duplicates inside
 | 
						|
	ldapMap := make(map[string]bool)
 | 
						|
 | 
						|
	for _, e := range entries {
 | 
						|
		dn, err := ldap.ParseDN(e.DN)
 | 
						|
		if err != nil || len(dn.RDNs) == 0 {
 | 
						|
			continue
 | 
						|
		}
 | 
						|
 | 
						|
		// Enumerate attributes of each result, parse out CN and add as group
 | 
						|
		values := e.GetAttributeValues(cfg.GroupAttr)
 | 
						|
		if len(values) > 0 {
 | 
						|
			for _, val := range values {
 | 
						|
				groupCN := getCN(cfg, val)
 | 
						|
				ldapMap[groupCN] = true
 | 
						|
			}
 | 
						|
		} else {
 | 
						|
			// If groupattr didn't resolve, use self (enumerating group objects)
 | 
						|
			groupCN := getCN(cfg, e.DN)
 | 
						|
			ldapMap[groupCN] = true
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	ldapGroups := make([]string, 0, len(ldapMap))
 | 
						|
	for key := range ldapMap {
 | 
						|
		ldapGroups = append(ldapGroups, key)
 | 
						|
	}
 | 
						|
 | 
						|
	return ldapGroups, nil
 | 
						|
}
 | 
						|
 | 
						|
// EscapeLDAPValue is exported because a plugin uses it outside this package.
 | 
						|
// EscapeLDAPValue will properly escape the input string as an ldap value
 | 
						|
// rfc4514 states the following must be escaped:
 | 
						|
// - leading space or hash
 | 
						|
// - trailing space
 | 
						|
// - special characters '"', '+', ',', ';', '<', '>', '\\'
 | 
						|
// - hex
 | 
						|
func EscapeLDAPValue(input string) string {
 | 
						|
	if input == "" {
 | 
						|
		return ""
 | 
						|
	}
 | 
						|
 | 
						|
	buf := bytes.Buffer{}
 | 
						|
 | 
						|
	escFn := func(c byte) {
 | 
						|
		buf.WriteByte('\\')
 | 
						|
		buf.WriteByte(c)
 | 
						|
	}
 | 
						|
 | 
						|
	inputLen := len(input)
 | 
						|
	for i := 0; i < inputLen; i++ {
 | 
						|
		char := input[i]
 | 
						|
		switch {
 | 
						|
		case i == 0 && char == ' ' || char == '#':
 | 
						|
			// leading space or hash.
 | 
						|
			escFn(char)
 | 
						|
			continue
 | 
						|
		case i == inputLen-1 && char == ' ':
 | 
						|
			// trailing space.
 | 
						|
			escFn(char)
 | 
						|
			continue
 | 
						|
		case specialChar(char):
 | 
						|
			escFn(char)
 | 
						|
			continue
 | 
						|
		case char < ' ' || char > '~':
 | 
						|
			// anything that's not between the ascii space and tilde must be hex
 | 
						|
			buf.WriteByte('\\')
 | 
						|
			buf.WriteString(hex.EncodeToString([]byte{char}))
 | 
						|
			continue
 | 
						|
		default:
 | 
						|
			// everything remaining, doesn't need to be escaped
 | 
						|
			buf.WriteByte(char)
 | 
						|
		}
 | 
						|
	}
 | 
						|
	return buf.String()
 | 
						|
}
 | 
						|
 | 
						|
func specialChar(char byte) bool {
 | 
						|
	switch char {
 | 
						|
	case '"', '+', ',', ';', '<', '>', '\\':
 | 
						|
		return true
 | 
						|
	default:
 | 
						|
		return false
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
/*
 | 
						|
 * Parses a distinguished name and returns the CN portion.
 | 
						|
 * Given a non-conforming string (such as an already-extracted CN),
 | 
						|
 * it will be returned as-is.
 | 
						|
 */
 | 
						|
func getCN(cfg *ConfigEntry, dn string) string {
 | 
						|
	parsedDN, err := ldap.ParseDN(dn)
 | 
						|
	if err != nil || len(parsedDN.RDNs) == 0 {
 | 
						|
		// It was already a CN, return as-is
 | 
						|
		return dn
 | 
						|
	}
 | 
						|
 | 
						|
	for _, rdn := range parsedDN.RDNs {
 | 
						|
		for _, rdnAttr := range rdn.Attributes {
 | 
						|
			if cfg.UsePre111GroupCNBehavior == nil || *cfg.UsePre111GroupCNBehavior {
 | 
						|
				if rdnAttr.Type == "CN" {
 | 
						|
					return rdnAttr.Value
 | 
						|
				}
 | 
						|
			} else {
 | 
						|
				if strings.EqualFold(rdnAttr.Type, "CN") {
 | 
						|
					return rdnAttr.Value
 | 
						|
				}
 | 
						|
			}
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	// Default, return self
 | 
						|
	return dn
 | 
						|
}
 | 
						|
 | 
						|
func getTLSConfig(cfg *ConfigEntry, host string) (*tls.Config, error) {
 | 
						|
	tlsConfig := &tls.Config{
 | 
						|
		ServerName: host,
 | 
						|
	}
 | 
						|
 | 
						|
	if cfg.TLSMinVersion != "" {
 | 
						|
		tlsMinVersion, ok := tlsutil.TLSLookup[cfg.TLSMinVersion]
 | 
						|
		if !ok {
 | 
						|
			return nil, fmt.Errorf("invalid 'tls_min_version' in config")
 | 
						|
		}
 | 
						|
		tlsConfig.MinVersion = tlsMinVersion
 | 
						|
	}
 | 
						|
 | 
						|
	if cfg.TLSMaxVersion != "" {
 | 
						|
		tlsMaxVersion, ok := tlsutil.TLSLookup[cfg.TLSMaxVersion]
 | 
						|
		if !ok {
 | 
						|
			return nil, fmt.Errorf("invalid 'tls_max_version' in config")
 | 
						|
		}
 | 
						|
		tlsConfig.MaxVersion = tlsMaxVersion
 | 
						|
	}
 | 
						|
 | 
						|
	if cfg.InsecureTLS {
 | 
						|
		tlsConfig.InsecureSkipVerify = true
 | 
						|
	}
 | 
						|
	if cfg.Certificate != "" {
 | 
						|
		caPool := x509.NewCertPool()
 | 
						|
		ok := caPool.AppendCertsFromPEM([]byte(cfg.Certificate))
 | 
						|
		if !ok {
 | 
						|
			return nil, fmt.Errorf("could not append CA certificate")
 | 
						|
		}
 | 
						|
		tlsConfig.RootCAs = caPool
 | 
						|
	}
 | 
						|
	if cfg.ClientTLSCert != "" && cfg.ClientTLSKey != "" {
 | 
						|
		certificate, err := tls.X509KeyPair([]byte(cfg.ClientTLSCert), []byte(cfg.ClientTLSKey))
 | 
						|
		if err != nil {
 | 
						|
			return nil, fmt.Errorf("failed to parse client X509 key pair: %w", err)
 | 
						|
		}
 | 
						|
		tlsConfig.Certificates = append(tlsConfig.Certificates, certificate)
 | 
						|
	} else if cfg.ClientTLSCert != "" || cfg.ClientTLSKey != "" {
 | 
						|
		return nil, fmt.Errorf("both client_tls_cert and client_tls_key must be set")
 | 
						|
	}
 | 
						|
	return tlsConfig, nil
 | 
						|
}
 |