backport of commit c67242463c (#20830)

Co-authored-by: Nick Cabatoff <ncabatoff@hashicorp.com>
This commit is contained in:
hc-github-team-secure-vault-core
2023-05-29 11:02:27 -04:00
committed by GitHub
parent fd40c5509b
commit 0ca00475cd
5 changed files with 8 additions and 42 deletions

3
changelog/20826.txt Normal file
View File

@@ -0,0 +1,3 @@
```release-note:change
core: Revert #19676 (VAULT_GRPC_MIN_CONNECT_TIMEOUT env var) as we decided it was unnecessary.
```

View File

@@ -330,8 +330,7 @@ func (c *Core) startClusterListener(ctx context.Context) error {
c.clusterListener.Store(cluster.NewListener(networkLayer, c.clusterListener.Store(cluster.NewListener(networkLayer,
c.clusterCipherSuites, c.clusterCipherSuites,
listenerLogger, listenerLogger,
5*c.clusterHeartbeatInterval, 5*c.clusterHeartbeatInterval))
c.grpcMinConnectTimeout))
c.AddLogger(listenerLogger) c.AddLogger(listenerLogger)

View File

@@ -75,10 +75,9 @@ type Listener struct {
logger log.Logger logger log.Logger
l sync.RWMutex l sync.RWMutex
tlsConnectionLoggingLevel log.Level tlsConnectionLoggingLevel log.Level
grpcMinConnectTimeout time.Duration
} }
func NewListener(networkLayer NetworkLayer, cipherSuites []uint16, logger log.Logger, idleTimeout, grpcMinConnectTimeout time.Duration) *Listener { func NewListener(networkLayer NetworkLayer, cipherSuites []uint16, logger log.Logger, idleTimeout time.Duration) *Listener {
var maxStreams uint32 = math.MaxUint32 var maxStreams uint32 = math.MaxUint32
if override := os.Getenv("VAULT_GRPC_MAX_STREAMS"); override != "" { if override := os.Getenv("VAULT_GRPC_MAX_STREAMS"); override != "" {
i, err := strconv.ParseUint(override, 10, 32) i, err := strconv.ParseUint(override, 10, 32)
@@ -115,7 +114,6 @@ func NewListener(networkLayer NetworkLayer, cipherSuites []uint16, logger log.Lo
cipherSuites: cipherSuites, cipherSuites: cipherSuites,
logger: logger, logger: logger,
tlsConnectionLoggingLevel: log.LevelFromString(os.Getenv("VAULT_CLUSTER_TLS_SESSION_LOG_LEVEL")), tlsConnectionLoggingLevel: log.LevelFromString(os.Getenv("VAULT_CLUSTER_TLS_SESSION_LOG_LEVEL")),
grpcMinConnectTimeout: grpcMinConnectTimeout,
} }
} }
@@ -466,21 +464,10 @@ func (cl *Listener) GetDialerFunc(ctx context.Context, alpn string) func(string,
} }
tlsConfig.NextProtos = []string{alpn} tlsConfig.NextProtos = []string{alpn}
args := []interface{}{ cl.logger.Debug("creating rpc dialer", "address", addr, "alpn", alpn, "host", tlsConfig.ServerName)
"address", addr,
"alpn", alpn,
"host", tlsConfig.ServerName,
"timeout", fmt.Sprintf("%s", timeout),
}
if cl.grpcMinConnectTimeout != 0 {
args = append(args, "timeout_env_override", fmt.Sprintf("%s", cl.grpcMinConnectTimeout))
}
cl.logger.Debug("creating rpc dialer", args...)
start := time.Now()
conn, err := cl.networkLayer.Dial(addr, timeout, tlsConfig) conn, err := cl.networkLayer.Dial(addr, timeout, tlsConfig)
if err != nil { if err != nil {
cl.logger.Debug("dial failure", "address", addr, "alpn", alpn, "host", tlsConfig.ServerName, "duration", fmt.Sprintf("%s", time.Since(start)), "error", err)
return nil, err return nil, err
} }
cl.logTLSSessionStart(conn.RemoteAddr().String(), conn.ConnectionState()) cl.logTLSSessionStart(conn.RemoteAddr().String(), conn.ConnectionState())

View File

@@ -695,9 +695,6 @@ type Core struct {
// if populated, the callback is called for every request // if populated, the callback is called for every request
// for testing purposes // for testing purposes
requestResponseCallback func(logical.Backend, *logical.Request, *logical.Response) requestResponseCallback func(logical.Backend, *logical.Request, *logical.Response)
// if populated, override the default gRPC min connect timeout (currently 20s in grpc 1.51)
grpcMinConnectTimeout time.Duration
} }
// c.stateLock needs to be held in read mode before calling this function. // c.stateLock needs to be held in read mode before calling this function.
@@ -1282,16 +1279,6 @@ func NewCore(conf *CoreConfig) (*Core, error) {
c.events.Start() c.events.Start()
} }
minConnectTimeoutRaw := os.Getenv("VAULT_GRPC_MIN_CONNECT_TIMEOUT")
if minConnectTimeoutRaw != "" {
dur, err := time.ParseDuration(minConnectTimeoutRaw)
if err != nil {
c.logger.Warn("VAULT_GRPC_MIN_CONNECT_TIMEOUT contains non-duration value, ignoring")
} else if dur != 0 {
c.grpcMinConnectTimeout = dur
}
}
return c, nil return c, nil
} }

View File

@@ -25,7 +25,6 @@ import (
"github.com/hashicorp/vault/vault/replication" "github.com/hashicorp/vault/vault/replication"
"golang.org/x/net/http2" "golang.org/x/net/http2"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/backoff"
"google.golang.org/grpc/keepalive" "google.golang.org/grpc/keepalive"
) )
@@ -279,8 +278,7 @@ func (c *Core) refreshRequestForwardingConnection(ctx context.Context, clusterAd
// ALPN header right. It's just "insecure" because GRPC isn't managing // ALPN header right. It's just "insecure" because GRPC isn't managing
// the TLS state. // the TLS state.
dctx, cancelFunc := context.WithCancel(ctx) dctx, cancelFunc := context.WithCancel(ctx)
c.rpcClientConn, err = grpc.DialContext(dctx, clusterURL.Host,
opts := []grpc.DialOption{
grpc.WithDialer(clusterListener.GetDialerFunc(ctx, consts.RequestForwardingALPN)), grpc.WithDialer(clusterListener.GetDialerFunc(ctx, consts.RequestForwardingALPN)),
grpc.WithInsecure(), // it's not, we handle it in the dialer grpc.WithInsecure(), // it's not, we handle it in the dialer
grpc.WithKeepaliveParams(keepalive.ClientParameters{ grpc.WithKeepaliveParams(keepalive.ClientParameters{
@@ -289,15 +287,7 @@ func (c *Core) refreshRequestForwardingConnection(ctx context.Context, clusterAd
grpc.WithDefaultCallOptions( grpc.WithDefaultCallOptions(
grpc.MaxCallRecvMsgSize(math.MaxInt32), grpc.MaxCallRecvMsgSize(math.MaxInt32),
grpc.MaxCallSendMsgSize(math.MaxInt32), grpc.MaxCallSendMsgSize(math.MaxInt32),
), ))
}
if c.grpcMinConnectTimeout != 0 {
opts = append(opts, grpc.WithConnectParams(grpc.ConnectParams{
MinConnectTimeout: c.grpcMinConnectTimeout,
Backoff: backoff.DefaultConfig,
}))
}
c.rpcClientConn, err = grpc.DialContext(dctx, clusterURL.Host, opts...)
if err != nil { if err != nil {
cancelFunc() cancelFunc()
c.logger.Error("err setting up forwarding rpc client", "error", err) c.logger.Error("err setting up forwarding rpc client", "error", err)