mirror of
https://github.com/optim-enterprises-bv/vault.git
synced 2025-12-12 15:05:39 +00:00
Merge branch 'master' into acl-parameters-permission
This commit is contained in:
@@ -7,7 +7,7 @@ services:
|
||||
- docker
|
||||
|
||||
go:
|
||||
- 1.8rc2
|
||||
- 1.8
|
||||
|
||||
matrix:
|
||||
allow_failures:
|
||||
|
||||
17
CHANGELOG.md
17
CHANGELOG.md
@@ -1,5 +1,16 @@
|
||||
## Next (Unreleased)
|
||||
|
||||
DEPRECATIONS/CHANGES:
|
||||
|
||||
* List Operations Always Use Trailing Slash: Any list operation, whether via
|
||||
the `GET` or `LIST` HTTP verb, will now internally canonicalize the path to
|
||||
have a trailing slash. This makes policy writing more predictable, as it
|
||||
means clients will no longer work or fail based on which client they're
|
||||
using or which HTTP verb they're using. However, it also means that policies
|
||||
allowing `list` capability must be carefully checked to ensure that they
|
||||
contain a trailing slash; some policies may need to be split into multiple
|
||||
stanzas to accommodate.
|
||||
|
||||
IMPROVEMENTS:
|
||||
|
||||
* auth/ldap: Use the value of the `LOGNAME` or `USER` env vars for the
|
||||
@@ -7,14 +18,20 @@ IMPROVEMENTS:
|
||||
[GH-2154]
|
||||
* audit: Support adding a configurable prefix (such as `@cee`) before each
|
||||
line [GH-2359]
|
||||
* core: Canonicalize list operations to use a trailing slash [GH-2390]
|
||||
* secret/pki: O (Organization) values can now be set to role-defined values
|
||||
for issued/signed certificates [GH-2369]
|
||||
|
||||
BUG FIXES:
|
||||
|
||||
* audit: When auditing headers use case-insensitive comparisons [GH-2362]
|
||||
* auth/aws-ec2: Return role period in seconds and not nanoseconds [GH-2374]
|
||||
* auth/okta: Fix panic if user had no local groups and/or policies set
|
||||
[GH-2367]
|
||||
* command/server: Fix parsing of redirect address when port is not mentioned
|
||||
[GH-2354]
|
||||
* physical/postgresql: Fix listing returning incorrect results if there were
|
||||
multiple levels of children [GH-2393]
|
||||
|
||||
## 0.6.5 (February 7th, 2017)
|
||||
|
||||
|
||||
5
Makefile
5
Makefile
@@ -24,6 +24,11 @@ dev-dynamic: generate
|
||||
test: generate
|
||||
CGO_ENABLED=0 VAULT_TOKEN= VAULT_ACC= go test -tags='$(BUILD_TAGS)' $(TEST) $(TESTARGS) -timeout=10m -parallel=4
|
||||
|
||||
testcompile: generate
|
||||
@for pkg in $(TEST) ; do \
|
||||
go test -v -c -tags='$(BUILD_TAGS)' $$pkg -parallel=4 ; \
|
||||
done
|
||||
|
||||
# testacc runs acceptance tests
|
||||
testacc: generate
|
||||
@if [ "$(TEST)" = "./..." ]; then \
|
||||
|
||||
@@ -56,9 +56,9 @@ All documentation is available on the [Vault website](https://www.vaultproject.i
|
||||
Developing Vault
|
||||
--------------------
|
||||
|
||||
If you wish to work on Vault itself or any of its built-in systems,
|
||||
you'll first need [Go](https://www.golang.org) installed on your
|
||||
machine (version 1.8+ is *required*).
|
||||
If you wish to work on Vault itself or any of its built-in systems, you'll
|
||||
first need [Go](https://www.golang.org) installed on your machine (version 1.8+
|
||||
is *required*).
|
||||
|
||||
For local dev first make sure Go is properly installed, including setting up a
|
||||
[GOPATH](https://golang.org/doc/code.html#GOPATH). Next, clone this repository
|
||||
|
||||
@@ -3,6 +3,7 @@ package api
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/fatih/structs"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
)
|
||||
|
||||
@@ -71,13 +72,18 @@ func (c *Sys) ListAudit() (map[string]*Audit, error) {
|
||||
return mounts, nil
|
||||
}
|
||||
|
||||
// DEPRECATED: Use EnableAuditWithOptions instead
|
||||
func (c *Sys) EnableAudit(
|
||||
path string, auditType string, desc string, opts map[string]string) error {
|
||||
body := map[string]interface{}{
|
||||
"type": auditType,
|
||||
"description": desc,
|
||||
"options": opts,
|
||||
}
|
||||
return c.EnableAuditWithOptions(path, &EnableAuditOptions{
|
||||
Type: auditType,
|
||||
Description: desc,
|
||||
Options: opts,
|
||||
})
|
||||
}
|
||||
|
||||
func (c *Sys) EnableAuditWithOptions(path string, options *EnableAuditOptions) error {
|
||||
body := structs.Map(options)
|
||||
|
||||
r := c.c.NewRequest("PUT", fmt.Sprintf("/v1/sys/audit/%s", path))
|
||||
if err := r.SetJSONBody(body); err != nil {
|
||||
@@ -106,9 +112,17 @@ func (c *Sys) DisableAudit(path string) error {
|
||||
// individually documented because the map almost directly to the raw HTTP API
|
||||
// documentation. Please refer to that documentation for more details.
|
||||
|
||||
type EnableAuditOptions struct {
|
||||
Type string `json:"type" structs:"type"`
|
||||
Description string `json:"description" structs:"description"`
|
||||
Options map[string]string `json:"options" structs:"options"`
|
||||
Local bool `json:"local" structs:"local"`
|
||||
}
|
||||
|
||||
type Audit struct {
|
||||
Path string
|
||||
Type string
|
||||
Description string
|
||||
Options map[string]string
|
||||
Local bool
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package api
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/fatih/structs"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
)
|
||||
|
||||
@@ -42,11 +43,16 @@ func (c *Sys) ListAuth() (map[string]*AuthMount, error) {
|
||||
return mounts, nil
|
||||
}
|
||||
|
||||
// DEPRECATED: Use EnableAuthWithOptions instead
|
||||
func (c *Sys) EnableAuth(path, authType, desc string) error {
|
||||
body := map[string]string{
|
||||
"type": authType,
|
||||
"description": desc,
|
||||
}
|
||||
return c.EnableAuthWithOptions(path, &EnableAuthOptions{
|
||||
Type: authType,
|
||||
Description: desc,
|
||||
})
|
||||
}
|
||||
|
||||
func (c *Sys) EnableAuthWithOptions(path string, options *EnableAuthOptions) error {
|
||||
body := structs.Map(options)
|
||||
|
||||
r := c.c.NewRequest("POST", fmt.Sprintf("/v1/sys/auth/%s", path))
|
||||
if err := r.SetJSONBody(body); err != nil {
|
||||
@@ -75,10 +81,17 @@ func (c *Sys) DisableAuth(path string) error {
|
||||
// individually documentd because the map almost directly to the raw HTTP API
|
||||
// documentation. Please refer to that documentation for more details.
|
||||
|
||||
type EnableAuthOptions struct {
|
||||
Type string `json:"type" structs:"type"`
|
||||
Description string `json:"description" structs:"description"`
|
||||
Local bool `json:"local" structs:"local"`
|
||||
}
|
||||
|
||||
type AuthMount struct {
|
||||
Type string `json:"type" structs:"type" mapstructure:"type"`
|
||||
Description string `json:"description" structs:"description" mapstructure:"description"`
|
||||
Config AuthConfigOutput `json:"config" structs:"config" mapstructure:"config"`
|
||||
Local bool `json:"local" structs:"local" mapstructure:"local"`
|
||||
}
|
||||
|
||||
type AuthConfigOutput struct {
|
||||
|
||||
@@ -123,6 +123,7 @@ type MountInput struct {
|
||||
Type string `json:"type" structs:"type"`
|
||||
Description string `json:"description" structs:"description"`
|
||||
Config MountConfigInput `json:"config" structs:"config"`
|
||||
Local bool `json:"local" structs:"local"`
|
||||
}
|
||||
|
||||
type MountConfigInput struct {
|
||||
@@ -134,6 +135,7 @@ type MountOutput struct {
|
||||
Type string `json:"type" structs:"type"`
|
||||
Description string `json:"description" structs:"description"`
|
||||
Config MountConfigOutput `json:"config" structs:"config"`
|
||||
Local bool `json:"local" structs:"local"`
|
||||
}
|
||||
|
||||
type MountConfigOutput struct {
|
||||
|
||||
121
audit/format.go
121
audit/format.go
@@ -27,7 +27,11 @@ func (f *AuditFormatter) FormatRequest(
|
||||
config FormatterConfig,
|
||||
auth *logical.Auth,
|
||||
req *logical.Request,
|
||||
err error) error {
|
||||
inErr error) error {
|
||||
|
||||
if req == nil {
|
||||
return fmt.Errorf("request to request-audit a nil request")
|
||||
}
|
||||
|
||||
if w == nil {
|
||||
return fmt.Errorf("writer for audit request is nil")
|
||||
@@ -49,22 +53,26 @@ func (f *AuditFormatter) FormatRequest(
|
||||
}()
|
||||
}
|
||||
|
||||
// Copy the structures
|
||||
cp, err := copystructure.Copy(auth)
|
||||
if err != nil {
|
||||
return err
|
||||
// Copy the auth structure
|
||||
if auth != nil {
|
||||
cp, err := copystructure.Copy(auth)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
auth = cp.(*logical.Auth)
|
||||
}
|
||||
auth = cp.(*logical.Auth)
|
||||
|
||||
cp, err = copystructure.Copy(req)
|
||||
cp, err := copystructure.Copy(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req = cp.(*logical.Request)
|
||||
|
||||
// Hash any sensitive information
|
||||
if err := Hash(config.Salt, auth); err != nil {
|
||||
return err
|
||||
if auth != nil {
|
||||
if err := Hash(config.Salt, auth); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Cache and restore accessor in the request
|
||||
@@ -85,8 +93,8 @@ func (f *AuditFormatter) FormatRequest(
|
||||
auth = new(logical.Auth)
|
||||
}
|
||||
var errString string
|
||||
if err != nil {
|
||||
errString = err.Error()
|
||||
if inErr != nil {
|
||||
errString = inErr.Error()
|
||||
}
|
||||
|
||||
reqEntry := &AuditRequestEntry{
|
||||
@@ -107,6 +115,7 @@ func (f *AuditFormatter) FormatRequest(
|
||||
Path: req.Path,
|
||||
Data: req.Data,
|
||||
RemoteAddr: getRemoteAddr(req),
|
||||
ReplicationCluster: req.ReplicationCluster,
|
||||
Headers: req.Headers,
|
||||
},
|
||||
}
|
||||
@@ -128,7 +137,11 @@ func (f *AuditFormatter) FormatResponse(
|
||||
auth *logical.Auth,
|
||||
req *logical.Request,
|
||||
resp *logical.Response,
|
||||
err error) error {
|
||||
inErr error) error {
|
||||
|
||||
if req == nil {
|
||||
return fmt.Errorf("request to response-audit a nil request")
|
||||
}
|
||||
|
||||
if w == nil {
|
||||
return fmt.Errorf("writer for audit request is nil")
|
||||
@@ -150,37 +163,43 @@ func (f *AuditFormatter) FormatResponse(
|
||||
}()
|
||||
}
|
||||
|
||||
// Copy the structure
|
||||
cp, err := copystructure.Copy(auth)
|
||||
if err != nil {
|
||||
return err
|
||||
// Copy the auth structure
|
||||
if auth != nil {
|
||||
cp, err := copystructure.Copy(auth)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
auth = cp.(*logical.Auth)
|
||||
}
|
||||
auth = cp.(*logical.Auth)
|
||||
|
||||
cp, err = copystructure.Copy(req)
|
||||
cp, err := copystructure.Copy(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req = cp.(*logical.Request)
|
||||
|
||||
cp, err = copystructure.Copy(resp)
|
||||
if err != nil {
|
||||
return err
|
||||
if resp != nil {
|
||||
cp, err := copystructure.Copy(resp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
resp = cp.(*logical.Response)
|
||||
}
|
||||
resp = cp.(*logical.Response)
|
||||
|
||||
// Hash any sensitive information
|
||||
|
||||
// Cache and restore accessor in the auth
|
||||
var accessor, wrappedAccessor string
|
||||
if !config.HMACAccessor && auth != nil && auth.Accessor != "" {
|
||||
accessor = auth.Accessor
|
||||
}
|
||||
if err := Hash(config.Salt, auth); err != nil {
|
||||
return err
|
||||
}
|
||||
if accessor != "" {
|
||||
auth.Accessor = accessor
|
||||
if auth != nil {
|
||||
var accessor string
|
||||
if !config.HMACAccessor && auth.Accessor != "" {
|
||||
accessor = auth.Accessor
|
||||
}
|
||||
if err := Hash(config.Salt, auth); err != nil {
|
||||
return err
|
||||
}
|
||||
if accessor != "" {
|
||||
auth.Accessor = accessor
|
||||
}
|
||||
}
|
||||
|
||||
// Cache and restore accessor in the request
|
||||
@@ -196,21 +215,23 @@ func (f *AuditFormatter) FormatResponse(
|
||||
}
|
||||
|
||||
// Cache and restore accessor in the response
|
||||
accessor = ""
|
||||
if !config.HMACAccessor && resp != nil && resp.Auth != nil && resp.Auth.Accessor != "" {
|
||||
accessor = resp.Auth.Accessor
|
||||
}
|
||||
if !config.HMACAccessor && resp != nil && resp.WrapInfo != nil && resp.WrapInfo.WrappedAccessor != "" {
|
||||
wrappedAccessor = resp.WrapInfo.WrappedAccessor
|
||||
}
|
||||
if err := Hash(config.Salt, resp); err != nil {
|
||||
return err
|
||||
}
|
||||
if accessor != "" {
|
||||
resp.Auth.Accessor = accessor
|
||||
}
|
||||
if wrappedAccessor != "" {
|
||||
resp.WrapInfo.WrappedAccessor = wrappedAccessor
|
||||
if resp != nil {
|
||||
var accessor, wrappedAccessor string
|
||||
if !config.HMACAccessor && resp != nil && resp.Auth != nil && resp.Auth.Accessor != "" {
|
||||
accessor = resp.Auth.Accessor
|
||||
}
|
||||
if !config.HMACAccessor && resp != nil && resp.WrapInfo != nil && resp.WrapInfo.WrappedAccessor != "" {
|
||||
wrappedAccessor = resp.WrapInfo.WrappedAccessor
|
||||
}
|
||||
if err := Hash(config.Salt, resp); err != nil {
|
||||
return err
|
||||
}
|
||||
if accessor != "" {
|
||||
resp.Auth.Accessor = accessor
|
||||
}
|
||||
if wrappedAccessor != "" {
|
||||
resp.WrapInfo.WrappedAccessor = wrappedAccessor
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -222,8 +243,8 @@ func (f *AuditFormatter) FormatResponse(
|
||||
resp = new(logical.Response)
|
||||
}
|
||||
var errString string
|
||||
if err != nil {
|
||||
errString = err.Error()
|
||||
if inErr != nil {
|
||||
errString = inErr.Error()
|
||||
}
|
||||
|
||||
var respAuth *AuditAuth
|
||||
@@ -276,6 +297,7 @@ func (f *AuditFormatter) FormatResponse(
|
||||
Path: req.Path,
|
||||
Data: req.Data,
|
||||
RemoteAddr: getRemoteAddr(req),
|
||||
ReplicationCluster: req.ReplicationCluster,
|
||||
Headers: req.Headers,
|
||||
},
|
||||
|
||||
@@ -312,14 +334,15 @@ type AuditRequestEntry struct {
|
||||
type AuditResponseEntry struct {
|
||||
Time string `json:"time,omitempty"`
|
||||
Type string `json:"type"`
|
||||
Error string `json:"error"`
|
||||
Auth AuditAuth `json:"auth"`
|
||||
Request AuditRequest `json:"request"`
|
||||
Response AuditResponse `json:"response"`
|
||||
Error string `json:"error"`
|
||||
}
|
||||
|
||||
type AuditRequest struct {
|
||||
ID string `json:"id"`
|
||||
ReplicationCluster string `json:"replication_cluster,omitempty"`
|
||||
Operation logical.Operation `json:"operation"`
|
||||
ClientToken string `json:"client_token"`
|
||||
ClientTokenAccessor string `json:"client_token_accessor"`
|
||||
|
||||
55
audit/format_test.go
Normal file
55
audit/format_test.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package audit
|
||||
|
||||
import (
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"testing"
|
||||
|
||||
"github.com/hashicorp/vault/helper/salt"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
)
|
||||
|
||||
type noopFormatWriter struct {
|
||||
}
|
||||
|
||||
func (n *noopFormatWriter) WriteRequest(_ io.Writer, _ *AuditRequestEntry) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n *noopFormatWriter) WriteResponse(_ io.Writer, _ *AuditResponseEntry) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestFormatRequestErrors(t *testing.T) {
|
||||
salter, _ := salt.NewSalt(nil, nil)
|
||||
config := FormatterConfig{
|
||||
Salt: salter,
|
||||
}
|
||||
formatter := AuditFormatter{
|
||||
AuditFormatWriter: &noopFormatWriter{},
|
||||
}
|
||||
|
||||
if err := formatter.FormatRequest(ioutil.Discard, config, nil, nil, nil); err == nil {
|
||||
t.Fatal("expected error due to nil request")
|
||||
}
|
||||
if err := formatter.FormatRequest(nil, config, nil, &logical.Request{}, nil); err == nil {
|
||||
t.Fatal("expected error due to nil writer")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatResponseErrors(t *testing.T) {
|
||||
salter, _ := salt.NewSalt(nil, nil)
|
||||
config := FormatterConfig{
|
||||
Salt: salter,
|
||||
}
|
||||
formatter := AuditFormatter{
|
||||
AuditFormatWriter: &noopFormatWriter{},
|
||||
}
|
||||
|
||||
if err := formatter.FormatResponse(ioutil.Discard, config, nil, nil, nil, nil); err == nil {
|
||||
t.Fatal("expected error due to nil request")
|
||||
}
|
||||
if err := formatter.FormatResponse(nil, config, nil, &logical.Request{}, nil, nil); err == nil {
|
||||
t.Fatal("expected error due to nil writer")
|
||||
}
|
||||
}
|
||||
@@ -157,10 +157,15 @@ func (b *Backend) open() error {
|
||||
return err
|
||||
}
|
||||
|
||||
// Change the file mode in case the log file already existed
|
||||
err = os.Chmod(b.path, b.mode)
|
||||
if err != nil {
|
||||
return err
|
||||
// Change the file mode in case the log file already existed. We special
|
||||
// case /dev/null since we can't chmod it
|
||||
switch b.path {
|
||||
case "/dev/null":
|
||||
default:
|
||||
err = os.Chmod(b.path, b.mode)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -17,20 +17,10 @@ func Factory(conf *logical.BackendConfig) (logical.Backend, error) {
|
||||
}
|
||||
|
||||
func Backend(conf *logical.BackendConfig) (*framework.Backend, error) {
|
||||
// Initialize the salt
|
||||
salt, err := salt.NewSalt(conf.StorageView, &salt.Config{
|
||||
HashFunc: salt.SHA1Hash,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var b backend
|
||||
b.Salt = salt
|
||||
b.MapAppId = &framework.PolicyMap{
|
||||
PathMap: framework.PathMap{
|
||||
Name: "app-id",
|
||||
Salt: salt,
|
||||
Schema: map[string]*framework.FieldSchema{
|
||||
"display_name": &framework.FieldSchema{
|
||||
Type: framework.TypeString,
|
||||
@@ -48,7 +38,6 @@ func Backend(conf *logical.BackendConfig) (*framework.Backend, error) {
|
||||
|
||||
b.MapUserId = &framework.PathMap{
|
||||
Name: "user-id",
|
||||
Salt: salt,
|
||||
Schema: map[string]*framework.FieldSchema{
|
||||
"cidr_block": &framework.FieldSchema{
|
||||
Type: framework.TypeString,
|
||||
@@ -81,17 +70,11 @@ func Backend(conf *logical.BackendConfig) (*framework.Backend, error) {
|
||||
),
|
||||
|
||||
AuthRenew: b.pathLoginRenew,
|
||||
|
||||
Init: b.initialize,
|
||||
}
|
||||
|
||||
// Since the salt is new in 0.2, we need to handle this by migrating
|
||||
// any existing keys to use the salt. We can deprecate this eventually,
|
||||
// but for now we want a smooth upgrade experience by automatically
|
||||
// upgrading to use salting.
|
||||
if salt.DidGenerate() {
|
||||
if err := b.upgradeToSalted(conf.StorageView); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
b.view = conf.StorageView
|
||||
|
||||
return b.Backend, nil
|
||||
}
|
||||
@@ -100,10 +83,36 @@ type backend struct {
|
||||
*framework.Backend
|
||||
|
||||
Salt *salt.Salt
|
||||
view logical.Storage
|
||||
MapAppId *framework.PolicyMap
|
||||
MapUserId *framework.PathMap
|
||||
}
|
||||
|
||||
func (b *backend) initialize() error {
|
||||
salt, err := salt.NewSalt(b.view, &salt.Config{
|
||||
HashFunc: salt.SHA1Hash,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
b.Salt = salt
|
||||
|
||||
b.MapAppId.Salt = salt
|
||||
b.MapUserId.Salt = salt
|
||||
|
||||
// Since the salt is new in 0.2, we need to handle this by migrating
|
||||
// any existing keys to use the salt. We can deprecate this eventually,
|
||||
// but for now we want a smooth upgrade experience by automatically
|
||||
// upgrading to use salting.
|
||||
if salt.DidGenerate() {
|
||||
if err := b.upgradeToSalted(b.view); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// upgradeToSalted is used to upgrade the non-salted keys prior to
|
||||
// Vault 0.2 to be salted. This is done on mount time and is only
|
||||
// done once. It can be deprecated eventually, but should be around
|
||||
|
||||
@@ -72,6 +72,10 @@ func TestBackend_upgradeToSalted(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
err = backend.Initialize()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
// Check the keys have been upgraded
|
||||
out, err := inm.Get("struct/map/app-id/foo")
|
||||
|
||||
@@ -17,6 +17,9 @@ type backend struct {
|
||||
// by this backend.
|
||||
salt *salt.Salt
|
||||
|
||||
// The view to use when creating the salt
|
||||
view logical.Storage
|
||||
|
||||
// Guard to clean-up the expired SecretID entries
|
||||
tidySecretIDCASGuard uint32
|
||||
|
||||
@@ -57,18 +60,9 @@ func Factory(conf *logical.BackendConfig) (logical.Backend, error) {
|
||||
}
|
||||
|
||||
func Backend(conf *logical.BackendConfig) (*backend, error) {
|
||||
// Initialize the salt
|
||||
salt, err := salt.NewSalt(conf.StorageView, &salt.Config{
|
||||
HashFunc: salt.SHA256Hash,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Create a backend object
|
||||
b := &backend{
|
||||
// Set the salt object for the backend
|
||||
salt: salt,
|
||||
view: conf.StorageView,
|
||||
|
||||
// Create the map of locks to modify the registered roles
|
||||
roleLocksMap: make(map[string]*sync.RWMutex, 257),
|
||||
@@ -83,6 +77,8 @@ func Backend(conf *logical.BackendConfig) (*backend, error) {
|
||||
secretIDAccessorLocksMap: make(map[string]*sync.RWMutex, 257),
|
||||
}
|
||||
|
||||
var err error
|
||||
|
||||
// Create 256 locks each for managing RoleID and SecretIDs. This will avoid
|
||||
// a superfluous number of locks directly proportional to the number of RoleID
|
||||
// and SecretIDs. These locks can be accessed by indexing based on the first two
|
||||
@@ -129,10 +125,22 @@ func Backend(conf *logical.BackendConfig) (*backend, error) {
|
||||
pathTidySecretID(b),
|
||||
},
|
||||
),
|
||||
Init: b.initialize,
|
||||
}
|
||||
return b, nil
|
||||
}
|
||||
|
||||
func (b *backend) initialize() error {
|
||||
salt, err := salt.NewSalt(b.view, &salt.Config{
|
||||
HashFunc: salt.SHA256Hash,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
b.salt = salt
|
||||
return nil
|
||||
}
|
||||
|
||||
// periodicFunc of the backend will be invoked once a minute by the RollbackManager.
|
||||
// RoleRole backend utilizes this function to delete expired SecretID entries.
|
||||
// This could mean that the SecretID may live in the backend upto 1 min after its
|
||||
|
||||
@@ -21,5 +21,9 @@ func createBackendWithStorage(t *testing.T) (*backend, logical.Storage) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = b.Initialize()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return b, config.StorageView
|
||||
}
|
||||
|
||||
@@ -143,7 +143,7 @@ func (b *backend) validateCredentials(req *logical.Request, data *framework.Fiel
|
||||
return nil, "", metadata, fmt.Errorf("failed to verify the CIDR restrictions set on the role: %v", err)
|
||||
}
|
||||
if !belongs {
|
||||
return nil, "", metadata, fmt.Errorf("source address unauthorized through CIDR restrictions on the role")
|
||||
return nil, "", metadata, fmt.Errorf("source address %q unauthorized through CIDR restrictions on the role", req.Connection.RemoteAddr)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -199,7 +199,7 @@ func (b *backend) validateBindSecretID(req *logical.Request, roleName, secretID,
|
||||
}
|
||||
|
||||
if belongs, err := cidrutil.IPBelongsToCIDRBlocksSlice(req.Connection.RemoteAddr, result.CIDRList); !belongs || err != nil {
|
||||
return false, nil, fmt.Errorf("source address unauthorized through CIDR restrictions on the secret ID: %v", err)
|
||||
return false, nil, fmt.Errorf("source address %q unauthorized through CIDR restrictions on the secret ID: %v", req.Connection.RemoteAddr, err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -261,7 +261,7 @@ func (b *backend) validateBindSecretID(req *logical.Request, roleName, secretID,
|
||||
}
|
||||
|
||||
if belongs, err := cidrutil.IPBelongsToCIDRBlocksSlice(req.Connection.RemoteAddr, result.CIDRList); !belongs || err != nil {
|
||||
return false, nil, fmt.Errorf("source address unauthorized through CIDR restrictions on the secret ID: %v", err)
|
||||
return false, nil, fmt.Errorf("source address %q unauthorized through CIDR restrictions on the secret ID: %v", req.Connection.RemoteAddr, err)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -23,6 +23,9 @@ type backend struct {
|
||||
*framework.Backend
|
||||
Salt *salt.Salt
|
||||
|
||||
// Used during initialization to set the salt
|
||||
view logical.Storage
|
||||
|
||||
// Lock to make changes to any of the backend's configuration endpoints.
|
||||
configMutex sync.RWMutex
|
||||
|
||||
@@ -59,18 +62,11 @@ type backend struct {
|
||||
}
|
||||
|
||||
func Backend(conf *logical.BackendConfig) (*backend, error) {
|
||||
salt, err := salt.NewSalt(conf.StorageView, &salt.Config{
|
||||
HashFunc: salt.SHA256Hash,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
b := &backend{
|
||||
// Setting the periodic func to be run once in an hour.
|
||||
// If there is a real need, this can be made configurable.
|
||||
tidyCooldownPeriod: time.Hour,
|
||||
Salt: salt,
|
||||
view: conf.StorageView,
|
||||
EC2ClientsMap: make(map[string]map[string]*ec2.EC2),
|
||||
IAMClientsMap: make(map[string]map[string]*iam.IAM),
|
||||
}
|
||||
@@ -83,6 +79,9 @@ func Backend(conf *logical.BackendConfig) (*backend, error) {
|
||||
Unauthenticated: []string{
|
||||
"login",
|
||||
},
|
||||
LocalStorage: []string{
|
||||
"whitelist/identity/",
|
||||
},
|
||||
},
|
||||
Paths: []*framework.Path{
|
||||
pathLogin(b),
|
||||
@@ -104,11 +103,26 @@ func Backend(conf *logical.BackendConfig) (*backend, error) {
|
||||
pathIdentityWhitelist(b),
|
||||
pathTidyIdentityWhitelist(b),
|
||||
},
|
||||
|
||||
Invalidate: b.invalidate,
|
||||
|
||||
Init: b.initialize,
|
||||
}
|
||||
|
||||
return b, nil
|
||||
}
|
||||
|
||||
func (b *backend) initialize() error {
|
||||
salt, err := salt.NewSalt(b.view, &salt.Config{
|
||||
HashFunc: salt.SHA256Hash,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
b.Salt = salt
|
||||
return nil
|
||||
}
|
||||
|
||||
// periodicFunc performs the tasks that the backend wishes to do periodically.
|
||||
// Currently this will be triggered once in a minute by the RollbackManager.
|
||||
//
|
||||
@@ -169,6 +183,16 @@ func (b *backend) periodicFunc(req *logical.Request) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *backend) invalidate(key string) {
|
||||
switch key {
|
||||
case "config/client":
|
||||
b.configMutex.Lock()
|
||||
defer b.configMutex.Unlock()
|
||||
b.flushCachedEC2Clients()
|
||||
b.flushCachedIAMClients()
|
||||
}
|
||||
}
|
||||
|
||||
const backendHelp = `
|
||||
aws-ec2 auth backend takes in PKCS#7 signature of an AWS EC2 instance and a client
|
||||
created nonce to authenticates the EC2 instance with Vault.
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package cert
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/hashicorp/vault/logical"
|
||||
@@ -13,7 +14,7 @@ func Factory(conf *logical.BackendConfig) (logical.Backend, error) {
|
||||
if err != nil {
|
||||
return b, err
|
||||
}
|
||||
return b, b.populateCRLs(conf.StorageView)
|
||||
return b, nil
|
||||
}
|
||||
|
||||
func Backend() *backend {
|
||||
@@ -36,9 +37,10 @@ func Backend() *backend {
|
||||
}),
|
||||
|
||||
AuthRenew: b.pathLoginRenew,
|
||||
|
||||
Invalidate: b.invalidate,
|
||||
}
|
||||
|
||||
b.crls = map[string]CRLInfo{}
|
||||
b.crlUpdateMutex = &sync.RWMutex{}
|
||||
|
||||
return &b
|
||||
@@ -52,6 +54,15 @@ type backend struct {
|
||||
crlUpdateMutex *sync.RWMutex
|
||||
}
|
||||
|
||||
func (b *backend) invalidate(key string) {
|
||||
switch {
|
||||
case strings.HasPrefix(key, "crls/"):
|
||||
b.crlUpdateMutex.Lock()
|
||||
defer b.crlUpdateMutex.Unlock()
|
||||
b.crls = nil
|
||||
}
|
||||
}
|
||||
|
||||
const backendHelp = `
|
||||
The "cert" credential provider allows authentication using
|
||||
TLS client certificates. A client connects to Vault and uses
|
||||
|
||||
@@ -45,6 +45,12 @@ func (b *backend) populateCRLs(storage logical.Storage) error {
|
||||
b.crlUpdateMutex.Lock()
|
||||
defer b.crlUpdateMutex.Unlock()
|
||||
|
||||
if b.crls != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
b.crls = map[string]CRLInfo{}
|
||||
|
||||
keys, err := storage.List("crls/")
|
||||
if err != nil {
|
||||
return fmt.Errorf("error listing CRLs: %v", err)
|
||||
@@ -56,6 +62,7 @@ func (b *backend) populateCRLs(storage logical.Storage) error {
|
||||
for _, key := range keys {
|
||||
entry, err := storage.Get("crls/" + key)
|
||||
if err != nil {
|
||||
b.crls = nil
|
||||
return fmt.Errorf("error loading CRL %s: %v", key, err)
|
||||
}
|
||||
if entry == nil {
|
||||
@@ -64,6 +71,7 @@ func (b *backend) populateCRLs(storage logical.Storage) error {
|
||||
var crlInfo CRLInfo
|
||||
err = entry.DecodeJSON(&crlInfo)
|
||||
if err != nil {
|
||||
b.crls = nil
|
||||
return fmt.Errorf("error decoding CRL %s: %v", key, err)
|
||||
}
|
||||
b.crls[key] = crlInfo
|
||||
@@ -121,6 +129,10 @@ func (b *backend) pathCRLDelete(
|
||||
return logical.ErrorResponse(`"name" parameter cannot be empty`), nil
|
||||
}
|
||||
|
||||
if err := b.populateCRLs(req.Storage); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
b.crlUpdateMutex.Lock()
|
||||
defer b.crlUpdateMutex.Unlock()
|
||||
|
||||
@@ -131,8 +143,7 @@ func (b *backend) pathCRLDelete(
|
||||
)), nil
|
||||
}
|
||||
|
||||
err := req.Storage.Delete("crls/" + name)
|
||||
if err != nil {
|
||||
if err := req.Storage.Delete("crls/" + name); err != nil {
|
||||
return logical.ErrorResponse(fmt.Sprintf(
|
||||
"error deleting crl %s: %v", name, err),
|
||||
), nil
|
||||
@@ -150,6 +161,10 @@ func (b *backend) pathCRLRead(
|
||||
return logical.ErrorResponse(`"name" parameter must be set`), nil
|
||||
}
|
||||
|
||||
if err := b.populateCRLs(req.Storage); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
b.crlUpdateMutex.RLock()
|
||||
defer b.crlUpdateMutex.RUnlock()
|
||||
|
||||
@@ -185,6 +200,10 @@ func (b *backend) pathCRLWrite(
|
||||
return logical.ErrorResponse("parsed CRL is nil"), nil
|
||||
}
|
||||
|
||||
if err := b.populateCRLs(req.Storage); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
b.crlUpdateMutex.Lock()
|
||||
defer b.crlUpdateMutex.Unlock()
|
||||
|
||||
|
||||
@@ -17,6 +17,12 @@ func Backend() *backend {
|
||||
b.Backend = &framework.Backend{
|
||||
Help: strings.TrimSpace(backendHelp),
|
||||
|
||||
PathsSpecial: &logical.Paths{
|
||||
LocalStorage: []string{
|
||||
framework.WALPrefix,
|
||||
},
|
||||
},
|
||||
|
||||
Paths: []*framework.Path{
|
||||
pathConfigRoot(),
|
||||
pathConfigLease(&b),
|
||||
|
||||
@@ -31,6 +31,8 @@ func Backend() *backend {
|
||||
secretCreds(&b),
|
||||
},
|
||||
|
||||
Invalidate: b.invalidate,
|
||||
|
||||
Clean: func() {
|
||||
b.ResetDB(nil)
|
||||
},
|
||||
@@ -107,6 +109,13 @@ func (b *backend) ResetDB(newSession *gocql.Session) {
|
||||
b.session = newSession
|
||||
}
|
||||
|
||||
func (b *backend) invalidate(key string) {
|
||||
switch key {
|
||||
case "config/connection":
|
||||
b.ResetDB(nil)
|
||||
}
|
||||
}
|
||||
|
||||
const backendHelp = `
|
||||
The Cassandra backend dynamically generates database users.
|
||||
|
||||
|
||||
@@ -421,7 +421,7 @@ seed_provider:
|
||||
parameters:
|
||||
# seeds is actually a comma-delimited list of addresses.
|
||||
# Ex: "<ip1>,<ip2>,<ip3>"
|
||||
- seeds: "172.17.0.2"
|
||||
- seeds: "172.17.0.3"
|
||||
|
||||
# For workloads with more data than can fit in memory, Cassandra's
|
||||
# bottleneck will be reads that need to fetch data from
|
||||
@@ -572,7 +572,7 @@ ssl_storage_port: 7001
|
||||
#
|
||||
# Setting listen_address to 0.0.0.0 is always wrong.
|
||||
#
|
||||
listen_address: 172.17.0.2
|
||||
listen_address: 172.17.0.3
|
||||
|
||||
# Set listen_address OR listen_interface, not both. Interfaces must correspond
|
||||
# to a single address, IP aliasing is not supported.
|
||||
@@ -586,7 +586,7 @@ listen_address: 172.17.0.2
|
||||
|
||||
# Address to broadcast to other Cassandra nodes
|
||||
# Leaving this blank will set it to the same value as listen_address
|
||||
broadcast_address: 172.17.0.2
|
||||
broadcast_address: 172.17.0.3
|
||||
|
||||
# When using multiple physical network interfaces, set this
|
||||
# to true to listen on broadcast_address in addition to
|
||||
@@ -668,7 +668,7 @@ rpc_port: 9160
|
||||
# be set to 0.0.0.0. If left blank, this will be set to the value of
|
||||
# rpc_address. If rpc_address is set to 0.0.0.0, broadcast_rpc_address must
|
||||
# be set.
|
||||
broadcast_rpc_address: 172.17.0.2
|
||||
broadcast_rpc_address: 172.17.0.3
|
||||
|
||||
# enable or disable keepalive on rpc/native connections
|
||||
rpc_keepalive: true
|
||||
|
||||
@@ -33,6 +33,8 @@ func Backend() *framework.Backend {
|
||||
},
|
||||
|
||||
Clean: b.ResetSession,
|
||||
|
||||
Invalidate: b.invalidate,
|
||||
}
|
||||
|
||||
return b.Backend
|
||||
@@ -97,6 +99,13 @@ func (b *backend) ResetSession() {
|
||||
b.session = nil
|
||||
}
|
||||
|
||||
func (b *backend) invalidate(key string) {
|
||||
switch key {
|
||||
case "config/connection":
|
||||
b.ResetSession()
|
||||
}
|
||||
}
|
||||
|
||||
// LeaseConfig returns the lease configuration
|
||||
func (b *backend) LeaseConfig(s logical.Storage) (*configLease, error) {
|
||||
entry, err := s.Get("config/lease")
|
||||
|
||||
@@ -32,6 +32,8 @@ func Backend() *backend {
|
||||
secretCreds(&b),
|
||||
},
|
||||
|
||||
Invalidate: b.invalidate,
|
||||
|
||||
Clean: b.ResetDB,
|
||||
}
|
||||
|
||||
@@ -112,6 +114,13 @@ func (b *backend) ResetDB() {
|
||||
b.db = nil
|
||||
}
|
||||
|
||||
func (b *backend) invalidate(key string) {
|
||||
switch key {
|
||||
case "config/connection":
|
||||
b.ResetDB()
|
||||
}
|
||||
}
|
||||
|
||||
// LeaseConfig returns the lease configuration
|
||||
func (b *backend) LeaseConfig(s logical.Storage) (*configLease, error) {
|
||||
entry, err := s.Get("config/lease")
|
||||
|
||||
@@ -32,6 +32,8 @@ func Backend() *backend {
|
||||
secretCreds(&b),
|
||||
},
|
||||
|
||||
Invalidate: b.invalidate,
|
||||
|
||||
Clean: b.ResetDB,
|
||||
}
|
||||
|
||||
@@ -105,6 +107,13 @@ func (b *backend) ResetDB() {
|
||||
b.db = nil
|
||||
}
|
||||
|
||||
func (b *backend) invalidate(key string) {
|
||||
switch key {
|
||||
case "config/connection":
|
||||
b.ResetDB()
|
||||
}
|
||||
}
|
||||
|
||||
// Lease returns the lease information
|
||||
func (b *backend) Lease(s logical.Storage) (*configLease, error) {
|
||||
entry, err := s.Get("config/lease")
|
||||
|
||||
@@ -29,6 +29,12 @@ func Backend() *backend {
|
||||
"crl/pem",
|
||||
"crl",
|
||||
},
|
||||
|
||||
LocalStorage: []string{
|
||||
"revoked/",
|
||||
"crl",
|
||||
"certs/",
|
||||
},
|
||||
},
|
||||
|
||||
Paths: []*framework.Path{
|
||||
|
||||
@@ -1478,6 +1478,27 @@ func generateRoleSteps(t *testing.T, useCSRs bool) []logicaltest.TestStep {
|
||||
}
|
||||
}
|
||||
|
||||
getOrganizationCheck := func(role roleEntry) logicaltest.TestCheckFunc {
|
||||
var certBundle certutil.CertBundle
|
||||
return func(resp *logical.Response) error {
|
||||
err := mapstructure.Decode(resp.Data, &certBundle)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
parsedCertBundle, err := certBundle.ToParsedCertBundle()
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error checking generated certificate: %s", err)
|
||||
}
|
||||
cert := parsedCertBundle.Certificate
|
||||
|
||||
expected := strutil.ParseDedupAndSortStrings(role.Organization, ",")
|
||||
if !reflect.DeepEqual(cert.Subject.Organization, expected) {
|
||||
return fmt.Errorf("Error: returned certificate has Organization of %s but %s was specified in the role.", cert.Subject.Organization, expected)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Returns a TestCheckFunc that performs various validity checks on the
|
||||
// returned certificate information, mostly within checkCertsAndPrivateKey
|
||||
getCnCheck := func(name string, role roleEntry, key crypto.Signer, usage x509.KeyUsage, extUsage x509.ExtKeyUsage, validity time.Duration) logicaltest.TestCheckFunc {
|
||||
@@ -1755,6 +1776,14 @@ func generateRoleSteps(t *testing.T, useCSRs bool) []logicaltest.TestStep {
|
||||
roleVals.OU = "foo,bar"
|
||||
addTests(getOuCheck(roleVals))
|
||||
}
|
||||
// Organization tests
|
||||
{
|
||||
roleVals.Organization = "system:masters"
|
||||
addTests(getOrganizationCheck(roleVals))
|
||||
|
||||
roleVals.Organization = "foo,bar"
|
||||
addTests(getOrganizationCheck(roleVals))
|
||||
}
|
||||
// IP SAN tests
|
||||
{
|
||||
issueVals.IPSANs = "127.0.0.1,::1"
|
||||
|
||||
@@ -35,6 +35,7 @@ const (
|
||||
type creationBundle struct {
|
||||
CommonName string
|
||||
OU []string
|
||||
Organization []string
|
||||
DNSNames []string
|
||||
EmailAddresses []string
|
||||
IPAddresses []net.IP
|
||||
@@ -581,6 +582,14 @@ func generateCreationBundle(b *backend,
|
||||
}
|
||||
}
|
||||
|
||||
// Set O (organization) values if specified in the role
|
||||
organization := []string{}
|
||||
{
|
||||
if role.Organization != "" {
|
||||
organization = strutil.ParseDedupAndSortStrings(role.Organization, ",")
|
||||
}
|
||||
}
|
||||
|
||||
// Read in alternate names -- DNS and email addresses
|
||||
dnsNames := []string{}
|
||||
emailAddresses := []string{}
|
||||
@@ -728,6 +737,7 @@ func generateCreationBundle(b *backend,
|
||||
creationBundle := &creationBundle{
|
||||
CommonName: cn,
|
||||
OU: ou,
|
||||
Organization: organization,
|
||||
DNSNames: dnsNames,
|
||||
EmailAddresses: emailAddresses,
|
||||
IPAddresses: ipAddresses,
|
||||
@@ -820,6 +830,7 @@ func createCertificate(creationInfo *creationBundle) (*certutil.ParsedCertBundle
|
||||
subject := pkix.Name{
|
||||
CommonName: creationInfo.CommonName,
|
||||
OrganizationalUnit: creationInfo.OU,
|
||||
Organization: creationInfo.Organization,
|
||||
}
|
||||
|
||||
certTemplate := &x509.Certificate{
|
||||
@@ -983,6 +994,7 @@ func signCertificate(creationInfo *creationBundle,
|
||||
subject := pkix.Name{
|
||||
CommonName: creationInfo.CommonName,
|
||||
OrganizationalUnit: creationInfo.OU,
|
||||
Organization: creationInfo.Organization,
|
||||
}
|
||||
|
||||
certTemplate := &x509.Certificate{
|
||||
|
||||
@@ -172,6 +172,13 @@ Names. Defaults to true.`,
|
||||
Type: framework.TypeString,
|
||||
Default: "",
|
||||
Description: `If set, the OU (OrganizationalUnit) will be set to
|
||||
this value in certificates issued by this role.`,
|
||||
},
|
||||
|
||||
"organization": &framework.FieldSchema{
|
||||
Type: framework.TypeString,
|
||||
Default: "",
|
||||
Description: `If set, the O (Organization) will be set to
|
||||
this value in certificates issued by this role.`,
|
||||
},
|
||||
},
|
||||
@@ -336,6 +343,7 @@ func (b *backend) pathRoleCreate(
|
||||
UseCSRCommonName: data.Get("use_csr_common_name").(bool),
|
||||
KeyUsage: data.Get("key_usage").(string),
|
||||
OU: data.Get("ou").(string),
|
||||
Organization: data.Get("organization").(string),
|
||||
}
|
||||
|
||||
if entry.KeyType == "rsa" && entry.KeyBits < 2048 {
|
||||
@@ -451,6 +459,7 @@ type roleEntry struct {
|
||||
MaxPathLength *int `json:",omitempty" structs:",omitempty"`
|
||||
KeyUsage string `json:"key_usage" structs:"key_usage" mapstructure:"key_usage"`
|
||||
OU string `json:"ou" structs:"ou" mapstructure:"ou"`
|
||||
Organization string `json:"organization" structs:"organization" mapstructure:"organization"`
|
||||
}
|
||||
|
||||
const pathListRolesHelpSyn = `List the existing roles in this backend`
|
||||
|
||||
@@ -34,6 +34,8 @@ func Backend(conf *logical.BackendConfig) *backend {
|
||||
},
|
||||
|
||||
Clean: b.ResetDB,
|
||||
|
||||
Invalidate: b.invalidate,
|
||||
}
|
||||
|
||||
b.logger = conf.Logger
|
||||
@@ -126,6 +128,13 @@ func (b *backend) ResetDB() {
|
||||
b.db = nil
|
||||
}
|
||||
|
||||
func (b *backend) invalidate(key string) {
|
||||
switch key {
|
||||
case "config/connection":
|
||||
b.ResetDB()
|
||||
}
|
||||
}
|
||||
|
||||
// Lease returns the lease information
|
||||
func (b *backend) Lease(s logical.Storage) (*configLease, error) {
|
||||
entry, err := s.Get("config/lease")
|
||||
|
||||
@@ -35,6 +35,8 @@ func Backend() *backend {
|
||||
},
|
||||
|
||||
Clean: b.resetClient,
|
||||
|
||||
Invalidate: b.invalidate,
|
||||
}
|
||||
|
||||
return &b
|
||||
@@ -99,6 +101,13 @@ func (b *backend) resetClient() {
|
||||
b.client = nil
|
||||
}
|
||||
|
||||
func (b *backend) invalidate(key string) {
|
||||
switch key {
|
||||
case "config/connection":
|
||||
b.resetClient()
|
||||
}
|
||||
}
|
||||
|
||||
// Lease returns the lease information
|
||||
func (b *backend) Lease(s logical.Storage) (*configLease, error) {
|
||||
entry, err := s.Get("config/lease")
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
|
||||
type backend struct {
|
||||
*framework.Backend
|
||||
view logical.Storage
|
||||
salt *salt.Salt
|
||||
}
|
||||
|
||||
@@ -22,15 +23,8 @@ func Factory(conf *logical.BackendConfig) (logical.Backend, error) {
|
||||
}
|
||||
|
||||
func Backend(conf *logical.BackendConfig) (*backend, error) {
|
||||
salt, err := salt.NewSalt(conf.StorageView, &salt.Config{
|
||||
HashFunc: salt.SHA256Hash,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var b backend
|
||||
b.salt = salt
|
||||
b.view = conf.StorageView
|
||||
b.Backend = &framework.Backend{
|
||||
Help: strings.TrimSpace(backendHelp),
|
||||
|
||||
@@ -38,6 +32,10 @@ func Backend(conf *logical.BackendConfig) (*backend, error) {
|
||||
Unauthenticated: []string{
|
||||
"verify",
|
||||
},
|
||||
|
||||
LocalStorage: []string{
|
||||
"otp/",
|
||||
},
|
||||
},
|
||||
|
||||
Paths: []*framework.Path{
|
||||
@@ -54,10 +52,23 @@ func Backend(conf *logical.BackendConfig) (*backend, error) {
|
||||
secretDynamicKey(&b),
|
||||
secretOTP(&b),
|
||||
},
|
||||
|
||||
Init: b.Initialize,
|
||||
}
|
||||
return &b, nil
|
||||
}
|
||||
|
||||
func (b *backend) Initialize() error {
|
||||
salt, err := salt.NewSalt(b.view, &salt.Config{
|
||||
HashFunc: salt.SHA256Hash,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
b.salt = salt
|
||||
return nil
|
||||
}
|
||||
|
||||
const backendHelp = `
|
||||
The SSH backend generates credentials allowing clients to establish SSH
|
||||
connections to remote hosts.
|
||||
|
||||
@@ -73,6 +73,10 @@ func TestBackend_allowed_users(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = b.Initialize()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
roleData := map[string]interface{}{
|
||||
"key_type": "otp",
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package transit
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/vault/helper/keysutil"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/logical/framework"
|
||||
@@ -39,6 +41,8 @@ func Backend(conf *logical.BackendConfig) *backend {
|
||||
},
|
||||
|
||||
Secrets: []*framework.Secret{},
|
||||
|
||||
Invalidate: b.invalidate,
|
||||
}
|
||||
|
||||
b.lm = keysutil.NewLockManager(conf.System.CachingDisabled())
|
||||
@@ -50,3 +54,14 @@ type backend struct {
|
||||
*framework.Backend
|
||||
lm *keysutil.LockManager
|
||||
}
|
||||
|
||||
func (b *backend) invalidate(key string) {
|
||||
if b.Logger().IsTrace() {
|
||||
b.Logger().Trace("transit: invalidating key", "key", key)
|
||||
}
|
||||
switch {
|
||||
case strings.HasPrefix(key, "policy/"):
|
||||
name := strings.TrimPrefix(key, "policy/")
|
||||
b.lm.InvalidatePolicy(name)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package command
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/hashicorp/vault/api"
|
||||
"github.com/hashicorp/vault/http"
|
||||
"github.com/hashicorp/vault/meta"
|
||||
"github.com/hashicorp/vault/vault"
|
||||
@@ -44,3 +45,42 @@ func TestAuditDisable(t *testing.T) {
|
||||
t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuditDisableWithOptions(t *testing.T) {
|
||||
core, _, token := vault.TestCoreUnsealed(t)
|
||||
ln, addr := http.TestServer(t, core)
|
||||
defer ln.Close()
|
||||
|
||||
ui := new(cli.MockUi)
|
||||
c := &AuditDisableCommand{
|
||||
Meta: meta.Meta{
|
||||
ClientToken: token,
|
||||
Ui: ui,
|
||||
},
|
||||
}
|
||||
|
||||
args := []string{
|
||||
"-address", addr,
|
||||
"noop",
|
||||
}
|
||||
|
||||
// Run once to get the client
|
||||
c.Run(args)
|
||||
|
||||
// Get the client
|
||||
client, err := c.Client()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %#v", err)
|
||||
}
|
||||
if err := client.Sys().EnableAuditWithOptions("noop", &api.EnableAuditOptions{
|
||||
Type: "noop",
|
||||
Description: "noop",
|
||||
}); err != nil {
|
||||
t.Fatalf("err: %#v", err)
|
||||
}
|
||||
|
||||
// Run again
|
||||
if code := c.Run(args); code != 0 {
|
||||
t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/vault/api"
|
||||
"github.com/hashicorp/vault/helper/kv-builder"
|
||||
"github.com/hashicorp/vault/meta"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
@@ -21,9 +22,11 @@ type AuditEnableCommand struct {
|
||||
|
||||
func (c *AuditEnableCommand) Run(args []string) int {
|
||||
var desc, path string
|
||||
var local bool
|
||||
flags := c.Meta.FlagSet("audit-enable", meta.FlagSetDefault)
|
||||
flags.StringVar(&desc, "description", "", "")
|
||||
flags.StringVar(&path, "path", "", "")
|
||||
flags.BoolVar(&local, "local", false, "")
|
||||
flags.Usage = func() { c.Ui.Error(c.Help()) }
|
||||
if err := flags.Parse(args); err != nil {
|
||||
return 1
|
||||
@@ -68,7 +71,12 @@ func (c *AuditEnableCommand) Run(args []string) int {
|
||||
return 1
|
||||
}
|
||||
|
||||
err = client.Sys().EnableAudit(path, auditType, desc, opts)
|
||||
err = client.Sys().EnableAuditWithOptions(path, &api.EnableAuditOptions{
|
||||
Type: auditType,
|
||||
Description: desc,
|
||||
Options: opts,
|
||||
Local: local,
|
||||
})
|
||||
if err != nil {
|
||||
c.Ui.Error(fmt.Sprintf(
|
||||
"Error enabling audit backend: %s", err))
|
||||
@@ -113,6 +121,9 @@ Audit Enable Options:
|
||||
is purely for referencing this audit backend. By
|
||||
default this will be the backend type.
|
||||
|
||||
-local Mark the mount as a local mount. Local mounts
|
||||
are not replicated nor (if a secondary)
|
||||
removed by replication.
|
||||
`
|
||||
return strings.TrimSpace(helpText)
|
||||
}
|
||||
|
||||
@@ -48,16 +48,19 @@ func (c *AuditListCommand) Run(args []string) int {
|
||||
}
|
||||
sort.Strings(paths)
|
||||
|
||||
columns := []string{"Path | Type | Description | Options"}
|
||||
columns := []string{"Path | Type | Description | Replication Behavior | Options"}
|
||||
for _, path := range paths {
|
||||
audit := audits[path]
|
||||
opts := make([]string, 0, len(audit.Options))
|
||||
for k, v := range audit.Options {
|
||||
opts = append(opts, k+"="+v)
|
||||
}
|
||||
|
||||
replicatedBehavior := "replicated"
|
||||
if audit.Local {
|
||||
replicatedBehavior = "local"
|
||||
}
|
||||
columns = append(columns, fmt.Sprintf(
|
||||
"%s | %s | %s | %s", audit.Path, audit.Type, audit.Description, strings.Join(opts, " ")))
|
||||
"%s | %s | %s | %s | %s", audit.Path, audit.Type, audit.Description, replicatedBehavior, strings.Join(opts, " ")))
|
||||
}
|
||||
|
||||
c.Ui.Output(columnize.SimpleFormat(columns))
|
||||
|
||||
@@ -3,6 +3,7 @@ package command
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/hashicorp/vault/api"
|
||||
"github.com/hashicorp/vault/http"
|
||||
"github.com/hashicorp/vault/meta"
|
||||
"github.com/hashicorp/vault/vault"
|
||||
@@ -34,7 +35,11 @@ func TestAuditList(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("err: %#v", err)
|
||||
}
|
||||
if err := client.Sys().EnableAudit("foo", "noop", "", nil); err != nil {
|
||||
if err := client.Sys().EnableAuditWithOptions("foo", &api.EnableAuditOptions{
|
||||
Type: "noop",
|
||||
Description: "noop",
|
||||
Options: nil,
|
||||
}); err != nil {
|
||||
t.Fatalf("err: %#v", err)
|
||||
}
|
||||
|
||||
|
||||
@@ -281,7 +281,7 @@ func (c *AuthCommand) listMethods() int {
|
||||
}
|
||||
sort.Strings(paths)
|
||||
|
||||
columns := []string{"Path | Type | Default TTL | Max TTL | Description"}
|
||||
columns := []string{"Path | Type | Default TTL | Max TTL | Replication Behavior | Description"}
|
||||
for _, path := range paths {
|
||||
auth := auth[path]
|
||||
defTTL := "system"
|
||||
@@ -292,8 +292,12 @@ func (c *AuthCommand) listMethods() int {
|
||||
if auth.Config.MaxLeaseTTL != 0 {
|
||||
maxTTL = strconv.Itoa(auth.Config.MaxLeaseTTL)
|
||||
}
|
||||
replicatedBehavior := "replicated"
|
||||
if auth.Local {
|
||||
replicatedBehavior = "local"
|
||||
}
|
||||
columns = append(columns, fmt.Sprintf(
|
||||
"%s | %s | %s | %s | %s", path, auth.Type, defTTL, maxTTL, auth.Description))
|
||||
"%s | %s | %s | %s | %s | %s", path, auth.Type, defTTL, maxTTL, replicatedBehavior, auth.Description))
|
||||
}
|
||||
|
||||
c.Ui.Output(columnize.SimpleFormat(columns))
|
||||
|
||||
@@ -3,6 +3,7 @@ package command
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/hashicorp/vault/api"
|
||||
"github.com/hashicorp/vault/http"
|
||||
"github.com/hashicorp/vault/meta"
|
||||
"github.com/hashicorp/vault/vault"
|
||||
@@ -52,3 +53,50 @@ func TestAuthDisable(t *testing.T) {
|
||||
t.Fatal("should not have noop mount")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthDisableWithOptions(t *testing.T) {
|
||||
core, _, token := vault.TestCoreUnsealed(t)
|
||||
ln, addr := http.TestServer(t, core)
|
||||
defer ln.Close()
|
||||
|
||||
ui := new(cli.MockUi)
|
||||
c := &AuthDisableCommand{
|
||||
Meta: meta.Meta{
|
||||
ClientToken: token,
|
||||
Ui: ui,
|
||||
},
|
||||
}
|
||||
|
||||
args := []string{
|
||||
"-address", addr,
|
||||
"noop",
|
||||
}
|
||||
|
||||
// Run the command once to setup the client, it will fail
|
||||
c.Run(args)
|
||||
|
||||
client, err := c.Client()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
if err := client.Sys().EnableAuthWithOptions("noop", &api.EnableAuthOptions{
|
||||
Type: "noop",
|
||||
Description: "",
|
||||
}); err != nil {
|
||||
t.Fatalf("err: %#v", err)
|
||||
}
|
||||
|
||||
if code := c.Run(args); code != 0 {
|
||||
t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String())
|
||||
}
|
||||
|
||||
mounts, err := client.Sys().ListAuth()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
if _, ok := mounts["noop"]; ok {
|
||||
t.Fatal("should not have noop mount")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/vault/api"
|
||||
"github.com/hashicorp/vault/meta"
|
||||
)
|
||||
|
||||
@@ -14,9 +15,11 @@ type AuthEnableCommand struct {
|
||||
|
||||
func (c *AuthEnableCommand) Run(args []string) int {
|
||||
var description, path string
|
||||
var local bool
|
||||
flags := c.Meta.FlagSet("auth-enable", meta.FlagSetDefault)
|
||||
flags.StringVar(&description, "description", "", "")
|
||||
flags.StringVar(&path, "path", "", "")
|
||||
flags.BoolVar(&local, "local", false, "")
|
||||
flags.Usage = func() { c.Ui.Error(c.Help()) }
|
||||
if err := flags.Parse(args); err != nil {
|
||||
return 1
|
||||
@@ -44,7 +47,11 @@ func (c *AuthEnableCommand) Run(args []string) int {
|
||||
return 2
|
||||
}
|
||||
|
||||
if err := client.Sys().EnableAuth(path, authType, description); err != nil {
|
||||
if err := client.Sys().EnableAuthWithOptions(path, &api.EnableAuthOptions{
|
||||
Type: authType,
|
||||
Description: description,
|
||||
Local: local,
|
||||
}); err != nil {
|
||||
c.Ui.Error(fmt.Sprintf(
|
||||
"Error: %s", err))
|
||||
return 2
|
||||
@@ -82,6 +89,9 @@ Auth Enable Options:
|
||||
to the type of the mount. This will make the auth
|
||||
provider available at "/auth/<path>"
|
||||
|
||||
-local Mark the mount as a local mount. Local mounts
|
||||
are not replicated nor (if a secondary)
|
||||
removed by replication.
|
||||
`
|
||||
return strings.TrimSpace(helpText)
|
||||
}
|
||||
|
||||
@@ -15,11 +15,13 @@ type MountCommand struct {
|
||||
|
||||
func (c *MountCommand) Run(args []string) int {
|
||||
var description, path, defaultLeaseTTL, maxLeaseTTL string
|
||||
var local bool
|
||||
flags := c.Meta.FlagSet("mount", meta.FlagSetDefault)
|
||||
flags.StringVar(&description, "description", "", "")
|
||||
flags.StringVar(&path, "path", "", "")
|
||||
flags.StringVar(&defaultLeaseTTL, "default-lease-ttl", "", "")
|
||||
flags.StringVar(&maxLeaseTTL, "max-lease-ttl", "", "")
|
||||
flags.BoolVar(&local, "local", false, "")
|
||||
flags.Usage = func() { c.Ui.Error(c.Help()) }
|
||||
if err := flags.Parse(args); err != nil {
|
||||
return 1
|
||||
@@ -54,6 +56,7 @@ func (c *MountCommand) Run(args []string) int {
|
||||
DefaultLeaseTTL: defaultLeaseTTL,
|
||||
MaxLeaseTTL: maxLeaseTTL,
|
||||
},
|
||||
Local: local,
|
||||
}
|
||||
|
||||
if err := client.Sys().Mount(path, mountInfo); err != nil {
|
||||
@@ -102,6 +105,10 @@ Mount Options:
|
||||
the previously set value. Set to '0' to
|
||||
explicitly set it to use the global default.
|
||||
|
||||
-local Mark the mount as a local mount. Local mounts
|
||||
are not replicated nor (if a secondary)
|
||||
removed by replication.
|
||||
|
||||
`
|
||||
return strings.TrimSpace(helpText)
|
||||
}
|
||||
|
||||
@@ -42,7 +42,7 @@ func (c *MountsCommand) Run(args []string) int {
|
||||
}
|
||||
sort.Strings(paths)
|
||||
|
||||
columns := []string{"Path | Type | Default TTL | Max TTL | Description"}
|
||||
columns := []string{"Path | Type | Default TTL | Max TTL | Replication Behavior | Description"}
|
||||
for _, path := range paths {
|
||||
mount := mounts[path]
|
||||
defTTL := "system"
|
||||
@@ -63,8 +63,12 @@ func (c *MountsCommand) Run(args []string) int {
|
||||
case mount.Config.MaxLeaseTTL != 0:
|
||||
maxTTL = strconv.Itoa(mount.Config.MaxLeaseTTL)
|
||||
}
|
||||
replicatedBehavior := "replicated"
|
||||
if mount.Local {
|
||||
replicatedBehavior = "local"
|
||||
}
|
||||
columns = append(columns, fmt.Sprintf(
|
||||
"%s | %s | %s | %s | %s", path, mount.Type, defTTL, maxTTL, mount.Description))
|
||||
"%s | %s | %s | %s | %s | %s", path, mount.Type, defTTL, maxTTL, replicatedBehavior, mount.Description))
|
||||
}
|
||||
|
||||
c.Ui.Output(columnize.SimpleFormat(columns))
|
||||
|
||||
@@ -61,7 +61,7 @@ type ServerCommand struct {
|
||||
}
|
||||
|
||||
func (c *ServerCommand) Run(args []string) int {
|
||||
var dev, verifyOnly, devHA bool
|
||||
var dev, verifyOnly, devHA, devTransactional bool
|
||||
var configPath []string
|
||||
var logLevel, devRootTokenID, devListenAddress string
|
||||
flags := c.Meta.FlagSet("server", meta.FlagSetDefault)
|
||||
@@ -70,7 +70,8 @@ func (c *ServerCommand) Run(args []string) int {
|
||||
flags.StringVar(&devListenAddress, "dev-listen-address", "", "")
|
||||
flags.StringVar(&logLevel, "log-level", "info", "")
|
||||
flags.BoolVar(&verifyOnly, "verify-only", false, "")
|
||||
flags.BoolVar(&devHA, "dev-ha", false, "")
|
||||
flags.BoolVar(&devHA, "ha", false, "")
|
||||
flags.BoolVar(&devTransactional, "transactional", false, "")
|
||||
flags.Usage = func() { c.Ui.Output(c.Help()) }
|
||||
flags.Var((*sliceflag.StringFlag)(&configPath), "config", "config")
|
||||
if err := flags.Parse(args); err != nil {
|
||||
@@ -122,7 +123,7 @@ func (c *ServerCommand) Run(args []string) int {
|
||||
devListenAddress = os.Getenv("VAULT_DEV_LISTEN_ADDRESS")
|
||||
}
|
||||
|
||||
if devHA {
|
||||
if devHA || devTransactional {
|
||||
dev = true
|
||||
}
|
||||
|
||||
@@ -143,7 +144,7 @@ func (c *ServerCommand) Run(args []string) int {
|
||||
// Load the configuration
|
||||
var config *server.Config
|
||||
if dev {
|
||||
config = server.DevConfig(devHA)
|
||||
config = server.DevConfig(devHA, devTransactional)
|
||||
if devListenAddress != "" {
|
||||
config.Listeners[0].Config["address"] = devListenAddress
|
||||
}
|
||||
@@ -235,6 +236,9 @@ func (c *ServerCommand) Run(args []string) int {
|
||||
ClusterName: config.ClusterName,
|
||||
CacheSize: config.CacheSize,
|
||||
}
|
||||
if dev {
|
||||
coreConfig.DevToken = devRootTokenID
|
||||
}
|
||||
|
||||
var disableClustering bool
|
||||
|
||||
|
||||
@@ -38,7 +38,7 @@ type Config struct {
|
||||
}
|
||||
|
||||
// DevConfig is a Config that is used for dev mode of Vault.
|
||||
func DevConfig(ha bool) *Config {
|
||||
func DevConfig(ha, transactional bool) *Config {
|
||||
ret := &Config{
|
||||
DisableCache: false,
|
||||
DisableMlock: true,
|
||||
@@ -63,7 +63,12 @@ func DevConfig(ha bool) *Config {
|
||||
DefaultLeaseTTL: 32 * 24 * time.Hour,
|
||||
}
|
||||
|
||||
if ha {
|
||||
switch {
|
||||
case ha && transactional:
|
||||
ret.Backend.Type = "inmem_transactional_ha"
|
||||
case !ha && transactional:
|
||||
ret.Backend.Type = "inmem_transactional"
|
||||
case ha && !transactional:
|
||||
ret.Backend.Type = "inmem_ha"
|
||||
}
|
||||
|
||||
|
||||
@@ -33,7 +33,7 @@ func TestServer_CommonHA(t *testing.T) {
|
||||
args := []string{"-config", tmpfile.Name(), "-verify-only", "true"}
|
||||
|
||||
if code := c.Run(args); code != 0 {
|
||||
t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String())
|
||||
t.Fatalf("bad: %d\n\n%s\n\n%s", code, ui.ErrorWriter.String(), ui.OutputWriter.String())
|
||||
}
|
||||
|
||||
if !strings.Contains(ui.OutputWriter.String(), "(HA available)") {
|
||||
@@ -61,7 +61,7 @@ func TestServer_GoodSeparateHA(t *testing.T) {
|
||||
args := []string{"-config", tmpfile.Name(), "-verify-only", "true"}
|
||||
|
||||
if code := c.Run(args); code != 0 {
|
||||
t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String())
|
||||
t.Fatalf("bad: %d\n\n%s\n\n%s", code, ui.ErrorWriter.String(), ui.OutputWriter.String())
|
||||
}
|
||||
|
||||
if !strings.Contains(ui.OutputWriter.String(), "HA Backend:") {
|
||||
|
||||
@@ -40,7 +40,7 @@ func (c *StatusCommand) Run(args []string) int {
|
||||
"Key Shares: %d\n"+
|
||||
"Key Threshold: %d\n"+
|
||||
"Unseal Progress: %d\n"+
|
||||
"Unseal Nonce: %v"+
|
||||
"Unseal Nonce: %v\n"+
|
||||
"Version: %s",
|
||||
sealStatus.Sealed,
|
||||
sealStatus.N,
|
||||
|
||||
@@ -14,6 +14,16 @@ func TestCIDRUtil_IPBelongsToCIDR(t *testing.T) {
|
||||
t.Fatalf("expected IP %q to belong to CIDR %q", ip, cidr)
|
||||
}
|
||||
|
||||
ip = "10.197.192.6"
|
||||
cidr = "10.197.192.0/18"
|
||||
belongs, err = IPBelongsToCIDR(ip, cidr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !belongs {
|
||||
t.Fatalf("expected IP %q to belong to CIDR %q", ip, cidr)
|
||||
}
|
||||
|
||||
ip = "192.168.25.30"
|
||||
cidr = "192.168.26.30/24"
|
||||
belongs, err = IPBelongsToCIDR(ip, cidr)
|
||||
@@ -44,6 +54,17 @@ func TestCIDRUtil_IPBelongsToCIDRBlocksString(t *testing.T) {
|
||||
t.Fatalf("expected IP %q to belong to one of the CIDRs in %q", ip, cidrList)
|
||||
}
|
||||
|
||||
ip = "10.197.192.6"
|
||||
cidrList = "1.2.3.0/8,10.197.192.0/18,10.197.193.0/24"
|
||||
|
||||
belongs, err = IPBelongsToCIDRBlocksString(ip, cidrList, ",")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !belongs {
|
||||
t.Fatalf("expected IP %q to belong to one of the CIDRs in %q", ip, cidrList)
|
||||
}
|
||||
|
||||
ip = "192.168.27.29"
|
||||
cidrList = "172.169.100.200/18,192.168.0.0.0/16,10.10.20.20/24"
|
||||
|
||||
|
||||
7
helper/consts/consts.go
Normal file
7
helper/consts/consts.go
Normal file
@@ -0,0 +1,7 @@
|
||||
package consts
|
||||
|
||||
const (
|
||||
// ExpirationRestoreWorkerCount specifies the numer of workers to use while
|
||||
// restoring leases into the expiration manager
|
||||
ExpirationRestoreWorkerCount = 64
|
||||
)
|
||||
13
helper/consts/error.go
Normal file
13
helper/consts/error.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package consts
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
// ErrSealed is returned if an operation is performed on a sealed barrier.
|
||||
// No operation is expected to succeed before unsealing
|
||||
ErrSealed = errors.New("Vault is sealed")
|
||||
|
||||
// ErrStandby is returned if an operation is performed on a standby Vault.
|
||||
// No operation is expected to succeed until active.
|
||||
ErrStandby = errors.New("Vault is in standby mode")
|
||||
)
|
||||
20
helper/consts/replication.go
Normal file
20
helper/consts/replication.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package consts
|
||||
|
||||
type ReplicationState uint32
|
||||
|
||||
const (
|
||||
ReplicationDisabled ReplicationState = iota
|
||||
ReplicationPrimary
|
||||
ReplicationSecondary
|
||||
)
|
||||
|
||||
func (r ReplicationState) String() string {
|
||||
switch r {
|
||||
case ReplicationSecondary:
|
||||
return "secondary"
|
||||
case ReplicationPrimary:
|
||||
return "primary"
|
||||
}
|
||||
|
||||
return "disabled"
|
||||
}
|
||||
@@ -71,6 +71,15 @@ func (lm *LockManager) CacheActive() bool {
|
||||
return lm.cache != nil
|
||||
}
|
||||
|
||||
func (lm *LockManager) InvalidatePolicy(name string) {
|
||||
// Check if it's in our cache. If so, return right away.
|
||||
if lm.CacheActive() {
|
||||
lm.cacheMutex.Lock()
|
||||
defer lm.cacheMutex.Unlock()
|
||||
delete(lm.cache, name)
|
||||
}
|
||||
}
|
||||
|
||||
func (lm *LockManager) policyLock(name string, lockType bool) *sync.RWMutex {
|
||||
lm.locksMutex.RLock()
|
||||
lock := lm.locks[name]
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/errwrap"
|
||||
"github.com/hashicorp/vault/helper/consts"
|
||||
"github.com/hashicorp/vault/helper/duration"
|
||||
"github.com/hashicorp/vault/helper/jsonutil"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
@@ -206,11 +207,11 @@ func handleRequestForwarding(core *vault.Core, handler http.Handler) http.Handle
|
||||
// case of an error.
|
||||
func request(core *vault.Core, w http.ResponseWriter, rawReq *http.Request, r *logical.Request) (*logical.Response, bool) {
|
||||
resp, err := core.HandleRequest(r)
|
||||
if errwrap.Contains(err, vault.ErrStandby.Error()) {
|
||||
if errwrap.Contains(err, consts.ErrStandby.Error()) {
|
||||
respondStandby(core, w, rawReq.URL)
|
||||
return resp, false
|
||||
}
|
||||
if respondErrorCommon(w, resp, err) {
|
||||
if respondErrorCommon(w, r, resp, err) {
|
||||
return resp, false
|
||||
}
|
||||
|
||||
@@ -310,20 +311,7 @@ func requestWrapInfo(r *http.Request, req *logical.Request) (*logical.Request, e
|
||||
}
|
||||
|
||||
func respondError(w http.ResponseWriter, status int, err error) {
|
||||
// Adjust status code when sealed
|
||||
if errwrap.Contains(err, vault.ErrSealed.Error()) {
|
||||
status = http.StatusServiceUnavailable
|
||||
}
|
||||
|
||||
// Adjust status code on
|
||||
if errwrap.Contains(err, "http: request body too large") {
|
||||
status = http.StatusRequestEntityTooLarge
|
||||
}
|
||||
|
||||
// Allow HTTPCoded error passthrough to specify a code
|
||||
if t, ok := err.(logical.HTTPCodedError); ok {
|
||||
status = t.Code()
|
||||
}
|
||||
logical.AdjustErrorStatusCode(&status, err)
|
||||
|
||||
w.Header().Add("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
@@ -337,42 +325,13 @@ func respondError(w http.ResponseWriter, status int, err error) {
|
||||
enc.Encode(resp)
|
||||
}
|
||||
|
||||
func respondErrorCommon(w http.ResponseWriter, resp *logical.Response, err error) bool {
|
||||
// If there are no errors return
|
||||
if err == nil && (resp == nil || !resp.IsError()) {
|
||||
func respondErrorCommon(w http.ResponseWriter, req *logical.Request, resp *logical.Response, err error) bool {
|
||||
statusCode, newErr := logical.RespondErrorCommon(req, resp, err)
|
||||
if newErr == nil && statusCode == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
// Start out with internal server error since in most of these cases there
|
||||
// won't be a response so this won't be overridden
|
||||
statusCode := http.StatusInternalServerError
|
||||
// If we actually have a response, start out with bad request
|
||||
if resp != nil {
|
||||
statusCode = http.StatusBadRequest
|
||||
}
|
||||
|
||||
// Now, check the error itself; if it has a specific logical error, set the
|
||||
// appropriate code
|
||||
if err != nil {
|
||||
switch {
|
||||
case errwrap.ContainsType(err, new(vault.StatusBadRequest)):
|
||||
statusCode = http.StatusBadRequest
|
||||
case errwrap.Contains(err, logical.ErrPermissionDenied.Error()):
|
||||
statusCode = http.StatusForbidden
|
||||
case errwrap.Contains(err, logical.ErrUnsupportedOperation.Error()):
|
||||
statusCode = http.StatusMethodNotAllowed
|
||||
case errwrap.Contains(err, logical.ErrUnsupportedPath.Error()):
|
||||
statusCode = http.StatusNotFound
|
||||
case errwrap.Contains(err, logical.ErrInvalidRequest.Error()):
|
||||
statusCode = http.StatusBadRequest
|
||||
}
|
||||
}
|
||||
|
||||
if resp != nil && resp.IsError() {
|
||||
err = fmt.Errorf("%s", resp.Data["error"].(string))
|
||||
}
|
||||
|
||||
respondError(w, statusCode, err)
|
||||
respondError(w, statusCode, newErr)
|
||||
return true
|
||||
}
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/hashicorp/go-cleanhttp"
|
||||
"github.com/hashicorp/vault/helper/consts"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/vault"
|
||||
)
|
||||
@@ -80,6 +81,7 @@ func TestSysMounts_headerAuth(t *testing.T) {
|
||||
"default_lease_ttl": json.Number("0"),
|
||||
"max_lease_ttl": json.Number("0"),
|
||||
},
|
||||
"local": false,
|
||||
},
|
||||
"sys/": map[string]interface{}{
|
||||
"description": "system endpoints used for control, policy and debugging",
|
||||
@@ -88,6 +90,7 @@ func TestSysMounts_headerAuth(t *testing.T) {
|
||||
"default_lease_ttl": json.Number("0"),
|
||||
"max_lease_ttl": json.Number("0"),
|
||||
},
|
||||
"local": false,
|
||||
},
|
||||
"cubbyhole/": map[string]interface{}{
|
||||
"description": "per-token private secret storage",
|
||||
@@ -96,6 +99,7 @@ func TestSysMounts_headerAuth(t *testing.T) {
|
||||
"default_lease_ttl": json.Number("0"),
|
||||
"max_lease_ttl": json.Number("0"),
|
||||
},
|
||||
"local": true,
|
||||
},
|
||||
},
|
||||
"secret/": map[string]interface{}{
|
||||
@@ -105,6 +109,7 @@ func TestSysMounts_headerAuth(t *testing.T) {
|
||||
"default_lease_ttl": json.Number("0"),
|
||||
"max_lease_ttl": json.Number("0"),
|
||||
},
|
||||
"local": false,
|
||||
},
|
||||
"sys/": map[string]interface{}{
|
||||
"description": "system endpoints used for control, policy and debugging",
|
||||
@@ -113,6 +118,7 @@ func TestSysMounts_headerAuth(t *testing.T) {
|
||||
"default_lease_ttl": json.Number("0"),
|
||||
"max_lease_ttl": json.Number("0"),
|
||||
},
|
||||
"local": false,
|
||||
},
|
||||
"cubbyhole/": map[string]interface{}{
|
||||
"description": "per-token private secret storage",
|
||||
@@ -121,6 +127,7 @@ func TestSysMounts_headerAuth(t *testing.T) {
|
||||
"default_lease_ttl": json.Number("0"),
|
||||
"max_lease_ttl": json.Number("0"),
|
||||
},
|
||||
"local": true,
|
||||
},
|
||||
}
|
||||
testResponseStatus(t, resp, 200)
|
||||
@@ -223,7 +230,7 @@ func TestHandler_error(t *testing.T) {
|
||||
// vault.ErrSealed is a special case
|
||||
w3 := httptest.NewRecorder()
|
||||
|
||||
respondError(w3, 400, vault.ErrSealed)
|
||||
respondError(w3, 400, consts.ErrSealed)
|
||||
|
||||
if w3.Code != 503 {
|
||||
t.Fatalf("expected 503, got %d", w3.Code)
|
||||
|
||||
@@ -35,7 +35,7 @@ func handleHelp(core *vault.Core, w http.ResponseWriter, req *http.Request) {
|
||||
|
||||
resp, err := core.HandleRequest(lreq)
|
||||
if err != nil {
|
||||
respondErrorCommon(w, resp, err)
|
||||
respondErrorCommon(w, lreq, resp, err)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -53,6 +53,12 @@ func buildLogicalRequest(core *vault.Core, w http.ResponseWriter, r *http.Reques
|
||||
return nil, http.StatusMethodNotAllowed, nil
|
||||
}
|
||||
|
||||
if op == logical.ListOperation {
|
||||
if !strings.HasSuffix(path, "/") {
|
||||
path += "/"
|
||||
}
|
||||
}
|
||||
|
||||
// Parse the request if we can
|
||||
var data map[string]interface{}
|
||||
if op == logical.UpdateOperation {
|
||||
@@ -109,40 +115,13 @@ func handleLogical(core *vault.Core, dataOnly bool, prepareRequestCallback Prepa
|
||||
|
||||
// Make the internal request. We attach the connection info
|
||||
// as well in case this is an authentication request that requires
|
||||
// it. Vault core handles stripping this if we need to.
|
||||
// it. Vault core handles stripping this if we need to. This also
|
||||
// handles all error cases; if we hit respondLogical, the request is a
|
||||
// success.
|
||||
resp, ok := request(core, w, r, req)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
switch {
|
||||
case req.Operation == logical.ReadOperation:
|
||||
if resp == nil {
|
||||
respondError(w, http.StatusNotFound, nil)
|
||||
return
|
||||
}
|
||||
|
||||
// Basically: if we have empty "keys" or no keys at all, 404. This
|
||||
// provides consistency with GET.
|
||||
case req.Operation == logical.ListOperation && resp.WrapInfo == nil:
|
||||
if resp == nil || len(resp.Data) == 0 {
|
||||
respondError(w, http.StatusNotFound, nil)
|
||||
return
|
||||
}
|
||||
keysRaw, ok := resp.Data["keys"]
|
||||
if !ok || keysRaw == nil {
|
||||
respondError(w, http.StatusNotFound, nil)
|
||||
return
|
||||
}
|
||||
keys, ok := keysRaw.([]string)
|
||||
if !ok {
|
||||
respondError(w, http.StatusInternalServerError, nil)
|
||||
return
|
||||
}
|
||||
if len(keys) == 0 {
|
||||
respondError(w, http.StatusNotFound, nil)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Build the proper response
|
||||
respondLogical(w, r, req, dataOnly, resp)
|
||||
|
||||
@@ -4,8 +4,10 @@ import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -101,7 +103,7 @@ func TestLogical_StandbyRedirect(t *testing.T) {
|
||||
|
||||
// Attempt to fix raciness in this test by giving the first core a chance
|
||||
// to grab the lock
|
||||
time.Sleep(time.Second)
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
// Create a second HA Vault
|
||||
conf2 := &vault.CoreConfig{
|
||||
@@ -252,3 +254,42 @@ func TestLogical_RequestSizeLimit(t *testing.T) {
|
||||
})
|
||||
testResponseStatus(t, resp, 413)
|
||||
}
|
||||
|
||||
func TestLogical_ListSuffix(t *testing.T) {
|
||||
core, _, _ := vault.TestCoreUnsealed(t)
|
||||
req, _ := http.NewRequest("GET", "http://127.0.0.1:8200/v1/secret/foo", nil)
|
||||
lreq, status, err := buildLogicalRequest(core, nil, req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if status != 0 {
|
||||
t.Fatalf("got status %d", status)
|
||||
}
|
||||
if strings.HasSuffix(lreq.Path, "/") {
|
||||
t.Fatal("trailing slash found on path")
|
||||
}
|
||||
|
||||
req, _ = http.NewRequest("GET", "http://127.0.0.1:8200/v1/secret/foo?list=true", nil)
|
||||
lreq, status, err = buildLogicalRequest(core, nil, req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if status != 0 {
|
||||
t.Fatalf("got status %d", status)
|
||||
}
|
||||
if !strings.HasSuffix(lreq.Path, "/") {
|
||||
t.Fatal("trailing slash not found on path")
|
||||
}
|
||||
|
||||
req, _ = http.NewRequest("LIST", "http://127.0.0.1:8200/v1/secret/foo", nil)
|
||||
lreq, status, err = buildLogicalRequest(core, nil, req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if status != 0 {
|
||||
t.Fatalf("got status %d", status)
|
||||
}
|
||||
if !strings.HasSuffix(lreq.Path, "/") {
|
||||
t.Fatal("trailing slash not found on path")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -35,6 +35,7 @@ func TestSysAudit(t *testing.T) {
|
||||
"type": "noop",
|
||||
"description": "",
|
||||
"options": map[string]interface{}{},
|
||||
"local": false,
|
||||
},
|
||||
},
|
||||
"noop/": map[string]interface{}{
|
||||
@@ -42,6 +43,7 @@ func TestSysAudit(t *testing.T) {
|
||||
"type": "noop",
|
||||
"description": "",
|
||||
"options": map[string]interface{}{},
|
||||
"local": false,
|
||||
},
|
||||
}
|
||||
testResponseStatus(t, resp, 200)
|
||||
|
||||
@@ -32,6 +32,7 @@ func TestSysAuth(t *testing.T) {
|
||||
"default_lease_ttl": json.Number("0"),
|
||||
"max_lease_ttl": json.Number("0"),
|
||||
},
|
||||
"local": false,
|
||||
},
|
||||
},
|
||||
"token/": map[string]interface{}{
|
||||
@@ -41,6 +42,7 @@ func TestSysAuth(t *testing.T) {
|
||||
"default_lease_ttl": json.Number("0"),
|
||||
"max_lease_ttl": json.Number("0"),
|
||||
},
|
||||
"local": false,
|
||||
},
|
||||
}
|
||||
testResponseStatus(t, resp, 200)
|
||||
@@ -83,6 +85,7 @@ func TestSysEnableAuth(t *testing.T) {
|
||||
"default_lease_ttl": json.Number("0"),
|
||||
"max_lease_ttl": json.Number("0"),
|
||||
},
|
||||
"local": false,
|
||||
},
|
||||
"token/": map[string]interface{}{
|
||||
"description": "token based credentials",
|
||||
@@ -91,6 +94,7 @@ func TestSysEnableAuth(t *testing.T) {
|
||||
"default_lease_ttl": json.Number("0"),
|
||||
"max_lease_ttl": json.Number("0"),
|
||||
},
|
||||
"local": false,
|
||||
},
|
||||
},
|
||||
"foo/": map[string]interface{}{
|
||||
@@ -100,6 +104,7 @@ func TestSysEnableAuth(t *testing.T) {
|
||||
"default_lease_ttl": json.Number("0"),
|
||||
"max_lease_ttl": json.Number("0"),
|
||||
},
|
||||
"local": false,
|
||||
},
|
||||
"token/": map[string]interface{}{
|
||||
"description": "token based credentials",
|
||||
@@ -108,6 +113,7 @@ func TestSysEnableAuth(t *testing.T) {
|
||||
"default_lease_ttl": json.Number("0"),
|
||||
"max_lease_ttl": json.Number("0"),
|
||||
},
|
||||
"local": false,
|
||||
},
|
||||
}
|
||||
testResponseStatus(t, resp, 200)
|
||||
@@ -153,6 +159,7 @@ func TestSysDisableAuth(t *testing.T) {
|
||||
},
|
||||
"description": "token based credentials",
|
||||
"type": "token",
|
||||
"local": false,
|
||||
},
|
||||
},
|
||||
"token/": map[string]interface{}{
|
||||
@@ -162,6 +169,7 @@ func TestSysDisableAuth(t *testing.T) {
|
||||
},
|
||||
"description": "token based credentials",
|
||||
"type": "token",
|
||||
"local": false,
|
||||
},
|
||||
}
|
||||
testResponseStatus(t, resp, 200)
|
||||
|
||||
@@ -33,6 +33,7 @@ func TestSysMounts(t *testing.T) {
|
||||
"default_lease_ttl": json.Number("0"),
|
||||
"max_lease_ttl": json.Number("0"),
|
||||
},
|
||||
"local": false,
|
||||
},
|
||||
"sys/": map[string]interface{}{
|
||||
"description": "system endpoints used for control, policy and debugging",
|
||||
@@ -41,6 +42,7 @@ func TestSysMounts(t *testing.T) {
|
||||
"default_lease_ttl": json.Number("0"),
|
||||
"max_lease_ttl": json.Number("0"),
|
||||
},
|
||||
"local": false,
|
||||
},
|
||||
"cubbyhole/": map[string]interface{}{
|
||||
"description": "per-token private secret storage",
|
||||
@@ -49,6 +51,7 @@ func TestSysMounts(t *testing.T) {
|
||||
"default_lease_ttl": json.Number("0"),
|
||||
"max_lease_ttl": json.Number("0"),
|
||||
},
|
||||
"local": true,
|
||||
},
|
||||
},
|
||||
"secret/": map[string]interface{}{
|
||||
@@ -58,6 +61,7 @@ func TestSysMounts(t *testing.T) {
|
||||
"default_lease_ttl": json.Number("0"),
|
||||
"max_lease_ttl": json.Number("0"),
|
||||
},
|
||||
"local": false,
|
||||
},
|
||||
"sys/": map[string]interface{}{
|
||||
"description": "system endpoints used for control, policy and debugging",
|
||||
@@ -66,6 +70,7 @@ func TestSysMounts(t *testing.T) {
|
||||
"default_lease_ttl": json.Number("0"),
|
||||
"max_lease_ttl": json.Number("0"),
|
||||
},
|
||||
"local": false,
|
||||
},
|
||||
"cubbyhole/": map[string]interface{}{
|
||||
"description": "per-token private secret storage",
|
||||
@@ -74,6 +79,7 @@ func TestSysMounts(t *testing.T) {
|
||||
"default_lease_ttl": json.Number("0"),
|
||||
"max_lease_ttl": json.Number("0"),
|
||||
},
|
||||
"local": true,
|
||||
},
|
||||
}
|
||||
testResponseStatus(t, resp, 200)
|
||||
@@ -114,6 +120,7 @@ func TestSysMount(t *testing.T) {
|
||||
"default_lease_ttl": json.Number("0"),
|
||||
"max_lease_ttl": json.Number("0"),
|
||||
},
|
||||
"local": false,
|
||||
},
|
||||
"secret/": map[string]interface{}{
|
||||
"description": "generic secret storage",
|
||||
@@ -122,6 +129,7 @@ func TestSysMount(t *testing.T) {
|
||||
"default_lease_ttl": json.Number("0"),
|
||||
"max_lease_ttl": json.Number("0"),
|
||||
},
|
||||
"local": false,
|
||||
},
|
||||
"sys/": map[string]interface{}{
|
||||
"description": "system endpoints used for control, policy and debugging",
|
||||
@@ -130,6 +138,7 @@ func TestSysMount(t *testing.T) {
|
||||
"default_lease_ttl": json.Number("0"),
|
||||
"max_lease_ttl": json.Number("0"),
|
||||
},
|
||||
"local": false,
|
||||
},
|
||||
"cubbyhole/": map[string]interface{}{
|
||||
"description": "per-token private secret storage",
|
||||
@@ -138,6 +147,7 @@ func TestSysMount(t *testing.T) {
|
||||
"default_lease_ttl": json.Number("0"),
|
||||
"max_lease_ttl": json.Number("0"),
|
||||
},
|
||||
"local": true,
|
||||
},
|
||||
},
|
||||
"foo/": map[string]interface{}{
|
||||
@@ -147,6 +157,7 @@ func TestSysMount(t *testing.T) {
|
||||
"default_lease_ttl": json.Number("0"),
|
||||
"max_lease_ttl": json.Number("0"),
|
||||
},
|
||||
"local": false,
|
||||
},
|
||||
"secret/": map[string]interface{}{
|
||||
"description": "generic secret storage",
|
||||
@@ -155,6 +166,7 @@ func TestSysMount(t *testing.T) {
|
||||
"default_lease_ttl": json.Number("0"),
|
||||
"max_lease_ttl": json.Number("0"),
|
||||
},
|
||||
"local": false,
|
||||
},
|
||||
"sys/": map[string]interface{}{
|
||||
"description": "system endpoints used for control, policy and debugging",
|
||||
@@ -163,6 +175,7 @@ func TestSysMount(t *testing.T) {
|
||||
"default_lease_ttl": json.Number("0"),
|
||||
"max_lease_ttl": json.Number("0"),
|
||||
},
|
||||
"local": false,
|
||||
},
|
||||
"cubbyhole/": map[string]interface{}{
|
||||
"description": "per-token private secret storage",
|
||||
@@ -171,6 +184,7 @@ func TestSysMount(t *testing.T) {
|
||||
"default_lease_ttl": json.Number("0"),
|
||||
"max_lease_ttl": json.Number("0"),
|
||||
},
|
||||
"local": true,
|
||||
},
|
||||
}
|
||||
testResponseStatus(t, resp, 200)
|
||||
@@ -233,6 +247,7 @@ func TestSysRemount(t *testing.T) {
|
||||
"default_lease_ttl": json.Number("0"),
|
||||
"max_lease_ttl": json.Number("0"),
|
||||
},
|
||||
"local": false,
|
||||
},
|
||||
"secret/": map[string]interface{}{
|
||||
"description": "generic secret storage",
|
||||
@@ -241,6 +256,7 @@ func TestSysRemount(t *testing.T) {
|
||||
"default_lease_ttl": json.Number("0"),
|
||||
"max_lease_ttl": json.Number("0"),
|
||||
},
|
||||
"local": false,
|
||||
},
|
||||
"sys/": map[string]interface{}{
|
||||
"description": "system endpoints used for control, policy and debugging",
|
||||
@@ -249,6 +265,7 @@ func TestSysRemount(t *testing.T) {
|
||||
"default_lease_ttl": json.Number("0"),
|
||||
"max_lease_ttl": json.Number("0"),
|
||||
},
|
||||
"local": false,
|
||||
},
|
||||
"cubbyhole/": map[string]interface{}{
|
||||
"description": "per-token private secret storage",
|
||||
@@ -257,6 +274,7 @@ func TestSysRemount(t *testing.T) {
|
||||
"default_lease_ttl": json.Number("0"),
|
||||
"max_lease_ttl": json.Number("0"),
|
||||
},
|
||||
"local": true,
|
||||
},
|
||||
},
|
||||
"bar/": map[string]interface{}{
|
||||
@@ -266,6 +284,7 @@ func TestSysRemount(t *testing.T) {
|
||||
"default_lease_ttl": json.Number("0"),
|
||||
"max_lease_ttl": json.Number("0"),
|
||||
},
|
||||
"local": false,
|
||||
},
|
||||
"secret/": map[string]interface{}{
|
||||
"description": "generic secret storage",
|
||||
@@ -274,6 +293,7 @@ func TestSysRemount(t *testing.T) {
|
||||
"default_lease_ttl": json.Number("0"),
|
||||
"max_lease_ttl": json.Number("0"),
|
||||
},
|
||||
"local": false,
|
||||
},
|
||||
"sys/": map[string]interface{}{
|
||||
"description": "system endpoints used for control, policy and debugging",
|
||||
@@ -282,6 +302,7 @@ func TestSysRemount(t *testing.T) {
|
||||
"default_lease_ttl": json.Number("0"),
|
||||
"max_lease_ttl": json.Number("0"),
|
||||
},
|
||||
"local": false,
|
||||
},
|
||||
"cubbyhole/": map[string]interface{}{
|
||||
"description": "per-token private secret storage",
|
||||
@@ -290,6 +311,7 @@ func TestSysRemount(t *testing.T) {
|
||||
"default_lease_ttl": json.Number("0"),
|
||||
"max_lease_ttl": json.Number("0"),
|
||||
},
|
||||
"local": true,
|
||||
},
|
||||
}
|
||||
testResponseStatus(t, resp, 200)
|
||||
@@ -333,6 +355,7 @@ func TestSysUnmount(t *testing.T) {
|
||||
"default_lease_ttl": json.Number("0"),
|
||||
"max_lease_ttl": json.Number("0"),
|
||||
},
|
||||
"local": false,
|
||||
},
|
||||
"sys/": map[string]interface{}{
|
||||
"description": "system endpoints used for control, policy and debugging",
|
||||
@@ -341,6 +364,7 @@ func TestSysUnmount(t *testing.T) {
|
||||
"default_lease_ttl": json.Number("0"),
|
||||
"max_lease_ttl": json.Number("0"),
|
||||
},
|
||||
"local": false,
|
||||
},
|
||||
"cubbyhole/": map[string]interface{}{
|
||||
"description": "per-token private secret storage",
|
||||
@@ -349,6 +373,7 @@ func TestSysUnmount(t *testing.T) {
|
||||
"default_lease_ttl": json.Number("0"),
|
||||
"max_lease_ttl": json.Number("0"),
|
||||
},
|
||||
"local": true,
|
||||
},
|
||||
},
|
||||
"secret/": map[string]interface{}{
|
||||
@@ -358,6 +383,7 @@ func TestSysUnmount(t *testing.T) {
|
||||
"default_lease_ttl": json.Number("0"),
|
||||
"max_lease_ttl": json.Number("0"),
|
||||
},
|
||||
"local": false,
|
||||
},
|
||||
"sys/": map[string]interface{}{
|
||||
"description": "system endpoints used for control, policy and debugging",
|
||||
@@ -366,6 +392,7 @@ func TestSysUnmount(t *testing.T) {
|
||||
"default_lease_ttl": json.Number("0"),
|
||||
"max_lease_ttl": json.Number("0"),
|
||||
},
|
||||
"local": false,
|
||||
},
|
||||
"cubbyhole/": map[string]interface{}{
|
||||
"description": "per-token private secret storage",
|
||||
@@ -374,6 +401,7 @@ func TestSysUnmount(t *testing.T) {
|
||||
"default_lease_ttl": json.Number("0"),
|
||||
"max_lease_ttl": json.Number("0"),
|
||||
},
|
||||
"local": true,
|
||||
},
|
||||
}
|
||||
testResponseStatus(t, resp, 200)
|
||||
@@ -414,6 +442,7 @@ func TestSysTuneMount(t *testing.T) {
|
||||
"default_lease_ttl": json.Number("0"),
|
||||
"max_lease_ttl": json.Number("0"),
|
||||
},
|
||||
"local": false,
|
||||
},
|
||||
"secret/": map[string]interface{}{
|
||||
"description": "generic secret storage",
|
||||
@@ -422,6 +451,7 @@ func TestSysTuneMount(t *testing.T) {
|
||||
"default_lease_ttl": json.Number("0"),
|
||||
"max_lease_ttl": json.Number("0"),
|
||||
},
|
||||
"local": false,
|
||||
},
|
||||
"sys/": map[string]interface{}{
|
||||
"description": "system endpoints used for control, policy and debugging",
|
||||
@@ -430,6 +460,7 @@ func TestSysTuneMount(t *testing.T) {
|
||||
"default_lease_ttl": json.Number("0"),
|
||||
"max_lease_ttl": json.Number("0"),
|
||||
},
|
||||
"local": false,
|
||||
},
|
||||
"cubbyhole/": map[string]interface{}{
|
||||
"description": "per-token private secret storage",
|
||||
@@ -438,6 +469,7 @@ func TestSysTuneMount(t *testing.T) {
|
||||
"default_lease_ttl": json.Number("0"),
|
||||
"max_lease_ttl": json.Number("0"),
|
||||
},
|
||||
"local": true,
|
||||
},
|
||||
},
|
||||
"foo/": map[string]interface{}{
|
||||
@@ -447,6 +479,7 @@ func TestSysTuneMount(t *testing.T) {
|
||||
"default_lease_ttl": json.Number("0"),
|
||||
"max_lease_ttl": json.Number("0"),
|
||||
},
|
||||
"local": false,
|
||||
},
|
||||
"secret/": map[string]interface{}{
|
||||
"description": "generic secret storage",
|
||||
@@ -455,6 +488,7 @@ func TestSysTuneMount(t *testing.T) {
|
||||
"default_lease_ttl": json.Number("0"),
|
||||
"max_lease_ttl": json.Number("0"),
|
||||
},
|
||||
"local": false,
|
||||
},
|
||||
"sys/": map[string]interface{}{
|
||||
"description": "system endpoints used for control, policy and debugging",
|
||||
@@ -463,6 +497,7 @@ func TestSysTuneMount(t *testing.T) {
|
||||
"default_lease_ttl": json.Number("0"),
|
||||
"max_lease_ttl": json.Number("0"),
|
||||
},
|
||||
"local": false,
|
||||
},
|
||||
"cubbyhole/": map[string]interface{}{
|
||||
"description": "per-token private secret storage",
|
||||
@@ -471,6 +506,7 @@ func TestSysTuneMount(t *testing.T) {
|
||||
"default_lease_ttl": json.Number("0"),
|
||||
"max_lease_ttl": json.Number("0"),
|
||||
},
|
||||
"local": true,
|
||||
},
|
||||
}
|
||||
testResponseStatus(t, resp, 200)
|
||||
@@ -532,6 +568,7 @@ func TestSysTuneMount(t *testing.T) {
|
||||
"default_lease_ttl": json.Number("259196400"),
|
||||
"max_lease_ttl": json.Number("259200000"),
|
||||
},
|
||||
"local": false,
|
||||
},
|
||||
"secret/": map[string]interface{}{
|
||||
"description": "generic secret storage",
|
||||
@@ -540,6 +577,7 @@ func TestSysTuneMount(t *testing.T) {
|
||||
"default_lease_ttl": json.Number("0"),
|
||||
"max_lease_ttl": json.Number("0"),
|
||||
},
|
||||
"local": false,
|
||||
},
|
||||
"sys/": map[string]interface{}{
|
||||
"description": "system endpoints used for control, policy and debugging",
|
||||
@@ -548,6 +586,7 @@ func TestSysTuneMount(t *testing.T) {
|
||||
"default_lease_ttl": json.Number("0"),
|
||||
"max_lease_ttl": json.Number("0"),
|
||||
},
|
||||
"local": false,
|
||||
},
|
||||
"cubbyhole/": map[string]interface{}{
|
||||
"description": "per-token private secret storage",
|
||||
@@ -556,6 +595,7 @@ func TestSysTuneMount(t *testing.T) {
|
||||
"default_lease_ttl": json.Number("0"),
|
||||
"max_lease_ttl": json.Number("0"),
|
||||
},
|
||||
"local": true,
|
||||
},
|
||||
},
|
||||
"foo/": map[string]interface{}{
|
||||
@@ -565,6 +605,7 @@ func TestSysTuneMount(t *testing.T) {
|
||||
"default_lease_ttl": json.Number("259196400"),
|
||||
"max_lease_ttl": json.Number("259200000"),
|
||||
},
|
||||
"local": false,
|
||||
},
|
||||
"secret/": map[string]interface{}{
|
||||
"description": "generic secret storage",
|
||||
@@ -573,6 +614,7 @@ func TestSysTuneMount(t *testing.T) {
|
||||
"default_lease_ttl": json.Number("0"),
|
||||
"max_lease_ttl": json.Number("0"),
|
||||
},
|
||||
"local": false,
|
||||
},
|
||||
"sys/": map[string]interface{}{
|
||||
"description": "system endpoints used for control, policy and debugging",
|
||||
@@ -581,6 +623,7 @@ func TestSysTuneMount(t *testing.T) {
|
||||
"default_lease_ttl": json.Number("0"),
|
||||
"max_lease_ttl": json.Number("0"),
|
||||
},
|
||||
"local": false,
|
||||
},
|
||||
"cubbyhole/": map[string]interface{}{
|
||||
"description": "per-token private secret storage",
|
||||
@@ -589,6 +632,7 @@ func TestSysTuneMount(t *testing.T) {
|
||||
"default_lease_ttl": json.Number("0"),
|
||||
"max_lease_ttl": json.Number("0"),
|
||||
},
|
||||
"local": true,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/hashicorp/vault/helper/consts"
|
||||
"github.com/hashicorp/vault/helper/pgpkeys"
|
||||
"github.com/hashicorp/vault/vault"
|
||||
)
|
||||
@@ -19,6 +20,13 @@ func handleSysRekeyInit(core *vault.Core, recovery bool) http.Handler {
|
||||
return
|
||||
}
|
||||
|
||||
repState := core.ReplicationState()
|
||||
if repState == consts.ReplicationSecondary {
|
||||
respondError(w, http.StatusBadRequest,
|
||||
fmt.Errorf("rekeying can only be performed on the primary cluster when replication is activated"))
|
||||
return
|
||||
}
|
||||
|
||||
switch {
|
||||
case recovery && !core.SealAccess().RecoveryKeySupported():
|
||||
respondError(w, http.StatusBadRequest, fmt.Errorf("recovery rekeying not supported"))
|
||||
@@ -108,7 +116,7 @@ func handleSysRekeyInitPut(core *vault.Core, recovery bool, w http.ResponseWrite
|
||||
// Right now we don't support this, but the rest of the code is ready for
|
||||
// when we do, hence the check below for this to be false if
|
||||
// StoredShares is greater than zero
|
||||
if core.SealAccess().StoredKeysSupported() {
|
||||
if core.SealAccess().StoredKeysSupported() && !recovery {
|
||||
respondError(w, http.StatusBadRequest, fmt.Errorf("rekeying of barrier not supported when stored key support is available"))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"net/http"
|
||||
|
||||
"github.com/hashicorp/errwrap"
|
||||
"github.com/hashicorp/vault/helper/consts"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/vault"
|
||||
"github.com/hashicorp/vault/version"
|
||||
@@ -126,7 +127,7 @@ func handleSysUnseal(core *vault.Core) http.Handler {
|
||||
case errwrap.Contains(err, vault.ErrBarrierInvalidKey.Error()):
|
||||
case errwrap.Contains(err, vault.ErrBarrierNotInit.Error()):
|
||||
case errwrap.Contains(err, vault.ErrBarrierSealed.Error()):
|
||||
case errwrap.Contains(err, vault.ErrStandby.Error()):
|
||||
case errwrap.Contains(err, consts.ErrStandby.Error()):
|
||||
default:
|
||||
respondError(w, http.StatusInternalServerError, err)
|
||||
return
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
// is present on the Request structure for credential backends.
|
||||
type Connection struct {
|
||||
// RemoteAddr is the network address that sent the request.
|
||||
RemoteAddr string
|
||||
RemoteAddr string `json:"remote_addr"`
|
||||
|
||||
// ConnState is the TLS connection state if applicable.
|
||||
ConnState *tls.ConnectionState
|
||||
|
||||
@@ -21,3 +21,27 @@ func (e *codedError) Error() string {
|
||||
func (e *codedError) Code() int {
|
||||
return e.code
|
||||
}
|
||||
|
||||
// Struct to identify user input errors. This is helpful in responding the
|
||||
// appropriate status codes to clients from the HTTP endpoints.
|
||||
type StatusBadRequest struct {
|
||||
Err string
|
||||
}
|
||||
|
||||
// Implementing error interface
|
||||
func (s *StatusBadRequest) Error() string {
|
||||
return s.Err
|
||||
}
|
||||
|
||||
// This is a new type declared to not cause potential compatibility problems if
|
||||
// the logic around the HTTPCodedError interface changes; in particular for
|
||||
// logical request paths it is basically ignored, and changing that behavior
|
||||
// might cause unforseen issues.
|
||||
type ReplicationCodedError struct {
|
||||
Msg string
|
||||
Code int
|
||||
}
|
||||
|
||||
func (r *ReplicationCodedError) Error() string {
|
||||
return r.Msg
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package framework
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"regexp"
|
||||
@@ -12,6 +13,7 @@ import (
|
||||
log "github.com/mgutz/logxi/v1"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/hashicorp/vault/helper/duration"
|
||||
"github.com/hashicorp/vault/helper/errutil"
|
||||
"github.com/hashicorp/vault/helper/logformat"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
@@ -534,7 +536,40 @@ type FieldSchema struct {
|
||||
// the zero value of the type.
|
||||
func (s *FieldSchema) DefaultOrZero() interface{} {
|
||||
if s.Default != nil {
|
||||
return s.Default
|
||||
switch s.Type {
|
||||
case TypeDurationSecond:
|
||||
var result int
|
||||
switch inp := s.Default.(type) {
|
||||
case nil:
|
||||
return s.Type.Zero()
|
||||
case int:
|
||||
result = inp
|
||||
case int64:
|
||||
result = int(inp)
|
||||
case float32:
|
||||
result = int(inp)
|
||||
case float64:
|
||||
result = int(inp)
|
||||
case string:
|
||||
dur, err := duration.ParseDurationSecond(inp)
|
||||
if err != nil {
|
||||
return s.Type.Zero()
|
||||
}
|
||||
result = int(dur.Seconds())
|
||||
case json.Number:
|
||||
valInt64, err := inp.Int64()
|
||||
if err != nil {
|
||||
return s.Type.Zero()
|
||||
}
|
||||
result = int(valInt64)
|
||||
default:
|
||||
return s.Type.Zero()
|
||||
}
|
||||
return result
|
||||
|
||||
default:
|
||||
return s.Default
|
||||
}
|
||||
}
|
||||
|
||||
return s.Type.Zero()
|
||||
|
||||
@@ -554,6 +554,16 @@ func TestFieldSchemaDefaultOrZero(t *testing.T) {
|
||||
60,
|
||||
},
|
||||
|
||||
"default duration int64": {
|
||||
&FieldSchema{Type: TypeDurationSecond, Default: int64(60)},
|
||||
60,
|
||||
},
|
||||
|
||||
"default duration string": {
|
||||
&FieldSchema{Type: TypeDurationSecond, Default: "60s"},
|
||||
60,
|
||||
},
|
||||
|
||||
"default duration not set": {
|
||||
&FieldSchema{Type: TypeDurationSecond},
|
||||
0,
|
||||
|
||||
@@ -80,22 +80,3 @@ type Paths struct {
|
||||
// indicates that these paths should not be replicated
|
||||
LocalStorage []string
|
||||
}
|
||||
|
||||
type ReplicationState uint32
|
||||
|
||||
const (
|
||||
ReplicationDisabled ReplicationState = iota
|
||||
ReplicationPrimary
|
||||
ReplicationSecondary
|
||||
)
|
||||
|
||||
func (r ReplicationState) String() string {
|
||||
switch r {
|
||||
case ReplicationSecondary:
|
||||
return "secondary"
|
||||
case ReplicationPrimary:
|
||||
return "primary"
|
||||
}
|
||||
|
||||
return "disabled"
|
||||
}
|
||||
|
||||
@@ -25,6 +25,10 @@ type Request struct {
|
||||
// Id is the uuid associated with each request
|
||||
ID string `json:"id" structs:"id" mapstructure:"id"`
|
||||
|
||||
// If set, the name given to the replication secondary where this request
|
||||
// originated
|
||||
ReplicationCluster string `json:"replication_cluster" structs:"replication_cluster", mapstructure:"replication_cluster"`
|
||||
|
||||
// Operation is the requested operation type
|
||||
Operation Operation `json:"operation" structs:"operation" mapstructure:"operation"`
|
||||
|
||||
@@ -38,7 +42,7 @@ type Request struct {
|
||||
Data map[string]interface{} `json:"map" structs:"data" mapstructure:"data"`
|
||||
|
||||
// Storage can be used to durably store and retrieve state.
|
||||
Storage Storage `json:"storage" structs:"storage" mapstructure:"storage"`
|
||||
Storage Storage `json:"-"`
|
||||
|
||||
// Secret will be non-nil only for Revoke and Renew operations
|
||||
// to represent the secret that was returned prior.
|
||||
|
||||
111
logical/response_util.go
Normal file
111
logical/response_util.go
Normal file
@@ -0,0 +1,111 @@
|
||||
package logical
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/hashicorp/errwrap"
|
||||
multierror "github.com/hashicorp/go-multierror"
|
||||
"github.com/hashicorp/vault/helper/consts"
|
||||
)
|
||||
|
||||
// RespondErrorCommon pulls most of the functionality from http's
|
||||
// respondErrorCommon and some of http's handleLogical and makes it available
|
||||
// to both the http package and elsewhere.
|
||||
func RespondErrorCommon(req *Request, resp *Response, err error) (int, error) {
|
||||
if err == nil && (resp == nil || !resp.IsError()) {
|
||||
switch {
|
||||
case req.Operation == ReadOperation:
|
||||
if resp == nil {
|
||||
return http.StatusNotFound, nil
|
||||
}
|
||||
|
||||
// Basically: if we have empty "keys" or no keys at all, 404. This
|
||||
// provides consistency with GET.
|
||||
case req.Operation == ListOperation && resp.WrapInfo == nil:
|
||||
if resp == nil || len(resp.Data) == 0 {
|
||||
return http.StatusNotFound, nil
|
||||
}
|
||||
keysRaw, ok := resp.Data["keys"]
|
||||
if !ok || keysRaw == nil {
|
||||
return http.StatusNotFound, nil
|
||||
}
|
||||
keys, ok := keysRaw.([]string)
|
||||
if !ok {
|
||||
return http.StatusInternalServerError, nil
|
||||
}
|
||||
if len(keys) == 0 {
|
||||
return http.StatusNotFound, nil
|
||||
}
|
||||
}
|
||||
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
if errwrap.ContainsType(err, new(ReplicationCodedError)) {
|
||||
var allErrors error
|
||||
codedErr := errwrap.GetType(err, new(ReplicationCodedError)).(*ReplicationCodedError)
|
||||
errwrap.Walk(err, func(inErr error) {
|
||||
newErr, ok := inErr.(*ReplicationCodedError)
|
||||
if !ok {
|
||||
allErrors = multierror.Append(allErrors, newErr)
|
||||
}
|
||||
})
|
||||
if allErrors != nil {
|
||||
return codedErr.Code, multierror.Append(errors.New(fmt.Sprintf("errors from both primary and secondary; primary error was %v; secondary errors follow", codedErr.Msg)), allErrors)
|
||||
}
|
||||
return codedErr.Code, errors.New(codedErr.Msg)
|
||||
}
|
||||
|
||||
// Start out with internal server error since in most of these cases there
|
||||
// won't be a response so this won't be overridden
|
||||
statusCode := http.StatusInternalServerError
|
||||
// If we actually have a response, start out with bad request
|
||||
if resp != nil {
|
||||
statusCode = http.StatusBadRequest
|
||||
}
|
||||
|
||||
// Now, check the error itself; if it has a specific logical error, set the
|
||||
// appropriate code
|
||||
if err != nil {
|
||||
switch {
|
||||
case errwrap.ContainsType(err, new(StatusBadRequest)):
|
||||
statusCode = http.StatusBadRequest
|
||||
case errwrap.Contains(err, ErrPermissionDenied.Error()):
|
||||
statusCode = http.StatusForbidden
|
||||
case errwrap.Contains(err, ErrUnsupportedOperation.Error()):
|
||||
statusCode = http.StatusMethodNotAllowed
|
||||
case errwrap.Contains(err, ErrUnsupportedPath.Error()):
|
||||
statusCode = http.StatusNotFound
|
||||
case errwrap.Contains(err, ErrInvalidRequest.Error()):
|
||||
statusCode = http.StatusBadRequest
|
||||
}
|
||||
}
|
||||
|
||||
if resp != nil && resp.IsError() {
|
||||
err = fmt.Errorf("%s", resp.Data["error"].(string))
|
||||
}
|
||||
|
||||
return statusCode, err
|
||||
}
|
||||
|
||||
// AdjustErrorStatusCode adjusts the status that will be sent in error
|
||||
// conditions in a way that can be shared across http's respondError and other
|
||||
// locations.
|
||||
func AdjustErrorStatusCode(status *int, err error) {
|
||||
// Adjust status code when sealed
|
||||
if errwrap.Contains(err, consts.ErrSealed.Error()) {
|
||||
*status = http.StatusServiceUnavailable
|
||||
}
|
||||
|
||||
// Adjust status code on
|
||||
if errwrap.Contains(err, "http: request body too large") {
|
||||
*status = http.StatusRequestEntityTooLarge
|
||||
}
|
||||
|
||||
// Allow HTTPCoded error passthrough to specify a code
|
||||
if t, ok := err.(HTTPCodedError); ok {
|
||||
*status = t.Code()
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,10 @@
|
||||
package logical
|
||||
|
||||
import "time"
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/vault/helper/consts"
|
||||
)
|
||||
|
||||
// SystemView exposes system configuration information in a safe way
|
||||
// for logical backends to consume
|
||||
@@ -32,7 +36,7 @@ type SystemView interface {
|
||||
CachingDisabled() bool
|
||||
|
||||
// ReplicationState indicates the state of cluster replication
|
||||
ReplicationState() ReplicationState
|
||||
ReplicationState() consts.ReplicationState
|
||||
}
|
||||
|
||||
type StaticSystemView struct {
|
||||
@@ -42,7 +46,7 @@ type StaticSystemView struct {
|
||||
TaintedVal bool
|
||||
CachingDisabledVal bool
|
||||
Primary bool
|
||||
ReplicationStateVal ReplicationState
|
||||
ReplicationStateVal consts.ReplicationState
|
||||
}
|
||||
|
||||
func (d StaticSystemView) DefaultLeaseTTL() time.Duration {
|
||||
@@ -65,6 +69,6 @@ func (d StaticSystemView) CachingDisabled() bool {
|
||||
return d.CachingDisabledVal
|
||||
}
|
||||
|
||||
func (d StaticSystemView) ReplicationState() ReplicationState {
|
||||
func (d StaticSystemView) ReplicationState() consts.ReplicationState {
|
||||
return d.ReplicationStateVal
|
||||
}
|
||||
|
||||
17
meta/meta.go
17
meta/meta.go
@@ -23,6 +23,21 @@ const (
|
||||
FlagSetDefault = FlagSetServer
|
||||
)
|
||||
|
||||
var (
|
||||
additionalOptionsUsage = func() string {
|
||||
return `
|
||||
-wrap-ttl="" Indicates that the response should be wrapped in a
|
||||
cubbyhole token with the requested TTL. The response
|
||||
can be fetched by calling the "sys/wrapping/unwrap"
|
||||
endpoint, passing in the wrappping token's ID. This
|
||||
is a numeric string with an optional suffix
|
||||
"s", "m", or "h"; if no suffix is specified it will
|
||||
be parsed as seconds. May also be specified via
|
||||
VAULT_WRAP_TTL.
|
||||
`
|
||||
}
|
||||
)
|
||||
|
||||
// Meta contains the meta-options and functionality that nearly every
|
||||
// Vault command inherits.
|
||||
type Meta struct {
|
||||
@@ -188,6 +203,6 @@ func GeneralOptionsUsage() string {
|
||||
if VAULT_SKIP_VERIFY is set.
|
||||
`
|
||||
|
||||
general += AdditionalOptionsUsage()
|
||||
general += additionalOptionsUsage()
|
||||
return general
|
||||
}
|
||||
|
||||
@@ -1,7 +0,0 @@
|
||||
// +build !vault
|
||||
|
||||
package meta
|
||||
|
||||
func AdditionalOptionsUsage() string {
|
||||
return ""
|
||||
}
|
||||
@@ -1,16 +0,0 @@
|
||||
// +build vault
|
||||
|
||||
package meta
|
||||
|
||||
func AdditionalOptionsUsage() string {
|
||||
return `
|
||||
-wrap-ttl="" Indicates that the response should be wrapped in a
|
||||
cubbyhole token with the requested TTL. The response
|
||||
can be fetched by calling the "sys/wrapping/unwrap"
|
||||
endpoint, passing in the wrappping token's ID. This
|
||||
is a numeric string with an optional suffix
|
||||
"s", "m", or "h"; if no suffix is specified it will
|
||||
be parsed as seconds. May also be specified via
|
||||
VAULT_WRAP_TTL.
|
||||
`
|
||||
}
|
||||
@@ -1,9 +1,15 @@
|
||||
package physical
|
||||
|
||||
import (
|
||||
"crypto/sha1"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/hashicorp/golang-lru"
|
||||
"github.com/hashicorp/vault/helper/locksutil"
|
||||
"github.com/hashicorp/vault/helper/strutil"
|
||||
log "github.com/mgutz/logxi/v1"
|
||||
)
|
||||
|
||||
@@ -17,8 +23,11 @@ const (
|
||||
// Vault are for policy objects so there is a large read reduction
|
||||
// by using a simple write-through cache.
|
||||
type Cache struct {
|
||||
backend Backend
|
||||
lru *lru.TwoQueueCache
|
||||
backend Backend
|
||||
transactional Transactional
|
||||
lru *lru.TwoQueueCache
|
||||
locks map[string]*sync.RWMutex
|
||||
logger log.Logger
|
||||
}
|
||||
|
||||
// NewCache returns a physical cache of the given size.
|
||||
@@ -34,16 +43,58 @@ func NewCache(b Backend, size int, logger log.Logger) *Cache {
|
||||
c := &Cache{
|
||||
backend: b,
|
||||
lru: cache,
|
||||
locks: make(map[string]*sync.RWMutex, 256),
|
||||
logger: logger,
|
||||
}
|
||||
if err := locksutil.CreateLocks(c.locks, 256); err != nil {
|
||||
logger.Error("physical/cache: error creating locks", "error", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
if txnl, ok := c.backend.(Transactional); ok {
|
||||
c.transactional = txnl
|
||||
}
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *Cache) lockHashForKey(key string) string {
|
||||
hf := sha1.New()
|
||||
hf.Write([]byte(key))
|
||||
return strings.ToLower(hex.EncodeToString(hf.Sum(nil))[:2])
|
||||
}
|
||||
|
||||
func (c *Cache) lockForKey(key string) *sync.RWMutex {
|
||||
return c.locks[c.lockHashForKey(key)]
|
||||
}
|
||||
|
||||
// Purge is used to clear the cache
|
||||
func (c *Cache) Purge() {
|
||||
// Lock the world
|
||||
lockHashes := make([]string, 0, len(c.locks))
|
||||
for hash := range c.locks {
|
||||
lockHashes = append(lockHashes, hash)
|
||||
}
|
||||
|
||||
// Sort and deduplicate. This ensures we don't try to grab the same lock
|
||||
// twice, and enforcing a sort means we'll not have multiple goroutines
|
||||
// deadlock by acquiring in different orders.
|
||||
lockHashes = strutil.RemoveDuplicates(lockHashes)
|
||||
|
||||
for _, lockHash := range lockHashes {
|
||||
lock := c.locks[lockHash]
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
}
|
||||
|
||||
c.lru.Purge()
|
||||
}
|
||||
|
||||
func (c *Cache) Put(entry *Entry) error {
|
||||
lock := c.lockForKey(entry.Key)
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
err := c.backend.Put(entry)
|
||||
if err == nil {
|
||||
c.lru.Add(entry.Key, entry)
|
||||
@@ -52,6 +103,10 @@ func (c *Cache) Put(entry *Entry) error {
|
||||
}
|
||||
|
||||
func (c *Cache) Get(key string) (*Entry, error) {
|
||||
lock := c.lockForKey(key)
|
||||
lock.RLock()
|
||||
defer lock.RUnlock()
|
||||
|
||||
// Check the LRU first
|
||||
if raw, ok := c.lru.Get(key); ok {
|
||||
if raw == nil {
|
||||
@@ -79,6 +134,10 @@ func (c *Cache) Get(key string) (*Entry, error) {
|
||||
}
|
||||
|
||||
func (c *Cache) Delete(key string) error {
|
||||
lock := c.lockForKey(key)
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
err := c.backend.Delete(key)
|
||||
if err == nil {
|
||||
c.lru.Remove(key)
|
||||
@@ -87,6 +146,45 @@ func (c *Cache) Delete(key string) error {
|
||||
}
|
||||
|
||||
func (c *Cache) List(prefix string) ([]string, error) {
|
||||
// Always pass-through as this would be difficult to cache.
|
||||
// Always pass-through as this would be difficult to cache. For the same
|
||||
// reason we don't lock as we can't reasonably know which locks to readlock
|
||||
// ahead of time.
|
||||
return c.backend.List(prefix)
|
||||
}
|
||||
|
||||
func (c *Cache) Transaction(txns []TxnEntry) error {
|
||||
if c.transactional == nil {
|
||||
return fmt.Errorf("physical/cache: underlying backend does not support transactions")
|
||||
}
|
||||
|
||||
var lockHashes []string
|
||||
for _, txn := range txns {
|
||||
lockHashes = append(lockHashes, c.lockHashForKey(txn.Entry.Key))
|
||||
}
|
||||
|
||||
// Sort and deduplicate. This ensures we don't try to grab the same lock
|
||||
// twice, and enforcing a sort means we'll not have multiple goroutines
|
||||
// deadlock by acquiring in different orders.
|
||||
lockHashes = strutil.RemoveDuplicates(lockHashes)
|
||||
|
||||
for _, lockHash := range lockHashes {
|
||||
lock := c.locks[lockHash]
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
}
|
||||
|
||||
if err := c.transactional.Transaction(txns); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, txn := range txns {
|
||||
switch txn.Operation {
|
||||
case PutOperation:
|
||||
c.lru.Add(txn.Entry.Key, txn.Entry)
|
||||
case DeleteOperation:
|
||||
c.lru.Remove(txn.Entry.Key)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package physical
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
@@ -21,6 +22,8 @@ import (
|
||||
"github.com/hashicorp/consul/lib"
|
||||
"github.com/hashicorp/errwrap"
|
||||
"github.com/hashicorp/go-cleanhttp"
|
||||
multierror "github.com/hashicorp/go-multierror"
|
||||
"github.com/hashicorp/vault/helper/consts"
|
||||
"github.com/hashicorp/vault/helper/strutil"
|
||||
"github.com/hashicorp/vault/helper/tlsutil"
|
||||
)
|
||||
@@ -154,6 +157,10 @@ func newConsulBackend(conf map[string]string, logger log.Logger) (Backend, error
|
||||
|
||||
// Configure the client
|
||||
consulConf := api.DefaultConfig()
|
||||
// Set MaxIdleConnsPerHost to the number of processes used in expiration.Restore
|
||||
tr := cleanhttp.DefaultPooledTransport()
|
||||
tr.MaxIdleConnsPerHost = consts.ExpirationRestoreWorkerCount
|
||||
consulConf.HttpClient.Transport = tr
|
||||
|
||||
if addr, ok := conf["address"]; ok {
|
||||
consulConf.Address = addr
|
||||
@@ -179,7 +186,7 @@ func newConsulBackend(conf map[string]string, logger log.Logger) (Backend, error
|
||||
}
|
||||
|
||||
transport := cleanhttp.DefaultPooledTransport()
|
||||
transport.MaxIdleConnsPerHost = 4
|
||||
transport.MaxIdleConnsPerHost = consts.ExpirationRestoreWorkerCount
|
||||
transport.TLSClientConfig = tlsClientConfig
|
||||
consulConf.HttpClient.Transport = transport
|
||||
logger.Debug("physical/consul: configured TLS")
|
||||
@@ -284,17 +291,59 @@ func setupTLSConfig(conf map[string]string) (*tls.Config, error) {
|
||||
return tlsClientConfig, nil
|
||||
}
|
||||
|
||||
// Used to run multiple entries via a transaction
|
||||
func (c *ConsulBackend) Transaction(txns []TxnEntry) error {
|
||||
if len(txns) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
ops := make([]*api.KVTxnOp, 0, len(txns))
|
||||
|
||||
for _, op := range txns {
|
||||
cop := &api.KVTxnOp{
|
||||
Key: c.path + op.Entry.Key,
|
||||
}
|
||||
switch op.Operation {
|
||||
case DeleteOperation:
|
||||
cop.Verb = api.KVDelete
|
||||
case PutOperation:
|
||||
cop.Verb = api.KVSet
|
||||
cop.Value = op.Entry.Value
|
||||
default:
|
||||
return fmt.Errorf("%q is not a supported transaction operation", op.Operation)
|
||||
}
|
||||
|
||||
ops = append(ops, cop)
|
||||
}
|
||||
|
||||
ok, resp, _, err := c.kv.Txn(ops, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
var retErr *multierror.Error
|
||||
for _, res := range resp.Errors {
|
||||
retErr = multierror.Append(retErr, errors.New(res.What))
|
||||
}
|
||||
|
||||
return retErr
|
||||
}
|
||||
|
||||
// Put is used to insert or update an entry
|
||||
func (c *ConsulBackend) Put(entry *Entry) error {
|
||||
defer metrics.MeasureSince([]string{"consul", "put"}, time.Now())
|
||||
|
||||
c.permitPool.Acquire()
|
||||
defer c.permitPool.Release()
|
||||
|
||||
pair := &api.KVPair{
|
||||
Key: c.path + entry.Key,
|
||||
Value: entry.Value,
|
||||
}
|
||||
|
||||
c.permitPool.Acquire()
|
||||
defer c.permitPool.Release()
|
||||
|
||||
_, err := c.kv.Put(pair, nil)
|
||||
return err
|
||||
}
|
||||
|
||||
104
physical/file.go
104
physical/file.go
@@ -22,12 +22,17 @@ import (
|
||||
// and non-performant. It is meant mostly for local testing and development.
|
||||
// It can be improved in the future.
|
||||
type FileBackend struct {
|
||||
Path string
|
||||
l sync.Mutex
|
||||
logger log.Logger
|
||||
sync.RWMutex
|
||||
path string
|
||||
logger log.Logger
|
||||
permitPool *PermitPool
|
||||
}
|
||||
|
||||
// newFileBackend constructs a Filebackend using the given directory
|
||||
type TransactionalFileBackend struct {
|
||||
FileBackend
|
||||
}
|
||||
|
||||
// newFileBackend constructs a FileBackend using the given directory
|
||||
func newFileBackend(conf map[string]string, logger log.Logger) (Backend, error) {
|
||||
path, ok := conf["path"]
|
||||
if !ok {
|
||||
@@ -35,20 +40,44 @@ func newFileBackend(conf map[string]string, logger log.Logger) (Backend, error)
|
||||
}
|
||||
|
||||
return &FileBackend{
|
||||
Path: path,
|
||||
logger: logger,
|
||||
path: path,
|
||||
logger: logger,
|
||||
permitPool: NewPermitPool(DefaultParallelOperations),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func newTransactionalFileBackend(conf map[string]string, logger log.Logger) (Backend, error) {
|
||||
path, ok := conf["path"]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("'path' must be set")
|
||||
}
|
||||
|
||||
// Create a pool of size 1 so only one operation runs at a time
|
||||
return &TransactionalFileBackend{
|
||||
FileBackend: FileBackend{
|
||||
path: path,
|
||||
logger: logger,
|
||||
permitPool: NewPermitPool(1),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (b *FileBackend) Delete(path string) error {
|
||||
b.permitPool.Acquire()
|
||||
defer b.permitPool.Release()
|
||||
|
||||
b.Lock()
|
||||
defer b.Unlock()
|
||||
|
||||
return b.DeleteInternal(path)
|
||||
}
|
||||
|
||||
func (b *FileBackend) DeleteInternal(path string) error {
|
||||
if path == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
b.l.Lock()
|
||||
defer b.l.Unlock()
|
||||
|
||||
basePath, key := b.path(path)
|
||||
basePath, key := b.expandPath(path)
|
||||
fullPath := filepath.Join(basePath, key)
|
||||
|
||||
err := os.Remove(fullPath)
|
||||
@@ -66,7 +95,7 @@ func (b *FileBackend) Delete(path string) error {
|
||||
func (b *FileBackend) cleanupLogicalPath(path string) error {
|
||||
nodes := strings.Split(path, fmt.Sprintf("%c", os.PathSeparator))
|
||||
for i := len(nodes) - 1; i > 0; i-- {
|
||||
fullPath := filepath.Join(b.Path, filepath.Join(nodes[:i]...))
|
||||
fullPath := filepath.Join(b.path, filepath.Join(nodes[:i]...))
|
||||
|
||||
dir, err := os.Open(fullPath)
|
||||
if err != nil {
|
||||
@@ -96,10 +125,17 @@ func (b *FileBackend) cleanupLogicalPath(path string) error {
|
||||
}
|
||||
|
||||
func (b *FileBackend) Get(k string) (*Entry, error) {
|
||||
b.l.Lock()
|
||||
defer b.l.Unlock()
|
||||
b.permitPool.Acquire()
|
||||
defer b.permitPool.Release()
|
||||
|
||||
path, key := b.path(k)
|
||||
b.RLock()
|
||||
defer b.RUnlock()
|
||||
|
||||
return b.GetInternal(k)
|
||||
}
|
||||
|
||||
func (b *FileBackend) GetInternal(k string) (*Entry, error) {
|
||||
path, key := b.expandPath(k)
|
||||
path = filepath.Join(path, key)
|
||||
|
||||
f, err := os.Open(path)
|
||||
@@ -121,10 +157,17 @@ func (b *FileBackend) Get(k string) (*Entry, error) {
|
||||
}
|
||||
|
||||
func (b *FileBackend) Put(entry *Entry) error {
|
||||
path, key := b.path(entry.Key)
|
||||
b.permitPool.Acquire()
|
||||
defer b.permitPool.Release()
|
||||
|
||||
b.l.Lock()
|
||||
defer b.l.Unlock()
|
||||
b.Lock()
|
||||
defer b.Unlock()
|
||||
|
||||
return b.PutInternal(entry)
|
||||
}
|
||||
|
||||
func (b *FileBackend) PutInternal(entry *Entry) error {
|
||||
path, key := b.expandPath(entry.Key)
|
||||
|
||||
// Make the parent tree
|
||||
if err := os.MkdirAll(path, 0755); err != nil {
|
||||
@@ -145,10 +188,17 @@ func (b *FileBackend) Put(entry *Entry) error {
|
||||
}
|
||||
|
||||
func (b *FileBackend) List(prefix string) ([]string, error) {
|
||||
b.l.Lock()
|
||||
defer b.l.Unlock()
|
||||
b.permitPool.Acquire()
|
||||
defer b.permitPool.Release()
|
||||
|
||||
path := b.Path
|
||||
b.RLock()
|
||||
defer b.RUnlock()
|
||||
|
||||
return b.ListInternal(prefix)
|
||||
}
|
||||
|
||||
func (b *FileBackend) ListInternal(prefix string) ([]string, error) {
|
||||
path := b.path
|
||||
if prefix != "" {
|
||||
path = filepath.Join(path, prefix)
|
||||
}
|
||||
@@ -180,9 +230,19 @@ func (b *FileBackend) List(prefix string) ([]string, error) {
|
||||
return names, nil
|
||||
}
|
||||
|
||||
func (b *FileBackend) path(k string) (string, string) {
|
||||
path := filepath.Join(b.Path, k)
|
||||
func (b *FileBackend) expandPath(k string) (string, string) {
|
||||
path := filepath.Join(b.path, k)
|
||||
key := filepath.Base(path)
|
||||
path = filepath.Dir(path)
|
||||
return path, "_" + key
|
||||
}
|
||||
|
||||
func (b *TransactionalFileBackend) Transaction(txns []TxnEntry) error {
|
||||
b.permitPool.Acquire()
|
||||
defer b.permitPool.Release()
|
||||
|
||||
b.Lock()
|
||||
defer b.Unlock()
|
||||
|
||||
return genericTransactionHandler(b, txns)
|
||||
}
|
||||
|
||||
@@ -13,12 +13,16 @@ import (
|
||||
// for testing and development situations where the data is not
|
||||
// expected to be durable.
|
||||
type InmemBackend struct {
|
||||
sync.RWMutex
|
||||
root *radix.Tree
|
||||
l sync.RWMutex
|
||||
permitPool *PermitPool
|
||||
logger log.Logger
|
||||
}
|
||||
|
||||
type TransactionalInmemBackend struct {
|
||||
InmemBackend
|
||||
}
|
||||
|
||||
// NewInmem constructs a new in-memory backend
|
||||
func NewInmem(logger log.Logger) *InmemBackend {
|
||||
in := &InmemBackend{
|
||||
@@ -29,14 +33,31 @@ func NewInmem(logger log.Logger) *InmemBackend {
|
||||
return in
|
||||
}
|
||||
|
||||
// Basically for now just creates a permit pool of size 1 so only one operation
|
||||
// can run at a time
|
||||
func NewTransactionalInmem(logger log.Logger) *TransactionalInmemBackend {
|
||||
in := &TransactionalInmemBackend{
|
||||
InmemBackend: InmemBackend{
|
||||
root: radix.New(),
|
||||
permitPool: NewPermitPool(1),
|
||||
logger: logger,
|
||||
},
|
||||
}
|
||||
return in
|
||||
}
|
||||
|
||||
// Put is used to insert or update an entry
|
||||
func (i *InmemBackend) Put(entry *Entry) error {
|
||||
i.permitPool.Acquire()
|
||||
defer i.permitPool.Release()
|
||||
|
||||
i.l.Lock()
|
||||
defer i.l.Unlock()
|
||||
i.Lock()
|
||||
defer i.Unlock()
|
||||
|
||||
return i.PutInternal(entry)
|
||||
}
|
||||
|
||||
func (i *InmemBackend) PutInternal(entry *Entry) error {
|
||||
i.root.Insert(entry.Key, entry)
|
||||
return nil
|
||||
}
|
||||
@@ -46,9 +67,13 @@ func (i *InmemBackend) Get(key string) (*Entry, error) {
|
||||
i.permitPool.Acquire()
|
||||
defer i.permitPool.Release()
|
||||
|
||||
i.l.RLock()
|
||||
defer i.l.RUnlock()
|
||||
i.RLock()
|
||||
defer i.RUnlock()
|
||||
|
||||
return i.GetInternal(key)
|
||||
}
|
||||
|
||||
func (i *InmemBackend) GetInternal(key string) (*Entry, error) {
|
||||
if raw, ok := i.root.Get(key); ok {
|
||||
return raw.(*Entry), nil
|
||||
}
|
||||
@@ -60,9 +85,13 @@ func (i *InmemBackend) Delete(key string) error {
|
||||
i.permitPool.Acquire()
|
||||
defer i.permitPool.Release()
|
||||
|
||||
i.l.Lock()
|
||||
defer i.l.Unlock()
|
||||
i.Lock()
|
||||
defer i.Unlock()
|
||||
|
||||
return i.DeleteInternal(key)
|
||||
}
|
||||
|
||||
func (i *InmemBackend) DeleteInternal(key string) error {
|
||||
i.root.Delete(key)
|
||||
return nil
|
||||
}
|
||||
@@ -73,9 +102,13 @@ func (i *InmemBackend) List(prefix string) ([]string, error) {
|
||||
i.permitPool.Acquire()
|
||||
defer i.permitPool.Release()
|
||||
|
||||
i.l.RLock()
|
||||
defer i.l.RUnlock()
|
||||
i.RLock()
|
||||
defer i.RUnlock()
|
||||
|
||||
return i.ListInternal(prefix)
|
||||
}
|
||||
|
||||
func (i *InmemBackend) ListInternal(prefix string) ([]string, error) {
|
||||
var out []string
|
||||
seen := make(map[string]interface{})
|
||||
walkFn := func(s string, v interface{}) bool {
|
||||
@@ -96,3 +129,14 @@ func (i *InmemBackend) List(prefix string) ([]string, error) {
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// Implements the transaction interface
|
||||
func (t *TransactionalInmemBackend) Transaction(txns []TxnEntry) error {
|
||||
t.permitPool.Acquire()
|
||||
defer t.permitPool.Release()
|
||||
|
||||
t.Lock()
|
||||
defer t.Unlock()
|
||||
|
||||
return genericTransactionHandler(t, txns)
|
||||
}
|
||||
|
||||
@@ -8,19 +8,40 @@ import (
|
||||
)
|
||||
|
||||
type InmemHABackend struct {
|
||||
InmemBackend
|
||||
Backend
|
||||
locks map[string]string
|
||||
l sync.Mutex
|
||||
cond *sync.Cond
|
||||
logger log.Logger
|
||||
}
|
||||
|
||||
type TransactionalInmemHABackend struct {
|
||||
Transactional
|
||||
InmemHABackend
|
||||
}
|
||||
|
||||
// NewInmemHA constructs a new in-memory HA backend. This is only for testing.
|
||||
func NewInmemHA(logger log.Logger) *InmemHABackend {
|
||||
in := &InmemHABackend{
|
||||
InmemBackend: *NewInmem(logger),
|
||||
locks: make(map[string]string),
|
||||
logger: logger,
|
||||
Backend: NewInmem(logger),
|
||||
locks: make(map[string]string),
|
||||
logger: logger,
|
||||
}
|
||||
in.cond = sync.NewCond(&in.l)
|
||||
return in
|
||||
}
|
||||
|
||||
func NewTransactionalInmemHA(logger log.Logger) *TransactionalInmemHABackend {
|
||||
transInmem := NewTransactionalInmem(logger)
|
||||
inmemHA := InmemHABackend{
|
||||
Backend: transInmem,
|
||||
locks: make(map[string]string),
|
||||
logger: logger,
|
||||
}
|
||||
|
||||
in := &TransactionalInmemHABackend{
|
||||
InmemHABackend: inmemHA,
|
||||
Transactional: transInmem,
|
||||
}
|
||||
in.cond = sync.NewCond(&in.l)
|
||||
return in
|
||||
|
||||
@@ -9,6 +9,16 @@ import (
|
||||
|
||||
const DefaultParallelOperations = 128
|
||||
|
||||
// The operation type
|
||||
type Operation string
|
||||
|
||||
const (
|
||||
DeleteOperation Operation = "delete"
|
||||
GetOperation = "get"
|
||||
ListOperation = "list"
|
||||
PutOperation = "put"
|
||||
)
|
||||
|
||||
// ShutdownSignal
|
||||
type ShutdownChannel chan struct{}
|
||||
|
||||
@@ -121,20 +131,27 @@ var builtinBackends = map[string]Factory{
|
||||
"inmem": func(_ map[string]string, logger log.Logger) (Backend, error) {
|
||||
return NewInmem(logger), nil
|
||||
},
|
||||
"inmem_transactional": func(_ map[string]string, logger log.Logger) (Backend, error) {
|
||||
return NewTransactionalInmem(logger), nil
|
||||
},
|
||||
"inmem_ha": func(_ map[string]string, logger log.Logger) (Backend, error) {
|
||||
return NewInmemHA(logger), nil
|
||||
},
|
||||
"consul": newConsulBackend,
|
||||
"zookeeper": newZookeeperBackend,
|
||||
"file": newFileBackend,
|
||||
"s3": newS3Backend,
|
||||
"azure": newAzureBackend,
|
||||
"dynamodb": newDynamoDBBackend,
|
||||
"etcd": newEtcdBackend,
|
||||
"mysql": newMySQLBackend,
|
||||
"postgresql": newPostgreSQLBackend,
|
||||
"swift": newSwiftBackend,
|
||||
"gcs": newGCSBackend,
|
||||
"inmem_transactional_ha": func(_ map[string]string, logger log.Logger) (Backend, error) {
|
||||
return NewTransactionalInmemHA(logger), nil
|
||||
},
|
||||
"file_transactional": newTransactionalFileBackend,
|
||||
"consul": newConsulBackend,
|
||||
"zookeeper": newZookeeperBackend,
|
||||
"file": newFileBackend,
|
||||
"s3": newS3Backend,
|
||||
"azure": newAzureBackend,
|
||||
"dynamodb": newDynamoDBBackend,
|
||||
"etcd": newEtcdBackend,
|
||||
"mysql": newMySQLBackend,
|
||||
"postgresql": newPostgreSQLBackend,
|
||||
"swift": newSwiftBackend,
|
||||
"gcs": newGCSBackend,
|
||||
}
|
||||
|
||||
// PermitPool is used to limit maximum outstanding requests
|
||||
|
||||
@@ -71,7 +71,8 @@ func newPostgreSQLBackend(conf map[string]string, logger log.Logger) (Backend, e
|
||||
get_query: "SELECT value FROM " + quoted_table + " WHERE path = $1 AND key = $2",
|
||||
delete_query: "DELETE FROM " + quoted_table + " WHERE path = $1 AND key = $2",
|
||||
list_query: "SELECT key FROM " + quoted_table + " WHERE path = $1" +
|
||||
"UNION SELECT substr(path, length($1)+1) FROM " + quoted_table + "WHERE parent_path = $1",
|
||||
"UNION SELECT DISTINCT substring(substr(path, length($1)+1) from '^.*?/') FROM " +
|
||||
quoted_table + " WHERE parent_path LIKE concat($1, '%')",
|
||||
logger: logger,
|
||||
}
|
||||
|
||||
|
||||
121
physical/transactions.go
Normal file
121
physical/transactions.go
Normal file
@@ -0,0 +1,121 @@
|
||||
package physical
|
||||
|
||||
import multierror "github.com/hashicorp/go-multierror"
|
||||
|
||||
// TxnEntry is an operation that takes atomically as part of
|
||||
// a transactional update. Only supported by Transactional backends.
|
||||
type TxnEntry struct {
|
||||
Operation Operation
|
||||
Entry *Entry
|
||||
}
|
||||
|
||||
// Transactional is an optional interface for backends that
|
||||
// support doing transactional updates of multiple keys. This is
|
||||
// required for some features such as replication.
|
||||
type Transactional interface {
|
||||
// The function to run a transaction
|
||||
Transaction([]TxnEntry) error
|
||||
}
|
||||
|
||||
type PseudoTransactional interface {
|
||||
// An internal function should do no locking or permit pool acquisition.
|
||||
// Depending on the backend and if it natively supports transactions, these
|
||||
// may simply chain to the normal backend functions.
|
||||
GetInternal(string) (*Entry, error)
|
||||
PutInternal(*Entry) error
|
||||
DeleteInternal(string) error
|
||||
}
|
||||
|
||||
// Implements the transaction interface
|
||||
func genericTransactionHandler(t PseudoTransactional, txns []TxnEntry) (retErr error) {
|
||||
rollbackStack := make([]TxnEntry, 0, len(txns))
|
||||
var dirty bool
|
||||
|
||||
// We walk the transactions in order; each successful operation goes into a
|
||||
// LIFO for rollback if we hit an error along the way
|
||||
TxnWalk:
|
||||
for _, txn := range txns {
|
||||
switch txn.Operation {
|
||||
case DeleteOperation:
|
||||
entry, err := t.GetInternal(txn.Entry.Key)
|
||||
if err != nil {
|
||||
retErr = multierror.Append(retErr, err)
|
||||
dirty = true
|
||||
break TxnWalk
|
||||
}
|
||||
if entry == nil {
|
||||
// Nothing to delete or roll back
|
||||
continue
|
||||
}
|
||||
rollbackEntry := TxnEntry{
|
||||
Operation: PutOperation,
|
||||
Entry: &Entry{
|
||||
Key: entry.Key,
|
||||
Value: entry.Value,
|
||||
},
|
||||
}
|
||||
err = t.DeleteInternal(txn.Entry.Key)
|
||||
if err != nil {
|
||||
retErr = multierror.Append(retErr, err)
|
||||
dirty = true
|
||||
break TxnWalk
|
||||
}
|
||||
rollbackStack = append([]TxnEntry{rollbackEntry}, rollbackStack...)
|
||||
|
||||
case PutOperation:
|
||||
entry, err := t.GetInternal(txn.Entry.Key)
|
||||
if err != nil {
|
||||
retErr = multierror.Append(retErr, err)
|
||||
dirty = true
|
||||
break TxnWalk
|
||||
}
|
||||
// Nothing existed so in fact rolling back requires a delete
|
||||
var rollbackEntry TxnEntry
|
||||
if entry == nil {
|
||||
rollbackEntry = TxnEntry{
|
||||
Operation: DeleteOperation,
|
||||
Entry: &Entry{
|
||||
Key: txn.Entry.Key,
|
||||
},
|
||||
}
|
||||
} else {
|
||||
rollbackEntry = TxnEntry{
|
||||
Operation: PutOperation,
|
||||
Entry: &Entry{
|
||||
Key: entry.Key,
|
||||
Value: entry.Value,
|
||||
},
|
||||
}
|
||||
}
|
||||
err = t.PutInternal(txn.Entry)
|
||||
if err != nil {
|
||||
retErr = multierror.Append(retErr, err)
|
||||
dirty = true
|
||||
break TxnWalk
|
||||
}
|
||||
rollbackStack = append([]TxnEntry{rollbackEntry}, rollbackStack...)
|
||||
}
|
||||
}
|
||||
|
||||
// Need to roll back because we hit an error along the way
|
||||
if dirty {
|
||||
// While traversing this, if we get an error, we continue anyways in
|
||||
// best-effort fashion
|
||||
for _, txn := range rollbackStack {
|
||||
switch txn.Operation {
|
||||
case DeleteOperation:
|
||||
err := t.DeleteInternal(txn.Entry.Key)
|
||||
if err != nil {
|
||||
retErr = multierror.Append(retErr, err)
|
||||
}
|
||||
case PutOperation:
|
||||
err := t.PutInternal(txn.Entry)
|
||||
if err != nil {
|
||||
retErr = multierror.Append(retErr, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
254
physical/transactions_test.go
Normal file
254
physical/transactions_test.go
Normal file
@@ -0,0 +1,254 @@
|
||||
package physical
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"sort"
|
||||
"testing"
|
||||
|
||||
radix "github.com/armon/go-radix"
|
||||
"github.com/hashicorp/vault/helper/logformat"
|
||||
log "github.com/mgutz/logxi/v1"
|
||||
)
|
||||
|
||||
type faultyPseudo struct {
|
||||
underlying InmemBackend
|
||||
faultyPaths map[string]struct{}
|
||||
}
|
||||
|
||||
func (f *faultyPseudo) Get(key string) (*Entry, error) {
|
||||
return f.underlying.Get(key)
|
||||
}
|
||||
|
||||
func (f *faultyPseudo) Put(entry *Entry) error {
|
||||
return f.underlying.Put(entry)
|
||||
}
|
||||
|
||||
func (f *faultyPseudo) Delete(key string) error {
|
||||
return f.underlying.Delete(key)
|
||||
}
|
||||
|
||||
func (f *faultyPseudo) GetInternal(key string) (*Entry, error) {
|
||||
if _, ok := f.faultyPaths[key]; ok {
|
||||
return nil, fmt.Errorf("fault")
|
||||
}
|
||||
return f.underlying.GetInternal(key)
|
||||
}
|
||||
|
||||
func (f *faultyPseudo) PutInternal(entry *Entry) error {
|
||||
if _, ok := f.faultyPaths[entry.Key]; ok {
|
||||
return fmt.Errorf("fault")
|
||||
}
|
||||
return f.underlying.PutInternal(entry)
|
||||
}
|
||||
|
||||
func (f *faultyPseudo) DeleteInternal(key string) error {
|
||||
if _, ok := f.faultyPaths[key]; ok {
|
||||
return fmt.Errorf("fault")
|
||||
}
|
||||
return f.underlying.DeleteInternal(key)
|
||||
}
|
||||
|
||||
func (f *faultyPseudo) List(prefix string) ([]string, error) {
|
||||
return f.underlying.List(prefix)
|
||||
}
|
||||
|
||||
func (f *faultyPseudo) Transaction(txns []TxnEntry) error {
|
||||
f.underlying.permitPool.Acquire()
|
||||
defer f.underlying.permitPool.Release()
|
||||
|
||||
f.underlying.Lock()
|
||||
defer f.underlying.Unlock()
|
||||
|
||||
return genericTransactionHandler(f, txns)
|
||||
}
|
||||
|
||||
func newFaultyPseudo(logger log.Logger, faultyPaths []string) *faultyPseudo {
|
||||
out := &faultyPseudo{
|
||||
underlying: InmemBackend{
|
||||
root: radix.New(),
|
||||
permitPool: NewPermitPool(1),
|
||||
logger: logger,
|
||||
},
|
||||
faultyPaths: make(map[string]struct{}, len(faultyPaths)),
|
||||
}
|
||||
for _, v := range faultyPaths {
|
||||
out.faultyPaths[v] = struct{}{}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func TestPseudo_Basic(t *testing.T) {
|
||||
logger := logformat.NewVaultLogger(log.LevelTrace)
|
||||
p := newFaultyPseudo(logger, nil)
|
||||
testBackend(t, p)
|
||||
testBackend_ListPrefix(t, p)
|
||||
}
|
||||
|
||||
func TestPseudo_SuccessfulTransaction(t *testing.T) {
|
||||
logger := logformat.NewVaultLogger(log.LevelTrace)
|
||||
p := newFaultyPseudo(logger, nil)
|
||||
|
||||
txns := setupPseudo(p, t)
|
||||
|
||||
if err := p.Transaction(txns); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
keys, err := p.List("")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
expected := []string{"foo", "zip"}
|
||||
|
||||
sort.Strings(keys)
|
||||
sort.Strings(expected)
|
||||
if !reflect.DeepEqual(keys, expected) {
|
||||
t.Fatalf("mismatch: expected\n%#v\ngot\n%#v\n", expected, keys)
|
||||
}
|
||||
|
||||
entry, err := p.Get("foo")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if entry == nil {
|
||||
t.Fatal("got nil entry")
|
||||
}
|
||||
if entry.Value == nil {
|
||||
t.Fatal("got nil value")
|
||||
}
|
||||
if string(entry.Value) != "bar3" {
|
||||
t.Fatal("updates did not apply correctly")
|
||||
}
|
||||
|
||||
entry, err = p.Get("zip")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if entry == nil {
|
||||
t.Fatal("got nil entry")
|
||||
}
|
||||
if entry.Value == nil {
|
||||
t.Fatal("got nil value")
|
||||
}
|
||||
if string(entry.Value) != "zap3" {
|
||||
t.Fatal("updates did not apply correctly")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPseudo_FailedTransaction(t *testing.T) {
|
||||
logger := logformat.NewVaultLogger(log.LevelTrace)
|
||||
p := newFaultyPseudo(logger, []string{"zip"})
|
||||
|
||||
txns := setupPseudo(p, t)
|
||||
|
||||
if err := p.Transaction(txns); err == nil {
|
||||
t.Fatal("expected error during transaction")
|
||||
}
|
||||
|
||||
keys, err := p.List("")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
expected := []string{"foo", "zip", "deleteme", "deleteme2"}
|
||||
|
||||
sort.Strings(keys)
|
||||
sort.Strings(expected)
|
||||
if !reflect.DeepEqual(keys, expected) {
|
||||
t.Fatalf("mismatch: expected\n%#v\ngot\n%#v\n", expected, keys)
|
||||
}
|
||||
|
||||
entry, err := p.Get("foo")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if entry == nil {
|
||||
t.Fatal("got nil entry")
|
||||
}
|
||||
if entry.Value == nil {
|
||||
t.Fatal("got nil value")
|
||||
}
|
||||
if string(entry.Value) != "bar" {
|
||||
t.Fatal("values did not rollback correctly")
|
||||
}
|
||||
|
||||
entry, err = p.Get("zip")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if entry == nil {
|
||||
t.Fatal("got nil entry")
|
||||
}
|
||||
if entry.Value == nil {
|
||||
t.Fatal("got nil value")
|
||||
}
|
||||
if string(entry.Value) != "zap" {
|
||||
t.Fatal("values did not rollback correctly")
|
||||
}
|
||||
}
|
||||
|
||||
func setupPseudo(p *faultyPseudo, t *testing.T) []TxnEntry {
|
||||
// Add a few keys so that we test rollback with deletion
|
||||
if err := p.Put(&Entry{
|
||||
Key: "foo",
|
||||
Value: []byte("bar"),
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := p.Put(&Entry{
|
||||
Key: "zip",
|
||||
Value: []byte("zap"),
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := p.Put(&Entry{
|
||||
Key: "deleteme",
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := p.Put(&Entry{
|
||||
Key: "deleteme2",
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
txns := []TxnEntry{
|
||||
TxnEntry{
|
||||
Operation: PutOperation,
|
||||
Entry: &Entry{
|
||||
Key: "foo",
|
||||
Value: []byte("bar2"),
|
||||
},
|
||||
},
|
||||
TxnEntry{
|
||||
Operation: DeleteOperation,
|
||||
Entry: &Entry{
|
||||
Key: "deleteme",
|
||||
},
|
||||
},
|
||||
TxnEntry{
|
||||
Operation: PutOperation,
|
||||
Entry: &Entry{
|
||||
Key: "foo",
|
||||
Value: []byte("bar3"),
|
||||
},
|
||||
},
|
||||
TxnEntry{
|
||||
Operation: DeleteOperation,
|
||||
Entry: &Entry{
|
||||
Key: "deleteme2",
|
||||
},
|
||||
},
|
||||
TxnEntry{
|
||||
Operation: PutOperation,
|
||||
Entry: &Entry{
|
||||
Key: "zip",
|
||||
Value: []byte("zap3"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
return txns
|
||||
}
|
||||
@@ -10,7 +10,7 @@ RUN apt-get update -y && apt-get install --no-install-recommends -y -q \
|
||||
git mercurial bzr \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
ENV GOVERSION 1.8rc3
|
||||
ENV GOVERSION 1.8
|
||||
RUN mkdir /goroot && mkdir /gopath
|
||||
RUN curl https://storage.googleapis.com/golang/go${GOVERSION}.linux-amd64.tar.gz \
|
||||
| tar xvzf - -C /goroot --strip-components=1
|
||||
|
||||
@@ -2,7 +2,6 @@ package vault
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
@@ -26,6 +25,10 @@ const (
|
||||
// can only be viewed or modified after an unseal.
|
||||
coreAuditConfigPath = "core/audit"
|
||||
|
||||
// coreLocalAuditConfigPath is used to store audit information for local
|
||||
// (non-replicated) mounts
|
||||
coreLocalAuditConfigPath = "core/local-audit"
|
||||
|
||||
// auditBarrierPrefix is the prefix to the UUID used in the
|
||||
// barrier view for the audit backends.
|
||||
auditBarrierPrefix = "audit/"
|
||||
@@ -69,12 +72,15 @@ func (c *Core) enableAudit(entry *MountEntry) error {
|
||||
}
|
||||
|
||||
// Generate a new UUID and view
|
||||
entryUUID, err := uuid.GenerateUUID()
|
||||
if err != nil {
|
||||
return err
|
||||
if entry.UUID == "" {
|
||||
entryUUID, err := uuid.GenerateUUID()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
entry.UUID = entryUUID
|
||||
}
|
||||
entry.UUID = entryUUID
|
||||
view := NewBarrierView(c.barrier, auditBarrierPrefix+entry.UUID+"/")
|
||||
viewPath := auditBarrierPrefix + entry.UUID + "/"
|
||||
view := NewBarrierView(c.barrier, viewPath)
|
||||
|
||||
// Lookup the new backend
|
||||
backend, err := c.newAuditBackend(entry, view, entry.Options)
|
||||
@@ -119,6 +125,12 @@ func (c *Core) disableAudit(path string) (bool, error) {
|
||||
|
||||
c.removeAuditReloadFunc(entry)
|
||||
|
||||
// When unmounting all entries the JSON code will load back up from storage
|
||||
// as a nil slice, which kills tests...just set it nil explicitly
|
||||
if len(newTable.Entries) == 0 {
|
||||
newTable.Entries = nil
|
||||
}
|
||||
|
||||
// Update the audit table
|
||||
if err := c.persistAudit(newTable); err != nil {
|
||||
return true, errors.New("failed to update audit table")
|
||||
@@ -131,12 +143,14 @@ func (c *Core) disableAudit(path string) (bool, error) {
|
||||
if c.logger.IsInfo() {
|
||||
c.logger.Info("core: disabled audit backend", "path", path)
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// loadAudits is invoked as part of postUnseal to load the audit table
|
||||
func (c *Core) loadAudits() error {
|
||||
auditTable := &MountTable{}
|
||||
localAuditTable := &MountTable{}
|
||||
|
||||
// Load the existing audit table
|
||||
raw, err := c.barrier.Get(coreAuditConfigPath)
|
||||
@@ -144,6 +158,11 @@ func (c *Core) loadAudits() error {
|
||||
c.logger.Error("core: failed to read audit table", "error", err)
|
||||
return errLoadAuditFailed
|
||||
}
|
||||
rawLocal, err := c.barrier.Get(coreLocalAuditConfigPath)
|
||||
if err != nil {
|
||||
c.logger.Error("core: failed to read local audit table", "error", err)
|
||||
return errLoadAuditFailed
|
||||
}
|
||||
|
||||
c.auditLock.Lock()
|
||||
defer c.auditLock.Unlock()
|
||||
@@ -155,6 +174,13 @@ func (c *Core) loadAudits() error {
|
||||
}
|
||||
c.audit = auditTable
|
||||
}
|
||||
if rawLocal != nil {
|
||||
if err := jsonutil.DecodeJSON(rawLocal.Value, localAuditTable); err != nil {
|
||||
c.logger.Error("core: failed to decode local audit table", "error", err)
|
||||
return errLoadAuditFailed
|
||||
}
|
||||
c.audit.Entries = append(c.audit.Entries, localAuditTable.Entries...)
|
||||
}
|
||||
|
||||
// Done if we have restored the audit table
|
||||
if c.audit != nil {
|
||||
@@ -203,17 +229,33 @@ func (c *Core) persistAudit(table *MountTable) error {
|
||||
}
|
||||
}
|
||||
|
||||
nonLocalAudit := &MountTable{
|
||||
Type: auditTableType,
|
||||
}
|
||||
|
||||
localAudit := &MountTable{
|
||||
Type: auditTableType,
|
||||
}
|
||||
|
||||
for _, entry := range table.Entries {
|
||||
if entry.Local {
|
||||
localAudit.Entries = append(localAudit.Entries, entry)
|
||||
} else {
|
||||
nonLocalAudit.Entries = append(nonLocalAudit.Entries, entry)
|
||||
}
|
||||
}
|
||||
|
||||
// Marshal the table
|
||||
raw, err := json.Marshal(table)
|
||||
compressedBytes, err := jsonutil.EncodeJSONAndCompress(nonLocalAudit, nil)
|
||||
if err != nil {
|
||||
c.logger.Error("core: failed to encode audit table", "error", err)
|
||||
c.logger.Error("core: failed to encode and/or compress audit table", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// Create an entry
|
||||
entry := &Entry{
|
||||
Key: coreAuditConfigPath,
|
||||
Value: raw,
|
||||
Value: compressedBytes,
|
||||
}
|
||||
|
||||
// Write to the physical backend
|
||||
@@ -221,6 +263,24 @@ func (c *Core) persistAudit(table *MountTable) error {
|
||||
c.logger.Error("core: failed to persist audit table", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// Repeat with local audit
|
||||
compressedBytes, err = jsonutil.EncodeJSONAndCompress(localAudit, nil)
|
||||
if err != nil {
|
||||
c.logger.Error("core: failed to encode and/or compress local audit table", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
entry = &Entry{
|
||||
Key: coreLocalAuditConfigPath,
|
||||
Value: compressedBytes,
|
||||
}
|
||||
|
||||
if err := c.barrier.Put(entry); err != nil {
|
||||
c.logger.Error("core: failed to persist local audit table", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -236,7 +296,8 @@ func (c *Core) setupAudits() error {
|
||||
|
||||
for _, entry := range c.audit.Entries {
|
||||
// Create a barrier view using the UUID
|
||||
view := NewBarrierView(c.barrier, auditBarrierPrefix+entry.UUID+"/")
|
||||
viewPath := auditBarrierPrefix + entry.UUID + "/"
|
||||
view := NewBarrierView(c.barrier, viewPath)
|
||||
|
||||
// Initialize the backend
|
||||
audit, err := c.newAuditBackend(entry, view, entry.Options)
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/hashicorp/errwrap"
|
||||
"github.com/hashicorp/go-uuid"
|
||||
"github.com/hashicorp/vault/audit"
|
||||
"github.com/hashicorp/vault/helper/jsonutil"
|
||||
"github.com/hashicorp/vault/helper/logformat"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
log "github.com/mgutz/logxi/v1"
|
||||
@@ -164,6 +165,94 @@ func TestCore_EnableAudit_MixedFailures(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// Test that the local table actually gets populated as expected with local
|
||||
// entries, and that upon reading the entries from both are recombined
|
||||
// correctly
|
||||
func TestCore_EnableAudit_Local(t *testing.T) {
|
||||
c, _, _ := TestCoreUnsealed(t)
|
||||
c.auditBackends["noop"] = func(config *audit.BackendConfig) (audit.Backend, error) {
|
||||
return &NoopAudit{
|
||||
Config: config,
|
||||
}, nil
|
||||
}
|
||||
|
||||
c.auditBackends["fail"] = func(config *audit.BackendConfig) (audit.Backend, error) {
|
||||
return nil, fmt.Errorf("failing enabling")
|
||||
}
|
||||
|
||||
c.audit = &MountTable{
|
||||
Type: auditTableType,
|
||||
Entries: []*MountEntry{
|
||||
&MountEntry{
|
||||
Table: auditTableType,
|
||||
Path: "noop/",
|
||||
Type: "noop",
|
||||
UUID: "abcd",
|
||||
},
|
||||
&MountEntry{
|
||||
Table: auditTableType,
|
||||
Path: "noop2/",
|
||||
Type: "noop",
|
||||
UUID: "bcde",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Both should set up successfully
|
||||
err := c.setupAudits()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
rawLocal, err := c.barrier.Get(coreLocalAuditConfigPath)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if rawLocal == nil {
|
||||
t.Fatal("expected non-nil local audit")
|
||||
}
|
||||
localAuditTable := &MountTable{}
|
||||
if err := jsonutil.DecodeJSON(rawLocal.Value, localAuditTable); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(localAuditTable.Entries) > 0 {
|
||||
t.Fatalf("expected no entries in local audit table, got %#v", localAuditTable)
|
||||
}
|
||||
|
||||
c.audit.Entries[1].Local = true
|
||||
if err := c.persistAudit(c.audit); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
rawLocal, err = c.barrier.Get(coreLocalAuditConfigPath)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if rawLocal == nil {
|
||||
t.Fatal("expected non-nil local audit")
|
||||
}
|
||||
localAuditTable = &MountTable{}
|
||||
if err := jsonutil.DecodeJSON(rawLocal.Value, localAuditTable); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(localAuditTable.Entries) != 1 {
|
||||
t.Fatalf("expected one entry in local audit table, got %#v", localAuditTable)
|
||||
}
|
||||
|
||||
oldAudit := c.audit
|
||||
if err := c.loadAudits(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(oldAudit, c.audit) {
|
||||
t.Fatalf("expected\n%#v\ngot\n%#v\n", oldAudit, c.audit)
|
||||
}
|
||||
|
||||
if len(c.audit.Entries) != 2 {
|
||||
t.Fatalf("expected two audit entries, got %#v", localAuditTable)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCore_DisableAudit(t *testing.T) {
|
||||
c, keys, _ := TestCoreUnsealed(t)
|
||||
c.auditBackends["noop"] = func(config *audit.BackendConfig) (audit.Backend, error) {
|
||||
@@ -217,7 +306,7 @@ func TestCore_DisableAudit(t *testing.T) {
|
||||
|
||||
// Verify matching mount tables
|
||||
if !reflect.DeepEqual(c.audit, c2.audit) {
|
||||
t.Fatalf("mismatch: %v %v", c.audit, c2.audit)
|
||||
t.Fatalf("mismatch:\n%#v\n%#v", c.audit, c2.audit)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package vault
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
@@ -17,6 +16,10 @@ const (
|
||||
// can only be viewed or modified after an unseal.
|
||||
coreAuthConfigPath = "core/auth"
|
||||
|
||||
// coreLocalAuthConfigPath is used to store credential configuration for
|
||||
// local (non-replicated) mounts
|
||||
coreLocalAuthConfigPath = "core/local-auth"
|
||||
|
||||
// credentialBarrierPrefix is the prefix to the UUID used in the
|
||||
// barrier view for the credential backends.
|
||||
credentialBarrierPrefix = "auth/"
|
||||
@@ -71,16 +74,25 @@ func (c *Core) enableCredential(entry *MountEntry) error {
|
||||
}
|
||||
|
||||
// Generate a new UUID and view
|
||||
entryUUID, err := uuid.GenerateUUID()
|
||||
if entry.UUID == "" {
|
||||
entryUUID, err := uuid.GenerateUUID()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
entry.UUID = entryUUID
|
||||
}
|
||||
|
||||
viewPath := credentialBarrierPrefix + entry.UUID + "/"
|
||||
view := NewBarrierView(c.barrier, viewPath)
|
||||
sysView := c.mountEntrySysView(entry)
|
||||
|
||||
// Create the new backend
|
||||
backend, err := c.newCredentialBackend(entry.Type, sysView, view, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
entry.UUID = entryUUID
|
||||
view := NewBarrierView(c.barrier, credentialBarrierPrefix+entry.UUID+"/")
|
||||
|
||||
// Create the new backend
|
||||
backend, err := c.newCredentialBackend(entry.Type, c.mountEntrySysView(entry), view, nil)
|
||||
if err != nil {
|
||||
if err := backend.Initialize(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -121,7 +133,7 @@ func (c *Core) disableCredential(path string) (bool, error) {
|
||||
fullPath := credentialRoutePrefix + path
|
||||
view := c.router.MatchingStorageView(fullPath)
|
||||
if view == nil {
|
||||
return false, fmt.Errorf("no matching backend")
|
||||
return false, fmt.Errorf("no matching backend %s", fullPath)
|
||||
}
|
||||
|
||||
// Mark the entry as tainted
|
||||
@@ -206,12 +218,19 @@ func (c *Core) taintCredEntry(path string) error {
|
||||
// loadCredentials is invoked as part of postUnseal to load the auth table
|
||||
func (c *Core) loadCredentials() error {
|
||||
authTable := &MountTable{}
|
||||
localAuthTable := &MountTable{}
|
||||
|
||||
// Load the existing mount table
|
||||
raw, err := c.barrier.Get(coreAuthConfigPath)
|
||||
if err != nil {
|
||||
c.logger.Error("core: failed to read auth table", "error", err)
|
||||
return errLoadAuthFailed
|
||||
}
|
||||
rawLocal, err := c.barrier.Get(coreLocalAuthConfigPath)
|
||||
if err != nil {
|
||||
c.logger.Error("core: failed to read local auth table", "error", err)
|
||||
return errLoadAuthFailed
|
||||
}
|
||||
|
||||
c.authLock.Lock()
|
||||
defer c.authLock.Unlock()
|
||||
@@ -223,6 +242,13 @@ func (c *Core) loadCredentials() error {
|
||||
}
|
||||
c.auth = authTable
|
||||
}
|
||||
if rawLocal != nil {
|
||||
if err := jsonutil.DecodeJSON(rawLocal.Value, localAuthTable); err != nil {
|
||||
c.logger.Error("core: failed to decode local auth table", "error", err)
|
||||
return errLoadAuthFailed
|
||||
}
|
||||
c.auth.Entries = append(c.auth.Entries, localAuthTable.Entries...)
|
||||
}
|
||||
|
||||
// Done if we have restored the auth table
|
||||
if c.auth != nil {
|
||||
@@ -272,17 +298,33 @@ func (c *Core) persistAuth(table *MountTable) error {
|
||||
}
|
||||
}
|
||||
|
||||
nonLocalAuth := &MountTable{
|
||||
Type: credentialTableType,
|
||||
}
|
||||
|
||||
localAuth := &MountTable{
|
||||
Type: credentialTableType,
|
||||
}
|
||||
|
||||
for _, entry := range table.Entries {
|
||||
if entry.Local {
|
||||
localAuth.Entries = append(localAuth.Entries, entry)
|
||||
} else {
|
||||
nonLocalAuth.Entries = append(nonLocalAuth.Entries, entry)
|
||||
}
|
||||
}
|
||||
|
||||
// Marshal the table
|
||||
raw, err := json.Marshal(table)
|
||||
compressedBytes, err := jsonutil.EncodeJSONAndCompress(nonLocalAuth, nil)
|
||||
if err != nil {
|
||||
c.logger.Error("core: failed to encode auth table", "error", err)
|
||||
c.logger.Error("core: failed to encode and/or compress auth table", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// Create an entry
|
||||
entry := &Entry{
|
||||
Key: coreAuthConfigPath,
|
||||
Value: raw,
|
||||
Value: compressedBytes,
|
||||
}
|
||||
|
||||
// Write to the physical backend
|
||||
@@ -290,6 +332,24 @@ func (c *Core) persistAuth(table *MountTable) error {
|
||||
c.logger.Error("core: failed to persist auth table", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// Repeat with local auth
|
||||
compressedBytes, err = jsonutil.EncodeJSONAndCompress(localAuth, nil)
|
||||
if err != nil {
|
||||
c.logger.Error("core: failed to encode and/or compress local auth table", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
entry = &Entry{
|
||||
Key: coreLocalAuthConfigPath,
|
||||
Value: compressedBytes,
|
||||
}
|
||||
|
||||
if err := c.barrier.Put(entry); err != nil {
|
||||
c.logger.Error("core: failed to persist local auth table", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -312,15 +372,21 @@ func (c *Core) setupCredentials() error {
|
||||
}
|
||||
|
||||
// Create a barrier view using the UUID
|
||||
view = NewBarrierView(c.barrier, credentialBarrierPrefix+entry.UUID+"/")
|
||||
viewPath := credentialBarrierPrefix + entry.UUID + "/"
|
||||
view = NewBarrierView(c.barrier, viewPath)
|
||||
sysView := c.mountEntrySysView(entry)
|
||||
|
||||
// Initialize the backend
|
||||
backend, err = c.newCredentialBackend(entry.Type, c.mountEntrySysView(entry), view, nil)
|
||||
backend, err = c.newCredentialBackend(entry.Type, sysView, view, nil)
|
||||
if err != nil {
|
||||
c.logger.Error("core: failed to create credential entry", "path", entry.Path, "error", err)
|
||||
return errLoadAuthFailed
|
||||
}
|
||||
|
||||
if err := backend.Initialize(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Mount the backend
|
||||
path := credentialRoutePrefix + entry.Path
|
||||
err = c.router.Mount(backend, path, entry, view)
|
||||
|
||||
@@ -2,8 +2,10 @@ package vault
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/hashicorp/vault/helper/jsonutil"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
)
|
||||
|
||||
@@ -84,6 +86,88 @@ func TestCore_EnableCredential(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// Test that the local table actually gets populated as expected with local
|
||||
// entries, and that upon reading the entries from both are recombined
|
||||
// correctly
|
||||
func TestCore_EnableCredential_Local(t *testing.T) {
|
||||
c, _, _ := TestCoreUnsealed(t)
|
||||
c.credentialBackends["noop"] = func(*logical.BackendConfig) (logical.Backend, error) {
|
||||
return &NoopBackend{}, nil
|
||||
}
|
||||
|
||||
c.auth = &MountTable{
|
||||
Type: credentialTableType,
|
||||
Entries: []*MountEntry{
|
||||
&MountEntry{
|
||||
Table: credentialTableType,
|
||||
Path: "noop/",
|
||||
Type: "noop",
|
||||
UUID: "abcd",
|
||||
},
|
||||
&MountEntry{
|
||||
Table: credentialTableType,
|
||||
Path: "noop2/",
|
||||
Type: "noop",
|
||||
UUID: "bcde",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Both should set up successfully
|
||||
err := c.setupCredentials()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
rawLocal, err := c.barrier.Get(coreLocalAuthConfigPath)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if rawLocal == nil {
|
||||
t.Fatal("expected non-nil local credential")
|
||||
}
|
||||
localCredentialTable := &MountTable{}
|
||||
if err := jsonutil.DecodeJSON(rawLocal.Value, localCredentialTable); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(localCredentialTable.Entries) > 0 {
|
||||
t.Fatalf("expected no entries in local credential table, got %#v", localCredentialTable)
|
||||
}
|
||||
|
||||
c.auth.Entries[1].Local = true
|
||||
if err := c.persistAuth(c.auth); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
rawLocal, err = c.barrier.Get(coreLocalAuthConfigPath)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if rawLocal == nil {
|
||||
t.Fatal("expected non-nil local credential")
|
||||
}
|
||||
localCredentialTable = &MountTable{}
|
||||
if err := jsonutil.DecodeJSON(rawLocal.Value, localCredentialTable); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(localCredentialTable.Entries) != 1 {
|
||||
t.Fatalf("expected one entry in local credential table, got %#v", localCredentialTable)
|
||||
}
|
||||
|
||||
oldCredential := c.auth
|
||||
if err := c.loadCredentials(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(oldCredential, c.auth) {
|
||||
t.Fatalf("expected\n%#v\ngot\n%#v\n", oldCredential, c.auth)
|
||||
}
|
||||
|
||||
if len(c.auth.Entries) != 2 {
|
||||
t.Fatalf("expected two credential entries, got %#v", localCredentialTable)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCore_EnableCredential_twice_409(t *testing.T) {
|
||||
c, _, _ := TestCoreUnsealed(t)
|
||||
c.credentialBackends["noop"] = func(*logical.BackendConfig) (logical.Backend, error) {
|
||||
@@ -132,7 +216,7 @@ func TestCore_DisableCredential(t *testing.T) {
|
||||
}
|
||||
|
||||
existed, err := c.disableCredential("foo")
|
||||
if existed || err.Error() != "no matching backend" {
|
||||
if existed || (err != nil && !strings.HasPrefix(err.Error(), "no matching backend")) {
|
||||
t.Fatalf("existed: %v; err: %v", existed, err)
|
||||
}
|
||||
|
||||
|
||||
@@ -86,6 +86,11 @@ type SecurityBarrier interface {
|
||||
// VerifyMaster is used to check if the given key matches the master key
|
||||
VerifyMaster(key []byte) error
|
||||
|
||||
// SetMasterKey is used to directly set a new master key. This is used in
|
||||
// repliated scenarios due to the chicken and egg problem of reloading the
|
||||
// keyring from disk before we have the master key to decrypt it.
|
||||
SetMasterKey(key []byte) error
|
||||
|
||||
// ReloadKeyring is used to re-read the underlying keyring.
|
||||
// This is used for HA deployments to ensure the latest keyring
|
||||
// is present in the leader.
|
||||
@@ -119,8 +124,14 @@ type SecurityBarrier interface {
|
||||
// Rekey is used to change the master key used to protect the keyring
|
||||
Rekey([]byte) error
|
||||
|
||||
// For replication we must send over the keyring, so this must be available
|
||||
Keyring() (*Keyring, error)
|
||||
|
||||
// SecurityBarrier must provide the storage APIs
|
||||
BarrierStorage
|
||||
|
||||
// SecurityBarrier must provide the encryption APIs
|
||||
BarrierEncryptor
|
||||
}
|
||||
|
||||
// BarrierStorage is the storage only interface required for a Barrier.
|
||||
@@ -139,6 +150,14 @@ type BarrierStorage interface {
|
||||
List(prefix string) ([]string, error)
|
||||
}
|
||||
|
||||
// BarrierEncryptor is the in memory only interface that does not actually
|
||||
// use the underlying barrier. It is used for lower level modules like the
|
||||
// Write-Ahead-Log and Merkle index to allow them to use the barrier.
|
||||
type BarrierEncryptor interface {
|
||||
Encrypt(key string, plaintext []byte) ([]byte, error)
|
||||
Decrypt(key string, ciphertext []byte) ([]byte, error)
|
||||
}
|
||||
|
||||
// Entry is used to represent data stored by the security barrier
|
||||
type Entry struct {
|
||||
Key string
|
||||
|
||||
@@ -574,19 +574,12 @@ func (b *AESGCMBarrier) ActiveKeyInfo() (*KeyInfo, error) {
|
||||
func (b *AESGCMBarrier) Rekey(key []byte) error {
|
||||
b.l.Lock()
|
||||
defer b.l.Unlock()
|
||||
if b.sealed {
|
||||
return ErrBarrierSealed
|
||||
}
|
||||
|
||||
// Verify the key size
|
||||
min, max := b.KeyLength()
|
||||
if len(key) < min || len(key) > max {
|
||||
return fmt.Errorf("Key size must be %d or %d", min, max)
|
||||
newKeyring, err := b.updateMasterKeyCommon(key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Add a new encryption key
|
||||
newKeyring := b.keyring.SetMasterKey(key)
|
||||
|
||||
// Persist the new keyring
|
||||
if err := b.persistKeyring(newKeyring); err != nil {
|
||||
return err
|
||||
@@ -599,6 +592,40 @@ func (b *AESGCMBarrier) Rekey(key []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetMasterKey updates the keyring's in-memory master key but does not persist
|
||||
// anything to storage
|
||||
func (b *AESGCMBarrier) SetMasterKey(key []byte) error {
|
||||
b.l.Lock()
|
||||
defer b.l.Unlock()
|
||||
|
||||
newKeyring, err := b.updateMasterKeyCommon(key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Swap the keyrings
|
||||
oldKeyring := b.keyring
|
||||
b.keyring = newKeyring
|
||||
oldKeyring.Zeroize(false)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Performs common tasks related to updating the master key; note that the lock
|
||||
// must be held before calling this function
|
||||
func (b *AESGCMBarrier) updateMasterKeyCommon(key []byte) (*Keyring, error) {
|
||||
if b.sealed {
|
||||
return nil, ErrBarrierSealed
|
||||
}
|
||||
|
||||
// Verify the key size
|
||||
min, max := b.KeyLength()
|
||||
if len(key) < min || len(key) > max {
|
||||
return nil, fmt.Errorf("Key size must be %d or %d", min, max)
|
||||
}
|
||||
|
||||
return b.keyring.SetMasterKey(key), nil
|
||||
}
|
||||
|
||||
// Put is used to insert or update an entry
|
||||
func (b *AESGCMBarrier) Put(entry *Entry) error {
|
||||
defer metrics.MeasureSince([]string{"barrier", "put"}, time.Now())
|
||||
@@ -813,3 +840,47 @@ func (b *AESGCMBarrier) decryptKeyring(path string, cipher []byte) ([]byte, erro
|
||||
return nil, fmt.Errorf("version bytes mis-match")
|
||||
}
|
||||
}
|
||||
|
||||
// Encrypt is used to encrypt in-memory for the BarrierEncryptor interface
|
||||
func (b *AESGCMBarrier) Encrypt(key string, plaintext []byte) ([]byte, error) {
|
||||
b.l.RLock()
|
||||
defer b.l.RUnlock()
|
||||
if b.sealed {
|
||||
return nil, ErrBarrierSealed
|
||||
}
|
||||
|
||||
term := b.keyring.ActiveTerm()
|
||||
primary, err := b.aeadForTerm(term)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ciphertext := b.encrypt(key, term, primary, plaintext)
|
||||
return ciphertext, nil
|
||||
}
|
||||
|
||||
// Decrypt is used to decrypt in-memory for the BarrierEncryptor interface
|
||||
func (b *AESGCMBarrier) Decrypt(key string, ciphertext []byte) ([]byte, error) {
|
||||
b.l.RLock()
|
||||
defer b.l.RUnlock()
|
||||
if b.sealed {
|
||||
return nil, ErrBarrierSealed
|
||||
}
|
||||
|
||||
// Decrypt the ciphertext
|
||||
plain, err := b.decryptKeyring(key, ciphertext)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decryption failed: %v", err)
|
||||
}
|
||||
return plain, nil
|
||||
}
|
||||
|
||||
func (b *AESGCMBarrier) Keyring() (*Keyring, error) {
|
||||
b.l.RLock()
|
||||
defer b.l.RUnlock()
|
||||
if b.sealed {
|
||||
return nil, ErrBarrierSealed
|
||||
}
|
||||
|
||||
return b.keyring.Clone(), nil
|
||||
}
|
||||
|
||||
@@ -15,7 +15,7 @@ var (
|
||||
)
|
||||
|
||||
// mockBarrier returns a physical backend, security barrier, and master key
|
||||
func mockBarrier(t *testing.T) (physical.Backend, SecurityBarrier, []byte) {
|
||||
func mockBarrier(t testing.TB) (physical.Backend, SecurityBarrier, []byte) {
|
||||
|
||||
inm := physical.NewInmem(logger)
|
||||
b, err := NewAESGCMBarrier(inm)
|
||||
@@ -433,3 +433,30 @@ func TestInitialize_KeyLength(t *testing.T) {
|
||||
t.Fatalf("key length protection failed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncrypt_BarrierEncryptor(t *testing.T) {
|
||||
inm := physical.NewInmem(logger)
|
||||
b, err := NewAESGCMBarrier(inm)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
// Initialize and unseal
|
||||
key, _ := b.GenerateKey()
|
||||
b.Initialize(key)
|
||||
b.Unseal(key)
|
||||
|
||||
cipher, err := b.Encrypt("foo", []byte("quick brown fox"))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
plain, err := b.Decrypt("foo", cipher)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
if string(plain) != "quick brown fox" {
|
||||
t.Fatalf("bad: %s", plain)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -69,14 +69,18 @@ func (v *BarrierView) Get(key string) (*logical.StorageEntry, error) {
|
||||
|
||||
// logical.Storage impl.
|
||||
func (v *BarrierView) Put(entry *logical.StorageEntry) error {
|
||||
if v.readonly {
|
||||
return logical.ErrReadOnly
|
||||
}
|
||||
if err := v.sanityCheck(entry.Key); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
expandedKey := v.expandKey(entry.Key)
|
||||
|
||||
if v.readonly {
|
||||
return logical.ErrReadOnly
|
||||
}
|
||||
|
||||
nested := &Entry{
|
||||
Key: v.expandKey(entry.Key),
|
||||
Key: expandedKey,
|
||||
Value: entry.Value,
|
||||
}
|
||||
return v.barrier.Put(nested)
|
||||
@@ -84,13 +88,18 @@ func (v *BarrierView) Put(entry *logical.StorageEntry) error {
|
||||
|
||||
// logical.Storage impl.
|
||||
func (v *BarrierView) Delete(key string) error {
|
||||
if v.readonly {
|
||||
return logical.ErrReadOnly
|
||||
}
|
||||
if err := v.sanityCheck(key); err != nil {
|
||||
return err
|
||||
}
|
||||
return v.barrier.Delete(v.expandKey(key))
|
||||
|
||||
expandedKey := v.expandKey(key)
|
||||
|
||||
if v.readonly {
|
||||
return logical.ErrReadOnly
|
||||
}
|
||||
|
||||
|
||||
return v.barrier.Delete(expandedKey)
|
||||
}
|
||||
|
||||
// SubView constructs a nested sub-view using the given prefix
|
||||
|
||||
@@ -1,27 +1,19 @@
|
||||
package vault
|
||||
|
||||
import "sort"
|
||||
import (
|
||||
"sort"
|
||||
|
||||
// Struct to identify user input errors.
|
||||
// This is helpful in responding the appropriate status codes to clients
|
||||
// from the HTTP endpoints.
|
||||
type StatusBadRequest struct {
|
||||
Err string
|
||||
}
|
||||
|
||||
// Implementing error interface
|
||||
func (s *StatusBadRequest) Error() string {
|
||||
return s.Err
|
||||
}
|
||||
"github.com/hashicorp/vault/logical"
|
||||
)
|
||||
|
||||
// Capabilities is used to fetch the capabilities of the given token on the given path
|
||||
func (c *Core) Capabilities(token, path string) ([]string, error) {
|
||||
if path == "" {
|
||||
return nil, &StatusBadRequest{Err: "missing path"}
|
||||
return nil, &logical.StatusBadRequest{Err: "missing path"}
|
||||
}
|
||||
|
||||
if token == "" {
|
||||
return nil, &StatusBadRequest{Err: "missing token"}
|
||||
return nil, &logical.StatusBadRequest{Err: "missing token"}
|
||||
}
|
||||
|
||||
te, err := c.tokenStore.Lookup(token)
|
||||
@@ -29,7 +21,7 @@ func (c *Core) Capabilities(token, path string) ([]string, error) {
|
||||
return nil, err
|
||||
}
|
||||
if te == nil {
|
||||
return nil, &StatusBadRequest{Err: "invalid token"}
|
||||
return nil, &logical.StatusBadRequest{Err: "invalid token"}
|
||||
}
|
||||
|
||||
if te.Policies == nil {
|
||||
|
||||
@@ -43,7 +43,7 @@ var (
|
||||
|
||||
// This can be one of a few key types so the different params may or may not be filled
|
||||
type clusterKeyParams struct {
|
||||
Type string `json:"type"`
|
||||
Type string `json:"type" structs:"type" mapstructure:"type"`
|
||||
X *big.Int `json:"x" structs:"x" mapstructure:"x"`
|
||||
Y *big.Int `json:"y" structs:"y" mapstructure:"y"`
|
||||
D *big.Int `json:"d" structs:"d" mapstructure:"d"`
|
||||
@@ -339,45 +339,67 @@ func (c *Core) stopClusterListener() {
|
||||
c.logger.Info("core/stopClusterListener: success")
|
||||
}
|
||||
|
||||
// ClusterTLSConfig generates a TLS configuration based on the local cluster
|
||||
// key and cert.
|
||||
// ClusterTLSConfig generates a TLS configuration based on the local/replicated
|
||||
// cluster key and cert.
|
||||
func (c *Core) ClusterTLSConfig() (*tls.Config, error) {
|
||||
cluster, err := c.Cluster()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if cluster == nil {
|
||||
return nil, fmt.Errorf("cluster information is nil")
|
||||
return nil, fmt.Errorf("local cluster information is nil")
|
||||
}
|
||||
|
||||
// Prevent data races with the TLS parameters
|
||||
c.clusterParamsLock.Lock()
|
||||
defer c.clusterParamsLock.Unlock()
|
||||
|
||||
if c.localClusterCert == nil || len(c.localClusterCert) == 0 {
|
||||
return nil, fmt.Errorf("cluster certificate is nil")
|
||||
forwarding := c.localClusterCert != nil && len(c.localClusterCert) > 0
|
||||
|
||||
var parsedCert *x509.Certificate
|
||||
if forwarding {
|
||||
parsedCert, err = x509.ParseCertificate(c.localClusterCert)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing local cluster certificate: %v", err)
|
||||
}
|
||||
|
||||
// This is idempotent, so be sure it's been added
|
||||
c.clusterCertPool.AddCert(parsedCert)
|
||||
}
|
||||
|
||||
parsedCert, err := x509.ParseCertificate(c.localClusterCert)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing local cluster certificate: %v", err)
|
||||
}
|
||||
nameLookup := func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
c.clusterParamsLock.RLock()
|
||||
defer c.clusterParamsLock.RUnlock()
|
||||
|
||||
// This is idempotent, so be sure it's been added
|
||||
c.clusterCertPool.AddCert(parsedCert)
|
||||
|
||||
tlsConfig := &tls.Config{
|
||||
Certificates: []tls.Certificate{
|
||||
tls.Certificate{
|
||||
if forwarding && clientHello.ServerName == parsedCert.Subject.CommonName {
|
||||
return &tls.Certificate{
|
||||
Certificate: [][]byte{c.localClusterCert},
|
||||
PrivateKey: c.localClusterPrivateKey,
|
||||
},
|
||||
},
|
||||
RootCAs: c.clusterCertPool,
|
||||
ServerName: parsedCert.Subject.CommonName,
|
||||
ClientAuth: tls.RequireAndVerifyClientCert,
|
||||
ClientCAs: c.clusterCertPool,
|
||||
MinVersion: tls.VersionTLS12,
|
||||
}, nil
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var clientCertificates []tls.Certificate
|
||||
if forwarding {
|
||||
clientCertificates = append(clientCertificates, tls.Certificate{
|
||||
Certificate: [][]byte{c.localClusterCert},
|
||||
PrivateKey: c.localClusterPrivateKey,
|
||||
})
|
||||
}
|
||||
|
||||
tlsConfig := &tls.Config{
|
||||
// We need this here for the client side
|
||||
Certificates: clientCertificates,
|
||||
RootCAs: c.clusterCertPool,
|
||||
ClientAuth: tls.RequireAndVerifyClientCert,
|
||||
ClientCAs: c.clusterCertPool,
|
||||
GetCertificate: nameLookup,
|
||||
MinVersion: tls.VersionTLS12,
|
||||
}
|
||||
if forwarding {
|
||||
tlsConfig.ServerName = parsedCert.Subject.CommonName
|
||||
}
|
||||
|
||||
return tlsConfig, nil
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/vault/helper/consts"
|
||||
"github.com/hashicorp/vault/helper/logformat"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/physical"
|
||||
@@ -100,7 +101,7 @@ func TestCluster_ListenForRequests(t *testing.T) {
|
||||
checkListenersFunc := func(expectFail bool) {
|
||||
tlsConfig, err := cores[0].ClusterTLSConfig()
|
||||
if err != nil {
|
||||
if err.Error() != ErrSealed.Error() {
|
||||
if err.Error() != consts.ErrSealed.Error() {
|
||||
t.Fatal(err)
|
||||
}
|
||||
tlsConfig = lastTLSConfig
|
||||
|
||||
168
vault/core.go
168
vault/core.go
@@ -1,9 +1,9 @@
|
||||
package vault
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/subtle"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -23,6 +23,7 @@ import (
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/hashicorp/go-uuid"
|
||||
"github.com/hashicorp/vault/audit"
|
||||
"github.com/hashicorp/vault/helper/consts"
|
||||
"github.com/hashicorp/vault/helper/errutil"
|
||||
"github.com/hashicorp/vault/helper/jsonutil"
|
||||
"github.com/hashicorp/vault/helper/logformat"
|
||||
@@ -56,17 +57,14 @@ const (
|
||||
// leaderPrefixCleanDelay is how long to wait between deletions
|
||||
// of orphaned leader keys, to prevent slamming the backend.
|
||||
leaderPrefixCleanDelay = 200 * time.Millisecond
|
||||
|
||||
// coreKeyringCanaryPath is used as a canary to indicate to replicated
|
||||
// clusters that they need to perform a rekey operation synchronously; this
|
||||
// isn't keyring-canary to avoid ignoring it when ignoring core/keyring
|
||||
coreKeyringCanaryPath = "core/canary-keyring"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrSealed is returned if an operation is performed on
|
||||
// a sealed barrier. No operation is expected to succeed before unsealing
|
||||
ErrSealed = errors.New("Vault is sealed")
|
||||
|
||||
// ErrStandby is returned if an operation is performed on
|
||||
// a standby Vault. No operation is expected to succeed until active.
|
||||
ErrStandby = errors.New("Vault is in standby mode")
|
||||
|
||||
// ErrAlreadyInit is returned if the core is already
|
||||
// initialized. This prevents a re-initialization.
|
||||
ErrAlreadyInit = errors.New("Vault is already initialized")
|
||||
@@ -87,6 +85,12 @@ var (
|
||||
// step down of the active node, to prevent instantly regrabbing the lock.
|
||||
// It's var not const so that tests can manipulate it.
|
||||
manualStepDownSleepPeriod = 10 * time.Second
|
||||
|
||||
// Functions only in the Enterprise version
|
||||
enterprisePostUnseal = enterprisePostUnsealImpl
|
||||
enterprisePreSeal = enterprisePreSealImpl
|
||||
startReplication = startReplicationImpl
|
||||
stopReplication = stopReplicationImpl
|
||||
)
|
||||
|
||||
// ReloadFunc are functions that are called when a reload is requested.
|
||||
@@ -133,6 +137,11 @@ type unlockInformation struct {
|
||||
// interface for API handlers and is responsible for managing the logical and physical
|
||||
// backends, router, security barrier, and audit trails.
|
||||
type Core struct {
|
||||
// N.B.: This is used to populate a dev token down replication, as
|
||||
// otherwise, after replication is started, a dev would have to go through
|
||||
// the generate-root process simply to talk to the new follower cluster.
|
||||
devToken string
|
||||
|
||||
// HABackend may be available depending on the physical backend
|
||||
ha physical.HABackend
|
||||
|
||||
@@ -268,7 +277,7 @@ type Core struct {
|
||||
//
|
||||
// Name
|
||||
clusterName string
|
||||
// Used to modify cluster TLS params
|
||||
// Used to modify cluster parameters
|
||||
clusterParamsLock sync.RWMutex
|
||||
// The private key stored in the barrier used for establishing
|
||||
// mutually-authenticated connections between Vault cluster members
|
||||
@@ -310,11 +319,13 @@ type Core struct {
|
||||
|
||||
// replicationState keeps the current replication state cached for quick
|
||||
// lookup
|
||||
replicationState logical.ReplicationState
|
||||
replicationState consts.ReplicationState
|
||||
}
|
||||
|
||||
// CoreConfig is used to parameterize a core
|
||||
type CoreConfig struct {
|
||||
DevToken string `json:"dev_token" structs:"dev_token" mapstructure:"dev_token"`
|
||||
|
||||
LogicalBackends map[string]logical.Factory `json:"logical_backends" structs:"logical_backends" mapstructure:"logical_backends"`
|
||||
|
||||
CredentialBackends map[string]logical.Factory `json:"credential_backends" structs:"credential_backends" mapstructure:"credential_backends"`
|
||||
@@ -390,6 +401,30 @@ func NewCore(conf *CoreConfig) (*Core, error) {
|
||||
conf.Logger = logformat.NewVaultLogger(log.LevelTrace)
|
||||
}
|
||||
|
||||
// Setup the core
|
||||
c := &Core{
|
||||
redirectAddr: conf.RedirectAddr,
|
||||
clusterAddr: conf.ClusterAddr,
|
||||
physical: conf.Physical,
|
||||
seal: conf.Seal,
|
||||
router: NewRouter(),
|
||||
sealed: true,
|
||||
standby: true,
|
||||
logger: conf.Logger,
|
||||
defaultLeaseTTL: conf.DefaultLeaseTTL,
|
||||
maxLeaseTTL: conf.MaxLeaseTTL,
|
||||
cachingDisabled: conf.DisableCache,
|
||||
clusterName: conf.ClusterName,
|
||||
clusterCertPool: x509.NewCertPool(),
|
||||
clusterListenerShutdownCh: make(chan struct{}),
|
||||
clusterListenerShutdownSuccessCh: make(chan struct{}),
|
||||
}
|
||||
|
||||
// Wrap the physical backend in a cache layer if enabled and not already wrapped
|
||||
if _, isCache := conf.Physical.(*physical.Cache); !conf.DisableCache && !isCache {
|
||||
c.physical = physical.NewCache(conf.Physical, conf.CacheSize, conf.Logger)
|
||||
}
|
||||
|
||||
if !conf.DisableMlock {
|
||||
// Ensure our memory usage is locked into physical RAM
|
||||
if err := mlock.LockMemory(); err != nil {
|
||||
@@ -407,36 +442,12 @@ func NewCore(conf *CoreConfig) (*Core, error) {
|
||||
}
|
||||
|
||||
// Construct a new AES-GCM barrier
|
||||
barrier, err := NewAESGCMBarrier(conf.Physical)
|
||||
var err error
|
||||
c.barrier, err = NewAESGCMBarrier(c.physical)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("barrier setup failed: %v", err)
|
||||
}
|
||||
|
||||
// Setup the core
|
||||
c := &Core{
|
||||
redirectAddr: conf.RedirectAddr,
|
||||
clusterAddr: conf.ClusterAddr,
|
||||
physical: conf.Physical,
|
||||
seal: conf.Seal,
|
||||
barrier: barrier,
|
||||
router: NewRouter(),
|
||||
sealed: true,
|
||||
standby: true,
|
||||
logger: conf.Logger,
|
||||
defaultLeaseTTL: conf.DefaultLeaseTTL,
|
||||
maxLeaseTTL: conf.MaxLeaseTTL,
|
||||
cachingDisabled: conf.DisableCache,
|
||||
clusterName: conf.ClusterName,
|
||||
clusterCertPool: x509.NewCertPool(),
|
||||
clusterListenerShutdownCh: make(chan struct{}),
|
||||
clusterListenerShutdownSuccessCh: make(chan struct{}),
|
||||
}
|
||||
|
||||
// Wrap the backend in a cache unless disabled
|
||||
if _, isCache := conf.Physical.(*physical.Cache); !conf.DisableCache && !isCache {
|
||||
c.physical = physical.NewCache(conf.Physical, conf.CacheSize, conf.Logger)
|
||||
}
|
||||
|
||||
if conf.HAPhysical != nil && conf.HAPhysical.HAEnabled() {
|
||||
c.ha = conf.HAPhysical
|
||||
}
|
||||
@@ -518,10 +529,10 @@ func (c *Core) LookupToken(token string) (*TokenEntry, error) {
|
||||
c.stateLock.RLock()
|
||||
defer c.stateLock.RUnlock()
|
||||
if c.sealed {
|
||||
return nil, ErrSealed
|
||||
return nil, consts.ErrSealed
|
||||
}
|
||||
if c.standby {
|
||||
return nil, ErrStandby
|
||||
return nil, consts.ErrStandby
|
||||
}
|
||||
|
||||
// Many tests don't have a token store running
|
||||
@@ -656,7 +667,7 @@ func (c *Core) Leader() (isLeader bool, leaderAddr string, err error) {
|
||||
|
||||
// Check if sealed
|
||||
if c.sealed {
|
||||
return false, "", ErrSealed
|
||||
return false, "", consts.ErrSealed
|
||||
}
|
||||
|
||||
// Check if HA enabled
|
||||
@@ -803,17 +814,29 @@ func (c *Core) Unseal(key []byte) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
masterKey, err := c.unsealPart(config, key)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if masterKey != nil {
|
||||
return c.unsealInternal(masterKey)
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (c *Core) unsealPart(config *SealConfig, key []byte) ([]byte, error) {
|
||||
// Check if we already have this piece
|
||||
if c.unlockInfo != nil {
|
||||
for _, existing := range c.unlockInfo.Parts {
|
||||
if bytes.Equal(existing, key) {
|
||||
return false, nil
|
||||
if subtle.ConstantTimeCompare(existing, key) == 1 {
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
} else {
|
||||
uuid, err := uuid.GenerateUUID()
|
||||
if err != nil {
|
||||
return false, err
|
||||
return nil, err
|
||||
}
|
||||
c.unlockInfo = &unlockInformation{
|
||||
Nonce: uuid,
|
||||
@@ -828,27 +851,37 @@ func (c *Core) Unseal(key []byte) (bool, error) {
|
||||
if c.logger.IsDebug() {
|
||||
c.logger.Debug("core: cannot unseal, not enough keys", "keys", len(c.unlockInfo.Parts), "threshold", config.SecretThreshold, "nonce", c.unlockInfo.Nonce)
|
||||
}
|
||||
return false, nil
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Best-effort memzero of unlock parts once we're done with them
|
||||
defer func() {
|
||||
for i, _ := range c.unlockInfo.Parts {
|
||||
memzero(c.unlockInfo.Parts[i])
|
||||
}
|
||||
c.unlockInfo = nil
|
||||
}()
|
||||
|
||||
// Recover the master key
|
||||
var masterKey []byte
|
||||
var err error
|
||||
if config.SecretThreshold == 1 {
|
||||
masterKey = c.unlockInfo.Parts[0]
|
||||
c.unlockInfo = nil
|
||||
masterKey = make([]byte, len(c.unlockInfo.Parts[0]))
|
||||
copy(masterKey, c.unlockInfo.Parts[0])
|
||||
} else {
|
||||
masterKey, err = shamir.Combine(c.unlockInfo.Parts)
|
||||
c.unlockInfo = nil
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to compute master key: %v", err)
|
||||
return nil, fmt.Errorf("failed to compute master key: %v", err)
|
||||
}
|
||||
}
|
||||
defer memzero(masterKey)
|
||||
|
||||
return c.unsealInternal(masterKey)
|
||||
return masterKey, nil
|
||||
}
|
||||
|
||||
// This must be called with the state write lock held
|
||||
func (c *Core) unsealInternal(masterKey []byte) (bool, error) {
|
||||
defer memzero(masterKey)
|
||||
|
||||
// Attempt to unlock
|
||||
if err := c.barrier.Unseal(masterKey); err != nil {
|
||||
return false, err
|
||||
@@ -867,12 +900,14 @@ func (c *Core) unsealInternal(masterKey []byte) (bool, error) {
|
||||
c.logger.Warn("core: vault is sealed")
|
||||
return false, err
|
||||
}
|
||||
|
||||
if err := c.postUnseal(); err != nil {
|
||||
c.logger.Error("core: post-unseal setup failed", "error", err)
|
||||
c.barrier.Seal()
|
||||
c.logger.Warn("core: vault is sealed")
|
||||
return false, err
|
||||
}
|
||||
|
||||
c.standby = false
|
||||
} else {
|
||||
// Go to standby mode, wait until we are active to unseal
|
||||
@@ -1168,6 +1203,7 @@ func (c *Core) postUnseal() (retErr error) {
|
||||
if purgable, ok := c.physical.(physical.Purgable); ok {
|
||||
purgable.Purge()
|
||||
}
|
||||
|
||||
// HA mode requires us to handle keyring rotation and rekeying
|
||||
if c.ha != nil {
|
||||
// We want to reload these from disk so that in case of a rekey we're
|
||||
@@ -1190,6 +1226,9 @@ func (c *Core) postUnseal() (retErr error) {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if err := enterprisePostUnseal(c); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := c.ensureWrappingKey(); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -1251,6 +1290,7 @@ func (c *Core) preSeal() error {
|
||||
c.metricsCh = nil
|
||||
}
|
||||
var result error
|
||||
|
||||
if c.ha != nil {
|
||||
c.stopClusterListener()
|
||||
}
|
||||
@@ -1273,6 +1313,10 @@ func (c *Core) preSeal() error {
|
||||
if err := c.unloadMounts(); err != nil {
|
||||
result = multierror.Append(result, errwrap.Wrapf("error unloading mounts: {{err}}", err))
|
||||
}
|
||||
if err := enterprisePreSeal(c); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
|
||||
// Purge the backend if supported
|
||||
if purgable, ok := c.physical.(physical.Purgable); ok {
|
||||
purgable.Purge()
|
||||
@@ -1281,6 +1325,22 @@ func (c *Core) preSeal() error {
|
||||
return result
|
||||
}
|
||||
|
||||
func enterprisePostUnsealImpl(c *Core) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func enterprisePreSealImpl(c *Core) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func startReplicationImpl(c *Core) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func stopReplicationImpl(c *Core) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// runStandby is a long running routine that is used when an HA backend
|
||||
// is enabled. It waits until we are leader and switches this Vault to
|
||||
// active.
|
||||
@@ -1599,6 +1659,14 @@ func (c *Core) emitMetrics(stopCh chan struct{}) {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Core) ReplicationState() consts.ReplicationState {
|
||||
var state consts.ReplicationState
|
||||
c.clusterParamsLock.RLock()
|
||||
state = c.replicationState
|
||||
c.clusterParamsLock.RUnlock()
|
||||
return state
|
||||
}
|
||||
|
||||
func (c *Core) SealAccess() *SealAccess {
|
||||
sa := &SealAccess{}
|
||||
sa.SetSeal(c.seal)
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"github.com/hashicorp/errwrap"
|
||||
"github.com/hashicorp/go-uuid"
|
||||
"github.com/hashicorp/vault/audit"
|
||||
"github.com/hashicorp/vault/helper/consts"
|
||||
"github.com/hashicorp/vault/helper/logformat"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/physical"
|
||||
@@ -198,7 +199,7 @@ func TestCore_Route_Sealed(t *testing.T) {
|
||||
Path: "sys/mounts",
|
||||
}
|
||||
_, err := c.HandleRequest(req)
|
||||
if err != ErrSealed {
|
||||
if err != consts.ErrSealed {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
@@ -1541,7 +1542,7 @@ func testCore_Standby_Common(t *testing.T, inm physical.Backend, inmha physical.
|
||||
|
||||
// Request should fail in standby mode
|
||||
_, err = core2.HandleRequest(req)
|
||||
if err != ErrStandby {
|
||||
if err != consts.ErrStandby {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ package vault
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/vault/helper/consts"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
)
|
||||
|
||||
@@ -79,8 +80,8 @@ func (d dynamicSystemView) CachingDisabled() bool {
|
||||
}
|
||||
|
||||
// Checks if this is a primary Vault instance.
|
||||
func (d dynamicSystemView) ReplicationState() logical.ReplicationState {
|
||||
var state logical.ReplicationState
|
||||
func (d dynamicSystemView) ReplicationState() consts.ReplicationState {
|
||||
var state consts.ReplicationState
|
||||
d.core.clusterParamsLock.RLock()
|
||||
state = d.core.replicationState
|
||||
d.core.clusterParamsLock.RUnlock()
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
log "github.com/mgutz/logxi/v1"
|
||||
|
||||
"github.com/hashicorp/go-uuid"
|
||||
"github.com/hashicorp/vault/helper/consts"
|
||||
"github.com/hashicorp/vault/helper/jsonutil"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
)
|
||||
@@ -125,46 +126,114 @@ func (m *ExpirationManager) Restore() error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to scan for leases: %v", err)
|
||||
}
|
||||
|
||||
m.logger.Debug("expiration: leases collected", "num_existing", len(existing))
|
||||
|
||||
// Restore each key
|
||||
for i, leaseID := range existing {
|
||||
if i%500 == 0 {
|
||||
m.logger.Trace("expiration: leases loading", "progress", i)
|
||||
}
|
||||
// Load the entry
|
||||
le, err := m.loadEntry(leaseID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Make the channels used for the worker pool
|
||||
broker := make(chan string)
|
||||
quit := make(chan bool)
|
||||
// Buffer these channels to prevent deadlocks
|
||||
errs := make(chan error, len(existing))
|
||||
result := make(chan *leaseEntry, len(existing))
|
||||
|
||||
// If there is no entry, nothing to restore
|
||||
if le == nil {
|
||||
continue
|
||||
}
|
||||
// Use a wait group
|
||||
wg := &sync.WaitGroup{}
|
||||
|
||||
// If there is no expiry time, don't do anything
|
||||
if le.ExpireTime.IsZero() {
|
||||
continue
|
||||
}
|
||||
// Create 64 workers to distribute work to
|
||||
for i := 0; i < consts.ExpirationRestoreWorkerCount; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
// Determine the remaining time to expiration
|
||||
expires := le.ExpireTime.Sub(time.Now())
|
||||
if expires <= 0 {
|
||||
expires = minRevokeDelay
|
||||
}
|
||||
for {
|
||||
select {
|
||||
case leaseID, ok := <-broker:
|
||||
// broker has been closed, we are done
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
// Setup revocation timer
|
||||
m.pending[le.LeaseID] = time.AfterFunc(expires, func() {
|
||||
m.expireID(le.LeaseID)
|
||||
})
|
||||
le, err := m.loadEntry(leaseID)
|
||||
if err != nil {
|
||||
errs <- err
|
||||
continue
|
||||
}
|
||||
|
||||
// Write results out to the result channel
|
||||
result <- le
|
||||
|
||||
// quit early
|
||||
case <-quit:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Distribute the collected keys to the workers in a go routine
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for i, leaseID := range existing {
|
||||
if i%500 == 0 {
|
||||
m.logger.Trace("expiration: leases loading", "progress", i)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-quit:
|
||||
return
|
||||
|
||||
default:
|
||||
broker <- leaseID
|
||||
}
|
||||
}
|
||||
|
||||
// Close the broker, causing worker routines to exit
|
||||
close(broker)
|
||||
}()
|
||||
|
||||
// Restore each key by pulling from the result chan
|
||||
for i := 0; i < len(existing); i++ {
|
||||
select {
|
||||
case err := <-errs:
|
||||
// Close all go routines
|
||||
close(quit)
|
||||
|
||||
return err
|
||||
|
||||
case le := <-result:
|
||||
|
||||
// If there is no entry, nothing to restore
|
||||
if le == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// If there is no expiry time, don't do anything
|
||||
if le.ExpireTime.IsZero() {
|
||||
continue
|
||||
}
|
||||
|
||||
// Determine the remaining time to expiration
|
||||
expires := le.ExpireTime.Sub(time.Now())
|
||||
if expires <= 0 {
|
||||
expires = minRevokeDelay
|
||||
}
|
||||
|
||||
// Setup revocation timer
|
||||
m.pending[le.LeaseID] = time.AfterFunc(expires, func() {
|
||||
m.expireID(le.LeaseID)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Let all go routines finish
|
||||
wg.Wait()
|
||||
|
||||
if len(m.pending) > 0 {
|
||||
if m.logger.IsInfo() {
|
||||
m.logger.Info("expire: leases restored", "restored_lease_count", len(m.pending))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -2,23 +2,131 @@ package vault
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/go-uuid"
|
||||
"github.com/hashicorp/vault/helper/logformat"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/logical/framework"
|
||||
"github.com/hashicorp/vault/physical"
|
||||
log "github.com/mgutz/logxi/v1"
|
||||
)
|
||||
|
||||
var (
|
||||
testImagePull sync.Once
|
||||
)
|
||||
|
||||
// mockExpiration returns a mock expiration manager
|
||||
func mockExpiration(t *testing.T) *ExpirationManager {
|
||||
func mockExpiration(t testing.TB) *ExpirationManager {
|
||||
_, ts, _, _ := TestCoreWithTokenStore(t)
|
||||
return ts.expiration
|
||||
}
|
||||
|
||||
func mockBackendExpiration(t testing.TB, backend physical.Backend) (*Core, *ExpirationManager) {
|
||||
c, ts, _, _ := TestCoreWithBackendTokenStore(t, backend)
|
||||
return c, ts.expiration
|
||||
}
|
||||
|
||||
func BenchmarkExpiration_Restore_Etcd(b *testing.B) {
|
||||
addr := os.Getenv("PHYSICAL_BACKEND_BENCHMARK_ADDR")
|
||||
randPath := fmt.Sprintf("vault-%d/", time.Now().Unix())
|
||||
|
||||
logger := logformat.NewVaultLogger(log.LevelTrace)
|
||||
physicalBackend, err := physical.NewBackend("etcd", logger, map[string]string{
|
||||
"address": addr,
|
||||
"path": randPath,
|
||||
"max_parallel": "256",
|
||||
})
|
||||
if err != nil {
|
||||
b.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
benchmarkExpirationBackend(b, physicalBackend, 10000) // 10,000 leases
|
||||
}
|
||||
|
||||
func BenchmarkExpiration_Restore_Consul(b *testing.B) {
|
||||
addr := os.Getenv("PHYSICAL_BACKEND_BENCHMARK_ADDR")
|
||||
randPath := fmt.Sprintf("vault-%d/", time.Now().Unix())
|
||||
|
||||
logger := logformat.NewVaultLogger(log.LevelTrace)
|
||||
physicalBackend, err := physical.NewBackend("consul", logger, map[string]string{
|
||||
"address": addr,
|
||||
"path": randPath,
|
||||
"max_parallel": "256",
|
||||
})
|
||||
if err != nil {
|
||||
b.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
benchmarkExpirationBackend(b, physicalBackend, 10000) // 10,000 leases
|
||||
}
|
||||
|
||||
func BenchmarkExpiration_Restore_InMem(b *testing.B) {
|
||||
logger := logformat.NewVaultLogger(log.LevelTrace)
|
||||
benchmarkExpirationBackend(b, physical.NewInmem(logger), 100000) // 100,000 Leases
|
||||
}
|
||||
|
||||
func benchmarkExpirationBackend(b *testing.B, physicalBackend physical.Backend, numLeases int) {
|
||||
c, exp := mockBackendExpiration(b, physicalBackend)
|
||||
noop := &NoopBackend{}
|
||||
view := NewBarrierView(c.barrier, "logical/")
|
||||
meUUID, err := uuid.GenerateUUID()
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
exp.router.Mount(noop, "prod/aws/", &MountEntry{UUID: meUUID}, view)
|
||||
|
||||
// Register fake leases
|
||||
for i := 0; i < numLeases; i++ {
|
||||
pathUUID, err := uuid.GenerateUUID()
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
req := &logical.Request{
|
||||
Operation: logical.ReadOperation,
|
||||
Path: "prod/aws/" + pathUUID,
|
||||
}
|
||||
resp := &logical.Response{
|
||||
Secret: &logical.Secret{
|
||||
LeaseOptions: logical.LeaseOptions{
|
||||
TTL: 400 * time.Second,
|
||||
},
|
||||
},
|
||||
Data: map[string]interface{}{
|
||||
"access_key": "xyz",
|
||||
"secret_key": "abcd",
|
||||
},
|
||||
}
|
||||
_, err = exp.Register(req, resp)
|
||||
if err != nil {
|
||||
b.Fatalf("err: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Stop everything
|
||||
err = exp.Stop()
|
||||
if err != nil {
|
||||
b.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
err = exp.Restore()
|
||||
// Restore
|
||||
if err != nil {
|
||||
b.Fatalf("err: %v", err)
|
||||
}
|
||||
}
|
||||
b.StopTimer()
|
||||
}
|
||||
|
||||
func TestExpiration_Restore(t *testing.T) {
|
||||
exp := mockExpiration(t)
|
||||
noop := &NoopBackend{}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
|
||||
"github.com/hashicorp/go-uuid"
|
||||
"github.com/hashicorp/vault/helper/consts"
|
||||
"github.com/hashicorp/vault/helper/pgpkeys"
|
||||
"github.com/hashicorp/vault/helper/xor"
|
||||
"github.com/hashicorp/vault/shamir"
|
||||
@@ -34,10 +35,10 @@ func (c *Core) GenerateRootProgress() (int, error) {
|
||||
c.stateLock.RLock()
|
||||
defer c.stateLock.RUnlock()
|
||||
if c.sealed {
|
||||
return 0, ErrSealed
|
||||
return 0, consts.ErrSealed
|
||||
}
|
||||
if c.standby {
|
||||
return 0, ErrStandby
|
||||
return 0, consts.ErrStandby
|
||||
}
|
||||
|
||||
c.generateRootLock.Lock()
|
||||
@@ -52,10 +53,10 @@ func (c *Core) GenerateRootConfiguration() (*GenerateRootConfig, error) {
|
||||
c.stateLock.RLock()
|
||||
defer c.stateLock.RUnlock()
|
||||
if c.sealed {
|
||||
return nil, ErrSealed
|
||||
return nil, consts.ErrSealed
|
||||
}
|
||||
if c.standby {
|
||||
return nil, ErrStandby
|
||||
return nil, consts.ErrStandby
|
||||
}
|
||||
|
||||
c.generateRootLock.Lock()
|
||||
@@ -101,10 +102,10 @@ func (c *Core) GenerateRootInit(otp, pgpKey string) error {
|
||||
c.stateLock.RLock()
|
||||
defer c.stateLock.RUnlock()
|
||||
if c.sealed {
|
||||
return ErrSealed
|
||||
return consts.ErrSealed
|
||||
}
|
||||
if c.standby {
|
||||
return ErrStandby
|
||||
return consts.ErrStandby
|
||||
}
|
||||
|
||||
c.generateRootLock.Lock()
|
||||
@@ -170,10 +171,10 @@ func (c *Core) GenerateRootUpdate(key []byte, nonce string) (*GenerateRootResult
|
||||
c.stateLock.RLock()
|
||||
defer c.stateLock.RUnlock()
|
||||
if c.sealed {
|
||||
return nil, ErrSealed
|
||||
return nil, consts.ErrSealed
|
||||
}
|
||||
if c.standby {
|
||||
return nil, ErrStandby
|
||||
return nil, consts.ErrStandby
|
||||
}
|
||||
|
||||
c.generateRootLock.Lock()
|
||||
@@ -308,10 +309,10 @@ func (c *Core) GenerateRootCancel() error {
|
||||
c.stateLock.RLock()
|
||||
defer c.stateLock.RUnlock()
|
||||
if c.sealed {
|
||||
return ErrSealed
|
||||
return consts.ErrSealed
|
||||
}
|
||||
if c.standby {
|
||||
return ErrStandby
|
||||
return consts.ErrStandby
|
||||
}
|
||||
|
||||
c.generateRootLock.Lock()
|
||||
|
||||
@@ -133,36 +133,12 @@ func (c *Core) Initialize(initParams *InitParams) (*InitResult, error) {
|
||||
return nil, fmt.Errorf("error initializing seal: %v", err)
|
||||
}
|
||||
|
||||
err = c.seal.SetBarrierConfig(barrierConfig)
|
||||
if err != nil {
|
||||
c.logger.Error("core: failed to save barrier configuration", "error", err)
|
||||
return nil, fmt.Errorf("barrier configuration saving failed: %v", err)
|
||||
}
|
||||
|
||||
barrierKey, barrierUnsealKeys, err := c.generateShares(barrierConfig)
|
||||
if err != nil {
|
||||
c.logger.Error("core: error generating shares", "error", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// If we are storing shares, pop them out of the returned results and push
|
||||
// them through the seal
|
||||
if barrierConfig.StoredShares > 0 {
|
||||
var keysToStore [][]byte
|
||||
for i := 0; i < barrierConfig.StoredShares; i++ {
|
||||
keysToStore = append(keysToStore, barrierUnsealKeys[0])
|
||||
barrierUnsealKeys = barrierUnsealKeys[1:]
|
||||
}
|
||||
if err := c.seal.SetStoredKeys(keysToStore); err != nil {
|
||||
c.logger.Error("core: failed to store keys", "error", err)
|
||||
return nil, fmt.Errorf("failed to store keys: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
results := &InitResult{
|
||||
SecretShares: barrierUnsealKeys,
|
||||
}
|
||||
|
||||
// Initialize the barrier
|
||||
if err := c.barrier.Initialize(barrierKey); err != nil {
|
||||
c.logger.Error("core: failed to initialize barrier", "error", err)
|
||||
@@ -180,11 +156,38 @@ func (c *Core) Initialize(initParams *InitParams) (*InitResult, error) {
|
||||
|
||||
// Ensure the barrier is re-sealed
|
||||
defer func() {
|
||||
// Defers are LIFO so we need to run this here too to ensure the stop
|
||||
// happens before sealing. preSeal also stops, so we just make the
|
||||
// stopping safe against multiple calls.
|
||||
if err := c.barrier.Seal(); err != nil {
|
||||
c.logger.Error("core: failed to seal barrier", "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
err = c.seal.SetBarrierConfig(barrierConfig)
|
||||
if err != nil {
|
||||
c.logger.Error("core: failed to save barrier configuration", "error", err)
|
||||
return nil, fmt.Errorf("barrier configuration saving failed: %v", err)
|
||||
}
|
||||
|
||||
// If we are storing shares, pop them out of the returned results and push
|
||||
// them through the seal
|
||||
if barrierConfig.StoredShares > 0 {
|
||||
var keysToStore [][]byte
|
||||
for i := 0; i < barrierConfig.StoredShares; i++ {
|
||||
keysToStore = append(keysToStore, barrierUnsealKeys[0])
|
||||
barrierUnsealKeys = barrierUnsealKeys[1:]
|
||||
}
|
||||
if err := c.seal.SetStoredKeys(keysToStore); err != nil {
|
||||
c.logger.Error("core: failed to store keys", "error", err)
|
||||
return nil, fmt.Errorf("failed to store keys: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
results := &InitResult{
|
||||
SecretShares: barrierUnsealKeys,
|
||||
}
|
||||
|
||||
// Perform initial setup
|
||||
if err := c.setupCluster(); err != nil {
|
||||
c.logger.Error("core: cluster setup failed during init", "error", err)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user