From 776a839a42eec6f77cb5e9a1632f9785e72b91c6 Mon Sep 17 00:00:00 2001 From: Herman Slatman Date: Tue, 9 Jan 2024 21:31:10 +0100 Subject: [PATCH] Fix linter issues and improve error handling --- acme/api/im_integration_test.go | 20 ++++---- acme/api/order.go | 8 ++-- acme/challenge.go | 18 ++++--- acme/db.go | 69 +++++++++++++++++++-------- acme/db/nosql/order.go | 4 +- acme/db/nosql/wire.go | 49 +++++++++---------- acme/order.go | 5 +- authority/provisioner/acme.go | 9 ++-- authority/provisioner/dpop_options.go | 14 ++++-- authority/provisioner/oidc_options.go | 22 +++++---- wire/id.go | 21 ++++---- 11 files changed, 139 insertions(+), 100 deletions(-) diff --git a/acme/api/im_integration_test.go b/acme/api/im_integration_test.go index 24ad2da3..83ed4dad 100644 --- a/acme/api/im_integration_test.go +++ b/acme/api/im_integration_test.go @@ -37,7 +37,6 @@ const ( ) func TestIMIntegration(t *testing.T) { - ctx := context.Background() prov := newACMEProvWithOptions(t, &provisioner.Options{ OIDC: &provisioner.OIDCOptions{ @@ -65,6 +64,7 @@ func TestIMIntegration(t *testing.T) { }) // mock provisioner and linker + ctx := context.Background() ctx = acme.NewProvisionerContext(ctx, prov) ctx = acme.NewLinkerContext(ctx, acme.NewLinker(baseURL, linkerPrefix)) @@ -113,7 +113,7 @@ func TestIMIntegration(t *testing.T) { // get directory dir := func(ctx context.Context) (dir Directory) { - req := httptest.NewRequest(http.MethodGet, "/foo/bar", nil) + req := httptest.NewRequest(http.MethodGet, "/foo/bar", http.NoBody) req = req.WithContext(ctx) w := httptest.NewRecorder() @@ -138,7 +138,7 @@ func TestIMIntegration(t *testing.T) { // get nonce nonce := func(ctx context.Context) (nonce string) { - req := httptest.NewRequest(http.MethodGet, dir.NewNonce, nil).WithContext(ctx) + req := httptest.NewRequest(http.MethodGet, dir.NewNonce, http.NoBody).WithContext(ctx) w := httptest.NewRecorder() addNonce(GetNonce)(w, req) res := w.Result() @@ -164,7 +164,7 @@ func TestIMIntegration(t *testing.T) { ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: rawNar}) // create account - req := httptest.NewRequest(http.MethodGet, dir.NewAccount, nil).WithContext(ctx) + req := httptest.NewRequest(http.MethodGet, dir.NewAccount, http.NoBody).WithContext(ctx) w := httptest.NewRecorder() NewAccount(w, req) @@ -214,7 +214,7 @@ func TestIMIntegration(t *testing.T) { ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) - req := httptest.NewRequest("POST", "https://random.local/", nil) + req := httptest.NewRequest("POST", "https://random.local/", http.NoBody) req = req.WithContext(ctx) w := httptest.NewRecorder() NewOrder(w, req) @@ -250,7 +250,7 @@ func TestIMIntegration(t *testing.T) { chiCtx.URLParams.Add("authzID", authzID) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) - req := httptest.NewRequest(http.MethodGet, "https://random.local/", nil).WithContext(ctx) + req := httptest.NewRequest(http.MethodGet, "https://random.local/", http.NoBody).WithContext(ctx) w := httptest.NewRecorder() GetAuthorization(w, req) @@ -287,7 +287,7 @@ func TestIMIntegration(t *testing.T) { ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: nil}) - req := httptest.NewRequest(http.MethodGet, "https://random.local/", nil).WithContext(ctx) + req := httptest.NewRequest(http.MethodGet, "https://random.local/", http.NoBody).WithContext(ctx) w := httptest.NewRecorder() GetChallenge(w, req) @@ -343,7 +343,7 @@ func TestIMIntegration(t *testing.T) { chiCtx.URLParams.Add("ordID", order.ID) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) - req := httptest.NewRequest(http.MethodGet, "https://random.local/", nil).WithContext(ctx) + req := httptest.NewRequest(http.MethodGet, "https://random.local/", http.NoBody).WithContext(ctx) w := httptest.NewRecorder() GetOrder(w, req) @@ -371,7 +371,7 @@ func TestIMIntegration(t *testing.T) { }(ctx) t.Log("updated order status:", updatedOrder.Status) - // finalise order + // finalize order finalizedOrder := func(ctx context.Context) (finalizedOrder *acme.Order) { mockMustAuthority(t, &mockCASigner{ signer: func(*x509.CertificateRequest, provisioner.SignOptions, ...provisioner.SignOption) ([]*x509.Certificate, error) { @@ -436,7 +436,7 @@ func TestIMIntegration(t *testing.T) { chiCtx.URLParams.Add("ordID", order.ID) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) - req := httptest.NewRequest(http.MethodGet, "https://random.local/", nil).WithContext(ctx) + req := httptest.NewRequest(http.MethodGet, "https://random.local/", http.NoBody).WithContext(ctx) w := httptest.NewRecorder() FinalizeOrder(w, req) diff --git a/acme/api/order.go b/acme/api/order.go index 9f30a5c4..a0f49b78 100644 --- a/acme/api/order.go +++ b/acme/api/order.go @@ -56,11 +56,11 @@ func (n *NewOrderRequest) Validate() error { if err != nil { return acme.NewError(acme.ErrorMalformedType, "ID cannot be parsed") } - clientIdUri, err := uri.Parse(orderValue.ClientID) + clientIDURI, err := uri.Parse(orderValue.ClientID) if err != nil { return acme.NewError(acme.ErrorMalformedType, "invalid client ID, it's supposed to be a valid URI") } - if clientIdUri.Scheme != "wireapp" { + if clientIDURI.Scheme != "wireapp" { return acme.NewError(acme.ErrorMalformedType, "invalid client ID scheme") } default: @@ -280,11 +280,11 @@ func newAuthorization(ctx context.Context, az *acme.Authorization) error { var target string if az.Identifier.Type == acme.WireID { - wireId, err := wire.ParseID([]byte(az.Identifier.Value)) + wireID, err := wire.ParseID([]byte(az.Identifier.Value)) if err != nil { return acme.WrapError(acme.ErrorMalformedType, err, "failed parsing WireID") } - clientID, err := wire.ParseClientID(wireId.ClientID) + clientID, err := wire.ParseClientID(wireID.ClientID) if err != nil { return acme.WrapError(acme.ErrorMalformedType, err, "failed parsing ClientID") } diff --git a/acme/challenge.go b/acme/challenge.go index 51c9b6a3..633e0667 100644 --- a/acme/challenge.go +++ b/acme/challenge.go @@ -355,8 +355,8 @@ func dns01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebK } type WireChallengePayload struct { - // IdToken - IdToken string `json:"id_token,omitempty"` + // IDToken + IDToken string `json:"id_token,omitempty"` // KeyAuth ({challenge-token}.{jwk-thumbprint}) KeyAuth string `json:"keyauth,omitempty"` // AccessToken is the token generated by wire-server @@ -377,7 +377,7 @@ func wireOIDC01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSO } oidcOptions := prov.GetOptions().GetOIDCOptions() - idToken, err := oidcOptions.GetProvider(ctx).Verifier(oidcOptions.GetConfig()).Verify(ctx, wireChallengePayload.IdToken) + idToken, err := oidcOptions.GetProvider(ctx).Verifier(oidcOptions.GetConfig()).Verify(ctx, wireChallengePayload.IDToken) if err != nil { return storeError(ctx, db, ch, false, WrapError(ErrorRejectedIdentifierType, err, "error verifying ID token signature")) @@ -422,12 +422,12 @@ func wireOIDC01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSO return WrapErrorISE(err, "error updating challenge") } - parsedIdToken, err := jwt.ParseSigned(wireChallengePayload.IdToken) + parsedIDToken, err := jwt.ParseSigned(wireChallengePayload.IDToken) if err != nil { return WrapErrorISE(err, "Invalid OIDC id token") } oidcToken := make(map[string]interface{}) - if err := parsedIdToken.UnsafeClaimsWithoutVerification(&oidcToken); err != nil { + if err := parsedIDToken.UnsafeClaimsWithoutVerification(&oidcToken); err != nil { return WrapErrorISE(err, "Failed parsing OIDC id token") } @@ -450,7 +450,7 @@ func wireOIDC01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSO } func wireDPOP01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, payload []byte) error { - provisioner, ok := ProvisionerFromContext(ctx) + prov, ok := ProvisionerFromContext(ctx) if !ok { return NewErrorISE("missing provisioner") } @@ -462,7 +462,7 @@ func wireDPOP01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSO kid := base64.RawURLEncoding.EncodeToString(rawKid) - dpopOptions := provisioner.GetOptions().GetDPOPOptions() + dpopOptions := prov.GetOptions().GetDPOPOptions() key := dpopOptions.GetSigningKey() var wireChallengePayload WireChallengePayload @@ -506,7 +506,7 @@ func wireDPOP01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSO } expiry := strconv.FormatInt(time.Now().Add(time.Hour*24*365).Unix(), 10) - cmd := exec.CommandContext( + cmd := exec.CommandContext( //nolint:gosec // TODO(hs): replace this with Go implementation ctx, dpopOptions.GetValidationExecPath(), "verify-access", @@ -573,8 +573,6 @@ func wireDPOP01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSO return WrapErrorISE(err, "Failed parsing access token") } - ctx = context.WithValue(ctx, "access", access) - rawDpop, ok := access["proof"].(string) if !ok { return WrapErrorISE(err, "Invalid dpop proof format in access token") diff --git a/acme/db.go b/acme/db.go index f6943e98..12a18a9d 100644 --- a/acme/db.go +++ b/acme/db.go @@ -55,10 +55,10 @@ type DB interface { GetAllOrdersByAccountID(ctx context.Context, accountID string) ([]string, error) UpdateOrder(ctx context.Context, o *Order) error - CreateDpopToken(ctx context.Context, orderId string, dpop map[string]interface{}) error - GetDpopToken(ctx context.Context, orderId string) (map[string]interface{}, error) - CreateOidcToken(ctx context.Context, orderId string, idToken map[string]interface{}) error - GetOidcToken(ctx context.Context, orderId string) (map[string]interface{}, error) + CreateDpopToken(ctx context.Context, orderID string, dpop map[string]interface{}) error + GetDpopToken(ctx context.Context, orderID string) (map[string]interface{}, error) + CreateOidcToken(ctx context.Context, orderID string, idToken map[string]interface{}) error + GetOidcToken(ctx context.Context, orderID string) (map[string]interface{}, error) } type dbKey struct{} @@ -124,6 +124,12 @@ type MockDB struct { MockGetOrdersByAccountID func(ctx context.Context, accountID string) ([]string, error) MockUpdateOrder func(ctx context.Context, o *Order) error + MockGetAllOrdersByAccountID func(ctx context.Context, accountID string) ([]string, error) + MockGetDpopToken func(ctx context.Context, orderID string) (map[string]interface{}, error) + MockCreateDpopToken func(ctx context.Context, orderID string, dpop map[string]interface{}) error + MockGetOidcToken func(ctx context.Context, orderID string) (map[string]interface{}, error) + MockCreateOidcToken func(ctx context.Context, orderID string, idToken map[string]interface{}) error + MockRet1 interface{} MockError error } @@ -398,27 +404,48 @@ func (m *MockDB) GetOrdersByAccountID(ctx context.Context, accID string) ([]stri return m.MockRet1.([]string), m.MockError } -// GetDpop retrieves an DPoP from the database. -func (m *MockDB) GetDpopToken(ctx context.Context, orderId string) (map[string]interface{}, error) { - return nil, errors.New("not implemented") +// GetAllOrdersByAccountID returns a list of any order IDs owned by the account. +func (m *MockDB) GetAllOrdersByAccountID(ctx context.Context, accountID string) ([]string, error) { + if m.MockGetAllOrdersByAccountID != nil { + return m.MockGetAllOrdersByAccountID(ctx, accountID) + } else if m.MockError != nil { + return nil, m.MockError + } + return m.MockRet1.([]string), m.MockError +} + +// GetDpop retrieves a DPoP from the database. +func (m *MockDB) GetDpopToken(ctx context.Context, orderID string) (map[string]any, error) { + if m.MockGetDpopToken != nil { + return m.MockGetDpopToken(ctx, orderID) + } else if m.MockError != nil { + return nil, m.MockError + } + return m.MockRet1.(map[string]any), m.MockError } // CreateDpop creates DPoP resources and saves them to the DB. -func (m *MockDB) CreateDpopToken(ctx context.Context, orderId string, dpop map[string]interface{}) error { - return errors.New("not implemented") -} - -// GetAllOrdersByAccountID returns a list of any order IDs owned by the account. -func (m *MockDB) GetAllOrdersByAccountID(ctx context.Context, accID string) ([]string, error) { - return nil, errors.New("not implemented") -} - -// CreateOidcToken creates oidc token resources and saves them to the DB. -func (m *MockDB) CreateOidcToken(ctx context.Context, orderId string, idToken map[string]interface{}) error { - return errors.New("not implemented") +func (m *MockDB) CreateDpopToken(ctx context.Context, orderID string, dpop map[string]any) error { + if m.MockCreateDpopToken != nil { + return m.MockCreateDpopToken(ctx, orderID, dpop) + } + return m.MockError } // GetOidcToken retrieves an oidc token from the database. -func (m *MockDB) GetOidcToken(ctx context.Context, orderId string) (map[string]interface{}, error) { - return nil, errors.New("not implemented") +func (m *MockDB) GetOidcToken(ctx context.Context, orderID string) (map[string]any, error) { + if m.MockGetOidcToken != nil { + return m.MockGetOidcToken(ctx, orderID) + } else if m.MockError != nil { + return nil, m.MockError + } + return m.MockRet1.(map[string]any), m.MockError +} + +// CreateOidcToken creates oidc token resources and saves them to the DB. +func (m *MockDB) CreateOidcToken(ctx context.Context, orderID string, idToken map[string]any) error { + if m.MockCreateOidcToken != nil { + return m.MockCreateOidcToken(ctx, orderID, idToken) + } + return m.MockError } diff --git a/acme/db/nosql/order.go b/acme/db/nosql/order.go index 29d69474..983fbe8d 100644 --- a/acme/db/nosql/order.go +++ b/acme/db/nosql/order.go @@ -121,7 +121,7 @@ func (db *DB) UpdateOrder(ctx context.Context, o *acme.Order) error { return db.save(ctx, old.ID, nu, old, "order", orderTable) } -func (db *DB) updateAddOrderIDs(ctx context.Context, accID string, anyOrder bool, addOids ...string) ([]string, error) { +func (db *DB) updateAddOrderIDs(ctx context.Context, accID string, includeReadyOrders bool, addOids ...string) ([]string, error) { ordersByAccountMux.Lock() defer ordersByAccountMux.Unlock() @@ -153,7 +153,7 @@ func (db *DB) updateAddOrderIDs(ctx context.Context, accID string, anyOrder bool return nil, acme.WrapErrorISE(err, "error updating order %s for account %s", oid, accID) } - if o.Status == acme.StatusPending || o.Status == acme.StatusReady { + if o.Status == acme.StatusPending || (o.Status == acme.StatusReady && includeReadyOrders) { pendOids = append(pendOids, oid) } } diff --git a/acme/db/nosql/wire.go b/acme/db/nosql/wire.go index cf8c5fa9..9ceeb52d 100644 --- a/acme/db/nosql/wire.go +++ b/acme/db/nosql/wire.go @@ -3,6 +3,7 @@ package nosql import ( "context" "encoding/json" + "fmt" "time" "github.com/pkg/errors" @@ -17,36 +18,36 @@ type dbDpopToken struct { } // getDBDpopToken retrieves and unmarshals an DPoP type from the database. -func (db *DB) getDBDpopToken(ctx context.Context, orderId string) (*dbDpopToken, error) { - b, err := db.db.Get(wireDpopTokenTable, []byte(orderId)) +func (db *DB) getDBDpopToken(_ context.Context, orderID string) (*dbDpopToken, error) { + b, err := db.db.Get(wireDpopTokenTable, []byte(orderID)) if nosql.IsErrNotFound(err) { - return nil, acme.NewError(acme.ErrorMalformedType, "dpop %s not found", orderId) + return nil, acme.NewError(acme.ErrorMalformedType, "dpop %s not found", orderID) } else if err != nil { - return nil, errors.Wrapf(err, "error loading dpop %s", orderId) + return nil, errors.Wrapf(err, "error loading dpop %s", orderID) } d := new(dbDpopToken) if err := json.Unmarshal(b, d); err != nil { - return nil, errors.Wrapf(err, "error unmarshaling dpop %s into dbDpopToken", orderId) + return nil, errors.Wrapf(err, "error unmarshaling dpop %s into dbDpopToken", orderID) } return d, nil } // GetDpopToken retrieves an DPoP from the database. -func (db *DB) GetDpopToken(ctx context.Context, orderId string) (map[string]interface{}, error) { - dbDpop, err := db.getDBDpopToken(ctx, orderId) +func (db *DB) GetDpopToken(ctx context.Context, orderID string) (map[string]any, error) { + dbDpop, err := db.getDBDpopToken(ctx, orderID) if err != nil { return nil, err } - dpop := make(map[string]interface{}) + dpop := make(map[string]any) err = json.Unmarshal(dbDpop.Content, &dpop) return dpop, err } // CreateDpopToken creates DPoP resources and saves them to the DB. -func (db *DB) CreateDpopToken(ctx context.Context, orderId string, dpop map[string]interface{}) error { +func (db *DB) CreateDpopToken(ctx context.Context, orderID string, dpop map[string]any) error { content, err := json.Marshal(dpop) if err != nil { return err @@ -54,12 +55,12 @@ func (db *DB) CreateDpopToken(ctx context.Context, orderId string, dpop map[stri now := clock.Now() dbDpop := &dbDpopToken{ - ID: orderId, + ID: orderID, Content: content, CreatedAt: now, } - if err := db.save(ctx, orderId, dbDpop, nil, "dpop", wireDpopTokenTable); err != nil { - return err + if err := db.save(ctx, orderID, dbDpop, nil, "dpop", wireDpopTokenTable); err != nil { + return fmt.Errorf("failed saving dpop token: %w", err) } return nil } @@ -71,35 +72,35 @@ type dbOidcToken struct { } // getDBOidcToken retrieves and unmarshals an OIDC id token type from the database. -func (db *DB) getDBOidcToken(ctx context.Context, orderId string) (*dbOidcToken, error) { - b, err := db.db.Get(wireOidcTokenTable, []byte(orderId)) +func (db *DB) getDBOidcToken(_ context.Context, orderID string) (*dbOidcToken, error) { + b, err := db.db.Get(wireOidcTokenTable, []byte(orderID)) if nosql.IsErrNotFound(err) { - return nil, acme.NewError(acme.ErrorMalformedType, "oidc token %s not found", orderId) + return nil, acme.NewError(acme.ErrorMalformedType, "oidc token %s not found", orderID) } else if err != nil { - return nil, errors.Wrapf(err, "error loading oidc token %s", orderId) + return nil, errors.Wrapf(err, "error loading oidc token %s", orderID) } o := new(dbOidcToken) if err := json.Unmarshal(b, o); err != nil { - return nil, errors.Wrapf(err, "error unmarshaling oidc token %s into dbOidcToken", orderId) + return nil, errors.Wrapf(err, "error unmarshaling oidc token %s into dbOidcToken", orderID) } return o, nil } // GetOidcToken retrieves an oidc token from the database. -func (db *DB) GetOidcToken(ctx context.Context, orderId string) (map[string]interface{}, error) { - dbOidc, err := db.getDBOidcToken(ctx, orderId) +func (db *DB) GetOidcToken(ctx context.Context, orderID string) (map[string]any, error) { + dbOidc, err := db.getDBOidcToken(ctx, orderID) if err != nil { return nil, err } - idToken := make(map[string]interface{}) + idToken := make(map[string]any) err = json.Unmarshal(dbOidc.Content, &idToken) return idToken, err } // CreateOidcToken creates oidc token resources and saves them to the DB. -func (db *DB) CreateOidcToken(ctx context.Context, orderId string, idToken map[string]interface{}) error { +func (db *DB) CreateOidcToken(ctx context.Context, orderID string, idToken map[string]any) error { content, err := json.Marshal(idToken) if err != nil { return err @@ -107,12 +108,12 @@ func (db *DB) CreateOidcToken(ctx context.Context, orderId string, idToken map[s now := clock.Now() dbOidc := &dbOidcToken{ - ID: orderId, + ID: orderID, Content: content, CreatedAt: now, } - if err := db.save(ctx, orderId, dbOidc, nil, "oidc", wireOidcTokenTable); err != nil { - return err + if err := db.save(ctx, orderID, dbOidc, nil, "oidc", wireOidcTokenTable); err != nil { + return fmt.Errorf("failed saving oidc token: %w", err) } return nil } diff --git a/acme/order.go b/acme/order.go index 6e94ca6f..cd3c6bac 100644 --- a/acme/order.go +++ b/acme/order.go @@ -406,13 +406,12 @@ func (o *Order) sans(csr *x509.CertificateRequest) ([]x509util.SubjectAlternativ if err != nil { return sans, NewErrorISE("unsupported identifier value in order: %s", n.Value) } - clientId, err := url.Parse(wireID.ClientID) + clientID, err := url.Parse(wireID.ClientID) if err != nil { return sans, NewErrorISE("clientId must be a URI: %s", wireID.ClientID) } - tmpOrderURIs[indexURI] = clientId + tmpOrderURIs[indexURI] = clientID indexURI++ - handle, err := url.Parse(wireID.Handle) if err != nil { return sans, NewErrorISE("handle must be a URI: %s", wireID.Handle) diff --git a/authority/provisioner/acme.go b/authority/provisioner/acme.go index 99a54ccc..726fcd20 100644 --- a/authority/provisioner/acme.go +++ b/authority/provisioner/acme.go @@ -254,12 +254,11 @@ func (p *ACME) AuthorizeOrderIdentifier(_ context.Context, identifier ACMEIdenti case DNS: err = x509Policy.IsDNSAllowed(identifier.Value) case WireID: - wireSANs, err := wire.ParseID([]byte(identifier.Value)) - if err != nil { - err = fmt.Errorf("could not parse Wire SANs: %w", err) - break + var wireID wire.ID + if wireID, err = wire.ParseID([]byte(identifier.Value)); err != nil { + return fmt.Errorf("failed parsing Wire SANs: %w", err) } - err = x509Policy.AreSANsAllowed([]string{wireSANs.ClientID, wireSANs.Handle}) + err = x509Policy.AreSANsAllowed([]string{wireID.ClientID, wireID.Handle}) default: err = fmt.Errorf("invalid ACME identifier type '%s' provided", identifier.Type) } diff --git a/authority/provisioner/dpop_options.go b/authority/provisioner/dpop_options.go index ced11d9c..ddfc5b15 100644 --- a/authority/provisioner/dpop_options.go +++ b/authority/provisioner/dpop_options.go @@ -2,6 +2,7 @@ package provisioner import ( "bytes" + "errors" "fmt" "text/template" ) @@ -39,11 +40,16 @@ func (o *DPOPOptions) GetDPOPTarget() string { func (o *DPOPOptions) GetTarget(deviceID string) (string, error) { if o == nil { - return "", fmt.Errorf("Misconfigured target template configuration") + return "", errors.New("misconfigured target template configuration") } targetTemplate := o.GetDPOPTarget() - tmpl, err := template.New("DeviceId").Parse(targetTemplate) + tmpl, err := template.New("DeviceID").Parse(targetTemplate) + if err != nil { + return "", fmt.Errorf("failed parsing dpop template: %w", err) + } buf := new(bytes.Buffer) - err = tmpl.Execute(buf, struct{ DeviceId string }{deviceID}) - return buf.String(), err + if err = tmpl.Execute(buf, struct{ DeviceID string }{DeviceID: deviceID}); err != nil { + return "", fmt.Errorf("failed executing dpop template: %w", err) + } + return buf.String(), nil } diff --git a/authority/provisioner/oidc_options.go b/authority/provisioner/oidc_options.go index 21df2d60..46dda9fe 100644 --- a/authority/provisioner/oidc_options.go +++ b/authority/provisioner/oidc_options.go @@ -3,6 +3,7 @@ package provisioner import ( "bytes" "context" + "errors" "fmt" "net/url" "text/template" @@ -52,17 +53,22 @@ func (o *OIDCOptions) GetConfig() *oidc.Config { func (o *OIDCOptions) GetTarget(deviceID string) (string, error) { if o == nil { - return "", fmt.Errorf("Misconfigured target template configuration") + return "", errors.New("misconfigured target template configuration") } targetTemplate := o.Provider.IssuerURL - tmpl, err := template.New("DeviceId").Parse(targetTemplate) + tmpl, err := template.New("DeviceID").Parse(targetTemplate) + if err != nil { + return "", fmt.Errorf("failed parsing oidc template: %w", err) + } buf := new(bytes.Buffer) - err = tmpl.Execute(buf, struct{ DeviceId string }{deviceID}) - return buf.String(), err + if err = tmpl.Execute(buf, struct{ DeviceID string }{DeviceID: deviceID}); err != nil { + return "", fmt.Errorf("failed executing oidc template: %w", err) + } + return buf.String(), nil } func toProviderConfig(in ProviderJSON) *oidc.ProviderConfig { - issuerUrl, err := url.Parse(in.IssuerURL) + issuerURL, err := url.Parse(in.IssuerURL) if err != nil { panic(err) // config error, it's ok to panic here } @@ -71,10 +77,10 @@ func toProviderConfig(in ProviderJSON) *oidc.ProviderConfig { // This URL is going to look like: "https://idp:5556/dex?clientid=foo" // If we don't trim the query params here i.e. 'clientid' then the idToken verification is going to fail because // the 'iss' claim of the idToken will be "https://idp:5556/dex" - issuerUrl.RawQuery = "" - issuerUrl.Fragment = "" + issuerURL.RawQuery = "" + issuerURL.Fragment = "" return &oidc.ProviderConfig{ - IssuerURL: issuerUrl.String(), + IssuerURL: issuerURL.String(), AuthURL: in.AuthURL, TokenURL: in.TokenURL, UserInfoURL: in.UserInfoURL, diff --git a/wire/id.go b/wire/id.go index 5ba8438a..9d57f79b 100644 --- a/wire/id.go +++ b/wire/id.go @@ -8,14 +8,14 @@ import ( "go.step.sm/crypto/kms/uri" ) -type WireIDJSON struct { +type ID struct { Name string `json:"name,omitempty"` Domain string `json:"domain,omitempty"` ClientID string `json:"client-id,omitempty"` Handle string `json:"handle,omitempty"` } -func ParseID(data []byte) (wireID WireIDJSON, err error) { +func ParseID(data []byte) (wireID ID, err error) { err = json.Unmarshal(data, &wireID) return } @@ -26,21 +26,24 @@ type ClientID struct { Domain string } -// ClientId format is : "wireapp://CzbfFjDOQrenCbDxVmgnFw!594930e9d50bb175@wire.com" where '!' is used as a separator -// between the user id & device id +// ParseClientID parses a Wire clientID. The ClientID format is as follows: +// +// "wireapp://CzbfFjDOQrenCbDxVmgnFw!594930e9d50bb175@wire.com", +// +// where '!' is used as a separator between the user id & device id. func ParseClientID(clientID string) (ClientID, error) { - clientIdUri, err := uri.Parse(clientID) + clientIDURI, err := uri.Parse(clientID) if err != nil { - return ClientID{}, fmt.Errorf("invalid client id URI") + return ClientID{}, fmt.Errorf("invalid clientID URI %q: %w", clientID, err) } - fullUsername := clientIdUri.User.Username() + fullUsername := clientIDURI.User.Username() parts := strings.SplitN(fullUsername, "!", 2) if len(parts) != 2 { - return ClientID{}, fmt.Errorf("invalid client id") + return ClientID{}, fmt.Errorf("invalid clientID %q", fullUsername) } return ClientID{ Username: parts[0], DeviceID: parts[1], - Domain: clientIdUri.Host, + Domain: clientIDURI.Host, }, nil }