reload service registration configuration on SIGHUP (#17598)

* add reloading service configuration

* add changelog entry

* add tests

* fix typo

* check if config.ServiceRegistration is nil before signaling

* add changes for deregistering service on nil config with failing tests

* fix tests by decreasing reconcile_timeout + setting consul agent tokens

* fix races

* add comments in test

---------

Co-authored-by: Marc Boudreau <marc.boudreau@hashicorp.com>
This commit is contained in:
Kevin Schoonover
2024-05-09 14:13:14 -07:00
committed by GitHub
parent 077c70fc1f
commit c0ea7b1a35
7 changed files with 288 additions and 114 deletions

3
changelog/17598.txt Normal file
View File

@@ -0,0 +1,3 @@
```release-note:improvement
core/config: reload service registration configuration on SIGHUP
```

View File

@@ -1685,6 +1685,15 @@ func (c *ServerCommand) Run(args []string) int {
}
}
// notify ServiceRegistration that a configuration reload has occurred
if sr := coreConfig.GetServiceRegistration(); sr != nil {
var srConfig *map[string]string
if config.ServiceRegistration != nil {
srConfig = &config.ServiceRegistration.Config
}
sr.NotifyConfigurationReload(srConfig)
}
if err := core.ReloadCensus(); err != nil {
c.UI.Error(err.Error())
}

View File

@@ -51,7 +51,7 @@ const (
// reconcileTimeout is how often Vault should query Consul to detect
// and fix any state drift.
reconcileTimeout = 60 * time.Second
DefaultReconcileTimeout = 60 * time.Second
// metaExternalSource is a metadata value for external-source that can be
// used by the Consul UI.
@@ -64,9 +64,11 @@ var hostnameRegex = regexp.MustCompile(`^(([a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9\-]*
// Vault to Consul.
type serviceRegistration struct {
Client *api.Client
config *api.Config
logger log.Logger
serviceLock sync.RWMutex
registeredServiceID string
redirectHost string
redirectPort int64
serviceName string
@@ -74,6 +76,7 @@ type serviceRegistration struct {
serviceAddress *string
disableRegistration bool
checkTimeout time.Duration
reconcileTimeout time.Duration
notifyActiveCh chan struct{}
notifySealedCh chan struct{}
@@ -92,90 +95,9 @@ func NewServiceRegistration(conf map[string]string, logger log.Logger, state sr.
return nil, errors.New("logger is required")
}
// Allow admins to disable consul integration
disableReg, ok := conf["disable_registration"]
var disableRegistration bool
if ok && disableReg != "" {
b, err := parseutil.ParseBool(disableReg)
if err != nil {
return nil, fmt.Errorf("failed parsing disable_registration parameter: %w", err)
}
disableRegistration = b
}
if logger.IsDebug() {
logger.Debug("config disable_registration set", "disable_registration", disableRegistration)
}
// Get the service name to advertise in Consul
service, ok := conf["service"]
if !ok {
service = DefaultServiceName
}
if !hostnameRegex.MatchString(service) {
return nil, errors.New("service name must be valid per RFC 1123 and can contain only alphanumeric characters or dashes")
}
if logger.IsDebug() {
logger.Debug("config service set", "service", service)
}
// Get the additional tags to attach to the registered service name
tags := conf["service_tags"]
if logger.IsDebug() {
logger.Debug("config service_tags set", "service_tags", tags)
}
// Get the service-specific address to override the use of the HA redirect address
var serviceAddr *string
serviceAddrStr, ok := conf["service_address"]
if ok {
serviceAddr = &serviceAddrStr
}
if logger.IsDebug() {
logger.Debug("config service_address set", "service_address", serviceAddrStr)
}
checkTimeout := defaultCheckTimeout
checkTimeoutStr, ok := conf["check_timeout"]
if ok {
d, err := parseutil.ParseDurationSecond(checkTimeoutStr)
if err != nil {
return nil, err
}
min, _ := durationMinusBufferDomain(d, checkMinBuffer, checkJitterFactor)
if min < checkMinBuffer {
return nil, fmt.Errorf("consul check_timeout must be greater than %v", min)
}
checkTimeout = d
if logger.IsDebug() {
logger.Debug("config check_timeout set", "check_timeout", d)
}
}
// Configure the client
consulConf := api.DefaultConfig()
// Set MaxIdleConnsPerHost to the number of processes used in expiration.Restore
consulConf.Transport.MaxIdleConnsPerHost = consts.ExpirationRestoreWorkerCount
SetupSecureTLS(context.Background(), consulConf, conf, logger, false)
consulConf.HttpClient = &http.Client{Transport: consulConf.Transport}
client, err := api.NewClient(consulConf)
if err != nil {
return nil, fmt.Errorf("client setup failed: %w", err)
}
// Setup the backend
c := &serviceRegistration{
Client: client,
logger: logger,
serviceName: service,
serviceTags: strutil.ParseDedupAndSortStrings(tags, ","),
serviceAddress: serviceAddr,
checkTimeout: checkTimeout,
disableRegistration: disableRegistration,
logger: logger,
notifyActiveCh: make(chan struct{}),
notifySealedCh: make(chan struct{}),
@@ -187,7 +109,11 @@ func NewServiceRegistration(conf map[string]string, logger log.Logger, state sr.
isPerfStandby: atomicB.NewBool(state.IsPerformanceStandby),
isInitialized: atomicB.NewBool(state.IsInitialized),
}
return c, nil
c.serviceLock.Lock()
defer c.serviceLock.Unlock()
err := c.merge(conf)
return c, err
}
func SetupSecureTLS(ctx context.Context, consulConf *api.Config, conf map[string]string, logger log.Logger, isDiagnose bool) error {
@@ -270,6 +196,112 @@ func (c *serviceRegistration) Run(shutdownCh <-chan struct{}, wait *sync.WaitGro
return nil
}
func (c *serviceRegistration) merge(conf map[string]string) error {
// Allow admins to disable consul integration
disableReg, ok := conf["disable_registration"]
var disableRegistration bool
if ok && disableReg != "" {
b, err := parseutil.ParseBool(disableReg)
if err != nil {
return fmt.Errorf("failed parsing disable_registration parameter: %w", err)
}
disableRegistration = b
}
if c.logger.IsDebug() {
c.logger.Debug("config disable_registration set", "disable_registration", disableRegistration)
}
// Get the service name to advertise in Consul
service, ok := conf["service"]
if !ok {
service = DefaultServiceName
}
if !hostnameRegex.MatchString(service) {
return errors.New("service name must be valid per RFC 1123 and can contain only alphanumeric characters or dashes")
}
if c.logger.IsDebug() {
c.logger.Debug("config service set", "service", service)
}
// Get the additional tags to attach to the registered service name
tags := conf["service_tags"]
if c.logger.IsDebug() {
c.logger.Debug("config service_tags set", "service_tags", tags)
}
// Get the service-specific address to override the use of the HA redirect address
var serviceAddr *string
serviceAddrStr, ok := conf["service_address"]
if ok {
serviceAddr = &serviceAddrStr
}
if c.logger.IsDebug() {
c.logger.Debug("config service_address set", "service_address", serviceAddrStr)
}
checkTimeout := defaultCheckTimeout
checkTimeoutStr, ok := conf["check_timeout"]
if ok {
d, err := parseutil.ParseDurationSecond(checkTimeoutStr)
if err != nil {
return err
}
min, _ := durationMinusBufferDomain(d, checkMinBuffer, checkJitterFactor)
if min < checkMinBuffer {
return fmt.Errorf("consul check_timeout must be greater than %v", min)
}
checkTimeout = d
if c.logger.IsDebug() {
c.logger.Debug("config check_timeout set", "check_timeout", d)
}
}
reconcileTimeout := DefaultReconcileTimeout
reconcileTimeoutStr, ok := conf["reconcile_timeout"]
if ok {
d, err := parseutil.ParseDurationSecond(reconcileTimeoutStr)
if err != nil {
return err
}
min, _ := durationMinusBufferDomain(d, checkMinBuffer, checkJitterFactor)
if min < checkMinBuffer {
return fmt.Errorf("consul reconcile_timeout must be greater than %v", min)
}
reconcileTimeout = d
if c.logger.IsDebug() {
c.logger.Debug("config reconcile_timeout set", "reconcile_timeout", d)
}
}
// Configure the client
consulConf := api.DefaultConfig()
// Set MaxIdleConnsPerHost to the number of processes used in expiration.Restore
consulConf.Transport.MaxIdleConnsPerHost = consts.ExpirationRestoreWorkerCount
SetupSecureTLS(context.Background(), consulConf, conf, c.logger, false)
consulConf.HttpClient = &http.Client{Transport: consulConf.Transport}
client, err := api.NewClient(consulConf)
if err != nil {
return fmt.Errorf("client setup failed: %w", err)
}
c.Client = client
c.config = consulConf
c.serviceName = service
c.serviceTags = strutil.ParseDedupAndSortStrings(tags, ",")
c.serviceAddress = serviceAddr
c.checkTimeout = checkTimeout
c.disableRegistration = disableRegistration
c.reconcileTimeout = reconcileTimeout
return nil
}
func (c *serviceRegistration) NotifyActiveStateChange(isActive bool) error {
c.isActive.Store(isActive)
select {
@@ -322,6 +354,25 @@ func (c *serviceRegistration) NotifyInitializedStateChange(isInitialized bool) e
return nil
}
func (c *serviceRegistration) NotifyConfigurationReload(conf *map[string]string) error {
c.serviceLock.Lock()
defer c.serviceLock.Unlock()
if conf == nil {
if c.logger.IsDebug() {
c.logger.Debug("registration is now empty, deregistering service from consul")
}
c.disableRegistration = true
err := c.deregisterService()
c.Client = nil
return err
} else {
if c.logger.IsDebug() {
c.logger.Debug("service registration configuration received, merging with existing configuation")
}
return c.merge(*conf)
}
}
func (c *serviceRegistration) checkDuration() time.Duration {
return durationMinusBuffer(c.checkTimeout, checkMinBuffer, checkJitterFactor)
}
@@ -363,7 +414,6 @@ func (c *serviceRegistration) runEventDemuxer(waitGroup *sync.WaitGroup, shutdow
// and end of a handler's life (or after a handler wakes up from
// sleeping during a back-off/retry).
var shutdown atomicB.Bool
var registeredServiceID string
checkLock := new(int32)
serviceRegLock := new(int32)
@@ -383,16 +433,19 @@ func (c *serviceRegistration) runEventDemuxer(waitGroup *sync.WaitGroup, shutdow
checkTimer.Reset(0)
case <-reconcileTimer.C:
// Unconditionally rearm the reconcileTimer
reconcileTimer.Reset(reconcileTimeout - randomStagger(reconcileTimeout/checkJitterFactor))
c.serviceLock.RLock()
reconcileTimer.Reset(c.reconcileTimeout - randomStagger(c.reconcileTimeout/checkJitterFactor))
disableRegistration := c.disableRegistration
c.serviceLock.RUnlock()
// Abort if service discovery is disabled or a
// reconcile handler is already active
if !c.disableRegistration && atomic.CompareAndSwapInt32(serviceRegLock, 0, 1) {
if !disableRegistration && atomic.CompareAndSwapInt32(serviceRegLock, 0, 1) {
// Enter handler with serviceRegLock held
go func() {
defer atomic.CompareAndSwapInt32(serviceRegLock, 1, 0)
for !shutdown.Load() {
serviceID, err := c.reconcileConsul(registeredServiceID)
serviceID, err := c.reconcileConsul()
if err != nil {
if c.logger.IsWarn() {
c.logger.Warn("reconcile unable to talk with Consul backend", "error", err)
@@ -402,7 +455,7 @@ func (c *serviceRegistration) runEventDemuxer(waitGroup *sync.WaitGroup, shutdow
}
c.serviceLock.Lock()
registeredServiceID = serviceID
c.registeredServiceID = serviceID
c.serviceLock.Unlock()
return
@@ -411,19 +464,29 @@ func (c *serviceRegistration) runEventDemuxer(waitGroup *sync.WaitGroup, shutdow
}
case <-checkTimer.C:
checkTimer.Reset(c.checkDuration())
c.serviceLock.RLock()
disableRegistration := c.disableRegistration
c.serviceLock.RUnlock()
// Abort if service discovery is disabled or a
// reconcile handler is active
if !c.disableRegistration && atomic.CompareAndSwapInt32(checkLock, 0, 1) {
if !disableRegistration && atomic.CompareAndSwapInt32(checkLock, 0, 1) {
// Enter handler with checkLock held
go func() {
defer atomic.CompareAndSwapInt32(checkLock, 1, 0)
for !shutdown.Load() {
if err := c.runCheck(c.isSealed.Load()); err != nil {
if c.logger.IsWarn() {
c.logger.Warn("check unable to talk with Consul backend", "error", err)
c.serviceLock.RLock()
registeredServiceID := c.registeredServiceID
c.serviceLock.RUnlock()
if registeredServiceID != "" {
if err := c.runCheck(c.isSealed.Load()); err != nil {
if c.logger.IsWarn() {
c.logger.Warn("check unable to talk with Consul backend", "error", err)
}
time.Sleep(consulRetryInterval)
continue
}
time.Sleep(consulRetryInterval)
continue
}
return
}
@@ -435,13 +498,23 @@ func (c *serviceRegistration) runEventDemuxer(waitGroup *sync.WaitGroup, shutdow
}
}
c.serviceLock.RLock()
defer c.serviceLock.RUnlock()
if err := c.Client.Agent().ServiceDeregister(registeredServiceID); err != nil {
if c.logger.IsWarn() {
c.logger.Warn("service deregistration failed", "error", err)
c.serviceLock.Lock()
defer c.serviceLock.Unlock()
c.deregisterService()
}
func (c *serviceRegistration) deregisterService() error {
if c.registeredServiceID != "" {
if err := c.Client.Agent().ServiceDeregister(c.registeredServiceID); err != nil {
if c.logger.IsWarn() {
c.logger.Warn("service deregistration failed", "error", err)
}
return err
}
c.registeredServiceID = ""
}
return nil
}
// checkID returns the ID used for a Consul Check. Assume at least a read
@@ -458,10 +531,12 @@ func (c *serviceRegistration) serviceID() string {
// reconcileConsul queries the state of Vault Core and Consul and fixes up
// Consul's state according to what's in Vault. reconcileConsul is called
// without any locks held and can be run concurrently, therefore no changes
// with a read lock and can be run concurrently, therefore no changes
// to serviceRegistration can be made in this method (i.e. wtb const receiver for
// compiler enforced safety).
func (c *serviceRegistration) reconcileConsul(registeredServiceID string) (serviceID string, err error) {
func (c *serviceRegistration) reconcileConsul() (serviceID string, err error) {
c.serviceLock.RLock()
defer c.serviceLock.RUnlock()
agent := c.Client.Agent()
catalog := c.Client.Catalog()
@@ -483,7 +558,7 @@ func (c *serviceRegistration) reconcileConsul(registeredServiceID string) (servi
var reregister bool
switch {
case currentVaultService == nil, registeredServiceID == "":
case currentVaultService == nil, c.registeredServiceID == "":
reregister = true
default:
switch {

View File

@@ -63,6 +63,17 @@ func TestConsul_ServiceRegistration(t *testing.T) {
t.Fatal(err)
}
// update the agent's ACL token so that we can successfully deregister the
// service later in the test
_, err = client.Agent().UpdateAgentACLToken(config.Token, nil)
if err != nil {
t.Fatal(err)
}
_, err = client.Agent().UpdateDefaultACLToken(config.Token, nil)
if err != nil {
t.Fatal(err)
}
// waitForServices waits for the services in the Consul catalog to
// reach an expected value, returning the delta if that doesn't happen in time.
waitForServices := func(t *testing.T, expected map[string][]string) map[string][]string {
@@ -92,10 +103,13 @@ func TestConsul_ServiceRegistration(t *testing.T) {
// Create a ServiceRegistration that points to our consul instance
logger := logging.NewVaultLogger(log.Trace)
sd, err := NewServiceRegistration(map[string]string{
srConfig := map[string]string{
"address": config.Address(),
"token": config.Token,
}, logger, sr.State{})
// decrease reconcile timeout to make test run faster
"reconcile_timeout": "1s",
}
sd, err := NewServiceRegistration(srConfig, logger, sr.State{})
if err != nil {
t.Fatal(err)
}
@@ -147,6 +161,58 @@ func TestConsul_ServiceRegistration(t *testing.T) {
"consul": {},
"vault": {"active", "initialized"},
})
// change the token and trigger reload
if sd.(*serviceRegistration).config.Token == "" {
t.Fatal("expected service registration token to not be '' before configuration reload")
}
srConfigWithoutToken := make(map[string]string)
for k, v := range srConfig {
srConfigWithoutToken[k] = v
}
srConfigWithoutToken["token"] = ""
err = sd.NotifyConfigurationReload(&srConfigWithoutToken)
if err != nil {
t.Fatal(err)
}
if sd.(*serviceRegistration).config.Token != "" {
t.Fatal("expected service registration token to be '' after configuration reload")
}
// reconfigure the configuration back to its original state and verify vault is registered
err = sd.NotifyConfigurationReload(&srConfig)
if err != nil {
t.Fatal(err)
}
waitForServices(t, map[string][]string{
"consul": {},
"vault": {"active", "initialized"},
})
// send 'nil' configuration to verify that the service is deregistered
err = sd.NotifyConfigurationReload(nil)
if err != nil {
t.Fatal(err)
}
waitForServices(t, map[string][]string{
"consul": {},
})
// reconfigure the configuration back to its original state and verify vault
// is re-registered
err = sd.NotifyConfigurationReload(&srConfig)
if err != nil {
t.Fatal(err)
}
waitForServices(t, map[string][]string{
"consul": {},
"vault": {"active", "initialized"},
})
}
func TestConsul_ServiceAddress(t *testing.T) {

View File

@@ -106,6 +106,10 @@ func (r *serviceRegistration) NotifyInitializedStateChange(isInitialized bool) e
return nil
}
func (c *serviceRegistration) NotifyConfigurationReload(conf *map[string]string) error {
return nil
}
func getRequiredField(logger hclog.Logger, config map[string]string, envVar, configParam string) (string, error) {
value := ""
switch {

View File

@@ -96,4 +96,14 @@ type ServiceRegistration interface {
// the implementation's responsibility to retry updating state
// in the face of errors.
NotifyInitializedStateChange(isInitialized bool) error
// NotifyConfigurationReload is used by Core to notify that the Vault
// configuration has been reloaded.
// If errors are returned, Vault only logs a warning, so it is
// the implementation's responsibility to retry updating state
// in the face of errors.
//
// If the passed in conf is nil, it is assumed that the service registration
// configuration no longer exits and should be deregistered.
NotifyConfigurationReload(conf *map[string]string) error
}

View File

@@ -3286,11 +3286,12 @@ func TestCore_HandleRequest_TokenCreate_RegisterAuthFailure(t *testing.T) {
// mockServiceRegistration helps test whether standalone ServiceRegistration works
type mockServiceRegistration struct {
notifyActiveCount int
notifySealedCount int
notifyPerfCount int
notifyInitCount int
runDiscoveryCount int
notifyActiveCount int
notifySealedCount int
notifyPerfCount int
notifyInitCount int
notifyConfigurationReload int
runDiscoveryCount int
}
func (m *mockServiceRegistration) Run(shutdownCh <-chan struct{}, wait *sync.WaitGroup, redirectAddr string) error {
@@ -3318,6 +3319,11 @@ func (m *mockServiceRegistration) NotifyInitializedStateChange(isInitialized boo
return nil
}
func (m *mockServiceRegistration) NotifyConfigurationReload(config *map[string]string) error {
m.notifyConfigurationReload++
return nil
}
// TestCore_ServiceRegistration tests whether standalone ServiceRegistration works
func TestCore_ServiceRegistration(t *testing.T) {
// Make a mock service discovery
@@ -3374,10 +3380,11 @@ func TestCore_ServiceRegistration(t *testing.T) {
// Vault should be registered, unsealed, and active
if diff := deep.Equal(sr, &mockServiceRegistration{
runDiscoveryCount: 1,
notifyActiveCount: 1,
notifySealedCount: 1,
notifyInitCount: 1,
runDiscoveryCount: 1,
notifyActiveCount: 1,
notifySealedCount: 1,
notifyInitCount: 1,
notifyConfigurationReload: 1,
}); diff != nil {
t.Fatal(diff)
}