From 6b9ce0c8c50fb52fbcc59e6016276cf96fa146dc Mon Sep 17 00:00:00 2001 From: Vishal Nayak Date: Wed, 11 Oct 2017 10:21:20 -0700 Subject: [PATCH] Porting identity store (#3419) * porting identity to OSS * changes that glue things together * add testing bits * wrapped entity id * fix mount error * some more changes to core * fix storagepacker tests * fix some more tests * fix mount tests * fix http mount tests * audit changes for identity * remove upgrade structs on the oss side * added go-memdb to vendor --- Makefile | 3 + audit/format.go | 3 + audit/format_jsonx_test.go | 5 +- helper/identity/identity.go | 64 + helper/identity/types.pb.go | 411 ++++ helper/identity/types.proto | 151 ++ helper/storagepacker/storagepacker.go | 351 +++ helper/storagepacker/storagepacker_test.go | 172 ++ helper/storagepacker/types.pb.go | 101 + helper/storagepacker/types.proto | 15 + helper/wrapping/wrapinfo.go | 4 + http/handler_test.go | 20 + http/logical_test.go | 2 + http/sys_generate_root_test.go | 2 + http/sys_mount_test.go | 120 + logical/auth.go | 4 + logical/identity.go | 16 +- logical/request.go | 8 + logical/translate_response.go | 3 + vault/core.go | 93 +- vault/dynamic_system_view.go | 2 +- vault/identity_lookup.go | 78 + vault/identity_store.go | 334 +++ vault/identity_store_aliases.go | 364 +++ vault/identity_store_aliases_test.go | 531 +++++ vault/identity_store_entities.go | 501 ++++ vault/identity_store_entities_test.go | 783 ++++++ vault/identity_store_groups.go | 286 +++ vault/identity_store_groups_test.go | 666 ++++++ vault/identity_store_schema.go | 180 ++ vault/identity_store_structs.go | 75 + vault/identity_store_test.go | 269 +++ vault/identity_store_upgrade.go | 86 + vault/identity_store_util.go | 2122 +++++++++++++++++ vault/identity_store_util_test.go | 40 + vault/logical_system_test.go | 11 + vault/mount.go | 50 +- vault/mount_test.go | 26 +- vault/request_handling.go | 37 +- vault/router.go | 47 + vault/testing.go | 25 +- vault/token_store.go | 11 + vault/token_store_test.go | 4 + vault/wrapping.go | 10 +- vendor/github.com/hashicorp/go-memdb/LICENSE | 363 +++ .../github.com/hashicorp/go-memdb/README.md | 98 + .../github.com/hashicorp/go-memdb/filter.go | 33 + vendor/github.com/hashicorp/go-memdb/index.go | 569 +++++ vendor/github.com/hashicorp/go-memdb/memdb.go | 92 + .../github.com/hashicorp/go-memdb/schema.go | 85 + vendor/github.com/hashicorp/go-memdb/txn.go | 644 +++++ vendor/github.com/hashicorp/go-memdb/watch.go | 129 + .../hashicorp/go-memdb/watch_few.go | 117 + vendor/vendor.json | 6 + 54 files changed, 10162 insertions(+), 60 deletions(-) create mode 100644 helper/identity/identity.go create mode 100644 helper/identity/types.pb.go create mode 100644 helper/identity/types.proto create mode 100644 helper/storagepacker/storagepacker.go create mode 100644 helper/storagepacker/storagepacker_test.go create mode 100644 helper/storagepacker/types.pb.go create mode 100644 helper/storagepacker/types.proto create mode 100644 vault/identity_lookup.go create mode 100644 vault/identity_store.go create mode 100644 vault/identity_store_aliases.go create mode 100644 vault/identity_store_aliases_test.go create mode 100644 vault/identity_store_entities.go create mode 100644 vault/identity_store_entities_test.go create mode 100644 vault/identity_store_groups.go create mode 100644 vault/identity_store_groups_test.go create mode 100644 vault/identity_store_schema.go create mode 100644 vault/identity_store_structs.go create mode 100644 vault/identity_store_test.go create mode 100644 vault/identity_store_upgrade.go create mode 100644 vault/identity_store_util.go create mode 100644 vault/identity_store_util_test.go create mode 100644 vendor/github.com/hashicorp/go-memdb/LICENSE create mode 100644 vendor/github.com/hashicorp/go-memdb/README.md create mode 100644 vendor/github.com/hashicorp/go-memdb/filter.go create mode 100644 vendor/github.com/hashicorp/go-memdb/index.go create mode 100644 vendor/github.com/hashicorp/go-memdb/memdb.go create mode 100644 vendor/github.com/hashicorp/go-memdb/schema.go create mode 100644 vendor/github.com/hashicorp/go-memdb/txn.go create mode 100644 vendor/github.com/hashicorp/go-memdb/watch.go create mode 100644 vendor/github.com/hashicorp/go-memdb/watch_few.go diff --git a/Makefile b/Makefile index 0bf1d14647..26db0be5b9 100644 --- a/Makefile +++ b/Makefile @@ -73,6 +73,9 @@ bootstrap: proto: protoc -I helper/forwarding -I vault -I ../../.. vault/*.proto --go_out=plugins=grpc:vault protoc -I helper/forwarding -I vault -I ../../.. helper/forwarding/types.proto --go_out=plugins=grpc:helper/forwarding + protoc -I helper/storagepacker helper/storagepacker/types.proto --go_out=plugins=grpc:helper/storagepacker + protoc -I helper/identity -I ../../.. helper/identity/types.proto --go_out=plugins=grpc:helper/identity + sed -i -e 's/Idp/IDP/' -e 's/Url/URL/' -e 's/Id/ID/' -e 's/EntityId/EntityID/' -e 's/Api/API/' helper/identity/types.pb.go helper/storagepacker/types.pb.go fmtcheck: @sh -c "'$(CURDIR)/scripts/gofmtcheck.sh'" diff --git a/audit/format.go b/audit/format.go index 18eb254eb5..df57276bfd 100644 --- a/audit/format.go +++ b/audit/format.go @@ -123,6 +123,7 @@ func (f *AuditFormatter) FormatRequest( DisplayName: auth.DisplayName, Policies: auth.Policies, Metadata: auth.Metadata, + EntityID: auth.EntityID, RemainingUses: req.ClientTokenRemainingUses, }, @@ -315,6 +316,7 @@ func (f *AuditFormatter) FormatResponse( Policies: auth.Policies, Metadata: auth.Metadata, RemainingUses: req.ClientTokenRemainingUses, + EntityID: auth.EntityID, }, Request: AuditRequest{ @@ -397,6 +399,7 @@ type AuditAuth struct { Metadata map[string]string `json:"metadata"` NumUses int `json:"num_uses,omitempty"` RemainingUses int `json:"remaining_uses,omitempty"` + EntityID string `json:"entity_id"` } type AuditSecret struct { diff --git a/audit/format_jsonx_test.go b/audit/format_jsonx_test.go index b04ccd0be1..24755430de 100644 --- a/audit/format_jsonx_test.go +++ b/audit/format_jsonx_test.go @@ -9,6 +9,7 @@ import ( "errors" "fmt" + "github.com/hashicorp/vault/helper/salt" "github.com/hashicorp/vault/logical" ) @@ -50,7 +51,7 @@ func TestFormatJSONx_formatRequest(t *testing.T) { errors.New("this is an error"), "", "", - fmt.Sprintf(`bar%stesttokenrootthis is an errorbarupdate/foo127.0.0.160request`, + fmt.Sprintf(`bar%stesttokenrootthis is an errorbarupdate/foo127.0.0.160request`, fooSalted), }, "auth, request with prefix": { @@ -71,7 +72,7 @@ func TestFormatJSONx_formatRequest(t *testing.T) { errors.New("this is an error"), "", "@cee: ", - fmt.Sprintf(`bar%stesttokenrootthis is an errorbarupdate/foo127.0.0.160request`, + fmt.Sprintf(`bar%stesttokenrootthis is an errorbarupdate/foo127.0.0.160request`, fooSalted), }, } diff --git a/helper/identity/identity.go b/helper/identity/identity.go new file mode 100644 index 0000000000..a0d812a96b --- /dev/null +++ b/helper/identity/identity.go @@ -0,0 +1,64 @@ +package identity + +import ( + "fmt" + + "github.com/gogo/protobuf/proto" +) + +func (g *Group) Clone() (*Group, error) { + if g == nil { + return nil, fmt.Errorf("nil group") + } + + marshaledGroup, err := proto.Marshal(g) + if err != nil { + return nil, fmt.Errorf("failed to marshal group: %v", err) + } + + var clonedGroup Group + err = proto.Unmarshal(marshaledGroup, &clonedGroup) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal group: %v", err) + } + + return &clonedGroup, nil +} + +func (e *Entity) Clone() (*Entity, error) { + if e == nil { + return nil, fmt.Errorf("nil entity") + } + + marshaledEntity, err := proto.Marshal(e) + if err != nil { + return nil, fmt.Errorf("failed to marshal entity: %v", err) + } + + var clonedEntity Entity + err = proto.Unmarshal(marshaledEntity, &clonedEntity) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal entity: %v", err) + } + + return &clonedEntity, nil +} + +func (p *Alias) Clone() (*Alias, error) { + if p == nil { + return nil, fmt.Errorf("nil alias") + } + + marshaledAlias, err := proto.Marshal(p) + if err != nil { + return nil, fmt.Errorf("failed to marshal alias: %v", err) + } + + var clonedAlias Alias + err = proto.Unmarshal(marshaledAlias, &clonedAlias) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal alias: %v", err) + } + + return &clonedAlias, nil +} diff --git a/helper/identity/types.pb.go b/helper/identity/types.pb.go new file mode 100644 index 0000000000..d6803d5cb1 --- /dev/null +++ b/helper/identity/types.pb.go @@ -0,0 +1,411 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// source: types.proto + +/* +Package identity is a generated protocol buffer package. + +It is generated from these files: + types.proto + +It has these top-level messages: + Group + Entity + Alias +*/ +package identity + +import proto "github.com/golang/protobuf/proto" +import fmt "fmt" +import math "math" +import google_protobuf "github.com/golang/protobuf/ptypes/timestamp" + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = fmt.Errorf +var _ = math.Inf + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the proto package it is being compiled against. +// A compilation error at this line likely means your copy of the +// proto package needs to be updated. +const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package + +// Group represents an identity group. +type Group struct { + // ID is the unique identifier for this group + ID string `protobuf:"bytes,1,opt,name=id" json:"id,omitempty"` + // Name is the unique name for this group + Name string `protobuf:"bytes,2,opt,name=name" json:"name,omitempty"` + // Policies are the vault policies to be granted to members of this group + Policies []string `protobuf:"bytes,3,rep,name=policies" json:"policies,omitempty"` + // ParentGroupIDs are the identifiers of those groups to which this group is a + // member of. These will serve as references to the parent group in the + // hierarchy. + ParentGroupIDs []string `protobuf:"bytes,4,rep,name=parent_group_ids,json=parentGroupIds" json:"parent_group_ids,omitempty"` + // MemberEntityIDs are the identifiers of entities which are members of this + // group + MemberEntityIDs []string `protobuf:"bytes,5,rep,name=member_entity_ids,json=memberEntityIDs" json:"member_entity_ids,omitempty"` + // Metadata represents the custom data tied with this group + Metadata map[string]string `protobuf:"bytes,6,rep,name=metadata" json:"metadata,omitempty" protobuf_key:"bytes,1,opt,name=key" protobuf_val:"bytes,2,opt,name=value"` + // CreationTime is the time at which this group was created + CreationTime *google_protobuf.Timestamp `protobuf:"bytes,7,opt,name=creation_time,json=creationTime" json:"creation_time,omitempty"` + // LastUpdateTime is the time at which this group was last modified + LastUpdateTime *google_protobuf.Timestamp `protobuf:"bytes,8,opt,name=last_update_time,json=lastUpdateTime" json:"last_update_time,omitempty"` + // ModifyIndex tracks the number of updates to the group. It is useful to detect + // updates to the groups. + ModifyIndex uint64 `protobuf:"varint,9,opt,name=modify_index,json=modifyIndex" json:"modify_index,omitempty"` + // BucketKeyHash is the MD5 hash of the storage bucket key into which this + // group is stored in the underlying storage. This is useful to find all + // the groups belonging to a particular bucket during invalidation of the + // storage key. + BucketKeyHash string `protobuf:"bytes,10,opt,name=bucket_key_hash,json=bucketKeyHash" json:"bucket_key_hash,omitempty"` +} + +func (m *Group) Reset() { *m = Group{} } +func (m *Group) String() string { return proto.CompactTextString(m) } +func (*Group) ProtoMessage() {} +func (*Group) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{0} } + +func (m *Group) GetID() string { + if m != nil { + return m.ID + } + return "" +} + +func (m *Group) GetName() string { + if m != nil { + return m.Name + } + return "" +} + +func (m *Group) GetPolicies() []string { + if m != nil { + return m.Policies + } + return nil +} + +func (m *Group) GetParentGroupIDs() []string { + if m != nil { + return m.ParentGroupIDs + } + return nil +} + +func (m *Group) GetMemberEntityIDs() []string { + if m != nil { + return m.MemberEntityIDs + } + return nil +} + +func (m *Group) GetMetadata() map[string]string { + if m != nil { + return m.Metadata + } + return nil +} + +func (m *Group) GetCreationTime() *google_protobuf.Timestamp { + if m != nil { + return m.CreationTime + } + return nil +} + +func (m *Group) GetLastUpdateTime() *google_protobuf.Timestamp { + if m != nil { + return m.LastUpdateTime + } + return nil +} + +func (m *Group) GetModifyIndex() uint64 { + if m != nil { + return m.ModifyIndex + } + return 0 +} + +func (m *Group) GetBucketKeyHash() string { + if m != nil { + return m.BucketKeyHash + } + return "" +} + +// Entity represents an entity that gets persisted and indexed. +// Entity is fundamentally composed of zero or many aliases. +type Entity struct { + // Aliases are the identities that this entity is made of. This can be + // empty as well to favor being able to create the entity first and then + // incrementally adding aliases. + Aliases []*Alias `protobuf:"bytes,1,rep,name=aliases" json:"aliases,omitempty"` + // ID is the unique identifier of the entity which always be a UUID. This + // should never be allowed to be updated. + ID string `protobuf:"bytes,2,opt,name=id" json:"id,omitempty"` + // Name is a unique identifier of the entity which is intended to be + // human-friendly. The default name might not be human friendly since it + // gets suffixed by a UUID, but it can optionally be updated, unlike the ID + // field. + Name string `protobuf:"bytes,3,opt,name=name" json:"name,omitempty"` + // Metadata represents the explicit metadata which is set by the + // clients. This is useful to tie any information pertaining to the + // aliases. This is a non-unique field of entity, meaning multiple + // entities can have the same metadata set. Entities will be indexed based + // on this explicit metadata. This enables virtual groupings of entities + // based on its metadata. + Metadata map[string]string `protobuf:"bytes,4,rep,name=metadata" json:"metadata,omitempty" protobuf_key:"bytes,1,opt,name=key" protobuf_val:"bytes,2,opt,name=value"` + // CreationTime is the time at which this entity is first created. + CreationTime *google_protobuf.Timestamp `protobuf:"bytes,5,opt,name=creation_time,json=creationTime" json:"creation_time,omitempty"` + // LastUpdateTime is the most recent time at which the properties of this + // entity got modified. This is helpful in filtering out entities based on + // its age and to take action on them, if desired. + LastUpdateTime *google_protobuf.Timestamp `protobuf:"bytes,6,opt,name=last_update_time,json=lastUpdateTime" json:"last_update_time,omitempty"` + // MergedEntityIDs are the entities which got merged to this one. Entities + // will be indexed based on all the entities that got merged into it. This + // helps to apply the actions on this entity on the tokens that are merged + // to the merged entities. Merged entities will be deleted entirely and + // this is the only trackable trail of its earlier presence. + MergedEntityIDs []string `protobuf:"bytes,7,rep,name=merged_entity_ids,json=mergedEntityIDs" json:"merged_entity_ids,omitempty"` + // Policies the entity is entitled to + Policies []string `protobuf:"bytes,8,rep,name=policies" json:"policies,omitempty"` + // BucketKeyHash is the MD5 hash of the storage bucket key into which this + // entity is stored in the underlying storage. This is useful to find all + // the entities belonging to a particular bucket during invalidation of the + // storage key. + BucketKeyHash string `protobuf:"bytes,9,opt,name=bucket_key_hash,json=bucketKeyHash" json:"bucket_key_hash,omitempty"` +} + +func (m *Entity) Reset() { *m = Entity{} } +func (m *Entity) String() string { return proto.CompactTextString(m) } +func (*Entity) ProtoMessage() {} +func (*Entity) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{1} } + +func (m *Entity) GetAliases() []*Alias { + if m != nil { + return m.Aliases + } + return nil +} + +func (m *Entity) GetID() string { + if m != nil { + return m.ID + } + return "" +} + +func (m *Entity) GetName() string { + if m != nil { + return m.Name + } + return "" +} + +func (m *Entity) GetMetadata() map[string]string { + if m != nil { + return m.Metadata + } + return nil +} + +func (m *Entity) GetCreationTime() *google_protobuf.Timestamp { + if m != nil { + return m.CreationTime + } + return nil +} + +func (m *Entity) GetLastUpdateTime() *google_protobuf.Timestamp { + if m != nil { + return m.LastUpdateTime + } + return nil +} + +func (m *Entity) GetMergedEntityIDs() []string { + if m != nil { + return m.MergedEntityIDs + } + return nil +} + +func (m *Entity) GetPolicies() []string { + if m != nil { + return m.Policies + } + return nil +} + +func (m *Entity) GetBucketKeyHash() string { + if m != nil { + return m.BucketKeyHash + } + return "" +} + +// Alias represents the alias that gets stored inside of the +// entity object in storage and also represents in an in-memory index of an +// alias object. +type Alias struct { + // ID is the unique identifier that represents this alias + ID string `protobuf:"bytes,1,opt,name=id" json:"id,omitempty"` + // EntityID is the entity identifier to which this alias belongs to + EntityID string `protobuf:"bytes,2,opt,name=entity_id,json=entityId" json:"entity_id,omitempty"` + // MountType is the backend mount's type to which this alias belongs to. + // This enables categorically querying aliases of specific backend types. + MountType string `protobuf:"bytes,3,opt,name=mount_type,json=mountType" json:"mount_type,omitempty"` + // MountAccessor is the backend mount's accessor to which this alias + // belongs to. + MountAccessor string `protobuf:"bytes,4,opt,name=mount_accessor,json=mountAccessor" json:"mount_accessor,omitempty"` + // MountPath is the backend mount's path to which the Maccessor belongs to. This + // field is not used for any operational purposes. This is only returned when + // alias is read, only as a nicety. + MountPath string `protobuf:"bytes,5,opt,name=mount_path,json=mountPath" json:"mount_path,omitempty"` + // Metadata is the explicit metadata that clients set against an entity + // which enables virtual grouping of aliases. Aliases will be indexed + // against their metadata. + Metadata map[string]string `protobuf:"bytes,6,rep,name=metadata" json:"metadata,omitempty" protobuf_key:"bytes,1,opt,name=key" protobuf_val:"bytes,2,opt,name=value"` + // Name is the identifier of this alias in its authentication source. + // This does not uniquely identify a alias in Vault. This in conjunction + // with MountAccessor form to be the factors that represent a alias in a + // unique way. Aliases will be indexed based on this combined uniqueness + // factor. + Name string `protobuf:"bytes,7,opt,name=name" json:"name,omitempty"` + // CreationTime is the time at which this alias was first created + CreationTime *google_protobuf.Timestamp `protobuf:"bytes,8,opt,name=creation_time,json=creationTime" json:"creation_time,omitempty"` + // LastUpdateTime is the most recent time at which the properties of this + // alias got modified. This is helpful in filtering out aliases based + // on its age and to take action on them, if desired. + LastUpdateTime *google_protobuf.Timestamp `protobuf:"bytes,9,opt,name=last_update_time,json=lastUpdateTime" json:"last_update_time,omitempty"` + // MergedFromEntityIDs is the FIFO history of merging activity by entity IDs from + // which this alias is transfered over to the entity to which it + // currently belongs to. + MergedFromEntityIDs []string `protobuf:"bytes,10,rep,name=merged_from_entity_ids,json=mergedFromEntityIDs" json:"merged_from_entity_ids,omitempty"` +} + +func (m *Alias) Reset() { *m = Alias{} } +func (m *Alias) String() string { return proto.CompactTextString(m) } +func (*Alias) ProtoMessage() {} +func (*Alias) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{2} } + +func (m *Alias) GetID() string { + if m != nil { + return m.ID + } + return "" +} + +func (m *Alias) GetEntityID() string { + if m != nil { + return m.EntityID + } + return "" +} + +func (m *Alias) GetMountType() string { + if m != nil { + return m.MountType + } + return "" +} + +func (m *Alias) GetMountAccessor() string { + if m != nil { + return m.MountAccessor + } + return "" +} + +func (m *Alias) GetMountPath() string { + if m != nil { + return m.MountPath + } + return "" +} + +func (m *Alias) GetMetadata() map[string]string { + if m != nil { + return m.Metadata + } + return nil +} + +func (m *Alias) GetName() string { + if m != nil { + return m.Name + } + return "" +} + +func (m *Alias) GetCreationTime() *google_protobuf.Timestamp { + if m != nil { + return m.CreationTime + } + return nil +} + +func (m *Alias) GetLastUpdateTime() *google_protobuf.Timestamp { + if m != nil { + return m.LastUpdateTime + } + return nil +} + +func (m *Alias) GetMergedFromEntityIDs() []string { + if m != nil { + return m.MergedFromEntityIDs + } + return nil +} + +func init() { + proto.RegisterType((*Group)(nil), "identity.Group") + proto.RegisterType((*Entity)(nil), "identity.Entity") + proto.RegisterType((*Alias)(nil), "identity.Alias") +} + +func init() { proto.RegisterFile("types.proto", fileDescriptor0) } + +var fileDescriptor0 = []byte{ + // 570 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xac, 0x94, 0xcd, 0x6e, 0xd3, 0x40, + 0x10, 0xc7, 0xe5, 0x38, 0x1f, 0xf6, 0xa4, 0x4d, 0xcb, 0x82, 0x90, 0x15, 0x54, 0x08, 0x95, 0x40, + 0x86, 0x83, 0x2b, 0xb5, 0x17, 0x28, 0x07, 0x54, 0x89, 0x02, 0x15, 0x42, 0x42, 0x56, 0x39, 0x5b, + 0x9b, 0x78, 0x9a, 0xac, 0x1a, 0x7b, 0x2d, 0xef, 0x1a, 0xe1, 0x27, 0xe4, 0x39, 0x38, 0xf1, 0x1a, + 0xc8, 0xb3, 0x76, 0x62, 0x08, 0x5f, 0x15, 0xb9, 0xd9, 0xff, 0x99, 0x1d, 0xcf, 0xce, 0xff, 0x37, + 0x86, 0xa1, 0x2e, 0x33, 0x54, 0x41, 0x96, 0x4b, 0x2d, 0x99, 0x23, 0x62, 0x4c, 0xb5, 0xd0, 0xe5, + 0xf8, 0xc1, 0x5c, 0xca, 0xf9, 0x12, 0x8f, 0x48, 0x9f, 0x16, 0x57, 0x47, 0x5a, 0x24, 0xa8, 0x34, + 0x4f, 0x32, 0x93, 0x7a, 0xf8, 0xcd, 0x86, 0xde, 0x9b, 0x5c, 0x16, 0x19, 0x1b, 0x41, 0x47, 0xc4, + 0x9e, 0x35, 0xb1, 0x7c, 0x37, 0xec, 0x88, 0x98, 0x31, 0xe8, 0xa6, 0x3c, 0x41, 0xaf, 0x43, 0x0a, + 0x3d, 0xb3, 0x31, 0x38, 0x99, 0x5c, 0x8a, 0x99, 0x40, 0xe5, 0xd9, 0x13, 0xdb, 0x77, 0xc3, 0xd5, + 0x3b, 0xf3, 0x61, 0x3f, 0xe3, 0x39, 0xa6, 0x3a, 0x9a, 0x57, 0xf5, 0x22, 0x11, 0x2b, 0xaf, 0x4b, + 0x39, 0x23, 0xa3, 0xd3, 0x67, 0x2e, 0x62, 0xc5, 0x9e, 0xc2, 0xad, 0x04, 0x93, 0x29, 0xe6, 0x91, + 0xe9, 0x92, 0x52, 0x7b, 0x94, 0xba, 0x67, 0x02, 0xe7, 0xa4, 0x57, 0xb9, 0xcf, 0xc1, 0x49, 0x50, + 0xf3, 0x98, 0x6b, 0xee, 0xf5, 0x27, 0xb6, 0x3f, 0x3c, 0x3e, 0x08, 0x9a, 0xdb, 0x05, 0x54, 0x31, + 0x78, 0x5f, 0xc7, 0xcf, 0x53, 0x9d, 0x97, 0xe1, 0x2a, 0x9d, 0xbd, 0x84, 0xdd, 0x59, 0x8e, 0x5c, + 0x0b, 0x99, 0x46, 0xd5, 0xb5, 0xbd, 0xc1, 0xc4, 0xf2, 0x87, 0xc7, 0xe3, 0xc0, 0xcc, 0x24, 0x68, + 0x66, 0x12, 0x5c, 0x36, 0x33, 0x09, 0x77, 0x9a, 0x03, 0x95, 0xc4, 0x5e, 0xc1, 0xfe, 0x92, 0x2b, + 0x1d, 0x15, 0x59, 0xcc, 0x35, 0x9a, 0x1a, 0xce, 0x5f, 0x6b, 0x8c, 0xaa, 0x33, 0x1f, 0xe9, 0x08, + 0x55, 0x79, 0x08, 0x3b, 0x89, 0x8c, 0xc5, 0x55, 0x19, 0x89, 0x34, 0xc6, 0xcf, 0x9e, 0x3b, 0xb1, + 0xfc, 0x6e, 0x38, 0x34, 0xda, 0x45, 0x25, 0xb1, 0xc7, 0xb0, 0x37, 0x2d, 0x66, 0xd7, 0xa8, 0xa3, + 0x6b, 0x2c, 0xa3, 0x05, 0x57, 0x0b, 0x0f, 0x68, 0xea, 0xbb, 0x46, 0x7e, 0x87, 0xe5, 0x5b, 0xae, + 0x16, 0xe3, 0x17, 0xb0, 0xfb, 0xc3, 0x65, 0xd9, 0x3e, 0xd8, 0xd7, 0x58, 0xd6, 0xa6, 0x55, 0x8f, + 0xec, 0x0e, 0xf4, 0x3e, 0xf1, 0x65, 0xd1, 0xd8, 0x66, 0x5e, 0x4e, 0x3b, 0xcf, 0xac, 0xc3, 0x2f, + 0x36, 0xf4, 0xcd, 0x5c, 0xd9, 0x13, 0x18, 0xf0, 0xa5, 0xe0, 0x0a, 0x95, 0x67, 0xd1, 0x4c, 0xf7, + 0xd6, 0x33, 0x3d, 0xab, 0x02, 0x61, 0x13, 0xaf, 0xa9, 0xe8, 0x6c, 0x50, 0x61, 0xb7, 0xa8, 0x38, + 0x6d, 0x79, 0xd4, 0xa5, 0x7a, 0xf7, 0xd7, 0xf5, 0xcc, 0x27, 0xff, 0xdd, 0xa4, 0xde, 0x16, 0x4c, + 0xea, 0xdf, 0xd8, 0x24, 0x42, 0x32, 0x9f, 0x63, 0xdc, 0x46, 0x72, 0xd0, 0x20, 0x59, 0x05, 0xd6, + 0x48, 0xb6, 0x97, 0xc0, 0xf9, 0x69, 0x09, 0x7e, 0xe1, 0xa4, 0xbb, 0x75, 0x27, 0xbf, 0xda, 0xd0, + 0x23, 0x9b, 0x36, 0x76, 0xf6, 0x1e, 0xb8, 0xab, 0xfe, 0xeb, 0x73, 0x0e, 0xd6, 0x8d, 0xb3, 0x03, + 0x80, 0x44, 0x16, 0xa9, 0x8e, 0xaa, 0x5f, 0x45, 0x6d, 0xa0, 0x4b, 0xca, 0x65, 0x99, 0x21, 0x7b, + 0x04, 0x23, 0x13, 0xe6, 0xb3, 0x19, 0x2a, 0x25, 0x73, 0xaf, 0x6b, 0x3a, 0x27, 0xf5, 0xac, 0x16, + 0xd7, 0x55, 0x32, 0xae, 0x17, 0xe4, 0x56, 0x53, 0xe5, 0x03, 0xd7, 0x8b, 0x3f, 0xef, 0x2b, 0x35, + 0xfd, 0x5b, 0x14, 0x1a, 0xb4, 0x06, 0x2d, 0xb4, 0x36, 0xf0, 0x70, 0xb6, 0x80, 0x87, 0x7b, 0x63, + 0x3c, 0x4e, 0xe0, 0x6e, 0x8d, 0xc7, 0x55, 0x2e, 0x93, 0x36, 0x23, 0x40, 0x00, 0xdc, 0x36, 0xd1, + 0xd7, 0xb9, 0x4c, 0x56, 0x9c, 0xfc, 0x97, 0xc7, 0xd3, 0x3e, 0x75, 0x75, 0xf2, 0x3d, 0x00, 0x00, + 0xff, 0xff, 0x17, 0x1c, 0xfc, 0x89, 0xd8, 0x05, 0x00, 0x00, +} diff --git a/helper/identity/types.proto b/helper/identity/types.proto new file mode 100644 index 0000000000..2c27442989 --- /dev/null +++ b/helper/identity/types.proto @@ -0,0 +1,151 @@ +syntax = "proto3"; + +package identity; + +import "google/protobuf/timestamp.proto"; + +// Group represents an identity group. +message Group { + // ID is the unique identifier for this group + string id = 1; + + // Name is the unique name for this group + string name = 2; + + // Policies are the vault policies to be granted to members of this group + repeated string policies = 3; + + // ParentGroupIDs are the identifiers of those groups to which this group is a + // member of. These will serve as references to the parent group in the + // hierarchy. + repeated string parent_group_ids = 4; + + // MemberEntityIDs are the identifiers of entities which are members of this + // group + repeated string member_entity_ids = 5; + + // Metadata represents the custom data tied with this group + map metadata = 6; + + // CreationTime is the time at which this group was created + google.protobuf.Timestamp creation_time = 7; + + // LastUpdateTime is the time at which this group was last modified + google.protobuf.Timestamp last_update_time= 8; + + // ModifyIndex tracks the number of updates to the group. It is useful to detect + // updates to the groups. + uint64 modify_index = 9; + + // BucketKeyHash is the MD5 hash of the storage bucket key into which this + // group is stored in the underlying storage. This is useful to find all + // the groups belonging to a particular bucket during invalidation of the + // storage key. + string bucket_key_hash = 10; +} + + +// Entity represents an entity that gets persisted and indexed. +// Entity is fundamentally composed of zero or many aliases. +message Entity { + // Aliases are the identities that this entity is made of. This can be + // empty as well to favor being able to create the entity first and then + // incrementally adding aliases. + repeated Alias aliases = 1; + + // ID is the unique identifier of the entity which always be a UUID. This + // should never be allowed to be updated. + string id = 2; + + // Name is a unique identifier of the entity which is intended to be + // human-friendly. The default name might not be human friendly since it + // gets suffixed by a UUID, but it can optionally be updated, unlike the ID + // field. + string name = 3; + + // Metadata represents the explicit metadata which is set by the + // clients. This is useful to tie any information pertaining to the + // aliases. This is a non-unique field of entity, meaning multiple + // entities can have the same metadata set. Entities will be indexed based + // on this explicit metadata. This enables virtual groupings of entities + // based on its metadata. + map metadata = 4; + + // CreationTime is the time at which this entity is first created. + google.protobuf.Timestamp creation_time = 5; + + // LastUpdateTime is the most recent time at which the properties of this + // entity got modified. This is helpful in filtering out entities based on + // its age and to take action on them, if desired. + google.protobuf.Timestamp last_update_time= 6; + + // MergedEntityIDs are the entities which got merged to this one. Entities + // will be indexed based on all the entities that got merged into it. This + // helps to apply the actions on this entity on the tokens that are merged + // to the merged entities. Merged entities will be deleted entirely and + // this is the only trackable trail of its earlier presence. + repeated string merged_entity_ids = 7; + + // Policies the entity is entitled to + repeated string policies = 8; + + // BucketKeyHash is the MD5 hash of the storage bucket key into which this + // entity is stored in the underlying storage. This is useful to find all + // the entities belonging to a particular bucket during invalidation of the + // storage key. + string bucket_key_hash = 9; + + // **Enterprise only** + // MFASecrets holds the MFA secrets indexed by the identifier of the MFA + // method configuration. + //map mfa_secrets = 10; +} + +// Alias represents the alias that gets stored inside of the +// entity object in storage and also represents in an in-memory index of an +// alias object. +message Alias { + // ID is the unique identifier that represents this alias + string id = 1; + + // EntityID is the entity identifier to which this alias belongs to + string entity_id = 2; + + // MountType is the backend mount's type to which this alias belongs to. + // This enables categorically querying aliases of specific backend types. + string mount_type = 3; + + // MountAccessor is the backend mount's accessor to which this alias + // belongs to. + string mount_accessor = 4; + + // MountPath is the backend mount's path to which the Maccessor belongs to. This + // field is not used for any operational purposes. This is only returned when + // alias is read, only as a nicety. + string mount_path = 5; + + // Metadata is the explicit metadata that clients set against an entity + // which enables virtual grouping of aliases. Aliases will be indexed + // against their metadata. + map metadata = 6; + + // Name is the identifier of this alias in its authentication source. + // This does not uniquely identify a alias in Vault. This in conjunction + // with MountAccessor form to be the factors that represent a alias in a + // unique way. Aliases will be indexed based on this combined uniqueness + // factor. + string name = 7; + + // CreationTime is the time at which this alias was first created + google.protobuf.Timestamp creation_time = 8; + + // LastUpdateTime is the most recent time at which the properties of this + // alias got modified. This is helpful in filtering out aliases based + // on its age and to take action on them, if desired. + google.protobuf.Timestamp last_update_time = 9; + + // MergedFromEntityIDs is the FIFO history of merging activity by entity IDs from + // which this alias is transfered over to the entity to which it + // currently belongs to. + repeated string merged_from_entity_ids = 10; +} diff --git a/helper/storagepacker/storagepacker.go b/helper/storagepacker/storagepacker.go new file mode 100644 index 0000000000..b9bdd68c8c --- /dev/null +++ b/helper/storagepacker/storagepacker.go @@ -0,0 +1,351 @@ +package storagepacker + +import ( + "crypto/md5" + "encoding/hex" + "fmt" + "strconv" + "strings" + + "github.com/golang/protobuf/proto" + "github.com/hashicorp/errwrap" + "github.com/hashicorp/vault/helper/compressutil" + "github.com/hashicorp/vault/helper/locksutil" + "github.com/hashicorp/vault/logical" + log "github.com/mgutz/logxi/v1" +) + +const ( + bucketCount = 256 + StoragePackerBucketsPrefix = "packer/buckets/" +) + +// StoragePacker packs the objects into a specific number of buckets by hashing +// its ID and indexing it. Currently this supports only 256 bucket entries and +// hence relies on the first byte of the hash value for indexing. The items +// that gets inserted into the packer should implement StorageBucketItem +// interface. +type StoragePacker struct { + view logical.Storage + logger log.Logger + storageLocks []*locksutil.LockEntry + viewPrefix string +} + +// BucketPath returns the storage entry key for a given bucket key +func (s *StoragePacker) BucketPath(bucketKey string) string { + return s.viewPrefix + bucketKey +} + +// BucketKeyHash returns the MD5 hash of the bucket storage key in which +// the item will be stored. The choice of MD5 is only for hash performance +// reasons since its value is not used for any security sensitive operation. +func (s *StoragePacker) BucketKeyHashByItemID(itemID string) string { + return s.BucketKeyHashByKey(s.BucketPath(s.BucketKey(itemID))) +} + +// BucketKeyHashByKey returns the MD5 hash of the bucket storage key +func (s *StoragePacker) BucketKeyHashByKey(bucketKey string) string { + hf := md5.New() + hf.Write([]byte(bucketKey)) + return hex.EncodeToString(hf.Sum(nil)) +} + +// View returns the storage view configured to be used by the packer +func (s *StoragePacker) View() logical.Storage { + return s.view +} + +// Get returns a bucket for a given key +func (s *StoragePacker) GetBucket(key string) (*Bucket, error) { + if key == "" { + return nil, fmt.Errorf("missing bucket key") + } + + lock := locksutil.LockForKey(s.storageLocks, key) + lock.RLock() + defer lock.RUnlock() + + // Read from the underlying view + storageEntry, err := s.view.Get(key) + if err != nil { + return nil, errwrap.Wrapf("failed to read packed storage entry: {{err}}", err) + } + if storageEntry == nil { + return nil, nil + } + + uncompressedData, notCompressed, err := compressutil.Decompress(storageEntry.Value) + if err != nil { + return nil, errwrap.Wrapf("failed to decompress packed storage entry: {{err}}", err) + } + if notCompressed { + uncompressedData = storageEntry.Value + } + + var bucket Bucket + err = proto.Unmarshal(uncompressedData, &bucket) + if err != nil { + return nil, errwrap.Wrapf("failed to decode packed storage entry: {{err}}", err) + } + + return &bucket, nil +} + +// upsert either inserts a new item into the bucket or updates an existing one +// if an item with a matching key is already present. +func (s *Bucket) upsert(item *Item) error { + if s == nil { + return fmt.Errorf("nil storage bucket") + } + + if item == nil { + return fmt.Errorf("nil item") + } + + if item.ID == "" { + return fmt.Errorf("missing item ID") + } + + // Look for an item with matching key and don't modify the collection + // while iterating + foundIdx := -1 + for itemIdx, bucketItems := range s.Items { + if bucketItems.ID == item.ID { + foundIdx = itemIdx + break + } + } + + // If there is no match, append the item, otherwise update it + if foundIdx == -1 { + s.Items = append(s.Items, item) + } else { + s.Items[foundIdx] = item + } + + return nil +} + +// BucketIndex returns the bucket key index for a given storage key +func (s *StoragePacker) BucketIndex(key string) uint8 { + hf := md5.New() + hf.Write([]byte(key)) + return uint8(hf.Sum(nil)[0]) +} + +// BucketKey returns the bucket key for a given item ID +func (s *StoragePacker) BucketKey(itemID string) string { + return strconv.Itoa(int(s.BucketIndex(itemID))) +} + +// DeleteItem removes the storage entry which the given key refers to from its +// corresponding bucket. +func (s *StoragePacker) DeleteItem(itemID string) error { + + if itemID == "" { + return fmt.Errorf("empty item ID") + } + + // Get the bucket key + bucketKey := s.BucketKey(itemID) + + // Prepend the view prefix + bucketPath := s.BucketPath(bucketKey) + + // Read from underlying view + storageEntry, err := s.view.Get(bucketPath) + if err != nil { + return errwrap.Wrapf("failed to read packed storage value: {{err}}", err) + } + if storageEntry == nil { + return nil + } + + uncompressedData, notCompressed, err := compressutil.Decompress(storageEntry.Value) + if err != nil { + return errwrap.Wrapf("failed to decompress packed storage value: {{err}}", err) + } + if notCompressed { + uncompressedData = storageEntry.Value + } + + var bucket Bucket + err = proto.Unmarshal(uncompressedData, &bucket) + if err != nil { + return errwrap.Wrapf("failed decoding packed storage entry: {{err}}", err) + } + + // Look for a matching storage entry + foundIdx := -1 + for itemIdx, item := range bucket.Items { + if item.ID == itemID { + foundIdx = itemIdx + break + } + } + + // If there is a match, remove it from the collection and persist the + // resulting collection + if foundIdx != -1 { + bucket.Items = append(bucket.Items[:foundIdx], bucket.Items[foundIdx+1:]...) + + // Persist bucket entry only if there is an update + err = s.PutBucket(&bucket) + if err != nil { + return err + } + } + + return nil +} + +// Put stores a packed bucket entry +func (s *StoragePacker) PutBucket(bucket *Bucket) error { + if bucket == nil { + return fmt.Errorf("nil bucket entry") + } + + if bucket.Key == "" { + return fmt.Errorf("missing key") + } + + if !strings.HasPrefix(bucket.Key, s.viewPrefix) { + return fmt.Errorf("incorrect prefix; bucket entry key should have %q prefix", s.viewPrefix) + } + + marshaledBucket, err := proto.Marshal(bucket) + if err != nil { + return errwrap.Wrapf("failed to marshal bucket: {{err}}", err) + } + + compressedBucket, err := compressutil.Compress(marshaledBucket, &compressutil.CompressionConfig{ + Type: compressutil.CompressionTypeSnappy, + }) + if err != nil { + return errwrap.Wrapf("failed to compress packed bucket: {{err}}", err) + } + + // Store the compressed value + err = s.view.Put(&logical.StorageEntry{ + Key: bucket.Key, + Value: compressedBucket, + }) + if err != nil { + return errwrap.Wrapf("failed to persist packed storage entry: {{err}}", err) + } + + return nil +} + +// GetItem fetches the storage entry for a given key from its corresponding +// bucket. +func (s *StoragePacker) GetItem(itemID string) (*Item, error) { + if itemID == "" { + return nil, fmt.Errorf("empty item ID") + } + + bucketKey := s.BucketKey(itemID) + bucketPath := s.BucketPath(bucketKey) + + // Fetch the bucket entry + bucket, err := s.GetBucket(bucketPath) + if err != nil { + return nil, errwrap.Wrapf("failed to read packed storage item: {{err}}", err) + } + + // Look for a matching storage entry in the bucket items + for _, item := range bucket.Items { + if item.ID == itemID { + return item, nil + } + } + + return nil, nil +} + +// PutItem stores a storage entry in its corresponding bucket +func (s *StoragePacker) PutItem(item *Item) error { + if item == nil { + return fmt.Errorf("nil item") + } + + if item.ID == "" { + return fmt.Errorf("missing ID in item") + } + + var err error + bucketKey := s.BucketKey(item.ID) + bucketPath := s.BucketPath(bucketKey) + + bucket := &Bucket{ + Key: bucketPath, + } + + // In this case, we persist the storage entry regardless of the read + // storageEntry below is nil or not. Hence, directly acquire write lock + // even to read the entry. + lock := locksutil.LockForKey(s.storageLocks, bucketPath) + lock.Lock() + defer lock.Unlock() + + // Check if there is an existing bucket for a given key + storageEntry, err := s.view.Get(bucketPath) + if err != nil { + return errwrap.Wrapf("failed to read packed storage bucket entry: {{err}}", err) + } + + if storageEntry == nil { + // If the bucket entry does not exist, this will be the only item the + // bucket that is going to be persisted. + bucket.Items = []*Item{ + item, + } + } else { + uncompressedData, notCompressed, err := compressutil.Decompress(storageEntry.Value) + if err != nil { + return errwrap.Wrapf("failed to decompress packed storage entry: {{err}}", err) + } + if notCompressed { + uncompressedData = storageEntry.Value + } + + err = proto.Unmarshal(uncompressedData, bucket) + if err != nil { + return errwrap.Wrapf("failed to decode packed storage entry: {{err}}", err) + } + + err = bucket.upsert(item) + if err != nil { + return errwrap.Wrapf("failed to update entry in packed storage entry: {{err}}", err) + } + } + + // Persist the result + return s.PutBucket(bucket) +} + +// NewStoragePacker creates a new storage packer for a given view +func NewStoragePacker(view logical.Storage, logger log.Logger, viewPrefix string) (*StoragePacker, error) { + if view == nil { + return nil, fmt.Errorf("nil view") + } + + if viewPrefix == "" { + viewPrefix = StoragePackerBucketsPrefix + } + + if !strings.HasSuffix(viewPrefix, "/") { + viewPrefix = viewPrefix + "/" + } + + // Create a new packer object for the given view + packer := &StoragePacker{ + view: view, + viewPrefix: viewPrefix, + logger: logger, + storageLocks: locksutil.CreateLocks(), + } + + return packer, nil +} diff --git a/helper/storagepacker/storagepacker_test.go b/helper/storagepacker/storagepacker_test.go new file mode 100644 index 0000000000..9f8f287763 --- /dev/null +++ b/helper/storagepacker/storagepacker_test.go @@ -0,0 +1,172 @@ +package storagepacker + +import ( + "reflect" + "testing" + + "github.com/golang/protobuf/ptypes" + uuid "github.com/hashicorp/go-uuid" + "github.com/hashicorp/vault/helper/identity" + "github.com/hashicorp/vault/logical" + log "github.com/mgutz/logxi/v1" +) + +func BenchmarkStoragePacker(b *testing.B) { + storagePacker, err := NewStoragePacker(&logical.InmemStorage{}, log.New("storagepackertest"), "") + if err != nil { + b.Fatal(err) + } + + for i := 0; i < b.N; i++ { + itemID, err := uuid.GenerateUUID() + if err != nil { + b.Fatal(err) + } + + item := &Item{ + ID: itemID, + } + + err = storagePacker.PutItem(item) + if err != nil { + b.Fatal(err) + } + + fetchedItem, err := storagePacker.GetItem(itemID) + if err != nil { + b.Fatal(err) + } + + if fetchedItem == nil { + b.Fatalf("failed to read stored item with ID: %q, iteration: %d", item.ID, i) + } + + if fetchedItem.ID != item.ID { + b.Fatalf("bad: item ID; expected: %q\n actual: %q", item.ID, fetchedItem.ID) + } + + err = storagePacker.DeleteItem(item.ID) + if err != nil { + b.Fatal(err) + } + + fetchedItem, err = storagePacker.GetItem(item.ID) + if err != nil { + b.Fatal(err) + } + if fetchedItem != nil { + b.Fatalf("failed to delete item") + } + } +} + +func TestStoragePacker(t *testing.T) { + storagePacker, err := NewStoragePacker(&logical.InmemStorage{}, log.New("storagepackertest"), "") + if err != nil { + t.Fatal(err) + } + + // Persist a storage entry + item1 := &Item{ + ID: "item1", + } + + err = storagePacker.PutItem(item1) + if err != nil { + t.Fatal(err) + } + + // Verify that it can be read + fetchedItem, err := storagePacker.GetItem(item1.ID) + if err != nil { + t.Fatal(err) + } + if fetchedItem == nil { + t.Fatalf("failed to read the stored item") + } + + if item1.ID != fetchedItem.ID { + t.Fatalf("bad: item ID; expected: %q\n actual: %q\n", item1.ID, fetchedItem.ID) + } + + // Delete item1 + err = storagePacker.DeleteItem(item1.ID) + if err != nil { + t.Fatal(err) + } + + // Check that the deletion was successful + fetchedItem, err = storagePacker.GetItem(item1.ID) + if err != nil { + t.Fatal(err) + } + + if fetchedItem != nil { + t.Fatalf("failed to delete item") + } +} + +func TestStoragePacker_SerializeDeserializeComplexItem(t *testing.T) { + storagePacker, err := NewStoragePacker(&logical.InmemStorage{}, log.New("storagepackertest"), "") + if err != nil { + t.Fatal(err) + } + + timeNow := ptypes.TimestampNow() + + alias1 := &identity.Alias{ + ID: "alias_id", + EntityID: "entity_id", + MountType: "mount_type", + MountAccessor: "mount_accessor", + Metadata: map[string]string{ + "aliasmkey": "aliasmvalue", + }, + Name: "alias_name", + CreationTime: timeNow, + LastUpdateTime: timeNow, + MergedFromEntityIDs: []string{"merged_from_entity_id"}, + } + + entity := &identity.Entity{ + Aliases: []*identity.Alias{alias1}, + ID: "entity_id", + Name: "entity_name", + Metadata: map[string]string{ + "testkey1": "testvalue1", + "testkey2": "testvalue2", + }, + CreationTime: timeNow, + LastUpdateTime: timeNow, + BucketKeyHash: "entity_hash", + MergedEntityIDs: []string{"merged_entity_id1", "merged_entity_id2"}, + Policies: []string{"policy1", "policy2"}, + } + + marshaledEntity, err := ptypes.MarshalAny(entity) + if err != nil { + t.Fatal(err) + } + err = storagePacker.PutItem(&Item{ + ID: entity.ID, + Message: marshaledEntity, + }) + if err != nil { + t.Fatal(err) + } + + itemFetched, err := storagePacker.GetItem(entity.ID) + if err != nil { + t.Fatal(err) + } + + var itemDecoded identity.Entity + err = ptypes.UnmarshalAny(itemFetched.Message, &itemDecoded) + if err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(&itemDecoded, entity) { + t.Fatalf("bad: expected: %#v\nactual: %#v\n", entity, itemDecoded) + } +} diff --git a/helper/storagepacker/types.pb.go b/helper/storagepacker/types.pb.go new file mode 100644 index 0000000000..251989025a --- /dev/null +++ b/helper/storagepacker/types.pb.go @@ -0,0 +1,101 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// source: types.proto + +/* +Package storagepacker is a generated protocol buffer package. + +It is generated from these files: + types.proto + +It has these top-level messages: + Item + Bucket +*/ +package storagepacker + +import proto "github.com/golang/protobuf/proto" +import fmt "fmt" +import math "math" +import google_protobuf "github.com/golang/protobuf/ptypes/any" + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = fmt.Errorf +var _ = math.Inf + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the proto package it is being compiled against. +// A compilation error at this line likely means your copy of the +// proto package needs to be updated. +const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package + +type Item struct { + ID string `protobuf:"bytes,1,opt,name=id" json:"id,omitempty"` + Message *google_protobuf.Any `protobuf:"bytes,2,opt,name=message" json:"message,omitempty"` +} + +func (m *Item) Reset() { *m = Item{} } +func (m *Item) String() string { return proto.CompactTextString(m) } +func (*Item) ProtoMessage() {} +func (*Item) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{0} } + +func (m *Item) GetID() string { + if m != nil { + return m.ID + } + return "" +} + +func (m *Item) GetMessage() *google_protobuf.Any { + if m != nil { + return m.Message + } + return nil +} + +type Bucket struct { + Key string `protobuf:"bytes,1,opt,name=key" json:"key,omitempty"` + Items []*Item `protobuf:"bytes,2,rep,name=items" json:"items,omitempty"` +} + +func (m *Bucket) Reset() { *m = Bucket{} } +func (m *Bucket) String() string { return proto.CompactTextString(m) } +func (*Bucket) ProtoMessage() {} +func (*Bucket) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{1} } + +func (m *Bucket) GetKey() string { + if m != nil { + return m.Key + } + return "" +} + +func (m *Bucket) GetItems() []*Item { + if m != nil { + return m.Items + } + return nil +} + +func init() { + proto.RegisterType((*Item)(nil), "storagepacker.Item") + proto.RegisterType((*Bucket)(nil), "storagepacker.Bucket") +} + +func init() { proto.RegisterFile("types.proto", fileDescriptor0) } + +var fileDescriptor0 = []byte{ + // 181 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0xe2, 0x2e, 0xa9, 0x2c, 0x48, + 0x2d, 0xd6, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0xe2, 0x2d, 0x2e, 0xc9, 0x2f, 0x4a, 0x4c, 0x4f, + 0x2d, 0x48, 0x4c, 0xce, 0x4e, 0x2d, 0x92, 0x92, 0x4c, 0xcf, 0xcf, 0x4f, 0xcf, 0x49, 0xd5, 0x07, + 0x4b, 0x26, 0x95, 0xa6, 0xe9, 0x27, 0xe6, 0x55, 0x42, 0x54, 0x2a, 0xb9, 0x71, 0xb1, 0x78, 0x96, + 0xa4, 0xe6, 0x0a, 0xf1, 0x71, 0x31, 0x65, 0xa6, 0x48, 0x30, 0x2a, 0x30, 0x6a, 0x70, 0x06, 0x31, + 0x65, 0xa6, 0x08, 0xe9, 0x71, 0xb1, 0xe7, 0xa6, 0x16, 0x17, 0x27, 0xa6, 0xa7, 0x4a, 0x30, 0x29, + 0x30, 0x6a, 0x70, 0x1b, 0x89, 0xe8, 0x41, 0x0c, 0xd1, 0x83, 0x19, 0xa2, 0xe7, 0x98, 0x57, 0x19, + 0x04, 0x53, 0xa4, 0xe4, 0xca, 0xc5, 0xe6, 0x54, 0x9a, 0x9c, 0x9d, 0x5a, 0x22, 0x24, 0xc0, 0xc5, + 0x9c, 0x9d, 0x5a, 0x09, 0x35, 0x0a, 0xc4, 0x14, 0xd2, 0xe4, 0x62, 0xcd, 0x2c, 0x49, 0xcd, 0x2d, + 0x96, 0x60, 0x52, 0x60, 0xd6, 0xe0, 0x36, 0x12, 0xd6, 0x43, 0x71, 0x9d, 0x1e, 0xc8, 0xfe, 0x20, + 0x88, 0x8a, 0x24, 0x36, 0xb0, 0xe9, 0xc6, 0x80, 0x00, 0x00, 0x00, 0xff, 0xff, 0x01, 0x30, 0x77, + 0x9a, 0xce, 0x00, 0x00, 0x00, +} diff --git a/helper/storagepacker/types.proto b/helper/storagepacker/types.proto new file mode 100644 index 0000000000..11c386b002 --- /dev/null +++ b/helper/storagepacker/types.proto @@ -0,0 +1,15 @@ +syntax = "proto3"; + +package storagepacker; + +import "google/protobuf/any.proto"; + +message Item { + string id = 1; + google.protobuf.Any message = 2; +} + +message Bucket { + string key = 1; + repeated Item items = 2; +} diff --git a/helper/wrapping/wrapinfo.go b/helper/wrapping/wrapinfo.go index 2242c7b309..6a9fa129db 100644 --- a/helper/wrapping/wrapinfo.go +++ b/helper/wrapping/wrapinfo.go @@ -18,6 +18,10 @@ type ResponseWrapInfo struct { // created token's accessor will be accessible here WrappedAccessor string `json:"wrapped_accessor" structs:"wrapped_accessor" mapstructure:"wrapped_accessor"` + // WrappedEntityID is the entity identifier of the caller who initiated the + // wrapping request + WrappedEntityID string `json:"wrapped_entity_id" structs:"wrapped_entity_id" mapstructure:"wrapped_entity_id"` + // The format to use. This doesn't get returned, it's only internal. Format string `json:"format" structs:"format" mapstructure:"format"` diff --git a/http/handler_test.go b/http/handler_test.go index 8eae984cca..514890f42e 100644 --- a/http/handler_test.go +++ b/http/handler_test.go @@ -186,6 +186,16 @@ func TestSysMounts_headerAuth(t *testing.T) { }, "local": true, }, + "identity/": map[string]interface{}{ + "description": "identity store", + "type": "identity", + "config": map[string]interface{}{ + "default_lease_ttl": json.Number("0"), + "max_lease_ttl": json.Number("0"), + "force_no_cache": false, + }, + "local": false, + }, }, "secret/": map[string]interface{}{ "description": "key/value secret storage", @@ -217,6 +227,16 @@ func TestSysMounts_headerAuth(t *testing.T) { }, "local": true, }, + "identity/": map[string]interface{}{ + "description": "identity store", + "type": "identity", + "config": map[string]interface{}{ + "default_lease_ttl": json.Number("0"), + "max_lease_ttl": json.Number("0"), + "force_no_cache": false, + }, + "local": false, + }, } testResponseStatus(t, resp, 200) testResponseBody(t, resp, &actual) diff --git a/http/logical_test.go b/http/logical_test.go index e4101a50bf..7f3ff66aeb 100644 --- a/http/logical_test.go +++ b/http/logical_test.go @@ -157,6 +157,7 @@ func TestLogical_StandbyRedirect(t *testing.T) { "creation_ttl": json.Number("0"), "explicit_max_ttl": json.Number("0"), "expire_time": nil, + "entity_id": "", }, "warnings": nilWarnings, "wrap_info": nil, @@ -206,6 +207,7 @@ func TestLogical_CreateToken(t *testing.T) { "metadata": nil, "lease_duration": json.Number("0"), "renewable": false, + "entity_id": "", }, "warnings": nilWarnings, } diff --git a/http/sys_generate_root_test.go b/http/sys_generate_root_test.go index 41cb2a540c..347dd2e43a 100644 --- a/http/sys_generate_root_test.go +++ b/http/sys_generate_root_test.go @@ -313,6 +313,7 @@ func TestSysGenerateRoot_Update_OTP(t *testing.T) { "path": "auth/token/root", "explicit_max_ttl": json.Number("0"), "expire_time": nil, + "entity_id": "", } resp = testHttpGet(t, newRootToken, addr+"/v1/auth/token/lookup-self") @@ -403,6 +404,7 @@ func TestSysGenerateRoot_Update_PGP(t *testing.T) { "path": "auth/token/root", "explicit_max_ttl": json.Number("0"), "expire_time": nil, + "entity_id": "", } resp = testHttpGet(t, newRootToken, addr+"/v1/auth/token/lookup-self") diff --git a/http/sys_mount_test.go b/http/sys_mount_test.go index 57f6dd7728..46135ec2ec 100644 --- a/http/sys_mount_test.go +++ b/http/sys_mount_test.go @@ -56,6 +56,16 @@ func TestSysMounts(t *testing.T) { }, "local": true, }, + "identity/": map[string]interface{}{ + "description": "identity store", + "type": "identity", + "config": map[string]interface{}{ + "default_lease_ttl": json.Number("0"), + "max_lease_ttl": json.Number("0"), + "force_no_cache": false, + }, + "local": false, + }, }, "secret/": map[string]interface{}{ "description": "key/value secret storage", @@ -87,6 +97,16 @@ func TestSysMounts(t *testing.T) { }, "local": true, }, + "identity/": map[string]interface{}{ + "description": "identity store", + "type": "identity", + "config": map[string]interface{}{ + "default_lease_ttl": json.Number("0"), + "max_lease_ttl": json.Number("0"), + "force_no_cache": false, + }, + "local": false, + }, } testResponseStatus(t, resp, 200) testResponseBody(t, resp, &actual) @@ -167,6 +187,16 @@ func TestSysMount(t *testing.T) { }, "local": true, }, + "identity/": map[string]interface{}{ + "description": "identity store", + "type": "identity", + "config": map[string]interface{}{ + "default_lease_ttl": json.Number("0"), + "max_lease_ttl": json.Number("0"), + "force_no_cache": false, + }, + "local": false, + }, }, "foo/": map[string]interface{}{ "description": "foo", @@ -208,6 +238,16 @@ func TestSysMount(t *testing.T) { }, "local": true, }, + "identity/": map[string]interface{}{ + "description": "identity store", + "type": "identity", + "config": map[string]interface{}{ + "default_lease_ttl": json.Number("0"), + "max_lease_ttl": json.Number("0"), + "force_no_cache": false, + }, + "local": false, + }, } testResponseStatus(t, resp, 200) testResponseBody(t, resp, &actual) @@ -310,6 +350,16 @@ func TestSysRemount(t *testing.T) { }, "local": true, }, + "identity/": map[string]interface{}{ + "description": "identity store", + "type": "identity", + "config": map[string]interface{}{ + "default_lease_ttl": json.Number("0"), + "max_lease_ttl": json.Number("0"), + "force_no_cache": false, + }, + "local": false, + }, }, "bar/": map[string]interface{}{ "description": "foo", @@ -351,6 +401,16 @@ func TestSysRemount(t *testing.T) { }, "local": true, }, + "identity/": map[string]interface{}{ + "description": "identity store", + "type": "identity", + "config": map[string]interface{}{ + "default_lease_ttl": json.Number("0"), + "max_lease_ttl": json.Number("0"), + "force_no_cache": false, + }, + "local": false, + }, } testResponseStatus(t, resp, 200) testResponseBody(t, resp, &actual) @@ -424,6 +484,16 @@ func TestSysUnmount(t *testing.T) { }, "local": true, }, + "identity/": map[string]interface{}{ + "description": "identity store", + "type": "identity", + "config": map[string]interface{}{ + "default_lease_ttl": json.Number("0"), + "max_lease_ttl": json.Number("0"), + "force_no_cache": false, + }, + "local": false, + }, }, "secret/": map[string]interface{}{ "description": "key/value secret storage", @@ -455,6 +525,16 @@ func TestSysUnmount(t *testing.T) { }, "local": true, }, + "identity/": map[string]interface{}{ + "description": "identity store", + "type": "identity", + "config": map[string]interface{}{ + "default_lease_ttl": json.Number("0"), + "max_lease_ttl": json.Number("0"), + "force_no_cache": false, + }, + "local": false, + }, } testResponseStatus(t, resp, 200) testResponseBody(t, resp, &actual) @@ -535,6 +615,16 @@ func TestSysTuneMount(t *testing.T) { }, "local": true, }, + "identity/": map[string]interface{}{ + "description": "identity store", + "type": "identity", + "config": map[string]interface{}{ + "default_lease_ttl": json.Number("0"), + "max_lease_ttl": json.Number("0"), + "force_no_cache": false, + }, + "local": false, + }, }, "foo/": map[string]interface{}{ "description": "foo", @@ -576,6 +666,16 @@ func TestSysTuneMount(t *testing.T) { }, "local": true, }, + "identity/": map[string]interface{}{ + "description": "identity store", + "type": "identity", + "config": map[string]interface{}{ + "default_lease_ttl": json.Number("0"), + "max_lease_ttl": json.Number("0"), + "force_no_cache": false, + }, + "local": false, + }, } testResponseStatus(t, resp, 200) testResponseBody(t, resp, &actual) @@ -677,6 +777,16 @@ func TestSysTuneMount(t *testing.T) { }, "local": true, }, + "identity/": map[string]interface{}{ + "description": "identity store", + "type": "identity", + "config": map[string]interface{}{ + "default_lease_ttl": json.Number("0"), + "max_lease_ttl": json.Number("0"), + "force_no_cache": false, + }, + "local": false, + }, }, "foo/": map[string]interface{}{ "description": "foo", @@ -718,6 +828,16 @@ func TestSysTuneMount(t *testing.T) { }, "local": true, }, + "identity/": map[string]interface{}{ + "description": "identity store", + "type": "identity", + "config": map[string]interface{}{ + "default_lease_ttl": json.Number("0"), + "max_lease_ttl": json.Number("0"), + "force_no_cache": false, + }, + "local": false, + }, } testResponseStatus(t, resp, 200) diff --git a/logical/auth.go b/logical/auth.go index cfb52f7b8a..c5d184da89 100644 --- a/logical/auth.go +++ b/logical/auth.go @@ -52,6 +52,10 @@ type Auth struct { // Number of allowed uses of the issued token NumUses int `json:"num_uses" mapstructure:"num_uses" structs:"num_uses"` + // EntityID is the identifier of the entity in identity store to which the + // identity of the authenticating client belongs to. + EntityID string `json:"entity_id" mapstructure:"entity_id" structs:"entity_id"` + // Alias is the information about the authenticated client returned by // the auth backend Alias *Alias `json:"alias" structs:"alias" mapstructure:"alias"` diff --git a/logical/identity.go b/logical/identity.go index 0ba62bf8ef..a49017a3eb 100644 --- a/logical/identity.go +++ b/logical/identity.go @@ -3,24 +3,14 @@ package logical // Alias represents the information used by core to create implicit entity. // Implicit entities get created when a client authenticates successfully from // any of the authentication backends (except token backend). -// -// This is applicable to enterprise binaries only. Alias should be set in the -// Auth response returned by the credential backends. This structure is placed -// in the open source repository only to enable custom authetication plugins to -// be used along with enterprise binary. The custom auth plugins should make -// use of this and fill out the Alias information in the authentication -// response. type Alias struct { // MountType is the backend mount's type to which this identity belongs - // to. MountType string `json:"mount_type" structs:"mount_type" mapstructure:"mount_type"` - // MountAccessor is the identifier of the mount entry to which - // this identity - // belongs to. + // MountAccessor is the identifier of the mount entry to which this + // identity belongs MountAccessor string `json:"mount_accessor" structs:"mount_accessor" mapstructure:"mount_accessor"` - // Name is the identifier of this identity in its - // authentication source. + // Name is the identifier of this identity in its authentication source Name string `json:"name" structs:"name" mapstructure:"name"` } diff --git a/logical/request.go b/logical/request.go index 561fbbc0fb..a1afeba2b2 100644 --- a/logical/request.go +++ b/logical/request.go @@ -87,6 +87,10 @@ type Request struct { // aliases, generating different defaults depending on the alias) MountType string `json:"mount_type" structs:"mount_type" mapstructure:"mount_type"` + // MountAccessor is provided so that identities returned by the authentication + // backends can be tied to the mount it belongs to. + MountAccessor string `json:"mount_accessor" structs:"mount_accessor" mapstructure:"mount_accessor"` + // WrapInfo contains requested response wrapping parameters WrapInfo *RequestWrapInfo `json:"wrap_info" structs:"wrap_info" mapstructure:"wrap_info"` @@ -94,6 +98,10 @@ type Request struct { // token supplied ClientTokenRemainingUses int `json:"client_token_remaining_uses" structs:"client_token_remaining_uses" mapstructure:"client_token_remaining_uses"` + // EntityID is the identity of the caller extracted out of the token used + // to make this request + EntityID string `json:"entity_id" structs:"entity_id" mapstructure:"entity_id"` + // For replication, contains the last WAL on the remote side after handling // the request, used for best-effort avoidance of stale read-after-write lastRemoteWAL uint64 diff --git a/logical/translate_response.go b/logical/translate_response.go index d3d727163c..bf2dae9d1e 100644 --- a/logical/translate_response.go +++ b/logical/translate_response.go @@ -33,6 +33,7 @@ func LogicalResponseToHTTPResponse(input *Response) *HTTPResponse { Metadata: input.Auth.Metadata, LeaseDuration: int(input.Auth.TTL.Seconds()), Renewable: input.Auth.Renewable, + EntityID: input.Auth.EntityID, } } @@ -59,6 +60,7 @@ func HTTPResponseToLogicalResponse(input *HTTPResponse) *Response { Accessor: input.Auth.Accessor, Policies: input.Auth.Policies, Metadata: input.Auth.Metadata, + EntityID: input.Auth.EntityID, } logicalResp.Auth.Renewable = input.Auth.Renewable logicalResp.Auth.TTL = time.Second * time.Duration(input.Auth.LeaseDuration) @@ -85,6 +87,7 @@ type HTTPAuth struct { Metadata map[string]string `json:"metadata"` LeaseDuration int `json:"lease_duration"` Renewable bool `json:"renewable"` + EntityID string `json:"entity_id"` } type HTTPWrapInfo struct { diff --git a/vault/core.go b/vault/core.go index 1259c03638..d6072249ed 100644 --- a/vault/core.go +++ b/vault/core.go @@ -26,6 +26,7 @@ import ( "github.com/hashicorp/vault/audit" "github.com/hashicorp/vault/helper/consts" "github.com/hashicorp/vault/helper/errutil" + "github.com/hashicorp/vault/helper/identity" "github.com/hashicorp/vault/helper/jsonutil" "github.com/hashicorp/vault/helper/logformat" "github.com/hashicorp/vault/helper/mlock" @@ -240,6 +241,9 @@ type Core struct { // can be output in the audit logs auditedHeaders *AuditedHeadersConfig + // systemBackend is the backend which is used to manage internal operations + systemBackend *SystemBackend + // systemBarrierView is the barrier view for the system backend systemBarrierView *BarrierView @@ -256,6 +260,9 @@ type Core struct { // token store is used to manage authentication tokens tokenStore *TokenStore + // identityStore is used to manage client entities + identityStore *IdentityStore + // metricsCh is used to stop the metrics streaming metricsCh chan struct{} @@ -551,6 +558,11 @@ func NewCore(conf *CoreConfig) (*Core, error) { } return b, nil } + + logicalBackends["identity"] = func(config *logical.BackendConfig) (logical.Backend, error) { + return NewIdentityStore(c, config) + } + c.logicalBackends = logicalBackends credentialBackends := make(map[string]logical.Factory) @@ -634,45 +646,88 @@ func (c *Core) LookupToken(token string) (*TokenEntry, error) { return c.tokenStore.Lookup(token) } -func (c *Core) fetchACLandTokenEntry(req *logical.Request) (*ACL, *TokenEntry, error) { +func (c *Core) fetchACLTokenEntryAndEntity(clientToken string) (*ACL, *TokenEntry, *identity.Entity, error) { defer metrics.MeasureSince([]string{"core", "fetch_acl_and_token"}, time.Now()) // Ensure there is a client token - if req.ClientToken == "" { - return nil, nil, fmt.Errorf("missing client token") + if clientToken == "" { + return nil, nil, nil, fmt.Errorf("missing client token") } if c.tokenStore == nil { c.logger.Error("core: token store is unavailable") - return nil, nil, ErrInternalError + return nil, nil, nil, ErrInternalError } // Resolve the token policy - te, err := c.tokenStore.Lookup(req.ClientToken) + te, err := c.tokenStore.Lookup(clientToken) if err != nil { c.logger.Error("core: failed to lookup token", "error", err) - return nil, nil, ErrInternalError + return nil, nil, nil, ErrInternalError } // Ensure the token is valid if te == nil { - return nil, nil, logical.ErrPermissionDenied + return nil, nil, nil, logical.ErrPermissionDenied + } + + tokenPolicies := te.Policies + + var entity *identity.Entity + + // Append the policies of the entity to those on the tokens and create ACL + // off of the combined list. + if te.EntityID != "" { + //c.logger.Debug("core: entity set on the token", "entity_id", te.EntityID) + // Fetch entity for the entity ID in the token entry + entity, err = c.identityStore.memDBEntityByID(te.EntityID, false) + if err != nil { + c.logger.Error("core: failed to lookup entity using its ID", "error", err) + return nil, nil, nil, ErrInternalError + } + + if entity == nil { + // If there was no corresponding entity object found, it is + // possible that the entity got merged into another entity. Try + // finding entity based on the merged entity index. + entity, err = c.identityStore.memDBEntityByMergedEntityID(te.EntityID, false) + if err != nil { + c.logger.Error("core: failed to lookup entity in merged entity ID index", "error", err) + return nil, nil, nil, ErrInternalError + } + } + + if entity != nil { + //c.logger.Debug("core: entity successfully fetched; adding entity policies to token's policies to create ACL") + // Attach the policies on the entity to the policies tied to the token + tokenPolicies = append(tokenPolicies, entity.Policies...) + + groupPolicies, err := c.identityStore.groupPoliciesByEntityID(entity.ID) + if err != nil { + c.logger.Error("core: failed to fetch group policies", "error", err) + return nil, nil, nil, ErrInternalError + } + + // Attach the policies from all the groups to which this entity ID + // belongs to + tokenPolicies = append(tokenPolicies, groupPolicies...) + } } // Construct the corresponding ACL object - acl, err := c.policyStore.ACL(te.Policies...) + acl, err := c.policyStore.ACL(tokenPolicies...) if err != nil { c.logger.Error("core: failed to construct ACL", "error", err) - return nil, nil, ErrInternalError + return nil, nil, nil, ErrInternalError } - return acl, te, nil + return acl, te, entity, nil } func (c *Core) checkToken(req *logical.Request) (*logical.Auth, *TokenEntry, error) { defer metrics.MeasureSince([]string{"core", "check_token"}, time.Now()) - acl, te, err := c.fetchACLandTokenEntry(req) + acl, te, _, err := c.fetchACLTokenEntryAndEntity(req.ClientToken) if err != nil { return nil, te, err } @@ -720,9 +775,15 @@ func (c *Core) checkToken(req *logical.Request) (*logical.Auth, *TokenEntry, err auth := &logical.Auth{ ClientToken: req.ClientToken, Accessor: req.ClientTokenAccessor, - Policies: te.Policies, - Metadata: te.Meta, - DisplayName: te.DisplayName, + } + + if te != nil { + auth.Policies = te.Policies + auth.Metadata = te.Meta + auth.DisplayName = te.DisplayName + auth.EntityID = te.EntityID + // Store the entity ID in the request object + req.EntityID = te.EntityID } // Check the standard non-root ACLs. Return the token entry if it's not @@ -1086,7 +1147,7 @@ func (c *Core) sealInitCommon(req *logical.Request) (retErr error) { } // Validate the token is a root token - acl, te, err := c.fetchACLandTokenEntry(req) + acl, te, _, err := c.fetchACLTokenEntryAndEntity(req.ClientToken) if err != nil { // Since there is no token store in standby nodes, sealing cannot // be done. Ideally, the request has to be forwarded to leader node @@ -1204,7 +1265,7 @@ func (c *Core) StepDown(req *logical.Request) (retErr error) { return nil } - acl, te, err := c.fetchACLandTokenEntry(req) + acl, te, _, err := c.fetchACLTokenEntryAndEntity(req.ClientToken) if err != nil { retErr = multierror.Append(retErr, err) return retErr diff --git a/vault/dynamic_system_view.go b/vault/dynamic_system_view.go index b5e477af90..9d546e7628 100644 --- a/vault/dynamic_system_view.go +++ b/vault/dynamic_system_view.go @@ -109,7 +109,7 @@ func (d dynamicSystemView) ResponseWrapData(data map[string]interface{}, ttl tim resp.WrapInfo.Format = "jwt" } - _, err := d.core.wrapInCubbyhole(req, resp) + _, err := d.core.wrapInCubbyhole(req, resp, nil) if err != nil { return nil, err } diff --git a/vault/identity_lookup.go b/vault/identity_lookup.go new file mode 100644 index 0000000000..5ce3156a77 --- /dev/null +++ b/vault/identity_lookup.go @@ -0,0 +1,78 @@ +package vault + +import ( + "fmt" + "strings" + + "github.com/hashicorp/vault/logical" + "github.com/hashicorp/vault/logical/framework" +) + +func lookupPaths(i *IdentityStore) []*framework.Path { + return []*framework.Path{ + { + Pattern: "lookup/group$", + Fields: map[string]*framework.FieldSchema{ + "type": { + Type: framework.TypeString, + Description: "Type of lookup. Current supported values are 'by_id' and 'by_name'", + }, + "group_name": { + Type: framework.TypeString, + Description: "Name of the group.", + }, + "group_id": { + Type: framework.TypeString, + Description: "ID of the group.", + }, + }, + Callbacks: map[logical.Operation]framework.OperationFunc{ + logical.UpdateOperation: i.checkPremiumVersion(i.pathLookupGroupUpdate), + }, + + HelpSynopsis: strings.TrimSpace(lookupHelp["lookup-group"][0]), + HelpDescription: strings.TrimSpace(lookupHelp["lookup-group"][1]), + }, + } +} + +func (i *IdentityStore) pathLookupGroupUpdate(req *logical.Request, d *framework.FieldData) (*logical.Response, error) { + lookupType := d.Get("type").(string) + if lookupType == "" { + return logical.ErrorResponse("empty type"), nil + } + + switch lookupType { + case "by_id": + groupID := d.Get("group_id").(string) + if groupID == "" { + return logical.ErrorResponse("empty group_id"), nil + } + group, err := i.memDBGroupByID(groupID, false) + if err != nil { + return nil, err + } + return i.handleGroupReadCommon(group) + case "by_name": + groupName := d.Get("group_name").(string) + if groupName == "" { + return logical.ErrorResponse("empty group_name"), nil + } + group, err := i.memDBGroupByName(groupName, false) + if err != nil { + return nil, err + } + return i.handleGroupReadCommon(group) + default: + return logical.ErrorResponse(fmt.Sprintf("unrecognized type %q", lookupType)), nil + } + + return nil, nil +} + +var lookupHelp = map[string][2]string{ + "lookup-group": { + "Query groups based on factors.", + "Currently this supports querying groups by its name or ID.", + }, +} diff --git a/vault/identity_store.go b/vault/identity_store.go new file mode 100644 index 0000000000..8e9ac9eeae --- /dev/null +++ b/vault/identity_store.go @@ -0,0 +1,334 @@ +package vault + +import ( + "fmt" + "strings" + + "github.com/golang/protobuf/ptypes" + memdb "github.com/hashicorp/go-memdb" + "github.com/hashicorp/vault/helper/identity" + "github.com/hashicorp/vault/helper/locksutil" + "github.com/hashicorp/vault/helper/storagepacker" + "github.com/hashicorp/vault/logical" + "github.com/hashicorp/vault/logical/framework" + "github.com/hashicorp/vault/version" +) + +const ( + groupBucketsPrefix = "packer/group/buckets/" +) + +// NewIdentityStore creates a new identity store +func NewIdentityStore(core *Core, config *logical.BackendConfig) (*IdentityStore, error) { + var err error + + // Create a new in-memory database for the identity store + db, err := memdb.NewMemDB(identityStoreSchema()) + if err != nil { + return nil, fmt.Errorf("failed to create memdb for identity store: %v", err) + } + + iStore := &IdentityStore{ + view: config.StorageView, + db: db, + entityLocks: locksutil.CreateLocks(), + logger: core.logger, + validateMountAccessorFunc: core.router.validateMountByAccessor, + } + + iStore.entityPacker, err = storagepacker.NewStoragePacker(iStore.view, iStore.logger, "") + if err != nil { + return nil, fmt.Errorf("failed to create entity packer: %v", err) + } + + iStore.groupPacker, err = storagepacker.NewStoragePacker(iStore.view, iStore.logger, groupBucketsPrefix) + if err != nil { + return nil, fmt.Errorf("failed to create group packer: %v", err) + } + + iStore.Backend = &framework.Backend{ + BackendType: logical.TypeLogical, + Paths: framework.PathAppend( + entityPaths(iStore), + aliasPaths(iStore), + groupPaths(iStore), + lookupPaths(iStore), + upgradePaths(iStore), + ), + Invalidate: iStore.Invalidate, + } + + err = iStore.Setup(config) + if err != nil { + return nil, err + } + + return iStore, nil +} + +func (i *IdentityStore) checkPremiumVersion(f framework.OperationFunc) framework.OperationFunc { + ver := version.GetVersion() + if ver.VersionMetadata == "pro" { + return func(*logical.Request, *framework.FieldData) (*logical.Response, error) { + return logical.ErrorResponse("identity features not available in the pro version"), logical.ErrInvalidRequest + } + } + return f +} + +// Invalidate is a callback wherein the backend is informed that the value at +// the given key is updated. In identity store's case, it would be the entity +// storage entries that get updated. The value needs to be read and MemDB needs +// to be updated accordingly. +func (i *IdentityStore) Invalidate(key string) { + i.logger.Debug("identity: invalidate notification received", "key", key) + + switch { + // Check if the key is a storage entry key for an entity bucket + case strings.HasPrefix(key, storagepacker.StoragePackerBucketsPrefix): + // Get the hash value of the storage bucket entry key + bucketKeyHash := i.entityPacker.BucketKeyHashByKey(key) + if len(bucketKeyHash) == 0 { + i.logger.Error("failed to get the bucket entry key hash") + return + } + + // Create a MemDB transaction + txn := i.db.Txn(true) + defer txn.Abort() + + // Each entity object in MemDB holds the MD5 hash of the storage + // entry key of the entity bucket. Fetch all the entities that + // belong to this bucket using the hash value. Remove these entities + // from MemDB along with all the aliases of each entity. + entitiesFetched, err := i.memDBEntitiesByBucketEntryKeyHashInTxn(txn, string(bucketKeyHash)) + if err != nil { + i.logger.Error("failed to fetch entities using the bucket entry key hash", "bucket_entry_key_hash", bucketKeyHash) + return + } + + for _, entity := range entitiesFetched { + // Delete all the aliases in the entity. This function will also remove + // the corresponding alias indexes too. + err = i.deleteAliasesInEntityInTxn(txn, entity, entity.Aliases) + if err != nil { + i.logger.Error("failed to delete aliases in entity", "entity_id", entity.ID, "error", err) + return + } + + // Delete the entity using the same transaction + err = i.memDBDeleteEntityByIDInTxn(txn, entity.ID) + if err != nil { + i.logger.Error("failed to delete entity from MemDB", "entity_id", entity.ID, "error", err) + return + } + } + + // Get the storage bucket entry + bucket, err := i.entityPacker.GetBucket(key) + if err != nil { + i.logger.Error("failed to refresh entities", "key", key, "error", err) + return + } + + // If the underlying entry is nil, it means that this invalidation + // notification is for the deletion of the underlying storage entry. At + // this point, since all the entities belonging to this bucket are + // already removed, there is nothing else to be done. But, if the + // storage entry is non-nil, its an indication of an update. In this + // case, entities in the updated bucket needs to be reinserted into + // MemDB. + if bucket != nil { + for _, item := range bucket.Items { + entity, err := i.parseEntityFromBucketItem(item) + if err != nil { + i.logger.Error("failed to parse entity from bucket entry item", "error", err) + return + } + + // Only update MemDB and don't touch the storage + err = i.upsertEntityInTxn(txn, entity, nil, false, false) + if err != nil { + i.logger.Error("failed to update entity in MemDB", "error", err) + return + } + } + } + + txn.Commit() + return + + // Check if the key is a storage entry key for an group bucket + case strings.HasPrefix(key, groupBucketsPrefix): + // Get the hash value of the storage bucket entry key + bucketKeyHash := i.groupPacker.BucketKeyHashByKey(key) + if len(bucketKeyHash) == 0 { + i.logger.Error("failed to get the bucket entry key hash") + return + } + + // Create a MemDB transaction + txn := i.db.Txn(true) + defer txn.Abort() + + groupsFetched, err := i.memDBGroupsByBucketEntryKeyHashInTxn(txn, string(bucketKeyHash)) + if err != nil { + i.logger.Error("failed to fetch groups using the bucket entry key hash", "bucket_entry_key_hash", bucketKeyHash) + return + } + + for _, group := range groupsFetched { + // Delete the group using the same transaction + err = i.memDBDeleteGroupByIDInTxn(txn, group.ID) + if err != nil { + i.logger.Error("failed to delete group from MemDB", "group_id", group.ID, "error", err) + return + } + } + + // Get the storage bucket entry + bucket, err := i.groupPacker.GetBucket(key) + if err != nil { + i.logger.Error("failed to refresh group", "key", key, "error", err) + return + } + + if bucket != nil { + for _, item := range bucket.Items { + group, err := i.parseGroupFromBucketItem(item) + if err != nil { + i.logger.Error("failed to parse group from bucket entry item", "error", err) + return + } + + // Only update MemDB and don't touch the storage + err = i.upsertGroupInTxn(txn, group, false) + if err != nil { + i.logger.Error("failed to update group in MemDB", "error", err) + return + } + } + } + + txn.Commit() + return + } +} + +func (i *IdentityStore) parseEntityFromBucketItem(item *storagepacker.Item) (*identity.Entity, error) { + if item == nil { + return nil, fmt.Errorf("nil item") + } + + var entity identity.Entity + err := ptypes.UnmarshalAny(item.Message, &entity) + if err != nil { + return nil, fmt.Errorf("failed to decode entity from storage bucket item: %v", err) + } + + return &entity, nil +} + +func (i *IdentityStore) parseGroupFromBucketItem(item *storagepacker.Item) (*identity.Group, error) { + if item == nil { + return nil, fmt.Errorf("nil item") + } + + var group identity.Group + err := ptypes.UnmarshalAny(item.Message, &group) + if err != nil { + return nil, fmt.Errorf("failed to decode group from storage bucket item: %v", err) + } + + return &group, nil +} + +// EntityByAliasFactors fetches the entity based on factors of alias, i.e mount +// accessor and the alias name. +func (i *IdentityStore) EntityByAliasFactors(mountAccessor, aliasName string, clone bool) (*identity.Entity, error) { + if mountAccessor == "" { + return nil, fmt.Errorf("missing mount accessor") + } + + if aliasName == "" { + return nil, fmt.Errorf("missing alias name") + } + + alias, err := i.memDBAliasByFactors(mountAccessor, aliasName, false) + if err != nil { + return nil, err + } + + if alias == nil { + return nil, nil + } + + return i.memDBEntityByAliasID(alias.ID, clone) +} + +// CreateEntity creates a new entity. This is used by core to +// associate each login attempt by a alias to a unified entity in Vault. +func (i *IdentityStore) CreateEntity(alias *logical.Alias) (*identity.Entity, error) { + var entity *identity.Entity + var err error + + if alias == nil { + return nil, fmt.Errorf("alias is nil") + } + + if alias.Name == "" { + return nil, fmt.Errorf("empty alias name") + } + + mountValidationResp := i.validateMountAccessorFunc(alias.MountAccessor) + if mountValidationResp == nil { + return nil, fmt.Errorf("invalid mount accessor %q", alias.MountAccessor) + } + + if mountValidationResp.MountType != alias.MountType { + return nil, fmt.Errorf("mount accessor %q is not a mount of type %q", alias.MountAccessor, alias.MountType) + } + + // Check if an entity already exists for the given alais + entity, err = i.EntityByAliasFactors(alias.MountAccessor, alias.Name, false) + if err != nil { + return nil, err + } + if entity != nil { + return nil, fmt.Errorf("alias already belongs to a different entity") + } + + entity = &identity.Entity{} + + err = i.sanitizeEntity(entity) + if err != nil { + return nil, err + } + + // Create a new alias + newAlias := &identity.Alias{ + EntityID: entity.ID, + Name: alias.Name, + MountAccessor: alias.MountAccessor, + MountPath: mountValidationResp.MountPath, + MountType: mountValidationResp.MountType, + } + + err = i.sanitizeAlias(newAlias) + if err != nil { + return nil, err + } + + // Append the new alias to the new entity + entity.Aliases = []*identity.Alias{ + newAlias, + } + + // Update MemDB and persist entity object + err = i.upsertEntity(entity, nil, true) + if err != nil { + return nil, err + } + + return entity, nil +} diff --git a/vault/identity_store_aliases.go b/vault/identity_store_aliases.go new file mode 100644 index 0000000000..43d46ff29a --- /dev/null +++ b/vault/identity_store_aliases.go @@ -0,0 +1,364 @@ +package vault + +import ( + "fmt" + "strings" + + "github.com/golang/protobuf/ptypes" + memdb "github.com/hashicorp/go-memdb" + "github.com/hashicorp/vault/helper/identity" + "github.com/hashicorp/vault/logical" + "github.com/hashicorp/vault/logical/framework" +) + +// aliasPaths returns the API endpoints to operate on aliases. +// Following are the paths supported: +// alias - To register/modify a alias +// alias/id - To lookup, delete and list aliases based on ID +func aliasPaths(i *IdentityStore) []*framework.Path { + return []*framework.Path{ + { + Pattern: "alias$", + Fields: map[string]*framework.FieldSchema{ + "id": { + Type: framework.TypeString, + Description: "ID of the alias", + }, + "entity_id": { + Type: framework.TypeString, + Description: "Entity ID to which this alias belongs to", + }, + "mount_accessor": { + Type: framework.TypeString, + Description: "Mount accessor to which this alias belongs to", + }, + "name": { + Type: framework.TypeString, + Description: "Name of the alias", + }, + "metadata": { + Type: framework.TypeStringSlice, + Description: "Metadata to be associated with the alias. Format should be a list of `key=value` pairs.", + }, + }, + Callbacks: map[logical.Operation]framework.OperationFunc{ + logical.UpdateOperation: i.checkPremiumVersion(i.pathAliasRegister), + }, + + HelpSynopsis: strings.TrimSpace(aliasHelp["alias"][0]), + HelpDescription: strings.TrimSpace(aliasHelp["alias"][1]), + }, + { + Pattern: "alias/id/" + framework.GenericNameRegex("id"), + Fields: map[string]*framework.FieldSchema{ + "id": { + Type: framework.TypeString, + Description: "ID of the alias", + }, + "entity_id": { + Type: framework.TypeString, + Description: "Entity ID to which this alias should be tied to", + }, + "mount_accessor": { + Type: framework.TypeString, + Description: "Mount accessor to which this alias belongs to", + }, + "name": { + Type: framework.TypeString, + Description: "Name of the alias", + }, + "metadata": { + Type: framework.TypeStringSlice, + Description: "Metadata to be associated with the alias. Format should be a comma separated list of `key=value` pairs.", + }, + }, + Callbacks: map[logical.Operation]framework.OperationFunc{ + logical.UpdateOperation: i.checkPremiumVersion(i.pathAliasIDUpdate), + logical.ReadOperation: i.checkPremiumVersion(i.pathAliasIDRead), + logical.DeleteOperation: i.checkPremiumVersion(i.pathAliasIDDelete), + }, + + HelpSynopsis: strings.TrimSpace(aliasHelp["alias-id"][0]), + HelpDescription: strings.TrimSpace(aliasHelp["alias-id"][1]), + }, + { + Pattern: "alias/id/?$", + Callbacks: map[logical.Operation]framework.OperationFunc{ + logical.ListOperation: i.checkPremiumVersion(i.pathAliasIDList), + }, + + HelpSynopsis: strings.TrimSpace(aliasHelp["alias-id-list"][0]), + HelpDescription: strings.TrimSpace(aliasHelp["alias-id-list"][1]), + }, + } +} + +// pathAliasRegister is used to register new alias +func (i *IdentityStore) pathAliasRegister(req *logical.Request, d *framework.FieldData) (*logical.Response, error) { + _, ok := d.GetOk("id") + if ok { + return i.pathAliasIDUpdate(req, d) + } + + return i.handleAliasUpdateCommon(req, d, nil) +} + +// pathAliasIDUpdate is used to update a alias based on the given +// alias ID +func (i *IdentityStore) pathAliasIDUpdate(req *logical.Request, d *framework.FieldData) (*logical.Response, error) { + // Get alias id + aliasID := d.Get("id").(string) + + if aliasID == "" { + return logical.ErrorResponse("missing alias ID"), nil + } + + alias, err := i.memDBAliasByID(aliasID, true) + if err != nil { + return nil, err + } + if alias == nil { + return logical.ErrorResponse("invalid alias id"), nil + } + + return i.handleAliasUpdateCommon(req, d, alias) +} + +// handleAliasUpdateCommon is used to update a alias +func (i *IdentityStore) handleAliasUpdateCommon(req *logical.Request, d *framework.FieldData, alias *identity.Alias) (*logical.Response, error) { + var err error + var newAlias bool + var entity *identity.Entity + var previousEntity *identity.Entity + + // Alias will be nil when a new alias is being registered; create a + // new struct in that case. + if alias == nil { + alias = &identity.Alias{} + newAlias = true + } + + // Get entity id + entityID := d.Get("entity_id").(string) + if entityID != "" { + entity, err = i.memDBEntityByID(entityID, true) + if err != nil { + return nil, err + } + if entity == nil { + return logical.ErrorResponse("invalid entity ID"), nil + } + } + + // Get alias name + aliasName := d.Get("name").(string) + if aliasName == "" { + return logical.ErrorResponse("missing alias name"), nil + } + + mountAccessor := d.Get("mount_accessor").(string) + if mountAccessor == "" { + return logical.ErrorResponse("missing mount_accessor"), nil + } + + mountValidationResp := i.validateMountAccessorFunc(mountAccessor) + if mountValidationResp == nil { + return logical.ErrorResponse(fmt.Sprintf("invalid mount accessor %q", mountAccessor)), nil + } + + // Get alias metadata + + // Accept metadata in the form of map[string]string to be able to index on + // it + var aliasMetadata map[string]string + aliasMetadataRaw, ok := d.GetOk("metadata") + if ok { + aliasMetadata, err = parseMetadata(aliasMetadataRaw.([]string)) + if err != nil { + return logical.ErrorResponse(fmt.Sprintf("failed to parse alias metadata: %v", err)), nil + } + } + + aliasByFactors, err := i.memDBAliasByFactors(mountValidationResp.MountAccessor, aliasName, false) + if err != nil { + return nil, err + } + + resp := &logical.Response{} + + if newAlias { + if aliasByFactors != nil { + return logical.ErrorResponse("combination of mount and alias name is already in use"), nil + } + + // If this is a alias being tied to a non-existent entity, create + // a new entity for it. + if entity == nil { + entity = &identity.Entity{ + Aliases: []*identity.Alias{ + alias, + }, + } + } else { + entity.Aliases = append(entity.Aliases, alias) + } + } else { + // Verify that the combination of alias name and mount is not + // already tied to a different alias + if aliasByFactors != nil && aliasByFactors.ID != alias.ID { + return logical.ErrorResponse("combination of mount and alias name is already in use"), nil + } + + // Fetch the entity to which the alias is tied to + existingEntity, err := i.memDBEntityByAliasID(alias.ID, true) + if err != nil { + return nil, err + } + + if existingEntity == nil { + return nil, fmt.Errorf("alias is not associated with an entity") + } + + if entity != nil && entity.ID != existingEntity.ID { + // Alias should be transferred from 'existingEntity' to 'entity' + err = i.deleteAliasFromEntity(existingEntity, alias) + if err != nil { + return nil, err + } + previousEntity = existingEntity + entity.Aliases = append(entity.Aliases, alias) + resp.AddWarning(fmt.Sprintf("alias is being transferred from entity %q to %q", existingEntity.ID, entity.ID)) + } else { + // Update entity with modified alias + err = i.updateAliasInEntity(existingEntity, alias) + if err != nil { + return nil, err + } + entity = existingEntity + } + } + + // ID creation and other validations; This is more useful for new entities + // and may not perform anything for the existing entities. Placing the + // check here to make the flow common for both new and existing entities. + err = i.sanitizeEntity(entity) + if err != nil { + return nil, err + } + + // Update the fields + alias.Name = aliasName + alias.Metadata = aliasMetadata + alias.MountType = mountValidationResp.MountType + alias.MountAccessor = mountValidationResp.MountAccessor + alias.MountPath = mountValidationResp.MountPath + + // Set the entity ID in the alias index. This should be done after + // sanitizing entity. + alias.EntityID = entity.ID + + // ID creation and other validations + err = i.sanitizeAlias(alias) + if err != nil { + return nil, err + } + + // Index entity and its aliases in MemDB and persist entity along with + // aliases in storage. If the alias is being transferred over from + // one entity to another, previous entity needs to get refreshed in MemDB + // and persisted in storage as well. + err = i.upsertEntity(entity, previousEntity, true) + if err != nil { + return nil, err + } + + // Return ID of both alias and entity + resp.Data = map[string]interface{}{ + "id": alias.ID, + "entity_id": entity.ID, + } + + return resp, nil +} + +// pathAliasIDRead returns the properties of a alias for a given +// alias ID +func (i *IdentityStore) pathAliasIDRead(req *logical.Request, d *framework.FieldData) (*logical.Response, error) { + aliasID := d.Get("id").(string) + if aliasID == "" { + return logical.ErrorResponse("missing alias id"), nil + } + + alias, err := i.memDBAliasByID(aliasID, false) + if err != nil { + return nil, err + } + + if alias == nil { + return nil, nil + } + + respData := map[string]interface{}{} + respData["id"] = alias.ID + respData["entity_id"] = alias.EntityID + respData["mount_type"] = alias.MountType + respData["mount_accessor"] = alias.MountAccessor + respData["mount_path"] = alias.MountPath + respData["metadata"] = alias.Metadata + respData["name"] = alias.Name + respData["merged_from_entity_ids"] = alias.MergedFromEntityIDs + + // Convert protobuf timestamp into RFC3339 format + respData["creation_time"] = ptypes.TimestampString(alias.CreationTime) + respData["last_update_time"] = ptypes.TimestampString(alias.LastUpdateTime) + + return &logical.Response{ + Data: respData, + }, nil +} + +// pathAliasIDDelete deleted the alias for a given alias ID +func (i *IdentityStore) pathAliasIDDelete(req *logical.Request, d *framework.FieldData) (*logical.Response, error) { + aliasID := d.Get("id").(string) + if aliasID == "" { + return logical.ErrorResponse("missing alias ID"), nil + } + + return nil, i.deleteAlias(aliasID) +} + +// pathAliasIDList lists the IDs of all the valid aliases in the identity +// store +func (i *IdentityStore) pathAliasIDList(req *logical.Request, d *framework.FieldData) (*logical.Response, error) { + ws := memdb.NewWatchSet() + iter, err := i.memDBAliases(ws) + if err != nil { + return nil, fmt.Errorf("failed to fetch iterator for aliases in memdb: %v", err) + } + + var aliasIDs []string + for { + raw := iter.Next() + if raw == nil { + break + } + aliasIDs = append(aliasIDs, raw.(*identity.Alias).ID) + } + + return logical.ListResponse(aliasIDs), nil +} + +var aliasHelp = map[string][2]string{ + "alias": { + "Create a new alias", + "", + }, + "alias-id": { + "Update, read or delete an entity using alias ID", + "", + }, + "alias-id-list": { + "List all the entity IDs", + "", + }, +} diff --git a/vault/identity_store_aliases_test.go b/vault/identity_store_aliases_test.go new file mode 100644 index 0000000000..a4367f5f1f --- /dev/null +++ b/vault/identity_store_aliases_test.go @@ -0,0 +1,531 @@ +package vault + +import ( + "reflect" + "testing" + + "github.com/hashicorp/vault/helper/identity" + "github.com/hashicorp/vault/logical" +) + +func TestIdentityStore_ListAlias(t *testing.T) { + var err error + var resp *logical.Response + + is, githubAccessor, _ := testIdentityStoreWithGithubAuth(t) + + entityReq := &logical.Request{ + Operation: logical.UpdateOperation, + Path: "entity", + } + resp, err = is.HandleRequest(entityReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%v resp:%#v", err, resp) + } + if resp == nil { + t.Fatalf("expected a non-nil response") + } + entityID := resp.Data["id"].(string) + + // Create a alias + aliasData := map[string]interface{}{ + "name": "testaliasname", + "mount_accessor": githubAccessor, + } + aliasReq := &logical.Request{ + Operation: logical.UpdateOperation, + Path: "alias", + Data: aliasData, + } + resp, err = is.HandleRequest(aliasReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%v resp:%#v", err, resp) + } + + aliasData["name"] = "entityalias" + aliasData["entity_id"] = entityID + resp, err = is.HandleRequest(aliasReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%v resp:%#v", err, resp) + } + + listReq := &logical.Request{ + Operation: logical.ListOperation, + Path: "alias/id", + } + resp, err = is.HandleRequest(listReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%v resp:%#v", err, resp) + } + + keys := resp.Data["keys"].([]string) + if len(keys) != 2 { + t.Fatalf("bad: lengh of alias IDs listed; expected: 2, actual: %d", len(keys)) + } +} + +// This test is required because MemDB does not take care of ensuring +// uniqueness of indexes that are marked unique. +func TestIdentityStore_AliasSameAliasNames(t *testing.T) { + var err error + var resp *logical.Response + is, githubAccessor, _ := testIdentityStoreWithGithubAuth(t) + + aliasData := map[string]interface{}{ + "name": "testaliasname", + "mount_accessor": githubAccessor, + } + + aliasReq := &logical.Request{ + Operation: logical.UpdateOperation, + Path: "alias", + Data: aliasData, + } + + // Register a alias + resp, err = is.HandleRequest(aliasReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%v resp:%#v", err, resp) + } + + // Register another alias with same name + resp, err = is.HandleRequest(aliasReq) + if err != nil { + t.Fatal(err) + } + if resp == nil || !resp.IsError() { + t.Fatalf("expected an error due to alias name not being unique") + } +} + +func TestIdentityStore_MemDBAliasIndexes(t *testing.T) { + var err error + + is, githubAccessor, _ := testIdentityStoreWithGithubAuth(t) + if is == nil { + t.Fatal("failed to create test identity store") + } + + validateMountResp := is.validateMountAccessorFunc(githubAccessor) + if validateMountResp == nil { + t.Fatal("failed to validate github auth mount") + } + + entity := &identity.Entity{ + ID: "testentityid", + Name: "testentityname", + } + + entity.BucketKeyHash = is.entityPacker.BucketKeyHashByItemID(entity.ID) + + err = is.memDBUpsertEntity(entity) + if err != nil { + t.Fatal(err) + } + + alias := &identity.Alias{ + EntityID: entity.ID, + ID: "testaliasid", + MountAccessor: githubAccessor, + MountType: validateMountResp.MountType, + Name: "testaliasname", + Metadata: map[string]string{ + "testkey1": "testmetadatavalue1", + "testkey2": "testmetadatavalue2", + }, + } + + err = is.memDBUpsertAlias(alias) + if err != nil { + t.Fatal(err) + } + + aliasFetched, err := is.memDBAliasByID("testaliasid", false) + if err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(alias, aliasFetched) { + t.Fatalf("bad: mismatched aliases; expected: %#v\n actual: %#v\n", alias, aliasFetched) + } + + aliasFetched, err = is.memDBAliasByEntityID(entity.ID, false) + if err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(alias, aliasFetched) { + t.Fatalf("bad: mismatched aliases; expected: %#v\n actual: %#v\n", alias, aliasFetched) + } + + aliasFetched, err = is.memDBAliasByFactors(validateMountResp.MountAccessor, "testaliasname", false) + if err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(alias, aliasFetched) { + t.Fatalf("bad: mismatched aliases; expected: %#v\n actual: %#v\n", alias, aliasFetched) + } + + aliasesFetched, err := is.memDBAliasesByMetadata(map[string]string{ + "testkey1": "testmetadatavalue1", + }, false) + if err != nil { + t.Fatal(err) + } + + if len(aliasesFetched) != 1 { + t.Fatalf("bad: length of aliases; expected: 1, actual: %d", len(aliasesFetched)) + } + + if !reflect.DeepEqual(alias, aliasesFetched[0]) { + t.Fatalf("bad: mismatched aliases; expected: %#v\n actual: %#v\n", alias, aliasFetched) + } + + aliasesFetched, err = is.memDBAliasesByMetadata(map[string]string{ + "testkey2": "testmetadatavalue2", + }, false) + if err != nil { + t.Fatal(err) + } + + if len(aliasesFetched) != 1 { + t.Fatalf("bad: length of aliases; expected: 1, actual: %d", len(aliasesFetched)) + } + + if !reflect.DeepEqual(alias, aliasesFetched[0]) { + t.Fatalf("bad: mismatched aliases; expected: %#v\n actual: %#v\n", alias, aliasFetched) + } + + aliasesFetched, err = is.memDBAliasesByMetadata(map[string]string{ + "testkey1": "testmetadatavalue1", + "testkey2": "testmetadatavalue2", + }, false) + if err != nil { + t.Fatal(err) + } + + if len(aliasesFetched) != 1 { + t.Fatalf("bad: length of aliases; expected: 1, actual: %d", len(aliasesFetched)) + } + + if !reflect.DeepEqual(alias, aliasesFetched[0]) { + t.Fatalf("bad: mismatched aliases; expected: %#v\n actual: %#v\n", alias, aliasFetched) + } + + alias2 := &identity.Alias{ + EntityID: entity.ID, + ID: "testaliasid2", + MountAccessor: validateMountResp.MountAccessor, + MountType: validateMountResp.MountType, + Name: "testaliasname2", + Metadata: map[string]string{ + "testkey1": "testmetadatavalue1", + "testkey3": "testmetadatavalue3", + }, + } + + err = is.memDBUpsertAlias(alias2) + if err != nil { + t.Fatal(err) + } + + aliasesFetched, err = is.memDBAliasesByMetadata(map[string]string{ + "testkey1": "testmetadatavalue1", + }, false) + if err != nil { + t.Fatal(err) + } + + if len(aliasesFetched) != 2 { + t.Fatalf("bad: length of aliases; expected: 2, actual: %d", len(aliasesFetched)) + } + + aliasesFetched, err = is.memDBAliasesByMetadata(map[string]string{ + "testkey3": "testmetadatavalue3", + }, false) + if err != nil { + t.Fatal(err) + } + + if len(aliasesFetched) != 1 { + t.Fatalf("bad: length of aliases; expected: 1, actual: %d", len(aliasesFetched)) + } + + err = is.memDBDeleteAliasByID("testaliasid") + if err != nil { + t.Fatal(err) + } + + aliasFetched, err = is.memDBAliasByID("testaliasid", false) + if err != nil { + t.Fatal(err) + } + + if aliasFetched != nil { + t.Fatalf("expected a nil alias") + } +} + +func TestIdentityStore_AliasRegister(t *testing.T) { + var err error + var resp *logical.Response + + is, githubAccessor, _ := testIdentityStoreWithGithubAuth(t) + + if is == nil { + t.Fatal("failed to create test alias store") + } + + aliasData := map[string]interface{}{ + "name": "testaliasname", + "mount_accessor": githubAccessor, + "metadata": []string{"organization=hashicorp", "team=vault"}, + } + + aliasReq := &logical.Request{ + Operation: logical.UpdateOperation, + Path: "alias", + Data: aliasData, + } + + // Register the alias + resp, err = is.HandleRequest(aliasReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%v resp:%#v", err, resp) + } + + idRaw, ok := resp.Data["id"] + if !ok { + t.Fatalf("alias id not present in alias register response") + } + + id := idRaw.(string) + if id == "" { + t.Fatalf("invalid alias id in alias register response") + } + + entityIDRaw, ok := resp.Data["entity_id"] + if !ok { + t.Fatalf("entity id not present in alias register response") + } + + entityID := entityIDRaw.(string) + if entityID == "" { + t.Fatalf("invalid entity id in alias register response") + } +} + +func TestIdentityStore_AliasUpdate(t *testing.T) { + var err error + var resp *logical.Response + is, githubAccessor, _ := testIdentityStoreWithGithubAuth(t) + + aliasData := map[string]interface{}{ + "name": "testaliasname", + "mount_accessor": githubAccessor, + "metadata": []string{"organization=hashicorp", "team=vault"}, + } + + aliasReq := &logical.Request{ + Operation: logical.UpdateOperation, + Path: "alias", + Data: aliasData, + } + + // This will create a alias and a corresponding entity + resp, err = is.HandleRequest(aliasReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%v resp:%#v", err, resp) + } + aliasID := resp.Data["id"].(string) + + updateData := map[string]interface{}{ + "name": "updatedaliasname", + "mount_accessor": githubAccessor, + "metadata": []string{"organization=updatedorganization", "team=updatedteam"}, + } + + aliasReq.Data = updateData + aliasReq.Path = "alias/id/" + aliasID + resp, err = is.HandleRequest(aliasReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%v resp:%#v", err, resp) + } + + aliasReq.Operation = logical.ReadOperation + resp, err = is.HandleRequest(aliasReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%v resp:%#v", err, resp) + } + + aliasMetadata := resp.Data["metadata"].(map[string]string) + updatedOrg := aliasMetadata["organization"] + updatedTeam := aliasMetadata["team"] + + if resp.Data["name"] != "updatedaliasname" || updatedOrg != "updatedorganization" || updatedTeam != "updatedteam" { + t.Fatalf("failed to update alias information; \n response data: %#v\n", resp.Data) + } +} + +func TestIdentityStore_AliasUpdate_ByID(t *testing.T) { + var err error + var resp *logical.Response + is, githubAccessor, _ := testIdentityStoreWithGithubAuth(t) + + updateData := map[string]interface{}{ + "name": "updatedaliasname", + "mount_accessor": githubAccessor, + "metadata": []string{"organization=updatedorganization", "team=updatedteam"}, + } + + updateReq := &logical.Request{ + Operation: logical.UpdateOperation, + Path: "alias/id/invalidaliasid", + Data: updateData, + } + + // Try to update an non-existent alias + resp, err = is.HandleRequest(updateReq) + if err != nil { + t.Fatal(err) + } + if resp == nil || !resp.IsError() { + t.Fatalf("expected an error due to invalid alias id") + } + + registerData := map[string]interface{}{ + "name": "testaliasname", + "mount_accessor": githubAccessor, + "metadata": []string{"organization=hashicorp", "team=vault"}, + } + + registerReq := &logical.Request{ + Operation: logical.UpdateOperation, + Path: "alias", + Data: registerData, + } + + resp, err = is.HandleRequest(registerReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%v resp:%#v", err, resp) + } + + idRaw, ok := resp.Data["id"] + if !ok { + t.Fatalf("alias id not present in response") + } + id := idRaw.(string) + if id == "" { + t.Fatalf("invalid alias id") + } + + updateReq.Path = "alias/id/" + id + resp, err = is.HandleRequest(updateReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%v resp:%#v", err, resp) + } + + readReq := &logical.Request{ + Operation: logical.ReadOperation, + Path: updateReq.Path, + } + resp, err = is.HandleRequest(readReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%v resp:%#v", err, resp) + } + + aliasMetadata := resp.Data["metadata"].(map[string]string) + updatedOrg := aliasMetadata["organization"] + updatedTeam := aliasMetadata["team"] + + if resp.Data["name"] != "updatedaliasname" || updatedOrg != "updatedorganization" || updatedTeam != "updatedteam" { + t.Fatalf("failed to update alias information; \n response data: %#v\n", resp.Data) + } + + delete(registerReq.Data, "name") + + resp, err = is.HandleRequest(registerReq) + if err != nil { + t.Fatal(err) + } + if resp == nil || !resp.IsError() { + t.Fatalf("expected error due to missing alias name") + } + + registerReq.Data["name"] = "testaliasname" + delete(registerReq.Data, "mount_accessor") + + resp, err = is.HandleRequest(registerReq) + if err != nil { + t.Fatal(err) + } + if resp == nil || !resp.IsError() { + t.Fatalf("expected error due to missing mount accessor") + } +} + +func TestIdentityStore_AliasReadDelete(t *testing.T) { + var err error + var resp *logical.Response + + is, githubAccessor, _ := testIdentityStoreWithGithubAuth(t) + + registerData := map[string]interface{}{ + "name": "testaliasname", + "mount_accessor": githubAccessor, + "metadata": []string{"organization=hashicorp", "team=vault"}, + } + + registerReq := &logical.Request{ + Operation: logical.UpdateOperation, + Path: "alias", + Data: registerData, + } + + resp, err = is.HandleRequest(registerReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%v resp:%#v", err, resp) + } + + idRaw, ok := resp.Data["id"] + if !ok { + t.Fatalf("alias id not present in response") + } + id := idRaw.(string) + if id == "" { + t.Fatalf("invalid alias id") + } + + // Read it back using alias id + aliasReq := &logical.Request{ + Operation: logical.ReadOperation, + Path: "alias/id/" + id, + } + resp, err = is.HandleRequest(aliasReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%v resp:%#v", err, resp) + } + + if resp.Data["id"].(string) == "" || + resp.Data["entity_id"].(string) == "" || + resp.Data["name"].(string) != registerData["name"] || + resp.Data["mount_type"].(string) != "github" { + t.Fatalf("bad: alias read response; \nexpected: %#v \nactual: %#v\n", registerData, resp.Data) + } + + aliasReq.Operation = logical.DeleteOperation + resp, err = is.HandleRequest(aliasReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%v resp:%#v", err, resp) + } + + aliasReq.Operation = logical.ReadOperation + resp, err = is.HandleRequest(aliasReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%v resp:%#v", err, resp) + } + if resp != nil { + t.Fatalf("bad: alias read response; expected: nil, actual: %#v\n", resp) + } +} diff --git a/vault/identity_store_entities.go b/vault/identity_store_entities.go new file mode 100644 index 0000000000..12aafca76e --- /dev/null +++ b/vault/identity_store_entities.go @@ -0,0 +1,501 @@ +package vault + +import ( + "fmt" + "strings" + + "github.com/golang/protobuf/ptypes" + memdb "github.com/hashicorp/go-memdb" + "github.com/hashicorp/vault/helper/identity" + "github.com/hashicorp/vault/helper/locksutil" + "github.com/hashicorp/vault/helper/storagepacker" + "github.com/hashicorp/vault/logical" + "github.com/hashicorp/vault/logical/framework" +) + +// entityPaths returns the API endpoints supported to operate on entities. +// Following are the paths supported: +// entity - To register a new entity +// entity/id - To lookup, modify, delete and list entities based on ID +// entity/merge - To merge entities based on ID +func entityPaths(i *IdentityStore) []*framework.Path { + return []*framework.Path{ + { + Pattern: "entity$", + Fields: map[string]*framework.FieldSchema{ + "id": { + Type: framework.TypeString, + Description: "ID of the entity", + }, + "name": { + Type: framework.TypeString, + Description: "Name of the entity", + }, + "metadata": { + Type: framework.TypeStringSlice, + Description: "Metadata to be associated with the entity. Format should be a list of `key=value` pairs.", + }, + "policies": { + Type: framework.TypeCommaStringSlice, + Description: "Policies to be tied to the entity", + }, + }, + Callbacks: map[logical.Operation]framework.OperationFunc{ + logical.UpdateOperation: i.checkPremiumVersion(i.pathEntityRegister), + }, + + HelpSynopsis: strings.TrimSpace(entityHelp["entity"][0]), + HelpDescription: strings.TrimSpace(entityHelp["entity"][1]), + }, + { + Pattern: "entity/id/" + framework.GenericNameRegex("id"), + Fields: map[string]*framework.FieldSchema{ + "id": { + Type: framework.TypeString, + Description: "ID of the entity", + }, + "name": { + Type: framework.TypeString, + Description: "Name of the entity", + }, + "metadata": { + Type: framework.TypeStringSlice, + Description: "Metadata to be associated with the entity. Format should be a comma separated list of `key=value` pairs.", + }, + "policies": { + Type: framework.TypeCommaStringSlice, + Description: "Policies to be tied to the entity", + }, + }, + Callbacks: map[logical.Operation]framework.OperationFunc{ + logical.UpdateOperation: i.checkPremiumVersion(i.pathEntityIDUpdate), + logical.ReadOperation: i.checkPremiumVersion(i.pathEntityIDRead), + logical.DeleteOperation: i.checkPremiumVersion(i.pathEntityIDDelete), + }, + + HelpSynopsis: strings.TrimSpace(entityHelp["entity-id"][0]), + HelpDescription: strings.TrimSpace(entityHelp["entity-id"][1]), + }, + { + Pattern: "entity/id/?$", + Callbacks: map[logical.Operation]framework.OperationFunc{ + logical.ListOperation: i.checkPremiumVersion(i.pathEntityIDList), + }, + + HelpSynopsis: strings.TrimSpace(entityHelp["entity-id-list"][0]), + HelpDescription: strings.TrimSpace(entityHelp["entity-id-list"][1]), + }, + { + Pattern: "entity/merge/?$", + Fields: map[string]*framework.FieldSchema{ + "from_entity_ids": { + Type: framework.TypeCommaStringSlice, + Description: "Entity IDs which needs to get merged", + }, + "to_entity_id": { + Type: framework.TypeString, + Description: "Entity ID into which all the other entities need to get merged", + }, + "force": { + Type: framework.TypeBool, + Description: "Setting this will follow the 'mine' strategy for merging MFA secrets. If there are secrets of the same type both in entities that are merged from and in entity into which all others are getting merged, secrets in the destination will be unaltered. If not set, this API will throw an error containing all the conflicts.", + }, + }, + Callbacks: map[logical.Operation]framework.OperationFunc{ + logical.UpdateOperation: i.checkPremiumVersion(i.pathEntityMergeID), + }, + + HelpSynopsis: strings.TrimSpace(entityHelp["entity-merge-id"][0]), + HelpDescription: strings.TrimSpace(entityHelp["entity-merge-id"][1]), + }, + } +} + +// pathEntityMergeID merges two or more entities into a single entity +func (i *IdentityStore) pathEntityMergeID(req *logical.Request, d *framework.FieldData) (*logical.Response, error) { + toEntityID := d.Get("to_entity_id").(string) + if toEntityID == "" { + return logical.ErrorResponse("missing entity id to merge to"), nil + } + + fromEntityIDs := d.Get("from_entity_ids").([]string) + if len(fromEntityIDs) == 0 { + return logical.ErrorResponse("missing entity ids to merge from"), nil + } + + force := d.Get("force").(bool) + + toEntityForLocking, err := i.memDBEntityByID(toEntityID, false) + if err != nil { + return nil, err + } + + if toEntityForLocking == nil { + return logical.ErrorResponse("entity id to merge to is invalid"), nil + } + + // Acquire the lock to modify the entity storage entry to merge to + toEntityLock := locksutil.LockForKey(i.entityLocks, toEntityForLocking.ID) + toEntityLock.Lock() + defer toEntityLock.Unlock() + + // Create a MemDB transaction to merge entities + txn := i.db.Txn(true) + defer txn.Abort() + + // Re-read post lock acquisition + toEntity, err := i.memDBEntityByID(toEntityID, true) + if err != nil { + return nil, err + } + + if toEntity == nil { + return logical.ErrorResponse("entity id to merge to is invalid"), nil + } + + if toEntity.ID != toEntityForLocking.ID { + return logical.ErrorResponse("acquired lock for an undesired entity"), nil + } + + var conflictErrors error + for _, fromEntityID := range fromEntityIDs { + if fromEntityID == toEntityID { + return logical.ErrorResponse("to_entity_id should not be present in from_entity_ids"), nil + } + + lockFromEntity, err := i.memDBEntityByID(fromEntityID, false) + if err != nil { + return nil, err + } + + if lockFromEntity == nil { + return logical.ErrorResponse("entity id to merge from is invalid"), nil + } + + // Acquire the lock to modify the entity storage entry to merge from + fromEntityLock := locksutil.LockForKey(i.entityLocks, lockFromEntity.ID) + + fromLockHeld := false + + // There are only 256 lock buckets and the chances of entity ID collision + // is fairly high. When we are merging entities belonging to the same + // bucket, multiple attempts to acquire the same lock should be avoided. + if fromEntityLock != toEntityLock { + fromEntityLock.Lock() + fromLockHeld = true + } + + // Re-read the entities post lock acquisition + fromEntity, err := i.memDBEntityByID(fromEntityID, false) + if err != nil { + if fromLockHeld { + fromEntityLock.Unlock() + } + return nil, err + } + + if fromEntity == nil { + if fromLockHeld { + fromEntityLock.Unlock() + } + return logical.ErrorResponse("entity id to merge from is invalid"), nil + } + + if fromEntity.ID != lockFromEntity.ID { + if fromLockHeld { + fromEntityLock.Unlock() + } + return logical.ErrorResponse("acquired lock for an undesired entity"), nil + } + + for _, alias := range fromEntity.Aliases { + // Set the desired entity id + alias.EntityID = toEntity.ID + + // Set the entity id of which this alias is now an alias to + alias.MergedFromEntityIDs = append(alias.MergedFromEntityIDs, fromEntity.ID) + + err = i.memDBUpsertAliasInTxn(txn, alias) + if err != nil { + if fromLockHeld { + fromEntityLock.Unlock() + } + return nil, fmt.Errorf("failed to update alias during merge: %v", err) + } + + // Add the alias to the desired entity + toEntity.Aliases = append(toEntity.Aliases, alias) + } + + // If the entity from which we are merging from was already a merged + // entity, transfer over the Merged set to the entity we are + // merging into. + toEntity.MergedEntityIDs = append(toEntity.MergedEntityIDs, fromEntity.MergedEntityIDs...) + + // Add the entity from which we are merging from to the list of entities + // the entity we are merging into is composed of. + toEntity.MergedEntityIDs = append(toEntity.MergedEntityIDs, fromEntity.ID) + + // Delete the entity which we are merging from in MemDB using the same transaction + err = i.memDBDeleteEntityByIDInTxn(txn, fromEntity.ID) + if err != nil { + if fromLockHeld { + fromEntityLock.Unlock() + } + return nil, err + } + + // Delete the entity which we are merging from in storage + err = i.entityPacker.DeleteItem(fromEntity.ID) + if err != nil { + if fromLockHeld { + fromEntityLock.Unlock() + } + return nil, err + } + + if fromLockHeld { + fromEntityLock.Unlock() + } + } + + if conflictErrors != nil && !force { + return logical.ErrorResponse(conflictErrors.Error()), nil + } + + // Update MemDB with changes to the entity we are merging to + err = i.memDBUpsertEntityInTxn(txn, toEntity) + if err != nil { + return nil, err + } + + // Persist the entity which we are merging to + toEntityAsAny, err := ptypes.MarshalAny(toEntity) + if err != nil { + return nil, err + } + item := &storagepacker.Item{ + ID: toEntity.ID, + Message: toEntityAsAny, + } + + err = i.entityPacker.PutItem(item) + if err != nil { + return nil, err + } + + // Committing the transaction *after* successfully performing storage + // persistence + txn.Commit() + + return nil, nil +} + +// pathEntityRegister is used to register a new entity +func (i *IdentityStore) pathEntityRegister(req *logical.Request, d *framework.FieldData) (*logical.Response, error) { + _, ok := d.GetOk("id") + if ok { + return i.pathEntityIDUpdate(req, d) + } + + return i.handleEntityUpdateCommon(req, d, nil) +} + +// pathEntityIDUpdate is used to update an entity based on the given entity ID +func (i *IdentityStore) pathEntityIDUpdate(req *logical.Request, d *framework.FieldData) (*logical.Response, error) { + // Get entity id + entityID := d.Get("id").(string) + + if entityID == "" { + return logical.ErrorResponse("missing entity id"), nil + } + + entity, err := i.memDBEntityByID(entityID, true) + if err != nil { + return nil, err + } + if entity == nil { + return nil, fmt.Errorf("invalid entity id") + } + + return i.handleEntityUpdateCommon(req, d, entity) +} + +// handleEntityUpdateCommon is used to update an entity +func (i *IdentityStore) handleEntityUpdateCommon(req *logical.Request, d *framework.FieldData, entity *identity.Entity) (*logical.Response, error) { + var err error + var newEntity bool + + // Entity will be nil when a new entity is being registered; create a new + // struct in that case. + if entity == nil { + entity = &identity.Entity{} + newEntity = true + } + + // Update the policies if supplied + entityPoliciesRaw, ok := d.GetOk("policies") + if ok { + entity.Policies = entityPoliciesRaw.([]string) + } + + // Get the name + entityName := d.Get("name").(string) + if entityName != "" { + entityByName, err := i.memDBEntityByName(entityName, false) + if err != nil { + return nil, err + } + switch { + case (newEntity && entityByName != nil), (entityByName != nil && entity.ID != "" && entityByName.ID != entity.ID): + return logical.ErrorResponse("entity name is already in use"), nil + } + entity.Name = entityName + } + + // Get entity metadata + + // Accept metadata in the form of map[string]string to be able to index on + // it + entityMetadataRaw, ok := d.GetOk("metadata") + if ok { + entity.Metadata, err = parseMetadata(entityMetadataRaw.([]string)) + if err != nil { + return logical.ErrorResponse(fmt.Sprintf("failed to parse entity metadata: %v", err)), nil + } + } + + // ID creation and some validations + err = i.sanitizeEntity(entity) + if err != nil { + return nil, err + } + + // Prepare the response + respData := map[string]interface{}{ + "id": entity.ID, + } + + var aliasIDs []string + for _, alias := range entity.Aliases { + aliasIDs = append(aliasIDs, alias.ID) + } + + respData["aliases"] = aliasIDs + + // Update MemDB and persist entity object + err = i.upsertEntity(entity, nil, true) + if err != nil { + return nil, err + } + + // Return ID of the entity that was either created or updated along with + // its aliases + return &logical.Response{ + Data: respData, + }, nil +} + +// pathEntityIDRead returns the properties of an entity for a given entity ID +func (i *IdentityStore) pathEntityIDRead(req *logical.Request, d *framework.FieldData) (*logical.Response, error) { + entityID := d.Get("id").(string) + if entityID == "" { + return logical.ErrorResponse("missing entity id"), nil + } + + entity, err := i.memDBEntityByID(entityID, false) + if err != nil { + return nil, err + } + if entity == nil { + return nil, nil + } + + respData := map[string]interface{}{} + respData["id"] = entity.ID + respData["name"] = entity.Name + respData["metadata"] = entity.Metadata + respData["merged_entity_ids"] = entity.MergedEntityIDs + respData["policies"] = entity.Policies + + // Convert protobuf timestamp into RFC3339 format + respData["creation_time"] = ptypes.TimestampString(entity.CreationTime) + respData["last_update_time"] = ptypes.TimestampString(entity.LastUpdateTime) + + // Convert each alias into a map and replace the time format in each + aliasesToReturn := make([]interface{}, len(entity.Aliases)) + for aliasIdx, alias := range entity.Aliases { + aliasMap := map[string]interface{}{} + aliasMap["id"] = alias.ID + aliasMap["entity_id"] = alias.EntityID + aliasMap["mount_type"] = alias.MountType + aliasMap["mount_accessor"] = alias.MountAccessor + aliasMap["mount_path"] = alias.MountPath + aliasMap["metadata"] = alias.Metadata + aliasMap["name"] = alias.Name + aliasMap["merged_from_entity_ids"] = alias.MergedFromEntityIDs + aliasMap["creation_time"] = ptypes.TimestampString(alias.CreationTime) + aliasMap["last_update_time"] = ptypes.TimestampString(alias.LastUpdateTime) + aliasesToReturn[aliasIdx] = aliasMap + } + + // Add the aliases information to the response which has the correct time + // formats + respData["aliases"] = aliasesToReturn + + resp := &logical.Response{ + Data: respData, + } + + return resp, nil +} + +// pathEntityIDDelete deletes the entity for a given entity ID +func (i *IdentityStore) pathEntityIDDelete(req *logical.Request, d *framework.FieldData) (*logical.Response, error) { + entityID := d.Get("id").(string) + if entityID == "" { + return logical.ErrorResponse("missing entity id"), nil + } + + return nil, i.deleteEntity(entityID) +} + +// pathEntityIDList lists the IDs of all the valid entities in the identity +// store +func (i *IdentityStore) pathEntityIDList(req *logical.Request, d *framework.FieldData) (*logical.Response, error) { + ws := memdb.NewWatchSet() + iter, err := i.memDBEntities(ws) + if err != nil { + return nil, fmt.Errorf("failed to fetch iterator for entities in memdb: %v", err) + } + + var entityIDs []string + for { + raw := iter.Next() + if raw == nil { + break + } + entityIDs = append(entityIDs, raw.(*identity.Entity).ID) + } + + return logical.ListResponse(entityIDs), nil +} + +var entityHelp = map[string][2]string{ + "entity": { + "Create a new entity", + "", + }, + "entity-id": { + "Update, read or delete an entity using entity ID", + "", + }, + "entity-id-list": { + "List all the entity IDs", + "", + }, + "entity-merge-id": { + "Merge two or more entities together", + "", + }, +} diff --git a/vault/identity_store_entities_test.go b/vault/identity_store_entities_test.go new file mode 100644 index 0000000000..b217456e7e --- /dev/null +++ b/vault/identity_store_entities_test.go @@ -0,0 +1,783 @@ +package vault + +import ( + "fmt" + "reflect" + "sort" + "testing" + + uuid "github.com/hashicorp/go-uuid" + credGithub "github.com/hashicorp/vault/builtin/credential/github" + "github.com/hashicorp/vault/helper/identity" + "github.com/hashicorp/vault/logical" +) + +func TestIdentityStore_EntityCreateUpdate(t *testing.T) { + var err error + var resp *logical.Response + + is, _, _ := testIdentityStoreWithGithubAuth(t) + + entityData := map[string]interface{}{ + "name": "testentityname", + "metadata": []string{"someusefulkey=someusefulvalue"}, + "policies": []string{"testpolicy1", "testpolicy2"}, + } + + entityReq := &logical.Request{ + Operation: logical.UpdateOperation, + Path: "entity", + Data: entityData, + } + + // Create the entity + resp, err = is.HandleRequest(entityReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%v resp:%#v", err, resp) + } + entityID := resp.Data["id"].(string) + + updateData := map[string]interface{}{ + // Set the entity ID here + "id": entityID, + "name": "updatedentityname", + "metadata": []string{"updatedkey=updatedvalue"}, + "policies": []string{"updatedpolicy1", "updatedpolicy2"}, + } + entityReq.Data = updateData + + // Update the entity + resp, err = is.HandleRequest(entityReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%v resp:%#v", err, resp) + } + + entityReq.Path = "entity/id/" + entityID + entityReq.Operation = logical.ReadOperation + + // Read the entity + resp, err = is.HandleRequest(entityReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%v resp:%#v", err, resp) + } + + if resp.Data["id"] != entityID || + resp.Data["name"] != updateData["name"] || + !reflect.DeepEqual(resp.Data["policies"], updateData["policies"]) { + t.Fatalf("bad: entity response after update; resp: %#v\n updateData: %#v\n", resp.Data, updateData) + } +} + +func TestIdentityStore_CloneImmutability(t *testing.T) { + alias := &identity.Alias{ + ID: "testaliasid", + Name: "testaliasname", + MergedFromEntityIDs: []string{"entityid1"}, + } + + entity := &identity.Entity{ + ID: "testentityid", + Name: "testentityname", + Aliases: []*identity.Alias{ + alias, + }, + } + + clonedEntity, err := entity.Clone() + if err != nil { + t.Fatal(err) + } + + // Modify entity + entity.Aliases[0].ID = "invalidid" + + if clonedEntity.Aliases[0].ID == "invalidid" { + t.Fatalf("cloned entity is mutated") + } + + clonedAlias, err := alias.Clone() + if err != nil { + t.Fatal(err) + } + + alias.MergedFromEntityIDs[0] = "invalidid" + + if clonedAlias.MergedFromEntityIDs[0] == "invalidid" { + t.Fatalf("cloned alias is mutated") + } +} + +func TestIdentityStore_MemDBImmutability(t *testing.T) { + var err error + is, githubAccessor, _ := testIdentityStoreWithGithubAuth(t) + + validateMountResp := is.validateMountAccessorFunc(githubAccessor) + if validateMountResp == nil { + t.Fatal("failed to validate github auth mount") + } + + alias1 := &identity.Alias{ + EntityID: "testentityid", + ID: "testaliasid", + MountAccessor: githubAccessor, + MountType: validateMountResp.MountType, + Name: "testaliasname", + Metadata: map[string]string{ + "testkey1": "testmetadatavalue1", + "testkey2": "testmetadatavalue2", + }, + } + + entity := &identity.Entity{ + ID: "testentityid", + Name: "testentityname", + Metadata: map[string]string{ + "someusefulkey": "someusefulvalue", + }, + Aliases: []*identity.Alias{ + alias1, + }, + } + + entity.BucketKeyHash = is.entityPacker.BucketKeyHashByItemID(entity.ID) + + err = is.memDBUpsertEntity(entity) + if err != nil { + t.Fatal(err) + } + + entityFetched, err := is.memDBEntityByID(entity.ID, true) + if err != nil { + t.Fatal(err) + } + + // Modify the fetched entity outside of a transaction + entityFetched.Aliases[0].ID = "invalidaliasid" + + entityFetched, err = is.memDBEntityByID(entity.ID, false) + if err != nil { + t.Fatal(err) + } + + if entityFetched.Aliases[0].ID == "invalidaliasid" { + t.Fatal("memdb item is mutable outside of transaction") + } +} + +func TestIdentityStore_ListEntities(t *testing.T) { + var err error + var resp *logical.Response + + is, _, _ := testIdentityStoreWithGithubAuth(t) + + entityReq := &logical.Request{ + Operation: logical.UpdateOperation, + Path: "entity", + } + + expected := []string{} + for i := 0; i < 10; i++ { + resp, err = is.HandleRequest(entityReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%v resp:%#v", err, resp) + } + expected = append(expected, resp.Data["id"].(string)) + } + + listReq := &logical.Request{ + Operation: logical.ListOperation, + Path: "entity/id", + } + + resp, err = is.HandleRequest(listReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%v resp:%#v", err, resp) + } + + actual := resp.Data["keys"].([]string) + + // Sort the operands for DeepEqual to work + sort.Strings(actual) + sort.Strings(expected) + if !reflect.DeepEqual(expected, actual) { + t.Fatalf("bad: listed entity IDs; expected: %#v\n actual: %#v\n", expected, actual) + } +} + +func TestIdentityStore_LoadingEntities(t *testing.T) { + var resp *logical.Response + // Add github credential factory to core config + err := AddTestCredentialBackend("github", credGithub.Factory) + if err != nil { + t.Fatalf("err: %s", err) + } + + c := TestCore(t) + unsealKeys, token := TestCoreInit(t, c) + for _, key := range unsealKeys { + if _, err := TestCoreUnseal(c, TestKeyCopy(key)); err != nil { + t.Fatalf("unseal err: %s", err) + } + } + + sealed, err := c.Sealed() + if err != nil { + t.Fatalf("err checking seal status: %s", err) + } + if sealed { + t.Fatal("should not be sealed") + } + + meGH := &MountEntry{ + Table: credentialTableType, + Path: "github/", + Type: "github", + Description: "github auth", + } + + // Mount UUID for github auth + meGHUUID, err := uuid.GenerateUUID() + if err != nil { + t.Fatal(err) + } + meGH.UUID = meGHUUID + + // Mount accessor for github auth + githubAccessor, err := c.generateMountAccessor("github") + if err != nil { + panic(fmt.Sprintf("could not generate github accessor: %v", err)) + } + meGH.Accessor = githubAccessor + + // Storage view for github auth + ghView := NewBarrierView(c.barrier, credentialBarrierPrefix+meGH.UUID+"/") + + // Sysview for github auth + ghSysview := c.mountEntrySysView(meGH) + + // Create new github auth credential backend + ghAuth, err := c.newCredentialBackend(meGH.Type, ghSysview, ghView, nil) + if err != nil { + t.Fatal(err) + } + + // Mount github auth + err = c.router.Mount(ghAuth, "auth/github", meGH, ghView) + if err != nil { + t.Fatal(err) + } + + // Identity store will be mounted by now, just fetch it from router + identitystore := c.router.MatchingBackend("identity/") + if identitystore == nil { + t.Fatalf("failed to fetch identity store from router") + } + + is := identitystore.(*IdentityStore) + + registerData := map[string]interface{}{ + "name": "testentityname", + "metadata": []string{"someusefulkey=someusefulvalue"}, + "policies": []string{"testpolicy1", "testpolicy2"}, + } + + registerReq := &logical.Request{ + Operation: logical.UpdateOperation, + Path: "entity", + Data: registerData, + } + + // Register the entity + resp, err = is.HandleRequest(registerReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%v resp:%#v", err, resp) + } + + entityID := resp.Data["id"].(string) + + readReq := &logical.Request{ + Path: "entity/id/" + entityID, + Operation: logical.ReadOperation, + } + + // Ensure that entity is created + resp, err = is.HandleRequest(readReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%v resp:%#v", err, resp) + } + + if resp.Data["id"] != entityID { + t.Fatalf("failed to read the created entity") + } + + // Perform a seal/unseal cycle + err = c.Seal(token) + if err != nil { + t.Fatalf("failed to seal core: %v", err) + } + + sealed, err = c.Sealed() + if err != nil { + t.Fatalf("err checking seal status: %s", err) + } + if !sealed { + t.Fatal("should be sealed") + } + + for _, key := range unsealKeys { + if _, err := TestCoreUnseal(c, TestKeyCopy(key)); err != nil { + t.Fatalf("unseal err: %s", err) + } + } + + sealed, err = c.Sealed() + if err != nil { + t.Fatalf("err checking seal status: %s", err) + } + if sealed { + t.Fatal("should not be sealed") + } + + // Check if the entity is restored + resp, err = is.HandleRequest(readReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%v resp:%#v", err, resp) + } + + if resp.Data["id"] != entityID { + t.Fatalf("failed to read the created entity after a seal/unseal cycle") + } +} + +func TestIdentityStore_MemDBEntityIndexes(t *testing.T) { + var err error + + is, githubAccessor, _ := testIdentityStoreWithGithubAuth(t) + + validateMountResp := is.validateMountAccessorFunc(githubAccessor) + if validateMountResp == nil { + t.Fatal("failed to validate github auth mount") + } + + alias1 := &identity.Alias{ + EntityID: "testentityid", + ID: "testaliasid", + MountAccessor: githubAccessor, + MountType: validateMountResp.MountType, + Name: "testaliasname", + Metadata: map[string]string{ + "testkey1": "testmetadatavalue1", + "testkey2": "testmetadatavalue2", + }, + } + + alias2 := &identity.Alias{ + EntityID: "testentityid", + ID: "testaliasid2", + MountAccessor: validateMountResp.MountAccessor, + MountType: validateMountResp.MountType, + Name: "testaliasname2", + Metadata: map[string]string{ + "testkey2": "testmetadatavalue2", + "testkey3": "testmetadatavalue3", + }, + } + + entity := &identity.Entity{ + ID: "testentityid", + Name: "testentityname", + Metadata: map[string]string{ + "someusefulkey": "someusefulvalue", + }, + Aliases: []*identity.Alias{ + alias1, + alias2, + }, + } + + entity.BucketKeyHash = is.entityPacker.BucketKeyHashByItemID(entity.ID) + + err = is.memDBUpsertEntity(entity) + if err != nil { + t.Fatal(err) + } + + // Fetch the entity using its ID + entityFetched, err := is.memDBEntityByID(entity.ID, false) + if err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(entity, entityFetched) { + t.Fatalf("bad: mismatched entities; expected: %#v\n actual: %#v\n", entity, entityFetched) + } + + // Fetch the entity using its name + entityFetched, err = is.memDBEntityByName(entity.Name, false) + if err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(entity, entityFetched) { + t.Fatalf("entity mismatched entities; expected: %#v\n actual: %#v\n", entity, entityFetched) + } + + // Fetch entities using the metadata + entitiesFetched, err := is.memDBEntitiesByMetadata(map[string]string{ + "someusefulkey": "someusefulvalue", + }, false) + if err != nil { + t.Fatal(err) + } + + if len(entitiesFetched) != 1 { + t.Fatalf("bad: length of entities; expected: 1, actual: %d", len(entitiesFetched)) + } + + if !reflect.DeepEqual(entity, entitiesFetched[0]) { + t.Fatalf("entity mismatch; entity: %#v\n entitiesFetched[0]: %#v\n", entity, entitiesFetched[0]) + } + + entitiesFetched, err = is.memDBEntitiesByBucketEntryKeyHash(entity.BucketKeyHash) + if err != nil { + t.Fatal(err) + } + + if len(entitiesFetched) != 1 { + t.Fatalf("bad: length of entities; expected: 1, actual: %d", len(entitiesFetched)) + } + + err = is.memDBDeleteEntityByID(entity.ID) + if err != nil { + t.Fatal(err) + } + + entityFetched, err = is.memDBEntityByID(entity.ID, false) + if err != nil { + t.Fatal(err) + } + + if entityFetched != nil { + t.Fatalf("bad: entity; expected: nil, actual: %#v\n", entityFetched) + } + + entityFetched, err = is.memDBEntityByName(entity.Name, false) + if err != nil { + t.Fatal(err) + } + + if entityFetched != nil { + t.Fatalf("bad: entity; expected: nil, actual: %#v\n", entityFetched) + } +} + +// This test is required because MemDB does not take care of ensuring +// uniqueness of indexes that are marked unique. It is the job of the higher +// level abstraction, the identity store in this case. +func TestIdentityStore_EntitySameEntityNames(t *testing.T) { + var err error + var resp *logical.Response + is, _, _ := testIdentityStoreWithGithubAuth(t) + + registerData := map[string]interface{}{ + "name": "testentityname", + } + + registerReq := &logical.Request{ + Operation: logical.UpdateOperation, + Path: "entity", + Data: registerData, + } + + // Register an entity + resp, err = is.HandleRequest(registerReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%v resp:%#v", err, resp) + } + + // Register another entity with same name + resp, err = is.HandleRequest(registerReq) + if err != nil { + t.Fatal(err) + } + if resp == nil || !resp.IsError() { + t.Fatalf("expected an error due to entity name not being unique") + } +} + +func TestIdentityStore_EntityCRUD(t *testing.T) { + var err error + var resp *logical.Response + + is, _, _ := testIdentityStoreWithGithubAuth(t) + + registerData := map[string]interface{}{ + "name": "testentityname", + "metadata": []string{"someusefulkey=someusefulvalue"}, + "policies": []string{"testpolicy1", "testpolicy2"}, + } + + registerReq := &logical.Request{ + Operation: logical.UpdateOperation, + Path: "entity", + Data: registerData, + } + + // Register the entity + resp, err = is.HandleRequest(registerReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%v resp:%#v", err, resp) + } + + idRaw, ok := resp.Data["id"] + if !ok { + t.Fatalf("entity id not present in response") + } + id := idRaw.(string) + if id == "" { + t.Fatalf("invalid entity id") + } + + readReq := &logical.Request{ + Path: "entity/id/" + id, + Operation: logical.ReadOperation, + } + + resp, err = is.HandleRequest(readReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%v resp:%#v", err, resp) + } + + if resp.Data["id"] != id || + resp.Data["name"] != registerData["name"] || + !reflect.DeepEqual(resp.Data["policies"], registerData["policies"]) { + t.Fatalf("bad: entity response") + } + + updateData := map[string]interface{}{ + "name": "updatedentityname", + "metadata": []string{"updatedkey=updatedvalue"}, + "policies": []string{"updatedpolicy1", "updatedpolicy2"}, + } + + updateReq := &logical.Request{ + Operation: logical.UpdateOperation, + Path: "entity/id/" + id, + Data: updateData, + } + + resp, err = is.HandleRequest(updateReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%v resp:%#v", err, resp) + } + + resp, err = is.HandleRequest(readReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%v resp:%#v", err, resp) + } + + if resp.Data["id"] != id || + resp.Data["name"] != updateData["name"] || + !reflect.DeepEqual(resp.Data["policies"], updateData["policies"]) { + t.Fatalf("bad: entity response after update; resp: %#v\n updateData: %#v\n", resp.Data, updateData) + } + + deleteReq := &logical.Request{ + Path: "entity/id/" + id, + Operation: logical.DeleteOperation, + } + + resp, err = is.HandleRequest(deleteReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%v resp:%#v", err, resp) + } + + resp, err = is.HandleRequest(readReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%v resp:%#v", err, resp) + } + if resp != nil { + t.Fatalf("expected a nil response; actual: %#v\n", resp) + } +} + +func TestIdentityStore_MergeEntitiesByID(t *testing.T) { + var err error + var resp *logical.Response + + is, githubAccessor, _ := testIdentityStoreWithGithubAuth(t) + + registerData := map[string]interface{}{ + "name": "testentityname2", + "metadata": []string{"someusefulkey=someusefulvalue"}, + } + + registerData2 := map[string]interface{}{ + "name": "testentityname", + "metadata": []string{"someusefulkey=someusefulvalue"}, + } + + aliasRegisterData1 := map[string]interface{}{ + "name": "testaliasname1", + "mount_accessor": githubAccessor, + "metadata": []string{"organization=hashicorp", "team=vault"}, + } + + aliasRegisterData2 := map[string]interface{}{ + "name": "testaliasname2", + "mount_accessor": githubAccessor, + "metadata": []string{"organization=hashicorp", "team=vault"}, + } + + aliasRegisterData3 := map[string]interface{}{ + "name": "testaliasname3", + "mount_accessor": githubAccessor, + "metadata": []string{"organization=hashicorp", "team=vault"}, + } + + aliasRegisterData4 := map[string]interface{}{ + "name": "testaliasname4", + "mount_accessor": githubAccessor, + "metadata": []string{"organization=hashicorp", "team=vault"}, + } + + registerReq := &logical.Request{ + Operation: logical.UpdateOperation, + Path: "entity", + Data: registerData, + } + + // Register the entity + resp, err = is.HandleRequest(registerReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%v resp:%#v", err, resp) + } + + entityID1 := resp.Data["id"].(string) + + // Set entity ID in alias registration data and register alias + aliasRegisterData1["entity_id"] = entityID1 + aliasRegisterData2["entity_id"] = entityID1 + + aliasReq := &logical.Request{ + Operation: logical.UpdateOperation, + Path: "alias", + Data: aliasRegisterData1, + } + + // Register the alias + resp, err = is.HandleRequest(aliasReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%v resp:%#v", err, resp) + } + + // Register the alias + aliasReq.Data = aliasRegisterData2 + resp, err = is.HandleRequest(aliasReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%v resp:%#v", err, resp) + } + + entity1, err := is.memDBEntityByID(entityID1, false) + if err != nil { + t.Fatal(err) + } + if entity1 == nil { + t.Fatalf("failed to create entity: %v", err) + } + if len(entity1.Aliases) != 2 { + t.Fatalf("bad: number of aliases in entity; expected: 2, actual: %d", len(entity1.Aliases)) + } + + registerReq.Data = registerData2 + // Register another entity + resp, err = is.HandleRequest(registerReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%v resp:%#v", err, resp) + } + + entityID2 := resp.Data["id"].(string) + // Set entity ID in alias registration data and register alias + aliasRegisterData3["entity_id"] = entityID2 + aliasRegisterData4["entity_id"] = entityID2 + + aliasReq = &logical.Request{ + Operation: logical.UpdateOperation, + Path: "alias", + Data: aliasRegisterData3, + } + + // Register the alias + resp, err = is.HandleRequest(aliasReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%v resp:%#v", err, resp) + } + + // Register the alias + aliasReq.Data = aliasRegisterData4 + resp, err = is.HandleRequest(aliasReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%v resp:%#v", err, resp) + } + + entity2, err := is.memDBEntityByID(entityID2, false) + if err != nil { + t.Fatal(err) + } + if entity2 == nil { + t.Fatalf("failed to create entity: %v", err) + } + + if len(entity2.Aliases) != 2 { + t.Fatalf("bad: number of aliases in entity; expected: 2, actual: %d", len(entity2.Aliases)) + } + + mergeData := map[string]interface{}{ + "to_entity_id": entityID1, + "from_entity_ids": []string{entityID2}, + } + mergeReq := &logical.Request{ + Operation: logical.UpdateOperation, + Path: "entity/merge", + Data: mergeData, + } + + resp, err = is.HandleRequest(mergeReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%v resp:%#v", err, resp) + } + + entityReq := &logical.Request{ + Operation: logical.ReadOperation, + Path: "entity/id/" + entityID2, + } + resp, err = is.HandleRequest(entityReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%v resp:%#v", err, resp) + } + if resp != nil { + t.Fatalf("entity should have been deleted") + } + + entityReq.Path = "entity/id/" + entityID1 + resp, err = is.HandleRequest(entityReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%v resp:%#v", err, resp) + } + + entity2Aliases := resp.Data["aliases"].([]interface{}) + if len(entity2Aliases) != 4 { + t.Fatalf("bad: number of aliases in entity; expected: 4, actual: %d", len(entity2Aliases)) + } + + for _, aliasRaw := range entity2Aliases { + alias := aliasRaw.(map[string]interface{}) + aliasLookedUp, err := is.memDBAliasByID(alias["id"].(string), false) + if err != nil { + t.Fatal(err) + } + if aliasLookedUp == nil { + t.Fatalf("index for alias id %q is not updated", alias["id"].(string)) + } + } +} diff --git a/vault/identity_store_groups.go b/vault/identity_store_groups.go new file mode 100644 index 0000000000..9bf42ce8a1 --- /dev/null +++ b/vault/identity_store_groups.go @@ -0,0 +1,286 @@ +package vault + +import ( + "fmt" + "strings" + + "github.com/golang/protobuf/ptypes" + memdb "github.com/hashicorp/go-memdb" + "github.com/hashicorp/vault/helper/identity" + "github.com/hashicorp/vault/logical" + "github.com/hashicorp/vault/logical/framework" +) + +func groupPaths(i *IdentityStore) []*framework.Path { + return []*framework.Path{ + { + Pattern: "group$", + Fields: map[string]*framework.FieldSchema{ + "id": { + Type: framework.TypeString, + Description: "ID of the group.", + }, + "name": { + Type: framework.TypeString, + Description: "Name of the group.", + }, + "metadata": { + Type: framework.TypeStringSlice, + Description: "Metadata to be associated with the group. Format should be a list of `key=value` pairs.", + }, + "policies": { + Type: framework.TypeCommaStringSlice, + Description: "Policies to be tied to the group.", + }, + "member_group_ids": { + Type: framework.TypeCommaStringSlice, + Description: "Group IDs to be assigned as group members.", + }, + "member_entity_ids": { + Type: framework.TypeCommaStringSlice, + Description: "Entity IDs to be assigned as group members.", + }, + }, + Callbacks: map[logical.Operation]framework.OperationFunc{ + logical.UpdateOperation: i.checkPremiumVersion(i.pathGroupRegister), + }, + + HelpSynopsis: strings.TrimSpace(groupHelp["register"][0]), + HelpDescription: strings.TrimSpace(groupHelp["register"][1]), + }, + { + Pattern: "group/id/" + framework.GenericNameRegex("id"), + Fields: map[string]*framework.FieldSchema{ + "id": { + Type: framework.TypeString, + Description: "ID of the group.", + }, + "name": { + Type: framework.TypeString, + Description: "Name of the group.", + }, + "metadata": { + Type: framework.TypeStringSlice, + Description: "Metadata to be associated with the group. Format should be a list of `key=value` pairs.", + }, + "policies": { + Type: framework.TypeCommaStringSlice, + Description: "Policies to be tied to the group.", + }, + "member_group_ids": { + Type: framework.TypeCommaStringSlice, + Description: "Group IDs to be assigned as group members.", + }, + "member_entity_ids": { + Type: framework.TypeCommaStringSlice, + Description: "Entity IDs to be assigned as group members.", + }, + }, + Callbacks: map[logical.Operation]framework.OperationFunc{ + logical.UpdateOperation: i.checkPremiumVersion(i.pathGroupIDUpdate), + logical.ReadOperation: i.checkPremiumVersion(i.pathGroupIDRead), + logical.DeleteOperation: i.checkPremiumVersion(i.pathGroupIDDelete), + }, + + HelpSynopsis: strings.TrimSpace(groupHelp["group-by-id"][0]), + HelpDescription: strings.TrimSpace(groupHelp["group-by-id"][1]), + }, + { + Pattern: "group/id/?$", + Callbacks: map[logical.Operation]framework.OperationFunc{ + logical.ListOperation: i.checkPremiumVersion(i.pathGroupIDList), + }, + + HelpSynopsis: strings.TrimSpace(entityHelp["group-id-list"][0]), + HelpDescription: strings.TrimSpace(entityHelp["group-id-list"][1]), + }, + } +} + +func (i *IdentityStore) pathGroupRegister(req *logical.Request, d *framework.FieldData) (*logical.Response, error) { + _, ok := d.GetOk("id") + if ok { + return i.pathGroupIDUpdate(req, d) + } + + i.groupLock.Lock() + defer i.groupLock.Unlock() + + return i.handleGroupUpdateCommon(req, d, nil) +} + +func (i *IdentityStore) pathGroupIDUpdate(req *logical.Request, d *framework.FieldData) (*logical.Response, error) { + groupID := d.Get("id").(string) + if groupID == "" { + return logical.ErrorResponse("empty group ID"), nil + } + + i.groupLock.Lock() + defer i.groupLock.Unlock() + + group, err := i.memDBGroupByID(groupID, true) + if err != nil { + return nil, err + } + if group == nil { + return logical.ErrorResponse("invalid group ID"), nil + } + + return i.handleGroupUpdateCommon(req, d, group) +} + +func (i *IdentityStore) handleGroupUpdateCommon(req *logical.Request, d *framework.FieldData, group *identity.Group) (*logical.Response, error) { + var err error + var newGroup bool + if group == nil { + group = &identity.Group{} + newGroup = true + } + + // Update the policies if supplied + policiesRaw, ok := d.GetOk("policies") + if ok { + group.Policies = policiesRaw.([]string) + } + + // Get the name + groupName := d.Get("name").(string) + if groupName != "" { + // Check if there is a group already existing for the given name + groupByName, err := i.memDBGroupByName(groupName, false) + if err != nil { + return nil, err + } + + // If this is a new group and if there already exists a group by this + // name, error out. If the name of an existing group is about to be + // modified into something which is already tied to a different group, + // error out. + switch { + case (newGroup && groupByName != nil), (groupByName != nil && group.ID != "" && groupByName.ID != group.ID): + return logical.ErrorResponse("group name is already in use"), nil + } + group.Name = groupName + } + + metadataRaw, ok := d.GetOk("metadata") + if ok { + group.Metadata, err = parseMetadata(metadataRaw.([]string)) + if err != nil { + return logical.ErrorResponse(fmt.Sprintf("failed to parse group metadata: %v", err)), nil + } + } + + memberEntityIDsRaw, ok := d.GetOk("member_entity_ids") + if ok { + group.MemberEntityIDs = memberEntityIDsRaw.([]string) + if len(group.MemberEntityIDs) > 512 { + return logical.ErrorResponse("member entity IDs exceeding the limit of 512"), nil + } + } + + memberGroupIDsRaw, ok := d.GetOk("member_group_ids") + var memberGroupIDs []string + if ok { + memberGroupIDs = memberGroupIDsRaw.([]string) + } + + err = i.sanitizeAndUpsertGroup(group, memberGroupIDs) + if err != nil { + return nil, err + } + + respData := map[string]interface{}{ + "id": group.ID, + "name": group.Name, + } + return &logical.Response{ + Data: respData, + }, nil +} + +func (i *IdentityStore) pathGroupIDRead(req *logical.Request, d *framework.FieldData) (*logical.Response, error) { + groupID := d.Get("id").(string) + if groupID == "" { + return logical.ErrorResponse("empty group id"), nil + } + + group, err := i.memDBGroupByID(groupID, false) + if err != nil { + return nil, err + } + if group == nil { + return nil, nil + } + + return i.handleGroupReadCommon(group) +} + +func (i *IdentityStore) handleGroupReadCommon(group *identity.Group) (*logical.Response, error) { + if group == nil { + return nil, fmt.Errorf("nil group") + } + + respData := map[string]interface{}{} + respData["id"] = group.ID + respData["name"] = group.Name + respData["policies"] = group.Policies + respData["member_entity_ids"] = group.MemberEntityIDs + respData["metadata"] = group.Metadata + respData["creation_time"] = ptypes.TimestampString(group.CreationTime) + respData["last_update_time"] = ptypes.TimestampString(group.LastUpdateTime) + respData["modify_index"] = group.ModifyIndex + + memberGroupIDs, err := i.memberGroupIDsByID(group.ID) + if err != nil { + return nil, err + } + respData["member_group_ids"] = memberGroupIDs + + return &logical.Response{ + Data: respData, + }, nil +} + +func (i *IdentityStore) pathGroupIDDelete(req *logical.Request, d *framework.FieldData) (*logical.Response, error) { + groupID := d.Get("id").(string) + if groupID == "" { + return logical.ErrorResponse("empty group ID"), nil + } + return nil, i.deleteGroupByID(groupID) +} + +// pathGroupIDList lists the IDs of all the groups in the identity store +func (i *IdentityStore) pathGroupIDList(req *logical.Request, d *framework.FieldData) (*logical.Response, error) { + ws := memdb.NewWatchSet() + iter, err := i.memDBGroupIterator(ws) + if err != nil { + return nil, fmt.Errorf("failed to fetch iterator for group in memdb: %v", err) + } + + var groupIDs []string + for { + raw := iter.Next() + if raw == nil { + break + } + groupIDs = append(groupIDs, raw.(*identity.Group).ID) + } + + return logical.ListResponse(groupIDs), nil +} + +var groupHelp = map[string][2]string{ + "register": { + "Create a new group.", + "", + }, + "group-by-id": { + "Update or delete an existing group using its ID.", + "", + }, + "group-id-list": { + "List all the group IDs.", + "", + }, +} diff --git a/vault/identity_store_groups_test.go b/vault/identity_store_groups_test.go new file mode 100644 index 0000000000..3886f8ef4b --- /dev/null +++ b/vault/identity_store_groups_test.go @@ -0,0 +1,666 @@ +package vault + +import ( + "reflect" + "sort" + "testing" + + "github.com/hashicorp/vault/helper/identity" + "github.com/hashicorp/vault/logical" +) + +func TestIdentityStore_MemDBGroupIndexes(t *testing.T) { + var err error + i, _, _ := testIdentityStoreWithGithubAuth(t) + + // Create a dummy group + group := &identity.Group{ + ID: "testgroupid", + Name: "testgroupname", + Metadata: map[string]string{ + "testmetadatakey1": "testmetadatavalue1", + "testmetadatakey2": "testmetadatavalue2", + }, + ParentGroupIDs: []string{"testparentgroupid1", "testparentgroupid2"}, + MemberEntityIDs: []string{"testentityid1", "testentityid2"}, + Policies: []string{"testpolicy1", "testpolicy2"}, + BucketKeyHash: i.groupPacker.BucketKeyHashByItemID("testgroupid"), + } + + // Insert it into memdb + err = i.memDBUpsertGroup(group) + if err != nil { + t.Fatal(err) + } + + // Insert another dummy group + group = &identity.Group{ + ID: "testgroupid2", + Name: "testgroupname2", + Metadata: map[string]string{ + "testmetadatakey2": "testmetadatavalue2", + "testmetadatakey3": "testmetadatavalue3", + }, + ParentGroupIDs: []string{"testparentgroupid2", "testparentgroupid3"}, + MemberEntityIDs: []string{"testentityid2", "testentityid3"}, + Policies: []string{"testpolicy2", "testpolicy3"}, + BucketKeyHash: i.groupPacker.BucketKeyHashByItemID("testgroupid2"), + } + + // Insert it into memdb + err = i.memDBUpsertGroup(group) + if err != nil { + t.Fatal(err) + } + + var fetchedGroup *identity.Group + + // Fetch group given the name + fetchedGroup, err = i.memDBGroupByName("testgroupname", false) + if err != nil { + t.Fatal(err) + } + if fetchedGroup == nil || fetchedGroup.Name != "testgroupname" { + t.Fatalf("failed to fetch an indexed group") + } + + // Fetch group given the ID + fetchedGroup, err = i.memDBGroupByID("testgroupid", false) + if err != nil { + t.Fatal(err) + } + if fetchedGroup == nil || fetchedGroup.Name != "testgroupname" { + t.Fatalf("failed to fetch an indexed group") + } + + var fetchedGroups []*identity.Group + // Fetch the subgroups of a given group ID + fetchedGroups, err = i.memDBGroupsByParentGroupID("testparentgroupid1", false) + if err != nil { + t.Fatal(err) + } + if len(fetchedGroups) != 1 || fetchedGroups[0].Name != "testgroupname" { + t.Fatalf("failed to fetch an indexed group") + } + + fetchedGroups, err = i.memDBGroupsByParentGroupID("testparentgroupid2", false) + if err != nil { + t.Fatal(err) + } + if len(fetchedGroups) != 2 { + t.Fatalf("failed to fetch a indexed groups") + } + + // Fetch groups based on policy name + fetchedGroups, err = i.memDBGroupsByPolicy("testpolicy1", false) + if err != nil { + t.Fatal(err) + } + if len(fetchedGroups) != 1 || fetchedGroups[0].Name != "testgroupname" { + t.Fatalf("failed to fetch an indexed group") + } + + fetchedGroups, err = i.memDBGroupsByPolicy("testpolicy2", false) + if err != nil { + t.Fatal(err) + } + if len(fetchedGroups) != 2 { + t.Fatalf("failed to fetch indexed groups") + } + + // Fetch groups based on member entity ID + fetchedGroups, err = i.memDBGroupsByMemberEntityID("testentityid1", false) + if err != nil { + t.Fatal(err) + } + if len(fetchedGroups) != 1 || fetchedGroups[0].Name != "testgroupname" { + t.Fatalf("failed to fetch an indexed group") + } + + fetchedGroups, err = i.memDBGroupsByMemberEntityID("testentityid2", false) + if err != nil { + t.Fatal(err) + } + + if len(fetchedGroups) != 2 { + t.Fatalf("failed to fetch groups by entity ID") + } +} + +func TestIdentityStore_GroupsCreateUpdate(t *testing.T) { + var resp *logical.Response + var err error + is, _, _ := testIdentityStoreWithGithubAuth(t) + + // Create an entity and get its ID + entityRegisterReq := &logical.Request{ + Operation: logical.UpdateOperation, + Path: "entity", + } + resp, err = is.HandleRequest(entityRegisterReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("bad: resp: %#v, err: %v", resp, err) + } + entityID1 := resp.Data["id"].(string) + + // Create another entity and get its ID + resp, err = is.HandleRequest(entityRegisterReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("bad: resp: %#v, err: %v", resp, err) + } + entityID2 := resp.Data["id"].(string) + + // Create a group with the above created 2 entities as its members + groupData := map[string]interface{}{ + "policies": "testpolicy1,testpolicy2", + "metadata": []string{"testkey1=testvalue1", "testkey2=testvalue2"}, + "member_entity_ids": []string{entityID1, entityID2}, + } + + // Create a group and get its ID + groupReq := &logical.Request{ + Operation: logical.UpdateOperation, + Path: "group", + Data: groupData, + } + resp, err = is.HandleRequest(groupReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("bad: resp: %#v, err: %v", resp, err) + } + memberGroupID1 := resp.Data["id"].(string) + + // Create another group and get its ID + resp, err = is.HandleRequest(groupReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("bad: resp: %#v, err: %v", resp, err) + } + memberGroupID2 := resp.Data["id"].(string) + + // Create a group with the above 2 groups as its members + groupData["member_group_ids"] = []string{memberGroupID1, memberGroupID2} + resp, err = is.HandleRequest(groupReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("bad: resp: %#v, err: %v", resp, err) + } + groupID := resp.Data["id"].(string) + + // Read the group using its iD and check if all the fields are properly + // set + groupReq = &logical.Request{ + Operation: logical.ReadOperation, + Path: "group/id/" + groupID, + } + resp, err = is.HandleRequest(groupReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("bad: resp: %#v, err: %v", resp, err) + } + + expectedData := map[string]interface{}{ + "policies": []string{"testpolicy1", "testpolicy2"}, + "metadata": map[string]string{ + "testkey1": "testvalue1", + "testkey2": "testvalue2", + }, + } + expectedData["id"] = resp.Data["id"] + expectedData["name"] = resp.Data["name"] + expectedData["member_group_ids"] = resp.Data["member_group_ids"] + expectedData["member_entity_ids"] = resp.Data["member_entity_ids"] + expectedData["creation_time"] = resp.Data["creation_time"] + expectedData["last_update_time"] = resp.Data["last_update_time"] + expectedData["modify_index"] = resp.Data["modify_index"] + + if !reflect.DeepEqual(expectedData, resp.Data) { + t.Fatalf("bad: group data;\nexpected: %#v\n actual: %#v\n", expectedData, resp.Data) + } + + // Update the policies and metadata in the group + groupReq.Operation = logical.UpdateOperation + groupReq.Data = groupData + + // Update by setting ID in the param + groupData["id"] = groupID + groupData["policies"] = "updatedpolicy1,updatedpolicy2" + groupData["metadata"] = []string{"updatedkey=updatedvalue"} + resp, err = is.HandleRequest(groupReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("bad: resp: %#v, err: %v", resp, err) + } + + // Check if updates are reflected + groupReq.Operation = logical.ReadOperation + resp, err = is.HandleRequest(groupReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("bad: resp: %#v, err: %v", resp, err) + } + + expectedData["policies"] = []string{"updatedpolicy1", "updatedpolicy2"} + expectedData["metadata"] = map[string]string{ + "updatedkey": "updatedvalue", + } + expectedData["last_update_time"] = resp.Data["last_update_time"] + expectedData["modify_index"] = resp.Data["modify_index"] + if !reflect.DeepEqual(expectedData, resp.Data) { + t.Fatalf("bad: group data; expected: %#v\n actual: %#v\n", expectedData, resp.Data) + } +} + +func TestIdentityStore_GroupsCRUD_ByID(t *testing.T) { + var resp *logical.Response + var err error + is, _, _ := testIdentityStoreWithGithubAuth(t) + + // Create an entity and get its ID + entityRegisterReq := &logical.Request{ + Operation: logical.UpdateOperation, + Path: "entity", + } + resp, err = is.HandleRequest(entityRegisterReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("bad: resp: %#v, err: %v", resp, err) + } + entityID1 := resp.Data["id"].(string) + + // Create another entity and get its ID + resp, err = is.HandleRequest(entityRegisterReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("bad: resp: %#v, err: %v", resp, err) + } + entityID2 := resp.Data["id"].(string) + + // Create a group with the above created 2 entities as its members + groupData := map[string]interface{}{ + "policies": "testpolicy1,testpolicy2", + "metadata": []string{"testkey1=testvalue1", "testkey2=testvalue2"}, + "member_entity_ids": []string{entityID1, entityID2}, + } + + // Create a group and get its ID + groupRegisterReq := &logical.Request{ + Operation: logical.UpdateOperation, + Path: "group", + Data: groupData, + } + resp, err = is.HandleRequest(groupRegisterReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("bad: resp: %#v, err: %v", resp, err) + } + memberGroupID1 := resp.Data["id"].(string) + + // Create another group and get its ID + resp, err = is.HandleRequest(groupRegisterReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("bad: resp: %#v, err: %v", resp, err) + } + memberGroupID2 := resp.Data["id"].(string) + + // Create a group with the above 2 groups as its members + groupData["member_group_ids"] = []string{memberGroupID1, memberGroupID2} + resp, err = is.HandleRequest(groupRegisterReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("bad: resp: %#v, err: %v", resp, err) + } + groupID := resp.Data["id"].(string) + + // Read the group using its name and check if all the fields are properly + // set + groupReq := &logical.Request{ + Operation: logical.ReadOperation, + Path: "group/id/" + groupID, + } + resp, err = is.HandleRequest(groupReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("bad: resp: %#v, err: %v", resp, err) + } + + expectedData := map[string]interface{}{ + "policies": []string{"testpolicy1", "testpolicy2"}, + "metadata": map[string]string{ + "testkey1": "testvalue1", + "testkey2": "testvalue2", + }, + } + expectedData["id"] = resp.Data["id"] + expectedData["name"] = resp.Data["name"] + expectedData["member_group_ids"] = resp.Data["member_group_ids"] + expectedData["member_entity_ids"] = resp.Data["member_entity_ids"] + expectedData["creation_time"] = resp.Data["creation_time"] + expectedData["last_update_time"] = resp.Data["last_update_time"] + expectedData["modify_index"] = resp.Data["modify_index"] + + if !reflect.DeepEqual(expectedData, resp.Data) { + t.Fatalf("bad: group data;\nexpected: %#v\n actual: %#v\n", expectedData, resp.Data) + } + + // Update the policies and metadata in the group + groupReq.Operation = logical.UpdateOperation + groupReq.Data = groupData + groupData["policies"] = "updatedpolicy1,updatedpolicy2" + groupData["metadata"] = []string{"updatedkey=updatedvalue"} + resp, err = is.HandleRequest(groupReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("bad: resp: %#v, err: %v", resp, err) + } + + // Check if updates are reflected + groupReq.Operation = logical.ReadOperation + resp, err = is.HandleRequest(groupReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("bad: resp: %#v, err: %v", resp, err) + } + + expectedData["policies"] = []string{"updatedpolicy1", "updatedpolicy2"} + expectedData["metadata"] = map[string]string{ + "updatedkey": "updatedvalue", + } + expectedData["last_update_time"] = resp.Data["last_update_time"] + expectedData["modify_index"] = resp.Data["modify_index"] + if !reflect.DeepEqual(expectedData, resp.Data) { + t.Fatalf("bad: group data; expected: %#v\n actual: %#v\n", expectedData, resp.Data) + } + + // Check if delete is working properly + groupReq.Operation = logical.DeleteOperation + resp, err = is.HandleRequest(groupReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("bad: resp: %#v, err: %v", resp, err) + } + + groupReq.Operation = logical.ReadOperation + resp, err = is.HandleRequest(groupReq) + if err != nil { + t.Fatal(err) + } + if resp != nil { + t.Fatalf("expected a nil response") + } +} + +/* +Test groups hierarchy: + eng + | | + vault ops + | | | | + kube identity build deploy +*/ +func TestIdentityStore_GroupHierarchyCases(t *testing.T) { + var resp *logical.Response + var err error + is, _, _ := testIdentityStoreWithGithubAuth(t) + groupRegisterReq := &logical.Request{ + Operation: logical.UpdateOperation, + Path: "group", + } + + // Create 'kube' group + kubeGroupData := map[string]interface{}{ + "name": "kube", + "policies": "kubepolicy", + } + groupRegisterReq.Data = kubeGroupData + resp, err = is.HandleRequest(groupRegisterReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("bad: resp: %#v, err: %v", resp, err) + } + kubeGroupID := resp.Data["id"].(string) + + // Create 'identity' group + identityGroupData := map[string]interface{}{ + "name": "identity", + "policies": "identitypolicy", + } + groupRegisterReq.Data = identityGroupData + resp, err = is.HandleRequest(groupRegisterReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("bad: resp: %#v, err: %v", resp, err) + } + identityGroupID := resp.Data["id"].(string) + + // Create 'build' group + buildGroupData := map[string]interface{}{ + "name": "build", + "policies": "buildpolicy", + } + groupRegisterReq.Data = buildGroupData + resp, err = is.HandleRequest(groupRegisterReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("bad: resp: %#v, err: %v", resp, err) + } + buildGroupID := resp.Data["id"].(string) + + // Create 'deploy' group + deployGroupData := map[string]interface{}{ + "name": "deploy", + "policies": "deploypolicy", + } + groupRegisterReq.Data = deployGroupData + resp, err = is.HandleRequest(groupRegisterReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("bad: resp: %#v, err: %v", resp, err) + } + deployGroupID := resp.Data["id"].(string) + + // Create 'vault' with 'kube' and 'identity' as member groups + vaultMemberGroupIDs := []string{kubeGroupID, identityGroupID} + vaultGroupData := map[string]interface{}{ + "name": "vault", + "policies": "vaultpolicy", + "member_group_ids": vaultMemberGroupIDs, + } + groupRegisterReq.Data = vaultGroupData + resp, err = is.HandleRequest(groupRegisterReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("bad: resp: %#v, err: %v", resp, err) + } + vaultGroupID := resp.Data["id"].(string) + + // Create 'ops' group with 'build' and 'deploy' as member groups + opsMemberGroupIDs := []string{buildGroupID, deployGroupID} + opsGroupData := map[string]interface{}{ + "name": "ops", + "policies": "opspolicy", + "member_group_ids": opsMemberGroupIDs, + } + groupRegisterReq.Data = opsGroupData + resp, err = is.HandleRequest(groupRegisterReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("bad: resp: %#v, err: %v", resp, err) + } + opsGroupID := resp.Data["id"].(string) + + // Create 'eng' group with 'vault' and 'ops' as member groups + engMemberGroupIDs := []string{vaultGroupID, opsGroupID} + engGroupData := map[string]interface{}{ + "name": "eng", + "policies": "engpolicy", + "member_group_ids": engMemberGroupIDs, + } + + groupRegisterReq.Data = engGroupData + resp, err = is.HandleRequest(groupRegisterReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("bad: resp: %#v, err: %v", resp, err) + } + engGroupID := resp.Data["id"].(string) + + /* + fmt.Printf("engGroupID: %#v\n", engGroupID) + fmt.Printf("vaultGroupID: %#v\n", vaultGroupID) + fmt.Printf("opsGroupID: %#v\n", opsGroupID) + fmt.Printf("kubeGroupID: %#v\n", kubeGroupID) + fmt.Printf("identityGroupID: %#v\n", identityGroupID) + fmt.Printf("buildGroupID: %#v\n", buildGroupID) + fmt.Printf("deployGroupID: %#v\n", deployGroupID) + */ + + var memberGroupIDs []string + // Fetch 'eng' group + engGroup, err := is.memDBGroupByID(engGroupID, false) + if err != nil { + t.Fatal(err) + } + memberGroupIDs, err = is.memberGroupIDsByID(engGroup.ID) + if err != nil { + t.Fatal(err) + } + sort.Strings(memberGroupIDs) + sort.Strings(engMemberGroupIDs) + if !reflect.DeepEqual(engMemberGroupIDs, memberGroupIDs) { + t.Fatalf("bad: group membership IDs; expected: %#v\n actual: %#v\n", engMemberGroupIDs, memberGroupIDs) + } + + vaultGroup, err := is.memDBGroupByID(vaultGroupID, false) + if err != nil { + t.Fatal(err) + } + memberGroupIDs, err = is.memberGroupIDsByID(vaultGroup.ID) + if err != nil { + t.Fatal(err) + } + sort.Strings(memberGroupIDs) + sort.Strings(vaultMemberGroupIDs) + if !reflect.DeepEqual(vaultMemberGroupIDs, memberGroupIDs) { + t.Fatalf("bad: group membership IDs; expected: %#v\n actual: %#v\n", vaultMemberGroupIDs, memberGroupIDs) + } + + opsGroup, err := is.memDBGroupByID(opsGroupID, false) + if err != nil { + t.Fatal(err) + } + memberGroupIDs, err = is.memberGroupIDsByID(opsGroup.ID) + if err != nil { + t.Fatal(err) + } + sort.Strings(memberGroupIDs) + sort.Strings(opsMemberGroupIDs) + if !reflect.DeepEqual(opsMemberGroupIDs, memberGroupIDs) { + t.Fatalf("bad: group membership IDs; expected: %#v\n actual: %#v\n", opsMemberGroupIDs, memberGroupIDs) + } + + groupUpdateReq := &logical.Request{ + Operation: logical.UpdateOperation, + } + + // Adding 'engGroupID' under 'kubeGroupID' should fail + groupUpdateReq.Path = "group/name/kube" + groupUpdateReq.Data = kubeGroupData + kubeGroupData["member_group_ids"] = []string{engGroupID} + resp, err = is.HandleRequest(groupUpdateReq) + if err == nil { + t.Fatalf("expected an error response") + } + + // Create an entity ID + entityRegisterReq := &logical.Request{ + Operation: logical.UpdateOperation, + Path: "entity", + } + resp, err = is.HandleRequest(entityRegisterReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("bad: resp: %#v, err: %v", resp, err) + } + entityID1 := resp.Data["id"].(string) + + // Add the entity as a member of 'kube' group + entityIDReq := &logical.Request{ + Operation: logical.UpdateOperation, + Path: "group/id/" + kubeGroupID, + Data: map[string]interface{}{ + "member_entity_ids": []string{entityID1}, + }, + } + resp, err = is.HandleRequest(entityIDReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("bad: resp: %#v, err: %v", resp, err) + } + + // Create a second entity ID + resp, err = is.HandleRequest(entityRegisterReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("bad: resp: %#v, err: %v", resp, err) + } + entityID2 := resp.Data["id"].(string) + + // Add the entity as a member of 'ops' group + entityIDReq.Path = "group/id/" + opsGroupID + entityIDReq.Data = map[string]interface{}{ + "member_entity_ids": []string{entityID2}, + } + resp, err = is.HandleRequest(entityIDReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("bad: resp: %#v, err: %v", resp, err) + } + + // Create a third entity ID + resp, err = is.HandleRequest(entityRegisterReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("bad: resp: %#v, err: %v", resp, err) + } + entityID3 := resp.Data["id"].(string) + + // Add the entity as a member of 'eng' group + entityIDReq.Path = "group/id/" + engGroupID + entityIDReq.Data = map[string]interface{}{ + "member_entity_ids": []string{entityID3}, + } + resp, err = is.HandleRequest(entityIDReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("bad: resp: %#v, err: %v", resp, err) + } + + policies, err := is.groupPoliciesByEntityID(entityID1) + if err != nil { + t.Fatal(err) + } + sort.Strings(policies) + expected := []string{"kubepolicy", "vaultpolicy", "engpolicy"} + sort.Strings(expected) + if !reflect.DeepEqual(expected, policies) { + t.Fatalf("bad: policies; expected: %#v\nactual:%#v", expected, policies) + } + + policies, err = is.groupPoliciesByEntityID(entityID2) + if err != nil { + t.Fatal(err) + } + sort.Strings(policies) + expected = []string{"opspolicy", "engpolicy"} + sort.Strings(expected) + if !reflect.DeepEqual(expected, policies) { + t.Fatalf("bad: policies; expected: %#v\nactual:%#v", expected, policies) + } + + policies, err = is.groupPoliciesByEntityID(entityID3) + if err != nil { + t.Fatal(err) + } + + if len(policies) != 1 && policies[0] != "engpolicy" { + t.Fatalf("bad: policies; expected: 'engpolicy'\nactual:%#v", policies) + } + + groups, err := is.transitiveGroupsByEntityID(entityID1) + if err != nil { + t.Fatal(err) + } + if len(groups) != 3 { + t.Fatalf("bad: length of groups; expected: 3, actual: %d", len(groups)) + } + + groups, err = is.transitiveGroupsByEntityID(entityID2) + if err != nil { + t.Fatal(err) + } + if len(groups) != 2 { + t.Fatalf("bad: length of groups; expected: 2, actual: %d", len(groups)) + } + + groups, err = is.transitiveGroupsByEntityID(entityID3) + if err != nil { + t.Fatal(err) + } + if len(groups) != 1 { + t.Fatalf("bad: length of groups; expected: 1, actual: %d", len(groups)) + } +} diff --git a/vault/identity_store_schema.go b/vault/identity_store_schema.go new file mode 100644 index 0000000000..f3026f0e23 --- /dev/null +++ b/vault/identity_store_schema.go @@ -0,0 +1,180 @@ +package vault + +import ( + "fmt" + + memdb "github.com/hashicorp/go-memdb" +) + +func identityStoreSchema() *memdb.DBSchema { + iStoreSchema := &memdb.DBSchema{ + Tables: make(map[string]*memdb.TableSchema), + } + + schemas := []func() *memdb.TableSchema{ + entityTableSchema, + aliasesTableSchema, + groupTableSchema, + } + + for _, schemaFunc := range schemas { + schema := schemaFunc() + if _, ok := iStoreSchema.Tables[schema.Name]; ok { + panic(fmt.Sprintf("duplicate table name: %s", schema.Name)) + } + iStoreSchema.Tables[schema.Name] = schema + } + + return iStoreSchema +} + +func aliasesTableSchema() *memdb.TableSchema { + return &memdb.TableSchema{ + Name: "aliases", + Indexes: map[string]*memdb.IndexSchema{ + "id": &memdb.IndexSchema{ + Name: "id", + Unique: true, + Indexer: &memdb.StringFieldIndex{ + Field: "ID", + }, + }, + "entity_id": &memdb.IndexSchema{ + Name: "entity_id", + Unique: false, + Indexer: &memdb.StringFieldIndex{ + Field: "EntityID", + }, + }, + "mount_type": &memdb.IndexSchema{ + Name: "mount_type", + Unique: false, + Indexer: &memdb.StringFieldIndex{ + Field: "MountType", + }, + }, + "factors": &memdb.IndexSchema{ + Name: "factors", + Unique: true, + Indexer: &memdb.CompoundIndex{ + Indexes: []memdb.Indexer{ + &memdb.StringFieldIndex{ + Field: "MountAccessor", + }, + &memdb.StringFieldIndex{ + Field: "Name", + }, + }, + }, + }, + "metadata": &memdb.IndexSchema{ + Name: "metadata", + Unique: false, + AllowMissing: true, + Indexer: &memdb.StringMapFieldIndex{ + Field: "Metadata", + }, + }, + }, + } +} + +func entityTableSchema() *memdb.TableSchema { + return &memdb.TableSchema{ + Name: "entities", + Indexes: map[string]*memdb.IndexSchema{ + "id": &memdb.IndexSchema{ + Name: "id", + Unique: true, + Indexer: &memdb.StringFieldIndex{ + Field: "ID", + }, + }, + "name": &memdb.IndexSchema{ + Name: "name", + Unique: true, + Indexer: &memdb.StringFieldIndex{ + Field: "Name", + }, + }, + "metadata": &memdb.IndexSchema{ + Name: "metadata", + Unique: false, + AllowMissing: true, + Indexer: &memdb.StringMapFieldIndex{ + Field: "Metadata", + }, + }, + "merged_entity_ids": &memdb.IndexSchema{ + Name: "merged_entity_ids", + Unique: true, + AllowMissing: true, + Indexer: &memdb.StringSliceFieldIndex{ + Field: "MergedEntityIDs", + }, + }, + "bucket_key_hash": &memdb.IndexSchema{ + Name: "bucket_key_hash", + Unique: false, + AllowMissing: false, + Indexer: &memdb.StringFieldIndex{ + Field: "BucketKeyHash", + }, + }, + }, + } +} + +func groupTableSchema() *memdb.TableSchema { + return &memdb.TableSchema{ + Name: "groups", + Indexes: map[string]*memdb.IndexSchema{ + "id": { + Name: "id", + Unique: true, + Indexer: &memdb.StringFieldIndex{ + Field: "ID", + }, + }, + "name": { + Name: "name", + Unique: true, + Indexer: &memdb.StringFieldIndex{ + Field: "Name", + }, + }, + "member_entity_ids": { + Name: "member_entity_ids", + Unique: false, + AllowMissing: true, + Indexer: &memdb.StringSliceFieldIndex{ + Field: "MemberEntityIDs", + }, + }, + "parent_group_ids": { + Name: "parent_group_ids", + Unique: false, + AllowMissing: true, + Indexer: &memdb.StringSliceFieldIndex{ + Field: "ParentGroupIDs", + }, + }, + "policies": { + Name: "policies", + Unique: false, + AllowMissing: true, + Indexer: &memdb.StringSliceFieldIndex{ + Field: "Policies", + }, + }, + "bucket_key_hash": &memdb.IndexSchema{ + Name: "bucket_key_hash", + Unique: false, + AllowMissing: false, + Indexer: &memdb.StringFieldIndex{ + Field: "BucketKeyHash", + }, + }, + }, + } +} diff --git a/vault/identity_store_structs.go b/vault/identity_store_structs.go new file mode 100644 index 0000000000..b9020c0289 --- /dev/null +++ b/vault/identity_store_structs.go @@ -0,0 +1,75 @@ +package vault + +import ( + "regexp" + "sync" + + memdb "github.com/hashicorp/go-memdb" + "github.com/hashicorp/vault/helper/locksutil" + "github.com/hashicorp/vault/helper/storagepacker" + "github.com/hashicorp/vault/logical" + "github.com/hashicorp/vault/logical/framework" + log "github.com/mgutz/logxi/v1" +) + +const ( + // Storage prefixes + entityPrefix = "entity/" +) + +var ( + // metaKeyFormatRegEx checks if a metadata key string is valid + metaKeyFormatRegEx = regexp.MustCompile(`^[a-zA-Z0-9=/+_-]+$`).MatchString +) + +const ( + // The meta key prefix reserved for Vault's internal use + metaKeyReservedPrefix = "vault-" + + // The maximum number of metadata key pairs allowed to be registered + metaMaxKeyPairs = 64 + + // The maximum allowed length of a metadata key + metaKeyMaxLength = 128 + + // The maximum allowed length of a metadata value + metaValueMaxLength = 512 +) + +// IdentityStore is composed of its own storage view and a MemDB which +// maintains active in-memory replicas of the storage contents indexed by +// multiple fields. +type IdentityStore struct { + // IdentityStore is a secret backend in Vault + *framework.Backend + + // view is the storage sub-view where all the artifacts of identity store + // gets persisted + view logical.Storage + + // db is the in-memory database where the storage artifacts gets replicated + // to enable richer queries based on multiple indexes. + db *memdb.MemDB + + // validateMountAccessorFunc is a utility from router which returnes the + // properties of the mount given the mount accessor. + validateMountAccessorFunc func(string) *validateMountResponse + + // entityLocks are a set of 256 locks to which all the entities will be + // categorized to while performing storage modifications. + entityLocks []*locksutil.LockEntry + + // groupLock is used to protect modifications to group entries + groupLock sync.RWMutex + + // logger is the server logger copied over from core + logger log.Logger + + // entityPacker is used to pack multiple entity storage entries into 256 + // buckets + entityPacker *storagepacker.StoragePacker + + // groupPacker is used to pack multiple group storage entries into 256 + // buckets + groupPacker *storagepacker.StoragePacker +} diff --git a/vault/identity_store_test.go b/vault/identity_store_test.go new file mode 100644 index 0000000000..9fce0b79d0 --- /dev/null +++ b/vault/identity_store_test.go @@ -0,0 +1,269 @@ +package vault + +import ( + "testing" + "time" + + credGithub "github.com/hashicorp/vault/builtin/credential/github" + "github.com/hashicorp/vault/logical" +) + +func TestIdentityStore_CreateEntity(t *testing.T) { + is, ghAccessor, _ := testIdentityStoreWithGithubAuth(t) + alias := &logical.Alias{ + MountType: "github", + MountAccessor: ghAccessor, + Name: "githubuser", + } + + entity, err := is.CreateEntity(alias) + if err != nil { + t.Fatal(err) + } + if entity == nil { + t.Fatalf("expected a non-nil entity") + } + + if len(entity.Aliases) != 1 { + t.Fatalf("bad: length of aliases; expected: 1, actual: %d", len(entity.Aliases)) + } + + if entity.Aliases[0].Name != alias.Name { + t.Fatalf("bad: alias name; expected: %q, actual: %q", alias.Name, entity.Aliases[0].Name) + } + + // Try recreating an entity with the same alias details. It should fail. + entity, err = is.CreateEntity(alias) + if err == nil { + t.Fatalf("expected an error") + } +} + +func TestIdentityStore_EntityByAliasFactors(t *testing.T) { + var err error + var resp *logical.Response + + is, ghAccessor, _ := testIdentityStoreWithGithubAuth(t) + + registerData := map[string]interface{}{ + "name": "testentityname", + "metadata": []string{"someusefulkey=someusefulvalue"}, + "policies": []string{"testpolicy1", "testpolicy2"}, + } + + registerReq := &logical.Request{ + Operation: logical.UpdateOperation, + Path: "entity", + Data: registerData, + } + + // Register the entity + resp, err = is.HandleRequest(registerReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%v resp:%#v", err, resp) + } + idRaw, ok := resp.Data["id"] + if !ok { + t.Fatalf("entity id not present in response") + } + entityID := idRaw.(string) + if entityID == "" { + t.Fatalf("invalid entity id") + } + + aliasData := map[string]interface{}{ + "entity_id": entityID, + "name": "alias_name", + "mount_accessor": ghAccessor, + } + aliasReq := &logical.Request{ + Operation: logical.UpdateOperation, + Path: "alias", + Data: aliasData, + } + + resp, err = is.HandleRequest(aliasReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%v resp:%#v", err, resp) + } + if resp == nil { + t.Fatalf("expected a non-nil response") + } + + entity, err := is.EntityByAliasFactors(ghAccessor, "alias_name", false) + if err != nil { + t.Fatal(err) + } + if entity == nil { + t.Fatalf("expected a non-nil entity") + } + if entity.ID != entityID { + t.Fatalf("bad: entity ID; expected: %q actual: %q", entityID, entity.ID) + } +} + +func TestIdentityStore_WrapInfoInheritance(t *testing.T) { + var err error + var resp *logical.Response + + core, is, ts, _ := testCoreWithIdentityTokenGithub(t) + + registerData := map[string]interface{}{ + "name": "testentityname", + "metadata": []string{"someusefulkey=someusefulvalue"}, + "policies": []string{"testpolicy1", "testpolicy2"}, + } + + registerReq := &logical.Request{ + Operation: logical.UpdateOperation, + Path: "entity", + Data: registerData, + } + + // Register the entity + resp, err = is.HandleRequest(registerReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%v resp:%#v", err, resp) + } + + idRaw, ok := resp.Data["id"] + if !ok { + t.Fatalf("entity id not present in response") + } + entityID := idRaw.(string) + if entityID == "" { + t.Fatalf("invalid entity id") + } + + // Create a token which has EntityID set and has update permissions to + // sys/wrapping/wrap + te := &TokenEntry{ + Path: "test", + Policies: []string{"default", responseWrappingPolicyName}, + EntityID: entityID, + } + + if err := ts.create(te); err != nil { + t.Fatal(err) + } + + wrapReq := &logical.Request{ + Path: "sys/wrapping/wrap", + ClientToken: te.ID, + Operation: logical.UpdateOperation, + Data: map[string]interface{}{ + "foo": "bar", + }, + WrapInfo: &logical.RequestWrapInfo{ + TTL: time.Duration(5 * time.Second), + }, + } + + resp, err = core.HandleRequest(wrapReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("bad: resp: %#v, err: %v", resp, err) + } + + if resp.WrapInfo == nil { + t.Fatalf("expected a non-nil WrapInfo") + } + + if resp.WrapInfo.WrappedEntityID != entityID { + t.Fatalf("bad: WrapInfo in response not having proper entity ID set; expected: %q, actual:%q", entityID, resp.WrapInfo.WrappedEntityID) + } +} + +func TestIdentityStore_TokenEntityInheritance(t *testing.T) { + _, ts, _, _ := TestCoreWithTokenStore(t) + + // Create a token which has EntityID set + te := &TokenEntry{ + Path: "test", + Policies: []string{"dev", "prod"}, + EntityID: "testentityid", + } + + if err := ts.create(te); err != nil { + t.Fatal(err) + } + + // Create a child token; this should inherit the EntityID + tokenReq := &logical.Request{ + Operation: logical.UpdateOperation, + Path: "create", + ClientToken: te.ID, + } + + resp, err := ts.HandleRequest(tokenReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("bad: resp: %#v err: %v", err, resp) + } + + if resp.Auth.EntityID != te.EntityID { + t.Fatalf("bad: entity ID; expected: %v, actual: %v", te.EntityID, resp.Auth.EntityID) + } + + // Create an orphan token; this should not inherit the EntityID + tokenReq.Path = "create-orphan" + resp, err = ts.HandleRequest(tokenReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("bad: resp: %#v err: %v", err, resp) + } + + if resp.Auth.EntityID != "" { + t.Fatalf("expected entity ID to be not set") + } +} + +func testCoreWithIdentityTokenGithub(t *testing.T) (*Core, *IdentityStore, *TokenStore, string) { + is, ghAccessor, core := testIdentityStoreWithGithubAuth(t) + ts := testTokenStore(t, core) + return core, is, ts, ghAccessor +} + +// testIdentityStoreWithGithubAuth returns an instance of identity store which +// is mounted by default. This function also enables the github auth backend to +// assist with testing aliases and entities that require an valid mount +// accessor of an auth backend. +func testIdentityStoreWithGithubAuth(t *testing.T) (*IdentityStore, string, *Core) { + // Add github credential factory to core config + err := AddTestCredentialBackend("github", credGithub.Factory) + if err != nil { + t.Fatalf("err: %s", err) + } + + c, _, _ := TestCoreUnsealed(t) + + meGH := &MountEntry{ + Table: credentialTableType, + Path: "github/", + Type: "github", + Description: "github auth", + } + + err = c.enableCredential(meGH) + if err != nil { + t.Fatal(err) + } + + // Identity store will be mounted by now, just fetch it from router + identitystore := c.router.MatchingBackend("identity/") + if identitystore == nil { + t.Fatalf("failed to fetch identity store from router") + } + + return identitystore.(*IdentityStore), meGH.Accessor, c +} + +func TestIdentityStore_MetadataKeyRegex(t *testing.T) { + key := "validVALID012_-=+/" + + if !metaKeyFormatRegEx(key) { + t.Fatal("failed to accept valid metadata key") + } + + key = "a:b" + if metaKeyFormatRegEx(key) { + t.Fatal("accepted invalid metadata key") + } +} diff --git a/vault/identity_store_upgrade.go b/vault/identity_store_upgrade.go new file mode 100644 index 0000000000..4a96b20e0a --- /dev/null +++ b/vault/identity_store_upgrade.go @@ -0,0 +1,86 @@ +package vault + +import ( + "strings" + + "github.com/hashicorp/vault/logical" + "github.com/hashicorp/vault/logical/framework" +) + +func upgradePaths(i *IdentityStore) []*framework.Path { + return []*framework.Path{ + { + Pattern: "persona$", + Fields: map[string]*framework.FieldSchema{ + "id": { + Type: framework.TypeString, + Description: "ID of the alias", + }, + "entity_id": { + Type: framework.TypeString, + Description: "Entity ID to which this alias belongs to", + }, + "mount_accessor": { + Type: framework.TypeString, + Description: "Mount accessor to which this alias belongs to", + }, + "name": { + Type: framework.TypeString, + Description: "Name of the alias", + }, + "metadata": { + Type: framework.TypeStringSlice, + Description: "Metadata to be associated with the alias. Format should be a list of `key=value` pairs.", + }, + }, + Callbacks: map[logical.Operation]framework.OperationFunc{ + logical.UpdateOperation: i.checkPremiumVersion(i.pathAliasRegister), + }, + + HelpSynopsis: strings.TrimSpace(aliasHelp["alias"][0]), + HelpDescription: strings.TrimSpace(aliasHelp["alias"][1]), + }, + { + Pattern: "persona/id/" + framework.GenericNameRegex("id"), + Fields: map[string]*framework.FieldSchema{ + "id": { + Type: framework.TypeString, + Description: "ID of the alias", + }, + "entity_id": { + Type: framework.TypeString, + Description: "Entity ID to which this alias should be tied to", + }, + "mount_accessor": { + Type: framework.TypeString, + Description: "Mount accessor to which this alias belongs to", + }, + "name": { + Type: framework.TypeString, + Description: "Name of the alias", + }, + "metadata": { + Type: framework.TypeStringSlice, + Description: "Metadata to be associated with the alias. Format should be a comma separated list of `key=value` pairs.", + }, + }, + Callbacks: map[logical.Operation]framework.OperationFunc{ + logical.UpdateOperation: i.checkPremiumVersion(i.pathAliasIDUpdate), + logical.ReadOperation: i.checkPremiumVersion(i.pathAliasIDRead), + logical.DeleteOperation: i.checkPremiumVersion(i.pathAliasIDDelete), + }, + + HelpSynopsis: strings.TrimSpace(aliasHelp["alias-id"][0]), + HelpDescription: strings.TrimSpace(aliasHelp["alias-id"][1]), + }, + { + Pattern: "persona/id/?$", + Callbacks: map[logical.Operation]framework.OperationFunc{ + logical.ListOperation: i.checkPremiumVersion(i.pathAliasIDList), + }, + + HelpSynopsis: strings.TrimSpace(aliasHelp["alias-id-list"][0]), + HelpDescription: strings.TrimSpace(aliasHelp["alias-id-list"][1]), + }, + } +} diff --git a/vault/identity_store_util.go b/vault/identity_store_util.go new file mode 100644 index 0000000000..c9ea1eb06e --- /dev/null +++ b/vault/identity_store_util.go @@ -0,0 +1,2122 @@ +package vault + +import ( + "fmt" + "strings" + "sync" + + "github.com/golang/protobuf/ptypes" + memdb "github.com/hashicorp/go-memdb" + uuid "github.com/hashicorp/go-uuid" + "github.com/hashicorp/vault/helper/consts" + "github.com/hashicorp/vault/helper/identity" + "github.com/hashicorp/vault/helper/locksutil" + "github.com/hashicorp/vault/helper/storagepacker" + "github.com/hashicorp/vault/helper/strutil" +) + +// parseMetadata takes in a slice of string and parses each item as a key value pair separated by an '=' sign. +func parseMetadata(keyPairs []string) (map[string]string, error) { + if len(keyPairs) == 0 { + return nil, nil + } + + metadata := make(map[string]string, len(keyPairs)) + for _, keyPair := range keyPairs { + keyPairSlice := strings.SplitN(keyPair, "=", 2) + if len(keyPairSlice) != 2 || keyPairSlice[0] == "" { + return nil, fmt.Errorf("invalid key pair %q", keyPair) + } + metadata[keyPairSlice[0]] = keyPairSlice[1] + } + + return metadata, nil +} + +func (c *Core) loadIdentityStoreArtifacts() error { + var err error + if c.identityStore == nil { + return fmt.Errorf("identity store is not setup") + } + + err = c.identityStore.loadEntities() + if err != nil { + return err + } + + err = c.identityStore.loadGroups() + if err != nil { + return err + } + + return nil +} + +func (i *IdentityStore) loadGroups() error { + i.logger.Debug("identity loading groups") + existing, err := i.groupPacker.View().List(groupBucketsPrefix) + if err != nil { + return fmt.Errorf("failed to scan for groups: %v", err) + } + i.logger.Debug("identity: groups collected", "num_existing", len(existing)) + + for _, key := range existing { + bucket, err := i.groupPacker.GetBucket(i.groupPacker.BucketPath(key)) + if err != nil { + return err + } + + if bucket == nil { + continue + } + + for _, item := range bucket.Items { + group, err := i.parseGroupFromBucketItem(item) + if err != nil { + return err + } + if group == nil { + continue + } + + if i.logger.IsTrace() { + i.logger.Trace("loading group", "name", group.Name, "id", group.ID) + } + + i.groupLock.Lock() + defer i.groupLock.Unlock() + + txn := i.db.Txn(true) + defer txn.Abort() + + err = i.upsertGroupInTxn(txn, group, false) + if err != nil { + return fmt.Errorf("failed to update group in memdb: %v", err) + } + + txn.Commit() + } + } + + if i.logger.IsInfo() { + i.logger.Info("identity: groups restored") + } + + return nil +} + +func (i *IdentityStore) loadEntities() error { + // Accumulate existing entities + i.logger.Debug("identity: loading entities") + existing, err := i.entityPacker.View().List(storagepacker.StoragePackerBucketsPrefix) + if err != nil { + return fmt.Errorf("failed to scan for entities: %v", err) + } + i.logger.Debug("identity: entities collected", "num_existing", len(existing)) + + // Make the channels used for the worker pool + broker := make(chan string) + quit := make(chan bool) + + // Buffer these channels to prevent deadlocks + errs := make(chan error, len(existing)) + result := make(chan *storagepacker.Bucket, len(existing)) + + // Use a wait group + wg := &sync.WaitGroup{} + + // Create 64 workers to distribute work to + for j := 0; j < consts.ExpirationRestoreWorkerCount; j++ { + wg.Add(1) + go func() { + defer wg.Done() + + for { + select { + case bucketKey, ok := <-broker: + // broker has been closed, we are done + if !ok { + return + } + + bucket, err := i.entityPacker.GetBucket(i.entityPacker.BucketPath(bucketKey)) + if err != nil { + errs <- err + continue + } + + // Write results out to the result channel + result <- bucket + + // quit early + case <-quit: + return + } + } + }() + } + + // Distribute the collected keys to the workers in a go routine + wg.Add(1) + go func() { + defer wg.Done() + for j, bucketKey := range existing { + if j%500 == 0 { + i.logger.Trace("identity: enities loading", "progress", j) + } + + select { + case <-quit: + return + + default: + broker <- bucketKey + } + } + + // Close the broker, causing worker routines to exit + close(broker) + }() + + // Restore each key by pulling from the result chan + for j := 0; j < len(existing); j++ { + select { + case err := <-errs: + // Close all go routines + close(quit) + + return err + + case bucket := <-result: + // If there is no entry, nothing to restore + if bucket == nil { + continue + } + + for _, item := range bucket.Items { + entity, err := i.parseEntityFromBucketItem(item) + if err != nil { + return err + } + + if entity == nil { + continue + } + + // Only update MemDB and don't hit the storage again + err = i.upsertEntity(entity, nil, false) + if err != nil { + return fmt.Errorf("failed to update entity in MemDB: %v", err) + } + } + } + } + + // Let all go routines finish + wg.Wait() + + if i.logger.IsInfo() { + i.logger.Info("identity: entities restored") + } + + return nil +} + +// LockForEntityID returns the lock used to modify the entity. +func (i *IdentityStore) LockForEntityID(entityID string) *locksutil.LockEntry { + return locksutil.LockForKey(i.entityLocks, entityID) +} + +// upsertEntityInTxn either creates or updates an existing entity. The +// operations will be updated in both MemDB and storage. If 'persist' is set to +// false, then storage will not be updated. When a alias is transferred from +// one entity to another, both the source and destination entities should get +// updated, in which case, callers should send in both entity and +// previousEntity. +func (i *IdentityStore) upsertEntityInTxn(txn *memdb.Txn, entity *identity.Entity, previousEntity *identity.Entity, persist, lockHeld bool) error { + var err error + + if txn == nil { + return fmt.Errorf("txn is nil") + } + + if entity == nil { + return fmt.Errorf("entity is nil") + } + + // Acquire the lock to modify the entity storage entry + if !lockHeld { + lock := locksutil.LockForKey(i.entityLocks, entity.ID) + lock.Lock() + defer lock.Unlock() + } + + for _, alias := range entity.Aliases { + // Verify that alias is not associated to a different one already + aliasByFactors, err := i.memDBAliasByFactors(alias.MountAccessor, alias.Name, false) + if err != nil { + return err + } + + if aliasByFactors != nil && aliasByFactors.EntityID != entity.ID { + return fmt.Errorf("alias %q in already tied to a different entity %q", alias.ID, aliasByFactors.EntityID) + } + + // Insert or update alias in MemDB using the transaction created above + err = i.memDBUpsertAliasInTxn(txn, alias) + if err != nil { + return err + } + } + + // If previous entity is set, update it in MemDB and persist it + if previousEntity != nil && persist { + err = i.memDBUpsertEntityInTxn(txn, previousEntity) + if err != nil { + return err + } + + // Persist the previous entity object + marshaledPreviousEntity, err := ptypes.MarshalAny(previousEntity) + if err != nil { + return err + } + err = i.entityPacker.PutItem(&storagepacker.Item{ + ID: previousEntity.ID, + Message: marshaledPreviousEntity, + }) + if err != nil { + return err + } + } + + // Insert or update entity in MemDB using the transaction created above + err = i.memDBUpsertEntityInTxn(txn, entity) + if err != nil { + return err + } + + if persist { + entityAsAny, err := ptypes.MarshalAny(entity) + if err != nil { + return err + } + item := &storagepacker.Item{ + ID: entity.ID, + Message: entityAsAny, + } + + // Persist the entity object + err = i.entityPacker.PutItem(item) + if err != nil { + return err + } + } + + return nil +} + +// upsertEntity either creates or updates an existing entity. The operations +// will be updated in both MemDB and storage. If 'persist' is set to false, +// then storage will not be updated. When a alias is transferred from one +// entity to another, both the source and destination entities should get +// updated, in which case, callers should send in both entity and +// previousEntity. +func (i *IdentityStore) upsertEntity(entity *identity.Entity, previousEntity *identity.Entity, persist bool) error { + + // Create a MemDB transaction to update both alias and entity + txn := i.db.Txn(true) + defer txn.Abort() + + err := i.upsertEntityInTxn(txn, entity, previousEntity, persist, false) + if err != nil { + return err + } + + txn.Commit() + + return nil +} + +// upsertEntityNonLocked creates or updates an entity. The lock to modify the +// entity should be held before calling this function. +func (i *IdentityStore) upsertEntityNonLocked(entity *identity.Entity, previousEntity *identity.Entity, persist bool) error { + // Create a MemDB transaction to update both alias and entity + txn := i.db.Txn(true) + defer txn.Abort() + + err := i.upsertEntityInTxn(txn, entity, previousEntity, persist, true) + if err != nil { + return err + } + + txn.Commit() + + return nil +} + +func (i *IdentityStore) deleteEntity(entityID string) error { + var err error + var entity *identity.Entity + + if entityID == "" { + return fmt.Errorf("missing entity id") + } + + // Since an entity ID is required to acquire the lock to modify the + // storage, fetch the entity without acquiring the lock + + lockEntity, err := i.memDBEntityByID(entityID, false) + if err != nil { + return err + } + + if lockEntity == nil { + return nil + } + + // Acquire the lock to modify the entity storage entry + lock := locksutil.LockForKey(i.entityLocks, lockEntity.ID) + lock.Lock() + defer lock.Unlock() + + // Create a MemDB transaction to delete entity + txn := i.db.Txn(true) + defer txn.Abort() + + // Fetch the entity using its ID + entity, err = i.memDBEntityByIDInTxn(txn, entityID, true) + if err != nil { + return err + } + + // If there is no entity for the ID, do nothing + if entity == nil { + return nil + } + + // Delete all the aliases in the entity. This function will also remove + // the corresponding alias indexes too. + err = i.deleteAliasesInEntityInTxn(txn, entity, entity.Aliases) + if err != nil { + return err + } + + // Delete the entity using the same transaction + err = i.memDBDeleteEntityByIDInTxn(txn, entity.ID) + if err != nil { + return err + } + + // Delete the entity from storage + err = i.entityPacker.DeleteItem(entity.ID) + if err != nil { + return err + } + + // Committing the transaction *after* successfully deleting entity + txn.Commit() + + return nil +} + +func (i *IdentityStore) deleteAlias(aliasID string) error { + var err error + var alias *identity.Alias + var entity *identity.Entity + + if aliasID == "" { + return fmt.Errorf("missing alias ID") + } + + // Since an entity ID is required to acquire the lock to modify the + // storage, fetch the entity without acquiring the lock + + // Fetch the alias using its ID + + alias, err = i.memDBAliasByID(aliasID, false) + if err != nil { + return err + } + + // If there is no alias for the ID, do nothing + if alias == nil { + return nil + } + + // Find the entity to which the alias is tied to + lockEntity, err := i.memDBEntityByAliasID(alias.ID, false) + if err != nil { + return err + } + + // If there is no entity tied to a valid alias, something is wrong + if lockEntity == nil { + return fmt.Errorf("alias not associated to an entity") + } + + // Acquire the lock to modify the entity storage entry + lock := locksutil.LockForKey(i.entityLocks, lockEntity.ID) + lock.Lock() + defer lock.Unlock() + + // Create a MemDB transaction to delete entity + txn := i.db.Txn(true) + defer txn.Abort() + + // Fetch the alias again after acquiring the lock using the transaction + // created above + alias, err = i.memDBAliasByIDInTxn(txn, aliasID, false) + if err != nil { + return err + } + + // If there is no alias for the ID, do nothing + if alias == nil { + return nil + } + + // Fetch the entity again after acquiring the lock using the transaction + // created above + entity, err = i.memDBEntityByAliasIDInTxn(txn, alias.ID, true) + if err != nil { + return err + } + + // If there is no entity tied to a valid alias, something is wrong + if entity == nil { + return fmt.Errorf("alias not associated to an entity") + } + + // Lock switching should not end up in this code pointing to different + // entities + if entity.ID != entity.ID { + return fmt.Errorf("operating on an entity to which the lock doesn't belong to") + } + + aliases := []*identity.Alias{ + alias, + } + + // Delete alias from the entity object + err = i.deleteAliasesInEntityInTxn(txn, entity, aliases) + if err != nil { + return err + } + + // Update the entity index in the entities table + err = i.memDBUpsertEntityInTxn(txn, entity) + if err != nil { + return err + } + + // Persist the entity object + entityAsAny, err := ptypes.MarshalAny(entity) + if err != nil { + return err + } + item := &storagepacker.Item{ + ID: entity.ID, + Message: entityAsAny, + } + + err = i.entityPacker.PutItem(item) + if err != nil { + return err + } + + // Committing the transaction *after* successfully updating entity in + // storage + txn.Commit() + + return nil +} + +func (i *IdentityStore) memDBUpsertAliasInTxn(txn *memdb.Txn, alias *identity.Alias) error { + if txn == nil { + return fmt.Errorf("nil txn") + } + + if alias == nil { + return fmt.Errorf("alias is nil") + } + + aliasRaw, err := txn.First("aliases", "id", alias.ID) + if err != nil { + return fmt.Errorf("failed to lookup alias from memdb using alias ID: %v", err) + } + + if aliasRaw != nil { + err = txn.Delete("aliases", aliasRaw) + if err != nil { + return fmt.Errorf("failed to delete alias from memdb: %v", err) + } + } + + if err := txn.Insert("aliases", alias); err != nil { + return fmt.Errorf("failed to update alias into memdb: %v", err) + } + + return nil +} + +func (i *IdentityStore) memDBUpsertAlias(alias *identity.Alias) error { + if alias == nil { + return fmt.Errorf("alias is nil") + } + + txn := i.db.Txn(true) + defer txn.Abort() + + err := i.memDBUpsertAliasInTxn(txn, alias) + if err != nil { + return err + } + + txn.Commit() + + return nil +} + +func (i *IdentityStore) memDBAliasByEntityIDInTxn(txn *memdb.Txn, entityID string, clone bool) (*identity.Alias, error) { + if entityID == "" { + return nil, fmt.Errorf("missing entity id") + } + + if txn == nil { + return nil, fmt.Errorf("txn is nil") + } + + aliasRaw, err := txn.First("aliases", "entity_id", entityID) + if err != nil { + return nil, fmt.Errorf("failed to fetch alias from memdb using entity id: %v", err) + } + + if aliasRaw == nil { + return nil, nil + } + + alias, ok := aliasRaw.(*identity.Alias) + if !ok { + return nil, fmt.Errorf("failed to declare the type of fetched alias") + } + + if clone { + return alias.Clone() + } + + return alias, nil +} + +func (i *IdentityStore) memDBAliasByEntityID(entityID string, clone bool) (*identity.Alias, error) { + if entityID == "" { + return nil, fmt.Errorf("missing entity id") + } + + txn := i.db.Txn(false) + + return i.memDBAliasByEntityIDInTxn(txn, entityID, clone) +} + +func (i *IdentityStore) memDBAliasByIDInTxn(txn *memdb.Txn, aliasID string, clone bool) (*identity.Alias, error) { + if aliasID == "" { + return nil, fmt.Errorf("missing alias ID") + } + + if txn == nil { + return nil, fmt.Errorf("txn is nil") + } + + aliasRaw, err := txn.First("aliases", "id", aliasID) + if err != nil { + return nil, fmt.Errorf("failed to fetch alias from memdb using alias ID: %v", err) + } + + if aliasRaw == nil { + return nil, nil + } + + alias, ok := aliasRaw.(*identity.Alias) + if !ok { + return nil, fmt.Errorf("failed to declare the type of fetched alias") + } + + if clone { + return alias.Clone() + } + + return alias, nil +} + +func (i *IdentityStore) memDBAliasByID(aliasID string, clone bool) (*identity.Alias, error) { + if aliasID == "" { + return nil, fmt.Errorf("missing alias ID") + } + + txn := i.db.Txn(false) + + return i.memDBAliasByIDInTxn(txn, aliasID, clone) +} + +func (i *IdentityStore) memDBAliasByFactors(mountAccessor, aliasName string, clone bool) (*identity.Alias, error) { + if aliasName == "" { + return nil, fmt.Errorf("missing alias name") + } + + if mountAccessor == "" { + return nil, fmt.Errorf("missing mount accessor") + } + + txn := i.db.Txn(false) + aliasRaw, err := txn.First("aliases", "factors", mountAccessor, aliasName) + if err != nil { + return nil, fmt.Errorf("failed to fetch alias from memdb using factors: %v", err) + } + + if aliasRaw == nil { + return nil, nil + } + + alias, ok := aliasRaw.(*identity.Alias) + if !ok { + return nil, fmt.Errorf("failed to declare the type of fetched alias") + } + + if clone { + return alias.Clone() + } + + return alias, nil +} + +func (i *IdentityStore) memDBAliasesByMetadata(filters map[string]string, clone bool) ([]*identity.Alias, error) { + if filters == nil { + return nil, fmt.Errorf("map filter is nil") + } + + txn := i.db.Txn(false) + defer txn.Abort() + + var args []interface{} + for key, value := range filters { + args = append(args, key, value) + break + } + + aliasesIter, err := txn.Get("aliases", "metadata", args...) + if err != nil { + return nil, fmt.Errorf("failed to lookup aliases using metadata: %v", err) + } + + var aliases []*identity.Alias + for alias := aliasesIter.Next(); alias != nil; alias = aliasesIter.Next() { + entry := alias.(*identity.Alias) + if len(filters) <= 1 || satisfiesMetadataFilters(entry.Metadata, filters) { + if clone { + entry, err = entry.Clone() + if err != nil { + return nil, err + } + } + aliases = append(aliases, entry) + } + } + return aliases, nil +} + +func (i *IdentityStore) memDBDeleteAliasByID(aliasID string) error { + if aliasID == "" { + return nil + } + + txn := i.db.Txn(true) + defer txn.Abort() + + err := i.memDBDeleteAliasByIDInTxn(txn, aliasID) + if err != nil { + return err + } + + txn.Commit() + + return nil +} + +func (i *IdentityStore) memDBDeleteAliasByIDInTxn(txn *memdb.Txn, aliasID string) error { + if aliasID == "" { + return nil + } + + if txn == nil { + return fmt.Errorf("txn is nil") + } + + alias, err := i.memDBAliasByIDInTxn(txn, aliasID, false) + if err != nil { + return err + } + + if alias == nil { + return nil + } + + err = txn.Delete("aliases", alias) + if err != nil { + return fmt.Errorf("failed to delete alias from memdb: %v", err) + } + + return nil +} + +func (i *IdentityStore) memDBAliases(ws memdb.WatchSet) (memdb.ResultIterator, error) { + txn := i.db.Txn(false) + + iter, err := txn.Get("aliases", "id") + if err != nil { + return nil, err + } + + ws.Add(iter.WatchCh()) + + return iter, nil +} + +func (i *IdentityStore) memDBUpsertEntityInTxn(txn *memdb.Txn, entity *identity.Entity) error { + if txn == nil { + return fmt.Errorf("nil txn") + } + + if entity == nil { + return fmt.Errorf("entity is nil") + } + + entityRaw, err := txn.First("entities", "id", entity.ID) + if err != nil { + return fmt.Errorf("failed to lookup entity from memdb using entity id: %v", err) + } + + if entityRaw != nil { + err = txn.Delete("entities", entityRaw) + if err != nil { + return fmt.Errorf("failed to delete entity from memdb: %v", err) + } + } + + if err := txn.Insert("entities", entity); err != nil { + return fmt.Errorf("failed to update entity into memdb: %v", err) + } + + return nil +} + +func (i *IdentityStore) memDBUpsertEntity(entity *identity.Entity) error { + if entity == nil { + return fmt.Errorf("entity to upsert is nil") + } + + txn := i.db.Txn(true) + defer txn.Abort() + + err := i.memDBUpsertEntityInTxn(txn, entity) + if err != nil { + return err + } + + txn.Commit() + + return nil +} + +func (i *IdentityStore) memDBEntityByIDInTxn(txn *memdb.Txn, entityID string, clone bool) (*identity.Entity, error) { + if entityID == "" { + return nil, fmt.Errorf("missing entity id") + } + + if txn == nil { + return nil, fmt.Errorf("txn is nil") + } + + entityRaw, err := txn.First("entities", "id", entityID) + if err != nil { + return nil, fmt.Errorf("failed to fetch entity from memdb using entity id: %v", err) + } + + if entityRaw == nil { + return nil, nil + } + + entity, ok := entityRaw.(*identity.Entity) + if !ok { + return nil, fmt.Errorf("failed to declare the type of fetched entity") + } + + if clone { + return entity.Clone() + } + + return entity, nil +} + +func (i *IdentityStore) memDBEntityByID(entityID string, clone bool) (*identity.Entity, error) { + if entityID == "" { + return nil, fmt.Errorf("missing entity id") + } + + txn := i.db.Txn(false) + + return i.memDBEntityByIDInTxn(txn, entityID, clone) +} + +func (i *IdentityStore) memDBEntityByNameInTxn(txn *memdb.Txn, entityName string, clone bool) (*identity.Entity, error) { + if entityName == "" { + return nil, fmt.Errorf("missing entity name") + } + + if txn == nil { + return nil, fmt.Errorf("txn is nil") + } + + entityRaw, err := txn.First("entities", "name", entityName) + if err != nil { + return nil, fmt.Errorf("failed to fetch entity from memdb using entity name: %v", err) + } + + if entityRaw == nil { + return nil, nil + } + + entity, ok := entityRaw.(*identity.Entity) + if !ok { + return nil, fmt.Errorf("failed to declare the type of fetched entity") + } + + if clone { + return entity.Clone() + } + + return entity, nil +} + +func (i *IdentityStore) memDBEntityByName(entityName string, clone bool) (*identity.Entity, error) { + if entityName == "" { + return nil, fmt.Errorf("missing entity name") + } + + txn := i.db.Txn(false) + + return i.memDBEntityByNameInTxn(txn, entityName, clone) +} + +func (i *IdentityStore) memDBEntitiesByMetadata(filters map[string]string, clone bool) ([]*identity.Entity, error) { + if filters == nil { + return nil, fmt.Errorf("map filter is nil") + } + + txn := i.db.Txn(false) + defer txn.Abort() + + var args []interface{} + for key, value := range filters { + args = append(args, key, value) + break + } + + entitiesIter, err := txn.Get("entities", "metadata", args...) + if err != nil { + return nil, fmt.Errorf("failed to lookup entities using metadata: %v", err) + } + + var entities []*identity.Entity + for entity := entitiesIter.Next(); entity != nil; entity = entitiesIter.Next() { + entry := entity.(*identity.Entity) + if clone { + entry, err = entry.Clone() + if err != nil { + return nil, err + } + } + if len(filters) <= 1 || satisfiesMetadataFilters(entry.Metadata, filters) { + entities = append(entities, entry) + } + } + return entities, nil +} + +func (i *IdentityStore) memDBEntitiesByBucketEntryKeyHash(hashValue string) ([]*identity.Entity, error) { + if hashValue == "" { + return nil, fmt.Errorf("empty hash value") + } + + txn := i.db.Txn(false) + defer txn.Abort() + + return i.memDBEntitiesByBucketEntryKeyHashInTxn(txn, hashValue) +} + +func (i *IdentityStore) memDBEntitiesByBucketEntryKeyHashInTxn(txn *memdb.Txn, hashValue string) ([]*identity.Entity, error) { + if txn == nil { + return nil, fmt.Errorf("nil txn") + } + + if hashValue == "" { + return nil, fmt.Errorf("empty hash value") + } + + entitiesIter, err := txn.Get("entities", "bucket_key_hash", hashValue) + if err != nil { + return nil, fmt.Errorf("failed to lookup entities using bucket entry key hash: %v", err) + } + + var entities []*identity.Entity + for entity := entitiesIter.Next(); entity != nil; entity = entitiesIter.Next() { + entities = append(entities, entity.(*identity.Entity)) + } + + return entities, nil +} + +func (i *IdentityStore) memDBEntityByMergedEntityIDInTxn(txn *memdb.Txn, mergedEntityID string, clone bool) (*identity.Entity, error) { + if mergedEntityID == "" { + return nil, fmt.Errorf("missing merged entity id") + } + + if txn == nil { + return nil, fmt.Errorf("txn is nil") + } + + entityRaw, err := txn.First("entities", "merged_entity_ids", mergedEntityID) + if err != nil { + return nil, fmt.Errorf("failed to fetch entity from memdb using merged entity id: %v", err) + } + + if entityRaw == nil { + return nil, nil + } + + entity, ok := entityRaw.(*identity.Entity) + if !ok { + return nil, fmt.Errorf("failed to declare the type of fetched entity") + } + + if clone { + return entity.Clone() + } + + return entity, nil +} + +func (i *IdentityStore) memDBEntityByMergedEntityID(mergedEntityID string, clone bool) (*identity.Entity, error) { + if mergedEntityID == "" { + return nil, fmt.Errorf("missing merged entity id") + } + + txn := i.db.Txn(false) + + return i.memDBEntityByMergedEntityIDInTxn(txn, mergedEntityID, clone) +} + +func (i *IdentityStore) memDBEntityByAliasIDInTxn(txn *memdb.Txn, aliasID string, clone bool) (*identity.Entity, error) { + if aliasID == "" { + return nil, fmt.Errorf("missing alias ID") + } + + if txn == nil { + return nil, fmt.Errorf("txn is nil") + } + + alias, err := i.memDBAliasByIDInTxn(txn, aliasID, false) + if err != nil { + return nil, err + } + + if alias == nil { + return nil, nil + } + + return i.memDBEntityByIDInTxn(txn, alias.EntityID, clone) +} + +func (i *IdentityStore) memDBEntityByAliasID(aliasID string, clone bool) (*identity.Entity, error) { + if aliasID == "" { + return nil, fmt.Errorf("missing alias ID") + } + + txn := i.db.Txn(false) + + return i.memDBEntityByAliasIDInTxn(txn, aliasID, clone) +} + +func (i *IdentityStore) memDBDeleteEntityByID(entityID string) error { + if entityID == "" { + return nil + } + + txn := i.db.Txn(true) + defer txn.Abort() + + err := i.memDBDeleteEntityByIDInTxn(txn, entityID) + if err != nil { + return err + } + + txn.Commit() + + return nil +} + +func (i *IdentityStore) memDBDeleteEntityByIDInTxn(txn *memdb.Txn, entityID string) error { + if entityID == "" { + return nil + } + + if txn == nil { + return fmt.Errorf("txn is nil") + } + + entity, err := i.memDBEntityByIDInTxn(txn, entityID, false) + if err != nil { + return err + } + + if entity == nil { + return nil + } + + err = txn.Delete("entities", entity) + if err != nil { + return fmt.Errorf("failed to delete entity from memdb: %v", err) + } + + return nil +} + +func (i *IdentityStore) memDBEntities(ws memdb.WatchSet) (memdb.ResultIterator, error) { + txn := i.db.Txn(false) + + iter, err := txn.Get("entities", "id") + if err != nil { + return nil, err + } + + ws.Add(iter.WatchCh()) + + return iter, nil +} + +func (i *IdentityStore) sanitizeAlias(alias *identity.Alias) error { + var err error + + if alias == nil { + return fmt.Errorf("alias is nil") + } + + // Alias must always be tied to an entity + if alias.EntityID == "" { + return fmt.Errorf("missing entity ID") + } + + // Alias must have a name + if alias.Name == "" { + return fmt.Errorf("missing alias name %q", alias.Name) + } + + // Alias metadata should always be map[string]string + err = validateMetadata(alias.Metadata) + if err != nil { + return fmt.Errorf("invalid alias metadata: %v", err) + } + + // Create an ID if there isn't one already + if alias.ID == "" { + alias.ID, err = uuid.GenerateUUID() + if err != nil { + return fmt.Errorf("failed to generate alias ID") + } + } + + // Set the creation and last update times + if alias.CreationTime == nil { + alias.CreationTime = ptypes.TimestampNow() + alias.LastUpdateTime = alias.CreationTime + } else { + alias.LastUpdateTime = ptypes.TimestampNow() + } + + return nil +} + +func (i *IdentityStore) sanitizeEntity(entity *identity.Entity) error { + var err error + + if entity == nil { + return fmt.Errorf("entity is nil") + } + + // Create an ID if there isn't one already + if entity.ID == "" { + entity.ID, err = uuid.GenerateUUID() + if err != nil { + return fmt.Errorf("failed to generate entity id") + } + + // Set the hash value of the storage bucket key in entity + entity.BucketKeyHash = i.entityPacker.BucketKeyHashByItemID(entity.ID) + } + + // Create a name if there isn't one already + if entity.Name == "" { + entity.Name, err = i.generateName("entity") + if err != nil { + return fmt.Errorf("failed to generate entity name") + } + } + + // Entity metadata should always be map[string]string + err = validateMetadata(entity.Metadata) + if err != nil { + return fmt.Errorf("invalid entity metadata: %v", err) + } + + // Set the creation and last update times + if entity.CreationTime == nil { + entity.CreationTime = ptypes.TimestampNow() + entity.LastUpdateTime = entity.CreationTime + } else { + entity.LastUpdateTime = ptypes.TimestampNow() + } + + return nil +} + +func (i *IdentityStore) sanitizeAndUpsertGroup(group *identity.Group, memberGroupIDs []string) error { + var err error + + if group == nil { + return fmt.Errorf("group is nil") + } + + // Create an ID if there isn't one already + if group.ID == "" { + group.ID, err = uuid.GenerateUUID() + if err != nil { + return fmt.Errorf("failed to generate group id") + } + + // Set the hash value of the storage bucket key in group + group.BucketKeyHash = i.groupPacker.BucketKeyHashByItemID(group.ID) + } + + // Create a name if there isn't one already + if group.Name == "" { + group.Name, err = i.generateName("group") + if err != nil { + return fmt.Errorf("failed to generate group name") + } + } + + // Entity metadata should always be map[string]string + err = validateMetadata(group.Metadata) + if err != nil { + return fmt.Errorf("invalid group metadata: %v", err) + } + + // Set the creation and last update times + if group.CreationTime == nil { + group.CreationTime = ptypes.TimestampNow() + group.LastUpdateTime = group.CreationTime + } else { + group.LastUpdateTime = ptypes.TimestampNow() + } + + // Remove duplicate entity IDs and check if all IDs are valid + group.MemberEntityIDs = strutil.RemoveDuplicates(group.MemberEntityIDs, false) + for _, entityID := range group.MemberEntityIDs { + err = i.validateEntityID(entityID) + if err != nil { + return err + } + } + + txn := i.db.Txn(true) + defer txn.Abort() + + memberGroupIDs = strutil.RemoveDuplicates(memberGroupIDs, false) + // After the group lock is held, make membership updates to all the + // relevant groups + for _, memberGroupID := range memberGroupIDs { + memberGroup, err := i.memDBGroupByID(memberGroupID, true) + if err != nil { + return err + } + if memberGroup == nil { + return fmt.Errorf("invalid member group ID %q", memberGroupID) + } + + // Skip if memberGroupID is already a member of group.ID + if strutil.StrListContains(memberGroup.ParentGroupIDs, group.ID) { + continue + } + + // Ensure that adding memberGroupID does not lead to cyclic + // relationships + err = i.validateMemberGroupID(group.ID, memberGroupID) + if err != nil { + return err + } + + memberGroup.ParentGroupIDs = append(memberGroup.ParentGroupIDs, group.ID) + + // This technically is not upsert. It is only update, only the method name is upsert here. + err = i.upsertGroupInTxn(txn, memberGroup, true) + if err != nil { + // Ideally we would want to revert the whole operation in case of + // errors while persisting in member groups. But there is no + // storage transaction support yet. When we do have it, this will need + // an update. + return err + } + } + + err = i.upsertGroupInTxn(txn, group, true) + if err != nil { + return err + } + + txn.Commit() + + return nil +} + +func (i *IdentityStore) validateMemberGroupID(groupID string, memberGroupID string) error { + group, err := i.memDBGroupByID(groupID, true) + if err != nil { + return err + } + + // If group is nil, that means that a group doesn't already exist and its + // okay to add any group as its member group. + if group == nil { + return nil + } + + // Detect self loop + if groupID == memberGroupID { + fmt.Errorf("member group ID %q is same as the ID of the group") + } + + // If adding the memberGroupID to groupID creates a cycle, then groupID must + // be a hop in that loop. Start a DFS traversal from memberGroupID and see if + // it reaches back to groupID. If it does, then it's a loop. + + // Created a visited set + visited := make(map[string]bool) + cycleDetected, err := i.detectCycleDFS(visited, groupID, memberGroupID) + if err != nil { + return fmt.Errorf("failed to perform cyclic relationship detection for member group ID %q", memberGroupID) + } + if cycleDetected { + return fmt.Errorf("cyclic relationship detected for member group ID %q", memberGroupID) + } + + return nil +} + +func (i *IdentityStore) validateEntityID(entityID string) error { + entity, err := i.memDBEntityByID(entityID, false) + if err != nil { + return fmt.Errorf("failed to validate entity ID %q: %v", entityID, err) + } + if entity == nil { + return fmt.Errorf("invalid entity ID %q", entityID) + } + return nil +} + +func (i *IdentityStore) validateGroupID(groupID string) error { + group, err := i.memDBGroupByID(groupID, false) + if err != nil { + return fmt.Errorf("failed to validate group ID %q: %v", groupID, err) + } + if group == nil { + return fmt.Errorf("invalid group ID %q", groupID) + } + return nil +} + +func (i *IdentityStore) deleteAliasesInEntityInTxn(txn *memdb.Txn, entity *identity.Entity, aliases []*identity.Alias) error { + if entity == nil { + return fmt.Errorf("entity is nil") + } + + if txn == nil { + return fmt.Errorf("txn is nil") + } + + var remainList []*identity.Alias + var removeList []*identity.Alias + + for _, item := range aliases { + for _, alias := range entity.Aliases { + if alias.ID == item.ID { + removeList = append(removeList, alias) + } else { + remainList = append(remainList, alias) + } + } + } + + // Remove identity indices from aliases table for those that needs to + // be removed + for _, alias := range removeList { + aliasToBeRemoved, err := i.memDBAliasByIDInTxn(txn, alias.ID, false) + if err != nil { + return err + } + if aliasToBeRemoved == nil { + return fmt.Errorf("alias was not indexed") + } + err = i.memDBDeleteAliasByIDInTxn(txn, aliasToBeRemoved.ID) + if err != nil { + return err + } + } + + // Update the entity with remaining items + entity.Aliases = remainList + + return nil +} + +func (i *IdentityStore) deleteAliasFromEntity(entity *identity.Entity, alias *identity.Alias) error { + if entity == nil { + return fmt.Errorf("entity is nil") + } + + if alias == nil { + return fmt.Errorf("alias is nil") + } + + for aliasIndex, item := range entity.Aliases { + if item.ID == alias.ID { + entity.Aliases = append(entity.Aliases[:aliasIndex], entity.Aliases[aliasIndex+1:]...) + break + } + } + + return nil +} + +func (i *IdentityStore) updateAliasInEntity(entity *identity.Entity, alias *identity.Alias) error { + if entity == nil { + return fmt.Errorf("entity is nil") + } + + if alias == nil { + return fmt.Errorf("alias is nil") + } + + aliasFound := false + for aliasIndex, item := range entity.Aliases { + if item.ID == alias.ID { + aliasFound = true + entity.Aliases[aliasIndex] = alias + } + } + + if !aliasFound { + return fmt.Errorf("alias does not exist in entity") + } + + return nil +} + +// validateMeta validates a set of key/value pairs from the agent config +func validateMetadata(meta map[string]string) error { + if len(meta) > metaMaxKeyPairs { + return fmt.Errorf("metadata cannot contain more than %d key/value pairs", metaMaxKeyPairs) + } + + for key, value := range meta { + if err := validateMetaPair(key, value); err != nil { + return fmt.Errorf("failed to load metadata pair (%q, %q): %v", key, value, err) + } + } + + return nil +} + +// validateMetaPair checks that the given key/value pair is in a valid format +func validateMetaPair(key, value string) error { + if key == "" { + return fmt.Errorf("key cannot be blank") + } + if !metaKeyFormatRegEx(key) { + return fmt.Errorf("key contains invalid characters") + } + if len(key) > metaKeyMaxLength { + return fmt.Errorf("key is too long (limit: %d characters)", metaKeyMaxLength) + } + if strings.HasPrefix(key, metaKeyReservedPrefix) { + return fmt.Errorf("key prefix %q is reserved for internal use", metaKeyReservedPrefix) + } + if len(value) > metaValueMaxLength { + return fmt.Errorf("value is too long (limit: %d characters)", metaValueMaxLength) + } + return nil +} + +// satisfiesMetadataFilters returns true if the metadata map contains the given filters +func satisfiesMetadataFilters(meta map[string]string, filters map[string]string) bool { + for key, value := range filters { + if v, ok := meta[key]; !ok || v != value { + return false + } + } + return true +} + +func (i *IdentityStore) memDBGroupByNameInTxn(txn *memdb.Txn, groupName string, clone bool) (*identity.Group, error) { + if groupName == "" { + return nil, fmt.Errorf("missing group name") + } + + if txn == nil { + return nil, fmt.Errorf("txn is nil") + } + + groupRaw, err := txn.First("groups", "name", groupName) + if err != nil { + return nil, fmt.Errorf("failed to fetch group from memdb using group name: %v", err) + } + + if groupRaw == nil { + return nil, nil + } + + group, ok := groupRaw.(*identity.Group) + if !ok { + return nil, fmt.Errorf("failed to declare the type of fetched group") + } + + if clone { + return group.Clone() + } + + return group, nil +} + +func (i *IdentityStore) memDBGroupByName(groupName string, clone bool) (*identity.Group, error) { + if groupName == "" { + return nil, fmt.Errorf("missing group name") + } + + txn := i.db.Txn(false) + + return i.memDBGroupByNameInTxn(txn, groupName, clone) +} + +func (i *IdentityStore) upsertGroupInTxn(txn *memdb.Txn, group *identity.Group, persist bool) error { + var err error + + if txn == nil { + return fmt.Errorf("txn is nil") + } + + if group == nil { + return fmt.Errorf("group is nil") + } + + // Increment the modify index of the group + group.ModifyIndex++ + + // Insert or update group in MemDB using the transaction created above + err = i.memDBUpsertGroupInTxn(txn, group) + if err != nil { + return err + } + + if persist { + groupAsAny, err := ptypes.MarshalAny(group) + if err != nil { + return err + } + + item := &storagepacker.Item{ + ID: group.ID, + Message: groupAsAny, + } + + err = i.groupPacker.PutItem(item) + if err != nil { + return err + } + } + + return nil +} + +func (i *IdentityStore) memDBUpsertGroup(group *identity.Group) error { + txn := i.db.Txn(true) + defer txn.Abort() + + err := i.memDBUpsertGroupInTxn(txn, group) + if err != nil { + return err + } + + txn.Commit() + + return nil +} + +func (i *IdentityStore) memDBUpsertGroupInTxn(txn *memdb.Txn, group *identity.Group) error { + if txn == nil { + return fmt.Errorf("nil txn") + } + + if group == nil { + return fmt.Errorf("group is nil") + } + + groupRaw, err := txn.First("groups", "id", group.ID) + if err != nil { + return fmt.Errorf("failed to lookup group from memdb using group id: %v", err) + } + + if groupRaw != nil { + err = txn.Delete("groups", groupRaw) + if err != nil { + return fmt.Errorf("failed to delete group from memdb: %v", err) + } + } + + if err := txn.Insert("groups", group); err != nil { + return fmt.Errorf("failed to update group into memdb: %v", err) + } + + return nil +} + +func (i *IdentityStore) deleteGroupByID(groupID string) error { + var err error + var group *identity.Group + + if groupID == "" { + return fmt.Errorf("missing group ID") + } + + // Acquire the lock to modify the group storage entry + i.groupLock.Lock() + defer i.groupLock.Unlock() + + // Create a MemDB transaction to delete group + txn := i.db.Txn(true) + defer txn.Abort() + + group, err = i.memDBGroupByIDInTxn(txn, groupID, false) + if err != nil { + return err + } + + // If there is no entity for the ID, do nothing + if group == nil { + return nil + } + + // Delete the group using the same transaction + err = i.memDBDeleteGroupByIDInTxn(txn, group.ID) + if err != nil { + return err + } + + // Delete the entity from storage + err = i.groupPacker.DeleteItem(group.ID) + if err != nil { + return err + } + + // Committing the transaction *after* successfully deleting group + txn.Commit() + + return nil +} + +func (i *IdentityStore) memDBDeleteGroupByIDInTxn(txn *memdb.Txn, groupID string) error { + if groupID == "" { + return nil + } + + if txn == nil { + return fmt.Errorf("txn is nil") + } + + group, err := i.memDBGroupByIDInTxn(txn, groupID, false) + if err != nil { + return err + } + + if group == nil { + return nil + } + + err = txn.Delete("groups", group) + if err != nil { + return fmt.Errorf("failed to delete group from memdb: %v", err) + } + + return nil +} + +func (i *IdentityStore) deleteGroupByName(groupName string) error { + var err error + var group *identity.Group + + if groupName == "" { + return fmt.Errorf("missing group name") + } + + // Acquire the lock to modify the group storage entry + i.groupLock.Lock() + defer i.groupLock.Unlock() + + // Create a MemDB transaction to delete group + txn := i.db.Txn(true) + defer txn.Abort() + + // Fetch the group using its ID + group, err = i.memDBGroupByNameInTxn(txn, groupName, false) + if err != nil { + return err + } + + // If there is no entity for the ID, do nothing + if group == nil { + return nil + } + + // Delete the group using the same transaction + err = i.memDBDeleteGroupByNameInTxn(txn, group.Name) + if err != nil { + return err + } + + // Delete the entity from storage + err = i.groupPacker.DeleteItem(group.ID) + if err != nil { + return err + } + + // Committing the transaction *after* successfully deleting group + txn.Commit() + + return nil +} + +func (i *IdentityStore) memDBDeleteGroupByNameInTxn(txn *memdb.Txn, groupName string) error { + if groupName == "" { + return nil + } + + if txn == nil { + return fmt.Errorf("txn is nil") + } + + group, err := i.memDBGroupByNameInTxn(txn, groupName, false) + if err != nil { + return err + } + + if group == nil { + return nil + } + + err = txn.Delete("groups", group) + if err != nil { + return fmt.Errorf("failed to delete group from memdb: %v", err) + } + + return nil +} + +func (i *IdentityStore) memDBGroupByIDInTxn(txn *memdb.Txn, groupID string, clone bool) (*identity.Group, error) { + if groupID == "" { + return nil, fmt.Errorf("missing group ID") + } + + if txn == nil { + return nil, fmt.Errorf("txn is nil") + } + + groupRaw, err := txn.First("groups", "id", groupID) + if err != nil { + return nil, fmt.Errorf("failed to fetch group from memdb using group ID: %v", err) + } + + if groupRaw == nil { + return nil, nil + } + + group, ok := groupRaw.(*identity.Group) + if !ok { + return nil, fmt.Errorf("failed to declare the type of fetched group") + } + + if clone { + return group.Clone() + } + + return group, nil +} + +func (i *IdentityStore) memDBGroupByID(groupID string, clone bool) (*identity.Group, error) { + if groupID == "" { + return nil, fmt.Errorf("missing group ID") + } + + txn := i.db.Txn(false) + + return i.memDBGroupByIDInTxn(txn, groupID, clone) +} + +func (i *IdentityStore) memDBGroupsByPolicyInTxn(txn *memdb.Txn, policyName string, clone bool) ([]*identity.Group, error) { + if policyName == "" { + return nil, fmt.Errorf("missing policy name") + } + + groupsIter, err := txn.Get("groups", "policies", policyName) + if err != nil { + return nil, fmt.Errorf("failed to lookup groups using policy name: %v", err) + } + + var groups []*identity.Group + for group := groupsIter.Next(); group != nil; group = groupsIter.Next() { + entry := group.(*identity.Group) + if clone { + entry, err = entry.Clone() + if err != nil { + return nil, err + } + } + groups = append(groups, entry) + } + + return groups, nil +} + +func (i *IdentityStore) memDBGroupsByPolicy(policyName string, clone bool) ([]*identity.Group, error) { + if policyName == "" { + return nil, fmt.Errorf("missing policy name") + } + + txn := i.db.Txn(false) + + return i.memDBGroupsByPolicyInTxn(txn, policyName, clone) +} + +func (i *IdentityStore) memDBGroupsByParentGroupIDInTxn(txn *memdb.Txn, memberGroupID string, clone bool) ([]*identity.Group, error) { + if memberGroupID == "" { + return nil, fmt.Errorf("missing member group ID") + } + + groupsIter, err := txn.Get("groups", "parent_group_ids", memberGroupID) + if err != nil { + return nil, fmt.Errorf("failed to lookup groups using member group ID: %v", err) + } + + var groups []*identity.Group + for group := groupsIter.Next(); group != nil; group = groupsIter.Next() { + entry := group.(*identity.Group) + if clone { + entry, err = entry.Clone() + if err != nil { + return nil, err + } + } + groups = append(groups, entry) + } + + return groups, nil +} + +func (i *IdentityStore) memDBGroupsByParentGroupID(memberGroupID string, clone bool) ([]*identity.Group, error) { + if memberGroupID == "" { + return nil, fmt.Errorf("missing member group ID") + } + + txn := i.db.Txn(false) + + return i.memDBGroupsByParentGroupIDInTxn(txn, memberGroupID, clone) +} + +func (i *IdentityStore) memDBGroupsByMemberEntityID(entityID string, clone bool) ([]*identity.Group, error) { + if entityID == "" { + return nil, fmt.Errorf("missing entity ID") + } + + txn := i.db.Txn(false) + defer txn.Abort() + + groupsIter, err := txn.Get("groups", "member_entity_ids", entityID) + if err != nil { + return nil, fmt.Errorf("failed to lookup groups using entity ID: %v", err) + } + + var groups []*identity.Group + for group := groupsIter.Next(); group != nil; group = groupsIter.Next() { + entry := group.(*identity.Group) + if clone { + entry, err = entry.Clone() + if err != nil { + return nil, err + } + } + groups = append(groups, entry) + } + + return groups, nil +} + +func (i *IdentityStore) groupPoliciesByEntityID(entityID string) ([]string, error) { + if entityID == "" { + return nil, fmt.Errorf("empty entity ID") + } + + groups, err := i.memDBGroupsByMemberEntityID(entityID, false) + if err != nil { + return nil, err + } + + visited := make(map[string]bool) + var policies []string + for _, group := range groups { + policies, err = i.collectPoliciesReverseDFS(group, visited, nil) + if err != nil { + return nil, err + } + } + + return strutil.RemoveDuplicates(policies, false), nil +} + +func (i *IdentityStore) transitiveGroupsByEntityID(entityID string) ([]*identity.Group, error) { + if entityID == "" { + return nil, fmt.Errorf("empty entity ID") + } + + groups, err := i.memDBGroupsByMemberEntityID(entityID, false) + if err != nil { + return nil, err + } + + visited := make(map[string]bool) + var tGroups []*identity.Group + for _, group := range groups { + tGroups, err = i.collectGroupsReverseDFS(group, visited, nil) + if err != nil { + return nil, err + } + } + + // Remove duplicates + groupMap := make(map[string]*identity.Group) + for _, group := range tGroups { + groupMap[group.ID] = group + } + + tGroups = nil + for _, group := range groupMap { + tGroups = append(tGroups, group) + } + + return tGroups, nil +} + +func (i *IdentityStore) collectGroupsReverseDFS(group *identity.Group, visited map[string]bool, groups []*identity.Group) ([]*identity.Group, error) { + if group == nil { + return nil, fmt.Errorf("nil group") + } + + // If traversal for a groupID is performed before, skip it + if visited[group.ID] { + return groups, nil + } + visited[group.ID] = true + + groups = append(groups, group) + + // Traverse all the parent groups + for _, parentGroupID := range group.ParentGroupIDs { + parentGroup, err := i.memDBGroupByID(parentGroupID, false) + if err != nil { + return nil, err + } + groups, err = i.collectGroupsReverseDFS(parentGroup, visited, groups) + if err != nil { + return nil, fmt.Errorf("failed to collect group at parent group ID %q", parentGroup.ID) + } + } + + return groups, nil +} + +func (i *IdentityStore) collectPoliciesReverseDFS(group *identity.Group, visited map[string]bool, policies []string) ([]string, error) { + if group == nil { + return nil, fmt.Errorf("nil group") + } + + // If traversal for a groupID is performed before, skip it + if visited[group.ID] { + return policies, nil + } + visited[group.ID] = true + + policies = append(policies, group.Policies...) + + // Traverse all the parent groups + for _, parentGroupID := range group.ParentGroupIDs { + parentGroup, err := i.memDBGroupByID(parentGroupID, false) + if err != nil { + return nil, err + } + policies, err = i.collectPoliciesReverseDFS(parentGroup, visited, policies) + if err != nil { + return nil, fmt.Errorf("failed to collect policies at parent group ID %q", parentGroup.ID) + } + } + + return policies, nil +} + +func (i *IdentityStore) detectCycleDFS(visited map[string]bool, startingGroupID, groupID string) (bool, error) { + // If the traversal reaches the startingGroupID, a loop is detected + if startingGroupID == groupID { + return true, nil + } + + // If traversal for a groupID is performed before, skip it + if visited[groupID] { + return false, nil + } + visited[groupID] = true + + group, err := i.memDBGroupByID(groupID, true) + if err != nil { + return false, err + } + if group == nil { + return false, nil + } + + // Fetch all groups in which groupID is present as a ParentGroupID. In + // other words, find all the subgroups of groupID. + memberGroups, err := i.memDBGroupsByParentGroupID(groupID, false) + if err != nil { + return false, err + } + + // DFS traverse the member groups + for _, memberGroup := range memberGroups { + cycleDetected, err := i.detectCycleDFS(visited, startingGroupID, memberGroup.ID) + if err != nil { + return false, fmt.Errorf("failed to perform cycle detection at member group ID %q", memberGroup.ID) + } + if cycleDetected { + return true, fmt.Errorf("cycle detected at member group ID %q", memberGroup.ID) + } + } + + return false, nil +} + +func (i *IdentityStore) memberGroupIDsByID(groupID string) ([]string, error) { + var memberGroupIDs []string + memberGroups, err := i.memDBGroupsByParentGroupID(groupID, false) + if err != nil { + return nil, err + } + for _, memberGroup := range memberGroups { + memberGroupIDs = append(memberGroupIDs, memberGroup.ID) + } + return memberGroupIDs, nil +} + +func (i *IdentityStore) memDBGroupIterator(ws memdb.WatchSet) (memdb.ResultIterator, error) { + txn := i.db.Txn(false) + + iter, err := txn.Get("groups", "id") + if err != nil { + return nil, err + } + + ws.Add(iter.WatchCh()) + + return iter, nil +} + +func (i *IdentityStore) generateName(entryType string) (string, error) { + var name string +OUTER: + for { + randBytes, err := uuid.GenerateRandomBytes(4) + if err != nil { + return "", err + } + name = fmt.Sprintf("%s_%s", entryType, fmt.Sprintf("%08x", randBytes[0:4])) + + switch entryType { + case "entity": + entity, err := i.memDBEntityByName(name, false) + if err != nil { + return "", err + } + if entity == nil { + break OUTER + } + case "group": + group, err := i.memDBGroupByName(name, false) + if err != nil { + return "", err + } + if group == nil { + break OUTER + } + default: + return "", fmt.Errorf("unrecognized type %q", entryType) + } + } + + return name, nil +} + +func (i *IdentityStore) memDBGroupsByBucketEntryKeyHash(hashValue string) ([]*identity.Group, error) { + if hashValue == "" { + return nil, fmt.Errorf("empty hash value") + } + + txn := i.db.Txn(false) + defer txn.Abort() + + return i.memDBGroupsByBucketEntryKeyHashInTxn(txn, hashValue) +} + +func (i *IdentityStore) memDBGroupsByBucketEntryKeyHashInTxn(txn *memdb.Txn, hashValue string) ([]*identity.Group, error) { + if txn == nil { + return nil, fmt.Errorf("nil txn") + } + + if hashValue == "" { + return nil, fmt.Errorf("empty hash value") + } + + groupsIter, err := txn.Get("groups", "bucket_key_hash", hashValue) + if err != nil { + return nil, fmt.Errorf("failed to lookup groups using bucket entry key hash: %v", err) + } + + var groups []*identity.Group + for group := groupsIter.Next(); group != nil; group = groupsIter.Next() { + groups = append(groups, group.(*identity.Group)) + } + + return groups, nil +} diff --git a/vault/identity_store_util_test.go b/vault/identity_store_util_test.go new file mode 100644 index 0000000000..d82697ad0a --- /dev/null +++ b/vault/identity_store_util_test.go @@ -0,0 +1,40 @@ +package vault + +import ( + "reflect" + "testing" +) + +func TestIdentityStore_parseMetadata(t *testing.T) { + goodKVs := []string{ + "key1=value1", + "key2=value1=value2", + } + expectedMap := map[string]string{ + "key1": "value1", + "key2": "value1=value2", + } + + actualMap, err := parseMetadata(goodKVs) + if err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(expectedMap, actualMap) { + t.Fatalf("bad: metadata; expected: %#v\n, actual: %#v\n", expectedMap, actualMap) + } + + badKV := []string{ + "=world", + } + actualMap, err = parseMetadata(badKV) + if err == nil { + t.Fatalf("expected an error; got: %#v", actualMap) + } + + badKV[0] = "world" + actualMap, err = parseMetadata(badKV) + if err == nil { + t.Fatalf("expected an error: %#v", actualMap) + } +} diff --git a/vault/logical_system_test.go b/vault/logical_system_test.go index 3f9243b017..9e6702d56c 100644 --- a/vault/logical_system_test.go +++ b/vault/logical_system_test.go @@ -149,6 +149,17 @@ func TestSystemBackend_mounts(t *testing.T) { }, "local": true, }, + "identity/": map[string]interface{}{ + "description": "identity store", + "type": "identity", + "accessor": resp.Data["identity/"].(map[string]interface{})["accessor"], + "config": map[string]interface{}{ + "default_lease_ttl": resp.Data["identity/"].(map[string]interface{})["config"].(map[string]interface{})["default_lease_ttl"].(int64), + "max_lease_ttl": resp.Data["identity/"].(map[string]interface{})["config"].(map[string]interface{})["max_lease_ttl"].(int64), + "force_no_cache": false, + }, + "local": false, + }, } if !reflect.DeepEqual(resp.Data, exp) { t.Fatalf("Got:\n%#v\nExpected:\n%#v", resp.Data, exp) diff --git a/vault/mount.go b/vault/mount.go index 41aece9762..7f4dd843da 100644 --- a/vault/mount.go +++ b/vault/mount.go @@ -50,12 +50,14 @@ var ( "auth/", "sys/", "cubbyhole/", + "identity/", } untunableMounts = []string{ "cubbyhole/", "sys/", "audit/", + "identity/", } // singletonMounts can only exist in one location and are @@ -64,6 +66,7 @@ var ( "cubbyhole", "system", "token", + "identity", } // mountAliases maps old backend names to new backend names, allowing us @@ -255,6 +258,8 @@ func (c *Core) mount(entry *MountEntry) error { return err } + c.setCoreBackend(entry, backend, view) + newTable := c.mounts.shallowClone() newTable.Entries = append(newTable.Entries, entry) if err := c.persistMounts(newTable, entry.Local); err != nil { @@ -713,14 +718,8 @@ func (c *Core) setupMounts() error { return err } - switch entry.Type { - case "system": - c.systemBarrierView = view - case "cubbyhole": - ch := backend.(*CubbyholeBackend) - ch.saltUUID = entry.UUID - ch.storageView = view - } + c.setCoreBackend(entry, backend, view) + ROUTER_MOUNT: // Mount the backend err = c.router.Mount(backend, entry.Path, entry, view) @@ -865,8 +864,29 @@ func (c *Core) requiredMountTable() *MountTable { UUID: sysUUID, Accessor: sysAccessor, } + + identityUUID, err := uuid.GenerateUUID() + if err != nil { + panic(fmt.Sprintf("could not create identity mount entry UUID: %v", err)) + } + identityAccessor, err := c.generateMountAccessor("identity") + if err != nil { + panic(fmt.Sprintf("could not generate identity accessor: %v", err)) + } + + identityMount := &MountEntry{ + Table: mountTableType, + Path: "identity/", + Type: "identity", + Description: "identity store", + UUID: identityUUID, + Accessor: identityAccessor, + } + table.Entries = append(table.Entries, cubbyholeMount) table.Entries = append(table.Entries, sysMount) + table.Entries = append(table.Entries, identityMount) + return table } @@ -898,3 +918,17 @@ func (c *Core) singletonMountTables() (mounts, auth *MountTable) { return } + +func (c *Core) setCoreBackend(entry *MountEntry, backend logical.Backend, view *BarrierView) { + switch entry.Type { + case "system": + c.systemBackend = backend.(*SystemBackend) + c.systemBarrierView = view + case "cubbyhole": + ch := backend.(*CubbyholeBackend) + ch.saltUUID = entry.UUID + ch.storageView = view + case "identity": + c.identityStore = backend.(*IdentityStore) + } +} diff --git a/vault/mount_test.go b/vault/mount_test.go index cf24e18a0f..7e4c5f709f 100644 --- a/vault/mount_test.go +++ b/vault/mount_test.go @@ -592,31 +592,26 @@ func testCore_MountTable_UpgradeToTyped_Common( } func verifyDefaultTable(t *testing.T, table *MountTable) { - if len(table.Entries) != 3 { + if len(table.Entries) != 4 { t.Fatalf("bad: %v", table.Entries) } table.sortEntriesByPath() - for idx, entry := range table.Entries { - switch idx { - case 0: - if entry.Path != "cubbyhole/" { - t.Fatalf("bad: %v", entry) - } + for _, entry := range table.Entries { + switch entry.Path { + case "cubbyhole/": if entry.Type != "cubbyhole" { t.Fatalf("bad: %v", entry) } - case 1: - if entry.Path != "secret/" { - t.Fatalf("bad: %v", entry) - } + case "secret/": if entry.Type != "kv" { t.Fatalf("bad: %v", entry) } - case 2: - if entry.Path != "sys/" { + case "sys/": + if entry.Type != "system" { t.Fatalf("bad: %v", entry) } - if entry.Type != "system" { + case "identity/": + if entry.Type != "identity" { t.Fatalf("bad: %v", entry) } } @@ -637,12 +632,13 @@ func TestSingletonMountTableFunc(t *testing.T) { mounts, auth := c.singletonMountTables() - if len(mounts.Entries) != 1 { + if len(mounts.Entries) != 2 { t.Fatal("length of mounts is wrong") } for _, entry := range mounts.Entries { switch entry.Type { case "system": + case "identity": default: t.Fatalf("unknown type %s", entry.Type) } diff --git a/vault/request_handling.go b/vault/request_handling.go index b003b3ff4a..0297242219 100644 --- a/vault/request_handling.go +++ b/vault/request_handling.go @@ -8,6 +8,7 @@ import ( "github.com/armon/go-metrics" "github.com/hashicorp/go-multierror" "github.com/hashicorp/vault/helper/consts" + "github.com/hashicorp/vault/helper/identity" "github.com/hashicorp/vault/helper/jsonutil" "github.com/hashicorp/vault/helper/policyutil" "github.com/hashicorp/vault/helper/strutil" @@ -66,7 +67,7 @@ func (c *Core) HandleRequest(req *logical.Request) (resp *logical.Response, err resp.WrapInfo.TTL != 0 if wrapping { - cubbyResp, cubbyErr := c.wrapInCubbyhole(req, resp) + cubbyResp, cubbyErr := c.wrapInCubbyhole(req, resp, auth) // If not successful, returns either an error response from the // cubbyhole backend or an error; if either is set, set resp and err to // those and continue so that that's what we audit log. Otherwise @@ -387,8 +388,42 @@ func (c *Core) handleLoginRequest(req *logical.Request) (*logical.Response, *log // If the response generated an authentication, then generate the token var auth *logical.Auth if resp != nil && resp.Auth != nil { + var entity *identity.Entity auth = resp.Auth + if auth.Alias != nil { + // Overwrite the mount type and mount path in the alias + // information + auth.Alias.MountType = req.MountType + auth.Alias.MountAccessor = req.MountAccessor + + if auth.Alias.Name == "" { + return nil, nil, fmt.Errorf("missing name in alias") + } + + var err error + + // Check if an entity already exists for the given alias + entity, err = c.identityStore.EntityByAliasFactors(auth.Alias.MountAccessor, auth.Alias.Name, false) + if err != nil { + return nil, nil, err + } + + // If not, create one. + if entity == nil { + c.logger.Debug("core: creating a new entity", "alias", auth.Alias) + entity, err = c.identityStore.CreateEntity(auth.Alias) + if err != nil { + return nil, nil, err + } + if entity == nil { + return nil, nil, fmt.Errorf("failed to create an entity for the authenticated alias") + } + } + + auth.EntityID = entity.ID + } + if strutil.StrListSubset(auth.Policies, []string{"root"}) { return logical.ErrorResponse("authentication backends cannot create root tokens"), nil, logical.ErrInvalidRequest } diff --git a/vault/router.go b/vault/router.go index f05e2076d5..479a323c8a 100644 --- a/vault/router.go +++ b/vault/router.go @@ -47,6 +47,31 @@ type routeEntry struct { loginPaths *radix.Tree } +type validateMountResponse struct { + MountType string `json:"mount_type" structs:"mount_type" mapstructure:"mount_type"` + MountAccessor string `json:"mount_accessor" structs:"mount_accessor" mapstructure:"mount_accessor"` + MountPath string `json:"mount_path" structs:"mount_path" mapstructure:"mount_path"` +} + +// validateMountByAccessor returns the mount type and ID for a given mount +// accessor +func (r *Router) validateMountByAccessor(accessor string) *validateMountResponse { + if accessor == "" { + return nil + } + + mountEntry := r.MatchingMountByAccessor(accessor) + if mountEntry == nil { + return nil + } + + return &validateMountResponse{ + MountAccessor: mountEntry.Accessor, + MountType: mountEntry.Type, + MountPath: mountEntry.Path, + } +} + // SaltID is used to apply a salt and hash to an ID to make sure its not reversible func (re *routeEntry) SaltID(id string) string { return salt.SaltID(re.mountEntry.UUID, id, salt.SHA1Hash) @@ -330,6 +355,17 @@ func (r *Router) routeCommon(req *logical.Request, existenceCheck bool) (*logica // Attach the storage view for the request req.Storage = re.storageView + originalEntityID := req.EntityID + + // Allow EntityID to passthrough to the system backend. This is required to + // allow clients to generate MFA credentials in respective entity objects + // in identity store via the system backend. + switch { + case strings.HasPrefix(originalPath, "sys/"): + default: + req.EntityID = "" + } + // Hash the request token unless this is the token backend clientToken := req.ClientToken switch { @@ -385,6 +421,12 @@ func (r *Router) routeCommon(req *logical.Request, existenceCheck bool) (*logica // This is only set in one place, after routing, so should never be set // by a backend req.SetLastRemoteWAL(0) + + // This will be used for attaching the mount accessor for the identities + // returned by the authentication backends + req.MountAccessor = re.mountEntry.Accessor + + req.EntityID = originalEntityID }() // Invoke the backend @@ -393,6 +435,11 @@ func (r *Router) routeCommon(req *logical.Request, existenceCheck bool) (*logica return nil, ok, exists, err } else { resp, err := re.backend.HandleRequest(req) + if resp != nil && + resp.Auth != nil && + resp.Auth.Alias != nil { + resp.Auth.Alias.MountAccessor = re.mountEntry.Accessor + } return resp, false, false, err } } diff --git a/vault/testing.go b/vault/testing.go index 3e500c2fe1..d26f2d640b 100644 --- a/vault/testing.go +++ b/vault/testing.go @@ -154,6 +154,15 @@ func testCoreConfig(t testing.T, physicalBackend physical.Backend, logger log.Lo noopBackends["http"] = func(config *logical.BackendConfig) (logical.Backend, error) { return new(rawHTTP), nil } + + credentialBackends := make(map[string]logical.Factory) + for backendName, backendFactory := range noopBackends { + credentialBackends[backendName] = backendFactory + } + for backendName, backendFactory := range testCredentialBackends { + credentialBackends[backendName] = backendFactory + } + logicalBackends := make(map[string]logical.Factory) for backendName, backendFactory := range noopBackends { logicalBackends[backendName] = backendFactory @@ -167,7 +176,7 @@ func testCoreConfig(t testing.T, physicalBackend physical.Backend, logger log.Lo Physical: physicalBackend, AuditBackends: noopAudits, LogicalBackends: logicalBackends, - CredentialBackends: noopBackends, + CredentialBackends: credentialBackends, DisableMlock: true, Logger: logger, } @@ -375,6 +384,7 @@ func TestAddTestPlugin(t testing.T, c *Core, name, testFunc string) { } var testLogicalBackends = map[string]logical.Factory{} +var testCredentialBackends = map[string]logical.Factory{} // Starts the test server which responds to SSH authentication. // Used to test the SSH secret backend. @@ -467,6 +477,19 @@ func executeServerCommand(ch ssh.Channel, req *ssh.Request) { }() } +// This adds a credential backend for the test core. This needs to be +// invoked before the test core is created. +func AddTestCredentialBackend(name string, factory logical.Factory) error { + if name == "" { + return fmt.Errorf("missing backend name") + } + if factory == nil { + return fmt.Errorf("missing backend factory function") + } + testCredentialBackends[name] = factory + return nil +} + // This adds a logical backend for the test core. This needs to be // invoked before the test core is created. func AddTestLogicalBackend(name string, factory logical.Factory) error { diff --git a/vault/token_store.go b/vault/token_store.go index 2708e48c22..04d019257a 100644 --- a/vault/token_store.go +++ b/vault/token_store.go @@ -581,6 +581,8 @@ type TokenEntry struct { NumUsesDeprecated int `json:"NumUses" mapstructure:"NumUses" structs:"NumUses"` CreationTimeDeprecated int64 `json:"CreationTime" mapstructure:"CreationTime" structs:"CreationTime"` ExplicitMaxTTLDeprecated time.Duration `json:"ExplicitMaxTTL" mapstructure:"ExplicitMaxTTL" structs:"ExplicitMaxTTL"` + + EntityID string `json:"entity_id" mapstructure:"entity_id" structs:"entity_id"` } // tsRoleEntry contains token store role information @@ -1730,6 +1732,13 @@ func (ts *TokenStore) handleCreateCommon( } } + // At this point, it is clear whether the token is going to be an orphan or + // not. If the token is not going to be an orphan, inherit the parent's + // entity identifier into the child token. + if te.Parent != "" { + te.EntityID = parent.EntityID + } + if data.ExplicitMaxTTL != "" { dur, err := parseutil.ParseDurationSecond(data.ExplicitMaxTTL) if err != nil { @@ -1872,6 +1881,7 @@ func (ts *TokenStore) handleCreateCommon( }, ClientToken: te.ID, Accessor: te.Accessor, + EntityID: te.EntityID, } if ts.policyLookupFunc != nil { @@ -2037,6 +2047,7 @@ func (ts *TokenStore) handleLookup( "expire_time": nil, "ttl": int64(0), "explicit_max_ttl": int64(out.ExplicitMaxTTL.Seconds()), + "entity_id": out.EntityID, }, } diff --git a/vault/token_store_test.go b/vault/token_store_test.go index 966508c4cc..8df37557ea 100644 --- a/vault/token_store_test.go +++ b/vault/token_store_test.go @@ -1448,6 +1448,7 @@ func TestTokenStore_HandleRequest_Lookup(t *testing.T) { "ttl": int64(0), "explicit_max_ttl": int64(0), "expire_time": nil, + "entity_id": "", } if resp.Data["creation_time"].(int64) == 0 { @@ -1487,6 +1488,7 @@ func TestTokenStore_HandleRequest_Lookup(t *testing.T) { "ttl": int64(3600), "explicit_max_ttl": int64(0), "renewable": true, + "entity_id": "", } if resp.Data["creation_time"].(int64) == 0 { @@ -1537,6 +1539,7 @@ func TestTokenStore_HandleRequest_Lookup(t *testing.T) { "ttl": int64(3600), "explicit_max_ttl": int64(0), "renewable": true, + "entity_id": "", } if resp.Data["creation_time"].(int64) == 0 { @@ -1618,6 +1621,7 @@ func TestTokenStore_HandleRequest_LookupSelf(t *testing.T) { "creation_ttl": int64(3600), "ttl": int64(3600), "explicit_max_ttl": int64(0), + "entity_id": "", } if resp.Data["creation_time"].(int64) == 0 { diff --git a/vault/wrapping.go b/vault/wrapping.go index 51715938df..cad6ee0322 100644 --- a/vault/wrapping.go +++ b/vault/wrapping.go @@ -72,7 +72,7 @@ func (c *Core) ensureWrappingKey() error { return nil } -func (c *Core) wrapInCubbyhole(req *logical.Request, resp *logical.Response) (*logical.Response, error) { +func (c *Core) wrapInCubbyhole(req *logical.Request, resp *logical.Response, auth *logical.Auth) (*logical.Response, error) { // Before wrapping, obey special rules for listing: if no entries are // found, 404. This prevents unwrapping only to find empty data. if req.Operation == logical.ListOperation { @@ -120,6 +120,10 @@ func (c *Core) wrapInCubbyhole(req *logical.Request, resp *logical.Response) (*l resp.WrapInfo.CreationPath = req.Path } + if auth != nil && auth.EntityID != "" { + resp.WrapInfo.WrappedEntityID = auth.EntityID + } + // This will only be non-nil if this response contains a token, so in that // case put the accessor in the wrap info. if resp.Auth != nil { @@ -223,7 +227,7 @@ func (c *Core) wrapInCubbyhole(req *logical.Request, resp *logical.Response) (*l return cubbyResp, nil } - auth := &logical.Auth{ + wAuth := &logical.Auth{ ClientToken: te.ID, Policies: []string{"response-wrapping"}, LeaseOptions: logical.LeaseOptions{ @@ -233,7 +237,7 @@ func (c *Core) wrapInCubbyhole(req *logical.Request, resp *logical.Response) (*l } // Register the wrapped token with the expiration manager - if err := c.expiration.RegisterAuth(te.Path, auth); err != nil { + if err := c.expiration.RegisterAuth(te.Path, wAuth); err != nil { // Revoke since it's not yet being tracked for expiration c.tokenStore.Revoke(te.ID) c.logger.Error("core: failed to register cubbyhole wrapping token lease", "request_path", req.Path, "error", err) diff --git a/vendor/github.com/hashicorp/go-memdb/LICENSE b/vendor/github.com/hashicorp/go-memdb/LICENSE new file mode 100644 index 0000000000..e87a115e46 --- /dev/null +++ b/vendor/github.com/hashicorp/go-memdb/LICENSE @@ -0,0 +1,363 @@ +Mozilla Public License, version 2.0 + +1. Definitions + +1.1. "Contributor" + + means each individual or legal entity that creates, contributes to the + creation of, or owns Covered Software. + +1.2. "Contributor Version" + + means the combination of the Contributions of others (if any) used by a + Contributor and that particular Contributor's Contribution. + +1.3. "Contribution" + + means Covered Software of a particular Contributor. + +1.4. "Covered Software" + + means Source Code Form to which the initial Contributor has attached the + notice in Exhibit A, the Executable Form of such Source Code Form, and + Modifications of such Source Code Form, in each case including portions + thereof. + +1.5. "Incompatible With Secondary Licenses" + means + + a. that the initial Contributor has attached the notice described in + Exhibit B to the Covered Software; or + + b. that the Covered Software was made available under the terms of + version 1.1 or earlier of the License, but not also under the terms of + a Secondary License. + +1.6. "Executable Form" + + means any form of the work other than Source Code Form. + +1.7. "Larger Work" + + means a work that combines Covered Software with other material, in a + separate file or files, that is not Covered Software. + +1.8. "License" + + means this document. + +1.9. "Licensable" + + means having the right to grant, to the maximum extent possible, whether + at the time of the initial grant or subsequently, any and all of the + rights conveyed by this License. + +1.10. "Modifications" + + means any of the following: + + a. any file in Source Code Form that results from an addition to, + deletion from, or modification of the contents of Covered Software; or + + b. any new file in Source Code Form that contains any Covered Software. + +1.11. "Patent Claims" of a Contributor + + means any patent claim(s), including without limitation, method, + process, and apparatus claims, in any patent Licensable by such + Contributor that would be infringed, but for the grant of the License, + by the making, using, selling, offering for sale, having made, import, + or transfer of either its Contributions or its Contributor Version. + +1.12. "Secondary License" + + means either the GNU General Public License, Version 2.0, the GNU Lesser + General Public License, Version 2.1, the GNU Affero General Public + License, Version 3.0, or any later versions of those licenses. + +1.13. "Source Code Form" + + means the form of the work preferred for making modifications. + +1.14. "You" (or "Your") + + means an individual or a legal entity exercising rights under this + License. For legal entities, "You" includes any entity that controls, is + controlled by, or is under common control with You. For purposes of this + definition, "control" means (a) the power, direct or indirect, to cause + the direction or management of such entity, whether by contract or + otherwise, or (b) ownership of more than fifty percent (50%) of the + outstanding shares or beneficial ownership of such entity. + + +2. License Grants and Conditions + +2.1. Grants + + Each Contributor hereby grants You a world-wide, royalty-free, + non-exclusive license: + + a. under intellectual property rights (other than patent or trademark) + Licensable by such Contributor to use, reproduce, make available, + modify, display, perform, distribute, and otherwise exploit its + Contributions, either on an unmodified basis, with Modifications, or + as part of a Larger Work; and + + b. under Patent Claims of such Contributor to make, use, sell, offer for + sale, have made, import, and otherwise transfer either its + Contributions or its Contributor Version. + +2.2. Effective Date + + The licenses granted in Section 2.1 with respect to any Contribution + become effective for each Contribution on the date the Contributor first + distributes such Contribution. + +2.3. Limitations on Grant Scope + + The licenses granted in this Section 2 are the only rights granted under + this License. No additional rights or licenses will be implied from the + distribution or licensing of Covered Software under this License. + Notwithstanding Section 2.1(b) above, no patent license is granted by a + Contributor: + + a. for any code that a Contributor has removed from Covered Software; or + + b. for infringements caused by: (i) Your and any other third party's + modifications of Covered Software, or (ii) the combination of its + Contributions with other software (except as part of its Contributor + Version); or + + c. under Patent Claims infringed by Covered Software in the absence of + its Contributions. + + This License does not grant any rights in the trademarks, service marks, + or logos of any Contributor (except as may be necessary to comply with + the notice requirements in Section 3.4). + +2.4. Subsequent Licenses + + No Contributor makes additional grants as a result of Your choice to + distribute the Covered Software under a subsequent version of this + License (see Section 10.2) or under the terms of a Secondary License (if + permitted under the terms of Section 3.3). + +2.5. Representation + + Each Contributor represents that the Contributor believes its + Contributions are its original creation(s) or it has sufficient rights to + grant the rights to its Contributions conveyed by this License. + +2.6. Fair Use + + This License is not intended to limit any rights You have under + applicable copyright doctrines of fair use, fair dealing, or other + equivalents. + +2.7. Conditions + + Sections 3.1, 3.2, 3.3, and 3.4 are conditions of the licenses granted in + Section 2.1. + + +3. Responsibilities + +3.1. Distribution of Source Form + + All distribution of Covered Software in Source Code Form, including any + Modifications that You create or to which You contribute, must be under + the terms of this License. You must inform recipients that the Source + Code Form of the Covered Software is governed by the terms of this + License, and how they can obtain a copy of this License. You may not + attempt to alter or restrict the recipients' rights in the Source Code + Form. + +3.2. Distribution of Executable Form + + If You distribute Covered Software in Executable Form then: + + a. such Covered Software must also be made available in Source Code Form, + as described in Section 3.1, and You must inform recipients of the + Executable Form how they can obtain a copy of such Source Code Form by + reasonable means in a timely manner, at a charge no more than the cost + of distribution to the recipient; and + + b. You may distribute such Executable Form under the terms of this + License, or sublicense it under different terms, provided that the + license for the Executable Form does not attempt to limit or alter the + recipients' rights in the Source Code Form under this License. + +3.3. Distribution of a Larger Work + + You may create and distribute a Larger Work under terms of Your choice, + provided that You also comply with the requirements of this License for + the Covered Software. If the Larger Work is a combination of Covered + Software with a work governed by one or more Secondary Licenses, and the + Covered Software is not Incompatible With Secondary Licenses, this + License permits You to additionally distribute such Covered Software + under the terms of such Secondary License(s), so that the recipient of + the Larger Work may, at their option, further distribute the Covered + Software under the terms of either this License or such Secondary + License(s). + +3.4. Notices + + You may not remove or alter the substance of any license notices + (including copyright notices, patent notices, disclaimers of warranty, or + limitations of liability) contained within the Source Code Form of the + Covered Software, except that You may alter any license notices to the + extent required to remedy known factual inaccuracies. + +3.5. Application of Additional Terms + + You may choose to offer, and to charge a fee for, warranty, support, + indemnity or liability obligations to one or more recipients of Covered + Software. However, You may do so only on Your own behalf, and not on + behalf of any Contributor. You must make it absolutely clear that any + such warranty, support, indemnity, or liability obligation is offered by + You alone, and You hereby agree to indemnify every Contributor for any + liability incurred by such Contributor as a result of warranty, support, + indemnity or liability terms You offer. You may include additional + disclaimers of warranty and limitations of liability specific to any + jurisdiction. + +4. Inability to Comply Due to Statute or Regulation + + If it is impossible for You to comply with any of the terms of this License + with respect to some or all of the Covered Software due to statute, + judicial order, or regulation then You must: (a) comply with the terms of + this License to the maximum extent possible; and (b) describe the + limitations and the code they affect. Such description must be placed in a + text file included with all distributions of the Covered Software under + this License. Except to the extent prohibited by statute or regulation, + such description must be sufficiently detailed for a recipient of ordinary + skill to be able to understand it. + +5. Termination + +5.1. The rights granted under this License will terminate automatically if You + fail to comply with any of its terms. However, if You become compliant, + then the rights granted under this License from a particular Contributor + are reinstated (a) provisionally, unless and until such Contributor + explicitly and finally terminates Your grants, and (b) on an ongoing + basis, if such Contributor fails to notify You of the non-compliance by + some reasonable means prior to 60 days after You have come back into + compliance. Moreover, Your grants from a particular Contributor are + reinstated on an ongoing basis if such Contributor notifies You of the + non-compliance by some reasonable means, this is the first time You have + received notice of non-compliance with this License from such + Contributor, and You become compliant prior to 30 days after Your receipt + of the notice. + +5.2. If You initiate litigation against any entity by asserting a patent + infringement claim (excluding declaratory judgment actions, + counter-claims, and cross-claims) alleging that a Contributor Version + directly or indirectly infringes any patent, then the rights granted to + You by any and all Contributors for the Covered Software under Section + 2.1 of this License shall terminate. + +5.3. In the event of termination under Sections 5.1 or 5.2 above, all end user + license agreements (excluding distributors and resellers) which have been + validly granted by You or Your distributors under this License prior to + termination shall survive termination. + +6. Disclaimer of Warranty + + Covered Software is provided under this License on an "as is" basis, + without warranty of any kind, either expressed, implied, or statutory, + including, without limitation, warranties that the Covered Software is free + of defects, merchantable, fit for a particular purpose or non-infringing. + The entire risk as to the quality and performance of the Covered Software + is with You. Should any Covered Software prove defective in any respect, + You (not any Contributor) assume the cost of any necessary servicing, + repair, or correction. This disclaimer of warranty constitutes an essential + part of this License. No use of any Covered Software is authorized under + this License except under this disclaimer. + +7. Limitation of Liability + + Under no circumstances and under no legal theory, whether tort (including + negligence), contract, or otherwise, shall any Contributor, or anyone who + distributes Covered Software as permitted above, be liable to You for any + direct, indirect, special, incidental, or consequential damages of any + character including, without limitation, damages for lost profits, loss of + goodwill, work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses, even if such party shall have been + informed of the possibility of such damages. This limitation of liability + shall not apply to liability for death or personal injury resulting from + such party's negligence to the extent applicable law prohibits such + limitation. Some jurisdictions do not allow the exclusion or limitation of + incidental or consequential damages, so this exclusion and limitation may + not apply to You. + +8. Litigation + + Any litigation relating to this License may be brought only in the courts + of a jurisdiction where the defendant maintains its principal place of + business and such litigation shall be governed by laws of that + jurisdiction, without reference to its conflict-of-law provisions. Nothing + in this Section shall prevent a party's ability to bring cross-claims or + counter-claims. + +9. Miscellaneous + + This License represents the complete agreement concerning the subject + matter hereof. If any provision of this License is held to be + unenforceable, such provision shall be reformed only to the extent + necessary to make it enforceable. Any law or regulation which provides that + the language of a contract shall be construed against the drafter shall not + be used to construe this License against a Contributor. + + +10. Versions of the License + +10.1. New Versions + + Mozilla Foundation is the license steward. Except as provided in Section + 10.3, no one other than the license steward has the right to modify or + publish new versions of this License. Each version will be given a + distinguishing version number. + +10.2. Effect of New Versions + + You may distribute the Covered Software under the terms of the version + of the License under which You originally received the Covered Software, + or under the terms of any subsequent version published by the license + steward. + +10.3. Modified Versions + + If you create software not governed by this License, and you want to + create a new license for such software, you may create and use a + modified version of this License if you rename the license and remove + any references to the name of the license steward (except to note that + such modified license differs from this License). + +10.4. Distributing Source Code Form that is Incompatible With Secondary + Licenses If You choose to distribute Source Code Form that is + Incompatible With Secondary Licenses under the terms of this version of + the License, the notice described in Exhibit B of this License must be + attached. + +Exhibit A - Source Code Form License Notice + + This Source Code Form is subject to the + terms of the Mozilla Public License, v. + 2.0. If a copy of the MPL was not + distributed with this file, You can + obtain one at + http://mozilla.org/MPL/2.0/. + +If it is not possible or desirable to put the notice in a particular file, +then You may include the notice in a location (such as a LICENSE file in a +relevant directory) where a recipient would be likely to look for such a +notice. + +You may add additional accurate notices of copyright ownership. + +Exhibit B - "Incompatible With Secondary Licenses" Notice + + This Source Code Form is "Incompatible + With Secondary Licenses", as defined by + the Mozilla Public License, v. 2.0. + diff --git a/vendor/github.com/hashicorp/go-memdb/README.md b/vendor/github.com/hashicorp/go-memdb/README.md new file mode 100644 index 0000000000..65e1eaefe8 --- /dev/null +++ b/vendor/github.com/hashicorp/go-memdb/README.md @@ -0,0 +1,98 @@ +# go-memdb + +Provides the `memdb` package that implements a simple in-memory database +built on immutable radix trees. The database provides Atomicity, Consistency +and Isolation from ACID. Being that it is in-memory, it does not provide durability. +The database is instantiated with a schema that specifies the tables and indices +that exist and allows transactions to be executed. + +The database provides the following: + +* Multi-Version Concurrency Control (MVCC) - By leveraging immutable radix trees + the database is able to support any number of concurrent readers without locking, + and allows a writer to make progress. + +* Transaction Support - The database allows for rich transactions, in which multiple + objects are inserted, updated or deleted. The transactions can span multiple tables, + and are applied atomically. The database provides atomicity and isolation in ACID + terminology, such that until commit the updates are not visible. + +* Rich Indexing - Tables can support any number of indexes, which can be simple like + a single field index, or more advanced compound field indexes. Certain types like + UUID can be efficiently compressed from strings into byte indexes for reduced + storage requirements. + +* Watches - Callers can populate a watch set as part of a query, which can be used to + detect when a modification has been made to the database which affects the query + results. This lets callers easily watch for changes in the database in a very general + way. + +For the underlying immutable radix trees, see [go-immutable-radix](https://github.com/hashicorp/go-immutable-radix). + +Documentation +============= + +The full documentation is available on [Godoc](http://godoc.org/github.com/hashicorp/go-memdb). + +Example +======= + +Below is a simple example of usage + +```go +// Create a sample struct +type Person struct { + Email string + Name string + Age int +} + +// Create the DB schema +schema := &memdb.DBSchema{ + Tables: map[string]*memdb.TableSchema{ + "person": &memdb.TableSchema{ + Name: "person", + Indexes: map[string]*memdb.IndexSchema{ + "id": &memdb.IndexSchema{ + Name: "id", + Unique: true, + Indexer: &memdb.StringFieldIndex{Field: "Email"}, + }, + }, + }, + }, +} + +// Create a new data base +db, err := memdb.NewMemDB(schema) +if err != nil { + panic(err) +} + +// Create a write transaction +txn := db.Txn(true) + +// Insert a new person +p := &Person{"joe@aol.com", "Joe", 30} +if err := txn.Insert("person", p); err != nil { + panic(err) +} + +// Commit the transaction +txn.Commit() + +// Create read-only transaction +txn = db.Txn(false) +defer txn.Abort() + +// Lookup by email +raw, err := txn.First("person", "id", "joe@aol.com") +if err != nil { + panic(err) +} + +// Say hi! +fmt.Printf("Hello %s!", raw.(*Person).Name) + +``` + diff --git a/vendor/github.com/hashicorp/go-memdb/filter.go b/vendor/github.com/hashicorp/go-memdb/filter.go new file mode 100644 index 0000000000..2e3a9b3f7b --- /dev/null +++ b/vendor/github.com/hashicorp/go-memdb/filter.go @@ -0,0 +1,33 @@ +package memdb + +// FilterFunc is a function that takes the results of an iterator and returns +// whether the result should be filtered out. +type FilterFunc func(interface{}) bool + +// FilterIterator is used to wrap a ResultIterator and apply a filter over it. +type FilterIterator struct { + // filter is the filter function applied over the base iterator. + filter FilterFunc + + // iter is the iterator that is being wrapped. + iter ResultIterator +} + +func NewFilterIterator(wrap ResultIterator, filter FilterFunc) *FilterIterator { + return &FilterIterator{ + filter: filter, + iter: wrap, + } +} + +// WatchCh returns the watch channel of the wrapped iterator. +func (f *FilterIterator) WatchCh() <-chan struct{} { return f.iter.WatchCh() } + +// Next returns the next non-filtered result from the wrapped iterator +func (f *FilterIterator) Next() interface{} { + for { + if value := f.iter.Next(); value == nil || !f.filter(value) { + return value + } + } +} diff --git a/vendor/github.com/hashicorp/go-memdb/index.go b/vendor/github.com/hashicorp/go-memdb/index.go new file mode 100644 index 0000000000..d1fb951466 --- /dev/null +++ b/vendor/github.com/hashicorp/go-memdb/index.go @@ -0,0 +1,569 @@ +package memdb + +import ( + "encoding/binary" + "encoding/hex" + "fmt" + "reflect" + "strings" +) + +// Indexer is an interface used for defining indexes +type Indexer interface { + // ExactFromArgs is used to build an exact index lookup + // based on arguments + FromArgs(args ...interface{}) ([]byte, error) +} + +// SingleIndexer is an interface used for defining indexes +// generating a single entry per object +type SingleIndexer interface { + // FromObject is used to extract an index value from an + // object or to indicate that the index value is missing. + FromObject(raw interface{}) (bool, []byte, error) +} + +// MultiIndexer is an interface used for defining indexes +// generating multiple entries per object +type MultiIndexer interface { + // FromObject is used to extract index values from an + // object or to indicate that the index value is missing. + FromObject(raw interface{}) (bool, [][]byte, error) +} + +// PrefixIndexer can optionally be implemented for any +// indexes that support prefix based iteration. This may +// not apply to all indexes. +type PrefixIndexer interface { + // PrefixFromArgs returns a prefix that should be used + // for scanning based on the arguments + PrefixFromArgs(args ...interface{}) ([]byte, error) +} + +// StringFieldIndex is used to extract a field from an object +// using reflection and builds an index on that field. +type StringFieldIndex struct { + Field string + Lowercase bool +} + +func (s *StringFieldIndex) FromObject(obj interface{}) (bool, []byte, error) { + v := reflect.ValueOf(obj) + v = reflect.Indirect(v) // Dereference the pointer if any + + fv := v.FieldByName(s.Field) + if !fv.IsValid() { + return false, nil, + fmt.Errorf("field '%s' for %#v is invalid", s.Field, obj) + } + + val := fv.String() + if val == "" { + return false, nil, nil + } + + if s.Lowercase { + val = strings.ToLower(val) + } + + // Add the null character as a terminator + val += "\x00" + return true, []byte(val), nil +} + +func (s *StringFieldIndex) FromArgs(args ...interface{}) ([]byte, error) { + if len(args) != 1 { + return nil, fmt.Errorf("must provide only a single argument") + } + arg, ok := args[0].(string) + if !ok { + return nil, fmt.Errorf("argument must be a string: %#v", args[0]) + } + if s.Lowercase { + arg = strings.ToLower(arg) + } + // Add the null character as a terminator + arg += "\x00" + return []byte(arg), nil +} + +func (s *StringFieldIndex) PrefixFromArgs(args ...interface{}) ([]byte, error) { + val, err := s.FromArgs(args...) + if err != nil { + return nil, err + } + + // Strip the null terminator, the rest is a prefix + n := len(val) + if n > 0 { + return val[:n-1], nil + } + return val, nil +} + +// StringSliceFieldIndex is used to extract a field from an object +// using reflection and builds an index on that field. +type StringSliceFieldIndex struct { + Field string + Lowercase bool +} + +func (s *StringSliceFieldIndex) FromObject(obj interface{}) (bool, [][]byte, error) { + v := reflect.ValueOf(obj) + v = reflect.Indirect(v) // Dereference the pointer if any + + fv := v.FieldByName(s.Field) + if !fv.IsValid() { + return false, nil, + fmt.Errorf("field '%s' for %#v is invalid", s.Field, obj) + } + + if fv.Kind() != reflect.Slice || fv.Type().Elem().Kind() != reflect.String { + return false, nil, fmt.Errorf("field '%s' is not a string slice", s.Field) + } + + length := fv.Len() + vals := make([][]byte, 0, length) + for i := 0; i < fv.Len(); i++ { + val := fv.Index(i).String() + if val == "" { + continue + } + + if s.Lowercase { + val = strings.ToLower(val) + } + + // Add the null character as a terminator + val += "\x00" + vals = append(vals, []byte(val)) + } + if len(vals) == 0 { + return false, nil, nil + } + return true, vals, nil +} + +func (s *StringSliceFieldIndex) FromArgs(args ...interface{}) ([]byte, error) { + if len(args) != 1 { + return nil, fmt.Errorf("must provide only a single argument") + } + arg, ok := args[0].(string) + if !ok { + return nil, fmt.Errorf("argument must be a string: %#v", args[0]) + } + if s.Lowercase { + arg = strings.ToLower(arg) + } + // Add the null character as a terminator + arg += "\x00" + return []byte(arg), nil +} + +func (s *StringSliceFieldIndex) PrefixFromArgs(args ...interface{}) ([]byte, error) { + val, err := s.FromArgs(args...) + if err != nil { + return nil, err + } + + // Strip the null terminator, the rest is a prefix + n := len(val) + if n > 0 { + return val[:n-1], nil + } + return val, nil +} + +// StringMapFieldIndex is used to extract a field of type map[string]string +// from an object using reflection and builds an index on that field. +type StringMapFieldIndex struct { + Field string + Lowercase bool +} + +var MapType = reflect.MapOf(reflect.TypeOf(""), reflect.TypeOf("")).Kind() + +func (s *StringMapFieldIndex) FromObject(obj interface{}) (bool, [][]byte, error) { + v := reflect.ValueOf(obj) + v = reflect.Indirect(v) // Dereference the pointer if any + + fv := v.FieldByName(s.Field) + if !fv.IsValid() { + return false, nil, fmt.Errorf("field '%s' for %#v is invalid", s.Field, obj) + } + + if fv.Kind() != MapType { + return false, nil, fmt.Errorf("field '%s' is not a map[string]string", s.Field) + } + + length := fv.Len() + vals := make([][]byte, 0, length) + for _, key := range fv.MapKeys() { + k := key.String() + if k == "" { + continue + } + val := fv.MapIndex(key).String() + + if s.Lowercase { + k = strings.ToLower(k) + val = strings.ToLower(val) + } + + // Add the null character as a terminator + k += "\x00" + val + "\x00" + + vals = append(vals, []byte(k)) + } + if len(vals) == 0 { + return false, nil, nil + } + return true, vals, nil +} + +func (s *StringMapFieldIndex) FromArgs(args ...interface{}) ([]byte, error) { + if len(args) > 2 || len(args) == 0 { + return nil, fmt.Errorf("must provide one or two arguments") + } + key, ok := args[0].(string) + if !ok { + return nil, fmt.Errorf("argument must be a string: %#v", args[0]) + } + if s.Lowercase { + key = strings.ToLower(key) + } + // Add the null character as a terminator + key += "\x00" + + if len(args) == 2 { + val, ok := args[1].(string) + if !ok { + return nil, fmt.Errorf("argument must be a string: %#v", args[1]) + } + if s.Lowercase { + val = strings.ToLower(val) + } + // Add the null character as a terminator + key += val + "\x00" + } + + return []byte(key), nil +} + +// UintFieldIndex is used to extract a uint field from an object using +// reflection and builds an index on that field. +type UintFieldIndex struct { + Field string +} + +func (u *UintFieldIndex) FromObject(obj interface{}) (bool, []byte, error) { + v := reflect.ValueOf(obj) + v = reflect.Indirect(v) // Dereference the pointer if any + + fv := v.FieldByName(u.Field) + if !fv.IsValid() { + return false, nil, + fmt.Errorf("field '%s' for %#v is invalid", u.Field, obj) + } + + // Check the type + k := fv.Kind() + size, ok := IsUintType(k) + if !ok { + return false, nil, fmt.Errorf("field %q is of type %v; want a uint", u.Field, k) + } + + // Get the value and encode it + val := fv.Uint() + buf := make([]byte, size) + binary.PutUvarint(buf, val) + + return true, buf, nil +} + +func (u *UintFieldIndex) FromArgs(args ...interface{}) ([]byte, error) { + if len(args) != 1 { + return nil, fmt.Errorf("must provide only a single argument") + } + + v := reflect.ValueOf(args[0]) + if !v.IsValid() { + return nil, fmt.Errorf("%#v is invalid", args[0]) + } + + k := v.Kind() + size, ok := IsUintType(k) + if !ok { + return nil, fmt.Errorf("arg is of type %v; want a uint", k) + } + + val := v.Uint() + buf := make([]byte, size) + binary.PutUvarint(buf, val) + + return buf, nil +} + +// IsUintType returns whether the passed type is a type of uint and the number +// of bytes needed to encode the type. +func IsUintType(k reflect.Kind) (size int, okay bool) { + switch k { + case reflect.Uint: + return binary.MaxVarintLen64, true + case reflect.Uint8: + return 2, true + case reflect.Uint16: + return binary.MaxVarintLen16, true + case reflect.Uint32: + return binary.MaxVarintLen32, true + case reflect.Uint64: + return binary.MaxVarintLen64, true + default: + return 0, false + } +} + +// UUIDFieldIndex is used to extract a field from an object +// using reflection and builds an index on that field by treating +// it as a UUID. This is an optimization to using a StringFieldIndex +// as the UUID can be more compactly represented in byte form. +type UUIDFieldIndex struct { + Field string +} + +func (u *UUIDFieldIndex) FromObject(obj interface{}) (bool, []byte, error) { + v := reflect.ValueOf(obj) + v = reflect.Indirect(v) // Dereference the pointer if any + + fv := v.FieldByName(u.Field) + if !fv.IsValid() { + return false, nil, + fmt.Errorf("field '%s' for %#v is invalid", u.Field, obj) + } + + val := fv.String() + if val == "" { + return false, nil, nil + } + + buf, err := u.parseString(val, true) + return true, buf, err +} + +func (u *UUIDFieldIndex) FromArgs(args ...interface{}) ([]byte, error) { + if len(args) != 1 { + return nil, fmt.Errorf("must provide only a single argument") + } + switch arg := args[0].(type) { + case string: + return u.parseString(arg, true) + case []byte: + if len(arg) != 16 { + return nil, fmt.Errorf("byte slice must be 16 characters") + } + return arg, nil + default: + return nil, + fmt.Errorf("argument must be a string or byte slice: %#v", args[0]) + } +} + +func (u *UUIDFieldIndex) PrefixFromArgs(args ...interface{}) ([]byte, error) { + if len(args) != 1 { + return nil, fmt.Errorf("must provide only a single argument") + } + switch arg := args[0].(type) { + case string: + return u.parseString(arg, false) + case []byte: + return arg, nil + default: + return nil, + fmt.Errorf("argument must be a string or byte slice: %#v", args[0]) + } +} + +// parseString parses a UUID from the string. If enforceLength is false, it will +// parse a partial UUID. An error is returned if the input, stripped of hyphens, +// is not even length. +func (u *UUIDFieldIndex) parseString(s string, enforceLength bool) ([]byte, error) { + // Verify the length + l := len(s) + if enforceLength && l != 36 { + return nil, fmt.Errorf("UUID must be 36 characters") + } else if l > 36 { + return nil, fmt.Errorf("Invalid UUID length. UUID have 36 characters; got %d", l) + } + + hyphens := strings.Count(s, "-") + if hyphens > 4 { + return nil, fmt.Errorf(`UUID should have maximum of 4 "-"; got %d`, hyphens) + } + + // The sanitized length is the length of the original string without the "-". + sanitized := strings.Replace(s, "-", "", -1) + sanitizedLength := len(sanitized) + if sanitizedLength%2 != 0 { + return nil, fmt.Errorf("Input (without hyphens) must be even length") + } + + dec, err := hex.DecodeString(sanitized) + if err != nil { + return nil, fmt.Errorf("Invalid UUID: %v", err) + } + + return dec, nil +} + +// FieldSetIndex is used to extract a field from an object using reflection and +// builds an index on whether the field is set by comparing it against its +// type's nil value. +type FieldSetIndex struct { + Field string +} + +func (f *FieldSetIndex) FromObject(obj interface{}) (bool, []byte, error) { + v := reflect.ValueOf(obj) + v = reflect.Indirect(v) // Dereference the pointer if any + + fv := v.FieldByName(f.Field) + if !fv.IsValid() { + return false, nil, + fmt.Errorf("field '%s' for %#v is invalid", f.Field, obj) + } + + if fv.Interface() == reflect.Zero(fv.Type()).Interface() { + return true, []byte{0}, nil + } + + return true, []byte{1}, nil +} + +func (f *FieldSetIndex) FromArgs(args ...interface{}) ([]byte, error) { + return fromBoolArgs(args) +} + +// ConditionalIndex builds an index based on a condition specified by a passed +// user function. This function may examine the passed object and return a +// boolean to encapsulate an arbitrarily complex conditional. +type ConditionalIndex struct { + Conditional ConditionalIndexFunc +} + +// ConditionalIndexFunc is the required function interface for a +// ConditionalIndex. +type ConditionalIndexFunc func(obj interface{}) (bool, error) + +func (c *ConditionalIndex) FromObject(obj interface{}) (bool, []byte, error) { + // Call the user's function + res, err := c.Conditional(obj) + if err != nil { + return false, nil, fmt.Errorf("ConditionalIndexFunc(%#v) failed: %v", obj, err) + } + + if res { + return true, []byte{1}, nil + } + + return true, []byte{0}, nil +} + +func (c *ConditionalIndex) FromArgs(args ...interface{}) ([]byte, error) { + return fromBoolArgs(args) +} + +// fromBoolArgs is a helper that expects only a single boolean argument and +// returns a single length byte array containing either a one or zero depending +// on whether the passed input is true or false respectively. +func fromBoolArgs(args []interface{}) ([]byte, error) { + if len(args) != 1 { + return nil, fmt.Errorf("must provide only a single argument") + } + + if val, ok := args[0].(bool); !ok { + return nil, fmt.Errorf("argument must be a boolean type: %#v", args[0]) + } else if val { + return []byte{1}, nil + } + + return []byte{0}, nil +} + +// CompoundIndex is used to build an index using multiple sub-indexes +// Prefix based iteration is supported as long as the appropriate prefix +// of indexers support it. All sub-indexers are only assumed to expect +// a single argument. +type CompoundIndex struct { + Indexes []Indexer + + // AllowMissing results in an index based on only the indexers + // that return data. If true, you may end up with 2/3 columns + // indexed which might be useful for an index scan. Otherwise, + // the CompoundIndex requires all indexers to be satisfied. + AllowMissing bool +} + +func (c *CompoundIndex) FromObject(raw interface{}) (bool, []byte, error) { + var out []byte + for i, idxRaw := range c.Indexes { + idx, ok := idxRaw.(SingleIndexer) + if !ok { + return false, nil, fmt.Errorf("sub-index %d error: %s", i, "sub-index must be a SingleIndexer") + } + ok, val, err := idx.FromObject(raw) + if err != nil { + return false, nil, fmt.Errorf("sub-index %d error: %v", i, err) + } + if !ok { + if c.AllowMissing { + break + } else { + return false, nil, nil + } + } + out = append(out, val...) + } + return true, out, nil +} + +func (c *CompoundIndex) FromArgs(args ...interface{}) ([]byte, error) { + if len(args) != len(c.Indexes) { + return nil, fmt.Errorf("less arguments than index fields") + } + var out []byte + for i, arg := range args { + val, err := c.Indexes[i].FromArgs(arg) + if err != nil { + return nil, fmt.Errorf("sub-index %d error: %v", i, err) + } + out = append(out, val...) + } + return out, nil +} + +func (c *CompoundIndex) PrefixFromArgs(args ...interface{}) ([]byte, error) { + if len(args) > len(c.Indexes) { + return nil, fmt.Errorf("more arguments than index fields") + } + var out []byte + for i, arg := range args { + if i+1 < len(args) { + val, err := c.Indexes[i].FromArgs(arg) + if err != nil { + return nil, fmt.Errorf("sub-index %d error: %v", i, err) + } + out = append(out, val...) + } else { + prefixIndexer, ok := c.Indexes[i].(PrefixIndexer) + if !ok { + return nil, fmt.Errorf("sub-index %d does not support prefix scanning", i) + } + val, err := prefixIndexer.PrefixFromArgs(arg) + if err != nil { + return nil, fmt.Errorf("sub-index %d error: %v", i, err) + } + out = append(out, val...) + } + } + return out, nil +} diff --git a/vendor/github.com/hashicorp/go-memdb/memdb.go b/vendor/github.com/hashicorp/go-memdb/memdb.go new file mode 100644 index 0000000000..9e9b98df50 --- /dev/null +++ b/vendor/github.com/hashicorp/go-memdb/memdb.go @@ -0,0 +1,92 @@ +package memdb + +import ( + "sync" + "sync/atomic" + "unsafe" + + "github.com/hashicorp/go-immutable-radix" +) + +// MemDB is an in-memory database. It provides a table abstraction, +// which is used to store objects (rows) with multiple indexes based +// on values. The database makes use of immutable radix trees to provide +// transactions and MVCC. +type MemDB struct { + schema *DBSchema + root unsafe.Pointer // *iradix.Tree underneath + primary bool + + // There can only be a single writter at once + writer sync.Mutex +} + +// NewMemDB creates a new MemDB with the given schema +func NewMemDB(schema *DBSchema) (*MemDB, error) { + // Validate the schema + if err := schema.Validate(); err != nil { + return nil, err + } + + // Create the MemDB + db := &MemDB{ + schema: schema, + root: unsafe.Pointer(iradix.New()), + primary: true, + } + if err := db.initialize(); err != nil { + return nil, err + } + return db, nil +} + +// getRoot is used to do an atomic load of the root pointer +func (db *MemDB) getRoot() *iradix.Tree { + root := (*iradix.Tree)(atomic.LoadPointer(&db.root)) + return root +} + +// Txn is used to start a new transaction, in either read or write mode. +// There can only be a single concurrent writer, but any number of readers. +func (db *MemDB) Txn(write bool) *Txn { + if write { + db.writer.Lock() + } + txn := &Txn{ + db: db, + write: write, + rootTxn: db.getRoot().Txn(), + } + return txn +} + +// Snapshot is used to capture a point-in-time snapshot +// of the database that will not be affected by any write +// operations to the existing DB. +func (db *MemDB) Snapshot() *MemDB { + clone := &MemDB{ + schema: db.schema, + root: unsafe.Pointer(db.getRoot()), + primary: false, + } + return clone +} + +// initialize is used to setup the DB for use after creation +func (db *MemDB) initialize() error { + root := db.getRoot() + for tName, tableSchema := range db.schema.Tables { + for iName := range tableSchema.Indexes { + index := iradix.New() + path := indexPath(tName, iName) + root, _, _ = root.Insert(path, index) + } + } + db.root = unsafe.Pointer(root) + return nil +} + +// indexPath returns the path from the root to the given table index +func indexPath(table, index string) []byte { + return []byte(table + "." + index) +} diff --git a/vendor/github.com/hashicorp/go-memdb/schema.go b/vendor/github.com/hashicorp/go-memdb/schema.go new file mode 100644 index 0000000000..d7210f91cd --- /dev/null +++ b/vendor/github.com/hashicorp/go-memdb/schema.go @@ -0,0 +1,85 @@ +package memdb + +import "fmt" + +// DBSchema contains the full database schema used for MemDB +type DBSchema struct { + Tables map[string]*TableSchema +} + +// Validate is used to validate the database schema +func (s *DBSchema) Validate() error { + if s == nil { + return fmt.Errorf("missing schema") + } + if len(s.Tables) == 0 { + return fmt.Errorf("no tables defined") + } + for name, table := range s.Tables { + if name != table.Name { + return fmt.Errorf("table name mis-match for '%s'", name) + } + if err := table.Validate(); err != nil { + return err + } + } + return nil +} + +// TableSchema contains the schema for a single table +type TableSchema struct { + Name string + Indexes map[string]*IndexSchema +} + +// Validate is used to validate the table schema +func (s *TableSchema) Validate() error { + if s.Name == "" { + return fmt.Errorf("missing table name") + } + if len(s.Indexes) == 0 { + return fmt.Errorf("missing table indexes for '%s'", s.Name) + } + if _, ok := s.Indexes["id"]; !ok { + return fmt.Errorf("must have id index") + } + if !s.Indexes["id"].Unique { + return fmt.Errorf("id index must be unique") + } + if _, ok := s.Indexes["id"].Indexer.(SingleIndexer); !ok { + return fmt.Errorf("id index must be a SingleIndexer") + } + for name, index := range s.Indexes { + if name != index.Name { + return fmt.Errorf("index name mis-match for '%s'", name) + } + if err := index.Validate(); err != nil { + return err + } + } + return nil +} + +// IndexSchema contains the schema for an index +type IndexSchema struct { + Name string + AllowMissing bool + Unique bool + Indexer Indexer +} + +func (s *IndexSchema) Validate() error { + if s.Name == "" { + return fmt.Errorf("missing index name") + } + if s.Indexer == nil { + return fmt.Errorf("missing index function for '%s'", s.Name) + } + switch s.Indexer.(type) { + case SingleIndexer: + case MultiIndexer: + default: + return fmt.Errorf("indexer for '%s' must be a SingleIndexer or MultiIndexer", s.Name) + } + return nil +} diff --git a/vendor/github.com/hashicorp/go-memdb/txn.go b/vendor/github.com/hashicorp/go-memdb/txn.go new file mode 100644 index 0000000000..2b85087ea3 --- /dev/null +++ b/vendor/github.com/hashicorp/go-memdb/txn.go @@ -0,0 +1,644 @@ +package memdb + +import ( + "bytes" + "fmt" + "strings" + "sync/atomic" + "unsafe" + + "github.com/hashicorp/go-immutable-radix" +) + +const ( + id = "id" +) + +var ( + // ErrNotFound is returned when the requested item is not found + ErrNotFound = fmt.Errorf("not found") +) + +// tableIndex is a tuple of (Table, Index) used for lookups +type tableIndex struct { + Table string + Index string +} + +// Txn is a transaction against a MemDB. +// This can be a read or write transaction. +type Txn struct { + db *MemDB + write bool + rootTxn *iradix.Txn + after []func() + + modified map[tableIndex]*iradix.Txn +} + +// readableIndex returns a transaction usable for reading the given +// index in a table. If a write transaction is in progress, we may need +// to use an existing modified txn. +func (txn *Txn) readableIndex(table, index string) *iradix.Txn { + // Look for existing transaction + if txn.write && txn.modified != nil { + key := tableIndex{table, index} + exist, ok := txn.modified[key] + if ok { + return exist + } + } + + // Create a read transaction + path := indexPath(table, index) + raw, _ := txn.rootTxn.Get(path) + indexTxn := raw.(*iradix.Tree).Txn() + return indexTxn +} + +// writableIndex returns a transaction usable for modifying the +// given index in a table. +func (txn *Txn) writableIndex(table, index string) *iradix.Txn { + if txn.modified == nil { + txn.modified = make(map[tableIndex]*iradix.Txn) + } + + // Look for existing transaction + key := tableIndex{table, index} + exist, ok := txn.modified[key] + if ok { + return exist + } + + // Start a new transaction + path := indexPath(table, index) + raw, _ := txn.rootTxn.Get(path) + indexTxn := raw.(*iradix.Tree).Txn() + + // If we are the primary DB, enable mutation tracking. Snapshots should + // not notify, otherwise we will trigger watches on the primary DB when + // the writes will not be visible. + indexTxn.TrackMutate(txn.db.primary) + + // Keep this open for the duration of the txn + txn.modified[key] = indexTxn + return indexTxn +} + +// Abort is used to cancel this transaction. +// This is a noop for read transactions. +func (txn *Txn) Abort() { + // Noop for a read transaction + if !txn.write { + return + } + + // Check if already aborted or committed + if txn.rootTxn == nil { + return + } + + // Clear the txn + txn.rootTxn = nil + txn.modified = nil + + // Release the writer lock since this is invalid + txn.db.writer.Unlock() +} + +// Commit is used to finalize this transaction. +// This is a noop for read transactions. +func (txn *Txn) Commit() { + // Noop for a read transaction + if !txn.write { + return + } + + // Check if already aborted or committed + if txn.rootTxn == nil { + return + } + + // Commit each sub-transaction scoped to (table, index) + for key, subTxn := range txn.modified { + path := indexPath(key.Table, key.Index) + final := subTxn.CommitOnly() + txn.rootTxn.Insert(path, final) + } + + // Update the root of the DB + newRoot := txn.rootTxn.CommitOnly() + atomic.StorePointer(&txn.db.root, unsafe.Pointer(newRoot)) + + // Now issue all of the mutation updates (this is safe to call + // even if mutation tracking isn't enabled); we do this after + // the root pointer is swapped so that waking responders will + // see the new state. + for _, subTxn := range txn.modified { + subTxn.Notify() + } + txn.rootTxn.Notify() + + // Clear the txn + txn.rootTxn = nil + txn.modified = nil + + // Release the writer lock since this is invalid + txn.db.writer.Unlock() + + // Run the deferred functions, if any + for i := len(txn.after); i > 0; i-- { + fn := txn.after[i-1] + fn() + } +} + +// Insert is used to add or update an object into the given table +func (txn *Txn) Insert(table string, obj interface{}) error { + if !txn.write { + return fmt.Errorf("cannot insert in read-only transaction") + } + + // Get the table schema + tableSchema, ok := txn.db.schema.Tables[table] + if !ok { + return fmt.Errorf("invalid table '%s'", table) + } + + // Get the primary ID of the object + idSchema := tableSchema.Indexes[id] + idIndexer := idSchema.Indexer.(SingleIndexer) + ok, idVal, err := idIndexer.FromObject(obj) + if err != nil { + return fmt.Errorf("failed to build primary index: %v", err) + } + if !ok { + return fmt.Errorf("object missing primary index") + } + + // Lookup the object by ID first, to see if this is an update + idTxn := txn.writableIndex(table, id) + existing, update := idTxn.Get(idVal) + + // On an update, there is an existing object with the given + // primary ID. We do the update by deleting the current object + // and inserting the new object. + for name, indexSchema := range tableSchema.Indexes { + indexTxn := txn.writableIndex(table, name) + + // Determine the new index value + var ( + ok bool + vals [][]byte + err error + ) + switch indexer := indexSchema.Indexer.(type) { + case SingleIndexer: + var val []byte + ok, val, err = indexer.FromObject(obj) + vals = [][]byte{val} + case MultiIndexer: + ok, vals, err = indexer.FromObject(obj) + } + if err != nil { + return fmt.Errorf("failed to build index '%s': %v", name, err) + } + + // Handle non-unique index by computing a unique index. + // This is done by appending the primary key which must + // be unique anyways. + if ok && !indexSchema.Unique { + for i := range vals { + vals[i] = append(vals[i], idVal...) + } + } + + // Handle the update by deleting from the index first + if update { + var ( + okExist bool + valsExist [][]byte + err error + ) + switch indexer := indexSchema.Indexer.(type) { + case SingleIndexer: + var valExist []byte + okExist, valExist, err = indexer.FromObject(existing) + valsExist = [][]byte{valExist} + case MultiIndexer: + okExist, valsExist, err = indexer.FromObject(existing) + } + if err != nil { + return fmt.Errorf("failed to build index '%s': %v", name, err) + } + if okExist { + for i, valExist := range valsExist { + // Handle non-unique index by computing a unique index. + // This is done by appending the primary key which must + // be unique anyways. + if !indexSchema.Unique { + valExist = append(valExist, idVal...) + } + + // If we are writing to the same index with the same value, + // we can avoid the delete as the insert will overwrite the + // value anyways. + if i >= len(vals) || !bytes.Equal(valExist, vals[i]) { + indexTxn.Delete(valExist) + } + } + } + } + + // If there is no index value, either this is an error or an expected + // case and we can skip updating + if !ok { + if indexSchema.AllowMissing { + continue + } else { + return fmt.Errorf("missing value for index '%s'", name) + } + } + + // Update the value of the index + for _, val := range vals { + indexTxn.Insert(val, obj) + } + } + return nil +} + +// Delete is used to delete a single object from the given table +// This object must already exist in the table +func (txn *Txn) Delete(table string, obj interface{}) error { + if !txn.write { + return fmt.Errorf("cannot delete in read-only transaction") + } + + // Get the table schema + tableSchema, ok := txn.db.schema.Tables[table] + if !ok { + return fmt.Errorf("invalid table '%s'", table) + } + + // Get the primary ID of the object + idSchema := tableSchema.Indexes[id] + idIndexer := idSchema.Indexer.(SingleIndexer) + ok, idVal, err := idIndexer.FromObject(obj) + if err != nil { + return fmt.Errorf("failed to build primary index: %v", err) + } + if !ok { + return fmt.Errorf("object missing primary index") + } + + // Lookup the object by ID first, check fi we should continue + idTxn := txn.writableIndex(table, id) + existing, ok := idTxn.Get(idVal) + if !ok { + return ErrNotFound + } + + // Remove the object from all the indexes + for name, indexSchema := range tableSchema.Indexes { + indexTxn := txn.writableIndex(table, name) + + // Handle the update by deleting from the index first + var ( + ok bool + vals [][]byte + err error + ) + switch indexer := indexSchema.Indexer.(type) { + case SingleIndexer: + var val []byte + ok, val, err = indexer.FromObject(existing) + vals = [][]byte{val} + case MultiIndexer: + ok, vals, err = indexer.FromObject(existing) + } + if err != nil { + return fmt.Errorf("failed to build index '%s': %v", name, err) + } + if ok { + // Handle non-unique index by computing a unique index. + // This is done by appending the primary key which must + // be unique anyways. + for _, val := range vals { + if !indexSchema.Unique { + val = append(val, idVal...) + } + indexTxn.Delete(val) + } + } + } + return nil +} + +// DeletePrefix is used to delete an entire subtree based on a prefix. +// The given index must be a prefix index, and will be used to perform a scan and enumerate the set of objects to delete. +// These will be removed from all other indexes, and then a special prefix operation will delete the objects from the given index in an efficient subtree delete operation. +// This is useful when you have a very large number of objects indexed by the given index, along with a much smaller number of entries in the other indexes for those objects. +func (txn *Txn) DeletePrefix(table string, prefix_index string, prefix string) (bool, error) { + if !txn.write { + return false, fmt.Errorf("cannot delete in read-only transaction") + } + + if !strings.HasSuffix(prefix_index, "_prefix") { + return false, fmt.Errorf("Index name for DeletePrefix must be a prefix index, Got %v ", prefix_index) + } + + deletePrefixIndex := strings.TrimSuffix(prefix_index, "_prefix") + + // Get an iterator over all of the keys with the given prefix. + entries, err := txn.Get(table, prefix_index, prefix) + if err != nil { + return false, fmt.Errorf("failed kvs lookup: %s", err) + } + // Get the table schema + tableSchema, ok := txn.db.schema.Tables[table] + if !ok { + return false, fmt.Errorf("invalid table '%s'", table) + } + + foundAny := false + for entry := entries.Next(); entry != nil; entry = entries.Next() { + if !foundAny { + foundAny = true + } + // Get the primary ID of the object + idSchema := tableSchema.Indexes[id] + idIndexer := idSchema.Indexer.(SingleIndexer) + ok, idVal, err := idIndexer.FromObject(entry) + if err != nil { + return false, fmt.Errorf("failed to build primary index: %v", err) + } + if !ok { + return false, fmt.Errorf("object missing primary index") + } + // Remove the object from all the indexes except the given prefix index + for name, indexSchema := range tableSchema.Indexes { + if name == deletePrefixIndex { + continue + } + indexTxn := txn.writableIndex(table, name) + + // Handle the update by deleting from the index first + var ( + ok bool + vals [][]byte + err error + ) + switch indexer := indexSchema.Indexer.(type) { + case SingleIndexer: + var val []byte + ok, val, err = indexer.FromObject(entry) + vals = [][]byte{val} + case MultiIndexer: + ok, vals, err = indexer.FromObject(entry) + } + if err != nil { + return false, fmt.Errorf("failed to build index '%s': %v", name, err) + } + + if ok { + // Handle non-unique index by computing a unique index. + // This is done by appending the primary key which must + // be unique anyways. + for _, val := range vals { + if !indexSchema.Unique { + val = append(val, idVal...) + } + indexTxn.Delete(val) + } + } + } + } + if foundAny { + indexTxn := txn.writableIndex(table, deletePrefixIndex) + ok = indexTxn.DeletePrefix([]byte(prefix)) + if !ok { + panic(fmt.Errorf("prefix %v matched some entries but DeletePrefix did not delete any ", prefix)) + } + return true, nil + } + return false, nil +} + +// DeleteAll is used to delete all the objects in a given table +// matching the constraints on the index +func (txn *Txn) DeleteAll(table, index string, args ...interface{}) (int, error) { + if !txn.write { + return 0, fmt.Errorf("cannot delete in read-only transaction") + } + + // Get all the objects + iter, err := txn.Get(table, index, args...) + if err != nil { + return 0, err + } + + // Put them into a slice so there are no safety concerns while actually + // performing the deletes + var objs []interface{} + for { + obj := iter.Next() + if obj == nil { + break + } + + objs = append(objs, obj) + } + + // Do the deletes + num := 0 + for _, obj := range objs { + if err := txn.Delete(table, obj); err != nil { + return num, err + } + num++ + } + return num, nil +} + +// FirstWatch is used to return the first matching object for +// the given constraints on the index along with the watch channel +func (txn *Txn) FirstWatch(table, index string, args ...interface{}) (<-chan struct{}, interface{}, error) { + // Get the index value + indexSchema, val, err := txn.getIndexValue(table, index, args...) + if err != nil { + return nil, nil, err + } + + // Get the index itself + indexTxn := txn.readableIndex(table, indexSchema.Name) + + // Do an exact lookup + if indexSchema.Unique && val != nil && indexSchema.Name == index { + watch, obj, ok := indexTxn.GetWatch(val) + if !ok { + return watch, nil, nil + } + return watch, obj, nil + } + + // Handle non-unique index by using an iterator and getting the first value + iter := indexTxn.Root().Iterator() + watch := iter.SeekPrefixWatch(val) + _, value, _ := iter.Next() + return watch, value, nil +} + +// First is used to return the first matching object for +// the given constraints on the index +func (txn *Txn) First(table, index string, args ...interface{}) (interface{}, error) { + _, val, err := txn.FirstWatch(table, index, args...) + return val, err +} + +// LongestPrefix is used to fetch the longest prefix match for the given +// constraints on the index. Note that this will not work with the memdb +// StringFieldIndex because it adds null terminators which prevent the +// algorithm from correctly finding a match (it will get to right before the +// null and fail to find a leaf node). This should only be used where the prefix +// given is capable of matching indexed entries directly, which typically only +// applies to a custom indexer. See the unit test for an example. +func (txn *Txn) LongestPrefix(table, index string, args ...interface{}) (interface{}, error) { + // Enforce that this only works on prefix indexes. + if !strings.HasSuffix(index, "_prefix") { + return nil, fmt.Errorf("must use '%s_prefix' on index", index) + } + + // Get the index value. + indexSchema, val, err := txn.getIndexValue(table, index, args...) + if err != nil { + return nil, err + } + + // This algorithm only makes sense against a unique index, otherwise the + // index keys will have the IDs appended to them. + if !indexSchema.Unique { + return nil, fmt.Errorf("index '%s' is not unique", index) + } + + // Find the longest prefix match with the given index. + indexTxn := txn.readableIndex(table, indexSchema.Name) + if _, value, ok := indexTxn.Root().LongestPrefix(val); ok { + return value, nil + } + return nil, nil +} + +// getIndexValue is used to get the IndexSchema and the value +// used to scan the index given the parameters. This handles prefix based +// scans when the index has the "_prefix" suffix. The index must support +// prefix iteration. +func (txn *Txn) getIndexValue(table, index string, args ...interface{}) (*IndexSchema, []byte, error) { + // Get the table schema + tableSchema, ok := txn.db.schema.Tables[table] + if !ok { + return nil, nil, fmt.Errorf("invalid table '%s'", table) + } + + // Check for a prefix scan + prefixScan := false + if strings.HasSuffix(index, "_prefix") { + index = strings.TrimSuffix(index, "_prefix") + prefixScan = true + } + + // Get the index schema + indexSchema, ok := tableSchema.Indexes[index] + if !ok { + return nil, nil, fmt.Errorf("invalid index '%s'", index) + } + + // Hot-path for when there are no arguments + if len(args) == 0 { + return indexSchema, nil, nil + } + + // Special case the prefix scanning + if prefixScan { + prefixIndexer, ok := indexSchema.Indexer.(PrefixIndexer) + if !ok { + return indexSchema, nil, + fmt.Errorf("index '%s' does not support prefix scanning", index) + } + + val, err := prefixIndexer.PrefixFromArgs(args...) + if err != nil { + return indexSchema, nil, fmt.Errorf("index error: %v", err) + } + return indexSchema, val, err + } + + // Get the exact match index + val, err := indexSchema.Indexer.FromArgs(args...) + if err != nil { + return indexSchema, nil, fmt.Errorf("index error: %v", err) + } + return indexSchema, val, err +} + +// ResultIterator is used to iterate over a list of results +// from a Get query on a table. +type ResultIterator interface { + WatchCh() <-chan struct{} + Next() interface{} +} + +// Get is used to construct a ResultIterator over all the +// rows that match the given constraints of an index. +func (txn *Txn) Get(table, index string, args ...interface{}) (ResultIterator, error) { + // Get the index value to scan + indexSchema, val, err := txn.getIndexValue(table, index, args...) + if err != nil { + return nil, err + } + + // Get the index itself + indexTxn := txn.readableIndex(table, indexSchema.Name) + indexRoot := indexTxn.Root() + + // Get an interator over the index + indexIter := indexRoot.Iterator() + + // Seek the iterator to the appropriate sub-set + watchCh := indexIter.SeekPrefixWatch(val) + + // Create an iterator + iter := &radixIterator{ + iter: indexIter, + watchCh: watchCh, + } + return iter, nil +} + +// Defer is used to push a new arbitrary function onto a stack which +// gets called when a transaction is committed and finished. Deferred +// functions are called in LIFO order, and only invoked at the end of +// write transactions. +func (txn *Txn) Defer(fn func()) { + txn.after = append(txn.after, fn) +} + +// radixIterator is used to wrap an underlying iradix iterator. +// This is much more efficient than a sliceIterator as we are not +// materializing the entire view. +type radixIterator struct { + iter *iradix.Iterator + watchCh <-chan struct{} +} + +func (r *radixIterator) WatchCh() <-chan struct{} { + return r.watchCh +} + +func (r *radixIterator) Next() interface{} { + _, value, ok := r.iter.Next() + if !ok { + return nil + } + return value +} diff --git a/vendor/github.com/hashicorp/go-memdb/watch.go b/vendor/github.com/hashicorp/go-memdb/watch.go new file mode 100644 index 0000000000..a6f01213be --- /dev/null +++ b/vendor/github.com/hashicorp/go-memdb/watch.go @@ -0,0 +1,129 @@ +package memdb + +import ( + "context" + "time" +) + +// WatchSet is a collection of watch channels. +type WatchSet map[<-chan struct{}]struct{} + +// NewWatchSet constructs a new watch set. +func NewWatchSet() WatchSet { + return make(map[<-chan struct{}]struct{}) +} + +// Add appends a watchCh to the WatchSet if non-nil. +func (w WatchSet) Add(watchCh <-chan struct{}) { + if w == nil { + return + } + + if _, ok := w[watchCh]; !ok { + w[watchCh] = struct{}{} + } +} + +// AddWithLimit appends a watchCh to the WatchSet if non-nil, and if the given +// softLimit hasn't been exceeded. Otherwise, it will watch the given alternate +// channel. It's expected that the altCh will be the same on many calls to this +// function, so you will exceed the soft limit a little bit if you hit this, but +// not by much. +// +// This is useful if you want to track individual items up to some limit, after +// which you watch a higher-level channel (usually a channel from start start of +// an iterator higher up in the radix tree) that will watch a superset of items. +func (w WatchSet) AddWithLimit(softLimit int, watchCh <-chan struct{}, altCh <-chan struct{}) { + // This is safe for a nil WatchSet so we don't need to check that here. + if len(w) < softLimit { + w.Add(watchCh) + } else { + w.Add(altCh) + } +} + +// Watch is used to wait for either the watch set to trigger or a timeout. +// Returns true on timeout. +func (w WatchSet) Watch(timeoutCh <-chan time.Time) bool { + if w == nil { + return false + } + + // Create a context that gets cancelled when the timeout is triggered + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { + select { + case <-timeoutCh: + cancel() + case <-ctx.Done(): + } + }() + + return w.WatchCtx(ctx) == context.Canceled +} + +// WatchCtx is used to wait for either the watch set to trigger or for the +// context to be cancelled. Watch with a timeout channel can be mimicked by +// creating a context with a deadline. WatchCtx should be preferred over Watch. +func (w WatchSet) WatchCtx(ctx context.Context) error { + if w == nil { + return nil + } + + if n := len(w); n <= aFew { + idx := 0 + chunk := make([]<-chan struct{}, aFew) + for watchCh := range w { + chunk[idx] = watchCh + idx++ + } + return watchFew(ctx, chunk) + } + + return w.watchMany(ctx) +} + +// watchMany is used if there are many watchers. +func (w WatchSet) watchMany(ctx context.Context) error { + // Set up a goroutine for each watcher. + triggerCh := make(chan struct{}, 1) + watcher := func(chunk []<-chan struct{}) { + if err := watchFew(ctx, chunk); err == nil { + select { + case triggerCh <- struct{}{}: + default: + } + } + } + + // Apportion the watch channels into chunks we can feed into the + // watchFew helper. + idx := 0 + chunk := make([]<-chan struct{}, aFew) + for watchCh := range w { + subIdx := idx % aFew + chunk[subIdx] = watchCh + idx++ + + // Fire off this chunk and start a fresh one. + if idx%aFew == 0 { + go watcher(chunk) + chunk = make([]<-chan struct{}, aFew) + } + } + + // Make sure to watch any residual channels in the last chunk. + if idx%aFew != 0 { + go watcher(chunk) + } + + // Wait for a channel to trigger or timeout. + select { + case <-triggerCh: + return nil + case <-ctx.Done(): + return ctx.Err() + } +} diff --git a/vendor/github.com/hashicorp/go-memdb/watch_few.go b/vendor/github.com/hashicorp/go-memdb/watch_few.go new file mode 100644 index 0000000000..880f098b77 --- /dev/null +++ b/vendor/github.com/hashicorp/go-memdb/watch_few.go @@ -0,0 +1,117 @@ +package memdb + +//go:generate sh -c "go run watch-gen/main.go >watch_few.go" + +import( + "context" +) + +// aFew gives how many watchers this function is wired to support. You must +// always pass a full slice of this length, but unused channels can be nil. +const aFew = 32 + +// watchFew is used if there are only a few watchers as a performance +// optimization. +func watchFew(ctx context.Context, ch []<-chan struct{}) error { + select { + + case <-ch[0]: + return nil + + case <-ch[1]: + return nil + + case <-ch[2]: + return nil + + case <-ch[3]: + return nil + + case <-ch[4]: + return nil + + case <-ch[5]: + return nil + + case <-ch[6]: + return nil + + case <-ch[7]: + return nil + + case <-ch[8]: + return nil + + case <-ch[9]: + return nil + + case <-ch[10]: + return nil + + case <-ch[11]: + return nil + + case <-ch[12]: + return nil + + case <-ch[13]: + return nil + + case <-ch[14]: + return nil + + case <-ch[15]: + return nil + + case <-ch[16]: + return nil + + case <-ch[17]: + return nil + + case <-ch[18]: + return nil + + case <-ch[19]: + return nil + + case <-ch[20]: + return nil + + case <-ch[21]: + return nil + + case <-ch[22]: + return nil + + case <-ch[23]: + return nil + + case <-ch[24]: + return nil + + case <-ch[25]: + return nil + + case <-ch[26]: + return nil + + case <-ch[27]: + return nil + + case <-ch[28]: + return nil + + case <-ch[29]: + return nil + + case <-ch[30]: + return nil + + case <-ch[31]: + return nil + + case <-ctx.Done(): + return ctx.Err() + } +} diff --git a/vendor/vendor.json b/vendor/vendor.json index b4820b5718..4f387ca275 100644 --- a/vendor/vendor.json +++ b/vendor/vendor.json @@ -1032,6 +1032,12 @@ "revision": "8aac2701530899b64bdea735a1de8da899815220", "revisionTime": "2017-07-25T22:12:15Z" }, + { + "checksumSHA1": "2JVfMLNCW8hfVlPAwAHlOX4HW2s=", + "path": "github.com/hashicorp/go-memdb", + "revision": "75ff99613d288868d8888ec87594525906815dbc", + "revisionTime": "2017-10-05T03:07:53Z" + }, { "checksumSHA1": "g7uHECbzuaWwdxvwoyxBwgeERPk=", "path": "github.com/hashicorp/go-multierror",