PKI: Refactor common role path policy code into common area (#27759)

This commit is contained in:
Steven Clark
2024-07-11 13:22:33 -04:00
committed by GitHub
parent f102434c4c
commit 2d7a3fba99
10 changed files with 237 additions and 29 deletions

View File

@@ -17,13 +17,13 @@ import (
)
type acmeContext struct {
issuing.IssuerRoleContext
// baseUrl is the combination of the configured cluster local URL and the acmePath up to /acme/
baseUrl *url.URL
clusterUrl *url.URL
sc *storageContext
acmeState *acmeState
role *issuing.RoleEntry
issuer *issuing.IssuerEntry
// acmeDirectory is a string that can distinguish the various acme directories we have configured
// if something needs to remain locked into a directory path structure.
acmeDirectory string
@@ -161,16 +161,15 @@ func (b *backend) acmeWrapper(opts acmeWrapperOpts, op acmeOperation) framework.
}
acmeCtx := &acmeContext{
baseUrl: acmeBaseUrl,
clusterUrl: clusterBase,
sc: sc,
acmeState: b.acmeState,
role: role,
issuer: issuer,
acmeDirectory: acmeDirectory,
eabPolicy: eabPolicy,
ciepsPolicy: ciepsPolicy,
runtimeOpts: runtimeOpts,
IssuerRoleContext: issuing.NewIssuerRoleContext(ctx, issuer, role),
baseUrl: acmeBaseUrl,
clusterUrl: clusterBase,
sc: sc,
acmeState: b.acmeState,
acmeDirectory: acmeDirectory,
eabPolicy: eabPolicy,
ciepsPolicy: ciepsPolicy,
runtimeOpts: runtimeOpts,
}
return op(acmeCtx, r, data)

View File

@@ -82,12 +82,12 @@ func TestACMEIssuerRoleLoading(t *testing.T) {
for _, tt := range tc {
t.Run(tt.name, func(t *testing.T) {
f := b.acmeWrapper(acmeWrapperOpts{}, func(acmeCtx *acmeContext, r *logical.Request, _ *framework.FieldData) (*logical.Response, error) {
if tt.roleName != acmeCtx.role.Name {
return nil, fmt.Errorf("expected role %s but got %s", tt.roleName, acmeCtx.role.Name)
if tt.roleName != acmeCtx.Role.Name {
return nil, fmt.Errorf("expected role %s but got %s", tt.roleName, acmeCtx.Role.Name)
}
if tt.expectedIssuerName != acmeCtx.issuer.Name {
return nil, fmt.Errorf("expected issuer %s but got %s", tt.expectedIssuerName, acmeCtx.issuer.Name)
if tt.expectedIssuerName != acmeCtx.Issuer.Name {
return nil, fmt.Errorf("expected issuer %s but got %s", tt.expectedIssuerName, acmeCtx.Issuer.Name)
}
return nil, nil

View File

@@ -0,0 +1,22 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package issuing
import "context"
// IssuerRoleContext combines in a single struct an issuer and a role that we should
// leverage to issue a certificate along with the
type IssuerRoleContext struct {
context.Context
Role *RoleEntry
Issuer *IssuerEntry
}
func NewIssuerRoleContext(ctx context.Context, issuer *IssuerEntry, role *RoleEntry) IssuerRoleContext {
return IssuerRoleContext{
Context: ctx,
Role: role,
Issuer: issuer,
}
}

View File

@@ -17,7 +17,6 @@ import (
"github.com/hashicorp/go-secure-stdlib/strutil"
"github.com/hashicorp/vault/builtin/logical/pki/parsing"
"github.com/hashicorp/vault/builtin/logical/pki/pki_backend"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/certutil"
"github.com/hashicorp/vault/sdk/helper/errutil"
@@ -64,6 +63,12 @@ type EntityInfo struct {
EntityID string
}
type CertificateCounter interface {
IsInitialized() bool
IncrementTotalCertificatesCount(certsCounted bool, newSerial string)
IncrementTotalRevokedCertificatesCount(certsCounted bool, newSerial string)
}
func NewEntityInfoFromReq(req *logical.Request) EntityInfo {
if req == nil {
return EntityInfo{}
@@ -1012,7 +1017,7 @@ func ApplyIssuerLeafNotAfterBehavior(caSign *certutil.CAInfoBundle, notAfter tim
}
// StoreCertificate given a certificate bundle that was signed, persist the certificate to storage
func StoreCertificate(ctx context.Context, s logical.Storage, certCounter pki_backend.CertificateCounter, certBundle *certutil.ParsedCertBundle) error {
func StoreCertificate(ctx context.Context, s logical.Storage, certCounter CertificateCounter, certBundle *certutil.ParsedCertBundle) error {
hyphenSerialNumber := parsing.NormalizeSerialForStorageFromBigInt(certBundle.Certificate.SerialNumber)
key := PathCerts + hyphenSerialNumber
certsCounted := certCounter.IsInitialized()

View File

@@ -245,6 +245,29 @@ func (i IssuerEntry) CanMaybeSignWithAlgo(algo x509.SignatureAlgorithm) error {
return fmt.Errorf("unable to use issuer of type %v to sign with %v key type", cert.PublicKeyAlgorithm.String(), algo.String())
}
// ResolveAndFetchIssuerForIssuance takes a name or uuid referencing an issuer, loads the issuer
// and validates that we have the associated private key and is allowed to perform issuance operations.
func ResolveAndFetchIssuerForIssuance(ctx context.Context, s logical.Storage, issuerName string) (*IssuerEntry, error) {
if len(issuerName) == 0 {
return nil, fmt.Errorf("unable to fetch pki issuer: empty issuer name")
}
issuerId, err := ResolveIssuerReference(ctx, s, issuerName)
if err != nil {
return nil, fmt.Errorf("failed to resolve issuer %s: %w", issuerName, err)
}
issuer, err := FetchIssuerById(ctx, s, issuerId)
if err != nil {
return nil, fmt.Errorf("failed to load issuer %s: %w", issuerName, err)
}
if issuer.Usage.HasUsage(IssuanceUsage) && len(issuer.KeyID) > 0 {
return issuer, nil
}
return nil, fmt.Errorf("issuer %s missing proper issuance usage or doesn't have associated key", issuerName)
}
func ResolveIssuerReference(ctx context.Context, s logical.Storage, reference string) (IssuerID, error) {
if reference == DefaultRef {
// Handle fetching the default issuer.

View File

@@ -476,7 +476,7 @@ func removeDuplicatesAndSortIps(ipIdentifiers []net.IP) []net.IP {
func maybeAugmentReqDataWithSuitableCN(ac *acmeContext, csr *x509.CertificateRequest, data *framework.FieldData) {
// Role doesn't require a CN, so we don't care.
if !ac.role.RequireCN {
if !ac.Role.RequireCN {
return
}
@@ -522,9 +522,9 @@ func issueCertFromCsr(ac *acmeContext, csr *x509.CertificateRequest) (*certutil.
// (TLS) clients are mostly verifying against server's DNS SANs.
maybeAugmentReqDataWithSuitableCN(ac, csr, data)
signingBundle, issuerId, err := ac.sc.fetchCAInfoWithIssuer(ac.issuer.ID.String(), issuing.IssuanceUsage)
signingBundle, issuerId, err := ac.sc.fetchCAInfoWithIssuer(ac.Issuer.ID.String(), issuing.IssuanceUsage)
if err != nil {
return nil, "", fmt.Errorf("failed loading CA %s: %w", ac.issuer.ID.String(), err)
return nil, "", fmt.Errorf("failed loading CA %s: %w", ac.Issuer.ID.String(), err)
}
// ACME issued cert will override the TTL values to truncate to the issuer's
@@ -536,7 +536,7 @@ func issueCertFromCsr(ac *acmeContext, csr *x509.CertificateRequest) (*certutil.
input := &inputBundle{
req: &logical.Request{},
apiData: data,
role: ac.role,
role: ac.Role,
}
normalNotAfter, _, err := getCertificateNotAfter(ac.sc.System(), input, signingBundle)
@@ -730,7 +730,7 @@ func (b *backend) acmeNewOrderHandler(ac *acmeContext, _ *logical.Request, _ *fr
return nil, err
}
err = b.validateIdentifiersAgainstRole(ac.role, identifiers)
err = b.validateIdentifiersAgainstRole(ac.Role, identifiers)
if err != nil {
return nil, err
}

View File

@@ -4,7 +4,13 @@
package pki_backend
import (
"context"
"fmt"
"strings"
log "github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/builtin/logical/pki/issuing"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/logical"
)
@@ -20,8 +26,80 @@ type Logger interface {
Logger() log.Logger
}
type CertificateCounter interface {
IsInitialized() bool
IncrementTotalCertificatesCount(certsCounted bool, newSerial string)
IncrementTotalRevokedCertificatesCount(certsCounted bool, newSerial string)
//go:generate enumer -type=RolePathPolicy -text -json -transform=kebab-case
type RolePathPolicy int
const (
RPPUnknown RolePathPolicy = iota
RPPSignVerbatim
RPPRole
)
var (
pathPolicyRolePrefix = "role:"
pathPolicyRolePrefixLength = len(pathPolicyRolePrefix)
)
// GetRoleByPathOrPathPolicy loads an existing role based on if the data field data contains a 'role' parameter
// or by the values within the pathPolicy
func GetRoleByPathOrPathPolicy(ctx context.Context, s logical.Storage, data *framework.FieldData, pathPolicy string) (*issuing.RoleEntry, error) {
var role *issuing.RoleEntry
// The role name from the path is the highest priority
if roleName, ok := getRoleNameFromPath(data); ok {
var err error
role, err = issuing.GetRole(ctx, s, roleName)
if err != nil {
return nil, err
}
} else {
policyType, policyVal, err := GetPathPolicyType(pathPolicy)
if err != nil {
return nil, err
}
switch policyType {
case RPPRole:
role, err = issuing.GetRole(ctx, s, policyVal)
if err != nil {
return nil, err
}
case RPPSignVerbatim:
role = issuing.SignVerbatimRole()
default:
return nil, fmt.Errorf("unsupported policy type returned: %s from policy path: %s", policyType, pathPolicy)
}
}
return role, nil
}
func GetPathPolicyType(pathPolicy string) (RolePathPolicy, string, error) {
policy := strings.TrimSpace(pathPolicy)
switch {
case policy == "sign-verbatim":
return RPPSignVerbatim, "", nil
case strings.HasPrefix(policy, pathPolicyRolePrefix):
if policy == pathPolicyRolePrefix {
return RPPUnknown, "", fmt.Errorf("no role specified by policy %v", pathPolicy)
}
roleName := pathPolicy[pathPolicyRolePrefixLength:]
return RPPRole, roleName, nil
default:
return RPPUnknown, "", fmt.Errorf("string %v was not a valid default path policy", pathPolicy)
}
}
func getRoleNameFromPath(data *framework.FieldData) (string, bool) {
// If our schema doesn't include the parameter bail
if _, ok := data.Schema["role"]; !ok {
return "", false
}
if roleName, ok := data.GetOk("role"); ok {
return roleName.(string), true
}
return "", false
}

View File

@@ -0,0 +1,80 @@
// Code generated by "enumer -type=RolePathPolicy -text -json -transform=kebab-case"; DO NOT EDIT.
package pki_backend
import (
"encoding/json"
"fmt"
)
const _RolePathPolicyName = "RPPUnknownRPPSignVerbatimRPPRole"
var _RolePathPolicyIndex = [...]uint8{0, 10, 25, 32}
func (i RolePathPolicy) String() string {
if i < 0 || i >= RolePathPolicy(len(_RolePathPolicyIndex)-1) {
return fmt.Sprintf("RolePathPolicy(%d)", i)
}
return _RolePathPolicyName[_RolePathPolicyIndex[i]:_RolePathPolicyIndex[i+1]]
}
var _RolePathPolicyValues = []RolePathPolicy{0, 1, 2}
var _RolePathPolicyNameToValueMap = map[string]RolePathPolicy{
_RolePathPolicyName[0:10]: 0,
_RolePathPolicyName[10:25]: 1,
_RolePathPolicyName[25:32]: 2,
}
// RolePathPolicyString retrieves an enum value from the enum constants string name.
// Throws an error if the param is not part of the enum.
func RolePathPolicyString(s string) (RolePathPolicy, error) {
if val, ok := _RolePathPolicyNameToValueMap[s]; ok {
return val, nil
}
return 0, fmt.Errorf("%s does not belong to RolePathPolicy values", s)
}
// RolePathPolicyValues returns all values of the enum
func RolePathPolicyValues() []RolePathPolicy {
return _RolePathPolicyValues
}
// IsARolePathPolicy returns "true" if the value is listed in the enum definition. "false" otherwise
func (i RolePathPolicy) IsARolePathPolicy() bool {
for _, v := range _RolePathPolicyValues {
if i == v {
return true
}
}
return false
}
// MarshalJSON implements the json.Marshaler interface for RolePathPolicy
func (i RolePathPolicy) MarshalJSON() ([]byte, error) {
return json.Marshal(i.String())
}
// UnmarshalJSON implements the json.Unmarshaler interface for RolePathPolicy
func (i *RolePathPolicy) UnmarshalJSON(data []byte) error {
var s string
if err := json.Unmarshal(data, &s); err != nil {
return fmt.Errorf("RolePathPolicy should be a string, got %s", data)
}
var err error
*i, err = RolePathPolicyString(s)
return err
}
// MarshalText implements the encoding.TextMarshaler interface for RolePathPolicy
func (i RolePathPolicy) MarshalText() ([]byte, error) {
return []byte(i.String()), nil
}
// UnmarshalText implements the encoding.TextUnmarshaler interface for RolePathPolicy
func (i *RolePathPolicy) UnmarshalText(text []byte) error {
var err error
*i, err = RolePathPolicyString(string(text))
return err
}

View File

@@ -7,6 +7,7 @@ import (
"context"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/builtin/logical/pki/issuing"
"github.com/hashicorp/vault/builtin/logical/pki/managed_key"
"github.com/hashicorp/vault/sdk/logical"
)
@@ -18,7 +19,7 @@ type StorageContext interface {
UseLegacyBundleCaStorage() bool
GetPkiManagedView() managed_key.PkiManagedKeyView
CrlBuilder() CrlBuilderType
GetCertificateCounter() CertificateCounter
GetCertificateCounter() issuing.CertificateCounter
Logger() hclog.Logger
}

View File

@@ -111,7 +111,7 @@ func (sc *storageContext) GetPkiManagedView() managed_key.PkiManagedKeyView {
return sc.Backend
}
func (sc *storageContext) GetCertificateCounter() pki_backend.CertificateCounter {
func (sc *storageContext) GetCertificateCounter() issuing.CertificateCounter {
return sc.Backend.GetCertificateCounter()
}