PKI refactoring to start breaking apart monolith into sub-packages (#24406)

* PKI refactoring to start breaking apart monolith into sub-packages

 - This was broken down by commit within enterprise for ease of review
   but would be too difficult to bring back individual commits back
   to the CE repository. (they would be squashed anyways)
 - This change was created by exporting a patch of the enterprise PR
   and applying it to CE repository

* Fix TestBackend_OID_SANs to not be rely on map ordering
This commit is contained in:
Steven Clark
2023-12-07 09:22:53 -05:00
committed by GitHub
parent a4180c193b
commit cbf6dc2c4f
70 changed files with 4620 additions and 3543 deletions

View File

@@ -17,6 +17,7 @@ import (
"time" "time"
"github.com/hashicorp/go-secure-stdlib/nonceutil" "github.com/hashicorp/go-secure-stdlib/nonceutil"
"github.com/hashicorp/vault/builtin/logical/pki/issuing"
"github.com/hashicorp/vault/sdk/framework" "github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/consts" "github.com/hashicorp/vault/sdk/helper/consts"
"github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/sdk/logical"
@@ -220,7 +221,7 @@ type acmeOrder struct {
CertificateSerialNumber string `json:"cert-serial-number"` CertificateSerialNumber string `json:"cert-serial-number"`
CertificateExpiry time.Time `json:"cert-expiry"` CertificateExpiry time.Time `json:"cert-expiry"`
// The actual issuer UUID that issued the certificate, blank if an order exists but no certificate was issued. // The actual issuer UUID that issued the certificate, blank if an order exists but no certificate was issued.
IssuerId issuerID `json:"issuer-id"` IssuerId issuing.IssuerID `json:"issuer-id"`
} }
func (o acmeOrder) getIdentifierDNSValues() []string { func (o acmeOrder) getIdentifierDNSValues() []string {

View File

@@ -13,6 +13,8 @@ import (
"github.com/hashicorp/vault/sdk/framework" "github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/sdk/logical"
"github.com/hashicorp/vault/builtin/logical/pki/issuing"
) )
type acmeContext struct { type acmeContext struct {
@@ -20,8 +22,8 @@ type acmeContext struct {
baseUrl *url.URL baseUrl *url.URL
clusterUrl *url.URL clusterUrl *url.URL
sc *storageContext sc *storageContext
role *roleEntry role *issuing.RoleEntry
issuer *issuerEntry issuer *issuing.IssuerEntry
// acmeDirectory is a string that can distinguish the various acme directories we have configured // acmeDirectory is a string that can distinguish the various acme directories we have configured
// if something needs to remain locked into a directory path structure. // if something needs to remain locked into a directory path structure.
acmeDirectory string acmeDirectory string
@@ -31,7 +33,7 @@ type acmeContext struct {
} }
func (c acmeContext) getAcmeState() *acmeState { func (c acmeContext) getAcmeState() *acmeState {
return c.sc.Backend.acmeState return c.sc.Backend.GetAcmeState()
} }
type ( type (
@@ -109,7 +111,7 @@ func (b *backend) acmeWrapper(opts acmeWrapperOpts, op acmeOperation) framework.
return acmeErrorWrapper(func(ctx context.Context, r *logical.Request, data *framework.FieldData) (*logical.Response, error) { return acmeErrorWrapper(func(ctx context.Context, r *logical.Request, data *framework.FieldData) (*logical.Response, error) {
sc := b.makeStorageContext(ctx, r.Storage) sc := b.makeStorageContext(ctx, r.Storage)
config, err := sc.Backend.acmeState.getConfigWithUpdate(sc) config, err := sc.Backend.GetAcmeState().getConfigWithUpdate(sc)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to fetch ACME configuration: %w", err) return nil, fmt.Errorf("failed to fetch ACME configuration: %w", err)
} }
@@ -124,7 +126,7 @@ func (b *backend) acmeWrapper(opts acmeWrapperOpts, op acmeOperation) framework.
return nil, ErrAcmeDisabled return nil, ErrAcmeDisabled
} }
if b.useLegacyBundleCaStorage() { if b.UseLegacyBundleCaStorage() {
return nil, fmt.Errorf("%w: Can not perform ACME operations until migration has completed", ErrServerInternal) return nil, fmt.Errorf("%w: Can not perform ACME operations until migration has completed", ErrServerInternal)
} }
@@ -180,7 +182,7 @@ func (b *backend) acmeWrapper(opts acmeWrapperOpts, op acmeOperation) framework.
// it does not enforce the account being in a valid state nor existing. // it does not enforce the account being in a valid state nor existing.
func (b *backend) acmeParsedWrapper(opt acmeWrapperOpts, op acmeParsedOperation) framework.OperationFunc { func (b *backend) acmeParsedWrapper(opt acmeWrapperOpts, op acmeParsedOperation) framework.OperationFunc {
return b.acmeWrapper(opt, func(acmeCtx *acmeContext, r *logical.Request, fields *framework.FieldData) (*logical.Response, error) { return b.acmeWrapper(opt, func(acmeCtx *acmeContext, r *logical.Request, fields *framework.FieldData) (*logical.Response, error) {
user, data, err := b.acmeState.ParseRequestParams(acmeCtx, r, fields) user, data, err := b.GetAcmeState().ParseRequestParams(acmeCtx, r, fields)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -194,7 +196,7 @@ func (b *backend) acmeParsedWrapper(opt acmeWrapperOpts, op acmeParsedOperation)
} }
if _, ok := resp.Headers["Replay-Nonce"]; !ok { if _, ok := resp.Headers["Replay-Nonce"]; !ok {
nonce, _, err := b.acmeState.GetNonce() nonce, _, err := b.GetAcmeState().GetNonce()
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -255,8 +257,7 @@ func (b *backend) acmeParsedWrapper(opt acmeWrapperOpts, op acmeParsedOperation)
// request has a proper signature for an existing account, and that account is // request has a proper signature for an existing account, and that account is
// in a valid status. It passes to the operation a decoded form of the request // in a valid status. It passes to the operation a decoded form of the request
// parameters as well as the ACME account the request is for. // parameters as well as the ACME account the request is for.
func (b *backend) acmeAccountRequiredWrapper(opt acmeWrapperOpts, op acmeAccountRequiredOperation) framework. func (b *backend) acmeAccountRequiredWrapper(opt acmeWrapperOpts, op acmeAccountRequiredOperation) framework.OperationFunc {
OperationFunc {
return b.acmeParsedWrapper(opt, func(acmeCtx *acmeContext, r *logical.Request, fields *framework.FieldData, uc *jwsCtx, data map[string]interface{}) (*logical.Response, error) { return b.acmeParsedWrapper(opt, func(acmeCtx *acmeContext, r *logical.Request, fields *framework.FieldData, uc *jwsCtx, data map[string]interface{}) (*logical.Response, error) {
if !uc.Existing { if !uc.Existing {
return nil, fmt.Errorf("cannot process request without a 'kid': %w", ErrMalformed) return nil, fmt.Errorf("cannot process request without a 'kid': %w", ErrMalformed)
@@ -320,7 +321,7 @@ func getBasePathFromClusterConfig(sc *storageContext) (*url.URL, error) {
return baseUrl, nil return baseUrl, nil
} }
func getAcmeIssuer(sc *storageContext, issuerName string) (*issuerEntry, error) { func getAcmeIssuer(sc *storageContext, issuerName string) (*issuing.IssuerEntry, error) {
if issuerName == "" { if issuerName == "" {
issuerName = defaultRef issuerName = defaultRef
} }
@@ -334,7 +335,7 @@ func getAcmeIssuer(sc *storageContext, issuerName string) (*issuerEntry, error)
return nil, fmt.Errorf("issuer failed to load: %w", err) return nil, fmt.Errorf("issuer failed to load: %w", err)
} }
if issuer.Usage.HasUsage(IssuanceUsage) && len(issuer.KeyID) > 0 { if issuer.Usage.HasUsage(issuing.IssuanceUsage) && len(issuer.KeyID) > 0 {
return issuer, nil return issuer, nil
} }
@@ -358,12 +359,12 @@ func getAcmeDirectory(r *logical.Request) (string, error) {
return strings.TrimLeft(acmePath[0:lastIndex]+"/acme/", "/"), nil return strings.TrimLeft(acmePath[0:lastIndex]+"/acme/", "/"), nil
} }
func getAcmeRoleAndIssuer(sc *storageContext, data *framework.FieldData, config *acmeConfigEntry) (*roleEntry, *issuerEntry, error) { func getAcmeRoleAndIssuer(sc *storageContext, data *framework.FieldData, config *acmeConfigEntry) (*issuing.RoleEntry, *issuing.IssuerEntry, error) {
requestedIssuer := getRequestedAcmeIssuerFromPath(data) requestedIssuer := getRequestedAcmeIssuerFromPath(data)
requestedRole := getRequestedAcmeRoleFromPath(data) requestedRole := getRequestedAcmeRoleFromPath(data)
issuerToLoad := requestedIssuer issuerToLoad := requestedIssuer
var role *roleEntry var role *issuing.RoleEntry
var err error var err error
if len(requestedRole) == 0 { // Default Directory if len(requestedRole) == 0 { // Default Directory
@@ -375,11 +376,9 @@ func getAcmeRoleAndIssuer(sc *storageContext, data *framework.FieldData, config
case Forbid: case Forbid:
return nil, nil, fmt.Errorf("%w: default directory not allowed by ACME policy", ErrServerInternal) return nil, nil, fmt.Errorf("%w: default directory not allowed by ACME policy", ErrServerInternal)
case SignVerbatim, ExternalPolicy: case SignVerbatim, ExternalPolicy:
role = buildSignVerbatimRoleWithNoData(&roleEntry{ role = issuing.SignVerbatimRoleWithOpts(
Issuer: requestedIssuer, issuing.WithIssuer(requestedIssuer),
NoStore: false, issuing.WithNoStore(false))
Name: requestedRole,
})
case Role: case Role:
role, err = getAndValidateAcmeRole(sc, extraInfo) role, err = getAndValidateAcmeRole(sc, extraInfo)
if err != nil { if err != nil {
@@ -455,9 +454,9 @@ func getAcmeRoleAndIssuer(sc *storageContext, data *framework.FieldData, config
return role, issuer, nil return role, issuer, nil
} }
func getAndValidateAcmeRole(sc *storageContext, requestedRole string) (*roleEntry, error) { func getAndValidateAcmeRole(sc *storageContext, requestedRole string) (*issuing.RoleEntry, error) {
var err error var err error
role, err := sc.Backend.getRole(sc.Context, sc.Storage, requestedRole) role, err := sc.Backend.GetRole(sc.Context, sc.Storage, requestedRole)
if err != nil { if err != nil {
return nil, fmt.Errorf("%w: err loading role", ErrServerInternal) return nil, fmt.Errorf("%w: err loading role", ErrServerInternal)
} }
@@ -491,14 +490,6 @@ func getRequestedAcmeIssuerFromPath(data *framework.FieldData) string {
return requestedIssuer return requestedIssuer
} }
func getRequestedPolicyFromPath(data *framework.FieldData) string {
requestedPolicy := ""
if requestedPolicyRaw, present := data.GetOk("policy"); present {
requestedPolicy = requestedPolicyRaw.(string)
}
return requestedPolicy
}
func isAcmeDisabled(sc *storageContext, config *acmeConfigEntry, policy EabPolicy) bool { func isAcmeDisabled(sc *storageContext, config *acmeConfigEntry, policy EabPolicy) bool {
if !config.Enabled { if !config.Enabled {
return true return true

View File

@@ -6,21 +6,23 @@ package pki
import ( import (
"context" "context"
"fmt" "fmt"
"sort"
"strings" "strings"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
atomic2 "go.uber.org/atomic"
"github.com/armon/go-metrics" "github.com/armon/go-metrics"
"github.com/hashicorp/go-multierror" "github.com/hashicorp/go-multierror"
"github.com/hashicorp/vault/helper/metricsutil" "github.com/hashicorp/vault/helper/metricsutil"
"github.com/hashicorp/vault/helper/namespace" "github.com/hashicorp/vault/helper/namespace"
"github.com/hashicorp/vault/sdk/framework" "github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/consts" "github.com/hashicorp/vault/sdk/helper/consts"
"github.com/hashicorp/vault/sdk/helper/errutil"
"github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/sdk/logical"
"github.com/hashicorp/vault/builtin/logical/pki/issuing"
"github.com/hashicorp/vault/builtin/logical/pki/managed_key"
"github.com/hashicorp/vault/builtin/logical/pki/pki_backend"
) )
const ( const (
@@ -287,19 +289,10 @@ func Backend(conf *logical.BackendConfig) *backend {
// Delay the first tidy until after we've started up. // Delay the first tidy until after we've started up.
b.lastTidy = time.Now() b.lastTidy = time.Now()
// Metrics initialization for count of certificates in storage
b.certCountEnabled = atomic2.NewBool(false)
b.publishCertCountMetrics = atomic2.NewBool(false)
b.certsCounted = atomic2.NewBool(false)
b.certCountError = "Initialize Not Yet Run, Cert Counts Unavailable"
b.certCount = &atomic.Uint32{}
b.revokedCertCount = &atomic.Uint32{}
b.possibleDoubleCountedSerials = make([]string, 0, 250)
b.possibleDoubleCountedRevokedSerials = make([]string, 0, 250)
b.unifiedTransferStatus = newUnifiedTransferStatus() b.unifiedTransferStatus = newUnifiedTransferStatus()
b.acmeState = NewACMEState() b.acmeState = NewACMEState()
b.certificateCounter = NewCertificateCounter(b.backendUUID)
b.SetupEnt() b.SetupEnt()
return &b return &b
@@ -319,19 +312,12 @@ type backend struct {
tidyStatus *tidyStatus tidyStatus *tidyStatus
lastTidy time.Time lastTidy time.Time
unifiedTransferStatus *unifiedTransferStatus unifiedTransferStatus *UnifiedTransferStatus
certCountEnabled *atomic2.Bool certificateCounter *CertificateCounter
publishCertCountMetrics *atomic2.Bool
certCount *atomic.Uint32
revokedCertCount *atomic.Uint32
certsCounted *atomic2.Bool
certCountError string
possibleDoubleCountedSerials []string
possibleDoubleCountedRevokedSerials []string
pkiStorageVersion atomic.Value pkiStorageVersion atomic.Value
crlBuilder *crlBuilder crlBuilder *CrlBuilder
// Write lock around issuers and keys. // Write lock around issuers and keys.
issuersLock sync.RWMutex issuersLock sync.RWMutex
@@ -341,7 +327,25 @@ type backend struct {
acmeAccountLock sync.RWMutex // (Write) Locked on Tidy, (Read) Locked on Account Creation acmeAccountLock sync.RWMutex // (Write) Locked on Tidy, (Read) Locked on Account Creation
} }
type roleOperation func(ctx context.Context, req *logical.Request, data *framework.FieldData, role *roleEntry) (*logical.Response, error) // BackendOps a bridge/legacy interface until we can further
// separate out backend things into distinct packages.
type BackendOps interface {
managed_key.PkiManagedKeyView
pki_backend.SystemViewGetter
pki_backend.MountInfo
pki_backend.Logger
UseLegacyBundleCaStorage() bool
CrlBuilder() *CrlBuilder
GetRevokeStorageLock() *sync.RWMutex
GetUnifiedTransferStatus() *UnifiedTransferStatus
GetAcmeState() *acmeState
GetRole(ctx context.Context, s logical.Storage, n string) (*issuing.RoleEntry, error)
GetCertificateCounter() *CertificateCounter
}
var _ BackendOps = &backend{}
type roleOperation func(ctx context.Context, req *logical.Request, data *framework.FieldData, role *issuing.RoleEntry) (*logical.Response, error)
const backendHelp = ` const backendHelp = `
The PKI backend dynamically generates X509 server and client certificates. The PKI backend dynamically generates X509 server and client certificates.
@@ -363,7 +367,7 @@ func metricsKey(req *logical.Request, extra ...string) []string {
func (b *backend) metricsWrap(callType string, roleMode int, ofunc roleOperation) framework.OperationFunc { func (b *backend) metricsWrap(callType string, roleMode int, ofunc roleOperation) framework.OperationFunc {
return func(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) { return func(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
key := metricsKey(req, callType) key := metricsKey(req, callType)
var role *roleEntry var role *issuing.RoleEntry
var labels []metrics.Label var labels []metrics.Label
var err error var err error
@@ -379,7 +383,7 @@ func (b *backend) metricsWrap(callType string, roleMode int, ofunc roleOperation
} }
if roleMode > noRole { if roleMode > noRole {
// Get the role // Get the role
role, err = b.getRole(ctx, req.Storage, roleName) role, err = b.GetRole(ctx, req.Storage, roleName)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -410,7 +414,7 @@ func (b *backend) metricsWrap(callType string, roleMode int, ofunc roleOperation
// initialize is used to perform a possible PKI storage migration if needed // initialize is used to perform a possible PKI storage migration if needed
func (b *backend) initialize(ctx context.Context, ir *logical.InitializationRequest) error { func (b *backend) initialize(ctx context.Context, ir *logical.InitializationRequest) error {
sc := b.makeStorageContext(ctx, b.storage) sc := b.makeStorageContext(ctx, b.storage)
if err := b.crlBuilder.reloadConfigIfRequired(sc); err != nil { if err := b.CrlBuilder().reloadConfigIfRequired(sc); err != nil {
return err return err
} }
@@ -419,7 +423,7 @@ func (b *backend) initialize(ctx context.Context, ir *logical.InitializationRequ
return err return err
} }
err = b.acmeState.Initialize(b, sc) err = b.GetAcmeState().Initialize(b, sc)
if err != nil { if err != nil {
return err return err
} }
@@ -429,7 +433,7 @@ func (b *backend) initialize(ctx context.Context, ir *logical.InitializationRequ
if err != nil { if err != nil {
// Don't block/err initialize/startup for metrics. Context on this call can time out due to number of certificates. // Don't block/err initialize/startup for metrics. Context on this call can time out due to number of certificates.
b.Logger().Error("Could not initialize stored certificate counts", "error", err) b.Logger().Error("Could not initialize stored certificate counts", "error", err)
b.certCountError = err.Error() b.GetCertificateCounter().SetError(err)
} }
return b.initializeEnt(sc, ir) return b.initializeEnt(sc, ir)
@@ -438,7 +442,7 @@ func (b *backend) initialize(ctx context.Context, ir *logical.InitializationRequ
func (b *backend) cleanup(ctx context.Context) { func (b *backend) cleanup(ctx context.Context) {
sc := b.makeStorageContext(ctx, b.storage) sc := b.makeStorageContext(ctx, b.storage)
b.acmeState.Shutdown(b) b.GetAcmeState().Shutdown(b)
b.cleanupEnt(sc) b.cleanupEnt(sc)
} }
@@ -469,7 +473,31 @@ func (b *backend) initializePKIIssuersStorage(ctx context.Context) error {
return nil return nil
} }
func (b *backend) useLegacyBundleCaStorage() bool { func (b *backend) BackendUUID() string {
return b.backendUUID
}
func (b *backend) CrlBuilder() *CrlBuilder {
return b.crlBuilder
}
func (b *backend) GetRevokeStorageLock() *sync.RWMutex {
return &b.revokeStorageLock
}
func (b *backend) GetUnifiedTransferStatus() *UnifiedTransferStatus {
return b.unifiedTransferStatus
}
func (b *backend) GetAcmeState() *acmeState {
return b.acmeState
}
func (b *backend) GetCertificateCounter() *CertificateCounter {
return b.certificateCounter
}
func (b *backend) UseLegacyBundleCaStorage() bool {
// This helper function is here to choose whether or not we use the newer // This helper function is here to choose whether or not we use the newer
// issuer/key storage format or the older legacy ca bundle format. // issuer/key storage format or the older legacy ca bundle format.
// //
@@ -482,6 +510,18 @@ func (b *backend) useLegacyBundleCaStorage() bool {
return version == nil || version == 0 return version == nil || version == 0
} }
func (b *backend) IsSecondaryNode() bool {
return b.System().ReplicationState().HasState(consts.ReplicationPerformanceStandby)
}
func (b *backend) GetManagedKeyView() (logical.ManagedKeySystemView, error) {
managedKeyView, ok := b.System().(logical.ManagedKeySystemView)
if !ok {
return nil, errutil.InternalError{Err: fmt.Sprintf("unsupported system view")}
}
return managedKeyView, nil
}
func (b *backend) updatePkiStorageVersion(ctx context.Context, grabIssuersLock bool) { func (b *backend) updatePkiStorageVersion(ctx context.Context, grabIssuersLock bool) {
info, err := getMigrationInfo(ctx, b.storage) info, err := getMigrationInfo(ctx, b.storage)
if err != nil { if err != nil {
@@ -520,36 +560,36 @@ func (b *backend) invalidate(ctx context.Context, key string) {
go func() { go func() {
b.Logger().Info("Detected a migration completed, resetting pki storage version") b.Logger().Info("Detected a migration completed, resetting pki storage version")
b.updatePkiStorageVersion(ctx, true) b.updatePkiStorageVersion(ctx, true)
b.crlBuilder.requestRebuildIfActiveNode(b) b.CrlBuilder().requestRebuildIfActiveNode(b)
}() }()
case strings.HasPrefix(key, issuerPrefix): case strings.HasPrefix(key, issuerPrefix):
if !b.useLegacyBundleCaStorage() { if !b.UseLegacyBundleCaStorage() {
// See note in updateDefaultIssuerId about why this is necessary. // See note in updateDefaultIssuerId about why this is necessary.
// We do this ahead of CRL rebuilding just so we know that things // We do this ahead of CRL rebuilding just so we know that things
// are stale. // are stale.
b.crlBuilder.invalidateCRLBuildTime() b.CrlBuilder().invalidateCRLBuildTime()
// If an issuer has changed on the primary, we need to schedule an update of our CRL, // If an issuer has changed on the primary, we need to schedule an update of our CRL,
// the primary cluster would have done it already, but the CRL is cluster specific so // the primary cluster would have done it already, but the CRL is cluster specific so
// force a rebuild of ours. // force a rebuild of ours.
b.crlBuilder.requestRebuildIfActiveNode(b) b.CrlBuilder().requestRebuildIfActiveNode(b)
} else { } else {
b.Logger().Debug("Ignoring invalidation updates for issuer as the PKI migration has yet to complete.") b.Logger().Debug("Ignoring invalidation updates for issuer as the PKI migration has yet to complete.")
} }
case key == "config/crl": case key == "config/crl":
// We may need to reload our OCSP status flag // We may need to reload our OCSP status flag
b.crlBuilder.markConfigDirty() b.CrlBuilder().markConfigDirty()
case key == storageAcmeConfig: case key == storageAcmeConfig:
b.acmeState.markConfigDirty() b.GetAcmeState().markConfigDirty()
case key == storageIssuerConfig: case key == storageIssuerConfig:
b.crlBuilder.invalidateCRLBuildTime() b.CrlBuilder().invalidateCRLBuildTime()
case strings.HasPrefix(key, crossRevocationPrefix): case strings.HasPrefix(key, crossRevocationPrefix):
split := strings.Split(key, "/") split := strings.Split(key, "/")
if !strings.HasSuffix(key, "/confirmed") { if !strings.HasSuffix(key, "/confirmed") {
cluster := split[len(split)-2] cluster := split[len(split)-2]
serial := split[len(split)-1] serial := split[len(split)-1]
b.crlBuilder.addCertForRevocationCheck(cluster, serial) b.CrlBuilder().addCertForRevocationCheck(cluster, serial)
} else { } else {
if len(split) >= 3 { if len(split) >= 3 {
cluster := split[len(split)-3] cluster := split[len(split)-3]
@@ -560,7 +600,7 @@ func (b *backend) invalidate(ctx context.Context, key string) {
// ignore them). On performance primary nodes though, // ignore them). On performance primary nodes though,
// we do want to track them to remove them. // we do want to track them to remove them.
if !isNotPerfPrimary { if !isNotPerfPrimary {
b.crlBuilder.addCertForRevocationRemoval(cluster, serial) b.CrlBuilder().addCertForRevocationRemoval(cluster, serial)
} }
} }
} }
@@ -569,7 +609,7 @@ func (b *backend) invalidate(ctx context.Context, key string) {
split := strings.Split(key, "/") split := strings.Split(key, "/")
cluster := split[len(split)-2] cluster := split[len(split)-2]
serial := split[len(split)-1] serial := split[len(split)-1]
b.crlBuilder.addCertFromCrossRevocation(cluster, serial) b.CrlBuilder().addCertFromCrossRevocation(cluster, serial)
} }
b.invalidateEnt(ctx, key) b.invalidateEnt(ctx, key)
@@ -580,7 +620,7 @@ func (b *backend) periodicFunc(ctx context.Context, request *logical.Request) er
doCRL := func() error { doCRL := func() error {
// First attempt to reload the CRL configuration. // First attempt to reload the CRL configuration.
if err := b.crlBuilder.reloadConfigIfRequired(sc); err != nil { if err := b.CrlBuilder().reloadConfigIfRequired(sc); err != nil {
return err return err
} }
@@ -592,22 +632,22 @@ func (b *backend) periodicFunc(ctx context.Context, request *logical.Request) er
} }
// First handle any global revocation queue entries. // First handle any global revocation queue entries.
if err := b.crlBuilder.processRevocationQueue(sc); err != nil { if err := b.CrlBuilder().processRevocationQueue(sc); err != nil {
return err return err
} }
// Then handle any unified cross-cluster revocations. // Then handle any unified cross-cluster revocations.
if err := b.crlBuilder.processCrossClusterRevocations(sc); err != nil { if err := b.CrlBuilder().processCrossClusterRevocations(sc); err != nil {
return err return err
} }
// Check if we're set to auto rebuild and a CRL is set to expire. // Check if we're set to auto rebuild and a CRL is set to expire.
if err := b.crlBuilder.checkForAutoRebuild(sc); err != nil { if err := b.CrlBuilder().checkForAutoRebuild(sc); err != nil {
return err return err
} }
// Then attempt to rebuild the CRLs if required. // Then attempt to rebuild the CRLs if required.
warnings, err := b.crlBuilder.rebuildIfForced(sc) warnings, err := b.CrlBuilder().rebuildIfForced(sc)
if err != nil { if err != nil {
return err return err
} }
@@ -622,7 +662,7 @@ func (b *backend) periodicFunc(ctx context.Context, request *logical.Request) er
// If a delta CRL was rebuilt above as part of the complete CRL rebuild, // If a delta CRL was rebuilt above as part of the complete CRL rebuild,
// this will be a no-op. However, if we do need to rebuild delta CRLs, // this will be a no-op. However, if we do need to rebuild delta CRLs,
// this would cause us to do so. // this would cause us to do so.
warnings, err = b.crlBuilder.rebuildDeltaCRLsIfForced(sc, false) warnings, err = b.CrlBuilder().rebuildDeltaCRLsIfForced(sc, false)
if err != nil { if err != nil {
return err return err
} }
@@ -689,7 +729,7 @@ func (b *backend) periodicFunc(ctx context.Context, request *logical.Request) er
} }
// First tidy any ACME nonces to free memory. // First tidy any ACME nonces to free memory.
b.acmeState.DoTidyNonces() b.GetAcmeState().DoTidyNonces()
// Then run unified transfer. // Then run unified transfer.
backgroundSc := b.makeStorageContext(context.Background(), b.storage) backgroundSc := b.makeStorageContext(context.Background(), b.storage)
@@ -700,11 +740,7 @@ func (b *backend) periodicFunc(ctx context.Context, request *logical.Request) er
tidyErr := doAutoTidy() tidyErr := doAutoTidy()
// Periodically re-emit gauges so that they don't disappear/go stale // Periodically re-emit gauges so that they don't disappear/go stale
tidyConfig, err := sc.getAutoTidyConfig() b.GetCertificateCounter().EmitCertStoreMetrics()
if err != nil {
return err
}
b.emitCertStoreMetrics(tidyConfig)
var errors error var errors error
if crlErr != nil { if crlErr != nil {
@@ -721,7 +757,7 @@ func (b *backend) periodicFunc(ctx context.Context, request *logical.Request) er
// Check if the CRL was invalidated due to issuer swap and update // Check if the CRL was invalidated due to issuer swap and update
// accordingly. // accordingly.
if err := b.crlBuilder.flushCRLBuildTimeInvalidation(sc); err != nil { if err := b.CrlBuilder().flushCRLBuildTimeInvalidation(sc); err != nil {
return err return err
} }
@@ -742,211 +778,22 @@ func (b *backend) initializeStoredCertificateCounts(ctx context.Context) error {
return err return err
} }
b.certCountEnabled.Store(config.MaintainCount) certCounter := b.GetCertificateCounter()
b.publishCertCountMetrics.Store(config.PublishMetrics) isEnabled := certCounter.ReconfigureWithTidyConfig(config)
if !isEnabled {
if config.MaintainCount == false {
b.possibleDoubleCountedRevokedSerials = nil
b.possibleDoubleCountedSerials = nil
b.certsCounted.Store(true)
b.certCount.Store(0)
b.revokedCertCount.Store(0)
b.certCountError = "Cert Count is Disabled: enable via Tidy Config maintain_stored_certificate_counts"
return nil return nil
} }
// Ideally these three things would be set in one transaction, since that isn't possible, set the counts to "0",
// first, so count will over-count (and miss putting things in deduplicate queue), rather than under-count.
b.certCount.Store(0)
b.revokedCertCount.Store(0)
b.possibleDoubleCountedRevokedSerials = nil
b.possibleDoubleCountedSerials = nil
// A cert issued or revoked here will be double-counted. That's okay, this is "best effort" metrics.
b.certsCounted.Store(false)
entries, err := b.storage.List(ctx, "certs/") entries, err := b.storage.List(ctx, "certs/")
if err != nil { if err != nil {
return err return err
} }
b.certCount.Add(uint32(len(entries)))
revokedEntries, err := b.storage.List(ctx, "revoked/") revokedEntries, err := b.storage.List(ctx, "revoked/")
if err != nil { if err != nil {
return err return err
} }
b.revokedCertCount.Add(uint32(len(revokedEntries)))
b.certsCounted.Store(true)
// Now that the metrics are set, we can switch from appending newly-stored certificates to the possible double-count
// list, and instead have them update the counter directly. We need to do this so that we are looking at a static
// slice of possibly double counted serials. Note that certsCounted is computed before the storage operation, so
// there may be some delay here.
// Sort the listed-entries first, to accommodate that delay.
sort.Slice(entries, func(i, j int) bool {
return entries[i] < entries[j]
})
sort.Slice(revokedEntries, func(i, j int) bool {
return revokedEntries[i] < revokedEntries[j]
})
// We assume here that these lists are now complete.
sort.Slice(b.possibleDoubleCountedSerials, func(i, j int) bool {
return b.possibleDoubleCountedSerials[i] < b.possibleDoubleCountedSerials[j]
})
listEntriesIndex := 0
possibleDoubleCountIndex := 0
for {
if listEntriesIndex >= len(entries) {
break
}
if possibleDoubleCountIndex >= len(b.possibleDoubleCountedSerials) {
break
}
if entries[listEntriesIndex] == b.possibleDoubleCountedSerials[possibleDoubleCountIndex] {
// This represents a double-counted entry
b.decrementTotalCertificatesCountNoReport()
listEntriesIndex = listEntriesIndex + 1
possibleDoubleCountIndex = possibleDoubleCountIndex + 1
continue
}
if entries[listEntriesIndex] < b.possibleDoubleCountedSerials[possibleDoubleCountIndex] {
listEntriesIndex = listEntriesIndex + 1
continue
}
if entries[listEntriesIndex] > b.possibleDoubleCountedSerials[possibleDoubleCountIndex] {
possibleDoubleCountIndex = possibleDoubleCountIndex + 1
continue
}
}
sort.Slice(b.possibleDoubleCountedRevokedSerials, func(i, j int) bool {
return b.possibleDoubleCountedRevokedSerials[i] < b.possibleDoubleCountedRevokedSerials[j]
})
listRevokedEntriesIndex := 0
possibleRevokedDoubleCountIndex := 0
for {
if listRevokedEntriesIndex >= len(revokedEntries) {
break
}
if possibleRevokedDoubleCountIndex >= len(b.possibleDoubleCountedRevokedSerials) {
break
}
if revokedEntries[listRevokedEntriesIndex] == b.possibleDoubleCountedRevokedSerials[possibleRevokedDoubleCountIndex] {
// This represents a double-counted revoked entry
b.decrementTotalRevokedCertificatesCountNoReport()
listRevokedEntriesIndex = listRevokedEntriesIndex + 1
possibleRevokedDoubleCountIndex = possibleRevokedDoubleCountIndex + 1
continue
}
if revokedEntries[listRevokedEntriesIndex] < b.possibleDoubleCountedRevokedSerials[possibleRevokedDoubleCountIndex] {
listRevokedEntriesIndex = listRevokedEntriesIndex + 1
continue
}
if revokedEntries[listRevokedEntriesIndex] > b.possibleDoubleCountedRevokedSerials[possibleRevokedDoubleCountIndex] {
possibleRevokedDoubleCountIndex = possibleRevokedDoubleCountIndex + 1
continue
}
}
b.possibleDoubleCountedRevokedSerials = nil
b.possibleDoubleCountedSerials = nil
b.emitCertStoreMetrics(config)
b.certCountError = ""
certCounter.InitializeCountsFromStorage(entries, revokedEntries)
return nil return nil
} }
func (b *backend) emitCertStoreMetrics(config *tidyConfig) {
if config.PublishMetrics == true {
certCount := b.certCount.Load()
b.emitTotalCertCountMetric(certCount)
revokedCertCount := b.revokedCertCount.Load()
b.emitTotalRevokedCountMetric(revokedCertCount)
}
}
// The "certsCounted" boolean here should be loaded from the backend certsCounted before the corresponding storage call:
// eg. certsCounted := b.certsCounted.Load()
func (b *backend) ifCountEnabledIncrementTotalCertificatesCount(certsCounted bool, newSerial string) {
if b.certCountEnabled.Load() {
certCount := b.certCount.Add(1)
switch {
case !certsCounted:
// This is unsafe, but a good best-attempt
if strings.HasPrefix(newSerial, "certs/") {
newSerial = newSerial[6:]
}
b.possibleDoubleCountedSerials = append(b.possibleDoubleCountedSerials, newSerial)
default:
if b.publishCertCountMetrics.Load() {
b.emitTotalCertCountMetric(certCount)
}
}
}
}
func (b *backend) ifCountEnabledDecrementTotalCertificatesCountReport() {
if b.certCountEnabled.Load() {
certCount := b.decrementTotalCertificatesCountNoReport()
if b.publishCertCountMetrics.Load() {
b.emitTotalCertCountMetric(certCount)
}
}
}
func (b *backend) emitTotalCertCountMetric(certCount uint32) {
metrics.SetGauge([]string{"secrets", "pki", b.backendUUID, "total_certificates_stored"}, float32(certCount))
}
// Called directly only by the initialize function to deduplicate the count, when we don't have a full count yet
// Does not respect whether-we-are-counting backend information.
func (b *backend) decrementTotalCertificatesCountNoReport() uint32 {
newCount := b.certCount.Add(^uint32(0))
return newCount
}
// The "certsCounted" boolean here should be loaded from the backend certsCounted before the corresponding storage call:
// eg. certsCounted := b.certsCounted.Load()
func (b *backend) ifCountEnabledIncrementTotalRevokedCertificatesCount(certsCounted bool, newSerial string) {
if b.certCountEnabled.Load() {
newRevokedCertCount := b.revokedCertCount.Add(1)
switch {
case !certsCounted:
// This is unsafe, but a good best-attempt
if strings.HasPrefix(newSerial, "revoked/") { // allow passing in the path (revoked/serial) OR the serial
newSerial = newSerial[8:]
}
b.possibleDoubleCountedRevokedSerials = append(b.possibleDoubleCountedRevokedSerials, newSerial)
default:
if b.publishCertCountMetrics.Load() {
b.emitTotalRevokedCountMetric(newRevokedCertCount)
}
}
}
}
func (b *backend) ifCountEnabledDecrementTotalRevokedCertificatesCountReport() {
if b.certCountEnabled.Load() {
revokedCertCount := b.decrementTotalRevokedCertificatesCountNoReport()
if b.publishCertCountMetrics.Load() {
b.emitTotalRevokedCountMetric(revokedCertCount)
}
}
}
func (b *backend) emitTotalRevokedCountMetric(revokedCertCount uint32) {
metrics.SetGauge([]string{"secrets", "pki", b.backendUUID, "total_revoked_certificates_stored"}, float32(revokedCertCount))
}
// Called directly only by the initialize function to deduplicate the count, when we don't have a full count yet
// Does not respect whether-we-are-counting backend information.
func (b *backend) decrementTotalRevokedCertificatesCountNoReport() uint32 {
newRevokedCertCount := b.revokedCertCount.Add(^uint32(0))
return newRevokedCertCount
}

View File

@@ -5,6 +5,7 @@ package pki
import ( import (
"bytes" "bytes"
"cmp"
"context" "context"
"crypto" "crypto"
"crypto/ecdsa" "crypto/ecdsa"
@@ -26,6 +27,7 @@ import (
"net/url" "net/url"
"os" "os"
"reflect" "reflect"
"slices"
"sort" "sort"
"strconv" "strconv"
"strings" "strings"
@@ -33,6 +35,7 @@ import (
"testing" "testing"
"time" "time"
"github.com/hashicorp/vault/builtin/logical/pki/parsing"
"github.com/hashicorp/vault/helper/testhelpers/teststorage" "github.com/hashicorp/vault/helper/testhelpers/teststorage"
"golang.org/x/exp/maps" "golang.org/x/exp/maps"
@@ -56,6 +59,8 @@ import (
"github.com/hashicorp/vault/vault" "github.com/hashicorp/vault/vault"
"github.com/mitchellh/mapstructure" "github.com/mitchellh/mapstructure"
"golang.org/x/net/idna" "golang.org/x/net/idna"
"github.com/hashicorp/vault/builtin/logical/pki/issuing"
) )
var stepCount = 0 var stepCount = 0
@@ -856,7 +861,7 @@ func generateTestCsr(t *testing.T, keyType certutil.PrivateKeyType, keyBits int)
// Generates steps to test out various role permutations // Generates steps to test out various role permutations
func generateRoleSteps(t *testing.T, useCSRs bool) []logicaltest.TestStep { func generateRoleSteps(t *testing.T, useCSRs bool) []logicaltest.TestStep {
roleVals := roleEntry{ roleVals := issuing.RoleEntry{
MaxTTL: 12 * time.Hour, MaxTTL: 12 * time.Hour,
KeyType: "rsa", KeyType: "rsa",
KeyBits: 2048, KeyBits: 2048,
@@ -938,7 +943,7 @@ func generateRoleSteps(t *testing.T, useCSRs bool) []logicaltest.TestStep {
ret = append(ret, issueTestStep) ret = append(ret, issueTestStep)
} }
getCountryCheck := func(role roleEntry) logicaltest.TestCheckFunc { getCountryCheck := func(role issuing.RoleEntry) logicaltest.TestCheckFunc {
var certBundle certutil.CertBundle var certBundle certutil.CertBundle
return func(resp *logical.Response) error { return func(resp *logical.Response) error {
err := mapstructure.Decode(resp.Data, &certBundle) err := mapstructure.Decode(resp.Data, &certBundle)
@@ -959,7 +964,7 @@ func generateRoleSteps(t *testing.T, useCSRs bool) []logicaltest.TestStep {
} }
} }
getOuCheck := func(role roleEntry) logicaltest.TestCheckFunc { getOuCheck := func(role issuing.RoleEntry) logicaltest.TestCheckFunc {
var certBundle certutil.CertBundle var certBundle certutil.CertBundle
return func(resp *logical.Response) error { return func(resp *logical.Response) error {
err := mapstructure.Decode(resp.Data, &certBundle) err := mapstructure.Decode(resp.Data, &certBundle)
@@ -980,7 +985,7 @@ func generateRoleSteps(t *testing.T, useCSRs bool) []logicaltest.TestStep {
} }
} }
getOrganizationCheck := func(role roleEntry) logicaltest.TestCheckFunc { getOrganizationCheck := func(role issuing.RoleEntry) logicaltest.TestCheckFunc {
var certBundle certutil.CertBundle var certBundle certutil.CertBundle
return func(resp *logical.Response) error { return func(resp *logical.Response) error {
err := mapstructure.Decode(resp.Data, &certBundle) err := mapstructure.Decode(resp.Data, &certBundle)
@@ -1001,7 +1006,7 @@ func generateRoleSteps(t *testing.T, useCSRs bool) []logicaltest.TestStep {
} }
} }
getLocalityCheck := func(role roleEntry) logicaltest.TestCheckFunc { getLocalityCheck := func(role issuing.RoleEntry) logicaltest.TestCheckFunc {
var certBundle certutil.CertBundle var certBundle certutil.CertBundle
return func(resp *logical.Response) error { return func(resp *logical.Response) error {
err := mapstructure.Decode(resp.Data, &certBundle) err := mapstructure.Decode(resp.Data, &certBundle)
@@ -1022,7 +1027,7 @@ func generateRoleSteps(t *testing.T, useCSRs bool) []logicaltest.TestStep {
} }
} }
getProvinceCheck := func(role roleEntry) logicaltest.TestCheckFunc { getProvinceCheck := func(role issuing.RoleEntry) logicaltest.TestCheckFunc {
var certBundle certutil.CertBundle var certBundle certutil.CertBundle
return func(resp *logical.Response) error { return func(resp *logical.Response) error {
err := mapstructure.Decode(resp.Data, &certBundle) err := mapstructure.Decode(resp.Data, &certBundle)
@@ -1043,7 +1048,7 @@ func generateRoleSteps(t *testing.T, useCSRs bool) []logicaltest.TestStep {
} }
} }
getStreetAddressCheck := func(role roleEntry) logicaltest.TestCheckFunc { getStreetAddressCheck := func(role issuing.RoleEntry) logicaltest.TestCheckFunc {
var certBundle certutil.CertBundle var certBundle certutil.CertBundle
return func(resp *logical.Response) error { return func(resp *logical.Response) error {
err := mapstructure.Decode(resp.Data, &certBundle) err := mapstructure.Decode(resp.Data, &certBundle)
@@ -1064,7 +1069,7 @@ func generateRoleSteps(t *testing.T, useCSRs bool) []logicaltest.TestStep {
} }
} }
getPostalCodeCheck := func(role roleEntry) logicaltest.TestCheckFunc { getPostalCodeCheck := func(role issuing.RoleEntry) logicaltest.TestCheckFunc {
var certBundle certutil.CertBundle var certBundle certutil.CertBundle
return func(resp *logical.Response) error { return func(resp *logical.Response) error {
err := mapstructure.Decode(resp.Data, &certBundle) err := mapstructure.Decode(resp.Data, &certBundle)
@@ -1085,7 +1090,7 @@ func generateRoleSteps(t *testing.T, useCSRs bool) []logicaltest.TestStep {
} }
} }
getNotBeforeCheck := func(role roleEntry) logicaltest.TestCheckFunc { getNotBeforeCheck := func(role issuing.RoleEntry) logicaltest.TestCheckFunc {
var certBundle certutil.CertBundle var certBundle certutil.CertBundle
return func(resp *logical.Response) error { return func(resp *logical.Response) error {
err := mapstructure.Decode(resp.Data, &certBundle) err := mapstructure.Decode(resp.Data, &certBundle)
@@ -1110,7 +1115,9 @@ func generateRoleSteps(t *testing.T, useCSRs bool) []logicaltest.TestStep {
// Returns a TestCheckFunc that performs various validity checks on the // Returns a TestCheckFunc that performs various validity checks on the
// returned certificate information, mostly within checkCertsAndPrivateKey // returned certificate information, mostly within checkCertsAndPrivateKey
getCnCheck := func(name string, role roleEntry, key crypto.Signer, usage x509.KeyUsage, extUsage x509.ExtKeyUsage, validity time.Duration) logicaltest.TestCheckFunc { getCnCheck := func(name string, role issuing.RoleEntry, key crypto.Signer, usage x509.KeyUsage,
extUsage x509.ExtKeyUsage, validity time.Duration,
) logicaltest.TestCheckFunc {
var certBundle certutil.CertBundle var certBundle certutil.CertBundle
return func(resp *logical.Response) error { return func(resp *logical.Response) error {
err := mapstructure.Decode(resp.Data, &certBundle) err := mapstructure.Decode(resp.Data, &certBundle)
@@ -1333,7 +1340,7 @@ func generateRoleSteps(t *testing.T, useCSRs bool) []logicaltest.TestStep {
} }
roleVals.KeyUsage = usage roleVals.KeyUsage = usage
parsedKeyUsage := parseKeyUsages(roleVals.KeyUsage) parsedKeyUsage := parsing.ParseKeyUsages(roleVals.KeyUsage)
if parsedKeyUsage == 0 && len(usage) != 0 { if parsedKeyUsage == 0 && len(usage) != 0 {
panic("parsed key usages was zero") panic("parsed key usages was zero")
} }
@@ -1592,7 +1599,7 @@ func generateRoleSteps(t *testing.T, useCSRs bool) []logicaltest.TestStep {
} }
{ {
getOtherCheck := func(expectedOthers ...otherNameUtf8) logicaltest.TestCheckFunc { getOtherCheck := func(expectedOthers ...issuing.OtherNameUtf8) logicaltest.TestCheckFunc {
return func(resp *logical.Response) error { return func(resp *logical.Response) error {
var certBundle certutil.CertBundle var certBundle certutil.CertBundle
err := mapstructure.Decode(resp.Data, &certBundle) err := mapstructure.Decode(resp.Data, &certBundle)
@@ -1608,7 +1615,7 @@ func generateRoleSteps(t *testing.T, useCSRs bool) []logicaltest.TestStep {
if err != nil { if err != nil {
return err return err
} }
var expected []otherNameUtf8 var expected []issuing.OtherNameUtf8
expected = append(expected, expectedOthers...) expected = append(expected, expectedOthers...)
if diff := deep.Equal(foundOthers, expected); len(diff) > 0 { if diff := deep.Equal(foundOthers, expected); len(diff) > 0 {
return fmt.Errorf("wrong SAN IPs, diff: %v", diff) return fmt.Errorf("wrong SAN IPs, diff: %v", diff)
@@ -1617,11 +1624,11 @@ func generateRoleSteps(t *testing.T, useCSRs bool) []logicaltest.TestStep {
} }
} }
addOtherSANTests := func(useCSRs, useCSRSANs bool, allowedOtherSANs []string, errorOk bool, otherSANs []string, csrOtherSANs []otherNameUtf8, check logicaltest.TestCheckFunc) { addOtherSANTests := func(useCSRs, useCSRSANs bool, allowedOtherSANs []string, errorOk bool, otherSANs []string, csrOtherSANs []issuing.OtherNameUtf8, check logicaltest.TestCheckFunc) {
otherSansMap := func(os []otherNameUtf8) map[string][]string { otherSansMap := func(os []issuing.OtherNameUtf8) map[string][]string {
ret := make(map[string][]string) ret := make(map[string][]string)
for _, o := range os { for _, o := range os {
ret[o.oid] = append(ret[o.oid], o.value) ret[o.Oid] = append(ret[o.Oid], o.Value)
} }
return ret return ret
} }
@@ -1652,14 +1659,14 @@ func generateRoleSteps(t *testing.T, useCSRs bool) []logicaltest.TestStep {
roleVals.UseCSRCommonName = true roleVals.UseCSRCommonName = true
commonNames.Localhost = true commonNames.Localhost = true
newOtherNameUtf8 := func(s string) (ret otherNameUtf8) { newOtherNameUtf8 := func(s string) (ret issuing.OtherNameUtf8) {
pieces := strings.Split(s, ";") pieces := strings.Split(s, ";")
if len(pieces) == 2 { if len(pieces) == 2 {
piecesRest := strings.Split(pieces[1], ":") piecesRest := strings.Split(pieces[1], ":")
if len(piecesRest) == 2 { if len(piecesRest) == 2 {
switch strings.ToUpper(piecesRest[0]) { switch strings.ToUpper(piecesRest[0]) {
case "UTF-8", "UTF8": case "UTF-8", "UTF8":
return otherNameUtf8{oid: pieces[0], value: piecesRest[1]} return issuing.OtherNameUtf8{Oid: pieces[0], Value: piecesRest[1]}
} }
} }
} }
@@ -1669,7 +1676,7 @@ func generateRoleSteps(t *testing.T, useCSRs bool) []logicaltest.TestStep {
oid1 := "1.3.6.1.4.1.311.20.2.3" oid1 := "1.3.6.1.4.1.311.20.2.3"
oth1str := oid1 + ";utf8:devops@nope.com" oth1str := oid1 + ";utf8:devops@nope.com"
oth1 := newOtherNameUtf8(oth1str) oth1 := newOtherNameUtf8(oth1str)
oth2 := otherNameUtf8{oid1, "me@example.com"} oth2 := issuing.OtherNameUtf8{oid1, "me@example.com"}
// allowNone, allowAll := []string{}, []string{oid1 + ";UTF-8:*"} // allowNone, allowAll := []string{}, []string{oid1 + ";UTF-8:*"}
allowNone, allowAll := []string{}, []string{"*"} allowNone, allowAll := []string{}, []string{"*"}
@@ -1684,15 +1691,15 @@ func generateRoleSteps(t *testing.T, useCSRs bool) []logicaltest.TestStep {
// Given OtherSANs as API argument and useCSRSANs false, CSR arg ignored. // Given OtherSANs as API argument and useCSRSANs false, CSR arg ignored.
addOtherSANTests(useCSRs, false, allowAll, false, []string{oth1str}, addOtherSANTests(useCSRs, false, allowAll, false, []string{oth1str},
[]otherNameUtf8{oth2}, getOtherCheck(oth1)) []issuing.OtherNameUtf8{oth2}, getOtherCheck(oth1))
if useCSRs { if useCSRs {
// OtherSANs not allowed, valid OtherSANs provided via CSR, should be an error. // OtherSANs not allowed, valid OtherSANs provided via CSR, should be an error.
addOtherSANTests(useCSRs, true, allowNone, true, nil, []otherNameUtf8{oth1}, nil) addOtherSANTests(useCSRs, true, allowNone, true, nil, []issuing.OtherNameUtf8{oth1}, nil)
// Given OtherSANs as both API and CSR arguments and useCSRSANs=true, API arg ignored. // Given OtherSANs as both API and CSR arguments and useCSRSANs=true, API arg ignored.
addOtherSANTests(useCSRs, false, allowAll, false, []string{oth2.String()}, addOtherSANTests(useCSRs, false, allowAll, false, []string{oth2.String()},
[]otherNameUtf8{oth1}, getOtherCheck(oth2)) []issuing.OtherNameUtf8{oth1}, getOtherCheck(oth2))
} }
} }
@@ -2405,7 +2412,7 @@ func TestBackend_Root_Idempotency(t *testing.T) {
certSkid := certutil.GetHexFormatted(cert.SubjectKeyId, ":") certSkid := certutil.GetHexFormatted(cert.SubjectKeyId, ":")
// -> Validate the SKID matches between the root cert and the key // -> Validate the SKID matches between the root cert and the key
resp, err = CBRead(b, s, "key/"+keyId1.(keyID).String()) resp, err = CBRead(b, s, "key/"+keyId1.(issuing.KeyID).String())
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, resp, "expected a response") require.NotNil(t, resp, "expected a response")
require.Equal(t, resp.Data["subject_key_id"], certSkid) require.Equal(t, resp.Data["subject_key_id"], certSkid)
@@ -2427,7 +2434,7 @@ func TestBackend_Root_Idempotency(t *testing.T) {
certSkid = certutil.GetHexFormatted(cert.SubjectKeyId, ":") certSkid = certutil.GetHexFormatted(cert.SubjectKeyId, ":")
// -> Validate the SKID matches between the root cert and the key // -> Validate the SKID matches between the root cert and the key
resp, err = CBRead(b, s, "key/"+keyId2.(keyID).String()) resp, err = CBRead(b, s, "key/"+keyId2.(issuing.KeyID).String())
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, resp, "expected a response") require.NotNil(t, resp, "expected a response")
require.Equal(t, resp.Data["subject_key_id"], certSkid) require.Equal(t, resp.Data["subject_key_id"], certSkid)
@@ -2562,7 +2569,7 @@ func TestBackend_SignIntermediate_AllowedPastCAValidity(t *testing.T) {
}) })
schema.ValidateResponse(t, schema.GetResponseSchema(t, b_root.Route("intermediate/generate/internal"), logical.UpdateOperation), resp, true) schema.ValidateResponse(t, schema.GetResponseSchema(t, b_root.Route("intermediate/generate/internal"), logical.UpdateOperation), resp, true)
require.Contains(t, resp.Data, "key_id") require.Contains(t, resp.Data, "key_id")
intKeyId := resp.Data["key_id"].(keyID) intKeyId := resp.Data["key_id"].(issuing.KeyID)
csr := resp.Data["csr"] csr := resp.Data["csr"]
resp, err = CBRead(b_int, s_int, "key/"+intKeyId.String()) resp, err = CBRead(b_int, s_int, "key/"+intKeyId.String())
@@ -2764,7 +2771,7 @@ func TestBackend_SignSelfIssued(t *testing.T) {
} }
sc := b.makeStorageContext(context.Background(), storage) sc := b.makeStorageContext(context.Background(), storage)
signingBundle, err := sc.fetchCAInfo(defaultRef, ReadOnlyUsage) signingBundle, err := sc.fetchCAInfo(defaultRef, issuing.ReadOnlyUsage)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -3123,11 +3130,14 @@ func TestBackend_OID_SANs(t *testing.T) {
cert.DNSNames[2] != "foobar.com" { cert.DNSNames[2] != "foobar.com" {
t.Fatalf("unexpected DNS SANs %v", cert.DNSNames) t.Fatalf("unexpected DNS SANs %v", cert.DNSNames)
} }
expectedOtherNames := []otherNameUtf8{{oid1, val1}, {oid2, val2}} expectedOtherNames := []issuing.OtherNameUtf8{{oid1, val1}, {oid2, val2}}
foundOtherNames, err := getOtherSANsFromX509Extensions(cert.Extensions) foundOtherNames, err := getOtherSANsFromX509Extensions(cert.Extensions)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
// Sort our returned list as SANS are built internally with a map so ordering can be inconsistent
slices.SortFunc(foundOtherNames, func(a, b issuing.OtherNameUtf8) int { return cmp.Compare(a.Oid, b.Oid) })
if diff := deep.Equal(expectedOtherNames, foundOtherNames); len(diff) != 0 { if diff := deep.Equal(expectedOtherNames, foundOtherNames); len(diff) != 0 {
t.Errorf("unexpected otherNames: %v", diff) t.Errorf("unexpected otherNames: %v", diff)
} }
@@ -3874,9 +3884,11 @@ func TestBackend_RevokePlusTidy_Intermediate(t *testing.T) {
"maintain_stored_certificate_counts": true, "maintain_stored_certificate_counts": true,
"publish_stored_certificate_count_metrics": true, "publish_stored_certificate_count_metrics": true,
}) })
require.NoError(t, err, "failed calling auto-tidy")
_, err = client.Logical().Write("/sys/plugins/reload/backend", map[string]interface{}{ _, err = client.Logical().Write("/sys/plugins/reload/backend", map[string]interface{}{
"mounts": "pki/", "mounts": "pki/",
}) })
require.NoError(t, err, "failed calling backend reload")
// Check the metrics initialized in order to calculate backendUUID for /pki // Check the metrics initialized in order to calculate backendUUID for /pki
// BackendUUID not consistent during tests with UUID from /sys/mounts/pki // BackendUUID not consistent during tests with UUID from /sys/mounts/pki
@@ -4934,9 +4946,9 @@ func TestRootWithExistingKey(t *testing.T) {
resp, err = CBList(b, s, "issuers") resp, err = CBList(b, s, "issuers")
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 3, len(resp.Data["keys"].([]string))) require.Equal(t, 3, len(resp.Data["keys"].([]string)))
require.Contains(t, resp.Data["keys"], string(myIssuerId1.(issuerID))) require.Contains(t, resp.Data["keys"], string(myIssuerId1.(issuing.IssuerID)))
require.Contains(t, resp.Data["keys"], string(myIssuerId2.(issuerID))) require.Contains(t, resp.Data["keys"], string(myIssuerId2.(issuing.IssuerID)))
require.Contains(t, resp.Data["keys"], string(myIssuerId3.(issuerID))) require.Contains(t, resp.Data["keys"], string(myIssuerId3.(issuing.IssuerID)))
} }
func TestIntermediateWithExistingKey(t *testing.T) { func TestIntermediateWithExistingKey(t *testing.T) {
@@ -5718,17 +5730,18 @@ func TestBackend_InitializeCertificateCounts(t *testing.T) {
} }
} }
if b.certCount.Load() != 6 { certCounter := b.GetCertificateCounter()
t.Fatalf("Failed to count six certificates root,A,B,C,D,E, instead counted %d certs", b.certCount.Load()) if certCounter.CertificateCount() != 6 {
t.Fatalf("Failed to count six certificates root,A,B,C,D,E, instead counted %d certs", certCounter.CertificateCount())
} }
if b.revokedCertCount.Load() != 2 { if certCounter.RevokedCount() != 2 {
t.Fatalf("Failed to count two revoked certificates A+B, instead counted %d certs", b.revokedCertCount.Load()) t.Fatalf("Failed to count two revoked certificates A+B, instead counted %d certs", certCounter.RevokedCount())
} }
// Simulates listing while initialize in progress, by "restarting it" // Simulates listing while initialize in progress, by "restarting it"
b.certCount.Store(0) certCounter.certCount.Store(0)
b.revokedCertCount.Store(0) certCounter.revokedCertCount.Store(0)
b.certsCounted.Store(false) certCounter.certsCounted.Store(false)
// Revoke certificates C, D // Revoke certificates C, D
dirtyRevocations := serials[2:4] dirtyRevocations := serials[2:4]
@@ -5753,15 +5766,16 @@ func TestBackend_InitializeCertificateCounts(t *testing.T) {
} }
// Run initialize // Run initialize
b.initializeStoredCertificateCounts(ctx) err = b.initializeStoredCertificateCounts(ctx)
require.NoError(t, err, "failed initializing certificate counts")
// Test certificate count // Test certificate count
if b.certCount.Load() != 8 { if certCounter.CertificateCount() != 8 {
t.Fatalf("Failed to initialize count of certificates root, A,B,C,D,E,F,G counted %d certs", b.certCount.Load()) t.Fatalf("Failed to initialize count of certificates root, A,B,C,D,E,F,G counted %d certs", certCounter.CertificateCount())
} }
if b.revokedCertCount.Load() != 4 { if certCounter.RevokedCount() != 4 {
t.Fatalf("Failed to count revoked certificates A,B,C,D counted %d certs", b.revokedCertCount.Load()) t.Fatalf("Failed to count revoked certificates A,B,C,D counted %d certs", certCounter.RevokedCount())
} }
return return
@@ -6147,7 +6161,7 @@ func TestPKI_TemplatedAIAs(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
resp, err = CBWrite(b, s, "root/generate/internal", rootData) resp, err = CBWrite(b, s, "root/generate/internal", rootData)
requireSuccessNonNilResponse(t, resp, err) requireSuccessNonNilResponse(t, resp, err)
issuerId := string(resp.Data["issuer_id"].(issuerID)) issuerId := string(resp.Data["issuer_id"].(issuing.IssuerID))
// Now write the original AIA config and sign a leaf. // Now write the original AIA config and sign a leaf.
_, err = CBWrite(b, s, "config/urls", aiaData) _, err = CBWrite(b, s, "config/urls", aiaData)
@@ -7063,7 +7077,7 @@ func TestPatchIssuer(t *testing.T) {
"issuer_name": "root", "issuer_name": "root",
}) })
requireSuccessNonNilResponse(t, resp, err, "failed generating root issuer") requireSuccessNonNilResponse(t, resp, err, "failed generating root issuer")
id := string(resp.Data["issuer_id"].(issuerID)) id := string(resp.Data["issuer_id"].(issuing.IssuerID))
// 2. Enable Cluster paths // 2. Enable Cluster paths
resp, err = CBWrite(b, s, "config/urls", map[string]interface{}{ resp, err = CBWrite(b, s, "config/urls", map[string]interface{}{

View File

@@ -18,9 +18,12 @@ import (
"github.com/hashicorp/vault/sdk/framework" "github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/certutil" "github.com/hashicorp/vault/sdk/helper/certutil"
"github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/sdk/logical"
"github.com/hashicorp/vault/builtin/logical/pki/issuing"
"github.com/hashicorp/vault/builtin/logical/pki/managed_key"
) )
func getGenerationParams(sc *storageContext, data *framework.FieldData) (exported bool, format string, role *roleEntry, errorResp *logical.Response) { func getGenerationParams(sc *storageContext, data *framework.FieldData) (exported bool, format string, role *issuing.RoleEntry, errorResp *logical.Response) {
exportedStr := data.Get("exported").(string) exportedStr := data.Get("exported").(string)
switch exportedStr { switch exportedStr {
case "exported": case "exported":
@@ -47,7 +50,7 @@ func getGenerationParams(sc *storageContext, data *framework.FieldData) (exporte
return return
} }
role = &roleEntry{ role = &issuing.RoleEntry{
TTL: time.Duration(data.Get("ttl").(int)) * time.Second, TTL: time.Duration(data.Get("ttl").(int)) * time.Second,
KeyType: keyType, KeyType: keyType,
KeyBits: keyBits, KeyBits: keyBits,
@@ -90,7 +93,7 @@ func generateCABundle(sc *storageContext, input *inputBundle, data *certutil.Cre
if err != nil { if err != nil {
return nil, err return nil, err
} }
return generateManagedKeyCABundle(ctx, b, keyId, data, randomSource) return managed_key.GenerateManagedKeyCABundle(ctx, b, keyId, data, randomSource)
} }
if existingKeyRequested(input) { if existingKeyRequested(input) {
@@ -104,12 +107,12 @@ func generateCABundle(sc *storageContext, input *inputBundle, data *certutil.Cre
return nil, err return nil, err
} }
if keyEntry.isManagedPrivateKey() { if keyEntry.IsManagedPrivateKey() {
keyId, err := keyEntry.getManagedKeyUUID() keyId, err := issuing.GetManagedKeyUUID(keyEntry)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return generateManagedKeyCABundle(ctx, b, keyId, data, randomSource) return managed_key.GenerateManagedKeyCABundle(ctx, b, keyId, data, randomSource)
} }
return certutil.CreateCertificateWithKeyGenerator(data, randomSource, existingKeyGeneratorFromBytes(keyEntry)) return certutil.CreateCertificateWithKeyGenerator(data, randomSource, existingKeyGeneratorFromBytes(keyEntry))
@@ -128,7 +131,7 @@ func generateCSRBundle(sc *storageContext, input *inputBundle, data *certutil.Cr
return nil, err return nil, err
} }
return generateManagedKeyCSRBundle(ctx, b, keyId, data, addBasicConstraints, randomSource) return managed_key.GenerateManagedKeyCSRBundle(ctx, b, keyId, data, addBasicConstraints, randomSource)
} }
if existingKeyRequested(input) { if existingKeyRequested(input) {
@@ -142,12 +145,12 @@ func generateCSRBundle(sc *storageContext, input *inputBundle, data *certutil.Cr
return nil, err return nil, err
} }
if key.isManagedPrivateKey() { if key.IsManagedPrivateKey() {
keyId, err := key.getManagedKeyUUID() keyId, err := issuing.GetManagedKeyUUID(key)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return generateManagedKeyCSRBundle(ctx, b, keyId, data, addBasicConstraints, randomSource) return managed_key.GenerateManagedKeyCSRBundle(ctx, b, keyId, data, addBasicConstraints, randomSource)
} }
return certutil.CreateCSRWithKeyGenerator(data, addBasicConstraints, randomSource, existingKeyGeneratorFromBytes(key)) return certutil.CreateCSRWithKeyGenerator(data, addBasicConstraints, randomSource, existingKeyGeneratorFromBytes(key))
@@ -157,10 +160,7 @@ func generateCSRBundle(sc *storageContext, input *inputBundle, data *certutil.Cr
} }
func parseCABundle(ctx context.Context, b *backend, bundle *certutil.CertBundle) (*certutil.ParsedCertBundle, error) { func parseCABundle(ctx context.Context, b *backend, bundle *certutil.CertBundle) (*certutil.ParsedCertBundle, error) {
if bundle.PrivateKeyType == certutil.ManagedPrivateKey { return issuing.ParseCABundle(ctx, b, bundle)
return parseManagedKeyCABundle(ctx, b, bundle)
}
return bundle.ToParsedCertBundle()
} }
func (sc *storageContext) getKeyTypeAndBitsForRole(data *framework.FieldData) (string, int, error) { func (sc *storageContext) getKeyTypeAndBitsForRole(data *framework.FieldData) (string, int, error) {
@@ -192,7 +192,7 @@ func (sc *storageContext) getKeyTypeAndBitsForRole(data *framework.FieldData) (s
return "", 0, errors.New("unable to determine managed key id: " + err.Error()) return "", 0, errors.New("unable to determine managed key id: " + err.Error())
} }
pubKeyManagedKey, err := getManagedKeyPublicKey(sc.Context, sc.Backend, keyId) pubKeyManagedKey, err := managed_key.GetManagedKeyPublicKey(sc.Context, sc.Backend, keyId)
if err != nil { if err != nil {
return "", 0, errors.New("failed to lookup public key from managed key: " + err.Error()) return "", 0, errors.New("failed to lookup public key from managed key: " + err.Error())
} }
@@ -245,7 +245,7 @@ func getKeyTypeAndBitsFromPublicKeyForRole(pubKey crypto.PublicKey) (certutil.Pr
return keyType, keyBits, nil return keyType, keyBits, nil
} }
func (sc *storageContext) getExistingKeyFromRef(keyRef string) (*keyEntry, error) { func (sc *storageContext) getExistingKeyFromRef(keyRef string) (*issuing.KeyEntry, error) {
keyId, err := sc.resolveKeyReference(keyRef) keyId, err := sc.resolveKeyReference(keyRef)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -253,7 +253,7 @@ func (sc *storageContext) getExistingKeyFromRef(keyRef string) (*keyEntry, error
return sc.fetchKeyById(keyId) return sc.fetchKeyById(keyId)
} }
func existingKeyGeneratorFromBytes(key *keyEntry) certutil.KeyGenerator { func existingKeyGeneratorFromBytes(key *issuing.KeyEntry) certutil.KeyGenerator {
return func(_ string, _ int, container certutil.ParsedPrivateKeyContainer, _ io.Reader) error { return func(_ string, _ int, container certutil.ParsedPrivateKeyContainer, _ io.Reader) error {
signer, _, pemBytes, err := getSignerFromKeyEntryBytes(key) signer, _, pemBytes, err := getSignerFromKeyEntryBytes(key)
if err != nil { if err != nil {
@@ -264,61 +264,3 @@ func existingKeyGeneratorFromBytes(key *keyEntry) certutil.KeyGenerator {
return nil return nil
} }
} }
func buildSignVerbatimRoleWithNoData(role *roleEntry) *roleEntry {
data := &framework.FieldData{
Raw: map[string]interface{}{},
Schema: addSignVerbatimRoleFields(map[string]*framework.FieldSchema{}),
}
return buildSignVerbatimRole(data, role)
}
func buildSignVerbatimRole(data *framework.FieldData, role *roleEntry) *roleEntry {
entry := &roleEntry{
AllowLocalhost: true,
AllowAnyName: true,
AllowIPSANs: true,
AllowWildcardCertificates: new(bool),
EnforceHostnames: false,
KeyType: "any",
UseCSRCommonName: true,
UseCSRSANs: true,
AllowedOtherSANs: []string{"*"},
AllowedSerialNumbers: []string{"*"},
AllowedURISANs: []string{"*"},
AllowedUserIDs: []string{"*"},
CNValidations: []string{"disabled"},
GenerateLease: new(bool),
// If adding new fields to be read, update the field list within addSignVerbatimRoleFields
KeyUsage: data.Get("key_usage").([]string),
ExtKeyUsage: data.Get("ext_key_usage").([]string),
ExtKeyUsageOIDs: data.Get("ext_key_usage_oids").([]string),
SignatureBits: data.Get("signature_bits").(int),
UsePSS: data.Get("use_pss").(bool),
}
*entry.AllowWildcardCertificates = true
*entry.GenerateLease = false
if role != nil {
if role.TTL > 0 {
entry.TTL = role.TTL
}
if role.MaxTTL > 0 {
entry.MaxTTL = role.MaxTTL
}
if role.GenerateLease != nil {
*entry.GenerateLease = *role.GenerateLease
}
if role.NotBeforeDuration > 0 {
entry.NotBeforeDuration = role.NotBeforeDuration
}
entry.NoStore = role.NoStore
entry.Issuer = role.Issuer
}
if len(entry.Issuer) == 0 {
entry.Issuer = defaultRef
}
return entry
}

File diff suppressed because it is too large Load Diff

View File

@@ -10,6 +10,7 @@ import (
"strings" "strings"
"testing" "testing"
"github.com/hashicorp/vault/builtin/logical/pki/issuing"
"github.com/hashicorp/vault/sdk/framework" "github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/sdk/logical"
) )
@@ -97,7 +98,7 @@ func TestPki_FetchCertBySerial(t *testing.T) {
// order-preserving way. // order-preserving way.
func TestPki_MultipleOUs(t *testing.T) { func TestPki_MultipleOUs(t *testing.T) {
t.Parallel() t.Parallel()
var b backend b, _ := CreateBackendWithStorage(t)
fields := addCACommonFields(map[string]*framework.FieldSchema{}) fields := addCACommonFields(map[string]*framework.FieldSchema{})
apiData := &framework.FieldData{ apiData := &framework.FieldData{
@@ -109,12 +110,12 @@ func TestPki_MultipleOUs(t *testing.T) {
} }
input := &inputBundle{ input := &inputBundle{
apiData: apiData, apiData: apiData,
role: &roleEntry{ role: &issuing.RoleEntry{
MaxTTL: 3600, MaxTTL: 3600,
OU: []string{"Z", "E", "V"}, OU: []string{"Z", "E", "V"},
}, },
} }
cb, _, err := generateCreationBundle(&b, input, nil, nil) cb, _, err := generateCreationBundle(b, input, nil, nil)
if err != nil { if err != nil {
t.Fatalf("Error: %v", err) t.Fatalf("Error: %v", err)
} }
@@ -129,7 +130,7 @@ func TestPki_MultipleOUs(t *testing.T) {
func TestPki_PermitFQDNs(t *testing.T) { func TestPki_PermitFQDNs(t *testing.T) {
t.Parallel() t.Parallel()
var b backend b, _ := CreateBackendWithStorage(t)
fields := addCACommonFields(map[string]*framework.FieldSchema{}) fields := addCACommonFields(map[string]*framework.FieldSchema{})
cases := map[string]struct { cases := map[string]struct {
@@ -146,7 +147,7 @@ func TestPki_PermitFQDNs(t *testing.T) {
"ttl": 3600, "ttl": 3600,
}, },
}, },
role: &roleEntry{ role: &issuing.RoleEntry{
AllowAnyName: true, AllowAnyName: true,
MaxTTL: 3600, MaxTTL: 3600,
EnforceHostnames: true, EnforceHostnames: true,
@@ -165,7 +166,7 @@ func TestPki_PermitFQDNs(t *testing.T) {
"ttl": 3600, "ttl": 3600,
}, },
}, },
role: &roleEntry{ role: &issuing.RoleEntry{
AllowedDomains: []string{"example.net", "EXAMPLE.COM"}, AllowedDomains: []string{"example.net", "EXAMPLE.COM"},
AllowBareDomains: true, AllowBareDomains: true,
MaxTTL: 3600, MaxTTL: 3600,
@@ -183,7 +184,7 @@ func TestPki_PermitFQDNs(t *testing.T) {
"ttl": 3600, "ttl": 3600,
}, },
}, },
role: &roleEntry{ role: &issuing.RoleEntry{
AllowedDomains: []string{"example.com", "*.Example.com"}, AllowedDomains: []string{"example.com", "*.Example.com"},
AllowGlobDomains: true, AllowGlobDomains: true,
MaxTTL: 3600, MaxTTL: 3600,
@@ -201,7 +202,7 @@ func TestPki_PermitFQDNs(t *testing.T) {
"ttl": 3600, "ttl": 3600,
}, },
}, },
role: &roleEntry{ role: &issuing.RoleEntry{
AllowedDomains: []string{"test@testemail.com"}, AllowedDomains: []string{"test@testemail.com"},
AllowBareDomains: true, AllowBareDomains: true,
MaxTTL: 3600, MaxTTL: 3600,
@@ -219,7 +220,7 @@ func TestPki_PermitFQDNs(t *testing.T) {
"ttl": 3600, "ttl": 3600,
}, },
}, },
role: &roleEntry{ role: &issuing.RoleEntry{
AllowedDomains: []string{"testemail.com"}, AllowedDomains: []string{"testemail.com"},
AllowBareDomains: true, AllowBareDomains: true,
MaxTTL: 3600, MaxTTL: 3600,
@@ -234,7 +235,7 @@ func TestPki_PermitFQDNs(t *testing.T) {
name := name name := name
testCase := testCase testCase := testCase
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
cb, _, err := generateCreationBundle(&b, testCase.input, nil, nil) cb, _, err := generateCreationBundle(b, testCase.input, nil, nil)
if err != nil { if err != nil {
t.Fatalf("Error: %v", err) t.Fatalf("Error: %v", err)
} }

View File

@@ -16,6 +16,7 @@ import (
"testing" "testing"
"time" "time"
"github.com/hashicorp/vault/builtin/logical/pki/issuing"
"github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/sdk/logical"
) )
@@ -575,7 +576,7 @@ func (c CBIssueLeaf) RevokeLeaf(t testing.TB, b *backend, s logical.Storage, kno
if resp == nil { if resp == nil {
t.Fatalf("failed to read default issuer config: nil response") t.Fatalf("failed to read default issuer config: nil response")
} }
defaultID := resp.Data["default"].(issuerID).String() defaultID := resp.Data["default"].(issuing.IssuerID).String()
c.Issuer = defaultID c.Issuer = defaultID
issuer = nil issuer = nil
} }
@@ -637,7 +638,7 @@ func (c CBIssueLeaf) Run(t testing.TB, b *backend, s logical.Storage, knownKeys
if resp == nil { if resp == nil {
t.Fatalf("failed to read default issuer config: nil response") t.Fatalf("failed to read default issuer config: nil response")
} }
defaultID := resp.Data["default"].(issuerID).String() defaultID := resp.Data["default"].(issuing.IssuerID).String()
resp, err = CBRead(b, s, "issuer/"+c.Issuer) resp, err = CBRead(b, s, "issuer/"+c.Issuer)
if err != nil { if err != nil {
@@ -646,7 +647,7 @@ func (c CBIssueLeaf) Run(t testing.TB, b *backend, s logical.Storage, knownKeys
if resp == nil { if resp == nil {
t.Fatalf("failed to read issuer %v: nil response", c.Issuer) t.Fatalf("failed to read issuer %v: nil response", c.Issuer)
} }
ourID := resp.Data["issuer_id"].(issuerID).String() ourID := resp.Data["issuer_id"].(issuing.IssuerID).String()
areDefault := ourID == defaultID areDefault := ourID == defaultID
for _, usage := range []string{"read-only", "crl-signing", "issuing-certificates", "issuing-certificates,crl-signing"} { for _, usage := range []string{"read-only", "crl-signing", "issuing-certificates", "issuing-certificates,crl-signing"} {

View File

@@ -9,10 +9,11 @@ import (
"fmt" "fmt"
"sort" "sort"
"github.com/hashicorp/vault/builtin/logical/pki/issuing"
"github.com/hashicorp/vault/sdk/helper/errutil" "github.com/hashicorp/vault/sdk/helper/errutil"
) )
func prettyIssuer(issuerIdEntryMap map[issuerID]*issuerEntry, issuer issuerID) string { func prettyIssuer(issuerIdEntryMap map[issuing.IssuerID]*issuing.IssuerEntry, issuer issuing.IssuerID) string {
if entry, ok := issuerIdEntryMap[issuer]; ok && len(entry.Name) > 0 { if entry, ok := issuerIdEntryMap[issuer]; ok && len(entry.Name) > 0 {
return "[id:" + string(issuer) + "/name:" + entry.Name + "]" return "[id:" + string(issuer) + "/name:" + entry.Name + "]"
} }
@@ -20,7 +21,7 @@ func prettyIssuer(issuerIdEntryMap map[issuerID]*issuerEntry, issuer issuerID) s
return "[" + string(issuer) + "]" return "[" + string(issuer) + "]"
} }
func (sc *storageContext) rebuildIssuersChains(referenceCert *issuerEntry /* optional */) error { func (sc *storageContext) rebuildIssuersChains(referenceCert *issuing.IssuerEntry /* optional */) error {
// This function rebuilds the CAChain field of all known issuers. This // This function rebuilds the CAChain field of all known issuers. This
// function should usually be invoked when a new issuer is added to the // function should usually be invoked when a new issuer is added to the
// pool of issuers. // pool of issuers.
@@ -116,22 +117,22 @@ func (sc *storageContext) rebuildIssuersChains(referenceCert *issuerEntry /* opt
// fourth maps that certificate back to the other issuers with that // fourth maps that certificate back to the other issuers with that
// subject (note the keyword _other_: we'll exclude self-loops here) -- // subject (note the keyword _other_: we'll exclude self-loops here) --
// either via a parent or child relationship. // either via a parent or child relationship.
issuerIdEntryMap := make(map[issuerID]*issuerEntry, len(issuers)) issuerIdEntryMap := make(map[issuing.IssuerID]*issuing.IssuerEntry, len(issuers))
issuerIdCertMap := make(map[issuerID]*x509.Certificate, len(issuers)) issuerIdCertMap := make(map[issuing.IssuerID]*x509.Certificate, len(issuers))
issuerIdParentsMap := make(map[issuerID][]issuerID, len(issuers)) issuerIdParentsMap := make(map[issuing.IssuerID][]issuing.IssuerID, len(issuers))
issuerIdChildrenMap := make(map[issuerID][]issuerID, len(issuers)) issuerIdChildrenMap := make(map[issuing.IssuerID][]issuing.IssuerID, len(issuers))
// For every known issuer, we map that subject back to the id of issuers // For every known issuer, we map that subject back to the id of issuers
// containing that subject. This lets us build our issuerID -> parents // containing that subject. This lets us build our IssuerID -> parents
// mapping efficiently. Worst case we'll have a single linear chain where // mapping efficiently. Worst case we'll have a single linear chain where
// every entry has a distinct subject. // every entry has a distinct subject.
subjectIssuerIdsMap := make(map[string][]issuerID, len(issuers)) subjectIssuerIdsMap := make(map[string][]issuing.IssuerID, len(issuers))
// First, read every issuer entry from storage. We'll propagate entries // First, read every issuer entry from storage. We'll propagate entries
// to three of the maps here: all but issuerIdParentsMap and // to three of the maps here: all but issuerIdParentsMap and
// issuerIdChildrenMap, which we'll do in a second pass. // issuerIdChildrenMap, which we'll do in a second pass.
for _, identifier := range issuers { for _, identifier := range issuers {
var stored *issuerEntry var stored *issuing.IssuerEntry
// When the reference issuer is provided and matches this identifier, // When the reference issuer is provided and matches this identifier,
// prefer the updated reference copy instead. // prefer the updated reference copy instead.
@@ -261,8 +262,8 @@ func (sc *storageContext) rebuildIssuersChains(referenceCert *issuerEntry /* opt
// manually building their chain prior to starting the topographical sort. // manually building their chain prior to starting the topographical sort.
// //
// This thus runs in O(|V| + |E|) -> O(n^2) in the number of issuers. // This thus runs in O(|V| + |E|) -> O(n^2) in the number of issuers.
processedIssuers := make(map[issuerID]bool, len(issuers)) processedIssuers := make(map[issuing.IssuerID]bool, len(issuers))
toVisit := make([]issuerID, 0, len(issuers)) toVisit := make([]issuing.IssuerID, 0, len(issuers))
// Handle any explicitly constructed certificate chains. Here, we don't // Handle any explicitly constructed certificate chains. Here, we don't
// validate much what the user provides; if they provide since-deleted // validate much what the user provides; if they provide since-deleted
@@ -323,7 +324,7 @@ func (sc *storageContext) rebuildIssuersChains(referenceCert *issuerEntry /* opt
// ensure we don't accidentally infinite-loop (if we introduce a bug). // ensure we don't accidentally infinite-loop (if we introduce a bug).
maxVisitCount := len(issuers)*len(issuers)*len(issuers) + 100 maxVisitCount := len(issuers)*len(issuers)*len(issuers) + 100
for len(toVisit) > 0 && maxVisitCount >= 0 { for len(toVisit) > 0 && maxVisitCount >= 0 {
var issuer issuerID var issuer issuing.IssuerID
issuer, toVisit = toVisit[0], toVisit[1:] issuer, toVisit = toVisit[0], toVisit[1:]
// If (and only if) we're presently starved for next nodes to visit, // If (and only if) we're presently starved for next nodes to visit,
@@ -387,8 +388,8 @@ func (sc *storageContext) rebuildIssuersChains(referenceCert *issuerEntry /* opt
// However, if you directly step onto the cross-signed, now you're // However, if you directly step onto the cross-signed, now you're
// taken in an alternative direction (via its chain), and must // taken in an alternative direction (via its chain), and must
// revisit any roots later. // revisit any roots later.
var roots []issuerID var roots []issuing.IssuerID
var intermediates []issuerID var intermediates []issuing.IssuerID
for _, parentCertId := range parentCerts { for _, parentCertId := range parentCerts {
if bytes.Equal(issuerIdCertMap[parentCertId].RawSubject, issuerIdCertMap[parentCertId].RawIssuer) { if bytes.Equal(issuerIdCertMap[parentCertId].RawSubject, issuerIdCertMap[parentCertId].RawIssuer) {
roots = append(roots, parentCertId) roots = append(roots, parentCertId)
@@ -470,7 +471,7 @@ func (sc *storageContext) rebuildIssuersChains(referenceCert *issuerEntry /* opt
return nil return nil
} }
func addToChainIfNotExisting(includedParentCerts map[string]bool, entry *issuerEntry, certToAdd string) { func addToChainIfNotExisting(includedParentCerts map[string]bool, entry *issuing.IssuerEntry, certToAdd string) {
included, ok := includedParentCerts[certToAdd] included, ok := includedParentCerts[certToAdd]
if ok && included { if ok && included {
return return
@@ -481,15 +482,15 @@ func addToChainIfNotExisting(includedParentCerts map[string]bool, entry *issuerE
} }
func processAnyCliqueOrCycle( func processAnyCliqueOrCycle(
issuers []issuerID, issuers []issuing.IssuerID,
processedIssuers map[issuerID]bool, processedIssuers map[issuing.IssuerID]bool,
toVisit []issuerID, toVisit []issuing.IssuerID,
issuerIdEntryMap map[issuerID]*issuerEntry, issuerIdEntryMap map[issuing.IssuerID]*issuing.IssuerEntry,
issuerIdCertMap map[issuerID]*x509.Certificate, issuerIdCertMap map[issuing.IssuerID]*x509.Certificate,
issuerIdParentsMap map[issuerID][]issuerID, issuerIdParentsMap map[issuing.IssuerID][]issuing.IssuerID,
issuerIdChildrenMap map[issuerID][]issuerID, issuerIdChildrenMap map[issuing.IssuerID][]issuing.IssuerID,
subjectIssuerIdsMap map[string][]issuerID, subjectIssuerIdsMap map[string][]issuing.IssuerID,
) ([]issuerID /* toVisit */, error) { ) ([]issuing.IssuerID /* toVisit */, error) {
// Topological sort really only works on directed acyclic graphs (DAGs). // Topological sort really only works on directed acyclic graphs (DAGs).
// But a pool of arbitrary (issuer) certificates are actually neither! // But a pool of arbitrary (issuer) certificates are actually neither!
// This pool could contain both cliques and cycles. Because this could // This pool could contain both cliques and cycles. Because this could
@@ -550,15 +551,15 @@ func processAnyCliqueOrCycle(
// Finally -- it isn't enough to consider this chain in isolation // Finally -- it isn't enough to consider this chain in isolation
// either. We need to consider _all_ parents and ensure they've been // either. We need to consider _all_ parents and ensure they've been
// processed before processing this closure. // processed before processing this closure.
var cliques [][]issuerID var cliques [][]issuing.IssuerID
var cycles [][]issuerID var cycles [][]issuing.IssuerID
closure := make(map[issuerID]bool) closure := make(map[issuing.IssuerID]bool)
var cliquesToProcess []issuerID var cliquesToProcess []issuing.IssuerID
cliquesToProcess = append(cliquesToProcess, issuer) cliquesToProcess = append(cliquesToProcess, issuer)
for len(cliquesToProcess) > 0 { for len(cliquesToProcess) > 0 {
var node issuerID var node issuing.IssuerID
node, cliquesToProcess = cliquesToProcess[0], cliquesToProcess[1:] node, cliquesToProcess = cliquesToProcess[0], cliquesToProcess[1:]
// Skip potential clique nodes which have already been processed // Skip potential clique nodes which have already been processed
@@ -753,7 +754,7 @@ func processAnyCliqueOrCycle(
return nil, err return nil, err
} }
closure := make(map[issuerID]bool) closure := make(map[issuing.IssuerID]bool)
for _, cycle := range cycles { for _, cycle := range cycles {
for _, node := range cycle { for _, node := range cycle {
closure[node] = true closure[node] = true
@@ -811,14 +812,14 @@ func processAnyCliqueOrCycle(
} }
func findAllCliques( func findAllCliques(
processedIssuers map[issuerID]bool, processedIssuers map[issuing.IssuerID]bool,
issuerIdCertMap map[issuerID]*x509.Certificate, issuerIdCertMap map[issuing.IssuerID]*x509.Certificate,
subjectIssuerIdsMap map[string][]issuerID, subjectIssuerIdsMap map[string][]issuing.IssuerID,
issuers []issuerID, issuers []issuing.IssuerID,
) ([][]issuerID, map[issuerID]int, []issuerID, error) { ) ([][]issuing.IssuerID, map[issuing.IssuerID]int, []issuing.IssuerID, error) {
var allCliques [][]issuerID var allCliques [][]issuing.IssuerID
issuerIdCliqueMap := make(map[issuerID]int) issuerIdCliqueMap := make(map[issuing.IssuerID]int)
var allCliqueNodes []issuerID var allCliqueNodes []issuing.IssuerID
for _, node := range issuers { for _, node := range issuers {
// Check if the node has already been visited... // Check if the node has already been visited...
@@ -859,11 +860,11 @@ func findAllCliques(
} }
func isOnReissuedClique( func isOnReissuedClique(
processedIssuers map[issuerID]bool, processedIssuers map[issuing.IssuerID]bool,
issuerIdCertMap map[issuerID]*x509.Certificate, issuerIdCertMap map[issuing.IssuerID]*x509.Certificate,
subjectIssuerIdsMap map[string][]issuerID, subjectIssuerIdsMap map[string][]issuing.IssuerID,
node issuerID, node issuing.IssuerID,
) ([]issuerID, error) { ) ([]issuing.IssuerID, error) {
// Finding max cliques in arbitrary graphs is a nearly pathological // Finding max cliques in arbitrary graphs is a nearly pathological
// problem, usually left to the realm of SAT solvers and NP-Complete // problem, usually left to the realm of SAT solvers and NP-Complete
// theoretical. // theoretical.
@@ -891,7 +892,7 @@ func isOnReissuedClique(
// under this reissued clique detection code). // under this reissued clique detection code).
// //
// What does this mean for our algorithm? A simple greedy search is // What does this mean for our algorithm? A simple greedy search is
// sufficient. If we index our certificates by subject -> issuerID // sufficient. If we index our certificates by subject -> IssuerID
// (and cache its value across calls, which we've already done for // (and cache its value across calls, which we've already done for
// building the parent/child relationship), we can find all other issuers // building the parent/child relationship), we can find all other issuers
// with the same public key and subject as the existing node fairly // with the same public key and subject as the existing node fairly
@@ -925,7 +926,7 @@ func isOnReissuedClique(
// condition (the subject half), so validate they match the other half // condition (the subject half), so validate they match the other half
// (the issuer half) and the second condition. For node (which is // (the issuer half) and the second condition. For node (which is
// included in candidates), the condition should vacuously hold. // included in candidates), the condition should vacuously hold.
var clique []issuerID var clique []issuing.IssuerID
for _, candidate := range candidates { for _, candidate := range candidates {
// Skip already processed nodes, even if they could be clique // Skip already processed nodes, even if they could be clique
// candidates. We'll treat them as any other (already processed) // candidates. We'll treat them as any other (already processed)
@@ -957,7 +958,7 @@ func isOnReissuedClique(
return clique, nil return clique, nil
} }
func containsIssuer(collection []issuerID, target issuerID) bool { func containsIssuer(collection []issuing.IssuerID, target issuing.IssuerID) bool {
if len(collection) == 0 { if len(collection) == 0 {
return false return false
} }
@@ -971,7 +972,7 @@ func containsIssuer(collection []issuerID, target issuerID) bool {
return false return false
} }
func appendCycleIfNotExisting(knownCycles [][]issuerID, candidate []issuerID) [][]issuerID { func appendCycleIfNotExisting(knownCycles [][]issuing.IssuerID, candidate []issuing.IssuerID) [][]issuing.IssuerID {
// There's two ways to do cycle detection: canonicalize the cycles, // There's two ways to do cycle detection: canonicalize the cycles,
// rewriting them to have the least (or max) element first or just // rewriting them to have the least (or max) element first or just
// brute force the detection. // brute force the detection.
@@ -1007,7 +1008,7 @@ func appendCycleIfNotExisting(knownCycles [][]issuerID, candidate []issuerID) []
return knownCycles return knownCycles
} }
func canonicalizeCycle(cycle []issuerID) []issuerID { func canonicalizeCycle(cycle []issuing.IssuerID) []issuing.IssuerID {
// Find the minimum value and put it at the head, keeping the relative // Find the minimum value and put it at the head, keeping the relative
// ordering the same. // ordering the same.
minIndex := 0 minIndex := 0
@@ -1026,11 +1027,11 @@ func canonicalizeCycle(cycle []issuerID) []issuerID {
} }
func findCyclesNearClique( func findCyclesNearClique(
processedIssuers map[issuerID]bool, processedIssuers map[issuing.IssuerID]bool,
issuerIdCertMap map[issuerID]*x509.Certificate, issuerIdCertMap map[issuing.IssuerID]*x509.Certificate,
issuerIdChildrenMap map[issuerID][]issuerID, issuerIdChildrenMap map[issuing.IssuerID][]issuing.IssuerID,
cliqueNodes []issuerID, cliqueNodes []issuing.IssuerID,
) ([][]issuerID, error) { ) ([][]issuing.IssuerID, error) {
// When we have a reissued clique, we need to find all cycles next to it. // When we have a reissued clique, we need to find all cycles next to it.
// Presumably, because they all have non-empty parents, they should not // Presumably, because they all have non-empty parents, they should not
// have been visited yet. We further know that (because we're exploring // have been visited yet. We further know that (because we're exploring
@@ -1046,7 +1047,7 @@ func findCyclesNearClique(
// Copy the clique nodes as excluded nodes; we'll avoid exploring cycles // Copy the clique nodes as excluded nodes; we'll avoid exploring cycles
// which have parents that have been already explored. // which have parents that have been already explored.
excludeNodes := cliqueNodes[:] excludeNodes := cliqueNodes[:]
var knownCycles [][]issuerID var knownCycles [][]issuing.IssuerID
// We know the node has at least one child, since the clique is non-empty. // We know the node has at least one child, since the clique is non-empty.
for _, child := range issuerIdChildrenMap[cliqueNode] { for _, child := range issuerIdChildrenMap[cliqueNode] {
@@ -1081,12 +1082,12 @@ func findCyclesNearClique(
} }
func findAllCyclesWithNode( func findAllCyclesWithNode(
processedIssuers map[issuerID]bool, processedIssuers map[issuing.IssuerID]bool,
issuerIdCertMap map[issuerID]*x509.Certificate, issuerIdCertMap map[issuing.IssuerID]*x509.Certificate,
issuerIdChildrenMap map[issuerID][]issuerID, issuerIdChildrenMap map[issuing.IssuerID][]issuing.IssuerID,
source issuerID, source issuing.IssuerID,
exclude []issuerID, exclude []issuing.IssuerID,
) ([][]issuerID, error) { ) ([][]issuing.IssuerID, error) {
// We wish to find all cycles involving this particular node and report // We wish to find all cycles involving this particular node and report
// the corresponding paths. This is a full-graph traversal (excluding // the corresponding paths. This is a full-graph traversal (excluding
// certain paths) as we're not just checking if a cycle occurred, but // certain paths) as we're not just checking if a cycle occurred, but
@@ -1096,28 +1097,28 @@ func findAllCyclesWithNode(
maxCycleSize := 8 maxCycleSize := 8
// Whether we've visited any given node. // Whether we've visited any given node.
cycleVisited := make(map[issuerID]bool) cycleVisited := make(map[issuing.IssuerID]bool)
visitCounts := make(map[issuerID]int) visitCounts := make(map[issuing.IssuerID]int)
parentCounts := make(map[issuerID]map[issuerID]bool) parentCounts := make(map[issuing.IssuerID]map[issuing.IssuerID]bool)
// Paths to the specified node. Some of these might be cycles. // Paths to the specified node. Some of these might be cycles.
pathsTo := make(map[issuerID][][]issuerID) pathsTo := make(map[issuing.IssuerID][][]issuing.IssuerID)
// Nodes to visit. // Nodes to visit.
var visitQueue []issuerID var visitQueue []issuing.IssuerID
// Add the source node to start. In order to set up the paths to a // Add the source node to start. In order to set up the paths to a
// given node, we seed pathsTo with the single path involving just // given node, we seed pathsTo with the single path involving just
// this node // this node
visitQueue = append(visitQueue, source) visitQueue = append(visitQueue, source)
pathsTo[source] = [][]issuerID{{source}} pathsTo[source] = [][]issuing.IssuerID{{source}}
// Begin building paths. // Begin building paths.
// //
// Loop invariant: // Loop invariant:
// pathTo[x] contains valid paths to reach this node, from source. // pathTo[x] contains valid paths to reach this node, from source.
for len(visitQueue) > 0 { for len(visitQueue) > 0 {
var current issuerID var current issuing.IssuerID
current, visitQueue = visitQueue[0], visitQueue[1:] current, visitQueue = visitQueue[0], visitQueue[1:]
// If we've already processed this node, we have a cycle. Skip this // If we've already processed this node, we have a cycle. Skip this
@@ -1162,7 +1163,7 @@ func findAllCyclesWithNode(
// Track this parent->child relationship to know when to exit. // Track this parent->child relationship to know when to exit.
setOfParents, ok := parentCounts[child] setOfParents, ok := parentCounts[child]
if !ok { if !ok {
setOfParents = make(map[issuerID]bool) setOfParents = make(map[issuing.IssuerID]bool)
parentCounts[child] = setOfParents parentCounts[child] = setOfParents
} }
_, existingParent := setOfParents[current] _, existingParent := setOfParents[current]
@@ -1179,7 +1180,7 @@ func findAllCyclesWithNode(
// externally with an existing path). // externally with an existing path).
addedPath := false addedPath := false
if _, ok := pathsTo[child]; !ok { if _, ok := pathsTo[child]; !ok {
pathsTo[child] = make([][]issuerID, 0) pathsTo[child] = make([][]issuing.IssuerID, 0)
} }
for _, path := range pathsTo[current] { for _, path := range pathsTo[current] {
@@ -1204,7 +1205,7 @@ func findAllCyclesWithNode(
return nil, errutil.InternalError{Err: fmt.Sprintf("Error updating certificate path: path of length %d is too long", len(path))} return nil, errutil.InternalError{Err: fmt.Sprintf("Error updating certificate path: path of length %d is too long", len(path))}
} }
// Make sure to deep copy the path. // Make sure to deep copy the path.
newPath := make([]issuerID, 0, len(path)+1) newPath := make([]issuing.IssuerID, 0, len(path)+1)
newPath = append(newPath, path...) newPath = append(newPath, path...)
newPath = append(newPath, child) newPath = append(newPath, child)
@@ -1249,7 +1250,7 @@ func findAllCyclesWithNode(
// Ok, we've now exited from our loop. Any cycles would've been detected // Ok, we've now exited from our loop. Any cycles would've been detected
// and their paths recorded in pathsTo. Now we can iterate over these // and their paths recorded in pathsTo. Now we can iterate over these
// (starting a source), clean them up and validate them. // (starting a source), clean them up and validate them.
var cycles [][]issuerID var cycles [][]issuing.IssuerID
for _, cycle := range pathsTo[source] { for _, cycle := range pathsTo[source] {
// Skip the trivial cycle. // Skip the trivial cycle.
if len(cycle) == 1 && cycle[0] == source { if len(cycle) == 1 && cycle[0] == source {
@@ -1287,8 +1288,8 @@ func findAllCyclesWithNode(
return cycles, nil return cycles, nil
} }
func reversedCycle(cycle []issuerID) []issuerID { func reversedCycle(cycle []issuing.IssuerID) []issuing.IssuerID {
var result []issuerID var result []issuing.IssuerID
for index := len(cycle) - 1; index >= 0; index-- { for index := len(cycle) - 1; index >= 0; index-- {
result = append(result, cycle[index]) result = append(result, cycle[index])
} }
@@ -1297,11 +1298,11 @@ func reversedCycle(cycle []issuerID) []issuerID {
} }
func computeParentsFromClosure( func computeParentsFromClosure(
processedIssuers map[issuerID]bool, processedIssuers map[issuing.IssuerID]bool,
issuerIdParentsMap map[issuerID][]issuerID, issuerIdParentsMap map[issuing.IssuerID][]issuing.IssuerID,
closure map[issuerID]bool, closure map[issuing.IssuerID]bool,
) (map[issuerID]bool, bool) { ) (map[issuing.IssuerID]bool, bool) {
parents := make(map[issuerID]bool) parents := make(map[issuing.IssuerID]bool)
for node := range closure { for node := range closure {
nodeParents, ok := issuerIdParentsMap[node] nodeParents, ok := issuerIdParentsMap[node]
if !ok { if !ok {
@@ -1326,11 +1327,11 @@ func computeParentsFromClosure(
} }
func addNodeCertsToEntry( func addNodeCertsToEntry(
issuerIdEntryMap map[issuerID]*issuerEntry, issuerIdEntryMap map[issuing.IssuerID]*issuing.IssuerEntry,
issuerIdChildrenMap map[issuerID][]issuerID, issuerIdChildrenMap map[issuing.IssuerID][]issuing.IssuerID,
includedParentCerts map[string]bool, includedParentCerts map[string]bool,
entry *issuerEntry, entry *issuing.IssuerEntry,
issuersCollection ...[]issuerID, issuersCollection ...[]issuing.IssuerID,
) { ) {
for _, collection := range issuersCollection { for _, collection := range issuersCollection {
// Find a starting point into this collection such that it verifies // Find a starting point into this collection such that it verifies
@@ -1369,10 +1370,10 @@ func addNodeCertsToEntry(
} }
func addParentChainsToEntry( func addParentChainsToEntry(
issuerIdEntryMap map[issuerID]*issuerEntry, issuerIdEntryMap map[issuing.IssuerID]*issuing.IssuerEntry,
includedParentCerts map[string]bool, includedParentCerts map[string]bool,
entry *issuerEntry, entry *issuing.IssuerEntry,
parents map[issuerID]bool, parents map[issuing.IssuerID]bool,
) { ) {
for parent := range parents { for parent := range parents {
nodeEntry := issuerIdEntryMap[parent] nodeEntry := issuerIdEntryMap[parent]

View File

@@ -9,13 +9,14 @@ import (
"crypto/x509" "crypto/x509"
"fmt" "fmt"
"github.com/hashicorp/vault/builtin/logical/pki/issuing"
"github.com/hashicorp/vault/sdk/framework" "github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/certutil" "github.com/hashicorp/vault/sdk/helper/certutil"
"github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/sdk/logical"
) )
// issueAcmeCertUsingCieps based on the passed in ACME information, perform a CIEPS request/response // issueAcmeCertUsingCieps based on the passed in ACME information, perform a CIEPS request/response
func issueAcmeCertUsingCieps(_ *backend, _ *acmeContext, _ *logical.Request, _ *framework.FieldData, _ *jwsCtx, _ *acmeAccount, _ *acmeOrder, _ *x509.CertificateRequest) (*certutil.ParsedCertBundle, issuerID, error) { func issueAcmeCertUsingCieps(_ *backend, _ *acmeContext, _ *logical.Request, _ *framework.FieldData, _ *jwsCtx, _ *acmeAccount, _ *acmeOrder, _ *x509.CertificateRequest) (*certutil.ParsedCertBundle, issuing.IssuerID, error) {
return nil, "", fmt.Errorf("cieps is an enterprise only feature") return nil, "", fmt.Errorf("cieps is an enterprise only feature")
} }

View File

@@ -4,9 +4,9 @@
package pki package pki
import ( import (
"fmt"
"strings" "strings"
"time"
"github.com/hashicorp/vault/builtin/logical/pki/issuing"
) )
func (sc *storageContext) isDefaultKeySet() (bool, error) { func (sc *storageContext) isDefaultKeySet() (bool, error) {
@@ -27,14 +27,14 @@ func (sc *storageContext) isDefaultIssuerSet() (bool, error) {
return strings.TrimSpace(config.DefaultIssuerId.String()) != "", nil return strings.TrimSpace(config.DefaultIssuerId.String()) != "", nil
} }
func (sc *storageContext) updateDefaultKeyId(id keyID) error { func (sc *storageContext) updateDefaultKeyId(id issuing.KeyID) error {
config, err := sc.getKeysConfig() config, err := sc.getKeysConfig()
if err != nil { if err != nil {
return err return err
} }
if config.DefaultKeyId != id { if config.DefaultKeyId != id {
return sc.setKeysConfig(&keyConfigEntry{ return sc.setKeysConfig(&issuing.KeyConfigEntry{
DefaultKeyId: id, DefaultKeyId: id,
}) })
} }
@@ -42,7 +42,7 @@ func (sc *storageContext) updateDefaultKeyId(id keyID) error {
return nil return nil
} }
func (sc *storageContext) updateDefaultIssuerId(id issuerID) error { func (sc *storageContext) updateDefaultIssuerId(id issuing.IssuerID) error {
config, err := sc.getIssuersConfig() config, err := sc.getIssuersConfig()
if err != nil { if err != nil {
return err return err
@@ -55,67 +55,3 @@ func (sc *storageContext) updateDefaultIssuerId(id issuerID) error {
return nil return nil
} }
func (sc *storageContext) changeDefaultIssuerTimestamps(oldDefault issuerID, newDefault issuerID) error {
if newDefault == oldDefault {
return nil
}
now := time.Now().UTC()
// When the default issuer changes, we need to modify four
// pieces of information:
//
// 1. The old default issuer's modification time, as it no
// longer works for the /cert/ca path.
// 2. The new default issuer's modification time, as it now
// works for the /cert/ca path.
// 3. & 4. Both issuer's CRLs, as they behave the same, under
// the /cert/crl path!
for _, thisId := range []issuerID{oldDefault, newDefault} {
if len(thisId) == 0 {
continue
}
// 1 & 2 above.
issuer, err := sc.fetchIssuerById(thisId)
if err != nil {
// Due to the lack of transactions, if we deleted the default
// issuer (successfully), but the subsequent issuer config write
// (to clear the default issuer's old id) failed, we might have
// an inconsistent config. If we later hit this loop (and flush
// these timestamps again -- perhaps because the operator
// selected a new default), we'd have erred out here, because
// the since-deleted default issuer doesn't exist. In this case,
// skip the issuer instead of bailing.
err := fmt.Errorf("unable to update issuer (%v)'s modification time: error fetching issuer: %w", thisId, err)
if strings.Contains(err.Error(), "does not exist") {
sc.Backend.Logger().Warn(err.Error())
continue
}
return err
}
issuer.LastModified = now
err = sc.writeIssuer(issuer)
if err != nil {
return fmt.Errorf("unable to update issuer (%v)'s modification time: error persisting issuer: %w", thisId, err)
}
}
// Fetch and update the internalCRLConfigEntry (3&4).
cfg, err := sc.getLocalCRLConfig()
if err != nil {
return fmt.Errorf("unable to update local CRL config's modification time: error fetching local CRL config: %w", err)
}
cfg.LastModified = now
cfg.DeltaLastModified = now
err = sc.setLocalCRLConfig(cfg)
if err != nil {
return fmt.Errorf("unable to update local CRL config's modification time: error persisting local CRL config: %w", err)
}
return nil
}

View File

@@ -13,6 +13,7 @@ import (
"time" "time"
"github.com/hashicorp/vault/api" "github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/builtin/logical/pki/issuing"
"github.com/hashicorp/vault/helper/constants" "github.com/hashicorp/vault/helper/constants"
vaulthttp "github.com/hashicorp/vault/http" vaulthttp "github.com/hashicorp/vault/http"
"github.com/hashicorp/vault/sdk/helper/testhelpers/schema" "github.com/hashicorp/vault/sdk/helper/testhelpers/schema"
@@ -1063,7 +1064,7 @@ func TestAutoRebuild(t *testing.T) {
var revInfo revocationInfo var revInfo revocationInfo
err = json.Unmarshal([]byte(revEntryValue), &revInfo) err = json.Unmarshal([]byte(revEntryValue), &revInfo)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, revInfo.CertificateIssuer, issuerID(rootIssuer)) require.Equal(t, revInfo.CertificateIssuer, issuing.IssuerID(rootIssuer))
// New serial should not appear on CRL. // New serial should not appear on CRL.
crl = getCrlCertificateList(t, client, "pki") crl = getCrlCertificateList(t, client, "pki")
@@ -1201,7 +1202,7 @@ func TestTidyIssuerAssociation(t *testing.T) {
require.NotEmpty(t, resp.Data["certificate"]) require.NotEmpty(t, resp.Data["certificate"])
require.NotEmpty(t, resp.Data["issuer_id"]) require.NotEmpty(t, resp.Data["issuer_id"])
rootCert := resp.Data["certificate"].(string) rootCert := resp.Data["certificate"].(string)
rootID := resp.Data["issuer_id"].(issuerID) rootID := resp.Data["issuer_id"].(issuing.IssuerID)
// Create a role for issuance. // Create a role for issuance.
_, err = CBWrite(b, s, "roles/local-testing", map[string]interface{}{ _, err = CBWrite(b, s, "roles/local-testing", map[string]interface{}{
@@ -1495,9 +1496,9 @@ func TestCRLIssuerRemoval(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, resp) require.NotNil(t, resp)
key := string(resp.Data["key_id"].(keyID)) key := string(resp.Data["key_id"].(issuing.KeyID))
keyIDs = append(keyIDs, key) keyIDs = append(keyIDs, key)
issuer := string(resp.Data["issuer_id"].(issuerID)) issuer := string(resp.Data["issuer_id"].(issuing.IssuerID))
issuerIDs = append(issuerIDs, issuer) issuerIDs = append(issuerIDs, issuer)
} }
_, err = CBRead(b, s, "crl/rotate") _, err = CBRead(b, s, "crl/rotate")

View File

@@ -12,14 +12,15 @@ import (
"math/big" "math/big"
"strings" "strings"
"sync" "sync"
"sync/atomic"
"time" "time"
"github.com/hashicorp/go-secure-stdlib/parseutil" "github.com/hashicorp/go-secure-stdlib/parseutil"
"github.com/hashicorp/vault/builtin/logical/pki/issuing"
"github.com/hashicorp/vault/sdk/helper/certutil" "github.com/hashicorp/vault/sdk/helper/certutil"
"github.com/hashicorp/vault/sdk/helper/consts" "github.com/hashicorp/vault/sdk/helper/consts"
"github.com/hashicorp/vault/sdk/helper/errutil" "github.com/hashicorp/vault/sdk/helper/errutil"
"github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/sdk/logical"
atomic2 "go.uber.org/atomic"
) )
const ( const (
@@ -38,10 +39,10 @@ const (
) )
type revocationInfo struct { type revocationInfo struct {
CertificateBytes []byte `json:"certificate_bytes"` CertificateBytes []byte `json:"certificate_bytes"`
RevocationTime int64 `json:"revocation_time"` RevocationTime int64 `json:"revocation_time"`
RevocationTimeUTC time.Time `json:"revocation_time_utc"` RevocationTimeUTC time.Time `json:"revocation_time_utc"`
CertificateIssuer issuerID `json:"issuer_id"` CertificateIssuer issuing.IssuerID `json:"issuer_id"`
} }
type revocationRequest struct { type revocationRequest struct {
@@ -81,31 +82,31 @@ type (
} }
) )
// crlBuilder is gatekeeper for controlling various read/write operations to the storage of the CRL. // CrlBuilder is gatekeeper for controlling various read/write operations to the storage of the CRL.
// The extra complexity arises from secondary performance clusters seeing various writes to its storage // The extra complexity arises from secondary performance clusters seeing various writes to its storage
// without the actual API calls. During the storage invalidation process, we do not have the required state // without the actual API calls. During the storage invalidation process, we do not have the required state
// to actually rebuild the CRLs, so we need to schedule it in a deferred fashion. This allows either // to actually rebuild the CRLs, so we need to schedule it in a deferred fashion. This allows either
// read or write calls to perform the operation if required, or have the flag reset upon a write operation // read or write calls to perform the operation if required, or have the flag reset upon a write operation
// //
// The CRL builder also tracks the revocation configuration. // The CRL builder also tracks the revocation configuration.
type crlBuilder struct { type CrlBuilder struct {
_builder sync.Mutex _builder sync.Mutex
forceRebuild *atomic2.Bool forceRebuild *atomic.Bool
canRebuild bool canRebuild bool
lastDeltaRebuildCheck time.Time lastDeltaRebuildCheck time.Time
_config sync.RWMutex _config sync.RWMutex
dirty *atomic2.Bool dirty *atomic.Bool
config crlConfig config crlConfig
haveInitializedConfig bool haveInitializedConfig bool
// Whether to invalidate our LastModifiedTime due to write on the // Whether to invalidate our LastModifiedTime due to write on the
// global issuance config. // global issuance config.
invalidate *atomic2.Bool invalidate *atomic.Bool
// Global revocation queue entries get accepted by the invalidate func // Global revocation queue entries get accepted by the invalidate func
// and passed to the crlBuilder for processing. // and passed to the CrlBuilder for processing.
haveInitializedQueue *atomic2.Bool haveInitializedQueue *atomic.Bool
revQueue *revocationQueue revQueue *revocationQueue
removalQueue *revocationQueue removalQueue *revocationQueue
crossQueue *revocationQueue crossQueue *revocationQueue
@@ -116,29 +117,31 @@ const (
_enforceForceFlag = false _enforceForceFlag = false
) )
func newCRLBuilder(canRebuild bool) *crlBuilder { func newCRLBuilder(canRebuild bool) *CrlBuilder {
return &crlBuilder{ builder := &CrlBuilder{
forceRebuild: atomic2.NewBool(false), forceRebuild: &atomic.Bool{},
canRebuild: canRebuild, canRebuild: canRebuild,
// Set the last delta rebuild window to now, delaying the first delta // Set the last delta rebuild window to now, delaying the first delta
// rebuild by the first rebuild period to give us some time on startup // rebuild by the first rebuild period to give us some time on startup
// to stabilize. // to stabilize.
lastDeltaRebuildCheck: time.Now(), lastDeltaRebuildCheck: time.Now(),
dirty: atomic2.NewBool(true), dirty: &atomic.Bool{},
config: defaultCrlConfig, config: defaultCrlConfig,
invalidate: atomic2.NewBool(false), invalidate: &atomic.Bool{},
haveInitializedQueue: atomic2.NewBool(false), haveInitializedQueue: &atomic.Bool{},
revQueue: newRevocationQueue(), revQueue: newRevocationQueue(),
removalQueue: newRevocationQueue(), removalQueue: newRevocationQueue(),
crossQueue: newRevocationQueue(), crossQueue: newRevocationQueue(),
} }
builder.dirty.Store(true)
return builder
} }
func (cb *crlBuilder) markConfigDirty() { func (cb *CrlBuilder) markConfigDirty() {
cb.dirty.Store(true) cb.dirty.Store(true)
} }
func (cb *crlBuilder) reloadConfigIfRequired(sc *storageContext) error { func (cb *CrlBuilder) reloadConfigIfRequired(sc *storageContext) error {
if cb.dirty.Load() { if cb.dirty.Load() {
// Acquire a write lock. // Acquire a write lock.
cb._config.Lock() cb._config.Lock()
@@ -180,12 +183,12 @@ func (cb *crlBuilder) reloadConfigIfRequired(sc *storageContext) error {
return nil return nil
} }
func (cb *crlBuilder) notifyOnConfigChange(sc *storageContext, priorConfig crlConfig, newConfig crlConfig) { func (cb *CrlBuilder) notifyOnConfigChange(sc *storageContext, priorConfig crlConfig, newConfig crlConfig) {
// If you need to hook into a CRL configuration change across different server types // If you need to hook into a CRL configuration change across different server types
// such as primary clusters as well as performance replicas, it is easier to do here than // such as primary clusters as well as performance replicas, it is easier to do here than
// in two places (API layer and in invalidateFunc) // in two places (API layer and in invalidateFunc)
if priorConfig.UnifiedCRL != newConfig.UnifiedCRL && newConfig.UnifiedCRL { if priorConfig.UnifiedCRL != newConfig.UnifiedCRL && newConfig.UnifiedCRL {
sc.Backend.unifiedTransferStatus.forceRun() sc.Backend.GetUnifiedTransferStatus().forceRun()
} }
if priorConfig.UseGlobalQueue != newConfig.UseGlobalQueue && newConfig.UseGlobalQueue { if priorConfig.UseGlobalQueue != newConfig.UseGlobalQueue && newConfig.UseGlobalQueue {
@@ -193,7 +196,7 @@ func (cb *crlBuilder) notifyOnConfigChange(sc *storageContext, priorConfig crlCo
} }
} }
func (cb *crlBuilder) getConfigWithUpdate(sc *storageContext) (*crlConfig, error) { func (cb *CrlBuilder) getConfigWithUpdate(sc *storageContext) (*crlConfig, error) {
// Config may mutate immediately after accessing, but will be freshly // Config may mutate immediately after accessing, but will be freshly
// fetched if necessary. // fetched if necessary.
if err := cb.reloadConfigIfRequired(sc); err != nil { if err := cb.reloadConfigIfRequired(sc); err != nil {
@@ -207,12 +210,12 @@ func (cb *crlBuilder) getConfigWithUpdate(sc *storageContext) (*crlConfig, error
return &configCopy, nil return &configCopy, nil
} }
func (cb *crlBuilder) getConfigWithForcedUpdate(sc *storageContext) (*crlConfig, error) { func (cb *CrlBuilder) getConfigWithForcedUpdate(sc *storageContext) (*crlConfig, error) {
cb.markConfigDirty() cb.markConfigDirty()
return cb.getConfigWithUpdate(sc) return cb.getConfigWithUpdate(sc)
} }
func (cb *crlBuilder) writeConfig(sc *storageContext, config *crlConfig) (*crlConfig, error) { func (cb *CrlBuilder) writeConfig(sc *storageContext, config *crlConfig) (*crlConfig, error) {
cb._config.Lock() cb._config.Lock()
defer cb._config.Unlock() defer cb._config.Unlock()
@@ -242,7 +245,7 @@ func (cb *crlBuilder) writeConfig(sc *storageContext, config *crlConfig) (*crlCo
return config, nil return config, nil
} }
func (cb *crlBuilder) checkForAutoRebuild(sc *storageContext) error { func (cb *CrlBuilder) checkForAutoRebuild(sc *storageContext) error {
cfg, err := cb.getConfigWithUpdate(sc) cfg, err := cb.getConfigWithUpdate(sc)
if err != nil { if err != nil {
return err return err
@@ -307,14 +310,14 @@ func (cb *crlBuilder) checkForAutoRebuild(sc *storageContext) error {
} }
// Mark the internal LastModifiedTime tracker invalid. // Mark the internal LastModifiedTime tracker invalid.
func (cb *crlBuilder) invalidateCRLBuildTime() { func (cb *CrlBuilder) invalidateCRLBuildTime() {
cb.invalidate.Store(true) cb.invalidate.Store(true)
} }
// Update the config to mark the modified CRL. See note in // Update the config to mark the modified CRL. See note in
// updateDefaultIssuerId about why this is necessary. // updateDefaultIssuerId about why this is necessary.
func (cb *crlBuilder) flushCRLBuildTimeInvalidation(sc *storageContext) error { func (cb *CrlBuilder) flushCRLBuildTimeInvalidation(sc *storageContext) error {
if cb.invalidate.CAS(true, false) { if cb.invalidate.CompareAndSwap(true, false) {
// Flush out our invalidation. // Flush out our invalidation.
cfg, err := sc.getLocalCRLConfig() cfg, err := sc.getLocalCRLConfig()
if err != nil { if err != nil {
@@ -336,7 +339,7 @@ func (cb *crlBuilder) flushCRLBuildTimeInvalidation(sc *storageContext) error {
// rebuildIfForced is to be called by readers or periodic functions that might need to trigger // rebuildIfForced is to be called by readers or periodic functions that might need to trigger
// a refresh of the CRL before the read occurs. // a refresh of the CRL before the read occurs.
func (cb *crlBuilder) rebuildIfForced(sc *storageContext) ([]string, error) { func (cb *CrlBuilder) rebuildIfForced(sc *storageContext) ([]string, error) {
if cb.forceRebuild.Load() { if cb.forceRebuild.Load() {
return cb._doRebuild(sc, true, _enforceForceFlag) return cb._doRebuild(sc, true, _enforceForceFlag)
} }
@@ -345,12 +348,12 @@ func (cb *crlBuilder) rebuildIfForced(sc *storageContext) ([]string, error) {
} }
// rebuild is to be called by various write apis that know the CRL is to be updated and can be now. // rebuild is to be called by various write apis that know the CRL is to be updated and can be now.
func (cb *crlBuilder) rebuild(sc *storageContext, forceNew bool) ([]string, error) { func (cb *CrlBuilder) rebuild(sc *storageContext, forceNew bool) ([]string, error) {
return cb._doRebuild(sc, forceNew, _ignoreForceFlag) return cb._doRebuild(sc, forceNew, _ignoreForceFlag)
} }
// requestRebuildIfActiveNode will schedule a rebuild of the CRL from the next read or write api call assuming we are the active node of a cluster // requestRebuildIfActiveNode will schedule a rebuild of the CRL from the next read or write api call assuming we are the active node of a cluster
func (cb *crlBuilder) requestRebuildIfActiveNode(b *backend) { func (cb *CrlBuilder) requestRebuildIfActiveNode(b *backend) {
// Only schedule us on active nodes, as the active node is the only node that can rebuild/write the CRL. // Only schedule us on active nodes, as the active node is the only node that can rebuild/write the CRL.
// Note 1: The CRL is cluster specific, so this does need to run on the active node of a performance secondary cluster. // Note 1: The CRL is cluster specific, so this does need to run on the active node of a performance secondary cluster.
// Note 2: This is called by the storage invalidation function, so it should not block. // Note 2: This is called by the storage invalidation function, so it should not block.
@@ -364,7 +367,7 @@ func (cb *crlBuilder) requestRebuildIfActiveNode(b *backend) {
cb.forceRebuild.Store(true) cb.forceRebuild.Store(true)
} }
func (cb *crlBuilder) _doRebuild(sc *storageContext, forceNew bool, ignoreForceFlag bool) ([]string, error) { func (cb *CrlBuilder) _doRebuild(sc *storageContext, forceNew bool, ignoreForceFlag bool) ([]string, error) {
cb._builder.Lock() cb._builder.Lock()
defer cb._builder.Unlock() defer cb._builder.Unlock()
// Re-read the lock in case someone beat us to the punch between the previous load op. // Re-read the lock in case someone beat us to the punch between the previous load op.
@@ -384,7 +387,7 @@ func (cb *crlBuilder) _doRebuild(sc *storageContext, forceNew bool, ignoreForceF
return nil, nil return nil, nil
} }
func (cb *crlBuilder) _getPresentDeltaWALForClearing(sc *storageContext, path string) ([]string, error) { func (cb *CrlBuilder) _getPresentDeltaWALForClearing(sc *storageContext, path string) ([]string, error) {
// Clearing of the delta WAL occurs after a new complete CRL has been built. // Clearing of the delta WAL occurs after a new complete CRL has been built.
walSerials, err := sc.Storage.List(sc.Context, path) walSerials, err := sc.Storage.List(sc.Context, path)
if err != nil { if err != nil {
@@ -397,11 +400,11 @@ func (cb *crlBuilder) _getPresentDeltaWALForClearing(sc *storageContext, path st
return walSerials, nil return walSerials, nil
} }
func (cb *crlBuilder) getPresentLocalDeltaWALForClearing(sc *storageContext) ([]string, error) { func (cb *CrlBuilder) getPresentLocalDeltaWALForClearing(sc *storageContext) ([]string, error) {
return cb._getPresentDeltaWALForClearing(sc, localDeltaWALPath) return cb._getPresentDeltaWALForClearing(sc, localDeltaWALPath)
} }
func (cb *crlBuilder) getPresentUnifiedDeltaWALForClearing(sc *storageContext) ([]string, error) { func (cb *CrlBuilder) getPresentUnifiedDeltaWALForClearing(sc *storageContext) ([]string, error) {
walClusters, err := sc.Storage.List(sc.Context, unifiedDeltaWALPrefix) walClusters, err := sc.Storage.List(sc.Context, unifiedDeltaWALPrefix)
if err != nil { if err != nil {
return nil, fmt.Errorf("error fetching list of clusters with delta WAL entries: %w", err) return nil, fmt.Errorf("error fetching list of clusters with delta WAL entries: %w", err)
@@ -426,7 +429,7 @@ func (cb *crlBuilder) getPresentUnifiedDeltaWALForClearing(sc *storageContext) (
return allPaths, nil return allPaths, nil
} }
func (cb *crlBuilder) _clearDeltaWAL(sc *storageContext, walSerials []string, path string) error { func (cb *CrlBuilder) _clearDeltaWAL(sc *storageContext, walSerials []string, path string) error {
// Clearing of the delta WAL occurs after a new complete CRL has been built. // Clearing of the delta WAL occurs after a new complete CRL has been built.
for _, serial := range walSerials { for _, serial := range walSerials {
// Don't remove our special entries! // Don't remove our special entries!
@@ -442,15 +445,15 @@ func (cb *crlBuilder) _clearDeltaWAL(sc *storageContext, walSerials []string, pa
return nil return nil
} }
func (cb *crlBuilder) clearLocalDeltaWAL(sc *storageContext, walSerials []string) error { func (cb *CrlBuilder) clearLocalDeltaWAL(sc *storageContext, walSerials []string) error {
return cb._clearDeltaWAL(sc, walSerials, localDeltaWALPath) return cb._clearDeltaWAL(sc, walSerials, localDeltaWALPath)
} }
func (cb *crlBuilder) clearUnifiedDeltaWAL(sc *storageContext, walSerials []string) error { func (cb *CrlBuilder) clearUnifiedDeltaWAL(sc *storageContext, walSerials []string) error {
return cb._clearDeltaWAL(sc, walSerials, unifiedDeltaWALPrefix) return cb._clearDeltaWAL(sc, walSerials, unifiedDeltaWALPrefix)
} }
func (cb *crlBuilder) rebuildDeltaCRLsIfForced(sc *storageContext, override bool) ([]string, error) { func (cb *CrlBuilder) rebuildDeltaCRLsIfForced(sc *storageContext, override bool) ([]string, error) {
// Delta CRLs use the same expiry duration as the complete CRL. Because // Delta CRLs use the same expiry duration as the complete CRL. Because
// we always rebuild the complete CRL and then the delta CRL, we can // we always rebuild the complete CRL and then the delta CRL, we can
// be assured that the delta CRL always expires after a complete CRL, // be assured that the delta CRL always expires after a complete CRL,
@@ -516,7 +519,7 @@ func (cb *crlBuilder) rebuildDeltaCRLsIfForced(sc *storageContext, override bool
return cb.rebuildDeltaCRLsHoldingLock(sc, false) return cb.rebuildDeltaCRLsHoldingLock(sc, false)
} }
func (cb *crlBuilder) _shouldRebuildLocalCRLs(sc *storageContext, override bool) (bool, error) { func (cb *CrlBuilder) _shouldRebuildLocalCRLs(sc *storageContext, override bool) (bool, error) {
// Fetch two storage entries to see if we actually need to do this // Fetch two storage entries to see if we actually need to do this
// rebuild, given we're within the window. // rebuild, given we're within the window.
lastWALEntry, err := sc.Storage.Get(sc.Context, localDeltaWALLastRevokedSerial) lastWALEntry, err := sc.Storage.Get(sc.Context, localDeltaWALLastRevokedSerial)
@@ -562,7 +565,7 @@ func (cb *crlBuilder) _shouldRebuildLocalCRLs(sc *storageContext, override bool)
return true, nil return true, nil
} }
func (cb *crlBuilder) _shouldRebuildUnifiedCRLs(sc *storageContext, override bool) (bool, error) { func (cb *CrlBuilder) _shouldRebuildUnifiedCRLs(sc *storageContext, override bool) (bool, error) {
// Unified CRL can only be built by the main cluster. // Unified CRL can only be built by the main cluster.
b := sc.Backend b := sc.Backend
if b.System().ReplicationState().HasState(consts.ReplicationDRSecondary|consts.ReplicationPerformanceStandby) || if b.System().ReplicationState().HasState(consts.ReplicationDRSecondary|consts.ReplicationPerformanceStandby) ||
@@ -636,18 +639,18 @@ func (cb *crlBuilder) _shouldRebuildUnifiedCRLs(sc *storageContext, override boo
return shouldRebuild, nil return shouldRebuild, nil
} }
func (cb *crlBuilder) rebuildDeltaCRLs(sc *storageContext, forceNew bool) ([]string, error) { func (cb *CrlBuilder) rebuildDeltaCRLs(sc *storageContext, forceNew bool) ([]string, error) {
cb._builder.Lock() cb._builder.Lock()
defer cb._builder.Unlock() defer cb._builder.Unlock()
return cb.rebuildDeltaCRLsHoldingLock(sc, forceNew) return cb.rebuildDeltaCRLsHoldingLock(sc, forceNew)
} }
func (cb *crlBuilder) rebuildDeltaCRLsHoldingLock(sc *storageContext, forceNew bool) ([]string, error) { func (cb *CrlBuilder) rebuildDeltaCRLsHoldingLock(sc *storageContext, forceNew bool) ([]string, error) {
return buildAnyCRLs(sc, forceNew, true /* building delta */) return buildAnyCRLs(sc, forceNew, true /* building delta */)
} }
func (cb *crlBuilder) addCertForRevocationCheck(cluster, serial string) { func (cb *CrlBuilder) addCertForRevocationCheck(cluster, serial string) {
entry := &revocationQueueEntry{ entry := &revocationQueueEntry{
Cluster: cluster, Cluster: cluster,
Serial: serial, Serial: serial,
@@ -655,7 +658,7 @@ func (cb *crlBuilder) addCertForRevocationCheck(cluster, serial string) {
cb.revQueue.Add(entry) cb.revQueue.Add(entry)
} }
func (cb *crlBuilder) addCertForRevocationRemoval(cluster, serial string) { func (cb *CrlBuilder) addCertForRevocationRemoval(cluster, serial string) {
entry := &revocationQueueEntry{ entry := &revocationQueueEntry{
Cluster: cluster, Cluster: cluster,
Serial: serial, Serial: serial,
@@ -663,7 +666,7 @@ func (cb *crlBuilder) addCertForRevocationRemoval(cluster, serial string) {
cb.removalQueue.Add(entry) cb.removalQueue.Add(entry)
} }
func (cb *crlBuilder) addCertFromCrossRevocation(cluster, serial string) { func (cb *CrlBuilder) addCertFromCrossRevocation(cluster, serial string) {
entry := &revocationQueueEntry{ entry := &revocationQueueEntry{
Cluster: cluster, Cluster: cluster,
Serial: serial, Serial: serial,
@@ -671,7 +674,7 @@ func (cb *crlBuilder) addCertFromCrossRevocation(cluster, serial string) {
cb.crossQueue.Add(entry) cb.crossQueue.Add(entry)
} }
func (cb *crlBuilder) maybeGatherQueueForFirstProcess(sc *storageContext, isNotPerfPrimary bool) error { func (cb *CrlBuilder) maybeGatherQueueForFirstProcess(sc *storageContext, isNotPerfPrimary bool) error {
// Assume holding lock. // Assume holding lock.
if cb.haveInitializedQueue.Load() { if cb.haveInitializedQueue.Load() {
return nil return nil
@@ -727,7 +730,7 @@ func (cb *crlBuilder) maybeGatherQueueForFirstProcess(sc *storageContext, isNotP
return nil return nil
} }
func (cb *crlBuilder) processRevocationQueue(sc *storageContext) error { func (cb *CrlBuilder) processRevocationQueue(sc *storageContext) error {
sc.Backend.Logger().Debug(fmt.Sprintf("starting to process revocation requests")) sc.Backend.Logger().Debug(fmt.Sprintf("starting to process revocation requests"))
isNotPerfPrimary := sc.Backend.System().ReplicationState().HasState(consts.ReplicationDRSecondary|consts.ReplicationPerformanceStandby) || isNotPerfPrimary := sc.Backend.System().ReplicationState().HasState(consts.ReplicationDRSecondary|consts.ReplicationPerformanceStandby) ||
@@ -844,7 +847,7 @@ func (cb *crlBuilder) processRevocationQueue(sc *storageContext) error {
return nil return nil
} }
func (cb *crlBuilder) processCrossClusterRevocations(sc *storageContext) error { func (cb *CrlBuilder) processCrossClusterRevocations(sc *storageContext) error {
sc.Backend.Logger().Debug(fmt.Sprintf("starting to process unified revocations")) sc.Backend.Logger().Debug(fmt.Sprintf("starting to process unified revocations"))
crlConfig, err := cb.getConfigWithUpdate(sc) crlConfig, err := cb.getConfigWithUpdate(sc)
@@ -906,25 +909,25 @@ func (cb *crlBuilder) processCrossClusterRevocations(sc *storageContext) error {
return nil return nil
} }
// Helper function to fetch a map of issuerID->parsed cert for revocation // Helper function to fetch a map of IssuerID->parsed cert for revocation
// usage. Unlike other paths, this needs to handle the legacy bundle // usage. Unlike other paths, this needs to handle the legacy bundle
// more gracefully than rejecting it outright. // more gracefully than rejecting it outright.
func fetchIssuerMapForRevocationChecking(sc *storageContext) (map[issuerID]*x509.Certificate, error) { func fetchIssuerMapForRevocationChecking(sc *storageContext) (map[issuing.IssuerID]*x509.Certificate, error) {
var err error var err error
var issuers []issuerID var issuers []issuing.IssuerID
if !sc.Backend.useLegacyBundleCaStorage() { if !sc.Backend.UseLegacyBundleCaStorage() {
issuers, err = sc.listIssuers() issuers, err = sc.listIssuers()
if err != nil { if err != nil {
return nil, fmt.Errorf("could not fetch issuers list: %w", err) return nil, fmt.Errorf("could not fetch issuers list: %w", err)
} }
} else { } else {
// Hack: this isn't a real issuerID, but it works for fetchCAInfo // Hack: this isn't a real IssuerID, but it works for fetchCAInfo
// since it resolves the reference. // since it resolves the reference.
issuers = []issuerID{legacyBundleShimID} issuers = []issuing.IssuerID{legacyBundleShimID}
} }
issuerIDCertMap := make(map[issuerID]*x509.Certificate, len(issuers)) issuerIDCertMap := make(map[issuing.IssuerID]*x509.Certificate, len(issuers))
for _, issuer := range issuers { for _, issuer := range issuers {
_, bundle, caErr := sc.fetchCertBundleByIssuerId(issuer, false) _, bundle, caErr := sc.fetchCertBundleByIssuerId(issuer, false)
if caErr != nil { if caErr != nil {
@@ -954,8 +957,8 @@ func fetchIssuerMapForRevocationChecking(sc *storageContext) (map[issuerID]*x509
// storage. // storage.
func tryRevokeCertBySerial(sc *storageContext, config *crlConfig, serial string) (*logical.Response, error) { func tryRevokeCertBySerial(sc *storageContext, config *crlConfig, serial string) (*logical.Response, error) {
// revokeCert requires us to hold these locks before calling it. // revokeCert requires us to hold these locks before calling it.
sc.Backend.revokeStorageLock.Lock() sc.Backend.GetRevokeStorageLock().Lock()
defer sc.Backend.revokeStorageLock.Unlock() defer sc.Backend.GetRevokeStorageLock().Unlock()
certEntry, err := fetchCertBySerial(sc, "certs/", serial) certEntry, err := fetchCertBySerial(sc, "certs/", serial)
if err != nil { if err != nil {
@@ -1052,12 +1055,13 @@ func revokeCert(sc *storageContext, config *crlConfig, cert *x509.Certificate) (
return nil, fmt.Errorf("error creating revocation entry: %w", err) return nil, fmt.Errorf("error creating revocation entry: %w", err)
} }
certsCounted := sc.Backend.certsCounted.Load() certCounter := sc.Backend.GetCertificateCounter()
certsCounted := certCounter.IsInitialized()
err = sc.Storage.Put(sc.Context, revEntry) err = sc.Storage.Put(sc.Context, revEntry)
if err != nil { if err != nil {
return nil, fmt.Errorf("error saving revoked certificate to new location: %w", err) return nil, fmt.Errorf("error saving revoked certificate to new location: %w", err)
} }
sc.Backend.ifCountEnabledIncrementTotalRevokedCertificatesCount(certsCounted, revEntry.Key) certCounter.IncrementTotalRevokedCertificatesCount(certsCounted, revEntry.Key)
// From here on out, the certificate has been revoked locally. Any other // From here on out, the certificate has been revoked locally. Any other
// persistence issues might still err, but any other failure messages // persistence issues might still err, but any other failure messages
@@ -1087,7 +1091,7 @@ func revokeCert(sc *storageContext, config *crlConfig, cert *x509.Certificate) (
// thread will reattempt it later on as we have the local write done. // thread will reattempt it later on as we have the local write done.
sc.Backend.Logger().Error("Failed to write unified revocation entry, will re-attempt later", sc.Backend.Logger().Error("Failed to write unified revocation entry, will re-attempt later",
"serial_number", colonSerial, "error", ignoreErr) "serial_number", colonSerial, "error", ignoreErr)
sc.Backend.unifiedTransferStatus.forceRun() sc.Backend.GetUnifiedTransferStatus().forceRun()
resp.AddWarning(fmt.Sprintf("Failed to write unified revocation entry, will re-attempt later: %v", err)) resp.AddWarning(fmt.Sprintf("Failed to write unified revocation entry, will re-attempt later: %v", err))
failedWritingUnifiedCRL = true failedWritingUnifiedCRL = true
@@ -1099,7 +1103,7 @@ func revokeCert(sc *storageContext, config *crlConfig, cert *x509.Certificate) (
// already rebuilt the full CRL so the Delta WAL will be cleared // already rebuilt the full CRL so the Delta WAL will be cleared
// afterwards. Writing an entry only to immediately remove it // afterwards. Writing an entry only to immediately remove it
// isn't necessary. // isn't necessary.
warnings, crlErr := sc.Backend.crlBuilder.rebuild(sc, false) warnings, crlErr := sc.Backend.CrlBuilder().rebuild(sc, false)
if crlErr != nil { if crlErr != nil {
switch crlErr.(type) { switch crlErr.(type) {
case errutil.UserError: case errutil.UserError:
@@ -1144,7 +1148,7 @@ func writeRevocationDeltaWALs(sc *storageContext, config *crlConfig, resp *logic
// thread will reattempt it later on as we have the local write done. // thread will reattempt it later on as we have the local write done.
sc.Backend.Logger().Error("Failed to write cross-cluster delta WAL entry, will re-attempt later", sc.Backend.Logger().Error("Failed to write cross-cluster delta WAL entry, will re-attempt later",
"serial_number", colonSerial, "error", ignoredErr) "serial_number", colonSerial, "error", ignoredErr)
sc.Backend.unifiedTransferStatus.forceRun() sc.Backend.GetUnifiedTransferStatus().forceRun()
resp.AddWarning(fmt.Sprintf("Failed to write cross-cluster delta WAL entry, will re-attempt later: %v", ignoredErr)) resp.AddWarning(fmt.Sprintf("Failed to write cross-cluster delta WAL entry, will re-attempt later: %v", ignoredErr))
} }
@@ -1235,12 +1239,12 @@ func buildAnyCRLs(sc *storageContext, forceNew bool, isDelta bool) ([]string, er
// See the message in revokedCert about rebuilding CRLs: we need to // See the message in revokedCert about rebuilding CRLs: we need to
// gracefully handle revoking entries with the legacy cert bundle. // gracefully handle revoking entries with the legacy cert bundle.
var err error var err error
var issuers []issuerID var issuers []issuing.IssuerID
var wasLegacy bool var wasLegacy bool
// First, fetch an updated copy of the CRL config. We'll pass this into // First, fetch an updated copy of the CRL config. We'll pass this into buildCRL.
// buildCRL. crlBuilder := sc.Backend.CrlBuilder()
globalCRLConfig, err := sc.Backend.crlBuilder.getConfigWithUpdate(sc) globalCRLConfig, err := crlBuilder.getConfigWithUpdate(sc)
if err != nil { if err != nil {
return nil, fmt.Errorf("error building CRL: while updating config: %w", err) return nil, fmt.Errorf("error building CRL: while updating config: %w", err)
} }
@@ -1257,7 +1261,7 @@ func buildAnyCRLs(sc *storageContext, forceNew bool, isDelta bool) ([]string, er
return nil, nil return nil, nil
} }
if !sc.Backend.useLegacyBundleCaStorage() { if !sc.Backend.UseLegacyBundleCaStorage() {
issuers, err = sc.listIssuers() issuers, err = sc.listIssuers()
if err != nil { if err != nil {
return nil, fmt.Errorf("error building CRL: while listing issuers: %w", err) return nil, fmt.Errorf("error building CRL: while listing issuers: %w", err)
@@ -1266,7 +1270,7 @@ func buildAnyCRLs(sc *storageContext, forceNew bool, isDelta bool) ([]string, er
// Here, we hard-code the legacy issuer entry instead of using the // Here, we hard-code the legacy issuer entry instead of using the
// default ref. This is because we need to hack some of the logic // default ref. This is because we need to hack some of the logic
// below for revocation to handle the legacy bundle. // below for revocation to handle the legacy bundle.
issuers = []issuerID{legacyBundleShimID} issuers = []issuing.IssuerID{legacyBundleShimID}
wasLegacy = true wasLegacy = true
// Here, we avoid building a delta CRL with the legacy CRL bundle. // Here, we avoid building a delta CRL with the legacy CRL bundle.
@@ -1283,15 +1287,15 @@ func buildAnyCRLs(sc *storageContext, forceNew bool, isDelta bool) ([]string, er
return nil, fmt.Errorf("error building CRLs: while getting the default config: %w", err) return nil, fmt.Errorf("error building CRLs: while getting the default config: %w", err)
} }
// We map issuerID->entry for fast lookup and also issuerID->Cert for // We map IssuerID->entry for fast lookup and also IssuerID->Cert for
// signature verification and correlation of revoked certs. // signature verification and correlation of revoked certs.
issuerIDEntryMap := make(map[issuerID]*issuerEntry, len(issuers)) issuerIDEntryMap := make(map[issuing.IssuerID]*issuing.IssuerEntry, len(issuers))
issuerIDCertMap := make(map[issuerID]*x509.Certificate, len(issuers)) issuerIDCertMap := make(map[issuing.IssuerID]*x509.Certificate, len(issuers))
// We use a double map (keyID->subject->issuerID) to store whether or not this // We use a double map (KeyID->subject->IssuerID) to store whether or not this
// key+subject paring has been seen before. We can then iterate over each // key+subject paring has been seen before. We can then iterate over each
// key/subject and choose any representative issuer for that combination. // key/subject and choose any representative issuer for that combination.
keySubjectIssuersMap := make(map[keyID]map[string][]issuerID) keySubjectIssuersMap := make(map[issuing.KeyID]map[string][]issuing.IssuerID)
for _, issuer := range issuers { for _, issuer := range issuers {
// We don't strictly need this call, but by requesting the bundle, the // We don't strictly need this call, but by requesting the bundle, the
// legacy path is automatically ignored. // legacy path is automatically ignored.
@@ -1328,7 +1332,7 @@ func buildAnyCRLs(sc *storageContext, forceNew bool, isDelta bool) ([]string, er
subject := string(thisCert.RawSubject) subject := string(thisCert.RawSubject)
if _, ok := keySubjectIssuersMap[thisEntry.KeyID]; !ok { if _, ok := keySubjectIssuersMap[thisEntry.KeyID]; !ok {
keySubjectIssuersMap[thisEntry.KeyID] = make(map[string][]issuerID) keySubjectIssuersMap[thisEntry.KeyID] = make(map[string][]issuing.IssuerID)
} }
keySubjectIssuersMap[thisEntry.KeyID][subject] = append(keySubjectIssuersMap[thisEntry.KeyID][subject], issuer) keySubjectIssuersMap[thisEntry.KeyID][subject] = append(keySubjectIssuersMap[thisEntry.KeyID][subject], issuer)
@@ -1365,13 +1369,13 @@ func buildAnyCRLs(sc *storageContext, forceNew bool, isDelta bool) ([]string, er
if !isDelta { if !isDelta {
// After we've confirmed the primary CRLs have built OK, go ahead and // After we've confirmed the primary CRLs have built OK, go ahead and
// clear the delta CRL WAL and rebuild it. // clear the delta CRL WAL and rebuild it.
if err := sc.Backend.crlBuilder.clearLocalDeltaWAL(sc, currLocalDeltaSerials); err != nil { if err := crlBuilder.clearLocalDeltaWAL(sc, currLocalDeltaSerials); err != nil {
return nil, fmt.Errorf("error building CRLs: unable to clear Delta WAL: %w", err) return nil, fmt.Errorf("error building CRLs: unable to clear Delta WAL: %w", err)
} }
if err := sc.Backend.crlBuilder.clearUnifiedDeltaWAL(sc, currUnifiedDeltaSerials); err != nil { if err := crlBuilder.clearUnifiedDeltaWAL(sc, currUnifiedDeltaSerials); err != nil {
return nil, fmt.Errorf("error building CRLs: unable to clear Delta WAL: %w", err) return nil, fmt.Errorf("error building CRLs: unable to clear Delta WAL: %w", err)
} }
deltaWarnings, err := sc.Backend.crlBuilder.rebuildDeltaCRLsHoldingLock(sc, forceNew) deltaWarnings, err := crlBuilder.rebuildDeltaCRLsHoldingLock(sc, forceNew)
if err != nil { if err != nil {
return nil, fmt.Errorf("error building CRLs: unable to rebuild empty Delta WAL: %w", err) return nil, fmt.Errorf("error building CRLs: unable to rebuild empty Delta WAL: %w", err)
} }
@@ -1404,12 +1408,12 @@ func getLastWALSerial(sc *storageContext, path string) (string, error) {
func buildAnyLocalCRLs( func buildAnyLocalCRLs(
sc *storageContext, sc *storageContext,
issuersConfig *issuerConfigEntry, issuersConfig *issuing.IssuerConfigEntry,
globalCRLConfig *crlConfig, globalCRLConfig *crlConfig,
issuers []issuerID, issuers []issuing.IssuerID,
issuerIDEntryMap map[issuerID]*issuerEntry, issuerIDEntryMap map[issuing.IssuerID]*issuing.IssuerEntry,
issuerIDCertMap map[issuerID]*x509.Certificate, issuerIDCertMap map[issuing.IssuerID]*x509.Certificate,
keySubjectIssuersMap map[keyID]map[string][]issuerID, keySubjectIssuersMap map[issuing.KeyID]map[string][]issuing.IssuerID,
wasLegacy bool, wasLegacy bool,
forceNew bool, forceNew bool,
isDelta bool, isDelta bool,
@@ -1435,14 +1439,14 @@ func buildAnyLocalCRLs(
// visible now, should also be visible on the complete CRL we're writing. // visible now, should also be visible on the complete CRL we're writing.
var currDeltaCerts []string var currDeltaCerts []string
if !isDelta { if !isDelta {
currDeltaCerts, err = sc.Backend.crlBuilder.getPresentLocalDeltaWALForClearing(sc) currDeltaCerts, err = sc.Backend.CrlBuilder().getPresentLocalDeltaWALForClearing(sc)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("error building CRLs: unable to get present delta WAL entries for removal: %w", err) return nil, nil, fmt.Errorf("error building CRLs: unable to get present delta WAL entries for removal: %w", err)
} }
} }
var unassignedCerts []pkix.RevokedCertificate var unassignedCerts []pkix.RevokedCertificate
var revokedCertsMap map[issuerID][]pkix.RevokedCertificate var revokedCertsMap map[issuing.IssuerID][]pkix.RevokedCertificate
// If the CRL is disabled do not bother reading in all the revoked certificates. // If the CRL is disabled do not bother reading in all the revoked certificates.
if !globalCRLConfig.Disable { if !globalCRLConfig.Disable {
@@ -1499,7 +1503,7 @@ func buildAnyLocalCRLs(
if isDelta { if isDelta {
// Update our last build time here so we avoid checking for new certs // Update our last build time here so we avoid checking for new certs
// for a while. // for a while.
sc.Backend.crlBuilder.lastDeltaRebuildCheck = time.Now() sc.Backend.CrlBuilder().lastDeltaRebuildCheck = time.Now()
if len(lastDeltaSerial) > 0 { if len(lastDeltaSerial) > 0 {
// When we have a last delta serial, write out the relevant info // When we have a last delta serial, write out the relevant info
@@ -1523,12 +1527,12 @@ func buildAnyLocalCRLs(
func buildAnyUnifiedCRLs( func buildAnyUnifiedCRLs(
sc *storageContext, sc *storageContext,
issuersConfig *issuerConfigEntry, issuersConfig *issuing.IssuerConfigEntry,
globalCRLConfig *crlConfig, globalCRLConfig *crlConfig,
issuers []issuerID, issuers []issuing.IssuerID,
issuerIDEntryMap map[issuerID]*issuerEntry, issuerIDEntryMap map[issuing.IssuerID]*issuing.IssuerEntry,
issuerIDCertMap map[issuerID]*x509.Certificate, issuerIDCertMap map[issuing.IssuerID]*x509.Certificate,
keySubjectIssuersMap map[keyID]map[string][]issuerID, keySubjectIssuersMap map[issuing.KeyID]map[string][]issuing.IssuerID,
wasLegacy bool, wasLegacy bool,
forceNew bool, forceNew bool,
isDelta bool, isDelta bool,
@@ -1578,14 +1582,14 @@ func buildAnyUnifiedCRLs(
// visible now, should also be visible on the complete CRL we're writing. // visible now, should also be visible on the complete CRL we're writing.
var currDeltaCerts []string var currDeltaCerts []string
if !isDelta { if !isDelta {
currDeltaCerts, err = sc.Backend.crlBuilder.getPresentUnifiedDeltaWALForClearing(sc) currDeltaCerts, err = sc.Backend.CrlBuilder().getPresentUnifiedDeltaWALForClearing(sc)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("error building CRLs: unable to get present delta WAL entries for removal: %w", err) return nil, nil, fmt.Errorf("error building CRLs: unable to get present delta WAL entries for removal: %w", err)
} }
} }
var unassignedCerts []pkix.RevokedCertificate var unassignedCerts []pkix.RevokedCertificate
var revokedCertsMap map[issuerID][]pkix.RevokedCertificate var revokedCertsMap map[issuing.IssuerID][]pkix.RevokedCertificate
// If the CRL is disabled do not bother reading in all the revoked certificates. // If the CRL is disabled do not bother reading in all the revoked certificates.
if !globalCRLConfig.Disable { if !globalCRLConfig.Disable {
@@ -1642,7 +1646,7 @@ func buildAnyUnifiedCRLs(
if isDelta { if isDelta {
// Update our last build time here so we avoid checking for new certs // Update our last build time here so we avoid checking for new certs
// for a while. // for a while.
sc.Backend.crlBuilder.lastDeltaRebuildCheck = time.Now() sc.Backend.CrlBuilder().lastDeltaRebuildCheck = time.Now()
// Persist all of our known last revoked serial numbers here, as the // Persist all of our known last revoked serial numbers here, as the
// last seen serial during build. This will allow us to detect if any // last seen serial during build. This will allow us to detect if any
@@ -1674,20 +1678,20 @@ func buildAnyUnifiedCRLs(
func buildAnyCRLsWithCerts( func buildAnyCRLsWithCerts(
sc *storageContext, sc *storageContext,
issuersConfig *issuerConfigEntry, issuersConfig *issuing.IssuerConfigEntry,
globalCRLConfig *crlConfig, globalCRLConfig *crlConfig,
internalCRLConfig *internalCRLConfigEntry, internalCRLConfig *issuing.InternalCRLConfigEntry,
issuers []issuerID, issuers []issuing.IssuerID,
issuerIDEntryMap map[issuerID]*issuerEntry, issuerIDEntryMap map[issuing.IssuerID]*issuing.IssuerEntry,
keySubjectIssuersMap map[keyID]map[string][]issuerID, keySubjectIssuersMap map[issuing.KeyID]map[string][]issuing.IssuerID,
unassignedCerts []pkix.RevokedCertificate, unassignedCerts []pkix.RevokedCertificate,
revokedCertsMap map[issuerID][]pkix.RevokedCertificate, revokedCertsMap map[issuing.IssuerID][]pkix.RevokedCertificate,
forceNew bool, forceNew bool,
isUnified bool, isUnified bool,
isDelta bool, isDelta bool,
) ([]string, error) { ) ([]string, error) {
// Now we can call buildCRL once, on an arbitrary/representative issuer // Now we can call buildCRL once, on an arbitrary/representative issuer
// from each of these (keyID, subject) sets. // from each of these (KeyID, subject) sets.
var warnings []string var warnings []string
for _, subjectIssuersMap := range keySubjectIssuersMap { for _, subjectIssuersMap := range keySubjectIssuersMap {
for _, issuersSet := range subjectIssuersMap { for _, issuersSet := range subjectIssuersMap {
@@ -1696,15 +1700,15 @@ func buildAnyCRLsWithCerts(
} }
var revokedCerts []pkix.RevokedCertificate var revokedCerts []pkix.RevokedCertificate
representative := issuerID("") representative := issuing.IssuerID("")
var crlIdentifier crlID var crlIdentifier issuing.CrlID
var crlIdIssuer issuerID var crlIdIssuer issuing.IssuerID
for _, issuerId := range issuersSet { for _, issuerId := range issuersSet {
// Skip entries which aren't enabled for CRL signing. We don't // Skip entries which aren't enabled for CRL signing. We don't
// particularly care which issuer is ultimately chosen as the // particularly care which issuer is ultimately chosen as the
// set representative for signing at this point, other than // set representative for signing at this point, other than
// that it has crl-signing usage. // that it has crl-signing usage.
if err := issuerIDEntryMap[issuerId].EnsureUsage(CRLSigningUsage); err != nil { if err := issuerIDEntryMap[issuerId].EnsureUsage(issuing.CRLSigningUsage); err != nil {
continue continue
} }
@@ -1724,7 +1728,7 @@ func buildAnyCRLsWithCerts(
// Otherwise, use any other random issuer if we've not yet // Otherwise, use any other random issuer if we've not yet
// chosen one. // chosen one.
if representative == issuerID("") { if representative == issuing.IssuerID("") {
representative = issuerId representative = issuerId
} }
@@ -1864,7 +1868,7 @@ func buildAnyCRLsWithCerts(
return warnings, nil return warnings, nil
} }
func isRevInfoIssuerValid(revInfo *revocationInfo, issuerIDCertMap map[issuerID]*x509.Certificate) bool { func isRevInfoIssuerValid(revInfo *revocationInfo, issuerIDCertMap map[issuing.IssuerID]*x509.Certificate) bool {
if len(revInfo.CertificateIssuer) > 0 { if len(revInfo.CertificateIssuer) > 0 {
issuerId := revInfo.CertificateIssuer issuerId := revInfo.CertificateIssuer
if _, issuerExists := issuerIDCertMap[issuerId]; issuerExists { if _, issuerExists := issuerIDCertMap[issuerId]; issuerExists {
@@ -1875,7 +1879,7 @@ func isRevInfoIssuerValid(revInfo *revocationInfo, issuerIDCertMap map[issuerID]
return false return false
} }
func associateRevokedCertWithIsssuer(revInfo *revocationInfo, revokedCert *x509.Certificate, issuerIDCertMap map[issuerID]*x509.Certificate) bool { func associateRevokedCertWithIsssuer(revInfo *revocationInfo, revokedCert *x509.Certificate, issuerIDCertMap map[issuing.IssuerID]*x509.Certificate) bool {
for issuerId, issuerCert := range issuerIDCertMap { for issuerId, issuerCert := range issuerIDCertMap {
if bytes.Equal(revokedCert.RawIssuer, issuerCert.RawSubject) { if bytes.Equal(revokedCert.RawIssuer, issuerCert.RawSubject) {
if err := revokedCert.CheckSignatureFrom(issuerCert); err == nil { if err := revokedCert.CheckSignatureFrom(issuerCert); err == nil {
@@ -1889,9 +1893,9 @@ func associateRevokedCertWithIsssuer(revInfo *revocationInfo, revokedCert *x509.
return false return false
} }
func getLocalRevokedCertEntries(sc *storageContext, issuerIDCertMap map[issuerID]*x509.Certificate, isDelta bool) ([]pkix.RevokedCertificate, map[issuerID][]pkix.RevokedCertificate, error) { func getLocalRevokedCertEntries(sc *storageContext, issuerIDCertMap map[issuing.IssuerID]*x509.Certificate, isDelta bool) ([]pkix.RevokedCertificate, map[issuing.IssuerID][]pkix.RevokedCertificate, error) {
var unassignedCerts []pkix.RevokedCertificate var unassignedCerts []pkix.RevokedCertificate
revokedCertsMap := make(map[issuerID][]pkix.RevokedCertificate) revokedCertsMap := make(map[issuing.IssuerID][]pkix.RevokedCertificate)
listingPath := revokedPath listingPath := revokedPath
if isDelta { if isDelta {
@@ -2018,13 +2022,13 @@ func getLocalRevokedCertEntries(sc *storageContext, issuerIDCertMap map[issuerID
return unassignedCerts, revokedCertsMap, nil return unassignedCerts, revokedCertsMap, nil
} }
func getUnifiedRevokedCertEntries(sc *storageContext, issuerIDCertMap map[issuerID]*x509.Certificate, isDelta bool) ([]pkix.RevokedCertificate, map[issuerID][]pkix.RevokedCertificate, error) { func getUnifiedRevokedCertEntries(sc *storageContext, issuerIDCertMap map[issuing.IssuerID]*x509.Certificate, isDelta bool) ([]pkix.RevokedCertificate, map[issuing.IssuerID][]pkix.RevokedCertificate, error) {
// Getting unified revocation entries is a bit different than getting // Getting unified revocation entries is a bit different than getting
// the local ones. In particular, the full copy of the certificate is // the local ones. In particular, the full copy of the certificate is
// unavailable, so we'll be able to avoid parsing the stored certificate, // unavailable, so we'll be able to avoid parsing the stored certificate,
// at the expense of potentially having incorrect issuer mappings. // at the expense of potentially having incorrect issuer mappings.
var unassignedCerts []pkix.RevokedCertificate var unassignedCerts []pkix.RevokedCertificate
revokedCertsMap := make(map[issuerID][]pkix.RevokedCertificate) revokedCertsMap := make(map[issuing.IssuerID][]pkix.RevokedCertificate)
listingPath := unifiedRevocationReadPathPrefix listingPath := unifiedRevocationReadPathPrefix
if isDelta { if isDelta {
@@ -2114,7 +2118,7 @@ func getUnifiedRevokedCertEntries(sc *storageContext, issuerIDCertMap map[issuer
return unassignedCerts, revokedCertsMap, nil return unassignedCerts, revokedCertsMap, nil
} }
func augmentWithRevokedIssuers(issuerIDEntryMap map[issuerID]*issuerEntry, issuerIDCertMap map[issuerID]*x509.Certificate, revokedCertsMap map[issuerID][]pkix.RevokedCertificate) error { func augmentWithRevokedIssuers(issuerIDEntryMap map[issuing.IssuerID]*issuing.IssuerEntry, issuerIDCertMap map[issuing.IssuerID]*x509.Certificate, revokedCertsMap map[issuing.IssuerID][]pkix.RevokedCertificate) error {
// When setup our maps with the legacy CA bundle, we only have a // When setup our maps with the legacy CA bundle, we only have a
// single entry here. This entry is never revoked, so the outer loop // single entry here. This entry is never revoked, so the outer loop
// will exit quickly. // will exit quickly.
@@ -2150,7 +2154,7 @@ func augmentWithRevokedIssuers(issuerIDEntryMap map[issuerID]*issuerEntry, issue
// Builds a CRL by going through the list of revoked certificates and building // Builds a CRL by going through the list of revoked certificates and building
// a new CRL with the stored revocation times and serial numbers. // a new CRL with the stored revocation times and serial numbers.
func buildCRL(sc *storageContext, crlInfo *crlConfig, forceNew bool, thisIssuerId issuerID, revoked []pkix.RevokedCertificate, identifier crlID, crlNumber int64, isUnified bool, isDelta bool, lastCompleteNumber int64) (*time.Time, error) { func buildCRL(sc *storageContext, crlInfo *crlConfig, forceNew bool, thisIssuerId issuing.IssuerID, revoked []pkix.RevokedCertificate, identifier issuing.CrlID, crlNumber int64, isUnified bool, isDelta bool, lastCompleteNumber int64) (*time.Time, error) {
var revokedCerts []pkix.RevokedCertificate var revokedCerts []pkix.RevokedCertificate
crlLifetime, err := parseutil.ParseDurationSecond(crlInfo.Expiry) crlLifetime, err := parseutil.ParseDurationSecond(crlInfo.Expiry)
@@ -2177,7 +2181,7 @@ func buildCRL(sc *storageContext, crlInfo *crlConfig, forceNew bool, thisIssuerI
revokedCerts = revoked revokedCerts = revoked
WRITE: WRITE:
signingBundle, caErr := sc.fetchCAInfoByIssuerId(thisIssuerId, CRLSigningUsage) signingBundle, caErr := sc.fetchCAInfoByIssuerId(thisIssuerId, issuing.CRLSigningUsage)
if caErr != nil { if caErr != nil {
switch caErr.(type) { switch caErr.(type) {
case errutil.UserError: case errutil.UserError:

View File

@@ -6,6 +6,7 @@ package pki
import ( import (
"time" "time"
"github.com/hashicorp/vault/builtin/logical/pki/issuing"
"github.com/hashicorp/vault/sdk/framework" "github.com/hashicorp/vault/sdk/framework"
) )
@@ -597,7 +598,7 @@ basic constraints.`,
func addSignVerbatimRoleFields(fields map[string]*framework.FieldSchema) map[string]*framework.FieldSchema { func addSignVerbatimRoleFields(fields map[string]*framework.FieldSchema) map[string]*framework.FieldSchema {
fields["key_usage"] = &framework.FieldSchema{ fields["key_usage"] = &framework.FieldSchema{
Type: framework.TypeCommaStringSlice, Type: framework.TypeCommaStringSlice,
Default: []string{"DigitalSignature", "KeyAgreement", "KeyEncipherment"}, Default: issuing.DefaultRoleKeyUsages,
Description: `A comma-separated string or list of key usages (not extended Description: `A comma-separated string or list of key usages (not extended
key usages). Valid values can be found at key usages). Valid values can be found at
https://golang.org/pkg/crypto/x509/#KeyUsage https://golang.org/pkg/crypto/x509/#KeyUsage
@@ -608,7 +609,7 @@ this value to an empty list.`,
fields["ext_key_usage"] = &framework.FieldSchema{ fields["ext_key_usage"] = &framework.FieldSchema{
Type: framework.TypeCommaStringSlice, Type: framework.TypeCommaStringSlice,
Default: []string{}, Default: issuing.DefaultRoleEstKeyUsages,
Description: `A comma-separated string or list of extended key usages. Valid values can be found at Description: `A comma-separated string or list of extended key usages. Valid values can be found at
https://golang.org/pkg/crypto/x509/#ExtKeyUsage https://golang.org/pkg/crypto/x509/#ExtKeyUsage
-- simply drop the "ExtKeyUsage" part of the name. -- simply drop the "ExtKeyUsage" part of the name.
@@ -618,24 +619,25 @@ this value to an empty list.`,
fields["ext_key_usage_oids"] = &framework.FieldSchema{ fields["ext_key_usage_oids"] = &framework.FieldSchema{
Type: framework.TypeCommaStringSlice, Type: framework.TypeCommaStringSlice,
Default: issuing.DefaultRoleEstKeyUsageOids,
Description: `A comma-separated string or list of extended key usage oids.`, Description: `A comma-separated string or list of extended key usage oids.`,
} }
fields["signature_bits"] = &framework.FieldSchema{ fields["signature_bits"] = &framework.FieldSchema{
Type: framework.TypeInt, Type: framework.TypeInt,
Default: 0, Default: issuing.DefaultRoleSignatureBits,
Description: `The number of bits to use in the signature Description: `The number of bits to use in the signature
algorithm; accepts 256 for SHA-2-256, 384 for SHA-2-384, and 512 for algorithm; accepts 256 for SHA-2-256, 384 for SHA-2-384, and 512 for
SHA-2-512. Defaults to 0 to automatically detect based on key length SHA-2-512. Defaults to 0 to automatically detect based on key length
(SHA-2-256 for RSA keys, and matching the curve size for NIST P-Curves).`, (SHA-2-256 for RSA keys, and matching the curve size for NIST P-Curves).`,
DisplayAttrs: &framework.DisplayAttributes{ DisplayAttrs: &framework.DisplayAttributes{
Value: 0, Value: issuing.DefaultRoleSignatureBits,
}, },
} }
fields["use_pss"] = &framework.FieldSchema{ fields["use_pss"] = &framework.FieldSchema{
Type: framework.TypeBool, Type: framework.TypeBool,
Default: false, Default: issuing.DefaultRoleUsePss,
Description: `Whether or not to use PSS signatures when using a Description: `Whether or not to use PSS signatures when using a
RSA key-type issuer. Defaults to false.`, RSA key-type issuer. Defaults to false.`,
} }

View File

@@ -15,6 +15,7 @@ import (
"testing" "testing"
"github.com/hashicorp/vault/api" "github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/builtin/logical/pki/issuing"
vaulthttp "github.com/hashicorp/vault/http" vaulthttp "github.com/hashicorp/vault/http"
vaultocsp "github.com/hashicorp/vault/sdk/helper/ocsp" vaultocsp "github.com/hashicorp/vault/sdk/helper/ocsp"
"github.com/hashicorp/vault/sdk/helper/testhelpers/schema" "github.com/hashicorp/vault/sdk/helper/testhelpers/schema"
@@ -41,7 +42,7 @@ func TestIntegration_RotateRootUsesNext(t *testing.T) {
require.NotNil(t, resp, "got nil response from rotate root") require.NotNil(t, resp, "got nil response from rotate root")
require.False(t, resp.IsError(), "got an error from rotate root: %#v", resp) require.False(t, resp.IsError(), "got an error from rotate root: %#v", resp)
issuerId1 := resp.Data["issuer_id"].(issuerID) issuerId1 := resp.Data["issuer_id"].(issuing.IssuerID)
issuerName1 := resp.Data["issuer_name"] issuerName1 := resp.Data["issuer_name"]
require.NotEmpty(t, issuerId1, "issuer id was empty on initial rotate root command") require.NotEmpty(t, issuerId1, "issuer id was empty on initial rotate root command")
@@ -61,7 +62,7 @@ func TestIntegration_RotateRootUsesNext(t *testing.T) {
require.NotNil(t, resp, "got nil response from rotate root") require.NotNil(t, resp, "got nil response from rotate root")
require.False(t, resp.IsError(), "got an error from rotate root: %#v", resp) require.False(t, resp.IsError(), "got an error from rotate root: %#v", resp)
issuerId2 := resp.Data["issuer_id"].(issuerID) issuerId2 := resp.Data["issuer_id"].(issuing.IssuerID)
issuerName2 := resp.Data["issuer_name"] issuerName2 := resp.Data["issuer_name"]
require.NotEmpty(t, issuerId2, "issuer id was empty on second rotate root command") require.NotEmpty(t, issuerId2, "issuer id was empty on second rotate root command")
@@ -83,7 +84,7 @@ func TestIntegration_RotateRootUsesNext(t *testing.T) {
require.NotNil(t, resp, "got nil response from rotate root") require.NotNil(t, resp, "got nil response from rotate root")
require.False(t, resp.IsError(), "got an error from rotate root: %#v", resp) require.False(t, resp.IsError(), "got an error from rotate root: %#v", resp)
issuerId3 := resp.Data["issuer_id"].(issuerID) issuerId3 := resp.Data["issuer_id"].(issuing.IssuerID)
issuerName3 := resp.Data["issuer_name"] issuerName3 := resp.Data["issuer_name"]
require.NotEmpty(t, issuerId3, "issuer id was empty on third rotate root command") require.NotEmpty(t, issuerId3, "issuer id was empty on third rotate root command")
@@ -436,7 +437,7 @@ func TestIntegration_AutoIssuer(t *testing.T) {
"pem_bundle": certOne, "pem_bundle": certOne,
}) })
requireSuccessNonNilResponse(t, resp, err) requireSuccessNonNilResponse(t, resp, err)
issuerIdOneReimported := issuerID(resp.Data["imported_issuers"].([]string)[0]) issuerIdOneReimported := issuing.IssuerID(resp.Data["imported_issuers"].([]string)[0])
resp, err = CBRead(b, s, "config/issuers") resp, err = CBRead(b, s, "config/issuers")
requireSuccessNonNilResponse(t, resp, err) requireSuccessNonNilResponse(t, resp, err)
@@ -643,11 +644,11 @@ func TestIntegrationOCSPClientWithPKI(t *testing.T) {
} }
} }
func genTestRootCa(t *testing.T, b *backend, s logical.Storage) (issuerID, keyID) { func genTestRootCa(t *testing.T, b *backend, s logical.Storage) (issuing.IssuerID, issuing.KeyID) {
return genTestRootCaWithIssuerName(t, b, s, "") return genTestRootCaWithIssuerName(t, b, s, "")
} }
func genTestRootCaWithIssuerName(t *testing.T, b *backend, s logical.Storage, issuerName string) (issuerID, keyID) { func genTestRootCaWithIssuerName(t *testing.T, b *backend, s logical.Storage, issuerName string) (issuing.IssuerID, issuing.KeyID) {
data := map[string]interface{}{ data := map[string]interface{}{
"common_name": "test.com", "common_name": "test.com",
} }
@@ -665,8 +666,8 @@ func genTestRootCaWithIssuerName(t *testing.T, b *backend, s logical.Storage, is
require.NotNil(t, resp, "got nil response from generating root ca") require.NotNil(t, resp, "got nil response from generating root ca")
require.False(t, resp.IsError(), "got an error from generating root ca: %#v", resp) require.False(t, resp.IsError(), "got an error from generating root ca: %#v", resp)
issuerId := resp.Data["issuer_id"].(issuerID) issuerId := resp.Data["issuer_id"].(issuing.IssuerID)
keyId := resp.Data["key_id"].(keyID) keyId := resp.Data["key_id"].(issuing.KeyID)
require.NotEmpty(t, issuerId, "returned issuer id was empty") require.NotEmpty(t, issuerId, "returned issuer id was empty")
require.NotEmpty(t, keyId, "returned key id was empty") require.NotEmpty(t, keyId, "returned key id was empty")

View File

@@ -0,0 +1,154 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package issuing
import (
"context"
"fmt"
"strings"
"github.com/asaskevich/govalidator"
"github.com/hashicorp/vault/sdk/helper/certutil"
"github.com/hashicorp/vault/sdk/logical"
)
const ClusterConfigPath = "config/cluster"
type AiaConfigEntry struct {
IssuingCertificates []string `json:"issuing_certificates"`
CRLDistributionPoints []string `json:"crl_distribution_points"`
OCSPServers []string `json:"ocsp_servers"`
EnableTemplating bool `json:"enable_templating"`
}
type ClusterConfigEntry struct {
Path string `json:"path"`
AIAPath string `json:"aia_path"`
}
func GetAIAURLs(ctx context.Context, s logical.Storage, i *IssuerEntry) (*certutil.URLEntries, error) {
// Default to the per-issuer AIA URLs.
entries := i.AIAURIs
// If none are set (either due to a nil entry or because no URLs have
// been provided), fall back to the global AIA URL config.
if entries == nil || (len(entries.IssuingCertificates) == 0 && len(entries.CRLDistributionPoints) == 0 && len(entries.OCSPServers) == 0) {
var err error
entries, err = GetGlobalAIAURLs(ctx, s)
if err != nil {
return nil, err
}
}
if entries == nil {
return &certutil.URLEntries{}, nil
}
return ToURLEntries(ctx, s, i.ID, entries)
}
func GetGlobalAIAURLs(ctx context.Context, storage logical.Storage) (*AiaConfigEntry, error) {
entry, err := storage.Get(ctx, "urls")
if err != nil {
return nil, err
}
entries := &AiaConfigEntry{
IssuingCertificates: []string{},
CRLDistributionPoints: []string{},
OCSPServers: []string{},
EnableTemplating: false,
}
if entry == nil {
return entries, nil
}
if err := entry.DecodeJSON(entries); err != nil {
return nil, err
}
return entries, nil
}
func ToURLEntries(ctx context.Context, s logical.Storage, issuer IssuerID, c *AiaConfigEntry) (*certutil.URLEntries, error) {
if len(c.IssuingCertificates) == 0 && len(c.CRLDistributionPoints) == 0 && len(c.OCSPServers) == 0 {
return &certutil.URLEntries{}, nil
}
result := certutil.URLEntries{
IssuingCertificates: c.IssuingCertificates[:],
CRLDistributionPoints: c.CRLDistributionPoints[:],
OCSPServers: c.OCSPServers[:],
}
if c.EnableTemplating {
cfg, err := GetClusterConfig(ctx, s)
if err != nil {
return nil, fmt.Errorf("error fetching cluster-local address config: %w", err)
}
for name, source := range map[string]*[]string{
"issuing_certificates": &result.IssuingCertificates,
"crl_distribution_points": &result.CRLDistributionPoints,
"ocsp_servers": &result.OCSPServers,
} {
templated := make([]string, len(*source))
for index, uri := range *source {
if strings.Contains(uri, "{{cluster_path}}") && len(cfg.Path) == 0 {
return nil, fmt.Errorf("unable to template AIA URLs as we lack local cluster address information (path)")
}
if strings.Contains(uri, "{{cluster_aia_path}}") && len(cfg.AIAPath) == 0 {
return nil, fmt.Errorf("unable to template AIA URLs as we lack local cluster address information (aia_path)")
}
if strings.Contains(uri, "{{issuer_id}}") && len(issuer) == 0 {
// Elide issuer AIA info as we lack an issuer_id.
return nil, fmt.Errorf("unable to template AIA URLs as we lack an issuer_id for this operation")
}
uri = strings.ReplaceAll(uri, "{{cluster_path}}", cfg.Path)
uri = strings.ReplaceAll(uri, "{{cluster_aia_path}}", cfg.AIAPath)
uri = strings.ReplaceAll(uri, "{{issuer_id}}", issuer.String())
templated[index] = uri
}
if uri := ValidateURLs(templated); uri != "" {
return nil, fmt.Errorf("error validating templated %v; invalid URI: %v", name, uri)
}
*source = templated
}
}
return &result, nil
}
func GetClusterConfig(ctx context.Context, s logical.Storage) (*ClusterConfigEntry, error) {
entry, err := s.Get(ctx, ClusterConfigPath)
if err != nil {
return nil, err
}
var result ClusterConfigEntry
if entry == nil {
return &result, nil
}
if err = entry.DecodeJSON(&result); err != nil {
return nil, err
}
return &result, nil
}
func ValidateURLs(urls []string) string {
for _, curr := range urls {
if !govalidator.IsURL(curr) || strings.Contains(curr, "{{issuer_id}}") || strings.Contains(curr, "{{cluster_path}}") || strings.Contains(curr, "{{cluster_aia_path}}") {
return curr
}
}
return ""
}

View File

@@ -0,0 +1,124 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package issuing
import (
"context"
"fmt"
"strings"
"time"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/sdk/helper/errutil"
"github.com/hashicorp/vault/sdk/logical"
)
const StorageIssuerConfig = "config/issuers"
type IssuerConfigEntry struct {
// This new fetchedDefault field allows us to detect if the default
// issuer was modified, in turn dispatching the timestamp updater
// if necessary.
fetchedDefault IssuerID `json:"-"`
DefaultIssuerId IssuerID `json:"default"`
DefaultFollowsLatestIssuer bool `json:"default_follows_latest_issuer"`
}
func GetIssuersConfig(ctx context.Context, s logical.Storage) (*IssuerConfigEntry, error) {
entry, err := s.Get(ctx, StorageIssuerConfig)
if err != nil {
return nil, err
}
issuerConfig := &IssuerConfigEntry{}
if entry != nil {
if err := entry.DecodeJSON(issuerConfig); err != nil {
return nil, errutil.InternalError{Err: fmt.Sprintf("unable to decode issuer configuration: %v", err)}
}
}
issuerConfig.fetchedDefault = issuerConfig.DefaultIssuerId
return issuerConfig, nil
}
func SetIssuersConfig(ctx context.Context, s logical.Storage, config *IssuerConfigEntry) error {
json, err := logical.StorageEntryJSON(StorageIssuerConfig, config)
if err != nil {
return err
}
if err := s.Put(ctx, json); err != nil {
return err
}
if err := changeDefaultIssuerTimestamps(ctx, s, config.fetchedDefault, config.DefaultIssuerId); err != nil {
return err
}
return nil
}
func changeDefaultIssuerTimestamps(ctx context.Context, s logical.Storage, oldDefault IssuerID, newDefault IssuerID) error {
if newDefault == oldDefault {
return nil
}
now := time.Now().UTC()
// When the default issuer changes, we need to modify four
// pieces of information:
//
// 1. The old default issuer's modification time, as it no
// longer works for the /cert/ca path.
// 2. The new default issuer's modification time, as it now
// works for the /cert/ca path.
// 3. & 4. Both issuer's CRLs, as they behave the same, under
// the /cert/crl path!
for _, thisId := range []IssuerID{oldDefault, newDefault} {
if len(thisId) == 0 {
continue
}
// 1 & 2 above.
issuer, err := FetchIssuerById(ctx, s, thisId)
if err != nil {
// Due to the lack of transactions, if we deleted the default
// issuer (successfully), but the subsequent issuer config write
// (to clear the default issuer's old id) failed, we might have
// an inconsistent config. If we later hit this loop (and flush
// these timestamps again -- perhaps because the operator
// selected a new default), we'd have erred out here, because
// the since-deleted default issuer doesn't exist. In this case,
// skip the issuer instead of bailing.
err := fmt.Errorf("unable to update issuer (%v)'s modification time: error fetching issuer: %w", thisId, err)
if strings.Contains(err.Error(), "does not exist") {
hclog.L().Warn(err.Error())
continue
}
return err
}
issuer.LastModified = now
err = WriteIssuer(ctx, s, issuer)
if err != nil {
return fmt.Errorf("unable to update issuer (%v)'s modification time: error persisting issuer: %w", thisId, err)
}
}
// Fetch and update the internalCRLConfigEntry (3&4).
cfg, err := GetLocalCRLConfig(ctx, s)
if err != nil {
return fmt.Errorf("unable to update local CRL config's modification time: error fetching local CRL config: %w", err)
}
cfg.LastModified = now
cfg.DeltaLastModified = now
err = SetLocalCRLConfig(ctx, s, cfg)
if err != nil {
return fmt.Errorf("unable to update local CRL config's modification time: error persisting local CRL config: %w", err)
}
return nil
}

View File

@@ -0,0 +1,45 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package issuing
import (
"context"
"fmt"
"github.com/hashicorp/vault/sdk/helper/errutil"
"github.com/hashicorp/vault/sdk/logical"
)
const (
StorageKeyConfig = "config/keys"
)
type KeyConfigEntry struct {
DefaultKeyId KeyID `json:"default"`
}
func SetKeysConfig(ctx context.Context, s logical.Storage, config *KeyConfigEntry) error {
json, err := logical.StorageEntryJSON(StorageKeyConfig, config)
if err != nil {
return err
}
return s.Put(ctx, json)
}
func GetKeysConfig(ctx context.Context, s logical.Storage) (*KeyConfigEntry, error) {
entry, err := s.Get(ctx, StorageKeyConfig)
if err != nil {
return nil, err
}
keyConfig := &KeyConfigEntry{}
if entry != nil {
if err := entry.DecodeJSON(keyConfig); err != nil {
return nil, errutil.InternalError{Err: fmt.Sprintf("unable to decode key configuration: %v", err)}
}
}
return keyConfig, nil
}

View File

@@ -0,0 +1,190 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package issuing
import (
"context"
"fmt"
"strings"
"time"
"github.com/hashicorp/vault/sdk/helper/errutil"
"github.com/hashicorp/vault/sdk/logical"
)
const (
StorageLocalCRLConfig = "crls/config"
StorageUnifiedCRLConfig = "unified-crls/config"
)
type InternalCRLConfigEntry struct {
IssuerIDCRLMap map[IssuerID]CrlID `json:"issuer_id_crl_map"`
CRLNumberMap map[CrlID]int64 `json:"crl_number_map"`
LastCompleteNumberMap map[CrlID]int64 `json:"last_complete_number_map"`
CRLExpirationMap map[CrlID]time.Time `json:"crl_expiration_map"`
LastModified time.Time `json:"last_modified"`
DeltaLastModified time.Time `json:"delta_last_modified"`
UseGlobalQueue bool `json:"cross_cluster_revocation"`
}
type CrlID string
func (p CrlID) String() string {
return string(p)
}
func GetLocalCRLConfig(ctx context.Context, s logical.Storage) (*InternalCRLConfigEntry, error) {
return _getInternalCRLConfig(ctx, s, StorageLocalCRLConfig)
}
func GetUnifiedCRLConfig(ctx context.Context, s logical.Storage) (*InternalCRLConfigEntry, error) {
return _getInternalCRLConfig(ctx, s, StorageUnifiedCRLConfig)
}
func _getInternalCRLConfig(ctx context.Context, s logical.Storage, path string) (*InternalCRLConfigEntry, error) {
entry, err := s.Get(ctx, path)
if err != nil {
return nil, err
}
mapping := &InternalCRLConfigEntry{}
if entry != nil {
if err := entry.DecodeJSON(mapping); err != nil {
return nil, errutil.InternalError{Err: fmt.Sprintf("unable to decode cluster-local CRL configuration: %v", err)}
}
}
if len(mapping.IssuerIDCRLMap) == 0 {
mapping.IssuerIDCRLMap = make(map[IssuerID]CrlID)
}
if len(mapping.CRLNumberMap) == 0 {
mapping.CRLNumberMap = make(map[CrlID]int64)
}
if len(mapping.LastCompleteNumberMap) == 0 {
mapping.LastCompleteNumberMap = make(map[CrlID]int64)
// Since this might not exist on migration, we want to guess as
// to the last full CRL number was. This was likely the last
// Value from CRLNumberMap if it existed, since we're just adding
// the mapping here in this block.
//
// After the next full CRL build, we will have set this Value
// correctly, so it doesn't really matter in the long term if
// we're off here.
for id, number := range mapping.CRLNumberMap {
// Decrement by one, since CRLNumberMap is the future number,
// not the last built number.
mapping.LastCompleteNumberMap[id] = number - 1
}
}
if len(mapping.CRLExpirationMap) == 0 {
mapping.CRLExpirationMap = make(map[CrlID]time.Time)
}
return mapping, nil
}
func SetLocalCRLConfig(ctx context.Context, s logical.Storage, mapping *InternalCRLConfigEntry) error {
return _setInternalCRLConfig(ctx, s, mapping, StorageLocalCRLConfig)
}
func SetUnifiedCRLConfig(ctx context.Context, s logical.Storage, mapping *InternalCRLConfigEntry) error {
return _setInternalCRLConfig(ctx, s, mapping, StorageUnifiedCRLConfig)
}
func _setInternalCRLConfig(ctx context.Context, s logical.Storage, mapping *InternalCRLConfigEntry, path string) error {
if err := _cleanupInternalCRLMapping(ctx, s, mapping, path); err != nil {
return fmt.Errorf("failed to clean up internal CRL mapping: %w", err)
}
json, err := logical.StorageEntryJSON(path, mapping)
if err != nil {
return err
}
return s.Put(ctx, json)
}
func _cleanupInternalCRLMapping(ctx context.Context, s logical.Storage, mapping *InternalCRLConfigEntry, path string) error {
// Track which CRL IDs are presently referred to by issuers; any other CRL
// IDs are subject to cleanup.
//
// Unused IDs both need to be removed from this map (cleaning up the size
// of this storage entry) but also the full CRLs removed from disk.
presentMap := make(map[CrlID]bool)
for _, id := range mapping.IssuerIDCRLMap {
presentMap[id] = true
}
// Identify which CRL IDs exist and are candidates for removal;
// theoretically these three maps should be in sync, but were added
// at different times.
toRemove := make(map[CrlID]bool)
for id := range mapping.CRLNumberMap {
if !presentMap[id] {
toRemove[id] = true
}
}
for id := range mapping.LastCompleteNumberMap {
if !presentMap[id] {
toRemove[id] = true
}
}
for id := range mapping.CRLExpirationMap {
if !presentMap[id] {
toRemove[id] = true
}
}
// Depending on which path we're writing this config to, we need to
// remove CRLs from the relevant folder too.
isLocal := path == StorageLocalCRLConfig
baseCRLPath := "crls/"
if !isLocal {
baseCRLPath = "unified-crls/"
}
for id := range toRemove {
// Clean up space in this mapping...
delete(mapping.CRLNumberMap, id)
delete(mapping.LastCompleteNumberMap, id)
delete(mapping.CRLExpirationMap, id)
// And clean up space on disk from the fat CRL mapping.
crlPath := baseCRLPath + string(id)
deltaCRLPath := crlPath + "-delta"
if err := s.Delete(ctx, crlPath); err != nil {
return fmt.Errorf("failed to delete unreferenced CRL %v: %w", id, err)
}
if err := s.Delete(ctx, deltaCRLPath); err != nil {
return fmt.Errorf("failed to delete unreferenced delta CRL %v: %w", id, err)
}
}
// Lastly, some CRLs could've been partially removed from the map but
// not from disk. Check to see if we have any dangling CRLs and remove
// them too.
list, err := s.List(ctx, baseCRLPath)
if err != nil {
return fmt.Errorf("failed listing all CRLs: %w", err)
}
for _, crl := range list {
if crl == "config" || strings.HasSuffix(crl, "/") {
continue
}
if presentMap[CrlID(crl)] {
continue
}
if err := s.Delete(ctx, baseCRLPath+"/"+crl); err != nil {
return fmt.Errorf("failed cleaning up orphaned CRL %v: %w", crl, err)
}
}
return nil
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,495 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package issuing
import (
"context"
"crypto/x509"
"fmt"
"sort"
"strings"
"time"
"github.com/hashicorp/vault/sdk/helper/certutil"
"github.com/hashicorp/vault/sdk/helper/errutil"
"github.com/hashicorp/vault/sdk/logical"
"github.com/hashicorp/vault/builtin/logical/pki/managed_key"
"github.com/hashicorp/vault/builtin/logical/pki/parsing"
)
const (
ReadOnlyUsage IssuerUsage = iota
IssuanceUsage IssuerUsage = 1 << iota
CRLSigningUsage IssuerUsage = 1 << iota
OCSPSigningUsage IssuerUsage = 1 << iota
)
const (
// When adding a new usage in the future, we'll need to create a usage
// mask field on the IssuerEntry and handle migrations to a newer mask,
// inferring a value for the new bits.
AllIssuerUsages = ReadOnlyUsage | IssuanceUsage | CRLSigningUsage | OCSPSigningUsage
DefaultRef = "default"
IssuerPrefix = "config/issuer/"
// Used as a quick sanity check for a reference id lookups...
uuidLength = 36
IssuerRefNotFound = IssuerID("not-found")
LatestIssuerVersion = 1
LegacyCertBundlePath = "config/ca_bundle"
LegacyBundleShimID = IssuerID("legacy-entry-shim-id")
LegacyBundleShimKeyID = KeyID("legacy-entry-shim-key-id")
)
type IssuerID string
func (p IssuerID) String() string {
return string(p)
}
type IssuerUsage uint
var namedIssuerUsages = map[string]IssuerUsage{
"read-only": ReadOnlyUsage,
"issuing-certificates": IssuanceUsage,
"crl-signing": CRLSigningUsage,
"ocsp-signing": OCSPSigningUsage,
}
func (i *IssuerUsage) ToggleUsage(usages ...IssuerUsage) {
for _, usage := range usages {
*i ^= usage
}
}
func (i IssuerUsage) HasUsage(usage IssuerUsage) bool {
return (i & usage) == usage
}
func (i IssuerUsage) Names() string {
var names []string
var builtUsage IssuerUsage
// Return the known set of usages in a sorted order to not have Terraform state files flipping
// saying values are different when it's the same list in a different order.
keys := make([]string, 0, len(namedIssuerUsages))
for k := range namedIssuerUsages {
keys = append(keys, k)
}
sort.Strings(keys)
for _, name := range keys {
usage := namedIssuerUsages[name]
if i.HasUsage(usage) {
names = append(names, name)
builtUsage.ToggleUsage(usage)
}
}
if i != builtUsage {
// Found some unknown usage, we should indicate this in the names.
names = append(names, fmt.Sprintf("unknown:%v", i^builtUsage))
}
return strings.Join(names, ",")
}
func NewIssuerUsageFromNames(names []string) (IssuerUsage, error) {
var result IssuerUsage
for index, name := range names {
usage, ok := namedIssuerUsages[name]
if !ok {
return ReadOnlyUsage, fmt.Errorf("unknown name for usage at index %v: %v", index, name)
}
result.ToggleUsage(usage)
}
return result, nil
}
type IssuerEntry struct {
ID IssuerID `json:"id"`
Name string `json:"name"`
KeyID KeyID `json:"key_id"`
Certificate string `json:"certificate"`
CAChain []string `json:"ca_chain"`
ManualChain []IssuerID `json:"manual_chain"`
SerialNumber string `json:"serial_number"`
LeafNotAfterBehavior certutil.NotAfterBehavior `json:"not_after_behavior"`
Usage IssuerUsage `json:"usage"`
RevocationSigAlg x509.SignatureAlgorithm `json:"revocation_signature_algorithm"`
Revoked bool `json:"revoked"`
RevocationTime int64 `json:"revocation_time"`
RevocationTimeUTC time.Time `json:"revocation_time_utc"`
AIAURIs *AiaConfigEntry `json:"aia_uris,omitempty"`
LastModified time.Time `json:"last_modified"`
Version uint `json:"version"`
}
func (i IssuerEntry) GetCertificate() (*x509.Certificate, error) {
cert, err := parsing.ParseCertificateFromBytes([]byte(i.Certificate))
if err != nil {
return nil, errutil.InternalError{Err: fmt.Sprintf("unable to parse certificate from issuer: %s: %v", err.Error(), i.ID)}
}
return cert, nil
}
func (i IssuerEntry) EnsureUsage(usage IssuerUsage) error {
// We want to spit out a nice error message about missing usages.
if i.Usage.HasUsage(usage) {
return nil
}
issuerRef := fmt.Sprintf("id:%v", i.ID)
if len(i.Name) > 0 {
issuerRef = fmt.Sprintf("%v / name:%v", issuerRef, i.Name)
}
// These usages differ at some point in time. We've gotta find the first
// usage that differs and return a logical-sounding error message around
// that difference.
for name, candidate := range namedIssuerUsages {
if usage.HasUsage(candidate) && !i.Usage.HasUsage(candidate) {
return fmt.Errorf("requested usage %v for issuer [%v] but only had usage %v", name, issuerRef, i.Usage.Names())
}
}
// Maybe we have an unnamed usage that's requested.
return fmt.Errorf("unknown delta between usages: %v -> %v / for issuer [%v]", usage.Names(), i.Usage.Names(), issuerRef)
}
func (i IssuerEntry) CanMaybeSignWithAlgo(algo x509.SignatureAlgorithm) error {
// Hack: Go isn't kind enough expose its lovely signatureAlgorithmDetails
// informational struct for our usage. However, we don't want to actually
// fetch the private key and attempt a signature with this algo (as we'll
// mint new, previously unsigned material in the process that could maybe
// be potentially abused if it leaks).
//
// So...
//
// ...we maintain our own mapping of cert.PKI<->sigAlgos. Notably, we
// exclude DSA support as the PKI engine has never supported DSA keys.
if algo == x509.UnknownSignatureAlgorithm {
// Special cased to indicate upgrade and letting Go automatically
// chose the correct value.
return nil
}
cert, err := i.GetCertificate()
if err != nil {
return fmt.Errorf("unable to parse issuer's potential signature algorithm types: %w", err)
}
switch cert.PublicKeyAlgorithm {
case x509.RSA:
switch algo {
case x509.SHA256WithRSA, x509.SHA384WithRSA, x509.SHA512WithRSA,
x509.SHA256WithRSAPSS, x509.SHA384WithRSAPSS,
x509.SHA512WithRSAPSS:
return nil
}
case x509.ECDSA:
switch algo {
case x509.ECDSAWithSHA256, x509.ECDSAWithSHA384, x509.ECDSAWithSHA512:
return nil
}
case x509.Ed25519:
switch algo {
case x509.PureEd25519:
return nil
}
}
return fmt.Errorf("unable to use issuer of type %v to sign with %v key type", cert.PublicKeyAlgorithm.String(), algo.String())
}
func ResolveIssuerReference(ctx context.Context, s logical.Storage, reference string) (IssuerID, error) {
if reference == DefaultRef {
// Handle fetching the default issuer.
config, err := GetIssuersConfig(ctx, s)
if err != nil {
return IssuerID("config-error"), err
}
if len(config.DefaultIssuerId) == 0 {
return IssuerRefNotFound, fmt.Errorf("no default issuer currently configured")
}
return config.DefaultIssuerId, nil
}
// Lookup by a direct get first to see if our reference is an ID, this is quick and cached.
if len(reference) == uuidLength {
entry, err := s.Get(ctx, IssuerPrefix+reference)
if err != nil {
return IssuerID("issuer-read"), err
}
if entry != nil {
return IssuerID(reference), nil
}
}
// ... than to pull all issuers from storage.
issuers, err := ListIssuers(ctx, s)
if err != nil {
return IssuerID("list-error"), err
}
for _, issuerId := range issuers {
issuer, err := FetchIssuerById(ctx, s, issuerId)
if err != nil {
return IssuerID("issuer-read"), err
}
if issuer.Name == reference {
return issuer.ID, nil
}
}
// Otherwise, we must not have found the issuer.
return IssuerRefNotFound, errutil.UserError{Err: fmt.Sprintf("unable to find PKI issuer for reference: %v", reference)}
}
func ListIssuers(ctx context.Context, s logical.Storage) ([]IssuerID, error) {
strList, err := s.List(ctx, IssuerPrefix)
if err != nil {
return nil, err
}
issuerIds := make([]IssuerID, 0, len(strList))
for _, entry := range strList {
issuerIds = append(issuerIds, IssuerID(entry))
}
return issuerIds, nil
}
// FetchIssuerById returns an IssuerEntry based on issuerId, if none found an error is returned.
func FetchIssuerById(ctx context.Context, s logical.Storage, issuerId IssuerID) (*IssuerEntry, error) {
if len(issuerId) == 0 {
return nil, errutil.InternalError{Err: "unable to fetch pki issuer: empty issuer identifier"}
}
entry, err := s.Get(ctx, IssuerPrefix+issuerId.String())
if err != nil {
return nil, errutil.InternalError{Err: fmt.Sprintf("unable to fetch pki issuer: %v", err)}
}
if entry == nil {
return nil, errutil.UserError{Err: fmt.Sprintf("pki issuer id %s does not exist", issuerId.String())}
}
var issuer IssuerEntry
if err := entry.DecodeJSON(&issuer); err != nil {
return nil, errutil.InternalError{Err: fmt.Sprintf("unable to decode pki issuer with id %s: %v", issuerId.String(), err)}
}
return upgradeIssuerIfRequired(&issuer), nil
}
func WriteIssuer(ctx context.Context, s logical.Storage, issuer *IssuerEntry) error {
issuerId := issuer.ID
if issuer.LastModified.IsZero() {
issuer.LastModified = time.Now().UTC()
}
json, err := logical.StorageEntryJSON(IssuerPrefix+issuerId.String(), issuer)
if err != nil {
return err
}
return s.Put(ctx, json)
}
func DeleteIssuer(ctx context.Context, s logical.Storage, id IssuerID) (bool, error) {
config, err := GetIssuersConfig(ctx, s)
if err != nil {
return false, err
}
wasDefault := false
if config.DefaultIssuerId == id {
wasDefault = true
// Overwrite the fetched default issuer as we're going to remove this
// entry.
config.fetchedDefault = IssuerID("")
config.DefaultIssuerId = IssuerID("")
if err := SetIssuersConfig(ctx, s, config); err != nil {
return wasDefault, err
}
}
return wasDefault, s.Delete(ctx, IssuerPrefix+id.String())
}
func upgradeIssuerIfRequired(issuer *IssuerEntry) *IssuerEntry {
// *NOTE*: Don't attempt to write out the issuer here as it may cause ErrReadOnly that will direct the
// request all the way up to the primary cluster which would be horrible for local cluster operations such
// as generating a leaf cert or a revoke.
// Also even though we could tell if we are the primary cluster's active node, we can't tell if we have the
// a full rw issuer lock, so it might not be safe to write.
if issuer.Version == LatestIssuerVersion {
return issuer
}
if issuer.Version == 0 {
// Upgrade at this step requires interrogating the certificate itself;
// if this decode fails, it indicates internal problems and the
// request will subsequently fail elsewhere. However, decoding this
// certificate is mildly expensive, so we only do it in the event of
// a Version 0 certificate.
cert, err := issuer.GetCertificate()
if err != nil {
return issuer
}
hadCRL := issuer.Usage.HasUsage(CRLSigningUsage)
// Remove CRL signing usage if it exists on the issuer but doesn't
// exist in the KU of the x509 certificate.
if hadCRL && (cert.KeyUsage&x509.KeyUsageCRLSign) == 0 {
issuer.Usage.ToggleUsage(CRLSigningUsage)
}
// Handle our new OCSPSigning usage flag for earlier versions. If we
// had it (prior to removing it in this upgrade), we'll add the OCSP
// flag since EKUs don't matter.
if hadCRL && !issuer.Usage.HasUsage(OCSPSigningUsage) {
issuer.Usage.ToggleUsage(OCSPSigningUsage)
}
}
issuer.Version = LatestIssuerVersion
return issuer
}
// FetchCAInfoByIssuerId will fetch the CA info, will return an error if no ca info exists for the given issuerId.
// This does support the loading using the legacyBundleShimID
func FetchCAInfoByIssuerId(ctx context.Context, s logical.Storage, mkv managed_key.PkiManagedKeyView, issuerId IssuerID, usage IssuerUsage) (*certutil.CAInfoBundle, error) {
entry, bundle, err := FetchCertBundleByIssuerId(ctx, s, issuerId, true)
if err != nil {
switch err.(type) {
case errutil.UserError:
return nil, err
case errutil.InternalError:
return nil, err
default:
return nil, errutil.InternalError{Err: fmt.Sprintf("error fetching CA info: %v", err)}
}
}
if err = entry.EnsureUsage(usage); err != nil {
return nil, errutil.InternalError{Err: fmt.Sprintf("error while attempting to use issuer %v: %v", issuerId, err)}
}
parsedBundle, err := ParseCABundle(ctx, mkv, bundle)
if err != nil {
return nil, errutil.InternalError{Err: err.Error()}
}
if parsedBundle.Certificate == nil {
return nil, errutil.InternalError{Err: "stored CA information not able to be parsed"}
}
if parsedBundle.PrivateKey == nil {
return nil, errutil.UserError{Err: fmt.Sprintf("unable to fetch corresponding key for issuer %v; unable to use this issuer for signing", issuerId)}
}
caInfo := &certutil.CAInfoBundle{
ParsedCertBundle: *parsedBundle,
URLs: nil,
LeafNotAfterBehavior: entry.LeafNotAfterBehavior,
RevocationSigAlg: entry.RevocationSigAlg,
}
entries, err := GetAIAURLs(ctx, s, entry)
if err != nil {
return nil, errutil.InternalError{Err: fmt.Sprintf("unable to fetch AIA URL information: %v", err)}
}
caInfo.URLs = entries
return caInfo, nil
}
func ParseCABundle(ctx context.Context, mkv managed_key.PkiManagedKeyView, bundle *certutil.CertBundle) (*certutil.ParsedCertBundle, error) {
if bundle.PrivateKeyType == certutil.ManagedPrivateKey {
return managed_key.ParseManagedKeyCABundle(ctx, mkv, bundle)
}
return bundle.ToParsedCertBundle()
}
// FetchCertBundleByIssuerId builds a certutil.CertBundle from the specified issuer identifier,
// optionally loading the key or not. This method supports loading legacy
// bundles using the legacyBundleShimID issuerId, and if no entry is found will return an error.
func FetchCertBundleByIssuerId(ctx context.Context, s logical.Storage, id IssuerID, loadKey bool) (*IssuerEntry, *certutil.CertBundle, error) {
if id == LegacyBundleShimID {
// We have not completed the migration, or started a request in legacy mode, so
// attempt to load the bundle from the legacy location
issuer, bundle, err := GetLegacyCertBundle(ctx, s)
if err != nil {
return nil, nil, err
}
if issuer == nil || bundle == nil {
return nil, nil, errutil.UserError{Err: "no legacy cert bundle exists"}
}
return issuer, bundle, err
}
issuer, err := FetchIssuerById(ctx, s, id)
if err != nil {
return nil, nil, err
}
var bundle certutil.CertBundle
bundle.Certificate = issuer.Certificate
bundle.CAChain = issuer.CAChain
bundle.SerialNumber = issuer.SerialNumber
// Fetch the key if it exists. Sometimes we don't need the key immediately.
if loadKey && issuer.KeyID != KeyID("") {
key, err := FetchKeyById(ctx, s, issuer.KeyID)
if err != nil {
return nil, nil, err
}
bundle.PrivateKeyType = key.PrivateKeyType
bundle.PrivateKey = key.PrivateKey
}
return issuer, &bundle, nil
}
func GetLegacyCertBundle(ctx context.Context, s logical.Storage) (*IssuerEntry, *certutil.CertBundle, error) {
entry, err := s.Get(ctx, LegacyCertBundlePath)
if err != nil {
return nil, nil, err
}
if entry == nil {
return nil, nil, nil
}
cb := &certutil.CertBundle{}
err = entry.DecodeJSON(cb)
if err != nil {
return nil, nil, err
}
// Fake a storage entry with backwards compatibility in mind.
issuer := &IssuerEntry{
ID: LegacyBundleShimID,
KeyID: LegacyBundleShimKeyID,
Name: "legacy-entry-shim",
Certificate: cb.Certificate,
CAChain: cb.CAChain,
SerialNumber: cb.SerialNumber,
LeafNotAfterBehavior: certutil.ErrNotAfterBehavior,
}
issuer.Usage.ToggleUsage(AllIssuerUsages)
return issuer, cb, nil
}

View File

@@ -0,0 +1,153 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package issuing
import (
"context"
"fmt"
"github.com/hashicorp/vault/sdk/helper/certutil"
"github.com/hashicorp/vault/sdk/helper/errutil"
"github.com/hashicorp/vault/sdk/logical"
"github.com/hashicorp/vault/builtin/logical/pki/managed_key"
)
const (
KeyPrefix = "config/key/"
KeyRefNotFound = KeyID("not-found")
)
type KeyID string
func (p KeyID) String() string {
return string(p)
}
type KeyEntry struct {
ID KeyID `json:"id"`
Name string `json:"name"`
PrivateKeyType certutil.PrivateKeyType `json:"private_key_type"`
PrivateKey string `json:"private_key"`
}
func (e KeyEntry) IsManagedPrivateKey() bool {
return e.PrivateKeyType == certutil.ManagedPrivateKey
}
func ListKeys(ctx context.Context, s logical.Storage) ([]KeyID, error) {
strList, err := s.List(ctx, KeyPrefix)
if err != nil {
return nil, err
}
keyIds := make([]KeyID, 0, len(strList))
for _, entry := range strList {
keyIds = append(keyIds, KeyID(entry))
}
return keyIds, nil
}
func FetchKeyById(ctx context.Context, s logical.Storage, keyId KeyID) (*KeyEntry, error) {
if len(keyId) == 0 {
return nil, errutil.InternalError{Err: "unable to fetch pki key: empty key identifier"}
}
entry, err := s.Get(ctx, KeyPrefix+keyId.String())
if err != nil {
return nil, errutil.InternalError{Err: fmt.Sprintf("unable to fetch pki key: %v", err)}
}
if entry == nil {
return nil, errutil.UserError{Err: fmt.Sprintf("pki key id %s does not exist", keyId.String())}
}
var key KeyEntry
if err := entry.DecodeJSON(&key); err != nil {
return nil, errutil.InternalError{Err: fmt.Sprintf("unable to decode pki key with id %s: %v", keyId.String(), err)}
}
return &key, nil
}
func WriteKey(ctx context.Context, s logical.Storage, key KeyEntry) error {
keyId := key.ID
json, err := logical.StorageEntryJSON(KeyPrefix+keyId.String(), key)
if err != nil {
return err
}
return s.Put(ctx, json)
}
func DeleteKey(ctx context.Context, s logical.Storage, id KeyID) (bool, error) {
config, err := GetKeysConfig(ctx, s)
if err != nil {
return false, err
}
wasDefault := false
if config.DefaultKeyId == id {
wasDefault = true
config.DefaultKeyId = KeyID("")
if err := SetKeysConfig(ctx, s, config); err != nil {
return wasDefault, err
}
}
return wasDefault, s.Delete(ctx, KeyPrefix+id.String())
}
func ResolveKeyReference(ctx context.Context, s logical.Storage, reference string) (KeyID, error) {
if reference == DefaultRef {
// Handle fetching the default key.
config, err := GetKeysConfig(ctx, s)
if err != nil {
return KeyID("config-error"), err
}
if len(config.DefaultKeyId) == 0 {
return KeyRefNotFound, fmt.Errorf("no default key currently configured")
}
return config.DefaultKeyId, nil
}
// Lookup by a direct get first to see if our reference is an ID, this is quick and cached.
if len(reference) == uuidLength {
entry, err := s.Get(ctx, KeyPrefix+reference)
if err != nil {
return KeyID("key-read"), err
}
if entry != nil {
return KeyID(reference), nil
}
}
// ... than to pull all keys from storage.
keys, err := ListKeys(ctx, s)
if err != nil {
return KeyID("list-error"), err
}
for _, keyId := range keys {
key, err := FetchKeyById(ctx, s, keyId)
if err != nil {
return KeyID("key-read"), err
}
if key.Name == reference {
return key.ID, nil
}
}
// Otherwise, we must not have found the key.
return KeyRefNotFound, errutil.UserError{Err: fmt.Sprintf("unable to find PKI key for reference: %v", reference)}
}
func GetManagedKeyUUID(key *KeyEntry) (managed_key.UUIDKey, error) {
if !key.IsManagedPrivateKey() {
return "", errutil.InternalError{Err: "getManagedKeyUUID called on a key id %s (%s) "}
}
return managed_key.ExtractManagedKeyId([]byte(key.PrivateKey))
}

View File

@@ -0,0 +1,452 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package issuing
import (
"context"
"errors"
"fmt"
"strings"
"time"
"github.com/hashicorp/go-secure-stdlib/parseutil"
"github.com/hashicorp/vault/sdk/helper/certutil"
"github.com/hashicorp/vault/sdk/logical"
)
var (
DefaultRoleKeyUsages = []string{"DigitalSignature", "KeyAgreement", "KeyEncipherment"}
DefaultRoleEstKeyUsages = []string{}
DefaultRoleEstKeyUsageOids = []string{}
)
const (
DefaultRoleSignatureBits = 0
DefaultRoleUsePss = false
)
type RoleEntry struct {
LeaseMax string `json:"lease_max"`
Lease string `json:"lease"`
DeprecatedMaxTTL string `json:"max_ttl"`
DeprecatedTTL string `json:"ttl"`
TTL time.Duration `json:"ttl_duration"`
MaxTTL time.Duration `json:"max_ttl_duration"`
AllowLocalhost bool `json:"allow_localhost"`
AllowedBaseDomain string `json:"allowed_base_domain"`
AllowedDomainsOld string `json:"allowed_domains,omitempty"`
AllowedDomains []string `json:"allowed_domains_list"`
AllowedDomainsTemplate bool `json:"allowed_domains_template"`
AllowBaseDomain bool `json:"allow_base_domain"`
AllowBareDomains bool `json:"allow_bare_domains"`
AllowTokenDisplayName bool `json:"allow_token_displayname"`
AllowSubdomains bool `json:"allow_subdomains"`
AllowGlobDomains bool `json:"allow_glob_domains"`
AllowWildcardCertificates *bool `json:"allow_wildcard_certificates,omitempty"`
AllowAnyName bool `json:"allow_any_name"`
EnforceHostnames bool `json:"enforce_hostnames"`
AllowIPSANs bool `json:"allow_ip_sans"`
ServerFlag bool `json:"server_flag"`
ClientFlag bool `json:"client_flag"`
CodeSigningFlag bool `json:"code_signing_flag"`
EmailProtectionFlag bool `json:"email_protection_flag"`
UseCSRCommonName bool `json:"use_csr_common_name"`
UseCSRSANs bool `json:"use_csr_sans"`
KeyType string `json:"key_type"`
KeyBits int `json:"key_bits"`
UsePSS bool `json:"use_pss"`
SignatureBits int `json:"signature_bits"`
MaxPathLength *int `json:",omitempty"`
KeyUsageOld string `json:"key_usage,omitempty"`
KeyUsage []string `json:"key_usage_list"`
ExtKeyUsage []string `json:"extended_key_usage_list"`
OUOld string `json:"ou,omitempty"`
OU []string `json:"ou_list"`
OrganizationOld string `json:"organization,omitempty"`
Organization []string `json:"organization_list"`
Country []string `json:"country"`
Locality []string `json:"locality"`
Province []string `json:"province"`
StreetAddress []string `json:"street_address"`
PostalCode []string `json:"postal_code"`
GenerateLease *bool `json:"generate_lease,omitempty"`
NoStore bool `json:"no_store"`
RequireCN bool `json:"require_cn"`
CNValidations []string `json:"cn_validations"`
AllowedOtherSANs []string `json:"allowed_other_sans"`
AllowedSerialNumbers []string `json:"allowed_serial_numbers"`
AllowedUserIDs []string `json:"allowed_user_ids"`
AllowedURISANs []string `json:"allowed_uri_sans"`
AllowedURISANsTemplate bool `json:"allowed_uri_sans_template"`
PolicyIdentifiers []string `json:"policy_identifiers"`
ExtKeyUsageOIDs []string `json:"ext_key_usage_oids"`
BasicConstraintsValidForNonCA bool `json:"basic_constraints_valid_for_non_ca"`
NotBeforeDuration time.Duration `json:"not_before_duration"`
NotAfter string `json:"not_after"`
Issuer string `json:"issuer"`
// Name is only set when the role has been stored, on the fly roles have a blank name
Name string `json:"-"`
// WasModified indicates to callers if the returned entry is different than the persisted version
WasModified bool `json:"-"`
}
func (r *RoleEntry) ToResponseData() map[string]interface{} {
responseData := map[string]interface{}{
"ttl": int64(r.TTL.Seconds()),
"max_ttl": int64(r.MaxTTL.Seconds()),
"allow_localhost": r.AllowLocalhost,
"allowed_domains": r.AllowedDomains,
"allowed_domains_template": r.AllowedDomainsTemplate,
"allow_bare_domains": r.AllowBareDomains,
"allow_token_displayname": r.AllowTokenDisplayName,
"allow_subdomains": r.AllowSubdomains,
"allow_glob_domains": r.AllowGlobDomains,
"allow_wildcard_certificates": r.AllowWildcardCertificates,
"allow_any_name": r.AllowAnyName,
"allowed_uri_sans_template": r.AllowedURISANsTemplate,
"enforce_hostnames": r.EnforceHostnames,
"allow_ip_sans": r.AllowIPSANs,
"server_flag": r.ServerFlag,
"client_flag": r.ClientFlag,
"code_signing_flag": r.CodeSigningFlag,
"email_protection_flag": r.EmailProtectionFlag,
"use_csr_common_name": r.UseCSRCommonName,
"use_csr_sans": r.UseCSRSANs,
"key_type": r.KeyType,
"key_bits": r.KeyBits,
"signature_bits": r.SignatureBits,
"use_pss": r.UsePSS,
"key_usage": r.KeyUsage,
"ext_key_usage": r.ExtKeyUsage,
"ext_key_usage_oids": r.ExtKeyUsageOIDs,
"ou": r.OU,
"organization": r.Organization,
"country": r.Country,
"locality": r.Locality,
"province": r.Province,
"street_address": r.StreetAddress,
"postal_code": r.PostalCode,
"no_store": r.NoStore,
"allowed_other_sans": r.AllowedOtherSANs,
"allowed_serial_numbers": r.AllowedSerialNumbers,
"allowed_user_ids": r.AllowedUserIDs,
"allowed_uri_sans": r.AllowedURISANs,
"require_cn": r.RequireCN,
"cn_validations": r.CNValidations,
"policy_identifiers": r.PolicyIdentifiers,
"basic_constraints_valid_for_non_ca": r.BasicConstraintsValidForNonCA,
"not_before_duration": int64(r.NotBeforeDuration.Seconds()),
"not_after": r.NotAfter,
"issuer_ref": r.Issuer,
}
if r.MaxPathLength != nil {
responseData["max_path_length"] = r.MaxPathLength
}
if r.GenerateLease != nil {
responseData["generate_lease"] = r.GenerateLease
}
return responseData
}
var ErrRoleNotFound = errors.New("role not found")
// GetRole will load a role from storage based on the provided name and
// update its contents to the latest version if out of date. The WasUpdated field
// will be set to true if modifications were made indicating the caller should if
// possible write them back to disk. If the role is not found an ErrRoleNotFound
// will be returned as an error.
func GetRole(ctx context.Context, s logical.Storage, n string) (*RoleEntry, error) {
entry, err := s.Get(ctx, "role/"+n)
if err != nil {
return nil, fmt.Errorf("failed to load role %s: %w", n, err)
}
if entry == nil {
return nil, fmt.Errorf("%w: with name %s", ErrRoleNotFound, n)
}
var result RoleEntry
if err := entry.DecodeJSON(&result); err != nil {
return nil, fmt.Errorf("failed decoding role %s: %w", n, err)
}
// Migrate existing saved entries and save back if changed
modified := false
if len(result.DeprecatedTTL) == 0 && len(result.Lease) != 0 {
result.DeprecatedTTL = result.Lease
result.Lease = ""
modified = true
}
if result.TTL == 0 && len(result.DeprecatedTTL) != 0 {
parsed, err := parseutil.ParseDurationSecond(result.DeprecatedTTL)
if err != nil {
return nil, err
}
result.TTL = parsed
result.DeprecatedTTL = ""
modified = true
}
if len(result.DeprecatedMaxTTL) == 0 && len(result.LeaseMax) != 0 {
result.DeprecatedMaxTTL = result.LeaseMax
result.LeaseMax = ""
modified = true
}
if result.MaxTTL == 0 && len(result.DeprecatedMaxTTL) != 0 {
parsed, err := parseutil.ParseDurationSecond(result.DeprecatedMaxTTL)
if err != nil {
return nil, fmt.Errorf("failed parsing max_ttl field in %s: %w", n, err)
}
result.MaxTTL = parsed
result.DeprecatedMaxTTL = ""
modified = true
}
if result.AllowBaseDomain {
result.AllowBaseDomain = false
result.AllowBareDomains = true
modified = true
}
if result.AllowedDomainsOld != "" {
result.AllowedDomains = strings.Split(result.AllowedDomainsOld, ",")
result.AllowedDomainsOld = ""
modified = true
}
if result.AllowedBaseDomain != "" {
found := false
for _, v := range result.AllowedDomains {
if v == result.AllowedBaseDomain {
found = true
break
}
}
if !found {
result.AllowedDomains = append(result.AllowedDomains, result.AllowedBaseDomain)
}
result.AllowedBaseDomain = ""
modified = true
}
if result.AllowWildcardCertificates == nil {
// While not the most secure default, when AllowWildcardCertificates isn't
// explicitly specified in the stored Role, we automatically upgrade it to
// true to preserve compatibility with previous versions of Vault. Once this
// field is set, this logic will not be triggered any more.
result.AllowWildcardCertificates = new(bool)
*result.AllowWildcardCertificates = true
modified = true
}
// Upgrade generate_lease in role
if result.GenerateLease == nil {
// All the new roles will have GenerateLease always set to a Value. A
// nil Value indicates that this role needs an upgrade. Set it to
// `true` to not alter its current behavior.
result.GenerateLease = new(bool)
*result.GenerateLease = true
modified = true
}
// Upgrade key usages
if result.KeyUsageOld != "" {
result.KeyUsage = strings.Split(result.KeyUsageOld, ",")
result.KeyUsageOld = ""
modified = true
}
// Upgrade OU
if result.OUOld != "" {
result.OU = strings.Split(result.OUOld, ",")
result.OUOld = ""
modified = true
}
// Upgrade Organization
if result.OrganizationOld != "" {
result.Organization = strings.Split(result.OrganizationOld, ",")
result.OrganizationOld = ""
modified = true
}
// Set the issuer field to default if not set. We want to do this
// unconditionally as we should probably never have an empty issuer
// on a stored roles.
if len(result.Issuer) == 0 {
result.Issuer = DefaultRef
modified = true
}
// Update CN Validations to be the present default, "email,hostname"
if len(result.CNValidations) == 0 {
result.CNValidations = []string{"email", "hostname"}
modified = true
}
result.Name = n
result.WasModified = modified
return &result, nil
}
type RoleModifier func(r *RoleEntry)
func WithKeyUsage(keyUsages []string) RoleModifier {
return func(r *RoleEntry) {
r.KeyUsage = keyUsages
}
}
func WithExtKeyUsage(extKeyUsages []string) RoleModifier {
return func(r *RoleEntry) {
r.ExtKeyUsage = extKeyUsages
}
}
func WithExtKeyUsageOIDs(extKeyUsageOids []string) RoleModifier {
return func(r *RoleEntry) {
r.ExtKeyUsageOIDs = extKeyUsageOids
}
}
func WithSignatureBits(signatureBits int) RoleModifier {
return func(r *RoleEntry) {
r.SignatureBits = signatureBits
}
}
func WithUsePSS(usePss bool) RoleModifier {
return func(r *RoleEntry) {
r.UsePSS = usePss
}
}
func WithTTL(ttl time.Duration) RoleModifier {
return func(r *RoleEntry) {
r.TTL = ttl
}
}
func WithMaxTTL(ttl time.Duration) RoleModifier {
return func(r *RoleEntry) {
r.MaxTTL = ttl
}
}
func WithGenerateLease(genLease bool) RoleModifier {
return func(r *RoleEntry) {
*r.GenerateLease = genLease
}
}
func WithNotBeforeDuration(ttl time.Duration) RoleModifier {
return func(r *RoleEntry) {
r.NotBeforeDuration = ttl
}
}
func WithNoStore(noStore bool) RoleModifier {
return func(r *RoleEntry) {
r.NoStore = noStore
}
}
func WithIssuer(issuer string) RoleModifier {
return func(r *RoleEntry) {
if issuer == "" {
issuer = DefaultRef
}
r.Issuer = issuer
}
}
// SignVerbatimRole create a sign-verbatim role with no overrides. This will store
// the signed certificate, allowing any key type and Value from a role restriction.
func SignVerbatimRole() *RoleEntry {
return SignVerbatimRoleWithOpts()
}
// SignVerbatimRoleWithOpts create a sign-verbatim role with the normal defaults,
// but allowing any field to be tweaked based on the consumers needs.
func SignVerbatimRoleWithOpts(opts ...RoleModifier) *RoleEntry {
entry := &RoleEntry{
AllowLocalhost: true,
AllowAnyName: true,
AllowIPSANs: true,
AllowWildcardCertificates: new(bool),
EnforceHostnames: false,
KeyType: "any",
UseCSRCommonName: true,
UseCSRSANs: true,
AllowedOtherSANs: []string{"*"},
AllowedSerialNumbers: []string{"*"},
AllowedURISANs: []string{"*"},
AllowedUserIDs: []string{"*"},
CNValidations: []string{"disabled"},
GenerateLease: new(bool),
KeyUsage: DefaultRoleKeyUsages,
ExtKeyUsage: DefaultRoleEstKeyUsages,
ExtKeyUsageOIDs: DefaultRoleEstKeyUsageOids,
SignatureBits: DefaultRoleSignatureBits,
UsePSS: DefaultRoleUsePss,
}
*entry.AllowWildcardCertificates = true
*entry.GenerateLease = false
if opts != nil {
for _, opt := range opts {
if opt != nil {
opt(entry)
}
}
}
return entry
}
func ParseExtKeyUsagesFromRole(role *RoleEntry) certutil.CertExtKeyUsage {
var parsedKeyUsages certutil.CertExtKeyUsage
if role.ServerFlag {
parsedKeyUsages |= certutil.ServerAuthExtKeyUsage
}
if role.ClientFlag {
parsedKeyUsages |= certutil.ClientAuthExtKeyUsage
}
if role.CodeSigningFlag {
parsedKeyUsages |= certutil.CodeSigningExtKeyUsage
}
if role.EmailProtectionFlag {
parsedKeyUsages |= certutil.EmailProtectionExtKeyUsage
}
for _, k := range role.ExtKeyUsage {
switch strings.ToLower(strings.TrimSpace(k)) {
case "any":
parsedKeyUsages |= certutil.AnyExtKeyUsage
case "serverauth":
parsedKeyUsages |= certutil.ServerAuthExtKeyUsage
case "clientauth":
parsedKeyUsages |= certutil.ClientAuthExtKeyUsage
case "codesigning":
parsedKeyUsages |= certutil.CodeSigningExtKeyUsage
case "emailprotection":
parsedKeyUsages |= certutil.EmailProtectionExtKeyUsage
case "ipsecendsystem":
parsedKeyUsages |= certutil.IpsecEndSystemExtKeyUsage
case "ipsectunnel":
parsedKeyUsages |= certutil.IpsecTunnelExtKeyUsage
case "ipsecuser":
parsedKeyUsages |= certutil.IpsecUserExtKeyUsage
case "timestamping":
parsedKeyUsages |= certutil.TimeStampingExtKeyUsage
case "ocspsigning":
parsedKeyUsages |= certutil.OcspSigningExtKeyUsage
case "microsoftservergatedcrypto":
parsedKeyUsages |= certutil.MicrosoftServerGatedCryptoExtKeyUsage
case "netscapeservergatedcrypto":
parsedKeyUsages |= certutil.NetscapeServerGatedCryptoExtKeyUsage
}
}
return parsedKeyUsages
}

View File

@@ -0,0 +1,291 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package issuing
import (
"crypto/ecdsa"
"crypto/ed25519"
"crypto/rsa"
"crypto/x509"
"fmt"
"github.com/hashicorp/vault/sdk/helper/certutil"
"github.com/hashicorp/vault/sdk/helper/errutil"
"github.com/hashicorp/vault/sdk/logical"
)
type SignCertInput interface {
CreationBundleInput
GetCSR() (*x509.CertificateRequest, error)
IsCA() bool
UseCSRValues() bool
GetPermittedDomains() []string
}
func NewBasicSignCertInput(csr *x509.CertificateRequest, isCA bool, useCSRValues bool) BasicSignCertInput {
return BasicSignCertInput{
isCA: isCA,
useCSRValues: useCSRValues,
csr: csr,
}
}
var _ SignCertInput = BasicSignCertInput{}
type BasicSignCertInput struct {
isCA bool
useCSRValues bool
csr *x509.CertificateRequest
}
func (b BasicSignCertInput) GetTTL() int {
return 0
}
func (b BasicSignCertInput) GetOptionalNotAfter() (interface{}, bool) {
return "", false
}
func (b BasicSignCertInput) GetCommonName() string {
return ""
}
func (b BasicSignCertInput) GetSerialNumber() string {
return ""
}
func (b BasicSignCertInput) GetExcludeCnFromSans() bool {
return false
}
func (b BasicSignCertInput) GetOptionalAltNames() (interface{}, bool) {
return []string{}, false
}
func (b BasicSignCertInput) GetOtherSans() []string {
return []string{}
}
func (b BasicSignCertInput) GetIpSans() []string {
return []string{}
}
func (b BasicSignCertInput) GetURISans() []string {
return []string{}
}
func (b BasicSignCertInput) GetOptionalSkid() (interface{}, bool) {
return "", false
}
func (b BasicSignCertInput) IsUserIdInSchema() (interface{}, bool) {
return []string{}, false
}
func (b BasicSignCertInput) GetUserIds() []string {
return []string{}
}
func (b BasicSignCertInput) GetCSR() (*x509.CertificateRequest, error) {
return b.csr, nil
}
func (b BasicSignCertInput) IsCA() bool {
return b.isCA
}
func (b BasicSignCertInput) UseCSRValues() bool {
return b.useCSRValues
}
func (b BasicSignCertInput) GetPermittedDomains() []string {
return []string{}
}
func SignCert(b logical.SystemView, role *RoleEntry, entityInfo EntityInfo, caSign *certutil.CAInfoBundle, signInput SignCertInput) (*certutil.ParsedCertBundle, []string, error) {
if role == nil {
return nil, nil, errutil.InternalError{Err: "no role found in data bundle"}
}
csr, err := signInput.GetCSR()
if err != nil {
return nil, nil, err
}
if csr.PublicKeyAlgorithm == x509.UnknownPublicKeyAlgorithm || csr.PublicKey == nil {
return nil, nil, errutil.UserError{Err: "Refusing to sign CSR with empty PublicKey. This usually means the SubjectPublicKeyInfo field has an OID not recognized by Go, such as 1.2.840.113549.1.1.10 for rsaPSS."}
}
// This switch validates that the CSR key type matches the role and sets
// the Value in the actualKeyType/actualKeyBits values.
actualKeyType := ""
actualKeyBits := 0
switch role.KeyType {
case "rsa":
// Verify that the key matches the role type
if csr.PublicKeyAlgorithm != x509.RSA {
return nil, nil, errutil.UserError{Err: fmt.Sprintf("role requires keys of type %s", role.KeyType)}
}
pubKey, ok := csr.PublicKey.(*rsa.PublicKey)
if !ok {
return nil, nil, errutil.UserError{Err: "could not parse CSR's public key"}
}
actualKeyType = "rsa"
actualKeyBits = pubKey.N.BitLen()
case "ec":
// Verify that the key matches the role type
if csr.PublicKeyAlgorithm != x509.ECDSA {
return nil, nil, errutil.UserError{Err: fmt.Sprintf(
"role requires keys of type %s",
role.KeyType)}
}
pubKey, ok := csr.PublicKey.(*ecdsa.PublicKey)
if !ok {
return nil, nil, errutil.UserError{Err: "could not parse CSR's public key"}
}
actualKeyType = "ec"
actualKeyBits = pubKey.Params().BitSize
case "ed25519":
// Verify that the key matches the role type
if csr.PublicKeyAlgorithm != x509.Ed25519 {
return nil, nil, errutil.UserError{Err: fmt.Sprintf(
"role requires keys of type %s",
role.KeyType)}
}
_, ok := csr.PublicKey.(ed25519.PublicKey)
if !ok {
return nil, nil, errutil.UserError{Err: "could not parse CSR's public key"}
}
actualKeyType = "ed25519"
actualKeyBits = 0
case "any":
// We need to compute the actual key type and key bits, to correctly
// validate minimums and SignatureBits below.
switch csr.PublicKeyAlgorithm {
case x509.RSA:
pubKey, ok := csr.PublicKey.(*rsa.PublicKey)
if !ok {
return nil, nil, errutil.UserError{Err: "could not parse CSR's public key"}
}
if pubKey.N.BitLen() < 2048 {
return nil, nil, errutil.UserError{Err: "RSA keys < 2048 bits are unsafe and not supported"}
}
actualKeyType = "rsa"
actualKeyBits = pubKey.N.BitLen()
case x509.ECDSA:
pubKey, ok := csr.PublicKey.(*ecdsa.PublicKey)
if !ok {
return nil, nil, errutil.UserError{Err: "could not parse CSR's public key"}
}
actualKeyType = "ec"
actualKeyBits = pubKey.Params().BitSize
case x509.Ed25519:
_, ok := csr.PublicKey.(ed25519.PublicKey)
if !ok {
return nil, nil, errutil.UserError{Err: "could not parse CSR's public key"}
}
actualKeyType = "ed25519"
actualKeyBits = 0
default:
return nil, nil, errutil.UserError{Err: "Unknown key type in CSR: " + csr.PublicKeyAlgorithm.String()}
}
default:
return nil, nil, errutil.InternalError{Err: fmt.Sprintf("unsupported key type Value: %s", role.KeyType)}
}
// Before validating key lengths, update our KeyBits/SignatureBits based
// on the actual CSR key type.
if role.KeyType == "any" {
// We update the Value of KeyBits and SignatureBits here (from the
// role), using the specified key type. This allows us to convert
// the default Value (0) for SignatureBits and KeyBits to a
// meaningful Value.
//
// We ignore the role's original KeyBits Value if the KeyType is any
// as legacy (pre-1.10) roles had default values that made sense only
// for RSA keys (key_bits=2048) and the older code paths ignored the role Value
// set for KeyBits when KeyType was set to any. This also enforces the
// docs saying when key_type=any, we only enforce our specified minimums
// for signing operations
var err error
if role.KeyBits, role.SignatureBits, err = certutil.ValidateDefaultOrValueKeyTypeSignatureLength(
actualKeyType, 0, role.SignatureBits); err != nil {
return nil, nil, errutil.InternalError{Err: fmt.Sprintf("unknown internal error updating default values: %v", err)}
}
// We're using the KeyBits field as a minimum Value below, and P-224 is safe
// and a previously allowed Value. However, the above call defaults
// to P-256 as that's a saner default than P-224 (w.r.t. generation), so
// override it here to allow 224 as the smallest size we permit.
if actualKeyType == "ec" {
role.KeyBits = 224
}
}
// At this point, role.KeyBits and role.SignatureBits should both
// be non-zero, for RSA and ECDSA keys. Validate the actualKeyBits based on
// the role's values. If the KeyType was any, and KeyBits was set to 0,
// KeyBits should be updated to 2048 unless some other Value was chosen
// explicitly.
//
// This validation needs to occur regardless of the role's key type, so
// that we always validate both RSA and ECDSA key sizes.
if actualKeyType == "rsa" {
if actualKeyBits < role.KeyBits {
return nil, nil, errutil.UserError{Err: fmt.Sprintf(
"role requires a minimum of a %d-bit key, but CSR's key is %d bits",
role.KeyBits, actualKeyBits)}
}
if actualKeyBits < 2048 {
return nil, nil, errutil.UserError{Err: fmt.Sprintf(
"Vault requires a minimum of a 2048-bit key, but CSR's key is %d bits",
actualKeyBits)}
}
} else if actualKeyType == "ec" {
if actualKeyBits < role.KeyBits {
return nil, nil, errutil.UserError{Err: fmt.Sprintf(
"role requires a minimum of a %d-bit key, but CSR's key is %d bits",
role.KeyBits,
actualKeyBits)}
}
}
creation, warnings, err := GenerateCreationBundle(b, role, entityInfo, signInput, caSign, csr)
if err != nil {
return nil, nil, err
}
if creation.Params == nil {
return nil, nil, errutil.InternalError{Err: "nil parameters received from parameter bundle generation"}
}
creation.Params.IsCA = signInput.IsCA()
creation.Params.UseCSRValues = signInput.UseCSRValues()
if signInput.IsCA() {
creation.Params.PermittedDNSDomains = signInput.GetPermittedDomains()
} else {
for _, ext := range csr.Extensions {
if ext.Id.Equal(certutil.ExtensionBasicConstraintsOID) {
warnings = append(warnings, "specified CSR contained a Basic Constraints extension that was ignored during issuance")
}
}
}
parsedBundle, err := certutil.SignCertificate(creation)
if err != nil {
return nil, nil, err
}
return parsedBundle, warnings, nil
}

View File

@@ -12,9 +12,12 @@ import (
"github.com/hashicorp/vault/sdk/helper/certutil" "github.com/hashicorp/vault/sdk/helper/certutil"
"github.com/hashicorp/vault/sdk/helper/errutil" "github.com/hashicorp/vault/sdk/helper/errutil"
"github.com/hashicorp/vault/builtin/logical/pki/issuing"
"github.com/hashicorp/vault/builtin/logical/pki/managed_key"
) )
func comparePublicKey(sc *storageContext, key *keyEntry, publicKey crypto.PublicKey) (bool, error) { func comparePublicKey(sc *storageContext, key *issuing.KeyEntry, publicKey crypto.PublicKey) (bool, error) {
publicKeyForKeyEntry, err := getPublicKey(sc.Context, sc.Backend, key) publicKeyForKeyEntry, err := getPublicKey(sc.Context, sc.Backend, key)
if err != nil { if err != nil {
return false, err return false, err
@@ -23,13 +26,9 @@ func comparePublicKey(sc *storageContext, key *keyEntry, publicKey crypto.Public
return certutil.ComparePublicKeysAndType(publicKeyForKeyEntry, publicKey) return certutil.ComparePublicKeysAndType(publicKeyForKeyEntry, publicKey)
} }
func getPublicKey(ctx context.Context, b *backend, key *keyEntry) (crypto.PublicKey, error) { func getPublicKey(ctx context.Context, b *backend, key *issuing.KeyEntry) (crypto.PublicKey, error) {
if key.PrivateKeyType == certutil.ManagedPrivateKey { if key.PrivateKeyType == certutil.ManagedPrivateKey {
keyId, err := extractManagedKeyId([]byte(key.PrivateKey)) return managed_key.GetPublicKeyFromKeyBytes(ctx, b, []byte(key.PrivateKey))
if err != nil {
return nil, err
}
return getManagedKeyPublicKey(ctx, b, keyId)
} }
signer, _, _, err := getSignerFromKeyEntryBytes(key) signer, _, _, err := getSignerFromKeyEntryBytes(key)
@@ -39,7 +38,7 @@ func getPublicKey(ctx context.Context, b *backend, key *keyEntry) (crypto.Public
return signer.Public(), nil return signer.Public(), nil
} }
func getSignerFromKeyEntryBytes(key *keyEntry) (crypto.Signer, certutil.BlockType, *pem.Block, error) { func getSignerFromKeyEntryBytes(key *issuing.KeyEntry) (crypto.Signer, certutil.BlockType, *pem.Block, error) {
if key.PrivateKeyType == certutil.UnknownPrivateKey { if key.PrivateKeyType == certutil.UnknownPrivateKey {
return nil, certutil.UnknownBlock, nil, errutil.InternalError{Err: fmt.Sprintf("unsupported unknown private key type for key: %s (%s)", key.ID, key.Name)} return nil, certutil.UnknownBlock, nil, errutil.InternalError{Err: fmt.Sprintf("unsupported unknown private key type for key: %s (%s)", key.ID, key.Name)}
} }
@@ -78,7 +77,7 @@ func getPublicKeyFromBytes(keyBytes []byte) (crypto.PublicKey, error) {
return signer.Public(), nil return signer.Public(), nil
} }
func importKeyFromBytes(sc *storageContext, keyValue string, keyName string) (*keyEntry, bool, error) { func importKeyFromBytes(sc *storageContext, keyValue string, keyName string) (*issuing.KeyEntry, bool, error) {
signer, _, _, err := getSignerFromBytes([]byte(keyValue)) signer, _, _, err := getSignerFromBytes([]byte(keyValue))
if err != nil { if err != nil {
return nil, false, err return nil, false, err

View File

@@ -0,0 +1,43 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package managed_key
import (
"crypto"
"io"
"github.com/hashicorp/vault/sdk/helper/certutil"
"github.com/hashicorp/vault/sdk/logical"
)
type ManagedKeyInfo struct {
publicKey crypto.PublicKey
KeyType certutil.PrivateKeyType
Name NameKey
Uuid UUIDKey
}
type managedKeyId interface {
String() string
}
type PkiManagedKeyView interface {
BackendUUID() string
IsSecondaryNode() bool
GetManagedKeyView() (logical.ManagedKeySystemView, error)
GetRandomReader() io.Reader
}
type (
UUIDKey string
NameKey string
)
func (u UUIDKey) String() string {
return string(u)
}
func (n NameKey) String() string {
return string(n)
}

View File

@@ -0,0 +1,49 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
//go:build !enterprise
package managed_key
import (
"context"
"crypto"
"errors"
"io"
"github.com/hashicorp/vault/sdk/helper/certutil"
)
var errEntOnly = errors.New("managed keys are supported within enterprise edition only")
func GetPublicKeyFromKeyBytes(ctx context.Context, mkv PkiManagedKeyView, keyBytes []byte) (crypto.PublicKey, error) {
return nil, errEntOnly
}
func GenerateManagedKeyCABundle(ctx context.Context, b PkiManagedKeyView, keyId managedKeyId, data *certutil.CreationBundle, randomSource io.Reader) (bundle *certutil.ParsedCertBundle, err error) {
return nil, errEntOnly
}
func GenerateManagedKeyCSRBundle(ctx context.Context, b PkiManagedKeyView, keyId managedKeyId, data *certutil.CreationBundle, addBasicConstraints bool, randomSource io.Reader) (bundle *certutil.ParsedCSRBundle, err error) {
return nil, errEntOnly
}
func GetManagedKeyPublicKey(ctx context.Context, b PkiManagedKeyView, keyId managedKeyId) (crypto.PublicKey, error) {
return nil, errEntOnly
}
func ParseManagedKeyCABundle(ctx context.Context, mkv PkiManagedKeyView, bundle *certutil.CertBundle) (*certutil.ParsedCertBundle, error) {
return nil, errEntOnly
}
func ExtractManagedKeyId(privateKeyBytes []byte) (UUIDKey, error) {
return "", errEntOnly
}
func CreateKmsKeyBundle(ctx context.Context, mkv PkiManagedKeyView, keyId managedKeyId) (certutil.KeyBundle, certutil.PrivateKeyType, error) {
return certutil.KeyBundle{}, certutil.UnknownPrivateKey, errEntOnly
}
func GetManagedKeyInfo(ctx context.Context, mkv PkiManagedKeyView, keyId managedKeyId) (*ManagedKeyInfo, error) {
return nil, errEntOnly
}

View File

@@ -1,45 +0,0 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
//go:build !enterprise
package pki
import (
"context"
"crypto"
"errors"
"io"
"github.com/hashicorp/vault/sdk/helper/certutil"
)
var errEntOnly = errors.New("managed keys are supported within enterprise edition only")
func generateManagedKeyCABundle(ctx context.Context, b *backend, keyId managedKeyId, data *certutil.CreationBundle, randomSource io.Reader) (bundle *certutil.ParsedCertBundle, err error) {
return nil, errEntOnly
}
func generateManagedKeyCSRBundle(ctx context.Context, b *backend, keyId managedKeyId, data *certutil.CreationBundle, addBasicConstraints bool, randomSource io.Reader) (bundle *certutil.ParsedCSRBundle, err error) {
return nil, errEntOnly
}
func getManagedKeyPublicKey(ctx context.Context, b *backend, keyId managedKeyId) (crypto.PublicKey, error) {
return nil, errEntOnly
}
func parseManagedKeyCABundle(ctx context.Context, b *backend, bundle *certutil.CertBundle) (*certutil.ParsedCertBundle, error) {
return nil, errEntOnly
}
func extractManagedKeyId(privateKeyBytes []byte) (UUIDKey, error) {
return "", errEntOnly
}
func createKmsKeyBundle(ctx context.Context, b *backend, keyId managedKeyId) (certutil.KeyBundle, certutil.PrivateKeyType, error) {
return certutil.KeyBundle{}, certutil.UnknownPrivateKey, errEntOnly
}
func getManagedKeyInfo(ctx context.Context, b *backend, keyId managedKeyId) (*managedKeyInfo, error) {
return nil, errEntOnly
}

View File

@@ -0,0 +1,263 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package pki
import (
"errors"
"sort"
"strings"
"sync/atomic"
"github.com/armon/go-metrics"
)
type CertificateCounter struct {
certCountEnabled *atomic.Bool
publishCertCountMetrics *atomic.Bool
certCount *atomic.Uint32
revokedCertCount *atomic.Uint32
certsCounted *atomic.Bool
certCountError error
possibleDoubleCountedSerials []string
possibleDoubleCountedRevokedSerials []string
backendUuid string
}
func (c *CertificateCounter) IsInitialized() bool {
return c.certsCounted.Load()
}
func (c *CertificateCounter) IsEnabled() bool {
return c.certCountEnabled.Load()
}
func (c *CertificateCounter) Error() error {
return c.certCountError
}
func (c *CertificateCounter) SetError(err error) {
c.certCountError = err
}
func (c *CertificateCounter) ReconfigureWithTidyConfig(config *tidyConfig) bool {
if config.MaintainCount {
c.enableCertCounting(config.PublishMetrics)
} else {
c.disableCertCounting()
}
return config.MaintainCount
}
func (c *CertificateCounter) disableCertCounting() {
c.possibleDoubleCountedRevokedSerials = nil
c.possibleDoubleCountedSerials = nil
c.certsCounted.Store(false)
c.certCount.Store(0)
c.revokedCertCount.Store(0)
c.certCountError = errors.New("Cert Count is Disabled: enable via Tidy Config maintain_stored_certificate_counts")
c.certCountEnabled.Store(false)
c.publishCertCountMetrics.Store(false)
}
func (c *CertificateCounter) enableCertCounting(publishMetrics bool) {
c.publishCertCountMetrics.Store(publishMetrics)
c.certCountEnabled.Store(true)
if !c.certsCounted.Load() {
c.certCountError = errors.New("Certificate Counting Has Not Been Initialized, re-initialize this mount")
}
}
func (c *CertificateCounter) InitializeCountsFromStorage(certs, revoked []string) {
c.certCount.Add(uint32(len(certs)))
c.revokedCertCount.Add(uint32(len(revoked)))
c.pruneDuplicates(certs, revoked)
c.certCountError = nil
c.certsCounted.Store(true)
c.emitTotalCertCountMetric()
}
func (c *CertificateCounter) pruneDuplicates(entries, revokedEntries []string) {
// Now that the metrics are set, we can switch from appending newly-stored certificates to the possible double-count
// list, and instead have them update the counter directly. We need to do this so that we are looking at a static
// slice of possibly double counted serials. Note that certsCounted is computed before the storage operation, so
// there may be some delay here.
// Sort the listed-entries first, to accommodate that delay.
sort.Slice(entries, func(i, j int) bool {
return entries[i] < entries[j]
})
sort.Slice(revokedEntries, func(i, j int) bool {
return revokedEntries[i] < revokedEntries[j]
})
// We assume here that these lists are now complete.
sort.Slice(c.possibleDoubleCountedSerials, func(i, j int) bool {
return c.possibleDoubleCountedSerials[i] < c.possibleDoubleCountedSerials[j]
})
listEntriesIndex := 0
possibleDoubleCountIndex := 0
for {
if listEntriesIndex >= len(entries) {
break
}
if possibleDoubleCountIndex >= len(c.possibleDoubleCountedSerials) {
break
}
if entries[listEntriesIndex] == c.possibleDoubleCountedSerials[possibleDoubleCountIndex] {
// This represents a double-counted entry
c.decrementTotalCertificatesCountNoReport()
listEntriesIndex = listEntriesIndex + 1
possibleDoubleCountIndex = possibleDoubleCountIndex + 1
continue
}
if entries[listEntriesIndex] < c.possibleDoubleCountedSerials[possibleDoubleCountIndex] {
listEntriesIndex = listEntriesIndex + 1
continue
}
if entries[listEntriesIndex] > c.possibleDoubleCountedSerials[possibleDoubleCountIndex] {
possibleDoubleCountIndex = possibleDoubleCountIndex + 1
continue
}
}
sort.Slice(c.possibleDoubleCountedRevokedSerials, func(i, j int) bool {
return c.possibleDoubleCountedRevokedSerials[i] < c.possibleDoubleCountedRevokedSerials[j]
})
listRevokedEntriesIndex := 0
possibleRevokedDoubleCountIndex := 0
for {
if listRevokedEntriesIndex >= len(revokedEntries) {
break
}
if possibleRevokedDoubleCountIndex >= len(c.possibleDoubleCountedRevokedSerials) {
break
}
if revokedEntries[listRevokedEntriesIndex] == c.possibleDoubleCountedRevokedSerials[possibleRevokedDoubleCountIndex] {
// This represents a double-counted revoked entry
c.decrementTotalRevokedCertificatesCountNoReport()
listRevokedEntriesIndex = listRevokedEntriesIndex + 1
possibleRevokedDoubleCountIndex = possibleRevokedDoubleCountIndex + 1
continue
}
if revokedEntries[listRevokedEntriesIndex] < c.possibleDoubleCountedRevokedSerials[possibleRevokedDoubleCountIndex] {
listRevokedEntriesIndex = listRevokedEntriesIndex + 1
continue
}
if revokedEntries[listRevokedEntriesIndex] > c.possibleDoubleCountedRevokedSerials[possibleRevokedDoubleCountIndex] {
possibleRevokedDoubleCountIndex = possibleRevokedDoubleCountIndex + 1
continue
}
}
c.possibleDoubleCountedRevokedSerials = nil
c.possibleDoubleCountedSerials = nil
}
func (c *CertificateCounter) decrementTotalCertificatesCountNoReport() uint32 {
newCount := c.certCount.Add(^uint32(0))
return newCount
}
func (c *CertificateCounter) decrementTotalRevokedCertificatesCountNoReport() uint32 {
newRevokedCertCount := c.revokedCertCount.Add(^uint32(0))
return newRevokedCertCount
}
func (c *CertificateCounter) CertificateCount() uint32 {
return c.certCount.Load()
}
func (c *CertificateCounter) RevokedCount() uint32 {
return c.revokedCertCount.Load()
}
func (c *CertificateCounter) IncrementTotalCertificatesCount(certsCounted bool, newSerial string) {
if c.certCountEnabled.Load() {
c.certCount.Add(1)
switch {
case !certsCounted:
// This is unsafe, but a good best-attempt
if strings.HasPrefix(newSerial, "certs/") {
newSerial = newSerial[6:]
}
c.possibleDoubleCountedSerials = append(c.possibleDoubleCountedSerials, newSerial)
default:
c.emitTotalCertCountMetric()
}
}
}
// The "certsCounted" boolean here should be loaded from the backend certsCounted before the corresponding storage call:
// eg. certsCounted := certCounter.IsInitialized()
func (c *CertificateCounter) IncrementTotalRevokedCertificatesCount(certsCounted bool, newSerial string) {
if c.certCountEnabled.Load() {
c.revokedCertCount.Add(1)
switch {
case !certsCounted:
// This is unsafe, but a good best-attempt
if strings.HasPrefix(newSerial, "revoked/") { // allow passing in the path (revoked/serial) OR the serial
newSerial = newSerial[8:]
}
c.possibleDoubleCountedRevokedSerials = append(c.possibleDoubleCountedRevokedSerials, newSerial)
default:
c.emitTotalRevokedCountMetric()
}
}
}
func (c *CertificateCounter) DecrementTotalCertificatesCountReport() {
if c.certCountEnabled.Load() {
c.decrementTotalCertificatesCountNoReport()
c.emitTotalCertCountMetric()
}
}
func (c *CertificateCounter) DecrementTotalRevokedCertificatesCountReport() {
if c.certCountEnabled.Load() {
c.decrementTotalRevokedCertificatesCountNoReport()
c.emitTotalRevokedCountMetric()
}
}
func (c *CertificateCounter) EmitCertStoreMetrics() {
c.emitTotalCertCountMetric()
c.emitTotalRevokedCountMetric()
}
func (c *CertificateCounter) emitTotalCertCountMetric() {
if c.publishCertCountMetrics.Load() {
certCount := float32(c.CertificateCount())
metrics.SetGauge([]string{"secrets", "pki", c.backendUuid, "total_certificates_stored"}, certCount)
}
}
func (c *CertificateCounter) emitTotalRevokedCountMetric() {
if c.publishCertCountMetrics.Load() {
revokedCount := float32(c.RevokedCount())
metrics.SetGauge([]string{"secrets", "pki", c.backendUuid, "total_revoked_certificates_stored"}, revokedCount)
}
}
func NewCertificateCounter(backendUuid string) *CertificateCounter {
counter := &CertificateCounter{
backendUuid: backendUuid,
certCountEnabled: &atomic.Bool{},
publishCertCountMetrics: &atomic.Bool{},
certCount: &atomic.Uint32{},
revokedCertCount: &atomic.Uint32{},
certsCounted: &atomic.Bool{},
certCountError: errors.New("Initialize Not Yet Run, Cert Counts Unavailable"),
possibleDoubleCountedSerials: make([]string, 0, 250),
possibleDoubleCountedRevokedSerials: make([]string, 0, 250),
}
return counter
}

View File

@@ -0,0 +1,74 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package parsing
import (
"crypto/x509"
"fmt"
"strings"
)
func ParseCertificateFromString(pemCert string) (*x509.Certificate, error) {
return ParseCertificateFromBytes([]byte(pemCert))
}
func ParseCertificateFromBytes(certBytes []byte) (*x509.Certificate, error) {
block, err := DecodePem(certBytes)
if err != nil {
return nil, fmt.Errorf("unable to parse certificate: %w", err)
}
cert, err := x509.ParseCertificate(block.Bytes)
if err != nil {
return nil, fmt.Errorf("unable to parse certificate: %w", err)
}
return cert, nil
}
func ParseCertificatesFromString(pemCerts string) ([]*x509.Certificate, error) {
return ParseCertificatesFromBytes([]byte(pemCerts))
}
func ParseCertificatesFromBytes(certBytes []byte) ([]*x509.Certificate, error) {
block, err := DecodePem(certBytes)
if err != nil {
return nil, fmt.Errorf("unable to parse certificate: %w", err)
}
cert, err := x509.ParseCertificates(block.Bytes)
if err != nil {
return nil, fmt.Errorf("unable to parse certificate: %w", err)
}
return cert, nil
}
func ParseKeyUsages(input []string) int {
var parsedKeyUsages x509.KeyUsage
for _, k := range input {
switch strings.ToLower(strings.TrimSpace(k)) {
case "digitalsignature":
parsedKeyUsages |= x509.KeyUsageDigitalSignature
case "contentcommitment":
parsedKeyUsages |= x509.KeyUsageContentCommitment
case "keyencipherment":
parsedKeyUsages |= x509.KeyUsageKeyEncipherment
case "dataencipherment":
parsedKeyUsages |= x509.KeyUsageDataEncipherment
case "keyagreement":
parsedKeyUsages |= x509.KeyUsageKeyAgreement
case "certsign":
parsedKeyUsages |= x509.KeyUsageCertSign
case "crlsign":
parsedKeyUsages |= x509.KeyUsageCRLSign
case "encipheronly":
parsedKeyUsages |= x509.KeyUsageEncipherOnly
case "decipheronly":
parsedKeyUsages |= x509.KeyUsageDecipherOnly
}
}
return int(parsedKeyUsages)
}

View File

@@ -0,0 +1,27 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package parsing
import (
"crypto/x509"
"fmt"
)
func ParseCertificateRequestFromString(pemCert string) (*x509.CertificateRequest, error) {
return ParseCertificateRequestFromBytes([]byte(pemCert))
}
func ParseCertificateRequestFromBytes(certBytes []byte) (*x509.CertificateRequest, error) {
block, err := DecodePem(certBytes)
if err != nil {
return nil, fmt.Errorf("unable to parse certificate request: %w", err)
}
csr, err := x509.ParseCertificateRequest(block.Bytes)
if err != nil {
return nil, fmt.Errorf("unable to parse certificate request: %w", err)
}
return csr, nil
}

View File

@@ -0,0 +1,21 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package parsing
import (
"encoding/pem"
"errors"
"strings"
)
func DecodePem(certBytes []byte) (*pem.Block, error) {
block, extra := pem.Decode(certBytes)
if block == nil {
return nil, errors.New("invalid PEM")
}
if len(strings.TrimSpace(string(extra))) > 0 {
return nil, errors.New("trailing PEM data")
}
return block, nil
}

View File

@@ -11,7 +11,6 @@ import (
"time" "time"
"github.com/hashicorp/go-secure-stdlib/strutil" "github.com/hashicorp/go-secure-stdlib/strutil"
"github.com/hashicorp/vault/sdk/framework" "github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/sdk/logical"
) )
@@ -222,7 +221,7 @@ func (b *backend) acmeAccountSearchHandler(acmeCtx *acmeContext, userCtx *jwsCtx
return nil, fmt.Errorf("failed generating thumbprint for key: %w", err) return nil, fmt.Errorf("failed generating thumbprint for key: %w", err)
} }
account, err := b.acmeState.LoadAccountByKey(acmeCtx, thumbprint) account, err := b.GetAcmeState().LoadAccountByKey(acmeCtx, thumbprint)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to load account by thumbprint: %w", err) return nil, fmt.Errorf("failed to load account by thumbprint: %w", err)
} }
@@ -253,7 +252,7 @@ func (b *backend) acmeNewAccountCreateHandler(acmeCtx *acmeContext, userCtx *jws
return nil, fmt.Errorf("failed generating thumbprint for key: %w", err) return nil, fmt.Errorf("failed generating thumbprint for key: %w", err)
} }
accountByKey, err := b.acmeState.LoadAccountByKey(acmeCtx, thumbprint) accountByKey, err := b.GetAcmeState().LoadAccountByKey(acmeCtx, thumbprint)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to load account by thumbprint: %w", err) return nil, fmt.Errorf("failed to load account by thumbprint: %w", err)
} }
@@ -267,7 +266,7 @@ func (b *backend) acmeNewAccountCreateHandler(acmeCtx *acmeContext, userCtx *jws
var eab *eabType var eab *eabType
if len(eabData) != 0 { if len(eabData) != 0 {
eab, err = verifyEabPayload(b.acmeState, acmeCtx, userCtx, r.Path, eabData) eab, err = verifyEabPayload(b.GetAcmeState(), acmeCtx, userCtx, r.Path, eabData)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -288,7 +287,7 @@ func (b *backend) acmeNewAccountCreateHandler(acmeCtx *acmeContext, userCtx *jws
// We delete the EAB to prevent future re-use after associating it with an account, worst // We delete the EAB to prevent future re-use after associating it with an account, worst
// case if we fail creating the account we simply nuked the EAB which they can create another // case if we fail creating the account we simply nuked the EAB which they can create another
// and retry // and retry
wasDeleted, err := b.acmeState.DeleteEab(acmeCtx.sc, eab.KeyID) wasDeleted, err := b.GetAcmeState().DeleteEab(acmeCtx.sc, eab.KeyID)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to delete eab reference: %w", err) return nil, fmt.Errorf("failed to delete eab reference: %w", err)
} }
@@ -302,7 +301,7 @@ func (b *backend) acmeNewAccountCreateHandler(acmeCtx *acmeContext, userCtx *jws
b.acmeAccountLock.RLock() // Prevents Account Creation and Tidy Interfering b.acmeAccountLock.RLock() // Prevents Account Creation and Tidy Interfering
defer b.acmeAccountLock.RUnlock() defer b.acmeAccountLock.RUnlock()
accountByKid, err := b.acmeState.CreateAccount(acmeCtx, userCtx, contact, termsOfServiceAgreed, eab) accountByKid, err := b.GetAcmeState().CreateAccount(acmeCtx, userCtx, contact, termsOfServiceAgreed, eab)
if err != nil { if err != nil {
if eab != nil { if eab != nil {
return nil, fmt.Errorf("failed to create account: %w; the EAB key used for this request has been deleted as a result of this operation; fetch a new EAB key before retrying", err) return nil, fmt.Errorf("failed to create account: %w; the EAB key used for this request has been deleted as a result of this operation; fetch a new EAB key before retrying", err)
@@ -329,7 +328,7 @@ func (b *backend) acmeNewAccountUpdateHandler(acmeCtx *acmeContext, userCtx *jws
return nil, fmt.Errorf("%w: not allowed to update EAB data in accounts", ErrMalformed) return nil, fmt.Errorf("%w: not allowed to update EAB data in accounts", ErrMalformed)
} }
account, err := b.acmeState.LoadAccount(acmeCtx, userCtx.Kid) account, err := b.GetAcmeState().LoadAccount(acmeCtx, userCtx.Kid)
if err != nil { if err != nil {
return nil, fmt.Errorf("error loading account: %w", err) return nil, fmt.Errorf("error loading account: %w", err)
} }
@@ -363,7 +362,7 @@ func (b *backend) acmeNewAccountUpdateHandler(acmeCtx *acmeContext, userCtx *jws
} }
if shouldUpdate { if shouldUpdate {
err = b.acmeState.UpdateAccount(acmeCtx.sc, account) err = b.GetAcmeState().UpdateAccount(acmeCtx.sc, account)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to update account: %w", err) return nil, fmt.Errorf("failed to update account: %w", err)
} }

View File

@@ -48,7 +48,7 @@ func patternAcmeAuthorization(b *backend, pattern string, opts acmeWrapperOpts)
func (b *backend) acmeAuthorizationHandler(acmeCtx *acmeContext, r *logical.Request, fields *framework.FieldData, userCtx *jwsCtx, data map[string]interface{}, _ *acmeAccount) (*logical.Response, error) { func (b *backend) acmeAuthorizationHandler(acmeCtx *acmeContext, r *logical.Request, fields *framework.FieldData, userCtx *jwsCtx, data map[string]interface{}, _ *acmeAccount) (*logical.Response, error) {
authId := fields.Get("auth_id").(string) authId := fields.Get("auth_id").(string)
authz, err := b.acmeState.LoadAuthorization(acmeCtx, userCtx, authId) authz, err := b.GetAcmeState().LoadAuthorization(acmeCtx, userCtx, authId)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to load authorization: %w", err) return nil, fmt.Errorf("failed to load authorization: %w", err)
} }
@@ -90,7 +90,7 @@ func (b *backend) acmeAuthorizationDeactivateHandler(acmeCtx *acmeContext, r *lo
challenge.Status = ACMEChallengeInvalid challenge.Status = ACMEChallengeInvalid
} }
if err := b.acmeState.SaveAuthorization(acmeCtx, authz); err != nil { if err := b.GetAcmeState().SaveAuthorization(acmeCtx, authz); err != nil {
return nil, fmt.Errorf("error saving deactivated authorization: %w", err) return nil, fmt.Errorf("error saving deactivated authorization: %w", err)
} }

View File

@@ -57,7 +57,7 @@ func (b *backend) acmeChallengeHandler(acmeCtx *acmeContext, r *logical.Request,
authId := fields.Get("auth_id").(string) authId := fields.Get("auth_id").(string)
challengeType := fields.Get("challenge_type").(string) challengeType := fields.Get("challenge_type").(string)
authz, err := b.acmeState.LoadAuthorization(acmeCtx, userCtx, authId) authz, err := b.GetAcmeState().LoadAuthorization(acmeCtx, userCtx, authId)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to load authorization: %w", err) return nil, fmt.Errorf("failed to load authorization: %w", err)
} }
@@ -95,7 +95,7 @@ func (b *backend) acmeChallengeFetchHandler(acmeCtx *acmeContext, r *logical.Req
return nil, fmt.Errorf("failed to get thumbprint for key: %w", err) return nil, fmt.Errorf("failed to get thumbprint for key: %w", err)
} }
if err := b.acmeState.validator.AcceptChallenge(acmeCtx.sc, userCtx.Kid, authz, challenge, thumbprint); err != nil { if err := b.GetAcmeState().validator.AcceptChallenge(acmeCtx.sc, userCtx.Kid, authz, challenge, thumbprint); err != nil {
return nil, fmt.Errorf("error submitting challenge for validation: %w", err) return nil, fmt.Errorf("error submitting challenge for validation: %w", err)
} }
} }

View File

@@ -183,7 +183,8 @@ type eabType struct {
func (b *backend) pathAcmeListEab(ctx context.Context, r *logical.Request, _ *framework.FieldData) (*logical.Response, error) { func (b *backend) pathAcmeListEab(ctx context.Context, r *logical.Request, _ *framework.FieldData) (*logical.Response, error) {
sc := b.makeStorageContext(ctx, r.Storage) sc := b.makeStorageContext(ctx, r.Storage)
eabIds, err := b.acmeState.ListEabIds(sc) acmeState := b.GetAcmeState()
eabIds, err := acmeState.ListEabIds(sc)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -193,7 +194,7 @@ func (b *backend) pathAcmeListEab(ctx context.Context, r *logical.Request, _ *fr
keyInfos := map[string]interface{}{} keyInfos := map[string]interface{}{}
for _, eabKey := range eabIds { for _, eabKey := range eabIds {
eab, err := b.acmeState.LoadEab(sc, eabKey) eab, err := acmeState.LoadEab(sc, eabKey)
if err != nil { if err != nil {
warnings = append(warnings, fmt.Sprintf("failed loading eab entry %s: %v", eabKey, err)) warnings = append(warnings, fmt.Sprintf("failed loading eab entry %s: %v", eabKey, err))
continue continue
@@ -236,7 +237,7 @@ func (b *backend) pathAcmeCreateEab(ctx context.Context, r *logical.Request, dat
} }
sc := b.makeStorageContext(ctx, r.Storage) sc := b.makeStorageContext(ctx, r.Storage)
err = b.acmeState.SaveEab(sc, eab) err = b.GetAcmeState().SaveEab(sc, eab)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed saving generated eab: %w", err) return nil, fmt.Errorf("failed saving generated eab: %w", err)
} }
@@ -263,7 +264,7 @@ func (b *backend) pathAcmeDeleteEab(ctx context.Context, r *logical.Request, d *
return nil, fmt.Errorf("badly formatted key_id field") return nil, fmt.Errorf("badly formatted key_id field")
} }
deleted, err := b.acmeState.DeleteEab(sc, keyId) deleted, err := b.GetAcmeState().DeleteEab(sc, keyId)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed deleting key id: %w", err) return nil, fmt.Errorf("failed deleting key id: %w", err)
} }

View File

@@ -41,7 +41,7 @@ func patternAcmeNonce(b *backend, pattern string, opts acmeWrapperOpts) *framewo
} }
func (b *backend) acmeNonceHandler(ctx *acmeContext, r *logical.Request, _ *framework.FieldData) (*logical.Response, error) { func (b *backend) acmeNonceHandler(ctx *acmeContext, r *logical.Request, _ *framework.FieldData) (*logical.Response, error) {
nonce, _, err := b.acmeState.GetNonce() nonce, _, err := b.GetAcmeState().GetNonce()
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -21,6 +21,8 @@ import (
"github.com/hashicorp/vault/sdk/framework" "github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/sdk/logical"
"golang.org/x/net/idna" "golang.org/x/net/idna"
"github.com/hashicorp/vault/builtin/logical/pki/issuing"
) )
var maxAcmeCertTTL = 90 * (24 * time.Hour) var maxAcmeCertTTL = 90 * (24 * time.Hour)
@@ -164,7 +166,7 @@ func addFieldsForACMEOrder(fields map[string]*framework.FieldSchema) {
func (b *backend) acmeFetchCertOrderHandler(ac *acmeContext, _ *logical.Request, fields *framework.FieldData, uc *jwsCtx, data map[string]interface{}, _ *acmeAccount) (*logical.Response, error) { func (b *backend) acmeFetchCertOrderHandler(ac *acmeContext, _ *logical.Request, fields *framework.FieldData, uc *jwsCtx, data map[string]interface{}, _ *acmeAccount) (*logical.Response, error) {
orderId := fields.Get("order_id").(string) orderId := fields.Get("order_id").(string)
order, err := b.acmeState.LoadOrder(ac, uc, orderId) order, err := b.GetAcmeState().LoadOrder(ac, uc, orderId)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -232,7 +234,7 @@ func (b *backend) acmeFinalizeOrderHandler(ac *acmeContext, r *logical.Request,
return nil, err return nil, err
} }
order, err := b.acmeState.LoadOrder(ac, uc, orderId) order, err := b.GetAcmeState().LoadOrder(ac, uc, orderId)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -260,7 +262,7 @@ func (b *backend) acmeFinalizeOrderHandler(ac *acmeContext, r *logical.Request,
} }
var signedCertBundle *certutil.ParsedCertBundle var signedCertBundle *certutil.ParsedCertBundle
var issuerId issuerID var issuerId issuing.IssuerID
if ac.runtimeOpts.isCiepsEnabled { if ac.runtimeOpts.isCiepsEnabled {
// Note that issueAcmeCertUsingCieps enforces storage requirements and // Note that issueAcmeCertUsingCieps enforces storage requirements and
// does the certificate storage for us // does the certificate storage for us
@@ -281,7 +283,7 @@ func (b *backend) acmeFinalizeOrderHandler(ac *acmeContext, r *logical.Request,
} }
hyphenSerialNumber := normalizeSerialFromBigInt(signedCertBundle.Certificate.SerialNumber) hyphenSerialNumber := normalizeSerialFromBigInt(signedCertBundle.Certificate.SerialNumber)
if err := b.acmeState.TrackIssuedCert(ac, order.AccountId, hyphenSerialNumber, order.OrderId); err != nil { if err := b.GetAcmeState().TrackIssuedCert(ac, order.AccountId, hyphenSerialNumber, order.OrderId); err != nil {
b.Logger().Warn("orphaned generated ACME certificate due to error saving account->cert->order reference", "serial_number", hyphenSerialNumber, "error", err) b.Logger().Warn("orphaned generated ACME certificate due to error saving account->cert->order reference", "serial_number", hyphenSerialNumber, "error", err)
return nil, err return nil, err
} }
@@ -291,7 +293,7 @@ func (b *backend) acmeFinalizeOrderHandler(ac *acmeContext, r *logical.Request,
order.CertificateExpiry = signedCertBundle.Certificate.NotAfter order.CertificateExpiry = signedCertBundle.Certificate.NotAfter
order.IssuerId = issuerId order.IssuerId = issuerId
err = b.acmeState.SaveOrder(ac, order) err = b.GetAcmeState().SaveOrder(ac, order)
if err != nil { if err != nil {
b.Logger().Warn("orphaned generated ACME certificate due to error saving order", "serial_number", hyphenSerialNumber, "error", err) b.Logger().Warn("orphaned generated ACME certificate due to error saving order", "serial_number", hyphenSerialNumber, "error", err)
return nil, fmt.Errorf("failed saving updated order: %w", err) return nil, fmt.Errorf("failed saving updated order: %w", err)
@@ -413,7 +415,7 @@ func validateCsrMatchesOrder(csr *x509.CertificateRequest, order *acmeOrder) err
return nil return nil
} }
func (b *backend) validateIdentifiersAgainstRole(role *roleEntry, identifiers []*ACMEIdentifier) error { func (b *backend) validateIdentifiersAgainstRole(role *issuing.RoleEntry, identifiers []*ACMEIdentifier) error {
for _, identifier := range identifiers { for _, identifier := range identifiers {
switch identifier.Type { switch identifier.Type {
case ACMEDNSIdentifier: case ACMEDNSIdentifier:
@@ -480,7 +482,8 @@ func removeDuplicatesAndSortIps(ipIdentifiers []net.IP) []net.IP {
func storeCertificate(sc *storageContext, signedCertBundle *certutil.ParsedCertBundle) error { func storeCertificate(sc *storageContext, signedCertBundle *certutil.ParsedCertBundle) error {
hyphenSerialNumber := normalizeSerialFromBigInt(signedCertBundle.Certificate.SerialNumber) hyphenSerialNumber := normalizeSerialFromBigInt(signedCertBundle.Certificate.SerialNumber)
key := "certs/" + hyphenSerialNumber key := "certs/" + hyphenSerialNumber
certsCounted := sc.Backend.certsCounted.Load() certCounter := sc.Backend.GetCertificateCounter()
certsCounted := certCounter.IsInitialized()
err := sc.Storage.Put(sc.Context, &logical.StorageEntry{ err := sc.Storage.Put(sc.Context, &logical.StorageEntry{
Key: key, Key: key,
Value: signedCertBundle.CertificateBytes, Value: signedCertBundle.CertificateBytes,
@@ -488,7 +491,7 @@ func storeCertificate(sc *storageContext, signedCertBundle *certutil.ParsedCertB
if err != nil { if err != nil {
return fmt.Errorf("unable to store certificate locally: %w", err) return fmt.Errorf("unable to store certificate locally: %w", err)
} }
sc.Backend.ifCountEnabledIncrementTotalCertificatesCount(certsCounted, key) certCounter.IncrementTotalCertificatesCount(certsCounted, key)
return nil return nil
} }
@@ -520,7 +523,7 @@ func maybeAugmentReqDataWithSuitableCN(ac *acmeContext, csr *x509.CertificateReq
} }
} }
func issueCertFromCsr(ac *acmeContext, csr *x509.CertificateRequest) (*certutil.ParsedCertBundle, issuerID, error) { func issueCertFromCsr(ac *acmeContext, csr *x509.CertificateRequest) (*certutil.ParsedCertBundle, issuing.IssuerID, error) {
pemBlock := &pem.Block{ pemBlock := &pem.Block{
Type: "CERTIFICATE REQUEST", Type: "CERTIFICATE REQUEST",
Headers: nil, Headers: nil,
@@ -540,7 +543,7 @@ func issueCertFromCsr(ac *acmeContext, csr *x509.CertificateRequest) (*certutil.
// (TLS) clients are mostly verifying against server's DNS SANs. // (TLS) clients are mostly verifying against server's DNS SANs.
maybeAugmentReqDataWithSuitableCN(ac, csr, data) maybeAugmentReqDataWithSuitableCN(ac, csr, data)
signingBundle, issuerId, err := ac.sc.fetchCAInfoWithIssuer(ac.issuer.ID.String(), IssuanceUsage) signingBundle, issuerId, err := ac.sc.fetchCAInfoWithIssuer(ac.issuer.ID.String(), issuing.IssuanceUsage)
if err != nil { if err != nil {
return nil, "", fmt.Errorf("failed loading CA %s: %w", ac.issuer.ID.String(), err) return nil, "", fmt.Errorf("failed loading CA %s: %w", ac.issuer.ID.String(), err)
} }
@@ -595,7 +598,7 @@ func issueCertFromCsr(ac *acmeContext, csr *x509.CertificateRequest) (*certutil.
// We only allow ServerAuth key usage from ACME issued certs // We only allow ServerAuth key usage from ACME issued certs
// when configuration does not allow usage of ExtKeyusage field. // when configuration does not allow usage of ExtKeyusage field.
config, err := ac.sc.Backend.acmeState.getConfigWithUpdate(ac.sc) config, err := ac.sc.Backend.GetAcmeState().getConfigWithUpdate(ac.sc)
if err != nil { if err != nil {
return nil, "", fmt.Errorf("failed to fetch ACME configuration: %w", err) return nil, "", fmt.Errorf("failed to fetch ACME configuration: %w", err)
} }
@@ -655,7 +658,7 @@ func parseCsrFromFinalize(data map[string]interface{}) (*x509.CertificateRequest
func (b *backend) acmeGetOrderHandler(ac *acmeContext, _ *logical.Request, fields *framework.FieldData, uc *jwsCtx, _ map[string]interface{}, _ *acmeAccount) (*logical.Response, error) { func (b *backend) acmeGetOrderHandler(ac *acmeContext, _ *logical.Request, fields *framework.FieldData, uc *jwsCtx, _ map[string]interface{}, _ *acmeAccount) (*logical.Response, error) {
orderId := fields.Get("order_id").(string) orderId := fields.Get("order_id").(string)
order, err := b.acmeState.LoadOrder(ac, uc, orderId) order, err := b.GetAcmeState().LoadOrder(ac, uc, orderId)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -674,7 +677,7 @@ func (b *backend) acmeGetOrderHandler(ac *acmeContext, _ *logical.Request, field
filteredAuthorizationIds := []string{} filteredAuthorizationIds := []string{}
for _, authId := range order.AuthorizationIds { for _, authId := range order.AuthorizationIds {
authorization, err := b.acmeState.LoadAuthorization(ac, uc, authId) authorization, err := b.GetAcmeState().LoadAuthorization(ac, uc, authId)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -692,14 +695,14 @@ func (b *backend) acmeGetOrderHandler(ac *acmeContext, _ *logical.Request, field
} }
func (b *backend) acmeListOrdersHandler(ac *acmeContext, _ *logical.Request, _ *framework.FieldData, uc *jwsCtx, _ map[string]interface{}, acct *acmeAccount) (*logical.Response, error) { func (b *backend) acmeListOrdersHandler(ac *acmeContext, _ *logical.Request, _ *framework.FieldData, uc *jwsCtx, _ map[string]interface{}, acct *acmeAccount) (*logical.Response, error) {
orderIds, err := b.acmeState.ListOrderIds(ac.sc, acct.KeyId) orderIds, err := b.GetAcmeState().ListOrderIds(ac.sc, acct.KeyId)
if err != nil { if err != nil {
return nil, err return nil, err
} }
orderUrls := []string{} orderUrls := []string{}
for _, orderId := range orderIds { for _, orderId := range orderIds {
order, err := b.acmeState.LoadOrder(ac, uc, orderId) order, err := b.GetAcmeState().LoadOrder(ac, uc, orderId)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -771,7 +774,7 @@ func (b *backend) acmeNewOrderHandler(ac *acmeContext, _ *logical.Request, _ *fr
} }
authorizations = append(authorizations, authz) authorizations = append(authorizations, authz)
err = b.acmeState.SaveAuthorization(ac, authz) err = b.GetAcmeState().SaveAuthorization(ac, authz)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed storing authorization: %w", err) return nil, fmt.Errorf("failed storing authorization: %w", err)
} }
@@ -788,7 +791,7 @@ func (b *backend) acmeNewOrderHandler(ac *acmeContext, _ *logical.Request, _ *fr
AuthorizationIds: authorizationIds, AuthorizationIds: authorizationIds,
} }
err = b.acmeState.SaveOrder(ac, order) err = b.GetAcmeState().SaveOrder(ac, order)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed storing order: %w", err) return nil, fmt.Errorf("failed storing order: %w", err)
} }

View File

@@ -10,6 +10,8 @@ import (
"github.com/hashicorp/vault/sdk/framework" "github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/sdk/logical"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/hashicorp/vault/builtin/logical/pki/issuing"
) )
// TestACME_ValidateIdentifiersAgainstRole Verify the ACME order creation // TestACME_ValidateIdentifiersAgainstRole Verify the ACME order creation
@@ -20,13 +22,13 @@ func TestACME_ValidateIdentifiersAgainstRole(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
role *roleEntry role *issuing.RoleEntry
identifiers []*ACMEIdentifier identifiers []*ACMEIdentifier
expectErr bool expectErr bool
}{ }{
{ {
name: "verbatim-role-allows-dns-ip", name: "verbatim-role-allows-dns-ip",
role: buildSignVerbatimRoleWithNoData(nil), role: issuing.SignVerbatimRole(),
identifiers: _buildACMEIdentifiers("test.com", "127.0.0.1"), identifiers: _buildACMEIdentifiers("test.com", "127.0.0.1"),
expectErr: false, expectErr: false,
}, },
@@ -119,7 +121,7 @@ func _buildACMEIdentifier(val string) *ACMEIdentifier {
// Easily allow tests to create valid roles with proper defaults, since we don't have an easy // Easily allow tests to create valid roles with proper defaults, since we don't have an easy
// way to generate roles with proper defaults, go through the createRole handler with the handlers // way to generate roles with proper defaults, go through the createRole handler with the handlers
// field data so we pickup all the defaults specified there. // field data so we pickup all the defaults specified there.
func buildTestRole(t *testing.T, config map[string]interface{}) *roleEntry { func buildTestRole(t *testing.T, config map[string]interface{}) *issuing.RoleEntry {
b, s := CreateBackendWithStorage(t) b, s := CreateBackendWithStorage(t)
path := pathRoles(b) path := pathRoles(b)
@@ -135,7 +137,7 @@ func buildTestRole(t *testing.T, config map[string]interface{}) *roleEntry {
_, err := b.pathRoleCreate(ctx, &logical.Request{Storage: s}, &framework.FieldData{Raw: config, Schema: fields}) _, err := b.pathRoleCreate(ctx, &logical.Request{Storage: s}, &framework.FieldData{Raw: config, Schema: fields})
require.NoError(t, err, "failed generating role with config %v", config) require.NoError(t, err, "failed generating role with config %v", config)
role, err := b.getRole(ctx, s, config["name"].(string)) role, err := b.GetRole(ctx, s, config["name"].(string))
require.NoError(t, err, "failed loading stored role") require.NoError(t, err, "failed loading stored role")
return role return role

View File

@@ -82,7 +82,7 @@ func (b *backend) acmeRevocationHandler(acmeCtx *acmeContext, _ *logical.Request
// Fetch the CRL config as we need it to ultimately do the // Fetch the CRL config as we need it to ultimately do the
// revocation. This should be cached and thus relatively fast. // revocation. This should be cached and thus relatively fast.
config, err := b.crlBuilder.getConfigWithUpdate(acmeCtx.sc) config, err := b.CrlBuilder().getConfigWithUpdate(acmeCtx.sc)
if err != nil { if err != nil {
return nil, fmt.Errorf("unable to revoke certificate: failed reading revocation config: %v: %w", err, ErrServerInternal) return nil, fmt.Errorf("unable to revoke certificate: failed reading revocation config: %v: %w", err, ErrServerInternal)
} }
@@ -153,8 +153,8 @@ func (b *backend) acmeRevocationByPoP(acmeCtx *acmeContext, userCtx *jwsCtx, cer
} }
// Now it is safe to revoke. // Now it is safe to revoke.
b.revokeStorageLock.Lock() b.GetRevokeStorageLock().Lock()
defer b.revokeStorageLock.Unlock() defer b.GetRevokeStorageLock().Unlock()
return revokeCert(acmeCtx.sc, config, cert) return revokeCert(acmeCtx.sc, config, cert)
} }
@@ -169,14 +169,14 @@ func (b *backend) acmeRevocationByAccount(acmeCtx *acmeContext, userCtx *jwsCtx,
// We only support certificates issued by this user, we don't support // We only support certificates issued by this user, we don't support
// cross-account revocations. // cross-account revocations.
serial := serialFromCert(cert) serial := serialFromCert(cert)
acmeEntry, err := b.acmeState.GetIssuedCert(acmeCtx, userCtx.Kid, serial) acmeEntry, err := b.GetAcmeState().GetIssuedCert(acmeCtx, userCtx.Kid, serial)
if err != nil || acmeEntry == nil { if err != nil || acmeEntry == nil {
return nil, fmt.Errorf("unable to revoke certificate: %v: %w", err, ErrMalformed) return nil, fmt.Errorf("unable to revoke certificate: %v: %w", err, ErrMalformed)
} }
// Now it is safe to revoke. // Now it is safe to revoke.
b.revokeStorageLock.Lock() b.GetRevokeStorageLock().Lock()
defer b.revokeStorageLock.Unlock() defer b.GetRevokeStorageLock().Unlock()
return revokeCert(acmeCtx.sc, config, cert) return revokeCert(acmeCtx.sc, config, cert)
} }

View File

@@ -157,7 +157,7 @@ func pathAcmeConfig(b *backend) *framework.Path {
func (b *backend) pathAcmeRead(ctx context.Context, req *logical.Request, _ *framework.FieldData) (*logical.Response, error) { func (b *backend) pathAcmeRead(ctx context.Context, req *logical.Request, _ *framework.FieldData) (*logical.Response, error) {
sc := b.makeStorageContext(ctx, req.Storage) sc := b.makeStorageContext(ctx, req.Storage)
config, err := b.acmeState.getConfigWithForcedUpdate(sc) config, err := b.GetAcmeState().getConfigWithForcedUpdate(sc)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -195,7 +195,7 @@ func genResponseFromAcmeConfig(config *acmeConfigEntry, warnings []string) *logi
func (b *backend) pathAcmeWrite(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) { func (b *backend) pathAcmeWrite(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
sc := b.makeStorageContext(ctx, req.Storage) sc := b.makeStorageContext(ctx, req.Storage)
config, err := b.acmeState.getConfigWithForcedUpdate(sc) config, err := b.GetAcmeState().getConfigWithForcedUpdate(sc)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -337,7 +337,7 @@ func (b *backend) pathAcmeWrite(ctx context.Context, req *logical.Request, d *fr
} }
} }
if _, err := b.acmeState.writeConfig(sc, config); err != nil { if _, err := b.GetAcmeState().writeConfig(sc, config); err != nil {
return nil, fmt.Errorf("failed persisting: %w", err) return nil, fmt.Errorf("failed persisting: %w", err)
} }

View File

@@ -7,6 +7,7 @@ import (
"context" "context"
"net/http" "net/http"
"github.com/hashicorp/vault/builtin/logical/pki/issuing"
"github.com/hashicorp/vault/sdk/framework" "github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/sdk/logical"
) )
@@ -212,7 +213,7 @@ func pathReplaceRoot(b *backend) *framework.Path {
} }
func (b *backend) pathCAIssuersRead(ctx context.Context, req *logical.Request, _ *framework.FieldData) (*logical.Response, error) { func (b *backend) pathCAIssuersRead(ctx context.Context, req *logical.Request, _ *framework.FieldData) (*logical.Response, error) {
if b.useLegacyBundleCaStorage() { if b.UseLegacyBundleCaStorage() {
return logical.ErrorResponse("Cannot read defaults until migration has completed"), nil return logical.ErrorResponse("Cannot read defaults until migration has completed"), nil
} }
@@ -225,7 +226,7 @@ func (b *backend) pathCAIssuersRead(ctx context.Context, req *logical.Request, _
return b.formatCAIssuerConfigRead(config), nil return b.formatCAIssuerConfigRead(config), nil
} }
func (b *backend) formatCAIssuerConfigRead(config *issuerConfigEntry) *logical.Response { func (b *backend) formatCAIssuerConfigRead(config *issuing.IssuerConfigEntry) *logical.Response {
return &logical.Response{ return &logical.Response{
Data: map[string]interface{}{ Data: map[string]interface{}{
defaultRef: config.DefaultIssuerId, defaultRef: config.DefaultIssuerId,
@@ -240,7 +241,7 @@ func (b *backend) pathCAIssuersWrite(ctx context.Context, req *logical.Request,
b.issuersLock.Lock() b.issuersLock.Lock()
defer b.issuersLock.Unlock() defer b.issuersLock.Unlock()
if b.useLegacyBundleCaStorage() { if b.UseLegacyBundleCaStorage() {
return logical.ErrorResponse("Cannot update defaults until migration has completed"), nil return logical.ErrorResponse("Cannot update defaults until migration has completed"), nil
} }
@@ -370,7 +371,7 @@ func pathConfigKeys(b *backend) *framework.Path {
} }
func (b *backend) pathKeyDefaultRead(ctx context.Context, req *logical.Request, _ *framework.FieldData) (*logical.Response, error) { func (b *backend) pathKeyDefaultRead(ctx context.Context, req *logical.Request, _ *framework.FieldData) (*logical.Response, error) {
if b.useLegacyBundleCaStorage() { if b.UseLegacyBundleCaStorage() {
return logical.ErrorResponse("Cannot read key defaults until migration has completed"), nil return logical.ErrorResponse("Cannot read key defaults until migration has completed"), nil
} }
@@ -393,7 +394,7 @@ func (b *backend) pathKeyDefaultWrite(ctx context.Context, req *logical.Request,
b.issuersLock.Lock() b.issuersLock.Lock()
defer b.issuersLock.Unlock() defer b.issuersLock.Unlock()
if b.useLegacyBundleCaStorage() { if b.UseLegacyBundleCaStorage() {
return logical.ErrorResponse("Cannot update key defaults until migration has completed"), nil return logical.ErrorResponse("Cannot update key defaults until migration has completed"), nil
} }

View File

@@ -275,7 +275,7 @@ existing CRL and OCSP paths will return the unified CRL instead of a response ba
func (b *backend) pathCRLRead(ctx context.Context, req *logical.Request, _ *framework.FieldData) (*logical.Response, error) { func (b *backend) pathCRLRead(ctx context.Context, req *logical.Request, _ *framework.FieldData) (*logical.Response, error) {
sc := b.makeStorageContext(ctx, req.Storage) sc := b.makeStorageContext(ctx, req.Storage)
config, err := b.crlBuilder.getConfigWithForcedUpdate(sc) config, err := b.CrlBuilder().getConfigWithForcedUpdate(sc)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed fetching CRL config: %w", err) return nil, fmt.Errorf("failed fetching CRL config: %w", err)
} }
@@ -285,7 +285,7 @@ func (b *backend) pathCRLRead(ctx context.Context, req *logical.Request, _ *fram
func (b *backend) pathCRLWrite(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) { func (b *backend) pathCRLWrite(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
sc := b.makeStorageContext(ctx, req.Storage) sc := b.makeStorageContext(ctx, req.Storage)
config, err := b.crlBuilder.getConfigWithForcedUpdate(sc) config, err := b.CrlBuilder().getConfigWithForcedUpdate(sc)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -410,7 +410,7 @@ func (b *backend) pathCRLWrite(ctx context.Context, req *logical.Request, d *fra
return logical.ErrorResponse("unified_crl=true requires auto_rebuild=true, as unified CRLs cannot be rebuilt on every revocation."), nil return logical.ErrorResponse("unified_crl=true requires auto_rebuild=true, as unified CRLs cannot be rebuilt on every revocation."), nil
} }
if _, err := b.crlBuilder.writeConfig(sc, config); err != nil { if _, err := b.CrlBuilder().writeConfig(sc, config); err != nil {
return nil, fmt.Errorf("failed persisting CRL config: %w", err) return nil, fmt.Errorf("failed persisting CRL config: %w", err)
} }
@@ -418,13 +418,13 @@ func (b *backend) pathCRLWrite(ctx context.Context, req *logical.Request, d *fra
// Note this only affects/happens on the main cluster node, if you need to // Note this only affects/happens on the main cluster node, if you need to
// notify something based on a configuration change on all server types // notify something based on a configuration change on all server types
// have a look at crlBuilder::reloadConfigIfRequired // have a look at CrlBuilder::reloadConfigIfRequired
if oldDisable != config.Disable || (oldAutoRebuild && !config.AutoRebuild) || (oldEnableDelta != config.EnableDelta) || (oldUnifiedCRL != config.UnifiedCRL) { if oldDisable != config.Disable || (oldAutoRebuild && !config.AutoRebuild) || (oldEnableDelta != config.EnableDelta) || (oldUnifiedCRL != config.UnifiedCRL) {
// It wasn't disabled but now it is (or equivalently, we were set to // It wasn't disabled but now it is (or equivalently, we were set to
// auto-rebuild and we aren't now or equivalently, we changed our // auto-rebuild and we aren't now or equivalently, we changed our
// mind about delta CRLs and need a new complete one or equivalently, // mind about delta CRLs and need a new complete one or equivalently,
// we changed our mind about unified CRLs), rotate the CRLs. // we changed our mind about unified CRLs), rotate the CRLs.
warnings, crlErr := b.crlBuilder.rebuild(sc, true) warnings, crlErr := b.CrlBuilder().rebuild(sc, true)
if crlErr != nil { if crlErr != nil {
switch crlErr.(type) { switch crlErr.(type) {
case errutil.UserError: case errutil.UserError:

View File

@@ -7,9 +7,8 @@ import (
"context" "context"
"fmt" "fmt"
"net/http" "net/http"
"strings"
"github.com/asaskevich/govalidator" "github.com/hashicorp/vault/builtin/logical/pki/issuing"
"github.com/hashicorp/vault/sdk/framework" "github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/sdk/logical"
) )
@@ -140,23 +139,13 @@ set on all PR Secondary clusters.`,
} }
} }
func validateURLs(urls []string) string { func getGlobalAIAURLs(ctx context.Context, storage logical.Storage) (*issuing.AiaConfigEntry, error) {
for _, curr := range urls {
if !govalidator.IsURL(curr) || strings.Contains(curr, "{{issuer_id}}") || strings.Contains(curr, "{{cluster_path}}") || strings.Contains(curr, "{{cluster_aia_path}}") {
return curr
}
}
return ""
}
func getGlobalAIAURLs(ctx context.Context, storage logical.Storage) (*aiaConfigEntry, error) {
entry, err := storage.Get(ctx, "urls") entry, err := storage.Get(ctx, "urls")
if err != nil { if err != nil {
return nil, err return nil, err
} }
entries := &aiaConfigEntry{ entries := &issuing.AiaConfigEntry{
IssuingCertificates: []string{}, IssuingCertificates: []string{},
CRLDistributionPoints: []string{}, CRLDistributionPoints: []string{},
OCSPServers: []string{}, OCSPServers: []string{},
@@ -174,7 +163,7 @@ func getGlobalAIAURLs(ctx context.Context, storage logical.Storage) (*aiaConfigE
return entries, nil return entries, nil
} }
func writeURLs(ctx context.Context, storage logical.Storage, entries *aiaConfigEntry) error { func writeURLs(ctx context.Context, storage logical.Storage, entries *issuing.AiaConfigEntry) error {
entry, err := logical.StorageEntryJSON("urls", entries) entry, err := logical.StorageEntryJSON("urls", entries)
if err != nil { if err != nil {
return err return err
@@ -237,7 +226,7 @@ func (b *backend) pathWriteURL(ctx context.Context, req *logical.Request, data *
}, },
} }
if entries.EnableTemplating && !b.useLegacyBundleCaStorage() { if entries.EnableTemplating && !b.UseLegacyBundleCaStorage() {
sc := b.makeStorageContext(ctx, req.Storage) sc := b.makeStorageContext(ctx, req.Storage)
issuers, err := sc.listIssuers() issuers, err := sc.listIssuers()
if err != nil { if err != nil {
@@ -250,23 +239,23 @@ func (b *backend) pathWriteURL(ctx context.Context, req *logical.Request, data *
return nil, fmt.Errorf("unable to read issuer to validate templated URIs: %w", err) return nil, fmt.Errorf("unable to read issuer to validate templated URIs: %w", err)
} }
_, err = entries.toURLEntries(sc, issuer.ID) _, err = ToURLEntries(sc, issuer.ID, entries)
if err != nil { if err != nil {
resp.AddWarning(fmt.Sprintf("issuance may fail: %v\n\nConsider setting the cluster-local address if it is not already set.", err)) resp.AddWarning(fmt.Sprintf("issuance may fail: %v\n\nConsider setting the cluster-local address if it is not already set.", err))
} }
} }
} else if !entries.EnableTemplating { } else if !entries.EnableTemplating {
if badURL := validateURLs(entries.IssuingCertificates); badURL != "" { if badURL := issuing.ValidateURLs(entries.IssuingCertificates); badURL != "" {
return logical.ErrorResponse(fmt.Sprintf( return logical.ErrorResponse(fmt.Sprintf(
"invalid URL found in Authority Information Access (AIA) parameter issuing_certificates: %s", badURL)), nil "invalid URL found in Authority Information Access (AIA) parameter issuing_certificates: %s", badURL)), nil
} }
if badURL := validateURLs(entries.CRLDistributionPoints); badURL != "" { if badURL := issuing.ValidateURLs(entries.CRLDistributionPoints); badURL != "" {
return logical.ErrorResponse(fmt.Sprintf( return logical.ErrorResponse(fmt.Sprintf(
"invalid URL found in Authority Information Access (AIA) parameter crl_distribution_points: %s", badURL)), nil "invalid URL found in Authority Information Access (AIA) parameter crl_distribution_points: %s", badURL)), nil
} }
if badURL := validateURLs(entries.OCSPServers); badURL != "" { if badURL := issuing.ValidateURLs(entries.OCSPServers); badURL != "" {
return logical.ErrorResponse(fmt.Sprintf( return logical.ErrorResponse(fmt.Sprintf(
"invalid URL found in Authority Information Access (AIA) parameter ocsp_servers: %s", badURL)), nil "invalid URL found in Authority Information Access (AIA) parameter ocsp_servers: %s", badURL)), nil
} }

View File

@@ -11,6 +11,7 @@ import (
"strings" "strings"
"time" "time"
"github.com/hashicorp/vault/builtin/logical/pki/issuing"
"github.com/hashicorp/vault/helper/constants" "github.com/hashicorp/vault/helper/constants"
"github.com/hashicorp/vault/sdk/framework" "github.com/hashicorp/vault/sdk/framework"
@@ -306,7 +307,7 @@ func (b *backend) pathFetchRead(ctx context.Context, req *logical.Request, data
contentType = "application/pkix-cert" contentType = "application/pkix-cert"
} }
case req.Path == "crl" || req.Path == "crl/pem" || req.Path == "crl/delta" || req.Path == "crl/delta/pem" || req.Path == "cert/crl" || req.Path == "cert/crl/raw" || req.Path == "cert/crl/raw/pem" || req.Path == "cert/delta-crl" || req.Path == "cert/delta-crl/raw" || req.Path == "cert/delta-crl/raw/pem" || req.Path == "unified-crl" || req.Path == "unified-crl/pem" || req.Path == "unified-crl/delta" || req.Path == "unified-crl/delta/pem" || req.Path == "cert/unified-crl" || req.Path == "cert/unified-crl/raw" || req.Path == "cert/unified-crl/raw/pem" || req.Path == "cert/unified-delta-crl" || req.Path == "cert/unified-delta-crl/raw" || req.Path == "cert/unified-delta-crl/raw/pem": case req.Path == "crl" || req.Path == "crl/pem" || req.Path == "crl/delta" || req.Path == "crl/delta/pem" || req.Path == "cert/crl" || req.Path == "cert/crl/raw" || req.Path == "cert/crl/raw/pem" || req.Path == "cert/delta-crl" || req.Path == "cert/delta-crl/raw" || req.Path == "cert/delta-crl/raw/pem" || req.Path == "unified-crl" || req.Path == "unified-crl/pem" || req.Path == "unified-crl/delta" || req.Path == "unified-crl/delta/pem" || req.Path == "cert/unified-crl" || req.Path == "cert/unified-crl/raw" || req.Path == "cert/unified-crl/raw/pem" || req.Path == "cert/unified-delta-crl" || req.Path == "cert/unified-delta-crl/raw" || req.Path == "cert/unified-delta-crl/raw/pem":
config, err := b.crlBuilder.getConfigWithUpdate(sc) config, err := b.CrlBuilder().getConfigWithUpdate(sc)
if err != nil { if err != nil {
retErr = err retErr = err
goto reply goto reply
@@ -370,7 +371,7 @@ func (b *backend) pathFetchRead(ctx context.Context, req *logical.Request, data
// Prefer fetchCAInfo to fetchCertBySerial for CA certificates. // Prefer fetchCAInfo to fetchCertBySerial for CA certificates.
if serial == "ca_chain" || serial == "ca" { if serial == "ca_chain" || serial == "ca" {
caInfo, err := sc.fetchCAInfo(defaultRef, ReadOnlyUsage) caInfo, err := sc.fetchCAInfo(defaultRef, issuing.ReadOnlyUsage)
if err != nil { if err != nil {
switch err.(type) { switch err.(type) {
case errutil.UserError: case errutil.UserError:

View File

@@ -12,6 +12,7 @@ import (
"strings" "strings"
"time" "time"
"github.com/hashicorp/vault/builtin/logical/pki/issuing"
"github.com/hashicorp/vault/sdk/framework" "github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/certutil" "github.com/hashicorp/vault/sdk/helper/certutil"
"github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/sdk/logical"
@@ -55,7 +56,7 @@ func pathListIssuers(b *backend) *framework.Path {
} }
func (b *backend) pathListIssuersHandler(ctx context.Context, req *logical.Request, _ *framework.FieldData) (*logical.Response, error) { func (b *backend) pathListIssuersHandler(ctx context.Context, req *logical.Request, _ *framework.FieldData) (*logical.Response, error) {
if b.useLegacyBundleCaStorage() { if b.UseLegacyBundleCaStorage() {
return logical.ErrorResponse("Can not list issuers until migration has completed"), nil return logical.ErrorResponse("Can not list issuers until migration has completed"), nil
} }
@@ -398,11 +399,11 @@ func (b *backend) pathGetIssuer(ctx context.Context, req *logical.Request, data
return b.pathGetRawIssuer(ctx, req, data) return b.pathGetRawIssuer(ctx, req, data)
} }
if b.useLegacyBundleCaStorage() { if b.UseLegacyBundleCaStorage() {
return logical.ErrorResponse("Can not get issuer until migration has completed"), nil return logical.ErrorResponse("Can not get issuer until migration has completed"), nil
} }
issuerName := getIssuerRef(data) issuerName := GetIssuerRef(data)
if len(issuerName) == 0 { if len(issuerName) == 0 {
return logical.ErrorResponse("missing issuer reference"), nil return logical.ErrorResponse("missing issuer reference"), nil
} }
@@ -424,7 +425,7 @@ func (b *backend) pathGetIssuer(ctx context.Context, req *logical.Request, data
return respondReadIssuer(issuer) return respondReadIssuer(issuer)
} }
func respondReadIssuer(issuer *issuerEntry) (*logical.Response, error) { func respondReadIssuer(issuer *issuing.IssuerEntry) (*logical.Response, error) {
var respManualChain []string var respManualChain []string
for _, entity := range issuer.ManualChain { for _, entity := range issuer.ManualChain {
respManualChain = append(respManualChain, string(entity)) respManualChain = append(respManualChain, string(entity))
@@ -483,11 +484,11 @@ func (b *backend) pathUpdateIssuer(ctx context.Context, req *logical.Request, da
b.issuersLock.Lock() b.issuersLock.Lock()
defer b.issuersLock.Unlock() defer b.issuersLock.Unlock()
if b.useLegacyBundleCaStorage() { if b.UseLegacyBundleCaStorage() {
return logical.ErrorResponse("Can not update issuer until migration has completed"), nil return logical.ErrorResponse("Can not update issuer until migration has completed"), nil
} }
issuerName := getIssuerRef(data) issuerName := GetIssuerRef(data)
if len(issuerName) == 0 { if len(issuerName) == 0 {
return logical.ErrorResponse("missing issuer reference"), nil return logical.ErrorResponse("missing issuer reference"), nil
} }
@@ -537,9 +538,9 @@ func (b *backend) pathUpdateIssuer(ctx context.Context, req *logical.Request, da
} }
rawUsage := data.Get("usage").([]string) rawUsage := data.Get("usage").([]string)
newUsage, err := NewIssuerUsageFromNames(rawUsage) newUsage, err := issuing.NewIssuerUsageFromNames(rawUsage)
if err != nil { if err != nil {
return logical.ErrorResponse(fmt.Sprintf("Unable to parse specified usages: %v - valid values are %v", rawUsage, AllIssuerUsages.Names())), nil return logical.ErrorResponse(fmt.Sprintf("Unable to parse specified usages: %v - valid values are %v", rawUsage, issuing.AllIssuerUsages.Names())), nil
} }
// Revocation signature algorithm changes // Revocation signature algorithm changes
@@ -562,15 +563,15 @@ func (b *backend) pathUpdateIssuer(ctx context.Context, req *logical.Request, da
// AIA access changes // AIA access changes
enableTemplating := data.Get("enable_aia_url_templating").(bool) enableTemplating := data.Get("enable_aia_url_templating").(bool)
issuerCertificates := data.Get("issuing_certificates").([]string) issuerCertificates := data.Get("issuing_certificates").([]string)
if badURL := validateURLs(issuerCertificates); !enableTemplating && badURL != "" { if badURL := issuing.ValidateURLs(issuerCertificates); !enableTemplating && badURL != "" {
return logical.ErrorResponse(fmt.Sprintf("invalid URL found in Authority Information Access (AIA) parameter issuing_certificates: %s", badURL)), nil return logical.ErrorResponse(fmt.Sprintf("invalid URL found in Authority Information Access (AIA) parameter issuing_certificates: %s", badURL)), nil
} }
crlDistributionPoints := data.Get("crl_distribution_points").([]string) crlDistributionPoints := data.Get("crl_distribution_points").([]string)
if badURL := validateURLs(crlDistributionPoints); !enableTemplating && badURL != "" { if badURL := issuing.ValidateURLs(crlDistributionPoints); !enableTemplating && badURL != "" {
return logical.ErrorResponse(fmt.Sprintf("invalid URL found in Authority Information Access (AIA) parameter crl_distribution_points: %s", badURL)), nil return logical.ErrorResponse(fmt.Sprintf("invalid URL found in Authority Information Access (AIA) parameter crl_distribution_points: %s", badURL)), nil
} }
ocspServers := data.Get("ocsp_servers").([]string) ocspServers := data.Get("ocsp_servers").([]string)
if badURL := validateURLs(ocspServers); !enableTemplating && badURL != "" { if badURL := issuing.ValidateURLs(ocspServers); !enableTemplating && badURL != "" {
return logical.ErrorResponse(fmt.Sprintf("invalid URL found in Authority Information Access (AIA) parameter ocsp_servers: %s", badURL)), nil return logical.ErrorResponse(fmt.Sprintf("invalid URL found in Authority Information Access (AIA) parameter ocsp_servers: %s", badURL)), nil
} }
@@ -582,8 +583,8 @@ func (b *backend) pathUpdateIssuer(ctx context.Context, req *logical.Request, da
issuer.Name = newName issuer.Name = newName
issuer.LastModified = time.Now().UTC() issuer.LastModified = time.Now().UTC()
// See note in updateDefaultIssuerId about why this is necessary. // See note in updateDefaultIssuerId about why this is necessary.
b.crlBuilder.invalidateCRLBuildTime() b.CrlBuilder().invalidateCRLBuildTime()
b.crlBuilder.flushCRLBuildTimeInvalidation(sc) b.CrlBuilder().flushCRLBuildTimeInvalidation(sc)
modified = true modified = true
} }
@@ -593,7 +594,7 @@ func (b *backend) pathUpdateIssuer(ctx context.Context, req *logical.Request, da
} }
if newUsage != issuer.Usage { if newUsage != issuer.Usage {
if issuer.Revoked && newUsage.HasUsage(IssuanceUsage) { if issuer.Revoked && newUsage.HasUsage(issuing.IssuanceUsage) {
// Forbid allowing cert signing on its usage. // Forbid allowing cert signing on its usage.
return logical.ErrorResponse("This issuer was revoked; unable to modify its usage to include certificate signing again. Reissue this certificate (preferably with a new key) and modify that entry instead."), nil return logical.ErrorResponse("This issuer was revoked; unable to modify its usage to include certificate signing again. Reissue this certificate (preferably with a new key) and modify that entry instead."), nil
} }
@@ -604,7 +605,7 @@ func (b *backend) pathUpdateIssuer(ctx context.Context, req *logical.Request, da
if err != nil { if err != nil {
return nil, fmt.Errorf("unable to parse issuer's certificate: %w", err) return nil, fmt.Errorf("unable to parse issuer's certificate: %w", err)
} }
if (cert.KeyUsage&x509.KeyUsageCRLSign) == 0 && newUsage.HasUsage(CRLSigningUsage) { if (cert.KeyUsage&x509.KeyUsageCRLSign) == 0 && newUsage.HasUsage(issuing.CRLSigningUsage) {
return logical.ErrorResponse("This issuer's underlying certificate lacks the CRLSign KeyUsage value; unable to set CRLSigningUsage on this issuer as a result."), nil return logical.ErrorResponse("This issuer's underlying certificate lacks the CRLSign KeyUsage value; unable to set CRLSigningUsage on this issuer as a result."), nil
} }
@@ -618,7 +619,7 @@ func (b *backend) pathUpdateIssuer(ctx context.Context, req *logical.Request, da
} }
if issuer.AIAURIs == nil && (len(issuerCertificates) > 0 || len(crlDistributionPoints) > 0 || len(ocspServers) > 0) { if issuer.AIAURIs == nil && (len(issuerCertificates) > 0 || len(crlDistributionPoints) > 0 || len(ocspServers) > 0) {
issuer.AIAURIs = &aiaConfigEntry{} issuer.AIAURIs = &issuing.AiaConfigEntry{}
} }
if issuer.AIAURIs != nil { if issuer.AIAURIs != nil {
// Associative mapping from data source to destination on the // Associative mapping from data source to destination on the
@@ -665,7 +666,7 @@ func (b *backend) pathUpdateIssuer(ctx context.Context, req *logical.Request, da
// it'll write it out to disk for us. We'd hate to then modify the issuer // it'll write it out to disk for us. We'd hate to then modify the issuer
// again and write it a second time. // again and write it a second time.
var updateChain bool var updateChain bool
var constructedChain []issuerID var constructedChain []issuing.IssuerID
for index, newPathRef := range newPath { for index, newPathRef := range newPath {
// Allow self for the first entry. // Allow self for the first entry.
if index == 0 && newPathRef == "self" { if index == 0 && newPathRef == "self" {
@@ -715,7 +716,7 @@ func (b *backend) pathUpdateIssuer(ctx context.Context, req *logical.Request, da
addWarningOnDereferencing(sc, oldName, response) addWarningOnDereferencing(sc, oldName, response)
} }
if issuer.AIAURIs != nil && issuer.AIAURIs.EnableTemplating { if issuer.AIAURIs != nil && issuer.AIAURIs.EnableTemplating {
_, aiaErr := issuer.AIAURIs.toURLEntries(sc, issuer.ID) _, aiaErr := ToURLEntries(sc, issuer.ID, issuer.AIAURIs)
if aiaErr != nil { if aiaErr != nil {
response.AddWarning(fmt.Sprintf("issuance may fail: %v\n\nConsider setting the cluster-local address if it is not already set.", aiaErr)) response.AddWarning(fmt.Sprintf("issuance may fail: %v\n\nConsider setting the cluster-local address if it is not already set.", aiaErr))
} }
@@ -730,12 +731,12 @@ func (b *backend) pathPatchIssuer(ctx context.Context, req *logical.Request, dat
b.issuersLock.Lock() b.issuersLock.Lock()
defer b.issuersLock.Unlock() defer b.issuersLock.Unlock()
if b.useLegacyBundleCaStorage() { if b.UseLegacyBundleCaStorage() {
return logical.ErrorResponse("Can not patch issuer until migration has completed"), nil return logical.ErrorResponse("Can not patch issuer until migration has completed"), nil
} }
// First we fetch the issuer // First we fetch the issuer
issuerName := getIssuerRef(data) issuerName := GetIssuerRef(data)
if len(issuerName) == 0 { if len(issuerName) == 0 {
return logical.ErrorResponse("missing issuer reference"), nil return logical.ErrorResponse("missing issuer reference"), nil
} }
@@ -782,8 +783,8 @@ func (b *backend) pathPatchIssuer(ctx context.Context, req *logical.Request, dat
issuer.Name = newName issuer.Name = newName
issuer.LastModified = time.Now().UTC() issuer.LastModified = time.Now().UTC()
// See note in updateDefaultIssuerId about why this is necessary. // See note in updateDefaultIssuerId about why this is necessary.
b.crlBuilder.invalidateCRLBuildTime() b.CrlBuilder().invalidateCRLBuildTime()
b.crlBuilder.flushCRLBuildTimeInvalidation(sc) b.CrlBuilder().flushCRLBuildTimeInvalidation(sc)
modified = true modified = true
} }
} }
@@ -813,12 +814,12 @@ func (b *backend) pathPatchIssuer(ctx context.Context, req *logical.Request, dat
rawUsageData, ok := data.GetOk("usage") rawUsageData, ok := data.GetOk("usage")
if ok { if ok {
rawUsage := rawUsageData.([]string) rawUsage := rawUsageData.([]string)
newUsage, err := NewIssuerUsageFromNames(rawUsage) newUsage, err := issuing.NewIssuerUsageFromNames(rawUsage)
if err != nil { if err != nil {
return logical.ErrorResponse(fmt.Sprintf("Unable to parse specified usages: %v - valid values are %v", rawUsage, AllIssuerUsages.Names())), nil return logical.ErrorResponse(fmt.Sprintf("Unable to parse specified usages: %v - valid values are %v", rawUsage, issuing.AllIssuerUsages.Names())), nil
} }
if newUsage != issuer.Usage { if newUsage != issuer.Usage {
if issuer.Revoked && newUsage.HasUsage(IssuanceUsage) { if issuer.Revoked && newUsage.HasUsage(issuing.IssuanceUsage) {
// Forbid allowing cert signing on its usage. // Forbid allowing cert signing on its usage.
return logical.ErrorResponse("This issuer was revoked; unable to modify its usage to include certificate signing again. Reissue this certificate (preferably with a new key) and modify that entry instead."), nil return logical.ErrorResponse("This issuer was revoked; unable to modify its usage to include certificate signing again. Reissue this certificate (preferably with a new key) and modify that entry instead."), nil
} }
@@ -827,7 +828,7 @@ func (b *backend) pathPatchIssuer(ctx context.Context, req *logical.Request, dat
if err != nil { if err != nil {
return nil, fmt.Errorf("unable to parse issuer's certificate: %w", err) return nil, fmt.Errorf("unable to parse issuer's certificate: %w", err)
} }
if (cert.KeyUsage&x509.KeyUsageCRLSign) == 0 && newUsage.HasUsage(CRLSigningUsage) { if (cert.KeyUsage&x509.KeyUsageCRLSign) == 0 && newUsage.HasUsage(issuing.CRLSigningUsage) {
return logical.ErrorResponse("This issuer's underlying certificate lacks the CRLSign KeyUsage value; unable to set CRLSigningUsage on this issuer as a result."), nil return logical.ErrorResponse("This issuer's underlying certificate lacks the CRLSign KeyUsage value; unable to set CRLSigningUsage on this issuer as a result."), nil
} }
@@ -864,7 +865,7 @@ func (b *backend) pathPatchIssuer(ctx context.Context, req *logical.Request, dat
// AIA access changes. // AIA access changes.
if issuer.AIAURIs == nil { if issuer.AIAURIs == nil {
issuer.AIAURIs = &aiaConfigEntry{} issuer.AIAURIs = &issuing.AiaConfigEntry{}
} }
// Associative mapping from data source to destination on the // Associative mapping from data source to destination on the
@@ -903,7 +904,7 @@ func (b *backend) pathPatchIssuer(ctx context.Context, req *logical.Request, dat
rawURLsValue, ok := data.GetOk(pair.Source) rawURLsValue, ok := data.GetOk(pair.Source)
if ok { if ok {
urlsValue := rawURLsValue.([]string) urlsValue := rawURLsValue.([]string)
if badURL := validateURLs(urlsValue); !issuer.AIAURIs.EnableTemplating && badURL != "" { if badURL := issuing.ValidateURLs(urlsValue); !issuer.AIAURIs.EnableTemplating && badURL != "" {
return logical.ErrorResponse(fmt.Sprintf("invalid URL found in Authority Information Access (AIA) parameter %v: %s", pair.Source, badURL)), nil return logical.ErrorResponse(fmt.Sprintf("invalid URL found in Authority Information Access (AIA) parameter %v: %s", pair.Source, badURL)), nil
} }
@@ -925,7 +926,7 @@ func (b *backend) pathPatchIssuer(ctx context.Context, req *logical.Request, dat
if ok { if ok {
newPath := newPathData.([]string) newPath := newPathData.([]string)
var updateChain bool var updateChain bool
var constructedChain []issuerID var constructedChain []issuing.IssuerID
for index, newPathRef := range newPath { for index, newPathRef := range newPath {
// Allow self for the first entry. // Allow self for the first entry.
if index == 0 && newPathRef == "self" { if index == 0 && newPathRef == "self" {
@@ -976,7 +977,7 @@ func (b *backend) pathPatchIssuer(ctx context.Context, req *logical.Request, dat
addWarningOnDereferencing(sc, oldName, response) addWarningOnDereferencing(sc, oldName, response)
} }
if issuer.AIAURIs != nil && issuer.AIAURIs.EnableTemplating { if issuer.AIAURIs != nil && issuer.AIAURIs.EnableTemplating {
_, aiaErr := issuer.AIAURIs.toURLEntries(sc, issuer.ID) _, aiaErr := ToURLEntries(sc, issuer.ID, issuer.AIAURIs)
if aiaErr != nil { if aiaErr != nil {
response.AddWarning(fmt.Sprintf("issuance may fail: %v\n\nConsider setting the cluster-local address if it is not already set.", aiaErr)) response.AddWarning(fmt.Sprintf("issuance may fail: %v\n\nConsider setting the cluster-local address if it is not already set.", aiaErr))
} }
@@ -986,11 +987,11 @@ func (b *backend) pathPatchIssuer(ctx context.Context, req *logical.Request, dat
} }
func (b *backend) pathGetRawIssuer(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) { func (b *backend) pathGetRawIssuer(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
if b.useLegacyBundleCaStorage() { if b.UseLegacyBundleCaStorage() {
return logical.ErrorResponse("Can not get issuer until migration has completed"), nil return logical.ErrorResponse("Can not get issuer until migration has completed"), nil
} }
issuerName := getIssuerRef(data) issuerName := GetIssuerRef(data)
if len(issuerName) == 0 { if len(issuerName) == 0 {
return logical.ErrorResponse("missing issuer reference"), nil return logical.ErrorResponse("missing issuer reference"), nil
} }
@@ -1069,11 +1070,11 @@ func (b *backend) pathDeleteIssuer(ctx context.Context, req *logical.Request, da
b.issuersLock.Lock() b.issuersLock.Lock()
defer b.issuersLock.Unlock() defer b.issuersLock.Unlock()
if b.useLegacyBundleCaStorage() { if b.UseLegacyBundleCaStorage() {
return logical.ErrorResponse("Can not delete issuer until migration has completed"), nil return logical.ErrorResponse("Can not delete issuer until migration has completed"), nil
} }
issuerName := getIssuerRef(data) issuerName := GetIssuerRef(data)
if len(issuerName) == 0 { if len(issuerName) == 0 {
return logical.ErrorResponse("missing issuer reference"), nil return logical.ErrorResponse("missing issuer reference"), nil
} }
@@ -1082,7 +1083,7 @@ func (b *backend) pathDeleteIssuer(ctx context.Context, req *logical.Request, da
ref, err := sc.resolveIssuerReference(issuerName) ref, err := sc.resolveIssuerReference(issuerName)
if err != nil { if err != nil {
// Return as if we deleted it if we fail to lookup the issuer. // Return as if we deleted it if we fail to lookup the issuer.
if ref == IssuerRefNotFound { if ref == issuing.IssuerRefNotFound {
return &logical.Response{}, nil return &logical.Response{}, nil
} }
return nil, err return nil, err
@@ -1120,7 +1121,7 @@ func (b *backend) pathDeleteIssuer(ctx context.Context, req *logical.Request, da
// Finally, we need to rebuild both the local and the unified CRLs. This // Finally, we need to rebuild both the local and the unified CRLs. This
// will free up any now unnecessary space used in both the CRL config // will free up any now unnecessary space used in both the CRL config
// and for the underlying CRL. // and for the underlying CRL.
warnings, err := b.crlBuilder.rebuild(sc, true) warnings, err := b.CrlBuilder().rebuild(sc, true)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -1220,17 +1221,17 @@ func buildPathGetIssuerCRL(b *backend, pattern string, displayAttrs *framework.D
} }
func (b *backend) pathGetIssuerCRL(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) { func (b *backend) pathGetIssuerCRL(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
if b.useLegacyBundleCaStorage() { if b.UseLegacyBundleCaStorage() {
return logical.ErrorResponse("Can not get issuer's CRL until migration has completed"), nil return logical.ErrorResponse("Can not get issuer's CRL until migration has completed"), nil
} }
issuerName := getIssuerRef(data) issuerName := GetIssuerRef(data)
if len(issuerName) == 0 { if len(issuerName) == 0 {
return logical.ErrorResponse("missing issuer reference"), nil return logical.ErrorResponse("missing issuer reference"), nil
} }
sc := b.makeStorageContext(ctx, req.Storage) sc := b.makeStorageContext(ctx, req.Storage)
warnings, err := b.crlBuilder.rebuildIfForced(sc) warnings, err := b.CrlBuilder().rebuildIfForced(sc)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -14,6 +14,9 @@ import (
"github.com/hashicorp/vault/sdk/framework" "github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/sdk/logical"
"github.com/hashicorp/vault/builtin/logical/pki/issuing"
"github.com/hashicorp/vault/builtin/logical/pki/managed_key"
) )
func pathListKeys(b *backend) *framework.Path { func pathListKeys(b *backend) *framework.Path {
@@ -62,7 +65,7 @@ their identifier and their name (if set).`
) )
func (b *backend) pathListKeysHandler(ctx context.Context, req *logical.Request, _ *framework.FieldData) (*logical.Response, error) { func (b *backend) pathListKeysHandler(ctx context.Context, req *logical.Request, _ *framework.FieldData) (*logical.Response, error) {
if b.useLegacyBundleCaStorage() { if b.UseLegacyBundleCaStorage() {
return logical.ErrorResponse("Can not list keys until migration has completed"), nil return logical.ErrorResponse("Can not list keys until migration has completed"), nil
} }
@@ -225,7 +228,7 @@ the certificate.
) )
func (b *backend) pathGetKeyHandler(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) { func (b *backend) pathGetKeyHandler(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
if b.useLegacyBundleCaStorage() { if b.UseLegacyBundleCaStorage() {
return logical.ErrorResponse("Can not get keys until migration has completed"), nil return logical.ErrorResponse("Can not get keys until migration has completed"), nil
} }
@@ -255,27 +258,27 @@ func (b *backend) pathGetKeyHandler(ctx context.Context, req *logical.Request, d
} }
var pkForSkid crypto.PublicKey var pkForSkid crypto.PublicKey
if key.isManagedPrivateKey() { if key.IsManagedPrivateKey() {
managedKeyUUID, err := key.getManagedKeyUUID() managedKeyUUID, err := issuing.GetManagedKeyUUID(key)
if err != nil { if err != nil {
return nil, errutil.InternalError{Err: fmt.Sprintf("failed extracting managed key uuid from key id %s (%s): %v", key.ID, key.Name, err)} return nil, errutil.InternalError{Err: fmt.Sprintf("failed extracting managed key uuid from key id %s (%s): %v", key.ID, key.Name, err)}
} }
keyInfo, err := getManagedKeyInfo(ctx, b, managedKeyUUID) keyInfo, err := managed_key.GetManagedKeyInfo(ctx, b, managedKeyUUID)
if err != nil { if err != nil {
return nil, errutil.InternalError{Err: fmt.Sprintf("failed fetching managed key info from key id %s (%s): %v", key.ID, key.Name, err)} return nil, errutil.InternalError{Err: fmt.Sprintf("failed fetching managed key info from key id %s (%s): %v", key.ID, key.Name, err)}
} }
pkForSkid, err = getManagedKeyPublicKey(sc.Context, sc.Backend, managedKeyUUID) pkForSkid, err = managed_key.GetManagedKeyPublicKey(sc.Context, sc.Backend, managedKeyUUID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// To remain consistent across the api responses (mainly generate root/intermediate calls), return the actual // To remain consistent across the api responses (mainly generate root/intermediate calls), return the actual
// type of key, not that it is a managed key. // type of key, not that it is a managed key.
respData[keyTypeParam] = string(keyInfo.keyType) respData[keyTypeParam] = string(keyInfo.KeyType)
respData[managedKeyIdArg] = string(keyInfo.uuid) respData[managedKeyIdArg] = string(keyInfo.Uuid)
respData[managedKeyNameArg] = string(keyInfo.name) respData[managedKeyNameArg] = string(keyInfo.Name)
} else { } else {
pkForSkid, err = getPublicKeyFromBytes([]byte(key.PrivateKey)) pkForSkid, err = getPublicKeyFromBytes([]byte(key.PrivateKey))
if err != nil { if err != nil {
@@ -298,7 +301,7 @@ func (b *backend) pathUpdateKeyHandler(ctx context.Context, req *logical.Request
b.issuersLock.Lock() b.issuersLock.Lock()
defer b.issuersLock.Unlock() defer b.issuersLock.Unlock()
if b.useLegacyBundleCaStorage() { if b.UseLegacyBundleCaStorage() {
return logical.ErrorResponse("Can not update keys until migration has completed"), nil return logical.ErrorResponse("Can not update keys until migration has completed"), nil
} }
@@ -356,7 +359,7 @@ func (b *backend) pathDeleteKeyHandler(ctx context.Context, req *logical.Request
b.issuersLock.Lock() b.issuersLock.Lock()
defer b.issuersLock.Unlock() defer b.issuersLock.Unlock()
if b.useLegacyBundleCaStorage() { if b.UseLegacyBundleCaStorage() {
return logical.ErrorResponse("Can not delete keys until migration has completed"), nil return logical.ErrorResponse("Can not delete keys until migration has completed"), nil
} }
@@ -368,7 +371,7 @@ func (b *backend) pathDeleteKeyHandler(ctx context.Context, req *logical.Request
sc := b.makeStorageContext(ctx, req.Storage) sc := b.makeStorageContext(ctx, req.Storage)
keyId, err := sc.resolveKeyReference(keyRef) keyId, err := sc.resolveKeyReference(keyRef)
if err != nil { if err != nil {
if keyId == KeyRefNotFound { if keyId == issuing.KeyRefNotFound {
// We failed to lookup the key, we should ignore any error here and reply as if it was deleted. // We failed to lookup the key, we should ignore any error here and reply as if it was deleted.
return nil, nil return nil, nil
} }

View File

@@ -102,7 +102,7 @@ func (b *backend) pathGenerateIntermediate(ctx context.Context, req *logical.Req
var err error var err error
if b.useLegacyBundleCaStorage() { if b.UseLegacyBundleCaStorage() {
return logical.ErrorResponse("Can not create intermediate until migration has completed"), nil return logical.ErrorResponse("Can not create intermediate until migration has completed"), nil
} }

View File

@@ -19,6 +19,8 @@ import (
"github.com/hashicorp/vault/sdk/helper/consts" "github.com/hashicorp/vault/sdk/helper/consts"
"github.com/hashicorp/vault/sdk/helper/errutil" "github.com/hashicorp/vault/sdk/helper/errutil"
"github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/sdk/logical"
"github.com/hashicorp/vault/builtin/logical/pki/issuing"
) )
func pathIssue(b *backend) *framework.Path { func pathIssue(b *backend) *framework.Path {
@@ -285,7 +287,7 @@ See the API documentation for more information about required parameters.
// pathIssue issues a certificate and private key from given parameters, // pathIssue issues a certificate and private key from given parameters,
// subject to role restrictions // subject to role restrictions
func (b *backend) pathIssue(ctx context.Context, req *logical.Request, data *framework.FieldData, role *roleEntry) (*logical.Response, error) { func (b *backend) pathIssue(ctx context.Context, req *logical.Request, data *framework.FieldData, role *issuing.RoleEntry) (*logical.Response, error) {
if role.KeyType == "any" { if role.KeyType == "any" {
return logical.ErrorResponse("role key type \"any\" not allowed for issuing certificates, only signing"), nil return logical.ErrorResponse("role key type \"any\" not allowed for issuing certificates, only signing"), nil
} }
@@ -295,19 +297,49 @@ func (b *backend) pathIssue(ctx context.Context, req *logical.Request, data *fra
// pathSign issues a certificate from a submitted CSR, subject to role // pathSign issues a certificate from a submitted CSR, subject to role
// restrictions // restrictions
func (b *backend) pathSign(ctx context.Context, req *logical.Request, data *framework.FieldData, role *roleEntry) (*logical.Response, error) { func (b *backend) pathSign(ctx context.Context, req *logical.Request, data *framework.FieldData, role *issuing.RoleEntry) (*logical.Response, error) {
return b.pathIssueSignCert(ctx, req, data, role, true, false) return b.pathIssueSignCert(ctx, req, data, role, true, false)
} }
// pathSignVerbatim issues a certificate from a submitted CSR, *not* subject to // pathSignVerbatim issues a certificate from a submitted CSR, *not* subject to
// role restrictions // role restrictions
func (b *backend) pathSignVerbatim(ctx context.Context, req *logical.Request, data *framework.FieldData, role *roleEntry) (*logical.Response, error) { func (b *backend) pathSignVerbatim(ctx context.Context, req *logical.Request, data *framework.FieldData, role *issuing.RoleEntry) (*logical.Response, error) {
entry := buildSignVerbatimRole(data, role) opts := []issuing.RoleModifier{
issuing.WithKeyUsage(data.Get("key_usage").([]string)),
issuing.WithExtKeyUsage(data.Get("ext_key_usage").([]string)),
issuing.WithExtKeyUsageOIDs(data.Get("ext_key_usage_oids").([]string)),
issuing.WithSignatureBits(data.Get("signature_bits").(int)),
issuing.WithUsePSS(data.Get("use_pss").(bool)),
}
// if we did receive a role parameter value with a valid role, use some of its values
// to populate and influence the sign-verbatim behavior.
if role != nil {
opts = append(opts, issuing.WithNoStore(role.NoStore))
opts = append(opts, issuing.WithIssuer(role.Issuer))
if role.TTL > 0 {
opts = append(opts, issuing.WithTTL(role.TTL))
}
if role.MaxTTL > 0 {
opts = append(opts, issuing.WithMaxTTL(role.MaxTTL))
}
if role.GenerateLease != nil {
opts = append(opts, issuing.WithGenerateLease(*role.GenerateLease))
}
if role.NotBeforeDuration > 0 {
opts = append(opts, issuing.WithNotBeforeDuration(role.NotBeforeDuration))
}
}
entry := issuing.SignVerbatimRoleWithOpts(opts...)
return b.pathIssueSignCert(ctx, req, data, entry, true, true) return b.pathIssueSignCert(ctx, req, data, entry, true, true)
} }
func (b *backend) pathIssueSignCert(ctx context.Context, req *logical.Request, data *framework.FieldData, role *roleEntry, useCSR, useCSRValues bool) (*logical.Response, error) { func (b *backend) pathIssueSignCert(ctx context.Context, req *logical.Request, data *framework.FieldData, role *issuing.RoleEntry, useCSR, useCSRValues bool) (*logical.Response, error) {
// If storing the certificate and on a performance standby, forward this request on to the primary // If storing the certificate and on a performance standby, forward this request on to the primary
// Allow performance secondaries to generate and store certificates locally to them. // Allow performance secondaries to generate and store certificates locally to them.
if !role.NoStore && b.System().ReplicationState().HasState(consts.ReplicationPerformanceStandby) { if !role.NoStore && b.System().ReplicationState().HasState(consts.ReplicationPerformanceStandby) {
@@ -333,7 +365,7 @@ func (b *backend) pathIssueSignCert(ctx context.Context, req *logical.Request, d
} else { } else {
// Otherwise, we must have a newer API which requires an issuer // Otherwise, we must have a newer API which requires an issuer
// reference. Fetch it in this case // reference. Fetch it in this case
issuerName = getIssuerRef(data) issuerName = GetIssuerRef(data)
if len(issuerName) == 0 { if len(issuerName) == 0 {
return logical.ErrorResponse("missing issuer reference"), nil return logical.ErrorResponse("missing issuer reference"), nil
} }
@@ -347,7 +379,7 @@ func (b *backend) pathIssueSignCert(ctx context.Context, req *logical.Request, d
var caErr error var caErr error
sc := b.makeStorageContext(ctx, req.Storage) sc := b.makeStorageContext(ctx, req.Storage)
signingBundle, caErr := sc.fetchCAInfo(issuerName, IssuanceUsage) signingBundle, caErr := sc.fetchCAInfo(issuerName, issuing.IssuanceUsage)
if caErr != nil { if caErr != nil {
switch caErr.(type) { switch caErr.(type) {
case errutil.UserError: case errutil.UserError:
@@ -400,7 +432,8 @@ func (b *backend) pathIssueSignCert(ctx context.Context, req *logical.Request, d
if !role.NoStore { if !role.NoStore {
key := "certs/" + normalizeSerial(cb.SerialNumber) key := "certs/" + normalizeSerial(cb.SerialNumber)
certsCounted := b.certsCounted.Load() certCounter := b.GetCertificateCounter()
certsCounted := certCounter.IsInitialized()
err = req.Storage.Put(ctx, &logical.StorageEntry{ err = req.Storage.Put(ctx, &logical.StorageEntry{
Key: key, Key: key,
Value: parsedBundle.CertificateBytes, Value: parsedBundle.CertificateBytes,
@@ -408,7 +441,7 @@ func (b *backend) pathIssueSignCert(ctx context.Context, req *logical.Request, d
if err != nil { if err != nil {
return nil, fmt.Errorf("unable to store certificate locally: %w", err) return nil, fmt.Errorf("unable to store certificate locally: %w", err)
} }
b.ifCountEnabledIncrementTotalCertificatesCount(certsCounted, key) certCounter.IncrementTotalCertificatesCount(certsCounted, key)
} }
if useCSR { if useCSR {

View File

@@ -13,6 +13,7 @@ import (
"strings" "strings"
"time" "time"
"github.com/hashicorp/vault/builtin/logical/pki/issuing"
"github.com/hashicorp/vault/sdk/framework" "github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/errutil" "github.com/hashicorp/vault/sdk/helper/errutil"
"github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/sdk/logical"
@@ -275,7 +276,7 @@ func (b *backend) pathImportIssuers(ctx context.Context, req *logical.Request, d
keysAllowed := strings.HasSuffix(req.Path, "bundle") || req.Path == "config/ca" keysAllowed := strings.HasSuffix(req.Path, "bundle") || req.Path == "config/ca"
if b.useLegacyBundleCaStorage() { if b.UseLegacyBundleCaStorage() {
return logical.ErrorResponse("Can not import issuers until migration has completed"), nil return logical.ErrorResponse("Can not import issuers until migration has completed"), nil
} }
@@ -406,7 +407,7 @@ func (b *backend) pathImportIssuers(ctx context.Context, req *logical.Request, d
} }
if len(createdIssuers) > 0 { if len(createdIssuers) > 0 {
warnings, err := b.crlBuilder.rebuild(sc, true) warnings, err := b.CrlBuilder().rebuild(sc, true)
if err != nil { if err != nil {
// Before returning, check if the error message includes the // Before returning, check if the error message includes the
// string "PSS". If so, it indicates we might've wanted to modify // string "PSS". If so, it indicates we might've wanted to modify
@@ -438,7 +439,7 @@ func (b *backend) pathImportIssuers(ctx context.Context, req *logical.Request, d
response.AddWarning("Unable to fetch default issuers configuration to update default issuer if necessary: " + err.Error()) response.AddWarning("Unable to fetch default issuers configuration to update default issuer if necessary: " + err.Error())
} else if config.DefaultFollowsLatestIssuer { } else if config.DefaultFollowsLatestIssuer {
if len(issuersWithKeys) == 1 { if len(issuersWithKeys) == 1 {
if err := sc.updateDefaultIssuerId(issuerID(issuersWithKeys[0])); err != nil { if err := sc.updateDefaultIssuerId(issuing.IssuerID(issuersWithKeys[0])); err != nil {
response.AddWarning("Unable to update this new root as the default issuer: " + err.Error()) response.AddWarning("Unable to update this new root as the default issuer: " + err.Error())
} }
} else if len(issuersWithKeys) > 1 { } else if len(issuersWithKeys) > 1 {
@@ -627,11 +628,11 @@ func (b *backend) pathRevokeIssuer(ctx context.Context, req *logical.Request, da
defer b.issuersLock.Unlock() defer b.issuersLock.Unlock()
// Issuer revocation can't work on the legacy cert bundle. // Issuer revocation can't work on the legacy cert bundle.
if b.useLegacyBundleCaStorage() { if b.UseLegacyBundleCaStorage() {
return logical.ErrorResponse("cannot revoke issuer until migration has completed"), nil return logical.ErrorResponse("cannot revoke issuer until migration has completed"), nil
} }
issuerName := getIssuerRef(data) issuerName := GetIssuerRef(data)
if len(issuerName) == 0 { if len(issuerName) == 0 {
return logical.ErrorResponse("missing issuer reference"), nil return logical.ErrorResponse("missing issuer reference"), nil
} }
@@ -661,8 +662,8 @@ func (b *backend) pathRevokeIssuer(ctx context.Context, req *logical.Request, da
// new revocations of leaves issued by this issuer to trigger a CRL // new revocations of leaves issued by this issuer to trigger a CRL
// rebuild still. // rebuild still.
issuer.Revoked = true issuer.Revoked = true
if issuer.Usage.HasUsage(IssuanceUsage) { if issuer.Usage.HasUsage(issuing.IssuanceUsage) {
issuer.Usage.ToggleUsage(IssuanceUsage) issuer.Usage.ToggleUsage(issuing.IssuanceUsage)
} }
currTime := time.Now() currTime := time.Now()
@@ -730,7 +731,7 @@ func (b *backend) pathRevokeIssuer(ctx context.Context, req *logical.Request, da
} }
// Rebuild the CRL to include the newly revoked issuer. // Rebuild the CRL to include the newly revoked issuer.
warnings, crlErr := b.crlBuilder.rebuild(sc, false) warnings, crlErr := b.CrlBuilder().rebuild(sc, false)
if crlErr != nil { if crlErr != nil {
switch crlErr.(type) { switch crlErr.(type) {
case errutil.UserError: case errutil.UserError:

View File

@@ -13,6 +13,8 @@ import (
"github.com/hashicorp/vault/sdk/framework" "github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/certutil" "github.com/hashicorp/vault/sdk/helper/certutil"
"github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/sdk/logical"
"github.com/hashicorp/vault/builtin/logical/pki/managed_key"
) )
func pathGenerateKey(b *backend) *framework.Path { func pathGenerateKey(b *backend) *framework.Path {
@@ -114,7 +116,7 @@ func (b *backend) pathGenerateKeyHandler(ctx context.Context, req *logical.Reque
b.issuersLock.Lock() b.issuersLock.Lock()
defer b.issuersLock.Unlock() defer b.issuersLock.Unlock()
if b.useLegacyBundleCaStorage() { if b.UseLegacyBundleCaStorage() {
return logical.ErrorResponse("Can not generate keys until migration has completed"), nil return logical.ErrorResponse("Can not generate keys until migration has completed"), nil
} }
@@ -153,7 +155,7 @@ func (b *backend) pathGenerateKeyHandler(ctx context.Context, req *logical.Reque
return nil, err return nil, err
} }
keyBundle, actualPrivateKeyType, err = createKmsKeyBundle(ctx, b, keyId) keyBundle, actualPrivateKeyType, err = managed_key.CreateKmsKeyBundle(ctx, b, keyId)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -252,7 +254,7 @@ func (b *backend) pathImportKeyHandler(ctx context.Context, req *logical.Request
b.issuersLock.Lock() b.issuersLock.Lock()
defer b.issuersLock.Unlock() defer b.issuersLock.Unlock()
if b.useLegacyBundleCaStorage() { if b.UseLegacyBundleCaStorage() {
return logical.ErrorResponse("Cannot import keys until migration has completed"), nil return logical.ErrorResponse("Cannot import keys until migration has completed"), nil
} }

View File

@@ -12,6 +12,7 @@ import (
"fmt" "fmt"
"testing" "testing"
"github.com/hashicorp/vault/builtin/logical/pki/issuing"
"github.com/hashicorp/vault/sdk/helper/testhelpers/schema" "github.com/hashicorp/vault/sdk/helper/testhelpers/schema"
"github.com/hashicorp/vault/sdk/helper/certutil" "github.com/hashicorp/vault/sdk/helper/certutil"
@@ -152,7 +153,7 @@ func TestPKI_PathManageKeys_ImportKeyBundle(t *testing.T) {
require.NotEmpty(t, resp.Data["key_id"], "key id for ec import response was empty") require.NotEmpty(t, resp.Data["key_id"], "key id for ec import response was empty")
require.Equal(t, "my-ec-key", resp.Data["key_name"], "key_name was incorrect for ec key") require.Equal(t, "my-ec-key", resp.Data["key_name"], "key_name was incorrect for ec key")
require.Equal(t, certutil.ECPrivateKey, resp.Data["key_type"]) require.Equal(t, certutil.ECPrivateKey, resp.Data["key_type"])
keyId1 := resp.Data["key_id"].(keyID) keyId1 := resp.Data["key_id"].(issuing.KeyID)
resp, err = b.HandleRequest(context.Background(), &logical.Request{ resp, err = b.HandleRequest(context.Background(), &logical.Request{
Operation: logical.UpdateOperation, Operation: logical.UpdateOperation,
@@ -170,7 +171,7 @@ func TestPKI_PathManageKeys_ImportKeyBundle(t *testing.T) {
require.NotEmpty(t, resp.Data["key_id"], "key id for rsa import response was empty") require.NotEmpty(t, resp.Data["key_id"], "key id for rsa import response was empty")
require.Equal(t, "my-rsa-key", resp.Data["key_name"], "key_name was incorrect for ec key") require.Equal(t, "my-rsa-key", resp.Data["key_name"], "key_name was incorrect for ec key")
require.Equal(t, certutil.RSAPrivateKey, resp.Data["key_type"]) require.Equal(t, certutil.RSAPrivateKey, resp.Data["key_type"])
keyId2 := resp.Data["key_id"].(keyID) keyId2 := resp.Data["key_id"].(issuing.KeyID)
require.NotEqual(t, keyId1, keyId2) require.NotEqual(t, keyId1, keyId2)
@@ -251,7 +252,7 @@ func TestPKI_PathManageKeys_ImportKeyBundle(t *testing.T) {
require.NotEmpty(t, resp.Data["key_id"], "key id for rsa import response was empty") require.NotEmpty(t, resp.Data["key_id"], "key id for rsa import response was empty")
require.Equal(t, "my-rsa-key", resp.Data["key_name"], "key_name was incorrect for ec key") require.Equal(t, "my-rsa-key", resp.Data["key_name"], "key_name was incorrect for ec key")
require.Equal(t, certutil.RSAPrivateKey, resp.Data["key_type"]) require.Equal(t, certutil.RSAPrivateKey, resp.Data["key_type"])
keyId2Reimport := resp.Data["key_id"].(keyID) keyId2Reimport := resp.Data["key_id"].(issuing.KeyID)
require.NotEqual(t, keyId2, keyId2Reimport, "re-importing key 2 did not generate a new key id") require.NotEqual(t, keyId2, keyId2Reimport, "re-importing key 2 did not generate a new key id")
} }
@@ -270,7 +271,7 @@ func TestPKI_PathManageKeys_DeleteDefaultKeyWarns(t *testing.T) {
require.NoError(t, err, "Failed generating key") require.NoError(t, err, "Failed generating key")
require.NotNil(t, resp, "Got nil response generating key") require.NotNil(t, resp, "Got nil response generating key")
require.False(t, resp.IsError(), "resp contained errors generating key: %#v", resp.Error()) require.False(t, resp.IsError(), "resp contained errors generating key: %#v", resp.Error())
keyId := resp.Data["key_id"].(keyID) keyId := resp.Data["key_id"].(issuing.KeyID)
resp, err = b.HandleRequest(context.Background(), &logical.Request{ resp, err = b.HandleRequest(context.Background(), &logical.Request{
Operation: logical.DeleteOperation, Operation: logical.DeleteOperation,
@@ -298,7 +299,7 @@ func TestPKI_PathManageKeys_DeleteUsedKeyFails(t *testing.T) {
require.NoError(t, err, "Failed generating issuer") require.NoError(t, err, "Failed generating issuer")
require.NotNil(t, resp, "Got nil response generating issuer") require.NotNil(t, resp, "Got nil response generating issuer")
require.False(t, resp.IsError(), "resp contained errors generating issuer: %#v", resp.Error()) require.False(t, resp.IsError(), "resp contained errors generating issuer: %#v", resp.Error())
keyId := resp.Data["key_id"].(keyID) keyId := resp.Data["key_id"].(issuing.KeyID)
resp, err = b.HandleRequest(context.Background(), &logical.Request{ resp, err = b.HandleRequest(context.Background(), &logical.Request{
Operation: logical.DeleteOperation, Operation: logical.DeleteOperation,
@@ -325,7 +326,7 @@ func TestPKI_PathManageKeys_UpdateKeyDetails(t *testing.T) {
require.NoError(t, err, "Failed generating key") require.NoError(t, err, "Failed generating key")
require.NotNil(t, resp, "Got nil response generating key") require.NotNil(t, resp, "Got nil response generating key")
require.False(t, resp.IsError(), "resp contained errors generating key: %#v", resp.Error()) require.False(t, resp.IsError(), "resp contained errors generating key: %#v", resp.Error())
keyId := resp.Data["key_id"].(keyID) keyId := resp.Data["key_id"].(issuing.KeyID)
resp, err = b.HandleRequest(context.Background(), &logical.Request{ resp, err = b.HandleRequest(context.Background(), &logical.Request{
Operation: logical.UpdateOperation, Operation: logical.UpdateOperation,

View File

@@ -20,6 +20,7 @@ import (
"time" "time"
"github.com/hashicorp/go-secure-stdlib/parseutil" "github.com/hashicorp/go-secure-stdlib/parseutil"
"github.com/hashicorp/vault/builtin/logical/pki/issuing"
"github.com/hashicorp/vault/sdk/framework" "github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/certutil" "github.com/hashicorp/vault/sdk/helper/certutil"
"github.com/hashicorp/vault/sdk/helper/errutil" "github.com/hashicorp/vault/sdk/helper/errutil"
@@ -37,7 +38,7 @@ type ocspRespInfo struct {
serialNumber *big.Int serialNumber *big.Int
ocspStatus int ocspStatus int
revocationTimeUTC *time.Time revocationTimeUTC *time.Time
issuerID issuerID issuerID issuing.IssuerID
} }
// These response variables should not be mutated, instead treat them as constants // These response variables should not be mutated, instead treat them as constants
@@ -155,7 +156,7 @@ func buildOcspPostWithPath(b *backend, pattern string, displayAttrs *framework.D
func (b *backend) ocspHandler(ctx context.Context, request *logical.Request, data *framework.FieldData) (*logical.Response, error) { func (b *backend) ocspHandler(ctx context.Context, request *logical.Request, data *framework.FieldData) (*logical.Response, error) {
sc := b.makeStorageContext(ctx, request.Storage) sc := b.makeStorageContext(ctx, request.Storage)
cfg, err := b.crlBuilder.getConfigWithUpdate(sc) cfg, err := b.CrlBuilder().getConfigWithUpdate(sc)
if err != nil || cfg.OcspDisable || (isUnifiedOcspPath(request) && !cfg.UnifiedCRL) { if err != nil || cfg.OcspDisable || (isUnifiedOcspPath(request) && !cfg.UnifiedCRL) {
return OcspUnauthorizedResponse, nil return OcspUnauthorizedResponse, nil
} }
@@ -247,7 +248,7 @@ func generateUnknownResponse(cfg *crlConfig, sc *storageContext, ocspReq *ocsp.R
return logAndReturnInternalError(sc.Backend, err) return logAndReturnInternalError(sc.Backend, err)
} }
if !issuer.Usage.HasUsage(OCSPSigningUsage) { if !issuer.Usage.HasUsage(issuing.OCSPSigningUsage) {
// If we don't have any issuers or default issuers set, no way to sign a response so Unauthorized it is. // If we don't have any issuers or default issuers set, no way to sign a response so Unauthorized it is.
return OcspUnauthorizedResponse return OcspUnauthorizedResponse
} }
@@ -358,7 +359,7 @@ func getOcspStatus(sc *storageContext, ocspReq *ocsp.Request, useUnifiedStorage
return &info, nil return &info, nil
} }
func lookupOcspIssuer(sc *storageContext, req *ocsp.Request, optRevokedIssuer issuerID) (*certutil.ParsedCertBundle, *issuerEntry, error) { func lookupOcspIssuer(sc *storageContext, req *ocsp.Request, optRevokedIssuer issuing.IssuerID) (*certutil.ParsedCertBundle, *issuing.IssuerEntry, error) {
reqHash := req.HashAlgorithm reqHash := req.HashAlgorithm
if !reqHash.Available() { if !reqHash.Available() {
return nil, nil, x509.ErrUnsupportedAlgorithm return nil, nil, x509.ErrUnsupportedAlgorithm
@@ -395,7 +396,7 @@ func lookupOcspIssuer(sc *storageContext, req *ocsp.Request, optRevokedIssuer is
} }
if matches { if matches {
if !issuer.Usage.HasUsage(OCSPSigningUsage) { if !issuer.Usage.HasUsage(issuing.OCSPSigningUsage) {
matchedButNoUsage = true matchedButNoUsage = true
// We found a matching issuer, but it's not allowed to sign the // We found a matching issuer, but it's not allowed to sign the
// response, there might be another issuer that we rotated // response, there might be another issuer that we rotated
@@ -415,7 +416,7 @@ func lookupOcspIssuer(sc *storageContext, req *ocsp.Request, optRevokedIssuer is
return nil, nil, ErrUnknownIssuer return nil, nil, ErrUnknownIssuer
} }
func getOcspIssuerParsedBundle(sc *storageContext, issuerId issuerID) (*certutil.ParsedCertBundle, *issuerEntry, error) { func getOcspIssuerParsedBundle(sc *storageContext, issuerId issuing.IssuerID) (*certutil.ParsedCertBundle, *issuing.IssuerEntry, error) {
issuer, bundle, err := sc.fetchCertBundleByIssuerId(issuerId, true) issuer, bundle, err := sc.fetchCertBundleByIssuerId(issuerId, true)
if err != nil { if err != nil {
switch err.(type) { switch err.(type) {
@@ -440,13 +441,13 @@ func getOcspIssuerParsedBundle(sc *storageContext, issuerId issuerID) (*certutil
return caBundle, issuer, nil return caBundle, issuer, nil
} }
func lookupIssuerIds(sc *storageContext, optRevokedIssuer issuerID) ([]issuerID, error) { func lookupIssuerIds(sc *storageContext, optRevokedIssuer issuing.IssuerID) ([]issuing.IssuerID, error) {
if optRevokedIssuer != "" { if optRevokedIssuer != "" {
return []issuerID{optRevokedIssuer}, nil return []issuing.IssuerID{optRevokedIssuer}, nil
} }
if sc.Backend.useLegacyBundleCaStorage() { if sc.Backend.UseLegacyBundleCaStorage() {
return []issuerID{legacyBundleShimID}, nil return []issuing.IssuerID{legacyBundleShimID}, nil
} }
return sc.listIssuers() return sc.listIssuers()

View File

@@ -18,6 +18,7 @@ import (
"time" "time"
"github.com/hashicorp/go-secure-stdlib/parseutil" "github.com/hashicorp/go-secure-stdlib/parseutil"
"github.com/hashicorp/vault/builtin/logical/pki/issuing"
vaulthttp "github.com/hashicorp/vault/http" vaulthttp "github.com/hashicorp/vault/http"
"github.com/hashicorp/vault/sdk/helper/testhelpers/schema" "github.com/hashicorp/vault/sdk/helper/testhelpers/schema"
"github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/sdk/logical"
@@ -258,11 +259,11 @@ func TestOcsp_RevokedCertHasIssuerWithoutOcspUsage(t *testing.T) {
requireFieldsSetInResp(t, resp, "usage") requireFieldsSetInResp(t, resp, "usage")
// Do not assume a specific ordering for usage... // Do not assume a specific ordering for usage...
usages, err := NewIssuerUsageFromNames(strings.Split(resp.Data["usage"].(string), ",")) usages, err := issuing.NewIssuerUsageFromNames(strings.Split(resp.Data["usage"].(string), ","))
require.NoError(t, err, "failed parsing usage return value") require.NoError(t, err, "failed parsing usage return value")
require.True(t, usages.HasUsage(IssuanceUsage)) require.True(t, usages.HasUsage(issuing.IssuanceUsage))
require.True(t, usages.HasUsage(CRLSigningUsage)) require.True(t, usages.HasUsage(issuing.CRLSigningUsage))
require.False(t, usages.HasUsage(OCSPSigningUsage)) require.False(t, usages.HasUsage(issuing.OCSPSigningUsage))
// Request an OCSP request from it, we should get an Unauthorized response back // Request an OCSP request from it, we should get an Unauthorized response back
resp, err = SendOcspRequest(t, b, s, "get", testEnv.leafCertIssuer1, testEnv.issuer1, crypto.SHA1) resp, err = SendOcspRequest(t, b, s, "get", testEnv.leafCertIssuer1, testEnv.issuer1, crypto.SHA1)
@@ -290,7 +291,7 @@ func TestOcsp_RevokedCertHasIssuerWithoutAKey(t *testing.T) {
resp, err = CBRead(b, s, "issuer/"+testEnv.issuerId1.String()) resp, err = CBRead(b, s, "issuer/"+testEnv.issuerId1.String())
requireSuccessNonNilResponse(t, resp, err, "failed reading issuer") requireSuccessNonNilResponse(t, resp, err, "failed reading issuer")
requireFieldsSetInResp(t, resp, "key_id") requireFieldsSetInResp(t, resp, "key_id")
keyId := resp.Data["key_id"].(keyID) keyId := resp.Data["key_id"].(issuing.KeyID)
// This is a bit naughty but allow me to delete the key... // This is a bit naughty but allow me to delete the key...
sc := b.makeStorageContext(context.Background(), s) sc := b.makeStorageContext(context.Background(), s)
@@ -343,11 +344,11 @@ func TestOcsp_MultipleMatchingIssuersOneWithoutSigningUsage(t *testing.T) {
requireSuccessNonNilResponse(t, resp, err, "failed resetting usage flags on issuer") requireSuccessNonNilResponse(t, resp, err, "failed resetting usage flags on issuer")
requireFieldsSetInResp(t, resp, "usage") requireFieldsSetInResp(t, resp, "usage")
// Do not assume a specific ordering for usage... // Do not assume a specific ordering for usage...
usages, err := NewIssuerUsageFromNames(strings.Split(resp.Data["usage"].(string), ",")) usages, err := issuing.NewIssuerUsageFromNames(strings.Split(resp.Data["usage"].(string), ","))
require.NoError(t, err, "failed parsing usage return value") require.NoError(t, err, "failed parsing usage return value")
require.True(t, usages.HasUsage(IssuanceUsage)) require.True(t, usages.HasUsage(issuing.IssuanceUsage))
require.True(t, usages.HasUsage(CRLSigningUsage)) require.True(t, usages.HasUsage(issuing.CRLSigningUsage))
require.False(t, usages.HasUsage(OCSPSigningUsage)) require.False(t, usages.HasUsage(issuing.OCSPSigningUsage))
// Request an OCSP request from it, we should get a Good response back, from the rotated cert // Request an OCSP request from it, we should get a Good response back, from the rotated cert
resp, err = SendOcspRequest(t, b, s, "get", testEnv.leafCertIssuer1, testEnv.issuer1, crypto.SHA1) resp, err = SendOcspRequest(t, b, s, "get", testEnv.leafCertIssuer1, testEnv.issuer1, crypto.SHA1)
@@ -625,14 +626,14 @@ type ocspTestEnv struct {
issuer1 *x509.Certificate issuer1 *x509.Certificate
issuer2 *x509.Certificate issuer2 *x509.Certificate
issuerId1 issuerID issuerId1 issuing.IssuerID
issuerId2 issuerID issuerId2 issuing.IssuerID
leafCertIssuer1 *x509.Certificate leafCertIssuer1 *x509.Certificate
leafCertIssuer2 *x509.Certificate leafCertIssuer2 *x509.Certificate
keyId1 keyID keyId1 issuing.KeyID
keyId2 keyID keyId2 issuing.KeyID
} }
func setupOcspEnv(t *testing.T, keyType string) (*backend, logical.Storage, *ocspTestEnv) { func setupOcspEnv(t *testing.T, keyType string) (*backend, logical.Storage, *ocspTestEnv) {
@@ -643,8 +644,8 @@ func setupOcspEnvWithCaKeyConfig(t *testing.T, keyType string, caKeyBits int, ca
b, s := CreateBackendWithStorage(t) b, s := CreateBackendWithStorage(t)
var issuerCerts []*x509.Certificate var issuerCerts []*x509.Certificate
var leafCerts []*x509.Certificate var leafCerts []*x509.Certificate
var issuerIds []issuerID var issuerIds []issuing.IssuerID
var keyIds []keyID var keyIds []issuing.KeyID
resp, err := CBWrite(b, s, "config/crl", map[string]interface{}{ resp, err := CBWrite(b, s, "config/crl", map[string]interface{}{
"ocsp_enable": true, "ocsp_enable": true,
@@ -662,8 +663,8 @@ func setupOcspEnvWithCaKeyConfig(t *testing.T, keyType string, caKeyBits int, ca
}) })
requireSuccessNonNilResponse(t, resp, err, "root/generate/internal") requireSuccessNonNilResponse(t, resp, err, "root/generate/internal")
requireFieldsSetInResp(t, resp, "issuer_id", "key_id") requireFieldsSetInResp(t, resp, "issuer_id", "key_id")
issuerId := resp.Data["issuer_id"].(issuerID) issuerId := resp.Data["issuer_id"].(issuing.IssuerID)
keyId := resp.Data["key_id"].(keyID) keyId := resp.Data["key_id"].(issuing.KeyID)
resp, err = CBWrite(b, s, "roles/test"+strconv.FormatInt(int64(i), 10), map[string]interface{}{ resp, err = CBWrite(b, s, "roles/test"+strconv.FormatInt(int64(i), 10), map[string]interface{}{
"allow_bare_domains": true, "allow_bare_domains": true,

View File

@@ -21,6 +21,7 @@ import (
"time" "time"
"github.com/hashicorp/go-secure-stdlib/parseutil" "github.com/hashicorp/go-secure-stdlib/parseutil"
"github.com/hashicorp/vault/builtin/logical/pki/issuing"
"github.com/hashicorp/vault/sdk/framework" "github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/certutil" "github.com/hashicorp/vault/sdk/helper/certutil"
"github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/sdk/logical"
@@ -185,11 +186,11 @@ return a signed CRL based on the parameter values.`,
} }
func (b *backend) pathUpdateResignCrlsHandler(ctx context.Context, request *logical.Request, data *framework.FieldData) (*logical.Response, error) { func (b *backend) pathUpdateResignCrlsHandler(ctx context.Context, request *logical.Request, data *framework.FieldData) (*logical.Response, error) {
if b.useLegacyBundleCaStorage() { if b.UseLegacyBundleCaStorage() {
return logical.ErrorResponse("This API cannot be used until the migration has completed"), nil return logical.ErrorResponse("This API cannot be used until the migration has completed"), nil
} }
issuerRef := getIssuerRef(data) issuerRef := GetIssuerRef(data)
crlNumber := data.Get(crlNumberParam).(int) crlNumber := data.Get(crlNumberParam).(int)
deltaCrlBaseNumber := data.Get(deltaCrlBaseNumberParam).(int) deltaCrlBaseNumber := data.Get(deltaCrlBaseNumberParam).(int)
nextUpdateStr := data.Get(nextUpdateParam).(string) nextUpdateStr := data.Get(nextUpdateParam).(string)
@@ -273,11 +274,11 @@ func (b *backend) pathUpdateResignCrlsHandler(ctx context.Context, request *logi
} }
func (b *backend) pathUpdateSignRevocationListHandler(ctx context.Context, request *logical.Request, data *framework.FieldData) (*logical.Response, error) { func (b *backend) pathUpdateSignRevocationListHandler(ctx context.Context, request *logical.Request, data *framework.FieldData) (*logical.Response, error) {
if b.useLegacyBundleCaStorage() { if b.UseLegacyBundleCaStorage() {
return logical.ErrorResponse("This API cannot be used until the migration has completed"), nil return logical.ErrorResponse("This API cannot be used until the migration has completed"), nil
} }
issuerRef := getIssuerRef(data) issuerRef := GetIssuerRef(data)
crlNumber := data.Get(crlNumberParam).(int) crlNumber := data.Get(crlNumberParam).(int)
deltaCrlBaseNumber := data.Get(deltaCrlBaseNumberParam).(int) deltaCrlBaseNumber := data.Get(deltaCrlBaseNumberParam).(int)
nextUpdateStr := data.Get(nextUpdateParam).(string) nextUpdateStr := data.Get(nextUpdateParam).(string)
@@ -649,7 +650,7 @@ func getCaBundle(sc *storageContext, issuerRef string) (*certutil.CAInfoBundle,
return nil, fmt.Errorf("failed to resolve issuer %s: %w", issuerRefParam, err) return nil, fmt.Errorf("failed to resolve issuer %s: %w", issuerRefParam, err)
} }
return sc.fetchCAInfoByIssuerId(issuerId, CRLSigningUsage) return sc.fetchCAInfoByIssuerId(issuerId, issuing.CRLSigningUsage)
} }
func decodePemCrls(rawCrls []string) ([]*x509.RevocationList, error) { func decodePemCrls(rawCrls []string) ([]*x509.RevocationList, error) {

View File

@@ -23,6 +23,8 @@ import (
"github.com/hashicorp/vault/sdk/helper/certutil" "github.com/hashicorp/vault/sdk/helper/certutil"
"github.com/hashicorp/vault/sdk/helper/errutil" "github.com/hashicorp/vault/sdk/helper/errutil"
"github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/sdk/logical"
"github.com/hashicorp/vault/builtin/logical/pki/issuing"
) )
func pathListCertsRevoked(b *backend) *framework.Path { func pathListCertsRevoked(b *backend) *framework.Path {
@@ -306,7 +308,7 @@ func (b *backend) pathRevokeWriteHandleCertificate(ctx context.Context, req *log
// //
// We return the parsed serial number, an optionally-nil byte array to // We return the parsed serial number, an optionally-nil byte array to
// write out to disk, and an error if one occurred. // write out to disk, and an error if one occurred.
if b.useLegacyBundleCaStorage() { if b.UseLegacyBundleCaStorage() {
// We require listing all issuers from the 1.11 method. If we're // We require listing all issuers from the 1.11 method. If we're
// still using the legacy CA bundle but with the newer certificate // still using the legacy CA bundle but with the newer certificate
// attribute, we err and require the operator to upgrade and migrate // attribute, we err and require the operator to upgrade and migrate
@@ -534,7 +536,7 @@ func (b *backend) maybeRevokeCrossCluster(sc *storageContext, config *crlConfig,
return resp, nil return resp, nil
} }
func (b *backend) pathRevokeWrite(ctx context.Context, req *logical.Request, data *framework.FieldData, _ *roleEntry) (*logical.Response, error) { func (b *backend) pathRevokeWrite(ctx context.Context, req *logical.Request, data *framework.FieldData, _ *issuing.RoleEntry) (*logical.Response, error) {
rawSerial, haveSerial := data.GetOk("serial_number") rawSerial, haveSerial := data.GetOk("serial_number")
rawCertificate, haveCert := data.GetOk("certificate") rawCertificate, haveCert := data.GetOk("certificate")
sc := b.makeStorageContext(ctx, req.Storage) sc := b.makeStorageContext(ctx, req.Storage)
@@ -563,7 +565,7 @@ func (b *backend) pathRevokeWrite(ctx context.Context, req *logical.Request, dat
var cert *x509.Certificate var cert *x509.Certificate
var serial string var serial string
config, err := sc.Backend.crlBuilder.getConfigWithUpdate(sc) config, err := sc.Backend.CrlBuilder().getConfigWithUpdate(sc)
if err != nil { if err != nil {
return nil, fmt.Errorf("error revoking serial: %s: failed reading config: %w", serial, err) return nil, fmt.Errorf("error revoking serial: %s: failed reading config: %w", serial, err)
} }
@@ -647,18 +649,18 @@ func (b *backend) pathRevokeWrite(ctx context.Context, req *logical.Request, dat
return nil, logical.ErrReadOnly return nil, logical.ErrReadOnly
} }
b.revokeStorageLock.Lock() b.GetRevokeStorageLock().Lock()
defer b.revokeStorageLock.Unlock() defer b.GetRevokeStorageLock().Unlock()
return revokeCert(sc, config, cert) return revokeCert(sc, config, cert)
} }
func (b *backend) pathRotateCRLRead(ctx context.Context, req *logical.Request, _ *framework.FieldData) (*logical.Response, error) { func (b *backend) pathRotateCRLRead(ctx context.Context, req *logical.Request, _ *framework.FieldData) (*logical.Response, error) {
b.revokeStorageLock.RLock() b.GetRevokeStorageLock().RLock()
defer b.revokeStorageLock.RUnlock() defer b.GetRevokeStorageLock().RUnlock()
sc := b.makeStorageContext(ctx, req.Storage) sc := b.makeStorageContext(ctx, req.Storage)
warnings, crlErr := b.crlBuilder.rebuild(sc, false) warnings, crlErr := b.CrlBuilder().rebuild(sc, false)
if crlErr != nil { if crlErr != nil {
switch crlErr.(type) { switch crlErr.(type) {
case errutil.UserError: case errutil.UserError:
@@ -684,14 +686,14 @@ func (b *backend) pathRotateCRLRead(ctx context.Context, req *logical.Request, _
func (b *backend) pathRotateDeltaCRLRead(ctx context.Context, req *logical.Request, _ *framework.FieldData) (*logical.Response, error) { func (b *backend) pathRotateDeltaCRLRead(ctx context.Context, req *logical.Request, _ *framework.FieldData) (*logical.Response, error) {
sc := b.makeStorageContext(ctx, req.Storage) sc := b.makeStorageContext(ctx, req.Storage)
cfg, err := b.crlBuilder.getConfigWithUpdate(sc) cfg, err := b.CrlBuilder().getConfigWithUpdate(sc)
if err != nil { if err != nil {
return nil, fmt.Errorf("error fetching CRL configuration: %w", err) return nil, fmt.Errorf("error fetching CRL configuration: %w", err)
} }
isEnabled := cfg.EnableDelta isEnabled := cfg.EnableDelta
warnings, crlErr := b.crlBuilder.rebuildDeltaCRLsIfForced(sc, true) warnings, crlErr := b.CrlBuilder().rebuildDeltaCRLsIfForced(sc, true)
if crlErr != nil { if crlErr != nil {
switch crlErr.(type) { switch crlErr.(type) {
case errutil.UserError: case errutil.UserError:

View File

@@ -5,19 +5,20 @@ package pki
import ( import (
"context" "context"
"crypto/x509"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"net/http" "net/http"
"strings" "strings"
"time" "time"
"github.com/hashicorp/go-secure-stdlib/parseutil"
"github.com/hashicorp/vault/sdk/framework" "github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/certutil" "github.com/hashicorp/vault/sdk/helper/certutil"
"github.com/hashicorp/vault/sdk/helper/consts" "github.com/hashicorp/vault/sdk/helper/consts"
"github.com/hashicorp/vault/sdk/helper/errutil" "github.com/hashicorp/vault/sdk/helper/errutil"
"github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/sdk/logical"
"github.com/hashicorp/vault/builtin/logical/pki/issuing"
) )
func pathListRoles(b *backend) *framework.Path { func pathListRoles(b *backend) *framework.Path {
@@ -860,137 +861,26 @@ serviced by this role.`,
} }
} }
func (b *backend) getRole(ctx context.Context, s logical.Storage, n string) (*roleEntry, error) { // GetRole loads a role from storage, will validate it and error out if,
entry, err := s.Get(ctx, "role/"+n) // updates it and stores it back if possible. If the role does not exist
// a nil, nil response is returned.
func (b *backend) GetRole(ctx context.Context, s logical.Storage, n string) (*issuing.RoleEntry, error) {
result, err := issuing.GetRole(ctx, s, n)
if err != nil { if err != nil {
if errors.Is(err, issuing.ErrRoleNotFound) {
return nil, nil
}
return nil, err return nil, err
} }
if entry == nil {
return nil, nil
}
var result roleEntry
if err := entry.DecodeJSON(&result); err != nil {
return nil, err
}
// Migrate existing saved entries and save back if changed
modified := false
if len(result.DeprecatedTTL) == 0 && len(result.Lease) != 0 {
result.DeprecatedTTL = result.Lease
result.Lease = ""
modified = true
}
if result.TTL == 0 && len(result.DeprecatedTTL) != 0 {
parsed, err := parseutil.ParseDurationSecond(result.DeprecatedTTL)
if err != nil {
return nil, err
}
result.TTL = parsed
result.DeprecatedTTL = ""
modified = true
}
if len(result.DeprecatedMaxTTL) == 0 && len(result.LeaseMax) != 0 {
result.DeprecatedMaxTTL = result.LeaseMax
result.LeaseMax = ""
modified = true
}
if result.MaxTTL == 0 && len(result.DeprecatedMaxTTL) != 0 {
parsed, err := parseutil.ParseDurationSecond(result.DeprecatedMaxTTL)
if err != nil {
return nil, err
}
result.MaxTTL = parsed
result.DeprecatedMaxTTL = ""
modified = true
}
if result.AllowBaseDomain {
result.AllowBaseDomain = false
result.AllowBareDomains = true
modified = true
}
if result.AllowedDomainsOld != "" {
result.AllowedDomains = strings.Split(result.AllowedDomainsOld, ",")
result.AllowedDomainsOld = ""
modified = true
}
if result.AllowedBaseDomain != "" {
found := false
for _, v := range result.AllowedDomains {
if v == result.AllowedBaseDomain {
found = true
break
}
}
if !found {
result.AllowedDomains = append(result.AllowedDomains, result.AllowedBaseDomain)
}
result.AllowedBaseDomain = ""
modified = true
}
if result.AllowWildcardCertificates == nil {
// While not the most secure default, when AllowWildcardCertificates isn't
// explicitly specified in the stored Role, we automatically upgrade it to
// true to preserve compatibility with previous versions of Vault. Once this
// field is set, this logic will not be triggered any more.
result.AllowWildcardCertificates = new(bool)
*result.AllowWildcardCertificates = true
modified = true
}
// Upgrade generate_lease in role
if result.GenerateLease == nil {
// All the new roles will have GenerateLease always set to a value. A
// nil value indicates that this role needs an upgrade. Set it to
// `true` to not alter its current behavior.
result.GenerateLease = new(bool)
*result.GenerateLease = true
modified = true
}
// Upgrade key usages
if result.KeyUsageOld != "" {
result.KeyUsage = strings.Split(result.KeyUsageOld, ",")
result.KeyUsageOld = ""
modified = true
}
// Upgrade OU
if result.OUOld != "" {
result.OU = strings.Split(result.OUOld, ",")
result.OUOld = ""
modified = true
}
// Upgrade Organization
if result.OrganizationOld != "" {
result.Organization = strings.Split(result.OrganizationOld, ",")
result.OrganizationOld = ""
modified = true
}
// Set the issuer field to default if not set. We want to do this
// unconditionally as we should probably never have an empty issuer
// on a stored roles.
if len(result.Issuer) == 0 {
result.Issuer = defaultRef
modified = true
}
// Update CN Validations to be the present default, "email,hostname"
if len(result.CNValidations) == 0 {
result.CNValidations = []string{"email", "hostname"}
modified = true
}
// Ensure the role is valid after updating. // Ensure the role is valid after updating.
_, err = validateRole(b, &result, ctx, s) _, err = validateRole(b, result, ctx, s)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if modified && (b.System().LocalMount() || !b.System().ReplicationState().HasState(consts.ReplicationPerformanceSecondary)) { if result.WasModified && (b.System().LocalMount() || !b.System().ReplicationState().HasState(consts.ReplicationPerformanceSecondary)) {
jsonEntry, err := logical.StorageEntryJSON("role/"+n, &result) jsonEntry, err := logical.StorageEntryJSON("role/"+n, result)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -1000,11 +890,10 @@ func (b *backend) getRole(ctx context.Context, s logical.Storage, n string) (*ro
return nil, err return nil, err
} }
} }
result.WasModified = false
} }
result.Name = n return result, nil
return &result, nil
} }
func (b *backend) pathRoleDelete(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) { func (b *backend) pathRoleDelete(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
@@ -1022,7 +911,7 @@ func (b *backend) pathRoleRead(ctx context.Context, req *logical.Request, data *
return logical.ErrorResponse("missing role name"), nil return logical.ErrorResponse("missing role name"), nil
} }
role, err := b.getRole(ctx, req.Storage, roleName) role, err := b.GetRole(ctx, req.Storage, roleName)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -1049,7 +938,7 @@ func (b *backend) pathRoleCreate(ctx context.Context, req *logical.Request, data
var err error var err error
name := data.Get("name").(string) name := data.Get("name").(string)
entry := &roleEntry{ entry := &issuing.RoleEntry{
MaxTTL: time.Duration(data.Get("max_ttl").(int)) * time.Second, MaxTTL: time.Duration(data.Get("max_ttl").(int)) * time.Second,
TTL: time.Duration(data.Get("ttl").(int)) * time.Second, TTL: time.Duration(data.Get("ttl").(int)) * time.Second,
AllowLocalhost: data.Get("allow_localhost").(bool), AllowLocalhost: data.Get("allow_localhost").(bool),
@@ -1156,7 +1045,7 @@ func (b *backend) pathRoleCreate(ctx context.Context, req *logical.Request, data
return resp, nil return resp, nil
} }
func validateRole(b *backend, entry *roleEntry, ctx context.Context, s logical.Storage) (*logical.Response, error) { func validateRole(b *backend, entry *issuing.RoleEntry, ctx context.Context, s logical.Storage) (*logical.Response, error) {
resp := &logical.Response{} resp := &logical.Response{}
var err error var err error
@@ -1194,11 +1083,11 @@ func validateRole(b *backend, entry *roleEntry, ctx context.Context, s logical.S
entry.Issuer = defaultRef entry.Issuer = defaultRef
} }
// Check that the issuers reference set resolves to something // Check that the issuers reference set resolves to something
if !b.useLegacyBundleCaStorage() { if !b.UseLegacyBundleCaStorage() {
sc := b.makeStorageContext(ctx, s) sc := b.makeStorageContext(ctx, s)
issuerId, err := sc.resolveIssuerReference(entry.Issuer) issuerId, err := sc.resolveIssuerReference(entry.Issuer)
if err != nil { if err != nil {
if issuerId == IssuerRefNotFound { if issuerId == issuing.IssuerRefNotFound {
resp = &logical.Response{} resp = &logical.Response{}
if entry.Issuer == defaultRef { if entry.Issuer == defaultRef {
resp.AddWarning("Issuing Certificate was set to default, but no default issuing certificate (configurable at /config/issuers) is currently set") resp.AddWarning("Issuing Certificate was set to default, but no default issuing certificate (configurable at /config/issuers) is currently set")
@@ -1241,7 +1130,7 @@ func getTimeWithExplicitDefault(data *framework.FieldData, field string, default
func (b *backend) pathRolePatch(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) { func (b *backend) pathRolePatch(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
name := data.Get("name").(string) name := data.Get("name").(string)
oldEntry, err := b.getRole(ctx, req.Storage, name) oldEntry, err := b.GetRole(ctx, req.Storage, name)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -1249,7 +1138,7 @@ func (b *backend) pathRolePatch(ctx context.Context, req *logical.Request, data
return logical.ErrorResponse("Unable to fetch role entry to patch"), nil return logical.ErrorResponse("Unable to fetch role entry to patch"), nil
} }
entry := &roleEntry{ entry := &issuing.RoleEntry{
MaxTTL: getTimeWithExplicitDefault(data, "max_ttl", oldEntry.MaxTTL), MaxTTL: getTimeWithExplicitDefault(data, "max_ttl", oldEntry.MaxTTL),
TTL: getTimeWithExplicitDefault(data, "ttl", oldEntry.TTL), TTL: getTimeWithExplicitDefault(data, "ttl", oldEntry.TTL),
AllowLocalhost: getWithExplicitDefault(data, "allow_localhost", oldEntry.AllowLocalhost).(bool), AllowLocalhost: getWithExplicitDefault(data, "allow_localhost", oldEntry.AllowLocalhost).(bool),
@@ -1363,206 +1252,6 @@ func (b *backend) pathRolePatch(ctx context.Context, req *logical.Request, data
return resp, nil return resp, nil
} }
func parseKeyUsages(input []string) int {
var parsedKeyUsages x509.KeyUsage
for _, k := range input {
switch strings.ToLower(strings.TrimSpace(k)) {
case "digitalsignature":
parsedKeyUsages |= x509.KeyUsageDigitalSignature
case "contentcommitment":
parsedKeyUsages |= x509.KeyUsageContentCommitment
case "keyencipherment":
parsedKeyUsages |= x509.KeyUsageKeyEncipherment
case "dataencipherment":
parsedKeyUsages |= x509.KeyUsageDataEncipherment
case "keyagreement":
parsedKeyUsages |= x509.KeyUsageKeyAgreement
case "certsign":
parsedKeyUsages |= x509.KeyUsageCertSign
case "crlsign":
parsedKeyUsages |= x509.KeyUsageCRLSign
case "encipheronly":
parsedKeyUsages |= x509.KeyUsageEncipherOnly
case "decipheronly":
parsedKeyUsages |= x509.KeyUsageDecipherOnly
}
}
return int(parsedKeyUsages)
}
func parseExtKeyUsages(role *roleEntry) certutil.CertExtKeyUsage {
var parsedKeyUsages certutil.CertExtKeyUsage
if role.ServerFlag {
parsedKeyUsages |= certutil.ServerAuthExtKeyUsage
}
if role.ClientFlag {
parsedKeyUsages |= certutil.ClientAuthExtKeyUsage
}
if role.CodeSigningFlag {
parsedKeyUsages |= certutil.CodeSigningExtKeyUsage
}
if role.EmailProtectionFlag {
parsedKeyUsages |= certutil.EmailProtectionExtKeyUsage
}
for _, k := range role.ExtKeyUsage {
switch strings.ToLower(strings.TrimSpace(k)) {
case "any":
parsedKeyUsages |= certutil.AnyExtKeyUsage
case "serverauth":
parsedKeyUsages |= certutil.ServerAuthExtKeyUsage
case "clientauth":
parsedKeyUsages |= certutil.ClientAuthExtKeyUsage
case "codesigning":
parsedKeyUsages |= certutil.CodeSigningExtKeyUsage
case "emailprotection":
parsedKeyUsages |= certutil.EmailProtectionExtKeyUsage
case "ipsecendsystem":
parsedKeyUsages |= certutil.IpsecEndSystemExtKeyUsage
case "ipsectunnel":
parsedKeyUsages |= certutil.IpsecTunnelExtKeyUsage
case "ipsecuser":
parsedKeyUsages |= certutil.IpsecUserExtKeyUsage
case "timestamping":
parsedKeyUsages |= certutil.TimeStampingExtKeyUsage
case "ocspsigning":
parsedKeyUsages |= certutil.OcspSigningExtKeyUsage
case "microsoftservergatedcrypto":
parsedKeyUsages |= certutil.MicrosoftServerGatedCryptoExtKeyUsage
case "netscapeservergatedcrypto":
parsedKeyUsages |= certutil.NetscapeServerGatedCryptoExtKeyUsage
}
}
return parsedKeyUsages
}
type roleEntry struct {
LeaseMax string `json:"lease_max"`
Lease string `json:"lease"`
DeprecatedMaxTTL string `json:"max_ttl"`
DeprecatedTTL string `json:"ttl"`
TTL time.Duration `json:"ttl_duration"`
MaxTTL time.Duration `json:"max_ttl_duration"`
AllowLocalhost bool `json:"allow_localhost"`
AllowedBaseDomain string `json:"allowed_base_domain"`
AllowedDomainsOld string `json:"allowed_domains,omitempty"`
AllowedDomains []string `json:"allowed_domains_list"`
AllowedDomainsTemplate bool `json:"allowed_domains_template"`
AllowBaseDomain bool `json:"allow_base_domain"`
AllowBareDomains bool `json:"allow_bare_domains"`
AllowTokenDisplayName bool `json:"allow_token_displayname"`
AllowSubdomains bool `json:"allow_subdomains"`
AllowGlobDomains bool `json:"allow_glob_domains"`
AllowWildcardCertificates *bool `json:"allow_wildcard_certificates,omitempty"`
AllowAnyName bool `json:"allow_any_name"`
EnforceHostnames bool `json:"enforce_hostnames"`
AllowIPSANs bool `json:"allow_ip_sans"`
ServerFlag bool `json:"server_flag"`
ClientFlag bool `json:"client_flag"`
CodeSigningFlag bool `json:"code_signing_flag"`
EmailProtectionFlag bool `json:"email_protection_flag"`
UseCSRCommonName bool `json:"use_csr_common_name"`
UseCSRSANs bool `json:"use_csr_sans"`
KeyType string `json:"key_type"`
KeyBits int `json:"key_bits"`
UsePSS bool `json:"use_pss"`
SignatureBits int `json:"signature_bits"`
MaxPathLength *int `json:",omitempty"`
KeyUsageOld string `json:"key_usage,omitempty"`
KeyUsage []string `json:"key_usage_list"`
ExtKeyUsage []string `json:"extended_key_usage_list"`
OUOld string `json:"ou,omitempty"`
OU []string `json:"ou_list"`
OrganizationOld string `json:"organization,omitempty"`
Organization []string `json:"organization_list"`
Country []string `json:"country"`
Locality []string `json:"locality"`
Province []string `json:"province"`
StreetAddress []string `json:"street_address"`
PostalCode []string `json:"postal_code"`
GenerateLease *bool `json:"generate_lease,omitempty"`
NoStore bool `json:"no_store"`
RequireCN bool `json:"require_cn"`
CNValidations []string `json:"cn_validations"`
AllowedOtherSANs []string `json:"allowed_other_sans"`
AllowedSerialNumbers []string `json:"allowed_serial_numbers"`
AllowedUserIDs []string `json:"allowed_user_ids"`
AllowedURISANs []string `json:"allowed_uri_sans"`
AllowedURISANsTemplate bool `json:"allowed_uri_sans_template"`
PolicyIdentifiers []string `json:"policy_identifiers"`
ExtKeyUsageOIDs []string `json:"ext_key_usage_oids"`
BasicConstraintsValidForNonCA bool `json:"basic_constraints_valid_for_non_ca"`
NotBeforeDuration time.Duration `json:"not_before_duration"`
NotAfter string `json:"not_after"`
Issuer string `json:"issuer"`
// Name is only set when the role has been stored, on the fly roles have a blank name
Name string `json:"-"`
}
func (r *roleEntry) ToResponseData() map[string]interface{} {
responseData := map[string]interface{}{
"ttl": int64(r.TTL.Seconds()),
"max_ttl": int64(r.MaxTTL.Seconds()),
"allow_localhost": r.AllowLocalhost,
"allowed_domains": r.AllowedDomains,
"allowed_domains_template": r.AllowedDomainsTemplate,
"allow_bare_domains": r.AllowBareDomains,
"allow_token_displayname": r.AllowTokenDisplayName,
"allow_subdomains": r.AllowSubdomains,
"allow_glob_domains": r.AllowGlobDomains,
"allow_wildcard_certificates": r.AllowWildcardCertificates,
"allow_any_name": r.AllowAnyName,
"allowed_uri_sans_template": r.AllowedURISANsTemplate,
"enforce_hostnames": r.EnforceHostnames,
"allow_ip_sans": r.AllowIPSANs,
"server_flag": r.ServerFlag,
"client_flag": r.ClientFlag,
"code_signing_flag": r.CodeSigningFlag,
"email_protection_flag": r.EmailProtectionFlag,
"use_csr_common_name": r.UseCSRCommonName,
"use_csr_sans": r.UseCSRSANs,
"key_type": r.KeyType,
"key_bits": r.KeyBits,
"signature_bits": r.SignatureBits,
"use_pss": r.UsePSS,
"key_usage": r.KeyUsage,
"ext_key_usage": r.ExtKeyUsage,
"ext_key_usage_oids": r.ExtKeyUsageOIDs,
"ou": r.OU,
"organization": r.Organization,
"country": r.Country,
"locality": r.Locality,
"province": r.Province,
"street_address": r.StreetAddress,
"postal_code": r.PostalCode,
"no_store": r.NoStore,
"allowed_other_sans": r.AllowedOtherSANs,
"allowed_serial_numbers": r.AllowedSerialNumbers,
"allowed_user_ids": r.AllowedUserIDs,
"allowed_uri_sans": r.AllowedURISANs,
"require_cn": r.RequireCN,
"cn_validations": r.CNValidations,
"policy_identifiers": r.PolicyIdentifiers,
"basic_constraints_valid_for_non_ca": r.BasicConstraintsValidForNonCA,
"not_before_duration": int64(r.NotBeforeDuration.Seconds()),
"not_after": r.NotAfter,
"issuer_ref": r.Issuer,
}
if r.MaxPathLength != nil {
responseData["max_path_length"] = r.MaxPathLength
}
if r.GenerateLease != nil {
responseData["generate_lease"] = r.GenerateLease
}
return responseData
}
func checkCNValidations(validations []string) ([]string, error) { func checkCNValidations(validations []string) ([]string, error) {
var haveDisabled bool var haveDisabled bool
var haveEmail bool var haveEmail bool

View File

@@ -18,6 +18,8 @@ import (
"github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/sdk/logical"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/hashicorp/vault/builtin/logical/pki/issuing"
) )
func TestPki_RoleGenerateLease(t *testing.T) { func TestPki_RoleGenerateLease(t *testing.T) {
@@ -69,7 +71,7 @@ func TestPki_RoleGenerateLease(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
var role roleEntry var role issuing.RoleEntry
if err := entry.DecodeJSON(&role); err != nil { if err := entry.DecodeJSON(&role); err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -170,7 +172,7 @@ func TestPki_RoleKeyUsage(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
var role roleEntry var role issuing.RoleEntry
if err := entry.DecodeJSON(&role); err != nil { if err := entry.DecodeJSON(&role); err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -205,7 +207,7 @@ func TestPki_RoleKeyUsage(t *testing.T) {
if entry == nil { if entry == nil {
t.Fatalf("role should not be nil") t.Fatalf("role should not be nil")
} }
var result roleEntry var result issuing.RoleEntry
if err := entry.DecodeJSON(&result); err != nil { if err := entry.DecodeJSON(&result); err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
@@ -265,7 +267,7 @@ func TestPki_RoleOUOrganizationUpgrade(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
var role roleEntry var role issuing.RoleEntry
if err := entry.DecodeJSON(&role); err != nil { if err := entry.DecodeJSON(&role); err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -305,7 +307,7 @@ func TestPki_RoleOUOrganizationUpgrade(t *testing.T) {
if entry == nil { if entry == nil {
t.Fatalf("role should not be nil") t.Fatalf("role should not be nil")
} }
var result roleEntry var result issuing.RoleEntry
if err := entry.DecodeJSON(&result); err != nil { if err := entry.DecodeJSON(&result); err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
@@ -365,7 +367,7 @@ func TestPki_RoleAllowedDomains(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
var role roleEntry var role issuing.RoleEntry
if err := entry.DecodeJSON(&role); err != nil { if err := entry.DecodeJSON(&role); err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -399,7 +401,7 @@ func TestPki_RoleAllowedDomains(t *testing.T) {
if entry == nil { if entry == nil {
t.Fatalf("role should not be nil") t.Fatalf("role should not be nil")
} }
var result roleEntry var result issuing.RoleEntry
if err := entry.DecodeJSON(&result); err != nil { if err := entry.DecodeJSON(&result); err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }

View File

@@ -27,6 +27,9 @@ import (
"github.com/hashicorp/vault/sdk/framework" "github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/errutil" "github.com/hashicorp/vault/sdk/helper/errutil"
"github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/sdk/logical"
"github.com/hashicorp/vault/builtin/logical/pki/issuing"
"github.com/hashicorp/vault/builtin/logical/pki/parsing"
) )
func pathGenerateRoot(b *backend) *framework.Path { func pathGenerateRoot(b *backend) *framework.Path {
@@ -78,7 +81,7 @@ func (b *backend) pathCADeleteRoot(ctx context.Context, req *logical.Request, _
defer b.issuersLock.Unlock() defer b.issuersLock.Unlock()
sc := b.makeStorageContext(ctx, req.Storage) sc := b.makeStorageContext(ctx, req.Storage)
if !b.useLegacyBundleCaStorage() { if !b.UseLegacyBundleCaStorage() {
issuers, err := sc.listIssuers() issuers, err := sc.listIssuers()
if err != nil { if err != nil {
return nil, err return nil, err
@@ -132,7 +135,7 @@ func (b *backend) pathCAGenerateRoot(ctx context.Context, req *logical.Request,
var err error var err error
if b.useLegacyBundleCaStorage() { if b.UseLegacyBundleCaStorage() {
return logical.ErrorResponse("Can not create root CA until migration has completed"), nil return logical.ErrorResponse("Can not create root CA until migration has completed"), nil
} }
@@ -286,7 +289,8 @@ func (b *backend) pathCAGenerateRoot(ctx context.Context, req *logical.Request,
// Also store it as just the certificate identified by serial number, so it // Also store it as just the certificate identified by serial number, so it
// can be revoked // can be revoked
key := "certs/" + normalizeSerial(cb.SerialNumber) key := "certs/" + normalizeSerial(cb.SerialNumber)
certsCounted := b.certsCounted.Load() certCounter := b.GetCertificateCounter()
certsCounted := certCounter.IsInitialized()
err = req.Storage.Put(ctx, &logical.StorageEntry{ err = req.Storage.Put(ctx, &logical.StorageEntry{
Key: key, Key: key,
Value: parsedBundle.CertificateBytes, Value: parsedBundle.CertificateBytes,
@@ -294,10 +298,10 @@ func (b *backend) pathCAGenerateRoot(ctx context.Context, req *logical.Request,
if err != nil { if err != nil {
return nil, fmt.Errorf("unable to store certificate locally: %w", err) return nil, fmt.Errorf("unable to store certificate locally: %w", err)
} }
b.ifCountEnabledIncrementTotalCertificatesCount(certsCounted, key) certCounter.IncrementTotalCertificatesCount(certsCounted, key)
// Build a fresh CRL // Build a fresh CRL
warnings, err = b.crlBuilder.rebuild(sc, true) warnings, err = b.CrlBuilder().rebuild(sc, true)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -327,7 +331,7 @@ func (b *backend) pathCAGenerateRoot(ctx context.Context, req *logical.Request,
func (b *backend) pathIssuerSignIntermediate(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) { func (b *backend) pathIssuerSignIntermediate(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
var err error var err error
issuerName := getIssuerRef(data) issuerName := GetIssuerRef(data)
if len(issuerName) == 0 { if len(issuerName) == 0 {
return logical.ErrorResponse("missing issuer reference"), nil return logical.ErrorResponse("missing issuer reference"), nil
} }
@@ -337,7 +341,7 @@ func (b *backend) pathIssuerSignIntermediate(ctx context.Context, req *logical.R
return logical.ErrorResponse(`The "format" path parameter must be "pem", "der" or "pem_bundle"`), nil return logical.ErrorResponse(`The "format" path parameter must be "pem", "der" or "pem_bundle"`), nil
} }
role := &roleEntry{ role := &issuing.RoleEntry{
OU: data.Get("ou").([]string), OU: data.Get("ou").([]string),
Organization: data.Get("organization").([]string), Organization: data.Get("organization").([]string),
Country: data.Get("country").([]string), Country: data.Get("country").([]string),
@@ -369,7 +373,7 @@ func (b *backend) pathIssuerSignIntermediate(ctx context.Context, req *logical.R
var caErr error var caErr error
sc := b.makeStorageContext(ctx, req.Storage) sc := b.makeStorageContext(ctx, req.Storage)
signingBundle, caErr := sc.fetchCAInfo(issuerName, IssuanceUsage) signingBundle, caErr := sc.fetchCAInfo(issuerName, issuing.IssuanceUsage)
if caErr != nil { if caErr != nil {
switch caErr.(type) { switch caErr.(type) {
case errutil.UserError: case errutil.UserError:
@@ -420,7 +424,8 @@ func (b *backend) pathIssuerSignIntermediate(ctx context.Context, req *logical.R
} }
key := "certs/" + normalizeSerialFromBigInt(parsedBundle.Certificate.SerialNumber) key := "certs/" + normalizeSerialFromBigInt(parsedBundle.Certificate.SerialNumber)
certsCounted := b.certsCounted.Load() certCounter := b.GetCertificateCounter()
certsCounted := certCounter.IsInitialized()
err = req.Storage.Put(ctx, &logical.StorageEntry{ err = req.Storage.Put(ctx, &logical.StorageEntry{
Key: key, Key: key,
Value: parsedBundle.CertificateBytes, Value: parsedBundle.CertificateBytes,
@@ -428,7 +433,7 @@ func (b *backend) pathIssuerSignIntermediate(ctx context.Context, req *logical.R
if err != nil { if err != nil {
return nil, fmt.Errorf("unable to store certificate locally: %w", err) return nil, fmt.Errorf("unable to store certificate locally: %w", err)
} }
b.ifCountEnabledIncrementTotalCertificatesCount(certsCounted, key) certCounter.IncrementTotalCertificatesCount(certsCounted, key)
return resp, nil return resp, nil
} }
@@ -512,19 +517,13 @@ func signIntermediateResponse(signingBundle *certutil.CAInfoBundle, parsedBundle
} }
func (b *backend) pathIssuerSignSelfIssued(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) { func (b *backend) pathIssuerSignSelfIssued(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
var err error issuerName := GetIssuerRef(data)
issuerName := getIssuerRef(data)
if len(issuerName) == 0 { if len(issuerName) == 0 {
return logical.ErrorResponse("missing issuer reference"), nil return logical.ErrorResponse("missing issuer reference"), nil
} }
certPem := data.Get("certificate").(string) certPem := data.Get("certificate").(string)
block, _ := pem.Decode([]byte(certPem)) certs, err := parsing.ParseCertificatesFromString(certPem)
if block == nil || len(block.Bytes) == 0 {
return logical.ErrorResponse("certificate could not be PEM-decoded"), nil
}
certs, err := x509.ParseCertificates(block.Bytes)
if err != nil { if err != nil {
return logical.ErrorResponse(fmt.Sprintf("error parsing certificate: %s", err)), nil return logical.ErrorResponse(fmt.Sprintf("error parsing certificate: %s", err)), nil
} }
@@ -540,9 +539,8 @@ func (b *backend) pathIssuerSignSelfIssued(ctx context.Context, req *logical.Req
return logical.ErrorResponse("given certificate is not self-issued"), nil return logical.ErrorResponse("given certificate is not self-issued"), nil
} }
var caErr error
sc := b.makeStorageContext(ctx, req.Storage) sc := b.makeStorageContext(ctx, req.Storage)
signingBundle, caErr := sc.fetchCAInfo(issuerName, IssuanceUsage) signingBundle, caErr := sc.fetchCAInfo(issuerName, issuing.IssuanceUsage)
if caErr != nil { if caErr != nil {
switch caErr.(type) { switch caErr.(type) {
case errutil.UserError: case errutil.UserError:

View File

@@ -15,6 +15,7 @@ import (
"github.com/armon/go-metrics" "github.com/armon/go-metrics"
"github.com/hashicorp/go-hclog" "github.com/hashicorp/go-hclog"
"github.com/hashicorp/go-secure-stdlib/parseutil" "github.com/hashicorp/go-secure-stdlib/parseutil"
"github.com/hashicorp/vault/builtin/logical/pki/issuing"
"github.com/hashicorp/vault/sdk/framework" "github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/consts" "github.com/hashicorp/vault/sdk/helper/consts"
"github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/sdk/logical"
@@ -1013,8 +1014,8 @@ func (b *backend) doTidyCertStore(ctx context.Context, req *logical.Request, log
} }
func (b *backend) doTidyRevocationStore(ctx context.Context, req *logical.Request, logger hclog.Logger, config *tidyConfig) error { func (b *backend) doTidyRevocationStore(ctx context.Context, req *logical.Request, logger hclog.Logger, config *tidyConfig) error {
b.revokeStorageLock.Lock() b.GetRevokeStorageLock().Lock()
defer b.revokeStorageLock.Unlock() defer b.GetRevokeStorageLock().Unlock()
// Fetch and parse our issuers so we can associate them if necessary. // Fetch and parse our issuers so we can associate them if necessary.
sc := b.makeStorageContext(ctx, req.Storage) sc := b.makeStorageContext(ctx, req.Storage)
@@ -1047,9 +1048,9 @@ func (b *backend) doTidyRevocationStore(ctx context.Context, req *logical.Reques
// Check for pause duration to reduce resource consumption. // Check for pause duration to reduce resource consumption.
if config.PauseDuration > (0 * time.Second) { if config.PauseDuration > (0 * time.Second) {
b.revokeStorageLock.Unlock() b.GetRevokeStorageLock().Unlock()
time.Sleep(config.PauseDuration) time.Sleep(config.PauseDuration)
b.revokeStorageLock.Lock() b.GetRevokeStorageLock().Lock()
} }
revokedEntry, err := req.Storage.Get(ctx, "revoked/"+serial) revokedEntry, err := req.Storage.Get(ctx, "revoked/"+serial)
@@ -1092,7 +1093,7 @@ func (b *backend) doTidyRevocationStore(ctx context.Context, req *logical.Reques
if config.IssuerAssocs { if config.IssuerAssocs {
if !isRevInfoIssuerValid(&revInfo, issuerIDCertMap) { if !isRevInfoIssuerValid(&revInfo, issuerIDCertMap) {
b.tidyStatusIncMissingIssuerCertCount() b.tidyStatusIncMissingIssuerCertCount()
revInfo.CertificateIssuer = issuerID("") revInfo.CertificateIssuer = issuing.IssuerID("")
storeCert = true storeCert = true
if associateRevokedCertWithIsssuer(&revInfo, revokedCert, issuerIDCertMap) { if associateRevokedCertWithIsssuer(&revInfo, revokedCert, issuerIDCertMap) {
fixedIssuers += 1 fixedIssuers += 1
@@ -1150,7 +1151,7 @@ func (b *backend) doTidyRevocationStore(ctx context.Context, req *logical.Reques
} }
if !config.AutoRebuild { if !config.AutoRebuild {
warnings, err := b.crlBuilder.rebuild(sc, false) warnings, err := b.CrlBuilder().rebuild(sc, false)
if err != nil { if err != nil {
return err return err
} }
@@ -1180,7 +1181,7 @@ func (b *backend) doTidyExpiredIssuers(ctx context.Context, req *logical.Request
// Short-circuit to avoid having to deal with the legacy mounts. While we // Short-circuit to avoid having to deal with the legacy mounts. While we
// could handle this case and remove these issuers, its somewhat // could handle this case and remove these issuers, its somewhat
// unexpected behavior and we'd prefer to finish the migration first. // unexpected behavior and we'd prefer to finish the migration first.
if b.useLegacyBundleCaStorage() { if b.UseLegacyBundleCaStorage() {
return nil return nil
} }
@@ -1259,10 +1260,10 @@ func (b *backend) doTidyExpiredIssuers(ctx context.Context, req *logical.Request
// Removal of issuers is generally a good reason to rebuild the CRL, // Removal of issuers is generally a good reason to rebuild the CRL,
// even if auto-rebuild is enabled. // even if auto-rebuild is enabled.
b.revokeStorageLock.Lock() b.GetRevokeStorageLock().Lock()
defer b.revokeStorageLock.Unlock() defer b.GetRevokeStorageLock().Unlock()
warnings, err := b.crlBuilder.rebuild(sc, false) warnings, err := b.CrlBuilder().rebuild(sc, false)
if err != nil { if err != nil {
return err return err
} }
@@ -1290,7 +1291,7 @@ func (b *backend) doTidyMoveCABundle(ctx context.Context, req *logical.Request,
// Short-circuit to avoid moving the legacy bundle from under a legacy // Short-circuit to avoid moving the legacy bundle from under a legacy
// mount. // mount.
if b.useLegacyBundleCaStorage() { if b.UseLegacyBundleCaStorage() {
return nil return nil
} }
@@ -1353,8 +1354,8 @@ func (b *backend) doTidyRevocationQueue(ctx context.Context, req *logical.Reques
} }
// Grab locks as we're potentially modifying revocation-related storage. // Grab locks as we're potentially modifying revocation-related storage.
b.revokeStorageLock.Lock() b.GetRevokeStorageLock().Lock()
defer b.revokeStorageLock.Unlock() defer b.GetRevokeStorageLock().Unlock()
for cIndex, cluster := range clusters { for cIndex, cluster := range clusters {
if cluster[len(cluster)-1] == '/' { if cluster[len(cluster)-1] == '/' {
@@ -1375,9 +1376,9 @@ func (b *backend) doTidyRevocationQueue(ctx context.Context, req *logical.Reques
// Check for pause duration to reduce resource consumption. // Check for pause duration to reduce resource consumption.
if config.PauseDuration > (0 * time.Second) { if config.PauseDuration > (0 * time.Second) {
b.revokeStorageLock.Unlock() b.GetRevokeStorageLock().Unlock()
time.Sleep(config.PauseDuration) time.Sleep(config.PauseDuration)
b.revokeStorageLock.Lock() b.GetRevokeStorageLock().Lock()
} }
// Confirmation entries _should_ be handled by this cluster's // Confirmation entries _should_ be handled by this cluster's
@@ -1475,8 +1476,8 @@ func (b *backend) doTidyCrossRevocationStore(ctx context.Context, req *logical.R
} }
// Grab locks as we're potentially modifying revocation-related storage. // Grab locks as we're potentially modifying revocation-related storage.
b.revokeStorageLock.Lock() b.GetRevokeStorageLock().Lock()
defer b.revokeStorageLock.Unlock() defer b.GetRevokeStorageLock().Unlock()
for cIndex, cluster := range clusters { for cIndex, cluster := range clusters {
if cluster[len(cluster)-1] == '/' { if cluster[len(cluster)-1] == '/' {
@@ -1497,9 +1498,9 @@ func (b *backend) doTidyCrossRevocationStore(ctx context.Context, req *logical.R
// Check for pause duration to reduce resource consumption. // Check for pause duration to reduce resource consumption.
if config.PauseDuration > (0 * time.Second) { if config.PauseDuration > (0 * time.Second) {
b.revokeStorageLock.Unlock() b.GetRevokeStorageLock().Unlock()
time.Sleep(config.PauseDuration) time.Sleep(config.PauseDuration)
b.revokeStorageLock.Lock() b.GetRevokeStorageLock().Lock()
} }
ePath := cPath + serial ePath := cPath + serial
@@ -1547,7 +1548,7 @@ func (b *backend) doTidyAcme(ctx context.Context, req *logical.Request, logger h
b.tidyStatusLock.Unlock() b.tidyStatusLock.Unlock()
for _, thumbprint := range thumbprints { for _, thumbprint := range thumbprints {
err := b.tidyAcmeAccountByThumbprint(b.acmeState, sc, thumbprint, config.SafetyBuffer, config.AcmeAccountSafetyBuffer) err := b.tidyAcmeAccountByThumbprint(b.GetAcmeState(), sc, thumbprint, config.SafetyBuffer, config.AcmeAccountSafetyBuffer)
if err != nil { if err != nil {
logger.Warn("error tidying account %v: %v", thumbprint, err.Error()) logger.Warn("error tidying account %v: %v", thumbprint, err.Error())
} }
@@ -1567,13 +1568,13 @@ func (b *backend) doTidyAcme(ctx context.Context, req *logical.Request, logger h
} }
// Clean up any unused EAB // Clean up any unused EAB
eabIds, err := b.acmeState.ListEabIds(sc) eabIds, err := b.GetAcmeState().ListEabIds(sc)
if err != nil { if err != nil {
return fmt.Errorf("failed listing EAB ids: %w", err) return fmt.Errorf("failed listing EAB ids: %w", err)
} }
for _, eabId := range eabIds { for _, eabId := range eabIds {
eab, err := b.acmeState.LoadEab(sc, eabId) eab, err := b.GetAcmeState().LoadEab(sc, eabId)
if err != nil { if err != nil {
if errors.Is(err, ErrStorageItemNotFound) { if errors.Is(err, ErrStorageItemNotFound) {
// We don't need to worry about a consumed EAB // We don't need to worry about a consumed EAB
@@ -1584,7 +1585,7 @@ func (b *backend) doTidyAcme(ctx context.Context, req *logical.Request, logger h
eabExpiration := eab.CreatedOn.Add(config.AcmeAccountSafetyBuffer) eabExpiration := eab.CreatedOn.Add(config.AcmeAccountSafetyBuffer)
if time.Now().After(eabExpiration) { if time.Now().After(eabExpiration) {
_, err := b.acmeState.DeleteEab(sc, eabId) _, err := b.GetAcmeState().DeleteEab(sc, eabId)
if err != nil { if err != nil {
return fmt.Errorf("failed to tidy eab %s: %w", eabId, err) return fmt.Errorf("failed to tidy eab %s: %w", eabId, err)
} }
@@ -1669,15 +1670,17 @@ func (b *backend) pathTidyStatusRead(_ context.Context, _ *logical.Request, _ *f
resp.Data["internal_backend_uuid"] = b.backendUUID resp.Data["internal_backend_uuid"] = b.backendUUID
if b.certCountEnabled.Load() { certCounter := b.GetCertificateCounter()
resp.Data["current_cert_store_count"] = b.certCount.Load() if certCounter.IsEnabled() {
resp.Data["current_revoked_cert_count"] = b.revokedCertCount.Load() resp.Data["current_cert_store_count"] = certCounter.CertificateCount()
if !b.certsCounted.Load() { resp.Data["current_revoked_cert_count"] = certCounter.RevokedCount()
if !certCounter.IsInitialized() {
resp.AddWarning("Certificates in storage are still being counted, current counts provided may be " + resp.AddWarning("Certificates in storage are still being counted, current counts provided may be " +
"inaccurate") "inaccurate")
} }
if b.certCountError != "" { certError := certCounter.Error()
resp.Data["certificate_counting_error"] = b.certCountError if certError != nil {
resp.Data["certificate_counting_error"] = certError.Error()
} }
} }
@@ -1925,7 +1928,7 @@ func (b *backend) tidyStatusIncCertStoreCount() {
b.tidyStatus.certStoreDeletedCount++ b.tidyStatus.certStoreDeletedCount++
b.ifCountEnabledDecrementTotalCertificatesCountReport() b.GetCertificateCounter().DecrementTotalCertificatesCountReport()
} }
func (b *backend) tidyStatusIncRevokedCertCount() { func (b *backend) tidyStatusIncRevokedCertCount() {
@@ -1934,7 +1937,7 @@ func (b *backend) tidyStatusIncRevokedCertCount() {
b.tidyStatus.revokedCertDeletedCount++ b.tidyStatus.revokedCertDeletedCount++
b.ifCountEnabledDecrementTotalRevokedCertificatesCountReport() b.GetCertificateCounter().DecrementTotalRevokedCertificatesCountReport()
} }
func (b *backend) tidyStatusIncMissingIssuerCertCount() { func (b *backend) tidyStatusIncMissingIssuerCertCount() {

View File

@@ -18,18 +18,18 @@ const (
minUnifiedTransferDelay = 30 * time.Minute minUnifiedTransferDelay = 30 * time.Minute
) )
type unifiedTransferStatus struct { type UnifiedTransferStatus struct {
isRunning atomic.Bool isRunning atomic.Bool
lastRun time.Time lastRun time.Time
forceRerun atomic.Bool forceRerun atomic.Bool
} }
func (uts *unifiedTransferStatus) forceRun() { func (uts *UnifiedTransferStatus) forceRun() {
uts.forceRerun.Store(true) uts.forceRerun.Store(true)
} }
func newUnifiedTransferStatus() *unifiedTransferStatus { func newUnifiedTransferStatus() *UnifiedTransferStatus {
return &unifiedTransferStatus{} return &UnifiedTransferStatus{}
} }
// runUnifiedTransfer meant to run as a background, this will process all and // runUnifiedTransfer meant to run as a background, this will process all and
@@ -37,7 +37,7 @@ func newUnifiedTransferStatus() *unifiedTransferStatus {
// is enabled. // is enabled.
func runUnifiedTransfer(sc *storageContext) { func runUnifiedTransfer(sc *storageContext) {
b := sc.Backend b := sc.Backend
status := b.unifiedTransferStatus status := b.GetUnifiedTransferStatus()
isPerfStandby := b.System().ReplicationState().HasState(consts.ReplicationDRSecondary | consts.ReplicationPerformanceStandby) isPerfStandby := b.System().ReplicationState().HasState(consts.ReplicationDRSecondary | consts.ReplicationPerformanceStandby)
@@ -46,7 +46,7 @@ func runUnifiedTransfer(sc *storageContext) {
return return
} }
config, err := b.crlBuilder.getConfigWithUpdate(sc) config, err := b.CrlBuilder().getConfigWithUpdate(sc)
if err != nil { if err != nil {
b.Logger().Error("failed to retrieve crl config from storage for unified transfer background process", b.Logger().Error("failed to retrieve crl config from storage for unified transfer background process",
"error", err) "error", err)
@@ -125,7 +125,7 @@ func doUnifiedTransferMissingLocalSerials(sc *storageContext, clusterId string)
errCount := 0 errCount := 0
for i, serialNum := range localRevokedSerialNums { for i, serialNum := range localRevokedSerialNums {
if i%25 == 0 { if i%25 == 0 {
config, _ := sc.Backend.crlBuilder.getConfigWithUpdate(sc) config, _ := sc.Backend.CrlBuilder().getConfigWithUpdate(sc)
if config != nil && !config.UnifiedCRL { if config != nil && !config.UnifiedCRL {
return errors.New("unified crl has been disabled after we started, stopping") return errors.New("unified crl has been disabled after we started, stopping")
} }
@@ -224,7 +224,7 @@ func doUnifiedTransferMissingDeltaWALSerials(sc *storageContext, clusterId strin
errCount := 0 errCount := 0
for index, serial := range localWALEntries { for index, serial := range localWALEntries {
if index%25 == 0 { if index%25 == 0 {
config, _ := sc.Backend.crlBuilder.getConfigWithUpdate(sc) config, _ := sc.Backend.CrlBuilder().getConfigWithUpdate(sc)
if config != nil && (!config.UnifiedCRL || !config.EnableDelta) { if config != nil && (!config.UnifiedCRL || !config.EnableDelta) {
return errors.New("unified or delta CRLs have been disabled after we started, stopping") return errors.New("unified or delta CRLs have been disabled after we started, stopping")
} }

View File

@@ -0,0 +1,21 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package pki_backend
import (
log "github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/sdk/logical"
)
type SystemViewGetter interface {
System() logical.SystemView
}
type MountInfo interface {
BackendUUID() string
}
type Logger interface {
Logger() log.Logger
}

View File

@@ -49,8 +49,8 @@ func (b *backend) secretCredsRevoke(ctx context.Context, req *logical.Request, _
return nil, fmt.Errorf("could not find serial in internal secret data") return nil, fmt.Errorf("could not find serial in internal secret data")
} }
b.revokeStorageLock.Lock() b.GetRevokeStorageLock().Lock()
defer b.revokeStorageLock.Unlock() defer b.GetRevokeStorageLock().Unlock()
sc := b.makeStorageContext(ctx, req.Storage) sc := b.makeStorageContext(ctx, req.Storage)
serial := serialInt.(string) serial := serialInt.(string)
@@ -77,7 +77,7 @@ func (b *backend) secretCredsRevoke(ctx context.Context, req *logical.Request, _
return nil, nil return nil, nil
} }
config, err := sc.Backend.crlBuilder.getConfigWithUpdate(sc) config, err := sc.Backend.CrlBuilder().getConfigWithUpdate(sc)
if err != nil { if err != nil {
return nil, fmt.Errorf("error revoking serial: %s: failed reading config: %w", serial, err) return nil, fmt.Errorf("error revoking serial: %s: failed reading config: %w", serial, err)
} }

File diff suppressed because it is too large Load Diff

View File

@@ -10,6 +10,7 @@ import (
"fmt" "fmt"
"time" "time"
"github.com/hashicorp/vault/builtin/logical/pki/issuing"
"github.com/hashicorp/vault/sdk/helper/certutil" "github.com/hashicorp/vault/sdk/helper/certutil"
"github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/sdk/logical"
) )
@@ -19,16 +20,16 @@ import (
// and we need to perform it again... // and we need to perform it again...
const ( const (
latestMigrationVersion = 2 latestMigrationVersion = 2
legacyBundleShimID = issuerID("legacy-entry-shim-id") legacyBundleShimID = issuing.LegacyBundleShimID
legacyBundleShimKeyID = keyID("legacy-entry-shim-key-id") legacyBundleShimKeyID = issuing.LegacyBundleShimKeyID
) )
type legacyBundleMigrationLog struct { type legacyBundleMigrationLog struct {
Hash string `json:"hash"` Hash string `json:"hash"`
Created time.Time `json:"created"` Created time.Time `json:"created"`
CreatedIssuer issuerID `json:"issuer_id"` CreatedIssuer issuing.IssuerID `json:"issuer_id"`
CreatedKey keyID `json:"key_id"` CreatedKey issuing.KeyID `json:"key_id"`
MigrationVersion int `json:"migrationVersion"` MigrationVersion int `json:"migrationVersion"`
} }
type migrationInfo struct { type migrationInfo struct {
@@ -84,8 +85,8 @@ func migrateStorage(ctx context.Context, b *backend, s logical.Storage) error {
return nil return nil
} }
var issuerIdentifier issuerID var issuerIdentifier issuing.IssuerID
var keyIdentifier keyID var keyIdentifier issuing.KeyID
sc := b.makeStorageContext(ctx, s) sc := b.makeStorageContext(ctx, s)
if migrationInfo.legacyBundle != nil { if migrationInfo.legacyBundle != nil {
// When the legacy bundle still exists, there's three scenarios we // When the legacy bundle still exists, there's three scenarios we
@@ -120,7 +121,7 @@ func migrateStorage(ctx context.Context, b *backend, s logical.Storage) error {
// Since we do not have all the mount information available we must schedule // Since we do not have all the mount information available we must schedule
// the CRL to be rebuilt at a later time. // the CRL to be rebuilt at a later time.
b.crlBuilder.requestRebuildIfActiveNode(b) b.CrlBuilder().requestRebuildIfActiveNode(b)
} }
} }
@@ -202,33 +203,6 @@ func setLegacyBundleMigrationLog(ctx context.Context, s logical.Storage, lbm *le
return s.Put(ctx, json) return s.Put(ctx, json)
} }
func getLegacyCertBundle(ctx context.Context, s logical.Storage) (*issuerEntry, *certutil.CertBundle, error) { func getLegacyCertBundle(ctx context.Context, s logical.Storage) (*issuing.IssuerEntry, *certutil.CertBundle, error) {
entry, err := s.Get(ctx, legacyCertBundlePath) return issuing.GetLegacyCertBundle(ctx, s)
if err != nil {
return nil, nil, err
}
if entry == nil {
return nil, nil, nil
}
cb := &certutil.CertBundle{}
err = entry.DecodeJSON(cb)
if err != nil {
return nil, nil, err
}
// Fake a storage entry with backwards compatibility in mind.
issuer := &issuerEntry{
ID: legacyBundleShimID,
KeyID: legacyBundleShimKeyID,
Name: "legacy-entry-shim",
Certificate: cb.Certificate,
CAChain: cb.CAChain,
SerialNumber: cb.SerialNumber,
LeafNotAfterBehavior: certutil.ErrNotAfterBehavior,
}
issuer.Usage.ToggleUsage(AllIssuerUsages)
return issuer, cb, nil
} }

View File

@@ -9,6 +9,7 @@ import (
"testing" "testing"
"time" "time"
"github.com/hashicorp/vault/builtin/logical/pki/issuing"
"github.com/hashicorp/vault/sdk/helper/certutil" "github.com/hashicorp/vault/sdk/helper/certutil"
"github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/sdk/logical"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@@ -23,7 +24,7 @@ func Test_migrateStorageEmptyStorage(t *testing.T) {
// Reset the version the helper above set to 1. // Reset the version the helper above set to 1.
b.pkiStorageVersion.Store(0) b.pkiStorageVersion.Store(0)
require.True(t, b.useLegacyBundleCaStorage(), "pre migration we should have been told to use legacy storage.") require.True(t, b.UseLegacyBundleCaStorage(), "pre migration we should have been told to use legacy storage.")
request := &logical.InitializationRequest{Storage: s} request := &logical.InitializationRequest{Storage: s}
err := b.initialize(ctx, request) err := b.initialize(ctx, request)
@@ -48,7 +49,7 @@ func Test_migrateStorageEmptyStorage(t *testing.T) {
require.Empty(t, logEntry.CreatedIssuer) require.Empty(t, logEntry.CreatedIssuer)
require.Empty(t, logEntry.CreatedKey) require.Empty(t, logEntry.CreatedKey)
require.False(t, b.useLegacyBundleCaStorage(), "post migration we are still told to use legacy storage") require.False(t, b.UseLegacyBundleCaStorage(), "post migration we are still told to use legacy storage")
// Make sure we can re-run the migration without issues // Make sure we can re-run the migration without issues
request = &logical.InitializationRequest{Storage: s} request = &logical.InitializationRequest{Storage: s}
@@ -72,7 +73,7 @@ func Test_migrateStorageOnlyKey(t *testing.T) {
// Reset the version the helper above set to 1. // Reset the version the helper above set to 1.
b.pkiStorageVersion.Store(0) b.pkiStorageVersion.Store(0)
require.True(t, b.useLegacyBundleCaStorage(), "pre migration we should have been told to use legacy storage.") require.True(t, b.UseLegacyBundleCaStorage(), "pre migration we should have been told to use legacy storage.")
bundle := genCertBundle(t, b, s) bundle := genCertBundle(t, b, s)
// Clear everything except for the key // Clear everything except for the key
@@ -106,7 +107,7 @@ func Test_migrateStorageOnlyKey(t *testing.T) {
"Hash value (%s) should not have been empty", logEntry.Hash) "Hash value (%s) should not have been empty", logEntry.Hash)
require.True(t, startTime.Before(logEntry.Created), require.True(t, startTime.Before(logEntry.Created),
"created log entry time (%v) was before our start time(%v)?", logEntry.Created, startTime) "created log entry time (%v) was before our start time(%v)?", logEntry.Created, startTime)
require.Equal(t, logEntry.CreatedIssuer, issuerID("")) require.Equal(t, logEntry.CreatedIssuer, issuing.IssuerID(""))
require.Equal(t, logEntry.CreatedKey, keyIds[0]) require.Equal(t, logEntry.CreatedKey, keyIds[0])
keyId := keyIds[0] keyId := keyIds[0]
@@ -126,11 +127,11 @@ func Test_migrateStorageOnlyKey(t *testing.T) {
// Make sure we setup the default values // Make sure we setup the default values
keysConfig, err := sc.getKeysConfig() keysConfig, err := sc.getKeysConfig()
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, &keyConfigEntry{DefaultKeyId: keyId}, keysConfig) require.Equal(t, &issuing.KeyConfigEntry{DefaultKeyId: keyId}, keysConfig)
issuersConfig, err := sc.getIssuersConfig() issuersConfig, err := sc.getIssuersConfig()
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, issuerID(""), issuersConfig.DefaultIssuerId) require.Equal(t, issuing.IssuerID(""), issuersConfig.DefaultIssuerId)
// Make sure if we attempt to re-run the migration nothing happens... // Make sure if we attempt to re-run the migration nothing happens...
err = migrateStorage(ctx, b, s) err = migrateStorage(ctx, b, s)
@@ -142,7 +143,7 @@ func Test_migrateStorageOnlyKey(t *testing.T) {
require.Equal(t, logEntry.Created, logEntry2.Created) require.Equal(t, logEntry.Created, logEntry2.Created)
require.Equal(t, logEntry.Hash, logEntry2.Hash) require.Equal(t, logEntry.Hash, logEntry2.Hash)
require.False(t, b.useLegacyBundleCaStorage(), "post migration we are still told to use legacy storage") require.False(t, b.UseLegacyBundleCaStorage(), "post migration we are still told to use legacy storage")
} }
func Test_migrateStorageSimpleBundle(t *testing.T) { func Test_migrateStorageSimpleBundle(t *testing.T) {
@@ -154,7 +155,7 @@ func Test_migrateStorageSimpleBundle(t *testing.T) {
// Reset the version the helper above set to 1. // Reset the version the helper above set to 1.
b.pkiStorageVersion.Store(0) b.pkiStorageVersion.Store(0)
require.True(t, b.useLegacyBundleCaStorage(), "pre migration we should have been told to use legacy storage.") require.True(t, b.UseLegacyBundleCaStorage(), "pre migration we should have been told to use legacy storage.")
bundle := genCertBundle(t, b, s) bundle := genCertBundle(t, b, s)
json, err := logical.StorageEntryJSON(legacyCertBundlePath, bundle) json, err := logical.StorageEntryJSON(legacyCertBundlePath, bundle)
@@ -204,7 +205,7 @@ func Test_migrateStorageSimpleBundle(t *testing.T) {
require.Equal(t, keyId, issuer.KeyID) require.Equal(t, keyId, issuer.KeyID)
require.Empty(t, issuer.ManualChain) require.Empty(t, issuer.ManualChain)
require.Equal(t, []string{bundle.Certificate + "\n"}, issuer.CAChain) require.Equal(t, []string{bundle.Certificate + "\n"}, issuer.CAChain)
require.Equal(t, AllIssuerUsages, issuer.Usage) require.Equal(t, issuing.AllIssuerUsages, issuer.Usage)
require.Equal(t, certutil.ErrNotAfterBehavior, issuer.LeafNotAfterBehavior) require.Equal(t, certutil.ErrNotAfterBehavior, issuer.LeafNotAfterBehavior)
require.Equal(t, keyId, key.ID) require.Equal(t, keyId, key.ID)
@@ -219,7 +220,7 @@ func Test_migrateStorageSimpleBundle(t *testing.T) {
// Make sure we setup the default values // Make sure we setup the default values
keysConfig, err := sc.getKeysConfig() keysConfig, err := sc.getKeysConfig()
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, &keyConfigEntry{DefaultKeyId: keyId}, keysConfig) require.Equal(t, &issuing.KeyConfigEntry{DefaultKeyId: keyId}, keysConfig)
issuersConfig, err := sc.getIssuersConfig() issuersConfig, err := sc.getIssuersConfig()
require.NoError(t, err) require.NoError(t, err)
@@ -235,7 +236,7 @@ func Test_migrateStorageSimpleBundle(t *testing.T) {
require.Equal(t, logEntry.Created, logEntry2.Created) require.Equal(t, logEntry.Created, logEntry2.Created)
require.Equal(t, logEntry.Hash, logEntry2.Hash) require.Equal(t, logEntry.Hash, logEntry2.Hash)
require.False(t, b.useLegacyBundleCaStorage(), "post migration we are still told to use legacy storage") require.False(t, b.UseLegacyBundleCaStorage(), "post migration we are still told to use legacy storage")
// Make sure we can re-process a migration from scratch for whatever reason // Make sure we can re-process a migration from scratch for whatever reason
err = s.Delete(ctx, legacyMigrationBundleLogKey) err = s.Delete(ctx, legacyMigrationBundleLogKey)
@@ -296,8 +297,8 @@ func TestMigration_OnceChainRebuild(t *testing.T) {
// //
// Afterwards, we mutate these issuers to only point at themselves and // Afterwards, we mutate these issuers to only point at themselves and
// write back out. // write back out.
var rootIssuerId issuerID var rootIssuerId issuing.IssuerID
var intIssuerId issuerID var intIssuerId issuing.IssuerID
for _, issuerId := range issuerIds { for _, issuerId := range issuerIds {
issuer, err := sc.fetchIssuerById(issuerId) issuer, err := sc.fetchIssuerById(issuerId)
require.NoError(t, err) require.NoError(t, err)
@@ -368,7 +369,7 @@ func TestExpectedOpsWork_PreMigration(t *testing.T) {
b, s := CreateBackendWithStorage(t) b, s := CreateBackendWithStorage(t)
// Reset the version the helper above set to 1. // Reset the version the helper above set to 1.
b.pkiStorageVersion.Store(0) b.pkiStorageVersion.Store(0)
require.True(t, b.useLegacyBundleCaStorage(), "pre migration we should have been told to use legacy storage.") require.True(t, b.UseLegacyBundleCaStorage(), "pre migration we should have been told to use legacy storage.")
bundle := genCertBundle(t, b, s) bundle := genCertBundle(t, b, s)
json, err := logical.StorageEntryJSON(legacyCertBundlePath, bundle) json, err := logical.StorageEntryJSON(legacyCertBundlePath, bundle)
@@ -601,7 +602,7 @@ func TestBackupBundle(t *testing.T) {
// Reset the version the helper above set to 1. // Reset the version the helper above set to 1.
b.pkiStorageVersion.Store(0) b.pkiStorageVersion.Store(0)
require.True(t, b.useLegacyBundleCaStorage(), "pre migration we should have been told to use legacy storage.") require.True(t, b.UseLegacyBundleCaStorage(), "pre migration we should have been told to use legacy storage.")
// Create an empty request and tidy configuration for us. // Create an empty request and tidy configuration for us.
req := &logical.Request{ req := &logical.Request{
@@ -793,7 +794,7 @@ func TestDeletedIssuersPostMigration(t *testing.T) {
// Reset the version the helper above set to 1. // Reset the version the helper above set to 1.
b.pkiStorageVersion.Store(0) b.pkiStorageVersion.Store(0)
require.True(t, b.useLegacyBundleCaStorage(), "pre migration we should have been told to use legacy storage.") require.True(t, b.UseLegacyBundleCaStorage(), "pre migration we should have been told to use legacy storage.")
// Create a legacy CA bundle and write it out. // Create a legacy CA bundle and write it out.
bundle := genCertBundle(t, b, s) bundle := genCertBundle(t, b, s)

View File

@@ -8,6 +8,7 @@ import (
"strings" "strings"
"testing" "testing"
"github.com/hashicorp/vault/builtin/logical/pki/issuing"
"github.com/hashicorp/vault/sdk/framework" "github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/certutil" "github.com/hashicorp/vault/sdk/helper/certutil"
"github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/sdk/logical"
@@ -22,27 +23,27 @@ func Test_ConfigsRoundTrip(t *testing.T) {
sc := b.makeStorageContext(ctx, s) sc := b.makeStorageContext(ctx, s)
// Create an empty key, issuer for testing. // Create an empty key, issuer for testing.
key := keyEntry{ID: genKeyId()} key := issuing.KeyEntry{ID: genKeyId()}
err := sc.writeKey(key) err := sc.writeKey(key)
require.NoError(t, err) require.NoError(t, err)
issuer := &issuerEntry{ID: genIssuerId()} issuer := &issuing.IssuerEntry{ID: genIssuerId()}
err = sc.writeIssuer(issuer) err = sc.writeIssuer(issuer)
require.NoError(t, err) require.NoError(t, err)
// Verify we handle nothing stored properly // Verify we handle nothing stored properly
keyConfigEmpty, err := sc.getKeysConfig() keyConfigEmpty, err := sc.getKeysConfig()
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, &keyConfigEntry{}, keyConfigEmpty) require.Equal(t, &issuing.KeyConfigEntry{}, keyConfigEmpty)
issuerConfigEmpty, err := sc.getIssuersConfig() issuerConfigEmpty, err := sc.getIssuersConfig()
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, &issuerConfigEntry{}, issuerConfigEmpty) require.Equal(t, &issuing.IssuerConfigEntry{}, issuerConfigEmpty)
// Now attempt to store and reload properly // Now attempt to store and reload properly
origKeyConfig := &keyConfigEntry{ origKeyConfig := &issuing.KeyConfigEntry{
DefaultKeyId: key.ID, DefaultKeyId: key.ID,
} }
origIssuerConfig := &issuerConfigEntry{ origIssuerConfig := &issuing.IssuerConfigEntry{
DefaultIssuerId: issuer.ID, DefaultIssuerId: issuer.ID,
} }
@@ -98,12 +99,12 @@ func Test_IssuerRoundTrip(t *testing.T) {
keys, err := sc.listKeys() keys, err := sc.listKeys()
require.NoError(t, err) require.NoError(t, err)
require.ElementsMatch(t, []keyID{key1.ID, key2.ID}, keys) require.ElementsMatch(t, []issuing.KeyID{key1.ID, key2.ID}, keys)
issuers, err := sc.listIssuers() issuers, err := sc.listIssuers()
require.NoError(t, err) require.NoError(t, err)
require.ElementsMatch(t, []issuerID{issuer1.ID, issuer2.ID}, issuers) require.ElementsMatch(t, []issuing.IssuerID{issuer1.ID, issuer2.ID}, issuers)
} }
func Test_KeysIssuerImport(t *testing.T) { func Test_KeysIssuerImport(t *testing.T) {
@@ -183,7 +184,7 @@ func Test_IssuerUpgrade(t *testing.T) {
// Make sure that we add OCSP signing to v0 issuers if CRLSigning is enabled // Make sure that we add OCSP signing to v0 issuers if CRLSigning is enabled
issuer, _ := genIssuerAndKey(t, b, s) issuer, _ := genIssuerAndKey(t, b, s)
issuer.Version = 0 issuer.Version = 0
issuer.Usage.ToggleUsage(OCSPSigningUsage) issuer.Usage.ToggleUsage(issuing.OCSPSigningUsage)
err := sc.writeIssuer(&issuer) err := sc.writeIssuer(&issuer)
require.NoError(t, err, "failed writing out issuer") require.NoError(t, err, "failed writing out issuer")
@@ -192,13 +193,13 @@ func Test_IssuerUpgrade(t *testing.T) {
require.NoError(t, err, "failed fetching issuer") require.NoError(t, err, "failed fetching issuer")
require.Equal(t, uint(1), newIssuer.Version) require.Equal(t, uint(1), newIssuer.Version)
require.True(t, newIssuer.Usage.HasUsage(OCSPSigningUsage)) require.True(t, newIssuer.Usage.HasUsage(issuing.OCSPSigningUsage))
// If CRLSigning is not present on a v0, we should not have OCSP signing after upgrade. // If CRLSigning is not present on a v0, we should not have OCSP signing after upgrade.
issuer, _ = genIssuerAndKey(t, b, s) issuer, _ = genIssuerAndKey(t, b, s)
issuer.Version = 0 issuer.Version = 0
issuer.Usage.ToggleUsage(OCSPSigningUsage) issuer.Usage.ToggleUsage(issuing.OCSPSigningUsage)
issuer.Usage.ToggleUsage(CRLSigningUsage) issuer.Usage.ToggleUsage(issuing.CRLSigningUsage)
err = sc.writeIssuer(&issuer) err = sc.writeIssuer(&issuer)
require.NoError(t, err, "failed writing out issuer") require.NoError(t, err, "failed writing out issuer")
@@ -207,15 +208,15 @@ func Test_IssuerUpgrade(t *testing.T) {
require.NoError(t, err, "failed fetching issuer") require.NoError(t, err, "failed fetching issuer")
require.Equal(t, uint(1), newIssuer.Version) require.Equal(t, uint(1), newIssuer.Version)
require.False(t, newIssuer.Usage.HasUsage(OCSPSigningUsage)) require.False(t, newIssuer.Usage.HasUsage(issuing.OCSPSigningUsage))
} }
func genIssuerAndKey(t *testing.T, b *backend, s logical.Storage) (issuerEntry, keyEntry) { func genIssuerAndKey(t *testing.T, b *backend, s logical.Storage) (issuing.IssuerEntry, issuing.KeyEntry) {
certBundle := genCertBundle(t, b, s) certBundle := genCertBundle(t, b, s)
keyId := genKeyId() keyId := genKeyId()
pkiKey := keyEntry{ pkiKey := issuing.KeyEntry{
ID: keyId, ID: keyId,
PrivateKeyType: certBundle.PrivateKeyType, PrivateKeyType: certBundle.PrivateKeyType,
PrivateKey: strings.TrimSpace(certBundle.PrivateKey) + "\n", PrivateKey: strings.TrimSpace(certBundle.PrivateKey) + "\n",
@@ -223,14 +224,14 @@ func genIssuerAndKey(t *testing.T, b *backend, s logical.Storage) (issuerEntry,
issuerId := genIssuerId() issuerId := genIssuerId()
pkiIssuer := issuerEntry{ pkiIssuer := issuing.IssuerEntry{
ID: issuerId, ID: issuerId,
KeyID: keyId, KeyID: keyId,
Certificate: strings.TrimSpace(certBundle.Certificate) + "\n", Certificate: strings.TrimSpace(certBundle.Certificate) + "\n",
CAChain: certBundle.CAChain, CAChain: certBundle.CAChain,
SerialNumber: certBundle.SerialNumber, SerialNumber: certBundle.SerialNumber,
Usage: AllIssuerUsages, Usage: issuing.AllIssuerUsages,
Version: latestIssuerVersion, Version: issuing.LatestIssuerVersion,
} }
return pkiIssuer, pkiKey return pkiIssuer, pkiKey

View File

@@ -8,6 +8,7 @@ import (
"strings" "strings"
"time" "time"
"github.com/hashicorp/vault/builtin/logical/pki/issuing"
"github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/sdk/logical"
) )
@@ -17,10 +18,10 @@ const (
) )
type unifiedRevocationEntry struct { type unifiedRevocationEntry struct {
SerialNumber string `json:"-"` SerialNumber string `json:"-"`
CertExpiration time.Time `json:"certificate_expiration_utc"` CertExpiration time.Time `json:"certificate_expiration_utc"`
RevocationTimeUTC time.Time `json:"revocation_time_utc"` RevocationTimeUTC time.Time `json:"revocation_time_utc"`
CertificateIssuer issuerID `json:"issuer_id"` CertificateIssuer issuing.IssuerID `json:"issuer_id"`
} }
func getUnifiedRevocationBySerial(sc *storageContext, serial string) (*unifiedRevocationEntry, error) { func getUnifiedRevocationBySerial(sc *storageContext, serial string) (*unifiedRevocationEntry, error) {

View File

@@ -4,7 +4,6 @@
package pki package pki
import ( import (
"crypto"
"crypto/x509" "crypto/x509"
"fmt" "fmt"
"math/big" "math/big"
@@ -14,6 +13,8 @@ import (
"sync" "sync"
"time" "time"
"github.com/hashicorp/vault/builtin/logical/pki/issuing"
"github.com/hashicorp/vault/builtin/logical/pki/managed_key"
"github.com/hashicorp/vault/sdk/framework" "github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/certutil" "github.com/hashicorp/vault/sdk/helper/certutil"
@@ -24,7 +25,7 @@ import (
const ( const (
managedKeyNameArg = "managed_key_name" managedKeyNameArg = "managed_key_name"
managedKeyIdArg = "managed_key_id" managedKeyIdArg = "managed_key_id"
defaultRef = "default" defaultRef = issuing.DefaultRef
// Constants for If-Modified-Since operation // Constants for If-Modified-Since operation
headerIfModifiedSince = "If-Modified-Since" headerIfModifiedSince = "If-Modified-Since"
@@ -92,26 +93,6 @@ type managedKeyId interface {
String() string String() string
} }
type (
UUIDKey string
NameKey string
)
func (u UUIDKey) String() string {
return string(u)
}
func (n NameKey) String() string {
return string(n)
}
type managedKeyInfo struct {
publicKey crypto.PublicKey
keyType certutil.PrivateKeyType
name NameKey
uuid UUIDKey
}
// getManagedKeyId returns a NameKey or a UUIDKey, whichever was specified in the // getManagedKeyId returns a NameKey or a UUIDKey, whichever was specified in the
// request API data. // request API data.
func getManagedKeyId(data *framework.FieldData) (managedKeyId, error) { func getManagedKeyId(data *framework.FieldData) (managedKeyId, error) {
@@ -120,9 +101,9 @@ func getManagedKeyId(data *framework.FieldData) (managedKeyId, error) {
return nil, err return nil, err
} }
var keyId managedKeyId = NameKey(name) var keyId managedKeyId = managed_key.NameKey(name)
if len(UUID) > 0 { if len(UUID) > 0 {
keyId = UUIDKey(UUID) keyId = managed_key.UUIDKey(UUID)
} }
return keyId, nil return keyId, nil
@@ -188,7 +169,7 @@ func getIssuerName(sc *storageContext, data *framework.FieldData) (string, error
return issuerName, errIssuerNameInUse return issuerName, errIssuerNameInUse
} }
if err != nil && issuerId != IssuerRefNotFound { if err != nil && issuerId != issuing.IssuerRefNotFound {
return issuerName, errutil.InternalError{Err: err.Error()} return issuerName, errutil.InternalError{Err: err.Error()}
} }
} }
@@ -213,14 +194,14 @@ func getKeyName(sc *storageContext, data *framework.FieldData) (string, error) {
return "", errKeyNameInUse return "", errKeyNameInUse
} }
if err != nil && keyId != KeyRefNotFound { if err != nil && keyId != issuing.KeyRefNotFound {
return "", errutil.InternalError{Err: err.Error()} return "", errutil.InternalError{Err: err.Error()}
} }
} }
return keyName, nil return keyName, nil
} }
func getIssuerRef(data *framework.FieldData) string { func GetIssuerRef(data *framework.FieldData) string {
return extractRef(data, issuerRefParam) return extractRef(data, issuerRefParam)
} }
@@ -286,7 +267,7 @@ const (
type IfModifiedSinceHelper struct { type IfModifiedSinceHelper struct {
req *logical.Request req *logical.Request
reqType ifModifiedReqType reqType ifModifiedReqType
issuerRef issuerID issuerRef issuing.IssuerID
} }
func sendNotModifiedResponseIfNecessary(helper *IfModifiedSinceHelper, sc *storageContext, resp *logical.Response) (bool, error) { func sendNotModifiedResponseIfNecessary(helper *IfModifiedSinceHelper, sc *storageContext, resp *logical.Response) (bool, error) {
@@ -326,7 +307,7 @@ func (sc *storageContext) isIfModifiedSinceBeforeLastModified(helper *IfModified
switch helper.reqType { switch helper.reqType {
case ifModifiedCRL, ifModifiedDeltaCRL: case ifModifiedCRL, ifModifiedDeltaCRL:
if sc.Backend.crlBuilder.invalidate.Load() { if sc.Backend.CrlBuilder().invalidate.Load() {
// When we see the CRL is invalidated, respond with false // When we see the CRL is invalidated, respond with false
// regardless of what the local CRL state says. We've likely // regardless of what the local CRL state says. We've likely
// renamed some issuers or are about to rebuild a new CRL.... // renamed some issuers or are about to rebuild a new CRL....
@@ -346,7 +327,7 @@ func (sc *storageContext) isIfModifiedSinceBeforeLastModified(helper *IfModified
lastModified = crlConfig.DeltaLastModified lastModified = crlConfig.DeltaLastModified
} }
case ifModifiedUnifiedCRL, ifModifiedUnifiedDeltaCRL: case ifModifiedUnifiedCRL, ifModifiedUnifiedDeltaCRL:
if sc.Backend.crlBuilder.invalidate.Load() { if sc.Backend.CrlBuilder().invalidate.Load() {
// When we see the CRL is invalidated, respond with false // When we see the CRL is invalidated, respond with false
// regardless of what the local CRL state says. We've likely // regardless of what the local CRL state says. We've likely
// renamed some issuers or are about to rebuild a new CRL.... // renamed some issuers or are about to rebuild a new CRL....

View File

@@ -163,6 +163,21 @@ func GetPrivateKeyTypeFromSigner(signer crypto.Signer) PrivateKeyType {
return UnknownPrivateKey return UnknownPrivateKey
} }
// GetPrivateKeyTypeFromPublicKey based on the public key, return the PrivateKeyType
// that would be associated with it, returning UnknownPrivateKey for unsupported types
func GetPrivateKeyTypeFromPublicKey(pubKey crypto.PublicKey) PrivateKeyType {
switch pubKey.(type) {
case *rsa.PublicKey:
return RSAPrivateKey
case *ecdsa.PublicKey:
return ECPrivateKey
case *ed25519.PublicKey:
return Ed25519PrivateKey
default:
return UnknownPrivateKey
}
}
// ToPEMBundle converts a string-based certificate bundle // ToPEMBundle converts a string-based certificate bundle
// to a PEM-based string certificate bundle in trust path // to a PEM-based string certificate bundle in trust path
// order, leaf certificate first // order, leaf certificate first