Vault SSH: Store roles as slice of strings

This commit is contained in:
vishalnayak
2015-08-31 17:03:46 -04:00
parent f67a12266e
commit 22ff8fc8ad
4 changed files with 28 additions and 25 deletions

View File

@@ -193,30 +193,36 @@ func TestSSHBackend_VerifyEcho(t *testing.T) {
} }
func TestSSHBackend_ConfigZeroAddressCRUD(t *testing.T) { func TestSSHBackend_ConfigZeroAddressCRUD(t *testing.T) {
zeroAddressData1 := map[string]interface{}{ req1 := map[string]interface{}{
"roles": testOTPRoleName, "roles": testOTPRoleName,
} }
zeroAddressData2 := map[string]interface{}{ resp1 := map[string]interface{}{
"roles": []string{testOTPRoleName},
}
req2 := map[string]interface{}{
"roles": fmt.Sprintf("%s,%s", testOTPRoleName, testDynamicRoleName), "roles": fmt.Sprintf("%s,%s", testOTPRoleName, testDynamicRoleName),
} }
zeroAddressData3 := map[string]interface{}{ resp2 := map[string]interface{}{
"roles": "", "roles": []string{testOTPRoleName, testDynamicRoleName},
}
resp3 := map[string]interface{}{
"roles": []string{},
} }
logicaltest.Test(t, logicaltest.TestCase{ logicaltest.Test(t, logicaltest.TestCase{
Factory: Factory, Factory: Factory,
Steps: []logicaltest.TestStep{ Steps: []logicaltest.TestStep{
testRoleWrite(t, testOTPRoleName, testOTPRoleData), testRoleWrite(t, testOTPRoleName, testOTPRoleData),
testConfigZeroAddressWrite(t, zeroAddressData1), testConfigZeroAddressWrite(t, req1),
testConfigZeroAddressRead(t, zeroAddressData1), testConfigZeroAddressRead(t, resp1),
testNamedKeysWrite(t), testNamedKeysWrite(t),
testRoleWrite(t, testDynamicRoleName, testDynamicRoleData), testRoleWrite(t, testDynamicRoleName, testDynamicRoleData),
testConfigZeroAddressWrite(t, zeroAddressData2), testConfigZeroAddressWrite(t, req2),
testConfigZeroAddressRead(t, zeroAddressData2), testConfigZeroAddressRead(t, resp2),
testRoleDelete(t, testDynamicRoleName), testRoleDelete(t, testDynamicRoleName),
testConfigZeroAddressRead(t, zeroAddressData1), testConfigZeroAddressRead(t, resp1),
testRoleDelete(t, testOTPRoleName), testRoleDelete(t, testOTPRoleName),
testConfigZeroAddressRead(t, zeroAddressData3), testConfigZeroAddressRead(t, resp3),
testConfigZeroAddressDelete(t), testConfigZeroAddressDelete(t),
}, },
}) })

View File

@@ -10,7 +10,7 @@ import (
// Structure to hold roles that are allowed to accept any IP address. // Structure to hold roles that are allowed to accept any IP address.
type zeroAddressRoles struct { type zeroAddressRoles struct {
Roles string `json:"roles" mapstructure:"roles"` Roles []string `json:"roles" mapstructure:"roles"`
} }
func pathConfigZeroAddress(b *backend) *framework.Path { func pathConfigZeroAddress(b *backend) *framework.Path {
@@ -76,7 +76,7 @@ func (b *backend) pathConfigZeroAddressWrite(req *logical.Request, d *framework.
} }
} }
err := b.putZeroAddressRoles(req.Storage, roleNames) err := b.putZeroAddressRoles(req.Storage, roles)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -85,7 +85,7 @@ func (b *backend) pathConfigZeroAddressWrite(req *logical.Request, d *framework.
} }
// Stores the given list of roles at zeroaddress endpoint // Stores the given list of roles at zeroaddress endpoint
func (b *backend) putZeroAddressRoles(s logical.Storage, roles string) error { func (b *backend) putZeroAddressRoles(s logical.Storage, roles []string) error {
entry, err := logical.StorageEntryJSON("config/zeroaddress", &zeroAddressRoles{ entry, err := logical.StorageEntryJSON("config/zeroaddress", &zeroAddressRoles{
Roles: roles, Roles: roles,
}) })
@@ -137,31 +137,30 @@ func (b *backend) removeZeroAddressRole(s logical.Storage, roleName string) erro
// Removes a given role from the comma separated string // Removes a given role from the comma separated string
func (r *zeroAddressRoles) Remove(roleName string) error { func (r *zeroAddressRoles) Remove(roleName string) error {
var index int var index int
roles := strings.Split(r.Roles, ",") for i, role := range r.Roles {
for i, role := range roles {
if role == roleName { if role == roleName {
index = i index = i
break break
} }
} }
length := len(roles) length := len(r.Roles)
if index >= length || index < 0 { if index >= length || index < 0 {
return fmt.Errorf("invalid index [%d]", index) return fmt.Errorf("invalid index [%d]", index)
} }
// If slice has zero or one item, remove the item by setting slice to nil. // If slice has zero or one item, remove the item by setting slice to nil.
if length < 2 { if length < 2 {
r.Roles = "" r.Roles = nil
return nil return nil
} }
// Last item to be deleted // Last item to be deleted
if length-1 == index { if length-1 == index {
r.Roles = strings.Join(roles[:length-1], ",") r.Roles = r.Roles[:length-1]
return nil return nil
} }
// Delete the item by appending all items except the one at index // Delete the item by appending all items except the one at index
r.Roles = strings.Join(append(roles[:index], roles[index+1:]...), ",") r.Roles = append(r.Roles[:index], r.Roles[index+1:]...)
return nil return nil
} }

View File

@@ -97,7 +97,7 @@ func (b *backend) pathCredsCreateWrite(
if err != nil { if err != nil {
return nil, fmt.Errorf("error retrieving zero-address roles: %s", err) return nil, fmt.Errorf("error retrieving zero-address roles: %s", err)
} }
var zeroAddressRoles string var zeroAddressRoles []string
if zeroAddressEntry != nil { if zeroAddressEntry != nil {
zeroAddressRoles = zeroAddressEntry.Roles zeroAddressRoles = zeroAddressEntry.Roles
} }
@@ -256,10 +256,9 @@ func (b *backend) GenerateOTPCredential(req *logical.Request, username, ip strin
// excluded CIDR blocks and if IP is found there as well, an error is returned. // excluded CIDR blocks and if IP is found there as well, an error is returned.
// IP is valid only if it is encompassed by allowed CIDR blocks and not by // IP is valid only if it is encompassed by allowed CIDR blocks and not by
// excluded CIDR blocks. // excluded CIDR blocks.
func validateIP(ip, roleName, cidrList, excludeCidrList string, zeroAddressRoles string) error { func validateIP(ip, roleName, cidrList, excludeCidrList string, zeroAddressRoles []string) error {
// Search IP in the zero-address list // Search IP in the zero-address list
roles := strings.Split(zeroAddressRoles, ",") for _, role := range zeroAddressRoles {
for _, role := range roles {
if roleName == role { if roleName == role {
return nil return nil
} }

View File

@@ -3,7 +3,6 @@ package ssh
import ( import (
"fmt" "fmt"
"net" "net"
"strings"
"github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework" "github.com/hashicorp/vault/logical/framework"
@@ -57,7 +56,7 @@ func (b *backend) pathLookupWrite(req *logical.Request, d *framework.FieldData)
return nil, err return nil, err
} }
if zeroAddressEntry != nil { if zeroAddressEntry != nil {
matchingRoles = append(matchingRoles, strings.Split(zeroAddressEntry.Roles, ",")...) matchingRoles = append(matchingRoles, zeroAddressEntry.Roles...)
} }
// This list may potentially reveal more information than it is supposed to. // This list may potentially reveal more information than it is supposed to.