From 50704e612cb9ddc110f62dad6beefedf73c6e4ed Mon Sep 17 00:00:00 2001 From: Jeff Mitchell Date: Tue, 2 Jul 2019 09:52:05 -0400 Subject: [PATCH] Add UpgradeValue path to tokenutil (#7041) This drastically reduces boilerplate for upgrading existing values --- builtin/credential/approle/path_role.go | 32 +----- builtin/credential/aws/path_role.go | 74 +++---------- builtin/credential/cert/path_certs.go | 104 ++++-------------- builtin/credential/github/path_config.go | 32 +----- builtin/credential/okta/path_config.go | 32 +----- builtin/credential/userpass/path_users.go | 70 ++---------- sdk/helper/tokenutil/tokenutil.go | 128 ++++++++++++++++++++++ 7 files changed, 188 insertions(+), 284 deletions(-) diff --git a/builtin/credential/approle/path_role.go b/builtin/credential/approle/path_role.go index a47b8247c1..14d42ed74b 100644 --- a/builtin/credential/approle/path_role.go +++ b/builtin/credential/approle/path_role.go @@ -954,36 +954,12 @@ func (b *backend) pathRoleCreateUpdate(ctx context.Context, req *logical.Request // handle upgrade cases { - policiesRaw, ok := data.GetOk("token_policies") - if !ok { - policiesRaw, ok = data.GetOk("policies") - if ok { - role.Policies = policyutil.ParsePolicies(policiesRaw) - role.TokenPolicies = role.Policies - } - } else { - _, ok = data.GetOk("policies") - if ok { - role.Policies = role.TokenPolicies - } else { - role.Policies = nil - } + if err := tokenutil.UpgradeValue(data, "policies", "token_policies", &role.Policies, &role.TokenPolicies); err != nil { + return logical.ErrorResponse(err.Error()), nil } - periodRaw, ok := data.GetOk("token_period") - if !ok { - periodRaw, ok = data.GetOk("period") - if ok { - role.Period = time.Duration(periodRaw.(int)) * time.Second - role.TokenPeriod = role.Period - } - } else { - _, ok = data.GetOk("period") - if ok { - role.Period = role.TokenPeriod - } else { - role.Period = 0 - } + if err := tokenutil.UpgradeValue(data, "period", "token_period", &role.Period, &role.TokenPeriod); err != nil { + return logical.ErrorResponse(err.Error()), nil } } diff --git a/builtin/credential/aws/path_role.go b/builtin/credential/aws/path_role.go index dc80e15657..dd1a75b383 100644 --- a/builtin/credential/aws/path_role.go +++ b/builtin/credential/aws/path_role.go @@ -11,7 +11,6 @@ import ( uuid "github.com/hashicorp/go-uuid" "github.com/hashicorp/vault/sdk/framework" "github.com/hashicorp/vault/sdk/helper/consts" - "github.com/hashicorp/vault/sdk/helper/policyutil" "github.com/hashicorp/vault/sdk/helper/tokenutil" "github.com/hashicorp/vault/sdk/logical" "github.com/mitchellh/copystructure" @@ -740,71 +739,32 @@ func (b *backend) pathRoleCreateUpdate(ctx context.Context, req *logical.Request // Handle upgrade cases { - policiesRaw, ok := data.GetOk("token_policies") - if !ok { - policiesRaw, ok = data.GetOk("policies") - if ok { - roleEntry.Policies = policyutil.ParsePolicies(policiesRaw) - roleEntry.TokenPolicies = roleEntry.Policies - } - } else { - _, ok = data.GetOk("policies") - if ok { - roleEntry.Policies = roleEntry.TokenPolicies - } else { - roleEntry.Policies = nil - } + if err := tokenutil.UpgradeValue(data, "policies", "token_policies", &roleEntry.Policies, &roleEntry.TokenPolicies); err != nil { + return logical.ErrorResponse(err.Error()), nil } - ttlRaw, ok := data.GetOk("token_ttl") + if err := tokenutil.UpgradeValue(data, "ttl", "token_ttl", &roleEntry.TTL, &roleEntry.TokenTTL); err != nil { + return logical.ErrorResponse(err.Error()), nil + } + // Special case here for old lease value + _, ok := data.GetOk("token_ttl") if !ok { - ttlRaw, ok = data.GetOk("ttl") - if !ok { - ttlRaw, ok = data.GetOk("lease") - } - if ok { - roleEntry.TTL = time.Duration(ttlRaw.(int)) * time.Second - roleEntry.TokenTTL = roleEntry.TTL - } - } else { _, ok = data.GetOk("ttl") - if ok { - roleEntry.TTL = roleEntry.TokenTTL - } else { - roleEntry.TTL = 0 + if !ok { + ttlRaw, ok := data.GetOk("lease") + if ok { + roleEntry.TTL = time.Duration(ttlRaw.(int)) * time.Second + roleEntry.TokenTTL = roleEntry.TTL + } } } - maxTTLRaw, ok := data.GetOk("token_max_ttl") - if !ok { - maxTTLRaw, ok = data.GetOk("max_ttl") - if ok { - roleEntry.MaxTTL = time.Duration(maxTTLRaw.(int)) * time.Second - roleEntry.TokenMaxTTL = roleEntry.MaxTTL - } - } else { - _, ok = data.GetOk("max_ttl") - if ok { - roleEntry.MaxTTL = roleEntry.TokenMaxTTL - } else { - roleEntry.MaxTTL = 0 - } + if err := tokenutil.UpgradeValue(data, "max_ttl", "token_max_ttl", &roleEntry.MaxTTL, &roleEntry.TokenMaxTTL); err != nil { + return logical.ErrorResponse(err.Error()), nil } - periodRaw, ok := data.GetOk("token_period") - if !ok { - periodRaw, ok = data.GetOk("period") - if ok { - roleEntry.Period = time.Duration(periodRaw.(int)) * time.Second - roleEntry.TokenPeriod = roleEntry.Period - } - } else { - _, ok = data.GetOk("period") - if ok { - roleEntry.Period = roleEntry.TokenPeriod - } else { - roleEntry.Period = 0 - } + if err := tokenutil.UpgradeValue(data, "period", "token_period", &roleEntry.Period, &roleEntry.TokenPeriod); err != nil { + return logical.ErrorResponse(err.Error()), nil } } diff --git a/builtin/credential/cert/path_certs.go b/builtin/credential/cert/path_certs.go index a286e8b11e..0a57f347be 100644 --- a/builtin/credential/cert/path_certs.go +++ b/builtin/credential/cert/path_certs.go @@ -9,8 +9,6 @@ import ( sockaddr "github.com/hashicorp/go-sockaddr" "github.com/hashicorp/vault/sdk/framework" - "github.com/hashicorp/vault/sdk/helper/parseutil" - "github.com/hashicorp/vault/sdk/helper/policyutil" "github.com/hashicorp/vault/sdk/helper/tokenutil" "github.com/hashicorp/vault/sdk/logical" ) @@ -287,93 +285,37 @@ func (b *backend) pathCertWrite(ctx context.Context, req *logical.Request, d *fr // Handle upgrade cases { - policiesRaw, ok := d.GetOk("token_policies") - if !ok { - policiesRaw, ok = d.GetOk("policies") - if ok { - cert.Policies = policyutil.ParsePolicies(policiesRaw) - cert.TokenPolicies = cert.Policies - } - } else { - _, ok = d.GetOk("policies") - if ok { - cert.Policies = cert.TokenPolicies - } else { - cert.Policies = nil - } + if err := tokenutil.UpgradeValue(d, "policies", "token_policies", &cert.Policies, &cert.TokenPolicies); err != nil { + return logical.ErrorResponse(err.Error()), nil } - ttlRaw, ok := d.GetOk("token_ttl") + if err := tokenutil.UpgradeValue(d, "ttl", "token_ttl", &cert.TTL, &cert.TokenTTL); err != nil { + return logical.ErrorResponse(err.Error()), nil + } + // Special case here for old lease value + _, ok := d.GetOk("token_ttl") if !ok { - ttlRaw, ok = d.GetOk("ttl") - if !ok { - ttlRaw, ok = d.GetOk("lease") - } - if ok { - cert.TTL = time.Duration(ttlRaw.(int)) * time.Second - cert.TokenTTL = cert.TTL - } - } else { _, ok = d.GetOk("ttl") - if ok { - cert.TTL = cert.TokenTTL - } else { - cert.TTL = 0 - } - } - - maxTTLRaw, ok := d.GetOk("token_max_ttl") - if !ok { - maxTTLRaw, ok = d.GetOk("max_ttl") - if ok { - cert.MaxTTL = time.Duration(maxTTLRaw.(int)) * time.Second - cert.TokenMaxTTL = cert.MaxTTL - } - } else { - _, ok = d.GetOk("max_ttl") - if ok { - cert.MaxTTL = cert.TokenMaxTTL - } else { - cert.MaxTTL = 0 - } - } - - periodRaw, ok := d.GetOk("token_period") - if !ok { - periodRaw, ok = d.GetOk("period") - if ok { - cert.Period = time.Duration(periodRaw.(int)) * time.Second - cert.TokenPeriod = cert.Period - } - } else { - _, ok = d.GetOk("period") - if ok { - cert.Period = cert.TokenPeriod - } else { - cert.Period = 0 - } - } - - boundCIDRsRaw, ok := d.GetOk("token_bound_cidrs") - if !ok { - boundCIDRsRaw, ok = d.GetOk("bound_cidrs") - if ok { - boundCIDRs, err := parseutil.ParseAddrs(boundCIDRsRaw) - if err != nil { - return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest + if !ok { + ttlRaw, ok := d.GetOk("lease") + if ok { + cert.TTL = time.Duration(ttlRaw.(int)) * time.Second + cert.TokenTTL = cert.TTL } - cert.BoundCIDRs = boundCIDRs - cert.TokenBoundCIDRs = cert.BoundCIDRs - } - } else { - _, ok = d.GetOk("bound_cidrs") - if ok { - cert.BoundCIDRs = cert.TokenBoundCIDRs - } else { - cert.BoundCIDRs = nil } } + if err := tokenutil.UpgradeValue(d, "max_ttl", "token_max_ttl", &cert.MaxTTL, &cert.TokenMaxTTL); err != nil { + return logical.ErrorResponse(err.Error()), nil + } + + if err := tokenutil.UpgradeValue(d, "period", "token_period", &cert.Period, &cert.TokenPeriod); err != nil { + return logical.ErrorResponse(err.Error()), nil + } + + if err := tokenutil.UpgradeValue(d, "bound_cidrs", "token_bound_cidrs", &cert.BoundCIDRs, &cert.TokenBoundCIDRs); err != nil { + return logical.ErrorResponse(err.Error()), nil + } } var resp logical.Response diff --git a/builtin/credential/github/path_config.go b/builtin/credential/github/path_config.go index be793b7978..3f66cefca6 100644 --- a/builtin/credential/github/path_config.go +++ b/builtin/credential/github/path_config.go @@ -86,36 +86,12 @@ func (b *backend) pathConfigWrite(ctx context.Context, req *logical.Request, dat // Handle upgrade cases { - ttlRaw, ok := data.GetOk("token_ttl") - if !ok { - ttlRaw, ok = data.GetOk("ttl") - if ok { - c.TTL = time.Duration(ttlRaw.(int)) * time.Second - c.TokenTTL = c.TTL - } - } else { - _, ok = data.GetOk("ttl") - if ok { - c.TTL = c.TokenTTL - } else { - c.TTL = 0 - } + if err := tokenutil.UpgradeValue(data, "ttl", "token_ttl", &c.TTL, &c.TokenTTL); err != nil { + return logical.ErrorResponse(err.Error()), nil } - maxTTLRaw, ok := data.GetOk("token_max_ttl") - if !ok { - maxTTLRaw, ok = data.GetOk("max_ttl") - if ok { - c.MaxTTL = time.Duration(maxTTLRaw.(int)) * time.Second - c.TokenMaxTTL = c.MaxTTL - } - } else { - _, ok = data.GetOk("max_ttl") - if ok { - c.MaxTTL = c.TokenMaxTTL - } else { - c.MaxTTL = 0 - } + if err := tokenutil.UpgradeValue(data, "max_ttl", "token_max_ttl", &c.MaxTTL, &c.TokenMaxTTL); err != nil { + return logical.ErrorResponse(err.Error()), nil } } diff --git a/builtin/credential/okta/path_config.go b/builtin/credential/okta/path_config.go index b17aa59048..3cbf1041e2 100644 --- a/builtin/credential/okta/path_config.go +++ b/builtin/credential/okta/path_config.go @@ -228,36 +228,12 @@ func (b *backend) pathConfigWrite(ctx context.Context, req *logical.Request, d * // Handle upgrade cases { - ttlRaw, ok := d.GetOk("token_ttl") - if !ok { - ttlRaw, ok = d.GetOk("ttl") - if ok { - cfg.TTL = time.Duration(ttlRaw.(int)) * time.Second - cfg.TokenTTL = cfg.TTL - } - } else { - _, ok = d.GetOk("ttl") - if ok { - cfg.TTL = cfg.TokenTTL - } else { - cfg.TTL = 0 - } + if err := tokenutil.UpgradeValue(d, "ttl", "token_ttl", &cfg.TTL, &cfg.TokenTTL); err != nil { + return logical.ErrorResponse(err.Error()), nil } - maxTTLRaw, ok := d.GetOk("token_max_ttl") - if !ok { - maxTTLRaw, ok = d.GetOk("max_ttl") - if ok { - cfg.MaxTTL = time.Duration(maxTTLRaw.(int)) * time.Second - cfg.TokenMaxTTL = cfg.MaxTTL - } - } else { - _, ok = d.GetOk("max_ttl") - if ok { - cfg.MaxTTL = cfg.TokenMaxTTL - } else { - cfg.MaxTTL = 0 - } + if err := tokenutil.UpgradeValue(d, "max_ttl", "token_max_ttl", &cfg.MaxTTL, &cfg.TokenMaxTTL); err != nil { + return logical.ErrorResponse(err.Error()), nil } } diff --git a/builtin/credential/userpass/path_users.go b/builtin/credential/userpass/path_users.go index aff97380ba..ae9af65b80 100644 --- a/builtin/credential/userpass/path_users.go +++ b/builtin/credential/userpass/path_users.go @@ -8,8 +8,6 @@ import ( sockaddr "github.com/hashicorp/go-sockaddr" "github.com/hashicorp/vault/sdk/framework" - "github.com/hashicorp/vault/sdk/helper/parseutil" - "github.com/hashicorp/vault/sdk/helper/policyutil" "github.com/hashicorp/vault/sdk/helper/tokenutil" "github.com/hashicorp/vault/sdk/logical" ) @@ -210,72 +208,20 @@ func (b *backend) userCreateUpdate(ctx context.Context, req *logical.Request, d // handle upgrade cases { - policiesRaw, ok := d.GetOk("token_policies") - if !ok { - policiesRaw, ok = d.GetOk("policies") - if ok { - userEntry.Policies = policyutil.ParsePolicies(policiesRaw) - userEntry.TokenPolicies = userEntry.Policies - } - } else { - _, ok = d.GetOk("policies") - if ok { - userEntry.Policies = userEntry.TokenPolicies - } else { - userEntry.Policies = nil - } + if err := tokenutil.UpgradeValue(d, "policies", "token_policies", &userEntry.Policies, &userEntry.TokenPolicies); err != nil { + return logical.ErrorResponse(err.Error()), nil } - ttlRaw, ok := d.GetOk("token_ttl") - if !ok { - ttlRaw, ok = d.GetOk("ttl") - if ok { - userEntry.TTL = time.Duration(ttlRaw.(int)) * time.Second - userEntry.TokenTTL = userEntry.TTL - } - } else { - _, ok = d.GetOk("ttl") - if ok { - userEntry.TTL = userEntry.TokenTTL - } else { - userEntry.TTL = 0 - } + if err := tokenutil.UpgradeValue(d, "ttl", "token_ttl", &userEntry.TTL, &userEntry.TokenTTL); err != nil { + return logical.ErrorResponse(err.Error()), nil } - maxTTLRaw, ok := d.GetOk("token_max_ttl") - if !ok { - maxTTLRaw, ok = d.GetOk("max_ttl") - if ok { - userEntry.MaxTTL = time.Duration(maxTTLRaw.(int)) * time.Second - userEntry.TokenMaxTTL = userEntry.TokenMaxTTL - } - } else { - _, ok = d.GetOk("max_ttl") - if ok { - userEntry.MaxTTL = userEntry.TokenMaxTTL - } else { - userEntry.MaxTTL = 0 - } + if err := tokenutil.UpgradeValue(d, "max_ttl", "token_max_ttl", &userEntry.MaxTTL, &userEntry.TokenMaxTTL); err != nil { + return logical.ErrorResponse(err.Error()), nil } - boundCIDRsRaw, ok := d.GetOk("token_bound_cidrs") - if !ok { - boundCIDRsRaw, ok = d.GetOk("bound_cidrs") - if ok { - boundCIDRs, err := parseutil.ParseAddrs(boundCIDRsRaw) - if err != nil { - return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest - } - userEntry.BoundCIDRs = boundCIDRs - userEntry.TokenBoundCIDRs = userEntry.BoundCIDRs - } - } else { - _, ok = d.GetOk("bound_cidrs") - if ok { - userEntry.BoundCIDRs = userEntry.TokenBoundCIDRs - } else { - userEntry.BoundCIDRs = nil - } + if err := tokenutil.UpgradeValue(d, "bound_cidrs", "token_bound_cirs", &userEntry.BoundCIDRs, &userEntry.TokenBoundCIDRs); err != nil { + return logical.ErrorResponse(err.Error()), nil } } diff --git a/sdk/helper/tokenutil/tokenutil.go b/sdk/helper/tokenutil/tokenutil.go index ad45ba40db..eb247c0124 100644 --- a/sdk/helper/tokenutil/tokenutil.go +++ b/sdk/helper/tokenutil/tokenutil.go @@ -8,6 +8,7 @@ import ( sockaddr "github.com/hashicorp/go-sockaddr" "github.com/hashicorp/vault/sdk/framework" "github.com/hashicorp/vault/sdk/helper/parseutil" + "github.com/hashicorp/vault/sdk/helper/policyutil" "github.com/hashicorp/vault/sdk/helper/strutil" "github.com/hashicorp/vault/sdk/logical" ) @@ -266,6 +267,133 @@ func DeprecationText(param string) string { return fmt.Sprintf("Use %q instead. If this and %q are both specified, only %q will be used.", param, param, param) } +func upgradeDurationValue(d *framework.FieldData, oldKey, newKey string, oldVal, newVal *time.Duration) error { + _, ok := d.GetOk(newKey) + if !ok { + raw, ok := d.GetOk(oldKey) + if ok { + *oldVal = time.Duration(raw.(int)) * time.Second + *newVal = *oldVal + } + } else { + _, ok = d.GetOk(oldKey) + if ok { + *oldVal = *newVal + } else { + *oldVal = 0 + } + } + + return nil +} + +func upgradeIntValue(d *framework.FieldData, oldKey, newKey string, oldVal, newVal *int) error { + _, ok := d.GetOk(newKey) + if !ok { + raw, ok := d.GetOk(oldKey) + if ok { + *oldVal = raw.(int) + *newVal = *oldVal + } + } else { + _, ok = d.GetOk(oldKey) + if ok { + *oldVal = *newVal + } else { + *oldVal = 0 + } + } + + return nil +} + +func upgradeStringSliceValue(d *framework.FieldData, oldKey, newKey string, oldVal, newVal *[]string) error { + _, ok := d.GetOk(newKey) + if !ok { + raw, ok := d.GetOk(oldKey) + if ok { + // Special case: if we're looking at "token_policies" parse the policies + if newKey == "token_policies" { + *oldVal = policyutil.ParsePolicies(raw) + } else { + *oldVal = raw.([]string) + } + *newVal = *oldVal + } + } else { + _, ok = d.GetOk(oldKey) + if ok { + *oldVal = *newVal + } else { + *oldVal = nil + } + } + + return nil +} + +func upgradeSockAddrSliceValue(d *framework.FieldData, oldKey, newKey string, oldVal, newVal *[]*sockaddr.SockAddrMarshaler) error { + _, ok := d.GetOk(newKey) + if !ok { + raw, ok := d.GetOk(oldKey) + if ok { + boundCIDRs, err := parseutil.ParseAddrs(raw) + if err != nil { + return err + } + *oldVal = boundCIDRs + *newVal = *oldVal + } + } else { + _, ok = d.GetOk(oldKey) + if ok { + *oldVal = *newVal + } else { + *oldVal = nil + } + } + + return nil +} + +// UpgradeValue takes in old/new data keys and old/new values and calls out to +// a helper function to perform upgrades in a standardized way. It reqiures +// pointers in all cases so that we can set directly into the target struct. +func UpgradeValue(d *framework.FieldData, oldKey, newKey string, oldVal, newVal interface{}) error { + switch typedOldVal := oldVal.(type) { + case *time.Duration: + typedNewVal, ok := newVal.(*time.Duration) + if !ok { + return errors.New("mismatch in value types in tokenutil.UpgradeValue") + } + return upgradeDurationValue(d, oldKey, newKey, typedOldVal, typedNewVal) + + case *int: + typedNewVal, ok := newVal.(*int) + if !ok { + return errors.New("mismatch in value types in tokenutil.UpgradeValue") + } + return upgradeIntValue(d, oldKey, newKey, typedOldVal, typedNewVal) + + case *[]string: + typedNewVal, ok := newVal.(*[]string) + if !ok { + return errors.New("mismatch in value types in tokenutil.UpgradeValue") + } + return upgradeStringSliceValue(d, oldKey, newKey, typedOldVal, typedNewVal) + + case *[]*sockaddr.SockAddrMarshaler: + typedNewVal, ok := newVal.(*[]*sockaddr.SockAddrMarshaler) + if !ok { + return errors.New("mismatch in value types in tokenutil.UpgradeValue") + } + return upgradeSockAddrSliceValue(d, oldKey, newKey, typedOldVal, typedNewVal) + + default: + return errors.New("unhandled type in tokenutil.UpgradeValue") + } +} + const ( tokenPeriodHelp = `If set, tokens created via this role will have no max lifetime; instead, their