Merge branch 'master' into acl-parameters-permission

This commit is contained in:
Brian Kassouf
2017-02-21 14:46:06 -08:00
287 changed files with 6237 additions and 2356 deletions

View File

@@ -7,7 +7,7 @@ services:
- docker - docker
go: go:
- 1.8rc2 - 1.8
matrix: matrix:
allow_failures: allow_failures:

View File

@@ -1,5 +1,16 @@
## Next (Unreleased) ## Next (Unreleased)
DEPRECATIONS/CHANGES:
* List Operations Always Use Trailing Slash: Any list operation, whether via
the `GET` or `LIST` HTTP verb, will now internally canonicalize the path to
have a trailing slash. This makes policy writing more predictable, as it
means clients will no longer work or fail based on which client they're
using or which HTTP verb they're using. However, it also means that policies
allowing `list` capability must be carefully checked to ensure that they
contain a trailing slash; some policies may need to be split into multiple
stanzas to accommodate.
IMPROVEMENTS: IMPROVEMENTS:
* auth/ldap: Use the value of the `LOGNAME` or `USER` env vars for the * auth/ldap: Use the value of the `LOGNAME` or `USER` env vars for the
@@ -7,14 +18,20 @@ IMPROVEMENTS:
[GH-2154] [GH-2154]
* audit: Support adding a configurable prefix (such as `@cee`) before each * audit: Support adding a configurable prefix (such as `@cee`) before each
line [GH-2359] line [GH-2359]
* core: Canonicalize list operations to use a trailing slash [GH-2390]
* secret/pki: O (Organization) values can now be set to role-defined values
for issued/signed certificates [GH-2369]
BUG FIXES: BUG FIXES:
* audit: When auditing headers use case-insensitive comparisons [GH-2362]
* auth/aws-ec2: Return role period in seconds and not nanoseconds [GH-2374] * auth/aws-ec2: Return role period in seconds and not nanoseconds [GH-2374]
* auth/okta: Fix panic if user had no local groups and/or policies set * auth/okta: Fix panic if user had no local groups and/or policies set
[GH-2367] [GH-2367]
* command/server: Fix parsing of redirect address when port is not mentioned * command/server: Fix parsing of redirect address when port is not mentioned
[GH-2354] [GH-2354]
* physical/postgresql: Fix listing returning incorrect results if there were
multiple levels of children [GH-2393]
## 0.6.5 (February 7th, 2017) ## 0.6.5 (February 7th, 2017)

View File

@@ -24,6 +24,11 @@ dev-dynamic: generate
test: generate test: generate
CGO_ENABLED=0 VAULT_TOKEN= VAULT_ACC= go test -tags='$(BUILD_TAGS)' $(TEST) $(TESTARGS) -timeout=10m -parallel=4 CGO_ENABLED=0 VAULT_TOKEN= VAULT_ACC= go test -tags='$(BUILD_TAGS)' $(TEST) $(TESTARGS) -timeout=10m -parallel=4
testcompile: generate
@for pkg in $(TEST) ; do \
go test -v -c -tags='$(BUILD_TAGS)' $$pkg -parallel=4 ; \
done
# testacc runs acceptance tests # testacc runs acceptance tests
testacc: generate testacc: generate
@if [ "$(TEST)" = "./..." ]; then \ @if [ "$(TEST)" = "./..." ]; then \

View File

@@ -56,9 +56,9 @@ All documentation is available on the [Vault website](https://www.vaultproject.i
Developing Vault Developing Vault
-------------------- --------------------
If you wish to work on Vault itself or any of its built-in systems, If you wish to work on Vault itself or any of its built-in systems, you'll
you'll first need [Go](https://www.golang.org) installed on your first need [Go](https://www.golang.org) installed on your machine (version 1.8+
machine (version 1.8+ is *required*). is *required*).
For local dev first make sure Go is properly installed, including setting up a For local dev first make sure Go is properly installed, including setting up a
[GOPATH](https://golang.org/doc/code.html#GOPATH). Next, clone this repository [GOPATH](https://golang.org/doc/code.html#GOPATH). Next, clone this repository

View File

@@ -3,6 +3,7 @@ package api
import ( import (
"fmt" "fmt"
"github.com/fatih/structs"
"github.com/mitchellh/mapstructure" "github.com/mitchellh/mapstructure"
) )
@@ -71,13 +72,18 @@ func (c *Sys) ListAudit() (map[string]*Audit, error) {
return mounts, nil return mounts, nil
} }
// DEPRECATED: Use EnableAuditWithOptions instead
func (c *Sys) EnableAudit( func (c *Sys) EnableAudit(
path string, auditType string, desc string, opts map[string]string) error { path string, auditType string, desc string, opts map[string]string) error {
body := map[string]interface{}{ return c.EnableAuditWithOptions(path, &EnableAuditOptions{
"type": auditType, Type: auditType,
"description": desc, Description: desc,
"options": opts, Options: opts,
} })
}
func (c *Sys) EnableAuditWithOptions(path string, options *EnableAuditOptions) error {
body := structs.Map(options)
r := c.c.NewRequest("PUT", fmt.Sprintf("/v1/sys/audit/%s", path)) r := c.c.NewRequest("PUT", fmt.Sprintf("/v1/sys/audit/%s", path))
if err := r.SetJSONBody(body); err != nil { if err := r.SetJSONBody(body); err != nil {
@@ -106,9 +112,17 @@ func (c *Sys) DisableAudit(path string) error {
// individually documented because the map almost directly to the raw HTTP API // individually documented because the map almost directly to the raw HTTP API
// documentation. Please refer to that documentation for more details. // documentation. Please refer to that documentation for more details.
type EnableAuditOptions struct {
Type string `json:"type" structs:"type"`
Description string `json:"description" structs:"description"`
Options map[string]string `json:"options" structs:"options"`
Local bool `json:"local" structs:"local"`
}
type Audit struct { type Audit struct {
Path string Path string
Type string Type string
Description string Description string
Options map[string]string Options map[string]string
Local bool
} }

View File

@@ -3,6 +3,7 @@ package api
import ( import (
"fmt" "fmt"
"github.com/fatih/structs"
"github.com/mitchellh/mapstructure" "github.com/mitchellh/mapstructure"
) )
@@ -42,11 +43,16 @@ func (c *Sys) ListAuth() (map[string]*AuthMount, error) {
return mounts, nil return mounts, nil
} }
// DEPRECATED: Use EnableAuthWithOptions instead
func (c *Sys) EnableAuth(path, authType, desc string) error { func (c *Sys) EnableAuth(path, authType, desc string) error {
body := map[string]string{ return c.EnableAuthWithOptions(path, &EnableAuthOptions{
"type": authType, Type: authType,
"description": desc, Description: desc,
} })
}
func (c *Sys) EnableAuthWithOptions(path string, options *EnableAuthOptions) error {
body := structs.Map(options)
r := c.c.NewRequest("POST", fmt.Sprintf("/v1/sys/auth/%s", path)) r := c.c.NewRequest("POST", fmt.Sprintf("/v1/sys/auth/%s", path))
if err := r.SetJSONBody(body); err != nil { if err := r.SetJSONBody(body); err != nil {
@@ -75,10 +81,17 @@ func (c *Sys) DisableAuth(path string) error {
// individually documentd because the map almost directly to the raw HTTP API // individually documentd because the map almost directly to the raw HTTP API
// documentation. Please refer to that documentation for more details. // documentation. Please refer to that documentation for more details.
type EnableAuthOptions struct {
Type string `json:"type" structs:"type"`
Description string `json:"description" structs:"description"`
Local bool `json:"local" structs:"local"`
}
type AuthMount struct { type AuthMount struct {
Type string `json:"type" structs:"type" mapstructure:"type"` Type string `json:"type" structs:"type" mapstructure:"type"`
Description string `json:"description" structs:"description" mapstructure:"description"` Description string `json:"description" structs:"description" mapstructure:"description"`
Config AuthConfigOutput `json:"config" structs:"config" mapstructure:"config"` Config AuthConfigOutput `json:"config" structs:"config" mapstructure:"config"`
Local bool `json:"local" structs:"local" mapstructure:"local"`
} }
type AuthConfigOutput struct { type AuthConfigOutput struct {

View File

@@ -123,6 +123,7 @@ type MountInput struct {
Type string `json:"type" structs:"type"` Type string `json:"type" structs:"type"`
Description string `json:"description" structs:"description"` Description string `json:"description" structs:"description"`
Config MountConfigInput `json:"config" structs:"config"` Config MountConfigInput `json:"config" structs:"config"`
Local bool `json:"local" structs:"local"`
} }
type MountConfigInput struct { type MountConfigInput struct {
@@ -134,6 +135,7 @@ type MountOutput struct {
Type string `json:"type" structs:"type"` Type string `json:"type" structs:"type"`
Description string `json:"description" structs:"description"` Description string `json:"description" structs:"description"`
Config MountConfigOutput `json:"config" structs:"config"` Config MountConfigOutput `json:"config" structs:"config"`
Local bool `json:"local" structs:"local"`
} }
type MountConfigOutput struct { type MountConfigOutput struct {

View File

@@ -27,7 +27,11 @@ func (f *AuditFormatter) FormatRequest(
config FormatterConfig, config FormatterConfig,
auth *logical.Auth, auth *logical.Auth,
req *logical.Request, req *logical.Request,
err error) error { inErr error) error {
if req == nil {
return fmt.Errorf("request to request-audit a nil request")
}
if w == nil { if w == nil {
return fmt.Errorf("writer for audit request is nil") return fmt.Errorf("writer for audit request is nil")
@@ -49,22 +53,26 @@ func (f *AuditFormatter) FormatRequest(
}() }()
} }
// Copy the structures // Copy the auth structure
cp, err := copystructure.Copy(auth) if auth != nil {
if err != nil { cp, err := copystructure.Copy(auth)
return err if err != nil {
return err
}
auth = cp.(*logical.Auth)
} }
auth = cp.(*logical.Auth)
cp, err = copystructure.Copy(req) cp, err := copystructure.Copy(req)
if err != nil { if err != nil {
return err return err
} }
req = cp.(*logical.Request) req = cp.(*logical.Request)
// Hash any sensitive information // Hash any sensitive information
if err := Hash(config.Salt, auth); err != nil { if auth != nil {
return err if err := Hash(config.Salt, auth); err != nil {
return err
}
} }
// Cache and restore accessor in the request // Cache and restore accessor in the request
@@ -85,8 +93,8 @@ func (f *AuditFormatter) FormatRequest(
auth = new(logical.Auth) auth = new(logical.Auth)
} }
var errString string var errString string
if err != nil { if inErr != nil {
errString = err.Error() errString = inErr.Error()
} }
reqEntry := &AuditRequestEntry{ reqEntry := &AuditRequestEntry{
@@ -107,6 +115,7 @@ func (f *AuditFormatter) FormatRequest(
Path: req.Path, Path: req.Path,
Data: req.Data, Data: req.Data,
RemoteAddr: getRemoteAddr(req), RemoteAddr: getRemoteAddr(req),
ReplicationCluster: req.ReplicationCluster,
Headers: req.Headers, Headers: req.Headers,
}, },
} }
@@ -128,7 +137,11 @@ func (f *AuditFormatter) FormatResponse(
auth *logical.Auth, auth *logical.Auth,
req *logical.Request, req *logical.Request,
resp *logical.Response, resp *logical.Response,
err error) error { inErr error) error {
if req == nil {
return fmt.Errorf("request to response-audit a nil request")
}
if w == nil { if w == nil {
return fmt.Errorf("writer for audit request is nil") return fmt.Errorf("writer for audit request is nil")
@@ -150,37 +163,43 @@ func (f *AuditFormatter) FormatResponse(
}() }()
} }
// Copy the structure // Copy the auth structure
cp, err := copystructure.Copy(auth) if auth != nil {
if err != nil { cp, err := copystructure.Copy(auth)
return err if err != nil {
return err
}
auth = cp.(*logical.Auth)
} }
auth = cp.(*logical.Auth)
cp, err = copystructure.Copy(req) cp, err := copystructure.Copy(req)
if err != nil { if err != nil {
return err return err
} }
req = cp.(*logical.Request) req = cp.(*logical.Request)
cp, err = copystructure.Copy(resp) if resp != nil {
if err != nil { cp, err := copystructure.Copy(resp)
return err if err != nil {
return err
}
resp = cp.(*logical.Response)
} }
resp = cp.(*logical.Response)
// Hash any sensitive information // Hash any sensitive information
// Cache and restore accessor in the auth // Cache and restore accessor in the auth
var accessor, wrappedAccessor string if auth != nil {
if !config.HMACAccessor && auth != nil && auth.Accessor != "" { var accessor string
accessor = auth.Accessor if !config.HMACAccessor && auth.Accessor != "" {
} accessor = auth.Accessor
if err := Hash(config.Salt, auth); err != nil { }
return err if err := Hash(config.Salt, auth); err != nil {
} return err
if accessor != "" { }
auth.Accessor = accessor if accessor != "" {
auth.Accessor = accessor
}
} }
// Cache and restore accessor in the request // Cache and restore accessor in the request
@@ -196,21 +215,23 @@ func (f *AuditFormatter) FormatResponse(
} }
// Cache and restore accessor in the response // Cache and restore accessor in the response
accessor = "" if resp != nil {
if !config.HMACAccessor && resp != nil && resp.Auth != nil && resp.Auth.Accessor != "" { var accessor, wrappedAccessor string
accessor = resp.Auth.Accessor if !config.HMACAccessor && resp != nil && resp.Auth != nil && resp.Auth.Accessor != "" {
} accessor = resp.Auth.Accessor
if !config.HMACAccessor && resp != nil && resp.WrapInfo != nil && resp.WrapInfo.WrappedAccessor != "" { }
wrappedAccessor = resp.WrapInfo.WrappedAccessor if !config.HMACAccessor && resp != nil && resp.WrapInfo != nil && resp.WrapInfo.WrappedAccessor != "" {
} wrappedAccessor = resp.WrapInfo.WrappedAccessor
if err := Hash(config.Salt, resp); err != nil { }
return err if err := Hash(config.Salt, resp); err != nil {
} return err
if accessor != "" { }
resp.Auth.Accessor = accessor if accessor != "" {
} resp.Auth.Accessor = accessor
if wrappedAccessor != "" { }
resp.WrapInfo.WrappedAccessor = wrappedAccessor if wrappedAccessor != "" {
resp.WrapInfo.WrappedAccessor = wrappedAccessor
}
} }
} }
@@ -222,8 +243,8 @@ func (f *AuditFormatter) FormatResponse(
resp = new(logical.Response) resp = new(logical.Response)
} }
var errString string var errString string
if err != nil { if inErr != nil {
errString = err.Error() errString = inErr.Error()
} }
var respAuth *AuditAuth var respAuth *AuditAuth
@@ -276,6 +297,7 @@ func (f *AuditFormatter) FormatResponse(
Path: req.Path, Path: req.Path,
Data: req.Data, Data: req.Data,
RemoteAddr: getRemoteAddr(req), RemoteAddr: getRemoteAddr(req),
ReplicationCluster: req.ReplicationCluster,
Headers: req.Headers, Headers: req.Headers,
}, },
@@ -312,14 +334,15 @@ type AuditRequestEntry struct {
type AuditResponseEntry struct { type AuditResponseEntry struct {
Time string `json:"time,omitempty"` Time string `json:"time,omitempty"`
Type string `json:"type"` Type string `json:"type"`
Error string `json:"error"`
Auth AuditAuth `json:"auth"` Auth AuditAuth `json:"auth"`
Request AuditRequest `json:"request"` Request AuditRequest `json:"request"`
Response AuditResponse `json:"response"` Response AuditResponse `json:"response"`
Error string `json:"error"`
} }
type AuditRequest struct { type AuditRequest struct {
ID string `json:"id"` ID string `json:"id"`
ReplicationCluster string `json:"replication_cluster,omitempty"`
Operation logical.Operation `json:"operation"` Operation logical.Operation `json:"operation"`
ClientToken string `json:"client_token"` ClientToken string `json:"client_token"`
ClientTokenAccessor string `json:"client_token_accessor"` ClientTokenAccessor string `json:"client_token_accessor"`

55
audit/format_test.go Normal file
View File

@@ -0,0 +1,55 @@
package audit
import (
"io"
"io/ioutil"
"testing"
"github.com/hashicorp/vault/helper/salt"
"github.com/hashicorp/vault/logical"
)
type noopFormatWriter struct {
}
func (n *noopFormatWriter) WriteRequest(_ io.Writer, _ *AuditRequestEntry) error {
return nil
}
func (n *noopFormatWriter) WriteResponse(_ io.Writer, _ *AuditResponseEntry) error {
return nil
}
func TestFormatRequestErrors(t *testing.T) {
salter, _ := salt.NewSalt(nil, nil)
config := FormatterConfig{
Salt: salter,
}
formatter := AuditFormatter{
AuditFormatWriter: &noopFormatWriter{},
}
if err := formatter.FormatRequest(ioutil.Discard, config, nil, nil, nil); err == nil {
t.Fatal("expected error due to nil request")
}
if err := formatter.FormatRequest(nil, config, nil, &logical.Request{}, nil); err == nil {
t.Fatal("expected error due to nil writer")
}
}
func TestFormatResponseErrors(t *testing.T) {
salter, _ := salt.NewSalt(nil, nil)
config := FormatterConfig{
Salt: salter,
}
formatter := AuditFormatter{
AuditFormatWriter: &noopFormatWriter{},
}
if err := formatter.FormatResponse(ioutil.Discard, config, nil, nil, nil, nil); err == nil {
t.Fatal("expected error due to nil request")
}
if err := formatter.FormatResponse(nil, config, nil, &logical.Request{}, nil, nil); err == nil {
t.Fatal("expected error due to nil writer")
}
}

View File

@@ -157,10 +157,15 @@ func (b *Backend) open() error {
return err return err
} }
// Change the file mode in case the log file already existed // Change the file mode in case the log file already existed. We special
err = os.Chmod(b.path, b.mode) // case /dev/null since we can't chmod it
if err != nil { switch b.path {
return err case "/dev/null":
default:
err = os.Chmod(b.path, b.mode)
if err != nil {
return err
}
} }
return nil return nil

View File

@@ -17,20 +17,10 @@ func Factory(conf *logical.BackendConfig) (logical.Backend, error) {
} }
func Backend(conf *logical.BackendConfig) (*framework.Backend, error) { func Backend(conf *logical.BackendConfig) (*framework.Backend, error) {
// Initialize the salt
salt, err := salt.NewSalt(conf.StorageView, &salt.Config{
HashFunc: salt.SHA1Hash,
})
if err != nil {
return nil, err
}
var b backend var b backend
b.Salt = salt
b.MapAppId = &framework.PolicyMap{ b.MapAppId = &framework.PolicyMap{
PathMap: framework.PathMap{ PathMap: framework.PathMap{
Name: "app-id", Name: "app-id",
Salt: salt,
Schema: map[string]*framework.FieldSchema{ Schema: map[string]*framework.FieldSchema{
"display_name": &framework.FieldSchema{ "display_name": &framework.FieldSchema{
Type: framework.TypeString, Type: framework.TypeString,
@@ -48,7 +38,6 @@ func Backend(conf *logical.BackendConfig) (*framework.Backend, error) {
b.MapUserId = &framework.PathMap{ b.MapUserId = &framework.PathMap{
Name: "user-id", Name: "user-id",
Salt: salt,
Schema: map[string]*framework.FieldSchema{ Schema: map[string]*framework.FieldSchema{
"cidr_block": &framework.FieldSchema{ "cidr_block": &framework.FieldSchema{
Type: framework.TypeString, Type: framework.TypeString,
@@ -81,17 +70,11 @@ func Backend(conf *logical.BackendConfig) (*framework.Backend, error) {
), ),
AuthRenew: b.pathLoginRenew, AuthRenew: b.pathLoginRenew,
Init: b.initialize,
} }
// Since the salt is new in 0.2, we need to handle this by migrating b.view = conf.StorageView
// any existing keys to use the salt. We can deprecate this eventually,
// but for now we want a smooth upgrade experience by automatically
// upgrading to use salting.
if salt.DidGenerate() {
if err := b.upgradeToSalted(conf.StorageView); err != nil {
return nil, err
}
}
return b.Backend, nil return b.Backend, nil
} }
@@ -100,10 +83,36 @@ type backend struct {
*framework.Backend *framework.Backend
Salt *salt.Salt Salt *salt.Salt
view logical.Storage
MapAppId *framework.PolicyMap MapAppId *framework.PolicyMap
MapUserId *framework.PathMap MapUserId *framework.PathMap
} }
func (b *backend) initialize() error {
salt, err := salt.NewSalt(b.view, &salt.Config{
HashFunc: salt.SHA1Hash,
})
if err != nil {
return err
}
b.Salt = salt
b.MapAppId.Salt = salt
b.MapUserId.Salt = salt
// Since the salt is new in 0.2, we need to handle this by migrating
// any existing keys to use the salt. We can deprecate this eventually,
// but for now we want a smooth upgrade experience by automatically
// upgrading to use salting.
if salt.DidGenerate() {
if err := b.upgradeToSalted(b.view); err != nil {
return err
}
}
return nil
}
// upgradeToSalted is used to upgrade the non-salted keys prior to // upgradeToSalted is used to upgrade the non-salted keys prior to
// Vault 0.2 to be salted. This is done on mount time and is only // Vault 0.2 to be salted. This is done on mount time and is only
// done once. It can be deprecated eventually, but should be around // done once. It can be deprecated eventually, but should be around

View File

@@ -72,6 +72,10 @@ func TestBackend_upgradeToSalted(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
err = backend.Initialize()
if err != nil {
t.Fatalf("err: %v", err)
}
// Check the keys have been upgraded // Check the keys have been upgraded
out, err := inm.Get("struct/map/app-id/foo") out, err := inm.Get("struct/map/app-id/foo")

View File

@@ -17,6 +17,9 @@ type backend struct {
// by this backend. // by this backend.
salt *salt.Salt salt *salt.Salt
// The view to use when creating the salt
view logical.Storage
// Guard to clean-up the expired SecretID entries // Guard to clean-up the expired SecretID entries
tidySecretIDCASGuard uint32 tidySecretIDCASGuard uint32
@@ -57,18 +60,9 @@ func Factory(conf *logical.BackendConfig) (logical.Backend, error) {
} }
func Backend(conf *logical.BackendConfig) (*backend, error) { func Backend(conf *logical.BackendConfig) (*backend, error) {
// Initialize the salt
salt, err := salt.NewSalt(conf.StorageView, &salt.Config{
HashFunc: salt.SHA256Hash,
})
if err != nil {
return nil, err
}
// Create a backend object // Create a backend object
b := &backend{ b := &backend{
// Set the salt object for the backend view: conf.StorageView,
salt: salt,
// Create the map of locks to modify the registered roles // Create the map of locks to modify the registered roles
roleLocksMap: make(map[string]*sync.RWMutex, 257), roleLocksMap: make(map[string]*sync.RWMutex, 257),
@@ -83,6 +77,8 @@ func Backend(conf *logical.BackendConfig) (*backend, error) {
secretIDAccessorLocksMap: make(map[string]*sync.RWMutex, 257), secretIDAccessorLocksMap: make(map[string]*sync.RWMutex, 257),
} }
var err error
// Create 256 locks each for managing RoleID and SecretIDs. This will avoid // Create 256 locks each for managing RoleID and SecretIDs. This will avoid
// a superfluous number of locks directly proportional to the number of RoleID // a superfluous number of locks directly proportional to the number of RoleID
// and SecretIDs. These locks can be accessed by indexing based on the first two // and SecretIDs. These locks can be accessed by indexing based on the first two
@@ -129,10 +125,22 @@ func Backend(conf *logical.BackendConfig) (*backend, error) {
pathTidySecretID(b), pathTidySecretID(b),
}, },
), ),
Init: b.initialize,
} }
return b, nil return b, nil
} }
func (b *backend) initialize() error {
salt, err := salt.NewSalt(b.view, &salt.Config{
HashFunc: salt.SHA256Hash,
})
if err != nil {
return err
}
b.salt = salt
return nil
}
// periodicFunc of the backend will be invoked once a minute by the RollbackManager. // periodicFunc of the backend will be invoked once a minute by the RollbackManager.
// RoleRole backend utilizes this function to delete expired SecretID entries. // RoleRole backend utilizes this function to delete expired SecretID entries.
// This could mean that the SecretID may live in the backend upto 1 min after its // This could mean that the SecretID may live in the backend upto 1 min after its

View File

@@ -21,5 +21,9 @@ func createBackendWithStorage(t *testing.T) (*backend, logical.Storage) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
err = b.Initialize()
if err != nil {
t.Fatal(err)
}
return b, config.StorageView return b, config.StorageView
} }

View File

@@ -143,7 +143,7 @@ func (b *backend) validateCredentials(req *logical.Request, data *framework.Fiel
return nil, "", metadata, fmt.Errorf("failed to verify the CIDR restrictions set on the role: %v", err) return nil, "", metadata, fmt.Errorf("failed to verify the CIDR restrictions set on the role: %v", err)
} }
if !belongs { if !belongs {
return nil, "", metadata, fmt.Errorf("source address unauthorized through CIDR restrictions on the role") return nil, "", metadata, fmt.Errorf("source address %q unauthorized through CIDR restrictions on the role", req.Connection.RemoteAddr)
} }
} }
@@ -199,7 +199,7 @@ func (b *backend) validateBindSecretID(req *logical.Request, roleName, secretID,
} }
if belongs, err := cidrutil.IPBelongsToCIDRBlocksSlice(req.Connection.RemoteAddr, result.CIDRList); !belongs || err != nil { if belongs, err := cidrutil.IPBelongsToCIDRBlocksSlice(req.Connection.RemoteAddr, result.CIDRList); !belongs || err != nil {
return false, nil, fmt.Errorf("source address unauthorized through CIDR restrictions on the secret ID: %v", err) return false, nil, fmt.Errorf("source address %q unauthorized through CIDR restrictions on the secret ID: %v", req.Connection.RemoteAddr, err)
} }
} }
@@ -261,7 +261,7 @@ func (b *backend) validateBindSecretID(req *logical.Request, roleName, secretID,
} }
if belongs, err := cidrutil.IPBelongsToCIDRBlocksSlice(req.Connection.RemoteAddr, result.CIDRList); !belongs || err != nil { if belongs, err := cidrutil.IPBelongsToCIDRBlocksSlice(req.Connection.RemoteAddr, result.CIDRList); !belongs || err != nil {
return false, nil, fmt.Errorf("source address unauthorized through CIDR restrictions on the secret ID: %v", err) return false, nil, fmt.Errorf("source address %q unauthorized through CIDR restrictions on the secret ID: %v", req.Connection.RemoteAddr, err)
} }
} }

View File

@@ -23,6 +23,9 @@ type backend struct {
*framework.Backend *framework.Backend
Salt *salt.Salt Salt *salt.Salt
// Used during initialization to set the salt
view logical.Storage
// Lock to make changes to any of the backend's configuration endpoints. // Lock to make changes to any of the backend's configuration endpoints.
configMutex sync.RWMutex configMutex sync.RWMutex
@@ -59,18 +62,11 @@ type backend struct {
} }
func Backend(conf *logical.BackendConfig) (*backend, error) { func Backend(conf *logical.BackendConfig) (*backend, error) {
salt, err := salt.NewSalt(conf.StorageView, &salt.Config{
HashFunc: salt.SHA256Hash,
})
if err != nil {
return nil, err
}
b := &backend{ b := &backend{
// Setting the periodic func to be run once in an hour. // Setting the periodic func to be run once in an hour.
// If there is a real need, this can be made configurable. // If there is a real need, this can be made configurable.
tidyCooldownPeriod: time.Hour, tidyCooldownPeriod: time.Hour,
Salt: salt, view: conf.StorageView,
EC2ClientsMap: make(map[string]map[string]*ec2.EC2), EC2ClientsMap: make(map[string]map[string]*ec2.EC2),
IAMClientsMap: make(map[string]map[string]*iam.IAM), IAMClientsMap: make(map[string]map[string]*iam.IAM),
} }
@@ -83,6 +79,9 @@ func Backend(conf *logical.BackendConfig) (*backend, error) {
Unauthenticated: []string{ Unauthenticated: []string{
"login", "login",
}, },
LocalStorage: []string{
"whitelist/identity/",
},
}, },
Paths: []*framework.Path{ Paths: []*framework.Path{
pathLogin(b), pathLogin(b),
@@ -104,11 +103,26 @@ func Backend(conf *logical.BackendConfig) (*backend, error) {
pathIdentityWhitelist(b), pathIdentityWhitelist(b),
pathTidyIdentityWhitelist(b), pathTidyIdentityWhitelist(b),
}, },
Invalidate: b.invalidate,
Init: b.initialize,
} }
return b, nil return b, nil
} }
func (b *backend) initialize() error {
salt, err := salt.NewSalt(b.view, &salt.Config{
HashFunc: salt.SHA256Hash,
})
if err != nil {
return err
}
b.Salt = salt
return nil
}
// periodicFunc performs the tasks that the backend wishes to do periodically. // periodicFunc performs the tasks that the backend wishes to do periodically.
// Currently this will be triggered once in a minute by the RollbackManager. // Currently this will be triggered once in a minute by the RollbackManager.
// //
@@ -169,6 +183,16 @@ func (b *backend) periodicFunc(req *logical.Request) error {
return nil return nil
} }
func (b *backend) invalidate(key string) {
switch key {
case "config/client":
b.configMutex.Lock()
defer b.configMutex.Unlock()
b.flushCachedEC2Clients()
b.flushCachedIAMClients()
}
}
const backendHelp = ` const backendHelp = `
aws-ec2 auth backend takes in PKCS#7 signature of an AWS EC2 instance and a client aws-ec2 auth backend takes in PKCS#7 signature of an AWS EC2 instance and a client
created nonce to authenticates the EC2 instance with Vault. created nonce to authenticates the EC2 instance with Vault.

View File

@@ -1,6 +1,7 @@
package cert package cert
import ( import (
"strings"
"sync" "sync"
"github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical"
@@ -13,7 +14,7 @@ func Factory(conf *logical.BackendConfig) (logical.Backend, error) {
if err != nil { if err != nil {
return b, err return b, err
} }
return b, b.populateCRLs(conf.StorageView) return b, nil
} }
func Backend() *backend { func Backend() *backend {
@@ -36,9 +37,10 @@ func Backend() *backend {
}), }),
AuthRenew: b.pathLoginRenew, AuthRenew: b.pathLoginRenew,
Invalidate: b.invalidate,
} }
b.crls = map[string]CRLInfo{}
b.crlUpdateMutex = &sync.RWMutex{} b.crlUpdateMutex = &sync.RWMutex{}
return &b return &b
@@ -52,6 +54,15 @@ type backend struct {
crlUpdateMutex *sync.RWMutex crlUpdateMutex *sync.RWMutex
} }
func (b *backend) invalidate(key string) {
switch {
case strings.HasPrefix(key, "crls/"):
b.crlUpdateMutex.Lock()
defer b.crlUpdateMutex.Unlock()
b.crls = nil
}
}
const backendHelp = ` const backendHelp = `
The "cert" credential provider allows authentication using The "cert" credential provider allows authentication using
TLS client certificates. A client connects to Vault and uses TLS client certificates. A client connects to Vault and uses

View File

@@ -45,6 +45,12 @@ func (b *backend) populateCRLs(storage logical.Storage) error {
b.crlUpdateMutex.Lock() b.crlUpdateMutex.Lock()
defer b.crlUpdateMutex.Unlock() defer b.crlUpdateMutex.Unlock()
if b.crls != nil {
return nil
}
b.crls = map[string]CRLInfo{}
keys, err := storage.List("crls/") keys, err := storage.List("crls/")
if err != nil { if err != nil {
return fmt.Errorf("error listing CRLs: %v", err) return fmt.Errorf("error listing CRLs: %v", err)
@@ -56,6 +62,7 @@ func (b *backend) populateCRLs(storage logical.Storage) error {
for _, key := range keys { for _, key := range keys {
entry, err := storage.Get("crls/" + key) entry, err := storage.Get("crls/" + key)
if err != nil { if err != nil {
b.crls = nil
return fmt.Errorf("error loading CRL %s: %v", key, err) return fmt.Errorf("error loading CRL %s: %v", key, err)
} }
if entry == nil { if entry == nil {
@@ -64,6 +71,7 @@ func (b *backend) populateCRLs(storage logical.Storage) error {
var crlInfo CRLInfo var crlInfo CRLInfo
err = entry.DecodeJSON(&crlInfo) err = entry.DecodeJSON(&crlInfo)
if err != nil { if err != nil {
b.crls = nil
return fmt.Errorf("error decoding CRL %s: %v", key, err) return fmt.Errorf("error decoding CRL %s: %v", key, err)
} }
b.crls[key] = crlInfo b.crls[key] = crlInfo
@@ -121,6 +129,10 @@ func (b *backend) pathCRLDelete(
return logical.ErrorResponse(`"name" parameter cannot be empty`), nil return logical.ErrorResponse(`"name" parameter cannot be empty`), nil
} }
if err := b.populateCRLs(req.Storage); err != nil {
return nil, err
}
b.crlUpdateMutex.Lock() b.crlUpdateMutex.Lock()
defer b.crlUpdateMutex.Unlock() defer b.crlUpdateMutex.Unlock()
@@ -131,8 +143,7 @@ func (b *backend) pathCRLDelete(
)), nil )), nil
} }
err := req.Storage.Delete("crls/" + name) if err := req.Storage.Delete("crls/" + name); err != nil {
if err != nil {
return logical.ErrorResponse(fmt.Sprintf( return logical.ErrorResponse(fmt.Sprintf(
"error deleting crl %s: %v", name, err), "error deleting crl %s: %v", name, err),
), nil ), nil
@@ -150,6 +161,10 @@ func (b *backend) pathCRLRead(
return logical.ErrorResponse(`"name" parameter must be set`), nil return logical.ErrorResponse(`"name" parameter must be set`), nil
} }
if err := b.populateCRLs(req.Storage); err != nil {
return nil, err
}
b.crlUpdateMutex.RLock() b.crlUpdateMutex.RLock()
defer b.crlUpdateMutex.RUnlock() defer b.crlUpdateMutex.RUnlock()
@@ -185,6 +200,10 @@ func (b *backend) pathCRLWrite(
return logical.ErrorResponse("parsed CRL is nil"), nil return logical.ErrorResponse("parsed CRL is nil"), nil
} }
if err := b.populateCRLs(req.Storage); err != nil {
return nil, err
}
b.crlUpdateMutex.Lock() b.crlUpdateMutex.Lock()
defer b.crlUpdateMutex.Unlock() defer b.crlUpdateMutex.Unlock()

View File

@@ -17,6 +17,12 @@ func Backend() *backend {
b.Backend = &framework.Backend{ b.Backend = &framework.Backend{
Help: strings.TrimSpace(backendHelp), Help: strings.TrimSpace(backendHelp),
PathsSpecial: &logical.Paths{
LocalStorage: []string{
framework.WALPrefix,
},
},
Paths: []*framework.Path{ Paths: []*framework.Path{
pathConfigRoot(), pathConfigRoot(),
pathConfigLease(&b), pathConfigLease(&b),

View File

@@ -31,6 +31,8 @@ func Backend() *backend {
secretCreds(&b), secretCreds(&b),
}, },
Invalidate: b.invalidate,
Clean: func() { Clean: func() {
b.ResetDB(nil) b.ResetDB(nil)
}, },
@@ -107,6 +109,13 @@ func (b *backend) ResetDB(newSession *gocql.Session) {
b.session = newSession b.session = newSession
} }
func (b *backend) invalidate(key string) {
switch key {
case "config/connection":
b.ResetDB(nil)
}
}
const backendHelp = ` const backendHelp = `
The Cassandra backend dynamically generates database users. The Cassandra backend dynamically generates database users.

View File

@@ -421,7 +421,7 @@ seed_provider:
parameters: parameters:
# seeds is actually a comma-delimited list of addresses. # seeds is actually a comma-delimited list of addresses.
# Ex: "<ip1>,<ip2>,<ip3>" # Ex: "<ip1>,<ip2>,<ip3>"
- seeds: "172.17.0.2" - seeds: "172.17.0.3"
# For workloads with more data than can fit in memory, Cassandra's # For workloads with more data than can fit in memory, Cassandra's
# bottleneck will be reads that need to fetch data from # bottleneck will be reads that need to fetch data from
@@ -572,7 +572,7 @@ ssl_storage_port: 7001
# #
# Setting listen_address to 0.0.0.0 is always wrong. # Setting listen_address to 0.0.0.0 is always wrong.
# #
listen_address: 172.17.0.2 listen_address: 172.17.0.3
# Set listen_address OR listen_interface, not both. Interfaces must correspond # Set listen_address OR listen_interface, not both. Interfaces must correspond
# to a single address, IP aliasing is not supported. # to a single address, IP aliasing is not supported.
@@ -586,7 +586,7 @@ listen_address: 172.17.0.2
# Address to broadcast to other Cassandra nodes # Address to broadcast to other Cassandra nodes
# Leaving this blank will set it to the same value as listen_address # Leaving this blank will set it to the same value as listen_address
broadcast_address: 172.17.0.2 broadcast_address: 172.17.0.3
# When using multiple physical network interfaces, set this # When using multiple physical network interfaces, set this
# to true to listen on broadcast_address in addition to # to true to listen on broadcast_address in addition to
@@ -668,7 +668,7 @@ rpc_port: 9160
# be set to 0.0.0.0. If left blank, this will be set to the value of # be set to 0.0.0.0. If left blank, this will be set to the value of
# rpc_address. If rpc_address is set to 0.0.0.0, broadcast_rpc_address must # rpc_address. If rpc_address is set to 0.0.0.0, broadcast_rpc_address must
# be set. # be set.
broadcast_rpc_address: 172.17.0.2 broadcast_rpc_address: 172.17.0.3
# enable or disable keepalive on rpc/native connections # enable or disable keepalive on rpc/native connections
rpc_keepalive: true rpc_keepalive: true

View File

@@ -33,6 +33,8 @@ func Backend() *framework.Backend {
}, },
Clean: b.ResetSession, Clean: b.ResetSession,
Invalidate: b.invalidate,
} }
return b.Backend return b.Backend
@@ -97,6 +99,13 @@ func (b *backend) ResetSession() {
b.session = nil b.session = nil
} }
func (b *backend) invalidate(key string) {
switch key {
case "config/connection":
b.ResetSession()
}
}
// LeaseConfig returns the lease configuration // LeaseConfig returns the lease configuration
func (b *backend) LeaseConfig(s logical.Storage) (*configLease, error) { func (b *backend) LeaseConfig(s logical.Storage) (*configLease, error) {
entry, err := s.Get("config/lease") entry, err := s.Get("config/lease")

View File

@@ -32,6 +32,8 @@ func Backend() *backend {
secretCreds(&b), secretCreds(&b),
}, },
Invalidate: b.invalidate,
Clean: b.ResetDB, Clean: b.ResetDB,
} }
@@ -112,6 +114,13 @@ func (b *backend) ResetDB() {
b.db = nil b.db = nil
} }
func (b *backend) invalidate(key string) {
switch key {
case "config/connection":
b.ResetDB()
}
}
// LeaseConfig returns the lease configuration // LeaseConfig returns the lease configuration
func (b *backend) LeaseConfig(s logical.Storage) (*configLease, error) { func (b *backend) LeaseConfig(s logical.Storage) (*configLease, error) {
entry, err := s.Get("config/lease") entry, err := s.Get("config/lease")

View File

@@ -32,6 +32,8 @@ func Backend() *backend {
secretCreds(&b), secretCreds(&b),
}, },
Invalidate: b.invalidate,
Clean: b.ResetDB, Clean: b.ResetDB,
} }
@@ -105,6 +107,13 @@ func (b *backend) ResetDB() {
b.db = nil b.db = nil
} }
func (b *backend) invalidate(key string) {
switch key {
case "config/connection":
b.ResetDB()
}
}
// Lease returns the lease information // Lease returns the lease information
func (b *backend) Lease(s logical.Storage) (*configLease, error) { func (b *backend) Lease(s logical.Storage) (*configLease, error) {
entry, err := s.Get("config/lease") entry, err := s.Get("config/lease")

View File

@@ -29,6 +29,12 @@ func Backend() *backend {
"crl/pem", "crl/pem",
"crl", "crl",
}, },
LocalStorage: []string{
"revoked/",
"crl",
"certs/",
},
}, },
Paths: []*framework.Path{ Paths: []*framework.Path{

View File

@@ -1478,6 +1478,27 @@ func generateRoleSteps(t *testing.T, useCSRs bool) []logicaltest.TestStep {
} }
} }
getOrganizationCheck := func(role roleEntry) logicaltest.TestCheckFunc {
var certBundle certutil.CertBundle
return func(resp *logical.Response) error {
err := mapstructure.Decode(resp.Data, &certBundle)
if err != nil {
return err
}
parsedCertBundle, err := certBundle.ToParsedCertBundle()
if err != nil {
return fmt.Errorf("Error checking generated certificate: %s", err)
}
cert := parsedCertBundle.Certificate
expected := strutil.ParseDedupAndSortStrings(role.Organization, ",")
if !reflect.DeepEqual(cert.Subject.Organization, expected) {
return fmt.Errorf("Error: returned certificate has Organization of %s but %s was specified in the role.", cert.Subject.Organization, expected)
}
return nil
}
}
// Returns a TestCheckFunc that performs various validity checks on the // Returns a TestCheckFunc that performs various validity checks on the
// returned certificate information, mostly within checkCertsAndPrivateKey // returned certificate information, mostly within checkCertsAndPrivateKey
getCnCheck := func(name string, role roleEntry, key crypto.Signer, usage x509.KeyUsage, extUsage x509.ExtKeyUsage, validity time.Duration) logicaltest.TestCheckFunc { getCnCheck := func(name string, role roleEntry, key crypto.Signer, usage x509.KeyUsage, extUsage x509.ExtKeyUsage, validity time.Duration) logicaltest.TestCheckFunc {
@@ -1755,6 +1776,14 @@ func generateRoleSteps(t *testing.T, useCSRs bool) []logicaltest.TestStep {
roleVals.OU = "foo,bar" roleVals.OU = "foo,bar"
addTests(getOuCheck(roleVals)) addTests(getOuCheck(roleVals))
} }
// Organization tests
{
roleVals.Organization = "system:masters"
addTests(getOrganizationCheck(roleVals))
roleVals.Organization = "foo,bar"
addTests(getOrganizationCheck(roleVals))
}
// IP SAN tests // IP SAN tests
{ {
issueVals.IPSANs = "127.0.0.1,::1" issueVals.IPSANs = "127.0.0.1,::1"

View File

@@ -35,6 +35,7 @@ const (
type creationBundle struct { type creationBundle struct {
CommonName string CommonName string
OU []string OU []string
Organization []string
DNSNames []string DNSNames []string
EmailAddresses []string EmailAddresses []string
IPAddresses []net.IP IPAddresses []net.IP
@@ -581,6 +582,14 @@ func generateCreationBundle(b *backend,
} }
} }
// Set O (organization) values if specified in the role
organization := []string{}
{
if role.Organization != "" {
organization = strutil.ParseDedupAndSortStrings(role.Organization, ",")
}
}
// Read in alternate names -- DNS and email addresses // Read in alternate names -- DNS and email addresses
dnsNames := []string{} dnsNames := []string{}
emailAddresses := []string{} emailAddresses := []string{}
@@ -728,6 +737,7 @@ func generateCreationBundle(b *backend,
creationBundle := &creationBundle{ creationBundle := &creationBundle{
CommonName: cn, CommonName: cn,
OU: ou, OU: ou,
Organization: organization,
DNSNames: dnsNames, DNSNames: dnsNames,
EmailAddresses: emailAddresses, EmailAddresses: emailAddresses,
IPAddresses: ipAddresses, IPAddresses: ipAddresses,
@@ -820,6 +830,7 @@ func createCertificate(creationInfo *creationBundle) (*certutil.ParsedCertBundle
subject := pkix.Name{ subject := pkix.Name{
CommonName: creationInfo.CommonName, CommonName: creationInfo.CommonName,
OrganizationalUnit: creationInfo.OU, OrganizationalUnit: creationInfo.OU,
Organization: creationInfo.Organization,
} }
certTemplate := &x509.Certificate{ certTemplate := &x509.Certificate{
@@ -983,6 +994,7 @@ func signCertificate(creationInfo *creationBundle,
subject := pkix.Name{ subject := pkix.Name{
CommonName: creationInfo.CommonName, CommonName: creationInfo.CommonName,
OrganizationalUnit: creationInfo.OU, OrganizationalUnit: creationInfo.OU,
Organization: creationInfo.Organization,
} }
certTemplate := &x509.Certificate{ certTemplate := &x509.Certificate{

View File

@@ -172,6 +172,13 @@ Names. Defaults to true.`,
Type: framework.TypeString, Type: framework.TypeString,
Default: "", Default: "",
Description: `If set, the OU (OrganizationalUnit) will be set to Description: `If set, the OU (OrganizationalUnit) will be set to
this value in certificates issued by this role.`,
},
"organization": &framework.FieldSchema{
Type: framework.TypeString,
Default: "",
Description: `If set, the O (Organization) will be set to
this value in certificates issued by this role.`, this value in certificates issued by this role.`,
}, },
}, },
@@ -336,6 +343,7 @@ func (b *backend) pathRoleCreate(
UseCSRCommonName: data.Get("use_csr_common_name").(bool), UseCSRCommonName: data.Get("use_csr_common_name").(bool),
KeyUsage: data.Get("key_usage").(string), KeyUsage: data.Get("key_usage").(string),
OU: data.Get("ou").(string), OU: data.Get("ou").(string),
Organization: data.Get("organization").(string),
} }
if entry.KeyType == "rsa" && entry.KeyBits < 2048 { if entry.KeyType == "rsa" && entry.KeyBits < 2048 {
@@ -451,6 +459,7 @@ type roleEntry struct {
MaxPathLength *int `json:",omitempty" structs:",omitempty"` MaxPathLength *int `json:",omitempty" structs:",omitempty"`
KeyUsage string `json:"key_usage" structs:"key_usage" mapstructure:"key_usage"` KeyUsage string `json:"key_usage" structs:"key_usage" mapstructure:"key_usage"`
OU string `json:"ou" structs:"ou" mapstructure:"ou"` OU string `json:"ou" structs:"ou" mapstructure:"ou"`
Organization string `json:"organization" structs:"organization" mapstructure:"organization"`
} }
const pathListRolesHelpSyn = `List the existing roles in this backend` const pathListRolesHelpSyn = `List the existing roles in this backend`

View File

@@ -34,6 +34,8 @@ func Backend(conf *logical.BackendConfig) *backend {
}, },
Clean: b.ResetDB, Clean: b.ResetDB,
Invalidate: b.invalidate,
} }
b.logger = conf.Logger b.logger = conf.Logger
@@ -126,6 +128,13 @@ func (b *backend) ResetDB() {
b.db = nil b.db = nil
} }
func (b *backend) invalidate(key string) {
switch key {
case "config/connection":
b.ResetDB()
}
}
// Lease returns the lease information // Lease returns the lease information
func (b *backend) Lease(s logical.Storage) (*configLease, error) { func (b *backend) Lease(s logical.Storage) (*configLease, error) {
entry, err := s.Get("config/lease") entry, err := s.Get("config/lease")

View File

@@ -35,6 +35,8 @@ func Backend() *backend {
}, },
Clean: b.resetClient, Clean: b.resetClient,
Invalidate: b.invalidate,
} }
return &b return &b
@@ -99,6 +101,13 @@ func (b *backend) resetClient() {
b.client = nil b.client = nil
} }
func (b *backend) invalidate(key string) {
switch key {
case "config/connection":
b.resetClient()
}
}
// Lease returns the lease information // Lease returns the lease information
func (b *backend) Lease(s logical.Storage) (*configLease, error) { func (b *backend) Lease(s logical.Storage) (*configLease, error) {
entry, err := s.Get("config/lease") entry, err := s.Get("config/lease")

View File

@@ -10,6 +10,7 @@ import (
type backend struct { type backend struct {
*framework.Backend *framework.Backend
view logical.Storage
salt *salt.Salt salt *salt.Salt
} }
@@ -22,15 +23,8 @@ func Factory(conf *logical.BackendConfig) (logical.Backend, error) {
} }
func Backend(conf *logical.BackendConfig) (*backend, error) { func Backend(conf *logical.BackendConfig) (*backend, error) {
salt, err := salt.NewSalt(conf.StorageView, &salt.Config{
HashFunc: salt.SHA256Hash,
})
if err != nil {
return nil, err
}
var b backend var b backend
b.salt = salt b.view = conf.StorageView
b.Backend = &framework.Backend{ b.Backend = &framework.Backend{
Help: strings.TrimSpace(backendHelp), Help: strings.TrimSpace(backendHelp),
@@ -38,6 +32,10 @@ func Backend(conf *logical.BackendConfig) (*backend, error) {
Unauthenticated: []string{ Unauthenticated: []string{
"verify", "verify",
}, },
LocalStorage: []string{
"otp/",
},
}, },
Paths: []*framework.Path{ Paths: []*framework.Path{
@@ -54,10 +52,23 @@ func Backend(conf *logical.BackendConfig) (*backend, error) {
secretDynamicKey(&b), secretDynamicKey(&b),
secretOTP(&b), secretOTP(&b),
}, },
Init: b.Initialize,
} }
return &b, nil return &b, nil
} }
func (b *backend) Initialize() error {
salt, err := salt.NewSalt(b.view, &salt.Config{
HashFunc: salt.SHA256Hash,
})
if err != nil {
return err
}
b.salt = salt
return nil
}
const backendHelp = ` const backendHelp = `
The SSH backend generates credentials allowing clients to establish SSH The SSH backend generates credentials allowing clients to establish SSH
connections to remote hosts. connections to remote hosts.

View File

@@ -73,6 +73,10 @@ func TestBackend_allowed_users(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
err = b.Initialize()
if err != nil {
t.Fatal(err)
}
roleData := map[string]interface{}{ roleData := map[string]interface{}{
"key_type": "otp", "key_type": "otp",

View File

@@ -1,6 +1,8 @@
package transit package transit
import ( import (
"strings"
"github.com/hashicorp/vault/helper/keysutil" "github.com/hashicorp/vault/helper/keysutil"
"github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework" "github.com/hashicorp/vault/logical/framework"
@@ -39,6 +41,8 @@ func Backend(conf *logical.BackendConfig) *backend {
}, },
Secrets: []*framework.Secret{}, Secrets: []*framework.Secret{},
Invalidate: b.invalidate,
} }
b.lm = keysutil.NewLockManager(conf.System.CachingDisabled()) b.lm = keysutil.NewLockManager(conf.System.CachingDisabled())
@@ -50,3 +54,14 @@ type backend struct {
*framework.Backend *framework.Backend
lm *keysutil.LockManager lm *keysutil.LockManager
} }
func (b *backend) invalidate(key string) {
if b.Logger().IsTrace() {
b.Logger().Trace("transit: invalidating key", "key", key)
}
switch {
case strings.HasPrefix(key, "policy/"):
name := strings.TrimPrefix(key, "policy/")
b.lm.InvalidatePolicy(name)
}
}

View File

@@ -3,6 +3,7 @@ package command
import ( import (
"testing" "testing"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/http" "github.com/hashicorp/vault/http"
"github.com/hashicorp/vault/meta" "github.com/hashicorp/vault/meta"
"github.com/hashicorp/vault/vault" "github.com/hashicorp/vault/vault"
@@ -44,3 +45,42 @@ func TestAuditDisable(t *testing.T) {
t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String())
} }
} }
func TestAuditDisableWithOptions(t *testing.T) {
core, _, token := vault.TestCoreUnsealed(t)
ln, addr := http.TestServer(t, core)
defer ln.Close()
ui := new(cli.MockUi)
c := &AuditDisableCommand{
Meta: meta.Meta{
ClientToken: token,
Ui: ui,
},
}
args := []string{
"-address", addr,
"noop",
}
// Run once to get the client
c.Run(args)
// Get the client
client, err := c.Client()
if err != nil {
t.Fatalf("err: %#v", err)
}
if err := client.Sys().EnableAuditWithOptions("noop", &api.EnableAuditOptions{
Type: "noop",
Description: "noop",
}); err != nil {
t.Fatalf("err: %#v", err)
}
// Run again
if code := c.Run(args); code != 0 {
t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String())
}
}

View File

@@ -6,6 +6,7 @@ import (
"os" "os"
"strings" "strings"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/helper/kv-builder" "github.com/hashicorp/vault/helper/kv-builder"
"github.com/hashicorp/vault/meta" "github.com/hashicorp/vault/meta"
"github.com/mitchellh/mapstructure" "github.com/mitchellh/mapstructure"
@@ -21,9 +22,11 @@ type AuditEnableCommand struct {
func (c *AuditEnableCommand) Run(args []string) int { func (c *AuditEnableCommand) Run(args []string) int {
var desc, path string var desc, path string
var local bool
flags := c.Meta.FlagSet("audit-enable", meta.FlagSetDefault) flags := c.Meta.FlagSet("audit-enable", meta.FlagSetDefault)
flags.StringVar(&desc, "description", "", "") flags.StringVar(&desc, "description", "", "")
flags.StringVar(&path, "path", "", "") flags.StringVar(&path, "path", "", "")
flags.BoolVar(&local, "local", false, "")
flags.Usage = func() { c.Ui.Error(c.Help()) } flags.Usage = func() { c.Ui.Error(c.Help()) }
if err := flags.Parse(args); err != nil { if err := flags.Parse(args); err != nil {
return 1 return 1
@@ -68,7 +71,12 @@ func (c *AuditEnableCommand) Run(args []string) int {
return 1 return 1
} }
err = client.Sys().EnableAudit(path, auditType, desc, opts) err = client.Sys().EnableAuditWithOptions(path, &api.EnableAuditOptions{
Type: auditType,
Description: desc,
Options: opts,
Local: local,
})
if err != nil { if err != nil {
c.Ui.Error(fmt.Sprintf( c.Ui.Error(fmt.Sprintf(
"Error enabling audit backend: %s", err)) "Error enabling audit backend: %s", err))
@@ -113,6 +121,9 @@ Audit Enable Options:
is purely for referencing this audit backend. By is purely for referencing this audit backend. By
default this will be the backend type. default this will be the backend type.
-local Mark the mount as a local mount. Local mounts
are not replicated nor (if a secondary)
removed by replication.
` `
return strings.TrimSpace(helpText) return strings.TrimSpace(helpText)
} }

View File

@@ -48,16 +48,19 @@ func (c *AuditListCommand) Run(args []string) int {
} }
sort.Strings(paths) sort.Strings(paths)
columns := []string{"Path | Type | Description | Options"} columns := []string{"Path | Type | Description | Replication Behavior | Options"}
for _, path := range paths { for _, path := range paths {
audit := audits[path] audit := audits[path]
opts := make([]string, 0, len(audit.Options)) opts := make([]string, 0, len(audit.Options))
for k, v := range audit.Options { for k, v := range audit.Options {
opts = append(opts, k+"="+v) opts = append(opts, k+"="+v)
} }
replicatedBehavior := "replicated"
if audit.Local {
replicatedBehavior = "local"
}
columns = append(columns, fmt.Sprintf( columns = append(columns, fmt.Sprintf(
"%s | %s | %s | %s", audit.Path, audit.Type, audit.Description, strings.Join(opts, " "))) "%s | %s | %s | %s | %s", audit.Path, audit.Type, audit.Description, replicatedBehavior, strings.Join(opts, " ")))
} }
c.Ui.Output(columnize.SimpleFormat(columns)) c.Ui.Output(columnize.SimpleFormat(columns))

View File

@@ -3,6 +3,7 @@ package command
import ( import (
"testing" "testing"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/http" "github.com/hashicorp/vault/http"
"github.com/hashicorp/vault/meta" "github.com/hashicorp/vault/meta"
"github.com/hashicorp/vault/vault" "github.com/hashicorp/vault/vault"
@@ -34,7 +35,11 @@ func TestAuditList(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("err: %#v", err) t.Fatalf("err: %#v", err)
} }
if err := client.Sys().EnableAudit("foo", "noop", "", nil); err != nil { if err := client.Sys().EnableAuditWithOptions("foo", &api.EnableAuditOptions{
Type: "noop",
Description: "noop",
Options: nil,
}); err != nil {
t.Fatalf("err: %#v", err) t.Fatalf("err: %#v", err)
} }

View File

@@ -281,7 +281,7 @@ func (c *AuthCommand) listMethods() int {
} }
sort.Strings(paths) sort.Strings(paths)
columns := []string{"Path | Type | Default TTL | Max TTL | Description"} columns := []string{"Path | Type | Default TTL | Max TTL | Replication Behavior | Description"}
for _, path := range paths { for _, path := range paths {
auth := auth[path] auth := auth[path]
defTTL := "system" defTTL := "system"
@@ -292,8 +292,12 @@ func (c *AuthCommand) listMethods() int {
if auth.Config.MaxLeaseTTL != 0 { if auth.Config.MaxLeaseTTL != 0 {
maxTTL = strconv.Itoa(auth.Config.MaxLeaseTTL) maxTTL = strconv.Itoa(auth.Config.MaxLeaseTTL)
} }
replicatedBehavior := "replicated"
if auth.Local {
replicatedBehavior = "local"
}
columns = append(columns, fmt.Sprintf( columns = append(columns, fmt.Sprintf(
"%s | %s | %s | %s | %s", path, auth.Type, defTTL, maxTTL, auth.Description)) "%s | %s | %s | %s | %s | %s", path, auth.Type, defTTL, maxTTL, replicatedBehavior, auth.Description))
} }
c.Ui.Output(columnize.SimpleFormat(columns)) c.Ui.Output(columnize.SimpleFormat(columns))

View File

@@ -3,6 +3,7 @@ package command
import ( import (
"testing" "testing"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/http" "github.com/hashicorp/vault/http"
"github.com/hashicorp/vault/meta" "github.com/hashicorp/vault/meta"
"github.com/hashicorp/vault/vault" "github.com/hashicorp/vault/vault"
@@ -52,3 +53,50 @@ func TestAuthDisable(t *testing.T) {
t.Fatal("should not have noop mount") t.Fatal("should not have noop mount")
} }
} }
func TestAuthDisableWithOptions(t *testing.T) {
core, _, token := vault.TestCoreUnsealed(t)
ln, addr := http.TestServer(t, core)
defer ln.Close()
ui := new(cli.MockUi)
c := &AuthDisableCommand{
Meta: meta.Meta{
ClientToken: token,
Ui: ui,
},
}
args := []string{
"-address", addr,
"noop",
}
// Run the command once to setup the client, it will fail
c.Run(args)
client, err := c.Client()
if err != nil {
t.Fatalf("err: %s", err)
}
if err := client.Sys().EnableAuthWithOptions("noop", &api.EnableAuthOptions{
Type: "noop",
Description: "",
}); err != nil {
t.Fatalf("err: %#v", err)
}
if code := c.Run(args); code != 0 {
t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String())
}
mounts, err := client.Sys().ListAuth()
if err != nil {
t.Fatalf("err: %s", err)
}
if _, ok := mounts["noop"]; ok {
t.Fatal("should not have noop mount")
}
}

View File

@@ -4,6 +4,7 @@ import (
"fmt" "fmt"
"strings" "strings"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/meta" "github.com/hashicorp/vault/meta"
) )
@@ -14,9 +15,11 @@ type AuthEnableCommand struct {
func (c *AuthEnableCommand) Run(args []string) int { func (c *AuthEnableCommand) Run(args []string) int {
var description, path string var description, path string
var local bool
flags := c.Meta.FlagSet("auth-enable", meta.FlagSetDefault) flags := c.Meta.FlagSet("auth-enable", meta.FlagSetDefault)
flags.StringVar(&description, "description", "", "") flags.StringVar(&description, "description", "", "")
flags.StringVar(&path, "path", "", "") flags.StringVar(&path, "path", "", "")
flags.BoolVar(&local, "local", false, "")
flags.Usage = func() { c.Ui.Error(c.Help()) } flags.Usage = func() { c.Ui.Error(c.Help()) }
if err := flags.Parse(args); err != nil { if err := flags.Parse(args); err != nil {
return 1 return 1
@@ -44,7 +47,11 @@ func (c *AuthEnableCommand) Run(args []string) int {
return 2 return 2
} }
if err := client.Sys().EnableAuth(path, authType, description); err != nil { if err := client.Sys().EnableAuthWithOptions(path, &api.EnableAuthOptions{
Type: authType,
Description: description,
Local: local,
}); err != nil {
c.Ui.Error(fmt.Sprintf( c.Ui.Error(fmt.Sprintf(
"Error: %s", err)) "Error: %s", err))
return 2 return 2
@@ -82,6 +89,9 @@ Auth Enable Options:
to the type of the mount. This will make the auth to the type of the mount. This will make the auth
provider available at "/auth/<path>" provider available at "/auth/<path>"
-local Mark the mount as a local mount. Local mounts
are not replicated nor (if a secondary)
removed by replication.
` `
return strings.TrimSpace(helpText) return strings.TrimSpace(helpText)
} }

View File

@@ -15,11 +15,13 @@ type MountCommand struct {
func (c *MountCommand) Run(args []string) int { func (c *MountCommand) Run(args []string) int {
var description, path, defaultLeaseTTL, maxLeaseTTL string var description, path, defaultLeaseTTL, maxLeaseTTL string
var local bool
flags := c.Meta.FlagSet("mount", meta.FlagSetDefault) flags := c.Meta.FlagSet("mount", meta.FlagSetDefault)
flags.StringVar(&description, "description", "", "") flags.StringVar(&description, "description", "", "")
flags.StringVar(&path, "path", "", "") flags.StringVar(&path, "path", "", "")
flags.StringVar(&defaultLeaseTTL, "default-lease-ttl", "", "") flags.StringVar(&defaultLeaseTTL, "default-lease-ttl", "", "")
flags.StringVar(&maxLeaseTTL, "max-lease-ttl", "", "") flags.StringVar(&maxLeaseTTL, "max-lease-ttl", "", "")
flags.BoolVar(&local, "local", false, "")
flags.Usage = func() { c.Ui.Error(c.Help()) } flags.Usage = func() { c.Ui.Error(c.Help()) }
if err := flags.Parse(args); err != nil { if err := flags.Parse(args); err != nil {
return 1 return 1
@@ -54,6 +56,7 @@ func (c *MountCommand) Run(args []string) int {
DefaultLeaseTTL: defaultLeaseTTL, DefaultLeaseTTL: defaultLeaseTTL,
MaxLeaseTTL: maxLeaseTTL, MaxLeaseTTL: maxLeaseTTL,
}, },
Local: local,
} }
if err := client.Sys().Mount(path, mountInfo); err != nil { if err := client.Sys().Mount(path, mountInfo); err != nil {
@@ -102,6 +105,10 @@ Mount Options:
the previously set value. Set to '0' to the previously set value. Set to '0' to
explicitly set it to use the global default. explicitly set it to use the global default.
-local Mark the mount as a local mount. Local mounts
are not replicated nor (if a secondary)
removed by replication.
` `
return strings.TrimSpace(helpText) return strings.TrimSpace(helpText)
} }

View File

@@ -42,7 +42,7 @@ func (c *MountsCommand) Run(args []string) int {
} }
sort.Strings(paths) sort.Strings(paths)
columns := []string{"Path | Type | Default TTL | Max TTL | Description"} columns := []string{"Path | Type | Default TTL | Max TTL | Replication Behavior | Description"}
for _, path := range paths { for _, path := range paths {
mount := mounts[path] mount := mounts[path]
defTTL := "system" defTTL := "system"
@@ -63,8 +63,12 @@ func (c *MountsCommand) Run(args []string) int {
case mount.Config.MaxLeaseTTL != 0: case mount.Config.MaxLeaseTTL != 0:
maxTTL = strconv.Itoa(mount.Config.MaxLeaseTTL) maxTTL = strconv.Itoa(mount.Config.MaxLeaseTTL)
} }
replicatedBehavior := "replicated"
if mount.Local {
replicatedBehavior = "local"
}
columns = append(columns, fmt.Sprintf( columns = append(columns, fmt.Sprintf(
"%s | %s | %s | %s | %s", path, mount.Type, defTTL, maxTTL, mount.Description)) "%s | %s | %s | %s | %s | %s", path, mount.Type, defTTL, maxTTL, replicatedBehavior, mount.Description))
} }
c.Ui.Output(columnize.SimpleFormat(columns)) c.Ui.Output(columnize.SimpleFormat(columns))

View File

@@ -61,7 +61,7 @@ type ServerCommand struct {
} }
func (c *ServerCommand) Run(args []string) int { func (c *ServerCommand) Run(args []string) int {
var dev, verifyOnly, devHA bool var dev, verifyOnly, devHA, devTransactional bool
var configPath []string var configPath []string
var logLevel, devRootTokenID, devListenAddress string var logLevel, devRootTokenID, devListenAddress string
flags := c.Meta.FlagSet("server", meta.FlagSetDefault) flags := c.Meta.FlagSet("server", meta.FlagSetDefault)
@@ -70,7 +70,8 @@ func (c *ServerCommand) Run(args []string) int {
flags.StringVar(&devListenAddress, "dev-listen-address", "", "") flags.StringVar(&devListenAddress, "dev-listen-address", "", "")
flags.StringVar(&logLevel, "log-level", "info", "") flags.StringVar(&logLevel, "log-level", "info", "")
flags.BoolVar(&verifyOnly, "verify-only", false, "") flags.BoolVar(&verifyOnly, "verify-only", false, "")
flags.BoolVar(&devHA, "dev-ha", false, "") flags.BoolVar(&devHA, "ha", false, "")
flags.BoolVar(&devTransactional, "transactional", false, "")
flags.Usage = func() { c.Ui.Output(c.Help()) } flags.Usage = func() { c.Ui.Output(c.Help()) }
flags.Var((*sliceflag.StringFlag)(&configPath), "config", "config") flags.Var((*sliceflag.StringFlag)(&configPath), "config", "config")
if err := flags.Parse(args); err != nil { if err := flags.Parse(args); err != nil {
@@ -122,7 +123,7 @@ func (c *ServerCommand) Run(args []string) int {
devListenAddress = os.Getenv("VAULT_DEV_LISTEN_ADDRESS") devListenAddress = os.Getenv("VAULT_DEV_LISTEN_ADDRESS")
} }
if devHA { if devHA || devTransactional {
dev = true dev = true
} }
@@ -143,7 +144,7 @@ func (c *ServerCommand) Run(args []string) int {
// Load the configuration // Load the configuration
var config *server.Config var config *server.Config
if dev { if dev {
config = server.DevConfig(devHA) config = server.DevConfig(devHA, devTransactional)
if devListenAddress != "" { if devListenAddress != "" {
config.Listeners[0].Config["address"] = devListenAddress config.Listeners[0].Config["address"] = devListenAddress
} }
@@ -235,6 +236,9 @@ func (c *ServerCommand) Run(args []string) int {
ClusterName: config.ClusterName, ClusterName: config.ClusterName,
CacheSize: config.CacheSize, CacheSize: config.CacheSize,
} }
if dev {
coreConfig.DevToken = devRootTokenID
}
var disableClustering bool var disableClustering bool

View File

@@ -38,7 +38,7 @@ type Config struct {
} }
// DevConfig is a Config that is used for dev mode of Vault. // DevConfig is a Config that is used for dev mode of Vault.
func DevConfig(ha bool) *Config { func DevConfig(ha, transactional bool) *Config {
ret := &Config{ ret := &Config{
DisableCache: false, DisableCache: false,
DisableMlock: true, DisableMlock: true,
@@ -63,7 +63,12 @@ func DevConfig(ha bool) *Config {
DefaultLeaseTTL: 32 * 24 * time.Hour, DefaultLeaseTTL: 32 * 24 * time.Hour,
} }
if ha { switch {
case ha && transactional:
ret.Backend.Type = "inmem_transactional_ha"
case !ha && transactional:
ret.Backend.Type = "inmem_transactional"
case ha && !transactional:
ret.Backend.Type = "inmem_ha" ret.Backend.Type = "inmem_ha"
} }

View File

@@ -33,7 +33,7 @@ func TestServer_CommonHA(t *testing.T) {
args := []string{"-config", tmpfile.Name(), "-verify-only", "true"} args := []string{"-config", tmpfile.Name(), "-verify-only", "true"}
if code := c.Run(args); code != 0 { if code := c.Run(args); code != 0 {
t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) t.Fatalf("bad: %d\n\n%s\n\n%s", code, ui.ErrorWriter.String(), ui.OutputWriter.String())
} }
if !strings.Contains(ui.OutputWriter.String(), "(HA available)") { if !strings.Contains(ui.OutputWriter.String(), "(HA available)") {
@@ -61,7 +61,7 @@ func TestServer_GoodSeparateHA(t *testing.T) {
args := []string{"-config", tmpfile.Name(), "-verify-only", "true"} args := []string{"-config", tmpfile.Name(), "-verify-only", "true"}
if code := c.Run(args); code != 0 { if code := c.Run(args); code != 0 {
t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) t.Fatalf("bad: %d\n\n%s\n\n%s", code, ui.ErrorWriter.String(), ui.OutputWriter.String())
} }
if !strings.Contains(ui.OutputWriter.String(), "HA Backend:") { if !strings.Contains(ui.OutputWriter.String(), "HA Backend:") {

View File

@@ -40,7 +40,7 @@ func (c *StatusCommand) Run(args []string) int {
"Key Shares: %d\n"+ "Key Shares: %d\n"+
"Key Threshold: %d\n"+ "Key Threshold: %d\n"+
"Unseal Progress: %d\n"+ "Unseal Progress: %d\n"+
"Unseal Nonce: %v"+ "Unseal Nonce: %v\n"+
"Version: %s", "Version: %s",
sealStatus.Sealed, sealStatus.Sealed,
sealStatus.N, sealStatus.N,

View File

@@ -14,6 +14,16 @@ func TestCIDRUtil_IPBelongsToCIDR(t *testing.T) {
t.Fatalf("expected IP %q to belong to CIDR %q", ip, cidr) t.Fatalf("expected IP %q to belong to CIDR %q", ip, cidr)
} }
ip = "10.197.192.6"
cidr = "10.197.192.0/18"
belongs, err = IPBelongsToCIDR(ip, cidr)
if err != nil {
t.Fatal(err)
}
if !belongs {
t.Fatalf("expected IP %q to belong to CIDR %q", ip, cidr)
}
ip = "192.168.25.30" ip = "192.168.25.30"
cidr = "192.168.26.30/24" cidr = "192.168.26.30/24"
belongs, err = IPBelongsToCIDR(ip, cidr) belongs, err = IPBelongsToCIDR(ip, cidr)
@@ -44,6 +54,17 @@ func TestCIDRUtil_IPBelongsToCIDRBlocksString(t *testing.T) {
t.Fatalf("expected IP %q to belong to one of the CIDRs in %q", ip, cidrList) t.Fatalf("expected IP %q to belong to one of the CIDRs in %q", ip, cidrList)
} }
ip = "10.197.192.6"
cidrList = "1.2.3.0/8,10.197.192.0/18,10.197.193.0/24"
belongs, err = IPBelongsToCIDRBlocksString(ip, cidrList, ",")
if err != nil {
t.Fatal(err)
}
if !belongs {
t.Fatalf("expected IP %q to belong to one of the CIDRs in %q", ip, cidrList)
}
ip = "192.168.27.29" ip = "192.168.27.29"
cidrList = "172.169.100.200/18,192.168.0.0.0/16,10.10.20.20/24" cidrList = "172.169.100.200/18,192.168.0.0.0/16,10.10.20.20/24"

7
helper/consts/consts.go Normal file
View File

@@ -0,0 +1,7 @@
package consts
const (
// ExpirationRestoreWorkerCount specifies the numer of workers to use while
// restoring leases into the expiration manager
ExpirationRestoreWorkerCount = 64
)

13
helper/consts/error.go Normal file
View File

@@ -0,0 +1,13 @@
package consts
import "errors"
var (
// ErrSealed is returned if an operation is performed on a sealed barrier.
// No operation is expected to succeed before unsealing
ErrSealed = errors.New("Vault is sealed")
// ErrStandby is returned if an operation is performed on a standby Vault.
// No operation is expected to succeed until active.
ErrStandby = errors.New("Vault is in standby mode")
)

View File

@@ -0,0 +1,20 @@
package consts
type ReplicationState uint32
const (
ReplicationDisabled ReplicationState = iota
ReplicationPrimary
ReplicationSecondary
)
func (r ReplicationState) String() string {
switch r {
case ReplicationSecondary:
return "secondary"
case ReplicationPrimary:
return "primary"
}
return "disabled"
}

View File

@@ -71,6 +71,15 @@ func (lm *LockManager) CacheActive() bool {
return lm.cache != nil return lm.cache != nil
} }
func (lm *LockManager) InvalidatePolicy(name string) {
// Check if it's in our cache. If so, return right away.
if lm.CacheActive() {
lm.cacheMutex.Lock()
defer lm.cacheMutex.Unlock()
delete(lm.cache, name)
}
}
func (lm *LockManager) policyLock(name string, lockType bool) *sync.RWMutex { func (lm *LockManager) policyLock(name string, lockType bool) *sync.RWMutex {
lm.locksMutex.RLock() lm.locksMutex.RLock()
lock := lm.locks[name] lock := lm.locks[name]

View File

@@ -9,6 +9,7 @@ import (
"strings" "strings"
"github.com/hashicorp/errwrap" "github.com/hashicorp/errwrap"
"github.com/hashicorp/vault/helper/consts"
"github.com/hashicorp/vault/helper/duration" "github.com/hashicorp/vault/helper/duration"
"github.com/hashicorp/vault/helper/jsonutil" "github.com/hashicorp/vault/helper/jsonutil"
"github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical"
@@ -206,11 +207,11 @@ func handleRequestForwarding(core *vault.Core, handler http.Handler) http.Handle
// case of an error. // case of an error.
func request(core *vault.Core, w http.ResponseWriter, rawReq *http.Request, r *logical.Request) (*logical.Response, bool) { func request(core *vault.Core, w http.ResponseWriter, rawReq *http.Request, r *logical.Request) (*logical.Response, bool) {
resp, err := core.HandleRequest(r) resp, err := core.HandleRequest(r)
if errwrap.Contains(err, vault.ErrStandby.Error()) { if errwrap.Contains(err, consts.ErrStandby.Error()) {
respondStandby(core, w, rawReq.URL) respondStandby(core, w, rawReq.URL)
return resp, false return resp, false
} }
if respondErrorCommon(w, resp, err) { if respondErrorCommon(w, r, resp, err) {
return resp, false return resp, false
} }
@@ -310,20 +311,7 @@ func requestWrapInfo(r *http.Request, req *logical.Request) (*logical.Request, e
} }
func respondError(w http.ResponseWriter, status int, err error) { func respondError(w http.ResponseWriter, status int, err error) {
// Adjust status code when sealed logical.AdjustErrorStatusCode(&status, err)
if errwrap.Contains(err, vault.ErrSealed.Error()) {
status = http.StatusServiceUnavailable
}
// Adjust status code on
if errwrap.Contains(err, "http: request body too large") {
status = http.StatusRequestEntityTooLarge
}
// Allow HTTPCoded error passthrough to specify a code
if t, ok := err.(logical.HTTPCodedError); ok {
status = t.Code()
}
w.Header().Add("Content-Type", "application/json") w.Header().Add("Content-Type", "application/json")
w.WriteHeader(status) w.WriteHeader(status)
@@ -337,42 +325,13 @@ func respondError(w http.ResponseWriter, status int, err error) {
enc.Encode(resp) enc.Encode(resp)
} }
func respondErrorCommon(w http.ResponseWriter, resp *logical.Response, err error) bool { func respondErrorCommon(w http.ResponseWriter, req *logical.Request, resp *logical.Response, err error) bool {
// If there are no errors return statusCode, newErr := logical.RespondErrorCommon(req, resp, err)
if err == nil && (resp == nil || !resp.IsError()) { if newErr == nil && statusCode == 0 {
return false return false
} }
// Start out with internal server error since in most of these cases there respondError(w, statusCode, newErr)
// won't be a response so this won't be overridden
statusCode := http.StatusInternalServerError
// If we actually have a response, start out with bad request
if resp != nil {
statusCode = http.StatusBadRequest
}
// Now, check the error itself; if it has a specific logical error, set the
// appropriate code
if err != nil {
switch {
case errwrap.ContainsType(err, new(vault.StatusBadRequest)):
statusCode = http.StatusBadRequest
case errwrap.Contains(err, logical.ErrPermissionDenied.Error()):
statusCode = http.StatusForbidden
case errwrap.Contains(err, logical.ErrUnsupportedOperation.Error()):
statusCode = http.StatusMethodNotAllowed
case errwrap.Contains(err, logical.ErrUnsupportedPath.Error()):
statusCode = http.StatusNotFound
case errwrap.Contains(err, logical.ErrInvalidRequest.Error()):
statusCode = http.StatusBadRequest
}
}
if resp != nil && resp.IsError() {
err = fmt.Errorf("%s", resp.Data["error"].(string))
}
respondError(w, statusCode, err)
return true return true
} }

View File

@@ -9,6 +9,7 @@ import (
"testing" "testing"
"github.com/hashicorp/go-cleanhttp" "github.com/hashicorp/go-cleanhttp"
"github.com/hashicorp/vault/helper/consts"
"github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/vault" "github.com/hashicorp/vault/vault"
) )
@@ -80,6 +81,7 @@ func TestSysMounts_headerAuth(t *testing.T) {
"default_lease_ttl": json.Number("0"), "default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"),
}, },
"local": false,
}, },
"sys/": map[string]interface{}{ "sys/": map[string]interface{}{
"description": "system endpoints used for control, policy and debugging", "description": "system endpoints used for control, policy and debugging",
@@ -88,6 +90,7 @@ func TestSysMounts_headerAuth(t *testing.T) {
"default_lease_ttl": json.Number("0"), "default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"),
}, },
"local": false,
}, },
"cubbyhole/": map[string]interface{}{ "cubbyhole/": map[string]interface{}{
"description": "per-token private secret storage", "description": "per-token private secret storage",
@@ -96,6 +99,7 @@ func TestSysMounts_headerAuth(t *testing.T) {
"default_lease_ttl": json.Number("0"), "default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"),
}, },
"local": true,
}, },
}, },
"secret/": map[string]interface{}{ "secret/": map[string]interface{}{
@@ -105,6 +109,7 @@ func TestSysMounts_headerAuth(t *testing.T) {
"default_lease_ttl": json.Number("0"), "default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"),
}, },
"local": false,
}, },
"sys/": map[string]interface{}{ "sys/": map[string]interface{}{
"description": "system endpoints used for control, policy and debugging", "description": "system endpoints used for control, policy and debugging",
@@ -113,6 +118,7 @@ func TestSysMounts_headerAuth(t *testing.T) {
"default_lease_ttl": json.Number("0"), "default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"),
}, },
"local": false,
}, },
"cubbyhole/": map[string]interface{}{ "cubbyhole/": map[string]interface{}{
"description": "per-token private secret storage", "description": "per-token private secret storage",
@@ -121,6 +127,7 @@ func TestSysMounts_headerAuth(t *testing.T) {
"default_lease_ttl": json.Number("0"), "default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"),
}, },
"local": true,
}, },
} }
testResponseStatus(t, resp, 200) testResponseStatus(t, resp, 200)
@@ -223,7 +230,7 @@ func TestHandler_error(t *testing.T) {
// vault.ErrSealed is a special case // vault.ErrSealed is a special case
w3 := httptest.NewRecorder() w3 := httptest.NewRecorder()
respondError(w3, 400, vault.ErrSealed) respondError(w3, 400, consts.ErrSealed)
if w3.Code != 503 { if w3.Code != 503 {
t.Fatalf("expected 503, got %d", w3.Code) t.Fatalf("expected 503, got %d", w3.Code)

View File

@@ -35,7 +35,7 @@ func handleHelp(core *vault.Core, w http.ResponseWriter, req *http.Request) {
resp, err := core.HandleRequest(lreq) resp, err := core.HandleRequest(lreq)
if err != nil { if err != nil {
respondErrorCommon(w, resp, err) respondErrorCommon(w, lreq, resp, err)
return return
} }

View File

@@ -53,6 +53,12 @@ func buildLogicalRequest(core *vault.Core, w http.ResponseWriter, r *http.Reques
return nil, http.StatusMethodNotAllowed, nil return nil, http.StatusMethodNotAllowed, nil
} }
if op == logical.ListOperation {
if !strings.HasSuffix(path, "/") {
path += "/"
}
}
// Parse the request if we can // Parse the request if we can
var data map[string]interface{} var data map[string]interface{}
if op == logical.UpdateOperation { if op == logical.UpdateOperation {
@@ -109,40 +115,13 @@ func handleLogical(core *vault.Core, dataOnly bool, prepareRequestCallback Prepa
// Make the internal request. We attach the connection info // Make the internal request. We attach the connection info
// as well in case this is an authentication request that requires // as well in case this is an authentication request that requires
// it. Vault core handles stripping this if we need to. // it. Vault core handles stripping this if we need to. This also
// handles all error cases; if we hit respondLogical, the request is a
// success.
resp, ok := request(core, w, r, req) resp, ok := request(core, w, r, req)
if !ok { if !ok {
return return
} }
switch {
case req.Operation == logical.ReadOperation:
if resp == nil {
respondError(w, http.StatusNotFound, nil)
return
}
// Basically: if we have empty "keys" or no keys at all, 404. This
// provides consistency with GET.
case req.Operation == logical.ListOperation && resp.WrapInfo == nil:
if resp == nil || len(resp.Data) == 0 {
respondError(w, http.StatusNotFound, nil)
return
}
keysRaw, ok := resp.Data["keys"]
if !ok || keysRaw == nil {
respondError(w, http.StatusNotFound, nil)
return
}
keys, ok := keysRaw.([]string)
if !ok {
respondError(w, http.StatusInternalServerError, nil)
return
}
if len(keys) == 0 {
respondError(w, http.StatusNotFound, nil)
return
}
}
// Build the proper response // Build the proper response
respondLogical(w, r, req, dataOnly, resp) respondLogical(w, r, req, dataOnly, resp)

View File

@@ -4,8 +4,10 @@ import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"io" "io"
"net/http"
"reflect" "reflect"
"strconv" "strconv"
"strings"
"testing" "testing"
"time" "time"
@@ -101,7 +103,7 @@ func TestLogical_StandbyRedirect(t *testing.T) {
// Attempt to fix raciness in this test by giving the first core a chance // Attempt to fix raciness in this test by giving the first core a chance
// to grab the lock // to grab the lock
time.Sleep(time.Second) time.Sleep(2 * time.Second)
// Create a second HA Vault // Create a second HA Vault
conf2 := &vault.CoreConfig{ conf2 := &vault.CoreConfig{
@@ -252,3 +254,42 @@ func TestLogical_RequestSizeLimit(t *testing.T) {
}) })
testResponseStatus(t, resp, 413) testResponseStatus(t, resp, 413)
} }
func TestLogical_ListSuffix(t *testing.T) {
core, _, _ := vault.TestCoreUnsealed(t)
req, _ := http.NewRequest("GET", "http://127.0.0.1:8200/v1/secret/foo", nil)
lreq, status, err := buildLogicalRequest(core, nil, req)
if err != nil {
t.Fatal(err)
}
if status != 0 {
t.Fatalf("got status %d", status)
}
if strings.HasSuffix(lreq.Path, "/") {
t.Fatal("trailing slash found on path")
}
req, _ = http.NewRequest("GET", "http://127.0.0.1:8200/v1/secret/foo?list=true", nil)
lreq, status, err = buildLogicalRequest(core, nil, req)
if err != nil {
t.Fatal(err)
}
if status != 0 {
t.Fatalf("got status %d", status)
}
if !strings.HasSuffix(lreq.Path, "/") {
t.Fatal("trailing slash not found on path")
}
req, _ = http.NewRequest("LIST", "http://127.0.0.1:8200/v1/secret/foo", nil)
lreq, status, err = buildLogicalRequest(core, nil, req)
if err != nil {
t.Fatal(err)
}
if status != 0 {
t.Fatalf("got status %d", status)
}
if !strings.HasSuffix(lreq.Path, "/") {
t.Fatal("trailing slash not found on path")
}
}

View File

@@ -35,6 +35,7 @@ func TestSysAudit(t *testing.T) {
"type": "noop", "type": "noop",
"description": "", "description": "",
"options": map[string]interface{}{}, "options": map[string]interface{}{},
"local": false,
}, },
}, },
"noop/": map[string]interface{}{ "noop/": map[string]interface{}{
@@ -42,6 +43,7 @@ func TestSysAudit(t *testing.T) {
"type": "noop", "type": "noop",
"description": "", "description": "",
"options": map[string]interface{}{}, "options": map[string]interface{}{},
"local": false,
}, },
} }
testResponseStatus(t, resp, 200) testResponseStatus(t, resp, 200)

View File

@@ -32,6 +32,7 @@ func TestSysAuth(t *testing.T) {
"default_lease_ttl": json.Number("0"), "default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"),
}, },
"local": false,
}, },
}, },
"token/": map[string]interface{}{ "token/": map[string]interface{}{
@@ -41,6 +42,7 @@ func TestSysAuth(t *testing.T) {
"default_lease_ttl": json.Number("0"), "default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"),
}, },
"local": false,
}, },
} }
testResponseStatus(t, resp, 200) testResponseStatus(t, resp, 200)
@@ -83,6 +85,7 @@ func TestSysEnableAuth(t *testing.T) {
"default_lease_ttl": json.Number("0"), "default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"),
}, },
"local": false,
}, },
"token/": map[string]interface{}{ "token/": map[string]interface{}{
"description": "token based credentials", "description": "token based credentials",
@@ -91,6 +94,7 @@ func TestSysEnableAuth(t *testing.T) {
"default_lease_ttl": json.Number("0"), "default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"),
}, },
"local": false,
}, },
}, },
"foo/": map[string]interface{}{ "foo/": map[string]interface{}{
@@ -100,6 +104,7 @@ func TestSysEnableAuth(t *testing.T) {
"default_lease_ttl": json.Number("0"), "default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"),
}, },
"local": false,
}, },
"token/": map[string]interface{}{ "token/": map[string]interface{}{
"description": "token based credentials", "description": "token based credentials",
@@ -108,6 +113,7 @@ func TestSysEnableAuth(t *testing.T) {
"default_lease_ttl": json.Number("0"), "default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"),
}, },
"local": false,
}, },
} }
testResponseStatus(t, resp, 200) testResponseStatus(t, resp, 200)
@@ -153,6 +159,7 @@ func TestSysDisableAuth(t *testing.T) {
}, },
"description": "token based credentials", "description": "token based credentials",
"type": "token", "type": "token",
"local": false,
}, },
}, },
"token/": map[string]interface{}{ "token/": map[string]interface{}{
@@ -162,6 +169,7 @@ func TestSysDisableAuth(t *testing.T) {
}, },
"description": "token based credentials", "description": "token based credentials",
"type": "token", "type": "token",
"local": false,
}, },
} }
testResponseStatus(t, resp, 200) testResponseStatus(t, resp, 200)

View File

@@ -33,6 +33,7 @@ func TestSysMounts(t *testing.T) {
"default_lease_ttl": json.Number("0"), "default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"),
}, },
"local": false,
}, },
"sys/": map[string]interface{}{ "sys/": map[string]interface{}{
"description": "system endpoints used for control, policy and debugging", "description": "system endpoints used for control, policy and debugging",
@@ -41,6 +42,7 @@ func TestSysMounts(t *testing.T) {
"default_lease_ttl": json.Number("0"), "default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"),
}, },
"local": false,
}, },
"cubbyhole/": map[string]interface{}{ "cubbyhole/": map[string]interface{}{
"description": "per-token private secret storage", "description": "per-token private secret storage",
@@ -49,6 +51,7 @@ func TestSysMounts(t *testing.T) {
"default_lease_ttl": json.Number("0"), "default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"),
}, },
"local": true,
}, },
}, },
"secret/": map[string]interface{}{ "secret/": map[string]interface{}{
@@ -58,6 +61,7 @@ func TestSysMounts(t *testing.T) {
"default_lease_ttl": json.Number("0"), "default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"),
}, },
"local": false,
}, },
"sys/": map[string]interface{}{ "sys/": map[string]interface{}{
"description": "system endpoints used for control, policy and debugging", "description": "system endpoints used for control, policy and debugging",
@@ -66,6 +70,7 @@ func TestSysMounts(t *testing.T) {
"default_lease_ttl": json.Number("0"), "default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"),
}, },
"local": false,
}, },
"cubbyhole/": map[string]interface{}{ "cubbyhole/": map[string]interface{}{
"description": "per-token private secret storage", "description": "per-token private secret storage",
@@ -74,6 +79,7 @@ func TestSysMounts(t *testing.T) {
"default_lease_ttl": json.Number("0"), "default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"),
}, },
"local": true,
}, },
} }
testResponseStatus(t, resp, 200) testResponseStatus(t, resp, 200)
@@ -114,6 +120,7 @@ func TestSysMount(t *testing.T) {
"default_lease_ttl": json.Number("0"), "default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"),
}, },
"local": false,
}, },
"secret/": map[string]interface{}{ "secret/": map[string]interface{}{
"description": "generic secret storage", "description": "generic secret storage",
@@ -122,6 +129,7 @@ func TestSysMount(t *testing.T) {
"default_lease_ttl": json.Number("0"), "default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"),
}, },
"local": false,
}, },
"sys/": map[string]interface{}{ "sys/": map[string]interface{}{
"description": "system endpoints used for control, policy and debugging", "description": "system endpoints used for control, policy and debugging",
@@ -130,6 +138,7 @@ func TestSysMount(t *testing.T) {
"default_lease_ttl": json.Number("0"), "default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"),
}, },
"local": false,
}, },
"cubbyhole/": map[string]interface{}{ "cubbyhole/": map[string]interface{}{
"description": "per-token private secret storage", "description": "per-token private secret storage",
@@ -138,6 +147,7 @@ func TestSysMount(t *testing.T) {
"default_lease_ttl": json.Number("0"), "default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"),
}, },
"local": true,
}, },
}, },
"foo/": map[string]interface{}{ "foo/": map[string]interface{}{
@@ -147,6 +157,7 @@ func TestSysMount(t *testing.T) {
"default_lease_ttl": json.Number("0"), "default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"),
}, },
"local": false,
}, },
"secret/": map[string]interface{}{ "secret/": map[string]interface{}{
"description": "generic secret storage", "description": "generic secret storage",
@@ -155,6 +166,7 @@ func TestSysMount(t *testing.T) {
"default_lease_ttl": json.Number("0"), "default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"),
}, },
"local": false,
}, },
"sys/": map[string]interface{}{ "sys/": map[string]interface{}{
"description": "system endpoints used for control, policy and debugging", "description": "system endpoints used for control, policy and debugging",
@@ -163,6 +175,7 @@ func TestSysMount(t *testing.T) {
"default_lease_ttl": json.Number("0"), "default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"),
}, },
"local": false,
}, },
"cubbyhole/": map[string]interface{}{ "cubbyhole/": map[string]interface{}{
"description": "per-token private secret storage", "description": "per-token private secret storage",
@@ -171,6 +184,7 @@ func TestSysMount(t *testing.T) {
"default_lease_ttl": json.Number("0"), "default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"),
}, },
"local": true,
}, },
} }
testResponseStatus(t, resp, 200) testResponseStatus(t, resp, 200)
@@ -233,6 +247,7 @@ func TestSysRemount(t *testing.T) {
"default_lease_ttl": json.Number("0"), "default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"),
}, },
"local": false,
}, },
"secret/": map[string]interface{}{ "secret/": map[string]interface{}{
"description": "generic secret storage", "description": "generic secret storage",
@@ -241,6 +256,7 @@ func TestSysRemount(t *testing.T) {
"default_lease_ttl": json.Number("0"), "default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"),
}, },
"local": false,
}, },
"sys/": map[string]interface{}{ "sys/": map[string]interface{}{
"description": "system endpoints used for control, policy and debugging", "description": "system endpoints used for control, policy and debugging",
@@ -249,6 +265,7 @@ func TestSysRemount(t *testing.T) {
"default_lease_ttl": json.Number("0"), "default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"),
}, },
"local": false,
}, },
"cubbyhole/": map[string]interface{}{ "cubbyhole/": map[string]interface{}{
"description": "per-token private secret storage", "description": "per-token private secret storage",
@@ -257,6 +274,7 @@ func TestSysRemount(t *testing.T) {
"default_lease_ttl": json.Number("0"), "default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"),
}, },
"local": true,
}, },
}, },
"bar/": map[string]interface{}{ "bar/": map[string]interface{}{
@@ -266,6 +284,7 @@ func TestSysRemount(t *testing.T) {
"default_lease_ttl": json.Number("0"), "default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"),
}, },
"local": false,
}, },
"secret/": map[string]interface{}{ "secret/": map[string]interface{}{
"description": "generic secret storage", "description": "generic secret storage",
@@ -274,6 +293,7 @@ func TestSysRemount(t *testing.T) {
"default_lease_ttl": json.Number("0"), "default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"),
}, },
"local": false,
}, },
"sys/": map[string]interface{}{ "sys/": map[string]interface{}{
"description": "system endpoints used for control, policy and debugging", "description": "system endpoints used for control, policy and debugging",
@@ -282,6 +302,7 @@ func TestSysRemount(t *testing.T) {
"default_lease_ttl": json.Number("0"), "default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"),
}, },
"local": false,
}, },
"cubbyhole/": map[string]interface{}{ "cubbyhole/": map[string]interface{}{
"description": "per-token private secret storage", "description": "per-token private secret storage",
@@ -290,6 +311,7 @@ func TestSysRemount(t *testing.T) {
"default_lease_ttl": json.Number("0"), "default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"),
}, },
"local": true,
}, },
} }
testResponseStatus(t, resp, 200) testResponseStatus(t, resp, 200)
@@ -333,6 +355,7 @@ func TestSysUnmount(t *testing.T) {
"default_lease_ttl": json.Number("0"), "default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"),
}, },
"local": false,
}, },
"sys/": map[string]interface{}{ "sys/": map[string]interface{}{
"description": "system endpoints used for control, policy and debugging", "description": "system endpoints used for control, policy and debugging",
@@ -341,6 +364,7 @@ func TestSysUnmount(t *testing.T) {
"default_lease_ttl": json.Number("0"), "default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"),
}, },
"local": false,
}, },
"cubbyhole/": map[string]interface{}{ "cubbyhole/": map[string]interface{}{
"description": "per-token private secret storage", "description": "per-token private secret storage",
@@ -349,6 +373,7 @@ func TestSysUnmount(t *testing.T) {
"default_lease_ttl": json.Number("0"), "default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"),
}, },
"local": true,
}, },
}, },
"secret/": map[string]interface{}{ "secret/": map[string]interface{}{
@@ -358,6 +383,7 @@ func TestSysUnmount(t *testing.T) {
"default_lease_ttl": json.Number("0"), "default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"),
}, },
"local": false,
}, },
"sys/": map[string]interface{}{ "sys/": map[string]interface{}{
"description": "system endpoints used for control, policy and debugging", "description": "system endpoints used for control, policy and debugging",
@@ -366,6 +392,7 @@ func TestSysUnmount(t *testing.T) {
"default_lease_ttl": json.Number("0"), "default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"),
}, },
"local": false,
}, },
"cubbyhole/": map[string]interface{}{ "cubbyhole/": map[string]interface{}{
"description": "per-token private secret storage", "description": "per-token private secret storage",
@@ -374,6 +401,7 @@ func TestSysUnmount(t *testing.T) {
"default_lease_ttl": json.Number("0"), "default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"),
}, },
"local": true,
}, },
} }
testResponseStatus(t, resp, 200) testResponseStatus(t, resp, 200)
@@ -414,6 +442,7 @@ func TestSysTuneMount(t *testing.T) {
"default_lease_ttl": json.Number("0"), "default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"),
}, },
"local": false,
}, },
"secret/": map[string]interface{}{ "secret/": map[string]interface{}{
"description": "generic secret storage", "description": "generic secret storage",
@@ -422,6 +451,7 @@ func TestSysTuneMount(t *testing.T) {
"default_lease_ttl": json.Number("0"), "default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"),
}, },
"local": false,
}, },
"sys/": map[string]interface{}{ "sys/": map[string]interface{}{
"description": "system endpoints used for control, policy and debugging", "description": "system endpoints used for control, policy and debugging",
@@ -430,6 +460,7 @@ func TestSysTuneMount(t *testing.T) {
"default_lease_ttl": json.Number("0"), "default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"),
}, },
"local": false,
}, },
"cubbyhole/": map[string]interface{}{ "cubbyhole/": map[string]interface{}{
"description": "per-token private secret storage", "description": "per-token private secret storage",
@@ -438,6 +469,7 @@ func TestSysTuneMount(t *testing.T) {
"default_lease_ttl": json.Number("0"), "default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"),
}, },
"local": true,
}, },
}, },
"foo/": map[string]interface{}{ "foo/": map[string]interface{}{
@@ -447,6 +479,7 @@ func TestSysTuneMount(t *testing.T) {
"default_lease_ttl": json.Number("0"), "default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"),
}, },
"local": false,
}, },
"secret/": map[string]interface{}{ "secret/": map[string]interface{}{
"description": "generic secret storage", "description": "generic secret storage",
@@ -455,6 +488,7 @@ func TestSysTuneMount(t *testing.T) {
"default_lease_ttl": json.Number("0"), "default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"),
}, },
"local": false,
}, },
"sys/": map[string]interface{}{ "sys/": map[string]interface{}{
"description": "system endpoints used for control, policy and debugging", "description": "system endpoints used for control, policy and debugging",
@@ -463,6 +497,7 @@ func TestSysTuneMount(t *testing.T) {
"default_lease_ttl": json.Number("0"), "default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"),
}, },
"local": false,
}, },
"cubbyhole/": map[string]interface{}{ "cubbyhole/": map[string]interface{}{
"description": "per-token private secret storage", "description": "per-token private secret storage",
@@ -471,6 +506,7 @@ func TestSysTuneMount(t *testing.T) {
"default_lease_ttl": json.Number("0"), "default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"),
}, },
"local": true,
}, },
} }
testResponseStatus(t, resp, 200) testResponseStatus(t, resp, 200)
@@ -532,6 +568,7 @@ func TestSysTuneMount(t *testing.T) {
"default_lease_ttl": json.Number("259196400"), "default_lease_ttl": json.Number("259196400"),
"max_lease_ttl": json.Number("259200000"), "max_lease_ttl": json.Number("259200000"),
}, },
"local": false,
}, },
"secret/": map[string]interface{}{ "secret/": map[string]interface{}{
"description": "generic secret storage", "description": "generic secret storage",
@@ -540,6 +577,7 @@ func TestSysTuneMount(t *testing.T) {
"default_lease_ttl": json.Number("0"), "default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"),
}, },
"local": false,
}, },
"sys/": map[string]interface{}{ "sys/": map[string]interface{}{
"description": "system endpoints used for control, policy and debugging", "description": "system endpoints used for control, policy and debugging",
@@ -548,6 +586,7 @@ func TestSysTuneMount(t *testing.T) {
"default_lease_ttl": json.Number("0"), "default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"),
}, },
"local": false,
}, },
"cubbyhole/": map[string]interface{}{ "cubbyhole/": map[string]interface{}{
"description": "per-token private secret storage", "description": "per-token private secret storage",
@@ -556,6 +595,7 @@ func TestSysTuneMount(t *testing.T) {
"default_lease_ttl": json.Number("0"), "default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"),
}, },
"local": true,
}, },
}, },
"foo/": map[string]interface{}{ "foo/": map[string]interface{}{
@@ -565,6 +605,7 @@ func TestSysTuneMount(t *testing.T) {
"default_lease_ttl": json.Number("259196400"), "default_lease_ttl": json.Number("259196400"),
"max_lease_ttl": json.Number("259200000"), "max_lease_ttl": json.Number("259200000"),
}, },
"local": false,
}, },
"secret/": map[string]interface{}{ "secret/": map[string]interface{}{
"description": "generic secret storage", "description": "generic secret storage",
@@ -573,6 +614,7 @@ func TestSysTuneMount(t *testing.T) {
"default_lease_ttl": json.Number("0"), "default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"),
}, },
"local": false,
}, },
"sys/": map[string]interface{}{ "sys/": map[string]interface{}{
"description": "system endpoints used for control, policy and debugging", "description": "system endpoints used for control, policy and debugging",
@@ -581,6 +623,7 @@ func TestSysTuneMount(t *testing.T) {
"default_lease_ttl": json.Number("0"), "default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"),
}, },
"local": false,
}, },
"cubbyhole/": map[string]interface{}{ "cubbyhole/": map[string]interface{}{
"description": "per-token private secret storage", "description": "per-token private secret storage",
@@ -589,6 +632,7 @@ func TestSysTuneMount(t *testing.T) {
"default_lease_ttl": json.Number("0"), "default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"),
}, },
"local": true,
}, },
} }

View File

@@ -7,6 +7,7 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"github.com/hashicorp/vault/helper/consts"
"github.com/hashicorp/vault/helper/pgpkeys" "github.com/hashicorp/vault/helper/pgpkeys"
"github.com/hashicorp/vault/vault" "github.com/hashicorp/vault/vault"
) )
@@ -19,6 +20,13 @@ func handleSysRekeyInit(core *vault.Core, recovery bool) http.Handler {
return return
} }
repState := core.ReplicationState()
if repState == consts.ReplicationSecondary {
respondError(w, http.StatusBadRequest,
fmt.Errorf("rekeying can only be performed on the primary cluster when replication is activated"))
return
}
switch { switch {
case recovery && !core.SealAccess().RecoveryKeySupported(): case recovery && !core.SealAccess().RecoveryKeySupported():
respondError(w, http.StatusBadRequest, fmt.Errorf("recovery rekeying not supported")) respondError(w, http.StatusBadRequest, fmt.Errorf("recovery rekeying not supported"))
@@ -108,7 +116,7 @@ func handleSysRekeyInitPut(core *vault.Core, recovery bool, w http.ResponseWrite
// Right now we don't support this, but the rest of the code is ready for // Right now we don't support this, but the rest of the code is ready for
// when we do, hence the check below for this to be false if // when we do, hence the check below for this to be false if
// StoredShares is greater than zero // StoredShares is greater than zero
if core.SealAccess().StoredKeysSupported() { if core.SealAccess().StoredKeysSupported() && !recovery {
respondError(w, http.StatusBadRequest, fmt.Errorf("rekeying of barrier not supported when stored key support is available")) respondError(w, http.StatusBadRequest, fmt.Errorf("rekeying of barrier not supported when stored key support is available"))
return return
} }

View File

@@ -8,6 +8,7 @@ import (
"net/http" "net/http"
"github.com/hashicorp/errwrap" "github.com/hashicorp/errwrap"
"github.com/hashicorp/vault/helper/consts"
"github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/vault" "github.com/hashicorp/vault/vault"
"github.com/hashicorp/vault/version" "github.com/hashicorp/vault/version"
@@ -126,7 +127,7 @@ func handleSysUnseal(core *vault.Core) http.Handler {
case errwrap.Contains(err, vault.ErrBarrierInvalidKey.Error()): case errwrap.Contains(err, vault.ErrBarrierInvalidKey.Error()):
case errwrap.Contains(err, vault.ErrBarrierNotInit.Error()): case errwrap.Contains(err, vault.ErrBarrierNotInit.Error()):
case errwrap.Contains(err, vault.ErrBarrierSealed.Error()): case errwrap.Contains(err, vault.ErrBarrierSealed.Error()):
case errwrap.Contains(err, vault.ErrStandby.Error()): case errwrap.Contains(err, consts.ErrStandby.Error()):
default: default:
respondError(w, http.StatusInternalServerError, err) respondError(w, http.StatusInternalServerError, err)
return return

View File

@@ -8,7 +8,7 @@ import (
// is present on the Request structure for credential backends. // is present on the Request structure for credential backends.
type Connection struct { type Connection struct {
// RemoteAddr is the network address that sent the request. // RemoteAddr is the network address that sent the request.
RemoteAddr string RemoteAddr string `json:"remote_addr"`
// ConnState is the TLS connection state if applicable. // ConnState is the TLS connection state if applicable.
ConnState *tls.ConnectionState ConnState *tls.ConnectionState

View File

@@ -21,3 +21,27 @@ func (e *codedError) Error() string {
func (e *codedError) Code() int { func (e *codedError) Code() int {
return e.code return e.code
} }
// Struct to identify user input errors. This is helpful in responding the
// appropriate status codes to clients from the HTTP endpoints.
type StatusBadRequest struct {
Err string
}
// Implementing error interface
func (s *StatusBadRequest) Error() string {
return s.Err
}
// This is a new type declared to not cause potential compatibility problems if
// the logic around the HTTPCodedError interface changes; in particular for
// logical request paths it is basically ignored, and changing that behavior
// might cause unforseen issues.
type ReplicationCodedError struct {
Msg string
Code int
}
func (r *ReplicationCodedError) Error() string {
return r.Msg
}

View File

@@ -1,6 +1,7 @@
package framework package framework
import ( import (
"encoding/json"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"regexp" "regexp"
@@ -12,6 +13,7 @@ import (
log "github.com/mgutz/logxi/v1" log "github.com/mgutz/logxi/v1"
"github.com/hashicorp/go-multierror" "github.com/hashicorp/go-multierror"
"github.com/hashicorp/vault/helper/duration"
"github.com/hashicorp/vault/helper/errutil" "github.com/hashicorp/vault/helper/errutil"
"github.com/hashicorp/vault/helper/logformat" "github.com/hashicorp/vault/helper/logformat"
"github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical"
@@ -534,7 +536,40 @@ type FieldSchema struct {
// the zero value of the type. // the zero value of the type.
func (s *FieldSchema) DefaultOrZero() interface{} { func (s *FieldSchema) DefaultOrZero() interface{} {
if s.Default != nil { if s.Default != nil {
return s.Default switch s.Type {
case TypeDurationSecond:
var result int
switch inp := s.Default.(type) {
case nil:
return s.Type.Zero()
case int:
result = inp
case int64:
result = int(inp)
case float32:
result = int(inp)
case float64:
result = int(inp)
case string:
dur, err := duration.ParseDurationSecond(inp)
if err != nil {
return s.Type.Zero()
}
result = int(dur.Seconds())
case json.Number:
valInt64, err := inp.Int64()
if err != nil {
return s.Type.Zero()
}
result = int(valInt64)
default:
return s.Type.Zero()
}
return result
default:
return s.Default
}
} }
return s.Type.Zero() return s.Type.Zero()

View File

@@ -554,6 +554,16 @@ func TestFieldSchemaDefaultOrZero(t *testing.T) {
60, 60,
}, },
"default duration int64": {
&FieldSchema{Type: TypeDurationSecond, Default: int64(60)},
60,
},
"default duration string": {
&FieldSchema{Type: TypeDurationSecond, Default: "60s"},
60,
},
"default duration not set": { "default duration not set": {
&FieldSchema{Type: TypeDurationSecond}, &FieldSchema{Type: TypeDurationSecond},
0, 0,

View File

@@ -80,22 +80,3 @@ type Paths struct {
// indicates that these paths should not be replicated // indicates that these paths should not be replicated
LocalStorage []string LocalStorage []string
} }
type ReplicationState uint32
const (
ReplicationDisabled ReplicationState = iota
ReplicationPrimary
ReplicationSecondary
)
func (r ReplicationState) String() string {
switch r {
case ReplicationSecondary:
return "secondary"
case ReplicationPrimary:
return "primary"
}
return "disabled"
}

View File

@@ -25,6 +25,10 @@ type Request struct {
// Id is the uuid associated with each request // Id is the uuid associated with each request
ID string `json:"id" structs:"id" mapstructure:"id"` ID string `json:"id" structs:"id" mapstructure:"id"`
// If set, the name given to the replication secondary where this request
// originated
ReplicationCluster string `json:"replication_cluster" structs:"replication_cluster", mapstructure:"replication_cluster"`
// Operation is the requested operation type // Operation is the requested operation type
Operation Operation `json:"operation" structs:"operation" mapstructure:"operation"` Operation Operation `json:"operation" structs:"operation" mapstructure:"operation"`
@@ -38,7 +42,7 @@ type Request struct {
Data map[string]interface{} `json:"map" structs:"data" mapstructure:"data"` Data map[string]interface{} `json:"map" structs:"data" mapstructure:"data"`
// Storage can be used to durably store and retrieve state. // Storage can be used to durably store and retrieve state.
Storage Storage `json:"storage" structs:"storage" mapstructure:"storage"` Storage Storage `json:"-"`
// Secret will be non-nil only for Revoke and Renew operations // Secret will be non-nil only for Revoke and Renew operations
// to represent the secret that was returned prior. // to represent the secret that was returned prior.

111
logical/response_util.go Normal file
View File

@@ -0,0 +1,111 @@
package logical
import (
"errors"
"fmt"
"net/http"
"github.com/hashicorp/errwrap"
multierror "github.com/hashicorp/go-multierror"
"github.com/hashicorp/vault/helper/consts"
)
// RespondErrorCommon pulls most of the functionality from http's
// respondErrorCommon and some of http's handleLogical and makes it available
// to both the http package and elsewhere.
func RespondErrorCommon(req *Request, resp *Response, err error) (int, error) {
if err == nil && (resp == nil || !resp.IsError()) {
switch {
case req.Operation == ReadOperation:
if resp == nil {
return http.StatusNotFound, nil
}
// Basically: if we have empty "keys" or no keys at all, 404. This
// provides consistency with GET.
case req.Operation == ListOperation && resp.WrapInfo == nil:
if resp == nil || len(resp.Data) == 0 {
return http.StatusNotFound, nil
}
keysRaw, ok := resp.Data["keys"]
if !ok || keysRaw == nil {
return http.StatusNotFound, nil
}
keys, ok := keysRaw.([]string)
if !ok {
return http.StatusInternalServerError, nil
}
if len(keys) == 0 {
return http.StatusNotFound, nil
}
}
return 0, nil
}
if errwrap.ContainsType(err, new(ReplicationCodedError)) {
var allErrors error
codedErr := errwrap.GetType(err, new(ReplicationCodedError)).(*ReplicationCodedError)
errwrap.Walk(err, func(inErr error) {
newErr, ok := inErr.(*ReplicationCodedError)
if !ok {
allErrors = multierror.Append(allErrors, newErr)
}
})
if allErrors != nil {
return codedErr.Code, multierror.Append(errors.New(fmt.Sprintf("errors from both primary and secondary; primary error was %v; secondary errors follow", codedErr.Msg)), allErrors)
}
return codedErr.Code, errors.New(codedErr.Msg)
}
// Start out with internal server error since in most of these cases there
// won't be a response so this won't be overridden
statusCode := http.StatusInternalServerError
// If we actually have a response, start out with bad request
if resp != nil {
statusCode = http.StatusBadRequest
}
// Now, check the error itself; if it has a specific logical error, set the
// appropriate code
if err != nil {
switch {
case errwrap.ContainsType(err, new(StatusBadRequest)):
statusCode = http.StatusBadRequest
case errwrap.Contains(err, ErrPermissionDenied.Error()):
statusCode = http.StatusForbidden
case errwrap.Contains(err, ErrUnsupportedOperation.Error()):
statusCode = http.StatusMethodNotAllowed
case errwrap.Contains(err, ErrUnsupportedPath.Error()):
statusCode = http.StatusNotFound
case errwrap.Contains(err, ErrInvalidRequest.Error()):
statusCode = http.StatusBadRequest
}
}
if resp != nil && resp.IsError() {
err = fmt.Errorf("%s", resp.Data["error"].(string))
}
return statusCode, err
}
// AdjustErrorStatusCode adjusts the status that will be sent in error
// conditions in a way that can be shared across http's respondError and other
// locations.
func AdjustErrorStatusCode(status *int, err error) {
// Adjust status code when sealed
if errwrap.Contains(err, consts.ErrSealed.Error()) {
*status = http.StatusServiceUnavailable
}
// Adjust status code on
if errwrap.Contains(err, "http: request body too large") {
*status = http.StatusRequestEntityTooLarge
}
// Allow HTTPCoded error passthrough to specify a code
if t, ok := err.(HTTPCodedError); ok {
*status = t.Code()
}
}

View File

@@ -1,6 +1,10 @@
package logical package logical
import "time" import (
"time"
"github.com/hashicorp/vault/helper/consts"
)
// SystemView exposes system configuration information in a safe way // SystemView exposes system configuration information in a safe way
// for logical backends to consume // for logical backends to consume
@@ -32,7 +36,7 @@ type SystemView interface {
CachingDisabled() bool CachingDisabled() bool
// ReplicationState indicates the state of cluster replication // ReplicationState indicates the state of cluster replication
ReplicationState() ReplicationState ReplicationState() consts.ReplicationState
} }
type StaticSystemView struct { type StaticSystemView struct {
@@ -42,7 +46,7 @@ type StaticSystemView struct {
TaintedVal bool TaintedVal bool
CachingDisabledVal bool CachingDisabledVal bool
Primary bool Primary bool
ReplicationStateVal ReplicationState ReplicationStateVal consts.ReplicationState
} }
func (d StaticSystemView) DefaultLeaseTTL() time.Duration { func (d StaticSystemView) DefaultLeaseTTL() time.Duration {
@@ -65,6 +69,6 @@ func (d StaticSystemView) CachingDisabled() bool {
return d.CachingDisabledVal return d.CachingDisabledVal
} }
func (d StaticSystemView) ReplicationState() ReplicationState { func (d StaticSystemView) ReplicationState() consts.ReplicationState {
return d.ReplicationStateVal return d.ReplicationStateVal
} }

View File

@@ -23,6 +23,21 @@ const (
FlagSetDefault = FlagSetServer FlagSetDefault = FlagSetServer
) )
var (
additionalOptionsUsage = func() string {
return `
-wrap-ttl="" Indicates that the response should be wrapped in a
cubbyhole token with the requested TTL. The response
can be fetched by calling the "sys/wrapping/unwrap"
endpoint, passing in the wrappping token's ID. This
is a numeric string with an optional suffix
"s", "m", or "h"; if no suffix is specified it will
be parsed as seconds. May also be specified via
VAULT_WRAP_TTL.
`
}
)
// Meta contains the meta-options and functionality that nearly every // Meta contains the meta-options and functionality that nearly every
// Vault command inherits. // Vault command inherits.
type Meta struct { type Meta struct {
@@ -188,6 +203,6 @@ func GeneralOptionsUsage() string {
if VAULT_SKIP_VERIFY is set. if VAULT_SKIP_VERIFY is set.
` `
general += AdditionalOptionsUsage() general += additionalOptionsUsage()
return general return general
} }

View File

@@ -1,7 +0,0 @@
// +build !vault
package meta
func AdditionalOptionsUsage() string {
return ""
}

View File

@@ -1,16 +0,0 @@
// +build vault
package meta
func AdditionalOptionsUsage() string {
return `
-wrap-ttl="" Indicates that the response should be wrapped in a
cubbyhole token with the requested TTL. The response
can be fetched by calling the "sys/wrapping/unwrap"
endpoint, passing in the wrappping token's ID. This
is a numeric string with an optional suffix
"s", "m", or "h"; if no suffix is specified it will
be parsed as seconds. May also be specified via
VAULT_WRAP_TTL.
`
}

View File

@@ -1,9 +1,15 @@
package physical package physical
import ( import (
"crypto/sha1"
"encoding/hex"
"fmt"
"strings" "strings"
"sync"
"github.com/hashicorp/golang-lru" "github.com/hashicorp/golang-lru"
"github.com/hashicorp/vault/helper/locksutil"
"github.com/hashicorp/vault/helper/strutil"
log "github.com/mgutz/logxi/v1" log "github.com/mgutz/logxi/v1"
) )
@@ -17,8 +23,11 @@ const (
// Vault are for policy objects so there is a large read reduction // Vault are for policy objects so there is a large read reduction
// by using a simple write-through cache. // by using a simple write-through cache.
type Cache struct { type Cache struct {
backend Backend backend Backend
lru *lru.TwoQueueCache transactional Transactional
lru *lru.TwoQueueCache
locks map[string]*sync.RWMutex
logger log.Logger
} }
// NewCache returns a physical cache of the given size. // NewCache returns a physical cache of the given size.
@@ -34,16 +43,58 @@ func NewCache(b Backend, size int, logger log.Logger) *Cache {
c := &Cache{ c := &Cache{
backend: b, backend: b,
lru: cache, lru: cache,
locks: make(map[string]*sync.RWMutex, 256),
logger: logger,
} }
if err := locksutil.CreateLocks(c.locks, 256); err != nil {
logger.Error("physical/cache: error creating locks", "error", err)
return nil
}
if txnl, ok := c.backend.(Transactional); ok {
c.transactional = txnl
}
return c return c
} }
func (c *Cache) lockHashForKey(key string) string {
hf := sha1.New()
hf.Write([]byte(key))
return strings.ToLower(hex.EncodeToString(hf.Sum(nil))[:2])
}
func (c *Cache) lockForKey(key string) *sync.RWMutex {
return c.locks[c.lockHashForKey(key)]
}
// Purge is used to clear the cache // Purge is used to clear the cache
func (c *Cache) Purge() { func (c *Cache) Purge() {
// Lock the world
lockHashes := make([]string, 0, len(c.locks))
for hash := range c.locks {
lockHashes = append(lockHashes, hash)
}
// Sort and deduplicate. This ensures we don't try to grab the same lock
// twice, and enforcing a sort means we'll not have multiple goroutines
// deadlock by acquiring in different orders.
lockHashes = strutil.RemoveDuplicates(lockHashes)
for _, lockHash := range lockHashes {
lock := c.locks[lockHash]
lock.Lock()
defer lock.Unlock()
}
c.lru.Purge() c.lru.Purge()
} }
func (c *Cache) Put(entry *Entry) error { func (c *Cache) Put(entry *Entry) error {
lock := c.lockForKey(entry.Key)
lock.Lock()
defer lock.Unlock()
err := c.backend.Put(entry) err := c.backend.Put(entry)
if err == nil { if err == nil {
c.lru.Add(entry.Key, entry) c.lru.Add(entry.Key, entry)
@@ -52,6 +103,10 @@ func (c *Cache) Put(entry *Entry) error {
} }
func (c *Cache) Get(key string) (*Entry, error) { func (c *Cache) Get(key string) (*Entry, error) {
lock := c.lockForKey(key)
lock.RLock()
defer lock.RUnlock()
// Check the LRU first // Check the LRU first
if raw, ok := c.lru.Get(key); ok { if raw, ok := c.lru.Get(key); ok {
if raw == nil { if raw == nil {
@@ -79,6 +134,10 @@ func (c *Cache) Get(key string) (*Entry, error) {
} }
func (c *Cache) Delete(key string) error { func (c *Cache) Delete(key string) error {
lock := c.lockForKey(key)
lock.Lock()
defer lock.Unlock()
err := c.backend.Delete(key) err := c.backend.Delete(key)
if err == nil { if err == nil {
c.lru.Remove(key) c.lru.Remove(key)
@@ -87,6 +146,45 @@ func (c *Cache) Delete(key string) error {
} }
func (c *Cache) List(prefix string) ([]string, error) { func (c *Cache) List(prefix string) ([]string, error) {
// Always pass-through as this would be difficult to cache. // Always pass-through as this would be difficult to cache. For the same
// reason we don't lock as we can't reasonably know which locks to readlock
// ahead of time.
return c.backend.List(prefix) return c.backend.List(prefix)
} }
func (c *Cache) Transaction(txns []TxnEntry) error {
if c.transactional == nil {
return fmt.Errorf("physical/cache: underlying backend does not support transactions")
}
var lockHashes []string
for _, txn := range txns {
lockHashes = append(lockHashes, c.lockHashForKey(txn.Entry.Key))
}
// Sort and deduplicate. This ensures we don't try to grab the same lock
// twice, and enforcing a sort means we'll not have multiple goroutines
// deadlock by acquiring in different orders.
lockHashes = strutil.RemoveDuplicates(lockHashes)
for _, lockHash := range lockHashes {
lock := c.locks[lockHash]
lock.Lock()
defer lock.Unlock()
}
if err := c.transactional.Transaction(txns); err != nil {
return err
}
for _, txn := range txns {
switch txn.Operation {
case PutOperation:
c.lru.Add(txn.Entry.Key, txn.Entry)
case DeleteOperation:
c.lru.Remove(txn.Entry.Key)
}
}
return nil
}

View File

@@ -1,6 +1,7 @@
package physical package physical
import ( import (
"errors"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"net" "net"
@@ -21,6 +22,8 @@ import (
"github.com/hashicorp/consul/lib" "github.com/hashicorp/consul/lib"
"github.com/hashicorp/errwrap" "github.com/hashicorp/errwrap"
"github.com/hashicorp/go-cleanhttp" "github.com/hashicorp/go-cleanhttp"
multierror "github.com/hashicorp/go-multierror"
"github.com/hashicorp/vault/helper/consts"
"github.com/hashicorp/vault/helper/strutil" "github.com/hashicorp/vault/helper/strutil"
"github.com/hashicorp/vault/helper/tlsutil" "github.com/hashicorp/vault/helper/tlsutil"
) )
@@ -154,6 +157,10 @@ func newConsulBackend(conf map[string]string, logger log.Logger) (Backend, error
// Configure the client // Configure the client
consulConf := api.DefaultConfig() consulConf := api.DefaultConfig()
// Set MaxIdleConnsPerHost to the number of processes used in expiration.Restore
tr := cleanhttp.DefaultPooledTransport()
tr.MaxIdleConnsPerHost = consts.ExpirationRestoreWorkerCount
consulConf.HttpClient.Transport = tr
if addr, ok := conf["address"]; ok { if addr, ok := conf["address"]; ok {
consulConf.Address = addr consulConf.Address = addr
@@ -179,7 +186,7 @@ func newConsulBackend(conf map[string]string, logger log.Logger) (Backend, error
} }
transport := cleanhttp.DefaultPooledTransport() transport := cleanhttp.DefaultPooledTransport()
transport.MaxIdleConnsPerHost = 4 transport.MaxIdleConnsPerHost = consts.ExpirationRestoreWorkerCount
transport.TLSClientConfig = tlsClientConfig transport.TLSClientConfig = tlsClientConfig
consulConf.HttpClient.Transport = transport consulConf.HttpClient.Transport = transport
logger.Debug("physical/consul: configured TLS") logger.Debug("physical/consul: configured TLS")
@@ -284,17 +291,59 @@ func setupTLSConfig(conf map[string]string) (*tls.Config, error) {
return tlsClientConfig, nil return tlsClientConfig, nil
} }
// Used to run multiple entries via a transaction
func (c *ConsulBackend) Transaction(txns []TxnEntry) error {
if len(txns) == 0 {
return nil
}
ops := make([]*api.KVTxnOp, 0, len(txns))
for _, op := range txns {
cop := &api.KVTxnOp{
Key: c.path + op.Entry.Key,
}
switch op.Operation {
case DeleteOperation:
cop.Verb = api.KVDelete
case PutOperation:
cop.Verb = api.KVSet
cop.Value = op.Entry.Value
default:
return fmt.Errorf("%q is not a supported transaction operation", op.Operation)
}
ops = append(ops, cop)
}
ok, resp, _, err := c.kv.Txn(ops, nil)
if err != nil {
return err
}
if ok {
return nil
}
var retErr *multierror.Error
for _, res := range resp.Errors {
retErr = multierror.Append(retErr, errors.New(res.What))
}
return retErr
}
// Put is used to insert or update an entry // Put is used to insert or update an entry
func (c *ConsulBackend) Put(entry *Entry) error { func (c *ConsulBackend) Put(entry *Entry) error {
defer metrics.MeasureSince([]string{"consul", "put"}, time.Now()) defer metrics.MeasureSince([]string{"consul", "put"}, time.Now())
c.permitPool.Acquire()
defer c.permitPool.Release()
pair := &api.KVPair{ pair := &api.KVPair{
Key: c.path + entry.Key, Key: c.path + entry.Key,
Value: entry.Value, Value: entry.Value,
} }
c.permitPool.Acquire()
defer c.permitPool.Release()
_, err := c.kv.Put(pair, nil) _, err := c.kv.Put(pair, nil)
return err return err
} }

View File

@@ -22,12 +22,17 @@ import (
// and non-performant. It is meant mostly for local testing and development. // and non-performant. It is meant mostly for local testing and development.
// It can be improved in the future. // It can be improved in the future.
type FileBackend struct { type FileBackend struct {
Path string sync.RWMutex
l sync.Mutex path string
logger log.Logger logger log.Logger
permitPool *PermitPool
} }
// newFileBackend constructs a Filebackend using the given directory type TransactionalFileBackend struct {
FileBackend
}
// newFileBackend constructs a FileBackend using the given directory
func newFileBackend(conf map[string]string, logger log.Logger) (Backend, error) { func newFileBackend(conf map[string]string, logger log.Logger) (Backend, error) {
path, ok := conf["path"] path, ok := conf["path"]
if !ok { if !ok {
@@ -35,20 +40,44 @@ func newFileBackend(conf map[string]string, logger log.Logger) (Backend, error)
} }
return &FileBackend{ return &FileBackend{
Path: path, path: path,
logger: logger, logger: logger,
permitPool: NewPermitPool(DefaultParallelOperations),
}, nil
}
func newTransactionalFileBackend(conf map[string]string, logger log.Logger) (Backend, error) {
path, ok := conf["path"]
if !ok {
return nil, fmt.Errorf("'path' must be set")
}
// Create a pool of size 1 so only one operation runs at a time
return &TransactionalFileBackend{
FileBackend: FileBackend{
path: path,
logger: logger,
permitPool: NewPermitPool(1),
},
}, nil }, nil
} }
func (b *FileBackend) Delete(path string) error { func (b *FileBackend) Delete(path string) error {
b.permitPool.Acquire()
defer b.permitPool.Release()
b.Lock()
defer b.Unlock()
return b.DeleteInternal(path)
}
func (b *FileBackend) DeleteInternal(path string) error {
if path == "" { if path == "" {
return nil return nil
} }
b.l.Lock() basePath, key := b.expandPath(path)
defer b.l.Unlock()
basePath, key := b.path(path)
fullPath := filepath.Join(basePath, key) fullPath := filepath.Join(basePath, key)
err := os.Remove(fullPath) err := os.Remove(fullPath)
@@ -66,7 +95,7 @@ func (b *FileBackend) Delete(path string) error {
func (b *FileBackend) cleanupLogicalPath(path string) error { func (b *FileBackend) cleanupLogicalPath(path string) error {
nodes := strings.Split(path, fmt.Sprintf("%c", os.PathSeparator)) nodes := strings.Split(path, fmt.Sprintf("%c", os.PathSeparator))
for i := len(nodes) - 1; i > 0; i-- { for i := len(nodes) - 1; i > 0; i-- {
fullPath := filepath.Join(b.Path, filepath.Join(nodes[:i]...)) fullPath := filepath.Join(b.path, filepath.Join(nodes[:i]...))
dir, err := os.Open(fullPath) dir, err := os.Open(fullPath)
if err != nil { if err != nil {
@@ -96,10 +125,17 @@ func (b *FileBackend) cleanupLogicalPath(path string) error {
} }
func (b *FileBackend) Get(k string) (*Entry, error) { func (b *FileBackend) Get(k string) (*Entry, error) {
b.l.Lock() b.permitPool.Acquire()
defer b.l.Unlock() defer b.permitPool.Release()
path, key := b.path(k) b.RLock()
defer b.RUnlock()
return b.GetInternal(k)
}
func (b *FileBackend) GetInternal(k string) (*Entry, error) {
path, key := b.expandPath(k)
path = filepath.Join(path, key) path = filepath.Join(path, key)
f, err := os.Open(path) f, err := os.Open(path)
@@ -121,10 +157,17 @@ func (b *FileBackend) Get(k string) (*Entry, error) {
} }
func (b *FileBackend) Put(entry *Entry) error { func (b *FileBackend) Put(entry *Entry) error {
path, key := b.path(entry.Key) b.permitPool.Acquire()
defer b.permitPool.Release()
b.l.Lock() b.Lock()
defer b.l.Unlock() defer b.Unlock()
return b.PutInternal(entry)
}
func (b *FileBackend) PutInternal(entry *Entry) error {
path, key := b.expandPath(entry.Key)
// Make the parent tree // Make the parent tree
if err := os.MkdirAll(path, 0755); err != nil { if err := os.MkdirAll(path, 0755); err != nil {
@@ -145,10 +188,17 @@ func (b *FileBackend) Put(entry *Entry) error {
} }
func (b *FileBackend) List(prefix string) ([]string, error) { func (b *FileBackend) List(prefix string) ([]string, error) {
b.l.Lock() b.permitPool.Acquire()
defer b.l.Unlock() defer b.permitPool.Release()
path := b.Path b.RLock()
defer b.RUnlock()
return b.ListInternal(prefix)
}
func (b *FileBackend) ListInternal(prefix string) ([]string, error) {
path := b.path
if prefix != "" { if prefix != "" {
path = filepath.Join(path, prefix) path = filepath.Join(path, prefix)
} }
@@ -180,9 +230,19 @@ func (b *FileBackend) List(prefix string) ([]string, error) {
return names, nil return names, nil
} }
func (b *FileBackend) path(k string) (string, string) { func (b *FileBackend) expandPath(k string) (string, string) {
path := filepath.Join(b.Path, k) path := filepath.Join(b.path, k)
key := filepath.Base(path) key := filepath.Base(path)
path = filepath.Dir(path) path = filepath.Dir(path)
return path, "_" + key return path, "_" + key
} }
func (b *TransactionalFileBackend) Transaction(txns []TxnEntry) error {
b.permitPool.Acquire()
defer b.permitPool.Release()
b.Lock()
defer b.Unlock()
return genericTransactionHandler(b, txns)
}

View File

@@ -13,12 +13,16 @@ import (
// for testing and development situations where the data is not // for testing and development situations where the data is not
// expected to be durable. // expected to be durable.
type InmemBackend struct { type InmemBackend struct {
sync.RWMutex
root *radix.Tree root *radix.Tree
l sync.RWMutex
permitPool *PermitPool permitPool *PermitPool
logger log.Logger logger log.Logger
} }
type TransactionalInmemBackend struct {
InmemBackend
}
// NewInmem constructs a new in-memory backend // NewInmem constructs a new in-memory backend
func NewInmem(logger log.Logger) *InmemBackend { func NewInmem(logger log.Logger) *InmemBackend {
in := &InmemBackend{ in := &InmemBackend{
@@ -29,14 +33,31 @@ func NewInmem(logger log.Logger) *InmemBackend {
return in return in
} }
// Basically for now just creates a permit pool of size 1 so only one operation
// can run at a time
func NewTransactionalInmem(logger log.Logger) *TransactionalInmemBackend {
in := &TransactionalInmemBackend{
InmemBackend: InmemBackend{
root: radix.New(),
permitPool: NewPermitPool(1),
logger: logger,
},
}
return in
}
// Put is used to insert or update an entry // Put is used to insert or update an entry
func (i *InmemBackend) Put(entry *Entry) error { func (i *InmemBackend) Put(entry *Entry) error {
i.permitPool.Acquire() i.permitPool.Acquire()
defer i.permitPool.Release() defer i.permitPool.Release()
i.l.Lock() i.Lock()
defer i.l.Unlock() defer i.Unlock()
return i.PutInternal(entry)
}
func (i *InmemBackend) PutInternal(entry *Entry) error {
i.root.Insert(entry.Key, entry) i.root.Insert(entry.Key, entry)
return nil return nil
} }
@@ -46,9 +67,13 @@ func (i *InmemBackend) Get(key string) (*Entry, error) {
i.permitPool.Acquire() i.permitPool.Acquire()
defer i.permitPool.Release() defer i.permitPool.Release()
i.l.RLock() i.RLock()
defer i.l.RUnlock() defer i.RUnlock()
return i.GetInternal(key)
}
func (i *InmemBackend) GetInternal(key string) (*Entry, error) {
if raw, ok := i.root.Get(key); ok { if raw, ok := i.root.Get(key); ok {
return raw.(*Entry), nil return raw.(*Entry), nil
} }
@@ -60,9 +85,13 @@ func (i *InmemBackend) Delete(key string) error {
i.permitPool.Acquire() i.permitPool.Acquire()
defer i.permitPool.Release() defer i.permitPool.Release()
i.l.Lock() i.Lock()
defer i.l.Unlock() defer i.Unlock()
return i.DeleteInternal(key)
}
func (i *InmemBackend) DeleteInternal(key string) error {
i.root.Delete(key) i.root.Delete(key)
return nil return nil
} }
@@ -73,9 +102,13 @@ func (i *InmemBackend) List(prefix string) ([]string, error) {
i.permitPool.Acquire() i.permitPool.Acquire()
defer i.permitPool.Release() defer i.permitPool.Release()
i.l.RLock() i.RLock()
defer i.l.RUnlock() defer i.RUnlock()
return i.ListInternal(prefix)
}
func (i *InmemBackend) ListInternal(prefix string) ([]string, error) {
var out []string var out []string
seen := make(map[string]interface{}) seen := make(map[string]interface{})
walkFn := func(s string, v interface{}) bool { walkFn := func(s string, v interface{}) bool {
@@ -96,3 +129,14 @@ func (i *InmemBackend) List(prefix string) ([]string, error) {
return out, nil return out, nil
} }
// Implements the transaction interface
func (t *TransactionalInmemBackend) Transaction(txns []TxnEntry) error {
t.permitPool.Acquire()
defer t.permitPool.Release()
t.Lock()
defer t.Unlock()
return genericTransactionHandler(t, txns)
}

View File

@@ -8,19 +8,40 @@ import (
) )
type InmemHABackend struct { type InmemHABackend struct {
InmemBackend Backend
locks map[string]string locks map[string]string
l sync.Mutex l sync.Mutex
cond *sync.Cond cond *sync.Cond
logger log.Logger logger log.Logger
} }
type TransactionalInmemHABackend struct {
Transactional
InmemHABackend
}
// NewInmemHA constructs a new in-memory HA backend. This is only for testing. // NewInmemHA constructs a new in-memory HA backend. This is only for testing.
func NewInmemHA(logger log.Logger) *InmemHABackend { func NewInmemHA(logger log.Logger) *InmemHABackend {
in := &InmemHABackend{ in := &InmemHABackend{
InmemBackend: *NewInmem(logger), Backend: NewInmem(logger),
locks: make(map[string]string), locks: make(map[string]string),
logger: logger, logger: logger,
}
in.cond = sync.NewCond(&in.l)
return in
}
func NewTransactionalInmemHA(logger log.Logger) *TransactionalInmemHABackend {
transInmem := NewTransactionalInmem(logger)
inmemHA := InmemHABackend{
Backend: transInmem,
locks: make(map[string]string),
logger: logger,
}
in := &TransactionalInmemHABackend{
InmemHABackend: inmemHA,
Transactional: transInmem,
} }
in.cond = sync.NewCond(&in.l) in.cond = sync.NewCond(&in.l)
return in return in

View File

@@ -9,6 +9,16 @@ import (
const DefaultParallelOperations = 128 const DefaultParallelOperations = 128
// The operation type
type Operation string
const (
DeleteOperation Operation = "delete"
GetOperation = "get"
ListOperation = "list"
PutOperation = "put"
)
// ShutdownSignal // ShutdownSignal
type ShutdownChannel chan struct{} type ShutdownChannel chan struct{}
@@ -121,20 +131,27 @@ var builtinBackends = map[string]Factory{
"inmem": func(_ map[string]string, logger log.Logger) (Backend, error) { "inmem": func(_ map[string]string, logger log.Logger) (Backend, error) {
return NewInmem(logger), nil return NewInmem(logger), nil
}, },
"inmem_transactional": func(_ map[string]string, logger log.Logger) (Backend, error) {
return NewTransactionalInmem(logger), nil
},
"inmem_ha": func(_ map[string]string, logger log.Logger) (Backend, error) { "inmem_ha": func(_ map[string]string, logger log.Logger) (Backend, error) {
return NewInmemHA(logger), nil return NewInmemHA(logger), nil
}, },
"consul": newConsulBackend, "inmem_transactional_ha": func(_ map[string]string, logger log.Logger) (Backend, error) {
"zookeeper": newZookeeperBackend, return NewTransactionalInmemHA(logger), nil
"file": newFileBackend, },
"s3": newS3Backend, "file_transactional": newTransactionalFileBackend,
"azure": newAzureBackend, "consul": newConsulBackend,
"dynamodb": newDynamoDBBackend, "zookeeper": newZookeeperBackend,
"etcd": newEtcdBackend, "file": newFileBackend,
"mysql": newMySQLBackend, "s3": newS3Backend,
"postgresql": newPostgreSQLBackend, "azure": newAzureBackend,
"swift": newSwiftBackend, "dynamodb": newDynamoDBBackend,
"gcs": newGCSBackend, "etcd": newEtcdBackend,
"mysql": newMySQLBackend,
"postgresql": newPostgreSQLBackend,
"swift": newSwiftBackend,
"gcs": newGCSBackend,
} }
// PermitPool is used to limit maximum outstanding requests // PermitPool is used to limit maximum outstanding requests

View File

@@ -71,7 +71,8 @@ func newPostgreSQLBackend(conf map[string]string, logger log.Logger) (Backend, e
get_query: "SELECT value FROM " + quoted_table + " WHERE path = $1 AND key = $2", get_query: "SELECT value FROM " + quoted_table + " WHERE path = $1 AND key = $2",
delete_query: "DELETE FROM " + quoted_table + " WHERE path = $1 AND key = $2", delete_query: "DELETE FROM " + quoted_table + " WHERE path = $1 AND key = $2",
list_query: "SELECT key FROM " + quoted_table + " WHERE path = $1" + list_query: "SELECT key FROM " + quoted_table + " WHERE path = $1" +
"UNION SELECT substr(path, length($1)+1) FROM " + quoted_table + "WHERE parent_path = $1", "UNION SELECT DISTINCT substring(substr(path, length($1)+1) from '^.*?/') FROM " +
quoted_table + " WHERE parent_path LIKE concat($1, '%')",
logger: logger, logger: logger,
} }

121
physical/transactions.go Normal file
View File

@@ -0,0 +1,121 @@
package physical
import multierror "github.com/hashicorp/go-multierror"
// TxnEntry is an operation that takes atomically as part of
// a transactional update. Only supported by Transactional backends.
type TxnEntry struct {
Operation Operation
Entry *Entry
}
// Transactional is an optional interface for backends that
// support doing transactional updates of multiple keys. This is
// required for some features such as replication.
type Transactional interface {
// The function to run a transaction
Transaction([]TxnEntry) error
}
type PseudoTransactional interface {
// An internal function should do no locking or permit pool acquisition.
// Depending on the backend and if it natively supports transactions, these
// may simply chain to the normal backend functions.
GetInternal(string) (*Entry, error)
PutInternal(*Entry) error
DeleteInternal(string) error
}
// Implements the transaction interface
func genericTransactionHandler(t PseudoTransactional, txns []TxnEntry) (retErr error) {
rollbackStack := make([]TxnEntry, 0, len(txns))
var dirty bool
// We walk the transactions in order; each successful operation goes into a
// LIFO for rollback if we hit an error along the way
TxnWalk:
for _, txn := range txns {
switch txn.Operation {
case DeleteOperation:
entry, err := t.GetInternal(txn.Entry.Key)
if err != nil {
retErr = multierror.Append(retErr, err)
dirty = true
break TxnWalk
}
if entry == nil {
// Nothing to delete or roll back
continue
}
rollbackEntry := TxnEntry{
Operation: PutOperation,
Entry: &Entry{
Key: entry.Key,
Value: entry.Value,
},
}
err = t.DeleteInternal(txn.Entry.Key)
if err != nil {
retErr = multierror.Append(retErr, err)
dirty = true
break TxnWalk
}
rollbackStack = append([]TxnEntry{rollbackEntry}, rollbackStack...)
case PutOperation:
entry, err := t.GetInternal(txn.Entry.Key)
if err != nil {
retErr = multierror.Append(retErr, err)
dirty = true
break TxnWalk
}
// Nothing existed so in fact rolling back requires a delete
var rollbackEntry TxnEntry
if entry == nil {
rollbackEntry = TxnEntry{
Operation: DeleteOperation,
Entry: &Entry{
Key: txn.Entry.Key,
},
}
} else {
rollbackEntry = TxnEntry{
Operation: PutOperation,
Entry: &Entry{
Key: entry.Key,
Value: entry.Value,
},
}
}
err = t.PutInternal(txn.Entry)
if err != nil {
retErr = multierror.Append(retErr, err)
dirty = true
break TxnWalk
}
rollbackStack = append([]TxnEntry{rollbackEntry}, rollbackStack...)
}
}
// Need to roll back because we hit an error along the way
if dirty {
// While traversing this, if we get an error, we continue anyways in
// best-effort fashion
for _, txn := range rollbackStack {
switch txn.Operation {
case DeleteOperation:
err := t.DeleteInternal(txn.Entry.Key)
if err != nil {
retErr = multierror.Append(retErr, err)
}
case PutOperation:
err := t.PutInternal(txn.Entry)
if err != nil {
retErr = multierror.Append(retErr, err)
}
}
}
}
return
}

View File

@@ -0,0 +1,254 @@
package physical
import (
"fmt"
"reflect"
"sort"
"testing"
radix "github.com/armon/go-radix"
"github.com/hashicorp/vault/helper/logformat"
log "github.com/mgutz/logxi/v1"
)
type faultyPseudo struct {
underlying InmemBackend
faultyPaths map[string]struct{}
}
func (f *faultyPseudo) Get(key string) (*Entry, error) {
return f.underlying.Get(key)
}
func (f *faultyPseudo) Put(entry *Entry) error {
return f.underlying.Put(entry)
}
func (f *faultyPseudo) Delete(key string) error {
return f.underlying.Delete(key)
}
func (f *faultyPseudo) GetInternal(key string) (*Entry, error) {
if _, ok := f.faultyPaths[key]; ok {
return nil, fmt.Errorf("fault")
}
return f.underlying.GetInternal(key)
}
func (f *faultyPseudo) PutInternal(entry *Entry) error {
if _, ok := f.faultyPaths[entry.Key]; ok {
return fmt.Errorf("fault")
}
return f.underlying.PutInternal(entry)
}
func (f *faultyPseudo) DeleteInternal(key string) error {
if _, ok := f.faultyPaths[key]; ok {
return fmt.Errorf("fault")
}
return f.underlying.DeleteInternal(key)
}
func (f *faultyPseudo) List(prefix string) ([]string, error) {
return f.underlying.List(prefix)
}
func (f *faultyPseudo) Transaction(txns []TxnEntry) error {
f.underlying.permitPool.Acquire()
defer f.underlying.permitPool.Release()
f.underlying.Lock()
defer f.underlying.Unlock()
return genericTransactionHandler(f, txns)
}
func newFaultyPseudo(logger log.Logger, faultyPaths []string) *faultyPseudo {
out := &faultyPseudo{
underlying: InmemBackend{
root: radix.New(),
permitPool: NewPermitPool(1),
logger: logger,
},
faultyPaths: make(map[string]struct{}, len(faultyPaths)),
}
for _, v := range faultyPaths {
out.faultyPaths[v] = struct{}{}
}
return out
}
func TestPseudo_Basic(t *testing.T) {
logger := logformat.NewVaultLogger(log.LevelTrace)
p := newFaultyPseudo(logger, nil)
testBackend(t, p)
testBackend_ListPrefix(t, p)
}
func TestPseudo_SuccessfulTransaction(t *testing.T) {
logger := logformat.NewVaultLogger(log.LevelTrace)
p := newFaultyPseudo(logger, nil)
txns := setupPseudo(p, t)
if err := p.Transaction(txns); err != nil {
t.Fatal(err)
}
keys, err := p.List("")
if err != nil {
t.Fatal(err)
}
expected := []string{"foo", "zip"}
sort.Strings(keys)
sort.Strings(expected)
if !reflect.DeepEqual(keys, expected) {
t.Fatalf("mismatch: expected\n%#v\ngot\n%#v\n", expected, keys)
}
entry, err := p.Get("foo")
if err != nil {
t.Fatal(err)
}
if entry == nil {
t.Fatal("got nil entry")
}
if entry.Value == nil {
t.Fatal("got nil value")
}
if string(entry.Value) != "bar3" {
t.Fatal("updates did not apply correctly")
}
entry, err = p.Get("zip")
if err != nil {
t.Fatal(err)
}
if entry == nil {
t.Fatal("got nil entry")
}
if entry.Value == nil {
t.Fatal("got nil value")
}
if string(entry.Value) != "zap3" {
t.Fatal("updates did not apply correctly")
}
}
func TestPseudo_FailedTransaction(t *testing.T) {
logger := logformat.NewVaultLogger(log.LevelTrace)
p := newFaultyPseudo(logger, []string{"zip"})
txns := setupPseudo(p, t)
if err := p.Transaction(txns); err == nil {
t.Fatal("expected error during transaction")
}
keys, err := p.List("")
if err != nil {
t.Fatal(err)
}
expected := []string{"foo", "zip", "deleteme", "deleteme2"}
sort.Strings(keys)
sort.Strings(expected)
if !reflect.DeepEqual(keys, expected) {
t.Fatalf("mismatch: expected\n%#v\ngot\n%#v\n", expected, keys)
}
entry, err := p.Get("foo")
if err != nil {
t.Fatal(err)
}
if entry == nil {
t.Fatal("got nil entry")
}
if entry.Value == nil {
t.Fatal("got nil value")
}
if string(entry.Value) != "bar" {
t.Fatal("values did not rollback correctly")
}
entry, err = p.Get("zip")
if err != nil {
t.Fatal(err)
}
if entry == nil {
t.Fatal("got nil entry")
}
if entry.Value == nil {
t.Fatal("got nil value")
}
if string(entry.Value) != "zap" {
t.Fatal("values did not rollback correctly")
}
}
func setupPseudo(p *faultyPseudo, t *testing.T) []TxnEntry {
// Add a few keys so that we test rollback with deletion
if err := p.Put(&Entry{
Key: "foo",
Value: []byte("bar"),
}); err != nil {
t.Fatal(err)
}
if err := p.Put(&Entry{
Key: "zip",
Value: []byte("zap"),
}); err != nil {
t.Fatal(err)
}
if err := p.Put(&Entry{
Key: "deleteme",
}); err != nil {
t.Fatal(err)
}
if err := p.Put(&Entry{
Key: "deleteme2",
}); err != nil {
t.Fatal(err)
}
txns := []TxnEntry{
TxnEntry{
Operation: PutOperation,
Entry: &Entry{
Key: "foo",
Value: []byte("bar2"),
},
},
TxnEntry{
Operation: DeleteOperation,
Entry: &Entry{
Key: "deleteme",
},
},
TxnEntry{
Operation: PutOperation,
Entry: &Entry{
Key: "foo",
Value: []byte("bar3"),
},
},
TxnEntry{
Operation: DeleteOperation,
Entry: &Entry{
Key: "deleteme2",
},
},
TxnEntry{
Operation: PutOperation,
Entry: &Entry{
Key: "zip",
Value: []byte("zap3"),
},
},
}
return txns
}

View File

@@ -10,7 +10,7 @@ RUN apt-get update -y && apt-get install --no-install-recommends -y -q \
git mercurial bzr \ git mercurial bzr \
&& rm -rf /var/lib/apt/lists/* && rm -rf /var/lib/apt/lists/*
ENV GOVERSION 1.8rc3 ENV GOVERSION 1.8
RUN mkdir /goroot && mkdir /gopath RUN mkdir /goroot && mkdir /gopath
RUN curl https://storage.googleapis.com/golang/go${GOVERSION}.linux-amd64.tar.gz \ RUN curl https://storage.googleapis.com/golang/go${GOVERSION}.linux-amd64.tar.gz \
| tar xvzf - -C /goroot --strip-components=1 | tar xvzf - -C /goroot --strip-components=1

View File

@@ -2,7 +2,6 @@ package vault
import ( import (
"crypto/sha256" "crypto/sha256"
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"strings" "strings"
@@ -26,6 +25,10 @@ const (
// can only be viewed or modified after an unseal. // can only be viewed or modified after an unseal.
coreAuditConfigPath = "core/audit" coreAuditConfigPath = "core/audit"
// coreLocalAuditConfigPath is used to store audit information for local
// (non-replicated) mounts
coreLocalAuditConfigPath = "core/local-audit"
// auditBarrierPrefix is the prefix to the UUID used in the // auditBarrierPrefix is the prefix to the UUID used in the
// barrier view for the audit backends. // barrier view for the audit backends.
auditBarrierPrefix = "audit/" auditBarrierPrefix = "audit/"
@@ -69,12 +72,15 @@ func (c *Core) enableAudit(entry *MountEntry) error {
} }
// Generate a new UUID and view // Generate a new UUID and view
entryUUID, err := uuid.GenerateUUID() if entry.UUID == "" {
if err != nil { entryUUID, err := uuid.GenerateUUID()
return err if err != nil {
return err
}
entry.UUID = entryUUID
} }
entry.UUID = entryUUID viewPath := auditBarrierPrefix + entry.UUID + "/"
view := NewBarrierView(c.barrier, auditBarrierPrefix+entry.UUID+"/") view := NewBarrierView(c.barrier, viewPath)
// Lookup the new backend // Lookup the new backend
backend, err := c.newAuditBackend(entry, view, entry.Options) backend, err := c.newAuditBackend(entry, view, entry.Options)
@@ -119,6 +125,12 @@ func (c *Core) disableAudit(path string) (bool, error) {
c.removeAuditReloadFunc(entry) c.removeAuditReloadFunc(entry)
// When unmounting all entries the JSON code will load back up from storage
// as a nil slice, which kills tests...just set it nil explicitly
if len(newTable.Entries) == 0 {
newTable.Entries = nil
}
// Update the audit table // Update the audit table
if err := c.persistAudit(newTable); err != nil { if err := c.persistAudit(newTable); err != nil {
return true, errors.New("failed to update audit table") return true, errors.New("failed to update audit table")
@@ -131,12 +143,14 @@ func (c *Core) disableAudit(path string) (bool, error) {
if c.logger.IsInfo() { if c.logger.IsInfo() {
c.logger.Info("core: disabled audit backend", "path", path) c.logger.Info("core: disabled audit backend", "path", path)
} }
return true, nil return true, nil
} }
// loadAudits is invoked as part of postUnseal to load the audit table // loadAudits is invoked as part of postUnseal to load the audit table
func (c *Core) loadAudits() error { func (c *Core) loadAudits() error {
auditTable := &MountTable{} auditTable := &MountTable{}
localAuditTable := &MountTable{}
// Load the existing audit table // Load the existing audit table
raw, err := c.barrier.Get(coreAuditConfigPath) raw, err := c.barrier.Get(coreAuditConfigPath)
@@ -144,6 +158,11 @@ func (c *Core) loadAudits() error {
c.logger.Error("core: failed to read audit table", "error", err) c.logger.Error("core: failed to read audit table", "error", err)
return errLoadAuditFailed return errLoadAuditFailed
} }
rawLocal, err := c.barrier.Get(coreLocalAuditConfigPath)
if err != nil {
c.logger.Error("core: failed to read local audit table", "error", err)
return errLoadAuditFailed
}
c.auditLock.Lock() c.auditLock.Lock()
defer c.auditLock.Unlock() defer c.auditLock.Unlock()
@@ -155,6 +174,13 @@ func (c *Core) loadAudits() error {
} }
c.audit = auditTable c.audit = auditTable
} }
if rawLocal != nil {
if err := jsonutil.DecodeJSON(rawLocal.Value, localAuditTable); err != nil {
c.logger.Error("core: failed to decode local audit table", "error", err)
return errLoadAuditFailed
}
c.audit.Entries = append(c.audit.Entries, localAuditTable.Entries...)
}
// Done if we have restored the audit table // Done if we have restored the audit table
if c.audit != nil { if c.audit != nil {
@@ -203,17 +229,33 @@ func (c *Core) persistAudit(table *MountTable) error {
} }
} }
nonLocalAudit := &MountTable{
Type: auditTableType,
}
localAudit := &MountTable{
Type: auditTableType,
}
for _, entry := range table.Entries {
if entry.Local {
localAudit.Entries = append(localAudit.Entries, entry)
} else {
nonLocalAudit.Entries = append(nonLocalAudit.Entries, entry)
}
}
// Marshal the table // Marshal the table
raw, err := json.Marshal(table) compressedBytes, err := jsonutil.EncodeJSONAndCompress(nonLocalAudit, nil)
if err != nil { if err != nil {
c.logger.Error("core: failed to encode audit table", "error", err) c.logger.Error("core: failed to encode and/or compress audit table", "error", err)
return err return err
} }
// Create an entry // Create an entry
entry := &Entry{ entry := &Entry{
Key: coreAuditConfigPath, Key: coreAuditConfigPath,
Value: raw, Value: compressedBytes,
} }
// Write to the physical backend // Write to the physical backend
@@ -221,6 +263,24 @@ func (c *Core) persistAudit(table *MountTable) error {
c.logger.Error("core: failed to persist audit table", "error", err) c.logger.Error("core: failed to persist audit table", "error", err)
return err return err
} }
// Repeat with local audit
compressedBytes, err = jsonutil.EncodeJSONAndCompress(localAudit, nil)
if err != nil {
c.logger.Error("core: failed to encode and/or compress local audit table", "error", err)
return err
}
entry = &Entry{
Key: coreLocalAuditConfigPath,
Value: compressedBytes,
}
if err := c.barrier.Put(entry); err != nil {
c.logger.Error("core: failed to persist local audit table", "error", err)
return err
}
return nil return nil
} }
@@ -236,7 +296,8 @@ func (c *Core) setupAudits() error {
for _, entry := range c.audit.Entries { for _, entry := range c.audit.Entries {
// Create a barrier view using the UUID // Create a barrier view using the UUID
view := NewBarrierView(c.barrier, auditBarrierPrefix+entry.UUID+"/") viewPath := auditBarrierPrefix + entry.UUID + "/"
view := NewBarrierView(c.barrier, viewPath)
// Initialize the backend // Initialize the backend
audit, err := c.newAuditBackend(entry, view, entry.Options) audit, err := c.newAuditBackend(entry, view, entry.Options)

View File

@@ -11,6 +11,7 @@ import (
"github.com/hashicorp/errwrap" "github.com/hashicorp/errwrap"
"github.com/hashicorp/go-uuid" "github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/audit" "github.com/hashicorp/vault/audit"
"github.com/hashicorp/vault/helper/jsonutil"
"github.com/hashicorp/vault/helper/logformat" "github.com/hashicorp/vault/helper/logformat"
"github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical"
log "github.com/mgutz/logxi/v1" log "github.com/mgutz/logxi/v1"
@@ -164,6 +165,94 @@ func TestCore_EnableAudit_MixedFailures(t *testing.T) {
} }
} }
// Test that the local table actually gets populated as expected with local
// entries, and that upon reading the entries from both are recombined
// correctly
func TestCore_EnableAudit_Local(t *testing.T) {
c, _, _ := TestCoreUnsealed(t)
c.auditBackends["noop"] = func(config *audit.BackendConfig) (audit.Backend, error) {
return &NoopAudit{
Config: config,
}, nil
}
c.auditBackends["fail"] = func(config *audit.BackendConfig) (audit.Backend, error) {
return nil, fmt.Errorf("failing enabling")
}
c.audit = &MountTable{
Type: auditTableType,
Entries: []*MountEntry{
&MountEntry{
Table: auditTableType,
Path: "noop/",
Type: "noop",
UUID: "abcd",
},
&MountEntry{
Table: auditTableType,
Path: "noop2/",
Type: "noop",
UUID: "bcde",
},
},
}
// Both should set up successfully
err := c.setupAudits()
if err != nil {
t.Fatal(err)
}
rawLocal, err := c.barrier.Get(coreLocalAuditConfigPath)
if err != nil {
t.Fatal(err)
}
if rawLocal == nil {
t.Fatal("expected non-nil local audit")
}
localAuditTable := &MountTable{}
if err := jsonutil.DecodeJSON(rawLocal.Value, localAuditTable); err != nil {
t.Fatal(err)
}
if len(localAuditTable.Entries) > 0 {
t.Fatalf("expected no entries in local audit table, got %#v", localAuditTable)
}
c.audit.Entries[1].Local = true
if err := c.persistAudit(c.audit); err != nil {
t.Fatal(err)
}
rawLocal, err = c.barrier.Get(coreLocalAuditConfigPath)
if err != nil {
t.Fatal(err)
}
if rawLocal == nil {
t.Fatal("expected non-nil local audit")
}
localAuditTable = &MountTable{}
if err := jsonutil.DecodeJSON(rawLocal.Value, localAuditTable); err != nil {
t.Fatal(err)
}
if len(localAuditTable.Entries) != 1 {
t.Fatalf("expected one entry in local audit table, got %#v", localAuditTable)
}
oldAudit := c.audit
if err := c.loadAudits(); err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(oldAudit, c.audit) {
t.Fatalf("expected\n%#v\ngot\n%#v\n", oldAudit, c.audit)
}
if len(c.audit.Entries) != 2 {
t.Fatalf("expected two audit entries, got %#v", localAuditTable)
}
}
func TestCore_DisableAudit(t *testing.T) { func TestCore_DisableAudit(t *testing.T) {
c, keys, _ := TestCoreUnsealed(t) c, keys, _ := TestCoreUnsealed(t)
c.auditBackends["noop"] = func(config *audit.BackendConfig) (audit.Backend, error) { c.auditBackends["noop"] = func(config *audit.BackendConfig) (audit.Backend, error) {
@@ -217,7 +306,7 @@ func TestCore_DisableAudit(t *testing.T) {
// Verify matching mount tables // Verify matching mount tables
if !reflect.DeepEqual(c.audit, c2.audit) { if !reflect.DeepEqual(c.audit, c2.audit) {
t.Fatalf("mismatch: %v %v", c.audit, c2.audit) t.Fatalf("mismatch:\n%#v\n%#v", c.audit, c2.audit)
} }
} }

View File

@@ -1,7 +1,6 @@
package vault package vault
import ( import (
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"strings" "strings"
@@ -17,6 +16,10 @@ const (
// can only be viewed or modified after an unseal. // can only be viewed or modified after an unseal.
coreAuthConfigPath = "core/auth" coreAuthConfigPath = "core/auth"
// coreLocalAuthConfigPath is used to store credential configuration for
// local (non-replicated) mounts
coreLocalAuthConfigPath = "core/local-auth"
// credentialBarrierPrefix is the prefix to the UUID used in the // credentialBarrierPrefix is the prefix to the UUID used in the
// barrier view for the credential backends. // barrier view for the credential backends.
credentialBarrierPrefix = "auth/" credentialBarrierPrefix = "auth/"
@@ -71,16 +74,25 @@ func (c *Core) enableCredential(entry *MountEntry) error {
} }
// Generate a new UUID and view // Generate a new UUID and view
entryUUID, err := uuid.GenerateUUID() if entry.UUID == "" {
entryUUID, err := uuid.GenerateUUID()
if err != nil {
return err
}
entry.UUID = entryUUID
}
viewPath := credentialBarrierPrefix + entry.UUID + "/"
view := NewBarrierView(c.barrier, viewPath)
sysView := c.mountEntrySysView(entry)
// Create the new backend
backend, err := c.newCredentialBackend(entry.Type, sysView, view, nil)
if err != nil { if err != nil {
return err return err
} }
entry.UUID = entryUUID
view := NewBarrierView(c.barrier, credentialBarrierPrefix+entry.UUID+"/")
// Create the new backend if err := backend.Initialize(); err != nil {
backend, err := c.newCredentialBackend(entry.Type, c.mountEntrySysView(entry), view, nil)
if err != nil {
return err return err
} }
@@ -121,7 +133,7 @@ func (c *Core) disableCredential(path string) (bool, error) {
fullPath := credentialRoutePrefix + path fullPath := credentialRoutePrefix + path
view := c.router.MatchingStorageView(fullPath) view := c.router.MatchingStorageView(fullPath)
if view == nil { if view == nil {
return false, fmt.Errorf("no matching backend") return false, fmt.Errorf("no matching backend %s", fullPath)
} }
// Mark the entry as tainted // Mark the entry as tainted
@@ -206,12 +218,19 @@ func (c *Core) taintCredEntry(path string) error {
// loadCredentials is invoked as part of postUnseal to load the auth table // loadCredentials is invoked as part of postUnseal to load the auth table
func (c *Core) loadCredentials() error { func (c *Core) loadCredentials() error {
authTable := &MountTable{} authTable := &MountTable{}
localAuthTable := &MountTable{}
// Load the existing mount table // Load the existing mount table
raw, err := c.barrier.Get(coreAuthConfigPath) raw, err := c.barrier.Get(coreAuthConfigPath)
if err != nil { if err != nil {
c.logger.Error("core: failed to read auth table", "error", err) c.logger.Error("core: failed to read auth table", "error", err)
return errLoadAuthFailed return errLoadAuthFailed
} }
rawLocal, err := c.barrier.Get(coreLocalAuthConfigPath)
if err != nil {
c.logger.Error("core: failed to read local auth table", "error", err)
return errLoadAuthFailed
}
c.authLock.Lock() c.authLock.Lock()
defer c.authLock.Unlock() defer c.authLock.Unlock()
@@ -223,6 +242,13 @@ func (c *Core) loadCredentials() error {
} }
c.auth = authTable c.auth = authTable
} }
if rawLocal != nil {
if err := jsonutil.DecodeJSON(rawLocal.Value, localAuthTable); err != nil {
c.logger.Error("core: failed to decode local auth table", "error", err)
return errLoadAuthFailed
}
c.auth.Entries = append(c.auth.Entries, localAuthTable.Entries...)
}
// Done if we have restored the auth table // Done if we have restored the auth table
if c.auth != nil { if c.auth != nil {
@@ -272,17 +298,33 @@ func (c *Core) persistAuth(table *MountTable) error {
} }
} }
nonLocalAuth := &MountTable{
Type: credentialTableType,
}
localAuth := &MountTable{
Type: credentialTableType,
}
for _, entry := range table.Entries {
if entry.Local {
localAuth.Entries = append(localAuth.Entries, entry)
} else {
nonLocalAuth.Entries = append(nonLocalAuth.Entries, entry)
}
}
// Marshal the table // Marshal the table
raw, err := json.Marshal(table) compressedBytes, err := jsonutil.EncodeJSONAndCompress(nonLocalAuth, nil)
if err != nil { if err != nil {
c.logger.Error("core: failed to encode auth table", "error", err) c.logger.Error("core: failed to encode and/or compress auth table", "error", err)
return err return err
} }
// Create an entry // Create an entry
entry := &Entry{ entry := &Entry{
Key: coreAuthConfigPath, Key: coreAuthConfigPath,
Value: raw, Value: compressedBytes,
} }
// Write to the physical backend // Write to the physical backend
@@ -290,6 +332,24 @@ func (c *Core) persistAuth(table *MountTable) error {
c.logger.Error("core: failed to persist auth table", "error", err) c.logger.Error("core: failed to persist auth table", "error", err)
return err return err
} }
// Repeat with local auth
compressedBytes, err = jsonutil.EncodeJSONAndCompress(localAuth, nil)
if err != nil {
c.logger.Error("core: failed to encode and/or compress local auth table", "error", err)
return err
}
entry = &Entry{
Key: coreLocalAuthConfigPath,
Value: compressedBytes,
}
if err := c.barrier.Put(entry); err != nil {
c.logger.Error("core: failed to persist local auth table", "error", err)
return err
}
return nil return nil
} }
@@ -312,15 +372,21 @@ func (c *Core) setupCredentials() error {
} }
// Create a barrier view using the UUID // Create a barrier view using the UUID
view = NewBarrierView(c.barrier, credentialBarrierPrefix+entry.UUID+"/") viewPath := credentialBarrierPrefix + entry.UUID + "/"
view = NewBarrierView(c.barrier, viewPath)
sysView := c.mountEntrySysView(entry)
// Initialize the backend // Initialize the backend
backend, err = c.newCredentialBackend(entry.Type, c.mountEntrySysView(entry), view, nil) backend, err = c.newCredentialBackend(entry.Type, sysView, view, nil)
if err != nil { if err != nil {
c.logger.Error("core: failed to create credential entry", "path", entry.Path, "error", err) c.logger.Error("core: failed to create credential entry", "path", entry.Path, "error", err)
return errLoadAuthFailed return errLoadAuthFailed
} }
if err := backend.Initialize(); err != nil {
return err
}
// Mount the backend // Mount the backend
path := credentialRoutePrefix + entry.Path path := credentialRoutePrefix + entry.Path
err = c.router.Mount(backend, path, entry, view) err = c.router.Mount(backend, path, entry, view)

View File

@@ -2,8 +2,10 @@ package vault
import ( import (
"reflect" "reflect"
"strings"
"testing" "testing"
"github.com/hashicorp/vault/helper/jsonutil"
"github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical"
) )
@@ -84,6 +86,88 @@ func TestCore_EnableCredential(t *testing.T) {
} }
} }
// Test that the local table actually gets populated as expected with local
// entries, and that upon reading the entries from both are recombined
// correctly
func TestCore_EnableCredential_Local(t *testing.T) {
c, _, _ := TestCoreUnsealed(t)
c.credentialBackends["noop"] = func(*logical.BackendConfig) (logical.Backend, error) {
return &NoopBackend{}, nil
}
c.auth = &MountTable{
Type: credentialTableType,
Entries: []*MountEntry{
&MountEntry{
Table: credentialTableType,
Path: "noop/",
Type: "noop",
UUID: "abcd",
},
&MountEntry{
Table: credentialTableType,
Path: "noop2/",
Type: "noop",
UUID: "bcde",
},
},
}
// Both should set up successfully
err := c.setupCredentials()
if err != nil {
t.Fatal(err)
}
rawLocal, err := c.barrier.Get(coreLocalAuthConfigPath)
if err != nil {
t.Fatal(err)
}
if rawLocal == nil {
t.Fatal("expected non-nil local credential")
}
localCredentialTable := &MountTable{}
if err := jsonutil.DecodeJSON(rawLocal.Value, localCredentialTable); err != nil {
t.Fatal(err)
}
if len(localCredentialTable.Entries) > 0 {
t.Fatalf("expected no entries in local credential table, got %#v", localCredentialTable)
}
c.auth.Entries[1].Local = true
if err := c.persistAuth(c.auth); err != nil {
t.Fatal(err)
}
rawLocal, err = c.barrier.Get(coreLocalAuthConfigPath)
if err != nil {
t.Fatal(err)
}
if rawLocal == nil {
t.Fatal("expected non-nil local credential")
}
localCredentialTable = &MountTable{}
if err := jsonutil.DecodeJSON(rawLocal.Value, localCredentialTable); err != nil {
t.Fatal(err)
}
if len(localCredentialTable.Entries) != 1 {
t.Fatalf("expected one entry in local credential table, got %#v", localCredentialTable)
}
oldCredential := c.auth
if err := c.loadCredentials(); err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(oldCredential, c.auth) {
t.Fatalf("expected\n%#v\ngot\n%#v\n", oldCredential, c.auth)
}
if len(c.auth.Entries) != 2 {
t.Fatalf("expected two credential entries, got %#v", localCredentialTable)
}
}
func TestCore_EnableCredential_twice_409(t *testing.T) { func TestCore_EnableCredential_twice_409(t *testing.T) {
c, _, _ := TestCoreUnsealed(t) c, _, _ := TestCoreUnsealed(t)
c.credentialBackends["noop"] = func(*logical.BackendConfig) (logical.Backend, error) { c.credentialBackends["noop"] = func(*logical.BackendConfig) (logical.Backend, error) {
@@ -132,7 +216,7 @@ func TestCore_DisableCredential(t *testing.T) {
} }
existed, err := c.disableCredential("foo") existed, err := c.disableCredential("foo")
if existed || err.Error() != "no matching backend" { if existed || (err != nil && !strings.HasPrefix(err.Error(), "no matching backend")) {
t.Fatalf("existed: %v; err: %v", existed, err) t.Fatalf("existed: %v; err: %v", existed, err)
} }

View File

@@ -86,6 +86,11 @@ type SecurityBarrier interface {
// VerifyMaster is used to check if the given key matches the master key // VerifyMaster is used to check if the given key matches the master key
VerifyMaster(key []byte) error VerifyMaster(key []byte) error
// SetMasterKey is used to directly set a new master key. This is used in
// repliated scenarios due to the chicken and egg problem of reloading the
// keyring from disk before we have the master key to decrypt it.
SetMasterKey(key []byte) error
// ReloadKeyring is used to re-read the underlying keyring. // ReloadKeyring is used to re-read the underlying keyring.
// This is used for HA deployments to ensure the latest keyring // This is used for HA deployments to ensure the latest keyring
// is present in the leader. // is present in the leader.
@@ -119,8 +124,14 @@ type SecurityBarrier interface {
// Rekey is used to change the master key used to protect the keyring // Rekey is used to change the master key used to protect the keyring
Rekey([]byte) error Rekey([]byte) error
// For replication we must send over the keyring, so this must be available
Keyring() (*Keyring, error)
// SecurityBarrier must provide the storage APIs // SecurityBarrier must provide the storage APIs
BarrierStorage BarrierStorage
// SecurityBarrier must provide the encryption APIs
BarrierEncryptor
} }
// BarrierStorage is the storage only interface required for a Barrier. // BarrierStorage is the storage only interface required for a Barrier.
@@ -139,6 +150,14 @@ type BarrierStorage interface {
List(prefix string) ([]string, error) List(prefix string) ([]string, error)
} }
// BarrierEncryptor is the in memory only interface that does not actually
// use the underlying barrier. It is used for lower level modules like the
// Write-Ahead-Log and Merkle index to allow them to use the barrier.
type BarrierEncryptor interface {
Encrypt(key string, plaintext []byte) ([]byte, error)
Decrypt(key string, ciphertext []byte) ([]byte, error)
}
// Entry is used to represent data stored by the security barrier // Entry is used to represent data stored by the security barrier
type Entry struct { type Entry struct {
Key string Key string

View File

@@ -574,19 +574,12 @@ func (b *AESGCMBarrier) ActiveKeyInfo() (*KeyInfo, error) {
func (b *AESGCMBarrier) Rekey(key []byte) error { func (b *AESGCMBarrier) Rekey(key []byte) error {
b.l.Lock() b.l.Lock()
defer b.l.Unlock() defer b.l.Unlock()
if b.sealed {
return ErrBarrierSealed
}
// Verify the key size newKeyring, err := b.updateMasterKeyCommon(key)
min, max := b.KeyLength() if err != nil {
if len(key) < min || len(key) > max { return err
return fmt.Errorf("Key size must be %d or %d", min, max)
} }
// Add a new encryption key
newKeyring := b.keyring.SetMasterKey(key)
// Persist the new keyring // Persist the new keyring
if err := b.persistKeyring(newKeyring); err != nil { if err := b.persistKeyring(newKeyring); err != nil {
return err return err
@@ -599,6 +592,40 @@ func (b *AESGCMBarrier) Rekey(key []byte) error {
return nil return nil
} }
// SetMasterKey updates the keyring's in-memory master key but does not persist
// anything to storage
func (b *AESGCMBarrier) SetMasterKey(key []byte) error {
b.l.Lock()
defer b.l.Unlock()
newKeyring, err := b.updateMasterKeyCommon(key)
if err != nil {
return err
}
// Swap the keyrings
oldKeyring := b.keyring
b.keyring = newKeyring
oldKeyring.Zeroize(false)
return nil
}
// Performs common tasks related to updating the master key; note that the lock
// must be held before calling this function
func (b *AESGCMBarrier) updateMasterKeyCommon(key []byte) (*Keyring, error) {
if b.sealed {
return nil, ErrBarrierSealed
}
// Verify the key size
min, max := b.KeyLength()
if len(key) < min || len(key) > max {
return nil, fmt.Errorf("Key size must be %d or %d", min, max)
}
return b.keyring.SetMasterKey(key), nil
}
// Put is used to insert or update an entry // Put is used to insert or update an entry
func (b *AESGCMBarrier) Put(entry *Entry) error { func (b *AESGCMBarrier) Put(entry *Entry) error {
defer metrics.MeasureSince([]string{"barrier", "put"}, time.Now()) defer metrics.MeasureSince([]string{"barrier", "put"}, time.Now())
@@ -813,3 +840,47 @@ func (b *AESGCMBarrier) decryptKeyring(path string, cipher []byte) ([]byte, erro
return nil, fmt.Errorf("version bytes mis-match") return nil, fmt.Errorf("version bytes mis-match")
} }
} }
// Encrypt is used to encrypt in-memory for the BarrierEncryptor interface
func (b *AESGCMBarrier) Encrypt(key string, plaintext []byte) ([]byte, error) {
b.l.RLock()
defer b.l.RUnlock()
if b.sealed {
return nil, ErrBarrierSealed
}
term := b.keyring.ActiveTerm()
primary, err := b.aeadForTerm(term)
if err != nil {
return nil, err
}
ciphertext := b.encrypt(key, term, primary, plaintext)
return ciphertext, nil
}
// Decrypt is used to decrypt in-memory for the BarrierEncryptor interface
func (b *AESGCMBarrier) Decrypt(key string, ciphertext []byte) ([]byte, error) {
b.l.RLock()
defer b.l.RUnlock()
if b.sealed {
return nil, ErrBarrierSealed
}
// Decrypt the ciphertext
plain, err := b.decryptKeyring(key, ciphertext)
if err != nil {
return nil, fmt.Errorf("decryption failed: %v", err)
}
return plain, nil
}
func (b *AESGCMBarrier) Keyring() (*Keyring, error) {
b.l.RLock()
defer b.l.RUnlock()
if b.sealed {
return nil, ErrBarrierSealed
}
return b.keyring.Clone(), nil
}

View File

@@ -15,7 +15,7 @@ var (
) )
// mockBarrier returns a physical backend, security barrier, and master key // mockBarrier returns a physical backend, security barrier, and master key
func mockBarrier(t *testing.T) (physical.Backend, SecurityBarrier, []byte) { func mockBarrier(t testing.TB) (physical.Backend, SecurityBarrier, []byte) {
inm := physical.NewInmem(logger) inm := physical.NewInmem(logger)
b, err := NewAESGCMBarrier(inm) b, err := NewAESGCMBarrier(inm)
@@ -433,3 +433,30 @@ func TestInitialize_KeyLength(t *testing.T) {
t.Fatalf("key length protection failed") t.Fatalf("key length protection failed")
} }
} }
func TestEncrypt_BarrierEncryptor(t *testing.T) {
inm := physical.NewInmem(logger)
b, err := NewAESGCMBarrier(inm)
if err != nil {
t.Fatalf("err: %v", err)
}
// Initialize and unseal
key, _ := b.GenerateKey()
b.Initialize(key)
b.Unseal(key)
cipher, err := b.Encrypt("foo", []byte("quick brown fox"))
if err != nil {
t.Fatalf("err: %v", err)
}
plain, err := b.Decrypt("foo", cipher)
if err != nil {
t.Fatalf("err: %v", err)
}
if string(plain) != "quick brown fox" {
t.Fatalf("bad: %s", plain)
}
}

View File

@@ -69,14 +69,18 @@ func (v *BarrierView) Get(key string) (*logical.StorageEntry, error) {
// logical.Storage impl. // logical.Storage impl.
func (v *BarrierView) Put(entry *logical.StorageEntry) error { func (v *BarrierView) Put(entry *logical.StorageEntry) error {
if v.readonly {
return logical.ErrReadOnly
}
if err := v.sanityCheck(entry.Key); err != nil { if err := v.sanityCheck(entry.Key); err != nil {
return err return err
} }
expandedKey := v.expandKey(entry.Key)
if v.readonly {
return logical.ErrReadOnly
}
nested := &Entry{ nested := &Entry{
Key: v.expandKey(entry.Key), Key: expandedKey,
Value: entry.Value, Value: entry.Value,
} }
return v.barrier.Put(nested) return v.barrier.Put(nested)
@@ -84,13 +88,18 @@ func (v *BarrierView) Put(entry *logical.StorageEntry) error {
// logical.Storage impl. // logical.Storage impl.
func (v *BarrierView) Delete(key string) error { func (v *BarrierView) Delete(key string) error {
if v.readonly {
return logical.ErrReadOnly
}
if err := v.sanityCheck(key); err != nil { if err := v.sanityCheck(key); err != nil {
return err return err
} }
return v.barrier.Delete(v.expandKey(key))
expandedKey := v.expandKey(key)
if v.readonly {
return logical.ErrReadOnly
}
return v.barrier.Delete(expandedKey)
} }
// SubView constructs a nested sub-view using the given prefix // SubView constructs a nested sub-view using the given prefix

View File

@@ -1,27 +1,19 @@
package vault package vault
import "sort" import (
"sort"
// Struct to identify user input errors. "github.com/hashicorp/vault/logical"
// This is helpful in responding the appropriate status codes to clients )
// from the HTTP endpoints.
type StatusBadRequest struct {
Err string
}
// Implementing error interface
func (s *StatusBadRequest) Error() string {
return s.Err
}
// Capabilities is used to fetch the capabilities of the given token on the given path // Capabilities is used to fetch the capabilities of the given token on the given path
func (c *Core) Capabilities(token, path string) ([]string, error) { func (c *Core) Capabilities(token, path string) ([]string, error) {
if path == "" { if path == "" {
return nil, &StatusBadRequest{Err: "missing path"} return nil, &logical.StatusBadRequest{Err: "missing path"}
} }
if token == "" { if token == "" {
return nil, &StatusBadRequest{Err: "missing token"} return nil, &logical.StatusBadRequest{Err: "missing token"}
} }
te, err := c.tokenStore.Lookup(token) te, err := c.tokenStore.Lookup(token)
@@ -29,7 +21,7 @@ func (c *Core) Capabilities(token, path string) ([]string, error) {
return nil, err return nil, err
} }
if te == nil { if te == nil {
return nil, &StatusBadRequest{Err: "invalid token"} return nil, &logical.StatusBadRequest{Err: "invalid token"}
} }
if te.Policies == nil { if te.Policies == nil {

View File

@@ -43,7 +43,7 @@ var (
// This can be one of a few key types so the different params may or may not be filled // This can be one of a few key types so the different params may or may not be filled
type clusterKeyParams struct { type clusterKeyParams struct {
Type string `json:"type"` Type string `json:"type" structs:"type" mapstructure:"type"`
X *big.Int `json:"x" structs:"x" mapstructure:"x"` X *big.Int `json:"x" structs:"x" mapstructure:"x"`
Y *big.Int `json:"y" structs:"y" mapstructure:"y"` Y *big.Int `json:"y" structs:"y" mapstructure:"y"`
D *big.Int `json:"d" structs:"d" mapstructure:"d"` D *big.Int `json:"d" structs:"d" mapstructure:"d"`
@@ -339,45 +339,67 @@ func (c *Core) stopClusterListener() {
c.logger.Info("core/stopClusterListener: success") c.logger.Info("core/stopClusterListener: success")
} }
// ClusterTLSConfig generates a TLS configuration based on the local cluster // ClusterTLSConfig generates a TLS configuration based on the local/replicated
// key and cert. // cluster key and cert.
func (c *Core) ClusterTLSConfig() (*tls.Config, error) { func (c *Core) ClusterTLSConfig() (*tls.Config, error) {
cluster, err := c.Cluster() cluster, err := c.Cluster()
if err != nil { if err != nil {
return nil, err return nil, err
} }
if cluster == nil { if cluster == nil {
return nil, fmt.Errorf("cluster information is nil") return nil, fmt.Errorf("local cluster information is nil")
} }
// Prevent data races with the TLS parameters // Prevent data races with the TLS parameters
c.clusterParamsLock.Lock() c.clusterParamsLock.Lock()
defer c.clusterParamsLock.Unlock() defer c.clusterParamsLock.Unlock()
if c.localClusterCert == nil || len(c.localClusterCert) == 0 { forwarding := c.localClusterCert != nil && len(c.localClusterCert) > 0
return nil, fmt.Errorf("cluster certificate is nil")
var parsedCert *x509.Certificate
if forwarding {
parsedCert, err = x509.ParseCertificate(c.localClusterCert)
if err != nil {
return nil, fmt.Errorf("error parsing local cluster certificate: %v", err)
}
// This is idempotent, so be sure it's been added
c.clusterCertPool.AddCert(parsedCert)
} }
parsedCert, err := x509.ParseCertificate(c.localClusterCert) nameLookup := func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
if err != nil { c.clusterParamsLock.RLock()
return nil, fmt.Errorf("error parsing local cluster certificate: %v", err) defer c.clusterParamsLock.RUnlock()
}
// This is idempotent, so be sure it's been added if forwarding && clientHello.ServerName == parsedCert.Subject.CommonName {
c.clusterCertPool.AddCert(parsedCert) return &tls.Certificate{
tlsConfig := &tls.Config{
Certificates: []tls.Certificate{
tls.Certificate{
Certificate: [][]byte{c.localClusterCert}, Certificate: [][]byte{c.localClusterCert},
PrivateKey: c.localClusterPrivateKey, PrivateKey: c.localClusterPrivateKey,
}, }, nil
}, }
RootCAs: c.clusterCertPool,
ServerName: parsedCert.Subject.CommonName, return nil, nil
ClientAuth: tls.RequireAndVerifyClientCert, }
ClientCAs: c.clusterCertPool,
MinVersion: tls.VersionTLS12, var clientCertificates []tls.Certificate
if forwarding {
clientCertificates = append(clientCertificates, tls.Certificate{
Certificate: [][]byte{c.localClusterCert},
PrivateKey: c.localClusterPrivateKey,
})
}
tlsConfig := &tls.Config{
// We need this here for the client side
Certificates: clientCertificates,
RootCAs: c.clusterCertPool,
ClientAuth: tls.RequireAndVerifyClientCert,
ClientCAs: c.clusterCertPool,
GetCertificate: nameLookup,
MinVersion: tls.VersionTLS12,
}
if forwarding {
tlsConfig.ServerName = parsedCert.Subject.CommonName
} }
return tlsConfig, nil return tlsConfig, nil

View File

@@ -10,6 +10,7 @@ import (
"testing" "testing"
"time" "time"
"github.com/hashicorp/vault/helper/consts"
"github.com/hashicorp/vault/helper/logformat" "github.com/hashicorp/vault/helper/logformat"
"github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/physical" "github.com/hashicorp/vault/physical"
@@ -100,7 +101,7 @@ func TestCluster_ListenForRequests(t *testing.T) {
checkListenersFunc := func(expectFail bool) { checkListenersFunc := func(expectFail bool) {
tlsConfig, err := cores[0].ClusterTLSConfig() tlsConfig, err := cores[0].ClusterTLSConfig()
if err != nil { if err != nil {
if err.Error() != ErrSealed.Error() { if err.Error() != consts.ErrSealed.Error() {
t.Fatal(err) t.Fatal(err)
} }
tlsConfig = lastTLSConfig tlsConfig = lastTLSConfig

View File

@@ -1,9 +1,9 @@
package vault package vault
import ( import (
"bytes"
"crypto" "crypto"
"crypto/ecdsa" "crypto/ecdsa"
"crypto/subtle"
"crypto/x509" "crypto/x509"
"errors" "errors"
"fmt" "fmt"
@@ -23,6 +23,7 @@ import (
"github.com/hashicorp/go-multierror" "github.com/hashicorp/go-multierror"
"github.com/hashicorp/go-uuid" "github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/audit" "github.com/hashicorp/vault/audit"
"github.com/hashicorp/vault/helper/consts"
"github.com/hashicorp/vault/helper/errutil" "github.com/hashicorp/vault/helper/errutil"
"github.com/hashicorp/vault/helper/jsonutil" "github.com/hashicorp/vault/helper/jsonutil"
"github.com/hashicorp/vault/helper/logformat" "github.com/hashicorp/vault/helper/logformat"
@@ -56,17 +57,14 @@ const (
// leaderPrefixCleanDelay is how long to wait between deletions // leaderPrefixCleanDelay is how long to wait between deletions
// of orphaned leader keys, to prevent slamming the backend. // of orphaned leader keys, to prevent slamming the backend.
leaderPrefixCleanDelay = 200 * time.Millisecond leaderPrefixCleanDelay = 200 * time.Millisecond
// coreKeyringCanaryPath is used as a canary to indicate to replicated
// clusters that they need to perform a rekey operation synchronously; this
// isn't keyring-canary to avoid ignoring it when ignoring core/keyring
coreKeyringCanaryPath = "core/canary-keyring"
) )
var ( var (
// ErrSealed is returned if an operation is performed on
// a sealed barrier. No operation is expected to succeed before unsealing
ErrSealed = errors.New("Vault is sealed")
// ErrStandby is returned if an operation is performed on
// a standby Vault. No operation is expected to succeed until active.
ErrStandby = errors.New("Vault is in standby mode")
// ErrAlreadyInit is returned if the core is already // ErrAlreadyInit is returned if the core is already
// initialized. This prevents a re-initialization. // initialized. This prevents a re-initialization.
ErrAlreadyInit = errors.New("Vault is already initialized") ErrAlreadyInit = errors.New("Vault is already initialized")
@@ -87,6 +85,12 @@ var (
// step down of the active node, to prevent instantly regrabbing the lock. // step down of the active node, to prevent instantly regrabbing the lock.
// It's var not const so that tests can manipulate it. // It's var not const so that tests can manipulate it.
manualStepDownSleepPeriod = 10 * time.Second manualStepDownSleepPeriod = 10 * time.Second
// Functions only in the Enterprise version
enterprisePostUnseal = enterprisePostUnsealImpl
enterprisePreSeal = enterprisePreSealImpl
startReplication = startReplicationImpl
stopReplication = stopReplicationImpl
) )
// ReloadFunc are functions that are called when a reload is requested. // ReloadFunc are functions that are called when a reload is requested.
@@ -133,6 +137,11 @@ type unlockInformation struct {
// interface for API handlers and is responsible for managing the logical and physical // interface for API handlers and is responsible for managing the logical and physical
// backends, router, security barrier, and audit trails. // backends, router, security barrier, and audit trails.
type Core struct { type Core struct {
// N.B.: This is used to populate a dev token down replication, as
// otherwise, after replication is started, a dev would have to go through
// the generate-root process simply to talk to the new follower cluster.
devToken string
// HABackend may be available depending on the physical backend // HABackend may be available depending on the physical backend
ha physical.HABackend ha physical.HABackend
@@ -268,7 +277,7 @@ type Core struct {
// //
// Name // Name
clusterName string clusterName string
// Used to modify cluster TLS params // Used to modify cluster parameters
clusterParamsLock sync.RWMutex clusterParamsLock sync.RWMutex
// The private key stored in the barrier used for establishing // The private key stored in the barrier used for establishing
// mutually-authenticated connections between Vault cluster members // mutually-authenticated connections between Vault cluster members
@@ -310,11 +319,13 @@ type Core struct {
// replicationState keeps the current replication state cached for quick // replicationState keeps the current replication state cached for quick
// lookup // lookup
replicationState logical.ReplicationState replicationState consts.ReplicationState
} }
// CoreConfig is used to parameterize a core // CoreConfig is used to parameterize a core
type CoreConfig struct { type CoreConfig struct {
DevToken string `json:"dev_token" structs:"dev_token" mapstructure:"dev_token"`
LogicalBackends map[string]logical.Factory `json:"logical_backends" structs:"logical_backends" mapstructure:"logical_backends"` LogicalBackends map[string]logical.Factory `json:"logical_backends" structs:"logical_backends" mapstructure:"logical_backends"`
CredentialBackends map[string]logical.Factory `json:"credential_backends" structs:"credential_backends" mapstructure:"credential_backends"` CredentialBackends map[string]logical.Factory `json:"credential_backends" structs:"credential_backends" mapstructure:"credential_backends"`
@@ -390,6 +401,30 @@ func NewCore(conf *CoreConfig) (*Core, error) {
conf.Logger = logformat.NewVaultLogger(log.LevelTrace) conf.Logger = logformat.NewVaultLogger(log.LevelTrace)
} }
// Setup the core
c := &Core{
redirectAddr: conf.RedirectAddr,
clusterAddr: conf.ClusterAddr,
physical: conf.Physical,
seal: conf.Seal,
router: NewRouter(),
sealed: true,
standby: true,
logger: conf.Logger,
defaultLeaseTTL: conf.DefaultLeaseTTL,
maxLeaseTTL: conf.MaxLeaseTTL,
cachingDisabled: conf.DisableCache,
clusterName: conf.ClusterName,
clusterCertPool: x509.NewCertPool(),
clusterListenerShutdownCh: make(chan struct{}),
clusterListenerShutdownSuccessCh: make(chan struct{}),
}
// Wrap the physical backend in a cache layer if enabled and not already wrapped
if _, isCache := conf.Physical.(*physical.Cache); !conf.DisableCache && !isCache {
c.physical = physical.NewCache(conf.Physical, conf.CacheSize, conf.Logger)
}
if !conf.DisableMlock { if !conf.DisableMlock {
// Ensure our memory usage is locked into physical RAM // Ensure our memory usage is locked into physical RAM
if err := mlock.LockMemory(); err != nil { if err := mlock.LockMemory(); err != nil {
@@ -407,36 +442,12 @@ func NewCore(conf *CoreConfig) (*Core, error) {
} }
// Construct a new AES-GCM barrier // Construct a new AES-GCM barrier
barrier, err := NewAESGCMBarrier(conf.Physical) var err error
c.barrier, err = NewAESGCMBarrier(c.physical)
if err != nil { if err != nil {
return nil, fmt.Errorf("barrier setup failed: %v", err) return nil, fmt.Errorf("barrier setup failed: %v", err)
} }
// Setup the core
c := &Core{
redirectAddr: conf.RedirectAddr,
clusterAddr: conf.ClusterAddr,
physical: conf.Physical,
seal: conf.Seal,
barrier: barrier,
router: NewRouter(),
sealed: true,
standby: true,
logger: conf.Logger,
defaultLeaseTTL: conf.DefaultLeaseTTL,
maxLeaseTTL: conf.MaxLeaseTTL,
cachingDisabled: conf.DisableCache,
clusterName: conf.ClusterName,
clusterCertPool: x509.NewCertPool(),
clusterListenerShutdownCh: make(chan struct{}),
clusterListenerShutdownSuccessCh: make(chan struct{}),
}
// Wrap the backend in a cache unless disabled
if _, isCache := conf.Physical.(*physical.Cache); !conf.DisableCache && !isCache {
c.physical = physical.NewCache(conf.Physical, conf.CacheSize, conf.Logger)
}
if conf.HAPhysical != nil && conf.HAPhysical.HAEnabled() { if conf.HAPhysical != nil && conf.HAPhysical.HAEnabled() {
c.ha = conf.HAPhysical c.ha = conf.HAPhysical
} }
@@ -518,10 +529,10 @@ func (c *Core) LookupToken(token string) (*TokenEntry, error) {
c.stateLock.RLock() c.stateLock.RLock()
defer c.stateLock.RUnlock() defer c.stateLock.RUnlock()
if c.sealed { if c.sealed {
return nil, ErrSealed return nil, consts.ErrSealed
} }
if c.standby { if c.standby {
return nil, ErrStandby return nil, consts.ErrStandby
} }
// Many tests don't have a token store running // Many tests don't have a token store running
@@ -656,7 +667,7 @@ func (c *Core) Leader() (isLeader bool, leaderAddr string, err error) {
// Check if sealed // Check if sealed
if c.sealed { if c.sealed {
return false, "", ErrSealed return false, "", consts.ErrSealed
} }
// Check if HA enabled // Check if HA enabled
@@ -803,17 +814,29 @@ func (c *Core) Unseal(key []byte) (bool, error) {
return true, nil return true, nil
} }
masterKey, err := c.unsealPart(config, key)
if err != nil {
return false, err
}
if masterKey != nil {
return c.unsealInternal(masterKey)
}
return false, nil
}
func (c *Core) unsealPart(config *SealConfig, key []byte) ([]byte, error) {
// Check if we already have this piece // Check if we already have this piece
if c.unlockInfo != nil { if c.unlockInfo != nil {
for _, existing := range c.unlockInfo.Parts { for _, existing := range c.unlockInfo.Parts {
if bytes.Equal(existing, key) { if subtle.ConstantTimeCompare(existing, key) == 1 {
return false, nil return nil, nil
} }
} }
} else { } else {
uuid, err := uuid.GenerateUUID() uuid, err := uuid.GenerateUUID()
if err != nil { if err != nil {
return false, err return nil, err
} }
c.unlockInfo = &unlockInformation{ c.unlockInfo = &unlockInformation{
Nonce: uuid, Nonce: uuid,
@@ -828,27 +851,37 @@ func (c *Core) Unseal(key []byte) (bool, error) {
if c.logger.IsDebug() { if c.logger.IsDebug() {
c.logger.Debug("core: cannot unseal, not enough keys", "keys", len(c.unlockInfo.Parts), "threshold", config.SecretThreshold, "nonce", c.unlockInfo.Nonce) c.logger.Debug("core: cannot unseal, not enough keys", "keys", len(c.unlockInfo.Parts), "threshold", config.SecretThreshold, "nonce", c.unlockInfo.Nonce)
} }
return false, nil return nil, nil
} }
// Best-effort memzero of unlock parts once we're done with them
defer func() {
for i, _ := range c.unlockInfo.Parts {
memzero(c.unlockInfo.Parts[i])
}
c.unlockInfo = nil
}()
// Recover the master key // Recover the master key
var masterKey []byte var masterKey []byte
var err error
if config.SecretThreshold == 1 { if config.SecretThreshold == 1 {
masterKey = c.unlockInfo.Parts[0] masterKey = make([]byte, len(c.unlockInfo.Parts[0]))
c.unlockInfo = nil copy(masterKey, c.unlockInfo.Parts[0])
} else { } else {
masterKey, err = shamir.Combine(c.unlockInfo.Parts) masterKey, err = shamir.Combine(c.unlockInfo.Parts)
c.unlockInfo = nil
if err != nil { if err != nil {
return false, fmt.Errorf("failed to compute master key: %v", err) return nil, fmt.Errorf("failed to compute master key: %v", err)
} }
} }
defer memzero(masterKey)
return c.unsealInternal(masterKey) return masterKey, nil
} }
// This must be called with the state write lock held
func (c *Core) unsealInternal(masterKey []byte) (bool, error) { func (c *Core) unsealInternal(masterKey []byte) (bool, error) {
defer memzero(masterKey)
// Attempt to unlock // Attempt to unlock
if err := c.barrier.Unseal(masterKey); err != nil { if err := c.barrier.Unseal(masterKey); err != nil {
return false, err return false, err
@@ -867,12 +900,14 @@ func (c *Core) unsealInternal(masterKey []byte) (bool, error) {
c.logger.Warn("core: vault is sealed") c.logger.Warn("core: vault is sealed")
return false, err return false, err
} }
if err := c.postUnseal(); err != nil { if err := c.postUnseal(); err != nil {
c.logger.Error("core: post-unseal setup failed", "error", err) c.logger.Error("core: post-unseal setup failed", "error", err)
c.barrier.Seal() c.barrier.Seal()
c.logger.Warn("core: vault is sealed") c.logger.Warn("core: vault is sealed")
return false, err return false, err
} }
c.standby = false c.standby = false
} else { } else {
// Go to standby mode, wait until we are active to unseal // Go to standby mode, wait until we are active to unseal
@@ -1168,6 +1203,7 @@ func (c *Core) postUnseal() (retErr error) {
if purgable, ok := c.physical.(physical.Purgable); ok { if purgable, ok := c.physical.(physical.Purgable); ok {
purgable.Purge() purgable.Purge()
} }
// HA mode requires us to handle keyring rotation and rekeying // HA mode requires us to handle keyring rotation and rekeying
if c.ha != nil { if c.ha != nil {
// We want to reload these from disk so that in case of a rekey we're // We want to reload these from disk so that in case of a rekey we're
@@ -1190,6 +1226,9 @@ func (c *Core) postUnseal() (retErr error) {
return err return err
} }
} }
if err := enterprisePostUnseal(c); err != nil {
return err
}
if err := c.ensureWrappingKey(); err != nil { if err := c.ensureWrappingKey(); err != nil {
return err return err
} }
@@ -1251,6 +1290,7 @@ func (c *Core) preSeal() error {
c.metricsCh = nil c.metricsCh = nil
} }
var result error var result error
if c.ha != nil { if c.ha != nil {
c.stopClusterListener() c.stopClusterListener()
} }
@@ -1273,6 +1313,10 @@ func (c *Core) preSeal() error {
if err := c.unloadMounts(); err != nil { if err := c.unloadMounts(); err != nil {
result = multierror.Append(result, errwrap.Wrapf("error unloading mounts: {{err}}", err)) result = multierror.Append(result, errwrap.Wrapf("error unloading mounts: {{err}}", err))
} }
if err := enterprisePreSeal(c); err != nil {
result = multierror.Append(result, err)
}
// Purge the backend if supported // Purge the backend if supported
if purgable, ok := c.physical.(physical.Purgable); ok { if purgable, ok := c.physical.(physical.Purgable); ok {
purgable.Purge() purgable.Purge()
@@ -1281,6 +1325,22 @@ func (c *Core) preSeal() error {
return result return result
} }
func enterprisePostUnsealImpl(c *Core) error {
return nil
}
func enterprisePreSealImpl(c *Core) error {
return nil
}
func startReplicationImpl(c *Core) error {
return nil
}
func stopReplicationImpl(c *Core) error {
return nil
}
// runStandby is a long running routine that is used when an HA backend // runStandby is a long running routine that is used when an HA backend
// is enabled. It waits until we are leader and switches this Vault to // is enabled. It waits until we are leader and switches this Vault to
// active. // active.
@@ -1599,6 +1659,14 @@ func (c *Core) emitMetrics(stopCh chan struct{}) {
} }
} }
func (c *Core) ReplicationState() consts.ReplicationState {
var state consts.ReplicationState
c.clusterParamsLock.RLock()
state = c.replicationState
c.clusterParamsLock.RUnlock()
return state
}
func (c *Core) SealAccess() *SealAccess { func (c *Core) SealAccess() *SealAccess {
sa := &SealAccess{} sa := &SealAccess{}
sa.SetSeal(c.seal) sa.SetSeal(c.seal)

View File

@@ -8,6 +8,7 @@ import (
"github.com/hashicorp/errwrap" "github.com/hashicorp/errwrap"
"github.com/hashicorp/go-uuid" "github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/audit" "github.com/hashicorp/vault/audit"
"github.com/hashicorp/vault/helper/consts"
"github.com/hashicorp/vault/helper/logformat" "github.com/hashicorp/vault/helper/logformat"
"github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/physical" "github.com/hashicorp/vault/physical"
@@ -198,7 +199,7 @@ func TestCore_Route_Sealed(t *testing.T) {
Path: "sys/mounts", Path: "sys/mounts",
} }
_, err := c.HandleRequest(req) _, err := c.HandleRequest(req)
if err != ErrSealed { if err != consts.ErrSealed {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
@@ -1541,7 +1542,7 @@ func testCore_Standby_Common(t *testing.T, inm physical.Backend, inmha physical.
// Request should fail in standby mode // Request should fail in standby mode
_, err = core2.HandleRequest(req) _, err = core2.HandleRequest(req)
if err != ErrStandby { if err != consts.ErrStandby {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }

View File

@@ -3,6 +3,7 @@ package vault
import ( import (
"time" "time"
"github.com/hashicorp/vault/helper/consts"
"github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical"
) )
@@ -79,8 +80,8 @@ func (d dynamicSystemView) CachingDisabled() bool {
} }
// Checks if this is a primary Vault instance. // Checks if this is a primary Vault instance.
func (d dynamicSystemView) ReplicationState() logical.ReplicationState { func (d dynamicSystemView) ReplicationState() consts.ReplicationState {
var state logical.ReplicationState var state consts.ReplicationState
d.core.clusterParamsLock.RLock() d.core.clusterParamsLock.RLock()
state = d.core.replicationState state = d.core.replicationState
d.core.clusterParamsLock.RUnlock() d.core.clusterParamsLock.RUnlock()

View File

@@ -12,6 +12,7 @@ import (
log "github.com/mgutz/logxi/v1" log "github.com/mgutz/logxi/v1"
"github.com/hashicorp/go-uuid" "github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/helper/consts"
"github.com/hashicorp/vault/helper/jsonutil" "github.com/hashicorp/vault/helper/jsonutil"
"github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical"
) )
@@ -125,46 +126,114 @@ func (m *ExpirationManager) Restore() error {
if err != nil { if err != nil {
return fmt.Errorf("failed to scan for leases: %v", err) return fmt.Errorf("failed to scan for leases: %v", err)
} }
m.logger.Debug("expiration: leases collected", "num_existing", len(existing)) m.logger.Debug("expiration: leases collected", "num_existing", len(existing))
// Restore each key // Make the channels used for the worker pool
for i, leaseID := range existing { broker := make(chan string)
if i%500 == 0 { quit := make(chan bool)
m.logger.Trace("expiration: leases loading", "progress", i) // Buffer these channels to prevent deadlocks
} errs := make(chan error, len(existing))
// Load the entry result := make(chan *leaseEntry, len(existing))
le, err := m.loadEntry(leaseID)
if err != nil {
return err
}
// If there is no entry, nothing to restore // Use a wait group
if le == nil { wg := &sync.WaitGroup{}
continue
}
// If there is no expiry time, don't do anything // Create 64 workers to distribute work to
if le.ExpireTime.IsZero() { for i := 0; i < consts.ExpirationRestoreWorkerCount; i++ {
continue wg.Add(1)
} go func() {
defer wg.Done()
// Determine the remaining time to expiration for {
expires := le.ExpireTime.Sub(time.Now()) select {
if expires <= 0 { case leaseID, ok := <-broker:
expires = minRevokeDelay // broker has been closed, we are done
} if !ok {
return
}
// Setup revocation timer le, err := m.loadEntry(leaseID)
m.pending[le.LeaseID] = time.AfterFunc(expires, func() { if err != nil {
m.expireID(le.LeaseID) errs <- err
}) continue
}
// Write results out to the result channel
result <- le
// quit early
case <-quit:
return
}
}
}()
} }
// Distribute the collected keys to the workers in a go routine
wg.Add(1)
go func() {
defer wg.Done()
for i, leaseID := range existing {
if i%500 == 0 {
m.logger.Trace("expiration: leases loading", "progress", i)
}
select {
case <-quit:
return
default:
broker <- leaseID
}
}
// Close the broker, causing worker routines to exit
close(broker)
}()
// Restore each key by pulling from the result chan
for i := 0; i < len(existing); i++ {
select {
case err := <-errs:
// Close all go routines
close(quit)
return err
case le := <-result:
// If there is no entry, nothing to restore
if le == nil {
continue
}
// If there is no expiry time, don't do anything
if le.ExpireTime.IsZero() {
continue
}
// Determine the remaining time to expiration
expires := le.ExpireTime.Sub(time.Now())
if expires <= 0 {
expires = minRevokeDelay
}
// Setup revocation timer
m.pending[le.LeaseID] = time.AfterFunc(expires, func() {
m.expireID(le.LeaseID)
})
}
}
// Let all go routines finish
wg.Wait()
if len(m.pending) > 0 { if len(m.pending) > 0 {
if m.logger.IsInfo() { if m.logger.IsInfo() {
m.logger.Info("expire: leases restored", "restored_lease_count", len(m.pending)) m.logger.Info("expire: leases restored", "restored_lease_count", len(m.pending))
} }
} }
return nil return nil
} }

View File

@@ -2,23 +2,131 @@ package vault
import ( import (
"fmt" "fmt"
"os"
"reflect" "reflect"
"sort" "sort"
"strings" "strings"
"sync"
"testing" "testing"
"time" "time"
"github.com/hashicorp/go-uuid" "github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/helper/logformat"
"github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework" "github.com/hashicorp/vault/logical/framework"
"github.com/hashicorp/vault/physical"
log "github.com/mgutz/logxi/v1"
)
var (
testImagePull sync.Once
) )
// mockExpiration returns a mock expiration manager // mockExpiration returns a mock expiration manager
func mockExpiration(t *testing.T) *ExpirationManager { func mockExpiration(t testing.TB) *ExpirationManager {
_, ts, _, _ := TestCoreWithTokenStore(t) _, ts, _, _ := TestCoreWithTokenStore(t)
return ts.expiration return ts.expiration
} }
func mockBackendExpiration(t testing.TB, backend physical.Backend) (*Core, *ExpirationManager) {
c, ts, _, _ := TestCoreWithBackendTokenStore(t, backend)
return c, ts.expiration
}
func BenchmarkExpiration_Restore_Etcd(b *testing.B) {
addr := os.Getenv("PHYSICAL_BACKEND_BENCHMARK_ADDR")
randPath := fmt.Sprintf("vault-%d/", time.Now().Unix())
logger := logformat.NewVaultLogger(log.LevelTrace)
physicalBackend, err := physical.NewBackend("etcd", logger, map[string]string{
"address": addr,
"path": randPath,
"max_parallel": "256",
})
if err != nil {
b.Fatalf("err: %s", err)
}
benchmarkExpirationBackend(b, physicalBackend, 10000) // 10,000 leases
}
func BenchmarkExpiration_Restore_Consul(b *testing.B) {
addr := os.Getenv("PHYSICAL_BACKEND_BENCHMARK_ADDR")
randPath := fmt.Sprintf("vault-%d/", time.Now().Unix())
logger := logformat.NewVaultLogger(log.LevelTrace)
physicalBackend, err := physical.NewBackend("consul", logger, map[string]string{
"address": addr,
"path": randPath,
"max_parallel": "256",
})
if err != nil {
b.Fatalf("err: %s", err)
}
benchmarkExpirationBackend(b, physicalBackend, 10000) // 10,000 leases
}
func BenchmarkExpiration_Restore_InMem(b *testing.B) {
logger := logformat.NewVaultLogger(log.LevelTrace)
benchmarkExpirationBackend(b, physical.NewInmem(logger), 100000) // 100,000 Leases
}
func benchmarkExpirationBackend(b *testing.B, physicalBackend physical.Backend, numLeases int) {
c, exp := mockBackendExpiration(b, physicalBackend)
noop := &NoopBackend{}
view := NewBarrierView(c.barrier, "logical/")
meUUID, err := uuid.GenerateUUID()
if err != nil {
b.Fatal(err)
}
exp.router.Mount(noop, "prod/aws/", &MountEntry{UUID: meUUID}, view)
// Register fake leases
for i := 0; i < numLeases; i++ {
pathUUID, err := uuid.GenerateUUID()
if err != nil {
b.Fatal(err)
}
req := &logical.Request{
Operation: logical.ReadOperation,
Path: "prod/aws/" + pathUUID,
}
resp := &logical.Response{
Secret: &logical.Secret{
LeaseOptions: logical.LeaseOptions{
TTL: 400 * time.Second,
},
},
Data: map[string]interface{}{
"access_key": "xyz",
"secret_key": "abcd",
},
}
_, err = exp.Register(req, resp)
if err != nil {
b.Fatalf("err: %v", err)
}
}
// Stop everything
err = exp.Stop()
if err != nil {
b.Fatalf("err: %v", err)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
err = exp.Restore()
// Restore
if err != nil {
b.Fatalf("err: %v", err)
}
}
b.StopTimer()
}
func TestExpiration_Restore(t *testing.T) { func TestExpiration_Restore(t *testing.T) {
exp := mockExpiration(t) exp := mockExpiration(t)
noop := &NoopBackend{} noop := &NoopBackend{}

View File

@@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"github.com/hashicorp/go-uuid" "github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/helper/consts"
"github.com/hashicorp/vault/helper/pgpkeys" "github.com/hashicorp/vault/helper/pgpkeys"
"github.com/hashicorp/vault/helper/xor" "github.com/hashicorp/vault/helper/xor"
"github.com/hashicorp/vault/shamir" "github.com/hashicorp/vault/shamir"
@@ -34,10 +35,10 @@ func (c *Core) GenerateRootProgress() (int, error) {
c.stateLock.RLock() c.stateLock.RLock()
defer c.stateLock.RUnlock() defer c.stateLock.RUnlock()
if c.sealed { if c.sealed {
return 0, ErrSealed return 0, consts.ErrSealed
} }
if c.standby { if c.standby {
return 0, ErrStandby return 0, consts.ErrStandby
} }
c.generateRootLock.Lock() c.generateRootLock.Lock()
@@ -52,10 +53,10 @@ func (c *Core) GenerateRootConfiguration() (*GenerateRootConfig, error) {
c.stateLock.RLock() c.stateLock.RLock()
defer c.stateLock.RUnlock() defer c.stateLock.RUnlock()
if c.sealed { if c.sealed {
return nil, ErrSealed return nil, consts.ErrSealed
} }
if c.standby { if c.standby {
return nil, ErrStandby return nil, consts.ErrStandby
} }
c.generateRootLock.Lock() c.generateRootLock.Lock()
@@ -101,10 +102,10 @@ func (c *Core) GenerateRootInit(otp, pgpKey string) error {
c.stateLock.RLock() c.stateLock.RLock()
defer c.stateLock.RUnlock() defer c.stateLock.RUnlock()
if c.sealed { if c.sealed {
return ErrSealed return consts.ErrSealed
} }
if c.standby { if c.standby {
return ErrStandby return consts.ErrStandby
} }
c.generateRootLock.Lock() c.generateRootLock.Lock()
@@ -170,10 +171,10 @@ func (c *Core) GenerateRootUpdate(key []byte, nonce string) (*GenerateRootResult
c.stateLock.RLock() c.stateLock.RLock()
defer c.stateLock.RUnlock() defer c.stateLock.RUnlock()
if c.sealed { if c.sealed {
return nil, ErrSealed return nil, consts.ErrSealed
} }
if c.standby { if c.standby {
return nil, ErrStandby return nil, consts.ErrStandby
} }
c.generateRootLock.Lock() c.generateRootLock.Lock()
@@ -308,10 +309,10 @@ func (c *Core) GenerateRootCancel() error {
c.stateLock.RLock() c.stateLock.RLock()
defer c.stateLock.RUnlock() defer c.stateLock.RUnlock()
if c.sealed { if c.sealed {
return ErrSealed return consts.ErrSealed
} }
if c.standby { if c.standby {
return ErrStandby return consts.ErrStandby
} }
c.generateRootLock.Lock() c.generateRootLock.Lock()

View File

@@ -133,36 +133,12 @@ func (c *Core) Initialize(initParams *InitParams) (*InitResult, error) {
return nil, fmt.Errorf("error initializing seal: %v", err) return nil, fmt.Errorf("error initializing seal: %v", err)
} }
err = c.seal.SetBarrierConfig(barrierConfig)
if err != nil {
c.logger.Error("core: failed to save barrier configuration", "error", err)
return nil, fmt.Errorf("barrier configuration saving failed: %v", err)
}
barrierKey, barrierUnsealKeys, err := c.generateShares(barrierConfig) barrierKey, barrierUnsealKeys, err := c.generateShares(barrierConfig)
if err != nil { if err != nil {
c.logger.Error("core: error generating shares", "error", err) c.logger.Error("core: error generating shares", "error", err)
return nil, err return nil, err
} }
// If we are storing shares, pop them out of the returned results and push
// them through the seal
if barrierConfig.StoredShares > 0 {
var keysToStore [][]byte
for i := 0; i < barrierConfig.StoredShares; i++ {
keysToStore = append(keysToStore, barrierUnsealKeys[0])
barrierUnsealKeys = barrierUnsealKeys[1:]
}
if err := c.seal.SetStoredKeys(keysToStore); err != nil {
c.logger.Error("core: failed to store keys", "error", err)
return nil, fmt.Errorf("failed to store keys: %v", err)
}
}
results := &InitResult{
SecretShares: barrierUnsealKeys,
}
// Initialize the barrier // Initialize the barrier
if err := c.barrier.Initialize(barrierKey); err != nil { if err := c.barrier.Initialize(barrierKey); err != nil {
c.logger.Error("core: failed to initialize barrier", "error", err) c.logger.Error("core: failed to initialize barrier", "error", err)
@@ -180,11 +156,38 @@ func (c *Core) Initialize(initParams *InitParams) (*InitResult, error) {
// Ensure the barrier is re-sealed // Ensure the barrier is re-sealed
defer func() { defer func() {
// Defers are LIFO so we need to run this here too to ensure the stop
// happens before sealing. preSeal also stops, so we just make the
// stopping safe against multiple calls.
if err := c.barrier.Seal(); err != nil { if err := c.barrier.Seal(); err != nil {
c.logger.Error("core: failed to seal barrier", "error", err) c.logger.Error("core: failed to seal barrier", "error", err)
} }
}() }()
err = c.seal.SetBarrierConfig(barrierConfig)
if err != nil {
c.logger.Error("core: failed to save barrier configuration", "error", err)
return nil, fmt.Errorf("barrier configuration saving failed: %v", err)
}
// If we are storing shares, pop them out of the returned results and push
// them through the seal
if barrierConfig.StoredShares > 0 {
var keysToStore [][]byte
for i := 0; i < barrierConfig.StoredShares; i++ {
keysToStore = append(keysToStore, barrierUnsealKeys[0])
barrierUnsealKeys = barrierUnsealKeys[1:]
}
if err := c.seal.SetStoredKeys(keysToStore); err != nil {
c.logger.Error("core: failed to store keys", "error", err)
return nil, fmt.Errorf("failed to store keys: %v", err)
}
}
results := &InitResult{
SecretShares: barrierUnsealKeys,
}
// Perform initial setup // Perform initial setup
if err := c.setupCluster(); err != nil { if err := c.setupCluster(); err != nil {
c.logger.Error("core: cluster setup failed during init", "error", err) c.logger.Error("core: cluster setup failed during init", "error", err)

Some files were not shown because too many files have changed in this diff Show More