diff --git a/cmd/main.go b/cmd/main.go index ec30df6..f3ca33a 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -12,6 +12,7 @@ import ( "net" "net/http" "os" + "time" est "github.com/foundriesio/estserver" "github.com/labstack/echo/v4" @@ -33,6 +34,7 @@ func main() { {name: "root-cert", help: "EST CA PEM encoded root certificate"}, } port := flag.Int("port", 8443, "Port to listen on") + certDuration := flag.Duration("cert-duration", time.Hour*24*365*3, "How long new certs should be valid for. e.g. such as '1.5h' or '2h45m'. 3 years is default") clientCas := flag.String("client-cas", "", "PEM encoded list of device CA's to allow. The device must present a certificate signed by a CA in this list or the `ca-cert` to authenticate") for _, opt := range required { @@ -83,7 +85,7 @@ func main() { log.Fatal().Err(err).Msg("Unable to create tls cert handler") } - svcHandler := est.NewStaticServiceHandler(est.NewService(rootCert, caCert, caKey)) + svcHandler := est.NewStaticServiceHandler(est.NewService(rootCert, caCert, caKey, *certDuration)) e := echo.New() s := http.Server{ diff --git a/http_handlers.go b/http_handlers.go index 19af7b0..6147349 100644 --- a/http_handlers.go +++ b/http_handlers.go @@ -33,7 +33,7 @@ func RegisterEchoHandlers(svcHandler ServiceHandler, e *echo.Echo) { } bytes, err = svc.Enroll(c.Request().Context(), bytes) if err != nil { - if errors.Is(err, EstError) { + if errors.Is(err, ErrEst) { return c.String(http.StatusBadRequest, err.Error()) } return c.String(http.StatusInternalServerError, err.Error()) @@ -52,7 +52,7 @@ func RegisterEchoHandlers(svcHandler ServiceHandler, e *echo.Echo) { peerCerts := c.Request().TLS.PeerCertificates bytes, err = svc.ReEnroll(c.Request().Context(), bytes, peerCerts[0]) if err != nil { - if errors.Is(err, EstError) { + if errors.Is(err, ErrEst) { return c.String(http.StatusBadRequest, err.Error()) } return c.String(http.StatusInternalServerError, err.Error()) diff --git a/service.go b/service.go index 1b6589b..6f63cbd 100644 --- a/service.go +++ b/service.go @@ -19,7 +19,7 @@ import ( ) var ( - EstError = errors.New("Base EstError") + ErrEst = errors.New("base EstError") ) type EstErrorType int @@ -34,7 +34,7 @@ const ( ) func (e EstErrorType) Unwrap() error { - return EstError + return ErrEst } func (e EstErrorType) Error() string { switch e { @@ -96,14 +96,18 @@ type Service struct { // ca and key are the EST7030 keypair used for signing EST7030 requests ca *x509.Certificate key crypto.Signer + + certDuration time.Duration } // NewService creates an EST7030 API for a Factory -func NewService(rootCa *x509.Certificate, ca *x509.Certificate, key crypto.Signer) Service { +func NewService(rootCa *x509.Certificate, ca *x509.Certificate, key crypto.Signer, certDuration time.Duration) Service { return Service{ rootCa: rootCa, ca: ca, key: key, + + certDuration: certDuration, } } @@ -207,7 +211,7 @@ func (s Service) signCsr(ctx context.Context, csr *x509.CertificateRequest) ([]b } now := time.Now() - notAfter := now.Add(time.Hour * 24 * 365) + notAfter := now.Add(s.certDuration) if notAfter.After(s.ca.NotAfter) { log.Warn().Msg("Adjusting default cert expiry") notAfter = s.ca.NotAfter diff --git a/service_test.go b/service_test.go index a9d231a..3e0a526 100644 --- a/service_test.go +++ b/service_test.go @@ -71,7 +71,7 @@ func createService(t *testing.T) Service { cert, err := x509.ParseCertificate(der) require.Nil(t, err) - return Service{cert, cert, key} + return Service{cert, cert, key, time.Hour * 24} } func TestService_CA(t *testing.T) { @@ -102,7 +102,7 @@ func TestService_loadCsrBase64(t *testing.T) { content := base64.StdEncoding.EncodeToString([]byte("not a valid CSR")) _, err = Service{}.loadCsr(ctx, []byte(content)) require.True(t, errors.Is(err, ErrInvalidCsr)) - require.True(t, errors.Is(err, EstError)) + require.True(t, errors.Is(err, ErrEst)) // valid Csr cn := random.String(12) diff --git a/tls_handler.go b/tls_handler.go index 702f2eb..035d918 100644 --- a/tls_handler.go +++ b/tls_handler.go @@ -9,7 +9,7 @@ import ( ) var ( - errNoCerts = errors.New("Unable to find certs for this server") + errNoCerts = errors.New("unable to find certs for this server") ) // TlsCerts represents the Server TLS keypair to advertise and CA roots we trust @@ -30,7 +30,7 @@ type TlsCertHandler interface { // Apply the TlsCertHandler logic to the tlsConfig func ApplyTlsCertHandler(tlsConfig *tls.Config, handler TlsCertHandler) error { if tlsConfig.ClientAuth != tls.VerifyClientCertIfGiven { - return fmt.Errorf("Invalid TLS ClientAuth value: %d. It must be `tls.VerifyClientCertIfGiven` to fulfill EST requirements", tlsConfig.ClientAuth) + return fmt.Errorf("invalid TLS ClientAuth value: %d. It must be `tls.VerifyClientCertIfGiven` to fulfill EST requirements", tlsConfig.ClientAuth) } tlsConfig.GetConfigForClient = func(helloInfo *tls.ClientHelloInfo) (*tls.Config, error) { return getConfigForClient(tlsConfig, handler, helloInfo)