mirror of
https://github.com/optim-enterprises-bv/vault.git
synced 2025-11-02 03:27:54 +00:00
VAULT-31749: Interceptors to reject requests from removed Raft nodes (#28875)
* initial interceptors * tests and request handling * remove comment * test comments * changelog * pr fixes * reuse existing method * fix test
This commit is contained in:
3
changelog/28875.txt
Normal file
3
changelog/28875.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
```release-note:change
|
||||
storage/raft: Do not allow nodes that have been removed from the raft cluster configuration to respond to requests. Shutdown and seal raft nodes when they are removed.
|
||||
```
|
||||
@@ -988,9 +988,14 @@ func forwardRequest(core *vault.Core, w http.ResponseWriter, r *http.Request) {
|
||||
// ErrCannotForward and we simply fall back
|
||||
statusCode, header, retBytes, err := core.ForwardRequest(r)
|
||||
if err != nil {
|
||||
if err == vault.ErrCannotForward {
|
||||
switch {
|
||||
case errors.Is(err, vault.ErrCannotForward):
|
||||
core.Logger().Trace("cannot forward request (possibly disabled on active node), falling back to redirection to standby")
|
||||
} else {
|
||||
case errors.Is(err, vault.StatusNotHAMember):
|
||||
core.Logger().Trace("this node is not a member of the HA cluster", "error", err)
|
||||
respondError(w, http.StatusInternalServerError, err)
|
||||
return
|
||||
default:
|
||||
core.Logger().Error("forward request error", "error", err)
|
||||
}
|
||||
|
||||
|
||||
@@ -4579,16 +4579,8 @@ func (c *Core) setupAuditedHeadersConfig(ctx context.Context) error {
|
||||
// RemovableNodeHABackend interface. The value of the `ok` result will be false
|
||||
// if the HA and underlyingPhysical backends are nil or do not support this operation.
|
||||
func (c *Core) IsRemovedFromCluster() (removed, ok bool) {
|
||||
var haBackend any
|
||||
if c.ha != nil {
|
||||
haBackend = c.ha
|
||||
} else if c.underlyingPhysical != nil {
|
||||
haBackend = c.underlyingPhysical
|
||||
} else {
|
||||
return false, false
|
||||
}
|
||||
removableNodeHA, ok := haBackend.(physical.RemovableNodeHABackend)
|
||||
if !ok {
|
||||
removableNodeHA := c.getRemovableHABackend()
|
||||
if removableNodeHA == nil {
|
||||
return false, false
|
||||
}
|
||||
|
||||
|
||||
@@ -3720,7 +3720,7 @@ func TestCore_IsRemovedFromCluster(t *testing.T) {
|
||||
core.underlyingPhysical = mockHA
|
||||
removed, ok = core.IsRemovedFromCluster()
|
||||
if removed || !ok {
|
||||
t.Fatalf("expected removed and ok to be false, got removed: %v, ok: %v", removed, ok)
|
||||
t.Fatalf("expected removed to be false and ok to be true, got removed: %v, ok: %v", removed, ok)
|
||||
}
|
||||
|
||||
// Test case where HA backend is nil, but the underlying physical is there, supports RemovableNodeHABackend, and is removed
|
||||
@@ -3731,6 +3731,7 @@ func TestCore_IsRemovedFromCluster(t *testing.T) {
|
||||
}
|
||||
|
||||
// Test case where HA backend does not support RemovableNodeHABackend
|
||||
core.underlyingPhysical = &MockHABackend{}
|
||||
core.ha = &MockHABackend{}
|
||||
removed, ok = core.IsRemovedFromCluster()
|
||||
if removed || ok {
|
||||
|
||||
@@ -1360,3 +1360,33 @@ func TestRaft_Join_InitStatus(t *testing.T) {
|
||||
verifyInitStatus(i, true)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRaftCluster_Removed creates a 3 node raft cluster and then removes one of
|
||||
// the nodes. The test verifies that a write on the removed node errors, and that
|
||||
// the removed node is sealed.
|
||||
func TestRaftCluster_Removed(t *testing.T) {
|
||||
t.Parallel()
|
||||
cluster, _ := raftCluster(t, nil)
|
||||
defer cluster.Cleanup()
|
||||
|
||||
follower := cluster.Cores[2]
|
||||
followerClient := follower.Client
|
||||
_, err := followerClient.Logical().Write("secret/foo", map[string]interface{}{
|
||||
"test": "data",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = cluster.Cores[0].Client.Logical().Write("/sys/storage/raft/remove-peer", map[string]interface{}{
|
||||
"server_id": follower.NodeID,
|
||||
})
|
||||
followerClient.SetCheckRedirect(func(request *http.Request, requests []*http.Request) error {
|
||||
require.Fail(t, "request caused a redirect", request.URL.Path)
|
||||
return fmt.Errorf("no redirects allowed")
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = followerClient.Logical().Write("secret/foo", map[string]interface{}{
|
||||
"test": "other_data",
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.True(t, follower.Sealed())
|
||||
}
|
||||
|
||||
13
vault/ha.go
13
vault/ha.go
@@ -1223,3 +1223,16 @@ func (c *Core) SetNeverBecomeActive(on bool) {
|
||||
atomic.StoreUint32(c.neverBecomeActive, 0)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Core) getRemovableHABackend() physical.RemovableNodeHABackend {
|
||||
var haBackend physical.RemovableNodeHABackend
|
||||
if removableHA, ok := c.ha.(physical.RemovableNodeHABackend); ok {
|
||||
haBackend = removableHA
|
||||
}
|
||||
|
||||
if removableHA, ok := c.underlyingPhysical.(physical.RemovableNodeHABackend); ok {
|
||||
haBackend = removableHA
|
||||
}
|
||||
|
||||
return haBackend
|
||||
}
|
||||
|
||||
@@ -22,13 +22,116 @@ import (
|
||||
"github.com/hashicorp/vault/helper/forwarding"
|
||||
"github.com/hashicorp/vault/sdk/helper/consts"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
"github.com/hashicorp/vault/sdk/physical"
|
||||
"github.com/hashicorp/vault/vault/cluster"
|
||||
"github.com/hashicorp/vault/vault/replication"
|
||||
"golang.org/x/net/http2"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/keepalive"
|
||||
"google.golang.org/grpc/metadata"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
var (
|
||||
NotHAMember = "node is not in HA cluster membership"
|
||||
StatusNotHAMember = status.Errorf(codes.FailedPrecondition, NotHAMember)
|
||||
)
|
||||
|
||||
const haNodeIDKey = "ha_node_id"
|
||||
|
||||
func haIDFromContext(ctx context.Context) (string, bool) {
|
||||
md, ok := metadata.FromIncomingContext(ctx)
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
res := md.Get(haNodeIDKey)
|
||||
if len(res) == 0 {
|
||||
return "", false
|
||||
}
|
||||
return res[0], true
|
||||
}
|
||||
|
||||
// haMembershipServerCheck extracts the client's HA node ID from the context
|
||||
// and checks if this client has been removed. The function returns
|
||||
// StatusNotHAMember if the client has been removed
|
||||
func haMembershipServerCheck(ctx context.Context, c *Core, haBackend physical.RemovableNodeHABackend) error {
|
||||
if haBackend == nil {
|
||||
return nil
|
||||
}
|
||||
nodeID, ok := haIDFromContext(ctx)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
removed, err := haBackend.IsNodeRemoved(ctx, nodeID)
|
||||
if err != nil {
|
||||
c.logger.Error("failed to check if node is removed", "error", err)
|
||||
return err
|
||||
}
|
||||
if removed {
|
||||
return StatusNotHAMember
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func haMembershipUnaryServerInterceptor(c *Core, haBackend physical.RemovableNodeHABackend) grpc.UnaryServerInterceptor {
|
||||
return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp any, err error) {
|
||||
err = haMembershipServerCheck(ctx, c, haBackend)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return handler(ctx, req)
|
||||
}
|
||||
}
|
||||
|
||||
func haMembershipStreamServerInterceptor(c *Core, haBackend physical.RemovableNodeHABackend) grpc.StreamServerInterceptor {
|
||||
return func(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
|
||||
err := haMembershipServerCheck(ss.Context(), c, haBackend)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return handler(srv, ss)
|
||||
}
|
||||
}
|
||||
|
||||
// haMembershipClientCheck checks if the given error from the server
|
||||
// is StatusNotHAMember. If so, the client will mark itself as removed
|
||||
// and shutdown
|
||||
func haMembershipClientCheck(err error, c *Core, haBackend physical.RemovableNodeHABackend) {
|
||||
if !errors.Is(err, StatusNotHAMember) {
|
||||
return
|
||||
}
|
||||
removeErr := haBackend.RemoveSelf()
|
||||
if removeErr != nil {
|
||||
c.logger.Debug("failed to remove self", "error", removeErr)
|
||||
}
|
||||
go c.ShutdownCoreError(errors.New("node removed from HA configuration"))
|
||||
}
|
||||
|
||||
func haMembershipUnaryClientInterceptor(c *Core, haBackend physical.RemovableNodeHABackend) grpc.UnaryClientInterceptor {
|
||||
return func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
|
||||
if haBackend == nil {
|
||||
return invoker(ctx, method, req, reply, cc, opts...)
|
||||
}
|
||||
ctx = metadata.AppendToOutgoingContext(ctx, haNodeIDKey, haBackend.NodeID())
|
||||
err := invoker(ctx, method, req, reply, cc, opts...)
|
||||
haMembershipClientCheck(err, c, haBackend)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
func haMembershipStreamClientInterceptor(c *Core, haBackend physical.RemovableNodeHABackend) grpc.StreamClientInterceptor {
|
||||
return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
|
||||
if haBackend == nil {
|
||||
return streamer(ctx, desc, cc, method, opts...)
|
||||
}
|
||||
ctx = metadata.AppendToOutgoingContext(ctx, haNodeIDKey, haBackend.NodeID())
|
||||
stream, err := streamer(ctx, desc, cc, method, opts...)
|
||||
haMembershipClientCheck(err, c, haBackend)
|
||||
return stream, err
|
||||
}
|
||||
}
|
||||
|
||||
type requestForwardingHandler struct {
|
||||
fws *http2.Server
|
||||
fwRPCServer *grpc.Server
|
||||
@@ -47,6 +150,7 @@ type requestForwardingClusterClient struct {
|
||||
func NewRequestForwardingHandler(c *Core, fws *http2.Server, perfStandbySlots chan struct{}, perfStandbyRepCluster *replication.Cluster) (*requestForwardingHandler, error) {
|
||||
// Resolve locally to avoid races
|
||||
ha := c.ha != nil
|
||||
removableHABackend := c.getRemovableHABackend()
|
||||
|
||||
fwRPCServer := grpc.NewServer(
|
||||
grpc.KeepaliveParams(keepalive.ServerParameters{
|
||||
@@ -54,6 +158,8 @@ func NewRequestForwardingHandler(c *Core, fws *http2.Server, perfStandbySlots ch
|
||||
}),
|
||||
grpc.MaxRecvMsgSize(math.MaxInt32),
|
||||
grpc.MaxSendMsgSize(math.MaxInt32),
|
||||
grpc.StreamInterceptor(haMembershipStreamServerInterceptor(c, removableHABackend)),
|
||||
grpc.UnaryInterceptor(haMembershipUnaryServerInterceptor(c, removableHABackend)),
|
||||
)
|
||||
|
||||
if ha && c.clusterHandler != nil {
|
||||
@@ -274,6 +380,8 @@ func (c *Core) refreshRequestForwardingConnection(ctx context.Context, clusterAd
|
||||
core: c,
|
||||
})
|
||||
|
||||
removableHABackend := c.getRemovableHABackend()
|
||||
|
||||
// Set up grpc forwarding handling
|
||||
// It's not really insecure, but we have to dial manually to get the
|
||||
// ALPN header right. It's just "insecure" because GRPC isn't managing
|
||||
@@ -285,6 +393,8 @@ func (c *Core) refreshRequestForwardingConnection(ctx context.Context, clusterAd
|
||||
grpc.WithKeepaliveParams(keepalive.ClientParameters{
|
||||
Time: 2 * c.clusterHeartbeatInterval,
|
||||
}),
|
||||
grpc.WithStreamInterceptor(haMembershipStreamClientInterceptor(c, removableHABackend)),
|
||||
grpc.WithUnaryInterceptor(haMembershipUnaryClientInterceptor(c, removableHABackend)),
|
||||
grpc.WithDefaultCallOptions(
|
||||
grpc.MaxCallRecvMsgSize(math.MaxInt32),
|
||||
grpc.MaxCallSendMsgSize(math.MaxInt32),
|
||||
@@ -374,6 +484,10 @@ func (c *Core) ForwardRequest(req *http.Request) (int, http.Header, []byte, erro
|
||||
if err != nil {
|
||||
metrics.IncrCounter([]string{"ha", "rpc", "client", "forward", "errors"}, 1)
|
||||
c.logger.Error("error during forwarded RPC request", "error", err)
|
||||
|
||||
if errors.Is(err, StatusNotHAMember) {
|
||||
return 0, nil, nil, fmt.Errorf("error during forwarding RPC request: %w", err)
|
||||
}
|
||||
return 0, nil, nil, fmt.Errorf("error during forwarding RPC request")
|
||||
}
|
||||
|
||||
|
||||
131
vault/request_forwarding_test.go
Normal file
131
vault/request_forwarding_test.go
Normal file
@@ -0,0 +1,131 @@
|
||||
// Copyright (c) HashiCorp, Inc.
|
||||
// SPDX-License-Identifier: BUSL-1.1
|
||||
|
||||
package vault
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/hashicorp/go-hclog"
|
||||
"github.com/hashicorp/vault/sdk/physical"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc/metadata"
|
||||
)
|
||||
|
||||
// Test_haIDFromContext verifies that the HA node ID gets correctly extracted
|
||||
// from a gRPC context
|
||||
func Test_haIDFromContext(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
md metadata.MD
|
||||
wantID string
|
||||
wantOk bool
|
||||
}{
|
||||
{
|
||||
name: "no ID",
|
||||
md: metadata.MD{},
|
||||
wantID: "",
|
||||
wantOk: false,
|
||||
},
|
||||
{
|
||||
name: "with ID",
|
||||
md: metadata.MD{haNodeIDKey: {"node_id"}},
|
||||
wantID: "node_id",
|
||||
wantOk: true,
|
||||
},
|
||||
{
|
||||
name: "with empty string ID",
|
||||
md: metadata.MD{haNodeIDKey: {""}},
|
||||
wantID: "",
|
||||
wantOk: true,
|
||||
},
|
||||
{
|
||||
name: "with empty ID",
|
||||
md: metadata.MD{haNodeIDKey: {}},
|
||||
wantID: "",
|
||||
wantOk: false,
|
||||
},
|
||||
|
||||
{
|
||||
name: "with multiple IDs",
|
||||
md: metadata.MD{haNodeIDKey: {"1", "2"}},
|
||||
wantID: "1",
|
||||
wantOk: true,
|
||||
},
|
||||
}
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
ctx := metadata.NewIncomingContext(context.Background(), tc.md)
|
||||
id, ok := haIDFromContext(ctx)
|
||||
require.Equal(t, tc.wantID, id)
|
||||
require.Equal(t, tc.wantOk, ok)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type mockHARemovableNodeBackend struct {
|
||||
physical.RemovableNodeHABackend
|
||||
isRemoved func(context.Context, string) (bool, error)
|
||||
}
|
||||
|
||||
func (m *mockHARemovableNodeBackend) IsNodeRemoved(ctx context.Context, nodeID string) (bool, error) {
|
||||
return m.isRemoved(ctx, nodeID)
|
||||
}
|
||||
|
||||
func newMockHARemovableNodeBackend(isRemoved func(context.Context, string) (bool, error)) physical.RemovableNodeHABackend {
|
||||
return &mockHARemovableNodeBackend{isRemoved: isRemoved}
|
||||
}
|
||||
|
||||
// Test_haMembershipServerCheck verifies that the correct error is returned
|
||||
// when the context contains a removed node ID
|
||||
func Test_haMembershipServerCheck(t *testing.T) {
|
||||
nodeIDCtx := metadata.NewIncomingContext(context.Background(), metadata.MD{haNodeIDKey: {"node_id"}})
|
||||
otherErr := errors.New("error checking")
|
||||
testCases := []struct {
|
||||
name string
|
||||
nodeIDCtx context.Context
|
||||
haBackend physical.RemovableNodeHABackend
|
||||
wantError error
|
||||
}{
|
||||
{
|
||||
name: "nil backend",
|
||||
haBackend: nil,
|
||||
nodeIDCtx: nodeIDCtx,
|
||||
}, {
|
||||
name: "no node ID context",
|
||||
haBackend: newMockHARemovableNodeBackend(func(ctx context.Context, s string) (bool, error) {
|
||||
return false, nil
|
||||
}),
|
||||
nodeIDCtx: context.Background(),
|
||||
}, {
|
||||
name: "node removed",
|
||||
haBackend: newMockHARemovableNodeBackend(func(ctx context.Context, s string) (bool, error) {
|
||||
return true, nil
|
||||
}),
|
||||
nodeIDCtx: nodeIDCtx,
|
||||
wantError: StatusNotHAMember,
|
||||
}, {
|
||||
name: "node removed err",
|
||||
haBackend: newMockHARemovableNodeBackend(func(ctx context.Context, s string) (bool, error) {
|
||||
return false, otherErr
|
||||
}),
|
||||
nodeIDCtx: nodeIDCtx,
|
||||
wantError: otherErr,
|
||||
},
|
||||
}
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
c := &Core{
|
||||
logger: hclog.NewNullLogger(),
|
||||
}
|
||||
err := haMembershipServerCheck(tc.nodeIDCtx, c, tc.haBackend)
|
||||
if tc.wantError != nil {
|
||||
require.EqualError(t, err, tc.wantError.Error())
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user