VAULT-30219 Bug fix for race condition (#28228)

This commit is contained in:
Violet Hynes
2024-08-30 09:45:58 -04:00
committed by GitHub
parent b5621aa368
commit bc7923ad29
2 changed files with 131 additions and 121 deletions

View File

@@ -4,6 +4,7 @@
package cache package cache
import ( import (
"bufio"
"bytes" "bytes"
"context" "context"
"encoding/json" "encoding/json"
@@ -12,7 +13,6 @@ import (
"io" "io"
"net/http" "net/http"
"net/url" "net/url"
"strconv"
"strings" "strings"
"sync/atomic" "sync/atomic"
"time" "time"
@@ -218,18 +218,6 @@ func (updater *StaticSecretCacheUpdater) streamStaticSecretEvents(ctx context.Co
} }
} }
// For all other operations, we *only* care about the latest version.
// However, if we know the current version, we should update that too
currentVersion := 0
currentVersionString, ok := metadata["current_version"].(string)
if ok {
versionInt, err := strconv.Atoi(currentVersionString)
if err != nil {
return fmt.Errorf("unexpected event format when decoding 'current_version' element, message: %s\nerror: %w", string(message), err)
}
currentVersion = versionInt
}
// Note: For delete/destroy events, we continue through to updating the secret itself, too. // Note: For delete/destroy events, we continue through to updating the secret itself, too.
// This means that if the latest version of the secret gets deleted, then the cache keeps // This means that if the latest version of the secret gets deleted, then the cache keeps
// knowledge of which the latest version is. // knowledge of which the latest version is.
@@ -237,7 +225,7 @@ func (updater *StaticSecretCacheUpdater) streamStaticSecretEvents(ctx context.Co
// to update the secret will 404. This is consistent with other behaviour. For Proxy, this means // to update the secret will 404. This is consistent with other behaviour. For Proxy, this means
// the secret may be evicted. That's okay. // the secret may be evicted. That's okay.
err = updater.updateStaticSecret(ctx, path, currentVersion) err = updater.updateStaticSecret(ctx, path)
if err != nil { if err != nil {
// While we are kind of 'missing' an event this way, re-calling this function will // While we are kind of 'missing' an event this way, re-calling this function will
// result in the secret remaining up to date. // result in the secret remaining up to date.
@@ -363,7 +351,7 @@ func (updater *StaticSecretCacheUpdater) preEventStreamUpdate(ctx context.Contex
if index.Type != cacheboltdb.StaticSecretType { if index.Type != cacheboltdb.StaticSecretType {
continue continue
} }
err = updater.updateStaticSecret(ctx, index.RequestPath, 0) err = updater.updateStaticSecret(ctx, index.RequestPath)
if err != nil { if err != nil {
errs = multierror.Append(errs, err) errs = multierror.Append(errs, err)
} }
@@ -411,9 +399,9 @@ func (updater *StaticSecretCacheUpdater) handleDeleteDestroyVersions(path string
} }
// updateStaticSecret checks for updates for a static secret on the path given, // updateStaticSecret checks for updates for a static secret on the path given,
// and updates the cache if appropriate. If currentVersion is not 0, we will also update // and updates the cache if appropriate. For KVv2 secrets, we will also update
// will also update the version at index.Versions[currentVersion] with the same data. // the version at index.Versions[currentVersion] with the same data.
func (updater *StaticSecretCacheUpdater) updateStaticSecret(ctx context.Context, path string, currentVersion int) error { func (updater *StaticSecretCacheUpdater) updateStaticSecret(ctx context.Context, path string) error {
// We clone the client, as we won't be using the same token. // We clone the client, as we won't be using the same token.
client, err := updater.client.Clone() client, err := updater.client.Clone()
if err != nil { if err != nil {
@@ -514,19 +502,37 @@ func (updater *StaticSecretCacheUpdater) updateStaticSecret(ctx context.Context,
return err return err
} }
// Set the index's Response
index.Response = respBytes.Bytes() index.Response = respBytes.Bytes()
index.LastRenewed = time.Now().UTC() index.LastRenewed = time.Now().UTC()
if currentVersion != 0 {
// It should always be non-nil, but avoid a panic just in case. // For KVv2 secrets, let's also update index.Versions[version_of_secret]
if index.Versions == nil { // with the response we received from the current version.
index.Versions = map[int][]byte{} // Instead of relying on current_version in the event, we should
} // check the message we received, since it's possible the secret
index.Versions[currentVersion] = index.Response // got updated between receipt of the event and when we received
// the request for the secret.
// First, re-read secret into response so that we can parse it again:
reader := bufio.NewReader(bytes.NewReader(index.Response))
resp, err := http.ReadResponse(reader, nil)
if err != nil {
// This shouldn't happen, but log just in case it does. There's
// no real negative consequences of the following function though.
updater.logger.Warn("failed to deserialize response", "error", err)
} }
secret, err := api.ParseSecret(resp.Body)
if err != nil {
// This shouldn't happen, but log just in case it does. There's
// no real negative consequences of the following function though.
updater.logger.Warn("failed to serialize response", "error", err)
}
// In case of failures or KVv1 secrets, this function will simply fail silently,
// which is fine (and expected) since this could be arbitrary JSON.
updater.leaseCache.addToVersionListForCurrentVersionKVv2Secret(index, secret)
// Lastly, store the secret // Lastly, store the secret
updater.logger.Debug("storing response into the cache due to update", "path", path, "currentVersion", currentVersion) updater.logger.Debug("storing response into the cache due to update", "path", path)
err = updater.leaseCache.db.Set(index) err = updater.leaseCache.db.Set(index)
if err != nil { if err != nil {
return err return err

View File

@@ -65,9 +65,7 @@ func testNewStaticSecretCacheUpdater(t *testing.T, client *api.Client) *StaticSe
Logger: logging.NewVaultLogger(hclog.Trace).Named("cache.updater"), Logger: logging.NewVaultLogger(hclog.Trace).Named("cache.updater"),
TokenSink: tokenSink, TokenSink: tokenSink,
}) })
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
return updater return updater
} }
@@ -80,9 +78,7 @@ func TestNewStaticSecretCacheUpdater(t *testing.T) {
config := api.DefaultConfig() config := api.DefaultConfig()
logger := logging.NewVaultLogger(hclog.Trace).Named("cache.updater") logger := logging.NewVaultLogger(hclog.Trace).Named("cache.updater")
client, err := api.NewClient(config) client, err := api.NewClient(config)
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
tokenSink := newMockSink(t) tokenSink := newMockSink(t)
// Expect an error if any of the arguments are nil: // Expect an error if any of the arguments are nil:
@@ -129,9 +125,7 @@ func TestNewStaticSecretCacheUpdater(t *testing.T) {
Logger: logging.NewVaultLogger(hclog.Trace).Named("cache.updater"), Logger: logging.NewVaultLogger(hclog.Trace).Named("cache.updater"),
TokenSink: tokenSink, TokenSink: tokenSink,
}) })
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
require.NotNil(t, updater) require.NotNil(t, updater)
} }
@@ -194,9 +188,7 @@ func TestOpenWebSocketConnection_BadPolicyToken(t *testing.T) {
select { select {
case <-ctx.Done(): case <-ctx.Done():
case err := <-errCh: case err := <-errCh:
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
} }
}() }()
@@ -245,9 +237,7 @@ func TestOpenWebSocketConnection_AutoAuthSelfHeal(t *testing.T) {
select { select {
case <-ctx.Done(): case <-ctx.Done():
case err := <-errCh: case err := <-errCh:
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
} }
}() }()
@@ -305,9 +295,7 @@ func TestOpenWebSocketConnectionReceivesEventsDefaultMount(t *testing.T) {
updater := testNewStaticSecretCacheUpdater(t, client) updater := testNewStaticSecretCacheUpdater(t, client)
conn, err := updater.openWebSocketConnection(context.Background()) conn, err := updater.openWebSocketConnection(context.Background())
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
require.NotNil(t, conn) require.NotNil(t, conn)
t.Cleanup(func() { t.Cleanup(func() {
@@ -321,24 +309,17 @@ func TestOpenWebSocketConnectionReceivesEventsDefaultMount(t *testing.T) {
} }
// Put a secret, which should trigger an event // Put a secret, which should trigger an event
err = client.KVv1("secret").Put(context.Background(), "foo", makeData(100)) err = client.KVv1("secret").Put(context.Background(), "foo", makeData(100))
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
for i := 0; i < 5; i++ { for i := 0; i < 5; i++ {
// Do a fresh PUT just to refresh the secret and send a new message // Do a fresh PUT just to refresh the secret and send a new message
err = client.KVv1("secret").Put(context.Background(), "foo", makeData(i)) err = client.KVv1("secret").Put(context.Background(), "foo", makeData(i))
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
// This method blocks until it gets a secret, so this test // This method blocks until it gets a secret, so this test
// will only pass if we're receiving events correctly. // will only pass if we're receiving events correctly.
_, message, err := conn.Read(context.Background()) _, _, err = conn.Read(context.Background())
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
t.Log(string(message))
} }
} }
@@ -364,9 +345,7 @@ func TestOpenWebSocketConnectionReceivesEventsKVV1(t *testing.T) {
updater := testNewStaticSecretCacheUpdater(t, client) updater := testNewStaticSecretCacheUpdater(t, client)
conn, err := updater.openWebSocketConnection(context.Background()) conn, err := updater.openWebSocketConnection(context.Background())
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
require.NotNil(t, conn) require.NotNil(t, conn)
t.Cleanup(func() { t.Cleanup(func() {
@@ -376,9 +355,7 @@ func TestOpenWebSocketConnectionReceivesEventsKVV1(t *testing.T) {
err = client.Sys().Mount("secret-v1", &api.MountInput{ err = client.Sys().Mount("secret-v1", &api.MountInput{
Type: "kv", Type: "kv",
}) })
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
makeData := func(i int) map[string]interface{} { makeData := func(i int) map[string]interface{} {
return map[string]interface{}{ return map[string]interface{}{
@@ -387,23 +364,17 @@ func TestOpenWebSocketConnectionReceivesEventsKVV1(t *testing.T) {
} }
// Put a secret, which should trigger an event // Put a secret, which should trigger an event
err = client.KVv1("secret-v1").Put(context.Background(), "foo", makeData(100)) err = client.KVv1("secret-v1").Put(context.Background(), "foo", makeData(100))
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
for i := 0; i < 5; i++ { for i := 0; i < 5; i++ {
// Do a fresh PUT just to refresh the secret and send a new message // Do a fresh PUT just to refresh the secret and send a new message
err = client.KVv1("secret-v1").Put(context.Background(), "foo", makeData(i)) err = client.KVv1("secret-v1").Put(context.Background(), "foo", makeData(i))
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
// This method blocks until it gets a secret, so this test // This method blocks until it gets a secret, so this test
// will only pass if we're receiving events correctly. // will only pass if we're receiving events correctly.
_, _, err := conn.Read(context.Background()) _, _, err := conn.Read(context.Background())
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
} }
} }
@@ -429,9 +400,7 @@ func TestOpenWebSocketConnectionReceivesEventsKVV2(t *testing.T) {
updater := testNewStaticSecretCacheUpdater(t, client) updater := testNewStaticSecretCacheUpdater(t, client)
conn, err := updater.openWebSocketConnection(context.Background()) conn, err := updater.openWebSocketConnection(context.Background())
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
require.NotNil(t, conn) require.NotNil(t, conn)
t.Cleanup(func() { t.Cleanup(func() {
@@ -447,29 +416,21 @@ func TestOpenWebSocketConnectionReceivesEventsKVV2(t *testing.T) {
err = client.Sys().Mount("secret-v2", &api.MountInput{ err = client.Sys().Mount("secret-v2", &api.MountInput{
Type: "kv-v2", Type: "kv-v2",
}) })
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
// Put a secret, which should trigger an event // Put a secret, which should trigger an event
_, err = client.KVv2("secret-v2").Put(context.Background(), "foo", makeData(100)) _, err = client.KVv2("secret-v2").Put(context.Background(), "foo", makeData(100))
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
for i := 0; i < 5; i++ { for i := 0; i < 5; i++ {
// Do a fresh PUT just to refresh the secret and send a new message // Do a fresh PUT just to refresh the secret and send a new message
_, err = client.KVv2("secret-v2").Put(context.Background(), "foo", makeData(i)) _, err = client.KVv2("secret-v2").Put(context.Background(), "foo", makeData(i))
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
// This method blocks until it gets a secret, so this test // This method blocks until it gets a secret, so this test
// will only pass if we're receiving events correctly. // will only pass if we're receiving events correctly.
_, _, err := conn.Read(context.Background()) _, _, err := conn.Read(context.Background())
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
} }
} }
@@ -489,24 +450,18 @@ func TestOpenWebSocketConnectionTestServer(t *testing.T) {
keys, rootToken := vault.TestCoreInit(t, core) keys, rootToken := vault.TestCoreInit(t, core)
for _, key := range keys { for _, key := range keys {
_, err := core.Unseal(key) _, err := core.Unseal(key)
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
} }
config := api.DefaultConfig() config := api.DefaultConfig()
config.Address = addr config.Address = addr
client, err := api.NewClient(config) client, err := api.NewClient(config)
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
client.SetToken(rootToken) client.SetToken(rootToken)
updater := testNewStaticSecretCacheUpdater(t, client) updater := testNewStaticSecretCacheUpdater(t, client)
conn, err := updater.openWebSocketConnection(context.Background()) conn, err := updater.openWebSocketConnection(context.Background())
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
require.NotNil(t, conn) require.NotNil(t, conn)
} }
@@ -636,7 +591,7 @@ func TestUpdateStaticSecret(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
// attempt the update // attempt the update
err = updater.updateStaticSecret(context.Background(), path, 0) err = updater.updateStaticSecret(context.Background(), path)
require.NoError(t, err) require.NoError(t, err)
newIndex, err := leaseCache.db.Get(cachememdb.IndexNameID, indexId) newIndex, err := leaseCache.db.Get(cachememdb.IndexNameID, indexId)
@@ -649,6 +604,72 @@ func TestUpdateStaticSecret(t *testing.T) {
require.Len(t, newIndex.Versions, 0) require.Len(t, newIndex.Versions, 0)
} }
// TestUpdateStaticSecret_KVv2 tests that updateStaticSecret works as expected, reaching out
// to Vault to get an updated secret when called. It should also update the corresponding
// version of that secret in the cache index's Versions field.
func TestUpdateStaticSecret_KVv2(t *testing.T) {
t.Parallel()
// We need a valid cluster for the connection to succeed.
cluster := vault.NewTestCluster(t, &vault.CoreConfig{
LogicalBackends: map[string]logical.Factory{
"kv": kv.VersionedKVFactory,
},
}, &vault.TestClusterOptions{
HandlerFunc: vaulthttp.Handler,
})
client := cluster.Cores[0].Client
updater := testNewStaticSecretCacheUpdater(t, client)
leaseCache := updater.leaseCache
path := "secret-v2/data/foo"
indexId := hashStaticSecretIndex(path)
initialTime := time.Now().UTC()
// pre-populate the leaseCache with a secret to update
index := &cachememdb.Index{
Namespace: "root/",
RequestPath: path,
LastRenewed: initialTime,
ID: indexId,
Versions: map[int][]byte{},
// Valid token provided, so update should work.
Tokens: map[string]struct{}{client.Token(): {}},
Response: []byte{},
}
err := leaseCache.db.Set(index)
require.NoError(t, err)
secretData := map[string]interface{}{
"foo": "bar",
}
err = client.Sys().Mount("secret-v2", &api.MountInput{
Type: "kv-v2",
})
require.NoError(t, err)
// create the secret in Vault
_, err = client.KVv2("secret-v2").Put(context.Background(), "foo", secretData)
require.NoError(t, err)
// attempt the update
err = updater.updateStaticSecret(context.Background(), path)
require.NoError(t, err)
newIndex, err := leaseCache.db.Get(cachememdb.IndexNameID, indexId)
require.NoError(t, err)
require.NotNil(t, newIndex)
require.Truef(t, initialTime.Before(newIndex.LastRenewed), "last updated time not updated on index")
require.NotEqual(t, []byte{}, newIndex.Response)
require.Equal(t, index.RequestPath, newIndex.RequestPath)
require.Equal(t, index.Tokens, newIndex.Tokens)
// It should have also updated version 1 with the same version.
require.Len(t, newIndex.Versions, 1)
require.NotNil(t, newIndex.Versions[1])
require.Equal(t, newIndex.Versions[1], newIndex.Response)
}
// TestUpdateStaticSecret_EvictsIfInvalidTokens tests that updateStaticSecret will // TestUpdateStaticSecret_EvictsIfInvalidTokens tests that updateStaticSecret will
// evict secrets from the cache if no valid tokens are left. // evict secrets from the cache if no valid tokens are left.
func TestUpdateStaticSecret_EvictsIfInvalidTokens(t *testing.T) { func TestUpdateStaticSecret_EvictsIfInvalidTokens(t *testing.T) {
@@ -676,9 +697,7 @@ func TestUpdateStaticSecret_EvictsIfInvalidTokens(t *testing.T) {
Tokens: map[string]struct{}{"invalid token": {}}, Tokens: map[string]struct{}{"invalid token": {}},
} }
err := leaseCache.db.Set(index) err := leaseCache.db.Set(index)
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
secretData := map[string]interface{}{ secretData := map[string]interface{}{
"foo": "bar", "foo": "bar",
@@ -686,15 +705,11 @@ func TestUpdateStaticSecret_EvictsIfInvalidTokens(t *testing.T) {
// create the secret in Vault. n.b. the test cluster has already mounted the KVv1 backend at "secret" // create the secret in Vault. n.b. the test cluster has already mounted the KVv1 backend at "secret"
err = client.KVv1("secret").Put(context.Background(), "foo", secretData) err = client.KVv1("secret").Put(context.Background(), "foo", secretData)
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
// attempt the update // attempt the update
err = updater.updateStaticSecret(context.Background(), path, 0) err = updater.updateStaticSecret(context.Background(), path)
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
newIndex, err := leaseCache.db.Get(cachememdb.IndexNameID, indexId) newIndex, err := leaseCache.db.Get(cachememdb.IndexNameID, indexId)
require.Equal(t, cachememdb.ErrCacheItemNotFound, err) require.Equal(t, cachememdb.ErrCacheItemNotFound, err)
@@ -715,13 +730,8 @@ func TestUpdateStaticSecret_HandlesNonCachedPaths(t *testing.T) {
path := "secret/foo" path := "secret/foo"
// Attempt the update for with currentVersion 0 // Attempt the update
err := updater.updateStaticSecret(context.Background(), path, 0) err := updater.updateStaticSecret(context.Background(), path)
require.NoError(t, err)
require.Nil(t, err)
// Attempt a higher currentVersion just to be sure
err = updater.updateStaticSecret(context.Background(), path, 100)
require.NoError(t, err) require.NoError(t, err)
require.Nil(t, err) require.Nil(t, err)
} }
@@ -821,9 +831,7 @@ func TestPreEventStreamUpdateErrorUpdating(t *testing.T) {
Type: cacheboltdb.StaticSecretType, Type: cacheboltdb.StaticSecretType,
} }
err := leaseCache.db.Set(index) err := leaseCache.db.Set(index)
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
secretData := map[string]interface{}{ secretData := map[string]interface{}{
"foo": "bar", "foo": "bar",
@@ -832,15 +840,11 @@ func TestPreEventStreamUpdateErrorUpdating(t *testing.T) {
err = client.Sys().Mount("secret-v2", &api.MountInput{ err = client.Sys().Mount("secret-v2", &api.MountInput{
Type: "kv-v2", Type: "kv-v2",
}) })
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
// Put a secret (with different values to what's currently in the cache) // Put a secret (with different values to what's currently in the cache)
_, err = client.KVv2("secret-v2").Put(context.Background(), "foo", secretData) _, err = client.KVv2("secret-v2").Put(context.Background(), "foo", secretData)
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
// Seal Vault, so that the update will fail // Seal Vault, so that the update will fail
cluster.EnsureCoresSealed(t) cluster.EnsureCoresSealed(t)