mirror of
https://github.com/optim-enterprises-bv/vault.git
synced 2025-11-01 19:17:58 +00:00
Removed region parameter from config/client endpoint.
Region to create ec2 client objects is fetched from the identity document. Maintaining a map of cached clients indexed by region.
This commit is contained in:
@@ -55,6 +55,8 @@ func Backend(conf *logical.BackendConfig) (*framework.Backend, error) {
|
||||
AuthRenew: b.pathLoginRenew,
|
||||
}
|
||||
|
||||
b.EC2ClientsMap = make(map[string]*ec2.EC2)
|
||||
|
||||
return b.Backend, nil
|
||||
}
|
||||
|
||||
@@ -64,7 +66,7 @@ type backend struct {
|
||||
|
||||
configMutex sync.RWMutex
|
||||
|
||||
ec2Client *ec2.EC2
|
||||
EC2ClientsMap map[string]*ec2.EC2
|
||||
}
|
||||
|
||||
const backendHelp = `
|
||||
|
||||
@@ -2,7 +2,6 @@ package aws
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/aws/credentials"
|
||||
@@ -21,7 +20,7 @@ import (
|
||||
// * Static credentials from 'config/client'
|
||||
// * Environment variables
|
||||
// * Instance metadata role
|
||||
func (b *backend) getClientConfig(s logical.Storage) (*aws.Config, error) {
|
||||
func (b *backend) getClientConfig(s logical.Storage, region string) (*aws.Config, error) {
|
||||
// Read the configured secret key and access key
|
||||
config, err := clientConfigEntry(s)
|
||||
if err != nil {
|
||||
@@ -29,13 +28,8 @@ func (b *backend) getClientConfig(s logical.Storage) (*aws.Config, error) {
|
||||
}
|
||||
|
||||
var providers []credentials.Provider
|
||||
region := os.Getenv("AWS_REGION")
|
||||
|
||||
if config != nil {
|
||||
if config.Region != "" {
|
||||
region = config.Region
|
||||
}
|
||||
|
||||
switch {
|
||||
case config.AccessKey != "" && config.SecretKey != "":
|
||||
providers = append(providers, &credentials.StaticProvider{
|
||||
@@ -75,13 +69,23 @@ func (b *backend) getClientConfig(s logical.Storage) (*aws.Config, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
// flushCachedEC2Clients deletes all the cached ec2 client objects from the backend.
|
||||
func (b *backend) flushCachedEC2Clients() {
|
||||
b.configMutex.Lock()
|
||||
defer b.configMutex.Unlock()
|
||||
|
||||
for region, _ := range b.EC2ClientsMap {
|
||||
delete(b.EC2ClientsMap, region)
|
||||
}
|
||||
}
|
||||
|
||||
// clientEC2 creates a client to interact with AWS EC2 API.
|
||||
func (b *backend) clientEC2(s logical.Storage, recreate bool) (*ec2.EC2, error) {
|
||||
func (b *backend) clientEC2(s logical.Storage, region string, recreate bool) (*ec2.EC2, error) {
|
||||
if !recreate {
|
||||
b.configMutex.RLock()
|
||||
if b.ec2Client != nil {
|
||||
if b.EC2ClientsMap[region] != nil {
|
||||
defer b.configMutex.RUnlock()
|
||||
return b.ec2Client, nil
|
||||
return b.EC2ClientsMap[region], nil
|
||||
}
|
||||
b.configMutex.RUnlock()
|
||||
}
|
||||
@@ -89,11 +93,11 @@ func (b *backend) clientEC2(s logical.Storage, recreate bool) (*ec2.EC2, error)
|
||||
b.configMutex.Lock()
|
||||
defer b.configMutex.Unlock()
|
||||
|
||||
awsConfig, err := b.getClientConfig(s)
|
||||
awsConfig, err := b.getClientConfig(s, region)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
b.ec2Client = ec2.New(session.New(awsConfig))
|
||||
return b.ec2Client, nil
|
||||
b.EC2ClientsMap[region] = ec2.New(session.New(awsConfig))
|
||||
return b.EC2ClientsMap[region], nil
|
||||
}
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
package aws
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/fatih/structs"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/logical/framework"
|
||||
@@ -21,11 +19,6 @@ func pathConfigClient(b *backend) *framework.Path {
|
||||
Type: framework.TypeString,
|
||||
Description: "AWS Secret key with permissions to query EC2 instance metadata.",
|
||||
},
|
||||
|
||||
"region": &framework.FieldSchema{
|
||||
Type: framework.TypeString,
|
||||
Description: "Region for API calls. Defaults to the value of the AWS_REGION env var. Required.",
|
||||
},
|
||||
},
|
||||
|
||||
ExistenceCheck: b.pathConfigClientExistenceCheck,
|
||||
@@ -104,10 +97,8 @@ func (b *backend) pathConfigClientDelete(
|
||||
|
||||
b.configMutex.Unlock()
|
||||
|
||||
_, err = b.clientEC2(req.Storage, true)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating client with updated credentials: %s", err)
|
||||
}
|
||||
// Remove all the cached EC2 client objects in the backend.
|
||||
b.flushCachedEC2Clients()
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
@@ -127,21 +118,14 @@ func (b *backend) pathConfigClientCreateUpdate(
|
||||
configEntry = &clientConfig{}
|
||||
}
|
||||
|
||||
regionStr, ok := data.GetOk("region")
|
||||
if ok {
|
||||
configEntry.Region = regionStr.(string)
|
||||
} else if req.Operation == logical.CreateOperation {
|
||||
configEntry.Region = data.Get("region").(string)
|
||||
}
|
||||
|
||||
changedCreds := false
|
||||
|
||||
accessKeyStr, ok := data.GetOk("access_key")
|
||||
if ok {
|
||||
if configEntry.AccessKey != accessKeyStr.(string) {
|
||||
changedCreds = true
|
||||
configEntry.AccessKey = accessKeyStr.(string)
|
||||
}
|
||||
configEntry.AccessKey = accessKeyStr.(string)
|
||||
} else if req.Operation == logical.CreateOperation {
|
||||
// Use the default
|
||||
configEntry.AccessKey = data.Get("access_key").(string)
|
||||
@@ -151,8 +135,8 @@ func (b *backend) pathConfigClientCreateUpdate(
|
||||
if ok {
|
||||
if configEntry.SecretKey != secretKeyStr.(string) {
|
||||
changedCreds = true
|
||||
configEntry.SecretKey = secretKeyStr.(string)
|
||||
}
|
||||
configEntry.SecretKey = secretKeyStr.(string)
|
||||
} else if req.Operation == logical.CreateOperation {
|
||||
configEntry.SecretKey = data.Get("secret_key").(string)
|
||||
}
|
||||
@@ -170,13 +154,8 @@ func (b *backend) pathConfigClientCreateUpdate(
|
||||
// We have to be careful here to re-lock as we have a deferred unlock
|
||||
// queued up and unlocking an unlocked mutex leads to a panic
|
||||
b.configMutex.Unlock()
|
||||
_, err = b.clientEC2(req.Storage, true)
|
||||
b.flushCachedEC2Clients()
|
||||
b.configMutex.Lock()
|
||||
if err != nil {
|
||||
return logical.ErrorResponse(fmt.Sprintf(
|
||||
"error creating client with updated credentials: %s", err),
|
||||
), nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
@@ -187,7 +166,6 @@ func (b *backend) pathConfigClientCreateUpdate(
|
||||
type clientConfig struct {
|
||||
AccessKey string `json:"access_key" structs:"access_key" mapstructure:"access_key"`
|
||||
SecretKey string `json:"secret_key" structs:"secret_key" mapstructure:"secret_key"`
|
||||
Region string `json:"region" structs:"region" mapstructure:"region"`
|
||||
}
|
||||
|
||||
const pathConfigClientHelpSyn = `
|
||||
|
||||
@@ -38,18 +38,18 @@ func pathLogin(b *backend) *framework.Path {
|
||||
}
|
||||
}
|
||||
|
||||
// validateInstanceID queries the status of the EC2 instance using AWS EC2 API and
|
||||
// validateInstance queries the status of the EC2 instance using AWS EC2 API and
|
||||
// checks if the instance is running and is healthy.
|
||||
func (b *backend) validateInstanceID(s logical.Storage, instanceID string) error {
|
||||
func (b *backend) validateInstance(s logical.Storage, identityDoc *identityDocument) error {
|
||||
// Create an EC2 client to pull the instance information
|
||||
ec2Client, err := b.clientEC2(s, false)
|
||||
ec2Client, err := b.clientEC2(s, identityDoc.Region, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Get the status of the instance
|
||||
instanceStatus, err := ec2Client.DescribeInstanceStatus(&ec2.DescribeInstanceStatusInput{
|
||||
InstanceIds: []*string{aws.String(instanceID)},
|
||||
InstanceIds: []*string{aws.String(identityDoc.InstanceID)},
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -186,12 +186,9 @@ func (b *backend) pathLoginUpdate(
|
||||
}
|
||||
|
||||
// Validate the instance ID.
|
||||
//TODO: uncomment this block, until the API invoking problem is resolved.
|
||||
/*
|
||||
if err := b.validateInstanceID(req.Storage, identityDoc.InstanceID); err != nil {
|
||||
return logical.ErrorResponse(fmt.Sprintf("failed to verify instance ID: %s", err)), nil
|
||||
}
|
||||
*/
|
||||
if err := b.validateInstance(req.Storage, identityDoc); err != nil {
|
||||
return logical.ErrorResponse(fmt.Sprintf("failed to verify instance ID: %s", err)), nil
|
||||
}
|
||||
|
||||
// Get the entry for the AMI used by the instance.
|
||||
imageEntry, err := awsImage(req.Storage, identityDoc.AmiID)
|
||||
@@ -311,8 +308,8 @@ func (b *backend) pathLoginUpdate(
|
||||
|
||||
// fetchRoleTagValue creates an AWS EC2 client and queries the tags
|
||||
// attached to the instance identified by the given instanceID.
|
||||
func (b *backend) fetchRoleTagValue(s logical.Storage, tagKey string) (string, error) {
|
||||
ec2Client, err := b.clientEC2(s, false)
|
||||
func (b *backend) fetchRoleTagValue(s logical.Storage, region string, tagKey string) (string, error) {
|
||||
ec2Client, err := b.clientEC2(s, region, false)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -350,7 +347,7 @@ func (b *backend) handleRoleTagLogin(s logical.Storage, identityDoc *identityDoc
|
||||
// NOTE: If AWS adds the instance tags as meta-data in the instance identity
|
||||
// document, then it is better to look this information there instead of making
|
||||
// another API call. Currently, we don't have an option but make this call.
|
||||
rTagValue, err := b.fetchRoleTagValue(s, imageEntry.RoleTag)
|
||||
rTagValue, err := b.fetchRoleTagValue(s, identityDoc.Region, imageEntry.RoleTag)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user