logical/framework: allow the lease max to come from existing lease

This commit is contained in:
Armon Dadgar
2015-06-17 14:24:12 -07:00
parent 2a894171ca
commit daf94d6721
3 changed files with 25 additions and 9 deletions

View File

@@ -215,7 +215,7 @@ func TestBackendHandleRequest_renew(t *testing.T) {
func TestBackendHandleRequest_renewExtend(t *testing.T) {
secret := &Secret{
Type: "foo",
Renew: LeaseExtend(0, 0),
Renew: LeaseExtend(0, 0, false),
DefaultDuration: 5 * time.Minute,
}
b := &Backend{

View File

@@ -14,13 +14,22 @@ import (
//
// maxSession is the maximum session length allowed since the original
// issue time. If this is zero, it is ignored.
func LeaseExtend(max, maxSession time.Duration) OperationFunc {
//
// maxFromLease controls if the maximum renewal period comes from the existing
// lease. This means the value of `max` will be replaced with the existing
// lease duration.
func LeaseExtend(max, maxSession time.Duration, maxFromLease bool) OperationFunc {
return func(req *logical.Request, data *FieldData) (*logical.Response, error) {
lease := detectLease(req)
if lease == nil {
return nil, fmt.Errorf("no lease options for request")
}
// Check if we should limit max
if maxFromLease {
max = lease.Lease
}
// Sanity check the desired increment
switch {
// Protect against negative leases

View File

@@ -15,6 +15,7 @@ func TestLeaseExtend(t *testing.T) {
MaxSession time.Duration
Request time.Duration
Result time.Duration
MaxFromLease bool
Error bool
}{
"valid request, good bounds": {
@@ -62,20 +63,26 @@ func TestLeaseExtend(t *testing.T) {
Request: -7 * time.Hour,
Error: true,
},
"max form lease, request too large": {
Request: 10 * time.Hour,
MaxFromLease: true,
Result: time.Hour,
},
}
for name, tc := range cases {
req := &logical.Request{
Auth: &logical.Auth{
LeaseOptions: logical.LeaseOptions{
Lease: 1 * time.Second,
Lease: 1 * time.Hour,
LeaseIssue: now,
LeaseIncrement: tc.Request,
},
},
}
callback := LeaseExtend(tc.Max, tc.MaxSession)
callback := LeaseExtend(tc.Max, tc.MaxSession, tc.MaxFromLease)
resp, err := callback(req, nil)
if (err != nil) != tc.Error {
t.Fatalf("bad: %s\nerr: %s", name, err)