Let the CA determine the RA lifetime

When the RA mode with StepCAS is used, let the CA decide which lifetime
the RA should get instead of requiring always 24h.

This commit also fixes linter warnings.

Related to #1094
This commit is contained in:
Mariano Cano
2024-03-12 14:29:55 -07:00
parent ef1631b00d
commit 10f6a901ec
32 changed files with 179 additions and 38 deletions

View File

@@ -147,10 +147,10 @@ func validateJWS(next nextHTTP) nextHTTP {
sig := jws.Signatures[0]
uh := sig.Unprotected
if len(uh.KeyID) > 0 ||
if uh.KeyID != "" ||
uh.JSONWebKey != nil ||
len(uh.Algorithm) > 0 ||
len(uh.Nonce) > 0 ||
uh.Algorithm != "" ||
uh.Nonce != "" ||
len(uh.ExtraHeaders) > 0 {
render.Error(w, acme.NewError(acme.ErrorMalformedType, "unprotected header must not be used"))
return
@@ -199,7 +199,7 @@ func validateJWS(next nextHTTP) nextHTTP {
return
}
if hdr.JSONWebKey != nil && len(hdr.KeyID) > 0 {
if hdr.JSONWebKey != nil && hdr.KeyID != "" {
render.Error(w, acme.NewError(acme.ErrorMalformedType, "jwk and kid are mutually exclusive"))
return
}

View File

@@ -565,7 +565,7 @@ func LogSSHCertificate(w http.ResponseWriter, cert *ssh.Certificate) {
func ParseCursor(r *http.Request) (cursor string, limit int, err error) {
q := r.URL.Query()
cursor = q.Get("cursor")
if v := q.Get("limit"); len(v) > 0 {
if v := q.Get("limit"); v != "" {
limit, err = strconv.Atoi(v)
if err != nil {
return "", 0, errs.BadRequestErr(err, "limit '%s' is not an integer", v)

View File

@@ -78,7 +78,7 @@ func Revoke(w http.ResponseWriter, r *http.Request) {
// A token indicates that we are using the api via a provisioner token,
// otherwise it is assumed that the certificate is revoking itself over mTLS.
if len(body.OTT) > 0 {
if body.OTT != "" {
logOtt(w, body.OTT)
if _, err := a.Authorize(ctx, body.OTT); err != nil {
render.Error(w, errs.UnauthorizedErr(err))

View File

@@ -38,7 +38,7 @@ func GetProvisioner(w http.ResponseWriter, r *http.Request) {
auth := mustAuthority(ctx)
db := admin.MustFromContext(ctx)
if len(id) > 0 {
if id != "" {
if p, err = auth.LoadProvisionerByID(id); err != nil {
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", id))
return
@@ -116,7 +116,7 @@ func DeleteProvisioner(w http.ResponseWriter, r *http.Request) {
name := chi.URLParam(r, "name")
auth := mustAuthority(r.Context())
if len(id) > 0 {
if id != "" {
if p, err = auth.LoadProvisionerByID(id); err != nil {
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", id))
return

View File

@@ -857,7 +857,7 @@ func TestDB_CreateAdmin(t *testing.T) {
var _dba = new(dbAdmin)
assert.FatalError(t, json.Unmarshal(nu, _dba))
assert.True(t, len(_dba.ID) > 0 && _dba.ID == string(key))
assert.True(t, _dba.ID != "" && _dba.ID == string(key))
assert.Equals(t, _dba.AuthorityID, adm.AuthorityId)
assert.Equals(t, _dba.ProvisionerID, adm.ProvisionerId)
assert.Equals(t, _dba.Subject, adm.Subject)
@@ -890,7 +890,7 @@ func TestDB_CreateAdmin(t *testing.T) {
var _dba = new(dbAdmin)
assert.FatalError(t, json.Unmarshal(nu, _dba))
assert.True(t, len(_dba.ID) > 0 && _dba.ID == string(key))
assert.True(t, _dba.ID != "" && _dba.ID == string(key))
assert.Equals(t, _dba.AuthorityID, adm.AuthorityId)
assert.Equals(t, _dba.ProvisionerID, adm.ProvisionerId)
assert.Equals(t, _dba.Subject, adm.Subject)

View File

@@ -906,7 +906,7 @@ func TestDB_CreateProvisioner(t *testing.T) {
var _dbp = new(dbProvisioner)
assert.FatalError(t, json.Unmarshal(nu, _dbp))
assert.True(t, len(_dbp.ID) > 0 && _dbp.ID == string(key))
assert.True(t, _dbp.ID != "" && _dbp.ID == string(key))
assert.Equals(t, _dbp.AuthorityID, prov.AuthorityId)
assert.Equals(t, _dbp.Type, prov.Type)
assert.Equals(t, _dbp.Name, prov.Name)
@@ -944,7 +944,7 @@ func TestDB_CreateProvisioner(t *testing.T) {
var _dbp = new(dbProvisioner)
assert.FatalError(t, json.Unmarshal(nu, _dbp))
assert.True(t, len(_dbp.ID) > 0 && _dbp.ID == string(key))
assert.True(t, _dbp.ID != "" && _dbp.ID == string(key))
assert.Equals(t, _dbp.AuthorityID, prov.AuthorityId)
assert.Equals(t, _dbp.Type, prov.Type)
assert.Equals(t, _dbp.Name, prov.Name)
@@ -1093,7 +1093,7 @@ func TestDB_UpdateProvisioner(t *testing.T) {
var _dbp = new(dbProvisioner)
assert.FatalError(t, json.Unmarshal(nu, _dbp))
assert.True(t, len(_dbp.ID) > 0 && _dbp.ID == string(key))
assert.True(t, _dbp.ID != "" && _dbp.ID == string(key))
assert.Equals(t, _dbp.AuthorityID, prov.AuthorityId)
assert.Equals(t, _dbp.Type, prov.Type)
assert.Equals(t, _dbp.Name, prov.Name)
@@ -1188,7 +1188,7 @@ func TestDB_UpdateProvisioner(t *testing.T) {
var _dbp = new(dbProvisioner)
assert.FatalError(t, json.Unmarshal(nu, _dbp))
assert.True(t, len(_dbp.ID) > 0 && _dbp.ID == string(key))
assert.True(t, _dbp.ID != "" && _dbp.ID == string(key))
assert.Equals(t, _dbp.AuthorityID, prov.AuthorityId)
assert.Equals(t, _dbp.Type, prov.Type)
assert.Equals(t, _dbp.Name, prov.Name)

View File

@@ -203,7 +203,7 @@ func matchURIConstraint(uri *url.URL, constraint string) (bool, error) {
// domainToReverseLabels converts a textual domain name like foo.example.com to
// the list of labels in reverse order, e.g. ["com", "example", "foo"].
func domainToReverseLabels(domain string) (reverseLabels []string, ok bool) {
for len(domain) > 0 {
for domain != "" {
if i := strings.LastIndexByte(domain, '.'); i == -1 {
reverseLabels = append(reverseLabels, domain)
domain = ""
@@ -316,7 +316,7 @@ func parseRFC2821Mailbox(in string) (mailbox rfc2821Mailbox, ok bool) {
} else {
// Atom ("." Atom)*
NextChar:
for len(in) > 0 {
for in != "" {
// atext from RFC 2822, Section 3.2.4
c := in[0]

View File

@@ -125,7 +125,7 @@ func (c *Collection) LoadByToken(token *jose.JSONWebToken, claims *jose.Claims)
}
// Try with azp (OIDC)
if len(payload.AuthorizedParty) > 0 {
if payload.AuthorizedParty != "" {
if p, ok := c.LoadByTokenID(payload.AuthorizedParty); ok {
return p, ok
}

View File

@@ -87,7 +87,7 @@ func (p *JWK) GetType() Type {
// GetEncryptedKey returns the base provisioner encrypted key if it's defined.
func (p *JWK) GetEncryptedKey() (string, string, bool) {
return p.Key.KeyID, p.EncryptedKey, len(p.EncryptedKey) > 0
return p.Key.KeyID, p.EncryptedKey, p.EncryptedKey != ""
}
// Init initializes and validates the fields of a JWK type.

View File

@@ -105,7 +105,7 @@ func getKeysFromJWKsURI(uri string) (jose.JSONWebKeySet, time.Duration, error) {
func getCacheAge(cacheControl string) time.Duration {
age := defaultCacheAge
if len(cacheControl) > 0 {
if cacheControl != "" {
match := maxAgeRegex.FindAllStringSubmatch(cacheControl, -1)
if len(match) > 0 {
if len(match[0]) == 2 {

View File

@@ -304,7 +304,7 @@ func (s *SCEP) Init(config Config) (err error) {
}
}
if decryptionKeyURI := s.DecrypterKeyURI; len(decryptionKeyURI) > 0 {
if decryptionKeyURI := s.DecrypterKeyURI; decryptionKeyURI != "" {
u, err := uri.Parse(s.DecrypterKeyURI)
if err != nil {
return fmt.Errorf("failed parsing decrypter key: %w", err)

View File

@@ -813,7 +813,7 @@ func TestX5C_AuthorizeSSHSign(t *testing.T) {
}
tot++
}
if len(tc.claims.Step.SSH.CertType) > 0 {
if tc.claims.Step.SSH.CertType != "" {
assert.Equals(t, tot, 12)
} else {
assert.Equals(t, tot, 10)

View File

@@ -608,19 +608,19 @@ func provisionerWebhookToLinkedca(pwh *provisioner.Webhook) *linkedca.Webhook {
}
func durationsToCertificates(d *linkedca.Durations) (min, max, def *provisioner.Duration, err error) {
if len(d.Min) > 0 {
if d.Min != "" {
min, err = provisioner.NewDuration(d.Min)
if err != nil {
return nil, nil, nil, admin.WrapErrorISE(err, "error parsing minimum duration '%s'", d.Min)
}
}
if len(d.Max) > 0 {
if d.Max != "" {
max, err = provisioner.NewDuration(d.Max)
if err != nil {
return nil, nil, nil, admin.WrapErrorISE(err, "error parsing maximum duration '%s'", d.Max)
}
}
if len(d.Default) > 0 {
if d.Default != "" {
def, err = provisioner.NewDuration(d.Default)
if err != nil {
return nil, nil, nil, admin.WrapErrorISE(err, "error parsing default duration '%s'", d.Default)

View File

@@ -45,7 +45,7 @@ func (a *Authority) GetRoots() ([]*x509.Certificate, error) {
// GetFederation returns all the root certificates in the federation.
// This method implements the Authority interface.
func (a *Authority) GetFederation() (federation []*x509.Certificate, err error) {
a.certificates.Range(func(k, v interface{}) bool {
a.certificates.Range(func(_, v interface{}) bool {
crt, ok := v.(*x509.Certificate)
if !ok {
federation = nil

View File

@@ -59,7 +59,7 @@ var (
)
func withDefaultASN1DN(def *config.ASN1DN) provisioner.CertificateModifierFunc {
return func(crt *x509.Certificate, opts provisioner.SignOptions) error {
return func(crt *x509.Certificate, _ provisioner.SignOptions) error {
if def == nil {
return errors.New("default ASN1DN template cannot be nil")
}
@@ -913,10 +913,16 @@ func (a *Authority) GetTLSCertificate() (*tls.Certificate, error) {
return fatal(err)
}
// For StepCAS RA let the lifetime to the provisioner used by the CA.
var lifetime time.Duration
if casapi.TypeOf(a.x509CAService) != casapi.StepCAS {
lifetime = 24 * time.Hour
}
resp, err := a.x509CAService.CreateCertificate(&casapi.CreateCertificateRequest{
Template: certTpl,
CSR: cr,
Lifetime: 24 * time.Hour,
Lifetime: lifetime,
Backdate: 1 * time.Minute,
IsCAServerCert: true,
})

View File

@@ -204,7 +204,7 @@ func (o *adminOptions) apply(opts []AdminOption) (err error) {
func (o *adminOptions) rawQuery() string {
v := url.Values{}
if len(o.cursor) > 0 {
if o.cursor != "" {
v.Set("cursor", o.cursor)
}
if o.limit > 0 {

View File

@@ -678,7 +678,7 @@ func (ca *CA) shouldServeSCEPEndpoints() bool {
//nolint:unused // useful for debugging
func dumpRoutes(mux chi.Routes) {
// helpful routine for logging all routes
walkFunc := func(method string, route string, handler http.Handler, middlewares ...func(http.Handler) http.Handler) error {
walkFunc := func(method string, route string, _ http.Handler, _ ...func(http.Handler) http.Handler) error {
fmt.Printf("%s %s\n", method, route)
return nil
}

View File

@@ -69,7 +69,7 @@ func init() {
GetClientCertificate: id.GetClientCertificateFunc(),
},
}
return func(ctx context.Context, network, address string) (net.Conn, error) {
return func(ctx context.Context, _, _ string) (net.Conn, error) {
return d.DialContext(ctx, "tcp", net.JoinHostPort(host, port))
}
}

View File

@@ -67,6 +67,14 @@ func (t Type) String() string {
return strings.ToLower(string(t))
}
// TypeOf returns the type of the given CertificateAuthorityService.
func TypeOf(c CertificateAuthorityService) Type {
if ct, ok := c.(interface{ Type() Type }); ok {
return ct.Type()
}
return ExternalCAS
}
// NotImplementedError is the type of error returned if an operation is not implemented.
type NotImplementedError struct {
Message string

View File

@@ -4,6 +4,24 @@ import (
"testing"
)
type simpleCAS struct{}
func (*simpleCAS) CreateCertificate(req *CreateCertificateRequest) (*CreateCertificateResponse, error) {
return nil, NotImplementedError{}
}
func (*simpleCAS) RenewCertificate(req *RenewCertificateRequest) (*RenewCertificateResponse, error) {
return nil, NotImplementedError{}
}
func (*simpleCAS) RevokeCertificate(req *RevokeCertificateRequest) (*RevokeCertificateResponse, error) {
return nil, NotImplementedError{}
}
type fakeCAS struct {
simpleCAS
}
func (*fakeCAS) Type() Type { return SoftCAS }
func TestType_String(t *testing.T) {
tests := []struct {
name string
@@ -25,6 +43,27 @@ func TestType_String(t *testing.T) {
}
}
func TestTypeOf(t *testing.T) {
type args struct {
c CertificateAuthorityService
}
tests := []struct {
name string
args args
want Type
}{
{"ok", args{&simpleCAS{}}, ExternalCAS},
{"ok with type", args{&fakeCAS{}}, SoftCAS},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := TypeOf(tt.args.c); got != tt.want {
t.Errorf("TypeOf() = %v, want %v", got, tt.want)
}
})
}
}
func TestNotImplementedError_Error(t *testing.T) {
type fields struct {
Message string

View File

@@ -154,6 +154,11 @@ func New(ctx context.Context, opts apiv1.Options) (*CloudCAS, error) {
}, nil
}
// Type returns the type of this CertificateAuthorityService.
func (c *CloudCAS) Type() apiv1.Type {
return apiv1.CloudCAS
}
// GetCertificateAuthority returns the root certificate for the given
// certificate authority. It implements apiv1.CertificateAuthorityGetter
// interface.

View File

@@ -443,6 +443,23 @@ func TestNew_real(t *testing.T) {
}
}
func TestCloudCAS_Type(t *testing.T) {
tests := []struct {
name string
want apiv1.Type
}{
{"ok", apiv1.CloudCAS},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &CloudCAS{}
if got := c.Type(); got != tt.want {
t.Errorf("CloudCAS.Type() = %v, want %v", got, tt.want)
}
})
}
}
func TestCloudCAS_GetCertificateAuthority(t *testing.T) {
root := mustParseCertificate(t, testRootCertificate)
type fields struct {

View File

@@ -53,6 +53,11 @@ func New(_ context.Context, opts apiv1.Options) (*SoftCAS, error) {
}, nil
}
// Type returns the type of this CertificateAuthorityService.
func (c *SoftCAS) Type() apiv1.Type {
return apiv1.SoftCAS
}
// CreateCertificate signs a new certificate using Golang or KMS crypto.
func (c *SoftCAS) CreateCertificate(req *apiv1.CreateCertificateRequest) (*apiv1.CreateCertificateResponse, error) {
switch {

View File

@@ -252,6 +252,23 @@ func TestNew_register(t *testing.T) {
}
}
func TestSoftCAS_Type(t *testing.T) {
tests := []struct {
name string
want apiv1.Type
}{
{"ok", apiv1.SoftCAS},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &SoftCAS{}
if got := c.Type(); got != tt.want {
t.Errorf("SoftCAS.Type() = %v, want %v", got, tt.want)
}
})
}
}
func TestSoftCAS_CreateCertificate(t *testing.T) {
mockNow(t)
// Set rand.Reader to EOF

View File

@@ -65,6 +65,11 @@ func New(ctx context.Context, opts apiv1.Options) (*StepCAS, error) {
}, nil
}
// Type returns the type of this CertificateAuthorityService.
func (s *StepCAS) Type() apiv1.Type {
return apiv1.StepCAS
}
// CreateCertificate uses the step-ca sign request with the configured
// provisioner to get a new certificate from the certificate authority.
func (s *StepCAS) CreateCertificate(req *apiv1.CreateCertificateRequest) (*apiv1.CreateCertificateResponse, error) {
@@ -73,8 +78,8 @@ func (s *StepCAS) CreateCertificate(req *apiv1.CreateCertificateRequest) (*apiv1
return nil, errors.New("createCertificateRequest `csr` cannot be nil")
case req.Template == nil:
return nil, errors.New("createCertificateRequest `template` cannot be nil")
case req.Lifetime == 0:
return nil, errors.New("createCertificateRequest `lifetime` cannot be 0")
case req.Lifetime < 0:
return nil, errors.New("createCertificateRequest `lifetime` cannot less than 0")
}
info := &raInfo{

View File

@@ -624,6 +624,23 @@ func TestNew(t *testing.T) {
}
}
func TestStepCAS_Type(t *testing.T) {
tests := []struct {
name string
want apiv1.Type
}{
{"ok", apiv1.StepCAS},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &StepCAS{}
if got := c.Type(); got != tt.want {
t.Errorf("StepCAS.Type() = %v, want %v", got, tt.want)
}
})
}
}
func TestStepCAS_CreateCertificate(t *testing.T) {
caURL, client := testCAHelper(t)
x5c := testX5CIssuer(t, caURL, "")

View File

@@ -110,6 +110,11 @@ func New(ctx context.Context, opts apiv1.Options) (*VaultCAS, error) {
}, nil
}
// Type returns the type of this CertificateAuthorityService.
func (v *VaultCAS) Type() apiv1.Type {
return apiv1.VaultCAS
}
// CreateCertificate signs a new certificate using Hashicorp Vault.
func (v *VaultCAS) CreateCertificate(req *apiv1.CreateCertificateRequest) (*apiv1.CreateCertificateResponse, error) {
switch {

View File

@@ -193,6 +193,23 @@ func TestNew_register(t *testing.T) {
}
}
func TestVaultCAS_Type(t *testing.T) {
tests := []struct {
name string
want apiv1.Type
}{
{"ok", apiv1.VaultCAS},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &VaultCAS{}
if got := c.Type(); got != tt.want {
t.Errorf("VaultCAS.Type() = %v, want %v", got, tt.want)
}
})
}
}
func TestVaultCAS_CreateCertificate(t *testing.T) {
_, client := testCAHelper(t)

View File

@@ -239,7 +239,7 @@ To get a linked authority token:
// replace resolver if requested
if resolver != "" {
net.DefaultResolver.PreferGo = true
net.DefaultResolver.Dial = func(ctx context.Context, network, address string) (net.Conn, error) {
net.DefaultResolver.Dial = func(_ context.Context, network, _ string) (net.Conn, error) {
return net.Dial(network, resolver)
}
}

View File

@@ -116,7 +116,7 @@ func New(c *Config) (AuthDB, error) {
opts := []nosql.Option{nosql.WithDatabase(c.Database),
nosql.WithValueDir(c.ValueDir)}
if len(c.BadgerFileLoadingMode) > 0 {
if c.BadgerFileLoadingMode != "" {
opts = append(opts, nosql.WithBadgerFileLoadingMode(c.BadgerFileLoadingMode))
}

View File

@@ -80,7 +80,7 @@ func (e *Error) StatusCode() int {
// Message returns a user friendly error, if one is set.
func (e *Error) Message() string {
if len(e.Msg) > 0 {
if e.Msg != "" {
return e.Msg
}
return e.Err.Error()
@@ -123,7 +123,7 @@ func Wrapf(status int, e error, format string, args ...interface{}) error {
// MarshalJSON implements json.Marshaller interface for the Error struct.
func (e *Error) MarshalJSON() ([]byte, error) {
var msg string
if len(e.Msg) > 0 {
if e.Msg != "" {
msg = e.Msg
} else {
msg = http.StatusText(e.Status)

View File

@@ -288,7 +288,7 @@ func checkNameConstraints(
// domainToReverseLabels converts a textual domain name like foo.example.com to
// the list of labels in reverse order, e.g. ["com", "example", "foo"].
func domainToReverseLabels(domain string) (reverseLabels []string, ok bool) {
for len(domain) > 0 {
for domain != "" {
if i := strings.LastIndexByte(domain, '.'); i == -1 {
reverseLabels = append(reverseLabels, domain)
domain = ""
@@ -401,7 +401,7 @@ func parseRFC2821Mailbox(in string) (mailbox rfc2821Mailbox, ok bool) {
} else {
// Atom ("." Atom)*
NextChar:
for len(in) > 0 {
for in != "" {
// atext from RFC 2822, Section 3.2.4
c := in[0]