mirror of
https://github.com/optim-enterprises-bv/vault.git
synced 2025-12-15 01:47:13 +00:00
Merge branch 'master' into acl-parameters-permission
This commit is contained in:
@@ -7,7 +7,7 @@ services:
|
|||||||
- docker
|
- docker
|
||||||
|
|
||||||
go:
|
go:
|
||||||
- 1.8rc2
|
- 1.8
|
||||||
|
|
||||||
matrix:
|
matrix:
|
||||||
allow_failures:
|
allow_failures:
|
||||||
|
|||||||
17
CHANGELOG.md
17
CHANGELOG.md
@@ -1,5 +1,16 @@
|
|||||||
## Next (Unreleased)
|
## Next (Unreleased)
|
||||||
|
|
||||||
|
DEPRECATIONS/CHANGES:
|
||||||
|
|
||||||
|
* List Operations Always Use Trailing Slash: Any list operation, whether via
|
||||||
|
the `GET` or `LIST` HTTP verb, will now internally canonicalize the path to
|
||||||
|
have a trailing slash. This makes policy writing more predictable, as it
|
||||||
|
means clients will no longer work or fail based on which client they're
|
||||||
|
using or which HTTP verb they're using. However, it also means that policies
|
||||||
|
allowing `list` capability must be carefully checked to ensure that they
|
||||||
|
contain a trailing slash; some policies may need to be split into multiple
|
||||||
|
stanzas to accommodate.
|
||||||
|
|
||||||
IMPROVEMENTS:
|
IMPROVEMENTS:
|
||||||
|
|
||||||
* auth/ldap: Use the value of the `LOGNAME` or `USER` env vars for the
|
* auth/ldap: Use the value of the `LOGNAME` or `USER` env vars for the
|
||||||
@@ -7,14 +18,20 @@ IMPROVEMENTS:
|
|||||||
[GH-2154]
|
[GH-2154]
|
||||||
* audit: Support adding a configurable prefix (such as `@cee`) before each
|
* audit: Support adding a configurable prefix (such as `@cee`) before each
|
||||||
line [GH-2359]
|
line [GH-2359]
|
||||||
|
* core: Canonicalize list operations to use a trailing slash [GH-2390]
|
||||||
|
* secret/pki: O (Organization) values can now be set to role-defined values
|
||||||
|
for issued/signed certificates [GH-2369]
|
||||||
|
|
||||||
BUG FIXES:
|
BUG FIXES:
|
||||||
|
|
||||||
|
* audit: When auditing headers use case-insensitive comparisons [GH-2362]
|
||||||
* auth/aws-ec2: Return role period in seconds and not nanoseconds [GH-2374]
|
* auth/aws-ec2: Return role period in seconds and not nanoseconds [GH-2374]
|
||||||
* auth/okta: Fix panic if user had no local groups and/or policies set
|
* auth/okta: Fix panic if user had no local groups and/or policies set
|
||||||
[GH-2367]
|
[GH-2367]
|
||||||
* command/server: Fix parsing of redirect address when port is not mentioned
|
* command/server: Fix parsing of redirect address when port is not mentioned
|
||||||
[GH-2354]
|
[GH-2354]
|
||||||
|
* physical/postgresql: Fix listing returning incorrect results if there were
|
||||||
|
multiple levels of children [GH-2393]
|
||||||
|
|
||||||
## 0.6.5 (February 7th, 2017)
|
## 0.6.5 (February 7th, 2017)
|
||||||
|
|
||||||
|
|||||||
5
Makefile
5
Makefile
@@ -24,6 +24,11 @@ dev-dynamic: generate
|
|||||||
test: generate
|
test: generate
|
||||||
CGO_ENABLED=0 VAULT_TOKEN= VAULT_ACC= go test -tags='$(BUILD_TAGS)' $(TEST) $(TESTARGS) -timeout=10m -parallel=4
|
CGO_ENABLED=0 VAULT_TOKEN= VAULT_ACC= go test -tags='$(BUILD_TAGS)' $(TEST) $(TESTARGS) -timeout=10m -parallel=4
|
||||||
|
|
||||||
|
testcompile: generate
|
||||||
|
@for pkg in $(TEST) ; do \
|
||||||
|
go test -v -c -tags='$(BUILD_TAGS)' $$pkg -parallel=4 ; \
|
||||||
|
done
|
||||||
|
|
||||||
# testacc runs acceptance tests
|
# testacc runs acceptance tests
|
||||||
testacc: generate
|
testacc: generate
|
||||||
@if [ "$(TEST)" = "./..." ]; then \
|
@if [ "$(TEST)" = "./..." ]; then \
|
||||||
|
|||||||
@@ -56,9 +56,9 @@ All documentation is available on the [Vault website](https://www.vaultproject.i
|
|||||||
Developing Vault
|
Developing Vault
|
||||||
--------------------
|
--------------------
|
||||||
|
|
||||||
If you wish to work on Vault itself or any of its built-in systems,
|
If you wish to work on Vault itself or any of its built-in systems, you'll
|
||||||
you'll first need [Go](https://www.golang.org) installed on your
|
first need [Go](https://www.golang.org) installed on your machine (version 1.8+
|
||||||
machine (version 1.8+ is *required*).
|
is *required*).
|
||||||
|
|
||||||
For local dev first make sure Go is properly installed, including setting up a
|
For local dev first make sure Go is properly installed, including setting up a
|
||||||
[GOPATH](https://golang.org/doc/code.html#GOPATH). Next, clone this repository
|
[GOPATH](https://golang.org/doc/code.html#GOPATH). Next, clone this repository
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package api
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/fatih/structs"
|
||||||
"github.com/mitchellh/mapstructure"
|
"github.com/mitchellh/mapstructure"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -71,13 +72,18 @@ func (c *Sys) ListAudit() (map[string]*Audit, error) {
|
|||||||
return mounts, nil
|
return mounts, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DEPRECATED: Use EnableAuditWithOptions instead
|
||||||
func (c *Sys) EnableAudit(
|
func (c *Sys) EnableAudit(
|
||||||
path string, auditType string, desc string, opts map[string]string) error {
|
path string, auditType string, desc string, opts map[string]string) error {
|
||||||
body := map[string]interface{}{
|
return c.EnableAuditWithOptions(path, &EnableAuditOptions{
|
||||||
"type": auditType,
|
Type: auditType,
|
||||||
"description": desc,
|
Description: desc,
|
||||||
"options": opts,
|
Options: opts,
|
||||||
}
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Sys) EnableAuditWithOptions(path string, options *EnableAuditOptions) error {
|
||||||
|
body := structs.Map(options)
|
||||||
|
|
||||||
r := c.c.NewRequest("PUT", fmt.Sprintf("/v1/sys/audit/%s", path))
|
r := c.c.NewRequest("PUT", fmt.Sprintf("/v1/sys/audit/%s", path))
|
||||||
if err := r.SetJSONBody(body); err != nil {
|
if err := r.SetJSONBody(body); err != nil {
|
||||||
@@ -106,9 +112,17 @@ func (c *Sys) DisableAudit(path string) error {
|
|||||||
// individually documented because the map almost directly to the raw HTTP API
|
// individually documented because the map almost directly to the raw HTTP API
|
||||||
// documentation. Please refer to that documentation for more details.
|
// documentation. Please refer to that documentation for more details.
|
||||||
|
|
||||||
|
type EnableAuditOptions struct {
|
||||||
|
Type string `json:"type" structs:"type"`
|
||||||
|
Description string `json:"description" structs:"description"`
|
||||||
|
Options map[string]string `json:"options" structs:"options"`
|
||||||
|
Local bool `json:"local" structs:"local"`
|
||||||
|
}
|
||||||
|
|
||||||
type Audit struct {
|
type Audit struct {
|
||||||
Path string
|
Path string
|
||||||
Type string
|
Type string
|
||||||
Description string
|
Description string
|
||||||
Options map[string]string
|
Options map[string]string
|
||||||
|
Local bool
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package api
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/fatih/structs"
|
||||||
"github.com/mitchellh/mapstructure"
|
"github.com/mitchellh/mapstructure"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -42,11 +43,16 @@ func (c *Sys) ListAuth() (map[string]*AuthMount, error) {
|
|||||||
return mounts, nil
|
return mounts, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DEPRECATED: Use EnableAuthWithOptions instead
|
||||||
func (c *Sys) EnableAuth(path, authType, desc string) error {
|
func (c *Sys) EnableAuth(path, authType, desc string) error {
|
||||||
body := map[string]string{
|
return c.EnableAuthWithOptions(path, &EnableAuthOptions{
|
||||||
"type": authType,
|
Type: authType,
|
||||||
"description": desc,
|
Description: desc,
|
||||||
}
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Sys) EnableAuthWithOptions(path string, options *EnableAuthOptions) error {
|
||||||
|
body := structs.Map(options)
|
||||||
|
|
||||||
r := c.c.NewRequest("POST", fmt.Sprintf("/v1/sys/auth/%s", path))
|
r := c.c.NewRequest("POST", fmt.Sprintf("/v1/sys/auth/%s", path))
|
||||||
if err := r.SetJSONBody(body); err != nil {
|
if err := r.SetJSONBody(body); err != nil {
|
||||||
@@ -75,10 +81,17 @@ func (c *Sys) DisableAuth(path string) error {
|
|||||||
// individually documentd because the map almost directly to the raw HTTP API
|
// individually documentd because the map almost directly to the raw HTTP API
|
||||||
// documentation. Please refer to that documentation for more details.
|
// documentation. Please refer to that documentation for more details.
|
||||||
|
|
||||||
|
type EnableAuthOptions struct {
|
||||||
|
Type string `json:"type" structs:"type"`
|
||||||
|
Description string `json:"description" structs:"description"`
|
||||||
|
Local bool `json:"local" structs:"local"`
|
||||||
|
}
|
||||||
|
|
||||||
type AuthMount struct {
|
type AuthMount struct {
|
||||||
Type string `json:"type" structs:"type" mapstructure:"type"`
|
Type string `json:"type" structs:"type" mapstructure:"type"`
|
||||||
Description string `json:"description" structs:"description" mapstructure:"description"`
|
Description string `json:"description" structs:"description" mapstructure:"description"`
|
||||||
Config AuthConfigOutput `json:"config" structs:"config" mapstructure:"config"`
|
Config AuthConfigOutput `json:"config" structs:"config" mapstructure:"config"`
|
||||||
|
Local bool `json:"local" structs:"local" mapstructure:"local"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type AuthConfigOutput struct {
|
type AuthConfigOutput struct {
|
||||||
|
|||||||
@@ -123,6 +123,7 @@ type MountInput struct {
|
|||||||
Type string `json:"type" structs:"type"`
|
Type string `json:"type" structs:"type"`
|
||||||
Description string `json:"description" structs:"description"`
|
Description string `json:"description" structs:"description"`
|
||||||
Config MountConfigInput `json:"config" structs:"config"`
|
Config MountConfigInput `json:"config" structs:"config"`
|
||||||
|
Local bool `json:"local" structs:"local"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type MountConfigInput struct {
|
type MountConfigInput struct {
|
||||||
@@ -134,6 +135,7 @@ type MountOutput struct {
|
|||||||
Type string `json:"type" structs:"type"`
|
Type string `json:"type" structs:"type"`
|
||||||
Description string `json:"description" structs:"description"`
|
Description string `json:"description" structs:"description"`
|
||||||
Config MountConfigOutput `json:"config" structs:"config"`
|
Config MountConfigOutput `json:"config" structs:"config"`
|
||||||
|
Local bool `json:"local" structs:"local"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type MountConfigOutput struct {
|
type MountConfigOutput struct {
|
||||||
|
|||||||
121
audit/format.go
121
audit/format.go
@@ -27,7 +27,11 @@ func (f *AuditFormatter) FormatRequest(
|
|||||||
config FormatterConfig,
|
config FormatterConfig,
|
||||||
auth *logical.Auth,
|
auth *logical.Auth,
|
||||||
req *logical.Request,
|
req *logical.Request,
|
||||||
err error) error {
|
inErr error) error {
|
||||||
|
|
||||||
|
if req == nil {
|
||||||
|
return fmt.Errorf("request to request-audit a nil request")
|
||||||
|
}
|
||||||
|
|
||||||
if w == nil {
|
if w == nil {
|
||||||
return fmt.Errorf("writer for audit request is nil")
|
return fmt.Errorf("writer for audit request is nil")
|
||||||
@@ -49,22 +53,26 @@ func (f *AuditFormatter) FormatRequest(
|
|||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Copy the structures
|
// Copy the auth structure
|
||||||
cp, err := copystructure.Copy(auth)
|
if auth != nil {
|
||||||
if err != nil {
|
cp, err := copystructure.Copy(auth)
|
||||||
return err
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
auth = cp.(*logical.Auth)
|
||||||
}
|
}
|
||||||
auth = cp.(*logical.Auth)
|
|
||||||
|
|
||||||
cp, err = copystructure.Copy(req)
|
cp, err := copystructure.Copy(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
req = cp.(*logical.Request)
|
req = cp.(*logical.Request)
|
||||||
|
|
||||||
// Hash any sensitive information
|
// Hash any sensitive information
|
||||||
if err := Hash(config.Salt, auth); err != nil {
|
if auth != nil {
|
||||||
return err
|
if err := Hash(config.Salt, auth); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Cache and restore accessor in the request
|
// Cache and restore accessor in the request
|
||||||
@@ -85,8 +93,8 @@ func (f *AuditFormatter) FormatRequest(
|
|||||||
auth = new(logical.Auth)
|
auth = new(logical.Auth)
|
||||||
}
|
}
|
||||||
var errString string
|
var errString string
|
||||||
if err != nil {
|
if inErr != nil {
|
||||||
errString = err.Error()
|
errString = inErr.Error()
|
||||||
}
|
}
|
||||||
|
|
||||||
reqEntry := &AuditRequestEntry{
|
reqEntry := &AuditRequestEntry{
|
||||||
@@ -107,6 +115,7 @@ func (f *AuditFormatter) FormatRequest(
|
|||||||
Path: req.Path,
|
Path: req.Path,
|
||||||
Data: req.Data,
|
Data: req.Data,
|
||||||
RemoteAddr: getRemoteAddr(req),
|
RemoteAddr: getRemoteAddr(req),
|
||||||
|
ReplicationCluster: req.ReplicationCluster,
|
||||||
Headers: req.Headers,
|
Headers: req.Headers,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -128,7 +137,11 @@ func (f *AuditFormatter) FormatResponse(
|
|||||||
auth *logical.Auth,
|
auth *logical.Auth,
|
||||||
req *logical.Request,
|
req *logical.Request,
|
||||||
resp *logical.Response,
|
resp *logical.Response,
|
||||||
err error) error {
|
inErr error) error {
|
||||||
|
|
||||||
|
if req == nil {
|
||||||
|
return fmt.Errorf("request to response-audit a nil request")
|
||||||
|
}
|
||||||
|
|
||||||
if w == nil {
|
if w == nil {
|
||||||
return fmt.Errorf("writer for audit request is nil")
|
return fmt.Errorf("writer for audit request is nil")
|
||||||
@@ -150,37 +163,43 @@ func (f *AuditFormatter) FormatResponse(
|
|||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Copy the structure
|
// Copy the auth structure
|
||||||
cp, err := copystructure.Copy(auth)
|
if auth != nil {
|
||||||
if err != nil {
|
cp, err := copystructure.Copy(auth)
|
||||||
return err
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
auth = cp.(*logical.Auth)
|
||||||
}
|
}
|
||||||
auth = cp.(*logical.Auth)
|
|
||||||
|
|
||||||
cp, err = copystructure.Copy(req)
|
cp, err := copystructure.Copy(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
req = cp.(*logical.Request)
|
req = cp.(*logical.Request)
|
||||||
|
|
||||||
cp, err = copystructure.Copy(resp)
|
if resp != nil {
|
||||||
if err != nil {
|
cp, err := copystructure.Copy(resp)
|
||||||
return err
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
resp = cp.(*logical.Response)
|
||||||
}
|
}
|
||||||
resp = cp.(*logical.Response)
|
|
||||||
|
|
||||||
// Hash any sensitive information
|
// Hash any sensitive information
|
||||||
|
|
||||||
// Cache and restore accessor in the auth
|
// Cache and restore accessor in the auth
|
||||||
var accessor, wrappedAccessor string
|
if auth != nil {
|
||||||
if !config.HMACAccessor && auth != nil && auth.Accessor != "" {
|
var accessor string
|
||||||
accessor = auth.Accessor
|
if !config.HMACAccessor && auth.Accessor != "" {
|
||||||
}
|
accessor = auth.Accessor
|
||||||
if err := Hash(config.Salt, auth); err != nil {
|
}
|
||||||
return err
|
if err := Hash(config.Salt, auth); err != nil {
|
||||||
}
|
return err
|
||||||
if accessor != "" {
|
}
|
||||||
auth.Accessor = accessor
|
if accessor != "" {
|
||||||
|
auth.Accessor = accessor
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Cache and restore accessor in the request
|
// Cache and restore accessor in the request
|
||||||
@@ -196,21 +215,23 @@ func (f *AuditFormatter) FormatResponse(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Cache and restore accessor in the response
|
// Cache and restore accessor in the response
|
||||||
accessor = ""
|
if resp != nil {
|
||||||
if !config.HMACAccessor && resp != nil && resp.Auth != nil && resp.Auth.Accessor != "" {
|
var accessor, wrappedAccessor string
|
||||||
accessor = resp.Auth.Accessor
|
if !config.HMACAccessor && resp != nil && resp.Auth != nil && resp.Auth.Accessor != "" {
|
||||||
}
|
accessor = resp.Auth.Accessor
|
||||||
if !config.HMACAccessor && resp != nil && resp.WrapInfo != nil && resp.WrapInfo.WrappedAccessor != "" {
|
}
|
||||||
wrappedAccessor = resp.WrapInfo.WrappedAccessor
|
if !config.HMACAccessor && resp != nil && resp.WrapInfo != nil && resp.WrapInfo.WrappedAccessor != "" {
|
||||||
}
|
wrappedAccessor = resp.WrapInfo.WrappedAccessor
|
||||||
if err := Hash(config.Salt, resp); err != nil {
|
}
|
||||||
return err
|
if err := Hash(config.Salt, resp); err != nil {
|
||||||
}
|
return err
|
||||||
if accessor != "" {
|
}
|
||||||
resp.Auth.Accessor = accessor
|
if accessor != "" {
|
||||||
}
|
resp.Auth.Accessor = accessor
|
||||||
if wrappedAccessor != "" {
|
}
|
||||||
resp.WrapInfo.WrappedAccessor = wrappedAccessor
|
if wrappedAccessor != "" {
|
||||||
|
resp.WrapInfo.WrappedAccessor = wrappedAccessor
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -222,8 +243,8 @@ func (f *AuditFormatter) FormatResponse(
|
|||||||
resp = new(logical.Response)
|
resp = new(logical.Response)
|
||||||
}
|
}
|
||||||
var errString string
|
var errString string
|
||||||
if err != nil {
|
if inErr != nil {
|
||||||
errString = err.Error()
|
errString = inErr.Error()
|
||||||
}
|
}
|
||||||
|
|
||||||
var respAuth *AuditAuth
|
var respAuth *AuditAuth
|
||||||
@@ -276,6 +297,7 @@ func (f *AuditFormatter) FormatResponse(
|
|||||||
Path: req.Path,
|
Path: req.Path,
|
||||||
Data: req.Data,
|
Data: req.Data,
|
||||||
RemoteAddr: getRemoteAddr(req),
|
RemoteAddr: getRemoteAddr(req),
|
||||||
|
ReplicationCluster: req.ReplicationCluster,
|
||||||
Headers: req.Headers,
|
Headers: req.Headers,
|
||||||
},
|
},
|
||||||
|
|
||||||
@@ -312,14 +334,15 @@ type AuditRequestEntry struct {
|
|||||||
type AuditResponseEntry struct {
|
type AuditResponseEntry struct {
|
||||||
Time string `json:"time,omitempty"`
|
Time string `json:"time,omitempty"`
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
Error string `json:"error"`
|
|
||||||
Auth AuditAuth `json:"auth"`
|
Auth AuditAuth `json:"auth"`
|
||||||
Request AuditRequest `json:"request"`
|
Request AuditRequest `json:"request"`
|
||||||
Response AuditResponse `json:"response"`
|
Response AuditResponse `json:"response"`
|
||||||
|
Error string `json:"error"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type AuditRequest struct {
|
type AuditRequest struct {
|
||||||
ID string `json:"id"`
|
ID string `json:"id"`
|
||||||
|
ReplicationCluster string `json:"replication_cluster,omitempty"`
|
||||||
Operation logical.Operation `json:"operation"`
|
Operation logical.Operation `json:"operation"`
|
||||||
ClientToken string `json:"client_token"`
|
ClientToken string `json:"client_token"`
|
||||||
ClientTokenAccessor string `json:"client_token_accessor"`
|
ClientTokenAccessor string `json:"client_token_accessor"`
|
||||||
|
|||||||
55
audit/format_test.go
Normal file
55
audit/format_test.go
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
package audit
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"io/ioutil"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/hashicorp/vault/helper/salt"
|
||||||
|
"github.com/hashicorp/vault/logical"
|
||||||
|
)
|
||||||
|
|
||||||
|
type noopFormatWriter struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *noopFormatWriter) WriteRequest(_ io.Writer, _ *AuditRequestEntry) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *noopFormatWriter) WriteResponse(_ io.Writer, _ *AuditResponseEntry) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFormatRequestErrors(t *testing.T) {
|
||||||
|
salter, _ := salt.NewSalt(nil, nil)
|
||||||
|
config := FormatterConfig{
|
||||||
|
Salt: salter,
|
||||||
|
}
|
||||||
|
formatter := AuditFormatter{
|
||||||
|
AuditFormatWriter: &noopFormatWriter{},
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := formatter.FormatRequest(ioutil.Discard, config, nil, nil, nil); err == nil {
|
||||||
|
t.Fatal("expected error due to nil request")
|
||||||
|
}
|
||||||
|
if err := formatter.FormatRequest(nil, config, nil, &logical.Request{}, nil); err == nil {
|
||||||
|
t.Fatal("expected error due to nil writer")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFormatResponseErrors(t *testing.T) {
|
||||||
|
salter, _ := salt.NewSalt(nil, nil)
|
||||||
|
config := FormatterConfig{
|
||||||
|
Salt: salter,
|
||||||
|
}
|
||||||
|
formatter := AuditFormatter{
|
||||||
|
AuditFormatWriter: &noopFormatWriter{},
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := formatter.FormatResponse(ioutil.Discard, config, nil, nil, nil, nil); err == nil {
|
||||||
|
t.Fatal("expected error due to nil request")
|
||||||
|
}
|
||||||
|
if err := formatter.FormatResponse(nil, config, nil, &logical.Request{}, nil, nil); err == nil {
|
||||||
|
t.Fatal("expected error due to nil writer")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -157,10 +157,15 @@ func (b *Backend) open() error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Change the file mode in case the log file already existed
|
// Change the file mode in case the log file already existed. We special
|
||||||
err = os.Chmod(b.path, b.mode)
|
// case /dev/null since we can't chmod it
|
||||||
if err != nil {
|
switch b.path {
|
||||||
return err
|
case "/dev/null":
|
||||||
|
default:
|
||||||
|
err = os.Chmod(b.path, b.mode)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -17,20 +17,10 @@ func Factory(conf *logical.BackendConfig) (logical.Backend, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func Backend(conf *logical.BackendConfig) (*framework.Backend, error) {
|
func Backend(conf *logical.BackendConfig) (*framework.Backend, error) {
|
||||||
// Initialize the salt
|
|
||||||
salt, err := salt.NewSalt(conf.StorageView, &salt.Config{
|
|
||||||
HashFunc: salt.SHA1Hash,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var b backend
|
var b backend
|
||||||
b.Salt = salt
|
|
||||||
b.MapAppId = &framework.PolicyMap{
|
b.MapAppId = &framework.PolicyMap{
|
||||||
PathMap: framework.PathMap{
|
PathMap: framework.PathMap{
|
||||||
Name: "app-id",
|
Name: "app-id",
|
||||||
Salt: salt,
|
|
||||||
Schema: map[string]*framework.FieldSchema{
|
Schema: map[string]*framework.FieldSchema{
|
||||||
"display_name": &framework.FieldSchema{
|
"display_name": &framework.FieldSchema{
|
||||||
Type: framework.TypeString,
|
Type: framework.TypeString,
|
||||||
@@ -48,7 +38,6 @@ func Backend(conf *logical.BackendConfig) (*framework.Backend, error) {
|
|||||||
|
|
||||||
b.MapUserId = &framework.PathMap{
|
b.MapUserId = &framework.PathMap{
|
||||||
Name: "user-id",
|
Name: "user-id",
|
||||||
Salt: salt,
|
|
||||||
Schema: map[string]*framework.FieldSchema{
|
Schema: map[string]*framework.FieldSchema{
|
||||||
"cidr_block": &framework.FieldSchema{
|
"cidr_block": &framework.FieldSchema{
|
||||||
Type: framework.TypeString,
|
Type: framework.TypeString,
|
||||||
@@ -81,17 +70,11 @@ func Backend(conf *logical.BackendConfig) (*framework.Backend, error) {
|
|||||||
),
|
),
|
||||||
|
|
||||||
AuthRenew: b.pathLoginRenew,
|
AuthRenew: b.pathLoginRenew,
|
||||||
|
|
||||||
|
Init: b.initialize,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Since the salt is new in 0.2, we need to handle this by migrating
|
b.view = conf.StorageView
|
||||||
// any existing keys to use the salt. We can deprecate this eventually,
|
|
||||||
// but for now we want a smooth upgrade experience by automatically
|
|
||||||
// upgrading to use salting.
|
|
||||||
if salt.DidGenerate() {
|
|
||||||
if err := b.upgradeToSalted(conf.StorageView); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return b.Backend, nil
|
return b.Backend, nil
|
||||||
}
|
}
|
||||||
@@ -100,10 +83,36 @@ type backend struct {
|
|||||||
*framework.Backend
|
*framework.Backend
|
||||||
|
|
||||||
Salt *salt.Salt
|
Salt *salt.Salt
|
||||||
|
view logical.Storage
|
||||||
MapAppId *framework.PolicyMap
|
MapAppId *framework.PolicyMap
|
||||||
MapUserId *framework.PathMap
|
MapUserId *framework.PathMap
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (b *backend) initialize() error {
|
||||||
|
salt, err := salt.NewSalt(b.view, &salt.Config{
|
||||||
|
HashFunc: salt.SHA1Hash,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
b.Salt = salt
|
||||||
|
|
||||||
|
b.MapAppId.Salt = salt
|
||||||
|
b.MapUserId.Salt = salt
|
||||||
|
|
||||||
|
// Since the salt is new in 0.2, we need to handle this by migrating
|
||||||
|
// any existing keys to use the salt. We can deprecate this eventually,
|
||||||
|
// but for now we want a smooth upgrade experience by automatically
|
||||||
|
// upgrading to use salting.
|
||||||
|
if salt.DidGenerate() {
|
||||||
|
if err := b.upgradeToSalted(b.view); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// upgradeToSalted is used to upgrade the non-salted keys prior to
|
// upgradeToSalted is used to upgrade the non-salted keys prior to
|
||||||
// Vault 0.2 to be salted. This is done on mount time and is only
|
// Vault 0.2 to be salted. This is done on mount time and is only
|
||||||
// done once. It can be deprecated eventually, but should be around
|
// done once. It can be deprecated eventually, but should be around
|
||||||
|
|||||||
@@ -72,6 +72,10 @@ func TestBackend_upgradeToSalted(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %v", err)
|
t.Fatalf("err: %v", err)
|
||||||
}
|
}
|
||||||
|
err = backend.Initialize()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("err: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
// Check the keys have been upgraded
|
// Check the keys have been upgraded
|
||||||
out, err := inm.Get("struct/map/app-id/foo")
|
out, err := inm.Get("struct/map/app-id/foo")
|
||||||
|
|||||||
@@ -17,6 +17,9 @@ type backend struct {
|
|||||||
// by this backend.
|
// by this backend.
|
||||||
salt *salt.Salt
|
salt *salt.Salt
|
||||||
|
|
||||||
|
// The view to use when creating the salt
|
||||||
|
view logical.Storage
|
||||||
|
|
||||||
// Guard to clean-up the expired SecretID entries
|
// Guard to clean-up the expired SecretID entries
|
||||||
tidySecretIDCASGuard uint32
|
tidySecretIDCASGuard uint32
|
||||||
|
|
||||||
@@ -57,18 +60,9 @@ func Factory(conf *logical.BackendConfig) (logical.Backend, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func Backend(conf *logical.BackendConfig) (*backend, error) {
|
func Backend(conf *logical.BackendConfig) (*backend, error) {
|
||||||
// Initialize the salt
|
|
||||||
salt, err := salt.NewSalt(conf.StorageView, &salt.Config{
|
|
||||||
HashFunc: salt.SHA256Hash,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create a backend object
|
// Create a backend object
|
||||||
b := &backend{
|
b := &backend{
|
||||||
// Set the salt object for the backend
|
view: conf.StorageView,
|
||||||
salt: salt,
|
|
||||||
|
|
||||||
// Create the map of locks to modify the registered roles
|
// Create the map of locks to modify the registered roles
|
||||||
roleLocksMap: make(map[string]*sync.RWMutex, 257),
|
roleLocksMap: make(map[string]*sync.RWMutex, 257),
|
||||||
@@ -83,6 +77,8 @@ func Backend(conf *logical.BackendConfig) (*backend, error) {
|
|||||||
secretIDAccessorLocksMap: make(map[string]*sync.RWMutex, 257),
|
secretIDAccessorLocksMap: make(map[string]*sync.RWMutex, 257),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var err error
|
||||||
|
|
||||||
// Create 256 locks each for managing RoleID and SecretIDs. This will avoid
|
// Create 256 locks each for managing RoleID and SecretIDs. This will avoid
|
||||||
// a superfluous number of locks directly proportional to the number of RoleID
|
// a superfluous number of locks directly proportional to the number of RoleID
|
||||||
// and SecretIDs. These locks can be accessed by indexing based on the first two
|
// and SecretIDs. These locks can be accessed by indexing based on the first two
|
||||||
@@ -129,10 +125,22 @@ func Backend(conf *logical.BackendConfig) (*backend, error) {
|
|||||||
pathTidySecretID(b),
|
pathTidySecretID(b),
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
|
Init: b.initialize,
|
||||||
}
|
}
|
||||||
return b, nil
|
return b, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (b *backend) initialize() error {
|
||||||
|
salt, err := salt.NewSalt(b.view, &salt.Config{
|
||||||
|
HashFunc: salt.SHA256Hash,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
b.salt = salt
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// periodicFunc of the backend will be invoked once a minute by the RollbackManager.
|
// periodicFunc of the backend will be invoked once a minute by the RollbackManager.
|
||||||
// RoleRole backend utilizes this function to delete expired SecretID entries.
|
// RoleRole backend utilizes this function to delete expired SecretID entries.
|
||||||
// This could mean that the SecretID may live in the backend upto 1 min after its
|
// This could mean that the SecretID may live in the backend upto 1 min after its
|
||||||
|
|||||||
@@ -21,5 +21,9 @@ func createBackendWithStorage(t *testing.T) (*backend, logical.Storage) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
err = b.Initialize()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
return b, config.StorageView
|
return b, config.StorageView
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -143,7 +143,7 @@ func (b *backend) validateCredentials(req *logical.Request, data *framework.Fiel
|
|||||||
return nil, "", metadata, fmt.Errorf("failed to verify the CIDR restrictions set on the role: %v", err)
|
return nil, "", metadata, fmt.Errorf("failed to verify the CIDR restrictions set on the role: %v", err)
|
||||||
}
|
}
|
||||||
if !belongs {
|
if !belongs {
|
||||||
return nil, "", metadata, fmt.Errorf("source address unauthorized through CIDR restrictions on the role")
|
return nil, "", metadata, fmt.Errorf("source address %q unauthorized through CIDR restrictions on the role", req.Connection.RemoteAddr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -199,7 +199,7 @@ func (b *backend) validateBindSecretID(req *logical.Request, roleName, secretID,
|
|||||||
}
|
}
|
||||||
|
|
||||||
if belongs, err := cidrutil.IPBelongsToCIDRBlocksSlice(req.Connection.RemoteAddr, result.CIDRList); !belongs || err != nil {
|
if belongs, err := cidrutil.IPBelongsToCIDRBlocksSlice(req.Connection.RemoteAddr, result.CIDRList); !belongs || err != nil {
|
||||||
return false, nil, fmt.Errorf("source address unauthorized through CIDR restrictions on the secret ID: %v", err)
|
return false, nil, fmt.Errorf("source address %q unauthorized through CIDR restrictions on the secret ID: %v", req.Connection.RemoteAddr, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -261,7 +261,7 @@ func (b *backend) validateBindSecretID(req *logical.Request, roleName, secretID,
|
|||||||
}
|
}
|
||||||
|
|
||||||
if belongs, err := cidrutil.IPBelongsToCIDRBlocksSlice(req.Connection.RemoteAddr, result.CIDRList); !belongs || err != nil {
|
if belongs, err := cidrutil.IPBelongsToCIDRBlocksSlice(req.Connection.RemoteAddr, result.CIDRList); !belongs || err != nil {
|
||||||
return false, nil, fmt.Errorf("source address unauthorized through CIDR restrictions on the secret ID: %v", err)
|
return false, nil, fmt.Errorf("source address %q unauthorized through CIDR restrictions on the secret ID: %v", req.Connection.RemoteAddr, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -23,6 +23,9 @@ type backend struct {
|
|||||||
*framework.Backend
|
*framework.Backend
|
||||||
Salt *salt.Salt
|
Salt *salt.Salt
|
||||||
|
|
||||||
|
// Used during initialization to set the salt
|
||||||
|
view logical.Storage
|
||||||
|
|
||||||
// Lock to make changes to any of the backend's configuration endpoints.
|
// Lock to make changes to any of the backend's configuration endpoints.
|
||||||
configMutex sync.RWMutex
|
configMutex sync.RWMutex
|
||||||
|
|
||||||
@@ -59,18 +62,11 @@ type backend struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func Backend(conf *logical.BackendConfig) (*backend, error) {
|
func Backend(conf *logical.BackendConfig) (*backend, error) {
|
||||||
salt, err := salt.NewSalt(conf.StorageView, &salt.Config{
|
|
||||||
HashFunc: salt.SHA256Hash,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
b := &backend{
|
b := &backend{
|
||||||
// Setting the periodic func to be run once in an hour.
|
// Setting the periodic func to be run once in an hour.
|
||||||
// If there is a real need, this can be made configurable.
|
// If there is a real need, this can be made configurable.
|
||||||
tidyCooldownPeriod: time.Hour,
|
tidyCooldownPeriod: time.Hour,
|
||||||
Salt: salt,
|
view: conf.StorageView,
|
||||||
EC2ClientsMap: make(map[string]map[string]*ec2.EC2),
|
EC2ClientsMap: make(map[string]map[string]*ec2.EC2),
|
||||||
IAMClientsMap: make(map[string]map[string]*iam.IAM),
|
IAMClientsMap: make(map[string]map[string]*iam.IAM),
|
||||||
}
|
}
|
||||||
@@ -83,6 +79,9 @@ func Backend(conf *logical.BackendConfig) (*backend, error) {
|
|||||||
Unauthenticated: []string{
|
Unauthenticated: []string{
|
||||||
"login",
|
"login",
|
||||||
},
|
},
|
||||||
|
LocalStorage: []string{
|
||||||
|
"whitelist/identity/",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
Paths: []*framework.Path{
|
Paths: []*framework.Path{
|
||||||
pathLogin(b),
|
pathLogin(b),
|
||||||
@@ -104,11 +103,26 @@ func Backend(conf *logical.BackendConfig) (*backend, error) {
|
|||||||
pathIdentityWhitelist(b),
|
pathIdentityWhitelist(b),
|
||||||
pathTidyIdentityWhitelist(b),
|
pathTidyIdentityWhitelist(b),
|
||||||
},
|
},
|
||||||
|
|
||||||
|
Invalidate: b.invalidate,
|
||||||
|
|
||||||
|
Init: b.initialize,
|
||||||
}
|
}
|
||||||
|
|
||||||
return b, nil
|
return b, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (b *backend) initialize() error {
|
||||||
|
salt, err := salt.NewSalt(b.view, &salt.Config{
|
||||||
|
HashFunc: salt.SHA256Hash,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
b.Salt = salt
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// periodicFunc performs the tasks that the backend wishes to do periodically.
|
// periodicFunc performs the tasks that the backend wishes to do periodically.
|
||||||
// Currently this will be triggered once in a minute by the RollbackManager.
|
// Currently this will be triggered once in a minute by the RollbackManager.
|
||||||
//
|
//
|
||||||
@@ -169,6 +183,16 @@ func (b *backend) periodicFunc(req *logical.Request) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (b *backend) invalidate(key string) {
|
||||||
|
switch key {
|
||||||
|
case "config/client":
|
||||||
|
b.configMutex.Lock()
|
||||||
|
defer b.configMutex.Unlock()
|
||||||
|
b.flushCachedEC2Clients()
|
||||||
|
b.flushCachedIAMClients()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const backendHelp = `
|
const backendHelp = `
|
||||||
aws-ec2 auth backend takes in PKCS#7 signature of an AWS EC2 instance and a client
|
aws-ec2 auth backend takes in PKCS#7 signature of an AWS EC2 instance and a client
|
||||||
created nonce to authenticates the EC2 instance with Vault.
|
created nonce to authenticates the EC2 instance with Vault.
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package cert
|
package cert
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/hashicorp/vault/logical"
|
"github.com/hashicorp/vault/logical"
|
||||||
@@ -13,7 +14,7 @@ func Factory(conf *logical.BackendConfig) (logical.Backend, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return b, err
|
return b, err
|
||||||
}
|
}
|
||||||
return b, b.populateCRLs(conf.StorageView)
|
return b, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func Backend() *backend {
|
func Backend() *backend {
|
||||||
@@ -36,9 +37,10 @@ func Backend() *backend {
|
|||||||
}),
|
}),
|
||||||
|
|
||||||
AuthRenew: b.pathLoginRenew,
|
AuthRenew: b.pathLoginRenew,
|
||||||
|
|
||||||
|
Invalidate: b.invalidate,
|
||||||
}
|
}
|
||||||
|
|
||||||
b.crls = map[string]CRLInfo{}
|
|
||||||
b.crlUpdateMutex = &sync.RWMutex{}
|
b.crlUpdateMutex = &sync.RWMutex{}
|
||||||
|
|
||||||
return &b
|
return &b
|
||||||
@@ -52,6 +54,15 @@ type backend struct {
|
|||||||
crlUpdateMutex *sync.RWMutex
|
crlUpdateMutex *sync.RWMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (b *backend) invalidate(key string) {
|
||||||
|
switch {
|
||||||
|
case strings.HasPrefix(key, "crls/"):
|
||||||
|
b.crlUpdateMutex.Lock()
|
||||||
|
defer b.crlUpdateMutex.Unlock()
|
||||||
|
b.crls = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const backendHelp = `
|
const backendHelp = `
|
||||||
The "cert" credential provider allows authentication using
|
The "cert" credential provider allows authentication using
|
||||||
TLS client certificates. A client connects to Vault and uses
|
TLS client certificates. A client connects to Vault and uses
|
||||||
|
|||||||
@@ -45,6 +45,12 @@ func (b *backend) populateCRLs(storage logical.Storage) error {
|
|||||||
b.crlUpdateMutex.Lock()
|
b.crlUpdateMutex.Lock()
|
||||||
defer b.crlUpdateMutex.Unlock()
|
defer b.crlUpdateMutex.Unlock()
|
||||||
|
|
||||||
|
if b.crls != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
b.crls = map[string]CRLInfo{}
|
||||||
|
|
||||||
keys, err := storage.List("crls/")
|
keys, err := storage.List("crls/")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error listing CRLs: %v", err)
|
return fmt.Errorf("error listing CRLs: %v", err)
|
||||||
@@ -56,6 +62,7 @@ func (b *backend) populateCRLs(storage logical.Storage) error {
|
|||||||
for _, key := range keys {
|
for _, key := range keys {
|
||||||
entry, err := storage.Get("crls/" + key)
|
entry, err := storage.Get("crls/" + key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
b.crls = nil
|
||||||
return fmt.Errorf("error loading CRL %s: %v", key, err)
|
return fmt.Errorf("error loading CRL %s: %v", key, err)
|
||||||
}
|
}
|
||||||
if entry == nil {
|
if entry == nil {
|
||||||
@@ -64,6 +71,7 @@ func (b *backend) populateCRLs(storage logical.Storage) error {
|
|||||||
var crlInfo CRLInfo
|
var crlInfo CRLInfo
|
||||||
err = entry.DecodeJSON(&crlInfo)
|
err = entry.DecodeJSON(&crlInfo)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
b.crls = nil
|
||||||
return fmt.Errorf("error decoding CRL %s: %v", key, err)
|
return fmt.Errorf("error decoding CRL %s: %v", key, err)
|
||||||
}
|
}
|
||||||
b.crls[key] = crlInfo
|
b.crls[key] = crlInfo
|
||||||
@@ -121,6 +129,10 @@ func (b *backend) pathCRLDelete(
|
|||||||
return logical.ErrorResponse(`"name" parameter cannot be empty`), nil
|
return logical.ErrorResponse(`"name" parameter cannot be empty`), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := b.populateCRLs(req.Storage); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
b.crlUpdateMutex.Lock()
|
b.crlUpdateMutex.Lock()
|
||||||
defer b.crlUpdateMutex.Unlock()
|
defer b.crlUpdateMutex.Unlock()
|
||||||
|
|
||||||
@@ -131,8 +143,7 @@ func (b *backend) pathCRLDelete(
|
|||||||
)), nil
|
)), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
err := req.Storage.Delete("crls/" + name)
|
if err := req.Storage.Delete("crls/" + name); err != nil {
|
||||||
if err != nil {
|
|
||||||
return logical.ErrorResponse(fmt.Sprintf(
|
return logical.ErrorResponse(fmt.Sprintf(
|
||||||
"error deleting crl %s: %v", name, err),
|
"error deleting crl %s: %v", name, err),
|
||||||
), nil
|
), nil
|
||||||
@@ -150,6 +161,10 @@ func (b *backend) pathCRLRead(
|
|||||||
return logical.ErrorResponse(`"name" parameter must be set`), nil
|
return logical.ErrorResponse(`"name" parameter must be set`), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := b.populateCRLs(req.Storage); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
b.crlUpdateMutex.RLock()
|
b.crlUpdateMutex.RLock()
|
||||||
defer b.crlUpdateMutex.RUnlock()
|
defer b.crlUpdateMutex.RUnlock()
|
||||||
|
|
||||||
@@ -185,6 +200,10 @@ func (b *backend) pathCRLWrite(
|
|||||||
return logical.ErrorResponse("parsed CRL is nil"), nil
|
return logical.ErrorResponse("parsed CRL is nil"), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := b.populateCRLs(req.Storage); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
b.crlUpdateMutex.Lock()
|
b.crlUpdateMutex.Lock()
|
||||||
defer b.crlUpdateMutex.Unlock()
|
defer b.crlUpdateMutex.Unlock()
|
||||||
|
|
||||||
|
|||||||
@@ -17,6 +17,12 @@ func Backend() *backend {
|
|||||||
b.Backend = &framework.Backend{
|
b.Backend = &framework.Backend{
|
||||||
Help: strings.TrimSpace(backendHelp),
|
Help: strings.TrimSpace(backendHelp),
|
||||||
|
|
||||||
|
PathsSpecial: &logical.Paths{
|
||||||
|
LocalStorage: []string{
|
||||||
|
framework.WALPrefix,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
|
||||||
Paths: []*framework.Path{
|
Paths: []*framework.Path{
|
||||||
pathConfigRoot(),
|
pathConfigRoot(),
|
||||||
pathConfigLease(&b),
|
pathConfigLease(&b),
|
||||||
|
|||||||
@@ -31,6 +31,8 @@ func Backend() *backend {
|
|||||||
secretCreds(&b),
|
secretCreds(&b),
|
||||||
},
|
},
|
||||||
|
|
||||||
|
Invalidate: b.invalidate,
|
||||||
|
|
||||||
Clean: func() {
|
Clean: func() {
|
||||||
b.ResetDB(nil)
|
b.ResetDB(nil)
|
||||||
},
|
},
|
||||||
@@ -107,6 +109,13 @@ func (b *backend) ResetDB(newSession *gocql.Session) {
|
|||||||
b.session = newSession
|
b.session = newSession
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (b *backend) invalidate(key string) {
|
||||||
|
switch key {
|
||||||
|
case "config/connection":
|
||||||
|
b.ResetDB(nil)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const backendHelp = `
|
const backendHelp = `
|
||||||
The Cassandra backend dynamically generates database users.
|
The Cassandra backend dynamically generates database users.
|
||||||
|
|
||||||
|
|||||||
@@ -421,7 +421,7 @@ seed_provider:
|
|||||||
parameters:
|
parameters:
|
||||||
# seeds is actually a comma-delimited list of addresses.
|
# seeds is actually a comma-delimited list of addresses.
|
||||||
# Ex: "<ip1>,<ip2>,<ip3>"
|
# Ex: "<ip1>,<ip2>,<ip3>"
|
||||||
- seeds: "172.17.0.2"
|
- seeds: "172.17.0.3"
|
||||||
|
|
||||||
# For workloads with more data than can fit in memory, Cassandra's
|
# For workloads with more data than can fit in memory, Cassandra's
|
||||||
# bottleneck will be reads that need to fetch data from
|
# bottleneck will be reads that need to fetch data from
|
||||||
@@ -572,7 +572,7 @@ ssl_storage_port: 7001
|
|||||||
#
|
#
|
||||||
# Setting listen_address to 0.0.0.0 is always wrong.
|
# Setting listen_address to 0.0.0.0 is always wrong.
|
||||||
#
|
#
|
||||||
listen_address: 172.17.0.2
|
listen_address: 172.17.0.3
|
||||||
|
|
||||||
# Set listen_address OR listen_interface, not both. Interfaces must correspond
|
# Set listen_address OR listen_interface, not both. Interfaces must correspond
|
||||||
# to a single address, IP aliasing is not supported.
|
# to a single address, IP aliasing is not supported.
|
||||||
@@ -586,7 +586,7 @@ listen_address: 172.17.0.2
|
|||||||
|
|
||||||
# Address to broadcast to other Cassandra nodes
|
# Address to broadcast to other Cassandra nodes
|
||||||
# Leaving this blank will set it to the same value as listen_address
|
# Leaving this blank will set it to the same value as listen_address
|
||||||
broadcast_address: 172.17.0.2
|
broadcast_address: 172.17.0.3
|
||||||
|
|
||||||
# When using multiple physical network interfaces, set this
|
# When using multiple physical network interfaces, set this
|
||||||
# to true to listen on broadcast_address in addition to
|
# to true to listen on broadcast_address in addition to
|
||||||
@@ -668,7 +668,7 @@ rpc_port: 9160
|
|||||||
# be set to 0.0.0.0. If left blank, this will be set to the value of
|
# be set to 0.0.0.0. If left blank, this will be set to the value of
|
||||||
# rpc_address. If rpc_address is set to 0.0.0.0, broadcast_rpc_address must
|
# rpc_address. If rpc_address is set to 0.0.0.0, broadcast_rpc_address must
|
||||||
# be set.
|
# be set.
|
||||||
broadcast_rpc_address: 172.17.0.2
|
broadcast_rpc_address: 172.17.0.3
|
||||||
|
|
||||||
# enable or disable keepalive on rpc/native connections
|
# enable or disable keepalive on rpc/native connections
|
||||||
rpc_keepalive: true
|
rpc_keepalive: true
|
||||||
|
|||||||
@@ -33,6 +33,8 @@ func Backend() *framework.Backend {
|
|||||||
},
|
},
|
||||||
|
|
||||||
Clean: b.ResetSession,
|
Clean: b.ResetSession,
|
||||||
|
|
||||||
|
Invalidate: b.invalidate,
|
||||||
}
|
}
|
||||||
|
|
||||||
return b.Backend
|
return b.Backend
|
||||||
@@ -97,6 +99,13 @@ func (b *backend) ResetSession() {
|
|||||||
b.session = nil
|
b.session = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (b *backend) invalidate(key string) {
|
||||||
|
switch key {
|
||||||
|
case "config/connection":
|
||||||
|
b.ResetSession()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// LeaseConfig returns the lease configuration
|
// LeaseConfig returns the lease configuration
|
||||||
func (b *backend) LeaseConfig(s logical.Storage) (*configLease, error) {
|
func (b *backend) LeaseConfig(s logical.Storage) (*configLease, error) {
|
||||||
entry, err := s.Get("config/lease")
|
entry, err := s.Get("config/lease")
|
||||||
|
|||||||
@@ -32,6 +32,8 @@ func Backend() *backend {
|
|||||||
secretCreds(&b),
|
secretCreds(&b),
|
||||||
},
|
},
|
||||||
|
|
||||||
|
Invalidate: b.invalidate,
|
||||||
|
|
||||||
Clean: b.ResetDB,
|
Clean: b.ResetDB,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -112,6 +114,13 @@ func (b *backend) ResetDB() {
|
|||||||
b.db = nil
|
b.db = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (b *backend) invalidate(key string) {
|
||||||
|
switch key {
|
||||||
|
case "config/connection":
|
||||||
|
b.ResetDB()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// LeaseConfig returns the lease configuration
|
// LeaseConfig returns the lease configuration
|
||||||
func (b *backend) LeaseConfig(s logical.Storage) (*configLease, error) {
|
func (b *backend) LeaseConfig(s logical.Storage) (*configLease, error) {
|
||||||
entry, err := s.Get("config/lease")
|
entry, err := s.Get("config/lease")
|
||||||
|
|||||||
@@ -32,6 +32,8 @@ func Backend() *backend {
|
|||||||
secretCreds(&b),
|
secretCreds(&b),
|
||||||
},
|
},
|
||||||
|
|
||||||
|
Invalidate: b.invalidate,
|
||||||
|
|
||||||
Clean: b.ResetDB,
|
Clean: b.ResetDB,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -105,6 +107,13 @@ func (b *backend) ResetDB() {
|
|||||||
b.db = nil
|
b.db = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (b *backend) invalidate(key string) {
|
||||||
|
switch key {
|
||||||
|
case "config/connection":
|
||||||
|
b.ResetDB()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Lease returns the lease information
|
// Lease returns the lease information
|
||||||
func (b *backend) Lease(s logical.Storage) (*configLease, error) {
|
func (b *backend) Lease(s logical.Storage) (*configLease, error) {
|
||||||
entry, err := s.Get("config/lease")
|
entry, err := s.Get("config/lease")
|
||||||
|
|||||||
@@ -29,6 +29,12 @@ func Backend() *backend {
|
|||||||
"crl/pem",
|
"crl/pem",
|
||||||
"crl",
|
"crl",
|
||||||
},
|
},
|
||||||
|
|
||||||
|
LocalStorage: []string{
|
||||||
|
"revoked/",
|
||||||
|
"crl",
|
||||||
|
"certs/",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
|
|
||||||
Paths: []*framework.Path{
|
Paths: []*framework.Path{
|
||||||
|
|||||||
@@ -1478,6 +1478,27 @@ func generateRoleSteps(t *testing.T, useCSRs bool) []logicaltest.TestStep {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
getOrganizationCheck := func(role roleEntry) logicaltest.TestCheckFunc {
|
||||||
|
var certBundle certutil.CertBundle
|
||||||
|
return func(resp *logical.Response) error {
|
||||||
|
err := mapstructure.Decode(resp.Data, &certBundle)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
parsedCertBundle, err := certBundle.ToParsedCertBundle()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("Error checking generated certificate: %s", err)
|
||||||
|
}
|
||||||
|
cert := parsedCertBundle.Certificate
|
||||||
|
|
||||||
|
expected := strutil.ParseDedupAndSortStrings(role.Organization, ",")
|
||||||
|
if !reflect.DeepEqual(cert.Subject.Organization, expected) {
|
||||||
|
return fmt.Errorf("Error: returned certificate has Organization of %s but %s was specified in the role.", cert.Subject.Organization, expected)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Returns a TestCheckFunc that performs various validity checks on the
|
// Returns a TestCheckFunc that performs various validity checks on the
|
||||||
// returned certificate information, mostly within checkCertsAndPrivateKey
|
// returned certificate information, mostly within checkCertsAndPrivateKey
|
||||||
getCnCheck := func(name string, role roleEntry, key crypto.Signer, usage x509.KeyUsage, extUsage x509.ExtKeyUsage, validity time.Duration) logicaltest.TestCheckFunc {
|
getCnCheck := func(name string, role roleEntry, key crypto.Signer, usage x509.KeyUsage, extUsage x509.ExtKeyUsage, validity time.Duration) logicaltest.TestCheckFunc {
|
||||||
@@ -1755,6 +1776,14 @@ func generateRoleSteps(t *testing.T, useCSRs bool) []logicaltest.TestStep {
|
|||||||
roleVals.OU = "foo,bar"
|
roleVals.OU = "foo,bar"
|
||||||
addTests(getOuCheck(roleVals))
|
addTests(getOuCheck(roleVals))
|
||||||
}
|
}
|
||||||
|
// Organization tests
|
||||||
|
{
|
||||||
|
roleVals.Organization = "system:masters"
|
||||||
|
addTests(getOrganizationCheck(roleVals))
|
||||||
|
|
||||||
|
roleVals.Organization = "foo,bar"
|
||||||
|
addTests(getOrganizationCheck(roleVals))
|
||||||
|
}
|
||||||
// IP SAN tests
|
// IP SAN tests
|
||||||
{
|
{
|
||||||
issueVals.IPSANs = "127.0.0.1,::1"
|
issueVals.IPSANs = "127.0.0.1,::1"
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ const (
|
|||||||
type creationBundle struct {
|
type creationBundle struct {
|
||||||
CommonName string
|
CommonName string
|
||||||
OU []string
|
OU []string
|
||||||
|
Organization []string
|
||||||
DNSNames []string
|
DNSNames []string
|
||||||
EmailAddresses []string
|
EmailAddresses []string
|
||||||
IPAddresses []net.IP
|
IPAddresses []net.IP
|
||||||
@@ -581,6 +582,14 @@ func generateCreationBundle(b *backend,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Set O (organization) values if specified in the role
|
||||||
|
organization := []string{}
|
||||||
|
{
|
||||||
|
if role.Organization != "" {
|
||||||
|
organization = strutil.ParseDedupAndSortStrings(role.Organization, ",")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Read in alternate names -- DNS and email addresses
|
// Read in alternate names -- DNS and email addresses
|
||||||
dnsNames := []string{}
|
dnsNames := []string{}
|
||||||
emailAddresses := []string{}
|
emailAddresses := []string{}
|
||||||
@@ -728,6 +737,7 @@ func generateCreationBundle(b *backend,
|
|||||||
creationBundle := &creationBundle{
|
creationBundle := &creationBundle{
|
||||||
CommonName: cn,
|
CommonName: cn,
|
||||||
OU: ou,
|
OU: ou,
|
||||||
|
Organization: organization,
|
||||||
DNSNames: dnsNames,
|
DNSNames: dnsNames,
|
||||||
EmailAddresses: emailAddresses,
|
EmailAddresses: emailAddresses,
|
||||||
IPAddresses: ipAddresses,
|
IPAddresses: ipAddresses,
|
||||||
@@ -820,6 +830,7 @@ func createCertificate(creationInfo *creationBundle) (*certutil.ParsedCertBundle
|
|||||||
subject := pkix.Name{
|
subject := pkix.Name{
|
||||||
CommonName: creationInfo.CommonName,
|
CommonName: creationInfo.CommonName,
|
||||||
OrganizationalUnit: creationInfo.OU,
|
OrganizationalUnit: creationInfo.OU,
|
||||||
|
Organization: creationInfo.Organization,
|
||||||
}
|
}
|
||||||
|
|
||||||
certTemplate := &x509.Certificate{
|
certTemplate := &x509.Certificate{
|
||||||
@@ -983,6 +994,7 @@ func signCertificate(creationInfo *creationBundle,
|
|||||||
subject := pkix.Name{
|
subject := pkix.Name{
|
||||||
CommonName: creationInfo.CommonName,
|
CommonName: creationInfo.CommonName,
|
||||||
OrganizationalUnit: creationInfo.OU,
|
OrganizationalUnit: creationInfo.OU,
|
||||||
|
Organization: creationInfo.Organization,
|
||||||
}
|
}
|
||||||
|
|
||||||
certTemplate := &x509.Certificate{
|
certTemplate := &x509.Certificate{
|
||||||
|
|||||||
@@ -172,6 +172,13 @@ Names. Defaults to true.`,
|
|||||||
Type: framework.TypeString,
|
Type: framework.TypeString,
|
||||||
Default: "",
|
Default: "",
|
||||||
Description: `If set, the OU (OrganizationalUnit) will be set to
|
Description: `If set, the OU (OrganizationalUnit) will be set to
|
||||||
|
this value in certificates issued by this role.`,
|
||||||
|
},
|
||||||
|
|
||||||
|
"organization": &framework.FieldSchema{
|
||||||
|
Type: framework.TypeString,
|
||||||
|
Default: "",
|
||||||
|
Description: `If set, the O (Organization) will be set to
|
||||||
this value in certificates issued by this role.`,
|
this value in certificates issued by this role.`,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -336,6 +343,7 @@ func (b *backend) pathRoleCreate(
|
|||||||
UseCSRCommonName: data.Get("use_csr_common_name").(bool),
|
UseCSRCommonName: data.Get("use_csr_common_name").(bool),
|
||||||
KeyUsage: data.Get("key_usage").(string),
|
KeyUsage: data.Get("key_usage").(string),
|
||||||
OU: data.Get("ou").(string),
|
OU: data.Get("ou").(string),
|
||||||
|
Organization: data.Get("organization").(string),
|
||||||
}
|
}
|
||||||
|
|
||||||
if entry.KeyType == "rsa" && entry.KeyBits < 2048 {
|
if entry.KeyType == "rsa" && entry.KeyBits < 2048 {
|
||||||
@@ -451,6 +459,7 @@ type roleEntry struct {
|
|||||||
MaxPathLength *int `json:",omitempty" structs:",omitempty"`
|
MaxPathLength *int `json:",omitempty" structs:",omitempty"`
|
||||||
KeyUsage string `json:"key_usage" structs:"key_usage" mapstructure:"key_usage"`
|
KeyUsage string `json:"key_usage" structs:"key_usage" mapstructure:"key_usage"`
|
||||||
OU string `json:"ou" structs:"ou" mapstructure:"ou"`
|
OU string `json:"ou" structs:"ou" mapstructure:"ou"`
|
||||||
|
Organization string `json:"organization" structs:"organization" mapstructure:"organization"`
|
||||||
}
|
}
|
||||||
|
|
||||||
const pathListRolesHelpSyn = `List the existing roles in this backend`
|
const pathListRolesHelpSyn = `List the existing roles in this backend`
|
||||||
|
|||||||
@@ -34,6 +34,8 @@ func Backend(conf *logical.BackendConfig) *backend {
|
|||||||
},
|
},
|
||||||
|
|
||||||
Clean: b.ResetDB,
|
Clean: b.ResetDB,
|
||||||
|
|
||||||
|
Invalidate: b.invalidate,
|
||||||
}
|
}
|
||||||
|
|
||||||
b.logger = conf.Logger
|
b.logger = conf.Logger
|
||||||
@@ -126,6 +128,13 @@ func (b *backend) ResetDB() {
|
|||||||
b.db = nil
|
b.db = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (b *backend) invalidate(key string) {
|
||||||
|
switch key {
|
||||||
|
case "config/connection":
|
||||||
|
b.ResetDB()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Lease returns the lease information
|
// Lease returns the lease information
|
||||||
func (b *backend) Lease(s logical.Storage) (*configLease, error) {
|
func (b *backend) Lease(s logical.Storage) (*configLease, error) {
|
||||||
entry, err := s.Get("config/lease")
|
entry, err := s.Get("config/lease")
|
||||||
|
|||||||
@@ -35,6 +35,8 @@ func Backend() *backend {
|
|||||||
},
|
},
|
||||||
|
|
||||||
Clean: b.resetClient,
|
Clean: b.resetClient,
|
||||||
|
|
||||||
|
Invalidate: b.invalidate,
|
||||||
}
|
}
|
||||||
|
|
||||||
return &b
|
return &b
|
||||||
@@ -99,6 +101,13 @@ func (b *backend) resetClient() {
|
|||||||
b.client = nil
|
b.client = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (b *backend) invalidate(key string) {
|
||||||
|
switch key {
|
||||||
|
case "config/connection":
|
||||||
|
b.resetClient()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Lease returns the lease information
|
// Lease returns the lease information
|
||||||
func (b *backend) Lease(s logical.Storage) (*configLease, error) {
|
func (b *backend) Lease(s logical.Storage) (*configLease, error) {
|
||||||
entry, err := s.Get("config/lease")
|
entry, err := s.Get("config/lease")
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
|
|
||||||
type backend struct {
|
type backend struct {
|
||||||
*framework.Backend
|
*framework.Backend
|
||||||
|
view logical.Storage
|
||||||
salt *salt.Salt
|
salt *salt.Salt
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -22,15 +23,8 @@ func Factory(conf *logical.BackendConfig) (logical.Backend, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func Backend(conf *logical.BackendConfig) (*backend, error) {
|
func Backend(conf *logical.BackendConfig) (*backend, error) {
|
||||||
salt, err := salt.NewSalt(conf.StorageView, &salt.Config{
|
|
||||||
HashFunc: salt.SHA256Hash,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var b backend
|
var b backend
|
||||||
b.salt = salt
|
b.view = conf.StorageView
|
||||||
b.Backend = &framework.Backend{
|
b.Backend = &framework.Backend{
|
||||||
Help: strings.TrimSpace(backendHelp),
|
Help: strings.TrimSpace(backendHelp),
|
||||||
|
|
||||||
@@ -38,6 +32,10 @@ func Backend(conf *logical.BackendConfig) (*backend, error) {
|
|||||||
Unauthenticated: []string{
|
Unauthenticated: []string{
|
||||||
"verify",
|
"verify",
|
||||||
},
|
},
|
||||||
|
|
||||||
|
LocalStorage: []string{
|
||||||
|
"otp/",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
|
|
||||||
Paths: []*framework.Path{
|
Paths: []*framework.Path{
|
||||||
@@ -54,10 +52,23 @@ func Backend(conf *logical.BackendConfig) (*backend, error) {
|
|||||||
secretDynamicKey(&b),
|
secretDynamicKey(&b),
|
||||||
secretOTP(&b),
|
secretOTP(&b),
|
||||||
},
|
},
|
||||||
|
|
||||||
|
Init: b.Initialize,
|
||||||
}
|
}
|
||||||
return &b, nil
|
return &b, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (b *backend) Initialize() error {
|
||||||
|
salt, err := salt.NewSalt(b.view, &salt.Config{
|
||||||
|
HashFunc: salt.SHA256Hash,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
b.salt = salt
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
const backendHelp = `
|
const backendHelp = `
|
||||||
The SSH backend generates credentials allowing clients to establish SSH
|
The SSH backend generates credentials allowing clients to establish SSH
|
||||||
connections to remote hosts.
|
connections to remote hosts.
|
||||||
|
|||||||
@@ -73,6 +73,10 @@ func TestBackend_allowed_users(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
err = b.Initialize()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
roleData := map[string]interface{}{
|
roleData := map[string]interface{}{
|
||||||
"key_type": "otp",
|
"key_type": "otp",
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
package transit
|
package transit
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/hashicorp/vault/helper/keysutil"
|
"github.com/hashicorp/vault/helper/keysutil"
|
||||||
"github.com/hashicorp/vault/logical"
|
"github.com/hashicorp/vault/logical"
|
||||||
"github.com/hashicorp/vault/logical/framework"
|
"github.com/hashicorp/vault/logical/framework"
|
||||||
@@ -39,6 +41,8 @@ func Backend(conf *logical.BackendConfig) *backend {
|
|||||||
},
|
},
|
||||||
|
|
||||||
Secrets: []*framework.Secret{},
|
Secrets: []*framework.Secret{},
|
||||||
|
|
||||||
|
Invalidate: b.invalidate,
|
||||||
}
|
}
|
||||||
|
|
||||||
b.lm = keysutil.NewLockManager(conf.System.CachingDisabled())
|
b.lm = keysutil.NewLockManager(conf.System.CachingDisabled())
|
||||||
@@ -50,3 +54,14 @@ type backend struct {
|
|||||||
*framework.Backend
|
*framework.Backend
|
||||||
lm *keysutil.LockManager
|
lm *keysutil.LockManager
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (b *backend) invalidate(key string) {
|
||||||
|
if b.Logger().IsTrace() {
|
||||||
|
b.Logger().Trace("transit: invalidating key", "key", key)
|
||||||
|
}
|
||||||
|
switch {
|
||||||
|
case strings.HasPrefix(key, "policy/"):
|
||||||
|
name := strings.TrimPrefix(key, "policy/")
|
||||||
|
b.lm.InvalidatePolicy(name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package command
|
|||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/hashicorp/vault/api"
|
||||||
"github.com/hashicorp/vault/http"
|
"github.com/hashicorp/vault/http"
|
||||||
"github.com/hashicorp/vault/meta"
|
"github.com/hashicorp/vault/meta"
|
||||||
"github.com/hashicorp/vault/vault"
|
"github.com/hashicorp/vault/vault"
|
||||||
@@ -44,3 +45,42 @@ func TestAuditDisable(t *testing.T) {
|
|||||||
t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String())
|
t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAuditDisableWithOptions(t *testing.T) {
|
||||||
|
core, _, token := vault.TestCoreUnsealed(t)
|
||||||
|
ln, addr := http.TestServer(t, core)
|
||||||
|
defer ln.Close()
|
||||||
|
|
||||||
|
ui := new(cli.MockUi)
|
||||||
|
c := &AuditDisableCommand{
|
||||||
|
Meta: meta.Meta{
|
||||||
|
ClientToken: token,
|
||||||
|
Ui: ui,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
args := []string{
|
||||||
|
"-address", addr,
|
||||||
|
"noop",
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run once to get the client
|
||||||
|
c.Run(args)
|
||||||
|
|
||||||
|
// Get the client
|
||||||
|
client, err := c.Client()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("err: %#v", err)
|
||||||
|
}
|
||||||
|
if err := client.Sys().EnableAuditWithOptions("noop", &api.EnableAuditOptions{
|
||||||
|
Type: "noop",
|
||||||
|
Description: "noop",
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatalf("err: %#v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run again
|
||||||
|
if code := c.Run(args); code != 0 {
|
||||||
|
t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/hashicorp/vault/api"
|
||||||
"github.com/hashicorp/vault/helper/kv-builder"
|
"github.com/hashicorp/vault/helper/kv-builder"
|
||||||
"github.com/hashicorp/vault/meta"
|
"github.com/hashicorp/vault/meta"
|
||||||
"github.com/mitchellh/mapstructure"
|
"github.com/mitchellh/mapstructure"
|
||||||
@@ -21,9 +22,11 @@ type AuditEnableCommand struct {
|
|||||||
|
|
||||||
func (c *AuditEnableCommand) Run(args []string) int {
|
func (c *AuditEnableCommand) Run(args []string) int {
|
||||||
var desc, path string
|
var desc, path string
|
||||||
|
var local bool
|
||||||
flags := c.Meta.FlagSet("audit-enable", meta.FlagSetDefault)
|
flags := c.Meta.FlagSet("audit-enable", meta.FlagSetDefault)
|
||||||
flags.StringVar(&desc, "description", "", "")
|
flags.StringVar(&desc, "description", "", "")
|
||||||
flags.StringVar(&path, "path", "", "")
|
flags.StringVar(&path, "path", "", "")
|
||||||
|
flags.BoolVar(&local, "local", false, "")
|
||||||
flags.Usage = func() { c.Ui.Error(c.Help()) }
|
flags.Usage = func() { c.Ui.Error(c.Help()) }
|
||||||
if err := flags.Parse(args); err != nil {
|
if err := flags.Parse(args); err != nil {
|
||||||
return 1
|
return 1
|
||||||
@@ -68,7 +71,12 @@ func (c *AuditEnableCommand) Run(args []string) int {
|
|||||||
return 1
|
return 1
|
||||||
}
|
}
|
||||||
|
|
||||||
err = client.Sys().EnableAudit(path, auditType, desc, opts)
|
err = client.Sys().EnableAuditWithOptions(path, &api.EnableAuditOptions{
|
||||||
|
Type: auditType,
|
||||||
|
Description: desc,
|
||||||
|
Options: opts,
|
||||||
|
Local: local,
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.Ui.Error(fmt.Sprintf(
|
c.Ui.Error(fmt.Sprintf(
|
||||||
"Error enabling audit backend: %s", err))
|
"Error enabling audit backend: %s", err))
|
||||||
@@ -113,6 +121,9 @@ Audit Enable Options:
|
|||||||
is purely for referencing this audit backend. By
|
is purely for referencing this audit backend. By
|
||||||
default this will be the backend type.
|
default this will be the backend type.
|
||||||
|
|
||||||
|
-local Mark the mount as a local mount. Local mounts
|
||||||
|
are not replicated nor (if a secondary)
|
||||||
|
removed by replication.
|
||||||
`
|
`
|
||||||
return strings.TrimSpace(helpText)
|
return strings.TrimSpace(helpText)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -48,16 +48,19 @@ func (c *AuditListCommand) Run(args []string) int {
|
|||||||
}
|
}
|
||||||
sort.Strings(paths)
|
sort.Strings(paths)
|
||||||
|
|
||||||
columns := []string{"Path | Type | Description | Options"}
|
columns := []string{"Path | Type | Description | Replication Behavior | Options"}
|
||||||
for _, path := range paths {
|
for _, path := range paths {
|
||||||
audit := audits[path]
|
audit := audits[path]
|
||||||
opts := make([]string, 0, len(audit.Options))
|
opts := make([]string, 0, len(audit.Options))
|
||||||
for k, v := range audit.Options {
|
for k, v := range audit.Options {
|
||||||
opts = append(opts, k+"="+v)
|
opts = append(opts, k+"="+v)
|
||||||
}
|
}
|
||||||
|
replicatedBehavior := "replicated"
|
||||||
|
if audit.Local {
|
||||||
|
replicatedBehavior = "local"
|
||||||
|
}
|
||||||
columns = append(columns, fmt.Sprintf(
|
columns = append(columns, fmt.Sprintf(
|
||||||
"%s | %s | %s | %s", audit.Path, audit.Type, audit.Description, strings.Join(opts, " ")))
|
"%s | %s | %s | %s | %s", audit.Path, audit.Type, audit.Description, replicatedBehavior, strings.Join(opts, " ")))
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Ui.Output(columnize.SimpleFormat(columns))
|
c.Ui.Output(columnize.SimpleFormat(columns))
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package command
|
|||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/hashicorp/vault/api"
|
||||||
"github.com/hashicorp/vault/http"
|
"github.com/hashicorp/vault/http"
|
||||||
"github.com/hashicorp/vault/meta"
|
"github.com/hashicorp/vault/meta"
|
||||||
"github.com/hashicorp/vault/vault"
|
"github.com/hashicorp/vault/vault"
|
||||||
@@ -34,7 +35,11 @@ func TestAuditList(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %#v", err)
|
t.Fatalf("err: %#v", err)
|
||||||
}
|
}
|
||||||
if err := client.Sys().EnableAudit("foo", "noop", "", nil); err != nil {
|
if err := client.Sys().EnableAuditWithOptions("foo", &api.EnableAuditOptions{
|
||||||
|
Type: "noop",
|
||||||
|
Description: "noop",
|
||||||
|
Options: nil,
|
||||||
|
}); err != nil {
|
||||||
t.Fatalf("err: %#v", err)
|
t.Fatalf("err: %#v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -281,7 +281,7 @@ func (c *AuthCommand) listMethods() int {
|
|||||||
}
|
}
|
||||||
sort.Strings(paths)
|
sort.Strings(paths)
|
||||||
|
|
||||||
columns := []string{"Path | Type | Default TTL | Max TTL | Description"}
|
columns := []string{"Path | Type | Default TTL | Max TTL | Replication Behavior | Description"}
|
||||||
for _, path := range paths {
|
for _, path := range paths {
|
||||||
auth := auth[path]
|
auth := auth[path]
|
||||||
defTTL := "system"
|
defTTL := "system"
|
||||||
@@ -292,8 +292,12 @@ func (c *AuthCommand) listMethods() int {
|
|||||||
if auth.Config.MaxLeaseTTL != 0 {
|
if auth.Config.MaxLeaseTTL != 0 {
|
||||||
maxTTL = strconv.Itoa(auth.Config.MaxLeaseTTL)
|
maxTTL = strconv.Itoa(auth.Config.MaxLeaseTTL)
|
||||||
}
|
}
|
||||||
|
replicatedBehavior := "replicated"
|
||||||
|
if auth.Local {
|
||||||
|
replicatedBehavior = "local"
|
||||||
|
}
|
||||||
columns = append(columns, fmt.Sprintf(
|
columns = append(columns, fmt.Sprintf(
|
||||||
"%s | %s | %s | %s | %s", path, auth.Type, defTTL, maxTTL, auth.Description))
|
"%s | %s | %s | %s | %s | %s", path, auth.Type, defTTL, maxTTL, replicatedBehavior, auth.Description))
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Ui.Output(columnize.SimpleFormat(columns))
|
c.Ui.Output(columnize.SimpleFormat(columns))
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package command
|
|||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/hashicorp/vault/api"
|
||||||
"github.com/hashicorp/vault/http"
|
"github.com/hashicorp/vault/http"
|
||||||
"github.com/hashicorp/vault/meta"
|
"github.com/hashicorp/vault/meta"
|
||||||
"github.com/hashicorp/vault/vault"
|
"github.com/hashicorp/vault/vault"
|
||||||
@@ -52,3 +53,50 @@ func TestAuthDisable(t *testing.T) {
|
|||||||
t.Fatal("should not have noop mount")
|
t.Fatal("should not have noop mount")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAuthDisableWithOptions(t *testing.T) {
|
||||||
|
core, _, token := vault.TestCoreUnsealed(t)
|
||||||
|
ln, addr := http.TestServer(t, core)
|
||||||
|
defer ln.Close()
|
||||||
|
|
||||||
|
ui := new(cli.MockUi)
|
||||||
|
c := &AuthDisableCommand{
|
||||||
|
Meta: meta.Meta{
|
||||||
|
ClientToken: token,
|
||||||
|
Ui: ui,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
args := []string{
|
||||||
|
"-address", addr,
|
||||||
|
"noop",
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run the command once to setup the client, it will fail
|
||||||
|
c.Run(args)
|
||||||
|
|
||||||
|
client, err := c.Client()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("err: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := client.Sys().EnableAuthWithOptions("noop", &api.EnableAuthOptions{
|
||||||
|
Type: "noop",
|
||||||
|
Description: "",
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatalf("err: %#v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if code := c.Run(args); code != 0 {
|
||||||
|
t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
mounts, err := client.Sys().ListAuth()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("err: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := mounts["noop"]; ok {
|
||||||
|
t.Fatal("should not have noop mount")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/hashicorp/vault/api"
|
||||||
"github.com/hashicorp/vault/meta"
|
"github.com/hashicorp/vault/meta"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -14,9 +15,11 @@ type AuthEnableCommand struct {
|
|||||||
|
|
||||||
func (c *AuthEnableCommand) Run(args []string) int {
|
func (c *AuthEnableCommand) Run(args []string) int {
|
||||||
var description, path string
|
var description, path string
|
||||||
|
var local bool
|
||||||
flags := c.Meta.FlagSet("auth-enable", meta.FlagSetDefault)
|
flags := c.Meta.FlagSet("auth-enable", meta.FlagSetDefault)
|
||||||
flags.StringVar(&description, "description", "", "")
|
flags.StringVar(&description, "description", "", "")
|
||||||
flags.StringVar(&path, "path", "", "")
|
flags.StringVar(&path, "path", "", "")
|
||||||
|
flags.BoolVar(&local, "local", false, "")
|
||||||
flags.Usage = func() { c.Ui.Error(c.Help()) }
|
flags.Usage = func() { c.Ui.Error(c.Help()) }
|
||||||
if err := flags.Parse(args); err != nil {
|
if err := flags.Parse(args); err != nil {
|
||||||
return 1
|
return 1
|
||||||
@@ -44,7 +47,11 @@ func (c *AuthEnableCommand) Run(args []string) int {
|
|||||||
return 2
|
return 2
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := client.Sys().EnableAuth(path, authType, description); err != nil {
|
if err := client.Sys().EnableAuthWithOptions(path, &api.EnableAuthOptions{
|
||||||
|
Type: authType,
|
||||||
|
Description: description,
|
||||||
|
Local: local,
|
||||||
|
}); err != nil {
|
||||||
c.Ui.Error(fmt.Sprintf(
|
c.Ui.Error(fmt.Sprintf(
|
||||||
"Error: %s", err))
|
"Error: %s", err))
|
||||||
return 2
|
return 2
|
||||||
@@ -82,6 +89,9 @@ Auth Enable Options:
|
|||||||
to the type of the mount. This will make the auth
|
to the type of the mount. This will make the auth
|
||||||
provider available at "/auth/<path>"
|
provider available at "/auth/<path>"
|
||||||
|
|
||||||
|
-local Mark the mount as a local mount. Local mounts
|
||||||
|
are not replicated nor (if a secondary)
|
||||||
|
removed by replication.
|
||||||
`
|
`
|
||||||
return strings.TrimSpace(helpText)
|
return strings.TrimSpace(helpText)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,11 +15,13 @@ type MountCommand struct {
|
|||||||
|
|
||||||
func (c *MountCommand) Run(args []string) int {
|
func (c *MountCommand) Run(args []string) int {
|
||||||
var description, path, defaultLeaseTTL, maxLeaseTTL string
|
var description, path, defaultLeaseTTL, maxLeaseTTL string
|
||||||
|
var local bool
|
||||||
flags := c.Meta.FlagSet("mount", meta.FlagSetDefault)
|
flags := c.Meta.FlagSet("mount", meta.FlagSetDefault)
|
||||||
flags.StringVar(&description, "description", "", "")
|
flags.StringVar(&description, "description", "", "")
|
||||||
flags.StringVar(&path, "path", "", "")
|
flags.StringVar(&path, "path", "", "")
|
||||||
flags.StringVar(&defaultLeaseTTL, "default-lease-ttl", "", "")
|
flags.StringVar(&defaultLeaseTTL, "default-lease-ttl", "", "")
|
||||||
flags.StringVar(&maxLeaseTTL, "max-lease-ttl", "", "")
|
flags.StringVar(&maxLeaseTTL, "max-lease-ttl", "", "")
|
||||||
|
flags.BoolVar(&local, "local", false, "")
|
||||||
flags.Usage = func() { c.Ui.Error(c.Help()) }
|
flags.Usage = func() { c.Ui.Error(c.Help()) }
|
||||||
if err := flags.Parse(args); err != nil {
|
if err := flags.Parse(args); err != nil {
|
||||||
return 1
|
return 1
|
||||||
@@ -54,6 +56,7 @@ func (c *MountCommand) Run(args []string) int {
|
|||||||
DefaultLeaseTTL: defaultLeaseTTL,
|
DefaultLeaseTTL: defaultLeaseTTL,
|
||||||
MaxLeaseTTL: maxLeaseTTL,
|
MaxLeaseTTL: maxLeaseTTL,
|
||||||
},
|
},
|
||||||
|
Local: local,
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := client.Sys().Mount(path, mountInfo); err != nil {
|
if err := client.Sys().Mount(path, mountInfo); err != nil {
|
||||||
@@ -102,6 +105,10 @@ Mount Options:
|
|||||||
the previously set value. Set to '0' to
|
the previously set value. Set to '0' to
|
||||||
explicitly set it to use the global default.
|
explicitly set it to use the global default.
|
||||||
|
|
||||||
|
-local Mark the mount as a local mount. Local mounts
|
||||||
|
are not replicated nor (if a secondary)
|
||||||
|
removed by replication.
|
||||||
|
|
||||||
`
|
`
|
||||||
return strings.TrimSpace(helpText)
|
return strings.TrimSpace(helpText)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -42,7 +42,7 @@ func (c *MountsCommand) Run(args []string) int {
|
|||||||
}
|
}
|
||||||
sort.Strings(paths)
|
sort.Strings(paths)
|
||||||
|
|
||||||
columns := []string{"Path | Type | Default TTL | Max TTL | Description"}
|
columns := []string{"Path | Type | Default TTL | Max TTL | Replication Behavior | Description"}
|
||||||
for _, path := range paths {
|
for _, path := range paths {
|
||||||
mount := mounts[path]
|
mount := mounts[path]
|
||||||
defTTL := "system"
|
defTTL := "system"
|
||||||
@@ -63,8 +63,12 @@ func (c *MountsCommand) Run(args []string) int {
|
|||||||
case mount.Config.MaxLeaseTTL != 0:
|
case mount.Config.MaxLeaseTTL != 0:
|
||||||
maxTTL = strconv.Itoa(mount.Config.MaxLeaseTTL)
|
maxTTL = strconv.Itoa(mount.Config.MaxLeaseTTL)
|
||||||
}
|
}
|
||||||
|
replicatedBehavior := "replicated"
|
||||||
|
if mount.Local {
|
||||||
|
replicatedBehavior = "local"
|
||||||
|
}
|
||||||
columns = append(columns, fmt.Sprintf(
|
columns = append(columns, fmt.Sprintf(
|
||||||
"%s | %s | %s | %s | %s", path, mount.Type, defTTL, maxTTL, mount.Description))
|
"%s | %s | %s | %s | %s | %s", path, mount.Type, defTTL, maxTTL, replicatedBehavior, mount.Description))
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Ui.Output(columnize.SimpleFormat(columns))
|
c.Ui.Output(columnize.SimpleFormat(columns))
|
||||||
|
|||||||
@@ -61,7 +61,7 @@ type ServerCommand struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *ServerCommand) Run(args []string) int {
|
func (c *ServerCommand) Run(args []string) int {
|
||||||
var dev, verifyOnly, devHA bool
|
var dev, verifyOnly, devHA, devTransactional bool
|
||||||
var configPath []string
|
var configPath []string
|
||||||
var logLevel, devRootTokenID, devListenAddress string
|
var logLevel, devRootTokenID, devListenAddress string
|
||||||
flags := c.Meta.FlagSet("server", meta.FlagSetDefault)
|
flags := c.Meta.FlagSet("server", meta.FlagSetDefault)
|
||||||
@@ -70,7 +70,8 @@ func (c *ServerCommand) Run(args []string) int {
|
|||||||
flags.StringVar(&devListenAddress, "dev-listen-address", "", "")
|
flags.StringVar(&devListenAddress, "dev-listen-address", "", "")
|
||||||
flags.StringVar(&logLevel, "log-level", "info", "")
|
flags.StringVar(&logLevel, "log-level", "info", "")
|
||||||
flags.BoolVar(&verifyOnly, "verify-only", false, "")
|
flags.BoolVar(&verifyOnly, "verify-only", false, "")
|
||||||
flags.BoolVar(&devHA, "dev-ha", false, "")
|
flags.BoolVar(&devHA, "ha", false, "")
|
||||||
|
flags.BoolVar(&devTransactional, "transactional", false, "")
|
||||||
flags.Usage = func() { c.Ui.Output(c.Help()) }
|
flags.Usage = func() { c.Ui.Output(c.Help()) }
|
||||||
flags.Var((*sliceflag.StringFlag)(&configPath), "config", "config")
|
flags.Var((*sliceflag.StringFlag)(&configPath), "config", "config")
|
||||||
if err := flags.Parse(args); err != nil {
|
if err := flags.Parse(args); err != nil {
|
||||||
@@ -122,7 +123,7 @@ func (c *ServerCommand) Run(args []string) int {
|
|||||||
devListenAddress = os.Getenv("VAULT_DEV_LISTEN_ADDRESS")
|
devListenAddress = os.Getenv("VAULT_DEV_LISTEN_ADDRESS")
|
||||||
}
|
}
|
||||||
|
|
||||||
if devHA {
|
if devHA || devTransactional {
|
||||||
dev = true
|
dev = true
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -143,7 +144,7 @@ func (c *ServerCommand) Run(args []string) int {
|
|||||||
// Load the configuration
|
// Load the configuration
|
||||||
var config *server.Config
|
var config *server.Config
|
||||||
if dev {
|
if dev {
|
||||||
config = server.DevConfig(devHA)
|
config = server.DevConfig(devHA, devTransactional)
|
||||||
if devListenAddress != "" {
|
if devListenAddress != "" {
|
||||||
config.Listeners[0].Config["address"] = devListenAddress
|
config.Listeners[0].Config["address"] = devListenAddress
|
||||||
}
|
}
|
||||||
@@ -235,6 +236,9 @@ func (c *ServerCommand) Run(args []string) int {
|
|||||||
ClusterName: config.ClusterName,
|
ClusterName: config.ClusterName,
|
||||||
CacheSize: config.CacheSize,
|
CacheSize: config.CacheSize,
|
||||||
}
|
}
|
||||||
|
if dev {
|
||||||
|
coreConfig.DevToken = devRootTokenID
|
||||||
|
}
|
||||||
|
|
||||||
var disableClustering bool
|
var disableClustering bool
|
||||||
|
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ type Config struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// DevConfig is a Config that is used for dev mode of Vault.
|
// DevConfig is a Config that is used for dev mode of Vault.
|
||||||
func DevConfig(ha bool) *Config {
|
func DevConfig(ha, transactional bool) *Config {
|
||||||
ret := &Config{
|
ret := &Config{
|
||||||
DisableCache: false,
|
DisableCache: false,
|
||||||
DisableMlock: true,
|
DisableMlock: true,
|
||||||
@@ -63,7 +63,12 @@ func DevConfig(ha bool) *Config {
|
|||||||
DefaultLeaseTTL: 32 * 24 * time.Hour,
|
DefaultLeaseTTL: 32 * 24 * time.Hour,
|
||||||
}
|
}
|
||||||
|
|
||||||
if ha {
|
switch {
|
||||||
|
case ha && transactional:
|
||||||
|
ret.Backend.Type = "inmem_transactional_ha"
|
||||||
|
case !ha && transactional:
|
||||||
|
ret.Backend.Type = "inmem_transactional"
|
||||||
|
case ha && !transactional:
|
||||||
ret.Backend.Type = "inmem_ha"
|
ret.Backend.Type = "inmem_ha"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ func TestServer_CommonHA(t *testing.T) {
|
|||||||
args := []string{"-config", tmpfile.Name(), "-verify-only", "true"}
|
args := []string{"-config", tmpfile.Name(), "-verify-only", "true"}
|
||||||
|
|
||||||
if code := c.Run(args); code != 0 {
|
if code := c.Run(args); code != 0 {
|
||||||
t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String())
|
t.Fatalf("bad: %d\n\n%s\n\n%s", code, ui.ErrorWriter.String(), ui.OutputWriter.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
if !strings.Contains(ui.OutputWriter.String(), "(HA available)") {
|
if !strings.Contains(ui.OutputWriter.String(), "(HA available)") {
|
||||||
@@ -61,7 +61,7 @@ func TestServer_GoodSeparateHA(t *testing.T) {
|
|||||||
args := []string{"-config", tmpfile.Name(), "-verify-only", "true"}
|
args := []string{"-config", tmpfile.Name(), "-verify-only", "true"}
|
||||||
|
|
||||||
if code := c.Run(args); code != 0 {
|
if code := c.Run(args); code != 0 {
|
||||||
t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String())
|
t.Fatalf("bad: %d\n\n%s\n\n%s", code, ui.ErrorWriter.String(), ui.OutputWriter.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
if !strings.Contains(ui.OutputWriter.String(), "HA Backend:") {
|
if !strings.Contains(ui.OutputWriter.String(), "HA Backend:") {
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ func (c *StatusCommand) Run(args []string) int {
|
|||||||
"Key Shares: %d\n"+
|
"Key Shares: %d\n"+
|
||||||
"Key Threshold: %d\n"+
|
"Key Threshold: %d\n"+
|
||||||
"Unseal Progress: %d\n"+
|
"Unseal Progress: %d\n"+
|
||||||
"Unseal Nonce: %v"+
|
"Unseal Nonce: %v\n"+
|
||||||
"Version: %s",
|
"Version: %s",
|
||||||
sealStatus.Sealed,
|
sealStatus.Sealed,
|
||||||
sealStatus.N,
|
sealStatus.N,
|
||||||
|
|||||||
@@ -14,6 +14,16 @@ func TestCIDRUtil_IPBelongsToCIDR(t *testing.T) {
|
|||||||
t.Fatalf("expected IP %q to belong to CIDR %q", ip, cidr)
|
t.Fatalf("expected IP %q to belong to CIDR %q", ip, cidr)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ip = "10.197.192.6"
|
||||||
|
cidr = "10.197.192.0/18"
|
||||||
|
belongs, err = IPBelongsToCIDR(ip, cidr)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if !belongs {
|
||||||
|
t.Fatalf("expected IP %q to belong to CIDR %q", ip, cidr)
|
||||||
|
}
|
||||||
|
|
||||||
ip = "192.168.25.30"
|
ip = "192.168.25.30"
|
||||||
cidr = "192.168.26.30/24"
|
cidr = "192.168.26.30/24"
|
||||||
belongs, err = IPBelongsToCIDR(ip, cidr)
|
belongs, err = IPBelongsToCIDR(ip, cidr)
|
||||||
@@ -44,6 +54,17 @@ func TestCIDRUtil_IPBelongsToCIDRBlocksString(t *testing.T) {
|
|||||||
t.Fatalf("expected IP %q to belong to one of the CIDRs in %q", ip, cidrList)
|
t.Fatalf("expected IP %q to belong to one of the CIDRs in %q", ip, cidrList)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ip = "10.197.192.6"
|
||||||
|
cidrList = "1.2.3.0/8,10.197.192.0/18,10.197.193.0/24"
|
||||||
|
|
||||||
|
belongs, err = IPBelongsToCIDRBlocksString(ip, cidrList, ",")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if !belongs {
|
||||||
|
t.Fatalf("expected IP %q to belong to one of the CIDRs in %q", ip, cidrList)
|
||||||
|
}
|
||||||
|
|
||||||
ip = "192.168.27.29"
|
ip = "192.168.27.29"
|
||||||
cidrList = "172.169.100.200/18,192.168.0.0.0/16,10.10.20.20/24"
|
cidrList = "172.169.100.200/18,192.168.0.0.0/16,10.10.20.20/24"
|
||||||
|
|
||||||
|
|||||||
7
helper/consts/consts.go
Normal file
7
helper/consts/consts.go
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
package consts
|
||||||
|
|
||||||
|
const (
|
||||||
|
// ExpirationRestoreWorkerCount specifies the numer of workers to use while
|
||||||
|
// restoring leases into the expiration manager
|
||||||
|
ExpirationRestoreWorkerCount = 64
|
||||||
|
)
|
||||||
13
helper/consts/error.go
Normal file
13
helper/consts/error.go
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
package consts
|
||||||
|
|
||||||
|
import "errors"
|
||||||
|
|
||||||
|
var (
|
||||||
|
// ErrSealed is returned if an operation is performed on a sealed barrier.
|
||||||
|
// No operation is expected to succeed before unsealing
|
||||||
|
ErrSealed = errors.New("Vault is sealed")
|
||||||
|
|
||||||
|
// ErrStandby is returned if an operation is performed on a standby Vault.
|
||||||
|
// No operation is expected to succeed until active.
|
||||||
|
ErrStandby = errors.New("Vault is in standby mode")
|
||||||
|
)
|
||||||
20
helper/consts/replication.go
Normal file
20
helper/consts/replication.go
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
package consts
|
||||||
|
|
||||||
|
type ReplicationState uint32
|
||||||
|
|
||||||
|
const (
|
||||||
|
ReplicationDisabled ReplicationState = iota
|
||||||
|
ReplicationPrimary
|
||||||
|
ReplicationSecondary
|
||||||
|
)
|
||||||
|
|
||||||
|
func (r ReplicationState) String() string {
|
||||||
|
switch r {
|
||||||
|
case ReplicationSecondary:
|
||||||
|
return "secondary"
|
||||||
|
case ReplicationPrimary:
|
||||||
|
return "primary"
|
||||||
|
}
|
||||||
|
|
||||||
|
return "disabled"
|
||||||
|
}
|
||||||
@@ -71,6 +71,15 @@ func (lm *LockManager) CacheActive() bool {
|
|||||||
return lm.cache != nil
|
return lm.cache != nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (lm *LockManager) InvalidatePolicy(name string) {
|
||||||
|
// Check if it's in our cache. If so, return right away.
|
||||||
|
if lm.CacheActive() {
|
||||||
|
lm.cacheMutex.Lock()
|
||||||
|
defer lm.cacheMutex.Unlock()
|
||||||
|
delete(lm.cache, name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (lm *LockManager) policyLock(name string, lockType bool) *sync.RWMutex {
|
func (lm *LockManager) policyLock(name string, lockType bool) *sync.RWMutex {
|
||||||
lm.locksMutex.RLock()
|
lm.locksMutex.RLock()
|
||||||
lock := lm.locks[name]
|
lock := lm.locks[name]
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/hashicorp/errwrap"
|
"github.com/hashicorp/errwrap"
|
||||||
|
"github.com/hashicorp/vault/helper/consts"
|
||||||
"github.com/hashicorp/vault/helper/duration"
|
"github.com/hashicorp/vault/helper/duration"
|
||||||
"github.com/hashicorp/vault/helper/jsonutil"
|
"github.com/hashicorp/vault/helper/jsonutil"
|
||||||
"github.com/hashicorp/vault/logical"
|
"github.com/hashicorp/vault/logical"
|
||||||
@@ -206,11 +207,11 @@ func handleRequestForwarding(core *vault.Core, handler http.Handler) http.Handle
|
|||||||
// case of an error.
|
// case of an error.
|
||||||
func request(core *vault.Core, w http.ResponseWriter, rawReq *http.Request, r *logical.Request) (*logical.Response, bool) {
|
func request(core *vault.Core, w http.ResponseWriter, rawReq *http.Request, r *logical.Request) (*logical.Response, bool) {
|
||||||
resp, err := core.HandleRequest(r)
|
resp, err := core.HandleRequest(r)
|
||||||
if errwrap.Contains(err, vault.ErrStandby.Error()) {
|
if errwrap.Contains(err, consts.ErrStandby.Error()) {
|
||||||
respondStandby(core, w, rawReq.URL)
|
respondStandby(core, w, rawReq.URL)
|
||||||
return resp, false
|
return resp, false
|
||||||
}
|
}
|
||||||
if respondErrorCommon(w, resp, err) {
|
if respondErrorCommon(w, r, resp, err) {
|
||||||
return resp, false
|
return resp, false
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -310,20 +311,7 @@ func requestWrapInfo(r *http.Request, req *logical.Request) (*logical.Request, e
|
|||||||
}
|
}
|
||||||
|
|
||||||
func respondError(w http.ResponseWriter, status int, err error) {
|
func respondError(w http.ResponseWriter, status int, err error) {
|
||||||
// Adjust status code when sealed
|
logical.AdjustErrorStatusCode(&status, err)
|
||||||
if errwrap.Contains(err, vault.ErrSealed.Error()) {
|
|
||||||
status = http.StatusServiceUnavailable
|
|
||||||
}
|
|
||||||
|
|
||||||
// Adjust status code on
|
|
||||||
if errwrap.Contains(err, "http: request body too large") {
|
|
||||||
status = http.StatusRequestEntityTooLarge
|
|
||||||
}
|
|
||||||
|
|
||||||
// Allow HTTPCoded error passthrough to specify a code
|
|
||||||
if t, ok := err.(logical.HTTPCodedError); ok {
|
|
||||||
status = t.Code()
|
|
||||||
}
|
|
||||||
|
|
||||||
w.Header().Add("Content-Type", "application/json")
|
w.Header().Add("Content-Type", "application/json")
|
||||||
w.WriteHeader(status)
|
w.WriteHeader(status)
|
||||||
@@ -337,42 +325,13 @@ func respondError(w http.ResponseWriter, status int, err error) {
|
|||||||
enc.Encode(resp)
|
enc.Encode(resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
func respondErrorCommon(w http.ResponseWriter, resp *logical.Response, err error) bool {
|
func respondErrorCommon(w http.ResponseWriter, req *logical.Request, resp *logical.Response, err error) bool {
|
||||||
// If there are no errors return
|
statusCode, newErr := logical.RespondErrorCommon(req, resp, err)
|
||||||
if err == nil && (resp == nil || !resp.IsError()) {
|
if newErr == nil && statusCode == 0 {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start out with internal server error since in most of these cases there
|
respondError(w, statusCode, newErr)
|
||||||
// won't be a response so this won't be overridden
|
|
||||||
statusCode := http.StatusInternalServerError
|
|
||||||
// If we actually have a response, start out with bad request
|
|
||||||
if resp != nil {
|
|
||||||
statusCode = http.StatusBadRequest
|
|
||||||
}
|
|
||||||
|
|
||||||
// Now, check the error itself; if it has a specific logical error, set the
|
|
||||||
// appropriate code
|
|
||||||
if err != nil {
|
|
||||||
switch {
|
|
||||||
case errwrap.ContainsType(err, new(vault.StatusBadRequest)):
|
|
||||||
statusCode = http.StatusBadRequest
|
|
||||||
case errwrap.Contains(err, logical.ErrPermissionDenied.Error()):
|
|
||||||
statusCode = http.StatusForbidden
|
|
||||||
case errwrap.Contains(err, logical.ErrUnsupportedOperation.Error()):
|
|
||||||
statusCode = http.StatusMethodNotAllowed
|
|
||||||
case errwrap.Contains(err, logical.ErrUnsupportedPath.Error()):
|
|
||||||
statusCode = http.StatusNotFound
|
|
||||||
case errwrap.Contains(err, logical.ErrInvalidRequest.Error()):
|
|
||||||
statusCode = http.StatusBadRequest
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if resp != nil && resp.IsError() {
|
|
||||||
err = fmt.Errorf("%s", resp.Data["error"].(string))
|
|
||||||
}
|
|
||||||
|
|
||||||
respondError(w, statusCode, err)
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/hashicorp/go-cleanhttp"
|
"github.com/hashicorp/go-cleanhttp"
|
||||||
|
"github.com/hashicorp/vault/helper/consts"
|
||||||
"github.com/hashicorp/vault/logical"
|
"github.com/hashicorp/vault/logical"
|
||||||
"github.com/hashicorp/vault/vault"
|
"github.com/hashicorp/vault/vault"
|
||||||
)
|
)
|
||||||
@@ -80,6 +81,7 @@ func TestSysMounts_headerAuth(t *testing.T) {
|
|||||||
"default_lease_ttl": json.Number("0"),
|
"default_lease_ttl": json.Number("0"),
|
||||||
"max_lease_ttl": json.Number("0"),
|
"max_lease_ttl": json.Number("0"),
|
||||||
},
|
},
|
||||||
|
"local": false,
|
||||||
},
|
},
|
||||||
"sys/": map[string]interface{}{
|
"sys/": map[string]interface{}{
|
||||||
"description": "system endpoints used for control, policy and debugging",
|
"description": "system endpoints used for control, policy and debugging",
|
||||||
@@ -88,6 +90,7 @@ func TestSysMounts_headerAuth(t *testing.T) {
|
|||||||
"default_lease_ttl": json.Number("0"),
|
"default_lease_ttl": json.Number("0"),
|
||||||
"max_lease_ttl": json.Number("0"),
|
"max_lease_ttl": json.Number("0"),
|
||||||
},
|
},
|
||||||
|
"local": false,
|
||||||
},
|
},
|
||||||
"cubbyhole/": map[string]interface{}{
|
"cubbyhole/": map[string]interface{}{
|
||||||
"description": "per-token private secret storage",
|
"description": "per-token private secret storage",
|
||||||
@@ -96,6 +99,7 @@ func TestSysMounts_headerAuth(t *testing.T) {
|
|||||||
"default_lease_ttl": json.Number("0"),
|
"default_lease_ttl": json.Number("0"),
|
||||||
"max_lease_ttl": json.Number("0"),
|
"max_lease_ttl": json.Number("0"),
|
||||||
},
|
},
|
||||||
|
"local": true,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"secret/": map[string]interface{}{
|
"secret/": map[string]interface{}{
|
||||||
@@ -105,6 +109,7 @@ func TestSysMounts_headerAuth(t *testing.T) {
|
|||||||
"default_lease_ttl": json.Number("0"),
|
"default_lease_ttl": json.Number("0"),
|
||||||
"max_lease_ttl": json.Number("0"),
|
"max_lease_ttl": json.Number("0"),
|
||||||
},
|
},
|
||||||
|
"local": false,
|
||||||
},
|
},
|
||||||
"sys/": map[string]interface{}{
|
"sys/": map[string]interface{}{
|
||||||
"description": "system endpoints used for control, policy and debugging",
|
"description": "system endpoints used for control, policy and debugging",
|
||||||
@@ -113,6 +118,7 @@ func TestSysMounts_headerAuth(t *testing.T) {
|
|||||||
"default_lease_ttl": json.Number("0"),
|
"default_lease_ttl": json.Number("0"),
|
||||||
"max_lease_ttl": json.Number("0"),
|
"max_lease_ttl": json.Number("0"),
|
||||||
},
|
},
|
||||||
|
"local": false,
|
||||||
},
|
},
|
||||||
"cubbyhole/": map[string]interface{}{
|
"cubbyhole/": map[string]interface{}{
|
||||||
"description": "per-token private secret storage",
|
"description": "per-token private secret storage",
|
||||||
@@ -121,6 +127,7 @@ func TestSysMounts_headerAuth(t *testing.T) {
|
|||||||
"default_lease_ttl": json.Number("0"),
|
"default_lease_ttl": json.Number("0"),
|
||||||
"max_lease_ttl": json.Number("0"),
|
"max_lease_ttl": json.Number("0"),
|
||||||
},
|
},
|
||||||
|
"local": true,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
testResponseStatus(t, resp, 200)
|
testResponseStatus(t, resp, 200)
|
||||||
@@ -223,7 +230,7 @@ func TestHandler_error(t *testing.T) {
|
|||||||
// vault.ErrSealed is a special case
|
// vault.ErrSealed is a special case
|
||||||
w3 := httptest.NewRecorder()
|
w3 := httptest.NewRecorder()
|
||||||
|
|
||||||
respondError(w3, 400, vault.ErrSealed)
|
respondError(w3, 400, consts.ErrSealed)
|
||||||
|
|
||||||
if w3.Code != 503 {
|
if w3.Code != 503 {
|
||||||
t.Fatalf("expected 503, got %d", w3.Code)
|
t.Fatalf("expected 503, got %d", w3.Code)
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ func handleHelp(core *vault.Core, w http.ResponseWriter, req *http.Request) {
|
|||||||
|
|
||||||
resp, err := core.HandleRequest(lreq)
|
resp, err := core.HandleRequest(lreq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
respondErrorCommon(w, resp, err)
|
respondErrorCommon(w, lreq, resp, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -53,6 +53,12 @@ func buildLogicalRequest(core *vault.Core, w http.ResponseWriter, r *http.Reques
|
|||||||
return nil, http.StatusMethodNotAllowed, nil
|
return nil, http.StatusMethodNotAllowed, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if op == logical.ListOperation {
|
||||||
|
if !strings.HasSuffix(path, "/") {
|
||||||
|
path += "/"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Parse the request if we can
|
// Parse the request if we can
|
||||||
var data map[string]interface{}
|
var data map[string]interface{}
|
||||||
if op == logical.UpdateOperation {
|
if op == logical.UpdateOperation {
|
||||||
@@ -109,40 +115,13 @@ func handleLogical(core *vault.Core, dataOnly bool, prepareRequestCallback Prepa
|
|||||||
|
|
||||||
// Make the internal request. We attach the connection info
|
// Make the internal request. We attach the connection info
|
||||||
// as well in case this is an authentication request that requires
|
// as well in case this is an authentication request that requires
|
||||||
// it. Vault core handles stripping this if we need to.
|
// it. Vault core handles stripping this if we need to. This also
|
||||||
|
// handles all error cases; if we hit respondLogical, the request is a
|
||||||
|
// success.
|
||||||
resp, ok := request(core, w, r, req)
|
resp, ok := request(core, w, r, req)
|
||||||
if !ok {
|
if !ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
switch {
|
|
||||||
case req.Operation == logical.ReadOperation:
|
|
||||||
if resp == nil {
|
|
||||||
respondError(w, http.StatusNotFound, nil)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Basically: if we have empty "keys" or no keys at all, 404. This
|
|
||||||
// provides consistency with GET.
|
|
||||||
case req.Operation == logical.ListOperation && resp.WrapInfo == nil:
|
|
||||||
if resp == nil || len(resp.Data) == 0 {
|
|
||||||
respondError(w, http.StatusNotFound, nil)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
keysRaw, ok := resp.Data["keys"]
|
|
||||||
if !ok || keysRaw == nil {
|
|
||||||
respondError(w, http.StatusNotFound, nil)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
keys, ok := keysRaw.([]string)
|
|
||||||
if !ok {
|
|
||||||
respondError(w, http.StatusInternalServerError, nil)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if len(keys) == 0 {
|
|
||||||
respondError(w, http.StatusNotFound, nil)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Build the proper response
|
// Build the proper response
|
||||||
respondLogical(w, r, req, dataOnly, resp)
|
respondLogical(w, r, req, dataOnly, resp)
|
||||||
|
|||||||
@@ -4,8 +4,10 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"io"
|
"io"
|
||||||
|
"net/http"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -101,7 +103,7 @@ func TestLogical_StandbyRedirect(t *testing.T) {
|
|||||||
|
|
||||||
// Attempt to fix raciness in this test by giving the first core a chance
|
// Attempt to fix raciness in this test by giving the first core a chance
|
||||||
// to grab the lock
|
// to grab the lock
|
||||||
time.Sleep(time.Second)
|
time.Sleep(2 * time.Second)
|
||||||
|
|
||||||
// Create a second HA Vault
|
// Create a second HA Vault
|
||||||
conf2 := &vault.CoreConfig{
|
conf2 := &vault.CoreConfig{
|
||||||
@@ -252,3 +254,42 @@ func TestLogical_RequestSizeLimit(t *testing.T) {
|
|||||||
})
|
})
|
||||||
testResponseStatus(t, resp, 413)
|
testResponseStatus(t, resp, 413)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestLogical_ListSuffix(t *testing.T) {
|
||||||
|
core, _, _ := vault.TestCoreUnsealed(t)
|
||||||
|
req, _ := http.NewRequest("GET", "http://127.0.0.1:8200/v1/secret/foo", nil)
|
||||||
|
lreq, status, err := buildLogicalRequest(core, nil, req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if status != 0 {
|
||||||
|
t.Fatalf("got status %d", status)
|
||||||
|
}
|
||||||
|
if strings.HasSuffix(lreq.Path, "/") {
|
||||||
|
t.Fatal("trailing slash found on path")
|
||||||
|
}
|
||||||
|
|
||||||
|
req, _ = http.NewRequest("GET", "http://127.0.0.1:8200/v1/secret/foo?list=true", nil)
|
||||||
|
lreq, status, err = buildLogicalRequest(core, nil, req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if status != 0 {
|
||||||
|
t.Fatalf("got status %d", status)
|
||||||
|
}
|
||||||
|
if !strings.HasSuffix(lreq.Path, "/") {
|
||||||
|
t.Fatal("trailing slash not found on path")
|
||||||
|
}
|
||||||
|
|
||||||
|
req, _ = http.NewRequest("LIST", "http://127.0.0.1:8200/v1/secret/foo", nil)
|
||||||
|
lreq, status, err = buildLogicalRequest(core, nil, req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if status != 0 {
|
||||||
|
t.Fatalf("got status %d", status)
|
||||||
|
}
|
||||||
|
if !strings.HasSuffix(lreq.Path, "/") {
|
||||||
|
t.Fatal("trailing slash not found on path")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ func TestSysAudit(t *testing.T) {
|
|||||||
"type": "noop",
|
"type": "noop",
|
||||||
"description": "",
|
"description": "",
|
||||||
"options": map[string]interface{}{},
|
"options": map[string]interface{}{},
|
||||||
|
"local": false,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"noop/": map[string]interface{}{
|
"noop/": map[string]interface{}{
|
||||||
@@ -42,6 +43,7 @@ func TestSysAudit(t *testing.T) {
|
|||||||
"type": "noop",
|
"type": "noop",
|
||||||
"description": "",
|
"description": "",
|
||||||
"options": map[string]interface{}{},
|
"options": map[string]interface{}{},
|
||||||
|
"local": false,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
testResponseStatus(t, resp, 200)
|
testResponseStatus(t, resp, 200)
|
||||||
|
|||||||
@@ -32,6 +32,7 @@ func TestSysAuth(t *testing.T) {
|
|||||||
"default_lease_ttl": json.Number("0"),
|
"default_lease_ttl": json.Number("0"),
|
||||||
"max_lease_ttl": json.Number("0"),
|
"max_lease_ttl": json.Number("0"),
|
||||||
},
|
},
|
||||||
|
"local": false,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"token/": map[string]interface{}{
|
"token/": map[string]interface{}{
|
||||||
@@ -41,6 +42,7 @@ func TestSysAuth(t *testing.T) {
|
|||||||
"default_lease_ttl": json.Number("0"),
|
"default_lease_ttl": json.Number("0"),
|
||||||
"max_lease_ttl": json.Number("0"),
|
"max_lease_ttl": json.Number("0"),
|
||||||
},
|
},
|
||||||
|
"local": false,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
testResponseStatus(t, resp, 200)
|
testResponseStatus(t, resp, 200)
|
||||||
@@ -83,6 +85,7 @@ func TestSysEnableAuth(t *testing.T) {
|
|||||||
"default_lease_ttl": json.Number("0"),
|
"default_lease_ttl": json.Number("0"),
|
||||||
"max_lease_ttl": json.Number("0"),
|
"max_lease_ttl": json.Number("0"),
|
||||||
},
|
},
|
||||||
|
"local": false,
|
||||||
},
|
},
|
||||||
"token/": map[string]interface{}{
|
"token/": map[string]interface{}{
|
||||||
"description": "token based credentials",
|
"description": "token based credentials",
|
||||||
@@ -91,6 +94,7 @@ func TestSysEnableAuth(t *testing.T) {
|
|||||||
"default_lease_ttl": json.Number("0"),
|
"default_lease_ttl": json.Number("0"),
|
||||||
"max_lease_ttl": json.Number("0"),
|
"max_lease_ttl": json.Number("0"),
|
||||||
},
|
},
|
||||||
|
"local": false,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"foo/": map[string]interface{}{
|
"foo/": map[string]interface{}{
|
||||||
@@ -100,6 +104,7 @@ func TestSysEnableAuth(t *testing.T) {
|
|||||||
"default_lease_ttl": json.Number("0"),
|
"default_lease_ttl": json.Number("0"),
|
||||||
"max_lease_ttl": json.Number("0"),
|
"max_lease_ttl": json.Number("0"),
|
||||||
},
|
},
|
||||||
|
"local": false,
|
||||||
},
|
},
|
||||||
"token/": map[string]interface{}{
|
"token/": map[string]interface{}{
|
||||||
"description": "token based credentials",
|
"description": "token based credentials",
|
||||||
@@ -108,6 +113,7 @@ func TestSysEnableAuth(t *testing.T) {
|
|||||||
"default_lease_ttl": json.Number("0"),
|
"default_lease_ttl": json.Number("0"),
|
||||||
"max_lease_ttl": json.Number("0"),
|
"max_lease_ttl": json.Number("0"),
|
||||||
},
|
},
|
||||||
|
"local": false,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
testResponseStatus(t, resp, 200)
|
testResponseStatus(t, resp, 200)
|
||||||
@@ -153,6 +159,7 @@ func TestSysDisableAuth(t *testing.T) {
|
|||||||
},
|
},
|
||||||
"description": "token based credentials",
|
"description": "token based credentials",
|
||||||
"type": "token",
|
"type": "token",
|
||||||
|
"local": false,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"token/": map[string]interface{}{
|
"token/": map[string]interface{}{
|
||||||
@@ -162,6 +169,7 @@ func TestSysDisableAuth(t *testing.T) {
|
|||||||
},
|
},
|
||||||
"description": "token based credentials",
|
"description": "token based credentials",
|
||||||
"type": "token",
|
"type": "token",
|
||||||
|
"local": false,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
testResponseStatus(t, resp, 200)
|
testResponseStatus(t, resp, 200)
|
||||||
|
|||||||
@@ -33,6 +33,7 @@ func TestSysMounts(t *testing.T) {
|
|||||||
"default_lease_ttl": json.Number("0"),
|
"default_lease_ttl": json.Number("0"),
|
||||||
"max_lease_ttl": json.Number("0"),
|
"max_lease_ttl": json.Number("0"),
|
||||||
},
|
},
|
||||||
|
"local": false,
|
||||||
},
|
},
|
||||||
"sys/": map[string]interface{}{
|
"sys/": map[string]interface{}{
|
||||||
"description": "system endpoints used for control, policy and debugging",
|
"description": "system endpoints used for control, policy and debugging",
|
||||||
@@ -41,6 +42,7 @@ func TestSysMounts(t *testing.T) {
|
|||||||
"default_lease_ttl": json.Number("0"),
|
"default_lease_ttl": json.Number("0"),
|
||||||
"max_lease_ttl": json.Number("0"),
|
"max_lease_ttl": json.Number("0"),
|
||||||
},
|
},
|
||||||
|
"local": false,
|
||||||
},
|
},
|
||||||
"cubbyhole/": map[string]interface{}{
|
"cubbyhole/": map[string]interface{}{
|
||||||
"description": "per-token private secret storage",
|
"description": "per-token private secret storage",
|
||||||
@@ -49,6 +51,7 @@ func TestSysMounts(t *testing.T) {
|
|||||||
"default_lease_ttl": json.Number("0"),
|
"default_lease_ttl": json.Number("0"),
|
||||||
"max_lease_ttl": json.Number("0"),
|
"max_lease_ttl": json.Number("0"),
|
||||||
},
|
},
|
||||||
|
"local": true,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"secret/": map[string]interface{}{
|
"secret/": map[string]interface{}{
|
||||||
@@ -58,6 +61,7 @@ func TestSysMounts(t *testing.T) {
|
|||||||
"default_lease_ttl": json.Number("0"),
|
"default_lease_ttl": json.Number("0"),
|
||||||
"max_lease_ttl": json.Number("0"),
|
"max_lease_ttl": json.Number("0"),
|
||||||
},
|
},
|
||||||
|
"local": false,
|
||||||
},
|
},
|
||||||
"sys/": map[string]interface{}{
|
"sys/": map[string]interface{}{
|
||||||
"description": "system endpoints used for control, policy and debugging",
|
"description": "system endpoints used for control, policy and debugging",
|
||||||
@@ -66,6 +70,7 @@ func TestSysMounts(t *testing.T) {
|
|||||||
"default_lease_ttl": json.Number("0"),
|
"default_lease_ttl": json.Number("0"),
|
||||||
"max_lease_ttl": json.Number("0"),
|
"max_lease_ttl": json.Number("0"),
|
||||||
},
|
},
|
||||||
|
"local": false,
|
||||||
},
|
},
|
||||||
"cubbyhole/": map[string]interface{}{
|
"cubbyhole/": map[string]interface{}{
|
||||||
"description": "per-token private secret storage",
|
"description": "per-token private secret storage",
|
||||||
@@ -74,6 +79,7 @@ func TestSysMounts(t *testing.T) {
|
|||||||
"default_lease_ttl": json.Number("0"),
|
"default_lease_ttl": json.Number("0"),
|
||||||
"max_lease_ttl": json.Number("0"),
|
"max_lease_ttl": json.Number("0"),
|
||||||
},
|
},
|
||||||
|
"local": true,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
testResponseStatus(t, resp, 200)
|
testResponseStatus(t, resp, 200)
|
||||||
@@ -114,6 +120,7 @@ func TestSysMount(t *testing.T) {
|
|||||||
"default_lease_ttl": json.Number("0"),
|
"default_lease_ttl": json.Number("0"),
|
||||||
"max_lease_ttl": json.Number("0"),
|
"max_lease_ttl": json.Number("0"),
|
||||||
},
|
},
|
||||||
|
"local": false,
|
||||||
},
|
},
|
||||||
"secret/": map[string]interface{}{
|
"secret/": map[string]interface{}{
|
||||||
"description": "generic secret storage",
|
"description": "generic secret storage",
|
||||||
@@ -122,6 +129,7 @@ func TestSysMount(t *testing.T) {
|
|||||||
"default_lease_ttl": json.Number("0"),
|
"default_lease_ttl": json.Number("0"),
|
||||||
"max_lease_ttl": json.Number("0"),
|
"max_lease_ttl": json.Number("0"),
|
||||||
},
|
},
|
||||||
|
"local": false,
|
||||||
},
|
},
|
||||||
"sys/": map[string]interface{}{
|
"sys/": map[string]interface{}{
|
||||||
"description": "system endpoints used for control, policy and debugging",
|
"description": "system endpoints used for control, policy and debugging",
|
||||||
@@ -130,6 +138,7 @@ func TestSysMount(t *testing.T) {
|
|||||||
"default_lease_ttl": json.Number("0"),
|
"default_lease_ttl": json.Number("0"),
|
||||||
"max_lease_ttl": json.Number("0"),
|
"max_lease_ttl": json.Number("0"),
|
||||||
},
|
},
|
||||||
|
"local": false,
|
||||||
},
|
},
|
||||||
"cubbyhole/": map[string]interface{}{
|
"cubbyhole/": map[string]interface{}{
|
||||||
"description": "per-token private secret storage",
|
"description": "per-token private secret storage",
|
||||||
@@ -138,6 +147,7 @@ func TestSysMount(t *testing.T) {
|
|||||||
"default_lease_ttl": json.Number("0"),
|
"default_lease_ttl": json.Number("0"),
|
||||||
"max_lease_ttl": json.Number("0"),
|
"max_lease_ttl": json.Number("0"),
|
||||||
},
|
},
|
||||||
|
"local": true,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"foo/": map[string]interface{}{
|
"foo/": map[string]interface{}{
|
||||||
@@ -147,6 +157,7 @@ func TestSysMount(t *testing.T) {
|
|||||||
"default_lease_ttl": json.Number("0"),
|
"default_lease_ttl": json.Number("0"),
|
||||||
"max_lease_ttl": json.Number("0"),
|
"max_lease_ttl": json.Number("0"),
|
||||||
},
|
},
|
||||||
|
"local": false,
|
||||||
},
|
},
|
||||||
"secret/": map[string]interface{}{
|
"secret/": map[string]interface{}{
|
||||||
"description": "generic secret storage",
|
"description": "generic secret storage",
|
||||||
@@ -155,6 +166,7 @@ func TestSysMount(t *testing.T) {
|
|||||||
"default_lease_ttl": json.Number("0"),
|
"default_lease_ttl": json.Number("0"),
|
||||||
"max_lease_ttl": json.Number("0"),
|
"max_lease_ttl": json.Number("0"),
|
||||||
},
|
},
|
||||||
|
"local": false,
|
||||||
},
|
},
|
||||||
"sys/": map[string]interface{}{
|
"sys/": map[string]interface{}{
|
||||||
"description": "system endpoints used for control, policy and debugging",
|
"description": "system endpoints used for control, policy and debugging",
|
||||||
@@ -163,6 +175,7 @@ func TestSysMount(t *testing.T) {
|
|||||||
"default_lease_ttl": json.Number("0"),
|
"default_lease_ttl": json.Number("0"),
|
||||||
"max_lease_ttl": json.Number("0"),
|
"max_lease_ttl": json.Number("0"),
|
||||||
},
|
},
|
||||||
|
"local": false,
|
||||||
},
|
},
|
||||||
"cubbyhole/": map[string]interface{}{
|
"cubbyhole/": map[string]interface{}{
|
||||||
"description": "per-token private secret storage",
|
"description": "per-token private secret storage",
|
||||||
@@ -171,6 +184,7 @@ func TestSysMount(t *testing.T) {
|
|||||||
"default_lease_ttl": json.Number("0"),
|
"default_lease_ttl": json.Number("0"),
|
||||||
"max_lease_ttl": json.Number("0"),
|
"max_lease_ttl": json.Number("0"),
|
||||||
},
|
},
|
||||||
|
"local": true,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
testResponseStatus(t, resp, 200)
|
testResponseStatus(t, resp, 200)
|
||||||
@@ -233,6 +247,7 @@ func TestSysRemount(t *testing.T) {
|
|||||||
"default_lease_ttl": json.Number("0"),
|
"default_lease_ttl": json.Number("0"),
|
||||||
"max_lease_ttl": json.Number("0"),
|
"max_lease_ttl": json.Number("0"),
|
||||||
},
|
},
|
||||||
|
"local": false,
|
||||||
},
|
},
|
||||||
"secret/": map[string]interface{}{
|
"secret/": map[string]interface{}{
|
||||||
"description": "generic secret storage",
|
"description": "generic secret storage",
|
||||||
@@ -241,6 +256,7 @@ func TestSysRemount(t *testing.T) {
|
|||||||
"default_lease_ttl": json.Number("0"),
|
"default_lease_ttl": json.Number("0"),
|
||||||
"max_lease_ttl": json.Number("0"),
|
"max_lease_ttl": json.Number("0"),
|
||||||
},
|
},
|
||||||
|
"local": false,
|
||||||
},
|
},
|
||||||
"sys/": map[string]interface{}{
|
"sys/": map[string]interface{}{
|
||||||
"description": "system endpoints used for control, policy and debugging",
|
"description": "system endpoints used for control, policy and debugging",
|
||||||
@@ -249,6 +265,7 @@ func TestSysRemount(t *testing.T) {
|
|||||||
"default_lease_ttl": json.Number("0"),
|
"default_lease_ttl": json.Number("0"),
|
||||||
"max_lease_ttl": json.Number("0"),
|
"max_lease_ttl": json.Number("0"),
|
||||||
},
|
},
|
||||||
|
"local": false,
|
||||||
},
|
},
|
||||||
"cubbyhole/": map[string]interface{}{
|
"cubbyhole/": map[string]interface{}{
|
||||||
"description": "per-token private secret storage",
|
"description": "per-token private secret storage",
|
||||||
@@ -257,6 +274,7 @@ func TestSysRemount(t *testing.T) {
|
|||||||
"default_lease_ttl": json.Number("0"),
|
"default_lease_ttl": json.Number("0"),
|
||||||
"max_lease_ttl": json.Number("0"),
|
"max_lease_ttl": json.Number("0"),
|
||||||
},
|
},
|
||||||
|
"local": true,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"bar/": map[string]interface{}{
|
"bar/": map[string]interface{}{
|
||||||
@@ -266,6 +284,7 @@ func TestSysRemount(t *testing.T) {
|
|||||||
"default_lease_ttl": json.Number("0"),
|
"default_lease_ttl": json.Number("0"),
|
||||||
"max_lease_ttl": json.Number("0"),
|
"max_lease_ttl": json.Number("0"),
|
||||||
},
|
},
|
||||||
|
"local": false,
|
||||||
},
|
},
|
||||||
"secret/": map[string]interface{}{
|
"secret/": map[string]interface{}{
|
||||||
"description": "generic secret storage",
|
"description": "generic secret storage",
|
||||||
@@ -274,6 +293,7 @@ func TestSysRemount(t *testing.T) {
|
|||||||
"default_lease_ttl": json.Number("0"),
|
"default_lease_ttl": json.Number("0"),
|
||||||
"max_lease_ttl": json.Number("0"),
|
"max_lease_ttl": json.Number("0"),
|
||||||
},
|
},
|
||||||
|
"local": false,
|
||||||
},
|
},
|
||||||
"sys/": map[string]interface{}{
|
"sys/": map[string]interface{}{
|
||||||
"description": "system endpoints used for control, policy and debugging",
|
"description": "system endpoints used for control, policy and debugging",
|
||||||
@@ -282,6 +302,7 @@ func TestSysRemount(t *testing.T) {
|
|||||||
"default_lease_ttl": json.Number("0"),
|
"default_lease_ttl": json.Number("0"),
|
||||||
"max_lease_ttl": json.Number("0"),
|
"max_lease_ttl": json.Number("0"),
|
||||||
},
|
},
|
||||||
|
"local": false,
|
||||||
},
|
},
|
||||||
"cubbyhole/": map[string]interface{}{
|
"cubbyhole/": map[string]interface{}{
|
||||||
"description": "per-token private secret storage",
|
"description": "per-token private secret storage",
|
||||||
@@ -290,6 +311,7 @@ func TestSysRemount(t *testing.T) {
|
|||||||
"default_lease_ttl": json.Number("0"),
|
"default_lease_ttl": json.Number("0"),
|
||||||
"max_lease_ttl": json.Number("0"),
|
"max_lease_ttl": json.Number("0"),
|
||||||
},
|
},
|
||||||
|
"local": true,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
testResponseStatus(t, resp, 200)
|
testResponseStatus(t, resp, 200)
|
||||||
@@ -333,6 +355,7 @@ func TestSysUnmount(t *testing.T) {
|
|||||||
"default_lease_ttl": json.Number("0"),
|
"default_lease_ttl": json.Number("0"),
|
||||||
"max_lease_ttl": json.Number("0"),
|
"max_lease_ttl": json.Number("0"),
|
||||||
},
|
},
|
||||||
|
"local": false,
|
||||||
},
|
},
|
||||||
"sys/": map[string]interface{}{
|
"sys/": map[string]interface{}{
|
||||||
"description": "system endpoints used for control, policy and debugging",
|
"description": "system endpoints used for control, policy and debugging",
|
||||||
@@ -341,6 +364,7 @@ func TestSysUnmount(t *testing.T) {
|
|||||||
"default_lease_ttl": json.Number("0"),
|
"default_lease_ttl": json.Number("0"),
|
||||||
"max_lease_ttl": json.Number("0"),
|
"max_lease_ttl": json.Number("0"),
|
||||||
},
|
},
|
||||||
|
"local": false,
|
||||||
},
|
},
|
||||||
"cubbyhole/": map[string]interface{}{
|
"cubbyhole/": map[string]interface{}{
|
||||||
"description": "per-token private secret storage",
|
"description": "per-token private secret storage",
|
||||||
@@ -349,6 +373,7 @@ func TestSysUnmount(t *testing.T) {
|
|||||||
"default_lease_ttl": json.Number("0"),
|
"default_lease_ttl": json.Number("0"),
|
||||||
"max_lease_ttl": json.Number("0"),
|
"max_lease_ttl": json.Number("0"),
|
||||||
},
|
},
|
||||||
|
"local": true,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"secret/": map[string]interface{}{
|
"secret/": map[string]interface{}{
|
||||||
@@ -358,6 +383,7 @@ func TestSysUnmount(t *testing.T) {
|
|||||||
"default_lease_ttl": json.Number("0"),
|
"default_lease_ttl": json.Number("0"),
|
||||||
"max_lease_ttl": json.Number("0"),
|
"max_lease_ttl": json.Number("0"),
|
||||||
},
|
},
|
||||||
|
"local": false,
|
||||||
},
|
},
|
||||||
"sys/": map[string]interface{}{
|
"sys/": map[string]interface{}{
|
||||||
"description": "system endpoints used for control, policy and debugging",
|
"description": "system endpoints used for control, policy and debugging",
|
||||||
@@ -366,6 +392,7 @@ func TestSysUnmount(t *testing.T) {
|
|||||||
"default_lease_ttl": json.Number("0"),
|
"default_lease_ttl": json.Number("0"),
|
||||||
"max_lease_ttl": json.Number("0"),
|
"max_lease_ttl": json.Number("0"),
|
||||||
},
|
},
|
||||||
|
"local": false,
|
||||||
},
|
},
|
||||||
"cubbyhole/": map[string]interface{}{
|
"cubbyhole/": map[string]interface{}{
|
||||||
"description": "per-token private secret storage",
|
"description": "per-token private secret storage",
|
||||||
@@ -374,6 +401,7 @@ func TestSysUnmount(t *testing.T) {
|
|||||||
"default_lease_ttl": json.Number("0"),
|
"default_lease_ttl": json.Number("0"),
|
||||||
"max_lease_ttl": json.Number("0"),
|
"max_lease_ttl": json.Number("0"),
|
||||||
},
|
},
|
||||||
|
"local": true,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
testResponseStatus(t, resp, 200)
|
testResponseStatus(t, resp, 200)
|
||||||
@@ -414,6 +442,7 @@ func TestSysTuneMount(t *testing.T) {
|
|||||||
"default_lease_ttl": json.Number("0"),
|
"default_lease_ttl": json.Number("0"),
|
||||||
"max_lease_ttl": json.Number("0"),
|
"max_lease_ttl": json.Number("0"),
|
||||||
},
|
},
|
||||||
|
"local": false,
|
||||||
},
|
},
|
||||||
"secret/": map[string]interface{}{
|
"secret/": map[string]interface{}{
|
||||||
"description": "generic secret storage",
|
"description": "generic secret storage",
|
||||||
@@ -422,6 +451,7 @@ func TestSysTuneMount(t *testing.T) {
|
|||||||
"default_lease_ttl": json.Number("0"),
|
"default_lease_ttl": json.Number("0"),
|
||||||
"max_lease_ttl": json.Number("0"),
|
"max_lease_ttl": json.Number("0"),
|
||||||
},
|
},
|
||||||
|
"local": false,
|
||||||
},
|
},
|
||||||
"sys/": map[string]interface{}{
|
"sys/": map[string]interface{}{
|
||||||
"description": "system endpoints used for control, policy and debugging",
|
"description": "system endpoints used for control, policy and debugging",
|
||||||
@@ -430,6 +460,7 @@ func TestSysTuneMount(t *testing.T) {
|
|||||||
"default_lease_ttl": json.Number("0"),
|
"default_lease_ttl": json.Number("0"),
|
||||||
"max_lease_ttl": json.Number("0"),
|
"max_lease_ttl": json.Number("0"),
|
||||||
},
|
},
|
||||||
|
"local": false,
|
||||||
},
|
},
|
||||||
"cubbyhole/": map[string]interface{}{
|
"cubbyhole/": map[string]interface{}{
|
||||||
"description": "per-token private secret storage",
|
"description": "per-token private secret storage",
|
||||||
@@ -438,6 +469,7 @@ func TestSysTuneMount(t *testing.T) {
|
|||||||
"default_lease_ttl": json.Number("0"),
|
"default_lease_ttl": json.Number("0"),
|
||||||
"max_lease_ttl": json.Number("0"),
|
"max_lease_ttl": json.Number("0"),
|
||||||
},
|
},
|
||||||
|
"local": true,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"foo/": map[string]interface{}{
|
"foo/": map[string]interface{}{
|
||||||
@@ -447,6 +479,7 @@ func TestSysTuneMount(t *testing.T) {
|
|||||||
"default_lease_ttl": json.Number("0"),
|
"default_lease_ttl": json.Number("0"),
|
||||||
"max_lease_ttl": json.Number("0"),
|
"max_lease_ttl": json.Number("0"),
|
||||||
},
|
},
|
||||||
|
"local": false,
|
||||||
},
|
},
|
||||||
"secret/": map[string]interface{}{
|
"secret/": map[string]interface{}{
|
||||||
"description": "generic secret storage",
|
"description": "generic secret storage",
|
||||||
@@ -455,6 +488,7 @@ func TestSysTuneMount(t *testing.T) {
|
|||||||
"default_lease_ttl": json.Number("0"),
|
"default_lease_ttl": json.Number("0"),
|
||||||
"max_lease_ttl": json.Number("0"),
|
"max_lease_ttl": json.Number("0"),
|
||||||
},
|
},
|
||||||
|
"local": false,
|
||||||
},
|
},
|
||||||
"sys/": map[string]interface{}{
|
"sys/": map[string]interface{}{
|
||||||
"description": "system endpoints used for control, policy and debugging",
|
"description": "system endpoints used for control, policy and debugging",
|
||||||
@@ -463,6 +497,7 @@ func TestSysTuneMount(t *testing.T) {
|
|||||||
"default_lease_ttl": json.Number("0"),
|
"default_lease_ttl": json.Number("0"),
|
||||||
"max_lease_ttl": json.Number("0"),
|
"max_lease_ttl": json.Number("0"),
|
||||||
},
|
},
|
||||||
|
"local": false,
|
||||||
},
|
},
|
||||||
"cubbyhole/": map[string]interface{}{
|
"cubbyhole/": map[string]interface{}{
|
||||||
"description": "per-token private secret storage",
|
"description": "per-token private secret storage",
|
||||||
@@ -471,6 +506,7 @@ func TestSysTuneMount(t *testing.T) {
|
|||||||
"default_lease_ttl": json.Number("0"),
|
"default_lease_ttl": json.Number("0"),
|
||||||
"max_lease_ttl": json.Number("0"),
|
"max_lease_ttl": json.Number("0"),
|
||||||
},
|
},
|
||||||
|
"local": true,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
testResponseStatus(t, resp, 200)
|
testResponseStatus(t, resp, 200)
|
||||||
@@ -532,6 +568,7 @@ func TestSysTuneMount(t *testing.T) {
|
|||||||
"default_lease_ttl": json.Number("259196400"),
|
"default_lease_ttl": json.Number("259196400"),
|
||||||
"max_lease_ttl": json.Number("259200000"),
|
"max_lease_ttl": json.Number("259200000"),
|
||||||
},
|
},
|
||||||
|
"local": false,
|
||||||
},
|
},
|
||||||
"secret/": map[string]interface{}{
|
"secret/": map[string]interface{}{
|
||||||
"description": "generic secret storage",
|
"description": "generic secret storage",
|
||||||
@@ -540,6 +577,7 @@ func TestSysTuneMount(t *testing.T) {
|
|||||||
"default_lease_ttl": json.Number("0"),
|
"default_lease_ttl": json.Number("0"),
|
||||||
"max_lease_ttl": json.Number("0"),
|
"max_lease_ttl": json.Number("0"),
|
||||||
},
|
},
|
||||||
|
"local": false,
|
||||||
},
|
},
|
||||||
"sys/": map[string]interface{}{
|
"sys/": map[string]interface{}{
|
||||||
"description": "system endpoints used for control, policy and debugging",
|
"description": "system endpoints used for control, policy and debugging",
|
||||||
@@ -548,6 +586,7 @@ func TestSysTuneMount(t *testing.T) {
|
|||||||
"default_lease_ttl": json.Number("0"),
|
"default_lease_ttl": json.Number("0"),
|
||||||
"max_lease_ttl": json.Number("0"),
|
"max_lease_ttl": json.Number("0"),
|
||||||
},
|
},
|
||||||
|
"local": false,
|
||||||
},
|
},
|
||||||
"cubbyhole/": map[string]interface{}{
|
"cubbyhole/": map[string]interface{}{
|
||||||
"description": "per-token private secret storage",
|
"description": "per-token private secret storage",
|
||||||
@@ -556,6 +595,7 @@ func TestSysTuneMount(t *testing.T) {
|
|||||||
"default_lease_ttl": json.Number("0"),
|
"default_lease_ttl": json.Number("0"),
|
||||||
"max_lease_ttl": json.Number("0"),
|
"max_lease_ttl": json.Number("0"),
|
||||||
},
|
},
|
||||||
|
"local": true,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"foo/": map[string]interface{}{
|
"foo/": map[string]interface{}{
|
||||||
@@ -565,6 +605,7 @@ func TestSysTuneMount(t *testing.T) {
|
|||||||
"default_lease_ttl": json.Number("259196400"),
|
"default_lease_ttl": json.Number("259196400"),
|
||||||
"max_lease_ttl": json.Number("259200000"),
|
"max_lease_ttl": json.Number("259200000"),
|
||||||
},
|
},
|
||||||
|
"local": false,
|
||||||
},
|
},
|
||||||
"secret/": map[string]interface{}{
|
"secret/": map[string]interface{}{
|
||||||
"description": "generic secret storage",
|
"description": "generic secret storage",
|
||||||
@@ -573,6 +614,7 @@ func TestSysTuneMount(t *testing.T) {
|
|||||||
"default_lease_ttl": json.Number("0"),
|
"default_lease_ttl": json.Number("0"),
|
||||||
"max_lease_ttl": json.Number("0"),
|
"max_lease_ttl": json.Number("0"),
|
||||||
},
|
},
|
||||||
|
"local": false,
|
||||||
},
|
},
|
||||||
"sys/": map[string]interface{}{
|
"sys/": map[string]interface{}{
|
||||||
"description": "system endpoints used for control, policy and debugging",
|
"description": "system endpoints used for control, policy and debugging",
|
||||||
@@ -581,6 +623,7 @@ func TestSysTuneMount(t *testing.T) {
|
|||||||
"default_lease_ttl": json.Number("0"),
|
"default_lease_ttl": json.Number("0"),
|
||||||
"max_lease_ttl": json.Number("0"),
|
"max_lease_ttl": json.Number("0"),
|
||||||
},
|
},
|
||||||
|
"local": false,
|
||||||
},
|
},
|
||||||
"cubbyhole/": map[string]interface{}{
|
"cubbyhole/": map[string]interface{}{
|
||||||
"description": "per-token private secret storage",
|
"description": "per-token private secret storage",
|
||||||
@@ -589,6 +632,7 @@ func TestSysTuneMount(t *testing.T) {
|
|||||||
"default_lease_ttl": json.Number("0"),
|
"default_lease_ttl": json.Number("0"),
|
||||||
"max_lease_ttl": json.Number("0"),
|
"max_lease_ttl": json.Number("0"),
|
||||||
},
|
},
|
||||||
|
"local": true,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/hashicorp/vault/helper/consts"
|
||||||
"github.com/hashicorp/vault/helper/pgpkeys"
|
"github.com/hashicorp/vault/helper/pgpkeys"
|
||||||
"github.com/hashicorp/vault/vault"
|
"github.com/hashicorp/vault/vault"
|
||||||
)
|
)
|
||||||
@@ -19,6 +20,13 @@ func handleSysRekeyInit(core *vault.Core, recovery bool) http.Handler {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
repState := core.ReplicationState()
|
||||||
|
if repState == consts.ReplicationSecondary {
|
||||||
|
respondError(w, http.StatusBadRequest,
|
||||||
|
fmt.Errorf("rekeying can only be performed on the primary cluster when replication is activated"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
switch {
|
switch {
|
||||||
case recovery && !core.SealAccess().RecoveryKeySupported():
|
case recovery && !core.SealAccess().RecoveryKeySupported():
|
||||||
respondError(w, http.StatusBadRequest, fmt.Errorf("recovery rekeying not supported"))
|
respondError(w, http.StatusBadRequest, fmt.Errorf("recovery rekeying not supported"))
|
||||||
@@ -108,7 +116,7 @@ func handleSysRekeyInitPut(core *vault.Core, recovery bool, w http.ResponseWrite
|
|||||||
// Right now we don't support this, but the rest of the code is ready for
|
// Right now we don't support this, but the rest of the code is ready for
|
||||||
// when we do, hence the check below for this to be false if
|
// when we do, hence the check below for this to be false if
|
||||||
// StoredShares is greater than zero
|
// StoredShares is greater than zero
|
||||||
if core.SealAccess().StoredKeysSupported() {
|
if core.SealAccess().StoredKeysSupported() && !recovery {
|
||||||
respondError(w, http.StatusBadRequest, fmt.Errorf("rekeying of barrier not supported when stored key support is available"))
|
respondError(w, http.StatusBadRequest, fmt.Errorf("rekeying of barrier not supported when stored key support is available"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/hashicorp/errwrap"
|
"github.com/hashicorp/errwrap"
|
||||||
|
"github.com/hashicorp/vault/helper/consts"
|
||||||
"github.com/hashicorp/vault/logical"
|
"github.com/hashicorp/vault/logical"
|
||||||
"github.com/hashicorp/vault/vault"
|
"github.com/hashicorp/vault/vault"
|
||||||
"github.com/hashicorp/vault/version"
|
"github.com/hashicorp/vault/version"
|
||||||
@@ -126,7 +127,7 @@ func handleSysUnseal(core *vault.Core) http.Handler {
|
|||||||
case errwrap.Contains(err, vault.ErrBarrierInvalidKey.Error()):
|
case errwrap.Contains(err, vault.ErrBarrierInvalidKey.Error()):
|
||||||
case errwrap.Contains(err, vault.ErrBarrierNotInit.Error()):
|
case errwrap.Contains(err, vault.ErrBarrierNotInit.Error()):
|
||||||
case errwrap.Contains(err, vault.ErrBarrierSealed.Error()):
|
case errwrap.Contains(err, vault.ErrBarrierSealed.Error()):
|
||||||
case errwrap.Contains(err, vault.ErrStandby.Error()):
|
case errwrap.Contains(err, consts.ErrStandby.Error()):
|
||||||
default:
|
default:
|
||||||
respondError(w, http.StatusInternalServerError, err)
|
respondError(w, http.StatusInternalServerError, err)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ import (
|
|||||||
// is present on the Request structure for credential backends.
|
// is present on the Request structure for credential backends.
|
||||||
type Connection struct {
|
type Connection struct {
|
||||||
// RemoteAddr is the network address that sent the request.
|
// RemoteAddr is the network address that sent the request.
|
||||||
RemoteAddr string
|
RemoteAddr string `json:"remote_addr"`
|
||||||
|
|
||||||
// ConnState is the TLS connection state if applicable.
|
// ConnState is the TLS connection state if applicable.
|
||||||
ConnState *tls.ConnectionState
|
ConnState *tls.ConnectionState
|
||||||
|
|||||||
@@ -21,3 +21,27 @@ func (e *codedError) Error() string {
|
|||||||
func (e *codedError) Code() int {
|
func (e *codedError) Code() int {
|
||||||
return e.code
|
return e.code
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Struct to identify user input errors. This is helpful in responding the
|
||||||
|
// appropriate status codes to clients from the HTTP endpoints.
|
||||||
|
type StatusBadRequest struct {
|
||||||
|
Err string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Implementing error interface
|
||||||
|
func (s *StatusBadRequest) Error() string {
|
||||||
|
return s.Err
|
||||||
|
}
|
||||||
|
|
||||||
|
// This is a new type declared to not cause potential compatibility problems if
|
||||||
|
// the logic around the HTTPCodedError interface changes; in particular for
|
||||||
|
// logical request paths it is basically ignored, and changing that behavior
|
||||||
|
// might cause unforseen issues.
|
||||||
|
type ReplicationCodedError struct {
|
||||||
|
Msg string
|
||||||
|
Code int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ReplicationCodedError) Error() string {
|
||||||
|
return r.Msg
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package framework
|
package framework
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"regexp"
|
"regexp"
|
||||||
@@ -12,6 +13,7 @@ import (
|
|||||||
log "github.com/mgutz/logxi/v1"
|
log "github.com/mgutz/logxi/v1"
|
||||||
|
|
||||||
"github.com/hashicorp/go-multierror"
|
"github.com/hashicorp/go-multierror"
|
||||||
|
"github.com/hashicorp/vault/helper/duration"
|
||||||
"github.com/hashicorp/vault/helper/errutil"
|
"github.com/hashicorp/vault/helper/errutil"
|
||||||
"github.com/hashicorp/vault/helper/logformat"
|
"github.com/hashicorp/vault/helper/logformat"
|
||||||
"github.com/hashicorp/vault/logical"
|
"github.com/hashicorp/vault/logical"
|
||||||
@@ -534,7 +536,40 @@ type FieldSchema struct {
|
|||||||
// the zero value of the type.
|
// the zero value of the type.
|
||||||
func (s *FieldSchema) DefaultOrZero() interface{} {
|
func (s *FieldSchema) DefaultOrZero() interface{} {
|
||||||
if s.Default != nil {
|
if s.Default != nil {
|
||||||
return s.Default
|
switch s.Type {
|
||||||
|
case TypeDurationSecond:
|
||||||
|
var result int
|
||||||
|
switch inp := s.Default.(type) {
|
||||||
|
case nil:
|
||||||
|
return s.Type.Zero()
|
||||||
|
case int:
|
||||||
|
result = inp
|
||||||
|
case int64:
|
||||||
|
result = int(inp)
|
||||||
|
case float32:
|
||||||
|
result = int(inp)
|
||||||
|
case float64:
|
||||||
|
result = int(inp)
|
||||||
|
case string:
|
||||||
|
dur, err := duration.ParseDurationSecond(inp)
|
||||||
|
if err != nil {
|
||||||
|
return s.Type.Zero()
|
||||||
|
}
|
||||||
|
result = int(dur.Seconds())
|
||||||
|
case json.Number:
|
||||||
|
valInt64, err := inp.Int64()
|
||||||
|
if err != nil {
|
||||||
|
return s.Type.Zero()
|
||||||
|
}
|
||||||
|
result = int(valInt64)
|
||||||
|
default:
|
||||||
|
return s.Type.Zero()
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
|
||||||
|
default:
|
||||||
|
return s.Default
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return s.Type.Zero()
|
return s.Type.Zero()
|
||||||
|
|||||||
@@ -554,6 +554,16 @@ func TestFieldSchemaDefaultOrZero(t *testing.T) {
|
|||||||
60,
|
60,
|
||||||
},
|
},
|
||||||
|
|
||||||
|
"default duration int64": {
|
||||||
|
&FieldSchema{Type: TypeDurationSecond, Default: int64(60)},
|
||||||
|
60,
|
||||||
|
},
|
||||||
|
|
||||||
|
"default duration string": {
|
||||||
|
&FieldSchema{Type: TypeDurationSecond, Default: "60s"},
|
||||||
|
60,
|
||||||
|
},
|
||||||
|
|
||||||
"default duration not set": {
|
"default duration not set": {
|
||||||
&FieldSchema{Type: TypeDurationSecond},
|
&FieldSchema{Type: TypeDurationSecond},
|
||||||
0,
|
0,
|
||||||
|
|||||||
@@ -80,22 +80,3 @@ type Paths struct {
|
|||||||
// indicates that these paths should not be replicated
|
// indicates that these paths should not be replicated
|
||||||
LocalStorage []string
|
LocalStorage []string
|
||||||
}
|
}
|
||||||
|
|
||||||
type ReplicationState uint32
|
|
||||||
|
|
||||||
const (
|
|
||||||
ReplicationDisabled ReplicationState = iota
|
|
||||||
ReplicationPrimary
|
|
||||||
ReplicationSecondary
|
|
||||||
)
|
|
||||||
|
|
||||||
func (r ReplicationState) String() string {
|
|
||||||
switch r {
|
|
||||||
case ReplicationSecondary:
|
|
||||||
return "secondary"
|
|
||||||
case ReplicationPrimary:
|
|
||||||
return "primary"
|
|
||||||
}
|
|
||||||
|
|
||||||
return "disabled"
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -25,6 +25,10 @@ type Request struct {
|
|||||||
// Id is the uuid associated with each request
|
// Id is the uuid associated with each request
|
||||||
ID string `json:"id" structs:"id" mapstructure:"id"`
|
ID string `json:"id" structs:"id" mapstructure:"id"`
|
||||||
|
|
||||||
|
// If set, the name given to the replication secondary where this request
|
||||||
|
// originated
|
||||||
|
ReplicationCluster string `json:"replication_cluster" structs:"replication_cluster", mapstructure:"replication_cluster"`
|
||||||
|
|
||||||
// Operation is the requested operation type
|
// Operation is the requested operation type
|
||||||
Operation Operation `json:"operation" structs:"operation" mapstructure:"operation"`
|
Operation Operation `json:"operation" structs:"operation" mapstructure:"operation"`
|
||||||
|
|
||||||
@@ -38,7 +42,7 @@ type Request struct {
|
|||||||
Data map[string]interface{} `json:"map" structs:"data" mapstructure:"data"`
|
Data map[string]interface{} `json:"map" structs:"data" mapstructure:"data"`
|
||||||
|
|
||||||
// Storage can be used to durably store and retrieve state.
|
// Storage can be used to durably store and retrieve state.
|
||||||
Storage Storage `json:"storage" structs:"storage" mapstructure:"storage"`
|
Storage Storage `json:"-"`
|
||||||
|
|
||||||
// Secret will be non-nil only for Revoke and Renew operations
|
// Secret will be non-nil only for Revoke and Renew operations
|
||||||
// to represent the secret that was returned prior.
|
// to represent the secret that was returned prior.
|
||||||
|
|||||||
111
logical/response_util.go
Normal file
111
logical/response_util.go
Normal file
@@ -0,0 +1,111 @@
|
|||||||
|
package logical
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/hashicorp/errwrap"
|
||||||
|
multierror "github.com/hashicorp/go-multierror"
|
||||||
|
"github.com/hashicorp/vault/helper/consts"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RespondErrorCommon pulls most of the functionality from http's
|
||||||
|
// respondErrorCommon and some of http's handleLogical and makes it available
|
||||||
|
// to both the http package and elsewhere.
|
||||||
|
func RespondErrorCommon(req *Request, resp *Response, err error) (int, error) {
|
||||||
|
if err == nil && (resp == nil || !resp.IsError()) {
|
||||||
|
switch {
|
||||||
|
case req.Operation == ReadOperation:
|
||||||
|
if resp == nil {
|
||||||
|
return http.StatusNotFound, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Basically: if we have empty "keys" or no keys at all, 404. This
|
||||||
|
// provides consistency with GET.
|
||||||
|
case req.Operation == ListOperation && resp.WrapInfo == nil:
|
||||||
|
if resp == nil || len(resp.Data) == 0 {
|
||||||
|
return http.StatusNotFound, nil
|
||||||
|
}
|
||||||
|
keysRaw, ok := resp.Data["keys"]
|
||||||
|
if !ok || keysRaw == nil {
|
||||||
|
return http.StatusNotFound, nil
|
||||||
|
}
|
||||||
|
keys, ok := keysRaw.([]string)
|
||||||
|
if !ok {
|
||||||
|
return http.StatusInternalServerError, nil
|
||||||
|
}
|
||||||
|
if len(keys) == 0 {
|
||||||
|
return http.StatusNotFound, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if errwrap.ContainsType(err, new(ReplicationCodedError)) {
|
||||||
|
var allErrors error
|
||||||
|
codedErr := errwrap.GetType(err, new(ReplicationCodedError)).(*ReplicationCodedError)
|
||||||
|
errwrap.Walk(err, func(inErr error) {
|
||||||
|
newErr, ok := inErr.(*ReplicationCodedError)
|
||||||
|
if !ok {
|
||||||
|
allErrors = multierror.Append(allErrors, newErr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
if allErrors != nil {
|
||||||
|
return codedErr.Code, multierror.Append(errors.New(fmt.Sprintf("errors from both primary and secondary; primary error was %v; secondary errors follow", codedErr.Msg)), allErrors)
|
||||||
|
}
|
||||||
|
return codedErr.Code, errors.New(codedErr.Msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start out with internal server error since in most of these cases there
|
||||||
|
// won't be a response so this won't be overridden
|
||||||
|
statusCode := http.StatusInternalServerError
|
||||||
|
// If we actually have a response, start out with bad request
|
||||||
|
if resp != nil {
|
||||||
|
statusCode = http.StatusBadRequest
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now, check the error itself; if it has a specific logical error, set the
|
||||||
|
// appropriate code
|
||||||
|
if err != nil {
|
||||||
|
switch {
|
||||||
|
case errwrap.ContainsType(err, new(StatusBadRequest)):
|
||||||
|
statusCode = http.StatusBadRequest
|
||||||
|
case errwrap.Contains(err, ErrPermissionDenied.Error()):
|
||||||
|
statusCode = http.StatusForbidden
|
||||||
|
case errwrap.Contains(err, ErrUnsupportedOperation.Error()):
|
||||||
|
statusCode = http.StatusMethodNotAllowed
|
||||||
|
case errwrap.Contains(err, ErrUnsupportedPath.Error()):
|
||||||
|
statusCode = http.StatusNotFound
|
||||||
|
case errwrap.Contains(err, ErrInvalidRequest.Error()):
|
||||||
|
statusCode = http.StatusBadRequest
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp != nil && resp.IsError() {
|
||||||
|
err = fmt.Errorf("%s", resp.Data["error"].(string))
|
||||||
|
}
|
||||||
|
|
||||||
|
return statusCode, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// AdjustErrorStatusCode adjusts the status that will be sent in error
|
||||||
|
// conditions in a way that can be shared across http's respondError and other
|
||||||
|
// locations.
|
||||||
|
func AdjustErrorStatusCode(status *int, err error) {
|
||||||
|
// Adjust status code when sealed
|
||||||
|
if errwrap.Contains(err, consts.ErrSealed.Error()) {
|
||||||
|
*status = http.StatusServiceUnavailable
|
||||||
|
}
|
||||||
|
|
||||||
|
// Adjust status code on
|
||||||
|
if errwrap.Contains(err, "http: request body too large") {
|
||||||
|
*status = http.StatusRequestEntityTooLarge
|
||||||
|
}
|
||||||
|
|
||||||
|
// Allow HTTPCoded error passthrough to specify a code
|
||||||
|
if t, ok := err.(HTTPCodedError); ok {
|
||||||
|
*status = t.Code()
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,6 +1,10 @@
|
|||||||
package logical
|
package logical
|
||||||
|
|
||||||
import "time"
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/hashicorp/vault/helper/consts"
|
||||||
|
)
|
||||||
|
|
||||||
// SystemView exposes system configuration information in a safe way
|
// SystemView exposes system configuration information in a safe way
|
||||||
// for logical backends to consume
|
// for logical backends to consume
|
||||||
@@ -32,7 +36,7 @@ type SystemView interface {
|
|||||||
CachingDisabled() bool
|
CachingDisabled() bool
|
||||||
|
|
||||||
// ReplicationState indicates the state of cluster replication
|
// ReplicationState indicates the state of cluster replication
|
||||||
ReplicationState() ReplicationState
|
ReplicationState() consts.ReplicationState
|
||||||
}
|
}
|
||||||
|
|
||||||
type StaticSystemView struct {
|
type StaticSystemView struct {
|
||||||
@@ -42,7 +46,7 @@ type StaticSystemView struct {
|
|||||||
TaintedVal bool
|
TaintedVal bool
|
||||||
CachingDisabledVal bool
|
CachingDisabledVal bool
|
||||||
Primary bool
|
Primary bool
|
||||||
ReplicationStateVal ReplicationState
|
ReplicationStateVal consts.ReplicationState
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d StaticSystemView) DefaultLeaseTTL() time.Duration {
|
func (d StaticSystemView) DefaultLeaseTTL() time.Duration {
|
||||||
@@ -65,6 +69,6 @@ func (d StaticSystemView) CachingDisabled() bool {
|
|||||||
return d.CachingDisabledVal
|
return d.CachingDisabledVal
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d StaticSystemView) ReplicationState() ReplicationState {
|
func (d StaticSystemView) ReplicationState() consts.ReplicationState {
|
||||||
return d.ReplicationStateVal
|
return d.ReplicationStateVal
|
||||||
}
|
}
|
||||||
|
|||||||
17
meta/meta.go
17
meta/meta.go
@@ -23,6 +23,21 @@ const (
|
|||||||
FlagSetDefault = FlagSetServer
|
FlagSetDefault = FlagSetServer
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
additionalOptionsUsage = func() string {
|
||||||
|
return `
|
||||||
|
-wrap-ttl="" Indicates that the response should be wrapped in a
|
||||||
|
cubbyhole token with the requested TTL. The response
|
||||||
|
can be fetched by calling the "sys/wrapping/unwrap"
|
||||||
|
endpoint, passing in the wrappping token's ID. This
|
||||||
|
is a numeric string with an optional suffix
|
||||||
|
"s", "m", or "h"; if no suffix is specified it will
|
||||||
|
be parsed as seconds. May also be specified via
|
||||||
|
VAULT_WRAP_TTL.
|
||||||
|
`
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
// Meta contains the meta-options and functionality that nearly every
|
// Meta contains the meta-options and functionality that nearly every
|
||||||
// Vault command inherits.
|
// Vault command inherits.
|
||||||
type Meta struct {
|
type Meta struct {
|
||||||
@@ -188,6 +203,6 @@ func GeneralOptionsUsage() string {
|
|||||||
if VAULT_SKIP_VERIFY is set.
|
if VAULT_SKIP_VERIFY is set.
|
||||||
`
|
`
|
||||||
|
|
||||||
general += AdditionalOptionsUsage()
|
general += additionalOptionsUsage()
|
||||||
return general
|
return general
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +0,0 @@
|
|||||||
// +build !vault
|
|
||||||
|
|
||||||
package meta
|
|
||||||
|
|
||||||
func AdditionalOptionsUsage() string {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
@@ -1,16 +0,0 @@
|
|||||||
// +build vault
|
|
||||||
|
|
||||||
package meta
|
|
||||||
|
|
||||||
func AdditionalOptionsUsage() string {
|
|
||||||
return `
|
|
||||||
-wrap-ttl="" Indicates that the response should be wrapped in a
|
|
||||||
cubbyhole token with the requested TTL. The response
|
|
||||||
can be fetched by calling the "sys/wrapping/unwrap"
|
|
||||||
endpoint, passing in the wrappping token's ID. This
|
|
||||||
is a numeric string with an optional suffix
|
|
||||||
"s", "m", or "h"; if no suffix is specified it will
|
|
||||||
be parsed as seconds. May also be specified via
|
|
||||||
VAULT_WRAP_TTL.
|
|
||||||
`
|
|
||||||
}
|
|
||||||
@@ -1,9 +1,15 @@
|
|||||||
package physical
|
package physical
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto/sha1"
|
||||||
|
"encoding/hex"
|
||||||
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
"github.com/hashicorp/golang-lru"
|
"github.com/hashicorp/golang-lru"
|
||||||
|
"github.com/hashicorp/vault/helper/locksutil"
|
||||||
|
"github.com/hashicorp/vault/helper/strutil"
|
||||||
log "github.com/mgutz/logxi/v1"
|
log "github.com/mgutz/logxi/v1"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -17,8 +23,11 @@ const (
|
|||||||
// Vault are for policy objects so there is a large read reduction
|
// Vault are for policy objects so there is a large read reduction
|
||||||
// by using a simple write-through cache.
|
// by using a simple write-through cache.
|
||||||
type Cache struct {
|
type Cache struct {
|
||||||
backend Backend
|
backend Backend
|
||||||
lru *lru.TwoQueueCache
|
transactional Transactional
|
||||||
|
lru *lru.TwoQueueCache
|
||||||
|
locks map[string]*sync.RWMutex
|
||||||
|
logger log.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewCache returns a physical cache of the given size.
|
// NewCache returns a physical cache of the given size.
|
||||||
@@ -34,16 +43,58 @@ func NewCache(b Backend, size int, logger log.Logger) *Cache {
|
|||||||
c := &Cache{
|
c := &Cache{
|
||||||
backend: b,
|
backend: b,
|
||||||
lru: cache,
|
lru: cache,
|
||||||
|
locks: make(map[string]*sync.RWMutex, 256),
|
||||||
|
logger: logger,
|
||||||
}
|
}
|
||||||
|
if err := locksutil.CreateLocks(c.locks, 256); err != nil {
|
||||||
|
logger.Error("physical/cache: error creating locks", "error", err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if txnl, ok := c.backend.(Transactional); ok {
|
||||||
|
c.transactional = txnl
|
||||||
|
}
|
||||||
|
|
||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Cache) lockHashForKey(key string) string {
|
||||||
|
hf := sha1.New()
|
||||||
|
hf.Write([]byte(key))
|
||||||
|
return strings.ToLower(hex.EncodeToString(hf.Sum(nil))[:2])
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Cache) lockForKey(key string) *sync.RWMutex {
|
||||||
|
return c.locks[c.lockHashForKey(key)]
|
||||||
|
}
|
||||||
|
|
||||||
// Purge is used to clear the cache
|
// Purge is used to clear the cache
|
||||||
func (c *Cache) Purge() {
|
func (c *Cache) Purge() {
|
||||||
|
// Lock the world
|
||||||
|
lockHashes := make([]string, 0, len(c.locks))
|
||||||
|
for hash := range c.locks {
|
||||||
|
lockHashes = append(lockHashes, hash)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sort and deduplicate. This ensures we don't try to grab the same lock
|
||||||
|
// twice, and enforcing a sort means we'll not have multiple goroutines
|
||||||
|
// deadlock by acquiring in different orders.
|
||||||
|
lockHashes = strutil.RemoveDuplicates(lockHashes)
|
||||||
|
|
||||||
|
for _, lockHash := range lockHashes {
|
||||||
|
lock := c.locks[lockHash]
|
||||||
|
lock.Lock()
|
||||||
|
defer lock.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
c.lru.Purge()
|
c.lru.Purge()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Cache) Put(entry *Entry) error {
|
func (c *Cache) Put(entry *Entry) error {
|
||||||
|
lock := c.lockForKey(entry.Key)
|
||||||
|
lock.Lock()
|
||||||
|
defer lock.Unlock()
|
||||||
|
|
||||||
err := c.backend.Put(entry)
|
err := c.backend.Put(entry)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
c.lru.Add(entry.Key, entry)
|
c.lru.Add(entry.Key, entry)
|
||||||
@@ -52,6 +103,10 @@ func (c *Cache) Put(entry *Entry) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Cache) Get(key string) (*Entry, error) {
|
func (c *Cache) Get(key string) (*Entry, error) {
|
||||||
|
lock := c.lockForKey(key)
|
||||||
|
lock.RLock()
|
||||||
|
defer lock.RUnlock()
|
||||||
|
|
||||||
// Check the LRU first
|
// Check the LRU first
|
||||||
if raw, ok := c.lru.Get(key); ok {
|
if raw, ok := c.lru.Get(key); ok {
|
||||||
if raw == nil {
|
if raw == nil {
|
||||||
@@ -79,6 +134,10 @@ func (c *Cache) Get(key string) (*Entry, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Cache) Delete(key string) error {
|
func (c *Cache) Delete(key string) error {
|
||||||
|
lock := c.lockForKey(key)
|
||||||
|
lock.Lock()
|
||||||
|
defer lock.Unlock()
|
||||||
|
|
||||||
err := c.backend.Delete(key)
|
err := c.backend.Delete(key)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
c.lru.Remove(key)
|
c.lru.Remove(key)
|
||||||
@@ -87,6 +146,45 @@ func (c *Cache) Delete(key string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Cache) List(prefix string) ([]string, error) {
|
func (c *Cache) List(prefix string) ([]string, error) {
|
||||||
// Always pass-through as this would be difficult to cache.
|
// Always pass-through as this would be difficult to cache. For the same
|
||||||
|
// reason we don't lock as we can't reasonably know which locks to readlock
|
||||||
|
// ahead of time.
|
||||||
return c.backend.List(prefix)
|
return c.backend.List(prefix)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Cache) Transaction(txns []TxnEntry) error {
|
||||||
|
if c.transactional == nil {
|
||||||
|
return fmt.Errorf("physical/cache: underlying backend does not support transactions")
|
||||||
|
}
|
||||||
|
|
||||||
|
var lockHashes []string
|
||||||
|
for _, txn := range txns {
|
||||||
|
lockHashes = append(lockHashes, c.lockHashForKey(txn.Entry.Key))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sort and deduplicate. This ensures we don't try to grab the same lock
|
||||||
|
// twice, and enforcing a sort means we'll not have multiple goroutines
|
||||||
|
// deadlock by acquiring in different orders.
|
||||||
|
lockHashes = strutil.RemoveDuplicates(lockHashes)
|
||||||
|
|
||||||
|
for _, lockHash := range lockHashes {
|
||||||
|
lock := c.locks[lockHash]
|
||||||
|
lock.Lock()
|
||||||
|
defer lock.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.transactional.Transaction(txns); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, txn := range txns {
|
||||||
|
switch txn.Operation {
|
||||||
|
case PutOperation:
|
||||||
|
c.lru.Add(txn.Entry.Key, txn.Entry)
|
||||||
|
case DeleteOperation:
|
||||||
|
c.lru.Remove(txn.Entry.Key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package physical
|
package physical
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net"
|
"net"
|
||||||
@@ -21,6 +22,8 @@ import (
|
|||||||
"github.com/hashicorp/consul/lib"
|
"github.com/hashicorp/consul/lib"
|
||||||
"github.com/hashicorp/errwrap"
|
"github.com/hashicorp/errwrap"
|
||||||
"github.com/hashicorp/go-cleanhttp"
|
"github.com/hashicorp/go-cleanhttp"
|
||||||
|
multierror "github.com/hashicorp/go-multierror"
|
||||||
|
"github.com/hashicorp/vault/helper/consts"
|
||||||
"github.com/hashicorp/vault/helper/strutil"
|
"github.com/hashicorp/vault/helper/strutil"
|
||||||
"github.com/hashicorp/vault/helper/tlsutil"
|
"github.com/hashicorp/vault/helper/tlsutil"
|
||||||
)
|
)
|
||||||
@@ -154,6 +157,10 @@ func newConsulBackend(conf map[string]string, logger log.Logger) (Backend, error
|
|||||||
|
|
||||||
// Configure the client
|
// Configure the client
|
||||||
consulConf := api.DefaultConfig()
|
consulConf := api.DefaultConfig()
|
||||||
|
// Set MaxIdleConnsPerHost to the number of processes used in expiration.Restore
|
||||||
|
tr := cleanhttp.DefaultPooledTransport()
|
||||||
|
tr.MaxIdleConnsPerHost = consts.ExpirationRestoreWorkerCount
|
||||||
|
consulConf.HttpClient.Transport = tr
|
||||||
|
|
||||||
if addr, ok := conf["address"]; ok {
|
if addr, ok := conf["address"]; ok {
|
||||||
consulConf.Address = addr
|
consulConf.Address = addr
|
||||||
@@ -179,7 +186,7 @@ func newConsulBackend(conf map[string]string, logger log.Logger) (Backend, error
|
|||||||
}
|
}
|
||||||
|
|
||||||
transport := cleanhttp.DefaultPooledTransport()
|
transport := cleanhttp.DefaultPooledTransport()
|
||||||
transport.MaxIdleConnsPerHost = 4
|
transport.MaxIdleConnsPerHost = consts.ExpirationRestoreWorkerCount
|
||||||
transport.TLSClientConfig = tlsClientConfig
|
transport.TLSClientConfig = tlsClientConfig
|
||||||
consulConf.HttpClient.Transport = transport
|
consulConf.HttpClient.Transport = transport
|
||||||
logger.Debug("physical/consul: configured TLS")
|
logger.Debug("physical/consul: configured TLS")
|
||||||
@@ -284,17 +291,59 @@ func setupTLSConfig(conf map[string]string) (*tls.Config, error) {
|
|||||||
return tlsClientConfig, nil
|
return tlsClientConfig, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Used to run multiple entries via a transaction
|
||||||
|
func (c *ConsulBackend) Transaction(txns []TxnEntry) error {
|
||||||
|
if len(txns) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
ops := make([]*api.KVTxnOp, 0, len(txns))
|
||||||
|
|
||||||
|
for _, op := range txns {
|
||||||
|
cop := &api.KVTxnOp{
|
||||||
|
Key: c.path + op.Entry.Key,
|
||||||
|
}
|
||||||
|
switch op.Operation {
|
||||||
|
case DeleteOperation:
|
||||||
|
cop.Verb = api.KVDelete
|
||||||
|
case PutOperation:
|
||||||
|
cop.Verb = api.KVSet
|
||||||
|
cop.Value = op.Entry.Value
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("%q is not a supported transaction operation", op.Operation)
|
||||||
|
}
|
||||||
|
|
||||||
|
ops = append(ops, cop)
|
||||||
|
}
|
||||||
|
|
||||||
|
ok, resp, _, err := c.kv.Txn(ops, nil)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var retErr *multierror.Error
|
||||||
|
for _, res := range resp.Errors {
|
||||||
|
retErr = multierror.Append(retErr, errors.New(res.What))
|
||||||
|
}
|
||||||
|
|
||||||
|
return retErr
|
||||||
|
}
|
||||||
|
|
||||||
// Put is used to insert or update an entry
|
// Put is used to insert or update an entry
|
||||||
func (c *ConsulBackend) Put(entry *Entry) error {
|
func (c *ConsulBackend) Put(entry *Entry) error {
|
||||||
defer metrics.MeasureSince([]string{"consul", "put"}, time.Now())
|
defer metrics.MeasureSince([]string{"consul", "put"}, time.Now())
|
||||||
|
|
||||||
|
c.permitPool.Acquire()
|
||||||
|
defer c.permitPool.Release()
|
||||||
|
|
||||||
pair := &api.KVPair{
|
pair := &api.KVPair{
|
||||||
Key: c.path + entry.Key,
|
Key: c.path + entry.Key,
|
||||||
Value: entry.Value,
|
Value: entry.Value,
|
||||||
}
|
}
|
||||||
|
|
||||||
c.permitPool.Acquire()
|
|
||||||
defer c.permitPool.Release()
|
|
||||||
|
|
||||||
_, err := c.kv.Put(pair, nil)
|
_, err := c.kv.Put(pair, nil)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
104
physical/file.go
104
physical/file.go
@@ -22,12 +22,17 @@ import (
|
|||||||
// and non-performant. It is meant mostly for local testing and development.
|
// and non-performant. It is meant mostly for local testing and development.
|
||||||
// It can be improved in the future.
|
// It can be improved in the future.
|
||||||
type FileBackend struct {
|
type FileBackend struct {
|
||||||
Path string
|
sync.RWMutex
|
||||||
l sync.Mutex
|
path string
|
||||||
logger log.Logger
|
logger log.Logger
|
||||||
|
permitPool *PermitPool
|
||||||
}
|
}
|
||||||
|
|
||||||
// newFileBackend constructs a Filebackend using the given directory
|
type TransactionalFileBackend struct {
|
||||||
|
FileBackend
|
||||||
|
}
|
||||||
|
|
||||||
|
// newFileBackend constructs a FileBackend using the given directory
|
||||||
func newFileBackend(conf map[string]string, logger log.Logger) (Backend, error) {
|
func newFileBackend(conf map[string]string, logger log.Logger) (Backend, error) {
|
||||||
path, ok := conf["path"]
|
path, ok := conf["path"]
|
||||||
if !ok {
|
if !ok {
|
||||||
@@ -35,20 +40,44 @@ func newFileBackend(conf map[string]string, logger log.Logger) (Backend, error)
|
|||||||
}
|
}
|
||||||
|
|
||||||
return &FileBackend{
|
return &FileBackend{
|
||||||
Path: path,
|
path: path,
|
||||||
logger: logger,
|
logger: logger,
|
||||||
|
permitPool: NewPermitPool(DefaultParallelOperations),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTransactionalFileBackend(conf map[string]string, logger log.Logger) (Backend, error) {
|
||||||
|
path, ok := conf["path"]
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("'path' must be set")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a pool of size 1 so only one operation runs at a time
|
||||||
|
return &TransactionalFileBackend{
|
||||||
|
FileBackend: FileBackend{
|
||||||
|
path: path,
|
||||||
|
logger: logger,
|
||||||
|
permitPool: NewPermitPool(1),
|
||||||
|
},
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *FileBackend) Delete(path string) error {
|
func (b *FileBackend) Delete(path string) error {
|
||||||
|
b.permitPool.Acquire()
|
||||||
|
defer b.permitPool.Release()
|
||||||
|
|
||||||
|
b.Lock()
|
||||||
|
defer b.Unlock()
|
||||||
|
|
||||||
|
return b.DeleteInternal(path)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *FileBackend) DeleteInternal(path string) error {
|
||||||
if path == "" {
|
if path == "" {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
b.l.Lock()
|
basePath, key := b.expandPath(path)
|
||||||
defer b.l.Unlock()
|
|
||||||
|
|
||||||
basePath, key := b.path(path)
|
|
||||||
fullPath := filepath.Join(basePath, key)
|
fullPath := filepath.Join(basePath, key)
|
||||||
|
|
||||||
err := os.Remove(fullPath)
|
err := os.Remove(fullPath)
|
||||||
@@ -66,7 +95,7 @@ func (b *FileBackend) Delete(path string) error {
|
|||||||
func (b *FileBackend) cleanupLogicalPath(path string) error {
|
func (b *FileBackend) cleanupLogicalPath(path string) error {
|
||||||
nodes := strings.Split(path, fmt.Sprintf("%c", os.PathSeparator))
|
nodes := strings.Split(path, fmt.Sprintf("%c", os.PathSeparator))
|
||||||
for i := len(nodes) - 1; i > 0; i-- {
|
for i := len(nodes) - 1; i > 0; i-- {
|
||||||
fullPath := filepath.Join(b.Path, filepath.Join(nodes[:i]...))
|
fullPath := filepath.Join(b.path, filepath.Join(nodes[:i]...))
|
||||||
|
|
||||||
dir, err := os.Open(fullPath)
|
dir, err := os.Open(fullPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -96,10 +125,17 @@ func (b *FileBackend) cleanupLogicalPath(path string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (b *FileBackend) Get(k string) (*Entry, error) {
|
func (b *FileBackend) Get(k string) (*Entry, error) {
|
||||||
b.l.Lock()
|
b.permitPool.Acquire()
|
||||||
defer b.l.Unlock()
|
defer b.permitPool.Release()
|
||||||
|
|
||||||
path, key := b.path(k)
|
b.RLock()
|
||||||
|
defer b.RUnlock()
|
||||||
|
|
||||||
|
return b.GetInternal(k)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *FileBackend) GetInternal(k string) (*Entry, error) {
|
||||||
|
path, key := b.expandPath(k)
|
||||||
path = filepath.Join(path, key)
|
path = filepath.Join(path, key)
|
||||||
|
|
||||||
f, err := os.Open(path)
|
f, err := os.Open(path)
|
||||||
@@ -121,10 +157,17 @@ func (b *FileBackend) Get(k string) (*Entry, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (b *FileBackend) Put(entry *Entry) error {
|
func (b *FileBackend) Put(entry *Entry) error {
|
||||||
path, key := b.path(entry.Key)
|
b.permitPool.Acquire()
|
||||||
|
defer b.permitPool.Release()
|
||||||
|
|
||||||
b.l.Lock()
|
b.Lock()
|
||||||
defer b.l.Unlock()
|
defer b.Unlock()
|
||||||
|
|
||||||
|
return b.PutInternal(entry)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *FileBackend) PutInternal(entry *Entry) error {
|
||||||
|
path, key := b.expandPath(entry.Key)
|
||||||
|
|
||||||
// Make the parent tree
|
// Make the parent tree
|
||||||
if err := os.MkdirAll(path, 0755); err != nil {
|
if err := os.MkdirAll(path, 0755); err != nil {
|
||||||
@@ -145,10 +188,17 @@ func (b *FileBackend) Put(entry *Entry) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (b *FileBackend) List(prefix string) ([]string, error) {
|
func (b *FileBackend) List(prefix string) ([]string, error) {
|
||||||
b.l.Lock()
|
b.permitPool.Acquire()
|
||||||
defer b.l.Unlock()
|
defer b.permitPool.Release()
|
||||||
|
|
||||||
path := b.Path
|
b.RLock()
|
||||||
|
defer b.RUnlock()
|
||||||
|
|
||||||
|
return b.ListInternal(prefix)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *FileBackend) ListInternal(prefix string) ([]string, error) {
|
||||||
|
path := b.path
|
||||||
if prefix != "" {
|
if prefix != "" {
|
||||||
path = filepath.Join(path, prefix)
|
path = filepath.Join(path, prefix)
|
||||||
}
|
}
|
||||||
@@ -180,9 +230,19 @@ func (b *FileBackend) List(prefix string) ([]string, error) {
|
|||||||
return names, nil
|
return names, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *FileBackend) path(k string) (string, string) {
|
func (b *FileBackend) expandPath(k string) (string, string) {
|
||||||
path := filepath.Join(b.Path, k)
|
path := filepath.Join(b.path, k)
|
||||||
key := filepath.Base(path)
|
key := filepath.Base(path)
|
||||||
path = filepath.Dir(path)
|
path = filepath.Dir(path)
|
||||||
return path, "_" + key
|
return path, "_" + key
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (b *TransactionalFileBackend) Transaction(txns []TxnEntry) error {
|
||||||
|
b.permitPool.Acquire()
|
||||||
|
defer b.permitPool.Release()
|
||||||
|
|
||||||
|
b.Lock()
|
||||||
|
defer b.Unlock()
|
||||||
|
|
||||||
|
return genericTransactionHandler(b, txns)
|
||||||
|
}
|
||||||
|
|||||||
@@ -13,12 +13,16 @@ import (
|
|||||||
// for testing and development situations where the data is not
|
// for testing and development situations where the data is not
|
||||||
// expected to be durable.
|
// expected to be durable.
|
||||||
type InmemBackend struct {
|
type InmemBackend struct {
|
||||||
|
sync.RWMutex
|
||||||
root *radix.Tree
|
root *radix.Tree
|
||||||
l sync.RWMutex
|
|
||||||
permitPool *PermitPool
|
permitPool *PermitPool
|
||||||
logger log.Logger
|
logger log.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type TransactionalInmemBackend struct {
|
||||||
|
InmemBackend
|
||||||
|
}
|
||||||
|
|
||||||
// NewInmem constructs a new in-memory backend
|
// NewInmem constructs a new in-memory backend
|
||||||
func NewInmem(logger log.Logger) *InmemBackend {
|
func NewInmem(logger log.Logger) *InmemBackend {
|
||||||
in := &InmemBackend{
|
in := &InmemBackend{
|
||||||
@@ -29,14 +33,31 @@ func NewInmem(logger log.Logger) *InmemBackend {
|
|||||||
return in
|
return in
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Basically for now just creates a permit pool of size 1 so only one operation
|
||||||
|
// can run at a time
|
||||||
|
func NewTransactionalInmem(logger log.Logger) *TransactionalInmemBackend {
|
||||||
|
in := &TransactionalInmemBackend{
|
||||||
|
InmemBackend: InmemBackend{
|
||||||
|
root: radix.New(),
|
||||||
|
permitPool: NewPermitPool(1),
|
||||||
|
logger: logger,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return in
|
||||||
|
}
|
||||||
|
|
||||||
// Put is used to insert or update an entry
|
// Put is used to insert or update an entry
|
||||||
func (i *InmemBackend) Put(entry *Entry) error {
|
func (i *InmemBackend) Put(entry *Entry) error {
|
||||||
i.permitPool.Acquire()
|
i.permitPool.Acquire()
|
||||||
defer i.permitPool.Release()
|
defer i.permitPool.Release()
|
||||||
|
|
||||||
i.l.Lock()
|
i.Lock()
|
||||||
defer i.l.Unlock()
|
defer i.Unlock()
|
||||||
|
|
||||||
|
return i.PutInternal(entry)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *InmemBackend) PutInternal(entry *Entry) error {
|
||||||
i.root.Insert(entry.Key, entry)
|
i.root.Insert(entry.Key, entry)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -46,9 +67,13 @@ func (i *InmemBackend) Get(key string) (*Entry, error) {
|
|||||||
i.permitPool.Acquire()
|
i.permitPool.Acquire()
|
||||||
defer i.permitPool.Release()
|
defer i.permitPool.Release()
|
||||||
|
|
||||||
i.l.RLock()
|
i.RLock()
|
||||||
defer i.l.RUnlock()
|
defer i.RUnlock()
|
||||||
|
|
||||||
|
return i.GetInternal(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *InmemBackend) GetInternal(key string) (*Entry, error) {
|
||||||
if raw, ok := i.root.Get(key); ok {
|
if raw, ok := i.root.Get(key); ok {
|
||||||
return raw.(*Entry), nil
|
return raw.(*Entry), nil
|
||||||
}
|
}
|
||||||
@@ -60,9 +85,13 @@ func (i *InmemBackend) Delete(key string) error {
|
|||||||
i.permitPool.Acquire()
|
i.permitPool.Acquire()
|
||||||
defer i.permitPool.Release()
|
defer i.permitPool.Release()
|
||||||
|
|
||||||
i.l.Lock()
|
i.Lock()
|
||||||
defer i.l.Unlock()
|
defer i.Unlock()
|
||||||
|
|
||||||
|
return i.DeleteInternal(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *InmemBackend) DeleteInternal(key string) error {
|
||||||
i.root.Delete(key)
|
i.root.Delete(key)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -73,9 +102,13 @@ func (i *InmemBackend) List(prefix string) ([]string, error) {
|
|||||||
i.permitPool.Acquire()
|
i.permitPool.Acquire()
|
||||||
defer i.permitPool.Release()
|
defer i.permitPool.Release()
|
||||||
|
|
||||||
i.l.RLock()
|
i.RLock()
|
||||||
defer i.l.RUnlock()
|
defer i.RUnlock()
|
||||||
|
|
||||||
|
return i.ListInternal(prefix)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *InmemBackend) ListInternal(prefix string) ([]string, error) {
|
||||||
var out []string
|
var out []string
|
||||||
seen := make(map[string]interface{})
|
seen := make(map[string]interface{})
|
||||||
walkFn := func(s string, v interface{}) bool {
|
walkFn := func(s string, v interface{}) bool {
|
||||||
@@ -96,3 +129,14 @@ func (i *InmemBackend) List(prefix string) ([]string, error) {
|
|||||||
|
|
||||||
return out, nil
|
return out, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Implements the transaction interface
|
||||||
|
func (t *TransactionalInmemBackend) Transaction(txns []TxnEntry) error {
|
||||||
|
t.permitPool.Acquire()
|
||||||
|
defer t.permitPool.Release()
|
||||||
|
|
||||||
|
t.Lock()
|
||||||
|
defer t.Unlock()
|
||||||
|
|
||||||
|
return genericTransactionHandler(t, txns)
|
||||||
|
}
|
||||||
|
|||||||
@@ -8,19 +8,40 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type InmemHABackend struct {
|
type InmemHABackend struct {
|
||||||
InmemBackend
|
Backend
|
||||||
locks map[string]string
|
locks map[string]string
|
||||||
l sync.Mutex
|
l sync.Mutex
|
||||||
cond *sync.Cond
|
cond *sync.Cond
|
||||||
logger log.Logger
|
logger log.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type TransactionalInmemHABackend struct {
|
||||||
|
Transactional
|
||||||
|
InmemHABackend
|
||||||
|
}
|
||||||
|
|
||||||
// NewInmemHA constructs a new in-memory HA backend. This is only for testing.
|
// NewInmemHA constructs a new in-memory HA backend. This is only for testing.
|
||||||
func NewInmemHA(logger log.Logger) *InmemHABackend {
|
func NewInmemHA(logger log.Logger) *InmemHABackend {
|
||||||
in := &InmemHABackend{
|
in := &InmemHABackend{
|
||||||
InmemBackend: *NewInmem(logger),
|
Backend: NewInmem(logger),
|
||||||
locks: make(map[string]string),
|
locks: make(map[string]string),
|
||||||
logger: logger,
|
logger: logger,
|
||||||
|
}
|
||||||
|
in.cond = sync.NewCond(&in.l)
|
||||||
|
return in
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewTransactionalInmemHA(logger log.Logger) *TransactionalInmemHABackend {
|
||||||
|
transInmem := NewTransactionalInmem(logger)
|
||||||
|
inmemHA := InmemHABackend{
|
||||||
|
Backend: transInmem,
|
||||||
|
locks: make(map[string]string),
|
||||||
|
logger: logger,
|
||||||
|
}
|
||||||
|
|
||||||
|
in := &TransactionalInmemHABackend{
|
||||||
|
InmemHABackend: inmemHA,
|
||||||
|
Transactional: transInmem,
|
||||||
}
|
}
|
||||||
in.cond = sync.NewCond(&in.l)
|
in.cond = sync.NewCond(&in.l)
|
||||||
return in
|
return in
|
||||||
|
|||||||
@@ -9,6 +9,16 @@ import (
|
|||||||
|
|
||||||
const DefaultParallelOperations = 128
|
const DefaultParallelOperations = 128
|
||||||
|
|
||||||
|
// The operation type
|
||||||
|
type Operation string
|
||||||
|
|
||||||
|
const (
|
||||||
|
DeleteOperation Operation = "delete"
|
||||||
|
GetOperation = "get"
|
||||||
|
ListOperation = "list"
|
||||||
|
PutOperation = "put"
|
||||||
|
)
|
||||||
|
|
||||||
// ShutdownSignal
|
// ShutdownSignal
|
||||||
type ShutdownChannel chan struct{}
|
type ShutdownChannel chan struct{}
|
||||||
|
|
||||||
@@ -121,20 +131,27 @@ var builtinBackends = map[string]Factory{
|
|||||||
"inmem": func(_ map[string]string, logger log.Logger) (Backend, error) {
|
"inmem": func(_ map[string]string, logger log.Logger) (Backend, error) {
|
||||||
return NewInmem(logger), nil
|
return NewInmem(logger), nil
|
||||||
},
|
},
|
||||||
|
"inmem_transactional": func(_ map[string]string, logger log.Logger) (Backend, error) {
|
||||||
|
return NewTransactionalInmem(logger), nil
|
||||||
|
},
|
||||||
"inmem_ha": func(_ map[string]string, logger log.Logger) (Backend, error) {
|
"inmem_ha": func(_ map[string]string, logger log.Logger) (Backend, error) {
|
||||||
return NewInmemHA(logger), nil
|
return NewInmemHA(logger), nil
|
||||||
},
|
},
|
||||||
"consul": newConsulBackend,
|
"inmem_transactional_ha": func(_ map[string]string, logger log.Logger) (Backend, error) {
|
||||||
"zookeeper": newZookeeperBackend,
|
return NewTransactionalInmemHA(logger), nil
|
||||||
"file": newFileBackend,
|
},
|
||||||
"s3": newS3Backend,
|
"file_transactional": newTransactionalFileBackend,
|
||||||
"azure": newAzureBackend,
|
"consul": newConsulBackend,
|
||||||
"dynamodb": newDynamoDBBackend,
|
"zookeeper": newZookeeperBackend,
|
||||||
"etcd": newEtcdBackend,
|
"file": newFileBackend,
|
||||||
"mysql": newMySQLBackend,
|
"s3": newS3Backend,
|
||||||
"postgresql": newPostgreSQLBackend,
|
"azure": newAzureBackend,
|
||||||
"swift": newSwiftBackend,
|
"dynamodb": newDynamoDBBackend,
|
||||||
"gcs": newGCSBackend,
|
"etcd": newEtcdBackend,
|
||||||
|
"mysql": newMySQLBackend,
|
||||||
|
"postgresql": newPostgreSQLBackend,
|
||||||
|
"swift": newSwiftBackend,
|
||||||
|
"gcs": newGCSBackend,
|
||||||
}
|
}
|
||||||
|
|
||||||
// PermitPool is used to limit maximum outstanding requests
|
// PermitPool is used to limit maximum outstanding requests
|
||||||
|
|||||||
@@ -71,7 +71,8 @@ func newPostgreSQLBackend(conf map[string]string, logger log.Logger) (Backend, e
|
|||||||
get_query: "SELECT value FROM " + quoted_table + " WHERE path = $1 AND key = $2",
|
get_query: "SELECT value FROM " + quoted_table + " WHERE path = $1 AND key = $2",
|
||||||
delete_query: "DELETE FROM " + quoted_table + " WHERE path = $1 AND key = $2",
|
delete_query: "DELETE FROM " + quoted_table + " WHERE path = $1 AND key = $2",
|
||||||
list_query: "SELECT key FROM " + quoted_table + " WHERE path = $1" +
|
list_query: "SELECT key FROM " + quoted_table + " WHERE path = $1" +
|
||||||
"UNION SELECT substr(path, length($1)+1) FROM " + quoted_table + "WHERE parent_path = $1",
|
"UNION SELECT DISTINCT substring(substr(path, length($1)+1) from '^.*?/') FROM " +
|
||||||
|
quoted_table + " WHERE parent_path LIKE concat($1, '%')",
|
||||||
logger: logger,
|
logger: logger,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
121
physical/transactions.go
Normal file
121
physical/transactions.go
Normal file
@@ -0,0 +1,121 @@
|
|||||||
|
package physical
|
||||||
|
|
||||||
|
import multierror "github.com/hashicorp/go-multierror"
|
||||||
|
|
||||||
|
// TxnEntry is an operation that takes atomically as part of
|
||||||
|
// a transactional update. Only supported by Transactional backends.
|
||||||
|
type TxnEntry struct {
|
||||||
|
Operation Operation
|
||||||
|
Entry *Entry
|
||||||
|
}
|
||||||
|
|
||||||
|
// Transactional is an optional interface for backends that
|
||||||
|
// support doing transactional updates of multiple keys. This is
|
||||||
|
// required for some features such as replication.
|
||||||
|
type Transactional interface {
|
||||||
|
// The function to run a transaction
|
||||||
|
Transaction([]TxnEntry) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type PseudoTransactional interface {
|
||||||
|
// An internal function should do no locking or permit pool acquisition.
|
||||||
|
// Depending on the backend and if it natively supports transactions, these
|
||||||
|
// may simply chain to the normal backend functions.
|
||||||
|
GetInternal(string) (*Entry, error)
|
||||||
|
PutInternal(*Entry) error
|
||||||
|
DeleteInternal(string) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// Implements the transaction interface
|
||||||
|
func genericTransactionHandler(t PseudoTransactional, txns []TxnEntry) (retErr error) {
|
||||||
|
rollbackStack := make([]TxnEntry, 0, len(txns))
|
||||||
|
var dirty bool
|
||||||
|
|
||||||
|
// We walk the transactions in order; each successful operation goes into a
|
||||||
|
// LIFO for rollback if we hit an error along the way
|
||||||
|
TxnWalk:
|
||||||
|
for _, txn := range txns {
|
||||||
|
switch txn.Operation {
|
||||||
|
case DeleteOperation:
|
||||||
|
entry, err := t.GetInternal(txn.Entry.Key)
|
||||||
|
if err != nil {
|
||||||
|
retErr = multierror.Append(retErr, err)
|
||||||
|
dirty = true
|
||||||
|
break TxnWalk
|
||||||
|
}
|
||||||
|
if entry == nil {
|
||||||
|
// Nothing to delete or roll back
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
rollbackEntry := TxnEntry{
|
||||||
|
Operation: PutOperation,
|
||||||
|
Entry: &Entry{
|
||||||
|
Key: entry.Key,
|
||||||
|
Value: entry.Value,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
err = t.DeleteInternal(txn.Entry.Key)
|
||||||
|
if err != nil {
|
||||||
|
retErr = multierror.Append(retErr, err)
|
||||||
|
dirty = true
|
||||||
|
break TxnWalk
|
||||||
|
}
|
||||||
|
rollbackStack = append([]TxnEntry{rollbackEntry}, rollbackStack...)
|
||||||
|
|
||||||
|
case PutOperation:
|
||||||
|
entry, err := t.GetInternal(txn.Entry.Key)
|
||||||
|
if err != nil {
|
||||||
|
retErr = multierror.Append(retErr, err)
|
||||||
|
dirty = true
|
||||||
|
break TxnWalk
|
||||||
|
}
|
||||||
|
// Nothing existed so in fact rolling back requires a delete
|
||||||
|
var rollbackEntry TxnEntry
|
||||||
|
if entry == nil {
|
||||||
|
rollbackEntry = TxnEntry{
|
||||||
|
Operation: DeleteOperation,
|
||||||
|
Entry: &Entry{
|
||||||
|
Key: txn.Entry.Key,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
rollbackEntry = TxnEntry{
|
||||||
|
Operation: PutOperation,
|
||||||
|
Entry: &Entry{
|
||||||
|
Key: entry.Key,
|
||||||
|
Value: entry.Value,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
err = t.PutInternal(txn.Entry)
|
||||||
|
if err != nil {
|
||||||
|
retErr = multierror.Append(retErr, err)
|
||||||
|
dirty = true
|
||||||
|
break TxnWalk
|
||||||
|
}
|
||||||
|
rollbackStack = append([]TxnEntry{rollbackEntry}, rollbackStack...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Need to roll back because we hit an error along the way
|
||||||
|
if dirty {
|
||||||
|
// While traversing this, if we get an error, we continue anyways in
|
||||||
|
// best-effort fashion
|
||||||
|
for _, txn := range rollbackStack {
|
||||||
|
switch txn.Operation {
|
||||||
|
case DeleteOperation:
|
||||||
|
err := t.DeleteInternal(txn.Entry.Key)
|
||||||
|
if err != nil {
|
||||||
|
retErr = multierror.Append(retErr, err)
|
||||||
|
}
|
||||||
|
case PutOperation:
|
||||||
|
err := t.PutInternal(txn.Entry)
|
||||||
|
if err != nil {
|
||||||
|
retErr = multierror.Append(retErr, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
254
physical/transactions_test.go
Normal file
254
physical/transactions_test.go
Normal file
@@ -0,0 +1,254 @@
|
|||||||
|
package physical
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"sort"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
radix "github.com/armon/go-radix"
|
||||||
|
"github.com/hashicorp/vault/helper/logformat"
|
||||||
|
log "github.com/mgutz/logxi/v1"
|
||||||
|
)
|
||||||
|
|
||||||
|
type faultyPseudo struct {
|
||||||
|
underlying InmemBackend
|
||||||
|
faultyPaths map[string]struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *faultyPseudo) Get(key string) (*Entry, error) {
|
||||||
|
return f.underlying.Get(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *faultyPseudo) Put(entry *Entry) error {
|
||||||
|
return f.underlying.Put(entry)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *faultyPseudo) Delete(key string) error {
|
||||||
|
return f.underlying.Delete(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *faultyPseudo) GetInternal(key string) (*Entry, error) {
|
||||||
|
if _, ok := f.faultyPaths[key]; ok {
|
||||||
|
return nil, fmt.Errorf("fault")
|
||||||
|
}
|
||||||
|
return f.underlying.GetInternal(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *faultyPseudo) PutInternal(entry *Entry) error {
|
||||||
|
if _, ok := f.faultyPaths[entry.Key]; ok {
|
||||||
|
return fmt.Errorf("fault")
|
||||||
|
}
|
||||||
|
return f.underlying.PutInternal(entry)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *faultyPseudo) DeleteInternal(key string) error {
|
||||||
|
if _, ok := f.faultyPaths[key]; ok {
|
||||||
|
return fmt.Errorf("fault")
|
||||||
|
}
|
||||||
|
return f.underlying.DeleteInternal(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *faultyPseudo) List(prefix string) ([]string, error) {
|
||||||
|
return f.underlying.List(prefix)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *faultyPseudo) Transaction(txns []TxnEntry) error {
|
||||||
|
f.underlying.permitPool.Acquire()
|
||||||
|
defer f.underlying.permitPool.Release()
|
||||||
|
|
||||||
|
f.underlying.Lock()
|
||||||
|
defer f.underlying.Unlock()
|
||||||
|
|
||||||
|
return genericTransactionHandler(f, txns)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newFaultyPseudo(logger log.Logger, faultyPaths []string) *faultyPseudo {
|
||||||
|
out := &faultyPseudo{
|
||||||
|
underlying: InmemBackend{
|
||||||
|
root: radix.New(),
|
||||||
|
permitPool: NewPermitPool(1),
|
||||||
|
logger: logger,
|
||||||
|
},
|
||||||
|
faultyPaths: make(map[string]struct{}, len(faultyPaths)),
|
||||||
|
}
|
||||||
|
for _, v := range faultyPaths {
|
||||||
|
out.faultyPaths[v] = struct{}{}
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPseudo_Basic(t *testing.T) {
|
||||||
|
logger := logformat.NewVaultLogger(log.LevelTrace)
|
||||||
|
p := newFaultyPseudo(logger, nil)
|
||||||
|
testBackend(t, p)
|
||||||
|
testBackend_ListPrefix(t, p)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPseudo_SuccessfulTransaction(t *testing.T) {
|
||||||
|
logger := logformat.NewVaultLogger(log.LevelTrace)
|
||||||
|
p := newFaultyPseudo(logger, nil)
|
||||||
|
|
||||||
|
txns := setupPseudo(p, t)
|
||||||
|
|
||||||
|
if err := p.Transaction(txns); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
keys, err := p.List("")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
expected := []string{"foo", "zip"}
|
||||||
|
|
||||||
|
sort.Strings(keys)
|
||||||
|
sort.Strings(expected)
|
||||||
|
if !reflect.DeepEqual(keys, expected) {
|
||||||
|
t.Fatalf("mismatch: expected\n%#v\ngot\n%#v\n", expected, keys)
|
||||||
|
}
|
||||||
|
|
||||||
|
entry, err := p.Get("foo")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if entry == nil {
|
||||||
|
t.Fatal("got nil entry")
|
||||||
|
}
|
||||||
|
if entry.Value == nil {
|
||||||
|
t.Fatal("got nil value")
|
||||||
|
}
|
||||||
|
if string(entry.Value) != "bar3" {
|
||||||
|
t.Fatal("updates did not apply correctly")
|
||||||
|
}
|
||||||
|
|
||||||
|
entry, err = p.Get("zip")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if entry == nil {
|
||||||
|
t.Fatal("got nil entry")
|
||||||
|
}
|
||||||
|
if entry.Value == nil {
|
||||||
|
t.Fatal("got nil value")
|
||||||
|
}
|
||||||
|
if string(entry.Value) != "zap3" {
|
||||||
|
t.Fatal("updates did not apply correctly")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPseudo_FailedTransaction(t *testing.T) {
|
||||||
|
logger := logformat.NewVaultLogger(log.LevelTrace)
|
||||||
|
p := newFaultyPseudo(logger, []string{"zip"})
|
||||||
|
|
||||||
|
txns := setupPseudo(p, t)
|
||||||
|
|
||||||
|
if err := p.Transaction(txns); err == nil {
|
||||||
|
t.Fatal("expected error during transaction")
|
||||||
|
}
|
||||||
|
|
||||||
|
keys, err := p.List("")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
expected := []string{"foo", "zip", "deleteme", "deleteme2"}
|
||||||
|
|
||||||
|
sort.Strings(keys)
|
||||||
|
sort.Strings(expected)
|
||||||
|
if !reflect.DeepEqual(keys, expected) {
|
||||||
|
t.Fatalf("mismatch: expected\n%#v\ngot\n%#v\n", expected, keys)
|
||||||
|
}
|
||||||
|
|
||||||
|
entry, err := p.Get("foo")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if entry == nil {
|
||||||
|
t.Fatal("got nil entry")
|
||||||
|
}
|
||||||
|
if entry.Value == nil {
|
||||||
|
t.Fatal("got nil value")
|
||||||
|
}
|
||||||
|
if string(entry.Value) != "bar" {
|
||||||
|
t.Fatal("values did not rollback correctly")
|
||||||
|
}
|
||||||
|
|
||||||
|
entry, err = p.Get("zip")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if entry == nil {
|
||||||
|
t.Fatal("got nil entry")
|
||||||
|
}
|
||||||
|
if entry.Value == nil {
|
||||||
|
t.Fatal("got nil value")
|
||||||
|
}
|
||||||
|
if string(entry.Value) != "zap" {
|
||||||
|
t.Fatal("values did not rollback correctly")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupPseudo(p *faultyPseudo, t *testing.T) []TxnEntry {
|
||||||
|
// Add a few keys so that we test rollback with deletion
|
||||||
|
if err := p.Put(&Entry{
|
||||||
|
Key: "foo",
|
||||||
|
Value: []byte("bar"),
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := p.Put(&Entry{
|
||||||
|
Key: "zip",
|
||||||
|
Value: []byte("zap"),
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := p.Put(&Entry{
|
||||||
|
Key: "deleteme",
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := p.Put(&Entry{
|
||||||
|
Key: "deleteme2",
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
txns := []TxnEntry{
|
||||||
|
TxnEntry{
|
||||||
|
Operation: PutOperation,
|
||||||
|
Entry: &Entry{
|
||||||
|
Key: "foo",
|
||||||
|
Value: []byte("bar2"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
TxnEntry{
|
||||||
|
Operation: DeleteOperation,
|
||||||
|
Entry: &Entry{
|
||||||
|
Key: "deleteme",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
TxnEntry{
|
||||||
|
Operation: PutOperation,
|
||||||
|
Entry: &Entry{
|
||||||
|
Key: "foo",
|
||||||
|
Value: []byte("bar3"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
TxnEntry{
|
||||||
|
Operation: DeleteOperation,
|
||||||
|
Entry: &Entry{
|
||||||
|
Key: "deleteme2",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
TxnEntry{
|
||||||
|
Operation: PutOperation,
|
||||||
|
Entry: &Entry{
|
||||||
|
Key: "zip",
|
||||||
|
Value: []byte("zap3"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
return txns
|
||||||
|
}
|
||||||
@@ -10,7 +10,7 @@ RUN apt-get update -y && apt-get install --no-install-recommends -y -q \
|
|||||||
git mercurial bzr \
|
git mercurial bzr \
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
ENV GOVERSION 1.8rc3
|
ENV GOVERSION 1.8
|
||||||
RUN mkdir /goroot && mkdir /gopath
|
RUN mkdir /goroot && mkdir /gopath
|
||||||
RUN curl https://storage.googleapis.com/golang/go${GOVERSION}.linux-amd64.tar.gz \
|
RUN curl https://storage.googleapis.com/golang/go${GOVERSION}.linux-amd64.tar.gz \
|
||||||
| tar xvzf - -C /goroot --strip-components=1
|
| tar xvzf - -C /goroot --strip-components=1
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ package vault
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -26,6 +25,10 @@ const (
|
|||||||
// can only be viewed or modified after an unseal.
|
// can only be viewed or modified after an unseal.
|
||||||
coreAuditConfigPath = "core/audit"
|
coreAuditConfigPath = "core/audit"
|
||||||
|
|
||||||
|
// coreLocalAuditConfigPath is used to store audit information for local
|
||||||
|
// (non-replicated) mounts
|
||||||
|
coreLocalAuditConfigPath = "core/local-audit"
|
||||||
|
|
||||||
// auditBarrierPrefix is the prefix to the UUID used in the
|
// auditBarrierPrefix is the prefix to the UUID used in the
|
||||||
// barrier view for the audit backends.
|
// barrier view for the audit backends.
|
||||||
auditBarrierPrefix = "audit/"
|
auditBarrierPrefix = "audit/"
|
||||||
@@ -69,12 +72,15 @@ func (c *Core) enableAudit(entry *MountEntry) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Generate a new UUID and view
|
// Generate a new UUID and view
|
||||||
entryUUID, err := uuid.GenerateUUID()
|
if entry.UUID == "" {
|
||||||
if err != nil {
|
entryUUID, err := uuid.GenerateUUID()
|
||||||
return err
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
entry.UUID = entryUUID
|
||||||
}
|
}
|
||||||
entry.UUID = entryUUID
|
viewPath := auditBarrierPrefix + entry.UUID + "/"
|
||||||
view := NewBarrierView(c.barrier, auditBarrierPrefix+entry.UUID+"/")
|
view := NewBarrierView(c.barrier, viewPath)
|
||||||
|
|
||||||
// Lookup the new backend
|
// Lookup the new backend
|
||||||
backend, err := c.newAuditBackend(entry, view, entry.Options)
|
backend, err := c.newAuditBackend(entry, view, entry.Options)
|
||||||
@@ -119,6 +125,12 @@ func (c *Core) disableAudit(path string) (bool, error) {
|
|||||||
|
|
||||||
c.removeAuditReloadFunc(entry)
|
c.removeAuditReloadFunc(entry)
|
||||||
|
|
||||||
|
// When unmounting all entries the JSON code will load back up from storage
|
||||||
|
// as a nil slice, which kills tests...just set it nil explicitly
|
||||||
|
if len(newTable.Entries) == 0 {
|
||||||
|
newTable.Entries = nil
|
||||||
|
}
|
||||||
|
|
||||||
// Update the audit table
|
// Update the audit table
|
||||||
if err := c.persistAudit(newTable); err != nil {
|
if err := c.persistAudit(newTable); err != nil {
|
||||||
return true, errors.New("failed to update audit table")
|
return true, errors.New("failed to update audit table")
|
||||||
@@ -131,12 +143,14 @@ func (c *Core) disableAudit(path string) (bool, error) {
|
|||||||
if c.logger.IsInfo() {
|
if c.logger.IsInfo() {
|
||||||
c.logger.Info("core: disabled audit backend", "path", path)
|
c.logger.Info("core: disabled audit backend", "path", path)
|
||||||
}
|
}
|
||||||
|
|
||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// loadAudits is invoked as part of postUnseal to load the audit table
|
// loadAudits is invoked as part of postUnseal to load the audit table
|
||||||
func (c *Core) loadAudits() error {
|
func (c *Core) loadAudits() error {
|
||||||
auditTable := &MountTable{}
|
auditTable := &MountTable{}
|
||||||
|
localAuditTable := &MountTable{}
|
||||||
|
|
||||||
// Load the existing audit table
|
// Load the existing audit table
|
||||||
raw, err := c.barrier.Get(coreAuditConfigPath)
|
raw, err := c.barrier.Get(coreAuditConfigPath)
|
||||||
@@ -144,6 +158,11 @@ func (c *Core) loadAudits() error {
|
|||||||
c.logger.Error("core: failed to read audit table", "error", err)
|
c.logger.Error("core: failed to read audit table", "error", err)
|
||||||
return errLoadAuditFailed
|
return errLoadAuditFailed
|
||||||
}
|
}
|
||||||
|
rawLocal, err := c.barrier.Get(coreLocalAuditConfigPath)
|
||||||
|
if err != nil {
|
||||||
|
c.logger.Error("core: failed to read local audit table", "error", err)
|
||||||
|
return errLoadAuditFailed
|
||||||
|
}
|
||||||
|
|
||||||
c.auditLock.Lock()
|
c.auditLock.Lock()
|
||||||
defer c.auditLock.Unlock()
|
defer c.auditLock.Unlock()
|
||||||
@@ -155,6 +174,13 @@ func (c *Core) loadAudits() error {
|
|||||||
}
|
}
|
||||||
c.audit = auditTable
|
c.audit = auditTable
|
||||||
}
|
}
|
||||||
|
if rawLocal != nil {
|
||||||
|
if err := jsonutil.DecodeJSON(rawLocal.Value, localAuditTable); err != nil {
|
||||||
|
c.logger.Error("core: failed to decode local audit table", "error", err)
|
||||||
|
return errLoadAuditFailed
|
||||||
|
}
|
||||||
|
c.audit.Entries = append(c.audit.Entries, localAuditTable.Entries...)
|
||||||
|
}
|
||||||
|
|
||||||
// Done if we have restored the audit table
|
// Done if we have restored the audit table
|
||||||
if c.audit != nil {
|
if c.audit != nil {
|
||||||
@@ -203,17 +229,33 @@ func (c *Core) persistAudit(table *MountTable) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
nonLocalAudit := &MountTable{
|
||||||
|
Type: auditTableType,
|
||||||
|
}
|
||||||
|
|
||||||
|
localAudit := &MountTable{
|
||||||
|
Type: auditTableType,
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, entry := range table.Entries {
|
||||||
|
if entry.Local {
|
||||||
|
localAudit.Entries = append(localAudit.Entries, entry)
|
||||||
|
} else {
|
||||||
|
nonLocalAudit.Entries = append(nonLocalAudit.Entries, entry)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Marshal the table
|
// Marshal the table
|
||||||
raw, err := json.Marshal(table)
|
compressedBytes, err := jsonutil.EncodeJSONAndCompress(nonLocalAudit, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.logger.Error("core: failed to encode audit table", "error", err)
|
c.logger.Error("core: failed to encode and/or compress audit table", "error", err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create an entry
|
// Create an entry
|
||||||
entry := &Entry{
|
entry := &Entry{
|
||||||
Key: coreAuditConfigPath,
|
Key: coreAuditConfigPath,
|
||||||
Value: raw,
|
Value: compressedBytes,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Write to the physical backend
|
// Write to the physical backend
|
||||||
@@ -221,6 +263,24 @@ func (c *Core) persistAudit(table *MountTable) error {
|
|||||||
c.logger.Error("core: failed to persist audit table", "error", err)
|
c.logger.Error("core: failed to persist audit table", "error", err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Repeat with local audit
|
||||||
|
compressedBytes, err = jsonutil.EncodeJSONAndCompress(localAudit, nil)
|
||||||
|
if err != nil {
|
||||||
|
c.logger.Error("core: failed to encode and/or compress local audit table", "error", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
entry = &Entry{
|
||||||
|
Key: coreLocalAuditConfigPath,
|
||||||
|
Value: compressedBytes,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.barrier.Put(entry); err != nil {
|
||||||
|
c.logger.Error("core: failed to persist local audit table", "error", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -236,7 +296,8 @@ func (c *Core) setupAudits() error {
|
|||||||
|
|
||||||
for _, entry := range c.audit.Entries {
|
for _, entry := range c.audit.Entries {
|
||||||
// Create a barrier view using the UUID
|
// Create a barrier view using the UUID
|
||||||
view := NewBarrierView(c.barrier, auditBarrierPrefix+entry.UUID+"/")
|
viewPath := auditBarrierPrefix + entry.UUID + "/"
|
||||||
|
view := NewBarrierView(c.barrier, viewPath)
|
||||||
|
|
||||||
// Initialize the backend
|
// Initialize the backend
|
||||||
audit, err := c.newAuditBackend(entry, view, entry.Options)
|
audit, err := c.newAuditBackend(entry, view, entry.Options)
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import (
|
|||||||
"github.com/hashicorp/errwrap"
|
"github.com/hashicorp/errwrap"
|
||||||
"github.com/hashicorp/go-uuid"
|
"github.com/hashicorp/go-uuid"
|
||||||
"github.com/hashicorp/vault/audit"
|
"github.com/hashicorp/vault/audit"
|
||||||
|
"github.com/hashicorp/vault/helper/jsonutil"
|
||||||
"github.com/hashicorp/vault/helper/logformat"
|
"github.com/hashicorp/vault/helper/logformat"
|
||||||
"github.com/hashicorp/vault/logical"
|
"github.com/hashicorp/vault/logical"
|
||||||
log "github.com/mgutz/logxi/v1"
|
log "github.com/mgutz/logxi/v1"
|
||||||
@@ -164,6 +165,94 @@ func TestCore_EnableAudit_MixedFailures(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Test that the local table actually gets populated as expected with local
|
||||||
|
// entries, and that upon reading the entries from both are recombined
|
||||||
|
// correctly
|
||||||
|
func TestCore_EnableAudit_Local(t *testing.T) {
|
||||||
|
c, _, _ := TestCoreUnsealed(t)
|
||||||
|
c.auditBackends["noop"] = func(config *audit.BackendConfig) (audit.Backend, error) {
|
||||||
|
return &NoopAudit{
|
||||||
|
Config: config,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
c.auditBackends["fail"] = func(config *audit.BackendConfig) (audit.Backend, error) {
|
||||||
|
return nil, fmt.Errorf("failing enabling")
|
||||||
|
}
|
||||||
|
|
||||||
|
c.audit = &MountTable{
|
||||||
|
Type: auditTableType,
|
||||||
|
Entries: []*MountEntry{
|
||||||
|
&MountEntry{
|
||||||
|
Table: auditTableType,
|
||||||
|
Path: "noop/",
|
||||||
|
Type: "noop",
|
||||||
|
UUID: "abcd",
|
||||||
|
},
|
||||||
|
&MountEntry{
|
||||||
|
Table: auditTableType,
|
||||||
|
Path: "noop2/",
|
||||||
|
Type: "noop",
|
||||||
|
UUID: "bcde",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Both should set up successfully
|
||||||
|
err := c.setupAudits()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rawLocal, err := c.barrier.Get(coreLocalAuditConfigPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if rawLocal == nil {
|
||||||
|
t.Fatal("expected non-nil local audit")
|
||||||
|
}
|
||||||
|
localAuditTable := &MountTable{}
|
||||||
|
if err := jsonutil.DecodeJSON(rawLocal.Value, localAuditTable); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if len(localAuditTable.Entries) > 0 {
|
||||||
|
t.Fatalf("expected no entries in local audit table, got %#v", localAuditTable)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.audit.Entries[1].Local = true
|
||||||
|
if err := c.persistAudit(c.audit); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rawLocal, err = c.barrier.Get(coreLocalAuditConfigPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if rawLocal == nil {
|
||||||
|
t.Fatal("expected non-nil local audit")
|
||||||
|
}
|
||||||
|
localAuditTable = &MountTable{}
|
||||||
|
if err := jsonutil.DecodeJSON(rawLocal.Value, localAuditTable); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if len(localAuditTable.Entries) != 1 {
|
||||||
|
t.Fatalf("expected one entry in local audit table, got %#v", localAuditTable)
|
||||||
|
}
|
||||||
|
|
||||||
|
oldAudit := c.audit
|
||||||
|
if err := c.loadAudits(); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(oldAudit, c.audit) {
|
||||||
|
t.Fatalf("expected\n%#v\ngot\n%#v\n", oldAudit, c.audit)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(c.audit.Entries) != 2 {
|
||||||
|
t.Fatalf("expected two audit entries, got %#v", localAuditTable)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestCore_DisableAudit(t *testing.T) {
|
func TestCore_DisableAudit(t *testing.T) {
|
||||||
c, keys, _ := TestCoreUnsealed(t)
|
c, keys, _ := TestCoreUnsealed(t)
|
||||||
c.auditBackends["noop"] = func(config *audit.BackendConfig) (audit.Backend, error) {
|
c.auditBackends["noop"] = func(config *audit.BackendConfig) (audit.Backend, error) {
|
||||||
@@ -217,7 +306,7 @@ func TestCore_DisableAudit(t *testing.T) {
|
|||||||
|
|
||||||
// Verify matching mount tables
|
// Verify matching mount tables
|
||||||
if !reflect.DeepEqual(c.audit, c2.audit) {
|
if !reflect.DeepEqual(c.audit, c2.audit) {
|
||||||
t.Fatalf("mismatch: %v %v", c.audit, c2.audit)
|
t.Fatalf("mismatch:\n%#v\n%#v", c.audit, c2.audit)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package vault
|
package vault
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -17,6 +16,10 @@ const (
|
|||||||
// can only be viewed or modified after an unseal.
|
// can only be viewed or modified after an unseal.
|
||||||
coreAuthConfigPath = "core/auth"
|
coreAuthConfigPath = "core/auth"
|
||||||
|
|
||||||
|
// coreLocalAuthConfigPath is used to store credential configuration for
|
||||||
|
// local (non-replicated) mounts
|
||||||
|
coreLocalAuthConfigPath = "core/local-auth"
|
||||||
|
|
||||||
// credentialBarrierPrefix is the prefix to the UUID used in the
|
// credentialBarrierPrefix is the prefix to the UUID used in the
|
||||||
// barrier view for the credential backends.
|
// barrier view for the credential backends.
|
||||||
credentialBarrierPrefix = "auth/"
|
credentialBarrierPrefix = "auth/"
|
||||||
@@ -71,16 +74,25 @@ func (c *Core) enableCredential(entry *MountEntry) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Generate a new UUID and view
|
// Generate a new UUID and view
|
||||||
entryUUID, err := uuid.GenerateUUID()
|
if entry.UUID == "" {
|
||||||
|
entryUUID, err := uuid.GenerateUUID()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
entry.UUID = entryUUID
|
||||||
|
}
|
||||||
|
|
||||||
|
viewPath := credentialBarrierPrefix + entry.UUID + "/"
|
||||||
|
view := NewBarrierView(c.barrier, viewPath)
|
||||||
|
sysView := c.mountEntrySysView(entry)
|
||||||
|
|
||||||
|
// Create the new backend
|
||||||
|
backend, err := c.newCredentialBackend(entry.Type, sysView, view, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
entry.UUID = entryUUID
|
|
||||||
view := NewBarrierView(c.barrier, credentialBarrierPrefix+entry.UUID+"/")
|
|
||||||
|
|
||||||
// Create the new backend
|
if err := backend.Initialize(); err != nil {
|
||||||
backend, err := c.newCredentialBackend(entry.Type, c.mountEntrySysView(entry), view, nil)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -121,7 +133,7 @@ func (c *Core) disableCredential(path string) (bool, error) {
|
|||||||
fullPath := credentialRoutePrefix + path
|
fullPath := credentialRoutePrefix + path
|
||||||
view := c.router.MatchingStorageView(fullPath)
|
view := c.router.MatchingStorageView(fullPath)
|
||||||
if view == nil {
|
if view == nil {
|
||||||
return false, fmt.Errorf("no matching backend")
|
return false, fmt.Errorf("no matching backend %s", fullPath)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Mark the entry as tainted
|
// Mark the entry as tainted
|
||||||
@@ -206,12 +218,19 @@ func (c *Core) taintCredEntry(path string) error {
|
|||||||
// loadCredentials is invoked as part of postUnseal to load the auth table
|
// loadCredentials is invoked as part of postUnseal to load the auth table
|
||||||
func (c *Core) loadCredentials() error {
|
func (c *Core) loadCredentials() error {
|
||||||
authTable := &MountTable{}
|
authTable := &MountTable{}
|
||||||
|
localAuthTable := &MountTable{}
|
||||||
|
|
||||||
// Load the existing mount table
|
// Load the existing mount table
|
||||||
raw, err := c.barrier.Get(coreAuthConfigPath)
|
raw, err := c.barrier.Get(coreAuthConfigPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.logger.Error("core: failed to read auth table", "error", err)
|
c.logger.Error("core: failed to read auth table", "error", err)
|
||||||
return errLoadAuthFailed
|
return errLoadAuthFailed
|
||||||
}
|
}
|
||||||
|
rawLocal, err := c.barrier.Get(coreLocalAuthConfigPath)
|
||||||
|
if err != nil {
|
||||||
|
c.logger.Error("core: failed to read local auth table", "error", err)
|
||||||
|
return errLoadAuthFailed
|
||||||
|
}
|
||||||
|
|
||||||
c.authLock.Lock()
|
c.authLock.Lock()
|
||||||
defer c.authLock.Unlock()
|
defer c.authLock.Unlock()
|
||||||
@@ -223,6 +242,13 @@ func (c *Core) loadCredentials() error {
|
|||||||
}
|
}
|
||||||
c.auth = authTable
|
c.auth = authTable
|
||||||
}
|
}
|
||||||
|
if rawLocal != nil {
|
||||||
|
if err := jsonutil.DecodeJSON(rawLocal.Value, localAuthTable); err != nil {
|
||||||
|
c.logger.Error("core: failed to decode local auth table", "error", err)
|
||||||
|
return errLoadAuthFailed
|
||||||
|
}
|
||||||
|
c.auth.Entries = append(c.auth.Entries, localAuthTable.Entries...)
|
||||||
|
}
|
||||||
|
|
||||||
// Done if we have restored the auth table
|
// Done if we have restored the auth table
|
||||||
if c.auth != nil {
|
if c.auth != nil {
|
||||||
@@ -272,17 +298,33 @@ func (c *Core) persistAuth(table *MountTable) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
nonLocalAuth := &MountTable{
|
||||||
|
Type: credentialTableType,
|
||||||
|
}
|
||||||
|
|
||||||
|
localAuth := &MountTable{
|
||||||
|
Type: credentialTableType,
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, entry := range table.Entries {
|
||||||
|
if entry.Local {
|
||||||
|
localAuth.Entries = append(localAuth.Entries, entry)
|
||||||
|
} else {
|
||||||
|
nonLocalAuth.Entries = append(nonLocalAuth.Entries, entry)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Marshal the table
|
// Marshal the table
|
||||||
raw, err := json.Marshal(table)
|
compressedBytes, err := jsonutil.EncodeJSONAndCompress(nonLocalAuth, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.logger.Error("core: failed to encode auth table", "error", err)
|
c.logger.Error("core: failed to encode and/or compress auth table", "error", err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create an entry
|
// Create an entry
|
||||||
entry := &Entry{
|
entry := &Entry{
|
||||||
Key: coreAuthConfigPath,
|
Key: coreAuthConfigPath,
|
||||||
Value: raw,
|
Value: compressedBytes,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Write to the physical backend
|
// Write to the physical backend
|
||||||
@@ -290,6 +332,24 @@ func (c *Core) persistAuth(table *MountTable) error {
|
|||||||
c.logger.Error("core: failed to persist auth table", "error", err)
|
c.logger.Error("core: failed to persist auth table", "error", err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Repeat with local auth
|
||||||
|
compressedBytes, err = jsonutil.EncodeJSONAndCompress(localAuth, nil)
|
||||||
|
if err != nil {
|
||||||
|
c.logger.Error("core: failed to encode and/or compress local auth table", "error", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
entry = &Entry{
|
||||||
|
Key: coreLocalAuthConfigPath,
|
||||||
|
Value: compressedBytes,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.barrier.Put(entry); err != nil {
|
||||||
|
c.logger.Error("core: failed to persist local auth table", "error", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -312,15 +372,21 @@ func (c *Core) setupCredentials() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create a barrier view using the UUID
|
// Create a barrier view using the UUID
|
||||||
view = NewBarrierView(c.barrier, credentialBarrierPrefix+entry.UUID+"/")
|
viewPath := credentialBarrierPrefix + entry.UUID + "/"
|
||||||
|
view = NewBarrierView(c.barrier, viewPath)
|
||||||
|
sysView := c.mountEntrySysView(entry)
|
||||||
|
|
||||||
// Initialize the backend
|
// Initialize the backend
|
||||||
backend, err = c.newCredentialBackend(entry.Type, c.mountEntrySysView(entry), view, nil)
|
backend, err = c.newCredentialBackend(entry.Type, sysView, view, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.logger.Error("core: failed to create credential entry", "path", entry.Path, "error", err)
|
c.logger.Error("core: failed to create credential entry", "path", entry.Path, "error", err)
|
||||||
return errLoadAuthFailed
|
return errLoadAuthFailed
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := backend.Initialize(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
// Mount the backend
|
// Mount the backend
|
||||||
path := credentialRoutePrefix + entry.Path
|
path := credentialRoutePrefix + entry.Path
|
||||||
err = c.router.Mount(backend, path, entry, view)
|
err = c.router.Mount(backend, path, entry, view)
|
||||||
|
|||||||
@@ -2,8 +2,10 @@ package vault
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/hashicorp/vault/helper/jsonutil"
|
||||||
"github.com/hashicorp/vault/logical"
|
"github.com/hashicorp/vault/logical"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -84,6 +86,88 @@ func TestCore_EnableCredential(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Test that the local table actually gets populated as expected with local
|
||||||
|
// entries, and that upon reading the entries from both are recombined
|
||||||
|
// correctly
|
||||||
|
func TestCore_EnableCredential_Local(t *testing.T) {
|
||||||
|
c, _, _ := TestCoreUnsealed(t)
|
||||||
|
c.credentialBackends["noop"] = func(*logical.BackendConfig) (logical.Backend, error) {
|
||||||
|
return &NoopBackend{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
c.auth = &MountTable{
|
||||||
|
Type: credentialTableType,
|
||||||
|
Entries: []*MountEntry{
|
||||||
|
&MountEntry{
|
||||||
|
Table: credentialTableType,
|
||||||
|
Path: "noop/",
|
||||||
|
Type: "noop",
|
||||||
|
UUID: "abcd",
|
||||||
|
},
|
||||||
|
&MountEntry{
|
||||||
|
Table: credentialTableType,
|
||||||
|
Path: "noop2/",
|
||||||
|
Type: "noop",
|
||||||
|
UUID: "bcde",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Both should set up successfully
|
||||||
|
err := c.setupCredentials()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rawLocal, err := c.barrier.Get(coreLocalAuthConfigPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if rawLocal == nil {
|
||||||
|
t.Fatal("expected non-nil local credential")
|
||||||
|
}
|
||||||
|
localCredentialTable := &MountTable{}
|
||||||
|
if err := jsonutil.DecodeJSON(rawLocal.Value, localCredentialTable); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if len(localCredentialTable.Entries) > 0 {
|
||||||
|
t.Fatalf("expected no entries in local credential table, got %#v", localCredentialTable)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.auth.Entries[1].Local = true
|
||||||
|
if err := c.persistAuth(c.auth); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rawLocal, err = c.barrier.Get(coreLocalAuthConfigPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if rawLocal == nil {
|
||||||
|
t.Fatal("expected non-nil local credential")
|
||||||
|
}
|
||||||
|
localCredentialTable = &MountTable{}
|
||||||
|
if err := jsonutil.DecodeJSON(rawLocal.Value, localCredentialTable); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if len(localCredentialTable.Entries) != 1 {
|
||||||
|
t.Fatalf("expected one entry in local credential table, got %#v", localCredentialTable)
|
||||||
|
}
|
||||||
|
|
||||||
|
oldCredential := c.auth
|
||||||
|
if err := c.loadCredentials(); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(oldCredential, c.auth) {
|
||||||
|
t.Fatalf("expected\n%#v\ngot\n%#v\n", oldCredential, c.auth)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(c.auth.Entries) != 2 {
|
||||||
|
t.Fatalf("expected two credential entries, got %#v", localCredentialTable)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestCore_EnableCredential_twice_409(t *testing.T) {
|
func TestCore_EnableCredential_twice_409(t *testing.T) {
|
||||||
c, _, _ := TestCoreUnsealed(t)
|
c, _, _ := TestCoreUnsealed(t)
|
||||||
c.credentialBackends["noop"] = func(*logical.BackendConfig) (logical.Backend, error) {
|
c.credentialBackends["noop"] = func(*logical.BackendConfig) (logical.Backend, error) {
|
||||||
@@ -132,7 +216,7 @@ func TestCore_DisableCredential(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
existed, err := c.disableCredential("foo")
|
existed, err := c.disableCredential("foo")
|
||||||
if existed || err.Error() != "no matching backend" {
|
if existed || (err != nil && !strings.HasPrefix(err.Error(), "no matching backend")) {
|
||||||
t.Fatalf("existed: %v; err: %v", existed, err)
|
t.Fatalf("existed: %v; err: %v", existed, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -86,6 +86,11 @@ type SecurityBarrier interface {
|
|||||||
// VerifyMaster is used to check if the given key matches the master key
|
// VerifyMaster is used to check if the given key matches the master key
|
||||||
VerifyMaster(key []byte) error
|
VerifyMaster(key []byte) error
|
||||||
|
|
||||||
|
// SetMasterKey is used to directly set a new master key. This is used in
|
||||||
|
// repliated scenarios due to the chicken and egg problem of reloading the
|
||||||
|
// keyring from disk before we have the master key to decrypt it.
|
||||||
|
SetMasterKey(key []byte) error
|
||||||
|
|
||||||
// ReloadKeyring is used to re-read the underlying keyring.
|
// ReloadKeyring is used to re-read the underlying keyring.
|
||||||
// This is used for HA deployments to ensure the latest keyring
|
// This is used for HA deployments to ensure the latest keyring
|
||||||
// is present in the leader.
|
// is present in the leader.
|
||||||
@@ -119,8 +124,14 @@ type SecurityBarrier interface {
|
|||||||
// Rekey is used to change the master key used to protect the keyring
|
// Rekey is used to change the master key used to protect the keyring
|
||||||
Rekey([]byte) error
|
Rekey([]byte) error
|
||||||
|
|
||||||
|
// For replication we must send over the keyring, so this must be available
|
||||||
|
Keyring() (*Keyring, error)
|
||||||
|
|
||||||
// SecurityBarrier must provide the storage APIs
|
// SecurityBarrier must provide the storage APIs
|
||||||
BarrierStorage
|
BarrierStorage
|
||||||
|
|
||||||
|
// SecurityBarrier must provide the encryption APIs
|
||||||
|
BarrierEncryptor
|
||||||
}
|
}
|
||||||
|
|
||||||
// BarrierStorage is the storage only interface required for a Barrier.
|
// BarrierStorage is the storage only interface required for a Barrier.
|
||||||
@@ -139,6 +150,14 @@ type BarrierStorage interface {
|
|||||||
List(prefix string) ([]string, error)
|
List(prefix string) ([]string, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// BarrierEncryptor is the in memory only interface that does not actually
|
||||||
|
// use the underlying barrier. It is used for lower level modules like the
|
||||||
|
// Write-Ahead-Log and Merkle index to allow them to use the barrier.
|
||||||
|
type BarrierEncryptor interface {
|
||||||
|
Encrypt(key string, plaintext []byte) ([]byte, error)
|
||||||
|
Decrypt(key string, ciphertext []byte) ([]byte, error)
|
||||||
|
}
|
||||||
|
|
||||||
// Entry is used to represent data stored by the security barrier
|
// Entry is used to represent data stored by the security barrier
|
||||||
type Entry struct {
|
type Entry struct {
|
||||||
Key string
|
Key string
|
||||||
|
|||||||
@@ -574,19 +574,12 @@ func (b *AESGCMBarrier) ActiveKeyInfo() (*KeyInfo, error) {
|
|||||||
func (b *AESGCMBarrier) Rekey(key []byte) error {
|
func (b *AESGCMBarrier) Rekey(key []byte) error {
|
||||||
b.l.Lock()
|
b.l.Lock()
|
||||||
defer b.l.Unlock()
|
defer b.l.Unlock()
|
||||||
if b.sealed {
|
|
||||||
return ErrBarrierSealed
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify the key size
|
newKeyring, err := b.updateMasterKeyCommon(key)
|
||||||
min, max := b.KeyLength()
|
if err != nil {
|
||||||
if len(key) < min || len(key) > max {
|
return err
|
||||||
return fmt.Errorf("Key size must be %d or %d", min, max)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add a new encryption key
|
|
||||||
newKeyring := b.keyring.SetMasterKey(key)
|
|
||||||
|
|
||||||
// Persist the new keyring
|
// Persist the new keyring
|
||||||
if err := b.persistKeyring(newKeyring); err != nil {
|
if err := b.persistKeyring(newKeyring); err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -599,6 +592,40 @@ func (b *AESGCMBarrier) Rekey(key []byte) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetMasterKey updates the keyring's in-memory master key but does not persist
|
||||||
|
// anything to storage
|
||||||
|
func (b *AESGCMBarrier) SetMasterKey(key []byte) error {
|
||||||
|
b.l.Lock()
|
||||||
|
defer b.l.Unlock()
|
||||||
|
|
||||||
|
newKeyring, err := b.updateMasterKeyCommon(key)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Swap the keyrings
|
||||||
|
oldKeyring := b.keyring
|
||||||
|
b.keyring = newKeyring
|
||||||
|
oldKeyring.Zeroize(false)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Performs common tasks related to updating the master key; note that the lock
|
||||||
|
// must be held before calling this function
|
||||||
|
func (b *AESGCMBarrier) updateMasterKeyCommon(key []byte) (*Keyring, error) {
|
||||||
|
if b.sealed {
|
||||||
|
return nil, ErrBarrierSealed
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the key size
|
||||||
|
min, max := b.KeyLength()
|
||||||
|
if len(key) < min || len(key) > max {
|
||||||
|
return nil, fmt.Errorf("Key size must be %d or %d", min, max)
|
||||||
|
}
|
||||||
|
|
||||||
|
return b.keyring.SetMasterKey(key), nil
|
||||||
|
}
|
||||||
|
|
||||||
// Put is used to insert or update an entry
|
// Put is used to insert or update an entry
|
||||||
func (b *AESGCMBarrier) Put(entry *Entry) error {
|
func (b *AESGCMBarrier) Put(entry *Entry) error {
|
||||||
defer metrics.MeasureSince([]string{"barrier", "put"}, time.Now())
|
defer metrics.MeasureSince([]string{"barrier", "put"}, time.Now())
|
||||||
@@ -813,3 +840,47 @@ func (b *AESGCMBarrier) decryptKeyring(path string, cipher []byte) ([]byte, erro
|
|||||||
return nil, fmt.Errorf("version bytes mis-match")
|
return nil, fmt.Errorf("version bytes mis-match")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Encrypt is used to encrypt in-memory for the BarrierEncryptor interface
|
||||||
|
func (b *AESGCMBarrier) Encrypt(key string, plaintext []byte) ([]byte, error) {
|
||||||
|
b.l.RLock()
|
||||||
|
defer b.l.RUnlock()
|
||||||
|
if b.sealed {
|
||||||
|
return nil, ErrBarrierSealed
|
||||||
|
}
|
||||||
|
|
||||||
|
term := b.keyring.ActiveTerm()
|
||||||
|
primary, err := b.aeadForTerm(term)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
ciphertext := b.encrypt(key, term, primary, plaintext)
|
||||||
|
return ciphertext, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decrypt is used to decrypt in-memory for the BarrierEncryptor interface
|
||||||
|
func (b *AESGCMBarrier) Decrypt(key string, ciphertext []byte) ([]byte, error) {
|
||||||
|
b.l.RLock()
|
||||||
|
defer b.l.RUnlock()
|
||||||
|
if b.sealed {
|
||||||
|
return nil, ErrBarrierSealed
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decrypt the ciphertext
|
||||||
|
plain, err := b.decryptKeyring(key, ciphertext)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("decryption failed: %v", err)
|
||||||
|
}
|
||||||
|
return plain, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *AESGCMBarrier) Keyring() (*Keyring, error) {
|
||||||
|
b.l.RLock()
|
||||||
|
defer b.l.RUnlock()
|
||||||
|
if b.sealed {
|
||||||
|
return nil, ErrBarrierSealed
|
||||||
|
}
|
||||||
|
|
||||||
|
return b.keyring.Clone(), nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ var (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// mockBarrier returns a physical backend, security barrier, and master key
|
// mockBarrier returns a physical backend, security barrier, and master key
|
||||||
func mockBarrier(t *testing.T) (physical.Backend, SecurityBarrier, []byte) {
|
func mockBarrier(t testing.TB) (physical.Backend, SecurityBarrier, []byte) {
|
||||||
|
|
||||||
inm := physical.NewInmem(logger)
|
inm := physical.NewInmem(logger)
|
||||||
b, err := NewAESGCMBarrier(inm)
|
b, err := NewAESGCMBarrier(inm)
|
||||||
@@ -433,3 +433,30 @@ func TestInitialize_KeyLength(t *testing.T) {
|
|||||||
t.Fatalf("key length protection failed")
|
t.Fatalf("key length protection failed")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestEncrypt_BarrierEncryptor(t *testing.T) {
|
||||||
|
inm := physical.NewInmem(logger)
|
||||||
|
b, err := NewAESGCMBarrier(inm)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("err: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize and unseal
|
||||||
|
key, _ := b.GenerateKey()
|
||||||
|
b.Initialize(key)
|
||||||
|
b.Unseal(key)
|
||||||
|
|
||||||
|
cipher, err := b.Encrypt("foo", []byte("quick brown fox"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("err: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
plain, err := b.Decrypt("foo", cipher)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("err: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if string(plain) != "quick brown fox" {
|
||||||
|
t.Fatalf("bad: %s", plain)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -69,14 +69,18 @@ func (v *BarrierView) Get(key string) (*logical.StorageEntry, error) {
|
|||||||
|
|
||||||
// logical.Storage impl.
|
// logical.Storage impl.
|
||||||
func (v *BarrierView) Put(entry *logical.StorageEntry) error {
|
func (v *BarrierView) Put(entry *logical.StorageEntry) error {
|
||||||
if v.readonly {
|
|
||||||
return logical.ErrReadOnly
|
|
||||||
}
|
|
||||||
if err := v.sanityCheck(entry.Key); err != nil {
|
if err := v.sanityCheck(entry.Key); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
expandedKey := v.expandKey(entry.Key)
|
||||||
|
|
||||||
|
if v.readonly {
|
||||||
|
return logical.ErrReadOnly
|
||||||
|
}
|
||||||
|
|
||||||
nested := &Entry{
|
nested := &Entry{
|
||||||
Key: v.expandKey(entry.Key),
|
Key: expandedKey,
|
||||||
Value: entry.Value,
|
Value: entry.Value,
|
||||||
}
|
}
|
||||||
return v.barrier.Put(nested)
|
return v.barrier.Put(nested)
|
||||||
@@ -84,13 +88,18 @@ func (v *BarrierView) Put(entry *logical.StorageEntry) error {
|
|||||||
|
|
||||||
// logical.Storage impl.
|
// logical.Storage impl.
|
||||||
func (v *BarrierView) Delete(key string) error {
|
func (v *BarrierView) Delete(key string) error {
|
||||||
if v.readonly {
|
|
||||||
return logical.ErrReadOnly
|
|
||||||
}
|
|
||||||
if err := v.sanityCheck(key); err != nil {
|
if err := v.sanityCheck(key); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return v.barrier.Delete(v.expandKey(key))
|
|
||||||
|
expandedKey := v.expandKey(key)
|
||||||
|
|
||||||
|
if v.readonly {
|
||||||
|
return logical.ErrReadOnly
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
return v.barrier.Delete(expandedKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SubView constructs a nested sub-view using the given prefix
|
// SubView constructs a nested sub-view using the given prefix
|
||||||
|
|||||||
@@ -1,27 +1,19 @@
|
|||||||
package vault
|
package vault
|
||||||
|
|
||||||
import "sort"
|
import (
|
||||||
|
"sort"
|
||||||
|
|
||||||
// Struct to identify user input errors.
|
"github.com/hashicorp/vault/logical"
|
||||||
// This is helpful in responding the appropriate status codes to clients
|
)
|
||||||
// from the HTTP endpoints.
|
|
||||||
type StatusBadRequest struct {
|
|
||||||
Err string
|
|
||||||
}
|
|
||||||
|
|
||||||
// Implementing error interface
|
|
||||||
func (s *StatusBadRequest) Error() string {
|
|
||||||
return s.Err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Capabilities is used to fetch the capabilities of the given token on the given path
|
// Capabilities is used to fetch the capabilities of the given token on the given path
|
||||||
func (c *Core) Capabilities(token, path string) ([]string, error) {
|
func (c *Core) Capabilities(token, path string) ([]string, error) {
|
||||||
if path == "" {
|
if path == "" {
|
||||||
return nil, &StatusBadRequest{Err: "missing path"}
|
return nil, &logical.StatusBadRequest{Err: "missing path"}
|
||||||
}
|
}
|
||||||
|
|
||||||
if token == "" {
|
if token == "" {
|
||||||
return nil, &StatusBadRequest{Err: "missing token"}
|
return nil, &logical.StatusBadRequest{Err: "missing token"}
|
||||||
}
|
}
|
||||||
|
|
||||||
te, err := c.tokenStore.Lookup(token)
|
te, err := c.tokenStore.Lookup(token)
|
||||||
@@ -29,7 +21,7 @@ func (c *Core) Capabilities(token, path string) ([]string, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if te == nil {
|
if te == nil {
|
||||||
return nil, &StatusBadRequest{Err: "invalid token"}
|
return nil, &logical.StatusBadRequest{Err: "invalid token"}
|
||||||
}
|
}
|
||||||
|
|
||||||
if te.Policies == nil {
|
if te.Policies == nil {
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ var (
|
|||||||
|
|
||||||
// This can be one of a few key types so the different params may or may not be filled
|
// This can be one of a few key types so the different params may or may not be filled
|
||||||
type clusterKeyParams struct {
|
type clusterKeyParams struct {
|
||||||
Type string `json:"type"`
|
Type string `json:"type" structs:"type" mapstructure:"type"`
|
||||||
X *big.Int `json:"x" structs:"x" mapstructure:"x"`
|
X *big.Int `json:"x" structs:"x" mapstructure:"x"`
|
||||||
Y *big.Int `json:"y" structs:"y" mapstructure:"y"`
|
Y *big.Int `json:"y" structs:"y" mapstructure:"y"`
|
||||||
D *big.Int `json:"d" structs:"d" mapstructure:"d"`
|
D *big.Int `json:"d" structs:"d" mapstructure:"d"`
|
||||||
@@ -339,45 +339,67 @@ func (c *Core) stopClusterListener() {
|
|||||||
c.logger.Info("core/stopClusterListener: success")
|
c.logger.Info("core/stopClusterListener: success")
|
||||||
}
|
}
|
||||||
|
|
||||||
// ClusterTLSConfig generates a TLS configuration based on the local cluster
|
// ClusterTLSConfig generates a TLS configuration based on the local/replicated
|
||||||
// key and cert.
|
// cluster key and cert.
|
||||||
func (c *Core) ClusterTLSConfig() (*tls.Config, error) {
|
func (c *Core) ClusterTLSConfig() (*tls.Config, error) {
|
||||||
cluster, err := c.Cluster()
|
cluster, err := c.Cluster()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if cluster == nil {
|
if cluster == nil {
|
||||||
return nil, fmt.Errorf("cluster information is nil")
|
return nil, fmt.Errorf("local cluster information is nil")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Prevent data races with the TLS parameters
|
// Prevent data races with the TLS parameters
|
||||||
c.clusterParamsLock.Lock()
|
c.clusterParamsLock.Lock()
|
||||||
defer c.clusterParamsLock.Unlock()
|
defer c.clusterParamsLock.Unlock()
|
||||||
|
|
||||||
if c.localClusterCert == nil || len(c.localClusterCert) == 0 {
|
forwarding := c.localClusterCert != nil && len(c.localClusterCert) > 0
|
||||||
return nil, fmt.Errorf("cluster certificate is nil")
|
|
||||||
|
var parsedCert *x509.Certificate
|
||||||
|
if forwarding {
|
||||||
|
parsedCert, err = x509.ParseCertificate(c.localClusterCert)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error parsing local cluster certificate: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// This is idempotent, so be sure it's been added
|
||||||
|
c.clusterCertPool.AddCert(parsedCert)
|
||||||
}
|
}
|
||||||
|
|
||||||
parsedCert, err := x509.ParseCertificate(c.localClusterCert)
|
nameLookup := func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||||
if err != nil {
|
c.clusterParamsLock.RLock()
|
||||||
return nil, fmt.Errorf("error parsing local cluster certificate: %v", err)
|
defer c.clusterParamsLock.RUnlock()
|
||||||
}
|
|
||||||
|
|
||||||
// This is idempotent, so be sure it's been added
|
if forwarding && clientHello.ServerName == parsedCert.Subject.CommonName {
|
||||||
c.clusterCertPool.AddCert(parsedCert)
|
return &tls.Certificate{
|
||||||
|
|
||||||
tlsConfig := &tls.Config{
|
|
||||||
Certificates: []tls.Certificate{
|
|
||||||
tls.Certificate{
|
|
||||||
Certificate: [][]byte{c.localClusterCert},
|
Certificate: [][]byte{c.localClusterCert},
|
||||||
PrivateKey: c.localClusterPrivateKey,
|
PrivateKey: c.localClusterPrivateKey,
|
||||||
},
|
}, nil
|
||||||
},
|
}
|
||||||
RootCAs: c.clusterCertPool,
|
|
||||||
ServerName: parsedCert.Subject.CommonName,
|
return nil, nil
|
||||||
ClientAuth: tls.RequireAndVerifyClientCert,
|
}
|
||||||
ClientCAs: c.clusterCertPool,
|
|
||||||
MinVersion: tls.VersionTLS12,
|
var clientCertificates []tls.Certificate
|
||||||
|
if forwarding {
|
||||||
|
clientCertificates = append(clientCertificates, tls.Certificate{
|
||||||
|
Certificate: [][]byte{c.localClusterCert},
|
||||||
|
PrivateKey: c.localClusterPrivateKey,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
tlsConfig := &tls.Config{
|
||||||
|
// We need this here for the client side
|
||||||
|
Certificates: clientCertificates,
|
||||||
|
RootCAs: c.clusterCertPool,
|
||||||
|
ClientAuth: tls.RequireAndVerifyClientCert,
|
||||||
|
ClientCAs: c.clusterCertPool,
|
||||||
|
GetCertificate: nameLookup,
|
||||||
|
MinVersion: tls.VersionTLS12,
|
||||||
|
}
|
||||||
|
if forwarding {
|
||||||
|
tlsConfig.ServerName = parsedCert.Subject.CommonName
|
||||||
}
|
}
|
||||||
|
|
||||||
return tlsConfig, nil
|
return tlsConfig, nil
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/hashicorp/vault/helper/consts"
|
||||||
"github.com/hashicorp/vault/helper/logformat"
|
"github.com/hashicorp/vault/helper/logformat"
|
||||||
"github.com/hashicorp/vault/logical"
|
"github.com/hashicorp/vault/logical"
|
||||||
"github.com/hashicorp/vault/physical"
|
"github.com/hashicorp/vault/physical"
|
||||||
@@ -100,7 +101,7 @@ func TestCluster_ListenForRequests(t *testing.T) {
|
|||||||
checkListenersFunc := func(expectFail bool) {
|
checkListenersFunc := func(expectFail bool) {
|
||||||
tlsConfig, err := cores[0].ClusterTLSConfig()
|
tlsConfig, err := cores[0].ClusterTLSConfig()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err.Error() != ErrSealed.Error() {
|
if err.Error() != consts.ErrSealed.Error() {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
tlsConfig = lastTLSConfig
|
tlsConfig = lastTLSConfig
|
||||||
|
|||||||
168
vault/core.go
168
vault/core.go
@@ -1,9 +1,9 @@
|
|||||||
package vault
|
package vault
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"crypto"
|
"crypto"
|
||||||
"crypto/ecdsa"
|
"crypto/ecdsa"
|
||||||
|
"crypto/subtle"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
@@ -23,6 +23,7 @@ import (
|
|||||||
"github.com/hashicorp/go-multierror"
|
"github.com/hashicorp/go-multierror"
|
||||||
"github.com/hashicorp/go-uuid"
|
"github.com/hashicorp/go-uuid"
|
||||||
"github.com/hashicorp/vault/audit"
|
"github.com/hashicorp/vault/audit"
|
||||||
|
"github.com/hashicorp/vault/helper/consts"
|
||||||
"github.com/hashicorp/vault/helper/errutil"
|
"github.com/hashicorp/vault/helper/errutil"
|
||||||
"github.com/hashicorp/vault/helper/jsonutil"
|
"github.com/hashicorp/vault/helper/jsonutil"
|
||||||
"github.com/hashicorp/vault/helper/logformat"
|
"github.com/hashicorp/vault/helper/logformat"
|
||||||
@@ -56,17 +57,14 @@ const (
|
|||||||
// leaderPrefixCleanDelay is how long to wait between deletions
|
// leaderPrefixCleanDelay is how long to wait between deletions
|
||||||
// of orphaned leader keys, to prevent slamming the backend.
|
// of orphaned leader keys, to prevent slamming the backend.
|
||||||
leaderPrefixCleanDelay = 200 * time.Millisecond
|
leaderPrefixCleanDelay = 200 * time.Millisecond
|
||||||
|
|
||||||
|
// coreKeyringCanaryPath is used as a canary to indicate to replicated
|
||||||
|
// clusters that they need to perform a rekey operation synchronously; this
|
||||||
|
// isn't keyring-canary to avoid ignoring it when ignoring core/keyring
|
||||||
|
coreKeyringCanaryPath = "core/canary-keyring"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
// ErrSealed is returned if an operation is performed on
|
|
||||||
// a sealed barrier. No operation is expected to succeed before unsealing
|
|
||||||
ErrSealed = errors.New("Vault is sealed")
|
|
||||||
|
|
||||||
// ErrStandby is returned if an operation is performed on
|
|
||||||
// a standby Vault. No operation is expected to succeed until active.
|
|
||||||
ErrStandby = errors.New("Vault is in standby mode")
|
|
||||||
|
|
||||||
// ErrAlreadyInit is returned if the core is already
|
// ErrAlreadyInit is returned if the core is already
|
||||||
// initialized. This prevents a re-initialization.
|
// initialized. This prevents a re-initialization.
|
||||||
ErrAlreadyInit = errors.New("Vault is already initialized")
|
ErrAlreadyInit = errors.New("Vault is already initialized")
|
||||||
@@ -87,6 +85,12 @@ var (
|
|||||||
// step down of the active node, to prevent instantly regrabbing the lock.
|
// step down of the active node, to prevent instantly regrabbing the lock.
|
||||||
// It's var not const so that tests can manipulate it.
|
// It's var not const so that tests can manipulate it.
|
||||||
manualStepDownSleepPeriod = 10 * time.Second
|
manualStepDownSleepPeriod = 10 * time.Second
|
||||||
|
|
||||||
|
// Functions only in the Enterprise version
|
||||||
|
enterprisePostUnseal = enterprisePostUnsealImpl
|
||||||
|
enterprisePreSeal = enterprisePreSealImpl
|
||||||
|
startReplication = startReplicationImpl
|
||||||
|
stopReplication = stopReplicationImpl
|
||||||
)
|
)
|
||||||
|
|
||||||
// ReloadFunc are functions that are called when a reload is requested.
|
// ReloadFunc are functions that are called when a reload is requested.
|
||||||
@@ -133,6 +137,11 @@ type unlockInformation struct {
|
|||||||
// interface for API handlers and is responsible for managing the logical and physical
|
// interface for API handlers and is responsible for managing the logical and physical
|
||||||
// backends, router, security barrier, and audit trails.
|
// backends, router, security barrier, and audit trails.
|
||||||
type Core struct {
|
type Core struct {
|
||||||
|
// N.B.: This is used to populate a dev token down replication, as
|
||||||
|
// otherwise, after replication is started, a dev would have to go through
|
||||||
|
// the generate-root process simply to talk to the new follower cluster.
|
||||||
|
devToken string
|
||||||
|
|
||||||
// HABackend may be available depending on the physical backend
|
// HABackend may be available depending on the physical backend
|
||||||
ha physical.HABackend
|
ha physical.HABackend
|
||||||
|
|
||||||
@@ -268,7 +277,7 @@ type Core struct {
|
|||||||
//
|
//
|
||||||
// Name
|
// Name
|
||||||
clusterName string
|
clusterName string
|
||||||
// Used to modify cluster TLS params
|
// Used to modify cluster parameters
|
||||||
clusterParamsLock sync.RWMutex
|
clusterParamsLock sync.RWMutex
|
||||||
// The private key stored in the barrier used for establishing
|
// The private key stored in the barrier used for establishing
|
||||||
// mutually-authenticated connections between Vault cluster members
|
// mutually-authenticated connections between Vault cluster members
|
||||||
@@ -310,11 +319,13 @@ type Core struct {
|
|||||||
|
|
||||||
// replicationState keeps the current replication state cached for quick
|
// replicationState keeps the current replication state cached for quick
|
||||||
// lookup
|
// lookup
|
||||||
replicationState logical.ReplicationState
|
replicationState consts.ReplicationState
|
||||||
}
|
}
|
||||||
|
|
||||||
// CoreConfig is used to parameterize a core
|
// CoreConfig is used to parameterize a core
|
||||||
type CoreConfig struct {
|
type CoreConfig struct {
|
||||||
|
DevToken string `json:"dev_token" structs:"dev_token" mapstructure:"dev_token"`
|
||||||
|
|
||||||
LogicalBackends map[string]logical.Factory `json:"logical_backends" structs:"logical_backends" mapstructure:"logical_backends"`
|
LogicalBackends map[string]logical.Factory `json:"logical_backends" structs:"logical_backends" mapstructure:"logical_backends"`
|
||||||
|
|
||||||
CredentialBackends map[string]logical.Factory `json:"credential_backends" structs:"credential_backends" mapstructure:"credential_backends"`
|
CredentialBackends map[string]logical.Factory `json:"credential_backends" structs:"credential_backends" mapstructure:"credential_backends"`
|
||||||
@@ -390,6 +401,30 @@ func NewCore(conf *CoreConfig) (*Core, error) {
|
|||||||
conf.Logger = logformat.NewVaultLogger(log.LevelTrace)
|
conf.Logger = logformat.NewVaultLogger(log.LevelTrace)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Setup the core
|
||||||
|
c := &Core{
|
||||||
|
redirectAddr: conf.RedirectAddr,
|
||||||
|
clusterAddr: conf.ClusterAddr,
|
||||||
|
physical: conf.Physical,
|
||||||
|
seal: conf.Seal,
|
||||||
|
router: NewRouter(),
|
||||||
|
sealed: true,
|
||||||
|
standby: true,
|
||||||
|
logger: conf.Logger,
|
||||||
|
defaultLeaseTTL: conf.DefaultLeaseTTL,
|
||||||
|
maxLeaseTTL: conf.MaxLeaseTTL,
|
||||||
|
cachingDisabled: conf.DisableCache,
|
||||||
|
clusterName: conf.ClusterName,
|
||||||
|
clusterCertPool: x509.NewCertPool(),
|
||||||
|
clusterListenerShutdownCh: make(chan struct{}),
|
||||||
|
clusterListenerShutdownSuccessCh: make(chan struct{}),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wrap the physical backend in a cache layer if enabled and not already wrapped
|
||||||
|
if _, isCache := conf.Physical.(*physical.Cache); !conf.DisableCache && !isCache {
|
||||||
|
c.physical = physical.NewCache(conf.Physical, conf.CacheSize, conf.Logger)
|
||||||
|
}
|
||||||
|
|
||||||
if !conf.DisableMlock {
|
if !conf.DisableMlock {
|
||||||
// Ensure our memory usage is locked into physical RAM
|
// Ensure our memory usage is locked into physical RAM
|
||||||
if err := mlock.LockMemory(); err != nil {
|
if err := mlock.LockMemory(); err != nil {
|
||||||
@@ -407,36 +442,12 @@ func NewCore(conf *CoreConfig) (*Core, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Construct a new AES-GCM barrier
|
// Construct a new AES-GCM barrier
|
||||||
barrier, err := NewAESGCMBarrier(conf.Physical)
|
var err error
|
||||||
|
c.barrier, err = NewAESGCMBarrier(c.physical)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("barrier setup failed: %v", err)
|
return nil, fmt.Errorf("barrier setup failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Setup the core
|
|
||||||
c := &Core{
|
|
||||||
redirectAddr: conf.RedirectAddr,
|
|
||||||
clusterAddr: conf.ClusterAddr,
|
|
||||||
physical: conf.Physical,
|
|
||||||
seal: conf.Seal,
|
|
||||||
barrier: barrier,
|
|
||||||
router: NewRouter(),
|
|
||||||
sealed: true,
|
|
||||||
standby: true,
|
|
||||||
logger: conf.Logger,
|
|
||||||
defaultLeaseTTL: conf.DefaultLeaseTTL,
|
|
||||||
maxLeaseTTL: conf.MaxLeaseTTL,
|
|
||||||
cachingDisabled: conf.DisableCache,
|
|
||||||
clusterName: conf.ClusterName,
|
|
||||||
clusterCertPool: x509.NewCertPool(),
|
|
||||||
clusterListenerShutdownCh: make(chan struct{}),
|
|
||||||
clusterListenerShutdownSuccessCh: make(chan struct{}),
|
|
||||||
}
|
|
||||||
|
|
||||||
// Wrap the backend in a cache unless disabled
|
|
||||||
if _, isCache := conf.Physical.(*physical.Cache); !conf.DisableCache && !isCache {
|
|
||||||
c.physical = physical.NewCache(conf.Physical, conf.CacheSize, conf.Logger)
|
|
||||||
}
|
|
||||||
|
|
||||||
if conf.HAPhysical != nil && conf.HAPhysical.HAEnabled() {
|
if conf.HAPhysical != nil && conf.HAPhysical.HAEnabled() {
|
||||||
c.ha = conf.HAPhysical
|
c.ha = conf.HAPhysical
|
||||||
}
|
}
|
||||||
@@ -518,10 +529,10 @@ func (c *Core) LookupToken(token string) (*TokenEntry, error) {
|
|||||||
c.stateLock.RLock()
|
c.stateLock.RLock()
|
||||||
defer c.stateLock.RUnlock()
|
defer c.stateLock.RUnlock()
|
||||||
if c.sealed {
|
if c.sealed {
|
||||||
return nil, ErrSealed
|
return nil, consts.ErrSealed
|
||||||
}
|
}
|
||||||
if c.standby {
|
if c.standby {
|
||||||
return nil, ErrStandby
|
return nil, consts.ErrStandby
|
||||||
}
|
}
|
||||||
|
|
||||||
// Many tests don't have a token store running
|
// Many tests don't have a token store running
|
||||||
@@ -656,7 +667,7 @@ func (c *Core) Leader() (isLeader bool, leaderAddr string, err error) {
|
|||||||
|
|
||||||
// Check if sealed
|
// Check if sealed
|
||||||
if c.sealed {
|
if c.sealed {
|
||||||
return false, "", ErrSealed
|
return false, "", consts.ErrSealed
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if HA enabled
|
// Check if HA enabled
|
||||||
@@ -803,17 +814,29 @@ func (c *Core) Unseal(key []byte) (bool, error) {
|
|||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
masterKey, err := c.unsealPart(config, key)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
if masterKey != nil {
|
||||||
|
return c.unsealInternal(masterKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Core) unsealPart(config *SealConfig, key []byte) ([]byte, error) {
|
||||||
// Check if we already have this piece
|
// Check if we already have this piece
|
||||||
if c.unlockInfo != nil {
|
if c.unlockInfo != nil {
|
||||||
for _, existing := range c.unlockInfo.Parts {
|
for _, existing := range c.unlockInfo.Parts {
|
||||||
if bytes.Equal(existing, key) {
|
if subtle.ConstantTimeCompare(existing, key) == 1 {
|
||||||
return false, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
uuid, err := uuid.GenerateUUID()
|
uuid, err := uuid.GenerateUUID()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return nil, err
|
||||||
}
|
}
|
||||||
c.unlockInfo = &unlockInformation{
|
c.unlockInfo = &unlockInformation{
|
||||||
Nonce: uuid,
|
Nonce: uuid,
|
||||||
@@ -828,27 +851,37 @@ func (c *Core) Unseal(key []byte) (bool, error) {
|
|||||||
if c.logger.IsDebug() {
|
if c.logger.IsDebug() {
|
||||||
c.logger.Debug("core: cannot unseal, not enough keys", "keys", len(c.unlockInfo.Parts), "threshold", config.SecretThreshold, "nonce", c.unlockInfo.Nonce)
|
c.logger.Debug("core: cannot unseal, not enough keys", "keys", len(c.unlockInfo.Parts), "threshold", config.SecretThreshold, "nonce", c.unlockInfo.Nonce)
|
||||||
}
|
}
|
||||||
return false, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Best-effort memzero of unlock parts once we're done with them
|
||||||
|
defer func() {
|
||||||
|
for i, _ := range c.unlockInfo.Parts {
|
||||||
|
memzero(c.unlockInfo.Parts[i])
|
||||||
|
}
|
||||||
|
c.unlockInfo = nil
|
||||||
|
}()
|
||||||
|
|
||||||
// Recover the master key
|
// Recover the master key
|
||||||
var masterKey []byte
|
var masterKey []byte
|
||||||
|
var err error
|
||||||
if config.SecretThreshold == 1 {
|
if config.SecretThreshold == 1 {
|
||||||
masterKey = c.unlockInfo.Parts[0]
|
masterKey = make([]byte, len(c.unlockInfo.Parts[0]))
|
||||||
c.unlockInfo = nil
|
copy(masterKey, c.unlockInfo.Parts[0])
|
||||||
} else {
|
} else {
|
||||||
masterKey, err = shamir.Combine(c.unlockInfo.Parts)
|
masterKey, err = shamir.Combine(c.unlockInfo.Parts)
|
||||||
c.unlockInfo = nil
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, fmt.Errorf("failed to compute master key: %v", err)
|
return nil, fmt.Errorf("failed to compute master key: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
defer memzero(masterKey)
|
|
||||||
|
|
||||||
return c.unsealInternal(masterKey)
|
return masterKey, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// This must be called with the state write lock held
|
||||||
func (c *Core) unsealInternal(masterKey []byte) (bool, error) {
|
func (c *Core) unsealInternal(masterKey []byte) (bool, error) {
|
||||||
|
defer memzero(masterKey)
|
||||||
|
|
||||||
// Attempt to unlock
|
// Attempt to unlock
|
||||||
if err := c.barrier.Unseal(masterKey); err != nil {
|
if err := c.barrier.Unseal(masterKey); err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
@@ -867,12 +900,14 @@ func (c *Core) unsealInternal(masterKey []byte) (bool, error) {
|
|||||||
c.logger.Warn("core: vault is sealed")
|
c.logger.Warn("core: vault is sealed")
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := c.postUnseal(); err != nil {
|
if err := c.postUnseal(); err != nil {
|
||||||
c.logger.Error("core: post-unseal setup failed", "error", err)
|
c.logger.Error("core: post-unseal setup failed", "error", err)
|
||||||
c.barrier.Seal()
|
c.barrier.Seal()
|
||||||
c.logger.Warn("core: vault is sealed")
|
c.logger.Warn("core: vault is sealed")
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
|
|
||||||
c.standby = false
|
c.standby = false
|
||||||
} else {
|
} else {
|
||||||
// Go to standby mode, wait until we are active to unseal
|
// Go to standby mode, wait until we are active to unseal
|
||||||
@@ -1168,6 +1203,7 @@ func (c *Core) postUnseal() (retErr error) {
|
|||||||
if purgable, ok := c.physical.(physical.Purgable); ok {
|
if purgable, ok := c.physical.(physical.Purgable); ok {
|
||||||
purgable.Purge()
|
purgable.Purge()
|
||||||
}
|
}
|
||||||
|
|
||||||
// HA mode requires us to handle keyring rotation and rekeying
|
// HA mode requires us to handle keyring rotation and rekeying
|
||||||
if c.ha != nil {
|
if c.ha != nil {
|
||||||
// We want to reload these from disk so that in case of a rekey we're
|
// We want to reload these from disk so that in case of a rekey we're
|
||||||
@@ -1190,6 +1226,9 @@ func (c *Core) postUnseal() (retErr error) {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if err := enterprisePostUnseal(c); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
if err := c.ensureWrappingKey(); err != nil {
|
if err := c.ensureWrappingKey(); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -1251,6 +1290,7 @@ func (c *Core) preSeal() error {
|
|||||||
c.metricsCh = nil
|
c.metricsCh = nil
|
||||||
}
|
}
|
||||||
var result error
|
var result error
|
||||||
|
|
||||||
if c.ha != nil {
|
if c.ha != nil {
|
||||||
c.stopClusterListener()
|
c.stopClusterListener()
|
||||||
}
|
}
|
||||||
@@ -1273,6 +1313,10 @@ func (c *Core) preSeal() error {
|
|||||||
if err := c.unloadMounts(); err != nil {
|
if err := c.unloadMounts(); err != nil {
|
||||||
result = multierror.Append(result, errwrap.Wrapf("error unloading mounts: {{err}}", err))
|
result = multierror.Append(result, errwrap.Wrapf("error unloading mounts: {{err}}", err))
|
||||||
}
|
}
|
||||||
|
if err := enterprisePreSeal(c); err != nil {
|
||||||
|
result = multierror.Append(result, err)
|
||||||
|
}
|
||||||
|
|
||||||
// Purge the backend if supported
|
// Purge the backend if supported
|
||||||
if purgable, ok := c.physical.(physical.Purgable); ok {
|
if purgable, ok := c.physical.(physical.Purgable); ok {
|
||||||
purgable.Purge()
|
purgable.Purge()
|
||||||
@@ -1281,6 +1325,22 @@ func (c *Core) preSeal() error {
|
|||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func enterprisePostUnsealImpl(c *Core) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func enterprisePreSealImpl(c *Core) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func startReplicationImpl(c *Core) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func stopReplicationImpl(c *Core) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// runStandby is a long running routine that is used when an HA backend
|
// runStandby is a long running routine that is used when an HA backend
|
||||||
// is enabled. It waits until we are leader and switches this Vault to
|
// is enabled. It waits until we are leader and switches this Vault to
|
||||||
// active.
|
// active.
|
||||||
@@ -1599,6 +1659,14 @@ func (c *Core) emitMetrics(stopCh chan struct{}) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Core) ReplicationState() consts.ReplicationState {
|
||||||
|
var state consts.ReplicationState
|
||||||
|
c.clusterParamsLock.RLock()
|
||||||
|
state = c.replicationState
|
||||||
|
c.clusterParamsLock.RUnlock()
|
||||||
|
return state
|
||||||
|
}
|
||||||
|
|
||||||
func (c *Core) SealAccess() *SealAccess {
|
func (c *Core) SealAccess() *SealAccess {
|
||||||
sa := &SealAccess{}
|
sa := &SealAccess{}
|
||||||
sa.SetSeal(c.seal)
|
sa.SetSeal(c.seal)
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"github.com/hashicorp/errwrap"
|
"github.com/hashicorp/errwrap"
|
||||||
"github.com/hashicorp/go-uuid"
|
"github.com/hashicorp/go-uuid"
|
||||||
"github.com/hashicorp/vault/audit"
|
"github.com/hashicorp/vault/audit"
|
||||||
|
"github.com/hashicorp/vault/helper/consts"
|
||||||
"github.com/hashicorp/vault/helper/logformat"
|
"github.com/hashicorp/vault/helper/logformat"
|
||||||
"github.com/hashicorp/vault/logical"
|
"github.com/hashicorp/vault/logical"
|
||||||
"github.com/hashicorp/vault/physical"
|
"github.com/hashicorp/vault/physical"
|
||||||
@@ -198,7 +199,7 @@ func TestCore_Route_Sealed(t *testing.T) {
|
|||||||
Path: "sys/mounts",
|
Path: "sys/mounts",
|
||||||
}
|
}
|
||||||
_, err := c.HandleRequest(req)
|
_, err := c.HandleRequest(req)
|
||||||
if err != ErrSealed {
|
if err != consts.ErrSealed {
|
||||||
t.Fatalf("err: %v", err)
|
t.Fatalf("err: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1541,7 +1542,7 @@ func testCore_Standby_Common(t *testing.T, inm physical.Backend, inmha physical.
|
|||||||
|
|
||||||
// Request should fail in standby mode
|
// Request should fail in standby mode
|
||||||
_, err = core2.HandleRequest(req)
|
_, err = core2.HandleRequest(req)
|
||||||
if err != ErrStandby {
|
if err != consts.ErrStandby {
|
||||||
t.Fatalf("err: %v", err)
|
t.Fatalf("err: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package vault
|
|||||||
import (
|
import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/hashicorp/vault/helper/consts"
|
||||||
"github.com/hashicorp/vault/logical"
|
"github.com/hashicorp/vault/logical"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -79,8 +80,8 @@ func (d dynamicSystemView) CachingDisabled() bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Checks if this is a primary Vault instance.
|
// Checks if this is a primary Vault instance.
|
||||||
func (d dynamicSystemView) ReplicationState() logical.ReplicationState {
|
func (d dynamicSystemView) ReplicationState() consts.ReplicationState {
|
||||||
var state logical.ReplicationState
|
var state consts.ReplicationState
|
||||||
d.core.clusterParamsLock.RLock()
|
d.core.clusterParamsLock.RLock()
|
||||||
state = d.core.replicationState
|
state = d.core.replicationState
|
||||||
d.core.clusterParamsLock.RUnlock()
|
d.core.clusterParamsLock.RUnlock()
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
log "github.com/mgutz/logxi/v1"
|
log "github.com/mgutz/logxi/v1"
|
||||||
|
|
||||||
"github.com/hashicorp/go-uuid"
|
"github.com/hashicorp/go-uuid"
|
||||||
|
"github.com/hashicorp/vault/helper/consts"
|
||||||
"github.com/hashicorp/vault/helper/jsonutil"
|
"github.com/hashicorp/vault/helper/jsonutil"
|
||||||
"github.com/hashicorp/vault/logical"
|
"github.com/hashicorp/vault/logical"
|
||||||
)
|
)
|
||||||
@@ -125,46 +126,114 @@ func (m *ExpirationManager) Restore() error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to scan for leases: %v", err)
|
return fmt.Errorf("failed to scan for leases: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
m.logger.Debug("expiration: leases collected", "num_existing", len(existing))
|
m.logger.Debug("expiration: leases collected", "num_existing", len(existing))
|
||||||
|
|
||||||
// Restore each key
|
// Make the channels used for the worker pool
|
||||||
for i, leaseID := range existing {
|
broker := make(chan string)
|
||||||
if i%500 == 0 {
|
quit := make(chan bool)
|
||||||
m.logger.Trace("expiration: leases loading", "progress", i)
|
// Buffer these channels to prevent deadlocks
|
||||||
}
|
errs := make(chan error, len(existing))
|
||||||
// Load the entry
|
result := make(chan *leaseEntry, len(existing))
|
||||||
le, err := m.loadEntry(leaseID)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// If there is no entry, nothing to restore
|
// Use a wait group
|
||||||
if le == nil {
|
wg := &sync.WaitGroup{}
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// If there is no expiry time, don't do anything
|
// Create 64 workers to distribute work to
|
||||||
if le.ExpireTime.IsZero() {
|
for i := 0; i < consts.ExpirationRestoreWorkerCount; i++ {
|
||||||
continue
|
wg.Add(1)
|
||||||
}
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
|
||||||
// Determine the remaining time to expiration
|
for {
|
||||||
expires := le.ExpireTime.Sub(time.Now())
|
select {
|
||||||
if expires <= 0 {
|
case leaseID, ok := <-broker:
|
||||||
expires = minRevokeDelay
|
// broker has been closed, we are done
|
||||||
}
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// Setup revocation timer
|
le, err := m.loadEntry(leaseID)
|
||||||
m.pending[le.LeaseID] = time.AfterFunc(expires, func() {
|
if err != nil {
|
||||||
m.expireID(le.LeaseID)
|
errs <- err
|
||||||
})
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write results out to the result channel
|
||||||
|
result <- le
|
||||||
|
|
||||||
|
// quit early
|
||||||
|
case <-quit:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Distribute the collected keys to the workers in a go routine
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
for i, leaseID := range existing {
|
||||||
|
if i%500 == 0 {
|
||||||
|
m.logger.Trace("expiration: leases loading", "progress", i)
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-quit:
|
||||||
|
return
|
||||||
|
|
||||||
|
default:
|
||||||
|
broker <- leaseID
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close the broker, causing worker routines to exit
|
||||||
|
close(broker)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Restore each key by pulling from the result chan
|
||||||
|
for i := 0; i < len(existing); i++ {
|
||||||
|
select {
|
||||||
|
case err := <-errs:
|
||||||
|
// Close all go routines
|
||||||
|
close(quit)
|
||||||
|
|
||||||
|
return err
|
||||||
|
|
||||||
|
case le := <-result:
|
||||||
|
|
||||||
|
// If there is no entry, nothing to restore
|
||||||
|
if le == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// If there is no expiry time, don't do anything
|
||||||
|
if le.ExpireTime.IsZero() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Determine the remaining time to expiration
|
||||||
|
expires := le.ExpireTime.Sub(time.Now())
|
||||||
|
if expires <= 0 {
|
||||||
|
expires = minRevokeDelay
|
||||||
|
}
|
||||||
|
|
||||||
|
// Setup revocation timer
|
||||||
|
m.pending[le.LeaseID] = time.AfterFunc(expires, func() {
|
||||||
|
m.expireID(le.LeaseID)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Let all go routines finish
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
if len(m.pending) > 0 {
|
if len(m.pending) > 0 {
|
||||||
if m.logger.IsInfo() {
|
if m.logger.IsInfo() {
|
||||||
m.logger.Info("expire: leases restored", "restored_lease_count", len(m.pending))
|
m.logger.Info("expire: leases restored", "restored_lease_count", len(m.pending))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -2,23 +2,131 @@ package vault
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"os"
|
||||||
"reflect"
|
"reflect"
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/hashicorp/go-uuid"
|
"github.com/hashicorp/go-uuid"
|
||||||
|
"github.com/hashicorp/vault/helper/logformat"
|
||||||
"github.com/hashicorp/vault/logical"
|
"github.com/hashicorp/vault/logical"
|
||||||
"github.com/hashicorp/vault/logical/framework"
|
"github.com/hashicorp/vault/logical/framework"
|
||||||
|
"github.com/hashicorp/vault/physical"
|
||||||
|
log "github.com/mgutz/logxi/v1"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
testImagePull sync.Once
|
||||||
)
|
)
|
||||||
|
|
||||||
// mockExpiration returns a mock expiration manager
|
// mockExpiration returns a mock expiration manager
|
||||||
func mockExpiration(t *testing.T) *ExpirationManager {
|
func mockExpiration(t testing.TB) *ExpirationManager {
|
||||||
_, ts, _, _ := TestCoreWithTokenStore(t)
|
_, ts, _, _ := TestCoreWithTokenStore(t)
|
||||||
return ts.expiration
|
return ts.expiration
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func mockBackendExpiration(t testing.TB, backend physical.Backend) (*Core, *ExpirationManager) {
|
||||||
|
c, ts, _, _ := TestCoreWithBackendTokenStore(t, backend)
|
||||||
|
return c, ts.expiration
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkExpiration_Restore_Etcd(b *testing.B) {
|
||||||
|
addr := os.Getenv("PHYSICAL_BACKEND_BENCHMARK_ADDR")
|
||||||
|
randPath := fmt.Sprintf("vault-%d/", time.Now().Unix())
|
||||||
|
|
||||||
|
logger := logformat.NewVaultLogger(log.LevelTrace)
|
||||||
|
physicalBackend, err := physical.NewBackend("etcd", logger, map[string]string{
|
||||||
|
"address": addr,
|
||||||
|
"path": randPath,
|
||||||
|
"max_parallel": "256",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("err: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
benchmarkExpirationBackend(b, physicalBackend, 10000) // 10,000 leases
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkExpiration_Restore_Consul(b *testing.B) {
|
||||||
|
addr := os.Getenv("PHYSICAL_BACKEND_BENCHMARK_ADDR")
|
||||||
|
randPath := fmt.Sprintf("vault-%d/", time.Now().Unix())
|
||||||
|
|
||||||
|
logger := logformat.NewVaultLogger(log.LevelTrace)
|
||||||
|
physicalBackend, err := physical.NewBackend("consul", logger, map[string]string{
|
||||||
|
"address": addr,
|
||||||
|
"path": randPath,
|
||||||
|
"max_parallel": "256",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("err: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
benchmarkExpirationBackend(b, physicalBackend, 10000) // 10,000 leases
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkExpiration_Restore_InMem(b *testing.B) {
|
||||||
|
logger := logformat.NewVaultLogger(log.LevelTrace)
|
||||||
|
benchmarkExpirationBackend(b, physical.NewInmem(logger), 100000) // 100,000 Leases
|
||||||
|
}
|
||||||
|
|
||||||
|
func benchmarkExpirationBackend(b *testing.B, physicalBackend physical.Backend, numLeases int) {
|
||||||
|
c, exp := mockBackendExpiration(b, physicalBackend)
|
||||||
|
noop := &NoopBackend{}
|
||||||
|
view := NewBarrierView(c.barrier, "logical/")
|
||||||
|
meUUID, err := uuid.GenerateUUID()
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
exp.router.Mount(noop, "prod/aws/", &MountEntry{UUID: meUUID}, view)
|
||||||
|
|
||||||
|
// Register fake leases
|
||||||
|
for i := 0; i < numLeases; i++ {
|
||||||
|
pathUUID, err := uuid.GenerateUUID()
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req := &logical.Request{
|
||||||
|
Operation: logical.ReadOperation,
|
||||||
|
Path: "prod/aws/" + pathUUID,
|
||||||
|
}
|
||||||
|
resp := &logical.Response{
|
||||||
|
Secret: &logical.Secret{
|
||||||
|
LeaseOptions: logical.LeaseOptions{
|
||||||
|
TTL: 400 * time.Second,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Data: map[string]interface{}{
|
||||||
|
"access_key": "xyz",
|
||||||
|
"secret_key": "abcd",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
_, err = exp.Register(req, resp)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("err: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop everything
|
||||||
|
err = exp.Stop()
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("err: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
err = exp.Restore()
|
||||||
|
// Restore
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("err: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
b.StopTimer()
|
||||||
|
}
|
||||||
|
|
||||||
func TestExpiration_Restore(t *testing.T) {
|
func TestExpiration_Restore(t *testing.T) {
|
||||||
exp := mockExpiration(t)
|
exp := mockExpiration(t)
|
||||||
noop := &NoopBackend{}
|
noop := &NoopBackend{}
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/hashicorp/go-uuid"
|
"github.com/hashicorp/go-uuid"
|
||||||
|
"github.com/hashicorp/vault/helper/consts"
|
||||||
"github.com/hashicorp/vault/helper/pgpkeys"
|
"github.com/hashicorp/vault/helper/pgpkeys"
|
||||||
"github.com/hashicorp/vault/helper/xor"
|
"github.com/hashicorp/vault/helper/xor"
|
||||||
"github.com/hashicorp/vault/shamir"
|
"github.com/hashicorp/vault/shamir"
|
||||||
@@ -34,10 +35,10 @@ func (c *Core) GenerateRootProgress() (int, error) {
|
|||||||
c.stateLock.RLock()
|
c.stateLock.RLock()
|
||||||
defer c.stateLock.RUnlock()
|
defer c.stateLock.RUnlock()
|
||||||
if c.sealed {
|
if c.sealed {
|
||||||
return 0, ErrSealed
|
return 0, consts.ErrSealed
|
||||||
}
|
}
|
||||||
if c.standby {
|
if c.standby {
|
||||||
return 0, ErrStandby
|
return 0, consts.ErrStandby
|
||||||
}
|
}
|
||||||
|
|
||||||
c.generateRootLock.Lock()
|
c.generateRootLock.Lock()
|
||||||
@@ -52,10 +53,10 @@ func (c *Core) GenerateRootConfiguration() (*GenerateRootConfig, error) {
|
|||||||
c.stateLock.RLock()
|
c.stateLock.RLock()
|
||||||
defer c.stateLock.RUnlock()
|
defer c.stateLock.RUnlock()
|
||||||
if c.sealed {
|
if c.sealed {
|
||||||
return nil, ErrSealed
|
return nil, consts.ErrSealed
|
||||||
}
|
}
|
||||||
if c.standby {
|
if c.standby {
|
||||||
return nil, ErrStandby
|
return nil, consts.ErrStandby
|
||||||
}
|
}
|
||||||
|
|
||||||
c.generateRootLock.Lock()
|
c.generateRootLock.Lock()
|
||||||
@@ -101,10 +102,10 @@ func (c *Core) GenerateRootInit(otp, pgpKey string) error {
|
|||||||
c.stateLock.RLock()
|
c.stateLock.RLock()
|
||||||
defer c.stateLock.RUnlock()
|
defer c.stateLock.RUnlock()
|
||||||
if c.sealed {
|
if c.sealed {
|
||||||
return ErrSealed
|
return consts.ErrSealed
|
||||||
}
|
}
|
||||||
if c.standby {
|
if c.standby {
|
||||||
return ErrStandby
|
return consts.ErrStandby
|
||||||
}
|
}
|
||||||
|
|
||||||
c.generateRootLock.Lock()
|
c.generateRootLock.Lock()
|
||||||
@@ -170,10 +171,10 @@ func (c *Core) GenerateRootUpdate(key []byte, nonce string) (*GenerateRootResult
|
|||||||
c.stateLock.RLock()
|
c.stateLock.RLock()
|
||||||
defer c.stateLock.RUnlock()
|
defer c.stateLock.RUnlock()
|
||||||
if c.sealed {
|
if c.sealed {
|
||||||
return nil, ErrSealed
|
return nil, consts.ErrSealed
|
||||||
}
|
}
|
||||||
if c.standby {
|
if c.standby {
|
||||||
return nil, ErrStandby
|
return nil, consts.ErrStandby
|
||||||
}
|
}
|
||||||
|
|
||||||
c.generateRootLock.Lock()
|
c.generateRootLock.Lock()
|
||||||
@@ -308,10 +309,10 @@ func (c *Core) GenerateRootCancel() error {
|
|||||||
c.stateLock.RLock()
|
c.stateLock.RLock()
|
||||||
defer c.stateLock.RUnlock()
|
defer c.stateLock.RUnlock()
|
||||||
if c.sealed {
|
if c.sealed {
|
||||||
return ErrSealed
|
return consts.ErrSealed
|
||||||
}
|
}
|
||||||
if c.standby {
|
if c.standby {
|
||||||
return ErrStandby
|
return consts.ErrStandby
|
||||||
}
|
}
|
||||||
|
|
||||||
c.generateRootLock.Lock()
|
c.generateRootLock.Lock()
|
||||||
|
|||||||
@@ -133,36 +133,12 @@ func (c *Core) Initialize(initParams *InitParams) (*InitResult, error) {
|
|||||||
return nil, fmt.Errorf("error initializing seal: %v", err)
|
return nil, fmt.Errorf("error initializing seal: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = c.seal.SetBarrierConfig(barrierConfig)
|
|
||||||
if err != nil {
|
|
||||||
c.logger.Error("core: failed to save barrier configuration", "error", err)
|
|
||||||
return nil, fmt.Errorf("barrier configuration saving failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
barrierKey, barrierUnsealKeys, err := c.generateShares(barrierConfig)
|
barrierKey, barrierUnsealKeys, err := c.generateShares(barrierConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.logger.Error("core: error generating shares", "error", err)
|
c.logger.Error("core: error generating shares", "error", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// If we are storing shares, pop them out of the returned results and push
|
|
||||||
// them through the seal
|
|
||||||
if barrierConfig.StoredShares > 0 {
|
|
||||||
var keysToStore [][]byte
|
|
||||||
for i := 0; i < barrierConfig.StoredShares; i++ {
|
|
||||||
keysToStore = append(keysToStore, barrierUnsealKeys[0])
|
|
||||||
barrierUnsealKeys = barrierUnsealKeys[1:]
|
|
||||||
}
|
|
||||||
if err := c.seal.SetStoredKeys(keysToStore); err != nil {
|
|
||||||
c.logger.Error("core: failed to store keys", "error", err)
|
|
||||||
return nil, fmt.Errorf("failed to store keys: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
results := &InitResult{
|
|
||||||
SecretShares: barrierUnsealKeys,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Initialize the barrier
|
// Initialize the barrier
|
||||||
if err := c.barrier.Initialize(barrierKey); err != nil {
|
if err := c.barrier.Initialize(barrierKey); err != nil {
|
||||||
c.logger.Error("core: failed to initialize barrier", "error", err)
|
c.logger.Error("core: failed to initialize barrier", "error", err)
|
||||||
@@ -180,11 +156,38 @@ func (c *Core) Initialize(initParams *InitParams) (*InitResult, error) {
|
|||||||
|
|
||||||
// Ensure the barrier is re-sealed
|
// Ensure the barrier is re-sealed
|
||||||
defer func() {
|
defer func() {
|
||||||
|
// Defers are LIFO so we need to run this here too to ensure the stop
|
||||||
|
// happens before sealing. preSeal also stops, so we just make the
|
||||||
|
// stopping safe against multiple calls.
|
||||||
if err := c.barrier.Seal(); err != nil {
|
if err := c.barrier.Seal(); err != nil {
|
||||||
c.logger.Error("core: failed to seal barrier", "error", err)
|
c.logger.Error("core: failed to seal barrier", "error", err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
err = c.seal.SetBarrierConfig(barrierConfig)
|
||||||
|
if err != nil {
|
||||||
|
c.logger.Error("core: failed to save barrier configuration", "error", err)
|
||||||
|
return nil, fmt.Errorf("barrier configuration saving failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we are storing shares, pop them out of the returned results and push
|
||||||
|
// them through the seal
|
||||||
|
if barrierConfig.StoredShares > 0 {
|
||||||
|
var keysToStore [][]byte
|
||||||
|
for i := 0; i < barrierConfig.StoredShares; i++ {
|
||||||
|
keysToStore = append(keysToStore, barrierUnsealKeys[0])
|
||||||
|
barrierUnsealKeys = barrierUnsealKeys[1:]
|
||||||
|
}
|
||||||
|
if err := c.seal.SetStoredKeys(keysToStore); err != nil {
|
||||||
|
c.logger.Error("core: failed to store keys", "error", err)
|
||||||
|
return nil, fmt.Errorf("failed to store keys: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
results := &InitResult{
|
||||||
|
SecretShares: barrierUnsealKeys,
|
||||||
|
}
|
||||||
|
|
||||||
// Perform initial setup
|
// Perform initial setup
|
||||||
if err := c.setupCluster(); err != nil {
|
if err := c.setupCluster(); err != nil {
|
||||||
c.logger.Error("core: cluster setup failed during init", "error", err)
|
c.logger.Error("core: cluster setup failed during init", "error", err)
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user