diff --git a/elixir/apps/domain/lib/domain/billing/stripe/api_client.ex b/elixir/apps/domain/lib/domain/billing/stripe/api_client.ex index d9e227208..23b2ffdfe 100644 --- a/elixir/apps/domain/lib/domain/billing/stripe/api_client.ex +++ b/elixir/apps/domain/lib/domain/billing/stripe/api_client.ex @@ -1,5 +1,6 @@ defmodule Domain.Billing.Stripe.APIClient do use Supervisor + require Logger @pool_name __MODULE__.Finch @@ -37,7 +38,7 @@ defmodule Domain.Billing.Stripe.APIClient do |> put_if_not_nil("email", email) |> URI.encode_query(:www_form) - request(api_token, :post, "customers", body) + request_with_retry(api_token, :post, "customers", body) end def update_customer(api_token, customer_id, name, metadata) do @@ -51,11 +52,11 @@ defmodule Domain.Billing.Stripe.APIClient do |> Map.put("name", name) |> URI.encode_query(:www_form) - request(api_token, :post, "customers/#{customer_id}", body) + request_with_retry(api_token, :post, "customers/#{customer_id}", body) end def fetch_customer(api_token, customer_id) do - request(api_token, :get, "customers/#{customer_id}", "") + request_with_retry(api_token, :get, "customers/#{customer_id}", "") end def list_all_subscriptions(api_token, page_after \\ nil, acc \\ []) do @@ -66,7 +67,7 @@ defmodule Domain.Billing.Stripe.APIClient do "" end - case request(api_token, :get, "subscriptions#{query_params}", "") do + case request_with_retry(api_token, :get, "subscriptions#{query_params}", "") do {:ok, %{"has_more" => true, "data" => data}} -> page_after = List.last(data)["id"] list_all_subscriptions(api_token, page_after, acc ++ data) @@ -80,12 +81,12 @@ defmodule Domain.Billing.Stripe.APIClient do end def fetch_product(api_token, product_id) do - request(api_token, :get, "products/#{product_id}", "") + request_with_retry(api_token, :get, "products/#{product_id}", "") end def create_billing_portal_session(api_token, customer_id, return_url) do body = URI.encode_query(%{"customer" => customer_id, "return_url" => return_url}, :www_form) - request(api_token, :post, "billing_portal/sessions", body) + request_with_retry(api_token, :post, "billing_portal/sessions", body) end def create_subscription(api_token, customer_id, price_id) do @@ -98,7 +99,67 @@ defmodule Domain.Billing.Stripe.APIClient do :www_form ) - request(api_token, :post, "subscriptions", body) + request_with_retry(api_token, :post, "subscriptions", body) + end + + def request_with_retry(api_token, method, path, body) do + max_retries = fetch_retry_config(:max_retries, 3) + base_delay = fetch_retry_config(:base_delay_ms, 1000) + max_delay = fetch_retry_config(:max_delay_ms, 10_000) + + do_request_with_retry(api_token, method, path, body, 0, max_retries, base_delay, max_delay) + end + + defp do_request_with_retry( + api_token, + method, + path, + body, + attempt, + max_retries, + base_delay, + max_delay + ) do + case request(api_token, method, path, body) do + {:ok, response} -> + {:ok, response} + + {:error, {429, response}} when attempt < max_retries -> + delay = calculate_delay(attempt, base_delay, max_delay) + + Logger.warning( + "Rate limited by Stripe API (429), retrying request.", + request_delay: "#{delay}ms", + attempt_num: "#{attempt + 1} of #{max_retries}", + response: inspect(response) + ) + + Process.sleep(delay) + + do_request_with_retry( + api_token, + method, + path, + body, + attempt + 1, + max_retries, + base_delay, + max_delay + ) + + {:error, reason} -> + {:error, reason} + end + end + + defp calculate_delay(attempt, base_delay, max_delay) do + # Exponential backoff with jitter + exponential_delay = base_delay * :math.pow(2, attempt) + jitter = :rand.uniform() * 0.1 * exponential_delay + + (exponential_delay + jitter) + |> round() + |> min(max_delay) end def request(api_token, method, path, body) do @@ -141,4 +202,13 @@ defmodule Domain.Billing.Stripe.APIClient do Domain.Config.fetch_env!(:domain, __MODULE__) |> Keyword.fetch!(key) end + + defp fetch_retry_config(key, default) do + config = Domain.Config.fetch_env!(:domain, __MODULE__) + + case Keyword.get(config, :retry_config) do + nil -> default + retry_config -> Keyword.get(retry_config, key, default) + end + end end diff --git a/elixir/apps/domain/test/domain/billing/stripe/api_client_test.exs b/elixir/apps/domain/test/domain/billing/stripe/api_client_test.exs new file mode 100644 index 000000000..a8b853b73 --- /dev/null +++ b/elixir/apps/domain/test/domain/billing/stripe/api_client_test.exs @@ -0,0 +1,43 @@ +defmodule Domain.Billing.Stripe.APIClientTest do + use Domain.DataCase, async: true + alias Domain.Mocks.Stripe + import Domain.Billing.Stripe.APIClient + + describe "client retry logic" do + test "retries on 429 rate limit responses" do + bypass = Bypass.open() + account = Fixtures.Accounts.create_account() + + # Configure to return 429 on first 2 requests, then succeed + Stripe.enable_rate_limiting(2) + Stripe.mock_create_customer_endpoint(bypass, account) + + # This should succeed after 2 retries + {:ok, _customer} = + create_customer("secret_token_123", account.name, "test@example.com", %{}) + + # Verify 3 total requests were made (2 failures + 1 success) + assert Stripe.get_request_count() == 3 + + Domain.Mocks.Stripe.disable_rate_limiting() + end + + test "gives up after max retries exceeded" do + bypass = Bypass.open() + account = Fixtures.Accounts.create_account() + + # Configure to always return 429 + Stripe.configure_rate_limiting(fn _count -> true end) + Stripe.mock_create_customer_endpoint(bypass, account) + + # This should fail after exhausting retries + {:error, {429, _}} = + create_customer("secret_token_123", account.name, "test@example.com", %{}) + + # Should have made max_retries + 1 attempts + assert Stripe.get_request_count() == 4 + + Domain.Mocks.Stripe.disable_rate_limiting() + end + end +end diff --git a/elixir/apps/domain/test/support/mocks/stripe.ex b/elixir/apps/domain/test/support/mocks/stripe.ex index a6803fdea..dc6175213 100644 --- a/elixir/apps/domain/test/support/mocks/stripe.ex +++ b/elixir/apps/domain/test/support/mocks/stripe.ex @@ -13,21 +13,30 @@ defmodule Domain.Mocks.Stripe do test_pid = self() Bypass.expect(bypass, "POST", customers_endpoint_path, fn conn -> - conn = Plug.Conn.fetch_query_params(conn) - conn = fetch_request_params(conn) - send(test_pid, {:bypass_request, conn}) + # Store test PID in connection for rate limiting + conn = %{conn | private: Map.put(conn.private, :test_pid, test_pid)} - email = Map.get(conn.params, "email", "foo@example.com") + case check_rate_limit(conn) do + {:rate_limited, response} -> + response - resp = - Map.merge( - customer_object("cus_NffrFeUfNV2Hib", account.name, email, %{ - "account_id" => account.id - }), - resp - ) + :ok -> + conn = Plug.Conn.fetch_query_params(conn) + conn = fetch_request_params(conn) + send(test_pid, {:bypass_request, conn}) - Plug.Conn.send_resp(conn, 200, Jason.encode!(resp)) + email = Map.get(conn.params, "email", "foo@example.com") + + resp = + Map.merge( + customer_object("cus_NffrFeUfNV2Hib", account.name, email, %{ + "account_id" => account.id + }), + resp + ) + + Plug.Conn.send_resp(conn, 200, Jason.encode!(resp)) + end end) override_endpoint_url("http://localhost:#{bypass.port}") @@ -49,10 +58,18 @@ defmodule Domain.Mocks.Stripe do test_pid = self() Bypass.expect(bypass, "POST", customer_endpoint_path, fn conn -> - conn = Plug.Conn.fetch_query_params(conn) - conn = fetch_request_params(conn) - send(test_pid, {:bypass_request, conn}) - Plug.Conn.send_resp(conn, 200, Jason.encode!(resp)) + conn = %{conn | private: Map.put(conn.private, :test_pid, test_pid)} + + case check_rate_limit(conn) do + {:rate_limited, response} -> + response + + :ok -> + conn = Plug.Conn.fetch_query_params(conn) + conn = fetch_request_params(conn) + send(test_pid, {:bypass_request, conn}) + Plug.Conn.send_resp(conn, 200, Jason.encode!(resp)) + end end) override_endpoint_url("http://localhost:#{bypass.port}") @@ -74,9 +91,17 @@ defmodule Domain.Mocks.Stripe do test_pid = self() Bypass.expect(bypass, "GET", customer_endpoint_path, fn conn -> - conn = Plug.Conn.fetch_query_params(conn) - send(test_pid, {:bypass_request, conn}) - Plug.Conn.send_resp(conn, 200, Jason.encode!(resp)) + conn = %{conn | private: Map.put(conn.private, :test_pid, test_pid)} + + case check_rate_limit(conn) do + {:rate_limited, response} -> + response + + :ok -> + conn = Plug.Conn.fetch_query_params(conn) + send(test_pid, {:bypass_request, conn}) + Plug.Conn.send_resp(conn, 200, Jason.encode!(resp)) + end end) override_endpoint_url("http://localhost:#{bypass.port}") @@ -115,9 +140,17 @@ defmodule Domain.Mocks.Stripe do test_pid = self() Bypass.expect(bypass, "GET", product_endpoint_path, fn conn -> - conn = Plug.Conn.fetch_query_params(conn) - send(test_pid, {:bypass_request, conn}) - Plug.Conn.send_resp(conn, 200, Jason.encode!(resp)) + conn = %{conn | private: Map.put(conn.private, :test_pid, test_pid)} + + case check_rate_limit(conn) do + {:rate_limited, response} -> + response + + :ok -> + conn = Plug.Conn.fetch_query_params(conn) + send(test_pid, {:bypass_request, conn}) + Plug.Conn.send_resp(conn, 200, Jason.encode!(resp)) + end end) override_endpoint_url("http://localhost:#{bypass.port}") @@ -150,10 +183,18 @@ defmodule Domain.Mocks.Stripe do test_pid = self() Bypass.expect(bypass, "POST", customers_endpoint_path, fn conn -> - conn = Plug.Conn.fetch_query_params(conn) - conn = fetch_request_params(conn) - send(test_pid, {:bypass_request, conn}) - Plug.Conn.send_resp(conn, 200, Jason.encode!(resp)) + conn = %{conn | private: Map.put(conn.private, :test_pid, test_pid)} + + case check_rate_limit(conn) do + {:rate_limited, response} -> + response + + :ok -> + conn = Plug.Conn.fetch_query_params(conn) + conn = fetch_request_params(conn) + send(test_pid, {:bypass_request, conn}) + Plug.Conn.send_resp(conn, 200, Jason.encode!(resp)) + end end) override_endpoint_url("http://localhost:#{bypass.port}") @@ -173,10 +214,18 @@ defmodule Domain.Mocks.Stripe do test_pid = self() Bypass.expect(bypass, "POST", customers_endpoint_path, fn conn -> - conn = Plug.Conn.fetch_query_params(conn) - conn = fetch_request_params(conn) - send(test_pid, {:bypass_request, conn}) - Plug.Conn.send_resp(conn, 200, Jason.encode!(resp)) + conn = %{conn | private: Map.put(conn.private, :test_pid, test_pid)} + + case check_rate_limit(conn) do + {:rate_limited, response} -> + response + + :ok -> + conn = Plug.Conn.fetch_query_params(conn) + conn = fetch_request_params(conn) + send(test_pid, {:bypass_request, conn}) + Plug.Conn.send_resp(conn, 200, Jason.encode!(resp)) + end end) override_endpoint_url("http://localhost:#{bypass.port}") @@ -184,6 +233,160 @@ defmodule Domain.Mocks.Stripe do bypass end + # Rate limiting control functions + def enable_rate_limiting(rate_limit_count \\ 2) do + ensure_ets_table() + test_pid = self() + + state = %{ + rate_limit_enabled: true, + rate_limit_count: rate_limit_count, + request_count: 0 + } + + :ets.insert(:stripe_mock_state, {test_pid, state}) + end + + def disable_rate_limiting do + ensure_ets_table() + test_pid = self() + :ets.delete(:stripe_mock_state, test_pid) + end + + def reset_rate_limit_counter do + ensure_ets_table() + test_pid = self() + + case :ets.lookup(:stripe_mock_state, test_pid) do + [{^test_pid, state}] -> + updated_state = Map.put(state, :request_count, 0) + :ets.insert(:stripe_mock_state, {test_pid, updated_state}) + + [] -> + :ok + end + end + + def get_request_count do + ensure_ets_table() + test_pid = self() + + case :ets.lookup(:stripe_mock_state, test_pid) do + [{^test_pid, state}] -> Map.get(state, :request_count, 0) + [] -> 0 + end + end + + # Rate limiting with flexible patterns + def configure_rate_limiting(pattern) when is_function(pattern, 1) do + ensure_ets_table() + test_pid = self() + + state = %{ + rate_limit_pattern: pattern, + request_count: 0 + } + + :ets.insert(:stripe_mock_state, {test_pid, state}) + end + + def configure_rate_limiting(opts) when is_list(opts) do + ensure_ets_table() + test_pid = self() + fail_on_attempts = Keyword.get(opts, :fail_on_attempts, [1, 2]) + pattern = fn count -> count in fail_on_attempts end + + state = %{ + rate_limit_pattern: pattern, + request_count: 0 + } + + :ets.insert(:stripe_mock_state, {test_pid, state}) + end + + defp ensure_ets_table do + case :ets.whereis(:stripe_mock_state) do + :undefined -> + :ets.new(:stripe_mock_state, [:named_table, :public, :set]) + + _table -> + :ok + end + end + + defp check_rate_limit(conn) do + ensure_ets_table() + + # Get the test process PID from the connection + test_pid = get_test_pid(conn) + + # Get current state and increment request count + {current_count, state} = + case :ets.lookup(:stripe_mock_state, test_pid) do + [{^test_pid, existing_state}] -> + new_count = Map.get(existing_state, :request_count, 0) + 1 + updated_state = Map.put(existing_state, :request_count, new_count) + :ets.insert(:stripe_mock_state, {test_pid, updated_state}) + {new_count, updated_state} + + [] -> + {1, %{request_count: 1}} + end + + cond do + # Check for custom pattern function + pattern_fn = Map.get(state, :rate_limit_pattern) -> + if pattern_fn.(current_count) do + {:rate_limited, send_rate_limit_response(conn)} + else + :ok + end + + # Check for simple rate limiting + Map.get(state, :rate_limit_enabled) -> + rate_limit_count = Map.get(state, :rate_limit_count, 2) + + if current_count <= rate_limit_count do + {:rate_limited, send_rate_limit_response(conn)} + else + :ok + end + + true -> + :ok + end + end + + defp get_test_pid(conn) do + # Look for the test PID in the connection's private assigns + case Map.get(conn.private, :test_pid) do + nil -> + # Fallback: try to find any active test process + case :ets.first(:stripe_mock_state) do + :"$end_of_table" -> self() + pid -> pid + end + + pid -> + pid + end + end + + defp send_rate_limit_response(conn) do + error_response = %{ + "error" => %{ + "code" => "lock_timeout", + "doc_url" => "https://stripe.com/docs/error-codes/lock-timeout", + "message" => "Error message here", + "request_log_url" => "https://dashboard.stripe.com/logs/req_ABC123DEF456", + "type" => "invalid_request_error" + } + } + + conn + |> Plug.Conn.send_resp(429, Jason.encode!(error_response)) + end + def customer_object(id, name, email \\ nil, metadata \\ %{}) do %{ "id" => id, diff --git a/elixir/config/config.exs b/elixir/config/config.exs index 5e470e94e..f143b638e 100644 --- a/elixir/config/config.exs +++ b/elixir/config/config.exs @@ -92,7 +92,12 @@ config :domain, Domain.Auth.Adapters.Okta.APIClient, finch_transport_opts: [] config :domain, Domain.Billing.Stripe.APIClient, endpoint: "https://api.stripe.com", - finch_transport_opts: [] + finch_transport_opts: [], + retry_config: [ + max_retries: 3, + base_delay_ms: 1000, + max_delay_ms: 10_000 + ] config :domain, Domain.Billing, enabled: true, diff --git a/elixir/config/test.exs b/elixir/config/test.exs index af6a2b598..3d29ce99b 100644 --- a/elixir/config/test.exs +++ b/elixir/config/test.exs @@ -26,6 +26,15 @@ config :domain, Domain.Events.ReplicationConnection, database: "firezone_test#{partition_suffix}" ] +config :domain, Domain.Billing.Stripe.APIClient, + endpoint: "https://api.stripe.com", + finch_transport_opts: [], + retry_config: [ + max_retries: 3, + base_delay_ms: 100, + max_delay_ms: 1000 + ] + config :domain, Domain.Telemetry, enabled: false config :domain, Domain.ConnectivityChecks, enabled: false