mirror of
https://github.com/optim-enterprises-bv/vault.git
synced 2025-11-02 11:38:02 +00:00
Properly check for policy equivalency during renewal.
This introduces a function that compares two string policy sets while ignoring the presence of "default" (since it's added by core, not the backend), and ensuring that ordering and/or duplication are not failure conditions. Fixes #1256
This commit is contained in:
@@ -6,10 +6,9 @@ import (
|
|||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"reflect"
|
|
||||||
"sort"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/hashicorp/vault/helper/policies"
|
||||||
"github.com/hashicorp/vault/logical"
|
"github.com/hashicorp/vault/logical"
|
||||||
"github.com/hashicorp/vault/logical/framework"
|
"github.com/hashicorp/vault/logical/framework"
|
||||||
)
|
)
|
||||||
@@ -120,12 +119,11 @@ func (b *backend) pathLoginRenew(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Get the policies associated with the app
|
// Get the policies associated with the app
|
||||||
policies, err := b.MapAppId.Policies(req.Storage, appId)
|
mapPolicies, err := b.MapAppId.Policies(req.Storage, appId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
sort.Strings(req.Auth.Policies)
|
if !policies.EquivalentPolicies(mapPolicies, req.Auth.Policies) {
|
||||||
if !reflect.DeepEqual(policies, req.Auth.Policies) {
|
|
||||||
return logical.ErrorResponse("policies do not match"), nil
|
return logical.ErrorResponse("policies do not match"), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -436,6 +436,7 @@ func Test_Renew(t *testing.T) {
|
|||||||
req.Auth.InternalData = resp.Auth.InternalData
|
req.Auth.InternalData = resp.Auth.InternalData
|
||||||
req.Auth.Metadata = resp.Auth.Metadata
|
req.Auth.Metadata = resp.Auth.Metadata
|
||||||
req.Auth.LeaseOptions = resp.Auth.LeaseOptions
|
req.Auth.LeaseOptions = resp.Auth.LeaseOptions
|
||||||
|
req.Auth.Policies = resp.Auth.Policies
|
||||||
req.Auth.IssueTime = time.Now()
|
req.Auth.IssueTime = time.Now()
|
||||||
|
|
||||||
// Normal renewal
|
// Normal renewal
|
||||||
|
|||||||
@@ -7,10 +7,11 @@ import (
|
|||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/pem"
|
"encoding/pem"
|
||||||
"errors"
|
"errors"
|
||||||
"sort"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/hashicorp/vault/helper/certutil"
|
"github.com/hashicorp/vault/helper/certutil"
|
||||||
|
"github.com/hashicorp/vault/helper/policies"
|
||||||
"github.com/hashicorp/vault/logical"
|
"github.com/hashicorp/vault/logical"
|
||||||
"github.com/hashicorp/vault/logical/framework"
|
"github.com/hashicorp/vault/logical/framework"
|
||||||
)
|
)
|
||||||
@@ -59,16 +60,12 @@ func (b *backend) pathLogin(
|
|||||||
skid := base64.StdEncoding.EncodeToString(clientCerts[0].SubjectKeyId)
|
skid := base64.StdEncoding.EncodeToString(clientCerts[0].SubjectKeyId)
|
||||||
akid := base64.StdEncoding.EncodeToString(clientCerts[0].AuthorityKeyId)
|
akid := base64.StdEncoding.EncodeToString(clientCerts[0].AuthorityKeyId)
|
||||||
|
|
||||||
// We want to sort here so we can check properly during renewal)
|
|
||||||
sort.Strings(matched.Entry.Policies)
|
|
||||||
|
|
||||||
// Generate a response
|
// Generate a response
|
||||||
resp := &logical.Response{
|
resp := &logical.Response{
|
||||||
Auth: &logical.Auth{
|
Auth: &logical.Auth{
|
||||||
InternalData: map[string]interface{}{
|
InternalData: map[string]interface{}{
|
||||||
"subject_key_id": skid,
|
"subject_key_id": skid,
|
||||||
"authority_key_id": akid,
|
"authority_key_id": akid,
|
||||||
"policies": strings.Join(matched.Entry.Policies, ","),
|
|
||||||
},
|
},
|
||||||
Policies: matched.Entry.Policies,
|
Policies: matched.Entry.Policies,
|
||||||
DisplayName: matched.Entry.DisplayName,
|
DisplayName: matched.Entry.DisplayName,
|
||||||
@@ -132,10 +129,9 @@ func (b *backend) pathLoginRenew(
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
policies := cert.Policies
|
if !policies.EquivalentPolicies(cert.Policies, req.Auth.Policies) {
|
||||||
sort.Strings(policies)
|
return logical.ErrorResponse(fmt.Sprintf("policies have changed (%#v vs %#v), not renewing", cert.Policies, req.Auth.Policies)), nil
|
||||||
if strings.Join(policies, ",") != req.Auth.InternalData["policies"] {
|
// return logical.ErrorResponse("policies have changed, not renewing"), nil
|
||||||
return logical.ErrorResponse("policies have changed, not renewing"), nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return framework.LeaseExtend(cert.TTL, 0, b.System())(req, d)
|
return framework.LeaseExtend(cert.TTL, 0, b.System())(req, d)
|
||||||
|
|||||||
@@ -3,10 +3,9 @@ package github
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/url"
|
"net/url"
|
||||||
"reflect"
|
|
||||||
"sort"
|
|
||||||
|
|
||||||
"github.com/google/go-github/github"
|
"github.com/google/go-github/github"
|
||||||
|
"github.com/hashicorp/vault/helper/policies"
|
||||||
"github.com/hashicorp/vault/logical"
|
"github.com/hashicorp/vault/logical"
|
||||||
"github.com/hashicorp/vault/logical/framework"
|
"github.com/hashicorp/vault/logical/framework"
|
||||||
)
|
)
|
||||||
@@ -83,8 +82,7 @@ func (b *backend) pathLoginRenew(
|
|||||||
} else {
|
} else {
|
||||||
verifyResp = verifyResponse
|
verifyResp = verifyResponse
|
||||||
}
|
}
|
||||||
sort.Strings(req.Auth.Policies)
|
if !policies.EquivalentPolicies(verifyResp.Policies, req.Auth.Policies) {
|
||||||
if !reflect.DeepEqual(verifyResp.Policies, req.Auth.Policies) {
|
|
||||||
return logical.ErrorResponse("policies do not match"), nil
|
return logical.ErrorResponse("policies do not match"), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -246,7 +246,7 @@ func testAccStepLogin(t *testing.T, user string, pass string) logicaltest.TestSt
|
|||||||
},
|
},
|
||||||
Unauthenticated: true,
|
Unauthenticated: true,
|
||||||
|
|
||||||
Check: logicaltest.TestCheckAuth([]string{"foo", "bar"}),
|
Check: logicaltest.TestCheckAuth([]string{"bar", "default", "foo"}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/hashicorp/vault/helper/policies"
|
||||||
"github.com/hashicorp/vault/logical"
|
"github.com/hashicorp/vault/logical"
|
||||||
"github.com/hashicorp/vault/logical/framework"
|
"github.com/hashicorp/vault/logical/framework"
|
||||||
)
|
)
|
||||||
@@ -67,15 +68,13 @@ func (b *backend) pathLoginRenew(
|
|||||||
|
|
||||||
username := req.Auth.Metadata["username"]
|
username := req.Auth.Metadata["username"]
|
||||||
password := req.Auth.InternalData["password"].(string)
|
password := req.Auth.InternalData["password"].(string)
|
||||||
prevpolicies := req.Auth.Metadata["policies"]
|
|
||||||
|
|
||||||
policies, resp, err := b.Login(req, username, password)
|
loginPolicies, resp, err := b.Login(req, username, password)
|
||||||
if len(policies) == 0 {
|
if len(loginPolicies) == 0 {
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
|
|
||||||
sort.Strings(policies)
|
if !policies.EquivalentPolicies(loginPolicies, req.Auth.Policies) {
|
||||||
if strings.Join(policies, ",") != prevpolicies {
|
|
||||||
return logical.ErrorResponse("policies have changed, not renewing"), nil
|
return logical.ErrorResponse("policies have changed, not renewing"), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/hashicorp/vault/helper/policies"
|
||||||
"github.com/hashicorp/vault/logical"
|
"github.com/hashicorp/vault/logical"
|
||||||
"github.com/hashicorp/vault/logical/framework"
|
"github.com/hashicorp/vault/logical/framework"
|
||||||
"golang.org/x/crypto/bcrypt"
|
"golang.org/x/crypto/bcrypt"
|
||||||
@@ -92,6 +93,10 @@ func (b *backend) pathLoginRenew(
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !policies.EquivalentPolicies(user.Policies, req.Auth.Policies) {
|
||||||
|
return logical.ErrorResponse("policies have changed, not renewing"), nil
|
||||||
|
}
|
||||||
|
|
||||||
return framework.LeaseExtend(user.TTL, user.MaxTTL, b.System())(req, d)
|
return framework.LeaseExtend(user.TTL, user.MaxTTL, b.System())(req, d)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
57
helper/policies/policies.go
Normal file
57
helper/policies/policies.go
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
package policies
|
||||||
|
|
||||||
|
import "sort"
|
||||||
|
|
||||||
|
// ComparePolicies checks whether the given policy sets are equivalent, as in,
|
||||||
|
// they contain the same values. The benefit of this method is that it leaves
|
||||||
|
// the "default" policy out of its comparisons as it may be added later by core
|
||||||
|
// after a set of policies has been saved by a backend.
|
||||||
|
func EquivalentPolicies(a, b []string) bool {
|
||||||
|
if a == nil && b == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
if a == nil || b == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// First we'll build maps to ensure unique values and filter default
|
||||||
|
mapA := map[string]bool{}
|
||||||
|
mapB := map[string]bool{}
|
||||||
|
for _, keyA := range a {
|
||||||
|
if keyA == "default" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
mapA[keyA] = true
|
||||||
|
}
|
||||||
|
for _, keyB := range b {
|
||||||
|
if keyB == "default" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
mapB[keyB] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now we'll build our checking slices
|
||||||
|
var sortedA, sortedB []string
|
||||||
|
for keyA, _ := range mapA {
|
||||||
|
sortedA = append(sortedA, keyA)
|
||||||
|
}
|
||||||
|
for keyB, _ := range mapB {
|
||||||
|
sortedB = append(sortedB, keyB)
|
||||||
|
}
|
||||||
|
sort.Strings(sortedA)
|
||||||
|
sort.Strings(sortedB)
|
||||||
|
|
||||||
|
// Finally, compare
|
||||||
|
if len(sortedA) != len(sortedB) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range a {
|
||||||
|
if sortedA[i] != sortedB[i] {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
26
helper/policies/policies_test.go
Normal file
26
helper/policies/policies_test.go
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
package policies
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestEquivalentPolicies(t *testing.T) {
|
||||||
|
a := []string{"foo", "bar"}
|
||||||
|
var b []string
|
||||||
|
if EquivalentPolicies(a, b) {
|
||||||
|
t.Fatal("bad")
|
||||||
|
}
|
||||||
|
|
||||||
|
b = []string{"foo"}
|
||||||
|
if EquivalentPolicies(a, b) {
|
||||||
|
t.Fatal("bad")
|
||||||
|
}
|
||||||
|
|
||||||
|
b = []string{"bar", "foo"}
|
||||||
|
if !EquivalentPolicies(a, b) {
|
||||||
|
t.Fatal("bad")
|
||||||
|
}
|
||||||
|
|
||||||
|
b = []string{"foo", "default", "bar"}
|
||||||
|
if !EquivalentPolicies(a, b) {
|
||||||
|
t.Fatal("bad")
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user