AppRole/Identity: Fix for race when creating an entity during login (#3932)

* possible fix for race in approle login while creating entity

* Add a test that hits the login request concurrently

* address review comments
This commit is contained in:
Vishal Nayak
2018-02-09 10:40:56 -05:00
committed by GitHub
parent e47c7e866a
commit 5bb8fa2469
5 changed files with 167 additions and 24 deletions

View File

@@ -0,0 +1,86 @@
package command
import (
"sync"
"testing"
"github.com/hashicorp/vault/api"
credAppRole "github.com/hashicorp/vault/builtin/credential/approle"
vaulthttp "github.com/hashicorp/vault/http"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/vault"
logxi "github.com/mgutz/logxi/v1"
)
func TestAppRole_Integ_ConcurrentLogins(t *testing.T) {
var err error
coreConfig := &vault.CoreConfig{
DisableMlock: true,
DisableCache: true,
Logger: logxi.NullLog,
CredentialBackends: map[string]logical.Factory{
"approle": credAppRole.Factory,
},
}
cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{
HandlerFunc: vaulthttp.Handler,
})
cluster.Start()
defer cluster.Cleanup()
cores := cluster.Cores
vault.TestWaitActive(t, cores[0].Core)
client := cores[0].Client
err = client.Sys().EnableAuthWithOptions("approle", &api.EnableAuthOptions{
Type: "approle",
})
if err != nil {
t.Fatal(err)
}
_, err = client.Logical().Write("auth/approle/role/role1", map[string]interface{}{
"bind_secret_id": "true",
"period": "300",
})
if err != nil {
t.Fatal(err)
}
secret, err := client.Logical().Write("auth/approle/role/role1/secret-id", nil)
if err != nil {
t.Fatal(err)
}
secretID := secret.Data["secret_id"].(string)
secret, err = client.Logical().Read("auth/approle/role/role1/role-id")
if err != nil {
t.Fatal(err)
}
roleID := secret.Data["role_id"].(string)
wg := &sync.WaitGroup{}
for i := 0; i < 100; i++ {
wg.Add(1)
go func() {
defer wg.Done()
secret, err = client.Logical().Write("auth/approle/login", map[string]interface{}{
"role_id": roleID,
"secret_id": secretID,
})
if err != nil {
t.Fatal(err)
}
if secret.Auth.ClientToken == "" {
t.Fatalf("expected a successful login")
}
}()
}
wg.Wait()
}

View File

@@ -249,7 +249,27 @@ func (i *IdentityStore) entityByAliasFactors(mountAccessor, aliasName string, cl
return nil, fmt.Errorf("missing alias name")
}
alias, err := i.MemDBAliasByFactors(mountAccessor, aliasName, false, false)
txn := i.db.Txn(false)
return i.entityByAliasFactorsInTxn(txn, mountAccessor, aliasName, clone)
}
// entityByAlaisFactorsInTxn fetches the entity based on factors of alias, i.e
// mount accessor and the alias name.
func (i *IdentityStore) entityByAliasFactorsInTxn(txn *memdb.Txn, mountAccessor, aliasName string, clone bool) (*identity.Entity, error) {
if txn == nil {
return nil, fmt.Errorf("nil txn")
}
if mountAccessor == "" {
return nil, fmt.Errorf("missing mount accessor")
}
if aliasName == "" {
return nil, fmt.Errorf("missing alias name")
}
alias, err := i.MemDBAliasByFactorsInTxn(txn, mountAccessor, aliasName, false, false)
if err != nil {
return nil, err
}
@@ -258,12 +278,12 @@ func (i *IdentityStore) entityByAliasFactors(mountAccessor, aliasName string, cl
return nil, nil
}
return i.MemDBEntityByAliasID(alias.ID, clone)
return i.MemDBEntityByAliasIDInTxn(txn, alias.ID, clone)
}
// CreateEntity creates a new entity. This is used by core to
// CreateOrFetchEntity creates a new entity. This is used by core to
// associate each login attempt by an alias to a unified entity in Vault.
func (i *IdentityStore) CreateEntity(alias *logical.Alias) (*identity.Entity, error) {
func (i *IdentityStore) CreateOrFetchEntity(alias *logical.Alias) (*identity.Entity, error) {
var entity *identity.Entity
var err error
@@ -290,9 +310,24 @@ func (i *IdentityStore) CreateEntity(alias *logical.Alias) (*identity.Entity, er
return nil, err
}
if entity != nil {
return nil, fmt.Errorf("alias already belongs to a different entity")
return entity, nil
}
// Create a MemDB transaction to update both alias and entity
txn := i.db.Txn(true)
defer txn.Abort()
// Check if an entity was created before acquiring the lock
entity, err = i.entityByAliasFactorsInTxn(txn, alias.MountAccessor, alias.Name, false)
if err != nil {
return nil, err
}
if entity != nil {
return entity, nil
}
i.logger.Debug("identity: creating a new entity", "alias", alias)
entity = &identity.Entity{}
err = i.sanitizeEntity(entity)
@@ -320,10 +355,12 @@ func (i *IdentityStore) CreateEntity(alias *logical.Alias) (*identity.Entity, er
}
// Update MemDB and persist entity object
err = i.upsertEntity(entity, nil, true)
err = i.upsertEntityInTxn(txn, entity, nil, true, false)
if err != nil {
return nil, err
}
txn.Commit()
return entity, nil
}

View File

@@ -9,7 +9,7 @@ import (
"github.com/hashicorp/vault/logical"
)
func TestIdentityStore_CreateEntity(t *testing.T) {
func TestIdentityStore_CreateOrFetchEntity(t *testing.T) {
is, ghAccessor, _ := testIdentityStoreWithGithubAuth(t)
alias := &logical.Alias{
MountType: "github",
@@ -17,7 +17,7 @@ func TestIdentityStore_CreateEntity(t *testing.T) {
Name: "githubuser",
}
entity, err := is.CreateEntity(alias)
entity, err := is.CreateOrFetchEntity(alias)
if err != nil {
t.Fatal(err)
}
@@ -33,10 +33,20 @@ func TestIdentityStore_CreateEntity(t *testing.T) {
t.Fatalf("bad: alias name; expected: %q, actual: %q", alias.Name, entity.Aliases[0].Name)
}
// Try recreating an entity with the same alias details. It should fail.
entity, err = is.CreateEntity(alias)
if err == nil {
t.Fatalf("expected an error")
entity, err = is.CreateOrFetchEntity(alias)
if err != nil {
t.Fatal(err)
}
if entity == nil {
t.Fatalf("expected a non-nil entity")
}
if len(entity.Aliases) != 1 {
t.Fatalf("bad: length of aliases; expected: 1, actual: %d", len(entity.Aliases))
}
if entity.Aliases[0].Name != alias.Name {
t.Fatalf("bad: alias name; expected: %q, actual: %q", alias.Name, entity.Aliases[0].Name)
}
}

View File

@@ -666,12 +666,29 @@ func (i *IdentityStore) MemDBAliasByFactors(mountAccessor, aliasName string, clo
return nil, fmt.Errorf("missing mount accessor")
}
txn := i.db.Txn(false)
return i.MemDBAliasByFactorsInTxn(txn, mountAccessor, aliasName, clone, groupAlias)
}
func (i *IdentityStore) MemDBAliasByFactorsInTxn(txn *memdb.Txn, mountAccessor, aliasName string, clone bool, groupAlias bool) (*identity.Alias, error) {
if txn == nil {
return nil, fmt.Errorf("nil txn")
}
if aliasName == "" {
return nil, fmt.Errorf("missing alias name")
}
if mountAccessor == "" {
return nil, fmt.Errorf("missing mount accessor")
}
tableName := entityAliasesTable
if groupAlias {
tableName = groupAliasesTable
}
txn := i.db.Txn(false)
aliasRaw, err := txn.First(tableName, "factors", mountAccessor, aliasName)
if err != nil {
return nil, fmt.Errorf("failed to fetch alias from memdb using factors: %v", err)

View File

@@ -436,22 +436,15 @@ func (c *Core) handleLoginRequest(ctx context.Context, req *logical.Request) (re
var err error
// Check if an entity already exists for the given alias
entity, err = c.identityStore.entityByAliasFactors(auth.Alias.MountAccessor, auth.Alias.Name, false)
// Fetch the entity for the alias, or create an entity if one
// doesn't exist.
entity, err = c.identityStore.CreateOrFetchEntity(auth.Alias)
if err != nil {
return nil, nil, err
}
// If not, create one.
if entity == nil {
c.logger.Debug("core: creating a new entity", "alias", auth.Alias)
entity, err = c.identityStore.CreateEntity(auth.Alias)
if err != nil {
return nil, nil, err
}
if entity == nil {
return nil, nil, fmt.Errorf("failed to create an entity for the authenticated alias")
}
return nil, nil, fmt.Errorf("failed to create an entity for the authenticated alias")
}
auth.EntityID = entity.ID