mirror of
https://github.com/optim-enterprises-bv/vault.git
synced 2025-11-02 03:27:54 +00:00
Move cluster logic out of vault package (#6601)
* Move cluster logic out of vault package * Dedup heartbeat and fix tests * Fix test
This commit is contained in:
@@ -13,6 +13,7 @@ import (
|
|||||||
|
|
||||||
"github.com/hashicorp/go-hclog"
|
"github.com/hashicorp/go-hclog"
|
||||||
"github.com/hashicorp/vault/sdk/logical"
|
"github.com/hashicorp/vault/sdk/logical"
|
||||||
|
"github.com/hashicorp/vault/vault/cluster"
|
||||||
|
|
||||||
log "github.com/hashicorp/go-hclog"
|
log "github.com/hashicorp/go-hclog"
|
||||||
"github.com/hashicorp/vault/api"
|
"github.com/hashicorp/vault/api"
|
||||||
@@ -302,7 +303,7 @@ func GetPerfReplicatedClusters(t testing.T, conf *vault.CoreConfig, opts *vault.
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Set this lower so that state populates quickly to standby nodes
|
// Set this lower so that state populates quickly to standby nodes
|
||||||
vault.HeartbeatInterval = 2 * time.Second
|
cluster.HeartbeatInterval = 2 * time.Second
|
||||||
|
|
||||||
opts1 := *opts
|
opts1 := *opts
|
||||||
opts1.Logger = logger.Named("perf-pri")
|
opts1.Logger = logger.Named("perf-pri")
|
||||||
@@ -325,7 +326,7 @@ func GetFourReplicatedClusters(t testing.T, handlerFunc func(*vault.HandlerPrope
|
|||||||
Level: log.Trace,
|
Level: log.Trace,
|
||||||
})
|
})
|
||||||
// Set this lower so that state populates quickly to standby nodes
|
// Set this lower so that state populates quickly to standby nodes
|
||||||
vault.HeartbeatInterval = 2 * time.Second
|
cluster.HeartbeatInterval = 2 * time.Second
|
||||||
|
|
||||||
ret.PerfPrimaryCluster, _ = GetClusterAndCore(t, logger.Named("perf-pri"), handlerFunc)
|
ret.PerfPrimaryCluster, _ = GetClusterAndCore(t, logger.Named("perf-pri"), handlerFunc)
|
||||||
|
|
||||||
|
|||||||
@@ -11,4 +11,18 @@ const (
|
|||||||
|
|
||||||
// AuthHeaderName is the name of the header containing the token.
|
// AuthHeaderName is the name of the header containing the token.
|
||||||
AuthHeaderName = "X-Vault-Token"
|
AuthHeaderName = "X-Vault-Token"
|
||||||
|
|
||||||
|
// PerformanceReplicationALPN is the negotiated protocol used for
|
||||||
|
// performance replication.
|
||||||
|
PerformanceReplicationALPN = "replication_v1"
|
||||||
|
|
||||||
|
// DRReplicationALPN is the negotiated protocol used for
|
||||||
|
// dr replication.
|
||||||
|
DRReplicationALPN = "replication_dr_v1"
|
||||||
|
|
||||||
|
PerfStandbyALPN = "perf_standby_v1"
|
||||||
|
|
||||||
|
RequestForwardingALPN = "req_fw_sb-act_v1"
|
||||||
|
|
||||||
|
RaftStorageALPN = "raft_storage_v1"
|
||||||
)
|
)
|
||||||
|
|||||||
320
vault/cluster.go
320
vault/cluster.go
@@ -16,16 +16,13 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
|
||||||
"sync/atomic"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/hashicorp/errwrap"
|
"github.com/hashicorp/errwrap"
|
||||||
log "github.com/hashicorp/go-hclog"
|
|
||||||
uuid "github.com/hashicorp/go-uuid"
|
uuid "github.com/hashicorp/go-uuid"
|
||||||
"github.com/hashicorp/vault/sdk/helper/jsonutil"
|
"github.com/hashicorp/vault/sdk/helper/jsonutil"
|
||||||
"github.com/hashicorp/vault/sdk/logical"
|
"github.com/hashicorp/vault/sdk/logical"
|
||||||
"golang.org/x/net/http2"
|
"github.com/hashicorp/vault/vault/cluster"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -282,296 +279,6 @@ func (c *Core) setupCluster(ctx context.Context) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ClusterClient is used to lookup a client certificate.
|
|
||||||
type ClusterClient interface {
|
|
||||||
ClientLookup(context.Context, *tls.CertificateRequestInfo) (*tls.Certificate, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ClusterHandler exposes functions for looking up TLS configuration and handing
|
|
||||||
// off a connection for a cluster listener application.
|
|
||||||
type ClusterHandler interface {
|
|
||||||
ServerLookup(context.Context, *tls.ClientHelloInfo) (*tls.Certificate, error)
|
|
||||||
CALookup(context.Context) (*x509.Certificate, error)
|
|
||||||
|
|
||||||
// Handoff is used to pass the connection lifetime off to
|
|
||||||
// the handler
|
|
||||||
Handoff(context.Context, *sync.WaitGroup, chan struct{}, *tls.Conn) error
|
|
||||||
Stop() error
|
|
||||||
}
|
|
||||||
|
|
||||||
// ClusterListener is the source of truth for cluster handlers and connection
|
|
||||||
// clients. It dynamically builds the cluster TLS information. It's also
|
|
||||||
// responsible for starting tcp listeners and accepting new cluster connections.
|
|
||||||
type ClusterListener struct {
|
|
||||||
handlers map[string]ClusterHandler
|
|
||||||
clients map[string]ClusterClient
|
|
||||||
shutdown *uint32
|
|
||||||
shutdownWg *sync.WaitGroup
|
|
||||||
server *http2.Server
|
|
||||||
|
|
||||||
clusterListenerAddrs []*net.TCPAddr
|
|
||||||
clusterCipherSuites []uint16
|
|
||||||
logger log.Logger
|
|
||||||
l sync.RWMutex
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddClient adds a new client for an ALPN name
|
|
||||||
func (cl *ClusterListener) AddClient(alpn string, client ClusterClient) {
|
|
||||||
cl.l.Lock()
|
|
||||||
cl.clients[alpn] = client
|
|
||||||
cl.l.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemoveClient removes the client for the specified ALPN name
|
|
||||||
func (cl *ClusterListener) RemoveClient(alpn string) {
|
|
||||||
cl.l.Lock()
|
|
||||||
delete(cl.clients, alpn)
|
|
||||||
cl.l.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddHandler registers a new cluster handler for the provided ALPN name.
|
|
||||||
func (cl *ClusterListener) AddHandler(alpn string, handler ClusterHandler) {
|
|
||||||
cl.l.Lock()
|
|
||||||
cl.handlers[alpn] = handler
|
|
||||||
cl.l.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
// StopHandler stops the cluster handler for the provided ALPN name, it also
|
|
||||||
// calls stop on the handler.
|
|
||||||
func (cl *ClusterListener) StopHandler(alpn string) {
|
|
||||||
cl.l.Lock()
|
|
||||||
handler, ok := cl.handlers[alpn]
|
|
||||||
delete(cl.handlers, alpn)
|
|
||||||
cl.l.Unlock()
|
|
||||||
if ok {
|
|
||||||
handler.Stop()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Server returns the http2 server that the cluster listener is using
|
|
||||||
func (cl *ClusterListener) Server() *http2.Server {
|
|
||||||
return cl.server
|
|
||||||
}
|
|
||||||
|
|
||||||
// TLSConfig returns a tls config object that uses dynamic lookups to correctly
|
|
||||||
// authenticate registered handlers/clients
|
|
||||||
func (cl *ClusterListener) TLSConfig(ctx context.Context) (*tls.Config, error) {
|
|
||||||
serverLookup := func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
|
||||||
cl.logger.Debug("performing server cert lookup")
|
|
||||||
|
|
||||||
cl.l.RLock()
|
|
||||||
defer cl.l.RUnlock()
|
|
||||||
for _, v := range clientHello.SupportedProtos {
|
|
||||||
if handler, ok := cl.handlers[v]; ok {
|
|
||||||
return handler.ServerLookup(ctx, clientHello)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
cl.logger.Warn("no TLS certs found for ALPN", "ALPN", clientHello.SupportedProtos)
|
|
||||||
return nil, errors.New("unsupported protocol")
|
|
||||||
}
|
|
||||||
|
|
||||||
clientLookup := func(requestInfo *tls.CertificateRequestInfo) (*tls.Certificate, error) {
|
|
||||||
cl.logger.Debug("performing client cert lookup")
|
|
||||||
|
|
||||||
cl.l.RLock()
|
|
||||||
defer cl.l.RUnlock()
|
|
||||||
for _, client := range cl.clients {
|
|
||||||
cert, err := client.ClientLookup(ctx, requestInfo)
|
|
||||||
if err == nil && cert != nil {
|
|
||||||
return cert, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
cl.logger.Warn("no client information found")
|
|
||||||
return nil, errors.New("no client cert found")
|
|
||||||
}
|
|
||||||
|
|
||||||
serverConfigLookup := func(clientHello *tls.ClientHelloInfo) (*tls.Config, error) {
|
|
||||||
caPool := x509.NewCertPool()
|
|
||||||
|
|
||||||
ret := &tls.Config{
|
|
||||||
ClientAuth: tls.RequireAndVerifyClientCert,
|
|
||||||
GetCertificate: serverLookup,
|
|
||||||
GetClientCertificate: clientLookup,
|
|
||||||
MinVersion: tls.VersionTLS12,
|
|
||||||
RootCAs: caPool,
|
|
||||||
ClientCAs: caPool,
|
|
||||||
NextProtos: clientHello.SupportedProtos,
|
|
||||||
CipherSuites: cl.clusterCipherSuites,
|
|
||||||
}
|
|
||||||
|
|
||||||
cl.l.RLock()
|
|
||||||
defer cl.l.RUnlock()
|
|
||||||
for _, v := range clientHello.SupportedProtos {
|
|
||||||
if handler, ok := cl.handlers[v]; ok {
|
|
||||||
ca, err := handler.CALookup(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
caPool.AddCert(ca)
|
|
||||||
return ret, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
cl.logger.Warn("no TLS config found for ALPN", "ALPN", clientHello.SupportedProtos)
|
|
||||||
return nil, errors.New("unsupported protocol")
|
|
||||||
}
|
|
||||||
|
|
||||||
return &tls.Config{
|
|
||||||
ClientAuth: tls.RequireAndVerifyClientCert,
|
|
||||||
GetCertificate: serverLookup,
|
|
||||||
GetClientCertificate: clientLookup,
|
|
||||||
GetConfigForClient: serverConfigLookup,
|
|
||||||
MinVersion: tls.VersionTLS12,
|
|
||||||
CipherSuites: cl.clusterCipherSuites,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Run starts the tcp listeners and will accept connections until stop is
|
|
||||||
// called.
|
|
||||||
func (cl *ClusterListener) Run(ctx context.Context) error {
|
|
||||||
// Get our TLS config
|
|
||||||
tlsConfig, err := cl.TLSConfig(ctx)
|
|
||||||
if err != nil {
|
|
||||||
cl.logger.Error("failed to get tls configuration when starting cluster listener", "error", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// The server supports all of the possible protos
|
|
||||||
tlsConfig.NextProtos = []string{"h2", requestForwardingALPN, perfStandbyALPN, PerformanceReplicationALPN, DRReplicationALPN}
|
|
||||||
|
|
||||||
for i, laddr := range cl.clusterListenerAddrs {
|
|
||||||
// closeCh is used to shutdown the spawned goroutines once this
|
|
||||||
// function returns
|
|
||||||
closeCh := make(chan struct{})
|
|
||||||
|
|
||||||
if cl.logger.IsInfo() {
|
|
||||||
cl.logger.Info("starting listener", "listener_address", laddr)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create a TCP listener. We do this separately and specifically
|
|
||||||
// with TCP so that we can set deadlines.
|
|
||||||
tcpLn, err := net.ListenTCP("tcp", laddr)
|
|
||||||
if err != nil {
|
|
||||||
cl.logger.Error("error starting listener", "error", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if laddr.String() != tcpLn.Addr().String() {
|
|
||||||
// If we listened on port 0, record the port the OS gave us.
|
|
||||||
cl.clusterListenerAddrs[i] = tcpLn.Addr().(*net.TCPAddr)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Wrap the listener with TLS
|
|
||||||
tlsLn := tls.NewListener(tcpLn, tlsConfig)
|
|
||||||
|
|
||||||
if cl.logger.IsInfo() {
|
|
||||||
cl.logger.Info("serving cluster requests", "cluster_listen_address", tlsLn.Addr())
|
|
||||||
}
|
|
||||||
|
|
||||||
cl.shutdownWg.Add(1)
|
|
||||||
// Start our listening loop
|
|
||||||
go func(closeCh chan struct{}, tlsLn net.Listener) {
|
|
||||||
defer func() {
|
|
||||||
cl.shutdownWg.Done()
|
|
||||||
tlsLn.Close()
|
|
||||||
close(closeCh)
|
|
||||||
}()
|
|
||||||
|
|
||||||
for {
|
|
||||||
if atomic.LoadUint32(cl.shutdown) > 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set the deadline for the accept call. If it passes we'll get
|
|
||||||
// an error, causing us to check the condition at the top
|
|
||||||
// again.
|
|
||||||
tcpLn.SetDeadline(time.Now().Add(clusterListenerAcceptDeadline))
|
|
||||||
|
|
||||||
// Accept the connection
|
|
||||||
conn, err := tlsLn.Accept()
|
|
||||||
if err != nil {
|
|
||||||
if err, ok := err.(net.Error); ok && !err.Timeout() {
|
|
||||||
cl.logger.Debug("non-timeout error accepting on cluster port", "error", err)
|
|
||||||
}
|
|
||||||
if conn != nil {
|
|
||||||
conn.Close()
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if conn == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// Type assert to TLS connection and handshake to populate the
|
|
||||||
// connection state
|
|
||||||
tlsConn := conn.(*tls.Conn)
|
|
||||||
|
|
||||||
// Set a deadline for the handshake. This will cause clients
|
|
||||||
// that don't successfully auth to be kicked out quickly.
|
|
||||||
// Cluster connections should be reliable so being marginally
|
|
||||||
// aggressive here is fine.
|
|
||||||
err = tlsConn.SetDeadline(time.Now().Add(30 * time.Second))
|
|
||||||
if err != nil {
|
|
||||||
if cl.logger.IsDebug() {
|
|
||||||
cl.logger.Debug("error setting deadline for cluster connection", "error", err)
|
|
||||||
}
|
|
||||||
tlsConn.Close()
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
err = tlsConn.Handshake()
|
|
||||||
if err != nil {
|
|
||||||
if cl.logger.IsDebug() {
|
|
||||||
cl.logger.Debug("error handshaking cluster connection", "error", err)
|
|
||||||
}
|
|
||||||
tlsConn.Close()
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// Now, set it back to unlimited
|
|
||||||
err = tlsConn.SetDeadline(time.Time{})
|
|
||||||
if err != nil {
|
|
||||||
if cl.logger.IsDebug() {
|
|
||||||
cl.logger.Debug("error setting deadline for cluster connection", "error", err)
|
|
||||||
}
|
|
||||||
tlsConn.Close()
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
cl.l.RLock()
|
|
||||||
handler, ok := cl.handlers[tlsConn.ConnectionState().NegotiatedProtocol]
|
|
||||||
cl.l.RUnlock()
|
|
||||||
if !ok {
|
|
||||||
cl.logger.Debug("unknown negotiated protocol on cluster port")
|
|
||||||
tlsConn.Close()
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := handler.Handoff(ctx, cl.shutdownWg, closeCh, tlsConn); err != nil {
|
|
||||||
cl.logger.Error("error handling cluster connection", "error", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}(closeCh, tlsLn)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Stop stops the cluster listner
|
|
||||||
func (cl *ClusterListener) Stop() {
|
|
||||||
// Set the shutdown flag. This will cause the listeners to shut down
|
|
||||||
// within the deadline in clusterListenerAcceptDeadline
|
|
||||||
atomic.StoreUint32(cl.shutdown, 1)
|
|
||||||
cl.logger.Info("forwarding rpc listeners stopped")
|
|
||||||
|
|
||||||
// Wait for them all to shut down
|
|
||||||
cl.shutdownWg.Wait()
|
|
||||||
cl.logger.Info("rpc listeners successfully shut down")
|
|
||||||
}
|
|
||||||
|
|
||||||
// startClusterListener starts cluster request listeners during unseal. It
|
// startClusterListener starts cluster request listeners during unseal. It
|
||||||
// is assumed that the state lock is held while this is run. Right now this
|
// is assumed that the state lock is held while this is run. Right now this
|
||||||
// only starts cluster listeners. Once the listener is started handlers/clients
|
// only starts cluster listeners. Once the listener is started handlers/clients
|
||||||
@@ -589,27 +296,7 @@ func (c *Core) startClusterListener(ctx context.Context) error {
|
|||||||
|
|
||||||
c.logger.Debug("starting cluster listeners")
|
c.logger.Debug("starting cluster listeners")
|
||||||
|
|
||||||
// Create the HTTP/2 server that will be shared by both RPC and regular
|
c.clusterListener = cluster.NewListener(c.clusterListenerAddrs, c.clusterCipherSuites, c.logger.Named("cluster-listener"))
|
||||||
// duties. Doing it this way instead of listening via the server and gRPC
|
|
||||||
// allows us to re-use the same port via ALPN. We can just tell the server
|
|
||||||
// to serve a given conn and which handler to use.
|
|
||||||
h2Server := &http2.Server{
|
|
||||||
// Our forwarding connections heartbeat regularly so anything else we
|
|
||||||
// want to go away/get cleaned up pretty rapidly
|
|
||||||
IdleTimeout: 5 * HeartbeatInterval,
|
|
||||||
}
|
|
||||||
|
|
||||||
c.clusterListener = &ClusterListener{
|
|
||||||
handlers: make(map[string]ClusterHandler),
|
|
||||||
clients: make(map[string]ClusterClient),
|
|
||||||
shutdown: new(uint32),
|
|
||||||
shutdownWg: &sync.WaitGroup{},
|
|
||||||
server: h2Server,
|
|
||||||
|
|
||||||
clusterListenerAddrs: c.clusterListenerAddrs,
|
|
||||||
clusterCipherSuites: c.clusterCipherSuites,
|
|
||||||
logger: c.logger.Named("cluster-listener"),
|
|
||||||
}
|
|
||||||
|
|
||||||
err := c.clusterListener.Run(ctx)
|
err := c.clusterListener.Run(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -617,7 +304,7 @@ func (c *Core) startClusterListener(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
if strings.HasSuffix(c.clusterAddr, ":0") {
|
if strings.HasSuffix(c.clusterAddr, ":0") {
|
||||||
// If we listened on port 0, record the port the OS gave us.
|
// If we listened on port 0, record the port the OS gave us.
|
||||||
c.clusterAddr = fmt.Sprintf("https://%s", c.clusterListener.clusterListenerAddrs[0])
|
c.clusterAddr = fmt.Sprintf("https://%s", c.clusterListener.Addrs()[0])
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -633,6 +320,7 @@ func (c *Core) stopClusterListener() {
|
|||||||
c.logger.Info("stopping cluster listeners")
|
c.logger.Info("stopping cluster listeners")
|
||||||
|
|
||||||
c.clusterListener.Stop()
|
c.clusterListener.Stop()
|
||||||
|
c.clusterListener = nil
|
||||||
|
|
||||||
c.logger.Info("cluster listeners successfully shut down")
|
c.logger.Info("cluster listeners successfully shut down")
|
||||||
}
|
}
|
||||||
|
|||||||
343
vault/cluster/cluster.go
Normal file
343
vault/cluster/cluster.go
Normal file
@@ -0,0 +1,343 @@
|
|||||||
|
package cluster
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
|
"errors"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/hashicorp/go-hclog"
|
||||||
|
"github.com/hashicorp/vault/sdk/helper/consts"
|
||||||
|
"golang.org/x/net/http2"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
// Making this a package var allows tests to modify
|
||||||
|
HeartbeatInterval = 5 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
ListenerAcceptDeadline = 500 * time.Millisecond
|
||||||
|
)
|
||||||
|
|
||||||
|
// Client is used to lookup a client certificate.
|
||||||
|
type Client interface {
|
||||||
|
ClientLookup(context.Context, *tls.CertificateRequestInfo) (*tls.Certificate, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handler exposes functions for looking up TLS configuration and handing
|
||||||
|
// off a connection for a cluster listener application.
|
||||||
|
type Handler interface {
|
||||||
|
ServerLookup(context.Context, *tls.ClientHelloInfo) (*tls.Certificate, error)
|
||||||
|
CALookup(context.Context) (*x509.Certificate, error)
|
||||||
|
|
||||||
|
// Handoff is used to pass the connection lifetime off to
|
||||||
|
// the handler
|
||||||
|
Handoff(context.Context, *sync.WaitGroup, chan struct{}, *tls.Conn) error
|
||||||
|
Stop() error
|
||||||
|
}
|
||||||
|
|
||||||
|
// Listener is the source of truth for cluster handlers and connection
|
||||||
|
// clients. It dynamically builds the cluster TLS information. It's also
|
||||||
|
// responsible for starting tcp listeners and accepting new cluster connections.
|
||||||
|
type Listener struct {
|
||||||
|
handlers map[string]Handler
|
||||||
|
clients map[string]Client
|
||||||
|
shutdown *uint32
|
||||||
|
shutdownWg *sync.WaitGroup
|
||||||
|
server *http2.Server
|
||||||
|
|
||||||
|
listenerAddrs []*net.TCPAddr
|
||||||
|
cipherSuites []uint16
|
||||||
|
logger log.Logger
|
||||||
|
l sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewListener(addrs []*net.TCPAddr, cipherSuites []uint16, logger log.Logger) *Listener {
|
||||||
|
// Create the HTTP/2 server that will be shared by both RPC and regular
|
||||||
|
// duties. Doing it this way instead of listening via the server and gRPC
|
||||||
|
// allows us to re-use the same port via ALPN. We can just tell the server
|
||||||
|
// to serve a given conn and which handler to use.
|
||||||
|
h2Server := &http2.Server{
|
||||||
|
// Our forwarding connections heartbeat regularly so anything else we
|
||||||
|
// want to go away/get cleaned up pretty rapidly
|
||||||
|
IdleTimeout: 5 * HeartbeatInterval,
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Listener{
|
||||||
|
handlers: make(map[string]Handler),
|
||||||
|
clients: make(map[string]Client),
|
||||||
|
shutdown: new(uint32),
|
||||||
|
shutdownWg: &sync.WaitGroup{},
|
||||||
|
server: h2Server,
|
||||||
|
|
||||||
|
listenerAddrs: addrs,
|
||||||
|
cipherSuites: cipherSuites,
|
||||||
|
logger: logger,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cl *Listener) Addrs() []*net.TCPAddr {
|
||||||
|
return cl.listenerAddrs
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddClient adds a new client for an ALPN name
|
||||||
|
func (cl *Listener) AddClient(alpn string, client Client) {
|
||||||
|
cl.l.Lock()
|
||||||
|
cl.clients[alpn] = client
|
||||||
|
cl.l.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveClient removes the client for the specified ALPN name
|
||||||
|
func (cl *Listener) RemoveClient(alpn string) {
|
||||||
|
cl.l.Lock()
|
||||||
|
delete(cl.clients, alpn)
|
||||||
|
cl.l.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddHandler registers a new cluster handler for the provided ALPN name.
|
||||||
|
func (cl *Listener) AddHandler(alpn string, handler Handler) {
|
||||||
|
cl.l.Lock()
|
||||||
|
cl.handlers[alpn] = handler
|
||||||
|
cl.l.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// StopHandler stops the cluster handler for the provided ALPN name, it also
|
||||||
|
// calls stop on the handler.
|
||||||
|
func (cl *Listener) StopHandler(alpn string) {
|
||||||
|
cl.l.Lock()
|
||||||
|
handler, ok := cl.handlers[alpn]
|
||||||
|
delete(cl.handlers, alpn)
|
||||||
|
cl.l.Unlock()
|
||||||
|
if ok {
|
||||||
|
handler.Stop()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Server returns the http2 server that the cluster listener is using
|
||||||
|
func (cl *Listener) Server() *http2.Server {
|
||||||
|
return cl.server
|
||||||
|
}
|
||||||
|
|
||||||
|
// TLSConfig returns a tls config object that uses dynamic lookups to correctly
|
||||||
|
// authenticate registered handlers/clients
|
||||||
|
func (cl *Listener) TLSConfig(ctx context.Context) (*tls.Config, error) {
|
||||||
|
serverLookup := func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||||
|
cl.logger.Debug("performing server cert lookup")
|
||||||
|
|
||||||
|
cl.l.RLock()
|
||||||
|
defer cl.l.RUnlock()
|
||||||
|
for _, v := range clientHello.SupportedProtos {
|
||||||
|
if handler, ok := cl.handlers[v]; ok {
|
||||||
|
return handler.ServerLookup(ctx, clientHello)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cl.logger.Warn("no TLS certs found for ALPN", "ALPN", clientHello.SupportedProtos)
|
||||||
|
return nil, errors.New("unsupported protocol")
|
||||||
|
}
|
||||||
|
|
||||||
|
clientLookup := func(requestInfo *tls.CertificateRequestInfo) (*tls.Certificate, error) {
|
||||||
|
cl.logger.Debug("performing client cert lookup")
|
||||||
|
|
||||||
|
cl.l.RLock()
|
||||||
|
defer cl.l.RUnlock()
|
||||||
|
for _, client := range cl.clients {
|
||||||
|
cert, err := client.ClientLookup(ctx, requestInfo)
|
||||||
|
if err == nil && cert != nil {
|
||||||
|
return cert, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cl.logger.Warn("no client information found")
|
||||||
|
return nil, errors.New("no client cert found")
|
||||||
|
}
|
||||||
|
|
||||||
|
serverConfigLookup := func(clientHello *tls.ClientHelloInfo) (*tls.Config, error) {
|
||||||
|
caPool := x509.NewCertPool()
|
||||||
|
|
||||||
|
ret := &tls.Config{
|
||||||
|
ClientAuth: tls.RequireAndVerifyClientCert,
|
||||||
|
GetCertificate: serverLookup,
|
||||||
|
GetClientCertificate: clientLookup,
|
||||||
|
MinVersion: tls.VersionTLS12,
|
||||||
|
RootCAs: caPool,
|
||||||
|
ClientCAs: caPool,
|
||||||
|
NextProtos: clientHello.SupportedProtos,
|
||||||
|
CipherSuites: cl.cipherSuites,
|
||||||
|
}
|
||||||
|
|
||||||
|
cl.l.RLock()
|
||||||
|
defer cl.l.RUnlock()
|
||||||
|
for _, v := range clientHello.SupportedProtos {
|
||||||
|
if handler, ok := cl.handlers[v]; ok {
|
||||||
|
ca, err := handler.CALookup(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
caPool.AddCert(ca)
|
||||||
|
return ret, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cl.logger.Warn("no TLS config found for ALPN", "ALPN", clientHello.SupportedProtos)
|
||||||
|
return nil, errors.New("unsupported protocol")
|
||||||
|
}
|
||||||
|
|
||||||
|
return &tls.Config{
|
||||||
|
ClientAuth: tls.RequireAndVerifyClientCert,
|
||||||
|
GetCertificate: serverLookup,
|
||||||
|
GetClientCertificate: clientLookup,
|
||||||
|
GetConfigForClient: serverConfigLookup,
|
||||||
|
MinVersion: tls.VersionTLS12,
|
||||||
|
CipherSuites: cl.cipherSuites,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run starts the tcp listeners and will accept connections until stop is
|
||||||
|
// called.
|
||||||
|
func (cl *Listener) Run(ctx context.Context) error {
|
||||||
|
// Get our TLS config
|
||||||
|
tlsConfig, err := cl.TLSConfig(ctx)
|
||||||
|
if err != nil {
|
||||||
|
cl.logger.Error("failed to get tls configuration when starting cluster listener", "error", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// The server supports all of the possible protos
|
||||||
|
tlsConfig.NextProtos = []string{"h2", consts.RequestForwardingALPN, consts.PerfStandbyALPN, consts.PerformanceReplicationALPN, consts.DRReplicationALPN}
|
||||||
|
|
||||||
|
for i, laddr := range cl.listenerAddrs {
|
||||||
|
// closeCh is used to shutdown the spawned goroutines once this
|
||||||
|
// function returns
|
||||||
|
closeCh := make(chan struct{})
|
||||||
|
|
||||||
|
if cl.logger.IsInfo() {
|
||||||
|
cl.logger.Info("starting listener", "listener_address", laddr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a TCP listener. We do this separately and specifically
|
||||||
|
// with TCP so that we can set deadlines.
|
||||||
|
tcpLn, err := net.ListenTCP("tcp", laddr)
|
||||||
|
if err != nil {
|
||||||
|
cl.logger.Error("error starting listener", "error", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if laddr.String() != tcpLn.Addr().String() {
|
||||||
|
// If we listened on port 0, record the port the OS gave us.
|
||||||
|
cl.listenerAddrs[i] = tcpLn.Addr().(*net.TCPAddr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wrap the listener with TLS
|
||||||
|
tlsLn := tls.NewListener(tcpLn, tlsConfig)
|
||||||
|
|
||||||
|
if cl.logger.IsInfo() {
|
||||||
|
cl.logger.Info("serving cluster requests", "cluster_listen_address", tlsLn.Addr())
|
||||||
|
}
|
||||||
|
|
||||||
|
cl.shutdownWg.Add(1)
|
||||||
|
// Start our listening loop
|
||||||
|
go func(closeCh chan struct{}, tlsLn net.Listener) {
|
||||||
|
defer func() {
|
||||||
|
cl.shutdownWg.Done()
|
||||||
|
tlsLn.Close()
|
||||||
|
close(closeCh)
|
||||||
|
}()
|
||||||
|
|
||||||
|
for {
|
||||||
|
if atomic.LoadUint32(cl.shutdown) > 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set the deadline for the accept call. If it passes we'll get
|
||||||
|
// an error, causing us to check the condition at the top
|
||||||
|
// again.
|
||||||
|
tcpLn.SetDeadline(time.Now().Add(ListenerAcceptDeadline))
|
||||||
|
|
||||||
|
// Accept the connection
|
||||||
|
conn, err := tlsLn.Accept()
|
||||||
|
if err != nil {
|
||||||
|
if err, ok := err.(net.Error); ok && !err.Timeout() {
|
||||||
|
cl.logger.Debug("non-timeout error accepting on cluster port", "error", err)
|
||||||
|
}
|
||||||
|
if conn != nil {
|
||||||
|
conn.Close()
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if conn == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Type assert to TLS connection and handshake to populate the
|
||||||
|
// connection state
|
||||||
|
tlsConn := conn.(*tls.Conn)
|
||||||
|
|
||||||
|
// Set a deadline for the handshake. This will cause clients
|
||||||
|
// that don't successfully auth to be kicked out quickly.
|
||||||
|
// Cluster connections should be reliable so being marginally
|
||||||
|
// aggressive here is fine.
|
||||||
|
err = tlsConn.SetDeadline(time.Now().Add(30 * time.Second))
|
||||||
|
if err != nil {
|
||||||
|
if cl.logger.IsDebug() {
|
||||||
|
cl.logger.Debug("error setting deadline for cluster connection", "error", err)
|
||||||
|
}
|
||||||
|
tlsConn.Close()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
err = tlsConn.Handshake()
|
||||||
|
if err != nil {
|
||||||
|
if cl.logger.IsDebug() {
|
||||||
|
cl.logger.Debug("error handshaking cluster connection", "error", err)
|
||||||
|
}
|
||||||
|
tlsConn.Close()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now, set it back to unlimited
|
||||||
|
err = tlsConn.SetDeadline(time.Time{})
|
||||||
|
if err != nil {
|
||||||
|
if cl.logger.IsDebug() {
|
||||||
|
cl.logger.Debug("error setting deadline for cluster connection", "error", err)
|
||||||
|
}
|
||||||
|
tlsConn.Close()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
cl.l.RLock()
|
||||||
|
handler, ok := cl.handlers[tlsConn.ConnectionState().NegotiatedProtocol]
|
||||||
|
cl.l.RUnlock()
|
||||||
|
if !ok {
|
||||||
|
cl.logger.Debug("unknown negotiated protocol on cluster port")
|
||||||
|
tlsConn.Close()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := handler.Handoff(ctx, cl.shutdownWg, closeCh, tlsConn); err != nil {
|
||||||
|
cl.logger.Error("error handling cluster connection", "error", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}(closeCh, tlsLn)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop stops the cluster listner
|
||||||
|
func (cl *Listener) Stop() {
|
||||||
|
// Set the shutdown flag. This will cause the listeners to shut down
|
||||||
|
// within the deadline in clusterListenerAcceptDeadline
|
||||||
|
atomic.StoreUint32(cl.shutdown, 1)
|
||||||
|
cl.logger.Info("forwarding rpc listeners stopped")
|
||||||
|
|
||||||
|
// Wait for them all to shut down
|
||||||
|
cl.shutdownWg.Wait()
|
||||||
|
cl.logger.Info("rpc listeners successfully shut down")
|
||||||
|
}
|
||||||
@@ -101,17 +101,17 @@ func TestCluster_ListenForRequests(t *testing.T) {
|
|||||||
// Wait for core to become active
|
// Wait for core to become active
|
||||||
TestWaitActive(t, cores[0].Core)
|
TestWaitActive(t, cores[0].Core)
|
||||||
|
|
||||||
|
cores[0].clusterListener.AddClient(consts.RequestForwardingALPN, &requestForwardingClusterClient{cores[0].Core})
|
||||||
|
addrs := cores[0].clusterListener.Addrs()
|
||||||
|
|
||||||
// Use this to have a valid config after sealing since ClusterTLSConfig returns nil
|
// Use this to have a valid config after sealing since ClusterTLSConfig returns nil
|
||||||
checkListenersFunc := func(expectFail bool) {
|
checkListenersFunc := func(expectFail bool) {
|
||||||
cores[0].clusterListener.AddClient(requestForwardingALPN, &requestForwardingClusterClient{cores[0].Core})
|
|
||||||
|
|
||||||
parsedCert := cores[0].localClusterParsedCert.Load().(*x509.Certificate)
|
parsedCert := cores[0].localClusterParsedCert.Load().(*x509.Certificate)
|
||||||
dialer := cores[0].getGRPCDialer(context.Background(), requestForwardingALPN, parsedCert.Subject.CommonName, parsedCert)
|
dialer := cores[0].getGRPCDialer(context.Background(), consts.RequestForwardingALPN, parsedCert.Subject.CommonName, parsedCert)
|
||||||
for i := range cores[0].Listeners {
|
for i := range cores[0].Listeners {
|
||||||
|
|
||||||
clnAddr := cores[0].clusterListener.clusterListenerAddrs[i]
|
clnAddr := addrs[i]
|
||||||
netConn, err := dialer(clnAddr.String(), 0)
|
netConn, err := dialer(clnAddr.String(), 0)
|
||||||
conn := netConn.(*tls.Conn)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if expectFail {
|
if expectFail {
|
||||||
t.Logf("testing %s unsuccessful as expected", clnAddr)
|
t.Logf("testing %s unsuccessful as expected", clnAddr)
|
||||||
@@ -122,6 +122,7 @@ func TestCluster_ListenForRequests(t *testing.T) {
|
|||||||
if expectFail {
|
if expectFail {
|
||||||
t.Fatalf("testing %s not unsuccessful as expected", clnAddr)
|
t.Fatalf("testing %s not unsuccessful as expected", clnAddr)
|
||||||
}
|
}
|
||||||
|
conn := netConn.(*tls.Conn)
|
||||||
err = conn.Handshake()
|
err = conn.Handshake()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
@@ -130,7 +131,7 @@ func TestCluster_ListenForRequests(t *testing.T) {
|
|||||||
switch {
|
switch {
|
||||||
case connState.Version != tls.VersionTLS12:
|
case connState.Version != tls.VersionTLS12:
|
||||||
t.Fatal("version mismatch")
|
t.Fatal("version mismatch")
|
||||||
case connState.NegotiatedProtocol != requestForwardingALPN || !connState.NegotiatedProtocolIsMutual:
|
case connState.NegotiatedProtocol != consts.RequestForwardingALPN || !connState.NegotiatedProtocolIsMutual:
|
||||||
t.Fatal("bad protocol negotiation")
|
t.Fatal("bad protocol negotiation")
|
||||||
}
|
}
|
||||||
t.Logf("testing %s successful", clnAddr)
|
t.Logf("testing %s successful", clnAddr)
|
||||||
@@ -155,7 +156,8 @@ func TestCluster_ListenForRequests(t *testing.T) {
|
|||||||
checkListenersFunc(true)
|
checkListenersFunc(true)
|
||||||
|
|
||||||
// After this period it should be active again
|
// After this period it should be active again
|
||||||
time.Sleep(manualStepDownSleepPeriod)
|
TestWaitActive(t, cores[0].Core)
|
||||||
|
cores[0].clusterListener.AddClient(consts.RequestForwardingALPN, &requestForwardingClusterClient{cores[0].Core})
|
||||||
checkListenersFunc(false)
|
checkListenersFunc(false)
|
||||||
|
|
||||||
err = cores[0].Core.Seal(cluster.RootToken)
|
err = cores[0].Core.Seal(cluster.RootToken)
|
||||||
@@ -382,12 +384,12 @@ func TestCluster_CustomCipherSuites(t *testing.T) {
|
|||||||
// Wait for core to become active
|
// Wait for core to become active
|
||||||
TestWaitActive(t, core.Core)
|
TestWaitActive(t, core.Core)
|
||||||
|
|
||||||
core.clusterListener.AddClient(requestForwardingALPN, &requestForwardingClusterClient{core.Core})
|
core.clusterListener.AddClient(consts.RequestForwardingALPN, &requestForwardingClusterClient{core.Core})
|
||||||
|
|
||||||
parsedCert := core.localClusterParsedCert.Load().(*x509.Certificate)
|
parsedCert := core.localClusterParsedCert.Load().(*x509.Certificate)
|
||||||
dialer := core.getGRPCDialer(context.Background(), requestForwardingALPN, parsedCert.Subject.CommonName, parsedCert)
|
dialer := core.getGRPCDialer(context.Background(), consts.RequestForwardingALPN, parsedCert.Subject.CommonName, parsedCert)
|
||||||
|
|
||||||
netConn, err := dialer(core.clusterListener.clusterListenerAddrs[0].String(), 0)
|
netConn, err := dialer(core.clusterListener.Addrs()[0].String(), 0)
|
||||||
conn := netConn.(*tls.Conn)
|
conn := netConn.(*tls.Conn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
|
|||||||
@@ -38,6 +38,7 @@ import (
|
|||||||
"github.com/hashicorp/vault/sdk/logical"
|
"github.com/hashicorp/vault/sdk/logical"
|
||||||
"github.com/hashicorp/vault/sdk/physical"
|
"github.com/hashicorp/vault/sdk/physical"
|
||||||
"github.com/hashicorp/vault/shamir"
|
"github.com/hashicorp/vault/shamir"
|
||||||
|
"github.com/hashicorp/vault/vault/cluster"
|
||||||
"github.com/hashicorp/vault/vault/seal"
|
"github.com/hashicorp/vault/vault/seal"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -419,7 +420,7 @@ type Core struct {
|
|||||||
loadCaseSensitiveIdentityStore bool
|
loadCaseSensitiveIdentityStore bool
|
||||||
|
|
||||||
// clusterListener starts up and manages connections on the cluster ports
|
// clusterListener starts up and manages connections on the cluster ports
|
||||||
clusterListener *ClusterListener
|
clusterListener *cluster.Listener
|
||||||
|
|
||||||
// Telemetry objects
|
// Telemetry objects
|
||||||
metricsHelper *metricsutil.MetricsHelper
|
metricsHelper *metricsutil.MetricsHelper
|
||||||
@@ -592,7 +593,7 @@ func NewCore(conf *CoreConfig) (*Core, error) {
|
|||||||
maxLeaseTTL: conf.MaxLeaseTTL,
|
maxLeaseTTL: conf.MaxLeaseTTL,
|
||||||
cachingDisabled: conf.DisableCache,
|
cachingDisabled: conf.DisableCache,
|
||||||
clusterName: conf.ClusterName,
|
clusterName: conf.ClusterName,
|
||||||
clusterPeerClusterAddrsCache: cache.New(3*HeartbeatInterval, time.Second),
|
clusterPeerClusterAddrsCache: cache.New(3*cluster.HeartbeatInterval, time.Second),
|
||||||
enableMlock: !conf.DisableMlock,
|
enableMlock: !conf.DisableMlock,
|
||||||
rawEnabled: conf.EnableRaw,
|
rawEnabled: conf.EnableRaw,
|
||||||
replicationState: new(uint32),
|
replicationState: new(uint32),
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
"github.com/hashicorp/vault/sdk/helper/license"
|
"github.com/hashicorp/vault/sdk/helper/license"
|
||||||
"github.com/hashicorp/vault/sdk/logical"
|
"github.com/hashicorp/vault/sdk/logical"
|
||||||
"github.com/hashicorp/vault/sdk/physical"
|
"github.com/hashicorp/vault/sdk/physical"
|
||||||
|
"github.com/hashicorp/vault/vault/cluster"
|
||||||
"github.com/hashicorp/vault/vault/replication"
|
"github.com/hashicorp/vault/vault/replication"
|
||||||
cache "github.com/patrickmn/go-cache"
|
cache "github.com/patrickmn/go-cache"
|
||||||
)
|
)
|
||||||
@@ -111,5 +112,5 @@ func (c *Core) invalidateSentinelPolicy(PolicyType, string) {}
|
|||||||
func (c *Core) removePerfStandbySecondary(context.Context, string) {}
|
func (c *Core) removePerfStandbySecondary(context.Context, string) {}
|
||||||
|
|
||||||
func (c *Core) perfStandbyClusterHandler() (*replication.Cluster, *cache.Cache, chan struct{}, error) {
|
func (c *Core) perfStandbyClusterHandler() (*replication.Cluster, *cache.Cache, chan struct{}, error) {
|
||||||
return nil, cache.New(2*HeartbeatInterval, 1*time.Second), make(chan struct{}), nil
|
return nil, cache.New(2*cluster.HeartbeatInterval, 1*time.Second), make(chan struct{}), nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -16,6 +16,8 @@ import (
|
|||||||
|
|
||||||
log "github.com/hashicorp/go-hclog"
|
log "github.com/hashicorp/go-hclog"
|
||||||
"github.com/hashicorp/vault/helper/forwarding"
|
"github.com/hashicorp/vault/helper/forwarding"
|
||||||
|
"github.com/hashicorp/vault/sdk/helper/consts"
|
||||||
|
"github.com/hashicorp/vault/vault/cluster"
|
||||||
"github.com/hashicorp/vault/vault/replication"
|
"github.com/hashicorp/vault/vault/replication"
|
||||||
cache "github.com/patrickmn/go-cache"
|
cache "github.com/patrickmn/go-cache"
|
||||||
"golang.org/x/net/http2"
|
"golang.org/x/net/http2"
|
||||||
@@ -23,27 +25,6 @@ import (
|
|||||||
"google.golang.org/grpc/keepalive"
|
"google.golang.org/grpc/keepalive"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
|
||||||
clusterListenerAcceptDeadline = 500 * time.Millisecond
|
|
||||||
|
|
||||||
// PerformanceReplicationALPN is the negotiated protocol used for
|
|
||||||
// performance replication.
|
|
||||||
PerformanceReplicationALPN = "replication_v1"
|
|
||||||
|
|
||||||
// DRReplicationALPN is the negotiated protocol used for
|
|
||||||
// dr replication.
|
|
||||||
DRReplicationALPN = "replication_dr_v1"
|
|
||||||
|
|
||||||
perfStandbyALPN = "perf_standby_v1"
|
|
||||||
|
|
||||||
requestForwardingALPN = "req_fw_sb-act_v1"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
// Making this a package var allows tests to modify
|
|
||||||
HeartbeatInterval = 5 * time.Second
|
|
||||||
)
|
|
||||||
|
|
||||||
type requestForwardingHandler struct {
|
type requestForwardingHandler struct {
|
||||||
fws *http2.Server
|
fws *http2.Server
|
||||||
fwRPCServer *grpc.Server
|
fwRPCServer *grpc.Server
|
||||||
@@ -65,7 +46,7 @@ func NewRequestForwardingHandler(c *Core, fws *http2.Server, perfStandbySlots ch
|
|||||||
|
|
||||||
fwRPCServer := grpc.NewServer(
|
fwRPCServer := grpc.NewServer(
|
||||||
grpc.KeepaliveParams(keepalive.ServerParameters{
|
grpc.KeepaliveParams(keepalive.ServerParameters{
|
||||||
Time: 2 * HeartbeatInterval,
|
Time: 2 * cluster.HeartbeatInterval,
|
||||||
}),
|
}),
|
||||||
grpc.MaxRecvMsgSize(math.MaxInt32),
|
grpc.MaxRecvMsgSize(math.MaxInt32),
|
||||||
grpc.MaxSendMsgSize(math.MaxInt32),
|
grpc.MaxSendMsgSize(math.MaxInt32),
|
||||||
@@ -190,7 +171,7 @@ func (rf *requestForwardingHandler) Handoff(ctx context.Context, shutdownWg *syn
|
|||||||
// Stop stops the request forwarding server and closes connections.
|
// Stop stops the request forwarding server and closes connections.
|
||||||
func (rf *requestForwardingHandler) Stop() error {
|
func (rf *requestForwardingHandler) Stop() error {
|
||||||
// Give some time for existing RPCs to drain.
|
// Give some time for existing RPCs to drain.
|
||||||
time.Sleep(clusterListenerAcceptDeadline)
|
time.Sleep(cluster.ListenerAcceptDeadline)
|
||||||
close(rf.stopCh)
|
close(rf.stopCh)
|
||||||
rf.fwRPCServer.Stop()
|
rf.fwRPCServer.Stop()
|
||||||
return nil
|
return nil
|
||||||
@@ -198,16 +179,16 @@ func (rf *requestForwardingHandler) Stop() error {
|
|||||||
|
|
||||||
// Starts the listeners and servers necessary to handle forwarded requests
|
// Starts the listeners and servers necessary to handle forwarded requests
|
||||||
func (c *Core) startForwarding(ctx context.Context) error {
|
func (c *Core) startForwarding(ctx context.Context) error {
|
||||||
c.logger.Debug("cluster listener setup function")
|
c.logger.Debug("request forwarding setup function")
|
||||||
defer c.logger.Debug("leaving cluster listener setup function")
|
defer c.logger.Debug("leaving request forwarding setup function")
|
||||||
|
|
||||||
// Clean up in case we have transitioned from a client to a server
|
// Clean up in case we have transitioned from a client to a server
|
||||||
c.requestForwardingConnectionLock.Lock()
|
c.requestForwardingConnectionLock.Lock()
|
||||||
c.clearForwardingClients()
|
c.clearForwardingClients()
|
||||||
c.requestForwardingConnectionLock.Unlock()
|
c.requestForwardingConnectionLock.Unlock()
|
||||||
|
|
||||||
// Resolve locally to avoid races
|
|
||||||
if c.ha == nil || c.clusterListener == nil {
|
if c.ha == nil || c.clusterListener == nil {
|
||||||
|
c.logger.Debug("request forwarding not setup")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -221,15 +202,15 @@ func (c *Core) startForwarding(ctx context.Context) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
c.clusterListener.AddHandler(requestForwardingALPN, handler)
|
c.clusterListener.AddHandler(consts.RequestForwardingALPN, handler)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Core) stopForwarding() {
|
func (c *Core) stopForwarding() {
|
||||||
if c.clusterListener != nil {
|
if c.clusterListener != nil {
|
||||||
c.clusterListener.StopHandler(requestForwardingALPN)
|
c.clusterListener.StopHandler(consts.RequestForwardingALPN)
|
||||||
c.clusterListener.StopHandler(perfStandbyALPN)
|
c.clusterListener.StopHandler(consts.PerfStandbyALPN)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -264,7 +245,7 @@ func (c *Core) refreshRequestForwardingConnection(ctx context.Context, clusterAd
|
|||||||
}
|
}
|
||||||
|
|
||||||
if c.clusterListener != nil {
|
if c.clusterListener != nil {
|
||||||
c.clusterListener.AddClient(requestForwardingALPN, &requestForwardingClusterClient{
|
c.clusterListener.AddClient(consts.RequestForwardingALPN, &requestForwardingClusterClient{
|
||||||
core: c,
|
core: c,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -275,10 +256,10 @@ func (c *Core) refreshRequestForwardingConnection(ctx context.Context, clusterAd
|
|||||||
// the TLS state.
|
// the TLS state.
|
||||||
dctx, cancelFunc := context.WithCancel(ctx)
|
dctx, cancelFunc := context.WithCancel(ctx)
|
||||||
c.rpcClientConn, err = grpc.DialContext(dctx, clusterURL.Host,
|
c.rpcClientConn, err = grpc.DialContext(dctx, clusterURL.Host,
|
||||||
grpc.WithDialer(c.getGRPCDialer(ctx, requestForwardingALPN, parsedCert.Subject.CommonName, parsedCert)),
|
grpc.WithDialer(c.getGRPCDialer(ctx, consts.RequestForwardingALPN, parsedCert.Subject.CommonName, parsedCert)),
|
||||||
grpc.WithInsecure(), // it's not, we handle it in the dialer
|
grpc.WithInsecure(), // it's not, we handle it in the dialer
|
||||||
grpc.WithKeepaliveParams(keepalive.ClientParameters{
|
grpc.WithKeepaliveParams(keepalive.ClientParameters{
|
||||||
Time: 2 * HeartbeatInterval,
|
Time: 2 * cluster.HeartbeatInterval,
|
||||||
}),
|
}),
|
||||||
grpc.WithDefaultCallOptions(
|
grpc.WithDefaultCallOptions(
|
||||||
grpc.MaxCallRecvMsgSize(math.MaxInt32),
|
grpc.MaxCallRecvMsgSize(math.MaxInt32),
|
||||||
@@ -294,7 +275,7 @@ func (c *Core) refreshRequestForwardingConnection(ctx context.Context, clusterAd
|
|||||||
c.rpcForwardingClient = &forwardingClient{
|
c.rpcForwardingClient = &forwardingClient{
|
||||||
RequestForwardingClient: NewRequestForwardingClient(c.rpcClientConn),
|
RequestForwardingClient: NewRequestForwardingClient(c.rpcClientConn),
|
||||||
core: c,
|
core: c,
|
||||||
echoTicker: time.NewTicker(HeartbeatInterval),
|
echoTicker: time.NewTicker(cluster.HeartbeatInterval),
|
||||||
echoContext: dctx,
|
echoContext: dctx,
|
||||||
}
|
}
|
||||||
c.rpcForwardingClient.startHeartbeat()
|
c.rpcForwardingClient.startHeartbeat()
|
||||||
@@ -319,7 +300,7 @@ func (c *Core) clearForwardingClients() {
|
|||||||
c.rpcForwardingClient = nil
|
c.rpcForwardingClient = nil
|
||||||
|
|
||||||
if c.clusterListener != nil {
|
if c.clusterListener != nil {
|
||||||
c.clusterListener.RemoveClient(requestForwardingALPN)
|
c.clusterListener.RemoveClient(consts.RequestForwardingALPN)
|
||||||
}
|
}
|
||||||
c.clusterLeaderParams.Store((*ClusterLeaderParams)(nil))
|
c.clusterLeaderParams.Store((*ClusterLeaderParams)(nil))
|
||||||
}
|
}
|
||||||
|
|||||||
14
vendor/github.com/hashicorp/vault/sdk/helper/consts/consts.go
generated
vendored
14
vendor/github.com/hashicorp/vault/sdk/helper/consts/consts.go
generated
vendored
@@ -11,4 +11,18 @@ const (
|
|||||||
|
|
||||||
// AuthHeaderName is the name of the header containing the token.
|
// AuthHeaderName is the name of the header containing the token.
|
||||||
AuthHeaderName = "X-Vault-Token"
|
AuthHeaderName = "X-Vault-Token"
|
||||||
|
|
||||||
|
// PerformanceReplicationALPN is the negotiated protocol used for
|
||||||
|
// performance replication.
|
||||||
|
PerformanceReplicationALPN = "replication_v1"
|
||||||
|
|
||||||
|
// DRReplicationALPN is the negotiated protocol used for
|
||||||
|
// dr replication.
|
||||||
|
DRReplicationALPN = "replication_dr_v1"
|
||||||
|
|
||||||
|
PerfStandbyALPN = "perf_standby_v1"
|
||||||
|
|
||||||
|
RequestForwardingALPN = "req_fw_sb-act_v1"
|
||||||
|
|
||||||
|
RaftStorageALPN = "raft_storage_v1"
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user