refactor(portal): add retry logic to Stripe API client (#9466)

Why:

* We've seen some Stripe API requests come back with 429 responses,
which likely could be retried and succeed. This commit adds some basic
retry logic to our Stripe API client.
This commit is contained in:
Brian Manifold
2025-06-09 16:11:33 -07:00
committed by GitHub
parent 0e5e2296a8
commit 6d425d5677
5 changed files with 368 additions and 38 deletions

View File

@@ -1,5 +1,6 @@
defmodule Domain.Billing.Stripe.APIClient do
use Supervisor
require Logger
@pool_name __MODULE__.Finch
@@ -37,7 +38,7 @@ defmodule Domain.Billing.Stripe.APIClient do
|> put_if_not_nil("email", email)
|> URI.encode_query(:www_form)
request(api_token, :post, "customers", body)
request_with_retry(api_token, :post, "customers", body)
end
def update_customer(api_token, customer_id, name, metadata) do
@@ -51,11 +52,11 @@ defmodule Domain.Billing.Stripe.APIClient do
|> Map.put("name", name)
|> URI.encode_query(:www_form)
request(api_token, :post, "customers/#{customer_id}", body)
request_with_retry(api_token, :post, "customers/#{customer_id}", body)
end
def fetch_customer(api_token, customer_id) do
request(api_token, :get, "customers/#{customer_id}", "")
request_with_retry(api_token, :get, "customers/#{customer_id}", "")
end
def list_all_subscriptions(api_token, page_after \\ nil, acc \\ []) do
@@ -66,7 +67,7 @@ defmodule Domain.Billing.Stripe.APIClient do
""
end
case request(api_token, :get, "subscriptions#{query_params}", "") do
case request_with_retry(api_token, :get, "subscriptions#{query_params}", "") do
{:ok, %{"has_more" => true, "data" => data}} ->
page_after = List.last(data)["id"]
list_all_subscriptions(api_token, page_after, acc ++ data)
@@ -80,12 +81,12 @@ defmodule Domain.Billing.Stripe.APIClient do
end
def fetch_product(api_token, product_id) do
request(api_token, :get, "products/#{product_id}", "")
request_with_retry(api_token, :get, "products/#{product_id}", "")
end
def create_billing_portal_session(api_token, customer_id, return_url) do
body = URI.encode_query(%{"customer" => customer_id, "return_url" => return_url}, :www_form)
request(api_token, :post, "billing_portal/sessions", body)
request_with_retry(api_token, :post, "billing_portal/sessions", body)
end
def create_subscription(api_token, customer_id, price_id) do
@@ -98,7 +99,67 @@ defmodule Domain.Billing.Stripe.APIClient do
:www_form
)
request(api_token, :post, "subscriptions", body)
request_with_retry(api_token, :post, "subscriptions", body)
end
def request_with_retry(api_token, method, path, body) do
max_retries = fetch_retry_config(:max_retries, 3)
base_delay = fetch_retry_config(:base_delay_ms, 1000)
max_delay = fetch_retry_config(:max_delay_ms, 10_000)
do_request_with_retry(api_token, method, path, body, 0, max_retries, base_delay, max_delay)
end
defp do_request_with_retry(
api_token,
method,
path,
body,
attempt,
max_retries,
base_delay,
max_delay
) do
case request(api_token, method, path, body) do
{:ok, response} ->
{:ok, response}
{:error, {429, response}} when attempt < max_retries ->
delay = calculate_delay(attempt, base_delay, max_delay)
Logger.warning(
"Rate limited by Stripe API (429), retrying request.",
request_delay: "#{delay}ms",
attempt_num: "#{attempt + 1} of #{max_retries}",
response: inspect(response)
)
Process.sleep(delay)
do_request_with_retry(
api_token,
method,
path,
body,
attempt + 1,
max_retries,
base_delay,
max_delay
)
{:error, reason} ->
{:error, reason}
end
end
defp calculate_delay(attempt, base_delay, max_delay) do
# Exponential backoff with jitter
exponential_delay = base_delay * :math.pow(2, attempt)
jitter = :rand.uniform() * 0.1 * exponential_delay
(exponential_delay + jitter)
|> round()
|> min(max_delay)
end
def request(api_token, method, path, body) do
@@ -141,4 +202,13 @@ defmodule Domain.Billing.Stripe.APIClient do
Domain.Config.fetch_env!(:domain, __MODULE__)
|> Keyword.fetch!(key)
end
defp fetch_retry_config(key, default) do
config = Domain.Config.fetch_env!(:domain, __MODULE__)
case Keyword.get(config, :retry_config) do
nil -> default
retry_config -> Keyword.get(retry_config, key, default)
end
end
end

View File

@@ -0,0 +1,43 @@
defmodule Domain.Billing.Stripe.APIClientTest do
use Domain.DataCase, async: true
alias Domain.Mocks.Stripe
import Domain.Billing.Stripe.APIClient
describe "client retry logic" do
test "retries on 429 rate limit responses" do
bypass = Bypass.open()
account = Fixtures.Accounts.create_account()
# Configure to return 429 on first 2 requests, then succeed
Stripe.enable_rate_limiting(2)
Stripe.mock_create_customer_endpoint(bypass, account)
# This should succeed after 2 retries
{:ok, _customer} =
create_customer("secret_token_123", account.name, "test@example.com", %{})
# Verify 3 total requests were made (2 failures + 1 success)
assert Stripe.get_request_count() == 3
Domain.Mocks.Stripe.disable_rate_limiting()
end
test "gives up after max retries exceeded" do
bypass = Bypass.open()
account = Fixtures.Accounts.create_account()
# Configure to always return 429
Stripe.configure_rate_limiting(fn _count -> true end)
Stripe.mock_create_customer_endpoint(bypass, account)
# This should fail after exhausting retries
{:error, {429, _}} =
create_customer("secret_token_123", account.name, "test@example.com", %{})
# Should have made max_retries + 1 attempts
assert Stripe.get_request_count() == 4
Domain.Mocks.Stripe.disable_rate_limiting()
end
end
end

View File

@@ -13,21 +13,30 @@ defmodule Domain.Mocks.Stripe do
test_pid = self()
Bypass.expect(bypass, "POST", customers_endpoint_path, fn conn ->
conn = Plug.Conn.fetch_query_params(conn)
conn = fetch_request_params(conn)
send(test_pid, {:bypass_request, conn})
# Store test PID in connection for rate limiting
conn = %{conn | private: Map.put(conn.private, :test_pid, test_pid)}
email = Map.get(conn.params, "email", "foo@example.com")
case check_rate_limit(conn) do
{:rate_limited, response} ->
response
resp =
Map.merge(
customer_object("cus_NffrFeUfNV2Hib", account.name, email, %{
"account_id" => account.id
}),
resp
)
:ok ->
conn = Plug.Conn.fetch_query_params(conn)
conn = fetch_request_params(conn)
send(test_pid, {:bypass_request, conn})
Plug.Conn.send_resp(conn, 200, Jason.encode!(resp))
email = Map.get(conn.params, "email", "foo@example.com")
resp =
Map.merge(
customer_object("cus_NffrFeUfNV2Hib", account.name, email, %{
"account_id" => account.id
}),
resp
)
Plug.Conn.send_resp(conn, 200, Jason.encode!(resp))
end
end)
override_endpoint_url("http://localhost:#{bypass.port}")
@@ -49,10 +58,18 @@ defmodule Domain.Mocks.Stripe do
test_pid = self()
Bypass.expect(bypass, "POST", customer_endpoint_path, fn conn ->
conn = Plug.Conn.fetch_query_params(conn)
conn = fetch_request_params(conn)
send(test_pid, {:bypass_request, conn})
Plug.Conn.send_resp(conn, 200, Jason.encode!(resp))
conn = %{conn | private: Map.put(conn.private, :test_pid, test_pid)}
case check_rate_limit(conn) do
{:rate_limited, response} ->
response
:ok ->
conn = Plug.Conn.fetch_query_params(conn)
conn = fetch_request_params(conn)
send(test_pid, {:bypass_request, conn})
Plug.Conn.send_resp(conn, 200, Jason.encode!(resp))
end
end)
override_endpoint_url("http://localhost:#{bypass.port}")
@@ -74,9 +91,17 @@ defmodule Domain.Mocks.Stripe do
test_pid = self()
Bypass.expect(bypass, "GET", customer_endpoint_path, fn conn ->
conn = Plug.Conn.fetch_query_params(conn)
send(test_pid, {:bypass_request, conn})
Plug.Conn.send_resp(conn, 200, Jason.encode!(resp))
conn = %{conn | private: Map.put(conn.private, :test_pid, test_pid)}
case check_rate_limit(conn) do
{:rate_limited, response} ->
response
:ok ->
conn = Plug.Conn.fetch_query_params(conn)
send(test_pid, {:bypass_request, conn})
Plug.Conn.send_resp(conn, 200, Jason.encode!(resp))
end
end)
override_endpoint_url("http://localhost:#{bypass.port}")
@@ -115,9 +140,17 @@ defmodule Domain.Mocks.Stripe do
test_pid = self()
Bypass.expect(bypass, "GET", product_endpoint_path, fn conn ->
conn = Plug.Conn.fetch_query_params(conn)
send(test_pid, {:bypass_request, conn})
Plug.Conn.send_resp(conn, 200, Jason.encode!(resp))
conn = %{conn | private: Map.put(conn.private, :test_pid, test_pid)}
case check_rate_limit(conn) do
{:rate_limited, response} ->
response
:ok ->
conn = Plug.Conn.fetch_query_params(conn)
send(test_pid, {:bypass_request, conn})
Plug.Conn.send_resp(conn, 200, Jason.encode!(resp))
end
end)
override_endpoint_url("http://localhost:#{bypass.port}")
@@ -150,10 +183,18 @@ defmodule Domain.Mocks.Stripe do
test_pid = self()
Bypass.expect(bypass, "POST", customers_endpoint_path, fn conn ->
conn = Plug.Conn.fetch_query_params(conn)
conn = fetch_request_params(conn)
send(test_pid, {:bypass_request, conn})
Plug.Conn.send_resp(conn, 200, Jason.encode!(resp))
conn = %{conn | private: Map.put(conn.private, :test_pid, test_pid)}
case check_rate_limit(conn) do
{:rate_limited, response} ->
response
:ok ->
conn = Plug.Conn.fetch_query_params(conn)
conn = fetch_request_params(conn)
send(test_pid, {:bypass_request, conn})
Plug.Conn.send_resp(conn, 200, Jason.encode!(resp))
end
end)
override_endpoint_url("http://localhost:#{bypass.port}")
@@ -173,10 +214,18 @@ defmodule Domain.Mocks.Stripe do
test_pid = self()
Bypass.expect(bypass, "POST", customers_endpoint_path, fn conn ->
conn = Plug.Conn.fetch_query_params(conn)
conn = fetch_request_params(conn)
send(test_pid, {:bypass_request, conn})
Plug.Conn.send_resp(conn, 200, Jason.encode!(resp))
conn = %{conn | private: Map.put(conn.private, :test_pid, test_pid)}
case check_rate_limit(conn) do
{:rate_limited, response} ->
response
:ok ->
conn = Plug.Conn.fetch_query_params(conn)
conn = fetch_request_params(conn)
send(test_pid, {:bypass_request, conn})
Plug.Conn.send_resp(conn, 200, Jason.encode!(resp))
end
end)
override_endpoint_url("http://localhost:#{bypass.port}")
@@ -184,6 +233,160 @@ defmodule Domain.Mocks.Stripe do
bypass
end
# Rate limiting control functions
def enable_rate_limiting(rate_limit_count \\ 2) do
ensure_ets_table()
test_pid = self()
state = %{
rate_limit_enabled: true,
rate_limit_count: rate_limit_count,
request_count: 0
}
:ets.insert(:stripe_mock_state, {test_pid, state})
end
def disable_rate_limiting do
ensure_ets_table()
test_pid = self()
:ets.delete(:stripe_mock_state, test_pid)
end
def reset_rate_limit_counter do
ensure_ets_table()
test_pid = self()
case :ets.lookup(:stripe_mock_state, test_pid) do
[{^test_pid, state}] ->
updated_state = Map.put(state, :request_count, 0)
:ets.insert(:stripe_mock_state, {test_pid, updated_state})
[] ->
:ok
end
end
def get_request_count do
ensure_ets_table()
test_pid = self()
case :ets.lookup(:stripe_mock_state, test_pid) do
[{^test_pid, state}] -> Map.get(state, :request_count, 0)
[] -> 0
end
end
# Rate limiting with flexible patterns
def configure_rate_limiting(pattern) when is_function(pattern, 1) do
ensure_ets_table()
test_pid = self()
state = %{
rate_limit_pattern: pattern,
request_count: 0
}
:ets.insert(:stripe_mock_state, {test_pid, state})
end
def configure_rate_limiting(opts) when is_list(opts) do
ensure_ets_table()
test_pid = self()
fail_on_attempts = Keyword.get(opts, :fail_on_attempts, [1, 2])
pattern = fn count -> count in fail_on_attempts end
state = %{
rate_limit_pattern: pattern,
request_count: 0
}
:ets.insert(:stripe_mock_state, {test_pid, state})
end
defp ensure_ets_table do
case :ets.whereis(:stripe_mock_state) do
:undefined ->
:ets.new(:stripe_mock_state, [:named_table, :public, :set])
_table ->
:ok
end
end
defp check_rate_limit(conn) do
ensure_ets_table()
# Get the test process PID from the connection
test_pid = get_test_pid(conn)
# Get current state and increment request count
{current_count, state} =
case :ets.lookup(:stripe_mock_state, test_pid) do
[{^test_pid, existing_state}] ->
new_count = Map.get(existing_state, :request_count, 0) + 1
updated_state = Map.put(existing_state, :request_count, new_count)
:ets.insert(:stripe_mock_state, {test_pid, updated_state})
{new_count, updated_state}
[] ->
{1, %{request_count: 1}}
end
cond do
# Check for custom pattern function
pattern_fn = Map.get(state, :rate_limit_pattern) ->
if pattern_fn.(current_count) do
{:rate_limited, send_rate_limit_response(conn)}
else
:ok
end
# Check for simple rate limiting
Map.get(state, :rate_limit_enabled) ->
rate_limit_count = Map.get(state, :rate_limit_count, 2)
if current_count <= rate_limit_count do
{:rate_limited, send_rate_limit_response(conn)}
else
:ok
end
true ->
:ok
end
end
defp get_test_pid(conn) do
# Look for the test PID in the connection's private assigns
case Map.get(conn.private, :test_pid) do
nil ->
# Fallback: try to find any active test process
case :ets.first(:stripe_mock_state) do
:"$end_of_table" -> self()
pid -> pid
end
pid ->
pid
end
end
defp send_rate_limit_response(conn) do
error_response = %{
"error" => %{
"code" => "lock_timeout",
"doc_url" => "https://stripe.com/docs/error-codes/lock-timeout",
"message" => "Error message here",
"request_log_url" => "https://dashboard.stripe.com/logs/req_ABC123DEF456",
"type" => "invalid_request_error"
}
}
conn
|> Plug.Conn.send_resp(429, Jason.encode!(error_response))
end
def customer_object(id, name, email \\ nil, metadata \\ %{}) do
%{
"id" => id,

View File

@@ -92,7 +92,12 @@ config :domain, Domain.Auth.Adapters.Okta.APIClient, finch_transport_opts: []
config :domain, Domain.Billing.Stripe.APIClient,
endpoint: "https://api.stripe.com",
finch_transport_opts: []
finch_transport_opts: [],
retry_config: [
max_retries: 3,
base_delay_ms: 1000,
max_delay_ms: 10_000
]
config :domain, Domain.Billing,
enabled: true,

View File

@@ -26,6 +26,15 @@ config :domain, Domain.Events.ReplicationConnection,
database: "firezone_test#{partition_suffix}"
]
config :domain, Domain.Billing.Stripe.APIClient,
endpoint: "https://api.stripe.com",
finch_transport_opts: [],
retry_config: [
max_retries: 3,
base_delay_ms: 100,
max_delay_ms: 1000
]
config :domain, Domain.Telemetry, enabled: false
config :domain, Domain.ConnectivityChecks, enabled: false