mirror of
https://github.com/outbackdingo/firezone.git
synced 2026-01-27 18:18:55 +00:00
refactor(portal): add retry logic to Stripe API client (#9466)
Why: * We've seen some Stripe API requests come back with 429 responses, which likely could be retried and succeed. This commit adds some basic retry logic to our Stripe API client.
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
defmodule Domain.Billing.Stripe.APIClient do
|
||||
use Supervisor
|
||||
require Logger
|
||||
|
||||
@pool_name __MODULE__.Finch
|
||||
|
||||
@@ -37,7 +38,7 @@ defmodule Domain.Billing.Stripe.APIClient do
|
||||
|> put_if_not_nil("email", email)
|
||||
|> URI.encode_query(:www_form)
|
||||
|
||||
request(api_token, :post, "customers", body)
|
||||
request_with_retry(api_token, :post, "customers", body)
|
||||
end
|
||||
|
||||
def update_customer(api_token, customer_id, name, metadata) do
|
||||
@@ -51,11 +52,11 @@ defmodule Domain.Billing.Stripe.APIClient do
|
||||
|> Map.put("name", name)
|
||||
|> URI.encode_query(:www_form)
|
||||
|
||||
request(api_token, :post, "customers/#{customer_id}", body)
|
||||
request_with_retry(api_token, :post, "customers/#{customer_id}", body)
|
||||
end
|
||||
|
||||
def fetch_customer(api_token, customer_id) do
|
||||
request(api_token, :get, "customers/#{customer_id}", "")
|
||||
request_with_retry(api_token, :get, "customers/#{customer_id}", "")
|
||||
end
|
||||
|
||||
def list_all_subscriptions(api_token, page_after \\ nil, acc \\ []) do
|
||||
@@ -66,7 +67,7 @@ defmodule Domain.Billing.Stripe.APIClient do
|
||||
""
|
||||
end
|
||||
|
||||
case request(api_token, :get, "subscriptions#{query_params}", "") do
|
||||
case request_with_retry(api_token, :get, "subscriptions#{query_params}", "") do
|
||||
{:ok, %{"has_more" => true, "data" => data}} ->
|
||||
page_after = List.last(data)["id"]
|
||||
list_all_subscriptions(api_token, page_after, acc ++ data)
|
||||
@@ -80,12 +81,12 @@ defmodule Domain.Billing.Stripe.APIClient do
|
||||
end
|
||||
|
||||
def fetch_product(api_token, product_id) do
|
||||
request(api_token, :get, "products/#{product_id}", "")
|
||||
request_with_retry(api_token, :get, "products/#{product_id}", "")
|
||||
end
|
||||
|
||||
def create_billing_portal_session(api_token, customer_id, return_url) do
|
||||
body = URI.encode_query(%{"customer" => customer_id, "return_url" => return_url}, :www_form)
|
||||
request(api_token, :post, "billing_portal/sessions", body)
|
||||
request_with_retry(api_token, :post, "billing_portal/sessions", body)
|
||||
end
|
||||
|
||||
def create_subscription(api_token, customer_id, price_id) do
|
||||
@@ -98,7 +99,67 @@ defmodule Domain.Billing.Stripe.APIClient do
|
||||
:www_form
|
||||
)
|
||||
|
||||
request(api_token, :post, "subscriptions", body)
|
||||
request_with_retry(api_token, :post, "subscriptions", body)
|
||||
end
|
||||
|
||||
def request_with_retry(api_token, method, path, body) do
|
||||
max_retries = fetch_retry_config(:max_retries, 3)
|
||||
base_delay = fetch_retry_config(:base_delay_ms, 1000)
|
||||
max_delay = fetch_retry_config(:max_delay_ms, 10_000)
|
||||
|
||||
do_request_with_retry(api_token, method, path, body, 0, max_retries, base_delay, max_delay)
|
||||
end
|
||||
|
||||
defp do_request_with_retry(
|
||||
api_token,
|
||||
method,
|
||||
path,
|
||||
body,
|
||||
attempt,
|
||||
max_retries,
|
||||
base_delay,
|
||||
max_delay
|
||||
) do
|
||||
case request(api_token, method, path, body) do
|
||||
{:ok, response} ->
|
||||
{:ok, response}
|
||||
|
||||
{:error, {429, response}} when attempt < max_retries ->
|
||||
delay = calculate_delay(attempt, base_delay, max_delay)
|
||||
|
||||
Logger.warning(
|
||||
"Rate limited by Stripe API (429), retrying request.",
|
||||
request_delay: "#{delay}ms",
|
||||
attempt_num: "#{attempt + 1} of #{max_retries}",
|
||||
response: inspect(response)
|
||||
)
|
||||
|
||||
Process.sleep(delay)
|
||||
|
||||
do_request_with_retry(
|
||||
api_token,
|
||||
method,
|
||||
path,
|
||||
body,
|
||||
attempt + 1,
|
||||
max_retries,
|
||||
base_delay,
|
||||
max_delay
|
||||
)
|
||||
|
||||
{:error, reason} ->
|
||||
{:error, reason}
|
||||
end
|
||||
end
|
||||
|
||||
defp calculate_delay(attempt, base_delay, max_delay) do
|
||||
# Exponential backoff with jitter
|
||||
exponential_delay = base_delay * :math.pow(2, attempt)
|
||||
jitter = :rand.uniform() * 0.1 * exponential_delay
|
||||
|
||||
(exponential_delay + jitter)
|
||||
|> round()
|
||||
|> min(max_delay)
|
||||
end
|
||||
|
||||
def request(api_token, method, path, body) do
|
||||
@@ -141,4 +202,13 @@ defmodule Domain.Billing.Stripe.APIClient do
|
||||
Domain.Config.fetch_env!(:domain, __MODULE__)
|
||||
|> Keyword.fetch!(key)
|
||||
end
|
||||
|
||||
defp fetch_retry_config(key, default) do
|
||||
config = Domain.Config.fetch_env!(:domain, __MODULE__)
|
||||
|
||||
case Keyword.get(config, :retry_config) do
|
||||
nil -> default
|
||||
retry_config -> Keyword.get(retry_config, key, default)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
@@ -0,0 +1,43 @@
|
||||
defmodule Domain.Billing.Stripe.APIClientTest do
|
||||
use Domain.DataCase, async: true
|
||||
alias Domain.Mocks.Stripe
|
||||
import Domain.Billing.Stripe.APIClient
|
||||
|
||||
describe "client retry logic" do
|
||||
test "retries on 429 rate limit responses" do
|
||||
bypass = Bypass.open()
|
||||
account = Fixtures.Accounts.create_account()
|
||||
|
||||
# Configure to return 429 on first 2 requests, then succeed
|
||||
Stripe.enable_rate_limiting(2)
|
||||
Stripe.mock_create_customer_endpoint(bypass, account)
|
||||
|
||||
# This should succeed after 2 retries
|
||||
{:ok, _customer} =
|
||||
create_customer("secret_token_123", account.name, "test@example.com", %{})
|
||||
|
||||
# Verify 3 total requests were made (2 failures + 1 success)
|
||||
assert Stripe.get_request_count() == 3
|
||||
|
||||
Domain.Mocks.Stripe.disable_rate_limiting()
|
||||
end
|
||||
|
||||
test "gives up after max retries exceeded" do
|
||||
bypass = Bypass.open()
|
||||
account = Fixtures.Accounts.create_account()
|
||||
|
||||
# Configure to always return 429
|
||||
Stripe.configure_rate_limiting(fn _count -> true end)
|
||||
Stripe.mock_create_customer_endpoint(bypass, account)
|
||||
|
||||
# This should fail after exhausting retries
|
||||
{:error, {429, _}} =
|
||||
create_customer("secret_token_123", account.name, "test@example.com", %{})
|
||||
|
||||
# Should have made max_retries + 1 attempts
|
||||
assert Stripe.get_request_count() == 4
|
||||
|
||||
Domain.Mocks.Stripe.disable_rate_limiting()
|
||||
end
|
||||
end
|
||||
end
|
||||
@@ -13,21 +13,30 @@ defmodule Domain.Mocks.Stripe do
|
||||
test_pid = self()
|
||||
|
||||
Bypass.expect(bypass, "POST", customers_endpoint_path, fn conn ->
|
||||
conn = Plug.Conn.fetch_query_params(conn)
|
||||
conn = fetch_request_params(conn)
|
||||
send(test_pid, {:bypass_request, conn})
|
||||
# Store test PID in connection for rate limiting
|
||||
conn = %{conn | private: Map.put(conn.private, :test_pid, test_pid)}
|
||||
|
||||
email = Map.get(conn.params, "email", "foo@example.com")
|
||||
case check_rate_limit(conn) do
|
||||
{:rate_limited, response} ->
|
||||
response
|
||||
|
||||
resp =
|
||||
Map.merge(
|
||||
customer_object("cus_NffrFeUfNV2Hib", account.name, email, %{
|
||||
"account_id" => account.id
|
||||
}),
|
||||
resp
|
||||
)
|
||||
:ok ->
|
||||
conn = Plug.Conn.fetch_query_params(conn)
|
||||
conn = fetch_request_params(conn)
|
||||
send(test_pid, {:bypass_request, conn})
|
||||
|
||||
Plug.Conn.send_resp(conn, 200, Jason.encode!(resp))
|
||||
email = Map.get(conn.params, "email", "foo@example.com")
|
||||
|
||||
resp =
|
||||
Map.merge(
|
||||
customer_object("cus_NffrFeUfNV2Hib", account.name, email, %{
|
||||
"account_id" => account.id
|
||||
}),
|
||||
resp
|
||||
)
|
||||
|
||||
Plug.Conn.send_resp(conn, 200, Jason.encode!(resp))
|
||||
end
|
||||
end)
|
||||
|
||||
override_endpoint_url("http://localhost:#{bypass.port}")
|
||||
@@ -49,10 +58,18 @@ defmodule Domain.Mocks.Stripe do
|
||||
test_pid = self()
|
||||
|
||||
Bypass.expect(bypass, "POST", customer_endpoint_path, fn conn ->
|
||||
conn = Plug.Conn.fetch_query_params(conn)
|
||||
conn = fetch_request_params(conn)
|
||||
send(test_pid, {:bypass_request, conn})
|
||||
Plug.Conn.send_resp(conn, 200, Jason.encode!(resp))
|
||||
conn = %{conn | private: Map.put(conn.private, :test_pid, test_pid)}
|
||||
|
||||
case check_rate_limit(conn) do
|
||||
{:rate_limited, response} ->
|
||||
response
|
||||
|
||||
:ok ->
|
||||
conn = Plug.Conn.fetch_query_params(conn)
|
||||
conn = fetch_request_params(conn)
|
||||
send(test_pid, {:bypass_request, conn})
|
||||
Plug.Conn.send_resp(conn, 200, Jason.encode!(resp))
|
||||
end
|
||||
end)
|
||||
|
||||
override_endpoint_url("http://localhost:#{bypass.port}")
|
||||
@@ -74,9 +91,17 @@ defmodule Domain.Mocks.Stripe do
|
||||
test_pid = self()
|
||||
|
||||
Bypass.expect(bypass, "GET", customer_endpoint_path, fn conn ->
|
||||
conn = Plug.Conn.fetch_query_params(conn)
|
||||
send(test_pid, {:bypass_request, conn})
|
||||
Plug.Conn.send_resp(conn, 200, Jason.encode!(resp))
|
||||
conn = %{conn | private: Map.put(conn.private, :test_pid, test_pid)}
|
||||
|
||||
case check_rate_limit(conn) do
|
||||
{:rate_limited, response} ->
|
||||
response
|
||||
|
||||
:ok ->
|
||||
conn = Plug.Conn.fetch_query_params(conn)
|
||||
send(test_pid, {:bypass_request, conn})
|
||||
Plug.Conn.send_resp(conn, 200, Jason.encode!(resp))
|
||||
end
|
||||
end)
|
||||
|
||||
override_endpoint_url("http://localhost:#{bypass.port}")
|
||||
@@ -115,9 +140,17 @@ defmodule Domain.Mocks.Stripe do
|
||||
test_pid = self()
|
||||
|
||||
Bypass.expect(bypass, "GET", product_endpoint_path, fn conn ->
|
||||
conn = Plug.Conn.fetch_query_params(conn)
|
||||
send(test_pid, {:bypass_request, conn})
|
||||
Plug.Conn.send_resp(conn, 200, Jason.encode!(resp))
|
||||
conn = %{conn | private: Map.put(conn.private, :test_pid, test_pid)}
|
||||
|
||||
case check_rate_limit(conn) do
|
||||
{:rate_limited, response} ->
|
||||
response
|
||||
|
||||
:ok ->
|
||||
conn = Plug.Conn.fetch_query_params(conn)
|
||||
send(test_pid, {:bypass_request, conn})
|
||||
Plug.Conn.send_resp(conn, 200, Jason.encode!(resp))
|
||||
end
|
||||
end)
|
||||
|
||||
override_endpoint_url("http://localhost:#{bypass.port}")
|
||||
@@ -150,10 +183,18 @@ defmodule Domain.Mocks.Stripe do
|
||||
test_pid = self()
|
||||
|
||||
Bypass.expect(bypass, "POST", customers_endpoint_path, fn conn ->
|
||||
conn = Plug.Conn.fetch_query_params(conn)
|
||||
conn = fetch_request_params(conn)
|
||||
send(test_pid, {:bypass_request, conn})
|
||||
Plug.Conn.send_resp(conn, 200, Jason.encode!(resp))
|
||||
conn = %{conn | private: Map.put(conn.private, :test_pid, test_pid)}
|
||||
|
||||
case check_rate_limit(conn) do
|
||||
{:rate_limited, response} ->
|
||||
response
|
||||
|
||||
:ok ->
|
||||
conn = Plug.Conn.fetch_query_params(conn)
|
||||
conn = fetch_request_params(conn)
|
||||
send(test_pid, {:bypass_request, conn})
|
||||
Plug.Conn.send_resp(conn, 200, Jason.encode!(resp))
|
||||
end
|
||||
end)
|
||||
|
||||
override_endpoint_url("http://localhost:#{bypass.port}")
|
||||
@@ -173,10 +214,18 @@ defmodule Domain.Mocks.Stripe do
|
||||
test_pid = self()
|
||||
|
||||
Bypass.expect(bypass, "POST", customers_endpoint_path, fn conn ->
|
||||
conn = Plug.Conn.fetch_query_params(conn)
|
||||
conn = fetch_request_params(conn)
|
||||
send(test_pid, {:bypass_request, conn})
|
||||
Plug.Conn.send_resp(conn, 200, Jason.encode!(resp))
|
||||
conn = %{conn | private: Map.put(conn.private, :test_pid, test_pid)}
|
||||
|
||||
case check_rate_limit(conn) do
|
||||
{:rate_limited, response} ->
|
||||
response
|
||||
|
||||
:ok ->
|
||||
conn = Plug.Conn.fetch_query_params(conn)
|
||||
conn = fetch_request_params(conn)
|
||||
send(test_pid, {:bypass_request, conn})
|
||||
Plug.Conn.send_resp(conn, 200, Jason.encode!(resp))
|
||||
end
|
||||
end)
|
||||
|
||||
override_endpoint_url("http://localhost:#{bypass.port}")
|
||||
@@ -184,6 +233,160 @@ defmodule Domain.Mocks.Stripe do
|
||||
bypass
|
||||
end
|
||||
|
||||
# Rate limiting control functions
|
||||
def enable_rate_limiting(rate_limit_count \\ 2) do
|
||||
ensure_ets_table()
|
||||
test_pid = self()
|
||||
|
||||
state = %{
|
||||
rate_limit_enabled: true,
|
||||
rate_limit_count: rate_limit_count,
|
||||
request_count: 0
|
||||
}
|
||||
|
||||
:ets.insert(:stripe_mock_state, {test_pid, state})
|
||||
end
|
||||
|
||||
def disable_rate_limiting do
|
||||
ensure_ets_table()
|
||||
test_pid = self()
|
||||
:ets.delete(:stripe_mock_state, test_pid)
|
||||
end
|
||||
|
||||
def reset_rate_limit_counter do
|
||||
ensure_ets_table()
|
||||
test_pid = self()
|
||||
|
||||
case :ets.lookup(:stripe_mock_state, test_pid) do
|
||||
[{^test_pid, state}] ->
|
||||
updated_state = Map.put(state, :request_count, 0)
|
||||
:ets.insert(:stripe_mock_state, {test_pid, updated_state})
|
||||
|
||||
[] ->
|
||||
:ok
|
||||
end
|
||||
end
|
||||
|
||||
def get_request_count do
|
||||
ensure_ets_table()
|
||||
test_pid = self()
|
||||
|
||||
case :ets.lookup(:stripe_mock_state, test_pid) do
|
||||
[{^test_pid, state}] -> Map.get(state, :request_count, 0)
|
||||
[] -> 0
|
||||
end
|
||||
end
|
||||
|
||||
# Rate limiting with flexible patterns
|
||||
def configure_rate_limiting(pattern) when is_function(pattern, 1) do
|
||||
ensure_ets_table()
|
||||
test_pid = self()
|
||||
|
||||
state = %{
|
||||
rate_limit_pattern: pattern,
|
||||
request_count: 0
|
||||
}
|
||||
|
||||
:ets.insert(:stripe_mock_state, {test_pid, state})
|
||||
end
|
||||
|
||||
def configure_rate_limiting(opts) when is_list(opts) do
|
||||
ensure_ets_table()
|
||||
test_pid = self()
|
||||
fail_on_attempts = Keyword.get(opts, :fail_on_attempts, [1, 2])
|
||||
pattern = fn count -> count in fail_on_attempts end
|
||||
|
||||
state = %{
|
||||
rate_limit_pattern: pattern,
|
||||
request_count: 0
|
||||
}
|
||||
|
||||
:ets.insert(:stripe_mock_state, {test_pid, state})
|
||||
end
|
||||
|
||||
defp ensure_ets_table do
|
||||
case :ets.whereis(:stripe_mock_state) do
|
||||
:undefined ->
|
||||
:ets.new(:stripe_mock_state, [:named_table, :public, :set])
|
||||
|
||||
_table ->
|
||||
:ok
|
||||
end
|
||||
end
|
||||
|
||||
defp check_rate_limit(conn) do
|
||||
ensure_ets_table()
|
||||
|
||||
# Get the test process PID from the connection
|
||||
test_pid = get_test_pid(conn)
|
||||
|
||||
# Get current state and increment request count
|
||||
{current_count, state} =
|
||||
case :ets.lookup(:stripe_mock_state, test_pid) do
|
||||
[{^test_pid, existing_state}] ->
|
||||
new_count = Map.get(existing_state, :request_count, 0) + 1
|
||||
updated_state = Map.put(existing_state, :request_count, new_count)
|
||||
:ets.insert(:stripe_mock_state, {test_pid, updated_state})
|
||||
{new_count, updated_state}
|
||||
|
||||
[] ->
|
||||
{1, %{request_count: 1}}
|
||||
end
|
||||
|
||||
cond do
|
||||
# Check for custom pattern function
|
||||
pattern_fn = Map.get(state, :rate_limit_pattern) ->
|
||||
if pattern_fn.(current_count) do
|
||||
{:rate_limited, send_rate_limit_response(conn)}
|
||||
else
|
||||
:ok
|
||||
end
|
||||
|
||||
# Check for simple rate limiting
|
||||
Map.get(state, :rate_limit_enabled) ->
|
||||
rate_limit_count = Map.get(state, :rate_limit_count, 2)
|
||||
|
||||
if current_count <= rate_limit_count do
|
||||
{:rate_limited, send_rate_limit_response(conn)}
|
||||
else
|
||||
:ok
|
||||
end
|
||||
|
||||
true ->
|
||||
:ok
|
||||
end
|
||||
end
|
||||
|
||||
defp get_test_pid(conn) do
|
||||
# Look for the test PID in the connection's private assigns
|
||||
case Map.get(conn.private, :test_pid) do
|
||||
nil ->
|
||||
# Fallback: try to find any active test process
|
||||
case :ets.first(:stripe_mock_state) do
|
||||
:"$end_of_table" -> self()
|
||||
pid -> pid
|
||||
end
|
||||
|
||||
pid ->
|
||||
pid
|
||||
end
|
||||
end
|
||||
|
||||
defp send_rate_limit_response(conn) do
|
||||
error_response = %{
|
||||
"error" => %{
|
||||
"code" => "lock_timeout",
|
||||
"doc_url" => "https://stripe.com/docs/error-codes/lock-timeout",
|
||||
"message" => "Error message here",
|
||||
"request_log_url" => "https://dashboard.stripe.com/logs/req_ABC123DEF456",
|
||||
"type" => "invalid_request_error"
|
||||
}
|
||||
}
|
||||
|
||||
conn
|
||||
|> Plug.Conn.send_resp(429, Jason.encode!(error_response))
|
||||
end
|
||||
|
||||
def customer_object(id, name, email \\ nil, metadata \\ %{}) do
|
||||
%{
|
||||
"id" => id,
|
||||
|
||||
@@ -92,7 +92,12 @@ config :domain, Domain.Auth.Adapters.Okta.APIClient, finch_transport_opts: []
|
||||
|
||||
config :domain, Domain.Billing.Stripe.APIClient,
|
||||
endpoint: "https://api.stripe.com",
|
||||
finch_transport_opts: []
|
||||
finch_transport_opts: [],
|
||||
retry_config: [
|
||||
max_retries: 3,
|
||||
base_delay_ms: 1000,
|
||||
max_delay_ms: 10_000
|
||||
]
|
||||
|
||||
config :domain, Domain.Billing,
|
||||
enabled: true,
|
||||
|
||||
@@ -26,6 +26,15 @@ config :domain, Domain.Events.ReplicationConnection,
|
||||
database: "firezone_test#{partition_suffix}"
|
||||
]
|
||||
|
||||
config :domain, Domain.Billing.Stripe.APIClient,
|
||||
endpoint: "https://api.stripe.com",
|
||||
finch_transport_opts: [],
|
||||
retry_config: [
|
||||
max_retries: 3,
|
||||
base_delay_ms: 100,
|
||||
max_delay_ms: 1000
|
||||
]
|
||||
|
||||
config :domain, Domain.Telemetry, enabled: false
|
||||
|
||||
config :domain, Domain.ConnectivityChecks, enabled: false
|
||||
|
||||
Reference in New Issue
Block a user