mirror of
				https://github.com/optim-enterprises-bv/vault.git
				synced 2025-10-31 02:28:09 +00:00 
			
		
		
		
	named Login MFA methods (#18610)
* named MFA method configurations * fix a test * CL * fix an issue with same config name different ID and add a test * feedback * feedback on test * consistent use of passcode for all MFA methods (#18611) * make use of passcode factor consistent for all MFA types * improved type for MFA factors * add method name to login CLI * minor refactoring * only accept MFA method name with its namespace path in the login request MFA header * fix a bug * fixing an ErrorOrNil return value * more informative error message * Apply suggestions from code review Co-authored-by: Nick Cabatoff <ncabatoff@hashicorp.com> * feedback * test refactor a bit * adding godoc for a test * feedback * remove sanitize method name * guard a possbile nil ref Co-authored-by: Nick Cabatoff <ncabatoff@hashicorp.com>
This commit is contained in:
		
							
								
								
									
										4
									
								
								changelog/18610.txt
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										4
									
								
								changelog/18610.txt
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,4 @@ | |||||||
|  | ```release-note:improvement | ||||||
|  | auth: Allow naming login MFA methods and using those names instead of IDs in satisfying MFA requirement for requests. | ||||||
|  | Make passcode arguments consistent across login MFA method types. | ||||||
|  | ``` | ||||||
| @@ -557,6 +557,9 @@ func (t TableFormatter) OutputSecret(ui cli.Ui, secret *api.Secret) error { | |||||||
| 				for _, constraint := range constraintSet.Any { | 				for _, constraint := range constraintSet.Any { | ||||||
| 					out = append(out, fmt.Sprintf("mfa_constraint_%s_%s_id %s %s", k, constraint.Type, hopeDelim, constraint.ID)) | 					out = append(out, fmt.Sprintf("mfa_constraint_%s_%s_id %s %s", k, constraint.Type, hopeDelim, constraint.ID)) | ||||||
| 					out = append(out, fmt.Sprintf("mfa_constraint_%s_%s_uses_passcode %s %t", k, constraint.Type, hopeDelim, constraint.UsesPasscode)) | 					out = append(out, fmt.Sprintf("mfa_constraint_%s_%s_uses_passcode %s %t", k, constraint.Type, hopeDelim, constraint.UsesPasscode)) | ||||||
|  | 					if constraint.Name != "" { | ||||||
|  | 						out = append(out, fmt.Sprintf("mfa_constraint_%s_%s_name %s %s", k, constraint.Type, hopeDelim, constraint.Name)) | ||||||
|  | 					} | ||||||
| 				} | 				} | ||||||
| 			} | 			} | ||||||
| 		} else { // Token information only makes sense if no further MFA requirement (i.e. if we actually have a token) | 		} else { // Token information only makes sense if no further MFA requirement (i.e. if we actually have a token) | ||||||
|   | |||||||
| @@ -1,8 +1,11 @@ | |||||||
| package command | package command | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
|  | 	"context" | ||||||
|  | 	"regexp" | ||||||
| 	"strings" | 	"strings" | ||||||
| 	"testing" | 	"testing" | ||||||
|  | 	"time" | ||||||
|  |  | ||||||
| 	"github.com/mitchellh/cli" | 	"github.com/mitchellh/cli" | ||||||
|  |  | ||||||
| @@ -37,10 +40,7 @@ func testLoginCommand(tb testing.TB) (*cli.MockUi, *LoginCommand) { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| func TestLoginCommand_Run(t *testing.T) { | func TestCustomPath(t *testing.T) { | ||||||
| 	t.Parallel() |  | ||||||
|  |  | ||||||
| 	t.Run("custom_path", func(t *testing.T) { |  | ||||||
| 	t.Parallel() | 	t.Parallel() | ||||||
|  |  | ||||||
| 	client, closer := testVaultServer(t) | 	client, closer := testVaultServer(t) | ||||||
| @@ -91,9 +91,10 @@ func TestLoginCommand_Run(t *testing.T) { | |||||||
| 	if l, exp := len(storedToken), minTokenLengthExternal+vault.TokenPrefixLength; l < exp { | 	if l, exp := len(storedToken), minTokenLengthExternal+vault.TokenPrefixLength; l < exp { | ||||||
| 		t.Errorf("expected token to be %d characters, was %d: %q", exp, l, storedToken) | 		t.Errorf("expected token to be %d characters, was %d: %q", exp, l, storedToken) | ||||||
| 	} | 	} | ||||||
| 	}) | } | ||||||
|  |  | ||||||
| 	t.Run("no_store", func(t *testing.T) { | // Do not persist the token to the token helper | ||||||
|  | func TestNoStore(t *testing.T) { | ||||||
| 	t.Parallel() | 	t.Parallel() | ||||||
|  |  | ||||||
| 	client, closer := testVaultServer(t) | 	client, closer := testVaultServer(t) | ||||||
| @@ -137,9 +138,9 @@ func TestLoginCommand_Run(t *testing.T) { | |||||||
| 	if exp := ""; storedToken != exp { | 	if exp := ""; storedToken != exp { | ||||||
| 		t.Errorf("expected %q to be %q", storedToken, exp) | 		t.Errorf("expected %q to be %q", storedToken, exp) | ||||||
| 	} | 	} | ||||||
| 	}) | } | ||||||
|  |  | ||||||
| 	t.Run("stores", func(t *testing.T) { | func TestStores(t *testing.T) { | ||||||
| 	t.Parallel() | 	t.Parallel() | ||||||
|  |  | ||||||
| 	client, closer := testVaultServer(t) | 	client, closer := testVaultServer(t) | ||||||
| @@ -177,9 +178,9 @@ func TestLoginCommand_Run(t *testing.T) { | |||||||
| 	if storedToken != token { | 	if storedToken != token { | ||||||
| 		t.Errorf("expected %q to be %q", storedToken, token) | 		t.Errorf("expected %q to be %q", storedToken, token) | ||||||
| 	} | 	} | ||||||
| 	}) | } | ||||||
|  |  | ||||||
| 	t.Run("token_only", func(t *testing.T) { | func TestTokenOnly(t *testing.T) { | ||||||
| 	t.Parallel() | 	t.Parallel() | ||||||
|  |  | ||||||
| 	client, closer := testVaultServer(t) | 	client, closer := testVaultServer(t) | ||||||
| @@ -223,9 +224,9 @@ func TestLoginCommand_Run(t *testing.T) { | |||||||
| 	if storedToken, err := tokenHelper.Get(); err != nil || storedToken != "" { | 	if storedToken, err := tokenHelper.Get(); err != nil || storedToken != "" { | ||||||
| 		t.Fatalf("expected token to not be stored: %s: %q", err, storedToken) | 		t.Fatalf("expected token to not be stored: %s: %q", err, storedToken) | ||||||
| 	} | 	} | ||||||
| 	}) | } | ||||||
|  |  | ||||||
| 	t.Run("failure_no_store", func(t *testing.T) { | func TestFailureNoStore(t *testing.T) { | ||||||
| 	t.Parallel() | 	t.Parallel() | ||||||
|  |  | ||||||
| 	client, closer := testVaultServer(t) | 	client, closer := testVaultServer(t) | ||||||
| @@ -255,9 +256,9 @@ func TestLoginCommand_Run(t *testing.T) { | |||||||
| 	if storedToken, err := tokenHelper.Get(); err != nil || storedToken != "" { | 	if storedToken, err := tokenHelper.Get(); err != nil || storedToken != "" { | ||||||
| 		t.Fatalf("expected token to not be stored: %s: %q", err, storedToken) | 		t.Fatalf("expected token to not be stored: %s: %q", err, storedToken) | ||||||
| 	} | 	} | ||||||
| 	}) | } | ||||||
|  |  | ||||||
| 	t.Run("wrap_auto_unwrap", func(t *testing.T) { | func TestWrapAutoUnwrap(t *testing.T) { | ||||||
| 	t.Parallel() | 	t.Parallel() | ||||||
|  |  | ||||||
| 	client, closer := testVaultServer(t) | 	client, closer := testVaultServer(t) | ||||||
| @@ -314,9 +315,9 @@ func TestLoginCommand_Run(t *testing.T) { | |||||||
| 	if secret.WrapInfo != nil { | 	if secret.WrapInfo != nil { | ||||||
| 		t.Errorf("expected to be unwrapped: %#v", secret) | 		t.Errorf("expected to be unwrapped: %#v", secret) | ||||||
| 	} | 	} | ||||||
| 	}) | } | ||||||
|  |  | ||||||
| 	t.Run("wrap_token_only", func(t *testing.T) { | func TestWrapTokenOnly(t *testing.T) { | ||||||
| 	t.Parallel() | 	t.Parallel() | ||||||
|  |  | ||||||
| 	client, closer := testVaultServer(t) | 	client, closer := testVaultServer(t) | ||||||
| @@ -375,9 +376,9 @@ func TestLoginCommand_Run(t *testing.T) { | |||||||
| 	if secret == nil || secret.Auth == nil || secret.Auth.ClientToken == "" { | 	if secret == nil || secret.Auth == nil || secret.Auth.ClientToken == "" { | ||||||
| 		t.Fatalf("expected secret to have auth: %#v", secret) | 		t.Fatalf("expected secret to have auth: %#v", secret) | ||||||
| 	} | 	} | ||||||
| 	}) | } | ||||||
|  |  | ||||||
| 	t.Run("wrap_no_store", func(t *testing.T) { | func TestWrapNoStore(t *testing.T) { | ||||||
| 	t.Parallel() | 	t.Parallel() | ||||||
|  |  | ||||||
| 	client, closer := testVaultServer(t) | 	client, closer := testVaultServer(t) | ||||||
| @@ -427,94 +428,9 @@ func TestLoginCommand_Run(t *testing.T) { | |||||||
| 	if !strings.Contains(output, expected) { | 	if !strings.Contains(output, expected) { | ||||||
| 		t.Errorf("expected %q to contain %q", output, expected) | 		t.Errorf("expected %q to contain %q", output, expected) | ||||||
| 	} | 	} | ||||||
| 	}) | } | ||||||
|  |  | ||||||
| 	t.Run("login_mfa_single_phase", func(t *testing.T) { | func TestCommunicationFailure(t *testing.T) { | ||||||
| 		t.Parallel() |  | ||||||
|  |  | ||||||
| 		client, closer := testVaultServer(t) |  | ||||||
| 		defer closer() |  | ||||||
|  |  | ||||||
| 		ui, cmd := testLoginCommand(t) |  | ||||||
|  |  | ||||||
| 		userclient, entityID, methodID := testhelpers.SetupLoginMFATOTP(t, client) |  | ||||||
| 		cmd.client = userclient |  | ||||||
|  |  | ||||||
| 		enginePath := testhelpers.RegisterEntityInTOTPEngine(t, client, entityID, methodID) |  | ||||||
| 		totpCode := testhelpers.GetTOTPCodeFromEngine(t, client, enginePath) |  | ||||||
|  |  | ||||||
| 		// login command bails early for test clients, so we have to explicitly set this |  | ||||||
| 		cmd.client.SetMFACreds([]string{methodID + ":" + totpCode}) |  | ||||||
| 		code := cmd.Run([]string{ |  | ||||||
| 			"-method", "userpass", |  | ||||||
| 			"username=testuser1", |  | ||||||
| 			"password=testpassword", |  | ||||||
| 		}) |  | ||||||
| 		if exp := 0; code != exp { |  | ||||||
| 			t.Errorf("expected %d to be %d", code, exp) |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		tokenHelper, err := cmd.TokenHelper() |  | ||||||
| 		if err != nil { |  | ||||||
| 			t.Fatal(err) |  | ||||||
| 		} |  | ||||||
| 		storedToken, err := tokenHelper.Get() |  | ||||||
| 		if err != nil { |  | ||||||
| 			t.Fatal(err) |  | ||||||
| 		} |  | ||||||
| 		output = ui.OutputWriter.String() + ui.ErrorWriter.String() |  | ||||||
| 		t.Logf("\n%+v", output) |  | ||||||
| 		if !strings.Contains(output, storedToken) { |  | ||||||
| 			t.Fatalf("expected stored token: %q, got: %q", storedToken, output) |  | ||||||
| 		} |  | ||||||
| 	}) |  | ||||||
|  |  | ||||||
| 	t.Run("login_mfa_two_phase", func(t *testing.T) { |  | ||||||
| 		t.Parallel() |  | ||||||
|  |  | ||||||
| 		client, closer := testVaultServer(t) |  | ||||||
| 		defer closer() |  | ||||||
|  |  | ||||||
| 		ui, cmd := testLoginCommand(t) |  | ||||||
|  |  | ||||||
| 		userclient, entityID, methodID := testhelpers.SetupLoginMFATOTP(t, client) |  | ||||||
| 		cmd.client = userclient |  | ||||||
|  |  | ||||||
| 		_ = testhelpers.RegisterEntityInTOTPEngine(t, client, entityID, methodID) |  | ||||||
|  |  | ||||||
| 		// clear the MFA creds just to be sure |  | ||||||
| 		cmd.client.SetMFACreds([]string{}) |  | ||||||
|  |  | ||||||
| 		code := cmd.Run([]string{ |  | ||||||
| 			"-method", "userpass", |  | ||||||
| 			"username=testuser1", |  | ||||||
| 			"password=testpassword", |  | ||||||
| 		}) |  | ||||||
| 		if exp := 0; code != exp { |  | ||||||
| 			t.Errorf("expected %d to be %d", code, exp) |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		expected := methodID |  | ||||||
| 		output = ui.OutputWriter.String() + ui.ErrorWriter.String() |  | ||||||
| 		t.Logf("\n%+v", output) |  | ||||||
| 		if !strings.Contains(output, expected) { |  | ||||||
| 			t.Fatalf("expected stored token: %q, got: %q", expected, output) |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		tokenHelper, err := cmd.TokenHelper() |  | ||||||
| 		if err != nil { |  | ||||||
| 			t.Fatal(err) |  | ||||||
| 		} |  | ||||||
| 		storedToken, err := tokenHelper.Get() |  | ||||||
| 		if storedToken != "" { |  | ||||||
| 			t.Fatal("expected empty stored token") |  | ||||||
| 		} |  | ||||||
| 		if err != nil { |  | ||||||
| 			t.Fatal(err) |  | ||||||
| 		} |  | ||||||
| 	}) |  | ||||||
|  |  | ||||||
| 	t.Run("communication_failure", func(t *testing.T) { |  | ||||||
| 	t.Parallel() | 	t.Parallel() | ||||||
|  |  | ||||||
| 	client, closer := testVaultServerBad(t) | 	client, closer := testVaultServerBad(t) | ||||||
| @@ -535,12 +451,163 @@ func TestLoginCommand_Run(t *testing.T) { | |||||||
| 	if !strings.Contains(combined, expected) { | 	if !strings.Contains(combined, expected) { | ||||||
| 		t.Errorf("expected %q to contain %q", combined, expected) | 		t.Errorf("expected %q to contain %q", combined, expected) | ||||||
| 	} | 	} | ||||||
| 	}) | } | ||||||
|  |  | ||||||
| 	t.Run("no_tabs", func(t *testing.T) { | func TestNoTabs(t *testing.T) { | ||||||
| 	t.Parallel() | 	t.Parallel() | ||||||
|  |  | ||||||
| 	_, cmd := testLoginCommand(t) | 	_, cmd := testLoginCommand(t) | ||||||
| 	assertNoTabs(t, cmd) | 	assertNoTabs(t, cmd) | ||||||
| 	}) | } | ||||||
|  |  | ||||||
|  | func TestLoginMFASinglePhase(t *testing.T) { | ||||||
|  | 	t.Parallel() | ||||||
|  |  | ||||||
|  | 	client, closer := testVaultServer(t) | ||||||
|  | 	defer closer() | ||||||
|  |  | ||||||
|  | 	methodName := "foo" | ||||||
|  | 	waitPeriod := 5 | ||||||
|  | 	userClient, entityID, methodID := testhelpers.SetupLoginMFATOTP(t, client, methodName, waitPeriod) | ||||||
|  | 	enginePath := testhelpers.RegisterEntityInTOTPEngine(t, client, entityID, methodID) | ||||||
|  |  | ||||||
|  | 	runCommand := func(methodIdentifier string) { | ||||||
|  | 		// the time required for the totp engine to generate a new code | ||||||
|  | 		time.Sleep(time.Duration(waitPeriod) * time.Second) | ||||||
|  | 		totpCode := testhelpers.GetTOTPCodeFromEngine(t, client, enginePath) | ||||||
|  | 		ui, cmd := testLoginCommand(t) | ||||||
|  | 		cmd.client = userClient | ||||||
|  | 		// login command bails early for test clients, so we have to explicitly set this | ||||||
|  | 		cmd.client.SetMFACreds([]string{methodIdentifier + ":" + totpCode}) | ||||||
|  | 		code := cmd.Run([]string{ | ||||||
|  | 			"-method", "userpass", | ||||||
|  | 			"username=testuser1", | ||||||
|  | 			"password=testpassword", | ||||||
|  | 		}) | ||||||
|  | 		if exp := 0; code != exp { | ||||||
|  | 			t.Errorf("expected %d to be %d", code, exp) | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		tokenHelper, err := cmd.TokenHelper() | ||||||
|  | 		if err != nil { | ||||||
|  | 			t.Fatal(err) | ||||||
|  | 		} | ||||||
|  | 		storedToken, err := tokenHelper.Get() | ||||||
|  | 		if err != nil { | ||||||
|  | 			t.Fatal(err) | ||||||
|  | 		} | ||||||
|  | 		if storedToken == "" { | ||||||
|  | 			t.Fatal("expected non-empty stored token") | ||||||
|  | 		} | ||||||
|  | 		output = ui.OutputWriter.String() | ||||||
|  | 		if !strings.Contains(output, storedToken) { | ||||||
|  | 			t.Fatalf("expected stored token: %q, got: %q", storedToken, output) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	runCommand(methodID) | ||||||
|  | 	runCommand(methodName) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestLoginMFATwoPhase(t *testing.T) { | ||||||
|  | 	t.Parallel() | ||||||
|  |  | ||||||
|  | 	client, closer := testVaultServer(t) | ||||||
|  | 	defer closer() | ||||||
|  |  | ||||||
|  | 	ui, cmd := testLoginCommand(t) | ||||||
|  |  | ||||||
|  | 	userclient, entityID, methodID := testhelpers.SetupLoginMFATOTP(t, client, "", 5) | ||||||
|  | 	cmd.client = userclient | ||||||
|  |  | ||||||
|  | 	_ = testhelpers.RegisterEntityInTOTPEngine(t, client, entityID, methodID) | ||||||
|  |  | ||||||
|  | 	// clear the MFA creds just to be sure | ||||||
|  | 	cmd.client.SetMFACreds([]string{}) | ||||||
|  |  | ||||||
|  | 	code := cmd.Run([]string{ | ||||||
|  | 		"-method", "userpass", | ||||||
|  | 		"username=testuser1", | ||||||
|  | 		"password=testpassword", | ||||||
|  | 	}) | ||||||
|  | 	if exp := 0; code != exp { | ||||||
|  | 		t.Errorf("expected %d to be %d", code, exp) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	expected := methodID | ||||||
|  | 	output = ui.OutputWriter.String() | ||||||
|  | 	if !strings.Contains(output, expected) { | ||||||
|  | 		t.Fatalf("expected stored token: %q, got: %q", expected, output) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	tokenHelper, err := cmd.TokenHelper() | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  | 	storedToken, err := tokenHelper.Get() | ||||||
|  | 	if storedToken != "" { | ||||||
|  | 		t.Fatal("expected empty stored token") | ||||||
|  | 	} | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestLoginMFATwoPhaseNonInteractiveMethodName(t *testing.T) { | ||||||
|  | 	t.Parallel() | ||||||
|  |  | ||||||
|  | 	client, closer := testVaultServer(t) | ||||||
|  | 	defer closer() | ||||||
|  |  | ||||||
|  | 	ui, cmd := testLoginCommand(t) | ||||||
|  |  | ||||||
|  | 	methodName := "foo" | ||||||
|  | 	waitPeriod := 5 | ||||||
|  | 	userclient, entityID, methodID := testhelpers.SetupLoginMFATOTP(t, client, methodName, waitPeriod) | ||||||
|  | 	cmd.client = userclient | ||||||
|  |  | ||||||
|  | 	engineName := testhelpers.RegisterEntityInTOTPEngine(t, client, entityID, methodID) | ||||||
|  |  | ||||||
|  | 	// clear the MFA creds just to be sure | ||||||
|  | 	cmd.client.SetMFACreds([]string{}) | ||||||
|  |  | ||||||
|  | 	code := cmd.Run([]string{ | ||||||
|  | 		"-method", "userpass", | ||||||
|  | 		"-non-interactive", | ||||||
|  | 		"username=testuser1", | ||||||
|  | 		"password=testpassword", | ||||||
|  | 	}) | ||||||
|  | 	if exp := 0; code != exp { | ||||||
|  | 		t.Errorf("expected %d to be %d", code, exp) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	output = ui.OutputWriter.String() | ||||||
|  |  | ||||||
|  | 	reqIdReg := regexp.MustCompile(`mfa_request_id\s+([0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12})\s+mfa_constraint`) | ||||||
|  | 	reqIDRaw := reqIdReg.FindAllStringSubmatch(output, -1) | ||||||
|  | 	if len(reqIDRaw) == 0 || len(reqIDRaw[0]) < 2 { | ||||||
|  | 		t.Fatal("failed to MFA request ID from output") | ||||||
|  | 	} | ||||||
|  | 	mfaReqID := reqIDRaw[0][1] | ||||||
|  |  | ||||||
|  | 	validateFunc := func(methodIdentifier string) { | ||||||
|  | 		// the time required for the totp engine to generate a new code | ||||||
|  | 		time.Sleep(time.Duration(waitPeriod) * time.Second) | ||||||
|  | 		totpPasscode1 := "passcode=" + testhelpers.GetTOTPCodeFromEngine(t, client, engineName) | ||||||
|  |  | ||||||
|  | 		secret, err := cmd.client.Logical().WriteWithContext(context.Background(), "sys/mfa/validate", map[string]interface{}{ | ||||||
|  | 			"mfa_request_id": mfaReqID, | ||||||
|  | 			"mfa_payload": map[string][]string{ | ||||||
|  | 				methodIdentifier: {totpPasscode1}, | ||||||
|  | 			}, | ||||||
|  | 		}) | ||||||
|  | 		if err != nil { | ||||||
|  | 			t.Fatalf("mfa validation failed: %v", err) | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		if secret.Auth == nil || secret.Auth.ClientToken == "" { | ||||||
|  | 			t.Fatalf("mfa validation did not return a client token") | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	validateFunc(methodName) | ||||||
| } | } | ||||||
|   | |||||||
| @@ -33,7 +33,7 @@ const ( | |||||||
| 	GenerateRecovery | 	GenerateRecovery | ||||||
| ) | ) | ||||||
|  |  | ||||||
| // Generates a root token on the target cluster. | // GenerateRoot generates a root token on the target cluster. | ||||||
| func GenerateRoot(t testing.T, cluster *vault.TestCluster, kind GenerateRootKind) string { | func GenerateRoot(t testing.T, cluster *vault.TestCluster, kind GenerateRootKind) string { | ||||||
| 	t.Helper() | 	t.Helper() | ||||||
| 	token, err := GenerateRootWithError(t, cluster, kind) | 	token, err := GenerateRootWithError(t, cluster, kind) | ||||||
| @@ -767,6 +767,21 @@ func SetNonRootToken(client *api.Client) error { | |||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // RetryUntilAtCadence runs f until it returns a nil result or the timeout is reached. | ||||||
|  | // If a nil result hasn't been obtained by timeout, calls t.Fatal. | ||||||
|  | func RetryUntilAtCadence(t testing.T, timeout, sleepTime time.Duration, f func() error) { | ||||||
|  | 	t.Helper() | ||||||
|  | 	deadline := time.Now().Add(timeout) | ||||||
|  | 	var err error | ||||||
|  | 	for time.Now().Before(deadline) { | ||||||
|  | 		if err = f(); err == nil { | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  | 		time.Sleep(sleepTime) | ||||||
|  | 	} | ||||||
|  | 	t.Fatalf("did not complete before deadline, err: %v", err) | ||||||
|  | } | ||||||
|  |  | ||||||
| // RetryUntil runs f until it returns a nil result or the timeout is reached. | // RetryUntil runs f until it returns a nil result or the timeout is reached. | ||||||
| // If a nil result hasn't been obtained by timeout, calls t.Fatal. | // If a nil result hasn't been obtained by timeout, calls t.Fatal. | ||||||
| func RetryUntil(t testing.T, timeout time.Duration, f func() error) { | func RetryUntil(t testing.T, timeout time.Duration, f func() error) { | ||||||
| @@ -942,7 +957,7 @@ func GetTOTPCodeFromEngine(t testing.T, client *api.Client, enginePath string) s | |||||||
|  |  | ||||||
| // SetupLoginMFATOTP setups up a TOTP MFA using some basic configuration and | // SetupLoginMFATOTP setups up a TOTP MFA using some basic configuration and | ||||||
| // returns all relevant information to the client. | // returns all relevant information to the client. | ||||||
| func SetupLoginMFATOTP(t testing.T, client *api.Client) (*api.Client, string, string) { | func SetupLoginMFATOTP(t testing.T, client *api.Client, methodName string, waitPeriod int) (*api.Client, string, string) { | ||||||
| 	t.Helper() | 	t.Helper() | ||||||
| 	// Mount the totp secrets engine | 	// Mount the totp secrets engine | ||||||
| 	SetupTOTPMount(t, client) | 	SetupTOTPMount(t, client) | ||||||
| @@ -956,13 +971,14 @@ func SetupLoginMFATOTP(t testing.T, client *api.Client) (*api.Client, string, st | |||||||
| 	// Configure a default TOTP method | 	// Configure a default TOTP method | ||||||
| 	totpConfig := map[string]interface{}{ | 	totpConfig := map[string]interface{}{ | ||||||
| 		"issuer":                  "yCorp", | 		"issuer":                  "yCorp", | ||||||
| 		"period":                  20, | 		"period":                  waitPeriod, | ||||||
| 		"algorithm":               "SHA256", | 		"algorithm":               "SHA256", | ||||||
| 		"digits":                  6, | 		"digits":                  6, | ||||||
| 		"skew":                    1, | 		"skew":                    1, | ||||||
| 		"key_size":                20, | 		"key_size":                20, | ||||||
| 		"qr_size":                 200, | 		"qr_size":                 200, | ||||||
| 		"max_validation_attempts": 5, | 		"max_validation_attempts": 5, | ||||||
|  | 		"method_name":             methodName, | ||||||
| 	} | 	} | ||||||
| 	methodID := SetupTOTPMethod(t, client, totpConfig) | 	methodID := SetupTOTPMethod(t, client, totpConfig) | ||||||
|  |  | ||||||
|   | |||||||
| @@ -318,6 +318,7 @@ type MFAMethodID struct { | |||||||
| 	Type         string `protobuf:"bytes,1,opt,name=type,proto3" json:"type,omitempty"` | 	Type         string `protobuf:"bytes,1,opt,name=type,proto3" json:"type,omitempty"` | ||||||
| 	ID           string `protobuf:"bytes,2,opt,name=id,proto3" json:"id,omitempty"` | 	ID           string `protobuf:"bytes,2,opt,name=id,proto3" json:"id,omitempty"` | ||||||
| 	UsesPasscode bool   `protobuf:"varint,3,opt,name=uses_passcode,json=usesPasscode,proto3" json:"uses_passcode,omitempty"` | 	UsesPasscode bool   `protobuf:"varint,3,opt,name=uses_passcode,json=usesPasscode,proto3" json:"uses_passcode,omitempty"` | ||||||
|  | 	Name         string `protobuf:"bytes,4,opt,name=name,proto3" json:"name,omitempty"` | ||||||
| } | } | ||||||
|  |  | ||||||
| func (x *MFAMethodID) Reset() { | func (x *MFAMethodID) Reset() { | ||||||
| @@ -373,6 +374,13 @@ func (x *MFAMethodID) GetUsesPasscode() bool { | |||||||
| 	return false | 	return false | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func (x *MFAMethodID) GetName() string { | ||||||
|  | 	if x != nil { | ||||||
|  | 		return x.Name | ||||||
|  | 	} | ||||||
|  | 	return "" | ||||||
|  | } | ||||||
|  |  | ||||||
| type MFAConstraintAny struct { | type MFAConstraintAny struct { | ||||||
| 	state         protoimpl.MessageState | 	state         protoimpl.MessageState | ||||||
| 	sizeCache     protoimpl.SizeCache | 	sizeCache     protoimpl.SizeCache | ||||||
| @@ -537,34 +545,35 @@ var file_sdk_logical_identity_proto_rawDesc = []byte{ | |||||||
| 	0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, | 	0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, | ||||||
| 	0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, | 	0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, | ||||||
| 	0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x02, 0x38, | 	0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x02, 0x38, | ||||||
| 	0x01, 0x22, 0x56, 0x0a, 0x0b, 0x4d, 0x46, 0x41, 0x4d, 0x65, 0x74, 0x68, 0x6f, 0x64, 0x49, 0x44, | 	0x01, 0x22, 0x6a, 0x0a, 0x0b, 0x4d, 0x46, 0x41, 0x4d, 0x65, 0x74, 0x68, 0x6f, 0x64, 0x49, 0x44, | ||||||
| 	0x12, 0x12, 0x0a, 0x04, 0x74, 0x79, 0x70, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, | 	0x12, 0x12, 0x0a, 0x04, 0x74, 0x79, 0x70, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, | ||||||
| 	0x74, 0x79, 0x70, 0x65, 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, | 	0x74, 0x79, 0x70, 0x65, 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, | ||||||
| 	0x52, 0x02, 0x69, 0x64, 0x12, 0x23, 0x0a, 0x0d, 0x75, 0x73, 0x65, 0x73, 0x5f, 0x70, 0x61, 0x73, | 	0x52, 0x02, 0x69, 0x64, 0x12, 0x23, 0x0a, 0x0d, 0x75, 0x73, 0x65, 0x73, 0x5f, 0x70, 0x61, 0x73, | ||||||
| 	0x73, 0x63, 0x6f, 0x64, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0c, 0x75, 0x73, 0x65, | 	0x73, 0x63, 0x6f, 0x64, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0c, 0x75, 0x73, 0x65, | ||||||
| 	0x73, 0x50, 0x61, 0x73, 0x73, 0x63, 0x6f, 0x64, 0x65, 0x22, 0x3a, 0x0a, 0x10, 0x4d, 0x46, 0x41, | 	0x73, 0x50, 0x61, 0x73, 0x73, 0x63, 0x6f, 0x64, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x6e, 0x61, 0x6d, | ||||||
| 	0x43, 0x6f, 0x6e, 0x73, 0x74, 0x72, 0x61, 0x69, 0x6e, 0x74, 0x41, 0x6e, 0x79, 0x12, 0x26, 0x0a, | 	0x65, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x22, 0x3a, 0x0a, | ||||||
| 	0x03, 0x61, 0x6e, 0x79, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6c, 0x6f, 0x67, | 	0x10, 0x4d, 0x46, 0x41, 0x43, 0x6f, 0x6e, 0x73, 0x74, 0x72, 0x61, 0x69, 0x6e, 0x74, 0x41, 0x6e, | ||||||
| 	0x69, 0x63, 0x61, 0x6c, 0x2e, 0x4d, 0x46, 0x41, 0x4d, 0x65, 0x74, 0x68, 0x6f, 0x64, 0x49, 0x44, | 	0x79, 0x12, 0x26, 0x0a, 0x03, 0x61, 0x6e, 0x79, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x14, | ||||||
| 	0x52, 0x03, 0x61, 0x6e, 0x79, 0x22, 0xea, 0x01, 0x0a, 0x0e, 0x4d, 0x46, 0x41, 0x52, 0x65, 0x71, | 	0x2e, 0x6c, 0x6f, 0x67, 0x69, 0x63, 0x61, 0x6c, 0x2e, 0x4d, 0x46, 0x41, 0x4d, 0x65, 0x74, 0x68, | ||||||
| 	0x75, 0x69, 0x72, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x12, 0x24, 0x0a, 0x0e, 0x6d, 0x66, 0x61, 0x5f, | 	0x6f, 0x64, 0x49, 0x44, 0x52, 0x03, 0x61, 0x6e, 0x79, 0x22, 0xea, 0x01, 0x0a, 0x0e, 0x4d, 0x46, | ||||||
| 	0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, | 	0x41, 0x52, 0x65, 0x71, 0x75, 0x69, 0x72, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x12, 0x24, 0x0a, 0x0e, | ||||||
| 	0x52, 0x0c, 0x6d, 0x66, 0x61, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x49, 0x64, 0x12, 0x54, | 	0x6d, 0x66, 0x61, 0x5f, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x01, | ||||||
| 	0x0a, 0x0f, 0x6d, 0x66, 0x61, 0x5f, 0x63, 0x6f, 0x6e, 0x73, 0x74, 0x72, 0x61, 0x69, 0x6e, 0x74, | 	0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x6d, 0x66, 0x61, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, | ||||||
| 	0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x2b, 0x2e, 0x6c, 0x6f, 0x67, 0x69, 0x63, 0x61, | 	0x49, 0x64, 0x12, 0x54, 0x0a, 0x0f, 0x6d, 0x66, 0x61, 0x5f, 0x63, 0x6f, 0x6e, 0x73, 0x74, 0x72, | ||||||
| 	0x6c, 0x2e, 0x4d, 0x46, 0x41, 0x52, 0x65, 0x71, 0x75, 0x69, 0x72, 0x65, 0x6d, 0x65, 0x6e, 0x74, | 	0x61, 0x69, 0x6e, 0x74, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x2b, 0x2e, 0x6c, 0x6f, | ||||||
| 	0x2e, 0x4d, 0x66, 0x61, 0x43, 0x6f, 0x6e, 0x73, 0x74, 0x72, 0x61, 0x69, 0x6e, 0x74, 0x73, 0x45, | 	0x67, 0x69, 0x63, 0x61, 0x6c, 0x2e, 0x4d, 0x46, 0x41, 0x52, 0x65, 0x71, 0x75, 0x69, 0x72, 0x65, | ||||||
| 	0x6e, 0x74, 0x72, 0x79, 0x52, 0x0e, 0x6d, 0x66, 0x61, 0x43, 0x6f, 0x6e, 0x73, 0x74, 0x72, 0x61, | 	0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4d, 0x66, 0x61, 0x43, 0x6f, 0x6e, 0x73, 0x74, 0x72, 0x61, 0x69, | ||||||
| 	0x69, 0x6e, 0x74, 0x73, 0x1a, 0x5c, 0x0a, 0x13, 0x4d, 0x66, 0x61, 0x43, 0x6f, 0x6e, 0x73, 0x74, | 	0x6e, 0x74, 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x52, 0x0e, 0x6d, 0x66, 0x61, 0x43, 0x6f, 0x6e, | ||||||
| 	0x72, 0x61, 0x69, 0x6e, 0x74, 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, 0x03, 0x6b, | 	0x73, 0x74, 0x72, 0x61, 0x69, 0x6e, 0x74, 0x73, 0x1a, 0x5c, 0x0a, 0x13, 0x4d, 0x66, 0x61, 0x43, | ||||||
| 	0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x2f, 0x0a, | 	0x6f, 0x6e, 0x73, 0x74, 0x72, 0x61, 0x69, 0x6e, 0x74, 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, | ||||||
| 	0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x19, 0x2e, 0x6c, | 	0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, | ||||||
| 	0x6f, 0x67, 0x69, 0x63, 0x61, 0x6c, 0x2e, 0x4d, 0x46, 0x41, 0x43, 0x6f, 0x6e, 0x73, 0x74, 0x72, | 	0x79, 0x12, 0x2f, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, | ||||||
| 	0x61, 0x69, 0x6e, 0x74, 0x41, 0x6e, 0x79, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x02, | 	0x32, 0x19, 0x2e, 0x6c, 0x6f, 0x67, 0x69, 0x63, 0x61, 0x6c, 0x2e, 0x4d, 0x46, 0x41, 0x43, 0x6f, | ||||||
| 	0x38, 0x01, 0x42, 0x28, 0x5a, 0x26, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, | 	0x6e, 0x73, 0x74, 0x72, 0x61, 0x69, 0x6e, 0x74, 0x41, 0x6e, 0x79, 0x52, 0x05, 0x76, 0x61, 0x6c, | ||||||
| 	0x2f, 0x68, 0x61, 0x73, 0x68, 0x69, 0x63, 0x6f, 0x72, 0x70, 0x2f, 0x76, 0x61, 0x75, 0x6c, 0x74, | 	0x75, 0x65, 0x3a, 0x02, 0x38, 0x01, 0x42, 0x28, 0x5a, 0x26, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, | ||||||
| 	0x2f, 0x73, 0x64, 0x6b, 0x2f, 0x6c, 0x6f, 0x67, 0x69, 0x63, 0x61, 0x6c, 0x62, 0x06, 0x70, 0x72, | 	0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x68, 0x61, 0x73, 0x68, 0x69, 0x63, 0x6f, 0x72, 0x70, 0x2f, 0x76, | ||||||
| 	0x6f, 0x74, 0x6f, 0x33, | 	0x61, 0x75, 0x6c, 0x74, 0x2f, 0x73, 0x64, 0x6b, 0x2f, 0x6c, 0x6f, 0x67, 0x69, 0x63, 0x61, 0x6c, | ||||||
|  | 	0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, | ||||||
| } | } | ||||||
|  |  | ||||||
| var ( | var ( | ||||||
|   | |||||||
| @@ -79,6 +79,7 @@ message MFAMethodID { | |||||||
| 	string type = 1; | 	string type = 1; | ||||||
| 	string id = 2; | 	string id = 2; | ||||||
| 	bool uses_passcode = 3; | 	bool uses_passcode = 3; | ||||||
|  | 	string name = 4; | ||||||
| } | } | ||||||
|  |  | ||||||
| message MFAConstraintAny { | message MFAConstraintAny { | ||||||
|   | |||||||
| @@ -93,16 +93,17 @@ func TestLoginMfaGenerateTOTPTestAuditIncluded(t *testing.T) { | |||||||
| 	// Creating two users in the userpass auth mount | 	// Creating two users in the userpass auth mount | ||||||
| 	userClient1, entityID1, _ := testhelpers.CreateEntityAndAlias(t, client, mountAccessor, entity1, testuser1) | 	userClient1, entityID1, _ := testhelpers.CreateEntityAndAlias(t, client, mountAccessor, entity1, testuser1) | ||||||
| 	userClient2, entityID2, _ := testhelpers.CreateEntityAndAlias(t, client, mountAccessor, entity2, testuser2) | 	userClient2, entityID2, _ := testhelpers.CreateEntityAndAlias(t, client, mountAccessor, entity2, testuser2) | ||||||
|  | 	waitPeriod := 5 | ||||||
| 	totpConfig := map[string]interface{}{ | 	totpConfig := map[string]interface{}{ | ||||||
| 		"issuer":                  "yCorp", | 		"issuer":                  "yCorp", | ||||||
| 		"period":                  5, | 		"period":                  waitPeriod, | ||||||
| 		"algorithm":               "SHA1", | 		"algorithm":               "SHA1", | ||||||
| 		"digits":                  6, | 		"digits":                  6, | ||||||
| 		"skew":                    1, | 		"skew":                    1, | ||||||
| 		"key_size":                10, | 		"key_size":                10, | ||||||
| 		"qr_size":                 100, | 		"qr_size":                 100, | ||||||
| 		"max_validation_attempts": 3, | 		"max_validation_attempts": 3, | ||||||
|  | 		"method_name":             "foo", | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	methodID := testhelpers.SetupTOTPMethod(t, client, totpConfig) | 	methodID := testhelpers.SetupTOTPMethod(t, client, totpConfig) | ||||||
| @@ -123,22 +124,7 @@ func TestLoginMfaGenerateTOTPTestAuditIncluded(t *testing.T) { | |||||||
| 	userpassPath := fmt.Sprintf("auth/userpass/login/%s", testuser1) | 	userpassPath := fmt.Sprintf("auth/userpass/login/%s", testuser1) | ||||||
|  |  | ||||||
| 	// MFA single-phase login | 	// MFA single-phase login | ||||||
| 	time.Sleep(5 * time.Second) | 	verifyLoginRequest := func(secret *api.Secret) { | ||||||
| 	var secret *api.Secret |  | ||||||
| 	testhelpers.RetryUntil(t, 20*time.Second, func() error { |  | ||||||
| 		var err error |  | ||||||
| 		totpPasscode := testhelpers.GetTOTPCodeFromEngine(t, client, enginePath1) |  | ||||||
|  |  | ||||||
| 		userClient1.AddHeader("X-Vault-MFA", fmt.Sprintf("%s:%s", methodID, totpPasscode)) |  | ||||||
| 		secret, err = userClient1.Logical().WriteWithContext(context.Background(), userpassPath, map[string]interface{}{ |  | ||||||
| 			"password": "testpassword", |  | ||||||
| 		}) |  | ||||||
| 		if err != nil { |  | ||||||
| 			return fmt.Errorf("MFA failed: %w", err) |  | ||||||
| 		} |  | ||||||
| 		return nil |  | ||||||
| 	}) |  | ||||||
|  |  | ||||||
| 		userpassToken := secret.Auth.ClientToken | 		userpassToken := secret.Auth.ClientToken | ||||||
| 		userClient1.SetToken(client.Token()) | 		userClient1.SetToken(client.Token()) | ||||||
| 		secret, err := userClient1.Logical().WriteWithContext(context.Background(), "auth/token/lookup", map[string]interface{}{ | 		secret, err := userClient1.Logical().WriteWithContext(context.Background(), "auth/token/lookup", map[string]interface{}{ | ||||||
| @@ -152,11 +138,44 @@ func TestLoginMfaGenerateTOTPTestAuditIncluded(t *testing.T) { | |||||||
| 		if entityIDCheck != entityID1 { | 		if entityIDCheck != entityID1 { | ||||||
| 			t.Fatalf("different entityID assigned") | 			t.Fatalf("different entityID assigned") | ||||||
| 		} | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// helper function to clear the MFA request header | ||||||
|  | 	clearMFARequestHeaders := func(c *api.Client) { | ||||||
|  | 		headers := c.Headers() | ||||||
|  | 		headers.Del("X-Vault-MFA") | ||||||
|  | 		c.SetHeaders(headers) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	var secret *api.Secret | ||||||
|  | 	var err error | ||||||
|  | 	var methodIdentifier string | ||||||
|  |  | ||||||
|  | 	singlePhaseLoginFunc := func() error { | ||||||
|  | 		totpPasscode := testhelpers.GetTOTPCodeFromEngine(t, client, enginePath1) | ||||||
|  | 		userClient1.AddHeader("X-Vault-MFA", fmt.Sprintf("%s:%s", methodIdentifier, totpPasscode)) | ||||||
|  | 		defer clearMFARequestHeaders(userClient1) | ||||||
|  | 		secret, err = userClient1.Logical().WriteWithContext(context.Background(), userpassPath, map[string]interface{}{ | ||||||
|  | 			"password": "testpassword", | ||||||
|  | 		}) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return fmt.Errorf("MFA failed for identifier %s: %v", methodIdentifier, err) | ||||||
|  | 		} | ||||||
|  | 		return nil | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// single phase login for both method name and method ID | ||||||
|  | 	methodIdentifier = totpConfig["method_name"].(string) | ||||||
|  | 	testhelpers.RetryUntilAtCadence(t, 20*time.Second, 100*time.Millisecond, singlePhaseLoginFunc) | ||||||
|  | 	verifyLoginRequest(secret) | ||||||
|  |  | ||||||
|  | 	methodIdentifier = methodID | ||||||
|  | 	// Need to wait a bit longer to avoid hitting maximum allowed consecutive | ||||||
|  | 	// failed TOTP validation | ||||||
|  | 	testhelpers.RetryUntilAtCadence(t, 20*time.Second, time.Duration(waitPeriod)*time.Second, singlePhaseLoginFunc) | ||||||
|  | 	verifyLoginRequest(secret) | ||||||
|  |  | ||||||
| 	// Two-phase login | 	// Two-phase login | ||||||
| 	headers := userClient1.Headers() |  | ||||||
| 	headers.Del("X-Vault-MFA") |  | ||||||
| 	userClient1.SetHeaders(headers) |  | ||||||
| 	secret, err = userClient1.Logical().WriteWithContext(context.Background(), userpassPath, map[string]interface{}{ | 	secret, err = userClient1.Logical().WriteWithContext(context.Background(), userpassPath, map[string]interface{}{ | ||||||
| 		"password": "testpassword", | 		"password": "testpassword", | ||||||
| 	}) | 	}) | ||||||
| @@ -191,26 +210,43 @@ func TestLoginMfaGenerateTOTPTestAuditIncluded(t *testing.T) { | |||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// validation | 	// validation | ||||||
| 	time.Sleep(5 * time.Second) | 	var mfaReqID string | ||||||
| 	var totpPasscode1 string | 	var totpPasscode1 string | ||||||
| 	testhelpers.RetryUntil(t, 20*time.Second, func() error { | 	mfaValidateFunc := func() error { | ||||||
| 		totpPasscode1 = testhelpers.GetTOTPCodeFromEngine(t, client, enginePath1) | 		totpPasscode1 = testhelpers.GetTOTPCodeFromEngine(t, client, enginePath1) | ||||||
|  |  | ||||||
| 		secret, err = userClient1.Logical().WriteWithContext(context.Background(), "sys/mfa/validate", map[string]interface{}{ | 		secret, err = userClient1.Logical().WriteWithContext(context.Background(), "sys/mfa/validate", map[string]interface{}{ | ||||||
| 			"mfa_request_id": secret.Auth.MFARequirement.MFARequestID, | 			"mfa_request_id": mfaReqID, | ||||||
| 			"mfa_payload": map[string][]string{ | 			"mfa_payload": map[string][]string{ | ||||||
| 				methodID: {totpPasscode1}, | 				methodIdentifier: {totpPasscode1}, | ||||||
| 			}, | 			}, | ||||||
| 		}) | 		}) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return fmt.Errorf("MFA failed: %w", err) | 			return fmt.Errorf("MFA failed: %v", err) | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		return nil |  | ||||||
| 	}) |  | ||||||
| 		if secret.Auth == nil || secret.Auth.ClientToken == "" { | 		if secret.Auth == nil || secret.Auth.ClientToken == "" { | ||||||
| 			t.Fatalf("successful mfa validation did not return a client token") | 			t.Fatalf("successful mfa validation did not return a client token") | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
|  | 		return nil | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	methodIdentifier = methodID | ||||||
|  | 	mfaReqID = secret.Auth.MFARequirement.MFARequestID | ||||||
|  | 	testhelpers.RetryUntilAtCadence(t, 20*time.Second, time.Duration(waitPeriod)*time.Second, mfaValidateFunc) | ||||||
|  |  | ||||||
|  | 	// two phase login with method name | ||||||
|  | 	secret, err = userClient1.Logical().WriteWithContext(context.Background(), userpassPath, map[string]interface{}{ | ||||||
|  | 		"password": "testpassword", | ||||||
|  | 	}) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatalf("MFA failed: %v", err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	methodIdentifier = totpConfig["method_name"].(string) | ||||||
|  | 	mfaReqID = secret.Auth.MFARequirement.MFARequestID | ||||||
|  | 	testhelpers.RetryUntilAtCadence(t, 20*time.Second, time.Duration(waitPeriod)*time.Second, mfaValidateFunc) | ||||||
|  |  | ||||||
|  | 	// checking audit log | ||||||
| 	if noop.Req == nil { | 	if noop.Req == nil { | ||||||
| 		t.Fatalf("no request was logged in audit log") | 		t.Fatalf("no request was logged in audit log") | ||||||
| 	} | 	} | ||||||
|   | |||||||
| @@ -13,7 +13,7 @@ import ( | |||||||
| 	"github.com/hashicorp/vault/vault" | 	"github.com/hashicorp/vault/vault" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| // TestLoginMFA_Method_CRUD tests creating/reading/updating/deleting a method config for all of the MFA providers | // TestLoginMFA_Method_CRUD tests creating/reading/updating/deleting a method config for all the MFA providers | ||||||
| func TestLoginMFA_Method_CRUD(t *testing.T) { | func TestLoginMFA_Method_CRUD(t *testing.T) { | ||||||
| 	cluster := vault.NewTestCluster(t, &vault.CoreConfig{ | 	cluster := vault.NewTestCluster(t, &vault.CoreConfig{ | ||||||
| 		CredentialBackends: map[string]logical.Factory{ | 		CredentialBackends: map[string]logical.Factory{ | ||||||
| @@ -216,6 +216,126 @@ func TestLoginMFA_Method_CRUD(t *testing.T) { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func TestLoginMFAMethodName(t *testing.T) { | ||||||
|  | 	cluster := vault.NewTestCluster(t, &vault.CoreConfig{ | ||||||
|  | 		CredentialBackends: map[string]logical.Factory{ | ||||||
|  | 			"userpass": userpass.Factory, | ||||||
|  | 		}, | ||||||
|  | 	}, &vault.TestClusterOptions{ | ||||||
|  | 		HandlerFunc: vaulthttp.Handler, | ||||||
|  | 	}) | ||||||
|  | 	cluster.Start() | ||||||
|  | 	defer cluster.Cleanup() | ||||||
|  |  | ||||||
|  | 	core := cluster.Cores[0].Core | ||||||
|  | 	vault.TestWaitActive(t, core) | ||||||
|  | 	client := cluster.Cores[0].Client | ||||||
|  |  | ||||||
|  | 	// Enable userpass authentication | ||||||
|  | 	err := client.Sys().EnableAuthWithOptions("userpass", &api.EnableAuthOptions{ | ||||||
|  | 		Type: "userpass", | ||||||
|  | 	}) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatalf("failed to enable userpass auth: %v", err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	auths, err := client.Sys().ListAuth() | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  | 	mountAccessor := auths["userpass/"].Accessor | ||||||
|  |  | ||||||
|  | 	testCases := []struct { | ||||||
|  | 		methodType string | ||||||
|  | 		configData map[string]interface{} | ||||||
|  | 	}{ | ||||||
|  | 		{ | ||||||
|  | 			"totp", | ||||||
|  | 			map[string]interface{}{ | ||||||
|  | 				"issuer":      "yCorp", | ||||||
|  | 				"method_name": "totp-method", | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			"duo", | ||||||
|  | 			map[string]interface{}{ | ||||||
|  | 				"mount_accessor":  mountAccessor, | ||||||
|  | 				"secret_key":      "lol-secret", | ||||||
|  | 				"integration_key": "integration-key", | ||||||
|  | 				"api_hostname":    "some-hostname", | ||||||
|  | 				"method_name":     "duo-method", | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			"okta", | ||||||
|  | 			map[string]interface{}{ | ||||||
|  | 				"mount_accessor": mountAccessor, | ||||||
|  | 				"base_url":       "example.com", | ||||||
|  | 				"org_name":       "my-org", | ||||||
|  | 				"api_token":      "lol-token", | ||||||
|  | 				"method_name":    "okta-method", | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			"pingid", | ||||||
|  | 			map[string]interface{}{ | ||||||
|  | 				"mount_accessor":       mountAccessor, | ||||||
|  | 				"settings_file_base64": "I0F1dG8tR2VuZXJhdGVkIGZyb20gUGluZ09uZSwgZG93bmxvYWRlZCBieSBpZD1bU1NPXSBlbWFpbD1baGFtaWRAaGFzaGljb3JwLmNvbV0KI1dlZCBEZWMgMTUgMTM6MDg6NDQgTVNUIDIwMjEKdXNlX2Jhc2U2NF9rZXk9YlhrdGMyVmpjbVYwTFd0bGVRPT0KdXNlX3NpZ25hdHVyZT10cnVlCnRva2VuPWxvbC10b2tlbgppZHBfdXJsPWh0dHBzOi8vaWRweG55bDNtLnBpbmdpZGVudGl0eS5jb20vcGluZ2lkCm9yZ19hbGlhcz1sb2wtb3JnLWFsaWFzCmFkbWluX3VybD1odHRwczovL2lkcHhueWwzbS5waW5naWRlbnRpdHkuY29tL3BpbmdpZAphdXRoZW50aWNhdG9yX3VybD1odHRwczovL2F1dGhlbnRpY2F0b3IucGluZ29uZS5jb20vcGluZ2lkL3BwbQ==", | ||||||
|  | 				"method_name":          "pingid-method", | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	for _, tc := range testCases { | ||||||
|  | 		t.Run(tc.methodType, func(t *testing.T) { | ||||||
|  | 			// create a new method config | ||||||
|  | 			myPath := fmt.Sprintf("identity/mfa/method/%s", tc.methodType) | ||||||
|  | 			resp, err := client.Logical().Write(myPath, tc.configData) | ||||||
|  | 			if err != nil { | ||||||
|  | 				t.Fatal(err) | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			methodId := resp.Data["method_id"] | ||||||
|  | 			if methodId == "" { | ||||||
|  | 				t.Fatal("method id is empty") | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			// creating an MFA config with the same name should not return a new method ID | ||||||
|  | 			resp, err = client.Logical().Write(myPath, tc.configData) | ||||||
|  | 			if err != nil { | ||||||
|  | 				t.Fatal(err) | ||||||
|  | 			} | ||||||
|  | 			if methodId != resp.Data["method_id"] { | ||||||
|  | 				t.Fatal("trying to create a new MFA config with the same name should not result in a new MFA config") | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			originalName := tc.configData["method_name"] | ||||||
|  |  | ||||||
|  | 			// create a new MFA config name | ||||||
|  | 			tc.configData["method_name"] = "newName" | ||||||
|  | 			resp, err = client.Logical().Write(myPath, tc.configData) | ||||||
|  | 			if err != nil { | ||||||
|  | 				t.Fatal(err) | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			myNewPath := fmt.Sprintf("%s/%s", myPath, methodId) | ||||||
|  |  | ||||||
|  | 			// Updating an existing MFA config with another config's name | ||||||
|  | 			resp, err = client.Logical().Write(myNewPath, tc.configData) | ||||||
|  | 			if err == nil { | ||||||
|  | 				t.Fatalf("expected a failure for configuring an MFA method with an existing MFA method name, %v", err) | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			// Create a method with a / in the name | ||||||
|  | 			tc.configData["method_name"] = fmt.Sprintf("ns1/%s", originalName) | ||||||
|  | 			_, err = client.Logical().Write(myNewPath, tc.configData) | ||||||
|  | 			if err != nil { | ||||||
|  | 				t.Fatal(err) | ||||||
|  | 			} | ||||||
|  | 		}) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
| // TestLoginMFA_ListAllMFAConfigs tests listing all configs globally | // TestLoginMFA_ListAllMFAConfigs tests listing all configs globally | ||||||
| func TestLoginMFA_ListAllMFAConfigsGlobally(t *testing.T) { | func TestLoginMFA_ListAllMFAConfigsGlobally(t *testing.T) { | ||||||
| 	cluster := vault.NewTestCluster(t, &vault.CoreConfig{ | 	cluster := vault.NewTestCluster(t, &vault.CoreConfig{ | ||||||
|   | |||||||
| @@ -170,6 +170,10 @@ func mfaPaths(i *IdentityStore) []*framework.Path { | |||||||
| 		{ | 		{ | ||||||
| 			Pattern: "mfa/method/totp" + genericOptionalUUIDRegex("method_id"), | 			Pattern: "mfa/method/totp" + genericOptionalUUIDRegex("method_id"), | ||||||
| 			Fields: map[string]*framework.FieldSchema{ | 			Fields: map[string]*framework.FieldSchema{ | ||||||
|  | 				"method_name": { | ||||||
|  | 					Type:        framework.TypeString, | ||||||
|  | 					Description: `The unique name identifier for this MFA method.`, | ||||||
|  | 				}, | ||||||
| 				"method_id": { | 				"method_id": { | ||||||
| 					Type:        framework.TypeString, | 					Type:        framework.TypeString, | ||||||
| 					Description: `The unique identifier for this MFA method.`, | 					Description: `The unique identifier for this MFA method.`, | ||||||
| @@ -298,6 +302,10 @@ func mfaPaths(i *IdentityStore) []*framework.Path { | |||||||
| 		{ | 		{ | ||||||
| 			Pattern: "mfa/method/okta" + genericOptionalUUIDRegex("method_id"), | 			Pattern: "mfa/method/okta" + genericOptionalUUIDRegex("method_id"), | ||||||
| 			Fields: map[string]*framework.FieldSchema{ | 			Fields: map[string]*framework.FieldSchema{ | ||||||
|  | 				"method_name": { | ||||||
|  | 					Type:        framework.TypeString, | ||||||
|  | 					Description: `The unique name identifier for this MFA method.`, | ||||||
|  | 				}, | ||||||
| 				"method_id": { | 				"method_id": { | ||||||
| 					Type:        framework.TypeString, | 					Type:        framework.TypeString, | ||||||
| 					Description: `The unique identifier for this MFA method.`, | 					Description: `The unique identifier for this MFA method.`, | ||||||
| @@ -354,6 +362,10 @@ func mfaPaths(i *IdentityStore) []*framework.Path { | |||||||
| 		{ | 		{ | ||||||
| 			Pattern: "mfa/method/duo" + genericOptionalUUIDRegex("method_id"), | 			Pattern: "mfa/method/duo" + genericOptionalUUIDRegex("method_id"), | ||||||
| 			Fields: map[string]*framework.FieldSchema{ | 			Fields: map[string]*framework.FieldSchema{ | ||||||
|  | 				"method_name": { | ||||||
|  | 					Type:        framework.TypeString, | ||||||
|  | 					Description: `The unique name identifier for this MFA method.`, | ||||||
|  | 				}, | ||||||
| 				"method_id": { | 				"method_id": { | ||||||
| 					Type:        framework.TypeString, | 					Type:        framework.TypeString, | ||||||
| 					Description: `The unique identifier for this MFA method.`, | 					Description: `The unique identifier for this MFA method.`, | ||||||
| @@ -410,6 +422,10 @@ func mfaPaths(i *IdentityStore) []*framework.Path { | |||||||
| 		{ | 		{ | ||||||
| 			Pattern: "mfa/method/pingid" + genericOptionalUUIDRegex("method_id"), | 			Pattern: "mfa/method/pingid" + genericOptionalUUIDRegex("method_id"), | ||||||
| 			Fields: map[string]*framework.FieldSchema{ | 			Fields: map[string]*framework.FieldSchema{ | ||||||
|  | 				"method_name": { | ||||||
|  | 					Type:        framework.TypeString, | ||||||
|  | 					Description: `The unique name identifier for this MFA method.`, | ||||||
|  | 				}, | ||||||
| 				"method_id": { | 				"method_id": { | ||||||
| 					Type:        framework.TypeString, | 					Type:        framework.TypeString, | ||||||
| 					Description: `The unique identifier for this MFA method.`, | 					Description: `The unique identifier for this MFA method.`, | ||||||
|   | |||||||
| @@ -269,6 +269,7 @@ func (i *IdentityStore) handleMFAMethodUpdateCommon(ctx context.Context, req *lo | |||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	methodID := d.Get("method_id").(string) | 	methodID := d.Get("method_id").(string) | ||||||
|  | 	methodName := d.Get("method_name").(string) | ||||||
|  |  | ||||||
| 	b := i.mfaBackend | 	b := i.mfaBackend | ||||||
| 	b.mfaLock.Lock() | 	b.mfaLock.Lock() | ||||||
| @@ -286,6 +287,23 @@ func (i *IdentityStore) handleMFAMethodUpdateCommon(ctx context.Context, req *lo | |||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	// check if an MFA method configuration exists with that method name | ||||||
|  | 	if methodName != "" { | ||||||
|  | 		namedMfaConfig, err := b.MemDBMFAConfigByName(ctx, methodName) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return nil, err | ||||||
|  | 		} | ||||||
|  | 		if namedMfaConfig != nil { | ||||||
|  | 			if mConfig == nil { | ||||||
|  | 				mConfig = namedMfaConfig | ||||||
|  | 			} else { | ||||||
|  | 				if mConfig.ID != namedMfaConfig.ID { | ||||||
|  | 					return nil, fmt.Errorf("a login MFA method configuration with the method name %s already exists", methodName) | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	if mConfig == nil { | 	if mConfig == nil { | ||||||
| 		configID, err := uuid.GenerateUUID() | 		configID, err := uuid.GenerateUUID() | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| @@ -298,6 +316,11 @@ func (i *IdentityStore) handleMFAMethodUpdateCommon(ctx context.Context, req *lo | |||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	// Updating the method config name | ||||||
|  | 	if methodName != "" { | ||||||
|  | 		mConfig.Name = methodName | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	mfaNs, err := i.namespacer.NamespaceByID(ctx, mConfig.NamespaceID) | 	mfaNs, err := i.namespacer.NamespaceByID(ctx, mConfig.NamespaceID) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| @@ -647,6 +670,50 @@ func (b *LoginMFABackend) loginMFAMethodExistenceCheck(eConfig *mfa.MFAEnforceme | |||||||
| 	return aggErr.ErrorOrNil() | 	return aggErr.ErrorOrNil() | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // sanitizeMFACredsWithLoginEnforcementMethodIDs updates the MFACred map | ||||||
|  | // looping through the matched login enforcement configurations, and | ||||||
|  | // replacing MFA method names with MFA method IDs | ||||||
|  | func (b *LoginMFABackend) sanitizeMFACredsWithLoginEnforcementMethodIDs(ctx context.Context, mfaCredsMap logical.MFACreds, mfaMethodIDs []string) (logical.MFACreds, error) { | ||||||
|  | 	sanitizedMfaCreds := make(logical.MFACreds, 0) | ||||||
|  | 	var multiError *multierror.Error | ||||||
|  | 	for _, methodID := range mfaMethodIDs { | ||||||
|  | 		val, ok := mfaCredsMap[methodID] | ||||||
|  | 		if ok { | ||||||
|  | 			sanitizedMfaCreds[methodID] = val | ||||||
|  | 			continue | ||||||
|  | 		} | ||||||
|  | 		mConfig, err := b.MemDBMFAConfigByID(methodID) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return nil, err | ||||||
|  | 		} | ||||||
|  | 		// method name in the MFACredsMap should be the method full name, | ||||||
|  | 		// i.e., namespacePath+name. This is because, a user in a child | ||||||
|  | 		// namespace can reference an MFA method ID in a parent namespace | ||||||
|  | 		configNS, err := NamespaceByID(ctx, mConfig.NamespaceID, b.Core) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return nil, err | ||||||
|  | 		} | ||||||
|  | 		if configNS != nil { | ||||||
|  | 			val, ok = mfaCredsMap[configNS.Path+mConfig.Name] | ||||||
|  | 			if ok { | ||||||
|  | 				sanitizedMfaCreds[mConfig.ID] = val | ||||||
|  | 			} else { | ||||||
|  | 				multiError = multierror.Append(multiError, fmt.Errorf("failed to find MFA credentials associated with an MFA method ID %v, method name %v", methodID, configNS.Path+mConfig.Name)) | ||||||
|  | 			} | ||||||
|  | 		} else { | ||||||
|  | 			multiError = multierror.Append(multiError, fmt.Errorf("failed to find the namespace associated with an MFA method ID %v", mConfig.ID)) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// we don't need to find every MFA method identifiers in the MFA header | ||||||
|  | 	// So, don't return errors if that is the case. | ||||||
|  | 	if len(sanitizedMfaCreds) > 0 { | ||||||
|  | 		return sanitizedMfaCreds, nil | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return sanitizedMfaCreds, multiError | ||||||
|  | } | ||||||
|  |  | ||||||
| func (b *LoginMFABackend) handleMFALoginValidate(ctx context.Context, req *logical.Request, d *framework.FieldData) (retResp *logical.Response, retErr error) { | func (b *LoginMFABackend) handleMFALoginValidate(ctx context.Context, req *logical.Request, d *framework.FieldData) (retResp *logical.Response, retErr error) { | ||||||
| 	// mfaReqID is the ID of the login request | 	// mfaReqID is the ID of the login request | ||||||
| 	mfaReqID := d.Get("mfa_request_id").(string) | 	mfaReqID := d.Get("mfa_request_id").(string) | ||||||
| @@ -655,13 +722,13 @@ func (b *LoginMFABackend) handleMFALoginValidate(ctx context.Context, req *logic | |||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// a map of methodID to passcode | 	// a map of methodID to passcode | ||||||
| 	methodIDToPasscodeInterface := d.Get("mfa_payload") | 	mfaPayload := d.Get("mfa_payload") | ||||||
| 	if methodIDToPasscodeInterface == nil { | 	if mfaPayload == nil { | ||||||
| 		return logical.ErrorResponse("missing mfa payload"), nil | 		return logical.ErrorResponse("missing mfa payload"), nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	var mfaCreds logical.MFACreds | 	var mfaCreds logical.MFACreds | ||||||
| 	err := mapstructure.Decode(methodIDToPasscodeInterface, &mfaCreds) | 	err := mapstructure.Decode(mfaPayload, &mfaCreds) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return logical.ErrorResponse("invalid mfa payload"), nil | 		return logical.ErrorResponse("invalid mfa payload"), nil | ||||||
| 	} | 	} | ||||||
| @@ -1574,11 +1641,19 @@ func parseOktaConfig(mConfig *mfa.Config, d *framework.FieldData) error { | |||||||
| } | } | ||||||
|  |  | ||||||
| func (c *Core) validateLoginMFA(ctx context.Context, eConfig *mfa.MFAEnforcementConfig, entity *identity.Entity, requestConnRemoteAddr string, mfaCredsMap logical.MFACreds) error { | func (c *Core) validateLoginMFA(ctx context.Context, eConfig *mfa.MFAEnforcementConfig, entity *identity.Entity, requestConnRemoteAddr string, mfaCredsMap logical.MFACreds) error { | ||||||
|  | 	sanitizedMfaCreds, err := c.loginMFABackend.sanitizeMFACredsWithLoginEnforcementMethodIDs(ctx, mfaCredsMap, eConfig.MFAMethodIDs) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return fmt.Errorf("failed to sanitize MFA creds, %w", err) | ||||||
|  | 	} | ||||||
|  | 	if len(sanitizedMfaCreds) == 0 && len(eConfig.MFAMethodIDs) > 0 { | ||||||
|  | 		return fmt.Errorf("login MFA validation failed for methodID: %v", eConfig.MFAMethodIDs) | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	var retErr error | 	var retErr error | ||||||
| 	for _, methodID := range eConfig.MFAMethodIDs { | 	for _, methodID := range eConfig.MFAMethodIDs { | ||||||
| 		// as configID is the same as methodID, and methodID is unique, we can | 		// as configID is the same as methodID, and methodID is unique, we can | ||||||
| 		// use it to retrieve the MFACreds | 		// use it to retrieve the MFACreds | ||||||
| 		mfaCreds, ok := mfaCredsMap[methodID] | 		mfaCreds, ok := sanitizedMfaCreds[methodID] | ||||||
| 		if !ok || mfaCreds == nil { | 		if !ok || mfaCreds == nil { | ||||||
| 			continue | 			continue | ||||||
| 		} | 		} | ||||||
| @@ -1634,6 +1709,11 @@ func (c *Core) validateLoginMFAInternal(ctx context.Context, methodID string, en | |||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	mfaFactors, err := parseMfaFactors(mfaCreds) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return fmt.Errorf("failed to parse MFA factor, %w", err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	switch mConfig.Type { | 	switch mConfig.Type { | ||||||
| 	case mfaMethodTypeTOTP: | 	case mfaMethodTypeTOTP: | ||||||
| 		// Get the MFA secret data required to validate the supplied credentials | 		// Get the MFA secret data required to validate the supplied credentials | ||||||
| @@ -1645,17 +1725,13 @@ func (c *Core) validateLoginMFAInternal(ctx context.Context, methodID string, en | |||||||
| 			return fmt.Errorf("MFA secret for method name %q not present in entity %q", mConfig.Name, entity.ID) | 			return fmt.Errorf("MFA secret for method name %q not present in entity %q", mConfig.Name, entity.ID) | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		if mfaCreds == nil { | 		return c.validateTOTP(ctx, mfaFactors, entityMFASecret, mConfig.ID, entity.ID, c.loginMFABackend.usedCodes, mConfig.GetTOTPConfig().MaxValidationAttempts) | ||||||
| 			return fmt.Errorf("MFA credentials not supplied") |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		return c.validateTOTP(ctx, mfaCreds, entityMFASecret, mConfig.ID, entity.ID, c.loginMFABackend.usedCodes, mConfig.GetTOTPConfig().MaxValidationAttempts) |  | ||||||
|  |  | ||||||
| 	case mfaMethodTypeOkta: | 	case mfaMethodTypeOkta: | ||||||
| 		return c.validateOkta(ctx, mConfig, finalUsername) | 		return c.validateOkta(ctx, mConfig, finalUsername) | ||||||
|  |  | ||||||
| 	case mfaMethodTypeDuo: | 	case mfaMethodTypeDuo: | ||||||
| 		return c.validateDuo(ctx, mfaCreds, mConfig, finalUsername, reqConnectionRemoteAddress) | 		return c.validateDuo(ctx, mfaFactors, mConfig, finalUsername, reqConnectionRemoteAddress) | ||||||
|  |  | ||||||
| 	case mfaMethodTypePingID: | 	case mfaMethodTypePingID: | ||||||
| 		return c.validatePingID(ctx, mConfig, finalUsername) | 		return c.validatePingID(ctx, mConfig, finalUsername) | ||||||
| @@ -1764,23 +1840,52 @@ func formatUsername(format string, alias *identity.Alias, entity *identity.Entit | |||||||
| 	return username | 	return username | ||||||
| } | } | ||||||
|  |  | ||||||
| func (c *Core) validateDuo(ctx context.Context, creds []string, mConfig *mfa.Config, username, reqConnectionRemoteAddr string) error { | type MFAFactor struct { | ||||||
|  | 	passcode string | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func parseMfaFactors(creds []string) (*MFAFactor, error) { | ||||||
|  | 	mfaFactor := &MFAFactor{} | ||||||
|  |  | ||||||
|  | 	for _, cred := range creds { | ||||||
|  | 		switch { | ||||||
|  | 		case cred == "": // for the case of push notification | ||||||
|  | 			continue | ||||||
|  | 		case strings.HasPrefix(cred, "passcode="): | ||||||
|  | 			if mfaFactor.passcode != "" { | ||||||
|  | 				return nil, fmt.Errorf("found multiple passcodes for the same MFA method") | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			splits := strings.SplitN(cred, "=", 2) | ||||||
|  | 			if splits[1] == "" { | ||||||
|  | 				return nil, fmt.Errorf("invalid passcode") | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			mfaFactor.passcode = splits[1] | ||||||
|  | 		case strings.Contains(cred, "="): | ||||||
|  | 			return nil, fmt.Errorf("found an invalid MFA cred: %v", cred) | ||||||
|  | 		default: | ||||||
|  | 			// a non-empty cred that does not match the above | ||||||
|  | 			// means it is a passcode | ||||||
|  | 			if mfaFactor.passcode != "" { | ||||||
|  | 				return nil, fmt.Errorf("found multiple passcodes for the same MFA method") | ||||||
|  | 			} | ||||||
|  | 			mfaFactor.passcode = cred | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return mfaFactor, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *Core) validateDuo(ctx context.Context, mfaFactors *MFAFactor, mConfig *mfa.Config, username, reqConnectionRemoteAddr string) error { | ||||||
| 	duoConfig := mConfig.GetDuoConfig() | 	duoConfig := mConfig.GetDuoConfig() | ||||||
| 	if duoConfig == nil { | 	if duoConfig == nil { | ||||||
| 		return fmt.Errorf("failed to get Duo configuration for method %q", mConfig.Name) | 		return fmt.Errorf("failed to get Duo configuration for method %q", mConfig.Name) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	passcode := "" | 	var passcode string | ||||||
| 	for _, cred := range creds { | 	if mfaFactors != nil { | ||||||
| 		if strings.HasPrefix(cred, "passcode") { | 		passcode = mfaFactors.passcode | ||||||
| 			splits := strings.SplitN(cred, "=", 2) |  | ||||||
| 			if len(splits) != 2 { |  | ||||||
| 				return fmt.Errorf("invalid credential %q", cred) |  | ||||||
| 			} |  | ||||||
| 			if splits[0] == "passcode" { |  | ||||||
| 				passcode = splits[1] |  | ||||||
| 			} |  | ||||||
| 		} |  | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	client := duoapi.NewDuoApi( | 	client := duoapi.NewDuoApi( | ||||||
| @@ -2229,21 +2334,18 @@ func (c *Core) validatePingID(ctx context.Context, mConfig *mfa.Config, username | |||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
|  |  | ||||||
| func (c *Core) validateTOTP(ctx context.Context, creds []string, entityMethodSecret *mfa.Secret, configID, entityID string, usedCodes *cache.Cache, maximumValidationAttempts uint32) error { | func (c *Core) validateTOTP(ctx context.Context, mfaFactors *MFAFactor, entityMethodSecret *mfa.Secret, configID, entityID string, usedCodes *cache.Cache, maximumValidationAttempts uint32) error { | ||||||
| 	if len(creds) == 0 { | 	if mfaFactors.passcode == "" { | ||||||
| 		return fmt.Errorf("missing TOTP passcode") | 		return fmt.Errorf("MFA credentials not supplied") | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	if len(creds) > 1 { |  | ||||||
| 		return fmt.Errorf("more than one TOTP passcode supplied") |  | ||||||
| 	} | 	} | ||||||
|  | 	passcode := mfaFactors.passcode | ||||||
|  |  | ||||||
| 	totpSecret := entityMethodSecret.GetTOTPSecret() | 	totpSecret := entityMethodSecret.GetTOTPSecret() | ||||||
| 	if totpSecret == nil { | 	if totpSecret == nil { | ||||||
| 		return fmt.Errorf("entity does not contain the TOTP secret") | 		return fmt.Errorf("entity does not contain the TOTP secret") | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	usedName := fmt.Sprintf("%s_%s", configID, creds[0]) | 	usedName := fmt.Sprintf("%s_%s", configID, passcode) | ||||||
|  |  | ||||||
| 	_, ok := usedCodes.Get(usedName) | 	_, ok := usedCodes.Get(usedName) | ||||||
| 	if ok { | 	if ok { | ||||||
| @@ -2290,7 +2392,7 @@ func (c *Core) validateTOTP(ctx context.Context, creds []string, entityMethodSec | |||||||
| 		Algorithm: otplib.Algorithm(int(totpSecret.Algorithm)), | 		Algorithm: otplib.Algorithm(int(totpSecret.Algorithm)), | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	valid, err := totplib.ValidateCustom(creds[0], key, time.Now(), validateOpts) | 	valid, err := totplib.ValidateCustom(passcode, key, time.Now(), validateOpts) | ||||||
| 	if err != nil && err != otplib.ErrValidateInputInvalidLength { | 	if err != nil && err != otplib.ErrValidateInputInvalidLength { | ||||||
| 		return errwrap.Wrapf("failed to validate TOTP passcode: {{err}}", err) | 		return errwrap.Wrapf("failed to validate TOTP passcode: {{err}}", err) | ||||||
| 	} | 	} | ||||||
| @@ -2340,6 +2442,21 @@ func loginMFAConfigTableSchema() *memdb.TableSchema { | |||||||
| 					Field: "Type", | 					Field: "Type", | ||||||
| 				}, | 				}, | ||||||
| 			}, | 			}, | ||||||
|  | 			"name": { | ||||||
|  | 				Name:         "name", | ||||||
|  | 				Unique:       true, | ||||||
|  | 				AllowMissing: true, | ||||||
|  | 				Indexer: &memdb.CompoundIndex{ | ||||||
|  | 					Indexes: []memdb.Indexer{ | ||||||
|  | 						&memdb.StringFieldIndex{ | ||||||
|  | 							Field: "NamespaceID", | ||||||
|  | 						}, | ||||||
|  | 						&memdb.StringFieldIndex{ | ||||||
|  | 							Field: "Name", | ||||||
|  | 						}, | ||||||
|  | 					}, | ||||||
|  | 				}, | ||||||
|  | 			}, | ||||||
| 		}, | 		}, | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| @@ -2487,6 +2604,47 @@ func (b *LoginMFABackend) MemDBMFAConfigByID(mConfigID string) (*mfa.Config, err | |||||||
| 	return b.MemDBMFAConfigByIDInTxn(txn, mConfigID) | 	return b.MemDBMFAConfigByIDInTxn(txn, mConfigID) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func (b *LoginMFABackend) MemDBMFAConfigByNameInTxn(ctx context.Context, txn *memdb.Txn, mConfigName string) (*mfa.Config, error) { | ||||||
|  | 	if mConfigName == "" { | ||||||
|  | 		return nil, fmt.Errorf("missing config name") | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if txn == nil { | ||||||
|  | 		return nil, fmt.Errorf("txn is nil") | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	ns, err := namespace.FromContext(ctx) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	mConfigRaw, err := txn.First(b.methodTable, "name", ns.ID, mConfigName) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, fmt.Errorf("failed to fetch MFA config from memdb using name: %w", err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if mConfigRaw == nil { | ||||||
|  | 		return nil, nil | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	mConfig, ok := mConfigRaw.(*mfa.Config) | ||||||
|  | 	if !ok { | ||||||
|  | 		return nil, fmt.Errorf("failed to declare the type of fetched MFA config") | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return mConfig.Clone() | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (b *LoginMFABackend) MemDBMFAConfigByName(ctx context.Context, name string) (*mfa.Config, error) { | ||||||
|  | 	if name == "" { | ||||||
|  | 		return nil, fmt.Errorf("missing config name") | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	txn := b.db.Txn(false) | ||||||
|  |  | ||||||
|  | 	return b.MemDBMFAConfigByNameInTxn(ctx, txn, name) | ||||||
|  | } | ||||||
|  |  | ||||||
| func (b *LoginMFABackend) MemDBMFALoginEnforcementConfigByNameAndNamespace(name, namespaceId string) (*mfa.MFAEnforcementConfig, error) { | func (b *LoginMFABackend) MemDBMFALoginEnforcementConfigByNameAndNamespace(name, namespaceId string) (*mfa.MFAEnforcementConfig, error) { | ||||||
| 	if name == "" { | 	if name == "" { | ||||||
| 		return nil, fmt.Errorf("missing config name") | 		return nil, fmt.Errorf("missing config name") | ||||||
|   | |||||||
							
								
								
									
										61
									
								
								vault/login_mfa_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										61
									
								
								vault/login_mfa_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,61 @@ | |||||||
|  | package vault | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"strings" | ||||||
|  | 	"testing" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func TestParseFactors(t *testing.T) { | ||||||
|  | 	testcases := []struct { | ||||||
|  | 		name                string | ||||||
|  | 		invalidMFAHeaderVal []string | ||||||
|  | 		expectedError       string | ||||||
|  | 	}{ | ||||||
|  | 		{ | ||||||
|  | 			"two headers with passcode", | ||||||
|  | 			[]string{"passcode", "foo"}, | ||||||
|  | 			"found multiple passcodes for the same MFA method", | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			"single header with passcode=", | ||||||
|  | 			[]string{"passcode="}, | ||||||
|  | 			"invalid passcode", | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			"single invalid header", | ||||||
|  | 			[]string{"foo="}, | ||||||
|  | 			"found an invalid MFA cred", | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			"single header equal char", | ||||||
|  | 			[]string{"=="}, | ||||||
|  | 			"found an invalid MFA cred", | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			"two headers with passcode=", | ||||||
|  | 			[]string{"passcode=foo", "foo"}, | ||||||
|  | 			"found multiple passcodes for the same MFA method", | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			"two headers invalid name", | ||||||
|  | 			[]string{"passcode=foo", "passcode=bar"}, | ||||||
|  | 			"found multiple passcodes for the same MFA method", | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			"two headers, two invalid", | ||||||
|  | 			[]string{"foo", "bar"}, | ||||||
|  | 			"found multiple passcodes for the same MFA method", | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | 	for _, tc := range testcases { | ||||||
|  | 		t.Run(tc.name, func(t *testing.T) { | ||||||
|  | 			_, err := parseMfaFactors(tc.invalidMFAHeaderVal) | ||||||
|  | 			if err == nil { | ||||||
|  | 				t.Fatal("nil error returned") | ||||||
|  | 			} | ||||||
|  | 			if !strings.Contains(err.Error(), tc.expectedError) { | ||||||
|  | 				t.Fatalf("expected %s, got %v", tc.expectedError, err) | ||||||
|  | 			} | ||||||
|  | 		}) | ||||||
|  | 	} | ||||||
|  | } | ||||||
| @@ -2059,6 +2059,7 @@ func (c *Core) buildMfaEnforcementResponse(eConfig *mfa.MFAEnforcementConfig) (* | |||||||
| 			Type:         mConfig.Type, | 			Type:         mConfig.Type, | ||||||
| 			ID:           methodID, | 			ID:           methodID, | ||||||
| 			UsesPasscode: mConfig.Type == mfaMethodTypeTOTP || duoUsePasscode, | 			UsesPasscode: mConfig.Type == mfaMethodTypeTOTP || duoUsePasscode, | ||||||
|  | 			Name:         mConfig.Name, | ||||||
| 		} | 		} | ||||||
| 		mfaAny.Any = append(mfaAny.Any, mfaMethod) | 		mfaAny.Any = append(mfaAny.Any, mfaMethod) | ||||||
| 	} | 	} | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Hamid Ghaf
					Hamid Ghaf