Change Wire DB operations into using a runtime type assertion

This commit is contained in:
Herman Slatman
2024-08-13 11:11:08 +02:00
parent 92e95e4df3
commit bb512e76c3
5 changed files with 271 additions and 181 deletions

View File

@@ -393,6 +393,10 @@ type wireOidcPayload struct {
}
func wireOIDC01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, payload []byte) error {
wireDB, ok := db.(WireDB)
if !ok {
return NewErrorISE("db %T is not a WireDB", db)
}
prov, ok := ProvisionerFromContext(ctx)
if !ok {
return NewErrorISE("missing provisioner")
@@ -472,7 +476,7 @@ func wireOIDC01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSO
return WrapErrorISE(err, "error updating challenge")
}
orders, err := db.GetAllOrdersByAccountID(ctx, ch.AccountID)
orders, err := wireDB.GetAllOrdersByAccountID(ctx, ch.AccountID)
if err != nil {
return WrapErrorISE(err, "could not retrieve current order by account id")
}
@@ -481,7 +485,7 @@ func wireOIDC01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSO
}
order := orders[len(orders)-1]
if err := db.CreateOidcToken(ctx, order, transformedIDToken); err != nil {
if err := wireDB.CreateOidcToken(ctx, order, transformedIDToken); err != nil {
return WrapErrorISE(err, "failed storing OIDC id token")
}
@@ -523,6 +527,10 @@ type wireDpopPayload struct {
}
func wireDPOP01Validate(ctx context.Context, ch *Challenge, db DB, accountJWK *jose.JSONWebKey, payload []byte) error {
wireDB, ok := db.(WireDB)
if !ok {
return NewErrorISE("db %T is not a WireDB", db)
}
prov, ok := ProvisionerFromContext(ctx)
if !ok {
return NewErrorISE("missing provisioner")
@@ -586,7 +594,7 @@ func wireDPOP01Validate(ctx context.Context, ch *Challenge, db DB, accountJWK *j
return WrapErrorISE(err, "error updating challenge")
}
orders, err := db.GetAllOrdersByAccountID(ctx, ch.AccountID)
orders, err := wireDB.GetAllOrdersByAccountID(ctx, ch.AccountID)
if err != nil {
return WrapErrorISE(err, "could not find current order by account id")
}
@@ -595,7 +603,7 @@ func wireDPOP01Validate(ctx context.Context, ch *Challenge, db DB, accountJWK *j
}
order := orders[len(orders)-1]
if err := db.CreateDpopToken(ctx, order, map[string]any(*dpop)); err != nil {
if err := wireDB.CreateDpopToken(ctx, order, map[string]any(*dpop)); err != nil {
return WrapErrorISE(err, "failed storing DPoP token")
}

View File

@@ -962,14 +962,16 @@ MCowBQYDK2VwAyEA5c+4NKZSNQcR1T8qN6SjwgdPZQ0Ge12Ylx/YeGAJ35k=
payload: payload,
ctx: ctx,
jwk: jwk,
db: &MockDB{
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
assert.Equal(t, "chID", updch.ID)
assert.Equal(t, "token", updch.Token)
assert.Equal(t, StatusValid, updch.Status)
assert.Equal(t, ChallengeType("wire-oidc-01"), updch.Type)
assert.Equal(t, string(valueBytes), updch.Value)
return nil
db: &MockWireDB{
MockDB: MockDB{
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
assert.Equal(t, "chID", updch.ID)
assert.Equal(t, "token", updch.Token)
assert.Equal(t, StatusValid, updch.Status)
assert.Equal(t, ChallengeType("wire-oidc-01"), updch.Type)
assert.Equal(t, string(valueBytes), updch.Value)
return nil
},
},
MockGetAllOrdersByAccountID: func(ctx context.Context, accountID string) ([]string, error) {
assert.Equal(t, "accID", accountID)
@@ -1111,14 +1113,16 @@ MCowBQYDK2VwAyEA5c+4NKZSNQcR1T8qN6SjwgdPZQ0Ge12Ylx/YeGAJ35k=
payload: payload,
ctx: ctx,
jwk: jwk,
db: &MockDB{
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
assert.Equal(t, "chID", updch.ID)
assert.Equal(t, "token", updch.Token)
assert.Equal(t, StatusValid, updch.Status)
assert.Equal(t, ChallengeType("wire-dpop-01"), updch.Type)
assert.Equal(t, string(valueBytes), updch.Value)
return nil
db: &MockWireDB{
MockDB: MockDB{
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
assert.Equal(t, "chID", updch.ID)
assert.Equal(t, "token", updch.Token)
assert.Equal(t, StatusValid, updch.Status)
assert.Equal(t, ChallengeType("wire-dpop-01"), updch.Type)
assert.Equal(t, string(valueBytes), updch.Value)
return nil
},
},
MockGetAllOrdersByAccountID: func(ctx context.Context, accountID string) ([]string, error) {
assert.Equal(t, "accID", accountID)

View File

@@ -35,9 +35,22 @@ MCowBQYDK2VwAyEA5c+4NKZSNQcR1T8qN6SjwgdPZQ0Ge12Ylx/YeGAJ35k=
expectedErr *Error
}
tests := map[string]func(t *testing.T) test{
"fail/no-wire-db": func(t *testing.T) test {
return test{
ctx: context.Background(),
db: &MockDB{},
expectedErr: &Error{
Type: "urn:ietf:params:acme:error:serverInternal",
Detail: "The server experienced an internal error",
Status: 500,
Err: errors.New("db *acme.MockDB is not a WireDB"),
},
}
},
"fail/no-provisioner": func(t *testing.T) test {
return test{
ctx: context.Background(),
db: &MockWireDB{},
expectedErr: &Error{
Type: "urn:ietf:params:acme:error:serverInternal",
Detail: "The server experienced an internal error",
@@ -68,6 +81,7 @@ MCowBQYDK2VwAyEA5c+4NKZSNQcR1T8qN6SjwgdPZQ0Ge12Ylx/YeGAJ35k=
}))
return test{
ctx: ctx,
db: &MockWireDB{},
expectedErr: &Error{
Type: "urn:ietf:params:acme:error:serverInternal",
Detail: "The server experienced an internal error",
@@ -109,6 +123,7 @@ MCowBQYDK2VwAyEA5c+4NKZSNQcR1T8qN6SjwgdPZQ0Ge12Ylx/YeGAJ35k=
Status: StatusPending,
Value: "1234",
},
db: &MockWireDB{},
expectedErr: &Error{
Type: "urn:ietf:params:acme:error:malformed",
Detail: "The request message was malformed",
@@ -150,6 +165,7 @@ MCowBQYDK2VwAyEA5c+4NKZSNQcR1T8qN6SjwgdPZQ0Ge12Ylx/YeGAJ35k=
Status: StatusPending,
Value: "1234",
},
db: &MockWireDB{},
expectedErr: &Error{
Type: "urn:ietf:params:acme:error:serverInternal",
Detail: "The server experienced an internal error",
@@ -203,6 +219,7 @@ MCowBQYDK2VwAyEA5c+4NKZSNQcR1T8qN6SjwgdPZQ0Ge12Ylx/YeGAJ35k=
Status: StatusPending,
Value: string(valueBytes),
},
db: &MockWireDB{},
expectedErr: &Error{
Type: "urn:ietf:params:acme:error:serverInternal",
Detail: "The server experienced an internal error",
@@ -259,25 +276,27 @@ MCowBQYDK2VwAyEA5c+4NKZSNQcR1T8qN6SjwgdPZQ0Ge12Ylx/YeGAJ35k=
Status: StatusPending,
Value: string(valueBytes),
},
db: &MockDB{
MockUpdateChallenge: func(ctx context.Context, ch *Challenge) error {
assert.Equal(t, "chID", ch.ID)
assert.Equal(t, "azID", ch.AuthorizationID)
assert.Equal(t, "accID", ch.AccountID)
assert.Equal(t, "token", ch.Token)
assert.Equal(t, ChallengeType("wire-dpop-01"), ch.Type)
assert.Equal(t, StatusInvalid, ch.Status)
assert.Equal(t, string(valueBytes), ch.Value)
if assert.NotNil(t, ch.Error) {
var k *Error // NOTE: the error is not returned up, but stored with the challenge instead
if errors.As(ch.Error, &k) {
assert.Equal(t, "urn:ietf:params:acme:error:rejectedIdentifier", k.Type)
assert.Equal(t, "The server will not issue certificates for the identifier", k.Detail)
assert.Equal(t, 400, k.Status)
assert.Equal(t, `failed validating Wire access token: failed parsing token: go-jose/go-jose: compact JWS format must have three parts`, k.Err.Error())
db: &MockWireDB{
MockDB: MockDB{
MockUpdateChallenge: func(ctx context.Context, ch *Challenge) error {
assert.Equal(t, "chID", ch.ID)
assert.Equal(t, "azID", ch.AuthorizationID)
assert.Equal(t, "accID", ch.AccountID)
assert.Equal(t, "token", ch.Token)
assert.Equal(t, ChallengeType("wire-dpop-01"), ch.Type)
assert.Equal(t, StatusInvalid, ch.Status)
assert.Equal(t, string(valueBytes), ch.Value)
if assert.NotNil(t, ch.Error) {
var k *Error // NOTE: the error is not returned up, but stored with the challenge instead
if errors.As(ch.Error, &k) {
assert.Equal(t, "urn:ietf:params:acme:error:rejectedIdentifier", k.Type)
assert.Equal(t, "The server will not issue certificates for the identifier", k.Detail)
assert.Equal(t, 400, k.Status)
assert.Equal(t, `failed validating Wire access token: failed parsing token: go-jose/go-jose: compact JWS format must have three parts`, k.Err.Error())
}
}
}
return nil
return nil
},
},
},
}
@@ -410,14 +429,16 @@ MCowBQYDK2VwAyEA5c+4NKZSNQcR1T8qN6SjwgdPZQ0Ge12Ylx/YeGAJ35k=
payload: payload,
ctx: ctx,
jwk: jwk,
db: &MockDB{
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
assert.Equal(t, "chID", updch.ID)
assert.Equal(t, "token", updch.Token)
assert.Equal(t, StatusValid, updch.Status)
assert.Equal(t, ChallengeType("wire-dpop-01"), updch.Type)
assert.Equal(t, string(valueBytes), updch.Value)
return errors.New("fail")
db: &MockWireDB{
MockDB: MockDB{
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
assert.Equal(t, "chID", updch.ID)
assert.Equal(t, "token", updch.Token)
assert.Equal(t, StatusValid, updch.Status)
assert.Equal(t, ChallengeType("wire-dpop-01"), updch.Type)
assert.Equal(t, string(valueBytes), updch.Value)
return errors.New("fail")
},
},
},
expectedErr: &Error{
@@ -556,14 +577,16 @@ MCowBQYDK2VwAyEA5c+4NKZSNQcR1T8qN6SjwgdPZQ0Ge12Ylx/YeGAJ35k=
payload: payload,
ctx: ctx,
jwk: jwk,
db: &MockDB{
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
assert.Equal(t, "chID", updch.ID)
assert.Equal(t, "token", updch.Token)
assert.Equal(t, StatusValid, updch.Status)
assert.Equal(t, ChallengeType("wire-dpop-01"), updch.Type)
assert.Equal(t, string(valueBytes), updch.Value)
return nil
db: &MockWireDB{
MockDB: MockDB{
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
assert.Equal(t, "chID", updch.ID)
assert.Equal(t, "token", updch.Token)
assert.Equal(t, StatusValid, updch.Status)
assert.Equal(t, ChallengeType("wire-dpop-01"), updch.Type)
assert.Equal(t, string(valueBytes), updch.Value)
return nil
},
},
MockGetAllOrdersByAccountID: func(ctx context.Context, accountID string) ([]string, error) {
assert.Equal(t, "accID", accountID)
@@ -706,14 +729,16 @@ MCowBQYDK2VwAyEA5c+4NKZSNQcR1T8qN6SjwgdPZQ0Ge12Ylx/YeGAJ35k=
payload: payload,
ctx: ctx,
jwk: jwk,
db: &MockDB{
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
assert.Equal(t, "chID", updch.ID)
assert.Equal(t, "token", updch.Token)
assert.Equal(t, StatusValid, updch.Status)
assert.Equal(t, ChallengeType("wire-dpop-01"), updch.Type)
assert.Equal(t, string(valueBytes), updch.Value)
return nil
db: &MockWireDB{
MockDB: MockDB{
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
assert.Equal(t, "chID", updch.ID)
assert.Equal(t, "token", updch.Token)
assert.Equal(t, StatusValid, updch.Status)
assert.Equal(t, ChallengeType("wire-dpop-01"), updch.Type)
assert.Equal(t, string(valueBytes), updch.Value)
return nil
},
},
MockGetAllOrdersByAccountID: func(ctx context.Context, accountID string) ([]string, error) {
assert.Equal(t, "accID", accountID)
@@ -856,14 +881,16 @@ MCowBQYDK2VwAyEA5c+4NKZSNQcR1T8qN6SjwgdPZQ0Ge12Ylx/YeGAJ35k=
payload: payload,
ctx: ctx,
jwk: jwk,
db: &MockDB{
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
assert.Equal(t, "chID", updch.ID)
assert.Equal(t, "token", updch.Token)
assert.Equal(t, StatusValid, updch.Status)
assert.Equal(t, ChallengeType("wire-dpop-01"), updch.Type)
assert.Equal(t, string(valueBytes), updch.Value)
return nil
db: &MockWireDB{
MockDB: MockDB{
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
assert.Equal(t, "chID", updch.ID)
assert.Equal(t, "token", updch.Token)
assert.Equal(t, StatusValid, updch.Status)
assert.Equal(t, ChallengeType("wire-dpop-01"), updch.Type)
assert.Equal(t, string(valueBytes), updch.Value)
return nil
},
},
MockGetAllOrdersByAccountID: func(ctx context.Context, accountID string) ([]string, error) {
assert.Equal(t, "accID", accountID)
@@ -1012,14 +1039,16 @@ MCowBQYDK2VwAyEA5c+4NKZSNQcR1T8qN6SjwgdPZQ0Ge12Ylx/YeGAJ35k=
payload: payload,
ctx: ctx,
jwk: jwk,
db: &MockDB{
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
assert.Equal(t, "chID", updch.ID)
assert.Equal(t, "token", updch.Token)
assert.Equal(t, StatusValid, updch.Status)
assert.Equal(t, ChallengeType("wire-dpop-01"), updch.Type)
assert.Equal(t, string(valueBytes), updch.Value)
return nil
db: &MockWireDB{
MockDB: MockDB{
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
assert.Equal(t, "chID", updch.ID)
assert.Equal(t, "token", updch.Token)
assert.Equal(t, StatusValid, updch.Status)
assert.Equal(t, ChallengeType("wire-dpop-01"), updch.Type)
assert.Equal(t, string(valueBytes), updch.Value)
return nil
},
},
MockGetAllOrdersByAccountID: func(ctx context.Context, accountID string) ([]string, error) {
assert.Equal(t, "accID", accountID)
@@ -1072,9 +1101,22 @@ MCowBQYDK2VwAyEA5c+4NKZSNQcR1T8qN6SjwgdPZQ0Ge12Ylx/YeGAJ35k=
expectedErr *Error
}
tests := map[string]func(t *testing.T) test{
"fail/no-wire-db": func(t *testing.T) test {
return test{
ctx: context.Background(),
db: &MockDB{},
expectedErr: &Error{
Type: "urn:ietf:params:acme:error:serverInternal",
Detail: "The server experienced an internal error",
Status: 500,
Err: errors.New("db *acme.MockDB is not a WireDB"),
},
}
},
"fail/no-provisioner": func(t *testing.T) test {
return test{
ctx: context.Background(),
db: &MockWireDB{},
expectedErr: &Error{
Type: "urn:ietf:params:acme:error:serverInternal",
Detail: "The server experienced an internal error",
@@ -1105,6 +1147,7 @@ MCowBQYDK2VwAyEA5c+4NKZSNQcR1T8qN6SjwgdPZQ0Ge12Ylx/YeGAJ35k=
}))
return test{
ctx: ctx,
db: &MockWireDB{},
expectedErr: &Error{
Type: "urn:ietf:params:acme:error:serverInternal",
Detail: "The server experienced an internal error",
@@ -1146,10 +1189,12 @@ MCowBQYDK2VwAyEA5c+4NKZSNQcR1T8qN6SjwgdPZQ0Ge12Ylx/YeGAJ35k=
Status: StatusPending,
Value: "1234",
},
db: &MockDB{
MockUpdateChallenge: func(ctx context.Context, ch *Challenge) error {
assert.Equal(t, "chID", ch.ID)
return nil
db: &MockWireDB{
MockDB: MockDB{
MockUpdateChallenge: func(ctx context.Context, ch *Challenge) error {
assert.Equal(t, "chID", ch.ID)
return nil
},
},
},
expectedErr: &Error{
@@ -1193,6 +1238,7 @@ MCowBQYDK2VwAyEA5c+4NKZSNQcR1T8qN6SjwgdPZQ0Ge12Ylx/YeGAJ35k=
Status: StatusPending,
Value: "1234",
},
db: &MockWireDB{},
expectedErr: &Error{
Type: "urn:ietf:params:acme:error:serverInternal",
Detail: "The server experienced an internal error",
@@ -1288,23 +1334,25 @@ MCowBQYDK2VwAyEA5c+4NKZSNQcR1T8qN6SjwgdPZQ0Ge12Ylx/YeGAJ35k=
payload: payload,
ctx: ctx,
jwk: jwk,
db: &MockDB{
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
assert.Equal(t, "chID", updch.ID)
assert.Equal(t, "token", updch.Token)
assert.Equal(t, StatusInvalid, updch.Status)
assert.Equal(t, ChallengeType("wire-oidc-01"), updch.Type)
assert.Equal(t, string(valueBytes), updch.Value)
if assert.NotNil(t, updch.Error) {
var k *Error // NOTE: the error is not returned up, but stored with the challenge instead
if errors.As(updch.Error, &k) {
assert.Equal(t, "urn:ietf:params:acme:error:rejectedIdentifier", k.Type)
assert.Equal(t, "The server will not issue certificates for the identifier", k.Detail)
assert.Equal(t, 400, k.Status)
assert.Equal(t, `error verifying ID token signature: failed to verify signature: failed to verify id token signature`, k.Err.Error())
db: &MockWireDB{
MockDB: MockDB{
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
assert.Equal(t, "chID", updch.ID)
assert.Equal(t, "token", updch.Token)
assert.Equal(t, StatusInvalid, updch.Status)
assert.Equal(t, ChallengeType("wire-oidc-01"), updch.Type)
assert.Equal(t, string(valueBytes), updch.Value)
if assert.NotNil(t, updch.Error) {
var k *Error // NOTE: the error is not returned up, but stored with the challenge instead
if errors.As(updch.Error, &k) {
assert.Equal(t, "urn:ietf:params:acme:error:rejectedIdentifier", k.Type)
assert.Equal(t, "The server will not issue certificates for the identifier", k.Detail)
assert.Equal(t, 400, k.Status)
assert.Equal(t, `error verifying ID token signature: failed to verify signature: failed to verify id token signature`, k.Err.Error())
}
}
}
return nil
return nil
},
},
},
}
@@ -1394,23 +1442,25 @@ MCowBQYDK2VwAyEA5c+4NKZSNQcR1T8qN6SjwgdPZQ0Ge12Ylx/YeGAJ35k=
payload: payload,
ctx: ctx,
jwk: jwk,
db: &MockDB{
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
assert.Equal(t, "chID", updch.ID)
assert.Equal(t, "token", updch.Token)
assert.Equal(t, StatusInvalid, updch.Status)
assert.Equal(t, ChallengeType("wire-oidc-01"), updch.Type)
assert.Equal(t, string(valueBytes), updch.Value)
if assert.NotNil(t, updch.Error) {
var k *Error // NOTE: the error is not returned up, but stored with the challenge instead
if errors.As(updch.Error, &k) {
assert.Equal(t, "urn:ietf:params:acme:error:rejectedIdentifier", k.Type)
assert.Equal(t, "The server will not issue certificates for the identifier", k.Detail)
assert.Equal(t, 400, k.Status)
assert.Contains(t, k.Err.Error(), "keyAuthorization does not match")
db: &MockWireDB{
MockDB: MockDB{
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
assert.Equal(t, "chID", updch.ID)
assert.Equal(t, "token", updch.Token)
assert.Equal(t, StatusInvalid, updch.Status)
assert.Equal(t, ChallengeType("wire-oidc-01"), updch.Type)
assert.Equal(t, string(valueBytes), updch.Value)
if assert.NotNil(t, updch.Error) {
var k *Error // NOTE: the error is not returned up, but stored with the challenge instead
if errors.As(updch.Error, &k) {
assert.Equal(t, "urn:ietf:params:acme:error:rejectedIdentifier", k.Type)
assert.Equal(t, "The server will not issue certificates for the identifier", k.Detail)
assert.Equal(t, 400, k.Status)
assert.Contains(t, k.Err.Error(), "keyAuthorization does not match")
}
}
}
return nil
return nil
},
},
},
}
@@ -1500,23 +1550,25 @@ MCowBQYDK2VwAyEA5c+4NKZSNQcR1T8qN6SjwgdPZQ0Ge12Ylx/YeGAJ35k=
payload: payload,
ctx: ctx,
jwk: jwk,
db: &MockDB{
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
assert.Equal(t, "chID", updch.ID)
assert.Equal(t, "token", updch.Token)
assert.Equal(t, StatusInvalid, updch.Status)
assert.Equal(t, ChallengeType("wire-oidc-01"), updch.Type)
assert.Equal(t, string(valueBytes), updch.Value)
if assert.NotNil(t, updch.Error) {
var k *Error // NOTE: the error is not returned up, but stored with the challenge instead
if errors.As(updch.Error, &k) {
assert.Equal(t, "urn:ietf:params:acme:error:rejectedIdentifier", k.Type)
assert.Equal(t, "The server will not issue certificates for the identifier", k.Detail)
assert.Equal(t, 400, k.Status)
assert.Equal(t, `claims in OIDC ID token don't match: invalid 'preferred_username' "wireapp://%40bob@wire.com" after transformation`, k.Err.Error())
db: &MockWireDB{
MockDB: MockDB{
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
assert.Equal(t, "chID", updch.ID)
assert.Equal(t, "token", updch.Token)
assert.Equal(t, StatusInvalid, updch.Status)
assert.Equal(t, ChallengeType("wire-oidc-01"), updch.Type)
assert.Equal(t, string(valueBytes), updch.Value)
if assert.NotNil(t, updch.Error) {
var k *Error // NOTE: the error is not returned up, but stored with the challenge instead
if errors.As(updch.Error, &k) {
assert.Equal(t, "urn:ietf:params:acme:error:rejectedIdentifier", k.Type)
assert.Equal(t, "The server will not issue certificates for the identifier", k.Detail)
assert.Equal(t, 400, k.Status)
assert.Equal(t, `claims in OIDC ID token don't match: invalid 'preferred_username' "wireapp://%40bob@wire.com" after transformation`, k.Err.Error())
}
}
}
return nil
return nil
},
},
},
}
@@ -1606,14 +1658,16 @@ MCowBQYDK2VwAyEA5c+4NKZSNQcR1T8qN6SjwgdPZQ0Ge12Ylx/YeGAJ35k=
payload: payload,
ctx: ctx,
jwk: jwk,
db: &MockDB{
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
assert.Equal(t, "chID", updch.ID)
assert.Equal(t, "token", updch.Token)
assert.Equal(t, StatusValid, updch.Status)
assert.Equal(t, ChallengeType("wire-oidc-01"), updch.Type)
assert.Equal(t, string(valueBytes), updch.Value)
return errors.New("fail")
db: &MockWireDB{
MockDB: MockDB{
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
assert.Equal(t, "chID", updch.ID)
assert.Equal(t, "token", updch.Token)
assert.Equal(t, StatusValid, updch.Status)
assert.Equal(t, ChallengeType("wire-oidc-01"), updch.Type)
assert.Equal(t, string(valueBytes), updch.Value)
return errors.New("fail")
},
},
},
expectedErr: &Error{
@@ -1709,14 +1763,16 @@ MCowBQYDK2VwAyEA5c+4NKZSNQcR1T8qN6SjwgdPZQ0Ge12Ylx/YeGAJ35k=
payload: payload,
ctx: ctx,
jwk: jwk,
db: &MockDB{
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
assert.Equal(t, "chID", updch.ID)
assert.Equal(t, "token", updch.Token)
assert.Equal(t, StatusValid, updch.Status)
assert.Equal(t, ChallengeType("wire-oidc-01"), updch.Type)
assert.Equal(t, string(valueBytes), updch.Value)
return nil
db: &MockWireDB{
MockDB: MockDB{
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
assert.Equal(t, "chID", updch.ID)
assert.Equal(t, "token", updch.Token)
assert.Equal(t, StatusValid, updch.Status)
assert.Equal(t, ChallengeType("wire-oidc-01"), updch.Type)
assert.Equal(t, string(valueBytes), updch.Value)
return nil
},
},
MockGetAllOrdersByAccountID: func(ctx context.Context, accountID string) ([]string, error) {
assert.Equal(t, "accID", accountID)
@@ -1816,14 +1872,16 @@ MCowBQYDK2VwAyEA5c+4NKZSNQcR1T8qN6SjwgdPZQ0Ge12Ylx/YeGAJ35k=
payload: payload,
ctx: ctx,
jwk: jwk,
db: &MockDB{
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
assert.Equal(t, "chID", updch.ID)
assert.Equal(t, "token", updch.Token)
assert.Equal(t, StatusValid, updch.Status)
assert.Equal(t, ChallengeType("wire-oidc-01"), updch.Type)
assert.Equal(t, string(valueBytes), updch.Value)
return nil
db: &MockWireDB{
MockDB: MockDB{
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
assert.Equal(t, "chID", updch.ID)
assert.Equal(t, "token", updch.Token)
assert.Equal(t, StatusValid, updch.Status)
assert.Equal(t, ChallengeType("wire-oidc-01"), updch.Type)
assert.Equal(t, string(valueBytes), updch.Value)
return nil
},
},
MockGetAllOrdersByAccountID: func(ctx context.Context, accountID string) ([]string, error) {
assert.Equal(t, "accID", accountID)
@@ -1923,14 +1981,16 @@ MCowBQYDK2VwAyEA5c+4NKZSNQcR1T8qN6SjwgdPZQ0Ge12Ylx/YeGAJ35k=
payload: payload,
ctx: ctx,
jwk: jwk,
db: &MockDB{
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
assert.Equal(t, "chID", updch.ID)
assert.Equal(t, "token", updch.Token)
assert.Equal(t, StatusValid, updch.Status)
assert.Equal(t, ChallengeType("wire-oidc-01"), updch.Type)
assert.Equal(t, string(valueBytes), updch.Value)
return nil
db: &MockWireDB{
MockDB: MockDB{
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
assert.Equal(t, "chID", updch.ID)
assert.Equal(t, "token", updch.Token)
assert.Equal(t, StatusValid, updch.Status)
assert.Equal(t, ChallengeType("wire-oidc-01"), updch.Type)
assert.Equal(t, string(valueBytes), updch.Value)
return nil
},
},
MockGetAllOrdersByAccountID: func(ctx context.Context, accountID string) ([]string, error) {
assert.Equal(t, "accID", accountID)
@@ -2036,14 +2096,16 @@ MCowBQYDK2VwAyEA5c+4NKZSNQcR1T8qN6SjwgdPZQ0Ge12Ylx/YeGAJ35k=
payload: payload,
ctx: ctx,
jwk: jwk,
db: &MockDB{
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
assert.Equal(t, "chID", updch.ID)
assert.Equal(t, "token", updch.Token)
assert.Equal(t, StatusValid, updch.Status)
assert.Equal(t, ChallengeType("wire-oidc-01"), updch.Type)
assert.Equal(t, string(valueBytes), updch.Value)
return nil
db: &MockWireDB{
MockDB: MockDB{
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
assert.Equal(t, "chID", updch.ID)
assert.Equal(t, "token", updch.Token)
assert.Equal(t, StatusValid, updch.Status)
assert.Equal(t, ChallengeType("wire-oidc-01"), updch.Type)
assert.Equal(t, string(valueBytes), updch.Value)
return nil
},
},
MockGetAllOrdersByAccountID: func(ctx context.Context, accountID string) ([]string, error) {
assert.Equal(t, "accID", accountID)

View File

@@ -54,8 +54,14 @@ type DB interface {
GetOrder(ctx context.Context, id string) (*Order, error)
GetOrdersByAccountID(ctx context.Context, accountID string) ([]string, error)
UpdateOrder(ctx context.Context, o *Order) error
}
// TODO(hs): put in a different interface
// WireDB is the interface used for operations on ACME Orders for Wire identifiers. This
// is not a general purpose interface, and it should only be used when Wire identifiers
// are enabled in the CA configuration. Currently it provides a runtime assertion only;
// not at compile time.
type WireDB interface {
DB
GetAllOrdersByAccountID(ctx context.Context, accountID string) ([]string, error)
CreateDpopToken(ctx context.Context, orderID string, dpop map[string]interface{}) error
GetDpopToken(ctx context.Context, orderID string) (map[string]interface{}, error)
@@ -126,14 +132,20 @@ type MockDB struct {
MockGetOrdersByAccountID func(ctx context.Context, accountID string) ([]string, error)
MockUpdateOrder func(ctx context.Context, o *Order) error
MockRet1 interface{}
MockError error
}
// MockWireDB is an implementation of the WireDB interface that should only be used as
// a mock in tests. It embeds the MockDB, as it is an extension of the existing database
// methods.
type MockWireDB struct {
MockDB
MockGetAllOrdersByAccountID func(ctx context.Context, accountID string) ([]string, error)
MockGetDpopToken func(ctx context.Context, orderID string) (map[string]interface{}, error)
MockCreateDpopToken func(ctx context.Context, orderID string, dpop map[string]interface{}) error
MockGetOidcToken func(ctx context.Context, orderID string) (map[string]interface{}, error)
MockCreateOidcToken func(ctx context.Context, orderID string, idToken map[string]interface{}) error
MockRet1 interface{}
MockError error
}
// CreateAccount mock.
@@ -407,7 +419,7 @@ func (m *MockDB) GetOrdersByAccountID(ctx context.Context, accID string) ([]stri
}
// GetAllOrdersByAccountID returns a list of any order IDs owned by the account.
func (m *MockDB) GetAllOrdersByAccountID(ctx context.Context, accountID string) ([]string, error) {
func (m *MockWireDB) GetAllOrdersByAccountID(ctx context.Context, accountID string) ([]string, error) {
if m.MockGetAllOrdersByAccountID != nil {
return m.MockGetAllOrdersByAccountID(ctx, accountID)
} else if m.MockError != nil {
@@ -417,7 +429,7 @@ func (m *MockDB) GetAllOrdersByAccountID(ctx context.Context, accountID string)
}
// GetDpop retrieves a DPoP from the database.
func (m *MockDB) GetDpopToken(ctx context.Context, orderID string) (map[string]any, error) {
func (m *MockWireDB) GetDpopToken(ctx context.Context, orderID string) (map[string]any, error) {
if m.MockGetDpopToken != nil {
return m.MockGetDpopToken(ctx, orderID)
} else if m.MockError != nil {
@@ -427,7 +439,7 @@ func (m *MockDB) GetDpopToken(ctx context.Context, orderID string) (map[string]a
}
// CreateDpop creates DPoP resources and saves them to the DB.
func (m *MockDB) CreateDpopToken(ctx context.Context, orderID string, dpop map[string]any) error {
func (m *MockWireDB) CreateDpopToken(ctx context.Context, orderID string, dpop map[string]any) error {
if m.MockCreateDpopToken != nil {
return m.MockCreateDpopToken(ctx, orderID, dpop)
}
@@ -435,7 +447,7 @@ func (m *MockDB) CreateDpopToken(ctx context.Context, orderID string, dpop map[s
}
// GetOidcToken retrieves an oidc token from the database.
func (m *MockDB) GetOidcToken(ctx context.Context, orderID string) (map[string]any, error) {
func (m *MockWireDB) GetOidcToken(ctx context.Context, orderID string) (map[string]any, error) {
if m.MockGetOidcToken != nil {
return m.MockGetOidcToken(ctx, orderID)
} else if m.MockError != nil {
@@ -445,7 +457,7 @@ func (m *MockDB) GetOidcToken(ctx context.Context, orderID string) (map[string]a
}
// CreateOidcToken creates oidc token resources and saves them to the DB.
func (m *MockDB) CreateOidcToken(ctx context.Context, orderID string, idToken map[string]any) error {
func (m *MockWireDB) CreateOidcToken(ctx context.Context, orderID string, idToken map[string]any) error {
if m.MockCreateOidcToken != nil {
return m.MockCreateOidcToken(ctx, orderID, idToken)
}

View File

@@ -208,6 +208,10 @@ func (o *Order) Finalize(ctx context.Context, db DB, csr *x509.CertificateReques
// Template data
data := x509util.NewTemplateData()
if o.containsWireIdentifiers() {
wireDB, ok := db.(WireDB)
if !ok {
return fmt.Errorf("db %T is not a WireDB", db)
}
subject, err := createWireSubject(o, csr)
if err != nil {
return fmt.Errorf("failed creating Wire subject: %w", err)
@@ -215,13 +219,13 @@ func (o *Order) Finalize(ctx context.Context, db DB, csr *x509.CertificateReques
data.SetSubject(subject)
// Inject Wire's custom challenges into the template once they have been validated
dpop, err := db.GetDpopToken(ctx, o.ID)
dpop, err := wireDB.GetDpopToken(ctx, o.ID)
if err != nil {
return fmt.Errorf("failed getting Wire DPoP token: %w", err)
}
data.Set("Dpop", dpop)
oidc, err := db.GetOidcToken(ctx, o.ID)
oidc, err := wireDB.GetOidcToken(ctx, o.ID)
if err != nil {
return fmt.Errorf("failed getting Wire OIDC token: %w", err)
}