Porting identity store (#3419)

* porting identity to OSS

* changes that glue things together

* add testing bits

* wrapped entity id

* fix mount error

* some more changes to core

* fix storagepacker tests

* fix some more tests

* fix mount tests

* fix http mount tests

* audit changes for identity

* remove upgrade structs on the oss side

* added go-memdb to vendor
This commit is contained in:
Vishal Nayak
2017-10-11 10:21:20 -07:00
committed by GitHub
parent a3bd4530b6
commit 6b9ce0c8c5
54 changed files with 10162 additions and 60 deletions

View File

@@ -73,6 +73,9 @@ bootstrap:
proto:
protoc -I helper/forwarding -I vault -I ../../.. vault/*.proto --go_out=plugins=grpc:vault
protoc -I helper/forwarding -I vault -I ../../.. helper/forwarding/types.proto --go_out=plugins=grpc:helper/forwarding
protoc -I helper/storagepacker helper/storagepacker/types.proto --go_out=plugins=grpc:helper/storagepacker
protoc -I helper/identity -I ../../.. helper/identity/types.proto --go_out=plugins=grpc:helper/identity
sed -i -e 's/Idp/IDP/' -e 's/Url/URL/' -e 's/Id/ID/' -e 's/EntityId/EntityID/' -e 's/Api/API/' helper/identity/types.pb.go helper/storagepacker/types.pb.go
fmtcheck:
@sh -c "'$(CURDIR)/scripts/gofmtcheck.sh'"

View File

@@ -123,6 +123,7 @@ func (f *AuditFormatter) FormatRequest(
DisplayName: auth.DisplayName,
Policies: auth.Policies,
Metadata: auth.Metadata,
EntityID: auth.EntityID,
RemainingUses: req.ClientTokenRemainingUses,
},
@@ -315,6 +316,7 @@ func (f *AuditFormatter) FormatResponse(
Policies: auth.Policies,
Metadata: auth.Metadata,
RemainingUses: req.ClientTokenRemainingUses,
EntityID: auth.EntityID,
},
Request: AuditRequest{
@@ -397,6 +399,7 @@ type AuditAuth struct {
Metadata map[string]string `json:"metadata"`
NumUses int `json:"num_uses,omitempty"`
RemainingUses int `json:"remaining_uses,omitempty"`
EntityID string `json:"entity_id"`
}
type AuditSecret struct {

View File

@@ -9,6 +9,7 @@ import (
"errors"
"fmt"
"github.com/hashicorp/vault/helper/salt"
"github.com/hashicorp/vault/logical"
)
@@ -50,7 +51,7 @@ func TestFormatJSONx_formatRequest(t *testing.T) {
errors.New("this is an error"),
"",
"",
fmt.Sprintf(`<json:object name="auth"><json:string name="accessor">bar</json:string><json:string name="client_token">%s</json:string><json:string name="display_name">testtoken</json:string><json:null name="metadata" /><json:array name="policies"><json:string>root</json:string></json:array></json:object><json:string name="error">this is an error</json:string><json:object name="request"><json:string name="client_token"></json:string><json:string name="client_token_accessor"></json:string><json:null name="data" /><json:object name="headers"><json:array name="foo"><json:string>bar</json:string></json:array></json:object><json:string name="id"></json:string><json:string name="operation">update</json:string><json:string name="path">/foo</json:string><json:string name="remote_address">127.0.0.1</json:string><json:number name="wrap_ttl">60</json:number></json:object><json:string name="type">request</json:string>`,
fmt.Sprintf(`<json:object name="auth"><json:string name="accessor">bar</json:string><json:string name="client_token">%s</json:string><json:string name="display_name">testtoken</json:string><json:string name="entity_id"></json:string><json:null name="metadata" /><json:array name="policies"><json:string>root</json:string></json:array></json:object><json:string name="error">this is an error</json:string><json:object name="request"><json:string name="client_token"></json:string><json:string name="client_token_accessor"></json:string><json:null name="data" /><json:object name="headers"><json:array name="foo"><json:string>bar</json:string></json:array></json:object><json:string name="id"></json:string><json:string name="operation">update</json:string><json:string name="path">/foo</json:string><json:string name="remote_address">127.0.0.1</json:string><json:number name="wrap_ttl">60</json:number></json:object><json:string name="type">request</json:string>`,
fooSalted),
},
"auth, request with prefix": {
@@ -71,7 +72,7 @@ func TestFormatJSONx_formatRequest(t *testing.T) {
errors.New("this is an error"),
"",
"@cee: ",
fmt.Sprintf(`<json:object name="auth"><json:string name="accessor">bar</json:string><json:string name="client_token">%s</json:string><json:string name="display_name">testtoken</json:string><json:null name="metadata" /><json:array name="policies"><json:string>root</json:string></json:array></json:object><json:string name="error">this is an error</json:string><json:object name="request"><json:string name="client_token"></json:string><json:string name="client_token_accessor"></json:string><json:null name="data" /><json:object name="headers"><json:array name="foo"><json:string>bar</json:string></json:array></json:object><json:string name="id"></json:string><json:string name="operation">update</json:string><json:string name="path">/foo</json:string><json:string name="remote_address">127.0.0.1</json:string><json:number name="wrap_ttl">60</json:number></json:object><json:string name="type">request</json:string>`,
fmt.Sprintf(`<json:object name="auth"><json:string name="accessor">bar</json:string><json:string name="client_token">%s</json:string><json:string name="display_name">testtoken</json:string><json:string name="entity_id"></json:string><json:null name="metadata" /><json:array name="policies"><json:string>root</json:string></json:array></json:object><json:string name="error">this is an error</json:string><json:object name="request"><json:string name="client_token"></json:string><json:string name="client_token_accessor"></json:string><json:null name="data" /><json:object name="headers"><json:array name="foo"><json:string>bar</json:string></json:array></json:object><json:string name="id"></json:string><json:string name="operation">update</json:string><json:string name="path">/foo</json:string><json:string name="remote_address">127.0.0.1</json:string><json:number name="wrap_ttl">60</json:number></json:object><json:string name="type">request</json:string>`,
fooSalted),
},
}

View File

@@ -0,0 +1,64 @@
package identity
import (
"fmt"
"github.com/gogo/protobuf/proto"
)
func (g *Group) Clone() (*Group, error) {
if g == nil {
return nil, fmt.Errorf("nil group")
}
marshaledGroup, err := proto.Marshal(g)
if err != nil {
return nil, fmt.Errorf("failed to marshal group: %v", err)
}
var clonedGroup Group
err = proto.Unmarshal(marshaledGroup, &clonedGroup)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal group: %v", err)
}
return &clonedGroup, nil
}
func (e *Entity) Clone() (*Entity, error) {
if e == nil {
return nil, fmt.Errorf("nil entity")
}
marshaledEntity, err := proto.Marshal(e)
if err != nil {
return nil, fmt.Errorf("failed to marshal entity: %v", err)
}
var clonedEntity Entity
err = proto.Unmarshal(marshaledEntity, &clonedEntity)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal entity: %v", err)
}
return &clonedEntity, nil
}
func (p *Alias) Clone() (*Alias, error) {
if p == nil {
return nil, fmt.Errorf("nil alias")
}
marshaledAlias, err := proto.Marshal(p)
if err != nil {
return nil, fmt.Errorf("failed to marshal alias: %v", err)
}
var clonedAlias Alias
err = proto.Unmarshal(marshaledAlias, &clonedAlias)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal alias: %v", err)
}
return &clonedAlias, nil
}

411
helper/identity/types.pb.go Normal file
View File

@@ -0,0 +1,411 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// source: types.proto
/*
Package identity is a generated protocol buffer package.
It is generated from these files:
types.proto
It has these top-level messages:
Group
Entity
Alias
*/
package identity
import proto "github.com/golang/protobuf/proto"
import fmt "fmt"
import math "math"
import google_protobuf "github.com/golang/protobuf/ptypes/timestamp"
// Reference imports to suppress errors if they are not otherwise used.
var _ = proto.Marshal
var _ = fmt.Errorf
var _ = math.Inf
// This is a compile-time assertion to ensure that this generated file
// is compatible with the proto package it is being compiled against.
// A compilation error at this line likely means your copy of the
// proto package needs to be updated.
const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package
// Group represents an identity group.
type Group struct {
// ID is the unique identifier for this group
ID string `protobuf:"bytes,1,opt,name=id" json:"id,omitempty"`
// Name is the unique name for this group
Name string `protobuf:"bytes,2,opt,name=name" json:"name,omitempty"`
// Policies are the vault policies to be granted to members of this group
Policies []string `protobuf:"bytes,3,rep,name=policies" json:"policies,omitempty"`
// ParentGroupIDs are the identifiers of those groups to which this group is a
// member of. These will serve as references to the parent group in the
// hierarchy.
ParentGroupIDs []string `protobuf:"bytes,4,rep,name=parent_group_ids,json=parentGroupIds" json:"parent_group_ids,omitempty"`
// MemberEntityIDs are the identifiers of entities which are members of this
// group
MemberEntityIDs []string `protobuf:"bytes,5,rep,name=member_entity_ids,json=memberEntityIDs" json:"member_entity_ids,omitempty"`
// Metadata represents the custom data tied with this group
Metadata map[string]string `protobuf:"bytes,6,rep,name=metadata" json:"metadata,omitempty" protobuf_key:"bytes,1,opt,name=key" protobuf_val:"bytes,2,opt,name=value"`
// CreationTime is the time at which this group was created
CreationTime *google_protobuf.Timestamp `protobuf:"bytes,7,opt,name=creation_time,json=creationTime" json:"creation_time,omitempty"`
// LastUpdateTime is the time at which this group was last modified
LastUpdateTime *google_protobuf.Timestamp `protobuf:"bytes,8,opt,name=last_update_time,json=lastUpdateTime" json:"last_update_time,omitempty"`
// ModifyIndex tracks the number of updates to the group. It is useful to detect
// updates to the groups.
ModifyIndex uint64 `protobuf:"varint,9,opt,name=modify_index,json=modifyIndex" json:"modify_index,omitempty"`
// BucketKeyHash is the MD5 hash of the storage bucket key into which this
// group is stored in the underlying storage. This is useful to find all
// the groups belonging to a particular bucket during invalidation of the
// storage key.
BucketKeyHash string `protobuf:"bytes,10,opt,name=bucket_key_hash,json=bucketKeyHash" json:"bucket_key_hash,omitempty"`
}
func (m *Group) Reset() { *m = Group{} }
func (m *Group) String() string { return proto.CompactTextString(m) }
func (*Group) ProtoMessage() {}
func (*Group) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{0} }
func (m *Group) GetID() string {
if m != nil {
return m.ID
}
return ""
}
func (m *Group) GetName() string {
if m != nil {
return m.Name
}
return ""
}
func (m *Group) GetPolicies() []string {
if m != nil {
return m.Policies
}
return nil
}
func (m *Group) GetParentGroupIDs() []string {
if m != nil {
return m.ParentGroupIDs
}
return nil
}
func (m *Group) GetMemberEntityIDs() []string {
if m != nil {
return m.MemberEntityIDs
}
return nil
}
func (m *Group) GetMetadata() map[string]string {
if m != nil {
return m.Metadata
}
return nil
}
func (m *Group) GetCreationTime() *google_protobuf.Timestamp {
if m != nil {
return m.CreationTime
}
return nil
}
func (m *Group) GetLastUpdateTime() *google_protobuf.Timestamp {
if m != nil {
return m.LastUpdateTime
}
return nil
}
func (m *Group) GetModifyIndex() uint64 {
if m != nil {
return m.ModifyIndex
}
return 0
}
func (m *Group) GetBucketKeyHash() string {
if m != nil {
return m.BucketKeyHash
}
return ""
}
// Entity represents an entity that gets persisted and indexed.
// Entity is fundamentally composed of zero or many aliases.
type Entity struct {
// Aliases are the identities that this entity is made of. This can be
// empty as well to favor being able to create the entity first and then
// incrementally adding aliases.
Aliases []*Alias `protobuf:"bytes,1,rep,name=aliases" json:"aliases,omitempty"`
// ID is the unique identifier of the entity which always be a UUID. This
// should never be allowed to be updated.
ID string `protobuf:"bytes,2,opt,name=id" json:"id,omitempty"`
// Name is a unique identifier of the entity which is intended to be
// human-friendly. The default name might not be human friendly since it
// gets suffixed by a UUID, but it can optionally be updated, unlike the ID
// field.
Name string `protobuf:"bytes,3,opt,name=name" json:"name,omitempty"`
// Metadata represents the explicit metadata which is set by the
// clients. This is useful to tie any information pertaining to the
// aliases. This is a non-unique field of entity, meaning multiple
// entities can have the same metadata set. Entities will be indexed based
// on this explicit metadata. This enables virtual groupings of entities
// based on its metadata.
Metadata map[string]string `protobuf:"bytes,4,rep,name=metadata" json:"metadata,omitempty" protobuf_key:"bytes,1,opt,name=key" protobuf_val:"bytes,2,opt,name=value"`
// CreationTime is the time at which this entity is first created.
CreationTime *google_protobuf.Timestamp `protobuf:"bytes,5,opt,name=creation_time,json=creationTime" json:"creation_time,omitempty"`
// LastUpdateTime is the most recent time at which the properties of this
// entity got modified. This is helpful in filtering out entities based on
// its age and to take action on them, if desired.
LastUpdateTime *google_protobuf.Timestamp `protobuf:"bytes,6,opt,name=last_update_time,json=lastUpdateTime" json:"last_update_time,omitempty"`
// MergedEntityIDs are the entities which got merged to this one. Entities
// will be indexed based on all the entities that got merged into it. This
// helps to apply the actions on this entity on the tokens that are merged
// to the merged entities. Merged entities will be deleted entirely and
// this is the only trackable trail of its earlier presence.
MergedEntityIDs []string `protobuf:"bytes,7,rep,name=merged_entity_ids,json=mergedEntityIDs" json:"merged_entity_ids,omitempty"`
// Policies the entity is entitled to
Policies []string `protobuf:"bytes,8,rep,name=policies" json:"policies,omitempty"`
// BucketKeyHash is the MD5 hash of the storage bucket key into which this
// entity is stored in the underlying storage. This is useful to find all
// the entities belonging to a particular bucket during invalidation of the
// storage key.
BucketKeyHash string `protobuf:"bytes,9,opt,name=bucket_key_hash,json=bucketKeyHash" json:"bucket_key_hash,omitempty"`
}
func (m *Entity) Reset() { *m = Entity{} }
func (m *Entity) String() string { return proto.CompactTextString(m) }
func (*Entity) ProtoMessage() {}
func (*Entity) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{1} }
func (m *Entity) GetAliases() []*Alias {
if m != nil {
return m.Aliases
}
return nil
}
func (m *Entity) GetID() string {
if m != nil {
return m.ID
}
return ""
}
func (m *Entity) GetName() string {
if m != nil {
return m.Name
}
return ""
}
func (m *Entity) GetMetadata() map[string]string {
if m != nil {
return m.Metadata
}
return nil
}
func (m *Entity) GetCreationTime() *google_protobuf.Timestamp {
if m != nil {
return m.CreationTime
}
return nil
}
func (m *Entity) GetLastUpdateTime() *google_protobuf.Timestamp {
if m != nil {
return m.LastUpdateTime
}
return nil
}
func (m *Entity) GetMergedEntityIDs() []string {
if m != nil {
return m.MergedEntityIDs
}
return nil
}
func (m *Entity) GetPolicies() []string {
if m != nil {
return m.Policies
}
return nil
}
func (m *Entity) GetBucketKeyHash() string {
if m != nil {
return m.BucketKeyHash
}
return ""
}
// Alias represents the alias that gets stored inside of the
// entity object in storage and also represents in an in-memory index of an
// alias object.
type Alias struct {
// ID is the unique identifier that represents this alias
ID string `protobuf:"bytes,1,opt,name=id" json:"id,omitempty"`
// EntityID is the entity identifier to which this alias belongs to
EntityID string `protobuf:"bytes,2,opt,name=entity_id,json=entityId" json:"entity_id,omitempty"`
// MountType is the backend mount's type to which this alias belongs to.
// This enables categorically querying aliases of specific backend types.
MountType string `protobuf:"bytes,3,opt,name=mount_type,json=mountType" json:"mount_type,omitempty"`
// MountAccessor is the backend mount's accessor to which this alias
// belongs to.
MountAccessor string `protobuf:"bytes,4,opt,name=mount_accessor,json=mountAccessor" json:"mount_accessor,omitempty"`
// MountPath is the backend mount's path to which the Maccessor belongs to. This
// field is not used for any operational purposes. This is only returned when
// alias is read, only as a nicety.
MountPath string `protobuf:"bytes,5,opt,name=mount_path,json=mountPath" json:"mount_path,omitempty"`
// Metadata is the explicit metadata that clients set against an entity
// which enables virtual grouping of aliases. Aliases will be indexed
// against their metadata.
Metadata map[string]string `protobuf:"bytes,6,rep,name=metadata" json:"metadata,omitempty" protobuf_key:"bytes,1,opt,name=key" protobuf_val:"bytes,2,opt,name=value"`
// Name is the identifier of this alias in its authentication source.
// This does not uniquely identify a alias in Vault. This in conjunction
// with MountAccessor form to be the factors that represent a alias in a
// unique way. Aliases will be indexed based on this combined uniqueness
// factor.
Name string `protobuf:"bytes,7,opt,name=name" json:"name,omitempty"`
// CreationTime is the time at which this alias was first created
CreationTime *google_protobuf.Timestamp `protobuf:"bytes,8,opt,name=creation_time,json=creationTime" json:"creation_time,omitempty"`
// LastUpdateTime is the most recent time at which the properties of this
// alias got modified. This is helpful in filtering out aliases based
// on its age and to take action on them, if desired.
LastUpdateTime *google_protobuf.Timestamp `protobuf:"bytes,9,opt,name=last_update_time,json=lastUpdateTime" json:"last_update_time,omitempty"`
// MergedFromEntityIDs is the FIFO history of merging activity by entity IDs from
// which this alias is transfered over to the entity to which it
// currently belongs to.
MergedFromEntityIDs []string `protobuf:"bytes,10,rep,name=merged_from_entity_ids,json=mergedFromEntityIDs" json:"merged_from_entity_ids,omitempty"`
}
func (m *Alias) Reset() { *m = Alias{} }
func (m *Alias) String() string { return proto.CompactTextString(m) }
func (*Alias) ProtoMessage() {}
func (*Alias) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{2} }
func (m *Alias) GetID() string {
if m != nil {
return m.ID
}
return ""
}
func (m *Alias) GetEntityID() string {
if m != nil {
return m.EntityID
}
return ""
}
func (m *Alias) GetMountType() string {
if m != nil {
return m.MountType
}
return ""
}
func (m *Alias) GetMountAccessor() string {
if m != nil {
return m.MountAccessor
}
return ""
}
func (m *Alias) GetMountPath() string {
if m != nil {
return m.MountPath
}
return ""
}
func (m *Alias) GetMetadata() map[string]string {
if m != nil {
return m.Metadata
}
return nil
}
func (m *Alias) GetName() string {
if m != nil {
return m.Name
}
return ""
}
func (m *Alias) GetCreationTime() *google_protobuf.Timestamp {
if m != nil {
return m.CreationTime
}
return nil
}
func (m *Alias) GetLastUpdateTime() *google_protobuf.Timestamp {
if m != nil {
return m.LastUpdateTime
}
return nil
}
func (m *Alias) GetMergedFromEntityIDs() []string {
if m != nil {
return m.MergedFromEntityIDs
}
return nil
}
func init() {
proto.RegisterType((*Group)(nil), "identity.Group")
proto.RegisterType((*Entity)(nil), "identity.Entity")
proto.RegisterType((*Alias)(nil), "identity.Alias")
}
func init() { proto.RegisterFile("types.proto", fileDescriptor0) }
var fileDescriptor0 = []byte{
// 570 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xac, 0x94, 0xcd, 0x6e, 0xd3, 0x40,
0x10, 0xc7, 0xe5, 0x38, 0x1f, 0xf6, 0xa4, 0x4d, 0xcb, 0x82, 0x90, 0x15, 0x54, 0x08, 0x95, 0x40,
0x86, 0x83, 0x2b, 0xb5, 0x17, 0x28, 0x07, 0x54, 0x89, 0x02, 0x15, 0x42, 0x42, 0x56, 0x39, 0x5b,
0x9b, 0x78, 0x9a, 0xac, 0x1a, 0x7b, 0x2d, 0xef, 0x1a, 0xe1, 0x27, 0xe4, 0x39, 0x38, 0xf1, 0x1a,
0xc8, 0xb3, 0x76, 0x62, 0x08, 0x5f, 0x15, 0xb9, 0xd9, 0xff, 0x99, 0x1d, 0xcf, 0xce, 0xff, 0x37,
0x86, 0xa1, 0x2e, 0x33, 0x54, 0x41, 0x96, 0x4b, 0x2d, 0x99, 0x23, 0x62, 0x4c, 0xb5, 0xd0, 0xe5,
0xf8, 0xc1, 0x5c, 0xca, 0xf9, 0x12, 0x8f, 0x48, 0x9f, 0x16, 0x57, 0x47, 0x5a, 0x24, 0xa8, 0x34,
0x4f, 0x32, 0x93, 0x7a, 0xf8, 0xcd, 0x86, 0xde, 0x9b, 0x5c, 0x16, 0x19, 0x1b, 0x41, 0x47, 0xc4,
0x9e, 0x35, 0xb1, 0x7c, 0x37, 0xec, 0x88, 0x98, 0x31, 0xe8, 0xa6, 0x3c, 0x41, 0xaf, 0x43, 0x0a,
0x3d, 0xb3, 0x31, 0x38, 0x99, 0x5c, 0x8a, 0x99, 0x40, 0xe5, 0xd9, 0x13, 0xdb, 0x77, 0xc3, 0xd5,
0x3b, 0xf3, 0x61, 0x3f, 0xe3, 0x39, 0xa6, 0x3a, 0x9a, 0x57, 0xf5, 0x22, 0x11, 0x2b, 0xaf, 0x4b,
0x39, 0x23, 0xa3, 0xd3, 0x67, 0x2e, 0x62, 0xc5, 0x9e, 0xc2, 0xad, 0x04, 0x93, 0x29, 0xe6, 0x91,
0xe9, 0x92, 0x52, 0x7b, 0x94, 0xba, 0x67, 0x02, 0xe7, 0xa4, 0x57, 0xb9, 0xcf, 0xc1, 0x49, 0x50,
0xf3, 0x98, 0x6b, 0xee, 0xf5, 0x27, 0xb6, 0x3f, 0x3c, 0x3e, 0x08, 0x9a, 0xdb, 0x05, 0x54, 0x31,
0x78, 0x5f, 0xc7, 0xcf, 0x53, 0x9d, 0x97, 0xe1, 0x2a, 0x9d, 0xbd, 0x84, 0xdd, 0x59, 0x8e, 0x5c,
0x0b, 0x99, 0x46, 0xd5, 0xb5, 0xbd, 0xc1, 0xc4, 0xf2, 0x87, 0xc7, 0xe3, 0xc0, 0xcc, 0x24, 0x68,
0x66, 0x12, 0x5c, 0x36, 0x33, 0x09, 0x77, 0x9a, 0x03, 0x95, 0xc4, 0x5e, 0xc1, 0xfe, 0x92, 0x2b,
0x1d, 0x15, 0x59, 0xcc, 0x35, 0x9a, 0x1a, 0xce, 0x5f, 0x6b, 0x8c, 0xaa, 0x33, 0x1f, 0xe9, 0x08,
0x55, 0x79, 0x08, 0x3b, 0x89, 0x8c, 0xc5, 0x55, 0x19, 0x89, 0x34, 0xc6, 0xcf, 0x9e, 0x3b, 0xb1,
0xfc, 0x6e, 0x38, 0x34, 0xda, 0x45, 0x25, 0xb1, 0xc7, 0xb0, 0x37, 0x2d, 0x66, 0xd7, 0xa8, 0xa3,
0x6b, 0x2c, 0xa3, 0x05, 0x57, 0x0b, 0x0f, 0x68, 0xea, 0xbb, 0x46, 0x7e, 0x87, 0xe5, 0x5b, 0xae,
0x16, 0xe3, 0x17, 0xb0, 0xfb, 0xc3, 0x65, 0xd9, 0x3e, 0xd8, 0xd7, 0x58, 0xd6, 0xa6, 0x55, 0x8f,
0xec, 0x0e, 0xf4, 0x3e, 0xf1, 0x65, 0xd1, 0xd8, 0x66, 0x5e, 0x4e, 0x3b, 0xcf, 0xac, 0xc3, 0x2f,
0x36, 0xf4, 0xcd, 0x5c, 0xd9, 0x13, 0x18, 0xf0, 0xa5, 0xe0, 0x0a, 0x95, 0x67, 0xd1, 0x4c, 0xf7,
0xd6, 0x33, 0x3d, 0xab, 0x02, 0x61, 0x13, 0xaf, 0xa9, 0xe8, 0x6c, 0x50, 0x61, 0xb7, 0xa8, 0x38,
0x6d, 0x79, 0xd4, 0xa5, 0x7a, 0xf7, 0xd7, 0xf5, 0xcc, 0x27, 0xff, 0xdd, 0xa4, 0xde, 0x16, 0x4c,
0xea, 0xdf, 0xd8, 0x24, 0x42, 0x32, 0x9f, 0x63, 0xdc, 0x46, 0x72, 0xd0, 0x20, 0x59, 0x05, 0xd6,
0x48, 0xb6, 0x97, 0xc0, 0xf9, 0x69, 0x09, 0x7e, 0xe1, 0xa4, 0xbb, 0x75, 0x27, 0xbf, 0xda, 0xd0,
0x23, 0x9b, 0x36, 0x76, 0xf6, 0x1e, 0xb8, 0xab, 0xfe, 0xeb, 0x73, 0x0e, 0xd6, 0x8d, 0xb3, 0x03,
0x80, 0x44, 0x16, 0xa9, 0x8e, 0xaa, 0x5f, 0x45, 0x6d, 0xa0, 0x4b, 0xca, 0x65, 0x99, 0x21, 0x7b,
0x04, 0x23, 0x13, 0xe6, 0xb3, 0x19, 0x2a, 0x25, 0x73, 0xaf, 0x6b, 0x3a, 0x27, 0xf5, 0xac, 0x16,
0xd7, 0x55, 0x32, 0xae, 0x17, 0xe4, 0x56, 0x53, 0xe5, 0x03, 0xd7, 0x8b, 0x3f, 0xef, 0x2b, 0x35,
0xfd, 0x5b, 0x14, 0x1a, 0xb4, 0x06, 0x2d, 0xb4, 0x36, 0xf0, 0x70, 0xb6, 0x80, 0x87, 0x7b, 0x63,
0x3c, 0x4e, 0xe0, 0x6e, 0x8d, 0xc7, 0x55, 0x2e, 0x93, 0x36, 0x23, 0x40, 0x00, 0xdc, 0x36, 0xd1,
0xd7, 0xb9, 0x4c, 0x56, 0x9c, 0xfc, 0x97, 0xc7, 0xd3, 0x3e, 0x75, 0x75, 0xf2, 0x3d, 0x00, 0x00,
0xff, 0xff, 0x17, 0x1c, 0xfc, 0x89, 0xd8, 0x05, 0x00, 0x00,
}

151
helper/identity/types.proto Normal file
View File

@@ -0,0 +1,151 @@
syntax = "proto3";
package identity;
import "google/protobuf/timestamp.proto";
// Group represents an identity group.
message Group {
// ID is the unique identifier for this group
string id = 1;
// Name is the unique name for this group
string name = 2;
// Policies are the vault policies to be granted to members of this group
repeated string policies = 3;
// ParentGroupIDs are the identifiers of those groups to which this group is a
// member of. These will serve as references to the parent group in the
// hierarchy.
repeated string parent_group_ids = 4;
// MemberEntityIDs are the identifiers of entities which are members of this
// group
repeated string member_entity_ids = 5;
// Metadata represents the custom data tied with this group
map<string, string> metadata = 6;
// CreationTime is the time at which this group was created
google.protobuf.Timestamp creation_time = 7;
// LastUpdateTime is the time at which this group was last modified
google.protobuf.Timestamp last_update_time= 8;
// ModifyIndex tracks the number of updates to the group. It is useful to detect
// updates to the groups.
uint64 modify_index = 9;
// BucketKeyHash is the MD5 hash of the storage bucket key into which this
// group is stored in the underlying storage. This is useful to find all
// the groups belonging to a particular bucket during invalidation of the
// storage key.
string bucket_key_hash = 10;
}
// Entity represents an entity that gets persisted and indexed.
// Entity is fundamentally composed of zero or many aliases.
message Entity {
// Aliases are the identities that this entity is made of. This can be
// empty as well to favor being able to create the entity first and then
// incrementally adding aliases.
repeated Alias aliases = 1;
// ID is the unique identifier of the entity which always be a UUID. This
// should never be allowed to be updated.
string id = 2;
// Name is a unique identifier of the entity which is intended to be
// human-friendly. The default name might not be human friendly since it
// gets suffixed by a UUID, but it can optionally be updated, unlike the ID
// field.
string name = 3;
// Metadata represents the explicit metadata which is set by the
// clients. This is useful to tie any information pertaining to the
// aliases. This is a non-unique field of entity, meaning multiple
// entities can have the same metadata set. Entities will be indexed based
// on this explicit metadata. This enables virtual groupings of entities
// based on its metadata.
map<string, string> metadata = 4;
// CreationTime is the time at which this entity is first created.
google.protobuf.Timestamp creation_time = 5;
// LastUpdateTime is the most recent time at which the properties of this
// entity got modified. This is helpful in filtering out entities based on
// its age and to take action on them, if desired.
google.protobuf.Timestamp last_update_time= 6;
// MergedEntityIDs are the entities which got merged to this one. Entities
// will be indexed based on all the entities that got merged into it. This
// helps to apply the actions on this entity on the tokens that are merged
// to the merged entities. Merged entities will be deleted entirely and
// this is the only trackable trail of its earlier presence.
repeated string merged_entity_ids = 7;
// Policies the entity is entitled to
repeated string policies = 8;
// BucketKeyHash is the MD5 hash of the storage bucket key into which this
// entity is stored in the underlying storage. This is useful to find all
// the entities belonging to a particular bucket during invalidation of the
// storage key.
string bucket_key_hash = 9;
// **Enterprise only**
// MFASecrets holds the MFA secrets indexed by the identifier of the MFA
// method configuration.
//map<string, mfa.Secret> mfa_secrets = 10;
}
// Alias represents the alias that gets stored inside of the
// entity object in storage and also represents in an in-memory index of an
// alias object.
message Alias {
// ID is the unique identifier that represents this alias
string id = 1;
// EntityID is the entity identifier to which this alias belongs to
string entity_id = 2;
// MountType is the backend mount's type to which this alias belongs to.
// This enables categorically querying aliases of specific backend types.
string mount_type = 3;
// MountAccessor is the backend mount's accessor to which this alias
// belongs to.
string mount_accessor = 4;
// MountPath is the backend mount's path to which the Maccessor belongs to. This
// field is not used for any operational purposes. This is only returned when
// alias is read, only as a nicety.
string mount_path = 5;
// Metadata is the explicit metadata that clients set against an entity
// which enables virtual grouping of aliases. Aliases will be indexed
// against their metadata.
map<string, string> metadata = 6;
// Name is the identifier of this alias in its authentication source.
// This does not uniquely identify a alias in Vault. This in conjunction
// with MountAccessor form to be the factors that represent a alias in a
// unique way. Aliases will be indexed based on this combined uniqueness
// factor.
string name = 7;
// CreationTime is the time at which this alias was first created
google.protobuf.Timestamp creation_time = 8;
// LastUpdateTime is the most recent time at which the properties of this
// alias got modified. This is helpful in filtering out aliases based
// on its age and to take action on them, if desired.
google.protobuf.Timestamp last_update_time = 9;
// MergedFromEntityIDs is the FIFO history of merging activity by entity IDs from
// which this alias is transfered over to the entity to which it
// currently belongs to.
repeated string merged_from_entity_ids = 10;
}

View File

@@ -0,0 +1,351 @@
package storagepacker
import (
"crypto/md5"
"encoding/hex"
"fmt"
"strconv"
"strings"
"github.com/golang/protobuf/proto"
"github.com/hashicorp/errwrap"
"github.com/hashicorp/vault/helper/compressutil"
"github.com/hashicorp/vault/helper/locksutil"
"github.com/hashicorp/vault/logical"
log "github.com/mgutz/logxi/v1"
)
const (
bucketCount = 256
StoragePackerBucketsPrefix = "packer/buckets/"
)
// StoragePacker packs the objects into a specific number of buckets by hashing
// its ID and indexing it. Currently this supports only 256 bucket entries and
// hence relies on the first byte of the hash value for indexing. The items
// that gets inserted into the packer should implement StorageBucketItem
// interface.
type StoragePacker struct {
view logical.Storage
logger log.Logger
storageLocks []*locksutil.LockEntry
viewPrefix string
}
// BucketPath returns the storage entry key for a given bucket key
func (s *StoragePacker) BucketPath(bucketKey string) string {
return s.viewPrefix + bucketKey
}
// BucketKeyHash returns the MD5 hash of the bucket storage key in which
// the item will be stored. The choice of MD5 is only for hash performance
// reasons since its value is not used for any security sensitive operation.
func (s *StoragePacker) BucketKeyHashByItemID(itemID string) string {
return s.BucketKeyHashByKey(s.BucketPath(s.BucketKey(itemID)))
}
// BucketKeyHashByKey returns the MD5 hash of the bucket storage key
func (s *StoragePacker) BucketKeyHashByKey(bucketKey string) string {
hf := md5.New()
hf.Write([]byte(bucketKey))
return hex.EncodeToString(hf.Sum(nil))
}
// View returns the storage view configured to be used by the packer
func (s *StoragePacker) View() logical.Storage {
return s.view
}
// Get returns a bucket for a given key
func (s *StoragePacker) GetBucket(key string) (*Bucket, error) {
if key == "" {
return nil, fmt.Errorf("missing bucket key")
}
lock := locksutil.LockForKey(s.storageLocks, key)
lock.RLock()
defer lock.RUnlock()
// Read from the underlying view
storageEntry, err := s.view.Get(key)
if err != nil {
return nil, errwrap.Wrapf("failed to read packed storage entry: {{err}}", err)
}
if storageEntry == nil {
return nil, nil
}
uncompressedData, notCompressed, err := compressutil.Decompress(storageEntry.Value)
if err != nil {
return nil, errwrap.Wrapf("failed to decompress packed storage entry: {{err}}", err)
}
if notCompressed {
uncompressedData = storageEntry.Value
}
var bucket Bucket
err = proto.Unmarshal(uncompressedData, &bucket)
if err != nil {
return nil, errwrap.Wrapf("failed to decode packed storage entry: {{err}}", err)
}
return &bucket, nil
}
// upsert either inserts a new item into the bucket or updates an existing one
// if an item with a matching key is already present.
func (s *Bucket) upsert(item *Item) error {
if s == nil {
return fmt.Errorf("nil storage bucket")
}
if item == nil {
return fmt.Errorf("nil item")
}
if item.ID == "" {
return fmt.Errorf("missing item ID")
}
// Look for an item with matching key and don't modify the collection
// while iterating
foundIdx := -1
for itemIdx, bucketItems := range s.Items {
if bucketItems.ID == item.ID {
foundIdx = itemIdx
break
}
}
// If there is no match, append the item, otherwise update it
if foundIdx == -1 {
s.Items = append(s.Items, item)
} else {
s.Items[foundIdx] = item
}
return nil
}
// BucketIndex returns the bucket key index for a given storage key
func (s *StoragePacker) BucketIndex(key string) uint8 {
hf := md5.New()
hf.Write([]byte(key))
return uint8(hf.Sum(nil)[0])
}
// BucketKey returns the bucket key for a given item ID
func (s *StoragePacker) BucketKey(itemID string) string {
return strconv.Itoa(int(s.BucketIndex(itemID)))
}
// DeleteItem removes the storage entry which the given key refers to from its
// corresponding bucket.
func (s *StoragePacker) DeleteItem(itemID string) error {
if itemID == "" {
return fmt.Errorf("empty item ID")
}
// Get the bucket key
bucketKey := s.BucketKey(itemID)
// Prepend the view prefix
bucketPath := s.BucketPath(bucketKey)
// Read from underlying view
storageEntry, err := s.view.Get(bucketPath)
if err != nil {
return errwrap.Wrapf("failed to read packed storage value: {{err}}", err)
}
if storageEntry == nil {
return nil
}
uncompressedData, notCompressed, err := compressutil.Decompress(storageEntry.Value)
if err != nil {
return errwrap.Wrapf("failed to decompress packed storage value: {{err}}", err)
}
if notCompressed {
uncompressedData = storageEntry.Value
}
var bucket Bucket
err = proto.Unmarshal(uncompressedData, &bucket)
if err != nil {
return errwrap.Wrapf("failed decoding packed storage entry: {{err}}", err)
}
// Look for a matching storage entry
foundIdx := -1
for itemIdx, item := range bucket.Items {
if item.ID == itemID {
foundIdx = itemIdx
break
}
}
// If there is a match, remove it from the collection and persist the
// resulting collection
if foundIdx != -1 {
bucket.Items = append(bucket.Items[:foundIdx], bucket.Items[foundIdx+1:]...)
// Persist bucket entry only if there is an update
err = s.PutBucket(&bucket)
if err != nil {
return err
}
}
return nil
}
// Put stores a packed bucket entry
func (s *StoragePacker) PutBucket(bucket *Bucket) error {
if bucket == nil {
return fmt.Errorf("nil bucket entry")
}
if bucket.Key == "" {
return fmt.Errorf("missing key")
}
if !strings.HasPrefix(bucket.Key, s.viewPrefix) {
return fmt.Errorf("incorrect prefix; bucket entry key should have %q prefix", s.viewPrefix)
}
marshaledBucket, err := proto.Marshal(bucket)
if err != nil {
return errwrap.Wrapf("failed to marshal bucket: {{err}}", err)
}
compressedBucket, err := compressutil.Compress(marshaledBucket, &compressutil.CompressionConfig{
Type: compressutil.CompressionTypeSnappy,
})
if err != nil {
return errwrap.Wrapf("failed to compress packed bucket: {{err}}", err)
}
// Store the compressed value
err = s.view.Put(&logical.StorageEntry{
Key: bucket.Key,
Value: compressedBucket,
})
if err != nil {
return errwrap.Wrapf("failed to persist packed storage entry: {{err}}", err)
}
return nil
}
// GetItem fetches the storage entry for a given key from its corresponding
// bucket.
func (s *StoragePacker) GetItem(itemID string) (*Item, error) {
if itemID == "" {
return nil, fmt.Errorf("empty item ID")
}
bucketKey := s.BucketKey(itemID)
bucketPath := s.BucketPath(bucketKey)
// Fetch the bucket entry
bucket, err := s.GetBucket(bucketPath)
if err != nil {
return nil, errwrap.Wrapf("failed to read packed storage item: {{err}}", err)
}
// Look for a matching storage entry in the bucket items
for _, item := range bucket.Items {
if item.ID == itemID {
return item, nil
}
}
return nil, nil
}
// PutItem stores a storage entry in its corresponding bucket
func (s *StoragePacker) PutItem(item *Item) error {
if item == nil {
return fmt.Errorf("nil item")
}
if item.ID == "" {
return fmt.Errorf("missing ID in item")
}
var err error
bucketKey := s.BucketKey(item.ID)
bucketPath := s.BucketPath(bucketKey)
bucket := &Bucket{
Key: bucketPath,
}
// In this case, we persist the storage entry regardless of the read
// storageEntry below is nil or not. Hence, directly acquire write lock
// even to read the entry.
lock := locksutil.LockForKey(s.storageLocks, bucketPath)
lock.Lock()
defer lock.Unlock()
// Check if there is an existing bucket for a given key
storageEntry, err := s.view.Get(bucketPath)
if err != nil {
return errwrap.Wrapf("failed to read packed storage bucket entry: {{err}}", err)
}
if storageEntry == nil {
// If the bucket entry does not exist, this will be the only item the
// bucket that is going to be persisted.
bucket.Items = []*Item{
item,
}
} else {
uncompressedData, notCompressed, err := compressutil.Decompress(storageEntry.Value)
if err != nil {
return errwrap.Wrapf("failed to decompress packed storage entry: {{err}}", err)
}
if notCompressed {
uncompressedData = storageEntry.Value
}
err = proto.Unmarshal(uncompressedData, bucket)
if err != nil {
return errwrap.Wrapf("failed to decode packed storage entry: {{err}}", err)
}
err = bucket.upsert(item)
if err != nil {
return errwrap.Wrapf("failed to update entry in packed storage entry: {{err}}", err)
}
}
// Persist the result
return s.PutBucket(bucket)
}
// NewStoragePacker creates a new storage packer for a given view
func NewStoragePacker(view logical.Storage, logger log.Logger, viewPrefix string) (*StoragePacker, error) {
if view == nil {
return nil, fmt.Errorf("nil view")
}
if viewPrefix == "" {
viewPrefix = StoragePackerBucketsPrefix
}
if !strings.HasSuffix(viewPrefix, "/") {
viewPrefix = viewPrefix + "/"
}
// Create a new packer object for the given view
packer := &StoragePacker{
view: view,
viewPrefix: viewPrefix,
logger: logger,
storageLocks: locksutil.CreateLocks(),
}
return packer, nil
}

View File

@@ -0,0 +1,172 @@
package storagepacker
import (
"reflect"
"testing"
"github.com/golang/protobuf/ptypes"
uuid "github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/helper/identity"
"github.com/hashicorp/vault/logical"
log "github.com/mgutz/logxi/v1"
)
func BenchmarkStoragePacker(b *testing.B) {
storagePacker, err := NewStoragePacker(&logical.InmemStorage{}, log.New("storagepackertest"), "")
if err != nil {
b.Fatal(err)
}
for i := 0; i < b.N; i++ {
itemID, err := uuid.GenerateUUID()
if err != nil {
b.Fatal(err)
}
item := &Item{
ID: itemID,
}
err = storagePacker.PutItem(item)
if err != nil {
b.Fatal(err)
}
fetchedItem, err := storagePacker.GetItem(itemID)
if err != nil {
b.Fatal(err)
}
if fetchedItem == nil {
b.Fatalf("failed to read stored item with ID: %q, iteration: %d", item.ID, i)
}
if fetchedItem.ID != item.ID {
b.Fatalf("bad: item ID; expected: %q\n actual: %q", item.ID, fetchedItem.ID)
}
err = storagePacker.DeleteItem(item.ID)
if err != nil {
b.Fatal(err)
}
fetchedItem, err = storagePacker.GetItem(item.ID)
if err != nil {
b.Fatal(err)
}
if fetchedItem != nil {
b.Fatalf("failed to delete item")
}
}
}
func TestStoragePacker(t *testing.T) {
storagePacker, err := NewStoragePacker(&logical.InmemStorage{}, log.New("storagepackertest"), "")
if err != nil {
t.Fatal(err)
}
// Persist a storage entry
item1 := &Item{
ID: "item1",
}
err = storagePacker.PutItem(item1)
if err != nil {
t.Fatal(err)
}
// Verify that it can be read
fetchedItem, err := storagePacker.GetItem(item1.ID)
if err != nil {
t.Fatal(err)
}
if fetchedItem == nil {
t.Fatalf("failed to read the stored item")
}
if item1.ID != fetchedItem.ID {
t.Fatalf("bad: item ID; expected: %q\n actual: %q\n", item1.ID, fetchedItem.ID)
}
// Delete item1
err = storagePacker.DeleteItem(item1.ID)
if err != nil {
t.Fatal(err)
}
// Check that the deletion was successful
fetchedItem, err = storagePacker.GetItem(item1.ID)
if err != nil {
t.Fatal(err)
}
if fetchedItem != nil {
t.Fatalf("failed to delete item")
}
}
func TestStoragePacker_SerializeDeserializeComplexItem(t *testing.T) {
storagePacker, err := NewStoragePacker(&logical.InmemStorage{}, log.New("storagepackertest"), "")
if err != nil {
t.Fatal(err)
}
timeNow := ptypes.TimestampNow()
alias1 := &identity.Alias{
ID: "alias_id",
EntityID: "entity_id",
MountType: "mount_type",
MountAccessor: "mount_accessor",
Metadata: map[string]string{
"aliasmkey": "aliasmvalue",
},
Name: "alias_name",
CreationTime: timeNow,
LastUpdateTime: timeNow,
MergedFromEntityIDs: []string{"merged_from_entity_id"},
}
entity := &identity.Entity{
Aliases: []*identity.Alias{alias1},
ID: "entity_id",
Name: "entity_name",
Metadata: map[string]string{
"testkey1": "testvalue1",
"testkey2": "testvalue2",
},
CreationTime: timeNow,
LastUpdateTime: timeNow,
BucketKeyHash: "entity_hash",
MergedEntityIDs: []string{"merged_entity_id1", "merged_entity_id2"},
Policies: []string{"policy1", "policy2"},
}
marshaledEntity, err := ptypes.MarshalAny(entity)
if err != nil {
t.Fatal(err)
}
err = storagePacker.PutItem(&Item{
ID: entity.ID,
Message: marshaledEntity,
})
if err != nil {
t.Fatal(err)
}
itemFetched, err := storagePacker.GetItem(entity.ID)
if err != nil {
t.Fatal(err)
}
var itemDecoded identity.Entity
err = ptypes.UnmarshalAny(itemFetched.Message, &itemDecoded)
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(&itemDecoded, entity) {
t.Fatalf("bad: expected: %#v\nactual: %#v\n", entity, itemDecoded)
}
}

View File

@@ -0,0 +1,101 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// source: types.proto
/*
Package storagepacker is a generated protocol buffer package.
It is generated from these files:
types.proto
It has these top-level messages:
Item
Bucket
*/
package storagepacker
import proto "github.com/golang/protobuf/proto"
import fmt "fmt"
import math "math"
import google_protobuf "github.com/golang/protobuf/ptypes/any"
// Reference imports to suppress errors if they are not otherwise used.
var _ = proto.Marshal
var _ = fmt.Errorf
var _ = math.Inf
// This is a compile-time assertion to ensure that this generated file
// is compatible with the proto package it is being compiled against.
// A compilation error at this line likely means your copy of the
// proto package needs to be updated.
const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package
type Item struct {
ID string `protobuf:"bytes,1,opt,name=id" json:"id,omitempty"`
Message *google_protobuf.Any `protobuf:"bytes,2,opt,name=message" json:"message,omitempty"`
}
func (m *Item) Reset() { *m = Item{} }
func (m *Item) String() string { return proto.CompactTextString(m) }
func (*Item) ProtoMessage() {}
func (*Item) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{0} }
func (m *Item) GetID() string {
if m != nil {
return m.ID
}
return ""
}
func (m *Item) GetMessage() *google_protobuf.Any {
if m != nil {
return m.Message
}
return nil
}
type Bucket struct {
Key string `protobuf:"bytes,1,opt,name=key" json:"key,omitempty"`
Items []*Item `protobuf:"bytes,2,rep,name=items" json:"items,omitempty"`
}
func (m *Bucket) Reset() { *m = Bucket{} }
func (m *Bucket) String() string { return proto.CompactTextString(m) }
func (*Bucket) ProtoMessage() {}
func (*Bucket) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{1} }
func (m *Bucket) GetKey() string {
if m != nil {
return m.Key
}
return ""
}
func (m *Bucket) GetItems() []*Item {
if m != nil {
return m.Items
}
return nil
}
func init() {
proto.RegisterType((*Item)(nil), "storagepacker.Item")
proto.RegisterType((*Bucket)(nil), "storagepacker.Bucket")
}
func init() { proto.RegisterFile("types.proto", fileDescriptor0) }
var fileDescriptor0 = []byte{
// 181 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0xe2, 0x2e, 0xa9, 0x2c, 0x48,
0x2d, 0xd6, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0xe2, 0x2d, 0x2e, 0xc9, 0x2f, 0x4a, 0x4c, 0x4f,
0x2d, 0x48, 0x4c, 0xce, 0x4e, 0x2d, 0x92, 0x92, 0x4c, 0xcf, 0xcf, 0x4f, 0xcf, 0x49, 0xd5, 0x07,
0x4b, 0x26, 0x95, 0xa6, 0xe9, 0x27, 0xe6, 0x55, 0x42, 0x54, 0x2a, 0xb9, 0x71, 0xb1, 0x78, 0x96,
0xa4, 0xe6, 0x0a, 0xf1, 0x71, 0x31, 0x65, 0xa6, 0x48, 0x30, 0x2a, 0x30, 0x6a, 0x70, 0x06, 0x31,
0x65, 0xa6, 0x08, 0xe9, 0x71, 0xb1, 0xe7, 0xa6, 0x16, 0x17, 0x27, 0xa6, 0xa7, 0x4a, 0x30, 0x29,
0x30, 0x6a, 0x70, 0x1b, 0x89, 0xe8, 0x41, 0x0c, 0xd1, 0x83, 0x19, 0xa2, 0xe7, 0x98, 0x57, 0x19,
0x04, 0x53, 0xa4, 0xe4, 0xca, 0xc5, 0xe6, 0x54, 0x9a, 0x9c, 0x9d, 0x5a, 0x22, 0x24, 0xc0, 0xc5,
0x9c, 0x9d, 0x5a, 0x09, 0x35, 0x0a, 0xc4, 0x14, 0xd2, 0xe4, 0x62, 0xcd, 0x2c, 0x49, 0xcd, 0x2d,
0x96, 0x60, 0x52, 0x60, 0xd6, 0xe0, 0x36, 0x12, 0xd6, 0x43, 0x71, 0x9d, 0x1e, 0xc8, 0xfe, 0x20,
0x88, 0x8a, 0x24, 0x36, 0xb0, 0xe9, 0xc6, 0x80, 0x00, 0x00, 0x00, 0xff, 0xff, 0x01, 0x30, 0x77,
0x9a, 0xce, 0x00, 0x00, 0x00,
}

View File

@@ -0,0 +1,15 @@
syntax = "proto3";
package storagepacker;
import "google/protobuf/any.proto";
message Item {
string id = 1;
google.protobuf.Any message = 2;
}
message Bucket {
string key = 1;
repeated Item items = 2;
}

View File

@@ -18,6 +18,10 @@ type ResponseWrapInfo struct {
// created token's accessor will be accessible here
WrappedAccessor string `json:"wrapped_accessor" structs:"wrapped_accessor" mapstructure:"wrapped_accessor"`
// WrappedEntityID is the entity identifier of the caller who initiated the
// wrapping request
WrappedEntityID string `json:"wrapped_entity_id" structs:"wrapped_entity_id" mapstructure:"wrapped_entity_id"`
// The format to use. This doesn't get returned, it's only internal.
Format string `json:"format" structs:"format" mapstructure:"format"`

View File

@@ -186,6 +186,16 @@ func TestSysMounts_headerAuth(t *testing.T) {
},
"local": true,
},
"identity/": map[string]interface{}{
"description": "identity store",
"type": "identity",
"config": map[string]interface{}{
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
"force_no_cache": false,
},
"local": false,
},
},
"secret/": map[string]interface{}{
"description": "key/value secret storage",
@@ -217,6 +227,16 @@ func TestSysMounts_headerAuth(t *testing.T) {
},
"local": true,
},
"identity/": map[string]interface{}{
"description": "identity store",
"type": "identity",
"config": map[string]interface{}{
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
"force_no_cache": false,
},
"local": false,
},
}
testResponseStatus(t, resp, 200)
testResponseBody(t, resp, &actual)

View File

@@ -157,6 +157,7 @@ func TestLogical_StandbyRedirect(t *testing.T) {
"creation_ttl": json.Number("0"),
"explicit_max_ttl": json.Number("0"),
"expire_time": nil,
"entity_id": "",
},
"warnings": nilWarnings,
"wrap_info": nil,
@@ -206,6 +207,7 @@ func TestLogical_CreateToken(t *testing.T) {
"metadata": nil,
"lease_duration": json.Number("0"),
"renewable": false,
"entity_id": "",
},
"warnings": nilWarnings,
}

View File

@@ -313,6 +313,7 @@ func TestSysGenerateRoot_Update_OTP(t *testing.T) {
"path": "auth/token/root",
"explicit_max_ttl": json.Number("0"),
"expire_time": nil,
"entity_id": "",
}
resp = testHttpGet(t, newRootToken, addr+"/v1/auth/token/lookup-self")
@@ -403,6 +404,7 @@ func TestSysGenerateRoot_Update_PGP(t *testing.T) {
"path": "auth/token/root",
"explicit_max_ttl": json.Number("0"),
"expire_time": nil,
"entity_id": "",
}
resp = testHttpGet(t, newRootToken, addr+"/v1/auth/token/lookup-self")

View File

@@ -56,6 +56,16 @@ func TestSysMounts(t *testing.T) {
},
"local": true,
},
"identity/": map[string]interface{}{
"description": "identity store",
"type": "identity",
"config": map[string]interface{}{
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
"force_no_cache": false,
},
"local": false,
},
},
"secret/": map[string]interface{}{
"description": "key/value secret storage",
@@ -87,6 +97,16 @@ func TestSysMounts(t *testing.T) {
},
"local": true,
},
"identity/": map[string]interface{}{
"description": "identity store",
"type": "identity",
"config": map[string]interface{}{
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
"force_no_cache": false,
},
"local": false,
},
}
testResponseStatus(t, resp, 200)
testResponseBody(t, resp, &actual)
@@ -167,6 +187,16 @@ func TestSysMount(t *testing.T) {
},
"local": true,
},
"identity/": map[string]interface{}{
"description": "identity store",
"type": "identity",
"config": map[string]interface{}{
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
"force_no_cache": false,
},
"local": false,
},
},
"foo/": map[string]interface{}{
"description": "foo",
@@ -208,6 +238,16 @@ func TestSysMount(t *testing.T) {
},
"local": true,
},
"identity/": map[string]interface{}{
"description": "identity store",
"type": "identity",
"config": map[string]interface{}{
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
"force_no_cache": false,
},
"local": false,
},
}
testResponseStatus(t, resp, 200)
testResponseBody(t, resp, &actual)
@@ -310,6 +350,16 @@ func TestSysRemount(t *testing.T) {
},
"local": true,
},
"identity/": map[string]interface{}{
"description": "identity store",
"type": "identity",
"config": map[string]interface{}{
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
"force_no_cache": false,
},
"local": false,
},
},
"bar/": map[string]interface{}{
"description": "foo",
@@ -351,6 +401,16 @@ func TestSysRemount(t *testing.T) {
},
"local": true,
},
"identity/": map[string]interface{}{
"description": "identity store",
"type": "identity",
"config": map[string]interface{}{
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
"force_no_cache": false,
},
"local": false,
},
}
testResponseStatus(t, resp, 200)
testResponseBody(t, resp, &actual)
@@ -424,6 +484,16 @@ func TestSysUnmount(t *testing.T) {
},
"local": true,
},
"identity/": map[string]interface{}{
"description": "identity store",
"type": "identity",
"config": map[string]interface{}{
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
"force_no_cache": false,
},
"local": false,
},
},
"secret/": map[string]interface{}{
"description": "key/value secret storage",
@@ -455,6 +525,16 @@ func TestSysUnmount(t *testing.T) {
},
"local": true,
},
"identity/": map[string]interface{}{
"description": "identity store",
"type": "identity",
"config": map[string]interface{}{
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
"force_no_cache": false,
},
"local": false,
},
}
testResponseStatus(t, resp, 200)
testResponseBody(t, resp, &actual)
@@ -535,6 +615,16 @@ func TestSysTuneMount(t *testing.T) {
},
"local": true,
},
"identity/": map[string]interface{}{
"description": "identity store",
"type": "identity",
"config": map[string]interface{}{
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
"force_no_cache": false,
},
"local": false,
},
},
"foo/": map[string]interface{}{
"description": "foo",
@@ -576,6 +666,16 @@ func TestSysTuneMount(t *testing.T) {
},
"local": true,
},
"identity/": map[string]interface{}{
"description": "identity store",
"type": "identity",
"config": map[string]interface{}{
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
"force_no_cache": false,
},
"local": false,
},
}
testResponseStatus(t, resp, 200)
testResponseBody(t, resp, &actual)
@@ -677,6 +777,16 @@ func TestSysTuneMount(t *testing.T) {
},
"local": true,
},
"identity/": map[string]interface{}{
"description": "identity store",
"type": "identity",
"config": map[string]interface{}{
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
"force_no_cache": false,
},
"local": false,
},
},
"foo/": map[string]interface{}{
"description": "foo",
@@ -718,6 +828,16 @@ func TestSysTuneMount(t *testing.T) {
},
"local": true,
},
"identity/": map[string]interface{}{
"description": "identity store",
"type": "identity",
"config": map[string]interface{}{
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
"force_no_cache": false,
},
"local": false,
},
}
testResponseStatus(t, resp, 200)

View File

@@ -52,6 +52,10 @@ type Auth struct {
// Number of allowed uses of the issued token
NumUses int `json:"num_uses" mapstructure:"num_uses" structs:"num_uses"`
// EntityID is the identifier of the entity in identity store to which the
// identity of the authenticating client belongs to.
EntityID string `json:"entity_id" mapstructure:"entity_id" structs:"entity_id"`
// Alias is the information about the authenticated client returned by
// the auth backend
Alias *Alias `json:"alias" structs:"alias" mapstructure:"alias"`

View File

@@ -3,24 +3,14 @@ package logical
// Alias represents the information used by core to create implicit entity.
// Implicit entities get created when a client authenticates successfully from
// any of the authentication backends (except token backend).
//
// This is applicable to enterprise binaries only. Alias should be set in the
// Auth response returned by the credential backends. This structure is placed
// in the open source repository only to enable custom authetication plugins to
// be used along with enterprise binary. The custom auth plugins should make
// use of this and fill out the Alias information in the authentication
// response.
type Alias struct {
// MountType is the backend mount's type to which this identity belongs
// to.
MountType string `json:"mount_type" structs:"mount_type" mapstructure:"mount_type"`
// MountAccessor is the identifier of the mount entry to which
// this identity
// belongs to.
// MountAccessor is the identifier of the mount entry to which this
// identity belongs
MountAccessor string `json:"mount_accessor" structs:"mount_accessor" mapstructure:"mount_accessor"`
// Name is the identifier of this identity in its
// authentication source.
// Name is the identifier of this identity in its authentication source
Name string `json:"name" structs:"name" mapstructure:"name"`
}

View File

@@ -87,6 +87,10 @@ type Request struct {
// aliases, generating different defaults depending on the alias)
MountType string `json:"mount_type" structs:"mount_type" mapstructure:"mount_type"`
// MountAccessor is provided so that identities returned by the authentication
// backends can be tied to the mount it belongs to.
MountAccessor string `json:"mount_accessor" structs:"mount_accessor" mapstructure:"mount_accessor"`
// WrapInfo contains requested response wrapping parameters
WrapInfo *RequestWrapInfo `json:"wrap_info" structs:"wrap_info" mapstructure:"wrap_info"`
@@ -94,6 +98,10 @@ type Request struct {
// token supplied
ClientTokenRemainingUses int `json:"client_token_remaining_uses" structs:"client_token_remaining_uses" mapstructure:"client_token_remaining_uses"`
// EntityID is the identity of the caller extracted out of the token used
// to make this request
EntityID string `json:"entity_id" structs:"entity_id" mapstructure:"entity_id"`
// For replication, contains the last WAL on the remote side after handling
// the request, used for best-effort avoidance of stale read-after-write
lastRemoteWAL uint64

View File

@@ -33,6 +33,7 @@ func LogicalResponseToHTTPResponse(input *Response) *HTTPResponse {
Metadata: input.Auth.Metadata,
LeaseDuration: int(input.Auth.TTL.Seconds()),
Renewable: input.Auth.Renewable,
EntityID: input.Auth.EntityID,
}
}
@@ -59,6 +60,7 @@ func HTTPResponseToLogicalResponse(input *HTTPResponse) *Response {
Accessor: input.Auth.Accessor,
Policies: input.Auth.Policies,
Metadata: input.Auth.Metadata,
EntityID: input.Auth.EntityID,
}
logicalResp.Auth.Renewable = input.Auth.Renewable
logicalResp.Auth.TTL = time.Second * time.Duration(input.Auth.LeaseDuration)
@@ -85,6 +87,7 @@ type HTTPAuth struct {
Metadata map[string]string `json:"metadata"`
LeaseDuration int `json:"lease_duration"`
Renewable bool `json:"renewable"`
EntityID string `json:"entity_id"`
}
type HTTPWrapInfo struct {

View File

@@ -26,6 +26,7 @@ import (
"github.com/hashicorp/vault/audit"
"github.com/hashicorp/vault/helper/consts"
"github.com/hashicorp/vault/helper/errutil"
"github.com/hashicorp/vault/helper/identity"
"github.com/hashicorp/vault/helper/jsonutil"
"github.com/hashicorp/vault/helper/logformat"
"github.com/hashicorp/vault/helper/mlock"
@@ -240,6 +241,9 @@ type Core struct {
// can be output in the audit logs
auditedHeaders *AuditedHeadersConfig
// systemBackend is the backend which is used to manage internal operations
systemBackend *SystemBackend
// systemBarrierView is the barrier view for the system backend
systemBarrierView *BarrierView
@@ -256,6 +260,9 @@ type Core struct {
// token store is used to manage authentication tokens
tokenStore *TokenStore
// identityStore is used to manage client entities
identityStore *IdentityStore
// metricsCh is used to stop the metrics streaming
metricsCh chan struct{}
@@ -551,6 +558,11 @@ func NewCore(conf *CoreConfig) (*Core, error) {
}
return b, nil
}
logicalBackends["identity"] = func(config *logical.BackendConfig) (logical.Backend, error) {
return NewIdentityStore(c, config)
}
c.logicalBackends = logicalBackends
credentialBackends := make(map[string]logical.Factory)
@@ -634,45 +646,88 @@ func (c *Core) LookupToken(token string) (*TokenEntry, error) {
return c.tokenStore.Lookup(token)
}
func (c *Core) fetchACLandTokenEntry(req *logical.Request) (*ACL, *TokenEntry, error) {
func (c *Core) fetchACLTokenEntryAndEntity(clientToken string) (*ACL, *TokenEntry, *identity.Entity, error) {
defer metrics.MeasureSince([]string{"core", "fetch_acl_and_token"}, time.Now())
// Ensure there is a client token
if req.ClientToken == "" {
return nil, nil, fmt.Errorf("missing client token")
if clientToken == "" {
return nil, nil, nil, fmt.Errorf("missing client token")
}
if c.tokenStore == nil {
c.logger.Error("core: token store is unavailable")
return nil, nil, ErrInternalError
return nil, nil, nil, ErrInternalError
}
// Resolve the token policy
te, err := c.tokenStore.Lookup(req.ClientToken)
te, err := c.tokenStore.Lookup(clientToken)
if err != nil {
c.logger.Error("core: failed to lookup token", "error", err)
return nil, nil, ErrInternalError
return nil, nil, nil, ErrInternalError
}
// Ensure the token is valid
if te == nil {
return nil, nil, logical.ErrPermissionDenied
return nil, nil, nil, logical.ErrPermissionDenied
}
tokenPolicies := te.Policies
var entity *identity.Entity
// Append the policies of the entity to those on the tokens and create ACL
// off of the combined list.
if te.EntityID != "" {
//c.logger.Debug("core: entity set on the token", "entity_id", te.EntityID)
// Fetch entity for the entity ID in the token entry
entity, err = c.identityStore.memDBEntityByID(te.EntityID, false)
if err != nil {
c.logger.Error("core: failed to lookup entity using its ID", "error", err)
return nil, nil, nil, ErrInternalError
}
if entity == nil {
// If there was no corresponding entity object found, it is
// possible that the entity got merged into another entity. Try
// finding entity based on the merged entity index.
entity, err = c.identityStore.memDBEntityByMergedEntityID(te.EntityID, false)
if err != nil {
c.logger.Error("core: failed to lookup entity in merged entity ID index", "error", err)
return nil, nil, nil, ErrInternalError
}
}
if entity != nil {
//c.logger.Debug("core: entity successfully fetched; adding entity policies to token's policies to create ACL")
// Attach the policies on the entity to the policies tied to the token
tokenPolicies = append(tokenPolicies, entity.Policies...)
groupPolicies, err := c.identityStore.groupPoliciesByEntityID(entity.ID)
if err != nil {
c.logger.Error("core: failed to fetch group policies", "error", err)
return nil, nil, nil, ErrInternalError
}
// Attach the policies from all the groups to which this entity ID
// belongs to
tokenPolicies = append(tokenPolicies, groupPolicies...)
}
}
// Construct the corresponding ACL object
acl, err := c.policyStore.ACL(te.Policies...)
acl, err := c.policyStore.ACL(tokenPolicies...)
if err != nil {
c.logger.Error("core: failed to construct ACL", "error", err)
return nil, nil, ErrInternalError
return nil, nil, nil, ErrInternalError
}
return acl, te, nil
return acl, te, entity, nil
}
func (c *Core) checkToken(req *logical.Request) (*logical.Auth, *TokenEntry, error) {
defer metrics.MeasureSince([]string{"core", "check_token"}, time.Now())
acl, te, err := c.fetchACLandTokenEntry(req)
acl, te, _, err := c.fetchACLTokenEntryAndEntity(req.ClientToken)
if err != nil {
return nil, te, err
}
@@ -720,9 +775,15 @@ func (c *Core) checkToken(req *logical.Request) (*logical.Auth, *TokenEntry, err
auth := &logical.Auth{
ClientToken: req.ClientToken,
Accessor: req.ClientTokenAccessor,
Policies: te.Policies,
Metadata: te.Meta,
DisplayName: te.DisplayName,
}
if te != nil {
auth.Policies = te.Policies
auth.Metadata = te.Meta
auth.DisplayName = te.DisplayName
auth.EntityID = te.EntityID
// Store the entity ID in the request object
req.EntityID = te.EntityID
}
// Check the standard non-root ACLs. Return the token entry if it's not
@@ -1086,7 +1147,7 @@ func (c *Core) sealInitCommon(req *logical.Request) (retErr error) {
}
// Validate the token is a root token
acl, te, err := c.fetchACLandTokenEntry(req)
acl, te, _, err := c.fetchACLTokenEntryAndEntity(req.ClientToken)
if err != nil {
// Since there is no token store in standby nodes, sealing cannot
// be done. Ideally, the request has to be forwarded to leader node
@@ -1204,7 +1265,7 @@ func (c *Core) StepDown(req *logical.Request) (retErr error) {
return nil
}
acl, te, err := c.fetchACLandTokenEntry(req)
acl, te, _, err := c.fetchACLTokenEntryAndEntity(req.ClientToken)
if err != nil {
retErr = multierror.Append(retErr, err)
return retErr

View File

@@ -109,7 +109,7 @@ func (d dynamicSystemView) ResponseWrapData(data map[string]interface{}, ttl tim
resp.WrapInfo.Format = "jwt"
}
_, err := d.core.wrapInCubbyhole(req, resp)
_, err := d.core.wrapInCubbyhole(req, resp, nil)
if err != nil {
return nil, err
}

78
vault/identity_lookup.go Normal file
View File

@@ -0,0 +1,78 @@
package vault
import (
"fmt"
"strings"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
)
func lookupPaths(i *IdentityStore) []*framework.Path {
return []*framework.Path{
{
Pattern: "lookup/group$",
Fields: map[string]*framework.FieldSchema{
"type": {
Type: framework.TypeString,
Description: "Type of lookup. Current supported values are 'by_id' and 'by_name'",
},
"group_name": {
Type: framework.TypeString,
Description: "Name of the group.",
},
"group_id": {
Type: framework.TypeString,
Description: "ID of the group.",
},
},
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.UpdateOperation: i.checkPremiumVersion(i.pathLookupGroupUpdate),
},
HelpSynopsis: strings.TrimSpace(lookupHelp["lookup-group"][0]),
HelpDescription: strings.TrimSpace(lookupHelp["lookup-group"][1]),
},
}
}
func (i *IdentityStore) pathLookupGroupUpdate(req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
lookupType := d.Get("type").(string)
if lookupType == "" {
return logical.ErrorResponse("empty type"), nil
}
switch lookupType {
case "by_id":
groupID := d.Get("group_id").(string)
if groupID == "" {
return logical.ErrorResponse("empty group_id"), nil
}
group, err := i.memDBGroupByID(groupID, false)
if err != nil {
return nil, err
}
return i.handleGroupReadCommon(group)
case "by_name":
groupName := d.Get("group_name").(string)
if groupName == "" {
return logical.ErrorResponse("empty group_name"), nil
}
group, err := i.memDBGroupByName(groupName, false)
if err != nil {
return nil, err
}
return i.handleGroupReadCommon(group)
default:
return logical.ErrorResponse(fmt.Sprintf("unrecognized type %q", lookupType)), nil
}
return nil, nil
}
var lookupHelp = map[string][2]string{
"lookup-group": {
"Query groups based on factors.",
"Currently this supports querying groups by its name or ID.",
},
}

334
vault/identity_store.go Normal file
View File

@@ -0,0 +1,334 @@
package vault
import (
"fmt"
"strings"
"github.com/golang/protobuf/ptypes"
memdb "github.com/hashicorp/go-memdb"
"github.com/hashicorp/vault/helper/identity"
"github.com/hashicorp/vault/helper/locksutil"
"github.com/hashicorp/vault/helper/storagepacker"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
"github.com/hashicorp/vault/version"
)
const (
groupBucketsPrefix = "packer/group/buckets/"
)
// NewIdentityStore creates a new identity store
func NewIdentityStore(core *Core, config *logical.BackendConfig) (*IdentityStore, error) {
var err error
// Create a new in-memory database for the identity store
db, err := memdb.NewMemDB(identityStoreSchema())
if err != nil {
return nil, fmt.Errorf("failed to create memdb for identity store: %v", err)
}
iStore := &IdentityStore{
view: config.StorageView,
db: db,
entityLocks: locksutil.CreateLocks(),
logger: core.logger,
validateMountAccessorFunc: core.router.validateMountByAccessor,
}
iStore.entityPacker, err = storagepacker.NewStoragePacker(iStore.view, iStore.logger, "")
if err != nil {
return nil, fmt.Errorf("failed to create entity packer: %v", err)
}
iStore.groupPacker, err = storagepacker.NewStoragePacker(iStore.view, iStore.logger, groupBucketsPrefix)
if err != nil {
return nil, fmt.Errorf("failed to create group packer: %v", err)
}
iStore.Backend = &framework.Backend{
BackendType: logical.TypeLogical,
Paths: framework.PathAppend(
entityPaths(iStore),
aliasPaths(iStore),
groupPaths(iStore),
lookupPaths(iStore),
upgradePaths(iStore),
),
Invalidate: iStore.Invalidate,
}
err = iStore.Setup(config)
if err != nil {
return nil, err
}
return iStore, nil
}
func (i *IdentityStore) checkPremiumVersion(f framework.OperationFunc) framework.OperationFunc {
ver := version.GetVersion()
if ver.VersionMetadata == "pro" {
return func(*logical.Request, *framework.FieldData) (*logical.Response, error) {
return logical.ErrorResponse("identity features not available in the pro version"), logical.ErrInvalidRequest
}
}
return f
}
// Invalidate is a callback wherein the backend is informed that the value at
// the given key is updated. In identity store's case, it would be the entity
// storage entries that get updated. The value needs to be read and MemDB needs
// to be updated accordingly.
func (i *IdentityStore) Invalidate(key string) {
i.logger.Debug("identity: invalidate notification received", "key", key)
switch {
// Check if the key is a storage entry key for an entity bucket
case strings.HasPrefix(key, storagepacker.StoragePackerBucketsPrefix):
// Get the hash value of the storage bucket entry key
bucketKeyHash := i.entityPacker.BucketKeyHashByKey(key)
if len(bucketKeyHash) == 0 {
i.logger.Error("failed to get the bucket entry key hash")
return
}
// Create a MemDB transaction
txn := i.db.Txn(true)
defer txn.Abort()
// Each entity object in MemDB holds the MD5 hash of the storage
// entry key of the entity bucket. Fetch all the entities that
// belong to this bucket using the hash value. Remove these entities
// from MemDB along with all the aliases of each entity.
entitiesFetched, err := i.memDBEntitiesByBucketEntryKeyHashInTxn(txn, string(bucketKeyHash))
if err != nil {
i.logger.Error("failed to fetch entities using the bucket entry key hash", "bucket_entry_key_hash", bucketKeyHash)
return
}
for _, entity := range entitiesFetched {
// Delete all the aliases in the entity. This function will also remove
// the corresponding alias indexes too.
err = i.deleteAliasesInEntityInTxn(txn, entity, entity.Aliases)
if err != nil {
i.logger.Error("failed to delete aliases in entity", "entity_id", entity.ID, "error", err)
return
}
// Delete the entity using the same transaction
err = i.memDBDeleteEntityByIDInTxn(txn, entity.ID)
if err != nil {
i.logger.Error("failed to delete entity from MemDB", "entity_id", entity.ID, "error", err)
return
}
}
// Get the storage bucket entry
bucket, err := i.entityPacker.GetBucket(key)
if err != nil {
i.logger.Error("failed to refresh entities", "key", key, "error", err)
return
}
// If the underlying entry is nil, it means that this invalidation
// notification is for the deletion of the underlying storage entry. At
// this point, since all the entities belonging to this bucket are
// already removed, there is nothing else to be done. But, if the
// storage entry is non-nil, its an indication of an update. In this
// case, entities in the updated bucket needs to be reinserted into
// MemDB.
if bucket != nil {
for _, item := range bucket.Items {
entity, err := i.parseEntityFromBucketItem(item)
if err != nil {
i.logger.Error("failed to parse entity from bucket entry item", "error", err)
return
}
// Only update MemDB and don't touch the storage
err = i.upsertEntityInTxn(txn, entity, nil, false, false)
if err != nil {
i.logger.Error("failed to update entity in MemDB", "error", err)
return
}
}
}
txn.Commit()
return
// Check if the key is a storage entry key for an group bucket
case strings.HasPrefix(key, groupBucketsPrefix):
// Get the hash value of the storage bucket entry key
bucketKeyHash := i.groupPacker.BucketKeyHashByKey(key)
if len(bucketKeyHash) == 0 {
i.logger.Error("failed to get the bucket entry key hash")
return
}
// Create a MemDB transaction
txn := i.db.Txn(true)
defer txn.Abort()
groupsFetched, err := i.memDBGroupsByBucketEntryKeyHashInTxn(txn, string(bucketKeyHash))
if err != nil {
i.logger.Error("failed to fetch groups using the bucket entry key hash", "bucket_entry_key_hash", bucketKeyHash)
return
}
for _, group := range groupsFetched {
// Delete the group using the same transaction
err = i.memDBDeleteGroupByIDInTxn(txn, group.ID)
if err != nil {
i.logger.Error("failed to delete group from MemDB", "group_id", group.ID, "error", err)
return
}
}
// Get the storage bucket entry
bucket, err := i.groupPacker.GetBucket(key)
if err != nil {
i.logger.Error("failed to refresh group", "key", key, "error", err)
return
}
if bucket != nil {
for _, item := range bucket.Items {
group, err := i.parseGroupFromBucketItem(item)
if err != nil {
i.logger.Error("failed to parse group from bucket entry item", "error", err)
return
}
// Only update MemDB and don't touch the storage
err = i.upsertGroupInTxn(txn, group, false)
if err != nil {
i.logger.Error("failed to update group in MemDB", "error", err)
return
}
}
}
txn.Commit()
return
}
}
func (i *IdentityStore) parseEntityFromBucketItem(item *storagepacker.Item) (*identity.Entity, error) {
if item == nil {
return nil, fmt.Errorf("nil item")
}
var entity identity.Entity
err := ptypes.UnmarshalAny(item.Message, &entity)
if err != nil {
return nil, fmt.Errorf("failed to decode entity from storage bucket item: %v", err)
}
return &entity, nil
}
func (i *IdentityStore) parseGroupFromBucketItem(item *storagepacker.Item) (*identity.Group, error) {
if item == nil {
return nil, fmt.Errorf("nil item")
}
var group identity.Group
err := ptypes.UnmarshalAny(item.Message, &group)
if err != nil {
return nil, fmt.Errorf("failed to decode group from storage bucket item: %v", err)
}
return &group, nil
}
// EntityByAliasFactors fetches the entity based on factors of alias, i.e mount
// accessor and the alias name.
func (i *IdentityStore) EntityByAliasFactors(mountAccessor, aliasName string, clone bool) (*identity.Entity, error) {
if mountAccessor == "" {
return nil, fmt.Errorf("missing mount accessor")
}
if aliasName == "" {
return nil, fmt.Errorf("missing alias name")
}
alias, err := i.memDBAliasByFactors(mountAccessor, aliasName, false)
if err != nil {
return nil, err
}
if alias == nil {
return nil, nil
}
return i.memDBEntityByAliasID(alias.ID, clone)
}
// CreateEntity creates a new entity. This is used by core to
// associate each login attempt by a alias to a unified entity in Vault.
func (i *IdentityStore) CreateEntity(alias *logical.Alias) (*identity.Entity, error) {
var entity *identity.Entity
var err error
if alias == nil {
return nil, fmt.Errorf("alias is nil")
}
if alias.Name == "" {
return nil, fmt.Errorf("empty alias name")
}
mountValidationResp := i.validateMountAccessorFunc(alias.MountAccessor)
if mountValidationResp == nil {
return nil, fmt.Errorf("invalid mount accessor %q", alias.MountAccessor)
}
if mountValidationResp.MountType != alias.MountType {
return nil, fmt.Errorf("mount accessor %q is not a mount of type %q", alias.MountAccessor, alias.MountType)
}
// Check if an entity already exists for the given alais
entity, err = i.EntityByAliasFactors(alias.MountAccessor, alias.Name, false)
if err != nil {
return nil, err
}
if entity != nil {
return nil, fmt.Errorf("alias already belongs to a different entity")
}
entity = &identity.Entity{}
err = i.sanitizeEntity(entity)
if err != nil {
return nil, err
}
// Create a new alias
newAlias := &identity.Alias{
EntityID: entity.ID,
Name: alias.Name,
MountAccessor: alias.MountAccessor,
MountPath: mountValidationResp.MountPath,
MountType: mountValidationResp.MountType,
}
err = i.sanitizeAlias(newAlias)
if err != nil {
return nil, err
}
// Append the new alias to the new entity
entity.Aliases = []*identity.Alias{
newAlias,
}
// Update MemDB and persist entity object
err = i.upsertEntity(entity, nil, true)
if err != nil {
return nil, err
}
return entity, nil
}

View File

@@ -0,0 +1,364 @@
package vault
import (
"fmt"
"strings"
"github.com/golang/protobuf/ptypes"
memdb "github.com/hashicorp/go-memdb"
"github.com/hashicorp/vault/helper/identity"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
)
// aliasPaths returns the API endpoints to operate on aliases.
// Following are the paths supported:
// alias - To register/modify a alias
// alias/id - To lookup, delete and list aliases based on ID
func aliasPaths(i *IdentityStore) []*framework.Path {
return []*framework.Path{
{
Pattern: "alias$",
Fields: map[string]*framework.FieldSchema{
"id": {
Type: framework.TypeString,
Description: "ID of the alias",
},
"entity_id": {
Type: framework.TypeString,
Description: "Entity ID to which this alias belongs to",
},
"mount_accessor": {
Type: framework.TypeString,
Description: "Mount accessor to which this alias belongs to",
},
"name": {
Type: framework.TypeString,
Description: "Name of the alias",
},
"metadata": {
Type: framework.TypeStringSlice,
Description: "Metadata to be associated with the alias. Format should be a list of `key=value` pairs.",
},
},
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.UpdateOperation: i.checkPremiumVersion(i.pathAliasRegister),
},
HelpSynopsis: strings.TrimSpace(aliasHelp["alias"][0]),
HelpDescription: strings.TrimSpace(aliasHelp["alias"][1]),
},
{
Pattern: "alias/id/" + framework.GenericNameRegex("id"),
Fields: map[string]*framework.FieldSchema{
"id": {
Type: framework.TypeString,
Description: "ID of the alias",
},
"entity_id": {
Type: framework.TypeString,
Description: "Entity ID to which this alias should be tied to",
},
"mount_accessor": {
Type: framework.TypeString,
Description: "Mount accessor to which this alias belongs to",
},
"name": {
Type: framework.TypeString,
Description: "Name of the alias",
},
"metadata": {
Type: framework.TypeStringSlice,
Description: "Metadata to be associated with the alias. Format should be a comma separated list of `key=value` pairs.",
},
},
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.UpdateOperation: i.checkPremiumVersion(i.pathAliasIDUpdate),
logical.ReadOperation: i.checkPremiumVersion(i.pathAliasIDRead),
logical.DeleteOperation: i.checkPremiumVersion(i.pathAliasIDDelete),
},
HelpSynopsis: strings.TrimSpace(aliasHelp["alias-id"][0]),
HelpDescription: strings.TrimSpace(aliasHelp["alias-id"][1]),
},
{
Pattern: "alias/id/?$",
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.ListOperation: i.checkPremiumVersion(i.pathAliasIDList),
},
HelpSynopsis: strings.TrimSpace(aliasHelp["alias-id-list"][0]),
HelpDescription: strings.TrimSpace(aliasHelp["alias-id-list"][1]),
},
}
}
// pathAliasRegister is used to register new alias
func (i *IdentityStore) pathAliasRegister(req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
_, ok := d.GetOk("id")
if ok {
return i.pathAliasIDUpdate(req, d)
}
return i.handleAliasUpdateCommon(req, d, nil)
}
// pathAliasIDUpdate is used to update a alias based on the given
// alias ID
func (i *IdentityStore) pathAliasIDUpdate(req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
// Get alias id
aliasID := d.Get("id").(string)
if aliasID == "" {
return logical.ErrorResponse("missing alias ID"), nil
}
alias, err := i.memDBAliasByID(aliasID, true)
if err != nil {
return nil, err
}
if alias == nil {
return logical.ErrorResponse("invalid alias id"), nil
}
return i.handleAliasUpdateCommon(req, d, alias)
}
// handleAliasUpdateCommon is used to update a alias
func (i *IdentityStore) handleAliasUpdateCommon(req *logical.Request, d *framework.FieldData, alias *identity.Alias) (*logical.Response, error) {
var err error
var newAlias bool
var entity *identity.Entity
var previousEntity *identity.Entity
// Alias will be nil when a new alias is being registered; create a
// new struct in that case.
if alias == nil {
alias = &identity.Alias{}
newAlias = true
}
// Get entity id
entityID := d.Get("entity_id").(string)
if entityID != "" {
entity, err = i.memDBEntityByID(entityID, true)
if err != nil {
return nil, err
}
if entity == nil {
return logical.ErrorResponse("invalid entity ID"), nil
}
}
// Get alias name
aliasName := d.Get("name").(string)
if aliasName == "" {
return logical.ErrorResponse("missing alias name"), nil
}
mountAccessor := d.Get("mount_accessor").(string)
if mountAccessor == "" {
return logical.ErrorResponse("missing mount_accessor"), nil
}
mountValidationResp := i.validateMountAccessorFunc(mountAccessor)
if mountValidationResp == nil {
return logical.ErrorResponse(fmt.Sprintf("invalid mount accessor %q", mountAccessor)), nil
}
// Get alias metadata
// Accept metadata in the form of map[string]string to be able to index on
// it
var aliasMetadata map[string]string
aliasMetadataRaw, ok := d.GetOk("metadata")
if ok {
aliasMetadata, err = parseMetadata(aliasMetadataRaw.([]string))
if err != nil {
return logical.ErrorResponse(fmt.Sprintf("failed to parse alias metadata: %v", err)), nil
}
}
aliasByFactors, err := i.memDBAliasByFactors(mountValidationResp.MountAccessor, aliasName, false)
if err != nil {
return nil, err
}
resp := &logical.Response{}
if newAlias {
if aliasByFactors != nil {
return logical.ErrorResponse("combination of mount and alias name is already in use"), nil
}
// If this is a alias being tied to a non-existent entity, create
// a new entity for it.
if entity == nil {
entity = &identity.Entity{
Aliases: []*identity.Alias{
alias,
},
}
} else {
entity.Aliases = append(entity.Aliases, alias)
}
} else {
// Verify that the combination of alias name and mount is not
// already tied to a different alias
if aliasByFactors != nil && aliasByFactors.ID != alias.ID {
return logical.ErrorResponse("combination of mount and alias name is already in use"), nil
}
// Fetch the entity to which the alias is tied to
existingEntity, err := i.memDBEntityByAliasID(alias.ID, true)
if err != nil {
return nil, err
}
if existingEntity == nil {
return nil, fmt.Errorf("alias is not associated with an entity")
}
if entity != nil && entity.ID != existingEntity.ID {
// Alias should be transferred from 'existingEntity' to 'entity'
err = i.deleteAliasFromEntity(existingEntity, alias)
if err != nil {
return nil, err
}
previousEntity = existingEntity
entity.Aliases = append(entity.Aliases, alias)
resp.AddWarning(fmt.Sprintf("alias is being transferred from entity %q to %q", existingEntity.ID, entity.ID))
} else {
// Update entity with modified alias
err = i.updateAliasInEntity(existingEntity, alias)
if err != nil {
return nil, err
}
entity = existingEntity
}
}
// ID creation and other validations; This is more useful for new entities
// and may not perform anything for the existing entities. Placing the
// check here to make the flow common for both new and existing entities.
err = i.sanitizeEntity(entity)
if err != nil {
return nil, err
}
// Update the fields
alias.Name = aliasName
alias.Metadata = aliasMetadata
alias.MountType = mountValidationResp.MountType
alias.MountAccessor = mountValidationResp.MountAccessor
alias.MountPath = mountValidationResp.MountPath
// Set the entity ID in the alias index. This should be done after
// sanitizing entity.
alias.EntityID = entity.ID
// ID creation and other validations
err = i.sanitizeAlias(alias)
if err != nil {
return nil, err
}
// Index entity and its aliases in MemDB and persist entity along with
// aliases in storage. If the alias is being transferred over from
// one entity to another, previous entity needs to get refreshed in MemDB
// and persisted in storage as well.
err = i.upsertEntity(entity, previousEntity, true)
if err != nil {
return nil, err
}
// Return ID of both alias and entity
resp.Data = map[string]interface{}{
"id": alias.ID,
"entity_id": entity.ID,
}
return resp, nil
}
// pathAliasIDRead returns the properties of a alias for a given
// alias ID
func (i *IdentityStore) pathAliasIDRead(req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
aliasID := d.Get("id").(string)
if aliasID == "" {
return logical.ErrorResponse("missing alias id"), nil
}
alias, err := i.memDBAliasByID(aliasID, false)
if err != nil {
return nil, err
}
if alias == nil {
return nil, nil
}
respData := map[string]interface{}{}
respData["id"] = alias.ID
respData["entity_id"] = alias.EntityID
respData["mount_type"] = alias.MountType
respData["mount_accessor"] = alias.MountAccessor
respData["mount_path"] = alias.MountPath
respData["metadata"] = alias.Metadata
respData["name"] = alias.Name
respData["merged_from_entity_ids"] = alias.MergedFromEntityIDs
// Convert protobuf timestamp into RFC3339 format
respData["creation_time"] = ptypes.TimestampString(alias.CreationTime)
respData["last_update_time"] = ptypes.TimestampString(alias.LastUpdateTime)
return &logical.Response{
Data: respData,
}, nil
}
// pathAliasIDDelete deleted the alias for a given alias ID
func (i *IdentityStore) pathAliasIDDelete(req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
aliasID := d.Get("id").(string)
if aliasID == "" {
return logical.ErrorResponse("missing alias ID"), nil
}
return nil, i.deleteAlias(aliasID)
}
// pathAliasIDList lists the IDs of all the valid aliases in the identity
// store
func (i *IdentityStore) pathAliasIDList(req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
ws := memdb.NewWatchSet()
iter, err := i.memDBAliases(ws)
if err != nil {
return nil, fmt.Errorf("failed to fetch iterator for aliases in memdb: %v", err)
}
var aliasIDs []string
for {
raw := iter.Next()
if raw == nil {
break
}
aliasIDs = append(aliasIDs, raw.(*identity.Alias).ID)
}
return logical.ListResponse(aliasIDs), nil
}
var aliasHelp = map[string][2]string{
"alias": {
"Create a new alias",
"",
},
"alias-id": {
"Update, read or delete an entity using alias ID",
"",
},
"alias-id-list": {
"List all the entity IDs",
"",
},
}

View File

@@ -0,0 +1,531 @@
package vault
import (
"reflect"
"testing"
"github.com/hashicorp/vault/helper/identity"
"github.com/hashicorp/vault/logical"
)
func TestIdentityStore_ListAlias(t *testing.T) {
var err error
var resp *logical.Response
is, githubAccessor, _ := testIdentityStoreWithGithubAuth(t)
entityReq := &logical.Request{
Operation: logical.UpdateOperation,
Path: "entity",
}
resp, err = is.HandleRequest(entityReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
if resp == nil {
t.Fatalf("expected a non-nil response")
}
entityID := resp.Data["id"].(string)
// Create a alias
aliasData := map[string]interface{}{
"name": "testaliasname",
"mount_accessor": githubAccessor,
}
aliasReq := &logical.Request{
Operation: logical.UpdateOperation,
Path: "alias",
Data: aliasData,
}
resp, err = is.HandleRequest(aliasReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
aliasData["name"] = "entityalias"
aliasData["entity_id"] = entityID
resp, err = is.HandleRequest(aliasReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
listReq := &logical.Request{
Operation: logical.ListOperation,
Path: "alias/id",
}
resp, err = is.HandleRequest(listReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
keys := resp.Data["keys"].([]string)
if len(keys) != 2 {
t.Fatalf("bad: lengh of alias IDs listed; expected: 2, actual: %d", len(keys))
}
}
// This test is required because MemDB does not take care of ensuring
// uniqueness of indexes that are marked unique.
func TestIdentityStore_AliasSameAliasNames(t *testing.T) {
var err error
var resp *logical.Response
is, githubAccessor, _ := testIdentityStoreWithGithubAuth(t)
aliasData := map[string]interface{}{
"name": "testaliasname",
"mount_accessor": githubAccessor,
}
aliasReq := &logical.Request{
Operation: logical.UpdateOperation,
Path: "alias",
Data: aliasData,
}
// Register a alias
resp, err = is.HandleRequest(aliasReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
// Register another alias with same name
resp, err = is.HandleRequest(aliasReq)
if err != nil {
t.Fatal(err)
}
if resp == nil || !resp.IsError() {
t.Fatalf("expected an error due to alias name not being unique")
}
}
func TestIdentityStore_MemDBAliasIndexes(t *testing.T) {
var err error
is, githubAccessor, _ := testIdentityStoreWithGithubAuth(t)
if is == nil {
t.Fatal("failed to create test identity store")
}
validateMountResp := is.validateMountAccessorFunc(githubAccessor)
if validateMountResp == nil {
t.Fatal("failed to validate github auth mount")
}
entity := &identity.Entity{
ID: "testentityid",
Name: "testentityname",
}
entity.BucketKeyHash = is.entityPacker.BucketKeyHashByItemID(entity.ID)
err = is.memDBUpsertEntity(entity)
if err != nil {
t.Fatal(err)
}
alias := &identity.Alias{
EntityID: entity.ID,
ID: "testaliasid",
MountAccessor: githubAccessor,
MountType: validateMountResp.MountType,
Name: "testaliasname",
Metadata: map[string]string{
"testkey1": "testmetadatavalue1",
"testkey2": "testmetadatavalue2",
},
}
err = is.memDBUpsertAlias(alias)
if err != nil {
t.Fatal(err)
}
aliasFetched, err := is.memDBAliasByID("testaliasid", false)
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(alias, aliasFetched) {
t.Fatalf("bad: mismatched aliases; expected: %#v\n actual: %#v\n", alias, aliasFetched)
}
aliasFetched, err = is.memDBAliasByEntityID(entity.ID, false)
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(alias, aliasFetched) {
t.Fatalf("bad: mismatched aliases; expected: %#v\n actual: %#v\n", alias, aliasFetched)
}
aliasFetched, err = is.memDBAliasByFactors(validateMountResp.MountAccessor, "testaliasname", false)
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(alias, aliasFetched) {
t.Fatalf("bad: mismatched aliases; expected: %#v\n actual: %#v\n", alias, aliasFetched)
}
aliasesFetched, err := is.memDBAliasesByMetadata(map[string]string{
"testkey1": "testmetadatavalue1",
}, false)
if err != nil {
t.Fatal(err)
}
if len(aliasesFetched) != 1 {
t.Fatalf("bad: length of aliases; expected: 1, actual: %d", len(aliasesFetched))
}
if !reflect.DeepEqual(alias, aliasesFetched[0]) {
t.Fatalf("bad: mismatched aliases; expected: %#v\n actual: %#v\n", alias, aliasFetched)
}
aliasesFetched, err = is.memDBAliasesByMetadata(map[string]string{
"testkey2": "testmetadatavalue2",
}, false)
if err != nil {
t.Fatal(err)
}
if len(aliasesFetched) != 1 {
t.Fatalf("bad: length of aliases; expected: 1, actual: %d", len(aliasesFetched))
}
if !reflect.DeepEqual(alias, aliasesFetched[0]) {
t.Fatalf("bad: mismatched aliases; expected: %#v\n actual: %#v\n", alias, aliasFetched)
}
aliasesFetched, err = is.memDBAliasesByMetadata(map[string]string{
"testkey1": "testmetadatavalue1",
"testkey2": "testmetadatavalue2",
}, false)
if err != nil {
t.Fatal(err)
}
if len(aliasesFetched) != 1 {
t.Fatalf("bad: length of aliases; expected: 1, actual: %d", len(aliasesFetched))
}
if !reflect.DeepEqual(alias, aliasesFetched[0]) {
t.Fatalf("bad: mismatched aliases; expected: %#v\n actual: %#v\n", alias, aliasFetched)
}
alias2 := &identity.Alias{
EntityID: entity.ID,
ID: "testaliasid2",
MountAccessor: validateMountResp.MountAccessor,
MountType: validateMountResp.MountType,
Name: "testaliasname2",
Metadata: map[string]string{
"testkey1": "testmetadatavalue1",
"testkey3": "testmetadatavalue3",
},
}
err = is.memDBUpsertAlias(alias2)
if err != nil {
t.Fatal(err)
}
aliasesFetched, err = is.memDBAliasesByMetadata(map[string]string{
"testkey1": "testmetadatavalue1",
}, false)
if err != nil {
t.Fatal(err)
}
if len(aliasesFetched) != 2 {
t.Fatalf("bad: length of aliases; expected: 2, actual: %d", len(aliasesFetched))
}
aliasesFetched, err = is.memDBAliasesByMetadata(map[string]string{
"testkey3": "testmetadatavalue3",
}, false)
if err != nil {
t.Fatal(err)
}
if len(aliasesFetched) != 1 {
t.Fatalf("bad: length of aliases; expected: 1, actual: %d", len(aliasesFetched))
}
err = is.memDBDeleteAliasByID("testaliasid")
if err != nil {
t.Fatal(err)
}
aliasFetched, err = is.memDBAliasByID("testaliasid", false)
if err != nil {
t.Fatal(err)
}
if aliasFetched != nil {
t.Fatalf("expected a nil alias")
}
}
func TestIdentityStore_AliasRegister(t *testing.T) {
var err error
var resp *logical.Response
is, githubAccessor, _ := testIdentityStoreWithGithubAuth(t)
if is == nil {
t.Fatal("failed to create test alias store")
}
aliasData := map[string]interface{}{
"name": "testaliasname",
"mount_accessor": githubAccessor,
"metadata": []string{"organization=hashicorp", "team=vault"},
}
aliasReq := &logical.Request{
Operation: logical.UpdateOperation,
Path: "alias",
Data: aliasData,
}
// Register the alias
resp, err = is.HandleRequest(aliasReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
idRaw, ok := resp.Data["id"]
if !ok {
t.Fatalf("alias id not present in alias register response")
}
id := idRaw.(string)
if id == "" {
t.Fatalf("invalid alias id in alias register response")
}
entityIDRaw, ok := resp.Data["entity_id"]
if !ok {
t.Fatalf("entity id not present in alias register response")
}
entityID := entityIDRaw.(string)
if entityID == "" {
t.Fatalf("invalid entity id in alias register response")
}
}
func TestIdentityStore_AliasUpdate(t *testing.T) {
var err error
var resp *logical.Response
is, githubAccessor, _ := testIdentityStoreWithGithubAuth(t)
aliasData := map[string]interface{}{
"name": "testaliasname",
"mount_accessor": githubAccessor,
"metadata": []string{"organization=hashicorp", "team=vault"},
}
aliasReq := &logical.Request{
Operation: logical.UpdateOperation,
Path: "alias",
Data: aliasData,
}
// This will create a alias and a corresponding entity
resp, err = is.HandleRequest(aliasReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
aliasID := resp.Data["id"].(string)
updateData := map[string]interface{}{
"name": "updatedaliasname",
"mount_accessor": githubAccessor,
"metadata": []string{"organization=updatedorganization", "team=updatedteam"},
}
aliasReq.Data = updateData
aliasReq.Path = "alias/id/" + aliasID
resp, err = is.HandleRequest(aliasReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
aliasReq.Operation = logical.ReadOperation
resp, err = is.HandleRequest(aliasReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
aliasMetadata := resp.Data["metadata"].(map[string]string)
updatedOrg := aliasMetadata["organization"]
updatedTeam := aliasMetadata["team"]
if resp.Data["name"] != "updatedaliasname" || updatedOrg != "updatedorganization" || updatedTeam != "updatedteam" {
t.Fatalf("failed to update alias information; \n response data: %#v\n", resp.Data)
}
}
func TestIdentityStore_AliasUpdate_ByID(t *testing.T) {
var err error
var resp *logical.Response
is, githubAccessor, _ := testIdentityStoreWithGithubAuth(t)
updateData := map[string]interface{}{
"name": "updatedaliasname",
"mount_accessor": githubAccessor,
"metadata": []string{"organization=updatedorganization", "team=updatedteam"},
}
updateReq := &logical.Request{
Operation: logical.UpdateOperation,
Path: "alias/id/invalidaliasid",
Data: updateData,
}
// Try to update an non-existent alias
resp, err = is.HandleRequest(updateReq)
if err != nil {
t.Fatal(err)
}
if resp == nil || !resp.IsError() {
t.Fatalf("expected an error due to invalid alias id")
}
registerData := map[string]interface{}{
"name": "testaliasname",
"mount_accessor": githubAccessor,
"metadata": []string{"organization=hashicorp", "team=vault"},
}
registerReq := &logical.Request{
Operation: logical.UpdateOperation,
Path: "alias",
Data: registerData,
}
resp, err = is.HandleRequest(registerReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
idRaw, ok := resp.Data["id"]
if !ok {
t.Fatalf("alias id not present in response")
}
id := idRaw.(string)
if id == "" {
t.Fatalf("invalid alias id")
}
updateReq.Path = "alias/id/" + id
resp, err = is.HandleRequest(updateReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
readReq := &logical.Request{
Operation: logical.ReadOperation,
Path: updateReq.Path,
}
resp, err = is.HandleRequest(readReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
aliasMetadata := resp.Data["metadata"].(map[string]string)
updatedOrg := aliasMetadata["organization"]
updatedTeam := aliasMetadata["team"]
if resp.Data["name"] != "updatedaliasname" || updatedOrg != "updatedorganization" || updatedTeam != "updatedteam" {
t.Fatalf("failed to update alias information; \n response data: %#v\n", resp.Data)
}
delete(registerReq.Data, "name")
resp, err = is.HandleRequest(registerReq)
if err != nil {
t.Fatal(err)
}
if resp == nil || !resp.IsError() {
t.Fatalf("expected error due to missing alias name")
}
registerReq.Data["name"] = "testaliasname"
delete(registerReq.Data, "mount_accessor")
resp, err = is.HandleRequest(registerReq)
if err != nil {
t.Fatal(err)
}
if resp == nil || !resp.IsError() {
t.Fatalf("expected error due to missing mount accessor")
}
}
func TestIdentityStore_AliasReadDelete(t *testing.T) {
var err error
var resp *logical.Response
is, githubAccessor, _ := testIdentityStoreWithGithubAuth(t)
registerData := map[string]interface{}{
"name": "testaliasname",
"mount_accessor": githubAccessor,
"metadata": []string{"organization=hashicorp", "team=vault"},
}
registerReq := &logical.Request{
Operation: logical.UpdateOperation,
Path: "alias",
Data: registerData,
}
resp, err = is.HandleRequest(registerReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
idRaw, ok := resp.Data["id"]
if !ok {
t.Fatalf("alias id not present in response")
}
id := idRaw.(string)
if id == "" {
t.Fatalf("invalid alias id")
}
// Read it back using alias id
aliasReq := &logical.Request{
Operation: logical.ReadOperation,
Path: "alias/id/" + id,
}
resp, err = is.HandleRequest(aliasReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
if resp.Data["id"].(string) == "" ||
resp.Data["entity_id"].(string) == "" ||
resp.Data["name"].(string) != registerData["name"] ||
resp.Data["mount_type"].(string) != "github" {
t.Fatalf("bad: alias read response; \nexpected: %#v \nactual: %#v\n", registerData, resp.Data)
}
aliasReq.Operation = logical.DeleteOperation
resp, err = is.HandleRequest(aliasReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
aliasReq.Operation = logical.ReadOperation
resp, err = is.HandleRequest(aliasReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
if resp != nil {
t.Fatalf("bad: alias read response; expected: nil, actual: %#v\n", resp)
}
}

View File

@@ -0,0 +1,501 @@
package vault
import (
"fmt"
"strings"
"github.com/golang/protobuf/ptypes"
memdb "github.com/hashicorp/go-memdb"
"github.com/hashicorp/vault/helper/identity"
"github.com/hashicorp/vault/helper/locksutil"
"github.com/hashicorp/vault/helper/storagepacker"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
)
// entityPaths returns the API endpoints supported to operate on entities.
// Following are the paths supported:
// entity - To register a new entity
// entity/id - To lookup, modify, delete and list entities based on ID
// entity/merge - To merge entities based on ID
func entityPaths(i *IdentityStore) []*framework.Path {
return []*framework.Path{
{
Pattern: "entity$",
Fields: map[string]*framework.FieldSchema{
"id": {
Type: framework.TypeString,
Description: "ID of the entity",
},
"name": {
Type: framework.TypeString,
Description: "Name of the entity",
},
"metadata": {
Type: framework.TypeStringSlice,
Description: "Metadata to be associated with the entity. Format should be a list of `key=value` pairs.",
},
"policies": {
Type: framework.TypeCommaStringSlice,
Description: "Policies to be tied to the entity",
},
},
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.UpdateOperation: i.checkPremiumVersion(i.pathEntityRegister),
},
HelpSynopsis: strings.TrimSpace(entityHelp["entity"][0]),
HelpDescription: strings.TrimSpace(entityHelp["entity"][1]),
},
{
Pattern: "entity/id/" + framework.GenericNameRegex("id"),
Fields: map[string]*framework.FieldSchema{
"id": {
Type: framework.TypeString,
Description: "ID of the entity",
},
"name": {
Type: framework.TypeString,
Description: "Name of the entity",
},
"metadata": {
Type: framework.TypeStringSlice,
Description: "Metadata to be associated with the entity. Format should be a comma separated list of `key=value` pairs.",
},
"policies": {
Type: framework.TypeCommaStringSlice,
Description: "Policies to be tied to the entity",
},
},
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.UpdateOperation: i.checkPremiumVersion(i.pathEntityIDUpdate),
logical.ReadOperation: i.checkPremiumVersion(i.pathEntityIDRead),
logical.DeleteOperation: i.checkPremiumVersion(i.pathEntityIDDelete),
},
HelpSynopsis: strings.TrimSpace(entityHelp["entity-id"][0]),
HelpDescription: strings.TrimSpace(entityHelp["entity-id"][1]),
},
{
Pattern: "entity/id/?$",
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.ListOperation: i.checkPremiumVersion(i.pathEntityIDList),
},
HelpSynopsis: strings.TrimSpace(entityHelp["entity-id-list"][0]),
HelpDescription: strings.TrimSpace(entityHelp["entity-id-list"][1]),
},
{
Pattern: "entity/merge/?$",
Fields: map[string]*framework.FieldSchema{
"from_entity_ids": {
Type: framework.TypeCommaStringSlice,
Description: "Entity IDs which needs to get merged",
},
"to_entity_id": {
Type: framework.TypeString,
Description: "Entity ID into which all the other entities need to get merged",
},
"force": {
Type: framework.TypeBool,
Description: "Setting this will follow the 'mine' strategy for merging MFA secrets. If there are secrets of the same type both in entities that are merged from and in entity into which all others are getting merged, secrets in the destination will be unaltered. If not set, this API will throw an error containing all the conflicts.",
},
},
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.UpdateOperation: i.checkPremiumVersion(i.pathEntityMergeID),
},
HelpSynopsis: strings.TrimSpace(entityHelp["entity-merge-id"][0]),
HelpDescription: strings.TrimSpace(entityHelp["entity-merge-id"][1]),
},
}
}
// pathEntityMergeID merges two or more entities into a single entity
func (i *IdentityStore) pathEntityMergeID(req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
toEntityID := d.Get("to_entity_id").(string)
if toEntityID == "" {
return logical.ErrorResponse("missing entity id to merge to"), nil
}
fromEntityIDs := d.Get("from_entity_ids").([]string)
if len(fromEntityIDs) == 0 {
return logical.ErrorResponse("missing entity ids to merge from"), nil
}
force := d.Get("force").(bool)
toEntityForLocking, err := i.memDBEntityByID(toEntityID, false)
if err != nil {
return nil, err
}
if toEntityForLocking == nil {
return logical.ErrorResponse("entity id to merge to is invalid"), nil
}
// Acquire the lock to modify the entity storage entry to merge to
toEntityLock := locksutil.LockForKey(i.entityLocks, toEntityForLocking.ID)
toEntityLock.Lock()
defer toEntityLock.Unlock()
// Create a MemDB transaction to merge entities
txn := i.db.Txn(true)
defer txn.Abort()
// Re-read post lock acquisition
toEntity, err := i.memDBEntityByID(toEntityID, true)
if err != nil {
return nil, err
}
if toEntity == nil {
return logical.ErrorResponse("entity id to merge to is invalid"), nil
}
if toEntity.ID != toEntityForLocking.ID {
return logical.ErrorResponse("acquired lock for an undesired entity"), nil
}
var conflictErrors error
for _, fromEntityID := range fromEntityIDs {
if fromEntityID == toEntityID {
return logical.ErrorResponse("to_entity_id should not be present in from_entity_ids"), nil
}
lockFromEntity, err := i.memDBEntityByID(fromEntityID, false)
if err != nil {
return nil, err
}
if lockFromEntity == nil {
return logical.ErrorResponse("entity id to merge from is invalid"), nil
}
// Acquire the lock to modify the entity storage entry to merge from
fromEntityLock := locksutil.LockForKey(i.entityLocks, lockFromEntity.ID)
fromLockHeld := false
// There are only 256 lock buckets and the chances of entity ID collision
// is fairly high. When we are merging entities belonging to the same
// bucket, multiple attempts to acquire the same lock should be avoided.
if fromEntityLock != toEntityLock {
fromEntityLock.Lock()
fromLockHeld = true
}
// Re-read the entities post lock acquisition
fromEntity, err := i.memDBEntityByID(fromEntityID, false)
if err != nil {
if fromLockHeld {
fromEntityLock.Unlock()
}
return nil, err
}
if fromEntity == nil {
if fromLockHeld {
fromEntityLock.Unlock()
}
return logical.ErrorResponse("entity id to merge from is invalid"), nil
}
if fromEntity.ID != lockFromEntity.ID {
if fromLockHeld {
fromEntityLock.Unlock()
}
return logical.ErrorResponse("acquired lock for an undesired entity"), nil
}
for _, alias := range fromEntity.Aliases {
// Set the desired entity id
alias.EntityID = toEntity.ID
// Set the entity id of which this alias is now an alias to
alias.MergedFromEntityIDs = append(alias.MergedFromEntityIDs, fromEntity.ID)
err = i.memDBUpsertAliasInTxn(txn, alias)
if err != nil {
if fromLockHeld {
fromEntityLock.Unlock()
}
return nil, fmt.Errorf("failed to update alias during merge: %v", err)
}
// Add the alias to the desired entity
toEntity.Aliases = append(toEntity.Aliases, alias)
}
// If the entity from which we are merging from was already a merged
// entity, transfer over the Merged set to the entity we are
// merging into.
toEntity.MergedEntityIDs = append(toEntity.MergedEntityIDs, fromEntity.MergedEntityIDs...)
// Add the entity from which we are merging from to the list of entities
// the entity we are merging into is composed of.
toEntity.MergedEntityIDs = append(toEntity.MergedEntityIDs, fromEntity.ID)
// Delete the entity which we are merging from in MemDB using the same transaction
err = i.memDBDeleteEntityByIDInTxn(txn, fromEntity.ID)
if err != nil {
if fromLockHeld {
fromEntityLock.Unlock()
}
return nil, err
}
// Delete the entity which we are merging from in storage
err = i.entityPacker.DeleteItem(fromEntity.ID)
if err != nil {
if fromLockHeld {
fromEntityLock.Unlock()
}
return nil, err
}
if fromLockHeld {
fromEntityLock.Unlock()
}
}
if conflictErrors != nil && !force {
return logical.ErrorResponse(conflictErrors.Error()), nil
}
// Update MemDB with changes to the entity we are merging to
err = i.memDBUpsertEntityInTxn(txn, toEntity)
if err != nil {
return nil, err
}
// Persist the entity which we are merging to
toEntityAsAny, err := ptypes.MarshalAny(toEntity)
if err != nil {
return nil, err
}
item := &storagepacker.Item{
ID: toEntity.ID,
Message: toEntityAsAny,
}
err = i.entityPacker.PutItem(item)
if err != nil {
return nil, err
}
// Committing the transaction *after* successfully performing storage
// persistence
txn.Commit()
return nil, nil
}
// pathEntityRegister is used to register a new entity
func (i *IdentityStore) pathEntityRegister(req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
_, ok := d.GetOk("id")
if ok {
return i.pathEntityIDUpdate(req, d)
}
return i.handleEntityUpdateCommon(req, d, nil)
}
// pathEntityIDUpdate is used to update an entity based on the given entity ID
func (i *IdentityStore) pathEntityIDUpdate(req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
// Get entity id
entityID := d.Get("id").(string)
if entityID == "" {
return logical.ErrorResponse("missing entity id"), nil
}
entity, err := i.memDBEntityByID(entityID, true)
if err != nil {
return nil, err
}
if entity == nil {
return nil, fmt.Errorf("invalid entity id")
}
return i.handleEntityUpdateCommon(req, d, entity)
}
// handleEntityUpdateCommon is used to update an entity
func (i *IdentityStore) handleEntityUpdateCommon(req *logical.Request, d *framework.FieldData, entity *identity.Entity) (*logical.Response, error) {
var err error
var newEntity bool
// Entity will be nil when a new entity is being registered; create a new
// struct in that case.
if entity == nil {
entity = &identity.Entity{}
newEntity = true
}
// Update the policies if supplied
entityPoliciesRaw, ok := d.GetOk("policies")
if ok {
entity.Policies = entityPoliciesRaw.([]string)
}
// Get the name
entityName := d.Get("name").(string)
if entityName != "" {
entityByName, err := i.memDBEntityByName(entityName, false)
if err != nil {
return nil, err
}
switch {
case (newEntity && entityByName != nil), (entityByName != nil && entity.ID != "" && entityByName.ID != entity.ID):
return logical.ErrorResponse("entity name is already in use"), nil
}
entity.Name = entityName
}
// Get entity metadata
// Accept metadata in the form of map[string]string to be able to index on
// it
entityMetadataRaw, ok := d.GetOk("metadata")
if ok {
entity.Metadata, err = parseMetadata(entityMetadataRaw.([]string))
if err != nil {
return logical.ErrorResponse(fmt.Sprintf("failed to parse entity metadata: %v", err)), nil
}
}
// ID creation and some validations
err = i.sanitizeEntity(entity)
if err != nil {
return nil, err
}
// Prepare the response
respData := map[string]interface{}{
"id": entity.ID,
}
var aliasIDs []string
for _, alias := range entity.Aliases {
aliasIDs = append(aliasIDs, alias.ID)
}
respData["aliases"] = aliasIDs
// Update MemDB and persist entity object
err = i.upsertEntity(entity, nil, true)
if err != nil {
return nil, err
}
// Return ID of the entity that was either created or updated along with
// its aliases
return &logical.Response{
Data: respData,
}, nil
}
// pathEntityIDRead returns the properties of an entity for a given entity ID
func (i *IdentityStore) pathEntityIDRead(req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
entityID := d.Get("id").(string)
if entityID == "" {
return logical.ErrorResponse("missing entity id"), nil
}
entity, err := i.memDBEntityByID(entityID, false)
if err != nil {
return nil, err
}
if entity == nil {
return nil, nil
}
respData := map[string]interface{}{}
respData["id"] = entity.ID
respData["name"] = entity.Name
respData["metadata"] = entity.Metadata
respData["merged_entity_ids"] = entity.MergedEntityIDs
respData["policies"] = entity.Policies
// Convert protobuf timestamp into RFC3339 format
respData["creation_time"] = ptypes.TimestampString(entity.CreationTime)
respData["last_update_time"] = ptypes.TimestampString(entity.LastUpdateTime)
// Convert each alias into a map and replace the time format in each
aliasesToReturn := make([]interface{}, len(entity.Aliases))
for aliasIdx, alias := range entity.Aliases {
aliasMap := map[string]interface{}{}
aliasMap["id"] = alias.ID
aliasMap["entity_id"] = alias.EntityID
aliasMap["mount_type"] = alias.MountType
aliasMap["mount_accessor"] = alias.MountAccessor
aliasMap["mount_path"] = alias.MountPath
aliasMap["metadata"] = alias.Metadata
aliasMap["name"] = alias.Name
aliasMap["merged_from_entity_ids"] = alias.MergedFromEntityIDs
aliasMap["creation_time"] = ptypes.TimestampString(alias.CreationTime)
aliasMap["last_update_time"] = ptypes.TimestampString(alias.LastUpdateTime)
aliasesToReturn[aliasIdx] = aliasMap
}
// Add the aliases information to the response which has the correct time
// formats
respData["aliases"] = aliasesToReturn
resp := &logical.Response{
Data: respData,
}
return resp, nil
}
// pathEntityIDDelete deletes the entity for a given entity ID
func (i *IdentityStore) pathEntityIDDelete(req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
entityID := d.Get("id").(string)
if entityID == "" {
return logical.ErrorResponse("missing entity id"), nil
}
return nil, i.deleteEntity(entityID)
}
// pathEntityIDList lists the IDs of all the valid entities in the identity
// store
func (i *IdentityStore) pathEntityIDList(req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
ws := memdb.NewWatchSet()
iter, err := i.memDBEntities(ws)
if err != nil {
return nil, fmt.Errorf("failed to fetch iterator for entities in memdb: %v", err)
}
var entityIDs []string
for {
raw := iter.Next()
if raw == nil {
break
}
entityIDs = append(entityIDs, raw.(*identity.Entity).ID)
}
return logical.ListResponse(entityIDs), nil
}
var entityHelp = map[string][2]string{
"entity": {
"Create a new entity",
"",
},
"entity-id": {
"Update, read or delete an entity using entity ID",
"",
},
"entity-id-list": {
"List all the entity IDs",
"",
},
"entity-merge-id": {
"Merge two or more entities together",
"",
},
}

View File

@@ -0,0 +1,783 @@
package vault
import (
"fmt"
"reflect"
"sort"
"testing"
uuid "github.com/hashicorp/go-uuid"
credGithub "github.com/hashicorp/vault/builtin/credential/github"
"github.com/hashicorp/vault/helper/identity"
"github.com/hashicorp/vault/logical"
)
func TestIdentityStore_EntityCreateUpdate(t *testing.T) {
var err error
var resp *logical.Response
is, _, _ := testIdentityStoreWithGithubAuth(t)
entityData := map[string]interface{}{
"name": "testentityname",
"metadata": []string{"someusefulkey=someusefulvalue"},
"policies": []string{"testpolicy1", "testpolicy2"},
}
entityReq := &logical.Request{
Operation: logical.UpdateOperation,
Path: "entity",
Data: entityData,
}
// Create the entity
resp, err = is.HandleRequest(entityReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
entityID := resp.Data["id"].(string)
updateData := map[string]interface{}{
// Set the entity ID here
"id": entityID,
"name": "updatedentityname",
"metadata": []string{"updatedkey=updatedvalue"},
"policies": []string{"updatedpolicy1", "updatedpolicy2"},
}
entityReq.Data = updateData
// Update the entity
resp, err = is.HandleRequest(entityReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
entityReq.Path = "entity/id/" + entityID
entityReq.Operation = logical.ReadOperation
// Read the entity
resp, err = is.HandleRequest(entityReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
if resp.Data["id"] != entityID ||
resp.Data["name"] != updateData["name"] ||
!reflect.DeepEqual(resp.Data["policies"], updateData["policies"]) {
t.Fatalf("bad: entity response after update; resp: %#v\n updateData: %#v\n", resp.Data, updateData)
}
}
func TestIdentityStore_CloneImmutability(t *testing.T) {
alias := &identity.Alias{
ID: "testaliasid",
Name: "testaliasname",
MergedFromEntityIDs: []string{"entityid1"},
}
entity := &identity.Entity{
ID: "testentityid",
Name: "testentityname",
Aliases: []*identity.Alias{
alias,
},
}
clonedEntity, err := entity.Clone()
if err != nil {
t.Fatal(err)
}
// Modify entity
entity.Aliases[0].ID = "invalidid"
if clonedEntity.Aliases[0].ID == "invalidid" {
t.Fatalf("cloned entity is mutated")
}
clonedAlias, err := alias.Clone()
if err != nil {
t.Fatal(err)
}
alias.MergedFromEntityIDs[0] = "invalidid"
if clonedAlias.MergedFromEntityIDs[0] == "invalidid" {
t.Fatalf("cloned alias is mutated")
}
}
func TestIdentityStore_MemDBImmutability(t *testing.T) {
var err error
is, githubAccessor, _ := testIdentityStoreWithGithubAuth(t)
validateMountResp := is.validateMountAccessorFunc(githubAccessor)
if validateMountResp == nil {
t.Fatal("failed to validate github auth mount")
}
alias1 := &identity.Alias{
EntityID: "testentityid",
ID: "testaliasid",
MountAccessor: githubAccessor,
MountType: validateMountResp.MountType,
Name: "testaliasname",
Metadata: map[string]string{
"testkey1": "testmetadatavalue1",
"testkey2": "testmetadatavalue2",
},
}
entity := &identity.Entity{
ID: "testentityid",
Name: "testentityname",
Metadata: map[string]string{
"someusefulkey": "someusefulvalue",
},
Aliases: []*identity.Alias{
alias1,
},
}
entity.BucketKeyHash = is.entityPacker.BucketKeyHashByItemID(entity.ID)
err = is.memDBUpsertEntity(entity)
if err != nil {
t.Fatal(err)
}
entityFetched, err := is.memDBEntityByID(entity.ID, true)
if err != nil {
t.Fatal(err)
}
// Modify the fetched entity outside of a transaction
entityFetched.Aliases[0].ID = "invalidaliasid"
entityFetched, err = is.memDBEntityByID(entity.ID, false)
if err != nil {
t.Fatal(err)
}
if entityFetched.Aliases[0].ID == "invalidaliasid" {
t.Fatal("memdb item is mutable outside of transaction")
}
}
func TestIdentityStore_ListEntities(t *testing.T) {
var err error
var resp *logical.Response
is, _, _ := testIdentityStoreWithGithubAuth(t)
entityReq := &logical.Request{
Operation: logical.UpdateOperation,
Path: "entity",
}
expected := []string{}
for i := 0; i < 10; i++ {
resp, err = is.HandleRequest(entityReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
expected = append(expected, resp.Data["id"].(string))
}
listReq := &logical.Request{
Operation: logical.ListOperation,
Path: "entity/id",
}
resp, err = is.HandleRequest(listReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
actual := resp.Data["keys"].([]string)
// Sort the operands for DeepEqual to work
sort.Strings(actual)
sort.Strings(expected)
if !reflect.DeepEqual(expected, actual) {
t.Fatalf("bad: listed entity IDs; expected: %#v\n actual: %#v\n", expected, actual)
}
}
func TestIdentityStore_LoadingEntities(t *testing.T) {
var resp *logical.Response
// Add github credential factory to core config
err := AddTestCredentialBackend("github", credGithub.Factory)
if err != nil {
t.Fatalf("err: %s", err)
}
c := TestCore(t)
unsealKeys, token := TestCoreInit(t, c)
for _, key := range unsealKeys {
if _, err := TestCoreUnseal(c, TestKeyCopy(key)); err != nil {
t.Fatalf("unseal err: %s", err)
}
}
sealed, err := c.Sealed()
if err != nil {
t.Fatalf("err checking seal status: %s", err)
}
if sealed {
t.Fatal("should not be sealed")
}
meGH := &MountEntry{
Table: credentialTableType,
Path: "github/",
Type: "github",
Description: "github auth",
}
// Mount UUID for github auth
meGHUUID, err := uuid.GenerateUUID()
if err != nil {
t.Fatal(err)
}
meGH.UUID = meGHUUID
// Mount accessor for github auth
githubAccessor, err := c.generateMountAccessor("github")
if err != nil {
panic(fmt.Sprintf("could not generate github accessor: %v", err))
}
meGH.Accessor = githubAccessor
// Storage view for github auth
ghView := NewBarrierView(c.barrier, credentialBarrierPrefix+meGH.UUID+"/")
// Sysview for github auth
ghSysview := c.mountEntrySysView(meGH)
// Create new github auth credential backend
ghAuth, err := c.newCredentialBackend(meGH.Type, ghSysview, ghView, nil)
if err != nil {
t.Fatal(err)
}
// Mount github auth
err = c.router.Mount(ghAuth, "auth/github", meGH, ghView)
if err != nil {
t.Fatal(err)
}
// Identity store will be mounted by now, just fetch it from router
identitystore := c.router.MatchingBackend("identity/")
if identitystore == nil {
t.Fatalf("failed to fetch identity store from router")
}
is := identitystore.(*IdentityStore)
registerData := map[string]interface{}{
"name": "testentityname",
"metadata": []string{"someusefulkey=someusefulvalue"},
"policies": []string{"testpolicy1", "testpolicy2"},
}
registerReq := &logical.Request{
Operation: logical.UpdateOperation,
Path: "entity",
Data: registerData,
}
// Register the entity
resp, err = is.HandleRequest(registerReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
entityID := resp.Data["id"].(string)
readReq := &logical.Request{
Path: "entity/id/" + entityID,
Operation: logical.ReadOperation,
}
// Ensure that entity is created
resp, err = is.HandleRequest(readReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
if resp.Data["id"] != entityID {
t.Fatalf("failed to read the created entity")
}
// Perform a seal/unseal cycle
err = c.Seal(token)
if err != nil {
t.Fatalf("failed to seal core: %v", err)
}
sealed, err = c.Sealed()
if err != nil {
t.Fatalf("err checking seal status: %s", err)
}
if !sealed {
t.Fatal("should be sealed")
}
for _, key := range unsealKeys {
if _, err := TestCoreUnseal(c, TestKeyCopy(key)); err != nil {
t.Fatalf("unseal err: %s", err)
}
}
sealed, err = c.Sealed()
if err != nil {
t.Fatalf("err checking seal status: %s", err)
}
if sealed {
t.Fatal("should not be sealed")
}
// Check if the entity is restored
resp, err = is.HandleRequest(readReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
if resp.Data["id"] != entityID {
t.Fatalf("failed to read the created entity after a seal/unseal cycle")
}
}
func TestIdentityStore_MemDBEntityIndexes(t *testing.T) {
var err error
is, githubAccessor, _ := testIdentityStoreWithGithubAuth(t)
validateMountResp := is.validateMountAccessorFunc(githubAccessor)
if validateMountResp == nil {
t.Fatal("failed to validate github auth mount")
}
alias1 := &identity.Alias{
EntityID: "testentityid",
ID: "testaliasid",
MountAccessor: githubAccessor,
MountType: validateMountResp.MountType,
Name: "testaliasname",
Metadata: map[string]string{
"testkey1": "testmetadatavalue1",
"testkey2": "testmetadatavalue2",
},
}
alias2 := &identity.Alias{
EntityID: "testentityid",
ID: "testaliasid2",
MountAccessor: validateMountResp.MountAccessor,
MountType: validateMountResp.MountType,
Name: "testaliasname2",
Metadata: map[string]string{
"testkey2": "testmetadatavalue2",
"testkey3": "testmetadatavalue3",
},
}
entity := &identity.Entity{
ID: "testentityid",
Name: "testentityname",
Metadata: map[string]string{
"someusefulkey": "someusefulvalue",
},
Aliases: []*identity.Alias{
alias1,
alias2,
},
}
entity.BucketKeyHash = is.entityPacker.BucketKeyHashByItemID(entity.ID)
err = is.memDBUpsertEntity(entity)
if err != nil {
t.Fatal(err)
}
// Fetch the entity using its ID
entityFetched, err := is.memDBEntityByID(entity.ID, false)
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(entity, entityFetched) {
t.Fatalf("bad: mismatched entities; expected: %#v\n actual: %#v\n", entity, entityFetched)
}
// Fetch the entity using its name
entityFetched, err = is.memDBEntityByName(entity.Name, false)
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(entity, entityFetched) {
t.Fatalf("entity mismatched entities; expected: %#v\n actual: %#v\n", entity, entityFetched)
}
// Fetch entities using the metadata
entitiesFetched, err := is.memDBEntitiesByMetadata(map[string]string{
"someusefulkey": "someusefulvalue",
}, false)
if err != nil {
t.Fatal(err)
}
if len(entitiesFetched) != 1 {
t.Fatalf("bad: length of entities; expected: 1, actual: %d", len(entitiesFetched))
}
if !reflect.DeepEqual(entity, entitiesFetched[0]) {
t.Fatalf("entity mismatch; entity: %#v\n entitiesFetched[0]: %#v\n", entity, entitiesFetched[0])
}
entitiesFetched, err = is.memDBEntitiesByBucketEntryKeyHash(entity.BucketKeyHash)
if err != nil {
t.Fatal(err)
}
if len(entitiesFetched) != 1 {
t.Fatalf("bad: length of entities; expected: 1, actual: %d", len(entitiesFetched))
}
err = is.memDBDeleteEntityByID(entity.ID)
if err != nil {
t.Fatal(err)
}
entityFetched, err = is.memDBEntityByID(entity.ID, false)
if err != nil {
t.Fatal(err)
}
if entityFetched != nil {
t.Fatalf("bad: entity; expected: nil, actual: %#v\n", entityFetched)
}
entityFetched, err = is.memDBEntityByName(entity.Name, false)
if err != nil {
t.Fatal(err)
}
if entityFetched != nil {
t.Fatalf("bad: entity; expected: nil, actual: %#v\n", entityFetched)
}
}
// This test is required because MemDB does not take care of ensuring
// uniqueness of indexes that are marked unique. It is the job of the higher
// level abstraction, the identity store in this case.
func TestIdentityStore_EntitySameEntityNames(t *testing.T) {
var err error
var resp *logical.Response
is, _, _ := testIdentityStoreWithGithubAuth(t)
registerData := map[string]interface{}{
"name": "testentityname",
}
registerReq := &logical.Request{
Operation: logical.UpdateOperation,
Path: "entity",
Data: registerData,
}
// Register an entity
resp, err = is.HandleRequest(registerReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
// Register another entity with same name
resp, err = is.HandleRequest(registerReq)
if err != nil {
t.Fatal(err)
}
if resp == nil || !resp.IsError() {
t.Fatalf("expected an error due to entity name not being unique")
}
}
func TestIdentityStore_EntityCRUD(t *testing.T) {
var err error
var resp *logical.Response
is, _, _ := testIdentityStoreWithGithubAuth(t)
registerData := map[string]interface{}{
"name": "testentityname",
"metadata": []string{"someusefulkey=someusefulvalue"},
"policies": []string{"testpolicy1", "testpolicy2"},
}
registerReq := &logical.Request{
Operation: logical.UpdateOperation,
Path: "entity",
Data: registerData,
}
// Register the entity
resp, err = is.HandleRequest(registerReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
idRaw, ok := resp.Data["id"]
if !ok {
t.Fatalf("entity id not present in response")
}
id := idRaw.(string)
if id == "" {
t.Fatalf("invalid entity id")
}
readReq := &logical.Request{
Path: "entity/id/" + id,
Operation: logical.ReadOperation,
}
resp, err = is.HandleRequest(readReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
if resp.Data["id"] != id ||
resp.Data["name"] != registerData["name"] ||
!reflect.DeepEqual(resp.Data["policies"], registerData["policies"]) {
t.Fatalf("bad: entity response")
}
updateData := map[string]interface{}{
"name": "updatedentityname",
"metadata": []string{"updatedkey=updatedvalue"},
"policies": []string{"updatedpolicy1", "updatedpolicy2"},
}
updateReq := &logical.Request{
Operation: logical.UpdateOperation,
Path: "entity/id/" + id,
Data: updateData,
}
resp, err = is.HandleRequest(updateReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
resp, err = is.HandleRequest(readReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
if resp.Data["id"] != id ||
resp.Data["name"] != updateData["name"] ||
!reflect.DeepEqual(resp.Data["policies"], updateData["policies"]) {
t.Fatalf("bad: entity response after update; resp: %#v\n updateData: %#v\n", resp.Data, updateData)
}
deleteReq := &logical.Request{
Path: "entity/id/" + id,
Operation: logical.DeleteOperation,
}
resp, err = is.HandleRequest(deleteReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
resp, err = is.HandleRequest(readReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
if resp != nil {
t.Fatalf("expected a nil response; actual: %#v\n", resp)
}
}
func TestIdentityStore_MergeEntitiesByID(t *testing.T) {
var err error
var resp *logical.Response
is, githubAccessor, _ := testIdentityStoreWithGithubAuth(t)
registerData := map[string]interface{}{
"name": "testentityname2",
"metadata": []string{"someusefulkey=someusefulvalue"},
}
registerData2 := map[string]interface{}{
"name": "testentityname",
"metadata": []string{"someusefulkey=someusefulvalue"},
}
aliasRegisterData1 := map[string]interface{}{
"name": "testaliasname1",
"mount_accessor": githubAccessor,
"metadata": []string{"organization=hashicorp", "team=vault"},
}
aliasRegisterData2 := map[string]interface{}{
"name": "testaliasname2",
"mount_accessor": githubAccessor,
"metadata": []string{"organization=hashicorp", "team=vault"},
}
aliasRegisterData3 := map[string]interface{}{
"name": "testaliasname3",
"mount_accessor": githubAccessor,
"metadata": []string{"organization=hashicorp", "team=vault"},
}
aliasRegisterData4 := map[string]interface{}{
"name": "testaliasname4",
"mount_accessor": githubAccessor,
"metadata": []string{"organization=hashicorp", "team=vault"},
}
registerReq := &logical.Request{
Operation: logical.UpdateOperation,
Path: "entity",
Data: registerData,
}
// Register the entity
resp, err = is.HandleRequest(registerReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
entityID1 := resp.Data["id"].(string)
// Set entity ID in alias registration data and register alias
aliasRegisterData1["entity_id"] = entityID1
aliasRegisterData2["entity_id"] = entityID1
aliasReq := &logical.Request{
Operation: logical.UpdateOperation,
Path: "alias",
Data: aliasRegisterData1,
}
// Register the alias
resp, err = is.HandleRequest(aliasReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
// Register the alias
aliasReq.Data = aliasRegisterData2
resp, err = is.HandleRequest(aliasReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
entity1, err := is.memDBEntityByID(entityID1, false)
if err != nil {
t.Fatal(err)
}
if entity1 == nil {
t.Fatalf("failed to create entity: %v", err)
}
if len(entity1.Aliases) != 2 {
t.Fatalf("bad: number of aliases in entity; expected: 2, actual: %d", len(entity1.Aliases))
}
registerReq.Data = registerData2
// Register another entity
resp, err = is.HandleRequest(registerReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
entityID2 := resp.Data["id"].(string)
// Set entity ID in alias registration data and register alias
aliasRegisterData3["entity_id"] = entityID2
aliasRegisterData4["entity_id"] = entityID2
aliasReq = &logical.Request{
Operation: logical.UpdateOperation,
Path: "alias",
Data: aliasRegisterData3,
}
// Register the alias
resp, err = is.HandleRequest(aliasReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
// Register the alias
aliasReq.Data = aliasRegisterData4
resp, err = is.HandleRequest(aliasReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
entity2, err := is.memDBEntityByID(entityID2, false)
if err != nil {
t.Fatal(err)
}
if entity2 == nil {
t.Fatalf("failed to create entity: %v", err)
}
if len(entity2.Aliases) != 2 {
t.Fatalf("bad: number of aliases in entity; expected: 2, actual: %d", len(entity2.Aliases))
}
mergeData := map[string]interface{}{
"to_entity_id": entityID1,
"from_entity_ids": []string{entityID2},
}
mergeReq := &logical.Request{
Operation: logical.UpdateOperation,
Path: "entity/merge",
Data: mergeData,
}
resp, err = is.HandleRequest(mergeReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
entityReq := &logical.Request{
Operation: logical.ReadOperation,
Path: "entity/id/" + entityID2,
}
resp, err = is.HandleRequest(entityReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
if resp != nil {
t.Fatalf("entity should have been deleted")
}
entityReq.Path = "entity/id/" + entityID1
resp, err = is.HandleRequest(entityReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
entity2Aliases := resp.Data["aliases"].([]interface{})
if len(entity2Aliases) != 4 {
t.Fatalf("bad: number of aliases in entity; expected: 4, actual: %d", len(entity2Aliases))
}
for _, aliasRaw := range entity2Aliases {
alias := aliasRaw.(map[string]interface{})
aliasLookedUp, err := is.memDBAliasByID(alias["id"].(string), false)
if err != nil {
t.Fatal(err)
}
if aliasLookedUp == nil {
t.Fatalf("index for alias id %q is not updated", alias["id"].(string))
}
}
}

View File

@@ -0,0 +1,286 @@
package vault
import (
"fmt"
"strings"
"github.com/golang/protobuf/ptypes"
memdb "github.com/hashicorp/go-memdb"
"github.com/hashicorp/vault/helper/identity"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
)
func groupPaths(i *IdentityStore) []*framework.Path {
return []*framework.Path{
{
Pattern: "group$",
Fields: map[string]*framework.FieldSchema{
"id": {
Type: framework.TypeString,
Description: "ID of the group.",
},
"name": {
Type: framework.TypeString,
Description: "Name of the group.",
},
"metadata": {
Type: framework.TypeStringSlice,
Description: "Metadata to be associated with the group. Format should be a list of `key=value` pairs.",
},
"policies": {
Type: framework.TypeCommaStringSlice,
Description: "Policies to be tied to the group.",
},
"member_group_ids": {
Type: framework.TypeCommaStringSlice,
Description: "Group IDs to be assigned as group members.",
},
"member_entity_ids": {
Type: framework.TypeCommaStringSlice,
Description: "Entity IDs to be assigned as group members.",
},
},
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.UpdateOperation: i.checkPremiumVersion(i.pathGroupRegister),
},
HelpSynopsis: strings.TrimSpace(groupHelp["register"][0]),
HelpDescription: strings.TrimSpace(groupHelp["register"][1]),
},
{
Pattern: "group/id/" + framework.GenericNameRegex("id"),
Fields: map[string]*framework.FieldSchema{
"id": {
Type: framework.TypeString,
Description: "ID of the group.",
},
"name": {
Type: framework.TypeString,
Description: "Name of the group.",
},
"metadata": {
Type: framework.TypeStringSlice,
Description: "Metadata to be associated with the group. Format should be a list of `key=value` pairs.",
},
"policies": {
Type: framework.TypeCommaStringSlice,
Description: "Policies to be tied to the group.",
},
"member_group_ids": {
Type: framework.TypeCommaStringSlice,
Description: "Group IDs to be assigned as group members.",
},
"member_entity_ids": {
Type: framework.TypeCommaStringSlice,
Description: "Entity IDs to be assigned as group members.",
},
},
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.UpdateOperation: i.checkPremiumVersion(i.pathGroupIDUpdate),
logical.ReadOperation: i.checkPremiumVersion(i.pathGroupIDRead),
logical.DeleteOperation: i.checkPremiumVersion(i.pathGroupIDDelete),
},
HelpSynopsis: strings.TrimSpace(groupHelp["group-by-id"][0]),
HelpDescription: strings.TrimSpace(groupHelp["group-by-id"][1]),
},
{
Pattern: "group/id/?$",
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.ListOperation: i.checkPremiumVersion(i.pathGroupIDList),
},
HelpSynopsis: strings.TrimSpace(entityHelp["group-id-list"][0]),
HelpDescription: strings.TrimSpace(entityHelp["group-id-list"][1]),
},
}
}
func (i *IdentityStore) pathGroupRegister(req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
_, ok := d.GetOk("id")
if ok {
return i.pathGroupIDUpdate(req, d)
}
i.groupLock.Lock()
defer i.groupLock.Unlock()
return i.handleGroupUpdateCommon(req, d, nil)
}
func (i *IdentityStore) pathGroupIDUpdate(req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
groupID := d.Get("id").(string)
if groupID == "" {
return logical.ErrorResponse("empty group ID"), nil
}
i.groupLock.Lock()
defer i.groupLock.Unlock()
group, err := i.memDBGroupByID(groupID, true)
if err != nil {
return nil, err
}
if group == nil {
return logical.ErrorResponse("invalid group ID"), nil
}
return i.handleGroupUpdateCommon(req, d, group)
}
func (i *IdentityStore) handleGroupUpdateCommon(req *logical.Request, d *framework.FieldData, group *identity.Group) (*logical.Response, error) {
var err error
var newGroup bool
if group == nil {
group = &identity.Group{}
newGroup = true
}
// Update the policies if supplied
policiesRaw, ok := d.GetOk("policies")
if ok {
group.Policies = policiesRaw.([]string)
}
// Get the name
groupName := d.Get("name").(string)
if groupName != "" {
// Check if there is a group already existing for the given name
groupByName, err := i.memDBGroupByName(groupName, false)
if err != nil {
return nil, err
}
// If this is a new group and if there already exists a group by this
// name, error out. If the name of an existing group is about to be
// modified into something which is already tied to a different group,
// error out.
switch {
case (newGroup && groupByName != nil), (groupByName != nil && group.ID != "" && groupByName.ID != group.ID):
return logical.ErrorResponse("group name is already in use"), nil
}
group.Name = groupName
}
metadataRaw, ok := d.GetOk("metadata")
if ok {
group.Metadata, err = parseMetadata(metadataRaw.([]string))
if err != nil {
return logical.ErrorResponse(fmt.Sprintf("failed to parse group metadata: %v", err)), nil
}
}
memberEntityIDsRaw, ok := d.GetOk("member_entity_ids")
if ok {
group.MemberEntityIDs = memberEntityIDsRaw.([]string)
if len(group.MemberEntityIDs) > 512 {
return logical.ErrorResponse("member entity IDs exceeding the limit of 512"), nil
}
}
memberGroupIDsRaw, ok := d.GetOk("member_group_ids")
var memberGroupIDs []string
if ok {
memberGroupIDs = memberGroupIDsRaw.([]string)
}
err = i.sanitizeAndUpsertGroup(group, memberGroupIDs)
if err != nil {
return nil, err
}
respData := map[string]interface{}{
"id": group.ID,
"name": group.Name,
}
return &logical.Response{
Data: respData,
}, nil
}
func (i *IdentityStore) pathGroupIDRead(req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
groupID := d.Get("id").(string)
if groupID == "" {
return logical.ErrorResponse("empty group id"), nil
}
group, err := i.memDBGroupByID(groupID, false)
if err != nil {
return nil, err
}
if group == nil {
return nil, nil
}
return i.handleGroupReadCommon(group)
}
func (i *IdentityStore) handleGroupReadCommon(group *identity.Group) (*logical.Response, error) {
if group == nil {
return nil, fmt.Errorf("nil group")
}
respData := map[string]interface{}{}
respData["id"] = group.ID
respData["name"] = group.Name
respData["policies"] = group.Policies
respData["member_entity_ids"] = group.MemberEntityIDs
respData["metadata"] = group.Metadata
respData["creation_time"] = ptypes.TimestampString(group.CreationTime)
respData["last_update_time"] = ptypes.TimestampString(group.LastUpdateTime)
respData["modify_index"] = group.ModifyIndex
memberGroupIDs, err := i.memberGroupIDsByID(group.ID)
if err != nil {
return nil, err
}
respData["member_group_ids"] = memberGroupIDs
return &logical.Response{
Data: respData,
}, nil
}
func (i *IdentityStore) pathGroupIDDelete(req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
groupID := d.Get("id").(string)
if groupID == "" {
return logical.ErrorResponse("empty group ID"), nil
}
return nil, i.deleteGroupByID(groupID)
}
// pathGroupIDList lists the IDs of all the groups in the identity store
func (i *IdentityStore) pathGroupIDList(req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
ws := memdb.NewWatchSet()
iter, err := i.memDBGroupIterator(ws)
if err != nil {
return nil, fmt.Errorf("failed to fetch iterator for group in memdb: %v", err)
}
var groupIDs []string
for {
raw := iter.Next()
if raw == nil {
break
}
groupIDs = append(groupIDs, raw.(*identity.Group).ID)
}
return logical.ListResponse(groupIDs), nil
}
var groupHelp = map[string][2]string{
"register": {
"Create a new group.",
"",
},
"group-by-id": {
"Update or delete an existing group using its ID.",
"",
},
"group-id-list": {
"List all the group IDs.",
"",
},
}

View File

@@ -0,0 +1,666 @@
package vault
import (
"reflect"
"sort"
"testing"
"github.com/hashicorp/vault/helper/identity"
"github.com/hashicorp/vault/logical"
)
func TestIdentityStore_MemDBGroupIndexes(t *testing.T) {
var err error
i, _, _ := testIdentityStoreWithGithubAuth(t)
// Create a dummy group
group := &identity.Group{
ID: "testgroupid",
Name: "testgroupname",
Metadata: map[string]string{
"testmetadatakey1": "testmetadatavalue1",
"testmetadatakey2": "testmetadatavalue2",
},
ParentGroupIDs: []string{"testparentgroupid1", "testparentgroupid2"},
MemberEntityIDs: []string{"testentityid1", "testentityid2"},
Policies: []string{"testpolicy1", "testpolicy2"},
BucketKeyHash: i.groupPacker.BucketKeyHashByItemID("testgroupid"),
}
// Insert it into memdb
err = i.memDBUpsertGroup(group)
if err != nil {
t.Fatal(err)
}
// Insert another dummy group
group = &identity.Group{
ID: "testgroupid2",
Name: "testgroupname2",
Metadata: map[string]string{
"testmetadatakey2": "testmetadatavalue2",
"testmetadatakey3": "testmetadatavalue3",
},
ParentGroupIDs: []string{"testparentgroupid2", "testparentgroupid3"},
MemberEntityIDs: []string{"testentityid2", "testentityid3"},
Policies: []string{"testpolicy2", "testpolicy3"},
BucketKeyHash: i.groupPacker.BucketKeyHashByItemID("testgroupid2"),
}
// Insert it into memdb
err = i.memDBUpsertGroup(group)
if err != nil {
t.Fatal(err)
}
var fetchedGroup *identity.Group
// Fetch group given the name
fetchedGroup, err = i.memDBGroupByName("testgroupname", false)
if err != nil {
t.Fatal(err)
}
if fetchedGroup == nil || fetchedGroup.Name != "testgroupname" {
t.Fatalf("failed to fetch an indexed group")
}
// Fetch group given the ID
fetchedGroup, err = i.memDBGroupByID("testgroupid", false)
if err != nil {
t.Fatal(err)
}
if fetchedGroup == nil || fetchedGroup.Name != "testgroupname" {
t.Fatalf("failed to fetch an indexed group")
}
var fetchedGroups []*identity.Group
// Fetch the subgroups of a given group ID
fetchedGroups, err = i.memDBGroupsByParentGroupID("testparentgroupid1", false)
if err != nil {
t.Fatal(err)
}
if len(fetchedGroups) != 1 || fetchedGroups[0].Name != "testgroupname" {
t.Fatalf("failed to fetch an indexed group")
}
fetchedGroups, err = i.memDBGroupsByParentGroupID("testparentgroupid2", false)
if err != nil {
t.Fatal(err)
}
if len(fetchedGroups) != 2 {
t.Fatalf("failed to fetch a indexed groups")
}
// Fetch groups based on policy name
fetchedGroups, err = i.memDBGroupsByPolicy("testpolicy1", false)
if err != nil {
t.Fatal(err)
}
if len(fetchedGroups) != 1 || fetchedGroups[0].Name != "testgroupname" {
t.Fatalf("failed to fetch an indexed group")
}
fetchedGroups, err = i.memDBGroupsByPolicy("testpolicy2", false)
if err != nil {
t.Fatal(err)
}
if len(fetchedGroups) != 2 {
t.Fatalf("failed to fetch indexed groups")
}
// Fetch groups based on member entity ID
fetchedGroups, err = i.memDBGroupsByMemberEntityID("testentityid1", false)
if err != nil {
t.Fatal(err)
}
if len(fetchedGroups) != 1 || fetchedGroups[0].Name != "testgroupname" {
t.Fatalf("failed to fetch an indexed group")
}
fetchedGroups, err = i.memDBGroupsByMemberEntityID("testentityid2", false)
if err != nil {
t.Fatal(err)
}
if len(fetchedGroups) != 2 {
t.Fatalf("failed to fetch groups by entity ID")
}
}
func TestIdentityStore_GroupsCreateUpdate(t *testing.T) {
var resp *logical.Response
var err error
is, _, _ := testIdentityStoreWithGithubAuth(t)
// Create an entity and get its ID
entityRegisterReq := &logical.Request{
Operation: logical.UpdateOperation,
Path: "entity",
}
resp, err = is.HandleRequest(entityRegisterReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: resp: %#v, err: %v", resp, err)
}
entityID1 := resp.Data["id"].(string)
// Create another entity and get its ID
resp, err = is.HandleRequest(entityRegisterReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: resp: %#v, err: %v", resp, err)
}
entityID2 := resp.Data["id"].(string)
// Create a group with the above created 2 entities as its members
groupData := map[string]interface{}{
"policies": "testpolicy1,testpolicy2",
"metadata": []string{"testkey1=testvalue1", "testkey2=testvalue2"},
"member_entity_ids": []string{entityID1, entityID2},
}
// Create a group and get its ID
groupReq := &logical.Request{
Operation: logical.UpdateOperation,
Path: "group",
Data: groupData,
}
resp, err = is.HandleRequest(groupReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: resp: %#v, err: %v", resp, err)
}
memberGroupID1 := resp.Data["id"].(string)
// Create another group and get its ID
resp, err = is.HandleRequest(groupReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: resp: %#v, err: %v", resp, err)
}
memberGroupID2 := resp.Data["id"].(string)
// Create a group with the above 2 groups as its members
groupData["member_group_ids"] = []string{memberGroupID1, memberGroupID2}
resp, err = is.HandleRequest(groupReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: resp: %#v, err: %v", resp, err)
}
groupID := resp.Data["id"].(string)
// Read the group using its iD and check if all the fields are properly
// set
groupReq = &logical.Request{
Operation: logical.ReadOperation,
Path: "group/id/" + groupID,
}
resp, err = is.HandleRequest(groupReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: resp: %#v, err: %v", resp, err)
}
expectedData := map[string]interface{}{
"policies": []string{"testpolicy1", "testpolicy2"},
"metadata": map[string]string{
"testkey1": "testvalue1",
"testkey2": "testvalue2",
},
}
expectedData["id"] = resp.Data["id"]
expectedData["name"] = resp.Data["name"]
expectedData["member_group_ids"] = resp.Data["member_group_ids"]
expectedData["member_entity_ids"] = resp.Data["member_entity_ids"]
expectedData["creation_time"] = resp.Data["creation_time"]
expectedData["last_update_time"] = resp.Data["last_update_time"]
expectedData["modify_index"] = resp.Data["modify_index"]
if !reflect.DeepEqual(expectedData, resp.Data) {
t.Fatalf("bad: group data;\nexpected: %#v\n actual: %#v\n", expectedData, resp.Data)
}
// Update the policies and metadata in the group
groupReq.Operation = logical.UpdateOperation
groupReq.Data = groupData
// Update by setting ID in the param
groupData["id"] = groupID
groupData["policies"] = "updatedpolicy1,updatedpolicy2"
groupData["metadata"] = []string{"updatedkey=updatedvalue"}
resp, err = is.HandleRequest(groupReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: resp: %#v, err: %v", resp, err)
}
// Check if updates are reflected
groupReq.Operation = logical.ReadOperation
resp, err = is.HandleRequest(groupReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: resp: %#v, err: %v", resp, err)
}
expectedData["policies"] = []string{"updatedpolicy1", "updatedpolicy2"}
expectedData["metadata"] = map[string]string{
"updatedkey": "updatedvalue",
}
expectedData["last_update_time"] = resp.Data["last_update_time"]
expectedData["modify_index"] = resp.Data["modify_index"]
if !reflect.DeepEqual(expectedData, resp.Data) {
t.Fatalf("bad: group data; expected: %#v\n actual: %#v\n", expectedData, resp.Data)
}
}
func TestIdentityStore_GroupsCRUD_ByID(t *testing.T) {
var resp *logical.Response
var err error
is, _, _ := testIdentityStoreWithGithubAuth(t)
// Create an entity and get its ID
entityRegisterReq := &logical.Request{
Operation: logical.UpdateOperation,
Path: "entity",
}
resp, err = is.HandleRequest(entityRegisterReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: resp: %#v, err: %v", resp, err)
}
entityID1 := resp.Data["id"].(string)
// Create another entity and get its ID
resp, err = is.HandleRequest(entityRegisterReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: resp: %#v, err: %v", resp, err)
}
entityID2 := resp.Data["id"].(string)
// Create a group with the above created 2 entities as its members
groupData := map[string]interface{}{
"policies": "testpolicy1,testpolicy2",
"metadata": []string{"testkey1=testvalue1", "testkey2=testvalue2"},
"member_entity_ids": []string{entityID1, entityID2},
}
// Create a group and get its ID
groupRegisterReq := &logical.Request{
Operation: logical.UpdateOperation,
Path: "group",
Data: groupData,
}
resp, err = is.HandleRequest(groupRegisterReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: resp: %#v, err: %v", resp, err)
}
memberGroupID1 := resp.Data["id"].(string)
// Create another group and get its ID
resp, err = is.HandleRequest(groupRegisterReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: resp: %#v, err: %v", resp, err)
}
memberGroupID2 := resp.Data["id"].(string)
// Create a group with the above 2 groups as its members
groupData["member_group_ids"] = []string{memberGroupID1, memberGroupID2}
resp, err = is.HandleRequest(groupRegisterReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: resp: %#v, err: %v", resp, err)
}
groupID := resp.Data["id"].(string)
// Read the group using its name and check if all the fields are properly
// set
groupReq := &logical.Request{
Operation: logical.ReadOperation,
Path: "group/id/" + groupID,
}
resp, err = is.HandleRequest(groupReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: resp: %#v, err: %v", resp, err)
}
expectedData := map[string]interface{}{
"policies": []string{"testpolicy1", "testpolicy2"},
"metadata": map[string]string{
"testkey1": "testvalue1",
"testkey2": "testvalue2",
},
}
expectedData["id"] = resp.Data["id"]
expectedData["name"] = resp.Data["name"]
expectedData["member_group_ids"] = resp.Data["member_group_ids"]
expectedData["member_entity_ids"] = resp.Data["member_entity_ids"]
expectedData["creation_time"] = resp.Data["creation_time"]
expectedData["last_update_time"] = resp.Data["last_update_time"]
expectedData["modify_index"] = resp.Data["modify_index"]
if !reflect.DeepEqual(expectedData, resp.Data) {
t.Fatalf("bad: group data;\nexpected: %#v\n actual: %#v\n", expectedData, resp.Data)
}
// Update the policies and metadata in the group
groupReq.Operation = logical.UpdateOperation
groupReq.Data = groupData
groupData["policies"] = "updatedpolicy1,updatedpolicy2"
groupData["metadata"] = []string{"updatedkey=updatedvalue"}
resp, err = is.HandleRequest(groupReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: resp: %#v, err: %v", resp, err)
}
// Check if updates are reflected
groupReq.Operation = logical.ReadOperation
resp, err = is.HandleRequest(groupReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: resp: %#v, err: %v", resp, err)
}
expectedData["policies"] = []string{"updatedpolicy1", "updatedpolicy2"}
expectedData["metadata"] = map[string]string{
"updatedkey": "updatedvalue",
}
expectedData["last_update_time"] = resp.Data["last_update_time"]
expectedData["modify_index"] = resp.Data["modify_index"]
if !reflect.DeepEqual(expectedData, resp.Data) {
t.Fatalf("bad: group data; expected: %#v\n actual: %#v\n", expectedData, resp.Data)
}
// Check if delete is working properly
groupReq.Operation = logical.DeleteOperation
resp, err = is.HandleRequest(groupReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: resp: %#v, err: %v", resp, err)
}
groupReq.Operation = logical.ReadOperation
resp, err = is.HandleRequest(groupReq)
if err != nil {
t.Fatal(err)
}
if resp != nil {
t.Fatalf("expected a nil response")
}
}
/*
Test groups hierarchy:
eng
| |
vault ops
| | | |
kube identity build deploy
*/
func TestIdentityStore_GroupHierarchyCases(t *testing.T) {
var resp *logical.Response
var err error
is, _, _ := testIdentityStoreWithGithubAuth(t)
groupRegisterReq := &logical.Request{
Operation: logical.UpdateOperation,
Path: "group",
}
// Create 'kube' group
kubeGroupData := map[string]interface{}{
"name": "kube",
"policies": "kubepolicy",
}
groupRegisterReq.Data = kubeGroupData
resp, err = is.HandleRequest(groupRegisterReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: resp: %#v, err: %v", resp, err)
}
kubeGroupID := resp.Data["id"].(string)
// Create 'identity' group
identityGroupData := map[string]interface{}{
"name": "identity",
"policies": "identitypolicy",
}
groupRegisterReq.Data = identityGroupData
resp, err = is.HandleRequest(groupRegisterReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: resp: %#v, err: %v", resp, err)
}
identityGroupID := resp.Data["id"].(string)
// Create 'build' group
buildGroupData := map[string]interface{}{
"name": "build",
"policies": "buildpolicy",
}
groupRegisterReq.Data = buildGroupData
resp, err = is.HandleRequest(groupRegisterReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: resp: %#v, err: %v", resp, err)
}
buildGroupID := resp.Data["id"].(string)
// Create 'deploy' group
deployGroupData := map[string]interface{}{
"name": "deploy",
"policies": "deploypolicy",
}
groupRegisterReq.Data = deployGroupData
resp, err = is.HandleRequest(groupRegisterReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: resp: %#v, err: %v", resp, err)
}
deployGroupID := resp.Data["id"].(string)
// Create 'vault' with 'kube' and 'identity' as member groups
vaultMemberGroupIDs := []string{kubeGroupID, identityGroupID}
vaultGroupData := map[string]interface{}{
"name": "vault",
"policies": "vaultpolicy",
"member_group_ids": vaultMemberGroupIDs,
}
groupRegisterReq.Data = vaultGroupData
resp, err = is.HandleRequest(groupRegisterReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: resp: %#v, err: %v", resp, err)
}
vaultGroupID := resp.Data["id"].(string)
// Create 'ops' group with 'build' and 'deploy' as member groups
opsMemberGroupIDs := []string{buildGroupID, deployGroupID}
opsGroupData := map[string]interface{}{
"name": "ops",
"policies": "opspolicy",
"member_group_ids": opsMemberGroupIDs,
}
groupRegisterReq.Data = opsGroupData
resp, err = is.HandleRequest(groupRegisterReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: resp: %#v, err: %v", resp, err)
}
opsGroupID := resp.Data["id"].(string)
// Create 'eng' group with 'vault' and 'ops' as member groups
engMemberGroupIDs := []string{vaultGroupID, opsGroupID}
engGroupData := map[string]interface{}{
"name": "eng",
"policies": "engpolicy",
"member_group_ids": engMemberGroupIDs,
}
groupRegisterReq.Data = engGroupData
resp, err = is.HandleRequest(groupRegisterReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: resp: %#v, err: %v", resp, err)
}
engGroupID := resp.Data["id"].(string)
/*
fmt.Printf("engGroupID: %#v\n", engGroupID)
fmt.Printf("vaultGroupID: %#v\n", vaultGroupID)
fmt.Printf("opsGroupID: %#v\n", opsGroupID)
fmt.Printf("kubeGroupID: %#v\n", kubeGroupID)
fmt.Printf("identityGroupID: %#v\n", identityGroupID)
fmt.Printf("buildGroupID: %#v\n", buildGroupID)
fmt.Printf("deployGroupID: %#v\n", deployGroupID)
*/
var memberGroupIDs []string
// Fetch 'eng' group
engGroup, err := is.memDBGroupByID(engGroupID, false)
if err != nil {
t.Fatal(err)
}
memberGroupIDs, err = is.memberGroupIDsByID(engGroup.ID)
if err != nil {
t.Fatal(err)
}
sort.Strings(memberGroupIDs)
sort.Strings(engMemberGroupIDs)
if !reflect.DeepEqual(engMemberGroupIDs, memberGroupIDs) {
t.Fatalf("bad: group membership IDs; expected: %#v\n actual: %#v\n", engMemberGroupIDs, memberGroupIDs)
}
vaultGroup, err := is.memDBGroupByID(vaultGroupID, false)
if err != nil {
t.Fatal(err)
}
memberGroupIDs, err = is.memberGroupIDsByID(vaultGroup.ID)
if err != nil {
t.Fatal(err)
}
sort.Strings(memberGroupIDs)
sort.Strings(vaultMemberGroupIDs)
if !reflect.DeepEqual(vaultMemberGroupIDs, memberGroupIDs) {
t.Fatalf("bad: group membership IDs; expected: %#v\n actual: %#v\n", vaultMemberGroupIDs, memberGroupIDs)
}
opsGroup, err := is.memDBGroupByID(opsGroupID, false)
if err != nil {
t.Fatal(err)
}
memberGroupIDs, err = is.memberGroupIDsByID(opsGroup.ID)
if err != nil {
t.Fatal(err)
}
sort.Strings(memberGroupIDs)
sort.Strings(opsMemberGroupIDs)
if !reflect.DeepEqual(opsMemberGroupIDs, memberGroupIDs) {
t.Fatalf("bad: group membership IDs; expected: %#v\n actual: %#v\n", opsMemberGroupIDs, memberGroupIDs)
}
groupUpdateReq := &logical.Request{
Operation: logical.UpdateOperation,
}
// Adding 'engGroupID' under 'kubeGroupID' should fail
groupUpdateReq.Path = "group/name/kube"
groupUpdateReq.Data = kubeGroupData
kubeGroupData["member_group_ids"] = []string{engGroupID}
resp, err = is.HandleRequest(groupUpdateReq)
if err == nil {
t.Fatalf("expected an error response")
}
// Create an entity ID
entityRegisterReq := &logical.Request{
Operation: logical.UpdateOperation,
Path: "entity",
}
resp, err = is.HandleRequest(entityRegisterReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: resp: %#v, err: %v", resp, err)
}
entityID1 := resp.Data["id"].(string)
// Add the entity as a member of 'kube' group
entityIDReq := &logical.Request{
Operation: logical.UpdateOperation,
Path: "group/id/" + kubeGroupID,
Data: map[string]interface{}{
"member_entity_ids": []string{entityID1},
},
}
resp, err = is.HandleRequest(entityIDReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: resp: %#v, err: %v", resp, err)
}
// Create a second entity ID
resp, err = is.HandleRequest(entityRegisterReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: resp: %#v, err: %v", resp, err)
}
entityID2 := resp.Data["id"].(string)
// Add the entity as a member of 'ops' group
entityIDReq.Path = "group/id/" + opsGroupID
entityIDReq.Data = map[string]interface{}{
"member_entity_ids": []string{entityID2},
}
resp, err = is.HandleRequest(entityIDReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: resp: %#v, err: %v", resp, err)
}
// Create a third entity ID
resp, err = is.HandleRequest(entityRegisterReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: resp: %#v, err: %v", resp, err)
}
entityID3 := resp.Data["id"].(string)
// Add the entity as a member of 'eng' group
entityIDReq.Path = "group/id/" + engGroupID
entityIDReq.Data = map[string]interface{}{
"member_entity_ids": []string{entityID3},
}
resp, err = is.HandleRequest(entityIDReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: resp: %#v, err: %v", resp, err)
}
policies, err := is.groupPoliciesByEntityID(entityID1)
if err != nil {
t.Fatal(err)
}
sort.Strings(policies)
expected := []string{"kubepolicy", "vaultpolicy", "engpolicy"}
sort.Strings(expected)
if !reflect.DeepEqual(expected, policies) {
t.Fatalf("bad: policies; expected: %#v\nactual:%#v", expected, policies)
}
policies, err = is.groupPoliciesByEntityID(entityID2)
if err != nil {
t.Fatal(err)
}
sort.Strings(policies)
expected = []string{"opspolicy", "engpolicy"}
sort.Strings(expected)
if !reflect.DeepEqual(expected, policies) {
t.Fatalf("bad: policies; expected: %#v\nactual:%#v", expected, policies)
}
policies, err = is.groupPoliciesByEntityID(entityID3)
if err != nil {
t.Fatal(err)
}
if len(policies) != 1 && policies[0] != "engpolicy" {
t.Fatalf("bad: policies; expected: 'engpolicy'\nactual:%#v", policies)
}
groups, err := is.transitiveGroupsByEntityID(entityID1)
if err != nil {
t.Fatal(err)
}
if len(groups) != 3 {
t.Fatalf("bad: length of groups; expected: 3, actual: %d", len(groups))
}
groups, err = is.transitiveGroupsByEntityID(entityID2)
if err != nil {
t.Fatal(err)
}
if len(groups) != 2 {
t.Fatalf("bad: length of groups; expected: 2, actual: %d", len(groups))
}
groups, err = is.transitiveGroupsByEntityID(entityID3)
if err != nil {
t.Fatal(err)
}
if len(groups) != 1 {
t.Fatalf("bad: length of groups; expected: 1, actual: %d", len(groups))
}
}

View File

@@ -0,0 +1,180 @@
package vault
import (
"fmt"
memdb "github.com/hashicorp/go-memdb"
)
func identityStoreSchema() *memdb.DBSchema {
iStoreSchema := &memdb.DBSchema{
Tables: make(map[string]*memdb.TableSchema),
}
schemas := []func() *memdb.TableSchema{
entityTableSchema,
aliasesTableSchema,
groupTableSchema,
}
for _, schemaFunc := range schemas {
schema := schemaFunc()
if _, ok := iStoreSchema.Tables[schema.Name]; ok {
panic(fmt.Sprintf("duplicate table name: %s", schema.Name))
}
iStoreSchema.Tables[schema.Name] = schema
}
return iStoreSchema
}
func aliasesTableSchema() *memdb.TableSchema {
return &memdb.TableSchema{
Name: "aliases",
Indexes: map[string]*memdb.IndexSchema{
"id": &memdb.IndexSchema{
Name: "id",
Unique: true,
Indexer: &memdb.StringFieldIndex{
Field: "ID",
},
},
"entity_id": &memdb.IndexSchema{
Name: "entity_id",
Unique: false,
Indexer: &memdb.StringFieldIndex{
Field: "EntityID",
},
},
"mount_type": &memdb.IndexSchema{
Name: "mount_type",
Unique: false,
Indexer: &memdb.StringFieldIndex{
Field: "MountType",
},
},
"factors": &memdb.IndexSchema{
Name: "factors",
Unique: true,
Indexer: &memdb.CompoundIndex{
Indexes: []memdb.Indexer{
&memdb.StringFieldIndex{
Field: "MountAccessor",
},
&memdb.StringFieldIndex{
Field: "Name",
},
},
},
},
"metadata": &memdb.IndexSchema{
Name: "metadata",
Unique: false,
AllowMissing: true,
Indexer: &memdb.StringMapFieldIndex{
Field: "Metadata",
},
},
},
}
}
func entityTableSchema() *memdb.TableSchema {
return &memdb.TableSchema{
Name: "entities",
Indexes: map[string]*memdb.IndexSchema{
"id": &memdb.IndexSchema{
Name: "id",
Unique: true,
Indexer: &memdb.StringFieldIndex{
Field: "ID",
},
},
"name": &memdb.IndexSchema{
Name: "name",
Unique: true,
Indexer: &memdb.StringFieldIndex{
Field: "Name",
},
},
"metadata": &memdb.IndexSchema{
Name: "metadata",
Unique: false,
AllowMissing: true,
Indexer: &memdb.StringMapFieldIndex{
Field: "Metadata",
},
},
"merged_entity_ids": &memdb.IndexSchema{
Name: "merged_entity_ids",
Unique: true,
AllowMissing: true,
Indexer: &memdb.StringSliceFieldIndex{
Field: "MergedEntityIDs",
},
},
"bucket_key_hash": &memdb.IndexSchema{
Name: "bucket_key_hash",
Unique: false,
AllowMissing: false,
Indexer: &memdb.StringFieldIndex{
Field: "BucketKeyHash",
},
},
},
}
}
func groupTableSchema() *memdb.TableSchema {
return &memdb.TableSchema{
Name: "groups",
Indexes: map[string]*memdb.IndexSchema{
"id": {
Name: "id",
Unique: true,
Indexer: &memdb.StringFieldIndex{
Field: "ID",
},
},
"name": {
Name: "name",
Unique: true,
Indexer: &memdb.StringFieldIndex{
Field: "Name",
},
},
"member_entity_ids": {
Name: "member_entity_ids",
Unique: false,
AllowMissing: true,
Indexer: &memdb.StringSliceFieldIndex{
Field: "MemberEntityIDs",
},
},
"parent_group_ids": {
Name: "parent_group_ids",
Unique: false,
AllowMissing: true,
Indexer: &memdb.StringSliceFieldIndex{
Field: "ParentGroupIDs",
},
},
"policies": {
Name: "policies",
Unique: false,
AllowMissing: true,
Indexer: &memdb.StringSliceFieldIndex{
Field: "Policies",
},
},
"bucket_key_hash": &memdb.IndexSchema{
Name: "bucket_key_hash",
Unique: false,
AllowMissing: false,
Indexer: &memdb.StringFieldIndex{
Field: "BucketKeyHash",
},
},
},
}
}

View File

@@ -0,0 +1,75 @@
package vault
import (
"regexp"
"sync"
memdb "github.com/hashicorp/go-memdb"
"github.com/hashicorp/vault/helper/locksutil"
"github.com/hashicorp/vault/helper/storagepacker"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
log "github.com/mgutz/logxi/v1"
)
const (
// Storage prefixes
entityPrefix = "entity/"
)
var (
// metaKeyFormatRegEx checks if a metadata key string is valid
metaKeyFormatRegEx = regexp.MustCompile(`^[a-zA-Z0-9=/+_-]+$`).MatchString
)
const (
// The meta key prefix reserved for Vault's internal use
metaKeyReservedPrefix = "vault-"
// The maximum number of metadata key pairs allowed to be registered
metaMaxKeyPairs = 64
// The maximum allowed length of a metadata key
metaKeyMaxLength = 128
// The maximum allowed length of a metadata value
metaValueMaxLength = 512
)
// IdentityStore is composed of its own storage view and a MemDB which
// maintains active in-memory replicas of the storage contents indexed by
// multiple fields.
type IdentityStore struct {
// IdentityStore is a secret backend in Vault
*framework.Backend
// view is the storage sub-view where all the artifacts of identity store
// gets persisted
view logical.Storage
// db is the in-memory database where the storage artifacts gets replicated
// to enable richer queries based on multiple indexes.
db *memdb.MemDB
// validateMountAccessorFunc is a utility from router which returnes the
// properties of the mount given the mount accessor.
validateMountAccessorFunc func(string) *validateMountResponse
// entityLocks are a set of 256 locks to which all the entities will be
// categorized to while performing storage modifications.
entityLocks []*locksutil.LockEntry
// groupLock is used to protect modifications to group entries
groupLock sync.RWMutex
// logger is the server logger copied over from core
logger log.Logger
// entityPacker is used to pack multiple entity storage entries into 256
// buckets
entityPacker *storagepacker.StoragePacker
// groupPacker is used to pack multiple group storage entries into 256
// buckets
groupPacker *storagepacker.StoragePacker
}

View File

@@ -0,0 +1,269 @@
package vault
import (
"testing"
"time"
credGithub "github.com/hashicorp/vault/builtin/credential/github"
"github.com/hashicorp/vault/logical"
)
func TestIdentityStore_CreateEntity(t *testing.T) {
is, ghAccessor, _ := testIdentityStoreWithGithubAuth(t)
alias := &logical.Alias{
MountType: "github",
MountAccessor: ghAccessor,
Name: "githubuser",
}
entity, err := is.CreateEntity(alias)
if err != nil {
t.Fatal(err)
}
if entity == nil {
t.Fatalf("expected a non-nil entity")
}
if len(entity.Aliases) != 1 {
t.Fatalf("bad: length of aliases; expected: 1, actual: %d", len(entity.Aliases))
}
if entity.Aliases[0].Name != alias.Name {
t.Fatalf("bad: alias name; expected: %q, actual: %q", alias.Name, entity.Aliases[0].Name)
}
// Try recreating an entity with the same alias details. It should fail.
entity, err = is.CreateEntity(alias)
if err == nil {
t.Fatalf("expected an error")
}
}
func TestIdentityStore_EntityByAliasFactors(t *testing.T) {
var err error
var resp *logical.Response
is, ghAccessor, _ := testIdentityStoreWithGithubAuth(t)
registerData := map[string]interface{}{
"name": "testentityname",
"metadata": []string{"someusefulkey=someusefulvalue"},
"policies": []string{"testpolicy1", "testpolicy2"},
}
registerReq := &logical.Request{
Operation: logical.UpdateOperation,
Path: "entity",
Data: registerData,
}
// Register the entity
resp, err = is.HandleRequest(registerReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
idRaw, ok := resp.Data["id"]
if !ok {
t.Fatalf("entity id not present in response")
}
entityID := idRaw.(string)
if entityID == "" {
t.Fatalf("invalid entity id")
}
aliasData := map[string]interface{}{
"entity_id": entityID,
"name": "alias_name",
"mount_accessor": ghAccessor,
}
aliasReq := &logical.Request{
Operation: logical.UpdateOperation,
Path: "alias",
Data: aliasData,
}
resp, err = is.HandleRequest(aliasReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
if resp == nil {
t.Fatalf("expected a non-nil response")
}
entity, err := is.EntityByAliasFactors(ghAccessor, "alias_name", false)
if err != nil {
t.Fatal(err)
}
if entity == nil {
t.Fatalf("expected a non-nil entity")
}
if entity.ID != entityID {
t.Fatalf("bad: entity ID; expected: %q actual: %q", entityID, entity.ID)
}
}
func TestIdentityStore_WrapInfoInheritance(t *testing.T) {
var err error
var resp *logical.Response
core, is, ts, _ := testCoreWithIdentityTokenGithub(t)
registerData := map[string]interface{}{
"name": "testentityname",
"metadata": []string{"someusefulkey=someusefulvalue"},
"policies": []string{"testpolicy1", "testpolicy2"},
}
registerReq := &logical.Request{
Operation: logical.UpdateOperation,
Path: "entity",
Data: registerData,
}
// Register the entity
resp, err = is.HandleRequest(registerReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
idRaw, ok := resp.Data["id"]
if !ok {
t.Fatalf("entity id not present in response")
}
entityID := idRaw.(string)
if entityID == "" {
t.Fatalf("invalid entity id")
}
// Create a token which has EntityID set and has update permissions to
// sys/wrapping/wrap
te := &TokenEntry{
Path: "test",
Policies: []string{"default", responseWrappingPolicyName},
EntityID: entityID,
}
if err := ts.create(te); err != nil {
t.Fatal(err)
}
wrapReq := &logical.Request{
Path: "sys/wrapping/wrap",
ClientToken: te.ID,
Operation: logical.UpdateOperation,
Data: map[string]interface{}{
"foo": "bar",
},
WrapInfo: &logical.RequestWrapInfo{
TTL: time.Duration(5 * time.Second),
},
}
resp, err = core.HandleRequest(wrapReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: resp: %#v, err: %v", resp, err)
}
if resp.WrapInfo == nil {
t.Fatalf("expected a non-nil WrapInfo")
}
if resp.WrapInfo.WrappedEntityID != entityID {
t.Fatalf("bad: WrapInfo in response not having proper entity ID set; expected: %q, actual:%q", entityID, resp.WrapInfo.WrappedEntityID)
}
}
func TestIdentityStore_TokenEntityInheritance(t *testing.T) {
_, ts, _, _ := TestCoreWithTokenStore(t)
// Create a token which has EntityID set
te := &TokenEntry{
Path: "test",
Policies: []string{"dev", "prod"},
EntityID: "testentityid",
}
if err := ts.create(te); err != nil {
t.Fatal(err)
}
// Create a child token; this should inherit the EntityID
tokenReq := &logical.Request{
Operation: logical.UpdateOperation,
Path: "create",
ClientToken: te.ID,
}
resp, err := ts.HandleRequest(tokenReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: resp: %#v err: %v", err, resp)
}
if resp.Auth.EntityID != te.EntityID {
t.Fatalf("bad: entity ID; expected: %v, actual: %v", te.EntityID, resp.Auth.EntityID)
}
// Create an orphan token; this should not inherit the EntityID
tokenReq.Path = "create-orphan"
resp, err = ts.HandleRequest(tokenReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: resp: %#v err: %v", err, resp)
}
if resp.Auth.EntityID != "" {
t.Fatalf("expected entity ID to be not set")
}
}
func testCoreWithIdentityTokenGithub(t *testing.T) (*Core, *IdentityStore, *TokenStore, string) {
is, ghAccessor, core := testIdentityStoreWithGithubAuth(t)
ts := testTokenStore(t, core)
return core, is, ts, ghAccessor
}
// testIdentityStoreWithGithubAuth returns an instance of identity store which
// is mounted by default. This function also enables the github auth backend to
// assist with testing aliases and entities that require an valid mount
// accessor of an auth backend.
func testIdentityStoreWithGithubAuth(t *testing.T) (*IdentityStore, string, *Core) {
// Add github credential factory to core config
err := AddTestCredentialBackend("github", credGithub.Factory)
if err != nil {
t.Fatalf("err: %s", err)
}
c, _, _ := TestCoreUnsealed(t)
meGH := &MountEntry{
Table: credentialTableType,
Path: "github/",
Type: "github",
Description: "github auth",
}
err = c.enableCredential(meGH)
if err != nil {
t.Fatal(err)
}
// Identity store will be mounted by now, just fetch it from router
identitystore := c.router.MatchingBackend("identity/")
if identitystore == nil {
t.Fatalf("failed to fetch identity store from router")
}
return identitystore.(*IdentityStore), meGH.Accessor, c
}
func TestIdentityStore_MetadataKeyRegex(t *testing.T) {
key := "validVALID012_-=+/"
if !metaKeyFormatRegEx(key) {
t.Fatal("failed to accept valid metadata key")
}
key = "a:b"
if metaKeyFormatRegEx(key) {
t.Fatal("accepted invalid metadata key")
}
}

View File

@@ -0,0 +1,86 @@
package vault
import (
"strings"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
)
func upgradePaths(i *IdentityStore) []*framework.Path {
return []*framework.Path{
{
Pattern: "persona$",
Fields: map[string]*framework.FieldSchema{
"id": {
Type: framework.TypeString,
Description: "ID of the alias",
},
"entity_id": {
Type: framework.TypeString,
Description: "Entity ID to which this alias belongs to",
},
"mount_accessor": {
Type: framework.TypeString,
Description: "Mount accessor to which this alias belongs to",
},
"name": {
Type: framework.TypeString,
Description: "Name of the alias",
},
"metadata": {
Type: framework.TypeStringSlice,
Description: "Metadata to be associated with the alias. Format should be a list of `key=value` pairs.",
},
},
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.UpdateOperation: i.checkPremiumVersion(i.pathAliasRegister),
},
HelpSynopsis: strings.TrimSpace(aliasHelp["alias"][0]),
HelpDescription: strings.TrimSpace(aliasHelp["alias"][1]),
},
{
Pattern: "persona/id/" + framework.GenericNameRegex("id"),
Fields: map[string]*framework.FieldSchema{
"id": {
Type: framework.TypeString,
Description: "ID of the alias",
},
"entity_id": {
Type: framework.TypeString,
Description: "Entity ID to which this alias should be tied to",
},
"mount_accessor": {
Type: framework.TypeString,
Description: "Mount accessor to which this alias belongs to",
},
"name": {
Type: framework.TypeString,
Description: "Name of the alias",
},
"metadata": {
Type: framework.TypeStringSlice,
Description: "Metadata to be associated with the alias. Format should be a comma separated list of `key=value` pairs.",
},
},
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.UpdateOperation: i.checkPremiumVersion(i.pathAliasIDUpdate),
logical.ReadOperation: i.checkPremiumVersion(i.pathAliasIDRead),
logical.DeleteOperation: i.checkPremiumVersion(i.pathAliasIDDelete),
},
HelpSynopsis: strings.TrimSpace(aliasHelp["alias-id"][0]),
HelpDescription: strings.TrimSpace(aliasHelp["alias-id"][1]),
},
{
Pattern: "persona/id/?$",
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.ListOperation: i.checkPremiumVersion(i.pathAliasIDList),
},
HelpSynopsis: strings.TrimSpace(aliasHelp["alias-id-list"][0]),
HelpDescription: strings.TrimSpace(aliasHelp["alias-id-list"][1]),
},
}
}

2122
vault/identity_store_util.go Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,40 @@
package vault
import (
"reflect"
"testing"
)
func TestIdentityStore_parseMetadata(t *testing.T) {
goodKVs := []string{
"key1=value1",
"key2=value1=value2",
}
expectedMap := map[string]string{
"key1": "value1",
"key2": "value1=value2",
}
actualMap, err := parseMetadata(goodKVs)
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(expectedMap, actualMap) {
t.Fatalf("bad: metadata; expected: %#v\n, actual: %#v\n", expectedMap, actualMap)
}
badKV := []string{
"=world",
}
actualMap, err = parseMetadata(badKV)
if err == nil {
t.Fatalf("expected an error; got: %#v", actualMap)
}
badKV[0] = "world"
actualMap, err = parseMetadata(badKV)
if err == nil {
t.Fatalf("expected an error: %#v", actualMap)
}
}

View File

@@ -149,6 +149,17 @@ func TestSystemBackend_mounts(t *testing.T) {
},
"local": true,
},
"identity/": map[string]interface{}{
"description": "identity store",
"type": "identity",
"accessor": resp.Data["identity/"].(map[string]interface{})["accessor"],
"config": map[string]interface{}{
"default_lease_ttl": resp.Data["identity/"].(map[string]interface{})["config"].(map[string]interface{})["default_lease_ttl"].(int64),
"max_lease_ttl": resp.Data["identity/"].(map[string]interface{})["config"].(map[string]interface{})["max_lease_ttl"].(int64),
"force_no_cache": false,
},
"local": false,
},
}
if !reflect.DeepEqual(resp.Data, exp) {
t.Fatalf("Got:\n%#v\nExpected:\n%#v", resp.Data, exp)

View File

@@ -50,12 +50,14 @@ var (
"auth/",
"sys/",
"cubbyhole/",
"identity/",
}
untunableMounts = []string{
"cubbyhole/",
"sys/",
"audit/",
"identity/",
}
// singletonMounts can only exist in one location and are
@@ -64,6 +66,7 @@ var (
"cubbyhole",
"system",
"token",
"identity",
}
// mountAliases maps old backend names to new backend names, allowing us
@@ -255,6 +258,8 @@ func (c *Core) mount(entry *MountEntry) error {
return err
}
c.setCoreBackend(entry, backend, view)
newTable := c.mounts.shallowClone()
newTable.Entries = append(newTable.Entries, entry)
if err := c.persistMounts(newTable, entry.Local); err != nil {
@@ -713,14 +718,8 @@ func (c *Core) setupMounts() error {
return err
}
switch entry.Type {
case "system":
c.systemBarrierView = view
case "cubbyhole":
ch := backend.(*CubbyholeBackend)
ch.saltUUID = entry.UUID
ch.storageView = view
}
c.setCoreBackend(entry, backend, view)
ROUTER_MOUNT:
// Mount the backend
err = c.router.Mount(backend, entry.Path, entry, view)
@@ -865,8 +864,29 @@ func (c *Core) requiredMountTable() *MountTable {
UUID: sysUUID,
Accessor: sysAccessor,
}
identityUUID, err := uuid.GenerateUUID()
if err != nil {
panic(fmt.Sprintf("could not create identity mount entry UUID: %v", err))
}
identityAccessor, err := c.generateMountAccessor("identity")
if err != nil {
panic(fmt.Sprintf("could not generate identity accessor: %v", err))
}
identityMount := &MountEntry{
Table: mountTableType,
Path: "identity/",
Type: "identity",
Description: "identity store",
UUID: identityUUID,
Accessor: identityAccessor,
}
table.Entries = append(table.Entries, cubbyholeMount)
table.Entries = append(table.Entries, sysMount)
table.Entries = append(table.Entries, identityMount)
return table
}
@@ -898,3 +918,17 @@ func (c *Core) singletonMountTables() (mounts, auth *MountTable) {
return
}
func (c *Core) setCoreBackend(entry *MountEntry, backend logical.Backend, view *BarrierView) {
switch entry.Type {
case "system":
c.systemBackend = backend.(*SystemBackend)
c.systemBarrierView = view
case "cubbyhole":
ch := backend.(*CubbyholeBackend)
ch.saltUUID = entry.UUID
ch.storageView = view
case "identity":
c.identityStore = backend.(*IdentityStore)
}
}

View File

@@ -592,31 +592,26 @@ func testCore_MountTable_UpgradeToTyped_Common(
}
func verifyDefaultTable(t *testing.T, table *MountTable) {
if len(table.Entries) != 3 {
if len(table.Entries) != 4 {
t.Fatalf("bad: %v", table.Entries)
}
table.sortEntriesByPath()
for idx, entry := range table.Entries {
switch idx {
case 0:
if entry.Path != "cubbyhole/" {
t.Fatalf("bad: %v", entry)
}
for _, entry := range table.Entries {
switch entry.Path {
case "cubbyhole/":
if entry.Type != "cubbyhole" {
t.Fatalf("bad: %v", entry)
}
case 1:
if entry.Path != "secret/" {
t.Fatalf("bad: %v", entry)
}
case "secret/":
if entry.Type != "kv" {
t.Fatalf("bad: %v", entry)
}
case 2:
if entry.Path != "sys/" {
case "sys/":
if entry.Type != "system" {
t.Fatalf("bad: %v", entry)
}
if entry.Type != "system" {
case "identity/":
if entry.Type != "identity" {
t.Fatalf("bad: %v", entry)
}
}
@@ -637,12 +632,13 @@ func TestSingletonMountTableFunc(t *testing.T) {
mounts, auth := c.singletonMountTables()
if len(mounts.Entries) != 1 {
if len(mounts.Entries) != 2 {
t.Fatal("length of mounts is wrong")
}
for _, entry := range mounts.Entries {
switch entry.Type {
case "system":
case "identity":
default:
t.Fatalf("unknown type %s", entry.Type)
}

View File

@@ -8,6 +8,7 @@ import (
"github.com/armon/go-metrics"
"github.com/hashicorp/go-multierror"
"github.com/hashicorp/vault/helper/consts"
"github.com/hashicorp/vault/helper/identity"
"github.com/hashicorp/vault/helper/jsonutil"
"github.com/hashicorp/vault/helper/policyutil"
"github.com/hashicorp/vault/helper/strutil"
@@ -66,7 +67,7 @@ func (c *Core) HandleRequest(req *logical.Request) (resp *logical.Response, err
resp.WrapInfo.TTL != 0
if wrapping {
cubbyResp, cubbyErr := c.wrapInCubbyhole(req, resp)
cubbyResp, cubbyErr := c.wrapInCubbyhole(req, resp, auth)
// If not successful, returns either an error response from the
// cubbyhole backend or an error; if either is set, set resp and err to
// those and continue so that that's what we audit log. Otherwise
@@ -387,8 +388,42 @@ func (c *Core) handleLoginRequest(req *logical.Request) (*logical.Response, *log
// If the response generated an authentication, then generate the token
var auth *logical.Auth
if resp != nil && resp.Auth != nil {
var entity *identity.Entity
auth = resp.Auth
if auth.Alias != nil {
// Overwrite the mount type and mount path in the alias
// information
auth.Alias.MountType = req.MountType
auth.Alias.MountAccessor = req.MountAccessor
if auth.Alias.Name == "" {
return nil, nil, fmt.Errorf("missing name in alias")
}
var err error
// Check if an entity already exists for the given alias
entity, err = c.identityStore.EntityByAliasFactors(auth.Alias.MountAccessor, auth.Alias.Name, false)
if err != nil {
return nil, nil, err
}
// If not, create one.
if entity == nil {
c.logger.Debug("core: creating a new entity", "alias", auth.Alias)
entity, err = c.identityStore.CreateEntity(auth.Alias)
if err != nil {
return nil, nil, err
}
if entity == nil {
return nil, nil, fmt.Errorf("failed to create an entity for the authenticated alias")
}
}
auth.EntityID = entity.ID
}
if strutil.StrListSubset(auth.Policies, []string{"root"}) {
return logical.ErrorResponse("authentication backends cannot create root tokens"), nil, logical.ErrInvalidRequest
}

View File

@@ -47,6 +47,31 @@ type routeEntry struct {
loginPaths *radix.Tree
}
type validateMountResponse struct {
MountType string `json:"mount_type" structs:"mount_type" mapstructure:"mount_type"`
MountAccessor string `json:"mount_accessor" structs:"mount_accessor" mapstructure:"mount_accessor"`
MountPath string `json:"mount_path" structs:"mount_path" mapstructure:"mount_path"`
}
// validateMountByAccessor returns the mount type and ID for a given mount
// accessor
func (r *Router) validateMountByAccessor(accessor string) *validateMountResponse {
if accessor == "" {
return nil
}
mountEntry := r.MatchingMountByAccessor(accessor)
if mountEntry == nil {
return nil
}
return &validateMountResponse{
MountAccessor: mountEntry.Accessor,
MountType: mountEntry.Type,
MountPath: mountEntry.Path,
}
}
// SaltID is used to apply a salt and hash to an ID to make sure its not reversible
func (re *routeEntry) SaltID(id string) string {
return salt.SaltID(re.mountEntry.UUID, id, salt.SHA1Hash)
@@ -330,6 +355,17 @@ func (r *Router) routeCommon(req *logical.Request, existenceCheck bool) (*logica
// Attach the storage view for the request
req.Storage = re.storageView
originalEntityID := req.EntityID
// Allow EntityID to passthrough to the system backend. This is required to
// allow clients to generate MFA credentials in respective entity objects
// in identity store via the system backend.
switch {
case strings.HasPrefix(originalPath, "sys/"):
default:
req.EntityID = ""
}
// Hash the request token unless this is the token backend
clientToken := req.ClientToken
switch {
@@ -385,6 +421,12 @@ func (r *Router) routeCommon(req *logical.Request, existenceCheck bool) (*logica
// This is only set in one place, after routing, so should never be set
// by a backend
req.SetLastRemoteWAL(0)
// This will be used for attaching the mount accessor for the identities
// returned by the authentication backends
req.MountAccessor = re.mountEntry.Accessor
req.EntityID = originalEntityID
}()
// Invoke the backend
@@ -393,6 +435,11 @@ func (r *Router) routeCommon(req *logical.Request, existenceCheck bool) (*logica
return nil, ok, exists, err
} else {
resp, err := re.backend.HandleRequest(req)
if resp != nil &&
resp.Auth != nil &&
resp.Auth.Alias != nil {
resp.Auth.Alias.MountAccessor = re.mountEntry.Accessor
}
return resp, false, false, err
}
}

View File

@@ -154,6 +154,15 @@ func testCoreConfig(t testing.T, physicalBackend physical.Backend, logger log.Lo
noopBackends["http"] = func(config *logical.BackendConfig) (logical.Backend, error) {
return new(rawHTTP), nil
}
credentialBackends := make(map[string]logical.Factory)
for backendName, backendFactory := range noopBackends {
credentialBackends[backendName] = backendFactory
}
for backendName, backendFactory := range testCredentialBackends {
credentialBackends[backendName] = backendFactory
}
logicalBackends := make(map[string]logical.Factory)
for backendName, backendFactory := range noopBackends {
logicalBackends[backendName] = backendFactory
@@ -167,7 +176,7 @@ func testCoreConfig(t testing.T, physicalBackend physical.Backend, logger log.Lo
Physical: physicalBackend,
AuditBackends: noopAudits,
LogicalBackends: logicalBackends,
CredentialBackends: noopBackends,
CredentialBackends: credentialBackends,
DisableMlock: true,
Logger: logger,
}
@@ -375,6 +384,7 @@ func TestAddTestPlugin(t testing.T, c *Core, name, testFunc string) {
}
var testLogicalBackends = map[string]logical.Factory{}
var testCredentialBackends = map[string]logical.Factory{}
// Starts the test server which responds to SSH authentication.
// Used to test the SSH secret backend.
@@ -467,6 +477,19 @@ func executeServerCommand(ch ssh.Channel, req *ssh.Request) {
}()
}
// This adds a credential backend for the test core. This needs to be
// invoked before the test core is created.
func AddTestCredentialBackend(name string, factory logical.Factory) error {
if name == "" {
return fmt.Errorf("missing backend name")
}
if factory == nil {
return fmt.Errorf("missing backend factory function")
}
testCredentialBackends[name] = factory
return nil
}
// This adds a logical backend for the test core. This needs to be
// invoked before the test core is created.
func AddTestLogicalBackend(name string, factory logical.Factory) error {

View File

@@ -581,6 +581,8 @@ type TokenEntry struct {
NumUsesDeprecated int `json:"NumUses" mapstructure:"NumUses" structs:"NumUses"`
CreationTimeDeprecated int64 `json:"CreationTime" mapstructure:"CreationTime" structs:"CreationTime"`
ExplicitMaxTTLDeprecated time.Duration `json:"ExplicitMaxTTL" mapstructure:"ExplicitMaxTTL" structs:"ExplicitMaxTTL"`
EntityID string `json:"entity_id" mapstructure:"entity_id" structs:"entity_id"`
}
// tsRoleEntry contains token store role information
@@ -1730,6 +1732,13 @@ func (ts *TokenStore) handleCreateCommon(
}
}
// At this point, it is clear whether the token is going to be an orphan or
// not. If the token is not going to be an orphan, inherit the parent's
// entity identifier into the child token.
if te.Parent != "" {
te.EntityID = parent.EntityID
}
if data.ExplicitMaxTTL != "" {
dur, err := parseutil.ParseDurationSecond(data.ExplicitMaxTTL)
if err != nil {
@@ -1872,6 +1881,7 @@ func (ts *TokenStore) handleCreateCommon(
},
ClientToken: te.ID,
Accessor: te.Accessor,
EntityID: te.EntityID,
}
if ts.policyLookupFunc != nil {
@@ -2037,6 +2047,7 @@ func (ts *TokenStore) handleLookup(
"expire_time": nil,
"ttl": int64(0),
"explicit_max_ttl": int64(out.ExplicitMaxTTL.Seconds()),
"entity_id": out.EntityID,
},
}

View File

@@ -1448,6 +1448,7 @@ func TestTokenStore_HandleRequest_Lookup(t *testing.T) {
"ttl": int64(0),
"explicit_max_ttl": int64(0),
"expire_time": nil,
"entity_id": "",
}
if resp.Data["creation_time"].(int64) == 0 {
@@ -1487,6 +1488,7 @@ func TestTokenStore_HandleRequest_Lookup(t *testing.T) {
"ttl": int64(3600),
"explicit_max_ttl": int64(0),
"renewable": true,
"entity_id": "",
}
if resp.Data["creation_time"].(int64) == 0 {
@@ -1537,6 +1539,7 @@ func TestTokenStore_HandleRequest_Lookup(t *testing.T) {
"ttl": int64(3600),
"explicit_max_ttl": int64(0),
"renewable": true,
"entity_id": "",
}
if resp.Data["creation_time"].(int64) == 0 {
@@ -1618,6 +1621,7 @@ func TestTokenStore_HandleRequest_LookupSelf(t *testing.T) {
"creation_ttl": int64(3600),
"ttl": int64(3600),
"explicit_max_ttl": int64(0),
"entity_id": "",
}
if resp.Data["creation_time"].(int64) == 0 {

View File

@@ -72,7 +72,7 @@ func (c *Core) ensureWrappingKey() error {
return nil
}
func (c *Core) wrapInCubbyhole(req *logical.Request, resp *logical.Response) (*logical.Response, error) {
func (c *Core) wrapInCubbyhole(req *logical.Request, resp *logical.Response, auth *logical.Auth) (*logical.Response, error) {
// Before wrapping, obey special rules for listing: if no entries are
// found, 404. This prevents unwrapping only to find empty data.
if req.Operation == logical.ListOperation {
@@ -120,6 +120,10 @@ func (c *Core) wrapInCubbyhole(req *logical.Request, resp *logical.Response) (*l
resp.WrapInfo.CreationPath = req.Path
}
if auth != nil && auth.EntityID != "" {
resp.WrapInfo.WrappedEntityID = auth.EntityID
}
// This will only be non-nil if this response contains a token, so in that
// case put the accessor in the wrap info.
if resp.Auth != nil {
@@ -223,7 +227,7 @@ func (c *Core) wrapInCubbyhole(req *logical.Request, resp *logical.Response) (*l
return cubbyResp, nil
}
auth := &logical.Auth{
wAuth := &logical.Auth{
ClientToken: te.ID,
Policies: []string{"response-wrapping"},
LeaseOptions: logical.LeaseOptions{
@@ -233,7 +237,7 @@ func (c *Core) wrapInCubbyhole(req *logical.Request, resp *logical.Response) (*l
}
// Register the wrapped token with the expiration manager
if err := c.expiration.RegisterAuth(te.Path, auth); err != nil {
if err := c.expiration.RegisterAuth(te.Path, wAuth); err != nil {
// Revoke since it's not yet being tracked for expiration
c.tokenStore.Revoke(te.ID)
c.logger.Error("core: failed to register cubbyhole wrapping token lease", "request_path", req.Path, "error", err)

363
vendor/github.com/hashicorp/go-memdb/LICENSE generated vendored Normal file
View File

@@ -0,0 +1,363 @@
Mozilla Public License, version 2.0
1. Definitions
1.1. "Contributor"
means each individual or legal entity that creates, contributes to the
creation of, or owns Covered Software.
1.2. "Contributor Version"
means the combination of the Contributions of others (if any) used by a
Contributor and that particular Contributor's Contribution.
1.3. "Contribution"
means Covered Software of a particular Contributor.
1.4. "Covered Software"
means Source Code Form to which the initial Contributor has attached the
notice in Exhibit A, the Executable Form of such Source Code Form, and
Modifications of such Source Code Form, in each case including portions
thereof.
1.5. "Incompatible With Secondary Licenses"
means
a. that the initial Contributor has attached the notice described in
Exhibit B to the Covered Software; or
b. that the Covered Software was made available under the terms of
version 1.1 or earlier of the License, but not also under the terms of
a Secondary License.
1.6. "Executable Form"
means any form of the work other than Source Code Form.
1.7. "Larger Work"
means a work that combines Covered Software with other material, in a
separate file or files, that is not Covered Software.
1.8. "License"
means this document.
1.9. "Licensable"
means having the right to grant, to the maximum extent possible, whether
at the time of the initial grant or subsequently, any and all of the
rights conveyed by this License.
1.10. "Modifications"
means any of the following:
a. any file in Source Code Form that results from an addition to,
deletion from, or modification of the contents of Covered Software; or
b. any new file in Source Code Form that contains any Covered Software.
1.11. "Patent Claims" of a Contributor
means any patent claim(s), including without limitation, method,
process, and apparatus claims, in any patent Licensable by such
Contributor that would be infringed, but for the grant of the License,
by the making, using, selling, offering for sale, having made, import,
or transfer of either its Contributions or its Contributor Version.
1.12. "Secondary License"
means either the GNU General Public License, Version 2.0, the GNU Lesser
General Public License, Version 2.1, the GNU Affero General Public
License, Version 3.0, or any later versions of those licenses.
1.13. "Source Code Form"
means the form of the work preferred for making modifications.
1.14. "You" (or "Your")
means an individual or a legal entity exercising rights under this
License. For legal entities, "You" includes any entity that controls, is
controlled by, or is under common control with You. For purposes of this
definition, "control" means (a) the power, direct or indirect, to cause
the direction or management of such entity, whether by contract or
otherwise, or (b) ownership of more than fifty percent (50%) of the
outstanding shares or beneficial ownership of such entity.
2. License Grants and Conditions
2.1. Grants
Each Contributor hereby grants You a world-wide, royalty-free,
non-exclusive license:
a. under intellectual property rights (other than patent or trademark)
Licensable by such Contributor to use, reproduce, make available,
modify, display, perform, distribute, and otherwise exploit its
Contributions, either on an unmodified basis, with Modifications, or
as part of a Larger Work; and
b. under Patent Claims of such Contributor to make, use, sell, offer for
sale, have made, import, and otherwise transfer either its
Contributions or its Contributor Version.
2.2. Effective Date
The licenses granted in Section 2.1 with respect to any Contribution
become effective for each Contribution on the date the Contributor first
distributes such Contribution.
2.3. Limitations on Grant Scope
The licenses granted in this Section 2 are the only rights granted under
this License. No additional rights or licenses will be implied from the
distribution or licensing of Covered Software under this License.
Notwithstanding Section 2.1(b) above, no patent license is granted by a
Contributor:
a. for any code that a Contributor has removed from Covered Software; or
b. for infringements caused by: (i) Your and any other third party's
modifications of Covered Software, or (ii) the combination of its
Contributions with other software (except as part of its Contributor
Version); or
c. under Patent Claims infringed by Covered Software in the absence of
its Contributions.
This License does not grant any rights in the trademarks, service marks,
or logos of any Contributor (except as may be necessary to comply with
the notice requirements in Section 3.4).
2.4. Subsequent Licenses
No Contributor makes additional grants as a result of Your choice to
distribute the Covered Software under a subsequent version of this
License (see Section 10.2) or under the terms of a Secondary License (if
permitted under the terms of Section 3.3).
2.5. Representation
Each Contributor represents that the Contributor believes its
Contributions are its original creation(s) or it has sufficient rights to
grant the rights to its Contributions conveyed by this License.
2.6. Fair Use
This License is not intended to limit any rights You have under
applicable copyright doctrines of fair use, fair dealing, or other
equivalents.
2.7. Conditions
Sections 3.1, 3.2, 3.3, and 3.4 are conditions of the licenses granted in
Section 2.1.
3. Responsibilities
3.1. Distribution of Source Form
All distribution of Covered Software in Source Code Form, including any
Modifications that You create or to which You contribute, must be under
the terms of this License. You must inform recipients that the Source
Code Form of the Covered Software is governed by the terms of this
License, and how they can obtain a copy of this License. You may not
attempt to alter or restrict the recipients' rights in the Source Code
Form.
3.2. Distribution of Executable Form
If You distribute Covered Software in Executable Form then:
a. such Covered Software must also be made available in Source Code Form,
as described in Section 3.1, and You must inform recipients of the
Executable Form how they can obtain a copy of such Source Code Form by
reasonable means in a timely manner, at a charge no more than the cost
of distribution to the recipient; and
b. You may distribute such Executable Form under the terms of this
License, or sublicense it under different terms, provided that the
license for the Executable Form does not attempt to limit or alter the
recipients' rights in the Source Code Form under this License.
3.3. Distribution of a Larger Work
You may create and distribute a Larger Work under terms of Your choice,
provided that You also comply with the requirements of this License for
the Covered Software. If the Larger Work is a combination of Covered
Software with a work governed by one or more Secondary Licenses, and the
Covered Software is not Incompatible With Secondary Licenses, this
License permits You to additionally distribute such Covered Software
under the terms of such Secondary License(s), so that the recipient of
the Larger Work may, at their option, further distribute the Covered
Software under the terms of either this License or such Secondary
License(s).
3.4. Notices
You may not remove or alter the substance of any license notices
(including copyright notices, patent notices, disclaimers of warranty, or
limitations of liability) contained within the Source Code Form of the
Covered Software, except that You may alter any license notices to the
extent required to remedy known factual inaccuracies.
3.5. Application of Additional Terms
You may choose to offer, and to charge a fee for, warranty, support,
indemnity or liability obligations to one or more recipients of Covered
Software. However, You may do so only on Your own behalf, and not on
behalf of any Contributor. You must make it absolutely clear that any
such warranty, support, indemnity, or liability obligation is offered by
You alone, and You hereby agree to indemnify every Contributor for any
liability incurred by such Contributor as a result of warranty, support,
indemnity or liability terms You offer. You may include additional
disclaimers of warranty and limitations of liability specific to any
jurisdiction.
4. Inability to Comply Due to Statute or Regulation
If it is impossible for You to comply with any of the terms of this License
with respect to some or all of the Covered Software due to statute,
judicial order, or regulation then You must: (a) comply with the terms of
this License to the maximum extent possible; and (b) describe the
limitations and the code they affect. Such description must be placed in a
text file included with all distributions of the Covered Software under
this License. Except to the extent prohibited by statute or regulation,
such description must be sufficiently detailed for a recipient of ordinary
skill to be able to understand it.
5. Termination
5.1. The rights granted under this License will terminate automatically if You
fail to comply with any of its terms. However, if You become compliant,
then the rights granted under this License from a particular Contributor
are reinstated (a) provisionally, unless and until such Contributor
explicitly and finally terminates Your grants, and (b) on an ongoing
basis, if such Contributor fails to notify You of the non-compliance by
some reasonable means prior to 60 days after You have come back into
compliance. Moreover, Your grants from a particular Contributor are
reinstated on an ongoing basis if such Contributor notifies You of the
non-compliance by some reasonable means, this is the first time You have
received notice of non-compliance with this License from such
Contributor, and You become compliant prior to 30 days after Your receipt
of the notice.
5.2. If You initiate litigation against any entity by asserting a patent
infringement claim (excluding declaratory judgment actions,
counter-claims, and cross-claims) alleging that a Contributor Version
directly or indirectly infringes any patent, then the rights granted to
You by any and all Contributors for the Covered Software under Section
2.1 of this License shall terminate.
5.3. In the event of termination under Sections 5.1 or 5.2 above, all end user
license agreements (excluding distributors and resellers) which have been
validly granted by You or Your distributors under this License prior to
termination shall survive termination.
6. Disclaimer of Warranty
Covered Software is provided under this License on an "as is" basis,
without warranty of any kind, either expressed, implied, or statutory,
including, without limitation, warranties that the Covered Software is free
of defects, merchantable, fit for a particular purpose or non-infringing.
The entire risk as to the quality and performance of the Covered Software
is with You. Should any Covered Software prove defective in any respect,
You (not any Contributor) assume the cost of any necessary servicing,
repair, or correction. This disclaimer of warranty constitutes an essential
part of this License. No use of any Covered Software is authorized under
this License except under this disclaimer.
7. Limitation of Liability
Under no circumstances and under no legal theory, whether tort (including
negligence), contract, or otherwise, shall any Contributor, or anyone who
distributes Covered Software as permitted above, be liable to You for any
direct, indirect, special, incidental, or consequential damages of any
character including, without limitation, damages for lost profits, loss of
goodwill, work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses, even if such party shall have been
informed of the possibility of such damages. This limitation of liability
shall not apply to liability for death or personal injury resulting from
such party's negligence to the extent applicable law prohibits such
limitation. Some jurisdictions do not allow the exclusion or limitation of
incidental or consequential damages, so this exclusion and limitation may
not apply to You.
8. Litigation
Any litigation relating to this License may be brought only in the courts
of a jurisdiction where the defendant maintains its principal place of
business and such litigation shall be governed by laws of that
jurisdiction, without reference to its conflict-of-law provisions. Nothing
in this Section shall prevent a party's ability to bring cross-claims or
counter-claims.
9. Miscellaneous
This License represents the complete agreement concerning the subject
matter hereof. If any provision of this License is held to be
unenforceable, such provision shall be reformed only to the extent
necessary to make it enforceable. Any law or regulation which provides that
the language of a contract shall be construed against the drafter shall not
be used to construe this License against a Contributor.
10. Versions of the License
10.1. New Versions
Mozilla Foundation is the license steward. Except as provided in Section
10.3, no one other than the license steward has the right to modify or
publish new versions of this License. Each version will be given a
distinguishing version number.
10.2. Effect of New Versions
You may distribute the Covered Software under the terms of the version
of the License under which You originally received the Covered Software,
or under the terms of any subsequent version published by the license
steward.
10.3. Modified Versions
If you create software not governed by this License, and you want to
create a new license for such software, you may create and use a
modified version of this License if you rename the license and remove
any references to the name of the license steward (except to note that
such modified license differs from this License).
10.4. Distributing Source Code Form that is Incompatible With Secondary
Licenses If You choose to distribute Source Code Form that is
Incompatible With Secondary Licenses under the terms of this version of
the License, the notice described in Exhibit B of this License must be
attached.
Exhibit A - Source Code Form License Notice
This Source Code Form is subject to the
terms of the Mozilla Public License, v.
2.0. If a copy of the MPL was not
distributed with this file, You can
obtain one at
http://mozilla.org/MPL/2.0/.
If it is not possible or desirable to put the notice in a particular file,
then You may include the notice in a location (such as a LICENSE file in a
relevant directory) where a recipient would be likely to look for such a
notice.
You may add additional accurate notices of copyright ownership.
Exhibit B - "Incompatible With Secondary Licenses" Notice
This Source Code Form is "Incompatible
With Secondary Licenses", as defined by
the Mozilla Public License, v. 2.0.

98
vendor/github.com/hashicorp/go-memdb/README.md generated vendored Normal file
View File

@@ -0,0 +1,98 @@
# go-memdb
Provides the `memdb` package that implements a simple in-memory database
built on immutable radix trees. The database provides Atomicity, Consistency
and Isolation from ACID. Being that it is in-memory, it does not provide durability.
The database is instantiated with a schema that specifies the tables and indices
that exist and allows transactions to be executed.
The database provides the following:
* Multi-Version Concurrency Control (MVCC) - By leveraging immutable radix trees
the database is able to support any number of concurrent readers without locking,
and allows a writer to make progress.
* Transaction Support - The database allows for rich transactions, in which multiple
objects are inserted, updated or deleted. The transactions can span multiple tables,
and are applied atomically. The database provides atomicity and isolation in ACID
terminology, such that until commit the updates are not visible.
* Rich Indexing - Tables can support any number of indexes, which can be simple like
a single field index, or more advanced compound field indexes. Certain types like
UUID can be efficiently compressed from strings into byte indexes for reduced
storage requirements.
* Watches - Callers can populate a watch set as part of a query, which can be used to
detect when a modification has been made to the database which affects the query
results. This lets callers easily watch for changes in the database in a very general
way.
For the underlying immutable radix trees, see [go-immutable-radix](https://github.com/hashicorp/go-immutable-radix).
Documentation
=============
The full documentation is available on [Godoc](http://godoc.org/github.com/hashicorp/go-memdb).
Example
=======
Below is a simple example of usage
```go
// Create a sample struct
type Person struct {
Email string
Name string
Age int
}
// Create the DB schema
schema := &memdb.DBSchema{
Tables: map[string]*memdb.TableSchema{
"person": &memdb.TableSchema{
Name: "person",
Indexes: map[string]*memdb.IndexSchema{
"id": &memdb.IndexSchema{
Name: "id",
Unique: true,
Indexer: &memdb.StringFieldIndex{Field: "Email"},
},
},
},
},
}
// Create a new data base
db, err := memdb.NewMemDB(schema)
if err != nil {
panic(err)
}
// Create a write transaction
txn := db.Txn(true)
// Insert a new person
p := &Person{"joe@aol.com", "Joe", 30}
if err := txn.Insert("person", p); err != nil {
panic(err)
}
// Commit the transaction
txn.Commit()
// Create read-only transaction
txn = db.Txn(false)
defer txn.Abort()
// Lookup by email
raw, err := txn.First("person", "id", "joe@aol.com")
if err != nil {
panic(err)
}
// Say hi!
fmt.Printf("Hello %s!", raw.(*Person).Name)
```

33
vendor/github.com/hashicorp/go-memdb/filter.go generated vendored Normal file
View File

@@ -0,0 +1,33 @@
package memdb
// FilterFunc is a function that takes the results of an iterator and returns
// whether the result should be filtered out.
type FilterFunc func(interface{}) bool
// FilterIterator is used to wrap a ResultIterator and apply a filter over it.
type FilterIterator struct {
// filter is the filter function applied over the base iterator.
filter FilterFunc
// iter is the iterator that is being wrapped.
iter ResultIterator
}
func NewFilterIterator(wrap ResultIterator, filter FilterFunc) *FilterIterator {
return &FilterIterator{
filter: filter,
iter: wrap,
}
}
// WatchCh returns the watch channel of the wrapped iterator.
func (f *FilterIterator) WatchCh() <-chan struct{} { return f.iter.WatchCh() }
// Next returns the next non-filtered result from the wrapped iterator
func (f *FilterIterator) Next() interface{} {
for {
if value := f.iter.Next(); value == nil || !f.filter(value) {
return value
}
}
}

569
vendor/github.com/hashicorp/go-memdb/index.go generated vendored Normal file
View File

@@ -0,0 +1,569 @@
package memdb
import (
"encoding/binary"
"encoding/hex"
"fmt"
"reflect"
"strings"
)
// Indexer is an interface used for defining indexes
type Indexer interface {
// ExactFromArgs is used to build an exact index lookup
// based on arguments
FromArgs(args ...interface{}) ([]byte, error)
}
// SingleIndexer is an interface used for defining indexes
// generating a single entry per object
type SingleIndexer interface {
// FromObject is used to extract an index value from an
// object or to indicate that the index value is missing.
FromObject(raw interface{}) (bool, []byte, error)
}
// MultiIndexer is an interface used for defining indexes
// generating multiple entries per object
type MultiIndexer interface {
// FromObject is used to extract index values from an
// object or to indicate that the index value is missing.
FromObject(raw interface{}) (bool, [][]byte, error)
}
// PrefixIndexer can optionally be implemented for any
// indexes that support prefix based iteration. This may
// not apply to all indexes.
type PrefixIndexer interface {
// PrefixFromArgs returns a prefix that should be used
// for scanning based on the arguments
PrefixFromArgs(args ...interface{}) ([]byte, error)
}
// StringFieldIndex is used to extract a field from an object
// using reflection and builds an index on that field.
type StringFieldIndex struct {
Field string
Lowercase bool
}
func (s *StringFieldIndex) FromObject(obj interface{}) (bool, []byte, error) {
v := reflect.ValueOf(obj)
v = reflect.Indirect(v) // Dereference the pointer if any
fv := v.FieldByName(s.Field)
if !fv.IsValid() {
return false, nil,
fmt.Errorf("field '%s' for %#v is invalid", s.Field, obj)
}
val := fv.String()
if val == "" {
return false, nil, nil
}
if s.Lowercase {
val = strings.ToLower(val)
}
// Add the null character as a terminator
val += "\x00"
return true, []byte(val), nil
}
func (s *StringFieldIndex) FromArgs(args ...interface{}) ([]byte, error) {
if len(args) != 1 {
return nil, fmt.Errorf("must provide only a single argument")
}
arg, ok := args[0].(string)
if !ok {
return nil, fmt.Errorf("argument must be a string: %#v", args[0])
}
if s.Lowercase {
arg = strings.ToLower(arg)
}
// Add the null character as a terminator
arg += "\x00"
return []byte(arg), nil
}
func (s *StringFieldIndex) PrefixFromArgs(args ...interface{}) ([]byte, error) {
val, err := s.FromArgs(args...)
if err != nil {
return nil, err
}
// Strip the null terminator, the rest is a prefix
n := len(val)
if n > 0 {
return val[:n-1], nil
}
return val, nil
}
// StringSliceFieldIndex is used to extract a field from an object
// using reflection and builds an index on that field.
type StringSliceFieldIndex struct {
Field string
Lowercase bool
}
func (s *StringSliceFieldIndex) FromObject(obj interface{}) (bool, [][]byte, error) {
v := reflect.ValueOf(obj)
v = reflect.Indirect(v) // Dereference the pointer if any
fv := v.FieldByName(s.Field)
if !fv.IsValid() {
return false, nil,
fmt.Errorf("field '%s' for %#v is invalid", s.Field, obj)
}
if fv.Kind() != reflect.Slice || fv.Type().Elem().Kind() != reflect.String {
return false, nil, fmt.Errorf("field '%s' is not a string slice", s.Field)
}
length := fv.Len()
vals := make([][]byte, 0, length)
for i := 0; i < fv.Len(); i++ {
val := fv.Index(i).String()
if val == "" {
continue
}
if s.Lowercase {
val = strings.ToLower(val)
}
// Add the null character as a terminator
val += "\x00"
vals = append(vals, []byte(val))
}
if len(vals) == 0 {
return false, nil, nil
}
return true, vals, nil
}
func (s *StringSliceFieldIndex) FromArgs(args ...interface{}) ([]byte, error) {
if len(args) != 1 {
return nil, fmt.Errorf("must provide only a single argument")
}
arg, ok := args[0].(string)
if !ok {
return nil, fmt.Errorf("argument must be a string: %#v", args[0])
}
if s.Lowercase {
arg = strings.ToLower(arg)
}
// Add the null character as a terminator
arg += "\x00"
return []byte(arg), nil
}
func (s *StringSliceFieldIndex) PrefixFromArgs(args ...interface{}) ([]byte, error) {
val, err := s.FromArgs(args...)
if err != nil {
return nil, err
}
// Strip the null terminator, the rest is a prefix
n := len(val)
if n > 0 {
return val[:n-1], nil
}
return val, nil
}
// StringMapFieldIndex is used to extract a field of type map[string]string
// from an object using reflection and builds an index on that field.
type StringMapFieldIndex struct {
Field string
Lowercase bool
}
var MapType = reflect.MapOf(reflect.TypeOf(""), reflect.TypeOf("")).Kind()
func (s *StringMapFieldIndex) FromObject(obj interface{}) (bool, [][]byte, error) {
v := reflect.ValueOf(obj)
v = reflect.Indirect(v) // Dereference the pointer if any
fv := v.FieldByName(s.Field)
if !fv.IsValid() {
return false, nil, fmt.Errorf("field '%s' for %#v is invalid", s.Field, obj)
}
if fv.Kind() != MapType {
return false, nil, fmt.Errorf("field '%s' is not a map[string]string", s.Field)
}
length := fv.Len()
vals := make([][]byte, 0, length)
for _, key := range fv.MapKeys() {
k := key.String()
if k == "" {
continue
}
val := fv.MapIndex(key).String()
if s.Lowercase {
k = strings.ToLower(k)
val = strings.ToLower(val)
}
// Add the null character as a terminator
k += "\x00" + val + "\x00"
vals = append(vals, []byte(k))
}
if len(vals) == 0 {
return false, nil, nil
}
return true, vals, nil
}
func (s *StringMapFieldIndex) FromArgs(args ...interface{}) ([]byte, error) {
if len(args) > 2 || len(args) == 0 {
return nil, fmt.Errorf("must provide one or two arguments")
}
key, ok := args[0].(string)
if !ok {
return nil, fmt.Errorf("argument must be a string: %#v", args[0])
}
if s.Lowercase {
key = strings.ToLower(key)
}
// Add the null character as a terminator
key += "\x00"
if len(args) == 2 {
val, ok := args[1].(string)
if !ok {
return nil, fmt.Errorf("argument must be a string: %#v", args[1])
}
if s.Lowercase {
val = strings.ToLower(val)
}
// Add the null character as a terminator
key += val + "\x00"
}
return []byte(key), nil
}
// UintFieldIndex is used to extract a uint field from an object using
// reflection and builds an index on that field.
type UintFieldIndex struct {
Field string
}
func (u *UintFieldIndex) FromObject(obj interface{}) (bool, []byte, error) {
v := reflect.ValueOf(obj)
v = reflect.Indirect(v) // Dereference the pointer if any
fv := v.FieldByName(u.Field)
if !fv.IsValid() {
return false, nil,
fmt.Errorf("field '%s' for %#v is invalid", u.Field, obj)
}
// Check the type
k := fv.Kind()
size, ok := IsUintType(k)
if !ok {
return false, nil, fmt.Errorf("field %q is of type %v; want a uint", u.Field, k)
}
// Get the value and encode it
val := fv.Uint()
buf := make([]byte, size)
binary.PutUvarint(buf, val)
return true, buf, nil
}
func (u *UintFieldIndex) FromArgs(args ...interface{}) ([]byte, error) {
if len(args) != 1 {
return nil, fmt.Errorf("must provide only a single argument")
}
v := reflect.ValueOf(args[0])
if !v.IsValid() {
return nil, fmt.Errorf("%#v is invalid", args[0])
}
k := v.Kind()
size, ok := IsUintType(k)
if !ok {
return nil, fmt.Errorf("arg is of type %v; want a uint", k)
}
val := v.Uint()
buf := make([]byte, size)
binary.PutUvarint(buf, val)
return buf, nil
}
// IsUintType returns whether the passed type is a type of uint and the number
// of bytes needed to encode the type.
func IsUintType(k reflect.Kind) (size int, okay bool) {
switch k {
case reflect.Uint:
return binary.MaxVarintLen64, true
case reflect.Uint8:
return 2, true
case reflect.Uint16:
return binary.MaxVarintLen16, true
case reflect.Uint32:
return binary.MaxVarintLen32, true
case reflect.Uint64:
return binary.MaxVarintLen64, true
default:
return 0, false
}
}
// UUIDFieldIndex is used to extract a field from an object
// using reflection and builds an index on that field by treating
// it as a UUID. This is an optimization to using a StringFieldIndex
// as the UUID can be more compactly represented in byte form.
type UUIDFieldIndex struct {
Field string
}
func (u *UUIDFieldIndex) FromObject(obj interface{}) (bool, []byte, error) {
v := reflect.ValueOf(obj)
v = reflect.Indirect(v) // Dereference the pointer if any
fv := v.FieldByName(u.Field)
if !fv.IsValid() {
return false, nil,
fmt.Errorf("field '%s' for %#v is invalid", u.Field, obj)
}
val := fv.String()
if val == "" {
return false, nil, nil
}
buf, err := u.parseString(val, true)
return true, buf, err
}
func (u *UUIDFieldIndex) FromArgs(args ...interface{}) ([]byte, error) {
if len(args) != 1 {
return nil, fmt.Errorf("must provide only a single argument")
}
switch arg := args[0].(type) {
case string:
return u.parseString(arg, true)
case []byte:
if len(arg) != 16 {
return nil, fmt.Errorf("byte slice must be 16 characters")
}
return arg, nil
default:
return nil,
fmt.Errorf("argument must be a string or byte slice: %#v", args[0])
}
}
func (u *UUIDFieldIndex) PrefixFromArgs(args ...interface{}) ([]byte, error) {
if len(args) != 1 {
return nil, fmt.Errorf("must provide only a single argument")
}
switch arg := args[0].(type) {
case string:
return u.parseString(arg, false)
case []byte:
return arg, nil
default:
return nil,
fmt.Errorf("argument must be a string or byte slice: %#v", args[0])
}
}
// parseString parses a UUID from the string. If enforceLength is false, it will
// parse a partial UUID. An error is returned if the input, stripped of hyphens,
// is not even length.
func (u *UUIDFieldIndex) parseString(s string, enforceLength bool) ([]byte, error) {
// Verify the length
l := len(s)
if enforceLength && l != 36 {
return nil, fmt.Errorf("UUID must be 36 characters")
} else if l > 36 {
return nil, fmt.Errorf("Invalid UUID length. UUID have 36 characters; got %d", l)
}
hyphens := strings.Count(s, "-")
if hyphens > 4 {
return nil, fmt.Errorf(`UUID should have maximum of 4 "-"; got %d`, hyphens)
}
// The sanitized length is the length of the original string without the "-".
sanitized := strings.Replace(s, "-", "", -1)
sanitizedLength := len(sanitized)
if sanitizedLength%2 != 0 {
return nil, fmt.Errorf("Input (without hyphens) must be even length")
}
dec, err := hex.DecodeString(sanitized)
if err != nil {
return nil, fmt.Errorf("Invalid UUID: %v", err)
}
return dec, nil
}
// FieldSetIndex is used to extract a field from an object using reflection and
// builds an index on whether the field is set by comparing it against its
// type's nil value.
type FieldSetIndex struct {
Field string
}
func (f *FieldSetIndex) FromObject(obj interface{}) (bool, []byte, error) {
v := reflect.ValueOf(obj)
v = reflect.Indirect(v) // Dereference the pointer if any
fv := v.FieldByName(f.Field)
if !fv.IsValid() {
return false, nil,
fmt.Errorf("field '%s' for %#v is invalid", f.Field, obj)
}
if fv.Interface() == reflect.Zero(fv.Type()).Interface() {
return true, []byte{0}, nil
}
return true, []byte{1}, nil
}
func (f *FieldSetIndex) FromArgs(args ...interface{}) ([]byte, error) {
return fromBoolArgs(args)
}
// ConditionalIndex builds an index based on a condition specified by a passed
// user function. This function may examine the passed object and return a
// boolean to encapsulate an arbitrarily complex conditional.
type ConditionalIndex struct {
Conditional ConditionalIndexFunc
}
// ConditionalIndexFunc is the required function interface for a
// ConditionalIndex.
type ConditionalIndexFunc func(obj interface{}) (bool, error)
func (c *ConditionalIndex) FromObject(obj interface{}) (bool, []byte, error) {
// Call the user's function
res, err := c.Conditional(obj)
if err != nil {
return false, nil, fmt.Errorf("ConditionalIndexFunc(%#v) failed: %v", obj, err)
}
if res {
return true, []byte{1}, nil
}
return true, []byte{0}, nil
}
func (c *ConditionalIndex) FromArgs(args ...interface{}) ([]byte, error) {
return fromBoolArgs(args)
}
// fromBoolArgs is a helper that expects only a single boolean argument and
// returns a single length byte array containing either a one or zero depending
// on whether the passed input is true or false respectively.
func fromBoolArgs(args []interface{}) ([]byte, error) {
if len(args) != 1 {
return nil, fmt.Errorf("must provide only a single argument")
}
if val, ok := args[0].(bool); !ok {
return nil, fmt.Errorf("argument must be a boolean type: %#v", args[0])
} else if val {
return []byte{1}, nil
}
return []byte{0}, nil
}
// CompoundIndex is used to build an index using multiple sub-indexes
// Prefix based iteration is supported as long as the appropriate prefix
// of indexers support it. All sub-indexers are only assumed to expect
// a single argument.
type CompoundIndex struct {
Indexes []Indexer
// AllowMissing results in an index based on only the indexers
// that return data. If true, you may end up with 2/3 columns
// indexed which might be useful for an index scan. Otherwise,
// the CompoundIndex requires all indexers to be satisfied.
AllowMissing bool
}
func (c *CompoundIndex) FromObject(raw interface{}) (bool, []byte, error) {
var out []byte
for i, idxRaw := range c.Indexes {
idx, ok := idxRaw.(SingleIndexer)
if !ok {
return false, nil, fmt.Errorf("sub-index %d error: %s", i, "sub-index must be a SingleIndexer")
}
ok, val, err := idx.FromObject(raw)
if err != nil {
return false, nil, fmt.Errorf("sub-index %d error: %v", i, err)
}
if !ok {
if c.AllowMissing {
break
} else {
return false, nil, nil
}
}
out = append(out, val...)
}
return true, out, nil
}
func (c *CompoundIndex) FromArgs(args ...interface{}) ([]byte, error) {
if len(args) != len(c.Indexes) {
return nil, fmt.Errorf("less arguments than index fields")
}
var out []byte
for i, arg := range args {
val, err := c.Indexes[i].FromArgs(arg)
if err != nil {
return nil, fmt.Errorf("sub-index %d error: %v", i, err)
}
out = append(out, val...)
}
return out, nil
}
func (c *CompoundIndex) PrefixFromArgs(args ...interface{}) ([]byte, error) {
if len(args) > len(c.Indexes) {
return nil, fmt.Errorf("more arguments than index fields")
}
var out []byte
for i, arg := range args {
if i+1 < len(args) {
val, err := c.Indexes[i].FromArgs(arg)
if err != nil {
return nil, fmt.Errorf("sub-index %d error: %v", i, err)
}
out = append(out, val...)
} else {
prefixIndexer, ok := c.Indexes[i].(PrefixIndexer)
if !ok {
return nil, fmt.Errorf("sub-index %d does not support prefix scanning", i)
}
val, err := prefixIndexer.PrefixFromArgs(arg)
if err != nil {
return nil, fmt.Errorf("sub-index %d error: %v", i, err)
}
out = append(out, val...)
}
}
return out, nil
}

92
vendor/github.com/hashicorp/go-memdb/memdb.go generated vendored Normal file
View File

@@ -0,0 +1,92 @@
package memdb
import (
"sync"
"sync/atomic"
"unsafe"
"github.com/hashicorp/go-immutable-radix"
)
// MemDB is an in-memory database. It provides a table abstraction,
// which is used to store objects (rows) with multiple indexes based
// on values. The database makes use of immutable radix trees to provide
// transactions and MVCC.
type MemDB struct {
schema *DBSchema
root unsafe.Pointer // *iradix.Tree underneath
primary bool
// There can only be a single writter at once
writer sync.Mutex
}
// NewMemDB creates a new MemDB with the given schema
func NewMemDB(schema *DBSchema) (*MemDB, error) {
// Validate the schema
if err := schema.Validate(); err != nil {
return nil, err
}
// Create the MemDB
db := &MemDB{
schema: schema,
root: unsafe.Pointer(iradix.New()),
primary: true,
}
if err := db.initialize(); err != nil {
return nil, err
}
return db, nil
}
// getRoot is used to do an atomic load of the root pointer
func (db *MemDB) getRoot() *iradix.Tree {
root := (*iradix.Tree)(atomic.LoadPointer(&db.root))
return root
}
// Txn is used to start a new transaction, in either read or write mode.
// There can only be a single concurrent writer, but any number of readers.
func (db *MemDB) Txn(write bool) *Txn {
if write {
db.writer.Lock()
}
txn := &Txn{
db: db,
write: write,
rootTxn: db.getRoot().Txn(),
}
return txn
}
// Snapshot is used to capture a point-in-time snapshot
// of the database that will not be affected by any write
// operations to the existing DB.
func (db *MemDB) Snapshot() *MemDB {
clone := &MemDB{
schema: db.schema,
root: unsafe.Pointer(db.getRoot()),
primary: false,
}
return clone
}
// initialize is used to setup the DB for use after creation
func (db *MemDB) initialize() error {
root := db.getRoot()
for tName, tableSchema := range db.schema.Tables {
for iName := range tableSchema.Indexes {
index := iradix.New()
path := indexPath(tName, iName)
root, _, _ = root.Insert(path, index)
}
}
db.root = unsafe.Pointer(root)
return nil
}
// indexPath returns the path from the root to the given table index
func indexPath(table, index string) []byte {
return []byte(table + "." + index)
}

85
vendor/github.com/hashicorp/go-memdb/schema.go generated vendored Normal file
View File

@@ -0,0 +1,85 @@
package memdb
import "fmt"
// DBSchema contains the full database schema used for MemDB
type DBSchema struct {
Tables map[string]*TableSchema
}
// Validate is used to validate the database schema
func (s *DBSchema) Validate() error {
if s == nil {
return fmt.Errorf("missing schema")
}
if len(s.Tables) == 0 {
return fmt.Errorf("no tables defined")
}
for name, table := range s.Tables {
if name != table.Name {
return fmt.Errorf("table name mis-match for '%s'", name)
}
if err := table.Validate(); err != nil {
return err
}
}
return nil
}
// TableSchema contains the schema for a single table
type TableSchema struct {
Name string
Indexes map[string]*IndexSchema
}
// Validate is used to validate the table schema
func (s *TableSchema) Validate() error {
if s.Name == "" {
return fmt.Errorf("missing table name")
}
if len(s.Indexes) == 0 {
return fmt.Errorf("missing table indexes for '%s'", s.Name)
}
if _, ok := s.Indexes["id"]; !ok {
return fmt.Errorf("must have id index")
}
if !s.Indexes["id"].Unique {
return fmt.Errorf("id index must be unique")
}
if _, ok := s.Indexes["id"].Indexer.(SingleIndexer); !ok {
return fmt.Errorf("id index must be a SingleIndexer")
}
for name, index := range s.Indexes {
if name != index.Name {
return fmt.Errorf("index name mis-match for '%s'", name)
}
if err := index.Validate(); err != nil {
return err
}
}
return nil
}
// IndexSchema contains the schema for an index
type IndexSchema struct {
Name string
AllowMissing bool
Unique bool
Indexer Indexer
}
func (s *IndexSchema) Validate() error {
if s.Name == "" {
return fmt.Errorf("missing index name")
}
if s.Indexer == nil {
return fmt.Errorf("missing index function for '%s'", s.Name)
}
switch s.Indexer.(type) {
case SingleIndexer:
case MultiIndexer:
default:
return fmt.Errorf("indexer for '%s' must be a SingleIndexer or MultiIndexer", s.Name)
}
return nil
}

644
vendor/github.com/hashicorp/go-memdb/txn.go generated vendored Normal file
View File

@@ -0,0 +1,644 @@
package memdb
import (
"bytes"
"fmt"
"strings"
"sync/atomic"
"unsafe"
"github.com/hashicorp/go-immutable-radix"
)
const (
id = "id"
)
var (
// ErrNotFound is returned when the requested item is not found
ErrNotFound = fmt.Errorf("not found")
)
// tableIndex is a tuple of (Table, Index) used for lookups
type tableIndex struct {
Table string
Index string
}
// Txn is a transaction against a MemDB.
// This can be a read or write transaction.
type Txn struct {
db *MemDB
write bool
rootTxn *iradix.Txn
after []func()
modified map[tableIndex]*iradix.Txn
}
// readableIndex returns a transaction usable for reading the given
// index in a table. If a write transaction is in progress, we may need
// to use an existing modified txn.
func (txn *Txn) readableIndex(table, index string) *iradix.Txn {
// Look for existing transaction
if txn.write && txn.modified != nil {
key := tableIndex{table, index}
exist, ok := txn.modified[key]
if ok {
return exist
}
}
// Create a read transaction
path := indexPath(table, index)
raw, _ := txn.rootTxn.Get(path)
indexTxn := raw.(*iradix.Tree).Txn()
return indexTxn
}
// writableIndex returns a transaction usable for modifying the
// given index in a table.
func (txn *Txn) writableIndex(table, index string) *iradix.Txn {
if txn.modified == nil {
txn.modified = make(map[tableIndex]*iradix.Txn)
}
// Look for existing transaction
key := tableIndex{table, index}
exist, ok := txn.modified[key]
if ok {
return exist
}
// Start a new transaction
path := indexPath(table, index)
raw, _ := txn.rootTxn.Get(path)
indexTxn := raw.(*iradix.Tree).Txn()
// If we are the primary DB, enable mutation tracking. Snapshots should
// not notify, otherwise we will trigger watches on the primary DB when
// the writes will not be visible.
indexTxn.TrackMutate(txn.db.primary)
// Keep this open for the duration of the txn
txn.modified[key] = indexTxn
return indexTxn
}
// Abort is used to cancel this transaction.
// This is a noop for read transactions.
func (txn *Txn) Abort() {
// Noop for a read transaction
if !txn.write {
return
}
// Check if already aborted or committed
if txn.rootTxn == nil {
return
}
// Clear the txn
txn.rootTxn = nil
txn.modified = nil
// Release the writer lock since this is invalid
txn.db.writer.Unlock()
}
// Commit is used to finalize this transaction.
// This is a noop for read transactions.
func (txn *Txn) Commit() {
// Noop for a read transaction
if !txn.write {
return
}
// Check if already aborted or committed
if txn.rootTxn == nil {
return
}
// Commit each sub-transaction scoped to (table, index)
for key, subTxn := range txn.modified {
path := indexPath(key.Table, key.Index)
final := subTxn.CommitOnly()
txn.rootTxn.Insert(path, final)
}
// Update the root of the DB
newRoot := txn.rootTxn.CommitOnly()
atomic.StorePointer(&txn.db.root, unsafe.Pointer(newRoot))
// Now issue all of the mutation updates (this is safe to call
// even if mutation tracking isn't enabled); we do this after
// the root pointer is swapped so that waking responders will
// see the new state.
for _, subTxn := range txn.modified {
subTxn.Notify()
}
txn.rootTxn.Notify()
// Clear the txn
txn.rootTxn = nil
txn.modified = nil
// Release the writer lock since this is invalid
txn.db.writer.Unlock()
// Run the deferred functions, if any
for i := len(txn.after); i > 0; i-- {
fn := txn.after[i-1]
fn()
}
}
// Insert is used to add or update an object into the given table
func (txn *Txn) Insert(table string, obj interface{}) error {
if !txn.write {
return fmt.Errorf("cannot insert in read-only transaction")
}
// Get the table schema
tableSchema, ok := txn.db.schema.Tables[table]
if !ok {
return fmt.Errorf("invalid table '%s'", table)
}
// Get the primary ID of the object
idSchema := tableSchema.Indexes[id]
idIndexer := idSchema.Indexer.(SingleIndexer)
ok, idVal, err := idIndexer.FromObject(obj)
if err != nil {
return fmt.Errorf("failed to build primary index: %v", err)
}
if !ok {
return fmt.Errorf("object missing primary index")
}
// Lookup the object by ID first, to see if this is an update
idTxn := txn.writableIndex(table, id)
existing, update := idTxn.Get(idVal)
// On an update, there is an existing object with the given
// primary ID. We do the update by deleting the current object
// and inserting the new object.
for name, indexSchema := range tableSchema.Indexes {
indexTxn := txn.writableIndex(table, name)
// Determine the new index value
var (
ok bool
vals [][]byte
err error
)
switch indexer := indexSchema.Indexer.(type) {
case SingleIndexer:
var val []byte
ok, val, err = indexer.FromObject(obj)
vals = [][]byte{val}
case MultiIndexer:
ok, vals, err = indexer.FromObject(obj)
}
if err != nil {
return fmt.Errorf("failed to build index '%s': %v", name, err)
}
// Handle non-unique index by computing a unique index.
// This is done by appending the primary key which must
// be unique anyways.
if ok && !indexSchema.Unique {
for i := range vals {
vals[i] = append(vals[i], idVal...)
}
}
// Handle the update by deleting from the index first
if update {
var (
okExist bool
valsExist [][]byte
err error
)
switch indexer := indexSchema.Indexer.(type) {
case SingleIndexer:
var valExist []byte
okExist, valExist, err = indexer.FromObject(existing)
valsExist = [][]byte{valExist}
case MultiIndexer:
okExist, valsExist, err = indexer.FromObject(existing)
}
if err != nil {
return fmt.Errorf("failed to build index '%s': %v", name, err)
}
if okExist {
for i, valExist := range valsExist {
// Handle non-unique index by computing a unique index.
// This is done by appending the primary key which must
// be unique anyways.
if !indexSchema.Unique {
valExist = append(valExist, idVal...)
}
// If we are writing to the same index with the same value,
// we can avoid the delete as the insert will overwrite the
// value anyways.
if i >= len(vals) || !bytes.Equal(valExist, vals[i]) {
indexTxn.Delete(valExist)
}
}
}
}
// If there is no index value, either this is an error or an expected
// case and we can skip updating
if !ok {
if indexSchema.AllowMissing {
continue
} else {
return fmt.Errorf("missing value for index '%s'", name)
}
}
// Update the value of the index
for _, val := range vals {
indexTxn.Insert(val, obj)
}
}
return nil
}
// Delete is used to delete a single object from the given table
// This object must already exist in the table
func (txn *Txn) Delete(table string, obj interface{}) error {
if !txn.write {
return fmt.Errorf("cannot delete in read-only transaction")
}
// Get the table schema
tableSchema, ok := txn.db.schema.Tables[table]
if !ok {
return fmt.Errorf("invalid table '%s'", table)
}
// Get the primary ID of the object
idSchema := tableSchema.Indexes[id]
idIndexer := idSchema.Indexer.(SingleIndexer)
ok, idVal, err := idIndexer.FromObject(obj)
if err != nil {
return fmt.Errorf("failed to build primary index: %v", err)
}
if !ok {
return fmt.Errorf("object missing primary index")
}
// Lookup the object by ID first, check fi we should continue
idTxn := txn.writableIndex(table, id)
existing, ok := idTxn.Get(idVal)
if !ok {
return ErrNotFound
}
// Remove the object from all the indexes
for name, indexSchema := range tableSchema.Indexes {
indexTxn := txn.writableIndex(table, name)
// Handle the update by deleting from the index first
var (
ok bool
vals [][]byte
err error
)
switch indexer := indexSchema.Indexer.(type) {
case SingleIndexer:
var val []byte
ok, val, err = indexer.FromObject(existing)
vals = [][]byte{val}
case MultiIndexer:
ok, vals, err = indexer.FromObject(existing)
}
if err != nil {
return fmt.Errorf("failed to build index '%s': %v", name, err)
}
if ok {
// Handle non-unique index by computing a unique index.
// This is done by appending the primary key which must
// be unique anyways.
for _, val := range vals {
if !indexSchema.Unique {
val = append(val, idVal...)
}
indexTxn.Delete(val)
}
}
}
return nil
}
// DeletePrefix is used to delete an entire subtree based on a prefix.
// The given index must be a prefix index, and will be used to perform a scan and enumerate the set of objects to delete.
// These will be removed from all other indexes, and then a special prefix operation will delete the objects from the given index in an efficient subtree delete operation.
// This is useful when you have a very large number of objects indexed by the given index, along with a much smaller number of entries in the other indexes for those objects.
func (txn *Txn) DeletePrefix(table string, prefix_index string, prefix string) (bool, error) {
if !txn.write {
return false, fmt.Errorf("cannot delete in read-only transaction")
}
if !strings.HasSuffix(prefix_index, "_prefix") {
return false, fmt.Errorf("Index name for DeletePrefix must be a prefix index, Got %v ", prefix_index)
}
deletePrefixIndex := strings.TrimSuffix(prefix_index, "_prefix")
// Get an iterator over all of the keys with the given prefix.
entries, err := txn.Get(table, prefix_index, prefix)
if err != nil {
return false, fmt.Errorf("failed kvs lookup: %s", err)
}
// Get the table schema
tableSchema, ok := txn.db.schema.Tables[table]
if !ok {
return false, fmt.Errorf("invalid table '%s'", table)
}
foundAny := false
for entry := entries.Next(); entry != nil; entry = entries.Next() {
if !foundAny {
foundAny = true
}
// Get the primary ID of the object
idSchema := tableSchema.Indexes[id]
idIndexer := idSchema.Indexer.(SingleIndexer)
ok, idVal, err := idIndexer.FromObject(entry)
if err != nil {
return false, fmt.Errorf("failed to build primary index: %v", err)
}
if !ok {
return false, fmt.Errorf("object missing primary index")
}
// Remove the object from all the indexes except the given prefix index
for name, indexSchema := range tableSchema.Indexes {
if name == deletePrefixIndex {
continue
}
indexTxn := txn.writableIndex(table, name)
// Handle the update by deleting from the index first
var (
ok bool
vals [][]byte
err error
)
switch indexer := indexSchema.Indexer.(type) {
case SingleIndexer:
var val []byte
ok, val, err = indexer.FromObject(entry)
vals = [][]byte{val}
case MultiIndexer:
ok, vals, err = indexer.FromObject(entry)
}
if err != nil {
return false, fmt.Errorf("failed to build index '%s': %v", name, err)
}
if ok {
// Handle non-unique index by computing a unique index.
// This is done by appending the primary key which must
// be unique anyways.
for _, val := range vals {
if !indexSchema.Unique {
val = append(val, idVal...)
}
indexTxn.Delete(val)
}
}
}
}
if foundAny {
indexTxn := txn.writableIndex(table, deletePrefixIndex)
ok = indexTxn.DeletePrefix([]byte(prefix))
if !ok {
panic(fmt.Errorf("prefix %v matched some entries but DeletePrefix did not delete any ", prefix))
}
return true, nil
}
return false, nil
}
// DeleteAll is used to delete all the objects in a given table
// matching the constraints on the index
func (txn *Txn) DeleteAll(table, index string, args ...interface{}) (int, error) {
if !txn.write {
return 0, fmt.Errorf("cannot delete in read-only transaction")
}
// Get all the objects
iter, err := txn.Get(table, index, args...)
if err != nil {
return 0, err
}
// Put them into a slice so there are no safety concerns while actually
// performing the deletes
var objs []interface{}
for {
obj := iter.Next()
if obj == nil {
break
}
objs = append(objs, obj)
}
// Do the deletes
num := 0
for _, obj := range objs {
if err := txn.Delete(table, obj); err != nil {
return num, err
}
num++
}
return num, nil
}
// FirstWatch is used to return the first matching object for
// the given constraints on the index along with the watch channel
func (txn *Txn) FirstWatch(table, index string, args ...interface{}) (<-chan struct{}, interface{}, error) {
// Get the index value
indexSchema, val, err := txn.getIndexValue(table, index, args...)
if err != nil {
return nil, nil, err
}
// Get the index itself
indexTxn := txn.readableIndex(table, indexSchema.Name)
// Do an exact lookup
if indexSchema.Unique && val != nil && indexSchema.Name == index {
watch, obj, ok := indexTxn.GetWatch(val)
if !ok {
return watch, nil, nil
}
return watch, obj, nil
}
// Handle non-unique index by using an iterator and getting the first value
iter := indexTxn.Root().Iterator()
watch := iter.SeekPrefixWatch(val)
_, value, _ := iter.Next()
return watch, value, nil
}
// First is used to return the first matching object for
// the given constraints on the index
func (txn *Txn) First(table, index string, args ...interface{}) (interface{}, error) {
_, val, err := txn.FirstWatch(table, index, args...)
return val, err
}
// LongestPrefix is used to fetch the longest prefix match for the given
// constraints on the index. Note that this will not work with the memdb
// StringFieldIndex because it adds null terminators which prevent the
// algorithm from correctly finding a match (it will get to right before the
// null and fail to find a leaf node). This should only be used where the prefix
// given is capable of matching indexed entries directly, which typically only
// applies to a custom indexer. See the unit test for an example.
func (txn *Txn) LongestPrefix(table, index string, args ...interface{}) (interface{}, error) {
// Enforce that this only works on prefix indexes.
if !strings.HasSuffix(index, "_prefix") {
return nil, fmt.Errorf("must use '%s_prefix' on index", index)
}
// Get the index value.
indexSchema, val, err := txn.getIndexValue(table, index, args...)
if err != nil {
return nil, err
}
// This algorithm only makes sense against a unique index, otherwise the
// index keys will have the IDs appended to them.
if !indexSchema.Unique {
return nil, fmt.Errorf("index '%s' is not unique", index)
}
// Find the longest prefix match with the given index.
indexTxn := txn.readableIndex(table, indexSchema.Name)
if _, value, ok := indexTxn.Root().LongestPrefix(val); ok {
return value, nil
}
return nil, nil
}
// getIndexValue is used to get the IndexSchema and the value
// used to scan the index given the parameters. This handles prefix based
// scans when the index has the "_prefix" suffix. The index must support
// prefix iteration.
func (txn *Txn) getIndexValue(table, index string, args ...interface{}) (*IndexSchema, []byte, error) {
// Get the table schema
tableSchema, ok := txn.db.schema.Tables[table]
if !ok {
return nil, nil, fmt.Errorf("invalid table '%s'", table)
}
// Check for a prefix scan
prefixScan := false
if strings.HasSuffix(index, "_prefix") {
index = strings.TrimSuffix(index, "_prefix")
prefixScan = true
}
// Get the index schema
indexSchema, ok := tableSchema.Indexes[index]
if !ok {
return nil, nil, fmt.Errorf("invalid index '%s'", index)
}
// Hot-path for when there are no arguments
if len(args) == 0 {
return indexSchema, nil, nil
}
// Special case the prefix scanning
if prefixScan {
prefixIndexer, ok := indexSchema.Indexer.(PrefixIndexer)
if !ok {
return indexSchema, nil,
fmt.Errorf("index '%s' does not support prefix scanning", index)
}
val, err := prefixIndexer.PrefixFromArgs(args...)
if err != nil {
return indexSchema, nil, fmt.Errorf("index error: %v", err)
}
return indexSchema, val, err
}
// Get the exact match index
val, err := indexSchema.Indexer.FromArgs(args...)
if err != nil {
return indexSchema, nil, fmt.Errorf("index error: %v", err)
}
return indexSchema, val, err
}
// ResultIterator is used to iterate over a list of results
// from a Get query on a table.
type ResultIterator interface {
WatchCh() <-chan struct{}
Next() interface{}
}
// Get is used to construct a ResultIterator over all the
// rows that match the given constraints of an index.
func (txn *Txn) Get(table, index string, args ...interface{}) (ResultIterator, error) {
// Get the index value to scan
indexSchema, val, err := txn.getIndexValue(table, index, args...)
if err != nil {
return nil, err
}
// Get the index itself
indexTxn := txn.readableIndex(table, indexSchema.Name)
indexRoot := indexTxn.Root()
// Get an interator over the index
indexIter := indexRoot.Iterator()
// Seek the iterator to the appropriate sub-set
watchCh := indexIter.SeekPrefixWatch(val)
// Create an iterator
iter := &radixIterator{
iter: indexIter,
watchCh: watchCh,
}
return iter, nil
}
// Defer is used to push a new arbitrary function onto a stack which
// gets called when a transaction is committed and finished. Deferred
// functions are called in LIFO order, and only invoked at the end of
// write transactions.
func (txn *Txn) Defer(fn func()) {
txn.after = append(txn.after, fn)
}
// radixIterator is used to wrap an underlying iradix iterator.
// This is much more efficient than a sliceIterator as we are not
// materializing the entire view.
type radixIterator struct {
iter *iradix.Iterator
watchCh <-chan struct{}
}
func (r *radixIterator) WatchCh() <-chan struct{} {
return r.watchCh
}
func (r *radixIterator) Next() interface{} {
_, value, ok := r.iter.Next()
if !ok {
return nil
}
return value
}

129
vendor/github.com/hashicorp/go-memdb/watch.go generated vendored Normal file
View File

@@ -0,0 +1,129 @@
package memdb
import (
"context"
"time"
)
// WatchSet is a collection of watch channels.
type WatchSet map[<-chan struct{}]struct{}
// NewWatchSet constructs a new watch set.
func NewWatchSet() WatchSet {
return make(map[<-chan struct{}]struct{})
}
// Add appends a watchCh to the WatchSet if non-nil.
func (w WatchSet) Add(watchCh <-chan struct{}) {
if w == nil {
return
}
if _, ok := w[watchCh]; !ok {
w[watchCh] = struct{}{}
}
}
// AddWithLimit appends a watchCh to the WatchSet if non-nil, and if the given
// softLimit hasn't been exceeded. Otherwise, it will watch the given alternate
// channel. It's expected that the altCh will be the same on many calls to this
// function, so you will exceed the soft limit a little bit if you hit this, but
// not by much.
//
// This is useful if you want to track individual items up to some limit, after
// which you watch a higher-level channel (usually a channel from start start of
// an iterator higher up in the radix tree) that will watch a superset of items.
func (w WatchSet) AddWithLimit(softLimit int, watchCh <-chan struct{}, altCh <-chan struct{}) {
// This is safe for a nil WatchSet so we don't need to check that here.
if len(w) < softLimit {
w.Add(watchCh)
} else {
w.Add(altCh)
}
}
// Watch is used to wait for either the watch set to trigger or a timeout.
// Returns true on timeout.
func (w WatchSet) Watch(timeoutCh <-chan time.Time) bool {
if w == nil {
return false
}
// Create a context that gets cancelled when the timeout is triggered
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go func() {
select {
case <-timeoutCh:
cancel()
case <-ctx.Done():
}
}()
return w.WatchCtx(ctx) == context.Canceled
}
// WatchCtx is used to wait for either the watch set to trigger or for the
// context to be cancelled. Watch with a timeout channel can be mimicked by
// creating a context with a deadline. WatchCtx should be preferred over Watch.
func (w WatchSet) WatchCtx(ctx context.Context) error {
if w == nil {
return nil
}
if n := len(w); n <= aFew {
idx := 0
chunk := make([]<-chan struct{}, aFew)
for watchCh := range w {
chunk[idx] = watchCh
idx++
}
return watchFew(ctx, chunk)
}
return w.watchMany(ctx)
}
// watchMany is used if there are many watchers.
func (w WatchSet) watchMany(ctx context.Context) error {
// Set up a goroutine for each watcher.
triggerCh := make(chan struct{}, 1)
watcher := func(chunk []<-chan struct{}) {
if err := watchFew(ctx, chunk); err == nil {
select {
case triggerCh <- struct{}{}:
default:
}
}
}
// Apportion the watch channels into chunks we can feed into the
// watchFew helper.
idx := 0
chunk := make([]<-chan struct{}, aFew)
for watchCh := range w {
subIdx := idx % aFew
chunk[subIdx] = watchCh
idx++
// Fire off this chunk and start a fresh one.
if idx%aFew == 0 {
go watcher(chunk)
chunk = make([]<-chan struct{}, aFew)
}
}
// Make sure to watch any residual channels in the last chunk.
if idx%aFew != 0 {
go watcher(chunk)
}
// Wait for a channel to trigger or timeout.
select {
case <-triggerCh:
return nil
case <-ctx.Done():
return ctx.Err()
}
}

117
vendor/github.com/hashicorp/go-memdb/watch_few.go generated vendored Normal file
View File

@@ -0,0 +1,117 @@
package memdb
//go:generate sh -c "go run watch-gen/main.go >watch_few.go"
import(
"context"
)
// aFew gives how many watchers this function is wired to support. You must
// always pass a full slice of this length, but unused channels can be nil.
const aFew = 32
// watchFew is used if there are only a few watchers as a performance
// optimization.
func watchFew(ctx context.Context, ch []<-chan struct{}) error {
select {
case <-ch[0]:
return nil
case <-ch[1]:
return nil
case <-ch[2]:
return nil
case <-ch[3]:
return nil
case <-ch[4]:
return nil
case <-ch[5]:
return nil
case <-ch[6]:
return nil
case <-ch[7]:
return nil
case <-ch[8]:
return nil
case <-ch[9]:
return nil
case <-ch[10]:
return nil
case <-ch[11]:
return nil
case <-ch[12]:
return nil
case <-ch[13]:
return nil
case <-ch[14]:
return nil
case <-ch[15]:
return nil
case <-ch[16]:
return nil
case <-ch[17]:
return nil
case <-ch[18]:
return nil
case <-ch[19]:
return nil
case <-ch[20]:
return nil
case <-ch[21]:
return nil
case <-ch[22]:
return nil
case <-ch[23]:
return nil
case <-ch[24]:
return nil
case <-ch[25]:
return nil
case <-ch[26]:
return nil
case <-ch[27]:
return nil
case <-ch[28]:
return nil
case <-ch[29]:
return nil
case <-ch[30]:
return nil
case <-ch[31]:
return nil
case <-ctx.Done():
return ctx.Err()
}
}

6
vendor/vendor.json vendored
View File

@@ -1032,6 +1032,12 @@
"revision": "8aac2701530899b64bdea735a1de8da899815220",
"revisionTime": "2017-07-25T22:12:15Z"
},
{
"checksumSHA1": "2JVfMLNCW8hfVlPAwAHlOX4HW2s=",
"path": "github.com/hashicorp/go-memdb",
"revision": "75ff99613d288868d8888ec87594525906815dbc",
"revisionTime": "2017-10-05T03:07:53Z"
},
{
"checksumSHA1": "g7uHECbzuaWwdxvwoyxBwgeERPk=",
"path": "github.com/hashicorp/go-multierror",