Add UpgradeValue path to tokenutil (#7041)

This drastically reduces boilerplate for upgrading existing values
This commit is contained in:
Jeff Mitchell
2019-07-02 09:52:05 -04:00
committed by GitHub
parent bf5e9ec99d
commit 50704e612c
7 changed files with 188 additions and 284 deletions

View File

@@ -954,36 +954,12 @@ func (b *backend) pathRoleCreateUpdate(ctx context.Context, req *logical.Request
// handle upgrade cases
{
policiesRaw, ok := data.GetOk("token_policies")
if !ok {
policiesRaw, ok = data.GetOk("policies")
if ok {
role.Policies = policyutil.ParsePolicies(policiesRaw)
role.TokenPolicies = role.Policies
}
} else {
_, ok = data.GetOk("policies")
if ok {
role.Policies = role.TokenPolicies
} else {
role.Policies = nil
}
if err := tokenutil.UpgradeValue(data, "policies", "token_policies", &role.Policies, &role.TokenPolicies); err != nil {
return logical.ErrorResponse(err.Error()), nil
}
periodRaw, ok := data.GetOk("token_period")
if !ok {
periodRaw, ok = data.GetOk("period")
if ok {
role.Period = time.Duration(periodRaw.(int)) * time.Second
role.TokenPeriod = role.Period
}
} else {
_, ok = data.GetOk("period")
if ok {
role.Period = role.TokenPeriod
} else {
role.Period = 0
}
if err := tokenutil.UpgradeValue(data, "period", "token_period", &role.Period, &role.TokenPeriod); err != nil {
return logical.ErrorResponse(err.Error()), nil
}
}

View File

@@ -11,7 +11,6 @@ import (
uuid "github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/consts"
"github.com/hashicorp/vault/sdk/helper/policyutil"
"github.com/hashicorp/vault/sdk/helper/tokenutil"
"github.com/hashicorp/vault/sdk/logical"
"github.com/mitchellh/copystructure"
@@ -740,71 +739,32 @@ func (b *backend) pathRoleCreateUpdate(ctx context.Context, req *logical.Request
// Handle upgrade cases
{
policiesRaw, ok := data.GetOk("token_policies")
if !ok {
policiesRaw, ok = data.GetOk("policies")
if ok {
roleEntry.Policies = policyutil.ParsePolicies(policiesRaw)
roleEntry.TokenPolicies = roleEntry.Policies
}
} else {
_, ok = data.GetOk("policies")
if ok {
roleEntry.Policies = roleEntry.TokenPolicies
} else {
roleEntry.Policies = nil
}
if err := tokenutil.UpgradeValue(data, "policies", "token_policies", &roleEntry.Policies, &roleEntry.TokenPolicies); err != nil {
return logical.ErrorResponse(err.Error()), nil
}
ttlRaw, ok := data.GetOk("token_ttl")
if err := tokenutil.UpgradeValue(data, "ttl", "token_ttl", &roleEntry.TTL, &roleEntry.TokenTTL); err != nil {
return logical.ErrorResponse(err.Error()), nil
}
// Special case here for old lease value
_, ok := data.GetOk("token_ttl")
if !ok {
ttlRaw, ok = data.GetOk("ttl")
if !ok {
ttlRaw, ok = data.GetOk("lease")
}
if ok {
roleEntry.TTL = time.Duration(ttlRaw.(int)) * time.Second
roleEntry.TokenTTL = roleEntry.TTL
}
} else {
_, ok = data.GetOk("ttl")
if ok {
roleEntry.TTL = roleEntry.TokenTTL
} else {
roleEntry.TTL = 0
if !ok {
ttlRaw, ok := data.GetOk("lease")
if ok {
roleEntry.TTL = time.Duration(ttlRaw.(int)) * time.Second
roleEntry.TokenTTL = roleEntry.TTL
}
}
}
maxTTLRaw, ok := data.GetOk("token_max_ttl")
if !ok {
maxTTLRaw, ok = data.GetOk("max_ttl")
if ok {
roleEntry.MaxTTL = time.Duration(maxTTLRaw.(int)) * time.Second
roleEntry.TokenMaxTTL = roleEntry.MaxTTL
}
} else {
_, ok = data.GetOk("max_ttl")
if ok {
roleEntry.MaxTTL = roleEntry.TokenMaxTTL
} else {
roleEntry.MaxTTL = 0
}
if err := tokenutil.UpgradeValue(data, "max_ttl", "token_max_ttl", &roleEntry.MaxTTL, &roleEntry.TokenMaxTTL); err != nil {
return logical.ErrorResponse(err.Error()), nil
}
periodRaw, ok := data.GetOk("token_period")
if !ok {
periodRaw, ok = data.GetOk("period")
if ok {
roleEntry.Period = time.Duration(periodRaw.(int)) * time.Second
roleEntry.TokenPeriod = roleEntry.Period
}
} else {
_, ok = data.GetOk("period")
if ok {
roleEntry.Period = roleEntry.TokenPeriod
} else {
roleEntry.Period = 0
}
if err := tokenutil.UpgradeValue(data, "period", "token_period", &roleEntry.Period, &roleEntry.TokenPeriod); err != nil {
return logical.ErrorResponse(err.Error()), nil
}
}

View File

@@ -9,8 +9,6 @@ import (
sockaddr "github.com/hashicorp/go-sockaddr"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/parseutil"
"github.com/hashicorp/vault/sdk/helper/policyutil"
"github.com/hashicorp/vault/sdk/helper/tokenutil"
"github.com/hashicorp/vault/sdk/logical"
)
@@ -287,93 +285,37 @@ func (b *backend) pathCertWrite(ctx context.Context, req *logical.Request, d *fr
// Handle upgrade cases
{
policiesRaw, ok := d.GetOk("token_policies")
if !ok {
policiesRaw, ok = d.GetOk("policies")
if ok {
cert.Policies = policyutil.ParsePolicies(policiesRaw)
cert.TokenPolicies = cert.Policies
}
} else {
_, ok = d.GetOk("policies")
if ok {
cert.Policies = cert.TokenPolicies
} else {
cert.Policies = nil
}
if err := tokenutil.UpgradeValue(d, "policies", "token_policies", &cert.Policies, &cert.TokenPolicies); err != nil {
return logical.ErrorResponse(err.Error()), nil
}
ttlRaw, ok := d.GetOk("token_ttl")
if err := tokenutil.UpgradeValue(d, "ttl", "token_ttl", &cert.TTL, &cert.TokenTTL); err != nil {
return logical.ErrorResponse(err.Error()), nil
}
// Special case here for old lease value
_, ok := d.GetOk("token_ttl")
if !ok {
ttlRaw, ok = d.GetOk("ttl")
if !ok {
ttlRaw, ok = d.GetOk("lease")
}
if ok {
cert.TTL = time.Duration(ttlRaw.(int)) * time.Second
cert.TokenTTL = cert.TTL
}
} else {
_, ok = d.GetOk("ttl")
if ok {
cert.TTL = cert.TokenTTL
} else {
cert.TTL = 0
}
}
maxTTLRaw, ok := d.GetOk("token_max_ttl")
if !ok {
maxTTLRaw, ok = d.GetOk("max_ttl")
if ok {
cert.MaxTTL = time.Duration(maxTTLRaw.(int)) * time.Second
cert.TokenMaxTTL = cert.MaxTTL
}
} else {
_, ok = d.GetOk("max_ttl")
if ok {
cert.MaxTTL = cert.TokenMaxTTL
} else {
cert.MaxTTL = 0
}
}
periodRaw, ok := d.GetOk("token_period")
if !ok {
periodRaw, ok = d.GetOk("period")
if ok {
cert.Period = time.Duration(periodRaw.(int)) * time.Second
cert.TokenPeriod = cert.Period
}
} else {
_, ok = d.GetOk("period")
if ok {
cert.Period = cert.TokenPeriod
} else {
cert.Period = 0
}
}
boundCIDRsRaw, ok := d.GetOk("token_bound_cidrs")
if !ok {
boundCIDRsRaw, ok = d.GetOk("bound_cidrs")
if ok {
boundCIDRs, err := parseutil.ParseAddrs(boundCIDRsRaw)
if err != nil {
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
if !ok {
ttlRaw, ok := d.GetOk("lease")
if ok {
cert.TTL = time.Duration(ttlRaw.(int)) * time.Second
cert.TokenTTL = cert.TTL
}
cert.BoundCIDRs = boundCIDRs
cert.TokenBoundCIDRs = cert.BoundCIDRs
}
} else {
_, ok = d.GetOk("bound_cidrs")
if ok {
cert.BoundCIDRs = cert.TokenBoundCIDRs
} else {
cert.BoundCIDRs = nil
}
}
if err := tokenutil.UpgradeValue(d, "max_ttl", "token_max_ttl", &cert.MaxTTL, &cert.TokenMaxTTL); err != nil {
return logical.ErrorResponse(err.Error()), nil
}
if err := tokenutil.UpgradeValue(d, "period", "token_period", &cert.Period, &cert.TokenPeriod); err != nil {
return logical.ErrorResponse(err.Error()), nil
}
if err := tokenutil.UpgradeValue(d, "bound_cidrs", "token_bound_cidrs", &cert.BoundCIDRs, &cert.TokenBoundCIDRs); err != nil {
return logical.ErrorResponse(err.Error()), nil
}
}
var resp logical.Response

View File

@@ -86,36 +86,12 @@ func (b *backend) pathConfigWrite(ctx context.Context, req *logical.Request, dat
// Handle upgrade cases
{
ttlRaw, ok := data.GetOk("token_ttl")
if !ok {
ttlRaw, ok = data.GetOk("ttl")
if ok {
c.TTL = time.Duration(ttlRaw.(int)) * time.Second
c.TokenTTL = c.TTL
}
} else {
_, ok = data.GetOk("ttl")
if ok {
c.TTL = c.TokenTTL
} else {
c.TTL = 0
}
if err := tokenutil.UpgradeValue(data, "ttl", "token_ttl", &c.TTL, &c.TokenTTL); err != nil {
return logical.ErrorResponse(err.Error()), nil
}
maxTTLRaw, ok := data.GetOk("token_max_ttl")
if !ok {
maxTTLRaw, ok = data.GetOk("max_ttl")
if ok {
c.MaxTTL = time.Duration(maxTTLRaw.(int)) * time.Second
c.TokenMaxTTL = c.MaxTTL
}
} else {
_, ok = data.GetOk("max_ttl")
if ok {
c.MaxTTL = c.TokenMaxTTL
} else {
c.MaxTTL = 0
}
if err := tokenutil.UpgradeValue(data, "max_ttl", "token_max_ttl", &c.MaxTTL, &c.TokenMaxTTL); err != nil {
return logical.ErrorResponse(err.Error()), nil
}
}

View File

@@ -228,36 +228,12 @@ func (b *backend) pathConfigWrite(ctx context.Context, req *logical.Request, d *
// Handle upgrade cases
{
ttlRaw, ok := d.GetOk("token_ttl")
if !ok {
ttlRaw, ok = d.GetOk("ttl")
if ok {
cfg.TTL = time.Duration(ttlRaw.(int)) * time.Second
cfg.TokenTTL = cfg.TTL
}
} else {
_, ok = d.GetOk("ttl")
if ok {
cfg.TTL = cfg.TokenTTL
} else {
cfg.TTL = 0
}
if err := tokenutil.UpgradeValue(d, "ttl", "token_ttl", &cfg.TTL, &cfg.TokenTTL); err != nil {
return logical.ErrorResponse(err.Error()), nil
}
maxTTLRaw, ok := d.GetOk("token_max_ttl")
if !ok {
maxTTLRaw, ok = d.GetOk("max_ttl")
if ok {
cfg.MaxTTL = time.Duration(maxTTLRaw.(int)) * time.Second
cfg.TokenMaxTTL = cfg.MaxTTL
}
} else {
_, ok = d.GetOk("max_ttl")
if ok {
cfg.MaxTTL = cfg.TokenMaxTTL
} else {
cfg.MaxTTL = 0
}
if err := tokenutil.UpgradeValue(d, "max_ttl", "token_max_ttl", &cfg.MaxTTL, &cfg.TokenMaxTTL); err != nil {
return logical.ErrorResponse(err.Error()), nil
}
}

View File

@@ -8,8 +8,6 @@ import (
sockaddr "github.com/hashicorp/go-sockaddr"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/parseutil"
"github.com/hashicorp/vault/sdk/helper/policyutil"
"github.com/hashicorp/vault/sdk/helper/tokenutil"
"github.com/hashicorp/vault/sdk/logical"
)
@@ -210,72 +208,20 @@ func (b *backend) userCreateUpdate(ctx context.Context, req *logical.Request, d
// handle upgrade cases
{
policiesRaw, ok := d.GetOk("token_policies")
if !ok {
policiesRaw, ok = d.GetOk("policies")
if ok {
userEntry.Policies = policyutil.ParsePolicies(policiesRaw)
userEntry.TokenPolicies = userEntry.Policies
}
} else {
_, ok = d.GetOk("policies")
if ok {
userEntry.Policies = userEntry.TokenPolicies
} else {
userEntry.Policies = nil
}
if err := tokenutil.UpgradeValue(d, "policies", "token_policies", &userEntry.Policies, &userEntry.TokenPolicies); err != nil {
return logical.ErrorResponse(err.Error()), nil
}
ttlRaw, ok := d.GetOk("token_ttl")
if !ok {
ttlRaw, ok = d.GetOk("ttl")
if ok {
userEntry.TTL = time.Duration(ttlRaw.(int)) * time.Second
userEntry.TokenTTL = userEntry.TTL
}
} else {
_, ok = d.GetOk("ttl")
if ok {
userEntry.TTL = userEntry.TokenTTL
} else {
userEntry.TTL = 0
}
if err := tokenutil.UpgradeValue(d, "ttl", "token_ttl", &userEntry.TTL, &userEntry.TokenTTL); err != nil {
return logical.ErrorResponse(err.Error()), nil
}
maxTTLRaw, ok := d.GetOk("token_max_ttl")
if !ok {
maxTTLRaw, ok = d.GetOk("max_ttl")
if ok {
userEntry.MaxTTL = time.Duration(maxTTLRaw.(int)) * time.Second
userEntry.TokenMaxTTL = userEntry.TokenMaxTTL
}
} else {
_, ok = d.GetOk("max_ttl")
if ok {
userEntry.MaxTTL = userEntry.TokenMaxTTL
} else {
userEntry.MaxTTL = 0
}
if err := tokenutil.UpgradeValue(d, "max_ttl", "token_max_ttl", &userEntry.MaxTTL, &userEntry.TokenMaxTTL); err != nil {
return logical.ErrorResponse(err.Error()), nil
}
boundCIDRsRaw, ok := d.GetOk("token_bound_cidrs")
if !ok {
boundCIDRsRaw, ok = d.GetOk("bound_cidrs")
if ok {
boundCIDRs, err := parseutil.ParseAddrs(boundCIDRsRaw)
if err != nil {
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
}
userEntry.BoundCIDRs = boundCIDRs
userEntry.TokenBoundCIDRs = userEntry.BoundCIDRs
}
} else {
_, ok = d.GetOk("bound_cidrs")
if ok {
userEntry.BoundCIDRs = userEntry.TokenBoundCIDRs
} else {
userEntry.BoundCIDRs = nil
}
if err := tokenutil.UpgradeValue(d, "bound_cidrs", "token_bound_cirs", &userEntry.BoundCIDRs, &userEntry.TokenBoundCIDRs); err != nil {
return logical.ErrorResponse(err.Error()), nil
}
}

View File

@@ -8,6 +8,7 @@ import (
sockaddr "github.com/hashicorp/go-sockaddr"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/parseutil"
"github.com/hashicorp/vault/sdk/helper/policyutil"
"github.com/hashicorp/vault/sdk/helper/strutil"
"github.com/hashicorp/vault/sdk/logical"
)
@@ -266,6 +267,133 @@ func DeprecationText(param string) string {
return fmt.Sprintf("Use %q instead. If this and %q are both specified, only %q will be used.", param, param, param)
}
func upgradeDurationValue(d *framework.FieldData, oldKey, newKey string, oldVal, newVal *time.Duration) error {
_, ok := d.GetOk(newKey)
if !ok {
raw, ok := d.GetOk(oldKey)
if ok {
*oldVal = time.Duration(raw.(int)) * time.Second
*newVal = *oldVal
}
} else {
_, ok = d.GetOk(oldKey)
if ok {
*oldVal = *newVal
} else {
*oldVal = 0
}
}
return nil
}
func upgradeIntValue(d *framework.FieldData, oldKey, newKey string, oldVal, newVal *int) error {
_, ok := d.GetOk(newKey)
if !ok {
raw, ok := d.GetOk(oldKey)
if ok {
*oldVal = raw.(int)
*newVal = *oldVal
}
} else {
_, ok = d.GetOk(oldKey)
if ok {
*oldVal = *newVal
} else {
*oldVal = 0
}
}
return nil
}
func upgradeStringSliceValue(d *framework.FieldData, oldKey, newKey string, oldVal, newVal *[]string) error {
_, ok := d.GetOk(newKey)
if !ok {
raw, ok := d.GetOk(oldKey)
if ok {
// Special case: if we're looking at "token_policies" parse the policies
if newKey == "token_policies" {
*oldVal = policyutil.ParsePolicies(raw)
} else {
*oldVal = raw.([]string)
}
*newVal = *oldVal
}
} else {
_, ok = d.GetOk(oldKey)
if ok {
*oldVal = *newVal
} else {
*oldVal = nil
}
}
return nil
}
func upgradeSockAddrSliceValue(d *framework.FieldData, oldKey, newKey string, oldVal, newVal *[]*sockaddr.SockAddrMarshaler) error {
_, ok := d.GetOk(newKey)
if !ok {
raw, ok := d.GetOk(oldKey)
if ok {
boundCIDRs, err := parseutil.ParseAddrs(raw)
if err != nil {
return err
}
*oldVal = boundCIDRs
*newVal = *oldVal
}
} else {
_, ok = d.GetOk(oldKey)
if ok {
*oldVal = *newVal
} else {
*oldVal = nil
}
}
return nil
}
// UpgradeValue takes in old/new data keys and old/new values and calls out to
// a helper function to perform upgrades in a standardized way. It reqiures
// pointers in all cases so that we can set directly into the target struct.
func UpgradeValue(d *framework.FieldData, oldKey, newKey string, oldVal, newVal interface{}) error {
switch typedOldVal := oldVal.(type) {
case *time.Duration:
typedNewVal, ok := newVal.(*time.Duration)
if !ok {
return errors.New("mismatch in value types in tokenutil.UpgradeValue")
}
return upgradeDurationValue(d, oldKey, newKey, typedOldVal, typedNewVal)
case *int:
typedNewVal, ok := newVal.(*int)
if !ok {
return errors.New("mismatch in value types in tokenutil.UpgradeValue")
}
return upgradeIntValue(d, oldKey, newKey, typedOldVal, typedNewVal)
case *[]string:
typedNewVal, ok := newVal.(*[]string)
if !ok {
return errors.New("mismatch in value types in tokenutil.UpgradeValue")
}
return upgradeStringSliceValue(d, oldKey, newKey, typedOldVal, typedNewVal)
case *[]*sockaddr.SockAddrMarshaler:
typedNewVal, ok := newVal.(*[]*sockaddr.SockAddrMarshaler)
if !ok {
return errors.New("mismatch in value types in tokenutil.UpgradeValue")
}
return upgradeSockAddrSliceValue(d, oldKey, newKey, typedOldVal, typedNewVal)
default:
return errors.New("unhandled type in tokenutil.UpgradeValue")
}
}
const (
tokenPeriodHelp = `If set, tokens created via this role
will have no max lifetime; instead, their