Add wrapping through core and change to use TTL instead of Duration.

This commit is contained in:
Jeff Mitchell
2016-05-02 00:08:07 -04:00
parent 778d000b5f
commit 21c0e4ee42
15 changed files with 576 additions and 363 deletions

View File

@@ -21,7 +21,7 @@ const (
// WrapHeaderName is the name of the header containing a directive to wrap the // WrapHeaderName is the name of the header containing a directive to wrap the
// response. // response.
WrapDurationHeaderName = "X-Vault-Wrap-Duration" WrapTTLHeaderName = "X-Vault-Wrap-TTL"
) )
// Handler returns an http.Handler for the API. This can be used on // Handler returns an http.Handler for the API. This can be used on
@@ -161,29 +161,29 @@ func requestAuth(r *http.Request, req *logical.Request) *logical.Request {
return req return req
} }
// requestWrapDuration adds the WrapDuration value to the logical.Request if it // requestWrapTTL adds the WrapTTL value to the logical.Request if it
// exists. // exists.
func requestWrapDuration(r *http.Request, req *logical.Request) (*logical.Request, error) { func requestWrapTTL(r *http.Request, req *logical.Request) (*logical.Request, error) {
// First try for the header value // First try for the header value
wrapDuration := r.Header.Get(WrapDurationHeaderName) wrapTTL := r.Header.Get(WrapTTLHeaderName)
if wrapDuration == "" { if wrapTTL == "" {
return req, nil return req, nil
} }
// If it has an allowed suffix parse as a duration string // If it has an allowed suffix parse as a duration string
if strings.HasSuffix(wrapDuration, "s") || strings.HasSuffix(wrapDuration, "m") || strings.HasSuffix(wrapDuration, "h") { if strings.HasSuffix(wrapTTL, "s") || strings.HasSuffix(wrapTTL, "m") || strings.HasSuffix(wrapTTL, "h") {
dur, err := time.ParseDuration(wrapDuration) dur, err := time.ParseDuration(wrapTTL)
if err != nil { if err != nil {
return req, err return req, err
} }
req.WrapDuration = dur req.WrapTTL = dur
} else { } else {
// Parse as a straight number of seconds // Parse as a straight number of seconds
seconds, err := strconv.ParseInt(wrapDuration, 10, 64) seconds, err := strconv.ParseInt(wrapTTL, 10, 64)
if err != nil { if err != nil {
return req, err return req, err
} }
req.WrapDuration = time.Duration(time.Duration(seconds) * time.Second) req.WrapTTL = time.Duration(time.Duration(seconds) * time.Second)
} }
return req, nil return req, nil

View File

@@ -76,9 +76,9 @@ func handleLogical(core *vault.Core, dataOnly bool, prepareRequestCallback Prepa
Data: data, Data: data,
Connection: getConnection(r), Connection: getConnection(r),
}) })
req, err = requestWrapDuration(r, req) req, err = requestWrapTTL(r, req)
if err != nil { if err != nil {
respondError(w, http.StatusBadRequest, errwrap.Wrapf("error parsing X-Vault-Wrap-Duration header: {{err}}", err)) respondError(w, http.StatusBadRequest, errwrap.Wrapf("error parsing X-Vault-Wrap-TTL header: {{err}}", err))
return return
} }
@@ -130,7 +130,16 @@ func respondLogical(w http.ResponseWriter, r *http.Request, path string, dataOnl
return return
} }
httpResp = logical.SanitizeResponse(resp) if resp.WrapInfo.Token != "" {
httpResp = logical.HTTPResponse{
WrapInfo: &logical.HTTPWrapInfo{
Token: resp.WrapInfo.Token,
TTL: int(resp.WrapInfo.TTL.Seconds()),
},
}
} else {
httpResp = logical.SanitizeResponse(resp)
}
} }
// Respond // Respond

View File

@@ -40,8 +40,9 @@ func TestLogical(t *testing.T) {
"data": map[string]interface{}{ "data": map[string]interface{}{
"data": "bar", "data": "bar",
}, },
"auth": nil, "auth": nil,
"warnings": nilWarnings, "wrap_info": nil,
"warnings": nilWarnings,
} }
testResponseStatus(t, resp, 200) testResponseStatus(t, resp, 200)
testResponseBody(t, resp, &actual) testResponseBody(t, resp, &actual)
@@ -139,8 +140,9 @@ func TestLogical_StandbyRedirect(t *testing.T) {
"creation_ttl": float64(0), "creation_ttl": float64(0),
"role": "", "role": "",
}, },
"warnings": nilWarnings, "warnings": nilWarnings,
"auth": nil, "wrap_info": nil,
"auth": nil,
} }
testResponseStatus(t, resp, 200) testResponseStatus(t, resp, 200)
@@ -177,6 +179,7 @@ func TestLogical_CreateToken(t *testing.T) {
"renewable": false, "renewable": false,
"lease_duration": float64(0), "lease_duration": float64(0),
"data": nil, "data": nil,
"wrap_info": nil,
"auth": map[string]interface{}{ "auth": map[string]interface{}{
"policies": []interface{}{"root"}, "policies": []interface{}{"root"},
"metadata": nil, "metadata": nil,

View File

@@ -17,8 +17,8 @@ func TestSysPolicies(t *testing.T) {
var actual map[string]interface{} var actual map[string]interface{}
expected := map[string]interface{}{ expected := map[string]interface{}{
"policies": []interface{}{"default", "root"}, "policies": []interface{}{"cubbyhole-response-wrapping", "default", "root"},
"keys": []interface{}{"default", "root"}, "keys": []interface{}{"cubbyhole-response-wrapping", "default", "root"},
} }
testResponseStatus(t, resp, 200) testResponseStatus(t, resp, 200)
testResponseBody(t, resp, &actual) testResponseBody(t, resp, &actual)
@@ -62,14 +62,19 @@ func TestSysWritePolicy(t *testing.T) {
var actual map[string]interface{} var actual map[string]interface{}
expected := map[string]interface{}{ expected := map[string]interface{}{
"policies": []interface{}{"default", "foo", "root"}, "policies": []interface{}{"cubbyhole-response-wrapping", "default", "foo", "root"},
"keys": []interface{}{"default", "foo", "root"}, "keys": []interface{}{"cubbyhole-response-wrapping", "default", "foo", "root"},
} }
testResponseStatus(t, resp, 200) testResponseStatus(t, resp, 200)
testResponseBody(t, resp, &actual) testResponseBody(t, resp, &actual)
if !reflect.DeepEqual(actual, expected) { if !reflect.DeepEqual(actual, expected) {
t.Fatalf("bad: got\n%#v\nexpected\n%#v\n", actual, expected) t.Fatalf("bad: got\n%#v\nexpected\n%#v\n", actual, expected)
} }
resp = testHttpPost(t, token, addr+"/v1/sys/policy/cubbyhole-response-wrapping", map[string]interface{}{
"rules": ``,
})
testResponseStatus(t, resp, 400)
} }
func TestSysDeletePolicy(t *testing.T) { func TestSysDeletePolicy(t *testing.T) {
@@ -86,12 +91,17 @@ func TestSysDeletePolicy(t *testing.T) {
resp = testHttpDelete(t, token, addr+"/v1/sys/policy/foo") resp = testHttpDelete(t, token, addr+"/v1/sys/policy/foo")
testResponseStatus(t, resp, 204) testResponseStatus(t, resp, 204)
// Also attempt to delete these since they should not be allowed (ignore
// responses, if they exist later that's sufficient)
resp = testHttpDelete(t, token, addr+"/v1/sys/policy/default")
resp = testHttpDelete(t, token, addr+"/v1/sys/policy/cubbyhole-response-wrapping")
resp = testHttpGet(t, token, addr+"/v1/sys/policy") resp = testHttpGet(t, token, addr+"/v1/sys/policy")
var actual map[string]interface{} var actual map[string]interface{}
expected := map[string]interface{}{ expected := map[string]interface{}{
"policies": []interface{}{"default", "root"}, "policies": []interface{}{"cubbyhole-response-wrapping", "default", "root"},
"keys": []interface{}{"default", "root"}, "keys": []interface{}{"cubbyhole-response-wrapping", "default", "root"},
} }
testResponseStatus(t, resp, 200) testResponseStatus(t, resp, 200)
testResponseBody(t, resp, &actual) testResponseBody(t, resp, &actual)

View File

@@ -54,9 +54,9 @@ type Request struct {
// request path with the MountPoint trimmed off. // request path with the MountPoint trimmed off.
MountPoint string MountPoint string
// WrapDuration contains the requested TTL of the token used to wrap the // WrapTTL contains the requested TTL of the token used to wrap the
// response in a cubbyhole. // response in a cubbyhole.
WrapDuration time.Duration WrapTTL time.Duration
} }
// Get returns a data field and guards for nil Data // Get returns a data field and guards for nil Data

View File

@@ -30,7 +30,7 @@ const (
type WrapInfo struct { type WrapInfo struct {
// Setting to non-zero specifies that the response should be wrapped. // Setting to non-zero specifies that the response should be wrapped.
// Specifies the desired TTL of the wrapping token. // Specifies the desired TTL of the wrapping token.
Duration time.Duration TTL time.Duration
// The token containing the wrapped response // The token containing the wrapped response
Token string Token string
@@ -132,6 +132,11 @@ func (r *Response) ClearWarnings() {
r.warnings = make([]string, 0, 1) r.warnings = make([]string, 0, 1)
} }
// Copies the warnings from the other response to this one
func (r *Response) CloneWarnings(other *Response) {
r.warnings = other.warnings
}
// IsError returns true if this response seems to indicate an error. // IsError returns true if this response seems to indicate an error.
func (r *Response) IsError() bool { func (r *Response) IsError() bool {
return r != nil && len(r.Data) == 1 && r.Data["error"] != nil return r != nil && len(r.Data) == 1 && r.Data["error"] != nil

View File

@@ -5,6 +5,7 @@ func SanitizeResponse(input *Response) *HTTPResponse {
Data: input.Data, Data: input.Data,
Warnings: input.Warnings(), Warnings: input.Warnings(),
} }
if input.Secret != nil { if input.Secret != nil {
logicalResp.LeaseID = input.Secret.LeaseID logicalResp.LeaseID = input.Secret.LeaseID
logicalResp.Renewable = input.Secret.Renewable logicalResp.Renewable = input.Secret.Renewable
@@ -32,6 +33,7 @@ type HTTPResponse struct {
Renewable bool `json:"renewable"` Renewable bool `json:"renewable"`
LeaseDuration int `json:"lease_duration"` LeaseDuration int `json:"lease_duration"`
Data map[string]interface{} `json:"data"` Data map[string]interface{} `json:"data"`
WrapInfo *HTTPWrapInfo `json:"wrap_info"`
Warnings []string `json:"warnings"` Warnings []string `json:"warnings"`
Auth *HTTPAuth `json:"auth"` Auth *HTTPAuth `json:"auth"`
} }
@@ -44,3 +46,8 @@ type HTTPAuth struct {
LeaseDuration int `json:"lease_duration"` LeaseDuration int `json:"lease_duration"`
Renewable bool `json:"renewable"` Renewable bool `json:"renewable"`
} }
type HTTPWrapInfo struct {
Token string `json:"token"`
TTL int `json:"ttl"`
}

View File

@@ -7,8 +7,6 @@ import (
"log" "log"
"net/url" "net/url"
"os" "os"
"sort"
"strings"
"sync" "sync"
"time" "time"
@@ -18,7 +16,6 @@ import (
"github.com/hashicorp/go-uuid" "github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/audit" "github.com/hashicorp/vault/audit"
"github.com/hashicorp/vault/helper/mlock" "github.com/hashicorp/vault/helper/mlock"
"github.com/hashicorp/vault/helper/strutil"
"github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/physical" "github.com/hashicorp/vault/physical"
"github.com/hashicorp/vault/shamir" "github.com/hashicorp/vault/shamir"
@@ -374,303 +371,6 @@ func (c *Core) Shutdown() error {
return c.sealInternal() return c.sealInternal()
} }
// HandleRequest is used to handle a new incoming request
func (c *Core) HandleRequest(req *logical.Request) (resp *logical.Response, err error) {
c.stateLock.RLock()
defer c.stateLock.RUnlock()
if c.sealed {
return nil, ErrSealed
}
if c.standby {
return nil, ErrStandby
}
// Allowing writing to a path ending in / makes it extremely difficult to
// understand user intent for the filesystem-like backends (generic,
// cubbyhole) -- did they want a key named foo/ or did they want to write
// to a directory foo/ with no (or forgotten) key, or...? It also affects
// lookup, because paths ending in / are considered prefixes by some
// backends. Basically, it's all just terrible, so don't allow it.
if strings.HasSuffix(req.Path, "/") &&
(req.Operation == logical.UpdateOperation ||
req.Operation == logical.CreateOperation) {
return logical.ErrorResponse("cannot write to a path ending in '/'"), nil
}
var auth *logical.Auth
if c.router.LoginPath(req.Path) {
resp, auth, err = c.handleLoginRequest(req)
} else {
resp, auth, err = c.handleRequest(req)
}
// Ensure we don't leak internal data
if resp != nil {
if resp.Secret != nil {
resp.Secret.InternalData = nil
}
if resp.Auth != nil {
resp.Auth.InternalData = nil
}
}
// Create an audit trail of the response
if err := c.auditBroker.LogResponse(auth, req, resp, err); err != nil {
c.logger.Printf("[ERR] core: failed to audit response (request path: %s): %v",
req.Path, err)
return nil, ErrInternalError
}
return
}
func (c *Core) handleRequest(req *logical.Request) (retResp *logical.Response, retAuth *logical.Auth, retErr error) {
defer metrics.MeasureSince([]string{"core", "handle_request"}, time.Now())
// Validate the token
auth, te, err := c.checkToken(req)
if te != nil {
defer func() {
// Attempt to use the token (decrement num_uses)
// If a secret was generated and num_uses is currently 1, it will be
// immediately revoked; in that case, don't return the leased
// credentials as they are now invalid.
if retResp != nil &&
te != nil && te.NumUses == 1 &&
retResp.Secret != nil &&
// Some backends return a TTL even without a Lease ID
retResp.Secret.LeaseID != "" {
retResp = logical.ErrorResponse("Secret cannot be returned; token had one use left, so leased credentials were immediately revoked.")
}
if err := c.tokenStore.UseToken(te); err != nil {
c.logger.Printf("[ERR] core: failed to use token: %v", err)
retResp = nil
retAuth = nil
retErr = ErrInternalError
}
}()
}
if err != nil {
// If it is an internal error we return that, otherwise we
// return invalid request so that the status codes can be correct
var errType error
switch err {
case ErrInternalError, logical.ErrPermissionDenied:
errType = err
default:
errType = logical.ErrInvalidRequest
}
if err := c.auditBroker.LogRequest(auth, req, err); err != nil {
c.logger.Printf("[ERR] core: failed to audit request with path (%s): %v",
req.Path, err)
}
return logical.ErrorResponse(err.Error()), nil, errType
}
// Attach the display name
req.DisplayName = auth.DisplayName
// Create an audit trail of the request
if err := c.auditBroker.LogRequest(auth, req, nil); err != nil {
c.logger.Printf("[ERR] core: failed to audit request with path (%s): %v",
req.Path, err)
return nil, auth, ErrInternalError
}
// Route the request
resp, err := c.router.Route(req)
// If there is a secret, we must register it with the expiration manager.
// We exclude renewal of a lease, since it does not need to be re-registered
if resp != nil && resp.Secret != nil && !strings.HasPrefix(req.Path, "sys/renew/") {
// Get the SystemView for the mount
sysView := c.router.MatchingSystemView(req.Path)
if sysView == nil {
c.logger.Println("[ERR] core: unable to retrieve system view from router")
return nil, auth, ErrInternalError
}
// Apply the default lease if none given
if resp.Secret.TTL == 0 {
resp.Secret.TTL = sysView.DefaultLeaseTTL()
}
// Limit the lease duration
maxTTL := sysView.MaxLeaseTTL()
if resp.Secret.TTL > maxTTL {
resp.Secret.TTL = maxTTL
}
// Generic mounts should return the TTL but not register
// for a lease as this provides a massive slowdown
registerLease := true
matchingBackend := c.router.MatchingBackend(req.Path)
if matchingBackend == nil {
c.logger.Println("[ERR] core: unable to retrieve generic backend from router")
return nil, auth, ErrInternalError
}
if ptbe, ok := matchingBackend.(*PassthroughBackend); ok {
if !ptbe.GeneratesLeases() {
registerLease = false
resp.Secret.Renewable = false
}
}
if registerLease {
leaseID, err := c.expiration.Register(req, resp)
if err != nil {
c.logger.Printf(
"[ERR] core: failed to register lease "+
"(request path: %s): %v", req.Path, err)
return nil, auth, ErrInternalError
}
resp.Secret.LeaseID = leaseID
}
}
// Only the token store is allowed to return an auth block, for any
// other request this is an internal error. We exclude renewal of a token,
// since it does not need to be re-registered
if resp != nil && resp.Auth != nil && !strings.HasPrefix(req.Path, "auth/token/renew") {
if !strings.HasPrefix(req.Path, "auth/token/") {
c.logger.Printf(
"[ERR] core: unexpected Auth response for non-token backend "+
"(request path: %s)", req.Path)
return nil, auth, ErrInternalError
}
// Register with the expiration manager. We use the token's actual path
// here because roles allow suffixes.
te, err := c.tokenStore.Lookup(resp.Auth.ClientToken)
if err != nil {
c.logger.Printf("[ERR] core: failed to lookup token: %v", err)
return nil, nil, ErrInternalError
}
if err := c.expiration.RegisterAuth(te.Path, resp.Auth); err != nil {
c.logger.Printf("[ERR] core: failed to register token lease "+
"(request path: %s): %v", req.Path, err)
return nil, auth, ErrInternalError
}
}
// Return the response and error
return resp, auth, err
}
// handleLoginRequest is used to handle a login request, which is an
// unauthenticated request to the backend.
func (c *Core) handleLoginRequest(req *logical.Request) (*logical.Response, *logical.Auth, error) {
defer metrics.MeasureSince([]string{"core", "handle_login_request"}, time.Now())
// Create an audit trail of the request, auth is not available on login requests
if err := c.auditBroker.LogRequest(nil, req, nil); err != nil {
c.logger.Printf("[ERR] core: failed to audit request with path %s: %v",
req.Path, err)
return nil, nil, ErrInternalError
}
// Route the request
resp, err := c.router.Route(req)
// A login request should never return a secret!
if resp != nil && resp.Secret != nil {
c.logger.Printf("[ERR] core: unexpected Secret response for login path"+
"(request path: %s)", req.Path)
return nil, nil, ErrInternalError
}
// If the response generated an authentication, then generate the token
var auth *logical.Auth
if resp != nil && resp.Auth != nil {
auth = resp.Auth
// Determine the source of the login
source := c.router.MatchingMount(req.Path)
source = strings.TrimPrefix(source, credentialRoutePrefix)
source = strings.Replace(source, "/", "-", -1)
// Prepend the source to the display name
auth.DisplayName = strings.TrimSuffix(source+auth.DisplayName, "-")
sysView := c.router.MatchingSystemView(req.Path)
if sysView == nil {
c.logger.Printf("[ERR] core: unable to look up sys view for login path"+
"(request path: %s)", req.Path)
return nil, nil, ErrInternalError
}
// Set the default lease if non-provided, root tokens are exempt
if auth.TTL == 0 && !strutil.StrListContains(auth.Policies, "root") {
auth.TTL = sysView.DefaultLeaseTTL()
}
// Limit the lease duration
if auth.TTL > sysView.MaxLeaseTTL() {
auth.TTL = sysView.MaxLeaseTTL()
}
// Generate a token
te := TokenEntry{
Path: req.Path,
Policies: auth.Policies,
Meta: auth.Metadata,
DisplayName: auth.DisplayName,
CreationTime: time.Now().Unix(),
TTL: auth.TTL,
}
if strutil.StrListSubset(te.Policies, []string{"root"}) {
te.Policies = []string{"root"}
} else {
// Use a map to filter out/prevent duplicates
policyMap := map[string]bool{}
for _, policy := range te.Policies {
if policy == "" {
// Don't allow a policy with no name, even though it is a valid
// slice member
continue
}
policyMap[policy] = true
}
// Add the default policy
policyMap["default"] = true
te.Policies = []string{}
for k, _ := range policyMap {
te.Policies = append(te.Policies, k)
}
sort.Strings(te.Policies)
}
if err := c.tokenStore.create(&te); err != nil {
c.logger.Printf("[ERR] core: failed to create token: %v", err)
return nil, auth, ErrInternalError
}
// Populate the client token and accessor
auth.ClientToken = te.ID
auth.Accessor = te.Accessor
auth.Policies = te.Policies
// Register with the expiration manager
if err := c.expiration.RegisterAuth(te.Path, auth); err != nil {
c.logger.Printf("[ERR] core: failed to register token lease "+
"(request path: %s): %v", req.Path, err)
return nil, auth, ErrInternalError
}
// Attach the display name, might be used by audit backends
req.DisplayName = auth.DisplayName
}
return resp, auth, err
}
func (c *Core) fetchACLandTokenEntry(req *logical.Request) (*ACL, *TokenEntry, error) { func (c *Core) fetchACLandTokenEntry(req *logical.Request) (*ACL, *TokenEntry, error) {
defer metrics.MeasureSince([]string{"core", "fetch_acl_and_token"}, time.Now()) defer metrics.MeasureSince([]string{"core", "fetch_acl_and_token"}, time.Now())

View File

@@ -1086,6 +1086,7 @@ func (b *SystemBackend) handlePolicySet(
func (b *SystemBackend) handlePolicyDelete( func (b *SystemBackend) handlePolicyDelete(
req *logical.Request, data *framework.FieldData) (*logical.Response, error) { req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
name := data.Get("name").(string) name := data.Get("name").(string)
if err := b.Core.policyStore.DeletePolicy(name); err != nil { if err := b.Core.policyStore.DeletePolicy(name); err != nil {
return handleError(err) return handleError(err)
} }

View File

@@ -610,8 +610,8 @@ func TestSystemBackend_policyList(t *testing.T) {
} }
exp := map[string]interface{}{ exp := map[string]interface{}{
"keys": []string{"default", "root"}, "keys": []string{"cubbyhole-response-wrapping", "default", "root"},
"policies": []string{"default", "root"}, "policies": []string{"cubbyhole-response-wrapping", "default", "root"},
} }
if !reflect.DeepEqual(resp.Data, exp) { if !reflect.DeepEqual(resp.Data, exp) {
t.Fatalf("got: %#v expect: %#v", resp.Data, exp) t.Fatalf("got: %#v expect: %#v", resp.Data, exp)
@@ -663,8 +663,8 @@ func TestSystemBackend_policyCRUD(t *testing.T) {
} }
exp = map[string]interface{}{ exp = map[string]interface{}{
"keys": []string{"default", "foo", "root"}, "keys": []string{"cubbyhole-response-wrapping", "default", "foo", "root"},
"policies": []string{"default", "foo", "root"}, "policies": []string{"cubbyhole-response-wrapping", "default", "foo", "root"},
} }
if !reflect.DeepEqual(resp.Data, exp) { if !reflect.DeepEqual(resp.Data, exp) {
t.Fatalf("got: %#v expect: %#v", resp.Data, exp) t.Fatalf("got: %#v expect: %#v", resp.Data, exp)
@@ -698,8 +698,8 @@ func TestSystemBackend_policyCRUD(t *testing.T) {
} }
exp = map[string]interface{}{ exp = map[string]interface{}{
"keys": []string{"default", "root"}, "keys": []string{"cubbyhole-response-wrapping", "default", "root"},
"policies": []string{"default", "root"}, "policies": []string{"cubbyhole-response-wrapping", "default", "root"},
} }
if !reflect.DeepEqual(resp.Data, exp) { if !reflect.DeepEqual(resp.Data, exp) {
t.Fatalf("got: %#v expect: %#v", resp.Data, exp) t.Fatalf("got: %#v expect: %#v", resp.Data, exp)

View File

@@ -17,6 +17,14 @@ const (
// policyCacheSize is the number of policies that are kept cached // policyCacheSize is the number of policies that are kept cached
policyCacheSize = 1024 policyCacheSize = 1024
// cubbyholeResponseWrappingPolicy is the policy that ensures cubbyhole
// response wrapping can always succeed
cubbyholeResponseWrappingPolicy = `
path "cubbyhole/response" {
capabilities = ["create", "read"]
}
`
) )
// PolicyStore is used to provide durable storage of policy, and to // PolicyStore is used to provide durable storage of policy, and to
@@ -63,6 +71,19 @@ func (c *Core) setupPolicyStore() error {
return err return err
} }
} }
// Ensure that the cubbyhole response wrapping policy exists
policy, err = c.policyStore.GetPolicy("cubbyhole-response-wrapping")
if err != nil {
return errwrap.Wrapf("error fetching default policy from store: {{err}}", err)
}
if policy == nil || policy.Raw != cubbyholeResponseWrappingPolicy {
err := c.policyStore.createCubbyholeResponseWrappingPolicy()
if err != nil {
return err
}
}
return nil return nil
} }
@@ -79,10 +100,17 @@ func (ps *PolicyStore) SetPolicy(p *Policy) error {
if p.Name == "root" { if p.Name == "root" {
return fmt.Errorf("cannot update root policy") return fmt.Errorf("cannot update root policy")
} }
if p.Name == "cubbyhole-response-wrapping" {
return fmt.Errorf("cannot update cubbyhole-response-wrapping policy")
}
if p.Name == "" { if p.Name == "" {
return fmt.Errorf("policy name missing") return fmt.Errorf("policy name missing")
} }
return ps.setPolicyInternal(p)
}
func (ps *PolicyStore) setPolicyInternal(p *Policy) error {
// Create the entry // Create the entry
entry, err := logical.StorageEntryJSON(p.Name, &PolicyEntry{ entry, err := logical.StorageEntryJSON(p.Name, &PolicyEntry{
Version: 2, Version: 2,
@@ -174,6 +202,9 @@ func (ps *PolicyStore) DeletePolicy(name string) error {
if name == "default" { if name == "default" {
return fmt.Errorf("cannot delete default policy") return fmt.Errorf("cannot delete default policy")
} }
if name == "cubbyhole-response-wrapping" {
return fmt.Errorf("cannot delete cubbyhole-response-wrapping policy")
}
if err := ps.view.Delete(name); err != nil { if err := ps.view.Delete(name); err != nil {
return fmt.Errorf("failed to delete policy: %v", err) return fmt.Errorf("failed to delete policy: %v", err)
} }
@@ -235,5 +266,19 @@ path "cubbyhole" {
} }
policy.Name = "default" policy.Name = "default"
return ps.SetPolicy(policy) return ps.setPolicyInternal(policy)
}
func (ps *PolicyStore) createCubbyholeResponseWrappingPolicy() error {
policy, err := Parse(cubbyholeResponseWrappingPolicy)
if err != nil {
return errwrap.Wrapf("error parsing cubbyhole-response-wrapping policy: {{err}}", err)
}
if policy == nil {
return fmt.Errorf("parsing cubbyhole-response-wrapping policy resulted in nil policy")
}
policy.Name = "cubbyhole-response-wrapping"
return ps.setPolicyInternal(policy)
} }

View File

@@ -110,6 +110,32 @@ func TestPolicyStore_CRUD(t *testing.T) {
} }
} }
// Test predefined policy handling
func TestPolicyStore_Predefined(t *testing.T) {
core, _, _ := TestCoreUnsealed(t)
// Ensure both default policies are created
err := core.setupPolicyStore()
if err != nil {
t.Fatalf("err: %v", err)
}
// List should be two elements
out, err := core.policyStore.ListPolicies()
if err != nil {
t.Fatalf("err: %v", err)
}
if len(out) != 2 || out[0] != "cubbyhole-response-wrapping" || out[1] != "default" {
t.Fatalf("bad: %v", out)
}
p, err := core.policyStore.GetPolicy("cubbyhole-response-wrapping")
if err != nil {
t.Fatalf("err: %v", err)
}
if p.Raw != cubbyholeResponseWrappingPolicy {
t.Fatalf("bad: expected\n%s\ngot\n%s\n", cubbyholeResponseWrappingPolicy, p.Raw)
}
}
func TestPolicyStore_ACL(t *testing.T) { func TestPolicyStore_ACL(t *testing.T) {
ps := mockPolicyStore(t) ps := mockPolicyStore(t)

407
vault/request_handling.go Normal file
View File

@@ -0,0 +1,407 @@
package vault
import (
"sort"
"strings"
"sync"
"time"
"github.com/armon/go-metrics"
"github.com/hashicorp/vault/helper/strutil"
"github.com/hashicorp/vault/logical"
)
var (
// Value for memoizing whether cubbyhole is mounted, e.g. if we are in normal operation and not test mode
cubbyholeMounted *bool
// mutex to ensure the same
cubbyholeMountedMutex sync.Mutex
)
// HandleRequest is used to handle a new incoming request
func (c *Core) HandleRequest(req *logical.Request) (resp *logical.Response, err error) {
c.stateLock.RLock()
defer c.stateLock.RUnlock()
if c.sealed {
return nil, ErrSealed
}
if c.standby {
return nil, ErrStandby
}
// Allowing writing to a path ending in / makes it extremely difficult to
// understand user intent for the filesystem-like backends (generic,
// cubbyhole) -- did they want a key named foo/ or did they want to write
// to a directory foo/ with no (or forgotten) key, or...? It also affects
// lookup, because paths ending in / are considered prefixes by some
// backends. Basically, it's all just terrible, so don't allow it.
if strings.HasSuffix(req.Path, "/") &&
(req.Operation == logical.UpdateOperation ||
req.Operation == logical.CreateOperation) {
return logical.ErrorResponse("cannot write to a path ending in '/'"), nil
}
var auth *logical.Auth
if c.router.LoginPath(req.Path) {
resp, auth, err = c.handleLoginRequest(req)
} else {
resp, auth, err = c.handleRequest(req)
}
// Ensure we don't leak internal data
if resp != nil {
if resp.Secret != nil {
resp.Secret.InternalData = nil
}
if resp.Auth != nil {
resp.Auth.InternalData = nil
}
}
// In order to wrap, we need cubbyhole to be mounted, so we ensure that
// cubbyhole is actually mounted, as it may not be during tests. We memoize
// this response, since cubbyhole cannot be mounted or unmounted during
// normal operation.
if cubbyholeMounted == nil {
cubbyholeMountedMutex.Lock()
cubbyholeMounted = new(bool)
// Ensure it wasn't changed by another goroutine
if cubbyholeMounted == nil {
if c.router.MatchingMount("cubbyhole") != "" {
*cubbyholeMounted = true
} else {
*cubbyholeMounted = false
}
}
cubbyholeMountedMutex.Unlock()
}
// We are wrapping if there is anything to wrap (not a nil response) and a
// TTL was specified for the token, plus if cubbyhole is mounted (which
// will be the case normally)
wrapping := *cubbyholeMounted && resp != nil && resp.WrapInfo.TTL != 0
// If we are wrapping, the first part happens before auditing so that
// resp.WrapInfo.Token can contain the HMAC'd wrapping token ID in the
// audit logs, so that it can be determined from the audit logs whether the
// token was ever actually used.
if wrapping {
// Create the wrapping token
te := TokenEntry{
Path: req.Path,
Policies: []string{"cubbyhole-response-wrapping"},
CreationTime: time.Now().Unix(),
TTL: resp.WrapInfo.TTL,
NumUses: 1,
}
if err := c.tokenStore.create(&te); err != nil {
c.logger.Printf("[ERR] core: failed to create wrapping token: %v", err)
return nil, ErrInternalError
}
resp.WrapInfo.Token = te.ID
httpResponse := logical.SanitizeResponse(resp)
cubbyReq := &logical.Request{
Operation: logical.CreateOperation,
Path: "cubbyhole/response",
ClientToken: te.ID,
Data: map[string]interface{}{
"response": httpResponse,
},
}
_, err = c.router.Route(cubbyReq)
if err != nil {
c.logger.Printf("[ERR] core: failed to store wrapped response information: %v", err)
return nil, ErrInternalError
}
auth := &logical.Auth{
ClientToken: te.ID,
Policies: []string{"cubbyhole-response-wrapping"},
LeaseOptions: logical.LeaseOptions{
TTL: te.TTL,
Renewable: false,
},
}
// Register the wrapped token with the expiration manager
if err := c.expiration.RegisterAuth(te.Path, auth); err != nil {
c.logger.Printf("[ERR] core: failed to register cubbyhole wrapping token lease "+
"(request path: %s): %v", req.Path, err)
return nil, ErrInternalError
}
}
// Create an audit trail of the response
if err := c.auditBroker.LogResponse(auth, req, resp, err); err != nil {
c.logger.Printf("[ERR] core: failed to audit response (request path: %s): %v",
req.Path, err)
return nil, ErrInternalError
}
// If we are wrapping, now is when we create a new response object with the
// wrapped information, since the original response has been audit logged
if wrapping {
wrappingResp := &logical.Response{
WrapInfo: logical.WrapInfo{
Token: resp.WrapInfo.Token,
},
}
wrappingResp.CloneWarnings(resp)
resp = wrappingResp
}
return
}
func (c *Core) handleRequest(req *logical.Request) (retResp *logical.Response, retAuth *logical.Auth, retErr error) {
defer metrics.MeasureSince([]string{"core", "handle_request"}, time.Now())
// Validate the token
auth, te, err := c.checkToken(req)
if te != nil {
defer func() {
// Attempt to use the token (decrement num_uses)
// If a secret was generated and num_uses is currently 1, it will be
// immediately revoked; in that case, don't return the leased
// credentials as they are now invalid.
if retResp != nil &&
te != nil && te.NumUses == 1 &&
retResp.Secret != nil &&
// Some backends return a TTL even without a Lease ID
retResp.Secret.LeaseID != "" {
retResp = logical.ErrorResponse("Secret cannot be returned; token had one use left, so leased credentials were immediately revoked.")
}
if err := c.tokenStore.UseToken(te); err != nil {
c.logger.Printf("[ERR] core: failed to use token: %v", err)
retResp = nil
retAuth = nil
retErr = ErrInternalError
}
}()
}
if err != nil {
// If it is an internal error we return that, otherwise we
// return invalid request so that the status codes can be correct
var errType error
switch err {
case ErrInternalError, logical.ErrPermissionDenied:
errType = err
default:
errType = logical.ErrInvalidRequest
}
if err := c.auditBroker.LogRequest(auth, req, err); err != nil {
c.logger.Printf("[ERR] core: failed to audit request with path (%s): %v",
req.Path, err)
}
return logical.ErrorResponse(err.Error()), nil, errType
}
// Attach the display name
req.DisplayName = auth.DisplayName
// Create an audit trail of the request
if err := c.auditBroker.LogRequest(auth, req, nil); err != nil {
c.logger.Printf("[ERR] core: failed to audit request with path (%s): %v",
req.Path, err)
return nil, auth, ErrInternalError
}
// Route the request
resp, err := c.router.Route(req)
// If there is a secret, we must register it with the expiration manager.
// We exclude renewal of a lease, since it does not need to be re-registered
if resp != nil && resp.Secret != nil && !strings.HasPrefix(req.Path, "sys/renew/") {
// Get the SystemView for the mount
sysView := c.router.MatchingSystemView(req.Path)
if sysView == nil {
c.logger.Println("[ERR] core: unable to retrieve system view from router")
return nil, auth, ErrInternalError
}
// Apply the default lease if none given
if resp.Secret.TTL == 0 {
resp.Secret.TTL = sysView.DefaultLeaseTTL()
}
// Limit the lease duration
maxTTL := sysView.MaxLeaseTTL()
if resp.Secret.TTL > maxTTL {
resp.Secret.TTL = maxTTL
}
// Generic mounts should return the TTL but not register
// for a lease as this provides a massive slowdown
registerLease := true
matchingBackend := c.router.MatchingBackend(req.Path)
if matchingBackend == nil {
c.logger.Println("[ERR] core: unable to retrieve generic backend from router")
return nil, auth, ErrInternalError
}
if ptbe, ok := matchingBackend.(*PassthroughBackend); ok {
if !ptbe.GeneratesLeases() {
registerLease = false
resp.Secret.Renewable = false
}
}
if registerLease {
leaseID, err := c.expiration.Register(req, resp)
if err != nil {
c.logger.Printf(
"[ERR] core: failed to register lease "+
"(request path: %s): %v", req.Path, err)
return nil, auth, ErrInternalError
}
resp.Secret.LeaseID = leaseID
}
}
// Only the token store is allowed to return an auth block, for any
// other request this is an internal error. We exclude renewal of a token,
// since it does not need to be re-registered
if resp != nil && resp.Auth != nil && !strings.HasPrefix(req.Path, "auth/token/renew") {
if !strings.HasPrefix(req.Path, "auth/token/") {
c.logger.Printf(
"[ERR] core: unexpected Auth response for non-token backend "+
"(request path: %s)", req.Path)
return nil, auth, ErrInternalError
}
// Register with the expiration manager. We use the token's actual path
// here because roles allow suffixes.
te, err := c.tokenStore.Lookup(resp.Auth.ClientToken)
if err != nil {
c.logger.Printf("[ERR] core: failed to lookup token: %v", err)
return nil, nil, ErrInternalError
}
if err := c.expiration.RegisterAuth(te.Path, resp.Auth); err != nil {
c.logger.Printf("[ERR] core: failed to register token lease "+
"(request path: %s): %v", req.Path, err)
return nil, auth, ErrInternalError
}
}
// Return the response and error
return resp, auth, err
}
// handleLoginRequest is used to handle a login request, which is an
// unauthenticated request to the backend.
func (c *Core) handleLoginRequest(req *logical.Request) (*logical.Response, *logical.Auth, error) {
defer metrics.MeasureSince([]string{"core", "handle_login_request"}, time.Now())
// Create an audit trail of the request, auth is not available on login requests
if err := c.auditBroker.LogRequest(nil, req, nil); err != nil {
c.logger.Printf("[ERR] core: failed to audit request with path %s: %v",
req.Path, err)
return nil, nil, ErrInternalError
}
// Route the request
resp, err := c.router.Route(req)
// A login request should never return a secret!
if resp != nil && resp.Secret != nil {
c.logger.Printf("[ERR] core: unexpected Secret response for login path"+
"(request path: %s)", req.Path)
return nil, nil, ErrInternalError
}
// If the response generated an authentication, then generate the token
var auth *logical.Auth
if resp != nil && resp.Auth != nil {
auth = resp.Auth
// Determine the source of the login
source := c.router.MatchingMount(req.Path)
source = strings.TrimPrefix(source, credentialRoutePrefix)
source = strings.Replace(source, "/", "-", -1)
// Prepend the source to the display name
auth.DisplayName = strings.TrimSuffix(source+auth.DisplayName, "-")
sysView := c.router.MatchingSystemView(req.Path)
if sysView == nil {
c.logger.Printf("[ERR] core: unable to look up sys view for login path"+
"(request path: %s)", req.Path)
return nil, nil, ErrInternalError
}
// Set the default lease if non-provided, root tokens are exempt
if auth.TTL == 0 && !strutil.StrListContains(auth.Policies, "root") {
auth.TTL = sysView.DefaultLeaseTTL()
}
// Limit the lease duration
if auth.TTL > sysView.MaxLeaseTTL() {
auth.TTL = sysView.MaxLeaseTTL()
}
// Generate a token
te := TokenEntry{
Path: req.Path,
Policies: auth.Policies,
Meta: auth.Metadata,
DisplayName: auth.DisplayName,
CreationTime: time.Now().Unix(),
TTL: auth.TTL,
}
if strutil.StrListSubset(te.Policies, []string{"root"}) {
te.Policies = []string{"root"}
} else {
// Use a map to filter out/prevent duplicates
policyMap := map[string]bool{}
for _, policy := range te.Policies {
if policy == "" {
// Don't allow a policy with no name, even though it is a valid
// slice member
continue
}
policyMap[policy] = true
}
// Add the default policy
policyMap["default"] = true
te.Policies = []string{}
for k, _ := range policyMap {
te.Policies = append(te.Policies, k)
}
sort.Strings(te.Policies)
}
if err := c.tokenStore.create(&te); err != nil {
c.logger.Printf("[ERR] core: failed to create token: %v", err)
return nil, auth, ErrInternalError
}
// Populate the client token and accessor
auth.ClientToken = te.ID
auth.Accessor = te.Accessor
auth.Policies = te.Policies
// Register with the expiration manager
if err := c.expiration.RegisterAuth(te.Path, auth); err != nil {
c.logger.Printf("[ERR] core: failed to register token lease "+
"(request path: %s): %v", req.Path, err)
return nil, auth, ErrInternalError
}
// Attach the display name, might be used by audit backends
req.DisplayName = auth.DisplayName
}
return resp, auth, err
}

View File

@@ -261,19 +261,19 @@ func (r *Router) routeCommon(req *logical.Request, existenceCheck bool) (resp *l
// If either of the request or response requested wrapping, ensure that // If either of the request or response requested wrapping, ensure that
// the lowest value is what ends up in the response. // the lowest value is what ends up in the response.
switch { switch {
case req.WrapDuration == 0 && resp.WrapInfo.Duration == 0: case req.WrapTTL == 0 && resp.WrapInfo.TTL == 0:
case req.WrapDuration != 0 && resp.WrapInfo.Duration != 0: case req.WrapTTL != 0 && resp.WrapInfo.TTL != 0:
if req.WrapDuration < resp.WrapInfo.Duration { if req.WrapTTL < resp.WrapInfo.TTL {
resp.WrapInfo.Duration = req.WrapDuration resp.WrapInfo.TTL = req.WrapTTL
} }
case req.WrapDuration != 0: case req.WrapTTL != 0:
resp.WrapInfo.Duration = req.WrapDuration resp.WrapInfo.TTL = req.WrapTTL
// Only case left is that only resp defines it, which doesn't need to // Only case left is that only resp defines it, which doesn't need to
// be explicitly handled // be explicitly handled
} }
// Now set the mount point if we are wrapping // Now set the mount point if we are wrapping
if resp.WrapInfo.Duration != 0 { if resp.WrapInfo.TTL != 0 {
resp.WrapInfo.MountPoint = mount resp.WrapInfo.MountPoint = mount
} }
} }

View File

@@ -20,7 +20,7 @@ type NoopBackend struct {
Requests []*logical.Request Requests []*logical.Request
Response *logical.Response Response *logical.Response
WrapDuration time.Duration WrapTTL time.Duration
} }
func (n *NoopBackend) HandleRequest(req *logical.Request) (*logical.Response, error) { func (n *NoopBackend) HandleRequest(req *logical.Request) (*logical.Response, error) {
@@ -34,12 +34,12 @@ func (n *NoopBackend) HandleRequest(req *logical.Request) (*logical.Response, er
return nil, fmt.Errorf("missing view") return nil, fmt.Errorf("missing view")
} }
if n.Response == nil && (req.WrapDuration != 0 || n.WrapDuration != 0) { if n.Response == nil && (req.WrapTTL != 0 || n.WrapTTL != 0) {
n.Response = &logical.Response{} n.Response = &logical.Response{}
} }
if n.WrapDuration != 0 { if n.WrapTTL != 0 {
n.Response.WrapInfo.Duration = n.WrapDuration n.Response.WrapInfo.TTL = n.WrapTTL
} }
return n.Response, nil return n.Response, nil
@@ -420,10 +420,10 @@ func TestRouter_Wrapping(t *testing.T) {
// Just in the request // Just in the request
req = &logical.Request{ req = &logical.Request{
Path: "wraptest/foo", Path: "wraptest/foo",
ClientToken: root, ClientToken: root,
Operation: logical.UpdateOperation, Operation: logical.UpdateOperation,
WrapDuration: time.Duration(15 * time.Second), WrapTTL: time.Duration(15 * time.Second),
} }
resp, err = core.HandleRequest(req) resp, err = core.HandleRequest(req)
if err != nil { if err != nil {
@@ -432,13 +432,13 @@ func TestRouter_Wrapping(t *testing.T) {
if resp == nil { if resp == nil {
t.Fatalf("bad: %v", resp) t.Fatalf("bad: %v", resp)
} }
if resp.WrapInfo.Duration != time.Duration(15*time.Second) || if resp.WrapInfo.TTL != time.Duration(15*time.Second) ||
resp.WrapInfo.MountPoint != "wraptest/" { resp.WrapInfo.MountPoint != "wraptest/" {
t.Fatalf("bad: %#v", resp) t.Fatalf("bad: %#v", resp)
} }
// Just in the response // Just in the response
n.WrapDuration = time.Duration(15 * time.Second) n.WrapTTL = time.Duration(15 * time.Second)
req = &logical.Request{ req = &logical.Request{
Path: "wraptest/foo", Path: "wraptest/foo",
ClientToken: root, ClientToken: root,
@@ -451,18 +451,18 @@ func TestRouter_Wrapping(t *testing.T) {
if resp == nil { if resp == nil {
t.Fatalf("bad: %v", resp) t.Fatalf("bad: %v", resp)
} }
if resp.WrapInfo.Duration != time.Duration(15*time.Second) || if resp.WrapInfo.TTL != time.Duration(15*time.Second) ||
resp.WrapInfo.MountPoint != "wraptest/" { resp.WrapInfo.MountPoint != "wraptest/" {
t.Fatalf("bad: %#v", resp) t.Fatalf("bad: %#v", resp)
} }
// In both, with request less // In both, with request less
n.WrapDuration = time.Duration(15 * time.Second) n.WrapTTL = time.Duration(15 * time.Second)
req = &logical.Request{ req = &logical.Request{
Path: "wraptest/foo", Path: "wraptest/foo",
ClientToken: root, ClientToken: root,
Operation: logical.UpdateOperation, Operation: logical.UpdateOperation,
WrapDuration: time.Duration(10 * time.Second), WrapTTL: time.Duration(10 * time.Second),
} }
resp, err = core.HandleRequest(req) resp, err = core.HandleRequest(req)
if err != nil { if err != nil {
@@ -471,18 +471,18 @@ func TestRouter_Wrapping(t *testing.T) {
if resp == nil { if resp == nil {
t.Fatalf("bad: %v", resp) t.Fatalf("bad: %v", resp)
} }
if resp.WrapInfo.Duration != time.Duration(10*time.Second) || if resp.WrapInfo.TTL != time.Duration(10*time.Second) ||
resp.WrapInfo.MountPoint != "wraptest/" { resp.WrapInfo.MountPoint != "wraptest/" {
t.Fatalf("bad: %#v", resp) t.Fatalf("bad: %#v", resp)
} }
// In both, with response less // In both, with response less
n.WrapDuration = time.Duration(10 * time.Second) n.WrapTTL = time.Duration(10 * time.Second)
req = &logical.Request{ req = &logical.Request{
Path: "wraptest/foo", Path: "wraptest/foo",
ClientToken: root, ClientToken: root,
Operation: logical.UpdateOperation, Operation: logical.UpdateOperation,
WrapDuration: time.Duration(15 * time.Second), WrapTTL: time.Duration(15 * time.Second),
} }
resp, err = core.HandleRequest(req) resp, err = core.HandleRequest(req)
if err != nil { if err != nil {
@@ -491,7 +491,7 @@ func TestRouter_Wrapping(t *testing.T) {
if resp == nil { if resp == nil {
t.Fatalf("bad: %v", resp) t.Fatalf("bad: %v", resp)
} }
if resp.WrapInfo.Duration != time.Duration(10*time.Second) || if resp.WrapInfo.TTL != time.Duration(10*time.Second) ||
resp.WrapInfo.MountPoint != "wraptest/" { resp.WrapInfo.MountPoint != "wraptest/" {
t.Fatalf("bad: %#v", resp) t.Fatalf("bad: %#v", resp)
} }