mirror of
https://github.com/optim-enterprises-bv/vault.git
synced 2025-11-01 19:17:58 +00:00
Add context to storage backends and wire it through a lot of places (#3817)
This commit is contained in:
committed by
Jeff Mitchell
parent
2864fbd697
commit
8142b42d95
@@ -1,6 +1,8 @@
|
||||
package audit
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/hashicorp/vault/helper/salt"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
)
|
||||
@@ -14,13 +16,13 @@ type Backend interface {
|
||||
// request is authorized but before the request is executed. The arguments
|
||||
// MUST not be modified in anyway. They should be deep copied if this is
|
||||
// a possibility.
|
||||
LogRequest(*logical.Auth, *logical.Request, error) error
|
||||
LogRequest(context.Context, *logical.Auth, *logical.Request, error) error
|
||||
|
||||
// LogResponse is used to synchronously log a response. This is done after
|
||||
// the request is processed but before the response is sent. The arguments
|
||||
// MUST not be modified in anyway. They should be deep copied if this is
|
||||
// a possibility.
|
||||
LogResponse(*logical.Auth, *logical.Request, *logical.Response, error) error
|
||||
LogResponse(context.Context, *logical.Auth, *logical.Request, *logical.Response, error) error
|
||||
|
||||
// GetHash is used to return the given data with the backend's hash,
|
||||
// so that a caller can determine if a value in the audit log matches
|
||||
@@ -28,10 +30,10 @@ type Backend interface {
|
||||
GetHash(string) (string, error)
|
||||
|
||||
// Reload is called on SIGHUP for supporting backends.
|
||||
Reload() error
|
||||
Reload(context.Context) error
|
||||
|
||||
// Invalidate is called for path invalidation
|
||||
Invalidate()
|
||||
Invalidate(context.Context)
|
||||
}
|
||||
|
||||
type BackendConfig struct {
|
||||
@@ -46,4 +48,4 @@ type BackendConfig struct {
|
||||
}
|
||||
|
||||
// Factory is the factory function to create an audit backend.
|
||||
type Factory func(*BackendConfig) (Backend, error)
|
||||
type Factory func(context.Context, *BackendConfig) (Backend, error)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package audit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"reflect"
|
||||
@@ -94,7 +95,7 @@ func TestCopy_response(t *testing.T) {
|
||||
|
||||
func TestHashString(t *testing.T) {
|
||||
inmemStorage := &logical.InmemStorage{}
|
||||
inmemStorage.Put(&logical.StorageEntry{
|
||||
inmemStorage.Put(context.Background(), &logical.StorageEntry{
|
||||
Key: "salt",
|
||||
Value: []byte("foo"),
|
||||
})
|
||||
@@ -192,7 +193,7 @@ func TestHash(t *testing.T) {
|
||||
}
|
||||
|
||||
inmemStorage := &logical.InmemStorage{}
|
||||
inmemStorage.Put(&logical.StorageEntry{
|
||||
inmemStorage.Put(context.Background(), &logical.StorageEntry{
|
||||
Key: "salt",
|
||||
Value: []byte("foo"),
|
||||
})
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package file
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
@@ -14,7 +15,7 @@ import (
|
||||
"github.com/hashicorp/vault/logical"
|
||||
)
|
||||
|
||||
func Factory(conf *audit.BackendConfig) (audit.Backend, error) {
|
||||
func Factory(ctx context.Context, conf *audit.BackendConfig) (audit.Backend, error) {
|
||||
if conf.SaltConfig == nil {
|
||||
return nil, fmt.Errorf("nil salt config")
|
||||
}
|
||||
@@ -168,7 +169,12 @@ func (b *Backend) GetHash(data string) (string, error) {
|
||||
return audit.HashString(salt, data), nil
|
||||
}
|
||||
|
||||
func (b *Backend) LogRequest(auth *logical.Auth, req *logical.Request, outerErr error) error {
|
||||
func (b *Backend) LogRequest(
|
||||
_ context.Context,
|
||||
auth *logical.Auth,
|
||||
req *logical.Request,
|
||||
outerErr error) error {
|
||||
|
||||
b.fileLock.Lock()
|
||||
defer b.fileLock.Unlock()
|
||||
|
||||
@@ -199,6 +205,7 @@ func (b *Backend) LogRequest(auth *logical.Auth, req *logical.Request, outerErr
|
||||
}
|
||||
|
||||
func (b *Backend) LogResponse(
|
||||
_ context.Context,
|
||||
auth *logical.Auth,
|
||||
req *logical.Request,
|
||||
resp *logical.Response,
|
||||
@@ -264,7 +271,7 @@ func (b *Backend) open() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *Backend) Reload() error {
|
||||
func (b *Backend) Reload(_ context.Context) error {
|
||||
switch b.path {
|
||||
case "stdout", "discard":
|
||||
return nil
|
||||
@@ -288,7 +295,7 @@ func (b *Backend) Reload() error {
|
||||
return b.open()
|
||||
}
|
||||
|
||||
func (b *Backend) Invalidate() {
|
||||
func (b *Backend) Invalidate(_ context.Context) {
|
||||
b.saltMutex.Lock()
|
||||
defer b.saltMutex.Unlock()
|
||||
b.salt = nil
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package file
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@@ -33,7 +34,7 @@ func TestAuditFile_fileModeNew(t *testing.T) {
|
||||
"mode": modeStr,
|
||||
}
|
||||
|
||||
_, err = Factory(&audit.BackendConfig{
|
||||
_, err = Factory(context.Background(), &audit.BackendConfig{
|
||||
SaltConfig: &salt.Config{},
|
||||
SaltView: &logical.InmemStorage{},
|
||||
Config: config,
|
||||
@@ -72,7 +73,7 @@ func TestAuditFile_fileModeExisting(t *testing.T) {
|
||||
"path": f.Name(),
|
||||
}
|
||||
|
||||
_, err = Factory(&audit.BackendConfig{
|
||||
_, err = Factory(context.Background(), &audit.BackendConfig{
|
||||
Config: config,
|
||||
SaltConfig: &salt.Config{},
|
||||
SaltView: &logical.InmemStorage{},
|
||||
|
||||
@@ -2,6 +2,7 @@ package socket
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
@@ -15,7 +16,7 @@ import (
|
||||
"github.com/hashicorp/vault/logical"
|
||||
)
|
||||
|
||||
func Factory(conf *audit.BackendConfig) (audit.Backend, error) {
|
||||
func Factory(ctx context.Context, conf *audit.BackendConfig) (audit.Backend, error) {
|
||||
if conf.SaltConfig == nil {
|
||||
return nil, fmt.Errorf("nil salt config")
|
||||
}
|
||||
@@ -128,7 +129,7 @@ func (b *Backend) GetHash(data string) (string, error) {
|
||||
return audit.HashString(salt, data), nil
|
||||
}
|
||||
|
||||
func (b *Backend) LogRequest(auth *logical.Auth, req *logical.Request, outerErr error) error {
|
||||
func (b *Backend) LogRequest(ctx context.Context, auth *logical.Auth, req *logical.Request, outerErr error) error {
|
||||
var buf bytes.Buffer
|
||||
if err := b.formatter.FormatRequest(&buf, b.formatConfig, auth, req, outerErr); err != nil {
|
||||
return err
|
||||
@@ -137,21 +138,21 @@ func (b *Backend) LogRequest(auth *logical.Auth, req *logical.Request, outerErr
|
||||
b.Lock()
|
||||
defer b.Unlock()
|
||||
|
||||
err := b.write(buf.Bytes())
|
||||
err := b.write(ctx, buf.Bytes())
|
||||
if err != nil {
|
||||
rErr := b.reconnect()
|
||||
rErr := b.reconnect(ctx)
|
||||
if rErr != nil {
|
||||
err = multierror.Append(err, rErr)
|
||||
} else {
|
||||
// Try once more after reconnecting
|
||||
err = b.write(buf.Bytes())
|
||||
err = b.write(ctx, buf.Bytes())
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (b *Backend) LogResponse(auth *logical.Auth, req *logical.Request,
|
||||
func (b *Backend) LogResponse(ctx context.Context, auth *logical.Auth, req *logical.Request,
|
||||
resp *logical.Response, outerErr error) error {
|
||||
var buf bytes.Buffer
|
||||
if err := b.formatter.FormatResponse(&buf, b.formatConfig, auth, req, resp, outerErr); err != nil {
|
||||
@@ -161,23 +162,23 @@ func (b *Backend) LogResponse(auth *logical.Auth, req *logical.Request,
|
||||
b.Lock()
|
||||
defer b.Unlock()
|
||||
|
||||
err := b.write(buf.Bytes())
|
||||
err := b.write(ctx, buf.Bytes())
|
||||
if err != nil {
|
||||
rErr := b.reconnect()
|
||||
rErr := b.reconnect(ctx)
|
||||
if rErr != nil {
|
||||
err = multierror.Append(err, rErr)
|
||||
} else {
|
||||
// Try once more after reconnecting
|
||||
err = b.write(buf.Bytes())
|
||||
err = b.write(ctx, buf.Bytes())
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (b *Backend) write(buf []byte) error {
|
||||
func (b *Backend) write(ctx context.Context, buf []byte) error {
|
||||
if b.connection == nil {
|
||||
if err := b.reconnect(); err != nil {
|
||||
if err := b.reconnect(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -195,13 +196,14 @@ func (b *Backend) write(buf []byte) error {
|
||||
return err
|
||||
}
|
||||
|
||||
func (b *Backend) reconnect() error {
|
||||
func (b *Backend) reconnect(ctx context.Context) error {
|
||||
if b.connection != nil {
|
||||
b.connection.Close()
|
||||
b.connection = nil
|
||||
}
|
||||
|
||||
conn, err := net.Dial(b.socketType, b.address)
|
||||
dialer := net.Dialer{}
|
||||
conn, err := dialer.DialContext(ctx, b.socketType, b.address)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -211,11 +213,11 @@ func (b *Backend) reconnect() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *Backend) Reload() error {
|
||||
func (b *Backend) Reload(ctx context.Context) error {
|
||||
b.Lock()
|
||||
defer b.Unlock()
|
||||
|
||||
err := b.reconnect()
|
||||
err := b.reconnect(ctx)
|
||||
|
||||
return err
|
||||
}
|
||||
@@ -240,7 +242,7 @@ func (b *Backend) Salt() (*salt.Salt, error) {
|
||||
return salt, nil
|
||||
}
|
||||
|
||||
func (b *Backend) Invalidate() {
|
||||
func (b *Backend) Invalidate(_ context.Context) {
|
||||
b.saltMutex.Lock()
|
||||
defer b.saltMutex.Unlock()
|
||||
b.salt = nil
|
||||
|
||||
@@ -2,6 +2,7 @@ package syslog
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"sync"
|
||||
@@ -12,7 +13,7 @@ import (
|
||||
"github.com/hashicorp/vault/logical"
|
||||
)
|
||||
|
||||
func Factory(conf *audit.BackendConfig) (audit.Backend, error) {
|
||||
func Factory(ctx context.Context, conf *audit.BackendConfig) (audit.Backend, error) {
|
||||
if conf.SaltConfig == nil {
|
||||
return nil, fmt.Errorf("nil salt config")
|
||||
}
|
||||
@@ -115,7 +116,7 @@ func (b *Backend) GetHash(data string) (string, error) {
|
||||
return audit.HashString(salt, data), nil
|
||||
}
|
||||
|
||||
func (b *Backend) LogRequest(auth *logical.Auth, req *logical.Request, outerErr error) error {
|
||||
func (b *Backend) LogRequest(_ context.Context, auth *logical.Auth, req *logical.Request, outerErr error) error {
|
||||
var buf bytes.Buffer
|
||||
if err := b.formatter.FormatRequest(&buf, b.formatConfig, auth, req, outerErr); err != nil {
|
||||
return err
|
||||
@@ -126,7 +127,7 @@ func (b *Backend) LogRequest(auth *logical.Auth, req *logical.Request, outerErr
|
||||
return err
|
||||
}
|
||||
|
||||
func (b *Backend) LogResponse(auth *logical.Auth, req *logical.Request, resp *logical.Response, err error) error {
|
||||
func (b *Backend) LogResponse(_ context.Context, auth *logical.Auth, req *logical.Request, resp *logical.Response, err error) error {
|
||||
var buf bytes.Buffer
|
||||
if err := b.formatter.FormatResponse(&buf, b.formatConfig, auth, req, resp, err); err != nil {
|
||||
return err
|
||||
@@ -137,7 +138,7 @@ func (b *Backend) LogResponse(auth *logical.Auth, req *logical.Request, resp *lo
|
||||
return err
|
||||
}
|
||||
|
||||
func (b *Backend) Reload() error {
|
||||
func (b *Backend) Reload(_ context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -161,7 +162,7 @@ func (b *Backend) Salt() (*salt.Salt, error) {
|
||||
return salt, nil
|
||||
}
|
||||
|
||||
func (b *Backend) Invalidate() {
|
||||
func (b *Backend) Invalidate(_ context.Context) {
|
||||
b.saltMutex.Lock()
|
||||
defer b.saltMutex.Unlock()
|
||||
b.salt = nil
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package appId
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/hashicorp/vault/helper/salt"
|
||||
@@ -8,12 +9,12 @@ import (
|
||||
"github.com/hashicorp/vault/logical/framework"
|
||||
)
|
||||
|
||||
func Factory(conf *logical.BackendConfig) (logical.Backend, error) {
|
||||
func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) {
|
||||
b, err := Backend(conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := b.Setup(conf); err != nil {
|
||||
if err := b.Setup(ctx, conf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return b, nil
|
||||
@@ -115,7 +116,7 @@ func (b *backend) Salt() (*salt.Salt, error) {
|
||||
return salt, nil
|
||||
}
|
||||
|
||||
func (b *backend) invalidate(key string) {
|
||||
func (b *backend) invalidate(_ context.Context, key string) {
|
||||
switch key {
|
||||
case salt.DefaultLocation:
|
||||
b.SaltMutex.Lock()
|
||||
|
||||
@@ -14,13 +14,13 @@ func TestBackend_basic(t *testing.T) {
|
||||
var b *backend
|
||||
var err error
|
||||
var storage logical.Storage
|
||||
factory := func(conf *logical.BackendConfig) (logical.Backend, error) {
|
||||
factory := func(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) {
|
||||
b, err = Backend(conf)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
storage = conf.StorageView
|
||||
if err := b.Setup(conf); err != nil {
|
||||
if err := b.Setup(ctx, conf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return b, nil
|
||||
|
||||
@@ -84,7 +84,7 @@ func (b *backend) pathLogin(ctx context.Context, req *logical.Request, data *fra
|
||||
userId := data.Get("user_id").(string)
|
||||
|
||||
var displayName string
|
||||
if dispName, resp, err := b.verifyCredentials(req, appId, userId); err != nil {
|
||||
if dispName, resp, err := b.verifyCredentials(ctx, req, appId, userId); err != nil {
|
||||
return nil, err
|
||||
} else if resp != nil {
|
||||
return resp, nil
|
||||
@@ -93,7 +93,7 @@ func (b *backend) pathLogin(ctx context.Context, req *logical.Request, data *fra
|
||||
}
|
||||
|
||||
// Get the policies associated with the app
|
||||
policies, err := b.MapAppId.Policies(req.Storage, appId)
|
||||
policies, err := b.MapAppId.Policies(ctx, req.Storage, appId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -131,14 +131,14 @@ func (b *backend) pathLoginRenew(ctx context.Context, req *logical.Request, d *f
|
||||
|
||||
// Skipping CIDR verification to enable renewal from machines other than
|
||||
// the ones encompassed by CIDR block.
|
||||
if _, resp, err := b.verifyCredentials(req, appId, userId); err != nil {
|
||||
if _, resp, err := b.verifyCredentials(ctx, req, appId, userId); err != nil {
|
||||
return nil, err
|
||||
} else if resp != nil {
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// Get the policies associated with the app
|
||||
mapPolicies, err := b.MapAppId.Policies(req.Storage, appId)
|
||||
mapPolicies, err := b.MapAppId.Policies(ctx, req.Storage, appId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -149,14 +149,14 @@ func (b *backend) pathLoginRenew(ctx context.Context, req *logical.Request, d *f
|
||||
return framework.LeaseExtend(0, 0, b.System())(ctx, req, d)
|
||||
}
|
||||
|
||||
func (b *backend) verifyCredentials(req *logical.Request, appId, userId string) (string, *logical.Response, error) {
|
||||
func (b *backend) verifyCredentials(ctx context.Context, req *logical.Request, appId, userId string) (string, *logical.Response, error) {
|
||||
// Ensure both appId and userId are provided
|
||||
if appId == "" || userId == "" {
|
||||
return "", logical.ErrorResponse("missing 'app_id' or 'user_id'"), nil
|
||||
}
|
||||
|
||||
// Look up the apps that this user is allowed to access
|
||||
appsMap, err := b.MapUserId.Get(req.Storage, userId)
|
||||
appsMap, err := b.MapUserId.Get(ctx, req.Storage, userId)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
@@ -205,7 +205,7 @@ func (b *backend) verifyCredentials(req *logical.Request, appId, userId string)
|
||||
}
|
||||
|
||||
// Get the raw data associated with the app
|
||||
appRaw, err := b.MapAppId.Get(req.Storage, appId)
|
||||
appRaw, err := b.MapAppId.Get(ctx, req.Storage, appId)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package approle
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/hashicorp/vault/helper/locksutil"
|
||||
@@ -49,12 +50,12 @@ type backend struct {
|
||||
secretIDListingLock sync.RWMutex
|
||||
}
|
||||
|
||||
func Factory(conf *logical.BackendConfig) (logical.Backend, error) {
|
||||
func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) {
|
||||
b, err := Backend(conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := b.Setup(conf); err != nil {
|
||||
if err := b.Setup(ctx, conf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return b, nil
|
||||
@@ -125,7 +126,7 @@ func (b *backend) Salt() (*salt.Salt, error) {
|
||||
return salt, nil
|
||||
}
|
||||
|
||||
func (b *backend) invalidate(key string) {
|
||||
func (b *backend) invalidate(_ context.Context, key string) {
|
||||
switch key {
|
||||
case salt.DefaultLocation:
|
||||
b.saltMutex.Lock()
|
||||
@@ -139,9 +140,9 @@ func (b *backend) invalidate(key string) {
|
||||
// This could mean that the SecretID may live in the backend upto 1 min after its
|
||||
// expiration. The deletion of SecretIDs are not security sensitive and it is okay
|
||||
// to delay the removal of SecretIDs by a minute.
|
||||
func (b *backend) periodicFunc(req *logical.Request) error {
|
||||
func (b *backend) periodicFunc(ctx context.Context, req *logical.Request) error {
|
||||
// Initiate clean-up of expired SecretID entries
|
||||
b.tidySecretID(req.Storage)
|
||||
b.tidySecretID(ctx, req.Storage)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package approle
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/hashicorp/vault/logical"
|
||||
@@ -17,7 +18,7 @@ func createBackendWithStorage(t *testing.T) (*backend, logical.Storage) {
|
||||
if b == nil {
|
||||
t.Fatalf("failed to create backend")
|
||||
}
|
||||
err = b.Backend.Setup(config)
|
||||
err = b.Backend.Setup(context.Background(), config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -51,7 +51,7 @@ func (b *backend) pathLoginUpdateAliasLookahead(ctx context.Context, req *logica
|
||||
// Returns the Auth object indicating the authentication and authorization information
|
||||
// if the credentials provided are validated by the backend.
|
||||
func (b *backend) pathLoginUpdate(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
role, roleName, metadata, _, err := b.validateCredentials(req, data)
|
||||
role, roleName, metadata, _, err := b.validateCredentials(ctx, req, data)
|
||||
if err != nil || role == nil {
|
||||
return logical.ErrorResponse(fmt.Sprintf("failed to validate credentials: %v", err)), nil
|
||||
}
|
||||
@@ -93,7 +93,7 @@ func (b *backend) pathLoginRenew(ctx context.Context, req *logical.Request, data
|
||||
defer lock.RUnlock()
|
||||
|
||||
// Ensure that the Role still exists.
|
||||
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
|
||||
role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to validate role %s during renewal:%s", roleName, err)
|
||||
}
|
||||
|
||||
@@ -523,7 +523,7 @@ func (b *backend) pathRoleExistenceCheck(ctx context.Context, req *logical.Reque
|
||||
lock.RLock()
|
||||
defer lock.RUnlock()
|
||||
|
||||
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
|
||||
role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName))
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
@@ -538,7 +538,7 @@ func (b *backend) pathRoleList(ctx context.Context, req *logical.Request, data *
|
||||
lock.RLock()
|
||||
defer lock.RUnlock()
|
||||
|
||||
roles, err := req.Storage.List("role/")
|
||||
roles, err := req.Storage.List(ctx, "role/")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -557,7 +557,7 @@ func (b *backend) pathRoleSecretIDList(ctx context.Context, req *logical.Request
|
||||
defer lock.RUnlock()
|
||||
|
||||
// Get the role entry
|
||||
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
|
||||
role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -580,7 +580,7 @@ func (b *backend) pathRoleSecretIDList(ctx context.Context, req *logical.Request
|
||||
|
||||
// Listing works one level at a time. Get the first level of data
|
||||
// which could then be used to get the actual SecretID storage entries.
|
||||
secretIDHMACs, err := req.Storage.List(fmt.Sprintf("secret_id/%s/", roleNameHMAC))
|
||||
secretIDHMACs, err := req.Storage.List(ctx, fmt.Sprintf("secret_id/%s/", roleNameHMAC))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -606,7 +606,7 @@ func (b *backend) pathRoleSecretIDList(ctx context.Context, req *logical.Request
|
||||
secretIDLock.RLock()
|
||||
|
||||
result := secretIDStorageEntry{}
|
||||
if entry, err := req.Storage.Get(entryIndex); err != nil {
|
||||
if entry, err := req.Storage.Get(ctx, entryIndex); err != nil {
|
||||
secretIDLock.RUnlock()
|
||||
return nil, err
|
||||
} else if entry == nil {
|
||||
@@ -643,7 +643,7 @@ func validateRoleConstraints(role *roleStorageEntry) error {
|
||||
|
||||
// setRoleEntry persists the role and creates an index from roleID to role
|
||||
// name.
|
||||
func (b *backend) setRoleEntry(s logical.Storage, roleName string, role *roleStorageEntry, previousRoleID string) error {
|
||||
func (b *backend) setRoleEntry(ctx context.Context, s logical.Storage, roleName string, role *roleStorageEntry, previousRoleID string) error {
|
||||
if roleName == "" {
|
||||
return fmt.Errorf("missing role name")
|
||||
}
|
||||
@@ -667,7 +667,7 @@ func (b *backend) setRoleEntry(s logical.Storage, roleName string, role *roleSto
|
||||
}
|
||||
|
||||
// Check if the index from the role_id to role already exists
|
||||
roleIDIndex, err := b.roleIDEntry(s, role.RoleID)
|
||||
roleIDIndex, err := b.roleIDEntry(ctx, s, role.RoleID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read role_id index: %v", err)
|
||||
}
|
||||
@@ -680,13 +680,13 @@ func (b *backend) setRoleEntry(s logical.Storage, roleName string, role *roleSto
|
||||
// When role_id is getting updated, delete the old index before
|
||||
// a new one is created
|
||||
if previousRoleID != "" && previousRoleID != role.RoleID {
|
||||
if err = b.roleIDEntryDelete(s, previousRoleID); err != nil {
|
||||
if err = b.roleIDEntryDelete(ctx, s, previousRoleID); err != nil {
|
||||
return fmt.Errorf("failed to delete previous role ID index")
|
||||
}
|
||||
}
|
||||
|
||||
// Save the role entry only after all the validations
|
||||
if err = s.Put(entry); err != nil {
|
||||
if err = s.Put(ctx, entry); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -697,20 +697,20 @@ func (b *backend) setRoleEntry(s logical.Storage, roleName string, role *roleSto
|
||||
|
||||
// Create a storage entry for reverse mapping of RoleID to role.
|
||||
// Note that secondary index is created when the roleLock is held.
|
||||
return b.setRoleIDEntry(s, role.RoleID, &roleIDStorageEntry{
|
||||
return b.setRoleIDEntry(ctx, s, role.RoleID, &roleIDStorageEntry{
|
||||
Name: roleName,
|
||||
})
|
||||
}
|
||||
|
||||
// roleEntry reads the role from storage
|
||||
func (b *backend) roleEntry(s logical.Storage, roleName string) (*roleStorageEntry, error) {
|
||||
func (b *backend) roleEntry(ctx context.Context, s logical.Storage, roleName string) (*roleStorageEntry, error) {
|
||||
if roleName == "" {
|
||||
return nil, fmt.Errorf("missing role_name")
|
||||
}
|
||||
|
||||
var role roleStorageEntry
|
||||
|
||||
if entry, err := s.Get("role/" + strings.ToLower(roleName)); err != nil {
|
||||
if entry, err := s.Get(ctx, "role/"+strings.ToLower(roleName)); err != nil {
|
||||
return nil, err
|
||||
} else if entry == nil {
|
||||
return nil, nil
|
||||
@@ -734,7 +734,7 @@ func (b *backend) pathRoleCreateUpdate(ctx context.Context, req *logical.Request
|
||||
defer lock.Unlock()
|
||||
|
||||
// Check if the role already exists
|
||||
role, err := b.roleEntry(req.Storage, roleName)
|
||||
role, err := b.roleEntry(ctx, req.Storage, roleName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -855,7 +855,7 @@ func (b *backend) pathRoleCreateUpdate(ctx context.Context, req *logical.Request
|
||||
}
|
||||
|
||||
// Store the entry.
|
||||
return resp, b.setRoleEntry(req.Storage, roleName, role, previousRoleID)
|
||||
return resp, b.setRoleEntry(ctx, req.Storage, roleName, role, previousRoleID)
|
||||
}
|
||||
|
||||
// pathRoleRead grabs a read lock and reads the options set on the role from the storage
|
||||
@@ -869,7 +869,7 @@ func (b *backend) pathRoleRead(ctx context.Context, req *logical.Request, data *
|
||||
lock.RLock()
|
||||
lockRelease := lock.RUnlock
|
||||
|
||||
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
|
||||
role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName))
|
||||
if err != nil {
|
||||
lockRelease()
|
||||
return nil, err
|
||||
@@ -902,7 +902,7 @@ func (b *backend) pathRoleRead(ctx context.Context, req *logical.Request, data *
|
||||
|
||||
// For sanity, verify that the index still exists. If the index is missing,
|
||||
// add one and return a warning so it can be reported.
|
||||
roleIDIndex, err := b.roleIDEntry(req.Storage, role.RoleID)
|
||||
roleIDIndex, err := b.roleIDEntry(ctx, req.Storage, role.RoleID)
|
||||
if err != nil {
|
||||
lockRelease()
|
||||
return nil, err
|
||||
@@ -915,7 +915,7 @@ func (b *backend) pathRoleRead(ctx context.Context, req *logical.Request, data *
|
||||
lockRelease = lock.Unlock
|
||||
|
||||
// Check again if the index is missing
|
||||
roleIDIndex, err = b.roleIDEntry(req.Storage, role.RoleID)
|
||||
roleIDIndex, err = b.roleIDEntry(ctx, req.Storage, role.RoleID)
|
||||
if err != nil {
|
||||
lockRelease()
|
||||
return nil, err
|
||||
@@ -923,7 +923,7 @@ func (b *backend) pathRoleRead(ctx context.Context, req *logical.Request, data *
|
||||
|
||||
if roleIDIndex == nil {
|
||||
// Create a new index
|
||||
err = b.setRoleIDEntry(req.Storage, role.RoleID, &roleIDStorageEntry{
|
||||
err = b.setRoleIDEntry(ctx, req.Storage, role.RoleID, &roleIDStorageEntry{
|
||||
Name: roleName,
|
||||
})
|
||||
if err != nil {
|
||||
@@ -950,7 +950,7 @@ func (b *backend) pathRoleDelete(ctx context.Context, req *logical.Request, data
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
|
||||
role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -959,17 +959,17 @@ func (b *backend) pathRoleDelete(ctx context.Context, req *logical.Request, data
|
||||
}
|
||||
|
||||
// Just before the role is deleted, remove all the SecretIDs issued as part of the role.
|
||||
if err = b.flushRoleSecrets(req.Storage, roleName, role.HMACKey); err != nil {
|
||||
if err = b.flushRoleSecrets(ctx, req.Storage, roleName, role.HMACKey); err != nil {
|
||||
return nil, fmt.Errorf("failed to invalidate the secrets belonging to role %q: %v", roleName, err)
|
||||
}
|
||||
|
||||
// Delete the reverse mapping from RoleID to the role
|
||||
if err = b.roleIDEntryDelete(req.Storage, role.RoleID); err != nil {
|
||||
if err = b.roleIDEntryDelete(ctx, req.Storage, role.RoleID); err != nil {
|
||||
return nil, fmt.Errorf("failed to delete the mapping from RoleID to role %q: %v", roleName, err)
|
||||
}
|
||||
|
||||
// After deleting the SecretIDs and the RoleID, delete the role itself
|
||||
if err = req.Storage.Delete("role/" + strings.ToLower(roleName)); err != nil {
|
||||
if err = req.Storage.Delete(ctx, "role/"+strings.ToLower(roleName)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -993,7 +993,7 @@ func (b *backend) pathRoleSecretIDLookupUpdate(ctx context.Context, req *logical
|
||||
defer lock.RUnlock()
|
||||
|
||||
// Fetch the role
|
||||
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
|
||||
role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1020,16 +1020,16 @@ func (b *backend) pathRoleSecretIDLookupUpdate(ctx context.Context, req *logical
|
||||
// Create the index at which the secret_id would've been stored
|
||||
entryIndex := fmt.Sprintf("secret_id/%s/%s", roleNameHMAC, secretIDHMAC)
|
||||
|
||||
return b.secretIDCommon(req.Storage, entryIndex, secretIDHMAC)
|
||||
return b.secretIDCommon(ctx, req.Storage, entryIndex, secretIDHMAC)
|
||||
}
|
||||
|
||||
func (b *backend) secretIDCommon(s logical.Storage, entryIndex, secretIDHMAC string) (*logical.Response, error) {
|
||||
func (b *backend) secretIDCommon(ctx context.Context, s logical.Storage, entryIndex, secretIDHMAC string) (*logical.Response, error) {
|
||||
lock := b.secretIDLock(secretIDHMAC)
|
||||
lock.RLock()
|
||||
defer lock.RUnlock()
|
||||
|
||||
result := secretIDStorageEntry{}
|
||||
if entry, err := s.Get(entryIndex); err != nil {
|
||||
if entry, err := s.Get(ctx, entryIndex); err != nil {
|
||||
return nil, err
|
||||
} else if entry == nil {
|
||||
return nil, nil
|
||||
@@ -1075,7 +1075,7 @@ func (b *backend) pathRoleSecretIDDestroyUpdateDelete(ctx context.Context, req *
|
||||
roleLock.RLock()
|
||||
defer roleLock.RUnlock()
|
||||
|
||||
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
|
||||
role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1100,7 +1100,7 @@ func (b *backend) pathRoleSecretIDDestroyUpdateDelete(ctx context.Context, req *
|
||||
defer lock.Unlock()
|
||||
|
||||
result := secretIDStorageEntry{}
|
||||
if entry, err := req.Storage.Get(entryIndex); err != nil {
|
||||
if entry, err := req.Storage.Get(ctx, entryIndex); err != nil {
|
||||
return nil, err
|
||||
} else if entry == nil {
|
||||
return nil, nil
|
||||
@@ -1109,12 +1109,12 @@ func (b *backend) pathRoleSecretIDDestroyUpdateDelete(ctx context.Context, req *
|
||||
}
|
||||
|
||||
// Delete the accessor of the SecretID first
|
||||
if err := b.deleteSecretIDAccessorEntry(req.Storage, result.SecretIDAccessor); err != nil {
|
||||
if err := b.deleteSecretIDAccessorEntry(ctx, req.Storage, result.SecretIDAccessor); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Delete the storage entry that corresponds to the SecretID
|
||||
if err := req.Storage.Delete(entryIndex); err != nil {
|
||||
if err := req.Storage.Delete(ctx, entryIndex); err != nil {
|
||||
return nil, fmt.Errorf("failed to delete secret_id: %v", err)
|
||||
}
|
||||
|
||||
@@ -1142,7 +1142,7 @@ func (b *backend) pathRoleSecretIDAccessorLookupUpdate(ctx context.Context, req
|
||||
lock.RLock()
|
||||
defer lock.RUnlock()
|
||||
|
||||
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
|
||||
role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1150,7 +1150,7 @@ func (b *backend) pathRoleSecretIDAccessorLookupUpdate(ctx context.Context, req
|
||||
return nil, fmt.Errorf("role %q does not exist", roleName)
|
||||
}
|
||||
|
||||
accessorEntry, err := b.secretIDAccessorEntry(req.Storage, secretIDAccessor)
|
||||
accessorEntry, err := b.secretIDAccessorEntry(ctx, req.Storage, secretIDAccessor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1165,7 +1165,7 @@ func (b *backend) pathRoleSecretIDAccessorLookupUpdate(ctx context.Context, req
|
||||
|
||||
entryIndex := fmt.Sprintf("secret_id/%s/%s", roleNameHMAC, accessorEntry.SecretIDHMAC)
|
||||
|
||||
return b.secretIDCommon(req.Storage, entryIndex, accessorEntry.SecretIDHMAC)
|
||||
return b.secretIDCommon(ctx, req.Storage, entryIndex, accessorEntry.SecretIDHMAC)
|
||||
}
|
||||
|
||||
func (b *backend) pathRoleSecretIDAccessorDestroyUpdateDelete(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
@@ -1183,7 +1183,7 @@ func (b *backend) pathRoleSecretIDAccessorDestroyUpdateDelete(ctx context.Contex
|
||||
// Get the role details to fetch the RoleID and accessor to get
|
||||
// the HMACed SecretID.
|
||||
|
||||
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
|
||||
role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1191,7 +1191,7 @@ func (b *backend) pathRoleSecretIDAccessorDestroyUpdateDelete(ctx context.Contex
|
||||
return nil, fmt.Errorf("role %q does not exist", roleName)
|
||||
}
|
||||
|
||||
accessorEntry, err := b.secretIDAccessorEntry(req.Storage, secretIDAccessor)
|
||||
accessorEntry, err := b.secretIDAccessorEntry(ctx, req.Storage, secretIDAccessor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1211,12 +1211,12 @@ func (b *backend) pathRoleSecretIDAccessorDestroyUpdateDelete(ctx context.Contex
|
||||
defer lock.Unlock()
|
||||
|
||||
// Delete the accessor of the SecretID first
|
||||
if err := b.deleteSecretIDAccessorEntry(req.Storage, secretIDAccessor); err != nil {
|
||||
if err := b.deleteSecretIDAccessorEntry(ctx, req.Storage, secretIDAccessor); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Delete the storage entry that corresponds to the SecretID
|
||||
if err := req.Storage.Delete(entryIndex); err != nil {
|
||||
if err := req.Storage.Delete(ctx, entryIndex); err != nil {
|
||||
return nil, fmt.Errorf("failed to delete secret_id: %v", err)
|
||||
}
|
||||
|
||||
@@ -1234,7 +1234,7 @@ func (b *backend) pathRoleBoundCIDRListUpdate(ctx context.Context, req *logical.
|
||||
defer lock.Unlock()
|
||||
|
||||
// Re-read the role after grabbing the lock
|
||||
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
|
||||
role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1257,7 +1257,7 @@ func (b *backend) pathRoleBoundCIDRListUpdate(ctx context.Context, req *logical.
|
||||
}
|
||||
}
|
||||
|
||||
return nil, b.setRoleEntry(req.Storage, roleName, role, "")
|
||||
return nil, b.setRoleEntry(ctx, req.Storage, roleName, role, "")
|
||||
}
|
||||
|
||||
func (b *backend) pathRoleBoundCIDRListRead(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
@@ -1270,7 +1270,7 @@ func (b *backend) pathRoleBoundCIDRListRead(ctx context.Context, req *logical.Re
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
if role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)); err != nil {
|
||||
if role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName)); err != nil {
|
||||
return nil, err
|
||||
} else if role == nil {
|
||||
return nil, nil
|
||||
@@ -1293,7 +1293,7 @@ func (b *backend) pathRoleBoundCIDRListDelete(ctx context.Context, req *logical.
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
|
||||
role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1304,7 +1304,7 @@ func (b *backend) pathRoleBoundCIDRListDelete(ctx context.Context, req *logical.
|
||||
// Deleting a field implies setting the value to it's default value.
|
||||
role.BoundCIDRList = data.GetDefaultOrZero("bound_cidr_list").(string)
|
||||
|
||||
return nil, b.setRoleEntry(req.Storage, roleName, role, "")
|
||||
return nil, b.setRoleEntry(ctx, req.Storage, roleName, role, "")
|
||||
}
|
||||
|
||||
func (b *backend) pathRoleBindSecretIDUpdate(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
@@ -1317,7 +1317,7 @@ func (b *backend) pathRoleBindSecretIDUpdate(ctx context.Context, req *logical.R
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
|
||||
role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1327,7 +1327,7 @@ func (b *backend) pathRoleBindSecretIDUpdate(ctx context.Context, req *logical.R
|
||||
|
||||
if bindSecretIDRaw, ok := data.GetOk("bind_secret_id"); ok {
|
||||
role.BindSecretID = bindSecretIDRaw.(bool)
|
||||
return nil, b.setRoleEntry(req.Storage, roleName, role, "")
|
||||
return nil, b.setRoleEntry(ctx, req.Storage, roleName, role, "")
|
||||
} else {
|
||||
return logical.ErrorResponse("missing bind_secret_id"), nil
|
||||
}
|
||||
@@ -1343,7 +1343,7 @@ func (b *backend) pathRoleBindSecretIDRead(ctx context.Context, req *logical.Req
|
||||
lock.RLock()
|
||||
defer lock.RUnlock()
|
||||
|
||||
if role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)); err != nil {
|
||||
if role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName)); err != nil {
|
||||
return nil, err
|
||||
} else if role == nil {
|
||||
return nil, nil
|
||||
@@ -1366,7 +1366,7 @@ func (b *backend) pathRoleBindSecretIDDelete(ctx context.Context, req *logical.R
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
|
||||
role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1377,7 +1377,7 @@ func (b *backend) pathRoleBindSecretIDDelete(ctx context.Context, req *logical.R
|
||||
// Deleting a field implies setting the value to it's default value.
|
||||
role.BindSecretID = data.GetDefaultOrZero("bind_secret_id").(bool)
|
||||
|
||||
return nil, b.setRoleEntry(req.Storage, roleName, role, "")
|
||||
return nil, b.setRoleEntry(ctx, req.Storage, roleName, role, "")
|
||||
}
|
||||
|
||||
func (b *backend) pathRolePoliciesUpdate(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
@@ -1390,7 +1390,7 @@ func (b *backend) pathRolePoliciesUpdate(ctx context.Context, req *logical.Reque
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
|
||||
role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1405,7 +1405,7 @@ func (b *backend) pathRolePoliciesUpdate(ctx context.Context, req *logical.Reque
|
||||
|
||||
role.Policies = policyutil.ParsePolicies(policiesRaw)
|
||||
|
||||
return nil, b.setRoleEntry(req.Storage, roleName, role, "")
|
||||
return nil, b.setRoleEntry(ctx, req.Storage, roleName, role, "")
|
||||
}
|
||||
|
||||
func (b *backend) pathRolePoliciesRead(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
@@ -1418,7 +1418,7 @@ func (b *backend) pathRolePoliciesRead(ctx context.Context, req *logical.Request
|
||||
lock.RLock()
|
||||
defer lock.RUnlock()
|
||||
|
||||
if role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)); err != nil {
|
||||
if role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName)); err != nil {
|
||||
return nil, err
|
||||
} else if role == nil {
|
||||
return nil, nil
|
||||
@@ -1441,7 +1441,7 @@ func (b *backend) pathRolePoliciesDelete(ctx context.Context, req *logical.Reque
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
|
||||
role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1451,7 +1451,7 @@ func (b *backend) pathRolePoliciesDelete(ctx context.Context, req *logical.Reque
|
||||
|
||||
role.Policies = []string{}
|
||||
|
||||
return nil, b.setRoleEntry(req.Storage, roleName, role, "")
|
||||
return nil, b.setRoleEntry(ctx, req.Storage, roleName, role, "")
|
||||
}
|
||||
|
||||
func (b *backend) pathRoleSecretIDNumUsesUpdate(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
@@ -1464,7 +1464,7 @@ func (b *backend) pathRoleSecretIDNumUsesUpdate(ctx context.Context, req *logica
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
|
||||
role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1477,7 +1477,7 @@ func (b *backend) pathRoleSecretIDNumUsesUpdate(ctx context.Context, req *logica
|
||||
if role.SecretIDNumUses < 0 {
|
||||
return logical.ErrorResponse("secret_id_num_uses cannot be negative"), nil
|
||||
}
|
||||
return nil, b.setRoleEntry(req.Storage, roleName, role, "")
|
||||
return nil, b.setRoleEntry(ctx, req.Storage, roleName, role, "")
|
||||
} else {
|
||||
return logical.ErrorResponse("missing secret_id_num_uses"), nil
|
||||
}
|
||||
@@ -1493,7 +1493,7 @@ func (b *backend) pathRoleRoleIDUpdate(ctx context.Context, req *logical.Request
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
|
||||
role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1507,7 +1507,7 @@ func (b *backend) pathRoleRoleIDUpdate(ctx context.Context, req *logical.Request
|
||||
return logical.ErrorResponse("missing role_id"), nil
|
||||
}
|
||||
|
||||
return nil, b.setRoleEntry(req.Storage, roleName, role, previousRoleID)
|
||||
return nil, b.setRoleEntry(ctx, req.Storage, roleName, role, previousRoleID)
|
||||
}
|
||||
|
||||
func (b *backend) pathRoleRoleIDRead(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
@@ -1520,7 +1520,7 @@ func (b *backend) pathRoleRoleIDRead(ctx context.Context, req *logical.Request,
|
||||
lock.RLock()
|
||||
defer lock.RUnlock()
|
||||
|
||||
if role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)); err != nil {
|
||||
if role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName)); err != nil {
|
||||
return nil, err
|
||||
} else if role == nil {
|
||||
return nil, nil
|
||||
@@ -1543,7 +1543,7 @@ func (b *backend) pathRoleSecretIDNumUsesRead(ctx context.Context, req *logical.
|
||||
lock.RLock()
|
||||
defer lock.RUnlock()
|
||||
|
||||
if role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)); err != nil {
|
||||
if role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName)); err != nil {
|
||||
return nil, err
|
||||
} else if role == nil {
|
||||
return nil, nil
|
||||
@@ -1566,7 +1566,7 @@ func (b *backend) pathRoleSecretIDNumUsesDelete(ctx context.Context, req *logica
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
|
||||
role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1576,7 +1576,7 @@ func (b *backend) pathRoleSecretIDNumUsesDelete(ctx context.Context, req *logica
|
||||
|
||||
role.SecretIDNumUses = data.GetDefaultOrZero("secret_id_num_uses").(int)
|
||||
|
||||
return nil, b.setRoleEntry(req.Storage, roleName, role, "")
|
||||
return nil, b.setRoleEntry(ctx, req.Storage, roleName, role, "")
|
||||
}
|
||||
|
||||
func (b *backend) pathRoleSecretIDTTLUpdate(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
@@ -1589,7 +1589,7 @@ func (b *backend) pathRoleSecretIDTTLUpdate(ctx context.Context, req *logical.Re
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
|
||||
role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1599,7 +1599,7 @@ func (b *backend) pathRoleSecretIDTTLUpdate(ctx context.Context, req *logical.Re
|
||||
|
||||
if secretIDTTLRaw, ok := data.GetOk("secret_id_ttl"); ok {
|
||||
role.SecretIDTTL = time.Second * time.Duration(secretIDTTLRaw.(int))
|
||||
return nil, b.setRoleEntry(req.Storage, roleName, role, "")
|
||||
return nil, b.setRoleEntry(ctx, req.Storage, roleName, role, "")
|
||||
} else {
|
||||
return logical.ErrorResponse("missing secret_id_ttl"), nil
|
||||
}
|
||||
@@ -1615,7 +1615,7 @@ func (b *backend) pathRoleSecretIDTTLRead(ctx context.Context, req *logical.Requ
|
||||
lock.RLock()
|
||||
defer lock.RUnlock()
|
||||
|
||||
if role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)); err != nil {
|
||||
if role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName)); err != nil {
|
||||
return nil, err
|
||||
} else if role == nil {
|
||||
return nil, nil
|
||||
@@ -1639,7 +1639,7 @@ func (b *backend) pathRoleSecretIDTTLDelete(ctx context.Context, req *logical.Re
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
|
||||
role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1649,7 +1649,7 @@ func (b *backend) pathRoleSecretIDTTLDelete(ctx context.Context, req *logical.Re
|
||||
|
||||
role.SecretIDTTL = time.Second * time.Duration(data.GetDefaultOrZero("secret_id_ttl").(int))
|
||||
|
||||
return nil, b.setRoleEntry(req.Storage, roleName, role, "")
|
||||
return nil, b.setRoleEntry(ctx, req.Storage, roleName, role, "")
|
||||
}
|
||||
|
||||
func (b *backend) pathRolePeriodUpdate(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
@@ -1662,7 +1662,7 @@ func (b *backend) pathRolePeriodUpdate(ctx context.Context, req *logical.Request
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
|
||||
role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1675,7 +1675,7 @@ func (b *backend) pathRolePeriodUpdate(ctx context.Context, req *logical.Request
|
||||
if role.Period > b.System().MaxLeaseTTL() {
|
||||
return logical.ErrorResponse(fmt.Sprintf("period of %q is greater than the backend's maximum lease TTL of %q", role.Period.String(), b.System().MaxLeaseTTL().String())), nil
|
||||
}
|
||||
return nil, b.setRoleEntry(req.Storage, roleName, role, "")
|
||||
return nil, b.setRoleEntry(ctx, req.Storage, roleName, role, "")
|
||||
} else {
|
||||
return logical.ErrorResponse("missing period"), nil
|
||||
}
|
||||
@@ -1691,7 +1691,7 @@ func (b *backend) pathRolePeriodRead(ctx context.Context, req *logical.Request,
|
||||
lock.RLock()
|
||||
defer lock.RUnlock()
|
||||
|
||||
if role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)); err != nil {
|
||||
if role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName)); err != nil {
|
||||
return nil, err
|
||||
} else if role == nil {
|
||||
return nil, nil
|
||||
@@ -1715,7 +1715,7 @@ func (b *backend) pathRolePeriodDelete(ctx context.Context, req *logical.Request
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
|
||||
role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1725,7 +1725,7 @@ func (b *backend) pathRolePeriodDelete(ctx context.Context, req *logical.Request
|
||||
|
||||
role.Period = time.Second * time.Duration(data.GetDefaultOrZero("period").(int))
|
||||
|
||||
return nil, b.setRoleEntry(req.Storage, roleName, role, "")
|
||||
return nil, b.setRoleEntry(ctx, req.Storage, roleName, role, "")
|
||||
}
|
||||
|
||||
func (b *backend) pathRoleTokenNumUsesUpdate(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
@@ -1738,7 +1738,7 @@ func (b *backend) pathRoleTokenNumUsesUpdate(ctx context.Context, req *logical.R
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
|
||||
role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1748,7 +1748,7 @@ func (b *backend) pathRoleTokenNumUsesUpdate(ctx context.Context, req *logical.R
|
||||
|
||||
if tokenNumUsesRaw, ok := data.GetOk("token_num_uses"); ok {
|
||||
role.TokenNumUses = tokenNumUsesRaw.(int)
|
||||
return nil, b.setRoleEntry(req.Storage, roleName, role, "")
|
||||
return nil, b.setRoleEntry(ctx, req.Storage, roleName, role, "")
|
||||
} else {
|
||||
return logical.ErrorResponse("missing token_num_uses"), nil
|
||||
}
|
||||
@@ -1764,7 +1764,7 @@ func (b *backend) pathRoleTokenNumUsesRead(ctx context.Context, req *logical.Req
|
||||
lock.RLock()
|
||||
defer lock.RUnlock()
|
||||
|
||||
if role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)); err != nil {
|
||||
if role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName)); err != nil {
|
||||
return nil, err
|
||||
} else if role == nil {
|
||||
return nil, nil
|
||||
@@ -1787,7 +1787,7 @@ func (b *backend) pathRoleTokenNumUsesDelete(ctx context.Context, req *logical.R
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
|
||||
role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1797,7 +1797,7 @@ func (b *backend) pathRoleTokenNumUsesDelete(ctx context.Context, req *logical.R
|
||||
|
||||
role.TokenNumUses = data.GetDefaultOrZero("token_num_uses").(int)
|
||||
|
||||
return nil, b.setRoleEntry(req.Storage, roleName, role, "")
|
||||
return nil, b.setRoleEntry(ctx, req.Storage, roleName, role, "")
|
||||
}
|
||||
|
||||
func (b *backend) pathRoleTokenTTLUpdate(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
@@ -1810,7 +1810,7 @@ func (b *backend) pathRoleTokenTTLUpdate(ctx context.Context, req *logical.Reque
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
|
||||
role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1823,7 +1823,7 @@ func (b *backend) pathRoleTokenTTLUpdate(ctx context.Context, req *logical.Reque
|
||||
if role.TokenMaxTTL > time.Duration(0) && role.TokenTTL > role.TokenMaxTTL {
|
||||
return logical.ErrorResponse("token_ttl should not be greater than token_max_ttl"), nil
|
||||
}
|
||||
return nil, b.setRoleEntry(req.Storage, roleName, role, "")
|
||||
return nil, b.setRoleEntry(ctx, req.Storage, roleName, role, "")
|
||||
} else {
|
||||
return logical.ErrorResponse("missing token_ttl"), nil
|
||||
}
|
||||
@@ -1839,7 +1839,7 @@ func (b *backend) pathRoleTokenTTLRead(ctx context.Context, req *logical.Request
|
||||
lock.RLock()
|
||||
defer lock.RUnlock()
|
||||
|
||||
if role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)); err != nil {
|
||||
if role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName)); err != nil {
|
||||
return nil, err
|
||||
} else if role == nil {
|
||||
return nil, nil
|
||||
@@ -1863,7 +1863,7 @@ func (b *backend) pathRoleTokenTTLDelete(ctx context.Context, req *logical.Reque
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
|
||||
role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1873,7 +1873,7 @@ func (b *backend) pathRoleTokenTTLDelete(ctx context.Context, req *logical.Reque
|
||||
|
||||
role.TokenTTL = time.Second * time.Duration(data.GetDefaultOrZero("token_ttl").(int))
|
||||
|
||||
return nil, b.setRoleEntry(req.Storage, roleName, role, "")
|
||||
return nil, b.setRoleEntry(ctx, req.Storage, roleName, role, "")
|
||||
}
|
||||
|
||||
func (b *backend) pathRoleTokenMaxTTLUpdate(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
@@ -1886,7 +1886,7 @@ func (b *backend) pathRoleTokenMaxTTLUpdate(ctx context.Context, req *logical.Re
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
|
||||
role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1899,7 +1899,7 @@ func (b *backend) pathRoleTokenMaxTTLUpdate(ctx context.Context, req *logical.Re
|
||||
if role.TokenMaxTTL > time.Duration(0) && role.TokenTTL > role.TokenMaxTTL {
|
||||
return logical.ErrorResponse("token_max_ttl should be greater than or equal to token_ttl"), nil
|
||||
}
|
||||
return nil, b.setRoleEntry(req.Storage, roleName, role, "")
|
||||
return nil, b.setRoleEntry(ctx, req.Storage, roleName, role, "")
|
||||
} else {
|
||||
return logical.ErrorResponse("missing token_max_ttl"), nil
|
||||
}
|
||||
@@ -1915,7 +1915,7 @@ func (b *backend) pathRoleTokenMaxTTLRead(ctx context.Context, req *logical.Requ
|
||||
lock.RLock()
|
||||
defer lock.RUnlock()
|
||||
|
||||
if role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)); err != nil {
|
||||
if role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName)); err != nil {
|
||||
return nil, err
|
||||
} else if role == nil {
|
||||
return nil, nil
|
||||
@@ -1939,7 +1939,7 @@ func (b *backend) pathRoleTokenMaxTTLDelete(ctx context.Context, req *logical.Re
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
|
||||
role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1949,7 +1949,7 @@ func (b *backend) pathRoleTokenMaxTTLDelete(ctx context.Context, req *logical.Re
|
||||
|
||||
role.TokenMaxTTL = time.Second * time.Duration(data.GetDefaultOrZero("token_max_ttl").(int))
|
||||
|
||||
return nil, b.setRoleEntry(req.Storage, roleName, role, "")
|
||||
return nil, b.setRoleEntry(ctx, req.Storage, roleName, role, "")
|
||||
}
|
||||
|
||||
func (b *backend) pathRoleSecretIDUpdate(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
@@ -1978,7 +1978,7 @@ func (b *backend) handleRoleSecretIDCommon(ctx context.Context, req *logical.Req
|
||||
lock.RLock()
|
||||
defer lock.RUnlock()
|
||||
|
||||
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
|
||||
role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -2026,7 +2026,7 @@ func (b *backend) handleRoleSecretIDCommon(ctx context.Context, req *logical.Req
|
||||
roleName = strings.ToLower(roleName)
|
||||
}
|
||||
|
||||
if secretIDStorage, err = b.registerSecretIDEntry(req.Storage, roleName, secretID, role.HMACKey, secretIDStorage); err != nil {
|
||||
if secretIDStorage, err = b.registerSecretIDEntry(ctx, req.Storage, roleName, secretID, role.HMACKey, secretIDStorage); err != nil {
|
||||
return nil, fmt.Errorf("failed to store secret_id: %v", err)
|
||||
}
|
||||
|
||||
@@ -2047,7 +2047,7 @@ func (b *backend) roleLock(roleName string) *locksutil.LockEntry {
|
||||
}
|
||||
|
||||
// setRoleIDEntry creates a storage entry that maps RoleID to Role
|
||||
func (b *backend) setRoleIDEntry(s logical.Storage, roleID string, roleIDEntry *roleIDStorageEntry) error {
|
||||
func (b *backend) setRoleIDEntry(ctx context.Context, s logical.Storage, roleID string, roleIDEntry *roleIDStorageEntry) error {
|
||||
lock := b.roleIDLock(roleID)
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
@@ -2062,14 +2062,14 @@ func (b *backend) setRoleIDEntry(s logical.Storage, roleID string, roleIDEntry *
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err = s.Put(entry); err != nil {
|
||||
if err = s.Put(ctx, entry); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// roleIDEntry is used to read the storage entry that maps RoleID to Role
|
||||
func (b *backend) roleIDEntry(s logical.Storage, roleID string) (*roleIDStorageEntry, error) {
|
||||
func (b *backend) roleIDEntry(ctx context.Context, s logical.Storage, roleID string) (*roleIDStorageEntry, error) {
|
||||
if roleID == "" {
|
||||
return nil, fmt.Errorf("missing roleID")
|
||||
}
|
||||
@@ -2086,7 +2086,7 @@ func (b *backend) roleIDEntry(s logical.Storage, roleID string) (*roleIDStorageE
|
||||
}
|
||||
entryIndex := "role_id/" + salt.SaltID(roleID)
|
||||
|
||||
if entry, err := s.Get(entryIndex); err != nil {
|
||||
if entry, err := s.Get(ctx, entryIndex); err != nil {
|
||||
return nil, err
|
||||
} else if entry == nil {
|
||||
return nil, nil
|
||||
@@ -2099,7 +2099,7 @@ func (b *backend) roleIDEntry(s logical.Storage, roleID string) (*roleIDStorageE
|
||||
|
||||
// roleIDEntryDelete is used to remove the secondary index that maps the
|
||||
// RoleID to the Role itself.
|
||||
func (b *backend) roleIDEntryDelete(s logical.Storage, roleID string) error {
|
||||
func (b *backend) roleIDEntryDelete(ctx context.Context, s logical.Storage, roleID string) error {
|
||||
if roleID == "" {
|
||||
return fmt.Errorf("missing roleID")
|
||||
}
|
||||
@@ -2114,7 +2114,7 @@ func (b *backend) roleIDEntryDelete(s logical.Storage, roleID string) error {
|
||||
}
|
||||
entryIndex := "role_id/" + salt.SaltID(roleID)
|
||||
|
||||
return s.Delete(entryIndex)
|
||||
return s.Delete(ctx, entryIndex)
|
||||
}
|
||||
|
||||
var roleHelp = map[string][2]string{
|
||||
|
||||
@@ -26,7 +26,7 @@ func TestApprole_RoleNameLowerCasing(t *testing.T) {
|
||||
Policies: []string{"default"},
|
||||
BindSecretID: true,
|
||||
}
|
||||
err = b.setRoleEntry(storage, "testRoleName", role, "")
|
||||
err = b.setRoleEntry(context.Background(), storage, "testRoleName", role, "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -208,7 +208,7 @@ func TestAppRole_RoleReadSetIndex(t *testing.T) {
|
||||
roleID := resp.Data["role_id"].(string)
|
||||
|
||||
// Delete the role ID index
|
||||
err = b.roleIDEntryDelete(storage, roleID)
|
||||
err = b.roleIDEntryDelete(context.Background(), storage, roleID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -225,7 +225,7 @@ func TestAppRole_RoleReadSetIndex(t *testing.T) {
|
||||
t.Fatalf("bad: expected a warning in the response")
|
||||
}
|
||||
|
||||
roleIDIndex, err := b.roleIDEntry(storage, roleID)
|
||||
roleIDIndex, err := b.roleIDEntry(context.Background(), storage, roleID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -25,7 +25,7 @@ func pathTidySecretID(b *backend) *framework.Path {
|
||||
}
|
||||
|
||||
// tidySecretID is used to delete entries in the whitelist that are expired.
|
||||
func (b *backend) tidySecretID(s logical.Storage) error {
|
||||
func (b *backend) tidySecretID(ctx context.Context, s logical.Storage) error {
|
||||
grabbed := atomic.CompareAndSwapUint32(&b.tidySecretIDCASGuard, 0, 1)
|
||||
if grabbed {
|
||||
defer atomic.StoreUint32(&b.tidySecretIDCASGuard, 0)
|
||||
@@ -33,7 +33,7 @@ func (b *backend) tidySecretID(s logical.Storage) error {
|
||||
return fmt.Errorf("SecretID tidy operation already running")
|
||||
}
|
||||
|
||||
roleNameHMACs, err := s.List("secret_id/")
|
||||
roleNameHMACs, err := s.List(ctx, "secret_id/")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -41,7 +41,7 @@ func (b *backend) tidySecretID(s logical.Storage) error {
|
||||
var result error
|
||||
for _, roleNameHMAC := range roleNameHMACs {
|
||||
// roleNameHMAC will already have a '/' suffix. Don't append another one.
|
||||
secretIDHMACs, err := s.List(fmt.Sprintf("secret_id/%s", roleNameHMAC))
|
||||
secretIDHMACs, err := s.List(ctx, fmt.Sprintf("secret_id/%s", roleNameHMAC))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -52,7 +52,7 @@ func (b *backend) tidySecretID(s logical.Storage) error {
|
||||
lock.Lock()
|
||||
// roleNameHMAC will already have a '/' suffix. Don't append another one.
|
||||
entryIndex := fmt.Sprintf("secret_id/%s%s", roleNameHMAC, secretIDHMAC)
|
||||
secretIDEntry, err := s.Get(entryIndex)
|
||||
secretIDEntry, err := s.Get(ctx, entryIndex)
|
||||
if err != nil {
|
||||
lock.Unlock()
|
||||
return fmt.Errorf("error fetching SecretID %s: %s", secretIDHMAC, err)
|
||||
@@ -77,7 +77,7 @@ func (b *backend) tidySecretID(s logical.Storage) error {
|
||||
|
||||
// ExpirationTime not being set indicates non-expiring SecretIDs
|
||||
if !result.ExpirationTime.IsZero() && time.Now().After(result.ExpirationTime) {
|
||||
if err := s.Delete(entryIndex); err != nil {
|
||||
if err := s.Delete(ctx, entryIndex); err != nil {
|
||||
lock.Unlock()
|
||||
return fmt.Errorf("error deleting SecretID %s from storage: %s", secretIDHMAC, err)
|
||||
}
|
||||
@@ -90,7 +90,7 @@ func (b *backend) tidySecretID(s logical.Storage) error {
|
||||
|
||||
// pathTidySecretIDUpdate is used to delete the expired SecretID entries
|
||||
func (b *backend) pathTidySecretIDUpdate(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
return nil, b.tidySecretID(req.Storage)
|
||||
return nil, b.tidySecretID(ctx, req.Storage)
|
||||
}
|
||||
|
||||
const pathTidySecretIDSyn = "Trigger the clean-up of expired SecretID entries."
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package approle
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
@@ -68,9 +69,9 @@ type secretIDAccessorStorageEntry struct {
|
||||
}
|
||||
|
||||
// Checks if the Role represented by the RoleID still exists
|
||||
func (b *backend) validateRoleID(s logical.Storage, roleID string) (*roleStorageEntry, string, error) {
|
||||
func (b *backend) validateRoleID(ctx context.Context, s logical.Storage, roleID string) (*roleStorageEntry, string, error) {
|
||||
// Look for the storage entry that maps the roleID to role
|
||||
roleIDIndex, err := b.roleIDEntry(s, roleID)
|
||||
roleIDIndex, err := b.roleIDEntry(ctx, s, roleID)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
@@ -82,7 +83,7 @@ func (b *backend) validateRoleID(s logical.Storage, roleID string) (*roleStorage
|
||||
lock.RLock()
|
||||
defer lock.RUnlock()
|
||||
|
||||
role, err := b.roleEntry(s, roleIDIndex.Name)
|
||||
role, err := b.roleEntry(ctx, s, roleIDIndex.Name)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
@@ -94,7 +95,7 @@ func (b *backend) validateRoleID(s logical.Storage, roleID string) (*roleStorage
|
||||
}
|
||||
|
||||
// Validates the supplied RoleID and SecretID
|
||||
func (b *backend) validateCredentials(req *logical.Request, data *framework.FieldData) (*roleStorageEntry, string, map[string]string, string, error) {
|
||||
func (b *backend) validateCredentials(ctx context.Context, req *logical.Request, data *framework.FieldData) (*roleStorageEntry, string, map[string]string, string, error) {
|
||||
metadata := make(map[string]string)
|
||||
// RoleID must be supplied during every login
|
||||
roleID := strings.TrimSpace(data.Get("role_id").(string))
|
||||
@@ -103,7 +104,7 @@ func (b *backend) validateCredentials(req *logical.Request, data *framework.Fiel
|
||||
}
|
||||
|
||||
// Validate the RoleID and get the Role entry
|
||||
role, roleName, err := b.validateRoleID(req.Storage, roleID)
|
||||
role, roleName, err := b.validateRoleID(ctx, req.Storage, roleID)
|
||||
if err != nil {
|
||||
return nil, "", metadata, "", err
|
||||
}
|
||||
@@ -132,7 +133,7 @@ func (b *backend) validateCredentials(req *logical.Request, data *framework.Fiel
|
||||
// Check if the SecretID supplied is valid. If use limit was specified
|
||||
// on the SecretID, it will be decremented in this call.
|
||||
var valid bool
|
||||
valid, metadata, err = b.validateBindSecretID(req, roleName, secretID, role.HMACKey, role.BoundCIDRList)
|
||||
valid, metadata, err = b.validateBindSecretID(ctx, req, roleName, secretID, role.HMACKey, role.BoundCIDRList)
|
||||
if err != nil {
|
||||
return nil, "", metadata, "", err
|
||||
}
|
||||
@@ -160,7 +161,7 @@ func (b *backend) validateCredentials(req *logical.Request, data *framework.Fiel
|
||||
}
|
||||
|
||||
// validateBindSecretID is used to determine if the given SecretID is a valid one.
|
||||
func (b *backend) validateBindSecretID(req *logical.Request, roleName, secretID,
|
||||
func (b *backend) validateBindSecretID(ctx context.Context, req *logical.Request, roleName, secretID,
|
||||
hmacKey, roleBoundCIDRList string) (bool, map[string]string, error) {
|
||||
secretIDHMAC, err := createHMAC(hmacKey, secretID)
|
||||
if err != nil {
|
||||
@@ -180,7 +181,7 @@ func (b *backend) validateBindSecretID(req *logical.Request, roleName, secretID,
|
||||
lock := b.secretIDLock(secretIDHMAC)
|
||||
lock.RLock()
|
||||
|
||||
result, err := b.nonLockedSecretIDStorageEntry(req.Storage, roleNameHMAC, secretIDHMAC)
|
||||
result, err := b.nonLockedSecretIDStorageEntry(ctx, req.Storage, roleNameHMAC, secretIDHMAC)
|
||||
if err != nil {
|
||||
lock.RUnlock()
|
||||
return false, nil, err
|
||||
@@ -225,7 +226,7 @@ func (b *backend) validateBindSecretID(req *logical.Request, roleName, secretID,
|
||||
defer lock.Unlock()
|
||||
|
||||
// Lock switching may change the data. Refresh the contents.
|
||||
result, err = b.nonLockedSecretIDStorageEntry(req.Storage, roleNameHMAC, secretIDHMAC)
|
||||
result, err = b.nonLockedSecretIDStorageEntry(ctx, req.Storage, roleNameHMAC, secretIDHMAC)
|
||||
if err != nil {
|
||||
return false, nil, err
|
||||
}
|
||||
@@ -238,10 +239,10 @@ func (b *backend) validateBindSecretID(req *logical.Request, roleName, secretID,
|
||||
// requests to use the same SecretID will fail.
|
||||
if result.SecretIDNumUses == 1 {
|
||||
// Delete the secret IDs accessor first
|
||||
if err := b.deleteSecretIDAccessorEntry(req.Storage, result.SecretIDAccessor); err != nil {
|
||||
if err := b.deleteSecretIDAccessorEntry(ctx, req.Storage, result.SecretIDAccessor); err != nil {
|
||||
return false, nil, err
|
||||
}
|
||||
if err := req.Storage.Delete(entryIndex); err != nil {
|
||||
if err := req.Storage.Delete(ctx, entryIndex); err != nil {
|
||||
return false, nil, fmt.Errorf("failed to delete secret ID: %v", err)
|
||||
}
|
||||
} else {
|
||||
@@ -250,7 +251,7 @@ func (b *backend) validateBindSecretID(req *logical.Request, roleName, secretID,
|
||||
result.LastUpdatedTime = time.Now()
|
||||
if entry, err := logical.StorageEntryJSON(entryIndex, &result); err != nil {
|
||||
return false, nil, fmt.Errorf("failed to decrement the use count for secret ID %q", secretID)
|
||||
} else if err = req.Storage.Put(entry); err != nil {
|
||||
} else if err = req.Storage.Put(ctx, entry); err != nil {
|
||||
return false, nil, fmt.Errorf("failed to decrement the use count for secret ID %q", secretID)
|
||||
}
|
||||
}
|
||||
@@ -320,7 +321,7 @@ func (b *backend) secretIDAccessorLock(secretIDAccessor string) *locksutil.LockE
|
||||
// storage. The entry will be indexed based on the given HMACs of both role
|
||||
// name and the secret ID. This method will not acquire secret ID lock to fetch
|
||||
// the storage entry. Locks need to be acquired before calling this method.
|
||||
func (b *backend) nonLockedSecretIDStorageEntry(s logical.Storage, roleNameHMAC, secretIDHMAC string) (*secretIDStorageEntry, error) {
|
||||
func (b *backend) nonLockedSecretIDStorageEntry(ctx context.Context, s logical.Storage, roleNameHMAC, secretIDHMAC string) (*secretIDStorageEntry, error) {
|
||||
if secretIDHMAC == "" {
|
||||
return nil, fmt.Errorf("missing secret ID HMAC")
|
||||
}
|
||||
@@ -332,7 +333,7 @@ func (b *backend) nonLockedSecretIDStorageEntry(s logical.Storage, roleNameHMAC,
|
||||
// Prepare the storage index at which the secret ID will be stored
|
||||
entryIndex := fmt.Sprintf("secret_id/%s/%s", roleNameHMAC, secretIDHMAC)
|
||||
|
||||
entry, err := s.Get(entryIndex)
|
||||
entry, err := s.Get(ctx, entryIndex)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -360,7 +361,7 @@ func (b *backend) nonLockedSecretIDStorageEntry(s logical.Storage, roleNameHMAC,
|
||||
}
|
||||
|
||||
if persistNeeded {
|
||||
if err := b.nonLockedSetSecretIDStorageEntry(s, roleNameHMAC, secretIDHMAC, &result); err != nil {
|
||||
if err := b.nonLockedSetSecretIDStorageEntry(ctx, s, roleNameHMAC, secretIDHMAC, &result); err != nil {
|
||||
return nil, fmt.Errorf("failed to upgrade role storage entry %s", err)
|
||||
}
|
||||
}
|
||||
@@ -373,7 +374,7 @@ func (b *backend) nonLockedSecretIDStorageEntry(s logical.Storage, roleNameHMAC,
|
||||
// role name and the secret ID. This method will not acquire secret ID lock to
|
||||
// create/update the storage entry. Locks need to be acquired before calling
|
||||
// this method.
|
||||
func (b *backend) nonLockedSetSecretIDStorageEntry(s logical.Storage, roleNameHMAC, secretIDHMAC string, secretEntry *secretIDStorageEntry) error {
|
||||
func (b *backend) nonLockedSetSecretIDStorageEntry(ctx context.Context, s logical.Storage, roleNameHMAC, secretIDHMAC string, secretEntry *secretIDStorageEntry) error {
|
||||
if secretIDHMAC == "" {
|
||||
return fmt.Errorf("missing secret ID HMAC")
|
||||
}
|
||||
@@ -390,7 +391,7 @@ func (b *backend) nonLockedSetSecretIDStorageEntry(s logical.Storage, roleNameHM
|
||||
|
||||
if entry, err := logical.StorageEntryJSON(entryIndex, secretEntry); err != nil {
|
||||
return err
|
||||
} else if err = s.Put(entry); err != nil {
|
||||
} else if err = s.Put(ctx, entry); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -398,7 +399,7 @@ func (b *backend) nonLockedSetSecretIDStorageEntry(s logical.Storage, roleNameHM
|
||||
}
|
||||
|
||||
// registerSecretIDEntry creates a new storage entry for the given SecretID.
|
||||
func (b *backend) registerSecretIDEntry(s logical.Storage, roleName, secretID, hmacKey string, secretEntry *secretIDStorageEntry) (*secretIDStorageEntry, error) {
|
||||
func (b *backend) registerSecretIDEntry(ctx context.Context, s logical.Storage, roleName, secretID, hmacKey string, secretEntry *secretIDStorageEntry) (*secretIDStorageEntry, error) {
|
||||
secretIDHMAC, err := createHMAC(hmacKey, secretID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create HMAC of secret ID: %v", err)
|
||||
@@ -411,7 +412,7 @@ func (b *backend) registerSecretIDEntry(s logical.Storage, roleName, secretID, h
|
||||
lock := b.secretIDLock(secretIDHMAC)
|
||||
lock.RLock()
|
||||
|
||||
entry, err := b.nonLockedSecretIDStorageEntry(s, roleNameHMAC, secretIDHMAC)
|
||||
entry, err := b.nonLockedSecretIDStorageEntry(ctx, s, roleNameHMAC, secretIDHMAC)
|
||||
if err != nil {
|
||||
lock.RUnlock()
|
||||
return nil, err
|
||||
@@ -428,7 +429,7 @@ func (b *backend) registerSecretIDEntry(s logical.Storage, roleName, secretID, h
|
||||
defer lock.Unlock()
|
||||
|
||||
// But before saving a new entry, check if the secretID entry was created during the lock switch.
|
||||
entry, err = b.nonLockedSecretIDStorageEntry(s, roleNameHMAC, secretIDHMAC)
|
||||
entry, err = b.nonLockedSecretIDStorageEntry(ctx, s, roleNameHMAC, secretIDHMAC)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -457,11 +458,11 @@ func (b *backend) registerSecretIDEntry(s logical.Storage, roleName, secretID, h
|
||||
}
|
||||
|
||||
// Before storing the SecretID, store its accessor.
|
||||
if err := b.createSecretIDAccessorEntry(s, secretEntry, secretIDHMAC); err != nil {
|
||||
if err := b.createSecretIDAccessorEntry(ctx, s, secretEntry, secretIDHMAC); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := b.nonLockedSetSecretIDStorageEntry(s, roleNameHMAC, secretIDHMAC, secretEntry); err != nil {
|
||||
if err := b.nonLockedSetSecretIDStorageEntry(ctx, s, roleNameHMAC, secretIDHMAC, secretEntry); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -470,7 +471,7 @@ func (b *backend) registerSecretIDEntry(s logical.Storage, roleName, secretID, h
|
||||
|
||||
// secretIDAccessorEntry is used to read the storage entry that maps an
|
||||
// accessor to a secret_id.
|
||||
func (b *backend) secretIDAccessorEntry(s logical.Storage, secretIDAccessor string) (*secretIDAccessorStorageEntry, error) {
|
||||
func (b *backend) secretIDAccessorEntry(ctx context.Context, s logical.Storage, secretIDAccessor string) (*secretIDAccessorStorageEntry, error) {
|
||||
if secretIDAccessor == "" {
|
||||
return nil, fmt.Errorf("missing secretIDAccessor")
|
||||
}
|
||||
@@ -488,7 +489,7 @@ func (b *backend) secretIDAccessorEntry(s logical.Storage, secretIDAccessor stri
|
||||
accessorLock.RLock()
|
||||
defer accessorLock.RUnlock()
|
||||
|
||||
if entry, err := s.Get(entryIndex); err != nil {
|
||||
if entry, err := s.Get(ctx, entryIndex); err != nil {
|
||||
return nil, err
|
||||
} else if entry == nil {
|
||||
return nil, nil
|
||||
@@ -502,7 +503,7 @@ func (b *backend) secretIDAccessorEntry(s logical.Storage, secretIDAccessor stri
|
||||
// createSecretIDAccessorEntry creates an identifier for the SecretID. A storage index,
|
||||
// mapping the accessor to the SecretID is also created. This method should
|
||||
// be called when the lock for the corresponding SecretID is held.
|
||||
func (b *backend) createSecretIDAccessorEntry(s logical.Storage, entry *secretIDStorageEntry, secretIDHMAC string) error {
|
||||
func (b *backend) createSecretIDAccessorEntry(ctx context.Context, s logical.Storage, entry *secretIDStorageEntry, secretIDHMAC string) error {
|
||||
// Create a random accessor
|
||||
accessorUUID, err := uuid.GenerateUUID()
|
||||
if err != nil {
|
||||
@@ -525,7 +526,7 @@ func (b *backend) createSecretIDAccessorEntry(s logical.Storage, entry *secretID
|
||||
SecretIDHMAC: secretIDHMAC,
|
||||
}); err != nil {
|
||||
return err
|
||||
} else if err = s.Put(entry); err != nil {
|
||||
} else if err = s.Put(ctx, entry); err != nil {
|
||||
return fmt.Errorf("failed to persist accessor index entry: %v", err)
|
||||
}
|
||||
|
||||
@@ -533,7 +534,7 @@ func (b *backend) createSecretIDAccessorEntry(s logical.Storage, entry *secretID
|
||||
}
|
||||
|
||||
// deleteSecretIDAccessorEntry deletes the storage index mapping the accessor to a SecretID.
|
||||
func (b *backend) deleteSecretIDAccessorEntry(s logical.Storage, secretIDAccessor string) error {
|
||||
func (b *backend) deleteSecretIDAccessorEntry(ctx context.Context, s logical.Storage, secretIDAccessor string) error {
|
||||
salt, err := b.Salt()
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -545,7 +546,7 @@ func (b *backend) deleteSecretIDAccessorEntry(s logical.Storage, secretIDAccesso
|
||||
defer accessorLock.Unlock()
|
||||
|
||||
// Delete the accessor of the SecretID first
|
||||
if err := s.Delete(accessorEntryIndex); err != nil {
|
||||
if err := s.Delete(ctx, accessorEntryIndex); err != nil {
|
||||
return fmt.Errorf("failed to delete accessor storage entry: %v", err)
|
||||
}
|
||||
|
||||
@@ -554,7 +555,7 @@ func (b *backend) deleteSecretIDAccessorEntry(s logical.Storage, secretIDAccesso
|
||||
|
||||
// flushRoleSecrets deletes all the SecretIDs that belong to the given
|
||||
// RoleID.
|
||||
func (b *backend) flushRoleSecrets(s logical.Storage, roleName, hmacKey string) error {
|
||||
func (b *backend) flushRoleSecrets(ctx context.Context, s logical.Storage, roleName, hmacKey string) error {
|
||||
roleNameHMAC, err := createHMAC(hmacKey, roleName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create HMAC of role_name: %v", err)
|
||||
@@ -564,7 +565,7 @@ func (b *backend) flushRoleSecrets(s logical.Storage, roleName, hmacKey string)
|
||||
b.secretIDListingLock.RLock()
|
||||
defer b.secretIDListingLock.RUnlock()
|
||||
|
||||
secretIDHMACs, err := s.List(fmt.Sprintf("secret_id/%s/", roleNameHMAC))
|
||||
secretIDHMACs, err := s.List(ctx, fmt.Sprintf("secret_id/%s/", roleNameHMAC))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -573,7 +574,7 @@ func (b *backend) flushRoleSecrets(s logical.Storage, roleName, hmacKey string)
|
||||
lock := b.secretIDLock(secretIDHMAC)
|
||||
lock.Lock()
|
||||
entryIndex := fmt.Sprintf("secret_id/%s/%s", roleNameHMAC, secretIDHMAC)
|
||||
if err := s.Delete(entryIndex); err != nil {
|
||||
if err := s.Delete(ctx, entryIndex); err != nil {
|
||||
lock.Unlock()
|
||||
return fmt.Errorf("error deleting SecretID %q from storage: %v", secretIDHMAC, err)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package awsauth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -13,12 +14,12 @@ import (
|
||||
"github.com/patrickmn/go-cache"
|
||||
)
|
||||
|
||||
func Factory(conf *logical.BackendConfig) (logical.Backend, error) {
|
||||
func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) {
|
||||
b, err := Backend(conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := b.Setup(conf); err != nil {
|
||||
if err := b.Setup(ctx, conf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return b, nil
|
||||
@@ -73,7 +74,7 @@ type backend struct {
|
||||
// accounts using their IAM instance profile to get their credentials.
|
||||
defaultAWSAccountID string
|
||||
|
||||
resolveArnToUniqueIDFunc func(logical.Storage, string) (string, error)
|
||||
resolveArnToUniqueIDFunc func(context.Context, logical.Storage, string) (string, error)
|
||||
}
|
||||
|
||||
func Backend(conf *logical.BackendConfig) (*backend, error) {
|
||||
@@ -138,13 +139,13 @@ func Backend(conf *logical.BackendConfig) (*backend, error) {
|
||||
// not once in a minute, but once in an hour, controlled by 'tidyCooldownPeriod'.
|
||||
// Tidying of blacklist and whitelist are by default enabled. This can be
|
||||
// changed using `config/tidy/roletags` and `config/tidy/identities` endpoints.
|
||||
func (b *backend) periodicFunc(req *logical.Request) error {
|
||||
func (b *backend) periodicFunc(ctx context.Context, req *logical.Request) error {
|
||||
// Run the tidy operations for the first time. Then run it when current
|
||||
// time matches the nextTidyTime.
|
||||
if b.nextTidyTime.IsZero() || !time.Now().Before(b.nextTidyTime) {
|
||||
// safety_buffer defaults to 180 days for roletag blacklist
|
||||
safety_buffer := 15552000
|
||||
tidyBlacklistConfigEntry, err := b.lockedConfigTidyRoleTags(req.Storage)
|
||||
tidyBlacklistConfigEntry, err := b.lockedConfigTidyRoleTags(ctx, req.Storage)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -160,12 +161,12 @@ func (b *backend) periodicFunc(req *logical.Request) error {
|
||||
}
|
||||
// tidy role tags if explicitly not disabled
|
||||
if !skipBlacklistTidy {
|
||||
b.tidyBlacklistRoleTag(req.Storage, safety_buffer)
|
||||
b.tidyBlacklistRoleTag(ctx, req.Storage, safety_buffer)
|
||||
}
|
||||
|
||||
// reset the safety_buffer to 72h
|
||||
safety_buffer = 259200
|
||||
tidyWhitelistConfigEntry, err := b.lockedConfigTidyIdentities(req.Storage)
|
||||
tidyWhitelistConfigEntry, err := b.lockedConfigTidyIdentities(ctx, req.Storage)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -181,7 +182,7 @@ func (b *backend) periodicFunc(req *logical.Request) error {
|
||||
}
|
||||
// tidy identities if explicitly not disabled
|
||||
if !skipWhitelistTidy {
|
||||
b.tidyWhitelistIdentity(req.Storage, safety_buffer)
|
||||
b.tidyWhitelistIdentity(ctx, req.Storage, safety_buffer)
|
||||
}
|
||||
|
||||
// Update the time at which to run the tidy functions again.
|
||||
@@ -190,7 +191,7 @@ func (b *backend) periodicFunc(req *logical.Request) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *backend) invalidate(key string) {
|
||||
func (b *backend) invalidate(ctx context.Context, key string) {
|
||||
switch key {
|
||||
case "config/client":
|
||||
b.configMutex.Lock()
|
||||
@@ -203,7 +204,7 @@ func (b *backend) invalidate(key string) {
|
||||
|
||||
// Putting this here so we can inject a fake resolver into the backend for unit testing
|
||||
// purposes
|
||||
func (b *backend) resolveArnToRealUniqueId(s logical.Storage, arn string) (string, error) {
|
||||
func (b *backend) resolveArnToRealUniqueId(ctx context.Context, s logical.Storage, arn string) (string, error) {
|
||||
entity, err := parseIamArn(arn)
|
||||
if err != nil {
|
||||
return "", err
|
||||
@@ -223,7 +224,7 @@ func (b *backend) resolveArnToRealUniqueId(s logical.Storage, arn string) (strin
|
||||
if region == nil {
|
||||
return "", fmt.Errorf("Unable to resolve partition %q to a region", entity.Partition)
|
||||
}
|
||||
iamClient, err := b.clientIAM(s, region.ID(), entity.AccountNumber)
|
||||
iamClient, err := b.clientIAM(ctx, s, region.ID(), entity.AccountNumber)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
@@ -30,7 +30,8 @@ func TestBackend_CreateParseVerifyRoleTag(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = b.Setup(config)
|
||||
|
||||
err = b.Setup(context.Background(), config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -55,7 +56,7 @@ func TestBackend_CreateParseVerifyRoleTag(t *testing.T) {
|
||||
}
|
||||
|
||||
// read the created role entry
|
||||
roleEntry, err := b.lockedAWSRole(storage, "abcd-123")
|
||||
roleEntry, err := b.lockedAWSRole(context.Background(), storage, "abcd-123")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -83,7 +84,7 @@ func TestBackend_CreateParseVerifyRoleTag(t *testing.T) {
|
||||
}
|
||||
|
||||
// parse the created role tag
|
||||
rTag2, err := b.parseAndVerifyRoleTagValue(storage, val)
|
||||
rTag2, err := b.parseAndVerifyRoleTagValue(context.Background(), storage, val)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -122,7 +123,7 @@ func TestBackend_CreateParseVerifyRoleTag(t *testing.T) {
|
||||
}
|
||||
|
||||
// get the entry of the newly created role entry
|
||||
roleEntry2, err := b.lockedAWSRole(storage, "ami-6789")
|
||||
roleEntry2, err := b.lockedAWSRole(context.Background(), storage, "ami-6789")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -254,7 +255,8 @@ func TestBackend_ConfigTidyIdentities(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = b.Setup(config)
|
||||
|
||||
err = b.Setup(context.Background(), config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -308,7 +310,8 @@ func TestBackend_ConfigTidyRoleTags(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = b.Setup(config)
|
||||
|
||||
err = b.Setup(context.Background(), config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -362,7 +365,8 @@ func TestBackend_TidyIdentities(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = b.Setup(config)
|
||||
|
||||
err = b.Setup(context.Background(), config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -387,7 +391,8 @@ func TestBackend_TidyRoleTags(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = b.Setup(config)
|
||||
|
||||
err = b.Setup(context.Background(), config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -412,7 +417,8 @@ func TestBackend_ConfigClient(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = b.Setup(config)
|
||||
|
||||
err = b.Setup(context.Background(), config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -549,7 +555,8 @@ func TestBackend_pathConfigCertificate(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = b.Setup(config)
|
||||
|
||||
err = b.Setup(context.Background(), config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -704,7 +711,8 @@ func TestBackend_parseAndVerifyRoleTagValue(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = b.Setup(config)
|
||||
|
||||
err = b.Setup(context.Background(), config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -763,7 +771,7 @@ func TestBackend_parseAndVerifyRoleTagValue(t *testing.T) {
|
||||
tagValue := resp.Data["tag_value"].(string)
|
||||
|
||||
// parse the value and check if the verifiable values match
|
||||
rTag, err := b.parseAndVerifyRoleTagValue(storage, tagValue)
|
||||
rTag, err := b.parseAndVerifyRoleTagValue(context.Background(), storage, tagValue)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
@@ -785,7 +793,8 @@ func TestBackend_PathRoleTag(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = b.Setup(config)
|
||||
|
||||
err = b.Setup(context.Background(), config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -850,7 +859,8 @@ func TestBackend_PathBlacklistRoleTag(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = b.Setup(config)
|
||||
|
||||
err = b.Setup(context.Background(), config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -939,7 +949,7 @@ func TestBackend_PathBlacklistRoleTag(t *testing.T) {
|
||||
}
|
||||
|
||||
// try to read the deleted entry
|
||||
tagEntry, err := b.lockedBlacklistRoleTagEntry(storage, tag)
|
||||
tagEntry, err := b.lockedBlacklistRoleTagEntry(context.Background(), storage, tag)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -998,7 +1008,8 @@ func TestBackendAcc_LoginWithInstanceIdentityDocAndWhitelistIdentity(t *testing.
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = b.Setup(config)
|
||||
|
||||
err = b.Setup(context.Background(), config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -1190,7 +1201,8 @@ func TestBackend_pathStsConfig(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = b.Setup(config)
|
||||
|
||||
err = b.Setup(context.Background(), config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -1338,7 +1350,8 @@ func TestBackendAcc_LoginWithCallerIdentity(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = b.Setup(config)
|
||||
|
||||
err = b.Setup(context.Background(), config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -1442,11 +1455,11 @@ func TestBackendAcc_LoginWithCallerIdentity(t *testing.T) {
|
||||
}
|
||||
|
||||
fakeArn := "arn:aws:iam::123456789012:role/somePath/FakeRole"
|
||||
fakeArnResolver := func(s logical.Storage, arn string) (string, error) {
|
||||
fakeArnResolver := func(ctx context.Context, s logical.Storage, arn string) (string, error) {
|
||||
if arn == fakeArn {
|
||||
return fmt.Sprintf("FakeUniqueIdFor%s", fakeArn), nil
|
||||
}
|
||||
return b.resolveArnToRealUniqueId(s, arn)
|
||||
return b.resolveArnToRealUniqueId(context.Background(), s, arn)
|
||||
}
|
||||
b.resolveArnToUniqueIDFunc = fakeArnResolver
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package awsauth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
@@ -21,13 +22,13 @@ import (
|
||||
// * Static credentials from 'config/client'
|
||||
// * Environment variables
|
||||
// * Instance metadata role
|
||||
func (b *backend) getRawClientConfig(s logical.Storage, region, clientType string) (*aws.Config, error) {
|
||||
func (b *backend) getRawClientConfig(ctx context.Context, s logical.Storage, region, clientType string) (*aws.Config, error) {
|
||||
credsConfig := &awsutil.CredentialsConfig{
|
||||
Region: region,
|
||||
}
|
||||
|
||||
// Read the configured secret key and access key
|
||||
config, err := b.nonLockedClientConfigEntry(s)
|
||||
config, err := b.nonLockedClientConfigEntry(ctx, s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -71,9 +72,9 @@ func (b *backend) getRawClientConfig(s logical.Storage, region, clientType strin
|
||||
// It uses getRawClientConfig to obtain config for the runtime environemnt, and if
|
||||
// stsRole is a non-empty string, it will use AssumeRole to obtain a set of assumed
|
||||
// credentials. The credentials will expire after 15 minutes but will auto-refresh.
|
||||
func (b *backend) getClientConfig(s logical.Storage, region, stsRole, accountID, clientType string) (*aws.Config, error) {
|
||||
func (b *backend) getClientConfig(ctx context.Context, s logical.Storage, region, stsRole, accountID, clientType string) (*aws.Config, error) {
|
||||
|
||||
config, err := b.getRawClientConfig(s, region, clientType)
|
||||
config, err := b.getRawClientConfig(ctx, s, region, clientType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -81,7 +82,7 @@ func (b *backend) getClientConfig(s logical.Storage, region, stsRole, accountID,
|
||||
return nil, fmt.Errorf("could not compile valid credentials through the default provider chain")
|
||||
}
|
||||
|
||||
stsConfig, err := b.getRawClientConfig(s, region, "sts")
|
||||
stsConfig, err := b.getRawClientConfig(ctx, s, region, "sts")
|
||||
if stsConfig == nil {
|
||||
return nil, fmt.Errorf("could not configure STS client")
|
||||
}
|
||||
@@ -160,9 +161,9 @@ func (b *backend) setCachedUserId(userId, arn string) {
|
||||
}
|
||||
}
|
||||
|
||||
func (b *backend) stsRoleForAccount(s logical.Storage, accountID string) (string, error) {
|
||||
func (b *backend) stsRoleForAccount(ctx context.Context, s logical.Storage, accountID string) (string, error) {
|
||||
// Check if an STS configuration exists for the AWS account
|
||||
sts, err := b.lockedAwsStsEntry(s, accountID)
|
||||
sts, err := b.lockedAwsStsEntry(ctx, s, accountID)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error fetching STS config for account ID %q: %q\n", accountID, err)
|
||||
}
|
||||
@@ -174,8 +175,8 @@ func (b *backend) stsRoleForAccount(s logical.Storage, accountID string) (string
|
||||
}
|
||||
|
||||
// clientEC2 creates a client to interact with AWS EC2 API
|
||||
func (b *backend) clientEC2(s logical.Storage, region, accountID string) (*ec2.EC2, error) {
|
||||
stsRole, err := b.stsRoleForAccount(s, accountID)
|
||||
func (b *backend) clientEC2(ctx context.Context, s logical.Storage, region, accountID string) (*ec2.EC2, error) {
|
||||
stsRole, err := b.stsRoleForAccount(ctx, s, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -198,7 +199,7 @@ func (b *backend) clientEC2(s logical.Storage, region, accountID string) (*ec2.E
|
||||
|
||||
// Create an AWS config object using a chain of providers
|
||||
var awsConfig *aws.Config
|
||||
awsConfig, err = b.getClientConfig(s, region, stsRole, accountID, "ec2")
|
||||
awsConfig, err = b.getClientConfig(ctx, s, region, stsRole, accountID, "ec2")
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -223,8 +224,8 @@ func (b *backend) clientEC2(s logical.Storage, region, accountID string) (*ec2.E
|
||||
}
|
||||
|
||||
// clientIAM creates a client to interact with AWS IAM API
|
||||
func (b *backend) clientIAM(s logical.Storage, region, accountID string) (*iam.IAM, error) {
|
||||
stsRole, err := b.stsRoleForAccount(s, accountID)
|
||||
func (b *backend) clientIAM(ctx context.Context, s logical.Storage, region, accountID string) (*iam.IAM, error) {
|
||||
stsRole, err := b.stsRoleForAccount(ctx, s, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -247,7 +248,7 @@ func (b *backend) clientIAM(s logical.Storage, region, accountID string) (*iam.I
|
||||
|
||||
// Create an AWS config object using a chain of providers
|
||||
var awsConfig *aws.Config
|
||||
awsConfig, err = b.getClientConfig(s, region, stsRole, accountID, "iam")
|
||||
awsConfig, err = b.getClientConfig(ctx, s, region, stsRole, accountID, "iam")
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
"math/big"
|
||||
"strings"
|
||||
|
||||
"github.com/fatih/structs"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/logical/framework"
|
||||
)
|
||||
@@ -131,7 +130,7 @@ func (b *backend) pathConfigCertificateExistenceCheck(ctx context.Context, req *
|
||||
return false, fmt.Errorf("missing cert_name")
|
||||
}
|
||||
|
||||
entry, err := b.lockedAWSPublicCertificateEntry(req.Storage, certName)
|
||||
entry, err := b.lockedAWSPublicCertificateEntry(ctx, req.Storage, certName)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
@@ -143,7 +142,7 @@ func (b *backend) pathCertificatesList(ctx context.Context, req *logical.Request
|
||||
b.configMutex.RLock()
|
||||
defer b.configMutex.RUnlock()
|
||||
|
||||
certs, err := req.Storage.List("config/certificate/")
|
||||
certs, err := req.Storage.List(ctx, "config/certificate/")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -174,7 +173,7 @@ func decodePEMAndParseCertificate(certificate string) (*x509.Certificate, error)
|
||||
// the PKCS7 signatures of the instance identity documents. This method will
|
||||
// append the certificates registered using `config/certificate/<cert_name>`
|
||||
// endpoint, along with the default certificate in the backend.
|
||||
func (b *backend) awsPublicCertificates(s logical.Storage, isPkcs bool) ([]*x509.Certificate, error) {
|
||||
func (b *backend) awsPublicCertificates(ctx context.Context, s logical.Storage, isPkcs bool) ([]*x509.Certificate, error) {
|
||||
// Lock at beginning and use internal method so that we are consistent as
|
||||
// we iterate through
|
||||
b.configMutex.RLock()
|
||||
@@ -195,14 +194,14 @@ func (b *backend) awsPublicCertificates(s logical.Storage, isPkcs bool) ([]*x509
|
||||
certs = append(certs, decodedCert)
|
||||
|
||||
// Get the list of all the registered certificates
|
||||
registeredCerts, err := s.List("config/certificate/")
|
||||
registeredCerts, err := s.List(ctx, "config/certificate/")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Iterate through each certificate, parse and append it to a slice
|
||||
for _, cert := range registeredCerts {
|
||||
certEntry, err := b.nonLockedAWSPublicCertificateEntry(s, cert)
|
||||
certEntry, err := b.nonLockedAWSPublicCertificateEntry(ctx, s, cert)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -226,7 +225,7 @@ func (b *backend) awsPublicCertificates(s logical.Storage, isPkcs bool) ([]*x509
|
||||
// lockedSetAWSPublicCertificateEntry is used to store the AWS public key in
|
||||
// the storage. This method acquires lock before creating or updating a storage
|
||||
// entry.
|
||||
func (b *backend) lockedSetAWSPublicCertificateEntry(s logical.Storage, certName string, certEntry *awsPublicCert) error {
|
||||
func (b *backend) lockedSetAWSPublicCertificateEntry(ctx context.Context, s logical.Storage, certName string, certEntry *awsPublicCert) error {
|
||||
if certName == "" {
|
||||
return fmt.Errorf("missing certificate name")
|
||||
}
|
||||
@@ -238,13 +237,13 @@ func (b *backend) lockedSetAWSPublicCertificateEntry(s logical.Storage, certName
|
||||
b.configMutex.Lock()
|
||||
defer b.configMutex.Unlock()
|
||||
|
||||
return b.nonLockedSetAWSPublicCertificateEntry(s, certName, certEntry)
|
||||
return b.nonLockedSetAWSPublicCertificateEntry(ctx, s, certName, certEntry)
|
||||
}
|
||||
|
||||
// nonLockedSetAWSPublicCertificateEntry is used to store the AWS public key in
|
||||
// the storage. This method does not acquire lock before reading the storage.
|
||||
// If locking is desired, use lockedSetAWSPublicCertificateEntry instead.
|
||||
func (b *backend) nonLockedSetAWSPublicCertificateEntry(s logical.Storage, certName string, certEntry *awsPublicCert) error {
|
||||
func (b *backend) nonLockedSetAWSPublicCertificateEntry(ctx context.Context, s logical.Storage, certName string, certEntry *awsPublicCert) error {
|
||||
if certName == "" {
|
||||
return fmt.Errorf("missing certificate name")
|
||||
}
|
||||
@@ -261,24 +260,24 @@ func (b *backend) nonLockedSetAWSPublicCertificateEntry(s logical.Storage, certN
|
||||
return fmt.Errorf("failed to create storage entry for AWS public key certificate")
|
||||
}
|
||||
|
||||
return s.Put(entry)
|
||||
return s.Put(ctx, entry)
|
||||
}
|
||||
|
||||
// lockedAWSPublicCertificateEntry is used to get the configured AWS Public Key
|
||||
// that is used to verify the PKCS#7 signature of the instance identity
|
||||
// document.
|
||||
func (b *backend) lockedAWSPublicCertificateEntry(s logical.Storage, certName string) (*awsPublicCert, error) {
|
||||
func (b *backend) lockedAWSPublicCertificateEntry(ctx context.Context, s logical.Storage, certName string) (*awsPublicCert, error) {
|
||||
b.configMutex.RLock()
|
||||
defer b.configMutex.RUnlock()
|
||||
|
||||
return b.nonLockedAWSPublicCertificateEntry(s, certName)
|
||||
return b.nonLockedAWSPublicCertificateEntry(ctx, s, certName)
|
||||
}
|
||||
|
||||
// nonLockedAWSPublicCertificateEntry reads the certificate information from
|
||||
// the storage. This method does not acquire lock before reading the storage.
|
||||
// If locking is desired, use lockedAWSPublicCertificateEntry instead.
|
||||
func (b *backend) nonLockedAWSPublicCertificateEntry(s logical.Storage, certName string) (*awsPublicCert, error) {
|
||||
entry, err := s.Get("config/certificate/" + certName)
|
||||
func (b *backend) nonLockedAWSPublicCertificateEntry(ctx context.Context, s logical.Storage, certName string) (*awsPublicCert, error) {
|
||||
entry, err := s.Get(ctx, "config/certificate/"+certName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -298,7 +297,7 @@ func (b *backend) nonLockedAWSPublicCertificateEntry(s logical.Storage, certName
|
||||
}
|
||||
|
||||
if persistNeeded {
|
||||
if err := b.nonLockedSetAWSPublicCertificateEntry(s, certName, &certEntry); err != nil {
|
||||
if err := b.nonLockedSetAWSPublicCertificateEntry(ctx, s, certName, &certEntry); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
@@ -318,7 +317,7 @@ func (b *backend) pathConfigCertificateDelete(ctx context.Context, req *logical.
|
||||
return logical.ErrorResponse("missing cert_name"), nil
|
||||
}
|
||||
|
||||
return nil, req.Storage.Delete("config/certificate/" + certName)
|
||||
return nil, req.Storage.Delete(ctx, "config/certificate/"+certName)
|
||||
}
|
||||
|
||||
// pathConfigCertificateRead is used to view the configured AWS Public Key that
|
||||
@@ -329,7 +328,7 @@ func (b *backend) pathConfigCertificateRead(ctx context.Context, req *logical.Re
|
||||
return logical.ErrorResponse("missing cert_name"), nil
|
||||
}
|
||||
|
||||
certificateEntry, err := b.lockedAWSPublicCertificateEntry(req.Storage, certName)
|
||||
certificateEntry, err := b.lockedAWSPublicCertificateEntry(ctx, req.Storage, certName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -338,7 +337,10 @@ func (b *backend) pathConfigCertificateRead(ctx context.Context, req *logical.Re
|
||||
}
|
||||
|
||||
return &logical.Response{
|
||||
Data: structs.New(certificateEntry).Map(),
|
||||
Data: map[string]interface{}{
|
||||
"aws_public_cert": certificateEntry.AWSPublicCert,
|
||||
"type": certificateEntry.Type,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -354,7 +356,7 @@ func (b *backend) pathConfigCertificateCreateUpdate(ctx context.Context, req *lo
|
||||
defer b.configMutex.Unlock()
|
||||
|
||||
// Check if there is already a certificate entry registered
|
||||
certEntry, err := b.nonLockedAWSPublicCertificateEntry(req.Storage, certName)
|
||||
certEntry, err := b.nonLockedAWSPublicCertificateEntry(ctx, req.Storage, certName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -406,7 +408,7 @@ func (b *backend) pathConfigCertificateCreateUpdate(ctx context.Context, req *lo
|
||||
}
|
||||
|
||||
// If none of the checks fail, save the provided certificate
|
||||
if err := b.nonLockedSetAWSPublicCertificateEntry(req.Storage, certName, certEntry); err != nil {
|
||||
if err := b.nonLockedSetAWSPublicCertificateEntry(ctx, req.Storage, certName, certEntry); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
@@ -66,7 +66,7 @@ func pathConfigClient(b *backend) *framework.Path {
|
||||
// Establishes dichotomy of request operation between CreateOperation and UpdateOperation.
|
||||
// Returning 'true' forces an UpdateOperation, CreateOperation otherwise.
|
||||
func (b *backend) pathConfigClientExistenceCheck(ctx context.Context, req *logical.Request, data *framework.FieldData) (bool, error) {
|
||||
entry, err := b.lockedClientConfigEntry(req.Storage)
|
||||
entry, err := b.lockedClientConfigEntry(ctx, req.Storage)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
@@ -74,16 +74,16 @@ func (b *backend) pathConfigClientExistenceCheck(ctx context.Context, req *logic
|
||||
}
|
||||
|
||||
// Fetch the client configuration required to access the AWS API, after acquiring an exclusive lock.
|
||||
func (b *backend) lockedClientConfigEntry(s logical.Storage) (*clientConfig, error) {
|
||||
func (b *backend) lockedClientConfigEntry(ctx context.Context, s logical.Storage) (*clientConfig, error) {
|
||||
b.configMutex.RLock()
|
||||
defer b.configMutex.RUnlock()
|
||||
|
||||
return b.nonLockedClientConfigEntry(s)
|
||||
return b.nonLockedClientConfigEntry(ctx, s)
|
||||
}
|
||||
|
||||
// Fetch the client configuration required to access the AWS API.
|
||||
func (b *backend) nonLockedClientConfigEntry(s logical.Storage) (*clientConfig, error) {
|
||||
entry, err := s.Get("config/client")
|
||||
func (b *backend) nonLockedClientConfigEntry(ctx context.Context, s logical.Storage) (*clientConfig, error) {
|
||||
entry, err := s.Get(ctx, "config/client")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -99,7 +99,7 @@ func (b *backend) nonLockedClientConfigEntry(s logical.Storage) (*clientConfig,
|
||||
}
|
||||
|
||||
func (b *backend) pathConfigClientRead(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
clientConfig, err := b.lockedClientConfigEntry(req.Storage)
|
||||
clientConfig, err := b.lockedClientConfigEntry(ctx, req.Storage)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -117,7 +117,7 @@ func (b *backend) pathConfigClientDelete(ctx context.Context, req *logical.Reque
|
||||
b.configMutex.Lock()
|
||||
defer b.configMutex.Unlock()
|
||||
|
||||
if err := req.Storage.Delete("config/client"); err != nil {
|
||||
if err := req.Storage.Delete(ctx, "config/client"); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -139,7 +139,7 @@ func (b *backend) pathConfigClientCreateUpdate(ctx context.Context, req *logical
|
||||
b.configMutex.Lock()
|
||||
defer b.configMutex.Unlock()
|
||||
|
||||
configEntry, err := b.nonLockedClientConfigEntry(req.Storage)
|
||||
configEntry, err := b.nonLockedClientConfigEntry(ctx, req.Storage)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -231,7 +231,7 @@ func (b *backend) pathConfigClientCreateUpdate(ctx context.Context, req *logical
|
||||
}
|
||||
|
||||
if changedCreds || changedOtherConfig || req.Operation == logical.CreateOperation {
|
||||
if err := req.Storage.Put(entry); err != nil {
|
||||
if err := req.Storage.Put(ctx, entry); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,7 +16,8 @@ func TestBackend_pathConfigClient(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = b.Setup(config)
|
||||
|
||||
err = b.Setup(context.Background(), config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -66,7 +66,7 @@ func (b *backend) pathConfigStsExistenceCheck(ctx context.Context, req *logical.
|
||||
return false, fmt.Errorf("missing account_id")
|
||||
}
|
||||
|
||||
entry, err := b.lockedAwsStsEntry(req.Storage, accountID)
|
||||
entry, err := b.lockedAwsStsEntry(ctx, req.Storage, accountID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
@@ -78,7 +78,7 @@ func (b *backend) pathConfigStsExistenceCheck(ctx context.Context, req *logical.
|
||||
func (b *backend) pathStsList(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
b.configMutex.RLock()
|
||||
defer b.configMutex.RUnlock()
|
||||
sts, err := req.Storage.List("config/sts/")
|
||||
sts, err := req.Storage.List(ctx, "config/sts/")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -88,7 +88,7 @@ func (b *backend) pathStsList(ctx context.Context, req *logical.Request, data *f
|
||||
// nonLockedSetAwsStsEntry creates or updates an STS role association with the given accountID
|
||||
// This method does not acquire the write lock before creating or updating. If locking is
|
||||
// desired, use lockedSetAwsStsEntry instead
|
||||
func (b *backend) nonLockedSetAwsStsEntry(s logical.Storage, accountID string, stsEntry *awsStsEntry) error {
|
||||
func (b *backend) nonLockedSetAwsStsEntry(ctx context.Context, s logical.Storage, accountID string, stsEntry *awsStsEntry) error {
|
||||
if accountID == "" {
|
||||
return fmt.Errorf("missing AWS account ID")
|
||||
}
|
||||
@@ -106,12 +106,12 @@ func (b *backend) nonLockedSetAwsStsEntry(s logical.Storage, accountID string, s
|
||||
return fmt.Errorf("failed to create storage entry for AWS STS configuration")
|
||||
}
|
||||
|
||||
return s.Put(entry)
|
||||
return s.Put(ctx, entry)
|
||||
}
|
||||
|
||||
// lockedSetAwsStsEntry creates or updates an STS role association with the given accountID
|
||||
// This method acquires the write lock before creating or updating the STS entry.
|
||||
func (b *backend) lockedSetAwsStsEntry(s logical.Storage, accountID string, stsEntry *awsStsEntry) error {
|
||||
func (b *backend) lockedSetAwsStsEntry(ctx context.Context, s logical.Storage, accountID string, stsEntry *awsStsEntry) error {
|
||||
if accountID == "" {
|
||||
return fmt.Errorf("missing AWS account ID")
|
||||
}
|
||||
@@ -123,14 +123,14 @@ func (b *backend) lockedSetAwsStsEntry(s logical.Storage, accountID string, stsE
|
||||
b.configMutex.Lock()
|
||||
defer b.configMutex.Unlock()
|
||||
|
||||
return b.nonLockedSetAwsStsEntry(s, accountID, stsEntry)
|
||||
return b.nonLockedSetAwsStsEntry(ctx, s, accountID, stsEntry)
|
||||
}
|
||||
|
||||
// nonLockedAwsStsEntry returns the STS role associated with the given accountID.
|
||||
// This method does not acquire the read lock before returning information. If locking is
|
||||
// desired, use lockedAwsStsEntry instead
|
||||
func (b *backend) nonLockedAwsStsEntry(s logical.Storage, accountID string) (*awsStsEntry, error) {
|
||||
entry, err := s.Get("config/sts/" + accountID)
|
||||
func (b *backend) nonLockedAwsStsEntry(ctx context.Context, s logical.Storage, accountID string) (*awsStsEntry, error) {
|
||||
entry, err := s.Get(ctx, "config/sts/"+accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -147,11 +147,11 @@ func (b *backend) nonLockedAwsStsEntry(s logical.Storage, accountID string) (*aw
|
||||
|
||||
// lockedAwsStsEntry returns the STS role associated with the given accountID.
|
||||
// This method acquires the read lock before returning the association.
|
||||
func (b *backend) lockedAwsStsEntry(s logical.Storage, accountID string) (*awsStsEntry, error) {
|
||||
func (b *backend) lockedAwsStsEntry(ctx context.Context, s logical.Storage, accountID string) (*awsStsEntry, error) {
|
||||
b.configMutex.RLock()
|
||||
defer b.configMutex.RUnlock()
|
||||
|
||||
return b.nonLockedAwsStsEntry(s, accountID)
|
||||
return b.nonLockedAwsStsEntry(ctx, s, accountID)
|
||||
}
|
||||
|
||||
// pathConfigStsRead is used to return information about an STS role/AWS accountID association
|
||||
@@ -161,7 +161,7 @@ func (b *backend) pathConfigStsRead(ctx context.Context, req *logical.Request, d
|
||||
return logical.ErrorResponse("missing account id"), nil
|
||||
}
|
||||
|
||||
stsEntry, err := b.lockedAwsStsEntry(req.Storage, accountID)
|
||||
stsEntry, err := b.lockedAwsStsEntry(ctx, req.Storage, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -185,7 +185,7 @@ func (b *backend) pathConfigStsCreateUpdate(ctx context.Context, req *logical.Re
|
||||
defer b.configMutex.Unlock()
|
||||
|
||||
// Check if an STS role is already registered
|
||||
stsEntry, err := b.nonLockedAwsStsEntry(req.Storage, accountID)
|
||||
stsEntry, err := b.nonLockedAwsStsEntry(ctx, req.Storage, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -206,7 +206,7 @@ func (b *backend) pathConfigStsCreateUpdate(ctx context.Context, req *logical.Re
|
||||
}
|
||||
|
||||
// save the provided STS role
|
||||
if err := b.nonLockedSetAwsStsEntry(req.Storage, accountID, stsEntry); err != nil {
|
||||
if err := b.nonLockedSetAwsStsEntry(ctx, req.Storage, accountID, stsEntry); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -223,7 +223,7 @@ func (b *backend) pathConfigStsDelete(ctx context.Context, req *logical.Request,
|
||||
return logical.ErrorResponse("missing account id"), nil
|
||||
}
|
||||
|
||||
return nil, req.Storage.Delete("config/sts/" + accountID)
|
||||
return nil, req.Storage.Delete(ctx, "config/sts/"+accountID)
|
||||
}
|
||||
|
||||
const pathConfigStsSyn = `
|
||||
|
||||
@@ -45,22 +45,22 @@ expiration, before it is removed from the backend storage.`,
|
||||
}
|
||||
|
||||
func (b *backend) pathConfigTidyIdentityWhitelistExistenceCheck(ctx context.Context, req *logical.Request, data *framework.FieldData) (bool, error) {
|
||||
entry, err := b.lockedConfigTidyIdentities(req.Storage)
|
||||
entry, err := b.lockedConfigTidyIdentities(ctx, req.Storage)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return entry != nil, nil
|
||||
}
|
||||
|
||||
func (b *backend) lockedConfigTidyIdentities(s logical.Storage) (*tidyWhitelistIdentityConfig, error) {
|
||||
func (b *backend) lockedConfigTidyIdentities(ctx context.Context, s logical.Storage) (*tidyWhitelistIdentityConfig, error) {
|
||||
b.configMutex.RLock()
|
||||
defer b.configMutex.RUnlock()
|
||||
|
||||
return b.nonLockedConfigTidyIdentities(s)
|
||||
return b.nonLockedConfigTidyIdentities(ctx, s)
|
||||
}
|
||||
|
||||
func (b *backend) nonLockedConfigTidyIdentities(s logical.Storage) (*tidyWhitelistIdentityConfig, error) {
|
||||
entry, err := s.Get(identityWhitelistConfigPath)
|
||||
func (b *backend) nonLockedConfigTidyIdentities(ctx context.Context, s logical.Storage) (*tidyWhitelistIdentityConfig, error) {
|
||||
entry, err := s.Get(ctx, identityWhitelistConfigPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -79,7 +79,7 @@ func (b *backend) pathConfigTidyIdentityWhitelistCreateUpdate(ctx context.Contex
|
||||
b.configMutex.Lock()
|
||||
defer b.configMutex.Unlock()
|
||||
|
||||
configEntry, err := b.nonLockedConfigTidyIdentities(req.Storage)
|
||||
configEntry, err := b.nonLockedConfigTidyIdentities(ctx, req.Storage)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -106,7 +106,7 @@ func (b *backend) pathConfigTidyIdentityWhitelistCreateUpdate(ctx context.Contex
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := req.Storage.Put(entry); err != nil {
|
||||
if err := req.Storage.Put(ctx, entry); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -114,7 +114,7 @@ func (b *backend) pathConfigTidyIdentityWhitelistCreateUpdate(ctx context.Contex
|
||||
}
|
||||
|
||||
func (b *backend) pathConfigTidyIdentityWhitelistRead(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
clientConfig, err := b.lockedConfigTidyIdentities(req.Storage)
|
||||
clientConfig, err := b.lockedConfigTidyIdentities(ctx, req.Storage)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -131,7 +131,7 @@ func (b *backend) pathConfigTidyIdentityWhitelistDelete(ctx context.Context, req
|
||||
b.configMutex.Lock()
|
||||
defer b.configMutex.Unlock()
|
||||
|
||||
return nil, req.Storage.Delete(identityWhitelistConfigPath)
|
||||
return nil, req.Storage.Delete(ctx, identityWhitelistConfigPath)
|
||||
}
|
||||
|
||||
type tidyWhitelistIdentityConfig struct {
|
||||
|
||||
@@ -47,22 +47,22 @@ Defaults to 4320h (180 days).`,
|
||||
}
|
||||
|
||||
func (b *backend) pathConfigTidyRoletagBlacklistExistenceCheck(ctx context.Context, req *logical.Request, data *framework.FieldData) (bool, error) {
|
||||
entry, err := b.lockedConfigTidyRoleTags(req.Storage)
|
||||
entry, err := b.lockedConfigTidyRoleTags(ctx, req.Storage)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return entry != nil, nil
|
||||
}
|
||||
|
||||
func (b *backend) lockedConfigTidyRoleTags(s logical.Storage) (*tidyBlacklistRoleTagConfig, error) {
|
||||
func (b *backend) lockedConfigTidyRoleTags(ctx context.Context, s logical.Storage) (*tidyBlacklistRoleTagConfig, error) {
|
||||
b.configMutex.RLock()
|
||||
defer b.configMutex.RUnlock()
|
||||
|
||||
return b.nonLockedConfigTidyRoleTags(s)
|
||||
return b.nonLockedConfigTidyRoleTags(ctx, s)
|
||||
}
|
||||
|
||||
func (b *backend) nonLockedConfigTidyRoleTags(s logical.Storage) (*tidyBlacklistRoleTagConfig, error) {
|
||||
entry, err := s.Get(roletagBlacklistConfigPath)
|
||||
func (b *backend) nonLockedConfigTidyRoleTags(ctx context.Context, s logical.Storage) (*tidyBlacklistRoleTagConfig, error) {
|
||||
entry, err := s.Get(ctx, roletagBlacklistConfigPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -82,7 +82,7 @@ func (b *backend) pathConfigTidyRoletagBlacklistCreateUpdate(ctx context.Context
|
||||
b.configMutex.Lock()
|
||||
defer b.configMutex.Unlock()
|
||||
|
||||
configEntry, err := b.nonLockedConfigTidyRoleTags(req.Storage)
|
||||
configEntry, err := b.nonLockedConfigTidyRoleTags(ctx, req.Storage)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -107,7 +107,7 @@ func (b *backend) pathConfigTidyRoletagBlacklistCreateUpdate(ctx context.Context
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := req.Storage.Put(entry); err != nil {
|
||||
if err := req.Storage.Put(ctx, entry); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -115,7 +115,7 @@ func (b *backend) pathConfigTidyRoletagBlacklistCreateUpdate(ctx context.Context
|
||||
}
|
||||
|
||||
func (b *backend) pathConfigTidyRoletagBlacklistRead(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
clientConfig, err := b.lockedConfigTidyRoleTags(req.Storage)
|
||||
clientConfig, err := b.lockedConfigTidyRoleTags(ctx, req.Storage)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -132,7 +132,7 @@ func (b *backend) pathConfigTidyRoletagBlacklistDelete(ctx context.Context, req
|
||||
b.configMutex.Lock()
|
||||
defer b.configMutex.Unlock()
|
||||
|
||||
return nil, req.Storage.Delete(roletagBlacklistConfigPath)
|
||||
return nil, req.Storage.Delete(ctx, roletagBlacklistConfigPath)
|
||||
}
|
||||
|
||||
type tidyBlacklistRoleTagConfig struct {
|
||||
|
||||
@@ -46,7 +46,7 @@ func pathListIdentityWhitelist(b *backend) *framework.Path {
|
||||
// pathWhitelistIdentitiesList is used to list all the instance IDs that are present
|
||||
// in the identity whitelist. This will list both valid and expired entries.
|
||||
func (b *backend) pathWhitelistIdentitiesList(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
identities, err := req.Storage.List("whitelist/identity/")
|
||||
identities, err := req.Storage.List(ctx, "whitelist/identity/")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -54,8 +54,8 @@ func (b *backend) pathWhitelistIdentitiesList(ctx context.Context, req *logical.
|
||||
}
|
||||
|
||||
// Fetch an item from the whitelist given an instance ID.
|
||||
func whitelistIdentityEntry(s logical.Storage, instanceID string) (*whitelistIdentity, error) {
|
||||
entry, err := s.Get("whitelist/identity/" + instanceID)
|
||||
func whitelistIdentityEntry(ctx context.Context, s logical.Storage, instanceID string) (*whitelistIdentity, error) {
|
||||
entry, err := s.Get(ctx, "whitelist/identity/"+instanceID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -72,13 +72,13 @@ func whitelistIdentityEntry(s logical.Storage, instanceID string) (*whitelistIde
|
||||
|
||||
// Stores an instance ID and the information required to validate further login/renewal attempts from
|
||||
// the same instance ID.
|
||||
func setWhitelistIdentityEntry(s logical.Storage, instanceID string, identity *whitelistIdentity) error {
|
||||
func setWhitelistIdentityEntry(ctx context.Context, s logical.Storage, instanceID string, identity *whitelistIdentity) error {
|
||||
entry, err := logical.StorageEntryJSON("whitelist/identity/"+instanceID, identity)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := s.Put(entry); err != nil {
|
||||
if err := s.Put(ctx, entry); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
@@ -91,7 +91,7 @@ func (b *backend) pathIdentityWhitelistDelete(ctx context.Context, req *logical.
|
||||
return logical.ErrorResponse("missing instance_id"), nil
|
||||
}
|
||||
|
||||
return nil, req.Storage.Delete("whitelist/identity/" + instanceID)
|
||||
return nil, req.Storage.Delete(ctx, "whitelist/identity/"+instanceID)
|
||||
}
|
||||
|
||||
// pathIdentityWhitelistRead is used to view an entry in the identity whitelist given an instance ID.
|
||||
@@ -101,7 +101,7 @@ func (b *backend) pathIdentityWhitelistRead(ctx context.Context, req *logical.Re
|
||||
return logical.ErrorResponse("missing instance_id"), nil
|
||||
}
|
||||
|
||||
entry, err := whitelistIdentityEntry(req.Storage, instanceID)
|
||||
entry, err := whitelistIdentityEntry(ctx, req.Storage, instanceID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -154,9 +154,9 @@ func (b *backend) instanceIamRoleARN(iamClient *iam.IAM, instanceProfileName str
|
||||
|
||||
// validateInstance queries the status of the EC2 instance using AWS EC2 API
|
||||
// and checks if the instance is running and is healthy
|
||||
func (b *backend) validateInstance(s logical.Storage, instanceID, region, accountID string) (*ec2.Instance, error) {
|
||||
func (b *backend) validateInstance(ctx context.Context, s logical.Storage, instanceID, region, accountID string) (*ec2.Instance, error) {
|
||||
// Create an EC2 client to pull the instance information
|
||||
ec2Client, err := b.clientEC2(s, region, accountID)
|
||||
ec2Client, err := b.clientEC2(ctx, s, region, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -256,7 +256,7 @@ func validateMetadata(clientNonce, pendingTime string, storedIdentity *whitelist
|
||||
// Verifies the integrity of the instance identity document using its SHA256
|
||||
// RSA signature. After verification, returns the unmarshaled instance identity
|
||||
// document.
|
||||
func (b *backend) verifyInstanceIdentitySignature(s logical.Storage, identityBytes, signatureBytes []byte) (*identityDocument, error) {
|
||||
func (b *backend) verifyInstanceIdentitySignature(ctx context.Context, s logical.Storage, identityBytes, signatureBytes []byte) (*identityDocument, error) {
|
||||
if len(identityBytes) == 0 {
|
||||
return nil, fmt.Errorf("missing instance identity document")
|
||||
}
|
||||
@@ -270,7 +270,7 @@ func (b *backend) verifyInstanceIdentitySignature(s logical.Storage, identityByt
|
||||
// certificate and all the registered certificates via
|
||||
// 'config/certificate/<cert_name>' endpoint, for verifying the RSA
|
||||
// digest.
|
||||
publicCerts, err := b.awsPublicCertificates(s, false)
|
||||
publicCerts, err := b.awsPublicCertificates(ctx, s, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -297,7 +297,7 @@ func (b *backend) verifyInstanceIdentitySignature(s logical.Storage, identityByt
|
||||
// Verifies the correctness of the authenticated attributes present in the PKCS#7
|
||||
// signature. After verification, extracts the instance identity document from the
|
||||
// signature, parses it and returns it.
|
||||
func (b *backend) parseIdentityDocument(s logical.Storage, pkcs7B64 string) (*identityDocument, error) {
|
||||
func (b *backend) parseIdentityDocument(ctx context.Context, s logical.Storage, pkcs7B64 string) (*identityDocument, error) {
|
||||
// Insert the header and footer for the signature to be able to pem decode it
|
||||
pkcs7B64 = fmt.Sprintf("-----BEGIN PKCS7-----\n%s\n-----END PKCS7-----", pkcs7B64)
|
||||
|
||||
@@ -316,7 +316,7 @@ func (b *backend) parseIdentityDocument(s logical.Storage, pkcs7B64 string) (*id
|
||||
// Get the public certificates that are used to verify the signature.
|
||||
// This returns a slice of certificates containing the default certificate
|
||||
// and all the registered certificates via 'config/certificate/<cert_name>' endpoint
|
||||
publicCerts, err := b.awsPublicCertificates(s, true)
|
||||
publicCerts, err := b.awsPublicCertificates(ctx, s, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -372,7 +372,7 @@ func (b *backend) pathLoginUpdate(ctx context.Context, req *logical.Request, dat
|
||||
// error that means the instance doesn't meet the role requirements
|
||||
// The second error return value indicates whether there's an error in even
|
||||
// trying to validate those requirements
|
||||
func (b *backend) verifyInstanceMeetsRoleRequirements(
|
||||
func (b *backend) verifyInstanceMeetsRoleRequirements(ctx context.Context,
|
||||
s logical.Storage, instance *ec2.Instance, roleEntry *awsRoleEntry, roleName string, identityDoc *identityDocument) (error, error) {
|
||||
|
||||
switch {
|
||||
@@ -470,7 +470,7 @@ func (b *backend) verifyInstanceMeetsRoleRequirements(
|
||||
}
|
||||
|
||||
// Use instance profile ARN to fetch the associated role ARN
|
||||
iamClient, err := b.clientIAM(s, identityDoc.Region, identityDoc.AccountID)
|
||||
iamClient, err := b.clientIAM(ctx, s, identityDoc.Region, identityDoc.AccountID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not fetch IAM client: %v", err)
|
||||
} else if iamClient == nil {
|
||||
@@ -530,7 +530,7 @@ func (b *backend) pathLoginUpdateEc2(ctx context.Context, req *logical.Request,
|
||||
// Verify the signature of the identity document and unmarshal it
|
||||
var identityDocParsed *identityDocument
|
||||
if pkcs7B64 != "" {
|
||||
identityDocParsed, err = b.parseIdentityDocument(req.Storage, pkcs7B64)
|
||||
identityDocParsed, err = b.parseIdentityDocument(ctx, req.Storage, pkcs7B64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -538,7 +538,7 @@ func (b *backend) pathLoginUpdateEc2(ctx context.Context, req *logical.Request,
|
||||
return logical.ErrorResponse("failed to verify the instance identity document using pkcs7"), nil
|
||||
}
|
||||
} else {
|
||||
identityDocParsed, err = b.verifyInstanceIdentitySignature(req.Storage, identityDocBytes, signatureBytes)
|
||||
identityDocParsed, err = b.verifyInstanceIdentitySignature(ctx, req.Storage, identityDocBytes, signatureBytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -566,7 +566,7 @@ func (b *backend) pathLoginUpdateEc2(ctx context.Context, req *logical.Request,
|
||||
}
|
||||
|
||||
// Get the entry for the role used by the instance
|
||||
roleEntry, err := b.lockedAWSRole(req.Storage, roleName)
|
||||
roleEntry, err := b.lockedAWSRole(ctx, req.Storage, roleName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -581,7 +581,7 @@ func (b *backend) pathLoginUpdateEc2(ctx context.Context, req *logical.Request,
|
||||
// Validate the instance ID by making a call to AWS EC2 DescribeInstances API
|
||||
// and fetching the instance description. Validation succeeds only if the
|
||||
// instance is in 'running' state.
|
||||
instance, err := b.validateInstance(req.Storage, identityDocParsed.InstanceID, identityDocParsed.Region, identityDocParsed.AccountID)
|
||||
instance, err := b.validateInstance(ctx, req.Storage, identityDocParsed.InstanceID, identityDocParsed.Region, identityDocParsed.AccountID)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse(fmt.Sprintf("failed to verify instance ID: %v", err)), nil
|
||||
}
|
||||
@@ -592,7 +592,7 @@ func (b *backend) pathLoginUpdateEc2(ctx context.Context, req *logical.Request,
|
||||
return logical.ErrorResponse(fmt.Sprintf("Region %q does not satisfy the constraint on role %q", identityDocParsed.Region, roleName)), nil
|
||||
}
|
||||
|
||||
validationError, err := b.verifyInstanceMeetsRoleRequirements(req.Storage, instance, roleEntry, roleName, identityDocParsed)
|
||||
validationError, err := b.verifyInstanceMeetsRoleRequirements(ctx, req.Storage, instance, roleEntry, roleName, identityDocParsed)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -601,7 +601,7 @@ func (b *backend) pathLoginUpdateEc2(ctx context.Context, req *logical.Request,
|
||||
}
|
||||
|
||||
// Get the entry from the identity whitelist, if there is one
|
||||
storedIdentity, err := whitelistIdentityEntry(req.Storage, identityDocParsed.InstanceID)
|
||||
storedIdentity, err := whitelistIdentityEntry(ctx, req.Storage, identityDocParsed.InstanceID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -682,7 +682,7 @@ func (b *backend) pathLoginUpdateEc2(ctx context.Context, req *logical.Request,
|
||||
rTagMaxTTL := time.Duration(0)
|
||||
var roleTagResp *roleTagLoginResponse
|
||||
if roleEntry.RoleTag != "" {
|
||||
roleTagResp, err := b.handleRoleTagLogin(req.Storage, roleName, roleEntry, instance)
|
||||
roleTagResp, err := b.handleRoleTagLogin(ctx, req.Storage, roleName, roleEntry, instance)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -750,7 +750,7 @@ func (b *backend) pathLoginUpdateEc2(ctx context.Context, req *logical.Request,
|
||||
return logical.ErrorResponse("client nonce exceeding the limit of 128 characters"), nil
|
||||
}
|
||||
|
||||
if err = setWhitelistIdentityEntry(req.Storage, identityDocParsed.InstanceID, storedIdentity); err != nil {
|
||||
if err = setWhitelistIdentityEntry(ctx, req.Storage, identityDocParsed.InstanceID, storedIdentity); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -800,7 +800,7 @@ func (b *backend) pathLoginUpdateEc2(ctx context.Context, req *logical.Request,
|
||||
// handleRoleTagLogin is used to fetch the role tag of the instance and
|
||||
// verifies it to be correct. Then the policies for the login request will be
|
||||
// set off of the role tag, if certain creteria satisfies.
|
||||
func (b *backend) handleRoleTagLogin(s logical.Storage, roleName string, roleEntry *awsRoleEntry, instance *ec2.Instance) (*roleTagLoginResponse, error) {
|
||||
func (b *backend) handleRoleTagLogin(ctx context.Context, s logical.Storage, roleName string, roleEntry *awsRoleEntry, instance *ec2.Instance) (*roleTagLoginResponse, error) {
|
||||
if roleEntry == nil {
|
||||
return nil, fmt.Errorf("nil role entry")
|
||||
}
|
||||
@@ -832,7 +832,7 @@ func (b *backend) handleRoleTagLogin(s logical.Storage, roleName string, roleEnt
|
||||
}
|
||||
|
||||
// Parse the role tag into a struct, extract the plaintext part of it and verify its HMAC
|
||||
rTag, err := b.parseAndVerifyRoleTagValue(s, rTagValue)
|
||||
rTag, err := b.parseAndVerifyRoleTagValue(ctx, s, rTagValue)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -849,7 +849,7 @@ func (b *backend) handleRoleTagLogin(s logical.Storage, roleName string, roleEnt
|
||||
}
|
||||
|
||||
// Check if the role tag is blacklisted
|
||||
blacklistEntry, err := b.lockedBlacklistRoleTagEntry(s, rTagValue)
|
||||
blacklistEntry, err := b.lockedBlacklistRoleTagEntry(ctx, s, rTagValue)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -896,7 +896,7 @@ func (b *backend) pathLoginRenewIam(ctx context.Context, req *logical.Request, d
|
||||
if roleName == "" {
|
||||
return nil, fmt.Errorf("error retrieving role_name during renewal")
|
||||
}
|
||||
roleEntry, err := b.lockedAWSRole(req.Storage, roleName)
|
||||
roleEntry, err := b.lockedAWSRole(ctx, req.Storage, roleName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -924,7 +924,7 @@ func (b *backend) pathLoginRenewIam(ctx context.Context, req *logical.Request, d
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no inferred AWS region in auth metadata")
|
||||
}
|
||||
_, err := b.validateInstance(req.Storage, instanceID, instanceRegion, req.Auth.Metadata["account_id"])
|
||||
_, err := b.validateInstance(ctx, req.Storage, instanceID, instanceRegion, req.Auth.Metadata["account_id"])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to verify instance ID %q: %v", instanceID, err)
|
||||
}
|
||||
@@ -956,7 +956,7 @@ func (b *backend) pathLoginRenewIam(ctx context.Context, req *logical.Request, d
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing ARN %q: %v", canonicalArn, err)
|
||||
}
|
||||
fullArn, err = b.fullArn(entity, req.Storage)
|
||||
fullArn, err = b.fullArn(ctx, entity, req.Storage)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error looking up full ARN of entity %v: %v", entity, err)
|
||||
}
|
||||
@@ -1008,12 +1008,12 @@ func (b *backend) pathLoginRenewEc2(ctx context.Context, req *logical.Request, d
|
||||
}
|
||||
|
||||
// Cross check that the instance is still in 'running' state
|
||||
_, err := b.validateInstance(req.Storage, instanceID, region, accountID)
|
||||
_, err := b.validateInstance(ctx, req.Storage, instanceID, region, accountID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to verify instance ID %q: %q", instanceID, err)
|
||||
}
|
||||
|
||||
storedIdentity, err := whitelistIdentityEntry(req.Storage, instanceID)
|
||||
storedIdentity, err := whitelistIdentityEntry(ctx, req.Storage, instanceID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1022,7 +1022,7 @@ func (b *backend) pathLoginRenewEc2(ctx context.Context, req *logical.Request, d
|
||||
}
|
||||
|
||||
// Ensure that role entry is not deleted
|
||||
roleEntry, err := b.lockedAWSRole(req.Storage, storedIdentity.Role)
|
||||
roleEntry, err := b.lockedAWSRole(ctx, req.Storage, storedIdentity.Role)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1061,7 +1061,7 @@ func (b *backend) pathLoginRenewEc2(ctx context.Context, req *logical.Request, d
|
||||
|
||||
// Updating the expiration time is required for the tidy operation on the
|
||||
// whitelist identity storage items
|
||||
if err = setWhitelistIdentityEntry(req.Storage, instanceID, storedIdentity); err != nil {
|
||||
if err = setWhitelistIdentityEntry(ctx, req.Storage, instanceID, storedIdentity); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -1127,7 +1127,7 @@ func (b *backend) pathLoginUpdateIam(ctx context.Context, req *logical.Request,
|
||||
return logical.ErrorResponse("nil response when parsing iam_request_headers"), nil
|
||||
}
|
||||
|
||||
config, err := b.lockedClientConfigEntry(req.Storage)
|
||||
config, err := b.lockedClientConfigEntry(ctx, req.Storage)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse("error getting configuration"), nil
|
||||
}
|
||||
@@ -1175,7 +1175,7 @@ func (b *backend) pathLoginUpdateIam(ctx context.Context, req *logical.Request,
|
||||
roleName = entity.FriendlyName
|
||||
}
|
||||
|
||||
roleEntry, err := b.lockedAWSRole(req.Storage, roleName)
|
||||
roleEntry, err := b.lockedAWSRole(ctx, req.Storage, roleName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1200,7 +1200,7 @@ func (b *backend) pathLoginUpdateIam(ctx context.Context, req *logical.Request,
|
||||
if strings.HasSuffix(roleEntry.BoundIamPrincipalARN, "*") {
|
||||
fullArn := b.getCachedUserId(callerUniqueId)
|
||||
if fullArn == "" {
|
||||
fullArn, err = b.fullArn(entity, req.Storage)
|
||||
fullArn, err = b.fullArn(ctx, entity, req.Storage)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse(fmt.Sprintf("error looking up full ARN of entity %v: %v", entity, err)), nil
|
||||
}
|
||||
@@ -1224,7 +1224,7 @@ func (b *backend) pathLoginUpdateIam(ctx context.Context, req *logical.Request,
|
||||
inferredEntityType := ""
|
||||
inferredEntityID := ""
|
||||
if roleEntry.InferredEntityType == ec2EntityType {
|
||||
instance, err := b.validateInstance(req.Storage, entity.SessionInfo, roleEntry.InferredAWSRegion, callerID.Account)
|
||||
instance, err := b.validateInstance(ctx, req.Storage, entity.SessionInfo, roleEntry.InferredAWSRegion, callerID.Account)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse(fmt.Sprintf("failed to verify %s as a valid EC2 instance in region %s", entity.SessionInfo, roleEntry.InferredAWSRegion)), nil
|
||||
}
|
||||
@@ -1239,7 +1239,7 @@ func (b *backend) pathLoginUpdateIam(ctx context.Context, req *logical.Request,
|
||||
PendingTime: instance.LaunchTime.Format(time.RFC3339),
|
||||
}
|
||||
|
||||
validationError, err := b.verifyInstanceMeetsRoleRequirements(req.Storage, instance, roleEntry, roleName, identityDoc)
|
||||
validationError, err := b.verifyInstanceMeetsRoleRequirements(ctx, req.Storage, instance, roleEntry, roleName, identityDoc)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1587,9 +1587,9 @@ func (e *iamEntity) canonicalArn() string {
|
||||
}
|
||||
|
||||
// This returns the "full" ARN of an iamEntity, how it would be referred to in AWS proper
|
||||
func (b *backend) fullArn(e *iamEntity, s logical.Storage) (string, error) {
|
||||
func (b *backend) fullArn(ctx context.Context, e *iamEntity, s logical.Storage) (string, error) {
|
||||
// Not assuming path is reliable for any entity types
|
||||
client, err := b.clientIAM(s, getAnyRegionForAwsPartition(e.Partition).ID(), e.AccountNumber)
|
||||
client, err := b.clientIAM(ctx, s, getAnyRegionForAwsPartition(e.Partition).ID(), e.AccountNumber)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error creating IAM client: %v", err)
|
||||
}
|
||||
|
||||
@@ -204,7 +204,7 @@ func pathListRoles(b *backend) *framework.Path {
|
||||
// Establishes dichotomy of request operation between CreateOperation and UpdateOperation.
|
||||
// Returning 'true' forces an UpdateOperation, CreateOperation otherwise.
|
||||
func (b *backend) pathRoleExistenceCheck(ctx context.Context, req *logical.Request, data *framework.FieldData) (bool, error) {
|
||||
entry, err := b.lockedAWSRole(req.Storage, strings.ToLower(data.Get("role").(string)))
|
||||
entry, err := b.lockedAWSRole(ctx, req.Storage, strings.ToLower(data.Get("role").(string)))
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
@@ -213,13 +213,13 @@ func (b *backend) pathRoleExistenceCheck(ctx context.Context, req *logical.Reque
|
||||
|
||||
// lockedAWSRole returns the properties set on the given role. This method
|
||||
// acquires the read lock before reading the role from the storage.
|
||||
func (b *backend) lockedAWSRole(s logical.Storage, roleName string) (*awsRoleEntry, error) {
|
||||
func (b *backend) lockedAWSRole(ctx context.Context, s logical.Storage, roleName string) (*awsRoleEntry, error) {
|
||||
if roleName == "" {
|
||||
return nil, fmt.Errorf("missing role name")
|
||||
}
|
||||
|
||||
b.roleMutex.RLock()
|
||||
roleEntry, err := b.nonLockedAWSRole(s, roleName)
|
||||
roleEntry, err := b.nonLockedAWSRole(ctx, s, roleName)
|
||||
// we manually unlock rather than defer the unlock because we might need to grab
|
||||
// a read/write lock in the upgrade path
|
||||
b.roleMutex.RUnlock()
|
||||
@@ -229,7 +229,7 @@ func (b *backend) lockedAWSRole(s logical.Storage, roleName string) (*awsRoleEnt
|
||||
if roleEntry == nil {
|
||||
return nil, nil
|
||||
}
|
||||
needUpgrade, err := b.upgradeRoleEntry(s, roleEntry)
|
||||
needUpgrade, err := b.upgradeRoleEntry(ctx, s, roleEntry)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error upgrading roleEntry: %v", err)
|
||||
}
|
||||
@@ -238,7 +238,7 @@ func (b *backend) lockedAWSRole(s logical.Storage, roleName string) (*awsRoleEnt
|
||||
defer b.roleMutex.Unlock()
|
||||
// Now that we have a R/W lock, we need to re-read the role entry in case it was
|
||||
// written to between releasing the read lock and acquiring the write lock
|
||||
roleEntry, err = b.nonLockedAWSRole(s, roleName)
|
||||
roleEntry, err = b.nonLockedAWSRole(ctx, s, roleName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -247,11 +247,11 @@ func (b *backend) lockedAWSRole(s logical.Storage, roleName string) (*awsRoleEnt
|
||||
return nil, nil
|
||||
}
|
||||
// now re-check to see if we need to upgrade
|
||||
if needUpgrade, err = b.upgradeRoleEntry(s, roleEntry); err != nil {
|
||||
if needUpgrade, err = b.upgradeRoleEntry(ctx, s, roleEntry); err != nil {
|
||||
return nil, fmt.Errorf("error upgrading roleEntry: %v", err)
|
||||
}
|
||||
if needUpgrade {
|
||||
if err = b.nonLockedSetAWSRole(s, roleName, roleEntry); err != nil {
|
||||
if err = b.nonLockedSetAWSRole(ctx, s, roleName, roleEntry); err != nil {
|
||||
return nil, fmt.Errorf("error saving upgraded roleEntry: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -261,7 +261,7 @@ func (b *backend) lockedAWSRole(s logical.Storage, roleName string) (*awsRoleEnt
|
||||
|
||||
// lockedSetAWSRole creates or updates a role in the storage. This method
|
||||
// acquires the write lock before creating or updating the role at the storage.
|
||||
func (b *backend) lockedSetAWSRole(s logical.Storage, roleName string, roleEntry *awsRoleEntry) error {
|
||||
func (b *backend) lockedSetAWSRole(ctx context.Context, s logical.Storage, roleName string, roleEntry *awsRoleEntry) error {
|
||||
if roleName == "" {
|
||||
return fmt.Errorf("missing role name")
|
||||
}
|
||||
@@ -273,13 +273,13 @@ func (b *backend) lockedSetAWSRole(s logical.Storage, roleName string, roleEntry
|
||||
b.roleMutex.Lock()
|
||||
defer b.roleMutex.Unlock()
|
||||
|
||||
return b.nonLockedSetAWSRole(s, roleName, roleEntry)
|
||||
return b.nonLockedSetAWSRole(ctx, s, roleName, roleEntry)
|
||||
}
|
||||
|
||||
// nonLockedSetAWSRole creates or updates a role in the storage. This method
|
||||
// does not acquire the write lock before reading the role from the storage. If
|
||||
// locking is desired, use lockedSetAWSRole instead.
|
||||
func (b *backend) nonLockedSetAWSRole(s logical.Storage, roleName string,
|
||||
func (b *backend) nonLockedSetAWSRole(ctx context.Context, s logical.Storage, roleName string,
|
||||
roleEntry *awsRoleEntry) error {
|
||||
if roleName == "" {
|
||||
return fmt.Errorf("missing role name")
|
||||
@@ -294,7 +294,7 @@ func (b *backend) nonLockedSetAWSRole(s logical.Storage, roleName string,
|
||||
return err
|
||||
}
|
||||
|
||||
if err := s.Put(entry); err != nil {
|
||||
if err := s.Put(ctx, entry); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -303,7 +303,7 @@ func (b *backend) nonLockedSetAWSRole(s logical.Storage, roleName string,
|
||||
|
||||
// If needed, updates the role entry and returns a bool indicating if it was updated
|
||||
// (and thus needs to be persisted)
|
||||
func (b *backend) upgradeRoleEntry(s logical.Storage, roleEntry *awsRoleEntry) (bool, error) {
|
||||
func (b *backend) upgradeRoleEntry(ctx context.Context, s logical.Storage, roleEntry *awsRoleEntry) (bool, error) {
|
||||
if roleEntry == nil {
|
||||
return false, fmt.Errorf("received nil roleEntry")
|
||||
}
|
||||
@@ -331,7 +331,7 @@ func (b *backend) upgradeRoleEntry(s logical.Storage, roleEntry *awsRoleEntry) (
|
||||
roleEntry.BoundIamPrincipalARN != "" &&
|
||||
roleEntry.BoundIamPrincipalID == "" &&
|
||||
!strings.HasSuffix(roleEntry.BoundIamPrincipalARN, "*") {
|
||||
principalId, err := b.resolveArnToUniqueIDFunc(s, roleEntry.BoundIamPrincipalARN)
|
||||
principalId, err := b.resolveArnToUniqueIDFunc(ctx, s, roleEntry.BoundIamPrincipalARN)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
@@ -349,12 +349,12 @@ func (b *backend) upgradeRoleEntry(s logical.Storage, roleEntry *awsRoleEntry) (
|
||||
// This method also does NOT check to see if a role upgrade is required. It is
|
||||
// the responsibility of the caller to check if a role upgrade is required and,
|
||||
// if so, to upgrade the role
|
||||
func (b *backend) nonLockedAWSRole(s logical.Storage, roleName string) (*awsRoleEntry, error) {
|
||||
func (b *backend) nonLockedAWSRole(ctx context.Context, s logical.Storage, roleName string) (*awsRoleEntry, error) {
|
||||
if roleName == "" {
|
||||
return nil, fmt.Errorf("missing role name")
|
||||
}
|
||||
|
||||
entry, err := s.Get("role/" + strings.ToLower(roleName))
|
||||
entry, err := s.Get(ctx, "role/"+strings.ToLower(roleName))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -380,7 +380,7 @@ func (b *backend) pathRoleDelete(ctx context.Context, req *logical.Request, data
|
||||
b.roleMutex.Lock()
|
||||
defer b.roleMutex.Unlock()
|
||||
|
||||
return nil, req.Storage.Delete("role/" + strings.ToLower(roleName))
|
||||
return nil, req.Storage.Delete(ctx, "role/"+strings.ToLower(roleName))
|
||||
}
|
||||
|
||||
// pathRoleList is used to list all the AMI IDs registered with Vault.
|
||||
@@ -388,7 +388,7 @@ func (b *backend) pathRoleList(ctx context.Context, req *logical.Request, data *
|
||||
b.roleMutex.RLock()
|
||||
defer b.roleMutex.RUnlock()
|
||||
|
||||
roles, err := req.Storage.List("role/")
|
||||
roles, err := req.Storage.List(ctx, "role/")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -397,7 +397,7 @@ func (b *backend) pathRoleList(ctx context.Context, req *logical.Request, data *
|
||||
|
||||
// pathRoleRead is used to view the information registered for a given AMI ID.
|
||||
func (b *backend) pathRoleRead(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
roleEntry, err := b.lockedAWSRole(req.Storage, strings.ToLower(data.Get("role").(string)))
|
||||
roleEntry, err := b.lockedAWSRole(ctx, req.Storage, strings.ToLower(data.Get("role").(string)))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -431,19 +431,19 @@ func (b *backend) pathRoleCreateUpdate(ctx context.Context, req *logical.Request
|
||||
b.roleMutex.Lock()
|
||||
defer b.roleMutex.Unlock()
|
||||
|
||||
roleEntry, err := b.nonLockedAWSRole(req.Storage, roleName)
|
||||
roleEntry, err := b.nonLockedAWSRole(ctx, req.Storage, roleName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if roleEntry == nil {
|
||||
roleEntry = &awsRoleEntry{}
|
||||
} else {
|
||||
needUpdate, err := b.upgradeRoleEntry(req.Storage, roleEntry)
|
||||
needUpdate, err := b.upgradeRoleEntry(ctx, req.Storage, roleEntry)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse(fmt.Sprintf("failed to update roleEntry: %v", err)), nil
|
||||
}
|
||||
if needUpdate {
|
||||
err = b.nonLockedSetAWSRole(req.Storage, roleName, roleEntry)
|
||||
err = b.nonLockedSetAWSRole(ctx, req.Storage, roleName, roleEntry)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse(fmt.Sprintf("failed to save upgraded roleEntry: %v", err)), nil
|
||||
}
|
||||
@@ -501,7 +501,7 @@ func (b *backend) pathRoleCreateUpdate(ctx context.Context, req *logical.Request
|
||||
// to re-resolve the ARN to the unique ID, in case an entity was deleted and
|
||||
// recreated
|
||||
if roleEntry.ResolveAWSUniqueIDs && !strings.HasSuffix(roleEntry.BoundIamPrincipalARN, "*") {
|
||||
principalID, err := b.resolveArnToUniqueIDFunc(req.Storage, principalARN)
|
||||
principalID, err := b.resolveArnToUniqueIDFunc(ctx, req.Storage, principalARN)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse(fmt.Sprintf("failed updating the unique ID of ARN %#v: %#v", principalARN, err)), nil
|
||||
}
|
||||
@@ -512,7 +512,7 @@ func (b *backend) pathRoleCreateUpdate(ctx context.Context, req *logical.Request
|
||||
}
|
||||
} else if roleEntry.ResolveAWSUniqueIDs && roleEntry.BoundIamPrincipalARN != "" && !strings.HasSuffix(roleEntry.BoundIamPrincipalARN, "*") {
|
||||
// we're turning on resolution on this role, so ensure we update it
|
||||
principalID, err := b.resolveArnToUniqueIDFunc(req.Storage, roleEntry.BoundIamPrincipalARN)
|
||||
principalID, err := b.resolveArnToUniqueIDFunc(ctx, req.Storage, roleEntry.BoundIamPrincipalARN)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse(fmt.Sprintf("unable to resolve ARN %#v to internal ID: %#v", roleEntry.BoundIamPrincipalARN, err)), nil
|
||||
}
|
||||
@@ -731,7 +731,7 @@ func (b *backend) pathRoleCreateUpdate(ctx context.Context, req *logical.Request
|
||||
}
|
||||
}
|
||||
|
||||
if err := b.nonLockedSetAWSRole(req.Storage, roleName, roleEntry); err != nil {
|
||||
if err := b.nonLockedSetAWSRole(ctx, req.Storage, roleName, roleEntry); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
@@ -77,7 +77,7 @@ func (b *backend) pathRoleTagUpdate(ctx context.Context, req *logical.Request, d
|
||||
}
|
||||
|
||||
// Fetch the role entry
|
||||
roleEntry, err := b.lockedAWSRole(req.Storage, roleName)
|
||||
roleEntry, err := b.lockedAWSRole(ctx, req.Storage, roleName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -288,7 +288,7 @@ func prepareRoleTagPlaintextValue(rTag *roleTag) (string, error) {
|
||||
|
||||
// Parses the tag from string form into a struct form. This method
|
||||
// also verifies the correctness of the parsed role tag.
|
||||
func (b *backend) parseAndVerifyRoleTagValue(s logical.Storage, tag string) (*roleTag, error) {
|
||||
func (b *backend) parseAndVerifyRoleTagValue(ctx context.Context, s logical.Storage, tag string) (*roleTag, error) {
|
||||
tagItems := strings.Split(tag, ":")
|
||||
|
||||
// Tag must contain version, nonce, policies and HMAC
|
||||
@@ -349,7 +349,7 @@ func (b *backend) parseAndVerifyRoleTagValue(s logical.Storage, tag string) (*ro
|
||||
return nil, fmt.Errorf("missing role name")
|
||||
}
|
||||
|
||||
roleEntry, err := b.lockedAWSRole(s, rTag.Role)
|
||||
roleEntry, err := b.lockedAWSRole(ctx, s, rTag.Role)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -20,7 +20,8 @@ func TestBackend_pathRoleEc2(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = b.Setup(config)
|
||||
|
||||
err = b.Setup(context.Background(), config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -164,7 +165,7 @@ func Test_enableIamIDResolution(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = b.Setup(config)
|
||||
err = b.Setup(context.Background(), config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -239,7 +240,7 @@ func TestBackend_pathIam(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = b.Setup(config)
|
||||
err = b.Setup(context.Background(), config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -403,7 +404,7 @@ func TestBackend_pathRoleMixedTypes(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = b.Setup(config)
|
||||
err = b.Setup(context.Background(), config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -509,7 +510,8 @@ func TestAwsEc2_RoleCrud(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = b.Setup(config)
|
||||
|
||||
err = b.Setup(context.Background(), config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -635,7 +637,8 @@ func TestAwsEc2_RoleDurationSeconds(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = b.Setup(config)
|
||||
|
||||
err = b.Setup(context.Background(), config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -679,6 +682,6 @@ func TestAwsEc2_RoleDurationSeconds(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func resolveArnToFakeUniqueId(s logical.Storage, arn string) (string, error) {
|
||||
func resolveArnToFakeUniqueId(ctx context.Context, s logical.Storage, arn string) (string, error) {
|
||||
return "FakeUniqueId1", nil
|
||||
}
|
||||
|
||||
@@ -50,7 +50,7 @@ func (b *backend) pathRoletagBlacklistsList(ctx context.Context, req *logical.Re
|
||||
b.blacklistMutex.RLock()
|
||||
defer b.blacklistMutex.RUnlock()
|
||||
|
||||
tags, err := req.Storage.List("blacklist/roletag/")
|
||||
tags, err := req.Storage.List(ctx, "blacklist/roletag/")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -71,15 +71,15 @@ func (b *backend) pathRoletagBlacklistsList(ctx context.Context, req *logical.Re
|
||||
|
||||
// Fetch an entry from the role tag blacklist for a given tag.
|
||||
// This method takes a role tag in its original form and not a base64 encoded form.
|
||||
func (b *backend) lockedBlacklistRoleTagEntry(s logical.Storage, tag string) (*roleTagBlacklistEntry, error) {
|
||||
func (b *backend) lockedBlacklistRoleTagEntry(ctx context.Context, s logical.Storage, tag string) (*roleTagBlacklistEntry, error) {
|
||||
b.blacklistMutex.RLock()
|
||||
defer b.blacklistMutex.RUnlock()
|
||||
|
||||
return b.nonLockedBlacklistRoleTagEntry(s, tag)
|
||||
return b.nonLockedBlacklistRoleTagEntry(ctx, s, tag)
|
||||
}
|
||||
|
||||
func (b *backend) nonLockedBlacklistRoleTagEntry(s logical.Storage, tag string) (*roleTagBlacklistEntry, error) {
|
||||
entry, err := s.Get("blacklist/roletag/" + base64.StdEncoding.EncodeToString([]byte(tag)))
|
||||
func (b *backend) nonLockedBlacklistRoleTagEntry(ctx context.Context, s logical.Storage, tag string) (*roleTagBlacklistEntry, error) {
|
||||
entry, err := s.Get(ctx, "blacklist/roletag/"+base64.StdEncoding.EncodeToString([]byte(tag)))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -104,7 +104,7 @@ func (b *backend) pathRoletagBlacklistDelete(ctx context.Context, req *logical.R
|
||||
return logical.ErrorResponse("missing role_tag"), nil
|
||||
}
|
||||
|
||||
return nil, req.Storage.Delete("blacklist/roletag/" + base64.StdEncoding.EncodeToString([]byte(tag)))
|
||||
return nil, req.Storage.Delete(ctx, "blacklist/roletag/"+base64.StdEncoding.EncodeToString([]byte(tag)))
|
||||
}
|
||||
|
||||
// If the given role tag is blacklisted, returns the details of the blacklist entry.
|
||||
@@ -115,7 +115,7 @@ func (b *backend) pathRoletagBlacklistRead(ctx context.Context, req *logical.Req
|
||||
return logical.ErrorResponse("missing role_tag"), nil
|
||||
}
|
||||
|
||||
entry, err := b.lockedBlacklistRoleTagEntry(req.Storage, tag)
|
||||
entry, err := b.lockedBlacklistRoleTagEntry(ctx, req.Storage, tag)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -154,7 +154,7 @@ func (b *backend) pathRoletagBlacklistUpdate(ctx context.Context, req *logical.R
|
||||
}
|
||||
|
||||
// Parse and verify the role tag from string form to a struct form and verify it.
|
||||
rTag, err := b.parseAndVerifyRoleTagValue(req.Storage, tag)
|
||||
rTag, err := b.parseAndVerifyRoleTagValue(ctx, req.Storage, tag)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -163,7 +163,7 @@ func (b *backend) pathRoletagBlacklistUpdate(ctx context.Context, req *logical.R
|
||||
}
|
||||
|
||||
// Get the entry for the role mentioned in the role tag.
|
||||
roleEntry, err := b.lockedAWSRole(req.Storage, rTag.Role)
|
||||
roleEntry, err := b.lockedAWSRole(ctx, req.Storage, rTag.Role)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -175,7 +175,7 @@ func (b *backend) pathRoletagBlacklistUpdate(ctx context.Context, req *logical.R
|
||||
defer b.blacklistMutex.Unlock()
|
||||
|
||||
// Check if the role tag is already blacklisted. If yes, update it.
|
||||
blEntry, err := b.nonLockedBlacklistRoleTagEntry(req.Storage, tag)
|
||||
blEntry, err := b.nonLockedBlacklistRoleTagEntry(ctx, req.Storage, tag)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -211,7 +211,7 @@ func (b *backend) pathRoletagBlacklistUpdate(ctx context.Context, req *logical.R
|
||||
}
|
||||
|
||||
// Store the blacklist entry.
|
||||
if err := req.Storage.Put(entry); err != nil {
|
||||
if err := req.Storage.Put(ctx, entry); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
@@ -32,7 +32,7 @@ expiration, before it is removed from the backend storage.`,
|
||||
}
|
||||
|
||||
// tidyWhitelistIdentity is used to delete entries in the whitelist that are expired.
|
||||
func (b *backend) tidyWhitelistIdentity(s logical.Storage, safety_buffer int) error {
|
||||
func (b *backend) tidyWhitelistIdentity(ctx context.Context, s logical.Storage, safety_buffer int) error {
|
||||
grabbed := atomic.CompareAndSwapUint32(&b.tidyWhitelistCASGuard, 0, 1)
|
||||
if grabbed {
|
||||
defer atomic.StoreUint32(&b.tidyWhitelistCASGuard, 0)
|
||||
@@ -42,13 +42,13 @@ func (b *backend) tidyWhitelistIdentity(s logical.Storage, safety_buffer int) er
|
||||
|
||||
bufferDuration := time.Duration(safety_buffer) * time.Second
|
||||
|
||||
identities, err := s.List("whitelist/identity/")
|
||||
identities, err := s.List(ctx, "whitelist/identity/")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, instanceID := range identities {
|
||||
identityEntry, err := s.Get("whitelist/identity/" + instanceID)
|
||||
identityEntry, err := s.Get(ctx, "whitelist/identity/"+instanceID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error fetching identity of instanceID %s: %s", instanceID, err)
|
||||
}
|
||||
@@ -67,7 +67,7 @@ func (b *backend) tidyWhitelistIdentity(s logical.Storage, safety_buffer int) er
|
||||
}
|
||||
|
||||
if time.Now().After(result.ExpirationTime.Add(bufferDuration)) {
|
||||
if err := s.Delete("whitelist/identity" + instanceID); err != nil {
|
||||
if err := s.Delete(ctx, "whitelist/identity"+instanceID); err != nil {
|
||||
return fmt.Errorf("error deleting identity of instanceID %s from storage: %s", instanceID, err)
|
||||
}
|
||||
}
|
||||
@@ -78,7 +78,7 @@ func (b *backend) tidyWhitelistIdentity(s logical.Storage, safety_buffer int) er
|
||||
|
||||
// pathTidyIdentityWhitelistUpdate is used to delete entries in the whitelist that are expired.
|
||||
func (b *backend) pathTidyIdentityWhitelistUpdate(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
return nil, b.tidyWhitelistIdentity(req.Storage, data.Get("safety_buffer").(int))
|
||||
return nil, b.tidyWhitelistIdentity(ctx, req.Storage, data.Get("safety_buffer").(int))
|
||||
}
|
||||
|
||||
const pathTidyIdentityWhitelistSyn = `
|
||||
|
||||
@@ -32,7 +32,7 @@ expiration, before it is removed from the backend storage.`,
|
||||
}
|
||||
|
||||
// tidyBlacklistRoleTag is used to clean-up the entries in the role tag blacklist.
|
||||
func (b *backend) tidyBlacklistRoleTag(s logical.Storage, safety_buffer int) error {
|
||||
func (b *backend) tidyBlacklistRoleTag(ctx context.Context, s logical.Storage, safety_buffer int) error {
|
||||
grabbed := atomic.CompareAndSwapUint32(&b.tidyBlacklistCASGuard, 0, 1)
|
||||
if grabbed {
|
||||
defer atomic.StoreUint32(&b.tidyBlacklistCASGuard, 0)
|
||||
@@ -41,13 +41,13 @@ func (b *backend) tidyBlacklistRoleTag(s logical.Storage, safety_buffer int) err
|
||||
}
|
||||
|
||||
bufferDuration := time.Duration(safety_buffer) * time.Second
|
||||
tags, err := s.List("blacklist/roletag/")
|
||||
tags, err := s.List(ctx, "blacklist/roletag/")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, tag := range tags {
|
||||
tagEntry, err := s.Get("blacklist/roletag/" + tag)
|
||||
tagEntry, err := s.Get(ctx, "blacklist/roletag/"+tag)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error fetching tag %s: %s", tag, err)
|
||||
}
|
||||
@@ -66,7 +66,7 @@ func (b *backend) tidyBlacklistRoleTag(s logical.Storage, safety_buffer int) err
|
||||
}
|
||||
|
||||
if time.Now().After(result.ExpirationTime.Add(bufferDuration)) {
|
||||
if err := s.Delete("blacklist/roletag" + tag); err != nil {
|
||||
if err := s.Delete(ctx, "blacklist/roletag"+tag); err != nil {
|
||||
return fmt.Errorf("error deleting tag %s from storage: %s", tag, err)
|
||||
}
|
||||
}
|
||||
@@ -77,7 +77,7 @@ func (b *backend) tidyBlacklistRoleTag(s logical.Storage, safety_buffer int) err
|
||||
|
||||
// pathTidyRoletagBlacklistUpdate is used to clean-up the entries in the role tag blacklist.
|
||||
func (b *backend) pathTidyRoletagBlacklistUpdate(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
return nil, b.tidyBlacklistRoleTag(req.Storage, data.Get("safety_buffer").(int))
|
||||
return nil, b.tidyBlacklistRoleTag(ctx, req.Storage, data.Get("safety_buffer").(int))
|
||||
}
|
||||
|
||||
const pathTidyRoletagBlacklistSyn = `
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package cert
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
@@ -8,9 +9,9 @@ import (
|
||||
"github.com/hashicorp/vault/logical/framework"
|
||||
)
|
||||
|
||||
func Factory(conf *logical.BackendConfig) (logical.Backend, error) {
|
||||
func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) {
|
||||
b := Backend()
|
||||
if err := b.Setup(conf); err != nil {
|
||||
if err := b.Setup(ctx, conf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return b, nil
|
||||
@@ -50,7 +51,7 @@ type backend struct {
|
||||
crlUpdateMutex *sync.RWMutex
|
||||
}
|
||||
|
||||
func (b *backend) invalidate(key string) {
|
||||
func (b *backend) invalidate(_ context.Context, key string) {
|
||||
switch {
|
||||
case strings.HasPrefix(key, "crls/"):
|
||||
b.crlUpdateMutex.Lock()
|
||||
|
||||
@@ -306,7 +306,7 @@ func TestBackend_NonCAExpiry(t *testing.T) {
|
||||
storage := &logical.InmemStorage{}
|
||||
config.StorageView = storage
|
||||
|
||||
b, err := Factory(config)
|
||||
b, err := Factory(context.Background(), config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -366,7 +366,7 @@ func TestBackend_RegisteredNonCA_CRL(t *testing.T) {
|
||||
storage := &logical.InmemStorage{}
|
||||
config.StorageView = storage
|
||||
|
||||
b, err := Factory(config)
|
||||
b, err := Factory(context.Background(), config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -449,7 +449,7 @@ func TestBackend_CRLs(t *testing.T) {
|
||||
storage := &logical.InmemStorage{}
|
||||
config.StorageView = storage
|
||||
|
||||
b, err := Factory(config)
|
||||
b, err := Factory(context.Background(), config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -586,7 +586,7 @@ func TestBackend_CRLs(t *testing.T) {
|
||||
}
|
||||
|
||||
func testFactory(t *testing.T) logical.Backend {
|
||||
b, err := Factory(&logical.BackendConfig{
|
||||
b, err := Factory(context.Background(), &logical.BackendConfig{
|
||||
System: &logical.StaticSystemView{
|
||||
DefaultLeaseTTLVal: 1000 * time.Second,
|
||||
MaxLeaseTTLVal: 1800 * time.Second,
|
||||
@@ -1135,7 +1135,7 @@ func testConnState(certPath, keyPath, rootCertPath string) (tls.ConnectionState,
|
||||
func Test_Renew(t *testing.T) {
|
||||
storage := &logical.InmemStorage{}
|
||||
|
||||
lb, err := Factory(&logical.BackendConfig{
|
||||
lb, err := Factory(context.Background(), &logical.BackendConfig{
|
||||
System: &logical.StaticSystemView{
|
||||
DefaultLeaseTTLVal: 300 * time.Second,
|
||||
MaxLeaseTTLVal: 1800 * time.Second,
|
||||
|
||||
@@ -101,8 +101,8 @@ TTL will be set to the value of this parameter.`,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *backend) Cert(s logical.Storage, n string) (*CertEntry, error) {
|
||||
entry, err := s.Get("cert/" + strings.ToLower(n))
|
||||
func (b *backend) Cert(ctx context.Context, s logical.Storage, n string) (*CertEntry, error) {
|
||||
entry, err := s.Get(ctx, "cert/"+strings.ToLower(n))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -118,7 +118,7 @@ func (b *backend) Cert(s logical.Storage, n string) (*CertEntry, error) {
|
||||
}
|
||||
|
||||
func (b *backend) pathCertDelete(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
err := req.Storage.Delete("cert/" + strings.ToLower(d.Get("name").(string)))
|
||||
err := req.Storage.Delete(ctx, "cert/"+strings.ToLower(d.Get("name").(string)))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -126,7 +126,7 @@ func (b *backend) pathCertDelete(ctx context.Context, req *logical.Request, d *f
|
||||
}
|
||||
|
||||
func (b *backend) pathCertList(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
certs, err := req.Storage.List("cert/")
|
||||
certs, err := req.Storage.List(ctx, "cert/")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -134,7 +134,7 @@ func (b *backend) pathCertList(ctx context.Context, req *logical.Request, d *fra
|
||||
}
|
||||
|
||||
func (b *backend) pathCertRead(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
cert, err := b.Cert(req.Storage, strings.ToLower(d.Get("name").(string)))
|
||||
cert, err := b.Cert(ctx, req.Storage, strings.ToLower(d.Get("name").(string)))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -245,7 +245,7 @@ func (b *backend) pathCertWrite(ctx context.Context, req *logical.Request, d *fr
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := req.Storage.Put(entry); err != nil {
|
||||
if err := req.Storage.Put(ctx, entry); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
@@ -35,15 +35,15 @@ func (b *backend) pathConfigWrite(ctx context.Context, req *logical.Request, dat
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := req.Storage.Put(entry); err != nil {
|
||||
if err := req.Storage.Put(ctx, entry); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Config returns the configuration for this backend.
|
||||
func (b *backend) Config(s logical.Storage) (*config, error) {
|
||||
entry, err := s.Get("config")
|
||||
func (b *backend) Config(ctx context.Context, s logical.Storage) (*config, error) {
|
||||
entry, err := s.Get(ctx, "config")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -42,7 +42,7 @@ using the same name as specified here.`,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *backend) populateCRLs(storage logical.Storage) error {
|
||||
func (b *backend) populateCRLs(ctx context.Context, storage logical.Storage) error {
|
||||
b.crlUpdateMutex.Lock()
|
||||
defer b.crlUpdateMutex.Unlock()
|
||||
|
||||
@@ -52,7 +52,7 @@ func (b *backend) populateCRLs(storage logical.Storage) error {
|
||||
|
||||
b.crls = map[string]CRLInfo{}
|
||||
|
||||
keys, err := storage.List("crls/")
|
||||
keys, err := storage.List(ctx, "crls/")
|
||||
if err != nil {
|
||||
return fmt.Errorf("error listing CRLs: %v", err)
|
||||
}
|
||||
@@ -61,7 +61,7 @@ func (b *backend) populateCRLs(storage logical.Storage) error {
|
||||
}
|
||||
|
||||
for _, key := range keys {
|
||||
entry, err := storage.Get("crls/" + key)
|
||||
entry, err := storage.Get(ctx, "crls/"+key)
|
||||
if err != nil {
|
||||
b.crls = nil
|
||||
return fmt.Errorf("error loading CRL %s: %v", key, err)
|
||||
@@ -129,7 +129,7 @@ func (b *backend) pathCRLDelete(ctx context.Context, req *logical.Request, d *fr
|
||||
return logical.ErrorResponse(`"name" parameter cannot be empty`), nil
|
||||
}
|
||||
|
||||
if err := b.populateCRLs(req.Storage); err != nil {
|
||||
if err := b.populateCRLs(ctx, req.Storage); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -143,7 +143,7 @@ func (b *backend) pathCRLDelete(ctx context.Context, req *logical.Request, d *fr
|
||||
)), nil
|
||||
}
|
||||
|
||||
if err := req.Storage.Delete("crls/" + name); err != nil {
|
||||
if err := req.Storage.Delete(ctx, "crls/"+name); err != nil {
|
||||
return logical.ErrorResponse(fmt.Sprintf(
|
||||
"error deleting crl %s: %v", name, err),
|
||||
), nil
|
||||
@@ -160,7 +160,7 @@ func (b *backend) pathCRLRead(ctx context.Context, req *logical.Request, d *fram
|
||||
return logical.ErrorResponse(`"name" parameter must be set`), nil
|
||||
}
|
||||
|
||||
if err := b.populateCRLs(req.Storage); err != nil {
|
||||
if err := b.populateCRLs(ctx, req.Storage); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -198,7 +198,7 @@ func (b *backend) pathCRLWrite(ctx context.Context, req *logical.Request, d *fra
|
||||
return logical.ErrorResponse("parsed CRL is nil"), nil
|
||||
}
|
||||
|
||||
if err := b.populateCRLs(req.Storage); err != nil {
|
||||
if err := b.populateCRLs(ctx, req.Storage); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -216,7 +216,7 @@ func (b *backend) pathCRLWrite(ctx context.Context, req *logical.Request, d *fra
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = req.Storage.Put(entry); err != nil {
|
||||
if err = req.Storage.Put(ctx, entry); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
@@ -60,7 +60,7 @@ func (b *backend) pathLoginAliasLookahead(ctx context.Context, req *logical.Requ
|
||||
|
||||
func (b *backend) pathLogin(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
var matched *ParsedCert
|
||||
if verifyResp, resp, err := b.verifyCredentials(req, data); err != nil {
|
||||
if verifyResp, resp, err := b.verifyCredentials(ctx, req, data); err != nil {
|
||||
return nil, err
|
||||
} else if resp != nil {
|
||||
return resp, nil
|
||||
@@ -128,14 +128,14 @@ func (b *backend) pathLogin(ctx context.Context, req *logical.Request, data *fra
|
||||
}
|
||||
|
||||
func (b *backend) pathLoginRenew(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
config, err := b.Config(req.Storage)
|
||||
config, err := b.Config(ctx, req.Storage)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !config.DisableBinding {
|
||||
var matched *ParsedCert
|
||||
if verifyResp, resp, err := b.verifyCredentials(req, d); err != nil {
|
||||
if verifyResp, resp, err := b.verifyCredentials(ctx, req, d); err != nil {
|
||||
return nil, err
|
||||
} else if resp != nil {
|
||||
return resp, nil
|
||||
@@ -162,7 +162,7 @@ func (b *backend) pathLoginRenew(ctx context.Context, req *logical.Request, d *f
|
||||
|
||||
}
|
||||
// Get the cert and use its TTL
|
||||
cert, err := b.Cert(req.Storage, req.Auth.Metadata["cert_name"])
|
||||
cert, err := b.Cert(ctx, req.Storage, req.Auth.Metadata["cert_name"])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -188,7 +188,7 @@ func (b *backend) pathLoginRenew(ctx context.Context, req *logical.Request, d *f
|
||||
return framework.LeaseExtend(cert.TTL, cert.MaxTTL, b.System())(ctx, req, d)
|
||||
}
|
||||
|
||||
func (b *backend) verifyCredentials(req *logical.Request, d *framework.FieldData) (*ParsedCert, *logical.Response, error) {
|
||||
func (b *backend) verifyCredentials(ctx context.Context, req *logical.Request, d *framework.FieldData) (*ParsedCert, *logical.Response, error) {
|
||||
// Get the connection state
|
||||
if req.Connection == nil || req.Connection.ConnState == nil {
|
||||
return nil, logical.ErrorResponse("tls connection required"), nil
|
||||
@@ -209,7 +209,14 @@ func (b *backend) verifyCredentials(req *logical.Request, d *framework.FieldData
|
||||
}
|
||||
|
||||
// Load the trusted certificates
|
||||
roots, trusted, trustedNonCAs := b.loadTrustedCerts(req.Storage, certName)
|
||||
roots, trusted, trustedNonCAs := b.loadTrustedCerts(ctx, req.Storage, certName)
|
||||
|
||||
// Get the list of full chains matching the connection and validates the
|
||||
// certificate itself
|
||||
trustedChains, err := validateConnState(roots, connState)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// If trustedNonCAs is not empty it means that client had registered a non-CA cert
|
||||
// with the backend.
|
||||
@@ -225,12 +232,8 @@ func (b *backend) verifyCredentials(req *logical.Request, d *framework.FieldData
|
||||
}
|
||||
}
|
||||
|
||||
// Get the list of full chains matching the connection
|
||||
trustedChains, err := validateConnState(roots, connState)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
// If no trusted chain was found, client is not authenticated
|
||||
// This check happens after checking for a matching configured non-CA certs
|
||||
if len(trustedChains) == 0 {
|
||||
return nil, logical.ErrorResponse("invalid certificate or no client certificate supplied"), nil
|
||||
}
|
||||
@@ -328,11 +331,11 @@ func (b *backend) matchesCertificateExtenions(clientCert *x509.Certificate, conf
|
||||
}
|
||||
|
||||
// loadTrustedCerts is used to load all the trusted certificates from the backend
|
||||
func (b *backend) loadTrustedCerts(store logical.Storage, certName string) (pool *x509.CertPool, trusted []*ParsedCert, trustedNonCAs []*ParsedCert) {
|
||||
func (b *backend) loadTrustedCerts(ctx context.Context, storage logical.Storage, certName string) (pool *x509.CertPool, trusted []*ParsedCert, trustedNonCAs []*ParsedCert) {
|
||||
pool = x509.NewCertPool()
|
||||
trusted = make([]*ParsedCert, 0)
|
||||
trustedNonCAs = make([]*ParsedCert, 0)
|
||||
names, err := store.List("cert/")
|
||||
names, err := storage.List(ctx, "cert/")
|
||||
if err != nil {
|
||||
b.Logger().Error("cert: failed to list trusted certs", "error", err)
|
||||
return
|
||||
@@ -342,7 +345,7 @@ func (b *backend) loadTrustedCerts(store logical.Storage, certName string) (pool
|
||||
if certName != "" && name != certName {
|
||||
continue
|
||||
}
|
||||
entry, err := b.Cert(store, strings.TrimPrefix(name, "cert/"))
|
||||
entry, err := b.Cert(ctx, storage, strings.TrimPrefix(name, "cert/"))
|
||||
if err != nil {
|
||||
b.Logger().Error("cert: failed to load trusted cert", "name", name, "error", err)
|
||||
continue
|
||||
@@ -423,9 +426,6 @@ func validateConnState(roots *x509.CertPool, cs *tls.ConnectionState) ([][]*x509
|
||||
if len(certs) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
if certs[0].IsCA {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
opts := x509.VerifyOptions{
|
||||
Roots: roots,
|
||||
|
||||
@@ -11,9 +11,9 @@ import (
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
func Factory(conf *logical.BackendConfig) (logical.Backend, error) {
|
||||
func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) {
|
||||
b := Backend()
|
||||
if err := b.Setup(conf); err != nil {
|
||||
if err := b.Setup(ctx, conf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return b, nil
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package github
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
@@ -14,7 +15,7 @@ import (
|
||||
func TestBackend_Config(t *testing.T) {
|
||||
defaultLeaseTTLVal := time.Hour * 24
|
||||
maxLeaseTTLVal := time.Hour * 24 * 2
|
||||
b, err := Factory(&logical.BackendConfig{
|
||||
b, err := Factory(context.Background(), &logical.BackendConfig{
|
||||
Logger: nil,
|
||||
System: &logical.StaticSystemView{
|
||||
DefaultLeaseTTLVal: defaultLeaseTTLVal,
|
||||
@@ -92,7 +93,7 @@ func testConfigWrite(t *testing.T, d map[string]interface{}) logicaltest.TestSte
|
||||
func TestBackend_basic(t *testing.T) {
|
||||
defaultLeaseTTLVal := time.Hour * 24
|
||||
maxLeaseTTLVal := time.Hour * 24 * 32
|
||||
b, err := Factory(&logical.BackendConfig{
|
||||
b, err := Factory(context.Background(), &logical.BackendConfig{
|
||||
Logger: nil,
|
||||
System: &logical.StaticSystemView{
|
||||
DefaultLeaseTTLVal: defaultLeaseTTLVal,
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/fatih/structs"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/logical/framework"
|
||||
)
|
||||
@@ -87,7 +86,7 @@ func (b *backend) pathConfigWrite(ctx context.Context, req *logical.Request, dat
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := req.Storage.Put(entry); err != nil {
|
||||
if err := req.Storage.Put(ctx, entry); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -95,7 +94,7 @@ func (b *backend) pathConfigWrite(ctx context.Context, req *logical.Request, dat
|
||||
}
|
||||
|
||||
func (b *backend) pathConfigRead(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
config, err := b.Config(req.Storage)
|
||||
config, err := b.Config(ctx, req.Storage)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -108,14 +107,19 @@ func (b *backend) pathConfigRead(ctx context.Context, req *logical.Request, data
|
||||
config.MaxTTL /= time.Second
|
||||
|
||||
resp := &logical.Response{
|
||||
Data: structs.New(config).Map(),
|
||||
Data: map[string]interface{}{
|
||||
"organization": config.Organization,
|
||||
"base_url": config.BaseURL,
|
||||
"ttl": config.TTL,
|
||||
"max_ttl": config.MaxTTL,
|
||||
},
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// Config returns the configuration for this backend.
|
||||
func (b *backend) Config(s logical.Storage) (*config, error) {
|
||||
entry, err := s.Get("config")
|
||||
func (b *backend) Config(ctx context.Context, s logical.Storage) (*config, error) {
|
||||
entry, err := s.Get(ctx, "config")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -33,7 +33,7 @@ func (b *backend) pathLoginAliasLookahead(ctx context.Context, req *logical.Requ
|
||||
token := data.Get("token").(string)
|
||||
|
||||
var verifyResp *verifyCredentialsResp
|
||||
if verifyResponse, resp, err := b.verifyCredentials(req, token); err != nil {
|
||||
if verifyResponse, resp, err := b.verifyCredentials(ctx, req, token); err != nil {
|
||||
return nil, err
|
||||
} else if resp != nil {
|
||||
return resp, nil
|
||||
@@ -54,7 +54,7 @@ func (b *backend) pathLogin(ctx context.Context, req *logical.Request, data *fra
|
||||
token := data.Get("token").(string)
|
||||
|
||||
var verifyResp *verifyCredentialsResp
|
||||
if verifyResponse, resp, err := b.verifyCredentials(req, token); err != nil {
|
||||
if verifyResponse, resp, err := b.verifyCredentials(ctx, req, token); err != nil {
|
||||
return nil, err
|
||||
} else if resp != nil {
|
||||
return resp, nil
|
||||
@@ -62,7 +62,7 @@ func (b *backend) pathLogin(ctx context.Context, req *logical.Request, data *fra
|
||||
verifyResp = verifyResponse
|
||||
}
|
||||
|
||||
config, err := b.Config(req.Storage)
|
||||
config, err := b.Config(ctx, req.Storage)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -117,7 +117,7 @@ func (b *backend) pathLoginRenew(ctx context.Context, req *logical.Request, d *f
|
||||
token := tokenRaw.(string)
|
||||
|
||||
var verifyResp *verifyCredentialsResp
|
||||
if verifyResponse, resp, err := b.verifyCredentials(req, token); err != nil {
|
||||
if verifyResponse, resp, err := b.verifyCredentials(ctx, req, token); err != nil {
|
||||
return nil, err
|
||||
} else if resp != nil {
|
||||
return resp, nil
|
||||
@@ -128,7 +128,7 @@ func (b *backend) pathLoginRenew(ctx context.Context, req *logical.Request, d *f
|
||||
return nil, fmt.Errorf("policies do not match")
|
||||
}
|
||||
|
||||
config, err := b.Config(req.Storage)
|
||||
config, err := b.Config(ctx, req.Storage)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -150,8 +150,8 @@ func (b *backend) pathLoginRenew(ctx context.Context, req *logical.Request, d *f
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (b *backend) verifyCredentials(req *logical.Request, token string) (*verifyCredentialsResp, *logical.Response, error) {
|
||||
config, err := b.Config(req.Storage)
|
||||
func (b *backend) verifyCredentials(ctx context.Context, req *logical.Request, token string) (*verifyCredentialsResp, *logical.Response, error) {
|
||||
config, err := b.Config(ctx, req.Storage)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
@@ -174,7 +174,7 @@ func (b *backend) verifyCredentials(req *logical.Request, token string) (*verify
|
||||
}
|
||||
|
||||
// Get the user
|
||||
user, _, err := client.Users.Get(context.Background(), "")
|
||||
user, _, err := client.Users.Get(ctx, "")
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
@@ -188,7 +188,7 @@ func (b *backend) verifyCredentials(req *logical.Request, token string) (*verify
|
||||
|
||||
var allOrgs []*github.Organization
|
||||
for {
|
||||
orgs, resp, err := client.Organizations.List(context.Background(), "", orgOpt)
|
||||
orgs, resp, err := client.Organizations.List(ctx, "", orgOpt)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
@@ -218,7 +218,7 @@ func (b *backend) verifyCredentials(req *logical.Request, token string) (*verify
|
||||
|
||||
var allTeams []*github.Team
|
||||
for {
|
||||
teams, resp, err := client.Organizations.ListUserTeams(context.Background(), teamOpt)
|
||||
teams, resp, err := client.Organizations.ListUserTeams(ctx, teamOpt)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
@@ -242,13 +242,13 @@ func (b *backend) verifyCredentials(req *logical.Request, token string) (*verify
|
||||
}
|
||||
}
|
||||
|
||||
groupPoliciesList, err := b.TeamMap.Policies(req.Storage, teamNames...)
|
||||
groupPoliciesList, err := b.TeamMap.Policies(ctx, req.Storage, teamNames...)
|
||||
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
userPoliciesList, err := b.UserMap.Policies(req.Storage, []string{*user.Login}...)
|
||||
userPoliciesList, err := b.UserMap.Policies(ctx, req.Storage, []string{*user.Login}...)
|
||||
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
|
||||
@@ -2,6 +2,7 @@ package ldap
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"text/template"
|
||||
|
||||
@@ -12,9 +13,9 @@ import (
|
||||
"github.com/hashicorp/vault/logical/framework"
|
||||
)
|
||||
|
||||
func Factory(conf *logical.BackendConfig) (logical.Backend, error) {
|
||||
func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) {
|
||||
b := Backend()
|
||||
if err := b.Setup(conf); err != nil {
|
||||
if err := b.Setup(ctx, conf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return b, nil
|
||||
@@ -92,9 +93,9 @@ func EscapeLDAPValue(input string) string {
|
||||
return input
|
||||
}
|
||||
|
||||
func (b *backend) Login(req *logical.Request, username string, password string) ([]string, *logical.Response, []string, error) {
|
||||
func (b *backend) Login(ctx context.Context, req *logical.Request, username string, password string) ([]string, *logical.Response, []string, error) {
|
||||
|
||||
cfg, err := b.Config(req)
|
||||
cfg, err := b.Config(ctx, req)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
@@ -172,7 +173,7 @@ func (b *backend) Login(req *logical.Request, username string, password string)
|
||||
|
||||
var allGroups []string
|
||||
// Import the custom added groups from ldap backend
|
||||
user, err := b.User(req.Storage, username)
|
||||
user, err := b.User(ctx, req.Storage, username)
|
||||
if err == nil && user != nil && user.Groups != nil {
|
||||
if b.Logger().IsDebug() {
|
||||
b.Logger().Debug("auth/ldap: adding local groups", "num_local_groups", len(user.Groups), "local_groups", user.Groups)
|
||||
@@ -185,7 +186,7 @@ func (b *backend) Login(req *logical.Request, username string, password string)
|
||||
// Retrieve policies
|
||||
var policies []string
|
||||
for _, groupName := range allGroups {
|
||||
group, err := b.Group(req.Storage, groupName)
|
||||
group, err := b.Group(ctx, req.Storage, groupName)
|
||||
if err == nil && group != nil {
|
||||
policies = append(policies, group.Policies...)
|
||||
}
|
||||
|
||||
@@ -23,7 +23,7 @@ func createBackendWithStorage(t *testing.T) (*backend, logical.Storage) {
|
||||
t.Fatalf("failed to create backend")
|
||||
}
|
||||
|
||||
err := b.Backend.Setup(config)
|
||||
err := b.Backend.Setup(context.Background(), config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -120,7 +120,7 @@ func TestLdapAuthBackend_UserPolicies(t *testing.T) {
|
||||
func factory(t *testing.T) logical.Backend {
|
||||
defaultLeaseTTLVal := time.Hour * 24
|
||||
maxLeaseTTLVal := time.Hour * 24 * 32
|
||||
b, err := Factory(&logical.BackendConfig{
|
||||
b, err := Factory(context.Background(), &logical.BackendConfig{
|
||||
Logger: nil,
|
||||
System: &logical.StaticSystemView{
|
||||
DefaultLeaseTTLVal: defaultLeaseTTLVal,
|
||||
|
||||
@@ -130,7 +130,7 @@ Default: cn`,
|
||||
/*
|
||||
* Construct ConfigEntry struct using stored configuration.
|
||||
*/
|
||||
func (b *backend) Config(req *logical.Request) (*ConfigEntry, error) {
|
||||
func (b *backend) Config(ctx context.Context, req *logical.Request) (*ConfigEntry, error) {
|
||||
// Schema for ConfigEntry
|
||||
fd, err := b.getConfigFieldData()
|
||||
if err != nil {
|
||||
@@ -143,7 +143,7 @@ func (b *backend) Config(req *logical.Request) (*ConfigEntry, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
storedConfig, err := req.Storage.Get("config")
|
||||
storedConfig, err := req.Storage.Get(ctx, "config")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -165,7 +165,7 @@ func (b *backend) Config(req *logical.Request) (*ConfigEntry, error) {
|
||||
}
|
||||
|
||||
func (b *backend) pathConfigRead(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
cfg, err := b.Config(req)
|
||||
cfg, err := b.Config(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -299,7 +299,7 @@ func (b *backend) pathConfigWrite(ctx context.Context, req *logical.Request, d *
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := req.Storage.Put(entry); err != nil {
|
||||
if err := req.Storage.Put(ctx, entry); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
@@ -47,8 +47,8 @@ func pathGroups(b *backend) *framework.Path {
|
||||
}
|
||||
}
|
||||
|
||||
func (b *backend) Group(s logical.Storage, n string) (*GroupEntry, error) {
|
||||
entry, err := s.Get("group/" + n)
|
||||
func (b *backend) Group(ctx context.Context, s logical.Storage, n string) (*GroupEntry, error) {
|
||||
entry, err := s.Get(ctx, "group/"+n)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -65,7 +65,7 @@ func (b *backend) Group(s logical.Storage, n string) (*GroupEntry, error) {
|
||||
}
|
||||
|
||||
func (b *backend) pathGroupDelete(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
err := req.Storage.Delete("group/" + d.Get("name").(string))
|
||||
err := req.Storage.Delete(ctx, "group/"+d.Get("name").(string))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -74,7 +74,7 @@ func (b *backend) pathGroupDelete(ctx context.Context, req *logical.Request, d *
|
||||
}
|
||||
|
||||
func (b *backend) pathGroupRead(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
group, err := b.Group(req.Storage, d.Get("name").(string))
|
||||
group, err := b.Group(ctx, req.Storage, d.Get("name").(string))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -97,7 +97,7 @@ func (b *backend) pathGroupWrite(ctx context.Context, req *logical.Request, d *f
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := req.Storage.Put(entry); err != nil {
|
||||
if err := req.Storage.Put(ctx, entry); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -105,7 +105,7 @@ func (b *backend) pathGroupWrite(ctx context.Context, req *logical.Request, d *f
|
||||
}
|
||||
|
||||
func (b *backend) pathGroupList(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
groups, err := req.Storage.List("group/")
|
||||
groups, err := req.Storage.List(ctx, "group/")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -54,7 +54,7 @@ func (b *backend) pathLogin(ctx context.Context, req *logical.Request, d *framew
|
||||
username := d.Get("username").(string)
|
||||
password := d.Get("password").(string)
|
||||
|
||||
policies, resp, groupNames, err := b.Login(req, username, password)
|
||||
policies, resp, groupNames, err := b.Login(ctx, req, username, password)
|
||||
// Handle an internal error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -102,7 +102,7 @@ func (b *backend) pathLoginRenew(ctx context.Context, req *logical.Request, d *f
|
||||
username := req.Auth.Metadata["username"]
|
||||
password := req.Auth.InternalData["password"].(string)
|
||||
|
||||
loginPolicies, resp, groupNames, err := b.Login(req, username, password)
|
||||
loginPolicies, resp, groupNames, err := b.Login(ctx, req, username, password)
|
||||
if len(loginPolicies) == 0 {
|
||||
return resp, err
|
||||
}
|
||||
|
||||
@@ -54,8 +54,8 @@ func pathUsers(b *backend) *framework.Path {
|
||||
}
|
||||
}
|
||||
|
||||
func (b *backend) User(s logical.Storage, n string) (*UserEntry, error) {
|
||||
entry, err := s.Get("user/" + n)
|
||||
func (b *backend) User(ctx context.Context, s logical.Storage, n string) (*UserEntry, error) {
|
||||
entry, err := s.Get(ctx, "user/"+n)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -72,7 +72,7 @@ func (b *backend) User(s logical.Storage, n string) (*UserEntry, error) {
|
||||
}
|
||||
|
||||
func (b *backend) pathUserDelete(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
err := req.Storage.Delete("user/" + d.Get("name").(string))
|
||||
err := req.Storage.Delete(ctx, "user/"+d.Get("name").(string))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -81,7 +81,7 @@ func (b *backend) pathUserDelete(ctx context.Context, req *logical.Request, d *f
|
||||
}
|
||||
|
||||
func (b *backend) pathUserRead(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
user, err := b.User(req.Storage, d.Get("name").(string))
|
||||
user, err := b.User(ctx, req.Storage, d.Get("name").(string))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -113,7 +113,7 @@ func (b *backend) pathUserWrite(ctx context.Context, req *logical.Request, d *fr
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := req.Storage.Put(entry); err != nil {
|
||||
if err := req.Storage.Put(ctx, entry); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -121,7 +121,7 @@ func (b *backend) pathUserWrite(ctx context.Context, req *logical.Request, d *fr
|
||||
}
|
||||
|
||||
func (b *backend) pathUserList(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
users, err := req.Storage.List("user/")
|
||||
users, err := req.Storage.List(ctx, "user/")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package okta
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/chrismalek/oktasdk-go/okta"
|
||||
@@ -9,9 +10,9 @@ import (
|
||||
"github.com/hashicorp/vault/logical/framework"
|
||||
)
|
||||
|
||||
func Factory(conf *logical.BackendConfig) (logical.Backend, error) {
|
||||
func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) {
|
||||
b := Backend()
|
||||
if err := b.Setup(conf); err != nil {
|
||||
if err := b.Setup(ctx, conf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return b, nil
|
||||
@@ -54,8 +55,8 @@ type backend struct {
|
||||
*framework.Backend
|
||||
}
|
||||
|
||||
func (b *backend) Login(req *logical.Request, username string, password string) ([]string, *logical.Response, []string, error) {
|
||||
cfg, err := b.Config(req.Storage)
|
||||
func (b *backend) Login(ctx context.Context, req *logical.Request, username string, password string) ([]string, *logical.Response, []string, error) {
|
||||
cfg, err := b.Config(ctx, req.Storage)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
@@ -110,7 +111,7 @@ func (b *backend) Login(req *logical.Request, username string, password string)
|
||||
}
|
||||
|
||||
// Import the custom added groups from okta backend
|
||||
user, err := b.User(req.Storage, username)
|
||||
user, err := b.User(ctx, req.Storage, username)
|
||||
if err != nil {
|
||||
if b.Logger().IsDebug() {
|
||||
b.Logger().Debug("auth/okta: error looking up user", "error", err)
|
||||
@@ -126,7 +127,7 @@ func (b *backend) Login(req *logical.Request, username string, password string)
|
||||
// Retrieve policies
|
||||
var policies []string
|
||||
for _, groupName := range allGroups {
|
||||
entry, _, err := b.Group(req.Storage, groupName)
|
||||
entry, _, err := b.Group(ctx, req.Storage, groupName)
|
||||
if err != nil {
|
||||
if b.Logger().IsDebug() {
|
||||
b.Logger().Debug("auth/okta: error looking up group policies", "error", err)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package okta
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
@@ -19,7 +20,7 @@ import (
|
||||
func TestBackend_Config(t *testing.T) {
|
||||
defaultLeaseTTLVal := time.Hour * 12
|
||||
maxLeaseTTLVal := time.Hour * 24
|
||||
b, err := Factory(&logical.BackendConfig{
|
||||
b, err := Factory(context.Background(), &logical.BackendConfig{
|
||||
Logger: logformat.NewVaultLogger(log.LevelTrace),
|
||||
System: &logical.StaticSystemView{
|
||||
DefaultLeaseTTLVal: defaultLeaseTTLVal,
|
||||
|
||||
@@ -69,8 +69,8 @@ func pathConfig(b *backend) *framework.Path {
|
||||
}
|
||||
|
||||
// Config returns the configuration for this backend.
|
||||
func (b *backend) Config(s logical.Storage) (*ConfigEntry, error) {
|
||||
entry, err := s.Get("config")
|
||||
func (b *backend) Config(ctx context.Context, s logical.Storage) (*ConfigEntry, error) {
|
||||
entry, err := s.Get(ctx, "config")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -89,7 +89,7 @@ func (b *backend) Config(s logical.Storage) (*ConfigEntry, error) {
|
||||
}
|
||||
|
||||
func (b *backend) pathConfigRead(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
cfg, err := b.Config(req.Storage)
|
||||
cfg, err := b.Config(ctx, req.Storage)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -116,7 +116,7 @@ func (b *backend) pathConfigRead(ctx context.Context, req *logical.Request, d *f
|
||||
}
|
||||
|
||||
func (b *backend) pathConfigWrite(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
cfg, err := b.Config(req.Storage)
|
||||
cfg, err := b.Config(ctx, req.Storage)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -193,7 +193,7 @@ func (b *backend) pathConfigWrite(ctx context.Context, req *logical.Request, d *
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := req.Storage.Put(jsonCfg); err != nil {
|
||||
if err := req.Storage.Put(ctx, jsonCfg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -201,7 +201,7 @@ func (b *backend) pathConfigWrite(ctx context.Context, req *logical.Request, d *
|
||||
}
|
||||
|
||||
func (b *backend) pathConfigExistenceCheck(ctx context.Context, req *logical.Request, d *framework.FieldData) (bool, error) {
|
||||
cfg, err := b.Config(req.Storage)
|
||||
cfg, err := b.Config(ctx, req.Storage)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
@@ -50,20 +50,20 @@ func pathGroups(b *backend) *framework.Path {
|
||||
|
||||
// We look up groups in a case-insensitive manner since Okta is case-preserving
|
||||
// but case-insensitive for comparisons
|
||||
func (b *backend) Group(s logical.Storage, n string) (*GroupEntry, string, error) {
|
||||
func (b *backend) Group(ctx context.Context, s logical.Storage, n string) (*GroupEntry, string, error) {
|
||||
canonicalName := n
|
||||
entry, err := s.Get("group/" + n)
|
||||
entry, err := s.Get(ctx, "group/"+n)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
if entry == nil {
|
||||
entries, err := s.List("group/")
|
||||
entries, err := s.List(ctx, "group/")
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
for _, groupName := range entries {
|
||||
if strings.ToLower(groupName) == strings.ToLower(n) {
|
||||
entry, err = s.Get("group/" + groupName)
|
||||
entry, err = s.Get(ctx, "group/"+groupName)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
@@ -90,12 +90,12 @@ func (b *backend) pathGroupDelete(ctx context.Context, req *logical.Request, d *
|
||||
return logical.ErrorResponse("'name' must be supplied"), nil
|
||||
}
|
||||
|
||||
entry, canonicalName, err := b.Group(req.Storage, name)
|
||||
entry, canonicalName, err := b.Group(ctx, req.Storage, name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if entry != nil {
|
||||
err := req.Storage.Delete("group/" + canonicalName)
|
||||
err := req.Storage.Delete(ctx, "group/"+canonicalName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -110,7 +110,7 @@ func (b *backend) pathGroupRead(ctx context.Context, req *logical.Request, d *fr
|
||||
return logical.ErrorResponse("'name' must be supplied"), nil
|
||||
}
|
||||
|
||||
group, _, err := b.Group(req.Storage, name)
|
||||
group, _, err := b.Group(ctx, req.Storage, name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -133,7 +133,7 @@ func (b *backend) pathGroupWrite(ctx context.Context, req *logical.Request, d *f
|
||||
|
||||
// Check for an existing group, possibly lowercased so that we keep using
|
||||
// existing user set values
|
||||
_, canonicalName, err := b.Group(req.Storage, name)
|
||||
_, canonicalName, err := b.Group(ctx, req.Storage, name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -149,7 +149,7 @@ func (b *backend) pathGroupWrite(ctx context.Context, req *logical.Request, d *f
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := req.Storage.Put(entry); err != nil {
|
||||
if err := req.Storage.Put(ctx, entry); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -157,7 +157,7 @@ func (b *backend) pathGroupWrite(ctx context.Context, req *logical.Request, d *f
|
||||
}
|
||||
|
||||
func (b *backend) pathGroupList(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
groups, err := req.Storage.List("group/")
|
||||
groups, err := req.Storage.List(ctx, "group/")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -56,7 +56,7 @@ func (b *backend) pathLogin(ctx context.Context, req *logical.Request, d *framew
|
||||
username := d.Get("username").(string)
|
||||
password := d.Get("password").(string)
|
||||
|
||||
policies, resp, groupNames, err := b.Login(req, username, password)
|
||||
policies, resp, groupNames, err := b.Login(ctx, req, username, password)
|
||||
// Handle an internal error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -72,7 +72,7 @@ func (b *backend) pathLogin(ctx context.Context, req *logical.Request, d *framew
|
||||
|
||||
sort.Strings(policies)
|
||||
|
||||
cfg, err := b.getConfig(req)
|
||||
cfg, err := b.getConfig(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -112,7 +112,7 @@ func (b *backend) pathLoginRenew(ctx context.Context, req *logical.Request, d *f
|
||||
username := req.Auth.Metadata["username"]
|
||||
password := req.Auth.InternalData["password"].(string)
|
||||
|
||||
loginPolicies, resp, groupNames, err := b.Login(req, username, password)
|
||||
loginPolicies, resp, groupNames, err := b.Login(ctx, req, username, password)
|
||||
if len(loginPolicies) == 0 {
|
||||
return resp, err
|
||||
}
|
||||
@@ -121,7 +121,7 @@ func (b *backend) pathLoginRenew(ctx context.Context, req *logical.Request, d *f
|
||||
return nil, fmt.Errorf("policies have changed, not renewing")
|
||||
}
|
||||
|
||||
cfg, err := b.getConfig(req)
|
||||
cfg, err := b.getConfig(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -144,9 +144,9 @@ func (b *backend) pathLoginRenew(ctx context.Context, req *logical.Request, d *f
|
||||
|
||||
}
|
||||
|
||||
func (b *backend) getConfig(req *logical.Request) (*ConfigEntry, error) {
|
||||
func (b *backend) getConfig(ctx context.Context, req *logical.Request) (*ConfigEntry, error) {
|
||||
|
||||
cfg, err := b.Config(req.Storage)
|
||||
cfg, err := b.Config(ctx, req.Storage)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -51,8 +51,8 @@ func pathUsers(b *backend) *framework.Path {
|
||||
}
|
||||
}
|
||||
|
||||
func (b *backend) User(s logical.Storage, n string) (*UserEntry, error) {
|
||||
entry, err := s.Get("user/" + n)
|
||||
func (b *backend) User(ctx context.Context, s logical.Storage, n string) (*UserEntry, error) {
|
||||
entry, err := s.Get(ctx, "user/"+n)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -74,7 +74,7 @@ func (b *backend) pathUserDelete(ctx context.Context, req *logical.Request, d *f
|
||||
return logical.ErrorResponse("Error empty name"), nil
|
||||
}
|
||||
|
||||
err := req.Storage.Delete("user/" + name)
|
||||
err := req.Storage.Delete(ctx, "user/"+name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -88,7 +88,7 @@ func (b *backend) pathUserRead(ctx context.Context, req *logical.Request, d *fra
|
||||
return logical.ErrorResponse("Error empty name"), nil
|
||||
}
|
||||
|
||||
user, err := b.User(req.Storage, name)
|
||||
user, err := b.User(ctx, req.Storage, name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -121,7 +121,7 @@ func (b *backend) pathUserWrite(ctx context.Context, req *logical.Request, d *fr
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := req.Storage.Put(entry); err != nil {
|
||||
if err := req.Storage.Put(ctx, entry); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -129,7 +129,7 @@ func (b *backend) pathUserWrite(ctx context.Context, req *logical.Request, d *fr
|
||||
}
|
||||
|
||||
func (b *backend) pathUserList(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
users, err := req.Storage.List("user/")
|
||||
users, err := req.Storage.List(ctx, "user/")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -1,14 +1,16 @@
|
||||
package radius
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/hashicorp/vault/helper/mfa"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/logical/framework"
|
||||
)
|
||||
|
||||
func Factory(conf *logical.BackendConfig) (logical.Backend, error) {
|
||||
func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) {
|
||||
b := Backend()
|
||||
if err := b.Setup(conf); err != nil {
|
||||
if err := b.Setup(ctx, conf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return b, nil
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package radius
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"reflect"
|
||||
@@ -17,7 +18,7 @@ const (
|
||||
)
|
||||
|
||||
func TestBackend_Config(t *testing.T) {
|
||||
b, err := Factory(&logical.BackendConfig{
|
||||
b, err := Factory(context.Background(), &logical.BackendConfig{
|
||||
Logger: nil,
|
||||
System: &logical.StaticSystemView{
|
||||
DefaultLeaseTTLVal: testSysTTL,
|
||||
@@ -70,7 +71,7 @@ func TestBackend_Config(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestBackend_users(t *testing.T) {
|
||||
b, err := Factory(&logical.BackendConfig{
|
||||
b, err := Factory(context.Background(), &logical.BackendConfig{
|
||||
Logger: nil,
|
||||
System: &logical.StaticSystemView{
|
||||
DefaultLeaseTTLVal: testSysTTL,
|
||||
@@ -98,7 +99,7 @@ func TestBackend_acceptance(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
b, err := Factory(&logical.BackendConfig{
|
||||
b, err := Factory(context.Background(), &logical.BackendConfig{
|
||||
Logger: nil,
|
||||
System: &logical.StaticSystemView{
|
||||
DefaultLeaseTTLVal: testSysTTL,
|
||||
|
||||
@@ -65,7 +65,7 @@ func pathConfig(b *backend) *framework.Path {
|
||||
// Establishes dichotomy of request operation between CreateOperation and UpdateOperation.
|
||||
// Returning 'true' forces an UpdateOperation, CreateOperation otherwise.
|
||||
func (b *backend) configExistenceCheck(ctx context.Context, req *logical.Request, data *framework.FieldData) (bool, error) {
|
||||
entry, err := b.Config(req)
|
||||
entry, err := b.Config(ctx, req)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
@@ -75,9 +75,8 @@ func (b *backend) configExistenceCheck(ctx context.Context, req *logical.Request
|
||||
/*
|
||||
* Construct ConfigEntry struct using stored configuration.
|
||||
*/
|
||||
func (b *backend) Config(req *logical.Request) (*ConfigEntry, error) {
|
||||
|
||||
storedConfig, err := req.Storage.Get("config")
|
||||
func (b *backend) Config(ctx context.Context, req *logical.Request) (*ConfigEntry, error) {
|
||||
storedConfig, err := req.Storage.Get(ctx, "config")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -96,7 +95,7 @@ func (b *backend) Config(req *logical.Request) (*ConfigEntry, error) {
|
||||
}
|
||||
|
||||
func (b *backend) pathConfigRead(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
cfg, err := b.Config(req)
|
||||
cfg, err := b.Config(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -113,7 +112,7 @@ func (b *backend) pathConfigRead(ctx context.Context, req *logical.Request, d *f
|
||||
|
||||
func (b *backend) pathConfigCreateUpdate(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
// Build a ConfigEntry struct out of the supplied FieldData
|
||||
cfg, err := b.Config(req)
|
||||
cfg, err := b.Config(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -190,7 +189,7 @@ func (b *backend) pathConfigCreateUpdate(ctx context.Context, req *logical.Reque
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := req.Storage.Put(entry); err != nil {
|
||||
if err := req.Storage.Put(ctx, entry); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
@@ -76,7 +76,7 @@ func (b *backend) pathLogin(ctx context.Context, req *logical.Request, d *framew
|
||||
return logical.ErrorResponse("password cannot be empty"), nil
|
||||
}
|
||||
|
||||
policies, resp, err := b.RadiusLogin(req, username, password)
|
||||
policies, resp, err := b.RadiusLogin(ctx, req, username, password)
|
||||
// Handle an internal error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -117,7 +117,7 @@ func (b *backend) pathLoginRenew(ctx context.Context, req *logical.Request, d *f
|
||||
var resp *logical.Response
|
||||
var loginPolicies []string
|
||||
|
||||
loginPolicies, resp, err = b.RadiusLogin(req, username, password)
|
||||
loginPolicies, resp, err = b.RadiusLogin(ctx, req, username, password)
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
return resp, err
|
||||
}
|
||||
@@ -129,9 +129,9 @@ func (b *backend) pathLoginRenew(ctx context.Context, req *logical.Request, d *f
|
||||
return framework.LeaseExtend(0, 0, b.System())(ctx, req, d)
|
||||
}
|
||||
|
||||
func (b *backend) RadiusLogin(req *logical.Request, username string, password string) ([]string, *logical.Response, error) {
|
||||
func (b *backend) RadiusLogin(ctx context.Context, req *logical.Request, username string, password string) ([]string, *logical.Response, error) {
|
||||
|
||||
cfg, err := b.Config(req)
|
||||
cfg, err := b.Config(ctx, req)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
@@ -163,7 +163,7 @@ func (b *backend) RadiusLogin(req *logical.Request, username string, password st
|
||||
|
||||
var policies []string
|
||||
// Retrieve user entry from storage
|
||||
user, err := b.user(req.Storage, username)
|
||||
user, err := b.user(ctx, req.Storage, username)
|
||||
if err != nil {
|
||||
return policies, logical.ErrorResponse("could not retrieve user entry from storage"), err
|
||||
}
|
||||
|
||||
@@ -53,7 +53,7 @@ func pathUsers(b *backend) *framework.Path {
|
||||
}
|
||||
|
||||
func (b *backend) userExistenceCheck(ctx context.Context, req *logical.Request, data *framework.FieldData) (bool, error) {
|
||||
userEntry, err := b.user(req.Storage, data.Get("name").(string))
|
||||
userEntry, err := b.user(ctx, req.Storage, data.Get("name").(string))
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
@@ -61,12 +61,12 @@ func (b *backend) userExistenceCheck(ctx context.Context, req *logical.Request,
|
||||
return userEntry != nil, nil
|
||||
}
|
||||
|
||||
func (b *backend) user(s logical.Storage, username string) (*UserEntry, error) {
|
||||
func (b *backend) user(ctx context.Context, s logical.Storage, username string) (*UserEntry, error) {
|
||||
if username == "" {
|
||||
return nil, fmt.Errorf("missing username")
|
||||
}
|
||||
|
||||
entry, err := s.Get("user/" + strings.ToLower(username))
|
||||
entry, err := s.Get(ctx, "user/"+strings.ToLower(username))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -83,7 +83,7 @@ func (b *backend) user(s logical.Storage, username string) (*UserEntry, error) {
|
||||
}
|
||||
|
||||
func (b *backend) pathUserDelete(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
err := req.Storage.Delete("user/" + d.Get("name").(string))
|
||||
err := req.Storage.Delete(ctx, "user/"+d.Get("name").(string))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -92,7 +92,7 @@ func (b *backend) pathUserDelete(ctx context.Context, req *logical.Request, d *f
|
||||
}
|
||||
|
||||
func (b *backend) pathUserRead(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
user, err := b.user(req.Storage, d.Get("name").(string))
|
||||
user, err := b.user(ctx, req.Storage, d.Get("name").(string))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -123,7 +123,7 @@ func (b *backend) pathUserWrite(ctx context.Context, req *logical.Request, d *fr
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := req.Storage.Put(entry); err != nil {
|
||||
if err := req.Storage.Put(ctx, entry); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -131,7 +131,7 @@ func (b *backend) pathUserWrite(ctx context.Context, req *logical.Request, d *fr
|
||||
}
|
||||
|
||||
func (b *backend) pathUserList(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
users, err := req.Storage.List("user/")
|
||||
users, err := req.Storage.List(ctx, "user/")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -1,14 +1,16 @@
|
||||
package userpass
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/hashicorp/vault/helper/mfa"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/logical/framework"
|
||||
)
|
||||
|
||||
func Factory(conf *logical.BackendConfig) (logical.Backend, error) {
|
||||
func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) {
|
||||
b := Backend()
|
||||
if err := b.Setup(conf); err != nil {
|
||||
if err := b.Setup(ctx, conf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return b, nil
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package userpass
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"testing"
|
||||
@@ -45,7 +46,7 @@ func TestBackend_TTLDurations(t *testing.T) {
|
||||
data5 := map[string]interface{}{
|
||||
"password": "password",
|
||||
}
|
||||
b, err := Factory(&logical.BackendConfig{
|
||||
b, err := Factory(context.Background(), &logical.BackendConfig{
|
||||
Logger: nil,
|
||||
System: &logical.StaticSystemView{
|
||||
DefaultLeaseTTLVal: testSysTTL,
|
||||
@@ -69,7 +70,7 @@ func TestBackend_TTLDurations(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestBackend_basic(t *testing.T) {
|
||||
b, err := Factory(&logical.BackendConfig{
|
||||
b, err := Factory(context.Background(), &logical.BackendConfig{
|
||||
Logger: nil,
|
||||
System: &logical.StaticSystemView{
|
||||
DefaultLeaseTTLVal: testSysTTL,
|
||||
@@ -92,7 +93,7 @@ func TestBackend_basic(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestBackend_userCrud(t *testing.T) {
|
||||
b, err := Factory(&logical.BackendConfig{
|
||||
b, err := Factory(context.Background(), &logical.BackendConfig{
|
||||
Logger: nil,
|
||||
System: &logical.StaticSystemView{
|
||||
DefaultLeaseTTLVal: testSysTTL,
|
||||
@@ -115,7 +116,7 @@ func TestBackend_userCrud(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestBackend_userCreateOperation(t *testing.T) {
|
||||
b, err := Factory(&logical.BackendConfig{
|
||||
b, err := Factory(context.Background(), &logical.BackendConfig{
|
||||
Logger: nil,
|
||||
System: &logical.StaticSystemView{
|
||||
DefaultLeaseTTLVal: testSysTTL,
|
||||
@@ -136,7 +137,7 @@ func TestBackend_userCreateOperation(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestBackend_passwordUpdate(t *testing.T) {
|
||||
b, err := Factory(&logical.BackendConfig{
|
||||
b, err := Factory(context.Background(), &logical.BackendConfig{
|
||||
Logger: nil,
|
||||
System: &logical.StaticSystemView{
|
||||
DefaultLeaseTTLVal: testSysTTL,
|
||||
@@ -161,7 +162,7 @@ func TestBackend_passwordUpdate(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestBackend_policiesUpdate(t *testing.T) {
|
||||
b, err := Factory(&logical.BackendConfig{
|
||||
b, err := Factory(context.Background(), &logical.BackendConfig{
|
||||
Logger: nil,
|
||||
System: &logical.StaticSystemView{
|
||||
DefaultLeaseTTLVal: testSysTTL,
|
||||
|
||||
@@ -61,7 +61,7 @@ func (b *backend) pathLogin(ctx context.Context, req *logical.Request, d *framew
|
||||
}
|
||||
|
||||
// Get the user and validate auth
|
||||
user, err := b.user(req.Storage, username)
|
||||
user, err := b.user(ctx, req.Storage, username)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -102,7 +102,7 @@ func (b *backend) pathLogin(ctx context.Context, req *logical.Request, d *framew
|
||||
|
||||
func (b *backend) pathLoginRenew(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
// Get the user
|
||||
user, err := b.user(req.Storage, req.Auth.Metadata["username"])
|
||||
user, err := b.user(ctx, req.Storage, req.Auth.Metadata["username"])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -37,7 +37,7 @@ func pathUserPassword(b *backend) *framework.Path {
|
||||
func (b *backend) pathUserPasswordUpdate(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
username := d.Get("username").(string)
|
||||
|
||||
userEntry, err := b.user(req.Storage, username)
|
||||
userEntry, err := b.user(ctx, req.Storage, username)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -53,7 +53,7 @@ func (b *backend) pathUserPasswordUpdate(ctx context.Context, req *logical.Reque
|
||||
return logical.ErrorResponse(userErr.Error()), logical.ErrInvalidRequest
|
||||
}
|
||||
|
||||
return nil, b.setUser(req.Storage, username, userEntry)
|
||||
return nil, b.setUser(ctx, req.Storage, username, userEntry)
|
||||
}
|
||||
|
||||
func (b *backend) updateUserPassword(req *logical.Request, d *framework.FieldData, userEntry *UserEntry) (error, error) {
|
||||
|
||||
@@ -35,7 +35,7 @@ func pathUserPolicies(b *backend) *framework.Path {
|
||||
func (b *backend) pathUserPoliciesUpdate(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
username := d.Get("username").(string)
|
||||
|
||||
userEntry, err := b.user(req.Storage, username)
|
||||
userEntry, err := b.user(ctx, req.Storage, username)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -45,7 +45,7 @@ func (b *backend) pathUserPoliciesUpdate(ctx context.Context, req *logical.Reque
|
||||
|
||||
userEntry.Policies = policyutil.ParsePolicies(d.Get("policies"))
|
||||
|
||||
return nil, b.setUser(req.Storage, username, userEntry)
|
||||
return nil, b.setUser(ctx, req.Storage, username, userEntry)
|
||||
}
|
||||
|
||||
const pathUserPoliciesHelpSyn = `
|
||||
|
||||
@@ -69,7 +69,7 @@ func pathUsers(b *backend) *framework.Path {
|
||||
}
|
||||
|
||||
func (b *backend) userExistenceCheck(ctx context.Context, req *logical.Request, data *framework.FieldData) (bool, error) {
|
||||
userEntry, err := b.user(req.Storage, data.Get("username").(string))
|
||||
userEntry, err := b.user(ctx, req.Storage, data.Get("username").(string))
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
@@ -77,12 +77,12 @@ func (b *backend) userExistenceCheck(ctx context.Context, req *logical.Request,
|
||||
return userEntry != nil, nil
|
||||
}
|
||||
|
||||
func (b *backend) user(s logical.Storage, username string) (*UserEntry, error) {
|
||||
func (b *backend) user(ctx context.Context, s logical.Storage, username string) (*UserEntry, error) {
|
||||
if username == "" {
|
||||
return nil, fmt.Errorf("missing username")
|
||||
}
|
||||
|
||||
entry, err := s.Get("user/" + strings.ToLower(username))
|
||||
entry, err := s.Get(ctx, "user/"+strings.ToLower(username))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -98,17 +98,17 @@ func (b *backend) user(s logical.Storage, username string) (*UserEntry, error) {
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
func (b *backend) setUser(s logical.Storage, username string, userEntry *UserEntry) error {
|
||||
func (b *backend) setUser(ctx context.Context, s logical.Storage, username string, userEntry *UserEntry) error {
|
||||
entry, err := logical.StorageEntryJSON("user/"+username, userEntry)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return s.Put(entry)
|
||||
return s.Put(ctx, entry)
|
||||
}
|
||||
|
||||
func (b *backend) pathUserList(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
users, err := req.Storage.List("user/")
|
||||
users, err := req.Storage.List(ctx, "user/")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -116,7 +116,7 @@ func (b *backend) pathUserList(ctx context.Context, req *logical.Request, d *fra
|
||||
}
|
||||
|
||||
func (b *backend) pathUserDelete(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
err := req.Storage.Delete("user/" + strings.ToLower(d.Get("username").(string)))
|
||||
err := req.Storage.Delete(ctx, "user/"+strings.ToLower(d.Get("username").(string)))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -125,7 +125,7 @@ func (b *backend) pathUserDelete(ctx context.Context, req *logical.Request, d *f
|
||||
}
|
||||
|
||||
func (b *backend) pathUserRead(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
user, err := b.user(req.Storage, strings.ToLower(d.Get("username").(string)))
|
||||
user, err := b.user(ctx, req.Storage, strings.ToLower(d.Get("username").(string)))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -144,7 +144,7 @@ func (b *backend) pathUserRead(ctx context.Context, req *logical.Request, d *fra
|
||||
|
||||
func (b *backend) userCreateUpdate(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
username := strings.ToLower(d.Get("username").(string))
|
||||
userEntry, err := b.user(req.Storage, username)
|
||||
userEntry, err := b.user(ctx, req.Storage, username)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -182,7 +182,7 @@ func (b *backend) userCreateUpdate(ctx context.Context, req *logical.Request, d
|
||||
return logical.ErrorResponse(fmt.Sprintf("err: %s", err)), nil
|
||||
}
|
||||
|
||||
return nil, b.setUser(req.Storage, username, userEntry)
|
||||
return nil, b.setUser(ctx, req.Storage, username, userEntry)
|
||||
}
|
||||
|
||||
func (b *backend) pathUserWrite(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package aws
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -8,9 +9,9 @@ import (
|
||||
"github.com/hashicorp/vault/logical/framework"
|
||||
)
|
||||
|
||||
func Factory(conf *logical.BackendConfig) (logical.Backend, error) {
|
||||
func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) {
|
||||
b := Backend()
|
||||
if err := b.Setup(conf); err != nil {
|
||||
if err := b.Setup(ctx, conf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return b, nil
|
||||
|
||||
@@ -2,6 +2,7 @@ package aws
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
@@ -22,7 +23,7 @@ import (
|
||||
)
|
||||
|
||||
func getBackend(t *testing.T) logical.Backend {
|
||||
be, _ := Factory(logical.TestBackendConfig())
|
||||
be, _ := Factory(context.Background(), logical.TestBackendConfig())
|
||||
return be
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package aws
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
@@ -13,11 +14,11 @@ import (
|
||||
"github.com/hashicorp/vault/logical"
|
||||
)
|
||||
|
||||
func getRootConfig(s logical.Storage, clientType string) (*aws.Config, error) {
|
||||
func getRootConfig(ctx context.Context, s logical.Storage, clientType string) (*aws.Config, error) {
|
||||
credsConfig := &awsutil.CredentialsConfig{}
|
||||
var endpoint string
|
||||
|
||||
entry, err := s.Get("config/root")
|
||||
entry, err := s.Get(ctx, "config/root")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -63,8 +64,8 @@ func getRootConfig(s logical.Storage, clientType string) (*aws.Config, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
func clientIAM(s logical.Storage) (*iam.IAM, error) {
|
||||
awsConfig, err := getRootConfig(s, "iam")
|
||||
func clientIAM(ctx context.Context, s logical.Storage) (*iam.IAM, error) {
|
||||
awsConfig, err := getRootConfig(ctx, s, "iam")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -77,8 +78,8 @@ func clientIAM(s logical.Storage) (*iam.IAM, error) {
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func clientSTS(s logical.Storage) (*sts.STS, error) {
|
||||
awsConfig, err := getRootConfig(s, "sts")
|
||||
func clientSTS(ctx context.Context, s logical.Storage) (*sts.STS, error) {
|
||||
awsConfig, err := getRootConfig(ctx, s, "sts")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -35,8 +35,8 @@ func pathConfigLease(b *backend) *framework.Path {
|
||||
}
|
||||
|
||||
// Lease returns the lease information
|
||||
func (b *backend) Lease(s logical.Storage) (*configLease, error) {
|
||||
entry, err := s.Get("config/lease")
|
||||
func (b *backend) Lease(ctx context.Context, s logical.Storage) (*configLease, error) {
|
||||
entry, err := s.Get(ctx, "config/lease")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -82,7 +82,7 @@ func (b *backend) pathLeaseWrite(ctx context.Context, req *logical.Request, d *f
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := req.Storage.Put(entry); err != nil {
|
||||
if err := req.Storage.Put(ctx, entry); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -90,7 +90,7 @@ func (b *backend) pathLeaseWrite(ctx context.Context, req *logical.Request, d *f
|
||||
}
|
||||
|
||||
func (b *backend) pathLeaseRead(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
lease, err := b.Lease(req.Storage)
|
||||
lease, err := b.Lease(ctx, req.Storage)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -60,7 +60,7 @@ func pathConfigRootWrite(ctx context.Context, req *logical.Request, data *framew
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := req.Storage.Put(entry); err != nil {
|
||||
if err := req.Storage.Put(ctx, entry); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
@@ -58,7 +58,7 @@ func pathRoles() *framework.Path {
|
||||
}
|
||||
|
||||
func (b *backend) pathRoleList(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
entries, err := req.Storage.List("policy/")
|
||||
entries, err := req.Storage.List(ctx, "policy/")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -66,7 +66,7 @@ func (b *backend) pathRoleList(ctx context.Context, req *logical.Request, d *fra
|
||||
}
|
||||
|
||||
func pathRolesDelete(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
err := req.Storage.Delete("policy/" + d.Get("name").(string))
|
||||
err := req.Storage.Delete(ctx, "policy/"+d.Get("name").(string))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -75,7 +75,7 @@ func pathRolesDelete(ctx context.Context, req *logical.Request, d *framework.Fie
|
||||
}
|
||||
|
||||
func pathRolesRead(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
entry, err := req.Storage.Get("policy/" + d.Get("name").(string))
|
||||
entry, err := req.Storage.Get(ctx, "policy/"+d.Get("name").(string))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -125,7 +125,7 @@ func pathRolesWrite(ctx context.Context, req *logical.Request, d *framework.Fiel
|
||||
"Error compacting policy: %s", err)), nil
|
||||
}
|
||||
// Write the policy into storage
|
||||
err := req.Storage.Put(&logical.StorageEntry{
|
||||
err := req.Storage.Put(ctx, &logical.StorageEntry{
|
||||
Key: "policy/" + d.Get("name").(string),
|
||||
Value: buf.Bytes(),
|
||||
})
|
||||
@@ -134,7 +134,7 @@ func pathRolesWrite(ctx context.Context, req *logical.Request, d *framework.Fiel
|
||||
}
|
||||
} else {
|
||||
// Write the arn ref into storage
|
||||
err := req.Storage.Put(&logical.StorageEntry{
|
||||
err := req.Storage.Put(ctx, &logical.StorageEntry{
|
||||
Key: "policy/" + d.Get("name").(string),
|
||||
Value: []byte(d.Get("arn").(string)),
|
||||
})
|
||||
|
||||
@@ -15,7 +15,7 @@ func TestBackend_PathListRoles(t *testing.T) {
|
||||
config.StorageView = &logical.InmemStorage{}
|
||||
|
||||
b := Backend()
|
||||
if err := b.Setup(config); err != nil {
|
||||
if err := b.Setup(context.Background(), config); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
|
||||
@@ -45,7 +45,7 @@ func (b *backend) pathSTSRead(ctx context.Context, req *logical.Request, d *fram
|
||||
ttl := int64(d.Get("ttl").(int))
|
||||
|
||||
// Read the policy
|
||||
policy, err := req.Storage.Get("policy/" + policyName)
|
||||
policy, err := req.Storage.Get(ctx, "policy/"+policyName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error retrieving role: %s", err)
|
||||
}
|
||||
@@ -57,6 +57,7 @@ func (b *backend) pathSTSRead(ctx context.Context, req *logical.Request, d *fram
|
||||
if strings.HasPrefix(policyValue, "arn:") {
|
||||
if strings.Contains(policyValue, ":role/") {
|
||||
return b.assumeRole(
|
||||
ctx,
|
||||
req.Storage,
|
||||
req.DisplayName, policyName, policyValue,
|
||||
ttl,
|
||||
@@ -69,6 +70,7 @@ func (b *backend) pathSTSRead(ctx context.Context, req *logical.Request, d *fram
|
||||
}
|
||||
// Use the helper to create the secret
|
||||
return b.secretTokenCreate(
|
||||
ctx,
|
||||
req.Storage,
|
||||
req.DisplayName, policyName, policyValue,
|
||||
ttl,
|
||||
|
||||
@@ -34,7 +34,7 @@ func (b *backend) pathUserRead(ctx context.Context, req *logical.Request, d *fra
|
||||
policyName := d.Get("name").(string)
|
||||
|
||||
// Read the policy
|
||||
policy, err := req.Storage.Get("policy/" + policyName)
|
||||
policy, err := req.Storage.Get(ctx, "policy/"+policyName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error retrieving role: %s", err)
|
||||
}
|
||||
@@ -45,10 +45,10 @@ func (b *backend) pathUserRead(ctx context.Context, req *logical.Request, d *fra
|
||||
|
||||
// Use the helper to create the secret
|
||||
return b.secretAccessKeysCreate(
|
||||
req.Storage, req.DisplayName, policyName, string(policy.Value))
|
||||
ctx, req.Storage, req.DisplayName, policyName, string(policy.Value))
|
||||
}
|
||||
|
||||
func pathUserRollback(req *logical.Request, _kind string, data interface{}) error {
|
||||
func pathUserRollback(ctx context.Context, req *logical.Request, _kind string, data interface{}) error {
|
||||
var entry walUser
|
||||
if err := mapstructure.Decode(data, &entry); err != nil {
|
||||
return err
|
||||
@@ -56,7 +56,7 @@ func pathUserRollback(req *logical.Request, _kind string, data interface{}) erro
|
||||
username := entry.UserName
|
||||
|
||||
// Get the client
|
||||
client, err := clientIAM(req.Storage)
|
||||
client, err := clientIAM(ctx, req.Storage)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package aws
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/hashicorp/vault/logical"
|
||||
@@ -11,11 +12,11 @@ var walRollbackMap = map[string]framework.WALRollbackFunc{
|
||||
"user": pathUserRollback,
|
||||
}
|
||||
|
||||
func walRollback(req *logical.Request, kind string, data interface{}) error {
|
||||
func walRollback(ctx context.Context, req *logical.Request, kind string, data interface{}) error {
|
||||
f, ok := walRollbackMap[kind]
|
||||
if !ok {
|
||||
return fmt.Errorf("unknown type to rollback")
|
||||
}
|
||||
|
||||
return f(req, kind, data)
|
||||
return f(ctx, req, kind, data)
|
||||
}
|
||||
|
||||
@@ -65,10 +65,10 @@ func genUsername(displayName, policyName, userType string) (ret string, warning
|
||||
return
|
||||
}
|
||||
|
||||
func (b *backend) secretTokenCreate(s logical.Storage,
|
||||
func (b *backend) secretTokenCreate(ctx context.Context, s logical.Storage,
|
||||
displayName, policyName, policy string,
|
||||
lifeTimeInSeconds int64) (*logical.Response, error) {
|
||||
STSClient, err := clientSTS(s)
|
||||
STSClient, err := clientSTS(ctx, s)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse(err.Error()), nil
|
||||
}
|
||||
@@ -110,10 +110,10 @@ func (b *backend) secretTokenCreate(s logical.Storage,
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (b *backend) assumeRole(s logical.Storage,
|
||||
func (b *backend) assumeRole(ctx context.Context, s logical.Storage,
|
||||
displayName, policyName, policy string,
|
||||
lifeTimeInSeconds int64) (*logical.Response, error) {
|
||||
STSClient, err := clientSTS(s)
|
||||
STSClient, err := clientSTS(ctx, s)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse(err.Error()), nil
|
||||
}
|
||||
@@ -156,9 +156,10 @@ func (b *backend) assumeRole(s logical.Storage,
|
||||
}
|
||||
|
||||
func (b *backend) secretAccessKeysCreate(
|
||||
ctx context.Context,
|
||||
s logical.Storage,
|
||||
displayName, policyName string, policy string) (*logical.Response, error) {
|
||||
client, err := clientIAM(s)
|
||||
client, err := clientIAM(ctx, s)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse(err.Error()), nil
|
||||
}
|
||||
@@ -169,7 +170,7 @@ func (b *backend) secretAccessKeysCreate(
|
||||
// the user is created because if switch the order then the WAL put
|
||||
// can fail, which would put us in an awkward position: we have a user
|
||||
// we need to rollback but can't put the WAL entry to do the rollback.
|
||||
walId, err := framework.PutWAL(s, "user", &walUser{
|
||||
walId, err := framework.PutWAL(ctx, s, "user", &walUser{
|
||||
UserName: username,
|
||||
})
|
||||
if err != nil {
|
||||
@@ -221,7 +222,7 @@ func (b *backend) secretAccessKeysCreate(
|
||||
// Remove the WAL entry, we succeeded! If we fail, we don't return
|
||||
// the secret because it'll get rolled back anyways, so we have to return
|
||||
// an error here.
|
||||
if err := framework.DeleteWAL(s, walId); err != nil {
|
||||
if err := framework.DeleteWAL(ctx, s, walId); err != nil {
|
||||
return nil, fmt.Errorf("Failed to commit WAL entry: %s", err)
|
||||
}
|
||||
|
||||
@@ -236,7 +237,7 @@ func (b *backend) secretAccessKeysCreate(
|
||||
"is_sts": false,
|
||||
})
|
||||
|
||||
lease, err := b.Lease(s)
|
||||
lease, err := b.Lease(ctx, s)
|
||||
if err != nil || lease == nil {
|
||||
lease = &configLease{}
|
||||
}
|
||||
@@ -262,7 +263,7 @@ func (b *backend) secretAccessKeysRenew(ctx context.Context, req *logical.Reques
|
||||
}
|
||||
}
|
||||
|
||||
lease, err := b.Lease(req.Storage)
|
||||
lease, err := b.Lease(ctx, req.Storage)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -302,7 +303,7 @@ func secretAccessKeysRevoke(ctx context.Context, req *logical.Request, d *framew
|
||||
}
|
||||
|
||||
// Use the user rollback mechanism to delete this user
|
||||
err := pathUserRollback(req, "user", map[string]interface{}{
|
||||
err := pathUserRollback(ctx, req, "user", map[string]interface{}{
|
||||
"username": username,
|
||||
})
|
||||
if err != nil {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package cassandra
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -11,9 +12,9 @@ import (
|
||||
)
|
||||
|
||||
// Factory creates a new backend
|
||||
func Factory(conf *logical.BackendConfig) (logical.Backend, error) {
|
||||
func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) {
|
||||
b := Backend()
|
||||
if err := b.Setup(conf); err != nil {
|
||||
if err := b.Setup(ctx, conf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return b, nil
|
||||
@@ -43,7 +44,7 @@ func Backend() *backend {
|
||||
|
||||
Invalidate: b.invalidate,
|
||||
|
||||
Clean: func() {
|
||||
Clean: func(_ context.Context) {
|
||||
b.ResetDB(nil)
|
||||
},
|
||||
BackendType: logical.TypeLogical,
|
||||
@@ -77,7 +78,7 @@ type sessionConfig struct {
|
||||
}
|
||||
|
||||
// DB returns the database connection.
|
||||
func (b *backend) DB(s logical.Storage) (*gocql.Session, error) {
|
||||
func (b *backend) DB(ctx context.Context, s logical.Storage) (*gocql.Session, error) {
|
||||
b.lock.Lock()
|
||||
defer b.lock.Unlock()
|
||||
|
||||
@@ -86,7 +87,7 @@ func (b *backend) DB(s logical.Storage) (*gocql.Session, error) {
|
||||
return b.session, nil
|
||||
}
|
||||
|
||||
entry, err := s.Get("config/connection")
|
||||
entry, err := s.Get(ctx, "config/connection")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -120,7 +121,7 @@ func (b *backend) ResetDB(newSession *gocql.Session) {
|
||||
b.session = newSession
|
||||
}
|
||||
|
||||
func (b *backend) invalidate(key string) {
|
||||
func (b *backend) invalidate(_ context.Context, key string) {
|
||||
switch key {
|
||||
case "config/connection":
|
||||
b.ResetDB(nil)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package cassandra
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
@@ -82,7 +83,7 @@ func TestBackend_basic(t *testing.T) {
|
||||
}
|
||||
config := logical.TestBackendConfig()
|
||||
config.StorageView = &logical.InmemStorage{}
|
||||
b, err := Factory(config)
|
||||
b, err := Factory(context.Background(), config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -106,7 +107,7 @@ func TestBackend_roleCrud(t *testing.T) {
|
||||
}
|
||||
config := logical.TestBackendConfig()
|
||||
config.StorageView = &logical.InmemStorage{}
|
||||
b, err := Factory(config)
|
||||
b, err := Factory(context.Background(), config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -87,7 +87,7 @@ take precedence.`,
|
||||
}
|
||||
|
||||
func (b *backend) pathConnectionRead(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
entry, err := req.Storage.Get("config/connection")
|
||||
entry, err := req.Storage.Get(ctx, "config/connection")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -196,7 +196,7 @@ func (b *backend) pathConnectionWrite(ctx context.Context, req *logical.Request,
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := req.Storage.Put(entry); err != nil {
|
||||
if err := req.Storage.Put(ctx, entry); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
@@ -36,7 +36,7 @@ func (b *backend) pathCredsCreateRead(ctx context.Context, req *logical.Request,
|
||||
name := data.Get("name").(string)
|
||||
|
||||
// Get the role
|
||||
role, err := getRole(req.Storage, name)
|
||||
role, err := getRole(ctx, req.Storage, name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -57,7 +57,7 @@ func (b *backend) pathCredsCreateRead(ctx context.Context, req *logical.Request,
|
||||
}
|
||||
|
||||
// Get our connection
|
||||
session, err := b.DB(req.Storage)
|
||||
session, err := b.DB(ctx, req.Storage)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -75,8 +75,8 @@ template values are '{{username}}' and
|
||||
}
|
||||
}
|
||||
|
||||
func getRole(s logical.Storage, n string) (*roleEntry, error) {
|
||||
entry, err := s.Get("role/" + n)
|
||||
func getRole(ctx context.Context, s logical.Storage, n string) (*roleEntry, error) {
|
||||
entry, err := s.Get(ctx, "role/"+n)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -93,7 +93,7 @@ func getRole(s logical.Storage, n string) (*roleEntry, error) {
|
||||
}
|
||||
|
||||
func (b *backend) pathRoleDelete(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
err := req.Storage.Delete("role/" + data.Get("name").(string))
|
||||
err := req.Storage.Delete(ctx, "role/"+data.Get("name").(string))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -102,7 +102,7 @@ func (b *backend) pathRoleDelete(ctx context.Context, req *logical.Request, data
|
||||
}
|
||||
|
||||
func (b *backend) pathRoleRead(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
role, err := getRole(req.Storage, data.Get("name").(string))
|
||||
role, err := getRole(ctx, req.Storage, data.Get("name").(string))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -148,7 +148,7 @@ func (b *backend) pathRoleCreate(ctx context.Context, req *logical.Request, data
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := req.Storage.Put(entryJSON); err != nil {
|
||||
if err := req.Storage.Put(ctx, entryJSON); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
@@ -42,7 +42,7 @@ func (b *backend) secretCredsRenew(ctx context.Context, req *logical.Request, d
|
||||
return nil, fmt.Errorf("error converting role internal data to string")
|
||||
}
|
||||
|
||||
role, err := getRole(req.Storage, roleName)
|
||||
role, err := getRole(ctx, req.Storage, roleName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to load role: %s", err)
|
||||
}
|
||||
@@ -61,7 +61,7 @@ func (b *backend) secretCredsRevoke(ctx context.Context, req *logical.Request, d
|
||||
return nil, fmt.Errorf("error converting username internal data to string")
|
||||
}
|
||||
|
||||
session, err := b.DB(req.Storage)
|
||||
session, err := b.DB(ctx, req.Storage)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting session")
|
||||
}
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
package consul
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/logical/framework"
|
||||
)
|
||||
|
||||
func Factory(conf *logical.BackendConfig) (logical.Backend, error) {
|
||||
func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) {
|
||||
b := Backend()
|
||||
if err := b.Setup(conf); err != nil {
|
||||
if err := b.Setup(ctx, conf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return b, nil
|
||||
|
||||
@@ -83,7 +83,7 @@ func cleanupTestContainer(t *testing.T, cid dockertest.ContainerID) {
|
||||
func TestBackend_config_access(t *testing.T) {
|
||||
config := logical.TestBackendConfig()
|
||||
config.StorageView = &logical.InmemStorage{}
|
||||
b, err := Factory(config)
|
||||
b, err := Factory(context.Background(), config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -130,7 +130,7 @@ func TestBackend_config_access(t *testing.T) {
|
||||
func TestBackend_basic(t *testing.T) {
|
||||
config := logical.TestBackendConfig()
|
||||
config.StorageView = &logical.InmemStorage{}
|
||||
b, err := Factory(config)
|
||||
b, err := Factory(context.Background(), config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -157,7 +157,7 @@ func TestBackend_basic(t *testing.T) {
|
||||
func TestBackend_renew_revoke(t *testing.T) {
|
||||
config := logical.TestBackendConfig()
|
||||
config.StorageView = &logical.InmemStorage{}
|
||||
b, err := Factory(config)
|
||||
b, err := Factory(context.Background(), config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -264,7 +264,7 @@ func TestBackend_renew_revoke(t *testing.T) {
|
||||
func TestBackend_management(t *testing.T) {
|
||||
config := logical.TestBackendConfig()
|
||||
config.StorageView = &logical.InmemStorage{}
|
||||
b, err := Factory(config)
|
||||
b, err := Factory(context.Background(), config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -289,7 +289,7 @@ func TestBackend_management(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestBackend_crud(t *testing.T) {
|
||||
b, _ := Factory(logical.TestBackendConfig())
|
||||
b, _ := Factory(context.Background(), logical.TestBackendConfig())
|
||||
logicaltest.Test(t, logicaltest.TestCase{
|
||||
Backend: b,
|
||||
Steps: []logicaltest.TestStep{
|
||||
@@ -304,7 +304,7 @@ func TestBackend_crud(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestBackend_role_lease(t *testing.T) {
|
||||
b, _ := Factory(logical.TestBackendConfig())
|
||||
b, _ := Factory(context.Background(), logical.TestBackendConfig())
|
||||
logicaltest.Test(t, logicaltest.TestCase{
|
||||
Backend: b,
|
||||
Steps: []logicaltest.TestStep{
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
package consul
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/hashicorp/consul/api"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
)
|
||||
|
||||
func client(s logical.Storage) (*api.Client, error, error) {
|
||||
conf, userErr, intErr := readConfigAccess(s)
|
||||
func client(ctx context.Context, s logical.Storage) (*api.Client, error, error) {
|
||||
conf, userErr, intErr := readConfigAccess(ctx, s)
|
||||
if intErr != nil {
|
||||
return nil, nil, intErr
|
||||
}
|
||||
|
||||
@@ -40,8 +40,8 @@ func pathConfigAccess() *framework.Path {
|
||||
}
|
||||
}
|
||||
|
||||
func readConfigAccess(storage logical.Storage) (*accessConfig, error, error) {
|
||||
entry, err := storage.Get("config/access")
|
||||
func readConfigAccess(ctx context.Context, storage logical.Storage) (*accessConfig, error, error) {
|
||||
entry, err := storage.Get(ctx, "config/access")
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
@@ -60,7 +60,7 @@ func readConfigAccess(storage logical.Storage) (*accessConfig, error, error) {
|
||||
}
|
||||
|
||||
func pathConfigAccessRead(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
conf, userErr, intErr := readConfigAccess(req.Storage)
|
||||
conf, userErr, intErr := readConfigAccess(ctx, req.Storage)
|
||||
if intErr != nil {
|
||||
return nil, intErr
|
||||
}
|
||||
@@ -89,7 +89,7 @@ func pathConfigAccessWrite(ctx context.Context, req *logical.Request, data *fram
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := req.Storage.Put(entry); err != nil {
|
||||
if err := req.Storage.Put(ctx, entry); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
@@ -59,7 +59,7 @@ Defaults to 'client'.`,
|
||||
}
|
||||
|
||||
func (b *backend) pathRoleList(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
entries, err := req.Storage.List("policy/")
|
||||
entries, err := req.Storage.List(ctx, "policy/")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -70,7 +70,7 @@ func (b *backend) pathRoleList(ctx context.Context, req *logical.Request, d *fra
|
||||
func pathRolesRead(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
name := d.Get("name").(string)
|
||||
|
||||
entry, err := req.Storage.Get("policy/" + name)
|
||||
entry, err := req.Storage.Get(ctx, "policy/"+name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -142,7 +142,7 @@ func pathRolesWrite(ctx context.Context, req *logical.Request, d *framework.Fiel
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := req.Storage.Put(entry); err != nil {
|
||||
if err := req.Storage.Put(ctx, entry); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -151,7 +151,7 @@ func pathRolesWrite(ctx context.Context, req *logical.Request, d *framework.Fiel
|
||||
|
||||
func pathRolesDelete(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
name := d.Get("name").(string)
|
||||
if err := req.Storage.Delete("policy/" + name); err != nil {
|
||||
if err := req.Storage.Delete(ctx, "policy/"+name); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return nil, nil
|
||||
|
||||
@@ -29,7 +29,7 @@ func pathToken(b *backend) *framework.Path {
|
||||
func (b *backend) pathTokenRead(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
role := d.Get("role").(string)
|
||||
|
||||
entry, err := req.Storage.Get("policy/" + role)
|
||||
entry, err := req.Storage.Get(ctx, "policy/"+role)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error retrieving role: %s", err)
|
||||
}
|
||||
@@ -47,7 +47,7 @@ func (b *backend) pathTokenRead(ctx context.Context, req *logical.Request, d *fr
|
||||
}
|
||||
|
||||
// Get the consul client
|
||||
c, userErr, intErr := client(req.Storage)
|
||||
c, userErr, intErr := client(ctx, req.Storage)
|
||||
if intErr != nil {
|
||||
return nil, intErr
|
||||
}
|
||||
|
||||
@@ -38,7 +38,7 @@ func (b *backend) secretTokenRenew(ctx context.Context, req *logical.Request, d
|
||||
return framework.LeaseExtend(0, 0, b.System())(ctx, req, d)
|
||||
}
|
||||
|
||||
entry, err := req.Storage.Get("policy/" + role)
|
||||
entry, err := req.Storage.Get(ctx, "policy/"+role)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error retrieving role: %s", err)
|
||||
}
|
||||
@@ -55,7 +55,7 @@ func (b *backend) secretTokenRenew(ctx context.Context, req *logical.Request, d
|
||||
}
|
||||
|
||||
func secretTokenRevoke(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
c, userErr, intErr := client(req.Storage)
|
||||
c, userErr, intErr := client(ctx, req.Storage)
|
||||
if intErr != nil {
|
||||
return nil, intErr
|
||||
}
|
||||
|
||||
@@ -16,9 +16,9 @@ import (
|
||||
|
||||
const databaseConfigPath = "database/config/"
|
||||
|
||||
func Factory(conf *logical.BackendConfig) (logical.Backend, error) {
|
||||
func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) {
|
||||
b := Backend(conf)
|
||||
if err := b.Setup(conf); err != nil {
|
||||
if err := b.Setup(ctx, conf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return b, nil
|
||||
@@ -66,7 +66,7 @@ type databaseBackend struct {
|
||||
}
|
||||
|
||||
// closeAllDBs closes all connections from all database types
|
||||
func (b *databaseBackend) closeAllDBs() {
|
||||
func (b *databaseBackend) closeAllDBs(ctx context.Context) {
|
||||
b.Lock()
|
||||
defer b.Unlock()
|
||||
|
||||
@@ -94,12 +94,12 @@ func (b *databaseBackend) createDBObj(ctx context.Context, s logical.Storage, na
|
||||
return db, nil
|
||||
}
|
||||
|
||||
config, err := b.DatabaseConfig(s, name)
|
||||
config, err := b.DatabaseConfig(ctx, s, name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
db, err = dbplugin.PluginFactory(config.PluginName, b.System(), b.logger)
|
||||
db, err = dbplugin.PluginFactory(ctx, config.PluginName, b.System(), b.logger)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -115,8 +115,8 @@ func (b *databaseBackend) createDBObj(ctx context.Context, s logical.Storage, na
|
||||
return db, nil
|
||||
}
|
||||
|
||||
func (b *databaseBackend) DatabaseConfig(s logical.Storage, name string) (*DatabaseConfig, error) {
|
||||
entry, err := s.Get(fmt.Sprintf("config/%s", name))
|
||||
func (b *databaseBackend) DatabaseConfig(ctx context.Context, s logical.Storage, name string) (*DatabaseConfig, error) {
|
||||
entry, err := s.Get(ctx, fmt.Sprintf("config/%s", name))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read connection configuration: %s", err)
|
||||
}
|
||||
@@ -147,8 +147,8 @@ type upgradeCheck struct {
|
||||
Statements upgradeStatements `json:"statments"`
|
||||
}
|
||||
|
||||
func (b *databaseBackend) Role(s logical.Storage, roleName string) (*roleEntry, error) {
|
||||
entry, err := s.Get("role/" + roleName)
|
||||
func (b *databaseBackend) Role(ctx context.Context, s logical.Storage, roleName string) (*roleEntry, error) {
|
||||
entry, err := s.Get(ctx, "role/"+roleName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -177,7 +177,7 @@ func (b *databaseBackend) Role(s logical.Storage, roleName string) (*roleEntry,
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
func (b *databaseBackend) invalidate(key string) {
|
||||
func (b *databaseBackend) invalidate(ctx context.Context, key string) {
|
||||
b.Lock()
|
||||
defer b.Unlock()
|
||||
|
||||
|
||||
@@ -133,11 +133,11 @@ func TestBackend_RoleUpgrade(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := storage.Put(entry); err != nil {
|
||||
if err := storage.Put(context.Background(), entry); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
role, err := backend.Role(storage, "test")
|
||||
role, err := backend.Role(context.Background(), storage, "test")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -152,11 +152,11 @@ func TestBackend_RoleUpgrade(t *testing.T) {
|
||||
Key: "role/test",
|
||||
Value: []byte(badJSON),
|
||||
}
|
||||
if err := storage.Put(entry); err != nil {
|
||||
if err := storage.Put(context.Background(), entry); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
role, err = backend.Role(storage, "test")
|
||||
role, err = backend.Role(context.Background(), storage, "test")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -177,11 +177,11 @@ func TestBackend_config_connection(t *testing.T) {
|
||||
config := logical.TestBackendConfig()
|
||||
config.StorageView = &logical.InmemStorage{}
|
||||
config.System = sys
|
||||
b, err := Factory(config)
|
||||
b, err := Factory(context.Background(), config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer b.Cleanup()
|
||||
defer b.Cleanup(context.Background())
|
||||
|
||||
configData := map[string]interface{}{
|
||||
"connection_url": "sample_connection_url",
|
||||
@@ -241,11 +241,11 @@ func TestBackend_basic(t *testing.T) {
|
||||
config.StorageView = &logical.InmemStorage{}
|
||||
config.System = sys
|
||||
|
||||
b, err := Factory(config)
|
||||
b, err := Factory(context.Background(), config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer b.Cleanup()
|
||||
defer b.Cleanup(context.Background())
|
||||
|
||||
cleanup, connURL := preparePostgresTestContainer(t, config.StorageView, b)
|
||||
defer cleanup()
|
||||
@@ -399,11 +399,11 @@ func TestBackend_connectionCrud(t *testing.T) {
|
||||
config.StorageView = &logical.InmemStorage{}
|
||||
config.System = sys
|
||||
|
||||
b, err := Factory(config)
|
||||
b, err := Factory(context.Background(), config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer b.Cleanup()
|
||||
defer b.Cleanup(context.Background())
|
||||
|
||||
cleanup, connURL := preparePostgresTestContainer(t, config.StorageView, b)
|
||||
defer cleanup()
|
||||
@@ -544,11 +544,11 @@ func TestBackend_roleCrud(t *testing.T) {
|
||||
config.StorageView = &logical.InmemStorage{}
|
||||
config.System = sys
|
||||
|
||||
b, err := Factory(config)
|
||||
b, err := Factory(context.Background(), config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer b.Cleanup()
|
||||
defer b.Cleanup(context.Background())
|
||||
|
||||
cleanup, connURL := preparePostgresTestContainer(t, config.StorageView, b)
|
||||
defer cleanup()
|
||||
@@ -656,11 +656,11 @@ func TestBackend_allowedRoles(t *testing.T) {
|
||||
config.StorageView = &logical.InmemStorage{}
|
||||
config.System = sys
|
||||
|
||||
b, err := Factory(config)
|
||||
b, err := Factory(context.Background(), config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer b.Cleanup()
|
||||
defer b.Cleanup(context.Background())
|
||||
|
||||
cleanup, connURL := preparePostgresTestContainer(t, config.StorageView, b)
|
||||
defer cleanup()
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package dbplugin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
|
||||
@@ -30,13 +31,13 @@ func (dc *DatabasePluginClient) Close() error {
|
||||
// newPluginClient returns a databaseRPCClient with a connection to a running
|
||||
// plugin. The client is wrapped in a DatabasePluginClient object to ensure the
|
||||
// plugin is killed on call of Close().
|
||||
func newPluginClient(sys pluginutil.RunnerUtil, pluginRunner *pluginutil.PluginRunner, logger log.Logger) (Database, error) {
|
||||
func newPluginClient(ctx context.Context, sys pluginutil.RunnerUtil, pluginRunner *pluginutil.PluginRunner, logger log.Logger) (Database, error) {
|
||||
// pluginMap is the map of plugins we can dispense.
|
||||
var pluginMap = map[string]plugin.Plugin{
|
||||
"database": new(DatabasePlugin),
|
||||
}
|
||||
|
||||
client, err := pluginRunner.Run(sys, pluginMap, handshakeConfig, []string{}, logger)
|
||||
client, err := pluginRunner.Run(ctx, sys, pluginMap, handshakeConfig, []string{}, logger)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -26,9 +26,9 @@ type Database interface {
|
||||
|
||||
// PluginFactory is used to build plugin database types. It wraps the database
|
||||
// object in a logging and metrics middleware.
|
||||
func PluginFactory(pluginName string, sys pluginutil.LookRunnerUtil, logger log.Logger) (Database, error) {
|
||||
func PluginFactory(ctx context.Context, pluginName string, sys pluginutil.LookRunnerUtil, logger log.Logger) (Database, error) {
|
||||
// Look for plugin in the plugin catalog
|
||||
pluginRunner, err := sys.LookupPlugin(pluginName)
|
||||
pluginRunner, err := sys.LookupPlugin(ctx, pluginName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -53,7 +53,7 @@ func PluginFactory(pluginName string, sys pluginutil.LookRunnerUtil, logger log.
|
||||
|
||||
} else {
|
||||
// create a DatabasePluginClient instance
|
||||
db, err = newPluginClient(sys, pluginRunner, logger)
|
||||
db, err = newPluginClient(ctx, sys, pluginRunner, logger)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -136,7 +136,7 @@ func TestPlugin_Initialize(t *testing.T) {
|
||||
cluster, sys := getCluster(t)
|
||||
defer cluster.Cleanup()
|
||||
|
||||
dbRaw, err := dbplugin.PluginFactory("test-plugin", sys, &log.NullLogger{})
|
||||
dbRaw, err := dbplugin.PluginFactory(context.Background(), "test-plugin", sys, &log.NullLogger{})
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
@@ -160,7 +160,7 @@ func TestPlugin_CreateUser(t *testing.T) {
|
||||
cluster, sys := getCluster(t)
|
||||
defer cluster.Cleanup()
|
||||
|
||||
db, err := dbplugin.PluginFactory("test-plugin", sys, &log.NullLogger{})
|
||||
db, err := dbplugin.PluginFactory(context.Background(), "test-plugin", sys, &log.NullLogger{})
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
@@ -200,7 +200,7 @@ func TestPlugin_RenewUser(t *testing.T) {
|
||||
cluster, sys := getCluster(t)
|
||||
defer cluster.Cleanup()
|
||||
|
||||
db, err := dbplugin.PluginFactory("test-plugin", sys, &log.NullLogger{})
|
||||
db, err := dbplugin.PluginFactory(context.Background(), "test-plugin", sys, &log.NullLogger{})
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
@@ -234,7 +234,7 @@ func TestPlugin_RevokeUser(t *testing.T) {
|
||||
cluster, sys := getCluster(t)
|
||||
defer cluster.Cleanup()
|
||||
|
||||
db, err := dbplugin.PluginFactory("test-plugin", sys, &log.NullLogger{})
|
||||
db, err := dbplugin.PluginFactory(context.Background(), "test-plugin", sys, &log.NullLogger{})
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
@@ -276,7 +276,7 @@ func TestPlugin_NetRPC_Initialize(t *testing.T) {
|
||||
cluster, sys := getCluster(t)
|
||||
defer cluster.Cleanup()
|
||||
|
||||
dbRaw, err := dbplugin.PluginFactory("test-plugin-netRPC", sys, &log.NullLogger{})
|
||||
dbRaw, err := dbplugin.PluginFactory(context.Background(), "test-plugin-netRPC", sys, &log.NullLogger{})
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
@@ -300,7 +300,7 @@ func TestPlugin_NetRPC_CreateUser(t *testing.T) {
|
||||
cluster, sys := getCluster(t)
|
||||
defer cluster.Cleanup()
|
||||
|
||||
db, err := dbplugin.PluginFactory("test-plugin-netRPC", sys, &log.NullLogger{})
|
||||
db, err := dbplugin.PluginFactory(context.Background(), "test-plugin-netRPC", sys, &log.NullLogger{})
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
@@ -340,7 +340,7 @@ func TestPlugin_NetRPC_RenewUser(t *testing.T) {
|
||||
cluster, sys := getCluster(t)
|
||||
defer cluster.Cleanup()
|
||||
|
||||
db, err := dbplugin.PluginFactory("test-plugin-netRPC", sys, &log.NullLogger{})
|
||||
db, err := dbplugin.PluginFactory(context.Background(), "test-plugin-netRPC", sys, &log.NullLogger{})
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
@@ -374,7 +374,7 @@ func TestPlugin_NetRPC_RevokeUser(t *testing.T) {
|
||||
cluster, sys := getCluster(t)
|
||||
defer cluster.Cleanup()
|
||||
|
||||
db, err := dbplugin.PluginFactory("test-plugin-netRPC", sys, &log.NullLogger{})
|
||||
db, err := dbplugin.PluginFactory(context.Background(), "test-plugin-netRPC", sys, &log.NullLogger{})
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
@@ -131,7 +131,7 @@ func pathListPluginConnection(b *databaseBackend) *framework.Path {
|
||||
|
||||
func (b *databaseBackend) connectionListHandler() framework.OperationFunc {
|
||||
return func(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
entries, err := req.Storage.List("config/")
|
||||
entries, err := req.Storage.List(ctx, "config/")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -148,7 +148,7 @@ func (b *databaseBackend) connectionReadHandler() framework.OperationFunc {
|
||||
return logical.ErrorResponse(respErrEmptyName), nil
|
||||
}
|
||||
|
||||
entry, err := req.Storage.Get(fmt.Sprintf("config/%s", name))
|
||||
entry, err := req.Storage.Get(ctx, fmt.Sprintf("config/%s", name))
|
||||
if err != nil {
|
||||
return nil, errors.New("failed to read connection configuration")
|
||||
}
|
||||
@@ -174,7 +174,7 @@ func (b *databaseBackend) connectionDeleteHandler() framework.OperationFunc {
|
||||
return logical.ErrorResponse(respErrEmptyName), nil
|
||||
}
|
||||
|
||||
err := req.Storage.Delete(fmt.Sprintf("config/%s", name))
|
||||
err := req.Storage.Delete(ctx, fmt.Sprintf("config/%s", name))
|
||||
if err != nil {
|
||||
return nil, errors.New("failed to delete connection configuration")
|
||||
}
|
||||
@@ -226,7 +226,7 @@ func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc {
|
||||
AllowedRoles: allowedRoles,
|
||||
}
|
||||
|
||||
db, err := dbplugin.PluginFactory(config.PluginName, b.System(), b.logger)
|
||||
db, err := dbplugin.PluginFactory(ctx, config.PluginName, b.System(), b.logger)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse(fmt.Sprintf("error creating database object: %s", err)), nil
|
||||
}
|
||||
@@ -252,7 +252,7 @@ func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := req.Storage.Put(entry); err != nil {
|
||||
if err := req.Storage.Put(ctx, entry); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
@@ -35,7 +35,7 @@ func (b *databaseBackend) pathCredsCreateRead() framework.OperationFunc {
|
||||
name := data.Get("name").(string)
|
||||
|
||||
// Get the role
|
||||
role, err := b.Role(req.Storage, name)
|
||||
role, err := b.Role(ctx, req.Storage, name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -43,7 +43,7 @@ func (b *databaseBackend) pathCredsCreateRead() framework.OperationFunc {
|
||||
return logical.ErrorResponse(fmt.Sprintf("unknown role: %s", name)), nil
|
||||
}
|
||||
|
||||
dbConfig, err := b.DatabaseConfig(req.Storage, role.DBName)
|
||||
dbConfig, err := b.DatabaseConfig(ctx, req.Storage, role.DBName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -87,7 +87,7 @@ func pathRoles(b *databaseBackend) *framework.Path {
|
||||
|
||||
func (b *databaseBackend) pathRoleDelete() framework.OperationFunc {
|
||||
return func(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
err := req.Storage.Delete("role/" + data.Get("name").(string))
|
||||
err := req.Storage.Delete(ctx, "role/"+data.Get("name").(string))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -98,7 +98,7 @@ func (b *databaseBackend) pathRoleDelete() framework.OperationFunc {
|
||||
|
||||
func (b *databaseBackend) pathRoleRead() framework.OperationFunc {
|
||||
return func(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
role, err := b.Role(req.Storage, data.Get("name").(string))
|
||||
role, err := b.Role(ctx, req.Storage, data.Get("name").(string))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -122,7 +122,7 @@ func (b *databaseBackend) pathRoleRead() framework.OperationFunc {
|
||||
|
||||
func (b *databaseBackend) pathRoleList() framework.OperationFunc {
|
||||
return func(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
entries, err := req.Storage.List("role/")
|
||||
entries, err := req.Storage.List(ctx, "role/")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -172,7 +172,7 @@ func (b *databaseBackend) pathRoleCreate() framework.OperationFunc {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := req.Storage.Put(entry); err != nil {
|
||||
if err := req.Storage.Put(ctx, entry); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
@@ -34,7 +34,7 @@ func (b *databaseBackend) secretCredsRenew() framework.OperationFunc {
|
||||
return nil, fmt.Errorf("could not find role with name: %s", req.Secret.InternalData["role"])
|
||||
}
|
||||
|
||||
role, err := b.Role(req.Storage, roleNameRaw.(string))
|
||||
role, err := b.Role(ctx, req.Storage, roleNameRaw.(string))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -99,7 +99,7 @@ func (b *databaseBackend) secretCredsRevoke() framework.OperationFunc {
|
||||
return nil, fmt.Errorf("no role name was provided")
|
||||
}
|
||||
|
||||
role, err := b.Role(req.Storage, roleNameRaw.(string))
|
||||
role, err := b.Role(ctx, req.Storage, roleNameRaw.(string))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package mongodb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -11,9 +12,9 @@ import (
|
||||
"gopkg.in/mgo.v2"
|
||||
)
|
||||
|
||||
func Factory(conf *logical.BackendConfig) (logical.Backend, error) {
|
||||
func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) {
|
||||
b := Backend()
|
||||
if err := b.Setup(conf); err != nil {
|
||||
if err := b.Setup(ctx, conf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return b, nil
|
||||
@@ -59,7 +60,7 @@ type backend struct {
|
||||
}
|
||||
|
||||
// Session returns the database connection.
|
||||
func (b *backend) Session(s logical.Storage) (*mgo.Session, error) {
|
||||
func (b *backend) Session(ctx context.Context, s logical.Storage) (*mgo.Session, error) {
|
||||
b.lock.Lock()
|
||||
defer b.lock.Unlock()
|
||||
|
||||
@@ -70,7 +71,7 @@ func (b *backend) Session(s logical.Storage) (*mgo.Session, error) {
|
||||
b.session.Close()
|
||||
}
|
||||
|
||||
connConfigJSON, err := s.Get("config/connection")
|
||||
connConfigJSON, err := s.Get(ctx, "config/connection")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -99,7 +100,7 @@ func (b *backend) Session(s logical.Storage) (*mgo.Session, error) {
|
||||
}
|
||||
|
||||
// ResetSession forces creation of a new connection next time Session() is called.
|
||||
func (b *backend) ResetSession() {
|
||||
func (b *backend) ResetSession(_ context.Context) {
|
||||
b.lock.Lock()
|
||||
defer b.lock.Unlock()
|
||||
|
||||
@@ -110,16 +111,16 @@ func (b *backend) ResetSession() {
|
||||
b.session = nil
|
||||
}
|
||||
|
||||
func (b *backend) invalidate(key string) {
|
||||
func (b *backend) invalidate(ctx context.Context, key string) {
|
||||
switch key {
|
||||
case "config/connection":
|
||||
b.ResetSession()
|
||||
b.ResetSession(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
// LeaseConfig returns the lease configuration
|
||||
func (b *backend) LeaseConfig(s logical.Storage) (*configLease, error) {
|
||||
entry, err := s.Get("config/lease")
|
||||
func (b *backend) LeaseConfig(ctx context.Context, s logical.Storage) (*configLease, error) {
|
||||
entry, err := s.Get(ctx, "config/lease")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user