Make WrapInfo a pointer to match secret/auth in response

This commit is contained in:
Jeff Mitchell
2016-05-07 19:17:51 -04:00
parent f3a3fc3d55
commit e36f66000e
5 changed files with 23 additions and 13 deletions

View File

@@ -155,7 +155,7 @@ func respondLogical(w http.ResponseWriter, r *http.Request, path string, dataOnl
return return
} }
if resp.WrapInfo.Token != "" { if resp.WrapInfo != nil && resp.WrapInfo.Token != "" {
httpResp = logical.HTTPResponse{ httpResp = logical.HTTPResponse{
WrapInfo: &logical.HTTPWrapInfo{ WrapInfo: &logical.HTTPWrapInfo{
Token: resp.WrapInfo.Token, Token: resp.WrapInfo.Token,

View File

@@ -66,7 +66,7 @@ type Response struct {
warnings []string warnings []string
// Information for wrapping the response in a cubbyhole // Information for wrapping the response in a cubbyhole
WrapInfo WrapInfo WrapInfo *WrapInfo
} }
func init() { func init() {

View File

@@ -78,7 +78,7 @@ func (c *Core) HandleRequest(req *logical.Request) (resp *logical.Response, err
// We are wrapping if there is anything to wrap (not a nil response) and a // We are wrapping if there is anything to wrap (not a nil response) and a
// TTL was specified for the token, plus if cubbyhole is mounted (which // TTL was specified for the token, plus if cubbyhole is mounted (which
// will be the case normally) // will be the case normally)
wrapping := cubbyholeMounted && resp != nil && resp.WrapInfo.TTL != 0 wrapping := cubbyholeMounted && resp != nil && resp.WrapInfo != nil && resp.WrapInfo.TTL != 0
// If we are wrapping, the first part happens before auditing so that // If we are wrapping, the first part happens before auditing so that
// resp.WrapInfo.Token can contain the HMAC'd wrapping token ID in the // resp.WrapInfo.Token can contain the HMAC'd wrapping token ID in the

View File

@@ -256,15 +256,25 @@ func (r *Router) routeCommon(req *logical.Request, existenceCheck bool) (resp *l
// If either of the request or response requested wrapping, ensure that // If either of the request or response requested wrapping, ensure that
// the lowest value is what ends up in the response. // the lowest value is what ends up in the response.
switch { switch {
case req.WrapTTL == 0 && resp.WrapInfo.TTL == 0: case req.WrapTTL == 0 && (resp.WrapInfo == nil || resp.WrapInfo.TTL == 0):
case req.WrapTTL != 0 && resp.WrapInfo.TTL != 0: // Neither defines it, so do nothing
case req.WrapTTL != 0 && (resp.WrapInfo != nil && resp.WrapInfo.TTL != 0):
// Both define, so use the lowest
if req.WrapTTL < resp.WrapInfo.TTL { if req.WrapTTL < resp.WrapInfo.TTL {
resp.WrapInfo.TTL = req.WrapTTL resp.WrapInfo.TTL = req.WrapTTL
} }
case req.WrapTTL != 0: case req.WrapTTL != 0:
resp.WrapInfo.TTL = req.WrapTTL // Response wrap info doesn't exist, or its TTL is zero, so set
// Only case left is that only resp defines it, which doesn't need to // it to the request TTL
// be explicitly handled resp.WrapInfo = &logical.WrapInfo{
TTL: req.WrapTTL,
}
default:
// Only case left is that only resp defines it, which doesn't
// need to be explicitly handled
} }
} }

View File

@@ -39,7 +39,7 @@ func (n *NoopBackend) HandleRequest(req *logical.Request) (*logical.Response, er
} }
if n.WrapTTL != 0 { if n.WrapTTL != 0 {
n.Response.WrapInfo.TTL = n.WrapTTL n.Response.WrapInfo = &logical.WrapInfo{TTL: n.WrapTTL}
} }
return n.Response, nil return n.Response, nil
@@ -432,7 +432,7 @@ func TestRouter_Wrapping(t *testing.T) {
if resp == nil { if resp == nil {
t.Fatalf("bad: %v", resp) t.Fatalf("bad: %v", resp)
} }
if resp.WrapInfo.TTL != time.Duration(15*time.Second) { if resp.WrapInfo == nil || resp.WrapInfo.TTL != time.Duration(15*time.Second) {
t.Fatalf("bad: %#v", resp) t.Fatalf("bad: %#v", resp)
} }
@@ -450,7 +450,7 @@ func TestRouter_Wrapping(t *testing.T) {
if resp == nil { if resp == nil {
t.Fatalf("bad: %v", resp) t.Fatalf("bad: %v", resp)
} }
if resp.WrapInfo.TTL != time.Duration(15*time.Second) { if resp.WrapInfo == nil || resp.WrapInfo.TTL != time.Duration(15*time.Second) {
t.Fatalf("bad: %#v", resp) t.Fatalf("bad: %#v", resp)
} }
@@ -469,7 +469,7 @@ func TestRouter_Wrapping(t *testing.T) {
if resp == nil { if resp == nil {
t.Fatalf("bad: %v", resp) t.Fatalf("bad: %v", resp)
} }
if resp.WrapInfo.TTL != time.Duration(10*time.Second) { if resp.WrapInfo == nil || resp.WrapInfo.TTL != time.Duration(10*time.Second) {
t.Fatalf("bad: %#v", resp) t.Fatalf("bad: %#v", resp)
} }
@@ -488,7 +488,7 @@ func TestRouter_Wrapping(t *testing.T) {
if resp == nil { if resp == nil {
t.Fatalf("bad: %v", resp) t.Fatalf("bad: %v", resp)
} }
if resp.WrapInfo.TTL != time.Duration(10*time.Second) { if resp.WrapInfo == nil || resp.WrapInfo.TTL != time.Duration(10*time.Second) {
t.Fatalf("bad: %#v", resp) t.Fatalf("bad: %#v", resp)
} }
} }