Fix max_ttl not being honored in database backend when default_ttl is zero (#3814)

Fixes #3812
This commit is contained in:
Jeff Mitchell
2018-01-18 01:43:38 -05:00
committed by GitHub
parent b907a2e01f
commit 69eca11b62
7 changed files with 112 additions and 11 deletions

View File

@@ -9,6 +9,7 @@ import (
"reflect" "reflect"
"sync" "sync"
"testing" "testing"
"time"
"github.com/hashicorp/vault/builtin/logical/database/dbplugin" "github.com/hashicorp/vault/builtin/logical/database/dbplugin"
"github.com/hashicorp/vault/helper/pluginutil" "github.com/hashicorp/vault/helper/pluginutil"
@@ -267,6 +268,71 @@ func TestBackend_basic(t *testing.T) {
} }
// Create a role // Create a role
data = map[string]interface{}{
"db_name": "plugin-test",
"creation_statements": testRole,
"max_ttl": "10m",
}
req = &logical.Request{
Operation: logical.UpdateOperation,
Path: "roles/plugin-role-test",
Storage: config.StorageView,
Data: data,
}
resp, err = b.HandleRequest(context.Background(), req)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%s resp:%#v\n", err, resp)
}
// Get creds
data = map[string]interface{}{}
req = &logical.Request{
Operation: logical.ReadOperation,
Path: "creds/plugin-role-test",
Storage: config.StorageView,
Data: data,
}
credsResp, err := b.HandleRequest(context.Background(), req)
if err != nil || (credsResp != nil && credsResp.IsError()) {
t.Fatalf("err:%s resp:%#v\n", err, credsResp)
}
// Test for #3812
if credsResp.Secret.TTL != 10*time.Minute {
t.Fatalf("unexpected TTL of %d", credsResp.Secret.TTL)
}
// Update the role with no max ttl
data = map[string]interface{}{
"db_name": "plugin-test",
"creation_statements": testRole,
"default_ttl": "5m",
"max_ttl": 0,
}
req = &logical.Request{
Operation: logical.UpdateOperation,
Path: "roles/plugin-role-test",
Storage: config.StorageView,
Data: data,
}
resp, err = b.HandleRequest(context.Background(), req)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%s resp:%#v\n", err, resp)
}
// Get creds
data = map[string]interface{}{}
req = &logical.Request{
Operation: logical.ReadOperation,
Path: "creds/plugin-role-test",
Storage: config.StorageView,
Data: data,
}
credsResp, err = b.HandleRequest(context.Background(), req)
if err != nil || (credsResp != nil && credsResp.IsError()) {
t.Fatalf("err:%s resp:%#v\n", err, credsResp)
}
// Test for #3812
if credsResp.Secret.TTL != 5*time.Minute {
t.Fatalf("unexpected TTL of %d", credsResp.Secret.TTL)
}
// Update the role with a max ttl
data = map[string]interface{}{ data = map[string]interface{}{
"db_name": "plugin-test", "db_name": "plugin-test",
"creation_statements": testRole, "creation_statements": testRole,
@@ -283,7 +349,6 @@ func TestBackend_basic(t *testing.T) {
if err != nil || (resp != nil && resp.IsError()) { if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%s resp:%#v\n", err, resp) t.Fatalf("err:%s resp:%#v\n", err, resp)
} }
// Get creds // Get creds
data = map[string]interface{}{} data = map[string]interface{}{}
req = &logical.Request{ req = &logical.Request{
@@ -292,11 +357,14 @@ func TestBackend_basic(t *testing.T) {
Storage: config.StorageView, Storage: config.StorageView,
Data: data, Data: data,
} }
credsResp, err := b.HandleRequest(context.Background(), req) credsResp, err = b.HandleRequest(context.Background(), req)
if err != nil || (credsResp != nil && credsResp.IsError()) { if err != nil || (credsResp != nil && credsResp.IsError()) {
t.Fatalf("err:%s resp:%#v\n", err, credsResp) t.Fatalf("err:%s resp:%#v\n", err, credsResp)
} }
// Test for #3812
if credsResp.Secret.TTL != 5*time.Minute {
t.Fatalf("unexpected TTL of %d", credsResp.Secret.TTL)
}
if !testCredsExist(t, credsResp, connURL) { if !testCredsExist(t, credsResp, connURL) {
t.Fatalf("Creds should exist") t.Fatalf("Creds should exist")
} }

View File

@@ -74,7 +74,12 @@ func (b *databaseBackend) pathCredsCreateRead() framework.OperationFunc {
} }
} }
expiration := time.Now().Add(role.DefaultTTL) ttl := role.DefaultTTL
if ttl == 0 || (role.MaxTTL > 0 && ttl > role.MaxTTL) {
ttl = role.MaxTTL
}
expiration := time.Now().Add(ttl)
usernameConfig := dbplugin.UsernameConfig{ usernameConfig := dbplugin.UsernameConfig{
DisplayName: req.DisplayName, DisplayName: req.DisplayName,
@@ -96,7 +101,7 @@ func (b *databaseBackend) pathCredsCreateRead() framework.OperationFunc {
"username": username, "username": username,
"role": name, "role": name,
}) })
resp.Secret.TTL = role.DefaultTTL resp.Secret.TTL = ttl
unlockFunc() unlockFunc()
return resp, nil return resp, nil

View File

@@ -95,7 +95,13 @@ func (b *backend) pathCredsCreateRead(ctx context.Context, req *logical.Request,
"username": username, "username": username,
"db": role.DB, "db": role.DB,
}) })
resp.Secret.TTL = leaseConfig.TTL
ttl := leaseConfig.TTL
if ttl == 0 || (leaseConfig.MaxTTL > 0 && ttl > leaseConfig.MaxTTL) {
ttl = leaseConfig.MaxTTL
}
resp.Secret.TTL = ttl
return resp, nil return resp, nil
} }

View File

@@ -115,7 +115,13 @@ func (b *backend) pathCredsCreateRead(ctx context.Context, req *logical.Request,
}, map[string]interface{}{ }, map[string]interface{}{
"username": username, "username": username,
}) })
resp.Secret.TTL = leaseConfig.TTL
ttl := leaseConfig.TTL
if ttl == 0 || (leaseConfig.TTLMax > 0 && ttl > leaseConfig.TTLMax) {
ttl = leaseConfig.TTLMax
}
resp.Secret.TTL = ttl
return resp, nil return resp, nil
} }

View File

@@ -129,7 +129,13 @@ func (b *backend) pathRoleCreateRead(ctx context.Context, req *logical.Request,
"username": username, "username": username,
"role": name, "role": name,
}) })
resp.Secret.TTL = lease.Lease
ttl := lease.Lease
if ttl == 0 || (lease.LeaseMax > 0 && ttl > lease.LeaseMax) {
ttl = lease.LeaseMax
}
resp.Secret.TTL = ttl
return resp, nil return resp, nil
} }

View File

@@ -63,6 +63,11 @@ func (b *backend) pathRoleCreateRead(ctx context.Context, req *logical.Request,
} }
} }
ttl := lease.Lease
if ttl == 0 || (lease.LeaseMax > 0 && ttl > lease.LeaseMax) {
ttl = lease.LeaseMax
}
// Generate the username, password and expiration. PG limits user to 63 characters // Generate the username, password and expiration. PG limits user to 63 characters
displayName := req.DisplayName displayName := req.DisplayName
if len(displayName) > 26 { if len(displayName) > 26 {
@@ -81,7 +86,7 @@ func (b *backend) pathRoleCreateRead(ctx context.Context, req *logical.Request,
return nil, err return nil, err
} }
expiration := time.Now(). expiration := time.Now().
Add(lease.Lease). Add(ttl).
Format("2006-01-02 15:04:05-0700") Format("2006-01-02 15:04:05-0700")
// Get our handle // Get our handle
@@ -142,7 +147,7 @@ func (b *backend) pathRoleCreateRead(ctx context.Context, req *logical.Request,
"username": username, "username": username,
"role": name, "role": name,
}) })
resp.Secret.TTL = lease.Lease resp.Secret.TTL = ttl
return resp, nil return resp, nil
} }

View File

@@ -103,8 +103,13 @@ func (b *backend) pathCredsRead(ctx context.Context, req *logical.Request, d *fr
if err != nil { if err != nil {
return nil, err return nil, err
} }
if lease != nil { if lease != nil {
resp.Secret.TTL = lease.TTL ttl := lease.TTL
if ttl == 0 || (lease.MaxTTL > 0 && ttl > lease.MaxTTL) {
ttl = lease.MaxTTL
}
resp.Secret.TTL = ttl
} }
return resp, nil return resp, nil