Merge pull request #2338 from smallstep/mariano/shutdown

Fix process hanging after SIGTERM
This commit is contained in:
Mariano Cano
2025-07-16 12:08:16 -07:00
committed by GitHub
2 changed files with 37 additions and 40 deletions

View File

@@ -5,6 +5,7 @@ import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"log"
"net"
@@ -12,12 +13,11 @@ import (
"net/url"
"reflect"
"strings"
"sync"
"time"
"github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware"
"github.com/pkg/errors"
"golang.org/x/sync/errgroup"
"github.com/smallstep/cli-utils/step"
"github.com/smallstep/nosql"
@@ -264,7 +264,7 @@ func (ca *CA) Init(cfg *config.Config) (*CA, error) {
if cfg.DB != nil {
acmeDB, err = acmeNoSQL.New(auth.GetDatabase().(nosql.DB))
if err != nil {
return nil, errors.Wrap(err, "error configuring ACME DB interface")
return nil, fmt.Errorf("error configuring ACME DB interface: %w", err)
}
acmeLinker = acme.NewLinker(dns, "acme")
mux.Route("/acme", func(r chi.Router) {
@@ -418,9 +418,6 @@ func buildContext(a *authority.Authority, scepAuthority *scep.Authority, acmeDB
// Run starts the CA calling to the server ListenAndServe method.
func (ca *CA) Run() error {
var wg sync.WaitGroup
errs := make(chan error, 1)
if !ca.opts.quiet {
authorityInfo := ca.auth.GetInfo()
log.Printf("Starting %s", step.Version())
@@ -450,36 +447,29 @@ func (ca *CA) Run() error {
}
}
wg.Add(1)
go func() {
defer wg.Done()
eg := new(errgroup.Group)
eg.Go(func() error {
ca.runCompactJob()
}()
return nil
})
if ca.insecureSrv != nil {
wg.Add(1)
go func() {
defer wg.Done()
errs <- ca.insecureSrv.ListenAndServe()
}()
eg.Go(func() error {
return ca.insecureSrv.ListenAndServe()
})
}
if ca.metricsSrv != nil {
wg.Add(1)
go func() {
defer wg.Done()
errs <- ca.metricsSrv.ListenAndServe()
}()
eg.Go(func() error {
return ca.metricsSrv.ListenAndServe()
})
}
wg.Add(1)
go func() {
defer wg.Done()
errs <- ca.srv.ListenAndServe()
}()
eg.Go(func() error {
return ca.srv.ListenAndServe()
})
// wait till error occurs; ensures the servers keep listening
err := <-errs
err := eg.Wait()
// if the error is not the usual HTTP server closed error, it is
// highly likely that an error occurred when starting one of the
@@ -495,8 +485,6 @@ func (ca *CA) Run() error {
}
}
wg.Wait()
return err
}
@@ -510,17 +498,26 @@ func (ca *CA) Stop() error {
if err := ca.auth.Shutdown(); err != nil {
log.Printf("error stopping ca.Authority: %+v\n", err)
}
var insecureShutdownErr error
if ca.insecureSrv != nil {
insecureShutdownErr = ca.insecureSrv.Shutdown()
}
secureErr := ca.srv.Shutdown()
if insecureShutdownErr != nil {
return insecureShutdownErr
var metricsShutdownErr error
if ca.metricsSrv != nil {
metricsShutdownErr = ca.metricsSrv.Shutdown()
}
secureErr := ca.srv.Shutdown()
switch {
case insecureShutdownErr != nil:
return insecureShutdownErr
case metricsShutdownErr != nil:
return metricsShutdownErr
default:
return secureErr
}
return secureErr
}
// Reload reloads the configuration of the CA and calls to the server Reload
@@ -528,7 +525,7 @@ func (ca *CA) Stop() error {
func (ca *CA) Reload() error {
cfg, err := config.LoadConfiguration(ca.opts.configFile)
if err != nil {
return errors.Wrap(err, "error reloading ca configuration")
return fmt.Errorf("error reloading ca configuration: %w", err)
}
logContinue := func(reason string) {
@@ -555,26 +552,26 @@ func (ca *CA) Reload() error {
)
if err != nil {
logContinue("Reload failed because the CA with new configuration could not be initialized.")
return errors.Wrap(err, "error reloading ca")
return fmt.Errorf("error reloading ca: %w", err)
}
if ca.insecureSrv != nil {
if err = ca.insecureSrv.Reload(newCA.insecureSrv); err != nil {
logContinue("Reload failed because insecure server could not be replaced.")
return errors.Wrap(err, "error reloading insecure server")
return fmt.Errorf("error reloading insecure server: %w", err)
}
}
if ca.metricsSrv != nil {
if err = ca.metricsSrv.Reload(newCA.metricsSrv); err != nil {
logContinue("Reload failed because metrics server could not be replaced.")
return errors.Wrap(err, "error reloading metrics server")
return fmt.Errorf("error reloading metrics server: %w", err)
}
}
if err = ca.srv.Reload(newCA.srv); err != nil {
logContinue("Reload failed because server could not be replaced.")
return errors.Wrap(err, "error reloading server")
return fmt.Errorf("error reloading server: %w", err)
}
// 1. Stop previous renewer

2
go.mod
View File

@@ -41,6 +41,7 @@ require (
golang.org/x/crypto v0.39.0
golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0
golang.org/x/net v0.41.0
golang.org/x/sync v0.15.0
google.golang.org/api v0.240.0
google.golang.org/grpc v1.73.0
google.golang.org/protobuf v1.36.6
@@ -160,7 +161,6 @@ require (
go.opentelemetry.io/otel/trace v1.36.0 // indirect
golang.org/x/mod v0.25.0 // indirect
golang.org/x/oauth2 v0.30.0 // indirect
golang.org/x/sync v0.15.0 // indirect
golang.org/x/sys v0.33.0 // indirect
golang.org/x/text v0.26.0 // indirect
golang.org/x/time v0.12.0 // indirect