Add base_url option to GitHub auth provider to allow selecting a custom endpoint. Fixes #572.

This commit is contained in:
Jeff Mitchell
2015-08-28 06:28:35 -07:00
parent 4d3f68a631
commit 06f7fb5dc3
3 changed files with 55 additions and 7 deletions

View File

@@ -14,9 +14,13 @@ func TestBackend_basic(t *testing.T) {
Backend: Backend(), Backend: Backend(),
Steps: []logicaltest.TestStep{ Steps: []logicaltest.TestStep{
testAccStepConfig(t), testAccStepConfig(t),
testAccMap(t, "default", "foo"), testAccMap(t, "default", "root"),
testAccMap(t, "oWnErs", "bar"), testAccMap(t, "oWnErs", "root"),
testAccLogin(t, []string{"bar", "foo"}), testAccLogin(t, []string{"root"}),
testAccStepConfigWithBaseURL(t),
testAccMap(t, "default", "root"),
testAccMap(t, "oWnErs", "root"),
testAccLogin(t, []string{"root"}),
}, },
}) })
} }
@@ -29,6 +33,10 @@ func testAccPreCheck(t *testing.T) {
if v := os.Getenv("GITHUB_ORG"); v == "" { if v := os.Getenv("GITHUB_ORG"); v == "" {
t.Fatal("GITHUB_ORG must be set for acceptance tests") t.Fatal("GITHUB_ORG must be set for acceptance tests")
} }
if v := os.Getenv("GITHUB_BASEURL"); v == "" {
t.Fatal("GITHUB_BASEURL must be set for acceptance tests (use 'https://api.github.com' if you don't know what you're doing)")
}
} }
func testAccStepConfig(t *testing.T) logicaltest.TestStep { func testAccStepConfig(t *testing.T) logicaltest.TestStep {
@@ -41,6 +49,17 @@ func testAccStepConfig(t *testing.T) logicaltest.TestStep {
} }
} }
func testAccStepConfigWithBaseURL(t *testing.T) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.WriteOperation,
Path: "config",
Data: map[string]interface{}{
"organization": os.Getenv("GITHUB_ORG"),
"base_url": os.Getenv("GITHUB_BASEURL"),
},
}
}
func testAccMap(t *testing.T, k string, v string) logicaltest.TestStep { func testAccMap(t *testing.T, k string, v string) logicaltest.TestStep {
return logicaltest.TestStep{ return logicaltest.TestStep{
Operation: logical.WriteOperation, Operation: logical.WriteOperation,

View File

@@ -2,6 +2,7 @@ package github
import ( import (
"fmt" "fmt"
"net/url"
"github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework" "github.com/hashicorp/vault/logical/framework"
@@ -15,6 +16,13 @@ func pathConfig() *framework.Path {
Type: framework.TypeString, Type: framework.TypeString,
Description: "The organization users must be part of", Description: "The organization users must be part of",
}, },
"base_url": &framework.FieldSchema{
Type: framework.TypeString,
Description: `The API endpoint to use. Useful if you
are running GitHub Enterprise or an
API-compatible authentication server.`,
},
}, },
Callbacks: map[logical.Operation]framework.OperationFunc{ Callbacks: map[logical.Operation]framework.OperationFunc{
@@ -25,9 +33,19 @@ func pathConfig() *framework.Path {
func pathConfigWrite( func pathConfigWrite(
req *logical.Request, data *framework.FieldData) (*logical.Response, error) { req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
entry, err := logical.StorageEntryJSON("config", config{ conf := config{
Org: data.Get("organization").(string), Org: data.Get("organization").(string),
}) }
baseURL := data.Get("base_url").(string)
if len(baseURL) != 0 {
_, err := url.Parse(baseURL)
if err != nil {
return logical.ErrorResponse(fmt.Sprintf("Error parsing given base_url: %s", err)), nil
}
conf.BaseURL = baseURL
}
entry, err := logical.StorageEntryJSON("config", conf)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -58,4 +76,5 @@ func (b *backend) Config(s logical.Storage) (*config, error) {
type config struct { type config struct {
Org string `json:"organization"` Org string `json:"organization"`
BaseURL string `json:"base_url"`
} }

View File

@@ -1,6 +1,9 @@
package github package github
import ( import (
"fmt"
"net/url"
"github.com/google/go-github/github" "github.com/google/go-github/github"
"github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework" "github.com/hashicorp/vault/logical/framework"
@@ -39,6 +42,14 @@ func (b *backend) pathLogin(
return nil, err return nil, err
} }
if config.BaseURL != "" {
parsedURL, err := url.Parse(config.BaseURL)
if err != nil {
return nil, fmt.Errorf("Successfully parsed base_url when set but failing to parse now: %s", err)
}
client.BaseURL = parsedURL
}
// Get the user // Get the user
user, _, err := client.Users.Get("") user, _, err := client.Users.Get("")
if err != nil { if err != nil {
@@ -108,7 +119,6 @@ func (b *backend) pathLogin(
} }
} }
policiesList, err := b.Map.Policies(req.Storage, teamNames...) policiesList, err := b.Map.Policies(req.Storage, teamNames...)
if err != nil { if err != nil {
return nil, err return nil, err