mirror of
				https://github.com/optim-enterprises-bv/vault.git
				synced 2025-10-31 02:28:09 +00:00 
			
		
		
		
	VAULT-19255 - Add event based static secret cache updater to Vault Proxy (#23560)
* VAULT-19255 first pass at structure for event updater * VAULT-19255 some more work, committign before rebase * VAULT-19255 Mostly finish event updating scaffolding * VAULT-19255 some additional coverage, clean-up, etc * VAULT-19255 some clean-up * VAULT-19255 fix tests * VAULT-19255 more WIP event system integration * VAULT-19255 More WIP * VAULT-19255 more discovery * VAULT-19255 add new test, some clean up * VAULT-19255 fix bug, extra clean-up * VAULT-19255 fix bugs, and clean up * VAULT-19255 clean imports, add more godocs * VAULT-19255 add config for test * VAULT-19255 typo * VAULT-19255 don't do the kv refactor in this PR * VAULT-19255 update docs * VAULT-19255 PR feedback * VAULT-19255 More specific error messages
This commit is contained in:
		| @@ -274,7 +274,7 @@ func (c *CacheMemDB) GetByPrefix(indexName string, indexValues ...interface{}) ( | ||||
| // Evict removes an index from the cache based on index name and value. | ||||
| func (c *CacheMemDB) Evict(indexName string, indexValues ...interface{}) error { | ||||
| 	index, err := c.Get(indexName, indexValues...) | ||||
| 	if err == ErrCacheItemNotFound { | ||||
| 	if errors.Is(err, ErrCacheItemNotFound) { | ||||
| 		return nil | ||||
| 	} | ||||
| 	if err != nil { | ||||
|   | ||||
							
								
								
									
										68
									
								
								command/agentproxyshared/cache/lease_cache.go
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										68
									
								
								command/agentproxyshared/cache/lease_cache.go
									
									
									
									
										vendored
									
									
								
							| @@ -205,25 +205,8 @@ func (c *LeaseCache) checkCacheForStaticSecretRequest(id string, req *SendReques | ||||
| // If a token is provided, it will validate that the token is allowed to retrieve this | ||||
| // cache entry, and return nil if it isn't. | ||||
| func (c *LeaseCache) checkCacheForRequest(id string, req *SendRequest) (*SendResponse, error) { | ||||
| 	var token string | ||||
| 	if req != nil { | ||||
| 		token = req.Token | ||||
| 		// HEAD and OPTIONS are included as future-proofing, since neither of those modify the resource either. | ||||
| 		if req.Request.Method != http.MethodGet && req.Request.Method != http.MethodHead && req.Request.Method != http.MethodOptions { | ||||
| 			// This must be an update to the resource, so we should short-circuit and invalidate the cache | ||||
| 			// as we know the cache is now stale. | ||||
| 			c.logger.Debug("evicting index from cache, as non-GET received", "id", id, "method", req.Request.Method, "path", req.Request.URL.Path) | ||||
| 			err := c.db.Evict(cachememdb.IndexNameID, id) | ||||
| 			if err != nil { | ||||
| 				return nil, err | ||||
| 			} | ||||
|  | ||||
| 			return nil, nil | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	index, err := c.db.Get(cachememdb.IndexNameID, id) | ||||
| 	if err == cachememdb.ErrCacheItemNotFound { | ||||
| 	if errors.Is(err, cachememdb.ErrCacheItemNotFound) { | ||||
| 		return nil, nil | ||||
| 	} | ||||
| 	if err != nil { | ||||
| @@ -233,8 +216,17 @@ func (c *LeaseCache) checkCacheForRequest(id string, req *SendRequest) (*SendRes | ||||
| 	index.IndexLock.RLock() | ||||
| 	defer index.IndexLock.RUnlock() | ||||
|  | ||||
| 	var token string | ||||
| 	if req != nil { | ||||
| 		// Req will be non-nil if we're checking for a static secret. | ||||
| 		// Token might still be "" if it's going to an unauthenticated | ||||
| 		// endpoint, or similar. For static secrets, we only care about | ||||
| 		// requests with tokens attached, as KV is authenticated. | ||||
| 		token = req.Token | ||||
| 	} | ||||
|  | ||||
| 	if token != "" { | ||||
| 		// This is a static secret check. We need to ensure that this token | ||||
| 		// We are checking for a static secret. We need to ensure that this token | ||||
| 		// has previously demonstrated access to this static secret. | ||||
| 		// We could check the capabilities cache here, but since these | ||||
| 		// indexes should be in sync, this saves us an extra cache get. | ||||
| @@ -381,7 +373,7 @@ func (c *LeaseCache) Send(ctx context.Context, req *SendRequest) (*SendResponse, | ||||
| 	} | ||||
|  | ||||
| 	// Check if the response for this request is already in the static secret cache | ||||
| 	if staticSecretCacheId != "" { | ||||
| 	if staticSecretCacheId != "" && req.Request.Method == http.MethodGet { | ||||
| 		cachedResp, err = c.checkCacheForStaticSecretRequest(staticSecretCacheId, req) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| @@ -446,7 +438,9 @@ func (c *LeaseCache) Send(ctx context.Context, req *SendRequest) (*SendResponse, | ||||
|  | ||||
| 	// There shouldn't be a situation where secret.MountType == "kv" and | ||||
| 	// staticSecretCacheId == "", but just in case. | ||||
| 	if c.cacheStaticSecrets && secret.MountType == "kv" && staticSecretCacheId != "" { | ||||
| 	// We restrict this to GETs as those are all we want to cache. | ||||
| 	if c.cacheStaticSecrets && secret.MountType == "kv" && | ||||
| 		staticSecretCacheId != "" && req.Request.Method == http.MethodGet { | ||||
| 		index.Type = cacheboltdb.StaticSecretType | ||||
| 		index.ID = staticSecretCacheId | ||||
| 		err := c.cacheStaticSecret(ctx, req, resp, index) | ||||
| @@ -475,7 +469,7 @@ func (c *LeaseCache) Send(ctx context.Context, req *SendRequest) (*SendResponse, | ||||
| 	case secret.LeaseID != "": | ||||
| 		c.logger.Debug("processing lease response", "method", req.Request.Method, "path", req.Request.URL.Path) | ||||
| 		entry, err := c.db.Get(cachememdb.IndexNameToken, req.Token) | ||||
| 		if err == cachememdb.ErrCacheItemNotFound { | ||||
| 		if errors.Is(err, cachememdb.ErrCacheItemNotFound) { | ||||
| 			// If the lease belongs to a token that is not managed by the lease cache, | ||||
| 			// return the response without caching it. | ||||
| 			c.logger.Debug("pass-through lease response; token not managed by lease cache", "method", req.Request.Method, "path", req.Request.URL.Path) | ||||
| @@ -501,7 +495,7 @@ func (c *LeaseCache) Send(ctx context.Context, req *SendRequest) (*SendResponse, | ||||
| 		var parentCtx context.Context | ||||
| 		if !secret.Auth.Orphan { | ||||
| 			entry, err := c.db.Get(cachememdb.IndexNameToken, req.Token) | ||||
| 			if err == cachememdb.ErrCacheItemNotFound { | ||||
| 			if errors.Is(err, cachememdb.ErrCacheItemNotFound) { | ||||
| 				// If the lease belongs to a token that is not managed by the lease cache, | ||||
| 				// return the response without caching it. | ||||
| 				c.logger.Debug("pass-through lease response; parent token not managed by lease cache", "method", req.Request.Method, "path", req.Request.URL.Path) | ||||
| @@ -564,7 +558,7 @@ func (c *LeaseCache) Send(ctx context.Context, req *SendRequest) (*SendResponse, | ||||
|  | ||||
| 	if index.Type != cacheboltdb.StaticSecretType { | ||||
| 		// Store the index in the cache | ||||
| 		c.logger.Debug("storing response into the cache", "method", req.Request.Method, "path", req.Request.URL.Path) | ||||
| 		c.logger.Debug("storing dynamic secret response into the cache", "method", req.Request.Method, "path", req.Request.URL.Path, "id", index.ID) | ||||
| 		err = c.Set(ctx, index) | ||||
| 		if err != nil { | ||||
| 			c.logger.Error("failed to cache the proxied response", "error", err) | ||||
| @@ -587,7 +581,7 @@ func (c *LeaseCache) cacheStaticSecret(ctx context.Context, req *SendRequest, re | ||||
| 	} | ||||
|  | ||||
| 	// The index already exists, so all we need to do is add our token | ||||
| 	// to the index's allowed token list, then re-store it | ||||
| 	// to the index's allowed token list, then re-store it. | ||||
| 	if indexFromCache != nil { | ||||
| 		// We must hold a lock for the index while it's being updated. | ||||
| 		// We keep the two locking mechanisms distinct, so that it's only writes | ||||
| @@ -627,7 +621,7 @@ func (c *LeaseCache) cacheStaticSecret(ctx context.Context, req *SendRequest, re | ||||
|  | ||||
| func (c *LeaseCache) storeStaticSecretIndex(ctx context.Context, req *SendRequest, index *cachememdb.Index) error { | ||||
| 	// Store the index in the cache | ||||
| 	c.logger.Debug("storing response into the cache", "method", req.Request.Method, "path", req.Request.URL.Path) | ||||
| 	c.logger.Debug("storing static secret response into the cache", "method", req.Request.Method, "path", req.Request.URL.Path, "id", index.ID) | ||||
| 	err := c.Set(ctx, index) | ||||
| 	if err != nil { | ||||
| 		c.logger.Error("failed to cache the proxied response", "error", err) | ||||
| @@ -663,7 +657,7 @@ func (c *LeaseCache) storeStaticSecretIndex(ctx context.Context, req *SendReques | ||||
| // capabilities entry from the cache, or create a new, empty one. | ||||
| func (c *LeaseCache) retrieveOrCreateTokenCapabilitiesEntry(token string) (*cachememdb.CapabilitiesIndex, error) { | ||||
| 	// The index ID is a hash of the token. | ||||
| 	indexId := hex.EncodeToString(cryptoutil.Blake2b256Hash(token)) | ||||
| 	indexId := hashStaticSecretIndex(token) | ||||
| 	indexFromCache, err := c.db.GetCapabilitiesIndex(cachememdb.IndexNameID, indexId) | ||||
| 	if err != nil && err != cachememdb.ErrCacheItemNotFound { | ||||
| 		return nil, err | ||||
| @@ -860,6 +854,12 @@ func getStaticSecretPathFromRequest(req *SendRequest) string { | ||||
| 	return canonicalizeStaticSecretPath(path, namespace) | ||||
| } | ||||
|  | ||||
| // hashStaticSecretIndex is a simple function that hashes the path into | ||||
| // a function. This is kept as a helper function for ease of use by downstream functions. | ||||
| func hashStaticSecretIndex(unhashedIndex string) string { | ||||
| 	return hex.EncodeToString(cryptoutil.Blake2b256Hash(unhashedIndex)) | ||||
| } | ||||
|  | ||||
| // computeStaticSecretCacheIndex results in a value that uniquely identifies a static | ||||
| // secret's cached ID. Notably, we intentionally ignore headers (for example, | ||||
| // the X-Vault-Token header) to remain agnostic to which token is being | ||||
| @@ -871,7 +871,7 @@ func computeStaticSecretCacheIndex(req *SendRequest) string { | ||||
| 	if path == "" { | ||||
| 		return path | ||||
| 	} | ||||
| 	return hex.EncodeToString(cryptoutil.Blake2b256Hash(path)) | ||||
| 	return hashStaticSecretIndex(path) | ||||
| } | ||||
|  | ||||
| // HandleCacheClear returns a handlerFunc that can perform cache clearing operations. | ||||
| @@ -973,7 +973,7 @@ func (c *LeaseCache) handleCacheClear(ctx context.Context, in *cacheClearInput) | ||||
|  | ||||
| 		// Get the context for the given token and cancel its context | ||||
| 		index, err := c.db.Get(cachememdb.IndexNameToken, in.Token) | ||||
| 		if err == cachememdb.ErrCacheItemNotFound { | ||||
| 		if errors.Is(err, cachememdb.ErrCacheItemNotFound) { | ||||
| 			return nil | ||||
| 		} | ||||
| 		if err != nil { | ||||
| @@ -992,7 +992,7 @@ func (c *LeaseCache) handleCacheClear(ctx context.Context, in *cacheClearInput) | ||||
| 		// Get the cached index and cancel the corresponding lifetime watcher | ||||
| 		// context | ||||
| 		index, err := c.db.Get(cachememdb.IndexNameTokenAccessor, in.TokenAccessor) | ||||
| 		if err == cachememdb.ErrCacheItemNotFound { | ||||
| 		if errors.Is(err, cachememdb.ErrCacheItemNotFound) { | ||||
| 			return nil | ||||
| 		} | ||||
| 		if err != nil { | ||||
| @@ -1011,7 +1011,7 @@ func (c *LeaseCache) handleCacheClear(ctx context.Context, in *cacheClearInput) | ||||
| 		// Get the cached index and cancel the corresponding lifetime watcher | ||||
| 		// context | ||||
| 		index, err := c.db.Get(cachememdb.IndexNameLease, in.Lease) | ||||
| 		if err == cachememdb.ErrCacheItemNotFound { | ||||
| 		if errors.Is(err, cachememdb.ErrCacheItemNotFound) { | ||||
| 			return nil | ||||
| 		} | ||||
| 		if err != nil { | ||||
| @@ -1147,7 +1147,7 @@ func (c *LeaseCache) handleRevocationRequest(ctx context.Context, req *SendReque | ||||
|  | ||||
| 		// Kill the lifetime watchers of the revoked token | ||||
| 		index, err := c.db.Get(cachememdb.IndexNameToken, token) | ||||
| 		if err == cachememdb.ErrCacheItemNotFound { | ||||
| 		if errors.Is(err, cachememdb.ErrCacheItemNotFound) { | ||||
| 			return true, nil | ||||
| 		} | ||||
| 		if err != nil { | ||||
| @@ -1395,7 +1395,7 @@ func (c *LeaseCache) restoreLeaseRenewCtx(index *cachememdb.Index) error { | ||||
| 	switch { | ||||
| 	case secret.LeaseID != "": | ||||
| 		entry, err := c.db.Get(cachememdb.IndexNameToken, index.RequestToken) | ||||
| 		if err == cachememdb.ErrCacheItemNotFound { | ||||
| 		if errors.Is(err, cachememdb.ErrCacheItemNotFound) { | ||||
| 			return fmt.Errorf("could not find parent Token %s for req path %s", index.RequestToken, index.RequestPath) | ||||
| 		} | ||||
| 		if err != nil { | ||||
| @@ -1409,7 +1409,7 @@ func (c *LeaseCache) restoreLeaseRenewCtx(index *cachememdb.Index) error { | ||||
| 		var parentCtx context.Context | ||||
| 		if !secret.Auth.Orphan { | ||||
| 			entry, err := c.db.Get(cachememdb.IndexNameToken, index.RequestToken) | ||||
| 			if err == cachememdb.ErrCacheItemNotFound { | ||||
| 			if errors.Is(err, cachememdb.ErrCacheItemNotFound) { | ||||
| 				// If parent token is not managed by the cache, child shouldn't be | ||||
| 				// either. | ||||
| 				if entry == nil { | ||||
|   | ||||
							
								
								
									
										385
									
								
								command/agentproxyshared/cache/static_secret_cache_updater.go
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										385
									
								
								command/agentproxyshared/cache/static_secret_cache_updater.go
									
									
									
									
										vendored
									
									
										Normal file
									
								
							| @@ -0,0 +1,385 @@ | ||||
| // Copyright (c) HashiCorp, Inc. | ||||
| // SPDX-License-Identifier: BUSL-1.1 | ||||
|  | ||||
| package cache | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"context" | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"net/http" | ||||
| 	"net/url" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/hashicorp/go-hclog" | ||||
| 	"github.com/hashicorp/vault/api" | ||||
| 	"github.com/hashicorp/vault/command/agentproxyshared/cache/cachememdb" | ||||
| 	"github.com/hashicorp/vault/command/agentproxyshared/sink" | ||||
| 	"github.com/hashicorp/vault/helper/useragent" | ||||
| 	"golang.org/x/exp/maps" | ||||
| 	"nhooyr.io/websocket" | ||||
| ) | ||||
|  | ||||
| // Example Event: | ||||
| //{ | ||||
| //  "id": "a3be9fb1-b514-519f-5b25-b6f144a8c1ce", | ||||
| //  "source": "https://vaultproject.io/", | ||||
| //  "specversion": "1.0", | ||||
| //  "type": "*", | ||||
| //  "data": { | ||||
| //    "event": { | ||||
| //      "id": "a3be9fb1-b514-519f-5b25-b6f144a8c1ce", | ||||
| //      "metadata": { | ||||
| //        "current_version": "1", | ||||
| //        "data_path": "secret/data/foo", | ||||
| //        "modified": "true", | ||||
| //        "oldest_version": "0", | ||||
| //        "operation": "data-write", | ||||
| //        "path": "secret/data/foo" | ||||
| //      } | ||||
| //    }, | ||||
| //    "event_type": "kv-v2/data-write", | ||||
| //    "plugin_info": { | ||||
| //      "mount_class": "secret", | ||||
| //      "mount_accessor": "kv_5dc4d18e", | ||||
| //      "mount_path": "secret/", | ||||
| //      "plugin": "kv" | ||||
| //    } | ||||
| //  }, | ||||
| //  "datacontentype": "application/cloudevents", | ||||
| //  "time": "2023-09-12T15:19:49.394915-07:00" | ||||
| //} | ||||
|  | ||||
| // StaticSecretCacheUpdater is a struct that utilizes | ||||
| // the event system to keep the static secret cache up to date. | ||||
| type StaticSecretCacheUpdater struct { | ||||
| 	client     *api.Client | ||||
| 	leaseCache *LeaseCache | ||||
| 	logger     hclog.Logger | ||||
| 	tokenSink  sink.Sink | ||||
| } | ||||
|  | ||||
| // StaticSecretCacheUpdaterConfig is the configuration for initializing a new | ||||
| // StaticSecretCacheUpdater. | ||||
| type StaticSecretCacheUpdaterConfig struct { | ||||
| 	Client     *api.Client | ||||
| 	LeaseCache *LeaseCache | ||||
| 	Logger     hclog.Logger | ||||
| 	// TokenSink is a token sync that will have the latest | ||||
| 	// token from auto-auth in it, to be used in event system | ||||
| 	// connections. | ||||
| 	TokenSink sink.Sink | ||||
| } | ||||
|  | ||||
| // NewStaticSecretCacheUpdater creates a new instance of a StaticSecretCacheUpdater. | ||||
| func NewStaticSecretCacheUpdater(conf *StaticSecretCacheUpdaterConfig) (*StaticSecretCacheUpdater, error) { | ||||
| 	if conf == nil { | ||||
| 		return nil, errors.New("nil configuration provided") | ||||
| 	} | ||||
|  | ||||
| 	if conf.LeaseCache == nil { | ||||
| 		return nil, fmt.Errorf("nil Lease Cache (a required parameter): %v", conf) | ||||
| 	} | ||||
|  | ||||
| 	if conf.Logger == nil { | ||||
| 		return nil, fmt.Errorf("nil Logger (a required parameter): %v", conf) | ||||
| 	} | ||||
|  | ||||
| 	if conf.Client == nil { | ||||
| 		return nil, fmt.Errorf("nil API client (a required parameter): %v", conf) | ||||
| 	} | ||||
|  | ||||
| 	if conf.TokenSink == nil { | ||||
| 		return nil, fmt.Errorf("nil token sink (a required parameter): %v", conf) | ||||
| 	} | ||||
|  | ||||
| 	return &StaticSecretCacheUpdater{ | ||||
| 		client:     conf.Client, | ||||
| 		leaseCache: conf.LeaseCache, | ||||
| 		logger:     conf.Logger, | ||||
| 		tokenSink:  conf.TokenSink, | ||||
| 	}, nil | ||||
| } | ||||
|  | ||||
| // streamStaticSecretEvents streams static secret events and updates | ||||
| // the cache when updates are notified. This method will return errors in cases | ||||
| // of failed updates, malformed events, and other. | ||||
| // For best results, the caller of this function should retry on error with backoff, | ||||
| // if it is desired for the cache to always remain up to date. | ||||
| func (updater *StaticSecretCacheUpdater) streamStaticSecretEvents(ctx context.Context) error { | ||||
| 	// First, ensure our token is up-to-date: | ||||
| 	updater.client.SetToken(updater.tokenSink.(sink.SinkReader).Token()) | ||||
| 	conn, err := updater.openWebSocketConnection(ctx) | ||||
| 	if err != nil { | ||||
| 		return fmt.Errorf("error when opening event stream: %w", err) | ||||
| 	} | ||||
| 	defer conn.Close(websocket.StatusNormalClosure, "") | ||||
|  | ||||
| 	// before we check for events, update all of our cached | ||||
| 	// kv secrets, in case we missed any events | ||||
| 	// TODO: to be implemented in a future PR | ||||
|  | ||||
| 	for { | ||||
| 		select { | ||||
| 		case <-ctx.Done(): | ||||
| 			return nil | ||||
| 		default: | ||||
| 			_, message, err := conn.Read(ctx) | ||||
| 			if err != nil { | ||||
| 				// The caller of this function should make the decision on if to retry. If it does, then | ||||
| 				// the websocket connection will be retried, and we will check for missed events. | ||||
| 				return fmt.Errorf("error when attempting to read from event stream, reopening websocket: %w", err) | ||||
| 			} | ||||
| 			updater.logger.Trace("received event", "message", string(message)) | ||||
| 			messageMap := make(map[string]interface{}) | ||||
| 			err = json.Unmarshal(message, &messageMap) | ||||
| 			if err != nil { | ||||
| 				return fmt.Errorf("error when unmarshaling event, message: %s\nerror: %w", string(message), err) | ||||
| 			} | ||||
| 			data, ok := messageMap["data"].(map[string]interface{}) | ||||
| 			if !ok { | ||||
| 				return fmt.Errorf("unexpected event format when decoding 'data' element, message: %s\nerror: %w", string(message), err) | ||||
| 			} | ||||
| 			event, ok := data["event"].(map[string]interface{}) | ||||
| 			if !ok { | ||||
| 				return fmt.Errorf("unexpected event format when decoding 'event' element, message: %s\nerror: %w", string(message), err) | ||||
| 			} | ||||
| 			metadata, ok := event["metadata"].(map[string]interface{}) | ||||
| 			if !ok { | ||||
| 				return fmt.Errorf("unexpected event format when decoding 'metadata' element, message: %s\nerror: %w", string(message), err) | ||||
| 			} | ||||
| 			modified, ok := metadata["modified"].(string) | ||||
| 			if ok && modified == "true" { | ||||
| 				path, ok := metadata["path"].(string) | ||||
| 				if !ok { | ||||
| 					return fmt.Errorf("unexpected event format when decoding 'path' element, message: %s\nerror: %w", string(message), err) | ||||
| 				} | ||||
| 				err := updater.updateStaticSecret(ctx, path) | ||||
| 				if err != nil { | ||||
| 					// While we are kind of 'missing' an event this way, re-calling this function will | ||||
| 					// result in the secret remaining up to date. | ||||
| 					return fmt.Errorf("error updating static secret: path: %q, message: %s error: %w", path, message, err) | ||||
| 				} | ||||
| 			} else { | ||||
| 				// This is an event we're not interested in, ignore it and | ||||
| 				// carry on. | ||||
| 				continue | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // updateStaticSecret checks for updates for a static secret on the path given, | ||||
| // and updates the cache if appropriate | ||||
| func (updater *StaticSecretCacheUpdater) updateStaticSecret(ctx context.Context, path string) error { | ||||
| 	// We clone the client, as we won't be using the same token. | ||||
| 	client, err := updater.client.Clone() | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	indexId := hashStaticSecretIndex(path) | ||||
|  | ||||
| 	updater.logger.Debug("received update static secret request", "path", path, "indexId", indexId) | ||||
|  | ||||
| 	index, err := updater.leaseCache.db.Get(cachememdb.IndexNameID, indexId) | ||||
| 	if errors.Is(err, cachememdb.ErrCacheItemNotFound) { | ||||
| 		// This event doesn't correspond to a secret in our cache | ||||
| 		// so this is a no-op. | ||||
| 		return nil | ||||
| 	} | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	// We use a raw request so that we can store all the | ||||
| 	// request information, just like we do in the Proxier Send methods. | ||||
| 	request := client.NewRequest(http.MethodGet, "/v1/"+path) | ||||
| 	if request.Headers == nil { | ||||
| 		request.Headers = make(http.Header) | ||||
| 	} | ||||
| 	request.Headers.Set("User-Agent", useragent.ProxyString()) | ||||
|  | ||||
| 	var resp *api.Response | ||||
| 	var tokensToRemove []string | ||||
| 	var successfulAttempt bool | ||||
| 	for _, token := range maps.Keys(index.Tokens) { | ||||
| 		client.SetToken(token) | ||||
| 		request.Headers.Set(api.AuthHeaderName, token) | ||||
| 		resp, err = client.RawRequestWithContext(ctx, request) | ||||
| 		if err != nil { | ||||
| 			updater.logger.Trace("received error when trying to update cache", "path", path, "err", err, "token", token) | ||||
| 			// We cannot access this secret with this token for whatever reason, | ||||
| 			// so token for removal. | ||||
| 			tokensToRemove = append(tokensToRemove, token) | ||||
| 			continue | ||||
| 		} else { | ||||
| 			// We got our updated secret! | ||||
| 			successfulAttempt = true | ||||
| 			break | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	if successfulAttempt { | ||||
| 		// We need to update the index, so first, hold the lock. | ||||
| 		index.IndexLock.Lock() | ||||
| 		defer index.IndexLock.Unlock() | ||||
|  | ||||
| 		// First, remove the tokens we noted couldn't access the secret from the token index | ||||
| 		for _, token := range tokensToRemove { | ||||
| 			delete(index.Tokens, token) | ||||
| 		} | ||||
|  | ||||
| 		sendResponse, err := NewSendResponse(resp, nil) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
|  | ||||
| 		// Serialize the response to store it in the cached index | ||||
| 		var respBytes bytes.Buffer | ||||
| 		err = sendResponse.Response.Write(&respBytes) | ||||
| 		if err != nil { | ||||
| 			updater.logger.Error("failed to serialize response", "error", err) | ||||
| 			return err | ||||
| 		} | ||||
|  | ||||
| 		// Set the index's Response | ||||
| 		index.Response = respBytes.Bytes() | ||||
| 		index.LastRenewed = time.Now().UTC() | ||||
|  | ||||
| 		// Lastly, store the secret | ||||
| 		updater.logger.Debug("storing response into the cache due to event update", "path", path) | ||||
| 		err = updater.leaseCache.db.Set(index) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 	} else { | ||||
| 		// No token could successfully update the secret, or secret was deleted. | ||||
| 		// We should evict the cache instead of re-storing the secret. | ||||
| 		updater.logger.Debug("evicting response from cache", "path", path) | ||||
| 		err = updater.leaseCache.db.Evict(cachememdb.IndexNameID, indexId) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // openWebSocketConnection opens a websocket connection to the event system for | ||||
| // the events that the static secret cache updater is interested in. | ||||
| func (updater *StaticSecretCacheUpdater) openWebSocketConnection(ctx context.Context) (*websocket.Conn, error) { | ||||
| 	// We parse this into a URL object to get the specific host and scheme | ||||
| 	// information without nasty string parsing. | ||||
| 	vaultURL, err := url.Parse(updater.client.Address()) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	vaultHost := vaultURL.Host | ||||
| 	// If we're using https, use wss, otherwise ws | ||||
| 	scheme := "wss" | ||||
| 	if vaultURL.Scheme == "http" { | ||||
| 		scheme = "ws" | ||||
| 	} | ||||
|  | ||||
| 	webSocketURL := url.URL{ | ||||
| 		Path:   "/v1/sys/events/subscribe/kv*", | ||||
| 		Host:   vaultHost, | ||||
| 		Scheme: scheme, | ||||
| 	} | ||||
| 	query := webSocketURL.Query() | ||||
| 	query.Set("json", "true") | ||||
| 	webSocketURL.RawQuery = query.Encode() | ||||
|  | ||||
| 	updater.client.AddHeader(api.AuthHeaderName, updater.client.Token()) | ||||
| 	updater.client.AddHeader(api.NamespaceHeaderName, updater.client.Namespace()) | ||||
|  | ||||
| 	// Populate these now to avoid recreating them in the upcoming for loop. | ||||
| 	headers := updater.client.Headers() | ||||
| 	wsURL := webSocketURL.String() | ||||
| 	httpClient := updater.client.CloneConfig().HttpClient | ||||
|  | ||||
| 	// We do ten attempts, to ensure we follow forwarding to the leader. | ||||
| 	var conn *websocket.Conn | ||||
| 	for attempt := 0; attempt < 10; attempt++ { | ||||
| 		var resp *http.Response | ||||
| 		conn, resp, err = websocket.Dial(ctx, wsURL, &websocket.DialOptions{ | ||||
| 			HTTPClient: httpClient, | ||||
| 			HTTPHeader: headers, | ||||
| 		}) | ||||
| 		if err == nil { | ||||
| 			break | ||||
| 		} | ||||
|  | ||||
| 		switch { | ||||
| 		case resp == nil: | ||||
| 			break | ||||
| 		case resp.StatusCode == http.StatusTemporaryRedirect: | ||||
| 			wsURL = resp.Header.Get("Location") | ||||
| 			continue | ||||
| 		default: | ||||
| 			break | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("error returned when opening event stream web socket to %s, ensure auto-auth token"+ | ||||
| 			" has correct permissions and Vault is version 1.16 or above: %w", wsURL, err) | ||||
| 	} | ||||
|  | ||||
| 	if conn == nil { | ||||
| 		return nil, errors.New(fmt.Sprintf("too many redirects as part of establishing web socket connection to %s", wsURL)) | ||||
| 	} | ||||
|  | ||||
| 	return conn, nil | ||||
| } | ||||
|  | ||||
| // Run is intended to be the method called by Vault Proxy, that runs the subsystem. | ||||
| // Once a token is provided to the sink, we will start the websocket and start consuming | ||||
| // events and updating secrets. | ||||
| // Run will shut down gracefully when the context is cancelled. | ||||
| func (updater *StaticSecretCacheUpdater) Run(ctx context.Context) error { | ||||
| 	updater.logger.Info("starting static secret cache updater subsystem") | ||||
| 	defer func() { | ||||
| 		updater.logger.Info("static secret cache updater subsystem stopped") | ||||
| 	}() | ||||
|  | ||||
| tokenLoop: | ||||
| 	for { | ||||
| 		select { | ||||
| 		case <-ctx.Done(): | ||||
| 			return nil | ||||
| 		default: | ||||
| 			// Wait for the auto-auth token to be populated... | ||||
| 			if updater.tokenSink.(sink.SinkReader).Token() != "" { | ||||
| 				break tokenLoop | ||||
| 			} | ||||
| 			time.Sleep(100 * time.Millisecond) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	shouldBackoff := false | ||||
| 	for { | ||||
| 		select { | ||||
| 		case <-ctx.Done(): | ||||
| 			return nil | ||||
| 		default: | ||||
| 			// If we're erroring and the context isn't done, we should add | ||||
| 			// a little backoff to make sure we don't accidentally overload | ||||
| 			// Vault or similar. | ||||
| 			if shouldBackoff { | ||||
| 				time.Sleep(10 * time.Second) | ||||
| 			} | ||||
| 			err := updater.streamStaticSecretEvents(ctx) | ||||
| 			if err != nil { | ||||
| 				updater.logger.Warn("error occurred during streaming static secret cache update events:", err) | ||||
| 				shouldBackoff = true | ||||
| 				continue | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										581
									
								
								command/agentproxyshared/cache/static_secret_cache_updater_test.go
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										581
									
								
								command/agentproxyshared/cache/static_secret_cache_updater_test.go
									
									
									
									
										vendored
									
									
										Normal file
									
								
							| @@ -0,0 +1,581 @@ | ||||
| // Copyright (c) HashiCorp, Inc. | ||||
| // SPDX-License-Identifier: BUSL-1.1 | ||||
|  | ||||
| package cache | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"fmt" | ||||
| 	"sync" | ||||
| 	"testing" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/hashicorp/vault/helper/testhelpers/minimal" | ||||
|  | ||||
| 	"github.com/hashicorp/go-hclog" | ||||
| 	kv "github.com/hashicorp/vault-plugin-secrets-kv" | ||||
| 	"github.com/hashicorp/vault/api" | ||||
| 	"github.com/hashicorp/vault/command/agentproxyshared/cache/cachememdb" | ||||
| 	"github.com/hashicorp/vault/command/agentproxyshared/sink" | ||||
| 	vaulthttp "github.com/hashicorp/vault/http" | ||||
| 	"github.com/hashicorp/vault/sdk/helper/logging" | ||||
| 	"github.com/hashicorp/vault/sdk/logical" | ||||
| 	"github.com/hashicorp/vault/vault" | ||||
| 	"github.com/stretchr/testify/require" | ||||
| 	"go.uber.org/atomic" | ||||
| 	"nhooyr.io/websocket" | ||||
| ) | ||||
|  | ||||
| // Avoiding a circular dependency in the test. | ||||
| type mockSink struct { | ||||
| 	token *atomic.String | ||||
| } | ||||
|  | ||||
| func (m *mockSink) Token() string { | ||||
| 	return m.token.Load() | ||||
| } | ||||
|  | ||||
| func (m *mockSink) WriteToken(token string) error { | ||||
| 	m.token.Store(token) | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func newMockSink(t *testing.T) sink.Sink { | ||||
| 	t.Helper() | ||||
|  | ||||
| 	return &mockSink{ | ||||
| 		token: atomic.NewString(""), | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // testNewStaticSecretCacheUpdater returns a new StaticSecretCacheUpdater | ||||
| // for use in tests. | ||||
| func testNewStaticSecretCacheUpdater(t *testing.T, client *api.Client) *StaticSecretCacheUpdater { | ||||
| 	t.Helper() | ||||
|  | ||||
| 	lc := testNewLeaseCache(t, []*SendResponse{}) | ||||
| 	tokenSink := newMockSink(t) | ||||
| 	tokenSink.WriteToken(client.Token()) | ||||
|  | ||||
| 	updater, err := NewStaticSecretCacheUpdater(&StaticSecretCacheUpdaterConfig{ | ||||
| 		Client:     client, | ||||
| 		LeaseCache: lc, | ||||
| 		Logger:     logging.NewVaultLogger(hclog.Trace).Named("cache.updater"), | ||||
| 		TokenSink:  tokenSink, | ||||
| 	}) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 	return updater | ||||
| } | ||||
|  | ||||
| // TestNewStaticSecretCacheUpdater tests the NewStaticSecretCacheUpdater method, | ||||
| // to ensure it errors out when appropriate. | ||||
| func TestNewStaticSecretCacheUpdater(t *testing.T) { | ||||
| 	t.Parallel() | ||||
|  | ||||
| 	lc := testNewLeaseCache(t, []*SendResponse{}) | ||||
| 	config := api.DefaultConfig() | ||||
| 	logger := logging.NewVaultLogger(hclog.Trace).Named("cache.updater") | ||||
| 	client, err := api.NewClient(config) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 	tokenSink := newMockSink(t) | ||||
|  | ||||
| 	// Expect an error if any of the arguments are nil: | ||||
| 	updater, err := NewStaticSecretCacheUpdater(&StaticSecretCacheUpdaterConfig{ | ||||
| 		Client:     nil, | ||||
| 		LeaseCache: lc, | ||||
| 		Logger:     logger, | ||||
| 		TokenSink:  tokenSink, | ||||
| 	}) | ||||
| 	require.Error(t, err) | ||||
| 	require.Nil(t, updater) | ||||
|  | ||||
| 	updater, err = NewStaticSecretCacheUpdater(&StaticSecretCacheUpdaterConfig{ | ||||
| 		Client:     client, | ||||
| 		LeaseCache: nil, | ||||
| 		Logger:     logger, | ||||
| 		TokenSink:  tokenSink, | ||||
| 	}) | ||||
| 	require.Error(t, err) | ||||
| 	require.Nil(t, updater) | ||||
|  | ||||
| 	updater, err = NewStaticSecretCacheUpdater(&StaticSecretCacheUpdaterConfig{ | ||||
| 		Client:     client, | ||||
| 		LeaseCache: lc, | ||||
| 		Logger:     nil, | ||||
| 		TokenSink:  tokenSink, | ||||
| 	}) | ||||
| 	require.Error(t, err) | ||||
| 	require.Nil(t, updater) | ||||
|  | ||||
| 	updater, err = NewStaticSecretCacheUpdater(&StaticSecretCacheUpdaterConfig{ | ||||
| 		Client:     client, | ||||
| 		LeaseCache: lc, | ||||
| 		Logger:     logging.NewVaultLogger(hclog.Trace).Named("cache.updater"), | ||||
| 		TokenSink:  nil, | ||||
| 	}) | ||||
| 	require.Error(t, err) | ||||
| 	require.Nil(t, updater) | ||||
|  | ||||
| 	// Don't expect an error if the arguments are as expected | ||||
| 	updater, err = NewStaticSecretCacheUpdater(&StaticSecretCacheUpdaterConfig{ | ||||
| 		Client:     client, | ||||
| 		LeaseCache: lc, | ||||
| 		Logger:     logging.NewVaultLogger(hclog.Trace).Named("cache.updater"), | ||||
| 		TokenSink:  tokenSink, | ||||
| 	}) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 	require.NotNil(t, updater) | ||||
| } | ||||
|  | ||||
| // TestOpenWebSocketConnection tests that the openWebSocketConnection function | ||||
| // works as expected. This uses a TLS enabled (wss) WebSocket connection. | ||||
| func TestOpenWebSocketConnection(t *testing.T) { | ||||
| 	t.Parallel() | ||||
| 	// We need a valid cluster for the connection to succeed. | ||||
| 	cluster := minimal.NewTestSoloCluster(t, nil) | ||||
| 	client := cluster.Cores[0].Client | ||||
|  | ||||
| 	updater := testNewStaticSecretCacheUpdater(t, client) | ||||
| 	updater.tokenSink.WriteToken(client.Token()) | ||||
|  | ||||
| 	conn, err := updater.openWebSocketConnection(context.Background()) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 	require.NotNil(t, conn) | ||||
| } | ||||
|  | ||||
| // TestOpenWebSocketConnectionReceivesEventsDefaultMount tests that the openWebSocketConnection function | ||||
| // works as expected with the default KVV1 mount, and then the connection can be used to receive an event. | ||||
| // This acts as more of an event system sanity check than a test of the updater | ||||
| // logic. It's still important coverage, though. | ||||
| // As of right now, it does not pass since the default kv mount is LeasedPassthroughBackend. | ||||
| // If that is changed, this test will be unskipped. | ||||
| func TestOpenWebSocketConnectionReceivesEventsDefaultMount(t *testing.T) { | ||||
| 	t.Parallel() | ||||
| 	t.Skip("This test won't finish, as the default KV mount is LeasedPassthroughBackend in tests, and therefore does not send events") | ||||
| 	// We need a valid cluster for the connection to succeed. | ||||
| 	cluster := vault.NewTestCluster(t, nil, &vault.TestClusterOptions{ | ||||
| 		HandlerFunc: vaulthttp.Handler, | ||||
| 	}) | ||||
| 	client := cluster.Cores[0].Client | ||||
|  | ||||
| 	updater := testNewStaticSecretCacheUpdater(t, client) | ||||
|  | ||||
| 	conn, err := updater.openWebSocketConnection(context.Background()) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 	require.NotNil(t, conn) | ||||
|  | ||||
| 	t.Cleanup(func() { | ||||
| 		conn.Close(websocket.StatusNormalClosure, "") | ||||
| 	}) | ||||
|  | ||||
| 	makeData := func(i int) map[string]interface{} { | ||||
| 		return map[string]interface{}{ | ||||
| 			"foo": fmt.Sprintf("bar%d", i), | ||||
| 		} | ||||
| 	} | ||||
| 	// Put a secret, which should trigger an event | ||||
| 	err = client.KVv1("secret").Put(context.Background(), "foo", makeData(100)) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	for i := 0; i < 5; i++ { | ||||
| 		// Do a fresh PUT just to refresh the secret and send a new message | ||||
| 		err = client.KVv1("secret").Put(context.Background(), "foo", makeData(i)) | ||||
| 		if err != nil { | ||||
| 			t.Fatal(err) | ||||
| 		} | ||||
|  | ||||
| 		// This method blocks until it gets a secret, so this test | ||||
| 		// will only pass if we're receiving events correctly. | ||||
| 		_, message, err := conn.Read(context.Background()) | ||||
| 		if err != nil { | ||||
| 			t.Fatal(err) | ||||
| 		} | ||||
| 		t.Log(string(message)) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // TestOpenWebSocketConnectionReceivesEventsKVV1 tests that the openWebSocketConnection function | ||||
| // works as expected with KVV1, and then the connection can be used to receive an event. | ||||
| // This acts as more of an event system sanity check than a test of the updater | ||||
| // logic. It's still important coverage, though. | ||||
| func TestOpenWebSocketConnectionReceivesEventsKVV1(t *testing.T) { | ||||
| 	t.Parallel() | ||||
| 	// We need a valid cluster for the connection to succeed. | ||||
| 	cluster := vault.NewTestCluster(t, &vault.CoreConfig{ | ||||
| 		LogicalBackends: map[string]logical.Factory{ | ||||
| 			"kv": kv.Factory, | ||||
| 		}, | ||||
| 	}, &vault.TestClusterOptions{ | ||||
| 		HandlerFunc: vaulthttp.Handler, | ||||
| 	}) | ||||
| 	client := cluster.Cores[0].Client | ||||
|  | ||||
| 	updater := testNewStaticSecretCacheUpdater(t, client) | ||||
|  | ||||
| 	conn, err := updater.openWebSocketConnection(context.Background()) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 	require.NotNil(t, conn) | ||||
|  | ||||
| 	t.Cleanup(func() { | ||||
| 		conn.Close(websocket.StatusNormalClosure, "") | ||||
| 	}) | ||||
|  | ||||
| 	err = client.Sys().Mount("secret-v1", &api.MountInput{ | ||||
| 		Type: "kv", | ||||
| 	}) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	makeData := func(i int) map[string]interface{} { | ||||
| 		return map[string]interface{}{ | ||||
| 			"foo": fmt.Sprintf("bar%d", i), | ||||
| 		} | ||||
| 	} | ||||
| 	// Put a secret, which should trigger an event | ||||
| 	err = client.KVv1("secret-v1").Put(context.Background(), "foo", makeData(100)) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	for i := 0; i < 5; i++ { | ||||
| 		// Do a fresh PUT just to refresh the secret and send a new message | ||||
| 		err = client.KVv1("secret-v1").Put(context.Background(), "foo", makeData(i)) | ||||
| 		if err != nil { | ||||
| 			t.Fatal(err) | ||||
| 		} | ||||
|  | ||||
| 		// This method blocks until it gets a secret, so this test | ||||
| 		// will only pass if we're receiving events correctly. | ||||
| 		_, _, err := conn.Read(context.Background()) | ||||
| 		if err != nil { | ||||
| 			t.Fatal(err) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // TestOpenWebSocketConnectionReceivesEvents tests that the openWebSocketConnection function | ||||
| // works as expected with KVV2, and then the connection can be used to receive an event. | ||||
| // This acts as more of an event system sanity check than a test of the updater | ||||
| // logic. It's still important coverage, though. | ||||
| func TestOpenWebSocketConnectionReceivesEventsKVV2(t *testing.T) { | ||||
| 	t.Parallel() | ||||
| 	// We need a valid cluster for the connection to succeed. | ||||
| 	cluster := vault.NewTestCluster(t, &vault.CoreConfig{ | ||||
| 		LogicalBackends: map[string]logical.Factory{ | ||||
| 			"kv": kv.VersionedKVFactory, | ||||
| 		}, | ||||
| 	}, &vault.TestClusterOptions{ | ||||
| 		HandlerFunc: vaulthttp.Handler, | ||||
| 	}) | ||||
| 	client := cluster.Cores[0].Client | ||||
|  | ||||
| 	updater := testNewStaticSecretCacheUpdater(t, client) | ||||
|  | ||||
| 	conn, err := updater.openWebSocketConnection(context.Background()) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 	require.NotNil(t, conn) | ||||
|  | ||||
| 	t.Cleanup(func() { | ||||
| 		conn.Close(websocket.StatusNormalClosure, "") | ||||
| 	}) | ||||
|  | ||||
| 	makeData := func(i int) map[string]interface{} { | ||||
| 		return map[string]interface{}{ | ||||
| 			"foo": fmt.Sprintf("bar%d", i), | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	err = client.Sys().Mount("secret-v2", &api.MountInput{ | ||||
| 		Type: "kv-v2", | ||||
| 	}) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	// Put a secret, which should trigger an event | ||||
| 	_, err = client.KVv2("secret-v2").Put(context.Background(), "foo", makeData(100)) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	for i := 0; i < 5; i++ { | ||||
| 		// Do a fresh PUT just to refresh the secret and send a new message | ||||
| 		_, err = client.KVv2("secret-v2").Put(context.Background(), "foo", makeData(i)) | ||||
| 		if err != nil { | ||||
| 			t.Fatal(err) | ||||
| 		} | ||||
|  | ||||
| 		// This method blocks until it gets a secret, so this test | ||||
| 		// will only pass if we're receiving events correctly. | ||||
| 		_, _, err := conn.Read(context.Background()) | ||||
| 		if err != nil { | ||||
| 			t.Fatal(err) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // TestOpenWebSocketConnectionTestServer tests that the openWebSocketConnection function | ||||
| // works as expected using vaulthttp.TestServer. This server isn't TLS enabled, so tests | ||||
| // the ws path (as opposed to the wss) path. | ||||
| func TestOpenWebSocketConnectionTestServer(t *testing.T) { | ||||
| 	t.Parallel() | ||||
| 	// We need a valid cluster for the connection to succeed. | ||||
| 	core := vault.TestCoreWithConfig(t, &vault.CoreConfig{}) | ||||
| 	ln, addr := vaulthttp.TestServer(t, core) | ||||
| 	defer ln.Close() | ||||
|  | ||||
| 	keys, rootToken := vault.TestCoreInit(t, core) | ||||
| 	for _, key := range keys { | ||||
| 		_, err := core.Unseal(key) | ||||
| 		if err != nil { | ||||
| 			t.Fatal(err) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	config := api.DefaultConfig() | ||||
| 	config.Address = addr | ||||
| 	client, err := api.NewClient(config) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 	client.SetToken(rootToken) | ||||
| 	updater := testNewStaticSecretCacheUpdater(t, client) | ||||
|  | ||||
| 	conn, err := updater.openWebSocketConnection(context.Background()) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 	require.NotNil(t, conn) | ||||
| } | ||||
|  | ||||
| // Test_StreamStaticSecretEvents_UpdatesCacheWithNewSecrets tests that an event will | ||||
| // properly update the corresponding secret in Proxy's cache. This is a little more end-to-end-y | ||||
| // than TestUpdateStaticSecret, and essentially is testing a similar thing, though is | ||||
| // ensuring that updateStaticSecret gets called by the event arriving | ||||
| // (as part of streamStaticSecretEvents) instead of testing calling it explicitly. | ||||
| func Test_StreamStaticSecretEvents_UpdatesCacheWithNewSecrets(t *testing.T) { | ||||
| 	t.Parallel() | ||||
| 	cluster := vault.NewTestCluster(t, &vault.CoreConfig{ | ||||
| 		LogicalBackends: map[string]logical.Factory{ | ||||
| 			"kv": kv.VersionedKVFactory, | ||||
| 		}, | ||||
| 	}, &vault.TestClusterOptions{ | ||||
| 		HandlerFunc: vaulthttp.Handler, | ||||
| 	}) | ||||
| 	client := cluster.Cores[0].Client | ||||
|  | ||||
| 	updater := testNewStaticSecretCacheUpdater(t, client) | ||||
| 	leaseCache := updater.leaseCache | ||||
|  | ||||
| 	wg := &sync.WaitGroup{} | ||||
| 	runStreamStaticSecretEvents := func() { | ||||
| 		wg.Add(1) | ||||
| 		err := updater.streamStaticSecretEvents(context.Background()) | ||||
| 		if err != nil { | ||||
| 			t.Fatal(err) | ||||
| 		} | ||||
| 	} | ||||
| 	go runStreamStaticSecretEvents() | ||||
|  | ||||
| 	// First, create the secret in the cache that we expect to be updated: | ||||
| 	path := "secret-v2/data/foo" | ||||
| 	indexId := hashStaticSecretIndex(path) | ||||
| 	initialTime := time.Now().UTC() | ||||
| 	// pre-populate the leaseCache with a secret to update | ||||
| 	index := &cachememdb.Index{ | ||||
| 		Namespace:   "root/", | ||||
| 		RequestPath: path, | ||||
| 		LastRenewed: initialTime, | ||||
| 		ID:          indexId, | ||||
| 		// Valid token provided, so update should work. | ||||
| 		Tokens:   map[string]struct{}{client.Token(): {}}, | ||||
| 		Response: []byte{}, | ||||
| 	} | ||||
| 	err := leaseCache.db.Set(index) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	secretData := map[string]interface{}{ | ||||
| 		"foo": "bar", | ||||
| 	} | ||||
|  | ||||
| 	err = client.Sys().Mount("secret-v2", &api.MountInput{ | ||||
| 		Type: "kv-v2", | ||||
| 	}) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	// Put a secret, which should trigger an event | ||||
| 	_, err = client.KVv2("secret-v2").Put(context.Background(), "foo", secretData) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	// Wait for the event to arrive. Events are usually much, much faster | ||||
| 	// than this, but we make it five seconds to protect against CI flakiness. | ||||
| 	time.Sleep(5 * time.Second) | ||||
|  | ||||
| 	// Then, do a GET to see if the event got updated | ||||
| 	newIndex, err := leaseCache.db.Get(cachememdb.IndexNameID, indexId) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 	require.NotNil(t, newIndex) | ||||
| 	require.NotEqual(t, []byte{}, newIndex.Response) | ||||
| 	require.Truef(t, initialTime.Before(newIndex.LastRenewed), "last updated time not updated on index") | ||||
| 	require.Equal(t, index.RequestPath, newIndex.RequestPath) | ||||
| 	require.Equal(t, index.Tokens, newIndex.Tokens) | ||||
|  | ||||
| 	wg.Done() | ||||
| } | ||||
|  | ||||
| // TestUpdateStaticSecret tests that updateStaticSecret works as expected, reaching out | ||||
| // to Vault to get an updated secret when called. | ||||
| func TestUpdateStaticSecret(t *testing.T) { | ||||
| 	t.Parallel() | ||||
| 	// We need a valid cluster for the connection to succeed. | ||||
| 	cluster := vault.NewTestCluster(t, nil, &vault.TestClusterOptions{ | ||||
| 		HandlerFunc: vaulthttp.Handler, | ||||
| 	}) | ||||
| 	client := cluster.Cores[0].Client | ||||
|  | ||||
| 	updater := testNewStaticSecretCacheUpdater(t, client) | ||||
| 	leaseCache := updater.leaseCache | ||||
|  | ||||
| 	path := "secret/foo" | ||||
| 	indexId := hashStaticSecretIndex(path) | ||||
| 	initialTime := time.Now().UTC() | ||||
| 	// pre-populate the leaseCache with a secret to update | ||||
| 	index := &cachememdb.Index{ | ||||
| 		Namespace:   "root/", | ||||
| 		RequestPath: "secret/foo", | ||||
| 		LastRenewed: initialTime, | ||||
| 		ID:          indexId, | ||||
| 		// Valid token provided, so update should work. | ||||
| 		Tokens:   map[string]struct{}{client.Token(): {}}, | ||||
| 		Response: []byte{}, | ||||
| 	} | ||||
| 	err := leaseCache.db.Set(index) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	secretData := map[string]interface{}{ | ||||
| 		"foo": "bar", | ||||
| 	} | ||||
|  | ||||
| 	// create the secret in Vault. n.b. the test cluster has already mounted the KVv1 backend at "secret" | ||||
| 	err = client.KVv1("secret").Put(context.Background(), "foo", secretData) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	// attempt the update | ||||
| 	err = updater.updateStaticSecret(context.Background(), path) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	newIndex, err := leaseCache.db.Get(cachememdb.IndexNameID, indexId) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 	require.NotNil(t, newIndex) | ||||
| 	require.Truef(t, initialTime.Before(newIndex.LastRenewed), "last updated time not updated on index") | ||||
| 	require.NotEqual(t, []byte{}, newIndex.Response) | ||||
| 	require.Equal(t, index.RequestPath, newIndex.RequestPath) | ||||
| 	require.Equal(t, index.Tokens, newIndex.Tokens) | ||||
| } | ||||
|  | ||||
| // TestUpdateStaticSecret_EvictsIfInvalidTokens tests that updateStaticSecret will | ||||
| // evict secrets from the cache if no valid tokens are left. | ||||
| func TestUpdateStaticSecret_EvictsIfInvalidTokens(t *testing.T) { | ||||
| 	t.Parallel() | ||||
| 	// We need a valid cluster for the connection to succeed. | ||||
| 	cluster := vault.NewTestCluster(t, nil, &vault.TestClusterOptions{ | ||||
| 		HandlerFunc: vaulthttp.Handler, | ||||
| 	}) | ||||
| 	client := cluster.Cores[0].Client | ||||
|  | ||||
| 	updater := testNewStaticSecretCacheUpdater(t, client) | ||||
| 	leaseCache := updater.leaseCache | ||||
|  | ||||
| 	path := "secret/foo" | ||||
| 	indexId := hashStaticSecretIndex(path) | ||||
| 	renewTime := time.Now().UTC() | ||||
|  | ||||
| 	// pre-populate the leaseCache with a secret to update | ||||
| 	index := &cachememdb.Index{ | ||||
| 		Namespace:   "root/", | ||||
| 		RequestPath: "secret/foo", | ||||
| 		LastRenewed: renewTime, | ||||
| 		ID:          indexId, | ||||
| 		// Note: invalid Tokens value provided, so this secret cannot be updated, and must be evicted | ||||
| 		Tokens: map[string]struct{}{"invalid token": {}}, | ||||
| 	} | ||||
| 	err := leaseCache.db.Set(index) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	secretData := map[string]interface{}{ | ||||
| 		"foo": "bar", | ||||
| 	} | ||||
|  | ||||
| 	// create the secret in Vault. n.b. the test cluster has already mounted the KVv1 backend at "secret" | ||||
| 	err = client.KVv1("secret").Put(context.Background(), "foo", secretData) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	// attempt the update | ||||
| 	err = updater.updateStaticSecret(context.Background(), path) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	newIndex, err := leaseCache.db.Get(cachememdb.IndexNameID, indexId) | ||||
| 	require.Equal(t, cachememdb.ErrCacheItemNotFound, err) | ||||
| 	require.Nil(t, newIndex) | ||||
| } | ||||
|  | ||||
| // TestUpdateStaticSecret_HandlesNonCachedPaths tests that updateStaticSecret | ||||
| // doesn't fail or error if we try and give it an update to a path that isn't cached. | ||||
| func TestUpdateStaticSecret_HandlesNonCachedPaths(t *testing.T) { | ||||
| 	t.Parallel() | ||||
| 	// We need a valid cluster for the connection to succeed. | ||||
| 	cluster := vault.NewTestCluster(t, nil, &vault.TestClusterOptions{ | ||||
| 		HandlerFunc: vaulthttp.Handler, | ||||
| 	}) | ||||
| 	client := cluster.Cores[0].Client | ||||
|  | ||||
| 	updater := testNewStaticSecretCacheUpdater(t, client) | ||||
|  | ||||
| 	path := "secret/foo" | ||||
|  | ||||
| 	// attempt the update | ||||
| 	err := updater.updateStaticSecret(context.Background(), path) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 	require.Nil(t, err) | ||||
| } | ||||
| @@ -433,6 +433,8 @@ func (c *ProxyCommand) Run(args []string) int { | ||||
| 	ctx, cancelFunc := context.WithCancel(context.Background()) | ||||
| 	defer cancelFunc() | ||||
|  | ||||
| 	var updater *cache.StaticSecretCacheUpdater | ||||
|  | ||||
| 	// Parse proxy cache configurations | ||||
| 	if config.Cache != nil { | ||||
| 		cacheLogger := c.logger.Named("cache") | ||||
| @@ -463,6 +465,33 @@ func (c *ProxyCommand) Run(args []string) int { | ||||
| 				defer deferFunc() | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
| 		// If we're caching static secrets, we need to start the updater, too | ||||
| 		if config.Cache.CacheStaticSecrets { | ||||
| 			staticSecretCacheUpdaterLogger := c.logger.Named("cache.staticsecretcacheupdater") | ||||
| 			inmemSink, err := inmem.New(&sink.SinkConfig{ | ||||
| 				Logger: staticSecretCacheUpdaterLogger, | ||||
| 			}, leaseCache) | ||||
| 			if err != nil { | ||||
| 				c.UI.Error(fmt.Sprintf("Error creating inmem sink for static secret updater susbsystem: %v", err)) | ||||
| 				return 1 | ||||
| 			} | ||||
| 			sinks = append(sinks, &sink.SinkConfig{ | ||||
| 				Logger: staticSecretCacheUpdaterLogger, | ||||
| 				Sink:   inmemSink, | ||||
| 			}) | ||||
|  | ||||
| 			updater, err = cache.NewStaticSecretCacheUpdater(&cache.StaticSecretCacheUpdaterConfig{ | ||||
| 				Client:     client, | ||||
| 				LeaseCache: leaseCache, | ||||
| 				Logger:     staticSecretCacheUpdaterLogger, | ||||
| 				TokenSink:  inmemSink, | ||||
| 			}) | ||||
| 			if err != nil { | ||||
| 				c.UI.Error(fmt.Sprintf("Error creating static secret cache updater: %v", err)) | ||||
| 				return 1 | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	var listeners []net.Listener | ||||
| @@ -500,7 +529,7 @@ func (c *ProxyCommand) Run(args []string) int { | ||||
| 		var inmemSink sink.Sink | ||||
| 		if config.APIProxy != nil { | ||||
| 			if config.APIProxy.UseAutoAuthToken { | ||||
| 				apiProxyLogger.Debug("auto-auth token is allowed to be used; configuring inmem sink") | ||||
| 				apiProxyLogger.Debug("configuring inmem auto-auth sink") | ||||
| 				inmemSink, err = inmem.New(&sink.SinkConfig{ | ||||
| 					Logger: apiProxyLogger, | ||||
| 				}, leaseCache) | ||||
| @@ -699,6 +728,16 @@ func (c *ProxyCommand) Run(args []string) int { | ||||
| 		}) | ||||
| 	} | ||||
|  | ||||
| 	// Add the static secret cache updater, if appropriate | ||||
| 	if updater != nil { | ||||
| 		g.Add(func() error { | ||||
| 			err := updater.Run(ctx) | ||||
| 			return err | ||||
| 		}, func(error) { | ||||
| 			cancelFunc() | ||||
| 		}) | ||||
| 	} | ||||
|  | ||||
| 	// Server configuration output | ||||
| 	padding := 24 | ||||
| 	sort.Strings(infoKeys) | ||||
|   | ||||
| @@ -247,12 +247,17 @@ func (c *Config) ValidateConfig() error { | ||||
| 	} | ||||
|  | ||||
| 	if c.AutoAuth != nil { | ||||
| 		cacheStaticSecrets := c.Cache != nil && c.Cache.CacheStaticSecrets | ||||
| 		if len(c.AutoAuth.Sinks) == 0 && | ||||
| 			(c.APIProxy == nil || !c.APIProxy.UseAutoAuthToken) { | ||||
| 			return fmt.Errorf("auto_auth requires at least one sink or api_proxy.use_auto_auth_token=true") | ||||
| 			(c.APIProxy == nil || !c.APIProxy.UseAutoAuthToken) && !cacheStaticSecrets { | ||||
| 			return fmt.Errorf("auto_auth requires at least one sink, api_proxy.use_auto_auth_token=true, or cache.cache_static_secrets=true") | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	if c.Cache != nil && c.Cache.CacheStaticSecrets && c.AutoAuth == nil { | ||||
| 		return fmt.Errorf("cache.cache_static_secrets=true requires an auto-auth block configured, to use the token to connect with Vault's event system") | ||||
| 	} | ||||
|  | ||||
| 	if c.AutoAuth == nil && c.Cache == nil && len(c.Listeners) == 0 { | ||||
| 		return fmt.Errorf("no auto_auth, cache, or listener block found in config") | ||||
| 	} | ||||
|   | ||||
| @@ -117,3 +117,16 @@ func TestLoadConfigFile_ProxyCache(t *testing.T) { | ||||
| 		t.Fatal(diff) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // TestLoadConfigFile_StaticSecretCachingWithoutAutoAuth tests that loading | ||||
| // a config file with static secret caching enabled but no auto auth will fail. | ||||
| func TestLoadConfigFile_StaticSecretCachingWithoutAutoAuth(t *testing.T) { | ||||
| 	cfg, err := LoadConfigFile("./test-fixtures/config-cache-static-no-auto-auth.hcl") | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	if err := cfg.ValidateConfig(); err == nil { | ||||
| 		t.Fatalf("expected error, as static secret caching requires auto-auth") | ||||
| 	} | ||||
| } | ||||
|   | ||||
| @@ -0,0 +1,18 @@ | ||||
| # Copyright (c) HashiCorp, Inc. | ||||
| # SPDX-License-Identifier: BUSL-1.1 | ||||
|  | ||||
| pid_file = "./pidfile" | ||||
|  | ||||
| cache { | ||||
|     cache_static_secrets = true | ||||
| } | ||||
|  | ||||
| listener "tcp" { | ||||
|     address = "127.0.0.1:8300" | ||||
|     tls_disable = true | ||||
| } | ||||
|  | ||||
| vault { | ||||
| 	address = "http://127.0.0.1:1111" | ||||
| 	tls_skip_verify = "true" | ||||
| } | ||||
| @@ -703,6 +703,20 @@ func TestProxy_Cache_StaticSecret(t *testing.T) { | ||||
| 	defer os.Setenv(api.EnvVaultAddress, os.Getenv(api.EnvVaultAddress)) | ||||
| 	os.Unsetenv(api.EnvVaultAddress) | ||||
|  | ||||
| 	tokenFileName := makeTempFile(t, "token-file", serverClient.Token()) | ||||
| 	defer os.Remove(tokenFileName) | ||||
| 	// We need auto-auth so that the event system can run. | ||||
| 	// For ease, we use the token file path with the root token. | ||||
| 	autoAuthConfig := fmt.Sprintf(` | ||||
| auto_auth { | ||||
|     method { | ||||
| 		type = "token_file" | ||||
|         config = { | ||||
|             token_file_path = "%s" | ||||
|         } | ||||
|     } | ||||
| }`, tokenFileName) | ||||
|  | ||||
| 	cacheConfig := ` | ||||
| cache { | ||||
| 	cache_static_secrets = true | ||||
| @@ -723,13 +737,14 @@ vault { | ||||
| } | ||||
| %s | ||||
| %s | ||||
| %s | ||||
| log_level = "trace" | ||||
| `, serverClient.Address(), cacheConfig, listenConfig) | ||||
| `, serverClient.Address(), cacheConfig, listenConfig, autoAuthConfig) | ||||
| 	configPath := makeTempFile(t, "config.hcl", config) | ||||
| 	defer os.Remove(configPath) | ||||
|  | ||||
| 	// Start proxy | ||||
| 	_, cmd := testProxyCommand(t, logger) | ||||
| 	ui, cmd := testProxyCommand(t, logger) | ||||
| 	cmd.startedCh = make(chan struct{}) | ||||
|  | ||||
| 	wg := &sync.WaitGroup{} | ||||
| @@ -743,6 +758,8 @@ log_level = "trace" | ||||
| 	case <-cmd.startedCh: | ||||
| 	case <-time.After(5 * time.Second): | ||||
| 		t.Errorf("timeout") | ||||
| 		t.Errorf("stdout: %s", ui.OutputWriter.String()) | ||||
| 		t.Errorf("stderr: %s", ui.ErrorWriter.String()) | ||||
| 	} | ||||
|  | ||||
| 	proxyClient, err := api.NewClient(api.DefaultConfig()) | ||||
| @@ -804,15 +821,18 @@ log_level = "trace" | ||||
| 	wg.Wait() | ||||
| } | ||||
|  | ||||
| // TestProxy_Cache_StaticSecretInvalidation Tests that the cache successfully caches a static secret | ||||
| // going through the Proxy, and that it gets invalidated by a POST. | ||||
| func TestProxy_Cache_StaticSecretInvalidation(t *testing.T) { | ||||
| // TestProxy_Cache_EventSystemUpdatesCacheKVV1 Tests that the cache successfully caches a static secret | ||||
| // going through the Proxy, and then the cache gets updated on a POST to the KVV1 secret due to an | ||||
| // event. | ||||
| func TestProxy_Cache_EventSystemUpdatesCacheKVV1(t *testing.T) { | ||||
| 	logger := logging.NewVaultLogger(hclog.Trace) | ||||
| 	cluster := vault.NewTestCluster(t, nil, &vault.TestClusterOptions{ | ||||
| 	cluster := vault.NewTestCluster(t, &vault.CoreConfig{ | ||||
| 		LogicalBackends: map[string]logical.Factory{ | ||||
| 			"kv": logicalKv.Factory, | ||||
| 		}, | ||||
| 	}, &vault.TestClusterOptions{ | ||||
| 		HandlerFunc: vaulthttp.Handler, | ||||
| 	}) | ||||
| 	cluster.Start() | ||||
| 	defer cluster.Cleanup() | ||||
|  | ||||
| 	serverClient := cluster.Cores[0].Client | ||||
|  | ||||
| @@ -821,6 +841,20 @@ func TestProxy_Cache_StaticSecretInvalidation(t *testing.T) { | ||||
| 	defer os.Setenv(api.EnvVaultAddress, os.Getenv(api.EnvVaultAddress)) | ||||
| 	os.Unsetenv(api.EnvVaultAddress) | ||||
|  | ||||
| 	tokenFileName := makeTempFile(t, "token-file", serverClient.Token()) | ||||
| 	defer os.Remove(tokenFileName) | ||||
| 	// We need auto-auth so that the event system can run. | ||||
| 	// For ease, we use the token file path with the root token. | ||||
| 	autoAuthConfig := fmt.Sprintf(` | ||||
| auto_auth { | ||||
|     method { | ||||
| 		type = "token_file" | ||||
|         config = { | ||||
|             token_file_path = "%s" | ||||
|         } | ||||
|     } | ||||
| }`, tokenFileName) | ||||
|  | ||||
| 	cacheConfig := ` | ||||
| cache { | ||||
| 	cache_static_secrets = true | ||||
| @@ -841,13 +875,14 @@ vault { | ||||
| } | ||||
| %s | ||||
| %s | ||||
| %s | ||||
| log_level = "trace" | ||||
| `, serverClient.Address(), cacheConfig, listenConfig) | ||||
| `, serverClient.Address(), cacheConfig, listenConfig, autoAuthConfig) | ||||
| 	configPath := makeTempFile(t, "config.hcl", config) | ||||
| 	defer os.Remove(configPath) | ||||
|  | ||||
| 	// Start proxy | ||||
| 	_, cmd := testProxyCommand(t, logger) | ||||
| 	ui, cmd := testProxyCommand(t, logger) | ||||
| 	cmd.startedCh = make(chan struct{}) | ||||
|  | ||||
| 	wg := &sync.WaitGroup{} | ||||
| @@ -861,6 +896,8 @@ log_level = "trace" | ||||
| 	case <-cmd.startedCh: | ||||
| 	case <-time.After(5 * time.Second): | ||||
| 		t.Errorf("timeout") | ||||
| 		t.Errorf("stdout: %s", ui.OutputWriter.String()) | ||||
| 		t.Errorf("stderr: %s", ui.ErrorWriter.String()) | ||||
| 	} | ||||
|  | ||||
| 	proxyClient, err := api.NewClient(api.DefaultConfig()) | ||||
| @@ -882,14 +919,27 @@ log_level = "trace" | ||||
| 		"bar": "baz", | ||||
| 	} | ||||
|  | ||||
| 	// Wait for the event system to successfully connect. | ||||
| 	// This is longer than it needs to be to account for unnatural slowness/avoiding | ||||
| 	// flakiness. | ||||
| 	time.Sleep(5 * time.Second) | ||||
|  | ||||
| 	// Mount the KVV2 engine | ||||
| 	err = serverClient.Sys().Mount("secret-v1", &api.MountInput{ | ||||
| 		Type: "kv", | ||||
| 	}) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	// Create kvv1 secret | ||||
| 	err = serverClient.KVv1("secret").Put(context.Background(), "my-secret", secretData) | ||||
| 	err = serverClient.KVv1("secret-v1").Put(context.Background(), "my-secret", secretData) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	// We use raw requests so we can check the headers for cache hit/miss. | ||||
| 	req := proxyClient.NewRequest(http.MethodGet, "/v1/secret/my-secret") | ||||
| 	req := proxyClient.NewRequest(http.MethodGet, "/v1/secret-v1/my-secret") | ||||
| 	resp1, err := proxyClient.RawRequest(req) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| @@ -899,27 +949,23 @@ log_level = "trace" | ||||
| 	require.Equal(t, "MISS", cacheValue) | ||||
|  | ||||
| 	// Update the secret using the proxy client | ||||
| 	err = proxyClient.KVv1("secret").Put(context.Background(), "my-secret", secretData2) | ||||
| 	err = proxyClient.KVv1("secret-v1").Put(context.Background(), "my-secret", secretData2) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	// Give some time for the event to actually get sent and the cache to be updated. | ||||
| 	// This is longer than it needs to be to account for unnatural slowness/avoiding | ||||
| 	// flakiness. | ||||
| 	time.Sleep(5 * time.Second) | ||||
|  | ||||
| 	// We expect this to be a cache hit, with the new value | ||||
| 	resp2, err := proxyClient.RawRequest(req) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	cacheValue = resp2.Header.Get("X-Cache") | ||||
| 	// This should miss too, as we just updated it | ||||
| 	require.Equal(t, "MISS", cacheValue) | ||||
|  | ||||
| 	resp3, err := proxyClient.RawRequest(req) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	cacheValue = resp3.Header.Get("X-Cache") | ||||
| 	// This should hit, as the third request should get the cached value | ||||
| 	require.Equal(t, "HIT", cacheValue) | ||||
|  | ||||
| 	// Lastly, we check to make sure the actual data we received is | ||||
| @@ -936,11 +982,175 @@ log_level = "trace" | ||||
| 	} | ||||
| 	require.Equal(t, secretData2, secret2.Data) | ||||
|  | ||||
| 	secret3, err := api.ParseSecret(resp3.Body) | ||||
| 	close(cmd.ShutdownCh) | ||||
| 	wg.Wait() | ||||
| } | ||||
|  | ||||
| // TestProxy_Cache_EventSystemUpdatesCacheKVV2 Tests that the cache successfully caches a static secret | ||||
| // going through the Proxy for a KVV2 secret, and then the cache gets updated on a POST to the secret due to an | ||||
| // event. | ||||
| func TestProxy_Cache_EventSystemUpdatesCacheKVV2(t *testing.T) { | ||||
| 	logger := logging.NewVaultLogger(hclog.Trace) | ||||
| 	cluster := vault.NewTestCluster(t, &vault.CoreConfig{ | ||||
| 		LogicalBackends: map[string]logical.Factory{ | ||||
| 			"kv": logicalKv.VersionedKVFactory, | ||||
| 		}, | ||||
| 	}, &vault.TestClusterOptions{ | ||||
| 		HandlerFunc: vaulthttp.Handler, | ||||
| 	}) | ||||
|  | ||||
| 	serverClient := cluster.Cores[0].Client | ||||
|  | ||||
| 	// Unset the environment variable so that proxy picks up the right test | ||||
| 	// cluster address | ||||
| 	defer os.Setenv(api.EnvVaultAddress, os.Getenv(api.EnvVaultAddress)) | ||||
| 	os.Unsetenv(api.EnvVaultAddress) | ||||
|  | ||||
| 	tokenFileName := makeTempFile(t, "token-file", serverClient.Token()) | ||||
| 	defer os.Remove(tokenFileName) | ||||
| 	// We need auto-auth so that the event system can run. | ||||
| 	// For ease, we use the token file path with the root token. | ||||
| 	autoAuthConfig := fmt.Sprintf(` | ||||
| auto_auth { | ||||
|     method { | ||||
| 		type = "token_file" | ||||
|         config = { | ||||
|             token_file_path = "%s" | ||||
|         } | ||||
|     } | ||||
| }`, tokenFileName) | ||||
|  | ||||
| 	cacheConfig := ` | ||||
| cache { | ||||
| 	cache_static_secrets = true | ||||
| } | ||||
| ` | ||||
| 	listenAddr := generateListenerAddress(t) | ||||
| 	listenConfig := fmt.Sprintf(` | ||||
| listener "tcp" { | ||||
|   address = "%s" | ||||
|   tls_disable = true | ||||
| } | ||||
| `, listenAddr) | ||||
|  | ||||
| 	config := fmt.Sprintf(` | ||||
| vault { | ||||
|   address = "%s" | ||||
|   tls_skip_verify = true | ||||
| } | ||||
| %s | ||||
| %s | ||||
| %s | ||||
| log_level = "trace" | ||||
| `, serverClient.Address(), cacheConfig, listenConfig, autoAuthConfig) | ||||
| 	configPath := makeTempFile(t, "config.hcl", config) | ||||
| 	defer os.Remove(configPath) | ||||
|  | ||||
| 	// Start proxy | ||||
| 	ui, cmd := testProxyCommand(t, logger) | ||||
| 	cmd.startedCh = make(chan struct{}) | ||||
|  | ||||
| 	wg := &sync.WaitGroup{} | ||||
| 	wg.Add(1) | ||||
| 	go func() { | ||||
| 		cmd.Run([]string{"-config", configPath}) | ||||
| 		wg.Done() | ||||
| 	}() | ||||
|  | ||||
| 	select { | ||||
| 	case <-cmd.startedCh: | ||||
| 	case <-time.After(5 * time.Second): | ||||
| 		t.Errorf("timeout") | ||||
| 		t.Errorf("stdout: %s", ui.OutputWriter.String()) | ||||
| 		t.Errorf("stderr: %s", ui.ErrorWriter.String()) | ||||
| 	} | ||||
|  | ||||
| 	proxyClient, err := api.NewClient(api.DefaultConfig()) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 	require.Equal(t, secret2.Data, secret3.Data) | ||||
| 	proxyClient.SetToken(serverClient.Token()) | ||||
| 	proxyClient.SetMaxRetries(0) | ||||
| 	err = proxyClient.SetAddress("http://" + listenAddr) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	secretData := map[string]interface{}{ | ||||
| 		"foo": "bar", | ||||
| 	} | ||||
|  | ||||
| 	secretData2 := map[string]interface{}{ | ||||
| 		"bar": "baz", | ||||
| 	} | ||||
|  | ||||
| 	// Wait for the event system to successfully connect. | ||||
| 	// This is longer than it needs to be to account for unnatural slowness/avoiding | ||||
| 	// flakiness. | ||||
| 	time.Sleep(5 * time.Second) | ||||
|  | ||||
| 	// Mount the KVV2 engine | ||||
| 	err = serverClient.Sys().Mount("secret-v2", &api.MountInput{ | ||||
| 		Type: "kv-v2", | ||||
| 	}) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	// Create kvv2 secret | ||||
| 	_, err = serverClient.KVv2("secret-v2").Put(context.Background(), "my-secret", secretData) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	// We use raw requests so we can check the headers for cache hit/miss. | ||||
| 	req := proxyClient.NewRequest(http.MethodGet, "/v1/secret-v2/data/my-secret") | ||||
| 	resp1, err := proxyClient.RawRequest(req) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	cacheValue := resp1.Header.Get("X-Cache") | ||||
| 	require.Equal(t, "MISS", cacheValue) | ||||
|  | ||||
| 	// Update the secret using the proxy client | ||||
| 	_, err = proxyClient.KVv2("secret-v2").Put(context.Background(), "my-secret", secretData2) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	// Give some time for the event to actually get sent and the cache to be updated. | ||||
| 	// This is longer than it needs to be to account for unnatural slowness/avoiding | ||||
| 	// flakiness. | ||||
| 	time.Sleep(5 * time.Second) | ||||
|  | ||||
| 	// We expect this to be a cache hit, with the new value | ||||
| 	resp2, err := proxyClient.RawRequest(req) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	cacheValue = resp2.Header.Get("X-Cache") | ||||
| 	require.Equal(t, "HIT", cacheValue) | ||||
|  | ||||
| 	// Lastly, we check to make sure the actual data we received is | ||||
| 	// as we expect. We must use ParseSecret due to the raw requests. | ||||
| 	secret1, err := api.ParseSecret(resp1.Body) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 	data, ok := secret1.Data["data"] | ||||
| 	require.True(t, ok) | ||||
| 	require.Equal(t, secretData, data) | ||||
|  | ||||
| 	secret2, err := api.ParseSecret(resp2.Body) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 	data2, ok := secret2.Data["data"] | ||||
| 	require.True(t, ok) | ||||
| 	// We expect that the cached value got updated by the event system. | ||||
| 	require.Equal(t, secretData2, data2) | ||||
|  | ||||
| 	close(cmd.ShutdownCh) | ||||
| 	wg.Wait() | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Violet Hynes
					Violet Hynes