mirror of
				https://github.com/optim-enterprises-bv/vault.git
				synced 2025-10-30 02:02:43 +00:00 
			
		
		
		
	[Vault-5248] MFA support for api login helpers (#14900)
* Add MFA support to login helpers
This commit is contained in:
		
							
								
								
									
										78
									
								
								api/auth.go
									
									
									
									
									
								
							
							
						
						
									
										78
									
								
								api/auth.go
									
									
									
									
									
								
							| @@ -31,16 +31,82 @@ func (a *Auth) Login(ctx context.Context, authMethod AuthMethod) (*Secret, error | ||||
| 	if authMethod == nil { | ||||
| 		return nil, fmt.Errorf("no auth method provided for login") | ||||
| 	} | ||||
| 	return a.login(ctx, authMethod) | ||||
| } | ||||
|  | ||||
| 	authSecret, err := authMethod.Login(ctx, a.c) | ||||
| // MFALogin is a wrapper that helps satisfy Vault's MFA implementation. | ||||
| // If optional credentials are provided a single-phase login will be attempted | ||||
| // and the resulting Secret will contain a ClientToken if the authentication is successful. | ||||
| // The client's token will also be set accordingly. | ||||
| // | ||||
| // If no credentials are provided a two-phase MFA login will be assumed and the resulting | ||||
| // Secret will have a MFARequirement containing the MFARequestID to be used in a follow-up | ||||
| // call to `sys/mfa/validate` or by passing it to the method (*Auth).MFAValidate. | ||||
| func (a *Auth) MFALogin(ctx context.Context, authMethod AuthMethod, creds ...string) (*Secret, error) { | ||||
| 	if len(creds) > 0 { | ||||
| 		a.c.SetMFACreds(creds) | ||||
| 		return a.login(ctx, authMethod) | ||||
| 	} | ||||
|  | ||||
| 	return a.twoPhaseMFALogin(ctx, authMethod) | ||||
| } | ||||
|  | ||||
| // MFAValidate validates an MFA request using the appropriate payload and a secret containing | ||||
| // Auth.MFARequirement, like the one returned by MFALogin when credentials are not provided. | ||||
| // Upon successful validation the client token will be set accordingly. | ||||
| // | ||||
| // The Secret returned is the authentication secret, which if desired can be | ||||
| // passed as input to the NewLifetimeWatcher method in order to start | ||||
| // automatically renewing the token. | ||||
| func (a *Auth) MFAValidate(ctx context.Context, mfaSecret *Secret, payload map[string]interface{}) (*Secret, error) { | ||||
| 	if mfaSecret == nil || mfaSecret.Auth == nil || mfaSecret.Auth.MFARequirement == nil { | ||||
| 		return nil, fmt.Errorf("secret does not contain MFARequirements") | ||||
| 	} | ||||
|  | ||||
| 	s, err := a.c.Sys().MFAValidateWithContext(ctx, mfaSecret.Auth.MFARequirement.GetMFARequestID(), payload) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	return a.checkAndSetToken(s) | ||||
| } | ||||
|  | ||||
| // login performs the (*AuthMethod).Login() with the configured client and checks that a ClientToken is returned | ||||
| func (a *Auth) login(ctx context.Context, authMethod AuthMethod) (*Secret, error) { | ||||
| 	s, err := authMethod.Login(ctx, a.c) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("unable to log in to auth method: %w", err) | ||||
| 	} | ||||
| 	if authSecret == nil || authSecret.Auth == nil || authSecret.Auth.ClientToken == "" { | ||||
| 		return nil, fmt.Errorf("login response from auth method did not return client token") | ||||
|  | ||||
| 	return a.checkAndSetToken(s) | ||||
| } | ||||
|  | ||||
| // twoPhaseMFALogin performs the (*AuthMethod).Login() with the configured client | ||||
| // and checks that an MFARequirement is returned | ||||
| func (a *Auth) twoPhaseMFALogin(ctx context.Context, authMethod AuthMethod) (*Secret, error) { | ||||
| 	s, err := authMethod.Login(ctx, a.c) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("unable to log in: %w", err) | ||||
| 	} | ||||
| 	if s == nil || s.Auth == nil || s.Auth.MFARequirement == nil { | ||||
| 		if s != nil { | ||||
| 			s.Warnings = append(s.Warnings, "expected secret to contain MFARequirements") | ||||
| 		} | ||||
| 		return s, fmt.Errorf("assumed two-phase MFA login, returned secret is missing MFARequirements") | ||||
| 	} | ||||
|  | ||||
| 	a.c.SetToken(authSecret.Auth.ClientToken) | ||||
|  | ||||
| 	return authSecret, nil | ||||
| 	return s, nil | ||||
| } | ||||
|  | ||||
| func (a *Auth) checkAndSetToken(s *Secret) (*Secret, error) { | ||||
| 	if s == nil || s.Auth == nil || s.Auth.ClientToken == "" { | ||||
| 		if s != nil { | ||||
| 			s.Warnings = append(s.Warnings, "expected secret to contain ClientToken") | ||||
| 		} | ||||
| 		return s, fmt.Errorf("response did not return ClientToken, client token not set") | ||||
| 	} | ||||
|  | ||||
| 	a.c.SetToken(s.Auth.ClientToken) | ||||
|  | ||||
| 	return s, nil | ||||
| } | ||||
|   | ||||
							
								
								
									
										130
									
								
								api/auth_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										130
									
								
								api/auth_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,130 @@ | ||||
| package api | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"testing" | ||||
|  | ||||
| 	"github.com/hashicorp/vault/sdk/logical" | ||||
| ) | ||||
|  | ||||
| type mockAuthMethod struct { | ||||
| 	mockedSecret *Secret | ||||
| 	mockedError  error | ||||
| } | ||||
|  | ||||
| func (m *mockAuthMethod) Login(_ context.Context, _ *Client) (*Secret, error) { | ||||
| 	return m.mockedSecret, m.mockedError | ||||
| } | ||||
|  | ||||
| func TestAuth_Login(t *testing.T) { | ||||
| 	a := &Auth{ | ||||
| 		c: &Client{}, | ||||
| 	} | ||||
|  | ||||
| 	m := mockAuthMethod{ | ||||
| 		mockedSecret: &Secret{ | ||||
| 			Auth: &SecretAuth{ | ||||
| 				ClientToken: "a-client-token", | ||||
| 			}, | ||||
| 		}, | ||||
| 		mockedError: nil, | ||||
| 	} | ||||
|  | ||||
| 	t.Run("Login should set token on success", func(t *testing.T) { | ||||
| 		if a.c.Token() != "" { | ||||
| 			t.Errorf("client token was %v expected to be unset", a.c.Token()) | ||||
| 		} | ||||
|  | ||||
| 		_, err := a.Login(context.Background(), &m) | ||||
| 		if err != nil { | ||||
| 			t.Errorf("Login() error = %v", err) | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
| 		if a.c.Token() != m.mockedSecret.Auth.ClientToken { | ||||
| 			t.Errorf("client token was %v expected %v", a.c.Token(), m.mockedSecret.Auth.ClientToken) | ||||
| 			return | ||||
| 		} | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| func TestAuth_MFALoginSinglePhase(t *testing.T) { | ||||
| 	t.Run("MFALogin() should succeed if credentials are passed in", func(t *testing.T) { | ||||
| 		a := &Auth{ | ||||
| 			c: &Client{}, | ||||
| 		} | ||||
|  | ||||
| 		m := mockAuthMethod{ | ||||
| 			mockedSecret: &Secret{ | ||||
| 				Auth: &SecretAuth{ | ||||
| 					ClientToken: "a-client-token", | ||||
| 				}, | ||||
| 			}, | ||||
| 			mockedError: nil, | ||||
| 		} | ||||
|  | ||||
| 		_, err := a.MFALogin(context.Background(), &m, "testMethod:testPasscode") | ||||
| 		if err != nil { | ||||
| 			t.Errorf("MFALogin() error %v", err) | ||||
| 			return | ||||
| 		} | ||||
| 		if a.c.Token() != m.mockedSecret.Auth.ClientToken { | ||||
| 			t.Errorf("client token was %v expected %v", a.c.Token(), m.mockedSecret.Auth.ClientToken) | ||||
| 			return | ||||
| 		} | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| func TestAuth_MFALoginTwoPhase(t *testing.T) { | ||||
| 	tests := []struct { | ||||
| 		name    string | ||||
| 		a       *Auth | ||||
| 		m       *mockAuthMethod | ||||
| 		creds   *string | ||||
| 		wantErr bool | ||||
| 	}{ | ||||
| 		{ | ||||
| 			name: "return MFARequirements", | ||||
| 			a: &Auth{ | ||||
| 				c: &Client{}, | ||||
| 			}, | ||||
| 			m: &mockAuthMethod{ | ||||
| 				mockedSecret: &Secret{ | ||||
| 					Auth: &SecretAuth{ | ||||
| 						MFARequirement: &logical.MFARequirement{ | ||||
| 							MFARequestID:   "a-req-id", | ||||
| 							MFAConstraints: nil, | ||||
| 						}, | ||||
| 					}, | ||||
| 				}, | ||||
| 			}, | ||||
| 			wantErr: false, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "error if no MFARequirements", | ||||
| 			a: &Auth{ | ||||
| 				c: &Client{}, | ||||
| 			}, | ||||
| 			m: &mockAuthMethod{ | ||||
| 				mockedSecret: &Secret{ | ||||
| 					Auth: &SecretAuth{}, | ||||
| 				}, | ||||
| 			}, | ||||
| 			wantErr: true, | ||||
| 		}, | ||||
| 	} | ||||
| 	for _, tt := range tests { | ||||
| 		t.Run(tt.name, func(t *testing.T) { | ||||
| 			secret, err := tt.a.MFALogin(context.Background(), tt.m) | ||||
| 			if (err != nil) != tt.wantErr { | ||||
| 				t.Errorf("MFALogin() error = %v, wantErr %v", err, tt.wantErr) | ||||
| 				return | ||||
| 			} | ||||
|  | ||||
| 			if secret.Auth.MFARequirement != tt.m.mockedSecret.Auth.MFARequirement { | ||||
| 				t.Errorf("MFALogin() returned %v, expected %v", secret.Auth.MFARequirement, tt.m.mockedSecret.Auth.MFARequirement) | ||||
| 				return | ||||
| 			} | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										45
									
								
								api/sys_mfa.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										45
									
								
								api/sys_mfa.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,45 @@ | ||||
| package api | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"fmt" | ||||
| 	"net/http" | ||||
| ) | ||||
|  | ||||
| func (c *Sys) MFAValidate(requestID string, payload map[string]interface{}) (*Secret, error) { | ||||
| 	return c.MFAValidateWithContext(context.Background(), requestID, payload) | ||||
| } | ||||
|  | ||||
| func (c *Sys) MFAValidateWithContext(ctx context.Context, requestID string, payload map[string]interface{}) (*Secret, error) { | ||||
| 	ctx, cancelFunc := c.c.withConfiguredTimeout(ctx) | ||||
| 	defer cancelFunc() | ||||
|  | ||||
| 	body := map[string]interface{}{ | ||||
| 		"mfa_request_id": requestID, | ||||
| 		"mfa_payload":    payload, | ||||
| 	} | ||||
|  | ||||
| 	r := c.c.NewRequest(http.MethodPost, fmt.Sprintf("/v1/sys/mfa/validate")) | ||||
| 	if err := r.SetJSONBody(body); err != nil { | ||||
| 		return nil, fmt.Errorf("failed to set request body: %w", err) | ||||
| 	} | ||||
|  | ||||
| 	resp, err := c.c.rawRequestWithContext(ctx, r) | ||||
| 	if resp != nil { | ||||
| 		defer resp.Body.Close() | ||||
| 	} | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	secret, err := ParseSecret(resp.Body) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("failed to parse secret from response: %w", err) | ||||
| 	} | ||||
|  | ||||
| 	if secret == nil { | ||||
| 		return nil, fmt.Errorf("data from server response is empty") | ||||
| 	} | ||||
|  | ||||
| 	return secret, nil | ||||
| } | ||||
							
								
								
									
										3
									
								
								changelog/14900.txt
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										3
									
								
								changelog/14900.txt
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,3 @@ | ||||
| ```release-note:improvement | ||||
| api: Added MFALogin() for handling MFA flow when using login helpers. | ||||
| ``` | ||||
| @@ -259,8 +259,8 @@ func (c *BaseCommand) validateMFA(reqID string, methodInfo MFAMethodInfo) int { | ||||
| 	} | ||||
|  | ||||
| 	// passcode could be an empty string | ||||
| 	mfaPayload := map[string][]string{ | ||||
| 		methodInfo.methodID: {passcode}, | ||||
| 	mfaPayload := map[string]interface{}{ | ||||
| 		methodInfo.methodID: []string{passcode}, | ||||
| 	} | ||||
|  | ||||
| 	client, err := c.Client() | ||||
| @@ -269,12 +269,7 @@ func (c *BaseCommand) validateMFA(reqID string, methodInfo MFAMethodInfo) int { | ||||
| 		return 2 | ||||
| 	} | ||||
|  | ||||
| 	path := "sys/mfa/validate" | ||||
|  | ||||
| 	secret, err := client.Logical().Write(path, map[string]interface{}{ | ||||
| 		"mfa_request_id": reqID, | ||||
| 		"mfa_payload":    mfaPayload, | ||||
| 	}) | ||||
| 	secret, err := client.Sys().MFAValidate(reqID, mfaPayload) | ||||
| 	if err != nil { | ||||
| 		c.UI.Error(err.Error()) | ||||
| 		if secret != nil { | ||||
| @@ -285,7 +280,7 @@ func (c *BaseCommand) validateMFA(reqID string, methodInfo MFAMethodInfo) int { | ||||
| 	if secret == nil { | ||||
| 		// Don't output anything unless using the "table" format | ||||
| 		if Format(c.UI) == "table" { | ||||
| 			c.UI.Info(fmt.Sprintf("Success! Data written to: %s", path)) | ||||
| 			c.UI.Info("Success! Data written to: sys/mfa/validate") | ||||
| 		} | ||||
| 		return 0 | ||||
| 	} | ||||
|   | ||||
| @@ -1,6 +1,7 @@ | ||||
| package identity | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"fmt" | ||||
| 	"net/http" | ||||
| 	"reflect" | ||||
| @@ -241,7 +242,6 @@ func mfaGenerateLoginDUOTest(client *api.Client) error { | ||||
| 			return fmt.Errorf("failed to configure MFAEnforcementConfig: %v", err) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	secret, err = client.Logical().Write("auth/userpass/login/vaultmfa", map[string]interface{}{ | ||||
| 		"password": "testpassword", | ||||
| 	}) | ||||
| @@ -272,12 +272,11 @@ func mfaGenerateLoginDUOTest(client *api.Client) error { | ||||
| 	} | ||||
|  | ||||
| 	// validation | ||||
| 	secret, err = client.Logical().Write("sys/mfa/validate", map[string]interface{}{ | ||||
| 		"mfa_request_id": secret.Auth.MFARequirement.MFARequestID, | ||||
| 		"mfa_payload": map[string][]string{ | ||||
| 			methodID: {}, | ||||
| 		}, | ||||
| 	}) | ||||
| 	secret, err = client.Sys().MFAValidateWithContext(context.Background(), | ||||
| 		secret.Auth.MFARequirement.MFARequestID, | ||||
| 		map[string]interface{}{ | ||||
| 			methodID: []string{}, | ||||
| 		}) | ||||
| 	if err != nil { | ||||
| 		return fmt.Errorf("MFA failed: %v", err) | ||||
| 	} | ||||
|   | ||||
| @@ -1,6 +1,7 @@ | ||||
| package identity | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"fmt" | ||||
| 	"reflect" | ||||
| 	"testing" | ||||
| @@ -322,12 +323,12 @@ func mfaGenerateOktaLoginMFATest(client *api.Client) error { | ||||
| 	} | ||||
|  | ||||
| 	// validation | ||||
| 	secret, err = client.Logical().Write("sys/mfa/validate", map[string]interface{}{ | ||||
| 		"mfa_request_id": secret.Auth.MFARequirement.MFARequestID, | ||||
| 		"mfa_payload": map[string][]string{ | ||||
| 			methodID: {}, | ||||
| 	secret, err = client.Sys().MFAValidateWithContext(context.Background(), | ||||
| 		secret.Auth.MFARequirement.MFARequestID, | ||||
| 		map[string]interface{}{ | ||||
| 			methodID: []string{}, | ||||
| 		}, | ||||
| 	}) | ||||
| 	) | ||||
| 	if err != nil { | ||||
| 		return fmt.Errorf("MFA failed: %v", err) | ||||
| 	} | ||||
|   | ||||
| @@ -7,6 +7,8 @@ import ( | ||||
| 	"testing" | ||||
| 	"time" | ||||
|  | ||||
| 	upAuth "github.com/hashicorp/vault/api/auth/userpass" | ||||
|  | ||||
| 	"github.com/hashicorp/vault/api" | ||||
| 	"github.com/hashicorp/vault/audit" | ||||
| 	"github.com/hashicorp/vault/builtin/credential/userpass" | ||||
| @@ -76,21 +78,27 @@ func doTwoPhaseLogin(client *api.Client, totpCodePath, methodID, username string | ||||
| 	} | ||||
| 	totpPasscode := totpResp.Data["code"].(string) | ||||
|  | ||||
| 	secret, err := client.Logical().WriteWithContext(context.Background(), fmt.Sprintf("auth/userpass/login/%s", username), map[string]interface{}{ | ||||
| 		"password": "testpassword", | ||||
| 	}) | ||||
| 	upMethod, err := upAuth.NewUserpassAuth(username, &upAuth.Password{FromString: "testpassword"}) | ||||
|  | ||||
| 	mfaSecret, err := client.Auth().MFALogin(context.Background(), upMethod) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("first phase of login MFA failed: %v", err) | ||||
| 		t.Fatalf("failed to login with userpass auth method: %v", err) | ||||
| 	} | ||||
| 	secret, err = client.Logical().WriteWithContext(context.Background(), "sys/mfa/validate", map[string]interface{}{ | ||||
| 		"mfa_request_id": secret.Auth.MFARequirement.MFARequestID, | ||||
| 		"mfa_payload": map[string][]string{ | ||||
| 			methodID: {totpPasscode}, | ||||
|  | ||||
| 	secret, err := client.Auth().MFAValidate( | ||||
| 		context.Background(), | ||||
| 		mfaSecret, | ||||
| 		map[string]interface{}{ | ||||
| 			methodID: []string{totpPasscode}, | ||||
| 		}, | ||||
| 	}) | ||||
| 	) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("MFA validation failed: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	if secret == nil || secret.Auth == nil || secret.Auth.ClientToken == "" { | ||||
| 		t.Fatalf("MFA validation failed to return a ClientToken in secret: %v", secret) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestLoginMfaGenerateTOTPTestAuditIncluded(t *testing.T) { | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Vinny Mannello
					Vinny Mannello