Merge branch 'master' into acl-parameters-permission

This commit is contained in:
Brian Kassouf
2017-02-21 14:46:06 -08:00
287 changed files with 6237 additions and 2356 deletions

View File

@@ -7,7 +7,7 @@ services:
- docker
go:
- 1.8rc2
- 1.8
matrix:
allow_failures:

View File

@@ -1,5 +1,16 @@
## Next (Unreleased)
DEPRECATIONS/CHANGES:
* List Operations Always Use Trailing Slash: Any list operation, whether via
the `GET` or `LIST` HTTP verb, will now internally canonicalize the path to
have a trailing slash. This makes policy writing more predictable, as it
means clients will no longer work or fail based on which client they're
using or which HTTP verb they're using. However, it also means that policies
allowing `list` capability must be carefully checked to ensure that they
contain a trailing slash; some policies may need to be split into multiple
stanzas to accommodate.
IMPROVEMENTS:
* auth/ldap: Use the value of the `LOGNAME` or `USER` env vars for the
@@ -7,14 +18,20 @@ IMPROVEMENTS:
[GH-2154]
* audit: Support adding a configurable prefix (such as `@cee`) before each
line [GH-2359]
* core: Canonicalize list operations to use a trailing slash [GH-2390]
* secret/pki: O (Organization) values can now be set to role-defined values
for issued/signed certificates [GH-2369]
BUG FIXES:
* audit: When auditing headers use case-insensitive comparisons [GH-2362]
* auth/aws-ec2: Return role period in seconds and not nanoseconds [GH-2374]
* auth/okta: Fix panic if user had no local groups and/or policies set
[GH-2367]
* command/server: Fix parsing of redirect address when port is not mentioned
[GH-2354]
* physical/postgresql: Fix listing returning incorrect results if there were
multiple levels of children [GH-2393]
## 0.6.5 (February 7th, 2017)

View File

@@ -24,6 +24,11 @@ dev-dynamic: generate
test: generate
CGO_ENABLED=0 VAULT_TOKEN= VAULT_ACC= go test -tags='$(BUILD_TAGS)' $(TEST) $(TESTARGS) -timeout=10m -parallel=4
testcompile: generate
@for pkg in $(TEST) ; do \
go test -v -c -tags='$(BUILD_TAGS)' $$pkg -parallel=4 ; \
done
# testacc runs acceptance tests
testacc: generate
@if [ "$(TEST)" = "./..." ]; then \

View File

@@ -56,9 +56,9 @@ All documentation is available on the [Vault website](https://www.vaultproject.i
Developing Vault
--------------------
If you wish to work on Vault itself or any of its built-in systems,
you'll first need [Go](https://www.golang.org) installed on your
machine (version 1.8+ is *required*).
If you wish to work on Vault itself or any of its built-in systems, you'll
first need [Go](https://www.golang.org) installed on your machine (version 1.8+
is *required*).
For local dev first make sure Go is properly installed, including setting up a
[GOPATH](https://golang.org/doc/code.html#GOPATH). Next, clone this repository

View File

@@ -3,6 +3,7 @@ package api
import (
"fmt"
"github.com/fatih/structs"
"github.com/mitchellh/mapstructure"
)
@@ -71,13 +72,18 @@ func (c *Sys) ListAudit() (map[string]*Audit, error) {
return mounts, nil
}
// DEPRECATED: Use EnableAuditWithOptions instead
func (c *Sys) EnableAudit(
path string, auditType string, desc string, opts map[string]string) error {
body := map[string]interface{}{
"type": auditType,
"description": desc,
"options": opts,
}
return c.EnableAuditWithOptions(path, &EnableAuditOptions{
Type: auditType,
Description: desc,
Options: opts,
})
}
func (c *Sys) EnableAuditWithOptions(path string, options *EnableAuditOptions) error {
body := structs.Map(options)
r := c.c.NewRequest("PUT", fmt.Sprintf("/v1/sys/audit/%s", path))
if err := r.SetJSONBody(body); err != nil {
@@ -106,9 +112,17 @@ func (c *Sys) DisableAudit(path string) error {
// individually documented because the map almost directly to the raw HTTP API
// documentation. Please refer to that documentation for more details.
type EnableAuditOptions struct {
Type string `json:"type" structs:"type"`
Description string `json:"description" structs:"description"`
Options map[string]string `json:"options" structs:"options"`
Local bool `json:"local" structs:"local"`
}
type Audit struct {
Path string
Type string
Description string
Options map[string]string
Local bool
}

View File

@@ -3,6 +3,7 @@ package api
import (
"fmt"
"github.com/fatih/structs"
"github.com/mitchellh/mapstructure"
)
@@ -42,11 +43,16 @@ func (c *Sys) ListAuth() (map[string]*AuthMount, error) {
return mounts, nil
}
// DEPRECATED: Use EnableAuthWithOptions instead
func (c *Sys) EnableAuth(path, authType, desc string) error {
body := map[string]string{
"type": authType,
"description": desc,
}
return c.EnableAuthWithOptions(path, &EnableAuthOptions{
Type: authType,
Description: desc,
})
}
func (c *Sys) EnableAuthWithOptions(path string, options *EnableAuthOptions) error {
body := structs.Map(options)
r := c.c.NewRequest("POST", fmt.Sprintf("/v1/sys/auth/%s", path))
if err := r.SetJSONBody(body); err != nil {
@@ -75,10 +81,17 @@ func (c *Sys) DisableAuth(path string) error {
// individually documentd because the map almost directly to the raw HTTP API
// documentation. Please refer to that documentation for more details.
type EnableAuthOptions struct {
Type string `json:"type" structs:"type"`
Description string `json:"description" structs:"description"`
Local bool `json:"local" structs:"local"`
}
type AuthMount struct {
Type string `json:"type" structs:"type" mapstructure:"type"`
Description string `json:"description" structs:"description" mapstructure:"description"`
Config AuthConfigOutput `json:"config" structs:"config" mapstructure:"config"`
Local bool `json:"local" structs:"local" mapstructure:"local"`
}
type AuthConfigOutput struct {

View File

@@ -123,6 +123,7 @@ type MountInput struct {
Type string `json:"type" structs:"type"`
Description string `json:"description" structs:"description"`
Config MountConfigInput `json:"config" structs:"config"`
Local bool `json:"local" structs:"local"`
}
type MountConfigInput struct {
@@ -134,6 +135,7 @@ type MountOutput struct {
Type string `json:"type" structs:"type"`
Description string `json:"description" structs:"description"`
Config MountConfigOutput `json:"config" structs:"config"`
Local bool `json:"local" structs:"local"`
}
type MountConfigOutput struct {

View File

@@ -27,7 +27,11 @@ func (f *AuditFormatter) FormatRequest(
config FormatterConfig,
auth *logical.Auth,
req *logical.Request,
err error) error {
inErr error) error {
if req == nil {
return fmt.Errorf("request to request-audit a nil request")
}
if w == nil {
return fmt.Errorf("writer for audit request is nil")
@@ -49,22 +53,26 @@ func (f *AuditFormatter) FormatRequest(
}()
}
// Copy the structures
cp, err := copystructure.Copy(auth)
if err != nil {
return err
// Copy the auth structure
if auth != nil {
cp, err := copystructure.Copy(auth)
if err != nil {
return err
}
auth = cp.(*logical.Auth)
}
auth = cp.(*logical.Auth)
cp, err = copystructure.Copy(req)
cp, err := copystructure.Copy(req)
if err != nil {
return err
}
req = cp.(*logical.Request)
// Hash any sensitive information
if err := Hash(config.Salt, auth); err != nil {
return err
if auth != nil {
if err := Hash(config.Salt, auth); err != nil {
return err
}
}
// Cache and restore accessor in the request
@@ -85,8 +93,8 @@ func (f *AuditFormatter) FormatRequest(
auth = new(logical.Auth)
}
var errString string
if err != nil {
errString = err.Error()
if inErr != nil {
errString = inErr.Error()
}
reqEntry := &AuditRequestEntry{
@@ -107,6 +115,7 @@ func (f *AuditFormatter) FormatRequest(
Path: req.Path,
Data: req.Data,
RemoteAddr: getRemoteAddr(req),
ReplicationCluster: req.ReplicationCluster,
Headers: req.Headers,
},
}
@@ -128,7 +137,11 @@ func (f *AuditFormatter) FormatResponse(
auth *logical.Auth,
req *logical.Request,
resp *logical.Response,
err error) error {
inErr error) error {
if req == nil {
return fmt.Errorf("request to response-audit a nil request")
}
if w == nil {
return fmt.Errorf("writer for audit request is nil")
@@ -150,37 +163,43 @@ func (f *AuditFormatter) FormatResponse(
}()
}
// Copy the structure
cp, err := copystructure.Copy(auth)
if err != nil {
return err
// Copy the auth structure
if auth != nil {
cp, err := copystructure.Copy(auth)
if err != nil {
return err
}
auth = cp.(*logical.Auth)
}
auth = cp.(*logical.Auth)
cp, err = copystructure.Copy(req)
cp, err := copystructure.Copy(req)
if err != nil {
return err
}
req = cp.(*logical.Request)
cp, err = copystructure.Copy(resp)
if err != nil {
return err
if resp != nil {
cp, err := copystructure.Copy(resp)
if err != nil {
return err
}
resp = cp.(*logical.Response)
}
resp = cp.(*logical.Response)
// Hash any sensitive information
// Cache and restore accessor in the auth
var accessor, wrappedAccessor string
if !config.HMACAccessor && auth != nil && auth.Accessor != "" {
accessor = auth.Accessor
}
if err := Hash(config.Salt, auth); err != nil {
return err
}
if accessor != "" {
auth.Accessor = accessor
if auth != nil {
var accessor string
if !config.HMACAccessor && auth.Accessor != "" {
accessor = auth.Accessor
}
if err := Hash(config.Salt, auth); err != nil {
return err
}
if accessor != "" {
auth.Accessor = accessor
}
}
// Cache and restore accessor in the request
@@ -196,21 +215,23 @@ func (f *AuditFormatter) FormatResponse(
}
// Cache and restore accessor in the response
accessor = ""
if !config.HMACAccessor && resp != nil && resp.Auth != nil && resp.Auth.Accessor != "" {
accessor = resp.Auth.Accessor
}
if !config.HMACAccessor && resp != nil && resp.WrapInfo != nil && resp.WrapInfo.WrappedAccessor != "" {
wrappedAccessor = resp.WrapInfo.WrappedAccessor
}
if err := Hash(config.Salt, resp); err != nil {
return err
}
if accessor != "" {
resp.Auth.Accessor = accessor
}
if wrappedAccessor != "" {
resp.WrapInfo.WrappedAccessor = wrappedAccessor
if resp != nil {
var accessor, wrappedAccessor string
if !config.HMACAccessor && resp != nil && resp.Auth != nil && resp.Auth.Accessor != "" {
accessor = resp.Auth.Accessor
}
if !config.HMACAccessor && resp != nil && resp.WrapInfo != nil && resp.WrapInfo.WrappedAccessor != "" {
wrappedAccessor = resp.WrapInfo.WrappedAccessor
}
if err := Hash(config.Salt, resp); err != nil {
return err
}
if accessor != "" {
resp.Auth.Accessor = accessor
}
if wrappedAccessor != "" {
resp.WrapInfo.WrappedAccessor = wrappedAccessor
}
}
}
@@ -222,8 +243,8 @@ func (f *AuditFormatter) FormatResponse(
resp = new(logical.Response)
}
var errString string
if err != nil {
errString = err.Error()
if inErr != nil {
errString = inErr.Error()
}
var respAuth *AuditAuth
@@ -276,6 +297,7 @@ func (f *AuditFormatter) FormatResponse(
Path: req.Path,
Data: req.Data,
RemoteAddr: getRemoteAddr(req),
ReplicationCluster: req.ReplicationCluster,
Headers: req.Headers,
},
@@ -312,14 +334,15 @@ type AuditRequestEntry struct {
type AuditResponseEntry struct {
Time string `json:"time,omitempty"`
Type string `json:"type"`
Error string `json:"error"`
Auth AuditAuth `json:"auth"`
Request AuditRequest `json:"request"`
Response AuditResponse `json:"response"`
Error string `json:"error"`
}
type AuditRequest struct {
ID string `json:"id"`
ReplicationCluster string `json:"replication_cluster,omitempty"`
Operation logical.Operation `json:"operation"`
ClientToken string `json:"client_token"`
ClientTokenAccessor string `json:"client_token_accessor"`

55
audit/format_test.go Normal file
View File

@@ -0,0 +1,55 @@
package audit
import (
"io"
"io/ioutil"
"testing"
"github.com/hashicorp/vault/helper/salt"
"github.com/hashicorp/vault/logical"
)
type noopFormatWriter struct {
}
func (n *noopFormatWriter) WriteRequest(_ io.Writer, _ *AuditRequestEntry) error {
return nil
}
func (n *noopFormatWriter) WriteResponse(_ io.Writer, _ *AuditResponseEntry) error {
return nil
}
func TestFormatRequestErrors(t *testing.T) {
salter, _ := salt.NewSalt(nil, nil)
config := FormatterConfig{
Salt: salter,
}
formatter := AuditFormatter{
AuditFormatWriter: &noopFormatWriter{},
}
if err := formatter.FormatRequest(ioutil.Discard, config, nil, nil, nil); err == nil {
t.Fatal("expected error due to nil request")
}
if err := formatter.FormatRequest(nil, config, nil, &logical.Request{}, nil); err == nil {
t.Fatal("expected error due to nil writer")
}
}
func TestFormatResponseErrors(t *testing.T) {
salter, _ := salt.NewSalt(nil, nil)
config := FormatterConfig{
Salt: salter,
}
formatter := AuditFormatter{
AuditFormatWriter: &noopFormatWriter{},
}
if err := formatter.FormatResponse(ioutil.Discard, config, nil, nil, nil, nil); err == nil {
t.Fatal("expected error due to nil request")
}
if err := formatter.FormatResponse(nil, config, nil, &logical.Request{}, nil, nil); err == nil {
t.Fatal("expected error due to nil writer")
}
}

View File

@@ -157,10 +157,15 @@ func (b *Backend) open() error {
return err
}
// Change the file mode in case the log file already existed
err = os.Chmod(b.path, b.mode)
if err != nil {
return err
// Change the file mode in case the log file already existed. We special
// case /dev/null since we can't chmod it
switch b.path {
case "/dev/null":
default:
err = os.Chmod(b.path, b.mode)
if err != nil {
return err
}
}
return nil

View File

@@ -17,20 +17,10 @@ func Factory(conf *logical.BackendConfig) (logical.Backend, error) {
}
func Backend(conf *logical.BackendConfig) (*framework.Backend, error) {
// Initialize the salt
salt, err := salt.NewSalt(conf.StorageView, &salt.Config{
HashFunc: salt.SHA1Hash,
})
if err != nil {
return nil, err
}
var b backend
b.Salt = salt
b.MapAppId = &framework.PolicyMap{
PathMap: framework.PathMap{
Name: "app-id",
Salt: salt,
Schema: map[string]*framework.FieldSchema{
"display_name": &framework.FieldSchema{
Type: framework.TypeString,
@@ -48,7 +38,6 @@ func Backend(conf *logical.BackendConfig) (*framework.Backend, error) {
b.MapUserId = &framework.PathMap{
Name: "user-id",
Salt: salt,
Schema: map[string]*framework.FieldSchema{
"cidr_block": &framework.FieldSchema{
Type: framework.TypeString,
@@ -81,17 +70,11 @@ func Backend(conf *logical.BackendConfig) (*framework.Backend, error) {
),
AuthRenew: b.pathLoginRenew,
Init: b.initialize,
}
// Since the salt is new in 0.2, we need to handle this by migrating
// any existing keys to use the salt. We can deprecate this eventually,
// but for now we want a smooth upgrade experience by automatically
// upgrading to use salting.
if salt.DidGenerate() {
if err := b.upgradeToSalted(conf.StorageView); err != nil {
return nil, err
}
}
b.view = conf.StorageView
return b.Backend, nil
}
@@ -100,10 +83,36 @@ type backend struct {
*framework.Backend
Salt *salt.Salt
view logical.Storage
MapAppId *framework.PolicyMap
MapUserId *framework.PathMap
}
func (b *backend) initialize() error {
salt, err := salt.NewSalt(b.view, &salt.Config{
HashFunc: salt.SHA1Hash,
})
if err != nil {
return err
}
b.Salt = salt
b.MapAppId.Salt = salt
b.MapUserId.Salt = salt
// Since the salt is new in 0.2, we need to handle this by migrating
// any existing keys to use the salt. We can deprecate this eventually,
// but for now we want a smooth upgrade experience by automatically
// upgrading to use salting.
if salt.DidGenerate() {
if err := b.upgradeToSalted(b.view); err != nil {
return err
}
}
return nil
}
// upgradeToSalted is used to upgrade the non-salted keys prior to
// Vault 0.2 to be salted. This is done on mount time and is only
// done once. It can be deprecated eventually, but should be around

View File

@@ -72,6 +72,10 @@ func TestBackend_upgradeToSalted(t *testing.T) {
if err != nil {
t.Fatalf("err: %v", err)
}
err = backend.Initialize()
if err != nil {
t.Fatalf("err: %v", err)
}
// Check the keys have been upgraded
out, err := inm.Get("struct/map/app-id/foo")

View File

@@ -17,6 +17,9 @@ type backend struct {
// by this backend.
salt *salt.Salt
// The view to use when creating the salt
view logical.Storage
// Guard to clean-up the expired SecretID entries
tidySecretIDCASGuard uint32
@@ -57,18 +60,9 @@ func Factory(conf *logical.BackendConfig) (logical.Backend, error) {
}
func Backend(conf *logical.BackendConfig) (*backend, error) {
// Initialize the salt
salt, err := salt.NewSalt(conf.StorageView, &salt.Config{
HashFunc: salt.SHA256Hash,
})
if err != nil {
return nil, err
}
// Create a backend object
b := &backend{
// Set the salt object for the backend
salt: salt,
view: conf.StorageView,
// Create the map of locks to modify the registered roles
roleLocksMap: make(map[string]*sync.RWMutex, 257),
@@ -83,6 +77,8 @@ func Backend(conf *logical.BackendConfig) (*backend, error) {
secretIDAccessorLocksMap: make(map[string]*sync.RWMutex, 257),
}
var err error
// Create 256 locks each for managing RoleID and SecretIDs. This will avoid
// a superfluous number of locks directly proportional to the number of RoleID
// and SecretIDs. These locks can be accessed by indexing based on the first two
@@ -129,10 +125,22 @@ func Backend(conf *logical.BackendConfig) (*backend, error) {
pathTidySecretID(b),
},
),
Init: b.initialize,
}
return b, nil
}
func (b *backend) initialize() error {
salt, err := salt.NewSalt(b.view, &salt.Config{
HashFunc: salt.SHA256Hash,
})
if err != nil {
return err
}
b.salt = salt
return nil
}
// periodicFunc of the backend will be invoked once a minute by the RollbackManager.
// RoleRole backend utilizes this function to delete expired SecretID entries.
// This could mean that the SecretID may live in the backend upto 1 min after its

View File

@@ -21,5 +21,9 @@ func createBackendWithStorage(t *testing.T) (*backend, logical.Storage) {
if err != nil {
t.Fatal(err)
}
err = b.Initialize()
if err != nil {
t.Fatal(err)
}
return b, config.StorageView
}

View File

@@ -143,7 +143,7 @@ func (b *backend) validateCredentials(req *logical.Request, data *framework.Fiel
return nil, "", metadata, fmt.Errorf("failed to verify the CIDR restrictions set on the role: %v", err)
}
if !belongs {
return nil, "", metadata, fmt.Errorf("source address unauthorized through CIDR restrictions on the role")
return nil, "", metadata, fmt.Errorf("source address %q unauthorized through CIDR restrictions on the role", req.Connection.RemoteAddr)
}
}
@@ -199,7 +199,7 @@ func (b *backend) validateBindSecretID(req *logical.Request, roleName, secretID,
}
if belongs, err := cidrutil.IPBelongsToCIDRBlocksSlice(req.Connection.RemoteAddr, result.CIDRList); !belongs || err != nil {
return false, nil, fmt.Errorf("source address unauthorized through CIDR restrictions on the secret ID: %v", err)
return false, nil, fmt.Errorf("source address %q unauthorized through CIDR restrictions on the secret ID: %v", req.Connection.RemoteAddr, err)
}
}
@@ -261,7 +261,7 @@ func (b *backend) validateBindSecretID(req *logical.Request, roleName, secretID,
}
if belongs, err := cidrutil.IPBelongsToCIDRBlocksSlice(req.Connection.RemoteAddr, result.CIDRList); !belongs || err != nil {
return false, nil, fmt.Errorf("source address unauthorized through CIDR restrictions on the secret ID: %v", err)
return false, nil, fmt.Errorf("source address %q unauthorized through CIDR restrictions on the secret ID: %v", req.Connection.RemoteAddr, err)
}
}

View File

@@ -23,6 +23,9 @@ type backend struct {
*framework.Backend
Salt *salt.Salt
// Used during initialization to set the salt
view logical.Storage
// Lock to make changes to any of the backend's configuration endpoints.
configMutex sync.RWMutex
@@ -59,18 +62,11 @@ type backend struct {
}
func Backend(conf *logical.BackendConfig) (*backend, error) {
salt, err := salt.NewSalt(conf.StorageView, &salt.Config{
HashFunc: salt.SHA256Hash,
})
if err != nil {
return nil, err
}
b := &backend{
// Setting the periodic func to be run once in an hour.
// If there is a real need, this can be made configurable.
tidyCooldownPeriod: time.Hour,
Salt: salt,
view: conf.StorageView,
EC2ClientsMap: make(map[string]map[string]*ec2.EC2),
IAMClientsMap: make(map[string]map[string]*iam.IAM),
}
@@ -83,6 +79,9 @@ func Backend(conf *logical.BackendConfig) (*backend, error) {
Unauthenticated: []string{
"login",
},
LocalStorage: []string{
"whitelist/identity/",
},
},
Paths: []*framework.Path{
pathLogin(b),
@@ -104,11 +103,26 @@ func Backend(conf *logical.BackendConfig) (*backend, error) {
pathIdentityWhitelist(b),
pathTidyIdentityWhitelist(b),
},
Invalidate: b.invalidate,
Init: b.initialize,
}
return b, nil
}
func (b *backend) initialize() error {
salt, err := salt.NewSalt(b.view, &salt.Config{
HashFunc: salt.SHA256Hash,
})
if err != nil {
return err
}
b.Salt = salt
return nil
}
// periodicFunc performs the tasks that the backend wishes to do periodically.
// Currently this will be triggered once in a minute by the RollbackManager.
//
@@ -169,6 +183,16 @@ func (b *backend) periodicFunc(req *logical.Request) error {
return nil
}
func (b *backend) invalidate(key string) {
switch key {
case "config/client":
b.configMutex.Lock()
defer b.configMutex.Unlock()
b.flushCachedEC2Clients()
b.flushCachedIAMClients()
}
}
const backendHelp = `
aws-ec2 auth backend takes in PKCS#7 signature of an AWS EC2 instance and a client
created nonce to authenticates the EC2 instance with Vault.

View File

@@ -1,6 +1,7 @@
package cert
import (
"strings"
"sync"
"github.com/hashicorp/vault/logical"
@@ -13,7 +14,7 @@ func Factory(conf *logical.BackendConfig) (logical.Backend, error) {
if err != nil {
return b, err
}
return b, b.populateCRLs(conf.StorageView)
return b, nil
}
func Backend() *backend {
@@ -36,9 +37,10 @@ func Backend() *backend {
}),
AuthRenew: b.pathLoginRenew,
Invalidate: b.invalidate,
}
b.crls = map[string]CRLInfo{}
b.crlUpdateMutex = &sync.RWMutex{}
return &b
@@ -52,6 +54,15 @@ type backend struct {
crlUpdateMutex *sync.RWMutex
}
func (b *backend) invalidate(key string) {
switch {
case strings.HasPrefix(key, "crls/"):
b.crlUpdateMutex.Lock()
defer b.crlUpdateMutex.Unlock()
b.crls = nil
}
}
const backendHelp = `
The "cert" credential provider allows authentication using
TLS client certificates. A client connects to Vault and uses

View File

@@ -45,6 +45,12 @@ func (b *backend) populateCRLs(storage logical.Storage) error {
b.crlUpdateMutex.Lock()
defer b.crlUpdateMutex.Unlock()
if b.crls != nil {
return nil
}
b.crls = map[string]CRLInfo{}
keys, err := storage.List("crls/")
if err != nil {
return fmt.Errorf("error listing CRLs: %v", err)
@@ -56,6 +62,7 @@ func (b *backend) populateCRLs(storage logical.Storage) error {
for _, key := range keys {
entry, err := storage.Get("crls/" + key)
if err != nil {
b.crls = nil
return fmt.Errorf("error loading CRL %s: %v", key, err)
}
if entry == nil {
@@ -64,6 +71,7 @@ func (b *backend) populateCRLs(storage logical.Storage) error {
var crlInfo CRLInfo
err = entry.DecodeJSON(&crlInfo)
if err != nil {
b.crls = nil
return fmt.Errorf("error decoding CRL %s: %v", key, err)
}
b.crls[key] = crlInfo
@@ -121,6 +129,10 @@ func (b *backend) pathCRLDelete(
return logical.ErrorResponse(`"name" parameter cannot be empty`), nil
}
if err := b.populateCRLs(req.Storage); err != nil {
return nil, err
}
b.crlUpdateMutex.Lock()
defer b.crlUpdateMutex.Unlock()
@@ -131,8 +143,7 @@ func (b *backend) pathCRLDelete(
)), nil
}
err := req.Storage.Delete("crls/" + name)
if err != nil {
if err := req.Storage.Delete("crls/" + name); err != nil {
return logical.ErrorResponse(fmt.Sprintf(
"error deleting crl %s: %v", name, err),
), nil
@@ -150,6 +161,10 @@ func (b *backend) pathCRLRead(
return logical.ErrorResponse(`"name" parameter must be set`), nil
}
if err := b.populateCRLs(req.Storage); err != nil {
return nil, err
}
b.crlUpdateMutex.RLock()
defer b.crlUpdateMutex.RUnlock()
@@ -185,6 +200,10 @@ func (b *backend) pathCRLWrite(
return logical.ErrorResponse("parsed CRL is nil"), nil
}
if err := b.populateCRLs(req.Storage); err != nil {
return nil, err
}
b.crlUpdateMutex.Lock()
defer b.crlUpdateMutex.Unlock()

View File

@@ -17,6 +17,12 @@ func Backend() *backend {
b.Backend = &framework.Backend{
Help: strings.TrimSpace(backendHelp),
PathsSpecial: &logical.Paths{
LocalStorage: []string{
framework.WALPrefix,
},
},
Paths: []*framework.Path{
pathConfigRoot(),
pathConfigLease(&b),

View File

@@ -31,6 +31,8 @@ func Backend() *backend {
secretCreds(&b),
},
Invalidate: b.invalidate,
Clean: func() {
b.ResetDB(nil)
},
@@ -107,6 +109,13 @@ func (b *backend) ResetDB(newSession *gocql.Session) {
b.session = newSession
}
func (b *backend) invalidate(key string) {
switch key {
case "config/connection":
b.ResetDB(nil)
}
}
const backendHelp = `
The Cassandra backend dynamically generates database users.

View File

@@ -421,7 +421,7 @@ seed_provider:
parameters:
# seeds is actually a comma-delimited list of addresses.
# Ex: "<ip1>,<ip2>,<ip3>"
- seeds: "172.17.0.2"
- seeds: "172.17.0.3"
# For workloads with more data than can fit in memory, Cassandra's
# bottleneck will be reads that need to fetch data from
@@ -572,7 +572,7 @@ ssl_storage_port: 7001
#
# Setting listen_address to 0.0.0.0 is always wrong.
#
listen_address: 172.17.0.2
listen_address: 172.17.0.3
# Set listen_address OR listen_interface, not both. Interfaces must correspond
# to a single address, IP aliasing is not supported.
@@ -586,7 +586,7 @@ listen_address: 172.17.0.2
# Address to broadcast to other Cassandra nodes
# Leaving this blank will set it to the same value as listen_address
broadcast_address: 172.17.0.2
broadcast_address: 172.17.0.3
# When using multiple physical network interfaces, set this
# to true to listen on broadcast_address in addition to
@@ -668,7 +668,7 @@ rpc_port: 9160
# be set to 0.0.0.0. If left blank, this will be set to the value of
# rpc_address. If rpc_address is set to 0.0.0.0, broadcast_rpc_address must
# be set.
broadcast_rpc_address: 172.17.0.2
broadcast_rpc_address: 172.17.0.3
# enable or disable keepalive on rpc/native connections
rpc_keepalive: true

View File

@@ -33,6 +33,8 @@ func Backend() *framework.Backend {
},
Clean: b.ResetSession,
Invalidate: b.invalidate,
}
return b.Backend
@@ -97,6 +99,13 @@ func (b *backend) ResetSession() {
b.session = nil
}
func (b *backend) invalidate(key string) {
switch key {
case "config/connection":
b.ResetSession()
}
}
// LeaseConfig returns the lease configuration
func (b *backend) LeaseConfig(s logical.Storage) (*configLease, error) {
entry, err := s.Get("config/lease")

View File

@@ -32,6 +32,8 @@ func Backend() *backend {
secretCreds(&b),
},
Invalidate: b.invalidate,
Clean: b.ResetDB,
}
@@ -112,6 +114,13 @@ func (b *backend) ResetDB() {
b.db = nil
}
func (b *backend) invalidate(key string) {
switch key {
case "config/connection":
b.ResetDB()
}
}
// LeaseConfig returns the lease configuration
func (b *backend) LeaseConfig(s logical.Storage) (*configLease, error) {
entry, err := s.Get("config/lease")

View File

@@ -32,6 +32,8 @@ func Backend() *backend {
secretCreds(&b),
},
Invalidate: b.invalidate,
Clean: b.ResetDB,
}
@@ -105,6 +107,13 @@ func (b *backend) ResetDB() {
b.db = nil
}
func (b *backend) invalidate(key string) {
switch key {
case "config/connection":
b.ResetDB()
}
}
// Lease returns the lease information
func (b *backend) Lease(s logical.Storage) (*configLease, error) {
entry, err := s.Get("config/lease")

View File

@@ -29,6 +29,12 @@ func Backend() *backend {
"crl/pem",
"crl",
},
LocalStorage: []string{
"revoked/",
"crl",
"certs/",
},
},
Paths: []*framework.Path{

View File

@@ -1478,6 +1478,27 @@ func generateRoleSteps(t *testing.T, useCSRs bool) []logicaltest.TestStep {
}
}
getOrganizationCheck := func(role roleEntry) logicaltest.TestCheckFunc {
var certBundle certutil.CertBundle
return func(resp *logical.Response) error {
err := mapstructure.Decode(resp.Data, &certBundle)
if err != nil {
return err
}
parsedCertBundle, err := certBundle.ToParsedCertBundle()
if err != nil {
return fmt.Errorf("Error checking generated certificate: %s", err)
}
cert := parsedCertBundle.Certificate
expected := strutil.ParseDedupAndSortStrings(role.Organization, ",")
if !reflect.DeepEqual(cert.Subject.Organization, expected) {
return fmt.Errorf("Error: returned certificate has Organization of %s but %s was specified in the role.", cert.Subject.Organization, expected)
}
return nil
}
}
// Returns a TestCheckFunc that performs various validity checks on the
// returned certificate information, mostly within checkCertsAndPrivateKey
getCnCheck := func(name string, role roleEntry, key crypto.Signer, usage x509.KeyUsage, extUsage x509.ExtKeyUsage, validity time.Duration) logicaltest.TestCheckFunc {
@@ -1755,6 +1776,14 @@ func generateRoleSteps(t *testing.T, useCSRs bool) []logicaltest.TestStep {
roleVals.OU = "foo,bar"
addTests(getOuCheck(roleVals))
}
// Organization tests
{
roleVals.Organization = "system:masters"
addTests(getOrganizationCheck(roleVals))
roleVals.Organization = "foo,bar"
addTests(getOrganizationCheck(roleVals))
}
// IP SAN tests
{
issueVals.IPSANs = "127.0.0.1,::1"

View File

@@ -35,6 +35,7 @@ const (
type creationBundle struct {
CommonName string
OU []string
Organization []string
DNSNames []string
EmailAddresses []string
IPAddresses []net.IP
@@ -581,6 +582,14 @@ func generateCreationBundle(b *backend,
}
}
// Set O (organization) values if specified in the role
organization := []string{}
{
if role.Organization != "" {
organization = strutil.ParseDedupAndSortStrings(role.Organization, ",")
}
}
// Read in alternate names -- DNS and email addresses
dnsNames := []string{}
emailAddresses := []string{}
@@ -728,6 +737,7 @@ func generateCreationBundle(b *backend,
creationBundle := &creationBundle{
CommonName: cn,
OU: ou,
Organization: organization,
DNSNames: dnsNames,
EmailAddresses: emailAddresses,
IPAddresses: ipAddresses,
@@ -820,6 +830,7 @@ func createCertificate(creationInfo *creationBundle) (*certutil.ParsedCertBundle
subject := pkix.Name{
CommonName: creationInfo.CommonName,
OrganizationalUnit: creationInfo.OU,
Organization: creationInfo.Organization,
}
certTemplate := &x509.Certificate{
@@ -983,6 +994,7 @@ func signCertificate(creationInfo *creationBundle,
subject := pkix.Name{
CommonName: creationInfo.CommonName,
OrganizationalUnit: creationInfo.OU,
Organization: creationInfo.Organization,
}
certTemplate := &x509.Certificate{

View File

@@ -172,6 +172,13 @@ Names. Defaults to true.`,
Type: framework.TypeString,
Default: "",
Description: `If set, the OU (OrganizationalUnit) will be set to
this value in certificates issued by this role.`,
},
"organization": &framework.FieldSchema{
Type: framework.TypeString,
Default: "",
Description: `If set, the O (Organization) will be set to
this value in certificates issued by this role.`,
},
},
@@ -336,6 +343,7 @@ func (b *backend) pathRoleCreate(
UseCSRCommonName: data.Get("use_csr_common_name").(bool),
KeyUsage: data.Get("key_usage").(string),
OU: data.Get("ou").(string),
Organization: data.Get("organization").(string),
}
if entry.KeyType == "rsa" && entry.KeyBits < 2048 {
@@ -451,6 +459,7 @@ type roleEntry struct {
MaxPathLength *int `json:",omitempty" structs:",omitempty"`
KeyUsage string `json:"key_usage" structs:"key_usage" mapstructure:"key_usage"`
OU string `json:"ou" structs:"ou" mapstructure:"ou"`
Organization string `json:"organization" structs:"organization" mapstructure:"organization"`
}
const pathListRolesHelpSyn = `List the existing roles in this backend`

View File

@@ -34,6 +34,8 @@ func Backend(conf *logical.BackendConfig) *backend {
},
Clean: b.ResetDB,
Invalidate: b.invalidate,
}
b.logger = conf.Logger
@@ -126,6 +128,13 @@ func (b *backend) ResetDB() {
b.db = nil
}
func (b *backend) invalidate(key string) {
switch key {
case "config/connection":
b.ResetDB()
}
}
// Lease returns the lease information
func (b *backend) Lease(s logical.Storage) (*configLease, error) {
entry, err := s.Get("config/lease")

View File

@@ -35,6 +35,8 @@ func Backend() *backend {
},
Clean: b.resetClient,
Invalidate: b.invalidate,
}
return &b
@@ -99,6 +101,13 @@ func (b *backend) resetClient() {
b.client = nil
}
func (b *backend) invalidate(key string) {
switch key {
case "config/connection":
b.resetClient()
}
}
// Lease returns the lease information
func (b *backend) Lease(s logical.Storage) (*configLease, error) {
entry, err := s.Get("config/lease")

View File

@@ -10,6 +10,7 @@ import (
type backend struct {
*framework.Backend
view logical.Storage
salt *salt.Salt
}
@@ -22,15 +23,8 @@ func Factory(conf *logical.BackendConfig) (logical.Backend, error) {
}
func Backend(conf *logical.BackendConfig) (*backend, error) {
salt, err := salt.NewSalt(conf.StorageView, &salt.Config{
HashFunc: salt.SHA256Hash,
})
if err != nil {
return nil, err
}
var b backend
b.salt = salt
b.view = conf.StorageView
b.Backend = &framework.Backend{
Help: strings.TrimSpace(backendHelp),
@@ -38,6 +32,10 @@ func Backend(conf *logical.BackendConfig) (*backend, error) {
Unauthenticated: []string{
"verify",
},
LocalStorage: []string{
"otp/",
},
},
Paths: []*framework.Path{
@@ -54,10 +52,23 @@ func Backend(conf *logical.BackendConfig) (*backend, error) {
secretDynamicKey(&b),
secretOTP(&b),
},
Init: b.Initialize,
}
return &b, nil
}
func (b *backend) Initialize() error {
salt, err := salt.NewSalt(b.view, &salt.Config{
HashFunc: salt.SHA256Hash,
})
if err != nil {
return err
}
b.salt = salt
return nil
}
const backendHelp = `
The SSH backend generates credentials allowing clients to establish SSH
connections to remote hosts.

View File

@@ -73,6 +73,10 @@ func TestBackend_allowed_users(t *testing.T) {
if err != nil {
t.Fatal(err)
}
err = b.Initialize()
if err != nil {
t.Fatal(err)
}
roleData := map[string]interface{}{
"key_type": "otp",

View File

@@ -1,6 +1,8 @@
package transit
import (
"strings"
"github.com/hashicorp/vault/helper/keysutil"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
@@ -39,6 +41,8 @@ func Backend(conf *logical.BackendConfig) *backend {
},
Secrets: []*framework.Secret{},
Invalidate: b.invalidate,
}
b.lm = keysutil.NewLockManager(conf.System.CachingDisabled())
@@ -50,3 +54,14 @@ type backend struct {
*framework.Backend
lm *keysutil.LockManager
}
func (b *backend) invalidate(key string) {
if b.Logger().IsTrace() {
b.Logger().Trace("transit: invalidating key", "key", key)
}
switch {
case strings.HasPrefix(key, "policy/"):
name := strings.TrimPrefix(key, "policy/")
b.lm.InvalidatePolicy(name)
}
}

View File

@@ -3,6 +3,7 @@ package command
import (
"testing"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/http"
"github.com/hashicorp/vault/meta"
"github.com/hashicorp/vault/vault"
@@ -44,3 +45,42 @@ func TestAuditDisable(t *testing.T) {
t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String())
}
}
func TestAuditDisableWithOptions(t *testing.T) {
core, _, token := vault.TestCoreUnsealed(t)
ln, addr := http.TestServer(t, core)
defer ln.Close()
ui := new(cli.MockUi)
c := &AuditDisableCommand{
Meta: meta.Meta{
ClientToken: token,
Ui: ui,
},
}
args := []string{
"-address", addr,
"noop",
}
// Run once to get the client
c.Run(args)
// Get the client
client, err := c.Client()
if err != nil {
t.Fatalf("err: %#v", err)
}
if err := client.Sys().EnableAuditWithOptions("noop", &api.EnableAuditOptions{
Type: "noop",
Description: "noop",
}); err != nil {
t.Fatalf("err: %#v", err)
}
// Run again
if code := c.Run(args); code != 0 {
t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String())
}
}

View File

@@ -6,6 +6,7 @@ import (
"os"
"strings"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/helper/kv-builder"
"github.com/hashicorp/vault/meta"
"github.com/mitchellh/mapstructure"
@@ -21,9 +22,11 @@ type AuditEnableCommand struct {
func (c *AuditEnableCommand) Run(args []string) int {
var desc, path string
var local bool
flags := c.Meta.FlagSet("audit-enable", meta.FlagSetDefault)
flags.StringVar(&desc, "description", "", "")
flags.StringVar(&path, "path", "", "")
flags.BoolVar(&local, "local", false, "")
flags.Usage = func() { c.Ui.Error(c.Help()) }
if err := flags.Parse(args); err != nil {
return 1
@@ -68,7 +71,12 @@ func (c *AuditEnableCommand) Run(args []string) int {
return 1
}
err = client.Sys().EnableAudit(path, auditType, desc, opts)
err = client.Sys().EnableAuditWithOptions(path, &api.EnableAuditOptions{
Type: auditType,
Description: desc,
Options: opts,
Local: local,
})
if err != nil {
c.Ui.Error(fmt.Sprintf(
"Error enabling audit backend: %s", err))
@@ -113,6 +121,9 @@ Audit Enable Options:
is purely for referencing this audit backend. By
default this will be the backend type.
-local Mark the mount as a local mount. Local mounts
are not replicated nor (if a secondary)
removed by replication.
`
return strings.TrimSpace(helpText)
}

View File

@@ -48,16 +48,19 @@ func (c *AuditListCommand) Run(args []string) int {
}
sort.Strings(paths)
columns := []string{"Path | Type | Description | Options"}
columns := []string{"Path | Type | Description | Replication Behavior | Options"}
for _, path := range paths {
audit := audits[path]
opts := make([]string, 0, len(audit.Options))
for k, v := range audit.Options {
opts = append(opts, k+"="+v)
}
replicatedBehavior := "replicated"
if audit.Local {
replicatedBehavior = "local"
}
columns = append(columns, fmt.Sprintf(
"%s | %s | %s | %s", audit.Path, audit.Type, audit.Description, strings.Join(opts, " ")))
"%s | %s | %s | %s | %s", audit.Path, audit.Type, audit.Description, replicatedBehavior, strings.Join(opts, " ")))
}
c.Ui.Output(columnize.SimpleFormat(columns))

View File

@@ -3,6 +3,7 @@ package command
import (
"testing"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/http"
"github.com/hashicorp/vault/meta"
"github.com/hashicorp/vault/vault"
@@ -34,7 +35,11 @@ func TestAuditList(t *testing.T) {
if err != nil {
t.Fatalf("err: %#v", err)
}
if err := client.Sys().EnableAudit("foo", "noop", "", nil); err != nil {
if err := client.Sys().EnableAuditWithOptions("foo", &api.EnableAuditOptions{
Type: "noop",
Description: "noop",
Options: nil,
}); err != nil {
t.Fatalf("err: %#v", err)
}

View File

@@ -281,7 +281,7 @@ func (c *AuthCommand) listMethods() int {
}
sort.Strings(paths)
columns := []string{"Path | Type | Default TTL | Max TTL | Description"}
columns := []string{"Path | Type | Default TTL | Max TTL | Replication Behavior | Description"}
for _, path := range paths {
auth := auth[path]
defTTL := "system"
@@ -292,8 +292,12 @@ func (c *AuthCommand) listMethods() int {
if auth.Config.MaxLeaseTTL != 0 {
maxTTL = strconv.Itoa(auth.Config.MaxLeaseTTL)
}
replicatedBehavior := "replicated"
if auth.Local {
replicatedBehavior = "local"
}
columns = append(columns, fmt.Sprintf(
"%s | %s | %s | %s | %s", path, auth.Type, defTTL, maxTTL, auth.Description))
"%s | %s | %s | %s | %s | %s", path, auth.Type, defTTL, maxTTL, replicatedBehavior, auth.Description))
}
c.Ui.Output(columnize.SimpleFormat(columns))

View File

@@ -3,6 +3,7 @@ package command
import (
"testing"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/http"
"github.com/hashicorp/vault/meta"
"github.com/hashicorp/vault/vault"
@@ -52,3 +53,50 @@ func TestAuthDisable(t *testing.T) {
t.Fatal("should not have noop mount")
}
}
func TestAuthDisableWithOptions(t *testing.T) {
core, _, token := vault.TestCoreUnsealed(t)
ln, addr := http.TestServer(t, core)
defer ln.Close()
ui := new(cli.MockUi)
c := &AuthDisableCommand{
Meta: meta.Meta{
ClientToken: token,
Ui: ui,
},
}
args := []string{
"-address", addr,
"noop",
}
// Run the command once to setup the client, it will fail
c.Run(args)
client, err := c.Client()
if err != nil {
t.Fatalf("err: %s", err)
}
if err := client.Sys().EnableAuthWithOptions("noop", &api.EnableAuthOptions{
Type: "noop",
Description: "",
}); err != nil {
t.Fatalf("err: %#v", err)
}
if code := c.Run(args); code != 0 {
t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String())
}
mounts, err := client.Sys().ListAuth()
if err != nil {
t.Fatalf("err: %s", err)
}
if _, ok := mounts["noop"]; ok {
t.Fatal("should not have noop mount")
}
}

View File

@@ -4,6 +4,7 @@ import (
"fmt"
"strings"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/meta"
)
@@ -14,9 +15,11 @@ type AuthEnableCommand struct {
func (c *AuthEnableCommand) Run(args []string) int {
var description, path string
var local bool
flags := c.Meta.FlagSet("auth-enable", meta.FlagSetDefault)
flags.StringVar(&description, "description", "", "")
flags.StringVar(&path, "path", "", "")
flags.BoolVar(&local, "local", false, "")
flags.Usage = func() { c.Ui.Error(c.Help()) }
if err := flags.Parse(args); err != nil {
return 1
@@ -44,7 +47,11 @@ func (c *AuthEnableCommand) Run(args []string) int {
return 2
}
if err := client.Sys().EnableAuth(path, authType, description); err != nil {
if err := client.Sys().EnableAuthWithOptions(path, &api.EnableAuthOptions{
Type: authType,
Description: description,
Local: local,
}); err != nil {
c.Ui.Error(fmt.Sprintf(
"Error: %s", err))
return 2
@@ -82,6 +89,9 @@ Auth Enable Options:
to the type of the mount. This will make the auth
provider available at "/auth/<path>"
-local Mark the mount as a local mount. Local mounts
are not replicated nor (if a secondary)
removed by replication.
`
return strings.TrimSpace(helpText)
}

View File

@@ -15,11 +15,13 @@ type MountCommand struct {
func (c *MountCommand) Run(args []string) int {
var description, path, defaultLeaseTTL, maxLeaseTTL string
var local bool
flags := c.Meta.FlagSet("mount", meta.FlagSetDefault)
flags.StringVar(&description, "description", "", "")
flags.StringVar(&path, "path", "", "")
flags.StringVar(&defaultLeaseTTL, "default-lease-ttl", "", "")
flags.StringVar(&maxLeaseTTL, "max-lease-ttl", "", "")
flags.BoolVar(&local, "local", false, "")
flags.Usage = func() { c.Ui.Error(c.Help()) }
if err := flags.Parse(args); err != nil {
return 1
@@ -54,6 +56,7 @@ func (c *MountCommand) Run(args []string) int {
DefaultLeaseTTL: defaultLeaseTTL,
MaxLeaseTTL: maxLeaseTTL,
},
Local: local,
}
if err := client.Sys().Mount(path, mountInfo); err != nil {
@@ -102,6 +105,10 @@ Mount Options:
the previously set value. Set to '0' to
explicitly set it to use the global default.
-local Mark the mount as a local mount. Local mounts
are not replicated nor (if a secondary)
removed by replication.
`
return strings.TrimSpace(helpText)
}

View File

@@ -42,7 +42,7 @@ func (c *MountsCommand) Run(args []string) int {
}
sort.Strings(paths)
columns := []string{"Path | Type | Default TTL | Max TTL | Description"}
columns := []string{"Path | Type | Default TTL | Max TTL | Replication Behavior | Description"}
for _, path := range paths {
mount := mounts[path]
defTTL := "system"
@@ -63,8 +63,12 @@ func (c *MountsCommand) Run(args []string) int {
case mount.Config.MaxLeaseTTL != 0:
maxTTL = strconv.Itoa(mount.Config.MaxLeaseTTL)
}
replicatedBehavior := "replicated"
if mount.Local {
replicatedBehavior = "local"
}
columns = append(columns, fmt.Sprintf(
"%s | %s | %s | %s | %s", path, mount.Type, defTTL, maxTTL, mount.Description))
"%s | %s | %s | %s | %s | %s", path, mount.Type, defTTL, maxTTL, replicatedBehavior, mount.Description))
}
c.Ui.Output(columnize.SimpleFormat(columns))

View File

@@ -61,7 +61,7 @@ type ServerCommand struct {
}
func (c *ServerCommand) Run(args []string) int {
var dev, verifyOnly, devHA bool
var dev, verifyOnly, devHA, devTransactional bool
var configPath []string
var logLevel, devRootTokenID, devListenAddress string
flags := c.Meta.FlagSet("server", meta.FlagSetDefault)
@@ -70,7 +70,8 @@ func (c *ServerCommand) Run(args []string) int {
flags.StringVar(&devListenAddress, "dev-listen-address", "", "")
flags.StringVar(&logLevel, "log-level", "info", "")
flags.BoolVar(&verifyOnly, "verify-only", false, "")
flags.BoolVar(&devHA, "dev-ha", false, "")
flags.BoolVar(&devHA, "ha", false, "")
flags.BoolVar(&devTransactional, "transactional", false, "")
flags.Usage = func() { c.Ui.Output(c.Help()) }
flags.Var((*sliceflag.StringFlag)(&configPath), "config", "config")
if err := flags.Parse(args); err != nil {
@@ -122,7 +123,7 @@ func (c *ServerCommand) Run(args []string) int {
devListenAddress = os.Getenv("VAULT_DEV_LISTEN_ADDRESS")
}
if devHA {
if devHA || devTransactional {
dev = true
}
@@ -143,7 +144,7 @@ func (c *ServerCommand) Run(args []string) int {
// Load the configuration
var config *server.Config
if dev {
config = server.DevConfig(devHA)
config = server.DevConfig(devHA, devTransactional)
if devListenAddress != "" {
config.Listeners[0].Config["address"] = devListenAddress
}
@@ -235,6 +236,9 @@ func (c *ServerCommand) Run(args []string) int {
ClusterName: config.ClusterName,
CacheSize: config.CacheSize,
}
if dev {
coreConfig.DevToken = devRootTokenID
}
var disableClustering bool

View File

@@ -38,7 +38,7 @@ type Config struct {
}
// DevConfig is a Config that is used for dev mode of Vault.
func DevConfig(ha bool) *Config {
func DevConfig(ha, transactional bool) *Config {
ret := &Config{
DisableCache: false,
DisableMlock: true,
@@ -63,7 +63,12 @@ func DevConfig(ha bool) *Config {
DefaultLeaseTTL: 32 * 24 * time.Hour,
}
if ha {
switch {
case ha && transactional:
ret.Backend.Type = "inmem_transactional_ha"
case !ha && transactional:
ret.Backend.Type = "inmem_transactional"
case ha && !transactional:
ret.Backend.Type = "inmem_ha"
}

View File

@@ -33,7 +33,7 @@ func TestServer_CommonHA(t *testing.T) {
args := []string{"-config", tmpfile.Name(), "-verify-only", "true"}
if code := c.Run(args); code != 0 {
t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String())
t.Fatalf("bad: %d\n\n%s\n\n%s", code, ui.ErrorWriter.String(), ui.OutputWriter.String())
}
if !strings.Contains(ui.OutputWriter.String(), "(HA available)") {
@@ -61,7 +61,7 @@ func TestServer_GoodSeparateHA(t *testing.T) {
args := []string{"-config", tmpfile.Name(), "-verify-only", "true"}
if code := c.Run(args); code != 0 {
t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String())
t.Fatalf("bad: %d\n\n%s\n\n%s", code, ui.ErrorWriter.String(), ui.OutputWriter.String())
}
if !strings.Contains(ui.OutputWriter.String(), "HA Backend:") {

View File

@@ -40,7 +40,7 @@ func (c *StatusCommand) Run(args []string) int {
"Key Shares: %d\n"+
"Key Threshold: %d\n"+
"Unseal Progress: %d\n"+
"Unseal Nonce: %v"+
"Unseal Nonce: %v\n"+
"Version: %s",
sealStatus.Sealed,
sealStatus.N,

View File

@@ -14,6 +14,16 @@ func TestCIDRUtil_IPBelongsToCIDR(t *testing.T) {
t.Fatalf("expected IP %q to belong to CIDR %q", ip, cidr)
}
ip = "10.197.192.6"
cidr = "10.197.192.0/18"
belongs, err = IPBelongsToCIDR(ip, cidr)
if err != nil {
t.Fatal(err)
}
if !belongs {
t.Fatalf("expected IP %q to belong to CIDR %q", ip, cidr)
}
ip = "192.168.25.30"
cidr = "192.168.26.30/24"
belongs, err = IPBelongsToCIDR(ip, cidr)
@@ -44,6 +54,17 @@ func TestCIDRUtil_IPBelongsToCIDRBlocksString(t *testing.T) {
t.Fatalf("expected IP %q to belong to one of the CIDRs in %q", ip, cidrList)
}
ip = "10.197.192.6"
cidrList = "1.2.3.0/8,10.197.192.0/18,10.197.193.0/24"
belongs, err = IPBelongsToCIDRBlocksString(ip, cidrList, ",")
if err != nil {
t.Fatal(err)
}
if !belongs {
t.Fatalf("expected IP %q to belong to one of the CIDRs in %q", ip, cidrList)
}
ip = "192.168.27.29"
cidrList = "172.169.100.200/18,192.168.0.0.0/16,10.10.20.20/24"

7
helper/consts/consts.go Normal file
View File

@@ -0,0 +1,7 @@
package consts
const (
// ExpirationRestoreWorkerCount specifies the numer of workers to use while
// restoring leases into the expiration manager
ExpirationRestoreWorkerCount = 64
)

13
helper/consts/error.go Normal file
View File

@@ -0,0 +1,13 @@
package consts
import "errors"
var (
// ErrSealed is returned if an operation is performed on a sealed barrier.
// No operation is expected to succeed before unsealing
ErrSealed = errors.New("Vault is sealed")
// ErrStandby is returned if an operation is performed on a standby Vault.
// No operation is expected to succeed until active.
ErrStandby = errors.New("Vault is in standby mode")
)

View File

@@ -0,0 +1,20 @@
package consts
type ReplicationState uint32
const (
ReplicationDisabled ReplicationState = iota
ReplicationPrimary
ReplicationSecondary
)
func (r ReplicationState) String() string {
switch r {
case ReplicationSecondary:
return "secondary"
case ReplicationPrimary:
return "primary"
}
return "disabled"
}

View File

@@ -71,6 +71,15 @@ func (lm *LockManager) CacheActive() bool {
return lm.cache != nil
}
func (lm *LockManager) InvalidatePolicy(name string) {
// Check if it's in our cache. If so, return right away.
if lm.CacheActive() {
lm.cacheMutex.Lock()
defer lm.cacheMutex.Unlock()
delete(lm.cache, name)
}
}
func (lm *LockManager) policyLock(name string, lockType bool) *sync.RWMutex {
lm.locksMutex.RLock()
lock := lm.locks[name]

View File

@@ -9,6 +9,7 @@ import (
"strings"
"github.com/hashicorp/errwrap"
"github.com/hashicorp/vault/helper/consts"
"github.com/hashicorp/vault/helper/duration"
"github.com/hashicorp/vault/helper/jsonutil"
"github.com/hashicorp/vault/logical"
@@ -206,11 +207,11 @@ func handleRequestForwarding(core *vault.Core, handler http.Handler) http.Handle
// case of an error.
func request(core *vault.Core, w http.ResponseWriter, rawReq *http.Request, r *logical.Request) (*logical.Response, bool) {
resp, err := core.HandleRequest(r)
if errwrap.Contains(err, vault.ErrStandby.Error()) {
if errwrap.Contains(err, consts.ErrStandby.Error()) {
respondStandby(core, w, rawReq.URL)
return resp, false
}
if respondErrorCommon(w, resp, err) {
if respondErrorCommon(w, r, resp, err) {
return resp, false
}
@@ -310,20 +311,7 @@ func requestWrapInfo(r *http.Request, req *logical.Request) (*logical.Request, e
}
func respondError(w http.ResponseWriter, status int, err error) {
// Adjust status code when sealed
if errwrap.Contains(err, vault.ErrSealed.Error()) {
status = http.StatusServiceUnavailable
}
// Adjust status code on
if errwrap.Contains(err, "http: request body too large") {
status = http.StatusRequestEntityTooLarge
}
// Allow HTTPCoded error passthrough to specify a code
if t, ok := err.(logical.HTTPCodedError); ok {
status = t.Code()
}
logical.AdjustErrorStatusCode(&status, err)
w.Header().Add("Content-Type", "application/json")
w.WriteHeader(status)
@@ -337,42 +325,13 @@ func respondError(w http.ResponseWriter, status int, err error) {
enc.Encode(resp)
}
func respondErrorCommon(w http.ResponseWriter, resp *logical.Response, err error) bool {
// If there are no errors return
if err == nil && (resp == nil || !resp.IsError()) {
func respondErrorCommon(w http.ResponseWriter, req *logical.Request, resp *logical.Response, err error) bool {
statusCode, newErr := logical.RespondErrorCommon(req, resp, err)
if newErr == nil && statusCode == 0 {
return false
}
// Start out with internal server error since in most of these cases there
// won't be a response so this won't be overridden
statusCode := http.StatusInternalServerError
// If we actually have a response, start out with bad request
if resp != nil {
statusCode = http.StatusBadRequest
}
// Now, check the error itself; if it has a specific logical error, set the
// appropriate code
if err != nil {
switch {
case errwrap.ContainsType(err, new(vault.StatusBadRequest)):
statusCode = http.StatusBadRequest
case errwrap.Contains(err, logical.ErrPermissionDenied.Error()):
statusCode = http.StatusForbidden
case errwrap.Contains(err, logical.ErrUnsupportedOperation.Error()):
statusCode = http.StatusMethodNotAllowed
case errwrap.Contains(err, logical.ErrUnsupportedPath.Error()):
statusCode = http.StatusNotFound
case errwrap.Contains(err, logical.ErrInvalidRequest.Error()):
statusCode = http.StatusBadRequest
}
}
if resp != nil && resp.IsError() {
err = fmt.Errorf("%s", resp.Data["error"].(string))
}
respondError(w, statusCode, err)
respondError(w, statusCode, newErr)
return true
}

View File

@@ -9,6 +9,7 @@ import (
"testing"
"github.com/hashicorp/go-cleanhttp"
"github.com/hashicorp/vault/helper/consts"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/vault"
)
@@ -80,6 +81,7 @@ func TestSysMounts_headerAuth(t *testing.T) {
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
},
"local": false,
},
"sys/": map[string]interface{}{
"description": "system endpoints used for control, policy and debugging",
@@ -88,6 +90,7 @@ func TestSysMounts_headerAuth(t *testing.T) {
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
},
"local": false,
},
"cubbyhole/": map[string]interface{}{
"description": "per-token private secret storage",
@@ -96,6 +99,7 @@ func TestSysMounts_headerAuth(t *testing.T) {
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
},
"local": true,
},
},
"secret/": map[string]interface{}{
@@ -105,6 +109,7 @@ func TestSysMounts_headerAuth(t *testing.T) {
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
},
"local": false,
},
"sys/": map[string]interface{}{
"description": "system endpoints used for control, policy and debugging",
@@ -113,6 +118,7 @@ func TestSysMounts_headerAuth(t *testing.T) {
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
},
"local": false,
},
"cubbyhole/": map[string]interface{}{
"description": "per-token private secret storage",
@@ -121,6 +127,7 @@ func TestSysMounts_headerAuth(t *testing.T) {
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
},
"local": true,
},
}
testResponseStatus(t, resp, 200)
@@ -223,7 +230,7 @@ func TestHandler_error(t *testing.T) {
// vault.ErrSealed is a special case
w3 := httptest.NewRecorder()
respondError(w3, 400, vault.ErrSealed)
respondError(w3, 400, consts.ErrSealed)
if w3.Code != 503 {
t.Fatalf("expected 503, got %d", w3.Code)

View File

@@ -35,7 +35,7 @@ func handleHelp(core *vault.Core, w http.ResponseWriter, req *http.Request) {
resp, err := core.HandleRequest(lreq)
if err != nil {
respondErrorCommon(w, resp, err)
respondErrorCommon(w, lreq, resp, err)
return
}

View File

@@ -53,6 +53,12 @@ func buildLogicalRequest(core *vault.Core, w http.ResponseWriter, r *http.Reques
return nil, http.StatusMethodNotAllowed, nil
}
if op == logical.ListOperation {
if !strings.HasSuffix(path, "/") {
path += "/"
}
}
// Parse the request if we can
var data map[string]interface{}
if op == logical.UpdateOperation {
@@ -109,40 +115,13 @@ func handleLogical(core *vault.Core, dataOnly bool, prepareRequestCallback Prepa
// Make the internal request. We attach the connection info
// as well in case this is an authentication request that requires
// it. Vault core handles stripping this if we need to.
// it. Vault core handles stripping this if we need to. This also
// handles all error cases; if we hit respondLogical, the request is a
// success.
resp, ok := request(core, w, r, req)
if !ok {
return
}
switch {
case req.Operation == logical.ReadOperation:
if resp == nil {
respondError(w, http.StatusNotFound, nil)
return
}
// Basically: if we have empty "keys" or no keys at all, 404. This
// provides consistency with GET.
case req.Operation == logical.ListOperation && resp.WrapInfo == nil:
if resp == nil || len(resp.Data) == 0 {
respondError(w, http.StatusNotFound, nil)
return
}
keysRaw, ok := resp.Data["keys"]
if !ok || keysRaw == nil {
respondError(w, http.StatusNotFound, nil)
return
}
keys, ok := keysRaw.([]string)
if !ok {
respondError(w, http.StatusInternalServerError, nil)
return
}
if len(keys) == 0 {
respondError(w, http.StatusNotFound, nil)
return
}
}
// Build the proper response
respondLogical(w, r, req, dataOnly, resp)

View File

@@ -4,8 +4,10 @@ import (
"bytes"
"encoding/json"
"io"
"net/http"
"reflect"
"strconv"
"strings"
"testing"
"time"
@@ -101,7 +103,7 @@ func TestLogical_StandbyRedirect(t *testing.T) {
// Attempt to fix raciness in this test by giving the first core a chance
// to grab the lock
time.Sleep(time.Second)
time.Sleep(2 * time.Second)
// Create a second HA Vault
conf2 := &vault.CoreConfig{
@@ -252,3 +254,42 @@ func TestLogical_RequestSizeLimit(t *testing.T) {
})
testResponseStatus(t, resp, 413)
}
func TestLogical_ListSuffix(t *testing.T) {
core, _, _ := vault.TestCoreUnsealed(t)
req, _ := http.NewRequest("GET", "http://127.0.0.1:8200/v1/secret/foo", nil)
lreq, status, err := buildLogicalRequest(core, nil, req)
if err != nil {
t.Fatal(err)
}
if status != 0 {
t.Fatalf("got status %d", status)
}
if strings.HasSuffix(lreq.Path, "/") {
t.Fatal("trailing slash found on path")
}
req, _ = http.NewRequest("GET", "http://127.0.0.1:8200/v1/secret/foo?list=true", nil)
lreq, status, err = buildLogicalRequest(core, nil, req)
if err != nil {
t.Fatal(err)
}
if status != 0 {
t.Fatalf("got status %d", status)
}
if !strings.HasSuffix(lreq.Path, "/") {
t.Fatal("trailing slash not found on path")
}
req, _ = http.NewRequest("LIST", "http://127.0.0.1:8200/v1/secret/foo", nil)
lreq, status, err = buildLogicalRequest(core, nil, req)
if err != nil {
t.Fatal(err)
}
if status != 0 {
t.Fatalf("got status %d", status)
}
if !strings.HasSuffix(lreq.Path, "/") {
t.Fatal("trailing slash not found on path")
}
}

View File

@@ -35,6 +35,7 @@ func TestSysAudit(t *testing.T) {
"type": "noop",
"description": "",
"options": map[string]interface{}{},
"local": false,
},
},
"noop/": map[string]interface{}{
@@ -42,6 +43,7 @@ func TestSysAudit(t *testing.T) {
"type": "noop",
"description": "",
"options": map[string]interface{}{},
"local": false,
},
}
testResponseStatus(t, resp, 200)

View File

@@ -32,6 +32,7 @@ func TestSysAuth(t *testing.T) {
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
},
"local": false,
},
},
"token/": map[string]interface{}{
@@ -41,6 +42,7 @@ func TestSysAuth(t *testing.T) {
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
},
"local": false,
},
}
testResponseStatus(t, resp, 200)
@@ -83,6 +85,7 @@ func TestSysEnableAuth(t *testing.T) {
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
},
"local": false,
},
"token/": map[string]interface{}{
"description": "token based credentials",
@@ -91,6 +94,7 @@ func TestSysEnableAuth(t *testing.T) {
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
},
"local": false,
},
},
"foo/": map[string]interface{}{
@@ -100,6 +104,7 @@ func TestSysEnableAuth(t *testing.T) {
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
},
"local": false,
},
"token/": map[string]interface{}{
"description": "token based credentials",
@@ -108,6 +113,7 @@ func TestSysEnableAuth(t *testing.T) {
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
},
"local": false,
},
}
testResponseStatus(t, resp, 200)
@@ -153,6 +159,7 @@ func TestSysDisableAuth(t *testing.T) {
},
"description": "token based credentials",
"type": "token",
"local": false,
},
},
"token/": map[string]interface{}{
@@ -162,6 +169,7 @@ func TestSysDisableAuth(t *testing.T) {
},
"description": "token based credentials",
"type": "token",
"local": false,
},
}
testResponseStatus(t, resp, 200)

View File

@@ -33,6 +33,7 @@ func TestSysMounts(t *testing.T) {
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
},
"local": false,
},
"sys/": map[string]interface{}{
"description": "system endpoints used for control, policy and debugging",
@@ -41,6 +42,7 @@ func TestSysMounts(t *testing.T) {
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
},
"local": false,
},
"cubbyhole/": map[string]interface{}{
"description": "per-token private secret storage",
@@ -49,6 +51,7 @@ func TestSysMounts(t *testing.T) {
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
},
"local": true,
},
},
"secret/": map[string]interface{}{
@@ -58,6 +61,7 @@ func TestSysMounts(t *testing.T) {
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
},
"local": false,
},
"sys/": map[string]interface{}{
"description": "system endpoints used for control, policy and debugging",
@@ -66,6 +70,7 @@ func TestSysMounts(t *testing.T) {
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
},
"local": false,
},
"cubbyhole/": map[string]interface{}{
"description": "per-token private secret storage",
@@ -74,6 +79,7 @@ func TestSysMounts(t *testing.T) {
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
},
"local": true,
},
}
testResponseStatus(t, resp, 200)
@@ -114,6 +120,7 @@ func TestSysMount(t *testing.T) {
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
},
"local": false,
},
"secret/": map[string]interface{}{
"description": "generic secret storage",
@@ -122,6 +129,7 @@ func TestSysMount(t *testing.T) {
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
},
"local": false,
},
"sys/": map[string]interface{}{
"description": "system endpoints used for control, policy and debugging",
@@ -130,6 +138,7 @@ func TestSysMount(t *testing.T) {
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
},
"local": false,
},
"cubbyhole/": map[string]interface{}{
"description": "per-token private secret storage",
@@ -138,6 +147,7 @@ func TestSysMount(t *testing.T) {
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
},
"local": true,
},
},
"foo/": map[string]interface{}{
@@ -147,6 +157,7 @@ func TestSysMount(t *testing.T) {
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
},
"local": false,
},
"secret/": map[string]interface{}{
"description": "generic secret storage",
@@ -155,6 +166,7 @@ func TestSysMount(t *testing.T) {
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
},
"local": false,
},
"sys/": map[string]interface{}{
"description": "system endpoints used for control, policy and debugging",
@@ -163,6 +175,7 @@ func TestSysMount(t *testing.T) {
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
},
"local": false,
},
"cubbyhole/": map[string]interface{}{
"description": "per-token private secret storage",
@@ -171,6 +184,7 @@ func TestSysMount(t *testing.T) {
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
},
"local": true,
},
}
testResponseStatus(t, resp, 200)
@@ -233,6 +247,7 @@ func TestSysRemount(t *testing.T) {
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
},
"local": false,
},
"secret/": map[string]interface{}{
"description": "generic secret storage",
@@ -241,6 +256,7 @@ func TestSysRemount(t *testing.T) {
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
},
"local": false,
},
"sys/": map[string]interface{}{
"description": "system endpoints used for control, policy and debugging",
@@ -249,6 +265,7 @@ func TestSysRemount(t *testing.T) {
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
},
"local": false,
},
"cubbyhole/": map[string]interface{}{
"description": "per-token private secret storage",
@@ -257,6 +274,7 @@ func TestSysRemount(t *testing.T) {
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
},
"local": true,
},
},
"bar/": map[string]interface{}{
@@ -266,6 +284,7 @@ func TestSysRemount(t *testing.T) {
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
},
"local": false,
},
"secret/": map[string]interface{}{
"description": "generic secret storage",
@@ -274,6 +293,7 @@ func TestSysRemount(t *testing.T) {
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
},
"local": false,
},
"sys/": map[string]interface{}{
"description": "system endpoints used for control, policy and debugging",
@@ -282,6 +302,7 @@ func TestSysRemount(t *testing.T) {
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
},
"local": false,
},
"cubbyhole/": map[string]interface{}{
"description": "per-token private secret storage",
@@ -290,6 +311,7 @@ func TestSysRemount(t *testing.T) {
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
},
"local": true,
},
}
testResponseStatus(t, resp, 200)
@@ -333,6 +355,7 @@ func TestSysUnmount(t *testing.T) {
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
},
"local": false,
},
"sys/": map[string]interface{}{
"description": "system endpoints used for control, policy and debugging",
@@ -341,6 +364,7 @@ func TestSysUnmount(t *testing.T) {
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
},
"local": false,
},
"cubbyhole/": map[string]interface{}{
"description": "per-token private secret storage",
@@ -349,6 +373,7 @@ func TestSysUnmount(t *testing.T) {
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
},
"local": true,
},
},
"secret/": map[string]interface{}{
@@ -358,6 +383,7 @@ func TestSysUnmount(t *testing.T) {
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
},
"local": false,
},
"sys/": map[string]interface{}{
"description": "system endpoints used for control, policy and debugging",
@@ -366,6 +392,7 @@ func TestSysUnmount(t *testing.T) {
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
},
"local": false,
},
"cubbyhole/": map[string]interface{}{
"description": "per-token private secret storage",
@@ -374,6 +401,7 @@ func TestSysUnmount(t *testing.T) {
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
},
"local": true,
},
}
testResponseStatus(t, resp, 200)
@@ -414,6 +442,7 @@ func TestSysTuneMount(t *testing.T) {
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
},
"local": false,
},
"secret/": map[string]interface{}{
"description": "generic secret storage",
@@ -422,6 +451,7 @@ func TestSysTuneMount(t *testing.T) {
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
},
"local": false,
},
"sys/": map[string]interface{}{
"description": "system endpoints used for control, policy and debugging",
@@ -430,6 +460,7 @@ func TestSysTuneMount(t *testing.T) {
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
},
"local": false,
},
"cubbyhole/": map[string]interface{}{
"description": "per-token private secret storage",
@@ -438,6 +469,7 @@ func TestSysTuneMount(t *testing.T) {
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
},
"local": true,
},
},
"foo/": map[string]interface{}{
@@ -447,6 +479,7 @@ func TestSysTuneMount(t *testing.T) {
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
},
"local": false,
},
"secret/": map[string]interface{}{
"description": "generic secret storage",
@@ -455,6 +488,7 @@ func TestSysTuneMount(t *testing.T) {
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
},
"local": false,
},
"sys/": map[string]interface{}{
"description": "system endpoints used for control, policy and debugging",
@@ -463,6 +497,7 @@ func TestSysTuneMount(t *testing.T) {
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
},
"local": false,
},
"cubbyhole/": map[string]interface{}{
"description": "per-token private secret storage",
@@ -471,6 +506,7 @@ func TestSysTuneMount(t *testing.T) {
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
},
"local": true,
},
}
testResponseStatus(t, resp, 200)
@@ -532,6 +568,7 @@ func TestSysTuneMount(t *testing.T) {
"default_lease_ttl": json.Number("259196400"),
"max_lease_ttl": json.Number("259200000"),
},
"local": false,
},
"secret/": map[string]interface{}{
"description": "generic secret storage",
@@ -540,6 +577,7 @@ func TestSysTuneMount(t *testing.T) {
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
},
"local": false,
},
"sys/": map[string]interface{}{
"description": "system endpoints used for control, policy and debugging",
@@ -548,6 +586,7 @@ func TestSysTuneMount(t *testing.T) {
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
},
"local": false,
},
"cubbyhole/": map[string]interface{}{
"description": "per-token private secret storage",
@@ -556,6 +595,7 @@ func TestSysTuneMount(t *testing.T) {
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
},
"local": true,
},
},
"foo/": map[string]interface{}{
@@ -565,6 +605,7 @@ func TestSysTuneMount(t *testing.T) {
"default_lease_ttl": json.Number("259196400"),
"max_lease_ttl": json.Number("259200000"),
},
"local": false,
},
"secret/": map[string]interface{}{
"description": "generic secret storage",
@@ -573,6 +614,7 @@ func TestSysTuneMount(t *testing.T) {
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
},
"local": false,
},
"sys/": map[string]interface{}{
"description": "system endpoints used for control, policy and debugging",
@@ -581,6 +623,7 @@ func TestSysTuneMount(t *testing.T) {
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
},
"local": false,
},
"cubbyhole/": map[string]interface{}{
"description": "per-token private secret storage",
@@ -589,6 +632,7 @@ func TestSysTuneMount(t *testing.T) {
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
},
"local": true,
},
}

View File

@@ -7,6 +7,7 @@ import (
"fmt"
"net/http"
"github.com/hashicorp/vault/helper/consts"
"github.com/hashicorp/vault/helper/pgpkeys"
"github.com/hashicorp/vault/vault"
)
@@ -19,6 +20,13 @@ func handleSysRekeyInit(core *vault.Core, recovery bool) http.Handler {
return
}
repState := core.ReplicationState()
if repState == consts.ReplicationSecondary {
respondError(w, http.StatusBadRequest,
fmt.Errorf("rekeying can only be performed on the primary cluster when replication is activated"))
return
}
switch {
case recovery && !core.SealAccess().RecoveryKeySupported():
respondError(w, http.StatusBadRequest, fmt.Errorf("recovery rekeying not supported"))
@@ -108,7 +116,7 @@ func handleSysRekeyInitPut(core *vault.Core, recovery bool, w http.ResponseWrite
// Right now we don't support this, but the rest of the code is ready for
// when we do, hence the check below for this to be false if
// StoredShares is greater than zero
if core.SealAccess().StoredKeysSupported() {
if core.SealAccess().StoredKeysSupported() && !recovery {
respondError(w, http.StatusBadRequest, fmt.Errorf("rekeying of barrier not supported when stored key support is available"))
return
}

View File

@@ -8,6 +8,7 @@ import (
"net/http"
"github.com/hashicorp/errwrap"
"github.com/hashicorp/vault/helper/consts"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/vault"
"github.com/hashicorp/vault/version"
@@ -126,7 +127,7 @@ func handleSysUnseal(core *vault.Core) http.Handler {
case errwrap.Contains(err, vault.ErrBarrierInvalidKey.Error()):
case errwrap.Contains(err, vault.ErrBarrierNotInit.Error()):
case errwrap.Contains(err, vault.ErrBarrierSealed.Error()):
case errwrap.Contains(err, vault.ErrStandby.Error()):
case errwrap.Contains(err, consts.ErrStandby.Error()):
default:
respondError(w, http.StatusInternalServerError, err)
return

View File

@@ -8,7 +8,7 @@ import (
// is present on the Request structure for credential backends.
type Connection struct {
// RemoteAddr is the network address that sent the request.
RemoteAddr string
RemoteAddr string `json:"remote_addr"`
// ConnState is the TLS connection state if applicable.
ConnState *tls.ConnectionState

View File

@@ -21,3 +21,27 @@ func (e *codedError) Error() string {
func (e *codedError) Code() int {
return e.code
}
// Struct to identify user input errors. This is helpful in responding the
// appropriate status codes to clients from the HTTP endpoints.
type StatusBadRequest struct {
Err string
}
// Implementing error interface
func (s *StatusBadRequest) Error() string {
return s.Err
}
// This is a new type declared to not cause potential compatibility problems if
// the logic around the HTTPCodedError interface changes; in particular for
// logical request paths it is basically ignored, and changing that behavior
// might cause unforseen issues.
type ReplicationCodedError struct {
Msg string
Code int
}
func (r *ReplicationCodedError) Error() string {
return r.Msg
}

View File

@@ -1,6 +1,7 @@
package framework
import (
"encoding/json"
"fmt"
"io/ioutil"
"regexp"
@@ -12,6 +13,7 @@ import (
log "github.com/mgutz/logxi/v1"
"github.com/hashicorp/go-multierror"
"github.com/hashicorp/vault/helper/duration"
"github.com/hashicorp/vault/helper/errutil"
"github.com/hashicorp/vault/helper/logformat"
"github.com/hashicorp/vault/logical"
@@ -534,7 +536,40 @@ type FieldSchema struct {
// the zero value of the type.
func (s *FieldSchema) DefaultOrZero() interface{} {
if s.Default != nil {
return s.Default
switch s.Type {
case TypeDurationSecond:
var result int
switch inp := s.Default.(type) {
case nil:
return s.Type.Zero()
case int:
result = inp
case int64:
result = int(inp)
case float32:
result = int(inp)
case float64:
result = int(inp)
case string:
dur, err := duration.ParseDurationSecond(inp)
if err != nil {
return s.Type.Zero()
}
result = int(dur.Seconds())
case json.Number:
valInt64, err := inp.Int64()
if err != nil {
return s.Type.Zero()
}
result = int(valInt64)
default:
return s.Type.Zero()
}
return result
default:
return s.Default
}
}
return s.Type.Zero()

View File

@@ -554,6 +554,16 @@ func TestFieldSchemaDefaultOrZero(t *testing.T) {
60,
},
"default duration int64": {
&FieldSchema{Type: TypeDurationSecond, Default: int64(60)},
60,
},
"default duration string": {
&FieldSchema{Type: TypeDurationSecond, Default: "60s"},
60,
},
"default duration not set": {
&FieldSchema{Type: TypeDurationSecond},
0,

View File

@@ -80,22 +80,3 @@ type Paths struct {
// indicates that these paths should not be replicated
LocalStorage []string
}
type ReplicationState uint32
const (
ReplicationDisabled ReplicationState = iota
ReplicationPrimary
ReplicationSecondary
)
func (r ReplicationState) String() string {
switch r {
case ReplicationSecondary:
return "secondary"
case ReplicationPrimary:
return "primary"
}
return "disabled"
}

View File

@@ -25,6 +25,10 @@ type Request struct {
// Id is the uuid associated with each request
ID string `json:"id" structs:"id" mapstructure:"id"`
// If set, the name given to the replication secondary where this request
// originated
ReplicationCluster string `json:"replication_cluster" structs:"replication_cluster", mapstructure:"replication_cluster"`
// Operation is the requested operation type
Operation Operation `json:"operation" structs:"operation" mapstructure:"operation"`
@@ -38,7 +42,7 @@ type Request struct {
Data map[string]interface{} `json:"map" structs:"data" mapstructure:"data"`
// Storage can be used to durably store and retrieve state.
Storage Storage `json:"storage" structs:"storage" mapstructure:"storage"`
Storage Storage `json:"-"`
// Secret will be non-nil only for Revoke and Renew operations
// to represent the secret that was returned prior.

111
logical/response_util.go Normal file
View File

@@ -0,0 +1,111 @@
package logical
import (
"errors"
"fmt"
"net/http"
"github.com/hashicorp/errwrap"
multierror "github.com/hashicorp/go-multierror"
"github.com/hashicorp/vault/helper/consts"
)
// RespondErrorCommon pulls most of the functionality from http's
// respondErrorCommon and some of http's handleLogical and makes it available
// to both the http package and elsewhere.
func RespondErrorCommon(req *Request, resp *Response, err error) (int, error) {
if err == nil && (resp == nil || !resp.IsError()) {
switch {
case req.Operation == ReadOperation:
if resp == nil {
return http.StatusNotFound, nil
}
// Basically: if we have empty "keys" or no keys at all, 404. This
// provides consistency with GET.
case req.Operation == ListOperation && resp.WrapInfo == nil:
if resp == nil || len(resp.Data) == 0 {
return http.StatusNotFound, nil
}
keysRaw, ok := resp.Data["keys"]
if !ok || keysRaw == nil {
return http.StatusNotFound, nil
}
keys, ok := keysRaw.([]string)
if !ok {
return http.StatusInternalServerError, nil
}
if len(keys) == 0 {
return http.StatusNotFound, nil
}
}
return 0, nil
}
if errwrap.ContainsType(err, new(ReplicationCodedError)) {
var allErrors error
codedErr := errwrap.GetType(err, new(ReplicationCodedError)).(*ReplicationCodedError)
errwrap.Walk(err, func(inErr error) {
newErr, ok := inErr.(*ReplicationCodedError)
if !ok {
allErrors = multierror.Append(allErrors, newErr)
}
})
if allErrors != nil {
return codedErr.Code, multierror.Append(errors.New(fmt.Sprintf("errors from both primary and secondary; primary error was %v; secondary errors follow", codedErr.Msg)), allErrors)
}
return codedErr.Code, errors.New(codedErr.Msg)
}
// Start out with internal server error since in most of these cases there
// won't be a response so this won't be overridden
statusCode := http.StatusInternalServerError
// If we actually have a response, start out with bad request
if resp != nil {
statusCode = http.StatusBadRequest
}
// Now, check the error itself; if it has a specific logical error, set the
// appropriate code
if err != nil {
switch {
case errwrap.ContainsType(err, new(StatusBadRequest)):
statusCode = http.StatusBadRequest
case errwrap.Contains(err, ErrPermissionDenied.Error()):
statusCode = http.StatusForbidden
case errwrap.Contains(err, ErrUnsupportedOperation.Error()):
statusCode = http.StatusMethodNotAllowed
case errwrap.Contains(err, ErrUnsupportedPath.Error()):
statusCode = http.StatusNotFound
case errwrap.Contains(err, ErrInvalidRequest.Error()):
statusCode = http.StatusBadRequest
}
}
if resp != nil && resp.IsError() {
err = fmt.Errorf("%s", resp.Data["error"].(string))
}
return statusCode, err
}
// AdjustErrorStatusCode adjusts the status that will be sent in error
// conditions in a way that can be shared across http's respondError and other
// locations.
func AdjustErrorStatusCode(status *int, err error) {
// Adjust status code when sealed
if errwrap.Contains(err, consts.ErrSealed.Error()) {
*status = http.StatusServiceUnavailable
}
// Adjust status code on
if errwrap.Contains(err, "http: request body too large") {
*status = http.StatusRequestEntityTooLarge
}
// Allow HTTPCoded error passthrough to specify a code
if t, ok := err.(HTTPCodedError); ok {
*status = t.Code()
}
}

View File

@@ -1,6 +1,10 @@
package logical
import "time"
import (
"time"
"github.com/hashicorp/vault/helper/consts"
)
// SystemView exposes system configuration information in a safe way
// for logical backends to consume
@@ -32,7 +36,7 @@ type SystemView interface {
CachingDisabled() bool
// ReplicationState indicates the state of cluster replication
ReplicationState() ReplicationState
ReplicationState() consts.ReplicationState
}
type StaticSystemView struct {
@@ -42,7 +46,7 @@ type StaticSystemView struct {
TaintedVal bool
CachingDisabledVal bool
Primary bool
ReplicationStateVal ReplicationState
ReplicationStateVal consts.ReplicationState
}
func (d StaticSystemView) DefaultLeaseTTL() time.Duration {
@@ -65,6 +69,6 @@ func (d StaticSystemView) CachingDisabled() bool {
return d.CachingDisabledVal
}
func (d StaticSystemView) ReplicationState() ReplicationState {
func (d StaticSystemView) ReplicationState() consts.ReplicationState {
return d.ReplicationStateVal
}

View File

@@ -23,6 +23,21 @@ const (
FlagSetDefault = FlagSetServer
)
var (
additionalOptionsUsage = func() string {
return `
-wrap-ttl="" Indicates that the response should be wrapped in a
cubbyhole token with the requested TTL. The response
can be fetched by calling the "sys/wrapping/unwrap"
endpoint, passing in the wrappping token's ID. This
is a numeric string with an optional suffix
"s", "m", or "h"; if no suffix is specified it will
be parsed as seconds. May also be specified via
VAULT_WRAP_TTL.
`
}
)
// Meta contains the meta-options and functionality that nearly every
// Vault command inherits.
type Meta struct {
@@ -188,6 +203,6 @@ func GeneralOptionsUsage() string {
if VAULT_SKIP_VERIFY is set.
`
general += AdditionalOptionsUsage()
general += additionalOptionsUsage()
return general
}

View File

@@ -1,7 +0,0 @@
// +build !vault
package meta
func AdditionalOptionsUsage() string {
return ""
}

View File

@@ -1,16 +0,0 @@
// +build vault
package meta
func AdditionalOptionsUsage() string {
return `
-wrap-ttl="" Indicates that the response should be wrapped in a
cubbyhole token with the requested TTL. The response
can be fetched by calling the "sys/wrapping/unwrap"
endpoint, passing in the wrappping token's ID. This
is a numeric string with an optional suffix
"s", "m", or "h"; if no suffix is specified it will
be parsed as seconds. May also be specified via
VAULT_WRAP_TTL.
`
}

View File

@@ -1,9 +1,15 @@
package physical
import (
"crypto/sha1"
"encoding/hex"
"fmt"
"strings"
"sync"
"github.com/hashicorp/golang-lru"
"github.com/hashicorp/vault/helper/locksutil"
"github.com/hashicorp/vault/helper/strutil"
log "github.com/mgutz/logxi/v1"
)
@@ -17,8 +23,11 @@ const (
// Vault are for policy objects so there is a large read reduction
// by using a simple write-through cache.
type Cache struct {
backend Backend
lru *lru.TwoQueueCache
backend Backend
transactional Transactional
lru *lru.TwoQueueCache
locks map[string]*sync.RWMutex
logger log.Logger
}
// NewCache returns a physical cache of the given size.
@@ -34,16 +43,58 @@ func NewCache(b Backend, size int, logger log.Logger) *Cache {
c := &Cache{
backend: b,
lru: cache,
locks: make(map[string]*sync.RWMutex, 256),
logger: logger,
}
if err := locksutil.CreateLocks(c.locks, 256); err != nil {
logger.Error("physical/cache: error creating locks", "error", err)
return nil
}
if txnl, ok := c.backend.(Transactional); ok {
c.transactional = txnl
}
return c
}
func (c *Cache) lockHashForKey(key string) string {
hf := sha1.New()
hf.Write([]byte(key))
return strings.ToLower(hex.EncodeToString(hf.Sum(nil))[:2])
}
func (c *Cache) lockForKey(key string) *sync.RWMutex {
return c.locks[c.lockHashForKey(key)]
}
// Purge is used to clear the cache
func (c *Cache) Purge() {
// Lock the world
lockHashes := make([]string, 0, len(c.locks))
for hash := range c.locks {
lockHashes = append(lockHashes, hash)
}
// Sort and deduplicate. This ensures we don't try to grab the same lock
// twice, and enforcing a sort means we'll not have multiple goroutines
// deadlock by acquiring in different orders.
lockHashes = strutil.RemoveDuplicates(lockHashes)
for _, lockHash := range lockHashes {
lock := c.locks[lockHash]
lock.Lock()
defer lock.Unlock()
}
c.lru.Purge()
}
func (c *Cache) Put(entry *Entry) error {
lock := c.lockForKey(entry.Key)
lock.Lock()
defer lock.Unlock()
err := c.backend.Put(entry)
if err == nil {
c.lru.Add(entry.Key, entry)
@@ -52,6 +103,10 @@ func (c *Cache) Put(entry *Entry) error {
}
func (c *Cache) Get(key string) (*Entry, error) {
lock := c.lockForKey(key)
lock.RLock()
defer lock.RUnlock()
// Check the LRU first
if raw, ok := c.lru.Get(key); ok {
if raw == nil {
@@ -79,6 +134,10 @@ func (c *Cache) Get(key string) (*Entry, error) {
}
func (c *Cache) Delete(key string) error {
lock := c.lockForKey(key)
lock.Lock()
defer lock.Unlock()
err := c.backend.Delete(key)
if err == nil {
c.lru.Remove(key)
@@ -87,6 +146,45 @@ func (c *Cache) Delete(key string) error {
}
func (c *Cache) List(prefix string) ([]string, error) {
// Always pass-through as this would be difficult to cache.
// Always pass-through as this would be difficult to cache. For the same
// reason we don't lock as we can't reasonably know which locks to readlock
// ahead of time.
return c.backend.List(prefix)
}
func (c *Cache) Transaction(txns []TxnEntry) error {
if c.transactional == nil {
return fmt.Errorf("physical/cache: underlying backend does not support transactions")
}
var lockHashes []string
for _, txn := range txns {
lockHashes = append(lockHashes, c.lockHashForKey(txn.Entry.Key))
}
// Sort and deduplicate. This ensures we don't try to grab the same lock
// twice, and enforcing a sort means we'll not have multiple goroutines
// deadlock by acquiring in different orders.
lockHashes = strutil.RemoveDuplicates(lockHashes)
for _, lockHash := range lockHashes {
lock := c.locks[lockHash]
lock.Lock()
defer lock.Unlock()
}
if err := c.transactional.Transaction(txns); err != nil {
return err
}
for _, txn := range txns {
switch txn.Operation {
case PutOperation:
c.lru.Add(txn.Entry.Key, txn.Entry)
case DeleteOperation:
c.lru.Remove(txn.Entry.Key)
}
}
return nil
}

View File

@@ -1,6 +1,7 @@
package physical
import (
"errors"
"fmt"
"io/ioutil"
"net"
@@ -21,6 +22,8 @@ import (
"github.com/hashicorp/consul/lib"
"github.com/hashicorp/errwrap"
"github.com/hashicorp/go-cleanhttp"
multierror "github.com/hashicorp/go-multierror"
"github.com/hashicorp/vault/helper/consts"
"github.com/hashicorp/vault/helper/strutil"
"github.com/hashicorp/vault/helper/tlsutil"
)
@@ -154,6 +157,10 @@ func newConsulBackend(conf map[string]string, logger log.Logger) (Backend, error
// Configure the client
consulConf := api.DefaultConfig()
// Set MaxIdleConnsPerHost to the number of processes used in expiration.Restore
tr := cleanhttp.DefaultPooledTransport()
tr.MaxIdleConnsPerHost = consts.ExpirationRestoreWorkerCount
consulConf.HttpClient.Transport = tr
if addr, ok := conf["address"]; ok {
consulConf.Address = addr
@@ -179,7 +186,7 @@ func newConsulBackend(conf map[string]string, logger log.Logger) (Backend, error
}
transport := cleanhttp.DefaultPooledTransport()
transport.MaxIdleConnsPerHost = 4
transport.MaxIdleConnsPerHost = consts.ExpirationRestoreWorkerCount
transport.TLSClientConfig = tlsClientConfig
consulConf.HttpClient.Transport = transport
logger.Debug("physical/consul: configured TLS")
@@ -284,17 +291,59 @@ func setupTLSConfig(conf map[string]string) (*tls.Config, error) {
return tlsClientConfig, nil
}
// Used to run multiple entries via a transaction
func (c *ConsulBackend) Transaction(txns []TxnEntry) error {
if len(txns) == 0 {
return nil
}
ops := make([]*api.KVTxnOp, 0, len(txns))
for _, op := range txns {
cop := &api.KVTxnOp{
Key: c.path + op.Entry.Key,
}
switch op.Operation {
case DeleteOperation:
cop.Verb = api.KVDelete
case PutOperation:
cop.Verb = api.KVSet
cop.Value = op.Entry.Value
default:
return fmt.Errorf("%q is not a supported transaction operation", op.Operation)
}
ops = append(ops, cop)
}
ok, resp, _, err := c.kv.Txn(ops, nil)
if err != nil {
return err
}
if ok {
return nil
}
var retErr *multierror.Error
for _, res := range resp.Errors {
retErr = multierror.Append(retErr, errors.New(res.What))
}
return retErr
}
// Put is used to insert or update an entry
func (c *ConsulBackend) Put(entry *Entry) error {
defer metrics.MeasureSince([]string{"consul", "put"}, time.Now())
c.permitPool.Acquire()
defer c.permitPool.Release()
pair := &api.KVPair{
Key: c.path + entry.Key,
Value: entry.Value,
}
c.permitPool.Acquire()
defer c.permitPool.Release()
_, err := c.kv.Put(pair, nil)
return err
}

View File

@@ -22,12 +22,17 @@ import (
// and non-performant. It is meant mostly for local testing and development.
// It can be improved in the future.
type FileBackend struct {
Path string
l sync.Mutex
logger log.Logger
sync.RWMutex
path string
logger log.Logger
permitPool *PermitPool
}
// newFileBackend constructs a Filebackend using the given directory
type TransactionalFileBackend struct {
FileBackend
}
// newFileBackend constructs a FileBackend using the given directory
func newFileBackend(conf map[string]string, logger log.Logger) (Backend, error) {
path, ok := conf["path"]
if !ok {
@@ -35,20 +40,44 @@ func newFileBackend(conf map[string]string, logger log.Logger) (Backend, error)
}
return &FileBackend{
Path: path,
logger: logger,
path: path,
logger: logger,
permitPool: NewPermitPool(DefaultParallelOperations),
}, nil
}
func newTransactionalFileBackend(conf map[string]string, logger log.Logger) (Backend, error) {
path, ok := conf["path"]
if !ok {
return nil, fmt.Errorf("'path' must be set")
}
// Create a pool of size 1 so only one operation runs at a time
return &TransactionalFileBackend{
FileBackend: FileBackend{
path: path,
logger: logger,
permitPool: NewPermitPool(1),
},
}, nil
}
func (b *FileBackend) Delete(path string) error {
b.permitPool.Acquire()
defer b.permitPool.Release()
b.Lock()
defer b.Unlock()
return b.DeleteInternal(path)
}
func (b *FileBackend) DeleteInternal(path string) error {
if path == "" {
return nil
}
b.l.Lock()
defer b.l.Unlock()
basePath, key := b.path(path)
basePath, key := b.expandPath(path)
fullPath := filepath.Join(basePath, key)
err := os.Remove(fullPath)
@@ -66,7 +95,7 @@ func (b *FileBackend) Delete(path string) error {
func (b *FileBackend) cleanupLogicalPath(path string) error {
nodes := strings.Split(path, fmt.Sprintf("%c", os.PathSeparator))
for i := len(nodes) - 1; i > 0; i-- {
fullPath := filepath.Join(b.Path, filepath.Join(nodes[:i]...))
fullPath := filepath.Join(b.path, filepath.Join(nodes[:i]...))
dir, err := os.Open(fullPath)
if err != nil {
@@ -96,10 +125,17 @@ func (b *FileBackend) cleanupLogicalPath(path string) error {
}
func (b *FileBackend) Get(k string) (*Entry, error) {
b.l.Lock()
defer b.l.Unlock()
b.permitPool.Acquire()
defer b.permitPool.Release()
path, key := b.path(k)
b.RLock()
defer b.RUnlock()
return b.GetInternal(k)
}
func (b *FileBackend) GetInternal(k string) (*Entry, error) {
path, key := b.expandPath(k)
path = filepath.Join(path, key)
f, err := os.Open(path)
@@ -121,10 +157,17 @@ func (b *FileBackend) Get(k string) (*Entry, error) {
}
func (b *FileBackend) Put(entry *Entry) error {
path, key := b.path(entry.Key)
b.permitPool.Acquire()
defer b.permitPool.Release()
b.l.Lock()
defer b.l.Unlock()
b.Lock()
defer b.Unlock()
return b.PutInternal(entry)
}
func (b *FileBackend) PutInternal(entry *Entry) error {
path, key := b.expandPath(entry.Key)
// Make the parent tree
if err := os.MkdirAll(path, 0755); err != nil {
@@ -145,10 +188,17 @@ func (b *FileBackend) Put(entry *Entry) error {
}
func (b *FileBackend) List(prefix string) ([]string, error) {
b.l.Lock()
defer b.l.Unlock()
b.permitPool.Acquire()
defer b.permitPool.Release()
path := b.Path
b.RLock()
defer b.RUnlock()
return b.ListInternal(prefix)
}
func (b *FileBackend) ListInternal(prefix string) ([]string, error) {
path := b.path
if prefix != "" {
path = filepath.Join(path, prefix)
}
@@ -180,9 +230,19 @@ func (b *FileBackend) List(prefix string) ([]string, error) {
return names, nil
}
func (b *FileBackend) path(k string) (string, string) {
path := filepath.Join(b.Path, k)
func (b *FileBackend) expandPath(k string) (string, string) {
path := filepath.Join(b.path, k)
key := filepath.Base(path)
path = filepath.Dir(path)
return path, "_" + key
}
func (b *TransactionalFileBackend) Transaction(txns []TxnEntry) error {
b.permitPool.Acquire()
defer b.permitPool.Release()
b.Lock()
defer b.Unlock()
return genericTransactionHandler(b, txns)
}

View File

@@ -13,12 +13,16 @@ import (
// for testing and development situations where the data is not
// expected to be durable.
type InmemBackend struct {
sync.RWMutex
root *radix.Tree
l sync.RWMutex
permitPool *PermitPool
logger log.Logger
}
type TransactionalInmemBackend struct {
InmemBackend
}
// NewInmem constructs a new in-memory backend
func NewInmem(logger log.Logger) *InmemBackend {
in := &InmemBackend{
@@ -29,14 +33,31 @@ func NewInmem(logger log.Logger) *InmemBackend {
return in
}
// Basically for now just creates a permit pool of size 1 so only one operation
// can run at a time
func NewTransactionalInmem(logger log.Logger) *TransactionalInmemBackend {
in := &TransactionalInmemBackend{
InmemBackend: InmemBackend{
root: radix.New(),
permitPool: NewPermitPool(1),
logger: logger,
},
}
return in
}
// Put is used to insert or update an entry
func (i *InmemBackend) Put(entry *Entry) error {
i.permitPool.Acquire()
defer i.permitPool.Release()
i.l.Lock()
defer i.l.Unlock()
i.Lock()
defer i.Unlock()
return i.PutInternal(entry)
}
func (i *InmemBackend) PutInternal(entry *Entry) error {
i.root.Insert(entry.Key, entry)
return nil
}
@@ -46,9 +67,13 @@ func (i *InmemBackend) Get(key string) (*Entry, error) {
i.permitPool.Acquire()
defer i.permitPool.Release()
i.l.RLock()
defer i.l.RUnlock()
i.RLock()
defer i.RUnlock()
return i.GetInternal(key)
}
func (i *InmemBackend) GetInternal(key string) (*Entry, error) {
if raw, ok := i.root.Get(key); ok {
return raw.(*Entry), nil
}
@@ -60,9 +85,13 @@ func (i *InmemBackend) Delete(key string) error {
i.permitPool.Acquire()
defer i.permitPool.Release()
i.l.Lock()
defer i.l.Unlock()
i.Lock()
defer i.Unlock()
return i.DeleteInternal(key)
}
func (i *InmemBackend) DeleteInternal(key string) error {
i.root.Delete(key)
return nil
}
@@ -73,9 +102,13 @@ func (i *InmemBackend) List(prefix string) ([]string, error) {
i.permitPool.Acquire()
defer i.permitPool.Release()
i.l.RLock()
defer i.l.RUnlock()
i.RLock()
defer i.RUnlock()
return i.ListInternal(prefix)
}
func (i *InmemBackend) ListInternal(prefix string) ([]string, error) {
var out []string
seen := make(map[string]interface{})
walkFn := func(s string, v interface{}) bool {
@@ -96,3 +129,14 @@ func (i *InmemBackend) List(prefix string) ([]string, error) {
return out, nil
}
// Implements the transaction interface
func (t *TransactionalInmemBackend) Transaction(txns []TxnEntry) error {
t.permitPool.Acquire()
defer t.permitPool.Release()
t.Lock()
defer t.Unlock()
return genericTransactionHandler(t, txns)
}

View File

@@ -8,19 +8,40 @@ import (
)
type InmemHABackend struct {
InmemBackend
Backend
locks map[string]string
l sync.Mutex
cond *sync.Cond
logger log.Logger
}
type TransactionalInmemHABackend struct {
Transactional
InmemHABackend
}
// NewInmemHA constructs a new in-memory HA backend. This is only for testing.
func NewInmemHA(logger log.Logger) *InmemHABackend {
in := &InmemHABackend{
InmemBackend: *NewInmem(logger),
locks: make(map[string]string),
logger: logger,
Backend: NewInmem(logger),
locks: make(map[string]string),
logger: logger,
}
in.cond = sync.NewCond(&in.l)
return in
}
func NewTransactionalInmemHA(logger log.Logger) *TransactionalInmemHABackend {
transInmem := NewTransactionalInmem(logger)
inmemHA := InmemHABackend{
Backend: transInmem,
locks: make(map[string]string),
logger: logger,
}
in := &TransactionalInmemHABackend{
InmemHABackend: inmemHA,
Transactional: transInmem,
}
in.cond = sync.NewCond(&in.l)
return in

View File

@@ -9,6 +9,16 @@ import (
const DefaultParallelOperations = 128
// The operation type
type Operation string
const (
DeleteOperation Operation = "delete"
GetOperation = "get"
ListOperation = "list"
PutOperation = "put"
)
// ShutdownSignal
type ShutdownChannel chan struct{}
@@ -121,20 +131,27 @@ var builtinBackends = map[string]Factory{
"inmem": func(_ map[string]string, logger log.Logger) (Backend, error) {
return NewInmem(logger), nil
},
"inmem_transactional": func(_ map[string]string, logger log.Logger) (Backend, error) {
return NewTransactionalInmem(logger), nil
},
"inmem_ha": func(_ map[string]string, logger log.Logger) (Backend, error) {
return NewInmemHA(logger), nil
},
"consul": newConsulBackend,
"zookeeper": newZookeeperBackend,
"file": newFileBackend,
"s3": newS3Backend,
"azure": newAzureBackend,
"dynamodb": newDynamoDBBackend,
"etcd": newEtcdBackend,
"mysql": newMySQLBackend,
"postgresql": newPostgreSQLBackend,
"swift": newSwiftBackend,
"gcs": newGCSBackend,
"inmem_transactional_ha": func(_ map[string]string, logger log.Logger) (Backend, error) {
return NewTransactionalInmemHA(logger), nil
},
"file_transactional": newTransactionalFileBackend,
"consul": newConsulBackend,
"zookeeper": newZookeeperBackend,
"file": newFileBackend,
"s3": newS3Backend,
"azure": newAzureBackend,
"dynamodb": newDynamoDBBackend,
"etcd": newEtcdBackend,
"mysql": newMySQLBackend,
"postgresql": newPostgreSQLBackend,
"swift": newSwiftBackend,
"gcs": newGCSBackend,
}
// PermitPool is used to limit maximum outstanding requests

View File

@@ -71,7 +71,8 @@ func newPostgreSQLBackend(conf map[string]string, logger log.Logger) (Backend, e
get_query: "SELECT value FROM " + quoted_table + " WHERE path = $1 AND key = $2",
delete_query: "DELETE FROM " + quoted_table + " WHERE path = $1 AND key = $2",
list_query: "SELECT key FROM " + quoted_table + " WHERE path = $1" +
"UNION SELECT substr(path, length($1)+1) FROM " + quoted_table + "WHERE parent_path = $1",
"UNION SELECT DISTINCT substring(substr(path, length($1)+1) from '^.*?/') FROM " +
quoted_table + " WHERE parent_path LIKE concat($1, '%')",
logger: logger,
}

121
physical/transactions.go Normal file
View File

@@ -0,0 +1,121 @@
package physical
import multierror "github.com/hashicorp/go-multierror"
// TxnEntry is an operation that takes atomically as part of
// a transactional update. Only supported by Transactional backends.
type TxnEntry struct {
Operation Operation
Entry *Entry
}
// Transactional is an optional interface for backends that
// support doing transactional updates of multiple keys. This is
// required for some features such as replication.
type Transactional interface {
// The function to run a transaction
Transaction([]TxnEntry) error
}
type PseudoTransactional interface {
// An internal function should do no locking or permit pool acquisition.
// Depending on the backend and if it natively supports transactions, these
// may simply chain to the normal backend functions.
GetInternal(string) (*Entry, error)
PutInternal(*Entry) error
DeleteInternal(string) error
}
// Implements the transaction interface
func genericTransactionHandler(t PseudoTransactional, txns []TxnEntry) (retErr error) {
rollbackStack := make([]TxnEntry, 0, len(txns))
var dirty bool
// We walk the transactions in order; each successful operation goes into a
// LIFO for rollback if we hit an error along the way
TxnWalk:
for _, txn := range txns {
switch txn.Operation {
case DeleteOperation:
entry, err := t.GetInternal(txn.Entry.Key)
if err != nil {
retErr = multierror.Append(retErr, err)
dirty = true
break TxnWalk
}
if entry == nil {
// Nothing to delete or roll back
continue
}
rollbackEntry := TxnEntry{
Operation: PutOperation,
Entry: &Entry{
Key: entry.Key,
Value: entry.Value,
},
}
err = t.DeleteInternal(txn.Entry.Key)
if err != nil {
retErr = multierror.Append(retErr, err)
dirty = true
break TxnWalk
}
rollbackStack = append([]TxnEntry{rollbackEntry}, rollbackStack...)
case PutOperation:
entry, err := t.GetInternal(txn.Entry.Key)
if err != nil {
retErr = multierror.Append(retErr, err)
dirty = true
break TxnWalk
}
// Nothing existed so in fact rolling back requires a delete
var rollbackEntry TxnEntry
if entry == nil {
rollbackEntry = TxnEntry{
Operation: DeleteOperation,
Entry: &Entry{
Key: txn.Entry.Key,
},
}
} else {
rollbackEntry = TxnEntry{
Operation: PutOperation,
Entry: &Entry{
Key: entry.Key,
Value: entry.Value,
},
}
}
err = t.PutInternal(txn.Entry)
if err != nil {
retErr = multierror.Append(retErr, err)
dirty = true
break TxnWalk
}
rollbackStack = append([]TxnEntry{rollbackEntry}, rollbackStack...)
}
}
// Need to roll back because we hit an error along the way
if dirty {
// While traversing this, if we get an error, we continue anyways in
// best-effort fashion
for _, txn := range rollbackStack {
switch txn.Operation {
case DeleteOperation:
err := t.DeleteInternal(txn.Entry.Key)
if err != nil {
retErr = multierror.Append(retErr, err)
}
case PutOperation:
err := t.PutInternal(txn.Entry)
if err != nil {
retErr = multierror.Append(retErr, err)
}
}
}
}
return
}

View File

@@ -0,0 +1,254 @@
package physical
import (
"fmt"
"reflect"
"sort"
"testing"
radix "github.com/armon/go-radix"
"github.com/hashicorp/vault/helper/logformat"
log "github.com/mgutz/logxi/v1"
)
type faultyPseudo struct {
underlying InmemBackend
faultyPaths map[string]struct{}
}
func (f *faultyPseudo) Get(key string) (*Entry, error) {
return f.underlying.Get(key)
}
func (f *faultyPseudo) Put(entry *Entry) error {
return f.underlying.Put(entry)
}
func (f *faultyPseudo) Delete(key string) error {
return f.underlying.Delete(key)
}
func (f *faultyPseudo) GetInternal(key string) (*Entry, error) {
if _, ok := f.faultyPaths[key]; ok {
return nil, fmt.Errorf("fault")
}
return f.underlying.GetInternal(key)
}
func (f *faultyPseudo) PutInternal(entry *Entry) error {
if _, ok := f.faultyPaths[entry.Key]; ok {
return fmt.Errorf("fault")
}
return f.underlying.PutInternal(entry)
}
func (f *faultyPseudo) DeleteInternal(key string) error {
if _, ok := f.faultyPaths[key]; ok {
return fmt.Errorf("fault")
}
return f.underlying.DeleteInternal(key)
}
func (f *faultyPseudo) List(prefix string) ([]string, error) {
return f.underlying.List(prefix)
}
func (f *faultyPseudo) Transaction(txns []TxnEntry) error {
f.underlying.permitPool.Acquire()
defer f.underlying.permitPool.Release()
f.underlying.Lock()
defer f.underlying.Unlock()
return genericTransactionHandler(f, txns)
}
func newFaultyPseudo(logger log.Logger, faultyPaths []string) *faultyPseudo {
out := &faultyPseudo{
underlying: InmemBackend{
root: radix.New(),
permitPool: NewPermitPool(1),
logger: logger,
},
faultyPaths: make(map[string]struct{}, len(faultyPaths)),
}
for _, v := range faultyPaths {
out.faultyPaths[v] = struct{}{}
}
return out
}
func TestPseudo_Basic(t *testing.T) {
logger := logformat.NewVaultLogger(log.LevelTrace)
p := newFaultyPseudo(logger, nil)
testBackend(t, p)
testBackend_ListPrefix(t, p)
}
func TestPseudo_SuccessfulTransaction(t *testing.T) {
logger := logformat.NewVaultLogger(log.LevelTrace)
p := newFaultyPseudo(logger, nil)
txns := setupPseudo(p, t)
if err := p.Transaction(txns); err != nil {
t.Fatal(err)
}
keys, err := p.List("")
if err != nil {
t.Fatal(err)
}
expected := []string{"foo", "zip"}
sort.Strings(keys)
sort.Strings(expected)
if !reflect.DeepEqual(keys, expected) {
t.Fatalf("mismatch: expected\n%#v\ngot\n%#v\n", expected, keys)
}
entry, err := p.Get("foo")
if err != nil {
t.Fatal(err)
}
if entry == nil {
t.Fatal("got nil entry")
}
if entry.Value == nil {
t.Fatal("got nil value")
}
if string(entry.Value) != "bar3" {
t.Fatal("updates did not apply correctly")
}
entry, err = p.Get("zip")
if err != nil {
t.Fatal(err)
}
if entry == nil {
t.Fatal("got nil entry")
}
if entry.Value == nil {
t.Fatal("got nil value")
}
if string(entry.Value) != "zap3" {
t.Fatal("updates did not apply correctly")
}
}
func TestPseudo_FailedTransaction(t *testing.T) {
logger := logformat.NewVaultLogger(log.LevelTrace)
p := newFaultyPseudo(logger, []string{"zip"})
txns := setupPseudo(p, t)
if err := p.Transaction(txns); err == nil {
t.Fatal("expected error during transaction")
}
keys, err := p.List("")
if err != nil {
t.Fatal(err)
}
expected := []string{"foo", "zip", "deleteme", "deleteme2"}
sort.Strings(keys)
sort.Strings(expected)
if !reflect.DeepEqual(keys, expected) {
t.Fatalf("mismatch: expected\n%#v\ngot\n%#v\n", expected, keys)
}
entry, err := p.Get("foo")
if err != nil {
t.Fatal(err)
}
if entry == nil {
t.Fatal("got nil entry")
}
if entry.Value == nil {
t.Fatal("got nil value")
}
if string(entry.Value) != "bar" {
t.Fatal("values did not rollback correctly")
}
entry, err = p.Get("zip")
if err != nil {
t.Fatal(err)
}
if entry == nil {
t.Fatal("got nil entry")
}
if entry.Value == nil {
t.Fatal("got nil value")
}
if string(entry.Value) != "zap" {
t.Fatal("values did not rollback correctly")
}
}
func setupPseudo(p *faultyPseudo, t *testing.T) []TxnEntry {
// Add a few keys so that we test rollback with deletion
if err := p.Put(&Entry{
Key: "foo",
Value: []byte("bar"),
}); err != nil {
t.Fatal(err)
}
if err := p.Put(&Entry{
Key: "zip",
Value: []byte("zap"),
}); err != nil {
t.Fatal(err)
}
if err := p.Put(&Entry{
Key: "deleteme",
}); err != nil {
t.Fatal(err)
}
if err := p.Put(&Entry{
Key: "deleteme2",
}); err != nil {
t.Fatal(err)
}
txns := []TxnEntry{
TxnEntry{
Operation: PutOperation,
Entry: &Entry{
Key: "foo",
Value: []byte("bar2"),
},
},
TxnEntry{
Operation: DeleteOperation,
Entry: &Entry{
Key: "deleteme",
},
},
TxnEntry{
Operation: PutOperation,
Entry: &Entry{
Key: "foo",
Value: []byte("bar3"),
},
},
TxnEntry{
Operation: DeleteOperation,
Entry: &Entry{
Key: "deleteme2",
},
},
TxnEntry{
Operation: PutOperation,
Entry: &Entry{
Key: "zip",
Value: []byte("zap3"),
},
},
}
return txns
}

View File

@@ -10,7 +10,7 @@ RUN apt-get update -y && apt-get install --no-install-recommends -y -q \
git mercurial bzr \
&& rm -rf /var/lib/apt/lists/*
ENV GOVERSION 1.8rc3
ENV GOVERSION 1.8
RUN mkdir /goroot && mkdir /gopath
RUN curl https://storage.googleapis.com/golang/go${GOVERSION}.linux-amd64.tar.gz \
| tar xvzf - -C /goroot --strip-components=1

View File

@@ -2,7 +2,6 @@ package vault
import (
"crypto/sha256"
"encoding/json"
"errors"
"fmt"
"strings"
@@ -26,6 +25,10 @@ const (
// can only be viewed or modified after an unseal.
coreAuditConfigPath = "core/audit"
// coreLocalAuditConfigPath is used to store audit information for local
// (non-replicated) mounts
coreLocalAuditConfigPath = "core/local-audit"
// auditBarrierPrefix is the prefix to the UUID used in the
// barrier view for the audit backends.
auditBarrierPrefix = "audit/"
@@ -69,12 +72,15 @@ func (c *Core) enableAudit(entry *MountEntry) error {
}
// Generate a new UUID and view
entryUUID, err := uuid.GenerateUUID()
if err != nil {
return err
if entry.UUID == "" {
entryUUID, err := uuid.GenerateUUID()
if err != nil {
return err
}
entry.UUID = entryUUID
}
entry.UUID = entryUUID
view := NewBarrierView(c.barrier, auditBarrierPrefix+entry.UUID+"/")
viewPath := auditBarrierPrefix + entry.UUID + "/"
view := NewBarrierView(c.barrier, viewPath)
// Lookup the new backend
backend, err := c.newAuditBackend(entry, view, entry.Options)
@@ -119,6 +125,12 @@ func (c *Core) disableAudit(path string) (bool, error) {
c.removeAuditReloadFunc(entry)
// When unmounting all entries the JSON code will load back up from storage
// as a nil slice, which kills tests...just set it nil explicitly
if len(newTable.Entries) == 0 {
newTable.Entries = nil
}
// Update the audit table
if err := c.persistAudit(newTable); err != nil {
return true, errors.New("failed to update audit table")
@@ -131,12 +143,14 @@ func (c *Core) disableAudit(path string) (bool, error) {
if c.logger.IsInfo() {
c.logger.Info("core: disabled audit backend", "path", path)
}
return true, nil
}
// loadAudits is invoked as part of postUnseal to load the audit table
func (c *Core) loadAudits() error {
auditTable := &MountTable{}
localAuditTable := &MountTable{}
// Load the existing audit table
raw, err := c.barrier.Get(coreAuditConfigPath)
@@ -144,6 +158,11 @@ func (c *Core) loadAudits() error {
c.logger.Error("core: failed to read audit table", "error", err)
return errLoadAuditFailed
}
rawLocal, err := c.barrier.Get(coreLocalAuditConfigPath)
if err != nil {
c.logger.Error("core: failed to read local audit table", "error", err)
return errLoadAuditFailed
}
c.auditLock.Lock()
defer c.auditLock.Unlock()
@@ -155,6 +174,13 @@ func (c *Core) loadAudits() error {
}
c.audit = auditTable
}
if rawLocal != nil {
if err := jsonutil.DecodeJSON(rawLocal.Value, localAuditTable); err != nil {
c.logger.Error("core: failed to decode local audit table", "error", err)
return errLoadAuditFailed
}
c.audit.Entries = append(c.audit.Entries, localAuditTable.Entries...)
}
// Done if we have restored the audit table
if c.audit != nil {
@@ -203,17 +229,33 @@ func (c *Core) persistAudit(table *MountTable) error {
}
}
nonLocalAudit := &MountTable{
Type: auditTableType,
}
localAudit := &MountTable{
Type: auditTableType,
}
for _, entry := range table.Entries {
if entry.Local {
localAudit.Entries = append(localAudit.Entries, entry)
} else {
nonLocalAudit.Entries = append(nonLocalAudit.Entries, entry)
}
}
// Marshal the table
raw, err := json.Marshal(table)
compressedBytes, err := jsonutil.EncodeJSONAndCompress(nonLocalAudit, nil)
if err != nil {
c.logger.Error("core: failed to encode audit table", "error", err)
c.logger.Error("core: failed to encode and/or compress audit table", "error", err)
return err
}
// Create an entry
entry := &Entry{
Key: coreAuditConfigPath,
Value: raw,
Value: compressedBytes,
}
// Write to the physical backend
@@ -221,6 +263,24 @@ func (c *Core) persistAudit(table *MountTable) error {
c.logger.Error("core: failed to persist audit table", "error", err)
return err
}
// Repeat with local audit
compressedBytes, err = jsonutil.EncodeJSONAndCompress(localAudit, nil)
if err != nil {
c.logger.Error("core: failed to encode and/or compress local audit table", "error", err)
return err
}
entry = &Entry{
Key: coreLocalAuditConfigPath,
Value: compressedBytes,
}
if err := c.barrier.Put(entry); err != nil {
c.logger.Error("core: failed to persist local audit table", "error", err)
return err
}
return nil
}
@@ -236,7 +296,8 @@ func (c *Core) setupAudits() error {
for _, entry := range c.audit.Entries {
// Create a barrier view using the UUID
view := NewBarrierView(c.barrier, auditBarrierPrefix+entry.UUID+"/")
viewPath := auditBarrierPrefix + entry.UUID + "/"
view := NewBarrierView(c.barrier, viewPath)
// Initialize the backend
audit, err := c.newAuditBackend(entry, view, entry.Options)

View File

@@ -11,6 +11,7 @@ import (
"github.com/hashicorp/errwrap"
"github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/audit"
"github.com/hashicorp/vault/helper/jsonutil"
"github.com/hashicorp/vault/helper/logformat"
"github.com/hashicorp/vault/logical"
log "github.com/mgutz/logxi/v1"
@@ -164,6 +165,94 @@ func TestCore_EnableAudit_MixedFailures(t *testing.T) {
}
}
// Test that the local table actually gets populated as expected with local
// entries, and that upon reading the entries from both are recombined
// correctly
func TestCore_EnableAudit_Local(t *testing.T) {
c, _, _ := TestCoreUnsealed(t)
c.auditBackends["noop"] = func(config *audit.BackendConfig) (audit.Backend, error) {
return &NoopAudit{
Config: config,
}, nil
}
c.auditBackends["fail"] = func(config *audit.BackendConfig) (audit.Backend, error) {
return nil, fmt.Errorf("failing enabling")
}
c.audit = &MountTable{
Type: auditTableType,
Entries: []*MountEntry{
&MountEntry{
Table: auditTableType,
Path: "noop/",
Type: "noop",
UUID: "abcd",
},
&MountEntry{
Table: auditTableType,
Path: "noop2/",
Type: "noop",
UUID: "bcde",
},
},
}
// Both should set up successfully
err := c.setupAudits()
if err != nil {
t.Fatal(err)
}
rawLocal, err := c.barrier.Get(coreLocalAuditConfigPath)
if err != nil {
t.Fatal(err)
}
if rawLocal == nil {
t.Fatal("expected non-nil local audit")
}
localAuditTable := &MountTable{}
if err := jsonutil.DecodeJSON(rawLocal.Value, localAuditTable); err != nil {
t.Fatal(err)
}
if len(localAuditTable.Entries) > 0 {
t.Fatalf("expected no entries in local audit table, got %#v", localAuditTable)
}
c.audit.Entries[1].Local = true
if err := c.persistAudit(c.audit); err != nil {
t.Fatal(err)
}
rawLocal, err = c.barrier.Get(coreLocalAuditConfigPath)
if err != nil {
t.Fatal(err)
}
if rawLocal == nil {
t.Fatal("expected non-nil local audit")
}
localAuditTable = &MountTable{}
if err := jsonutil.DecodeJSON(rawLocal.Value, localAuditTable); err != nil {
t.Fatal(err)
}
if len(localAuditTable.Entries) != 1 {
t.Fatalf("expected one entry in local audit table, got %#v", localAuditTable)
}
oldAudit := c.audit
if err := c.loadAudits(); err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(oldAudit, c.audit) {
t.Fatalf("expected\n%#v\ngot\n%#v\n", oldAudit, c.audit)
}
if len(c.audit.Entries) != 2 {
t.Fatalf("expected two audit entries, got %#v", localAuditTable)
}
}
func TestCore_DisableAudit(t *testing.T) {
c, keys, _ := TestCoreUnsealed(t)
c.auditBackends["noop"] = func(config *audit.BackendConfig) (audit.Backend, error) {
@@ -217,7 +306,7 @@ func TestCore_DisableAudit(t *testing.T) {
// Verify matching mount tables
if !reflect.DeepEqual(c.audit, c2.audit) {
t.Fatalf("mismatch: %v %v", c.audit, c2.audit)
t.Fatalf("mismatch:\n%#v\n%#v", c.audit, c2.audit)
}
}

View File

@@ -1,7 +1,6 @@
package vault
import (
"encoding/json"
"errors"
"fmt"
"strings"
@@ -17,6 +16,10 @@ const (
// can only be viewed or modified after an unseal.
coreAuthConfigPath = "core/auth"
// coreLocalAuthConfigPath is used to store credential configuration for
// local (non-replicated) mounts
coreLocalAuthConfigPath = "core/local-auth"
// credentialBarrierPrefix is the prefix to the UUID used in the
// barrier view for the credential backends.
credentialBarrierPrefix = "auth/"
@@ -71,16 +74,25 @@ func (c *Core) enableCredential(entry *MountEntry) error {
}
// Generate a new UUID and view
entryUUID, err := uuid.GenerateUUID()
if entry.UUID == "" {
entryUUID, err := uuid.GenerateUUID()
if err != nil {
return err
}
entry.UUID = entryUUID
}
viewPath := credentialBarrierPrefix + entry.UUID + "/"
view := NewBarrierView(c.barrier, viewPath)
sysView := c.mountEntrySysView(entry)
// Create the new backend
backend, err := c.newCredentialBackend(entry.Type, sysView, view, nil)
if err != nil {
return err
}
entry.UUID = entryUUID
view := NewBarrierView(c.barrier, credentialBarrierPrefix+entry.UUID+"/")
// Create the new backend
backend, err := c.newCredentialBackend(entry.Type, c.mountEntrySysView(entry), view, nil)
if err != nil {
if err := backend.Initialize(); err != nil {
return err
}
@@ -121,7 +133,7 @@ func (c *Core) disableCredential(path string) (bool, error) {
fullPath := credentialRoutePrefix + path
view := c.router.MatchingStorageView(fullPath)
if view == nil {
return false, fmt.Errorf("no matching backend")
return false, fmt.Errorf("no matching backend %s", fullPath)
}
// Mark the entry as tainted
@@ -206,12 +218,19 @@ func (c *Core) taintCredEntry(path string) error {
// loadCredentials is invoked as part of postUnseal to load the auth table
func (c *Core) loadCredentials() error {
authTable := &MountTable{}
localAuthTable := &MountTable{}
// Load the existing mount table
raw, err := c.barrier.Get(coreAuthConfigPath)
if err != nil {
c.logger.Error("core: failed to read auth table", "error", err)
return errLoadAuthFailed
}
rawLocal, err := c.barrier.Get(coreLocalAuthConfigPath)
if err != nil {
c.logger.Error("core: failed to read local auth table", "error", err)
return errLoadAuthFailed
}
c.authLock.Lock()
defer c.authLock.Unlock()
@@ -223,6 +242,13 @@ func (c *Core) loadCredentials() error {
}
c.auth = authTable
}
if rawLocal != nil {
if err := jsonutil.DecodeJSON(rawLocal.Value, localAuthTable); err != nil {
c.logger.Error("core: failed to decode local auth table", "error", err)
return errLoadAuthFailed
}
c.auth.Entries = append(c.auth.Entries, localAuthTable.Entries...)
}
// Done if we have restored the auth table
if c.auth != nil {
@@ -272,17 +298,33 @@ func (c *Core) persistAuth(table *MountTable) error {
}
}
nonLocalAuth := &MountTable{
Type: credentialTableType,
}
localAuth := &MountTable{
Type: credentialTableType,
}
for _, entry := range table.Entries {
if entry.Local {
localAuth.Entries = append(localAuth.Entries, entry)
} else {
nonLocalAuth.Entries = append(nonLocalAuth.Entries, entry)
}
}
// Marshal the table
raw, err := json.Marshal(table)
compressedBytes, err := jsonutil.EncodeJSONAndCompress(nonLocalAuth, nil)
if err != nil {
c.logger.Error("core: failed to encode auth table", "error", err)
c.logger.Error("core: failed to encode and/or compress auth table", "error", err)
return err
}
// Create an entry
entry := &Entry{
Key: coreAuthConfigPath,
Value: raw,
Value: compressedBytes,
}
// Write to the physical backend
@@ -290,6 +332,24 @@ func (c *Core) persistAuth(table *MountTable) error {
c.logger.Error("core: failed to persist auth table", "error", err)
return err
}
// Repeat with local auth
compressedBytes, err = jsonutil.EncodeJSONAndCompress(localAuth, nil)
if err != nil {
c.logger.Error("core: failed to encode and/or compress local auth table", "error", err)
return err
}
entry = &Entry{
Key: coreLocalAuthConfigPath,
Value: compressedBytes,
}
if err := c.barrier.Put(entry); err != nil {
c.logger.Error("core: failed to persist local auth table", "error", err)
return err
}
return nil
}
@@ -312,15 +372,21 @@ func (c *Core) setupCredentials() error {
}
// Create a barrier view using the UUID
view = NewBarrierView(c.barrier, credentialBarrierPrefix+entry.UUID+"/")
viewPath := credentialBarrierPrefix + entry.UUID + "/"
view = NewBarrierView(c.barrier, viewPath)
sysView := c.mountEntrySysView(entry)
// Initialize the backend
backend, err = c.newCredentialBackend(entry.Type, c.mountEntrySysView(entry), view, nil)
backend, err = c.newCredentialBackend(entry.Type, sysView, view, nil)
if err != nil {
c.logger.Error("core: failed to create credential entry", "path", entry.Path, "error", err)
return errLoadAuthFailed
}
if err := backend.Initialize(); err != nil {
return err
}
// Mount the backend
path := credentialRoutePrefix + entry.Path
err = c.router.Mount(backend, path, entry, view)

View File

@@ -2,8 +2,10 @@ package vault
import (
"reflect"
"strings"
"testing"
"github.com/hashicorp/vault/helper/jsonutil"
"github.com/hashicorp/vault/logical"
)
@@ -84,6 +86,88 @@ func TestCore_EnableCredential(t *testing.T) {
}
}
// Test that the local table actually gets populated as expected with local
// entries, and that upon reading the entries from both are recombined
// correctly
func TestCore_EnableCredential_Local(t *testing.T) {
c, _, _ := TestCoreUnsealed(t)
c.credentialBackends["noop"] = func(*logical.BackendConfig) (logical.Backend, error) {
return &NoopBackend{}, nil
}
c.auth = &MountTable{
Type: credentialTableType,
Entries: []*MountEntry{
&MountEntry{
Table: credentialTableType,
Path: "noop/",
Type: "noop",
UUID: "abcd",
},
&MountEntry{
Table: credentialTableType,
Path: "noop2/",
Type: "noop",
UUID: "bcde",
},
},
}
// Both should set up successfully
err := c.setupCredentials()
if err != nil {
t.Fatal(err)
}
rawLocal, err := c.barrier.Get(coreLocalAuthConfigPath)
if err != nil {
t.Fatal(err)
}
if rawLocal == nil {
t.Fatal("expected non-nil local credential")
}
localCredentialTable := &MountTable{}
if err := jsonutil.DecodeJSON(rawLocal.Value, localCredentialTable); err != nil {
t.Fatal(err)
}
if len(localCredentialTable.Entries) > 0 {
t.Fatalf("expected no entries in local credential table, got %#v", localCredentialTable)
}
c.auth.Entries[1].Local = true
if err := c.persistAuth(c.auth); err != nil {
t.Fatal(err)
}
rawLocal, err = c.barrier.Get(coreLocalAuthConfigPath)
if err != nil {
t.Fatal(err)
}
if rawLocal == nil {
t.Fatal("expected non-nil local credential")
}
localCredentialTable = &MountTable{}
if err := jsonutil.DecodeJSON(rawLocal.Value, localCredentialTable); err != nil {
t.Fatal(err)
}
if len(localCredentialTable.Entries) != 1 {
t.Fatalf("expected one entry in local credential table, got %#v", localCredentialTable)
}
oldCredential := c.auth
if err := c.loadCredentials(); err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(oldCredential, c.auth) {
t.Fatalf("expected\n%#v\ngot\n%#v\n", oldCredential, c.auth)
}
if len(c.auth.Entries) != 2 {
t.Fatalf("expected two credential entries, got %#v", localCredentialTable)
}
}
func TestCore_EnableCredential_twice_409(t *testing.T) {
c, _, _ := TestCoreUnsealed(t)
c.credentialBackends["noop"] = func(*logical.BackendConfig) (logical.Backend, error) {
@@ -132,7 +216,7 @@ func TestCore_DisableCredential(t *testing.T) {
}
existed, err := c.disableCredential("foo")
if existed || err.Error() != "no matching backend" {
if existed || (err != nil && !strings.HasPrefix(err.Error(), "no matching backend")) {
t.Fatalf("existed: %v; err: %v", existed, err)
}

View File

@@ -86,6 +86,11 @@ type SecurityBarrier interface {
// VerifyMaster is used to check if the given key matches the master key
VerifyMaster(key []byte) error
// SetMasterKey is used to directly set a new master key. This is used in
// repliated scenarios due to the chicken and egg problem of reloading the
// keyring from disk before we have the master key to decrypt it.
SetMasterKey(key []byte) error
// ReloadKeyring is used to re-read the underlying keyring.
// This is used for HA deployments to ensure the latest keyring
// is present in the leader.
@@ -119,8 +124,14 @@ type SecurityBarrier interface {
// Rekey is used to change the master key used to protect the keyring
Rekey([]byte) error
// For replication we must send over the keyring, so this must be available
Keyring() (*Keyring, error)
// SecurityBarrier must provide the storage APIs
BarrierStorage
// SecurityBarrier must provide the encryption APIs
BarrierEncryptor
}
// BarrierStorage is the storage only interface required for a Barrier.
@@ -139,6 +150,14 @@ type BarrierStorage interface {
List(prefix string) ([]string, error)
}
// BarrierEncryptor is the in memory only interface that does not actually
// use the underlying barrier. It is used for lower level modules like the
// Write-Ahead-Log and Merkle index to allow them to use the barrier.
type BarrierEncryptor interface {
Encrypt(key string, plaintext []byte) ([]byte, error)
Decrypt(key string, ciphertext []byte) ([]byte, error)
}
// Entry is used to represent data stored by the security barrier
type Entry struct {
Key string

View File

@@ -574,19 +574,12 @@ func (b *AESGCMBarrier) ActiveKeyInfo() (*KeyInfo, error) {
func (b *AESGCMBarrier) Rekey(key []byte) error {
b.l.Lock()
defer b.l.Unlock()
if b.sealed {
return ErrBarrierSealed
}
// Verify the key size
min, max := b.KeyLength()
if len(key) < min || len(key) > max {
return fmt.Errorf("Key size must be %d or %d", min, max)
newKeyring, err := b.updateMasterKeyCommon(key)
if err != nil {
return err
}
// Add a new encryption key
newKeyring := b.keyring.SetMasterKey(key)
// Persist the new keyring
if err := b.persistKeyring(newKeyring); err != nil {
return err
@@ -599,6 +592,40 @@ func (b *AESGCMBarrier) Rekey(key []byte) error {
return nil
}
// SetMasterKey updates the keyring's in-memory master key but does not persist
// anything to storage
func (b *AESGCMBarrier) SetMasterKey(key []byte) error {
b.l.Lock()
defer b.l.Unlock()
newKeyring, err := b.updateMasterKeyCommon(key)
if err != nil {
return err
}
// Swap the keyrings
oldKeyring := b.keyring
b.keyring = newKeyring
oldKeyring.Zeroize(false)
return nil
}
// Performs common tasks related to updating the master key; note that the lock
// must be held before calling this function
func (b *AESGCMBarrier) updateMasterKeyCommon(key []byte) (*Keyring, error) {
if b.sealed {
return nil, ErrBarrierSealed
}
// Verify the key size
min, max := b.KeyLength()
if len(key) < min || len(key) > max {
return nil, fmt.Errorf("Key size must be %d or %d", min, max)
}
return b.keyring.SetMasterKey(key), nil
}
// Put is used to insert or update an entry
func (b *AESGCMBarrier) Put(entry *Entry) error {
defer metrics.MeasureSince([]string{"barrier", "put"}, time.Now())
@@ -813,3 +840,47 @@ func (b *AESGCMBarrier) decryptKeyring(path string, cipher []byte) ([]byte, erro
return nil, fmt.Errorf("version bytes mis-match")
}
}
// Encrypt is used to encrypt in-memory for the BarrierEncryptor interface
func (b *AESGCMBarrier) Encrypt(key string, plaintext []byte) ([]byte, error) {
b.l.RLock()
defer b.l.RUnlock()
if b.sealed {
return nil, ErrBarrierSealed
}
term := b.keyring.ActiveTerm()
primary, err := b.aeadForTerm(term)
if err != nil {
return nil, err
}
ciphertext := b.encrypt(key, term, primary, plaintext)
return ciphertext, nil
}
// Decrypt is used to decrypt in-memory for the BarrierEncryptor interface
func (b *AESGCMBarrier) Decrypt(key string, ciphertext []byte) ([]byte, error) {
b.l.RLock()
defer b.l.RUnlock()
if b.sealed {
return nil, ErrBarrierSealed
}
// Decrypt the ciphertext
plain, err := b.decryptKeyring(key, ciphertext)
if err != nil {
return nil, fmt.Errorf("decryption failed: %v", err)
}
return plain, nil
}
func (b *AESGCMBarrier) Keyring() (*Keyring, error) {
b.l.RLock()
defer b.l.RUnlock()
if b.sealed {
return nil, ErrBarrierSealed
}
return b.keyring.Clone(), nil
}

View File

@@ -15,7 +15,7 @@ var (
)
// mockBarrier returns a physical backend, security barrier, and master key
func mockBarrier(t *testing.T) (physical.Backend, SecurityBarrier, []byte) {
func mockBarrier(t testing.TB) (physical.Backend, SecurityBarrier, []byte) {
inm := physical.NewInmem(logger)
b, err := NewAESGCMBarrier(inm)
@@ -433,3 +433,30 @@ func TestInitialize_KeyLength(t *testing.T) {
t.Fatalf("key length protection failed")
}
}
func TestEncrypt_BarrierEncryptor(t *testing.T) {
inm := physical.NewInmem(logger)
b, err := NewAESGCMBarrier(inm)
if err != nil {
t.Fatalf("err: %v", err)
}
// Initialize and unseal
key, _ := b.GenerateKey()
b.Initialize(key)
b.Unseal(key)
cipher, err := b.Encrypt("foo", []byte("quick brown fox"))
if err != nil {
t.Fatalf("err: %v", err)
}
plain, err := b.Decrypt("foo", cipher)
if err != nil {
t.Fatalf("err: %v", err)
}
if string(plain) != "quick brown fox" {
t.Fatalf("bad: %s", plain)
}
}

View File

@@ -69,14 +69,18 @@ func (v *BarrierView) Get(key string) (*logical.StorageEntry, error) {
// logical.Storage impl.
func (v *BarrierView) Put(entry *logical.StorageEntry) error {
if v.readonly {
return logical.ErrReadOnly
}
if err := v.sanityCheck(entry.Key); err != nil {
return err
}
expandedKey := v.expandKey(entry.Key)
if v.readonly {
return logical.ErrReadOnly
}
nested := &Entry{
Key: v.expandKey(entry.Key),
Key: expandedKey,
Value: entry.Value,
}
return v.barrier.Put(nested)
@@ -84,13 +88,18 @@ func (v *BarrierView) Put(entry *logical.StorageEntry) error {
// logical.Storage impl.
func (v *BarrierView) Delete(key string) error {
if v.readonly {
return logical.ErrReadOnly
}
if err := v.sanityCheck(key); err != nil {
return err
}
return v.barrier.Delete(v.expandKey(key))
expandedKey := v.expandKey(key)
if v.readonly {
return logical.ErrReadOnly
}
return v.barrier.Delete(expandedKey)
}
// SubView constructs a nested sub-view using the given prefix

View File

@@ -1,27 +1,19 @@
package vault
import "sort"
import (
"sort"
// Struct to identify user input errors.
// This is helpful in responding the appropriate status codes to clients
// from the HTTP endpoints.
type StatusBadRequest struct {
Err string
}
// Implementing error interface
func (s *StatusBadRequest) Error() string {
return s.Err
}
"github.com/hashicorp/vault/logical"
)
// Capabilities is used to fetch the capabilities of the given token on the given path
func (c *Core) Capabilities(token, path string) ([]string, error) {
if path == "" {
return nil, &StatusBadRequest{Err: "missing path"}
return nil, &logical.StatusBadRequest{Err: "missing path"}
}
if token == "" {
return nil, &StatusBadRequest{Err: "missing token"}
return nil, &logical.StatusBadRequest{Err: "missing token"}
}
te, err := c.tokenStore.Lookup(token)
@@ -29,7 +21,7 @@ func (c *Core) Capabilities(token, path string) ([]string, error) {
return nil, err
}
if te == nil {
return nil, &StatusBadRequest{Err: "invalid token"}
return nil, &logical.StatusBadRequest{Err: "invalid token"}
}
if te.Policies == nil {

View File

@@ -43,7 +43,7 @@ var (
// This can be one of a few key types so the different params may or may not be filled
type clusterKeyParams struct {
Type string `json:"type"`
Type string `json:"type" structs:"type" mapstructure:"type"`
X *big.Int `json:"x" structs:"x" mapstructure:"x"`
Y *big.Int `json:"y" structs:"y" mapstructure:"y"`
D *big.Int `json:"d" structs:"d" mapstructure:"d"`
@@ -339,45 +339,67 @@ func (c *Core) stopClusterListener() {
c.logger.Info("core/stopClusterListener: success")
}
// ClusterTLSConfig generates a TLS configuration based on the local cluster
// key and cert.
// ClusterTLSConfig generates a TLS configuration based on the local/replicated
// cluster key and cert.
func (c *Core) ClusterTLSConfig() (*tls.Config, error) {
cluster, err := c.Cluster()
if err != nil {
return nil, err
}
if cluster == nil {
return nil, fmt.Errorf("cluster information is nil")
return nil, fmt.Errorf("local cluster information is nil")
}
// Prevent data races with the TLS parameters
c.clusterParamsLock.Lock()
defer c.clusterParamsLock.Unlock()
if c.localClusterCert == nil || len(c.localClusterCert) == 0 {
return nil, fmt.Errorf("cluster certificate is nil")
forwarding := c.localClusterCert != nil && len(c.localClusterCert) > 0
var parsedCert *x509.Certificate
if forwarding {
parsedCert, err = x509.ParseCertificate(c.localClusterCert)
if err != nil {
return nil, fmt.Errorf("error parsing local cluster certificate: %v", err)
}
// This is idempotent, so be sure it's been added
c.clusterCertPool.AddCert(parsedCert)
}
parsedCert, err := x509.ParseCertificate(c.localClusterCert)
if err != nil {
return nil, fmt.Errorf("error parsing local cluster certificate: %v", err)
}
nameLookup := func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
c.clusterParamsLock.RLock()
defer c.clusterParamsLock.RUnlock()
// This is idempotent, so be sure it's been added
c.clusterCertPool.AddCert(parsedCert)
tlsConfig := &tls.Config{
Certificates: []tls.Certificate{
tls.Certificate{
if forwarding && clientHello.ServerName == parsedCert.Subject.CommonName {
return &tls.Certificate{
Certificate: [][]byte{c.localClusterCert},
PrivateKey: c.localClusterPrivateKey,
},
},
RootCAs: c.clusterCertPool,
ServerName: parsedCert.Subject.CommonName,
ClientAuth: tls.RequireAndVerifyClientCert,
ClientCAs: c.clusterCertPool,
MinVersion: tls.VersionTLS12,
}, nil
}
return nil, nil
}
var clientCertificates []tls.Certificate
if forwarding {
clientCertificates = append(clientCertificates, tls.Certificate{
Certificate: [][]byte{c.localClusterCert},
PrivateKey: c.localClusterPrivateKey,
})
}
tlsConfig := &tls.Config{
// We need this here for the client side
Certificates: clientCertificates,
RootCAs: c.clusterCertPool,
ClientAuth: tls.RequireAndVerifyClientCert,
ClientCAs: c.clusterCertPool,
GetCertificate: nameLookup,
MinVersion: tls.VersionTLS12,
}
if forwarding {
tlsConfig.ServerName = parsedCert.Subject.CommonName
}
return tlsConfig, nil

View File

@@ -10,6 +10,7 @@ import (
"testing"
"time"
"github.com/hashicorp/vault/helper/consts"
"github.com/hashicorp/vault/helper/logformat"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/physical"
@@ -100,7 +101,7 @@ func TestCluster_ListenForRequests(t *testing.T) {
checkListenersFunc := func(expectFail bool) {
tlsConfig, err := cores[0].ClusterTLSConfig()
if err != nil {
if err.Error() != ErrSealed.Error() {
if err.Error() != consts.ErrSealed.Error() {
t.Fatal(err)
}
tlsConfig = lastTLSConfig

View File

@@ -1,9 +1,9 @@
package vault
import (
"bytes"
"crypto"
"crypto/ecdsa"
"crypto/subtle"
"crypto/x509"
"errors"
"fmt"
@@ -23,6 +23,7 @@ import (
"github.com/hashicorp/go-multierror"
"github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/audit"
"github.com/hashicorp/vault/helper/consts"
"github.com/hashicorp/vault/helper/errutil"
"github.com/hashicorp/vault/helper/jsonutil"
"github.com/hashicorp/vault/helper/logformat"
@@ -56,17 +57,14 @@ const (
// leaderPrefixCleanDelay is how long to wait between deletions
// of orphaned leader keys, to prevent slamming the backend.
leaderPrefixCleanDelay = 200 * time.Millisecond
// coreKeyringCanaryPath is used as a canary to indicate to replicated
// clusters that they need to perform a rekey operation synchronously; this
// isn't keyring-canary to avoid ignoring it when ignoring core/keyring
coreKeyringCanaryPath = "core/canary-keyring"
)
var (
// ErrSealed is returned if an operation is performed on
// a sealed barrier. No operation is expected to succeed before unsealing
ErrSealed = errors.New("Vault is sealed")
// ErrStandby is returned if an operation is performed on
// a standby Vault. No operation is expected to succeed until active.
ErrStandby = errors.New("Vault is in standby mode")
// ErrAlreadyInit is returned if the core is already
// initialized. This prevents a re-initialization.
ErrAlreadyInit = errors.New("Vault is already initialized")
@@ -87,6 +85,12 @@ var (
// step down of the active node, to prevent instantly regrabbing the lock.
// It's var not const so that tests can manipulate it.
manualStepDownSleepPeriod = 10 * time.Second
// Functions only in the Enterprise version
enterprisePostUnseal = enterprisePostUnsealImpl
enterprisePreSeal = enterprisePreSealImpl
startReplication = startReplicationImpl
stopReplication = stopReplicationImpl
)
// ReloadFunc are functions that are called when a reload is requested.
@@ -133,6 +137,11 @@ type unlockInformation struct {
// interface for API handlers and is responsible for managing the logical and physical
// backends, router, security barrier, and audit trails.
type Core struct {
// N.B.: This is used to populate a dev token down replication, as
// otherwise, after replication is started, a dev would have to go through
// the generate-root process simply to talk to the new follower cluster.
devToken string
// HABackend may be available depending on the physical backend
ha physical.HABackend
@@ -268,7 +277,7 @@ type Core struct {
//
// Name
clusterName string
// Used to modify cluster TLS params
// Used to modify cluster parameters
clusterParamsLock sync.RWMutex
// The private key stored in the barrier used for establishing
// mutually-authenticated connections between Vault cluster members
@@ -310,11 +319,13 @@ type Core struct {
// replicationState keeps the current replication state cached for quick
// lookup
replicationState logical.ReplicationState
replicationState consts.ReplicationState
}
// CoreConfig is used to parameterize a core
type CoreConfig struct {
DevToken string `json:"dev_token" structs:"dev_token" mapstructure:"dev_token"`
LogicalBackends map[string]logical.Factory `json:"logical_backends" structs:"logical_backends" mapstructure:"logical_backends"`
CredentialBackends map[string]logical.Factory `json:"credential_backends" structs:"credential_backends" mapstructure:"credential_backends"`
@@ -390,6 +401,30 @@ func NewCore(conf *CoreConfig) (*Core, error) {
conf.Logger = logformat.NewVaultLogger(log.LevelTrace)
}
// Setup the core
c := &Core{
redirectAddr: conf.RedirectAddr,
clusterAddr: conf.ClusterAddr,
physical: conf.Physical,
seal: conf.Seal,
router: NewRouter(),
sealed: true,
standby: true,
logger: conf.Logger,
defaultLeaseTTL: conf.DefaultLeaseTTL,
maxLeaseTTL: conf.MaxLeaseTTL,
cachingDisabled: conf.DisableCache,
clusterName: conf.ClusterName,
clusterCertPool: x509.NewCertPool(),
clusterListenerShutdownCh: make(chan struct{}),
clusterListenerShutdownSuccessCh: make(chan struct{}),
}
// Wrap the physical backend in a cache layer if enabled and not already wrapped
if _, isCache := conf.Physical.(*physical.Cache); !conf.DisableCache && !isCache {
c.physical = physical.NewCache(conf.Physical, conf.CacheSize, conf.Logger)
}
if !conf.DisableMlock {
// Ensure our memory usage is locked into physical RAM
if err := mlock.LockMemory(); err != nil {
@@ -407,36 +442,12 @@ func NewCore(conf *CoreConfig) (*Core, error) {
}
// Construct a new AES-GCM barrier
barrier, err := NewAESGCMBarrier(conf.Physical)
var err error
c.barrier, err = NewAESGCMBarrier(c.physical)
if err != nil {
return nil, fmt.Errorf("barrier setup failed: %v", err)
}
// Setup the core
c := &Core{
redirectAddr: conf.RedirectAddr,
clusterAddr: conf.ClusterAddr,
physical: conf.Physical,
seal: conf.Seal,
barrier: barrier,
router: NewRouter(),
sealed: true,
standby: true,
logger: conf.Logger,
defaultLeaseTTL: conf.DefaultLeaseTTL,
maxLeaseTTL: conf.MaxLeaseTTL,
cachingDisabled: conf.DisableCache,
clusterName: conf.ClusterName,
clusterCertPool: x509.NewCertPool(),
clusterListenerShutdownCh: make(chan struct{}),
clusterListenerShutdownSuccessCh: make(chan struct{}),
}
// Wrap the backend in a cache unless disabled
if _, isCache := conf.Physical.(*physical.Cache); !conf.DisableCache && !isCache {
c.physical = physical.NewCache(conf.Physical, conf.CacheSize, conf.Logger)
}
if conf.HAPhysical != nil && conf.HAPhysical.HAEnabled() {
c.ha = conf.HAPhysical
}
@@ -518,10 +529,10 @@ func (c *Core) LookupToken(token string) (*TokenEntry, error) {
c.stateLock.RLock()
defer c.stateLock.RUnlock()
if c.sealed {
return nil, ErrSealed
return nil, consts.ErrSealed
}
if c.standby {
return nil, ErrStandby
return nil, consts.ErrStandby
}
// Many tests don't have a token store running
@@ -656,7 +667,7 @@ func (c *Core) Leader() (isLeader bool, leaderAddr string, err error) {
// Check if sealed
if c.sealed {
return false, "", ErrSealed
return false, "", consts.ErrSealed
}
// Check if HA enabled
@@ -803,17 +814,29 @@ func (c *Core) Unseal(key []byte) (bool, error) {
return true, nil
}
masterKey, err := c.unsealPart(config, key)
if err != nil {
return false, err
}
if masterKey != nil {
return c.unsealInternal(masterKey)
}
return false, nil
}
func (c *Core) unsealPart(config *SealConfig, key []byte) ([]byte, error) {
// Check if we already have this piece
if c.unlockInfo != nil {
for _, existing := range c.unlockInfo.Parts {
if bytes.Equal(existing, key) {
return false, nil
if subtle.ConstantTimeCompare(existing, key) == 1 {
return nil, nil
}
}
} else {
uuid, err := uuid.GenerateUUID()
if err != nil {
return false, err
return nil, err
}
c.unlockInfo = &unlockInformation{
Nonce: uuid,
@@ -828,27 +851,37 @@ func (c *Core) Unseal(key []byte) (bool, error) {
if c.logger.IsDebug() {
c.logger.Debug("core: cannot unseal, not enough keys", "keys", len(c.unlockInfo.Parts), "threshold", config.SecretThreshold, "nonce", c.unlockInfo.Nonce)
}
return false, nil
return nil, nil
}
// Best-effort memzero of unlock parts once we're done with them
defer func() {
for i, _ := range c.unlockInfo.Parts {
memzero(c.unlockInfo.Parts[i])
}
c.unlockInfo = nil
}()
// Recover the master key
var masterKey []byte
var err error
if config.SecretThreshold == 1 {
masterKey = c.unlockInfo.Parts[0]
c.unlockInfo = nil
masterKey = make([]byte, len(c.unlockInfo.Parts[0]))
copy(masterKey, c.unlockInfo.Parts[0])
} else {
masterKey, err = shamir.Combine(c.unlockInfo.Parts)
c.unlockInfo = nil
if err != nil {
return false, fmt.Errorf("failed to compute master key: %v", err)
return nil, fmt.Errorf("failed to compute master key: %v", err)
}
}
defer memzero(masterKey)
return c.unsealInternal(masterKey)
return masterKey, nil
}
// This must be called with the state write lock held
func (c *Core) unsealInternal(masterKey []byte) (bool, error) {
defer memzero(masterKey)
// Attempt to unlock
if err := c.barrier.Unseal(masterKey); err != nil {
return false, err
@@ -867,12 +900,14 @@ func (c *Core) unsealInternal(masterKey []byte) (bool, error) {
c.logger.Warn("core: vault is sealed")
return false, err
}
if err := c.postUnseal(); err != nil {
c.logger.Error("core: post-unseal setup failed", "error", err)
c.barrier.Seal()
c.logger.Warn("core: vault is sealed")
return false, err
}
c.standby = false
} else {
// Go to standby mode, wait until we are active to unseal
@@ -1168,6 +1203,7 @@ func (c *Core) postUnseal() (retErr error) {
if purgable, ok := c.physical.(physical.Purgable); ok {
purgable.Purge()
}
// HA mode requires us to handle keyring rotation and rekeying
if c.ha != nil {
// We want to reload these from disk so that in case of a rekey we're
@@ -1190,6 +1226,9 @@ func (c *Core) postUnseal() (retErr error) {
return err
}
}
if err := enterprisePostUnseal(c); err != nil {
return err
}
if err := c.ensureWrappingKey(); err != nil {
return err
}
@@ -1251,6 +1290,7 @@ func (c *Core) preSeal() error {
c.metricsCh = nil
}
var result error
if c.ha != nil {
c.stopClusterListener()
}
@@ -1273,6 +1313,10 @@ func (c *Core) preSeal() error {
if err := c.unloadMounts(); err != nil {
result = multierror.Append(result, errwrap.Wrapf("error unloading mounts: {{err}}", err))
}
if err := enterprisePreSeal(c); err != nil {
result = multierror.Append(result, err)
}
// Purge the backend if supported
if purgable, ok := c.physical.(physical.Purgable); ok {
purgable.Purge()
@@ -1281,6 +1325,22 @@ func (c *Core) preSeal() error {
return result
}
func enterprisePostUnsealImpl(c *Core) error {
return nil
}
func enterprisePreSealImpl(c *Core) error {
return nil
}
func startReplicationImpl(c *Core) error {
return nil
}
func stopReplicationImpl(c *Core) error {
return nil
}
// runStandby is a long running routine that is used when an HA backend
// is enabled. It waits until we are leader and switches this Vault to
// active.
@@ -1599,6 +1659,14 @@ func (c *Core) emitMetrics(stopCh chan struct{}) {
}
}
func (c *Core) ReplicationState() consts.ReplicationState {
var state consts.ReplicationState
c.clusterParamsLock.RLock()
state = c.replicationState
c.clusterParamsLock.RUnlock()
return state
}
func (c *Core) SealAccess() *SealAccess {
sa := &SealAccess{}
sa.SetSeal(c.seal)

View File

@@ -8,6 +8,7 @@ import (
"github.com/hashicorp/errwrap"
"github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/audit"
"github.com/hashicorp/vault/helper/consts"
"github.com/hashicorp/vault/helper/logformat"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/physical"
@@ -198,7 +199,7 @@ func TestCore_Route_Sealed(t *testing.T) {
Path: "sys/mounts",
}
_, err := c.HandleRequest(req)
if err != ErrSealed {
if err != consts.ErrSealed {
t.Fatalf("err: %v", err)
}
@@ -1541,7 +1542,7 @@ func testCore_Standby_Common(t *testing.T, inm physical.Backend, inmha physical.
// Request should fail in standby mode
_, err = core2.HandleRequest(req)
if err != ErrStandby {
if err != consts.ErrStandby {
t.Fatalf("err: %v", err)
}

View File

@@ -3,6 +3,7 @@ package vault
import (
"time"
"github.com/hashicorp/vault/helper/consts"
"github.com/hashicorp/vault/logical"
)
@@ -79,8 +80,8 @@ func (d dynamicSystemView) CachingDisabled() bool {
}
// Checks if this is a primary Vault instance.
func (d dynamicSystemView) ReplicationState() logical.ReplicationState {
var state logical.ReplicationState
func (d dynamicSystemView) ReplicationState() consts.ReplicationState {
var state consts.ReplicationState
d.core.clusterParamsLock.RLock()
state = d.core.replicationState
d.core.clusterParamsLock.RUnlock()

View File

@@ -12,6 +12,7 @@ import (
log "github.com/mgutz/logxi/v1"
"github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/helper/consts"
"github.com/hashicorp/vault/helper/jsonutil"
"github.com/hashicorp/vault/logical"
)
@@ -125,46 +126,114 @@ func (m *ExpirationManager) Restore() error {
if err != nil {
return fmt.Errorf("failed to scan for leases: %v", err)
}
m.logger.Debug("expiration: leases collected", "num_existing", len(existing))
// Restore each key
for i, leaseID := range existing {
if i%500 == 0 {
m.logger.Trace("expiration: leases loading", "progress", i)
}
// Load the entry
le, err := m.loadEntry(leaseID)
if err != nil {
return err
}
// Make the channels used for the worker pool
broker := make(chan string)
quit := make(chan bool)
// Buffer these channels to prevent deadlocks
errs := make(chan error, len(existing))
result := make(chan *leaseEntry, len(existing))
// If there is no entry, nothing to restore
if le == nil {
continue
}
// Use a wait group
wg := &sync.WaitGroup{}
// If there is no expiry time, don't do anything
if le.ExpireTime.IsZero() {
continue
}
// Create 64 workers to distribute work to
for i := 0; i < consts.ExpirationRestoreWorkerCount; i++ {
wg.Add(1)
go func() {
defer wg.Done()
// Determine the remaining time to expiration
expires := le.ExpireTime.Sub(time.Now())
if expires <= 0 {
expires = minRevokeDelay
}
for {
select {
case leaseID, ok := <-broker:
// broker has been closed, we are done
if !ok {
return
}
// Setup revocation timer
m.pending[le.LeaseID] = time.AfterFunc(expires, func() {
m.expireID(le.LeaseID)
})
le, err := m.loadEntry(leaseID)
if err != nil {
errs <- err
continue
}
// Write results out to the result channel
result <- le
// quit early
case <-quit:
return
}
}
}()
}
// Distribute the collected keys to the workers in a go routine
wg.Add(1)
go func() {
defer wg.Done()
for i, leaseID := range existing {
if i%500 == 0 {
m.logger.Trace("expiration: leases loading", "progress", i)
}
select {
case <-quit:
return
default:
broker <- leaseID
}
}
// Close the broker, causing worker routines to exit
close(broker)
}()
// Restore each key by pulling from the result chan
for i := 0; i < len(existing); i++ {
select {
case err := <-errs:
// Close all go routines
close(quit)
return err
case le := <-result:
// If there is no entry, nothing to restore
if le == nil {
continue
}
// If there is no expiry time, don't do anything
if le.ExpireTime.IsZero() {
continue
}
// Determine the remaining time to expiration
expires := le.ExpireTime.Sub(time.Now())
if expires <= 0 {
expires = minRevokeDelay
}
// Setup revocation timer
m.pending[le.LeaseID] = time.AfterFunc(expires, func() {
m.expireID(le.LeaseID)
})
}
}
// Let all go routines finish
wg.Wait()
if len(m.pending) > 0 {
if m.logger.IsInfo() {
m.logger.Info("expire: leases restored", "restored_lease_count", len(m.pending))
}
}
return nil
}

View File

@@ -2,23 +2,131 @@ package vault
import (
"fmt"
"os"
"reflect"
"sort"
"strings"
"sync"
"testing"
"time"
"github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/helper/logformat"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
"github.com/hashicorp/vault/physical"
log "github.com/mgutz/logxi/v1"
)
var (
testImagePull sync.Once
)
// mockExpiration returns a mock expiration manager
func mockExpiration(t *testing.T) *ExpirationManager {
func mockExpiration(t testing.TB) *ExpirationManager {
_, ts, _, _ := TestCoreWithTokenStore(t)
return ts.expiration
}
func mockBackendExpiration(t testing.TB, backend physical.Backend) (*Core, *ExpirationManager) {
c, ts, _, _ := TestCoreWithBackendTokenStore(t, backend)
return c, ts.expiration
}
func BenchmarkExpiration_Restore_Etcd(b *testing.B) {
addr := os.Getenv("PHYSICAL_BACKEND_BENCHMARK_ADDR")
randPath := fmt.Sprintf("vault-%d/", time.Now().Unix())
logger := logformat.NewVaultLogger(log.LevelTrace)
physicalBackend, err := physical.NewBackend("etcd", logger, map[string]string{
"address": addr,
"path": randPath,
"max_parallel": "256",
})
if err != nil {
b.Fatalf("err: %s", err)
}
benchmarkExpirationBackend(b, physicalBackend, 10000) // 10,000 leases
}
func BenchmarkExpiration_Restore_Consul(b *testing.B) {
addr := os.Getenv("PHYSICAL_BACKEND_BENCHMARK_ADDR")
randPath := fmt.Sprintf("vault-%d/", time.Now().Unix())
logger := logformat.NewVaultLogger(log.LevelTrace)
physicalBackend, err := physical.NewBackend("consul", logger, map[string]string{
"address": addr,
"path": randPath,
"max_parallel": "256",
})
if err != nil {
b.Fatalf("err: %s", err)
}
benchmarkExpirationBackend(b, physicalBackend, 10000) // 10,000 leases
}
func BenchmarkExpiration_Restore_InMem(b *testing.B) {
logger := logformat.NewVaultLogger(log.LevelTrace)
benchmarkExpirationBackend(b, physical.NewInmem(logger), 100000) // 100,000 Leases
}
func benchmarkExpirationBackend(b *testing.B, physicalBackend physical.Backend, numLeases int) {
c, exp := mockBackendExpiration(b, physicalBackend)
noop := &NoopBackend{}
view := NewBarrierView(c.barrier, "logical/")
meUUID, err := uuid.GenerateUUID()
if err != nil {
b.Fatal(err)
}
exp.router.Mount(noop, "prod/aws/", &MountEntry{UUID: meUUID}, view)
// Register fake leases
for i := 0; i < numLeases; i++ {
pathUUID, err := uuid.GenerateUUID()
if err != nil {
b.Fatal(err)
}
req := &logical.Request{
Operation: logical.ReadOperation,
Path: "prod/aws/" + pathUUID,
}
resp := &logical.Response{
Secret: &logical.Secret{
LeaseOptions: logical.LeaseOptions{
TTL: 400 * time.Second,
},
},
Data: map[string]interface{}{
"access_key": "xyz",
"secret_key": "abcd",
},
}
_, err = exp.Register(req, resp)
if err != nil {
b.Fatalf("err: %v", err)
}
}
// Stop everything
err = exp.Stop()
if err != nil {
b.Fatalf("err: %v", err)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
err = exp.Restore()
// Restore
if err != nil {
b.Fatalf("err: %v", err)
}
}
b.StopTimer()
}
func TestExpiration_Restore(t *testing.T) {
exp := mockExpiration(t)
noop := &NoopBackend{}

View File

@@ -6,6 +6,7 @@ import (
"fmt"
"github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/helper/consts"
"github.com/hashicorp/vault/helper/pgpkeys"
"github.com/hashicorp/vault/helper/xor"
"github.com/hashicorp/vault/shamir"
@@ -34,10 +35,10 @@ func (c *Core) GenerateRootProgress() (int, error) {
c.stateLock.RLock()
defer c.stateLock.RUnlock()
if c.sealed {
return 0, ErrSealed
return 0, consts.ErrSealed
}
if c.standby {
return 0, ErrStandby
return 0, consts.ErrStandby
}
c.generateRootLock.Lock()
@@ -52,10 +53,10 @@ func (c *Core) GenerateRootConfiguration() (*GenerateRootConfig, error) {
c.stateLock.RLock()
defer c.stateLock.RUnlock()
if c.sealed {
return nil, ErrSealed
return nil, consts.ErrSealed
}
if c.standby {
return nil, ErrStandby
return nil, consts.ErrStandby
}
c.generateRootLock.Lock()
@@ -101,10 +102,10 @@ func (c *Core) GenerateRootInit(otp, pgpKey string) error {
c.stateLock.RLock()
defer c.stateLock.RUnlock()
if c.sealed {
return ErrSealed
return consts.ErrSealed
}
if c.standby {
return ErrStandby
return consts.ErrStandby
}
c.generateRootLock.Lock()
@@ -170,10 +171,10 @@ func (c *Core) GenerateRootUpdate(key []byte, nonce string) (*GenerateRootResult
c.stateLock.RLock()
defer c.stateLock.RUnlock()
if c.sealed {
return nil, ErrSealed
return nil, consts.ErrSealed
}
if c.standby {
return nil, ErrStandby
return nil, consts.ErrStandby
}
c.generateRootLock.Lock()
@@ -308,10 +309,10 @@ func (c *Core) GenerateRootCancel() error {
c.stateLock.RLock()
defer c.stateLock.RUnlock()
if c.sealed {
return ErrSealed
return consts.ErrSealed
}
if c.standby {
return ErrStandby
return consts.ErrStandby
}
c.generateRootLock.Lock()

View File

@@ -133,36 +133,12 @@ func (c *Core) Initialize(initParams *InitParams) (*InitResult, error) {
return nil, fmt.Errorf("error initializing seal: %v", err)
}
err = c.seal.SetBarrierConfig(barrierConfig)
if err != nil {
c.logger.Error("core: failed to save barrier configuration", "error", err)
return nil, fmt.Errorf("barrier configuration saving failed: %v", err)
}
barrierKey, barrierUnsealKeys, err := c.generateShares(barrierConfig)
if err != nil {
c.logger.Error("core: error generating shares", "error", err)
return nil, err
}
// If we are storing shares, pop them out of the returned results and push
// them through the seal
if barrierConfig.StoredShares > 0 {
var keysToStore [][]byte
for i := 0; i < barrierConfig.StoredShares; i++ {
keysToStore = append(keysToStore, barrierUnsealKeys[0])
barrierUnsealKeys = barrierUnsealKeys[1:]
}
if err := c.seal.SetStoredKeys(keysToStore); err != nil {
c.logger.Error("core: failed to store keys", "error", err)
return nil, fmt.Errorf("failed to store keys: %v", err)
}
}
results := &InitResult{
SecretShares: barrierUnsealKeys,
}
// Initialize the barrier
if err := c.barrier.Initialize(barrierKey); err != nil {
c.logger.Error("core: failed to initialize barrier", "error", err)
@@ -180,11 +156,38 @@ func (c *Core) Initialize(initParams *InitParams) (*InitResult, error) {
// Ensure the barrier is re-sealed
defer func() {
// Defers are LIFO so we need to run this here too to ensure the stop
// happens before sealing. preSeal also stops, so we just make the
// stopping safe against multiple calls.
if err := c.barrier.Seal(); err != nil {
c.logger.Error("core: failed to seal barrier", "error", err)
}
}()
err = c.seal.SetBarrierConfig(barrierConfig)
if err != nil {
c.logger.Error("core: failed to save barrier configuration", "error", err)
return nil, fmt.Errorf("barrier configuration saving failed: %v", err)
}
// If we are storing shares, pop them out of the returned results and push
// them through the seal
if barrierConfig.StoredShares > 0 {
var keysToStore [][]byte
for i := 0; i < barrierConfig.StoredShares; i++ {
keysToStore = append(keysToStore, barrierUnsealKeys[0])
barrierUnsealKeys = barrierUnsealKeys[1:]
}
if err := c.seal.SetStoredKeys(keysToStore); err != nil {
c.logger.Error("core: failed to store keys", "error", err)
return nil, fmt.Errorf("failed to store keys: %v", err)
}
}
results := &InitResult{
SecretShares: barrierUnsealKeys,
}
// Perform initial setup
if err := c.setupCluster(); err != nil {
c.logger.Error("core: cluster setup failed during init", "error", err)

Some files were not shown because too many files have changed in this diff Show More