mirror of
				https://github.com/optim-enterprises-bv/vault.git
				synced 2025-11-03 20:17:59 +00:00 
			
		
		
		
	Move cluster logic out of vault package (#6601)
* Move cluster logic out of vault package * Dedup heartbeat and fix tests * Fix test
This commit is contained in:
		@@ -13,6 +13,7 @@ import (
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	"github.com/hashicorp/go-hclog"
 | 
						"github.com/hashicorp/go-hclog"
 | 
				
			||||||
	"github.com/hashicorp/vault/sdk/logical"
 | 
						"github.com/hashicorp/vault/sdk/logical"
 | 
				
			||||||
 | 
						"github.com/hashicorp/vault/vault/cluster"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	log "github.com/hashicorp/go-hclog"
 | 
						log "github.com/hashicorp/go-hclog"
 | 
				
			||||||
	"github.com/hashicorp/vault/api"
 | 
						"github.com/hashicorp/vault/api"
 | 
				
			||||||
@@ -302,7 +303,7 @@ func GetPerfReplicatedClusters(t testing.T, conf *vault.CoreConfig, opts *vault.
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Set this lower so that state populates quickly to standby nodes
 | 
						// Set this lower so that state populates quickly to standby nodes
 | 
				
			||||||
	vault.HeartbeatInterval = 2 * time.Second
 | 
						cluster.HeartbeatInterval = 2 * time.Second
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	opts1 := *opts
 | 
						opts1 := *opts
 | 
				
			||||||
	opts1.Logger = logger.Named("perf-pri")
 | 
						opts1.Logger = logger.Named("perf-pri")
 | 
				
			||||||
@@ -325,7 +326,7 @@ func GetFourReplicatedClusters(t testing.T, handlerFunc func(*vault.HandlerPrope
 | 
				
			|||||||
		Level: log.Trace,
 | 
							Level: log.Trace,
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
	// Set this lower so that state populates quickly to standby nodes
 | 
						// Set this lower so that state populates quickly to standby nodes
 | 
				
			||||||
	vault.HeartbeatInterval = 2 * time.Second
 | 
						cluster.HeartbeatInterval = 2 * time.Second
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	ret.PerfPrimaryCluster, _ = GetClusterAndCore(t, logger.Named("perf-pri"), handlerFunc)
 | 
						ret.PerfPrimaryCluster, _ = GetClusterAndCore(t, logger.Named("perf-pri"), handlerFunc)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -11,4 +11,18 @@ const (
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	// AuthHeaderName is the name of the header containing the token.
 | 
						// AuthHeaderName is the name of the header containing the token.
 | 
				
			||||||
	AuthHeaderName = "X-Vault-Token"
 | 
						AuthHeaderName = "X-Vault-Token"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// PerformanceReplicationALPN is the negotiated protocol used for
 | 
				
			||||||
 | 
						// performance replication.
 | 
				
			||||||
 | 
						PerformanceReplicationALPN = "replication_v1"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// DRReplicationALPN is the negotiated protocol used for
 | 
				
			||||||
 | 
						// dr replication.
 | 
				
			||||||
 | 
						DRReplicationALPN = "replication_dr_v1"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						PerfStandbyALPN = "perf_standby_v1"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						RequestForwardingALPN = "req_fw_sb-act_v1"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						RaftStorageALPN = "raft_storage_v1"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										320
									
								
								vault/cluster.go
									
									
									
									
									
								
							
							
						
						
									
										320
									
								
								vault/cluster.go
									
									
									
									
									
								
							@@ -16,16 +16,13 @@ import (
 | 
				
			|||||||
	"net"
 | 
						"net"
 | 
				
			||||||
	"net/http"
 | 
						"net/http"
 | 
				
			||||||
	"strings"
 | 
						"strings"
 | 
				
			||||||
	"sync"
 | 
					 | 
				
			||||||
	"sync/atomic"
 | 
					 | 
				
			||||||
	"time"
 | 
						"time"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"github.com/hashicorp/errwrap"
 | 
						"github.com/hashicorp/errwrap"
 | 
				
			||||||
	log "github.com/hashicorp/go-hclog"
 | 
					 | 
				
			||||||
	uuid "github.com/hashicorp/go-uuid"
 | 
						uuid "github.com/hashicorp/go-uuid"
 | 
				
			||||||
	"github.com/hashicorp/vault/sdk/helper/jsonutil"
 | 
						"github.com/hashicorp/vault/sdk/helper/jsonutil"
 | 
				
			||||||
	"github.com/hashicorp/vault/sdk/logical"
 | 
						"github.com/hashicorp/vault/sdk/logical"
 | 
				
			||||||
	"golang.org/x/net/http2"
 | 
						"github.com/hashicorp/vault/vault/cluster"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
const (
 | 
					const (
 | 
				
			||||||
@@ -282,296 +279,6 @@ func (c *Core) setupCluster(ctx context.Context) error {
 | 
				
			|||||||
	return nil
 | 
						return nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// ClusterClient is used to lookup a client certificate.
 | 
					 | 
				
			||||||
type ClusterClient interface {
 | 
					 | 
				
			||||||
	ClientLookup(context.Context, *tls.CertificateRequestInfo) (*tls.Certificate, error)
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// ClusterHandler exposes functions for looking up TLS configuration and handing
 | 
					 | 
				
			||||||
// off a connection for a cluster listener application.
 | 
					 | 
				
			||||||
type ClusterHandler interface {
 | 
					 | 
				
			||||||
	ServerLookup(context.Context, *tls.ClientHelloInfo) (*tls.Certificate, error)
 | 
					 | 
				
			||||||
	CALookup(context.Context) (*x509.Certificate, error)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// Handoff is used to pass the connection lifetime off to
 | 
					 | 
				
			||||||
	// the handler
 | 
					 | 
				
			||||||
	Handoff(context.Context, *sync.WaitGroup, chan struct{}, *tls.Conn) error
 | 
					 | 
				
			||||||
	Stop() error
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// ClusterListener is the source of truth for cluster handlers and connection
 | 
					 | 
				
			||||||
// clients. It dynamically builds the cluster TLS information. It's also
 | 
					 | 
				
			||||||
// responsible for starting tcp listeners and accepting new cluster connections.
 | 
					 | 
				
			||||||
type ClusterListener struct {
 | 
					 | 
				
			||||||
	handlers   map[string]ClusterHandler
 | 
					 | 
				
			||||||
	clients    map[string]ClusterClient
 | 
					 | 
				
			||||||
	shutdown   *uint32
 | 
					 | 
				
			||||||
	shutdownWg *sync.WaitGroup
 | 
					 | 
				
			||||||
	server     *http2.Server
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	clusterListenerAddrs []*net.TCPAddr
 | 
					 | 
				
			||||||
	clusterCipherSuites  []uint16
 | 
					 | 
				
			||||||
	logger               log.Logger
 | 
					 | 
				
			||||||
	l                    sync.RWMutex
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// AddClient adds a new client for an ALPN name
 | 
					 | 
				
			||||||
func (cl *ClusterListener) AddClient(alpn string, client ClusterClient) {
 | 
					 | 
				
			||||||
	cl.l.Lock()
 | 
					 | 
				
			||||||
	cl.clients[alpn] = client
 | 
					 | 
				
			||||||
	cl.l.Unlock()
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// RemoveClient removes the client for the specified ALPN name
 | 
					 | 
				
			||||||
func (cl *ClusterListener) RemoveClient(alpn string) {
 | 
					 | 
				
			||||||
	cl.l.Lock()
 | 
					 | 
				
			||||||
	delete(cl.clients, alpn)
 | 
					 | 
				
			||||||
	cl.l.Unlock()
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// AddHandler registers a new cluster handler for the provided ALPN name.
 | 
					 | 
				
			||||||
func (cl *ClusterListener) AddHandler(alpn string, handler ClusterHandler) {
 | 
					 | 
				
			||||||
	cl.l.Lock()
 | 
					 | 
				
			||||||
	cl.handlers[alpn] = handler
 | 
					 | 
				
			||||||
	cl.l.Unlock()
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// StopHandler stops the cluster handler for the provided ALPN name, it also
 | 
					 | 
				
			||||||
// calls stop on the handler.
 | 
					 | 
				
			||||||
func (cl *ClusterListener) StopHandler(alpn string) {
 | 
					 | 
				
			||||||
	cl.l.Lock()
 | 
					 | 
				
			||||||
	handler, ok := cl.handlers[alpn]
 | 
					 | 
				
			||||||
	delete(cl.handlers, alpn)
 | 
					 | 
				
			||||||
	cl.l.Unlock()
 | 
					 | 
				
			||||||
	if ok {
 | 
					 | 
				
			||||||
		handler.Stop()
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// Server returns the http2 server that the cluster listener is using
 | 
					 | 
				
			||||||
func (cl *ClusterListener) Server() *http2.Server {
 | 
					 | 
				
			||||||
	return cl.server
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// TLSConfig returns a tls config object that uses dynamic lookups to correctly
 | 
					 | 
				
			||||||
// authenticate registered handlers/clients
 | 
					 | 
				
			||||||
func (cl *ClusterListener) TLSConfig(ctx context.Context) (*tls.Config, error) {
 | 
					 | 
				
			||||||
	serverLookup := func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
 | 
					 | 
				
			||||||
		cl.logger.Debug("performing server cert lookup")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		cl.l.RLock()
 | 
					 | 
				
			||||||
		defer cl.l.RUnlock()
 | 
					 | 
				
			||||||
		for _, v := range clientHello.SupportedProtos {
 | 
					 | 
				
			||||||
			if handler, ok := cl.handlers[v]; ok {
 | 
					 | 
				
			||||||
				return handler.ServerLookup(ctx, clientHello)
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		cl.logger.Warn("no TLS certs found for ALPN", "ALPN", clientHello.SupportedProtos)
 | 
					 | 
				
			||||||
		return nil, errors.New("unsupported protocol")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	clientLookup := func(requestInfo *tls.CertificateRequestInfo) (*tls.Certificate, error) {
 | 
					 | 
				
			||||||
		cl.logger.Debug("performing client cert lookup")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		cl.l.RLock()
 | 
					 | 
				
			||||||
		defer cl.l.RUnlock()
 | 
					 | 
				
			||||||
		for _, client := range cl.clients {
 | 
					 | 
				
			||||||
			cert, err := client.ClientLookup(ctx, requestInfo)
 | 
					 | 
				
			||||||
			if err == nil && cert != nil {
 | 
					 | 
				
			||||||
				return cert, nil
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		cl.logger.Warn("no client information found")
 | 
					 | 
				
			||||||
		return nil, errors.New("no client cert found")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	serverConfigLookup := func(clientHello *tls.ClientHelloInfo) (*tls.Config, error) {
 | 
					 | 
				
			||||||
		caPool := x509.NewCertPool()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		ret := &tls.Config{
 | 
					 | 
				
			||||||
			ClientAuth:           tls.RequireAndVerifyClientCert,
 | 
					 | 
				
			||||||
			GetCertificate:       serverLookup,
 | 
					 | 
				
			||||||
			GetClientCertificate: clientLookup,
 | 
					 | 
				
			||||||
			MinVersion:           tls.VersionTLS12,
 | 
					 | 
				
			||||||
			RootCAs:              caPool,
 | 
					 | 
				
			||||||
			ClientCAs:            caPool,
 | 
					 | 
				
			||||||
			NextProtos:           clientHello.SupportedProtos,
 | 
					 | 
				
			||||||
			CipherSuites:         cl.clusterCipherSuites,
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		cl.l.RLock()
 | 
					 | 
				
			||||||
		defer cl.l.RUnlock()
 | 
					 | 
				
			||||||
		for _, v := range clientHello.SupportedProtos {
 | 
					 | 
				
			||||||
			if handler, ok := cl.handlers[v]; ok {
 | 
					 | 
				
			||||||
				ca, err := handler.CALookup(ctx)
 | 
					 | 
				
			||||||
				if err != nil {
 | 
					 | 
				
			||||||
					return nil, err
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
				caPool.AddCert(ca)
 | 
					 | 
				
			||||||
				return ret, nil
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		cl.logger.Warn("no TLS config found for ALPN", "ALPN", clientHello.SupportedProtos)
 | 
					 | 
				
			||||||
		return nil, errors.New("unsupported protocol")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	return &tls.Config{
 | 
					 | 
				
			||||||
		ClientAuth:           tls.RequireAndVerifyClientCert,
 | 
					 | 
				
			||||||
		GetCertificate:       serverLookup,
 | 
					 | 
				
			||||||
		GetClientCertificate: clientLookup,
 | 
					 | 
				
			||||||
		GetConfigForClient:   serverConfigLookup,
 | 
					 | 
				
			||||||
		MinVersion:           tls.VersionTLS12,
 | 
					 | 
				
			||||||
		CipherSuites:         cl.clusterCipherSuites,
 | 
					 | 
				
			||||||
	}, nil
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// Run starts the tcp listeners and will accept connections until stop is
 | 
					 | 
				
			||||||
// called.
 | 
					 | 
				
			||||||
func (cl *ClusterListener) Run(ctx context.Context) error {
 | 
					 | 
				
			||||||
	// Get our TLS config
 | 
					 | 
				
			||||||
	tlsConfig, err := cl.TLSConfig(ctx)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		cl.logger.Error("failed to get tls configuration when starting cluster listener", "error", err)
 | 
					 | 
				
			||||||
		return err
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// The server supports all of the possible protos
 | 
					 | 
				
			||||||
	tlsConfig.NextProtos = []string{"h2", requestForwardingALPN, perfStandbyALPN, PerformanceReplicationALPN, DRReplicationALPN}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	for i, laddr := range cl.clusterListenerAddrs {
 | 
					 | 
				
			||||||
		// closeCh is used to shutdown the spawned goroutines once this
 | 
					 | 
				
			||||||
		// function returns
 | 
					 | 
				
			||||||
		closeCh := make(chan struct{})
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		if cl.logger.IsInfo() {
 | 
					 | 
				
			||||||
			cl.logger.Info("starting listener", "listener_address", laddr)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		// Create a TCP listener. We do this separately and specifically
 | 
					 | 
				
			||||||
		// with TCP so that we can set deadlines.
 | 
					 | 
				
			||||||
		tcpLn, err := net.ListenTCP("tcp", laddr)
 | 
					 | 
				
			||||||
		if err != nil {
 | 
					 | 
				
			||||||
			cl.logger.Error("error starting listener", "error", err)
 | 
					 | 
				
			||||||
			continue
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		if laddr.String() != tcpLn.Addr().String() {
 | 
					 | 
				
			||||||
			// If we listened on port 0, record the port the OS gave us.
 | 
					 | 
				
			||||||
			cl.clusterListenerAddrs[i] = tcpLn.Addr().(*net.TCPAddr)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		// Wrap the listener with TLS
 | 
					 | 
				
			||||||
		tlsLn := tls.NewListener(tcpLn, tlsConfig)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		if cl.logger.IsInfo() {
 | 
					 | 
				
			||||||
			cl.logger.Info("serving cluster requests", "cluster_listen_address", tlsLn.Addr())
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		cl.shutdownWg.Add(1)
 | 
					 | 
				
			||||||
		// Start our listening loop
 | 
					 | 
				
			||||||
		go func(closeCh chan struct{}, tlsLn net.Listener) {
 | 
					 | 
				
			||||||
			defer func() {
 | 
					 | 
				
			||||||
				cl.shutdownWg.Done()
 | 
					 | 
				
			||||||
				tlsLn.Close()
 | 
					 | 
				
			||||||
				close(closeCh)
 | 
					 | 
				
			||||||
			}()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
			for {
 | 
					 | 
				
			||||||
				if atomic.LoadUint32(cl.shutdown) > 0 {
 | 
					 | 
				
			||||||
					return
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
				// Set the deadline for the accept call. If it passes we'll get
 | 
					 | 
				
			||||||
				// an error, causing us to check the condition at the top
 | 
					 | 
				
			||||||
				// again.
 | 
					 | 
				
			||||||
				tcpLn.SetDeadline(time.Now().Add(clusterListenerAcceptDeadline))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
				// Accept the connection
 | 
					 | 
				
			||||||
				conn, err := tlsLn.Accept()
 | 
					 | 
				
			||||||
				if err != nil {
 | 
					 | 
				
			||||||
					if err, ok := err.(net.Error); ok && !err.Timeout() {
 | 
					 | 
				
			||||||
						cl.logger.Debug("non-timeout error accepting on cluster port", "error", err)
 | 
					 | 
				
			||||||
					}
 | 
					 | 
				
			||||||
					if conn != nil {
 | 
					 | 
				
			||||||
						conn.Close()
 | 
					 | 
				
			||||||
					}
 | 
					 | 
				
			||||||
					continue
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
				if conn == nil {
 | 
					 | 
				
			||||||
					continue
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
				// Type assert to TLS connection and handshake to populate the
 | 
					 | 
				
			||||||
				// connection state
 | 
					 | 
				
			||||||
				tlsConn := conn.(*tls.Conn)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
				// Set a deadline for the handshake. This will cause clients
 | 
					 | 
				
			||||||
				// that don't successfully auth to be kicked out quickly.
 | 
					 | 
				
			||||||
				// Cluster connections should be reliable so being marginally
 | 
					 | 
				
			||||||
				// aggressive here is fine.
 | 
					 | 
				
			||||||
				err = tlsConn.SetDeadline(time.Now().Add(30 * time.Second))
 | 
					 | 
				
			||||||
				if err != nil {
 | 
					 | 
				
			||||||
					if cl.logger.IsDebug() {
 | 
					 | 
				
			||||||
						cl.logger.Debug("error setting deadline for cluster connection", "error", err)
 | 
					 | 
				
			||||||
					}
 | 
					 | 
				
			||||||
					tlsConn.Close()
 | 
					 | 
				
			||||||
					continue
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
				err = tlsConn.Handshake()
 | 
					 | 
				
			||||||
				if err != nil {
 | 
					 | 
				
			||||||
					if cl.logger.IsDebug() {
 | 
					 | 
				
			||||||
						cl.logger.Debug("error handshaking cluster connection", "error", err)
 | 
					 | 
				
			||||||
					}
 | 
					 | 
				
			||||||
					tlsConn.Close()
 | 
					 | 
				
			||||||
					continue
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
				// Now, set it back to unlimited
 | 
					 | 
				
			||||||
				err = tlsConn.SetDeadline(time.Time{})
 | 
					 | 
				
			||||||
				if err != nil {
 | 
					 | 
				
			||||||
					if cl.logger.IsDebug() {
 | 
					 | 
				
			||||||
						cl.logger.Debug("error setting deadline for cluster connection", "error", err)
 | 
					 | 
				
			||||||
					}
 | 
					 | 
				
			||||||
					tlsConn.Close()
 | 
					 | 
				
			||||||
					continue
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
				cl.l.RLock()
 | 
					 | 
				
			||||||
				handler, ok := cl.handlers[tlsConn.ConnectionState().NegotiatedProtocol]
 | 
					 | 
				
			||||||
				cl.l.RUnlock()
 | 
					 | 
				
			||||||
				if !ok {
 | 
					 | 
				
			||||||
					cl.logger.Debug("unknown negotiated protocol on cluster port")
 | 
					 | 
				
			||||||
					tlsConn.Close()
 | 
					 | 
				
			||||||
					continue
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
				if err := handler.Handoff(ctx, cl.shutdownWg, closeCh, tlsConn); err != nil {
 | 
					 | 
				
			||||||
					cl.logger.Error("error handling cluster connection", "error", err)
 | 
					 | 
				
			||||||
					continue
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}(closeCh, tlsLn)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	return nil
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// Stop stops the cluster listner
 | 
					 | 
				
			||||||
func (cl *ClusterListener) Stop() {
 | 
					 | 
				
			||||||
	// Set the shutdown flag. This will cause the listeners to shut down
 | 
					 | 
				
			||||||
	// within the deadline in clusterListenerAcceptDeadline
 | 
					 | 
				
			||||||
	atomic.StoreUint32(cl.shutdown, 1)
 | 
					 | 
				
			||||||
	cl.logger.Info("forwarding rpc listeners stopped")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// Wait for them all to shut down
 | 
					 | 
				
			||||||
	cl.shutdownWg.Wait()
 | 
					 | 
				
			||||||
	cl.logger.Info("rpc listeners successfully shut down")
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// startClusterListener starts cluster request listeners during unseal. It
 | 
					// startClusterListener starts cluster request listeners during unseal. It
 | 
				
			||||||
// is assumed that the state lock is held while this is run. Right now this
 | 
					// is assumed that the state lock is held while this is run. Right now this
 | 
				
			||||||
// only starts cluster listeners. Once the listener is started handlers/clients
 | 
					// only starts cluster listeners. Once the listener is started handlers/clients
 | 
				
			||||||
@@ -589,27 +296,7 @@ func (c *Core) startClusterListener(ctx context.Context) error {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	c.logger.Debug("starting cluster listeners")
 | 
						c.logger.Debug("starting cluster listeners")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Create the HTTP/2 server that will be shared by both RPC and regular
 | 
						c.clusterListener = cluster.NewListener(c.clusterListenerAddrs, c.clusterCipherSuites, c.logger.Named("cluster-listener"))
 | 
				
			||||||
	// duties. Doing it this way instead of listening via the server and gRPC
 | 
					 | 
				
			||||||
	// allows us to re-use the same port via ALPN. We can just tell the server
 | 
					 | 
				
			||||||
	// to serve a given conn and which handler to use.
 | 
					 | 
				
			||||||
	h2Server := &http2.Server{
 | 
					 | 
				
			||||||
		// Our forwarding connections heartbeat regularly so anything else we
 | 
					 | 
				
			||||||
		// want to go away/get cleaned up pretty rapidly
 | 
					 | 
				
			||||||
		IdleTimeout: 5 * HeartbeatInterval,
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	c.clusterListener = &ClusterListener{
 | 
					 | 
				
			||||||
		handlers:   make(map[string]ClusterHandler),
 | 
					 | 
				
			||||||
		clients:    make(map[string]ClusterClient),
 | 
					 | 
				
			||||||
		shutdown:   new(uint32),
 | 
					 | 
				
			||||||
		shutdownWg: &sync.WaitGroup{},
 | 
					 | 
				
			||||||
		server:     h2Server,
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		clusterListenerAddrs: c.clusterListenerAddrs,
 | 
					 | 
				
			||||||
		clusterCipherSuites:  c.clusterCipherSuites,
 | 
					 | 
				
			||||||
		logger:               c.logger.Named("cluster-listener"),
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
	err := c.clusterListener.Run(ctx)
 | 
						err := c.clusterListener.Run(ctx)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
@@ -617,7 +304,7 @@ func (c *Core) startClusterListener(ctx context.Context) error {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
	if strings.HasSuffix(c.clusterAddr, ":0") {
 | 
						if strings.HasSuffix(c.clusterAddr, ":0") {
 | 
				
			||||||
		// If we listened on port 0, record the port the OS gave us.
 | 
							// If we listened on port 0, record the port the OS gave us.
 | 
				
			||||||
		c.clusterAddr = fmt.Sprintf("https://%s", c.clusterListener.clusterListenerAddrs[0])
 | 
							c.clusterAddr = fmt.Sprintf("https://%s", c.clusterListener.Addrs()[0])
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return nil
 | 
						return nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
@@ -633,6 +320,7 @@ func (c *Core) stopClusterListener() {
 | 
				
			|||||||
	c.logger.Info("stopping cluster listeners")
 | 
						c.logger.Info("stopping cluster listeners")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	c.clusterListener.Stop()
 | 
						c.clusterListener.Stop()
 | 
				
			||||||
 | 
						c.clusterListener = nil
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	c.logger.Info("cluster listeners successfully shut down")
 | 
						c.logger.Info("cluster listeners successfully shut down")
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										343
									
								
								vault/cluster/cluster.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										343
									
								
								vault/cluster/cluster.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,343 @@
 | 
				
			|||||||
 | 
					package cluster
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"context"
 | 
				
			||||||
 | 
						"crypto/tls"
 | 
				
			||||||
 | 
						"crypto/x509"
 | 
				
			||||||
 | 
						"errors"
 | 
				
			||||||
 | 
						"net"
 | 
				
			||||||
 | 
						"sync"
 | 
				
			||||||
 | 
						"sync/atomic"
 | 
				
			||||||
 | 
						"time"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						log "github.com/hashicorp/go-hclog"
 | 
				
			||||||
 | 
						"github.com/hashicorp/vault/sdk/helper/consts"
 | 
				
			||||||
 | 
						"golang.org/x/net/http2"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					var (
 | 
				
			||||||
 | 
						// Making this a package var allows tests to modify
 | 
				
			||||||
 | 
						HeartbeatInterval = 5 * time.Second
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					const (
 | 
				
			||||||
 | 
						ListenerAcceptDeadline = 500 * time.Millisecond
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Client is used to lookup a client certificate.
 | 
				
			||||||
 | 
					type Client interface {
 | 
				
			||||||
 | 
						ClientLookup(context.Context, *tls.CertificateRequestInfo) (*tls.Certificate, error)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Handler exposes functions for looking up TLS configuration and handing
 | 
				
			||||||
 | 
					// off a connection for a cluster listener application.
 | 
				
			||||||
 | 
					type Handler interface {
 | 
				
			||||||
 | 
						ServerLookup(context.Context, *tls.ClientHelloInfo) (*tls.Certificate, error)
 | 
				
			||||||
 | 
						CALookup(context.Context) (*x509.Certificate, error)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Handoff is used to pass the connection lifetime off to
 | 
				
			||||||
 | 
						// the handler
 | 
				
			||||||
 | 
						Handoff(context.Context, *sync.WaitGroup, chan struct{}, *tls.Conn) error
 | 
				
			||||||
 | 
						Stop() error
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Listener is the source of truth for cluster handlers and connection
 | 
				
			||||||
 | 
					// clients. It dynamically builds the cluster TLS information. It's also
 | 
				
			||||||
 | 
					// responsible for starting tcp listeners and accepting new cluster connections.
 | 
				
			||||||
 | 
					type Listener struct {
 | 
				
			||||||
 | 
						handlers   map[string]Handler
 | 
				
			||||||
 | 
						clients    map[string]Client
 | 
				
			||||||
 | 
						shutdown   *uint32
 | 
				
			||||||
 | 
						shutdownWg *sync.WaitGroup
 | 
				
			||||||
 | 
						server     *http2.Server
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						listenerAddrs []*net.TCPAddr
 | 
				
			||||||
 | 
						cipherSuites  []uint16
 | 
				
			||||||
 | 
						logger        log.Logger
 | 
				
			||||||
 | 
						l             sync.RWMutex
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func NewListener(addrs []*net.TCPAddr, cipherSuites []uint16, logger log.Logger) *Listener {
 | 
				
			||||||
 | 
						// Create the HTTP/2 server that will be shared by both RPC and regular
 | 
				
			||||||
 | 
						// duties. Doing it this way instead of listening via the server and gRPC
 | 
				
			||||||
 | 
						// allows us to re-use the same port via ALPN. We can just tell the server
 | 
				
			||||||
 | 
						// to serve a given conn and which handler to use.
 | 
				
			||||||
 | 
						h2Server := &http2.Server{
 | 
				
			||||||
 | 
							// Our forwarding connections heartbeat regularly so anything else we
 | 
				
			||||||
 | 
							// want to go away/get cleaned up pretty rapidly
 | 
				
			||||||
 | 
							IdleTimeout: 5 * HeartbeatInterval,
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return &Listener{
 | 
				
			||||||
 | 
							handlers:   make(map[string]Handler),
 | 
				
			||||||
 | 
							clients:    make(map[string]Client),
 | 
				
			||||||
 | 
							shutdown:   new(uint32),
 | 
				
			||||||
 | 
							shutdownWg: &sync.WaitGroup{},
 | 
				
			||||||
 | 
							server:     h2Server,
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							listenerAddrs: addrs,
 | 
				
			||||||
 | 
							cipherSuites:  cipherSuites,
 | 
				
			||||||
 | 
							logger:        logger,
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (cl *Listener) Addrs() []*net.TCPAddr {
 | 
				
			||||||
 | 
						return cl.listenerAddrs
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// AddClient adds a new client for an ALPN name
 | 
				
			||||||
 | 
					func (cl *Listener) AddClient(alpn string, client Client) {
 | 
				
			||||||
 | 
						cl.l.Lock()
 | 
				
			||||||
 | 
						cl.clients[alpn] = client
 | 
				
			||||||
 | 
						cl.l.Unlock()
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// RemoveClient removes the client for the specified ALPN name
 | 
				
			||||||
 | 
					func (cl *Listener) RemoveClient(alpn string) {
 | 
				
			||||||
 | 
						cl.l.Lock()
 | 
				
			||||||
 | 
						delete(cl.clients, alpn)
 | 
				
			||||||
 | 
						cl.l.Unlock()
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// AddHandler registers a new cluster handler for the provided ALPN name.
 | 
				
			||||||
 | 
					func (cl *Listener) AddHandler(alpn string, handler Handler) {
 | 
				
			||||||
 | 
						cl.l.Lock()
 | 
				
			||||||
 | 
						cl.handlers[alpn] = handler
 | 
				
			||||||
 | 
						cl.l.Unlock()
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// StopHandler stops the cluster handler for the provided ALPN name, it also
 | 
				
			||||||
 | 
					// calls stop on the handler.
 | 
				
			||||||
 | 
					func (cl *Listener) StopHandler(alpn string) {
 | 
				
			||||||
 | 
						cl.l.Lock()
 | 
				
			||||||
 | 
						handler, ok := cl.handlers[alpn]
 | 
				
			||||||
 | 
						delete(cl.handlers, alpn)
 | 
				
			||||||
 | 
						cl.l.Unlock()
 | 
				
			||||||
 | 
						if ok {
 | 
				
			||||||
 | 
							handler.Stop()
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Server returns the http2 server that the cluster listener is using
 | 
				
			||||||
 | 
					func (cl *Listener) Server() *http2.Server {
 | 
				
			||||||
 | 
						return cl.server
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// TLSConfig returns a tls config object that uses dynamic lookups to correctly
 | 
				
			||||||
 | 
					// authenticate registered handlers/clients
 | 
				
			||||||
 | 
					func (cl *Listener) TLSConfig(ctx context.Context) (*tls.Config, error) {
 | 
				
			||||||
 | 
						serverLookup := func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
 | 
				
			||||||
 | 
							cl.logger.Debug("performing server cert lookup")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							cl.l.RLock()
 | 
				
			||||||
 | 
							defer cl.l.RUnlock()
 | 
				
			||||||
 | 
							for _, v := range clientHello.SupportedProtos {
 | 
				
			||||||
 | 
								if handler, ok := cl.handlers[v]; ok {
 | 
				
			||||||
 | 
									return handler.ServerLookup(ctx, clientHello)
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							cl.logger.Warn("no TLS certs found for ALPN", "ALPN", clientHello.SupportedProtos)
 | 
				
			||||||
 | 
							return nil, errors.New("unsupported protocol")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						clientLookup := func(requestInfo *tls.CertificateRequestInfo) (*tls.Certificate, error) {
 | 
				
			||||||
 | 
							cl.logger.Debug("performing client cert lookup")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							cl.l.RLock()
 | 
				
			||||||
 | 
							defer cl.l.RUnlock()
 | 
				
			||||||
 | 
							for _, client := range cl.clients {
 | 
				
			||||||
 | 
								cert, err := client.ClientLookup(ctx, requestInfo)
 | 
				
			||||||
 | 
								if err == nil && cert != nil {
 | 
				
			||||||
 | 
									return cert, nil
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							cl.logger.Warn("no client information found")
 | 
				
			||||||
 | 
							return nil, errors.New("no client cert found")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						serverConfigLookup := func(clientHello *tls.ClientHelloInfo) (*tls.Config, error) {
 | 
				
			||||||
 | 
							caPool := x509.NewCertPool()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							ret := &tls.Config{
 | 
				
			||||||
 | 
								ClientAuth:           tls.RequireAndVerifyClientCert,
 | 
				
			||||||
 | 
								GetCertificate:       serverLookup,
 | 
				
			||||||
 | 
								GetClientCertificate: clientLookup,
 | 
				
			||||||
 | 
								MinVersion:           tls.VersionTLS12,
 | 
				
			||||||
 | 
								RootCAs:              caPool,
 | 
				
			||||||
 | 
								ClientCAs:            caPool,
 | 
				
			||||||
 | 
								NextProtos:           clientHello.SupportedProtos,
 | 
				
			||||||
 | 
								CipherSuites:         cl.cipherSuites,
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							cl.l.RLock()
 | 
				
			||||||
 | 
							defer cl.l.RUnlock()
 | 
				
			||||||
 | 
							for _, v := range clientHello.SupportedProtos {
 | 
				
			||||||
 | 
								if handler, ok := cl.handlers[v]; ok {
 | 
				
			||||||
 | 
									ca, err := handler.CALookup(ctx)
 | 
				
			||||||
 | 
									if err != nil {
 | 
				
			||||||
 | 
										return nil, err
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
									caPool.AddCert(ca)
 | 
				
			||||||
 | 
									return ret, nil
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							cl.logger.Warn("no TLS config found for ALPN", "ALPN", clientHello.SupportedProtos)
 | 
				
			||||||
 | 
							return nil, errors.New("unsupported protocol")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return &tls.Config{
 | 
				
			||||||
 | 
							ClientAuth:           tls.RequireAndVerifyClientCert,
 | 
				
			||||||
 | 
							GetCertificate:       serverLookup,
 | 
				
			||||||
 | 
							GetClientCertificate: clientLookup,
 | 
				
			||||||
 | 
							GetConfigForClient:   serverConfigLookup,
 | 
				
			||||||
 | 
							MinVersion:           tls.VersionTLS12,
 | 
				
			||||||
 | 
							CipherSuites:         cl.cipherSuites,
 | 
				
			||||||
 | 
						}, nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Run starts the tcp listeners and will accept connections until stop is
 | 
				
			||||||
 | 
					// called.
 | 
				
			||||||
 | 
					func (cl *Listener) Run(ctx context.Context) error {
 | 
				
			||||||
 | 
						// Get our TLS config
 | 
				
			||||||
 | 
						tlsConfig, err := cl.TLSConfig(ctx)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							cl.logger.Error("failed to get tls configuration when starting cluster listener", "error", err)
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// The server supports all of the possible protos
 | 
				
			||||||
 | 
						tlsConfig.NextProtos = []string{"h2", consts.RequestForwardingALPN, consts.PerfStandbyALPN, consts.PerformanceReplicationALPN, consts.DRReplicationALPN}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						for i, laddr := range cl.listenerAddrs {
 | 
				
			||||||
 | 
							// closeCh is used to shutdown the spawned goroutines once this
 | 
				
			||||||
 | 
							// function returns
 | 
				
			||||||
 | 
							closeCh := make(chan struct{})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							if cl.logger.IsInfo() {
 | 
				
			||||||
 | 
								cl.logger.Info("starting listener", "listener_address", laddr)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Create a TCP listener. We do this separately and specifically
 | 
				
			||||||
 | 
							// with TCP so that we can set deadlines.
 | 
				
			||||||
 | 
							tcpLn, err := net.ListenTCP("tcp", laddr)
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								cl.logger.Error("error starting listener", "error", err)
 | 
				
			||||||
 | 
								continue
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							if laddr.String() != tcpLn.Addr().String() {
 | 
				
			||||||
 | 
								// If we listened on port 0, record the port the OS gave us.
 | 
				
			||||||
 | 
								cl.listenerAddrs[i] = tcpLn.Addr().(*net.TCPAddr)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Wrap the listener with TLS
 | 
				
			||||||
 | 
							tlsLn := tls.NewListener(tcpLn, tlsConfig)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							if cl.logger.IsInfo() {
 | 
				
			||||||
 | 
								cl.logger.Info("serving cluster requests", "cluster_listen_address", tlsLn.Addr())
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							cl.shutdownWg.Add(1)
 | 
				
			||||||
 | 
							// Start our listening loop
 | 
				
			||||||
 | 
							go func(closeCh chan struct{}, tlsLn net.Listener) {
 | 
				
			||||||
 | 
								defer func() {
 | 
				
			||||||
 | 
									cl.shutdownWg.Done()
 | 
				
			||||||
 | 
									tlsLn.Close()
 | 
				
			||||||
 | 
									close(closeCh)
 | 
				
			||||||
 | 
								}()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								for {
 | 
				
			||||||
 | 
									if atomic.LoadUint32(cl.shutdown) > 0 {
 | 
				
			||||||
 | 
										return
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
									// Set the deadline for the accept call. If it passes we'll get
 | 
				
			||||||
 | 
									// an error, causing us to check the condition at the top
 | 
				
			||||||
 | 
									// again.
 | 
				
			||||||
 | 
									tcpLn.SetDeadline(time.Now().Add(ListenerAcceptDeadline))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
									// Accept the connection
 | 
				
			||||||
 | 
									conn, err := tlsLn.Accept()
 | 
				
			||||||
 | 
									if err != nil {
 | 
				
			||||||
 | 
										if err, ok := err.(net.Error); ok && !err.Timeout() {
 | 
				
			||||||
 | 
											cl.logger.Debug("non-timeout error accepting on cluster port", "error", err)
 | 
				
			||||||
 | 
										}
 | 
				
			||||||
 | 
										if conn != nil {
 | 
				
			||||||
 | 
											conn.Close()
 | 
				
			||||||
 | 
										}
 | 
				
			||||||
 | 
										continue
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
									if conn == nil {
 | 
				
			||||||
 | 
										continue
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
									// Type assert to TLS connection and handshake to populate the
 | 
				
			||||||
 | 
									// connection state
 | 
				
			||||||
 | 
									tlsConn := conn.(*tls.Conn)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
									// Set a deadline for the handshake. This will cause clients
 | 
				
			||||||
 | 
									// that don't successfully auth to be kicked out quickly.
 | 
				
			||||||
 | 
									// Cluster connections should be reliable so being marginally
 | 
				
			||||||
 | 
									// aggressive here is fine.
 | 
				
			||||||
 | 
									err = tlsConn.SetDeadline(time.Now().Add(30 * time.Second))
 | 
				
			||||||
 | 
									if err != nil {
 | 
				
			||||||
 | 
										if cl.logger.IsDebug() {
 | 
				
			||||||
 | 
											cl.logger.Debug("error setting deadline for cluster connection", "error", err)
 | 
				
			||||||
 | 
										}
 | 
				
			||||||
 | 
										tlsConn.Close()
 | 
				
			||||||
 | 
										continue
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
									err = tlsConn.Handshake()
 | 
				
			||||||
 | 
									if err != nil {
 | 
				
			||||||
 | 
										if cl.logger.IsDebug() {
 | 
				
			||||||
 | 
											cl.logger.Debug("error handshaking cluster connection", "error", err)
 | 
				
			||||||
 | 
										}
 | 
				
			||||||
 | 
										tlsConn.Close()
 | 
				
			||||||
 | 
										continue
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
									// Now, set it back to unlimited
 | 
				
			||||||
 | 
									err = tlsConn.SetDeadline(time.Time{})
 | 
				
			||||||
 | 
									if err != nil {
 | 
				
			||||||
 | 
										if cl.logger.IsDebug() {
 | 
				
			||||||
 | 
											cl.logger.Debug("error setting deadline for cluster connection", "error", err)
 | 
				
			||||||
 | 
										}
 | 
				
			||||||
 | 
										tlsConn.Close()
 | 
				
			||||||
 | 
										continue
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
									cl.l.RLock()
 | 
				
			||||||
 | 
									handler, ok := cl.handlers[tlsConn.ConnectionState().NegotiatedProtocol]
 | 
				
			||||||
 | 
									cl.l.RUnlock()
 | 
				
			||||||
 | 
									if !ok {
 | 
				
			||||||
 | 
										cl.logger.Debug("unknown negotiated protocol on cluster port")
 | 
				
			||||||
 | 
										tlsConn.Close()
 | 
				
			||||||
 | 
										continue
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
									if err := handler.Handoff(ctx, cl.shutdownWg, closeCh, tlsConn); err != nil {
 | 
				
			||||||
 | 
										cl.logger.Error("error handling cluster connection", "error", err)
 | 
				
			||||||
 | 
										continue
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}(closeCh, tlsLn)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Stop stops the cluster listner
 | 
				
			||||||
 | 
					func (cl *Listener) Stop() {
 | 
				
			||||||
 | 
						// Set the shutdown flag. This will cause the listeners to shut down
 | 
				
			||||||
 | 
						// within the deadline in clusterListenerAcceptDeadline
 | 
				
			||||||
 | 
						atomic.StoreUint32(cl.shutdown, 1)
 | 
				
			||||||
 | 
						cl.logger.Info("forwarding rpc listeners stopped")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Wait for them all to shut down
 | 
				
			||||||
 | 
						cl.shutdownWg.Wait()
 | 
				
			||||||
 | 
						cl.logger.Info("rpc listeners successfully shut down")
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@@ -101,17 +101,17 @@ func TestCluster_ListenForRequests(t *testing.T) {
 | 
				
			|||||||
	// Wait for core to become active
 | 
						// Wait for core to become active
 | 
				
			||||||
	TestWaitActive(t, cores[0].Core)
 | 
						TestWaitActive(t, cores[0].Core)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						cores[0].clusterListener.AddClient(consts.RequestForwardingALPN, &requestForwardingClusterClient{cores[0].Core})
 | 
				
			||||||
 | 
						addrs := cores[0].clusterListener.Addrs()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Use this to have a valid config after sealing since ClusterTLSConfig returns nil
 | 
						// Use this to have a valid config after sealing since ClusterTLSConfig returns nil
 | 
				
			||||||
	checkListenersFunc := func(expectFail bool) {
 | 
						checkListenersFunc := func(expectFail bool) {
 | 
				
			||||||
		cores[0].clusterListener.AddClient(requestForwardingALPN, &requestForwardingClusterClient{cores[0].Core})
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		parsedCert := cores[0].localClusterParsedCert.Load().(*x509.Certificate)
 | 
							parsedCert := cores[0].localClusterParsedCert.Load().(*x509.Certificate)
 | 
				
			||||||
		dialer := cores[0].getGRPCDialer(context.Background(), requestForwardingALPN, parsedCert.Subject.CommonName, parsedCert)
 | 
							dialer := cores[0].getGRPCDialer(context.Background(), consts.RequestForwardingALPN, parsedCert.Subject.CommonName, parsedCert)
 | 
				
			||||||
		for i := range cores[0].Listeners {
 | 
							for i := range cores[0].Listeners {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			clnAddr := cores[0].clusterListener.clusterListenerAddrs[i]
 | 
								clnAddr := addrs[i]
 | 
				
			||||||
			netConn, err := dialer(clnAddr.String(), 0)
 | 
								netConn, err := dialer(clnAddr.String(), 0)
 | 
				
			||||||
			conn := netConn.(*tls.Conn)
 | 
					 | 
				
			||||||
			if err != nil {
 | 
								if err != nil {
 | 
				
			||||||
				if expectFail {
 | 
									if expectFail {
 | 
				
			||||||
					t.Logf("testing %s unsuccessful as expected", clnAddr)
 | 
										t.Logf("testing %s unsuccessful as expected", clnAddr)
 | 
				
			||||||
@@ -122,6 +122,7 @@ func TestCluster_ListenForRequests(t *testing.T) {
 | 
				
			|||||||
			if expectFail {
 | 
								if expectFail {
 | 
				
			||||||
				t.Fatalf("testing %s not unsuccessful as expected", clnAddr)
 | 
									t.Fatalf("testing %s not unsuccessful as expected", clnAddr)
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
								conn := netConn.(*tls.Conn)
 | 
				
			||||||
			err = conn.Handshake()
 | 
								err = conn.Handshake()
 | 
				
			||||||
			if err != nil {
 | 
								if err != nil {
 | 
				
			||||||
				t.Fatal(err)
 | 
									t.Fatal(err)
 | 
				
			||||||
@@ -130,7 +131,7 @@ func TestCluster_ListenForRequests(t *testing.T) {
 | 
				
			|||||||
			switch {
 | 
								switch {
 | 
				
			||||||
			case connState.Version != tls.VersionTLS12:
 | 
								case connState.Version != tls.VersionTLS12:
 | 
				
			||||||
				t.Fatal("version mismatch")
 | 
									t.Fatal("version mismatch")
 | 
				
			||||||
			case connState.NegotiatedProtocol != requestForwardingALPN || !connState.NegotiatedProtocolIsMutual:
 | 
								case connState.NegotiatedProtocol != consts.RequestForwardingALPN || !connState.NegotiatedProtocolIsMutual:
 | 
				
			||||||
				t.Fatal("bad protocol negotiation")
 | 
									t.Fatal("bad protocol negotiation")
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
			t.Logf("testing %s successful", clnAddr)
 | 
								t.Logf("testing %s successful", clnAddr)
 | 
				
			||||||
@@ -155,7 +156,8 @@ func TestCluster_ListenForRequests(t *testing.T) {
 | 
				
			|||||||
	checkListenersFunc(true)
 | 
						checkListenersFunc(true)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// After this period it should be active again
 | 
						// After this period it should be active again
 | 
				
			||||||
	time.Sleep(manualStepDownSleepPeriod)
 | 
						TestWaitActive(t, cores[0].Core)
 | 
				
			||||||
 | 
						cores[0].clusterListener.AddClient(consts.RequestForwardingALPN, &requestForwardingClusterClient{cores[0].Core})
 | 
				
			||||||
	checkListenersFunc(false)
 | 
						checkListenersFunc(false)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	err = cores[0].Core.Seal(cluster.RootToken)
 | 
						err = cores[0].Core.Seal(cluster.RootToken)
 | 
				
			||||||
@@ -382,12 +384,12 @@ func TestCluster_CustomCipherSuites(t *testing.T) {
 | 
				
			|||||||
	// Wait for core to become active
 | 
						// Wait for core to become active
 | 
				
			||||||
	TestWaitActive(t, core.Core)
 | 
						TestWaitActive(t, core.Core)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	core.clusterListener.AddClient(requestForwardingALPN, &requestForwardingClusterClient{core.Core})
 | 
						core.clusterListener.AddClient(consts.RequestForwardingALPN, &requestForwardingClusterClient{core.Core})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	parsedCert := core.localClusterParsedCert.Load().(*x509.Certificate)
 | 
						parsedCert := core.localClusterParsedCert.Load().(*x509.Certificate)
 | 
				
			||||||
	dialer := core.getGRPCDialer(context.Background(), requestForwardingALPN, parsedCert.Subject.CommonName, parsedCert)
 | 
						dialer := core.getGRPCDialer(context.Background(), consts.RequestForwardingALPN, parsedCert.Subject.CommonName, parsedCert)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	netConn, err := dialer(core.clusterListener.clusterListenerAddrs[0].String(), 0)
 | 
						netConn, err := dialer(core.clusterListener.Addrs()[0].String(), 0)
 | 
				
			||||||
	conn := netConn.(*tls.Conn)
 | 
						conn := netConn.(*tls.Conn)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		t.Fatal(err)
 | 
							t.Fatal(err)
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -38,6 +38,7 @@ import (
 | 
				
			|||||||
	"github.com/hashicorp/vault/sdk/logical"
 | 
						"github.com/hashicorp/vault/sdk/logical"
 | 
				
			||||||
	"github.com/hashicorp/vault/sdk/physical"
 | 
						"github.com/hashicorp/vault/sdk/physical"
 | 
				
			||||||
	"github.com/hashicorp/vault/shamir"
 | 
						"github.com/hashicorp/vault/shamir"
 | 
				
			||||||
 | 
						"github.com/hashicorp/vault/vault/cluster"
 | 
				
			||||||
	"github.com/hashicorp/vault/vault/seal"
 | 
						"github.com/hashicorp/vault/vault/seal"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -419,7 +420,7 @@ type Core struct {
 | 
				
			|||||||
	loadCaseSensitiveIdentityStore bool
 | 
						loadCaseSensitiveIdentityStore bool
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// clusterListener starts up and manages connections on the cluster ports
 | 
						// clusterListener starts up and manages connections on the cluster ports
 | 
				
			||||||
	clusterListener *ClusterListener
 | 
						clusterListener *cluster.Listener
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Telemetry objects
 | 
						// Telemetry objects
 | 
				
			||||||
	metricsHelper *metricsutil.MetricsHelper
 | 
						metricsHelper *metricsutil.MetricsHelper
 | 
				
			||||||
@@ -592,7 +593,7 @@ func NewCore(conf *CoreConfig) (*Core, error) {
 | 
				
			|||||||
		maxLeaseTTL:                  conf.MaxLeaseTTL,
 | 
							maxLeaseTTL:                  conf.MaxLeaseTTL,
 | 
				
			||||||
		cachingDisabled:              conf.DisableCache,
 | 
							cachingDisabled:              conf.DisableCache,
 | 
				
			||||||
		clusterName:                  conf.ClusterName,
 | 
							clusterName:                  conf.ClusterName,
 | 
				
			||||||
		clusterPeerClusterAddrsCache: cache.New(3*HeartbeatInterval, time.Second),
 | 
							clusterPeerClusterAddrsCache: cache.New(3*cluster.HeartbeatInterval, time.Second),
 | 
				
			||||||
		enableMlock:                  !conf.DisableMlock,
 | 
							enableMlock:                  !conf.DisableMlock,
 | 
				
			||||||
		rawEnabled:                   conf.EnableRaw,
 | 
							rawEnabled:                   conf.EnableRaw,
 | 
				
			||||||
		replicationState:             new(uint32),
 | 
							replicationState:             new(uint32),
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -10,6 +10,7 @@ import (
 | 
				
			|||||||
	"github.com/hashicorp/vault/sdk/helper/license"
 | 
						"github.com/hashicorp/vault/sdk/helper/license"
 | 
				
			||||||
	"github.com/hashicorp/vault/sdk/logical"
 | 
						"github.com/hashicorp/vault/sdk/logical"
 | 
				
			||||||
	"github.com/hashicorp/vault/sdk/physical"
 | 
						"github.com/hashicorp/vault/sdk/physical"
 | 
				
			||||||
 | 
						"github.com/hashicorp/vault/vault/cluster"
 | 
				
			||||||
	"github.com/hashicorp/vault/vault/replication"
 | 
						"github.com/hashicorp/vault/vault/replication"
 | 
				
			||||||
	cache "github.com/patrickmn/go-cache"
 | 
						cache "github.com/patrickmn/go-cache"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
@@ -111,5 +112,5 @@ func (c *Core) invalidateSentinelPolicy(PolicyType, string) {}
 | 
				
			|||||||
func (c *Core) removePerfStandbySecondary(context.Context, string) {}
 | 
					func (c *Core) removePerfStandbySecondary(context.Context, string) {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (c *Core) perfStandbyClusterHandler() (*replication.Cluster, *cache.Cache, chan struct{}, error) {
 | 
					func (c *Core) perfStandbyClusterHandler() (*replication.Cluster, *cache.Cache, chan struct{}, error) {
 | 
				
			||||||
	return nil, cache.New(2*HeartbeatInterval, 1*time.Second), make(chan struct{}), nil
 | 
						return nil, cache.New(2*cluster.HeartbeatInterval, 1*time.Second), make(chan struct{}), nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -16,6 +16,8 @@ import (
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	log "github.com/hashicorp/go-hclog"
 | 
						log "github.com/hashicorp/go-hclog"
 | 
				
			||||||
	"github.com/hashicorp/vault/helper/forwarding"
 | 
						"github.com/hashicorp/vault/helper/forwarding"
 | 
				
			||||||
 | 
						"github.com/hashicorp/vault/sdk/helper/consts"
 | 
				
			||||||
 | 
						"github.com/hashicorp/vault/vault/cluster"
 | 
				
			||||||
	"github.com/hashicorp/vault/vault/replication"
 | 
						"github.com/hashicorp/vault/vault/replication"
 | 
				
			||||||
	cache "github.com/patrickmn/go-cache"
 | 
						cache "github.com/patrickmn/go-cache"
 | 
				
			||||||
	"golang.org/x/net/http2"
 | 
						"golang.org/x/net/http2"
 | 
				
			||||||
@@ -23,27 +25,6 @@ import (
 | 
				
			|||||||
	"google.golang.org/grpc/keepalive"
 | 
						"google.golang.org/grpc/keepalive"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
const (
 | 
					 | 
				
			||||||
	clusterListenerAcceptDeadline = 500 * time.Millisecond
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// PerformanceReplicationALPN is the negotiated protocol used for
 | 
					 | 
				
			||||||
	// performance replication.
 | 
					 | 
				
			||||||
	PerformanceReplicationALPN = "replication_v1"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// DRReplicationALPN is the negotiated protocol used for
 | 
					 | 
				
			||||||
	// dr replication.
 | 
					 | 
				
			||||||
	DRReplicationALPN = "replication_dr_v1"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	perfStandbyALPN = "perf_standby_v1"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	requestForwardingALPN = "req_fw_sb-act_v1"
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
var (
 | 
					 | 
				
			||||||
	// Making this a package var allows tests to modify
 | 
					 | 
				
			||||||
	HeartbeatInterval = 5 * time.Second
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type requestForwardingHandler struct {
 | 
					type requestForwardingHandler struct {
 | 
				
			||||||
	fws         *http2.Server
 | 
						fws         *http2.Server
 | 
				
			||||||
	fwRPCServer *grpc.Server
 | 
						fwRPCServer *grpc.Server
 | 
				
			||||||
@@ -65,7 +46,7 @@ func NewRequestForwardingHandler(c *Core, fws *http2.Server, perfStandbySlots ch
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	fwRPCServer := grpc.NewServer(
 | 
						fwRPCServer := grpc.NewServer(
 | 
				
			||||||
		grpc.KeepaliveParams(keepalive.ServerParameters{
 | 
							grpc.KeepaliveParams(keepalive.ServerParameters{
 | 
				
			||||||
			Time: 2 * HeartbeatInterval,
 | 
								Time: 2 * cluster.HeartbeatInterval,
 | 
				
			||||||
		}),
 | 
							}),
 | 
				
			||||||
		grpc.MaxRecvMsgSize(math.MaxInt32),
 | 
							grpc.MaxRecvMsgSize(math.MaxInt32),
 | 
				
			||||||
		grpc.MaxSendMsgSize(math.MaxInt32),
 | 
							grpc.MaxSendMsgSize(math.MaxInt32),
 | 
				
			||||||
@@ -190,7 +171,7 @@ func (rf *requestForwardingHandler) Handoff(ctx context.Context, shutdownWg *syn
 | 
				
			|||||||
// Stop stops the request forwarding server and closes connections.
 | 
					// Stop stops the request forwarding server and closes connections.
 | 
				
			||||||
func (rf *requestForwardingHandler) Stop() error {
 | 
					func (rf *requestForwardingHandler) Stop() error {
 | 
				
			||||||
	// Give some time for existing RPCs to drain.
 | 
						// Give some time for existing RPCs to drain.
 | 
				
			||||||
	time.Sleep(clusterListenerAcceptDeadline)
 | 
						time.Sleep(cluster.ListenerAcceptDeadline)
 | 
				
			||||||
	close(rf.stopCh)
 | 
						close(rf.stopCh)
 | 
				
			||||||
	rf.fwRPCServer.Stop()
 | 
						rf.fwRPCServer.Stop()
 | 
				
			||||||
	return nil
 | 
						return nil
 | 
				
			||||||
@@ -198,16 +179,16 @@ func (rf *requestForwardingHandler) Stop() error {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
// Starts the listeners and servers necessary to handle forwarded requests
 | 
					// Starts the listeners and servers necessary to handle forwarded requests
 | 
				
			||||||
func (c *Core) startForwarding(ctx context.Context) error {
 | 
					func (c *Core) startForwarding(ctx context.Context) error {
 | 
				
			||||||
	c.logger.Debug("cluster listener setup function")
 | 
						c.logger.Debug("request forwarding setup function")
 | 
				
			||||||
	defer c.logger.Debug("leaving cluster listener setup function")
 | 
						defer c.logger.Debug("leaving request forwarding setup function")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Clean up in case we have transitioned from a client to a server
 | 
						// Clean up in case we have transitioned from a client to a server
 | 
				
			||||||
	c.requestForwardingConnectionLock.Lock()
 | 
						c.requestForwardingConnectionLock.Lock()
 | 
				
			||||||
	c.clearForwardingClients()
 | 
						c.clearForwardingClients()
 | 
				
			||||||
	c.requestForwardingConnectionLock.Unlock()
 | 
						c.requestForwardingConnectionLock.Unlock()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Resolve locally to avoid races
 | 
					 | 
				
			||||||
	if c.ha == nil || c.clusterListener == nil {
 | 
						if c.ha == nil || c.clusterListener == nil {
 | 
				
			||||||
 | 
							c.logger.Debug("request forwarding not setup")
 | 
				
			||||||
		return nil
 | 
							return nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -221,15 +202,15 @@ func (c *Core) startForwarding(ctx context.Context) error {
 | 
				
			|||||||
		return err
 | 
							return err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	c.clusterListener.AddHandler(requestForwardingALPN, handler)
 | 
						c.clusterListener.AddHandler(consts.RequestForwardingALPN, handler)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return nil
 | 
						return nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (c *Core) stopForwarding() {
 | 
					func (c *Core) stopForwarding() {
 | 
				
			||||||
	if c.clusterListener != nil {
 | 
						if c.clusterListener != nil {
 | 
				
			||||||
		c.clusterListener.StopHandler(requestForwardingALPN)
 | 
							c.clusterListener.StopHandler(consts.RequestForwardingALPN)
 | 
				
			||||||
		c.clusterListener.StopHandler(perfStandbyALPN)
 | 
							c.clusterListener.StopHandler(consts.PerfStandbyALPN)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -264,7 +245,7 @@ func (c *Core) refreshRequestForwardingConnection(ctx context.Context, clusterAd
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if c.clusterListener != nil {
 | 
						if c.clusterListener != nil {
 | 
				
			||||||
		c.clusterListener.AddClient(requestForwardingALPN, &requestForwardingClusterClient{
 | 
							c.clusterListener.AddClient(consts.RequestForwardingALPN, &requestForwardingClusterClient{
 | 
				
			||||||
			core: c,
 | 
								core: c,
 | 
				
			||||||
		})
 | 
							})
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@@ -275,10 +256,10 @@ func (c *Core) refreshRequestForwardingConnection(ctx context.Context, clusterAd
 | 
				
			|||||||
	// the TLS state.
 | 
						// the TLS state.
 | 
				
			||||||
	dctx, cancelFunc := context.WithCancel(ctx)
 | 
						dctx, cancelFunc := context.WithCancel(ctx)
 | 
				
			||||||
	c.rpcClientConn, err = grpc.DialContext(dctx, clusterURL.Host,
 | 
						c.rpcClientConn, err = grpc.DialContext(dctx, clusterURL.Host,
 | 
				
			||||||
		grpc.WithDialer(c.getGRPCDialer(ctx, requestForwardingALPN, parsedCert.Subject.CommonName, parsedCert)),
 | 
							grpc.WithDialer(c.getGRPCDialer(ctx, consts.RequestForwardingALPN, parsedCert.Subject.CommonName, parsedCert)),
 | 
				
			||||||
		grpc.WithInsecure(), // it's not, we handle it in the dialer
 | 
							grpc.WithInsecure(), // it's not, we handle it in the dialer
 | 
				
			||||||
		grpc.WithKeepaliveParams(keepalive.ClientParameters{
 | 
							grpc.WithKeepaliveParams(keepalive.ClientParameters{
 | 
				
			||||||
			Time: 2 * HeartbeatInterval,
 | 
								Time: 2 * cluster.HeartbeatInterval,
 | 
				
			||||||
		}),
 | 
							}),
 | 
				
			||||||
		grpc.WithDefaultCallOptions(
 | 
							grpc.WithDefaultCallOptions(
 | 
				
			||||||
			grpc.MaxCallRecvMsgSize(math.MaxInt32),
 | 
								grpc.MaxCallRecvMsgSize(math.MaxInt32),
 | 
				
			||||||
@@ -294,7 +275,7 @@ func (c *Core) refreshRequestForwardingConnection(ctx context.Context, clusterAd
 | 
				
			|||||||
	c.rpcForwardingClient = &forwardingClient{
 | 
						c.rpcForwardingClient = &forwardingClient{
 | 
				
			||||||
		RequestForwardingClient: NewRequestForwardingClient(c.rpcClientConn),
 | 
							RequestForwardingClient: NewRequestForwardingClient(c.rpcClientConn),
 | 
				
			||||||
		core:                    c,
 | 
							core:                    c,
 | 
				
			||||||
		echoTicker:              time.NewTicker(HeartbeatInterval),
 | 
							echoTicker:              time.NewTicker(cluster.HeartbeatInterval),
 | 
				
			||||||
		echoContext:             dctx,
 | 
							echoContext:             dctx,
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	c.rpcForwardingClient.startHeartbeat()
 | 
						c.rpcForwardingClient.startHeartbeat()
 | 
				
			||||||
@@ -319,7 +300,7 @@ func (c *Core) clearForwardingClients() {
 | 
				
			|||||||
	c.rpcForwardingClient = nil
 | 
						c.rpcForwardingClient = nil
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if c.clusterListener != nil {
 | 
						if c.clusterListener != nil {
 | 
				
			||||||
		c.clusterListener.RemoveClient(requestForwardingALPN)
 | 
							c.clusterListener.RemoveClient(consts.RequestForwardingALPN)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	c.clusterLeaderParams.Store((*ClusterLeaderParams)(nil))
 | 
						c.clusterLeaderParams.Store((*ClusterLeaderParams)(nil))
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										14
									
								
								vendor/github.com/hashicorp/vault/sdk/helper/consts/consts.go
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										14
									
								
								vendor/github.com/hashicorp/vault/sdk/helper/consts/consts.go
									
									
									
										generated
									
									
										vendored
									
									
								
							@@ -11,4 +11,18 @@ const (
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	// AuthHeaderName is the name of the header containing the token.
 | 
						// AuthHeaderName is the name of the header containing the token.
 | 
				
			||||||
	AuthHeaderName = "X-Vault-Token"
 | 
						AuthHeaderName = "X-Vault-Token"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// PerformanceReplicationALPN is the negotiated protocol used for
 | 
				
			||||||
 | 
						// performance replication.
 | 
				
			||||||
 | 
						PerformanceReplicationALPN = "replication_v1"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// DRReplicationALPN is the negotiated protocol used for
 | 
				
			||||||
 | 
						// dr replication.
 | 
				
			||||||
 | 
						DRReplicationALPN = "replication_dr_v1"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						PerfStandbyALPN = "perf_standby_v1"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						RequestForwardingALPN = "req_fw_sb-act_v1"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						RaftStorageALPN = "raft_storage_v1"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user