PKI refactoring to start breaking apart monolith into sub-packages (#24406)

* PKI refactoring to start breaking apart monolith into sub-packages

 - This was broken down by commit within enterprise for ease of review
   but would be too difficult to bring back individual commits back
   to the CE repository. (they would be squashed anyways)
 - This change was created by exporting a patch of the enterprise PR
   and applying it to CE repository

* Fix TestBackend_OID_SANs to not be rely on map ordering
This commit is contained in:
Steven Clark
2023-12-07 09:22:53 -05:00
committed by GitHub
parent a4180c193b
commit cbf6dc2c4f
70 changed files with 4620 additions and 3543 deletions

View File

@@ -17,6 +17,7 @@ import (
"time"
"github.com/hashicorp/go-secure-stdlib/nonceutil"
"github.com/hashicorp/vault/builtin/logical/pki/issuing"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/consts"
"github.com/hashicorp/vault/sdk/logical"
@@ -220,7 +221,7 @@ type acmeOrder struct {
CertificateSerialNumber string `json:"cert-serial-number"`
CertificateExpiry time.Time `json:"cert-expiry"`
// The actual issuer UUID that issued the certificate, blank if an order exists but no certificate was issued.
IssuerId issuerID `json:"issuer-id"`
IssuerId issuing.IssuerID `json:"issuer-id"`
}
func (o acmeOrder) getIdentifierDNSValues() []string {

View File

@@ -13,6 +13,8 @@ import (
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/logical"
"github.com/hashicorp/vault/builtin/logical/pki/issuing"
)
type acmeContext struct {
@@ -20,8 +22,8 @@ type acmeContext struct {
baseUrl *url.URL
clusterUrl *url.URL
sc *storageContext
role *roleEntry
issuer *issuerEntry
role *issuing.RoleEntry
issuer *issuing.IssuerEntry
// acmeDirectory is a string that can distinguish the various acme directories we have configured
// if something needs to remain locked into a directory path structure.
acmeDirectory string
@@ -31,7 +33,7 @@ type acmeContext struct {
}
func (c acmeContext) getAcmeState() *acmeState {
return c.sc.Backend.acmeState
return c.sc.Backend.GetAcmeState()
}
type (
@@ -109,7 +111,7 @@ func (b *backend) acmeWrapper(opts acmeWrapperOpts, op acmeOperation) framework.
return acmeErrorWrapper(func(ctx context.Context, r *logical.Request, data *framework.FieldData) (*logical.Response, error) {
sc := b.makeStorageContext(ctx, r.Storage)
config, err := sc.Backend.acmeState.getConfigWithUpdate(sc)
config, err := sc.Backend.GetAcmeState().getConfigWithUpdate(sc)
if err != nil {
return nil, fmt.Errorf("failed to fetch ACME configuration: %w", err)
}
@@ -124,7 +126,7 @@ func (b *backend) acmeWrapper(opts acmeWrapperOpts, op acmeOperation) framework.
return nil, ErrAcmeDisabled
}
if b.useLegacyBundleCaStorage() {
if b.UseLegacyBundleCaStorage() {
return nil, fmt.Errorf("%w: Can not perform ACME operations until migration has completed", ErrServerInternal)
}
@@ -180,7 +182,7 @@ func (b *backend) acmeWrapper(opts acmeWrapperOpts, op acmeOperation) framework.
// it does not enforce the account being in a valid state nor existing.
func (b *backend) acmeParsedWrapper(opt acmeWrapperOpts, op acmeParsedOperation) framework.OperationFunc {
return b.acmeWrapper(opt, func(acmeCtx *acmeContext, r *logical.Request, fields *framework.FieldData) (*logical.Response, error) {
user, data, err := b.acmeState.ParseRequestParams(acmeCtx, r, fields)
user, data, err := b.GetAcmeState().ParseRequestParams(acmeCtx, r, fields)
if err != nil {
return nil, err
}
@@ -194,7 +196,7 @@ func (b *backend) acmeParsedWrapper(opt acmeWrapperOpts, op acmeParsedOperation)
}
if _, ok := resp.Headers["Replay-Nonce"]; !ok {
nonce, _, err := b.acmeState.GetNonce()
nonce, _, err := b.GetAcmeState().GetNonce()
if err != nil {
return nil, err
}
@@ -255,8 +257,7 @@ func (b *backend) acmeParsedWrapper(opt acmeWrapperOpts, op acmeParsedOperation)
// request has a proper signature for an existing account, and that account is
// in a valid status. It passes to the operation a decoded form of the request
// parameters as well as the ACME account the request is for.
func (b *backend) acmeAccountRequiredWrapper(opt acmeWrapperOpts, op acmeAccountRequiredOperation) framework.
OperationFunc {
func (b *backend) acmeAccountRequiredWrapper(opt acmeWrapperOpts, op acmeAccountRequiredOperation) framework.OperationFunc {
return b.acmeParsedWrapper(opt, func(acmeCtx *acmeContext, r *logical.Request, fields *framework.FieldData, uc *jwsCtx, data map[string]interface{}) (*logical.Response, error) {
if !uc.Existing {
return nil, fmt.Errorf("cannot process request without a 'kid': %w", ErrMalformed)
@@ -320,7 +321,7 @@ func getBasePathFromClusterConfig(sc *storageContext) (*url.URL, error) {
return baseUrl, nil
}
func getAcmeIssuer(sc *storageContext, issuerName string) (*issuerEntry, error) {
func getAcmeIssuer(sc *storageContext, issuerName string) (*issuing.IssuerEntry, error) {
if issuerName == "" {
issuerName = defaultRef
}
@@ -334,7 +335,7 @@ func getAcmeIssuer(sc *storageContext, issuerName string) (*issuerEntry, error)
return nil, fmt.Errorf("issuer failed to load: %w", err)
}
if issuer.Usage.HasUsage(IssuanceUsage) && len(issuer.KeyID) > 0 {
if issuer.Usage.HasUsage(issuing.IssuanceUsage) && len(issuer.KeyID) > 0 {
return issuer, nil
}
@@ -358,12 +359,12 @@ func getAcmeDirectory(r *logical.Request) (string, error) {
return strings.TrimLeft(acmePath[0:lastIndex]+"/acme/", "/"), nil
}
func getAcmeRoleAndIssuer(sc *storageContext, data *framework.FieldData, config *acmeConfigEntry) (*roleEntry, *issuerEntry, error) {
func getAcmeRoleAndIssuer(sc *storageContext, data *framework.FieldData, config *acmeConfigEntry) (*issuing.RoleEntry, *issuing.IssuerEntry, error) {
requestedIssuer := getRequestedAcmeIssuerFromPath(data)
requestedRole := getRequestedAcmeRoleFromPath(data)
issuerToLoad := requestedIssuer
var role *roleEntry
var role *issuing.RoleEntry
var err error
if len(requestedRole) == 0 { // Default Directory
@@ -375,11 +376,9 @@ func getAcmeRoleAndIssuer(sc *storageContext, data *framework.FieldData, config
case Forbid:
return nil, nil, fmt.Errorf("%w: default directory not allowed by ACME policy", ErrServerInternal)
case SignVerbatim, ExternalPolicy:
role = buildSignVerbatimRoleWithNoData(&roleEntry{
Issuer: requestedIssuer,
NoStore: false,
Name: requestedRole,
})
role = issuing.SignVerbatimRoleWithOpts(
issuing.WithIssuer(requestedIssuer),
issuing.WithNoStore(false))
case Role:
role, err = getAndValidateAcmeRole(sc, extraInfo)
if err != nil {
@@ -455,9 +454,9 @@ func getAcmeRoleAndIssuer(sc *storageContext, data *framework.FieldData, config
return role, issuer, nil
}
func getAndValidateAcmeRole(sc *storageContext, requestedRole string) (*roleEntry, error) {
func getAndValidateAcmeRole(sc *storageContext, requestedRole string) (*issuing.RoleEntry, error) {
var err error
role, err := sc.Backend.getRole(sc.Context, sc.Storage, requestedRole)
role, err := sc.Backend.GetRole(sc.Context, sc.Storage, requestedRole)
if err != nil {
return nil, fmt.Errorf("%w: err loading role", ErrServerInternal)
}
@@ -491,14 +490,6 @@ func getRequestedAcmeIssuerFromPath(data *framework.FieldData) string {
return requestedIssuer
}
func getRequestedPolicyFromPath(data *framework.FieldData) string {
requestedPolicy := ""
if requestedPolicyRaw, present := data.GetOk("policy"); present {
requestedPolicy = requestedPolicyRaw.(string)
}
return requestedPolicy
}
func isAcmeDisabled(sc *storageContext, config *acmeConfigEntry, policy EabPolicy) bool {
if !config.Enabled {
return true

View File

@@ -6,21 +6,23 @@ package pki
import (
"context"
"fmt"
"sort"
"strings"
"sync"
"sync/atomic"
"time"
atomic2 "go.uber.org/atomic"
"github.com/armon/go-metrics"
"github.com/hashicorp/go-multierror"
"github.com/hashicorp/vault/helper/metricsutil"
"github.com/hashicorp/vault/helper/namespace"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/consts"
"github.com/hashicorp/vault/sdk/helper/errutil"
"github.com/hashicorp/vault/sdk/logical"
"github.com/hashicorp/vault/builtin/logical/pki/issuing"
"github.com/hashicorp/vault/builtin/logical/pki/managed_key"
"github.com/hashicorp/vault/builtin/logical/pki/pki_backend"
)
const (
@@ -287,19 +289,10 @@ func Backend(conf *logical.BackendConfig) *backend {
// Delay the first tidy until after we've started up.
b.lastTidy = time.Now()
// Metrics initialization for count of certificates in storage
b.certCountEnabled = atomic2.NewBool(false)
b.publishCertCountMetrics = atomic2.NewBool(false)
b.certsCounted = atomic2.NewBool(false)
b.certCountError = "Initialize Not Yet Run, Cert Counts Unavailable"
b.certCount = &atomic.Uint32{}
b.revokedCertCount = &atomic.Uint32{}
b.possibleDoubleCountedSerials = make([]string, 0, 250)
b.possibleDoubleCountedRevokedSerials = make([]string, 0, 250)
b.unifiedTransferStatus = newUnifiedTransferStatus()
b.acmeState = NewACMEState()
b.certificateCounter = NewCertificateCounter(b.backendUUID)
b.SetupEnt()
return &b
@@ -319,19 +312,12 @@ type backend struct {
tidyStatus *tidyStatus
lastTidy time.Time
unifiedTransferStatus *unifiedTransferStatus
unifiedTransferStatus *UnifiedTransferStatus
certCountEnabled *atomic2.Bool
publishCertCountMetrics *atomic2.Bool
certCount *atomic.Uint32
revokedCertCount *atomic.Uint32
certsCounted *atomic2.Bool
certCountError string
possibleDoubleCountedSerials []string
possibleDoubleCountedRevokedSerials []string
certificateCounter *CertificateCounter
pkiStorageVersion atomic.Value
crlBuilder *crlBuilder
crlBuilder *CrlBuilder
// Write lock around issuers and keys.
issuersLock sync.RWMutex
@@ -341,7 +327,25 @@ type backend struct {
acmeAccountLock sync.RWMutex // (Write) Locked on Tidy, (Read) Locked on Account Creation
}
type roleOperation func(ctx context.Context, req *logical.Request, data *framework.FieldData, role *roleEntry) (*logical.Response, error)
// BackendOps a bridge/legacy interface until we can further
// separate out backend things into distinct packages.
type BackendOps interface {
managed_key.PkiManagedKeyView
pki_backend.SystemViewGetter
pki_backend.MountInfo
pki_backend.Logger
UseLegacyBundleCaStorage() bool
CrlBuilder() *CrlBuilder
GetRevokeStorageLock() *sync.RWMutex
GetUnifiedTransferStatus() *UnifiedTransferStatus
GetAcmeState() *acmeState
GetRole(ctx context.Context, s logical.Storage, n string) (*issuing.RoleEntry, error)
GetCertificateCounter() *CertificateCounter
}
var _ BackendOps = &backend{}
type roleOperation func(ctx context.Context, req *logical.Request, data *framework.FieldData, role *issuing.RoleEntry) (*logical.Response, error)
const backendHelp = `
The PKI backend dynamically generates X509 server and client certificates.
@@ -363,7 +367,7 @@ func metricsKey(req *logical.Request, extra ...string) []string {
func (b *backend) metricsWrap(callType string, roleMode int, ofunc roleOperation) framework.OperationFunc {
return func(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
key := metricsKey(req, callType)
var role *roleEntry
var role *issuing.RoleEntry
var labels []metrics.Label
var err error
@@ -379,7 +383,7 @@ func (b *backend) metricsWrap(callType string, roleMode int, ofunc roleOperation
}
if roleMode > noRole {
// Get the role
role, err = b.getRole(ctx, req.Storage, roleName)
role, err = b.GetRole(ctx, req.Storage, roleName)
if err != nil {
return nil, err
}
@@ -410,7 +414,7 @@ func (b *backend) metricsWrap(callType string, roleMode int, ofunc roleOperation
// initialize is used to perform a possible PKI storage migration if needed
func (b *backend) initialize(ctx context.Context, ir *logical.InitializationRequest) error {
sc := b.makeStorageContext(ctx, b.storage)
if err := b.crlBuilder.reloadConfigIfRequired(sc); err != nil {
if err := b.CrlBuilder().reloadConfigIfRequired(sc); err != nil {
return err
}
@@ -419,7 +423,7 @@ func (b *backend) initialize(ctx context.Context, ir *logical.InitializationRequ
return err
}
err = b.acmeState.Initialize(b, sc)
err = b.GetAcmeState().Initialize(b, sc)
if err != nil {
return err
}
@@ -429,7 +433,7 @@ func (b *backend) initialize(ctx context.Context, ir *logical.InitializationRequ
if err != nil {
// Don't block/err initialize/startup for metrics. Context on this call can time out due to number of certificates.
b.Logger().Error("Could not initialize stored certificate counts", "error", err)
b.certCountError = err.Error()
b.GetCertificateCounter().SetError(err)
}
return b.initializeEnt(sc, ir)
@@ -438,7 +442,7 @@ func (b *backend) initialize(ctx context.Context, ir *logical.InitializationRequ
func (b *backend) cleanup(ctx context.Context) {
sc := b.makeStorageContext(ctx, b.storage)
b.acmeState.Shutdown(b)
b.GetAcmeState().Shutdown(b)
b.cleanupEnt(sc)
}
@@ -469,7 +473,31 @@ func (b *backend) initializePKIIssuersStorage(ctx context.Context) error {
return nil
}
func (b *backend) useLegacyBundleCaStorage() bool {
func (b *backend) BackendUUID() string {
return b.backendUUID
}
func (b *backend) CrlBuilder() *CrlBuilder {
return b.crlBuilder
}
func (b *backend) GetRevokeStorageLock() *sync.RWMutex {
return &b.revokeStorageLock
}
func (b *backend) GetUnifiedTransferStatus() *UnifiedTransferStatus {
return b.unifiedTransferStatus
}
func (b *backend) GetAcmeState() *acmeState {
return b.acmeState
}
func (b *backend) GetCertificateCounter() *CertificateCounter {
return b.certificateCounter
}
func (b *backend) UseLegacyBundleCaStorage() bool {
// This helper function is here to choose whether or not we use the newer
// issuer/key storage format or the older legacy ca bundle format.
//
@@ -482,6 +510,18 @@ func (b *backend) useLegacyBundleCaStorage() bool {
return version == nil || version == 0
}
func (b *backend) IsSecondaryNode() bool {
return b.System().ReplicationState().HasState(consts.ReplicationPerformanceStandby)
}
func (b *backend) GetManagedKeyView() (logical.ManagedKeySystemView, error) {
managedKeyView, ok := b.System().(logical.ManagedKeySystemView)
if !ok {
return nil, errutil.InternalError{Err: fmt.Sprintf("unsupported system view")}
}
return managedKeyView, nil
}
func (b *backend) updatePkiStorageVersion(ctx context.Context, grabIssuersLock bool) {
info, err := getMigrationInfo(ctx, b.storage)
if err != nil {
@@ -520,36 +560,36 @@ func (b *backend) invalidate(ctx context.Context, key string) {
go func() {
b.Logger().Info("Detected a migration completed, resetting pki storage version")
b.updatePkiStorageVersion(ctx, true)
b.crlBuilder.requestRebuildIfActiveNode(b)
b.CrlBuilder().requestRebuildIfActiveNode(b)
}()
case strings.HasPrefix(key, issuerPrefix):
if !b.useLegacyBundleCaStorage() {
if !b.UseLegacyBundleCaStorage() {
// See note in updateDefaultIssuerId about why this is necessary.
// We do this ahead of CRL rebuilding just so we know that things
// are stale.
b.crlBuilder.invalidateCRLBuildTime()
b.CrlBuilder().invalidateCRLBuildTime()
// If an issuer has changed on the primary, we need to schedule an update of our CRL,
// the primary cluster would have done it already, but the CRL is cluster specific so
// force a rebuild of ours.
b.crlBuilder.requestRebuildIfActiveNode(b)
b.CrlBuilder().requestRebuildIfActiveNode(b)
} else {
b.Logger().Debug("Ignoring invalidation updates for issuer as the PKI migration has yet to complete.")
}
case key == "config/crl":
// We may need to reload our OCSP status flag
b.crlBuilder.markConfigDirty()
b.CrlBuilder().markConfigDirty()
case key == storageAcmeConfig:
b.acmeState.markConfigDirty()
b.GetAcmeState().markConfigDirty()
case key == storageIssuerConfig:
b.crlBuilder.invalidateCRLBuildTime()
b.CrlBuilder().invalidateCRLBuildTime()
case strings.HasPrefix(key, crossRevocationPrefix):
split := strings.Split(key, "/")
if !strings.HasSuffix(key, "/confirmed") {
cluster := split[len(split)-2]
serial := split[len(split)-1]
b.crlBuilder.addCertForRevocationCheck(cluster, serial)
b.CrlBuilder().addCertForRevocationCheck(cluster, serial)
} else {
if len(split) >= 3 {
cluster := split[len(split)-3]
@@ -560,7 +600,7 @@ func (b *backend) invalidate(ctx context.Context, key string) {
// ignore them). On performance primary nodes though,
// we do want to track them to remove them.
if !isNotPerfPrimary {
b.crlBuilder.addCertForRevocationRemoval(cluster, serial)
b.CrlBuilder().addCertForRevocationRemoval(cluster, serial)
}
}
}
@@ -569,7 +609,7 @@ func (b *backend) invalidate(ctx context.Context, key string) {
split := strings.Split(key, "/")
cluster := split[len(split)-2]
serial := split[len(split)-1]
b.crlBuilder.addCertFromCrossRevocation(cluster, serial)
b.CrlBuilder().addCertFromCrossRevocation(cluster, serial)
}
b.invalidateEnt(ctx, key)
@@ -580,7 +620,7 @@ func (b *backend) periodicFunc(ctx context.Context, request *logical.Request) er
doCRL := func() error {
// First attempt to reload the CRL configuration.
if err := b.crlBuilder.reloadConfigIfRequired(sc); err != nil {
if err := b.CrlBuilder().reloadConfigIfRequired(sc); err != nil {
return err
}
@@ -592,22 +632,22 @@ func (b *backend) periodicFunc(ctx context.Context, request *logical.Request) er
}
// First handle any global revocation queue entries.
if err := b.crlBuilder.processRevocationQueue(sc); err != nil {
if err := b.CrlBuilder().processRevocationQueue(sc); err != nil {
return err
}
// Then handle any unified cross-cluster revocations.
if err := b.crlBuilder.processCrossClusterRevocations(sc); err != nil {
if err := b.CrlBuilder().processCrossClusterRevocations(sc); err != nil {
return err
}
// Check if we're set to auto rebuild and a CRL is set to expire.
if err := b.crlBuilder.checkForAutoRebuild(sc); err != nil {
if err := b.CrlBuilder().checkForAutoRebuild(sc); err != nil {
return err
}
// Then attempt to rebuild the CRLs if required.
warnings, err := b.crlBuilder.rebuildIfForced(sc)
warnings, err := b.CrlBuilder().rebuildIfForced(sc)
if err != nil {
return err
}
@@ -622,7 +662,7 @@ func (b *backend) periodicFunc(ctx context.Context, request *logical.Request) er
// If a delta CRL was rebuilt above as part of the complete CRL rebuild,
// this will be a no-op. However, if we do need to rebuild delta CRLs,
// this would cause us to do so.
warnings, err = b.crlBuilder.rebuildDeltaCRLsIfForced(sc, false)
warnings, err = b.CrlBuilder().rebuildDeltaCRLsIfForced(sc, false)
if err != nil {
return err
}
@@ -689,7 +729,7 @@ func (b *backend) periodicFunc(ctx context.Context, request *logical.Request) er
}
// First tidy any ACME nonces to free memory.
b.acmeState.DoTidyNonces()
b.GetAcmeState().DoTidyNonces()
// Then run unified transfer.
backgroundSc := b.makeStorageContext(context.Background(), b.storage)
@@ -700,11 +740,7 @@ func (b *backend) periodicFunc(ctx context.Context, request *logical.Request) er
tidyErr := doAutoTidy()
// Periodically re-emit gauges so that they don't disappear/go stale
tidyConfig, err := sc.getAutoTidyConfig()
if err != nil {
return err
}
b.emitCertStoreMetrics(tidyConfig)
b.GetCertificateCounter().EmitCertStoreMetrics()
var errors error
if crlErr != nil {
@@ -721,7 +757,7 @@ func (b *backend) periodicFunc(ctx context.Context, request *logical.Request) er
// Check if the CRL was invalidated due to issuer swap and update
// accordingly.
if err := b.crlBuilder.flushCRLBuildTimeInvalidation(sc); err != nil {
if err := b.CrlBuilder().flushCRLBuildTimeInvalidation(sc); err != nil {
return err
}
@@ -742,211 +778,22 @@ func (b *backend) initializeStoredCertificateCounts(ctx context.Context) error {
return err
}
b.certCountEnabled.Store(config.MaintainCount)
b.publishCertCountMetrics.Store(config.PublishMetrics)
if config.MaintainCount == false {
b.possibleDoubleCountedRevokedSerials = nil
b.possibleDoubleCountedSerials = nil
b.certsCounted.Store(true)
b.certCount.Store(0)
b.revokedCertCount.Store(0)
b.certCountError = "Cert Count is Disabled: enable via Tidy Config maintain_stored_certificate_counts"
certCounter := b.GetCertificateCounter()
isEnabled := certCounter.ReconfigureWithTidyConfig(config)
if !isEnabled {
return nil
}
// Ideally these three things would be set in one transaction, since that isn't possible, set the counts to "0",
// first, so count will over-count (and miss putting things in deduplicate queue), rather than under-count.
b.certCount.Store(0)
b.revokedCertCount.Store(0)
b.possibleDoubleCountedRevokedSerials = nil
b.possibleDoubleCountedSerials = nil
// A cert issued or revoked here will be double-counted. That's okay, this is "best effort" metrics.
b.certsCounted.Store(false)
entries, err := b.storage.List(ctx, "certs/")
if err != nil {
return err
}
b.certCount.Add(uint32(len(entries)))
revokedEntries, err := b.storage.List(ctx, "revoked/")
if err != nil {
return err
}
b.revokedCertCount.Add(uint32(len(revokedEntries)))
b.certsCounted.Store(true)
// Now that the metrics are set, we can switch from appending newly-stored certificates to the possible double-count
// list, and instead have them update the counter directly. We need to do this so that we are looking at a static
// slice of possibly double counted serials. Note that certsCounted is computed before the storage operation, so
// there may be some delay here.
// Sort the listed-entries first, to accommodate that delay.
sort.Slice(entries, func(i, j int) bool {
return entries[i] < entries[j]
})
sort.Slice(revokedEntries, func(i, j int) bool {
return revokedEntries[i] < revokedEntries[j]
})
// We assume here that these lists are now complete.
sort.Slice(b.possibleDoubleCountedSerials, func(i, j int) bool {
return b.possibleDoubleCountedSerials[i] < b.possibleDoubleCountedSerials[j]
})
listEntriesIndex := 0
possibleDoubleCountIndex := 0
for {
if listEntriesIndex >= len(entries) {
break
}
if possibleDoubleCountIndex >= len(b.possibleDoubleCountedSerials) {
break
}
if entries[listEntriesIndex] == b.possibleDoubleCountedSerials[possibleDoubleCountIndex] {
// This represents a double-counted entry
b.decrementTotalCertificatesCountNoReport()
listEntriesIndex = listEntriesIndex + 1
possibleDoubleCountIndex = possibleDoubleCountIndex + 1
continue
}
if entries[listEntriesIndex] < b.possibleDoubleCountedSerials[possibleDoubleCountIndex] {
listEntriesIndex = listEntriesIndex + 1
continue
}
if entries[listEntriesIndex] > b.possibleDoubleCountedSerials[possibleDoubleCountIndex] {
possibleDoubleCountIndex = possibleDoubleCountIndex + 1
continue
}
}
sort.Slice(b.possibleDoubleCountedRevokedSerials, func(i, j int) bool {
return b.possibleDoubleCountedRevokedSerials[i] < b.possibleDoubleCountedRevokedSerials[j]
})
listRevokedEntriesIndex := 0
possibleRevokedDoubleCountIndex := 0
for {
if listRevokedEntriesIndex >= len(revokedEntries) {
break
}
if possibleRevokedDoubleCountIndex >= len(b.possibleDoubleCountedRevokedSerials) {
break
}
if revokedEntries[listRevokedEntriesIndex] == b.possibleDoubleCountedRevokedSerials[possibleRevokedDoubleCountIndex] {
// This represents a double-counted revoked entry
b.decrementTotalRevokedCertificatesCountNoReport()
listRevokedEntriesIndex = listRevokedEntriesIndex + 1
possibleRevokedDoubleCountIndex = possibleRevokedDoubleCountIndex + 1
continue
}
if revokedEntries[listRevokedEntriesIndex] < b.possibleDoubleCountedRevokedSerials[possibleRevokedDoubleCountIndex] {
listRevokedEntriesIndex = listRevokedEntriesIndex + 1
continue
}
if revokedEntries[listRevokedEntriesIndex] > b.possibleDoubleCountedRevokedSerials[possibleRevokedDoubleCountIndex] {
possibleRevokedDoubleCountIndex = possibleRevokedDoubleCountIndex + 1
continue
}
}
b.possibleDoubleCountedRevokedSerials = nil
b.possibleDoubleCountedSerials = nil
b.emitCertStoreMetrics(config)
b.certCountError = ""
certCounter.InitializeCountsFromStorage(entries, revokedEntries)
return nil
}
func (b *backend) emitCertStoreMetrics(config *tidyConfig) {
if config.PublishMetrics == true {
certCount := b.certCount.Load()
b.emitTotalCertCountMetric(certCount)
revokedCertCount := b.revokedCertCount.Load()
b.emitTotalRevokedCountMetric(revokedCertCount)
}
}
// The "certsCounted" boolean here should be loaded from the backend certsCounted before the corresponding storage call:
// eg. certsCounted := b.certsCounted.Load()
func (b *backend) ifCountEnabledIncrementTotalCertificatesCount(certsCounted bool, newSerial string) {
if b.certCountEnabled.Load() {
certCount := b.certCount.Add(1)
switch {
case !certsCounted:
// This is unsafe, but a good best-attempt
if strings.HasPrefix(newSerial, "certs/") {
newSerial = newSerial[6:]
}
b.possibleDoubleCountedSerials = append(b.possibleDoubleCountedSerials, newSerial)
default:
if b.publishCertCountMetrics.Load() {
b.emitTotalCertCountMetric(certCount)
}
}
}
}
func (b *backend) ifCountEnabledDecrementTotalCertificatesCountReport() {
if b.certCountEnabled.Load() {
certCount := b.decrementTotalCertificatesCountNoReport()
if b.publishCertCountMetrics.Load() {
b.emitTotalCertCountMetric(certCount)
}
}
}
func (b *backend) emitTotalCertCountMetric(certCount uint32) {
metrics.SetGauge([]string{"secrets", "pki", b.backendUUID, "total_certificates_stored"}, float32(certCount))
}
// Called directly only by the initialize function to deduplicate the count, when we don't have a full count yet
// Does not respect whether-we-are-counting backend information.
func (b *backend) decrementTotalCertificatesCountNoReport() uint32 {
newCount := b.certCount.Add(^uint32(0))
return newCount
}
// The "certsCounted" boolean here should be loaded from the backend certsCounted before the corresponding storage call:
// eg. certsCounted := b.certsCounted.Load()
func (b *backend) ifCountEnabledIncrementTotalRevokedCertificatesCount(certsCounted bool, newSerial string) {
if b.certCountEnabled.Load() {
newRevokedCertCount := b.revokedCertCount.Add(1)
switch {
case !certsCounted:
// This is unsafe, but a good best-attempt
if strings.HasPrefix(newSerial, "revoked/") { // allow passing in the path (revoked/serial) OR the serial
newSerial = newSerial[8:]
}
b.possibleDoubleCountedRevokedSerials = append(b.possibleDoubleCountedRevokedSerials, newSerial)
default:
if b.publishCertCountMetrics.Load() {
b.emitTotalRevokedCountMetric(newRevokedCertCount)
}
}
}
}
func (b *backend) ifCountEnabledDecrementTotalRevokedCertificatesCountReport() {
if b.certCountEnabled.Load() {
revokedCertCount := b.decrementTotalRevokedCertificatesCountNoReport()
if b.publishCertCountMetrics.Load() {
b.emitTotalRevokedCountMetric(revokedCertCount)
}
}
}
func (b *backend) emitTotalRevokedCountMetric(revokedCertCount uint32) {
metrics.SetGauge([]string{"secrets", "pki", b.backendUUID, "total_revoked_certificates_stored"}, float32(revokedCertCount))
}
// Called directly only by the initialize function to deduplicate the count, when we don't have a full count yet
// Does not respect whether-we-are-counting backend information.
func (b *backend) decrementTotalRevokedCertificatesCountNoReport() uint32 {
newRevokedCertCount := b.revokedCertCount.Add(^uint32(0))
return newRevokedCertCount
}

View File

@@ -5,6 +5,7 @@ package pki
import (
"bytes"
"cmp"
"context"
"crypto"
"crypto/ecdsa"
@@ -26,6 +27,7 @@ import (
"net/url"
"os"
"reflect"
"slices"
"sort"
"strconv"
"strings"
@@ -33,6 +35,7 @@ import (
"testing"
"time"
"github.com/hashicorp/vault/builtin/logical/pki/parsing"
"github.com/hashicorp/vault/helper/testhelpers/teststorage"
"golang.org/x/exp/maps"
@@ -56,6 +59,8 @@ import (
"github.com/hashicorp/vault/vault"
"github.com/mitchellh/mapstructure"
"golang.org/x/net/idna"
"github.com/hashicorp/vault/builtin/logical/pki/issuing"
)
var stepCount = 0
@@ -856,7 +861,7 @@ func generateTestCsr(t *testing.T, keyType certutil.PrivateKeyType, keyBits int)
// Generates steps to test out various role permutations
func generateRoleSteps(t *testing.T, useCSRs bool) []logicaltest.TestStep {
roleVals := roleEntry{
roleVals := issuing.RoleEntry{
MaxTTL: 12 * time.Hour,
KeyType: "rsa",
KeyBits: 2048,
@@ -938,7 +943,7 @@ func generateRoleSteps(t *testing.T, useCSRs bool) []logicaltest.TestStep {
ret = append(ret, issueTestStep)
}
getCountryCheck := func(role roleEntry) logicaltest.TestCheckFunc {
getCountryCheck := func(role issuing.RoleEntry) logicaltest.TestCheckFunc {
var certBundle certutil.CertBundle
return func(resp *logical.Response) error {
err := mapstructure.Decode(resp.Data, &certBundle)
@@ -959,7 +964,7 @@ func generateRoleSteps(t *testing.T, useCSRs bool) []logicaltest.TestStep {
}
}
getOuCheck := func(role roleEntry) logicaltest.TestCheckFunc {
getOuCheck := func(role issuing.RoleEntry) logicaltest.TestCheckFunc {
var certBundle certutil.CertBundle
return func(resp *logical.Response) error {
err := mapstructure.Decode(resp.Data, &certBundle)
@@ -980,7 +985,7 @@ func generateRoleSteps(t *testing.T, useCSRs bool) []logicaltest.TestStep {
}
}
getOrganizationCheck := func(role roleEntry) logicaltest.TestCheckFunc {
getOrganizationCheck := func(role issuing.RoleEntry) logicaltest.TestCheckFunc {
var certBundle certutil.CertBundle
return func(resp *logical.Response) error {
err := mapstructure.Decode(resp.Data, &certBundle)
@@ -1001,7 +1006,7 @@ func generateRoleSteps(t *testing.T, useCSRs bool) []logicaltest.TestStep {
}
}
getLocalityCheck := func(role roleEntry) logicaltest.TestCheckFunc {
getLocalityCheck := func(role issuing.RoleEntry) logicaltest.TestCheckFunc {
var certBundle certutil.CertBundle
return func(resp *logical.Response) error {
err := mapstructure.Decode(resp.Data, &certBundle)
@@ -1022,7 +1027,7 @@ func generateRoleSteps(t *testing.T, useCSRs bool) []logicaltest.TestStep {
}
}
getProvinceCheck := func(role roleEntry) logicaltest.TestCheckFunc {
getProvinceCheck := func(role issuing.RoleEntry) logicaltest.TestCheckFunc {
var certBundle certutil.CertBundle
return func(resp *logical.Response) error {
err := mapstructure.Decode(resp.Data, &certBundle)
@@ -1043,7 +1048,7 @@ func generateRoleSteps(t *testing.T, useCSRs bool) []logicaltest.TestStep {
}
}
getStreetAddressCheck := func(role roleEntry) logicaltest.TestCheckFunc {
getStreetAddressCheck := func(role issuing.RoleEntry) logicaltest.TestCheckFunc {
var certBundle certutil.CertBundle
return func(resp *logical.Response) error {
err := mapstructure.Decode(resp.Data, &certBundle)
@@ -1064,7 +1069,7 @@ func generateRoleSteps(t *testing.T, useCSRs bool) []logicaltest.TestStep {
}
}
getPostalCodeCheck := func(role roleEntry) logicaltest.TestCheckFunc {
getPostalCodeCheck := func(role issuing.RoleEntry) logicaltest.TestCheckFunc {
var certBundle certutil.CertBundle
return func(resp *logical.Response) error {
err := mapstructure.Decode(resp.Data, &certBundle)
@@ -1085,7 +1090,7 @@ func generateRoleSteps(t *testing.T, useCSRs bool) []logicaltest.TestStep {
}
}
getNotBeforeCheck := func(role roleEntry) logicaltest.TestCheckFunc {
getNotBeforeCheck := func(role issuing.RoleEntry) logicaltest.TestCheckFunc {
var certBundle certutil.CertBundle
return func(resp *logical.Response) error {
err := mapstructure.Decode(resp.Data, &certBundle)
@@ -1110,7 +1115,9 @@ func generateRoleSteps(t *testing.T, useCSRs bool) []logicaltest.TestStep {
// Returns a TestCheckFunc that performs various validity checks on the
// returned certificate information, mostly within checkCertsAndPrivateKey
getCnCheck := func(name string, role roleEntry, key crypto.Signer, usage x509.KeyUsage, extUsage x509.ExtKeyUsage, validity time.Duration) logicaltest.TestCheckFunc {
getCnCheck := func(name string, role issuing.RoleEntry, key crypto.Signer, usage x509.KeyUsage,
extUsage x509.ExtKeyUsage, validity time.Duration,
) logicaltest.TestCheckFunc {
var certBundle certutil.CertBundle
return func(resp *logical.Response) error {
err := mapstructure.Decode(resp.Data, &certBundle)
@@ -1333,7 +1340,7 @@ func generateRoleSteps(t *testing.T, useCSRs bool) []logicaltest.TestStep {
}
roleVals.KeyUsage = usage
parsedKeyUsage := parseKeyUsages(roleVals.KeyUsage)
parsedKeyUsage := parsing.ParseKeyUsages(roleVals.KeyUsage)
if parsedKeyUsage == 0 && len(usage) != 0 {
panic("parsed key usages was zero")
}
@@ -1592,7 +1599,7 @@ func generateRoleSteps(t *testing.T, useCSRs bool) []logicaltest.TestStep {
}
{
getOtherCheck := func(expectedOthers ...otherNameUtf8) logicaltest.TestCheckFunc {
getOtherCheck := func(expectedOthers ...issuing.OtherNameUtf8) logicaltest.TestCheckFunc {
return func(resp *logical.Response) error {
var certBundle certutil.CertBundle
err := mapstructure.Decode(resp.Data, &certBundle)
@@ -1608,7 +1615,7 @@ func generateRoleSteps(t *testing.T, useCSRs bool) []logicaltest.TestStep {
if err != nil {
return err
}
var expected []otherNameUtf8
var expected []issuing.OtherNameUtf8
expected = append(expected, expectedOthers...)
if diff := deep.Equal(foundOthers, expected); len(diff) > 0 {
return fmt.Errorf("wrong SAN IPs, diff: %v", diff)
@@ -1617,11 +1624,11 @@ func generateRoleSteps(t *testing.T, useCSRs bool) []logicaltest.TestStep {
}
}
addOtherSANTests := func(useCSRs, useCSRSANs bool, allowedOtherSANs []string, errorOk bool, otherSANs []string, csrOtherSANs []otherNameUtf8, check logicaltest.TestCheckFunc) {
otherSansMap := func(os []otherNameUtf8) map[string][]string {
addOtherSANTests := func(useCSRs, useCSRSANs bool, allowedOtherSANs []string, errorOk bool, otherSANs []string, csrOtherSANs []issuing.OtherNameUtf8, check logicaltest.TestCheckFunc) {
otherSansMap := func(os []issuing.OtherNameUtf8) map[string][]string {
ret := make(map[string][]string)
for _, o := range os {
ret[o.oid] = append(ret[o.oid], o.value)
ret[o.Oid] = append(ret[o.Oid], o.Value)
}
return ret
}
@@ -1652,14 +1659,14 @@ func generateRoleSteps(t *testing.T, useCSRs bool) []logicaltest.TestStep {
roleVals.UseCSRCommonName = true
commonNames.Localhost = true
newOtherNameUtf8 := func(s string) (ret otherNameUtf8) {
newOtherNameUtf8 := func(s string) (ret issuing.OtherNameUtf8) {
pieces := strings.Split(s, ";")
if len(pieces) == 2 {
piecesRest := strings.Split(pieces[1], ":")
if len(piecesRest) == 2 {
switch strings.ToUpper(piecesRest[0]) {
case "UTF-8", "UTF8":
return otherNameUtf8{oid: pieces[0], value: piecesRest[1]}
return issuing.OtherNameUtf8{Oid: pieces[0], Value: piecesRest[1]}
}
}
}
@@ -1669,7 +1676,7 @@ func generateRoleSteps(t *testing.T, useCSRs bool) []logicaltest.TestStep {
oid1 := "1.3.6.1.4.1.311.20.2.3"
oth1str := oid1 + ";utf8:devops@nope.com"
oth1 := newOtherNameUtf8(oth1str)
oth2 := otherNameUtf8{oid1, "me@example.com"}
oth2 := issuing.OtherNameUtf8{oid1, "me@example.com"}
// allowNone, allowAll := []string{}, []string{oid1 + ";UTF-8:*"}
allowNone, allowAll := []string{}, []string{"*"}
@@ -1684,15 +1691,15 @@ func generateRoleSteps(t *testing.T, useCSRs bool) []logicaltest.TestStep {
// Given OtherSANs as API argument and useCSRSANs false, CSR arg ignored.
addOtherSANTests(useCSRs, false, allowAll, false, []string{oth1str},
[]otherNameUtf8{oth2}, getOtherCheck(oth1))
[]issuing.OtherNameUtf8{oth2}, getOtherCheck(oth1))
if useCSRs {
// OtherSANs not allowed, valid OtherSANs provided via CSR, should be an error.
addOtherSANTests(useCSRs, true, allowNone, true, nil, []otherNameUtf8{oth1}, nil)
addOtherSANTests(useCSRs, true, allowNone, true, nil, []issuing.OtherNameUtf8{oth1}, nil)
// Given OtherSANs as both API and CSR arguments and useCSRSANs=true, API arg ignored.
addOtherSANTests(useCSRs, false, allowAll, false, []string{oth2.String()},
[]otherNameUtf8{oth1}, getOtherCheck(oth2))
[]issuing.OtherNameUtf8{oth1}, getOtherCheck(oth2))
}
}
@@ -2405,7 +2412,7 @@ func TestBackend_Root_Idempotency(t *testing.T) {
certSkid := certutil.GetHexFormatted(cert.SubjectKeyId, ":")
// -> Validate the SKID matches between the root cert and the key
resp, err = CBRead(b, s, "key/"+keyId1.(keyID).String())
resp, err = CBRead(b, s, "key/"+keyId1.(issuing.KeyID).String())
require.NoError(t, err)
require.NotNil(t, resp, "expected a response")
require.Equal(t, resp.Data["subject_key_id"], certSkid)
@@ -2427,7 +2434,7 @@ func TestBackend_Root_Idempotency(t *testing.T) {
certSkid = certutil.GetHexFormatted(cert.SubjectKeyId, ":")
// -> Validate the SKID matches between the root cert and the key
resp, err = CBRead(b, s, "key/"+keyId2.(keyID).String())
resp, err = CBRead(b, s, "key/"+keyId2.(issuing.KeyID).String())
require.NoError(t, err)
require.NotNil(t, resp, "expected a response")
require.Equal(t, resp.Data["subject_key_id"], certSkid)
@@ -2562,7 +2569,7 @@ func TestBackend_SignIntermediate_AllowedPastCAValidity(t *testing.T) {
})
schema.ValidateResponse(t, schema.GetResponseSchema(t, b_root.Route("intermediate/generate/internal"), logical.UpdateOperation), resp, true)
require.Contains(t, resp.Data, "key_id")
intKeyId := resp.Data["key_id"].(keyID)
intKeyId := resp.Data["key_id"].(issuing.KeyID)
csr := resp.Data["csr"]
resp, err = CBRead(b_int, s_int, "key/"+intKeyId.String())
@@ -2764,7 +2771,7 @@ func TestBackend_SignSelfIssued(t *testing.T) {
}
sc := b.makeStorageContext(context.Background(), storage)
signingBundle, err := sc.fetchCAInfo(defaultRef, ReadOnlyUsage)
signingBundle, err := sc.fetchCAInfo(defaultRef, issuing.ReadOnlyUsage)
if err != nil {
t.Fatal(err)
}
@@ -3123,11 +3130,14 @@ func TestBackend_OID_SANs(t *testing.T) {
cert.DNSNames[2] != "foobar.com" {
t.Fatalf("unexpected DNS SANs %v", cert.DNSNames)
}
expectedOtherNames := []otherNameUtf8{{oid1, val1}, {oid2, val2}}
expectedOtherNames := []issuing.OtherNameUtf8{{oid1, val1}, {oid2, val2}}
foundOtherNames, err := getOtherSANsFromX509Extensions(cert.Extensions)
if err != nil {
t.Fatal(err)
}
// Sort our returned list as SANS are built internally with a map so ordering can be inconsistent
slices.SortFunc(foundOtherNames, func(a, b issuing.OtherNameUtf8) int { return cmp.Compare(a.Oid, b.Oid) })
if diff := deep.Equal(expectedOtherNames, foundOtherNames); len(diff) != 0 {
t.Errorf("unexpected otherNames: %v", diff)
}
@@ -3874,9 +3884,11 @@ func TestBackend_RevokePlusTidy_Intermediate(t *testing.T) {
"maintain_stored_certificate_counts": true,
"publish_stored_certificate_count_metrics": true,
})
require.NoError(t, err, "failed calling auto-tidy")
_, err = client.Logical().Write("/sys/plugins/reload/backend", map[string]interface{}{
"mounts": "pki/",
})
require.NoError(t, err, "failed calling backend reload")
// Check the metrics initialized in order to calculate backendUUID for /pki
// BackendUUID not consistent during tests with UUID from /sys/mounts/pki
@@ -4934,9 +4946,9 @@ func TestRootWithExistingKey(t *testing.T) {
resp, err = CBList(b, s, "issuers")
require.NoError(t, err)
require.Equal(t, 3, len(resp.Data["keys"].([]string)))
require.Contains(t, resp.Data["keys"], string(myIssuerId1.(issuerID)))
require.Contains(t, resp.Data["keys"], string(myIssuerId2.(issuerID)))
require.Contains(t, resp.Data["keys"], string(myIssuerId3.(issuerID)))
require.Contains(t, resp.Data["keys"], string(myIssuerId1.(issuing.IssuerID)))
require.Contains(t, resp.Data["keys"], string(myIssuerId2.(issuing.IssuerID)))
require.Contains(t, resp.Data["keys"], string(myIssuerId3.(issuing.IssuerID)))
}
func TestIntermediateWithExistingKey(t *testing.T) {
@@ -5718,17 +5730,18 @@ func TestBackend_InitializeCertificateCounts(t *testing.T) {
}
}
if b.certCount.Load() != 6 {
t.Fatalf("Failed to count six certificates root,A,B,C,D,E, instead counted %d certs", b.certCount.Load())
certCounter := b.GetCertificateCounter()
if certCounter.CertificateCount() != 6 {
t.Fatalf("Failed to count six certificates root,A,B,C,D,E, instead counted %d certs", certCounter.CertificateCount())
}
if b.revokedCertCount.Load() != 2 {
t.Fatalf("Failed to count two revoked certificates A+B, instead counted %d certs", b.revokedCertCount.Load())
if certCounter.RevokedCount() != 2 {
t.Fatalf("Failed to count two revoked certificates A+B, instead counted %d certs", certCounter.RevokedCount())
}
// Simulates listing while initialize in progress, by "restarting it"
b.certCount.Store(0)
b.revokedCertCount.Store(0)
b.certsCounted.Store(false)
certCounter.certCount.Store(0)
certCounter.revokedCertCount.Store(0)
certCounter.certsCounted.Store(false)
// Revoke certificates C, D
dirtyRevocations := serials[2:4]
@@ -5753,15 +5766,16 @@ func TestBackend_InitializeCertificateCounts(t *testing.T) {
}
// Run initialize
b.initializeStoredCertificateCounts(ctx)
err = b.initializeStoredCertificateCounts(ctx)
require.NoError(t, err, "failed initializing certificate counts")
// Test certificate count
if b.certCount.Load() != 8 {
t.Fatalf("Failed to initialize count of certificates root, A,B,C,D,E,F,G counted %d certs", b.certCount.Load())
if certCounter.CertificateCount() != 8 {
t.Fatalf("Failed to initialize count of certificates root, A,B,C,D,E,F,G counted %d certs", certCounter.CertificateCount())
}
if b.revokedCertCount.Load() != 4 {
t.Fatalf("Failed to count revoked certificates A,B,C,D counted %d certs", b.revokedCertCount.Load())
if certCounter.RevokedCount() != 4 {
t.Fatalf("Failed to count revoked certificates A,B,C,D counted %d certs", certCounter.RevokedCount())
}
return
@@ -6147,7 +6161,7 @@ func TestPKI_TemplatedAIAs(t *testing.T) {
require.NoError(t, err)
resp, err = CBWrite(b, s, "root/generate/internal", rootData)
requireSuccessNonNilResponse(t, resp, err)
issuerId := string(resp.Data["issuer_id"].(issuerID))
issuerId := string(resp.Data["issuer_id"].(issuing.IssuerID))
// Now write the original AIA config and sign a leaf.
_, err = CBWrite(b, s, "config/urls", aiaData)
@@ -7063,7 +7077,7 @@ func TestPatchIssuer(t *testing.T) {
"issuer_name": "root",
})
requireSuccessNonNilResponse(t, resp, err, "failed generating root issuer")
id := string(resp.Data["issuer_id"].(issuerID))
id := string(resp.Data["issuer_id"].(issuing.IssuerID))
// 2. Enable Cluster paths
resp, err = CBWrite(b, s, "config/urls", map[string]interface{}{

View File

@@ -18,9 +18,12 @@ import (
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/certutil"
"github.com/hashicorp/vault/sdk/logical"
"github.com/hashicorp/vault/builtin/logical/pki/issuing"
"github.com/hashicorp/vault/builtin/logical/pki/managed_key"
)
func getGenerationParams(sc *storageContext, data *framework.FieldData) (exported bool, format string, role *roleEntry, errorResp *logical.Response) {
func getGenerationParams(sc *storageContext, data *framework.FieldData) (exported bool, format string, role *issuing.RoleEntry, errorResp *logical.Response) {
exportedStr := data.Get("exported").(string)
switch exportedStr {
case "exported":
@@ -47,7 +50,7 @@ func getGenerationParams(sc *storageContext, data *framework.FieldData) (exporte
return
}
role = &roleEntry{
role = &issuing.RoleEntry{
TTL: time.Duration(data.Get("ttl").(int)) * time.Second,
KeyType: keyType,
KeyBits: keyBits,
@@ -90,7 +93,7 @@ func generateCABundle(sc *storageContext, input *inputBundle, data *certutil.Cre
if err != nil {
return nil, err
}
return generateManagedKeyCABundle(ctx, b, keyId, data, randomSource)
return managed_key.GenerateManagedKeyCABundle(ctx, b, keyId, data, randomSource)
}
if existingKeyRequested(input) {
@@ -104,12 +107,12 @@ func generateCABundle(sc *storageContext, input *inputBundle, data *certutil.Cre
return nil, err
}
if keyEntry.isManagedPrivateKey() {
keyId, err := keyEntry.getManagedKeyUUID()
if keyEntry.IsManagedPrivateKey() {
keyId, err := issuing.GetManagedKeyUUID(keyEntry)
if err != nil {
return nil, err
}
return generateManagedKeyCABundle(ctx, b, keyId, data, randomSource)
return managed_key.GenerateManagedKeyCABundle(ctx, b, keyId, data, randomSource)
}
return certutil.CreateCertificateWithKeyGenerator(data, randomSource, existingKeyGeneratorFromBytes(keyEntry))
@@ -128,7 +131,7 @@ func generateCSRBundle(sc *storageContext, input *inputBundle, data *certutil.Cr
return nil, err
}
return generateManagedKeyCSRBundle(ctx, b, keyId, data, addBasicConstraints, randomSource)
return managed_key.GenerateManagedKeyCSRBundle(ctx, b, keyId, data, addBasicConstraints, randomSource)
}
if existingKeyRequested(input) {
@@ -142,12 +145,12 @@ func generateCSRBundle(sc *storageContext, input *inputBundle, data *certutil.Cr
return nil, err
}
if key.isManagedPrivateKey() {
keyId, err := key.getManagedKeyUUID()
if key.IsManagedPrivateKey() {
keyId, err := issuing.GetManagedKeyUUID(key)
if err != nil {
return nil, err
}
return generateManagedKeyCSRBundle(ctx, b, keyId, data, addBasicConstraints, randomSource)
return managed_key.GenerateManagedKeyCSRBundle(ctx, b, keyId, data, addBasicConstraints, randomSource)
}
return certutil.CreateCSRWithKeyGenerator(data, addBasicConstraints, randomSource, existingKeyGeneratorFromBytes(key))
@@ -157,10 +160,7 @@ func generateCSRBundle(sc *storageContext, input *inputBundle, data *certutil.Cr
}
func parseCABundle(ctx context.Context, b *backend, bundle *certutil.CertBundle) (*certutil.ParsedCertBundle, error) {
if bundle.PrivateKeyType == certutil.ManagedPrivateKey {
return parseManagedKeyCABundle(ctx, b, bundle)
}
return bundle.ToParsedCertBundle()
return issuing.ParseCABundle(ctx, b, bundle)
}
func (sc *storageContext) getKeyTypeAndBitsForRole(data *framework.FieldData) (string, int, error) {
@@ -192,7 +192,7 @@ func (sc *storageContext) getKeyTypeAndBitsForRole(data *framework.FieldData) (s
return "", 0, errors.New("unable to determine managed key id: " + err.Error())
}
pubKeyManagedKey, err := getManagedKeyPublicKey(sc.Context, sc.Backend, keyId)
pubKeyManagedKey, err := managed_key.GetManagedKeyPublicKey(sc.Context, sc.Backend, keyId)
if err != nil {
return "", 0, errors.New("failed to lookup public key from managed key: " + err.Error())
}
@@ -245,7 +245,7 @@ func getKeyTypeAndBitsFromPublicKeyForRole(pubKey crypto.PublicKey) (certutil.Pr
return keyType, keyBits, nil
}
func (sc *storageContext) getExistingKeyFromRef(keyRef string) (*keyEntry, error) {
func (sc *storageContext) getExistingKeyFromRef(keyRef string) (*issuing.KeyEntry, error) {
keyId, err := sc.resolveKeyReference(keyRef)
if err != nil {
return nil, err
@@ -253,7 +253,7 @@ func (sc *storageContext) getExistingKeyFromRef(keyRef string) (*keyEntry, error
return sc.fetchKeyById(keyId)
}
func existingKeyGeneratorFromBytes(key *keyEntry) certutil.KeyGenerator {
func existingKeyGeneratorFromBytes(key *issuing.KeyEntry) certutil.KeyGenerator {
return func(_ string, _ int, container certutil.ParsedPrivateKeyContainer, _ io.Reader) error {
signer, _, pemBytes, err := getSignerFromKeyEntryBytes(key)
if err != nil {
@@ -264,61 +264,3 @@ func existingKeyGeneratorFromBytes(key *keyEntry) certutil.KeyGenerator {
return nil
}
}
func buildSignVerbatimRoleWithNoData(role *roleEntry) *roleEntry {
data := &framework.FieldData{
Raw: map[string]interface{}{},
Schema: addSignVerbatimRoleFields(map[string]*framework.FieldSchema{}),
}
return buildSignVerbatimRole(data, role)
}
func buildSignVerbatimRole(data *framework.FieldData, role *roleEntry) *roleEntry {
entry := &roleEntry{
AllowLocalhost: true,
AllowAnyName: true,
AllowIPSANs: true,
AllowWildcardCertificates: new(bool),
EnforceHostnames: false,
KeyType: "any",
UseCSRCommonName: true,
UseCSRSANs: true,
AllowedOtherSANs: []string{"*"},
AllowedSerialNumbers: []string{"*"},
AllowedURISANs: []string{"*"},
AllowedUserIDs: []string{"*"},
CNValidations: []string{"disabled"},
GenerateLease: new(bool),
// If adding new fields to be read, update the field list within addSignVerbatimRoleFields
KeyUsage: data.Get("key_usage").([]string),
ExtKeyUsage: data.Get("ext_key_usage").([]string),
ExtKeyUsageOIDs: data.Get("ext_key_usage_oids").([]string),
SignatureBits: data.Get("signature_bits").(int),
UsePSS: data.Get("use_pss").(bool),
}
*entry.AllowWildcardCertificates = true
*entry.GenerateLease = false
if role != nil {
if role.TTL > 0 {
entry.TTL = role.TTL
}
if role.MaxTTL > 0 {
entry.MaxTTL = role.MaxTTL
}
if role.GenerateLease != nil {
*entry.GenerateLease = *role.GenerateLease
}
if role.NotBeforeDuration > 0 {
entry.NotBeforeDuration = role.NotBeforeDuration
}
entry.NoStore = role.NoStore
entry.Issuer = role.Issuer
}
if len(entry.Issuer) == 0 {
entry.Issuer = defaultRef
}
return entry
}

File diff suppressed because it is too large Load Diff

View File

@@ -10,6 +10,7 @@ import (
"strings"
"testing"
"github.com/hashicorp/vault/builtin/logical/pki/issuing"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/logical"
)
@@ -97,7 +98,7 @@ func TestPki_FetchCertBySerial(t *testing.T) {
// order-preserving way.
func TestPki_MultipleOUs(t *testing.T) {
t.Parallel()
var b backend
b, _ := CreateBackendWithStorage(t)
fields := addCACommonFields(map[string]*framework.FieldSchema{})
apiData := &framework.FieldData{
@@ -109,12 +110,12 @@ func TestPki_MultipleOUs(t *testing.T) {
}
input := &inputBundle{
apiData: apiData,
role: &roleEntry{
role: &issuing.RoleEntry{
MaxTTL: 3600,
OU: []string{"Z", "E", "V"},
},
}
cb, _, err := generateCreationBundle(&b, input, nil, nil)
cb, _, err := generateCreationBundle(b, input, nil, nil)
if err != nil {
t.Fatalf("Error: %v", err)
}
@@ -129,7 +130,7 @@ func TestPki_MultipleOUs(t *testing.T) {
func TestPki_PermitFQDNs(t *testing.T) {
t.Parallel()
var b backend
b, _ := CreateBackendWithStorage(t)
fields := addCACommonFields(map[string]*framework.FieldSchema{})
cases := map[string]struct {
@@ -146,7 +147,7 @@ func TestPki_PermitFQDNs(t *testing.T) {
"ttl": 3600,
},
},
role: &roleEntry{
role: &issuing.RoleEntry{
AllowAnyName: true,
MaxTTL: 3600,
EnforceHostnames: true,
@@ -165,7 +166,7 @@ func TestPki_PermitFQDNs(t *testing.T) {
"ttl": 3600,
},
},
role: &roleEntry{
role: &issuing.RoleEntry{
AllowedDomains: []string{"example.net", "EXAMPLE.COM"},
AllowBareDomains: true,
MaxTTL: 3600,
@@ -183,7 +184,7 @@ func TestPki_PermitFQDNs(t *testing.T) {
"ttl": 3600,
},
},
role: &roleEntry{
role: &issuing.RoleEntry{
AllowedDomains: []string{"example.com", "*.Example.com"},
AllowGlobDomains: true,
MaxTTL: 3600,
@@ -201,7 +202,7 @@ func TestPki_PermitFQDNs(t *testing.T) {
"ttl": 3600,
},
},
role: &roleEntry{
role: &issuing.RoleEntry{
AllowedDomains: []string{"test@testemail.com"},
AllowBareDomains: true,
MaxTTL: 3600,
@@ -219,7 +220,7 @@ func TestPki_PermitFQDNs(t *testing.T) {
"ttl": 3600,
},
},
role: &roleEntry{
role: &issuing.RoleEntry{
AllowedDomains: []string{"testemail.com"},
AllowBareDomains: true,
MaxTTL: 3600,
@@ -234,7 +235,7 @@ func TestPki_PermitFQDNs(t *testing.T) {
name := name
testCase := testCase
t.Run(name, func(t *testing.T) {
cb, _, err := generateCreationBundle(&b, testCase.input, nil, nil)
cb, _, err := generateCreationBundle(b, testCase.input, nil, nil)
if err != nil {
t.Fatalf("Error: %v", err)
}

View File

@@ -16,6 +16,7 @@ import (
"testing"
"time"
"github.com/hashicorp/vault/builtin/logical/pki/issuing"
"github.com/hashicorp/vault/sdk/logical"
)
@@ -575,7 +576,7 @@ func (c CBIssueLeaf) RevokeLeaf(t testing.TB, b *backend, s logical.Storage, kno
if resp == nil {
t.Fatalf("failed to read default issuer config: nil response")
}
defaultID := resp.Data["default"].(issuerID).String()
defaultID := resp.Data["default"].(issuing.IssuerID).String()
c.Issuer = defaultID
issuer = nil
}
@@ -637,7 +638,7 @@ func (c CBIssueLeaf) Run(t testing.TB, b *backend, s logical.Storage, knownKeys
if resp == nil {
t.Fatalf("failed to read default issuer config: nil response")
}
defaultID := resp.Data["default"].(issuerID).String()
defaultID := resp.Data["default"].(issuing.IssuerID).String()
resp, err = CBRead(b, s, "issuer/"+c.Issuer)
if err != nil {
@@ -646,7 +647,7 @@ func (c CBIssueLeaf) Run(t testing.TB, b *backend, s logical.Storage, knownKeys
if resp == nil {
t.Fatalf("failed to read issuer %v: nil response", c.Issuer)
}
ourID := resp.Data["issuer_id"].(issuerID).String()
ourID := resp.Data["issuer_id"].(issuing.IssuerID).String()
areDefault := ourID == defaultID
for _, usage := range []string{"read-only", "crl-signing", "issuing-certificates", "issuing-certificates,crl-signing"} {

View File

@@ -9,10 +9,11 @@ import (
"fmt"
"sort"
"github.com/hashicorp/vault/builtin/logical/pki/issuing"
"github.com/hashicorp/vault/sdk/helper/errutil"
)
func prettyIssuer(issuerIdEntryMap map[issuerID]*issuerEntry, issuer issuerID) string {
func prettyIssuer(issuerIdEntryMap map[issuing.IssuerID]*issuing.IssuerEntry, issuer issuing.IssuerID) string {
if entry, ok := issuerIdEntryMap[issuer]; ok && len(entry.Name) > 0 {
return "[id:" + string(issuer) + "/name:" + entry.Name + "]"
}
@@ -20,7 +21,7 @@ func prettyIssuer(issuerIdEntryMap map[issuerID]*issuerEntry, issuer issuerID) s
return "[" + string(issuer) + "]"
}
func (sc *storageContext) rebuildIssuersChains(referenceCert *issuerEntry /* optional */) error {
func (sc *storageContext) rebuildIssuersChains(referenceCert *issuing.IssuerEntry /* optional */) error {
// This function rebuilds the CAChain field of all known issuers. This
// function should usually be invoked when a new issuer is added to the
// pool of issuers.
@@ -116,22 +117,22 @@ func (sc *storageContext) rebuildIssuersChains(referenceCert *issuerEntry /* opt
// fourth maps that certificate back to the other issuers with that
// subject (note the keyword _other_: we'll exclude self-loops here) --
// either via a parent or child relationship.
issuerIdEntryMap := make(map[issuerID]*issuerEntry, len(issuers))
issuerIdCertMap := make(map[issuerID]*x509.Certificate, len(issuers))
issuerIdParentsMap := make(map[issuerID][]issuerID, len(issuers))
issuerIdChildrenMap := make(map[issuerID][]issuerID, len(issuers))
issuerIdEntryMap := make(map[issuing.IssuerID]*issuing.IssuerEntry, len(issuers))
issuerIdCertMap := make(map[issuing.IssuerID]*x509.Certificate, len(issuers))
issuerIdParentsMap := make(map[issuing.IssuerID][]issuing.IssuerID, len(issuers))
issuerIdChildrenMap := make(map[issuing.IssuerID][]issuing.IssuerID, len(issuers))
// For every known issuer, we map that subject back to the id of issuers
// containing that subject. This lets us build our issuerID -> parents
// containing that subject. This lets us build our IssuerID -> parents
// mapping efficiently. Worst case we'll have a single linear chain where
// every entry has a distinct subject.
subjectIssuerIdsMap := make(map[string][]issuerID, len(issuers))
subjectIssuerIdsMap := make(map[string][]issuing.IssuerID, len(issuers))
// First, read every issuer entry from storage. We'll propagate entries
// to three of the maps here: all but issuerIdParentsMap and
// issuerIdChildrenMap, which we'll do in a second pass.
for _, identifier := range issuers {
var stored *issuerEntry
var stored *issuing.IssuerEntry
// When the reference issuer is provided and matches this identifier,
// prefer the updated reference copy instead.
@@ -261,8 +262,8 @@ func (sc *storageContext) rebuildIssuersChains(referenceCert *issuerEntry /* opt
// manually building their chain prior to starting the topographical sort.
//
// This thus runs in O(|V| + |E|) -> O(n^2) in the number of issuers.
processedIssuers := make(map[issuerID]bool, len(issuers))
toVisit := make([]issuerID, 0, len(issuers))
processedIssuers := make(map[issuing.IssuerID]bool, len(issuers))
toVisit := make([]issuing.IssuerID, 0, len(issuers))
// Handle any explicitly constructed certificate chains. Here, we don't
// validate much what the user provides; if they provide since-deleted
@@ -323,7 +324,7 @@ func (sc *storageContext) rebuildIssuersChains(referenceCert *issuerEntry /* opt
// ensure we don't accidentally infinite-loop (if we introduce a bug).
maxVisitCount := len(issuers)*len(issuers)*len(issuers) + 100
for len(toVisit) > 0 && maxVisitCount >= 0 {
var issuer issuerID
var issuer issuing.IssuerID
issuer, toVisit = toVisit[0], toVisit[1:]
// If (and only if) we're presently starved for next nodes to visit,
@@ -387,8 +388,8 @@ func (sc *storageContext) rebuildIssuersChains(referenceCert *issuerEntry /* opt
// However, if you directly step onto the cross-signed, now you're
// taken in an alternative direction (via its chain), and must
// revisit any roots later.
var roots []issuerID
var intermediates []issuerID
var roots []issuing.IssuerID
var intermediates []issuing.IssuerID
for _, parentCertId := range parentCerts {
if bytes.Equal(issuerIdCertMap[parentCertId].RawSubject, issuerIdCertMap[parentCertId].RawIssuer) {
roots = append(roots, parentCertId)
@@ -470,7 +471,7 @@ func (sc *storageContext) rebuildIssuersChains(referenceCert *issuerEntry /* opt
return nil
}
func addToChainIfNotExisting(includedParentCerts map[string]bool, entry *issuerEntry, certToAdd string) {
func addToChainIfNotExisting(includedParentCerts map[string]bool, entry *issuing.IssuerEntry, certToAdd string) {
included, ok := includedParentCerts[certToAdd]
if ok && included {
return
@@ -481,15 +482,15 @@ func addToChainIfNotExisting(includedParentCerts map[string]bool, entry *issuerE
}
func processAnyCliqueOrCycle(
issuers []issuerID,
processedIssuers map[issuerID]bool,
toVisit []issuerID,
issuerIdEntryMap map[issuerID]*issuerEntry,
issuerIdCertMap map[issuerID]*x509.Certificate,
issuerIdParentsMap map[issuerID][]issuerID,
issuerIdChildrenMap map[issuerID][]issuerID,
subjectIssuerIdsMap map[string][]issuerID,
) ([]issuerID /* toVisit */, error) {
issuers []issuing.IssuerID,
processedIssuers map[issuing.IssuerID]bool,
toVisit []issuing.IssuerID,
issuerIdEntryMap map[issuing.IssuerID]*issuing.IssuerEntry,
issuerIdCertMap map[issuing.IssuerID]*x509.Certificate,
issuerIdParentsMap map[issuing.IssuerID][]issuing.IssuerID,
issuerIdChildrenMap map[issuing.IssuerID][]issuing.IssuerID,
subjectIssuerIdsMap map[string][]issuing.IssuerID,
) ([]issuing.IssuerID /* toVisit */, error) {
// Topological sort really only works on directed acyclic graphs (DAGs).
// But a pool of arbitrary (issuer) certificates are actually neither!
// This pool could contain both cliques and cycles. Because this could
@@ -550,15 +551,15 @@ func processAnyCliqueOrCycle(
// Finally -- it isn't enough to consider this chain in isolation
// either. We need to consider _all_ parents and ensure they've been
// processed before processing this closure.
var cliques [][]issuerID
var cycles [][]issuerID
closure := make(map[issuerID]bool)
var cliques [][]issuing.IssuerID
var cycles [][]issuing.IssuerID
closure := make(map[issuing.IssuerID]bool)
var cliquesToProcess []issuerID
var cliquesToProcess []issuing.IssuerID
cliquesToProcess = append(cliquesToProcess, issuer)
for len(cliquesToProcess) > 0 {
var node issuerID
var node issuing.IssuerID
node, cliquesToProcess = cliquesToProcess[0], cliquesToProcess[1:]
// Skip potential clique nodes which have already been processed
@@ -753,7 +754,7 @@ func processAnyCliqueOrCycle(
return nil, err
}
closure := make(map[issuerID]bool)
closure := make(map[issuing.IssuerID]bool)
for _, cycle := range cycles {
for _, node := range cycle {
closure[node] = true
@@ -811,14 +812,14 @@ func processAnyCliqueOrCycle(
}
func findAllCliques(
processedIssuers map[issuerID]bool,
issuerIdCertMap map[issuerID]*x509.Certificate,
subjectIssuerIdsMap map[string][]issuerID,
issuers []issuerID,
) ([][]issuerID, map[issuerID]int, []issuerID, error) {
var allCliques [][]issuerID
issuerIdCliqueMap := make(map[issuerID]int)
var allCliqueNodes []issuerID
processedIssuers map[issuing.IssuerID]bool,
issuerIdCertMap map[issuing.IssuerID]*x509.Certificate,
subjectIssuerIdsMap map[string][]issuing.IssuerID,
issuers []issuing.IssuerID,
) ([][]issuing.IssuerID, map[issuing.IssuerID]int, []issuing.IssuerID, error) {
var allCliques [][]issuing.IssuerID
issuerIdCliqueMap := make(map[issuing.IssuerID]int)
var allCliqueNodes []issuing.IssuerID
for _, node := range issuers {
// Check if the node has already been visited...
@@ -859,11 +860,11 @@ func findAllCliques(
}
func isOnReissuedClique(
processedIssuers map[issuerID]bool,
issuerIdCertMap map[issuerID]*x509.Certificate,
subjectIssuerIdsMap map[string][]issuerID,
node issuerID,
) ([]issuerID, error) {
processedIssuers map[issuing.IssuerID]bool,
issuerIdCertMap map[issuing.IssuerID]*x509.Certificate,
subjectIssuerIdsMap map[string][]issuing.IssuerID,
node issuing.IssuerID,
) ([]issuing.IssuerID, error) {
// Finding max cliques in arbitrary graphs is a nearly pathological
// problem, usually left to the realm of SAT solvers and NP-Complete
// theoretical.
@@ -891,7 +892,7 @@ func isOnReissuedClique(
// under this reissued clique detection code).
//
// What does this mean for our algorithm? A simple greedy search is
// sufficient. If we index our certificates by subject -> issuerID
// sufficient. If we index our certificates by subject -> IssuerID
// (and cache its value across calls, which we've already done for
// building the parent/child relationship), we can find all other issuers
// with the same public key and subject as the existing node fairly
@@ -925,7 +926,7 @@ func isOnReissuedClique(
// condition (the subject half), so validate they match the other half
// (the issuer half) and the second condition. For node (which is
// included in candidates), the condition should vacuously hold.
var clique []issuerID
var clique []issuing.IssuerID
for _, candidate := range candidates {
// Skip already processed nodes, even if they could be clique
// candidates. We'll treat them as any other (already processed)
@@ -957,7 +958,7 @@ func isOnReissuedClique(
return clique, nil
}
func containsIssuer(collection []issuerID, target issuerID) bool {
func containsIssuer(collection []issuing.IssuerID, target issuing.IssuerID) bool {
if len(collection) == 0 {
return false
}
@@ -971,7 +972,7 @@ func containsIssuer(collection []issuerID, target issuerID) bool {
return false
}
func appendCycleIfNotExisting(knownCycles [][]issuerID, candidate []issuerID) [][]issuerID {
func appendCycleIfNotExisting(knownCycles [][]issuing.IssuerID, candidate []issuing.IssuerID) [][]issuing.IssuerID {
// There's two ways to do cycle detection: canonicalize the cycles,
// rewriting them to have the least (or max) element first or just
// brute force the detection.
@@ -1007,7 +1008,7 @@ func appendCycleIfNotExisting(knownCycles [][]issuerID, candidate []issuerID) []
return knownCycles
}
func canonicalizeCycle(cycle []issuerID) []issuerID {
func canonicalizeCycle(cycle []issuing.IssuerID) []issuing.IssuerID {
// Find the minimum value and put it at the head, keeping the relative
// ordering the same.
minIndex := 0
@@ -1026,11 +1027,11 @@ func canonicalizeCycle(cycle []issuerID) []issuerID {
}
func findCyclesNearClique(
processedIssuers map[issuerID]bool,
issuerIdCertMap map[issuerID]*x509.Certificate,
issuerIdChildrenMap map[issuerID][]issuerID,
cliqueNodes []issuerID,
) ([][]issuerID, error) {
processedIssuers map[issuing.IssuerID]bool,
issuerIdCertMap map[issuing.IssuerID]*x509.Certificate,
issuerIdChildrenMap map[issuing.IssuerID][]issuing.IssuerID,
cliqueNodes []issuing.IssuerID,
) ([][]issuing.IssuerID, error) {
// When we have a reissued clique, we need to find all cycles next to it.
// Presumably, because they all have non-empty parents, they should not
// have been visited yet. We further know that (because we're exploring
@@ -1046,7 +1047,7 @@ func findCyclesNearClique(
// Copy the clique nodes as excluded nodes; we'll avoid exploring cycles
// which have parents that have been already explored.
excludeNodes := cliqueNodes[:]
var knownCycles [][]issuerID
var knownCycles [][]issuing.IssuerID
// We know the node has at least one child, since the clique is non-empty.
for _, child := range issuerIdChildrenMap[cliqueNode] {
@@ -1081,12 +1082,12 @@ func findCyclesNearClique(
}
func findAllCyclesWithNode(
processedIssuers map[issuerID]bool,
issuerIdCertMap map[issuerID]*x509.Certificate,
issuerIdChildrenMap map[issuerID][]issuerID,
source issuerID,
exclude []issuerID,
) ([][]issuerID, error) {
processedIssuers map[issuing.IssuerID]bool,
issuerIdCertMap map[issuing.IssuerID]*x509.Certificate,
issuerIdChildrenMap map[issuing.IssuerID][]issuing.IssuerID,
source issuing.IssuerID,
exclude []issuing.IssuerID,
) ([][]issuing.IssuerID, error) {
// We wish to find all cycles involving this particular node and report
// the corresponding paths. This is a full-graph traversal (excluding
// certain paths) as we're not just checking if a cycle occurred, but
@@ -1096,28 +1097,28 @@ func findAllCyclesWithNode(
maxCycleSize := 8
// Whether we've visited any given node.
cycleVisited := make(map[issuerID]bool)
visitCounts := make(map[issuerID]int)
parentCounts := make(map[issuerID]map[issuerID]bool)
cycleVisited := make(map[issuing.IssuerID]bool)
visitCounts := make(map[issuing.IssuerID]int)
parentCounts := make(map[issuing.IssuerID]map[issuing.IssuerID]bool)
// Paths to the specified node. Some of these might be cycles.
pathsTo := make(map[issuerID][][]issuerID)
pathsTo := make(map[issuing.IssuerID][][]issuing.IssuerID)
// Nodes to visit.
var visitQueue []issuerID
var visitQueue []issuing.IssuerID
// Add the source node to start. In order to set up the paths to a
// given node, we seed pathsTo with the single path involving just
// this node
visitQueue = append(visitQueue, source)
pathsTo[source] = [][]issuerID{{source}}
pathsTo[source] = [][]issuing.IssuerID{{source}}
// Begin building paths.
//
// Loop invariant:
// pathTo[x] contains valid paths to reach this node, from source.
for len(visitQueue) > 0 {
var current issuerID
var current issuing.IssuerID
current, visitQueue = visitQueue[0], visitQueue[1:]
// If we've already processed this node, we have a cycle. Skip this
@@ -1162,7 +1163,7 @@ func findAllCyclesWithNode(
// Track this parent->child relationship to know when to exit.
setOfParents, ok := parentCounts[child]
if !ok {
setOfParents = make(map[issuerID]bool)
setOfParents = make(map[issuing.IssuerID]bool)
parentCounts[child] = setOfParents
}
_, existingParent := setOfParents[current]
@@ -1179,7 +1180,7 @@ func findAllCyclesWithNode(
// externally with an existing path).
addedPath := false
if _, ok := pathsTo[child]; !ok {
pathsTo[child] = make([][]issuerID, 0)
pathsTo[child] = make([][]issuing.IssuerID, 0)
}
for _, path := range pathsTo[current] {
@@ -1204,7 +1205,7 @@ func findAllCyclesWithNode(
return nil, errutil.InternalError{Err: fmt.Sprintf("Error updating certificate path: path of length %d is too long", len(path))}
}
// Make sure to deep copy the path.
newPath := make([]issuerID, 0, len(path)+1)
newPath := make([]issuing.IssuerID, 0, len(path)+1)
newPath = append(newPath, path...)
newPath = append(newPath, child)
@@ -1249,7 +1250,7 @@ func findAllCyclesWithNode(
// Ok, we've now exited from our loop. Any cycles would've been detected
// and their paths recorded in pathsTo. Now we can iterate over these
// (starting a source), clean them up and validate them.
var cycles [][]issuerID
var cycles [][]issuing.IssuerID
for _, cycle := range pathsTo[source] {
// Skip the trivial cycle.
if len(cycle) == 1 && cycle[0] == source {
@@ -1287,8 +1288,8 @@ func findAllCyclesWithNode(
return cycles, nil
}
func reversedCycle(cycle []issuerID) []issuerID {
var result []issuerID
func reversedCycle(cycle []issuing.IssuerID) []issuing.IssuerID {
var result []issuing.IssuerID
for index := len(cycle) - 1; index >= 0; index-- {
result = append(result, cycle[index])
}
@@ -1297,11 +1298,11 @@ func reversedCycle(cycle []issuerID) []issuerID {
}
func computeParentsFromClosure(
processedIssuers map[issuerID]bool,
issuerIdParentsMap map[issuerID][]issuerID,
closure map[issuerID]bool,
) (map[issuerID]bool, bool) {
parents := make(map[issuerID]bool)
processedIssuers map[issuing.IssuerID]bool,
issuerIdParentsMap map[issuing.IssuerID][]issuing.IssuerID,
closure map[issuing.IssuerID]bool,
) (map[issuing.IssuerID]bool, bool) {
parents := make(map[issuing.IssuerID]bool)
for node := range closure {
nodeParents, ok := issuerIdParentsMap[node]
if !ok {
@@ -1326,11 +1327,11 @@ func computeParentsFromClosure(
}
func addNodeCertsToEntry(
issuerIdEntryMap map[issuerID]*issuerEntry,
issuerIdChildrenMap map[issuerID][]issuerID,
issuerIdEntryMap map[issuing.IssuerID]*issuing.IssuerEntry,
issuerIdChildrenMap map[issuing.IssuerID][]issuing.IssuerID,
includedParentCerts map[string]bool,
entry *issuerEntry,
issuersCollection ...[]issuerID,
entry *issuing.IssuerEntry,
issuersCollection ...[]issuing.IssuerID,
) {
for _, collection := range issuersCollection {
// Find a starting point into this collection such that it verifies
@@ -1369,10 +1370,10 @@ func addNodeCertsToEntry(
}
func addParentChainsToEntry(
issuerIdEntryMap map[issuerID]*issuerEntry,
issuerIdEntryMap map[issuing.IssuerID]*issuing.IssuerEntry,
includedParentCerts map[string]bool,
entry *issuerEntry,
parents map[issuerID]bool,
entry *issuing.IssuerEntry,
parents map[issuing.IssuerID]bool,
) {
for parent := range parents {
nodeEntry := issuerIdEntryMap[parent]

View File

@@ -9,13 +9,14 @@ import (
"crypto/x509"
"fmt"
"github.com/hashicorp/vault/builtin/logical/pki/issuing"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/certutil"
"github.com/hashicorp/vault/sdk/logical"
)
// issueAcmeCertUsingCieps based on the passed in ACME information, perform a CIEPS request/response
func issueAcmeCertUsingCieps(_ *backend, _ *acmeContext, _ *logical.Request, _ *framework.FieldData, _ *jwsCtx, _ *acmeAccount, _ *acmeOrder, _ *x509.CertificateRequest) (*certutil.ParsedCertBundle, issuerID, error) {
func issueAcmeCertUsingCieps(_ *backend, _ *acmeContext, _ *logical.Request, _ *framework.FieldData, _ *jwsCtx, _ *acmeAccount, _ *acmeOrder, _ *x509.CertificateRequest) (*certutil.ParsedCertBundle, issuing.IssuerID, error) {
return nil, "", fmt.Errorf("cieps is an enterprise only feature")
}

View File

@@ -4,9 +4,9 @@
package pki
import (
"fmt"
"strings"
"time"
"github.com/hashicorp/vault/builtin/logical/pki/issuing"
)
func (sc *storageContext) isDefaultKeySet() (bool, error) {
@@ -27,14 +27,14 @@ func (sc *storageContext) isDefaultIssuerSet() (bool, error) {
return strings.TrimSpace(config.DefaultIssuerId.String()) != "", nil
}
func (sc *storageContext) updateDefaultKeyId(id keyID) error {
func (sc *storageContext) updateDefaultKeyId(id issuing.KeyID) error {
config, err := sc.getKeysConfig()
if err != nil {
return err
}
if config.DefaultKeyId != id {
return sc.setKeysConfig(&keyConfigEntry{
return sc.setKeysConfig(&issuing.KeyConfigEntry{
DefaultKeyId: id,
})
}
@@ -42,7 +42,7 @@ func (sc *storageContext) updateDefaultKeyId(id keyID) error {
return nil
}
func (sc *storageContext) updateDefaultIssuerId(id issuerID) error {
func (sc *storageContext) updateDefaultIssuerId(id issuing.IssuerID) error {
config, err := sc.getIssuersConfig()
if err != nil {
return err
@@ -55,67 +55,3 @@ func (sc *storageContext) updateDefaultIssuerId(id issuerID) error {
return nil
}
func (sc *storageContext) changeDefaultIssuerTimestamps(oldDefault issuerID, newDefault issuerID) error {
if newDefault == oldDefault {
return nil
}
now := time.Now().UTC()
// When the default issuer changes, we need to modify four
// pieces of information:
//
// 1. The old default issuer's modification time, as it no
// longer works for the /cert/ca path.
// 2. The new default issuer's modification time, as it now
// works for the /cert/ca path.
// 3. & 4. Both issuer's CRLs, as they behave the same, under
// the /cert/crl path!
for _, thisId := range []issuerID{oldDefault, newDefault} {
if len(thisId) == 0 {
continue
}
// 1 & 2 above.
issuer, err := sc.fetchIssuerById(thisId)
if err != nil {
// Due to the lack of transactions, if we deleted the default
// issuer (successfully), but the subsequent issuer config write
// (to clear the default issuer's old id) failed, we might have
// an inconsistent config. If we later hit this loop (and flush
// these timestamps again -- perhaps because the operator
// selected a new default), we'd have erred out here, because
// the since-deleted default issuer doesn't exist. In this case,
// skip the issuer instead of bailing.
err := fmt.Errorf("unable to update issuer (%v)'s modification time: error fetching issuer: %w", thisId, err)
if strings.Contains(err.Error(), "does not exist") {
sc.Backend.Logger().Warn(err.Error())
continue
}
return err
}
issuer.LastModified = now
err = sc.writeIssuer(issuer)
if err != nil {
return fmt.Errorf("unable to update issuer (%v)'s modification time: error persisting issuer: %w", thisId, err)
}
}
// Fetch and update the internalCRLConfigEntry (3&4).
cfg, err := sc.getLocalCRLConfig()
if err != nil {
return fmt.Errorf("unable to update local CRL config's modification time: error fetching local CRL config: %w", err)
}
cfg.LastModified = now
cfg.DeltaLastModified = now
err = sc.setLocalCRLConfig(cfg)
if err != nil {
return fmt.Errorf("unable to update local CRL config's modification time: error persisting local CRL config: %w", err)
}
return nil
}

View File

@@ -13,6 +13,7 @@ import (
"time"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/builtin/logical/pki/issuing"
"github.com/hashicorp/vault/helper/constants"
vaulthttp "github.com/hashicorp/vault/http"
"github.com/hashicorp/vault/sdk/helper/testhelpers/schema"
@@ -1063,7 +1064,7 @@ func TestAutoRebuild(t *testing.T) {
var revInfo revocationInfo
err = json.Unmarshal([]byte(revEntryValue), &revInfo)
require.NoError(t, err)
require.Equal(t, revInfo.CertificateIssuer, issuerID(rootIssuer))
require.Equal(t, revInfo.CertificateIssuer, issuing.IssuerID(rootIssuer))
// New serial should not appear on CRL.
crl = getCrlCertificateList(t, client, "pki")
@@ -1201,7 +1202,7 @@ func TestTidyIssuerAssociation(t *testing.T) {
require.NotEmpty(t, resp.Data["certificate"])
require.NotEmpty(t, resp.Data["issuer_id"])
rootCert := resp.Data["certificate"].(string)
rootID := resp.Data["issuer_id"].(issuerID)
rootID := resp.Data["issuer_id"].(issuing.IssuerID)
// Create a role for issuance.
_, err = CBWrite(b, s, "roles/local-testing", map[string]interface{}{
@@ -1495,9 +1496,9 @@ func TestCRLIssuerRemoval(t *testing.T) {
require.NoError(t, err)
require.NotNil(t, resp)
key := string(resp.Data["key_id"].(keyID))
key := string(resp.Data["key_id"].(issuing.KeyID))
keyIDs = append(keyIDs, key)
issuer := string(resp.Data["issuer_id"].(issuerID))
issuer := string(resp.Data["issuer_id"].(issuing.IssuerID))
issuerIDs = append(issuerIDs, issuer)
}
_, err = CBRead(b, s, "crl/rotate")

View File

@@ -12,14 +12,15 @@ import (
"math/big"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/hashicorp/go-secure-stdlib/parseutil"
"github.com/hashicorp/vault/builtin/logical/pki/issuing"
"github.com/hashicorp/vault/sdk/helper/certutil"
"github.com/hashicorp/vault/sdk/helper/consts"
"github.com/hashicorp/vault/sdk/helper/errutil"
"github.com/hashicorp/vault/sdk/logical"
atomic2 "go.uber.org/atomic"
)
const (
@@ -38,10 +39,10 @@ const (
)
type revocationInfo struct {
CertificateBytes []byte `json:"certificate_bytes"`
RevocationTime int64 `json:"revocation_time"`
RevocationTimeUTC time.Time `json:"revocation_time_utc"`
CertificateIssuer issuerID `json:"issuer_id"`
CertificateBytes []byte `json:"certificate_bytes"`
RevocationTime int64 `json:"revocation_time"`
RevocationTimeUTC time.Time `json:"revocation_time_utc"`
CertificateIssuer issuing.IssuerID `json:"issuer_id"`
}
type revocationRequest struct {
@@ -81,31 +82,31 @@ type (
}
)
// crlBuilder is gatekeeper for controlling various read/write operations to the storage of the CRL.
// CrlBuilder is gatekeeper for controlling various read/write operations to the storage of the CRL.
// The extra complexity arises from secondary performance clusters seeing various writes to its storage
// without the actual API calls. During the storage invalidation process, we do not have the required state
// to actually rebuild the CRLs, so we need to schedule it in a deferred fashion. This allows either
// read or write calls to perform the operation if required, or have the flag reset upon a write operation
//
// The CRL builder also tracks the revocation configuration.
type crlBuilder struct {
type CrlBuilder struct {
_builder sync.Mutex
forceRebuild *atomic2.Bool
forceRebuild *atomic.Bool
canRebuild bool
lastDeltaRebuildCheck time.Time
_config sync.RWMutex
dirty *atomic2.Bool
dirty *atomic.Bool
config crlConfig
haveInitializedConfig bool
// Whether to invalidate our LastModifiedTime due to write on the
// global issuance config.
invalidate *atomic2.Bool
invalidate *atomic.Bool
// Global revocation queue entries get accepted by the invalidate func
// and passed to the crlBuilder for processing.
haveInitializedQueue *atomic2.Bool
// and passed to the CrlBuilder for processing.
haveInitializedQueue *atomic.Bool
revQueue *revocationQueue
removalQueue *revocationQueue
crossQueue *revocationQueue
@@ -116,29 +117,31 @@ const (
_enforceForceFlag = false
)
func newCRLBuilder(canRebuild bool) *crlBuilder {
return &crlBuilder{
forceRebuild: atomic2.NewBool(false),
func newCRLBuilder(canRebuild bool) *CrlBuilder {
builder := &CrlBuilder{
forceRebuild: &atomic.Bool{},
canRebuild: canRebuild,
// Set the last delta rebuild window to now, delaying the first delta
// rebuild by the first rebuild period to give us some time on startup
// to stabilize.
lastDeltaRebuildCheck: time.Now(),
dirty: atomic2.NewBool(true),
dirty: &atomic.Bool{},
config: defaultCrlConfig,
invalidate: atomic2.NewBool(false),
haveInitializedQueue: atomic2.NewBool(false),
invalidate: &atomic.Bool{},
haveInitializedQueue: &atomic.Bool{},
revQueue: newRevocationQueue(),
removalQueue: newRevocationQueue(),
crossQueue: newRevocationQueue(),
}
builder.dirty.Store(true)
return builder
}
func (cb *crlBuilder) markConfigDirty() {
func (cb *CrlBuilder) markConfigDirty() {
cb.dirty.Store(true)
}
func (cb *crlBuilder) reloadConfigIfRequired(sc *storageContext) error {
func (cb *CrlBuilder) reloadConfigIfRequired(sc *storageContext) error {
if cb.dirty.Load() {
// Acquire a write lock.
cb._config.Lock()
@@ -180,12 +183,12 @@ func (cb *crlBuilder) reloadConfigIfRequired(sc *storageContext) error {
return nil
}
func (cb *crlBuilder) notifyOnConfigChange(sc *storageContext, priorConfig crlConfig, newConfig crlConfig) {
func (cb *CrlBuilder) notifyOnConfigChange(sc *storageContext, priorConfig crlConfig, newConfig crlConfig) {
// If you need to hook into a CRL configuration change across different server types
// such as primary clusters as well as performance replicas, it is easier to do here than
// in two places (API layer and in invalidateFunc)
if priorConfig.UnifiedCRL != newConfig.UnifiedCRL && newConfig.UnifiedCRL {
sc.Backend.unifiedTransferStatus.forceRun()
sc.Backend.GetUnifiedTransferStatus().forceRun()
}
if priorConfig.UseGlobalQueue != newConfig.UseGlobalQueue && newConfig.UseGlobalQueue {
@@ -193,7 +196,7 @@ func (cb *crlBuilder) notifyOnConfigChange(sc *storageContext, priorConfig crlCo
}
}
func (cb *crlBuilder) getConfigWithUpdate(sc *storageContext) (*crlConfig, error) {
func (cb *CrlBuilder) getConfigWithUpdate(sc *storageContext) (*crlConfig, error) {
// Config may mutate immediately after accessing, but will be freshly
// fetched if necessary.
if err := cb.reloadConfigIfRequired(sc); err != nil {
@@ -207,12 +210,12 @@ func (cb *crlBuilder) getConfigWithUpdate(sc *storageContext) (*crlConfig, error
return &configCopy, nil
}
func (cb *crlBuilder) getConfigWithForcedUpdate(sc *storageContext) (*crlConfig, error) {
func (cb *CrlBuilder) getConfigWithForcedUpdate(sc *storageContext) (*crlConfig, error) {
cb.markConfigDirty()
return cb.getConfigWithUpdate(sc)
}
func (cb *crlBuilder) writeConfig(sc *storageContext, config *crlConfig) (*crlConfig, error) {
func (cb *CrlBuilder) writeConfig(sc *storageContext, config *crlConfig) (*crlConfig, error) {
cb._config.Lock()
defer cb._config.Unlock()
@@ -242,7 +245,7 @@ func (cb *crlBuilder) writeConfig(sc *storageContext, config *crlConfig) (*crlCo
return config, nil
}
func (cb *crlBuilder) checkForAutoRebuild(sc *storageContext) error {
func (cb *CrlBuilder) checkForAutoRebuild(sc *storageContext) error {
cfg, err := cb.getConfigWithUpdate(sc)
if err != nil {
return err
@@ -307,14 +310,14 @@ func (cb *crlBuilder) checkForAutoRebuild(sc *storageContext) error {
}
// Mark the internal LastModifiedTime tracker invalid.
func (cb *crlBuilder) invalidateCRLBuildTime() {
func (cb *CrlBuilder) invalidateCRLBuildTime() {
cb.invalidate.Store(true)
}
// Update the config to mark the modified CRL. See note in
// updateDefaultIssuerId about why this is necessary.
func (cb *crlBuilder) flushCRLBuildTimeInvalidation(sc *storageContext) error {
if cb.invalidate.CAS(true, false) {
func (cb *CrlBuilder) flushCRLBuildTimeInvalidation(sc *storageContext) error {
if cb.invalidate.CompareAndSwap(true, false) {
// Flush out our invalidation.
cfg, err := sc.getLocalCRLConfig()
if err != nil {
@@ -336,7 +339,7 @@ func (cb *crlBuilder) flushCRLBuildTimeInvalidation(sc *storageContext) error {
// rebuildIfForced is to be called by readers or periodic functions that might need to trigger
// a refresh of the CRL before the read occurs.
func (cb *crlBuilder) rebuildIfForced(sc *storageContext) ([]string, error) {
func (cb *CrlBuilder) rebuildIfForced(sc *storageContext) ([]string, error) {
if cb.forceRebuild.Load() {
return cb._doRebuild(sc, true, _enforceForceFlag)
}
@@ -345,12 +348,12 @@ func (cb *crlBuilder) rebuildIfForced(sc *storageContext) ([]string, error) {
}
// rebuild is to be called by various write apis that know the CRL is to be updated and can be now.
func (cb *crlBuilder) rebuild(sc *storageContext, forceNew bool) ([]string, error) {
func (cb *CrlBuilder) rebuild(sc *storageContext, forceNew bool) ([]string, error) {
return cb._doRebuild(sc, forceNew, _ignoreForceFlag)
}
// requestRebuildIfActiveNode will schedule a rebuild of the CRL from the next read or write api call assuming we are the active node of a cluster
func (cb *crlBuilder) requestRebuildIfActiveNode(b *backend) {
func (cb *CrlBuilder) requestRebuildIfActiveNode(b *backend) {
// Only schedule us on active nodes, as the active node is the only node that can rebuild/write the CRL.
// Note 1: The CRL is cluster specific, so this does need to run on the active node of a performance secondary cluster.
// Note 2: This is called by the storage invalidation function, so it should not block.
@@ -364,7 +367,7 @@ func (cb *crlBuilder) requestRebuildIfActiveNode(b *backend) {
cb.forceRebuild.Store(true)
}
func (cb *crlBuilder) _doRebuild(sc *storageContext, forceNew bool, ignoreForceFlag bool) ([]string, error) {
func (cb *CrlBuilder) _doRebuild(sc *storageContext, forceNew bool, ignoreForceFlag bool) ([]string, error) {
cb._builder.Lock()
defer cb._builder.Unlock()
// Re-read the lock in case someone beat us to the punch between the previous load op.
@@ -384,7 +387,7 @@ func (cb *crlBuilder) _doRebuild(sc *storageContext, forceNew bool, ignoreForceF
return nil, nil
}
func (cb *crlBuilder) _getPresentDeltaWALForClearing(sc *storageContext, path string) ([]string, error) {
func (cb *CrlBuilder) _getPresentDeltaWALForClearing(sc *storageContext, path string) ([]string, error) {
// Clearing of the delta WAL occurs after a new complete CRL has been built.
walSerials, err := sc.Storage.List(sc.Context, path)
if err != nil {
@@ -397,11 +400,11 @@ func (cb *crlBuilder) _getPresentDeltaWALForClearing(sc *storageContext, path st
return walSerials, nil
}
func (cb *crlBuilder) getPresentLocalDeltaWALForClearing(sc *storageContext) ([]string, error) {
func (cb *CrlBuilder) getPresentLocalDeltaWALForClearing(sc *storageContext) ([]string, error) {
return cb._getPresentDeltaWALForClearing(sc, localDeltaWALPath)
}
func (cb *crlBuilder) getPresentUnifiedDeltaWALForClearing(sc *storageContext) ([]string, error) {
func (cb *CrlBuilder) getPresentUnifiedDeltaWALForClearing(sc *storageContext) ([]string, error) {
walClusters, err := sc.Storage.List(sc.Context, unifiedDeltaWALPrefix)
if err != nil {
return nil, fmt.Errorf("error fetching list of clusters with delta WAL entries: %w", err)
@@ -426,7 +429,7 @@ func (cb *crlBuilder) getPresentUnifiedDeltaWALForClearing(sc *storageContext) (
return allPaths, nil
}
func (cb *crlBuilder) _clearDeltaWAL(sc *storageContext, walSerials []string, path string) error {
func (cb *CrlBuilder) _clearDeltaWAL(sc *storageContext, walSerials []string, path string) error {
// Clearing of the delta WAL occurs after a new complete CRL has been built.
for _, serial := range walSerials {
// Don't remove our special entries!
@@ -442,15 +445,15 @@ func (cb *crlBuilder) _clearDeltaWAL(sc *storageContext, walSerials []string, pa
return nil
}
func (cb *crlBuilder) clearLocalDeltaWAL(sc *storageContext, walSerials []string) error {
func (cb *CrlBuilder) clearLocalDeltaWAL(sc *storageContext, walSerials []string) error {
return cb._clearDeltaWAL(sc, walSerials, localDeltaWALPath)
}
func (cb *crlBuilder) clearUnifiedDeltaWAL(sc *storageContext, walSerials []string) error {
func (cb *CrlBuilder) clearUnifiedDeltaWAL(sc *storageContext, walSerials []string) error {
return cb._clearDeltaWAL(sc, walSerials, unifiedDeltaWALPrefix)
}
func (cb *crlBuilder) rebuildDeltaCRLsIfForced(sc *storageContext, override bool) ([]string, error) {
func (cb *CrlBuilder) rebuildDeltaCRLsIfForced(sc *storageContext, override bool) ([]string, error) {
// Delta CRLs use the same expiry duration as the complete CRL. Because
// we always rebuild the complete CRL and then the delta CRL, we can
// be assured that the delta CRL always expires after a complete CRL,
@@ -516,7 +519,7 @@ func (cb *crlBuilder) rebuildDeltaCRLsIfForced(sc *storageContext, override bool
return cb.rebuildDeltaCRLsHoldingLock(sc, false)
}
func (cb *crlBuilder) _shouldRebuildLocalCRLs(sc *storageContext, override bool) (bool, error) {
func (cb *CrlBuilder) _shouldRebuildLocalCRLs(sc *storageContext, override bool) (bool, error) {
// Fetch two storage entries to see if we actually need to do this
// rebuild, given we're within the window.
lastWALEntry, err := sc.Storage.Get(sc.Context, localDeltaWALLastRevokedSerial)
@@ -562,7 +565,7 @@ func (cb *crlBuilder) _shouldRebuildLocalCRLs(sc *storageContext, override bool)
return true, nil
}
func (cb *crlBuilder) _shouldRebuildUnifiedCRLs(sc *storageContext, override bool) (bool, error) {
func (cb *CrlBuilder) _shouldRebuildUnifiedCRLs(sc *storageContext, override bool) (bool, error) {
// Unified CRL can only be built by the main cluster.
b := sc.Backend
if b.System().ReplicationState().HasState(consts.ReplicationDRSecondary|consts.ReplicationPerformanceStandby) ||
@@ -636,18 +639,18 @@ func (cb *crlBuilder) _shouldRebuildUnifiedCRLs(sc *storageContext, override boo
return shouldRebuild, nil
}
func (cb *crlBuilder) rebuildDeltaCRLs(sc *storageContext, forceNew bool) ([]string, error) {
func (cb *CrlBuilder) rebuildDeltaCRLs(sc *storageContext, forceNew bool) ([]string, error) {
cb._builder.Lock()
defer cb._builder.Unlock()
return cb.rebuildDeltaCRLsHoldingLock(sc, forceNew)
}
func (cb *crlBuilder) rebuildDeltaCRLsHoldingLock(sc *storageContext, forceNew bool) ([]string, error) {
func (cb *CrlBuilder) rebuildDeltaCRLsHoldingLock(sc *storageContext, forceNew bool) ([]string, error) {
return buildAnyCRLs(sc, forceNew, true /* building delta */)
}
func (cb *crlBuilder) addCertForRevocationCheck(cluster, serial string) {
func (cb *CrlBuilder) addCertForRevocationCheck(cluster, serial string) {
entry := &revocationQueueEntry{
Cluster: cluster,
Serial: serial,
@@ -655,7 +658,7 @@ func (cb *crlBuilder) addCertForRevocationCheck(cluster, serial string) {
cb.revQueue.Add(entry)
}
func (cb *crlBuilder) addCertForRevocationRemoval(cluster, serial string) {
func (cb *CrlBuilder) addCertForRevocationRemoval(cluster, serial string) {
entry := &revocationQueueEntry{
Cluster: cluster,
Serial: serial,
@@ -663,7 +666,7 @@ func (cb *crlBuilder) addCertForRevocationRemoval(cluster, serial string) {
cb.removalQueue.Add(entry)
}
func (cb *crlBuilder) addCertFromCrossRevocation(cluster, serial string) {
func (cb *CrlBuilder) addCertFromCrossRevocation(cluster, serial string) {
entry := &revocationQueueEntry{
Cluster: cluster,
Serial: serial,
@@ -671,7 +674,7 @@ func (cb *crlBuilder) addCertFromCrossRevocation(cluster, serial string) {
cb.crossQueue.Add(entry)
}
func (cb *crlBuilder) maybeGatherQueueForFirstProcess(sc *storageContext, isNotPerfPrimary bool) error {
func (cb *CrlBuilder) maybeGatherQueueForFirstProcess(sc *storageContext, isNotPerfPrimary bool) error {
// Assume holding lock.
if cb.haveInitializedQueue.Load() {
return nil
@@ -727,7 +730,7 @@ func (cb *crlBuilder) maybeGatherQueueForFirstProcess(sc *storageContext, isNotP
return nil
}
func (cb *crlBuilder) processRevocationQueue(sc *storageContext) error {
func (cb *CrlBuilder) processRevocationQueue(sc *storageContext) error {
sc.Backend.Logger().Debug(fmt.Sprintf("starting to process revocation requests"))
isNotPerfPrimary := sc.Backend.System().ReplicationState().HasState(consts.ReplicationDRSecondary|consts.ReplicationPerformanceStandby) ||
@@ -844,7 +847,7 @@ func (cb *crlBuilder) processRevocationQueue(sc *storageContext) error {
return nil
}
func (cb *crlBuilder) processCrossClusterRevocations(sc *storageContext) error {
func (cb *CrlBuilder) processCrossClusterRevocations(sc *storageContext) error {
sc.Backend.Logger().Debug(fmt.Sprintf("starting to process unified revocations"))
crlConfig, err := cb.getConfigWithUpdate(sc)
@@ -906,25 +909,25 @@ func (cb *crlBuilder) processCrossClusterRevocations(sc *storageContext) error {
return nil
}
// Helper function to fetch a map of issuerID->parsed cert for revocation
// Helper function to fetch a map of IssuerID->parsed cert for revocation
// usage. Unlike other paths, this needs to handle the legacy bundle
// more gracefully than rejecting it outright.
func fetchIssuerMapForRevocationChecking(sc *storageContext) (map[issuerID]*x509.Certificate, error) {
func fetchIssuerMapForRevocationChecking(sc *storageContext) (map[issuing.IssuerID]*x509.Certificate, error) {
var err error
var issuers []issuerID
var issuers []issuing.IssuerID
if !sc.Backend.useLegacyBundleCaStorage() {
if !sc.Backend.UseLegacyBundleCaStorage() {
issuers, err = sc.listIssuers()
if err != nil {
return nil, fmt.Errorf("could not fetch issuers list: %w", err)
}
} else {
// Hack: this isn't a real issuerID, but it works for fetchCAInfo
// Hack: this isn't a real IssuerID, but it works for fetchCAInfo
// since it resolves the reference.
issuers = []issuerID{legacyBundleShimID}
issuers = []issuing.IssuerID{legacyBundleShimID}
}
issuerIDCertMap := make(map[issuerID]*x509.Certificate, len(issuers))
issuerIDCertMap := make(map[issuing.IssuerID]*x509.Certificate, len(issuers))
for _, issuer := range issuers {
_, bundle, caErr := sc.fetchCertBundleByIssuerId(issuer, false)
if caErr != nil {
@@ -954,8 +957,8 @@ func fetchIssuerMapForRevocationChecking(sc *storageContext) (map[issuerID]*x509
// storage.
func tryRevokeCertBySerial(sc *storageContext, config *crlConfig, serial string) (*logical.Response, error) {
// revokeCert requires us to hold these locks before calling it.
sc.Backend.revokeStorageLock.Lock()
defer sc.Backend.revokeStorageLock.Unlock()
sc.Backend.GetRevokeStorageLock().Lock()
defer sc.Backend.GetRevokeStorageLock().Unlock()
certEntry, err := fetchCertBySerial(sc, "certs/", serial)
if err != nil {
@@ -1052,12 +1055,13 @@ func revokeCert(sc *storageContext, config *crlConfig, cert *x509.Certificate) (
return nil, fmt.Errorf("error creating revocation entry: %w", err)
}
certsCounted := sc.Backend.certsCounted.Load()
certCounter := sc.Backend.GetCertificateCounter()
certsCounted := certCounter.IsInitialized()
err = sc.Storage.Put(sc.Context, revEntry)
if err != nil {
return nil, fmt.Errorf("error saving revoked certificate to new location: %w", err)
}
sc.Backend.ifCountEnabledIncrementTotalRevokedCertificatesCount(certsCounted, revEntry.Key)
certCounter.IncrementTotalRevokedCertificatesCount(certsCounted, revEntry.Key)
// From here on out, the certificate has been revoked locally. Any other
// persistence issues might still err, but any other failure messages
@@ -1087,7 +1091,7 @@ func revokeCert(sc *storageContext, config *crlConfig, cert *x509.Certificate) (
// thread will reattempt it later on as we have the local write done.
sc.Backend.Logger().Error("Failed to write unified revocation entry, will re-attempt later",
"serial_number", colonSerial, "error", ignoreErr)
sc.Backend.unifiedTransferStatus.forceRun()
sc.Backend.GetUnifiedTransferStatus().forceRun()
resp.AddWarning(fmt.Sprintf("Failed to write unified revocation entry, will re-attempt later: %v", err))
failedWritingUnifiedCRL = true
@@ -1099,7 +1103,7 @@ func revokeCert(sc *storageContext, config *crlConfig, cert *x509.Certificate) (
// already rebuilt the full CRL so the Delta WAL will be cleared
// afterwards. Writing an entry only to immediately remove it
// isn't necessary.
warnings, crlErr := sc.Backend.crlBuilder.rebuild(sc, false)
warnings, crlErr := sc.Backend.CrlBuilder().rebuild(sc, false)
if crlErr != nil {
switch crlErr.(type) {
case errutil.UserError:
@@ -1144,7 +1148,7 @@ func writeRevocationDeltaWALs(sc *storageContext, config *crlConfig, resp *logic
// thread will reattempt it later on as we have the local write done.
sc.Backend.Logger().Error("Failed to write cross-cluster delta WAL entry, will re-attempt later",
"serial_number", colonSerial, "error", ignoredErr)
sc.Backend.unifiedTransferStatus.forceRun()
sc.Backend.GetUnifiedTransferStatus().forceRun()
resp.AddWarning(fmt.Sprintf("Failed to write cross-cluster delta WAL entry, will re-attempt later: %v", ignoredErr))
}
@@ -1235,12 +1239,12 @@ func buildAnyCRLs(sc *storageContext, forceNew bool, isDelta bool) ([]string, er
// See the message in revokedCert about rebuilding CRLs: we need to
// gracefully handle revoking entries with the legacy cert bundle.
var err error
var issuers []issuerID
var issuers []issuing.IssuerID
var wasLegacy bool
// First, fetch an updated copy of the CRL config. We'll pass this into
// buildCRL.
globalCRLConfig, err := sc.Backend.crlBuilder.getConfigWithUpdate(sc)
// First, fetch an updated copy of the CRL config. We'll pass this into buildCRL.
crlBuilder := sc.Backend.CrlBuilder()
globalCRLConfig, err := crlBuilder.getConfigWithUpdate(sc)
if err != nil {
return nil, fmt.Errorf("error building CRL: while updating config: %w", err)
}
@@ -1257,7 +1261,7 @@ func buildAnyCRLs(sc *storageContext, forceNew bool, isDelta bool) ([]string, er
return nil, nil
}
if !sc.Backend.useLegacyBundleCaStorage() {
if !sc.Backend.UseLegacyBundleCaStorage() {
issuers, err = sc.listIssuers()
if err != nil {
return nil, fmt.Errorf("error building CRL: while listing issuers: %w", err)
@@ -1266,7 +1270,7 @@ func buildAnyCRLs(sc *storageContext, forceNew bool, isDelta bool) ([]string, er
// Here, we hard-code the legacy issuer entry instead of using the
// default ref. This is because we need to hack some of the logic
// below for revocation to handle the legacy bundle.
issuers = []issuerID{legacyBundleShimID}
issuers = []issuing.IssuerID{legacyBundleShimID}
wasLegacy = true
// Here, we avoid building a delta CRL with the legacy CRL bundle.
@@ -1283,15 +1287,15 @@ func buildAnyCRLs(sc *storageContext, forceNew bool, isDelta bool) ([]string, er
return nil, fmt.Errorf("error building CRLs: while getting the default config: %w", err)
}
// We map issuerID->entry for fast lookup and also issuerID->Cert for
// We map IssuerID->entry for fast lookup and also IssuerID->Cert for
// signature verification and correlation of revoked certs.
issuerIDEntryMap := make(map[issuerID]*issuerEntry, len(issuers))
issuerIDCertMap := make(map[issuerID]*x509.Certificate, len(issuers))
issuerIDEntryMap := make(map[issuing.IssuerID]*issuing.IssuerEntry, len(issuers))
issuerIDCertMap := make(map[issuing.IssuerID]*x509.Certificate, len(issuers))
// We use a double map (keyID->subject->issuerID) to store whether or not this
// We use a double map (KeyID->subject->IssuerID) to store whether or not this
// key+subject paring has been seen before. We can then iterate over each
// key/subject and choose any representative issuer for that combination.
keySubjectIssuersMap := make(map[keyID]map[string][]issuerID)
keySubjectIssuersMap := make(map[issuing.KeyID]map[string][]issuing.IssuerID)
for _, issuer := range issuers {
// We don't strictly need this call, but by requesting the bundle, the
// legacy path is automatically ignored.
@@ -1328,7 +1332,7 @@ func buildAnyCRLs(sc *storageContext, forceNew bool, isDelta bool) ([]string, er
subject := string(thisCert.RawSubject)
if _, ok := keySubjectIssuersMap[thisEntry.KeyID]; !ok {
keySubjectIssuersMap[thisEntry.KeyID] = make(map[string][]issuerID)
keySubjectIssuersMap[thisEntry.KeyID] = make(map[string][]issuing.IssuerID)
}
keySubjectIssuersMap[thisEntry.KeyID][subject] = append(keySubjectIssuersMap[thisEntry.KeyID][subject], issuer)
@@ -1365,13 +1369,13 @@ func buildAnyCRLs(sc *storageContext, forceNew bool, isDelta bool) ([]string, er
if !isDelta {
// After we've confirmed the primary CRLs have built OK, go ahead and
// clear the delta CRL WAL and rebuild it.
if err := sc.Backend.crlBuilder.clearLocalDeltaWAL(sc, currLocalDeltaSerials); err != nil {
if err := crlBuilder.clearLocalDeltaWAL(sc, currLocalDeltaSerials); err != nil {
return nil, fmt.Errorf("error building CRLs: unable to clear Delta WAL: %w", err)
}
if err := sc.Backend.crlBuilder.clearUnifiedDeltaWAL(sc, currUnifiedDeltaSerials); err != nil {
if err := crlBuilder.clearUnifiedDeltaWAL(sc, currUnifiedDeltaSerials); err != nil {
return nil, fmt.Errorf("error building CRLs: unable to clear Delta WAL: %w", err)
}
deltaWarnings, err := sc.Backend.crlBuilder.rebuildDeltaCRLsHoldingLock(sc, forceNew)
deltaWarnings, err := crlBuilder.rebuildDeltaCRLsHoldingLock(sc, forceNew)
if err != nil {
return nil, fmt.Errorf("error building CRLs: unable to rebuild empty Delta WAL: %w", err)
}
@@ -1404,12 +1408,12 @@ func getLastWALSerial(sc *storageContext, path string) (string, error) {
func buildAnyLocalCRLs(
sc *storageContext,
issuersConfig *issuerConfigEntry,
issuersConfig *issuing.IssuerConfigEntry,
globalCRLConfig *crlConfig,
issuers []issuerID,
issuerIDEntryMap map[issuerID]*issuerEntry,
issuerIDCertMap map[issuerID]*x509.Certificate,
keySubjectIssuersMap map[keyID]map[string][]issuerID,
issuers []issuing.IssuerID,
issuerIDEntryMap map[issuing.IssuerID]*issuing.IssuerEntry,
issuerIDCertMap map[issuing.IssuerID]*x509.Certificate,
keySubjectIssuersMap map[issuing.KeyID]map[string][]issuing.IssuerID,
wasLegacy bool,
forceNew bool,
isDelta bool,
@@ -1435,14 +1439,14 @@ func buildAnyLocalCRLs(
// visible now, should also be visible on the complete CRL we're writing.
var currDeltaCerts []string
if !isDelta {
currDeltaCerts, err = sc.Backend.crlBuilder.getPresentLocalDeltaWALForClearing(sc)
currDeltaCerts, err = sc.Backend.CrlBuilder().getPresentLocalDeltaWALForClearing(sc)
if err != nil {
return nil, nil, fmt.Errorf("error building CRLs: unable to get present delta WAL entries for removal: %w", err)
}
}
var unassignedCerts []pkix.RevokedCertificate
var revokedCertsMap map[issuerID][]pkix.RevokedCertificate
var revokedCertsMap map[issuing.IssuerID][]pkix.RevokedCertificate
// If the CRL is disabled do not bother reading in all the revoked certificates.
if !globalCRLConfig.Disable {
@@ -1499,7 +1503,7 @@ func buildAnyLocalCRLs(
if isDelta {
// Update our last build time here so we avoid checking for new certs
// for a while.
sc.Backend.crlBuilder.lastDeltaRebuildCheck = time.Now()
sc.Backend.CrlBuilder().lastDeltaRebuildCheck = time.Now()
if len(lastDeltaSerial) > 0 {
// When we have a last delta serial, write out the relevant info
@@ -1523,12 +1527,12 @@ func buildAnyLocalCRLs(
func buildAnyUnifiedCRLs(
sc *storageContext,
issuersConfig *issuerConfigEntry,
issuersConfig *issuing.IssuerConfigEntry,
globalCRLConfig *crlConfig,
issuers []issuerID,
issuerIDEntryMap map[issuerID]*issuerEntry,
issuerIDCertMap map[issuerID]*x509.Certificate,
keySubjectIssuersMap map[keyID]map[string][]issuerID,
issuers []issuing.IssuerID,
issuerIDEntryMap map[issuing.IssuerID]*issuing.IssuerEntry,
issuerIDCertMap map[issuing.IssuerID]*x509.Certificate,
keySubjectIssuersMap map[issuing.KeyID]map[string][]issuing.IssuerID,
wasLegacy bool,
forceNew bool,
isDelta bool,
@@ -1578,14 +1582,14 @@ func buildAnyUnifiedCRLs(
// visible now, should also be visible on the complete CRL we're writing.
var currDeltaCerts []string
if !isDelta {
currDeltaCerts, err = sc.Backend.crlBuilder.getPresentUnifiedDeltaWALForClearing(sc)
currDeltaCerts, err = sc.Backend.CrlBuilder().getPresentUnifiedDeltaWALForClearing(sc)
if err != nil {
return nil, nil, fmt.Errorf("error building CRLs: unable to get present delta WAL entries for removal: %w", err)
}
}
var unassignedCerts []pkix.RevokedCertificate
var revokedCertsMap map[issuerID][]pkix.RevokedCertificate
var revokedCertsMap map[issuing.IssuerID][]pkix.RevokedCertificate
// If the CRL is disabled do not bother reading in all the revoked certificates.
if !globalCRLConfig.Disable {
@@ -1642,7 +1646,7 @@ func buildAnyUnifiedCRLs(
if isDelta {
// Update our last build time here so we avoid checking for new certs
// for a while.
sc.Backend.crlBuilder.lastDeltaRebuildCheck = time.Now()
sc.Backend.CrlBuilder().lastDeltaRebuildCheck = time.Now()
// Persist all of our known last revoked serial numbers here, as the
// last seen serial during build. This will allow us to detect if any
@@ -1674,20 +1678,20 @@ func buildAnyUnifiedCRLs(
func buildAnyCRLsWithCerts(
sc *storageContext,
issuersConfig *issuerConfigEntry,
issuersConfig *issuing.IssuerConfigEntry,
globalCRLConfig *crlConfig,
internalCRLConfig *internalCRLConfigEntry,
issuers []issuerID,
issuerIDEntryMap map[issuerID]*issuerEntry,
keySubjectIssuersMap map[keyID]map[string][]issuerID,
internalCRLConfig *issuing.InternalCRLConfigEntry,
issuers []issuing.IssuerID,
issuerIDEntryMap map[issuing.IssuerID]*issuing.IssuerEntry,
keySubjectIssuersMap map[issuing.KeyID]map[string][]issuing.IssuerID,
unassignedCerts []pkix.RevokedCertificate,
revokedCertsMap map[issuerID][]pkix.RevokedCertificate,
revokedCertsMap map[issuing.IssuerID][]pkix.RevokedCertificate,
forceNew bool,
isUnified bool,
isDelta bool,
) ([]string, error) {
// Now we can call buildCRL once, on an arbitrary/representative issuer
// from each of these (keyID, subject) sets.
// from each of these (KeyID, subject) sets.
var warnings []string
for _, subjectIssuersMap := range keySubjectIssuersMap {
for _, issuersSet := range subjectIssuersMap {
@@ -1696,15 +1700,15 @@ func buildAnyCRLsWithCerts(
}
var revokedCerts []pkix.RevokedCertificate
representative := issuerID("")
var crlIdentifier crlID
var crlIdIssuer issuerID
representative := issuing.IssuerID("")
var crlIdentifier issuing.CrlID
var crlIdIssuer issuing.IssuerID
for _, issuerId := range issuersSet {
// Skip entries which aren't enabled for CRL signing. We don't
// particularly care which issuer is ultimately chosen as the
// set representative for signing at this point, other than
// that it has crl-signing usage.
if err := issuerIDEntryMap[issuerId].EnsureUsage(CRLSigningUsage); err != nil {
if err := issuerIDEntryMap[issuerId].EnsureUsage(issuing.CRLSigningUsage); err != nil {
continue
}
@@ -1724,7 +1728,7 @@ func buildAnyCRLsWithCerts(
// Otherwise, use any other random issuer if we've not yet
// chosen one.
if representative == issuerID("") {
if representative == issuing.IssuerID("") {
representative = issuerId
}
@@ -1864,7 +1868,7 @@ func buildAnyCRLsWithCerts(
return warnings, nil
}
func isRevInfoIssuerValid(revInfo *revocationInfo, issuerIDCertMap map[issuerID]*x509.Certificate) bool {
func isRevInfoIssuerValid(revInfo *revocationInfo, issuerIDCertMap map[issuing.IssuerID]*x509.Certificate) bool {
if len(revInfo.CertificateIssuer) > 0 {
issuerId := revInfo.CertificateIssuer
if _, issuerExists := issuerIDCertMap[issuerId]; issuerExists {
@@ -1875,7 +1879,7 @@ func isRevInfoIssuerValid(revInfo *revocationInfo, issuerIDCertMap map[issuerID]
return false
}
func associateRevokedCertWithIsssuer(revInfo *revocationInfo, revokedCert *x509.Certificate, issuerIDCertMap map[issuerID]*x509.Certificate) bool {
func associateRevokedCertWithIsssuer(revInfo *revocationInfo, revokedCert *x509.Certificate, issuerIDCertMap map[issuing.IssuerID]*x509.Certificate) bool {
for issuerId, issuerCert := range issuerIDCertMap {
if bytes.Equal(revokedCert.RawIssuer, issuerCert.RawSubject) {
if err := revokedCert.CheckSignatureFrom(issuerCert); err == nil {
@@ -1889,9 +1893,9 @@ func associateRevokedCertWithIsssuer(revInfo *revocationInfo, revokedCert *x509.
return false
}
func getLocalRevokedCertEntries(sc *storageContext, issuerIDCertMap map[issuerID]*x509.Certificate, isDelta bool) ([]pkix.RevokedCertificate, map[issuerID][]pkix.RevokedCertificate, error) {
func getLocalRevokedCertEntries(sc *storageContext, issuerIDCertMap map[issuing.IssuerID]*x509.Certificate, isDelta bool) ([]pkix.RevokedCertificate, map[issuing.IssuerID][]pkix.RevokedCertificate, error) {
var unassignedCerts []pkix.RevokedCertificate
revokedCertsMap := make(map[issuerID][]pkix.RevokedCertificate)
revokedCertsMap := make(map[issuing.IssuerID][]pkix.RevokedCertificate)
listingPath := revokedPath
if isDelta {
@@ -2018,13 +2022,13 @@ func getLocalRevokedCertEntries(sc *storageContext, issuerIDCertMap map[issuerID
return unassignedCerts, revokedCertsMap, nil
}
func getUnifiedRevokedCertEntries(sc *storageContext, issuerIDCertMap map[issuerID]*x509.Certificate, isDelta bool) ([]pkix.RevokedCertificate, map[issuerID][]pkix.RevokedCertificate, error) {
func getUnifiedRevokedCertEntries(sc *storageContext, issuerIDCertMap map[issuing.IssuerID]*x509.Certificate, isDelta bool) ([]pkix.RevokedCertificate, map[issuing.IssuerID][]pkix.RevokedCertificate, error) {
// Getting unified revocation entries is a bit different than getting
// the local ones. In particular, the full copy of the certificate is
// unavailable, so we'll be able to avoid parsing the stored certificate,
// at the expense of potentially having incorrect issuer mappings.
var unassignedCerts []pkix.RevokedCertificate
revokedCertsMap := make(map[issuerID][]pkix.RevokedCertificate)
revokedCertsMap := make(map[issuing.IssuerID][]pkix.RevokedCertificate)
listingPath := unifiedRevocationReadPathPrefix
if isDelta {
@@ -2114,7 +2118,7 @@ func getUnifiedRevokedCertEntries(sc *storageContext, issuerIDCertMap map[issuer
return unassignedCerts, revokedCertsMap, nil
}
func augmentWithRevokedIssuers(issuerIDEntryMap map[issuerID]*issuerEntry, issuerIDCertMap map[issuerID]*x509.Certificate, revokedCertsMap map[issuerID][]pkix.RevokedCertificate) error {
func augmentWithRevokedIssuers(issuerIDEntryMap map[issuing.IssuerID]*issuing.IssuerEntry, issuerIDCertMap map[issuing.IssuerID]*x509.Certificate, revokedCertsMap map[issuing.IssuerID][]pkix.RevokedCertificate) error {
// When setup our maps with the legacy CA bundle, we only have a
// single entry here. This entry is never revoked, so the outer loop
// will exit quickly.
@@ -2150,7 +2154,7 @@ func augmentWithRevokedIssuers(issuerIDEntryMap map[issuerID]*issuerEntry, issue
// Builds a CRL by going through the list of revoked certificates and building
// a new CRL with the stored revocation times and serial numbers.
func buildCRL(sc *storageContext, crlInfo *crlConfig, forceNew bool, thisIssuerId issuerID, revoked []pkix.RevokedCertificate, identifier crlID, crlNumber int64, isUnified bool, isDelta bool, lastCompleteNumber int64) (*time.Time, error) {
func buildCRL(sc *storageContext, crlInfo *crlConfig, forceNew bool, thisIssuerId issuing.IssuerID, revoked []pkix.RevokedCertificate, identifier issuing.CrlID, crlNumber int64, isUnified bool, isDelta bool, lastCompleteNumber int64) (*time.Time, error) {
var revokedCerts []pkix.RevokedCertificate
crlLifetime, err := parseutil.ParseDurationSecond(crlInfo.Expiry)
@@ -2177,7 +2181,7 @@ func buildCRL(sc *storageContext, crlInfo *crlConfig, forceNew bool, thisIssuerI
revokedCerts = revoked
WRITE:
signingBundle, caErr := sc.fetchCAInfoByIssuerId(thisIssuerId, CRLSigningUsage)
signingBundle, caErr := sc.fetchCAInfoByIssuerId(thisIssuerId, issuing.CRLSigningUsage)
if caErr != nil {
switch caErr.(type) {
case errutil.UserError:

View File

@@ -6,6 +6,7 @@ package pki
import (
"time"
"github.com/hashicorp/vault/builtin/logical/pki/issuing"
"github.com/hashicorp/vault/sdk/framework"
)
@@ -597,7 +598,7 @@ basic constraints.`,
func addSignVerbatimRoleFields(fields map[string]*framework.FieldSchema) map[string]*framework.FieldSchema {
fields["key_usage"] = &framework.FieldSchema{
Type: framework.TypeCommaStringSlice,
Default: []string{"DigitalSignature", "KeyAgreement", "KeyEncipherment"},
Default: issuing.DefaultRoleKeyUsages,
Description: `A comma-separated string or list of key usages (not extended
key usages). Valid values can be found at
https://golang.org/pkg/crypto/x509/#KeyUsage
@@ -608,7 +609,7 @@ this value to an empty list.`,
fields["ext_key_usage"] = &framework.FieldSchema{
Type: framework.TypeCommaStringSlice,
Default: []string{},
Default: issuing.DefaultRoleEstKeyUsages,
Description: `A comma-separated string or list of extended key usages. Valid values can be found at
https://golang.org/pkg/crypto/x509/#ExtKeyUsage
-- simply drop the "ExtKeyUsage" part of the name.
@@ -618,24 +619,25 @@ this value to an empty list.`,
fields["ext_key_usage_oids"] = &framework.FieldSchema{
Type: framework.TypeCommaStringSlice,
Default: issuing.DefaultRoleEstKeyUsageOids,
Description: `A comma-separated string or list of extended key usage oids.`,
}
fields["signature_bits"] = &framework.FieldSchema{
Type: framework.TypeInt,
Default: 0,
Default: issuing.DefaultRoleSignatureBits,
Description: `The number of bits to use in the signature
algorithm; accepts 256 for SHA-2-256, 384 for SHA-2-384, and 512 for
SHA-2-512. Defaults to 0 to automatically detect based on key length
(SHA-2-256 for RSA keys, and matching the curve size for NIST P-Curves).`,
DisplayAttrs: &framework.DisplayAttributes{
Value: 0,
Value: issuing.DefaultRoleSignatureBits,
},
}
fields["use_pss"] = &framework.FieldSchema{
Type: framework.TypeBool,
Default: false,
Default: issuing.DefaultRoleUsePss,
Description: `Whether or not to use PSS signatures when using a
RSA key-type issuer. Defaults to false.`,
}

View File

@@ -15,6 +15,7 @@ import (
"testing"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/builtin/logical/pki/issuing"
vaulthttp "github.com/hashicorp/vault/http"
vaultocsp "github.com/hashicorp/vault/sdk/helper/ocsp"
"github.com/hashicorp/vault/sdk/helper/testhelpers/schema"
@@ -41,7 +42,7 @@ func TestIntegration_RotateRootUsesNext(t *testing.T) {
require.NotNil(t, resp, "got nil response from rotate root")
require.False(t, resp.IsError(), "got an error from rotate root: %#v", resp)
issuerId1 := resp.Data["issuer_id"].(issuerID)
issuerId1 := resp.Data["issuer_id"].(issuing.IssuerID)
issuerName1 := resp.Data["issuer_name"]
require.NotEmpty(t, issuerId1, "issuer id was empty on initial rotate root command")
@@ -61,7 +62,7 @@ func TestIntegration_RotateRootUsesNext(t *testing.T) {
require.NotNil(t, resp, "got nil response from rotate root")
require.False(t, resp.IsError(), "got an error from rotate root: %#v", resp)
issuerId2 := resp.Data["issuer_id"].(issuerID)
issuerId2 := resp.Data["issuer_id"].(issuing.IssuerID)
issuerName2 := resp.Data["issuer_name"]
require.NotEmpty(t, issuerId2, "issuer id was empty on second rotate root command")
@@ -83,7 +84,7 @@ func TestIntegration_RotateRootUsesNext(t *testing.T) {
require.NotNil(t, resp, "got nil response from rotate root")
require.False(t, resp.IsError(), "got an error from rotate root: %#v", resp)
issuerId3 := resp.Data["issuer_id"].(issuerID)
issuerId3 := resp.Data["issuer_id"].(issuing.IssuerID)
issuerName3 := resp.Data["issuer_name"]
require.NotEmpty(t, issuerId3, "issuer id was empty on third rotate root command")
@@ -436,7 +437,7 @@ func TestIntegration_AutoIssuer(t *testing.T) {
"pem_bundle": certOne,
})
requireSuccessNonNilResponse(t, resp, err)
issuerIdOneReimported := issuerID(resp.Data["imported_issuers"].([]string)[0])
issuerIdOneReimported := issuing.IssuerID(resp.Data["imported_issuers"].([]string)[0])
resp, err = CBRead(b, s, "config/issuers")
requireSuccessNonNilResponse(t, resp, err)
@@ -643,11 +644,11 @@ func TestIntegrationOCSPClientWithPKI(t *testing.T) {
}
}
func genTestRootCa(t *testing.T, b *backend, s logical.Storage) (issuerID, keyID) {
func genTestRootCa(t *testing.T, b *backend, s logical.Storage) (issuing.IssuerID, issuing.KeyID) {
return genTestRootCaWithIssuerName(t, b, s, "")
}
func genTestRootCaWithIssuerName(t *testing.T, b *backend, s logical.Storage, issuerName string) (issuerID, keyID) {
func genTestRootCaWithIssuerName(t *testing.T, b *backend, s logical.Storage, issuerName string) (issuing.IssuerID, issuing.KeyID) {
data := map[string]interface{}{
"common_name": "test.com",
}
@@ -665,8 +666,8 @@ func genTestRootCaWithIssuerName(t *testing.T, b *backend, s logical.Storage, is
require.NotNil(t, resp, "got nil response from generating root ca")
require.False(t, resp.IsError(), "got an error from generating root ca: %#v", resp)
issuerId := resp.Data["issuer_id"].(issuerID)
keyId := resp.Data["key_id"].(keyID)
issuerId := resp.Data["issuer_id"].(issuing.IssuerID)
keyId := resp.Data["key_id"].(issuing.KeyID)
require.NotEmpty(t, issuerId, "returned issuer id was empty")
require.NotEmpty(t, keyId, "returned key id was empty")

View File

@@ -0,0 +1,154 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package issuing
import (
"context"
"fmt"
"strings"
"github.com/asaskevich/govalidator"
"github.com/hashicorp/vault/sdk/helper/certutil"
"github.com/hashicorp/vault/sdk/logical"
)
const ClusterConfigPath = "config/cluster"
type AiaConfigEntry struct {
IssuingCertificates []string `json:"issuing_certificates"`
CRLDistributionPoints []string `json:"crl_distribution_points"`
OCSPServers []string `json:"ocsp_servers"`
EnableTemplating bool `json:"enable_templating"`
}
type ClusterConfigEntry struct {
Path string `json:"path"`
AIAPath string `json:"aia_path"`
}
func GetAIAURLs(ctx context.Context, s logical.Storage, i *IssuerEntry) (*certutil.URLEntries, error) {
// Default to the per-issuer AIA URLs.
entries := i.AIAURIs
// If none are set (either due to a nil entry or because no URLs have
// been provided), fall back to the global AIA URL config.
if entries == nil || (len(entries.IssuingCertificates) == 0 && len(entries.CRLDistributionPoints) == 0 && len(entries.OCSPServers) == 0) {
var err error
entries, err = GetGlobalAIAURLs(ctx, s)
if err != nil {
return nil, err
}
}
if entries == nil {
return &certutil.URLEntries{}, nil
}
return ToURLEntries(ctx, s, i.ID, entries)
}
func GetGlobalAIAURLs(ctx context.Context, storage logical.Storage) (*AiaConfigEntry, error) {
entry, err := storage.Get(ctx, "urls")
if err != nil {
return nil, err
}
entries := &AiaConfigEntry{
IssuingCertificates: []string{},
CRLDistributionPoints: []string{},
OCSPServers: []string{},
EnableTemplating: false,
}
if entry == nil {
return entries, nil
}
if err := entry.DecodeJSON(entries); err != nil {
return nil, err
}
return entries, nil
}
func ToURLEntries(ctx context.Context, s logical.Storage, issuer IssuerID, c *AiaConfigEntry) (*certutil.URLEntries, error) {
if len(c.IssuingCertificates) == 0 && len(c.CRLDistributionPoints) == 0 && len(c.OCSPServers) == 0 {
return &certutil.URLEntries{}, nil
}
result := certutil.URLEntries{
IssuingCertificates: c.IssuingCertificates[:],
CRLDistributionPoints: c.CRLDistributionPoints[:],
OCSPServers: c.OCSPServers[:],
}
if c.EnableTemplating {
cfg, err := GetClusterConfig(ctx, s)
if err != nil {
return nil, fmt.Errorf("error fetching cluster-local address config: %w", err)
}
for name, source := range map[string]*[]string{
"issuing_certificates": &result.IssuingCertificates,
"crl_distribution_points": &result.CRLDistributionPoints,
"ocsp_servers": &result.OCSPServers,
} {
templated := make([]string, len(*source))
for index, uri := range *source {
if strings.Contains(uri, "{{cluster_path}}") && len(cfg.Path) == 0 {
return nil, fmt.Errorf("unable to template AIA URLs as we lack local cluster address information (path)")
}
if strings.Contains(uri, "{{cluster_aia_path}}") && len(cfg.AIAPath) == 0 {
return nil, fmt.Errorf("unable to template AIA URLs as we lack local cluster address information (aia_path)")
}
if strings.Contains(uri, "{{issuer_id}}") && len(issuer) == 0 {
// Elide issuer AIA info as we lack an issuer_id.
return nil, fmt.Errorf("unable to template AIA URLs as we lack an issuer_id for this operation")
}
uri = strings.ReplaceAll(uri, "{{cluster_path}}", cfg.Path)
uri = strings.ReplaceAll(uri, "{{cluster_aia_path}}", cfg.AIAPath)
uri = strings.ReplaceAll(uri, "{{issuer_id}}", issuer.String())
templated[index] = uri
}
if uri := ValidateURLs(templated); uri != "" {
return nil, fmt.Errorf("error validating templated %v; invalid URI: %v", name, uri)
}
*source = templated
}
}
return &result, nil
}
func GetClusterConfig(ctx context.Context, s logical.Storage) (*ClusterConfigEntry, error) {
entry, err := s.Get(ctx, ClusterConfigPath)
if err != nil {
return nil, err
}
var result ClusterConfigEntry
if entry == nil {
return &result, nil
}
if err = entry.DecodeJSON(&result); err != nil {
return nil, err
}
return &result, nil
}
func ValidateURLs(urls []string) string {
for _, curr := range urls {
if !govalidator.IsURL(curr) || strings.Contains(curr, "{{issuer_id}}") || strings.Contains(curr, "{{cluster_path}}") || strings.Contains(curr, "{{cluster_aia_path}}") {
return curr
}
}
return ""
}

View File

@@ -0,0 +1,124 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package issuing
import (
"context"
"fmt"
"strings"
"time"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/sdk/helper/errutil"
"github.com/hashicorp/vault/sdk/logical"
)
const StorageIssuerConfig = "config/issuers"
type IssuerConfigEntry struct {
// This new fetchedDefault field allows us to detect if the default
// issuer was modified, in turn dispatching the timestamp updater
// if necessary.
fetchedDefault IssuerID `json:"-"`
DefaultIssuerId IssuerID `json:"default"`
DefaultFollowsLatestIssuer bool `json:"default_follows_latest_issuer"`
}
func GetIssuersConfig(ctx context.Context, s logical.Storage) (*IssuerConfigEntry, error) {
entry, err := s.Get(ctx, StorageIssuerConfig)
if err != nil {
return nil, err
}
issuerConfig := &IssuerConfigEntry{}
if entry != nil {
if err := entry.DecodeJSON(issuerConfig); err != nil {
return nil, errutil.InternalError{Err: fmt.Sprintf("unable to decode issuer configuration: %v", err)}
}
}
issuerConfig.fetchedDefault = issuerConfig.DefaultIssuerId
return issuerConfig, nil
}
func SetIssuersConfig(ctx context.Context, s logical.Storage, config *IssuerConfigEntry) error {
json, err := logical.StorageEntryJSON(StorageIssuerConfig, config)
if err != nil {
return err
}
if err := s.Put(ctx, json); err != nil {
return err
}
if err := changeDefaultIssuerTimestamps(ctx, s, config.fetchedDefault, config.DefaultIssuerId); err != nil {
return err
}
return nil
}
func changeDefaultIssuerTimestamps(ctx context.Context, s logical.Storage, oldDefault IssuerID, newDefault IssuerID) error {
if newDefault == oldDefault {
return nil
}
now := time.Now().UTC()
// When the default issuer changes, we need to modify four
// pieces of information:
//
// 1. The old default issuer's modification time, as it no
// longer works for the /cert/ca path.
// 2. The new default issuer's modification time, as it now
// works for the /cert/ca path.
// 3. & 4. Both issuer's CRLs, as they behave the same, under
// the /cert/crl path!
for _, thisId := range []IssuerID{oldDefault, newDefault} {
if len(thisId) == 0 {
continue
}
// 1 & 2 above.
issuer, err := FetchIssuerById(ctx, s, thisId)
if err != nil {
// Due to the lack of transactions, if we deleted the default
// issuer (successfully), but the subsequent issuer config write
// (to clear the default issuer's old id) failed, we might have
// an inconsistent config. If we later hit this loop (and flush
// these timestamps again -- perhaps because the operator
// selected a new default), we'd have erred out here, because
// the since-deleted default issuer doesn't exist. In this case,
// skip the issuer instead of bailing.
err := fmt.Errorf("unable to update issuer (%v)'s modification time: error fetching issuer: %w", thisId, err)
if strings.Contains(err.Error(), "does not exist") {
hclog.L().Warn(err.Error())
continue
}
return err
}
issuer.LastModified = now
err = WriteIssuer(ctx, s, issuer)
if err != nil {
return fmt.Errorf("unable to update issuer (%v)'s modification time: error persisting issuer: %w", thisId, err)
}
}
// Fetch and update the internalCRLConfigEntry (3&4).
cfg, err := GetLocalCRLConfig(ctx, s)
if err != nil {
return fmt.Errorf("unable to update local CRL config's modification time: error fetching local CRL config: %w", err)
}
cfg.LastModified = now
cfg.DeltaLastModified = now
err = SetLocalCRLConfig(ctx, s, cfg)
if err != nil {
return fmt.Errorf("unable to update local CRL config's modification time: error persisting local CRL config: %w", err)
}
return nil
}

View File

@@ -0,0 +1,45 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package issuing
import (
"context"
"fmt"
"github.com/hashicorp/vault/sdk/helper/errutil"
"github.com/hashicorp/vault/sdk/logical"
)
const (
StorageKeyConfig = "config/keys"
)
type KeyConfigEntry struct {
DefaultKeyId KeyID `json:"default"`
}
func SetKeysConfig(ctx context.Context, s logical.Storage, config *KeyConfigEntry) error {
json, err := logical.StorageEntryJSON(StorageKeyConfig, config)
if err != nil {
return err
}
return s.Put(ctx, json)
}
func GetKeysConfig(ctx context.Context, s logical.Storage) (*KeyConfigEntry, error) {
entry, err := s.Get(ctx, StorageKeyConfig)
if err != nil {
return nil, err
}
keyConfig := &KeyConfigEntry{}
if entry != nil {
if err := entry.DecodeJSON(keyConfig); err != nil {
return nil, errutil.InternalError{Err: fmt.Sprintf("unable to decode key configuration: %v", err)}
}
}
return keyConfig, nil
}

View File

@@ -0,0 +1,190 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package issuing
import (
"context"
"fmt"
"strings"
"time"
"github.com/hashicorp/vault/sdk/helper/errutil"
"github.com/hashicorp/vault/sdk/logical"
)
const (
StorageLocalCRLConfig = "crls/config"
StorageUnifiedCRLConfig = "unified-crls/config"
)
type InternalCRLConfigEntry struct {
IssuerIDCRLMap map[IssuerID]CrlID `json:"issuer_id_crl_map"`
CRLNumberMap map[CrlID]int64 `json:"crl_number_map"`
LastCompleteNumberMap map[CrlID]int64 `json:"last_complete_number_map"`
CRLExpirationMap map[CrlID]time.Time `json:"crl_expiration_map"`
LastModified time.Time `json:"last_modified"`
DeltaLastModified time.Time `json:"delta_last_modified"`
UseGlobalQueue bool `json:"cross_cluster_revocation"`
}
type CrlID string
func (p CrlID) String() string {
return string(p)
}
func GetLocalCRLConfig(ctx context.Context, s logical.Storage) (*InternalCRLConfigEntry, error) {
return _getInternalCRLConfig(ctx, s, StorageLocalCRLConfig)
}
func GetUnifiedCRLConfig(ctx context.Context, s logical.Storage) (*InternalCRLConfigEntry, error) {
return _getInternalCRLConfig(ctx, s, StorageUnifiedCRLConfig)
}
func _getInternalCRLConfig(ctx context.Context, s logical.Storage, path string) (*InternalCRLConfigEntry, error) {
entry, err := s.Get(ctx, path)
if err != nil {
return nil, err
}
mapping := &InternalCRLConfigEntry{}
if entry != nil {
if err := entry.DecodeJSON(mapping); err != nil {
return nil, errutil.InternalError{Err: fmt.Sprintf("unable to decode cluster-local CRL configuration: %v", err)}
}
}
if len(mapping.IssuerIDCRLMap) == 0 {
mapping.IssuerIDCRLMap = make(map[IssuerID]CrlID)
}
if len(mapping.CRLNumberMap) == 0 {
mapping.CRLNumberMap = make(map[CrlID]int64)
}
if len(mapping.LastCompleteNumberMap) == 0 {
mapping.LastCompleteNumberMap = make(map[CrlID]int64)
// Since this might not exist on migration, we want to guess as
// to the last full CRL number was. This was likely the last
// Value from CRLNumberMap if it existed, since we're just adding
// the mapping here in this block.
//
// After the next full CRL build, we will have set this Value
// correctly, so it doesn't really matter in the long term if
// we're off here.
for id, number := range mapping.CRLNumberMap {
// Decrement by one, since CRLNumberMap is the future number,
// not the last built number.
mapping.LastCompleteNumberMap[id] = number - 1
}
}
if len(mapping.CRLExpirationMap) == 0 {
mapping.CRLExpirationMap = make(map[CrlID]time.Time)
}
return mapping, nil
}
func SetLocalCRLConfig(ctx context.Context, s logical.Storage, mapping *InternalCRLConfigEntry) error {
return _setInternalCRLConfig(ctx, s, mapping, StorageLocalCRLConfig)
}
func SetUnifiedCRLConfig(ctx context.Context, s logical.Storage, mapping *InternalCRLConfigEntry) error {
return _setInternalCRLConfig(ctx, s, mapping, StorageUnifiedCRLConfig)
}
func _setInternalCRLConfig(ctx context.Context, s logical.Storage, mapping *InternalCRLConfigEntry, path string) error {
if err := _cleanupInternalCRLMapping(ctx, s, mapping, path); err != nil {
return fmt.Errorf("failed to clean up internal CRL mapping: %w", err)
}
json, err := logical.StorageEntryJSON(path, mapping)
if err != nil {
return err
}
return s.Put(ctx, json)
}
func _cleanupInternalCRLMapping(ctx context.Context, s logical.Storage, mapping *InternalCRLConfigEntry, path string) error {
// Track which CRL IDs are presently referred to by issuers; any other CRL
// IDs are subject to cleanup.
//
// Unused IDs both need to be removed from this map (cleaning up the size
// of this storage entry) but also the full CRLs removed from disk.
presentMap := make(map[CrlID]bool)
for _, id := range mapping.IssuerIDCRLMap {
presentMap[id] = true
}
// Identify which CRL IDs exist and are candidates for removal;
// theoretically these three maps should be in sync, but were added
// at different times.
toRemove := make(map[CrlID]bool)
for id := range mapping.CRLNumberMap {
if !presentMap[id] {
toRemove[id] = true
}
}
for id := range mapping.LastCompleteNumberMap {
if !presentMap[id] {
toRemove[id] = true
}
}
for id := range mapping.CRLExpirationMap {
if !presentMap[id] {
toRemove[id] = true
}
}
// Depending on which path we're writing this config to, we need to
// remove CRLs from the relevant folder too.
isLocal := path == StorageLocalCRLConfig
baseCRLPath := "crls/"
if !isLocal {
baseCRLPath = "unified-crls/"
}
for id := range toRemove {
// Clean up space in this mapping...
delete(mapping.CRLNumberMap, id)
delete(mapping.LastCompleteNumberMap, id)
delete(mapping.CRLExpirationMap, id)
// And clean up space on disk from the fat CRL mapping.
crlPath := baseCRLPath + string(id)
deltaCRLPath := crlPath + "-delta"
if err := s.Delete(ctx, crlPath); err != nil {
return fmt.Errorf("failed to delete unreferenced CRL %v: %w", id, err)
}
if err := s.Delete(ctx, deltaCRLPath); err != nil {
return fmt.Errorf("failed to delete unreferenced delta CRL %v: %w", id, err)
}
}
// Lastly, some CRLs could've been partially removed from the map but
// not from disk. Check to see if we have any dangling CRLs and remove
// them too.
list, err := s.List(ctx, baseCRLPath)
if err != nil {
return fmt.Errorf("failed listing all CRLs: %w", err)
}
for _, crl := range list {
if crl == "config" || strings.HasSuffix(crl, "/") {
continue
}
if presentMap[CrlID(crl)] {
continue
}
if err := s.Delete(ctx, baseCRLPath+"/"+crl); err != nil {
return fmt.Errorf("failed cleaning up orphaned CRL %v: %w", crl, err)
}
}
return nil
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,495 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package issuing
import (
"context"
"crypto/x509"
"fmt"
"sort"
"strings"
"time"
"github.com/hashicorp/vault/sdk/helper/certutil"
"github.com/hashicorp/vault/sdk/helper/errutil"
"github.com/hashicorp/vault/sdk/logical"
"github.com/hashicorp/vault/builtin/logical/pki/managed_key"
"github.com/hashicorp/vault/builtin/logical/pki/parsing"
)
const (
ReadOnlyUsage IssuerUsage = iota
IssuanceUsage IssuerUsage = 1 << iota
CRLSigningUsage IssuerUsage = 1 << iota
OCSPSigningUsage IssuerUsage = 1 << iota
)
const (
// When adding a new usage in the future, we'll need to create a usage
// mask field on the IssuerEntry and handle migrations to a newer mask,
// inferring a value for the new bits.
AllIssuerUsages = ReadOnlyUsage | IssuanceUsage | CRLSigningUsage | OCSPSigningUsage
DefaultRef = "default"
IssuerPrefix = "config/issuer/"
// Used as a quick sanity check for a reference id lookups...
uuidLength = 36
IssuerRefNotFound = IssuerID("not-found")
LatestIssuerVersion = 1
LegacyCertBundlePath = "config/ca_bundle"
LegacyBundleShimID = IssuerID("legacy-entry-shim-id")
LegacyBundleShimKeyID = KeyID("legacy-entry-shim-key-id")
)
type IssuerID string
func (p IssuerID) String() string {
return string(p)
}
type IssuerUsage uint
var namedIssuerUsages = map[string]IssuerUsage{
"read-only": ReadOnlyUsage,
"issuing-certificates": IssuanceUsage,
"crl-signing": CRLSigningUsage,
"ocsp-signing": OCSPSigningUsage,
}
func (i *IssuerUsage) ToggleUsage(usages ...IssuerUsage) {
for _, usage := range usages {
*i ^= usage
}
}
func (i IssuerUsage) HasUsage(usage IssuerUsage) bool {
return (i & usage) == usage
}
func (i IssuerUsage) Names() string {
var names []string
var builtUsage IssuerUsage
// Return the known set of usages in a sorted order to not have Terraform state files flipping
// saying values are different when it's the same list in a different order.
keys := make([]string, 0, len(namedIssuerUsages))
for k := range namedIssuerUsages {
keys = append(keys, k)
}
sort.Strings(keys)
for _, name := range keys {
usage := namedIssuerUsages[name]
if i.HasUsage(usage) {
names = append(names, name)
builtUsage.ToggleUsage(usage)
}
}
if i != builtUsage {
// Found some unknown usage, we should indicate this in the names.
names = append(names, fmt.Sprintf("unknown:%v", i^builtUsage))
}
return strings.Join(names, ",")
}
func NewIssuerUsageFromNames(names []string) (IssuerUsage, error) {
var result IssuerUsage
for index, name := range names {
usage, ok := namedIssuerUsages[name]
if !ok {
return ReadOnlyUsage, fmt.Errorf("unknown name for usage at index %v: %v", index, name)
}
result.ToggleUsage(usage)
}
return result, nil
}
type IssuerEntry struct {
ID IssuerID `json:"id"`
Name string `json:"name"`
KeyID KeyID `json:"key_id"`
Certificate string `json:"certificate"`
CAChain []string `json:"ca_chain"`
ManualChain []IssuerID `json:"manual_chain"`
SerialNumber string `json:"serial_number"`
LeafNotAfterBehavior certutil.NotAfterBehavior `json:"not_after_behavior"`
Usage IssuerUsage `json:"usage"`
RevocationSigAlg x509.SignatureAlgorithm `json:"revocation_signature_algorithm"`
Revoked bool `json:"revoked"`
RevocationTime int64 `json:"revocation_time"`
RevocationTimeUTC time.Time `json:"revocation_time_utc"`
AIAURIs *AiaConfigEntry `json:"aia_uris,omitempty"`
LastModified time.Time `json:"last_modified"`
Version uint `json:"version"`
}
func (i IssuerEntry) GetCertificate() (*x509.Certificate, error) {
cert, err := parsing.ParseCertificateFromBytes([]byte(i.Certificate))
if err != nil {
return nil, errutil.InternalError{Err: fmt.Sprintf("unable to parse certificate from issuer: %s: %v", err.Error(), i.ID)}
}
return cert, nil
}
func (i IssuerEntry) EnsureUsage(usage IssuerUsage) error {
// We want to spit out a nice error message about missing usages.
if i.Usage.HasUsage(usage) {
return nil
}
issuerRef := fmt.Sprintf("id:%v", i.ID)
if len(i.Name) > 0 {
issuerRef = fmt.Sprintf("%v / name:%v", issuerRef, i.Name)
}
// These usages differ at some point in time. We've gotta find the first
// usage that differs and return a logical-sounding error message around
// that difference.
for name, candidate := range namedIssuerUsages {
if usage.HasUsage(candidate) && !i.Usage.HasUsage(candidate) {
return fmt.Errorf("requested usage %v for issuer [%v] but only had usage %v", name, issuerRef, i.Usage.Names())
}
}
// Maybe we have an unnamed usage that's requested.
return fmt.Errorf("unknown delta between usages: %v -> %v / for issuer [%v]", usage.Names(), i.Usage.Names(), issuerRef)
}
func (i IssuerEntry) CanMaybeSignWithAlgo(algo x509.SignatureAlgorithm) error {
// Hack: Go isn't kind enough expose its lovely signatureAlgorithmDetails
// informational struct for our usage. However, we don't want to actually
// fetch the private key and attempt a signature with this algo (as we'll
// mint new, previously unsigned material in the process that could maybe
// be potentially abused if it leaks).
//
// So...
//
// ...we maintain our own mapping of cert.PKI<->sigAlgos. Notably, we
// exclude DSA support as the PKI engine has never supported DSA keys.
if algo == x509.UnknownSignatureAlgorithm {
// Special cased to indicate upgrade and letting Go automatically
// chose the correct value.
return nil
}
cert, err := i.GetCertificate()
if err != nil {
return fmt.Errorf("unable to parse issuer's potential signature algorithm types: %w", err)
}
switch cert.PublicKeyAlgorithm {
case x509.RSA:
switch algo {
case x509.SHA256WithRSA, x509.SHA384WithRSA, x509.SHA512WithRSA,
x509.SHA256WithRSAPSS, x509.SHA384WithRSAPSS,
x509.SHA512WithRSAPSS:
return nil
}
case x509.ECDSA:
switch algo {
case x509.ECDSAWithSHA256, x509.ECDSAWithSHA384, x509.ECDSAWithSHA512:
return nil
}
case x509.Ed25519:
switch algo {
case x509.PureEd25519:
return nil
}
}
return fmt.Errorf("unable to use issuer of type %v to sign with %v key type", cert.PublicKeyAlgorithm.String(), algo.String())
}
func ResolveIssuerReference(ctx context.Context, s logical.Storage, reference string) (IssuerID, error) {
if reference == DefaultRef {
// Handle fetching the default issuer.
config, err := GetIssuersConfig(ctx, s)
if err != nil {
return IssuerID("config-error"), err
}
if len(config.DefaultIssuerId) == 0 {
return IssuerRefNotFound, fmt.Errorf("no default issuer currently configured")
}
return config.DefaultIssuerId, nil
}
// Lookup by a direct get first to see if our reference is an ID, this is quick and cached.
if len(reference) == uuidLength {
entry, err := s.Get(ctx, IssuerPrefix+reference)
if err != nil {
return IssuerID("issuer-read"), err
}
if entry != nil {
return IssuerID(reference), nil
}
}
// ... than to pull all issuers from storage.
issuers, err := ListIssuers(ctx, s)
if err != nil {
return IssuerID("list-error"), err
}
for _, issuerId := range issuers {
issuer, err := FetchIssuerById(ctx, s, issuerId)
if err != nil {
return IssuerID("issuer-read"), err
}
if issuer.Name == reference {
return issuer.ID, nil
}
}
// Otherwise, we must not have found the issuer.
return IssuerRefNotFound, errutil.UserError{Err: fmt.Sprintf("unable to find PKI issuer for reference: %v", reference)}
}
func ListIssuers(ctx context.Context, s logical.Storage) ([]IssuerID, error) {
strList, err := s.List(ctx, IssuerPrefix)
if err != nil {
return nil, err
}
issuerIds := make([]IssuerID, 0, len(strList))
for _, entry := range strList {
issuerIds = append(issuerIds, IssuerID(entry))
}
return issuerIds, nil
}
// FetchIssuerById returns an IssuerEntry based on issuerId, if none found an error is returned.
func FetchIssuerById(ctx context.Context, s logical.Storage, issuerId IssuerID) (*IssuerEntry, error) {
if len(issuerId) == 0 {
return nil, errutil.InternalError{Err: "unable to fetch pki issuer: empty issuer identifier"}
}
entry, err := s.Get(ctx, IssuerPrefix+issuerId.String())
if err != nil {
return nil, errutil.InternalError{Err: fmt.Sprintf("unable to fetch pki issuer: %v", err)}
}
if entry == nil {
return nil, errutil.UserError{Err: fmt.Sprintf("pki issuer id %s does not exist", issuerId.String())}
}
var issuer IssuerEntry
if err := entry.DecodeJSON(&issuer); err != nil {
return nil, errutil.InternalError{Err: fmt.Sprintf("unable to decode pki issuer with id %s: %v", issuerId.String(), err)}
}
return upgradeIssuerIfRequired(&issuer), nil
}
func WriteIssuer(ctx context.Context, s logical.Storage, issuer *IssuerEntry) error {
issuerId := issuer.ID
if issuer.LastModified.IsZero() {
issuer.LastModified = time.Now().UTC()
}
json, err := logical.StorageEntryJSON(IssuerPrefix+issuerId.String(), issuer)
if err != nil {
return err
}
return s.Put(ctx, json)
}
func DeleteIssuer(ctx context.Context, s logical.Storage, id IssuerID) (bool, error) {
config, err := GetIssuersConfig(ctx, s)
if err != nil {
return false, err
}
wasDefault := false
if config.DefaultIssuerId == id {
wasDefault = true
// Overwrite the fetched default issuer as we're going to remove this
// entry.
config.fetchedDefault = IssuerID("")
config.DefaultIssuerId = IssuerID("")
if err := SetIssuersConfig(ctx, s, config); err != nil {
return wasDefault, err
}
}
return wasDefault, s.Delete(ctx, IssuerPrefix+id.String())
}
func upgradeIssuerIfRequired(issuer *IssuerEntry) *IssuerEntry {
// *NOTE*: Don't attempt to write out the issuer here as it may cause ErrReadOnly that will direct the
// request all the way up to the primary cluster which would be horrible for local cluster operations such
// as generating a leaf cert or a revoke.
// Also even though we could tell if we are the primary cluster's active node, we can't tell if we have the
// a full rw issuer lock, so it might not be safe to write.
if issuer.Version == LatestIssuerVersion {
return issuer
}
if issuer.Version == 0 {
// Upgrade at this step requires interrogating the certificate itself;
// if this decode fails, it indicates internal problems and the
// request will subsequently fail elsewhere. However, decoding this
// certificate is mildly expensive, so we only do it in the event of
// a Version 0 certificate.
cert, err := issuer.GetCertificate()
if err != nil {
return issuer
}
hadCRL := issuer.Usage.HasUsage(CRLSigningUsage)
// Remove CRL signing usage if it exists on the issuer but doesn't
// exist in the KU of the x509 certificate.
if hadCRL && (cert.KeyUsage&x509.KeyUsageCRLSign) == 0 {
issuer.Usage.ToggleUsage(CRLSigningUsage)
}
// Handle our new OCSPSigning usage flag for earlier versions. If we
// had it (prior to removing it in this upgrade), we'll add the OCSP
// flag since EKUs don't matter.
if hadCRL && !issuer.Usage.HasUsage(OCSPSigningUsage) {
issuer.Usage.ToggleUsage(OCSPSigningUsage)
}
}
issuer.Version = LatestIssuerVersion
return issuer
}
// FetchCAInfoByIssuerId will fetch the CA info, will return an error if no ca info exists for the given issuerId.
// This does support the loading using the legacyBundleShimID
func FetchCAInfoByIssuerId(ctx context.Context, s logical.Storage, mkv managed_key.PkiManagedKeyView, issuerId IssuerID, usage IssuerUsage) (*certutil.CAInfoBundle, error) {
entry, bundle, err := FetchCertBundleByIssuerId(ctx, s, issuerId, true)
if err != nil {
switch err.(type) {
case errutil.UserError:
return nil, err
case errutil.InternalError:
return nil, err
default:
return nil, errutil.InternalError{Err: fmt.Sprintf("error fetching CA info: %v", err)}
}
}
if err = entry.EnsureUsage(usage); err != nil {
return nil, errutil.InternalError{Err: fmt.Sprintf("error while attempting to use issuer %v: %v", issuerId, err)}
}
parsedBundle, err := ParseCABundle(ctx, mkv, bundle)
if err != nil {
return nil, errutil.InternalError{Err: err.Error()}
}
if parsedBundle.Certificate == nil {
return nil, errutil.InternalError{Err: "stored CA information not able to be parsed"}
}
if parsedBundle.PrivateKey == nil {
return nil, errutil.UserError{Err: fmt.Sprintf("unable to fetch corresponding key for issuer %v; unable to use this issuer for signing", issuerId)}
}
caInfo := &certutil.CAInfoBundle{
ParsedCertBundle: *parsedBundle,
URLs: nil,
LeafNotAfterBehavior: entry.LeafNotAfterBehavior,
RevocationSigAlg: entry.RevocationSigAlg,
}
entries, err := GetAIAURLs(ctx, s, entry)
if err != nil {
return nil, errutil.InternalError{Err: fmt.Sprintf("unable to fetch AIA URL information: %v", err)}
}
caInfo.URLs = entries
return caInfo, nil
}
func ParseCABundle(ctx context.Context, mkv managed_key.PkiManagedKeyView, bundle *certutil.CertBundle) (*certutil.ParsedCertBundle, error) {
if bundle.PrivateKeyType == certutil.ManagedPrivateKey {
return managed_key.ParseManagedKeyCABundle(ctx, mkv, bundle)
}
return bundle.ToParsedCertBundle()
}
// FetchCertBundleByIssuerId builds a certutil.CertBundle from the specified issuer identifier,
// optionally loading the key or not. This method supports loading legacy
// bundles using the legacyBundleShimID issuerId, and if no entry is found will return an error.
func FetchCertBundleByIssuerId(ctx context.Context, s logical.Storage, id IssuerID, loadKey bool) (*IssuerEntry, *certutil.CertBundle, error) {
if id == LegacyBundleShimID {
// We have not completed the migration, or started a request in legacy mode, so
// attempt to load the bundle from the legacy location
issuer, bundle, err := GetLegacyCertBundle(ctx, s)
if err != nil {
return nil, nil, err
}
if issuer == nil || bundle == nil {
return nil, nil, errutil.UserError{Err: "no legacy cert bundle exists"}
}
return issuer, bundle, err
}
issuer, err := FetchIssuerById(ctx, s, id)
if err != nil {
return nil, nil, err
}
var bundle certutil.CertBundle
bundle.Certificate = issuer.Certificate
bundle.CAChain = issuer.CAChain
bundle.SerialNumber = issuer.SerialNumber
// Fetch the key if it exists. Sometimes we don't need the key immediately.
if loadKey && issuer.KeyID != KeyID("") {
key, err := FetchKeyById(ctx, s, issuer.KeyID)
if err != nil {
return nil, nil, err
}
bundle.PrivateKeyType = key.PrivateKeyType
bundle.PrivateKey = key.PrivateKey
}
return issuer, &bundle, nil
}
func GetLegacyCertBundle(ctx context.Context, s logical.Storage) (*IssuerEntry, *certutil.CertBundle, error) {
entry, err := s.Get(ctx, LegacyCertBundlePath)
if err != nil {
return nil, nil, err
}
if entry == nil {
return nil, nil, nil
}
cb := &certutil.CertBundle{}
err = entry.DecodeJSON(cb)
if err != nil {
return nil, nil, err
}
// Fake a storage entry with backwards compatibility in mind.
issuer := &IssuerEntry{
ID: LegacyBundleShimID,
KeyID: LegacyBundleShimKeyID,
Name: "legacy-entry-shim",
Certificate: cb.Certificate,
CAChain: cb.CAChain,
SerialNumber: cb.SerialNumber,
LeafNotAfterBehavior: certutil.ErrNotAfterBehavior,
}
issuer.Usage.ToggleUsage(AllIssuerUsages)
return issuer, cb, nil
}

View File

@@ -0,0 +1,153 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package issuing
import (
"context"
"fmt"
"github.com/hashicorp/vault/sdk/helper/certutil"
"github.com/hashicorp/vault/sdk/helper/errutil"
"github.com/hashicorp/vault/sdk/logical"
"github.com/hashicorp/vault/builtin/logical/pki/managed_key"
)
const (
KeyPrefix = "config/key/"
KeyRefNotFound = KeyID("not-found")
)
type KeyID string
func (p KeyID) String() string {
return string(p)
}
type KeyEntry struct {
ID KeyID `json:"id"`
Name string `json:"name"`
PrivateKeyType certutil.PrivateKeyType `json:"private_key_type"`
PrivateKey string `json:"private_key"`
}
func (e KeyEntry) IsManagedPrivateKey() bool {
return e.PrivateKeyType == certutil.ManagedPrivateKey
}
func ListKeys(ctx context.Context, s logical.Storage) ([]KeyID, error) {
strList, err := s.List(ctx, KeyPrefix)
if err != nil {
return nil, err
}
keyIds := make([]KeyID, 0, len(strList))
for _, entry := range strList {
keyIds = append(keyIds, KeyID(entry))
}
return keyIds, nil
}
func FetchKeyById(ctx context.Context, s logical.Storage, keyId KeyID) (*KeyEntry, error) {
if len(keyId) == 0 {
return nil, errutil.InternalError{Err: "unable to fetch pki key: empty key identifier"}
}
entry, err := s.Get(ctx, KeyPrefix+keyId.String())
if err != nil {
return nil, errutil.InternalError{Err: fmt.Sprintf("unable to fetch pki key: %v", err)}
}
if entry == nil {
return nil, errutil.UserError{Err: fmt.Sprintf("pki key id %s does not exist", keyId.String())}
}
var key KeyEntry
if err := entry.DecodeJSON(&key); err != nil {
return nil, errutil.InternalError{Err: fmt.Sprintf("unable to decode pki key with id %s: %v", keyId.String(), err)}
}
return &key, nil
}
func WriteKey(ctx context.Context, s logical.Storage, key KeyEntry) error {
keyId := key.ID
json, err := logical.StorageEntryJSON(KeyPrefix+keyId.String(), key)
if err != nil {
return err
}
return s.Put(ctx, json)
}
func DeleteKey(ctx context.Context, s logical.Storage, id KeyID) (bool, error) {
config, err := GetKeysConfig(ctx, s)
if err != nil {
return false, err
}
wasDefault := false
if config.DefaultKeyId == id {
wasDefault = true
config.DefaultKeyId = KeyID("")
if err := SetKeysConfig(ctx, s, config); err != nil {
return wasDefault, err
}
}
return wasDefault, s.Delete(ctx, KeyPrefix+id.String())
}
func ResolveKeyReference(ctx context.Context, s logical.Storage, reference string) (KeyID, error) {
if reference == DefaultRef {
// Handle fetching the default key.
config, err := GetKeysConfig(ctx, s)
if err != nil {
return KeyID("config-error"), err
}
if len(config.DefaultKeyId) == 0 {
return KeyRefNotFound, fmt.Errorf("no default key currently configured")
}
return config.DefaultKeyId, nil
}
// Lookup by a direct get first to see if our reference is an ID, this is quick and cached.
if len(reference) == uuidLength {
entry, err := s.Get(ctx, KeyPrefix+reference)
if err != nil {
return KeyID("key-read"), err
}
if entry != nil {
return KeyID(reference), nil
}
}
// ... than to pull all keys from storage.
keys, err := ListKeys(ctx, s)
if err != nil {
return KeyID("list-error"), err
}
for _, keyId := range keys {
key, err := FetchKeyById(ctx, s, keyId)
if err != nil {
return KeyID("key-read"), err
}
if key.Name == reference {
return key.ID, nil
}
}
// Otherwise, we must not have found the key.
return KeyRefNotFound, errutil.UserError{Err: fmt.Sprintf("unable to find PKI key for reference: %v", reference)}
}
func GetManagedKeyUUID(key *KeyEntry) (managed_key.UUIDKey, error) {
if !key.IsManagedPrivateKey() {
return "", errutil.InternalError{Err: "getManagedKeyUUID called on a key id %s (%s) "}
}
return managed_key.ExtractManagedKeyId([]byte(key.PrivateKey))
}

View File

@@ -0,0 +1,452 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package issuing
import (
"context"
"errors"
"fmt"
"strings"
"time"
"github.com/hashicorp/go-secure-stdlib/parseutil"
"github.com/hashicorp/vault/sdk/helper/certutil"
"github.com/hashicorp/vault/sdk/logical"
)
var (
DefaultRoleKeyUsages = []string{"DigitalSignature", "KeyAgreement", "KeyEncipherment"}
DefaultRoleEstKeyUsages = []string{}
DefaultRoleEstKeyUsageOids = []string{}
)
const (
DefaultRoleSignatureBits = 0
DefaultRoleUsePss = false
)
type RoleEntry struct {
LeaseMax string `json:"lease_max"`
Lease string `json:"lease"`
DeprecatedMaxTTL string `json:"max_ttl"`
DeprecatedTTL string `json:"ttl"`
TTL time.Duration `json:"ttl_duration"`
MaxTTL time.Duration `json:"max_ttl_duration"`
AllowLocalhost bool `json:"allow_localhost"`
AllowedBaseDomain string `json:"allowed_base_domain"`
AllowedDomainsOld string `json:"allowed_domains,omitempty"`
AllowedDomains []string `json:"allowed_domains_list"`
AllowedDomainsTemplate bool `json:"allowed_domains_template"`
AllowBaseDomain bool `json:"allow_base_domain"`
AllowBareDomains bool `json:"allow_bare_domains"`
AllowTokenDisplayName bool `json:"allow_token_displayname"`
AllowSubdomains bool `json:"allow_subdomains"`
AllowGlobDomains bool `json:"allow_glob_domains"`
AllowWildcardCertificates *bool `json:"allow_wildcard_certificates,omitempty"`
AllowAnyName bool `json:"allow_any_name"`
EnforceHostnames bool `json:"enforce_hostnames"`
AllowIPSANs bool `json:"allow_ip_sans"`
ServerFlag bool `json:"server_flag"`
ClientFlag bool `json:"client_flag"`
CodeSigningFlag bool `json:"code_signing_flag"`
EmailProtectionFlag bool `json:"email_protection_flag"`
UseCSRCommonName bool `json:"use_csr_common_name"`
UseCSRSANs bool `json:"use_csr_sans"`
KeyType string `json:"key_type"`
KeyBits int `json:"key_bits"`
UsePSS bool `json:"use_pss"`
SignatureBits int `json:"signature_bits"`
MaxPathLength *int `json:",omitempty"`
KeyUsageOld string `json:"key_usage,omitempty"`
KeyUsage []string `json:"key_usage_list"`
ExtKeyUsage []string `json:"extended_key_usage_list"`
OUOld string `json:"ou,omitempty"`
OU []string `json:"ou_list"`
OrganizationOld string `json:"organization,omitempty"`
Organization []string `json:"organization_list"`
Country []string `json:"country"`
Locality []string `json:"locality"`
Province []string `json:"province"`
StreetAddress []string `json:"street_address"`
PostalCode []string `json:"postal_code"`
GenerateLease *bool `json:"generate_lease,omitempty"`
NoStore bool `json:"no_store"`
RequireCN bool `json:"require_cn"`
CNValidations []string `json:"cn_validations"`
AllowedOtherSANs []string `json:"allowed_other_sans"`
AllowedSerialNumbers []string `json:"allowed_serial_numbers"`
AllowedUserIDs []string `json:"allowed_user_ids"`
AllowedURISANs []string `json:"allowed_uri_sans"`
AllowedURISANsTemplate bool `json:"allowed_uri_sans_template"`
PolicyIdentifiers []string `json:"policy_identifiers"`
ExtKeyUsageOIDs []string `json:"ext_key_usage_oids"`
BasicConstraintsValidForNonCA bool `json:"basic_constraints_valid_for_non_ca"`
NotBeforeDuration time.Duration `json:"not_before_duration"`
NotAfter string `json:"not_after"`
Issuer string `json:"issuer"`
// Name is only set when the role has been stored, on the fly roles have a blank name
Name string `json:"-"`
// WasModified indicates to callers if the returned entry is different than the persisted version
WasModified bool `json:"-"`
}
func (r *RoleEntry) ToResponseData() map[string]interface{} {
responseData := map[string]interface{}{
"ttl": int64(r.TTL.Seconds()),
"max_ttl": int64(r.MaxTTL.Seconds()),
"allow_localhost": r.AllowLocalhost,
"allowed_domains": r.AllowedDomains,
"allowed_domains_template": r.AllowedDomainsTemplate,
"allow_bare_domains": r.AllowBareDomains,
"allow_token_displayname": r.AllowTokenDisplayName,
"allow_subdomains": r.AllowSubdomains,
"allow_glob_domains": r.AllowGlobDomains,
"allow_wildcard_certificates": r.AllowWildcardCertificates,
"allow_any_name": r.AllowAnyName,
"allowed_uri_sans_template": r.AllowedURISANsTemplate,
"enforce_hostnames": r.EnforceHostnames,
"allow_ip_sans": r.AllowIPSANs,
"server_flag": r.ServerFlag,
"client_flag": r.ClientFlag,
"code_signing_flag": r.CodeSigningFlag,
"email_protection_flag": r.EmailProtectionFlag,
"use_csr_common_name": r.UseCSRCommonName,
"use_csr_sans": r.UseCSRSANs,
"key_type": r.KeyType,
"key_bits": r.KeyBits,
"signature_bits": r.SignatureBits,
"use_pss": r.UsePSS,
"key_usage": r.KeyUsage,
"ext_key_usage": r.ExtKeyUsage,
"ext_key_usage_oids": r.ExtKeyUsageOIDs,
"ou": r.OU,
"organization": r.Organization,
"country": r.Country,
"locality": r.Locality,
"province": r.Province,
"street_address": r.StreetAddress,
"postal_code": r.PostalCode,
"no_store": r.NoStore,
"allowed_other_sans": r.AllowedOtherSANs,
"allowed_serial_numbers": r.AllowedSerialNumbers,
"allowed_user_ids": r.AllowedUserIDs,
"allowed_uri_sans": r.AllowedURISANs,
"require_cn": r.RequireCN,
"cn_validations": r.CNValidations,
"policy_identifiers": r.PolicyIdentifiers,
"basic_constraints_valid_for_non_ca": r.BasicConstraintsValidForNonCA,
"not_before_duration": int64(r.NotBeforeDuration.Seconds()),
"not_after": r.NotAfter,
"issuer_ref": r.Issuer,
}
if r.MaxPathLength != nil {
responseData["max_path_length"] = r.MaxPathLength
}
if r.GenerateLease != nil {
responseData["generate_lease"] = r.GenerateLease
}
return responseData
}
var ErrRoleNotFound = errors.New("role not found")
// GetRole will load a role from storage based on the provided name and
// update its contents to the latest version if out of date. The WasUpdated field
// will be set to true if modifications were made indicating the caller should if
// possible write them back to disk. If the role is not found an ErrRoleNotFound
// will be returned as an error.
func GetRole(ctx context.Context, s logical.Storage, n string) (*RoleEntry, error) {
entry, err := s.Get(ctx, "role/"+n)
if err != nil {
return nil, fmt.Errorf("failed to load role %s: %w", n, err)
}
if entry == nil {
return nil, fmt.Errorf("%w: with name %s", ErrRoleNotFound, n)
}
var result RoleEntry
if err := entry.DecodeJSON(&result); err != nil {
return nil, fmt.Errorf("failed decoding role %s: %w", n, err)
}
// Migrate existing saved entries and save back if changed
modified := false
if len(result.DeprecatedTTL) == 0 && len(result.Lease) != 0 {
result.DeprecatedTTL = result.Lease
result.Lease = ""
modified = true
}
if result.TTL == 0 && len(result.DeprecatedTTL) != 0 {
parsed, err := parseutil.ParseDurationSecond(result.DeprecatedTTL)
if err != nil {
return nil, err
}
result.TTL = parsed
result.DeprecatedTTL = ""
modified = true
}
if len(result.DeprecatedMaxTTL) == 0 && len(result.LeaseMax) != 0 {
result.DeprecatedMaxTTL = result.LeaseMax
result.LeaseMax = ""
modified = true
}
if result.MaxTTL == 0 && len(result.DeprecatedMaxTTL) != 0 {
parsed, err := parseutil.ParseDurationSecond(result.DeprecatedMaxTTL)
if err != nil {
return nil, fmt.Errorf("failed parsing max_ttl field in %s: %w", n, err)
}
result.MaxTTL = parsed
result.DeprecatedMaxTTL = ""
modified = true
}
if result.AllowBaseDomain {
result.AllowBaseDomain = false
result.AllowBareDomains = true
modified = true
}
if result.AllowedDomainsOld != "" {
result.AllowedDomains = strings.Split(result.AllowedDomainsOld, ",")
result.AllowedDomainsOld = ""
modified = true
}
if result.AllowedBaseDomain != "" {
found := false
for _, v := range result.AllowedDomains {
if v == result.AllowedBaseDomain {
found = true
break
}
}
if !found {
result.AllowedDomains = append(result.AllowedDomains, result.AllowedBaseDomain)
}
result.AllowedBaseDomain = ""
modified = true
}
if result.AllowWildcardCertificates == nil {
// While not the most secure default, when AllowWildcardCertificates isn't
// explicitly specified in the stored Role, we automatically upgrade it to
// true to preserve compatibility with previous versions of Vault. Once this
// field is set, this logic will not be triggered any more.
result.AllowWildcardCertificates = new(bool)
*result.AllowWildcardCertificates = true
modified = true
}
// Upgrade generate_lease in role
if result.GenerateLease == nil {
// All the new roles will have GenerateLease always set to a Value. A
// nil Value indicates that this role needs an upgrade. Set it to
// `true` to not alter its current behavior.
result.GenerateLease = new(bool)
*result.GenerateLease = true
modified = true
}
// Upgrade key usages
if result.KeyUsageOld != "" {
result.KeyUsage = strings.Split(result.KeyUsageOld, ",")
result.KeyUsageOld = ""
modified = true
}
// Upgrade OU
if result.OUOld != "" {
result.OU = strings.Split(result.OUOld, ",")
result.OUOld = ""
modified = true
}
// Upgrade Organization
if result.OrganizationOld != "" {
result.Organization = strings.Split(result.OrganizationOld, ",")
result.OrganizationOld = ""
modified = true
}
// Set the issuer field to default if not set. We want to do this
// unconditionally as we should probably never have an empty issuer
// on a stored roles.
if len(result.Issuer) == 0 {
result.Issuer = DefaultRef
modified = true
}
// Update CN Validations to be the present default, "email,hostname"
if len(result.CNValidations) == 0 {
result.CNValidations = []string{"email", "hostname"}
modified = true
}
result.Name = n
result.WasModified = modified
return &result, nil
}
type RoleModifier func(r *RoleEntry)
func WithKeyUsage(keyUsages []string) RoleModifier {
return func(r *RoleEntry) {
r.KeyUsage = keyUsages
}
}
func WithExtKeyUsage(extKeyUsages []string) RoleModifier {
return func(r *RoleEntry) {
r.ExtKeyUsage = extKeyUsages
}
}
func WithExtKeyUsageOIDs(extKeyUsageOids []string) RoleModifier {
return func(r *RoleEntry) {
r.ExtKeyUsageOIDs = extKeyUsageOids
}
}
func WithSignatureBits(signatureBits int) RoleModifier {
return func(r *RoleEntry) {
r.SignatureBits = signatureBits
}
}
func WithUsePSS(usePss bool) RoleModifier {
return func(r *RoleEntry) {
r.UsePSS = usePss
}
}
func WithTTL(ttl time.Duration) RoleModifier {
return func(r *RoleEntry) {
r.TTL = ttl
}
}
func WithMaxTTL(ttl time.Duration) RoleModifier {
return func(r *RoleEntry) {
r.MaxTTL = ttl
}
}
func WithGenerateLease(genLease bool) RoleModifier {
return func(r *RoleEntry) {
*r.GenerateLease = genLease
}
}
func WithNotBeforeDuration(ttl time.Duration) RoleModifier {
return func(r *RoleEntry) {
r.NotBeforeDuration = ttl
}
}
func WithNoStore(noStore bool) RoleModifier {
return func(r *RoleEntry) {
r.NoStore = noStore
}
}
func WithIssuer(issuer string) RoleModifier {
return func(r *RoleEntry) {
if issuer == "" {
issuer = DefaultRef
}
r.Issuer = issuer
}
}
// SignVerbatimRole create a sign-verbatim role with no overrides. This will store
// the signed certificate, allowing any key type and Value from a role restriction.
func SignVerbatimRole() *RoleEntry {
return SignVerbatimRoleWithOpts()
}
// SignVerbatimRoleWithOpts create a sign-verbatim role with the normal defaults,
// but allowing any field to be tweaked based on the consumers needs.
func SignVerbatimRoleWithOpts(opts ...RoleModifier) *RoleEntry {
entry := &RoleEntry{
AllowLocalhost: true,
AllowAnyName: true,
AllowIPSANs: true,
AllowWildcardCertificates: new(bool),
EnforceHostnames: false,
KeyType: "any",
UseCSRCommonName: true,
UseCSRSANs: true,
AllowedOtherSANs: []string{"*"},
AllowedSerialNumbers: []string{"*"},
AllowedURISANs: []string{"*"},
AllowedUserIDs: []string{"*"},
CNValidations: []string{"disabled"},
GenerateLease: new(bool),
KeyUsage: DefaultRoleKeyUsages,
ExtKeyUsage: DefaultRoleEstKeyUsages,
ExtKeyUsageOIDs: DefaultRoleEstKeyUsageOids,
SignatureBits: DefaultRoleSignatureBits,
UsePSS: DefaultRoleUsePss,
}
*entry.AllowWildcardCertificates = true
*entry.GenerateLease = false
if opts != nil {
for _, opt := range opts {
if opt != nil {
opt(entry)
}
}
}
return entry
}
func ParseExtKeyUsagesFromRole(role *RoleEntry) certutil.CertExtKeyUsage {
var parsedKeyUsages certutil.CertExtKeyUsage
if role.ServerFlag {
parsedKeyUsages |= certutil.ServerAuthExtKeyUsage
}
if role.ClientFlag {
parsedKeyUsages |= certutil.ClientAuthExtKeyUsage
}
if role.CodeSigningFlag {
parsedKeyUsages |= certutil.CodeSigningExtKeyUsage
}
if role.EmailProtectionFlag {
parsedKeyUsages |= certutil.EmailProtectionExtKeyUsage
}
for _, k := range role.ExtKeyUsage {
switch strings.ToLower(strings.TrimSpace(k)) {
case "any":
parsedKeyUsages |= certutil.AnyExtKeyUsage
case "serverauth":
parsedKeyUsages |= certutil.ServerAuthExtKeyUsage
case "clientauth":
parsedKeyUsages |= certutil.ClientAuthExtKeyUsage
case "codesigning":
parsedKeyUsages |= certutil.CodeSigningExtKeyUsage
case "emailprotection":
parsedKeyUsages |= certutil.EmailProtectionExtKeyUsage
case "ipsecendsystem":
parsedKeyUsages |= certutil.IpsecEndSystemExtKeyUsage
case "ipsectunnel":
parsedKeyUsages |= certutil.IpsecTunnelExtKeyUsage
case "ipsecuser":
parsedKeyUsages |= certutil.IpsecUserExtKeyUsage
case "timestamping":
parsedKeyUsages |= certutil.TimeStampingExtKeyUsage
case "ocspsigning":
parsedKeyUsages |= certutil.OcspSigningExtKeyUsage
case "microsoftservergatedcrypto":
parsedKeyUsages |= certutil.MicrosoftServerGatedCryptoExtKeyUsage
case "netscapeservergatedcrypto":
parsedKeyUsages |= certutil.NetscapeServerGatedCryptoExtKeyUsage
}
}
return parsedKeyUsages
}

View File

@@ -0,0 +1,291 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package issuing
import (
"crypto/ecdsa"
"crypto/ed25519"
"crypto/rsa"
"crypto/x509"
"fmt"
"github.com/hashicorp/vault/sdk/helper/certutil"
"github.com/hashicorp/vault/sdk/helper/errutil"
"github.com/hashicorp/vault/sdk/logical"
)
type SignCertInput interface {
CreationBundleInput
GetCSR() (*x509.CertificateRequest, error)
IsCA() bool
UseCSRValues() bool
GetPermittedDomains() []string
}
func NewBasicSignCertInput(csr *x509.CertificateRequest, isCA bool, useCSRValues bool) BasicSignCertInput {
return BasicSignCertInput{
isCA: isCA,
useCSRValues: useCSRValues,
csr: csr,
}
}
var _ SignCertInput = BasicSignCertInput{}
type BasicSignCertInput struct {
isCA bool
useCSRValues bool
csr *x509.CertificateRequest
}
func (b BasicSignCertInput) GetTTL() int {
return 0
}
func (b BasicSignCertInput) GetOptionalNotAfter() (interface{}, bool) {
return "", false
}
func (b BasicSignCertInput) GetCommonName() string {
return ""
}
func (b BasicSignCertInput) GetSerialNumber() string {
return ""
}
func (b BasicSignCertInput) GetExcludeCnFromSans() bool {
return false
}
func (b BasicSignCertInput) GetOptionalAltNames() (interface{}, bool) {
return []string{}, false
}
func (b BasicSignCertInput) GetOtherSans() []string {
return []string{}
}
func (b BasicSignCertInput) GetIpSans() []string {
return []string{}
}
func (b BasicSignCertInput) GetURISans() []string {
return []string{}
}
func (b BasicSignCertInput) GetOptionalSkid() (interface{}, bool) {
return "", false
}
func (b BasicSignCertInput) IsUserIdInSchema() (interface{}, bool) {
return []string{}, false
}
func (b BasicSignCertInput) GetUserIds() []string {
return []string{}
}
func (b BasicSignCertInput) GetCSR() (*x509.CertificateRequest, error) {
return b.csr, nil
}
func (b BasicSignCertInput) IsCA() bool {
return b.isCA
}
func (b BasicSignCertInput) UseCSRValues() bool {
return b.useCSRValues
}
func (b BasicSignCertInput) GetPermittedDomains() []string {
return []string{}
}
func SignCert(b logical.SystemView, role *RoleEntry, entityInfo EntityInfo, caSign *certutil.CAInfoBundle, signInput SignCertInput) (*certutil.ParsedCertBundle, []string, error) {
if role == nil {
return nil, nil, errutil.InternalError{Err: "no role found in data bundle"}
}
csr, err := signInput.GetCSR()
if err != nil {
return nil, nil, err
}
if csr.PublicKeyAlgorithm == x509.UnknownPublicKeyAlgorithm || csr.PublicKey == nil {
return nil, nil, errutil.UserError{Err: "Refusing to sign CSR with empty PublicKey. This usually means the SubjectPublicKeyInfo field has an OID not recognized by Go, such as 1.2.840.113549.1.1.10 for rsaPSS."}
}
// This switch validates that the CSR key type matches the role and sets
// the Value in the actualKeyType/actualKeyBits values.
actualKeyType := ""
actualKeyBits := 0
switch role.KeyType {
case "rsa":
// Verify that the key matches the role type
if csr.PublicKeyAlgorithm != x509.RSA {
return nil, nil, errutil.UserError{Err: fmt.Sprintf("role requires keys of type %s", role.KeyType)}
}
pubKey, ok := csr.PublicKey.(*rsa.PublicKey)
if !ok {
return nil, nil, errutil.UserError{Err: "could not parse CSR's public key"}
}
actualKeyType = "rsa"
actualKeyBits = pubKey.N.BitLen()
case "ec":
// Verify that the key matches the role type
if csr.PublicKeyAlgorithm != x509.ECDSA {
return nil, nil, errutil.UserError{Err: fmt.Sprintf(
"role requires keys of type %s",
role.KeyType)}
}
pubKey, ok := csr.PublicKey.(*ecdsa.PublicKey)
if !ok {
return nil, nil, errutil.UserError{Err: "could not parse CSR's public key"}
}
actualKeyType = "ec"
actualKeyBits = pubKey.Params().BitSize
case "ed25519":
// Verify that the key matches the role type
if csr.PublicKeyAlgorithm != x509.Ed25519 {
return nil, nil, errutil.UserError{Err: fmt.Sprintf(
"role requires keys of type %s",
role.KeyType)}
}
_, ok := csr.PublicKey.(ed25519.PublicKey)
if !ok {
return nil, nil, errutil.UserError{Err: "could not parse CSR's public key"}
}
actualKeyType = "ed25519"
actualKeyBits = 0
case "any":
// We need to compute the actual key type and key bits, to correctly
// validate minimums and SignatureBits below.
switch csr.PublicKeyAlgorithm {
case x509.RSA:
pubKey, ok := csr.PublicKey.(*rsa.PublicKey)
if !ok {
return nil, nil, errutil.UserError{Err: "could not parse CSR's public key"}
}
if pubKey.N.BitLen() < 2048 {
return nil, nil, errutil.UserError{Err: "RSA keys < 2048 bits are unsafe and not supported"}
}
actualKeyType = "rsa"
actualKeyBits = pubKey.N.BitLen()
case x509.ECDSA:
pubKey, ok := csr.PublicKey.(*ecdsa.PublicKey)
if !ok {
return nil, nil, errutil.UserError{Err: "could not parse CSR's public key"}
}
actualKeyType = "ec"
actualKeyBits = pubKey.Params().BitSize
case x509.Ed25519:
_, ok := csr.PublicKey.(ed25519.PublicKey)
if !ok {
return nil, nil, errutil.UserError{Err: "could not parse CSR's public key"}
}
actualKeyType = "ed25519"
actualKeyBits = 0
default:
return nil, nil, errutil.UserError{Err: "Unknown key type in CSR: " + csr.PublicKeyAlgorithm.String()}
}
default:
return nil, nil, errutil.InternalError{Err: fmt.Sprintf("unsupported key type Value: %s", role.KeyType)}
}
// Before validating key lengths, update our KeyBits/SignatureBits based
// on the actual CSR key type.
if role.KeyType == "any" {
// We update the Value of KeyBits and SignatureBits here (from the
// role), using the specified key type. This allows us to convert
// the default Value (0) for SignatureBits and KeyBits to a
// meaningful Value.
//
// We ignore the role's original KeyBits Value if the KeyType is any
// as legacy (pre-1.10) roles had default values that made sense only
// for RSA keys (key_bits=2048) and the older code paths ignored the role Value
// set for KeyBits when KeyType was set to any. This also enforces the
// docs saying when key_type=any, we only enforce our specified minimums
// for signing operations
var err error
if role.KeyBits, role.SignatureBits, err = certutil.ValidateDefaultOrValueKeyTypeSignatureLength(
actualKeyType, 0, role.SignatureBits); err != nil {
return nil, nil, errutil.InternalError{Err: fmt.Sprintf("unknown internal error updating default values: %v", err)}
}
// We're using the KeyBits field as a minimum Value below, and P-224 is safe
// and a previously allowed Value. However, the above call defaults
// to P-256 as that's a saner default than P-224 (w.r.t. generation), so
// override it here to allow 224 as the smallest size we permit.
if actualKeyType == "ec" {
role.KeyBits = 224
}
}
// At this point, role.KeyBits and role.SignatureBits should both
// be non-zero, for RSA and ECDSA keys. Validate the actualKeyBits based on
// the role's values. If the KeyType was any, and KeyBits was set to 0,
// KeyBits should be updated to 2048 unless some other Value was chosen
// explicitly.
//
// This validation needs to occur regardless of the role's key type, so
// that we always validate both RSA and ECDSA key sizes.
if actualKeyType == "rsa" {
if actualKeyBits < role.KeyBits {
return nil, nil, errutil.UserError{Err: fmt.Sprintf(
"role requires a minimum of a %d-bit key, but CSR's key is %d bits",
role.KeyBits, actualKeyBits)}
}
if actualKeyBits < 2048 {
return nil, nil, errutil.UserError{Err: fmt.Sprintf(
"Vault requires a minimum of a 2048-bit key, but CSR's key is %d bits",
actualKeyBits)}
}
} else if actualKeyType == "ec" {
if actualKeyBits < role.KeyBits {
return nil, nil, errutil.UserError{Err: fmt.Sprintf(
"role requires a minimum of a %d-bit key, but CSR's key is %d bits",
role.KeyBits,
actualKeyBits)}
}
}
creation, warnings, err := GenerateCreationBundle(b, role, entityInfo, signInput, caSign, csr)
if err != nil {
return nil, nil, err
}
if creation.Params == nil {
return nil, nil, errutil.InternalError{Err: "nil parameters received from parameter bundle generation"}
}
creation.Params.IsCA = signInput.IsCA()
creation.Params.UseCSRValues = signInput.UseCSRValues()
if signInput.IsCA() {
creation.Params.PermittedDNSDomains = signInput.GetPermittedDomains()
} else {
for _, ext := range csr.Extensions {
if ext.Id.Equal(certutil.ExtensionBasicConstraintsOID) {
warnings = append(warnings, "specified CSR contained a Basic Constraints extension that was ignored during issuance")
}
}
}
parsedBundle, err := certutil.SignCertificate(creation)
if err != nil {
return nil, nil, err
}
return parsedBundle, warnings, nil
}

View File

@@ -12,9 +12,12 @@ import (
"github.com/hashicorp/vault/sdk/helper/certutil"
"github.com/hashicorp/vault/sdk/helper/errutil"
"github.com/hashicorp/vault/builtin/logical/pki/issuing"
"github.com/hashicorp/vault/builtin/logical/pki/managed_key"
)
func comparePublicKey(sc *storageContext, key *keyEntry, publicKey crypto.PublicKey) (bool, error) {
func comparePublicKey(sc *storageContext, key *issuing.KeyEntry, publicKey crypto.PublicKey) (bool, error) {
publicKeyForKeyEntry, err := getPublicKey(sc.Context, sc.Backend, key)
if err != nil {
return false, err
@@ -23,13 +26,9 @@ func comparePublicKey(sc *storageContext, key *keyEntry, publicKey crypto.Public
return certutil.ComparePublicKeysAndType(publicKeyForKeyEntry, publicKey)
}
func getPublicKey(ctx context.Context, b *backend, key *keyEntry) (crypto.PublicKey, error) {
func getPublicKey(ctx context.Context, b *backend, key *issuing.KeyEntry) (crypto.PublicKey, error) {
if key.PrivateKeyType == certutil.ManagedPrivateKey {
keyId, err := extractManagedKeyId([]byte(key.PrivateKey))
if err != nil {
return nil, err
}
return getManagedKeyPublicKey(ctx, b, keyId)
return managed_key.GetPublicKeyFromKeyBytes(ctx, b, []byte(key.PrivateKey))
}
signer, _, _, err := getSignerFromKeyEntryBytes(key)
@@ -39,7 +38,7 @@ func getPublicKey(ctx context.Context, b *backend, key *keyEntry) (crypto.Public
return signer.Public(), nil
}
func getSignerFromKeyEntryBytes(key *keyEntry) (crypto.Signer, certutil.BlockType, *pem.Block, error) {
func getSignerFromKeyEntryBytes(key *issuing.KeyEntry) (crypto.Signer, certutil.BlockType, *pem.Block, error) {
if key.PrivateKeyType == certutil.UnknownPrivateKey {
return nil, certutil.UnknownBlock, nil, errutil.InternalError{Err: fmt.Sprintf("unsupported unknown private key type for key: %s (%s)", key.ID, key.Name)}
}
@@ -78,7 +77,7 @@ func getPublicKeyFromBytes(keyBytes []byte) (crypto.PublicKey, error) {
return signer.Public(), nil
}
func importKeyFromBytes(sc *storageContext, keyValue string, keyName string) (*keyEntry, bool, error) {
func importKeyFromBytes(sc *storageContext, keyValue string, keyName string) (*issuing.KeyEntry, bool, error) {
signer, _, _, err := getSignerFromBytes([]byte(keyValue))
if err != nil {
return nil, false, err

View File

@@ -0,0 +1,43 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package managed_key
import (
"crypto"
"io"
"github.com/hashicorp/vault/sdk/helper/certutil"
"github.com/hashicorp/vault/sdk/logical"
)
type ManagedKeyInfo struct {
publicKey crypto.PublicKey
KeyType certutil.PrivateKeyType
Name NameKey
Uuid UUIDKey
}
type managedKeyId interface {
String() string
}
type PkiManagedKeyView interface {
BackendUUID() string
IsSecondaryNode() bool
GetManagedKeyView() (logical.ManagedKeySystemView, error)
GetRandomReader() io.Reader
}
type (
UUIDKey string
NameKey string
)
func (u UUIDKey) String() string {
return string(u)
}
func (n NameKey) String() string {
return string(n)
}

View File

@@ -0,0 +1,49 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
//go:build !enterprise
package managed_key
import (
"context"
"crypto"
"errors"
"io"
"github.com/hashicorp/vault/sdk/helper/certutil"
)
var errEntOnly = errors.New("managed keys are supported within enterprise edition only")
func GetPublicKeyFromKeyBytes(ctx context.Context, mkv PkiManagedKeyView, keyBytes []byte) (crypto.PublicKey, error) {
return nil, errEntOnly
}
func GenerateManagedKeyCABundle(ctx context.Context, b PkiManagedKeyView, keyId managedKeyId, data *certutil.CreationBundle, randomSource io.Reader) (bundle *certutil.ParsedCertBundle, err error) {
return nil, errEntOnly
}
func GenerateManagedKeyCSRBundle(ctx context.Context, b PkiManagedKeyView, keyId managedKeyId, data *certutil.CreationBundle, addBasicConstraints bool, randomSource io.Reader) (bundle *certutil.ParsedCSRBundle, err error) {
return nil, errEntOnly
}
func GetManagedKeyPublicKey(ctx context.Context, b PkiManagedKeyView, keyId managedKeyId) (crypto.PublicKey, error) {
return nil, errEntOnly
}
func ParseManagedKeyCABundle(ctx context.Context, mkv PkiManagedKeyView, bundle *certutil.CertBundle) (*certutil.ParsedCertBundle, error) {
return nil, errEntOnly
}
func ExtractManagedKeyId(privateKeyBytes []byte) (UUIDKey, error) {
return "", errEntOnly
}
func CreateKmsKeyBundle(ctx context.Context, mkv PkiManagedKeyView, keyId managedKeyId) (certutil.KeyBundle, certutil.PrivateKeyType, error) {
return certutil.KeyBundle{}, certutil.UnknownPrivateKey, errEntOnly
}
func GetManagedKeyInfo(ctx context.Context, mkv PkiManagedKeyView, keyId managedKeyId) (*ManagedKeyInfo, error) {
return nil, errEntOnly
}

View File

@@ -1,45 +0,0 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
//go:build !enterprise
package pki
import (
"context"
"crypto"
"errors"
"io"
"github.com/hashicorp/vault/sdk/helper/certutil"
)
var errEntOnly = errors.New("managed keys are supported within enterprise edition only")
func generateManagedKeyCABundle(ctx context.Context, b *backend, keyId managedKeyId, data *certutil.CreationBundle, randomSource io.Reader) (bundle *certutil.ParsedCertBundle, err error) {
return nil, errEntOnly
}
func generateManagedKeyCSRBundle(ctx context.Context, b *backend, keyId managedKeyId, data *certutil.CreationBundle, addBasicConstraints bool, randomSource io.Reader) (bundle *certutil.ParsedCSRBundle, err error) {
return nil, errEntOnly
}
func getManagedKeyPublicKey(ctx context.Context, b *backend, keyId managedKeyId) (crypto.PublicKey, error) {
return nil, errEntOnly
}
func parseManagedKeyCABundle(ctx context.Context, b *backend, bundle *certutil.CertBundle) (*certutil.ParsedCertBundle, error) {
return nil, errEntOnly
}
func extractManagedKeyId(privateKeyBytes []byte) (UUIDKey, error) {
return "", errEntOnly
}
func createKmsKeyBundle(ctx context.Context, b *backend, keyId managedKeyId) (certutil.KeyBundle, certutil.PrivateKeyType, error) {
return certutil.KeyBundle{}, certutil.UnknownPrivateKey, errEntOnly
}
func getManagedKeyInfo(ctx context.Context, b *backend, keyId managedKeyId) (*managedKeyInfo, error) {
return nil, errEntOnly
}

View File

@@ -0,0 +1,263 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package pki
import (
"errors"
"sort"
"strings"
"sync/atomic"
"github.com/armon/go-metrics"
)
type CertificateCounter struct {
certCountEnabled *atomic.Bool
publishCertCountMetrics *atomic.Bool
certCount *atomic.Uint32
revokedCertCount *atomic.Uint32
certsCounted *atomic.Bool
certCountError error
possibleDoubleCountedSerials []string
possibleDoubleCountedRevokedSerials []string
backendUuid string
}
func (c *CertificateCounter) IsInitialized() bool {
return c.certsCounted.Load()
}
func (c *CertificateCounter) IsEnabled() bool {
return c.certCountEnabled.Load()
}
func (c *CertificateCounter) Error() error {
return c.certCountError
}
func (c *CertificateCounter) SetError(err error) {
c.certCountError = err
}
func (c *CertificateCounter) ReconfigureWithTidyConfig(config *tidyConfig) bool {
if config.MaintainCount {
c.enableCertCounting(config.PublishMetrics)
} else {
c.disableCertCounting()
}
return config.MaintainCount
}
func (c *CertificateCounter) disableCertCounting() {
c.possibleDoubleCountedRevokedSerials = nil
c.possibleDoubleCountedSerials = nil
c.certsCounted.Store(false)
c.certCount.Store(0)
c.revokedCertCount.Store(0)
c.certCountError = errors.New("Cert Count is Disabled: enable via Tidy Config maintain_stored_certificate_counts")
c.certCountEnabled.Store(false)
c.publishCertCountMetrics.Store(false)
}
func (c *CertificateCounter) enableCertCounting(publishMetrics bool) {
c.publishCertCountMetrics.Store(publishMetrics)
c.certCountEnabled.Store(true)
if !c.certsCounted.Load() {
c.certCountError = errors.New("Certificate Counting Has Not Been Initialized, re-initialize this mount")
}
}
func (c *CertificateCounter) InitializeCountsFromStorage(certs, revoked []string) {
c.certCount.Add(uint32(len(certs)))
c.revokedCertCount.Add(uint32(len(revoked)))
c.pruneDuplicates(certs, revoked)
c.certCountError = nil
c.certsCounted.Store(true)
c.emitTotalCertCountMetric()
}
func (c *CertificateCounter) pruneDuplicates(entries, revokedEntries []string) {
// Now that the metrics are set, we can switch from appending newly-stored certificates to the possible double-count
// list, and instead have them update the counter directly. We need to do this so that we are looking at a static
// slice of possibly double counted serials. Note that certsCounted is computed before the storage operation, so
// there may be some delay here.
// Sort the listed-entries first, to accommodate that delay.
sort.Slice(entries, func(i, j int) bool {
return entries[i] < entries[j]
})
sort.Slice(revokedEntries, func(i, j int) bool {
return revokedEntries[i] < revokedEntries[j]
})
// We assume here that these lists are now complete.
sort.Slice(c.possibleDoubleCountedSerials, func(i, j int) bool {
return c.possibleDoubleCountedSerials[i] < c.possibleDoubleCountedSerials[j]
})
listEntriesIndex := 0
possibleDoubleCountIndex := 0
for {
if listEntriesIndex >= len(entries) {
break
}
if possibleDoubleCountIndex >= len(c.possibleDoubleCountedSerials) {
break
}
if entries[listEntriesIndex] == c.possibleDoubleCountedSerials[possibleDoubleCountIndex] {
// This represents a double-counted entry
c.decrementTotalCertificatesCountNoReport()
listEntriesIndex = listEntriesIndex + 1
possibleDoubleCountIndex = possibleDoubleCountIndex + 1
continue
}
if entries[listEntriesIndex] < c.possibleDoubleCountedSerials[possibleDoubleCountIndex] {
listEntriesIndex = listEntriesIndex + 1
continue
}
if entries[listEntriesIndex] > c.possibleDoubleCountedSerials[possibleDoubleCountIndex] {
possibleDoubleCountIndex = possibleDoubleCountIndex + 1
continue
}
}
sort.Slice(c.possibleDoubleCountedRevokedSerials, func(i, j int) bool {
return c.possibleDoubleCountedRevokedSerials[i] < c.possibleDoubleCountedRevokedSerials[j]
})
listRevokedEntriesIndex := 0
possibleRevokedDoubleCountIndex := 0
for {
if listRevokedEntriesIndex >= len(revokedEntries) {
break
}
if possibleRevokedDoubleCountIndex >= len(c.possibleDoubleCountedRevokedSerials) {
break
}
if revokedEntries[listRevokedEntriesIndex] == c.possibleDoubleCountedRevokedSerials[possibleRevokedDoubleCountIndex] {
// This represents a double-counted revoked entry
c.decrementTotalRevokedCertificatesCountNoReport()
listRevokedEntriesIndex = listRevokedEntriesIndex + 1
possibleRevokedDoubleCountIndex = possibleRevokedDoubleCountIndex + 1
continue
}
if revokedEntries[listRevokedEntriesIndex] < c.possibleDoubleCountedRevokedSerials[possibleRevokedDoubleCountIndex] {
listRevokedEntriesIndex = listRevokedEntriesIndex + 1
continue
}
if revokedEntries[listRevokedEntriesIndex] > c.possibleDoubleCountedRevokedSerials[possibleRevokedDoubleCountIndex] {
possibleRevokedDoubleCountIndex = possibleRevokedDoubleCountIndex + 1
continue
}
}
c.possibleDoubleCountedRevokedSerials = nil
c.possibleDoubleCountedSerials = nil
}
func (c *CertificateCounter) decrementTotalCertificatesCountNoReport() uint32 {
newCount := c.certCount.Add(^uint32(0))
return newCount
}
func (c *CertificateCounter) decrementTotalRevokedCertificatesCountNoReport() uint32 {
newRevokedCertCount := c.revokedCertCount.Add(^uint32(0))
return newRevokedCertCount
}
func (c *CertificateCounter) CertificateCount() uint32 {
return c.certCount.Load()
}
func (c *CertificateCounter) RevokedCount() uint32 {
return c.revokedCertCount.Load()
}
func (c *CertificateCounter) IncrementTotalCertificatesCount(certsCounted bool, newSerial string) {
if c.certCountEnabled.Load() {
c.certCount.Add(1)
switch {
case !certsCounted:
// This is unsafe, but a good best-attempt
if strings.HasPrefix(newSerial, "certs/") {
newSerial = newSerial[6:]
}
c.possibleDoubleCountedSerials = append(c.possibleDoubleCountedSerials, newSerial)
default:
c.emitTotalCertCountMetric()
}
}
}
// The "certsCounted" boolean here should be loaded from the backend certsCounted before the corresponding storage call:
// eg. certsCounted := certCounter.IsInitialized()
func (c *CertificateCounter) IncrementTotalRevokedCertificatesCount(certsCounted bool, newSerial string) {
if c.certCountEnabled.Load() {
c.revokedCertCount.Add(1)
switch {
case !certsCounted:
// This is unsafe, but a good best-attempt
if strings.HasPrefix(newSerial, "revoked/") { // allow passing in the path (revoked/serial) OR the serial
newSerial = newSerial[8:]
}
c.possibleDoubleCountedRevokedSerials = append(c.possibleDoubleCountedRevokedSerials, newSerial)
default:
c.emitTotalRevokedCountMetric()
}
}
}
func (c *CertificateCounter) DecrementTotalCertificatesCountReport() {
if c.certCountEnabled.Load() {
c.decrementTotalCertificatesCountNoReport()
c.emitTotalCertCountMetric()
}
}
func (c *CertificateCounter) DecrementTotalRevokedCertificatesCountReport() {
if c.certCountEnabled.Load() {
c.decrementTotalRevokedCertificatesCountNoReport()
c.emitTotalRevokedCountMetric()
}
}
func (c *CertificateCounter) EmitCertStoreMetrics() {
c.emitTotalCertCountMetric()
c.emitTotalRevokedCountMetric()
}
func (c *CertificateCounter) emitTotalCertCountMetric() {
if c.publishCertCountMetrics.Load() {
certCount := float32(c.CertificateCount())
metrics.SetGauge([]string{"secrets", "pki", c.backendUuid, "total_certificates_stored"}, certCount)
}
}
func (c *CertificateCounter) emitTotalRevokedCountMetric() {
if c.publishCertCountMetrics.Load() {
revokedCount := float32(c.RevokedCount())
metrics.SetGauge([]string{"secrets", "pki", c.backendUuid, "total_revoked_certificates_stored"}, revokedCount)
}
}
func NewCertificateCounter(backendUuid string) *CertificateCounter {
counter := &CertificateCounter{
backendUuid: backendUuid,
certCountEnabled: &atomic.Bool{},
publishCertCountMetrics: &atomic.Bool{},
certCount: &atomic.Uint32{},
revokedCertCount: &atomic.Uint32{},
certsCounted: &atomic.Bool{},
certCountError: errors.New("Initialize Not Yet Run, Cert Counts Unavailable"),
possibleDoubleCountedSerials: make([]string, 0, 250),
possibleDoubleCountedRevokedSerials: make([]string, 0, 250),
}
return counter
}

View File

@@ -0,0 +1,74 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package parsing
import (
"crypto/x509"
"fmt"
"strings"
)
func ParseCertificateFromString(pemCert string) (*x509.Certificate, error) {
return ParseCertificateFromBytes([]byte(pemCert))
}
func ParseCertificateFromBytes(certBytes []byte) (*x509.Certificate, error) {
block, err := DecodePem(certBytes)
if err != nil {
return nil, fmt.Errorf("unable to parse certificate: %w", err)
}
cert, err := x509.ParseCertificate(block.Bytes)
if err != nil {
return nil, fmt.Errorf("unable to parse certificate: %w", err)
}
return cert, nil
}
func ParseCertificatesFromString(pemCerts string) ([]*x509.Certificate, error) {
return ParseCertificatesFromBytes([]byte(pemCerts))
}
func ParseCertificatesFromBytes(certBytes []byte) ([]*x509.Certificate, error) {
block, err := DecodePem(certBytes)
if err != nil {
return nil, fmt.Errorf("unable to parse certificate: %w", err)
}
cert, err := x509.ParseCertificates(block.Bytes)
if err != nil {
return nil, fmt.Errorf("unable to parse certificate: %w", err)
}
return cert, nil
}
func ParseKeyUsages(input []string) int {
var parsedKeyUsages x509.KeyUsage
for _, k := range input {
switch strings.ToLower(strings.TrimSpace(k)) {
case "digitalsignature":
parsedKeyUsages |= x509.KeyUsageDigitalSignature
case "contentcommitment":
parsedKeyUsages |= x509.KeyUsageContentCommitment
case "keyencipherment":
parsedKeyUsages |= x509.KeyUsageKeyEncipherment
case "dataencipherment":
parsedKeyUsages |= x509.KeyUsageDataEncipherment
case "keyagreement":
parsedKeyUsages |= x509.KeyUsageKeyAgreement
case "certsign":
parsedKeyUsages |= x509.KeyUsageCertSign
case "crlsign":
parsedKeyUsages |= x509.KeyUsageCRLSign
case "encipheronly":
parsedKeyUsages |= x509.KeyUsageEncipherOnly
case "decipheronly":
parsedKeyUsages |= x509.KeyUsageDecipherOnly
}
}
return int(parsedKeyUsages)
}

View File

@@ -0,0 +1,27 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package parsing
import (
"crypto/x509"
"fmt"
)
func ParseCertificateRequestFromString(pemCert string) (*x509.CertificateRequest, error) {
return ParseCertificateRequestFromBytes([]byte(pemCert))
}
func ParseCertificateRequestFromBytes(certBytes []byte) (*x509.CertificateRequest, error) {
block, err := DecodePem(certBytes)
if err != nil {
return nil, fmt.Errorf("unable to parse certificate request: %w", err)
}
csr, err := x509.ParseCertificateRequest(block.Bytes)
if err != nil {
return nil, fmt.Errorf("unable to parse certificate request: %w", err)
}
return csr, nil
}

View File

@@ -0,0 +1,21 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package parsing
import (
"encoding/pem"
"errors"
"strings"
)
func DecodePem(certBytes []byte) (*pem.Block, error) {
block, extra := pem.Decode(certBytes)
if block == nil {
return nil, errors.New("invalid PEM")
}
if len(strings.TrimSpace(string(extra))) > 0 {
return nil, errors.New("trailing PEM data")
}
return block, nil
}

View File

@@ -11,7 +11,6 @@ import (
"time"
"github.com/hashicorp/go-secure-stdlib/strutil"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/logical"
)
@@ -222,7 +221,7 @@ func (b *backend) acmeAccountSearchHandler(acmeCtx *acmeContext, userCtx *jwsCtx
return nil, fmt.Errorf("failed generating thumbprint for key: %w", err)
}
account, err := b.acmeState.LoadAccountByKey(acmeCtx, thumbprint)
account, err := b.GetAcmeState().LoadAccountByKey(acmeCtx, thumbprint)
if err != nil {
return nil, fmt.Errorf("failed to load account by thumbprint: %w", err)
}
@@ -253,7 +252,7 @@ func (b *backend) acmeNewAccountCreateHandler(acmeCtx *acmeContext, userCtx *jws
return nil, fmt.Errorf("failed generating thumbprint for key: %w", err)
}
accountByKey, err := b.acmeState.LoadAccountByKey(acmeCtx, thumbprint)
accountByKey, err := b.GetAcmeState().LoadAccountByKey(acmeCtx, thumbprint)
if err != nil {
return nil, fmt.Errorf("failed to load account by thumbprint: %w", err)
}
@@ -267,7 +266,7 @@ func (b *backend) acmeNewAccountCreateHandler(acmeCtx *acmeContext, userCtx *jws
var eab *eabType
if len(eabData) != 0 {
eab, err = verifyEabPayload(b.acmeState, acmeCtx, userCtx, r.Path, eabData)
eab, err = verifyEabPayload(b.GetAcmeState(), acmeCtx, userCtx, r.Path, eabData)
if err != nil {
return nil, err
}
@@ -288,7 +287,7 @@ func (b *backend) acmeNewAccountCreateHandler(acmeCtx *acmeContext, userCtx *jws
// We delete the EAB to prevent future re-use after associating it with an account, worst
// case if we fail creating the account we simply nuked the EAB which they can create another
// and retry
wasDeleted, err := b.acmeState.DeleteEab(acmeCtx.sc, eab.KeyID)
wasDeleted, err := b.GetAcmeState().DeleteEab(acmeCtx.sc, eab.KeyID)
if err != nil {
return nil, fmt.Errorf("failed to delete eab reference: %w", err)
}
@@ -302,7 +301,7 @@ func (b *backend) acmeNewAccountCreateHandler(acmeCtx *acmeContext, userCtx *jws
b.acmeAccountLock.RLock() // Prevents Account Creation and Tidy Interfering
defer b.acmeAccountLock.RUnlock()
accountByKid, err := b.acmeState.CreateAccount(acmeCtx, userCtx, contact, termsOfServiceAgreed, eab)
accountByKid, err := b.GetAcmeState().CreateAccount(acmeCtx, userCtx, contact, termsOfServiceAgreed, eab)
if err != nil {
if eab != nil {
return nil, fmt.Errorf("failed to create account: %w; the EAB key used for this request has been deleted as a result of this operation; fetch a new EAB key before retrying", err)
@@ -329,7 +328,7 @@ func (b *backend) acmeNewAccountUpdateHandler(acmeCtx *acmeContext, userCtx *jws
return nil, fmt.Errorf("%w: not allowed to update EAB data in accounts", ErrMalformed)
}
account, err := b.acmeState.LoadAccount(acmeCtx, userCtx.Kid)
account, err := b.GetAcmeState().LoadAccount(acmeCtx, userCtx.Kid)
if err != nil {
return nil, fmt.Errorf("error loading account: %w", err)
}
@@ -363,7 +362,7 @@ func (b *backend) acmeNewAccountUpdateHandler(acmeCtx *acmeContext, userCtx *jws
}
if shouldUpdate {
err = b.acmeState.UpdateAccount(acmeCtx.sc, account)
err = b.GetAcmeState().UpdateAccount(acmeCtx.sc, account)
if err != nil {
return nil, fmt.Errorf("failed to update account: %w", err)
}

View File

@@ -48,7 +48,7 @@ func patternAcmeAuthorization(b *backend, pattern string, opts acmeWrapperOpts)
func (b *backend) acmeAuthorizationHandler(acmeCtx *acmeContext, r *logical.Request, fields *framework.FieldData, userCtx *jwsCtx, data map[string]interface{}, _ *acmeAccount) (*logical.Response, error) {
authId := fields.Get("auth_id").(string)
authz, err := b.acmeState.LoadAuthorization(acmeCtx, userCtx, authId)
authz, err := b.GetAcmeState().LoadAuthorization(acmeCtx, userCtx, authId)
if err != nil {
return nil, fmt.Errorf("failed to load authorization: %w", err)
}
@@ -90,7 +90,7 @@ func (b *backend) acmeAuthorizationDeactivateHandler(acmeCtx *acmeContext, r *lo
challenge.Status = ACMEChallengeInvalid
}
if err := b.acmeState.SaveAuthorization(acmeCtx, authz); err != nil {
if err := b.GetAcmeState().SaveAuthorization(acmeCtx, authz); err != nil {
return nil, fmt.Errorf("error saving deactivated authorization: %w", err)
}

View File

@@ -57,7 +57,7 @@ func (b *backend) acmeChallengeHandler(acmeCtx *acmeContext, r *logical.Request,
authId := fields.Get("auth_id").(string)
challengeType := fields.Get("challenge_type").(string)
authz, err := b.acmeState.LoadAuthorization(acmeCtx, userCtx, authId)
authz, err := b.GetAcmeState().LoadAuthorization(acmeCtx, userCtx, authId)
if err != nil {
return nil, fmt.Errorf("failed to load authorization: %w", err)
}
@@ -95,7 +95,7 @@ func (b *backend) acmeChallengeFetchHandler(acmeCtx *acmeContext, r *logical.Req
return nil, fmt.Errorf("failed to get thumbprint for key: %w", err)
}
if err := b.acmeState.validator.AcceptChallenge(acmeCtx.sc, userCtx.Kid, authz, challenge, thumbprint); err != nil {
if err := b.GetAcmeState().validator.AcceptChallenge(acmeCtx.sc, userCtx.Kid, authz, challenge, thumbprint); err != nil {
return nil, fmt.Errorf("error submitting challenge for validation: %w", err)
}
}

View File

@@ -183,7 +183,8 @@ type eabType struct {
func (b *backend) pathAcmeListEab(ctx context.Context, r *logical.Request, _ *framework.FieldData) (*logical.Response, error) {
sc := b.makeStorageContext(ctx, r.Storage)
eabIds, err := b.acmeState.ListEabIds(sc)
acmeState := b.GetAcmeState()
eabIds, err := acmeState.ListEabIds(sc)
if err != nil {
return nil, err
}
@@ -193,7 +194,7 @@ func (b *backend) pathAcmeListEab(ctx context.Context, r *logical.Request, _ *fr
keyInfos := map[string]interface{}{}
for _, eabKey := range eabIds {
eab, err := b.acmeState.LoadEab(sc, eabKey)
eab, err := acmeState.LoadEab(sc, eabKey)
if err != nil {
warnings = append(warnings, fmt.Sprintf("failed loading eab entry %s: %v", eabKey, err))
continue
@@ -236,7 +237,7 @@ func (b *backend) pathAcmeCreateEab(ctx context.Context, r *logical.Request, dat
}
sc := b.makeStorageContext(ctx, r.Storage)
err = b.acmeState.SaveEab(sc, eab)
err = b.GetAcmeState().SaveEab(sc, eab)
if err != nil {
return nil, fmt.Errorf("failed saving generated eab: %w", err)
}
@@ -263,7 +264,7 @@ func (b *backend) pathAcmeDeleteEab(ctx context.Context, r *logical.Request, d *
return nil, fmt.Errorf("badly formatted key_id field")
}
deleted, err := b.acmeState.DeleteEab(sc, keyId)
deleted, err := b.GetAcmeState().DeleteEab(sc, keyId)
if err != nil {
return nil, fmt.Errorf("failed deleting key id: %w", err)
}

View File

@@ -41,7 +41,7 @@ func patternAcmeNonce(b *backend, pattern string, opts acmeWrapperOpts) *framewo
}
func (b *backend) acmeNonceHandler(ctx *acmeContext, r *logical.Request, _ *framework.FieldData) (*logical.Response, error) {
nonce, _, err := b.acmeState.GetNonce()
nonce, _, err := b.GetAcmeState().GetNonce()
if err != nil {
return nil, err
}

View File

@@ -21,6 +21,8 @@ import (
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/logical"
"golang.org/x/net/idna"
"github.com/hashicorp/vault/builtin/logical/pki/issuing"
)
var maxAcmeCertTTL = 90 * (24 * time.Hour)
@@ -164,7 +166,7 @@ func addFieldsForACMEOrder(fields map[string]*framework.FieldSchema) {
func (b *backend) acmeFetchCertOrderHandler(ac *acmeContext, _ *logical.Request, fields *framework.FieldData, uc *jwsCtx, data map[string]interface{}, _ *acmeAccount) (*logical.Response, error) {
orderId := fields.Get("order_id").(string)
order, err := b.acmeState.LoadOrder(ac, uc, orderId)
order, err := b.GetAcmeState().LoadOrder(ac, uc, orderId)
if err != nil {
return nil, err
}
@@ -232,7 +234,7 @@ func (b *backend) acmeFinalizeOrderHandler(ac *acmeContext, r *logical.Request,
return nil, err
}
order, err := b.acmeState.LoadOrder(ac, uc, orderId)
order, err := b.GetAcmeState().LoadOrder(ac, uc, orderId)
if err != nil {
return nil, err
}
@@ -260,7 +262,7 @@ func (b *backend) acmeFinalizeOrderHandler(ac *acmeContext, r *logical.Request,
}
var signedCertBundle *certutil.ParsedCertBundle
var issuerId issuerID
var issuerId issuing.IssuerID
if ac.runtimeOpts.isCiepsEnabled {
// Note that issueAcmeCertUsingCieps enforces storage requirements and
// does the certificate storage for us
@@ -281,7 +283,7 @@ func (b *backend) acmeFinalizeOrderHandler(ac *acmeContext, r *logical.Request,
}
hyphenSerialNumber := normalizeSerialFromBigInt(signedCertBundle.Certificate.SerialNumber)
if err := b.acmeState.TrackIssuedCert(ac, order.AccountId, hyphenSerialNumber, order.OrderId); err != nil {
if err := b.GetAcmeState().TrackIssuedCert(ac, order.AccountId, hyphenSerialNumber, order.OrderId); err != nil {
b.Logger().Warn("orphaned generated ACME certificate due to error saving account->cert->order reference", "serial_number", hyphenSerialNumber, "error", err)
return nil, err
}
@@ -291,7 +293,7 @@ func (b *backend) acmeFinalizeOrderHandler(ac *acmeContext, r *logical.Request,
order.CertificateExpiry = signedCertBundle.Certificate.NotAfter
order.IssuerId = issuerId
err = b.acmeState.SaveOrder(ac, order)
err = b.GetAcmeState().SaveOrder(ac, order)
if err != nil {
b.Logger().Warn("orphaned generated ACME certificate due to error saving order", "serial_number", hyphenSerialNumber, "error", err)
return nil, fmt.Errorf("failed saving updated order: %w", err)
@@ -413,7 +415,7 @@ func validateCsrMatchesOrder(csr *x509.CertificateRequest, order *acmeOrder) err
return nil
}
func (b *backend) validateIdentifiersAgainstRole(role *roleEntry, identifiers []*ACMEIdentifier) error {
func (b *backend) validateIdentifiersAgainstRole(role *issuing.RoleEntry, identifiers []*ACMEIdentifier) error {
for _, identifier := range identifiers {
switch identifier.Type {
case ACMEDNSIdentifier:
@@ -480,7 +482,8 @@ func removeDuplicatesAndSortIps(ipIdentifiers []net.IP) []net.IP {
func storeCertificate(sc *storageContext, signedCertBundle *certutil.ParsedCertBundle) error {
hyphenSerialNumber := normalizeSerialFromBigInt(signedCertBundle.Certificate.SerialNumber)
key := "certs/" + hyphenSerialNumber
certsCounted := sc.Backend.certsCounted.Load()
certCounter := sc.Backend.GetCertificateCounter()
certsCounted := certCounter.IsInitialized()
err := sc.Storage.Put(sc.Context, &logical.StorageEntry{
Key: key,
Value: signedCertBundle.CertificateBytes,
@@ -488,7 +491,7 @@ func storeCertificate(sc *storageContext, signedCertBundle *certutil.ParsedCertB
if err != nil {
return fmt.Errorf("unable to store certificate locally: %w", err)
}
sc.Backend.ifCountEnabledIncrementTotalCertificatesCount(certsCounted, key)
certCounter.IncrementTotalCertificatesCount(certsCounted, key)
return nil
}
@@ -520,7 +523,7 @@ func maybeAugmentReqDataWithSuitableCN(ac *acmeContext, csr *x509.CertificateReq
}
}
func issueCertFromCsr(ac *acmeContext, csr *x509.CertificateRequest) (*certutil.ParsedCertBundle, issuerID, error) {
func issueCertFromCsr(ac *acmeContext, csr *x509.CertificateRequest) (*certutil.ParsedCertBundle, issuing.IssuerID, error) {
pemBlock := &pem.Block{
Type: "CERTIFICATE REQUEST",
Headers: nil,
@@ -540,7 +543,7 @@ func issueCertFromCsr(ac *acmeContext, csr *x509.CertificateRequest) (*certutil.
// (TLS) clients are mostly verifying against server's DNS SANs.
maybeAugmentReqDataWithSuitableCN(ac, csr, data)
signingBundle, issuerId, err := ac.sc.fetchCAInfoWithIssuer(ac.issuer.ID.String(), IssuanceUsage)
signingBundle, issuerId, err := ac.sc.fetchCAInfoWithIssuer(ac.issuer.ID.String(), issuing.IssuanceUsage)
if err != nil {
return nil, "", fmt.Errorf("failed loading CA %s: %w", ac.issuer.ID.String(), err)
}
@@ -595,7 +598,7 @@ func issueCertFromCsr(ac *acmeContext, csr *x509.CertificateRequest) (*certutil.
// We only allow ServerAuth key usage from ACME issued certs
// when configuration does not allow usage of ExtKeyusage field.
config, err := ac.sc.Backend.acmeState.getConfigWithUpdate(ac.sc)
config, err := ac.sc.Backend.GetAcmeState().getConfigWithUpdate(ac.sc)
if err != nil {
return nil, "", fmt.Errorf("failed to fetch ACME configuration: %w", err)
}
@@ -655,7 +658,7 @@ func parseCsrFromFinalize(data map[string]interface{}) (*x509.CertificateRequest
func (b *backend) acmeGetOrderHandler(ac *acmeContext, _ *logical.Request, fields *framework.FieldData, uc *jwsCtx, _ map[string]interface{}, _ *acmeAccount) (*logical.Response, error) {
orderId := fields.Get("order_id").(string)
order, err := b.acmeState.LoadOrder(ac, uc, orderId)
order, err := b.GetAcmeState().LoadOrder(ac, uc, orderId)
if err != nil {
return nil, err
}
@@ -674,7 +677,7 @@ func (b *backend) acmeGetOrderHandler(ac *acmeContext, _ *logical.Request, field
filteredAuthorizationIds := []string{}
for _, authId := range order.AuthorizationIds {
authorization, err := b.acmeState.LoadAuthorization(ac, uc, authId)
authorization, err := b.GetAcmeState().LoadAuthorization(ac, uc, authId)
if err != nil {
return nil, err
}
@@ -692,14 +695,14 @@ func (b *backend) acmeGetOrderHandler(ac *acmeContext, _ *logical.Request, field
}
func (b *backend) acmeListOrdersHandler(ac *acmeContext, _ *logical.Request, _ *framework.FieldData, uc *jwsCtx, _ map[string]interface{}, acct *acmeAccount) (*logical.Response, error) {
orderIds, err := b.acmeState.ListOrderIds(ac.sc, acct.KeyId)
orderIds, err := b.GetAcmeState().ListOrderIds(ac.sc, acct.KeyId)
if err != nil {
return nil, err
}
orderUrls := []string{}
for _, orderId := range orderIds {
order, err := b.acmeState.LoadOrder(ac, uc, orderId)
order, err := b.GetAcmeState().LoadOrder(ac, uc, orderId)
if err != nil {
return nil, err
}
@@ -771,7 +774,7 @@ func (b *backend) acmeNewOrderHandler(ac *acmeContext, _ *logical.Request, _ *fr
}
authorizations = append(authorizations, authz)
err = b.acmeState.SaveAuthorization(ac, authz)
err = b.GetAcmeState().SaveAuthorization(ac, authz)
if err != nil {
return nil, fmt.Errorf("failed storing authorization: %w", err)
}
@@ -788,7 +791,7 @@ func (b *backend) acmeNewOrderHandler(ac *acmeContext, _ *logical.Request, _ *fr
AuthorizationIds: authorizationIds,
}
err = b.acmeState.SaveOrder(ac, order)
err = b.GetAcmeState().SaveOrder(ac, order)
if err != nil {
return nil, fmt.Errorf("failed storing order: %w", err)
}

View File

@@ -10,6 +10,8 @@ import (
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/logical"
"github.com/stretchr/testify/require"
"github.com/hashicorp/vault/builtin/logical/pki/issuing"
)
// TestACME_ValidateIdentifiersAgainstRole Verify the ACME order creation
@@ -20,13 +22,13 @@ func TestACME_ValidateIdentifiersAgainstRole(t *testing.T) {
tests := []struct {
name string
role *roleEntry
role *issuing.RoleEntry
identifiers []*ACMEIdentifier
expectErr bool
}{
{
name: "verbatim-role-allows-dns-ip",
role: buildSignVerbatimRoleWithNoData(nil),
role: issuing.SignVerbatimRole(),
identifiers: _buildACMEIdentifiers("test.com", "127.0.0.1"),
expectErr: false,
},
@@ -119,7 +121,7 @@ func _buildACMEIdentifier(val string) *ACMEIdentifier {
// Easily allow tests to create valid roles with proper defaults, since we don't have an easy
// way to generate roles with proper defaults, go through the createRole handler with the handlers
// field data so we pickup all the defaults specified there.
func buildTestRole(t *testing.T, config map[string]interface{}) *roleEntry {
func buildTestRole(t *testing.T, config map[string]interface{}) *issuing.RoleEntry {
b, s := CreateBackendWithStorage(t)
path := pathRoles(b)
@@ -135,7 +137,7 @@ func buildTestRole(t *testing.T, config map[string]interface{}) *roleEntry {
_, err := b.pathRoleCreate(ctx, &logical.Request{Storage: s}, &framework.FieldData{Raw: config, Schema: fields})
require.NoError(t, err, "failed generating role with config %v", config)
role, err := b.getRole(ctx, s, config["name"].(string))
role, err := b.GetRole(ctx, s, config["name"].(string))
require.NoError(t, err, "failed loading stored role")
return role

View File

@@ -82,7 +82,7 @@ func (b *backend) acmeRevocationHandler(acmeCtx *acmeContext, _ *logical.Request
// Fetch the CRL config as we need it to ultimately do the
// revocation. This should be cached and thus relatively fast.
config, err := b.crlBuilder.getConfigWithUpdate(acmeCtx.sc)
config, err := b.CrlBuilder().getConfigWithUpdate(acmeCtx.sc)
if err != nil {
return nil, fmt.Errorf("unable to revoke certificate: failed reading revocation config: %v: %w", err, ErrServerInternal)
}
@@ -153,8 +153,8 @@ func (b *backend) acmeRevocationByPoP(acmeCtx *acmeContext, userCtx *jwsCtx, cer
}
// Now it is safe to revoke.
b.revokeStorageLock.Lock()
defer b.revokeStorageLock.Unlock()
b.GetRevokeStorageLock().Lock()
defer b.GetRevokeStorageLock().Unlock()
return revokeCert(acmeCtx.sc, config, cert)
}
@@ -169,14 +169,14 @@ func (b *backend) acmeRevocationByAccount(acmeCtx *acmeContext, userCtx *jwsCtx,
// We only support certificates issued by this user, we don't support
// cross-account revocations.
serial := serialFromCert(cert)
acmeEntry, err := b.acmeState.GetIssuedCert(acmeCtx, userCtx.Kid, serial)
acmeEntry, err := b.GetAcmeState().GetIssuedCert(acmeCtx, userCtx.Kid, serial)
if err != nil || acmeEntry == nil {
return nil, fmt.Errorf("unable to revoke certificate: %v: %w", err, ErrMalformed)
}
// Now it is safe to revoke.
b.revokeStorageLock.Lock()
defer b.revokeStorageLock.Unlock()
b.GetRevokeStorageLock().Lock()
defer b.GetRevokeStorageLock().Unlock()
return revokeCert(acmeCtx.sc, config, cert)
}

View File

@@ -157,7 +157,7 @@ func pathAcmeConfig(b *backend) *framework.Path {
func (b *backend) pathAcmeRead(ctx context.Context, req *logical.Request, _ *framework.FieldData) (*logical.Response, error) {
sc := b.makeStorageContext(ctx, req.Storage)
config, err := b.acmeState.getConfigWithForcedUpdate(sc)
config, err := b.GetAcmeState().getConfigWithForcedUpdate(sc)
if err != nil {
return nil, err
}
@@ -195,7 +195,7 @@ func genResponseFromAcmeConfig(config *acmeConfigEntry, warnings []string) *logi
func (b *backend) pathAcmeWrite(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
sc := b.makeStorageContext(ctx, req.Storage)
config, err := b.acmeState.getConfigWithForcedUpdate(sc)
config, err := b.GetAcmeState().getConfigWithForcedUpdate(sc)
if err != nil {
return nil, err
}
@@ -337,7 +337,7 @@ func (b *backend) pathAcmeWrite(ctx context.Context, req *logical.Request, d *fr
}
}
if _, err := b.acmeState.writeConfig(sc, config); err != nil {
if _, err := b.GetAcmeState().writeConfig(sc, config); err != nil {
return nil, fmt.Errorf("failed persisting: %w", err)
}

View File

@@ -7,6 +7,7 @@ import (
"context"
"net/http"
"github.com/hashicorp/vault/builtin/logical/pki/issuing"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/logical"
)
@@ -212,7 +213,7 @@ func pathReplaceRoot(b *backend) *framework.Path {
}
func (b *backend) pathCAIssuersRead(ctx context.Context, req *logical.Request, _ *framework.FieldData) (*logical.Response, error) {
if b.useLegacyBundleCaStorage() {
if b.UseLegacyBundleCaStorage() {
return logical.ErrorResponse("Cannot read defaults until migration has completed"), nil
}
@@ -225,7 +226,7 @@ func (b *backend) pathCAIssuersRead(ctx context.Context, req *logical.Request, _
return b.formatCAIssuerConfigRead(config), nil
}
func (b *backend) formatCAIssuerConfigRead(config *issuerConfigEntry) *logical.Response {
func (b *backend) formatCAIssuerConfigRead(config *issuing.IssuerConfigEntry) *logical.Response {
return &logical.Response{
Data: map[string]interface{}{
defaultRef: config.DefaultIssuerId,
@@ -240,7 +241,7 @@ func (b *backend) pathCAIssuersWrite(ctx context.Context, req *logical.Request,
b.issuersLock.Lock()
defer b.issuersLock.Unlock()
if b.useLegacyBundleCaStorage() {
if b.UseLegacyBundleCaStorage() {
return logical.ErrorResponse("Cannot update defaults until migration has completed"), nil
}
@@ -370,7 +371,7 @@ func pathConfigKeys(b *backend) *framework.Path {
}
func (b *backend) pathKeyDefaultRead(ctx context.Context, req *logical.Request, _ *framework.FieldData) (*logical.Response, error) {
if b.useLegacyBundleCaStorage() {
if b.UseLegacyBundleCaStorage() {
return logical.ErrorResponse("Cannot read key defaults until migration has completed"), nil
}
@@ -393,7 +394,7 @@ func (b *backend) pathKeyDefaultWrite(ctx context.Context, req *logical.Request,
b.issuersLock.Lock()
defer b.issuersLock.Unlock()
if b.useLegacyBundleCaStorage() {
if b.UseLegacyBundleCaStorage() {
return logical.ErrorResponse("Cannot update key defaults until migration has completed"), nil
}

View File

@@ -275,7 +275,7 @@ existing CRL and OCSP paths will return the unified CRL instead of a response ba
func (b *backend) pathCRLRead(ctx context.Context, req *logical.Request, _ *framework.FieldData) (*logical.Response, error) {
sc := b.makeStorageContext(ctx, req.Storage)
config, err := b.crlBuilder.getConfigWithForcedUpdate(sc)
config, err := b.CrlBuilder().getConfigWithForcedUpdate(sc)
if err != nil {
return nil, fmt.Errorf("failed fetching CRL config: %w", err)
}
@@ -285,7 +285,7 @@ func (b *backend) pathCRLRead(ctx context.Context, req *logical.Request, _ *fram
func (b *backend) pathCRLWrite(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
sc := b.makeStorageContext(ctx, req.Storage)
config, err := b.crlBuilder.getConfigWithForcedUpdate(sc)
config, err := b.CrlBuilder().getConfigWithForcedUpdate(sc)
if err != nil {
return nil, err
}
@@ -410,7 +410,7 @@ func (b *backend) pathCRLWrite(ctx context.Context, req *logical.Request, d *fra
return logical.ErrorResponse("unified_crl=true requires auto_rebuild=true, as unified CRLs cannot be rebuilt on every revocation."), nil
}
if _, err := b.crlBuilder.writeConfig(sc, config); err != nil {
if _, err := b.CrlBuilder().writeConfig(sc, config); err != nil {
return nil, fmt.Errorf("failed persisting CRL config: %w", err)
}
@@ -418,13 +418,13 @@ func (b *backend) pathCRLWrite(ctx context.Context, req *logical.Request, d *fra
// Note this only affects/happens on the main cluster node, if you need to
// notify something based on a configuration change on all server types
// have a look at crlBuilder::reloadConfigIfRequired
// have a look at CrlBuilder::reloadConfigIfRequired
if oldDisable != config.Disable || (oldAutoRebuild && !config.AutoRebuild) || (oldEnableDelta != config.EnableDelta) || (oldUnifiedCRL != config.UnifiedCRL) {
// It wasn't disabled but now it is (or equivalently, we were set to
// auto-rebuild and we aren't now or equivalently, we changed our
// mind about delta CRLs and need a new complete one or equivalently,
// we changed our mind about unified CRLs), rotate the CRLs.
warnings, crlErr := b.crlBuilder.rebuild(sc, true)
warnings, crlErr := b.CrlBuilder().rebuild(sc, true)
if crlErr != nil {
switch crlErr.(type) {
case errutil.UserError:

View File

@@ -7,9 +7,8 @@ import (
"context"
"fmt"
"net/http"
"strings"
"github.com/asaskevich/govalidator"
"github.com/hashicorp/vault/builtin/logical/pki/issuing"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/logical"
)
@@ -140,23 +139,13 @@ set on all PR Secondary clusters.`,
}
}
func validateURLs(urls []string) string {
for _, curr := range urls {
if !govalidator.IsURL(curr) || strings.Contains(curr, "{{issuer_id}}") || strings.Contains(curr, "{{cluster_path}}") || strings.Contains(curr, "{{cluster_aia_path}}") {
return curr
}
}
return ""
}
func getGlobalAIAURLs(ctx context.Context, storage logical.Storage) (*aiaConfigEntry, error) {
func getGlobalAIAURLs(ctx context.Context, storage logical.Storage) (*issuing.AiaConfigEntry, error) {
entry, err := storage.Get(ctx, "urls")
if err != nil {
return nil, err
}
entries := &aiaConfigEntry{
entries := &issuing.AiaConfigEntry{
IssuingCertificates: []string{},
CRLDistributionPoints: []string{},
OCSPServers: []string{},
@@ -174,7 +163,7 @@ func getGlobalAIAURLs(ctx context.Context, storage logical.Storage) (*aiaConfigE
return entries, nil
}
func writeURLs(ctx context.Context, storage logical.Storage, entries *aiaConfigEntry) error {
func writeURLs(ctx context.Context, storage logical.Storage, entries *issuing.AiaConfigEntry) error {
entry, err := logical.StorageEntryJSON("urls", entries)
if err != nil {
return err
@@ -237,7 +226,7 @@ func (b *backend) pathWriteURL(ctx context.Context, req *logical.Request, data *
},
}
if entries.EnableTemplating && !b.useLegacyBundleCaStorage() {
if entries.EnableTemplating && !b.UseLegacyBundleCaStorage() {
sc := b.makeStorageContext(ctx, req.Storage)
issuers, err := sc.listIssuers()
if err != nil {
@@ -250,23 +239,23 @@ func (b *backend) pathWriteURL(ctx context.Context, req *logical.Request, data *
return nil, fmt.Errorf("unable to read issuer to validate templated URIs: %w", err)
}
_, err = entries.toURLEntries(sc, issuer.ID)
_, err = ToURLEntries(sc, issuer.ID, entries)
if err != nil {
resp.AddWarning(fmt.Sprintf("issuance may fail: %v\n\nConsider setting the cluster-local address if it is not already set.", err))
}
}
} else if !entries.EnableTemplating {
if badURL := validateURLs(entries.IssuingCertificates); badURL != "" {
if badURL := issuing.ValidateURLs(entries.IssuingCertificates); badURL != "" {
return logical.ErrorResponse(fmt.Sprintf(
"invalid URL found in Authority Information Access (AIA) parameter issuing_certificates: %s", badURL)), nil
}
if badURL := validateURLs(entries.CRLDistributionPoints); badURL != "" {
if badURL := issuing.ValidateURLs(entries.CRLDistributionPoints); badURL != "" {
return logical.ErrorResponse(fmt.Sprintf(
"invalid URL found in Authority Information Access (AIA) parameter crl_distribution_points: %s", badURL)), nil
}
if badURL := validateURLs(entries.OCSPServers); badURL != "" {
if badURL := issuing.ValidateURLs(entries.OCSPServers); badURL != "" {
return logical.ErrorResponse(fmt.Sprintf(
"invalid URL found in Authority Information Access (AIA) parameter ocsp_servers: %s", badURL)), nil
}

View File

@@ -11,6 +11,7 @@ import (
"strings"
"time"
"github.com/hashicorp/vault/builtin/logical/pki/issuing"
"github.com/hashicorp/vault/helper/constants"
"github.com/hashicorp/vault/sdk/framework"
@@ -306,7 +307,7 @@ func (b *backend) pathFetchRead(ctx context.Context, req *logical.Request, data
contentType = "application/pkix-cert"
}
case req.Path == "crl" || req.Path == "crl/pem" || req.Path == "crl/delta" || req.Path == "crl/delta/pem" || req.Path == "cert/crl" || req.Path == "cert/crl/raw" || req.Path == "cert/crl/raw/pem" || req.Path == "cert/delta-crl" || req.Path == "cert/delta-crl/raw" || req.Path == "cert/delta-crl/raw/pem" || req.Path == "unified-crl" || req.Path == "unified-crl/pem" || req.Path == "unified-crl/delta" || req.Path == "unified-crl/delta/pem" || req.Path == "cert/unified-crl" || req.Path == "cert/unified-crl/raw" || req.Path == "cert/unified-crl/raw/pem" || req.Path == "cert/unified-delta-crl" || req.Path == "cert/unified-delta-crl/raw" || req.Path == "cert/unified-delta-crl/raw/pem":
config, err := b.crlBuilder.getConfigWithUpdate(sc)
config, err := b.CrlBuilder().getConfigWithUpdate(sc)
if err != nil {
retErr = err
goto reply
@@ -370,7 +371,7 @@ func (b *backend) pathFetchRead(ctx context.Context, req *logical.Request, data
// Prefer fetchCAInfo to fetchCertBySerial for CA certificates.
if serial == "ca_chain" || serial == "ca" {
caInfo, err := sc.fetchCAInfo(defaultRef, ReadOnlyUsage)
caInfo, err := sc.fetchCAInfo(defaultRef, issuing.ReadOnlyUsage)
if err != nil {
switch err.(type) {
case errutil.UserError:

View File

@@ -12,6 +12,7 @@ import (
"strings"
"time"
"github.com/hashicorp/vault/builtin/logical/pki/issuing"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/certutil"
"github.com/hashicorp/vault/sdk/logical"
@@ -55,7 +56,7 @@ func pathListIssuers(b *backend) *framework.Path {
}
func (b *backend) pathListIssuersHandler(ctx context.Context, req *logical.Request, _ *framework.FieldData) (*logical.Response, error) {
if b.useLegacyBundleCaStorage() {
if b.UseLegacyBundleCaStorage() {
return logical.ErrorResponse("Can not list issuers until migration has completed"), nil
}
@@ -398,11 +399,11 @@ func (b *backend) pathGetIssuer(ctx context.Context, req *logical.Request, data
return b.pathGetRawIssuer(ctx, req, data)
}
if b.useLegacyBundleCaStorage() {
if b.UseLegacyBundleCaStorage() {
return logical.ErrorResponse("Can not get issuer until migration has completed"), nil
}
issuerName := getIssuerRef(data)
issuerName := GetIssuerRef(data)
if len(issuerName) == 0 {
return logical.ErrorResponse("missing issuer reference"), nil
}
@@ -424,7 +425,7 @@ func (b *backend) pathGetIssuer(ctx context.Context, req *logical.Request, data
return respondReadIssuer(issuer)
}
func respondReadIssuer(issuer *issuerEntry) (*logical.Response, error) {
func respondReadIssuer(issuer *issuing.IssuerEntry) (*logical.Response, error) {
var respManualChain []string
for _, entity := range issuer.ManualChain {
respManualChain = append(respManualChain, string(entity))
@@ -483,11 +484,11 @@ func (b *backend) pathUpdateIssuer(ctx context.Context, req *logical.Request, da
b.issuersLock.Lock()
defer b.issuersLock.Unlock()
if b.useLegacyBundleCaStorage() {
if b.UseLegacyBundleCaStorage() {
return logical.ErrorResponse("Can not update issuer until migration has completed"), nil
}
issuerName := getIssuerRef(data)
issuerName := GetIssuerRef(data)
if len(issuerName) == 0 {
return logical.ErrorResponse("missing issuer reference"), nil
}
@@ -537,9 +538,9 @@ func (b *backend) pathUpdateIssuer(ctx context.Context, req *logical.Request, da
}
rawUsage := data.Get("usage").([]string)
newUsage, err := NewIssuerUsageFromNames(rawUsage)
newUsage, err := issuing.NewIssuerUsageFromNames(rawUsage)
if err != nil {
return logical.ErrorResponse(fmt.Sprintf("Unable to parse specified usages: %v - valid values are %v", rawUsage, AllIssuerUsages.Names())), nil
return logical.ErrorResponse(fmt.Sprintf("Unable to parse specified usages: %v - valid values are %v", rawUsage, issuing.AllIssuerUsages.Names())), nil
}
// Revocation signature algorithm changes
@@ -562,15 +563,15 @@ func (b *backend) pathUpdateIssuer(ctx context.Context, req *logical.Request, da
// AIA access changes
enableTemplating := data.Get("enable_aia_url_templating").(bool)
issuerCertificates := data.Get("issuing_certificates").([]string)
if badURL := validateURLs(issuerCertificates); !enableTemplating && badURL != "" {
if badURL := issuing.ValidateURLs(issuerCertificates); !enableTemplating && badURL != "" {
return logical.ErrorResponse(fmt.Sprintf("invalid URL found in Authority Information Access (AIA) parameter issuing_certificates: %s", badURL)), nil
}
crlDistributionPoints := data.Get("crl_distribution_points").([]string)
if badURL := validateURLs(crlDistributionPoints); !enableTemplating && badURL != "" {
if badURL := issuing.ValidateURLs(crlDistributionPoints); !enableTemplating && badURL != "" {
return logical.ErrorResponse(fmt.Sprintf("invalid URL found in Authority Information Access (AIA) parameter crl_distribution_points: %s", badURL)), nil
}
ocspServers := data.Get("ocsp_servers").([]string)
if badURL := validateURLs(ocspServers); !enableTemplating && badURL != "" {
if badURL := issuing.ValidateURLs(ocspServers); !enableTemplating && badURL != "" {
return logical.ErrorResponse(fmt.Sprintf("invalid URL found in Authority Information Access (AIA) parameter ocsp_servers: %s", badURL)), nil
}
@@ -582,8 +583,8 @@ func (b *backend) pathUpdateIssuer(ctx context.Context, req *logical.Request, da
issuer.Name = newName
issuer.LastModified = time.Now().UTC()
// See note in updateDefaultIssuerId about why this is necessary.
b.crlBuilder.invalidateCRLBuildTime()
b.crlBuilder.flushCRLBuildTimeInvalidation(sc)
b.CrlBuilder().invalidateCRLBuildTime()
b.CrlBuilder().flushCRLBuildTimeInvalidation(sc)
modified = true
}
@@ -593,7 +594,7 @@ func (b *backend) pathUpdateIssuer(ctx context.Context, req *logical.Request, da
}
if newUsage != issuer.Usage {
if issuer.Revoked && newUsage.HasUsage(IssuanceUsage) {
if issuer.Revoked && newUsage.HasUsage(issuing.IssuanceUsage) {
// Forbid allowing cert signing on its usage.
return logical.ErrorResponse("This issuer was revoked; unable to modify its usage to include certificate signing again. Reissue this certificate (preferably with a new key) and modify that entry instead."), nil
}
@@ -604,7 +605,7 @@ func (b *backend) pathUpdateIssuer(ctx context.Context, req *logical.Request, da
if err != nil {
return nil, fmt.Errorf("unable to parse issuer's certificate: %w", err)
}
if (cert.KeyUsage&x509.KeyUsageCRLSign) == 0 && newUsage.HasUsage(CRLSigningUsage) {
if (cert.KeyUsage&x509.KeyUsageCRLSign) == 0 && newUsage.HasUsage(issuing.CRLSigningUsage) {
return logical.ErrorResponse("This issuer's underlying certificate lacks the CRLSign KeyUsage value; unable to set CRLSigningUsage on this issuer as a result."), nil
}
@@ -618,7 +619,7 @@ func (b *backend) pathUpdateIssuer(ctx context.Context, req *logical.Request, da
}
if issuer.AIAURIs == nil && (len(issuerCertificates) > 0 || len(crlDistributionPoints) > 0 || len(ocspServers) > 0) {
issuer.AIAURIs = &aiaConfigEntry{}
issuer.AIAURIs = &issuing.AiaConfigEntry{}
}
if issuer.AIAURIs != nil {
// Associative mapping from data source to destination on the
@@ -665,7 +666,7 @@ func (b *backend) pathUpdateIssuer(ctx context.Context, req *logical.Request, da
// it'll write it out to disk for us. We'd hate to then modify the issuer
// again and write it a second time.
var updateChain bool
var constructedChain []issuerID
var constructedChain []issuing.IssuerID
for index, newPathRef := range newPath {
// Allow self for the first entry.
if index == 0 && newPathRef == "self" {
@@ -715,7 +716,7 @@ func (b *backend) pathUpdateIssuer(ctx context.Context, req *logical.Request, da
addWarningOnDereferencing(sc, oldName, response)
}
if issuer.AIAURIs != nil && issuer.AIAURIs.EnableTemplating {
_, aiaErr := issuer.AIAURIs.toURLEntries(sc, issuer.ID)
_, aiaErr := ToURLEntries(sc, issuer.ID, issuer.AIAURIs)
if aiaErr != nil {
response.AddWarning(fmt.Sprintf("issuance may fail: %v\n\nConsider setting the cluster-local address if it is not already set.", aiaErr))
}
@@ -730,12 +731,12 @@ func (b *backend) pathPatchIssuer(ctx context.Context, req *logical.Request, dat
b.issuersLock.Lock()
defer b.issuersLock.Unlock()
if b.useLegacyBundleCaStorage() {
if b.UseLegacyBundleCaStorage() {
return logical.ErrorResponse("Can not patch issuer until migration has completed"), nil
}
// First we fetch the issuer
issuerName := getIssuerRef(data)
issuerName := GetIssuerRef(data)
if len(issuerName) == 0 {
return logical.ErrorResponse("missing issuer reference"), nil
}
@@ -782,8 +783,8 @@ func (b *backend) pathPatchIssuer(ctx context.Context, req *logical.Request, dat
issuer.Name = newName
issuer.LastModified = time.Now().UTC()
// See note in updateDefaultIssuerId about why this is necessary.
b.crlBuilder.invalidateCRLBuildTime()
b.crlBuilder.flushCRLBuildTimeInvalidation(sc)
b.CrlBuilder().invalidateCRLBuildTime()
b.CrlBuilder().flushCRLBuildTimeInvalidation(sc)
modified = true
}
}
@@ -813,12 +814,12 @@ func (b *backend) pathPatchIssuer(ctx context.Context, req *logical.Request, dat
rawUsageData, ok := data.GetOk("usage")
if ok {
rawUsage := rawUsageData.([]string)
newUsage, err := NewIssuerUsageFromNames(rawUsage)
newUsage, err := issuing.NewIssuerUsageFromNames(rawUsage)
if err != nil {
return logical.ErrorResponse(fmt.Sprintf("Unable to parse specified usages: %v - valid values are %v", rawUsage, AllIssuerUsages.Names())), nil
return logical.ErrorResponse(fmt.Sprintf("Unable to parse specified usages: %v - valid values are %v", rawUsage, issuing.AllIssuerUsages.Names())), nil
}
if newUsage != issuer.Usage {
if issuer.Revoked && newUsage.HasUsage(IssuanceUsage) {
if issuer.Revoked && newUsage.HasUsage(issuing.IssuanceUsage) {
// Forbid allowing cert signing on its usage.
return logical.ErrorResponse("This issuer was revoked; unable to modify its usage to include certificate signing again. Reissue this certificate (preferably with a new key) and modify that entry instead."), nil
}
@@ -827,7 +828,7 @@ func (b *backend) pathPatchIssuer(ctx context.Context, req *logical.Request, dat
if err != nil {
return nil, fmt.Errorf("unable to parse issuer's certificate: %w", err)
}
if (cert.KeyUsage&x509.KeyUsageCRLSign) == 0 && newUsage.HasUsage(CRLSigningUsage) {
if (cert.KeyUsage&x509.KeyUsageCRLSign) == 0 && newUsage.HasUsage(issuing.CRLSigningUsage) {
return logical.ErrorResponse("This issuer's underlying certificate lacks the CRLSign KeyUsage value; unable to set CRLSigningUsage on this issuer as a result."), nil
}
@@ -864,7 +865,7 @@ func (b *backend) pathPatchIssuer(ctx context.Context, req *logical.Request, dat
// AIA access changes.
if issuer.AIAURIs == nil {
issuer.AIAURIs = &aiaConfigEntry{}
issuer.AIAURIs = &issuing.AiaConfigEntry{}
}
// Associative mapping from data source to destination on the
@@ -903,7 +904,7 @@ func (b *backend) pathPatchIssuer(ctx context.Context, req *logical.Request, dat
rawURLsValue, ok := data.GetOk(pair.Source)
if ok {
urlsValue := rawURLsValue.([]string)
if badURL := validateURLs(urlsValue); !issuer.AIAURIs.EnableTemplating && badURL != "" {
if badURL := issuing.ValidateURLs(urlsValue); !issuer.AIAURIs.EnableTemplating && badURL != "" {
return logical.ErrorResponse(fmt.Sprintf("invalid URL found in Authority Information Access (AIA) parameter %v: %s", pair.Source, badURL)), nil
}
@@ -925,7 +926,7 @@ func (b *backend) pathPatchIssuer(ctx context.Context, req *logical.Request, dat
if ok {
newPath := newPathData.([]string)
var updateChain bool
var constructedChain []issuerID
var constructedChain []issuing.IssuerID
for index, newPathRef := range newPath {
// Allow self for the first entry.
if index == 0 && newPathRef == "self" {
@@ -976,7 +977,7 @@ func (b *backend) pathPatchIssuer(ctx context.Context, req *logical.Request, dat
addWarningOnDereferencing(sc, oldName, response)
}
if issuer.AIAURIs != nil && issuer.AIAURIs.EnableTemplating {
_, aiaErr := issuer.AIAURIs.toURLEntries(sc, issuer.ID)
_, aiaErr := ToURLEntries(sc, issuer.ID, issuer.AIAURIs)
if aiaErr != nil {
response.AddWarning(fmt.Sprintf("issuance may fail: %v\n\nConsider setting the cluster-local address if it is not already set.", aiaErr))
}
@@ -986,11 +987,11 @@ func (b *backend) pathPatchIssuer(ctx context.Context, req *logical.Request, dat
}
func (b *backend) pathGetRawIssuer(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
if b.useLegacyBundleCaStorage() {
if b.UseLegacyBundleCaStorage() {
return logical.ErrorResponse("Can not get issuer until migration has completed"), nil
}
issuerName := getIssuerRef(data)
issuerName := GetIssuerRef(data)
if len(issuerName) == 0 {
return logical.ErrorResponse("missing issuer reference"), nil
}
@@ -1069,11 +1070,11 @@ func (b *backend) pathDeleteIssuer(ctx context.Context, req *logical.Request, da
b.issuersLock.Lock()
defer b.issuersLock.Unlock()
if b.useLegacyBundleCaStorage() {
if b.UseLegacyBundleCaStorage() {
return logical.ErrorResponse("Can not delete issuer until migration has completed"), nil
}
issuerName := getIssuerRef(data)
issuerName := GetIssuerRef(data)
if len(issuerName) == 0 {
return logical.ErrorResponse("missing issuer reference"), nil
}
@@ -1082,7 +1083,7 @@ func (b *backend) pathDeleteIssuer(ctx context.Context, req *logical.Request, da
ref, err := sc.resolveIssuerReference(issuerName)
if err != nil {
// Return as if we deleted it if we fail to lookup the issuer.
if ref == IssuerRefNotFound {
if ref == issuing.IssuerRefNotFound {
return &logical.Response{}, nil
}
return nil, err
@@ -1120,7 +1121,7 @@ func (b *backend) pathDeleteIssuer(ctx context.Context, req *logical.Request, da
// Finally, we need to rebuild both the local and the unified CRLs. This
// will free up any now unnecessary space used in both the CRL config
// and for the underlying CRL.
warnings, err := b.crlBuilder.rebuild(sc, true)
warnings, err := b.CrlBuilder().rebuild(sc, true)
if err != nil {
return nil, err
}
@@ -1220,17 +1221,17 @@ func buildPathGetIssuerCRL(b *backend, pattern string, displayAttrs *framework.D
}
func (b *backend) pathGetIssuerCRL(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
if b.useLegacyBundleCaStorage() {
if b.UseLegacyBundleCaStorage() {
return logical.ErrorResponse("Can not get issuer's CRL until migration has completed"), nil
}
issuerName := getIssuerRef(data)
issuerName := GetIssuerRef(data)
if len(issuerName) == 0 {
return logical.ErrorResponse("missing issuer reference"), nil
}
sc := b.makeStorageContext(ctx, req.Storage)
warnings, err := b.crlBuilder.rebuildIfForced(sc)
warnings, err := b.CrlBuilder().rebuildIfForced(sc)
if err != nil {
return nil, err
}

View File

@@ -14,6 +14,9 @@ import (
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/logical"
"github.com/hashicorp/vault/builtin/logical/pki/issuing"
"github.com/hashicorp/vault/builtin/logical/pki/managed_key"
)
func pathListKeys(b *backend) *framework.Path {
@@ -62,7 +65,7 @@ their identifier and their name (if set).`
)
func (b *backend) pathListKeysHandler(ctx context.Context, req *logical.Request, _ *framework.FieldData) (*logical.Response, error) {
if b.useLegacyBundleCaStorage() {
if b.UseLegacyBundleCaStorage() {
return logical.ErrorResponse("Can not list keys until migration has completed"), nil
}
@@ -225,7 +228,7 @@ the certificate.
)
func (b *backend) pathGetKeyHandler(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
if b.useLegacyBundleCaStorage() {
if b.UseLegacyBundleCaStorage() {
return logical.ErrorResponse("Can not get keys until migration has completed"), nil
}
@@ -255,27 +258,27 @@ func (b *backend) pathGetKeyHandler(ctx context.Context, req *logical.Request, d
}
var pkForSkid crypto.PublicKey
if key.isManagedPrivateKey() {
managedKeyUUID, err := key.getManagedKeyUUID()
if key.IsManagedPrivateKey() {
managedKeyUUID, err := issuing.GetManagedKeyUUID(key)
if err != nil {
return nil, errutil.InternalError{Err: fmt.Sprintf("failed extracting managed key uuid from key id %s (%s): %v", key.ID, key.Name, err)}
}
keyInfo, err := getManagedKeyInfo(ctx, b, managedKeyUUID)
keyInfo, err := managed_key.GetManagedKeyInfo(ctx, b, managedKeyUUID)
if err != nil {
return nil, errutil.InternalError{Err: fmt.Sprintf("failed fetching managed key info from key id %s (%s): %v", key.ID, key.Name, err)}
}
pkForSkid, err = getManagedKeyPublicKey(sc.Context, sc.Backend, managedKeyUUID)
pkForSkid, err = managed_key.GetManagedKeyPublicKey(sc.Context, sc.Backend, managedKeyUUID)
if err != nil {
return nil, err
}
// To remain consistent across the api responses (mainly generate root/intermediate calls), return the actual
// type of key, not that it is a managed key.
respData[keyTypeParam] = string(keyInfo.keyType)
respData[managedKeyIdArg] = string(keyInfo.uuid)
respData[managedKeyNameArg] = string(keyInfo.name)
respData[keyTypeParam] = string(keyInfo.KeyType)
respData[managedKeyIdArg] = string(keyInfo.Uuid)
respData[managedKeyNameArg] = string(keyInfo.Name)
} else {
pkForSkid, err = getPublicKeyFromBytes([]byte(key.PrivateKey))
if err != nil {
@@ -298,7 +301,7 @@ func (b *backend) pathUpdateKeyHandler(ctx context.Context, req *logical.Request
b.issuersLock.Lock()
defer b.issuersLock.Unlock()
if b.useLegacyBundleCaStorage() {
if b.UseLegacyBundleCaStorage() {
return logical.ErrorResponse("Can not update keys until migration has completed"), nil
}
@@ -356,7 +359,7 @@ func (b *backend) pathDeleteKeyHandler(ctx context.Context, req *logical.Request
b.issuersLock.Lock()
defer b.issuersLock.Unlock()
if b.useLegacyBundleCaStorage() {
if b.UseLegacyBundleCaStorage() {
return logical.ErrorResponse("Can not delete keys until migration has completed"), nil
}
@@ -368,7 +371,7 @@ func (b *backend) pathDeleteKeyHandler(ctx context.Context, req *logical.Request
sc := b.makeStorageContext(ctx, req.Storage)
keyId, err := sc.resolveKeyReference(keyRef)
if err != nil {
if keyId == KeyRefNotFound {
if keyId == issuing.KeyRefNotFound {
// We failed to lookup the key, we should ignore any error here and reply as if it was deleted.
return nil, nil
}

View File

@@ -102,7 +102,7 @@ func (b *backend) pathGenerateIntermediate(ctx context.Context, req *logical.Req
var err error
if b.useLegacyBundleCaStorage() {
if b.UseLegacyBundleCaStorage() {
return logical.ErrorResponse("Can not create intermediate until migration has completed"), nil
}

View File

@@ -19,6 +19,8 @@ import (
"github.com/hashicorp/vault/sdk/helper/consts"
"github.com/hashicorp/vault/sdk/helper/errutil"
"github.com/hashicorp/vault/sdk/logical"
"github.com/hashicorp/vault/builtin/logical/pki/issuing"
)
func pathIssue(b *backend) *framework.Path {
@@ -285,7 +287,7 @@ See the API documentation for more information about required parameters.
// pathIssue issues a certificate and private key from given parameters,
// subject to role restrictions
func (b *backend) pathIssue(ctx context.Context, req *logical.Request, data *framework.FieldData, role *roleEntry) (*logical.Response, error) {
func (b *backend) pathIssue(ctx context.Context, req *logical.Request, data *framework.FieldData, role *issuing.RoleEntry) (*logical.Response, error) {
if role.KeyType == "any" {
return logical.ErrorResponse("role key type \"any\" not allowed for issuing certificates, only signing"), nil
}
@@ -295,19 +297,49 @@ func (b *backend) pathIssue(ctx context.Context, req *logical.Request, data *fra
// pathSign issues a certificate from a submitted CSR, subject to role
// restrictions
func (b *backend) pathSign(ctx context.Context, req *logical.Request, data *framework.FieldData, role *roleEntry) (*logical.Response, error) {
func (b *backend) pathSign(ctx context.Context, req *logical.Request, data *framework.FieldData, role *issuing.RoleEntry) (*logical.Response, error) {
return b.pathIssueSignCert(ctx, req, data, role, true, false)
}
// pathSignVerbatim issues a certificate from a submitted CSR, *not* subject to
// role restrictions
func (b *backend) pathSignVerbatim(ctx context.Context, req *logical.Request, data *framework.FieldData, role *roleEntry) (*logical.Response, error) {
entry := buildSignVerbatimRole(data, role)
func (b *backend) pathSignVerbatim(ctx context.Context, req *logical.Request, data *framework.FieldData, role *issuing.RoleEntry) (*logical.Response, error) {
opts := []issuing.RoleModifier{
issuing.WithKeyUsage(data.Get("key_usage").([]string)),
issuing.WithExtKeyUsage(data.Get("ext_key_usage").([]string)),
issuing.WithExtKeyUsageOIDs(data.Get("ext_key_usage_oids").([]string)),
issuing.WithSignatureBits(data.Get("signature_bits").(int)),
issuing.WithUsePSS(data.Get("use_pss").(bool)),
}
// if we did receive a role parameter value with a valid role, use some of its values
// to populate and influence the sign-verbatim behavior.
if role != nil {
opts = append(opts, issuing.WithNoStore(role.NoStore))
opts = append(opts, issuing.WithIssuer(role.Issuer))
if role.TTL > 0 {
opts = append(opts, issuing.WithTTL(role.TTL))
}
if role.MaxTTL > 0 {
opts = append(opts, issuing.WithMaxTTL(role.MaxTTL))
}
if role.GenerateLease != nil {
opts = append(opts, issuing.WithGenerateLease(*role.GenerateLease))
}
if role.NotBeforeDuration > 0 {
opts = append(opts, issuing.WithNotBeforeDuration(role.NotBeforeDuration))
}
}
entry := issuing.SignVerbatimRoleWithOpts(opts...)
return b.pathIssueSignCert(ctx, req, data, entry, true, true)
}
func (b *backend) pathIssueSignCert(ctx context.Context, req *logical.Request, data *framework.FieldData, role *roleEntry, useCSR, useCSRValues bool) (*logical.Response, error) {
func (b *backend) pathIssueSignCert(ctx context.Context, req *logical.Request, data *framework.FieldData, role *issuing.RoleEntry, useCSR, useCSRValues bool) (*logical.Response, error) {
// If storing the certificate and on a performance standby, forward this request on to the primary
// Allow performance secondaries to generate and store certificates locally to them.
if !role.NoStore && b.System().ReplicationState().HasState(consts.ReplicationPerformanceStandby) {
@@ -333,7 +365,7 @@ func (b *backend) pathIssueSignCert(ctx context.Context, req *logical.Request, d
} else {
// Otherwise, we must have a newer API which requires an issuer
// reference. Fetch it in this case
issuerName = getIssuerRef(data)
issuerName = GetIssuerRef(data)
if len(issuerName) == 0 {
return logical.ErrorResponse("missing issuer reference"), nil
}
@@ -347,7 +379,7 @@ func (b *backend) pathIssueSignCert(ctx context.Context, req *logical.Request, d
var caErr error
sc := b.makeStorageContext(ctx, req.Storage)
signingBundle, caErr := sc.fetchCAInfo(issuerName, IssuanceUsage)
signingBundle, caErr := sc.fetchCAInfo(issuerName, issuing.IssuanceUsage)
if caErr != nil {
switch caErr.(type) {
case errutil.UserError:
@@ -400,7 +432,8 @@ func (b *backend) pathIssueSignCert(ctx context.Context, req *logical.Request, d
if !role.NoStore {
key := "certs/" + normalizeSerial(cb.SerialNumber)
certsCounted := b.certsCounted.Load()
certCounter := b.GetCertificateCounter()
certsCounted := certCounter.IsInitialized()
err = req.Storage.Put(ctx, &logical.StorageEntry{
Key: key,
Value: parsedBundle.CertificateBytes,
@@ -408,7 +441,7 @@ func (b *backend) pathIssueSignCert(ctx context.Context, req *logical.Request, d
if err != nil {
return nil, fmt.Errorf("unable to store certificate locally: %w", err)
}
b.ifCountEnabledIncrementTotalCertificatesCount(certsCounted, key)
certCounter.IncrementTotalCertificatesCount(certsCounted, key)
}
if useCSR {

View File

@@ -13,6 +13,7 @@ import (
"strings"
"time"
"github.com/hashicorp/vault/builtin/logical/pki/issuing"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/errutil"
"github.com/hashicorp/vault/sdk/logical"
@@ -275,7 +276,7 @@ func (b *backend) pathImportIssuers(ctx context.Context, req *logical.Request, d
keysAllowed := strings.HasSuffix(req.Path, "bundle") || req.Path == "config/ca"
if b.useLegacyBundleCaStorage() {
if b.UseLegacyBundleCaStorage() {
return logical.ErrorResponse("Can not import issuers until migration has completed"), nil
}
@@ -406,7 +407,7 @@ func (b *backend) pathImportIssuers(ctx context.Context, req *logical.Request, d
}
if len(createdIssuers) > 0 {
warnings, err := b.crlBuilder.rebuild(sc, true)
warnings, err := b.CrlBuilder().rebuild(sc, true)
if err != nil {
// Before returning, check if the error message includes the
// string "PSS". If so, it indicates we might've wanted to modify
@@ -438,7 +439,7 @@ func (b *backend) pathImportIssuers(ctx context.Context, req *logical.Request, d
response.AddWarning("Unable to fetch default issuers configuration to update default issuer if necessary: " + err.Error())
} else if config.DefaultFollowsLatestIssuer {
if len(issuersWithKeys) == 1 {
if err := sc.updateDefaultIssuerId(issuerID(issuersWithKeys[0])); err != nil {
if err := sc.updateDefaultIssuerId(issuing.IssuerID(issuersWithKeys[0])); err != nil {
response.AddWarning("Unable to update this new root as the default issuer: " + err.Error())
}
} else if len(issuersWithKeys) > 1 {
@@ -627,11 +628,11 @@ func (b *backend) pathRevokeIssuer(ctx context.Context, req *logical.Request, da
defer b.issuersLock.Unlock()
// Issuer revocation can't work on the legacy cert bundle.
if b.useLegacyBundleCaStorage() {
if b.UseLegacyBundleCaStorage() {
return logical.ErrorResponse("cannot revoke issuer until migration has completed"), nil
}
issuerName := getIssuerRef(data)
issuerName := GetIssuerRef(data)
if len(issuerName) == 0 {
return logical.ErrorResponse("missing issuer reference"), nil
}
@@ -661,8 +662,8 @@ func (b *backend) pathRevokeIssuer(ctx context.Context, req *logical.Request, da
// new revocations of leaves issued by this issuer to trigger a CRL
// rebuild still.
issuer.Revoked = true
if issuer.Usage.HasUsage(IssuanceUsage) {
issuer.Usage.ToggleUsage(IssuanceUsage)
if issuer.Usage.HasUsage(issuing.IssuanceUsage) {
issuer.Usage.ToggleUsage(issuing.IssuanceUsage)
}
currTime := time.Now()
@@ -730,7 +731,7 @@ func (b *backend) pathRevokeIssuer(ctx context.Context, req *logical.Request, da
}
// Rebuild the CRL to include the newly revoked issuer.
warnings, crlErr := b.crlBuilder.rebuild(sc, false)
warnings, crlErr := b.CrlBuilder().rebuild(sc, false)
if crlErr != nil {
switch crlErr.(type) {
case errutil.UserError:

View File

@@ -13,6 +13,8 @@ import (
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/certutil"
"github.com/hashicorp/vault/sdk/logical"
"github.com/hashicorp/vault/builtin/logical/pki/managed_key"
)
func pathGenerateKey(b *backend) *framework.Path {
@@ -114,7 +116,7 @@ func (b *backend) pathGenerateKeyHandler(ctx context.Context, req *logical.Reque
b.issuersLock.Lock()
defer b.issuersLock.Unlock()
if b.useLegacyBundleCaStorage() {
if b.UseLegacyBundleCaStorage() {
return logical.ErrorResponse("Can not generate keys until migration has completed"), nil
}
@@ -153,7 +155,7 @@ func (b *backend) pathGenerateKeyHandler(ctx context.Context, req *logical.Reque
return nil, err
}
keyBundle, actualPrivateKeyType, err = createKmsKeyBundle(ctx, b, keyId)
keyBundle, actualPrivateKeyType, err = managed_key.CreateKmsKeyBundle(ctx, b, keyId)
if err != nil {
return nil, err
}
@@ -252,7 +254,7 @@ func (b *backend) pathImportKeyHandler(ctx context.Context, req *logical.Request
b.issuersLock.Lock()
defer b.issuersLock.Unlock()
if b.useLegacyBundleCaStorage() {
if b.UseLegacyBundleCaStorage() {
return logical.ErrorResponse("Cannot import keys until migration has completed"), nil
}

View File

@@ -12,6 +12,7 @@ import (
"fmt"
"testing"
"github.com/hashicorp/vault/builtin/logical/pki/issuing"
"github.com/hashicorp/vault/sdk/helper/testhelpers/schema"
"github.com/hashicorp/vault/sdk/helper/certutil"
@@ -152,7 +153,7 @@ func TestPKI_PathManageKeys_ImportKeyBundle(t *testing.T) {
require.NotEmpty(t, resp.Data["key_id"], "key id for ec import response was empty")
require.Equal(t, "my-ec-key", resp.Data["key_name"], "key_name was incorrect for ec key")
require.Equal(t, certutil.ECPrivateKey, resp.Data["key_type"])
keyId1 := resp.Data["key_id"].(keyID)
keyId1 := resp.Data["key_id"].(issuing.KeyID)
resp, err = b.HandleRequest(context.Background(), &logical.Request{
Operation: logical.UpdateOperation,
@@ -170,7 +171,7 @@ func TestPKI_PathManageKeys_ImportKeyBundle(t *testing.T) {
require.NotEmpty(t, resp.Data["key_id"], "key id for rsa import response was empty")
require.Equal(t, "my-rsa-key", resp.Data["key_name"], "key_name was incorrect for ec key")
require.Equal(t, certutil.RSAPrivateKey, resp.Data["key_type"])
keyId2 := resp.Data["key_id"].(keyID)
keyId2 := resp.Data["key_id"].(issuing.KeyID)
require.NotEqual(t, keyId1, keyId2)
@@ -251,7 +252,7 @@ func TestPKI_PathManageKeys_ImportKeyBundle(t *testing.T) {
require.NotEmpty(t, resp.Data["key_id"], "key id for rsa import response was empty")
require.Equal(t, "my-rsa-key", resp.Data["key_name"], "key_name was incorrect for ec key")
require.Equal(t, certutil.RSAPrivateKey, resp.Data["key_type"])
keyId2Reimport := resp.Data["key_id"].(keyID)
keyId2Reimport := resp.Data["key_id"].(issuing.KeyID)
require.NotEqual(t, keyId2, keyId2Reimport, "re-importing key 2 did not generate a new key id")
}
@@ -270,7 +271,7 @@ func TestPKI_PathManageKeys_DeleteDefaultKeyWarns(t *testing.T) {
require.NoError(t, err, "Failed generating key")
require.NotNil(t, resp, "Got nil response generating key")
require.False(t, resp.IsError(), "resp contained errors generating key: %#v", resp.Error())
keyId := resp.Data["key_id"].(keyID)
keyId := resp.Data["key_id"].(issuing.KeyID)
resp, err = b.HandleRequest(context.Background(), &logical.Request{
Operation: logical.DeleteOperation,
@@ -298,7 +299,7 @@ func TestPKI_PathManageKeys_DeleteUsedKeyFails(t *testing.T) {
require.NoError(t, err, "Failed generating issuer")
require.NotNil(t, resp, "Got nil response generating issuer")
require.False(t, resp.IsError(), "resp contained errors generating issuer: %#v", resp.Error())
keyId := resp.Data["key_id"].(keyID)
keyId := resp.Data["key_id"].(issuing.KeyID)
resp, err = b.HandleRequest(context.Background(), &logical.Request{
Operation: logical.DeleteOperation,
@@ -325,7 +326,7 @@ func TestPKI_PathManageKeys_UpdateKeyDetails(t *testing.T) {
require.NoError(t, err, "Failed generating key")
require.NotNil(t, resp, "Got nil response generating key")
require.False(t, resp.IsError(), "resp contained errors generating key: %#v", resp.Error())
keyId := resp.Data["key_id"].(keyID)
keyId := resp.Data["key_id"].(issuing.KeyID)
resp, err = b.HandleRequest(context.Background(), &logical.Request{
Operation: logical.UpdateOperation,

View File

@@ -20,6 +20,7 @@ import (
"time"
"github.com/hashicorp/go-secure-stdlib/parseutil"
"github.com/hashicorp/vault/builtin/logical/pki/issuing"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/certutil"
"github.com/hashicorp/vault/sdk/helper/errutil"
@@ -37,7 +38,7 @@ type ocspRespInfo struct {
serialNumber *big.Int
ocspStatus int
revocationTimeUTC *time.Time
issuerID issuerID
issuerID issuing.IssuerID
}
// These response variables should not be mutated, instead treat them as constants
@@ -155,7 +156,7 @@ func buildOcspPostWithPath(b *backend, pattern string, displayAttrs *framework.D
func (b *backend) ocspHandler(ctx context.Context, request *logical.Request, data *framework.FieldData) (*logical.Response, error) {
sc := b.makeStorageContext(ctx, request.Storage)
cfg, err := b.crlBuilder.getConfigWithUpdate(sc)
cfg, err := b.CrlBuilder().getConfigWithUpdate(sc)
if err != nil || cfg.OcspDisable || (isUnifiedOcspPath(request) && !cfg.UnifiedCRL) {
return OcspUnauthorizedResponse, nil
}
@@ -247,7 +248,7 @@ func generateUnknownResponse(cfg *crlConfig, sc *storageContext, ocspReq *ocsp.R
return logAndReturnInternalError(sc.Backend, err)
}
if !issuer.Usage.HasUsage(OCSPSigningUsage) {
if !issuer.Usage.HasUsage(issuing.OCSPSigningUsage) {
// If we don't have any issuers or default issuers set, no way to sign a response so Unauthorized it is.
return OcspUnauthorizedResponse
}
@@ -358,7 +359,7 @@ func getOcspStatus(sc *storageContext, ocspReq *ocsp.Request, useUnifiedStorage
return &info, nil
}
func lookupOcspIssuer(sc *storageContext, req *ocsp.Request, optRevokedIssuer issuerID) (*certutil.ParsedCertBundle, *issuerEntry, error) {
func lookupOcspIssuer(sc *storageContext, req *ocsp.Request, optRevokedIssuer issuing.IssuerID) (*certutil.ParsedCertBundle, *issuing.IssuerEntry, error) {
reqHash := req.HashAlgorithm
if !reqHash.Available() {
return nil, nil, x509.ErrUnsupportedAlgorithm
@@ -395,7 +396,7 @@ func lookupOcspIssuer(sc *storageContext, req *ocsp.Request, optRevokedIssuer is
}
if matches {
if !issuer.Usage.HasUsage(OCSPSigningUsage) {
if !issuer.Usage.HasUsage(issuing.OCSPSigningUsage) {
matchedButNoUsage = true
// We found a matching issuer, but it's not allowed to sign the
// response, there might be another issuer that we rotated
@@ -415,7 +416,7 @@ func lookupOcspIssuer(sc *storageContext, req *ocsp.Request, optRevokedIssuer is
return nil, nil, ErrUnknownIssuer
}
func getOcspIssuerParsedBundle(sc *storageContext, issuerId issuerID) (*certutil.ParsedCertBundle, *issuerEntry, error) {
func getOcspIssuerParsedBundle(sc *storageContext, issuerId issuing.IssuerID) (*certutil.ParsedCertBundle, *issuing.IssuerEntry, error) {
issuer, bundle, err := sc.fetchCertBundleByIssuerId(issuerId, true)
if err != nil {
switch err.(type) {
@@ -440,13 +441,13 @@ func getOcspIssuerParsedBundle(sc *storageContext, issuerId issuerID) (*certutil
return caBundle, issuer, nil
}
func lookupIssuerIds(sc *storageContext, optRevokedIssuer issuerID) ([]issuerID, error) {
func lookupIssuerIds(sc *storageContext, optRevokedIssuer issuing.IssuerID) ([]issuing.IssuerID, error) {
if optRevokedIssuer != "" {
return []issuerID{optRevokedIssuer}, nil
return []issuing.IssuerID{optRevokedIssuer}, nil
}
if sc.Backend.useLegacyBundleCaStorage() {
return []issuerID{legacyBundleShimID}, nil
if sc.Backend.UseLegacyBundleCaStorage() {
return []issuing.IssuerID{legacyBundleShimID}, nil
}
return sc.listIssuers()

View File

@@ -18,6 +18,7 @@ import (
"time"
"github.com/hashicorp/go-secure-stdlib/parseutil"
"github.com/hashicorp/vault/builtin/logical/pki/issuing"
vaulthttp "github.com/hashicorp/vault/http"
"github.com/hashicorp/vault/sdk/helper/testhelpers/schema"
"github.com/hashicorp/vault/sdk/logical"
@@ -258,11 +259,11 @@ func TestOcsp_RevokedCertHasIssuerWithoutOcspUsage(t *testing.T) {
requireFieldsSetInResp(t, resp, "usage")
// Do not assume a specific ordering for usage...
usages, err := NewIssuerUsageFromNames(strings.Split(resp.Data["usage"].(string), ","))
usages, err := issuing.NewIssuerUsageFromNames(strings.Split(resp.Data["usage"].(string), ","))
require.NoError(t, err, "failed parsing usage return value")
require.True(t, usages.HasUsage(IssuanceUsage))
require.True(t, usages.HasUsage(CRLSigningUsage))
require.False(t, usages.HasUsage(OCSPSigningUsage))
require.True(t, usages.HasUsage(issuing.IssuanceUsage))
require.True(t, usages.HasUsage(issuing.CRLSigningUsage))
require.False(t, usages.HasUsage(issuing.OCSPSigningUsage))
// Request an OCSP request from it, we should get an Unauthorized response back
resp, err = SendOcspRequest(t, b, s, "get", testEnv.leafCertIssuer1, testEnv.issuer1, crypto.SHA1)
@@ -290,7 +291,7 @@ func TestOcsp_RevokedCertHasIssuerWithoutAKey(t *testing.T) {
resp, err = CBRead(b, s, "issuer/"+testEnv.issuerId1.String())
requireSuccessNonNilResponse(t, resp, err, "failed reading issuer")
requireFieldsSetInResp(t, resp, "key_id")
keyId := resp.Data["key_id"].(keyID)
keyId := resp.Data["key_id"].(issuing.KeyID)
// This is a bit naughty but allow me to delete the key...
sc := b.makeStorageContext(context.Background(), s)
@@ -343,11 +344,11 @@ func TestOcsp_MultipleMatchingIssuersOneWithoutSigningUsage(t *testing.T) {
requireSuccessNonNilResponse(t, resp, err, "failed resetting usage flags on issuer")
requireFieldsSetInResp(t, resp, "usage")
// Do not assume a specific ordering for usage...
usages, err := NewIssuerUsageFromNames(strings.Split(resp.Data["usage"].(string), ","))
usages, err := issuing.NewIssuerUsageFromNames(strings.Split(resp.Data["usage"].(string), ","))
require.NoError(t, err, "failed parsing usage return value")
require.True(t, usages.HasUsage(IssuanceUsage))
require.True(t, usages.HasUsage(CRLSigningUsage))
require.False(t, usages.HasUsage(OCSPSigningUsage))
require.True(t, usages.HasUsage(issuing.IssuanceUsage))
require.True(t, usages.HasUsage(issuing.CRLSigningUsage))
require.False(t, usages.HasUsage(issuing.OCSPSigningUsage))
// Request an OCSP request from it, we should get a Good response back, from the rotated cert
resp, err = SendOcspRequest(t, b, s, "get", testEnv.leafCertIssuer1, testEnv.issuer1, crypto.SHA1)
@@ -625,14 +626,14 @@ type ocspTestEnv struct {
issuer1 *x509.Certificate
issuer2 *x509.Certificate
issuerId1 issuerID
issuerId2 issuerID
issuerId1 issuing.IssuerID
issuerId2 issuing.IssuerID
leafCertIssuer1 *x509.Certificate
leafCertIssuer2 *x509.Certificate
keyId1 keyID
keyId2 keyID
keyId1 issuing.KeyID
keyId2 issuing.KeyID
}
func setupOcspEnv(t *testing.T, keyType string) (*backend, logical.Storage, *ocspTestEnv) {
@@ -643,8 +644,8 @@ func setupOcspEnvWithCaKeyConfig(t *testing.T, keyType string, caKeyBits int, ca
b, s := CreateBackendWithStorage(t)
var issuerCerts []*x509.Certificate
var leafCerts []*x509.Certificate
var issuerIds []issuerID
var keyIds []keyID
var issuerIds []issuing.IssuerID
var keyIds []issuing.KeyID
resp, err := CBWrite(b, s, "config/crl", map[string]interface{}{
"ocsp_enable": true,
@@ -662,8 +663,8 @@ func setupOcspEnvWithCaKeyConfig(t *testing.T, keyType string, caKeyBits int, ca
})
requireSuccessNonNilResponse(t, resp, err, "root/generate/internal")
requireFieldsSetInResp(t, resp, "issuer_id", "key_id")
issuerId := resp.Data["issuer_id"].(issuerID)
keyId := resp.Data["key_id"].(keyID)
issuerId := resp.Data["issuer_id"].(issuing.IssuerID)
keyId := resp.Data["key_id"].(issuing.KeyID)
resp, err = CBWrite(b, s, "roles/test"+strconv.FormatInt(int64(i), 10), map[string]interface{}{
"allow_bare_domains": true,

View File

@@ -21,6 +21,7 @@ import (
"time"
"github.com/hashicorp/go-secure-stdlib/parseutil"
"github.com/hashicorp/vault/builtin/logical/pki/issuing"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/certutil"
"github.com/hashicorp/vault/sdk/logical"
@@ -185,11 +186,11 @@ return a signed CRL based on the parameter values.`,
}
func (b *backend) pathUpdateResignCrlsHandler(ctx context.Context, request *logical.Request, data *framework.FieldData) (*logical.Response, error) {
if b.useLegacyBundleCaStorage() {
if b.UseLegacyBundleCaStorage() {
return logical.ErrorResponse("This API cannot be used until the migration has completed"), nil
}
issuerRef := getIssuerRef(data)
issuerRef := GetIssuerRef(data)
crlNumber := data.Get(crlNumberParam).(int)
deltaCrlBaseNumber := data.Get(deltaCrlBaseNumberParam).(int)
nextUpdateStr := data.Get(nextUpdateParam).(string)
@@ -273,11 +274,11 @@ func (b *backend) pathUpdateResignCrlsHandler(ctx context.Context, request *logi
}
func (b *backend) pathUpdateSignRevocationListHandler(ctx context.Context, request *logical.Request, data *framework.FieldData) (*logical.Response, error) {
if b.useLegacyBundleCaStorage() {
if b.UseLegacyBundleCaStorage() {
return logical.ErrorResponse("This API cannot be used until the migration has completed"), nil
}
issuerRef := getIssuerRef(data)
issuerRef := GetIssuerRef(data)
crlNumber := data.Get(crlNumberParam).(int)
deltaCrlBaseNumber := data.Get(deltaCrlBaseNumberParam).(int)
nextUpdateStr := data.Get(nextUpdateParam).(string)
@@ -649,7 +650,7 @@ func getCaBundle(sc *storageContext, issuerRef string) (*certutil.CAInfoBundle,
return nil, fmt.Errorf("failed to resolve issuer %s: %w", issuerRefParam, err)
}
return sc.fetchCAInfoByIssuerId(issuerId, CRLSigningUsage)
return sc.fetchCAInfoByIssuerId(issuerId, issuing.CRLSigningUsage)
}
func decodePemCrls(rawCrls []string) ([]*x509.RevocationList, error) {

View File

@@ -23,6 +23,8 @@ import (
"github.com/hashicorp/vault/sdk/helper/certutil"
"github.com/hashicorp/vault/sdk/helper/errutil"
"github.com/hashicorp/vault/sdk/logical"
"github.com/hashicorp/vault/builtin/logical/pki/issuing"
)
func pathListCertsRevoked(b *backend) *framework.Path {
@@ -306,7 +308,7 @@ func (b *backend) pathRevokeWriteHandleCertificate(ctx context.Context, req *log
//
// We return the parsed serial number, an optionally-nil byte array to
// write out to disk, and an error if one occurred.
if b.useLegacyBundleCaStorage() {
if b.UseLegacyBundleCaStorage() {
// We require listing all issuers from the 1.11 method. If we're
// still using the legacy CA bundle but with the newer certificate
// attribute, we err and require the operator to upgrade and migrate
@@ -534,7 +536,7 @@ func (b *backend) maybeRevokeCrossCluster(sc *storageContext, config *crlConfig,
return resp, nil
}
func (b *backend) pathRevokeWrite(ctx context.Context, req *logical.Request, data *framework.FieldData, _ *roleEntry) (*logical.Response, error) {
func (b *backend) pathRevokeWrite(ctx context.Context, req *logical.Request, data *framework.FieldData, _ *issuing.RoleEntry) (*logical.Response, error) {
rawSerial, haveSerial := data.GetOk("serial_number")
rawCertificate, haveCert := data.GetOk("certificate")
sc := b.makeStorageContext(ctx, req.Storage)
@@ -563,7 +565,7 @@ func (b *backend) pathRevokeWrite(ctx context.Context, req *logical.Request, dat
var cert *x509.Certificate
var serial string
config, err := sc.Backend.crlBuilder.getConfigWithUpdate(sc)
config, err := sc.Backend.CrlBuilder().getConfigWithUpdate(sc)
if err != nil {
return nil, fmt.Errorf("error revoking serial: %s: failed reading config: %w", serial, err)
}
@@ -647,18 +649,18 @@ func (b *backend) pathRevokeWrite(ctx context.Context, req *logical.Request, dat
return nil, logical.ErrReadOnly
}
b.revokeStorageLock.Lock()
defer b.revokeStorageLock.Unlock()
b.GetRevokeStorageLock().Lock()
defer b.GetRevokeStorageLock().Unlock()
return revokeCert(sc, config, cert)
}
func (b *backend) pathRotateCRLRead(ctx context.Context, req *logical.Request, _ *framework.FieldData) (*logical.Response, error) {
b.revokeStorageLock.RLock()
defer b.revokeStorageLock.RUnlock()
b.GetRevokeStorageLock().RLock()
defer b.GetRevokeStorageLock().RUnlock()
sc := b.makeStorageContext(ctx, req.Storage)
warnings, crlErr := b.crlBuilder.rebuild(sc, false)
warnings, crlErr := b.CrlBuilder().rebuild(sc, false)
if crlErr != nil {
switch crlErr.(type) {
case errutil.UserError:
@@ -684,14 +686,14 @@ func (b *backend) pathRotateCRLRead(ctx context.Context, req *logical.Request, _
func (b *backend) pathRotateDeltaCRLRead(ctx context.Context, req *logical.Request, _ *framework.FieldData) (*logical.Response, error) {
sc := b.makeStorageContext(ctx, req.Storage)
cfg, err := b.crlBuilder.getConfigWithUpdate(sc)
cfg, err := b.CrlBuilder().getConfigWithUpdate(sc)
if err != nil {
return nil, fmt.Errorf("error fetching CRL configuration: %w", err)
}
isEnabled := cfg.EnableDelta
warnings, crlErr := b.crlBuilder.rebuildDeltaCRLsIfForced(sc, true)
warnings, crlErr := b.CrlBuilder().rebuildDeltaCRLsIfForced(sc, true)
if crlErr != nil {
switch crlErr.(type) {
case errutil.UserError:

View File

@@ -5,19 +5,20 @@ package pki
import (
"context"
"crypto/x509"
"encoding/json"
"errors"
"fmt"
"net/http"
"strings"
"time"
"github.com/hashicorp/go-secure-stdlib/parseutil"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/certutil"
"github.com/hashicorp/vault/sdk/helper/consts"
"github.com/hashicorp/vault/sdk/helper/errutil"
"github.com/hashicorp/vault/sdk/logical"
"github.com/hashicorp/vault/builtin/logical/pki/issuing"
)
func pathListRoles(b *backend) *framework.Path {
@@ -860,137 +861,26 @@ serviced by this role.`,
}
}
func (b *backend) getRole(ctx context.Context, s logical.Storage, n string) (*roleEntry, error) {
entry, err := s.Get(ctx, "role/"+n)
// GetRole loads a role from storage, will validate it and error out if,
// updates it and stores it back if possible. If the role does not exist
// a nil, nil response is returned.
func (b *backend) GetRole(ctx context.Context, s logical.Storage, n string) (*issuing.RoleEntry, error) {
result, err := issuing.GetRole(ctx, s, n)
if err != nil {
if errors.Is(err, issuing.ErrRoleNotFound) {
return nil, nil
}
return nil, err
}
if entry == nil {
return nil, nil
}
var result roleEntry
if err := entry.DecodeJSON(&result); err != nil {
return nil, err
}
// Migrate existing saved entries and save back if changed
modified := false
if len(result.DeprecatedTTL) == 0 && len(result.Lease) != 0 {
result.DeprecatedTTL = result.Lease
result.Lease = ""
modified = true
}
if result.TTL == 0 && len(result.DeprecatedTTL) != 0 {
parsed, err := parseutil.ParseDurationSecond(result.DeprecatedTTL)
if err != nil {
return nil, err
}
result.TTL = parsed
result.DeprecatedTTL = ""
modified = true
}
if len(result.DeprecatedMaxTTL) == 0 && len(result.LeaseMax) != 0 {
result.DeprecatedMaxTTL = result.LeaseMax
result.LeaseMax = ""
modified = true
}
if result.MaxTTL == 0 && len(result.DeprecatedMaxTTL) != 0 {
parsed, err := parseutil.ParseDurationSecond(result.DeprecatedMaxTTL)
if err != nil {
return nil, err
}
result.MaxTTL = parsed
result.DeprecatedMaxTTL = ""
modified = true
}
if result.AllowBaseDomain {
result.AllowBaseDomain = false
result.AllowBareDomains = true
modified = true
}
if result.AllowedDomainsOld != "" {
result.AllowedDomains = strings.Split(result.AllowedDomainsOld, ",")
result.AllowedDomainsOld = ""
modified = true
}
if result.AllowedBaseDomain != "" {
found := false
for _, v := range result.AllowedDomains {
if v == result.AllowedBaseDomain {
found = true
break
}
}
if !found {
result.AllowedDomains = append(result.AllowedDomains, result.AllowedBaseDomain)
}
result.AllowedBaseDomain = ""
modified = true
}
if result.AllowWildcardCertificates == nil {
// While not the most secure default, when AllowWildcardCertificates isn't
// explicitly specified in the stored Role, we automatically upgrade it to
// true to preserve compatibility with previous versions of Vault. Once this
// field is set, this logic will not be triggered any more.
result.AllowWildcardCertificates = new(bool)
*result.AllowWildcardCertificates = true
modified = true
}
// Upgrade generate_lease in role
if result.GenerateLease == nil {
// All the new roles will have GenerateLease always set to a value. A
// nil value indicates that this role needs an upgrade. Set it to
// `true` to not alter its current behavior.
result.GenerateLease = new(bool)
*result.GenerateLease = true
modified = true
}
// Upgrade key usages
if result.KeyUsageOld != "" {
result.KeyUsage = strings.Split(result.KeyUsageOld, ",")
result.KeyUsageOld = ""
modified = true
}
// Upgrade OU
if result.OUOld != "" {
result.OU = strings.Split(result.OUOld, ",")
result.OUOld = ""
modified = true
}
// Upgrade Organization
if result.OrganizationOld != "" {
result.Organization = strings.Split(result.OrganizationOld, ",")
result.OrganizationOld = ""
modified = true
}
// Set the issuer field to default if not set. We want to do this
// unconditionally as we should probably never have an empty issuer
// on a stored roles.
if len(result.Issuer) == 0 {
result.Issuer = defaultRef
modified = true
}
// Update CN Validations to be the present default, "email,hostname"
if len(result.CNValidations) == 0 {
result.CNValidations = []string{"email", "hostname"}
modified = true
}
// Ensure the role is valid after updating.
_, err = validateRole(b, &result, ctx, s)
_, err = validateRole(b, result, ctx, s)
if err != nil {
return nil, err
}
if modified && (b.System().LocalMount() || !b.System().ReplicationState().HasState(consts.ReplicationPerformanceSecondary)) {
jsonEntry, err := logical.StorageEntryJSON("role/"+n, &result)
if result.WasModified && (b.System().LocalMount() || !b.System().ReplicationState().HasState(consts.ReplicationPerformanceSecondary)) {
jsonEntry, err := logical.StorageEntryJSON("role/"+n, result)
if err != nil {
return nil, err
}
@@ -1000,11 +890,10 @@ func (b *backend) getRole(ctx context.Context, s logical.Storage, n string) (*ro
return nil, err
}
}
result.WasModified = false
}
result.Name = n
return &result, nil
return result, nil
}
func (b *backend) pathRoleDelete(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
@@ -1022,7 +911,7 @@ func (b *backend) pathRoleRead(ctx context.Context, req *logical.Request, data *
return logical.ErrorResponse("missing role name"), nil
}
role, err := b.getRole(ctx, req.Storage, roleName)
role, err := b.GetRole(ctx, req.Storage, roleName)
if err != nil {
return nil, err
}
@@ -1049,7 +938,7 @@ func (b *backend) pathRoleCreate(ctx context.Context, req *logical.Request, data
var err error
name := data.Get("name").(string)
entry := &roleEntry{
entry := &issuing.RoleEntry{
MaxTTL: time.Duration(data.Get("max_ttl").(int)) * time.Second,
TTL: time.Duration(data.Get("ttl").(int)) * time.Second,
AllowLocalhost: data.Get("allow_localhost").(bool),
@@ -1156,7 +1045,7 @@ func (b *backend) pathRoleCreate(ctx context.Context, req *logical.Request, data
return resp, nil
}
func validateRole(b *backend, entry *roleEntry, ctx context.Context, s logical.Storage) (*logical.Response, error) {
func validateRole(b *backend, entry *issuing.RoleEntry, ctx context.Context, s logical.Storage) (*logical.Response, error) {
resp := &logical.Response{}
var err error
@@ -1194,11 +1083,11 @@ func validateRole(b *backend, entry *roleEntry, ctx context.Context, s logical.S
entry.Issuer = defaultRef
}
// Check that the issuers reference set resolves to something
if !b.useLegacyBundleCaStorage() {
if !b.UseLegacyBundleCaStorage() {
sc := b.makeStorageContext(ctx, s)
issuerId, err := sc.resolveIssuerReference(entry.Issuer)
if err != nil {
if issuerId == IssuerRefNotFound {
if issuerId == issuing.IssuerRefNotFound {
resp = &logical.Response{}
if entry.Issuer == defaultRef {
resp.AddWarning("Issuing Certificate was set to default, but no default issuing certificate (configurable at /config/issuers) is currently set")
@@ -1241,7 +1130,7 @@ func getTimeWithExplicitDefault(data *framework.FieldData, field string, default
func (b *backend) pathRolePatch(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
name := data.Get("name").(string)
oldEntry, err := b.getRole(ctx, req.Storage, name)
oldEntry, err := b.GetRole(ctx, req.Storage, name)
if err != nil {
return nil, err
}
@@ -1249,7 +1138,7 @@ func (b *backend) pathRolePatch(ctx context.Context, req *logical.Request, data
return logical.ErrorResponse("Unable to fetch role entry to patch"), nil
}
entry := &roleEntry{
entry := &issuing.RoleEntry{
MaxTTL: getTimeWithExplicitDefault(data, "max_ttl", oldEntry.MaxTTL),
TTL: getTimeWithExplicitDefault(data, "ttl", oldEntry.TTL),
AllowLocalhost: getWithExplicitDefault(data, "allow_localhost", oldEntry.AllowLocalhost).(bool),
@@ -1363,206 +1252,6 @@ func (b *backend) pathRolePatch(ctx context.Context, req *logical.Request, data
return resp, nil
}
func parseKeyUsages(input []string) int {
var parsedKeyUsages x509.KeyUsage
for _, k := range input {
switch strings.ToLower(strings.TrimSpace(k)) {
case "digitalsignature":
parsedKeyUsages |= x509.KeyUsageDigitalSignature
case "contentcommitment":
parsedKeyUsages |= x509.KeyUsageContentCommitment
case "keyencipherment":
parsedKeyUsages |= x509.KeyUsageKeyEncipherment
case "dataencipherment":
parsedKeyUsages |= x509.KeyUsageDataEncipherment
case "keyagreement":
parsedKeyUsages |= x509.KeyUsageKeyAgreement
case "certsign":
parsedKeyUsages |= x509.KeyUsageCertSign
case "crlsign":
parsedKeyUsages |= x509.KeyUsageCRLSign
case "encipheronly":
parsedKeyUsages |= x509.KeyUsageEncipherOnly
case "decipheronly":
parsedKeyUsages |= x509.KeyUsageDecipherOnly
}
}
return int(parsedKeyUsages)
}
func parseExtKeyUsages(role *roleEntry) certutil.CertExtKeyUsage {
var parsedKeyUsages certutil.CertExtKeyUsage
if role.ServerFlag {
parsedKeyUsages |= certutil.ServerAuthExtKeyUsage
}
if role.ClientFlag {
parsedKeyUsages |= certutil.ClientAuthExtKeyUsage
}
if role.CodeSigningFlag {
parsedKeyUsages |= certutil.CodeSigningExtKeyUsage
}
if role.EmailProtectionFlag {
parsedKeyUsages |= certutil.EmailProtectionExtKeyUsage
}
for _, k := range role.ExtKeyUsage {
switch strings.ToLower(strings.TrimSpace(k)) {
case "any":
parsedKeyUsages |= certutil.AnyExtKeyUsage
case "serverauth":
parsedKeyUsages |= certutil.ServerAuthExtKeyUsage
case "clientauth":
parsedKeyUsages |= certutil.ClientAuthExtKeyUsage
case "codesigning":
parsedKeyUsages |= certutil.CodeSigningExtKeyUsage
case "emailprotection":
parsedKeyUsages |= certutil.EmailProtectionExtKeyUsage
case "ipsecendsystem":
parsedKeyUsages |= certutil.IpsecEndSystemExtKeyUsage
case "ipsectunnel":
parsedKeyUsages |= certutil.IpsecTunnelExtKeyUsage
case "ipsecuser":
parsedKeyUsages |= certutil.IpsecUserExtKeyUsage
case "timestamping":
parsedKeyUsages |= certutil.TimeStampingExtKeyUsage
case "ocspsigning":
parsedKeyUsages |= certutil.OcspSigningExtKeyUsage
case "microsoftservergatedcrypto":
parsedKeyUsages |= certutil.MicrosoftServerGatedCryptoExtKeyUsage
case "netscapeservergatedcrypto":
parsedKeyUsages |= certutil.NetscapeServerGatedCryptoExtKeyUsage
}
}
return parsedKeyUsages
}
type roleEntry struct {
LeaseMax string `json:"lease_max"`
Lease string `json:"lease"`
DeprecatedMaxTTL string `json:"max_ttl"`
DeprecatedTTL string `json:"ttl"`
TTL time.Duration `json:"ttl_duration"`
MaxTTL time.Duration `json:"max_ttl_duration"`
AllowLocalhost bool `json:"allow_localhost"`
AllowedBaseDomain string `json:"allowed_base_domain"`
AllowedDomainsOld string `json:"allowed_domains,omitempty"`
AllowedDomains []string `json:"allowed_domains_list"`
AllowedDomainsTemplate bool `json:"allowed_domains_template"`
AllowBaseDomain bool `json:"allow_base_domain"`
AllowBareDomains bool `json:"allow_bare_domains"`
AllowTokenDisplayName bool `json:"allow_token_displayname"`
AllowSubdomains bool `json:"allow_subdomains"`
AllowGlobDomains bool `json:"allow_glob_domains"`
AllowWildcardCertificates *bool `json:"allow_wildcard_certificates,omitempty"`
AllowAnyName bool `json:"allow_any_name"`
EnforceHostnames bool `json:"enforce_hostnames"`
AllowIPSANs bool `json:"allow_ip_sans"`
ServerFlag bool `json:"server_flag"`
ClientFlag bool `json:"client_flag"`
CodeSigningFlag bool `json:"code_signing_flag"`
EmailProtectionFlag bool `json:"email_protection_flag"`
UseCSRCommonName bool `json:"use_csr_common_name"`
UseCSRSANs bool `json:"use_csr_sans"`
KeyType string `json:"key_type"`
KeyBits int `json:"key_bits"`
UsePSS bool `json:"use_pss"`
SignatureBits int `json:"signature_bits"`
MaxPathLength *int `json:",omitempty"`
KeyUsageOld string `json:"key_usage,omitempty"`
KeyUsage []string `json:"key_usage_list"`
ExtKeyUsage []string `json:"extended_key_usage_list"`
OUOld string `json:"ou,omitempty"`
OU []string `json:"ou_list"`
OrganizationOld string `json:"organization,omitempty"`
Organization []string `json:"organization_list"`
Country []string `json:"country"`
Locality []string `json:"locality"`
Province []string `json:"province"`
StreetAddress []string `json:"street_address"`
PostalCode []string `json:"postal_code"`
GenerateLease *bool `json:"generate_lease,omitempty"`
NoStore bool `json:"no_store"`
RequireCN bool `json:"require_cn"`
CNValidations []string `json:"cn_validations"`
AllowedOtherSANs []string `json:"allowed_other_sans"`
AllowedSerialNumbers []string `json:"allowed_serial_numbers"`
AllowedUserIDs []string `json:"allowed_user_ids"`
AllowedURISANs []string `json:"allowed_uri_sans"`
AllowedURISANsTemplate bool `json:"allowed_uri_sans_template"`
PolicyIdentifiers []string `json:"policy_identifiers"`
ExtKeyUsageOIDs []string `json:"ext_key_usage_oids"`
BasicConstraintsValidForNonCA bool `json:"basic_constraints_valid_for_non_ca"`
NotBeforeDuration time.Duration `json:"not_before_duration"`
NotAfter string `json:"not_after"`
Issuer string `json:"issuer"`
// Name is only set when the role has been stored, on the fly roles have a blank name
Name string `json:"-"`
}
func (r *roleEntry) ToResponseData() map[string]interface{} {
responseData := map[string]interface{}{
"ttl": int64(r.TTL.Seconds()),
"max_ttl": int64(r.MaxTTL.Seconds()),
"allow_localhost": r.AllowLocalhost,
"allowed_domains": r.AllowedDomains,
"allowed_domains_template": r.AllowedDomainsTemplate,
"allow_bare_domains": r.AllowBareDomains,
"allow_token_displayname": r.AllowTokenDisplayName,
"allow_subdomains": r.AllowSubdomains,
"allow_glob_domains": r.AllowGlobDomains,
"allow_wildcard_certificates": r.AllowWildcardCertificates,
"allow_any_name": r.AllowAnyName,
"allowed_uri_sans_template": r.AllowedURISANsTemplate,
"enforce_hostnames": r.EnforceHostnames,
"allow_ip_sans": r.AllowIPSANs,
"server_flag": r.ServerFlag,
"client_flag": r.ClientFlag,
"code_signing_flag": r.CodeSigningFlag,
"email_protection_flag": r.EmailProtectionFlag,
"use_csr_common_name": r.UseCSRCommonName,
"use_csr_sans": r.UseCSRSANs,
"key_type": r.KeyType,
"key_bits": r.KeyBits,
"signature_bits": r.SignatureBits,
"use_pss": r.UsePSS,
"key_usage": r.KeyUsage,
"ext_key_usage": r.ExtKeyUsage,
"ext_key_usage_oids": r.ExtKeyUsageOIDs,
"ou": r.OU,
"organization": r.Organization,
"country": r.Country,
"locality": r.Locality,
"province": r.Province,
"street_address": r.StreetAddress,
"postal_code": r.PostalCode,
"no_store": r.NoStore,
"allowed_other_sans": r.AllowedOtherSANs,
"allowed_serial_numbers": r.AllowedSerialNumbers,
"allowed_user_ids": r.AllowedUserIDs,
"allowed_uri_sans": r.AllowedURISANs,
"require_cn": r.RequireCN,
"cn_validations": r.CNValidations,
"policy_identifiers": r.PolicyIdentifiers,
"basic_constraints_valid_for_non_ca": r.BasicConstraintsValidForNonCA,
"not_before_duration": int64(r.NotBeforeDuration.Seconds()),
"not_after": r.NotAfter,
"issuer_ref": r.Issuer,
}
if r.MaxPathLength != nil {
responseData["max_path_length"] = r.MaxPathLength
}
if r.GenerateLease != nil {
responseData["generate_lease"] = r.GenerateLease
}
return responseData
}
func checkCNValidations(validations []string) ([]string, error) {
var haveDisabled bool
var haveEmail bool

View File

@@ -18,6 +18,8 @@ import (
"github.com/hashicorp/vault/sdk/logical"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/hashicorp/vault/builtin/logical/pki/issuing"
)
func TestPki_RoleGenerateLease(t *testing.T) {
@@ -69,7 +71,7 @@ func TestPki_RoleGenerateLease(t *testing.T) {
t.Fatal(err)
}
var role roleEntry
var role issuing.RoleEntry
if err := entry.DecodeJSON(&role); err != nil {
t.Fatal(err)
}
@@ -170,7 +172,7 @@ func TestPki_RoleKeyUsage(t *testing.T) {
t.Fatal(err)
}
var role roleEntry
var role issuing.RoleEntry
if err := entry.DecodeJSON(&role); err != nil {
t.Fatal(err)
}
@@ -205,7 +207,7 @@ func TestPki_RoleKeyUsage(t *testing.T) {
if entry == nil {
t.Fatalf("role should not be nil")
}
var result roleEntry
var result issuing.RoleEntry
if err := entry.DecodeJSON(&result); err != nil {
t.Fatalf("err: %v", err)
}
@@ -265,7 +267,7 @@ func TestPki_RoleOUOrganizationUpgrade(t *testing.T) {
t.Fatal(err)
}
var role roleEntry
var role issuing.RoleEntry
if err := entry.DecodeJSON(&role); err != nil {
t.Fatal(err)
}
@@ -305,7 +307,7 @@ func TestPki_RoleOUOrganizationUpgrade(t *testing.T) {
if entry == nil {
t.Fatalf("role should not be nil")
}
var result roleEntry
var result issuing.RoleEntry
if err := entry.DecodeJSON(&result); err != nil {
t.Fatalf("err: %v", err)
}
@@ -365,7 +367,7 @@ func TestPki_RoleAllowedDomains(t *testing.T) {
t.Fatal(err)
}
var role roleEntry
var role issuing.RoleEntry
if err := entry.DecodeJSON(&role); err != nil {
t.Fatal(err)
}
@@ -399,7 +401,7 @@ func TestPki_RoleAllowedDomains(t *testing.T) {
if entry == nil {
t.Fatalf("role should not be nil")
}
var result roleEntry
var result issuing.RoleEntry
if err := entry.DecodeJSON(&result); err != nil {
t.Fatalf("err: %v", err)
}

View File

@@ -27,6 +27,9 @@ import (
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/errutil"
"github.com/hashicorp/vault/sdk/logical"
"github.com/hashicorp/vault/builtin/logical/pki/issuing"
"github.com/hashicorp/vault/builtin/logical/pki/parsing"
)
func pathGenerateRoot(b *backend) *framework.Path {
@@ -78,7 +81,7 @@ func (b *backend) pathCADeleteRoot(ctx context.Context, req *logical.Request, _
defer b.issuersLock.Unlock()
sc := b.makeStorageContext(ctx, req.Storage)
if !b.useLegacyBundleCaStorage() {
if !b.UseLegacyBundleCaStorage() {
issuers, err := sc.listIssuers()
if err != nil {
return nil, err
@@ -132,7 +135,7 @@ func (b *backend) pathCAGenerateRoot(ctx context.Context, req *logical.Request,
var err error
if b.useLegacyBundleCaStorage() {
if b.UseLegacyBundleCaStorage() {
return logical.ErrorResponse("Can not create root CA until migration has completed"), nil
}
@@ -286,7 +289,8 @@ func (b *backend) pathCAGenerateRoot(ctx context.Context, req *logical.Request,
// Also store it as just the certificate identified by serial number, so it
// can be revoked
key := "certs/" + normalizeSerial(cb.SerialNumber)
certsCounted := b.certsCounted.Load()
certCounter := b.GetCertificateCounter()
certsCounted := certCounter.IsInitialized()
err = req.Storage.Put(ctx, &logical.StorageEntry{
Key: key,
Value: parsedBundle.CertificateBytes,
@@ -294,10 +298,10 @@ func (b *backend) pathCAGenerateRoot(ctx context.Context, req *logical.Request,
if err != nil {
return nil, fmt.Errorf("unable to store certificate locally: %w", err)
}
b.ifCountEnabledIncrementTotalCertificatesCount(certsCounted, key)
certCounter.IncrementTotalCertificatesCount(certsCounted, key)
// Build a fresh CRL
warnings, err = b.crlBuilder.rebuild(sc, true)
warnings, err = b.CrlBuilder().rebuild(sc, true)
if err != nil {
return nil, err
}
@@ -327,7 +331,7 @@ func (b *backend) pathCAGenerateRoot(ctx context.Context, req *logical.Request,
func (b *backend) pathIssuerSignIntermediate(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
var err error
issuerName := getIssuerRef(data)
issuerName := GetIssuerRef(data)
if len(issuerName) == 0 {
return logical.ErrorResponse("missing issuer reference"), nil
}
@@ -337,7 +341,7 @@ func (b *backend) pathIssuerSignIntermediate(ctx context.Context, req *logical.R
return logical.ErrorResponse(`The "format" path parameter must be "pem", "der" or "pem_bundle"`), nil
}
role := &roleEntry{
role := &issuing.RoleEntry{
OU: data.Get("ou").([]string),
Organization: data.Get("organization").([]string),
Country: data.Get("country").([]string),
@@ -369,7 +373,7 @@ func (b *backend) pathIssuerSignIntermediate(ctx context.Context, req *logical.R
var caErr error
sc := b.makeStorageContext(ctx, req.Storage)
signingBundle, caErr := sc.fetchCAInfo(issuerName, IssuanceUsage)
signingBundle, caErr := sc.fetchCAInfo(issuerName, issuing.IssuanceUsage)
if caErr != nil {
switch caErr.(type) {
case errutil.UserError:
@@ -420,7 +424,8 @@ func (b *backend) pathIssuerSignIntermediate(ctx context.Context, req *logical.R
}
key := "certs/" + normalizeSerialFromBigInt(parsedBundle.Certificate.SerialNumber)
certsCounted := b.certsCounted.Load()
certCounter := b.GetCertificateCounter()
certsCounted := certCounter.IsInitialized()
err = req.Storage.Put(ctx, &logical.StorageEntry{
Key: key,
Value: parsedBundle.CertificateBytes,
@@ -428,7 +433,7 @@ func (b *backend) pathIssuerSignIntermediate(ctx context.Context, req *logical.R
if err != nil {
return nil, fmt.Errorf("unable to store certificate locally: %w", err)
}
b.ifCountEnabledIncrementTotalCertificatesCount(certsCounted, key)
certCounter.IncrementTotalCertificatesCount(certsCounted, key)
return resp, nil
}
@@ -512,19 +517,13 @@ func signIntermediateResponse(signingBundle *certutil.CAInfoBundle, parsedBundle
}
func (b *backend) pathIssuerSignSelfIssued(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
var err error
issuerName := getIssuerRef(data)
issuerName := GetIssuerRef(data)
if len(issuerName) == 0 {
return logical.ErrorResponse("missing issuer reference"), nil
}
certPem := data.Get("certificate").(string)
block, _ := pem.Decode([]byte(certPem))
if block == nil || len(block.Bytes) == 0 {
return logical.ErrorResponse("certificate could not be PEM-decoded"), nil
}
certs, err := x509.ParseCertificates(block.Bytes)
certs, err := parsing.ParseCertificatesFromString(certPem)
if err != nil {
return logical.ErrorResponse(fmt.Sprintf("error parsing certificate: %s", err)), nil
}
@@ -540,9 +539,8 @@ func (b *backend) pathIssuerSignSelfIssued(ctx context.Context, req *logical.Req
return logical.ErrorResponse("given certificate is not self-issued"), nil
}
var caErr error
sc := b.makeStorageContext(ctx, req.Storage)
signingBundle, caErr := sc.fetchCAInfo(issuerName, IssuanceUsage)
signingBundle, caErr := sc.fetchCAInfo(issuerName, issuing.IssuanceUsage)
if caErr != nil {
switch caErr.(type) {
case errutil.UserError:

View File

@@ -15,6 +15,7 @@ import (
"github.com/armon/go-metrics"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/go-secure-stdlib/parseutil"
"github.com/hashicorp/vault/builtin/logical/pki/issuing"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/consts"
"github.com/hashicorp/vault/sdk/logical"
@@ -1013,8 +1014,8 @@ func (b *backend) doTidyCertStore(ctx context.Context, req *logical.Request, log
}
func (b *backend) doTidyRevocationStore(ctx context.Context, req *logical.Request, logger hclog.Logger, config *tidyConfig) error {
b.revokeStorageLock.Lock()
defer b.revokeStorageLock.Unlock()
b.GetRevokeStorageLock().Lock()
defer b.GetRevokeStorageLock().Unlock()
// Fetch and parse our issuers so we can associate them if necessary.
sc := b.makeStorageContext(ctx, req.Storage)
@@ -1047,9 +1048,9 @@ func (b *backend) doTidyRevocationStore(ctx context.Context, req *logical.Reques
// Check for pause duration to reduce resource consumption.
if config.PauseDuration > (0 * time.Second) {
b.revokeStorageLock.Unlock()
b.GetRevokeStorageLock().Unlock()
time.Sleep(config.PauseDuration)
b.revokeStorageLock.Lock()
b.GetRevokeStorageLock().Lock()
}
revokedEntry, err := req.Storage.Get(ctx, "revoked/"+serial)
@@ -1092,7 +1093,7 @@ func (b *backend) doTidyRevocationStore(ctx context.Context, req *logical.Reques
if config.IssuerAssocs {
if !isRevInfoIssuerValid(&revInfo, issuerIDCertMap) {
b.tidyStatusIncMissingIssuerCertCount()
revInfo.CertificateIssuer = issuerID("")
revInfo.CertificateIssuer = issuing.IssuerID("")
storeCert = true
if associateRevokedCertWithIsssuer(&revInfo, revokedCert, issuerIDCertMap) {
fixedIssuers += 1
@@ -1150,7 +1151,7 @@ func (b *backend) doTidyRevocationStore(ctx context.Context, req *logical.Reques
}
if !config.AutoRebuild {
warnings, err := b.crlBuilder.rebuild(sc, false)
warnings, err := b.CrlBuilder().rebuild(sc, false)
if err != nil {
return err
}
@@ -1180,7 +1181,7 @@ func (b *backend) doTidyExpiredIssuers(ctx context.Context, req *logical.Request
// Short-circuit to avoid having to deal with the legacy mounts. While we
// could handle this case and remove these issuers, its somewhat
// unexpected behavior and we'd prefer to finish the migration first.
if b.useLegacyBundleCaStorage() {
if b.UseLegacyBundleCaStorage() {
return nil
}
@@ -1259,10 +1260,10 @@ func (b *backend) doTidyExpiredIssuers(ctx context.Context, req *logical.Request
// Removal of issuers is generally a good reason to rebuild the CRL,
// even if auto-rebuild is enabled.
b.revokeStorageLock.Lock()
defer b.revokeStorageLock.Unlock()
b.GetRevokeStorageLock().Lock()
defer b.GetRevokeStorageLock().Unlock()
warnings, err := b.crlBuilder.rebuild(sc, false)
warnings, err := b.CrlBuilder().rebuild(sc, false)
if err != nil {
return err
}
@@ -1290,7 +1291,7 @@ func (b *backend) doTidyMoveCABundle(ctx context.Context, req *logical.Request,
// Short-circuit to avoid moving the legacy bundle from under a legacy
// mount.
if b.useLegacyBundleCaStorage() {
if b.UseLegacyBundleCaStorage() {
return nil
}
@@ -1353,8 +1354,8 @@ func (b *backend) doTidyRevocationQueue(ctx context.Context, req *logical.Reques
}
// Grab locks as we're potentially modifying revocation-related storage.
b.revokeStorageLock.Lock()
defer b.revokeStorageLock.Unlock()
b.GetRevokeStorageLock().Lock()
defer b.GetRevokeStorageLock().Unlock()
for cIndex, cluster := range clusters {
if cluster[len(cluster)-1] == '/' {
@@ -1375,9 +1376,9 @@ func (b *backend) doTidyRevocationQueue(ctx context.Context, req *logical.Reques
// Check for pause duration to reduce resource consumption.
if config.PauseDuration > (0 * time.Second) {
b.revokeStorageLock.Unlock()
b.GetRevokeStorageLock().Unlock()
time.Sleep(config.PauseDuration)
b.revokeStorageLock.Lock()
b.GetRevokeStorageLock().Lock()
}
// Confirmation entries _should_ be handled by this cluster's
@@ -1475,8 +1476,8 @@ func (b *backend) doTidyCrossRevocationStore(ctx context.Context, req *logical.R
}
// Grab locks as we're potentially modifying revocation-related storage.
b.revokeStorageLock.Lock()
defer b.revokeStorageLock.Unlock()
b.GetRevokeStorageLock().Lock()
defer b.GetRevokeStorageLock().Unlock()
for cIndex, cluster := range clusters {
if cluster[len(cluster)-1] == '/' {
@@ -1497,9 +1498,9 @@ func (b *backend) doTidyCrossRevocationStore(ctx context.Context, req *logical.R
// Check for pause duration to reduce resource consumption.
if config.PauseDuration > (0 * time.Second) {
b.revokeStorageLock.Unlock()
b.GetRevokeStorageLock().Unlock()
time.Sleep(config.PauseDuration)
b.revokeStorageLock.Lock()
b.GetRevokeStorageLock().Lock()
}
ePath := cPath + serial
@@ -1547,7 +1548,7 @@ func (b *backend) doTidyAcme(ctx context.Context, req *logical.Request, logger h
b.tidyStatusLock.Unlock()
for _, thumbprint := range thumbprints {
err := b.tidyAcmeAccountByThumbprint(b.acmeState, sc, thumbprint, config.SafetyBuffer, config.AcmeAccountSafetyBuffer)
err := b.tidyAcmeAccountByThumbprint(b.GetAcmeState(), sc, thumbprint, config.SafetyBuffer, config.AcmeAccountSafetyBuffer)
if err != nil {
logger.Warn("error tidying account %v: %v", thumbprint, err.Error())
}
@@ -1567,13 +1568,13 @@ func (b *backend) doTidyAcme(ctx context.Context, req *logical.Request, logger h
}
// Clean up any unused EAB
eabIds, err := b.acmeState.ListEabIds(sc)
eabIds, err := b.GetAcmeState().ListEabIds(sc)
if err != nil {
return fmt.Errorf("failed listing EAB ids: %w", err)
}
for _, eabId := range eabIds {
eab, err := b.acmeState.LoadEab(sc, eabId)
eab, err := b.GetAcmeState().LoadEab(sc, eabId)
if err != nil {
if errors.Is(err, ErrStorageItemNotFound) {
// We don't need to worry about a consumed EAB
@@ -1584,7 +1585,7 @@ func (b *backend) doTidyAcme(ctx context.Context, req *logical.Request, logger h
eabExpiration := eab.CreatedOn.Add(config.AcmeAccountSafetyBuffer)
if time.Now().After(eabExpiration) {
_, err := b.acmeState.DeleteEab(sc, eabId)
_, err := b.GetAcmeState().DeleteEab(sc, eabId)
if err != nil {
return fmt.Errorf("failed to tidy eab %s: %w", eabId, err)
}
@@ -1669,15 +1670,17 @@ func (b *backend) pathTidyStatusRead(_ context.Context, _ *logical.Request, _ *f
resp.Data["internal_backend_uuid"] = b.backendUUID
if b.certCountEnabled.Load() {
resp.Data["current_cert_store_count"] = b.certCount.Load()
resp.Data["current_revoked_cert_count"] = b.revokedCertCount.Load()
if !b.certsCounted.Load() {
certCounter := b.GetCertificateCounter()
if certCounter.IsEnabled() {
resp.Data["current_cert_store_count"] = certCounter.CertificateCount()
resp.Data["current_revoked_cert_count"] = certCounter.RevokedCount()
if !certCounter.IsInitialized() {
resp.AddWarning("Certificates in storage are still being counted, current counts provided may be " +
"inaccurate")
}
if b.certCountError != "" {
resp.Data["certificate_counting_error"] = b.certCountError
certError := certCounter.Error()
if certError != nil {
resp.Data["certificate_counting_error"] = certError.Error()
}
}
@@ -1925,7 +1928,7 @@ func (b *backend) tidyStatusIncCertStoreCount() {
b.tidyStatus.certStoreDeletedCount++
b.ifCountEnabledDecrementTotalCertificatesCountReport()
b.GetCertificateCounter().DecrementTotalCertificatesCountReport()
}
func (b *backend) tidyStatusIncRevokedCertCount() {
@@ -1934,7 +1937,7 @@ func (b *backend) tidyStatusIncRevokedCertCount() {
b.tidyStatus.revokedCertDeletedCount++
b.ifCountEnabledDecrementTotalRevokedCertificatesCountReport()
b.GetCertificateCounter().DecrementTotalRevokedCertificatesCountReport()
}
func (b *backend) tidyStatusIncMissingIssuerCertCount() {

View File

@@ -18,18 +18,18 @@ const (
minUnifiedTransferDelay = 30 * time.Minute
)
type unifiedTransferStatus struct {
type UnifiedTransferStatus struct {
isRunning atomic.Bool
lastRun time.Time
forceRerun atomic.Bool
}
func (uts *unifiedTransferStatus) forceRun() {
func (uts *UnifiedTransferStatus) forceRun() {
uts.forceRerun.Store(true)
}
func newUnifiedTransferStatus() *unifiedTransferStatus {
return &unifiedTransferStatus{}
func newUnifiedTransferStatus() *UnifiedTransferStatus {
return &UnifiedTransferStatus{}
}
// runUnifiedTransfer meant to run as a background, this will process all and
@@ -37,7 +37,7 @@ func newUnifiedTransferStatus() *unifiedTransferStatus {
// is enabled.
func runUnifiedTransfer(sc *storageContext) {
b := sc.Backend
status := b.unifiedTransferStatus
status := b.GetUnifiedTransferStatus()
isPerfStandby := b.System().ReplicationState().HasState(consts.ReplicationDRSecondary | consts.ReplicationPerformanceStandby)
@@ -46,7 +46,7 @@ func runUnifiedTransfer(sc *storageContext) {
return
}
config, err := b.crlBuilder.getConfigWithUpdate(sc)
config, err := b.CrlBuilder().getConfigWithUpdate(sc)
if err != nil {
b.Logger().Error("failed to retrieve crl config from storage for unified transfer background process",
"error", err)
@@ -125,7 +125,7 @@ func doUnifiedTransferMissingLocalSerials(sc *storageContext, clusterId string)
errCount := 0
for i, serialNum := range localRevokedSerialNums {
if i%25 == 0 {
config, _ := sc.Backend.crlBuilder.getConfigWithUpdate(sc)
config, _ := sc.Backend.CrlBuilder().getConfigWithUpdate(sc)
if config != nil && !config.UnifiedCRL {
return errors.New("unified crl has been disabled after we started, stopping")
}
@@ -224,7 +224,7 @@ func doUnifiedTransferMissingDeltaWALSerials(sc *storageContext, clusterId strin
errCount := 0
for index, serial := range localWALEntries {
if index%25 == 0 {
config, _ := sc.Backend.crlBuilder.getConfigWithUpdate(sc)
config, _ := sc.Backend.CrlBuilder().getConfigWithUpdate(sc)
if config != nil && (!config.UnifiedCRL || !config.EnableDelta) {
return errors.New("unified or delta CRLs have been disabled after we started, stopping")
}

View File

@@ -0,0 +1,21 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package pki_backend
import (
log "github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/sdk/logical"
)
type SystemViewGetter interface {
System() logical.SystemView
}
type MountInfo interface {
BackendUUID() string
}
type Logger interface {
Logger() log.Logger
}

View File

@@ -49,8 +49,8 @@ func (b *backend) secretCredsRevoke(ctx context.Context, req *logical.Request, _
return nil, fmt.Errorf("could not find serial in internal secret data")
}
b.revokeStorageLock.Lock()
defer b.revokeStorageLock.Unlock()
b.GetRevokeStorageLock().Lock()
defer b.GetRevokeStorageLock().Unlock()
sc := b.makeStorageContext(ctx, req.Storage)
serial := serialInt.(string)
@@ -77,7 +77,7 @@ func (b *backend) secretCredsRevoke(ctx context.Context, req *logical.Request, _
return nil, nil
}
config, err := sc.Backend.crlBuilder.getConfigWithUpdate(sc)
config, err := sc.Backend.CrlBuilder().getConfigWithUpdate(sc)
if err != nil {
return nil, fmt.Errorf("error revoking serial: %s: failed reading config: %w", serial, err)
}

File diff suppressed because it is too large Load Diff

View File

@@ -10,6 +10,7 @@ import (
"fmt"
"time"
"github.com/hashicorp/vault/builtin/logical/pki/issuing"
"github.com/hashicorp/vault/sdk/helper/certutil"
"github.com/hashicorp/vault/sdk/logical"
)
@@ -19,16 +20,16 @@ import (
// and we need to perform it again...
const (
latestMigrationVersion = 2
legacyBundleShimID = issuerID("legacy-entry-shim-id")
legacyBundleShimKeyID = keyID("legacy-entry-shim-key-id")
legacyBundleShimID = issuing.LegacyBundleShimID
legacyBundleShimKeyID = issuing.LegacyBundleShimKeyID
)
type legacyBundleMigrationLog struct {
Hash string `json:"hash"`
Created time.Time `json:"created"`
CreatedIssuer issuerID `json:"issuer_id"`
CreatedKey keyID `json:"key_id"`
MigrationVersion int `json:"migrationVersion"`
Hash string `json:"hash"`
Created time.Time `json:"created"`
CreatedIssuer issuing.IssuerID `json:"issuer_id"`
CreatedKey issuing.KeyID `json:"key_id"`
MigrationVersion int `json:"migrationVersion"`
}
type migrationInfo struct {
@@ -84,8 +85,8 @@ func migrateStorage(ctx context.Context, b *backend, s logical.Storage) error {
return nil
}
var issuerIdentifier issuerID
var keyIdentifier keyID
var issuerIdentifier issuing.IssuerID
var keyIdentifier issuing.KeyID
sc := b.makeStorageContext(ctx, s)
if migrationInfo.legacyBundle != nil {
// When the legacy bundle still exists, there's three scenarios we
@@ -120,7 +121,7 @@ func migrateStorage(ctx context.Context, b *backend, s logical.Storage) error {
// Since we do not have all the mount information available we must schedule
// the CRL to be rebuilt at a later time.
b.crlBuilder.requestRebuildIfActiveNode(b)
b.CrlBuilder().requestRebuildIfActiveNode(b)
}
}
@@ -202,33 +203,6 @@ func setLegacyBundleMigrationLog(ctx context.Context, s logical.Storage, lbm *le
return s.Put(ctx, json)
}
func getLegacyCertBundle(ctx context.Context, s logical.Storage) (*issuerEntry, *certutil.CertBundle, error) {
entry, err := s.Get(ctx, legacyCertBundlePath)
if err != nil {
return nil, nil, err
}
if entry == nil {
return nil, nil, nil
}
cb := &certutil.CertBundle{}
err = entry.DecodeJSON(cb)
if err != nil {
return nil, nil, err
}
// Fake a storage entry with backwards compatibility in mind.
issuer := &issuerEntry{
ID: legacyBundleShimID,
KeyID: legacyBundleShimKeyID,
Name: "legacy-entry-shim",
Certificate: cb.Certificate,
CAChain: cb.CAChain,
SerialNumber: cb.SerialNumber,
LeafNotAfterBehavior: certutil.ErrNotAfterBehavior,
}
issuer.Usage.ToggleUsage(AllIssuerUsages)
return issuer, cb, nil
func getLegacyCertBundle(ctx context.Context, s logical.Storage) (*issuing.IssuerEntry, *certutil.CertBundle, error) {
return issuing.GetLegacyCertBundle(ctx, s)
}

View File

@@ -9,6 +9,7 @@ import (
"testing"
"time"
"github.com/hashicorp/vault/builtin/logical/pki/issuing"
"github.com/hashicorp/vault/sdk/helper/certutil"
"github.com/hashicorp/vault/sdk/logical"
"github.com/stretchr/testify/require"
@@ -23,7 +24,7 @@ func Test_migrateStorageEmptyStorage(t *testing.T) {
// Reset the version the helper above set to 1.
b.pkiStorageVersion.Store(0)
require.True(t, b.useLegacyBundleCaStorage(), "pre migration we should have been told to use legacy storage.")
require.True(t, b.UseLegacyBundleCaStorage(), "pre migration we should have been told to use legacy storage.")
request := &logical.InitializationRequest{Storage: s}
err := b.initialize(ctx, request)
@@ -48,7 +49,7 @@ func Test_migrateStorageEmptyStorage(t *testing.T) {
require.Empty(t, logEntry.CreatedIssuer)
require.Empty(t, logEntry.CreatedKey)
require.False(t, b.useLegacyBundleCaStorage(), "post migration we are still told to use legacy storage")
require.False(t, b.UseLegacyBundleCaStorage(), "post migration we are still told to use legacy storage")
// Make sure we can re-run the migration without issues
request = &logical.InitializationRequest{Storage: s}
@@ -72,7 +73,7 @@ func Test_migrateStorageOnlyKey(t *testing.T) {
// Reset the version the helper above set to 1.
b.pkiStorageVersion.Store(0)
require.True(t, b.useLegacyBundleCaStorage(), "pre migration we should have been told to use legacy storage.")
require.True(t, b.UseLegacyBundleCaStorage(), "pre migration we should have been told to use legacy storage.")
bundle := genCertBundle(t, b, s)
// Clear everything except for the key
@@ -106,7 +107,7 @@ func Test_migrateStorageOnlyKey(t *testing.T) {
"Hash value (%s) should not have been empty", logEntry.Hash)
require.True(t, startTime.Before(logEntry.Created),
"created log entry time (%v) was before our start time(%v)?", logEntry.Created, startTime)
require.Equal(t, logEntry.CreatedIssuer, issuerID(""))
require.Equal(t, logEntry.CreatedIssuer, issuing.IssuerID(""))
require.Equal(t, logEntry.CreatedKey, keyIds[0])
keyId := keyIds[0]
@@ -126,11 +127,11 @@ func Test_migrateStorageOnlyKey(t *testing.T) {
// Make sure we setup the default values
keysConfig, err := sc.getKeysConfig()
require.NoError(t, err)
require.Equal(t, &keyConfigEntry{DefaultKeyId: keyId}, keysConfig)
require.Equal(t, &issuing.KeyConfigEntry{DefaultKeyId: keyId}, keysConfig)
issuersConfig, err := sc.getIssuersConfig()
require.NoError(t, err)
require.Equal(t, issuerID(""), issuersConfig.DefaultIssuerId)
require.Equal(t, issuing.IssuerID(""), issuersConfig.DefaultIssuerId)
// Make sure if we attempt to re-run the migration nothing happens...
err = migrateStorage(ctx, b, s)
@@ -142,7 +143,7 @@ func Test_migrateStorageOnlyKey(t *testing.T) {
require.Equal(t, logEntry.Created, logEntry2.Created)
require.Equal(t, logEntry.Hash, logEntry2.Hash)
require.False(t, b.useLegacyBundleCaStorage(), "post migration we are still told to use legacy storage")
require.False(t, b.UseLegacyBundleCaStorage(), "post migration we are still told to use legacy storage")
}
func Test_migrateStorageSimpleBundle(t *testing.T) {
@@ -154,7 +155,7 @@ func Test_migrateStorageSimpleBundle(t *testing.T) {
// Reset the version the helper above set to 1.
b.pkiStorageVersion.Store(0)
require.True(t, b.useLegacyBundleCaStorage(), "pre migration we should have been told to use legacy storage.")
require.True(t, b.UseLegacyBundleCaStorage(), "pre migration we should have been told to use legacy storage.")
bundle := genCertBundle(t, b, s)
json, err := logical.StorageEntryJSON(legacyCertBundlePath, bundle)
@@ -204,7 +205,7 @@ func Test_migrateStorageSimpleBundle(t *testing.T) {
require.Equal(t, keyId, issuer.KeyID)
require.Empty(t, issuer.ManualChain)
require.Equal(t, []string{bundle.Certificate + "\n"}, issuer.CAChain)
require.Equal(t, AllIssuerUsages, issuer.Usage)
require.Equal(t, issuing.AllIssuerUsages, issuer.Usage)
require.Equal(t, certutil.ErrNotAfterBehavior, issuer.LeafNotAfterBehavior)
require.Equal(t, keyId, key.ID)
@@ -219,7 +220,7 @@ func Test_migrateStorageSimpleBundle(t *testing.T) {
// Make sure we setup the default values
keysConfig, err := sc.getKeysConfig()
require.NoError(t, err)
require.Equal(t, &keyConfigEntry{DefaultKeyId: keyId}, keysConfig)
require.Equal(t, &issuing.KeyConfigEntry{DefaultKeyId: keyId}, keysConfig)
issuersConfig, err := sc.getIssuersConfig()
require.NoError(t, err)
@@ -235,7 +236,7 @@ func Test_migrateStorageSimpleBundle(t *testing.T) {
require.Equal(t, logEntry.Created, logEntry2.Created)
require.Equal(t, logEntry.Hash, logEntry2.Hash)
require.False(t, b.useLegacyBundleCaStorage(), "post migration we are still told to use legacy storage")
require.False(t, b.UseLegacyBundleCaStorage(), "post migration we are still told to use legacy storage")
// Make sure we can re-process a migration from scratch for whatever reason
err = s.Delete(ctx, legacyMigrationBundleLogKey)
@@ -296,8 +297,8 @@ func TestMigration_OnceChainRebuild(t *testing.T) {
//
// Afterwards, we mutate these issuers to only point at themselves and
// write back out.
var rootIssuerId issuerID
var intIssuerId issuerID
var rootIssuerId issuing.IssuerID
var intIssuerId issuing.IssuerID
for _, issuerId := range issuerIds {
issuer, err := sc.fetchIssuerById(issuerId)
require.NoError(t, err)
@@ -368,7 +369,7 @@ func TestExpectedOpsWork_PreMigration(t *testing.T) {
b, s := CreateBackendWithStorage(t)
// Reset the version the helper above set to 1.
b.pkiStorageVersion.Store(0)
require.True(t, b.useLegacyBundleCaStorage(), "pre migration we should have been told to use legacy storage.")
require.True(t, b.UseLegacyBundleCaStorage(), "pre migration we should have been told to use legacy storage.")
bundle := genCertBundle(t, b, s)
json, err := logical.StorageEntryJSON(legacyCertBundlePath, bundle)
@@ -601,7 +602,7 @@ func TestBackupBundle(t *testing.T) {
// Reset the version the helper above set to 1.
b.pkiStorageVersion.Store(0)
require.True(t, b.useLegacyBundleCaStorage(), "pre migration we should have been told to use legacy storage.")
require.True(t, b.UseLegacyBundleCaStorage(), "pre migration we should have been told to use legacy storage.")
// Create an empty request and tidy configuration for us.
req := &logical.Request{
@@ -793,7 +794,7 @@ func TestDeletedIssuersPostMigration(t *testing.T) {
// Reset the version the helper above set to 1.
b.pkiStorageVersion.Store(0)
require.True(t, b.useLegacyBundleCaStorage(), "pre migration we should have been told to use legacy storage.")
require.True(t, b.UseLegacyBundleCaStorage(), "pre migration we should have been told to use legacy storage.")
// Create a legacy CA bundle and write it out.
bundle := genCertBundle(t, b, s)

View File

@@ -8,6 +8,7 @@ import (
"strings"
"testing"
"github.com/hashicorp/vault/builtin/logical/pki/issuing"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/certutil"
"github.com/hashicorp/vault/sdk/logical"
@@ -22,27 +23,27 @@ func Test_ConfigsRoundTrip(t *testing.T) {
sc := b.makeStorageContext(ctx, s)
// Create an empty key, issuer for testing.
key := keyEntry{ID: genKeyId()}
key := issuing.KeyEntry{ID: genKeyId()}
err := sc.writeKey(key)
require.NoError(t, err)
issuer := &issuerEntry{ID: genIssuerId()}
issuer := &issuing.IssuerEntry{ID: genIssuerId()}
err = sc.writeIssuer(issuer)
require.NoError(t, err)
// Verify we handle nothing stored properly
keyConfigEmpty, err := sc.getKeysConfig()
require.NoError(t, err)
require.Equal(t, &keyConfigEntry{}, keyConfigEmpty)
require.Equal(t, &issuing.KeyConfigEntry{}, keyConfigEmpty)
issuerConfigEmpty, err := sc.getIssuersConfig()
require.NoError(t, err)
require.Equal(t, &issuerConfigEntry{}, issuerConfigEmpty)
require.Equal(t, &issuing.IssuerConfigEntry{}, issuerConfigEmpty)
// Now attempt to store and reload properly
origKeyConfig := &keyConfigEntry{
origKeyConfig := &issuing.KeyConfigEntry{
DefaultKeyId: key.ID,
}
origIssuerConfig := &issuerConfigEntry{
origIssuerConfig := &issuing.IssuerConfigEntry{
DefaultIssuerId: issuer.ID,
}
@@ -98,12 +99,12 @@ func Test_IssuerRoundTrip(t *testing.T) {
keys, err := sc.listKeys()
require.NoError(t, err)
require.ElementsMatch(t, []keyID{key1.ID, key2.ID}, keys)
require.ElementsMatch(t, []issuing.KeyID{key1.ID, key2.ID}, keys)
issuers, err := sc.listIssuers()
require.NoError(t, err)
require.ElementsMatch(t, []issuerID{issuer1.ID, issuer2.ID}, issuers)
require.ElementsMatch(t, []issuing.IssuerID{issuer1.ID, issuer2.ID}, issuers)
}
func Test_KeysIssuerImport(t *testing.T) {
@@ -183,7 +184,7 @@ func Test_IssuerUpgrade(t *testing.T) {
// Make sure that we add OCSP signing to v0 issuers if CRLSigning is enabled
issuer, _ := genIssuerAndKey(t, b, s)
issuer.Version = 0
issuer.Usage.ToggleUsage(OCSPSigningUsage)
issuer.Usage.ToggleUsage(issuing.OCSPSigningUsage)
err := sc.writeIssuer(&issuer)
require.NoError(t, err, "failed writing out issuer")
@@ -192,13 +193,13 @@ func Test_IssuerUpgrade(t *testing.T) {
require.NoError(t, err, "failed fetching issuer")
require.Equal(t, uint(1), newIssuer.Version)
require.True(t, newIssuer.Usage.HasUsage(OCSPSigningUsage))
require.True(t, newIssuer.Usage.HasUsage(issuing.OCSPSigningUsage))
// If CRLSigning is not present on a v0, we should not have OCSP signing after upgrade.
issuer, _ = genIssuerAndKey(t, b, s)
issuer.Version = 0
issuer.Usage.ToggleUsage(OCSPSigningUsage)
issuer.Usage.ToggleUsage(CRLSigningUsage)
issuer.Usage.ToggleUsage(issuing.OCSPSigningUsage)
issuer.Usage.ToggleUsage(issuing.CRLSigningUsage)
err = sc.writeIssuer(&issuer)
require.NoError(t, err, "failed writing out issuer")
@@ -207,15 +208,15 @@ func Test_IssuerUpgrade(t *testing.T) {
require.NoError(t, err, "failed fetching issuer")
require.Equal(t, uint(1), newIssuer.Version)
require.False(t, newIssuer.Usage.HasUsage(OCSPSigningUsage))
require.False(t, newIssuer.Usage.HasUsage(issuing.OCSPSigningUsage))
}
func genIssuerAndKey(t *testing.T, b *backend, s logical.Storage) (issuerEntry, keyEntry) {
func genIssuerAndKey(t *testing.T, b *backend, s logical.Storage) (issuing.IssuerEntry, issuing.KeyEntry) {
certBundle := genCertBundle(t, b, s)
keyId := genKeyId()
pkiKey := keyEntry{
pkiKey := issuing.KeyEntry{
ID: keyId,
PrivateKeyType: certBundle.PrivateKeyType,
PrivateKey: strings.TrimSpace(certBundle.PrivateKey) + "\n",
@@ -223,14 +224,14 @@ func genIssuerAndKey(t *testing.T, b *backend, s logical.Storage) (issuerEntry,
issuerId := genIssuerId()
pkiIssuer := issuerEntry{
pkiIssuer := issuing.IssuerEntry{
ID: issuerId,
KeyID: keyId,
Certificate: strings.TrimSpace(certBundle.Certificate) + "\n",
CAChain: certBundle.CAChain,
SerialNumber: certBundle.SerialNumber,
Usage: AllIssuerUsages,
Version: latestIssuerVersion,
Usage: issuing.AllIssuerUsages,
Version: issuing.LatestIssuerVersion,
}
return pkiIssuer, pkiKey

View File

@@ -8,6 +8,7 @@ import (
"strings"
"time"
"github.com/hashicorp/vault/builtin/logical/pki/issuing"
"github.com/hashicorp/vault/sdk/logical"
)
@@ -17,10 +18,10 @@ const (
)
type unifiedRevocationEntry struct {
SerialNumber string `json:"-"`
CertExpiration time.Time `json:"certificate_expiration_utc"`
RevocationTimeUTC time.Time `json:"revocation_time_utc"`
CertificateIssuer issuerID `json:"issuer_id"`
SerialNumber string `json:"-"`
CertExpiration time.Time `json:"certificate_expiration_utc"`
RevocationTimeUTC time.Time `json:"revocation_time_utc"`
CertificateIssuer issuing.IssuerID `json:"issuer_id"`
}
func getUnifiedRevocationBySerial(sc *storageContext, serial string) (*unifiedRevocationEntry, error) {

View File

@@ -4,7 +4,6 @@
package pki
import (
"crypto"
"crypto/x509"
"fmt"
"math/big"
@@ -14,6 +13,8 @@ import (
"sync"
"time"
"github.com/hashicorp/vault/builtin/logical/pki/issuing"
"github.com/hashicorp/vault/builtin/logical/pki/managed_key"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/certutil"
@@ -24,7 +25,7 @@ import (
const (
managedKeyNameArg = "managed_key_name"
managedKeyIdArg = "managed_key_id"
defaultRef = "default"
defaultRef = issuing.DefaultRef
// Constants for If-Modified-Since operation
headerIfModifiedSince = "If-Modified-Since"
@@ -92,26 +93,6 @@ type managedKeyId interface {
String() string
}
type (
UUIDKey string
NameKey string
)
func (u UUIDKey) String() string {
return string(u)
}
func (n NameKey) String() string {
return string(n)
}
type managedKeyInfo struct {
publicKey crypto.PublicKey
keyType certutil.PrivateKeyType
name NameKey
uuid UUIDKey
}
// getManagedKeyId returns a NameKey or a UUIDKey, whichever was specified in the
// request API data.
func getManagedKeyId(data *framework.FieldData) (managedKeyId, error) {
@@ -120,9 +101,9 @@ func getManagedKeyId(data *framework.FieldData) (managedKeyId, error) {
return nil, err
}
var keyId managedKeyId = NameKey(name)
var keyId managedKeyId = managed_key.NameKey(name)
if len(UUID) > 0 {
keyId = UUIDKey(UUID)
keyId = managed_key.UUIDKey(UUID)
}
return keyId, nil
@@ -188,7 +169,7 @@ func getIssuerName(sc *storageContext, data *framework.FieldData) (string, error
return issuerName, errIssuerNameInUse
}
if err != nil && issuerId != IssuerRefNotFound {
if err != nil && issuerId != issuing.IssuerRefNotFound {
return issuerName, errutil.InternalError{Err: err.Error()}
}
}
@@ -213,14 +194,14 @@ func getKeyName(sc *storageContext, data *framework.FieldData) (string, error) {
return "", errKeyNameInUse
}
if err != nil && keyId != KeyRefNotFound {
if err != nil && keyId != issuing.KeyRefNotFound {
return "", errutil.InternalError{Err: err.Error()}
}
}
return keyName, nil
}
func getIssuerRef(data *framework.FieldData) string {
func GetIssuerRef(data *framework.FieldData) string {
return extractRef(data, issuerRefParam)
}
@@ -286,7 +267,7 @@ const (
type IfModifiedSinceHelper struct {
req *logical.Request
reqType ifModifiedReqType
issuerRef issuerID
issuerRef issuing.IssuerID
}
func sendNotModifiedResponseIfNecessary(helper *IfModifiedSinceHelper, sc *storageContext, resp *logical.Response) (bool, error) {
@@ -326,7 +307,7 @@ func (sc *storageContext) isIfModifiedSinceBeforeLastModified(helper *IfModified
switch helper.reqType {
case ifModifiedCRL, ifModifiedDeltaCRL:
if sc.Backend.crlBuilder.invalidate.Load() {
if sc.Backend.CrlBuilder().invalidate.Load() {
// When we see the CRL is invalidated, respond with false
// regardless of what the local CRL state says. We've likely
// renamed some issuers or are about to rebuild a new CRL....
@@ -346,7 +327,7 @@ func (sc *storageContext) isIfModifiedSinceBeforeLastModified(helper *IfModified
lastModified = crlConfig.DeltaLastModified
}
case ifModifiedUnifiedCRL, ifModifiedUnifiedDeltaCRL:
if sc.Backend.crlBuilder.invalidate.Load() {
if sc.Backend.CrlBuilder().invalidate.Load() {
// When we see the CRL is invalidated, respond with false
// regardless of what the local CRL state says. We've likely
// renamed some issuers or are about to rebuild a new CRL....

View File

@@ -163,6 +163,21 @@ func GetPrivateKeyTypeFromSigner(signer crypto.Signer) PrivateKeyType {
return UnknownPrivateKey
}
// GetPrivateKeyTypeFromPublicKey based on the public key, return the PrivateKeyType
// that would be associated with it, returning UnknownPrivateKey for unsupported types
func GetPrivateKeyTypeFromPublicKey(pubKey crypto.PublicKey) PrivateKeyType {
switch pubKey.(type) {
case *rsa.PublicKey:
return RSAPrivateKey
case *ecdsa.PublicKey:
return ECPrivateKey
case *ed25519.PublicKey:
return Ed25519PrivateKey
default:
return UnknownPrivateKey
}
}
// ToPEMBundle converts a string-based certificate bundle
// to a PEM-based string certificate bundle in trust path
// order, leaf certificate first