From 163bfa62a6df173326c852f32b04480a89ef3d22 Mon Sep 17 00:00:00 2001 From: Mitchell Hashimoto Date: Thu, 19 Mar 2015 20:20:25 +0100 Subject: [PATCH] logical/framework: support renew --- logical/framework/backend.go | 25 +++++++++++++++++++---- logical/framework/backend_test.go | 33 +++++++++++++++++++++++++++++++ logical/request.go | 29 +++++++++++++++++++++++++++ vault/expiration.go | 24 ++++------------------ vault/expiration_test.go | 4 ++-- vault/logical_passthrough.go | 30 ++++++++++++++++------------ vault/logical_passthrough_test.go | 1 + 7 files changed, 107 insertions(+), 39 deletions(-) diff --git a/logical/framework/backend.go b/logical/framework/backend.go index 15a122d17e..b85bead433 100644 --- a/logical/framework/backend.go +++ b/logical/framework/backend.go @@ -51,8 +51,10 @@ func (b *Backend) HandleRequest(req *logical.Request) (*logical.Response, error) // Check for special cased global operations. These don't route // to a specific Path. switch req.Operation { + case logical.RenewOperation: + fallthrough case logical.RevokeOperation: - return b.handleRevoke(req) + return b.handleRevokeRenew(req) case logical.RollbackOperation: return b.handleRollback(req) } @@ -91,6 +93,7 @@ func (b *Backend) HandleRequest(req *logical.Request) (*logical.Response, error) // Call the callback with the request and the data return callback(&Request{ + Backend: b, LogicalRequest: req, Data: &FieldData{ Raw: raw, @@ -156,7 +159,7 @@ func (b *Backend) route(path string) (*Path, map[string]string) { return nil, nil } -func (b *Backend) handleRevoke( +func (b *Backend) handleRevokeRenew( req *logical.Request) (*logical.Response, error) { leaseRaw, ok := req.Data["previous_lease"] if !ok { @@ -178,7 +181,18 @@ func (b *Backend) handleRevoke( return nil, fmt.Errorf("secret is unsupported by this backend") } - if secret.Revoke == nil { + var fn OperationFunc + switch req.Operation { + case logical.RenewOperation: + fn = secret.Renew + case logical.RevokeOperation: + fn = secret.Revoke + default: + return nil, fmt.Errorf( + "invalid operation for revoke/renew: %s", req.Operation) + } + + if fn == nil { return nil, logical.ErrUnsupportedOperation } @@ -187,7 +201,10 @@ func (b *Backend) handleRevoke( data = raw.(map[string]interface{}) } - return secret.Revoke(&Request{ + return fn(&Request{ + Backend: b, + LogicalRequest: req, + Data: &FieldData{ Raw: data, Schema: secret.Fields, diff --git a/logical/framework/backend_test.go b/logical/framework/backend_test.go index 135216adb2..79dc1fc2e1 100644 --- a/logical/framework/backend_test.go +++ b/logical/framework/backend_test.go @@ -137,6 +137,39 @@ func TestBackendHandleRequest_help(t *testing.T) { } } +func TestBackendHandleRequest_renew(t *testing.T) { + var called uint32 + callback := func(*Request) (*logical.Response, error) { + atomic.AddUint32(&called, 1) + return nil, nil + } + + b := &Backend{ + Secrets: []*Secret{ + &Secret{ + Type: "foo", + Renew: callback, + }, + }, + } + + _, err := b.HandleRequest(&logical.Request{ + Operation: logical.RenewOperation, + Path: "/foo", + Data: map[string]interface{}{ + "previous_lease": &logical.Lease{ + VaultID: "foo-bar", + }, + }, + }) + if err != nil { + t.Fatalf("err: %s", err) + } + if v := atomic.LoadUint32(&called); v != 1 { + t.Fatalf("bad: %#v", v) + } +} + func TestBackendHandleRequest_revoke(t *testing.T) { var called uint32 callback := func(*Request) (*logical.Response, error) { diff --git a/logical/request.go b/logical/request.go index f521dace96..1845c97bfe 100644 --- a/logical/request.go +++ b/logical/request.go @@ -2,6 +2,7 @@ package logical import ( "errors" + "time" ) // Request is a struct that stores the parameters and context @@ -39,6 +40,34 @@ func (r *Request) GetString(key string) string { return s } +// RenewRequest creates the structure of the renew request. +func RenewRequest( + path string, increment time.Duration, + lease *Lease, data map[string]interface{}) *Request { + return &Request{ + Operation: RenewOperation, + Path: path, + Data: map[string]interface{}{ + "previous_lease": lease, + "previous_data": data, + "increment": increment, + }, + } +} + +// RevokeRequest creates the structure of the revoke request. +func RevokeRequest( + path string, lease *Lease, data map[string]interface{}) *Request { + return &Request{ + Operation: RevokeOperation, + Path: path, + Data: map[string]interface{}{ + "previous_lease": lease, + "previous_data": data, + }, + } +} + // Operation is an enum that is used to specify the type // of request being made type Operation string diff --git a/vault/expiration.go b/vault/expiration.go index f2d00e5477..e101628a4c 100644 --- a/vault/expiration.go +++ b/vault/expiration.go @@ -322,16 +322,8 @@ func (m *ExpirationManager) expireID(vaultID string) { // revokeEntry is used to attempt revocation of an internal entry func (m *ExpirationManager) revokeEntry(le *leaseEntry) error { - data := map[string]interface{}{ - "previous_lease": le.Lease, - "previous_data": le.Data, - } - req := &logical.Request{ - Operation: logical.RevokeOperation, - Path: le.Path, - Data: data, - } - _, err := m.router.Route(req) + _, err := m.router.Route(logical.RevokeRequest( + le.Path, le.Lease, le.Data)) if err != nil { return fmt.Errorf("failed to revoke entry: %v", err) } @@ -340,16 +332,8 @@ func (m *ExpirationManager) revokeEntry(le *leaseEntry) error { // renewEntry is used to attempt renew of an internal entry func (m *ExpirationManager) renewEntry(le *leaseEntry, increment time.Duration) (*logical.Response, error) { - data := map[string]interface{}{ - "previous": le.Data, - "increment": increment, - } - req := &logical.Request{ - Operation: logical.RenewOperation, - Path: le.Path, - Data: data, - } - resp, err := m.router.Route(req) + resp, err := m.router.Route(logical.RenewRequest( + le.Path, increment, le.Lease, le.Data)) if err != nil { return nil, fmt.Errorf("failed to renew entry: %v", err) } diff --git a/vault/expiration_test.go b/vault/expiration_test.go index aa9324e105..d7447ce984 100644 --- a/vault/expiration_test.go +++ b/vault/expiration_test.go @@ -409,7 +409,7 @@ func TestExpiration_revokeEntry(t *testing.T) { if req.Path != le.Path { t.Fatalf("Bad: %v", req) } - if !reflect.DeepEqual(req.Data["previous"], le.Data) { + if !reflect.DeepEqual(req.Data["previous_data"], le.Data) { t.Fatalf("Bad: %v", req) } } @@ -462,7 +462,7 @@ func TestExpiration_renewEntry(t *testing.T) { if req.Path != le.Path { t.Fatalf("Bad: %v", req) } - if !reflect.DeepEqual(req.Data["previous"], le.Data) { + if !reflect.DeepEqual(req.Data["previous_data"], le.Data) { t.Fatalf("Bad: %v", req) } if req.Data["increment"] != time.Second { diff --git a/vault/logical_passthrough.go b/vault/logical_passthrough.go index 84be9b8618..f9ab71f5d9 100644 --- a/vault/logical_passthrough.go +++ b/vault/logical_passthrough.go @@ -26,17 +26,24 @@ func PassthroughBackendFactory(map[string]string) (logical.Backend, error) { Callbacks: map[logical.Operation]framework.OperationFunc{ logical.ReadOperation: b.handleRead, - logical.RenewOperation: b.handleRead, logical.WriteOperation: b.handleWrite, logical.DeleteOperation: b.handleDelete, logical.ListOperation: b.handleList, - logical.RevokeOperation: b.handleRevoke, }, HelpSynopsis: strings.TrimSpace(passthroughHelpSynopsis), HelpDescription: strings.TrimSpace(passthroughHelpDescription), }, }, + + Secrets: []*framework.Secret{ + &framework.Secret{ + Type: "generic", + + Renew: b.handleRead, + Revoke: b.handleRevoke, + }, + }, }, nil } @@ -73,25 +80,22 @@ func (b *PassthroughBackend) handleRead( return nil, fmt.Errorf("json decoding failed: %v", err) } + // Generate the response + resp, err := raw.Backend.Secret("generic").Response(rawData) + if err != nil { + return nil, fmt.Errorf("read failed: %v", err) + } + // Check if there is a lease key leaseVal, ok := rawData["lease"].(string) - var lease *logical.Lease if ok { leaseDuration, err := time.ParseDuration(leaseVal) if err == nil { - lease = &logical.Lease{ - Renewable: false, - Duration: leaseDuration, - } + resp.Lease.Renewable = false + resp.Lease.Duration = leaseDuration } } - // Generate the response - resp := &logical.Response{ - IsSecret: true, - Lease: lease, - Data: rawData, - } return resp, nil } diff --git a/vault/logical_passthrough_test.go b/vault/logical_passthrough_test.go index a777ea1691..e330f6c829 100644 --- a/vault/logical_passthrough_test.go +++ b/vault/logical_passthrough_test.go @@ -69,6 +69,7 @@ func TestPassthroughBackend_Read(t *testing.T) { }, } + resp.Lease.VaultID = "" if !reflect.DeepEqual(resp, expected) { t.Fatalf("bad response.\n\nexpected: %#v\n\nGot: %#v", expected, resp) }