Files
vault/builtin/logical/pki/issuing/keys.go
Steven Clark cbf6dc2c4f PKI refactoring to start breaking apart monolith into sub-packages (#24406)
* PKI refactoring to start breaking apart monolith into sub-packages

 - This was broken down by commit within enterprise for ease of review
   but would be too difficult to bring back individual commits back
   to the CE repository. (they would be squashed anyways)
 - This change was created by exporting a patch of the enterprise PR
   and applying it to CE repository

* Fix TestBackend_OID_SANs to not be rely on map ordering
2023-12-07 09:22:53 -05:00

154 lines
3.8 KiB
Go

// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package issuing
import (
"context"
"fmt"
"github.com/hashicorp/vault/sdk/helper/certutil"
"github.com/hashicorp/vault/sdk/helper/errutil"
"github.com/hashicorp/vault/sdk/logical"
"github.com/hashicorp/vault/builtin/logical/pki/managed_key"
)
const (
KeyPrefix = "config/key/"
KeyRefNotFound = KeyID("not-found")
)
type KeyID string
func (p KeyID) String() string {
return string(p)
}
type KeyEntry struct {
ID KeyID `json:"id"`
Name string `json:"name"`
PrivateKeyType certutil.PrivateKeyType `json:"private_key_type"`
PrivateKey string `json:"private_key"`
}
func (e KeyEntry) IsManagedPrivateKey() bool {
return e.PrivateKeyType == certutil.ManagedPrivateKey
}
func ListKeys(ctx context.Context, s logical.Storage) ([]KeyID, error) {
strList, err := s.List(ctx, KeyPrefix)
if err != nil {
return nil, err
}
keyIds := make([]KeyID, 0, len(strList))
for _, entry := range strList {
keyIds = append(keyIds, KeyID(entry))
}
return keyIds, nil
}
func FetchKeyById(ctx context.Context, s logical.Storage, keyId KeyID) (*KeyEntry, error) {
if len(keyId) == 0 {
return nil, errutil.InternalError{Err: "unable to fetch pki key: empty key identifier"}
}
entry, err := s.Get(ctx, KeyPrefix+keyId.String())
if err != nil {
return nil, errutil.InternalError{Err: fmt.Sprintf("unable to fetch pki key: %v", err)}
}
if entry == nil {
return nil, errutil.UserError{Err: fmt.Sprintf("pki key id %s does not exist", keyId.String())}
}
var key KeyEntry
if err := entry.DecodeJSON(&key); err != nil {
return nil, errutil.InternalError{Err: fmt.Sprintf("unable to decode pki key with id %s: %v", keyId.String(), err)}
}
return &key, nil
}
func WriteKey(ctx context.Context, s logical.Storage, key KeyEntry) error {
keyId := key.ID
json, err := logical.StorageEntryJSON(KeyPrefix+keyId.String(), key)
if err != nil {
return err
}
return s.Put(ctx, json)
}
func DeleteKey(ctx context.Context, s logical.Storage, id KeyID) (bool, error) {
config, err := GetKeysConfig(ctx, s)
if err != nil {
return false, err
}
wasDefault := false
if config.DefaultKeyId == id {
wasDefault = true
config.DefaultKeyId = KeyID("")
if err := SetKeysConfig(ctx, s, config); err != nil {
return wasDefault, err
}
}
return wasDefault, s.Delete(ctx, KeyPrefix+id.String())
}
func ResolveKeyReference(ctx context.Context, s logical.Storage, reference string) (KeyID, error) {
if reference == DefaultRef {
// Handle fetching the default key.
config, err := GetKeysConfig(ctx, s)
if err != nil {
return KeyID("config-error"), err
}
if len(config.DefaultKeyId) == 0 {
return KeyRefNotFound, fmt.Errorf("no default key currently configured")
}
return config.DefaultKeyId, nil
}
// Lookup by a direct get first to see if our reference is an ID, this is quick and cached.
if len(reference) == uuidLength {
entry, err := s.Get(ctx, KeyPrefix+reference)
if err != nil {
return KeyID("key-read"), err
}
if entry != nil {
return KeyID(reference), nil
}
}
// ... than to pull all keys from storage.
keys, err := ListKeys(ctx, s)
if err != nil {
return KeyID("list-error"), err
}
for _, keyId := range keys {
key, err := FetchKeyById(ctx, s, keyId)
if err != nil {
return KeyID("key-read"), err
}
if key.Name == reference {
return key.ID, nil
}
}
// Otherwise, we must not have found the key.
return KeyRefNotFound, errutil.UserError{Err: fmt.Sprintf("unable to find PKI key for reference: %v", reference)}
}
func GetManagedKeyUUID(key *KeyEntry) (managed_key.UUIDKey, error) {
if !key.IsManagedPrivateKey() {
return "", errutil.InternalError{Err: "getManagedKeyUUID called on a key id %s (%s) "}
}
return managed_key.ExtractManagedKeyId([]byte(key.PrivateKey))
}