Polish the code

This commit is contained in:
vishalnayak
2016-06-08 03:18:26 -04:00
parent 8b15722fb4
commit f216292e68
8 changed files with 160 additions and 313 deletions

View File

@@ -10,13 +10,13 @@ import (
"github.com/michaelklishin/rabbit-hole" "github.com/michaelklishin/rabbit-hole"
) )
// Factory creates and configures Backends // Factory creates and configures the backend
func Factory(conf *logical.BackendConfig) (logical.Backend, error) { func Factory(conf *logical.BackendConfig) (logical.Backend, error) {
return Backend().Setup(conf) return Backend().Setup(conf)
} }
// Backend creates a new Backend // Creates a new backend with all the paths and secrets belonging to it
func Backend() *framework.Backend { func Backend() *backend {
var b backend var b backend
b.Backend = &framework.Backend{ b.Backend = &framework.Backend{
Help: strings.TrimSpace(backendHelp), Help: strings.TrimSpace(backendHelp),
@@ -25,7 +25,7 @@ func Backend() *framework.Backend {
pathConfigConnection(&b), pathConfigConnection(&b),
pathConfigLease(&b), pathConfigLease(&b),
pathListRoles(&b), pathListRoles(&b),
pathRoleCreate(&b), pathCreds(&b),
pathRoles(&b), pathRoles(&b),
}, },
@@ -33,37 +33,38 @@ func Backend() *framework.Backend {
secretCreds(&b), secretCreds(&b),
}, },
Clean: b.ResetClient, Clean: b.resetClient,
} }
return b.Backend return &b
} }
type backend struct { type backend struct {
*framework.Backend *framework.Backend
client *rabbithole.Client client *rabbithole.Client
lock sync.Mutex lock sync.RWMutex
} }
// DB returns the database connection. // DB returns the database connection.
func (b *backend) Client(s logical.Storage) (*rabbithole.Client, error) { func (b *backend) Client(s logical.Storage) (*rabbithole.Client, error) {
b.lock.Lock() b.lock.RLock()
defer b.lock.Unlock()
// If we already have a client, we got it! // If we already have a client, return it
if b.client != nil { if b.client != nil {
b.lock.RUnlock()
return b.client, nil return b.client, nil
} }
b.lock.RUnlock()
// Otherwise, attempt to make connection // Otherwise, attempt to make connection
entry, err := s.Get("config/connection") entry, err := s.Get("config/connection")
if err != nil { if err != nil {
return nil, err return nil, err
} }
if entry == nil { if entry == nil {
return nil, return nil, fmt.Errorf("configure the client connection with config/connection first")
fmt.Errorf("configure the client connection with config/connection first")
} }
var connConfig connectionConfig var connConfig connectionConfig
@@ -71,6 +72,14 @@ func (b *backend) Client(s logical.Storage) (*rabbithole.Client, error) {
return nil, err return nil, err
} }
b.lock.Lock()
defer b.lock.Unlock()
// If the client was creted during the lock switch, return it
if b.client != nil {
return b.client, nil
}
b.client, err = rabbithole.NewClient(connConfig.URI, connConfig.Username, connConfig.Password) b.client, err = rabbithole.NewClient(connConfig.URI, connConfig.Username, connConfig.Password)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -79,8 +88,8 @@ func (b *backend) Client(s logical.Storage) (*rabbithole.Client, error) {
return b.client, nil return b.client, nil
} }
// ResetClient forces a connection next time Client() is called. // resetClient forces a connection next time Client() is called.
func (b *backend) ResetClient() { func (b *backend) resetClient() {
b.lock.Lock() b.lock.Lock()
defer b.lock.Unlock() defer b.lock.Unlock()
@@ -89,7 +98,7 @@ func (b *backend) ResetClient() {
// Lease returns the lease information // Lease returns the lease information
func (b *backend) Lease(s logical.Storage) (*configLease, error) { func (b *backend) Lease(s logical.Storage) (*configLease, error) {
entry, err := s.Get(leasePatternLabel) entry, err := s.Get("config/lease")
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -42,22 +42,18 @@ func pathConfigConnection(b *backend) *framework.Path {
func (b *backend) pathConnectionUpdate(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { func (b *backend) pathConnectionUpdate(req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
uri := data.Get("connection_uri").(string) uri := data.Get("connection_uri").(string)
username := data.Get("username").(string)
password := data.Get("password").(string)
if uri == "" { if uri == "" {
return logical.ErrorResponse(fmt.Sprintf( return logical.ErrorResponse("missing connection_uri"), nil
"'connection_uri' is a required parameter.")), nil
} }
username := data.Get("username").(string)
if username == "" { if username == "" {
return logical.ErrorResponse(fmt.Sprintf( return logical.ErrorResponse("missing username"), nil
"'username' is a required parameter.")), nil
} }
password := data.Get("password").(string)
if password == "" { if password == "" {
return logical.ErrorResponse(fmt.Sprintf( return logical.ErrorResponse("missing password"), nil
"'password' is a required parameter.")), nil
} }
// Don't check the connection_url if verification is disabled // Don't check the connection_url if verification is disabled
@@ -66,15 +62,12 @@ func (b *backend) pathConnectionUpdate(req *logical.Request, data *framework.Fie
// Create RabbitMQ management client // Create RabbitMQ management client
client, err := rabbithole.NewClient(uri, username, password) client, err := rabbithole.NewClient(uri, username, password)
if err != nil { if err != nil {
return logical.ErrorResponse(fmt.Sprintf( return nil, fmt.Errorf("failed to create client: %s", err)
"Error info: %s", err)), nil
} }
// Verify provided user is able to list users // Verify that configured credentials is capable of listing
_, err = client.ListUsers() if _, err = client.ListUsers(); err != nil {
if err != nil { return nil, fmt.Errorf("failed to validate the connection: %s", err)
return logical.ErrorResponse(fmt.Sprintf(
"Error validating connection info by listing users: %s", err)), nil
} }
} }
@@ -92,15 +85,20 @@ func (b *backend) pathConnectionUpdate(req *logical.Request, data *framework.Fie
} }
// Reset the client connection // Reset the client connection
b.ResetClient() b.resetClient()
return nil, nil return nil, nil
} }
// connectionConfig contains the information required to make a connection to a RabbitMQ node
type connectionConfig struct { type connectionConfig struct {
// URI of the RabbitMQ server
URI string `json:"connection_uri"` URI string `json:"connection_uri"`
VerifyURI string `json:"verify_connection"`
// Username which has 'administrator' tag attached to it
Username string `json:"username"` Username string `json:"username"`
// Password for the Username
Password string `json:"password"` Password string `json:"password"`
} }

View File

@@ -1,38 +1,26 @@
package rabbitmq package rabbitmq
import ( import (
"errors"
"fmt"
"time" "time"
"github.com/fatih/structs"
"github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework" "github.com/hashicorp/vault/logical/framework"
) )
const (
leaseLabel = "ttl"
leaseMaxLabel = "ttl_max"
leasePatternLabel = "config/" + leaseLabel
)
func configFields() map[string]*framework.FieldSchema {
return map[string]*framework.FieldSchema{
leaseLabel: &framework.FieldSchema{
Type: framework.TypeDurationSecond,
Description: "Default " + leaseLabel + " for roles.",
},
leaseMaxLabel: &framework.FieldSchema{
Type: framework.TypeDurationSecond,
Description: "Maximum time a credential is valid for.",
},
}
}
func pathConfigLease(b *backend) *framework.Path { func pathConfigLease(b *backend) *framework.Path {
return &framework.Path{ return &framework.Path{
Pattern: leasePatternLabel, Pattern: "config/lease",
Fields: configFields(), Fields: map[string]*framework.FieldSchema{
"ttl": &framework.FieldSchema{
Type: framework.TypeDurationSecond,
Description: "Duration before which the issued credentials needs renewal",
},
"max_ttl": &framework.FieldSchema{
Type: framework.TypeDurationSecond,
Description: `Duration after which the issued credentials shoulw not be allowed to be renewed`,
},
},
Callbacks: map[logical.Operation]framework.OperationFunc{ Callbacks: map[logical.Operation]framework.OperationFunc{
logical.ReadOperation: b.pathLeaseRead, logical.ReadOperation: b.pathLeaseRead,
@@ -44,17 +32,11 @@ func pathConfigLease(b *backend) *framework.Path {
} }
} }
func (b *backend) pathLeaseUpdate( // Sets the lease configuration parameters
req *logical.Request, d *framework.FieldData) (*logical.Response, error) { func (b *backend) pathLeaseUpdate(req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
lease, leaseMax, err := validateLeases(d) entry, err := logical.StorageEntryJSON("config/lease", &configLease{
if err != nil { TTL: time.Second * time.Duration(d.Get("ttl").(int)),
return nil, err MaxTTL: time.Second * time.Duration(d.Get("ttl").(int)),
}
// Store it
entry, err := logical.StorageEntryJSON(leasePatternLabel, &configLease{
Lease: lease,
LeaseMax: leaseMax,
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@@ -66,10 +48,9 @@ func (b *backend) pathLeaseUpdate(
return nil, nil return nil, nil
} }
func (b *backend) pathLeaseRead( // Returns the lease configuration parameters
req *logical.Request, data *framework.FieldData) (*logical.Response, error) { func (b *backend) pathLeaseRead(req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
lease, err := b.Lease(req.Storage) lease, err := b.Lease(req.Storage)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -78,41 +59,19 @@ func (b *backend) pathLeaseRead(
} }
return &logical.Response{ return &logical.Response{
Data: map[string]interface{}{ Data: structs.New(lease).Map(),
leaseLabel: lease.Lease.String(),
leaseMaxLabel: lease.LeaseMax.String(),
},
}, nil }, nil
} }
// Lease configuration information for the secrets issued by this backend
type configLease struct { type configLease struct {
Lease time.Duration TTL time.Duration `json:"ttl" structs:"ttl" mapstructure:"ttl"`
LeaseMax time.Duration MaxTTL time.Duration `json:"max_ttl" structs:"max_ttl" mapstructure:"max_ttl"`
} }
func validateLeases(data *framework.FieldData) (lease, leaseMax time.Duration, err error) { var pathConfigLeaseHelpSyn = "Configure the lease parameters for generated credentials"
leaseRaw := data.Get(leaseLabel).(int) var pathConfigLeaseHelpDesc = `
leaseMaxRaw := data.Get(leaseMaxLabel).(int) Sets the ttl and max_ttl values for the secrets to be issued by this backend.
Both ttl and max_ttl takes in an integet input as well as inputs like "1h".
if leaseRaw == 0 && leaseMaxRaw == 0 { `
err = errors.New(leaseLabel + " or " + leaseMaxLabel + " must have a value")
return
}
return time.Duration(leaseRaw) * time.Second, time.Duration(leaseMaxRaw) * time.Second, nil
}
var pathConfigLeaseHelpSyn = fmt.Sprintf(`
Configure the default %s information for generated credentials.
`, leaseLabel)
var pathConfigLeaseHelpDesc = fmt.Sprintf(`
This configures the default %s information used for credentials
generated by this backend. The %s specifies the duration that a
credential will be valid for, as well as the maximum session for
a set of credentials.
The format for the %s is "1h" or integer and then unit. The longest
unit is hour.
`, leaseLabel, leaseLabel, leaseLabel)

View File

@@ -1,53 +1,7 @@
package rabbitmq package rabbitmq
import ( import "testing"
"testing"
"github.com/hashicorp/vault/logical/framework" func TestBackend_config_lease(t *testing.T) {
)
type validateLeasesTestCase struct {
Lease int
LeaseMax int
Fail bool
}
func TestConfigLease_validateLeases(t *testing.T) {
cases := map[string]validateLeasesTestCase{
"Both lease and lease max": {
Lease: 60 * 60,
LeaseMax: 60 * 60,
},
"Just lease": {
Lease: 60 * 60,
LeaseMax: 0,
},
"No lease nor lease max": {
Lease: 0,
LeaseMax: 0,
Fail: true,
},
}
data := &framework.FieldData{
Schema: configFields(),
}
for name, c := range cases {
data.Raw = map[string]interface{}{
leaseLabel: c.Lease,
leaseMaxLabel: c.LeaseMax,
}
_, _, err := validateLeases(data)
if err != nil && c.Fail {
// This was expected
continue
} else if err != nil {
// This was unexpected
t.Errorf("Failed: %s", name)
} else if err == nil && c.Fail {
// This was unexpected
t.Errorf("Failed to fail: %s", name)
}
}
} }

View File

@@ -9,7 +9,7 @@ import (
"github.com/michaelklishin/rabbit-hole" "github.com/michaelklishin/rabbit-hole"
) )
func pathRoleCreate(b *backend) *framework.Path { func pathCreds(b *backend) *framework.Path {
return &framework.Path{ return &framework.Path{
Pattern: "creds/" + framework.GenericNameRegex("name"), Pattern: "creds/" + framework.GenericNameRegex("name"),
Fields: map[string]*framework.FieldSchema{ Fields: map[string]*framework.FieldSchema{
@@ -20,7 +20,7 @@ func pathRoleCreate(b *backend) *framework.Path {
}, },
Callbacks: map[logical.Operation]framework.OperationFunc{ Callbacks: map[logical.Operation]framework.OperationFunc{
logical.ReadOperation: b.pathRoleCreateRead, logical.ReadOperation: b.pathCredsRead,
}, },
HelpSynopsis: pathRoleCreateReadHelpSyn, HelpSynopsis: pathRoleCreateReadHelpSyn,
@@ -28,12 +28,11 @@ func pathRoleCreate(b *backend) *framework.Path {
} }
} }
func (b *backend) pathRoleCreateRead( // Issues the credential based on the role name
req *logical.Request, data *framework.FieldData) (*logical.Response, error) { func (b *backend) pathCredsRead(req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
// Validate name name := d.Get("name").(string)
name, err := validateName(data) if name == "" {
if err != nil { return logical.ErrorResponse("missing name"), nil
return nil, err
} }
// Get the role // Get the role
@@ -45,53 +44,40 @@ func (b *backend) pathRoleCreateRead(
return logical.ErrorResponse(fmt.Sprintf("unknown role: %s", name)), nil return logical.ErrorResponse(fmt.Sprintf("unknown role: %s", name)), nil
} }
// Determine if we have a lease
lease, err := b.Lease(req.Storage)
if err != nil {
return nil, err
}
if lease == nil {
lease = &configLease{}
}
// Ensure username is unique // Ensure username is unique
username := fmt.Sprintf("%s-%s", req.DisplayName, uuid.GenerateUUID()) username := fmt.Sprintf("%s-%s", req.DisplayName, uuid.GenerateUUID())
password := uuid.GenerateUUID() password := uuid.GenerateUUID()
// Get our connection // Get the client configuration
client, err := b.Client(req.Storage) client, err := b.Client(req.Storage)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if client == nil { if client == nil {
return logical.ErrorResponse("unable to get client"), nil return logical.ErrorResponse("failed to get the client"), nil
} }
// Create the user // Register the generated credentials in the backend, with the RabbitMQ server
_, err = client.PutUser(username, rabbithole.UserSettings{ if _, err = client.PutUser(username, rabbithole.UserSettings{
Password: password, Password: password,
Tags: role.Tags, Tags: role.Tags,
}) }); err != nil {
return nil, fmt.Errorf("failed to create a new user with the generated credentials")
if err != nil {
return nil, err
} }
// If the role had vhost permissions specified, assign those permissions
// to the created username for respective vhosts.
for vhost, permission := range role.VHosts { for vhost, permission := range role.VHosts {
_, err := client.UpdatePermissionsIn(vhost, username, rabbithole.Permissions{ if _, err := client.UpdatePermissionsIn(vhost, username, rabbithole.Permissions{
Configure: permission.Configure, Configure: permission.Configure,
Write: permission.Write, Write: permission.Write,
Read: permission.Read, Read: permission.Read,
}) }); err != nil {
if err != nil {
// Delete the user because it's in an unknown state // Delete the user because it's in an unknown state
_, rmErr := client.DeleteUser(username) if _, rmErr := client.DeleteUser(username); rmErr != nil {
if rmErr != nil { return nil, fmt.Errorf("failed to delete user:%s, err: %s. %s", username, err, rmErr)
return logical.ErrorResponse(fmt.Sprintf("failed to update user: %s, failed to delete user: %s, user: %s", err, rmErr, username)), rmErr
} }
return logical.ErrorResponse(fmt.Sprintf("failed to update user: %s, user: %s", err, username)), err return nil, fmt.Errorf("failed to update permissions to the %s user. err:%s", username, err)
} }
} }
@@ -102,7 +88,16 @@ func (b *backend) pathRoleCreateRead(
}, map[string]interface{}{ }, map[string]interface{}{
"username": username, "username": username,
}) })
resp.Secret.TTL = lease.Lease
// Determine if we have a lease
lease, err := b.Lease(req.Storage)
if err != nil {
return nil, err
}
if lease != nil {
resp.Secret.TTL = lease.TTL
}
return resp, nil return resp, nil
} }

View File

@@ -2,40 +2,19 @@ package rabbitmq
import ( import (
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"github.com/fatih/structs"
"github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework" "github.com/hashicorp/vault/logical/framework"
) )
func rolesFields() map[string]*framework.FieldSchema {
return map[string]*framework.FieldSchema{
"name": &framework.FieldSchema{
Type: framework.TypeString,
Description: "Name of the role.",
},
"tags": &framework.FieldSchema{
Type: framework.TypeString,
Description: "Comma-separated list of tags for this role.",
},
"vhosts": &framework.FieldSchema{
Type: framework.TypeString,
Description: "A map of virtual hosts to permissions.",
},
}
}
func pathListRoles(b *backend) *framework.Path { func pathListRoles(b *backend) *framework.Path {
return &framework.Path{ return &framework.Path{
Pattern: "roles/?$", Pattern: "roles/?$",
Callbacks: map[logical.Operation]framework.OperationFunc{ Callbacks: map[logical.Operation]framework.OperationFunc{
logical.ListOperation: b.pathRoleList, logical.ListOperation: b.pathRoleList,
}, },
HelpSynopsis: pathRoleHelpSyn, HelpSynopsis: pathRoleHelpSyn,
HelpDescription: pathRoleHelpDesc, HelpDescription: pathRoleHelpDesc,
} }
@@ -44,19 +23,31 @@ func pathListRoles(b *backend) *framework.Path {
func pathRoles(b *backend) *framework.Path { func pathRoles(b *backend) *framework.Path {
return &framework.Path{ return &framework.Path{
Pattern: "roles/" + framework.GenericNameRegex("name"), Pattern: "roles/" + framework.GenericNameRegex("name"),
Fields: rolesFields(), Fields: map[string]*framework.FieldSchema{
"name": &framework.FieldSchema{
Type: framework.TypeString,
Description: "Name of the role.",
},
"tags": &framework.FieldSchema{
Type: framework.TypeString,
Description: "Comma-separated list of tags for this role.",
},
"vhosts": &framework.FieldSchema{
Type: framework.TypeString,
Description: "A map of virtual hosts to permissions.",
},
},
Callbacks: map[logical.Operation]framework.OperationFunc{ Callbacks: map[logical.Operation]framework.OperationFunc{
logical.ReadOperation: b.pathRoleRead, logical.ReadOperation: b.pathRoleRead,
logical.UpdateOperation: b.pathRoleUpdate, logical.UpdateOperation: b.pathRoleUpdate,
logical.DeleteOperation: b.pathRoleDelete, logical.DeleteOperation: b.pathRoleDelete,
}, },
HelpSynopsis: pathRoleHelpSyn, HelpSynopsis: pathRoleHelpSyn,
HelpDescription: pathRoleHelpDesc, HelpDescription: pathRoleHelpDesc,
} }
} }
// Reads the role configuration from the storage
func (b *backend) Role(s logical.Storage, n string) (*roleEntry, error) { func (b *backend) Role(s logical.Storage, n string) (*roleEntry, error) {
entry, err := s.Get("role/" + n) entry, err := s.Get("role/" + n)
if err != nil { if err != nil {
@@ -74,28 +65,21 @@ func (b *backend) Role(s logical.Storage, n string) (*roleEntry, error) {
return &result, nil return &result, nil
} }
func (b *backend) pathRoleDelete( // Deletes an existing role
req *logical.Request, data *framework.FieldData) (*logical.Response, error) { func (b *backend) pathRoleDelete(req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
name := d.Get("name").(string)
name, err := validateName(data) if name == "" {
if err != nil { return logical.ErrorResponse("missing name"), nil
return nil, err
} }
err = req.Storage.Delete("role/" + name) return nil, req.Storage.Delete("role/" + name)
if err != nil {
return nil, err
}
return nil, nil
} }
func (b *backend) pathRoleRead( // Reads an existing role
req *logical.Request, data *framework.FieldData) (*logical.Response, error) { func (b *backend) pathRoleRead(req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
name := d.Get("name").(string)
name, err := validateName(data) if name == "" {
if err != nil { return logical.ErrorResponse("missing name"), nil
return nil, err
} }
role, err := b.Role(req.Storage, name) role, err := b.Role(req.Storage, name)
@@ -107,31 +91,34 @@ func (b *backend) pathRoleRead(
} }
return &logical.Response{ return &logical.Response{
Data: map[string]interface{}{ Data: structs.New(role).Map(),
"tags": role.Tags,
"vhosts": role.VHosts,
},
}, nil }, nil
} }
// Lists all the roles registered with the backend
func (b *backend) pathRoleList( func (b *backend) pathRoleList(
req *logical.Request, d *framework.FieldData) (*logical.Response, error) { req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
entries, err := req.Storage.List("role/") roles, err := req.Storage.List("role/")
if err != nil { if err != nil {
return nil, err return nil, err
} }
return logical.ListResponse(entries), nil return logical.ListResponse(roles), nil
} }
func (b *backend) pathRoleUpdate( // Registers a new role with the backend
req *logical.Request, data *framework.FieldData) (*logical.Response, error) { func (b *backend) pathRoleUpdate(req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
name, err := validateName(data) name := d.Get("name").(string)
if err != nil { if name == "" {
return nil, err return logical.ErrorResponse("missing name"), nil
}
tags := d.Get("tags").(string)
rawVHosts := d.Get("vhosts").(string)
if tags == "" && rawVHosts == "" {
return logical.ErrorResponse("both tags and vhosts not specified"), nil
} }
tags := data.Get("tags").(string)
rawVHosts := data.Get("vhosts").(string)
var vhosts map[string]vhostPermission var vhosts map[string]vhostPermission
if len(rawVHosts) > 0 { if len(rawVHosts) > 0 {
@@ -156,24 +143,17 @@ func (b *backend) pathRoleUpdate(
return nil, nil return nil, nil
} }
// Role that defines the capabilities of the credentials issued against it
type roleEntry struct { type roleEntry struct {
Tags string `json:"tags"` Tags string `json:"tags" structs:"tags" mapstructure:"tags"`
VHosts map[string]vhostPermission `json:"vhosts"` VHosts map[string]vhostPermission `json:"vhosts" structs:"vhosts" mapstructure:"vhosts"`
} }
// Structure representing the permissions of a vhost
type vhostPermission struct { type vhostPermission struct {
Configure string `json:"configure"` Configure string `json:"configure" structs:"configure" mapstructure:"configure"`
Write string `json:"write"` Write string `json:"write" structs:"write" mapstructure:"write"`
Read string `json:"read"` Read string `json:"read" structs:"write" mapstructure:"read"`
}
func validateName(data *framework.FieldData) (string, error) {
name := data.Get("name").(string)
if len(name) == 0 {
return "", errors.New("name is required")
}
return name, nil
} }
const pathRoleHelpSyn = ` const pathRoleHelpSyn = `

View File

@@ -1,42 +1 @@
package rabbitmq package rabbitmq
import (
"testing"
"github.com/hashicorp/vault/logical/framework"
)
type validateNameTestCase struct {
Name string
Fail bool
}
func TestRoles_validateName(t *testing.T) {
cases := map[string]validateNameTestCase{
"test name": {
Name: "test",
},
"empty name": {
Name: "",
Fail: true,
},
}
data := &framework.FieldData{
Schema: rolesFields(),
}
for name, c := range cases {
data.Raw = map[string]interface{}{
"name": c.Name,
}
actual, err := validateName(data)
if err != nil && !c.Fail {
t.Error(err)
}
if c.Name != actual {
t.Errorf("Fail: %s: expected %s, got %s", name, c.Name, actual)
}
}
}

View File

@@ -16,20 +16,19 @@ func secretCreds(b *backend) *framework.Secret {
Fields: map[string]*framework.FieldSchema{ Fields: map[string]*framework.FieldSchema{
"username": &framework.FieldSchema{ "username": &framework.FieldSchema{
Type: framework.TypeString, Type: framework.TypeString,
Description: "Username", Description: "RabbitMQ username",
}, },
"password": &framework.FieldSchema{ "password": &framework.FieldSchema{
Type: framework.TypeString, Type: framework.TypeString,
Description: "Password", Description: "Password for the RabbitMQ username",
}, },
}, },
Renew: b.secretCredsRenew, Renew: b.secretCredsRenew,
Revoke: b.secretCredsRevoke, Revoke: b.secretCredsRevoke,
} }
} }
// Renew the previously issued secret
func (b *backend) secretCredsRenew( func (b *backend) secretCredsRenew(
req *logical.Request, d *framework.FieldData) (*logical.Response, error) { req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
// Get the lease information // Get the lease information
@@ -41,15 +40,10 @@ func (b *backend) secretCredsRenew(
lease = &configLease{} lease = &configLease{}
} }
f := framework.LeaseExtend(lease.Lease, lease.LeaseMax, b.System()) return framework.LeaseExtend(lease.TTL, lease.MaxTTL, b.System())(req, d)
resp, err := f(req, d)
if err != nil {
return nil, err
}
return resp, nil
} }
// Revoke the previously issued secret
func (b *backend) secretCredsRevoke( func (b *backend) secretCredsRevoke(
req *logical.Request, d *framework.FieldData) (*logical.Response, error) { req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
// Get the username from the internal data // Get the username from the internal data
@@ -65,9 +59,8 @@ func (b *backend) secretCredsRevoke(
return nil, err return nil, err
} }
_, err = client.DeleteUser(username) if _, err = client.DeleteUser(username); err != nil {
if err != nil { return nil, fmt.Errorf("could not delete user: %s", err)
return logical.ErrorResponse(fmt.Sprintf("could not delete user: %s", err)), nil
} }
return nil, nil return nil, nil