mirror of
https://github.com/optim-enterprises-bv/vault.git
synced 2025-11-03 03:58:01 +00:00
Compare groups case-insensitively at login time (#3240)
* Compare groups case-insensitively at login time, since Okta groups are case-insensitive but preserving. * Make other group operations case-preserving but otherwise case-insensitive. New groups will be written in lowercase.
This commit is contained in:
@@ -84,6 +84,11 @@ func (b *backend) Login(req *logical.Request, username string, password string)
|
|||||||
var allGroups []string
|
var allGroups []string
|
||||||
// Import the custom added groups from okta backend
|
// Import the custom added groups from okta backend
|
||||||
user, err := b.User(req.Storage, username)
|
user, err := b.User(req.Storage, username)
|
||||||
|
if err != nil {
|
||||||
|
if b.Logger().IsDebug() {
|
||||||
|
b.Logger().Debug("auth/okta: error looking up user", "error", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
if err == nil && user != nil && user.Groups != nil {
|
if err == nil && user != nil && user.Groups != nil {
|
||||||
if b.Logger().IsDebug() {
|
if b.Logger().IsDebug() {
|
||||||
b.Logger().Debug("auth/okta: adding local groups", "num_local_groups", len(user.Groups), "local_groups", user.Groups)
|
b.Logger().Debug("auth/okta: adding local groups", "num_local_groups", len(user.Groups), "local_groups", user.Groups)
|
||||||
@@ -96,9 +101,14 @@ func (b *backend) Login(req *logical.Request, username string, password string)
|
|||||||
// Retrieve policies
|
// Retrieve policies
|
||||||
var policies []string
|
var policies []string
|
||||||
for _, groupName := range allGroups {
|
for _, groupName := range allGroups {
|
||||||
group, err := b.Group(req.Storage, groupName)
|
entry, _, err := b.Group(req.Storage, groupName)
|
||||||
if err == nil && group != nil && group.Policies != nil {
|
if err != nil {
|
||||||
policies = append(policies, group.Policies...)
|
if b.Logger().IsDebug() {
|
||||||
|
b.Logger().Debug("auth/okta: error looking up group policies", "error", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err == nil && entry != nil && entry.Policies != nil {
|
||||||
|
policies = append(policies, entry.Policies...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -10,9 +10,10 @@ import (
|
|||||||
"github.com/hashicorp/vault/helper/policyutil"
|
"github.com/hashicorp/vault/helper/policyutil"
|
||||||
log "github.com/mgutz/logxi/v1"
|
log "github.com/mgutz/logxi/v1"
|
||||||
|
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/hashicorp/vault/logical"
|
"github.com/hashicorp/vault/logical"
|
||||||
logicaltest "github.com/hashicorp/vault/logical/testing"
|
logicaltest "github.com/hashicorp/vault/logical/testing"
|
||||||
"time"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestBackend_Config(t *testing.T) {
|
func TestBackend_Config(t *testing.T) {
|
||||||
@@ -52,15 +53,15 @@ func TestBackend_Config(t *testing.T) {
|
|||||||
testConfigCreate(t, configData),
|
testConfigCreate(t, configData),
|
||||||
testLoginWrite(t, username, "wrong", "E0000004", 0, nil),
|
testLoginWrite(t, username, "wrong", "E0000004", 0, nil),
|
||||||
testLoginWrite(t, username, password, "user is not a member of any authorized policy", 0, nil),
|
testLoginWrite(t, username, password, "user is not a member of any authorized policy", 0, nil),
|
||||||
testAccUserGroups(t, username, "local_group,local_group2"),
|
testAccUserGroups(t, username, "local_grouP,lOcal_group2"),
|
||||||
testAccGroups(t, "local_group", "local_group_policy"),
|
testAccGroups(t, "local_groUp", "loCal_group_policy"),
|
||||||
testLoginWrite(t, username, password, "", defaultLeaseTTLVal, []string{"local_group_policy"}),
|
testLoginWrite(t, username, password, "", defaultLeaseTTLVal, []string{"local_group_policy"}),
|
||||||
testAccGroups(t, "Everyone", "everyone_group_policy,every_group_policy2"),
|
testAccGroups(t, "everyoNe", "everyone_grouP_policy,eveRy_group_policy2"),
|
||||||
testLoginWrite(t, username, password, "", defaultLeaseTTLVal, []string{"local_group_policy"}),
|
testLoginWrite(t, username, password, "", defaultLeaseTTLVal, []string{"local_group_policy"}),
|
||||||
testConfigUpdate(t, configDataToken),
|
testConfigUpdate(t, configDataToken),
|
||||||
testConfigRead(t, token, configData),
|
testConfigRead(t, token, configData),
|
||||||
testLoginWrite(t, username, password, "", updatedDuration, []string{"everyone_group_policy", "every_group_policy2", "local_group_policy"}),
|
testLoginWrite(t, username, password, "", updatedDuration, []string{"everyone_group_policy", "every_group_policy2", "local_group_policy"}),
|
||||||
testAccGroups(t, "local_group2", "testgroup_group_policy"),
|
testAccGroups(t, "locAl_group2", "testgroup_group_policy"),
|
||||||
testLoginWrite(t, username, password, "", updatedDuration, []string{"everyone_group_policy", "every_group_policy2", "local_group_policy", "testgroup_group_policy"}),
|
testLoginWrite(t, username, password, "", updatedDuration, []string{"everyone_group_policy", "every_group_policy2", "local_group_policy", "testgroup_group_policy"}),
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
package okta
|
package okta
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/hashicorp/vault/helper/policyutil"
|
"github.com/hashicorp/vault/helper/policyutil"
|
||||||
"github.com/hashicorp/vault/logical"
|
"github.com/hashicorp/vault/logical"
|
||||||
"github.com/hashicorp/vault/logical/framework"
|
"github.com/hashicorp/vault/logical/framework"
|
||||||
@@ -45,34 +47,59 @@ func pathGroups(b *backend) *framework.Path {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *backend) Group(s logical.Storage, n string) (*GroupEntry, error) {
|
// We look up groups in a case-insensitive manner since Okta is case-preserving
|
||||||
|
// but case-insensitive for comparisons
|
||||||
|
func (b *backend) Group(s logical.Storage, n string) (*GroupEntry, string, error) {
|
||||||
|
canonicalName := n
|
||||||
entry, err := s.Get("group/" + n)
|
entry, err := s.Get("group/" + n)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
if entry == nil {
|
if entry == nil {
|
||||||
return nil, nil
|
entries, err := s.List("group/")
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
for _, groupName := range entries {
|
||||||
|
if strings.ToLower(groupName) == strings.ToLower(n) {
|
||||||
|
entry, err = s.Get("group/" + groupName)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
canonicalName = groupName
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if entry == nil {
|
||||||
|
return nil, "", nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var result GroupEntry
|
var result GroupEntry
|
||||||
if err := entry.DecodeJSON(&result); err != nil {
|
if err := entry.DecodeJSON(&result); err != nil {
|
||||||
return nil, err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
return &result, nil
|
return &result, canonicalName, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *backend) pathGroupDelete(
|
func (b *backend) pathGroupDelete(
|
||||||
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||||
name := d.Get("name").(string)
|
name := d.Get("name").(string)
|
||||||
if len(name) == 0 {
|
if len(name) == 0 {
|
||||||
return logical.ErrorResponse("Error empty name"), nil
|
return logical.ErrorResponse("'name' must be supplied"), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
err := req.Storage.Delete("group/" + name)
|
entry, canonicalName, err := b.Group(req.Storage, name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
if entry != nil {
|
||||||
|
err := req.Storage.Delete("group/" + canonicalName)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
@@ -81,10 +108,10 @@ func (b *backend) pathGroupRead(
|
|||||||
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||||
name := d.Get("name").(string)
|
name := d.Get("name").(string)
|
||||||
if len(name) == 0 {
|
if len(name) == 0 {
|
||||||
return logical.ErrorResponse("Error empty name"), nil
|
return logical.ErrorResponse("'name' must be supplied"), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
group, err := b.Group(req.Storage, name)
|
group, _, err := b.Group(req.Storage, name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -103,7 +130,19 @@ func (b *backend) pathGroupWrite(
|
|||||||
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||||
name := d.Get("name").(string)
|
name := d.Get("name").(string)
|
||||||
if len(name) == 0 {
|
if len(name) == 0 {
|
||||||
return logical.ErrorResponse("Error empty name"), nil
|
return logical.ErrorResponse("'name' must be supplied"), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for an existing group, possibly lowercased so that we keep using
|
||||||
|
// existing user set values
|
||||||
|
_, canonicalName, err := b.Group(req.Storage, name)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if canonicalName != "" {
|
||||||
|
name = canonicalName
|
||||||
|
} else {
|
||||||
|
name = strings.ToLower(name)
|
||||||
}
|
}
|
||||||
|
|
||||||
entry, err := logical.StorageEntryJSON("group/"+name, &GroupEntry{
|
entry, err := logical.StorageEntryJSON("group/"+name, &GroupEntry{
|
||||||
|
|||||||
Reference in New Issue
Block a user