diff --git a/acme/api/handler.go b/acme/api/handler.go index c1d2d62a..2a6d3a02 100644 --- a/acme/api/handler.go +++ b/acme/api/handler.go @@ -91,8 +91,8 @@ func (h *Handler) Route(r api.Router) { // Standard ACME API r.MethodFunc("GET", getPath(NewNonceLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.addDirLink(h.GetNonce))))) r.MethodFunc("HEAD", getPath(NewNonceLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.addDirLink(h.GetNonce))))) - r.MethodFunc("GET", getPath(DirectoryLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.GetDirectory)))) - r.MethodFunc("HEAD", getPath(DirectoryLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.GetDirectory)))) + r.MethodFunc("GET", getPath(DirectoryLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.GetDirectory))) + r.MethodFunc("HEAD", getPath(DirectoryLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.GetDirectory))) extractPayloadByJWK := func(next nextHTTP) nextHTTP { return h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.addDirLink(h.verifyContentType(h.parseJWS(h.validateJWS(h.extractJWK(h.verifyAndExtractJWSPayload(next))))))))) diff --git a/authority/tls.go b/authority/tls.go index c848d188..bc160ad0 100644 --- a/authority/tls.go +++ b/authority/tls.go @@ -156,14 +156,15 @@ func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.Sign return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.Sign; error creating certificate", opts...) } - if err = a.db.StoreCertificate(resp.Certificate); err != nil { + fullchain := append([]*x509.Certificate{resp.Certificate}, resp.CertificateChain...) + if err = a.storeCertificate(fullchain); err != nil { if err != db.ErrNotImplemented { return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.Sign; error storing certificate in db", opts...) } } - return append([]*x509.Certificate{resp.Certificate}, resp.CertificateChain...), nil + return fullchain, nil } // Renew creates a new Certificate identical to the old certificate, except @@ -261,13 +262,29 @@ func (a *Authority) Rekey(oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x5 return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.Rekey", opts...) } - if err = a.db.StoreCertificate(resp.Certificate); err != nil { + fullchain := append([]*x509.Certificate{resp.Certificate}, resp.CertificateChain...) + if err = a.storeCertificate(fullchain); err != nil { if err != db.ErrNotImplemented { return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.Rekey; error storing certificate in db", opts...) } } - return append([]*x509.Certificate{resp.Certificate}, resp.CertificateChain...), nil + return fullchain, nil +} + +// storeCertificate allows to use an extension of the db.AuthDB interface that +// can log the full chain of certificates. +// +// TODO: at some point we should replace the db.AuthDB interface to implement +// `StoreCertificate(...*x509.Certificate) error` instead of just +// `StoreCertificate(*x509.Certificate) error`. +func (a *Authority) storeCertificate(fullchain []*x509.Certificate) error { + if s, ok := a.db.(interface { + StoreCertificateChain(...*x509.Certificate) error + }); ok { + return s.StoreCertificateChain(fullchain...) + } + return a.db.StoreCertificate(fullchain[0]) } // RevokeOptions are the options for the Revoke API. diff --git a/ca/client.go b/ca/client.go index b9593162..19f758f1 100644 --- a/ca/client.go +++ b/ca/client.go @@ -56,10 +56,7 @@ func newClient(transport http.RoundTripper) *uaClient { func newInsecureClient() *uaClient { return &uaClient{ Client: &http.Client{ - Transport: &http.Transport{ - Proxy: http.ProxyFromEnvironment, - TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, - }, + Transport: getDefaultTransport(&tls.Config{InsecureSkipVerify: true}), }, } } @@ -99,12 +96,13 @@ type RetryFunc func(code int) bool type ClientOption func(o *clientOptions) error type clientOptions struct { - transport http.RoundTripper - rootSHA256 string - rootFilename string - rootBundle []byte - certificate tls.Certificate - retryFunc RetryFunc + transport http.RoundTripper + rootSHA256 string + rootFilename string + rootBundle []byte + certificate tls.Certificate + getClientCertificate func(*tls.CertificateRequestInfo) (*tls.Certificate, error) + retryFunc RetryFunc } func (o *clientOptions) apply(opts []ClientOption) (err error) { @@ -139,6 +137,7 @@ func (o *clientOptions) applyDefaultIdentity() error { return nil } o.certificate = crt + o.getClientCertificate = i.GetClientCertificateFunc() return nil } @@ -193,6 +192,7 @@ func (o *clientOptions) getTransport(endpoint string) (tr http.RoundTripper, err } if len(tr.TLSClientConfig.Certificates) == 0 && tr.TLSClientConfig.GetClientCertificate == nil { tr.TLSClientConfig.Certificates = []tls.Certificate{o.certificate} + tr.TLSClientConfig.GetClientCertificate = o.getClientCertificate } case *http2.Transport: if tr.TLSClientConfig == nil { @@ -200,6 +200,7 @@ func (o *clientOptions) getTransport(endpoint string) (tr http.RoundTripper, err } if len(tr.TLSClientConfig.Certificates) == 0 && tr.TLSClientConfig.GetClientCertificate == nil { tr.TLSClientConfig.Certificates = []tls.Certificate{o.certificate} + tr.TLSClientConfig.GetClientCertificate = o.getClientCertificate } default: return nil, errors.Errorf("unsupported transport type %T", tr) @@ -288,7 +289,7 @@ func getTransportFromFile(filename string) (http.RoundTripper, error) { MinVersion: tls.VersionTLS12, PreferServerCipherSuites: true, RootCAs: pool, - }) + }), nil } func getTransportFromSHA256(endpoint, sum string) (http.RoundTripper, error) { @@ -307,7 +308,7 @@ func getTransportFromSHA256(endpoint, sum string) (http.RoundTripper, error) { MinVersion: tls.VersionTLS12, PreferServerCipherSuites: true, RootCAs: pool, - }) + }), nil } func getTransportFromCABundle(bundle []byte) (http.RoundTripper, error) { @@ -319,7 +320,7 @@ func getTransportFromCABundle(bundle []byte) (http.RoundTripper, error) { MinVersion: tls.VersionTLS12, PreferServerCipherSuites: true, RootCAs: pool, - }) + }), nil } // parseEndpoint parses and validates the given endpoint. It supports general diff --git a/ca/identity/identity.go b/ca/identity/identity.go index fa9ebf71..08a70c7f 100644 --- a/ca/identity/identity.go +++ b/ca/identity/identity.go @@ -26,9 +26,16 @@ type Type string // Disabled represents a disabled identity type const Disabled Type = "" -// MutualTLS represents the identity using mTLS +// MutualTLS represents the identity using mTLS. const MutualTLS Type = "mTLS" +// TunnelTLS represents an identity using a (m)TLS tunnel. +// +// TunnelTLS can be optionally configured with client certificates and a root +// file with the CAs to trust. By default it will use the system truststore +// instead of the CA truststore. +const TunnelTLS Type = "tTLS" + // DefaultLeeway is the duration for matching not before claims. const DefaultLeeway = 1 * time.Minute @@ -44,19 +51,30 @@ type Identity struct { Type string `json:"type"` Certificate string `json:"crt"` Key string `json:"key"` + + // Host is the tunnel host for a TunnelTLS (tTLS) identity. + Host string `json:"host,omitempty"` + // Root is the CA bundle of root CAs used in TunnelTLS to trust the + // certificate of the host. + Root string `json:"root,omitempty"` +} + +// LoadIdentity loads an identity present in the given filename. +func LoadIdentity(filename string) (*Identity, error) { + b, err := ioutil.ReadFile(filename) + if err != nil { + return nil, errors.Wrapf(err, "error reading %s", filename) + } + identity := new(Identity) + if err := json.Unmarshal(b, &identity); err != nil { + return nil, errors.Wrapf(err, "error unmarshaling %s", filename) + } + return identity, nil } // LoadDefaultIdentity loads the default identity. func LoadDefaultIdentity() (*Identity, error) { - b, err := ioutil.ReadFile(IdentityFile) - if err != nil { - return nil, errors.Wrapf(err, "error reading %s", IdentityFile) - } - identity := new(Identity) - if err := json.Unmarshal(b, &identity); err != nil { - return nil, errors.Wrapf(err, "error unmarshaling %s", IdentityFile) - } - return identity, nil + return LoadIdentity(IdentityFile) } // configDir and identityDir are used in WriteDefaultIdentity for testing @@ -81,7 +99,7 @@ func WriteDefaultIdentity(certChain []api.Certificate, key crypto.PrivateKey) er keyFilename := filepath.Join(identityDir, "identity_key") // Write certificate - if err := WriteIdentityCertificate(certChain); err != nil { + if err := writeCertificate(certFilename, certChain); err != nil { return err } @@ -116,22 +134,21 @@ func WriteDefaultIdentity(certChain []api.Certificate, key crypto.PrivateKey) er return nil } -// WriteIdentityCertificate writes the identity certificate in disk. -func WriteIdentityCertificate(certChain []api.Certificate) error { +// writeCertificate writes the given certificate on disk. +func writeCertificate(filename string, certChain []api.Certificate) error { buf := new(bytes.Buffer) - certFilename := filepath.Join(identityDir, "identity.crt") for _, crt := range certChain { block := &pem.Block{ Type: "CERTIFICATE", Bytes: crt.Raw, } if err := pem.Encode(buf, block); err != nil { - return errors.Wrap(err, "error encoding identity certificate") + return errors.Wrap(err, "error encoding certificate") } } - if err := ioutil.WriteFile(certFilename, buf.Bytes(), 0600); err != nil { - return errors.Wrap(err, "error writing identity certificate") + if err := ioutil.WriteFile(filename, buf.Bytes(), 0600); err != nil { + return errors.Wrap(err, "error writing certificate") } return nil @@ -144,6 +161,8 @@ func (i *Identity) Kind() Type { return Disabled case "mtls": return MutualTLS + case "ttls": + return TunnelTLS default: return Type(i.Type) } @@ -164,8 +183,26 @@ func (i *Identity) Validate() error { if err := fileExists(i.Certificate); err != nil { return err } - if err := fileExists(i.Key); err != nil { - return err + return fileExists(i.Key) + case TunnelTLS: + if i.Host == "" { + return errors.New("tunnel.host cannot be empty") + } + if i.Certificate != "" { + if err := fileExists(i.Certificate); err != nil { + return err + } + if i.Key == "" { + return errors.New("tunnel.key cannot be empty") + } + if err := fileExists(i.Key); err != nil { + return err + } + } + if i.Root != "" { + if err := fileExists(i.Root); err != nil { + return err + } } return nil default: @@ -179,7 +216,7 @@ func (i *Identity) TLSCertificate() (tls.Certificate, error) { switch i.Kind() { case Disabled: return tls.Certificate{}, nil - case MutualTLS: + case MutualTLS, TunnelTLS: crt, err := tls.LoadX509KeyPair(i.Certificate, i.Key) if err != nil { return fail(errors.Wrap(err, "error creating identity certificate")) @@ -215,6 +252,22 @@ func (i *Identity) GetClientCertificateFunc() func(*tls.CertificateRequestInfo) } } +// GetCertPool returns a x509.CertPool if the identity defines a custom root. +func (i *Identity) GetCertPool() (*x509.CertPool, error) { + if i.Root == "" { + return nil, nil + } + b, err := ioutil.ReadFile(i.Root) + if err != nil { + return nil, errors.Wrap(err, "error reading identity root") + } + pool := x509.NewCertPool() + if !pool.AppendCertsFromPEM(b) { + return nil, errors.Errorf("error pasing identity root: %s does not contain any certificate", i.Root) + } + return pool, nil +} + // Renewer is that interface that a renew client must implement. type Renewer interface { GetRootCAs() *x509.CertPool @@ -227,7 +280,7 @@ func (i *Identity) Renew(client Renewer) error { switch i.Kind() { case Disabled: return nil - case MutualTLS: + case MutualTLS, TunnelTLS: cert, err := i.TLSCertificate() if err != nil { return err diff --git a/ca/identity/identity_test.go b/ca/identity/identity_test.go index 7064cead..ce64768c 100644 --- a/ca/identity/identity_test.go +++ b/ca/identity/identity_test.go @@ -63,6 +63,7 @@ func TestIdentity_Kind(t *testing.T) { }{ {"disabled", fields{""}, Disabled}, {"mutualTLS", fields{"mTLS"}, MutualTLS}, + {"tunnelTLS", fields{"tTLS"}, TunnelTLS}, {"unknown", fields{"unknown"}, Type("unknown")}, } for _, tt := range tests { @@ -82,19 +83,27 @@ func TestIdentity_Validate(t *testing.T) { Type string Certificate string Key string + Host string + Root string } tests := []struct { name string fields fields wantErr bool }{ - {"ok", fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, false}, + {"ok mTLS", fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key", "", ""}, false}, + {"ok tTLS", fields{"tTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key", "tunnel:443", "testdata/certs/root_ca.crt"}, false}, {"ok disabled", fields{}, false}, - {"fail type", fields{"foo", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, true}, - {"fail certificate", fields{"mTLS", "", "testdata/identity/identity_key"}, true}, - {"fail key", fields{"mTLS", "testdata/identity/identity.crt", ""}, true}, - {"fail missing certificate", fields{"mTLS", "testdata/identity/missing.crt", "testdata/identity/identity_key"}, true}, - {"fail missing key", fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/missing_key"}, true}, + {"fail type", fields{"foo", "testdata/identity/identity.crt", "testdata/identity/identity_key", "", ""}, true}, + {"fail certificate", fields{"mTLS", "", "testdata/identity/identity_key", "", ""}, true}, + {"fail key", fields{"mTLS", "testdata/identity/identity.crt", "", "", ""}, true}, + {"fail key", fields{"tTLS", "testdata/identity/identity.crt", "", "tunnel:443", "testdata/certs/root_ca.crt"}, true}, + {"fail missing certificate", fields{"mTLS", "testdata/identity/missing.crt", "testdata/identity/identity_key", "", ""}, true}, + {"fail missing certificate", fields{"tTLS", "testdata/identity/missing.crt", "testdata/identity/identity_key", "tunnel:443", "testdata/certs/root_ca.crt"}, true}, + {"fail missing key", fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/missing_key", "", ""}, true}, + {"fail missing key", fields{"tTLS", "testdata/identity/identity.crt", "testdata/identity/missing_key", "tunnel:443", "testdata/certs/root_ca.crt"}, true}, + {"fail host", fields{"tTLS", "testdata/identity/identity.crt", "testdata/identity/missing_key", "", "testdata/certs/root_ca.crt"}, true}, + {"fail root", fields{"tTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key", "tunnel:443", "testdata/certs/missing.crt"}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -102,6 +111,8 @@ func TestIdentity_Validate(t *testing.T) { Type: tt.fields.Type, Certificate: tt.fields.Certificate, Key: tt.fields.Key, + Host: tt.fields.Host, + Root: tt.fields.Root, } if err := i.Validate(); (err != nil) != tt.wantErr { t.Errorf("Identity.Validate() error = %v, wantErr %v", err, tt.wantErr) @@ -127,7 +138,8 @@ func TestIdentity_TLSCertificate(t *testing.T) { want tls.Certificate wantErr bool }{ - {"ok", fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, expected, false}, + {"ok mTLS", fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, expected, false}, + {"ok tTLS", fields{"tTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, expected, false}, {"ok disabled", fields{}, tls.Certificate{}, false}, {"fail type", fields{"foo", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, tls.Certificate{}, true}, {"fail certificate", fields{"mTLS", "testdata/certs/server.crt", "testdata/identity/identity_key"}, tls.Certificate{}, true}, @@ -255,6 +267,95 @@ func TestWriteDefaultIdentity(t *testing.T) { } } +func TestIdentity_GetClientCertificateFunc(t *testing.T) { + expected, err := tls.LoadX509KeyPair("testdata/identity/identity.crt", "testdata/identity/identity_key") + if err != nil { + t.Fatal(err) + } + + type fields struct { + Type string + Certificate string + Key string + Host string + Root string + } + tests := []struct { + name string + fields fields + want *tls.Certificate + wantErr bool + }{ + {"ok mTLS", fields{"mtls", "testdata/identity/identity.crt", "testdata/identity/identity_key", "", ""}, &expected, false}, + {"ok tTLS", fields{"ttls", "testdata/identity/identity.crt", "testdata/identity/identity_key", "tunnel:443", "testdata/certs/root_ca.crt"}, &expected, false}, + {"fail missing cert", fields{"mTLS", "testdata/identity/missing.crt", "testdata/identity/identity_key", "", ""}, nil, true}, + {"fail missing key", fields{"tTLS", "testdata/identity/identity.crt", "testdata/identity/missing_key", "tunnel:443", "testdata/certs/root_ca.crt"}, nil, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + i := &Identity{ + Type: tt.fields.Type, + Certificate: tt.fields.Certificate, + Key: tt.fields.Key, + Host: tt.fields.Host, + Root: tt.fields.Root, + } + fn := i.GetClientCertificateFunc() + got, err := fn(&tls.CertificateRequestInfo{}) + if (err != nil) != tt.wantErr { + t.Errorf("Identity.GetClientCertificateFunc() = %v, wantErr %v", err, tt.wantErr) + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Identity.GetClientCertificateFunc() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestIdentity_GetCertPool(t *testing.T) { + type fields struct { + Type string + Certificate string + Key string + Host string + Root string + } + tests := []struct { + name string + fields fields + wantSubjects [][]byte + wantErr bool + }{ + {"ok", fields{"ttls", "testdata/identity/identity.crt", "testdata/identity/identity_key", "tunnel:443", "testdata/certs/root_ca.crt"}, [][]byte{[]byte("0\x1c1\x1a0\x18\x06\x03U\x04\x03\x13\x11Smallstep Root CA")}, false}, + {"ok nil", fields{"ttls", "testdata/identity/identity.crt", "testdata/identity/identity_key", "tunnel:443", ""}, nil, false}, + {"fail missing", fields{"ttls", "testdata/identity/identity.crt", "testdata/identity/identity_key", "tunnel:443", "testdata/certs/missing.crt"}, nil, true}, + {"fail no cert", fields{"ttls", "testdata/identity/identity.crt", "testdata/identity/identity_key", "tunnel:443", "testdata/secrets/root_ca_key"}, nil, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + i := &Identity{ + Type: tt.fields.Type, + Certificate: tt.fields.Certificate, + Key: tt.fields.Key, + Host: tt.fields.Host, + Root: tt.fields.Root, + } + got, err := i.GetCertPool() + if (err != nil) != tt.wantErr { + t.Errorf("Identity.GetCertPool() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != nil { + subjects := got.Subjects() + if !reflect.DeepEqual(subjects, tt.wantSubjects) { + t.Errorf("Identity.GetCertPool() = %x, want %x", subjects, tt.wantSubjects) + } + } + + }) + } +} + type renewer struct { pool *x509.CertPool sign *api.SignResponse diff --git a/ca/identity/testdata/config/tunnel.json b/ca/identity/testdata/config/tunnel.json new file mode 100644 index 00000000..49c76a55 --- /dev/null +++ b/ca/identity/testdata/config/tunnel.json @@ -0,0 +1,7 @@ +{ + "type": "mTLS", + "crt": "testdata/identity/identity.crt", + "key": "testdata/identity/identity_key", + "host": "tunnel:443", + "root": "testdata/certs/root_ca.crt" +} \ No newline at end of file diff --git a/ca/tls.go b/ca/tls.go index 20a5e504..2d9b8f92 100644 --- a/ca/tls.go +++ b/ca/tls.go @@ -10,13 +10,65 @@ import ( "encoding/pem" "net" "net/http" + "os" "time" "github.com/pkg/errors" "github.com/smallstep/certificates/api" - "golang.org/x/net/http2" + "github.com/smallstep/certificates/ca/identity" ) +// mTLSDialContext will hold the dial context function to use in +// getDefaultTransport. +var mTLSDialContext func() func(ctx context.Context, network, address string) (net.Conn, error) + +func init() { + // STEP_TLS_TUNNEL is an environment variable that can be set to do an TLS + // over (m)TLS tunnel to step-ca using identity-like credentials. The value + // is a path to a json file with the tunnel host, certificate, key and root + // used to create the (m)TLS tunnel. + // + // The configuration should look like: + // { + // "type": "tTLS", + // "host": "tunnel.example.com:443" + // "crt": "/path/to/tunnel.crt", + // "key": "/path/to/tunnel.key", + // "root": "/path/to/tunnel-root.crt" + // } + // + // This feature is EXPERIMENTAL and might change at any time. + if path := os.Getenv("STEP_TLS_TUNNEL"); path != "" { + id, err := identity.LoadIdentity(path) + if err != nil { + panic(err) + } + if err := id.Validate(); err != nil { + panic(err) + } + host, port, err := net.SplitHostPort(id.Host) + if err != nil { + panic(err) + } + pool, err := id.GetCertPool() + if err != nil { + panic(err) + } + mTLSDialContext = func() func(ctx context.Context, network, address string) (net.Conn, error) { + d := &tls.Dialer{ + NetDialer: getDefaultDialer(), + Config: &tls.Config{ + RootCAs: pool, + GetClientCertificate: id.GetClientCertificateFunc(), + }, + } + return func(ctx context.Context, network, address string) (net.Conn, error) { + return d.DialContext(ctx, "tcp", net.JoinHostPort(host, port)) + } + } + } +} + // GetClientTLSConfig returns a tls.Config for client use configured with the // sign certificate, and a new certificate pool with the sign root certificate. // The client certificate will automatically rotate before expiring. @@ -51,10 +103,7 @@ func (c *Client) getClientTLSConfig(ctx context.Context, sign *api.SignResponse, } // Update renew function with transport - tr, err := getDefaultTransport(tlsConfig) - if err != nil { - return nil, nil, err - } + tr := getDefaultTransport(tlsConfig) // Use mutable tls.Config on renew tr.DialTLS = c.buildDialTLS(tlsCtx) // nolint:staticcheck // tr.DialTLSContext = c.buildDialTLSContext(tlsCtx) @@ -103,10 +152,7 @@ func (c *Client) GetServerTLSConfig(ctx context.Context, sign *api.SignResponse, tlsConfig.GetConfigForClient = c.buildGetConfigForClient(tlsCtx) // Update renew function with transport - tr, err := getDefaultTransport(tlsConfig) - if err != nil { - return nil, err - } + tr := getDefaultTransport(tlsConfig) // Use mutable tls.Config on renew tr.DialTLS = c.buildDialTLS(tlsCtx) // nolint:staticcheck // tr.DialTLSContext = c.buildDialTLSContext(tlsCtx) @@ -144,11 +190,7 @@ func (c *Client) buildGetConfigForClient(ctx *TLSOptionCtx) func(*tls.ClientHell // buildDialTLS returns an implementation of DialTLS callback in http.Transport. func (c *Client) buildDialTLS(ctx *TLSOptionCtx) func(network, addr string) (net.Conn, error) { return func(network, addr string) (net.Conn, error) { - return tls.DialWithDialer(&net.Dialer{ - Timeout: 30 * time.Second, - KeepAlive: 30 * time.Second, - DualStack: true, - }, network, addr, ctx.mutableConfig.TLSConfig()) + return tls.DialWithDialer(getDefaultDialer(), network, addr, ctx.mutableConfig.TLSConfig()) } } @@ -156,18 +198,13 @@ func (c *Client) buildDialTLS(ctx *TLSOptionCtx) func(network, addr string) (net // nolint:unused func (c *Client) buildDialTLSContext(tlsCtx *TLSOptionCtx) func(ctx context.Context, network, addr string) (net.Conn, error) { return func(ctx context.Context, network, addr string) (net.Conn, error) { + d := getDefaultDialer() // TLS dialers do not support context, but we can use the context // deadline if it is set. - var deadline time.Time if t, ok := ctx.Deadline(); ok { - deadline = t + d.Deadline = t } - return tls.DialWithDialer(&net.Dialer{ - Timeout: 30 * time.Second, - KeepAlive: 30 * time.Second, - Deadline: deadline, - DualStack: true, - }, network, addr, tlsCtx.mutableConfig.TLSConfig()) + return tls.DialWithDialer(d, network, addr, tlsCtx.mutableConfig.TLSConfig()) } } @@ -238,27 +275,35 @@ func getDefaultTLSConfig(sign *api.SignResponse) *tls.Config { } } +// getDefaultDialer returns a new dialer with the default configuration. +func getDefaultDialer() *net.Dialer { + return &net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + } +} + // getDefaultTransport returns an http.Transport with the same parameters than // http.DefaultTransport, but adds the given tls.Config and configures the // transport for HTTP/2. -func getDefaultTransport(tlsConfig *tls.Config) (*http.Transport, error) { - tr := &http.Transport{ - Proxy: http.ProxyFromEnvironment, - DialContext: (&net.Dialer{ - Timeout: 30 * time.Second, - KeepAlive: 30 * time.Second, - DualStack: true, - }).DialContext, +func getDefaultTransport(tlsConfig *tls.Config) *http.Transport { + var dialContext func(ctx context.Context, network string, addr string) (net.Conn, error) + if mTLSDialContext == nil { + d := getDefaultDialer() + dialContext = d.DialContext + } else { + dialContext = mTLSDialContext() + } + return &http.Transport{ + Proxy: http.ProxyFromEnvironment, + DialContext: dialContext, + ForceAttemptHTTP2: true, MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second, TLSClientConfig: tlsConfig, } - if err := http2.ConfigureTransport(tr); err != nil { - return nil, errors.Wrap(err, "error configuring transport") - } - return tr, nil } func getPEM(i interface{}) ([]byte, error) { diff --git a/ca/tls_test.go b/ca/tls_test.go index 5513e06d..ac1d84b6 100644 --- a/ca/tls_test.go +++ b/ca/tls_test.go @@ -181,13 +181,8 @@ func TestClient_GetServerTLSConfig_http(t *testing.T) { t.Errorf("Client.GetClientTLSConfig() error = %v", err) return nil } - tr, err := getDefaultTransport(tlsConfig) - if err != nil { - t.Errorf("getDefaultTransport() error = %v", err) - return nil - } return &http.Client{ - Transport: tr, + Transport: getDefaultTransport(tlsConfig), } }, map[string]bool{srvTLS.URL: false, srvMTLS.URL: false}}, {"with no ClientCert", func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client { @@ -199,14 +194,8 @@ func TestClient_GetServerTLSConfig_http(t *testing.T) { tlsConfig := getDefaultTLSConfig(sr) tlsConfig.RootCAs = x509.NewCertPool() tlsConfig.RootCAs.AddCert(root) - - tr, err := getDefaultTransport(tlsConfig) - if err != nil { - t.Errorf("getDefaultTransport() error = %v", err) - return nil - } return &http.Client{ - Transport: tr, + Transport: getDefaultTransport(tlsConfig), } }, map[string]bool{srvTLS.URL + "/no-cert": false, srvMTLS.URL + "/no-cert": true}}, {"fail with default", func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client { @@ -288,10 +277,7 @@ func TestClient_GetServerTLSConfig_renew(t *testing.T) { if err != nil { t.Fatalf("Client.GetClientTLSConfig() error = %v", err) } - tr2, err := getDefaultTransport(tlsConfig) - if err != nil { - t.Fatalf("getDefaultTransport() error = %v", err) - } + tr2 := getDefaultTransport(tlsConfig) // No client cert root, err := RootCertificate(sr) if err != nil { @@ -300,10 +286,7 @@ func TestClient_GetServerTLSConfig_renew(t *testing.T) { tlsConfig = getDefaultTLSConfig(sr) tlsConfig.RootCAs = x509.NewCertPool() tlsConfig.RootCAs.AddCert(root) - tr3, err := getDefaultTransport(tlsConfig) - if err != nil { - t.Fatalf("getDefaultTransport() error = %v", err) - } + tr3 := getDefaultTransport(tlsConfig) // Disable keep alives to force TLS handshake tr1.DisableKeepAlives = true diff --git a/docs/provisioners.md b/docs/provisioners.md index 63275916..7ee9af50 100644 --- a/docs/provisioners.md +++ b/docs/provisioners.md @@ -191,7 +191,7 @@ In the ca.json configuration file, a complete JWK provisioner example looks like ### OIDC An OIDC provisioner allows a user to get a certificate after authenticating -himself with an OAuth OpenID Connect identity provider. The ID token provided +with an OAuth OpenID Connect identity provider. The ID token provided will be used on the CA authentication, and by default, the certificate will only have the user's email as a Subject Alternative Name (SAN) Extension. diff --git a/kms/yubikey/yubikey.go b/kms/yubikey/yubikey.go index 19cef55e..2dde244a 100644 --- a/kms/yubikey/yubikey.go +++ b/kms/yubikey/yubikey.go @@ -313,7 +313,7 @@ func getSlotAndName(name string) (piv.Slot, string, error) { s, ok := slotMapping[slotID] if !ok { - return piv.Slot{}, "", errors.Errorf("usupported slot-id '%s'", name) + return piv.Slot{}, "", errors.Errorf("unsupported slot-id '%s'", name) } name = "yubikey:slot-id=" + url.QueryEscape(slotID) diff --git a/systemd/cert-renewer@.service b/systemd/cert-renewer@.service index f38951b5..0cac0fbf 100644 --- a/systemd/cert-renewer@.service +++ b/systemd/cert-renewer@.service @@ -26,7 +26,7 @@ ExecStart=/usr/bin/step ca renew --force $CERT_LOCATION $KEY_LOCATION ; Try to reload or restart the systemd service that relies on this cert-renewer ; If the relying service doesn't exist, forge ahead. -ExecStartPost=/usr/bin/env bash -c "if ! systemctl --quiet is-enabled %i.service ; then exit 0; fi; systemctl try-reload-or-restart %i" +ExecStartPost=-systemctl try-reload-or-restart %i [Install] WantedBy=multi-user.target