From 487110d0b0e8c9557e5f32bae1323a9ea7c47d7b Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Tue, 24 Oct 2023 11:12:18 +1100 Subject: [PATCH] fix(gateway): stop reconnecting on client errors (#2464) Co-authored-by: Jamil --- rust/gateway/src/main.rs | 55 ++++++++++++++++++++++++++++++--- rust/phoenix-channel/src/lib.rs | 2 +- 2 files changed, 52 insertions(+), 5 deletions(-) diff --git a/rust/gateway/src/main.rs b/rust/gateway/src/main.rs index cb348428e..9bbda7e34 100644 --- a/rust/gateway/src/main.rs +++ b/rust/gateway/src/main.rs @@ -10,7 +10,10 @@ use futures::{future, TryFutureExt}; use phoenix_channel::SecureUrl; use secrecy::{Secret, SecretString}; use std::convert::Infallible; +use std::pin::pin; use std::sync::Arc; +use tokio::signal::ctrl_c; +use tokio_tungstenite::tungstenite; use tracing_subscriber::layer; use url::Url; @@ -30,17 +33,20 @@ async fn main() -> Result<()> { )?; let tunnel = Arc::new(Tunnel::new(private_key, CallbackHandler).await?); - tokio::spawn(backoff::future::retry_notify( + let task = pin!(backoff::future::retry_notify( ExponentialBackoffBuilder::default() .with_max_elapsed_time(None) .build(), - move || run(tunnel.clone(), connect_url.clone()).map_err(backoff::Error::transient), + move || run(tunnel.clone(), connect_url.clone()).map_err(to_backoff), |error, t| { tracing::warn!(retry_in = ?t, "Error connecting to portal: {error:#}"); }, )); + let ctrl_c = pin!(ctrl_c().map_err(anyhow::Error::new)); - tokio::signal::ctrl_c().await?; + future::try_select(task, ctrl_c) + .await + .map_err(|e| e.factor_first().0)?; Ok(()) } @@ -64,7 +70,19 @@ async fn run( let mut eventloop = Eventloop::new(tunnel, portal); - future::poll_fn(|cx| eventloop.poll(cx)).await + future::poll_fn(|cx| eventloop.poll(cx)) + .await + .context("Eventloop failed") +} + +/// Maps our [`anyhow::Error`] to either a permanent or transient [`backoff`] error. +fn to_backoff(e: anyhow::Error) -> backoff::Error { + // As per HTTP spec, retrying client-errors without modifying the request is pointless. Thus we abort the backoff. + if e.chain().any(is_client_error) { + return backoff::Error::permanent(e); + } + + backoff::Error::transient(e) } #[derive(Clone)] @@ -80,3 +98,32 @@ struct Cli { #[command(flatten)] common: CommonArgs, } + +/// Checks whether the given [`std::error::Error`] is in-fact an HTTP error with a 4xx status code. +fn is_client_error(e: &(dyn std::error::Error + 'static)) -> bool { + let Some(tungstenite::Error::Http(r)) = e.downcast_ref() else { + return false; + }; + + r.status().is_client_error() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn backoff_permanent_on_client_error() { + let error = + anyhow::Error::new(phoenix_channel::Error::WebSocket(tungstenite::Error::Http( + tungstenite::http::Response::builder() + .status(400) + .body(None) + .unwrap(), + ))); + + let backoff = to_backoff(error); + + assert!(matches!(backoff, backoff::Error::Permanent(_))) + } +} diff --git a/rust/phoenix-channel/src/lib.rs b/rust/phoenix-channel/src/lib.rs index fe4db33db..a5ea19e62 100644 --- a/rust/phoenix-channel/src/lib.rs +++ b/rust/phoenix-channel/src/lib.rs @@ -94,7 +94,7 @@ pub struct UnexpectedEventDuringInit(String); pub enum Error { #[error("provided URI is missing a host")] MissingHost, - #[error(transparent)] + #[error("websocket failed")] WebSocket(#[from] tokio_tungstenite::tungstenite::Error), #[error("failed to serialize message")] Serde(#[from] serde_json::Error),