From 06f7fb5dc3df677bb74cdb82a275f8e7bb54a609 Mon Sep 17 00:00:00 2001 From: Jeff Mitchell Date: Fri, 28 Aug 2015 06:28:35 -0700 Subject: [PATCH] Add base_url option to GitHub auth provider to allow selecting a custom endpoint. Fixes #572. --- builtin/credential/github/backend_test.go | 25 ++++++++++++++++++++--- builtin/credential/github/path_config.go | 25 ++++++++++++++++++++--- builtin/credential/github/path_login.go | 12 ++++++++++- 3 files changed, 55 insertions(+), 7 deletions(-) diff --git a/builtin/credential/github/backend_test.go b/builtin/credential/github/backend_test.go index 9f3ed114cf..bb9373a23a 100644 --- a/builtin/credential/github/backend_test.go +++ b/builtin/credential/github/backend_test.go @@ -14,9 +14,13 @@ func TestBackend_basic(t *testing.T) { Backend: Backend(), Steps: []logicaltest.TestStep{ testAccStepConfig(t), - testAccMap(t, "default", "foo"), - testAccMap(t, "oWnErs", "bar"), - testAccLogin(t, []string{"bar", "foo"}), + testAccMap(t, "default", "root"), + testAccMap(t, "oWnErs", "root"), + testAccLogin(t, []string{"root"}), + testAccStepConfigWithBaseURL(t), + testAccMap(t, "default", "root"), + testAccMap(t, "oWnErs", "root"), + testAccLogin(t, []string{"root"}), }, }) } @@ -29,6 +33,10 @@ func testAccPreCheck(t *testing.T) { if v := os.Getenv("GITHUB_ORG"); v == "" { t.Fatal("GITHUB_ORG must be set for acceptance tests") } + + if v := os.Getenv("GITHUB_BASEURL"); v == "" { + t.Fatal("GITHUB_BASEURL must be set for acceptance tests (use 'https://api.github.com' if you don't know what you're doing)") + } } func testAccStepConfig(t *testing.T) logicaltest.TestStep { @@ -41,6 +49,17 @@ func testAccStepConfig(t *testing.T) logicaltest.TestStep { } } +func testAccStepConfigWithBaseURL(t *testing.T) logicaltest.TestStep { + return logicaltest.TestStep{ + Operation: logical.WriteOperation, + Path: "config", + Data: map[string]interface{}{ + "organization": os.Getenv("GITHUB_ORG"), + "base_url": os.Getenv("GITHUB_BASEURL"), + }, + } +} + func testAccMap(t *testing.T, k string, v string) logicaltest.TestStep { return logicaltest.TestStep{ Operation: logical.WriteOperation, diff --git a/builtin/credential/github/path_config.go b/builtin/credential/github/path_config.go index 0c6db566ae..b923e57a9d 100644 --- a/builtin/credential/github/path_config.go +++ b/builtin/credential/github/path_config.go @@ -2,6 +2,7 @@ package github import ( "fmt" + "net/url" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/framework" @@ -15,6 +16,13 @@ func pathConfig() *framework.Path { Type: framework.TypeString, Description: "The organization users must be part of", }, + + "base_url": &framework.FieldSchema{ + Type: framework.TypeString, + Description: `The API endpoint to use. Useful if you +are running GitHub Enterprise or an +API-compatible authentication server.`, + }, }, Callbacks: map[logical.Operation]framework.OperationFunc{ @@ -25,9 +33,19 @@ func pathConfig() *framework.Path { func pathConfigWrite( req *logical.Request, data *framework.FieldData) (*logical.Response, error) { - entry, err := logical.StorageEntryJSON("config", config{ + conf := config{ Org: data.Get("organization").(string), - }) + } + baseURL := data.Get("base_url").(string) + if len(baseURL) != 0 { + _, err := url.Parse(baseURL) + if err != nil { + return logical.ErrorResponse(fmt.Sprintf("Error parsing given base_url: %s", err)), nil + } + conf.BaseURL = baseURL + } + + entry, err := logical.StorageEntryJSON("config", conf) if err != nil { return nil, err } @@ -57,5 +75,6 @@ func (b *backend) Config(s logical.Storage) (*config, error) { } type config struct { - Org string `json:"organization"` + Org string `json:"organization"` + BaseURL string `json:"base_url"` } diff --git a/builtin/credential/github/path_login.go b/builtin/credential/github/path_login.go index 69b815b191..854fa3cd81 100644 --- a/builtin/credential/github/path_login.go +++ b/builtin/credential/github/path_login.go @@ -1,6 +1,9 @@ package github import ( + "fmt" + "net/url" + "github.com/google/go-github/github" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/framework" @@ -39,6 +42,14 @@ func (b *backend) pathLogin( return nil, err } + if config.BaseURL != "" { + parsedURL, err := url.Parse(config.BaseURL) + if err != nil { + return nil, fmt.Errorf("Successfully parsed base_url when set but failing to parse now: %s", err) + } + client.BaseURL = parsedURL + } + // Get the user user, _, err := client.Users.Get("") if err != nil { @@ -108,7 +119,6 @@ func (b *backend) pathLogin( } } - policiesList, err := b.Map.Policies(req.Storage, teamNames...) if err != nil { return nil, err