mirror of
				https://github.com/optim-enterprises-bv/vault.git
				synced 2025-11-03 20:17:59 +00:00 
			
		
		
		
	Compare groups case-insensitively at login time (#3240)
* Compare groups case-insensitively at login time, since Okta groups are case-insensitive but preserving. * Make other group operations case-preserving but otherwise case-insensitive. New groups will be written in lowercase.
This commit is contained in:
		@@ -84,6 +84,11 @@ func (b *backend) Login(req *logical.Request, username string, password string)
 | 
				
			|||||||
	var allGroups []string
 | 
						var allGroups []string
 | 
				
			||||||
	// Import the custom added groups from okta backend
 | 
						// Import the custom added groups from okta backend
 | 
				
			||||||
	user, err := b.User(req.Storage, username)
 | 
						user, err := b.User(req.Storage, username)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							if b.Logger().IsDebug() {
 | 
				
			||||||
 | 
								b.Logger().Debug("auth/okta: error looking up user", "error", err)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
	if err == nil && user != nil && user.Groups != nil {
 | 
						if err == nil && user != nil && user.Groups != nil {
 | 
				
			||||||
		if b.Logger().IsDebug() {
 | 
							if b.Logger().IsDebug() {
 | 
				
			||||||
			b.Logger().Debug("auth/okta: adding local groups", "num_local_groups", len(user.Groups), "local_groups", user.Groups)
 | 
								b.Logger().Debug("auth/okta: adding local groups", "num_local_groups", len(user.Groups), "local_groups", user.Groups)
 | 
				
			||||||
@@ -96,9 +101,14 @@ func (b *backend) Login(req *logical.Request, username string, password string)
 | 
				
			|||||||
	// Retrieve policies
 | 
						// Retrieve policies
 | 
				
			||||||
	var policies []string
 | 
						var policies []string
 | 
				
			||||||
	for _, groupName := range allGroups {
 | 
						for _, groupName := range allGroups {
 | 
				
			||||||
		group, err := b.Group(req.Storage, groupName)
 | 
							entry, _, err := b.Group(req.Storage, groupName)
 | 
				
			||||||
		if err == nil && group != nil && group.Policies != nil {
 | 
							if err != nil {
 | 
				
			||||||
			policies = append(policies, group.Policies...)
 | 
								if b.Logger().IsDebug() {
 | 
				
			||||||
 | 
									b.Logger().Debug("auth/okta: error looking up group policies", "error", err)
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							if err == nil && entry != nil && entry.Policies != nil {
 | 
				
			||||||
 | 
								policies = append(policies, entry.Policies...)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -10,9 +10,10 @@ import (
 | 
				
			|||||||
	"github.com/hashicorp/vault/helper/policyutil"
 | 
						"github.com/hashicorp/vault/helper/policyutil"
 | 
				
			||||||
	log "github.com/mgutz/logxi/v1"
 | 
						log "github.com/mgutz/logxi/v1"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"time"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"github.com/hashicorp/vault/logical"
 | 
						"github.com/hashicorp/vault/logical"
 | 
				
			||||||
	logicaltest "github.com/hashicorp/vault/logical/testing"
 | 
						logicaltest "github.com/hashicorp/vault/logical/testing"
 | 
				
			||||||
	"time"
 | 
					 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestBackend_Config(t *testing.T) {
 | 
					func TestBackend_Config(t *testing.T) {
 | 
				
			||||||
@@ -52,15 +53,15 @@ func TestBackend_Config(t *testing.T) {
 | 
				
			|||||||
			testConfigCreate(t, configData),
 | 
								testConfigCreate(t, configData),
 | 
				
			||||||
			testLoginWrite(t, username, "wrong", "E0000004", 0, nil),
 | 
								testLoginWrite(t, username, "wrong", "E0000004", 0, nil),
 | 
				
			||||||
			testLoginWrite(t, username, password, "user is not a member of any authorized policy", 0, nil),
 | 
								testLoginWrite(t, username, password, "user is not a member of any authorized policy", 0, nil),
 | 
				
			||||||
			testAccUserGroups(t, username, "local_group,local_group2"),
 | 
								testAccUserGroups(t, username, "local_grouP,lOcal_group2"),
 | 
				
			||||||
			testAccGroups(t, "local_group", "local_group_policy"),
 | 
								testAccGroups(t, "local_groUp", "loCal_group_policy"),
 | 
				
			||||||
			testLoginWrite(t, username, password, "", defaultLeaseTTLVal, []string{"local_group_policy"}),
 | 
								testLoginWrite(t, username, password, "", defaultLeaseTTLVal, []string{"local_group_policy"}),
 | 
				
			||||||
			testAccGroups(t, "Everyone", "everyone_group_policy,every_group_policy2"),
 | 
								testAccGroups(t, "everyoNe", "everyone_grouP_policy,eveRy_group_policy2"),
 | 
				
			||||||
			testLoginWrite(t, username, password, "", defaultLeaseTTLVal, []string{"local_group_policy"}),
 | 
								testLoginWrite(t, username, password, "", defaultLeaseTTLVal, []string{"local_group_policy"}),
 | 
				
			||||||
			testConfigUpdate(t, configDataToken),
 | 
								testConfigUpdate(t, configDataToken),
 | 
				
			||||||
			testConfigRead(t, token, configData),
 | 
								testConfigRead(t, token, configData),
 | 
				
			||||||
			testLoginWrite(t, username, password, "", updatedDuration, []string{"everyone_group_policy", "every_group_policy2", "local_group_policy"}),
 | 
								testLoginWrite(t, username, password, "", updatedDuration, []string{"everyone_group_policy", "every_group_policy2", "local_group_policy"}),
 | 
				
			||||||
			testAccGroups(t, "local_group2", "testgroup_group_policy"),
 | 
								testAccGroups(t, "locAl_group2", "testgroup_group_policy"),
 | 
				
			||||||
			testLoginWrite(t, username, password, "", updatedDuration, []string{"everyone_group_policy", "every_group_policy2", "local_group_policy", "testgroup_group_policy"}),
 | 
								testLoginWrite(t, username, password, "", updatedDuration, []string{"everyone_group_policy", "every_group_policy2", "local_group_policy", "testgroup_group_policy"}),
 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,6 +1,8 @@
 | 
				
			|||||||
package okta
 | 
					package okta
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
 | 
						"strings"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"github.com/hashicorp/vault/helper/policyutil"
 | 
						"github.com/hashicorp/vault/helper/policyutil"
 | 
				
			||||||
	"github.com/hashicorp/vault/logical"
 | 
						"github.com/hashicorp/vault/logical"
 | 
				
			||||||
	"github.com/hashicorp/vault/logical/framework"
 | 
						"github.com/hashicorp/vault/logical/framework"
 | 
				
			||||||
@@ -45,34 +47,59 @@ func pathGroups(b *backend) *framework.Path {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (b *backend) Group(s logical.Storage, n string) (*GroupEntry, error) {
 | 
					// We look up groups in a case-insensitive manner since Okta is case-preserving
 | 
				
			||||||
 | 
					// but case-insensitive for comparisons
 | 
				
			||||||
 | 
					func (b *backend) Group(s logical.Storage, n string) (*GroupEntry, string, error) {
 | 
				
			||||||
 | 
						canonicalName := n
 | 
				
			||||||
	entry, err := s.Get("group/" + n)
 | 
						entry, err := s.Get("group/" + n)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return nil, err
 | 
							return nil, "", err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	if entry == nil {
 | 
						if entry == nil {
 | 
				
			||||||
		return nil, nil
 | 
							entries, err := s.List("group/")
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								return nil, "", err
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							for _, groupName := range entries {
 | 
				
			||||||
 | 
								if strings.ToLower(groupName) == strings.ToLower(n) {
 | 
				
			||||||
 | 
									entry, err = s.Get("group/" + groupName)
 | 
				
			||||||
 | 
									if err != nil {
 | 
				
			||||||
 | 
										return nil, "", err
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
									canonicalName = groupName
 | 
				
			||||||
 | 
									break
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if entry == nil {
 | 
				
			||||||
 | 
							return nil, "", nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	var result GroupEntry
 | 
						var result GroupEntry
 | 
				
			||||||
	if err := entry.DecodeJSON(&result); err != nil {
 | 
						if err := entry.DecodeJSON(&result); err != nil {
 | 
				
			||||||
		return nil, err
 | 
							return nil, "", err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return &result, nil
 | 
						return &result, canonicalName, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (b *backend) pathGroupDelete(
 | 
					func (b *backend) pathGroupDelete(
 | 
				
			||||||
	req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
 | 
						req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
 | 
				
			||||||
	name := d.Get("name").(string)
 | 
						name := d.Get("name").(string)
 | 
				
			||||||
	if len(name) == 0 {
 | 
						if len(name) == 0 {
 | 
				
			||||||
		return logical.ErrorResponse("Error empty name"), nil
 | 
							return logical.ErrorResponse("'name' must be supplied"), nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	err := req.Storage.Delete("group/" + name)
 | 
						entry, canonicalName, err := b.Group(req.Storage, name)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return nil, err
 | 
							return nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
						if entry != nil {
 | 
				
			||||||
 | 
							err := req.Storage.Delete("group/" + canonicalName)
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								return nil, err
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return nil, nil
 | 
						return nil, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
@@ -81,10 +108,10 @@ func (b *backend) pathGroupRead(
 | 
				
			|||||||
	req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
 | 
						req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
 | 
				
			||||||
	name := d.Get("name").(string)
 | 
						name := d.Get("name").(string)
 | 
				
			||||||
	if len(name) == 0 {
 | 
						if len(name) == 0 {
 | 
				
			||||||
		return logical.ErrorResponse("Error empty name"), nil
 | 
							return logical.ErrorResponse("'name' must be supplied"), nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	group, err := b.Group(req.Storage, name)
 | 
						group, _, err := b.Group(req.Storage, name)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return nil, err
 | 
							return nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@@ -103,7 +130,19 @@ func (b *backend) pathGroupWrite(
 | 
				
			|||||||
	req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
 | 
						req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
 | 
				
			||||||
	name := d.Get("name").(string)
 | 
						name := d.Get("name").(string)
 | 
				
			||||||
	if len(name) == 0 {
 | 
						if len(name) == 0 {
 | 
				
			||||||
		return logical.ErrorResponse("Error empty name"), nil
 | 
							return logical.ErrorResponse("'name' must be supplied"), nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Check for an existing group, possibly lowercased so that we keep using
 | 
				
			||||||
 | 
						// existing user set values
 | 
				
			||||||
 | 
						_, canonicalName, err := b.Group(req.Storage, name)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if canonicalName != "" {
 | 
				
			||||||
 | 
							name = canonicalName
 | 
				
			||||||
 | 
						} else {
 | 
				
			||||||
 | 
							name = strings.ToLower(name)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	entry, err := logical.StorageEntryJSON("group/"+name, &GroupEntry{
 | 
						entry, err := logical.StorageEntryJSON("group/"+name, &GroupEntry{
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user