mirror of
				https://github.com/optim-enterprises-bv/vault.git
				synced 2025-10-31 02:28:09 +00:00 
			
		
		
		
	github auth: use org id to verify creds (#13332)
* github auth: use org id to verify creds * add check for required org param; add test case * update UTs * add nil check for org * add changelog * fix typo in ut * set org ID if it is unset; add more ut coverage * add optional organization_id * move client instantiation * refactor parse URL; add UT for setting org ID * fix comment in UT * add nil check * don't update org name on change; return warning * refactor verifyCredentials * error when unable to fetch org ID on config write; add warnings * fix bug in log message * update UT and small refactor * update comments and log msg * use getter for org ID
This commit is contained in:
		 John-Michael Faircloth
					John-Michael Faircloth
				
			
				
					committed by
					
						 GitHub
						GitHub
					
				
			
			
				
	
			
			
			 GitHub
						GitHub
					
				
			
						parent
						
							9b86b8cf90
						
					
				
				
					commit
					524ded982b
				
			| @@ -2,6 +2,7 @@ package github | |||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"context" | 	"context" | ||||||
|  | 	"errors" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"os" | 	"os" | ||||||
| 	"strings" | 	"strings" | ||||||
| @@ -70,6 +71,9 @@ func testLoginWrite(t *testing.T, d map[string]interface{}, expectedTTL time.Dur | |||||||
| 		ErrorOk:   true, | 		ErrorOk:   true, | ||||||
| 		Data:      d, | 		Data:      d, | ||||||
| 		Check: func(resp *logical.Response) error { | 		Check: func(resp *logical.Response) error { | ||||||
|  | 			if resp == nil { | ||||||
|  | 				return errors.New("expected a response but got nil") | ||||||
|  | 			} | ||||||
| 			if resp.IsError() && expectFail { | 			if resp.IsError() && expectFail { | ||||||
| 				return nil | 				return nil | ||||||
| 			} | 			} | ||||||
|   | |||||||
| @@ -7,6 +7,7 @@ import ( | |||||||
| 	"strings" | 	"strings" | ||||||
| 	"time" | 	"time" | ||||||
|  |  | ||||||
|  | 	"github.com/google/go-github/github" | ||||||
| 	"github.com/hashicorp/vault/sdk/framework" | 	"github.com/hashicorp/vault/sdk/framework" | ||||||
| 	"github.com/hashicorp/vault/sdk/helper/tokenutil" | 	"github.com/hashicorp/vault/sdk/helper/tokenutil" | ||||||
| 	"github.com/hashicorp/vault/sdk/logical" | 	"github.com/hashicorp/vault/sdk/logical" | ||||||
| @@ -19,8 +20,12 @@ func pathConfig(b *backend) *framework.Path { | |||||||
| 			"organization": { | 			"organization": { | ||||||
| 				Type:        framework.TypeString, | 				Type:        framework.TypeString, | ||||||
| 				Description: "The organization users must be part of", | 				Description: "The organization users must be part of", | ||||||
|  | 				Required:    true, | ||||||
|  | 			}, | ||||||
|  | 			"organization_id": { | ||||||
|  | 				Type:        framework.TypeInt64, | ||||||
|  | 				Description: "The ID of the organization users must be part of", | ||||||
| 			}, | 			}, | ||||||
|  |  | ||||||
| 			"base_url": { | 			"base_url": { | ||||||
| 				Type: framework.TypeString, | 				Type: framework.TypeString, | ||||||
| 				Description: `The API endpoint to use. Useful if you | 				Description: `The API endpoint to use. Useful if you | ||||||
| @@ -55,6 +60,7 @@ API-compatible authentication server.`, | |||||||
| } | } | ||||||
|  |  | ||||||
| func (b *backend) pathConfigWrite(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) { | func (b *backend) pathConfigWrite(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) { | ||||||
|  | 	var resp logical.Response | ||||||
| 	c, err := b.Config(ctx, req.Storage) | 	c, err := b.Config(ctx, req.Storage) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| @@ -66,19 +72,47 @@ func (b *backend) pathConfigWrite(ctx context.Context, req *logical.Request, dat | |||||||
| 	if organizationRaw, ok := data.GetOk("organization"); ok { | 	if organizationRaw, ok := data.GetOk("organization"); ok { | ||||||
| 		c.Organization = organizationRaw.(string) | 		c.Organization = organizationRaw.(string) | ||||||
| 	} | 	} | ||||||
|  | 	if c.Organization == "" { | ||||||
|  | 		return logical.ErrorResponse("organization is a required parameter"), nil | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if organizationRaw, ok := data.GetOk("organization_id"); ok { | ||||||
|  | 		c.OrganizationID = organizationRaw.(int64) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	var parsedURL *url.URL | ||||||
| 	if baseURLRaw, ok := data.GetOk("base_url"); ok { | 	if baseURLRaw, ok := data.GetOk("base_url"); ok { | ||||||
| 		baseURL := baseURLRaw.(string) | 		baseURL := baseURLRaw.(string) | ||||||
| 		_, err := url.Parse(baseURL) |  | ||||||
| 		if err != nil { |  | ||||||
| 			return logical.ErrorResponse(fmt.Sprintf("Error parsing given base_url: %s", err)), nil |  | ||||||
| 		} |  | ||||||
| 		if !strings.HasSuffix(baseURL, "/") { | 		if !strings.HasSuffix(baseURL, "/") { | ||||||
| 			baseURL += "/" | 			baseURL += "/" | ||||||
| 		} | 		} | ||||||
|  | 		parsedURL, err = url.Parse(baseURL) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return logical.ErrorResponse(fmt.Sprintf("error parsing given base_url: %s", err)), nil | ||||||
|  | 		} | ||||||
| 		c.BaseURL = baseURL | 		c.BaseURL = baseURL | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	if c.OrganizationID == 0 { | ||||||
|  | 		client, err := b.Client("") | ||||||
|  | 		if err != nil { | ||||||
|  | 			return nil, err | ||||||
|  | 		} | ||||||
|  | 		// ensure our client has the BaseURL if it was provided | ||||||
|  | 		if parsedURL != nil { | ||||||
|  | 			client.BaseURL = parsedURL | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		// we want to set the Org ID in the config so we can use that to verify | ||||||
|  | 		// the credentials on login | ||||||
|  | 		err = c.setOrganizationID(ctx, client) | ||||||
|  | 		if err != nil { | ||||||
|  | 			errorMsg := fmt.Errorf("unable to fetch the organization_id, you must manually set it in the config: %s", err) | ||||||
|  | 			b.Logger().Error(errorMsg.Error()) | ||||||
|  | 			return nil, errorMsg | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	if err := c.ParseTokenFields(req, data); err != nil { | 	if err := c.ParseTokenFields(req, data); err != nil { | ||||||
| 		return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest | 		return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest | ||||||
| 	} | 	} | ||||||
| @@ -103,7 +137,11 @@ func (b *backend) pathConfigWrite(ctx context.Context, req *logical.Request, dat | |||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	return nil, nil | 	if len(resp.Warnings) == 0 { | ||||||
|  | 		return nil, nil | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return &resp, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| func (b *backend) pathConfigRead(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) { | func (b *backend) pathConfigRead(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) { | ||||||
| @@ -116,8 +154,9 @@ func (b *backend) pathConfigRead(ctx context.Context, req *logical.Request, data | |||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	d := map[string]interface{}{ | 	d := map[string]interface{}{ | ||||||
| 		"organization": config.Organization, | 		"organization_id": config.OrganizationID, | ||||||
| 		"base_url":     config.BaseURL, | 		"organization":    config.Organization, | ||||||
|  | 		"base_url":        config.BaseURL, | ||||||
| 	} | 	} | ||||||
| 	config.PopulateTokenData(d) | 	config.PopulateTokenData(d) | ||||||
|  |  | ||||||
| @@ -163,8 +202,25 @@ func (b *backend) Config(ctx context.Context, s logical.Storage) (*config, error | |||||||
| type config struct { | type config struct { | ||||||
| 	tokenutil.TokenParams | 	tokenutil.TokenParams | ||||||
|  |  | ||||||
| 	Organization string        `json:"organization" structs:"organization" mapstructure:"organization"` | 	OrganizationID int64         `json:"organization_id" structs:"organization_id" mapstructure:"organization_id"` | ||||||
| 	BaseURL      string        `json:"base_url" structs:"base_url" mapstructure:"base_url"` | 	Organization   string        `json:"organization" structs:"organization" mapstructure:"organization"` | ||||||
| 	TTL          time.Duration `json:"ttl" structs:"ttl" mapstructure:"ttl"` | 	BaseURL        string        `json:"base_url" structs:"base_url" mapstructure:"base_url"` | ||||||
| 	MaxTTL       time.Duration `json:"max_ttl" structs:"max_ttl" mapstructure:"max_ttl"` | 	TTL            time.Duration `json:"ttl" structs:"ttl" mapstructure:"ttl"` | ||||||
|  | 	MaxTTL         time.Duration `json:"max_ttl" structs:"max_ttl" mapstructure:"max_ttl"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *config) setOrganizationID(ctx context.Context, client *github.Client) error { | ||||||
|  | 	org, _, err := client.Organizations.Get(ctx, c.Organization) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	orgID := org.GetID() | ||||||
|  | 	if orgID == 0 { | ||||||
|  | 		return fmt.Errorf("organization_id not found for %s", c.Organization) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	c.OrganizationID = orgID | ||||||
|  |  | ||||||
|  | 	return nil | ||||||
| } | } | ||||||
|   | |||||||
							
								
								
									
										214
									
								
								builtin/credential/github/path_config_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										214
									
								
								builtin/credential/github/path_config_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,214 @@ | |||||||
|  | package github | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"context" | ||||||
|  | 	"errors" | ||||||
|  | 	"fmt" | ||||||
|  | 	"net/http" | ||||||
|  | 	"net/http/httptest" | ||||||
|  | 	"strings" | ||||||
|  | 	"testing" | ||||||
|  |  | ||||||
|  | 	"github.com/hashicorp/vault/sdk/logical" | ||||||
|  | 	"github.com/stretchr/testify/assert" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func createBackendWithStorage(t *testing.T) (*backend, logical.Storage) { | ||||||
|  | 	t.Helper() | ||||||
|  | 	config := logical.TestBackendConfig() | ||||||
|  | 	config.StorageView = &logical.InmemStorage{} | ||||||
|  |  | ||||||
|  | 	b := Backend() | ||||||
|  | 	if b == nil { | ||||||
|  | 		t.Fatalf("failed to create backend") | ||||||
|  | 	} | ||||||
|  | 	err := b.Backend.Setup(context.Background(), config) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  | 	return b, config.StorageView | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // setupTestServer configures httptest server to intercept and respond to the | ||||||
|  | // request to base_url | ||||||
|  | func setupTestServer(t *testing.T) *httptest.Server { | ||||||
|  | 	t.Helper() | ||||||
|  | 	return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||||
|  | 		var resp string | ||||||
|  | 		if strings.Contains(r.URL.String(), "/user/orgs") { | ||||||
|  | 			resp = string(listOrgResponse) | ||||||
|  | 		} else if strings.Contains(r.URL.String(), "/user/teams") { | ||||||
|  | 			resp = string(listUserTeamsResponse) | ||||||
|  | 		} else if strings.Contains(r.URL.String(), "/user") { | ||||||
|  | 			resp = getUserResponse | ||||||
|  | 		} else if strings.Contains(r.URL.String(), "/orgs/") { | ||||||
|  | 			resp = getOrgResponse | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		w.Header().Add("Content-Type", "application/json") | ||||||
|  | 		fmt.Fprintln(w, resp) | ||||||
|  | 	})) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // TestGitHub_WriteReadConfig tests that we can successfully read and write | ||||||
|  | // the github auth config | ||||||
|  | func TestGitHub_WriteReadConfig(t *testing.T) { | ||||||
|  | 	b, s := createBackendWithStorage(t) | ||||||
|  |  | ||||||
|  | 	// use a test server to return our mock GH org info | ||||||
|  | 	ts := setupTestServer(t) | ||||||
|  | 	defer ts.Close() | ||||||
|  |  | ||||||
|  | 	// Write the config | ||||||
|  | 	resp, err := b.HandleRequest(context.Background(), &logical.Request{ | ||||||
|  | 		Path:      "config", | ||||||
|  | 		Operation: logical.UpdateOperation, | ||||||
|  | 		Data: map[string]interface{}{ | ||||||
|  | 			"organization": "foo-org", | ||||||
|  | 			"base_url":     ts.URL, // base_url will call the test server | ||||||
|  | 		}, | ||||||
|  | 		Storage: s, | ||||||
|  | 	}) | ||||||
|  | 	assert.NoError(t, err) | ||||||
|  | 	assert.Nil(t, resp) | ||||||
|  | 	assert.NoError(t, resp.Error()) | ||||||
|  |  | ||||||
|  | 	// Read the config | ||||||
|  | 	resp, err = b.HandleRequest(context.Background(), &logical.Request{ | ||||||
|  | 		Path:      "config", | ||||||
|  | 		Operation: logical.ReadOperation, | ||||||
|  | 		Storage:   s, | ||||||
|  | 	}) | ||||||
|  | 	assert.NoError(t, err) | ||||||
|  | 	assert.NoError(t, resp.Error()) | ||||||
|  |  | ||||||
|  | 	// the ID should be set, we grab it from the GET /orgs API | ||||||
|  | 	assert.Equal(t, int64(12345), resp.Data["organization_id"]) | ||||||
|  | 	assert.Equal(t, "foo-org", resp.Data["organization"]) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // TestGitHub_WriteReadConfig_OrgID tests that we can successfully read and | ||||||
|  | // write the github auth config with an organization_id param | ||||||
|  | func TestGitHub_WriteReadConfig_OrgID(t *testing.T) { | ||||||
|  | 	b, s := createBackendWithStorage(t) | ||||||
|  |  | ||||||
|  | 	// Write the config and pass in organization_id | ||||||
|  | 	resp, err := b.HandleRequest(context.Background(), &logical.Request{ | ||||||
|  | 		Path:      "config", | ||||||
|  | 		Operation: logical.UpdateOperation, | ||||||
|  | 		Data: map[string]interface{}{ | ||||||
|  | 			"organization":    "foo-org", | ||||||
|  | 			"organization_id": 98765, | ||||||
|  | 		}, | ||||||
|  | 		Storage: s, | ||||||
|  | 	}) | ||||||
|  | 	assert.NoError(t, err) | ||||||
|  | 	assert.Nil(t, resp) | ||||||
|  | 	assert.NoError(t, resp.Error()) | ||||||
|  |  | ||||||
|  | 	// Read the config | ||||||
|  | 	resp, err = b.HandleRequest(context.Background(), &logical.Request{ | ||||||
|  | 		Path:      "config", | ||||||
|  | 		Operation: logical.ReadOperation, | ||||||
|  | 		Storage:   s, | ||||||
|  | 	}) | ||||||
|  | 	assert.NoError(t, err) | ||||||
|  | 	assert.NoError(t, resp.Error()) | ||||||
|  |  | ||||||
|  | 	// the ID should be set to what was written in the config | ||||||
|  | 	assert.Equal(t, int64(98765), resp.Data["organization_id"]) | ||||||
|  | 	assert.Equal(t, "foo-org", resp.Data["organization"]) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // TestGitHub_ErrorNoOrgID tests that an error is returned when we cannot fetch | ||||||
|  | // the org ID for the given org name | ||||||
|  | func TestGitHub_ErrorNoOrgID(t *testing.T) { | ||||||
|  | 	b, s := createBackendWithStorage(t) | ||||||
|  | 	// use a test server to return our mock GH org info | ||||||
|  | 	ts := func() *httptest.Server { | ||||||
|  | 		return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||||
|  | 			w.Header().Add("Content-Type", "application/json") | ||||||
|  | 			resp := `{ "id": 0 }` | ||||||
|  | 			fmt.Fprintln(w, resp) | ||||||
|  | 		})) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	defer ts().Close() | ||||||
|  |  | ||||||
|  | 	// Write the config | ||||||
|  | 	resp, err := b.HandleRequest(context.Background(), &logical.Request{ | ||||||
|  | 		Path:      "config", | ||||||
|  | 		Operation: logical.UpdateOperation, | ||||||
|  | 		Data: map[string]interface{}{ | ||||||
|  | 			"organization": "foo-org", | ||||||
|  | 			"base_url":     ts().URL, // base_url will call the test server | ||||||
|  | 		}, | ||||||
|  | 		Storage: s, | ||||||
|  | 	}) | ||||||
|  | 	assert.Error(t, err) | ||||||
|  | 	assert.Nil(t, resp) | ||||||
|  | 	assert.Equal(t, errors.New( | ||||||
|  | 		"unable to fetch the organization_id, you must manually set it in the config: organization_id not found for foo-org", | ||||||
|  | 	), err) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // TestGitHub_WriteConfig_ErrorNoOrg tests that an error is returned when the | ||||||
|  | // required "organization" parameter is not provided | ||||||
|  | func TestGitHub_WriteConfig_ErrorNoOrg(t *testing.T) { | ||||||
|  | 	b, s := createBackendWithStorage(t) | ||||||
|  |  | ||||||
|  | 	// Write the config | ||||||
|  | 	resp, err := b.HandleRequest(context.Background(), &logical.Request{ | ||||||
|  | 		Path:      "config", | ||||||
|  | 		Operation: logical.UpdateOperation, | ||||||
|  | 		Data:      map[string]interface{}{}, | ||||||
|  | 		Storage:   s, | ||||||
|  | 	}) | ||||||
|  |  | ||||||
|  | 	assert.NoError(t, err) | ||||||
|  | 	assert.Error(t, resp.Error()) | ||||||
|  | 	assert.Equal(t, errors.New("organization is a required parameter"), resp.Error()) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // https://docs.github.com/en/rest/reference/users#get-the-authenticated-user | ||||||
|  | // Note: many of the fields have been omitted | ||||||
|  | var getUserResponse = ` | ||||||
|  | { | ||||||
|  | 	"login": "user-foo", | ||||||
|  | 	"id": 6789, | ||||||
|  | 	"description": "A great user. The very best user.", | ||||||
|  | 	"name": "foo name", | ||||||
|  | 	"company": "foo-company", | ||||||
|  | 	"type": "User" | ||||||
|  | } | ||||||
|  | ` | ||||||
|  |  | ||||||
|  | // https://docs.github.com/en/rest/reference/orgs#get-an-organization | ||||||
|  | // Note: many of the fields have been omitted, we only care about 'login' and 'id' | ||||||
|  | var getOrgResponse = ` | ||||||
|  | { | ||||||
|  | 	"login": "foo-org", | ||||||
|  | 	"id": 12345, | ||||||
|  | 	"description": "A great org. The very best org.", | ||||||
|  | 	"name": "foo-display-name", | ||||||
|  | 	"company": "foo-company", | ||||||
|  | 	"type": "Organization" | ||||||
|  | } | ||||||
|  | ` | ||||||
|  |  | ||||||
|  | // https://docs.github.com/en/rest/reference/orgs#list-organizations-for-the-authenticated-user | ||||||
|  | var listOrgResponse = []byte(fmt.Sprintf(`[%v]`, getOrgResponse)) | ||||||
|  |  | ||||||
|  | // https://docs.github.com/en/rest/reference/teams#list-teams-for-the-authenticated-user | ||||||
|  | // Note: many of the fields have been omitted | ||||||
|  | var listUserTeamsResponse = []byte(fmt.Sprintf(`[ | ||||||
|  | { | ||||||
|  |     "id": 1, | ||||||
|  |     "node_id": "MDQ6VGVhbTE=", | ||||||
|  |     "name": "Foo team", | ||||||
|  |     "slug": "foo-team", | ||||||
|  |     "description": "A great team. The very best team.", | ||||||
|  |     "permission": "admin", | ||||||
|  |     "organization": %v | ||||||
|  |   } | ||||||
|  | ]`, getOrgResponse)) | ||||||
| @@ -2,9 +2,9 @@ package github | |||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"context" | 	"context" | ||||||
|  | 	"errors" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"net/url" | 	"net/url" | ||||||
| 	"strings" |  | ||||||
|  |  | ||||||
| 	"github.com/google/go-github/github" | 	"github.com/google/go-github/github" | ||||||
| 	"github.com/hashicorp/vault/sdk/framework" | 	"github.com/hashicorp/vault/sdk/framework" | ||||||
| @@ -33,16 +33,13 @@ func pathLogin(b *backend) *framework.Path { | |||||||
| func (b *backend) pathLoginAliasLookahead(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) { | func (b *backend) pathLoginAliasLookahead(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) { | ||||||
| 	token := data.Get("token").(string) | 	token := data.Get("token").(string) | ||||||
|  |  | ||||||
| 	var verifyResp *verifyCredentialsResp | 	verifyResp, err := b.verifyCredentials(ctx, req, token) | ||||||
| 	if verifyResponse, resp, err := b.verifyCredentials(ctx, req, token); err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} else if resp != nil { |  | ||||||
| 		return resp, nil |  | ||||||
| 	} else { |  | ||||||
| 		verifyResp = verifyResponse |  | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	return &logical.Response{ | 	return &logical.Response{ | ||||||
|  | 		Warnings: verifyResp.Warnings, | ||||||
| 		Auth: &logical.Auth{ | 		Auth: &logical.Auth{ | ||||||
| 			Alias: &logical.Alias{ | 			Alias: &logical.Alias{ | ||||||
| 				Name: *verifyResp.User.Login, | 				Name: *verifyResp.User.Login, | ||||||
| @@ -54,13 +51,9 @@ func (b *backend) pathLoginAliasLookahead(ctx context.Context, req *logical.Requ | |||||||
| func (b *backend) pathLogin(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) { | func (b *backend) pathLogin(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) { | ||||||
| 	token := data.Get("token").(string) | 	token := data.Get("token").(string) | ||||||
|  |  | ||||||
| 	var verifyResp *verifyCredentialsResp | 	verifyResp, err := b.verifyCredentials(ctx, req, token) | ||||||
| 	if verifyResponse, resp, err := b.verifyCredentials(ctx, req, token); err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} else if resp != nil { |  | ||||||
| 		return resp, nil |  | ||||||
| 	} else { |  | ||||||
| 		verifyResp = verifyResponse |  | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	auth := &logical.Auth{ | 	auth := &logical.Auth{ | ||||||
| @@ -84,7 +77,8 @@ func (b *backend) pathLogin(ctx context.Context, req *logical.Request, data *fra | |||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	resp := &logical.Response{ | 	resp := &logical.Response{ | ||||||
| 		Auth: auth, | 		Warnings: verifyResp.Warnings, | ||||||
|  | 		Auth:     auth, | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	for _, teamName := range verifyResp.TeamNames { | 	for _, teamName := range verifyResp.TeamNames { | ||||||
| @@ -110,14 +104,11 @@ func (b *backend) pathLoginRenew(ctx context.Context, req *logical.Request, d *f | |||||||
| 	} | 	} | ||||||
| 	token := tokenRaw.(string) | 	token := tokenRaw.(string) | ||||||
|  |  | ||||||
| 	var verifyResp *verifyCredentialsResp | 	verifyResp, err := b.verifyCredentials(ctx, req, token) | ||||||
| 	if verifyResponse, resp, err := b.verifyCredentials(ctx, req, token); err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} else if resp != nil { |  | ||||||
| 		return resp, nil |  | ||||||
| 	} else { |  | ||||||
| 		verifyResp = verifyResponse |  | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if !policyutil.EquivalentPolicies(verifyResp.Policies, req.Auth.TokenPolicies) { | 	if !policyutil.EquivalentPolicies(verifyResp.Policies, req.Auth.TokenPolicies) { | ||||||
| 		return nil, fmt.Errorf("policies do not match") | 		return nil, fmt.Errorf("policies do not match") | ||||||
| 	} | 	} | ||||||
| @@ -126,6 +117,7 @@ func (b *backend) pathLoginRenew(ctx context.Context, req *logical.Request, d *f | |||||||
| 	resp.Auth.Period = verifyResp.Config.TokenPeriod | 	resp.Auth.Period = verifyResp.Config.TokenPeriod | ||||||
| 	resp.Auth.TTL = verifyResp.Config.TokenTTL | 	resp.Auth.TTL = verifyResp.Config.TokenTTL | ||||||
| 	resp.Auth.MaxTTL = verifyResp.Config.TokenMaxTTL | 	resp.Auth.MaxTTL = verifyResp.Config.TokenMaxTTL | ||||||
|  | 	resp.Warnings = verifyResp.Warnings | ||||||
|  |  | ||||||
| 	// Remove old aliases | 	// Remove old aliases | ||||||
| 	resp.Auth.GroupAliases = nil | 	resp.Auth.GroupAliases = nil | ||||||
| @@ -139,48 +131,64 @@ func (b *backend) pathLoginRenew(ctx context.Context, req *logical.Request, d *f | |||||||
| 	return resp, nil | 	return resp, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| func (b *backend) verifyCredentials(ctx context.Context, req *logical.Request, token string) (*verifyCredentialsResp, *logical.Response, error) { | func (b *backend) verifyCredentials(ctx context.Context, req *logical.Request, token string) (*verifyCredentialsResp, error) { | ||||||
|  | 	var warnings []string | ||||||
| 	config, err := b.Config(ctx, req.Storage) | 	config, err := b.Config(ctx, req.Storage) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 	if config == nil { | 	if config == nil { | ||||||
| 		return nil, logical.ErrorResponse("configuration has not been set"), nil | 		return nil, errors.New("configuration has not been set") | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// Check for a CIDR match. | 	// Check for a CIDR match. | ||||||
| 	if len(config.TokenBoundCIDRs) > 0 { | 	if len(config.TokenBoundCIDRs) > 0 { | ||||||
| 		if req.Connection == nil { | 		if req.Connection == nil { | ||||||
| 			b.Logger().Warn("token bound CIDRs found but no connection information available for validation") | 			b.Logger().Error("token bound CIDRs found but no connection information available for validation") | ||||||
| 			return nil, nil, logical.ErrPermissionDenied | 			return nil, logical.ErrPermissionDenied | ||||||
| 		} | 		} | ||||||
| 		if !cidrutil.RemoteAddrIsOk(req.Connection.RemoteAddr, config.TokenBoundCIDRs) { | 		if !cidrutil.RemoteAddrIsOk(req.Connection.RemoteAddr, config.TokenBoundCIDRs) { | ||||||
| 			return nil, nil, logical.ErrPermissionDenied | 			return nil, logical.ErrPermissionDenied | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if config.Organization == "" { |  | ||||||
| 		return nil, logical.ErrorResponse( |  | ||||||
| 			"organization not found in configuration"), nil |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	client, err := b.Client(token) | 	client, err := b.Client(token) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if config.BaseURL != "" { | 	if config.BaseURL != "" { | ||||||
| 		parsedURL, err := url.Parse(config.BaseURL) | 		parsedURL, err := url.Parse(config.BaseURL) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return nil, nil, fmt.Errorf("successfully parsed base_url when set but failing to parse now: %w", err) | 			return nil, fmt.Errorf("successfully parsed base_url when set but failing to parse now: %w", err) | ||||||
| 		} | 		} | ||||||
| 		client.BaseURL = parsedURL | 		client.BaseURL = parsedURL | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	if config.OrganizationID == 0 { | ||||||
|  | 		// Previously we did not verify using the Org ID. So if the Org ID is | ||||||
|  | 		// not set, we will trust-on-first-use and set it now. | ||||||
|  | 		err = config.setOrganizationID(ctx, client) | ||||||
|  | 		if err != nil { | ||||||
|  | 			b.Logger().Error("failed to set the organization_id on login", "error", err) | ||||||
|  | 			return nil, err | ||||||
|  | 		} | ||||||
|  | 		entry, err := logical.StorageEntryJSON("config", config) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return nil, err | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		if err := req.Storage.Put(ctx, entry); err != nil { | ||||||
|  | 			return nil, err | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		b.Logger().Info("set ID on a trust-on-first-use basis", "organization_id", config.OrganizationID) | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	// Get the user | 	// Get the user | ||||||
| 	user, _, err := client.Users.Get(ctx, "") | 	user, _, err := client.Users.Get(ctx, "") | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// Verify that the user is part of the organization | 	// Verify that the user is part of the organization | ||||||
| @@ -194,7 +202,7 @@ func (b *backend) verifyCredentials(ctx context.Context, req *logical.Request, t | |||||||
| 	for { | 	for { | ||||||
| 		orgs, resp, err := client.Organizations.List(ctx, "", orgOpt) | 		orgs, resp, err := client.Organizations.List(ctx, "", orgOpt) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return nil, nil, err | 			return nil, err | ||||||
| 		} | 		} | ||||||
| 		allOrgs = append(allOrgs, orgs...) | 		allOrgs = append(allOrgs, orgs...) | ||||||
| 		if resp.NextPage == 0 { | 		if resp.NextPage == 0 { | ||||||
| @@ -203,14 +211,27 @@ func (b *backend) verifyCredentials(ctx context.Context, req *logical.Request, t | |||||||
| 		orgOpt.Page = resp.NextPage | 		orgOpt.Page = resp.NextPage | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	orgLoginName := "" | ||||||
| 	for _, o := range allOrgs { | 	for _, o := range allOrgs { | ||||||
| 		if strings.EqualFold(*o.Login, config.Organization) { | 		if o.GetID() == config.OrganizationID { | ||||||
| 			org = o | 			org = o | ||||||
|  | 			orgLoginName = *o.Login | ||||||
| 			break | 			break | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 	if org == nil { | 	if org == nil { | ||||||
| 		return nil, logical.ErrorResponse("user is not part of required org"), nil | 		return nil, errors.New("user is not part of required org") | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if orgLoginName != config.Organization { | ||||||
|  | 		warningMsg := fmt.Sprintf( | ||||||
|  | 			"the organization name has changed to %q. It is recommended to verify and update the organization name in the config: %s=%d", | ||||||
|  | 			orgLoginName, | ||||||
|  | 			"organization_id", | ||||||
|  | 			config.OrganizationID, | ||||||
|  | 		) | ||||||
|  | 		b.Logger().Warn(warningMsg) | ||||||
|  | 		warnings = append(warnings, warningMsg) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// Get the teams that this user is part of to determine the policies | 	// Get the teams that this user is part of to determine the policies | ||||||
| @@ -224,7 +245,7 @@ func (b *backend) verifyCredentials(ctx context.Context, req *logical.Request, t | |||||||
| 	for { | 	for { | ||||||
| 		teams, resp, err := client.Teams.ListUserTeams(ctx, teamOpt) | 		teams, resp, err := client.Teams.ListUserTeams(ctx, teamOpt) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return nil, nil, err | 			return nil, err | ||||||
| 		} | 		} | ||||||
| 		allTeams = append(allTeams, teams...) | 		allTeams = append(allTeams, teams...) | ||||||
| 		if resp.NextPage == 0 { | 		if resp.NextPage == 0 { | ||||||
| @@ -248,21 +269,24 @@ func (b *backend) verifyCredentials(ctx context.Context, req *logical.Request, t | |||||||
|  |  | ||||||
| 	groupPoliciesList, err := b.TeamMap.Policies(ctx, req.Storage, teamNames...) | 	groupPoliciesList, err := b.TeamMap.Policies(ctx, req.Storage, teamNames...) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	userPoliciesList, err := b.UserMap.Policies(ctx, req.Storage, []string{*user.Login}...) | 	userPoliciesList, err := b.UserMap.Policies(ctx, req.Storage, []string{*user.Login}...) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	return &verifyCredentialsResp{ | 	verifyResp := &verifyCredentialsResp{ | ||||||
| 		User:      user, | 		User:      user, | ||||||
| 		Org:       org, | 		Org:       org, | ||||||
| 		Policies:  append(groupPoliciesList, userPoliciesList...), | 		Policies:  append(groupPoliciesList, userPoliciesList...), | ||||||
| 		TeamNames: teamNames, | 		TeamNames: teamNames, | ||||||
| 		Config:    config, | 		Config:    config, | ||||||
| 	}, nil, nil | 		Warnings:  warnings, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return verifyResp, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| type verifyCredentialsResp struct { | type verifyCredentialsResp struct { | ||||||
| @@ -271,6 +295,9 @@ type verifyCredentialsResp struct { | |||||||
| 	Policies  []string | 	Policies  []string | ||||||
| 	TeamNames []string | 	TeamNames []string | ||||||
|  |  | ||||||
|  | 	// Warnings to send back to the caller | ||||||
|  | 	Warnings []string | ||||||
|  |  | ||||||
| 	// This is just a cache to send back to the caller | 	// This is just a cache to send back to the caller | ||||||
| 	Config *config | 	Config *config | ||||||
| } | } | ||||||
|   | |||||||
							
								
								
									
										186
									
								
								builtin/credential/github/path_login_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										186
									
								
								builtin/credential/github/path_login_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,186 @@ | |||||||
|  | package github | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"context" | ||||||
|  | 	"errors" | ||||||
|  | 	"testing" | ||||||
|  |  | ||||||
|  | 	"github.com/hashicorp/vault/helper/namespace" | ||||||
|  | 	"github.com/hashicorp/vault/sdk/logical" | ||||||
|  | 	"github.com/stretchr/testify/assert" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | // TestGitHub_Login tests that we can successfully login with the given config | ||||||
|  | func TestGitHub_Login(t *testing.T) { | ||||||
|  | 	b, s := createBackendWithStorage(t) | ||||||
|  |  | ||||||
|  | 	// use a test server to return our mock GH org info | ||||||
|  | 	ts := setupTestServer(t) | ||||||
|  | 	defer ts.Close() | ||||||
|  |  | ||||||
|  | 	// Write the config | ||||||
|  | 	resp, err := b.HandleRequest(context.Background(), &logical.Request{ | ||||||
|  | 		Path:      "config", | ||||||
|  | 		Operation: logical.UpdateOperation, | ||||||
|  | 		Data: map[string]interface{}{ | ||||||
|  | 			"organization": "foo-org", | ||||||
|  | 			"base_url":     ts.URL, // base_url will call the test server | ||||||
|  | 		}, | ||||||
|  | 		Storage: s, | ||||||
|  | 	}) | ||||||
|  | 	assert.NoError(t, err) | ||||||
|  | 	assert.NoError(t, resp.Error()) | ||||||
|  |  | ||||||
|  | 	// Read the config | ||||||
|  | 	resp, err = b.HandleRequest(context.Background(), &logical.Request{ | ||||||
|  | 		Path:      "config", | ||||||
|  | 		Operation: logical.ReadOperation, | ||||||
|  | 		Storage:   s, | ||||||
|  | 	}) | ||||||
|  | 	assert.NoError(t, err) | ||||||
|  | 	assert.NoError(t, resp.Error()) | ||||||
|  |  | ||||||
|  | 	// attempt a login | ||||||
|  | 	resp, err = b.HandleRequest(context.Background(), &logical.Request{ | ||||||
|  | 		Path:      "login", | ||||||
|  | 		Operation: logical.UpdateOperation, | ||||||
|  | 		Storage:   s, | ||||||
|  | 	}) | ||||||
|  |  | ||||||
|  | 	expectedMetaData := map[string]string{ | ||||||
|  | 		"org":      "foo-org", | ||||||
|  | 		"username": "user-foo", | ||||||
|  | 	} | ||||||
|  | 	assert.Equal(t, expectedMetaData, resp.Auth.Metadata) | ||||||
|  | 	assert.NoError(t, err) | ||||||
|  | 	assert.NoError(t, resp.Error()) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // TestGitHub_Login_OrgInvalid tests that we cannot login with an ID other than | ||||||
|  | // what is set in the config | ||||||
|  | func TestGitHub_Login_OrgInvalid(t *testing.T) { | ||||||
|  | 	b, s := createBackendWithStorage(t) | ||||||
|  | 	ctx := namespace.RootContext(nil) | ||||||
|  |  | ||||||
|  | 	// use a test server to return our mock GH org info | ||||||
|  | 	ts := setupTestServer(t) | ||||||
|  | 	defer ts.Close() | ||||||
|  |  | ||||||
|  | 	// write and store config | ||||||
|  | 	config := config{ | ||||||
|  | 		Organization:   "foo-org", | ||||||
|  | 		OrganizationID: 9999, | ||||||
|  | 		BaseURL:        ts.URL + "/", // base_url will call the test server | ||||||
|  | 	} | ||||||
|  | 	entry, err := logical.StorageEntryJSON("config", config) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatalf("failed creating storage entry") | ||||||
|  | 	} | ||||||
|  | 	if err := s.Put(ctx, entry); err != nil { | ||||||
|  | 		t.Fatalf("writing to in mem storage failed") | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// attempt a login | ||||||
|  | 	resp, err := b.HandleRequest(context.Background(), &logical.Request{ | ||||||
|  | 		Path:      "login", | ||||||
|  | 		Operation: logical.UpdateOperation, | ||||||
|  | 		Storage:   s, | ||||||
|  | 	}) | ||||||
|  |  | ||||||
|  | 	assert.Nil(t, resp) | ||||||
|  | 	assert.Error(t, err) | ||||||
|  | 	assert.Equal(t, errors.New("user is not part of required org"), err) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // TestGitHub_Login_OrgNameChanged tests that we can successfully login with the | ||||||
|  | // given config and emit a warning when the organization name has changed | ||||||
|  | func TestGitHub_Login_OrgNameChanged(t *testing.T) { | ||||||
|  | 	b, s := createBackendWithStorage(t) | ||||||
|  | 	ctx := namespace.RootContext(nil) | ||||||
|  |  | ||||||
|  | 	// use a test server to return our mock GH org info | ||||||
|  | 	ts := setupTestServer(t) | ||||||
|  | 	defer ts.Close() | ||||||
|  |  | ||||||
|  | 	// write and store config | ||||||
|  | 	// the name does not match what the API will return but the ID does | ||||||
|  | 	config := config{ | ||||||
|  | 		Organization:   "old-name", | ||||||
|  | 		OrganizationID: 12345, | ||||||
|  | 		BaseURL:        ts.URL + "/", // base_url will call the test server | ||||||
|  | 	} | ||||||
|  | 	entry, err := logical.StorageEntryJSON("config", config) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatalf("failed creating storage entry") | ||||||
|  | 	} | ||||||
|  | 	if err := s.Put(ctx, entry); err != nil { | ||||||
|  | 		t.Fatalf("writing to in mem storage failed") | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// attempt a login | ||||||
|  | 	resp, err := b.HandleRequest(context.Background(), &logical.Request{ | ||||||
|  | 		Path:      "login", | ||||||
|  | 		Operation: logical.UpdateOperation, | ||||||
|  | 		Storage:   s, | ||||||
|  | 	}) | ||||||
|  |  | ||||||
|  | 	assert.NoError(t, err) | ||||||
|  | 	assert.Nil(t, resp.Error()) | ||||||
|  | 	assert.Equal( | ||||||
|  | 		t, | ||||||
|  | 		[]string{"the organization name has changed to \"foo-org\". It is recommended to verify and update the organization name in the config: organization_id=12345"}, | ||||||
|  | 		resp.Warnings, | ||||||
|  | 	) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // TestGitHub_Login_NoOrgID tests that we can successfully login with the given | ||||||
|  | // config when no organization ID is present and write the fetched ID to the | ||||||
|  | // config | ||||||
|  | func TestGitHub_Login_NoOrgID(t *testing.T) { | ||||||
|  | 	b, s := createBackendWithStorage(t) | ||||||
|  | 	ctx := namespace.RootContext(nil) | ||||||
|  |  | ||||||
|  | 	// use a test server to return our mock GH org info | ||||||
|  | 	ts := setupTestServer(t) | ||||||
|  | 	defer ts.Close() | ||||||
|  |  | ||||||
|  | 	// write and store config without Org ID | ||||||
|  | 	config := config{ | ||||||
|  | 		Organization: "foo-org", | ||||||
|  | 		BaseURL:      ts.URL + "/", // base_url will call the test server | ||||||
|  | 	} | ||||||
|  | 	entry, err := logical.StorageEntryJSON("config", config) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatalf("failed creating storage entry") | ||||||
|  | 	} | ||||||
|  | 	if err := s.Put(ctx, entry); err != nil { | ||||||
|  | 		t.Fatalf("writing to in mem storage failed") | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// attempt a login | ||||||
|  | 	resp, err := b.HandleRequest(context.Background(), &logical.Request{ | ||||||
|  | 		Path:      "login", | ||||||
|  | 		Operation: logical.UpdateOperation, | ||||||
|  | 		Storage:   s, | ||||||
|  | 	}) | ||||||
|  |  | ||||||
|  | 	expectedMetaData := map[string]string{ | ||||||
|  | 		"org":      "foo-org", | ||||||
|  | 		"username": "user-foo", | ||||||
|  | 	} | ||||||
|  | 	assert.Equal(t, expectedMetaData, resp.Auth.Metadata) | ||||||
|  | 	assert.NoError(t, err) | ||||||
|  | 	assert.NoError(t, resp.Error()) | ||||||
|  |  | ||||||
|  | 	// Read the config | ||||||
|  | 	resp, err = b.HandleRequest(context.Background(), &logical.Request{ | ||||||
|  | 		Path:      "config", | ||||||
|  | 		Operation: logical.ReadOperation, | ||||||
|  | 		Storage:   s, | ||||||
|  | 	}) | ||||||
|  | 	assert.NoError(t, err) | ||||||
|  | 	assert.NoError(t, resp.Error()) | ||||||
|  |  | ||||||
|  | 	// the ID should be set, we grab it from the GET /orgs API | ||||||
|  | 	assert.Equal(t, int64(12345), resp.Data["organization_id"]) | ||||||
|  | } | ||||||
							
								
								
									
										3
									
								
								changelog/13332.txt
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										3
									
								
								changelog/13332.txt
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,3 @@ | |||||||
|  | ```release-note:bug | ||||||
|  | auth/github: Use the Organization ID instead of the Organization name to verify the org membership. | ||||||
|  | ``` | ||||||
		Reference in New Issue
	
	Block a user