From 194341e520f2c7f0f9049ab65fa42b8898101d0e Mon Sep 17 00:00:00 2001 From: Herman Slatman Date: Wed, 7 Feb 2024 00:54:29 +0100 Subject: [PATCH] Address review comments --- acme/api/order.go | 20 ++++++++++---------- acme/api/order_test.go | 4 ++-- acme/challenge.go | 21 ++++++++++----------- acme/order.go | 6 +++--- acme/wire/id.go | 8 ++++---- acme/wire/id_test.go | 30 +++++++++++++++--------------- authority/provisioner/acme.go | 10 +++++----- authority/provisioner/acme_test.go | 13 ++++++++++++- authority/provisioner/options.go | 12 ++++++++---- 9 files changed, 69 insertions(+), 55 deletions(-) diff --git a/acme/api/order.go b/acme/api/order.go index 9db94e90..14549e75 100644 --- a/acme/api/order.go +++ b/acme/api/order.go @@ -84,12 +84,12 @@ func (n *NewOrderRequest) validateWireIdentifiers() error { return fmt.Errorf("expected exactly one Wire DeviceID identifier, got %d", len(deviceIdentifiers)) } - wireUserID, err := wire.ParseUserID([]byte(userIdentifiers[0].Value)) + wireUserID, err := wire.ParseUserID(userIdentifiers[0].Value) if err != nil { return fmt.Errorf("failed parsing Wire UserID: %w", err) } - wireDeviceID, err := wire.ParseDeviceID([]byte(deviceIdentifiers[0].Value)) + wireDeviceID, err := wire.ParseDeviceID(deviceIdentifiers[0].Value) if err != nil { return fmt.Errorf("failed parsing Wire DeviceID: %w", err) } @@ -337,16 +337,16 @@ func newAuthorization(ctx context.Context, az *acme.Authorization) error { var target string switch az.Identifier.Type { case acme.WireUser: - wireOptions := prov.GetOptions().GetWireOptions() - if wireOptions == nil { - return acme.NewErrorISE("failed getting Wire options") + wireOptions, err := prov.GetOptions().GetWireOptions() + if err != nil { + return acme.WrapErrorISE(err, "failed getting Wire options") } - target, err = wireOptions.GetOIDCOptions().EvaluateTarget("") + target, err = wireOptions.GetOIDCOptions().EvaluateTarget("") // TODO(hs): determine if required by Wire if err != nil { return acme.WrapError(acme.ErrorMalformedType, err, "invalid Go template registered for 'target'") } case acme.WireDevice: - wireID, err := wire.ParseDeviceID([]byte(az.Identifier.Value)) + wireID, err := wire.ParseDeviceID(az.Identifier.Value) if err != nil { return acme.WrapError(acme.ErrorMalformedType, err, "failed parsing WireDevice") } @@ -354,9 +354,9 @@ func newAuthorization(ctx context.Context, az *acme.Authorization) error { if err != nil { return acme.WrapError(acme.ErrorMalformedType, err, "failed parsing ClientID") } - wireOptions := prov.GetOptions().GetWireOptions() - if wireOptions == nil { - return acme.NewErrorISE("failed getting Wire options") + wireOptions, err := prov.GetOptions().GetWireOptions() + if err != nil { + return acme.WrapErrorISE(err, "failed getting Wire options") } target, err = wireOptions.GetDPOPOptions().EvaluateTarget(clientID.DeviceID) if err != nil { diff --git a/acme/api/order_test.go b/acme/api/order_test.go index b57a2c75..9daa2f70 100644 --- a/acme/api/order_test.go +++ b/acme/api/order_test.go @@ -699,7 +699,7 @@ MCowBQYDK2VwAyEA5c+4NKZSNQcR1T8qN6SjwgdPZQ0Ge12Ylx/YeGAJ35k= az: az, err: &acme.Error{ Type: "urn:ietf:params:acme:error:serverInternal", - Err: errors.New("failed getting Wire options"), + Err: errors.New("failed getting Wire options: no Wire options available"), Detail: "The server experienced an internal error", Status: 500, }, @@ -765,7 +765,7 @@ MCowBQYDK2VwAyEA5c+4NKZSNQcR1T8qN6SjwgdPZQ0Ge12Ylx/YeGAJ35k= az: az, err: &acme.Error{ Type: "urn:ietf:params:acme:error:serverInternal", - Err: errors.New("failed getting Wire options"), + Err: errors.New("failed getting Wire options: no Wire options available"), Detail: "The server experienced an internal error", Status: 500, }, diff --git a/acme/challenge.go b/acme/challenge.go index b087e83c..465125e9 100644 --- a/acme/challenge.go +++ b/acme/challenge.go @@ -362,9 +362,9 @@ func wireOIDC01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSO if !ok { return NewErrorISE("missing provisioner") } - wireOptions := prov.GetOptions().GetWireOptions() - if wireOptions == nil { - return NewErrorISE("no Wire options available") + wireOptions, err := prov.GetOptions().GetWireOptions() + if err != nil { + return WrapErrorISE(err, "failed getting Wire options") } linker, ok := LinkerFromContext(ctx) if !ok { @@ -372,12 +372,11 @@ func wireOIDC01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSO } var oidcPayload wireOidcPayload - err := json.Unmarshal(payload, &oidcPayload) - if err != nil { + if err := json.Unmarshal(payload, &oidcPayload); err != nil { return WrapError(ErrorMalformedType, err, "error unmarshalling Wire OIDC challenge payload") } - wireID, err := wire.ParseUserID([]byte(ch.Value)) + wireID, err := wire.ParseUserID(ch.Value) if err != nil { return WrapErrorISE(err, "error unmarshalling challenge data") } @@ -493,9 +492,9 @@ func wireDPOP01Validate(ctx context.Context, ch *Challenge, db DB, accountJWK *j if !ok { return NewErrorISE("missing provisioner") } - wireOptions := prov.GetOptions().GetWireOptions() - if wireOptions == nil { - return NewErrorISE("no Wire options available") + wireOptions, err := prov.GetOptions().GetWireOptions() + if err != nil { + return WrapErrorISE(err, "failed getting Wire options") } linker, ok := LinkerFromContext(ctx) if !ok { @@ -507,7 +506,7 @@ func wireDPOP01Validate(ctx context.Context, ch *Challenge, db DB, accountJWK *j return WrapError(ErrorMalformedType, err, "error unmarshalling Wire DPoP challenge payload") } - wireID, err := wire.ParseDeviceID([]byte(ch.Value)) + wireID, err := wire.ParseDeviceID(ch.Value) if err != nil { return WrapErrorISE(err, "error unmarshalling challenge data") } @@ -728,7 +727,7 @@ func parseAndVerifyWireAccessToken(v wireVerifyParams) (*wireAccessToken, *wireD return nil, nil, fmt.Errorf("invalid display name in Wire DPoP token") } if name == "" || name != v.wireID.Name { - return nil, nil, fmt.Errorf("invalid Wire client display name %q", handle) + return nil, nil, fmt.Errorf("invalid Wire client display name %q", name) } return &accessToken, &dpopToken, nil diff --git a/acme/order.go b/acme/order.go index 974bac5f..1175bc38 100644 --- a/acme/order.go +++ b/acme/order.go @@ -340,7 +340,7 @@ func createWireSubject(o *Order, csr *x509.CertificateRequest) (subject x509util for _, identifier := range o.Identifiers { switch identifier.Type { case WireUser: - wireID, err := wire.ParseUserID([]byte(identifier.Value)) + wireID, err := wire.ParseUserID(identifier.Value) if err != nil { return subject, NewErrorISE("unmarshal wireID: %s", err) } @@ -406,7 +406,7 @@ func (o *Order) sans(csr *x509.CertificateRequest) ([]x509util.SubjectAlternativ orderPIDs[indexPID] = n.Value indexPID++ case WireUser: - wireID, err := wire.ParseUserID([]byte(n.Value)) + wireID, err := wire.ParseUserID(n.Value) if err != nil { return sans, NewErrorISE("unsupported identifier value in order: %s", n.Value) } @@ -417,7 +417,7 @@ func (o *Order) sans(csr *x509.CertificateRequest) ([]x509util.SubjectAlternativ tmpOrderURIs[indexURI] = handle indexURI++ case WireDevice: - wireID, err := wire.ParseDeviceID([]byte(n.Value)) + wireID, err := wire.ParseDeviceID(n.Value) if err != nil { return sans, NewErrorISE("unsupported identifier value in order: %s", n.Value) } diff --git a/acme/wire/id.go b/acme/wire/id.go index 6e6b2dde..d1abcedb 100644 --- a/acme/wire/id.go +++ b/acme/wire/id.go @@ -22,8 +22,8 @@ type DeviceID struct { Handle string `json:"handle,omitempty"` } -func ParseUserID(data []byte) (id UserID, err error) { - if err = json.Unmarshal(data, &id); err != nil { +func ParseUserID(value string) (id UserID, err error) { + if err = json.Unmarshal([]byte(value), &id); err != nil { return } @@ -39,8 +39,8 @@ func ParseUserID(data []byte) (id UserID, err error) { return } -func ParseDeviceID(data []byte) (id DeviceID, err error) { - if err = json.Unmarshal(data, &id); err != nil { +func ParseDeviceID(value string) (id DeviceID, err error) { + if err = json.Unmarshal([]byte(value), &id); err != nil { return } diff --git a/acme/wire/id_test.go b/acme/wire/id_test.go index 3913c722..3cf114b7 100644 --- a/acme/wire/id_test.go +++ b/acme/wire/id_test.go @@ -15,19 +15,19 @@ func TestParseUserID(t *testing.T) { emptyDomain := `{"name": "Alice Smith", "domain": "", "handle": "wireapp://%40alice_wire@wire.com"}` tests := []struct { name string - data []byte + value string wantWireID UserID wantErr bool }{ - {name: "ok", data: []byte(ok), wantWireID: UserID{Name: "Alice Smith", Domain: "wire.com", Handle: "wireapp://%40alice_wire@wire.com"}}, - {name: "fail/json", data: []byte(failJSON), wantErr: true}, - {name: "fail/empty-handle", data: []byte(emptyHandle), wantErr: true}, - {name: "fail/empty-name", data: []byte(emptyName), wantErr: true}, - {name: "fail/empty-domain", data: []byte(emptyDomain), wantErr: true}, + {name: "ok", value: ok, wantWireID: UserID{Name: "Alice Smith", Domain: "wire.com", Handle: "wireapp://%40alice_wire@wire.com"}}, + {name: "fail/json", value: failJSON, wantErr: true}, + {name: "fail/empty-handle", value: emptyHandle, wantErr: true}, + {name: "fail/empty-name", value: emptyName, wantErr: true}, + {name: "fail/empty-domain", value: emptyDomain, wantErr: true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - gotWireID, err := ParseUserID(tt.data) + gotWireID, err := ParseUserID(tt.value) if tt.wantErr { assert.Error(t, err) return @@ -48,20 +48,20 @@ func TestParseDeviceID(t *testing.T) { emptyClientID := `{"name": "device", "domain": "wire.com", "client-id": "", "handle": "wireapp://%40alice_wire@wire.com"}` tests := []struct { name string - data []byte + value string wantWireID DeviceID wantErr bool }{ - {name: "ok", data: []byte(ok), wantWireID: DeviceID{Name: "device", Domain: "wire.com", ClientID: "wireapp://CzbfFjDOQrenCbDxVmgnFw!594930e9d50bb175@wire.com", Handle: "wireapp://%40alice_wire@wire.com"}}, - {name: "fail/json", data: []byte(failJSON), wantErr: true}, - {name: "fail/empty-handle", data: []byte(emptyHandle), wantErr: true}, - {name: "fail/empty-name", data: []byte(emptyName), wantErr: true}, - {name: "fail/empty-domain", data: []byte(emptyDomain), wantErr: true}, - {name: "fail/empty-client-id", data: []byte(emptyClientID), wantErr: true}, + {name: "ok", value: ok, wantWireID: DeviceID{Name: "device", Domain: "wire.com", ClientID: "wireapp://CzbfFjDOQrenCbDxVmgnFw!594930e9d50bb175@wire.com", Handle: "wireapp://%40alice_wire@wire.com"}}, + {name: "fail/json", value: failJSON, wantErr: true}, + {name: "fail/empty-handle", value: emptyHandle, wantErr: true}, + {name: "fail/empty-name", value: emptyName, wantErr: true}, + {name: "fail/empty-domain", value: emptyDomain, wantErr: true}, + {name: "fail/empty-client-id", value: emptyClientID, wantErr: true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - gotWireID, err := ParseDeviceID(tt.data) + gotWireID, err := ParseDeviceID(tt.value) if tt.wantErr { assert.Error(t, err) return diff --git a/authority/provisioner/acme.go b/authority/provisioner/acme.go index 35b31da6..3b7fa654 100644 --- a/authority/provisioner/acme.go +++ b/authority/provisioner/acme.go @@ -235,9 +235,9 @@ func (p *ACME) initializeWireOptions() error { return nil } - w := p.GetOptions().GetWireOptions() - if w == nil { - return errors.New("no Wire options available") + w, err := p.GetOptions().GetWireOptions() + if err != nil { + return fmt.Errorf("failed getting Wire options: %w", err) } if err := w.Validate(); err != nil { @@ -295,13 +295,13 @@ func (p *ACME) AuthorizeOrderIdentifier(_ context.Context, identifier ACMEIdenti err = x509Policy.IsDNSAllowed(identifier.Value) case WireUser: var wireID wire.UserID - if wireID, err = wire.ParseUserID([]byte(identifier.Value)); err != nil { + if wireID, err = wire.ParseUserID(identifier.Value); err != nil { return fmt.Errorf("failed parsing Wire SANs: %w", err) } err = x509Policy.AreSANsAllowed([]string{wireID.Handle}) case WireDevice: var wireID wire.DeviceID - if wireID, err = wire.ParseDeviceID([]byte(identifier.Value)); err != nil { + if wireID, err = wire.ParseDeviceID(identifier.Value); err != nil { return fmt.Errorf("failed parsing Wire SANs: %w", err) } err = x509Policy.AreSANsAllowed([]string{wireID.ClientID}) diff --git a/authority/provisioner/acme_test.go b/authority/provisioner/acme_test.go index 8f5f8306..96f4bd8b 100644 --- a/authority/provisioner/acme_test.go +++ b/authority/provisioner/acme_test.go @@ -155,7 +155,18 @@ MCowBQYDK2VwAyEA5c+4NKZSNQcR1T8qN6SjwgdPZQ0Ge12Ylx/YeGAJ35k= Type: "ACME", Challenges: []ACMEChallenge{WIREOIDC_01, WIREDPOP_01}, }, - err: errors.New("failed initializing Wire options: no Wire options available"), + err: errors.New("failed initializing Wire options: failed getting Wire options: no options available"), + } + }, + "fail/wire-missing-wire-options": func(t *testing.T) ProvisionerValidateTest { + return ProvisionerValidateTest{ + p: &ACME{ + Name: "foo", + Type: "ACME", + Challenges: []ACMEChallenge{WIREOIDC_01, WIREDPOP_01}, + Options: &Options{}, + }, + err: errors.New("failed initializing Wire options: failed getting Wire options: no Wire options available"), } }, "fail/wire-validate-options": func(t *testing.T) ProvisionerValidateTest { diff --git a/authority/provisioner/options.go b/authority/provisioner/options.go index 1e0457c5..ec778081 100644 --- a/authority/provisioner/options.go +++ b/authority/provisioner/options.go @@ -53,12 +53,16 @@ func (o *Options) GetSSHOptions() *SSHOptions { return o.SSH } -// GetWireOptions returns the SSH options. -func (o *Options) GetWireOptions() *wire.Options { +// GetWireOptions returns the Wire options if available. It +// returns an error if they're not available. +func (o *Options) GetWireOptions() (*wire.Options, error) { if o == nil { - return nil + return nil, errors.New("no options available") } - return o.Wire + if o.Wire == nil { + return nil, errors.New("no Wire options available") + } + return o.Wire, nil } // GetWebhooks returns the webhooks options.