Add context to storage backends and wire it through a lot of places (#3817)

This commit is contained in:
Brian Kassouf
2018-01-18 22:44:44 -08:00
committed by Jeff Mitchell
parent 2864fbd697
commit 8142b42d95
341 changed files with 3417 additions and 3083 deletions

View File

@@ -1,6 +1,8 @@
package audit
import (
"context"
"github.com/hashicorp/vault/helper/salt"
"github.com/hashicorp/vault/logical"
)
@@ -14,13 +16,13 @@ type Backend interface {
// request is authorized but before the request is executed. The arguments
// MUST not be modified in anyway. They should be deep copied if this is
// a possibility.
LogRequest(*logical.Auth, *logical.Request, error) error
LogRequest(context.Context, *logical.Auth, *logical.Request, error) error
// LogResponse is used to synchronously log a response. This is done after
// the request is processed but before the response is sent. The arguments
// MUST not be modified in anyway. They should be deep copied if this is
// a possibility.
LogResponse(*logical.Auth, *logical.Request, *logical.Response, error) error
LogResponse(context.Context, *logical.Auth, *logical.Request, *logical.Response, error) error
// GetHash is used to return the given data with the backend's hash,
// so that a caller can determine if a value in the audit log matches
@@ -28,10 +30,10 @@ type Backend interface {
GetHash(string) (string, error)
// Reload is called on SIGHUP for supporting backends.
Reload() error
Reload(context.Context) error
// Invalidate is called for path invalidation
Invalidate()
Invalidate(context.Context)
}
type BackendConfig struct {
@@ -46,4 +48,4 @@ type BackendConfig struct {
}
// Factory is the factory function to create an audit backend.
type Factory func(*BackendConfig) (Backend, error)
type Factory func(context.Context, *BackendConfig) (Backend, error)

View File

@@ -1,6 +1,7 @@
package audit
import (
"context"
"crypto/sha256"
"fmt"
"reflect"
@@ -94,7 +95,7 @@ func TestCopy_response(t *testing.T) {
func TestHashString(t *testing.T) {
inmemStorage := &logical.InmemStorage{}
inmemStorage.Put(&logical.StorageEntry{
inmemStorage.Put(context.Background(), &logical.StorageEntry{
Key: "salt",
Value: []byte("foo"),
})
@@ -192,7 +193,7 @@ func TestHash(t *testing.T) {
}
inmemStorage := &logical.InmemStorage{}
inmemStorage.Put(&logical.StorageEntry{
inmemStorage.Put(context.Background(), &logical.StorageEntry{
Key: "salt",
Value: []byte("foo"),
})

View File

@@ -1,6 +1,7 @@
package file
import (
"context"
"fmt"
"io/ioutil"
"os"
@@ -14,7 +15,7 @@ import (
"github.com/hashicorp/vault/logical"
)
func Factory(conf *audit.BackendConfig) (audit.Backend, error) {
func Factory(ctx context.Context, conf *audit.BackendConfig) (audit.Backend, error) {
if conf.SaltConfig == nil {
return nil, fmt.Errorf("nil salt config")
}
@@ -168,7 +169,12 @@ func (b *Backend) GetHash(data string) (string, error) {
return audit.HashString(salt, data), nil
}
func (b *Backend) LogRequest(auth *logical.Auth, req *logical.Request, outerErr error) error {
func (b *Backend) LogRequest(
_ context.Context,
auth *logical.Auth,
req *logical.Request,
outerErr error) error {
b.fileLock.Lock()
defer b.fileLock.Unlock()
@@ -199,6 +205,7 @@ func (b *Backend) LogRequest(auth *logical.Auth, req *logical.Request, outerErr
}
func (b *Backend) LogResponse(
_ context.Context,
auth *logical.Auth,
req *logical.Request,
resp *logical.Response,
@@ -264,7 +271,7 @@ func (b *Backend) open() error {
return nil
}
func (b *Backend) Reload() error {
func (b *Backend) Reload(_ context.Context) error {
switch b.path {
case "stdout", "discard":
return nil
@@ -288,7 +295,7 @@ func (b *Backend) Reload() error {
return b.open()
}
func (b *Backend) Invalidate() {
func (b *Backend) Invalidate(_ context.Context) {
b.saltMutex.Lock()
defer b.saltMutex.Unlock()
b.salt = nil

View File

@@ -1,6 +1,7 @@
package file
import (
"context"
"io/ioutil"
"os"
"path/filepath"
@@ -33,7 +34,7 @@ func TestAuditFile_fileModeNew(t *testing.T) {
"mode": modeStr,
}
_, err = Factory(&audit.BackendConfig{
_, err = Factory(context.Background(), &audit.BackendConfig{
SaltConfig: &salt.Config{},
SaltView: &logical.InmemStorage{},
Config: config,
@@ -72,7 +73,7 @@ func TestAuditFile_fileModeExisting(t *testing.T) {
"path": f.Name(),
}
_, err = Factory(&audit.BackendConfig{
_, err = Factory(context.Background(), &audit.BackendConfig{
Config: config,
SaltConfig: &salt.Config{},
SaltView: &logical.InmemStorage{},

View File

@@ -2,6 +2,7 @@ package socket
import (
"bytes"
"context"
"fmt"
"net"
"strconv"
@@ -15,7 +16,7 @@ import (
"github.com/hashicorp/vault/logical"
)
func Factory(conf *audit.BackendConfig) (audit.Backend, error) {
func Factory(ctx context.Context, conf *audit.BackendConfig) (audit.Backend, error) {
if conf.SaltConfig == nil {
return nil, fmt.Errorf("nil salt config")
}
@@ -128,7 +129,7 @@ func (b *Backend) GetHash(data string) (string, error) {
return audit.HashString(salt, data), nil
}
func (b *Backend) LogRequest(auth *logical.Auth, req *logical.Request, outerErr error) error {
func (b *Backend) LogRequest(ctx context.Context, auth *logical.Auth, req *logical.Request, outerErr error) error {
var buf bytes.Buffer
if err := b.formatter.FormatRequest(&buf, b.formatConfig, auth, req, outerErr); err != nil {
return err
@@ -137,21 +138,21 @@ func (b *Backend) LogRequest(auth *logical.Auth, req *logical.Request, outerErr
b.Lock()
defer b.Unlock()
err := b.write(buf.Bytes())
err := b.write(ctx, buf.Bytes())
if err != nil {
rErr := b.reconnect()
rErr := b.reconnect(ctx)
if rErr != nil {
err = multierror.Append(err, rErr)
} else {
// Try once more after reconnecting
err = b.write(buf.Bytes())
err = b.write(ctx, buf.Bytes())
}
}
return err
}
func (b *Backend) LogResponse(auth *logical.Auth, req *logical.Request,
func (b *Backend) LogResponse(ctx context.Context, auth *logical.Auth, req *logical.Request,
resp *logical.Response, outerErr error) error {
var buf bytes.Buffer
if err := b.formatter.FormatResponse(&buf, b.formatConfig, auth, req, resp, outerErr); err != nil {
@@ -161,23 +162,23 @@ func (b *Backend) LogResponse(auth *logical.Auth, req *logical.Request,
b.Lock()
defer b.Unlock()
err := b.write(buf.Bytes())
err := b.write(ctx, buf.Bytes())
if err != nil {
rErr := b.reconnect()
rErr := b.reconnect(ctx)
if rErr != nil {
err = multierror.Append(err, rErr)
} else {
// Try once more after reconnecting
err = b.write(buf.Bytes())
err = b.write(ctx, buf.Bytes())
}
}
return err
}
func (b *Backend) write(buf []byte) error {
func (b *Backend) write(ctx context.Context, buf []byte) error {
if b.connection == nil {
if err := b.reconnect(); err != nil {
if err := b.reconnect(ctx); err != nil {
return err
}
}
@@ -195,13 +196,14 @@ func (b *Backend) write(buf []byte) error {
return err
}
func (b *Backend) reconnect() error {
func (b *Backend) reconnect(ctx context.Context) error {
if b.connection != nil {
b.connection.Close()
b.connection = nil
}
conn, err := net.Dial(b.socketType, b.address)
dialer := net.Dialer{}
conn, err := dialer.DialContext(ctx, b.socketType, b.address)
if err != nil {
return err
}
@@ -211,11 +213,11 @@ func (b *Backend) reconnect() error {
return nil
}
func (b *Backend) Reload() error {
func (b *Backend) Reload(ctx context.Context) error {
b.Lock()
defer b.Unlock()
err := b.reconnect()
err := b.reconnect(ctx)
return err
}
@@ -240,7 +242,7 @@ func (b *Backend) Salt() (*salt.Salt, error) {
return salt, nil
}
func (b *Backend) Invalidate() {
func (b *Backend) Invalidate(_ context.Context) {
b.saltMutex.Lock()
defer b.saltMutex.Unlock()
b.salt = nil

View File

@@ -2,6 +2,7 @@ package syslog
import (
"bytes"
"context"
"fmt"
"strconv"
"sync"
@@ -12,7 +13,7 @@ import (
"github.com/hashicorp/vault/logical"
)
func Factory(conf *audit.BackendConfig) (audit.Backend, error) {
func Factory(ctx context.Context, conf *audit.BackendConfig) (audit.Backend, error) {
if conf.SaltConfig == nil {
return nil, fmt.Errorf("nil salt config")
}
@@ -115,7 +116,7 @@ func (b *Backend) GetHash(data string) (string, error) {
return audit.HashString(salt, data), nil
}
func (b *Backend) LogRequest(auth *logical.Auth, req *logical.Request, outerErr error) error {
func (b *Backend) LogRequest(_ context.Context, auth *logical.Auth, req *logical.Request, outerErr error) error {
var buf bytes.Buffer
if err := b.formatter.FormatRequest(&buf, b.formatConfig, auth, req, outerErr); err != nil {
return err
@@ -126,7 +127,7 @@ func (b *Backend) LogRequest(auth *logical.Auth, req *logical.Request, outerErr
return err
}
func (b *Backend) LogResponse(auth *logical.Auth, req *logical.Request, resp *logical.Response, err error) error {
func (b *Backend) LogResponse(_ context.Context, auth *logical.Auth, req *logical.Request, resp *logical.Response, err error) error {
var buf bytes.Buffer
if err := b.formatter.FormatResponse(&buf, b.formatConfig, auth, req, resp, err); err != nil {
return err
@@ -137,7 +138,7 @@ func (b *Backend) LogResponse(auth *logical.Auth, req *logical.Request, resp *lo
return err
}
func (b *Backend) Reload() error {
func (b *Backend) Reload(_ context.Context) error {
return nil
}
@@ -161,7 +162,7 @@ func (b *Backend) Salt() (*salt.Salt, error) {
return salt, nil
}
func (b *Backend) Invalidate() {
func (b *Backend) Invalidate(_ context.Context) {
b.saltMutex.Lock()
defer b.saltMutex.Unlock()
b.salt = nil

View File

@@ -1,6 +1,7 @@
package appId
import (
"context"
"sync"
"github.com/hashicorp/vault/helper/salt"
@@ -8,12 +9,12 @@ import (
"github.com/hashicorp/vault/logical/framework"
)
func Factory(conf *logical.BackendConfig) (logical.Backend, error) {
func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) {
b, err := Backend(conf)
if err != nil {
return nil, err
}
if err := b.Setup(conf); err != nil {
if err := b.Setup(ctx, conf); err != nil {
return nil, err
}
return b, nil
@@ -115,7 +116,7 @@ func (b *backend) Salt() (*salt.Salt, error) {
return salt, nil
}
func (b *backend) invalidate(key string) {
func (b *backend) invalidate(_ context.Context, key string) {
switch key {
case salt.DefaultLocation:
b.SaltMutex.Lock()

View File

@@ -14,13 +14,13 @@ func TestBackend_basic(t *testing.T) {
var b *backend
var err error
var storage logical.Storage
factory := func(conf *logical.BackendConfig) (logical.Backend, error) {
factory := func(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) {
b, err = Backend(conf)
if err != nil {
t.Fatal(err)
}
storage = conf.StorageView
if err := b.Setup(conf); err != nil {
if err := b.Setup(ctx, conf); err != nil {
return nil, err
}
return b, nil

View File

@@ -84,7 +84,7 @@ func (b *backend) pathLogin(ctx context.Context, req *logical.Request, data *fra
userId := data.Get("user_id").(string)
var displayName string
if dispName, resp, err := b.verifyCredentials(req, appId, userId); err != nil {
if dispName, resp, err := b.verifyCredentials(ctx, req, appId, userId); err != nil {
return nil, err
} else if resp != nil {
return resp, nil
@@ -93,7 +93,7 @@ func (b *backend) pathLogin(ctx context.Context, req *logical.Request, data *fra
}
// Get the policies associated with the app
policies, err := b.MapAppId.Policies(req.Storage, appId)
policies, err := b.MapAppId.Policies(ctx, req.Storage, appId)
if err != nil {
return nil, err
}
@@ -131,14 +131,14 @@ func (b *backend) pathLoginRenew(ctx context.Context, req *logical.Request, d *f
// Skipping CIDR verification to enable renewal from machines other than
// the ones encompassed by CIDR block.
if _, resp, err := b.verifyCredentials(req, appId, userId); err != nil {
if _, resp, err := b.verifyCredentials(ctx, req, appId, userId); err != nil {
return nil, err
} else if resp != nil {
return resp, nil
}
// Get the policies associated with the app
mapPolicies, err := b.MapAppId.Policies(req.Storage, appId)
mapPolicies, err := b.MapAppId.Policies(ctx, req.Storage, appId)
if err != nil {
return nil, err
}
@@ -149,14 +149,14 @@ func (b *backend) pathLoginRenew(ctx context.Context, req *logical.Request, d *f
return framework.LeaseExtend(0, 0, b.System())(ctx, req, d)
}
func (b *backend) verifyCredentials(req *logical.Request, appId, userId string) (string, *logical.Response, error) {
func (b *backend) verifyCredentials(ctx context.Context, req *logical.Request, appId, userId string) (string, *logical.Response, error) {
// Ensure both appId and userId are provided
if appId == "" || userId == "" {
return "", logical.ErrorResponse("missing 'app_id' or 'user_id'"), nil
}
// Look up the apps that this user is allowed to access
appsMap, err := b.MapUserId.Get(req.Storage, userId)
appsMap, err := b.MapUserId.Get(ctx, req.Storage, userId)
if err != nil {
return "", nil, err
}
@@ -205,7 +205,7 @@ func (b *backend) verifyCredentials(req *logical.Request, appId, userId string)
}
// Get the raw data associated with the app
appRaw, err := b.MapAppId.Get(req.Storage, appId)
appRaw, err := b.MapAppId.Get(ctx, req.Storage, appId)
if err != nil {
return "", nil, err
}

View File

@@ -1,6 +1,7 @@
package approle
import (
"context"
"sync"
"github.com/hashicorp/vault/helper/locksutil"
@@ -49,12 +50,12 @@ type backend struct {
secretIDListingLock sync.RWMutex
}
func Factory(conf *logical.BackendConfig) (logical.Backend, error) {
func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) {
b, err := Backend(conf)
if err != nil {
return nil, err
}
if err := b.Setup(conf); err != nil {
if err := b.Setup(ctx, conf); err != nil {
return nil, err
}
return b, nil
@@ -125,7 +126,7 @@ func (b *backend) Salt() (*salt.Salt, error) {
return salt, nil
}
func (b *backend) invalidate(key string) {
func (b *backend) invalidate(_ context.Context, key string) {
switch key {
case salt.DefaultLocation:
b.saltMutex.Lock()
@@ -139,9 +140,9 @@ func (b *backend) invalidate(key string) {
// This could mean that the SecretID may live in the backend upto 1 min after its
// expiration. The deletion of SecretIDs are not security sensitive and it is okay
// to delay the removal of SecretIDs by a minute.
func (b *backend) periodicFunc(req *logical.Request) error {
func (b *backend) periodicFunc(ctx context.Context, req *logical.Request) error {
// Initiate clean-up of expired SecretID entries
b.tidySecretID(req.Storage)
b.tidySecretID(ctx, req.Storage)
return nil
}

View File

@@ -1,6 +1,7 @@
package approle
import (
"context"
"testing"
"github.com/hashicorp/vault/logical"
@@ -17,7 +18,7 @@ func createBackendWithStorage(t *testing.T) (*backend, logical.Storage) {
if b == nil {
t.Fatalf("failed to create backend")
}
err = b.Backend.Setup(config)
err = b.Backend.Setup(context.Background(), config)
if err != nil {
t.Fatal(err)
}

View File

@@ -51,7 +51,7 @@ func (b *backend) pathLoginUpdateAliasLookahead(ctx context.Context, req *logica
// Returns the Auth object indicating the authentication and authorization information
// if the credentials provided are validated by the backend.
func (b *backend) pathLoginUpdate(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
role, roleName, metadata, _, err := b.validateCredentials(req, data)
role, roleName, metadata, _, err := b.validateCredentials(ctx, req, data)
if err != nil || role == nil {
return logical.ErrorResponse(fmt.Sprintf("failed to validate credentials: %v", err)), nil
}
@@ -93,7 +93,7 @@ func (b *backend) pathLoginRenew(ctx context.Context, req *logical.Request, data
defer lock.RUnlock()
// Ensure that the Role still exists.
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName))
if err != nil {
return nil, fmt.Errorf("failed to validate role %s during renewal:%s", roleName, err)
}

View File

@@ -523,7 +523,7 @@ func (b *backend) pathRoleExistenceCheck(ctx context.Context, req *logical.Reque
lock.RLock()
defer lock.RUnlock()
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName))
if err != nil {
return false, err
}
@@ -538,7 +538,7 @@ func (b *backend) pathRoleList(ctx context.Context, req *logical.Request, data *
lock.RLock()
defer lock.RUnlock()
roles, err := req.Storage.List("role/")
roles, err := req.Storage.List(ctx, "role/")
if err != nil {
return nil, err
}
@@ -557,7 +557,7 @@ func (b *backend) pathRoleSecretIDList(ctx context.Context, req *logical.Request
defer lock.RUnlock()
// Get the role entry
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName))
if err != nil {
return nil, err
}
@@ -580,7 +580,7 @@ func (b *backend) pathRoleSecretIDList(ctx context.Context, req *logical.Request
// Listing works one level at a time. Get the first level of data
// which could then be used to get the actual SecretID storage entries.
secretIDHMACs, err := req.Storage.List(fmt.Sprintf("secret_id/%s/", roleNameHMAC))
secretIDHMACs, err := req.Storage.List(ctx, fmt.Sprintf("secret_id/%s/", roleNameHMAC))
if err != nil {
return nil, err
}
@@ -606,7 +606,7 @@ func (b *backend) pathRoleSecretIDList(ctx context.Context, req *logical.Request
secretIDLock.RLock()
result := secretIDStorageEntry{}
if entry, err := req.Storage.Get(entryIndex); err != nil {
if entry, err := req.Storage.Get(ctx, entryIndex); err != nil {
secretIDLock.RUnlock()
return nil, err
} else if entry == nil {
@@ -643,7 +643,7 @@ func validateRoleConstraints(role *roleStorageEntry) error {
// setRoleEntry persists the role and creates an index from roleID to role
// name.
func (b *backend) setRoleEntry(s logical.Storage, roleName string, role *roleStorageEntry, previousRoleID string) error {
func (b *backend) setRoleEntry(ctx context.Context, s logical.Storage, roleName string, role *roleStorageEntry, previousRoleID string) error {
if roleName == "" {
return fmt.Errorf("missing role name")
}
@@ -667,7 +667,7 @@ func (b *backend) setRoleEntry(s logical.Storage, roleName string, role *roleSto
}
// Check if the index from the role_id to role already exists
roleIDIndex, err := b.roleIDEntry(s, role.RoleID)
roleIDIndex, err := b.roleIDEntry(ctx, s, role.RoleID)
if err != nil {
return fmt.Errorf("failed to read role_id index: %v", err)
}
@@ -680,13 +680,13 @@ func (b *backend) setRoleEntry(s logical.Storage, roleName string, role *roleSto
// When role_id is getting updated, delete the old index before
// a new one is created
if previousRoleID != "" && previousRoleID != role.RoleID {
if err = b.roleIDEntryDelete(s, previousRoleID); err != nil {
if err = b.roleIDEntryDelete(ctx, s, previousRoleID); err != nil {
return fmt.Errorf("failed to delete previous role ID index")
}
}
// Save the role entry only after all the validations
if err = s.Put(entry); err != nil {
if err = s.Put(ctx, entry); err != nil {
return err
}
@@ -697,20 +697,20 @@ func (b *backend) setRoleEntry(s logical.Storage, roleName string, role *roleSto
// Create a storage entry for reverse mapping of RoleID to role.
// Note that secondary index is created when the roleLock is held.
return b.setRoleIDEntry(s, role.RoleID, &roleIDStorageEntry{
return b.setRoleIDEntry(ctx, s, role.RoleID, &roleIDStorageEntry{
Name: roleName,
})
}
// roleEntry reads the role from storage
func (b *backend) roleEntry(s logical.Storage, roleName string) (*roleStorageEntry, error) {
func (b *backend) roleEntry(ctx context.Context, s logical.Storage, roleName string) (*roleStorageEntry, error) {
if roleName == "" {
return nil, fmt.Errorf("missing role_name")
}
var role roleStorageEntry
if entry, err := s.Get("role/" + strings.ToLower(roleName)); err != nil {
if entry, err := s.Get(ctx, "role/"+strings.ToLower(roleName)); err != nil {
return nil, err
} else if entry == nil {
return nil, nil
@@ -734,7 +734,7 @@ func (b *backend) pathRoleCreateUpdate(ctx context.Context, req *logical.Request
defer lock.Unlock()
// Check if the role already exists
role, err := b.roleEntry(req.Storage, roleName)
role, err := b.roleEntry(ctx, req.Storage, roleName)
if err != nil {
return nil, err
}
@@ -855,7 +855,7 @@ func (b *backend) pathRoleCreateUpdate(ctx context.Context, req *logical.Request
}
// Store the entry.
return resp, b.setRoleEntry(req.Storage, roleName, role, previousRoleID)
return resp, b.setRoleEntry(ctx, req.Storage, roleName, role, previousRoleID)
}
// pathRoleRead grabs a read lock and reads the options set on the role from the storage
@@ -869,7 +869,7 @@ func (b *backend) pathRoleRead(ctx context.Context, req *logical.Request, data *
lock.RLock()
lockRelease := lock.RUnlock
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName))
if err != nil {
lockRelease()
return nil, err
@@ -902,7 +902,7 @@ func (b *backend) pathRoleRead(ctx context.Context, req *logical.Request, data *
// For sanity, verify that the index still exists. If the index is missing,
// add one and return a warning so it can be reported.
roleIDIndex, err := b.roleIDEntry(req.Storage, role.RoleID)
roleIDIndex, err := b.roleIDEntry(ctx, req.Storage, role.RoleID)
if err != nil {
lockRelease()
return nil, err
@@ -915,7 +915,7 @@ func (b *backend) pathRoleRead(ctx context.Context, req *logical.Request, data *
lockRelease = lock.Unlock
// Check again if the index is missing
roleIDIndex, err = b.roleIDEntry(req.Storage, role.RoleID)
roleIDIndex, err = b.roleIDEntry(ctx, req.Storage, role.RoleID)
if err != nil {
lockRelease()
return nil, err
@@ -923,7 +923,7 @@ func (b *backend) pathRoleRead(ctx context.Context, req *logical.Request, data *
if roleIDIndex == nil {
// Create a new index
err = b.setRoleIDEntry(req.Storage, role.RoleID, &roleIDStorageEntry{
err = b.setRoleIDEntry(ctx, req.Storage, role.RoleID, &roleIDStorageEntry{
Name: roleName,
})
if err != nil {
@@ -950,7 +950,7 @@ func (b *backend) pathRoleDelete(ctx context.Context, req *logical.Request, data
lock.Lock()
defer lock.Unlock()
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName))
if err != nil {
return nil, err
}
@@ -959,17 +959,17 @@ func (b *backend) pathRoleDelete(ctx context.Context, req *logical.Request, data
}
// Just before the role is deleted, remove all the SecretIDs issued as part of the role.
if err = b.flushRoleSecrets(req.Storage, roleName, role.HMACKey); err != nil {
if err = b.flushRoleSecrets(ctx, req.Storage, roleName, role.HMACKey); err != nil {
return nil, fmt.Errorf("failed to invalidate the secrets belonging to role %q: %v", roleName, err)
}
// Delete the reverse mapping from RoleID to the role
if err = b.roleIDEntryDelete(req.Storage, role.RoleID); err != nil {
if err = b.roleIDEntryDelete(ctx, req.Storage, role.RoleID); err != nil {
return nil, fmt.Errorf("failed to delete the mapping from RoleID to role %q: %v", roleName, err)
}
// After deleting the SecretIDs and the RoleID, delete the role itself
if err = req.Storage.Delete("role/" + strings.ToLower(roleName)); err != nil {
if err = req.Storage.Delete(ctx, "role/"+strings.ToLower(roleName)); err != nil {
return nil, err
}
@@ -993,7 +993,7 @@ func (b *backend) pathRoleSecretIDLookupUpdate(ctx context.Context, req *logical
defer lock.RUnlock()
// Fetch the role
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName))
if err != nil {
return nil, err
}
@@ -1020,16 +1020,16 @@ func (b *backend) pathRoleSecretIDLookupUpdate(ctx context.Context, req *logical
// Create the index at which the secret_id would've been stored
entryIndex := fmt.Sprintf("secret_id/%s/%s", roleNameHMAC, secretIDHMAC)
return b.secretIDCommon(req.Storage, entryIndex, secretIDHMAC)
return b.secretIDCommon(ctx, req.Storage, entryIndex, secretIDHMAC)
}
func (b *backend) secretIDCommon(s logical.Storage, entryIndex, secretIDHMAC string) (*logical.Response, error) {
func (b *backend) secretIDCommon(ctx context.Context, s logical.Storage, entryIndex, secretIDHMAC string) (*logical.Response, error) {
lock := b.secretIDLock(secretIDHMAC)
lock.RLock()
defer lock.RUnlock()
result := secretIDStorageEntry{}
if entry, err := s.Get(entryIndex); err != nil {
if entry, err := s.Get(ctx, entryIndex); err != nil {
return nil, err
} else if entry == nil {
return nil, nil
@@ -1075,7 +1075,7 @@ func (b *backend) pathRoleSecretIDDestroyUpdateDelete(ctx context.Context, req *
roleLock.RLock()
defer roleLock.RUnlock()
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName))
if err != nil {
return nil, err
}
@@ -1100,7 +1100,7 @@ func (b *backend) pathRoleSecretIDDestroyUpdateDelete(ctx context.Context, req *
defer lock.Unlock()
result := secretIDStorageEntry{}
if entry, err := req.Storage.Get(entryIndex); err != nil {
if entry, err := req.Storage.Get(ctx, entryIndex); err != nil {
return nil, err
} else if entry == nil {
return nil, nil
@@ -1109,12 +1109,12 @@ func (b *backend) pathRoleSecretIDDestroyUpdateDelete(ctx context.Context, req *
}
// Delete the accessor of the SecretID first
if err := b.deleteSecretIDAccessorEntry(req.Storage, result.SecretIDAccessor); err != nil {
if err := b.deleteSecretIDAccessorEntry(ctx, req.Storage, result.SecretIDAccessor); err != nil {
return nil, err
}
// Delete the storage entry that corresponds to the SecretID
if err := req.Storage.Delete(entryIndex); err != nil {
if err := req.Storage.Delete(ctx, entryIndex); err != nil {
return nil, fmt.Errorf("failed to delete secret_id: %v", err)
}
@@ -1142,7 +1142,7 @@ func (b *backend) pathRoleSecretIDAccessorLookupUpdate(ctx context.Context, req
lock.RLock()
defer lock.RUnlock()
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName))
if err != nil {
return nil, err
}
@@ -1150,7 +1150,7 @@ func (b *backend) pathRoleSecretIDAccessorLookupUpdate(ctx context.Context, req
return nil, fmt.Errorf("role %q does not exist", roleName)
}
accessorEntry, err := b.secretIDAccessorEntry(req.Storage, secretIDAccessor)
accessorEntry, err := b.secretIDAccessorEntry(ctx, req.Storage, secretIDAccessor)
if err != nil {
return nil, err
}
@@ -1165,7 +1165,7 @@ func (b *backend) pathRoleSecretIDAccessorLookupUpdate(ctx context.Context, req
entryIndex := fmt.Sprintf("secret_id/%s/%s", roleNameHMAC, accessorEntry.SecretIDHMAC)
return b.secretIDCommon(req.Storage, entryIndex, accessorEntry.SecretIDHMAC)
return b.secretIDCommon(ctx, req.Storage, entryIndex, accessorEntry.SecretIDHMAC)
}
func (b *backend) pathRoleSecretIDAccessorDestroyUpdateDelete(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
@@ -1183,7 +1183,7 @@ func (b *backend) pathRoleSecretIDAccessorDestroyUpdateDelete(ctx context.Contex
// Get the role details to fetch the RoleID and accessor to get
// the HMACed SecretID.
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName))
if err != nil {
return nil, err
}
@@ -1191,7 +1191,7 @@ func (b *backend) pathRoleSecretIDAccessorDestroyUpdateDelete(ctx context.Contex
return nil, fmt.Errorf("role %q does not exist", roleName)
}
accessorEntry, err := b.secretIDAccessorEntry(req.Storage, secretIDAccessor)
accessorEntry, err := b.secretIDAccessorEntry(ctx, req.Storage, secretIDAccessor)
if err != nil {
return nil, err
}
@@ -1211,12 +1211,12 @@ func (b *backend) pathRoleSecretIDAccessorDestroyUpdateDelete(ctx context.Contex
defer lock.Unlock()
// Delete the accessor of the SecretID first
if err := b.deleteSecretIDAccessorEntry(req.Storage, secretIDAccessor); err != nil {
if err := b.deleteSecretIDAccessorEntry(ctx, req.Storage, secretIDAccessor); err != nil {
return nil, err
}
// Delete the storage entry that corresponds to the SecretID
if err := req.Storage.Delete(entryIndex); err != nil {
if err := req.Storage.Delete(ctx, entryIndex); err != nil {
return nil, fmt.Errorf("failed to delete secret_id: %v", err)
}
@@ -1234,7 +1234,7 @@ func (b *backend) pathRoleBoundCIDRListUpdate(ctx context.Context, req *logical.
defer lock.Unlock()
// Re-read the role after grabbing the lock
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName))
if err != nil {
return nil, err
}
@@ -1257,7 +1257,7 @@ func (b *backend) pathRoleBoundCIDRListUpdate(ctx context.Context, req *logical.
}
}
return nil, b.setRoleEntry(req.Storage, roleName, role, "")
return nil, b.setRoleEntry(ctx, req.Storage, roleName, role, "")
}
func (b *backend) pathRoleBoundCIDRListRead(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
@@ -1270,7 +1270,7 @@ func (b *backend) pathRoleBoundCIDRListRead(ctx context.Context, req *logical.Re
lock.Lock()
defer lock.Unlock()
if role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)); err != nil {
if role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName)); err != nil {
return nil, err
} else if role == nil {
return nil, nil
@@ -1293,7 +1293,7 @@ func (b *backend) pathRoleBoundCIDRListDelete(ctx context.Context, req *logical.
lock.Lock()
defer lock.Unlock()
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName))
if err != nil {
return nil, err
}
@@ -1304,7 +1304,7 @@ func (b *backend) pathRoleBoundCIDRListDelete(ctx context.Context, req *logical.
// Deleting a field implies setting the value to it's default value.
role.BoundCIDRList = data.GetDefaultOrZero("bound_cidr_list").(string)
return nil, b.setRoleEntry(req.Storage, roleName, role, "")
return nil, b.setRoleEntry(ctx, req.Storage, roleName, role, "")
}
func (b *backend) pathRoleBindSecretIDUpdate(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
@@ -1317,7 +1317,7 @@ func (b *backend) pathRoleBindSecretIDUpdate(ctx context.Context, req *logical.R
lock.Lock()
defer lock.Unlock()
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName))
if err != nil {
return nil, err
}
@@ -1327,7 +1327,7 @@ func (b *backend) pathRoleBindSecretIDUpdate(ctx context.Context, req *logical.R
if bindSecretIDRaw, ok := data.GetOk("bind_secret_id"); ok {
role.BindSecretID = bindSecretIDRaw.(bool)
return nil, b.setRoleEntry(req.Storage, roleName, role, "")
return nil, b.setRoleEntry(ctx, req.Storage, roleName, role, "")
} else {
return logical.ErrorResponse("missing bind_secret_id"), nil
}
@@ -1343,7 +1343,7 @@ func (b *backend) pathRoleBindSecretIDRead(ctx context.Context, req *logical.Req
lock.RLock()
defer lock.RUnlock()
if role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)); err != nil {
if role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName)); err != nil {
return nil, err
} else if role == nil {
return nil, nil
@@ -1366,7 +1366,7 @@ func (b *backend) pathRoleBindSecretIDDelete(ctx context.Context, req *logical.R
lock.Lock()
defer lock.Unlock()
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName))
if err != nil {
return nil, err
}
@@ -1377,7 +1377,7 @@ func (b *backend) pathRoleBindSecretIDDelete(ctx context.Context, req *logical.R
// Deleting a field implies setting the value to it's default value.
role.BindSecretID = data.GetDefaultOrZero("bind_secret_id").(bool)
return nil, b.setRoleEntry(req.Storage, roleName, role, "")
return nil, b.setRoleEntry(ctx, req.Storage, roleName, role, "")
}
func (b *backend) pathRolePoliciesUpdate(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
@@ -1390,7 +1390,7 @@ func (b *backend) pathRolePoliciesUpdate(ctx context.Context, req *logical.Reque
lock.Lock()
defer lock.Unlock()
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName))
if err != nil {
return nil, err
}
@@ -1405,7 +1405,7 @@ func (b *backend) pathRolePoliciesUpdate(ctx context.Context, req *logical.Reque
role.Policies = policyutil.ParsePolicies(policiesRaw)
return nil, b.setRoleEntry(req.Storage, roleName, role, "")
return nil, b.setRoleEntry(ctx, req.Storage, roleName, role, "")
}
func (b *backend) pathRolePoliciesRead(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
@@ -1418,7 +1418,7 @@ func (b *backend) pathRolePoliciesRead(ctx context.Context, req *logical.Request
lock.RLock()
defer lock.RUnlock()
if role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)); err != nil {
if role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName)); err != nil {
return nil, err
} else if role == nil {
return nil, nil
@@ -1441,7 +1441,7 @@ func (b *backend) pathRolePoliciesDelete(ctx context.Context, req *logical.Reque
lock.Lock()
defer lock.Unlock()
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName))
if err != nil {
return nil, err
}
@@ -1451,7 +1451,7 @@ func (b *backend) pathRolePoliciesDelete(ctx context.Context, req *logical.Reque
role.Policies = []string{}
return nil, b.setRoleEntry(req.Storage, roleName, role, "")
return nil, b.setRoleEntry(ctx, req.Storage, roleName, role, "")
}
func (b *backend) pathRoleSecretIDNumUsesUpdate(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
@@ -1464,7 +1464,7 @@ func (b *backend) pathRoleSecretIDNumUsesUpdate(ctx context.Context, req *logica
lock.Lock()
defer lock.Unlock()
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName))
if err != nil {
return nil, err
}
@@ -1477,7 +1477,7 @@ func (b *backend) pathRoleSecretIDNumUsesUpdate(ctx context.Context, req *logica
if role.SecretIDNumUses < 0 {
return logical.ErrorResponse("secret_id_num_uses cannot be negative"), nil
}
return nil, b.setRoleEntry(req.Storage, roleName, role, "")
return nil, b.setRoleEntry(ctx, req.Storage, roleName, role, "")
} else {
return logical.ErrorResponse("missing secret_id_num_uses"), nil
}
@@ -1493,7 +1493,7 @@ func (b *backend) pathRoleRoleIDUpdate(ctx context.Context, req *logical.Request
lock.Lock()
defer lock.Unlock()
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName))
if err != nil {
return nil, err
}
@@ -1507,7 +1507,7 @@ func (b *backend) pathRoleRoleIDUpdate(ctx context.Context, req *logical.Request
return logical.ErrorResponse("missing role_id"), nil
}
return nil, b.setRoleEntry(req.Storage, roleName, role, previousRoleID)
return nil, b.setRoleEntry(ctx, req.Storage, roleName, role, previousRoleID)
}
func (b *backend) pathRoleRoleIDRead(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
@@ -1520,7 +1520,7 @@ func (b *backend) pathRoleRoleIDRead(ctx context.Context, req *logical.Request,
lock.RLock()
defer lock.RUnlock()
if role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)); err != nil {
if role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName)); err != nil {
return nil, err
} else if role == nil {
return nil, nil
@@ -1543,7 +1543,7 @@ func (b *backend) pathRoleSecretIDNumUsesRead(ctx context.Context, req *logical.
lock.RLock()
defer lock.RUnlock()
if role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)); err != nil {
if role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName)); err != nil {
return nil, err
} else if role == nil {
return nil, nil
@@ -1566,7 +1566,7 @@ func (b *backend) pathRoleSecretIDNumUsesDelete(ctx context.Context, req *logica
lock.Lock()
defer lock.Unlock()
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName))
if err != nil {
return nil, err
}
@@ -1576,7 +1576,7 @@ func (b *backend) pathRoleSecretIDNumUsesDelete(ctx context.Context, req *logica
role.SecretIDNumUses = data.GetDefaultOrZero("secret_id_num_uses").(int)
return nil, b.setRoleEntry(req.Storage, roleName, role, "")
return nil, b.setRoleEntry(ctx, req.Storage, roleName, role, "")
}
func (b *backend) pathRoleSecretIDTTLUpdate(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
@@ -1589,7 +1589,7 @@ func (b *backend) pathRoleSecretIDTTLUpdate(ctx context.Context, req *logical.Re
lock.Lock()
defer lock.Unlock()
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName))
if err != nil {
return nil, err
}
@@ -1599,7 +1599,7 @@ func (b *backend) pathRoleSecretIDTTLUpdate(ctx context.Context, req *logical.Re
if secretIDTTLRaw, ok := data.GetOk("secret_id_ttl"); ok {
role.SecretIDTTL = time.Second * time.Duration(secretIDTTLRaw.(int))
return nil, b.setRoleEntry(req.Storage, roleName, role, "")
return nil, b.setRoleEntry(ctx, req.Storage, roleName, role, "")
} else {
return logical.ErrorResponse("missing secret_id_ttl"), nil
}
@@ -1615,7 +1615,7 @@ func (b *backend) pathRoleSecretIDTTLRead(ctx context.Context, req *logical.Requ
lock.RLock()
defer lock.RUnlock()
if role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)); err != nil {
if role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName)); err != nil {
return nil, err
} else if role == nil {
return nil, nil
@@ -1639,7 +1639,7 @@ func (b *backend) pathRoleSecretIDTTLDelete(ctx context.Context, req *logical.Re
lock.Lock()
defer lock.Unlock()
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName))
if err != nil {
return nil, err
}
@@ -1649,7 +1649,7 @@ func (b *backend) pathRoleSecretIDTTLDelete(ctx context.Context, req *logical.Re
role.SecretIDTTL = time.Second * time.Duration(data.GetDefaultOrZero("secret_id_ttl").(int))
return nil, b.setRoleEntry(req.Storage, roleName, role, "")
return nil, b.setRoleEntry(ctx, req.Storage, roleName, role, "")
}
func (b *backend) pathRolePeriodUpdate(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
@@ -1662,7 +1662,7 @@ func (b *backend) pathRolePeriodUpdate(ctx context.Context, req *logical.Request
lock.Lock()
defer lock.Unlock()
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName))
if err != nil {
return nil, err
}
@@ -1675,7 +1675,7 @@ func (b *backend) pathRolePeriodUpdate(ctx context.Context, req *logical.Request
if role.Period > b.System().MaxLeaseTTL() {
return logical.ErrorResponse(fmt.Sprintf("period of %q is greater than the backend's maximum lease TTL of %q", role.Period.String(), b.System().MaxLeaseTTL().String())), nil
}
return nil, b.setRoleEntry(req.Storage, roleName, role, "")
return nil, b.setRoleEntry(ctx, req.Storage, roleName, role, "")
} else {
return logical.ErrorResponse("missing period"), nil
}
@@ -1691,7 +1691,7 @@ func (b *backend) pathRolePeriodRead(ctx context.Context, req *logical.Request,
lock.RLock()
defer lock.RUnlock()
if role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)); err != nil {
if role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName)); err != nil {
return nil, err
} else if role == nil {
return nil, nil
@@ -1715,7 +1715,7 @@ func (b *backend) pathRolePeriodDelete(ctx context.Context, req *logical.Request
lock.Lock()
defer lock.Unlock()
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName))
if err != nil {
return nil, err
}
@@ -1725,7 +1725,7 @@ func (b *backend) pathRolePeriodDelete(ctx context.Context, req *logical.Request
role.Period = time.Second * time.Duration(data.GetDefaultOrZero("period").(int))
return nil, b.setRoleEntry(req.Storage, roleName, role, "")
return nil, b.setRoleEntry(ctx, req.Storage, roleName, role, "")
}
func (b *backend) pathRoleTokenNumUsesUpdate(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
@@ -1738,7 +1738,7 @@ func (b *backend) pathRoleTokenNumUsesUpdate(ctx context.Context, req *logical.R
lock.Lock()
defer lock.Unlock()
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName))
if err != nil {
return nil, err
}
@@ -1748,7 +1748,7 @@ func (b *backend) pathRoleTokenNumUsesUpdate(ctx context.Context, req *logical.R
if tokenNumUsesRaw, ok := data.GetOk("token_num_uses"); ok {
role.TokenNumUses = tokenNumUsesRaw.(int)
return nil, b.setRoleEntry(req.Storage, roleName, role, "")
return nil, b.setRoleEntry(ctx, req.Storage, roleName, role, "")
} else {
return logical.ErrorResponse("missing token_num_uses"), nil
}
@@ -1764,7 +1764,7 @@ func (b *backend) pathRoleTokenNumUsesRead(ctx context.Context, req *logical.Req
lock.RLock()
defer lock.RUnlock()
if role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)); err != nil {
if role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName)); err != nil {
return nil, err
} else if role == nil {
return nil, nil
@@ -1787,7 +1787,7 @@ func (b *backend) pathRoleTokenNumUsesDelete(ctx context.Context, req *logical.R
lock.Lock()
defer lock.Unlock()
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName))
if err != nil {
return nil, err
}
@@ -1797,7 +1797,7 @@ func (b *backend) pathRoleTokenNumUsesDelete(ctx context.Context, req *logical.R
role.TokenNumUses = data.GetDefaultOrZero("token_num_uses").(int)
return nil, b.setRoleEntry(req.Storage, roleName, role, "")
return nil, b.setRoleEntry(ctx, req.Storage, roleName, role, "")
}
func (b *backend) pathRoleTokenTTLUpdate(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
@@ -1810,7 +1810,7 @@ func (b *backend) pathRoleTokenTTLUpdate(ctx context.Context, req *logical.Reque
lock.Lock()
defer lock.Unlock()
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName))
if err != nil {
return nil, err
}
@@ -1823,7 +1823,7 @@ func (b *backend) pathRoleTokenTTLUpdate(ctx context.Context, req *logical.Reque
if role.TokenMaxTTL > time.Duration(0) && role.TokenTTL > role.TokenMaxTTL {
return logical.ErrorResponse("token_ttl should not be greater than token_max_ttl"), nil
}
return nil, b.setRoleEntry(req.Storage, roleName, role, "")
return nil, b.setRoleEntry(ctx, req.Storage, roleName, role, "")
} else {
return logical.ErrorResponse("missing token_ttl"), nil
}
@@ -1839,7 +1839,7 @@ func (b *backend) pathRoleTokenTTLRead(ctx context.Context, req *logical.Request
lock.RLock()
defer lock.RUnlock()
if role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)); err != nil {
if role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName)); err != nil {
return nil, err
} else if role == nil {
return nil, nil
@@ -1863,7 +1863,7 @@ func (b *backend) pathRoleTokenTTLDelete(ctx context.Context, req *logical.Reque
lock.Lock()
defer lock.Unlock()
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName))
if err != nil {
return nil, err
}
@@ -1873,7 +1873,7 @@ func (b *backend) pathRoleTokenTTLDelete(ctx context.Context, req *logical.Reque
role.TokenTTL = time.Second * time.Duration(data.GetDefaultOrZero("token_ttl").(int))
return nil, b.setRoleEntry(req.Storage, roleName, role, "")
return nil, b.setRoleEntry(ctx, req.Storage, roleName, role, "")
}
func (b *backend) pathRoleTokenMaxTTLUpdate(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
@@ -1886,7 +1886,7 @@ func (b *backend) pathRoleTokenMaxTTLUpdate(ctx context.Context, req *logical.Re
lock.Lock()
defer lock.Unlock()
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName))
if err != nil {
return nil, err
}
@@ -1899,7 +1899,7 @@ func (b *backend) pathRoleTokenMaxTTLUpdate(ctx context.Context, req *logical.Re
if role.TokenMaxTTL > time.Duration(0) && role.TokenTTL > role.TokenMaxTTL {
return logical.ErrorResponse("token_max_ttl should be greater than or equal to token_ttl"), nil
}
return nil, b.setRoleEntry(req.Storage, roleName, role, "")
return nil, b.setRoleEntry(ctx, req.Storage, roleName, role, "")
} else {
return logical.ErrorResponse("missing token_max_ttl"), nil
}
@@ -1915,7 +1915,7 @@ func (b *backend) pathRoleTokenMaxTTLRead(ctx context.Context, req *logical.Requ
lock.RLock()
defer lock.RUnlock()
if role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)); err != nil {
if role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName)); err != nil {
return nil, err
} else if role == nil {
return nil, nil
@@ -1939,7 +1939,7 @@ func (b *backend) pathRoleTokenMaxTTLDelete(ctx context.Context, req *logical.Re
lock.Lock()
defer lock.Unlock()
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName))
if err != nil {
return nil, err
}
@@ -1949,7 +1949,7 @@ func (b *backend) pathRoleTokenMaxTTLDelete(ctx context.Context, req *logical.Re
role.TokenMaxTTL = time.Second * time.Duration(data.GetDefaultOrZero("token_max_ttl").(int))
return nil, b.setRoleEntry(req.Storage, roleName, role, "")
return nil, b.setRoleEntry(ctx, req.Storage, roleName, role, "")
}
func (b *backend) pathRoleSecretIDUpdate(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
@@ -1978,7 +1978,7 @@ func (b *backend) handleRoleSecretIDCommon(ctx context.Context, req *logical.Req
lock.RLock()
defer lock.RUnlock()
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName))
if err != nil {
return nil, err
}
@@ -2026,7 +2026,7 @@ func (b *backend) handleRoleSecretIDCommon(ctx context.Context, req *logical.Req
roleName = strings.ToLower(roleName)
}
if secretIDStorage, err = b.registerSecretIDEntry(req.Storage, roleName, secretID, role.HMACKey, secretIDStorage); err != nil {
if secretIDStorage, err = b.registerSecretIDEntry(ctx, req.Storage, roleName, secretID, role.HMACKey, secretIDStorage); err != nil {
return nil, fmt.Errorf("failed to store secret_id: %v", err)
}
@@ -2047,7 +2047,7 @@ func (b *backend) roleLock(roleName string) *locksutil.LockEntry {
}
// setRoleIDEntry creates a storage entry that maps RoleID to Role
func (b *backend) setRoleIDEntry(s logical.Storage, roleID string, roleIDEntry *roleIDStorageEntry) error {
func (b *backend) setRoleIDEntry(ctx context.Context, s logical.Storage, roleID string, roleIDEntry *roleIDStorageEntry) error {
lock := b.roleIDLock(roleID)
lock.Lock()
defer lock.Unlock()
@@ -2062,14 +2062,14 @@ func (b *backend) setRoleIDEntry(s logical.Storage, roleID string, roleIDEntry *
if err != nil {
return err
}
if err = s.Put(entry); err != nil {
if err = s.Put(ctx, entry); err != nil {
return err
}
return nil
}
// roleIDEntry is used to read the storage entry that maps RoleID to Role
func (b *backend) roleIDEntry(s logical.Storage, roleID string) (*roleIDStorageEntry, error) {
func (b *backend) roleIDEntry(ctx context.Context, s logical.Storage, roleID string) (*roleIDStorageEntry, error) {
if roleID == "" {
return nil, fmt.Errorf("missing roleID")
}
@@ -2086,7 +2086,7 @@ func (b *backend) roleIDEntry(s logical.Storage, roleID string) (*roleIDStorageE
}
entryIndex := "role_id/" + salt.SaltID(roleID)
if entry, err := s.Get(entryIndex); err != nil {
if entry, err := s.Get(ctx, entryIndex); err != nil {
return nil, err
} else if entry == nil {
return nil, nil
@@ -2099,7 +2099,7 @@ func (b *backend) roleIDEntry(s logical.Storage, roleID string) (*roleIDStorageE
// roleIDEntryDelete is used to remove the secondary index that maps the
// RoleID to the Role itself.
func (b *backend) roleIDEntryDelete(s logical.Storage, roleID string) error {
func (b *backend) roleIDEntryDelete(ctx context.Context, s logical.Storage, roleID string) error {
if roleID == "" {
return fmt.Errorf("missing roleID")
}
@@ -2114,7 +2114,7 @@ func (b *backend) roleIDEntryDelete(s logical.Storage, roleID string) error {
}
entryIndex := "role_id/" + salt.SaltID(roleID)
return s.Delete(entryIndex)
return s.Delete(ctx, entryIndex)
}
var roleHelp = map[string][2]string{

View File

@@ -26,7 +26,7 @@ func TestApprole_RoleNameLowerCasing(t *testing.T) {
Policies: []string{"default"},
BindSecretID: true,
}
err = b.setRoleEntry(storage, "testRoleName", role, "")
err = b.setRoleEntry(context.Background(), storage, "testRoleName", role, "")
if err != nil {
t.Fatal(err)
}
@@ -208,7 +208,7 @@ func TestAppRole_RoleReadSetIndex(t *testing.T) {
roleID := resp.Data["role_id"].(string)
// Delete the role ID index
err = b.roleIDEntryDelete(storage, roleID)
err = b.roleIDEntryDelete(context.Background(), storage, roleID)
if err != nil {
t.Fatal(err)
}
@@ -225,7 +225,7 @@ func TestAppRole_RoleReadSetIndex(t *testing.T) {
t.Fatalf("bad: expected a warning in the response")
}
roleIDIndex, err := b.roleIDEntry(storage, roleID)
roleIDIndex, err := b.roleIDEntry(context.Background(), storage, roleID)
if err != nil {
t.Fatal(err)
}

View File

@@ -25,7 +25,7 @@ func pathTidySecretID(b *backend) *framework.Path {
}
// tidySecretID is used to delete entries in the whitelist that are expired.
func (b *backend) tidySecretID(s logical.Storage) error {
func (b *backend) tidySecretID(ctx context.Context, s logical.Storage) error {
grabbed := atomic.CompareAndSwapUint32(&b.tidySecretIDCASGuard, 0, 1)
if grabbed {
defer atomic.StoreUint32(&b.tidySecretIDCASGuard, 0)
@@ -33,7 +33,7 @@ func (b *backend) tidySecretID(s logical.Storage) error {
return fmt.Errorf("SecretID tidy operation already running")
}
roleNameHMACs, err := s.List("secret_id/")
roleNameHMACs, err := s.List(ctx, "secret_id/")
if err != nil {
return err
}
@@ -41,7 +41,7 @@ func (b *backend) tidySecretID(s logical.Storage) error {
var result error
for _, roleNameHMAC := range roleNameHMACs {
// roleNameHMAC will already have a '/' suffix. Don't append another one.
secretIDHMACs, err := s.List(fmt.Sprintf("secret_id/%s", roleNameHMAC))
secretIDHMACs, err := s.List(ctx, fmt.Sprintf("secret_id/%s", roleNameHMAC))
if err != nil {
return err
}
@@ -52,7 +52,7 @@ func (b *backend) tidySecretID(s logical.Storage) error {
lock.Lock()
// roleNameHMAC will already have a '/' suffix. Don't append another one.
entryIndex := fmt.Sprintf("secret_id/%s%s", roleNameHMAC, secretIDHMAC)
secretIDEntry, err := s.Get(entryIndex)
secretIDEntry, err := s.Get(ctx, entryIndex)
if err != nil {
lock.Unlock()
return fmt.Errorf("error fetching SecretID %s: %s", secretIDHMAC, err)
@@ -77,7 +77,7 @@ func (b *backend) tidySecretID(s logical.Storage) error {
// ExpirationTime not being set indicates non-expiring SecretIDs
if !result.ExpirationTime.IsZero() && time.Now().After(result.ExpirationTime) {
if err := s.Delete(entryIndex); err != nil {
if err := s.Delete(ctx, entryIndex); err != nil {
lock.Unlock()
return fmt.Errorf("error deleting SecretID %s from storage: %s", secretIDHMAC, err)
}
@@ -90,7 +90,7 @@ func (b *backend) tidySecretID(s logical.Storage) error {
// pathTidySecretIDUpdate is used to delete the expired SecretID entries
func (b *backend) pathTidySecretIDUpdate(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
return nil, b.tidySecretID(req.Storage)
return nil, b.tidySecretID(ctx, req.Storage)
}
const pathTidySecretIDSyn = "Trigger the clean-up of expired SecretID entries."

View File

@@ -1,6 +1,7 @@
package approle
import (
"context"
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
@@ -68,9 +69,9 @@ type secretIDAccessorStorageEntry struct {
}
// Checks if the Role represented by the RoleID still exists
func (b *backend) validateRoleID(s logical.Storage, roleID string) (*roleStorageEntry, string, error) {
func (b *backend) validateRoleID(ctx context.Context, s logical.Storage, roleID string) (*roleStorageEntry, string, error) {
// Look for the storage entry that maps the roleID to role
roleIDIndex, err := b.roleIDEntry(s, roleID)
roleIDIndex, err := b.roleIDEntry(ctx, s, roleID)
if err != nil {
return nil, "", err
}
@@ -82,7 +83,7 @@ func (b *backend) validateRoleID(s logical.Storage, roleID string) (*roleStorage
lock.RLock()
defer lock.RUnlock()
role, err := b.roleEntry(s, roleIDIndex.Name)
role, err := b.roleEntry(ctx, s, roleIDIndex.Name)
if err != nil {
return nil, "", err
}
@@ -94,7 +95,7 @@ func (b *backend) validateRoleID(s logical.Storage, roleID string) (*roleStorage
}
// Validates the supplied RoleID and SecretID
func (b *backend) validateCredentials(req *logical.Request, data *framework.FieldData) (*roleStorageEntry, string, map[string]string, string, error) {
func (b *backend) validateCredentials(ctx context.Context, req *logical.Request, data *framework.FieldData) (*roleStorageEntry, string, map[string]string, string, error) {
metadata := make(map[string]string)
// RoleID must be supplied during every login
roleID := strings.TrimSpace(data.Get("role_id").(string))
@@ -103,7 +104,7 @@ func (b *backend) validateCredentials(req *logical.Request, data *framework.Fiel
}
// Validate the RoleID and get the Role entry
role, roleName, err := b.validateRoleID(req.Storage, roleID)
role, roleName, err := b.validateRoleID(ctx, req.Storage, roleID)
if err != nil {
return nil, "", metadata, "", err
}
@@ -132,7 +133,7 @@ func (b *backend) validateCredentials(req *logical.Request, data *framework.Fiel
// Check if the SecretID supplied is valid. If use limit was specified
// on the SecretID, it will be decremented in this call.
var valid bool
valid, metadata, err = b.validateBindSecretID(req, roleName, secretID, role.HMACKey, role.BoundCIDRList)
valid, metadata, err = b.validateBindSecretID(ctx, req, roleName, secretID, role.HMACKey, role.BoundCIDRList)
if err != nil {
return nil, "", metadata, "", err
}
@@ -160,7 +161,7 @@ func (b *backend) validateCredentials(req *logical.Request, data *framework.Fiel
}
// validateBindSecretID is used to determine if the given SecretID is a valid one.
func (b *backend) validateBindSecretID(req *logical.Request, roleName, secretID,
func (b *backend) validateBindSecretID(ctx context.Context, req *logical.Request, roleName, secretID,
hmacKey, roleBoundCIDRList string) (bool, map[string]string, error) {
secretIDHMAC, err := createHMAC(hmacKey, secretID)
if err != nil {
@@ -180,7 +181,7 @@ func (b *backend) validateBindSecretID(req *logical.Request, roleName, secretID,
lock := b.secretIDLock(secretIDHMAC)
lock.RLock()
result, err := b.nonLockedSecretIDStorageEntry(req.Storage, roleNameHMAC, secretIDHMAC)
result, err := b.nonLockedSecretIDStorageEntry(ctx, req.Storage, roleNameHMAC, secretIDHMAC)
if err != nil {
lock.RUnlock()
return false, nil, err
@@ -225,7 +226,7 @@ func (b *backend) validateBindSecretID(req *logical.Request, roleName, secretID,
defer lock.Unlock()
// Lock switching may change the data. Refresh the contents.
result, err = b.nonLockedSecretIDStorageEntry(req.Storage, roleNameHMAC, secretIDHMAC)
result, err = b.nonLockedSecretIDStorageEntry(ctx, req.Storage, roleNameHMAC, secretIDHMAC)
if err != nil {
return false, nil, err
}
@@ -238,10 +239,10 @@ func (b *backend) validateBindSecretID(req *logical.Request, roleName, secretID,
// requests to use the same SecretID will fail.
if result.SecretIDNumUses == 1 {
// Delete the secret IDs accessor first
if err := b.deleteSecretIDAccessorEntry(req.Storage, result.SecretIDAccessor); err != nil {
if err := b.deleteSecretIDAccessorEntry(ctx, req.Storage, result.SecretIDAccessor); err != nil {
return false, nil, err
}
if err := req.Storage.Delete(entryIndex); err != nil {
if err := req.Storage.Delete(ctx, entryIndex); err != nil {
return false, nil, fmt.Errorf("failed to delete secret ID: %v", err)
}
} else {
@@ -250,7 +251,7 @@ func (b *backend) validateBindSecretID(req *logical.Request, roleName, secretID,
result.LastUpdatedTime = time.Now()
if entry, err := logical.StorageEntryJSON(entryIndex, &result); err != nil {
return false, nil, fmt.Errorf("failed to decrement the use count for secret ID %q", secretID)
} else if err = req.Storage.Put(entry); err != nil {
} else if err = req.Storage.Put(ctx, entry); err != nil {
return false, nil, fmt.Errorf("failed to decrement the use count for secret ID %q", secretID)
}
}
@@ -320,7 +321,7 @@ func (b *backend) secretIDAccessorLock(secretIDAccessor string) *locksutil.LockE
// storage. The entry will be indexed based on the given HMACs of both role
// name and the secret ID. This method will not acquire secret ID lock to fetch
// the storage entry. Locks need to be acquired before calling this method.
func (b *backend) nonLockedSecretIDStorageEntry(s logical.Storage, roleNameHMAC, secretIDHMAC string) (*secretIDStorageEntry, error) {
func (b *backend) nonLockedSecretIDStorageEntry(ctx context.Context, s logical.Storage, roleNameHMAC, secretIDHMAC string) (*secretIDStorageEntry, error) {
if secretIDHMAC == "" {
return nil, fmt.Errorf("missing secret ID HMAC")
}
@@ -332,7 +333,7 @@ func (b *backend) nonLockedSecretIDStorageEntry(s logical.Storage, roleNameHMAC,
// Prepare the storage index at which the secret ID will be stored
entryIndex := fmt.Sprintf("secret_id/%s/%s", roleNameHMAC, secretIDHMAC)
entry, err := s.Get(entryIndex)
entry, err := s.Get(ctx, entryIndex)
if err != nil {
return nil, err
}
@@ -360,7 +361,7 @@ func (b *backend) nonLockedSecretIDStorageEntry(s logical.Storage, roleNameHMAC,
}
if persistNeeded {
if err := b.nonLockedSetSecretIDStorageEntry(s, roleNameHMAC, secretIDHMAC, &result); err != nil {
if err := b.nonLockedSetSecretIDStorageEntry(ctx, s, roleNameHMAC, secretIDHMAC, &result); err != nil {
return nil, fmt.Errorf("failed to upgrade role storage entry %s", err)
}
}
@@ -373,7 +374,7 @@ func (b *backend) nonLockedSecretIDStorageEntry(s logical.Storage, roleNameHMAC,
// role name and the secret ID. This method will not acquire secret ID lock to
// create/update the storage entry. Locks need to be acquired before calling
// this method.
func (b *backend) nonLockedSetSecretIDStorageEntry(s logical.Storage, roleNameHMAC, secretIDHMAC string, secretEntry *secretIDStorageEntry) error {
func (b *backend) nonLockedSetSecretIDStorageEntry(ctx context.Context, s logical.Storage, roleNameHMAC, secretIDHMAC string, secretEntry *secretIDStorageEntry) error {
if secretIDHMAC == "" {
return fmt.Errorf("missing secret ID HMAC")
}
@@ -390,7 +391,7 @@ func (b *backend) nonLockedSetSecretIDStorageEntry(s logical.Storage, roleNameHM
if entry, err := logical.StorageEntryJSON(entryIndex, secretEntry); err != nil {
return err
} else if err = s.Put(entry); err != nil {
} else if err = s.Put(ctx, entry); err != nil {
return err
}
@@ -398,7 +399,7 @@ func (b *backend) nonLockedSetSecretIDStorageEntry(s logical.Storage, roleNameHM
}
// registerSecretIDEntry creates a new storage entry for the given SecretID.
func (b *backend) registerSecretIDEntry(s logical.Storage, roleName, secretID, hmacKey string, secretEntry *secretIDStorageEntry) (*secretIDStorageEntry, error) {
func (b *backend) registerSecretIDEntry(ctx context.Context, s logical.Storage, roleName, secretID, hmacKey string, secretEntry *secretIDStorageEntry) (*secretIDStorageEntry, error) {
secretIDHMAC, err := createHMAC(hmacKey, secretID)
if err != nil {
return nil, fmt.Errorf("failed to create HMAC of secret ID: %v", err)
@@ -411,7 +412,7 @@ func (b *backend) registerSecretIDEntry(s logical.Storage, roleName, secretID, h
lock := b.secretIDLock(secretIDHMAC)
lock.RLock()
entry, err := b.nonLockedSecretIDStorageEntry(s, roleNameHMAC, secretIDHMAC)
entry, err := b.nonLockedSecretIDStorageEntry(ctx, s, roleNameHMAC, secretIDHMAC)
if err != nil {
lock.RUnlock()
return nil, err
@@ -428,7 +429,7 @@ func (b *backend) registerSecretIDEntry(s logical.Storage, roleName, secretID, h
defer lock.Unlock()
// But before saving a new entry, check if the secretID entry was created during the lock switch.
entry, err = b.nonLockedSecretIDStorageEntry(s, roleNameHMAC, secretIDHMAC)
entry, err = b.nonLockedSecretIDStorageEntry(ctx, s, roleNameHMAC, secretIDHMAC)
if err != nil {
return nil, err
}
@@ -457,11 +458,11 @@ func (b *backend) registerSecretIDEntry(s logical.Storage, roleName, secretID, h
}
// Before storing the SecretID, store its accessor.
if err := b.createSecretIDAccessorEntry(s, secretEntry, secretIDHMAC); err != nil {
if err := b.createSecretIDAccessorEntry(ctx, s, secretEntry, secretIDHMAC); err != nil {
return nil, err
}
if err := b.nonLockedSetSecretIDStorageEntry(s, roleNameHMAC, secretIDHMAC, secretEntry); err != nil {
if err := b.nonLockedSetSecretIDStorageEntry(ctx, s, roleNameHMAC, secretIDHMAC, secretEntry); err != nil {
return nil, err
}
@@ -470,7 +471,7 @@ func (b *backend) registerSecretIDEntry(s logical.Storage, roleName, secretID, h
// secretIDAccessorEntry is used to read the storage entry that maps an
// accessor to a secret_id.
func (b *backend) secretIDAccessorEntry(s logical.Storage, secretIDAccessor string) (*secretIDAccessorStorageEntry, error) {
func (b *backend) secretIDAccessorEntry(ctx context.Context, s logical.Storage, secretIDAccessor string) (*secretIDAccessorStorageEntry, error) {
if secretIDAccessor == "" {
return nil, fmt.Errorf("missing secretIDAccessor")
}
@@ -488,7 +489,7 @@ func (b *backend) secretIDAccessorEntry(s logical.Storage, secretIDAccessor stri
accessorLock.RLock()
defer accessorLock.RUnlock()
if entry, err := s.Get(entryIndex); err != nil {
if entry, err := s.Get(ctx, entryIndex); err != nil {
return nil, err
} else if entry == nil {
return nil, nil
@@ -502,7 +503,7 @@ func (b *backend) secretIDAccessorEntry(s logical.Storage, secretIDAccessor stri
// createSecretIDAccessorEntry creates an identifier for the SecretID. A storage index,
// mapping the accessor to the SecretID is also created. This method should
// be called when the lock for the corresponding SecretID is held.
func (b *backend) createSecretIDAccessorEntry(s logical.Storage, entry *secretIDStorageEntry, secretIDHMAC string) error {
func (b *backend) createSecretIDAccessorEntry(ctx context.Context, s logical.Storage, entry *secretIDStorageEntry, secretIDHMAC string) error {
// Create a random accessor
accessorUUID, err := uuid.GenerateUUID()
if err != nil {
@@ -525,7 +526,7 @@ func (b *backend) createSecretIDAccessorEntry(s logical.Storage, entry *secretID
SecretIDHMAC: secretIDHMAC,
}); err != nil {
return err
} else if err = s.Put(entry); err != nil {
} else if err = s.Put(ctx, entry); err != nil {
return fmt.Errorf("failed to persist accessor index entry: %v", err)
}
@@ -533,7 +534,7 @@ func (b *backend) createSecretIDAccessorEntry(s logical.Storage, entry *secretID
}
// deleteSecretIDAccessorEntry deletes the storage index mapping the accessor to a SecretID.
func (b *backend) deleteSecretIDAccessorEntry(s logical.Storage, secretIDAccessor string) error {
func (b *backend) deleteSecretIDAccessorEntry(ctx context.Context, s logical.Storage, secretIDAccessor string) error {
salt, err := b.Salt()
if err != nil {
return err
@@ -545,7 +546,7 @@ func (b *backend) deleteSecretIDAccessorEntry(s logical.Storage, secretIDAccesso
defer accessorLock.Unlock()
// Delete the accessor of the SecretID first
if err := s.Delete(accessorEntryIndex); err != nil {
if err := s.Delete(ctx, accessorEntryIndex); err != nil {
return fmt.Errorf("failed to delete accessor storage entry: %v", err)
}
@@ -554,7 +555,7 @@ func (b *backend) deleteSecretIDAccessorEntry(s logical.Storage, secretIDAccesso
// flushRoleSecrets deletes all the SecretIDs that belong to the given
// RoleID.
func (b *backend) flushRoleSecrets(s logical.Storage, roleName, hmacKey string) error {
func (b *backend) flushRoleSecrets(ctx context.Context, s logical.Storage, roleName, hmacKey string) error {
roleNameHMAC, err := createHMAC(hmacKey, roleName)
if err != nil {
return fmt.Errorf("failed to create HMAC of role_name: %v", err)
@@ -564,7 +565,7 @@ func (b *backend) flushRoleSecrets(s logical.Storage, roleName, hmacKey string)
b.secretIDListingLock.RLock()
defer b.secretIDListingLock.RUnlock()
secretIDHMACs, err := s.List(fmt.Sprintf("secret_id/%s/", roleNameHMAC))
secretIDHMACs, err := s.List(ctx, fmt.Sprintf("secret_id/%s/", roleNameHMAC))
if err != nil {
return err
}
@@ -573,7 +574,7 @@ func (b *backend) flushRoleSecrets(s logical.Storage, roleName, hmacKey string)
lock := b.secretIDLock(secretIDHMAC)
lock.Lock()
entryIndex := fmt.Sprintf("secret_id/%s/%s", roleNameHMAC, secretIDHMAC)
if err := s.Delete(entryIndex); err != nil {
if err := s.Delete(ctx, entryIndex); err != nil {
lock.Unlock()
return fmt.Errorf("error deleting SecretID %q from storage: %v", secretIDHMAC, err)
}

View File

@@ -1,6 +1,7 @@
package awsauth
import (
"context"
"fmt"
"sync"
"time"
@@ -13,12 +14,12 @@ import (
"github.com/patrickmn/go-cache"
)
func Factory(conf *logical.BackendConfig) (logical.Backend, error) {
func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) {
b, err := Backend(conf)
if err != nil {
return nil, err
}
if err := b.Setup(conf); err != nil {
if err := b.Setup(ctx, conf); err != nil {
return nil, err
}
return b, nil
@@ -73,7 +74,7 @@ type backend struct {
// accounts using their IAM instance profile to get their credentials.
defaultAWSAccountID string
resolveArnToUniqueIDFunc func(logical.Storage, string) (string, error)
resolveArnToUniqueIDFunc func(context.Context, logical.Storage, string) (string, error)
}
func Backend(conf *logical.BackendConfig) (*backend, error) {
@@ -138,13 +139,13 @@ func Backend(conf *logical.BackendConfig) (*backend, error) {
// not once in a minute, but once in an hour, controlled by 'tidyCooldownPeriod'.
// Tidying of blacklist and whitelist are by default enabled. This can be
// changed using `config/tidy/roletags` and `config/tidy/identities` endpoints.
func (b *backend) periodicFunc(req *logical.Request) error {
func (b *backend) periodicFunc(ctx context.Context, req *logical.Request) error {
// Run the tidy operations for the first time. Then run it when current
// time matches the nextTidyTime.
if b.nextTidyTime.IsZero() || !time.Now().Before(b.nextTidyTime) {
// safety_buffer defaults to 180 days for roletag blacklist
safety_buffer := 15552000
tidyBlacklistConfigEntry, err := b.lockedConfigTidyRoleTags(req.Storage)
tidyBlacklistConfigEntry, err := b.lockedConfigTidyRoleTags(ctx, req.Storage)
if err != nil {
return err
}
@@ -160,12 +161,12 @@ func (b *backend) periodicFunc(req *logical.Request) error {
}
// tidy role tags if explicitly not disabled
if !skipBlacklistTidy {
b.tidyBlacklistRoleTag(req.Storage, safety_buffer)
b.tidyBlacklistRoleTag(ctx, req.Storage, safety_buffer)
}
// reset the safety_buffer to 72h
safety_buffer = 259200
tidyWhitelistConfigEntry, err := b.lockedConfigTidyIdentities(req.Storage)
tidyWhitelistConfigEntry, err := b.lockedConfigTidyIdentities(ctx, req.Storage)
if err != nil {
return err
}
@@ -181,7 +182,7 @@ func (b *backend) periodicFunc(req *logical.Request) error {
}
// tidy identities if explicitly not disabled
if !skipWhitelistTidy {
b.tidyWhitelistIdentity(req.Storage, safety_buffer)
b.tidyWhitelistIdentity(ctx, req.Storage, safety_buffer)
}
// Update the time at which to run the tidy functions again.
@@ -190,7 +191,7 @@ func (b *backend) periodicFunc(req *logical.Request) error {
return nil
}
func (b *backend) invalidate(key string) {
func (b *backend) invalidate(ctx context.Context, key string) {
switch key {
case "config/client":
b.configMutex.Lock()
@@ -203,7 +204,7 @@ func (b *backend) invalidate(key string) {
// Putting this here so we can inject a fake resolver into the backend for unit testing
// purposes
func (b *backend) resolveArnToRealUniqueId(s logical.Storage, arn string) (string, error) {
func (b *backend) resolveArnToRealUniqueId(ctx context.Context, s logical.Storage, arn string) (string, error) {
entity, err := parseIamArn(arn)
if err != nil {
return "", err
@@ -223,7 +224,7 @@ func (b *backend) resolveArnToRealUniqueId(s logical.Storage, arn string) (strin
if region == nil {
return "", fmt.Errorf("Unable to resolve partition %q to a region", entity.Partition)
}
iamClient, err := b.clientIAM(s, region.ID(), entity.AccountNumber)
iamClient, err := b.clientIAM(ctx, s, region.ID(), entity.AccountNumber)
if err != nil {
return "", err
}

View File

@@ -30,7 +30,8 @@ func TestBackend_CreateParseVerifyRoleTag(t *testing.T) {
if err != nil {
t.Fatal(err)
}
err = b.Setup(config)
err = b.Setup(context.Background(), config)
if err != nil {
t.Fatal(err)
}
@@ -55,7 +56,7 @@ func TestBackend_CreateParseVerifyRoleTag(t *testing.T) {
}
// read the created role entry
roleEntry, err := b.lockedAWSRole(storage, "abcd-123")
roleEntry, err := b.lockedAWSRole(context.Background(), storage, "abcd-123")
if err != nil {
t.Fatal(err)
}
@@ -83,7 +84,7 @@ func TestBackend_CreateParseVerifyRoleTag(t *testing.T) {
}
// parse the created role tag
rTag2, err := b.parseAndVerifyRoleTagValue(storage, val)
rTag2, err := b.parseAndVerifyRoleTagValue(context.Background(), storage, val)
if err != nil {
t.Fatal(err)
}
@@ -122,7 +123,7 @@ func TestBackend_CreateParseVerifyRoleTag(t *testing.T) {
}
// get the entry of the newly created role entry
roleEntry2, err := b.lockedAWSRole(storage, "ami-6789")
roleEntry2, err := b.lockedAWSRole(context.Background(), storage, "ami-6789")
if err != nil {
t.Fatal(err)
}
@@ -254,7 +255,8 @@ func TestBackend_ConfigTidyIdentities(t *testing.T) {
if err != nil {
t.Fatal(err)
}
err = b.Setup(config)
err = b.Setup(context.Background(), config)
if err != nil {
t.Fatal(err)
}
@@ -308,7 +310,8 @@ func TestBackend_ConfigTidyRoleTags(t *testing.T) {
if err != nil {
t.Fatal(err)
}
err = b.Setup(config)
err = b.Setup(context.Background(), config)
if err != nil {
t.Fatal(err)
}
@@ -362,7 +365,8 @@ func TestBackend_TidyIdentities(t *testing.T) {
if err != nil {
t.Fatal(err)
}
err = b.Setup(config)
err = b.Setup(context.Background(), config)
if err != nil {
t.Fatal(err)
}
@@ -387,7 +391,8 @@ func TestBackend_TidyRoleTags(t *testing.T) {
if err != nil {
t.Fatal(err)
}
err = b.Setup(config)
err = b.Setup(context.Background(), config)
if err != nil {
t.Fatal(err)
}
@@ -412,7 +417,8 @@ func TestBackend_ConfigClient(t *testing.T) {
if err != nil {
t.Fatal(err)
}
err = b.Setup(config)
err = b.Setup(context.Background(), config)
if err != nil {
t.Fatal(err)
}
@@ -549,7 +555,8 @@ func TestBackend_pathConfigCertificate(t *testing.T) {
if err != nil {
t.Fatal(err)
}
err = b.Setup(config)
err = b.Setup(context.Background(), config)
if err != nil {
t.Fatal(err)
}
@@ -704,7 +711,8 @@ func TestBackend_parseAndVerifyRoleTagValue(t *testing.T) {
if err != nil {
t.Fatal(err)
}
err = b.Setup(config)
err = b.Setup(context.Background(), config)
if err != nil {
t.Fatal(err)
}
@@ -763,7 +771,7 @@ func TestBackend_parseAndVerifyRoleTagValue(t *testing.T) {
tagValue := resp.Data["tag_value"].(string)
// parse the value and check if the verifiable values match
rTag, err := b.parseAndVerifyRoleTagValue(storage, tagValue)
rTag, err := b.parseAndVerifyRoleTagValue(context.Background(), storage, tagValue)
if err != nil {
t.Fatalf("err: %s", err)
}
@@ -785,7 +793,8 @@ func TestBackend_PathRoleTag(t *testing.T) {
if err != nil {
t.Fatal(err)
}
err = b.Setup(config)
err = b.Setup(context.Background(), config)
if err != nil {
t.Fatal(err)
}
@@ -850,7 +859,8 @@ func TestBackend_PathBlacklistRoleTag(t *testing.T) {
if err != nil {
t.Fatal(err)
}
err = b.Setup(config)
err = b.Setup(context.Background(), config)
if err != nil {
t.Fatal(err)
}
@@ -939,7 +949,7 @@ func TestBackend_PathBlacklistRoleTag(t *testing.T) {
}
// try to read the deleted entry
tagEntry, err := b.lockedBlacklistRoleTagEntry(storage, tag)
tagEntry, err := b.lockedBlacklistRoleTagEntry(context.Background(), storage, tag)
if err != nil {
t.Fatal(err)
}
@@ -998,7 +1008,8 @@ func TestBackendAcc_LoginWithInstanceIdentityDocAndWhitelistIdentity(t *testing.
if err != nil {
t.Fatal(err)
}
err = b.Setup(config)
err = b.Setup(context.Background(), config)
if err != nil {
t.Fatal(err)
}
@@ -1190,7 +1201,8 @@ func TestBackend_pathStsConfig(t *testing.T) {
if err != nil {
t.Fatal(err)
}
err = b.Setup(config)
err = b.Setup(context.Background(), config)
if err != nil {
t.Fatal(err)
}
@@ -1338,7 +1350,8 @@ func TestBackendAcc_LoginWithCallerIdentity(t *testing.T) {
if err != nil {
t.Fatal(err)
}
err = b.Setup(config)
err = b.Setup(context.Background(), config)
if err != nil {
t.Fatal(err)
}
@@ -1442,11 +1455,11 @@ func TestBackendAcc_LoginWithCallerIdentity(t *testing.T) {
}
fakeArn := "arn:aws:iam::123456789012:role/somePath/FakeRole"
fakeArnResolver := func(s logical.Storage, arn string) (string, error) {
fakeArnResolver := func(ctx context.Context, s logical.Storage, arn string) (string, error) {
if arn == fakeArn {
return fmt.Sprintf("FakeUniqueIdFor%s", fakeArn), nil
}
return b.resolveArnToRealUniqueId(s, arn)
return b.resolveArnToRealUniqueId(context.Background(), s, arn)
}
b.resolveArnToUniqueIDFunc = fakeArnResolver

View File

@@ -1,6 +1,7 @@
package awsauth
import (
"context"
"fmt"
"github.com/aws/aws-sdk-go/aws"
@@ -21,13 +22,13 @@ import (
// * Static credentials from 'config/client'
// * Environment variables
// * Instance metadata role
func (b *backend) getRawClientConfig(s logical.Storage, region, clientType string) (*aws.Config, error) {
func (b *backend) getRawClientConfig(ctx context.Context, s logical.Storage, region, clientType string) (*aws.Config, error) {
credsConfig := &awsutil.CredentialsConfig{
Region: region,
}
// Read the configured secret key and access key
config, err := b.nonLockedClientConfigEntry(s)
config, err := b.nonLockedClientConfigEntry(ctx, s)
if err != nil {
return nil, err
}
@@ -71,9 +72,9 @@ func (b *backend) getRawClientConfig(s logical.Storage, region, clientType strin
// It uses getRawClientConfig to obtain config for the runtime environemnt, and if
// stsRole is a non-empty string, it will use AssumeRole to obtain a set of assumed
// credentials. The credentials will expire after 15 minutes but will auto-refresh.
func (b *backend) getClientConfig(s logical.Storage, region, stsRole, accountID, clientType string) (*aws.Config, error) {
func (b *backend) getClientConfig(ctx context.Context, s logical.Storage, region, stsRole, accountID, clientType string) (*aws.Config, error) {
config, err := b.getRawClientConfig(s, region, clientType)
config, err := b.getRawClientConfig(ctx, s, region, clientType)
if err != nil {
return nil, err
}
@@ -81,7 +82,7 @@ func (b *backend) getClientConfig(s logical.Storage, region, stsRole, accountID,
return nil, fmt.Errorf("could not compile valid credentials through the default provider chain")
}
stsConfig, err := b.getRawClientConfig(s, region, "sts")
stsConfig, err := b.getRawClientConfig(ctx, s, region, "sts")
if stsConfig == nil {
return nil, fmt.Errorf("could not configure STS client")
}
@@ -160,9 +161,9 @@ func (b *backend) setCachedUserId(userId, arn string) {
}
}
func (b *backend) stsRoleForAccount(s logical.Storage, accountID string) (string, error) {
func (b *backend) stsRoleForAccount(ctx context.Context, s logical.Storage, accountID string) (string, error) {
// Check if an STS configuration exists for the AWS account
sts, err := b.lockedAwsStsEntry(s, accountID)
sts, err := b.lockedAwsStsEntry(ctx, s, accountID)
if err != nil {
return "", fmt.Errorf("error fetching STS config for account ID %q: %q\n", accountID, err)
}
@@ -174,8 +175,8 @@ func (b *backend) stsRoleForAccount(s logical.Storage, accountID string) (string
}
// clientEC2 creates a client to interact with AWS EC2 API
func (b *backend) clientEC2(s logical.Storage, region, accountID string) (*ec2.EC2, error) {
stsRole, err := b.stsRoleForAccount(s, accountID)
func (b *backend) clientEC2(ctx context.Context, s logical.Storage, region, accountID string) (*ec2.EC2, error) {
stsRole, err := b.stsRoleForAccount(ctx, s, accountID)
if err != nil {
return nil, err
}
@@ -198,7 +199,7 @@ func (b *backend) clientEC2(s logical.Storage, region, accountID string) (*ec2.E
// Create an AWS config object using a chain of providers
var awsConfig *aws.Config
awsConfig, err = b.getClientConfig(s, region, stsRole, accountID, "ec2")
awsConfig, err = b.getClientConfig(ctx, s, region, stsRole, accountID, "ec2")
if err != nil {
return nil, err
@@ -223,8 +224,8 @@ func (b *backend) clientEC2(s logical.Storage, region, accountID string) (*ec2.E
}
// clientIAM creates a client to interact with AWS IAM API
func (b *backend) clientIAM(s logical.Storage, region, accountID string) (*iam.IAM, error) {
stsRole, err := b.stsRoleForAccount(s, accountID)
func (b *backend) clientIAM(ctx context.Context, s logical.Storage, region, accountID string) (*iam.IAM, error) {
stsRole, err := b.stsRoleForAccount(ctx, s, accountID)
if err != nil {
return nil, err
}
@@ -247,7 +248,7 @@ func (b *backend) clientIAM(s logical.Storage, region, accountID string) (*iam.I
// Create an AWS config object using a chain of providers
var awsConfig *aws.Config
awsConfig, err = b.getClientConfig(s, region, stsRole, accountID, "iam")
awsConfig, err = b.getClientConfig(ctx, s, region, stsRole, accountID, "iam")
if err != nil {
return nil, err

View File

@@ -9,7 +9,6 @@ import (
"math/big"
"strings"
"github.com/fatih/structs"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
)
@@ -131,7 +130,7 @@ func (b *backend) pathConfigCertificateExistenceCheck(ctx context.Context, req *
return false, fmt.Errorf("missing cert_name")
}
entry, err := b.lockedAWSPublicCertificateEntry(req.Storage, certName)
entry, err := b.lockedAWSPublicCertificateEntry(ctx, req.Storage, certName)
if err != nil {
return false, err
}
@@ -143,7 +142,7 @@ func (b *backend) pathCertificatesList(ctx context.Context, req *logical.Request
b.configMutex.RLock()
defer b.configMutex.RUnlock()
certs, err := req.Storage.List("config/certificate/")
certs, err := req.Storage.List(ctx, "config/certificate/")
if err != nil {
return nil, err
}
@@ -174,7 +173,7 @@ func decodePEMAndParseCertificate(certificate string) (*x509.Certificate, error)
// the PKCS7 signatures of the instance identity documents. This method will
// append the certificates registered using `config/certificate/<cert_name>`
// endpoint, along with the default certificate in the backend.
func (b *backend) awsPublicCertificates(s logical.Storage, isPkcs bool) ([]*x509.Certificate, error) {
func (b *backend) awsPublicCertificates(ctx context.Context, s logical.Storage, isPkcs bool) ([]*x509.Certificate, error) {
// Lock at beginning and use internal method so that we are consistent as
// we iterate through
b.configMutex.RLock()
@@ -195,14 +194,14 @@ func (b *backend) awsPublicCertificates(s logical.Storage, isPkcs bool) ([]*x509
certs = append(certs, decodedCert)
// Get the list of all the registered certificates
registeredCerts, err := s.List("config/certificate/")
registeredCerts, err := s.List(ctx, "config/certificate/")
if err != nil {
return nil, err
}
// Iterate through each certificate, parse and append it to a slice
for _, cert := range registeredCerts {
certEntry, err := b.nonLockedAWSPublicCertificateEntry(s, cert)
certEntry, err := b.nonLockedAWSPublicCertificateEntry(ctx, s, cert)
if err != nil {
return nil, err
}
@@ -226,7 +225,7 @@ func (b *backend) awsPublicCertificates(s logical.Storage, isPkcs bool) ([]*x509
// lockedSetAWSPublicCertificateEntry is used to store the AWS public key in
// the storage. This method acquires lock before creating or updating a storage
// entry.
func (b *backend) lockedSetAWSPublicCertificateEntry(s logical.Storage, certName string, certEntry *awsPublicCert) error {
func (b *backend) lockedSetAWSPublicCertificateEntry(ctx context.Context, s logical.Storage, certName string, certEntry *awsPublicCert) error {
if certName == "" {
return fmt.Errorf("missing certificate name")
}
@@ -238,13 +237,13 @@ func (b *backend) lockedSetAWSPublicCertificateEntry(s logical.Storage, certName
b.configMutex.Lock()
defer b.configMutex.Unlock()
return b.nonLockedSetAWSPublicCertificateEntry(s, certName, certEntry)
return b.nonLockedSetAWSPublicCertificateEntry(ctx, s, certName, certEntry)
}
// nonLockedSetAWSPublicCertificateEntry is used to store the AWS public key in
// the storage. This method does not acquire lock before reading the storage.
// If locking is desired, use lockedSetAWSPublicCertificateEntry instead.
func (b *backend) nonLockedSetAWSPublicCertificateEntry(s logical.Storage, certName string, certEntry *awsPublicCert) error {
func (b *backend) nonLockedSetAWSPublicCertificateEntry(ctx context.Context, s logical.Storage, certName string, certEntry *awsPublicCert) error {
if certName == "" {
return fmt.Errorf("missing certificate name")
}
@@ -261,24 +260,24 @@ func (b *backend) nonLockedSetAWSPublicCertificateEntry(s logical.Storage, certN
return fmt.Errorf("failed to create storage entry for AWS public key certificate")
}
return s.Put(entry)
return s.Put(ctx, entry)
}
// lockedAWSPublicCertificateEntry is used to get the configured AWS Public Key
// that is used to verify the PKCS#7 signature of the instance identity
// document.
func (b *backend) lockedAWSPublicCertificateEntry(s logical.Storage, certName string) (*awsPublicCert, error) {
func (b *backend) lockedAWSPublicCertificateEntry(ctx context.Context, s logical.Storage, certName string) (*awsPublicCert, error) {
b.configMutex.RLock()
defer b.configMutex.RUnlock()
return b.nonLockedAWSPublicCertificateEntry(s, certName)
return b.nonLockedAWSPublicCertificateEntry(ctx, s, certName)
}
// nonLockedAWSPublicCertificateEntry reads the certificate information from
// the storage. This method does not acquire lock before reading the storage.
// If locking is desired, use lockedAWSPublicCertificateEntry instead.
func (b *backend) nonLockedAWSPublicCertificateEntry(s logical.Storage, certName string) (*awsPublicCert, error) {
entry, err := s.Get("config/certificate/" + certName)
func (b *backend) nonLockedAWSPublicCertificateEntry(ctx context.Context, s logical.Storage, certName string) (*awsPublicCert, error) {
entry, err := s.Get(ctx, "config/certificate/"+certName)
if err != nil {
return nil, err
}
@@ -298,7 +297,7 @@ func (b *backend) nonLockedAWSPublicCertificateEntry(s logical.Storage, certName
}
if persistNeeded {
if err := b.nonLockedSetAWSPublicCertificateEntry(s, certName, &certEntry); err != nil {
if err := b.nonLockedSetAWSPublicCertificateEntry(ctx, s, certName, &certEntry); err != nil {
return nil, err
}
}
@@ -318,7 +317,7 @@ func (b *backend) pathConfigCertificateDelete(ctx context.Context, req *logical.
return logical.ErrorResponse("missing cert_name"), nil
}
return nil, req.Storage.Delete("config/certificate/" + certName)
return nil, req.Storage.Delete(ctx, "config/certificate/"+certName)
}
// pathConfigCertificateRead is used to view the configured AWS Public Key that
@@ -329,7 +328,7 @@ func (b *backend) pathConfigCertificateRead(ctx context.Context, req *logical.Re
return logical.ErrorResponse("missing cert_name"), nil
}
certificateEntry, err := b.lockedAWSPublicCertificateEntry(req.Storage, certName)
certificateEntry, err := b.lockedAWSPublicCertificateEntry(ctx, req.Storage, certName)
if err != nil {
return nil, err
}
@@ -338,7 +337,10 @@ func (b *backend) pathConfigCertificateRead(ctx context.Context, req *logical.Re
}
return &logical.Response{
Data: structs.New(certificateEntry).Map(),
Data: map[string]interface{}{
"aws_public_cert": certificateEntry.AWSPublicCert,
"type": certificateEntry.Type,
},
}, nil
}
@@ -354,7 +356,7 @@ func (b *backend) pathConfigCertificateCreateUpdate(ctx context.Context, req *lo
defer b.configMutex.Unlock()
// Check if there is already a certificate entry registered
certEntry, err := b.nonLockedAWSPublicCertificateEntry(req.Storage, certName)
certEntry, err := b.nonLockedAWSPublicCertificateEntry(ctx, req.Storage, certName)
if err != nil {
return nil, err
}
@@ -406,7 +408,7 @@ func (b *backend) pathConfigCertificateCreateUpdate(ctx context.Context, req *lo
}
// If none of the checks fail, save the provided certificate
if err := b.nonLockedSetAWSPublicCertificateEntry(req.Storage, certName, certEntry); err != nil {
if err := b.nonLockedSetAWSPublicCertificateEntry(ctx, req.Storage, certName, certEntry); err != nil {
return nil, err
}

View File

@@ -66,7 +66,7 @@ func pathConfigClient(b *backend) *framework.Path {
// Establishes dichotomy of request operation between CreateOperation and UpdateOperation.
// Returning 'true' forces an UpdateOperation, CreateOperation otherwise.
func (b *backend) pathConfigClientExistenceCheck(ctx context.Context, req *logical.Request, data *framework.FieldData) (bool, error) {
entry, err := b.lockedClientConfigEntry(req.Storage)
entry, err := b.lockedClientConfigEntry(ctx, req.Storage)
if err != nil {
return false, err
}
@@ -74,16 +74,16 @@ func (b *backend) pathConfigClientExistenceCheck(ctx context.Context, req *logic
}
// Fetch the client configuration required to access the AWS API, after acquiring an exclusive lock.
func (b *backend) lockedClientConfigEntry(s logical.Storage) (*clientConfig, error) {
func (b *backend) lockedClientConfigEntry(ctx context.Context, s logical.Storage) (*clientConfig, error) {
b.configMutex.RLock()
defer b.configMutex.RUnlock()
return b.nonLockedClientConfigEntry(s)
return b.nonLockedClientConfigEntry(ctx, s)
}
// Fetch the client configuration required to access the AWS API.
func (b *backend) nonLockedClientConfigEntry(s logical.Storage) (*clientConfig, error) {
entry, err := s.Get("config/client")
func (b *backend) nonLockedClientConfigEntry(ctx context.Context, s logical.Storage) (*clientConfig, error) {
entry, err := s.Get(ctx, "config/client")
if err != nil {
return nil, err
}
@@ -99,7 +99,7 @@ func (b *backend) nonLockedClientConfigEntry(s logical.Storage) (*clientConfig,
}
func (b *backend) pathConfigClientRead(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
clientConfig, err := b.lockedClientConfigEntry(req.Storage)
clientConfig, err := b.lockedClientConfigEntry(ctx, req.Storage)
if err != nil {
return nil, err
}
@@ -117,7 +117,7 @@ func (b *backend) pathConfigClientDelete(ctx context.Context, req *logical.Reque
b.configMutex.Lock()
defer b.configMutex.Unlock()
if err := req.Storage.Delete("config/client"); err != nil {
if err := req.Storage.Delete(ctx, "config/client"); err != nil {
return nil, err
}
@@ -139,7 +139,7 @@ func (b *backend) pathConfigClientCreateUpdate(ctx context.Context, req *logical
b.configMutex.Lock()
defer b.configMutex.Unlock()
configEntry, err := b.nonLockedClientConfigEntry(req.Storage)
configEntry, err := b.nonLockedClientConfigEntry(ctx, req.Storage)
if err != nil {
return nil, err
}
@@ -231,7 +231,7 @@ func (b *backend) pathConfigClientCreateUpdate(ctx context.Context, req *logical
}
if changedCreds || changedOtherConfig || req.Operation == logical.CreateOperation {
if err := req.Storage.Put(entry); err != nil {
if err := req.Storage.Put(ctx, entry); err != nil {
return nil, err
}
}

View File

@@ -16,7 +16,8 @@ func TestBackend_pathConfigClient(t *testing.T) {
if err != nil {
t.Fatal(err)
}
err = b.Setup(config)
err = b.Setup(context.Background(), config)
if err != nil {
t.Fatal(err)
}

View File

@@ -66,7 +66,7 @@ func (b *backend) pathConfigStsExistenceCheck(ctx context.Context, req *logical.
return false, fmt.Errorf("missing account_id")
}
entry, err := b.lockedAwsStsEntry(req.Storage, accountID)
entry, err := b.lockedAwsStsEntry(ctx, req.Storage, accountID)
if err != nil {
return false, err
}
@@ -78,7 +78,7 @@ func (b *backend) pathConfigStsExistenceCheck(ctx context.Context, req *logical.
func (b *backend) pathStsList(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
b.configMutex.RLock()
defer b.configMutex.RUnlock()
sts, err := req.Storage.List("config/sts/")
sts, err := req.Storage.List(ctx, "config/sts/")
if err != nil {
return nil, err
}
@@ -88,7 +88,7 @@ func (b *backend) pathStsList(ctx context.Context, req *logical.Request, data *f
// nonLockedSetAwsStsEntry creates or updates an STS role association with the given accountID
// This method does not acquire the write lock before creating or updating. If locking is
// desired, use lockedSetAwsStsEntry instead
func (b *backend) nonLockedSetAwsStsEntry(s logical.Storage, accountID string, stsEntry *awsStsEntry) error {
func (b *backend) nonLockedSetAwsStsEntry(ctx context.Context, s logical.Storage, accountID string, stsEntry *awsStsEntry) error {
if accountID == "" {
return fmt.Errorf("missing AWS account ID")
}
@@ -106,12 +106,12 @@ func (b *backend) nonLockedSetAwsStsEntry(s logical.Storage, accountID string, s
return fmt.Errorf("failed to create storage entry for AWS STS configuration")
}
return s.Put(entry)
return s.Put(ctx, entry)
}
// lockedSetAwsStsEntry creates or updates an STS role association with the given accountID
// This method acquires the write lock before creating or updating the STS entry.
func (b *backend) lockedSetAwsStsEntry(s logical.Storage, accountID string, stsEntry *awsStsEntry) error {
func (b *backend) lockedSetAwsStsEntry(ctx context.Context, s logical.Storage, accountID string, stsEntry *awsStsEntry) error {
if accountID == "" {
return fmt.Errorf("missing AWS account ID")
}
@@ -123,14 +123,14 @@ func (b *backend) lockedSetAwsStsEntry(s logical.Storage, accountID string, stsE
b.configMutex.Lock()
defer b.configMutex.Unlock()
return b.nonLockedSetAwsStsEntry(s, accountID, stsEntry)
return b.nonLockedSetAwsStsEntry(ctx, s, accountID, stsEntry)
}
// nonLockedAwsStsEntry returns the STS role associated with the given accountID.
// This method does not acquire the read lock before returning information. If locking is
// desired, use lockedAwsStsEntry instead
func (b *backend) nonLockedAwsStsEntry(s logical.Storage, accountID string) (*awsStsEntry, error) {
entry, err := s.Get("config/sts/" + accountID)
func (b *backend) nonLockedAwsStsEntry(ctx context.Context, s logical.Storage, accountID string) (*awsStsEntry, error) {
entry, err := s.Get(ctx, "config/sts/"+accountID)
if err != nil {
return nil, err
}
@@ -147,11 +147,11 @@ func (b *backend) nonLockedAwsStsEntry(s logical.Storage, accountID string) (*aw
// lockedAwsStsEntry returns the STS role associated with the given accountID.
// This method acquires the read lock before returning the association.
func (b *backend) lockedAwsStsEntry(s logical.Storage, accountID string) (*awsStsEntry, error) {
func (b *backend) lockedAwsStsEntry(ctx context.Context, s logical.Storage, accountID string) (*awsStsEntry, error) {
b.configMutex.RLock()
defer b.configMutex.RUnlock()
return b.nonLockedAwsStsEntry(s, accountID)
return b.nonLockedAwsStsEntry(ctx, s, accountID)
}
// pathConfigStsRead is used to return information about an STS role/AWS accountID association
@@ -161,7 +161,7 @@ func (b *backend) pathConfigStsRead(ctx context.Context, req *logical.Request, d
return logical.ErrorResponse("missing account id"), nil
}
stsEntry, err := b.lockedAwsStsEntry(req.Storage, accountID)
stsEntry, err := b.lockedAwsStsEntry(ctx, req.Storage, accountID)
if err != nil {
return nil, err
}
@@ -185,7 +185,7 @@ func (b *backend) pathConfigStsCreateUpdate(ctx context.Context, req *logical.Re
defer b.configMutex.Unlock()
// Check if an STS role is already registered
stsEntry, err := b.nonLockedAwsStsEntry(req.Storage, accountID)
stsEntry, err := b.nonLockedAwsStsEntry(ctx, req.Storage, accountID)
if err != nil {
return nil, err
}
@@ -206,7 +206,7 @@ func (b *backend) pathConfigStsCreateUpdate(ctx context.Context, req *logical.Re
}
// save the provided STS role
if err := b.nonLockedSetAwsStsEntry(req.Storage, accountID, stsEntry); err != nil {
if err := b.nonLockedSetAwsStsEntry(ctx, req.Storage, accountID, stsEntry); err != nil {
return nil, err
}
@@ -223,7 +223,7 @@ func (b *backend) pathConfigStsDelete(ctx context.Context, req *logical.Request,
return logical.ErrorResponse("missing account id"), nil
}
return nil, req.Storage.Delete("config/sts/" + accountID)
return nil, req.Storage.Delete(ctx, "config/sts/"+accountID)
}
const pathConfigStsSyn = `

View File

@@ -45,22 +45,22 @@ expiration, before it is removed from the backend storage.`,
}
func (b *backend) pathConfigTidyIdentityWhitelistExistenceCheck(ctx context.Context, req *logical.Request, data *framework.FieldData) (bool, error) {
entry, err := b.lockedConfigTidyIdentities(req.Storage)
entry, err := b.lockedConfigTidyIdentities(ctx, req.Storage)
if err != nil {
return false, err
}
return entry != nil, nil
}
func (b *backend) lockedConfigTidyIdentities(s logical.Storage) (*tidyWhitelistIdentityConfig, error) {
func (b *backend) lockedConfigTidyIdentities(ctx context.Context, s logical.Storage) (*tidyWhitelistIdentityConfig, error) {
b.configMutex.RLock()
defer b.configMutex.RUnlock()
return b.nonLockedConfigTidyIdentities(s)
return b.nonLockedConfigTidyIdentities(ctx, s)
}
func (b *backend) nonLockedConfigTidyIdentities(s logical.Storage) (*tidyWhitelistIdentityConfig, error) {
entry, err := s.Get(identityWhitelistConfigPath)
func (b *backend) nonLockedConfigTidyIdentities(ctx context.Context, s logical.Storage) (*tidyWhitelistIdentityConfig, error) {
entry, err := s.Get(ctx, identityWhitelistConfigPath)
if err != nil {
return nil, err
}
@@ -79,7 +79,7 @@ func (b *backend) pathConfigTidyIdentityWhitelistCreateUpdate(ctx context.Contex
b.configMutex.Lock()
defer b.configMutex.Unlock()
configEntry, err := b.nonLockedConfigTidyIdentities(req.Storage)
configEntry, err := b.nonLockedConfigTidyIdentities(ctx, req.Storage)
if err != nil {
return nil, err
}
@@ -106,7 +106,7 @@ func (b *backend) pathConfigTidyIdentityWhitelistCreateUpdate(ctx context.Contex
return nil, err
}
if err := req.Storage.Put(entry); err != nil {
if err := req.Storage.Put(ctx, entry); err != nil {
return nil, err
}
@@ -114,7 +114,7 @@ func (b *backend) pathConfigTidyIdentityWhitelistCreateUpdate(ctx context.Contex
}
func (b *backend) pathConfigTidyIdentityWhitelistRead(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
clientConfig, err := b.lockedConfigTidyIdentities(req.Storage)
clientConfig, err := b.lockedConfigTidyIdentities(ctx, req.Storage)
if err != nil {
return nil, err
}
@@ -131,7 +131,7 @@ func (b *backend) pathConfigTidyIdentityWhitelistDelete(ctx context.Context, req
b.configMutex.Lock()
defer b.configMutex.Unlock()
return nil, req.Storage.Delete(identityWhitelistConfigPath)
return nil, req.Storage.Delete(ctx, identityWhitelistConfigPath)
}
type tidyWhitelistIdentityConfig struct {

View File

@@ -47,22 +47,22 @@ Defaults to 4320h (180 days).`,
}
func (b *backend) pathConfigTidyRoletagBlacklistExistenceCheck(ctx context.Context, req *logical.Request, data *framework.FieldData) (bool, error) {
entry, err := b.lockedConfigTidyRoleTags(req.Storage)
entry, err := b.lockedConfigTidyRoleTags(ctx, req.Storage)
if err != nil {
return false, err
}
return entry != nil, nil
}
func (b *backend) lockedConfigTidyRoleTags(s logical.Storage) (*tidyBlacklistRoleTagConfig, error) {
func (b *backend) lockedConfigTidyRoleTags(ctx context.Context, s logical.Storage) (*tidyBlacklistRoleTagConfig, error) {
b.configMutex.RLock()
defer b.configMutex.RUnlock()
return b.nonLockedConfigTidyRoleTags(s)
return b.nonLockedConfigTidyRoleTags(ctx, s)
}
func (b *backend) nonLockedConfigTidyRoleTags(s logical.Storage) (*tidyBlacklistRoleTagConfig, error) {
entry, err := s.Get(roletagBlacklistConfigPath)
func (b *backend) nonLockedConfigTidyRoleTags(ctx context.Context, s logical.Storage) (*tidyBlacklistRoleTagConfig, error) {
entry, err := s.Get(ctx, roletagBlacklistConfigPath)
if err != nil {
return nil, err
}
@@ -82,7 +82,7 @@ func (b *backend) pathConfigTidyRoletagBlacklistCreateUpdate(ctx context.Context
b.configMutex.Lock()
defer b.configMutex.Unlock()
configEntry, err := b.nonLockedConfigTidyRoleTags(req.Storage)
configEntry, err := b.nonLockedConfigTidyRoleTags(ctx, req.Storage)
if err != nil {
return nil, err
}
@@ -107,7 +107,7 @@ func (b *backend) pathConfigTidyRoletagBlacklistCreateUpdate(ctx context.Context
return nil, err
}
if err := req.Storage.Put(entry); err != nil {
if err := req.Storage.Put(ctx, entry); err != nil {
return nil, err
}
@@ -115,7 +115,7 @@ func (b *backend) pathConfigTidyRoletagBlacklistCreateUpdate(ctx context.Context
}
func (b *backend) pathConfigTidyRoletagBlacklistRead(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
clientConfig, err := b.lockedConfigTidyRoleTags(req.Storage)
clientConfig, err := b.lockedConfigTidyRoleTags(ctx, req.Storage)
if err != nil {
return nil, err
}
@@ -132,7 +132,7 @@ func (b *backend) pathConfigTidyRoletagBlacklistDelete(ctx context.Context, req
b.configMutex.Lock()
defer b.configMutex.Unlock()
return nil, req.Storage.Delete(roletagBlacklistConfigPath)
return nil, req.Storage.Delete(ctx, roletagBlacklistConfigPath)
}
type tidyBlacklistRoleTagConfig struct {

View File

@@ -46,7 +46,7 @@ func pathListIdentityWhitelist(b *backend) *framework.Path {
// pathWhitelistIdentitiesList is used to list all the instance IDs that are present
// in the identity whitelist. This will list both valid and expired entries.
func (b *backend) pathWhitelistIdentitiesList(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
identities, err := req.Storage.List("whitelist/identity/")
identities, err := req.Storage.List(ctx, "whitelist/identity/")
if err != nil {
return nil, err
}
@@ -54,8 +54,8 @@ func (b *backend) pathWhitelistIdentitiesList(ctx context.Context, req *logical.
}
// Fetch an item from the whitelist given an instance ID.
func whitelistIdentityEntry(s logical.Storage, instanceID string) (*whitelistIdentity, error) {
entry, err := s.Get("whitelist/identity/" + instanceID)
func whitelistIdentityEntry(ctx context.Context, s logical.Storage, instanceID string) (*whitelistIdentity, error) {
entry, err := s.Get(ctx, "whitelist/identity/"+instanceID)
if err != nil {
return nil, err
}
@@ -72,13 +72,13 @@ func whitelistIdentityEntry(s logical.Storage, instanceID string) (*whitelistIde
// Stores an instance ID and the information required to validate further login/renewal attempts from
// the same instance ID.
func setWhitelistIdentityEntry(s logical.Storage, instanceID string, identity *whitelistIdentity) error {
func setWhitelistIdentityEntry(ctx context.Context, s logical.Storage, instanceID string, identity *whitelistIdentity) error {
entry, err := logical.StorageEntryJSON("whitelist/identity/"+instanceID, identity)
if err != nil {
return err
}
if err := s.Put(entry); err != nil {
if err := s.Put(ctx, entry); err != nil {
return err
}
return nil
@@ -91,7 +91,7 @@ func (b *backend) pathIdentityWhitelistDelete(ctx context.Context, req *logical.
return logical.ErrorResponse("missing instance_id"), nil
}
return nil, req.Storage.Delete("whitelist/identity/" + instanceID)
return nil, req.Storage.Delete(ctx, "whitelist/identity/"+instanceID)
}
// pathIdentityWhitelistRead is used to view an entry in the identity whitelist given an instance ID.
@@ -101,7 +101,7 @@ func (b *backend) pathIdentityWhitelistRead(ctx context.Context, req *logical.Re
return logical.ErrorResponse("missing instance_id"), nil
}
entry, err := whitelistIdentityEntry(req.Storage, instanceID)
entry, err := whitelistIdentityEntry(ctx, req.Storage, instanceID)
if err != nil {
return nil, err
}

View File

@@ -154,9 +154,9 @@ func (b *backend) instanceIamRoleARN(iamClient *iam.IAM, instanceProfileName str
// validateInstance queries the status of the EC2 instance using AWS EC2 API
// and checks if the instance is running and is healthy
func (b *backend) validateInstance(s logical.Storage, instanceID, region, accountID string) (*ec2.Instance, error) {
func (b *backend) validateInstance(ctx context.Context, s logical.Storage, instanceID, region, accountID string) (*ec2.Instance, error) {
// Create an EC2 client to pull the instance information
ec2Client, err := b.clientEC2(s, region, accountID)
ec2Client, err := b.clientEC2(ctx, s, region, accountID)
if err != nil {
return nil, err
}
@@ -256,7 +256,7 @@ func validateMetadata(clientNonce, pendingTime string, storedIdentity *whitelist
// Verifies the integrity of the instance identity document using its SHA256
// RSA signature. After verification, returns the unmarshaled instance identity
// document.
func (b *backend) verifyInstanceIdentitySignature(s logical.Storage, identityBytes, signatureBytes []byte) (*identityDocument, error) {
func (b *backend) verifyInstanceIdentitySignature(ctx context.Context, s logical.Storage, identityBytes, signatureBytes []byte) (*identityDocument, error) {
if len(identityBytes) == 0 {
return nil, fmt.Errorf("missing instance identity document")
}
@@ -270,7 +270,7 @@ func (b *backend) verifyInstanceIdentitySignature(s logical.Storage, identityByt
// certificate and all the registered certificates via
// 'config/certificate/<cert_name>' endpoint, for verifying the RSA
// digest.
publicCerts, err := b.awsPublicCertificates(s, false)
publicCerts, err := b.awsPublicCertificates(ctx, s, false)
if err != nil {
return nil, err
}
@@ -297,7 +297,7 @@ func (b *backend) verifyInstanceIdentitySignature(s logical.Storage, identityByt
// Verifies the correctness of the authenticated attributes present in the PKCS#7
// signature. After verification, extracts the instance identity document from the
// signature, parses it and returns it.
func (b *backend) parseIdentityDocument(s logical.Storage, pkcs7B64 string) (*identityDocument, error) {
func (b *backend) parseIdentityDocument(ctx context.Context, s logical.Storage, pkcs7B64 string) (*identityDocument, error) {
// Insert the header and footer for the signature to be able to pem decode it
pkcs7B64 = fmt.Sprintf("-----BEGIN PKCS7-----\n%s\n-----END PKCS7-----", pkcs7B64)
@@ -316,7 +316,7 @@ func (b *backend) parseIdentityDocument(s logical.Storage, pkcs7B64 string) (*id
// Get the public certificates that are used to verify the signature.
// This returns a slice of certificates containing the default certificate
// and all the registered certificates via 'config/certificate/<cert_name>' endpoint
publicCerts, err := b.awsPublicCertificates(s, true)
publicCerts, err := b.awsPublicCertificates(ctx, s, true)
if err != nil {
return nil, err
}
@@ -372,7 +372,7 @@ func (b *backend) pathLoginUpdate(ctx context.Context, req *logical.Request, dat
// error that means the instance doesn't meet the role requirements
// The second error return value indicates whether there's an error in even
// trying to validate those requirements
func (b *backend) verifyInstanceMeetsRoleRequirements(
func (b *backend) verifyInstanceMeetsRoleRequirements(ctx context.Context,
s logical.Storage, instance *ec2.Instance, roleEntry *awsRoleEntry, roleName string, identityDoc *identityDocument) (error, error) {
switch {
@@ -470,7 +470,7 @@ func (b *backend) verifyInstanceMeetsRoleRequirements(
}
// Use instance profile ARN to fetch the associated role ARN
iamClient, err := b.clientIAM(s, identityDoc.Region, identityDoc.AccountID)
iamClient, err := b.clientIAM(ctx, s, identityDoc.Region, identityDoc.AccountID)
if err != nil {
return nil, fmt.Errorf("could not fetch IAM client: %v", err)
} else if iamClient == nil {
@@ -530,7 +530,7 @@ func (b *backend) pathLoginUpdateEc2(ctx context.Context, req *logical.Request,
// Verify the signature of the identity document and unmarshal it
var identityDocParsed *identityDocument
if pkcs7B64 != "" {
identityDocParsed, err = b.parseIdentityDocument(req.Storage, pkcs7B64)
identityDocParsed, err = b.parseIdentityDocument(ctx, req.Storage, pkcs7B64)
if err != nil {
return nil, err
}
@@ -538,7 +538,7 @@ func (b *backend) pathLoginUpdateEc2(ctx context.Context, req *logical.Request,
return logical.ErrorResponse("failed to verify the instance identity document using pkcs7"), nil
}
} else {
identityDocParsed, err = b.verifyInstanceIdentitySignature(req.Storage, identityDocBytes, signatureBytes)
identityDocParsed, err = b.verifyInstanceIdentitySignature(ctx, req.Storage, identityDocBytes, signatureBytes)
if err != nil {
return nil, err
}
@@ -566,7 +566,7 @@ func (b *backend) pathLoginUpdateEc2(ctx context.Context, req *logical.Request,
}
// Get the entry for the role used by the instance
roleEntry, err := b.lockedAWSRole(req.Storage, roleName)
roleEntry, err := b.lockedAWSRole(ctx, req.Storage, roleName)
if err != nil {
return nil, err
}
@@ -581,7 +581,7 @@ func (b *backend) pathLoginUpdateEc2(ctx context.Context, req *logical.Request,
// Validate the instance ID by making a call to AWS EC2 DescribeInstances API
// and fetching the instance description. Validation succeeds only if the
// instance is in 'running' state.
instance, err := b.validateInstance(req.Storage, identityDocParsed.InstanceID, identityDocParsed.Region, identityDocParsed.AccountID)
instance, err := b.validateInstance(ctx, req.Storage, identityDocParsed.InstanceID, identityDocParsed.Region, identityDocParsed.AccountID)
if err != nil {
return logical.ErrorResponse(fmt.Sprintf("failed to verify instance ID: %v", err)), nil
}
@@ -592,7 +592,7 @@ func (b *backend) pathLoginUpdateEc2(ctx context.Context, req *logical.Request,
return logical.ErrorResponse(fmt.Sprintf("Region %q does not satisfy the constraint on role %q", identityDocParsed.Region, roleName)), nil
}
validationError, err := b.verifyInstanceMeetsRoleRequirements(req.Storage, instance, roleEntry, roleName, identityDocParsed)
validationError, err := b.verifyInstanceMeetsRoleRequirements(ctx, req.Storage, instance, roleEntry, roleName, identityDocParsed)
if err != nil {
return nil, err
}
@@ -601,7 +601,7 @@ func (b *backend) pathLoginUpdateEc2(ctx context.Context, req *logical.Request,
}
// Get the entry from the identity whitelist, if there is one
storedIdentity, err := whitelistIdentityEntry(req.Storage, identityDocParsed.InstanceID)
storedIdentity, err := whitelistIdentityEntry(ctx, req.Storage, identityDocParsed.InstanceID)
if err != nil {
return nil, err
}
@@ -682,7 +682,7 @@ func (b *backend) pathLoginUpdateEc2(ctx context.Context, req *logical.Request,
rTagMaxTTL := time.Duration(0)
var roleTagResp *roleTagLoginResponse
if roleEntry.RoleTag != "" {
roleTagResp, err := b.handleRoleTagLogin(req.Storage, roleName, roleEntry, instance)
roleTagResp, err := b.handleRoleTagLogin(ctx, req.Storage, roleName, roleEntry, instance)
if err != nil {
return nil, err
}
@@ -750,7 +750,7 @@ func (b *backend) pathLoginUpdateEc2(ctx context.Context, req *logical.Request,
return logical.ErrorResponse("client nonce exceeding the limit of 128 characters"), nil
}
if err = setWhitelistIdentityEntry(req.Storage, identityDocParsed.InstanceID, storedIdentity); err != nil {
if err = setWhitelistIdentityEntry(ctx, req.Storage, identityDocParsed.InstanceID, storedIdentity); err != nil {
return nil, err
}
@@ -800,7 +800,7 @@ func (b *backend) pathLoginUpdateEc2(ctx context.Context, req *logical.Request,
// handleRoleTagLogin is used to fetch the role tag of the instance and
// verifies it to be correct. Then the policies for the login request will be
// set off of the role tag, if certain creteria satisfies.
func (b *backend) handleRoleTagLogin(s logical.Storage, roleName string, roleEntry *awsRoleEntry, instance *ec2.Instance) (*roleTagLoginResponse, error) {
func (b *backend) handleRoleTagLogin(ctx context.Context, s logical.Storage, roleName string, roleEntry *awsRoleEntry, instance *ec2.Instance) (*roleTagLoginResponse, error) {
if roleEntry == nil {
return nil, fmt.Errorf("nil role entry")
}
@@ -832,7 +832,7 @@ func (b *backend) handleRoleTagLogin(s logical.Storage, roleName string, roleEnt
}
// Parse the role tag into a struct, extract the plaintext part of it and verify its HMAC
rTag, err := b.parseAndVerifyRoleTagValue(s, rTagValue)
rTag, err := b.parseAndVerifyRoleTagValue(ctx, s, rTagValue)
if err != nil {
return nil, err
}
@@ -849,7 +849,7 @@ func (b *backend) handleRoleTagLogin(s logical.Storage, roleName string, roleEnt
}
// Check if the role tag is blacklisted
blacklistEntry, err := b.lockedBlacklistRoleTagEntry(s, rTagValue)
blacklistEntry, err := b.lockedBlacklistRoleTagEntry(ctx, s, rTagValue)
if err != nil {
return nil, err
}
@@ -896,7 +896,7 @@ func (b *backend) pathLoginRenewIam(ctx context.Context, req *logical.Request, d
if roleName == "" {
return nil, fmt.Errorf("error retrieving role_name during renewal")
}
roleEntry, err := b.lockedAWSRole(req.Storage, roleName)
roleEntry, err := b.lockedAWSRole(ctx, req.Storage, roleName)
if err != nil {
return nil, err
}
@@ -924,7 +924,7 @@ func (b *backend) pathLoginRenewIam(ctx context.Context, req *logical.Request, d
if !ok {
return nil, fmt.Errorf("no inferred AWS region in auth metadata")
}
_, err := b.validateInstance(req.Storage, instanceID, instanceRegion, req.Auth.Metadata["account_id"])
_, err := b.validateInstance(ctx, req.Storage, instanceID, instanceRegion, req.Auth.Metadata["account_id"])
if err != nil {
return nil, fmt.Errorf("failed to verify instance ID %q: %v", instanceID, err)
}
@@ -956,7 +956,7 @@ func (b *backend) pathLoginRenewIam(ctx context.Context, req *logical.Request, d
if err != nil {
return nil, fmt.Errorf("error parsing ARN %q: %v", canonicalArn, err)
}
fullArn, err = b.fullArn(entity, req.Storage)
fullArn, err = b.fullArn(ctx, entity, req.Storage)
if err != nil {
return nil, fmt.Errorf("error looking up full ARN of entity %v: %v", entity, err)
}
@@ -1008,12 +1008,12 @@ func (b *backend) pathLoginRenewEc2(ctx context.Context, req *logical.Request, d
}
// Cross check that the instance is still in 'running' state
_, err := b.validateInstance(req.Storage, instanceID, region, accountID)
_, err := b.validateInstance(ctx, req.Storage, instanceID, region, accountID)
if err != nil {
return nil, fmt.Errorf("failed to verify instance ID %q: %q", instanceID, err)
}
storedIdentity, err := whitelistIdentityEntry(req.Storage, instanceID)
storedIdentity, err := whitelistIdentityEntry(ctx, req.Storage, instanceID)
if err != nil {
return nil, err
}
@@ -1022,7 +1022,7 @@ func (b *backend) pathLoginRenewEc2(ctx context.Context, req *logical.Request, d
}
// Ensure that role entry is not deleted
roleEntry, err := b.lockedAWSRole(req.Storage, storedIdentity.Role)
roleEntry, err := b.lockedAWSRole(ctx, req.Storage, storedIdentity.Role)
if err != nil {
return nil, err
}
@@ -1061,7 +1061,7 @@ func (b *backend) pathLoginRenewEc2(ctx context.Context, req *logical.Request, d
// Updating the expiration time is required for the tidy operation on the
// whitelist identity storage items
if err = setWhitelistIdentityEntry(req.Storage, instanceID, storedIdentity); err != nil {
if err = setWhitelistIdentityEntry(ctx, req.Storage, instanceID, storedIdentity); err != nil {
return nil, err
}
@@ -1127,7 +1127,7 @@ func (b *backend) pathLoginUpdateIam(ctx context.Context, req *logical.Request,
return logical.ErrorResponse("nil response when parsing iam_request_headers"), nil
}
config, err := b.lockedClientConfigEntry(req.Storage)
config, err := b.lockedClientConfigEntry(ctx, req.Storage)
if err != nil {
return logical.ErrorResponse("error getting configuration"), nil
}
@@ -1175,7 +1175,7 @@ func (b *backend) pathLoginUpdateIam(ctx context.Context, req *logical.Request,
roleName = entity.FriendlyName
}
roleEntry, err := b.lockedAWSRole(req.Storage, roleName)
roleEntry, err := b.lockedAWSRole(ctx, req.Storage, roleName)
if err != nil {
return nil, err
}
@@ -1200,7 +1200,7 @@ func (b *backend) pathLoginUpdateIam(ctx context.Context, req *logical.Request,
if strings.HasSuffix(roleEntry.BoundIamPrincipalARN, "*") {
fullArn := b.getCachedUserId(callerUniqueId)
if fullArn == "" {
fullArn, err = b.fullArn(entity, req.Storage)
fullArn, err = b.fullArn(ctx, entity, req.Storage)
if err != nil {
return logical.ErrorResponse(fmt.Sprintf("error looking up full ARN of entity %v: %v", entity, err)), nil
}
@@ -1224,7 +1224,7 @@ func (b *backend) pathLoginUpdateIam(ctx context.Context, req *logical.Request,
inferredEntityType := ""
inferredEntityID := ""
if roleEntry.InferredEntityType == ec2EntityType {
instance, err := b.validateInstance(req.Storage, entity.SessionInfo, roleEntry.InferredAWSRegion, callerID.Account)
instance, err := b.validateInstance(ctx, req.Storage, entity.SessionInfo, roleEntry.InferredAWSRegion, callerID.Account)
if err != nil {
return logical.ErrorResponse(fmt.Sprintf("failed to verify %s as a valid EC2 instance in region %s", entity.SessionInfo, roleEntry.InferredAWSRegion)), nil
}
@@ -1239,7 +1239,7 @@ func (b *backend) pathLoginUpdateIam(ctx context.Context, req *logical.Request,
PendingTime: instance.LaunchTime.Format(time.RFC3339),
}
validationError, err := b.verifyInstanceMeetsRoleRequirements(req.Storage, instance, roleEntry, roleName, identityDoc)
validationError, err := b.verifyInstanceMeetsRoleRequirements(ctx, req.Storage, instance, roleEntry, roleName, identityDoc)
if err != nil {
return nil, err
}
@@ -1587,9 +1587,9 @@ func (e *iamEntity) canonicalArn() string {
}
// This returns the "full" ARN of an iamEntity, how it would be referred to in AWS proper
func (b *backend) fullArn(e *iamEntity, s logical.Storage) (string, error) {
func (b *backend) fullArn(ctx context.Context, e *iamEntity, s logical.Storage) (string, error) {
// Not assuming path is reliable for any entity types
client, err := b.clientIAM(s, getAnyRegionForAwsPartition(e.Partition).ID(), e.AccountNumber)
client, err := b.clientIAM(ctx, s, getAnyRegionForAwsPartition(e.Partition).ID(), e.AccountNumber)
if err != nil {
return "", fmt.Errorf("error creating IAM client: %v", err)
}

View File

@@ -204,7 +204,7 @@ func pathListRoles(b *backend) *framework.Path {
// Establishes dichotomy of request operation between CreateOperation and UpdateOperation.
// Returning 'true' forces an UpdateOperation, CreateOperation otherwise.
func (b *backend) pathRoleExistenceCheck(ctx context.Context, req *logical.Request, data *framework.FieldData) (bool, error) {
entry, err := b.lockedAWSRole(req.Storage, strings.ToLower(data.Get("role").(string)))
entry, err := b.lockedAWSRole(ctx, req.Storage, strings.ToLower(data.Get("role").(string)))
if err != nil {
return false, err
}
@@ -213,13 +213,13 @@ func (b *backend) pathRoleExistenceCheck(ctx context.Context, req *logical.Reque
// lockedAWSRole returns the properties set on the given role. This method
// acquires the read lock before reading the role from the storage.
func (b *backend) lockedAWSRole(s logical.Storage, roleName string) (*awsRoleEntry, error) {
func (b *backend) lockedAWSRole(ctx context.Context, s logical.Storage, roleName string) (*awsRoleEntry, error) {
if roleName == "" {
return nil, fmt.Errorf("missing role name")
}
b.roleMutex.RLock()
roleEntry, err := b.nonLockedAWSRole(s, roleName)
roleEntry, err := b.nonLockedAWSRole(ctx, s, roleName)
// we manually unlock rather than defer the unlock because we might need to grab
// a read/write lock in the upgrade path
b.roleMutex.RUnlock()
@@ -229,7 +229,7 @@ func (b *backend) lockedAWSRole(s logical.Storage, roleName string) (*awsRoleEnt
if roleEntry == nil {
return nil, nil
}
needUpgrade, err := b.upgradeRoleEntry(s, roleEntry)
needUpgrade, err := b.upgradeRoleEntry(ctx, s, roleEntry)
if err != nil {
return nil, fmt.Errorf("error upgrading roleEntry: %v", err)
}
@@ -238,7 +238,7 @@ func (b *backend) lockedAWSRole(s logical.Storage, roleName string) (*awsRoleEnt
defer b.roleMutex.Unlock()
// Now that we have a R/W lock, we need to re-read the role entry in case it was
// written to between releasing the read lock and acquiring the write lock
roleEntry, err = b.nonLockedAWSRole(s, roleName)
roleEntry, err = b.nonLockedAWSRole(ctx, s, roleName)
if err != nil {
return nil, err
}
@@ -247,11 +247,11 @@ func (b *backend) lockedAWSRole(s logical.Storage, roleName string) (*awsRoleEnt
return nil, nil
}
// now re-check to see if we need to upgrade
if needUpgrade, err = b.upgradeRoleEntry(s, roleEntry); err != nil {
if needUpgrade, err = b.upgradeRoleEntry(ctx, s, roleEntry); err != nil {
return nil, fmt.Errorf("error upgrading roleEntry: %v", err)
}
if needUpgrade {
if err = b.nonLockedSetAWSRole(s, roleName, roleEntry); err != nil {
if err = b.nonLockedSetAWSRole(ctx, s, roleName, roleEntry); err != nil {
return nil, fmt.Errorf("error saving upgraded roleEntry: %v", err)
}
}
@@ -261,7 +261,7 @@ func (b *backend) lockedAWSRole(s logical.Storage, roleName string) (*awsRoleEnt
// lockedSetAWSRole creates or updates a role in the storage. This method
// acquires the write lock before creating or updating the role at the storage.
func (b *backend) lockedSetAWSRole(s logical.Storage, roleName string, roleEntry *awsRoleEntry) error {
func (b *backend) lockedSetAWSRole(ctx context.Context, s logical.Storage, roleName string, roleEntry *awsRoleEntry) error {
if roleName == "" {
return fmt.Errorf("missing role name")
}
@@ -273,13 +273,13 @@ func (b *backend) lockedSetAWSRole(s logical.Storage, roleName string, roleEntry
b.roleMutex.Lock()
defer b.roleMutex.Unlock()
return b.nonLockedSetAWSRole(s, roleName, roleEntry)
return b.nonLockedSetAWSRole(ctx, s, roleName, roleEntry)
}
// nonLockedSetAWSRole creates or updates a role in the storage. This method
// does not acquire the write lock before reading the role from the storage. If
// locking is desired, use lockedSetAWSRole instead.
func (b *backend) nonLockedSetAWSRole(s logical.Storage, roleName string,
func (b *backend) nonLockedSetAWSRole(ctx context.Context, s logical.Storage, roleName string,
roleEntry *awsRoleEntry) error {
if roleName == "" {
return fmt.Errorf("missing role name")
@@ -294,7 +294,7 @@ func (b *backend) nonLockedSetAWSRole(s logical.Storage, roleName string,
return err
}
if err := s.Put(entry); err != nil {
if err := s.Put(ctx, entry); err != nil {
return err
}
@@ -303,7 +303,7 @@ func (b *backend) nonLockedSetAWSRole(s logical.Storage, roleName string,
// If needed, updates the role entry and returns a bool indicating if it was updated
// (and thus needs to be persisted)
func (b *backend) upgradeRoleEntry(s logical.Storage, roleEntry *awsRoleEntry) (bool, error) {
func (b *backend) upgradeRoleEntry(ctx context.Context, s logical.Storage, roleEntry *awsRoleEntry) (bool, error) {
if roleEntry == nil {
return false, fmt.Errorf("received nil roleEntry")
}
@@ -331,7 +331,7 @@ func (b *backend) upgradeRoleEntry(s logical.Storage, roleEntry *awsRoleEntry) (
roleEntry.BoundIamPrincipalARN != "" &&
roleEntry.BoundIamPrincipalID == "" &&
!strings.HasSuffix(roleEntry.BoundIamPrincipalARN, "*") {
principalId, err := b.resolveArnToUniqueIDFunc(s, roleEntry.BoundIamPrincipalARN)
principalId, err := b.resolveArnToUniqueIDFunc(ctx, s, roleEntry.BoundIamPrincipalARN)
if err != nil {
return false, err
}
@@ -349,12 +349,12 @@ func (b *backend) upgradeRoleEntry(s logical.Storage, roleEntry *awsRoleEntry) (
// This method also does NOT check to see if a role upgrade is required. It is
// the responsibility of the caller to check if a role upgrade is required and,
// if so, to upgrade the role
func (b *backend) nonLockedAWSRole(s logical.Storage, roleName string) (*awsRoleEntry, error) {
func (b *backend) nonLockedAWSRole(ctx context.Context, s logical.Storage, roleName string) (*awsRoleEntry, error) {
if roleName == "" {
return nil, fmt.Errorf("missing role name")
}
entry, err := s.Get("role/" + strings.ToLower(roleName))
entry, err := s.Get(ctx, "role/"+strings.ToLower(roleName))
if err != nil {
return nil, err
}
@@ -380,7 +380,7 @@ func (b *backend) pathRoleDelete(ctx context.Context, req *logical.Request, data
b.roleMutex.Lock()
defer b.roleMutex.Unlock()
return nil, req.Storage.Delete("role/" + strings.ToLower(roleName))
return nil, req.Storage.Delete(ctx, "role/"+strings.ToLower(roleName))
}
// pathRoleList is used to list all the AMI IDs registered with Vault.
@@ -388,7 +388,7 @@ func (b *backend) pathRoleList(ctx context.Context, req *logical.Request, data *
b.roleMutex.RLock()
defer b.roleMutex.RUnlock()
roles, err := req.Storage.List("role/")
roles, err := req.Storage.List(ctx, "role/")
if err != nil {
return nil, err
}
@@ -397,7 +397,7 @@ func (b *backend) pathRoleList(ctx context.Context, req *logical.Request, data *
// pathRoleRead is used to view the information registered for a given AMI ID.
func (b *backend) pathRoleRead(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
roleEntry, err := b.lockedAWSRole(req.Storage, strings.ToLower(data.Get("role").(string)))
roleEntry, err := b.lockedAWSRole(ctx, req.Storage, strings.ToLower(data.Get("role").(string)))
if err != nil {
return nil, err
}
@@ -431,19 +431,19 @@ func (b *backend) pathRoleCreateUpdate(ctx context.Context, req *logical.Request
b.roleMutex.Lock()
defer b.roleMutex.Unlock()
roleEntry, err := b.nonLockedAWSRole(req.Storage, roleName)
roleEntry, err := b.nonLockedAWSRole(ctx, req.Storage, roleName)
if err != nil {
return nil, err
}
if roleEntry == nil {
roleEntry = &awsRoleEntry{}
} else {
needUpdate, err := b.upgradeRoleEntry(req.Storage, roleEntry)
needUpdate, err := b.upgradeRoleEntry(ctx, req.Storage, roleEntry)
if err != nil {
return logical.ErrorResponse(fmt.Sprintf("failed to update roleEntry: %v", err)), nil
}
if needUpdate {
err = b.nonLockedSetAWSRole(req.Storage, roleName, roleEntry)
err = b.nonLockedSetAWSRole(ctx, req.Storage, roleName, roleEntry)
if err != nil {
return logical.ErrorResponse(fmt.Sprintf("failed to save upgraded roleEntry: %v", err)), nil
}
@@ -501,7 +501,7 @@ func (b *backend) pathRoleCreateUpdate(ctx context.Context, req *logical.Request
// to re-resolve the ARN to the unique ID, in case an entity was deleted and
// recreated
if roleEntry.ResolveAWSUniqueIDs && !strings.HasSuffix(roleEntry.BoundIamPrincipalARN, "*") {
principalID, err := b.resolveArnToUniqueIDFunc(req.Storage, principalARN)
principalID, err := b.resolveArnToUniqueIDFunc(ctx, req.Storage, principalARN)
if err != nil {
return logical.ErrorResponse(fmt.Sprintf("failed updating the unique ID of ARN %#v: %#v", principalARN, err)), nil
}
@@ -512,7 +512,7 @@ func (b *backend) pathRoleCreateUpdate(ctx context.Context, req *logical.Request
}
} else if roleEntry.ResolveAWSUniqueIDs && roleEntry.BoundIamPrincipalARN != "" && !strings.HasSuffix(roleEntry.BoundIamPrincipalARN, "*") {
// we're turning on resolution on this role, so ensure we update it
principalID, err := b.resolveArnToUniqueIDFunc(req.Storage, roleEntry.BoundIamPrincipalARN)
principalID, err := b.resolveArnToUniqueIDFunc(ctx, req.Storage, roleEntry.BoundIamPrincipalARN)
if err != nil {
return logical.ErrorResponse(fmt.Sprintf("unable to resolve ARN %#v to internal ID: %#v", roleEntry.BoundIamPrincipalARN, err)), nil
}
@@ -731,7 +731,7 @@ func (b *backend) pathRoleCreateUpdate(ctx context.Context, req *logical.Request
}
}
if err := b.nonLockedSetAWSRole(req.Storage, roleName, roleEntry); err != nil {
if err := b.nonLockedSetAWSRole(ctx, req.Storage, roleName, roleEntry); err != nil {
return nil, err
}

View File

@@ -77,7 +77,7 @@ func (b *backend) pathRoleTagUpdate(ctx context.Context, req *logical.Request, d
}
// Fetch the role entry
roleEntry, err := b.lockedAWSRole(req.Storage, roleName)
roleEntry, err := b.lockedAWSRole(ctx, req.Storage, roleName)
if err != nil {
return nil, err
}
@@ -288,7 +288,7 @@ func prepareRoleTagPlaintextValue(rTag *roleTag) (string, error) {
// Parses the tag from string form into a struct form. This method
// also verifies the correctness of the parsed role tag.
func (b *backend) parseAndVerifyRoleTagValue(s logical.Storage, tag string) (*roleTag, error) {
func (b *backend) parseAndVerifyRoleTagValue(ctx context.Context, s logical.Storage, tag string) (*roleTag, error) {
tagItems := strings.Split(tag, ":")
// Tag must contain version, nonce, policies and HMAC
@@ -349,7 +349,7 @@ func (b *backend) parseAndVerifyRoleTagValue(s logical.Storage, tag string) (*ro
return nil, fmt.Errorf("missing role name")
}
roleEntry, err := b.lockedAWSRole(s, rTag.Role)
roleEntry, err := b.lockedAWSRole(ctx, s, rTag.Role)
if err != nil {
return nil, err
}

View File

@@ -20,7 +20,8 @@ func TestBackend_pathRoleEc2(t *testing.T) {
if err != nil {
t.Fatal(err)
}
err = b.Setup(config)
err = b.Setup(context.Background(), config)
if err != nil {
t.Fatal(err)
}
@@ -164,7 +165,7 @@ func Test_enableIamIDResolution(t *testing.T) {
if err != nil {
t.Fatal(err)
}
err = b.Setup(config)
err = b.Setup(context.Background(), config)
if err != nil {
t.Fatal(err)
}
@@ -239,7 +240,7 @@ func TestBackend_pathIam(t *testing.T) {
if err != nil {
t.Fatal(err)
}
err = b.Setup(config)
err = b.Setup(context.Background(), config)
if err != nil {
t.Fatal(err)
}
@@ -403,7 +404,7 @@ func TestBackend_pathRoleMixedTypes(t *testing.T) {
if err != nil {
t.Fatal(err)
}
err = b.Setup(config)
err = b.Setup(context.Background(), config)
if err != nil {
t.Fatal(err)
}
@@ -509,7 +510,8 @@ func TestAwsEc2_RoleCrud(t *testing.T) {
if err != nil {
t.Fatal(err)
}
err = b.Setup(config)
err = b.Setup(context.Background(), config)
if err != nil {
t.Fatal(err)
}
@@ -635,7 +637,8 @@ func TestAwsEc2_RoleDurationSeconds(t *testing.T) {
if err != nil {
t.Fatal(err)
}
err = b.Setup(config)
err = b.Setup(context.Background(), config)
if err != nil {
t.Fatal(err)
}
@@ -679,6 +682,6 @@ func TestAwsEc2_RoleDurationSeconds(t *testing.T) {
}
}
func resolveArnToFakeUniqueId(s logical.Storage, arn string) (string, error) {
func resolveArnToFakeUniqueId(ctx context.Context, s logical.Storage, arn string) (string, error) {
return "FakeUniqueId1", nil
}

View File

@@ -50,7 +50,7 @@ func (b *backend) pathRoletagBlacklistsList(ctx context.Context, req *logical.Re
b.blacklistMutex.RLock()
defer b.blacklistMutex.RUnlock()
tags, err := req.Storage.List("blacklist/roletag/")
tags, err := req.Storage.List(ctx, "blacklist/roletag/")
if err != nil {
return nil, err
}
@@ -71,15 +71,15 @@ func (b *backend) pathRoletagBlacklistsList(ctx context.Context, req *logical.Re
// Fetch an entry from the role tag blacklist for a given tag.
// This method takes a role tag in its original form and not a base64 encoded form.
func (b *backend) lockedBlacklistRoleTagEntry(s logical.Storage, tag string) (*roleTagBlacklistEntry, error) {
func (b *backend) lockedBlacklistRoleTagEntry(ctx context.Context, s logical.Storage, tag string) (*roleTagBlacklistEntry, error) {
b.blacklistMutex.RLock()
defer b.blacklistMutex.RUnlock()
return b.nonLockedBlacklistRoleTagEntry(s, tag)
return b.nonLockedBlacklistRoleTagEntry(ctx, s, tag)
}
func (b *backend) nonLockedBlacklistRoleTagEntry(s logical.Storage, tag string) (*roleTagBlacklistEntry, error) {
entry, err := s.Get("blacklist/roletag/" + base64.StdEncoding.EncodeToString([]byte(tag)))
func (b *backend) nonLockedBlacklistRoleTagEntry(ctx context.Context, s logical.Storage, tag string) (*roleTagBlacklistEntry, error) {
entry, err := s.Get(ctx, "blacklist/roletag/"+base64.StdEncoding.EncodeToString([]byte(tag)))
if err != nil {
return nil, err
}
@@ -104,7 +104,7 @@ func (b *backend) pathRoletagBlacklistDelete(ctx context.Context, req *logical.R
return logical.ErrorResponse("missing role_tag"), nil
}
return nil, req.Storage.Delete("blacklist/roletag/" + base64.StdEncoding.EncodeToString([]byte(tag)))
return nil, req.Storage.Delete(ctx, "blacklist/roletag/"+base64.StdEncoding.EncodeToString([]byte(tag)))
}
// If the given role tag is blacklisted, returns the details of the blacklist entry.
@@ -115,7 +115,7 @@ func (b *backend) pathRoletagBlacklistRead(ctx context.Context, req *logical.Req
return logical.ErrorResponse("missing role_tag"), nil
}
entry, err := b.lockedBlacklistRoleTagEntry(req.Storage, tag)
entry, err := b.lockedBlacklistRoleTagEntry(ctx, req.Storage, tag)
if err != nil {
return nil, err
}
@@ -154,7 +154,7 @@ func (b *backend) pathRoletagBlacklistUpdate(ctx context.Context, req *logical.R
}
// Parse and verify the role tag from string form to a struct form and verify it.
rTag, err := b.parseAndVerifyRoleTagValue(req.Storage, tag)
rTag, err := b.parseAndVerifyRoleTagValue(ctx, req.Storage, tag)
if err != nil {
return nil, err
}
@@ -163,7 +163,7 @@ func (b *backend) pathRoletagBlacklistUpdate(ctx context.Context, req *logical.R
}
// Get the entry for the role mentioned in the role tag.
roleEntry, err := b.lockedAWSRole(req.Storage, rTag.Role)
roleEntry, err := b.lockedAWSRole(ctx, req.Storage, rTag.Role)
if err != nil {
return nil, err
}
@@ -175,7 +175,7 @@ func (b *backend) pathRoletagBlacklistUpdate(ctx context.Context, req *logical.R
defer b.blacklistMutex.Unlock()
// Check if the role tag is already blacklisted. If yes, update it.
blEntry, err := b.nonLockedBlacklistRoleTagEntry(req.Storage, tag)
blEntry, err := b.nonLockedBlacklistRoleTagEntry(ctx, req.Storage, tag)
if err != nil {
return nil, err
}
@@ -211,7 +211,7 @@ func (b *backend) pathRoletagBlacklistUpdate(ctx context.Context, req *logical.R
}
// Store the blacklist entry.
if err := req.Storage.Put(entry); err != nil {
if err := req.Storage.Put(ctx, entry); err != nil {
return nil, err
}

View File

@@ -32,7 +32,7 @@ expiration, before it is removed from the backend storage.`,
}
// tidyWhitelistIdentity is used to delete entries in the whitelist that are expired.
func (b *backend) tidyWhitelistIdentity(s logical.Storage, safety_buffer int) error {
func (b *backend) tidyWhitelistIdentity(ctx context.Context, s logical.Storage, safety_buffer int) error {
grabbed := atomic.CompareAndSwapUint32(&b.tidyWhitelistCASGuard, 0, 1)
if grabbed {
defer atomic.StoreUint32(&b.tidyWhitelistCASGuard, 0)
@@ -42,13 +42,13 @@ func (b *backend) tidyWhitelistIdentity(s logical.Storage, safety_buffer int) er
bufferDuration := time.Duration(safety_buffer) * time.Second
identities, err := s.List("whitelist/identity/")
identities, err := s.List(ctx, "whitelist/identity/")
if err != nil {
return err
}
for _, instanceID := range identities {
identityEntry, err := s.Get("whitelist/identity/" + instanceID)
identityEntry, err := s.Get(ctx, "whitelist/identity/"+instanceID)
if err != nil {
return fmt.Errorf("error fetching identity of instanceID %s: %s", instanceID, err)
}
@@ -67,7 +67,7 @@ func (b *backend) tidyWhitelistIdentity(s logical.Storage, safety_buffer int) er
}
if time.Now().After(result.ExpirationTime.Add(bufferDuration)) {
if err := s.Delete("whitelist/identity" + instanceID); err != nil {
if err := s.Delete(ctx, "whitelist/identity"+instanceID); err != nil {
return fmt.Errorf("error deleting identity of instanceID %s from storage: %s", instanceID, err)
}
}
@@ -78,7 +78,7 @@ func (b *backend) tidyWhitelistIdentity(s logical.Storage, safety_buffer int) er
// pathTidyIdentityWhitelistUpdate is used to delete entries in the whitelist that are expired.
func (b *backend) pathTidyIdentityWhitelistUpdate(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
return nil, b.tidyWhitelistIdentity(req.Storage, data.Get("safety_buffer").(int))
return nil, b.tidyWhitelistIdentity(ctx, req.Storage, data.Get("safety_buffer").(int))
}
const pathTidyIdentityWhitelistSyn = `

View File

@@ -32,7 +32,7 @@ expiration, before it is removed from the backend storage.`,
}
// tidyBlacklistRoleTag is used to clean-up the entries in the role tag blacklist.
func (b *backend) tidyBlacklistRoleTag(s logical.Storage, safety_buffer int) error {
func (b *backend) tidyBlacklistRoleTag(ctx context.Context, s logical.Storage, safety_buffer int) error {
grabbed := atomic.CompareAndSwapUint32(&b.tidyBlacklistCASGuard, 0, 1)
if grabbed {
defer atomic.StoreUint32(&b.tidyBlacklistCASGuard, 0)
@@ -41,13 +41,13 @@ func (b *backend) tidyBlacklistRoleTag(s logical.Storage, safety_buffer int) err
}
bufferDuration := time.Duration(safety_buffer) * time.Second
tags, err := s.List("blacklist/roletag/")
tags, err := s.List(ctx, "blacklist/roletag/")
if err != nil {
return err
}
for _, tag := range tags {
tagEntry, err := s.Get("blacklist/roletag/" + tag)
tagEntry, err := s.Get(ctx, "blacklist/roletag/"+tag)
if err != nil {
return fmt.Errorf("error fetching tag %s: %s", tag, err)
}
@@ -66,7 +66,7 @@ func (b *backend) tidyBlacklistRoleTag(s logical.Storage, safety_buffer int) err
}
if time.Now().After(result.ExpirationTime.Add(bufferDuration)) {
if err := s.Delete("blacklist/roletag" + tag); err != nil {
if err := s.Delete(ctx, "blacklist/roletag"+tag); err != nil {
return fmt.Errorf("error deleting tag %s from storage: %s", tag, err)
}
}
@@ -77,7 +77,7 @@ func (b *backend) tidyBlacklistRoleTag(s logical.Storage, safety_buffer int) err
// pathTidyRoletagBlacklistUpdate is used to clean-up the entries in the role tag blacklist.
func (b *backend) pathTidyRoletagBlacklistUpdate(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
return nil, b.tidyBlacklistRoleTag(req.Storage, data.Get("safety_buffer").(int))
return nil, b.tidyBlacklistRoleTag(ctx, req.Storage, data.Get("safety_buffer").(int))
}
const pathTidyRoletagBlacklistSyn = `

View File

@@ -1,6 +1,7 @@
package cert
import (
"context"
"strings"
"sync"
@@ -8,9 +9,9 @@ import (
"github.com/hashicorp/vault/logical/framework"
)
func Factory(conf *logical.BackendConfig) (logical.Backend, error) {
func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) {
b := Backend()
if err := b.Setup(conf); err != nil {
if err := b.Setup(ctx, conf); err != nil {
return nil, err
}
return b, nil
@@ -50,7 +51,7 @@ type backend struct {
crlUpdateMutex *sync.RWMutex
}
func (b *backend) invalidate(key string) {
func (b *backend) invalidate(_ context.Context, key string) {
switch {
case strings.HasPrefix(key, "crls/"):
b.crlUpdateMutex.Lock()

View File

@@ -306,7 +306,7 @@ func TestBackend_NonCAExpiry(t *testing.T) {
storage := &logical.InmemStorage{}
config.StorageView = storage
b, err := Factory(config)
b, err := Factory(context.Background(), config)
if err != nil {
t.Fatal(err)
}
@@ -366,7 +366,7 @@ func TestBackend_RegisteredNonCA_CRL(t *testing.T) {
storage := &logical.InmemStorage{}
config.StorageView = storage
b, err := Factory(config)
b, err := Factory(context.Background(), config)
if err != nil {
t.Fatal(err)
}
@@ -449,7 +449,7 @@ func TestBackend_CRLs(t *testing.T) {
storage := &logical.InmemStorage{}
config.StorageView = storage
b, err := Factory(config)
b, err := Factory(context.Background(), config)
if err != nil {
t.Fatal(err)
}
@@ -586,7 +586,7 @@ func TestBackend_CRLs(t *testing.T) {
}
func testFactory(t *testing.T) logical.Backend {
b, err := Factory(&logical.BackendConfig{
b, err := Factory(context.Background(), &logical.BackendConfig{
System: &logical.StaticSystemView{
DefaultLeaseTTLVal: 1000 * time.Second,
MaxLeaseTTLVal: 1800 * time.Second,
@@ -1135,7 +1135,7 @@ func testConnState(certPath, keyPath, rootCertPath string) (tls.ConnectionState,
func Test_Renew(t *testing.T) {
storage := &logical.InmemStorage{}
lb, err := Factory(&logical.BackendConfig{
lb, err := Factory(context.Background(), &logical.BackendConfig{
System: &logical.StaticSystemView{
DefaultLeaseTTLVal: 300 * time.Second,
MaxLeaseTTLVal: 1800 * time.Second,

View File

@@ -101,8 +101,8 @@ TTL will be set to the value of this parameter.`,
}
}
func (b *backend) Cert(s logical.Storage, n string) (*CertEntry, error) {
entry, err := s.Get("cert/" + strings.ToLower(n))
func (b *backend) Cert(ctx context.Context, s logical.Storage, n string) (*CertEntry, error) {
entry, err := s.Get(ctx, "cert/"+strings.ToLower(n))
if err != nil {
return nil, err
}
@@ -118,7 +118,7 @@ func (b *backend) Cert(s logical.Storage, n string) (*CertEntry, error) {
}
func (b *backend) pathCertDelete(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
err := req.Storage.Delete("cert/" + strings.ToLower(d.Get("name").(string)))
err := req.Storage.Delete(ctx, "cert/"+strings.ToLower(d.Get("name").(string)))
if err != nil {
return nil, err
}
@@ -126,7 +126,7 @@ func (b *backend) pathCertDelete(ctx context.Context, req *logical.Request, d *f
}
func (b *backend) pathCertList(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
certs, err := req.Storage.List("cert/")
certs, err := req.Storage.List(ctx, "cert/")
if err != nil {
return nil, err
}
@@ -134,7 +134,7 @@ func (b *backend) pathCertList(ctx context.Context, req *logical.Request, d *fra
}
func (b *backend) pathCertRead(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
cert, err := b.Cert(req.Storage, strings.ToLower(d.Get("name").(string)))
cert, err := b.Cert(ctx, req.Storage, strings.ToLower(d.Get("name").(string)))
if err != nil {
return nil, err
}
@@ -245,7 +245,7 @@ func (b *backend) pathCertWrite(ctx context.Context, req *logical.Request, d *fr
if err != nil {
return nil, err
}
if err := req.Storage.Put(entry); err != nil {
if err := req.Storage.Put(ctx, entry); err != nil {
return nil, err
}

View File

@@ -35,15 +35,15 @@ func (b *backend) pathConfigWrite(ctx context.Context, req *logical.Request, dat
return nil, err
}
if err := req.Storage.Put(entry); err != nil {
if err := req.Storage.Put(ctx, entry); err != nil {
return nil, err
}
return nil, nil
}
// Config returns the configuration for this backend.
func (b *backend) Config(s logical.Storage) (*config, error) {
entry, err := s.Get("config")
func (b *backend) Config(ctx context.Context, s logical.Storage) (*config, error) {
entry, err := s.Get(ctx, "config")
if err != nil {
return nil, err
}

View File

@@ -42,7 +42,7 @@ using the same name as specified here.`,
}
}
func (b *backend) populateCRLs(storage logical.Storage) error {
func (b *backend) populateCRLs(ctx context.Context, storage logical.Storage) error {
b.crlUpdateMutex.Lock()
defer b.crlUpdateMutex.Unlock()
@@ -52,7 +52,7 @@ func (b *backend) populateCRLs(storage logical.Storage) error {
b.crls = map[string]CRLInfo{}
keys, err := storage.List("crls/")
keys, err := storage.List(ctx, "crls/")
if err != nil {
return fmt.Errorf("error listing CRLs: %v", err)
}
@@ -61,7 +61,7 @@ func (b *backend) populateCRLs(storage logical.Storage) error {
}
for _, key := range keys {
entry, err := storage.Get("crls/" + key)
entry, err := storage.Get(ctx, "crls/"+key)
if err != nil {
b.crls = nil
return fmt.Errorf("error loading CRL %s: %v", key, err)
@@ -129,7 +129,7 @@ func (b *backend) pathCRLDelete(ctx context.Context, req *logical.Request, d *fr
return logical.ErrorResponse(`"name" parameter cannot be empty`), nil
}
if err := b.populateCRLs(req.Storage); err != nil {
if err := b.populateCRLs(ctx, req.Storage); err != nil {
return nil, err
}
@@ -143,7 +143,7 @@ func (b *backend) pathCRLDelete(ctx context.Context, req *logical.Request, d *fr
)), nil
}
if err := req.Storage.Delete("crls/" + name); err != nil {
if err := req.Storage.Delete(ctx, "crls/"+name); err != nil {
return logical.ErrorResponse(fmt.Sprintf(
"error deleting crl %s: %v", name, err),
), nil
@@ -160,7 +160,7 @@ func (b *backend) pathCRLRead(ctx context.Context, req *logical.Request, d *fram
return logical.ErrorResponse(`"name" parameter must be set`), nil
}
if err := b.populateCRLs(req.Storage); err != nil {
if err := b.populateCRLs(ctx, req.Storage); err != nil {
return nil, err
}
@@ -198,7 +198,7 @@ func (b *backend) pathCRLWrite(ctx context.Context, req *logical.Request, d *fra
return logical.ErrorResponse("parsed CRL is nil"), nil
}
if err := b.populateCRLs(req.Storage); err != nil {
if err := b.populateCRLs(ctx, req.Storage); err != nil {
return nil, err
}
@@ -216,7 +216,7 @@ func (b *backend) pathCRLWrite(ctx context.Context, req *logical.Request, d *fra
if err != nil {
return nil, err
}
if err = req.Storage.Put(entry); err != nil {
if err = req.Storage.Put(ctx, entry); err != nil {
return nil, err
}

View File

@@ -60,7 +60,7 @@ func (b *backend) pathLoginAliasLookahead(ctx context.Context, req *logical.Requ
func (b *backend) pathLogin(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
var matched *ParsedCert
if verifyResp, resp, err := b.verifyCredentials(req, data); err != nil {
if verifyResp, resp, err := b.verifyCredentials(ctx, req, data); err != nil {
return nil, err
} else if resp != nil {
return resp, nil
@@ -128,14 +128,14 @@ func (b *backend) pathLogin(ctx context.Context, req *logical.Request, data *fra
}
func (b *backend) pathLoginRenew(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
config, err := b.Config(req.Storage)
config, err := b.Config(ctx, req.Storage)
if err != nil {
return nil, err
}
if !config.DisableBinding {
var matched *ParsedCert
if verifyResp, resp, err := b.verifyCredentials(req, d); err != nil {
if verifyResp, resp, err := b.verifyCredentials(ctx, req, d); err != nil {
return nil, err
} else if resp != nil {
return resp, nil
@@ -162,7 +162,7 @@ func (b *backend) pathLoginRenew(ctx context.Context, req *logical.Request, d *f
}
// Get the cert and use its TTL
cert, err := b.Cert(req.Storage, req.Auth.Metadata["cert_name"])
cert, err := b.Cert(ctx, req.Storage, req.Auth.Metadata["cert_name"])
if err != nil {
return nil, err
}
@@ -188,7 +188,7 @@ func (b *backend) pathLoginRenew(ctx context.Context, req *logical.Request, d *f
return framework.LeaseExtend(cert.TTL, cert.MaxTTL, b.System())(ctx, req, d)
}
func (b *backend) verifyCredentials(req *logical.Request, d *framework.FieldData) (*ParsedCert, *logical.Response, error) {
func (b *backend) verifyCredentials(ctx context.Context, req *logical.Request, d *framework.FieldData) (*ParsedCert, *logical.Response, error) {
// Get the connection state
if req.Connection == nil || req.Connection.ConnState == nil {
return nil, logical.ErrorResponse("tls connection required"), nil
@@ -209,7 +209,14 @@ func (b *backend) verifyCredentials(req *logical.Request, d *framework.FieldData
}
// Load the trusted certificates
roots, trusted, trustedNonCAs := b.loadTrustedCerts(req.Storage, certName)
roots, trusted, trustedNonCAs := b.loadTrustedCerts(ctx, req.Storage, certName)
// Get the list of full chains matching the connection and validates the
// certificate itself
trustedChains, err := validateConnState(roots, connState)
if err != nil {
return nil, nil, err
}
// If trustedNonCAs is not empty it means that client had registered a non-CA cert
// with the backend.
@@ -225,12 +232,8 @@ func (b *backend) verifyCredentials(req *logical.Request, d *framework.FieldData
}
}
// Get the list of full chains matching the connection
trustedChains, err := validateConnState(roots, connState)
if err != nil {
return nil, nil, err
}
// If no trusted chain was found, client is not authenticated
// This check happens after checking for a matching configured non-CA certs
if len(trustedChains) == 0 {
return nil, logical.ErrorResponse("invalid certificate or no client certificate supplied"), nil
}
@@ -328,11 +331,11 @@ func (b *backend) matchesCertificateExtenions(clientCert *x509.Certificate, conf
}
// loadTrustedCerts is used to load all the trusted certificates from the backend
func (b *backend) loadTrustedCerts(store logical.Storage, certName string) (pool *x509.CertPool, trusted []*ParsedCert, trustedNonCAs []*ParsedCert) {
func (b *backend) loadTrustedCerts(ctx context.Context, storage logical.Storage, certName string) (pool *x509.CertPool, trusted []*ParsedCert, trustedNonCAs []*ParsedCert) {
pool = x509.NewCertPool()
trusted = make([]*ParsedCert, 0)
trustedNonCAs = make([]*ParsedCert, 0)
names, err := store.List("cert/")
names, err := storage.List(ctx, "cert/")
if err != nil {
b.Logger().Error("cert: failed to list trusted certs", "error", err)
return
@@ -342,7 +345,7 @@ func (b *backend) loadTrustedCerts(store logical.Storage, certName string) (pool
if certName != "" && name != certName {
continue
}
entry, err := b.Cert(store, strings.TrimPrefix(name, "cert/"))
entry, err := b.Cert(ctx, storage, strings.TrimPrefix(name, "cert/"))
if err != nil {
b.Logger().Error("cert: failed to load trusted cert", "name", name, "error", err)
continue
@@ -423,9 +426,6 @@ func validateConnState(roots *x509.CertPool, cs *tls.ConnectionState) ([][]*x509
if len(certs) == 0 {
return nil, nil
}
if certs[0].IsCA {
return nil, nil
}
opts := x509.VerifyOptions{
Roots: roots,

View File

@@ -11,9 +11,9 @@ import (
"golang.org/x/oauth2"
)
func Factory(conf *logical.BackendConfig) (logical.Backend, error) {
func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) {
b := Backend()
if err := b.Setup(conf); err != nil {
if err := b.Setup(ctx, conf); err != nil {
return nil, err
}
return b, nil

View File

@@ -1,6 +1,7 @@
package github
import (
"context"
"fmt"
"os"
"strings"
@@ -14,7 +15,7 @@ import (
func TestBackend_Config(t *testing.T) {
defaultLeaseTTLVal := time.Hour * 24
maxLeaseTTLVal := time.Hour * 24 * 2
b, err := Factory(&logical.BackendConfig{
b, err := Factory(context.Background(), &logical.BackendConfig{
Logger: nil,
System: &logical.StaticSystemView{
DefaultLeaseTTLVal: defaultLeaseTTLVal,
@@ -92,7 +93,7 @@ func testConfigWrite(t *testing.T, d map[string]interface{}) logicaltest.TestSte
func TestBackend_basic(t *testing.T) {
defaultLeaseTTLVal := time.Hour * 24
maxLeaseTTLVal := time.Hour * 24 * 32
b, err := Factory(&logical.BackendConfig{
b, err := Factory(context.Background(), &logical.BackendConfig{
Logger: nil,
System: &logical.StaticSystemView{
DefaultLeaseTTLVal: defaultLeaseTTLVal,

View File

@@ -6,7 +6,6 @@ import (
"net/url"
"time"
"github.com/fatih/structs"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
)
@@ -87,7 +86,7 @@ func (b *backend) pathConfigWrite(ctx context.Context, req *logical.Request, dat
return nil, err
}
if err := req.Storage.Put(entry); err != nil {
if err := req.Storage.Put(ctx, entry); err != nil {
return nil, err
}
@@ -95,7 +94,7 @@ func (b *backend) pathConfigWrite(ctx context.Context, req *logical.Request, dat
}
func (b *backend) pathConfigRead(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
config, err := b.Config(req.Storage)
config, err := b.Config(ctx, req.Storage)
if err != nil {
return nil, err
}
@@ -108,14 +107,19 @@ func (b *backend) pathConfigRead(ctx context.Context, req *logical.Request, data
config.MaxTTL /= time.Second
resp := &logical.Response{
Data: structs.New(config).Map(),
Data: map[string]interface{}{
"organization": config.Organization,
"base_url": config.BaseURL,
"ttl": config.TTL,
"max_ttl": config.MaxTTL,
},
}
return resp, nil
}
// Config returns the configuration for this backend.
func (b *backend) Config(s logical.Storage) (*config, error) {
entry, err := s.Get("config")
func (b *backend) Config(ctx context.Context, s logical.Storage) (*config, error) {
entry, err := s.Get(ctx, "config")
if err != nil {
return nil, err
}

View File

@@ -33,7 +33,7 @@ func (b *backend) pathLoginAliasLookahead(ctx context.Context, req *logical.Requ
token := data.Get("token").(string)
var verifyResp *verifyCredentialsResp
if verifyResponse, resp, err := b.verifyCredentials(req, token); err != nil {
if verifyResponse, resp, err := b.verifyCredentials(ctx, req, token); err != nil {
return nil, err
} else if resp != nil {
return resp, nil
@@ -54,7 +54,7 @@ func (b *backend) pathLogin(ctx context.Context, req *logical.Request, data *fra
token := data.Get("token").(string)
var verifyResp *verifyCredentialsResp
if verifyResponse, resp, err := b.verifyCredentials(req, token); err != nil {
if verifyResponse, resp, err := b.verifyCredentials(ctx, req, token); err != nil {
return nil, err
} else if resp != nil {
return resp, nil
@@ -62,7 +62,7 @@ func (b *backend) pathLogin(ctx context.Context, req *logical.Request, data *fra
verifyResp = verifyResponse
}
config, err := b.Config(req.Storage)
config, err := b.Config(ctx, req.Storage)
if err != nil {
return nil, err
}
@@ -117,7 +117,7 @@ func (b *backend) pathLoginRenew(ctx context.Context, req *logical.Request, d *f
token := tokenRaw.(string)
var verifyResp *verifyCredentialsResp
if verifyResponse, resp, err := b.verifyCredentials(req, token); err != nil {
if verifyResponse, resp, err := b.verifyCredentials(ctx, req, token); err != nil {
return nil, err
} else if resp != nil {
return resp, nil
@@ -128,7 +128,7 @@ func (b *backend) pathLoginRenew(ctx context.Context, req *logical.Request, d *f
return nil, fmt.Errorf("policies do not match")
}
config, err := b.Config(req.Storage)
config, err := b.Config(ctx, req.Storage)
if err != nil {
return nil, err
}
@@ -150,8 +150,8 @@ func (b *backend) pathLoginRenew(ctx context.Context, req *logical.Request, d *f
return resp, nil
}
func (b *backend) verifyCredentials(req *logical.Request, token string) (*verifyCredentialsResp, *logical.Response, error) {
config, err := b.Config(req.Storage)
func (b *backend) verifyCredentials(ctx context.Context, req *logical.Request, token string) (*verifyCredentialsResp, *logical.Response, error) {
config, err := b.Config(ctx, req.Storage)
if err != nil {
return nil, nil, err
}
@@ -174,7 +174,7 @@ func (b *backend) verifyCredentials(req *logical.Request, token string) (*verify
}
// Get the user
user, _, err := client.Users.Get(context.Background(), "")
user, _, err := client.Users.Get(ctx, "")
if err != nil {
return nil, nil, err
}
@@ -188,7 +188,7 @@ func (b *backend) verifyCredentials(req *logical.Request, token string) (*verify
var allOrgs []*github.Organization
for {
orgs, resp, err := client.Organizations.List(context.Background(), "", orgOpt)
orgs, resp, err := client.Organizations.List(ctx, "", orgOpt)
if err != nil {
return nil, nil, err
}
@@ -218,7 +218,7 @@ func (b *backend) verifyCredentials(req *logical.Request, token string) (*verify
var allTeams []*github.Team
for {
teams, resp, err := client.Organizations.ListUserTeams(context.Background(), teamOpt)
teams, resp, err := client.Organizations.ListUserTeams(ctx, teamOpt)
if err != nil {
return nil, nil, err
}
@@ -242,13 +242,13 @@ func (b *backend) verifyCredentials(req *logical.Request, token string) (*verify
}
}
groupPoliciesList, err := b.TeamMap.Policies(req.Storage, teamNames...)
groupPoliciesList, err := b.TeamMap.Policies(ctx, req.Storage, teamNames...)
if err != nil {
return nil, nil, err
}
userPoliciesList, err := b.UserMap.Policies(req.Storage, []string{*user.Login}...)
userPoliciesList, err := b.UserMap.Policies(ctx, req.Storage, []string{*user.Login}...)
if err != nil {
return nil, nil, err

View File

@@ -2,6 +2,7 @@ package ldap
import (
"bytes"
"context"
"fmt"
"text/template"
@@ -12,9 +13,9 @@ import (
"github.com/hashicorp/vault/logical/framework"
)
func Factory(conf *logical.BackendConfig) (logical.Backend, error) {
func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) {
b := Backend()
if err := b.Setup(conf); err != nil {
if err := b.Setup(ctx, conf); err != nil {
return nil, err
}
return b, nil
@@ -92,9 +93,9 @@ func EscapeLDAPValue(input string) string {
return input
}
func (b *backend) Login(req *logical.Request, username string, password string) ([]string, *logical.Response, []string, error) {
func (b *backend) Login(ctx context.Context, req *logical.Request, username string, password string) ([]string, *logical.Response, []string, error) {
cfg, err := b.Config(req)
cfg, err := b.Config(ctx, req)
if err != nil {
return nil, nil, nil, err
}
@@ -172,7 +173,7 @@ func (b *backend) Login(req *logical.Request, username string, password string)
var allGroups []string
// Import the custom added groups from ldap backend
user, err := b.User(req.Storage, username)
user, err := b.User(ctx, req.Storage, username)
if err == nil && user != nil && user.Groups != nil {
if b.Logger().IsDebug() {
b.Logger().Debug("auth/ldap: adding local groups", "num_local_groups", len(user.Groups), "local_groups", user.Groups)
@@ -185,7 +186,7 @@ func (b *backend) Login(req *logical.Request, username string, password string)
// Retrieve policies
var policies []string
for _, groupName := range allGroups {
group, err := b.Group(req.Storage, groupName)
group, err := b.Group(ctx, req.Storage, groupName)
if err == nil && group != nil {
policies = append(policies, group.Policies...)
}

View File

@@ -23,7 +23,7 @@ func createBackendWithStorage(t *testing.T) (*backend, logical.Storage) {
t.Fatalf("failed to create backend")
}
err := b.Backend.Setup(config)
err := b.Backend.Setup(context.Background(), config)
if err != nil {
t.Fatal(err)
}
@@ -120,7 +120,7 @@ func TestLdapAuthBackend_UserPolicies(t *testing.T) {
func factory(t *testing.T) logical.Backend {
defaultLeaseTTLVal := time.Hour * 24
maxLeaseTTLVal := time.Hour * 24 * 32
b, err := Factory(&logical.BackendConfig{
b, err := Factory(context.Background(), &logical.BackendConfig{
Logger: nil,
System: &logical.StaticSystemView{
DefaultLeaseTTLVal: defaultLeaseTTLVal,

View File

@@ -130,7 +130,7 @@ Default: cn`,
/*
* Construct ConfigEntry struct using stored configuration.
*/
func (b *backend) Config(req *logical.Request) (*ConfigEntry, error) {
func (b *backend) Config(ctx context.Context, req *logical.Request) (*ConfigEntry, error) {
// Schema for ConfigEntry
fd, err := b.getConfigFieldData()
if err != nil {
@@ -143,7 +143,7 @@ func (b *backend) Config(req *logical.Request) (*ConfigEntry, error) {
return nil, err
}
storedConfig, err := req.Storage.Get("config")
storedConfig, err := req.Storage.Get(ctx, "config")
if err != nil {
return nil, err
}
@@ -165,7 +165,7 @@ func (b *backend) Config(req *logical.Request) (*ConfigEntry, error) {
}
func (b *backend) pathConfigRead(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
cfg, err := b.Config(req)
cfg, err := b.Config(ctx, req)
if err != nil {
return nil, err
}
@@ -299,7 +299,7 @@ func (b *backend) pathConfigWrite(ctx context.Context, req *logical.Request, d *
if err != nil {
return nil, err
}
if err := req.Storage.Put(entry); err != nil {
if err := req.Storage.Put(ctx, entry); err != nil {
return nil, err
}

View File

@@ -47,8 +47,8 @@ func pathGroups(b *backend) *framework.Path {
}
}
func (b *backend) Group(s logical.Storage, n string) (*GroupEntry, error) {
entry, err := s.Get("group/" + n)
func (b *backend) Group(ctx context.Context, s logical.Storage, n string) (*GroupEntry, error) {
entry, err := s.Get(ctx, "group/"+n)
if err != nil {
return nil, err
}
@@ -65,7 +65,7 @@ func (b *backend) Group(s logical.Storage, n string) (*GroupEntry, error) {
}
func (b *backend) pathGroupDelete(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
err := req.Storage.Delete("group/" + d.Get("name").(string))
err := req.Storage.Delete(ctx, "group/"+d.Get("name").(string))
if err != nil {
return nil, err
}
@@ -74,7 +74,7 @@ func (b *backend) pathGroupDelete(ctx context.Context, req *logical.Request, d *
}
func (b *backend) pathGroupRead(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
group, err := b.Group(req.Storage, d.Get("name").(string))
group, err := b.Group(ctx, req.Storage, d.Get("name").(string))
if err != nil {
return nil, err
}
@@ -97,7 +97,7 @@ func (b *backend) pathGroupWrite(ctx context.Context, req *logical.Request, d *f
if err != nil {
return nil, err
}
if err := req.Storage.Put(entry); err != nil {
if err := req.Storage.Put(ctx, entry); err != nil {
return nil, err
}
@@ -105,7 +105,7 @@ func (b *backend) pathGroupWrite(ctx context.Context, req *logical.Request, d *f
}
func (b *backend) pathGroupList(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
groups, err := req.Storage.List("group/")
groups, err := req.Storage.List(ctx, "group/")
if err != nil {
return nil, err
}

View File

@@ -54,7 +54,7 @@ func (b *backend) pathLogin(ctx context.Context, req *logical.Request, d *framew
username := d.Get("username").(string)
password := d.Get("password").(string)
policies, resp, groupNames, err := b.Login(req, username, password)
policies, resp, groupNames, err := b.Login(ctx, req, username, password)
// Handle an internal error
if err != nil {
return nil, err
@@ -102,7 +102,7 @@ func (b *backend) pathLoginRenew(ctx context.Context, req *logical.Request, d *f
username := req.Auth.Metadata["username"]
password := req.Auth.InternalData["password"].(string)
loginPolicies, resp, groupNames, err := b.Login(req, username, password)
loginPolicies, resp, groupNames, err := b.Login(ctx, req, username, password)
if len(loginPolicies) == 0 {
return resp, err
}

View File

@@ -54,8 +54,8 @@ func pathUsers(b *backend) *framework.Path {
}
}
func (b *backend) User(s logical.Storage, n string) (*UserEntry, error) {
entry, err := s.Get("user/" + n)
func (b *backend) User(ctx context.Context, s logical.Storage, n string) (*UserEntry, error) {
entry, err := s.Get(ctx, "user/"+n)
if err != nil {
return nil, err
}
@@ -72,7 +72,7 @@ func (b *backend) User(s logical.Storage, n string) (*UserEntry, error) {
}
func (b *backend) pathUserDelete(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
err := req.Storage.Delete("user/" + d.Get("name").(string))
err := req.Storage.Delete(ctx, "user/"+d.Get("name").(string))
if err != nil {
return nil, err
}
@@ -81,7 +81,7 @@ func (b *backend) pathUserDelete(ctx context.Context, req *logical.Request, d *f
}
func (b *backend) pathUserRead(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
user, err := b.User(req.Storage, d.Get("name").(string))
user, err := b.User(ctx, req.Storage, d.Get("name").(string))
if err != nil {
return nil, err
}
@@ -113,7 +113,7 @@ func (b *backend) pathUserWrite(ctx context.Context, req *logical.Request, d *fr
if err != nil {
return nil, err
}
if err := req.Storage.Put(entry); err != nil {
if err := req.Storage.Put(ctx, entry); err != nil {
return nil, err
}
@@ -121,7 +121,7 @@ func (b *backend) pathUserWrite(ctx context.Context, req *logical.Request, d *fr
}
func (b *backend) pathUserList(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
users, err := req.Storage.List("user/")
users, err := req.Storage.List(ctx, "user/")
if err != nil {
return nil, err
}

View File

@@ -1,6 +1,7 @@
package okta
import (
"context"
"fmt"
"github.com/chrismalek/oktasdk-go/okta"
@@ -9,9 +10,9 @@ import (
"github.com/hashicorp/vault/logical/framework"
)
func Factory(conf *logical.BackendConfig) (logical.Backend, error) {
func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) {
b := Backend()
if err := b.Setup(conf); err != nil {
if err := b.Setup(ctx, conf); err != nil {
return nil, err
}
return b, nil
@@ -54,8 +55,8 @@ type backend struct {
*framework.Backend
}
func (b *backend) Login(req *logical.Request, username string, password string) ([]string, *logical.Response, []string, error) {
cfg, err := b.Config(req.Storage)
func (b *backend) Login(ctx context.Context, req *logical.Request, username string, password string) ([]string, *logical.Response, []string, error) {
cfg, err := b.Config(ctx, req.Storage)
if err != nil {
return nil, nil, nil, err
}
@@ -110,7 +111,7 @@ func (b *backend) Login(req *logical.Request, username string, password string)
}
// Import the custom added groups from okta backend
user, err := b.User(req.Storage, username)
user, err := b.User(ctx, req.Storage, username)
if err != nil {
if b.Logger().IsDebug() {
b.Logger().Debug("auth/okta: error looking up user", "error", err)
@@ -126,7 +127,7 @@ func (b *backend) Login(req *logical.Request, username string, password string)
// Retrieve policies
var policies []string
for _, groupName := range allGroups {
entry, _, err := b.Group(req.Storage, groupName)
entry, _, err := b.Group(ctx, req.Storage, groupName)
if err != nil {
if b.Logger().IsDebug() {
b.Logger().Debug("auth/okta: error looking up group policies", "error", err)

View File

@@ -1,6 +1,7 @@
package okta
import (
"context"
"fmt"
"os"
"strings"
@@ -19,7 +20,7 @@ import (
func TestBackend_Config(t *testing.T) {
defaultLeaseTTLVal := time.Hour * 12
maxLeaseTTLVal := time.Hour * 24
b, err := Factory(&logical.BackendConfig{
b, err := Factory(context.Background(), &logical.BackendConfig{
Logger: logformat.NewVaultLogger(log.LevelTrace),
System: &logical.StaticSystemView{
DefaultLeaseTTLVal: defaultLeaseTTLVal,

View File

@@ -69,8 +69,8 @@ func pathConfig(b *backend) *framework.Path {
}
// Config returns the configuration for this backend.
func (b *backend) Config(s logical.Storage) (*ConfigEntry, error) {
entry, err := s.Get("config")
func (b *backend) Config(ctx context.Context, s logical.Storage) (*ConfigEntry, error) {
entry, err := s.Get(ctx, "config")
if err != nil {
return nil, err
}
@@ -89,7 +89,7 @@ func (b *backend) Config(s logical.Storage) (*ConfigEntry, error) {
}
func (b *backend) pathConfigRead(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
cfg, err := b.Config(req.Storage)
cfg, err := b.Config(ctx, req.Storage)
if err != nil {
return nil, err
}
@@ -116,7 +116,7 @@ func (b *backend) pathConfigRead(ctx context.Context, req *logical.Request, d *f
}
func (b *backend) pathConfigWrite(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
cfg, err := b.Config(req.Storage)
cfg, err := b.Config(ctx, req.Storage)
if err != nil {
return nil, err
}
@@ -193,7 +193,7 @@ func (b *backend) pathConfigWrite(ctx context.Context, req *logical.Request, d *
if err != nil {
return nil, err
}
if err := req.Storage.Put(jsonCfg); err != nil {
if err := req.Storage.Put(ctx, jsonCfg); err != nil {
return nil, err
}
@@ -201,7 +201,7 @@ func (b *backend) pathConfigWrite(ctx context.Context, req *logical.Request, d *
}
func (b *backend) pathConfigExistenceCheck(ctx context.Context, req *logical.Request, d *framework.FieldData) (bool, error) {
cfg, err := b.Config(req.Storage)
cfg, err := b.Config(ctx, req.Storage)
if err != nil {
return false, err
}

View File

@@ -50,20 +50,20 @@ func pathGroups(b *backend) *framework.Path {
// We look up groups in a case-insensitive manner since Okta is case-preserving
// but case-insensitive for comparisons
func (b *backend) Group(s logical.Storage, n string) (*GroupEntry, string, error) {
func (b *backend) Group(ctx context.Context, s logical.Storage, n string) (*GroupEntry, string, error) {
canonicalName := n
entry, err := s.Get("group/" + n)
entry, err := s.Get(ctx, "group/"+n)
if err != nil {
return nil, "", err
}
if entry == nil {
entries, err := s.List("group/")
entries, err := s.List(ctx, "group/")
if err != nil {
return nil, "", err
}
for _, groupName := range entries {
if strings.ToLower(groupName) == strings.ToLower(n) {
entry, err = s.Get("group/" + groupName)
entry, err = s.Get(ctx, "group/"+groupName)
if err != nil {
return nil, "", err
}
@@ -90,12 +90,12 @@ func (b *backend) pathGroupDelete(ctx context.Context, req *logical.Request, d *
return logical.ErrorResponse("'name' must be supplied"), nil
}
entry, canonicalName, err := b.Group(req.Storage, name)
entry, canonicalName, err := b.Group(ctx, req.Storage, name)
if err != nil {
return nil, err
}
if entry != nil {
err := req.Storage.Delete("group/" + canonicalName)
err := req.Storage.Delete(ctx, "group/"+canonicalName)
if err != nil {
return nil, err
}
@@ -110,7 +110,7 @@ func (b *backend) pathGroupRead(ctx context.Context, req *logical.Request, d *fr
return logical.ErrorResponse("'name' must be supplied"), nil
}
group, _, err := b.Group(req.Storage, name)
group, _, err := b.Group(ctx, req.Storage, name)
if err != nil {
return nil, err
}
@@ -133,7 +133,7 @@ func (b *backend) pathGroupWrite(ctx context.Context, req *logical.Request, d *f
// Check for an existing group, possibly lowercased so that we keep using
// existing user set values
_, canonicalName, err := b.Group(req.Storage, name)
_, canonicalName, err := b.Group(ctx, req.Storage, name)
if err != nil {
return nil, err
}
@@ -149,7 +149,7 @@ func (b *backend) pathGroupWrite(ctx context.Context, req *logical.Request, d *f
if err != nil {
return nil, err
}
if err := req.Storage.Put(entry); err != nil {
if err := req.Storage.Put(ctx, entry); err != nil {
return nil, err
}
@@ -157,7 +157,7 @@ func (b *backend) pathGroupWrite(ctx context.Context, req *logical.Request, d *f
}
func (b *backend) pathGroupList(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
groups, err := req.Storage.List("group/")
groups, err := req.Storage.List(ctx, "group/")
if err != nil {
return nil, err
}

View File

@@ -56,7 +56,7 @@ func (b *backend) pathLogin(ctx context.Context, req *logical.Request, d *framew
username := d.Get("username").(string)
password := d.Get("password").(string)
policies, resp, groupNames, err := b.Login(req, username, password)
policies, resp, groupNames, err := b.Login(ctx, req, username, password)
// Handle an internal error
if err != nil {
return nil, err
@@ -72,7 +72,7 @@ func (b *backend) pathLogin(ctx context.Context, req *logical.Request, d *framew
sort.Strings(policies)
cfg, err := b.getConfig(req)
cfg, err := b.getConfig(ctx, req)
if err != nil {
return nil, err
}
@@ -112,7 +112,7 @@ func (b *backend) pathLoginRenew(ctx context.Context, req *logical.Request, d *f
username := req.Auth.Metadata["username"]
password := req.Auth.InternalData["password"].(string)
loginPolicies, resp, groupNames, err := b.Login(req, username, password)
loginPolicies, resp, groupNames, err := b.Login(ctx, req, username, password)
if len(loginPolicies) == 0 {
return resp, err
}
@@ -121,7 +121,7 @@ func (b *backend) pathLoginRenew(ctx context.Context, req *logical.Request, d *f
return nil, fmt.Errorf("policies have changed, not renewing")
}
cfg, err := b.getConfig(req)
cfg, err := b.getConfig(ctx, req)
if err != nil {
return nil, err
}
@@ -144,9 +144,9 @@ func (b *backend) pathLoginRenew(ctx context.Context, req *logical.Request, d *f
}
func (b *backend) getConfig(req *logical.Request) (*ConfigEntry, error) {
func (b *backend) getConfig(ctx context.Context, req *logical.Request) (*ConfigEntry, error) {
cfg, err := b.Config(req.Storage)
cfg, err := b.Config(ctx, req.Storage)
if err != nil {
return nil, err
}

View File

@@ -51,8 +51,8 @@ func pathUsers(b *backend) *framework.Path {
}
}
func (b *backend) User(s logical.Storage, n string) (*UserEntry, error) {
entry, err := s.Get("user/" + n)
func (b *backend) User(ctx context.Context, s logical.Storage, n string) (*UserEntry, error) {
entry, err := s.Get(ctx, "user/"+n)
if err != nil {
return nil, err
}
@@ -74,7 +74,7 @@ func (b *backend) pathUserDelete(ctx context.Context, req *logical.Request, d *f
return logical.ErrorResponse("Error empty name"), nil
}
err := req.Storage.Delete("user/" + name)
err := req.Storage.Delete(ctx, "user/"+name)
if err != nil {
return nil, err
}
@@ -88,7 +88,7 @@ func (b *backend) pathUserRead(ctx context.Context, req *logical.Request, d *fra
return logical.ErrorResponse("Error empty name"), nil
}
user, err := b.User(req.Storage, name)
user, err := b.User(ctx, req.Storage, name)
if err != nil {
return nil, err
}
@@ -121,7 +121,7 @@ func (b *backend) pathUserWrite(ctx context.Context, req *logical.Request, d *fr
if err != nil {
return nil, err
}
if err := req.Storage.Put(entry); err != nil {
if err := req.Storage.Put(ctx, entry); err != nil {
return nil, err
}
@@ -129,7 +129,7 @@ func (b *backend) pathUserWrite(ctx context.Context, req *logical.Request, d *fr
}
func (b *backend) pathUserList(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
users, err := req.Storage.List("user/")
users, err := req.Storage.List(ctx, "user/")
if err != nil {
return nil, err
}

View File

@@ -1,14 +1,16 @@
package radius
import (
"context"
"github.com/hashicorp/vault/helper/mfa"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
)
func Factory(conf *logical.BackendConfig) (logical.Backend, error) {
func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) {
b := Backend()
if err := b.Setup(conf); err != nil {
if err := b.Setup(ctx, conf); err != nil {
return nil, err
}
return b, nil

View File

@@ -1,6 +1,7 @@
package radius
import (
"context"
"fmt"
"os"
"reflect"
@@ -17,7 +18,7 @@ const (
)
func TestBackend_Config(t *testing.T) {
b, err := Factory(&logical.BackendConfig{
b, err := Factory(context.Background(), &logical.BackendConfig{
Logger: nil,
System: &logical.StaticSystemView{
DefaultLeaseTTLVal: testSysTTL,
@@ -70,7 +71,7 @@ func TestBackend_Config(t *testing.T) {
}
func TestBackend_users(t *testing.T) {
b, err := Factory(&logical.BackendConfig{
b, err := Factory(context.Background(), &logical.BackendConfig{
Logger: nil,
System: &logical.StaticSystemView{
DefaultLeaseTTLVal: testSysTTL,
@@ -98,7 +99,7 @@ func TestBackend_acceptance(t *testing.T) {
return
}
b, err := Factory(&logical.BackendConfig{
b, err := Factory(context.Background(), &logical.BackendConfig{
Logger: nil,
System: &logical.StaticSystemView{
DefaultLeaseTTLVal: testSysTTL,

View File

@@ -65,7 +65,7 @@ func pathConfig(b *backend) *framework.Path {
// Establishes dichotomy of request operation between CreateOperation and UpdateOperation.
// Returning 'true' forces an UpdateOperation, CreateOperation otherwise.
func (b *backend) configExistenceCheck(ctx context.Context, req *logical.Request, data *framework.FieldData) (bool, error) {
entry, err := b.Config(req)
entry, err := b.Config(ctx, req)
if err != nil {
return false, err
}
@@ -75,9 +75,8 @@ func (b *backend) configExistenceCheck(ctx context.Context, req *logical.Request
/*
* Construct ConfigEntry struct using stored configuration.
*/
func (b *backend) Config(req *logical.Request) (*ConfigEntry, error) {
storedConfig, err := req.Storage.Get("config")
func (b *backend) Config(ctx context.Context, req *logical.Request) (*ConfigEntry, error) {
storedConfig, err := req.Storage.Get(ctx, "config")
if err != nil {
return nil, err
}
@@ -96,7 +95,7 @@ func (b *backend) Config(req *logical.Request) (*ConfigEntry, error) {
}
func (b *backend) pathConfigRead(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
cfg, err := b.Config(req)
cfg, err := b.Config(ctx, req)
if err != nil {
return nil, err
}
@@ -113,7 +112,7 @@ func (b *backend) pathConfigRead(ctx context.Context, req *logical.Request, d *f
func (b *backend) pathConfigCreateUpdate(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
// Build a ConfigEntry struct out of the supplied FieldData
cfg, err := b.Config(req)
cfg, err := b.Config(ctx, req)
if err != nil {
return nil, err
}
@@ -190,7 +189,7 @@ func (b *backend) pathConfigCreateUpdate(ctx context.Context, req *logical.Reque
if err != nil {
return nil, err
}
if err := req.Storage.Put(entry); err != nil {
if err := req.Storage.Put(ctx, entry); err != nil {
return nil, err
}

View File

@@ -76,7 +76,7 @@ func (b *backend) pathLogin(ctx context.Context, req *logical.Request, d *framew
return logical.ErrorResponse("password cannot be empty"), nil
}
policies, resp, err := b.RadiusLogin(req, username, password)
policies, resp, err := b.RadiusLogin(ctx, req, username, password)
// Handle an internal error
if err != nil {
return nil, err
@@ -117,7 +117,7 @@ func (b *backend) pathLoginRenew(ctx context.Context, req *logical.Request, d *f
var resp *logical.Response
var loginPolicies []string
loginPolicies, resp, err = b.RadiusLogin(req, username, password)
loginPolicies, resp, err = b.RadiusLogin(ctx, req, username, password)
if err != nil || (resp != nil && resp.IsError()) {
return resp, err
}
@@ -129,9 +129,9 @@ func (b *backend) pathLoginRenew(ctx context.Context, req *logical.Request, d *f
return framework.LeaseExtend(0, 0, b.System())(ctx, req, d)
}
func (b *backend) RadiusLogin(req *logical.Request, username string, password string) ([]string, *logical.Response, error) {
func (b *backend) RadiusLogin(ctx context.Context, req *logical.Request, username string, password string) ([]string, *logical.Response, error) {
cfg, err := b.Config(req)
cfg, err := b.Config(ctx, req)
if err != nil {
return nil, nil, err
}
@@ -163,7 +163,7 @@ func (b *backend) RadiusLogin(req *logical.Request, username string, password st
var policies []string
// Retrieve user entry from storage
user, err := b.user(req.Storage, username)
user, err := b.user(ctx, req.Storage, username)
if err != nil {
return policies, logical.ErrorResponse("could not retrieve user entry from storage"), err
}

View File

@@ -53,7 +53,7 @@ func pathUsers(b *backend) *framework.Path {
}
func (b *backend) userExistenceCheck(ctx context.Context, req *logical.Request, data *framework.FieldData) (bool, error) {
userEntry, err := b.user(req.Storage, data.Get("name").(string))
userEntry, err := b.user(ctx, req.Storage, data.Get("name").(string))
if err != nil {
return false, err
}
@@ -61,12 +61,12 @@ func (b *backend) userExistenceCheck(ctx context.Context, req *logical.Request,
return userEntry != nil, nil
}
func (b *backend) user(s logical.Storage, username string) (*UserEntry, error) {
func (b *backend) user(ctx context.Context, s logical.Storage, username string) (*UserEntry, error) {
if username == "" {
return nil, fmt.Errorf("missing username")
}
entry, err := s.Get("user/" + strings.ToLower(username))
entry, err := s.Get(ctx, "user/"+strings.ToLower(username))
if err != nil {
return nil, err
}
@@ -83,7 +83,7 @@ func (b *backend) user(s logical.Storage, username string) (*UserEntry, error) {
}
func (b *backend) pathUserDelete(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
err := req.Storage.Delete("user/" + d.Get("name").(string))
err := req.Storage.Delete(ctx, "user/"+d.Get("name").(string))
if err != nil {
return nil, err
}
@@ -92,7 +92,7 @@ func (b *backend) pathUserDelete(ctx context.Context, req *logical.Request, d *f
}
func (b *backend) pathUserRead(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
user, err := b.user(req.Storage, d.Get("name").(string))
user, err := b.user(ctx, req.Storage, d.Get("name").(string))
if err != nil {
return nil, err
}
@@ -123,7 +123,7 @@ func (b *backend) pathUserWrite(ctx context.Context, req *logical.Request, d *fr
if err != nil {
return nil, err
}
if err := req.Storage.Put(entry); err != nil {
if err := req.Storage.Put(ctx, entry); err != nil {
return nil, err
}
@@ -131,7 +131,7 @@ func (b *backend) pathUserWrite(ctx context.Context, req *logical.Request, d *fr
}
func (b *backend) pathUserList(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
users, err := req.Storage.List("user/")
users, err := req.Storage.List(ctx, "user/")
if err != nil {
return nil, err
}

View File

@@ -1,14 +1,16 @@
package userpass
import (
"context"
"github.com/hashicorp/vault/helper/mfa"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
)
func Factory(conf *logical.BackendConfig) (logical.Backend, error) {
func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) {
b := Backend()
if err := b.Setup(conf); err != nil {
if err := b.Setup(ctx, conf); err != nil {
return nil, err
}
return b, nil

View File

@@ -1,6 +1,7 @@
package userpass
import (
"context"
"fmt"
"reflect"
"testing"
@@ -45,7 +46,7 @@ func TestBackend_TTLDurations(t *testing.T) {
data5 := map[string]interface{}{
"password": "password",
}
b, err := Factory(&logical.BackendConfig{
b, err := Factory(context.Background(), &logical.BackendConfig{
Logger: nil,
System: &logical.StaticSystemView{
DefaultLeaseTTLVal: testSysTTL,
@@ -69,7 +70,7 @@ func TestBackend_TTLDurations(t *testing.T) {
}
func TestBackend_basic(t *testing.T) {
b, err := Factory(&logical.BackendConfig{
b, err := Factory(context.Background(), &logical.BackendConfig{
Logger: nil,
System: &logical.StaticSystemView{
DefaultLeaseTTLVal: testSysTTL,
@@ -92,7 +93,7 @@ func TestBackend_basic(t *testing.T) {
}
func TestBackend_userCrud(t *testing.T) {
b, err := Factory(&logical.BackendConfig{
b, err := Factory(context.Background(), &logical.BackendConfig{
Logger: nil,
System: &logical.StaticSystemView{
DefaultLeaseTTLVal: testSysTTL,
@@ -115,7 +116,7 @@ func TestBackend_userCrud(t *testing.T) {
}
func TestBackend_userCreateOperation(t *testing.T) {
b, err := Factory(&logical.BackendConfig{
b, err := Factory(context.Background(), &logical.BackendConfig{
Logger: nil,
System: &logical.StaticSystemView{
DefaultLeaseTTLVal: testSysTTL,
@@ -136,7 +137,7 @@ func TestBackend_userCreateOperation(t *testing.T) {
}
func TestBackend_passwordUpdate(t *testing.T) {
b, err := Factory(&logical.BackendConfig{
b, err := Factory(context.Background(), &logical.BackendConfig{
Logger: nil,
System: &logical.StaticSystemView{
DefaultLeaseTTLVal: testSysTTL,
@@ -161,7 +162,7 @@ func TestBackend_passwordUpdate(t *testing.T) {
}
func TestBackend_policiesUpdate(t *testing.T) {
b, err := Factory(&logical.BackendConfig{
b, err := Factory(context.Background(), &logical.BackendConfig{
Logger: nil,
System: &logical.StaticSystemView{
DefaultLeaseTTLVal: testSysTTL,

View File

@@ -61,7 +61,7 @@ func (b *backend) pathLogin(ctx context.Context, req *logical.Request, d *framew
}
// Get the user and validate auth
user, err := b.user(req.Storage, username)
user, err := b.user(ctx, req.Storage, username)
if err != nil {
return nil, err
}
@@ -102,7 +102,7 @@ func (b *backend) pathLogin(ctx context.Context, req *logical.Request, d *framew
func (b *backend) pathLoginRenew(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
// Get the user
user, err := b.user(req.Storage, req.Auth.Metadata["username"])
user, err := b.user(ctx, req.Storage, req.Auth.Metadata["username"])
if err != nil {
return nil, err
}

View File

@@ -37,7 +37,7 @@ func pathUserPassword(b *backend) *framework.Path {
func (b *backend) pathUserPasswordUpdate(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
username := d.Get("username").(string)
userEntry, err := b.user(req.Storage, username)
userEntry, err := b.user(ctx, req.Storage, username)
if err != nil {
return nil, err
}
@@ -53,7 +53,7 @@ func (b *backend) pathUserPasswordUpdate(ctx context.Context, req *logical.Reque
return logical.ErrorResponse(userErr.Error()), logical.ErrInvalidRequest
}
return nil, b.setUser(req.Storage, username, userEntry)
return nil, b.setUser(ctx, req.Storage, username, userEntry)
}
func (b *backend) updateUserPassword(req *logical.Request, d *framework.FieldData, userEntry *UserEntry) (error, error) {

View File

@@ -35,7 +35,7 @@ func pathUserPolicies(b *backend) *framework.Path {
func (b *backend) pathUserPoliciesUpdate(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
username := d.Get("username").(string)
userEntry, err := b.user(req.Storage, username)
userEntry, err := b.user(ctx, req.Storage, username)
if err != nil {
return nil, err
}
@@ -45,7 +45,7 @@ func (b *backend) pathUserPoliciesUpdate(ctx context.Context, req *logical.Reque
userEntry.Policies = policyutil.ParsePolicies(d.Get("policies"))
return nil, b.setUser(req.Storage, username, userEntry)
return nil, b.setUser(ctx, req.Storage, username, userEntry)
}
const pathUserPoliciesHelpSyn = `

View File

@@ -69,7 +69,7 @@ func pathUsers(b *backend) *framework.Path {
}
func (b *backend) userExistenceCheck(ctx context.Context, req *logical.Request, data *framework.FieldData) (bool, error) {
userEntry, err := b.user(req.Storage, data.Get("username").(string))
userEntry, err := b.user(ctx, req.Storage, data.Get("username").(string))
if err != nil {
return false, err
}
@@ -77,12 +77,12 @@ func (b *backend) userExistenceCheck(ctx context.Context, req *logical.Request,
return userEntry != nil, nil
}
func (b *backend) user(s logical.Storage, username string) (*UserEntry, error) {
func (b *backend) user(ctx context.Context, s logical.Storage, username string) (*UserEntry, error) {
if username == "" {
return nil, fmt.Errorf("missing username")
}
entry, err := s.Get("user/" + strings.ToLower(username))
entry, err := s.Get(ctx, "user/"+strings.ToLower(username))
if err != nil {
return nil, err
}
@@ -98,17 +98,17 @@ func (b *backend) user(s logical.Storage, username string) (*UserEntry, error) {
return &result, nil
}
func (b *backend) setUser(s logical.Storage, username string, userEntry *UserEntry) error {
func (b *backend) setUser(ctx context.Context, s logical.Storage, username string, userEntry *UserEntry) error {
entry, err := logical.StorageEntryJSON("user/"+username, userEntry)
if err != nil {
return err
}
return s.Put(entry)
return s.Put(ctx, entry)
}
func (b *backend) pathUserList(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
users, err := req.Storage.List("user/")
users, err := req.Storage.List(ctx, "user/")
if err != nil {
return nil, err
}
@@ -116,7 +116,7 @@ func (b *backend) pathUserList(ctx context.Context, req *logical.Request, d *fra
}
func (b *backend) pathUserDelete(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
err := req.Storage.Delete("user/" + strings.ToLower(d.Get("username").(string)))
err := req.Storage.Delete(ctx, "user/"+strings.ToLower(d.Get("username").(string)))
if err != nil {
return nil, err
}
@@ -125,7 +125,7 @@ func (b *backend) pathUserDelete(ctx context.Context, req *logical.Request, d *f
}
func (b *backend) pathUserRead(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
user, err := b.user(req.Storage, strings.ToLower(d.Get("username").(string)))
user, err := b.user(ctx, req.Storage, strings.ToLower(d.Get("username").(string)))
if err != nil {
return nil, err
}
@@ -144,7 +144,7 @@ func (b *backend) pathUserRead(ctx context.Context, req *logical.Request, d *fra
func (b *backend) userCreateUpdate(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
username := strings.ToLower(d.Get("username").(string))
userEntry, err := b.user(req.Storage, username)
userEntry, err := b.user(ctx, req.Storage, username)
if err != nil {
return nil, err
}
@@ -182,7 +182,7 @@ func (b *backend) userCreateUpdate(ctx context.Context, req *logical.Request, d
return logical.ErrorResponse(fmt.Sprintf("err: %s", err)), nil
}
return nil, b.setUser(req.Storage, username, userEntry)
return nil, b.setUser(ctx, req.Storage, username, userEntry)
}
func (b *backend) pathUserWrite(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {

View File

@@ -1,6 +1,7 @@
package aws
import (
"context"
"strings"
"time"
@@ -8,9 +9,9 @@ import (
"github.com/hashicorp/vault/logical/framework"
)
func Factory(conf *logical.BackendConfig) (logical.Backend, error) {
func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) {
b := Backend()
if err := b.Setup(conf); err != nil {
if err := b.Setup(ctx, conf); err != nil {
return nil, err
}
return b, nil

View File

@@ -2,6 +2,7 @@ package aws
import (
"bytes"
"context"
"encoding/json"
"fmt"
"log"
@@ -22,7 +23,7 @@ import (
)
func getBackend(t *testing.T) logical.Backend {
be, _ := Factory(logical.TestBackendConfig())
be, _ := Factory(context.Background(), logical.TestBackendConfig())
return be
}

View File

@@ -1,6 +1,7 @@
package aws
import (
"context"
"fmt"
"os"
@@ -13,11 +14,11 @@ import (
"github.com/hashicorp/vault/logical"
)
func getRootConfig(s logical.Storage, clientType string) (*aws.Config, error) {
func getRootConfig(ctx context.Context, s logical.Storage, clientType string) (*aws.Config, error) {
credsConfig := &awsutil.CredentialsConfig{}
var endpoint string
entry, err := s.Get("config/root")
entry, err := s.Get(ctx, "config/root")
if err != nil {
return nil, err
}
@@ -63,8 +64,8 @@ func getRootConfig(s logical.Storage, clientType string) (*aws.Config, error) {
}, nil
}
func clientIAM(s logical.Storage) (*iam.IAM, error) {
awsConfig, err := getRootConfig(s, "iam")
func clientIAM(ctx context.Context, s logical.Storage) (*iam.IAM, error) {
awsConfig, err := getRootConfig(ctx, s, "iam")
if err != nil {
return nil, err
}
@@ -77,8 +78,8 @@ func clientIAM(s logical.Storage) (*iam.IAM, error) {
return client, nil
}
func clientSTS(s logical.Storage) (*sts.STS, error) {
awsConfig, err := getRootConfig(s, "sts")
func clientSTS(ctx context.Context, s logical.Storage) (*sts.STS, error) {
awsConfig, err := getRootConfig(ctx, s, "sts")
if err != nil {
return nil, err
}

View File

@@ -35,8 +35,8 @@ func pathConfigLease(b *backend) *framework.Path {
}
// Lease returns the lease information
func (b *backend) Lease(s logical.Storage) (*configLease, error) {
entry, err := s.Get("config/lease")
func (b *backend) Lease(ctx context.Context, s logical.Storage) (*configLease, error) {
entry, err := s.Get(ctx, "config/lease")
if err != nil {
return nil, err
}
@@ -82,7 +82,7 @@ func (b *backend) pathLeaseWrite(ctx context.Context, req *logical.Request, d *f
if err != nil {
return nil, err
}
if err := req.Storage.Put(entry); err != nil {
if err := req.Storage.Put(ctx, entry); err != nil {
return nil, err
}
@@ -90,7 +90,7 @@ func (b *backend) pathLeaseWrite(ctx context.Context, req *logical.Request, d *f
}
func (b *backend) pathLeaseRead(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
lease, err := b.Lease(req.Storage)
lease, err := b.Lease(ctx, req.Storage)
if err != nil {
return nil, err

View File

@@ -60,7 +60,7 @@ func pathConfigRootWrite(ctx context.Context, req *logical.Request, data *framew
return nil, err
}
if err := req.Storage.Put(entry); err != nil {
if err := req.Storage.Put(ctx, entry); err != nil {
return nil, err
}

View File

@@ -58,7 +58,7 @@ func pathRoles() *framework.Path {
}
func (b *backend) pathRoleList(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
entries, err := req.Storage.List("policy/")
entries, err := req.Storage.List(ctx, "policy/")
if err != nil {
return nil, err
}
@@ -66,7 +66,7 @@ func (b *backend) pathRoleList(ctx context.Context, req *logical.Request, d *fra
}
func pathRolesDelete(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
err := req.Storage.Delete("policy/" + d.Get("name").(string))
err := req.Storage.Delete(ctx, "policy/"+d.Get("name").(string))
if err != nil {
return nil, err
}
@@ -75,7 +75,7 @@ func pathRolesDelete(ctx context.Context, req *logical.Request, d *framework.Fie
}
func pathRolesRead(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
entry, err := req.Storage.Get("policy/" + d.Get("name").(string))
entry, err := req.Storage.Get(ctx, "policy/"+d.Get("name").(string))
if err != nil {
return nil, err
}
@@ -125,7 +125,7 @@ func pathRolesWrite(ctx context.Context, req *logical.Request, d *framework.Fiel
"Error compacting policy: %s", err)), nil
}
// Write the policy into storage
err := req.Storage.Put(&logical.StorageEntry{
err := req.Storage.Put(ctx, &logical.StorageEntry{
Key: "policy/" + d.Get("name").(string),
Value: buf.Bytes(),
})
@@ -134,7 +134,7 @@ func pathRolesWrite(ctx context.Context, req *logical.Request, d *framework.Fiel
}
} else {
// Write the arn ref into storage
err := req.Storage.Put(&logical.StorageEntry{
err := req.Storage.Put(ctx, &logical.StorageEntry{
Key: "policy/" + d.Get("name").(string),
Value: []byte(d.Get("arn").(string)),
})

View File

@@ -15,7 +15,7 @@ func TestBackend_PathListRoles(t *testing.T) {
config.StorageView = &logical.InmemStorage{}
b := Backend()
if err := b.Setup(config); err != nil {
if err := b.Setup(context.Background(), config); err != nil {
t.Fatal(err)
}

View File

@@ -45,7 +45,7 @@ func (b *backend) pathSTSRead(ctx context.Context, req *logical.Request, d *fram
ttl := int64(d.Get("ttl").(int))
// Read the policy
policy, err := req.Storage.Get("policy/" + policyName)
policy, err := req.Storage.Get(ctx, "policy/"+policyName)
if err != nil {
return nil, fmt.Errorf("error retrieving role: %s", err)
}
@@ -57,6 +57,7 @@ func (b *backend) pathSTSRead(ctx context.Context, req *logical.Request, d *fram
if strings.HasPrefix(policyValue, "arn:") {
if strings.Contains(policyValue, ":role/") {
return b.assumeRole(
ctx,
req.Storage,
req.DisplayName, policyName, policyValue,
ttl,
@@ -69,6 +70,7 @@ func (b *backend) pathSTSRead(ctx context.Context, req *logical.Request, d *fram
}
// Use the helper to create the secret
return b.secretTokenCreate(
ctx,
req.Storage,
req.DisplayName, policyName, policyValue,
ttl,

View File

@@ -34,7 +34,7 @@ func (b *backend) pathUserRead(ctx context.Context, req *logical.Request, d *fra
policyName := d.Get("name").(string)
// Read the policy
policy, err := req.Storage.Get("policy/" + policyName)
policy, err := req.Storage.Get(ctx, "policy/"+policyName)
if err != nil {
return nil, fmt.Errorf("error retrieving role: %s", err)
}
@@ -45,10 +45,10 @@ func (b *backend) pathUserRead(ctx context.Context, req *logical.Request, d *fra
// Use the helper to create the secret
return b.secretAccessKeysCreate(
req.Storage, req.DisplayName, policyName, string(policy.Value))
ctx, req.Storage, req.DisplayName, policyName, string(policy.Value))
}
func pathUserRollback(req *logical.Request, _kind string, data interface{}) error {
func pathUserRollback(ctx context.Context, req *logical.Request, _kind string, data interface{}) error {
var entry walUser
if err := mapstructure.Decode(data, &entry); err != nil {
return err
@@ -56,7 +56,7 @@ func pathUserRollback(req *logical.Request, _kind string, data interface{}) erro
username := entry.UserName
// Get the client
client, err := clientIAM(req.Storage)
client, err := clientIAM(ctx, req.Storage)
if err != nil {
return err
}

View File

@@ -1,6 +1,7 @@
package aws
import (
"context"
"fmt"
"github.com/hashicorp/vault/logical"
@@ -11,11 +12,11 @@ var walRollbackMap = map[string]framework.WALRollbackFunc{
"user": pathUserRollback,
}
func walRollback(req *logical.Request, kind string, data interface{}) error {
func walRollback(ctx context.Context, req *logical.Request, kind string, data interface{}) error {
f, ok := walRollbackMap[kind]
if !ok {
return fmt.Errorf("unknown type to rollback")
}
return f(req, kind, data)
return f(ctx, req, kind, data)
}

View File

@@ -65,10 +65,10 @@ func genUsername(displayName, policyName, userType string) (ret string, warning
return
}
func (b *backend) secretTokenCreate(s logical.Storage,
func (b *backend) secretTokenCreate(ctx context.Context, s logical.Storage,
displayName, policyName, policy string,
lifeTimeInSeconds int64) (*logical.Response, error) {
STSClient, err := clientSTS(s)
STSClient, err := clientSTS(ctx, s)
if err != nil {
return logical.ErrorResponse(err.Error()), nil
}
@@ -110,10 +110,10 @@ func (b *backend) secretTokenCreate(s logical.Storage,
return resp, nil
}
func (b *backend) assumeRole(s logical.Storage,
func (b *backend) assumeRole(ctx context.Context, s logical.Storage,
displayName, policyName, policy string,
lifeTimeInSeconds int64) (*logical.Response, error) {
STSClient, err := clientSTS(s)
STSClient, err := clientSTS(ctx, s)
if err != nil {
return logical.ErrorResponse(err.Error()), nil
}
@@ -156,9 +156,10 @@ func (b *backend) assumeRole(s logical.Storage,
}
func (b *backend) secretAccessKeysCreate(
ctx context.Context,
s logical.Storage,
displayName, policyName string, policy string) (*logical.Response, error) {
client, err := clientIAM(s)
client, err := clientIAM(ctx, s)
if err != nil {
return logical.ErrorResponse(err.Error()), nil
}
@@ -169,7 +170,7 @@ func (b *backend) secretAccessKeysCreate(
// the user is created because if switch the order then the WAL put
// can fail, which would put us in an awkward position: we have a user
// we need to rollback but can't put the WAL entry to do the rollback.
walId, err := framework.PutWAL(s, "user", &walUser{
walId, err := framework.PutWAL(ctx, s, "user", &walUser{
UserName: username,
})
if err != nil {
@@ -221,7 +222,7 @@ func (b *backend) secretAccessKeysCreate(
// Remove the WAL entry, we succeeded! If we fail, we don't return
// the secret because it'll get rolled back anyways, so we have to return
// an error here.
if err := framework.DeleteWAL(s, walId); err != nil {
if err := framework.DeleteWAL(ctx, s, walId); err != nil {
return nil, fmt.Errorf("Failed to commit WAL entry: %s", err)
}
@@ -236,7 +237,7 @@ func (b *backend) secretAccessKeysCreate(
"is_sts": false,
})
lease, err := b.Lease(s)
lease, err := b.Lease(ctx, s)
if err != nil || lease == nil {
lease = &configLease{}
}
@@ -262,7 +263,7 @@ func (b *backend) secretAccessKeysRenew(ctx context.Context, req *logical.Reques
}
}
lease, err := b.Lease(req.Storage)
lease, err := b.Lease(ctx, req.Storage)
if err != nil {
return nil, err
}
@@ -302,7 +303,7 @@ func secretAccessKeysRevoke(ctx context.Context, req *logical.Request, d *framew
}
// Use the user rollback mechanism to delete this user
err := pathUserRollback(req, "user", map[string]interface{}{
err := pathUserRollback(ctx, req, "user", map[string]interface{}{
"username": username,
})
if err != nil {

View File

@@ -1,6 +1,7 @@
package cassandra
import (
"context"
"fmt"
"strings"
"sync"
@@ -11,9 +12,9 @@ import (
)
// Factory creates a new backend
func Factory(conf *logical.BackendConfig) (logical.Backend, error) {
func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) {
b := Backend()
if err := b.Setup(conf); err != nil {
if err := b.Setup(ctx, conf); err != nil {
return nil, err
}
return b, nil
@@ -43,7 +44,7 @@ func Backend() *backend {
Invalidate: b.invalidate,
Clean: func() {
Clean: func(_ context.Context) {
b.ResetDB(nil)
},
BackendType: logical.TypeLogical,
@@ -77,7 +78,7 @@ type sessionConfig struct {
}
// DB returns the database connection.
func (b *backend) DB(s logical.Storage) (*gocql.Session, error) {
func (b *backend) DB(ctx context.Context, s logical.Storage) (*gocql.Session, error) {
b.lock.Lock()
defer b.lock.Unlock()
@@ -86,7 +87,7 @@ func (b *backend) DB(s logical.Storage) (*gocql.Session, error) {
return b.session, nil
}
entry, err := s.Get("config/connection")
entry, err := s.Get(ctx, "config/connection")
if err != nil {
return nil, err
}
@@ -120,7 +121,7 @@ func (b *backend) ResetDB(newSession *gocql.Session) {
b.session = newSession
}
func (b *backend) invalidate(key string) {
func (b *backend) invalidate(_ context.Context, key string) {
switch key {
case "config/connection":
b.ResetDB(nil)

View File

@@ -1,6 +1,7 @@
package cassandra
import (
"context"
"fmt"
"log"
"os"
@@ -82,7 +83,7 @@ func TestBackend_basic(t *testing.T) {
}
config := logical.TestBackendConfig()
config.StorageView = &logical.InmemStorage{}
b, err := Factory(config)
b, err := Factory(context.Background(), config)
if err != nil {
t.Fatal(err)
}
@@ -106,7 +107,7 @@ func TestBackend_roleCrud(t *testing.T) {
}
config := logical.TestBackendConfig()
config.StorageView = &logical.InmemStorage{}
b, err := Factory(config)
b, err := Factory(context.Background(), config)
if err != nil {
t.Fatal(err)
}

View File

@@ -87,7 +87,7 @@ take precedence.`,
}
func (b *backend) pathConnectionRead(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
entry, err := req.Storage.Get("config/connection")
entry, err := req.Storage.Get(ctx, "config/connection")
if err != nil {
return nil, err
}
@@ -196,7 +196,7 @@ func (b *backend) pathConnectionWrite(ctx context.Context, req *logical.Request,
if err != nil {
return nil, err
}
if err := req.Storage.Put(entry); err != nil {
if err := req.Storage.Put(ctx, entry); err != nil {
return nil, err
}

View File

@@ -36,7 +36,7 @@ func (b *backend) pathCredsCreateRead(ctx context.Context, req *logical.Request,
name := data.Get("name").(string)
// Get the role
role, err := getRole(req.Storage, name)
role, err := getRole(ctx, req.Storage, name)
if err != nil {
return nil, err
}
@@ -57,7 +57,7 @@ func (b *backend) pathCredsCreateRead(ctx context.Context, req *logical.Request,
}
// Get our connection
session, err := b.DB(req.Storage)
session, err := b.DB(ctx, req.Storage)
if err != nil {
return nil, err
}

View File

@@ -75,8 +75,8 @@ template values are '{{username}}' and
}
}
func getRole(s logical.Storage, n string) (*roleEntry, error) {
entry, err := s.Get("role/" + n)
func getRole(ctx context.Context, s logical.Storage, n string) (*roleEntry, error) {
entry, err := s.Get(ctx, "role/"+n)
if err != nil {
return nil, err
}
@@ -93,7 +93,7 @@ func getRole(s logical.Storage, n string) (*roleEntry, error) {
}
func (b *backend) pathRoleDelete(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
err := req.Storage.Delete("role/" + data.Get("name").(string))
err := req.Storage.Delete(ctx, "role/"+data.Get("name").(string))
if err != nil {
return nil, err
}
@@ -102,7 +102,7 @@ func (b *backend) pathRoleDelete(ctx context.Context, req *logical.Request, data
}
func (b *backend) pathRoleRead(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
role, err := getRole(req.Storage, data.Get("name").(string))
role, err := getRole(ctx, req.Storage, data.Get("name").(string))
if err != nil {
return nil, err
}
@@ -148,7 +148,7 @@ func (b *backend) pathRoleCreate(ctx context.Context, req *logical.Request, data
if err != nil {
return nil, err
}
if err := req.Storage.Put(entryJSON); err != nil {
if err := req.Storage.Put(ctx, entryJSON); err != nil {
return nil, err
}

View File

@@ -42,7 +42,7 @@ func (b *backend) secretCredsRenew(ctx context.Context, req *logical.Request, d
return nil, fmt.Errorf("error converting role internal data to string")
}
role, err := getRole(req.Storage, roleName)
role, err := getRole(ctx, req.Storage, roleName)
if err != nil {
return nil, fmt.Errorf("unable to load role: %s", err)
}
@@ -61,7 +61,7 @@ func (b *backend) secretCredsRevoke(ctx context.Context, req *logical.Request, d
return nil, fmt.Errorf("error converting username internal data to string")
}
session, err := b.DB(req.Storage)
session, err := b.DB(ctx, req.Storage)
if err != nil {
return nil, fmt.Errorf("error getting session")
}

View File

@@ -1,13 +1,15 @@
package consul
import (
"context"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
)
func Factory(conf *logical.BackendConfig) (logical.Backend, error) {
func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) {
b := Backend()
if err := b.Setup(conf); err != nil {
if err := b.Setup(ctx, conf); err != nil {
return nil, err
}
return b, nil

View File

@@ -83,7 +83,7 @@ func cleanupTestContainer(t *testing.T, cid dockertest.ContainerID) {
func TestBackend_config_access(t *testing.T) {
config := logical.TestBackendConfig()
config.StorageView = &logical.InmemStorage{}
b, err := Factory(config)
b, err := Factory(context.Background(), config)
if err != nil {
t.Fatal(err)
}
@@ -130,7 +130,7 @@ func TestBackend_config_access(t *testing.T) {
func TestBackend_basic(t *testing.T) {
config := logical.TestBackendConfig()
config.StorageView = &logical.InmemStorage{}
b, err := Factory(config)
b, err := Factory(context.Background(), config)
if err != nil {
t.Fatal(err)
}
@@ -157,7 +157,7 @@ func TestBackend_basic(t *testing.T) {
func TestBackend_renew_revoke(t *testing.T) {
config := logical.TestBackendConfig()
config.StorageView = &logical.InmemStorage{}
b, err := Factory(config)
b, err := Factory(context.Background(), config)
if err != nil {
t.Fatal(err)
}
@@ -264,7 +264,7 @@ func TestBackend_renew_revoke(t *testing.T) {
func TestBackend_management(t *testing.T) {
config := logical.TestBackendConfig()
config.StorageView = &logical.InmemStorage{}
b, err := Factory(config)
b, err := Factory(context.Background(), config)
if err != nil {
t.Fatal(err)
}
@@ -289,7 +289,7 @@ func TestBackend_management(t *testing.T) {
}
func TestBackend_crud(t *testing.T) {
b, _ := Factory(logical.TestBackendConfig())
b, _ := Factory(context.Background(), logical.TestBackendConfig())
logicaltest.Test(t, logicaltest.TestCase{
Backend: b,
Steps: []logicaltest.TestStep{
@@ -304,7 +304,7 @@ func TestBackend_crud(t *testing.T) {
}
func TestBackend_role_lease(t *testing.T) {
b, _ := Factory(logical.TestBackendConfig())
b, _ := Factory(context.Background(), logical.TestBackendConfig())
logicaltest.Test(t, logicaltest.TestCase{
Backend: b,
Steps: []logicaltest.TestStep{

View File

@@ -1,14 +1,15 @@
package consul
import (
"context"
"fmt"
"github.com/hashicorp/consul/api"
"github.com/hashicorp/vault/logical"
)
func client(s logical.Storage) (*api.Client, error, error) {
conf, userErr, intErr := readConfigAccess(s)
func client(ctx context.Context, s logical.Storage) (*api.Client, error, error) {
conf, userErr, intErr := readConfigAccess(ctx, s)
if intErr != nil {
return nil, nil, intErr
}

View File

@@ -40,8 +40,8 @@ func pathConfigAccess() *framework.Path {
}
}
func readConfigAccess(storage logical.Storage) (*accessConfig, error, error) {
entry, err := storage.Get("config/access")
func readConfigAccess(ctx context.Context, storage logical.Storage) (*accessConfig, error, error) {
entry, err := storage.Get(ctx, "config/access")
if err != nil {
return nil, nil, err
}
@@ -60,7 +60,7 @@ func readConfigAccess(storage logical.Storage) (*accessConfig, error, error) {
}
func pathConfigAccessRead(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
conf, userErr, intErr := readConfigAccess(req.Storage)
conf, userErr, intErr := readConfigAccess(ctx, req.Storage)
if intErr != nil {
return nil, intErr
}
@@ -89,7 +89,7 @@ func pathConfigAccessWrite(ctx context.Context, req *logical.Request, data *fram
return nil, err
}
if err := req.Storage.Put(entry); err != nil {
if err := req.Storage.Put(ctx, entry); err != nil {
return nil, err
}

View File

@@ -59,7 +59,7 @@ Defaults to 'client'.`,
}
func (b *backend) pathRoleList(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
entries, err := req.Storage.List("policy/")
entries, err := req.Storage.List(ctx, "policy/")
if err != nil {
return nil, err
}
@@ -70,7 +70,7 @@ func (b *backend) pathRoleList(ctx context.Context, req *logical.Request, d *fra
func pathRolesRead(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
name := d.Get("name").(string)
entry, err := req.Storage.Get("policy/" + name)
entry, err := req.Storage.Get(ctx, "policy/"+name)
if err != nil {
return nil, err
}
@@ -142,7 +142,7 @@ func pathRolesWrite(ctx context.Context, req *logical.Request, d *framework.Fiel
return nil, err
}
if err := req.Storage.Put(entry); err != nil {
if err := req.Storage.Put(ctx, entry); err != nil {
return nil, err
}
@@ -151,7 +151,7 @@ func pathRolesWrite(ctx context.Context, req *logical.Request, d *framework.Fiel
func pathRolesDelete(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
name := d.Get("name").(string)
if err := req.Storage.Delete("policy/" + name); err != nil {
if err := req.Storage.Delete(ctx, "policy/"+name); err != nil {
return nil, err
}
return nil, nil

View File

@@ -29,7 +29,7 @@ func pathToken(b *backend) *framework.Path {
func (b *backend) pathTokenRead(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
role := d.Get("role").(string)
entry, err := req.Storage.Get("policy/" + role)
entry, err := req.Storage.Get(ctx, "policy/"+role)
if err != nil {
return nil, fmt.Errorf("error retrieving role: %s", err)
}
@@ -47,7 +47,7 @@ func (b *backend) pathTokenRead(ctx context.Context, req *logical.Request, d *fr
}
// Get the consul client
c, userErr, intErr := client(req.Storage)
c, userErr, intErr := client(ctx, req.Storage)
if intErr != nil {
return nil, intErr
}

View File

@@ -38,7 +38,7 @@ func (b *backend) secretTokenRenew(ctx context.Context, req *logical.Request, d
return framework.LeaseExtend(0, 0, b.System())(ctx, req, d)
}
entry, err := req.Storage.Get("policy/" + role)
entry, err := req.Storage.Get(ctx, "policy/"+role)
if err != nil {
return nil, fmt.Errorf("error retrieving role: %s", err)
}
@@ -55,7 +55,7 @@ func (b *backend) secretTokenRenew(ctx context.Context, req *logical.Request, d
}
func secretTokenRevoke(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
c, userErr, intErr := client(req.Storage)
c, userErr, intErr := client(ctx, req.Storage)
if intErr != nil {
return nil, intErr
}

View File

@@ -16,9 +16,9 @@ import (
const databaseConfigPath = "database/config/"
func Factory(conf *logical.BackendConfig) (logical.Backend, error) {
func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) {
b := Backend(conf)
if err := b.Setup(conf); err != nil {
if err := b.Setup(ctx, conf); err != nil {
return nil, err
}
return b, nil
@@ -66,7 +66,7 @@ type databaseBackend struct {
}
// closeAllDBs closes all connections from all database types
func (b *databaseBackend) closeAllDBs() {
func (b *databaseBackend) closeAllDBs(ctx context.Context) {
b.Lock()
defer b.Unlock()
@@ -94,12 +94,12 @@ func (b *databaseBackend) createDBObj(ctx context.Context, s logical.Storage, na
return db, nil
}
config, err := b.DatabaseConfig(s, name)
config, err := b.DatabaseConfig(ctx, s, name)
if err != nil {
return nil, err
}
db, err = dbplugin.PluginFactory(config.PluginName, b.System(), b.logger)
db, err = dbplugin.PluginFactory(ctx, config.PluginName, b.System(), b.logger)
if err != nil {
return nil, err
}
@@ -115,8 +115,8 @@ func (b *databaseBackend) createDBObj(ctx context.Context, s logical.Storage, na
return db, nil
}
func (b *databaseBackend) DatabaseConfig(s logical.Storage, name string) (*DatabaseConfig, error) {
entry, err := s.Get(fmt.Sprintf("config/%s", name))
func (b *databaseBackend) DatabaseConfig(ctx context.Context, s logical.Storage, name string) (*DatabaseConfig, error) {
entry, err := s.Get(ctx, fmt.Sprintf("config/%s", name))
if err != nil {
return nil, fmt.Errorf("failed to read connection configuration: %s", err)
}
@@ -147,8 +147,8 @@ type upgradeCheck struct {
Statements upgradeStatements `json:"statments"`
}
func (b *databaseBackend) Role(s logical.Storage, roleName string) (*roleEntry, error) {
entry, err := s.Get("role/" + roleName)
func (b *databaseBackend) Role(ctx context.Context, s logical.Storage, roleName string) (*roleEntry, error) {
entry, err := s.Get(ctx, "role/"+roleName)
if err != nil {
return nil, err
}
@@ -177,7 +177,7 @@ func (b *databaseBackend) Role(s logical.Storage, roleName string) (*roleEntry,
return &result, nil
}
func (b *databaseBackend) invalidate(key string) {
func (b *databaseBackend) invalidate(ctx context.Context, key string) {
b.Lock()
defer b.Unlock()

View File

@@ -133,11 +133,11 @@ func TestBackend_RoleUpgrade(t *testing.T) {
if err != nil {
t.Fatal(err)
}
if err := storage.Put(entry); err != nil {
if err := storage.Put(context.Background(), entry); err != nil {
t.Fatal(err)
}
role, err := backend.Role(storage, "test")
role, err := backend.Role(context.Background(), storage, "test")
if err != nil {
t.Fatal(err)
}
@@ -152,11 +152,11 @@ func TestBackend_RoleUpgrade(t *testing.T) {
Key: "role/test",
Value: []byte(badJSON),
}
if err := storage.Put(entry); err != nil {
if err := storage.Put(context.Background(), entry); err != nil {
t.Fatal(err)
}
role, err = backend.Role(storage, "test")
role, err = backend.Role(context.Background(), storage, "test")
if err != nil {
t.Fatal(err)
}
@@ -177,11 +177,11 @@ func TestBackend_config_connection(t *testing.T) {
config := logical.TestBackendConfig()
config.StorageView = &logical.InmemStorage{}
config.System = sys
b, err := Factory(config)
b, err := Factory(context.Background(), config)
if err != nil {
t.Fatal(err)
}
defer b.Cleanup()
defer b.Cleanup(context.Background())
configData := map[string]interface{}{
"connection_url": "sample_connection_url",
@@ -241,11 +241,11 @@ func TestBackend_basic(t *testing.T) {
config.StorageView = &logical.InmemStorage{}
config.System = sys
b, err := Factory(config)
b, err := Factory(context.Background(), config)
if err != nil {
t.Fatal(err)
}
defer b.Cleanup()
defer b.Cleanup(context.Background())
cleanup, connURL := preparePostgresTestContainer(t, config.StorageView, b)
defer cleanup()
@@ -399,11 +399,11 @@ func TestBackend_connectionCrud(t *testing.T) {
config.StorageView = &logical.InmemStorage{}
config.System = sys
b, err := Factory(config)
b, err := Factory(context.Background(), config)
if err != nil {
t.Fatal(err)
}
defer b.Cleanup()
defer b.Cleanup(context.Background())
cleanup, connURL := preparePostgresTestContainer(t, config.StorageView, b)
defer cleanup()
@@ -544,11 +544,11 @@ func TestBackend_roleCrud(t *testing.T) {
config.StorageView = &logical.InmemStorage{}
config.System = sys
b, err := Factory(config)
b, err := Factory(context.Background(), config)
if err != nil {
t.Fatal(err)
}
defer b.Cleanup()
defer b.Cleanup(context.Background())
cleanup, connURL := preparePostgresTestContainer(t, config.StorageView, b)
defer cleanup()
@@ -656,11 +656,11 @@ func TestBackend_allowedRoles(t *testing.T) {
config.StorageView = &logical.InmemStorage{}
config.System = sys
b, err := Factory(config)
b, err := Factory(context.Background(), config)
if err != nil {
t.Fatal(err)
}
defer b.Cleanup()
defer b.Cleanup(context.Background())
cleanup, connURL := preparePostgresTestContainer(t, config.StorageView, b)
defer cleanup()

View File

@@ -1,6 +1,7 @@
package dbplugin
import (
"context"
"errors"
"sync"
@@ -30,13 +31,13 @@ func (dc *DatabasePluginClient) Close() error {
// newPluginClient returns a databaseRPCClient with a connection to a running
// plugin. The client is wrapped in a DatabasePluginClient object to ensure the
// plugin is killed on call of Close().
func newPluginClient(sys pluginutil.RunnerUtil, pluginRunner *pluginutil.PluginRunner, logger log.Logger) (Database, error) {
func newPluginClient(ctx context.Context, sys pluginutil.RunnerUtil, pluginRunner *pluginutil.PluginRunner, logger log.Logger) (Database, error) {
// pluginMap is the map of plugins we can dispense.
var pluginMap = map[string]plugin.Plugin{
"database": new(DatabasePlugin),
}
client, err := pluginRunner.Run(sys, pluginMap, handshakeConfig, []string{}, logger)
client, err := pluginRunner.Run(ctx, sys, pluginMap, handshakeConfig, []string{}, logger)
if err != nil {
return nil, err
}

View File

@@ -26,9 +26,9 @@ type Database interface {
// PluginFactory is used to build plugin database types. It wraps the database
// object in a logging and metrics middleware.
func PluginFactory(pluginName string, sys pluginutil.LookRunnerUtil, logger log.Logger) (Database, error) {
func PluginFactory(ctx context.Context, pluginName string, sys pluginutil.LookRunnerUtil, logger log.Logger) (Database, error) {
// Look for plugin in the plugin catalog
pluginRunner, err := sys.LookupPlugin(pluginName)
pluginRunner, err := sys.LookupPlugin(ctx, pluginName)
if err != nil {
return nil, err
}
@@ -53,7 +53,7 @@ func PluginFactory(pluginName string, sys pluginutil.LookRunnerUtil, logger log.
} else {
// create a DatabasePluginClient instance
db, err = newPluginClient(sys, pluginRunner, logger)
db, err = newPluginClient(ctx, sys, pluginRunner, logger)
if err != nil {
return nil, err
}

View File

@@ -136,7 +136,7 @@ func TestPlugin_Initialize(t *testing.T) {
cluster, sys := getCluster(t)
defer cluster.Cleanup()
dbRaw, err := dbplugin.PluginFactory("test-plugin", sys, &log.NullLogger{})
dbRaw, err := dbplugin.PluginFactory(context.Background(), "test-plugin", sys, &log.NullLogger{})
if err != nil {
t.Fatalf("err: %s", err)
}
@@ -160,7 +160,7 @@ func TestPlugin_CreateUser(t *testing.T) {
cluster, sys := getCluster(t)
defer cluster.Cleanup()
db, err := dbplugin.PluginFactory("test-plugin", sys, &log.NullLogger{})
db, err := dbplugin.PluginFactory(context.Background(), "test-plugin", sys, &log.NullLogger{})
if err != nil {
t.Fatalf("err: %s", err)
}
@@ -200,7 +200,7 @@ func TestPlugin_RenewUser(t *testing.T) {
cluster, sys := getCluster(t)
defer cluster.Cleanup()
db, err := dbplugin.PluginFactory("test-plugin", sys, &log.NullLogger{})
db, err := dbplugin.PluginFactory(context.Background(), "test-plugin", sys, &log.NullLogger{})
if err != nil {
t.Fatalf("err: %s", err)
}
@@ -234,7 +234,7 @@ func TestPlugin_RevokeUser(t *testing.T) {
cluster, sys := getCluster(t)
defer cluster.Cleanup()
db, err := dbplugin.PluginFactory("test-plugin", sys, &log.NullLogger{})
db, err := dbplugin.PluginFactory(context.Background(), "test-plugin", sys, &log.NullLogger{})
if err != nil {
t.Fatalf("err: %s", err)
}
@@ -276,7 +276,7 @@ func TestPlugin_NetRPC_Initialize(t *testing.T) {
cluster, sys := getCluster(t)
defer cluster.Cleanup()
dbRaw, err := dbplugin.PluginFactory("test-plugin-netRPC", sys, &log.NullLogger{})
dbRaw, err := dbplugin.PluginFactory(context.Background(), "test-plugin-netRPC", sys, &log.NullLogger{})
if err != nil {
t.Fatalf("err: %s", err)
}
@@ -300,7 +300,7 @@ func TestPlugin_NetRPC_CreateUser(t *testing.T) {
cluster, sys := getCluster(t)
defer cluster.Cleanup()
db, err := dbplugin.PluginFactory("test-plugin-netRPC", sys, &log.NullLogger{})
db, err := dbplugin.PluginFactory(context.Background(), "test-plugin-netRPC", sys, &log.NullLogger{})
if err != nil {
t.Fatalf("err: %s", err)
}
@@ -340,7 +340,7 @@ func TestPlugin_NetRPC_RenewUser(t *testing.T) {
cluster, sys := getCluster(t)
defer cluster.Cleanup()
db, err := dbplugin.PluginFactory("test-plugin-netRPC", sys, &log.NullLogger{})
db, err := dbplugin.PluginFactory(context.Background(), "test-plugin-netRPC", sys, &log.NullLogger{})
if err != nil {
t.Fatalf("err: %s", err)
}
@@ -374,7 +374,7 @@ func TestPlugin_NetRPC_RevokeUser(t *testing.T) {
cluster, sys := getCluster(t)
defer cluster.Cleanup()
db, err := dbplugin.PluginFactory("test-plugin-netRPC", sys, &log.NullLogger{})
db, err := dbplugin.PluginFactory(context.Background(), "test-plugin-netRPC", sys, &log.NullLogger{})
if err != nil {
t.Fatalf("err: %s", err)
}

View File

@@ -131,7 +131,7 @@ func pathListPluginConnection(b *databaseBackend) *framework.Path {
func (b *databaseBackend) connectionListHandler() framework.OperationFunc {
return func(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
entries, err := req.Storage.List("config/")
entries, err := req.Storage.List(ctx, "config/")
if err != nil {
return nil, err
}
@@ -148,7 +148,7 @@ func (b *databaseBackend) connectionReadHandler() framework.OperationFunc {
return logical.ErrorResponse(respErrEmptyName), nil
}
entry, err := req.Storage.Get(fmt.Sprintf("config/%s", name))
entry, err := req.Storage.Get(ctx, fmt.Sprintf("config/%s", name))
if err != nil {
return nil, errors.New("failed to read connection configuration")
}
@@ -174,7 +174,7 @@ func (b *databaseBackend) connectionDeleteHandler() framework.OperationFunc {
return logical.ErrorResponse(respErrEmptyName), nil
}
err := req.Storage.Delete(fmt.Sprintf("config/%s", name))
err := req.Storage.Delete(ctx, fmt.Sprintf("config/%s", name))
if err != nil {
return nil, errors.New("failed to delete connection configuration")
}
@@ -226,7 +226,7 @@ func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc {
AllowedRoles: allowedRoles,
}
db, err := dbplugin.PluginFactory(config.PluginName, b.System(), b.logger)
db, err := dbplugin.PluginFactory(ctx, config.PluginName, b.System(), b.logger)
if err != nil {
return logical.ErrorResponse(fmt.Sprintf("error creating database object: %s", err)), nil
}
@@ -252,7 +252,7 @@ func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc {
if err != nil {
return nil, err
}
if err := req.Storage.Put(entry); err != nil {
if err := req.Storage.Put(ctx, entry); err != nil {
return nil, err
}

View File

@@ -35,7 +35,7 @@ func (b *databaseBackend) pathCredsCreateRead() framework.OperationFunc {
name := data.Get("name").(string)
// Get the role
role, err := b.Role(req.Storage, name)
role, err := b.Role(ctx, req.Storage, name)
if err != nil {
return nil, err
}
@@ -43,7 +43,7 @@ func (b *databaseBackend) pathCredsCreateRead() framework.OperationFunc {
return logical.ErrorResponse(fmt.Sprintf("unknown role: %s", name)), nil
}
dbConfig, err := b.DatabaseConfig(req.Storage, role.DBName)
dbConfig, err := b.DatabaseConfig(ctx, req.Storage, role.DBName)
if err != nil {
return nil, err
}

View File

@@ -87,7 +87,7 @@ func pathRoles(b *databaseBackend) *framework.Path {
func (b *databaseBackend) pathRoleDelete() framework.OperationFunc {
return func(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
err := req.Storage.Delete("role/" + data.Get("name").(string))
err := req.Storage.Delete(ctx, "role/"+data.Get("name").(string))
if err != nil {
return nil, err
}
@@ -98,7 +98,7 @@ func (b *databaseBackend) pathRoleDelete() framework.OperationFunc {
func (b *databaseBackend) pathRoleRead() framework.OperationFunc {
return func(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
role, err := b.Role(req.Storage, data.Get("name").(string))
role, err := b.Role(ctx, req.Storage, data.Get("name").(string))
if err != nil {
return nil, err
}
@@ -122,7 +122,7 @@ func (b *databaseBackend) pathRoleRead() framework.OperationFunc {
func (b *databaseBackend) pathRoleList() framework.OperationFunc {
return func(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
entries, err := req.Storage.List("role/")
entries, err := req.Storage.List(ctx, "role/")
if err != nil {
return nil, err
}
@@ -172,7 +172,7 @@ func (b *databaseBackend) pathRoleCreate() framework.OperationFunc {
if err != nil {
return nil, err
}
if err := req.Storage.Put(entry); err != nil {
if err := req.Storage.Put(ctx, entry); err != nil {
return nil, err
}

View File

@@ -34,7 +34,7 @@ func (b *databaseBackend) secretCredsRenew() framework.OperationFunc {
return nil, fmt.Errorf("could not find role with name: %s", req.Secret.InternalData["role"])
}
role, err := b.Role(req.Storage, roleNameRaw.(string))
role, err := b.Role(ctx, req.Storage, roleNameRaw.(string))
if err != nil {
return nil, err
}
@@ -99,7 +99,7 @@ func (b *databaseBackend) secretCredsRevoke() framework.OperationFunc {
return nil, fmt.Errorf("no role name was provided")
}
role, err := b.Role(req.Storage, roleNameRaw.(string))
role, err := b.Role(ctx, req.Storage, roleNameRaw.(string))
if err != nil {
return nil, err
}

View File

@@ -1,6 +1,7 @@
package mongodb
import (
"context"
"fmt"
"strings"
"sync"
@@ -11,9 +12,9 @@ import (
"gopkg.in/mgo.v2"
)
func Factory(conf *logical.BackendConfig) (logical.Backend, error) {
func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) {
b := Backend()
if err := b.Setup(conf); err != nil {
if err := b.Setup(ctx, conf); err != nil {
return nil, err
}
return b, nil
@@ -59,7 +60,7 @@ type backend struct {
}
// Session returns the database connection.
func (b *backend) Session(s logical.Storage) (*mgo.Session, error) {
func (b *backend) Session(ctx context.Context, s logical.Storage) (*mgo.Session, error) {
b.lock.Lock()
defer b.lock.Unlock()
@@ -70,7 +71,7 @@ func (b *backend) Session(s logical.Storage) (*mgo.Session, error) {
b.session.Close()
}
connConfigJSON, err := s.Get("config/connection")
connConfigJSON, err := s.Get(ctx, "config/connection")
if err != nil {
return nil, err
}
@@ -99,7 +100,7 @@ func (b *backend) Session(s logical.Storage) (*mgo.Session, error) {
}
// ResetSession forces creation of a new connection next time Session() is called.
func (b *backend) ResetSession() {
func (b *backend) ResetSession(_ context.Context) {
b.lock.Lock()
defer b.lock.Unlock()
@@ -110,16 +111,16 @@ func (b *backend) ResetSession() {
b.session = nil
}
func (b *backend) invalidate(key string) {
func (b *backend) invalidate(ctx context.Context, key string) {
switch key {
case "config/connection":
b.ResetSession()
b.ResetSession(ctx)
}
}
// LeaseConfig returns the lease configuration
func (b *backend) LeaseConfig(s logical.Storage) (*configLease, error) {
entry, err := s.Get("config/lease")
func (b *backend) LeaseConfig(ctx context.Context, s logical.Storage) (*configLease, error) {
entry, err := s.Get(ctx, "config/lease")
if err != nil {
return nil, err
}

Some files were not shown because too many files have changed in this diff Show More