diff --git a/rust/connlib/clients/android/src/lib.rs b/rust/connlib/clients/android/src/lib.rs index 7decab0b5..6bf6a4c9a 100644 --- a/rust/connlib/clients/android/src/lib.rs +++ b/rust/connlib/clients/android/src/lib.rs @@ -366,9 +366,11 @@ fn connect( get_user_agent(Some(os_version), env!("CARGO_PKG_VERSION")), "client", (), - ExponentialBackoffBuilder::default() - .with_max_elapsed_time(Some(MAX_PARTITION_TIME)) - .build(), + || { + ExponentialBackoffBuilder::default() + .with_max_elapsed_time(Some(MAX_PARTITION_TIME)) + .build() + }, tcp_socket_factory, )?; let session = Session::connect( diff --git a/rust/connlib/clients/apple/src/lib.rs b/rust/connlib/clients/apple/src/lib.rs index 00f82fc11..49a1c73cb 100644 --- a/rust/connlib/clients/apple/src/lib.rs +++ b/rust/connlib/clients/apple/src/lib.rs @@ -242,9 +242,11 @@ impl WrappedSession { get_user_agent(os_version_override, env!("CARGO_PKG_VERSION")), "client", (), - ExponentialBackoffBuilder::default() - .with_max_elapsed_time(Some(MAX_PARTITION_TIME)) - .build(), + || { + ExponentialBackoffBuilder::default() + .with_max_elapsed_time(Some(MAX_PARTITION_TIME)) + .build() + }, Arc::new(socket_factory::tcp), )?; let session = Session::connect( diff --git a/rust/gateway/src/main.rs b/rust/gateway/src/main.rs index f7c3cc2e3..cebd788f1 100644 --- a/rust/gateway/src/main.rs +++ b/rust/gateway/src/main.rs @@ -129,9 +129,11 @@ async fn run(login: LoginUrl) -> Result { get_user_agent(None, env!("CARGO_PKG_VERSION")), PHOENIX_TOPIC, (), - ExponentialBackoffBuilder::default() - .with_max_elapsed_time(None) - .build(), + || { + ExponentialBackoffBuilder::default() + .with_max_elapsed_time(None) + .build() + }, Arc::new(tcp_socket_factory), )?; diff --git a/rust/headless-client/src/ipc_service.rs b/rust/headless-client/src/ipc_service.rs index a8756b8e4..f5de4195f 100644 --- a/rust/headless-client/src/ipc_service.rs +++ b/rust/headless-client/src/ipc_service.rs @@ -585,9 +585,11 @@ impl<'a> Handler<'a> { get_user_agent(None, "1.3.14"), "client", (), - ExponentialBackoffBuilder::default() - .with_max_elapsed_time(Some(Duration::from_secs(60 * 60 * 24 * 30))) - .build(), + || { + ExponentialBackoffBuilder::default() + .with_max_elapsed_time(Some(Duration::from_secs(60 * 60 * 24 * 30))) + .build() + }, Arc::new(tcp_socket_factory), )?; // Turn this `io::Error` directly into an `Error` so we can distinguish it from others in the GUI client. diff --git a/rust/headless-client/src/main.rs b/rust/headless-client/src/main.rs index 327f135cc..cbddd927d 100644 --- a/rust/headless-client/src/main.rs +++ b/rust/headless-client/src/main.rs @@ -218,9 +218,11 @@ fn main() -> Result<()> { get_user_agent(None, env!("CARGO_PKG_VERSION")), "client", (), - ExponentialBackoffBuilder::default() - .with_max_elapsed_time(max_partition_time) - .build(), + move || { + ExponentialBackoffBuilder::default() + .with_max_elapsed_time(max_partition_time) + .build() + }, Arc::new(tcp_socket_factory), )?; let session = Session::connect( diff --git a/rust/phoenix-channel/src/lib.rs b/rust/phoenix-channel/src/lib.rs index 25f86bd2e..a7f4e2e98 100644 --- a/rust/phoenix-channel/src/lib.rs +++ b/rust/phoenix-channel/src/lib.rs @@ -54,7 +54,8 @@ pub struct PhoenixChannel { url_prototype: Secret>, last_url: Option, user_agent: String, - reconnect_backoff: ExponentialBackoff, + make_reconnect_backoff: Box ExponentialBackoff + Send>, + reconnect_backoff: Option, resolved_addresses: Vec, @@ -135,8 +136,8 @@ pub enum Error { Client(StatusCode), #[error("token expired")] TokenExpired, - #[error("max retries reached")] - MaxRetriesReached, + #[error("max retries reached: {final_error}")] + MaxRetriesReached { final_error: String }, #[error("login failed: {0}")] LoginFailed(ErrorReply), } @@ -146,7 +147,7 @@ impl Error { match self { Error::Client(s) => s == &StatusCode::UNAUTHORIZED || s == &StatusCode::FORBIDDEN, Error::TokenExpired => true, - Error::MaxRetriesReached => false, + Error::MaxRetriesReached { .. } => false, Error::LoginFailed(_) => false, } } @@ -259,7 +260,7 @@ where user_agent: String, login: &'static str, init_req: TInitReq, - reconnect_backoff: ExponentialBackoff, + make_reconnect_backoff: impl Fn() -> ExponentialBackoff + Send + 'static, socket_factory: Arc>, ) -> io::Result { let next_request_id = Arc::new(AtomicU64::new(0)); @@ -276,7 +277,8 @@ where .collect(); Ok(Self { - reconnect_backoff, + make_reconnect_backoff: Box::new(make_reconnect_backoff), + reconnect_backoff: None, url_prototype: url, user_agent, state: State::Closed, @@ -332,7 +334,7 @@ where } // 1. Reset the backoff. - self.reconnect_backoff.reset(); + self.reconnect_backoff = None; // 2. Set state to `Connecting` without a timer. let user_agent = self.user_agent.clone(); @@ -391,7 +393,7 @@ where State::Connected(stream) => stream, State::Connecting(future) => match future.poll_unpin(cx) { Poll::Ready(Ok(stream)) => { - self.reconnect_backoff.reset(); + self.reconnect_backoff = None; self.heartbeat.reset(); self.state = State::Connected(stream); @@ -408,10 +410,18 @@ where return Poll::Ready(Err(Error::Client(r.status()))); } Poll::Ready(Err(e)) => { - let Some(backoff) = self.reconnect_backoff.next_backoff() else { - tracing::warn!("Reconnect backoff expired"); - return Poll::Ready(Err(Error::MaxRetriesReached)); - }; + let socket_addresses = self.socket_addresses(); + + // If we don't have a backoff yet, this is the first error so create one. + let reconnect_backoff = self + .reconnect_backoff + .get_or_insert_with(|| (self.make_reconnect_backoff)()); + + let backoff = reconnect_backoff.next_backoff().ok_or_else(|| { + Error::MaxRetriesReached { + final_error: err_with_src(&e).to_string(), + } + })?; let secret_url = self .last_url @@ -420,9 +430,8 @@ where .clone(); let user_agent = self.user_agent.clone(); let socket_factory = self.socket_factory.clone(); - let socket_addresses = self.socket_addresses(); - tracing::debug!(?backoff, max_elapsed_time = ?self.reconnect_backoff.max_elapsed_time, "Reconnecting to portal on transient client error: {}", err_with_src(&e)); + tracing::debug!(?backoff, max_elapsed_time = ?reconnect_backoff.max_elapsed_time, "Reconnecting to portal on transient client error: {}", err_with_src(&e)); self.state = State::Connecting(Box::pin(async move { tokio::time::sleep(backoff).await; diff --git a/rust/relay/src/main.rs b/rust/relay/src/main.rs index ffb350a22..a0bc98b73 100644 --- a/rust/relay/src/main.rs +++ b/rust/relay/src/main.rs @@ -185,9 +185,11 @@ async fn try_main(args: Args) -> Result<()> { JoinMessage { stamp_secret: server.auth_secret().expose_secret().to_string(), }, - ExponentialBackoffBuilder::default() - .with_max_elapsed_time(Some(MAX_PARTITION_TIME)) - .build(), + || { + ExponentialBackoffBuilder::default() + .with_max_elapsed_time(Some(MAX_PARTITION_TIME)) + .build() + }, Arc::new(socket_factory::tcp), )?; channel.connect(NoParams);