diff --git a/.goreleaser.yml b/.goreleaser.yml index 97b9c24c..60eff7e7 100644 --- a/.goreleaser.yml +++ b/.goreleaser.yml @@ -12,25 +12,16 @@ builds: id: step-ca env: - CGO_ENABLED=0 - goos: - - linux - - darwin - - windows - goarch: - - amd64 - - arm - - arm64 - - 386 - goarm: - - 6 - - 7 - ignore: - - goos: windows - goarch: 386 - - goos: windows - goarm: 6 - - goos: windows - goarm: 7 + targets: + - darwin_amd64 + - darwin_arm64 + - freebsd_amd64 + - linux_386 + - linux_amd64 + - linux_arm64 + - linux_arm_6 + - linux_arm_7 + - windows_amd64 flags: - -trimpath main: ./cmd/step-ca/main.go diff --git a/CHANGELOG.md b/CHANGELOG.md index c5319526..65ddbc15 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,8 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. ## [Unreleased - 0.17.5] - DATE ### Added +- Support for Azure Key Vault as a KMS. +- Adapt `pki` package to support key managers. - gocritic linter ### Changed ### Deprecated @@ -15,13 +17,8 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. ### Security ## [0.17.4] - 2021-09-28 -### Added -### Changed -### Deprecated -### Removed ### Fixed - Support host-only or user-only SSH CA. -### Security ## [0.17.3] - 2021-09-24 ### Added @@ -57,4 +54,3 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. - Update TLS cipher suites to include 1.3 ### Security - Fix key version when SHA512WithRSA is used. There was a typo creating RSA keys with SHA256 digests instead of SHA512. - diff --git a/ca/tls.go b/ca/tls.go index 3a3b6766..0738d0e0 100644 --- a/ca/tls.go +++ b/ca/tls.go @@ -279,9 +279,9 @@ func getDefaultTLSConfig(sign *api.SignResponse) *tls.Config { // getDefaultDialer returns a new dialer with the default configuration. func getDefaultDialer() *net.Dialer { + // With the KeepAlive parameter set to 0, it will be use Golang's default. return &net.Dialer{ - Timeout: 30 * time.Second, - KeepAlive: 30 * time.Second, + Timeout: 30 * time.Second, } } diff --git a/cas/apiv1/requests.go b/cas/apiv1/requests.go index b47a9c13..bf745c17 100644 --- a/cas/apiv1/requests.go +++ b/cas/apiv1/requests.go @@ -108,6 +108,9 @@ type GetCertificateAuthorityResponse struct { RootCertificate *x509.Certificate } +// CreateKeyRequest is the request used to generate a new key using a KMS. +type CreateKeyRequest = apiv1.CreateKeyRequest + // CreateCertificateAuthorityRequest is the request used to generate a root or // intermediate certificate. type CreateCertificateAuthorityRequest struct { @@ -126,7 +129,7 @@ type CreateCertificateAuthorityRequest struct { // CreateKey defines the KMS CreateKeyRequest to use when creating a new // CertificateAuthority. If CreateKey is nil, a default algorithm will be // used. - CreateKey *apiv1.CreateKeyRequest + CreateKey *CreateKeyRequest } // CreateCertificateAuthorityResponse is the response for @@ -136,6 +139,7 @@ type CreateCertificateAuthorityResponse struct { Name string Certificate *x509.Certificate CertificateChain []*x509.Certificate + KeyName string PublicKey crypto.PublicKey PrivateKey crypto.PrivateKey Signer crypto.Signer diff --git a/cas/softcas/softcas.go b/cas/softcas/softcas.go index 87dfa5c5..8e67d016 100644 --- a/cas/softcas/softcas.go +++ b/cas/softcas/softcas.go @@ -172,6 +172,7 @@ func (c *SoftCAS) CreateCertificateAuthority(req *apiv1.CreateCertificateAuthori Name: cert.Subject.CommonName, Certificate: cert, CertificateChain: chain, + KeyName: key.Name, PublicKey: key.PublicKey, PrivateKey: key.PrivateKey, Signer: signer, diff --git a/cas/softcas/softcas_test.go b/cas/softcas/softcas_test.go index bd13f310..7d3add4f 100644 --- a/cas/softcas/softcas_test.go +++ b/cas/softcas/softcas_test.go @@ -106,6 +106,7 @@ func (m *mockKeyManager) CreateKey(req *kmsapi.CreateKeyRequest) (*kmsapi.Create signer = m.signer } return &kmsapi.CreateKeyResponse{ + Name: req.Name, PrivateKey: signer, PublicKey: signer.Public(), }, m.errCreateKey @@ -516,6 +517,22 @@ func TestSoftCAS_CreateCertificateAuthority(t *testing.T) { PrivateKey: saSigner, Signer: saSigner, }, false}, + {"ok createKey", fields{nil, nil, &mockKeyManager{}}, args{&apiv1.CreateCertificateAuthorityRequest{ + Type: apiv1.RootCA, + Template: testRootTemplate, + Lifetime: 24 * time.Hour, + CreateKey: &kmsapi.CreateKeyRequest{ + Name: "root_ca.crt", + SignatureAlgorithm: kmsapi.ECDSAWithSHA256, + }, + }}, &apiv1.CreateCertificateAuthorityResponse{ + Name: "Test Root CA", + Certificate: testSignedRootTemplate, + PublicKey: testSignedRootTemplate.PublicKey, + KeyName: "root_ca.crt", + PrivateKey: testSigner, + Signer: testSigner, + }, false}, {"fail template", fields{nil, nil, &mockKeyManager{}}, args{&apiv1.CreateCertificateAuthorityRequest{ Type: apiv1.RootCA, Lifetime: 24 * time.Hour, diff --git a/cmd/step-ca/main.go b/cmd/step-ca/main.go index bed1c14a..01d800d8 100644 --- a/cmd/step-ca/main.go +++ b/cmd/step-ca/main.go @@ -27,6 +27,7 @@ import ( // Enabled kms interfaces. _ "github.com/smallstep/certificates/kms/awskms" + _ "github.com/smallstep/certificates/kms/azurekms" _ "github.com/smallstep/certificates/kms/cloudkms" _ "github.com/smallstep/certificates/kms/softkms" _ "github.com/smallstep/certificates/kms/sshagentkms" diff --git a/examples/basic-client/client.go b/examples/basic-client/client.go index db6092bf..42358ac8 100644 --- a/examples/basic-client/client.go +++ b/examples/basic-client/client.go @@ -116,7 +116,6 @@ func main() { Proxy: http.ProxyFromEnvironment, DialContext: (&net.Dialer{ Timeout: 30 * time.Second, - KeepAlive: 30 * time.Second, DualStack: true, }).DialContext, MaxIdleConns: 100, diff --git a/go.mod b/go.mod index 35043e7e..b2edf0ca 100644 --- a/go.mod +++ b/go.mod @@ -1,9 +1,15 @@ module github.com/smallstep/certificates -go 1.14 +go 1.15 require ( cloud.google.com/go v0.83.0 + github.com/Azure/azure-sdk-for-go v58.0.0+incompatible + github.com/Azure/go-autorest/autorest v0.11.17 + github.com/Azure/go-autorest/autorest/azure/auth v0.5.8 + github.com/Azure/go-autorest/autorest/date v0.3.0 + github.com/Azure/go-autorest/autorest/to v0.4.0 // indirect + github.com/Azure/go-autorest/autorest/validation v0.3.1 // indirect github.com/Masterminds/sprig/v3 v3.2.2 github.com/ThalesIgnite/crypto11 v1.2.4 github.com/aws/aws-sdk-go v1.30.29 @@ -11,7 +17,7 @@ require ( github.com/go-chi/chi v4.0.2+incompatible github.com/go-kit/kit v0.10.0 // indirect github.com/go-piv/piv-go v1.7.0 - github.com/golang/mock v1.5.0 + github.com/golang/mock v1.6.0 github.com/google/uuid v1.3.0 github.com/googleapis/gax-go/v2 v2.0.5 github.com/konsorten/go-windows-terminal-sequences v1.0.2 // indirect diff --git a/go.sum b/go.sum index 4b5e2929..54a824a8 100644 --- a/go.sum +++ b/go.sum @@ -40,6 +40,31 @@ cloud.google.com/go/storage v1.10.0/go.mod h1:FLPqc6j+Ki4BU591ie1oL6qBQGu2Bl/tZ9 dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU= github.com/AndreasBriese/bbloom v0.0.0-20190825152654-46b345b51c96 h1:cTp8I5+VIoKjsnZuH8vjyaysT/ses3EvZeaV/1UkF2M= github.com/AndreasBriese/bbloom v0.0.0-20190825152654-46b345b51c96/go.mod h1:bOvUY6CB00SOBii9/FifXqc0awNKxLFCL/+pkDPuyl8= +github.com/Azure/azure-sdk-for-go v58.0.0+incompatible h1:Cw16jiP4dI+CK761aq44ol4RV5dUiIIXky1+EKpoiVM= +github.com/Azure/azure-sdk-for-go v58.0.0+incompatible/go.mod h1:9XXNKU+eRnpl9moKnB4QOLf1HestfXbmab5FXxiDBjc= +github.com/Azure/go-autorest v14.2.0+incompatible h1:V5VMDjClD3GiElqLWO7mz2MxNAK/vTfRHdAubSIPRgs= +github.com/Azure/go-autorest v14.2.0+incompatible/go.mod h1:r+4oMnoxhatjLLJ6zxSWATqVooLgysK6ZNox3g/xq24= +github.com/Azure/go-autorest/autorest v0.11.17 h1:2zCdHwNgRH+St1J+ZMf66xI8aLr/5KMy+wWLH97zwYM= +github.com/Azure/go-autorest/autorest v0.11.17/go.mod h1:eipySxLmqSyC5s5k1CLupqet0PSENBEDP93LQ9a8QYw= +github.com/Azure/go-autorest/autorest/adal v0.9.5/go.mod h1:B7KF7jKIeC9Mct5spmyCB/A8CG/sEz1vwIRGv/bbw7A= +github.com/Azure/go-autorest/autorest/adal v0.9.11 h1:L4/pmq7poLdsy41Bj1FayKvBhayuWRYkx9HU5i4Ybl0= +github.com/Azure/go-autorest/autorest/adal v0.9.11/go.mod h1:nBKAnTomx8gDtl+3ZCJv2v0KACFHWTB2drffI1B68Pk= +github.com/Azure/go-autorest/autorest/azure/auth v0.5.8 h1:TzPg6B6fTZ0G1zBf3T54aI7p3cAT6u//TOXGPmFMOXg= +github.com/Azure/go-autorest/autorest/azure/auth v0.5.8/go.mod h1:kxyKZTSfKh8OVFWPAgOgQ/frrJgeYQJPyR5fLFmXko4= +github.com/Azure/go-autorest/autorest/azure/cli v0.4.2 h1:dMOmEJfkLKW/7JsokJqkyoYSgmR08hi9KrhjZb+JALY= +github.com/Azure/go-autorest/autorest/azure/cli v0.4.2/go.mod h1:7qkJkT+j6b+hIpzMOwPChJhTqS8VbsqqgULzMNRugoM= +github.com/Azure/go-autorest/autorest/date v0.3.0 h1:7gUk1U5M/CQbp9WoqinNzJar+8KY+LPI6wiWrP/myHw= +github.com/Azure/go-autorest/autorest/date v0.3.0/go.mod h1:BI0uouVdmngYNUzGWeSYnokU+TrmwEsOqdt8Y6sso74= +github.com/Azure/go-autorest/autorest/mocks v0.4.1 h1:K0laFcLE6VLTOwNgSxaGbUcLPuGXlNkbVvq4cW4nIHk= +github.com/Azure/go-autorest/autorest/mocks v0.4.1/go.mod h1:LTp+uSrOhSkaKrUy935gNZuuIPPVsHlr9DSOxSayd+k= +github.com/Azure/go-autorest/autorest/to v0.4.0 h1:oXVqrxakqqV1UZdSazDOPOLvOIz+XA683u8EctwboHk= +github.com/Azure/go-autorest/autorest/to v0.4.0/go.mod h1:fE8iZBn7LQR7zH/9XU2NcPR4o9jEImooCeWJcYV/zLE= +github.com/Azure/go-autorest/autorest/validation v0.3.1 h1:AgyqjAd94fwNAoTjl/WQXg4VvFeRFpO+UhNyRXqF1ac= +github.com/Azure/go-autorest/autorest/validation v0.3.1/go.mod h1:yhLgjC0Wda5DYXl6JAsWyUe4KVNffhoDhG0zVzUMo3E= +github.com/Azure/go-autorest/logger v0.2.0 h1:e4RVHVZKC5p6UANLJHkM4OfR1UKZPj8Wt8Pcx+3oqrE= +github.com/Azure/go-autorest/logger v0.2.0/go.mod h1:T9E3cAhj2VqvPOtCYAvby9aBXkZmbF5NWuPV8+WeEW8= +github.com/Azure/go-autorest/tracing v0.6.0 h1:TYi4+3m5t6K48TGI9AUdb+IzbnSxvnvUMfuitfgcfuo= +github.com/Azure/go-autorest/tracing v0.6.0/go.mod h1:+vhtPC754Xsa23ID7GlGsrdKBpUA79WCAKPPZVC2DeU= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= github.com/Knetic/govaluate v3.0.1-0.20171022003610-9aa49832a739+incompatible/go.mod h1:r7JcOSlj0wfOMncg0iLm8Leh48TZaKVeNIfJntJ2wa0= @@ -146,6 +171,9 @@ github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZm github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2/go.mod h1:SqUrOPUnsFjfmXRMNPybcSiG0BgUW2AuFH8PAnS2iTw= github.com/dgryski/go-farm v0.0.0-20200201041132-a6ae2369ad13 h1:fAjc9m62+UWV/WAFKLNi6ZS0675eEUC9y3AlwSbQu1Y= github.com/dgryski/go-farm v0.0.0-20200201041132-a6ae2369ad13/go.mod h1:SqUrOPUnsFjfmXRMNPybcSiG0BgUW2AuFH8PAnS2iTw= +github.com/dimchansky/utfbom v1.1.0/go.mod h1:rO41eb7gLfo8SF1jd9F8HplJm1Fewwi4mQvIirEdv+8= +github.com/dimchansky/utfbom v1.1.1 h1:vV6w1AhK4VMnhBno/TPVCoK9U/LP0PkLCS9tbxHdi/U= +github.com/dimchansky/utfbom v1.1.1/go.mod h1:SxdoEBH5qIqFocHMyGOXVAybYJdr71b1Q/j0mACtrfE= github.com/dustin/go-humanize v0.0.0-20171111073723-bb3d318650d4/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/dustin/go-humanize v1.0.0 h1:VSnTsYCnlFHaM2/igO1h6X3HA71jcobQuxemgkq4zYo= github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= @@ -163,6 +191,8 @@ github.com/envoyproxy/go-control-plane v0.9.9-0.20210217033140-668b12f5399d/go.m github.com/envoyproxy/go-control-plane v0.9.9-0.20210512163311-63b5d3c536b0/go.mod h1:hliV/p42l8fGbc6Y9bQ70uLwIvmJyVE5k4iMKlh8wCQ= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4= +github.com/form3tech-oss/jwt-go v3.2.2+incompatible h1:TcekIExNqud5crz4xD2pavyTgWiPvpYe4Xau31I0PRk= +github.com/form3tech-oss/jwt-go v3.2.2+incompatible/go.mod h1:pbq4aXjuKjdthFRnoDwaVPLA+WlJuPGy+QneDUgJi2k= github.com/franela/goblin v0.0.0-20200105215937-c9ffbefa60db/go.mod h1:7dvUGVsVBjqR7JHJk0brhHOZYGmfBYOrK0ZhYMEtBr4= github.com/franela/goreq v0.0.0-20171204163338-bcd34c9993f8/go.mod h1:ZhphrRTfi2rbfLwlschooIH4+wKKDR4Pdxhh+TRoA20= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= @@ -206,8 +236,9 @@ github.com/golang/mock v1.4.0/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt github.com/golang/mock v1.4.1/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw= github.com/golang/mock v1.4.3/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw= github.com/golang/mock v1.4.4/go.mod h1:l3mdAwkq5BuhzHwde/uurv3sEJeZMXNpwsxVWU71h+4= -github.com/golang/mock v1.5.0 h1:jlYHihg//f7RRwuPfptm04yp4s7O6Kw8EZiVYIGcH0g= github.com/golang/mock v1.5.0/go.mod h1:CWnOUgYIOo4TcNZ0wHX3YZCqsaM1I1Jvs6v3mP3KVu8= +github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc= +github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= @@ -377,6 +408,7 @@ github.com/mitchellh/copystructure v1.0.0/go.mod h1:SNtv71yrdKgLRyLFxmLdkAbkKEFW github.com/mitchellh/copystructure v1.2.0 h1:vpKXTN4ewci03Vljg/q9QvCGUDttBOGBIa15WveJJGw= github.com/mitchellh/copystructure v1.2.0/go.mod h1:qLl+cE2AmVv+CoeAwDPye/v+N2HKCj9FbZEVFJRxO9s= github.com/mitchellh/go-homedir v1.0.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= +github.com/mitchellh/go-homedir v1.1.0 h1:lukF9ziXFxDFPkA1vsr5zpc1XuPDn/wFntq5mG+4E0Y= github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= github.com/mitchellh/go-testing-interface v1.0.0/go.mod h1:kRemZodwjscx+RGhAo8eIhFbs2+BFgRtFPeD/KE+zxI= github.com/mitchellh/gox v0.4.0/go.mod h1:Sd9lOJ0+aimLBi73mGofS1ycjY8lL3uZM3JPS42BGNg= @@ -539,8 +571,6 @@ go.step.sm/cli-utils v0.6.0/go.mod h1:jklBMavFl2PbmGlyxgax08ZnB0uWpadjuOlSKKXz+0 go.step.sm/crypto v0.9.0/go.mod h1:+CYG05Mek1YDqi5WK0ERc6cOpKly2i/a5aZmU1sfGj0= go.step.sm/crypto v0.11.0 h1:VDpeVgEmqme/FK2w5QINxkOQ1FWOm/Wi2TwQXiacKr8= go.step.sm/crypto v0.11.0/go.mod h1:5YzQ85BujYBu6NH18jw7nFjwuRnDch35nLzH0ES5sKg= -go.step.sm/linkedca v0.5.0 h1:oZVRSpElM7lAL1XN2YkjdHwI/oIZ+1ULOnuqYPM6xjY= -go.step.sm/linkedca v0.5.0/go.mod h1:5uTRjozEGSPAZal9xJqlaD38cvJcLe3o1VAFVjqcORo= go.step.sm/linkedca v0.8.0 h1:86DAufqUtUvFTJgYpgG0McKkpqnjXxg53FTXYyhs0HI= go.step.sm/linkedca v0.8.0/go.mod h1:5uTRjozEGSPAZal9xJqlaD38cvJcLe3o1VAFVjqcORo= go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= @@ -561,6 +591,8 @@ golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8U golang.org/x/crypto v0.0.0-20200414173820-0848c9571904/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20200728195943-123391ffb6de/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20201002170205-7f63de1d35b0/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20210513164829-c07d793c2f9a/go.mod h1:P+XmwS30IXTQdn5tA2iutPOUgjI07+tq3H3K9MVA1s8= golang.org/x/crypto v0.0.0-20210915214749-c084706c2272 h1:3erb+vDS8lU1sxfDHF4/hhWyaXnhIaO+7RgL4fDZORA= golang.org/x/crypto v0.0.0-20210915214749-c084706c2272/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= diff --git a/kms/apiv1/options.go b/kms/apiv1/options.go index 7cc7f748..79b07a60 100644 --- a/kms/apiv1/options.go +++ b/kms/apiv1/options.go @@ -29,6 +29,12 @@ type CertificateManager interface { StoreCertificate(req *StoreCertificateRequest) error } +// ValidateName is an interface that KeyManager can implement to validate a +// given name or URI. +type NameValidator interface { + ValidateName(s string) error +} + // ErrNotImplemented is the type of error returned if an operation is not // implemented. type ErrNotImplemented struct { @@ -73,6 +79,8 @@ const ( YubiKey Type = "yubikey" // SSHAgentKMS is a KMS implementation using ssh-agent to access keys. SSHAgentKMS Type = "sshagentkms" + // AzureKMS is a KMS implementation using Azure Key Vault. + AzureKMS Type = "azurekms" ) // Options are the KMS options. They represent the kms object in the ca.json. @@ -81,18 +89,18 @@ type Options struct { Type string `json:"type"` // Path to the credentials file used in CloudKMS and AmazonKMS. - CredentialsFile string `json:"credentialsFile"` + CredentialsFile string `json:"credentialsFile,omitempty"` // URI is based on the PKCS #11 URI Scheme defined in // https://tools.ietf.org/html/rfc7512 and represents the configuration used // to connect to the KMS. // // Used by: pkcs11 - URI string `json:"uri"` + URI string `json:"uri,omitempty"` // Pin used to access the PKCS11 module. It can be defined in the URI using // the pin-value or pin-source properties. - Pin string `json:"pin"` + Pin string `json:"pin,omitempty"` // ManagementKey used in YubiKeys. Default management key is the hexadecimal // string 010203040506070801020304050607080102030405060708: @@ -101,13 +109,13 @@ type Options struct { // 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, // 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, // } - ManagementKey string `json:"managementKey"` + ManagementKey string `json:"managementKey,omitempty"` // Region to use in AmazonKMS. - Region string `json:"region"` + Region string `json:"region,omitempty"` // Profile to use in AmazonKMS. - Profile string `json:"profile"` + Profile string `json:"profile,omitempty"` } // Validate checks the fields in Options. @@ -118,8 +126,9 @@ func (o *Options) Validate() error { switch Type(strings.ToLower(o.Type)) { case DefaultKMS, SoftKMS: // Go crypto based kms. - case CloudKMS, AmazonKMS, SSHAgentKMS: // Cloud based kms. + case CloudKMS, AmazonKMS, AzureKMS: // Cloud based kms. case YubiKey, PKCS11: // Hardware based kms. + case SSHAgentKMS: // Others default: return errors.Errorf("unsupported kms type %s", o.Type) } diff --git a/kms/azurekms/internal/mock/key_vault_client.go b/kms/azurekms/internal/mock/key_vault_client.go new file mode 100644 index 00000000..42bd55fd --- /dev/null +++ b/kms/azurekms/internal/mock/key_vault_client.go @@ -0,0 +1,80 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/smallstep/certificates/kms/azurekms (interfaces: KeyVaultClient) + +// Package mock is a generated GoMock package. +package mock + +import ( + context "context" + keyvault "github.com/Azure/azure-sdk-for-go/services/keyvault/v7.1/keyvault" + gomock "github.com/golang/mock/gomock" + reflect "reflect" +) + +// KeyVaultClient is a mock of KeyVaultClient interface +type KeyVaultClient struct { + ctrl *gomock.Controller + recorder *KeyVaultClientMockRecorder +} + +// KeyVaultClientMockRecorder is the mock recorder for KeyVaultClient +type KeyVaultClientMockRecorder struct { + mock *KeyVaultClient +} + +// NewKeyVaultClient creates a new mock instance +func NewKeyVaultClient(ctrl *gomock.Controller) *KeyVaultClient { + mock := &KeyVaultClient{ctrl: ctrl} + mock.recorder = &KeyVaultClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *KeyVaultClient) EXPECT() *KeyVaultClientMockRecorder { + return m.recorder +} + +// CreateKey mocks base method +func (m *KeyVaultClient) CreateKey(arg0 context.Context, arg1, arg2 string, arg3 keyvault.KeyCreateParameters) (keyvault.KeyBundle, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CreateKey", arg0, arg1, arg2, arg3) + ret0, _ := ret[0].(keyvault.KeyBundle) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CreateKey indicates an expected call of CreateKey +func (mr *KeyVaultClientMockRecorder) CreateKey(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateKey", reflect.TypeOf((*KeyVaultClient)(nil).CreateKey), arg0, arg1, arg2, arg3) +} + +// GetKey mocks base method +func (m *KeyVaultClient) GetKey(arg0 context.Context, arg1, arg2, arg3 string) (keyvault.KeyBundle, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetKey", arg0, arg1, arg2, arg3) + ret0, _ := ret[0].(keyvault.KeyBundle) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetKey indicates an expected call of GetKey +func (mr *KeyVaultClientMockRecorder) GetKey(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetKey", reflect.TypeOf((*KeyVaultClient)(nil).GetKey), arg0, arg1, arg2, arg3) +} + +// Sign mocks base method +func (m *KeyVaultClient) Sign(arg0 context.Context, arg1, arg2, arg3 string, arg4 keyvault.KeySignParameters) (keyvault.KeyOperationResult, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Sign", arg0, arg1, arg2, arg3, arg4) + ret0, _ := ret[0].(keyvault.KeyOperationResult) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Sign indicates an expected call of Sign +func (mr *KeyVaultClientMockRecorder) Sign(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Sign", reflect.TypeOf((*KeyVaultClient)(nil).Sign), arg0, arg1, arg2, arg3, arg4) +} diff --git a/kms/azurekms/key_vault.go b/kms/azurekms/key_vault.go new file mode 100644 index 00000000..34d9c3f1 --- /dev/null +++ b/kms/azurekms/key_vault.go @@ -0,0 +1,342 @@ +package azurekms + +import ( + "context" + "crypto" + "regexp" + "time" + + "github.com/Azure/azure-sdk-for-go/services/keyvault/v7.1/keyvault" + "github.com/Azure/go-autorest/autorest/azure" + "github.com/Azure/go-autorest/autorest/azure/auth" + "github.com/Azure/go-autorest/autorest/date" + "github.com/pkg/errors" + "github.com/smallstep/certificates/kms/apiv1" + "github.com/smallstep/certificates/kms/uri" +) + +func init() { + apiv1.Register(apiv1.AzureKMS, func(ctx context.Context, opts apiv1.Options) (apiv1.KeyManager, error) { + return New(ctx, opts) + }) +} + +// Scheme is the scheme used for the Azure Key Vault uris. +const Scheme = "azurekms" + +// keyIDRegexp is the regular expression that Key Vault uses on the kid. We can +// extract the vault, name and version of the key. +var keyIDRegexp = regexp.MustCompile(`^https://([0-9a-zA-Z-]+)\.vault\.azure\.net/keys/([0-9a-zA-Z-]+)/([0-9a-zA-Z-]+)$`) + +var ( + valueTrue = true + value2048 int32 = 2048 + value3072 int32 = 3072 + value4096 int32 = 4096 +) + +var now = func() time.Time { + return time.Now().UTC() +} + +type keyType struct { + Kty keyvault.JSONWebKeyType + Curve keyvault.JSONWebKeyCurveName +} + +func (k keyType) KeyType(pl apiv1.ProtectionLevel) keyvault.JSONWebKeyType { + switch k.Kty { + case keyvault.EC: + if pl == apiv1.HSM { + return keyvault.ECHSM + } + return k.Kty + case keyvault.RSA: + if pl == apiv1.HSM { + return keyvault.RSAHSM + } + return k.Kty + default: + return "" + } +} + +var signatureAlgorithmMapping = map[apiv1.SignatureAlgorithm]keyType{ + apiv1.UnspecifiedSignAlgorithm: { + Kty: keyvault.EC, + Curve: keyvault.P256, + }, + apiv1.SHA256WithRSA: { + Kty: keyvault.RSA, + }, + apiv1.SHA384WithRSA: { + Kty: keyvault.RSA, + }, + apiv1.SHA512WithRSA: { + Kty: keyvault.RSA, + }, + apiv1.SHA256WithRSAPSS: { + Kty: keyvault.RSA, + }, + apiv1.SHA384WithRSAPSS: { + Kty: keyvault.RSA, + }, + apiv1.SHA512WithRSAPSS: { + Kty: keyvault.RSA, + }, + apiv1.ECDSAWithSHA256: { + Kty: keyvault.EC, + Curve: keyvault.P256, + }, + apiv1.ECDSAWithSHA384: { + Kty: keyvault.EC, + Curve: keyvault.P384, + }, + apiv1.ECDSAWithSHA512: { + Kty: keyvault.EC, + Curve: keyvault.P521, + }, +} + +// vaultResource is the value the client will use as audience. +const vaultResource = "https://vault.azure.net" + +// KeyVaultClient is the interface implemented by keyvault.BaseClient. It will +// be used for testing purposes. +type KeyVaultClient interface { + GetKey(ctx context.Context, vaultBaseURL string, keyName string, keyVersion string) (keyvault.KeyBundle, error) + CreateKey(ctx context.Context, vaultBaseURL string, keyName string, parameters keyvault.KeyCreateParameters) (keyvault.KeyBundle, error) + Sign(ctx context.Context, vaultBaseURL string, keyName string, keyVersion string, parameters keyvault.KeySignParameters) (keyvault.KeyOperationResult, error) +} + +// KeyVault implements a KMS using Azure Key Vault. +// +// The URI format used in Azure Key Vault is the following: +// +// - azurekms:name=key-name;vault=vault-name +// - azurekms:name=key-name;vault=vault-name?version=key-version +// - azurekms:name=key-name;vault=vault-name?hsm=true +// +// The scheme is "azurekms"; "name" is the key name; "vault" is the key vault +// name where the key is located; "version" is an optional parameter that +// defines the version of they key, if version is not given, the latest one will +// be used; "hsm" defines if an HSM want to be used for this key, this is +// specially useful when this is used from `step`. +// +// TODO(mariano): The implementation is using /services/keyvault/v7.1/keyvault +// package, at some point Azure might create a keyvault client with all the +// functionality in /sdk/keyvault, we should migrate to that once available. +type KeyVault struct { + baseClient KeyVaultClient + defaults DefaultOptions +} + +// DefaultOptions are custom options that can be passed as defaults using the +// URI in apiv1.Options. +type DefaultOptions struct { + Vault string + ProtectionLevel apiv1.ProtectionLevel +} + +var createClient = func(ctx context.Context, opts apiv1.Options) (KeyVaultClient, error) { + baseClient := keyvault.New() + + // With an URI, try to log in only using client credentials in the URI. + // Client credentials requires: + // - client-id + // - client-secret + // - tenant-id + // And optionally the aad-endpoint to support custom clouds: + // - aad-endpoint (defaults to https://login.microsoftonline.com/) + if opts.URI != "" { + u, err := uri.ParseWithScheme(Scheme, opts.URI) + if err != nil { + return nil, err + } + + // Required options + clientID := u.Get("client-id") + clientSecret := u.Get("client-secret") + tenantID := u.Get("tenant-id") + // optional + aadEndpoint := u.Get("aad-endpoint") + + if clientID != "" && clientSecret != "" && tenantID != "" { + s := auth.EnvironmentSettings{ + Values: map[string]string{ + auth.ClientID: clientID, + auth.ClientSecret: clientSecret, + auth.TenantID: tenantID, + auth.Resource: vaultResource, + }, + Environment: azure.PublicCloud, + } + if aadEndpoint != "" { + s.Environment.ActiveDirectoryEndpoint = aadEndpoint + } + baseClient.Authorizer, err = s.GetAuthorizer() + if err != nil { + return nil, err + } + return baseClient, nil + } + } + + // Attempt to authorize with the following methods: + // 1. Environment variables. + // - Client credentials + // - Client certificate + // - Username and password + // - MSI + // 2. Using Azure CLI 2.0 on local development. + authorizer, err := auth.NewAuthorizerFromEnvironmentWithResource(vaultResource) + if err != nil { + authorizer, err = auth.NewAuthorizerFromCLIWithResource(vaultResource) + if err != nil { + return nil, errors.Wrap(err, "error getting authorizer for key vault") + } + } + baseClient.Authorizer = authorizer + return &baseClient, nil +} + +// New initializes a new KMS implemented using Azure Key Vault. +func New(ctx context.Context, opts apiv1.Options) (*KeyVault, error) { + baseClient, err := createClient(ctx, opts) + if err != nil { + return nil, err + } + + // step and step-ca do not need and URI, but having a default vault and + // protection level is useful if this package is used as an api + var defaults DefaultOptions + if opts.URI != "" { + u, err := uri.ParseWithScheme(Scheme, opts.URI) + if err != nil { + return nil, err + } + defaults.Vault = u.Get("vault") + if u.GetBool("hsm") { + defaults.ProtectionLevel = apiv1.HSM + } + } + + return &KeyVault{ + baseClient: baseClient, + defaults: defaults, + }, nil +} + +// GetPublicKey loads a public key from Azure Key Vault by its resource name. +func (k *KeyVault) GetPublicKey(req *apiv1.GetPublicKeyRequest) (crypto.PublicKey, error) { + if req.Name == "" { + return nil, errors.New("getPublicKeyRequest 'name' cannot be empty") + } + + vault, name, version, _, err := parseKeyName(req.Name, k.defaults) + if err != nil { + return nil, err + } + + ctx, cancel := defaultContext() + defer cancel() + + resp, err := k.baseClient.GetKey(ctx, vaultBaseURL(vault), name, version) + if err != nil { + return nil, errors.Wrap(err, "keyVault GetKey failed") + } + + return convertKey(resp.Key) +} + +// CreateKey creates a asymmetric key in Azure Key Vault. +func (k *KeyVault) CreateKey(req *apiv1.CreateKeyRequest) (*apiv1.CreateKeyResponse, error) { + if req.Name == "" { + return nil, errors.New("createKeyRequest 'name' cannot be empty") + } + + vault, name, _, hsm, err := parseKeyName(req.Name, k.defaults) + if err != nil { + return nil, err + } + + // Override protection level to HSM only if it's not specified, and is given + // in the uri. + protectionLevel := req.ProtectionLevel + if protectionLevel == apiv1.UnspecifiedProtectionLevel && hsm { + protectionLevel = apiv1.HSM + } + + kt, ok := signatureAlgorithmMapping[req.SignatureAlgorithm] + if !ok { + return nil, errors.Errorf("keyVault does not support signature algorithm '%s'", req.SignatureAlgorithm) + } + var keySize *int32 + if kt.Kty == keyvault.RSA || kt.Kty == keyvault.RSAHSM { + switch req.Bits { + case 2048: + keySize = &value2048 + case 0, 3072: + keySize = &value3072 + case 4096: + keySize = &value4096 + default: + return nil, errors.Errorf("keyVault does not support key size %d", req.Bits) + } + } + + created := date.UnixTime(now()) + + ctx, cancel := defaultContext() + defer cancel() + + resp, err := k.baseClient.CreateKey(ctx, vaultBaseURL(vault), name, keyvault.KeyCreateParameters{ + Kty: kt.KeyType(protectionLevel), + KeySize: keySize, + Curve: kt.Curve, + KeyOps: &[]keyvault.JSONWebKeyOperation{ + keyvault.Sign, keyvault.Verify, + }, + KeyAttributes: &keyvault.KeyAttributes{ + Enabled: &valueTrue, + Created: &created, + NotBefore: &created, + }, + }) + if err != nil { + return nil, errors.Wrap(err, "keyVault CreateKey failed") + } + + publicKey, err := convertKey(resp.Key) + if err != nil { + return nil, err + } + + keyURI := getKeyName(vault, name, resp) + return &apiv1.CreateKeyResponse{ + Name: keyURI, + PublicKey: publicKey, + CreateSignerRequest: apiv1.CreateSignerRequest{ + SigningKey: keyURI, + }, + }, nil +} + +// CreateSigner returns a crypto.Signer from a previously created asymmetric key. +func (k *KeyVault) CreateSigner(req *apiv1.CreateSignerRequest) (crypto.Signer, error) { + if req.SigningKey == "" { + return nil, errors.New("createSignerRequest 'signingKey' cannot be empty") + } + return NewSigner(k.baseClient, req.SigningKey, k.defaults) +} + +// Close closes the client connection to the Azure Key Vault. This is a noop. +func (k *KeyVault) Close() error { + return nil +} + +// ValidateName validates that the given string is a valid URI. +func (k *KeyVault) ValidateName(s string) error { + _, _, _, _, err := parseKeyName(s, k.defaults) + return err +} diff --git a/kms/azurekms/key_vault_test.go b/kms/azurekms/key_vault_test.go new file mode 100644 index 00000000..8f968189 --- /dev/null +++ b/kms/azurekms/key_vault_test.go @@ -0,0 +1,653 @@ +//go:generate mockgen -package mock -mock_names=KeyVaultClient=KeyVaultClient -destination internal/mock/key_vault_client.go github.com/smallstep/certificates/kms/azurekms KeyVaultClient +package azurekms + +import ( + "context" + "crypto" + "encoding/json" + "fmt" + "reflect" + "testing" + "time" + + "github.com/Azure/azure-sdk-for-go/services/keyvault/v7.1/keyvault" + "github.com/Azure/go-autorest/autorest/date" + "github.com/golang/mock/gomock" + "github.com/smallstep/certificates/kms/apiv1" + "github.com/smallstep/certificates/kms/azurekms/internal/mock" + "go.step.sm/crypto/keyutil" + "gopkg.in/square/go-jose.v2" +) + +var errTest = fmt.Errorf("test error") + +func mockNow(t *testing.T) time.Time { + old := now + t0 := time.Unix(1234567890, 123).UTC() + now = func() time.Time { + return t0 + } + t.Cleanup(func() { + now = old + }) + return t0 +} + +func mockClient(t *testing.T) *mock.KeyVaultClient { + t.Helper() + ctrl := gomock.NewController(t) + t.Cleanup(func() { + ctrl.Finish() + }) + return mock.NewKeyVaultClient(ctrl) +} + +func createJWK(t *testing.T, pub crypto.PublicKey) *keyvault.JSONWebKey { + t.Helper() + b, err := json.Marshal(&jose.JSONWebKey{ + Key: pub, + }) + if err != nil { + t.Fatal(err) + } + key := new(keyvault.JSONWebKey) + if err := json.Unmarshal(b, key); err != nil { + t.Fatal(err) + } + return key +} + +func Test_now(t *testing.T) { + t0 := now() + if loc := t0.Location(); loc != time.UTC { + t.Errorf("now() Location = %v, want %v", loc, time.UTC) + } +} + +func TestNew(t *testing.T) { + client := mockClient(t) + old := createClient + t.Cleanup(func() { + createClient = old + }) + + type args struct { + ctx context.Context + opts apiv1.Options + } + tests := []struct { + name string + setup func() + args args + want *KeyVault + wantErr bool + }{ + {"ok", func() { + createClient = func(ctx context.Context, opts apiv1.Options) (KeyVaultClient, error) { + return client, nil + } + }, args{context.Background(), apiv1.Options{}}, &KeyVault{ + baseClient: client, + }, false}, + {"ok with vault", func() { + createClient = func(ctx context.Context, opts apiv1.Options) (KeyVaultClient, error) { + return client, nil + } + }, args{context.Background(), apiv1.Options{ + URI: "azurekms:vault=my-vault", + }}, &KeyVault{ + baseClient: client, + defaults: DefaultOptions{ + Vault: "my-vault", + ProtectionLevel: apiv1.UnspecifiedProtectionLevel, + }, + }, false}, + {"ok with vault + hsm", func() { + createClient = func(ctx context.Context, opts apiv1.Options) (KeyVaultClient, error) { + return client, nil + } + }, args{context.Background(), apiv1.Options{ + URI: "azurekms:vault=my-vault;hsm=true", + }}, &KeyVault{ + baseClient: client, + defaults: DefaultOptions{ + Vault: "my-vault", + ProtectionLevel: apiv1.HSM, + }, + }, false}, + {"fail", func() { + createClient = func(ctx context.Context, opts apiv1.Options) (KeyVaultClient, error) { + return nil, errTest + } + }, args{context.Background(), apiv1.Options{}}, nil, true}, + {"fail uri", func() { + createClient = func(ctx context.Context, opts apiv1.Options) (KeyVaultClient, error) { + return client, nil + } + }, args{context.Background(), apiv1.Options{ + URI: "kms:vault=my-vault;hsm=true", + }}, nil, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.setup() + got, err := New(tt.args.ctx, tt.args.opts) + if (err != nil) != tt.wantErr { + t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("New() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestKeyVault_createClient(t *testing.T) { + type args struct { + ctx context.Context + opts apiv1.Options + } + tests := []struct { + name string + args args + skip bool + wantErr bool + }{ + {"ok", args{context.Background(), apiv1.Options{}}, true, false}, + {"ok with uri", args{context.Background(), apiv1.Options{ + URI: "azurekms:client-id=id;client-secret=secret;tenant-id=id", + }}, false, false}, + {"ok with uri+aad", args{context.Background(), apiv1.Options{ + URI: "azurekms:client-id=id;client-secret=secret;tenant-id=id;aad-enpoint=https%3A%2F%2Flogin.microsoftonline.us%2F", + }}, false, false}, + {"ok with uri no config", args{context.Background(), apiv1.Options{ + URI: "azurekms:", + }}, true, false}, + {"fail uri", args{context.Background(), apiv1.Options{ + URI: "kms:client-id=id;client-secret=secret;tenant-id=id", + }}, false, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.skip { + t.SkipNow() + } + _, err := createClient(tt.args.ctx, tt.args.opts) + if (err != nil) != tt.wantErr { + t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestKeyVault_GetPublicKey(t *testing.T) { + key, err := keyutil.GenerateDefaultSigner() + if err != nil { + t.Fatal(err) + } + pub := key.Public() + jwk := createJWK(t, pub) + + client := mockClient(t) + client.EXPECT().GetKey(gomock.Any(), "https://my-vault.vault.azure.net/", "my-key", "").Return(keyvault.KeyBundle{ + Key: jwk, + }, nil) + client.EXPECT().GetKey(gomock.Any(), "https://my-vault.vault.azure.net/", "my-key", "my-version").Return(keyvault.KeyBundle{ + Key: jwk, + }, nil) + client.EXPECT().GetKey(gomock.Any(), "https://my-vault.vault.azure.net/", "not-found", "my-version").Return(keyvault.KeyBundle{}, errTest) + + type fields struct { + baseClient KeyVaultClient + } + type args struct { + req *apiv1.GetPublicKeyRequest + } + tests := []struct { + name string + fields fields + args args + want crypto.PublicKey + wantErr bool + }{ + {"ok", fields{client}, args{&apiv1.GetPublicKeyRequest{ + Name: "azurekms:vault=my-vault;name=my-key", + }}, pub, false}, + {"ok with version", fields{client}, args{&apiv1.GetPublicKeyRequest{ + Name: "azurekms:vault=my-vault;name=my-key?version=my-version", + }}, pub, false}, + {"fail GetKey", fields{client}, args{&apiv1.GetPublicKeyRequest{ + Name: "azurekms:vault=my-vault;name=not-found?version=my-version", + }}, nil, true}, + {"fail empty", fields{client}, args{&apiv1.GetPublicKeyRequest{ + Name: "", + }}, nil, true}, + {"fail vault", fields{client}, args{&apiv1.GetPublicKeyRequest{ + Name: "azurekms:vault=;name=not-found?version=my-version", + }}, nil, true}, + {"fail id", fields{client}, args{&apiv1.GetPublicKeyRequest{ + Name: "azurekms:vault=;name=?version=my-version", + }}, nil, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + k := &KeyVault{ + baseClient: tt.fields.baseClient, + } + got, err := k.GetPublicKey(tt.args.req) + if (err != nil) != tt.wantErr { + t.Errorf("KeyVault.GetPublicKey() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("KeyVault.GetPublicKey() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestKeyVault_CreateKey(t *testing.T) { + ecKey, err := keyutil.GenerateDefaultSigner() + if err != nil { + t.Fatal(err) + } + rsaKey, err := keyutil.GenerateSigner("RSA", "", 2048) + if err != nil { + t.Fatal(err) + } + ecPub := ecKey.Public() + rsaPub := rsaKey.Public() + ecJWK := createJWK(t, ecPub) + rsaJWK := createJWK(t, rsaPub) + + t0 := date.UnixTime(mockNow(t)) + client := mockClient(t) + + expects := []struct { + Name string + Kty keyvault.JSONWebKeyType + KeySize *int32 + Curve keyvault.JSONWebKeyCurveName + Key *keyvault.JSONWebKey + }{ + {"P-256", keyvault.EC, nil, keyvault.P256, ecJWK}, + {"P-256 HSM", keyvault.ECHSM, nil, keyvault.P256, ecJWK}, + {"P-256 HSM (uri)", keyvault.ECHSM, nil, keyvault.P256, ecJWK}, + {"P-256 Default", keyvault.EC, nil, keyvault.P256, ecJWK}, + {"P-384", keyvault.EC, nil, keyvault.P384, ecJWK}, + {"P-521", keyvault.EC, nil, keyvault.P521, ecJWK}, + {"RSA 0", keyvault.RSA, &value3072, "", rsaJWK}, + {"RSA 0 HSM", keyvault.RSAHSM, &value3072, "", rsaJWK}, + {"RSA 0 HSM (uri)", keyvault.RSAHSM, &value3072, "", rsaJWK}, + {"RSA 2048", keyvault.RSA, &value2048, "", rsaJWK}, + {"RSA 3072", keyvault.RSA, &value3072, "", rsaJWK}, + {"RSA 4096", keyvault.RSA, &value4096, "", rsaJWK}, + } + + for _, e := range expects { + client.EXPECT().CreateKey(gomock.Any(), "https://my-vault.vault.azure.net/", "my-key", keyvault.KeyCreateParameters{ + Kty: e.Kty, + KeySize: e.KeySize, + Curve: e.Curve, + KeyOps: &[]keyvault.JSONWebKeyOperation{ + keyvault.Sign, keyvault.Verify, + }, + KeyAttributes: &keyvault.KeyAttributes{ + Enabled: &valueTrue, + Created: &t0, + NotBefore: &t0, + }, + }).Return(keyvault.KeyBundle{ + Key: e.Key, + }, nil) + } + client.EXPECT().CreateKey(gomock.Any(), "https://my-vault.vault.azure.net/", "not-found", gomock.Any()).Return(keyvault.KeyBundle{}, errTest) + client.EXPECT().CreateKey(gomock.Any(), "https://my-vault.vault.azure.net/", "not-found", gomock.Any()).Return(keyvault.KeyBundle{ + Key: nil, + }, nil) + + type fields struct { + baseClient KeyVaultClient + } + type args struct { + req *apiv1.CreateKeyRequest + } + tests := []struct { + name string + fields fields + args args + want *apiv1.CreateKeyResponse + wantErr bool + }{ + {"ok P-256", fields{client}, args{&apiv1.CreateKeyRequest{ + Name: "azurekms:vault=my-vault;name=my-key", + SignatureAlgorithm: apiv1.ECDSAWithSHA256, + ProtectionLevel: apiv1.Software, + }}, &apiv1.CreateKeyResponse{ + Name: "azurekms:name=my-key;vault=my-vault", + PublicKey: ecPub, + CreateSignerRequest: apiv1.CreateSignerRequest{ + SigningKey: "azurekms:name=my-key;vault=my-vault", + }, + }, false}, + {"ok P-256 HSM", fields{client}, args{&apiv1.CreateKeyRequest{ + Name: "azurekms:vault=my-vault;name=my-key", + SignatureAlgorithm: apiv1.ECDSAWithSHA256, + ProtectionLevel: apiv1.HSM, + }}, &apiv1.CreateKeyResponse{ + Name: "azurekms:name=my-key;vault=my-vault", + PublicKey: ecPub, + CreateSignerRequest: apiv1.CreateSignerRequest{ + SigningKey: "azurekms:name=my-key;vault=my-vault", + }, + }, false}, + {"ok P-256 HSM (uri)", fields{client}, args{&apiv1.CreateKeyRequest{ + Name: "azurekms:vault=my-vault;name=my-key?hsm=true", + SignatureAlgorithm: apiv1.ECDSAWithSHA256, + }}, &apiv1.CreateKeyResponse{ + Name: "azurekms:name=my-key;vault=my-vault", + PublicKey: ecPub, + CreateSignerRequest: apiv1.CreateSignerRequest{ + SigningKey: "azurekms:name=my-key;vault=my-vault", + }, + }, false}, + {"ok P-256 Default", fields{client}, args{&apiv1.CreateKeyRequest{ + Name: "azurekms:vault=my-vault;name=my-key", + }}, &apiv1.CreateKeyResponse{ + Name: "azurekms:name=my-key;vault=my-vault", + PublicKey: ecPub, + CreateSignerRequest: apiv1.CreateSignerRequest{ + SigningKey: "azurekms:name=my-key;vault=my-vault", + }, + }, false}, + {"ok P-384", fields{client}, args{&apiv1.CreateKeyRequest{ + Name: "azurekms:vault=my-vault;name=my-key", + SignatureAlgorithm: apiv1.ECDSAWithSHA384, + }}, &apiv1.CreateKeyResponse{ + Name: "azurekms:name=my-key;vault=my-vault", + PublicKey: ecPub, + CreateSignerRequest: apiv1.CreateSignerRequest{ + SigningKey: "azurekms:name=my-key;vault=my-vault", + }, + }, false}, + {"ok P-521", fields{client}, args{&apiv1.CreateKeyRequest{ + Name: "azurekms:vault=my-vault;name=my-key", + SignatureAlgorithm: apiv1.ECDSAWithSHA512, + }}, &apiv1.CreateKeyResponse{ + Name: "azurekms:name=my-key;vault=my-vault", + PublicKey: ecPub, + CreateSignerRequest: apiv1.CreateSignerRequest{ + SigningKey: "azurekms:name=my-key;vault=my-vault", + }, + }, false}, + {"ok RSA 0", fields{client}, args{&apiv1.CreateKeyRequest{ + Name: "azurekms:vault=my-vault;name=my-key", + Bits: 0, + SignatureAlgorithm: apiv1.SHA256WithRSA, + ProtectionLevel: apiv1.Software, + }}, &apiv1.CreateKeyResponse{ + Name: "azurekms:name=my-key;vault=my-vault", + PublicKey: rsaPub, + CreateSignerRequest: apiv1.CreateSignerRequest{ + SigningKey: "azurekms:name=my-key;vault=my-vault", + }, + }, false}, + {"ok RSA 0 HSM", fields{client}, args{&apiv1.CreateKeyRequest{ + Name: "azurekms:vault=my-vault;name=my-key", + Bits: 0, + SignatureAlgorithm: apiv1.SHA256WithRSAPSS, + ProtectionLevel: apiv1.HSM, + }}, &apiv1.CreateKeyResponse{ + Name: "azurekms:name=my-key;vault=my-vault", + PublicKey: rsaPub, + CreateSignerRequest: apiv1.CreateSignerRequest{ + SigningKey: "azurekms:name=my-key;vault=my-vault", + }, + }, false}, + {"ok RSA 0 HSM (uri)", fields{client}, args{&apiv1.CreateKeyRequest{ + Name: "azurekms:vault=my-vault;name=my-key;hsm=true", + Bits: 0, + SignatureAlgorithm: apiv1.SHA256WithRSAPSS, + }}, &apiv1.CreateKeyResponse{ + Name: "azurekms:name=my-key;vault=my-vault", + PublicKey: rsaPub, + CreateSignerRequest: apiv1.CreateSignerRequest{ + SigningKey: "azurekms:name=my-key;vault=my-vault", + }, + }, false}, + {"ok RSA 2048", fields{client}, args{&apiv1.CreateKeyRequest{ + Name: "azurekms:vault=my-vault;name=my-key", + Bits: 2048, + SignatureAlgorithm: apiv1.SHA384WithRSA, + }}, &apiv1.CreateKeyResponse{ + Name: "azurekms:name=my-key;vault=my-vault", + PublicKey: rsaPub, + CreateSignerRequest: apiv1.CreateSignerRequest{ + SigningKey: "azurekms:name=my-key;vault=my-vault", + }, + }, false}, + {"ok RSA 3072", fields{client}, args{&apiv1.CreateKeyRequest{ + Name: "azurekms:vault=my-vault;name=my-key", + Bits: 3072, + SignatureAlgorithm: apiv1.SHA512WithRSA, + }}, &apiv1.CreateKeyResponse{ + Name: "azurekms:name=my-key;vault=my-vault", + PublicKey: rsaPub, + CreateSignerRequest: apiv1.CreateSignerRequest{ + SigningKey: "azurekms:name=my-key;vault=my-vault", + }, + }, false}, + {"ok RSA 4096", fields{client}, args{&apiv1.CreateKeyRequest{ + Name: "azurekms:vault=my-vault;name=my-key", + Bits: 4096, + SignatureAlgorithm: apiv1.SHA512WithRSAPSS, + }}, &apiv1.CreateKeyResponse{ + Name: "azurekms:name=my-key;vault=my-vault", + PublicKey: rsaPub, + CreateSignerRequest: apiv1.CreateSignerRequest{ + SigningKey: "azurekms:name=my-key;vault=my-vault", + }, + }, false}, + {"fail createKey", fields{client}, args{&apiv1.CreateKeyRequest{ + Name: "azurekms:vault=my-vault;name=not-found", + SignatureAlgorithm: apiv1.ECDSAWithSHA256, + }}, nil, true}, + {"fail convertKey", fields{client}, args{&apiv1.CreateKeyRequest{ + Name: "azurekms:vault=my-vault;name=not-found", + SignatureAlgorithm: apiv1.ECDSAWithSHA256, + }}, nil, true}, + {"fail name", fields{client}, args{&apiv1.CreateKeyRequest{ + Name: "", + }}, nil, true}, + {"fail vault", fields{client}, args{&apiv1.CreateKeyRequest{ + Name: "azurekms:vault=;name=not-found?version=my-version", + }}, nil, true}, + {"fail id", fields{client}, args{&apiv1.CreateKeyRequest{ + Name: "azurekms:vault=my-vault;name=?version=my-version", + }}, nil, true}, + {"fail SignatureAlgorithm", fields{client}, args{&apiv1.CreateKeyRequest{ + Name: "azurekms:vault=my-vault;name=not-found", + SignatureAlgorithm: apiv1.PureEd25519, + }}, nil, true}, + {"fail bit size", fields{client}, args{&apiv1.CreateKeyRequest{ + Name: "azurekms:vault=my-vault;name=not-found", + SignatureAlgorithm: apiv1.SHA384WithRSAPSS, + Bits: 1024, + }}, nil, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + k := &KeyVault{ + baseClient: tt.fields.baseClient, + } + got, err := k.CreateKey(tt.args.req) + if (err != nil) != tt.wantErr { + t.Errorf("KeyVault.CreateKey() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("KeyVault.CreateKey() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestKeyVault_CreateSigner(t *testing.T) { + key, err := keyutil.GenerateDefaultSigner() + if err != nil { + t.Fatal(err) + } + pub := key.Public() + jwk := createJWK(t, pub) + + client := mockClient(t) + client.EXPECT().GetKey(gomock.Any(), "https://my-vault.vault.azure.net/", "my-key", "").Return(keyvault.KeyBundle{ + Key: jwk, + }, nil) + client.EXPECT().GetKey(gomock.Any(), "https://my-vault.vault.azure.net/", "my-key", "my-version").Return(keyvault.KeyBundle{ + Key: jwk, + }, nil) + client.EXPECT().GetKey(gomock.Any(), "https://my-vault.vault.azure.net/", "not-found", "my-version").Return(keyvault.KeyBundle{}, errTest) + + type fields struct { + baseClient KeyVaultClient + } + type args struct { + req *apiv1.CreateSignerRequest + } + tests := []struct { + name string + fields fields + args args + want crypto.Signer + wantErr bool + }{ + {"ok", fields{client}, args{&apiv1.CreateSignerRequest{ + SigningKey: "azurekms:vault=my-vault;name=my-key", + }}, &Signer{ + client: client, + vaultBaseURL: "https://my-vault.vault.azure.net/", + name: "my-key", + version: "", + publicKey: pub, + }, false}, + {"ok with version", fields{client}, args{&apiv1.CreateSignerRequest{ + SigningKey: "azurekms:vault=my-vault;name=my-key;version=my-version", + }}, &Signer{ + client: client, + vaultBaseURL: "https://my-vault.vault.azure.net/", + name: "my-key", + version: "my-version", + publicKey: pub, + }, false}, + {"fail GetKey", fields{client}, args{&apiv1.CreateSignerRequest{ + SigningKey: "azurekms:vault=my-vault;name=not-found;version=my-version", + }}, nil, true}, + {"fail SigningKey", fields{client}, args{&apiv1.CreateSignerRequest{ + SigningKey: "", + }}, nil, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + k := &KeyVault{ + baseClient: tt.fields.baseClient, + } + got, err := k.CreateSigner(tt.args.req) + if (err != nil) != tt.wantErr { + t.Errorf("KeyVault.CreateSigner() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("KeyVault.CreateSigner() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestKeyVault_Close(t *testing.T) { + client := mockClient(t) + type fields struct { + baseClient KeyVaultClient + } + tests := []struct { + name string + fields fields + wantErr bool + }{ + {"ok", fields{client}, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + k := &KeyVault{ + baseClient: tt.fields.baseClient, + } + if err := k.Close(); (err != nil) != tt.wantErr { + t.Errorf("KeyVault.Close() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func Test_keyType_KeyType(t *testing.T) { + type fields struct { + Kty keyvault.JSONWebKeyType + Curve keyvault.JSONWebKeyCurveName + } + type args struct { + pl apiv1.ProtectionLevel + } + tests := []struct { + name string + fields fields + args args + want keyvault.JSONWebKeyType + }{ + {"ec", fields{keyvault.EC, keyvault.P256}, args{apiv1.UnspecifiedProtectionLevel}, keyvault.EC}, + {"ec software", fields{keyvault.EC, keyvault.P384}, args{apiv1.Software}, keyvault.EC}, + {"ec hsm", fields{keyvault.EC, keyvault.P521}, args{apiv1.HSM}, keyvault.ECHSM}, + {"rsa", fields{keyvault.RSA, keyvault.P256}, args{apiv1.UnspecifiedProtectionLevel}, keyvault.RSA}, + {"rsa software", fields{keyvault.RSA, ""}, args{apiv1.Software}, keyvault.RSA}, + {"rsa hsm", fields{keyvault.RSA, ""}, args{apiv1.HSM}, keyvault.RSAHSM}, + {"empty", fields{"FOO", ""}, args{apiv1.UnspecifiedProtectionLevel}, ""}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + k := keyType{ + Kty: tt.fields.Kty, + Curve: tt.fields.Curve, + } + if got := k.KeyType(tt.args.pl); !reflect.DeepEqual(got, tt.want) { + t.Errorf("keyType.KeyType() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestKeyVault_ValidateName(t *testing.T) { + type args struct { + s string + } + tests := []struct { + name string + args args + wantErr bool + }{ + {"ok", args{"azurekms:name=my-key;vault=my-vault"}, false}, + {"ok hsm", args{"azurekms:name=my-key;vault=my-vault?hsm=true"}, false}, + {"fail scheme", args{"azure:name=my-key;vault=my-vault"}, true}, + {"fail parse uri", args{"azurekms:name=%ZZ;vault=my-vault"}, true}, + {"fail no name", args{"azurekms:vault=my-vault"}, true}, + {"fail no vault", args{"azurekms:name=my-key"}, true}, + {"fail empty", args{""}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + k := &KeyVault{} + if err := k.ValidateName(tt.args.s); (err != nil) != tt.wantErr { + t.Errorf("KeyVault.ValidateName() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} diff --git a/kms/azurekms/signer.go b/kms/azurekms/signer.go new file mode 100644 index 00000000..e3aca5fe --- /dev/null +++ b/kms/azurekms/signer.go @@ -0,0 +1,160 @@ +package azurekms + +import ( + "crypto" + "crypto/ecdsa" + "crypto/rsa" + "encoding/base64" + "io" + "math/big" + + "github.com/Azure/azure-sdk-for-go/services/keyvault/v7.1/keyvault" + "github.com/pkg/errors" + "golang.org/x/crypto/cryptobyte" + "golang.org/x/crypto/cryptobyte/asn1" +) + +// Signer implements a crypto.Signer using the AWS KMS. +type Signer struct { + client KeyVaultClient + vaultBaseURL string + name string + version string + publicKey crypto.PublicKey +} + +// NewSigner creates a new signer using a key in the AWS KMS. +func NewSigner(client KeyVaultClient, signingKey string, defaults DefaultOptions) (crypto.Signer, error) { + vault, name, version, _, err := parseKeyName(signingKey, defaults) + if err != nil { + return nil, err + } + + // Make sure that the key exists. + signer := &Signer{ + client: client, + vaultBaseURL: vaultBaseURL(vault), + name: name, + version: version, + } + if err := signer.preloadKey(); err != nil { + return nil, err + } + + return signer, nil +} + +func (s *Signer) preloadKey() error { + ctx, cancel := defaultContext() + defer cancel() + + resp, err := s.client.GetKey(ctx, s.vaultBaseURL, s.name, s.version) + if err != nil { + return errors.Wrap(err, "keyVault GetKey failed") + } + + s.publicKey, err = convertKey(resp.Key) + return err +} + +// Public returns the public key of this signer or an error. +func (s *Signer) Public() crypto.PublicKey { + return s.publicKey +} + +// Sign signs digest with the private key stored in the AWS KMS. +func (s *Signer) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) ([]byte, error) { + alg, err := getSigningAlgorithm(s.Public(), opts) + if err != nil { + return nil, err + } + + ctx, cancel := defaultContext() + defer cancel() + + b64 := base64.RawURLEncoding.EncodeToString(digest) + + resp, err := s.client.Sign(ctx, s.vaultBaseURL, s.name, s.version, keyvault.KeySignParameters{ + Algorithm: alg, + Value: &b64, + }) + if err != nil { + return nil, errors.Wrap(err, "keyVault Sign failed") + } + + sig, err := base64.RawURLEncoding.DecodeString(*resp.Result) + if err != nil { + return nil, errors.Wrap(err, "error decoding keyVault Sign result") + } + + var octetSize int + switch alg { + case keyvault.ES256: + octetSize = 32 // 256-bit, concat(R,S) = 64 bytes + case keyvault.ES384: + octetSize = 48 // 384-bit, concat(R,S) = 96 bytes + case keyvault.ES512: + octetSize = 66 // 528-bit, concat(R,S) = 132 bytes + default: + return sig, nil + } + + // Convert to asn1 + if len(sig) != octetSize*2 { + return nil, errors.Errorf("keyVault Sign failed: unexpected signature length") + } + var b cryptobyte.Builder + b.AddASN1(asn1.SEQUENCE, func(b *cryptobyte.Builder) { + b.AddASN1BigInt(new(big.Int).SetBytes(sig[:octetSize])) // R + b.AddASN1BigInt(new(big.Int).SetBytes(sig[octetSize:])) // S + }) + return b.Bytes() +} + +func getSigningAlgorithm(key crypto.PublicKey, opts crypto.SignerOpts) (keyvault.JSONWebKeySignatureAlgorithm, error) { + switch key.(type) { + case *rsa.PublicKey: + hashFunc := opts.HashFunc() + pss, isPSS := opts.(*rsa.PSSOptions) + // Random salt lengths are not supported + if isPSS && + pss.SaltLength != rsa.PSSSaltLengthAuto && + pss.SaltLength != rsa.PSSSaltLengthEqualsHash && + pss.SaltLength != hashFunc.Size() { + return "", errors.Errorf("unsupported RSA-PSS salt length %d", pss.SaltLength) + } + + switch h := hashFunc; h { + case crypto.SHA256: + if isPSS { + return keyvault.PS256, nil + } + return keyvault.RS256, nil + case crypto.SHA384: + if isPSS { + return keyvault.PS384, nil + } + return keyvault.RS384, nil + case crypto.SHA512: + if isPSS { + return keyvault.PS512, nil + } + return keyvault.RS512, nil + default: + return "", errors.Errorf("unsupported hash function %v", h) + } + case *ecdsa.PublicKey: + switch h := opts.HashFunc(); h { + case crypto.SHA256: + return keyvault.ES256, nil + case crypto.SHA384: + return keyvault.ES384, nil + case crypto.SHA512: + return keyvault.ES512, nil + default: + return "", errors.Errorf("unsupported hash function %v", h) + } + default: + return "", errors.Errorf("unsupported key type %T", key) + } +} diff --git a/kms/azurekms/signer_test.go b/kms/azurekms/signer_test.go new file mode 100644 index 00000000..381c3577 --- /dev/null +++ b/kms/azurekms/signer_test.go @@ -0,0 +1,352 @@ +package azurekms + +import ( + "crypto" + "crypto/ecdsa" + "crypto/rand" + "crypto/rsa" + "encoding/base64" + "io" + "reflect" + "testing" + + "github.com/Azure/azure-sdk-for-go/services/keyvault/v7.1/keyvault" + "github.com/golang/mock/gomock" + "github.com/smallstep/certificates/kms/apiv1" + "go.step.sm/crypto/keyutil" + "golang.org/x/crypto/cryptobyte" + "golang.org/x/crypto/cryptobyte/asn1" +) + +func TestNewSigner(t *testing.T) { + key, err := keyutil.GenerateDefaultSigner() + if err != nil { + t.Fatal(err) + } + pub := key.Public() + jwk := createJWK(t, pub) + + client := mockClient(t) + client.EXPECT().GetKey(gomock.Any(), "https://my-vault.vault.azure.net/", "my-key", "").Return(keyvault.KeyBundle{ + Key: jwk, + }, nil) + client.EXPECT().GetKey(gomock.Any(), "https://my-vault.vault.azure.net/", "my-key", "my-version").Return(keyvault.KeyBundle{ + Key: jwk, + }, nil) + client.EXPECT().GetKey(gomock.Any(), "https://my-vault.vault.azure.net/", "my-key", "my-version").Return(keyvault.KeyBundle{ + Key: jwk, + }, nil) + client.EXPECT().GetKey(gomock.Any(), "https://my-vault.vault.azure.net/", "not-found", "my-version").Return(keyvault.KeyBundle{}, errTest) + + var noOptions DefaultOptions + type args struct { + client KeyVaultClient + signingKey string + defaults DefaultOptions + } + tests := []struct { + name string + args args + want crypto.Signer + wantErr bool + }{ + {"ok", args{client, "azurekms:vault=my-vault;name=my-key", noOptions}, &Signer{ + client: client, + vaultBaseURL: "https://my-vault.vault.azure.net/", + name: "my-key", + version: "", + publicKey: pub, + }, false}, + {"ok with version", args{client, "azurekms:name=my-key;vault=my-vault?version=my-version", noOptions}, &Signer{ + client: client, + vaultBaseURL: "https://my-vault.vault.azure.net/", + name: "my-key", + version: "my-version", + publicKey: pub, + }, false}, + {"ok with options", args{client, "azurekms:name=my-key?version=my-version", DefaultOptions{Vault: "my-vault", ProtectionLevel: apiv1.HSM}}, &Signer{ + client: client, + vaultBaseURL: "https://my-vault.vault.azure.net/", + name: "my-key", + version: "my-version", + publicKey: pub, + }, false}, + {"fail GetKey", args{client, "azurekms:name=not-found;vault=my-vault?version=my-version", noOptions}, nil, true}, + {"fail vault", args{client, "azurekms:name=not-found;vault=", noOptions}, nil, true}, + {"fail id", args{client, "azurekms:name=;vault=my-vault?version=my-version", noOptions}, nil, true}, + {"fail scheme", args{client, "kms:name=not-found;vault=my-vault?version=my-version", noOptions}, nil, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := NewSigner(tt.args.client, tt.args.signingKey, tt.args.defaults) + if (err != nil) != tt.wantErr { + t.Errorf("NewSigner() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("NewSigner() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestSigner_Public(t *testing.T) { + key, err := keyutil.GenerateDefaultSigner() + if err != nil { + t.Fatal(err) + } + pub := key.Public() + + type fields struct { + publicKey crypto.PublicKey + } + tests := []struct { + name string + fields fields + want crypto.PublicKey + }{ + {"ok", fields{pub}, pub}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &Signer{ + publicKey: tt.fields.publicKey, + } + if got := s.Public(); !reflect.DeepEqual(got, tt.want) { + t.Errorf("Signer.Public() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestSigner_Sign(t *testing.T) { + sign := func(kty, crv string, bits int, opts crypto.SignerOpts) (crypto.PublicKey, []byte, string, []byte) { + key, err := keyutil.GenerateSigner(kty, crv, bits) + if err != nil { + t.Fatal(err) + } + h := opts.HashFunc().New() + h.Write([]byte("random-data")) + sum := h.Sum(nil) + + var sig, resultSig []byte + if priv, ok := key.(*ecdsa.PrivateKey); ok { + r, s, err := ecdsa.Sign(rand.Reader, priv, sum) + if err != nil { + t.Fatal(err) + } + curveBits := priv.Params().BitSize + keyBytes := curveBits / 8 + if curveBits%8 > 0 { + keyBytes++ + } + rBytes := r.Bytes() + rBytesPadded := make([]byte, keyBytes) + copy(rBytesPadded[keyBytes-len(rBytes):], rBytes) + + sBytes := s.Bytes() + sBytesPadded := make([]byte, keyBytes) + copy(sBytesPadded[keyBytes-len(sBytes):], sBytes) + // nolint:gocritic + resultSig = append(rBytesPadded, sBytesPadded...) + + var b cryptobyte.Builder + b.AddASN1(asn1.SEQUENCE, func(b *cryptobyte.Builder) { + b.AddASN1BigInt(r) + b.AddASN1BigInt(s) + }) + sig, err = b.Bytes() + if err != nil { + t.Fatal(err) + } + } else { + sig, err = key.Sign(rand.Reader, sum, opts) + if err != nil { + t.Fatal(err) + } + resultSig = sig + } + + return key.Public(), h.Sum(nil), base64.RawURLEncoding.EncodeToString(resultSig), sig + } + + p256, p256Digest, p256ResultSig, p256Sig := sign("EC", "P-256", 0, crypto.SHA256) + p384, p384Digest, p386ResultSig, p384Sig := sign("EC", "P-384", 0, crypto.SHA384) + p521, p521Digest, p521ResultSig, p521Sig := sign("EC", "P-521", 0, crypto.SHA512) + rsaSHA256, rsaSHA256Digest, rsaSHA256ResultSig, rsaSHA256Sig := sign("RSA", "", 2048, crypto.SHA256) + rsaSHA384, rsaSHA384Digest, rsaSHA384ResultSig, rsaSHA384Sig := sign("RSA", "", 2048, crypto.SHA384) + rsaSHA512, rsaSHA512Digest, rsaSHA512ResultSig, rsaSHA512Sig := sign("RSA", "", 2048, crypto.SHA512) + rsaPSSSHA256, rsaPSSSHA256Digest, rsaPSSSHA256ResultSig, rsaPSSSHA256Sig := sign("RSA", "", 2048, &rsa.PSSOptions{ + SaltLength: rsa.PSSSaltLengthAuto, + Hash: crypto.SHA256, + }) + rsaPSSSHA384, rsaPSSSHA384Digest, rsaPSSSHA384ResultSig, rsaPSSSHA384Sig := sign("RSA", "", 2048, &rsa.PSSOptions{ + SaltLength: rsa.PSSSaltLengthAuto, + Hash: crypto.SHA512, + }) + rsaPSSSHA512, rsaPSSSHA512Digest, rsaPSSSHA512ResultSig, rsaPSSSHA512Sig := sign("RSA", "", 2048, &rsa.PSSOptions{ + SaltLength: rsa.PSSSaltLengthAuto, + Hash: crypto.SHA512, + }) + + ed25519Key, err := keyutil.GenerateSigner("OKP", "Ed25519", 0) + if err != nil { + t.Fatal(err) + } + + client := mockClient(t) + expects := []struct { + name string + keyVersion string + alg keyvault.JSONWebKeySignatureAlgorithm + digest []byte + result keyvault.KeyOperationResult + err error + }{ + {"P-256", "", keyvault.ES256, p256Digest, keyvault.KeyOperationResult{ + Result: &p256ResultSig, + }, nil}, + {"P-384", "my-version", keyvault.ES384, p384Digest, keyvault.KeyOperationResult{ + Result: &p386ResultSig, + }, nil}, + {"P-521", "my-version", keyvault.ES512, p521Digest, keyvault.KeyOperationResult{ + Result: &p521ResultSig, + }, nil}, + {"RSA SHA256", "", keyvault.RS256, rsaSHA256Digest, keyvault.KeyOperationResult{ + Result: &rsaSHA256ResultSig, + }, nil}, + {"RSA SHA384", "", keyvault.RS384, rsaSHA384Digest, keyvault.KeyOperationResult{ + Result: &rsaSHA384ResultSig, + }, nil}, + {"RSA SHA512", "", keyvault.RS512, rsaSHA512Digest, keyvault.KeyOperationResult{ + Result: &rsaSHA512ResultSig, + }, nil}, + {"RSA-PSS SHA256", "", keyvault.PS256, rsaPSSSHA256Digest, keyvault.KeyOperationResult{ + Result: &rsaPSSSHA256ResultSig, + }, nil}, + {"RSA-PSS SHA384", "", keyvault.PS384, rsaPSSSHA384Digest, keyvault.KeyOperationResult{ + Result: &rsaPSSSHA384ResultSig, + }, nil}, + {"RSA-PSS SHA512", "", keyvault.PS512, rsaPSSSHA512Digest, keyvault.KeyOperationResult{ + Result: &rsaPSSSHA512ResultSig, + }, nil}, + // Errors + {"fail Sign", "", keyvault.RS256, rsaSHA256Digest, keyvault.KeyOperationResult{}, errTest}, + {"fail sign length", "", keyvault.ES256, p256Digest, keyvault.KeyOperationResult{ + Result: &rsaSHA256ResultSig, + }, nil}, + {"fail base64", "", keyvault.ES256, p256Digest, keyvault.KeyOperationResult{ + Result: func() *string { + v := "😎" + return &v + }(), + }, nil}, + } + for _, e := range expects { + value := base64.RawURLEncoding.EncodeToString(e.digest) + client.EXPECT().Sign(gomock.Any(), "https://my-vault.vault.azure.net/", "my-key", e.keyVersion, keyvault.KeySignParameters{ + Algorithm: e.alg, + Value: &value, + }).Return(e.result, e.err) + } + + type fields struct { + client KeyVaultClient + vaultBaseURL string + name string + version string + publicKey crypto.PublicKey + } + type args struct { + rand io.Reader + digest []byte + opts crypto.SignerOpts + } + tests := []struct { + name string + fields fields + args args + want []byte + wantErr bool + }{ + {"ok P-256", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", p256}, args{ + rand.Reader, p256Digest, crypto.SHA256, + }, p256Sig, false}, + {"ok P-384", fields{client, "https://my-vault.vault.azure.net/", "my-key", "my-version", p384}, args{ + rand.Reader, p384Digest, crypto.SHA384, + }, p384Sig, false}, + {"ok P-521", fields{client, "https://my-vault.vault.azure.net/", "my-key", "my-version", p521}, args{ + rand.Reader, p521Digest, crypto.SHA512, + }, p521Sig, false}, + {"ok RSA SHA256", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", rsaSHA256}, args{ + rand.Reader, rsaSHA256Digest, crypto.SHA256, + }, rsaSHA256Sig, false}, + {"ok RSA SHA384", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", rsaSHA384}, args{ + rand.Reader, rsaSHA384Digest, crypto.SHA384, + }, rsaSHA384Sig, false}, + {"ok RSA SHA512", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", rsaSHA512}, args{ + rand.Reader, rsaSHA512Digest, crypto.SHA512, + }, rsaSHA512Sig, false}, + {"ok RSA-PSS SHA256", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", rsaPSSSHA256}, args{ + rand.Reader, rsaPSSSHA256Digest, &rsa.PSSOptions{ + SaltLength: rsa.PSSSaltLengthAuto, + Hash: crypto.SHA256, + }, + }, rsaPSSSHA256Sig, false}, + {"ok RSA-PSS SHA384", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", rsaPSSSHA384}, args{ + rand.Reader, rsaPSSSHA384Digest, &rsa.PSSOptions{ + SaltLength: rsa.PSSSaltLengthEqualsHash, + Hash: crypto.SHA384, + }, + }, rsaPSSSHA384Sig, false}, + {"ok RSA-PSS SHA512", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", rsaPSSSHA512}, args{ + rand.Reader, rsaPSSSHA512Digest, &rsa.PSSOptions{ + SaltLength: 64, + Hash: crypto.SHA512, + }, + }, rsaPSSSHA512Sig, false}, + {"fail Sign", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", rsaSHA256}, args{ + rand.Reader, rsaSHA256Digest, crypto.SHA256, + }, nil, true}, + {"fail sign length", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", p256}, args{ + rand.Reader, p256Digest, crypto.SHA256, + }, nil, true}, + {"fail base64", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", p256}, args{ + rand.Reader, p256Digest, crypto.SHA256, + }, nil, true}, + {"fail RSA-PSS salt length", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", rsaPSSSHA256}, args{ + rand.Reader, rsaPSSSHA256Digest, &rsa.PSSOptions{ + SaltLength: 64, + Hash: crypto.SHA256, + }, + }, nil, true}, + {"fail RSA Hash", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", rsaSHA256}, args{ + rand.Reader, rsaSHA256Digest, crypto.SHA1, + }, nil, true}, + {"fail ECDSA Hash", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", p256}, args{ + rand.Reader, p256Digest, crypto.MD5, + }, nil, true}, + {"fail Ed25519", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", ed25519Key}, args{ + rand.Reader, []byte("message"), crypto.Hash(0), + }, nil, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &Signer{ + client: tt.fields.client, + vaultBaseURL: tt.fields.vaultBaseURL, + name: tt.fields.name, + version: tt.fields.version, + publicKey: tt.fields.publicKey, + } + got, err := s.Sign(tt.args.rand, tt.args.digest, tt.args.opts) + if (err != nil) != tt.wantErr { + t.Errorf("Signer.Sign() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Signer.Sign() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/kms/azurekms/utils.go b/kms/azurekms/utils.go new file mode 100644 index 00000000..d4201907 --- /dev/null +++ b/kms/azurekms/utils.go @@ -0,0 +1,98 @@ +package azurekms + +import ( + "context" + "crypto" + "encoding/json" + "net/url" + "time" + + "github.com/Azure/azure-sdk-for-go/services/keyvault/v7.1/keyvault" + "github.com/pkg/errors" + "github.com/smallstep/certificates/kms/apiv1" + "github.com/smallstep/certificates/kms/uri" + "go.step.sm/crypto/jose" +) + +// defaultContext returns the default context used in requests to azure. +func defaultContext() (context.Context, context.CancelFunc) { + return context.WithTimeout(context.Background(), 15*time.Second) +} + +// getKeyName returns the uri of the key vault key. +func getKeyName(vault, name string, bundle keyvault.KeyBundle) string { + if bundle.Key != nil && bundle.Key.Kid != nil { + sm := keyIDRegexp.FindAllStringSubmatch(*bundle.Key.Kid, 1) + if len(sm) == 1 && len(sm[0]) == 4 { + m := sm[0] + u := uri.New(Scheme, url.Values{ + "vault": []string{m[1]}, + "name": []string{m[2]}, + }) + u.RawQuery = url.Values{"version": []string{m[3]}}.Encode() + return u.String() + } + } + // Fallback to URI without id. + return uri.New(Scheme, url.Values{ + "vault": []string{vault}, + "name": []string{name}, + }).String() +} + +// parseKeyName returns the key vault, name and version from URIs like: +// +// - azurekms:vault=key-vault;name=key-name +// - azurekms:vault=key-vault;name=key-name?version=key-id +// - azurekms:vault=key-vault;name=key-name?version=key-id&hsm=true +// +// The key-id defines the version of the key, if it is not passed the latest +// version will be used. +// +// HSM can also be passed to define the protection level if this is not given in +// CreateQuery. +func parseKeyName(rawURI string, defaults DefaultOptions) (vault, name, version string, hsm bool, err error) { + var u *uri.URI + + u, err = uri.ParseWithScheme(Scheme, rawURI) + if err != nil { + return + } + if name = u.Get("name"); name == "" { + err = errors.Errorf("key uri %s is not valid: name is missing", rawURI) + return + } + if vault = u.Get("vault"); vault == "" { + if defaults.Vault == "" { + name = "" + err = errors.Errorf("key uri %s is not valid: vault is missing", rawURI) + return + } + vault = defaults.Vault + } + if u.Get("hsm") == "" { + hsm = (defaults.ProtectionLevel == apiv1.HSM) + } else { + hsm = u.GetBool("hsm") + } + + version = u.Get("version") + + return +} + +func vaultBaseURL(vault string) string { + return "https://" + vault + ".vault.azure.net/" +} + +func convertKey(key *keyvault.JSONWebKey) (crypto.PublicKey, error) { + b, err := json.Marshal(key) + if err != nil { + return nil, errors.Wrap(err, "error marshaling key") + } + var jwk jose.JSONWebKey + if err := jwk.UnmarshalJSON(b); err != nil { + return nil, errors.Wrap(err, "error unmarshaling key") + } + return jwk.Key, nil +} diff --git a/kms/azurekms/utils_test.go b/kms/azurekms/utils_test.go new file mode 100644 index 00000000..cded50ea --- /dev/null +++ b/kms/azurekms/utils_test.go @@ -0,0 +1,96 @@ +package azurekms + +import ( + "testing" + + "github.com/Azure/azure-sdk-for-go/services/keyvault/v7.1/keyvault" + "github.com/smallstep/certificates/kms/apiv1" +) + +func Test_getKeyName(t *testing.T) { + getBundle := func(kid string) keyvault.KeyBundle { + return keyvault.KeyBundle{ + Key: &keyvault.JSONWebKey{ + Kid: &kid, + }, + } + } + + type args struct { + vault string + name string + bundle keyvault.KeyBundle + } + tests := []struct { + name string + args args + want string + }{ + {"ok", args{"my-vault", "my-key", getBundle("https://my-vault.vault.azure.net/keys/my-key/my-version")}, "azurekms:name=my-key;vault=my-vault?version=my-version"}, + {"ok default", args{"my-vault", "my-key", getBundle("https://my-vault.foo.net/keys/my-key/my-version")}, "azurekms:name=my-key;vault=my-vault"}, + {"ok too short", args{"my-vault", "my-key", getBundle("https://my-vault.vault.azure.net/keys/my-version")}, "azurekms:name=my-key;vault=my-vault"}, + {"ok too long", args{"my-vault", "my-key", getBundle("https://my-vault.vault.azure.net/keys/my-key/my-version/sign")}, "azurekms:name=my-key;vault=my-vault"}, + {"ok nil key", args{"my-vault", "my-key", keyvault.KeyBundle{}}, "azurekms:name=my-key;vault=my-vault"}, + {"ok nil kid", args{"my-vault", "my-key", keyvault.KeyBundle{Key: &keyvault.JSONWebKey{}}}, "azurekms:name=my-key;vault=my-vault"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := getKeyName(tt.args.vault, tt.args.name, tt.args.bundle); got != tt.want { + t.Errorf("getKeyName() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_parseKeyName(t *testing.T) { + var noOptions DefaultOptions + type args struct { + rawURI string + defaults DefaultOptions + } + tests := []struct { + name string + args args + wantVault string + wantName string + wantVersion string + wantHsm bool + wantErr bool + }{ + {"ok", args{"azurekms:name=my-key;vault=my-vault?version=my-version", noOptions}, "my-vault", "my-key", "my-version", false, false}, + {"ok opaque version", args{"azurekms:name=my-key;vault=my-vault;version=my-version", noOptions}, "my-vault", "my-key", "my-version", false, false}, + {"ok no version", args{"azurekms:name=my-key;vault=my-vault", noOptions}, "my-vault", "my-key", "", false, false}, + {"ok hsm", args{"azurekms:name=my-key;vault=my-vault?hsm=true", noOptions}, "my-vault", "my-key", "", true, false}, + {"ok hsm false", args{"azurekms:name=my-key;vault=my-vault?hsm=false", noOptions}, "my-vault", "my-key", "", false, false}, + {"ok default vault", args{"azurekms:name=my-key?version=my-version", DefaultOptions{Vault: "my-vault"}}, "my-vault", "my-key", "my-version", false, false}, + {"ok default hsm", args{"azurekms:name=my-key;vault=my-vault?version=my-version", DefaultOptions{Vault: "other-vault", ProtectionLevel: apiv1.HSM}}, "my-vault", "my-key", "my-version", true, false}, + {"fail scheme", args{"azure:name=my-key;vault=my-vault", noOptions}, "", "", "", false, true}, + {"fail parse uri", args{"azurekms:name=%ZZ;vault=my-vault", noOptions}, "", "", "", false, true}, + {"fail no name", args{"azurekms:vault=my-vault", noOptions}, "", "", "", false, true}, + {"fail empty name", args{"azurekms:name=;vault=my-vault", noOptions}, "", "", "", false, true}, + {"fail no vault", args{"azurekms:name=my-key", noOptions}, "", "", "", false, true}, + {"fail empty vault", args{"azurekms:name=my-key;vault=", noOptions}, "", "", "", false, true}, + {"fail empty", args{"", noOptions}, "", "", "", false, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotVault, gotName, gotVersion, gotHsm, err := parseKeyName(tt.args.rawURI, tt.args.defaults) + if (err != nil) != tt.wantErr { + t.Errorf("parseKeyName() error = %v, wantErr %v", err, tt.wantErr) + return + } + if gotVault != tt.wantVault { + t.Errorf("parseKeyName() gotVault = %v, want %v", gotVault, tt.wantVault) + } + if gotName != tt.wantName { + t.Errorf("parseKeyName() gotName = %v, want %v", gotName, tt.wantName) + } + if gotVersion != tt.wantVersion { + t.Errorf("parseKeyName() gotVersion = %v, want %v", gotVersion, tt.wantVersion) + } + if gotHsm != tt.wantHsm { + t.Errorf("parseKeyName() gotHsm = %v, want %v", gotHsm, tt.wantHsm) + } + }) + } +} diff --git a/kms/kms.go b/kms/kms.go index 3eddca93..92b544df 100644 --- a/kms/kms.go +++ b/kms/kms.go @@ -8,7 +8,7 @@ import ( "github.com/smallstep/certificates/kms/apiv1" // Enable default implementation - _ "github.com/smallstep/certificates/kms/softkms" + "github.com/smallstep/certificates/kms/softkms" ) // KeyManager is the interface implemented by all the KMS. @@ -18,6 +18,12 @@ type KeyManager = apiv1.KeyManager // store x509.Certificates. type CertificateManager = apiv1.CertificateManager +// Options are the KMS options. They represent the kms object in the ca.json. +type Options = apiv1.Options + +// Default is the implementation of the default KMS. +var Default = &softkms.SoftKMS{} + // New initializes a new KMS from the given type. func New(ctx context.Context, opts apiv1.Options) (KeyManager, error) { if err := opts.Validate(); err != nil { diff --git a/kms/uri/uri.go b/kms/uri/uri.go index 44271e74..36e15e7d 100644 --- a/kms/uri/uri.go +++ b/kms/uri/uri.go @@ -95,6 +95,16 @@ func (u *URI) Get(key string) string { return v } +// GetBool returns true if a given key has the value "true". It returns false +// otherwise. +func (u *URI) GetBool(key string) bool { + v := u.Values.Get(key) + if v == "" { + v = u.URL.Query().Get(key) + } + return strings.EqualFold(v, "true") +} + // GetEncoded returns the first value in the uri with the given key, it will // return empty nil if that field is not present or is empty. If the return // value is hex encoded it will decode it and return it. diff --git a/kms/uri/uri_test.go b/kms/uri/uri_test.go index c2e0a9fe..01fbad0f 100644 --- a/kms/uri/uri_test.go +++ b/kms/uri/uri_test.go @@ -212,6 +212,40 @@ func TestURI_Get(t *testing.T) { } } +func TestURI_GetBool(t *testing.T) { + mustParse := func(s string) *URI { + u, err := Parse(s) + if err != nil { + t.Fatal(err) + } + return u + } + type args struct { + key string + } + tests := []struct { + name string + uri *URI + args args + want bool + }{ + {"true", mustParse("azurekms:name=foo;vault=bar;hsm=true"), args{"hsm"}, true}, + {"TRUE", mustParse("azurekms:name=foo;vault=bar;hsm=TRUE"), args{"hsm"}, true}, + {"tRUe query", mustParse("azurekms:name=foo;vault=bar?hsm=tRUe"), args{"hsm"}, true}, + {"false", mustParse("azurekms:name=foo;vault=bar;hsm=false"), args{"hsm"}, false}, + {"false query", mustParse("azurekms:name=foo;vault=bar?hsm=false"), args{"hsm"}, false}, + {"empty", mustParse("azurekms:name=foo;vault=bar;hsm=?bar=true"), args{"hsm"}, false}, + {"missing", mustParse("azurekms:name=foo;vault=bar"), args{"hsm"}, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.uri.GetBool(tt.args.key); got != tt.want { + t.Errorf("URI.GetBool() = %v, want %v", got, tt.want) + } + }) + } +} + func TestURI_GetEncoded(t *testing.T) { mustParse := func(s string) *URI { u, err := Parse(s) diff --git a/pki/pki.go b/pki/pki.go index 18cd0dda..61e20b6b 100644 --- a/pki/pki.go +++ b/pki/pki.go @@ -26,13 +26,14 @@ import ( "github.com/smallstep/certificates/cas" "github.com/smallstep/certificates/cas/apiv1" "github.com/smallstep/certificates/db" + "github.com/smallstep/certificates/kms" + kmsapi "github.com/smallstep/certificates/kms/apiv1" "github.com/smallstep/nosql" "go.step.sm/cli-utils/config" "go.step.sm/cli-utils/errs" "go.step.sm/cli-utils/fileutil" "go.step.sm/cli-utils/ui" "go.step.sm/crypto/jose" - "go.step.sm/crypto/keyutil" "go.step.sm/crypto/pemutil" "go.step.sm/linkedca" "golang.org/x/crypto/ssh" @@ -168,14 +169,18 @@ func GetProvisionerKey(caURL, rootFile, kid string) (string, error) { } type options struct { - provisioner string - pkiOnly bool - enableACME bool - enableSSH bool - enableAdmin bool - noDB bool - isHelm bool - deploymentType DeploymentType + provisioner string + pkiOnly bool + enableACME bool + enableSSH bool + enableAdmin bool + noDB bool + isHelm bool + deploymentType DeploymentType + rootKeyURI string + intermediateKeyURI string + hostKeyURI string + userKeyURI string } // Option is the type of a configuration option on the pki constructor. @@ -258,6 +263,26 @@ func WithDeploymentType(dt DeploymentType) Option { } } +// WithKMS enables the kms with the given name. +func WithKMS(name string) Option { + return func(p *PKI) { + typ := linkedca.KMS_Type_value[strings.ToUpper(name)] + p.Configuration.Kms = &linkedca.KMS{ + Type: linkedca.KMS_Type(typ), + } + } +} + +// WithKeyURIs defines the key uris for X.509 and SSH keys. +func WithKeyURIs(rootKey, intermediateKey, hostKey, userKey string) Option { + return func(p *PKI) { + p.options.rootKeyURI = rootKey + p.options.intermediateKeyURI = intermediateKey + p.options.hostKeyURI = hostKey + p.options.userKeyURI = userKey + } +} + // PKI represents the Public Key Infrastructure used by a certificate authority. type PKI struct { linkedca.Configuration @@ -265,6 +290,7 @@ type PKI struct { casOptions apiv1.Options caService apiv1.CertificateAuthorityService caCreator apiv1.CertificateAuthorityCreator + keyManager kmsapi.KeyManager config string defaults string ottPublicKey *jose.JSONWebKey @@ -303,8 +329,9 @@ func New(o apiv1.Options, opts ...Option) (*PKI, error) { Files: make(map[string][]byte), }, casOptions: o, - caCreator: caCreator, caService: caService, + caCreator: caCreator, + keyManager: o.KeyManager, options: &options{ provisioner: "step-cli", }, @@ -313,6 +340,11 @@ func New(o apiv1.Options, opts ...Option) (*PKI, error) { fn(p) } + // Use default key manager + if p.keyManager == nil { + p.keyManager = kms.Default + } + // Use /home/step as the step path in helm configurations. // Use the current step path when creating pki in files. var public, private, cfg string @@ -448,11 +480,18 @@ func (p *PKI) GenerateKeyPairs(pass []byte) error { // GenerateRootCertificate generates a root certificate with the given name // and using the default key type. func (p *PKI) GenerateRootCertificate(name, org, resource string, pass []byte) (*apiv1.CreateCertificateAuthorityResponse, error) { + if uri := p.options.rootKeyURI; uri != "" { + p.RootKey[0] = uri + } + resp, err := p.caCreator.CreateCertificateAuthority(&apiv1.CreateCertificateAuthorityRequest{ - Name: resource + "-Root-CA", - Type: apiv1.RootCA, - Lifetime: 10 * 365 * 24 * time.Hour, - CreateKey: nil, // use default + Name: resource + "-Root-CA", + Type: apiv1.RootCA, + Lifetime: 10 * 365 * 24 * time.Hour, + CreateKey: &apiv1.CreateKeyRequest{ + Name: p.RootKey[0], + SignatureAlgorithm: kmsapi.UnspecifiedSignAlgorithm, + }, Template: &x509.Certificate{ Subject: pkix.Name{ CommonName: name + " Root CA", @@ -469,6 +508,13 @@ func (p *PKI) GenerateRootCertificate(name, org, resource string, pass []byte) ( return nil, err } + // Replace key name with the one from the key manager if available. On + // softcas this will be the original filename, on any other kms will be the + // uri to the key. + if resp.KeyName != "" { + p.RootKey[0] = resp.KeyName + } + // PrivateKey will only be set if we have access to it (SoftCAS). if err := p.WriteRootCertificate(resp.Certificate, resp.PrivateKey, pass); err != nil { return nil, err @@ -495,11 +541,18 @@ func (p *PKI) WriteRootCertificate(rootCrt *x509.Certificate, rootKey interface{ // GenerateIntermediateCertificate generates an intermediate certificate with // the given name and using the default key type. func (p *PKI) GenerateIntermediateCertificate(name, org, resource string, parent *apiv1.CreateCertificateAuthorityResponse, pass []byte) error { + if uri := p.options.intermediateKeyURI; uri != "" { + p.IntermediateKey = uri + } + resp, err := p.caCreator.CreateCertificateAuthority(&apiv1.CreateCertificateAuthorityRequest{ - Name: resource + "-Intermediate-CA", - Type: apiv1.IntermediateCA, - Lifetime: 10 * 365 * 24 * time.Hour, - CreateKey: nil, // use default + Name: resource + "-Intermediate-CA", + Type: apiv1.IntermediateCA, + Lifetime: 10 * 365 * 24 * time.Hour, + CreateKey: &apiv1.CreateKeyRequest{ + Name: p.IntermediateKey, + SignatureAlgorithm: kmsapi.UnspecifiedSignAlgorithm, + }, Template: &x509.Certificate{ Subject: pkix.Name{ CommonName: name + " Intermediate CA", @@ -519,7 +572,19 @@ func (p *PKI) GenerateIntermediateCertificate(name, org, resource string, parent p.casOptions.CertificateAuthority = resp.Name p.Files[p.Intermediate] = encodeCertificate(resp.Certificate) - p.Files[p.IntermediateKey], err = encodePrivateKey(resp.PrivateKey, pass) + + // Replace the key name with the one from the key manager. On softcas this + // will be the original filename, on any other kms will be the uri to the + // key. + if resp.KeyName != "" { + p.IntermediateKey = resp.KeyName + } + + // If a kms is used it will not have the private key + if resp.PrivateKey != nil { + p.Files[p.IntermediateKey], err = encodePrivateKey(resp.PrivateKey, pass) + } + return err } @@ -564,27 +629,67 @@ func (p *PKI) GetCertificateAuthority() error { // GenerateSSHSigningKeys generates and encrypts a private key used for signing // SSH user certificates and a private key used for signing host certificates. func (p *PKI) GenerateSSHSigningKeys(password []byte) error { - var pubNames = []string{p.Ssh.HostPublicKey, p.Ssh.UserPublicKey} - var privNames = []string{p.Ssh.HostKey, p.Ssh.UserKey} - for i := 0; i < 2; i++ { - pub, priv, err := keyutil.GenerateDefaultKeyPair() - if err != nil { - return err - } - if _, ok := priv.(crypto.Signer); !ok { - return errors.Errorf("key of type %T is not a crypto.Signer", priv) - } - sshKey, err := ssh.NewPublicKey(pub) - if err != nil { - return errors.Wrapf(err, "error converting public key") - } - p.Files[pubNames[i]] = ssh.MarshalAuthorizedKey(sshKey) - p.Files[privNames[i]], err = encodePrivateKey(priv, password) - if err != nil { - return err - } - } + // Enable SSH p.options.enableSSH = true + + // Create SSH key used to sign host certificates. Using + // kmsapi.UnspecifiedSignAlgorithm will default to the default algorithm. + name := p.Ssh.HostKey + if uri := p.options.hostKeyURI; uri != "" { + name = uri + } + resp, err := p.keyManager.CreateKey(&kmsapi.CreateKeyRequest{ + Name: name, + SignatureAlgorithm: kmsapi.UnspecifiedSignAlgorithm, + }) + if err != nil { + return err + } + sshKey, err := ssh.NewPublicKey(resp.PublicKey) + if err != nil { + return errors.Wrapf(err, "error converting public key") + } + p.Files[p.Ssh.HostPublicKey] = ssh.MarshalAuthorizedKey(sshKey) + + // On softkms we will have the private key + if resp.PrivateKey != nil { + p.Files[p.Ssh.HostKey], err = encodePrivateKey(resp.PrivateKey, password) + if err != nil { + return err + } + } else { + p.Ssh.HostKey = resp.Name + } + + // Create SSH key used to sign user certificates. Using + // kmsapi.UnspecifiedSignAlgorithm will default to the default algorithm. + name = p.Ssh.UserKey + if uri := p.options.userKeyURI; uri != "" { + name = uri + } + resp, err = p.keyManager.CreateKey(&kmsapi.CreateKeyRequest{ + Name: name, + SignatureAlgorithm: kmsapi.UnspecifiedSignAlgorithm, + }) + if err != nil { + return err + } + sshKey, err = ssh.NewPublicKey(resp.PublicKey) + if err != nil { + return errors.Wrapf(err, "error converting public key") + } + p.Files[p.Ssh.UserPublicKey] = ssh.MarshalAuthorizedKey(sshKey) + + // On softkms we will have the private key + if resp.PrivateKey != nil { + p.Files[p.Ssh.UserKey], err = encodePrivateKey(resp.PrivateKey, password) + if err != nil { + return err + } + } else { + p.Ssh.UserKey = resp.Name + } + return nil } @@ -685,6 +790,13 @@ func (p *PKI) GenerateConfig(opt ...ConfigOption) (*authconfig.Config, error) { cfg.AuthorityConfig.DeploymentType = LinkedDeployment.String() } + // Enable KMS if necessary + if p.Kms != nil { + cfg.KMS = &kmsapi.Options{ + Type: strings.ToLower(p.Kms.Type.String()), + } + } + // On standalone deployments add the provisioners to either the ca.json or // the database. var provisioners []provisioner.Interface diff --git a/server/server.go b/server/server.go index d3968c4a..2b864148 100644 --- a/server/server.go +++ b/server/server.go @@ -72,10 +72,10 @@ func (srv *Server) Serve(ln net.Listener) error { // Start server if srv.TLSConfig == nil || (len(srv.TLSConfig.Certificates) == 0 && srv.TLSConfig.GetCertificate == nil) { log.Printf("Serving HTTP on %s ...", srv.Addr) - err = srv.Server.Serve(tcpKeepAliveListener{ln.(*net.TCPListener)}) + err = srv.Server.Serve(ln) } else { log.Printf("Serving HTTPS on %s ...", srv.Addr) - err = srv.Server.ServeTLS(tcpKeepAliveListener{ln.(*net.TCPListener)}, "", "") + err = srv.Server.ServeTLS(ln, "", "") } // log unexpected errors @@ -155,21 +155,3 @@ func (srv *Server) Forbidden(w http.ResponseWriter) { w.WriteHeader(http.StatusForbidden) w.Write([]byte("Forbidden.\n")) } - -// tcpKeepAliveListener sets TCP keep-alive timeouts on accepted -// connections. It's used by ListenAndServe and ListenAndServeTLS so -// dead TCP connections (e.g. closing laptop mid-download) eventually -// go away. -type tcpKeepAliveListener struct { - *net.TCPListener -} - -func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) { - tc, err := ln.AcceptTCP() - if err != nil { - return - } - tc.SetKeepAlive(true) - tc.SetKeepAlivePeriod(3 * time.Minute) - return tc, nil -}