diff --git a/ca/provisioner.go b/ca/provisioner.go index 7ac61f1b..e0c50362 100644 --- a/ca/provisioner.go +++ b/ca/provisioner.go @@ -33,14 +33,17 @@ type Provisioner struct { // NewProvisioner loads and decrypts key material from the CA for the named // provisioner. The key identified by `kid` will be used if specified. If `kid` // is the empty string we'll use the first key for the named provisioner that -// decrypts using `passFile`. +// decrypts using `password`. func NewProvisioner(name, kid, caURL, caRoot string, password []byte) (*Provisioner, error) { var jwk *jose.JSONWebKey var err error - if kid != "" { - jwk, err = loadProvisionerJWKByKid(kid, caURL, caRoot, password) - } else { + switch { + case name == "": + return nil, errors.New("provisioner name cannot be empty") + case kid == "": jwk, err = loadProvisionerJWKByName(name, caURL, caRoot, password) + default: + jwk, err = loadProvisionerJWKByKid(kid, caURL, caRoot, password) } if err != nil { return nil, err @@ -113,7 +116,7 @@ func decryptProvisionerJWK(encryptedKey string, password []byte) (*jose.JSONWebK } // loadProvisionerJWKByKid retrieves a provisioner key from the CA by key ID and -// decrypts it using the specified password file. +// decrypts it using the specified password. func loadProvisionerJWKByKid(kid, caURL, caRoot string, password []byte) (*jose.JSONWebKey, error) { encrypted, err := getProvisionerKey(caURL, caRoot, kid) if err != nil { @@ -125,7 +128,7 @@ func loadProvisionerJWKByKid(kid, caURL, caRoot string, password []byte) (*jose. // loadProvisionerJWKByName retrieves the list of provisioners and encrypted key then // returns the key of the first provisioner with a matching name that can be successfully -// decrypted with the specified password file. +// decrypted with the specified password. func loadProvisionerJWKByName(name, caURL, caRoot string, password []byte) (key *jose.JSONWebKey, err error) { provisioners, err := getProvisioners(caURL, caRoot) if err != nil { @@ -176,8 +179,7 @@ func getProvisioners(caURL, rootFile string) (provisioner.List, error) { } } -// getProvisionerKey returns the encrypted provisioner key with the for the -// given kid. +// getProvisionerKey returns the encrypted provisioner key for the given kid. func getProvisionerKey(caURL, rootFile, kid string) (string, error) { if len(rootFile) == 0 { rootFile = getRootCAPath() diff --git a/ca/provisioner_test.go b/ca/provisioner_test.go index cb4da7da..bc8a2b68 100644 --- a/ca/provisioner_test.go +++ b/ca/provisioner_test.go @@ -1,7 +1,6 @@ package ca import ( - "os" "reflect" "testing" "time" @@ -25,16 +24,10 @@ func getTestProvisioner(t *testing.T, url string) *Provisioner { } func TestNewProvisioner(t *testing.T) { - value := os.Getenv("STEPPATH") - defer os.Setenv("STEPPATH", value) - os.Setenv("STEPPATH", "testdata") - ca := startCATestServer() defer ca.Close() - want := getTestProvisioner(t, ca.URL) - wantByKid := getTestProvisioner(t, ca.URL) - wantByKid.name = "" + type args struct { name string kid string @@ -49,12 +42,12 @@ func TestNewProvisioner(t *testing.T) { wantErr bool }{ {"ok", args{want.name, want.kid, want.caURL, want.caRoot, []byte("password")}, want, false}, - {"ok-by-kid", args{"", want.kid, want.caURL, want.caRoot, []byte("password")}, wantByKid, false}, {"ok-by-name", args{want.name, "", want.caURL, want.caRoot, []byte("password")}, want, false}, - {"fail-by-kid", args{want.name, "bad-kid", want.caURL, want.caRoot, []byte("password")}, nil, true}, - {"fail-by-name", args{"bad-name", "", want.caURL, want.caRoot, []byte("password")}, nil, true}, - {"fail-by-password", args{"", want.kid, want.caURL, want.caRoot, []byte("bad-password")}, nil, true}, - {"fail-by-password", args{want.name, "", want.caURL, want.caRoot, []byte("bad-password")}, nil, true}, + {"fail-bad-kid", args{want.name, "bad-kid", want.caURL, want.caRoot, []byte("password")}, nil, true}, + {"fail-empty-name", args{"", want.kid, want.caURL, want.caRoot, []byte("password")}, nil, true}, + {"fail-bad-name", args{"bad-name", "", want.caURL, want.caRoot, []byte("password")}, nil, true}, + {"fail-by-password", args{want.name, want.kid, want.caURL, want.caRoot, []byte("bad-password")}, nil, true}, + {"fail-by-password-no-kid", args{want.name, "", want.caURL, want.caRoot, []byte("bad-password")}, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) {