Github backend: enable auth renewals

This commit is contained in:
vishalnayak
2015-10-02 13:33:19 -04:00
parent df90e664ad
commit bf017d28d1
5 changed files with 54 additions and 17 deletions

View File

@@ -35,9 +35,11 @@ func Backend() *framework.Backend {
}, },
Paths: append([]*framework.Path{ Paths: append([]*framework.Path{
pathConfig(), pathConfig(&b),
pathLogin(&b), pathLogin(&b),
}, b.Map.Paths()...), }, b.Map.Paths()...),
AuthRenew: b.pathLoginRenew,
} }
return b.Backend return b.Backend

View File

@@ -3,12 +3,13 @@ package github
import ( import (
"fmt" "fmt"
"net/url" "net/url"
"time"
"github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework" "github.com/hashicorp/vault/logical/framework"
) )
func pathConfig() *framework.Path { func pathConfig(b *backend) *framework.Path {
return &framework.Path{ return &framework.Path{
Pattern: "config", Pattern: "config",
Fields: map[string]*framework.FieldSchema{ Fields: map[string]*framework.FieldSchema{
@@ -23,29 +24,47 @@ func pathConfig() *framework.Path {
are running GitHub Enterprise or an are running GitHub Enterprise or an
API-compatible authentication server.`, API-compatible authentication server.`,
}, },
"ttl": &framework.FieldSchema{
Type: framework.TypeString,
Description: `Duration after which authentication will be expired`,
},
"max_ttl": &framework.FieldSchema{
Type: framework.TypeString,
Description: `Maximum duration after which authentication will be expired`,
},
}, },
Callbacks: map[logical.Operation]framework.OperationFunc{ Callbacks: map[logical.Operation]framework.OperationFunc{
logical.WriteOperation: pathConfigWrite, logical.WriteOperation: b.pathConfigWrite,
}, },
} }
} }
func pathConfigWrite( func (b *backend) pathConfigWrite(
req *logical.Request, data *framework.FieldData) (*logical.Response, error) { req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
conf := config{ organization := data.Get("organization").(string)
Org: data.Get("organization").(string),
}
baseURL := data.Get("base_url").(string) baseURL := data.Get("base_url").(string)
if len(baseURL) != 0 { if len(baseURL) != 0 {
_, err := url.Parse(baseURL) _, err := url.Parse(baseURL)
if err != nil { if err != nil {
return logical.ErrorResponse(fmt.Sprintf("Error parsing given base_url: %s", err)), nil return logical.ErrorResponse(fmt.Sprintf("Error parsing given base_url: %s", err)), nil
} }
conf.BaseURL = baseURL
} }
entry, err := logical.StorageEntryJSON("config", conf) ttlStr := data.Get("ttl").(string)
maxTTLStr := data.Get("max_ttl").(string)
ttl, maxTTL, err := b.SanitizeTTL(ttlStr, maxTTLStr)
if err != nil {
return logical.ErrorResponse(fmt.Sprintf("err: %s", err)), nil
}
entry, err := logical.StorageEntryJSON("config", config{
Org: organization,
BaseURL: baseURL,
TTL: ttl,
MaxTTL: maxTTL,
})
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -75,6 +94,8 @@ func (b *backend) Config(s logical.Storage) (*config, error) {
} }
type config struct { type config struct {
Org string `json:"organization"` Org string `json:"organization"`
BaseURL string `json:"base_url"` BaseURL string `json:"base_url"`
TTL time.Duration `json:"ttl"`
MaxTTL time.Duration `json:"max_ttl"`
} }

View File

@@ -132,6 +132,20 @@ func (b *backend) pathLogin(
"org": *org.Login, "org": *org.Login,
}, },
DisplayName: *user.Login, DisplayName: *user.Login,
LeaseOptions: logical.LeaseOptions{
TTL: config.TTL,
GracePeriod: config.TTL / 10,
Renewable: config.TTL > 0,
},
}, },
}, nil }, nil
} }
func (b *backend) pathLoginRenew(
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
config, err := b.Config(req.Storage)
if err != nil {
return nil, err
}
return framework.LeaseExtend(config.MaxTTL, 0, false)(req, d)
}

View File

@@ -33,6 +33,7 @@ type AuthCommand struct {
} }
func (c *AuthCommand) Run(args []string) int { func (c *AuthCommand) Run(args []string) int {
var format string
var method string var method string
var methods, methodHelp, noVerify bool var methods, methodHelp, noVerify bool
flags := c.Meta.FlagSet("auth", FlagSetDefault) flags := c.Meta.FlagSet("auth", FlagSetDefault)
@@ -40,6 +41,7 @@ func (c *AuthCommand) Run(args []string) int {
flags.BoolVar(&methodHelp, "method-help", false, "") flags.BoolVar(&methodHelp, "method-help", false, "")
flags.BoolVar(&noVerify, "no-verify", false, "") flags.BoolVar(&noVerify, "no-verify", false, "")
flags.StringVar(&method, "method", "", "method") flags.StringVar(&method, "method", "", "method")
flags.StringVar(&format, "format", "table", "")
flags.Usage = func() { c.Ui.Error(c.Help()) } flags.Usage = func() { c.Ui.Error(c.Help()) }
if err := flags.Parse(args); err != nil { if err := flags.Parse(args); err != nil {
return 1 return 1
@@ -202,12 +204,10 @@ func (c *AuthCommand) Run(args []string) int {
} }
output := "Successfully authenticated!" output := "Successfully authenticated!"
if secret.LeaseDuration > 0 { output += fmt.Sprintf("\ntoken: %s", secret.Data["id"])
output += fmt.Sprintf("\nThe token's lifetime is %d seconds.", secret.LeaseDuration) output += fmt.Sprintf("\ntoken_duration: %d", int(secret.Data["ttl"].(float64)))
}
if len(policies) > 0 { if len(policies) > 0 {
output += fmt.Sprintf("\nThe policies that are associated with this token\narelisted below:\n\n%s", strings.Join(policies, ", ")) output += fmt.Sprintf("\ntoken_policies: [%s]", strings.Join(policies, ", "))
} }
c.Ui.Output(output) c.Ui.Output(output)

View File

@@ -721,7 +721,7 @@ func (ts *TokenStore) handleLookup(
} }
// Generate a response. We purposely omit the parent reference otherwise // Generate a response. We purposely omit the parent reference otherwise
// you could escalade your privileges. // you could escalate your privileges.
resp := &logical.Response{ resp := &logical.Response{
Data: map[string]interface{}{ Data: map[string]interface{}{
"id": out.ID, "id": out.ID,