Fix linter issues

This commit is contained in:
Herman Slatman
2025-02-18 11:04:54 +01:00
parent 86c04f0ce8
commit 27944b4eae
33 changed files with 308 additions and 112 deletions

View File

@@ -180,7 +180,7 @@ func isAccountAuthorized(_ context.Context, dbCert *acme.Certificate, certToBeRe
func wrapRevokeErr(err error) *acme.Error {
t := err.Error()
if strings.Contains(t, "is already revoked") {
return acme.NewError(acme.ErrorAlreadyRevokedType, t) //nolint:govet // allow non-constant error messages
return acme.NewError(acme.ErrorAlreadyRevokedType, t)
}
return acme.WrapErrorISE(err, "error when revoking certificate")
}
@@ -190,9 +190,9 @@ func wrapRevokeErr(err error) *acme.Error {
func wrapUnauthorizedError(cert *x509.Certificate, unauthorizedIdentifiers []acme.Identifier, msg string, err error) *acme.Error {
var acmeErr *acme.Error
if err == nil {
acmeErr = acme.NewError(acme.ErrorUnauthorizedType, msg) //nolint:govet // allow non-constant error messages
acmeErr = acme.NewError(acme.ErrorUnauthorizedType, msg)
} else {
acmeErr = acme.WrapError(acme.ErrorUnauthorizedType, err, msg) //nolint:govet // allow non-constant error messages
acmeErr = acme.WrapError(acme.ErrorUnauthorizedType, err, msg)
}
acmeErr.Status = http.StatusForbidden // RFC8555 7.6 shows example with 403

View File

@@ -39,6 +39,7 @@ import (
"github.com/smallstep/certificates/acme/wire"
"github.com/smallstep/certificates/authority/provisioner"
wireprovisioner "github.com/smallstep/certificates/authority/provisioner/wire"
"github.com/smallstep/certificates/internal/cast"
)
type ChallengeType string
@@ -229,7 +230,7 @@ func tlsAlert(err error) uint8 {
if errors.As(err, &opErr) {
v := reflect.ValueOf(opErr.Err)
if v.Kind() == reflect.Uint8 {
return uint8(v.Uint())
return uint8(v.Uint()) //nolint:gosec // handled by checking its type
}
}
return 0
@@ -978,9 +979,9 @@ type tpmAttestationData struct {
type coseAlgorithmIdentifier int32
const (
coseAlgES256 coseAlgorithmIdentifier = -7
coseAlgRS256 coseAlgorithmIdentifier = -257
coseAlgRS1 coseAlgorithmIdentifier = -65535 // deprecated, but (still) often used in TPMs
coseAlgES256 = coseAlgorithmIdentifier(-7)
coseAlgRS256 = coseAlgorithmIdentifier(-257)
coseAlgRS1 = coseAlgorithmIdentifier(-65535) // deprecated, but (still) often used in TPMs
)
func doTPMAttestationFormat(_ context.Context, prov Provisioner, ch *Challenge, jwk *jose.JSONWebKey, att *attestationObject) (*tpmAttestationData, error) {
@@ -1105,8 +1106,13 @@ func doTPMAttestationFormat(_ context.Context, prov Provisioner, ch *Challenge,
return nil, NewDetailedError(ErrorBadAttestationStatementType, "invalid alg in attestation statement")
}
algI32, err := cast.SafeInt32(alg)
if err != nil {
return nil, WrapDetailedError(ErrorBadAttestationStatementType, err, "invalid alg %d in attestation statement", alg)
}
var hash crypto.Hash
switch coseAlgorithmIdentifier(alg) {
switch coseAlgorithmIdentifier(algI32) {
case coseAlgRS256, coseAlgES256:
hash = crypto.SHA256
case coseAlgRS1:

View File

@@ -86,7 +86,7 @@ func GetUnescapedPathSuffix(typ LinkType, provisionerName string, inputs ...stri
case AccountLinkType, OrderLinkType, AuthzLinkType, CertificateLinkType:
return fmt.Sprintf("/%s/%s/%s", provisionerName, typ, inputs[0])
case ChallengeLinkType:
return fmt.Sprintf("/%s/%s/%s/%s", provisionerName, typ, inputs[0], inputs[1])
return fmt.Sprintf("/%s/%s/%s/%s", provisionerName, typ, inputs[0], inputs[1]) //nolint:gosec // operating on internally defined inputs
case OrdersByAccountLinkType:
return fmt.Sprintf("/%s/%s/%s/orders", provisionerName, AccountLinkType, inputs[0])
case FinalizeLinkType:

View File

@@ -309,7 +309,6 @@ func (o *Order) Finalize(ctx context.Context, db DB, csr *x509.CertificateReques
// Add subproblem for webhook errors, others can be added later.
var webhookErr *webhook.Error
if errors.As(err, &webhookErr) {
//nolint:govet // ignore non-constant format string
acmeError := NewDetailedError(ErrorUnauthorizedType, webhookErr.Error())
acmeError.AddSubproblems(Subproblem{
Type: fmt.Sprintf("urn:smallstep:acme:error:%s", webhookErr.Code),

View File

@@ -4,7 +4,7 @@ import (
"bytes"
"context"
"crypto"
"crypto/dsa" // support legacy algorithms
"crypto/dsa" //nolint:staticcheck // support legacy algorithms
"crypto/ecdsa"
"crypto/ed25519"
"crypto/rsa"
@@ -31,6 +31,7 @@ import (
"github.com/smallstep/certificates/authority/config"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs"
"github.com/smallstep/certificates/internal/cast"
"github.com/smallstep/certificates/logging"
)
@@ -595,8 +596,8 @@ func LogSSHCertificate(w http.ResponseWriter, cert *ssh.Certificate) {
m := map[string]interface{}{
"serial": cert.Serial,
"principals": cert.ValidPrincipals,
"valid-from": time.Unix(int64(cert.ValidAfter), 0).Format(time.RFC3339),
"valid-to": time.Unix(int64(cert.ValidBefore), 0).Format(time.RFC3339),
"valid-from": time.Unix(cast.Int64(cert.ValidAfter), 0).Format(time.RFC3339),
"valid-to": time.Unix(cast.Int64(cert.ValidBefore), 0).Format(time.RFC3339),
"certificate": certificate,
"certificate-type": certificateType,
}

View File

@@ -20,6 +20,7 @@ import (
"github.com/smallstep/certificates/authority/config"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs"
"github.com/smallstep/certificates/internal/cast"
"github.com/smallstep/certificates/templates"
)
@@ -331,8 +332,8 @@ func SSHSign(w http.ResponseWriter, r *http.Request) {
// Enforce the same duration as ssh certificate.
signOpts = append(signOpts, &identityModifier{
Identity: getIdentityURI(cr),
NotBefore: time.Unix(int64(cert.ValidAfter), 0),
NotAfter: time.Unix(int64(cert.ValidBefore), 0),
NotBefore: time.Unix(cast.Int64(cert.ValidAfter), 0),
NotAfter: time.Unix(cast.Int64(cert.ValidBefore), 0),
})
certChain, err := a.SignWithContext(ctx, cr, provisioner.SignOptions{}, signOpts...)

View File

@@ -10,6 +10,7 @@ import (
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs"
"github.com/smallstep/certificates/internal/cast"
)
// SSHRekeyRequest is the request body of an SSH certificate request.
@@ -80,8 +81,8 @@ func SSHRekey(w http.ResponseWriter, r *http.Request) {
}
// Match identity cert with the SSH cert
notBefore := time.Unix(int64(oldCert.ValidAfter), 0)
notAfter := time.Unix(int64(oldCert.ValidBefore), 0)
notBefore := time.Unix(cast.Int64(oldCert.ValidAfter), 0)
notAfter := time.Unix(cast.Int64(oldCert.ValidBefore), 0)
identity, err := renewIdentityCertificate(r, notBefore, notAfter)
if err != nil {

View File

@@ -11,6 +11,7 @@ import (
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs"
"github.com/smallstep/certificates/internal/cast"
)
// SSHRenewRequest is the request body of an SSH certificate request.
@@ -72,8 +73,8 @@ func SSHRenew(w http.ResponseWriter, r *http.Request) {
}
// Match identity cert with the SSH cert
notBefore := time.Unix(int64(oldCert.ValidAfter), 0)
notAfter := time.Unix(int64(oldCert.ValidBefore), 0)
notBefore := time.Unix(cast.Int64(oldCert.ValidAfter), 0)
notAfter := time.Unix(cast.Int64(oldCert.ValidBefore), 0)
identity, err := renewIdentityCertificate(r, notBefore, notAfter)
if err != nil {

View File

@@ -202,7 +202,7 @@ func (war *webhookAdminResponder) UpdateProvisionerWebhook(w http.ResponseWriter
}
if !found {
msg := fmt.Sprintf("provisioner %q has no webhook with the name %q", prov.Name, newWebhook.Name)
err := admin.NewError(admin.ErrorNotFoundType, msg) //nolint:govet // allow non-constant error messages
err := admin.NewError(admin.ErrorNotFoundType, msg)
render.Error(w, r, err)
return
}

View File

@@ -28,6 +28,7 @@ import (
"github.com/smallstep/certificates/authority/admin"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/db"
"github.com/smallstep/certificates/internal/cast"
)
const uuidPattern = "^[a-fA-F0-9]{8}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{12}$"
@@ -336,7 +337,7 @@ func (c *linkedCaClient) Revoke(crt *x509.Certificate, rci *db.RevokedCertificat
Serial: rci.Serial,
PemCertificate: serializeCertificate(crt),
Reason: rci.Reason,
ReasonCode: linkedca.RevocationReasonCode(rci.ReasonCode),
ReasonCode: linkedca.RevocationReasonCode(cast.Int32(rci.ReasonCode)),
Passive: true,
})
@@ -350,7 +351,7 @@ func (c *linkedCaClient) RevokeSSH(cert *ssh.Certificate, rci *db.RevokedCertifi
Serial: rci.Serial,
Certificate: serializeSSHCertificate(cert),
Reason: rci.Reason,
ReasonCode: linkedca.RevocationReasonCode(rci.ReasonCode),
ReasonCode: linkedca.RevocationReasonCode(cast.Int32(rci.ReasonCode)),
Passive: true,
})
@@ -403,7 +404,7 @@ func createProvisionerIdentity(p provisioner.Interface) *linkedca.ProvisionerIde
}
return &linkedca.ProvisionerIdentity{
Id: p.GetID(),
Type: linkedca.Provisioner_Type(p.GetType()),
Type: linkedca.Provisioner_Type(cast.Int32(int(p.GetType()))),
Name: p.GetName(),
}
}

View File

@@ -12,8 +12,10 @@ import (
"strings"
"sync"
"github.com/smallstep/certificates/authority/admin"
"go.step.sm/crypto/jose"
"github.com/smallstep/certificates/authority/admin"
"github.com/smallstep/certificates/internal/cast"
)
// DefaultProvisionersLimit is the default limit for listing provisioners.
@@ -210,7 +212,7 @@ func (c *Collection) Store(p Interface) error {
// 0x00000000, 0x00000001, 0x00000002, ...
bi := make([]byte, 4)
sum := provisionerSum(p)
binary.BigEndian.PutUint32(bi, uint32(c.sorted.Len()))
binary.BigEndian.PutUint32(bi, cast.Uint32(c.sorted.Len()))
sum[0], sum[1], sum[2], sum[3] = bi[0], bi[1], bi[2], bi[3]
c.sorted = append(c.sorted, uidProvisioner{
provisioner: p,

View File

@@ -8,11 +8,14 @@ import (
"time"
"github.com/pkg/errors"
"golang.org/x/crypto/ssh"
"github.com/smallstep/linkedca"
"github.com/smallstep/certificates/errs"
"github.com/smallstep/certificates/internal/cast"
"github.com/smallstep/certificates/internal/httptransport"
"github.com/smallstep/certificates/webhook"
"github.com/smallstep/linkedca"
"golang.org/x/crypto/ssh"
)
// Controller wraps a provisioner with other attributes useful in callback
@@ -189,10 +192,10 @@ func DefaultAuthorizeSSHRenew(_ context.Context, p *Controller, cert *ssh.Certif
}
unixNow := time.Now().Unix()
if after := int64(cert.ValidAfter); after < 0 || unixNow < int64(cert.ValidAfter) {
if after := cast.Int64(cert.ValidAfter); after < 0 || unixNow < cast.Int64(cert.ValidAfter) {
return errs.Unauthorized("certificate is not yet valid")
}
if before := int64(cert.ValidBefore); cert.ValidBefore != uint64(ssh.CertTimeInfinity) && (unixNow >= before || before < 0) && !p.Claimer.AllowRenewalAfterExpiry() {
if before := cast.Int64(cert.ValidBefore); cert.ValidBefore != uint64(ssh.CertTimeInfinity) && (unixNow >= before || before < 0) && !p.Claimer.AllowRenewalAfterExpiry() {
return errs.Unauthorized("certificate has expired")
}

View File

@@ -14,6 +14,7 @@ import (
"go.step.sm/crypto/x509util"
"github.com/smallstep/certificates/errs"
"github.com/smallstep/certificates/internal/cast"
)
// jwtPayload extends jwt.Claims with step attributes.
@@ -249,7 +250,7 @@ func (p *JWK) AuthorizeSSHSign(_ context.Context, token string) ([]SignOption, e
// Use options in the token.
if opts.CertType != "" {
if certType, err = sshutil.CertTypeFromString(opts.CertType); err != nil {
return nil, errs.BadRequestErr(err, err.Error()) //nolint:govet // allow non-constant error messages
return nil, errs.BadRequestErr(err, err.Error())
}
}
if opts.KeyID != "" {
@@ -274,10 +275,10 @@ func (p *JWK) AuthorizeSSHSign(_ context.Context, token string) ([]SignOption, e
// Add modifiers from custom claims
t := now()
if !opts.ValidAfter.IsZero() {
signOptions = append(signOptions, sshCertValidAfterModifier(opts.ValidAfter.RelativeTime(t).Unix()))
signOptions = append(signOptions, sshCertValidAfterModifier(cast.Uint64(opts.ValidAfter.RelativeTime(t).Unix())))
}
if !opts.ValidBefore.IsZero() {
signOptions = append(signOptions, sshCertValidBeforeModifier(opts.ValidBefore.RelativeTime(t).Unix()))
signOptions = append(signOptions, sshCertValidBeforeModifier(cast.Uint64(opts.ValidBefore.RelativeTime(t).Unix())))
}
return append(signOptions,

View File

@@ -14,15 +14,16 @@ import (
"github.com/pkg/errors"
nebula "github.com/slackhq/nebula/cert"
"golang.org/x/crypto/ssh"
"github.com/smallstep/linkedca"
"go.step.sm/crypto/jose"
"go.step.sm/crypto/sshutil"
"go.step.sm/crypto/x25519"
"go.step.sm/crypto/x509util"
"golang.org/x/crypto/ssh"
"github.com/smallstep/certificates/errs"
"github.com/smallstep/certificates/internal/cast"
)
const (
@@ -237,10 +238,10 @@ func (p *Nebula) AuthorizeSSHSign(_ context.Context, token string) ([]SignOption
// Add modifiers from custom claims
t := now()
if !opts.ValidAfter.IsZero() {
signOptions = append(signOptions, sshCertValidAfterModifier(opts.ValidAfter.RelativeTime(t).Unix()))
signOptions = append(signOptions, sshCertValidAfterModifier(cast.Uint64(opts.ValidAfter.RelativeTime(t).Unix())))
}
if !opts.ValidBefore.IsZero() {
signOptions = append(signOptions, sshCertValidBeforeModifier(opts.ValidBefore.RelativeTime(t).Unix()))
signOptions = append(signOptions, sshCertValidBeforeModifier(cast.Uint64(opts.ValidBefore.RelativeTime(t).Unix())))
}
}

View File

@@ -10,10 +10,13 @@ import (
"time"
"github.com/pkg/errors"
"golang.org/x/crypto/ssh"
"go.step.sm/crypto/keyutil"
"github.com/smallstep/certificates/authority/policy"
"github.com/smallstep/certificates/errs"
"go.step.sm/crypto/keyutil"
"golang.org/x/crypto/ssh"
"github.com/smallstep/certificates/internal/cast"
)
const (
@@ -103,10 +106,10 @@ func (o SignSSHOptions) Modify(cert *ssh.Certificate, _ SignSSHOptions) error {
func (o SignSSHOptions) ModifyValidity(cert *ssh.Certificate) error {
t := now()
if !o.ValidAfter.IsZero() {
cert.ValidAfter = uint64(o.ValidAfter.RelativeTime(t).Unix())
cert.ValidAfter = cast.Uint64(o.ValidAfter.RelativeTime(t).Unix())
}
if !o.ValidBefore.IsZero() {
cert.ValidBefore = uint64(o.ValidBefore.RelativeTime(t).Unix())
cert.ValidBefore = cast.Uint64(o.ValidBefore.RelativeTime(t).Unix())
}
if cert.ValidAfter > 0 && cert.ValidBefore > 0 && cert.ValidAfter > cert.ValidBefore {
return errs.BadRequest("ssh certificate validAfter cannot be greater than validBefore")
@@ -167,11 +170,11 @@ func (m *sshDefaultDuration) Modify(cert *ssh.Certificate, o SignSSHOptions) err
var backdate uint64
if cert.ValidAfter == 0 {
backdate = uint64(o.Backdate / time.Second)
cert.ValidAfter = uint64(now().Truncate(time.Second).Unix())
backdate = cast.Uint64(o.Backdate / time.Second)
cert.ValidAfter = cast.Uint64(now().Truncate(time.Second).Unix())
}
if cert.ValidBefore == 0 {
cert.ValidBefore = cert.ValidAfter + uint64(d/time.Second)
cert.ValidBefore = cert.ValidAfter + cast.Uint64(d/time.Second)
}
// Apply backdate safely
if cert.ValidAfter > backdate {
@@ -206,11 +209,11 @@ func (m *sshLimitDuration) Modify(cert *ssh.Certificate, o SignSSHOptions) error
var backdate uint64
if cert.ValidAfter == 0 {
backdate = uint64(o.Backdate / time.Second)
cert.ValidAfter = uint64(now().Truncate(time.Second).Unix())
backdate = cast.Uint64(o.Backdate / time.Second)
cert.ValidAfter = cast.Uint64(now().Truncate(time.Second).Unix())
}
certValidAfter := time.Unix(int64(cert.ValidAfter), 0)
certValidAfter := time.Unix(cast.Int64(cert.ValidAfter), 0)
if certValidAfter.After(m.NotAfter) {
return errs.Forbidden("provisioning credential expiration (%s) is before requested certificate validAfter (%s)",
m.NotAfter, certValidAfter)
@@ -221,9 +224,9 @@ func (m *sshLimitDuration) Modify(cert *ssh.Certificate, o SignSSHOptions) error
if m.NotAfter.Before(certValidBefore) {
certValidBefore = m.NotAfter
}
cert.ValidBefore = uint64(certValidBefore.Unix())
cert.ValidBefore = cast.Uint64(certValidBefore.Unix())
} else {
certValidBefore := time.Unix(int64(cert.ValidBefore), 0)
certValidBefore := time.Unix(cast.Int64(cert.ValidBefore), 0)
if m.NotAfter.Before(certValidBefore) {
return errs.Forbidden("provisioning credential expiration (%s) is before requested certificate validBefore (%s)",
m.NotAfter, certValidBefore)
@@ -277,7 +280,7 @@ func (v *sshCertValidityValidator) Valid(cert *ssh.Certificate, opts SignSSHOpti
switch {
case cert.ValidAfter == 0:
return errs.BadRequest("ssh certificate validAfter cannot be 0")
case cert.ValidBefore < uint64(now().Unix()):
case cert.ValidBefore < cast.Uint64(now().Unix()):
return errs.BadRequest("ssh certificate validBefore cannot be in the past")
case cert.ValidBefore < cert.ValidAfter:
return errs.BadRequest("ssh certificate validBefore cannot be before validAfter")
@@ -299,7 +302,7 @@ func (v *sshCertValidityValidator) Valid(cert *ssh.Certificate, opts SignSSHOpti
// To not take into account the backdate, time.Now() will be used to
// calculate the duration if ValidAfter is in the past.
dur := time.Duration(cert.ValidBefore-cert.ValidAfter) * time.Second
dur := time.Duration(cast.Int64(cert.ValidBefore-cert.ValidAfter)) * time.Second
switch {
case dur < minDur:
@@ -332,7 +335,7 @@ func (v *sshCertDefaultValidator) Valid(cert *ssh.Certificate, _ SignSSHOptions)
return errs.Forbidden("ssh certificate key id cannot be empty")
case cert.ValidAfter == 0:
return errs.Forbidden("ssh certificate validAfter cannot be 0")
case cert.ValidBefore < uint64(now().Unix()):
case cert.ValidBefore < cast.Uint64(now().Unix()):
return errs.Forbidden("ssh certificate validBefore cannot be in the past")
case cert.ValidBefore < cert.ValidAfter:
return errs.Forbidden("ssh certificate validBefore cannot be before validAfter")
@@ -462,7 +465,7 @@ func sshParseString(in []byte) (out, rest []byte, ok bool) {
}
length := binary.BigEndian.Uint32(in)
in = in[4:]
if uint32(len(in)) < length {
if cast.Uint32(len(in)) < length {
return
}
out = in[:length]

View File

@@ -8,9 +8,12 @@ import (
"reflect"
"time"
"github.com/smallstep/certificates/errs"
"go.step.sm/crypto/sshutil"
"golang.org/x/crypto/ssh"
"go.step.sm/crypto/sshutil"
"github.com/smallstep/certificates/errs"
"github.com/smallstep/certificates/internal/cast"
)
func validateSSHCertificate(cert *ssh.Certificate, opts *SignSSHOptions) error {
@@ -30,9 +33,9 @@ func validateSSHCertificate(cert *ssh.Certificate, opts *SignSSHOptions) error {
case opts.CertType == "host" && cert.CertType != ssh.HostCert:
return fmt.Errorf("certificate type is not valid, want %v, got %v", ssh.HostCert, cert.CertType)
case cert.ValidAfter != uint64(opts.ValidAfter.Unix()):
return fmt.Errorf("certificate valid after is not valid, want %v, got %v", opts.ValidAfter.Unix(), time.Unix(int64(cert.ValidAfter), 0))
return fmt.Errorf("certificate valid after is not valid, want %v, got %v", opts.ValidAfter.Unix(), time.Unix(cast.Int64(cert.ValidAfter), 0))
case cert.ValidBefore != uint64(opts.ValidBefore.Unix()):
return fmt.Errorf("certificate valid after is not valid, want %v, got %v", opts.ValidAfter.Unix(), time.Unix(int64(cert.ValidAfter), 0))
return fmt.Errorf("certificate valid after is not valid, want %v, got %v", opts.ValidAfter.Unix(), time.Unix(cast.Int64(cert.ValidAfter), 0))
case opts.CertType == "user" && len(cert.Extensions) != 5:
return fmt.Errorf("certificate extensions number is invalid, want 5, got %d", len(cert.Extensions))
case opts.CertType == "host" && len(cert.Extensions) != 0:
@@ -90,7 +93,7 @@ func signSSHCertificate(key crypto.PublicKey, opts SignSSHOptions, signOpts []Si
var templErr *sshutil.TemplateError
if errors.As(err, &templErr) {
return nil, errs.NewErr(http.StatusBadRequest, templErr,
errs.WithMessage(templErr.Error()), //nolint:govet // allow non-constant error messages
errs.WithMessage(templErr.Error()),
errs.WithKeyVal("signOptions", signOpts),
)
}

View File

@@ -13,6 +13,7 @@ import (
"go.step.sm/crypto/jose"
"github.com/smallstep/certificates/errs"
"github.com/smallstep/certificates/internal/cast"
)
// sshPOPPayload extends jwt.Claims with step attributes.
@@ -118,10 +119,10 @@ func (p *SSHPOP) authorizeToken(token string, audiences []string, checkValidity
// Controller.AuthorizeSSHRenew will validate this on the renewal flow.
if checkValidity {
unixNow := time.Now().Unix()
if after := int64(sshCert.ValidAfter); after < 0 || unixNow < int64(sshCert.ValidAfter) {
if after := cast.Int64(sshCert.ValidAfter); after < 0 || unixNow < cast.Int64(sshCert.ValidAfter) {
return nil, errs.Unauthorized("sshpop.authorizeToken; sshpop certificate validAfter is in the future")
}
if before := int64(sshCert.ValidBefore); sshCert.ValidBefore != uint64(ssh.CertTimeInfinity) && (unixNow >= before || before < 0) {
if before := cast.Int64(sshCert.ValidBefore); sshCert.ValidBefore != uint64(ssh.CertTimeInfinity) && (unixNow >= before || before < 0) {
return nil, errs.Unauthorized("sshpop.authorizeToken; sshpop certificate validBefore is in the past")
}
}

View File

@@ -15,6 +15,7 @@ import (
"go.step.sm/crypto/x509util"
"github.com/smallstep/certificates/errs"
"github.com/smallstep/certificates/internal/cast"
"github.com/smallstep/certificates/webhook"
)
@@ -301,7 +302,7 @@ func (p *X5C) AuthorizeSSHSign(_ context.Context, token string) ([]SignOption, e
// Use options in the token.
if opts.CertType != "" {
if certType, err = sshutil.CertTypeFromString(opts.CertType); err != nil {
return nil, errs.BadRequestErr(err, err.Error()) //nolint:govet // allow non-constant error messages
return nil, errs.BadRequestErr(err, err.Error())
}
}
if opts.KeyID != "" {
@@ -332,10 +333,10 @@ func (p *X5C) AuthorizeSSHSign(_ context.Context, token string) ([]SignOption, e
// Add modifiers from custom claims
t := now()
if !opts.ValidAfter.IsZero() {
signOptions = append(signOptions, sshCertValidAfterModifier(opts.ValidAfter.RelativeTime(t).Unix()))
signOptions = append(signOptions, sshCertValidAfterModifier(cast.Uint64(opts.ValidAfter.RelativeTime(t).Unix())))
}
if !opts.ValidBefore.IsZero() {
signOptions = append(signOptions, sshCertValidBeforeModifier(opts.ValidBefore.RelativeTime(t).Unix()))
signOptions = append(signOptions, sshCertValidBeforeModifier(cast.Uint64(opts.ValidBefore.RelativeTime(t).Unix())))
}
return append(signOptions,

View File

@@ -22,6 +22,7 @@ import (
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/db"
"github.com/smallstep/certificates/errs"
"github.com/smallstep/certificates/internal/cast"
)
type raProvisioner interface {
@@ -1257,10 +1258,10 @@ func ProvisionerToLinkedca(p provisioner.Interface) (*linkedca.Provisioner, erro
ForceCn: p.ForceCN,
Challenge: p.ChallengePassword,
Capabilities: p.Capabilities,
MinimumPublicKeyLength: int32(p.MinimumPublicKeyLength),
MinimumPublicKeyLength: cast.Int32(p.MinimumPublicKeyLength),
IncludeRoot: p.IncludeRoot,
ExcludeIntermediate: p.ExcludeIntermediate,
EncryptionAlgorithmIdentifier: int32(p.EncryptionAlgorithmIdentifier),
EncryptionAlgorithmIdentifier: cast.Int32(p.EncryptionAlgorithmIdentifier),
Decrypter: &linkedca.SCEPDecrypter{
Certificate: p.DecrypterCertificate,
Key: p.DecrypterKeyPEM,

View File

@@ -19,6 +19,7 @@ import (
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/db"
"github.com/smallstep/certificates/errs"
"github.com/smallstep/certificates/internal/cast"
"github.com/smallstep/certificates/templates"
"github.com/smallstep/certificates/webhook"
)
@@ -214,7 +215,7 @@ func (a *Authority) signSSH(ctx context.Context, key ssh.PublicKey, opts provisi
for _, v := range keyValidators {
if err := v.Valid(key); err != nil {
return nil, nil, errs.ApplyOptions(
errs.ForbiddenErr(err, err.Error()), //nolint:govet // allow non-constant error messages
errs.ForbiddenErr(err, err.Error()),
errs.WithKeyVal("signOptions", signOpts),
)
}
@@ -231,7 +232,7 @@ func (a *Authority) signSSH(ctx context.Context, key ssh.PublicKey, opts provisi
// Call enriching webhooks
if err := a.callEnrichingWebhooksSSH(ctx, prov, webhookCtl, cr); err != nil {
return nil, prov, errs.ApplyOptions(
errs.ForbiddenErr(err, err.Error()), //nolint:govet // allow non-constant error messages
errs.ForbiddenErr(err, err.Error()),
errs.WithKeyVal("signOptions", signOpts),
)
}
@@ -243,7 +244,7 @@ func (a *Authority) signSSH(ctx context.Context, key ssh.PublicKey, opts provisi
switch {
case errors.As(err, &te):
return nil, prov, errs.ApplyOptions(
errs.BadRequestErr(err, err.Error()), //nolint:govet // allow non-constant error messages
errs.BadRequestErr(err, err.Error()),
errs.WithKeyVal("signOptions", signOpts),
)
case strings.HasPrefix(err.Error(), "error unmarshaling certificate"):
@@ -263,7 +264,7 @@ func (a *Authority) signSSH(ctx context.Context, key ssh.PublicKey, opts provisi
// Use SignSSHOptions to modify the certificate validity. It will be later
// checked or set if not defined.
if err := opts.ModifyValidity(certTpl); err != nil {
return nil, prov, errs.BadRequestErr(err, err.Error()) //nolint:govet // allow non-constant error messages
return nil, prov, errs.BadRequestErr(err, err.Error())
}
// Use provisioner modifiers.
@@ -356,7 +357,7 @@ func (a *Authority) renewSSH(ctx context.Context, oldCert *ssh.Certificate) (*ss
}
backdate := a.config.AuthorityConfig.Backdate.Duration
duration := time.Duration(oldCert.ValidBefore-oldCert.ValidAfter) * time.Second
duration := time.Duration(cast.Int64(oldCert.ValidBefore-oldCert.ValidAfter)) * time.Second
now := time.Now()
va := now.Add(-1 * backdate)
vb := now.Add(duration - backdate)
@@ -370,8 +371,8 @@ func (a *Authority) renewSSH(ctx context.Context, oldCert *ssh.Certificate) (*ss
ValidPrincipals: oldCert.ValidPrincipals,
Permissions: oldCert.Permissions,
Reserved: oldCert.Reserved,
ValidAfter: uint64(va.Unix()),
ValidBefore: uint64(vb.Unix()),
ValidAfter: cast.Uint64(va.Unix()),
ValidBefore: cast.Uint64(vb.Unix()),
}
// Get signer from authority keys
@@ -436,7 +437,7 @@ func (a *Authority) rekeySSH(ctx context.Context, oldCert *ssh.Certificate, pub
}
backdate := a.config.AuthorityConfig.Backdate.Duration
duration := time.Duration(oldCert.ValidBefore-oldCert.ValidAfter) * time.Second
duration := time.Duration(cast.Int64(oldCert.ValidBefore-oldCert.ValidAfter)) * time.Second
now := time.Now()
va := now.Add(-1 * backdate)
vb := now.Add(duration - backdate)
@@ -450,8 +451,8 @@ func (a *Authority) rekeySSH(ctx context.Context, oldCert *ssh.Certificate, pub
ValidPrincipals: oldCert.ValidPrincipals,
Permissions: oldCert.Permissions,
Reserved: oldCert.Reserved,
ValidAfter: uint64(va.Unix()),
ValidBefore: uint64(vb.Unix()),
ValidAfter: cast.Uint64(va.Unix()),
ValidBefore: cast.Uint64(vb.Unix()),
}
// Get signer from authority keys

View File

@@ -197,7 +197,7 @@ func (a *Authority) signX509(ctx context.Context, csr *x509.CertificateRequest,
if err := a.callEnrichingWebhooksX509(ctx, prov, webhookCtl, attData, csr); err != nil {
return nil, prov, errs.ApplyOptions(
errs.ForbiddenErr(err, err.Error()), //nolint:govet // allow non-constant error messages
errs.ForbiddenErr(err, err.Error()),
errs.WithKeyVal("csr", csr),
errs.WithKeyVal("signOptions", signOpts),
)
@@ -209,7 +209,7 @@ func (a *Authority) signX509(ctx context.Context, csr *x509.CertificateRequest,
switch {
case errors.As(err, &te):
return nil, prov, errs.ApplyOptions(
errs.BadRequestErr(err, err.Error()), //nolint:govet // allow non-constant error messages
errs.BadRequestErr(err, err.Error()),
errs.WithKeyVal("csr", csr),
errs.WithKeyVal("signOptions", signOpts),
)

View File

@@ -187,10 +187,7 @@ func Test_fileExists(t *testing.T) {
}
func TestWriteDefaultIdentity(t *testing.T) {
tmpDir, err := os.MkdirTemp(os.TempDir(), "go-tests")
if err != nil {
t.Fatal(err)
}
tmpDir := t.TempDir()
oldConfigDir := configDir
oldIdentityDir := identityDir
@@ -372,10 +369,7 @@ func (r *renewer) Renew(http.RoundTripper) (*api.SignResponse, error) {
}
func TestIdentity_Renew(t *testing.T) {
tmpDir, err := os.MkdirTemp(os.TempDir(), "go-tests")
if err != nil {
t.Fatal(err)
}
tmpDir := t.TempDir()
oldIdentityDir := identityDir
identityDir = returnInput("testdata/identity")

View File

@@ -62,7 +62,7 @@ func init() {
}
mTLSDialContext = func() func(ctx context.Context, network, address string) (net.Conn, error) {
d := &tls.Dialer{
NetDialer: getDefaultDialer(),
NetDialer: createDefaultDialer(),
Config: &tls.Config{
MinVersion: tls.VersionTLS12,
RootCAs: pool,
@@ -132,8 +132,7 @@ func (c *Client) getClientTLSConfig(ctx context.Context, sign *api.SignResponse,
}
tr := getDefaultTransport(tlsConfig)
tr.DialTLS = c.buildDialTLS(tlsCtx)
// tr.DialTLSContext = c.buildDialTLSContext(tlsCtx)
tr.DialTLSContext = c.buildDialTLSContext(tlsCtx)
renewer.RenewCertificate = getRenewFunc(tlsCtx, c, tr, pk) //nolint:contextcheck // deeply nested context
// Update client transport
@@ -179,8 +178,7 @@ func (c *Client) GetServerTLSConfig(ctx context.Context, sign *api.SignResponse,
// Update renew function with transport
tr := getDefaultTransport(tlsConfig)
tr.DialTLS = c.buildDialTLS(tlsCtx)
// tr.DialTLSContext = c.buildDialTLSContext(tlsCtx)
tr.DialTLSContext = c.buildDialTLSContext(tlsCtx)
renewer.RenewCertificate = getRenewFunc(tlsCtx, c, tr, pk) //nolint:contextcheck // deeply nested context
// Update client transport
@@ -212,17 +210,10 @@ func (c *Client) buildGetConfigForClient(ctx *TLSOptionCtx) func(*tls.ClientHell
}
}
// buildDialTLS returns an implementation of DialTLS callback in http.Transport.
func (c *Client) buildDialTLS(ctx *TLSOptionCtx) func(network, addr string) (net.Conn, error) {
return func(network, addr string) (net.Conn, error) {
return tls.DialWithDialer(getDefaultDialer(), network, addr, ctx.mutableConfig.TLSConfig())
}
}
//nolint:unused // buildDialTLSContext returns an implementation of DialTLSContext callback in http.Transport.
// buildDialTLSContext returns an implementation of DialTLSContext callback in http.Transport.
func (c *Client) buildDialTLSContext(tlsCtx *TLSOptionCtx) func(ctx context.Context, network, addr string) (net.Conn, error) {
return func(ctx context.Context, network, addr string) (net.Conn, error) {
d := getDefaultDialer()
d := createDefaultDialer()
// TLS dialers do not support context, but we can use the context
// deadline if it is set.
if t, ok := ctx.Deadline(); ok {
@@ -300,8 +291,8 @@ func getDefaultTLSConfig(sign *api.SignResponse) *tls.Config {
}
}
// getDefaultDialer returns a new dialer with the default configuration.
func getDefaultDialer() *net.Dialer {
// createDefaultDialer returns a new dialer with the default configuration.
func createDefaultDialer() *net.Dialer {
// With the KeepAlive parameter set to 0, it will be use Golang's default.
return &net.Dialer{
Timeout: 30 * time.Second,
@@ -325,7 +316,7 @@ func getDefaultTransport(tlsConfig *tls.Config) *http.Transport {
// context if it is available, required and expected to work.
dialContext = nil
case mTLSDialContext == nil:
d := getDefaultDialer()
d := createDefaultDialer()
dialContext = d.DialContext
default:
dialContext = mTLSDialContext()

View File

@@ -12,7 +12,10 @@ import (
pb "cloud.google.com/go/security/privateca/apiv1/privatecapb"
"github.com/pkg/errors"
kmsapi "go.step.sm/crypto/kms/apiv1"
"github.com/smallstep/certificates/internal/cast"
)
var (
@@ -250,7 +253,7 @@ func createX509Parameters(cert *x509.Certificate) *pb.X509Parameters {
maxPathLength = 0
caOptions.MaxIssuerPathLength = &maxPathLength
case cert.MaxPathLen > 0:
maxPathLength = int32(cert.MaxPathLen)
maxPathLength = cast.Int32(cert.MaxPathLen)
caOptions.MaxIssuerPathLength = &maxPathLength
}
caOptions.IsCa = &cert.IsCA
@@ -304,7 +307,7 @@ func isExtraExtension(oid asn1.ObjectIdentifier) bool {
func createObjectID(oid asn1.ObjectIdentifier) *pb.ObjectId {
ret := make([]int32, len(oid))
for i, v := range oid {
ret[i] = int32(v)
ret[i] = cast.Int32(v)
}
return &pb.ObjectId{
ObjectIdPath: ret,

View File

@@ -9,10 +9,13 @@ import (
"time"
"github.com/pkg/errors"
"github.com/smallstep/certificates/authority/provisioner"
"golang.org/x/crypto/ssh"
"github.com/smallstep/nosql"
"github.com/smallstep/nosql/database"
"golang.org/x/crypto/ssh"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/internal/cast"
)
var (
@@ -465,7 +468,7 @@ func (db *DB) GetSSHHostPrincipals() ([]string, error) {
if err := json.Unmarshal(e.Value, &data); err != nil {
return nil, err
}
if time.Unix(int64(data.Expiry), 0).After(time.Now()) {
if time.Unix(cast.Int64(data.Expiry), 0).After(time.Now()) {
principals = append(principals, string(e.Key))
}
}

1
go.mod
View File

@@ -6,6 +6,7 @@ require (
cloud.google.com/go/longrunning v0.6.4
cloud.google.com/go/security v1.18.3
github.com/Masterminds/sprig/v3 v3.3.0
github.com/ccoveille/go-safecast v1.5.0
github.com/coreos/go-oidc/v3 v3.12.0
github.com/dgraph-io/badger v1.6.2
github.com/dgraph-io/badger/v2 v2.2007.4

2
go.sum
View File

@@ -84,6 +84,8 @@ github.com/aws/smithy-go v1.22.2/go.mod h1:irrKGvNn1InZwb2d7fkIRNucdfwR8R+Ts3wxY
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
github.com/bgentry/speakeasy v0.1.0/go.mod h1:+zsyZBPWlz7T6j88CTgSN5bM796AkVf0kBD4zp0CCIs=
github.com/ccoveille/go-safecast v1.5.0 h1:cT/3uVQ/i5PTiJvhvkSU81HeKNurtyQtBndXEH3hDg4=
github.com/ccoveille/go-safecast v1.5.0/go.mod h1:QqwNjxQ7DAqY0C721OIO9InMk9zCwcsO7tnRuHytad8=
github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8=
github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE=
github.com/cespare/xxhash v1.1.0 h1:a6HrQnmkObjyL+Gs60czilIUGqrzKutQD6XZog3p+ko=

95
internal/cast/cast.go Normal file
View File

@@ -0,0 +1,95 @@
package cast
import (
"github.com/ccoveille/go-safecast"
)
type signed interface {
~int | ~int8 | ~int16 | ~int32 | ~int64
}
type unsigned interface {
~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64
}
type number interface {
signed | unsigned
}
func SafeUint(x int) (uint, error) {
return safecast.ToUint(x)
}
func Uint(x int) uint {
u, err := SafeUint(x)
if err != nil {
panic(err)
}
return u
}
func SafeInt64[T number](x T) (int64, error) {
return safecast.ToInt64(x)
}
func Int64[T number](x T) int64 {
i64, err := SafeInt64(x)
if err != nil {
panic(err)
}
return i64
}
func SafeUint64[T signed](x T) (uint64, error) {
return safecast.ToUint64(x)
}
func Uint64[T signed](x T) uint64 {
u64, err := SafeUint64(x)
if err != nil {
panic(err)
}
return u64
}
func SafeInt32[T signed](x T) (int32, error) {
return safecast.ToInt32(x)
}
func Int32[T signed](x T) int32 {
i32, err := SafeInt32(x)
if err != nil {
panic(err)
}
return i32
}
func SafeUint32(x int) (uint32, error) {
return safecast.ToUint32(x)
}
func Uint32(x int) uint32 {
u32, err := SafeUint32(x)
if err != nil {
panic(err)
}
return u32
}
func SafeUint16(x int) (uint16, error) {
return safecast.ToUint16(x)
}
func Uint16(x int) uint16 {
u16, err := SafeUint16(x)
if err != nil {
panic(err)
}
return u16
}

View File

@@ -0,0 +1,79 @@
package cast
import (
"math"
"testing"
"github.com/stretchr/testify/require"
)
func TestUintConvertsValues(t *testing.T) {
require.Equal(t, uint(0), Uint(0))
require.Equal(t, uint(math.MaxInt), Uint(math.MaxInt))
require.Equal(t, uint(42), Uint(42))
}
func TestUintPanicsOnNegativeValue(t *testing.T) {
require.Panics(t, func() { Uint(-1) })
}
func TestInt64ConvertsValues(t *testing.T) {
require.Equal(t, int64(0), Int64(0))
require.Equal(t, int64(math.MaxInt), Int64(math.MaxInt))
require.Equal(t, int64(42), Int64(42))
}
func TestInt64PanicsOnLargeValue(t *testing.T) {
require.Panics(t, func() { Int64(uint64(math.MaxInt + 1)) })
}
func TestUint64ConvertsValues(t *testing.T) {
require.Equal(t, uint64(0), Uint64(0))
require.Equal(t, uint64(math.MaxInt), Uint64((math.MaxInt)))
require.Equal(t, uint64(42), Uint64(42))
}
func TestUint64PanicsOnNegativeValue(t *testing.T) {
require.Panics(t, func() { Uint64(-1) })
}
func TestInt32ConvertsValues(t *testing.T) {
require.Equal(t, int32(0), Int32(0))
require.Equal(t, int32(math.MaxInt32), Int32(math.MaxInt32))
require.Equal(t, int32(42), Int32(42))
}
func TestInt32PanicsOnTooSmallValue(t *testing.T) {
require.Panics(t, func() { Int32(math.MinInt32 - 1) })
}
func TestInt32PanicsOnLargeValue(t *testing.T) {
require.Panics(t, func() { Int32(math.MaxInt32 + 1) })
}
func TestUint32ConvertsValues(t *testing.T) {
require.Equal(t, uint32(0), Uint32(0))
require.Equal(t, uint32(math.MaxUint32), Uint32(math.MaxUint32))
require.Equal(t, uint32(42), Uint32(42))
}
func TestUint32PanicsOnNegativeValue(t *testing.T) {
require.Panics(t, func() { Uint32(-1) })
}
func TestUint32PanicsOnLargeValue(t *testing.T) {
require.Panics(t, func() { Uint32(math.MaxUint32 + 1) })
}
func TestUint16ConvertsValues(t *testing.T) {
require.Equal(t, uint16(0), Uint16(0))
require.Equal(t, uint16(math.MaxUint16), Uint16(math.MaxUint16))
require.Equal(t, uint16(42), Uint16(42))
}
func TestUint16PanicsOnNegativeValue(t *testing.T) {
require.Panics(t, func() { Uint16(-1) })
}
func TestUint16PanicsOnLargeValue(t *testing.T) {
require.Panics(t, func() { Uint16(math.MaxUint32 + 1) })
}

View File

@@ -13,6 +13,7 @@ import (
badgerv1 "github.com/dgraph-io/badger"
badgerv2 "github.com/dgraph-io/badger/v2"
"github.com/smallstep/certificates/internal/cast"
"github.com/smallstep/nosql"
)
@@ -306,9 +307,9 @@ func parseBadgerEncode(bk []byte) (value, rest []byte) {
var (
keyLen uint16
start = uint16(2)
length = uint16(len(bk))
length = cast.Uint16(len(bk))
)
if uint16(len(bk)) < start {
if cast.Uint16(len(bk)) < start {
return nil, bk
}
// First 2 bytes stores the length of the value.

View File

@@ -368,9 +368,7 @@ func TestTemplate_Output(t *testing.T) {
}
func TestOutput_Write(t *testing.T) {
dir, err := os.MkdirTemp("", "test-output-write")
assert.FatalError(t, err)
defer os.RemoveAll(dir)
dir := t.TempDir()
join := func(elem ...string) string {
elems := append([]string{dir}, elem...)

View File

@@ -14,6 +14,8 @@ import (
"math/bits"
"strconv"
"strings"
"github.com/smallstep/certificates/internal/cast"
)
var (
@@ -80,7 +82,7 @@ func base128IntLength(n uint64) int {
func appendBase128Int(dst []byte, n uint64) []byte {
for i := base128IntLength(n) - 1; i >= 0; i-- {
o := byte(n >> uint(i*7))
o := byte(n >> cast.Uint(i*7))
o &= 0x7f
if i != 0 {
o |= 0x80

View File

@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//nolint:revive,gocritic,errorlint,unconvert // code copied from crypto/x509
//nolint:revive,gocritic,errorlint,unconvert,staticcheck // code copied from crypto/x509
package legacyx509
import (