VAULT-31749: Interceptors to reject requests from removed Raft nodes (#28875)

* initial interceptors

* tests and request handling

* remove comment

* test comments

* changelog

* pr fixes

* reuse existing method

* fix test
This commit is contained in:
miagilepner
2024-11-18 10:18:32 +01:00
committed by GitHub
parent cb0448a785
commit dce93e3d6c
8 changed files with 302 additions and 13 deletions

3
changelog/28875.txt Normal file
View File

@@ -0,0 +1,3 @@
```release-note:change
storage/raft: Do not allow nodes that have been removed from the raft cluster configuration to respond to requests. Shutdown and seal raft nodes when they are removed.
```

View File

@@ -988,9 +988,14 @@ func forwardRequest(core *vault.Core, w http.ResponseWriter, r *http.Request) {
// ErrCannotForward and we simply fall back
statusCode, header, retBytes, err := core.ForwardRequest(r)
if err != nil {
if err == vault.ErrCannotForward {
switch {
case errors.Is(err, vault.ErrCannotForward):
core.Logger().Trace("cannot forward request (possibly disabled on active node), falling back to redirection to standby")
} else {
case errors.Is(err, vault.StatusNotHAMember):
core.Logger().Trace("this node is not a member of the HA cluster", "error", err)
respondError(w, http.StatusInternalServerError, err)
return
default:
core.Logger().Error("forward request error", "error", err)
}

View File

@@ -4579,16 +4579,8 @@ func (c *Core) setupAuditedHeadersConfig(ctx context.Context) error {
// RemovableNodeHABackend interface. The value of the `ok` result will be false
// if the HA and underlyingPhysical backends are nil or do not support this operation.
func (c *Core) IsRemovedFromCluster() (removed, ok bool) {
var haBackend any
if c.ha != nil {
haBackend = c.ha
} else if c.underlyingPhysical != nil {
haBackend = c.underlyingPhysical
} else {
return false, false
}
removableNodeHA, ok := haBackend.(physical.RemovableNodeHABackend)
if !ok {
removableNodeHA := c.getRemovableHABackend()
if removableNodeHA == nil {
return false, false
}

View File

@@ -3720,7 +3720,7 @@ func TestCore_IsRemovedFromCluster(t *testing.T) {
core.underlyingPhysical = mockHA
removed, ok = core.IsRemovedFromCluster()
if removed || !ok {
t.Fatalf("expected removed and ok to be false, got removed: %v, ok: %v", removed, ok)
t.Fatalf("expected removed to be false and ok to be true, got removed: %v, ok: %v", removed, ok)
}
// Test case where HA backend is nil, but the underlying physical is there, supports RemovableNodeHABackend, and is removed
@@ -3731,6 +3731,7 @@ func TestCore_IsRemovedFromCluster(t *testing.T) {
}
// Test case where HA backend does not support RemovableNodeHABackend
core.underlyingPhysical = &MockHABackend{}
core.ha = &MockHABackend{}
removed, ok = core.IsRemovedFromCluster()
if removed || ok {

View File

@@ -1360,3 +1360,33 @@ func TestRaft_Join_InitStatus(t *testing.T) {
verifyInitStatus(i, true)
}
}
// TestRaftCluster_Removed creates a 3 node raft cluster and then removes one of
// the nodes. The test verifies that a write on the removed node errors, and that
// the removed node is sealed.
func TestRaftCluster_Removed(t *testing.T) {
t.Parallel()
cluster, _ := raftCluster(t, nil)
defer cluster.Cleanup()
follower := cluster.Cores[2]
followerClient := follower.Client
_, err := followerClient.Logical().Write("secret/foo", map[string]interface{}{
"test": "data",
})
require.NoError(t, err)
_, err = cluster.Cores[0].Client.Logical().Write("/sys/storage/raft/remove-peer", map[string]interface{}{
"server_id": follower.NodeID,
})
followerClient.SetCheckRedirect(func(request *http.Request, requests []*http.Request) error {
require.Fail(t, "request caused a redirect", request.URL.Path)
return fmt.Errorf("no redirects allowed")
})
require.NoError(t, err)
_, err = followerClient.Logical().Write("secret/foo", map[string]interface{}{
"test": "other_data",
})
require.Error(t, err)
require.True(t, follower.Sealed())
}

View File

@@ -1223,3 +1223,16 @@ func (c *Core) SetNeverBecomeActive(on bool) {
atomic.StoreUint32(c.neverBecomeActive, 0)
}
}
func (c *Core) getRemovableHABackend() physical.RemovableNodeHABackend {
var haBackend physical.RemovableNodeHABackend
if removableHA, ok := c.ha.(physical.RemovableNodeHABackend); ok {
haBackend = removableHA
}
if removableHA, ok := c.underlyingPhysical.(physical.RemovableNodeHABackend); ok {
haBackend = removableHA
}
return haBackend
}

View File

@@ -22,13 +22,116 @@ import (
"github.com/hashicorp/vault/helper/forwarding"
"github.com/hashicorp/vault/sdk/helper/consts"
"github.com/hashicorp/vault/sdk/logical"
"github.com/hashicorp/vault/sdk/physical"
"github.com/hashicorp/vault/vault/cluster"
"github.com/hashicorp/vault/vault/replication"
"golang.org/x/net/http2"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/keepalive"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
)
var (
NotHAMember = "node is not in HA cluster membership"
StatusNotHAMember = status.Errorf(codes.FailedPrecondition, NotHAMember)
)
const haNodeIDKey = "ha_node_id"
func haIDFromContext(ctx context.Context) (string, bool) {
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return "", false
}
res := md.Get(haNodeIDKey)
if len(res) == 0 {
return "", false
}
return res[0], true
}
// haMembershipServerCheck extracts the client's HA node ID from the context
// and checks if this client has been removed. The function returns
// StatusNotHAMember if the client has been removed
func haMembershipServerCheck(ctx context.Context, c *Core, haBackend physical.RemovableNodeHABackend) error {
if haBackend == nil {
return nil
}
nodeID, ok := haIDFromContext(ctx)
if !ok {
return nil
}
removed, err := haBackend.IsNodeRemoved(ctx, nodeID)
if err != nil {
c.logger.Error("failed to check if node is removed", "error", err)
return err
}
if removed {
return StatusNotHAMember
}
return nil
}
func haMembershipUnaryServerInterceptor(c *Core, haBackend physical.RemovableNodeHABackend) grpc.UnaryServerInterceptor {
return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp any, err error) {
err = haMembershipServerCheck(ctx, c, haBackend)
if err != nil {
return nil, err
}
return handler(ctx, req)
}
}
func haMembershipStreamServerInterceptor(c *Core, haBackend physical.RemovableNodeHABackend) grpc.StreamServerInterceptor {
return func(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
err := haMembershipServerCheck(ss.Context(), c, haBackend)
if err != nil {
return err
}
return handler(srv, ss)
}
}
// haMembershipClientCheck checks if the given error from the server
// is StatusNotHAMember. If so, the client will mark itself as removed
// and shutdown
func haMembershipClientCheck(err error, c *Core, haBackend physical.RemovableNodeHABackend) {
if !errors.Is(err, StatusNotHAMember) {
return
}
removeErr := haBackend.RemoveSelf()
if removeErr != nil {
c.logger.Debug("failed to remove self", "error", removeErr)
}
go c.ShutdownCoreError(errors.New("node removed from HA configuration"))
}
func haMembershipUnaryClientInterceptor(c *Core, haBackend physical.RemovableNodeHABackend) grpc.UnaryClientInterceptor {
return func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
if haBackend == nil {
return invoker(ctx, method, req, reply, cc, opts...)
}
ctx = metadata.AppendToOutgoingContext(ctx, haNodeIDKey, haBackend.NodeID())
err := invoker(ctx, method, req, reply, cc, opts...)
haMembershipClientCheck(err, c, haBackend)
return err
}
}
func haMembershipStreamClientInterceptor(c *Core, haBackend physical.RemovableNodeHABackend) grpc.StreamClientInterceptor {
return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
if haBackend == nil {
return streamer(ctx, desc, cc, method, opts...)
}
ctx = metadata.AppendToOutgoingContext(ctx, haNodeIDKey, haBackend.NodeID())
stream, err := streamer(ctx, desc, cc, method, opts...)
haMembershipClientCheck(err, c, haBackend)
return stream, err
}
}
type requestForwardingHandler struct {
fws *http2.Server
fwRPCServer *grpc.Server
@@ -47,6 +150,7 @@ type requestForwardingClusterClient struct {
func NewRequestForwardingHandler(c *Core, fws *http2.Server, perfStandbySlots chan struct{}, perfStandbyRepCluster *replication.Cluster) (*requestForwardingHandler, error) {
// Resolve locally to avoid races
ha := c.ha != nil
removableHABackend := c.getRemovableHABackend()
fwRPCServer := grpc.NewServer(
grpc.KeepaliveParams(keepalive.ServerParameters{
@@ -54,6 +158,8 @@ func NewRequestForwardingHandler(c *Core, fws *http2.Server, perfStandbySlots ch
}),
grpc.MaxRecvMsgSize(math.MaxInt32),
grpc.MaxSendMsgSize(math.MaxInt32),
grpc.StreamInterceptor(haMembershipStreamServerInterceptor(c, removableHABackend)),
grpc.UnaryInterceptor(haMembershipUnaryServerInterceptor(c, removableHABackend)),
)
if ha && c.clusterHandler != nil {
@@ -274,6 +380,8 @@ func (c *Core) refreshRequestForwardingConnection(ctx context.Context, clusterAd
core: c,
})
removableHABackend := c.getRemovableHABackend()
// Set up grpc forwarding handling
// It's not really insecure, but we have to dial manually to get the
// ALPN header right. It's just "insecure" because GRPC isn't managing
@@ -285,6 +393,8 @@ func (c *Core) refreshRequestForwardingConnection(ctx context.Context, clusterAd
grpc.WithKeepaliveParams(keepalive.ClientParameters{
Time: 2 * c.clusterHeartbeatInterval,
}),
grpc.WithStreamInterceptor(haMembershipStreamClientInterceptor(c, removableHABackend)),
grpc.WithUnaryInterceptor(haMembershipUnaryClientInterceptor(c, removableHABackend)),
grpc.WithDefaultCallOptions(
grpc.MaxCallRecvMsgSize(math.MaxInt32),
grpc.MaxCallSendMsgSize(math.MaxInt32),
@@ -374,6 +484,10 @@ func (c *Core) ForwardRequest(req *http.Request) (int, http.Header, []byte, erro
if err != nil {
metrics.IncrCounter([]string{"ha", "rpc", "client", "forward", "errors"}, 1)
c.logger.Error("error during forwarded RPC request", "error", err)
if errors.Is(err, StatusNotHAMember) {
return 0, nil, nil, fmt.Errorf("error during forwarding RPC request: %w", err)
}
return 0, nil, nil, fmt.Errorf("error during forwarding RPC request")
}

View File

@@ -0,0 +1,131 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package vault
import (
"context"
"errors"
"testing"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/sdk/physical"
"github.com/stretchr/testify/require"
"google.golang.org/grpc/metadata"
)
// Test_haIDFromContext verifies that the HA node ID gets correctly extracted
// from a gRPC context
func Test_haIDFromContext(t *testing.T) {
testCases := []struct {
name string
md metadata.MD
wantID string
wantOk bool
}{
{
name: "no ID",
md: metadata.MD{},
wantID: "",
wantOk: false,
},
{
name: "with ID",
md: metadata.MD{haNodeIDKey: {"node_id"}},
wantID: "node_id",
wantOk: true,
},
{
name: "with empty string ID",
md: metadata.MD{haNodeIDKey: {""}},
wantID: "",
wantOk: true,
},
{
name: "with empty ID",
md: metadata.MD{haNodeIDKey: {}},
wantID: "",
wantOk: false,
},
{
name: "with multiple IDs",
md: metadata.MD{haNodeIDKey: {"1", "2"}},
wantID: "1",
wantOk: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
ctx := metadata.NewIncomingContext(context.Background(), tc.md)
id, ok := haIDFromContext(ctx)
require.Equal(t, tc.wantID, id)
require.Equal(t, tc.wantOk, ok)
})
}
}
type mockHARemovableNodeBackend struct {
physical.RemovableNodeHABackend
isRemoved func(context.Context, string) (bool, error)
}
func (m *mockHARemovableNodeBackend) IsNodeRemoved(ctx context.Context, nodeID string) (bool, error) {
return m.isRemoved(ctx, nodeID)
}
func newMockHARemovableNodeBackend(isRemoved func(context.Context, string) (bool, error)) physical.RemovableNodeHABackend {
return &mockHARemovableNodeBackend{isRemoved: isRemoved}
}
// Test_haMembershipServerCheck verifies that the correct error is returned
// when the context contains a removed node ID
func Test_haMembershipServerCheck(t *testing.T) {
nodeIDCtx := metadata.NewIncomingContext(context.Background(), metadata.MD{haNodeIDKey: {"node_id"}})
otherErr := errors.New("error checking")
testCases := []struct {
name string
nodeIDCtx context.Context
haBackend physical.RemovableNodeHABackend
wantError error
}{
{
name: "nil backend",
haBackend: nil,
nodeIDCtx: nodeIDCtx,
}, {
name: "no node ID context",
haBackend: newMockHARemovableNodeBackend(func(ctx context.Context, s string) (bool, error) {
return false, nil
}),
nodeIDCtx: context.Background(),
}, {
name: "node removed",
haBackend: newMockHARemovableNodeBackend(func(ctx context.Context, s string) (bool, error) {
return true, nil
}),
nodeIDCtx: nodeIDCtx,
wantError: StatusNotHAMember,
}, {
name: "node removed err",
haBackend: newMockHARemovableNodeBackend(func(ctx context.Context, s string) (bool, error) {
return false, otherErr
}),
nodeIDCtx: nodeIDCtx,
wantError: otherErr,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
c := &Core{
logger: hclog.NewNullLogger(),
}
err := haMembershipServerCheck(tc.nodeIDCtx, c, tc.haBackend)
if tc.wantError != nil {
require.EqualError(t, err, tc.wantError.Error())
} else {
require.NoError(t, err)
}
})
}
}