mirror of
https://github.com/optim-enterprises-bv/vault.git
synced 2025-11-02 19:47:54 +00:00
AppRole/Identity: Fix for race when creating an entity during login (#3932)
* possible fix for race in approle login while creating entity * Add a test that hits the login request concurrently * address review comments
This commit is contained in:
86
command/approle_concurrency_integ_test.go
Normal file
86
command/approle_concurrency_integ_test.go
Normal file
@@ -0,0 +1,86 @@
|
||||
package command
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/hashicorp/vault/api"
|
||||
credAppRole "github.com/hashicorp/vault/builtin/credential/approle"
|
||||
vaulthttp "github.com/hashicorp/vault/http"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/vault"
|
||||
logxi "github.com/mgutz/logxi/v1"
|
||||
)
|
||||
|
||||
func TestAppRole_Integ_ConcurrentLogins(t *testing.T) {
|
||||
var err error
|
||||
coreConfig := &vault.CoreConfig{
|
||||
DisableMlock: true,
|
||||
DisableCache: true,
|
||||
Logger: logxi.NullLog,
|
||||
CredentialBackends: map[string]logical.Factory{
|
||||
"approle": credAppRole.Factory,
|
||||
},
|
||||
}
|
||||
|
||||
cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{
|
||||
HandlerFunc: vaulthttp.Handler,
|
||||
})
|
||||
|
||||
cluster.Start()
|
||||
defer cluster.Cleanup()
|
||||
|
||||
cores := cluster.Cores
|
||||
|
||||
vault.TestWaitActive(t, cores[0].Core)
|
||||
|
||||
client := cores[0].Client
|
||||
|
||||
err = client.Sys().EnableAuthWithOptions("approle", &api.EnableAuthOptions{
|
||||
Type: "approle",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = client.Logical().Write("auth/approle/role/role1", map[string]interface{}{
|
||||
"bind_secret_id": "true",
|
||||
"period": "300",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
secret, err := client.Logical().Write("auth/approle/role/role1/secret-id", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
secretID := secret.Data["secret_id"].(string)
|
||||
|
||||
secret, err = client.Logical().Read("auth/approle/role/role1/role-id")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
roleID := secret.Data["role_id"].(string)
|
||||
|
||||
wg := &sync.WaitGroup{}
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
secret, err = client.Logical().Write("auth/approle/login", map[string]interface{}{
|
||||
"role_id": roleID,
|
||||
"secret_id": secretID,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if secret.Auth.ClientToken == "" {
|
||||
t.Fatalf("expected a successful login")
|
||||
}
|
||||
}()
|
||||
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
@@ -249,7 +249,27 @@ func (i *IdentityStore) entityByAliasFactors(mountAccessor, aliasName string, cl
|
||||
return nil, fmt.Errorf("missing alias name")
|
||||
}
|
||||
|
||||
alias, err := i.MemDBAliasByFactors(mountAccessor, aliasName, false, false)
|
||||
txn := i.db.Txn(false)
|
||||
|
||||
return i.entityByAliasFactorsInTxn(txn, mountAccessor, aliasName, clone)
|
||||
}
|
||||
|
||||
// entityByAlaisFactorsInTxn fetches the entity based on factors of alias, i.e
|
||||
// mount accessor and the alias name.
|
||||
func (i *IdentityStore) entityByAliasFactorsInTxn(txn *memdb.Txn, mountAccessor, aliasName string, clone bool) (*identity.Entity, error) {
|
||||
if txn == nil {
|
||||
return nil, fmt.Errorf("nil txn")
|
||||
}
|
||||
|
||||
if mountAccessor == "" {
|
||||
return nil, fmt.Errorf("missing mount accessor")
|
||||
}
|
||||
|
||||
if aliasName == "" {
|
||||
return nil, fmt.Errorf("missing alias name")
|
||||
}
|
||||
|
||||
alias, err := i.MemDBAliasByFactorsInTxn(txn, mountAccessor, aliasName, false, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -258,12 +278,12 @@ func (i *IdentityStore) entityByAliasFactors(mountAccessor, aliasName string, cl
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return i.MemDBEntityByAliasID(alias.ID, clone)
|
||||
return i.MemDBEntityByAliasIDInTxn(txn, alias.ID, clone)
|
||||
}
|
||||
|
||||
// CreateEntity creates a new entity. This is used by core to
|
||||
// CreateOrFetchEntity creates a new entity. This is used by core to
|
||||
// associate each login attempt by an alias to a unified entity in Vault.
|
||||
func (i *IdentityStore) CreateEntity(alias *logical.Alias) (*identity.Entity, error) {
|
||||
func (i *IdentityStore) CreateOrFetchEntity(alias *logical.Alias) (*identity.Entity, error) {
|
||||
var entity *identity.Entity
|
||||
var err error
|
||||
|
||||
@@ -290,9 +310,24 @@ func (i *IdentityStore) CreateEntity(alias *logical.Alias) (*identity.Entity, er
|
||||
return nil, err
|
||||
}
|
||||
if entity != nil {
|
||||
return nil, fmt.Errorf("alias already belongs to a different entity")
|
||||
return entity, nil
|
||||
}
|
||||
|
||||
// Create a MemDB transaction to update both alias and entity
|
||||
txn := i.db.Txn(true)
|
||||
defer txn.Abort()
|
||||
|
||||
// Check if an entity was created before acquiring the lock
|
||||
entity, err = i.entityByAliasFactorsInTxn(txn, alias.MountAccessor, alias.Name, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if entity != nil {
|
||||
return entity, nil
|
||||
}
|
||||
|
||||
i.logger.Debug("identity: creating a new entity", "alias", alias)
|
||||
|
||||
entity = &identity.Entity{}
|
||||
|
||||
err = i.sanitizeEntity(entity)
|
||||
@@ -320,10 +355,12 @@ func (i *IdentityStore) CreateEntity(alias *logical.Alias) (*identity.Entity, er
|
||||
}
|
||||
|
||||
// Update MemDB and persist entity object
|
||||
err = i.upsertEntity(entity, nil, true)
|
||||
err = i.upsertEntityInTxn(txn, entity, nil, true, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
txn.Commit()
|
||||
|
||||
return entity, nil
|
||||
}
|
||||
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
"github.com/hashicorp/vault/logical"
|
||||
)
|
||||
|
||||
func TestIdentityStore_CreateEntity(t *testing.T) {
|
||||
func TestIdentityStore_CreateOrFetchEntity(t *testing.T) {
|
||||
is, ghAccessor, _ := testIdentityStoreWithGithubAuth(t)
|
||||
alias := &logical.Alias{
|
||||
MountType: "github",
|
||||
@@ -17,7 +17,7 @@ func TestIdentityStore_CreateEntity(t *testing.T) {
|
||||
Name: "githubuser",
|
||||
}
|
||||
|
||||
entity, err := is.CreateEntity(alias)
|
||||
entity, err := is.CreateOrFetchEntity(alias)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -33,10 +33,20 @@ func TestIdentityStore_CreateEntity(t *testing.T) {
|
||||
t.Fatalf("bad: alias name; expected: %q, actual: %q", alias.Name, entity.Aliases[0].Name)
|
||||
}
|
||||
|
||||
// Try recreating an entity with the same alias details. It should fail.
|
||||
entity, err = is.CreateEntity(alias)
|
||||
if err == nil {
|
||||
t.Fatalf("expected an error")
|
||||
entity, err = is.CreateOrFetchEntity(alias)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if entity == nil {
|
||||
t.Fatalf("expected a non-nil entity")
|
||||
}
|
||||
|
||||
if len(entity.Aliases) != 1 {
|
||||
t.Fatalf("bad: length of aliases; expected: 1, actual: %d", len(entity.Aliases))
|
||||
}
|
||||
|
||||
if entity.Aliases[0].Name != alias.Name {
|
||||
t.Fatalf("bad: alias name; expected: %q, actual: %q", alias.Name, entity.Aliases[0].Name)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -666,12 +666,29 @@ func (i *IdentityStore) MemDBAliasByFactors(mountAccessor, aliasName string, clo
|
||||
return nil, fmt.Errorf("missing mount accessor")
|
||||
}
|
||||
|
||||
txn := i.db.Txn(false)
|
||||
|
||||
return i.MemDBAliasByFactorsInTxn(txn, mountAccessor, aliasName, clone, groupAlias)
|
||||
}
|
||||
|
||||
func (i *IdentityStore) MemDBAliasByFactorsInTxn(txn *memdb.Txn, mountAccessor, aliasName string, clone bool, groupAlias bool) (*identity.Alias, error) {
|
||||
if txn == nil {
|
||||
return nil, fmt.Errorf("nil txn")
|
||||
}
|
||||
|
||||
if aliasName == "" {
|
||||
return nil, fmt.Errorf("missing alias name")
|
||||
}
|
||||
|
||||
if mountAccessor == "" {
|
||||
return nil, fmt.Errorf("missing mount accessor")
|
||||
}
|
||||
|
||||
tableName := entityAliasesTable
|
||||
if groupAlias {
|
||||
tableName = groupAliasesTable
|
||||
}
|
||||
|
||||
txn := i.db.Txn(false)
|
||||
aliasRaw, err := txn.First(tableName, "factors", mountAccessor, aliasName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to fetch alias from memdb using factors: %v", err)
|
||||
|
||||
@@ -436,22 +436,15 @@ func (c *Core) handleLoginRequest(ctx context.Context, req *logical.Request) (re
|
||||
|
||||
var err error
|
||||
|
||||
// Check if an entity already exists for the given alias
|
||||
entity, err = c.identityStore.entityByAliasFactors(auth.Alias.MountAccessor, auth.Alias.Name, false)
|
||||
// Fetch the entity for the alias, or create an entity if one
|
||||
// doesn't exist.
|
||||
entity, err = c.identityStore.CreateOrFetchEntity(auth.Alias)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// If not, create one.
|
||||
if entity == nil {
|
||||
c.logger.Debug("core: creating a new entity", "alias", auth.Alias)
|
||||
entity, err = c.identityStore.CreateEntity(auth.Alias)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if entity == nil {
|
||||
return nil, nil, fmt.Errorf("failed to create an entity for the authenticated alias")
|
||||
}
|
||||
return nil, nil, fmt.Errorf("failed to create an entity for the authenticated alias")
|
||||
}
|
||||
|
||||
auth.EntityID = entity.ID
|
||||
|
||||
Reference in New Issue
Block a user