Add acme account storage (#19953)

* Enable creation of accounts

 - Refactors many methods to take an acmeContext, which holds the
   storageContext on it.
 - Updates the core ACME Handlers to use *acmeContext, to avoid
   copying structs.
 - Makes JWK exported so the JSON parser can find it.

Signed-off-by: Alexander Scheel <alex.scheel@hashicorp.com>

* Finish ACME account creation

 - This ensures a Kid is created when one doesn't exist
 - Expands the parsed handler capabilities, to format the response and
   set required headers.

Signed-off-by: Alexander Scheel <alex.scheel@hashicorp.com>

---------

Signed-off-by: Alexander Scheel <alex.scheel@hashicorp.com>
This commit is contained in:
Alexander Scheel
2023-04-03 16:08:25 -04:00
committed by GitHub
parent de381c30f6
commit 754e2adc99
5 changed files with 198 additions and 59 deletions

View File

@@ -26,20 +26,20 @@ var AllowedOuterJWSTypes = map[string]interface{}{
type jwsCtx struct { type jwsCtx struct {
Algo string `json:"alg"` Algo string `json:"alg"`
Kid string `json:"kid"` Kid string `json:"kid"`
jwk json.RawMessage `json:"jwk"` Jwk json.RawMessage `json:"jwk"`
Nonce string `json:"nonce"` Nonce string `json:"nonce"`
Url string `json:"url"` Url string `json:"url"`
key jose.JSONWebKey `json:"-"` Key jose.JSONWebKey `json:"-"`
Existing bool `json:"-"` Existing bool `json:"-"`
} }
func (c *jwsCtx) UnmarshalJSON(a *acmeState, jws []byte) error { func (c *jwsCtx) UnmarshalJSON(a *acmeState, ac *acmeContext, jws []byte) error {
var err error var err error
if err = json.Unmarshal(jws, c); err != nil { if err = json.Unmarshal(jws, c); err != nil {
return err return err
} }
if c.Kid != "" && len(c.jwk) > 0 { if c.Kid != "" && len(c.Jwk) > 0 {
// See RFC 8555 Section 6.2. Request Authentication: // See RFC 8555 Section 6.2. Request Authentication:
// //
// > The "jwk" and "kid" fields are mutually exclusive. Servers MUST // > The "jwk" and "kid" fields are mutually exclusive. Servers MUST
@@ -47,7 +47,7 @@ func (c *jwsCtx) UnmarshalJSON(a *acmeState, jws []byte) error {
return fmt.Errorf("invalid header: got both account 'kid' and 'jwk' in the same message; expected only one: %w", ErrMalformed) return fmt.Errorf("invalid header: got both account 'kid' and 'jwk' in the same message; expected only one: %w", ErrMalformed)
} }
if c.Kid == "" && len(c.jwk) == 0 { if c.Kid == "" && len(c.Jwk) == 0 {
// See RFC 8555 Section 6.2. Request Authentication: // See RFC 8555 Section 6.2. Request Authentication:
// //
// > Either "jwk" (JSON Web Key) or "kid" (Key ID) as specified // > Either "jwk" (JSON Web Key) or "kid" (Key ID) as specified
@@ -70,24 +70,24 @@ func (c *jwsCtx) UnmarshalJSON(a *acmeState, jws []byte) error {
if c.Kid != "" { if c.Kid != "" {
// Load KID from storage first. // Load KID from storage first.
c.jwk, err = a.LoadJWK(c.Kid) c.Jwk, err = a.LoadJWK(ac, c.Kid)
if err != nil { if err != nil {
return err return err
} }
c.Existing = true c.Existing = true
} }
if err = c.key.UnmarshalJSON(c.jwk); err != nil { if err = c.Key.UnmarshalJSON(c.Jwk); err != nil {
return err return err
} }
if !c.key.Valid() { if !c.Key.Valid() {
return fmt.Errorf("received invalid jwk: %w", ErrMalformed) return fmt.Errorf("received invalid jwk: %w", ErrMalformed)
} }
if c.Kid != "" { if c.Kid == "" {
// Create a key ID // Create a key ID
kid, err := c.key.Thumbprint(crypto.SHA256) kid, err := c.Key.Thumbprint(crypto.SHA256)
if err != nil { if err != nil {
return fmt.Errorf("failed creating thumbprint: %w", err) return fmt.Errorf("failed creating thumbprint: %w", err)
} }
@@ -128,7 +128,7 @@ func (c *jwsCtx) VerifyJWS(signature string) (map[string]interface{}, error) {
return nil, fmt.Errorf("request had unprotected headers: %w", ErrMalformed) return nil, fmt.Errorf("request had unprotected headers: %w", ErrMalformed)
} }
payload, err := sig.Verify(c.key) payload, err := sig.Verify(c.Key)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -5,15 +5,23 @@ import (
"encoding/base64" "encoding/base64"
"fmt" "fmt"
"io" "io"
"strings"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/hashicorp/vault/sdk/framework" "github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/logical"
) )
// How long nonces are considered valid. const (
const nonceExpiry = 15 * time.Minute // How long nonces are considered valid.
nonceExpiry = 15 * time.Minute
// Path Prefixes
acmePathPrefix = "acme/"
acmeAccountPrefix = acmePathPrefix + "accounts/"
)
type acmeState struct { type acmeState struct {
nextExpiry *atomic.Int64 nextExpiry *atomic.Int64
@@ -99,36 +107,86 @@ func (a *acmeState) TidyNonces() {
a.nextExpiry.Store(nextRun.Unix()) a.nextExpiry.Store(nextRun.Unix())
} }
func (a *acmeState) CreateAccount(c *jwsCtx, contact []string, termsOfServiceAgreed bool) (map[string]interface{}, error) { type ACMEStates string
// TODO
return nil, nil const (
StatusValid = "valid"
StatusDeactivated = "deactivated"
StatusRevoked = "revoked"
)
type acmeAccount struct {
KeyId string `json:"-"`
Status ACMEStates `json:"state"`
Contact []string `json:"contact"`
TermsOfServiceAgreed bool `json:"termsOfServiceAgreed"`
Jwk []byte `json:"jwk"`
} }
func (a *acmeState) LoadAccount(keyID string) (map[string]interface{}, error) { func (a *acmeState) CreateAccount(ac *acmeContext, c *jwsCtx, contact []string, termsOfServiceAgreed bool) (*acmeAccount, error) {
// TODO acct := &acmeAccount{
return nil, nil KeyId: c.Kid,
Contact: contact,
TermsOfServiceAgreed: termsOfServiceAgreed,
Jwk: c.Jwk,
}
json, err := logical.StorageEntryJSON(acmeAccountPrefix+c.Kid, acct)
if err != nil {
return nil, fmt.Errorf("error creating account entry: %w", err)
}
if err := ac.sc.Storage.Put(ac.sc.Context, json); err != nil {
return nil, fmt.Errorf("error writing account entry: %w", err)
}
return acct, nil
} }
func (a *acmeState) DoesAccountExist(keyId string) bool { func cleanKid(keyID string) string {
account, err := a.LoadAccount(keyId) pieces := strings.Split(keyID, "/")
return err == nil && len(account) > 0 return pieces[len(pieces)-1]
} }
func (a *acmeState) LoadJWK(keyID string) ([]byte, error) { func (a *acmeState) LoadAccount(ac *acmeContext, keyID string) (*acmeAccount, error) {
key, err := a.LoadAccount(keyID) kid := cleanKid(keyID)
entry, err := ac.sc.Storage.Get(ac.sc.Context, acmeAccountPrefix+kid)
if err != nil {
return nil, fmt.Errorf("error loading account: %w", err)
}
if entry == nil {
return nil, fmt.Errorf("account not found: %w", ErrMalformed)
}
var acct acmeAccount
err = entry.DecodeJSON(&acct)
if err != nil {
return nil, fmt.Errorf("error loading account: %w", err)
}
return &acct, nil
}
func (a *acmeState) DoesAccountExist(ac *acmeContext, keyId string) bool {
account, err := a.LoadAccount(ac, keyId)
return err == nil && account != nil
}
func (a *acmeState) LoadJWK(ac *acmeContext, keyID string) ([]byte, error) {
key, err := a.LoadAccount(ac, keyID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
jwk, present := key["jwk"] if len(key.Jwk) == 0 {
if !present {
return nil, fmt.Errorf("malformed key entry lacks JWK") return nil, fmt.Errorf("malformed key entry lacks JWK")
} }
return jwk.([]byte), nil return key.Jwk, nil
} }
func (a *acmeState) ParseRequestParams(data *framework.FieldData) (*jwsCtx, map[string]interface{}, error) { func (a *acmeState) ParseRequestParams(ac *acmeContext, data *framework.FieldData) (*jwsCtx, map[string]interface{}, error) {
var c jwsCtx var c jwsCtx
var m map[string]interface{} var m map[string]interface{}
@@ -143,7 +201,7 @@ func (a *acmeState) ParseRequestParams(data *framework.FieldData) (*jwsCtx, map[
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("failed to base64 parse 'protected': %s: %w", err, ErrMalformed) return nil, nil, fmt.Errorf("failed to base64 parse 'protected': %s: %w", err, ErrMalformed)
} }
if err = c.UnmarshalJSON(a, jwkBytes); err != nil { if err = c.UnmarshalJSON(a, ac, jwkBytes); err != nil {
return nil, nil, fmt.Errorf("failed to json unmarshal 'protected': %w", err) return nil, nil, fmt.Errorf("failed to json unmarshal 'protected': %w", err)
} }

View File

@@ -55,7 +55,7 @@ func patternAcmeDirectory(b *backend, pattern string) *framework.Path {
} }
} }
type acmeOperation func(acmeCtx acmeContext, r *logical.Request, _ *framework.FieldData) (*logical.Response, error) type acmeOperation func(acmeCtx *acmeContext, r *logical.Request, _ *framework.FieldData) (*logical.Response, error)
type acmeContext struct { type acmeContext struct {
baseUrl *url.URL baseUrl *url.URL
@@ -76,7 +76,7 @@ func (b *backend) acmeWrapper(op acmeOperation) framework.OperationFunc {
return nil, err return nil, err
} }
acmeCtx := acmeContext{ acmeCtx := &acmeContext{
baseUrl: baseUrl, baseUrl: baseUrl,
sc: sc, sc: sc,
} }
@@ -120,7 +120,7 @@ func acmeErrorWrapper(op framework.OperationFunc) framework.OperationFunc {
} }
} }
func (b *backend) acmeDirectoryHandler(acmeCtx acmeContext, r *logical.Request, _ *framework.FieldData) (*logical.Response, error) { func (b *backend) acmeDirectoryHandler(acmeCtx *acmeContext, r *logical.Request, _ *framework.FieldData) (*logical.Response, error) {
rawBody, err := json.Marshal(map[string]interface{}{ rawBody, err := json.Marshal(map[string]interface{}{
"newNonce": acmeCtx.baseUrl.JoinPath("new-nonce").String(), "newNonce": acmeCtx.baseUrl.JoinPath("new-nonce").String(),
"newAccount": acmeCtx.baseUrl.JoinPath("new-account").String(), "newAccount": acmeCtx.baseUrl.JoinPath("new-account").String(),

View File

@@ -1,7 +1,9 @@
package pki package pki
import ( import (
"encoding/json"
"fmt" "fmt"
"net/http"
"strings" "strings"
"github.com/hashicorp/vault/sdk/framework" "github.com/hashicorp/vault/sdk/framework"
@@ -88,31 +90,102 @@ func patternAcmeNewAccount(b *backend, pattern string) *framework.Path {
} }
} }
type acmeParsedOperation func(acmeCtx acmeContext, r *logical.Request, fields *framework.FieldData, userCtx *jwsCtx, data map[string]interface{}) (*logical.Response, error) type acmeParsedOperation func(acmeCtx *acmeContext, r *logical.Request, fields *framework.FieldData, userCtx *jwsCtx, data map[string]interface{}) (*logical.Response, error)
func (b *backend) acmeParsedWrapper(op acmeParsedOperation) framework.OperationFunc { func (b *backend) acmeParsedWrapper(op acmeParsedOperation) framework.OperationFunc {
return b.acmeWrapper(func(acmeCtx acmeContext, r *logical.Request, fields *framework.FieldData) (*logical.Response, error) { return b.acmeWrapper(func(acmeCtx *acmeContext, r *logical.Request, fields *framework.FieldData) (*logical.Response, error) {
user, data, err := b.acmeState.ParseRequestParams(fields) user, data, err := b.acmeState.ParseRequestParams(acmeCtx, fields)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return op(acmeCtx, r, fields, user, data) resp, err := op(acmeCtx, r, fields, user, data)
// Our response handlers might not add the necessary headers.
if resp != nil {
if resp.Headers == nil {
resp.Headers = map[string][]string{}
}
if _, ok := resp.Headers["Replay-Nonce"]; !ok {
nonce, _, err := b.acmeState.GetNonce()
if err != nil {
return nil, err
}
resp.Headers["Replay-Nonce"] = []string{nonce}
}
if _, ok := resp.Headers["Link"]; !ok {
resp.Headers["Link"] = genAcmeLinkHeader(acmeCtx)
} else {
directory := genAcmeLinkHeader(acmeCtx)[0]
addDirectory := true
for _, item := range resp.Headers["Link"] {
if item == directory {
addDirectory = false
break
}
}
if addDirectory {
resp.Headers["Link"] = append(resp.Headers["Link"], directory)
}
}
// ACME responses don't understand Vault's default encoding
// format. Rather than expecting everything to handle creating
// ACME-formatted responses, do the marshaling in one place.
if _, ok := resp.Data[logical.HTTPRawBody]; !ok {
ignored_values := map[string]bool{logical.HTTPContentType: true, logical.HTTPStatusCode: true}
fields := map[string]interface{}{}
body := map[string]interface{}{
logical.HTTPContentType: "application/json",
logical.HTTPStatusCode: http.StatusOK,
}
for key, value := range resp.Data {
if _, present := ignored_values[key]; !present {
fields[key] = value
} else {
body[key] = value
}
}
rawBody, err := json.Marshal(fields)
if err != nil {
return nil, fmt.Errorf("Error marshaling JSON body: %w", err)
}
body[logical.HTTPRawBody] = rawBody
resp.Data = body
}
}
return resp, err
}) })
} }
func (b *backend) acmeNewAccountHandler(acmeCtx acmeContext, r *logical.Request, fields *framework.FieldData, userCtx *jwsCtx, data map[string]interface{}) (*logical.Response, error) { func (b *backend) acmeNewAccountHandler(acmeCtx *acmeContext, r *logical.Request, fields *framework.FieldData, userCtx *jwsCtx, data map[string]interface{}) (*logical.Response, error) {
// Parameters // Parameters
var ok bool var ok bool
var onlyReturnExisting bool var onlyReturnExisting bool
var contact []string var contacts []string
var termsOfServiceAgreed bool var termsOfServiceAgreed bool
rawContact, present := data["contact"] rawContact, present := data["contact"]
if present { if present {
contact, ok = rawContact.([]string) listContact, ok := rawContact.([]interface{})
if !ok { if !ok {
return nil, fmt.Errorf("invalid type for field 'contact': %w", ErrMalformed) return nil, fmt.Errorf("invalid type (%T) for field 'contact': %w", rawContact, ErrMalformed)
}
for index, singleContact := range listContact {
contact, ok := singleContact.(string)
if !ok {
return nil, fmt.Errorf("invalid type (%T) for field 'contact' item %d: %w", singleContact, index, ErrMalformed)
}
contacts = append(contacts, contact)
} }
} }
@@ -120,7 +193,7 @@ func (b *backend) acmeNewAccountHandler(acmeCtx acmeContext, r *logical.Request,
if present { if present {
termsOfServiceAgreed, ok = rawTermsOfServiceAgreed.(bool) termsOfServiceAgreed, ok = rawTermsOfServiceAgreed.(bool)
if !ok { if !ok {
return nil, fmt.Errorf("invalid type for field 'termsOfServiceAgreed': %w", ErrMalformed) return nil, fmt.Errorf("invalid type (%T) for field 'termsOfServiceAgreed': %w", rawTermsOfServiceAgreed, ErrMalformed)
} }
} }
@@ -128,7 +201,7 @@ func (b *backend) acmeNewAccountHandler(acmeCtx acmeContext, r *logical.Request,
if present { if present {
onlyReturnExisting, ok = rawOnlyReturnExisting.(bool) onlyReturnExisting, ok = rawOnlyReturnExisting.(bool)
if !ok { if !ok {
return nil, fmt.Errorf("invalid type for field 'onlyReturnExisting': %w", ErrMalformed) return nil, fmt.Errorf("invalid type (%T) for field 'onlyReturnExisting': %w", rawOnlyReturnExisting, ErrMalformed)
} }
} }
@@ -139,38 +212,39 @@ func (b *backend) acmeNewAccountHandler(acmeCtx acmeContext, r *logical.Request,
return b.acmeNewAccountSearchHandler(acmeCtx, r, fields, userCtx, data) return b.acmeNewAccountSearchHandler(acmeCtx, r, fields, userCtx, data)
} }
return b.acmeNewAccountCreateHandler(acmeCtx, r, fields, userCtx, data, contact, termsOfServiceAgreed) return b.acmeNewAccountCreateHandler(acmeCtx, r, fields, userCtx, data, contacts, termsOfServiceAgreed)
} }
func formatAccountResponse(location string, status string, contact []string) *logical.Response { func formatAccountResponse(location string, acct *acmeAccount) *logical.Response {
resp := &logical.Response{ resp := &logical.Response{
Data: map[string]interface{}{ Data: map[string]interface{}{
"status": status, "status": acct.Status,
"orders": location + "/orders", "orders": location + "/orders",
}, },
Headers: map[string][]string{
"Location": {location},
},
} }
if len(contact) > 0 { if len(acct.Contact) > 0 {
resp.Data["contact"] = contact resp.Data["contact"] = acct.Contact
} }
resp.Headers["Location"] = []string{location}
return resp return resp
} }
func (b *backend) acmeNewAccountSearchHandler(acmeCtx acmeContext, r *logical.Request, fields *framework.FieldData, userCtx *jwsCtx, data map[string]interface{}) (*logical.Response, error) { func (b *backend) acmeNewAccountSearchHandler(acmeCtx *acmeContext, r *logical.Request, fields *framework.FieldData, userCtx *jwsCtx, data map[string]interface{}) (*logical.Response, error) {
if userCtx.Existing || b.acmeState.DoesAccountExist(userCtx.Kid) { if userCtx.Existing || b.acmeState.DoesAccountExist(acmeCtx, userCtx.Kid) {
// This account exists; return its details. It would be slightly // This account exists; return its details. It would be slightly
// weird to specify a kid in the request (and not use an explicit // weird to specify a kid in the request (and not use an explicit
// jwk here), but we might as well support it too. // jwk here), but we might as well support it too.
account, err := b.acmeState.LoadAccount(userCtx.Kid) account, err := b.acmeState.LoadAccount(acmeCtx, userCtx.Kid)
if err != nil { if err != nil {
return nil, fmt.Errorf("error loading account: %w", err) return nil, fmt.Errorf("error loading account: %w", err)
} }
location := acmeCtx.baseUrl.String() + "account/" + userCtx.Kid location := acmeCtx.baseUrl.String() + "account/" + userCtx.Kid
return formatAccountResponse(location, account["status"].(string), account["contact"].([]string)), nil return formatAccountResponse(location, account), nil
} }
// Per RFC 8555 Section 7.3.1. Finding an Account URL Given a Key: // Per RFC 8555 Section 7.3.1. Finding an Account URL Given a Key:
@@ -181,13 +255,13 @@ func (b *backend) acmeNewAccountSearchHandler(acmeCtx acmeContext, r *logical.Re
return nil, fmt.Errorf("An account with this key does not exist: %w", ErrAccountDoesNotExist) return nil, fmt.Errorf("An account with this key does not exist: %w", ErrAccountDoesNotExist)
} }
func (b *backend) acmeNewAccountCreateHandler(acmeCtx acmeContext, r *logical.Request, fields *framework.FieldData, userCtx *jwsCtx, data map[string]interface{}, contact []string, termsOfServiceAgreed bool) (*logical.Response, error) { func (b *backend) acmeNewAccountCreateHandler(acmeCtx *acmeContext, r *logical.Request, fields *framework.FieldData, userCtx *jwsCtx, data map[string]interface{}, contact []string, termsOfServiceAgreed bool) (*logical.Response, error) {
if userCtx.Existing { if userCtx.Existing {
return nil, fmt.Errorf("cannot submit to newAccount with 'kid': %w", ErrMalformed) return nil, fmt.Errorf("cannot submit to newAccount with 'kid': %w", ErrMalformed)
} }
// If the account already exists, return the existing one. // If the account already exists, return the existing one.
if b.acmeState.DoesAccountExist(userCtx.Kid) { if b.acmeState.DoesAccountExist(acmeCtx, userCtx.Kid) {
return b.acmeNewAccountSearchHandler(acmeCtx, r, fields, userCtx, data) return b.acmeNewAccountSearchHandler(acmeCtx, r, fields, userCtx, data)
} }
@@ -196,11 +270,18 @@ func (b *backend) acmeNewAccountCreateHandler(acmeCtx acmeContext, r *logical.Re
return nil, fmt.Errorf("terms of service not agreed to: %w", ErrUserActionRequired) return nil, fmt.Errorf("terms of service not agreed to: %w", ErrUserActionRequired)
} }
account, err := b.acmeState.CreateAccount(userCtx, contact, termsOfServiceAgreed) account, err := b.acmeState.CreateAccount(acmeCtx, userCtx, contact, termsOfServiceAgreed)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create account: %w", err) return nil, fmt.Errorf("failed to create account: %w", err)
} }
location := acmeCtx.baseUrl.String() + "account/" + userCtx.Kid location := acmeCtx.baseUrl.String() + "account/" + userCtx.Kid
return formatAccountResponse(location, account["status"].(string), account["contact"].([]string)), nil resp := formatAccountResponse(location, account)
// Per RFC 8555 Section 7.3. Account Management:
//
// > The server returns this account object in a 201 (Created) response,
// > with the account URL in a Location header field.
resp.Data[logical.HTTPStatusCode] = http.StatusCreated
return resp, nil
} }

View File

@@ -51,7 +51,7 @@ func patternAcmeNonce(b *backend, pattern string) *framework.Path {
} }
} }
func (b *backend) acmeNonceHandler(ctx acmeContext, r *logical.Request, _ *framework.FieldData) (*logical.Response, error) { func (b *backend) acmeNonceHandler(ctx *acmeContext, r *logical.Request, _ *framework.FieldData) (*logical.Response, error) {
nonce, _, err := b.acmeState.GetNonce() nonce, _, err := b.acmeState.GetNonce()
if err != nil { if err != nil {
return nil, err return nil, err
@@ -78,7 +78,7 @@ func (b *backend) acmeNonceHandler(ctx acmeContext, r *logical.Request, _ *frame
}, nil }, nil
} }
func genAcmeLinkHeader(ctx acmeContext) []string { func genAcmeLinkHeader(ctx *acmeContext) []string {
path := fmt.Sprintf("<%s>;rel=\"index\"", ctx.baseUrl.JoinPath("directory").String()) path := fmt.Sprintf("<%s>;rel=\"index\"", ctx.baseUrl.JoinPath("directory").String())
return []string{path} return []string{path}
} }