diff --git a/helper/testhelpers/testhelpers.go b/helper/testhelpers/testhelpers.go index bb39906b54..3a02fbcd75 100644 --- a/helper/testhelpers/testhelpers.go +++ b/helper/testhelpers/testhelpers.go @@ -13,6 +13,7 @@ import ( "github.com/hashicorp/go-hclog" "github.com/hashicorp/vault/sdk/logical" + "github.com/hashicorp/vault/vault/cluster" log "github.com/hashicorp/go-hclog" "github.com/hashicorp/vault/api" @@ -302,7 +303,7 @@ func GetPerfReplicatedClusters(t testing.T, conf *vault.CoreConfig, opts *vault. } // Set this lower so that state populates quickly to standby nodes - vault.HeartbeatInterval = 2 * time.Second + cluster.HeartbeatInterval = 2 * time.Second opts1 := *opts opts1.Logger = logger.Named("perf-pri") @@ -325,7 +326,7 @@ func GetFourReplicatedClusters(t testing.T, handlerFunc func(*vault.HandlerPrope Level: log.Trace, }) // Set this lower so that state populates quickly to standby nodes - vault.HeartbeatInterval = 2 * time.Second + cluster.HeartbeatInterval = 2 * time.Second ret.PerfPrimaryCluster, _ = GetClusterAndCore(t, logger.Named("perf-pri"), handlerFunc) diff --git a/sdk/helper/consts/consts.go b/sdk/helper/consts/consts.go index 972a69f47b..769a785836 100644 --- a/sdk/helper/consts/consts.go +++ b/sdk/helper/consts/consts.go @@ -11,4 +11,18 @@ const ( // AuthHeaderName is the name of the header containing the token. AuthHeaderName = "X-Vault-Token" + + // PerformanceReplicationALPN is the negotiated protocol used for + // performance replication. + PerformanceReplicationALPN = "replication_v1" + + // DRReplicationALPN is the negotiated protocol used for + // dr replication. + DRReplicationALPN = "replication_dr_v1" + + PerfStandbyALPN = "perf_standby_v1" + + RequestForwardingALPN = "req_fw_sb-act_v1" + + RaftStorageALPN = "raft_storage_v1" ) diff --git a/vault/cluster.go b/vault/cluster.go index 78f2daf3fa..bd04981d30 100644 --- a/vault/cluster.go +++ b/vault/cluster.go @@ -16,16 +16,13 @@ import ( "net" "net/http" "strings" - "sync" - "sync/atomic" "time" "github.com/hashicorp/errwrap" - log "github.com/hashicorp/go-hclog" uuid "github.com/hashicorp/go-uuid" "github.com/hashicorp/vault/sdk/helper/jsonutil" "github.com/hashicorp/vault/sdk/logical" - "golang.org/x/net/http2" + "github.com/hashicorp/vault/vault/cluster" ) const ( @@ -282,296 +279,6 @@ func (c *Core) setupCluster(ctx context.Context) error { return nil } -// ClusterClient is used to lookup a client certificate. -type ClusterClient interface { - ClientLookup(context.Context, *tls.CertificateRequestInfo) (*tls.Certificate, error) -} - -// ClusterHandler exposes functions for looking up TLS configuration and handing -// off a connection for a cluster listener application. -type ClusterHandler interface { - ServerLookup(context.Context, *tls.ClientHelloInfo) (*tls.Certificate, error) - CALookup(context.Context) (*x509.Certificate, error) - - // Handoff is used to pass the connection lifetime off to - // the handler - Handoff(context.Context, *sync.WaitGroup, chan struct{}, *tls.Conn) error - Stop() error -} - -// ClusterListener is the source of truth for cluster handlers and connection -// clients. It dynamically builds the cluster TLS information. It's also -// responsible for starting tcp listeners and accepting new cluster connections. -type ClusterListener struct { - handlers map[string]ClusterHandler - clients map[string]ClusterClient - shutdown *uint32 - shutdownWg *sync.WaitGroup - server *http2.Server - - clusterListenerAddrs []*net.TCPAddr - clusterCipherSuites []uint16 - logger log.Logger - l sync.RWMutex -} - -// AddClient adds a new client for an ALPN name -func (cl *ClusterListener) AddClient(alpn string, client ClusterClient) { - cl.l.Lock() - cl.clients[alpn] = client - cl.l.Unlock() -} - -// RemoveClient removes the client for the specified ALPN name -func (cl *ClusterListener) RemoveClient(alpn string) { - cl.l.Lock() - delete(cl.clients, alpn) - cl.l.Unlock() -} - -// AddHandler registers a new cluster handler for the provided ALPN name. -func (cl *ClusterListener) AddHandler(alpn string, handler ClusterHandler) { - cl.l.Lock() - cl.handlers[alpn] = handler - cl.l.Unlock() -} - -// StopHandler stops the cluster handler for the provided ALPN name, it also -// calls stop on the handler. -func (cl *ClusterListener) StopHandler(alpn string) { - cl.l.Lock() - handler, ok := cl.handlers[alpn] - delete(cl.handlers, alpn) - cl.l.Unlock() - if ok { - handler.Stop() - } -} - -// Server returns the http2 server that the cluster listener is using -func (cl *ClusterListener) Server() *http2.Server { - return cl.server -} - -// TLSConfig returns a tls config object that uses dynamic lookups to correctly -// authenticate registered handlers/clients -func (cl *ClusterListener) TLSConfig(ctx context.Context) (*tls.Config, error) { - serverLookup := func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) { - cl.logger.Debug("performing server cert lookup") - - cl.l.RLock() - defer cl.l.RUnlock() - for _, v := range clientHello.SupportedProtos { - if handler, ok := cl.handlers[v]; ok { - return handler.ServerLookup(ctx, clientHello) - } - } - - cl.logger.Warn("no TLS certs found for ALPN", "ALPN", clientHello.SupportedProtos) - return nil, errors.New("unsupported protocol") - } - - clientLookup := func(requestInfo *tls.CertificateRequestInfo) (*tls.Certificate, error) { - cl.logger.Debug("performing client cert lookup") - - cl.l.RLock() - defer cl.l.RUnlock() - for _, client := range cl.clients { - cert, err := client.ClientLookup(ctx, requestInfo) - if err == nil && cert != nil { - return cert, nil - } - } - - cl.logger.Warn("no client information found") - return nil, errors.New("no client cert found") - } - - serverConfigLookup := func(clientHello *tls.ClientHelloInfo) (*tls.Config, error) { - caPool := x509.NewCertPool() - - ret := &tls.Config{ - ClientAuth: tls.RequireAndVerifyClientCert, - GetCertificate: serverLookup, - GetClientCertificate: clientLookup, - MinVersion: tls.VersionTLS12, - RootCAs: caPool, - ClientCAs: caPool, - NextProtos: clientHello.SupportedProtos, - CipherSuites: cl.clusterCipherSuites, - } - - cl.l.RLock() - defer cl.l.RUnlock() - for _, v := range clientHello.SupportedProtos { - if handler, ok := cl.handlers[v]; ok { - ca, err := handler.CALookup(ctx) - if err != nil { - return nil, err - } - - caPool.AddCert(ca) - return ret, nil - } - } - - cl.logger.Warn("no TLS config found for ALPN", "ALPN", clientHello.SupportedProtos) - return nil, errors.New("unsupported protocol") - } - - return &tls.Config{ - ClientAuth: tls.RequireAndVerifyClientCert, - GetCertificate: serverLookup, - GetClientCertificate: clientLookup, - GetConfigForClient: serverConfigLookup, - MinVersion: tls.VersionTLS12, - CipherSuites: cl.clusterCipherSuites, - }, nil -} - -// Run starts the tcp listeners and will accept connections until stop is -// called. -func (cl *ClusterListener) Run(ctx context.Context) error { - // Get our TLS config - tlsConfig, err := cl.TLSConfig(ctx) - if err != nil { - cl.logger.Error("failed to get tls configuration when starting cluster listener", "error", err) - return err - } - - // The server supports all of the possible protos - tlsConfig.NextProtos = []string{"h2", requestForwardingALPN, perfStandbyALPN, PerformanceReplicationALPN, DRReplicationALPN} - - for i, laddr := range cl.clusterListenerAddrs { - // closeCh is used to shutdown the spawned goroutines once this - // function returns - closeCh := make(chan struct{}) - - if cl.logger.IsInfo() { - cl.logger.Info("starting listener", "listener_address", laddr) - } - - // Create a TCP listener. We do this separately and specifically - // with TCP so that we can set deadlines. - tcpLn, err := net.ListenTCP("tcp", laddr) - if err != nil { - cl.logger.Error("error starting listener", "error", err) - continue - } - if laddr.String() != tcpLn.Addr().String() { - // If we listened on port 0, record the port the OS gave us. - cl.clusterListenerAddrs[i] = tcpLn.Addr().(*net.TCPAddr) - } - - // Wrap the listener with TLS - tlsLn := tls.NewListener(tcpLn, tlsConfig) - - if cl.logger.IsInfo() { - cl.logger.Info("serving cluster requests", "cluster_listen_address", tlsLn.Addr()) - } - - cl.shutdownWg.Add(1) - // Start our listening loop - go func(closeCh chan struct{}, tlsLn net.Listener) { - defer func() { - cl.shutdownWg.Done() - tlsLn.Close() - close(closeCh) - }() - - for { - if atomic.LoadUint32(cl.shutdown) > 0 { - return - } - - // Set the deadline for the accept call. If it passes we'll get - // an error, causing us to check the condition at the top - // again. - tcpLn.SetDeadline(time.Now().Add(clusterListenerAcceptDeadline)) - - // Accept the connection - conn, err := tlsLn.Accept() - if err != nil { - if err, ok := err.(net.Error); ok && !err.Timeout() { - cl.logger.Debug("non-timeout error accepting on cluster port", "error", err) - } - if conn != nil { - conn.Close() - } - continue - } - if conn == nil { - continue - } - - // Type assert to TLS connection and handshake to populate the - // connection state - tlsConn := conn.(*tls.Conn) - - // Set a deadline for the handshake. This will cause clients - // that don't successfully auth to be kicked out quickly. - // Cluster connections should be reliable so being marginally - // aggressive here is fine. - err = tlsConn.SetDeadline(time.Now().Add(30 * time.Second)) - if err != nil { - if cl.logger.IsDebug() { - cl.logger.Debug("error setting deadline for cluster connection", "error", err) - } - tlsConn.Close() - continue - } - - err = tlsConn.Handshake() - if err != nil { - if cl.logger.IsDebug() { - cl.logger.Debug("error handshaking cluster connection", "error", err) - } - tlsConn.Close() - continue - } - - // Now, set it back to unlimited - err = tlsConn.SetDeadline(time.Time{}) - if err != nil { - if cl.logger.IsDebug() { - cl.logger.Debug("error setting deadline for cluster connection", "error", err) - } - tlsConn.Close() - continue - } - - cl.l.RLock() - handler, ok := cl.handlers[tlsConn.ConnectionState().NegotiatedProtocol] - cl.l.RUnlock() - if !ok { - cl.logger.Debug("unknown negotiated protocol on cluster port") - tlsConn.Close() - continue - } - - if err := handler.Handoff(ctx, cl.shutdownWg, closeCh, tlsConn); err != nil { - cl.logger.Error("error handling cluster connection", "error", err) - continue - } - } - }(closeCh, tlsLn) - } - - return nil -} - -// Stop stops the cluster listner -func (cl *ClusterListener) Stop() { - // Set the shutdown flag. This will cause the listeners to shut down - // within the deadline in clusterListenerAcceptDeadline - atomic.StoreUint32(cl.shutdown, 1) - cl.logger.Info("forwarding rpc listeners stopped") - - // Wait for them all to shut down - cl.shutdownWg.Wait() - cl.logger.Info("rpc listeners successfully shut down") -} - // startClusterListener starts cluster request listeners during unseal. It // is assumed that the state lock is held while this is run. Right now this // only starts cluster listeners. Once the listener is started handlers/clients @@ -589,27 +296,7 @@ func (c *Core) startClusterListener(ctx context.Context) error { c.logger.Debug("starting cluster listeners") - // Create the HTTP/2 server that will be shared by both RPC and regular - // duties. Doing it this way instead of listening via the server and gRPC - // allows us to re-use the same port via ALPN. We can just tell the server - // to serve a given conn and which handler to use. - h2Server := &http2.Server{ - // Our forwarding connections heartbeat regularly so anything else we - // want to go away/get cleaned up pretty rapidly - IdleTimeout: 5 * HeartbeatInterval, - } - - c.clusterListener = &ClusterListener{ - handlers: make(map[string]ClusterHandler), - clients: make(map[string]ClusterClient), - shutdown: new(uint32), - shutdownWg: &sync.WaitGroup{}, - server: h2Server, - - clusterListenerAddrs: c.clusterListenerAddrs, - clusterCipherSuites: c.clusterCipherSuites, - logger: c.logger.Named("cluster-listener"), - } + c.clusterListener = cluster.NewListener(c.clusterListenerAddrs, c.clusterCipherSuites, c.logger.Named("cluster-listener")) err := c.clusterListener.Run(ctx) if err != nil { @@ -617,7 +304,7 @@ func (c *Core) startClusterListener(ctx context.Context) error { } if strings.HasSuffix(c.clusterAddr, ":0") { // If we listened on port 0, record the port the OS gave us. - c.clusterAddr = fmt.Sprintf("https://%s", c.clusterListener.clusterListenerAddrs[0]) + c.clusterAddr = fmt.Sprintf("https://%s", c.clusterListener.Addrs()[0]) } return nil } @@ -633,6 +320,7 @@ func (c *Core) stopClusterListener() { c.logger.Info("stopping cluster listeners") c.clusterListener.Stop() + c.clusterListener = nil c.logger.Info("cluster listeners successfully shut down") } diff --git a/vault/cluster/cluster.go b/vault/cluster/cluster.go new file mode 100644 index 0000000000..72ce23efef --- /dev/null +++ b/vault/cluster/cluster.go @@ -0,0 +1,343 @@ +package cluster + +import ( + "context" + "crypto/tls" + "crypto/x509" + "errors" + "net" + "sync" + "sync/atomic" + "time" + + log "github.com/hashicorp/go-hclog" + "github.com/hashicorp/vault/sdk/helper/consts" + "golang.org/x/net/http2" +) + +var ( + // Making this a package var allows tests to modify + HeartbeatInterval = 5 * time.Second +) + +const ( + ListenerAcceptDeadline = 500 * time.Millisecond +) + +// Client is used to lookup a client certificate. +type Client interface { + ClientLookup(context.Context, *tls.CertificateRequestInfo) (*tls.Certificate, error) +} + +// Handler exposes functions for looking up TLS configuration and handing +// off a connection for a cluster listener application. +type Handler interface { + ServerLookup(context.Context, *tls.ClientHelloInfo) (*tls.Certificate, error) + CALookup(context.Context) (*x509.Certificate, error) + + // Handoff is used to pass the connection lifetime off to + // the handler + Handoff(context.Context, *sync.WaitGroup, chan struct{}, *tls.Conn) error + Stop() error +} + +// Listener is the source of truth for cluster handlers and connection +// clients. It dynamically builds the cluster TLS information. It's also +// responsible for starting tcp listeners and accepting new cluster connections. +type Listener struct { + handlers map[string]Handler + clients map[string]Client + shutdown *uint32 + shutdownWg *sync.WaitGroup + server *http2.Server + + listenerAddrs []*net.TCPAddr + cipherSuites []uint16 + logger log.Logger + l sync.RWMutex +} + +func NewListener(addrs []*net.TCPAddr, cipherSuites []uint16, logger log.Logger) *Listener { + // Create the HTTP/2 server that will be shared by both RPC and regular + // duties. Doing it this way instead of listening via the server and gRPC + // allows us to re-use the same port via ALPN. We can just tell the server + // to serve a given conn and which handler to use. + h2Server := &http2.Server{ + // Our forwarding connections heartbeat regularly so anything else we + // want to go away/get cleaned up pretty rapidly + IdleTimeout: 5 * HeartbeatInterval, + } + + return &Listener{ + handlers: make(map[string]Handler), + clients: make(map[string]Client), + shutdown: new(uint32), + shutdownWg: &sync.WaitGroup{}, + server: h2Server, + + listenerAddrs: addrs, + cipherSuites: cipherSuites, + logger: logger, + } +} + +func (cl *Listener) Addrs() []*net.TCPAddr { + return cl.listenerAddrs +} + +// AddClient adds a new client for an ALPN name +func (cl *Listener) AddClient(alpn string, client Client) { + cl.l.Lock() + cl.clients[alpn] = client + cl.l.Unlock() +} + +// RemoveClient removes the client for the specified ALPN name +func (cl *Listener) RemoveClient(alpn string) { + cl.l.Lock() + delete(cl.clients, alpn) + cl.l.Unlock() +} + +// AddHandler registers a new cluster handler for the provided ALPN name. +func (cl *Listener) AddHandler(alpn string, handler Handler) { + cl.l.Lock() + cl.handlers[alpn] = handler + cl.l.Unlock() +} + +// StopHandler stops the cluster handler for the provided ALPN name, it also +// calls stop on the handler. +func (cl *Listener) StopHandler(alpn string) { + cl.l.Lock() + handler, ok := cl.handlers[alpn] + delete(cl.handlers, alpn) + cl.l.Unlock() + if ok { + handler.Stop() + } +} + +// Server returns the http2 server that the cluster listener is using +func (cl *Listener) Server() *http2.Server { + return cl.server +} + +// TLSConfig returns a tls config object that uses dynamic lookups to correctly +// authenticate registered handlers/clients +func (cl *Listener) TLSConfig(ctx context.Context) (*tls.Config, error) { + serverLookup := func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) { + cl.logger.Debug("performing server cert lookup") + + cl.l.RLock() + defer cl.l.RUnlock() + for _, v := range clientHello.SupportedProtos { + if handler, ok := cl.handlers[v]; ok { + return handler.ServerLookup(ctx, clientHello) + } + } + + cl.logger.Warn("no TLS certs found for ALPN", "ALPN", clientHello.SupportedProtos) + return nil, errors.New("unsupported protocol") + } + + clientLookup := func(requestInfo *tls.CertificateRequestInfo) (*tls.Certificate, error) { + cl.logger.Debug("performing client cert lookup") + + cl.l.RLock() + defer cl.l.RUnlock() + for _, client := range cl.clients { + cert, err := client.ClientLookup(ctx, requestInfo) + if err == nil && cert != nil { + return cert, nil + } + } + + cl.logger.Warn("no client information found") + return nil, errors.New("no client cert found") + } + + serverConfigLookup := func(clientHello *tls.ClientHelloInfo) (*tls.Config, error) { + caPool := x509.NewCertPool() + + ret := &tls.Config{ + ClientAuth: tls.RequireAndVerifyClientCert, + GetCertificate: serverLookup, + GetClientCertificate: clientLookup, + MinVersion: tls.VersionTLS12, + RootCAs: caPool, + ClientCAs: caPool, + NextProtos: clientHello.SupportedProtos, + CipherSuites: cl.cipherSuites, + } + + cl.l.RLock() + defer cl.l.RUnlock() + for _, v := range clientHello.SupportedProtos { + if handler, ok := cl.handlers[v]; ok { + ca, err := handler.CALookup(ctx) + if err != nil { + return nil, err + } + + caPool.AddCert(ca) + return ret, nil + } + } + + cl.logger.Warn("no TLS config found for ALPN", "ALPN", clientHello.SupportedProtos) + return nil, errors.New("unsupported protocol") + } + + return &tls.Config{ + ClientAuth: tls.RequireAndVerifyClientCert, + GetCertificate: serverLookup, + GetClientCertificate: clientLookup, + GetConfigForClient: serverConfigLookup, + MinVersion: tls.VersionTLS12, + CipherSuites: cl.cipherSuites, + }, nil +} + +// Run starts the tcp listeners and will accept connections until stop is +// called. +func (cl *Listener) Run(ctx context.Context) error { + // Get our TLS config + tlsConfig, err := cl.TLSConfig(ctx) + if err != nil { + cl.logger.Error("failed to get tls configuration when starting cluster listener", "error", err) + return err + } + + // The server supports all of the possible protos + tlsConfig.NextProtos = []string{"h2", consts.RequestForwardingALPN, consts.PerfStandbyALPN, consts.PerformanceReplicationALPN, consts.DRReplicationALPN} + + for i, laddr := range cl.listenerAddrs { + // closeCh is used to shutdown the spawned goroutines once this + // function returns + closeCh := make(chan struct{}) + + if cl.logger.IsInfo() { + cl.logger.Info("starting listener", "listener_address", laddr) + } + + // Create a TCP listener. We do this separately and specifically + // with TCP so that we can set deadlines. + tcpLn, err := net.ListenTCP("tcp", laddr) + if err != nil { + cl.logger.Error("error starting listener", "error", err) + continue + } + if laddr.String() != tcpLn.Addr().String() { + // If we listened on port 0, record the port the OS gave us. + cl.listenerAddrs[i] = tcpLn.Addr().(*net.TCPAddr) + } + + // Wrap the listener with TLS + tlsLn := tls.NewListener(tcpLn, tlsConfig) + + if cl.logger.IsInfo() { + cl.logger.Info("serving cluster requests", "cluster_listen_address", tlsLn.Addr()) + } + + cl.shutdownWg.Add(1) + // Start our listening loop + go func(closeCh chan struct{}, tlsLn net.Listener) { + defer func() { + cl.shutdownWg.Done() + tlsLn.Close() + close(closeCh) + }() + + for { + if atomic.LoadUint32(cl.shutdown) > 0 { + return + } + + // Set the deadline for the accept call. If it passes we'll get + // an error, causing us to check the condition at the top + // again. + tcpLn.SetDeadline(time.Now().Add(ListenerAcceptDeadline)) + + // Accept the connection + conn, err := tlsLn.Accept() + if err != nil { + if err, ok := err.(net.Error); ok && !err.Timeout() { + cl.logger.Debug("non-timeout error accepting on cluster port", "error", err) + } + if conn != nil { + conn.Close() + } + continue + } + if conn == nil { + continue + } + + // Type assert to TLS connection and handshake to populate the + // connection state + tlsConn := conn.(*tls.Conn) + + // Set a deadline for the handshake. This will cause clients + // that don't successfully auth to be kicked out quickly. + // Cluster connections should be reliable so being marginally + // aggressive here is fine. + err = tlsConn.SetDeadline(time.Now().Add(30 * time.Second)) + if err != nil { + if cl.logger.IsDebug() { + cl.logger.Debug("error setting deadline for cluster connection", "error", err) + } + tlsConn.Close() + continue + } + + err = tlsConn.Handshake() + if err != nil { + if cl.logger.IsDebug() { + cl.logger.Debug("error handshaking cluster connection", "error", err) + } + tlsConn.Close() + continue + } + + // Now, set it back to unlimited + err = tlsConn.SetDeadline(time.Time{}) + if err != nil { + if cl.logger.IsDebug() { + cl.logger.Debug("error setting deadline for cluster connection", "error", err) + } + tlsConn.Close() + continue + } + + cl.l.RLock() + handler, ok := cl.handlers[tlsConn.ConnectionState().NegotiatedProtocol] + cl.l.RUnlock() + if !ok { + cl.logger.Debug("unknown negotiated protocol on cluster port") + tlsConn.Close() + continue + } + + if err := handler.Handoff(ctx, cl.shutdownWg, closeCh, tlsConn); err != nil { + cl.logger.Error("error handling cluster connection", "error", err) + continue + } + } + }(closeCh, tlsLn) + } + + return nil +} + +// Stop stops the cluster listner +func (cl *Listener) Stop() { + // Set the shutdown flag. This will cause the listeners to shut down + // within the deadline in clusterListenerAcceptDeadline + atomic.StoreUint32(cl.shutdown, 1) + cl.logger.Info("forwarding rpc listeners stopped") + + // Wait for them all to shut down + cl.shutdownWg.Wait() + cl.logger.Info("rpc listeners successfully shut down") +} diff --git a/vault/cluster_test.go b/vault/cluster_test.go index d56c6d30f8..5344a729af 100644 --- a/vault/cluster_test.go +++ b/vault/cluster_test.go @@ -101,17 +101,17 @@ func TestCluster_ListenForRequests(t *testing.T) { // Wait for core to become active TestWaitActive(t, cores[0].Core) + cores[0].clusterListener.AddClient(consts.RequestForwardingALPN, &requestForwardingClusterClient{cores[0].Core}) + addrs := cores[0].clusterListener.Addrs() + // Use this to have a valid config after sealing since ClusterTLSConfig returns nil checkListenersFunc := func(expectFail bool) { - cores[0].clusterListener.AddClient(requestForwardingALPN, &requestForwardingClusterClient{cores[0].Core}) - parsedCert := cores[0].localClusterParsedCert.Load().(*x509.Certificate) - dialer := cores[0].getGRPCDialer(context.Background(), requestForwardingALPN, parsedCert.Subject.CommonName, parsedCert) + dialer := cores[0].getGRPCDialer(context.Background(), consts.RequestForwardingALPN, parsedCert.Subject.CommonName, parsedCert) for i := range cores[0].Listeners { - clnAddr := cores[0].clusterListener.clusterListenerAddrs[i] + clnAddr := addrs[i] netConn, err := dialer(clnAddr.String(), 0) - conn := netConn.(*tls.Conn) if err != nil { if expectFail { t.Logf("testing %s unsuccessful as expected", clnAddr) @@ -122,6 +122,7 @@ func TestCluster_ListenForRequests(t *testing.T) { if expectFail { t.Fatalf("testing %s not unsuccessful as expected", clnAddr) } + conn := netConn.(*tls.Conn) err = conn.Handshake() if err != nil { t.Fatal(err) @@ -130,7 +131,7 @@ func TestCluster_ListenForRequests(t *testing.T) { switch { case connState.Version != tls.VersionTLS12: t.Fatal("version mismatch") - case connState.NegotiatedProtocol != requestForwardingALPN || !connState.NegotiatedProtocolIsMutual: + case connState.NegotiatedProtocol != consts.RequestForwardingALPN || !connState.NegotiatedProtocolIsMutual: t.Fatal("bad protocol negotiation") } t.Logf("testing %s successful", clnAddr) @@ -155,7 +156,8 @@ func TestCluster_ListenForRequests(t *testing.T) { checkListenersFunc(true) // After this period it should be active again - time.Sleep(manualStepDownSleepPeriod) + TestWaitActive(t, cores[0].Core) + cores[0].clusterListener.AddClient(consts.RequestForwardingALPN, &requestForwardingClusterClient{cores[0].Core}) checkListenersFunc(false) err = cores[0].Core.Seal(cluster.RootToken) @@ -382,12 +384,12 @@ func TestCluster_CustomCipherSuites(t *testing.T) { // Wait for core to become active TestWaitActive(t, core.Core) - core.clusterListener.AddClient(requestForwardingALPN, &requestForwardingClusterClient{core.Core}) + core.clusterListener.AddClient(consts.RequestForwardingALPN, &requestForwardingClusterClient{core.Core}) parsedCert := core.localClusterParsedCert.Load().(*x509.Certificate) - dialer := core.getGRPCDialer(context.Background(), requestForwardingALPN, parsedCert.Subject.CommonName, parsedCert) + dialer := core.getGRPCDialer(context.Background(), consts.RequestForwardingALPN, parsedCert.Subject.CommonName, parsedCert) - netConn, err := dialer(core.clusterListener.clusterListenerAddrs[0].String(), 0) + netConn, err := dialer(core.clusterListener.Addrs()[0].String(), 0) conn := netConn.(*tls.Conn) if err != nil { t.Fatal(err) diff --git a/vault/core.go b/vault/core.go index 042054c208..33bbcd7c99 100644 --- a/vault/core.go +++ b/vault/core.go @@ -38,6 +38,7 @@ import ( "github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/sdk/physical" "github.com/hashicorp/vault/shamir" + "github.com/hashicorp/vault/vault/cluster" "github.com/hashicorp/vault/vault/seal" ) @@ -419,7 +420,7 @@ type Core struct { loadCaseSensitiveIdentityStore bool // clusterListener starts up and manages connections on the cluster ports - clusterListener *ClusterListener + clusterListener *cluster.Listener // Telemetry objects metricsHelper *metricsutil.MetricsHelper @@ -592,7 +593,7 @@ func NewCore(conf *CoreConfig) (*Core, error) { maxLeaseTTL: conf.MaxLeaseTTL, cachingDisabled: conf.DisableCache, clusterName: conf.ClusterName, - clusterPeerClusterAddrsCache: cache.New(3*HeartbeatInterval, time.Second), + clusterPeerClusterAddrsCache: cache.New(3*cluster.HeartbeatInterval, time.Second), enableMlock: !conf.DisableMlock, rawEnabled: conf.EnableRaw, replicationState: new(uint32), diff --git a/vault/core_util.go b/vault/core_util.go index 0ec09b5a31..7fba9c6e5b 100644 --- a/vault/core_util.go +++ b/vault/core_util.go @@ -10,6 +10,7 @@ import ( "github.com/hashicorp/vault/sdk/helper/license" "github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/sdk/physical" + "github.com/hashicorp/vault/vault/cluster" "github.com/hashicorp/vault/vault/replication" cache "github.com/patrickmn/go-cache" ) @@ -111,5 +112,5 @@ func (c *Core) invalidateSentinelPolicy(PolicyType, string) {} func (c *Core) removePerfStandbySecondary(context.Context, string) {} func (c *Core) perfStandbyClusterHandler() (*replication.Cluster, *cache.Cache, chan struct{}, error) { - return nil, cache.New(2*HeartbeatInterval, 1*time.Second), make(chan struct{}), nil + return nil, cache.New(2*cluster.HeartbeatInterval, 1*time.Second), make(chan struct{}), nil } diff --git a/vault/request_forwarding.go b/vault/request_forwarding.go index da743cf4d0..03b4ab813f 100644 --- a/vault/request_forwarding.go +++ b/vault/request_forwarding.go @@ -16,6 +16,8 @@ import ( log "github.com/hashicorp/go-hclog" "github.com/hashicorp/vault/helper/forwarding" + "github.com/hashicorp/vault/sdk/helper/consts" + "github.com/hashicorp/vault/vault/cluster" "github.com/hashicorp/vault/vault/replication" cache "github.com/patrickmn/go-cache" "golang.org/x/net/http2" @@ -23,27 +25,6 @@ import ( "google.golang.org/grpc/keepalive" ) -const ( - clusterListenerAcceptDeadline = 500 * time.Millisecond - - // PerformanceReplicationALPN is the negotiated protocol used for - // performance replication. - PerformanceReplicationALPN = "replication_v1" - - // DRReplicationALPN is the negotiated protocol used for - // dr replication. - DRReplicationALPN = "replication_dr_v1" - - perfStandbyALPN = "perf_standby_v1" - - requestForwardingALPN = "req_fw_sb-act_v1" -) - -var ( - // Making this a package var allows tests to modify - HeartbeatInterval = 5 * time.Second -) - type requestForwardingHandler struct { fws *http2.Server fwRPCServer *grpc.Server @@ -65,7 +46,7 @@ func NewRequestForwardingHandler(c *Core, fws *http2.Server, perfStandbySlots ch fwRPCServer := grpc.NewServer( grpc.KeepaliveParams(keepalive.ServerParameters{ - Time: 2 * HeartbeatInterval, + Time: 2 * cluster.HeartbeatInterval, }), grpc.MaxRecvMsgSize(math.MaxInt32), grpc.MaxSendMsgSize(math.MaxInt32), @@ -190,7 +171,7 @@ func (rf *requestForwardingHandler) Handoff(ctx context.Context, shutdownWg *syn // Stop stops the request forwarding server and closes connections. func (rf *requestForwardingHandler) Stop() error { // Give some time for existing RPCs to drain. - time.Sleep(clusterListenerAcceptDeadline) + time.Sleep(cluster.ListenerAcceptDeadline) close(rf.stopCh) rf.fwRPCServer.Stop() return nil @@ -198,16 +179,16 @@ func (rf *requestForwardingHandler) Stop() error { // Starts the listeners and servers necessary to handle forwarded requests func (c *Core) startForwarding(ctx context.Context) error { - c.logger.Debug("cluster listener setup function") - defer c.logger.Debug("leaving cluster listener setup function") + c.logger.Debug("request forwarding setup function") + defer c.logger.Debug("leaving request forwarding setup function") // Clean up in case we have transitioned from a client to a server c.requestForwardingConnectionLock.Lock() c.clearForwardingClients() c.requestForwardingConnectionLock.Unlock() - // Resolve locally to avoid races if c.ha == nil || c.clusterListener == nil { + c.logger.Debug("request forwarding not setup") return nil } @@ -221,15 +202,15 @@ func (c *Core) startForwarding(ctx context.Context) error { return err } - c.clusterListener.AddHandler(requestForwardingALPN, handler) + c.clusterListener.AddHandler(consts.RequestForwardingALPN, handler) return nil } func (c *Core) stopForwarding() { if c.clusterListener != nil { - c.clusterListener.StopHandler(requestForwardingALPN) - c.clusterListener.StopHandler(perfStandbyALPN) + c.clusterListener.StopHandler(consts.RequestForwardingALPN) + c.clusterListener.StopHandler(consts.PerfStandbyALPN) } } @@ -264,7 +245,7 @@ func (c *Core) refreshRequestForwardingConnection(ctx context.Context, clusterAd } if c.clusterListener != nil { - c.clusterListener.AddClient(requestForwardingALPN, &requestForwardingClusterClient{ + c.clusterListener.AddClient(consts.RequestForwardingALPN, &requestForwardingClusterClient{ core: c, }) } @@ -275,10 +256,10 @@ func (c *Core) refreshRequestForwardingConnection(ctx context.Context, clusterAd // the TLS state. dctx, cancelFunc := context.WithCancel(ctx) c.rpcClientConn, err = grpc.DialContext(dctx, clusterURL.Host, - grpc.WithDialer(c.getGRPCDialer(ctx, requestForwardingALPN, parsedCert.Subject.CommonName, parsedCert)), + grpc.WithDialer(c.getGRPCDialer(ctx, consts.RequestForwardingALPN, parsedCert.Subject.CommonName, parsedCert)), grpc.WithInsecure(), // it's not, we handle it in the dialer grpc.WithKeepaliveParams(keepalive.ClientParameters{ - Time: 2 * HeartbeatInterval, + Time: 2 * cluster.HeartbeatInterval, }), grpc.WithDefaultCallOptions( grpc.MaxCallRecvMsgSize(math.MaxInt32), @@ -294,7 +275,7 @@ func (c *Core) refreshRequestForwardingConnection(ctx context.Context, clusterAd c.rpcForwardingClient = &forwardingClient{ RequestForwardingClient: NewRequestForwardingClient(c.rpcClientConn), core: c, - echoTicker: time.NewTicker(HeartbeatInterval), + echoTicker: time.NewTicker(cluster.HeartbeatInterval), echoContext: dctx, } c.rpcForwardingClient.startHeartbeat() @@ -319,7 +300,7 @@ func (c *Core) clearForwardingClients() { c.rpcForwardingClient = nil if c.clusterListener != nil { - c.clusterListener.RemoveClient(requestForwardingALPN) + c.clusterListener.RemoveClient(consts.RequestForwardingALPN) } c.clusterLeaderParams.Store((*ClusterLeaderParams)(nil)) } diff --git a/vendor/github.com/hashicorp/vault/sdk/helper/consts/consts.go b/vendor/github.com/hashicorp/vault/sdk/helper/consts/consts.go index 972a69f47b..769a785836 100644 --- a/vendor/github.com/hashicorp/vault/sdk/helper/consts/consts.go +++ b/vendor/github.com/hashicorp/vault/sdk/helper/consts/consts.go @@ -11,4 +11,18 @@ const ( // AuthHeaderName is the name of the header containing the token. AuthHeaderName = "X-Vault-Token" + + // PerformanceReplicationALPN is the negotiated protocol used for + // performance replication. + PerformanceReplicationALPN = "replication_v1" + + // DRReplicationALPN is the negotiated protocol used for + // dr replication. + DRReplicationALPN = "replication_dr_v1" + + PerfStandbyALPN = "perf_standby_v1" + + RequestForwardingALPN = "req_fw_sb-act_v1" + + RaftStorageALPN = "raft_storage_v1" )