Compare groups case-insensitively at login time (#3240)

* Compare groups case-insensitively at login time, since Okta groups are
case-insensitive but preserving.

* Make other group operations case-preserving but otherwise
case-insensitive. New groups will be written in lowercase.
This commit is contained in:
Jeff Mitchell
2017-08-25 14:48:37 -04:00
committed by GitHub
parent ae825401e1
commit 341636336b
3 changed files with 68 additions and 18 deletions

View File

@@ -84,6 +84,11 @@ func (b *backend) Login(req *logical.Request, username string, password string)
var allGroups []string
// Import the custom added groups from okta backend
user, err := b.User(req.Storage, username)
if err != nil {
if b.Logger().IsDebug() {
b.Logger().Debug("auth/okta: error looking up user", "error", err)
}
}
if err == nil && user != nil && user.Groups != nil {
if b.Logger().IsDebug() {
b.Logger().Debug("auth/okta: adding local groups", "num_local_groups", len(user.Groups), "local_groups", user.Groups)
@@ -96,9 +101,14 @@ func (b *backend) Login(req *logical.Request, username string, password string)
// Retrieve policies
var policies []string
for _, groupName := range allGroups {
group, err := b.Group(req.Storage, groupName)
if err == nil && group != nil && group.Policies != nil {
policies = append(policies, group.Policies...)
entry, _, err := b.Group(req.Storage, groupName)
if err != nil {
if b.Logger().IsDebug() {
b.Logger().Debug("auth/okta: error looking up group policies", "error", err)
}
}
if err == nil && entry != nil && entry.Policies != nil {
policies = append(policies, entry.Policies...)
}
}

View File

@@ -10,9 +10,10 @@ import (
"github.com/hashicorp/vault/helper/policyutil"
log "github.com/mgutz/logxi/v1"
"time"
"github.com/hashicorp/vault/logical"
logicaltest "github.com/hashicorp/vault/logical/testing"
"time"
)
func TestBackend_Config(t *testing.T) {
@@ -52,15 +53,15 @@ func TestBackend_Config(t *testing.T) {
testConfigCreate(t, configData),
testLoginWrite(t, username, "wrong", "E0000004", 0, nil),
testLoginWrite(t, username, password, "user is not a member of any authorized policy", 0, nil),
testAccUserGroups(t, username, "local_group,local_group2"),
testAccGroups(t, "local_group", "local_group_policy"),
testAccUserGroups(t, username, "local_grouP,lOcal_group2"),
testAccGroups(t, "local_groUp", "loCal_group_policy"),
testLoginWrite(t, username, password, "", defaultLeaseTTLVal, []string{"local_group_policy"}),
testAccGroups(t, "Everyone", "everyone_group_policy,every_group_policy2"),
testAccGroups(t, "everyoNe", "everyone_grouP_policy,eveRy_group_policy2"),
testLoginWrite(t, username, password, "", defaultLeaseTTLVal, []string{"local_group_policy"}),
testConfigUpdate(t, configDataToken),
testConfigRead(t, token, configData),
testLoginWrite(t, username, password, "", updatedDuration, []string{"everyone_group_policy", "every_group_policy2", "local_group_policy"}),
testAccGroups(t, "local_group2", "testgroup_group_policy"),
testAccGroups(t, "locAl_group2", "testgroup_group_policy"),
testLoginWrite(t, username, password, "", updatedDuration, []string{"everyone_group_policy", "every_group_policy2", "local_group_policy", "testgroup_group_policy"}),
},
})

View File

@@ -1,6 +1,8 @@
package okta
import (
"strings"
"github.com/hashicorp/vault/helper/policyutil"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
@@ -45,34 +47,59 @@ func pathGroups(b *backend) *framework.Path {
}
}
func (b *backend) Group(s logical.Storage, n string) (*GroupEntry, error) {
// We look up groups in a case-insensitive manner since Okta is case-preserving
// but case-insensitive for comparisons
func (b *backend) Group(s logical.Storage, n string) (*GroupEntry, string, error) {
canonicalName := n
entry, err := s.Get("group/" + n)
if err != nil {
return nil, err
return nil, "", err
}
if entry == nil {
return nil, nil
entries, err := s.List("group/")
if err != nil {
return nil, "", err
}
for _, groupName := range entries {
if strings.ToLower(groupName) == strings.ToLower(n) {
entry, err = s.Get("group/" + groupName)
if err != nil {
return nil, "", err
}
canonicalName = groupName
break
}
}
}
if entry == nil {
return nil, "", nil
}
var result GroupEntry
if err := entry.DecodeJSON(&result); err != nil {
return nil, err
return nil, "", err
}
return &result, nil
return &result, canonicalName, nil
}
func (b *backend) pathGroupDelete(
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
name := d.Get("name").(string)
if len(name) == 0 {
return logical.ErrorResponse("Error empty name"), nil
return logical.ErrorResponse("'name' must be supplied"), nil
}
err := req.Storage.Delete("group/" + name)
entry, canonicalName, err := b.Group(req.Storage, name)
if err != nil {
return nil, err
}
if entry != nil {
err := req.Storage.Delete("group/" + canonicalName)
if err != nil {
return nil, err
}
}
return nil, nil
}
@@ -81,10 +108,10 @@ func (b *backend) pathGroupRead(
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
name := d.Get("name").(string)
if len(name) == 0 {
return logical.ErrorResponse("Error empty name"), nil
return logical.ErrorResponse("'name' must be supplied"), nil
}
group, err := b.Group(req.Storage, name)
group, _, err := b.Group(req.Storage, name)
if err != nil {
return nil, err
}
@@ -103,7 +130,19 @@ func (b *backend) pathGroupWrite(
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
name := d.Get("name").(string)
if len(name) == 0 {
return logical.ErrorResponse("Error empty name"), nil
return logical.ErrorResponse("'name' must be supplied"), nil
}
// Check for an existing group, possibly lowercased so that we keep using
// existing user set values
_, canonicalName, err := b.Group(req.Storage, name)
if err != nil {
return nil, err
}
if canonicalName != "" {
name = canonicalName
} else {
name = strings.ToLower(name)
}
entry, err := logical.StorageEntryJSON("group/"+name, &GroupEntry{