mirror of
				https://github.com/optim-enterprises-bv/vault.git
				synced 2025-10-31 02:28:09 +00:00 
			
		
		
		
	 961bf20bdb
			
		
	
	961bf20bdb
	
	
	
		
			
			We have many hand-written String() methods (and similar) for enums. These require more maintenance and are more error-prone than using automatically generated methods. In addition, the auto-generated versions can be more efficient. Here, we switch to using https://github.com/loggerhead/enumer, itself a fork of https://github.com/diegostamigni/enumer, no longer maintained, and a fork of the mostly standard tool https://pkg.go.dev/golang.org/x/tools/cmd/stringer. We use this fork of enumer for Go 1.20+ compatibility and because we require the `-transform` flag to be able to generate constants that match our current code base. Some enums were not targeted for this change:
		
			
				
	
	
		
			464 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			464 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| // Copyright (c) HashiCorp, Inc.
 | |
| // SPDX-License-Identifier: MPL-2.0
 | |
| 
 | |
| package testcluster
 | |
| 
 | |
| import (
 | |
| 	"context"
 | |
| 	"encoding/base64"
 | |
| 	"encoding/hex"
 | |
| 	"fmt"
 | |
| 	"sync/atomic"
 | |
| 	"time"
 | |
| 
 | |
| 	"github.com/hashicorp/go-multierror"
 | |
| 	"github.com/hashicorp/go-uuid"
 | |
| 	"github.com/hashicorp/vault/api"
 | |
| 	"github.com/hashicorp/vault/sdk/helper/xor"
 | |
| )
 | |
| 
 | |
| // Note that OSS standbys will not accept seal requests.  And ent perf standbys
 | |
| // may fail it as well if they haven't yet been able to get "elected" as perf standbys.
 | |
| func SealNode(ctx context.Context, cluster VaultCluster, nodeIdx int) error {
 | |
| 	if nodeIdx >= len(cluster.Nodes()) {
 | |
| 		return fmt.Errorf("invalid nodeIdx %d for cluster", nodeIdx)
 | |
| 	}
 | |
| 	node := cluster.Nodes()[nodeIdx]
 | |
| 	client := node.APIClient()
 | |
| 
 | |
| 	err := client.Sys().SealWithContext(ctx)
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	return NodeSealed(ctx, cluster, nodeIdx)
 | |
| }
 | |
| 
 | |
| func SealAllNodes(ctx context.Context, cluster VaultCluster) error {
 | |
| 	for i := range cluster.Nodes() {
 | |
| 		if err := SealNode(ctx, cluster, i); err != nil {
 | |
| 			return err
 | |
| 		}
 | |
| 	}
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| func UnsealNode(ctx context.Context, cluster VaultCluster, nodeIdx int) error {
 | |
| 	if nodeIdx >= len(cluster.Nodes()) {
 | |
| 		return fmt.Errorf("invalid nodeIdx %d for cluster", nodeIdx)
 | |
| 	}
 | |
| 	node := cluster.Nodes()[nodeIdx]
 | |
| 	client := node.APIClient()
 | |
| 
 | |
| 	for _, key := range cluster.GetBarrierOrRecoveryKeys() {
 | |
| 		_, err := client.Sys().UnsealWithContext(ctx, hex.EncodeToString(key))
 | |
| 		if err != nil {
 | |
| 			return err
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	return NodeHealthy(ctx, cluster, nodeIdx)
 | |
| }
 | |
| 
 | |
| func UnsealAllNodes(ctx context.Context, cluster VaultCluster) error {
 | |
| 	for i := range cluster.Nodes() {
 | |
| 		if err := UnsealNode(ctx, cluster, i); err != nil {
 | |
| 			return err
 | |
| 		}
 | |
| 	}
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| func NodeSealed(ctx context.Context, cluster VaultCluster, nodeIdx int) error {
 | |
| 	if nodeIdx >= len(cluster.Nodes()) {
 | |
| 		return fmt.Errorf("invalid nodeIdx %d for cluster", nodeIdx)
 | |
| 	}
 | |
| 	node := cluster.Nodes()[nodeIdx]
 | |
| 	client := node.APIClient()
 | |
| 
 | |
| 	var health *api.HealthResponse
 | |
| 	var err error
 | |
| 	for ctx.Err() == nil {
 | |
| 		health, err = client.Sys().HealthWithContext(ctx)
 | |
| 		switch {
 | |
| 		case err != nil:
 | |
| 		case !health.Sealed:
 | |
| 			err = fmt.Errorf("unsealed: %#v", health)
 | |
| 		default:
 | |
| 			return nil
 | |
| 		}
 | |
| 		time.Sleep(500 * time.Millisecond)
 | |
| 	}
 | |
| 	return fmt.Errorf("node %d is not sealed: %v", nodeIdx, err)
 | |
| }
 | |
| 
 | |
| func WaitForNCoresSealed(ctx context.Context, cluster VaultCluster, n int) error {
 | |
| 	ctx, cancel := context.WithCancel(ctx)
 | |
| 	defer cancel()
 | |
| 
 | |
| 	errs := make(chan error)
 | |
| 	for i := range cluster.Nodes() {
 | |
| 		go func(i int) {
 | |
| 			var err error
 | |
| 			for ctx.Err() == nil {
 | |
| 				err = NodeSealed(ctx, cluster, i)
 | |
| 				if err == nil {
 | |
| 					errs <- nil
 | |
| 					return
 | |
| 				}
 | |
| 				time.Sleep(100 * time.Millisecond)
 | |
| 			}
 | |
| 			if err == nil {
 | |
| 				err = ctx.Err()
 | |
| 			}
 | |
| 			errs <- err
 | |
| 		}(i)
 | |
| 	}
 | |
| 
 | |
| 	var merr *multierror.Error
 | |
| 	var sealed int
 | |
| 	for range cluster.Nodes() {
 | |
| 		err := <-errs
 | |
| 		if err != nil {
 | |
| 			merr = multierror.Append(merr, err)
 | |
| 		} else {
 | |
| 			sealed++
 | |
| 			if sealed == n {
 | |
| 				return nil
 | |
| 			}
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	return fmt.Errorf("%d cores were not sealed, errs: %v", n, merr.ErrorOrNil())
 | |
| }
 | |
| 
 | |
| func NodeHealthy(ctx context.Context, cluster VaultCluster, nodeIdx int) error {
 | |
| 	if nodeIdx >= len(cluster.Nodes()) {
 | |
| 		return fmt.Errorf("invalid nodeIdx %d for cluster", nodeIdx)
 | |
| 	}
 | |
| 	node := cluster.Nodes()[nodeIdx]
 | |
| 	client := node.APIClient()
 | |
| 
 | |
| 	var health *api.HealthResponse
 | |
| 	var err error
 | |
| 	for ctx.Err() == nil {
 | |
| 		health, err = client.Sys().HealthWithContext(ctx)
 | |
| 		switch {
 | |
| 		case err != nil:
 | |
| 		case health == nil:
 | |
| 			err = fmt.Errorf("nil response to health check")
 | |
| 		case health.Sealed:
 | |
| 			err = fmt.Errorf("sealed: %#v", health)
 | |
| 		default:
 | |
| 			return nil
 | |
| 		}
 | |
| 		time.Sleep(500 * time.Millisecond)
 | |
| 	}
 | |
| 	return fmt.Errorf("node %d is unhealthy: %v", nodeIdx, err)
 | |
| }
 | |
| 
 | |
| func LeaderNode(ctx context.Context, cluster VaultCluster) (int, error) {
 | |
| 	// Be robust to multiple nodes thinking they are active. This is possible in
 | |
| 	// certain network partition situations where the old leader has not
 | |
| 	// discovered it's lost leadership yet. In tests this is only likely to come
 | |
| 	// up when we are specifically provoking it, but it's possible it could happen
 | |
| 	// at any point if leadership flaps of connectivity suffers transient errors
 | |
| 	// etc. so be robust against it. The best solution would be to have some sort
 | |
| 	// of epoch like the raft term that is guaranteed to be monotonically
 | |
| 	// increasing through elections, however we don't have that abstraction for
 | |
| 	// all HABackends in general. The best we have is the ActiveTime. In a
 | |
| 	// distributed systems text book this would be bad to rely on due to clock
 | |
| 	// sync issues etc. but for our tests it's likely fine because even if we are
 | |
| 	// running separate Vault containers, they are all using the same hardware
 | |
| 	// clock in the system.
 | |
| 	leaderActiveTimes := make(map[int]time.Time)
 | |
| 	for i, node := range cluster.Nodes() {
 | |
| 		client := node.APIClient()
 | |
| 		ctx, cancel := context.WithTimeout(ctx, 500*time.Millisecond)
 | |
| 		resp, err := client.Sys().LeaderWithContext(ctx)
 | |
| 		cancel()
 | |
| 		if err != nil || resp == nil || !resp.IsSelf {
 | |
| 			continue
 | |
| 		}
 | |
| 		leaderActiveTimes[i] = resp.ActiveTime
 | |
| 	}
 | |
| 	if len(leaderActiveTimes) == 0 {
 | |
| 		return -1, fmt.Errorf("no leader found")
 | |
| 	}
 | |
| 	// At least one node thinks it is active. If multiple, pick the one with the
 | |
| 	// most recent ActiveTime. Note if there is only one then this just returns
 | |
| 	// it.
 | |
| 	var newestLeaderIdx int
 | |
| 	var newestActiveTime time.Time
 | |
| 	for i, at := range leaderActiveTimes {
 | |
| 		if at.After(newestActiveTime) {
 | |
| 			newestActiveTime = at
 | |
| 			newestLeaderIdx = i
 | |
| 		}
 | |
| 	}
 | |
| 	return newestLeaderIdx, nil
 | |
| }
 | |
| 
 | |
| func WaitForActiveNode(ctx context.Context, cluster VaultCluster) (int, error) {
 | |
| 	for ctx.Err() == nil {
 | |
| 		if idx, _ := LeaderNode(ctx, cluster); idx != -1 {
 | |
| 			return idx, nil
 | |
| 		}
 | |
| 		time.Sleep(500 * time.Millisecond)
 | |
| 	}
 | |
| 	return -1, ctx.Err()
 | |
| }
 | |
| 
 | |
| func WaitForStandbyNode(ctx context.Context, cluster VaultCluster, nodeIdx int) error {
 | |
| 	if nodeIdx >= len(cluster.Nodes()) {
 | |
| 		return fmt.Errorf("invalid nodeIdx %d for cluster", nodeIdx)
 | |
| 	}
 | |
| 	node := cluster.Nodes()[nodeIdx]
 | |
| 	client := node.APIClient()
 | |
| 
 | |
| 	var err error
 | |
| 	for ctx.Err() == nil {
 | |
| 		var resp *api.LeaderResponse
 | |
| 
 | |
| 		resp, err = client.Sys().LeaderWithContext(ctx)
 | |
| 		switch {
 | |
| 		case err != nil:
 | |
| 		case resp.IsSelf:
 | |
| 			return fmt.Errorf("waiting for standby but node is leader")
 | |
| 		case resp.LeaderAddress == "":
 | |
| 			err = fmt.Errorf("node doesn't know leader address")
 | |
| 		default:
 | |
| 			return nil
 | |
| 		}
 | |
| 
 | |
| 		time.Sleep(100 * time.Millisecond)
 | |
| 	}
 | |
| 	if err == nil {
 | |
| 		err = ctx.Err()
 | |
| 	}
 | |
| 	return err
 | |
| }
 | |
| 
 | |
| func WaitForActiveNodeAndStandbys(ctx context.Context, cluster VaultCluster) (int, error) {
 | |
| 	ctx, cancel := context.WithCancel(ctx)
 | |
| 	defer cancel()
 | |
| 
 | |
| 	leaderIdx, err := WaitForActiveNode(ctx, cluster)
 | |
| 	if err != nil {
 | |
| 		return 0, err
 | |
| 	}
 | |
| 
 | |
| 	if len(cluster.Nodes()) == 1 {
 | |
| 		return 0, nil
 | |
| 	}
 | |
| 
 | |
| 	errs := make(chan error)
 | |
| 	for i := range cluster.Nodes() {
 | |
| 		if i == leaderIdx {
 | |
| 			continue
 | |
| 		}
 | |
| 		go func(i int) {
 | |
| 			errs <- WaitForStandbyNode(ctx, cluster, i)
 | |
| 		}(i)
 | |
| 	}
 | |
| 
 | |
| 	var merr *multierror.Error
 | |
| 	expectedStandbys := len(cluster.Nodes()) - 1
 | |
| 	for i := 0; i < expectedStandbys; i++ {
 | |
| 		merr = multierror.Append(merr, <-errs)
 | |
| 	}
 | |
| 
 | |
| 	return leaderIdx, merr.ErrorOrNil()
 | |
| }
 | |
| 
 | |
| func WaitForActiveNodeAndPerfStandbys(ctx context.Context, cluster VaultCluster) error {
 | |
| 	logger := cluster.NamedLogger("WaitForActiveNodeAndPerfStandbys")
 | |
| 	// This WaitForActiveNode was added because after a Raft cluster is sealed
 | |
| 	// and then unsealed, when it comes up it may have a different leader than
 | |
| 	// Core0, making this helper fail.
 | |
| 	// A sleep before calling WaitForActiveNodeAndPerfStandbys seems to sort
 | |
| 	// things out, but so apparently does this.  We should be able to eliminate
 | |
| 	// this call to WaitForActiveNode by reworking the logic in this method.
 | |
| 	leaderIdx, err := WaitForActiveNode(ctx, cluster)
 | |
| 	if err != nil {
 | |
| 		return fmt.Errorf("did not find leader: %w", err)
 | |
| 	}
 | |
| 
 | |
| 	if len(cluster.Nodes()) == 1 {
 | |
| 		return nil
 | |
| 	}
 | |
| 
 | |
| 	expectedStandbys := len(cluster.Nodes()) - 1
 | |
| 
 | |
| 	mountPoint, err := uuid.GenerateUUID()
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 	leaderClient := cluster.Nodes()[leaderIdx].APIClient()
 | |
| 
 | |
| 	for ctx.Err() == nil {
 | |
| 		err = leaderClient.Sys().MountWithContext(ctx, mountPoint, &api.MountInput{
 | |
| 			Type:  "kv",
 | |
| 			Local: true,
 | |
| 		})
 | |
| 		if err == nil {
 | |
| 			break
 | |
| 		}
 | |
| 		time.Sleep(1 * time.Second)
 | |
| 	}
 | |
| 	if err != nil {
 | |
| 		return fmt.Errorf("unable to mount KV engine: %w", err)
 | |
| 	}
 | |
| 	path := mountPoint + "/waitforactivenodeandperfstandbys"
 | |
| 	var standbys, actives int64
 | |
| 	errchan := make(chan error, len(cluster.Nodes()))
 | |
| 	for i := range cluster.Nodes() {
 | |
| 		go func(coreNo int) {
 | |
| 			node := cluster.Nodes()[coreNo]
 | |
| 			client := node.APIClient()
 | |
| 			val := 1
 | |
| 			var err error
 | |
| 			defer func() {
 | |
| 				errchan <- err
 | |
| 			}()
 | |
| 
 | |
| 			var lastWAL uint64
 | |
| 			for ctx.Err() == nil {
 | |
| 				_, err = leaderClient.Logical().WriteWithContext(ctx, path, map[string]interface{}{
 | |
| 					"bar": val,
 | |
| 				})
 | |
| 				val++
 | |
| 				time.Sleep(250 * time.Millisecond)
 | |
| 				if err != nil {
 | |
| 					continue
 | |
| 				}
 | |
| 				var leader *api.LeaderResponse
 | |
| 				leader, err = client.Sys().LeaderWithContext(ctx)
 | |
| 				if err != nil {
 | |
| 					logger.Trace("waiting for core", "core", coreNo, "err", err)
 | |
| 					continue
 | |
| 				}
 | |
| 				switch {
 | |
| 				case leader.IsSelf:
 | |
| 					logger.Trace("waiting for core", "core", coreNo, "isLeader", true)
 | |
| 					atomic.AddInt64(&actives, 1)
 | |
| 					return
 | |
| 				case leader.PerfStandby && leader.PerfStandbyLastRemoteWAL > 0:
 | |
| 					switch {
 | |
| 					case lastWAL == 0:
 | |
| 						lastWAL = leader.PerfStandbyLastRemoteWAL
 | |
| 						logger.Trace("waiting for core", "core", coreNo, "lastRemoteWAL", leader.PerfStandbyLastRemoteWAL, "lastWAL", lastWAL)
 | |
| 					case lastWAL < leader.PerfStandbyLastRemoteWAL:
 | |
| 						logger.Trace("waiting for core", "core", coreNo, "lastRemoteWAL", leader.PerfStandbyLastRemoteWAL, "lastWAL", lastWAL)
 | |
| 						atomic.AddInt64(&standbys, 1)
 | |
| 						return
 | |
| 					}
 | |
| 				default:
 | |
| 					logger.Trace("waiting for core", "core", coreNo,
 | |
| 						"ha_enabled", leader.HAEnabled,
 | |
| 						"is_self", leader.IsSelf,
 | |
| 						"perf_standby", leader.PerfStandby,
 | |
| 						"perf_standby_remote_wal", leader.PerfStandbyLastRemoteWAL)
 | |
| 				}
 | |
| 			}
 | |
| 		}(i)
 | |
| 	}
 | |
| 
 | |
| 	errs := make([]error, 0, len(cluster.Nodes()))
 | |
| 	for range cluster.Nodes() {
 | |
| 		errs = append(errs, <-errchan)
 | |
| 	}
 | |
| 	if actives != 1 || int(standbys) != expectedStandbys {
 | |
| 		return fmt.Errorf("expected 1 active core and %d standbys, got %d active and %d standbys, errs: %v",
 | |
| 			expectedStandbys, actives, standbys, errs)
 | |
| 	}
 | |
| 
 | |
| 	for ctx.Err() == nil {
 | |
| 		err = leaderClient.Sys().UnmountWithContext(ctx, mountPoint)
 | |
| 		if err == nil {
 | |
| 			break
 | |
| 		}
 | |
| 		time.Sleep(time.Second)
 | |
| 	}
 | |
| 	if err != nil {
 | |
| 		return fmt.Errorf("unable to unmount KV engine: %w", err)
 | |
| 	}
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| func Clients(vc VaultCluster) []*api.Client {
 | |
| 	var ret []*api.Client
 | |
| 	for _, n := range vc.Nodes() {
 | |
| 		ret = append(ret, n.APIClient())
 | |
| 	}
 | |
| 	return ret
 | |
| }
 | |
| 
 | |
| //go:generate enumer -type=GenerateRootKind -trimprefix=GenerateRoot
 | |
| type GenerateRootKind int
 | |
| 
 | |
| const (
 | |
| 	GenerateRootRegular GenerateRootKind = iota
 | |
| 	GenerateRootDR
 | |
| 	GenerateRecovery
 | |
| )
 | |
| 
 | |
| func GenerateRoot(cluster VaultCluster, kind GenerateRootKind) (string, error) {
 | |
| 	// If recovery keys supported, use those to perform root token generation instead
 | |
| 	keys := cluster.GetBarrierOrRecoveryKeys()
 | |
| 
 | |
| 	client := cluster.Nodes()[0].APIClient()
 | |
| 
 | |
| 	var err error
 | |
| 	var status *api.GenerateRootStatusResponse
 | |
| 	switch kind {
 | |
| 	case GenerateRootRegular:
 | |
| 		status, err = client.Sys().GenerateRootInit("", "")
 | |
| 	case GenerateRootDR:
 | |
| 		status, err = client.Sys().GenerateDROperationTokenInit("", "")
 | |
| 	case GenerateRecovery:
 | |
| 		status, err = client.Sys().GenerateRecoveryOperationTokenInit("", "")
 | |
| 	}
 | |
| 	if err != nil {
 | |
| 		return "", err
 | |
| 	}
 | |
| 
 | |
| 	if status.Required > len(keys) {
 | |
| 		return "", fmt.Errorf("need more keys than have, need %d have %d", status.Required, len(keys))
 | |
| 	}
 | |
| 
 | |
| 	otp := status.OTP
 | |
| 
 | |
| 	for i, key := range keys {
 | |
| 		if i >= status.Required {
 | |
| 			break
 | |
| 		}
 | |
| 
 | |
| 		strKey := base64.StdEncoding.EncodeToString(key)
 | |
| 		switch kind {
 | |
| 		case GenerateRootRegular:
 | |
| 			status, err = client.Sys().GenerateRootUpdate(strKey, status.Nonce)
 | |
| 		case GenerateRootDR:
 | |
| 			status, err = client.Sys().GenerateDROperationTokenUpdate(strKey, status.Nonce)
 | |
| 		case GenerateRecovery:
 | |
| 			status, err = client.Sys().GenerateRecoveryOperationTokenUpdate(strKey, status.Nonce)
 | |
| 		}
 | |
| 		if err != nil {
 | |
| 			return "", err
 | |
| 		}
 | |
| 	}
 | |
| 	if !status.Complete {
 | |
| 		return "", fmt.Errorf("generate root operation did not end successfully")
 | |
| 	}
 | |
| 
 | |
| 	tokenBytes, err := base64.RawStdEncoding.DecodeString(status.EncodedToken)
 | |
| 	if err != nil {
 | |
| 		return "", err
 | |
| 	}
 | |
| 	tokenBytes, err = xor.XORBytes(tokenBytes, []byte(otp))
 | |
| 	if err != nil {
 | |
| 		return "", err
 | |
| 	}
 | |
| 	return string(tokenBytes), nil
 | |
| }
 |