mirror of
				https://github.com/optim-enterprises-bv/vault.git
				synced 2025-11-03 20:17:59 +00:00 
			
		
		
		
	Add UpgradeValue path to tokenutil (#7041)
This drastically reduces boilerplate for upgrading existing values
This commit is contained in:
		@@ -954,36 +954,12 @@ func (b *backend) pathRoleCreateUpdate(ctx context.Context, req *logical.Request
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	// handle upgrade cases
 | 
						// handle upgrade cases
 | 
				
			||||||
	{
 | 
						{
 | 
				
			||||||
		policiesRaw, ok := data.GetOk("token_policies")
 | 
							if err := tokenutil.UpgradeValue(data, "policies", "token_policies", &role.Policies, &role.TokenPolicies); err != nil {
 | 
				
			||||||
		if !ok {
 | 
								return logical.ErrorResponse(err.Error()), nil
 | 
				
			||||||
			policiesRaw, ok = data.GetOk("policies")
 | 
					 | 
				
			||||||
			if ok {
 | 
					 | 
				
			||||||
				role.Policies = policyutil.ParsePolicies(policiesRaw)
 | 
					 | 
				
			||||||
				role.TokenPolicies = role.Policies
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		} else {
 | 
					 | 
				
			||||||
			_, ok = data.GetOk("policies")
 | 
					 | 
				
			||||||
			if ok {
 | 
					 | 
				
			||||||
				role.Policies = role.TokenPolicies
 | 
					 | 
				
			||||||
			} else {
 | 
					 | 
				
			||||||
				role.Policies = nil
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		periodRaw, ok := data.GetOk("token_period")
 | 
							if err := tokenutil.UpgradeValue(data, "period", "token_period", &role.Period, &role.TokenPeriod); err != nil {
 | 
				
			||||||
		if !ok {
 | 
								return logical.ErrorResponse(err.Error()), nil
 | 
				
			||||||
			periodRaw, ok = data.GetOk("period")
 | 
					 | 
				
			||||||
			if ok {
 | 
					 | 
				
			||||||
				role.Period = time.Duration(periodRaw.(int)) * time.Second
 | 
					 | 
				
			||||||
				role.TokenPeriod = role.Period
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		} else {
 | 
					 | 
				
			||||||
			_, ok = data.GetOk("period")
 | 
					 | 
				
			||||||
			if ok {
 | 
					 | 
				
			||||||
				role.Period = role.TokenPeriod
 | 
					 | 
				
			||||||
			} else {
 | 
					 | 
				
			||||||
				role.Period = 0
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -11,7 +11,6 @@ import (
 | 
				
			|||||||
	uuid "github.com/hashicorp/go-uuid"
 | 
						uuid "github.com/hashicorp/go-uuid"
 | 
				
			||||||
	"github.com/hashicorp/vault/sdk/framework"
 | 
						"github.com/hashicorp/vault/sdk/framework"
 | 
				
			||||||
	"github.com/hashicorp/vault/sdk/helper/consts"
 | 
						"github.com/hashicorp/vault/sdk/helper/consts"
 | 
				
			||||||
	"github.com/hashicorp/vault/sdk/helper/policyutil"
 | 
					 | 
				
			||||||
	"github.com/hashicorp/vault/sdk/helper/tokenutil"
 | 
						"github.com/hashicorp/vault/sdk/helper/tokenutil"
 | 
				
			||||||
	"github.com/hashicorp/vault/sdk/logical"
 | 
						"github.com/hashicorp/vault/sdk/logical"
 | 
				
			||||||
	"github.com/mitchellh/copystructure"
 | 
						"github.com/mitchellh/copystructure"
 | 
				
			||||||
@@ -740,71 +739,32 @@ func (b *backend) pathRoleCreateUpdate(ctx context.Context, req *logical.Request
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	// Handle upgrade cases
 | 
						// Handle upgrade cases
 | 
				
			||||||
	{
 | 
						{
 | 
				
			||||||
		policiesRaw, ok := data.GetOk("token_policies")
 | 
							if err := tokenutil.UpgradeValue(data, "policies", "token_policies", &roleEntry.Policies, &roleEntry.TokenPolicies); err != nil {
 | 
				
			||||||
		if !ok {
 | 
								return logical.ErrorResponse(err.Error()), nil
 | 
				
			||||||
			policiesRaw, ok = data.GetOk("policies")
 | 
					 | 
				
			||||||
			if ok {
 | 
					 | 
				
			||||||
				roleEntry.Policies = policyutil.ParsePolicies(policiesRaw)
 | 
					 | 
				
			||||||
				roleEntry.TokenPolicies = roleEntry.Policies
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		} else {
 | 
					 | 
				
			||||||
			_, ok = data.GetOk("policies")
 | 
					 | 
				
			||||||
			if ok {
 | 
					 | 
				
			||||||
				roleEntry.Policies = roleEntry.TokenPolicies
 | 
					 | 
				
			||||||
			} else {
 | 
					 | 
				
			||||||
				roleEntry.Policies = nil
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		ttlRaw, ok := data.GetOk("token_ttl")
 | 
							if err := tokenutil.UpgradeValue(data, "ttl", "token_ttl", &roleEntry.TTL, &roleEntry.TokenTTL); err != nil {
 | 
				
			||||||
		if !ok {
 | 
								return logical.ErrorResponse(err.Error()), nil
 | 
				
			||||||
			ttlRaw, ok = data.GetOk("ttl")
 | 
					 | 
				
			||||||
			if !ok {
 | 
					 | 
				
			||||||
				ttlRaw, ok = data.GetOk("lease")
 | 
					 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
							// Special case here for old lease value
 | 
				
			||||||
 | 
							_, ok := data.GetOk("token_ttl")
 | 
				
			||||||
 | 
							if !ok {
 | 
				
			||||||
 | 
								_, ok = data.GetOk("ttl")
 | 
				
			||||||
 | 
								if !ok {
 | 
				
			||||||
 | 
									ttlRaw, ok := data.GetOk("lease")
 | 
				
			||||||
				if ok {
 | 
									if ok {
 | 
				
			||||||
					roleEntry.TTL = time.Duration(ttlRaw.(int)) * time.Second
 | 
										roleEntry.TTL = time.Duration(ttlRaw.(int)) * time.Second
 | 
				
			||||||
					roleEntry.TokenTTL = roleEntry.TTL
 | 
										roleEntry.TokenTTL = roleEntry.TTL
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
		} else {
 | 
					 | 
				
			||||||
			_, ok = data.GetOk("ttl")
 | 
					 | 
				
			||||||
			if ok {
 | 
					 | 
				
			||||||
				roleEntry.TTL = roleEntry.TokenTTL
 | 
					 | 
				
			||||||
			} else {
 | 
					 | 
				
			||||||
				roleEntry.TTL = 0
 | 
					 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		maxTTLRaw, ok := data.GetOk("token_max_ttl")
 | 
							if err := tokenutil.UpgradeValue(data, "max_ttl", "token_max_ttl", &roleEntry.MaxTTL, &roleEntry.TokenMaxTTL); err != nil {
 | 
				
			||||||
		if !ok {
 | 
								return logical.ErrorResponse(err.Error()), nil
 | 
				
			||||||
			maxTTLRaw, ok = data.GetOk("max_ttl")
 | 
					 | 
				
			||||||
			if ok {
 | 
					 | 
				
			||||||
				roleEntry.MaxTTL = time.Duration(maxTTLRaw.(int)) * time.Second
 | 
					 | 
				
			||||||
				roleEntry.TokenMaxTTL = roleEntry.MaxTTL
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		} else {
 | 
					 | 
				
			||||||
			_, ok = data.GetOk("max_ttl")
 | 
					 | 
				
			||||||
			if ok {
 | 
					 | 
				
			||||||
				roleEntry.MaxTTL = roleEntry.TokenMaxTTL
 | 
					 | 
				
			||||||
			} else {
 | 
					 | 
				
			||||||
				roleEntry.MaxTTL = 0
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		periodRaw, ok := data.GetOk("token_period")
 | 
							if err := tokenutil.UpgradeValue(data, "period", "token_period", &roleEntry.Period, &roleEntry.TokenPeriod); err != nil {
 | 
				
			||||||
		if !ok {
 | 
								return logical.ErrorResponse(err.Error()), nil
 | 
				
			||||||
			periodRaw, ok = data.GetOk("period")
 | 
					 | 
				
			||||||
			if ok {
 | 
					 | 
				
			||||||
				roleEntry.Period = time.Duration(periodRaw.(int)) * time.Second
 | 
					 | 
				
			||||||
				roleEntry.TokenPeriod = roleEntry.Period
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		} else {
 | 
					 | 
				
			||||||
			_, ok = data.GetOk("period")
 | 
					 | 
				
			||||||
			if ok {
 | 
					 | 
				
			||||||
				roleEntry.Period = roleEntry.TokenPeriod
 | 
					 | 
				
			||||||
			} else {
 | 
					 | 
				
			||||||
				roleEntry.Period = 0
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -9,8 +9,6 @@ import (
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	sockaddr "github.com/hashicorp/go-sockaddr"
 | 
						sockaddr "github.com/hashicorp/go-sockaddr"
 | 
				
			||||||
	"github.com/hashicorp/vault/sdk/framework"
 | 
						"github.com/hashicorp/vault/sdk/framework"
 | 
				
			||||||
	"github.com/hashicorp/vault/sdk/helper/parseutil"
 | 
					 | 
				
			||||||
	"github.com/hashicorp/vault/sdk/helper/policyutil"
 | 
					 | 
				
			||||||
	"github.com/hashicorp/vault/sdk/helper/tokenutil"
 | 
						"github.com/hashicorp/vault/sdk/helper/tokenutil"
 | 
				
			||||||
	"github.com/hashicorp/vault/sdk/logical"
 | 
						"github.com/hashicorp/vault/sdk/logical"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
@@ -287,93 +285,37 @@ func (b *backend) pathCertWrite(ctx context.Context, req *logical.Request, d *fr
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	// Handle upgrade cases
 | 
						// Handle upgrade cases
 | 
				
			||||||
	{
 | 
						{
 | 
				
			||||||
		policiesRaw, ok := d.GetOk("token_policies")
 | 
							if err := tokenutil.UpgradeValue(d, "policies", "token_policies", &cert.Policies, &cert.TokenPolicies); err != nil {
 | 
				
			||||||
		if !ok {
 | 
								return logical.ErrorResponse(err.Error()), nil
 | 
				
			||||||
			policiesRaw, ok = d.GetOk("policies")
 | 
					 | 
				
			||||||
			if ok {
 | 
					 | 
				
			||||||
				cert.Policies = policyutil.ParsePolicies(policiesRaw)
 | 
					 | 
				
			||||||
				cert.TokenPolicies = cert.Policies
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		} else {
 | 
					 | 
				
			||||||
			_, ok = d.GetOk("policies")
 | 
					 | 
				
			||||||
			if ok {
 | 
					 | 
				
			||||||
				cert.Policies = cert.TokenPolicies
 | 
					 | 
				
			||||||
			} else {
 | 
					 | 
				
			||||||
				cert.Policies = nil
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		ttlRaw, ok := d.GetOk("token_ttl")
 | 
							if err := tokenutil.UpgradeValue(d, "ttl", "token_ttl", &cert.TTL, &cert.TokenTTL); err != nil {
 | 
				
			||||||
		if !ok {
 | 
								return logical.ErrorResponse(err.Error()), nil
 | 
				
			||||||
			ttlRaw, ok = d.GetOk("ttl")
 | 
					 | 
				
			||||||
			if !ok {
 | 
					 | 
				
			||||||
				ttlRaw, ok = d.GetOk("lease")
 | 
					 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
							// Special case here for old lease value
 | 
				
			||||||
 | 
							_, ok := d.GetOk("token_ttl")
 | 
				
			||||||
 | 
							if !ok {
 | 
				
			||||||
 | 
								_, ok = d.GetOk("ttl")
 | 
				
			||||||
 | 
								if !ok {
 | 
				
			||||||
 | 
									ttlRaw, ok := d.GetOk("lease")
 | 
				
			||||||
				if ok {
 | 
									if ok {
 | 
				
			||||||
					cert.TTL = time.Duration(ttlRaw.(int)) * time.Second
 | 
										cert.TTL = time.Duration(ttlRaw.(int)) * time.Second
 | 
				
			||||||
					cert.TokenTTL = cert.TTL
 | 
										cert.TokenTTL = cert.TTL
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
		} else {
 | 
					 | 
				
			||||||
			_, ok = d.GetOk("ttl")
 | 
					 | 
				
			||||||
			if ok {
 | 
					 | 
				
			||||||
				cert.TTL = cert.TokenTTL
 | 
					 | 
				
			||||||
			} else {
 | 
					 | 
				
			||||||
				cert.TTL = 0
 | 
					 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		maxTTLRaw, ok := d.GetOk("token_max_ttl")
 | 
							if err := tokenutil.UpgradeValue(d, "max_ttl", "token_max_ttl", &cert.MaxTTL, &cert.TokenMaxTTL); err != nil {
 | 
				
			||||||
		if !ok {
 | 
								return logical.ErrorResponse(err.Error()), nil
 | 
				
			||||||
			maxTTLRaw, ok = d.GetOk("max_ttl")
 | 
					 | 
				
			||||||
			if ok {
 | 
					 | 
				
			||||||
				cert.MaxTTL = time.Duration(maxTTLRaw.(int)) * time.Second
 | 
					 | 
				
			||||||
				cert.TokenMaxTTL = cert.MaxTTL
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		} else {
 | 
					 | 
				
			||||||
			_, ok = d.GetOk("max_ttl")
 | 
					 | 
				
			||||||
			if ok {
 | 
					 | 
				
			||||||
				cert.MaxTTL = cert.TokenMaxTTL
 | 
					 | 
				
			||||||
			} else {
 | 
					 | 
				
			||||||
				cert.MaxTTL = 0
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		periodRaw, ok := d.GetOk("token_period")
 | 
							if err := tokenutil.UpgradeValue(d, "period", "token_period", &cert.Period, &cert.TokenPeriod); err != nil {
 | 
				
			||||||
		if !ok {
 | 
								return logical.ErrorResponse(err.Error()), nil
 | 
				
			||||||
			periodRaw, ok = d.GetOk("period")
 | 
					 | 
				
			||||||
			if ok {
 | 
					 | 
				
			||||||
				cert.Period = time.Duration(periodRaw.(int)) * time.Second
 | 
					 | 
				
			||||||
				cert.TokenPeriod = cert.Period
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		} else {
 | 
					 | 
				
			||||||
			_, ok = d.GetOk("period")
 | 
					 | 
				
			||||||
			if ok {
 | 
					 | 
				
			||||||
				cert.Period = cert.TokenPeriod
 | 
					 | 
				
			||||||
			} else {
 | 
					 | 
				
			||||||
				cert.Period = 0
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		boundCIDRsRaw, ok := d.GetOk("token_bound_cidrs")
 | 
							if err := tokenutil.UpgradeValue(d, "bound_cidrs", "token_bound_cidrs", &cert.BoundCIDRs, &cert.TokenBoundCIDRs); err != nil {
 | 
				
			||||||
		if !ok {
 | 
								return logical.ErrorResponse(err.Error()), nil
 | 
				
			||||||
			boundCIDRsRaw, ok = d.GetOk("bound_cidrs")
 | 
					 | 
				
			||||||
			if ok {
 | 
					 | 
				
			||||||
				boundCIDRs, err := parseutil.ParseAddrs(boundCIDRsRaw)
 | 
					 | 
				
			||||||
				if err != nil {
 | 
					 | 
				
			||||||
					return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
 | 
					 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
				cert.BoundCIDRs = boundCIDRs
 | 
					 | 
				
			||||||
				cert.TokenBoundCIDRs = cert.BoundCIDRs
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		} else {
 | 
					 | 
				
			||||||
			_, ok = d.GetOk("bound_cidrs")
 | 
					 | 
				
			||||||
			if ok {
 | 
					 | 
				
			||||||
				cert.BoundCIDRs = cert.TokenBoundCIDRs
 | 
					 | 
				
			||||||
			} else {
 | 
					 | 
				
			||||||
				cert.BoundCIDRs = nil
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	var resp logical.Response
 | 
						var resp logical.Response
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -86,36 +86,12 @@ func (b *backend) pathConfigWrite(ctx context.Context, req *logical.Request, dat
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	// Handle upgrade cases
 | 
						// Handle upgrade cases
 | 
				
			||||||
	{
 | 
						{
 | 
				
			||||||
		ttlRaw, ok := data.GetOk("token_ttl")
 | 
							if err := tokenutil.UpgradeValue(data, "ttl", "token_ttl", &c.TTL, &c.TokenTTL); err != nil {
 | 
				
			||||||
		if !ok {
 | 
								return logical.ErrorResponse(err.Error()), nil
 | 
				
			||||||
			ttlRaw, ok = data.GetOk("ttl")
 | 
					 | 
				
			||||||
			if ok {
 | 
					 | 
				
			||||||
				c.TTL = time.Duration(ttlRaw.(int)) * time.Second
 | 
					 | 
				
			||||||
				c.TokenTTL = c.TTL
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		} else {
 | 
					 | 
				
			||||||
			_, ok = data.GetOk("ttl")
 | 
					 | 
				
			||||||
			if ok {
 | 
					 | 
				
			||||||
				c.TTL = c.TokenTTL
 | 
					 | 
				
			||||||
			} else {
 | 
					 | 
				
			||||||
				c.TTL = 0
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		maxTTLRaw, ok := data.GetOk("token_max_ttl")
 | 
							if err := tokenutil.UpgradeValue(data, "max_ttl", "token_max_ttl", &c.MaxTTL, &c.TokenMaxTTL); err != nil {
 | 
				
			||||||
		if !ok {
 | 
								return logical.ErrorResponse(err.Error()), nil
 | 
				
			||||||
			maxTTLRaw, ok = data.GetOk("max_ttl")
 | 
					 | 
				
			||||||
			if ok {
 | 
					 | 
				
			||||||
				c.MaxTTL = time.Duration(maxTTLRaw.(int)) * time.Second
 | 
					 | 
				
			||||||
				c.TokenMaxTTL = c.MaxTTL
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		} else {
 | 
					 | 
				
			||||||
			_, ok = data.GetOk("max_ttl")
 | 
					 | 
				
			||||||
			if ok {
 | 
					 | 
				
			||||||
				c.MaxTTL = c.TokenMaxTTL
 | 
					 | 
				
			||||||
			} else {
 | 
					 | 
				
			||||||
				c.MaxTTL = 0
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -228,36 +228,12 @@ func (b *backend) pathConfigWrite(ctx context.Context, req *logical.Request, d *
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	// Handle upgrade cases
 | 
						// Handle upgrade cases
 | 
				
			||||||
	{
 | 
						{
 | 
				
			||||||
		ttlRaw, ok := d.GetOk("token_ttl")
 | 
							if err := tokenutil.UpgradeValue(d, "ttl", "token_ttl", &cfg.TTL, &cfg.TokenTTL); err != nil {
 | 
				
			||||||
		if !ok {
 | 
								return logical.ErrorResponse(err.Error()), nil
 | 
				
			||||||
			ttlRaw, ok = d.GetOk("ttl")
 | 
					 | 
				
			||||||
			if ok {
 | 
					 | 
				
			||||||
				cfg.TTL = time.Duration(ttlRaw.(int)) * time.Second
 | 
					 | 
				
			||||||
				cfg.TokenTTL = cfg.TTL
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		} else {
 | 
					 | 
				
			||||||
			_, ok = d.GetOk("ttl")
 | 
					 | 
				
			||||||
			if ok {
 | 
					 | 
				
			||||||
				cfg.TTL = cfg.TokenTTL
 | 
					 | 
				
			||||||
			} else {
 | 
					 | 
				
			||||||
				cfg.TTL = 0
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		maxTTLRaw, ok := d.GetOk("token_max_ttl")
 | 
							if err := tokenutil.UpgradeValue(d, "max_ttl", "token_max_ttl", &cfg.MaxTTL, &cfg.TokenMaxTTL); err != nil {
 | 
				
			||||||
		if !ok {
 | 
								return logical.ErrorResponse(err.Error()), nil
 | 
				
			||||||
			maxTTLRaw, ok = d.GetOk("max_ttl")
 | 
					 | 
				
			||||||
			if ok {
 | 
					 | 
				
			||||||
				cfg.MaxTTL = time.Duration(maxTTLRaw.(int)) * time.Second
 | 
					 | 
				
			||||||
				cfg.TokenMaxTTL = cfg.MaxTTL
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		} else {
 | 
					 | 
				
			||||||
			_, ok = d.GetOk("max_ttl")
 | 
					 | 
				
			||||||
			if ok {
 | 
					 | 
				
			||||||
				cfg.MaxTTL = cfg.TokenMaxTTL
 | 
					 | 
				
			||||||
			} else {
 | 
					 | 
				
			||||||
				cfg.MaxTTL = 0
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -8,8 +8,6 @@ import (
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	sockaddr "github.com/hashicorp/go-sockaddr"
 | 
						sockaddr "github.com/hashicorp/go-sockaddr"
 | 
				
			||||||
	"github.com/hashicorp/vault/sdk/framework"
 | 
						"github.com/hashicorp/vault/sdk/framework"
 | 
				
			||||||
	"github.com/hashicorp/vault/sdk/helper/parseutil"
 | 
					 | 
				
			||||||
	"github.com/hashicorp/vault/sdk/helper/policyutil"
 | 
					 | 
				
			||||||
	"github.com/hashicorp/vault/sdk/helper/tokenutil"
 | 
						"github.com/hashicorp/vault/sdk/helper/tokenutil"
 | 
				
			||||||
	"github.com/hashicorp/vault/sdk/logical"
 | 
						"github.com/hashicorp/vault/sdk/logical"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
@@ -210,72 +208,20 @@ func (b *backend) userCreateUpdate(ctx context.Context, req *logical.Request, d
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	// handle upgrade cases
 | 
						// handle upgrade cases
 | 
				
			||||||
	{
 | 
						{
 | 
				
			||||||
		policiesRaw, ok := d.GetOk("token_policies")
 | 
							if err := tokenutil.UpgradeValue(d, "policies", "token_policies", &userEntry.Policies, &userEntry.TokenPolicies); err != nil {
 | 
				
			||||||
		if !ok {
 | 
								return logical.ErrorResponse(err.Error()), nil
 | 
				
			||||||
			policiesRaw, ok = d.GetOk("policies")
 | 
					 | 
				
			||||||
			if ok {
 | 
					 | 
				
			||||||
				userEntry.Policies = policyutil.ParsePolicies(policiesRaw)
 | 
					 | 
				
			||||||
				userEntry.TokenPolicies = userEntry.Policies
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		} else {
 | 
					 | 
				
			||||||
			_, ok = d.GetOk("policies")
 | 
					 | 
				
			||||||
			if ok {
 | 
					 | 
				
			||||||
				userEntry.Policies = userEntry.TokenPolicies
 | 
					 | 
				
			||||||
			} else {
 | 
					 | 
				
			||||||
				userEntry.Policies = nil
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		ttlRaw, ok := d.GetOk("token_ttl")
 | 
							if err := tokenutil.UpgradeValue(d, "ttl", "token_ttl", &userEntry.TTL, &userEntry.TokenTTL); err != nil {
 | 
				
			||||||
		if !ok {
 | 
								return logical.ErrorResponse(err.Error()), nil
 | 
				
			||||||
			ttlRaw, ok = d.GetOk("ttl")
 | 
					 | 
				
			||||||
			if ok {
 | 
					 | 
				
			||||||
				userEntry.TTL = time.Duration(ttlRaw.(int)) * time.Second
 | 
					 | 
				
			||||||
				userEntry.TokenTTL = userEntry.TTL
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		} else {
 | 
					 | 
				
			||||||
			_, ok = d.GetOk("ttl")
 | 
					 | 
				
			||||||
			if ok {
 | 
					 | 
				
			||||||
				userEntry.TTL = userEntry.TokenTTL
 | 
					 | 
				
			||||||
			} else {
 | 
					 | 
				
			||||||
				userEntry.TTL = 0
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		maxTTLRaw, ok := d.GetOk("token_max_ttl")
 | 
							if err := tokenutil.UpgradeValue(d, "max_ttl", "token_max_ttl", &userEntry.MaxTTL, &userEntry.TokenMaxTTL); err != nil {
 | 
				
			||||||
		if !ok {
 | 
								return logical.ErrorResponse(err.Error()), nil
 | 
				
			||||||
			maxTTLRaw, ok = d.GetOk("max_ttl")
 | 
					 | 
				
			||||||
			if ok {
 | 
					 | 
				
			||||||
				userEntry.MaxTTL = time.Duration(maxTTLRaw.(int)) * time.Second
 | 
					 | 
				
			||||||
				userEntry.TokenMaxTTL = userEntry.TokenMaxTTL
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		} else {
 | 
					 | 
				
			||||||
			_, ok = d.GetOk("max_ttl")
 | 
					 | 
				
			||||||
			if ok {
 | 
					 | 
				
			||||||
				userEntry.MaxTTL = userEntry.TokenMaxTTL
 | 
					 | 
				
			||||||
			} else {
 | 
					 | 
				
			||||||
				userEntry.MaxTTL = 0
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		boundCIDRsRaw, ok := d.GetOk("token_bound_cidrs")
 | 
							if err := tokenutil.UpgradeValue(d, "bound_cidrs", "token_bound_cirs", &userEntry.BoundCIDRs, &userEntry.TokenBoundCIDRs); err != nil {
 | 
				
			||||||
		if !ok {
 | 
								return logical.ErrorResponse(err.Error()), nil
 | 
				
			||||||
			boundCIDRsRaw, ok = d.GetOk("bound_cidrs")
 | 
					 | 
				
			||||||
			if ok {
 | 
					 | 
				
			||||||
				boundCIDRs, err := parseutil.ParseAddrs(boundCIDRsRaw)
 | 
					 | 
				
			||||||
				if err != nil {
 | 
					 | 
				
			||||||
					return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
				userEntry.BoundCIDRs = boundCIDRs
 | 
					 | 
				
			||||||
				userEntry.TokenBoundCIDRs = userEntry.BoundCIDRs
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		} else {
 | 
					 | 
				
			||||||
			_, ok = d.GetOk("bound_cidrs")
 | 
					 | 
				
			||||||
			if ok {
 | 
					 | 
				
			||||||
				userEntry.BoundCIDRs = userEntry.TokenBoundCIDRs
 | 
					 | 
				
			||||||
			} else {
 | 
					 | 
				
			||||||
				userEntry.BoundCIDRs = nil
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -8,6 +8,7 @@ import (
 | 
				
			|||||||
	sockaddr "github.com/hashicorp/go-sockaddr"
 | 
						sockaddr "github.com/hashicorp/go-sockaddr"
 | 
				
			||||||
	"github.com/hashicorp/vault/sdk/framework"
 | 
						"github.com/hashicorp/vault/sdk/framework"
 | 
				
			||||||
	"github.com/hashicorp/vault/sdk/helper/parseutil"
 | 
						"github.com/hashicorp/vault/sdk/helper/parseutil"
 | 
				
			||||||
 | 
						"github.com/hashicorp/vault/sdk/helper/policyutil"
 | 
				
			||||||
	"github.com/hashicorp/vault/sdk/helper/strutil"
 | 
						"github.com/hashicorp/vault/sdk/helper/strutil"
 | 
				
			||||||
	"github.com/hashicorp/vault/sdk/logical"
 | 
						"github.com/hashicorp/vault/sdk/logical"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
@@ -266,6 +267,133 @@ func DeprecationText(param string) string {
 | 
				
			|||||||
	return fmt.Sprintf("Use %q instead. If this and %q are both specified, only %q will be used.", param, param, param)
 | 
						return fmt.Sprintf("Use %q instead. If this and %q are both specified, only %q will be used.", param, param, param)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func upgradeDurationValue(d *framework.FieldData, oldKey, newKey string, oldVal, newVal *time.Duration) error {
 | 
				
			||||||
 | 
						_, ok := d.GetOk(newKey)
 | 
				
			||||||
 | 
						if !ok {
 | 
				
			||||||
 | 
							raw, ok := d.GetOk(oldKey)
 | 
				
			||||||
 | 
							if ok {
 | 
				
			||||||
 | 
								*oldVal = time.Duration(raw.(int)) * time.Second
 | 
				
			||||||
 | 
								*newVal = *oldVal
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						} else {
 | 
				
			||||||
 | 
							_, ok = d.GetOk(oldKey)
 | 
				
			||||||
 | 
							if ok {
 | 
				
			||||||
 | 
								*oldVal = *newVal
 | 
				
			||||||
 | 
							} else {
 | 
				
			||||||
 | 
								*oldVal = 0
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func upgradeIntValue(d *framework.FieldData, oldKey, newKey string, oldVal, newVal *int) error {
 | 
				
			||||||
 | 
						_, ok := d.GetOk(newKey)
 | 
				
			||||||
 | 
						if !ok {
 | 
				
			||||||
 | 
							raw, ok := d.GetOk(oldKey)
 | 
				
			||||||
 | 
							if ok {
 | 
				
			||||||
 | 
								*oldVal = raw.(int)
 | 
				
			||||||
 | 
								*newVal = *oldVal
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						} else {
 | 
				
			||||||
 | 
							_, ok = d.GetOk(oldKey)
 | 
				
			||||||
 | 
							if ok {
 | 
				
			||||||
 | 
								*oldVal = *newVal
 | 
				
			||||||
 | 
							} else {
 | 
				
			||||||
 | 
								*oldVal = 0
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func upgradeStringSliceValue(d *framework.FieldData, oldKey, newKey string, oldVal, newVal *[]string) error {
 | 
				
			||||||
 | 
						_, ok := d.GetOk(newKey)
 | 
				
			||||||
 | 
						if !ok {
 | 
				
			||||||
 | 
							raw, ok := d.GetOk(oldKey)
 | 
				
			||||||
 | 
							if ok {
 | 
				
			||||||
 | 
								// Special case: if we're looking at "token_policies" parse the policies
 | 
				
			||||||
 | 
								if newKey == "token_policies" {
 | 
				
			||||||
 | 
									*oldVal = policyutil.ParsePolicies(raw)
 | 
				
			||||||
 | 
								} else {
 | 
				
			||||||
 | 
									*oldVal = raw.([]string)
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								*newVal = *oldVal
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						} else {
 | 
				
			||||||
 | 
							_, ok = d.GetOk(oldKey)
 | 
				
			||||||
 | 
							if ok {
 | 
				
			||||||
 | 
								*oldVal = *newVal
 | 
				
			||||||
 | 
							} else {
 | 
				
			||||||
 | 
								*oldVal = nil
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func upgradeSockAddrSliceValue(d *framework.FieldData, oldKey, newKey string, oldVal, newVal *[]*sockaddr.SockAddrMarshaler) error {
 | 
				
			||||||
 | 
						_, ok := d.GetOk(newKey)
 | 
				
			||||||
 | 
						if !ok {
 | 
				
			||||||
 | 
							raw, ok := d.GetOk(oldKey)
 | 
				
			||||||
 | 
							if ok {
 | 
				
			||||||
 | 
								boundCIDRs, err := parseutil.ParseAddrs(raw)
 | 
				
			||||||
 | 
								if err != nil {
 | 
				
			||||||
 | 
									return err
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								*oldVal = boundCIDRs
 | 
				
			||||||
 | 
								*newVal = *oldVal
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						} else {
 | 
				
			||||||
 | 
							_, ok = d.GetOk(oldKey)
 | 
				
			||||||
 | 
							if ok {
 | 
				
			||||||
 | 
								*oldVal = *newVal
 | 
				
			||||||
 | 
							} else {
 | 
				
			||||||
 | 
								*oldVal = nil
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// UpgradeValue takes in old/new data keys and old/new values and calls out to
 | 
				
			||||||
 | 
					// a helper function to perform upgrades in a standardized way. It reqiures
 | 
				
			||||||
 | 
					// pointers in all cases so that we can set directly into the target struct.
 | 
				
			||||||
 | 
					func UpgradeValue(d *framework.FieldData, oldKey, newKey string, oldVal, newVal interface{}) error {
 | 
				
			||||||
 | 
						switch typedOldVal := oldVal.(type) {
 | 
				
			||||||
 | 
						case *time.Duration:
 | 
				
			||||||
 | 
							typedNewVal, ok := newVal.(*time.Duration)
 | 
				
			||||||
 | 
							if !ok {
 | 
				
			||||||
 | 
								return errors.New("mismatch in value types in tokenutil.UpgradeValue")
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							return upgradeDurationValue(d, oldKey, newKey, typedOldVal, typedNewVal)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						case *int:
 | 
				
			||||||
 | 
							typedNewVal, ok := newVal.(*int)
 | 
				
			||||||
 | 
							if !ok {
 | 
				
			||||||
 | 
								return errors.New("mismatch in value types in tokenutil.UpgradeValue")
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							return upgradeIntValue(d, oldKey, newKey, typedOldVal, typedNewVal)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						case *[]string:
 | 
				
			||||||
 | 
							typedNewVal, ok := newVal.(*[]string)
 | 
				
			||||||
 | 
							if !ok {
 | 
				
			||||||
 | 
								return errors.New("mismatch in value types in tokenutil.UpgradeValue")
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							return upgradeStringSliceValue(d, oldKey, newKey, typedOldVal, typedNewVal)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						case *[]*sockaddr.SockAddrMarshaler:
 | 
				
			||||||
 | 
							typedNewVal, ok := newVal.(*[]*sockaddr.SockAddrMarshaler)
 | 
				
			||||||
 | 
							if !ok {
 | 
				
			||||||
 | 
								return errors.New("mismatch in value types in tokenutil.UpgradeValue")
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							return upgradeSockAddrSliceValue(d, oldKey, newKey, typedOldVal, typedNewVal)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						default:
 | 
				
			||||||
 | 
							return errors.New("unhandled type in tokenutil.UpgradeValue")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
const (
 | 
					const (
 | 
				
			||||||
	tokenPeriodHelp = `If set, tokens created via this role
 | 
						tokenPeriodHelp = `If set, tokens created via this role
 | 
				
			||||||
will have no max lifetime; instead, their
 | 
					will have no max lifetime; instead, their
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user