diff --git a/rust/gateway/src/eventloop.rs b/rust/gateway/src/eventloop.rs index 3b92bc898..8c6af57f6 100644 --- a/rust/gateway/src/eventloop.rs +++ b/rust/gateway/src/eventloop.rs @@ -19,9 +19,9 @@ use webrtc::peer_connection::sdp::session_description::RTCSessionDescription; pub const PHOENIX_TOPIC: &str = "gateway"; -pub struct Eventloop { +pub struct Eventloop<'a> { tunnel: Arc>, - control_rx: mpsc::Receiver<(ClientId, RTCIceCandidate)>, + control_rx: &'a mut mpsc::Receiver<(ClientId, RTCIceCandidate)>, portal: PhoenixChannel, // TODO: Strongly type request reference (currently `String`) @@ -32,12 +32,12 @@ pub struct Eventloop { print_stats_timer: tokio::time::Interval, } -impl Eventloop { +impl<'a> Eventloop<'a> { pub(crate) fn new( tunnel: Arc>, - control_rx: mpsc::Receiver<(ClientId, RTCIceCandidate)>, + control_rx: &'a mut mpsc::Receiver<(ClientId, RTCIceCandidate)>, portal: PhoenixChannel, - ) -> Self { + ) -> Eventloop<'a> { Self { tunnel, control_rx, @@ -54,7 +54,7 @@ impl Eventloop { } } -impl Eventloop { +impl Eventloop<'_> { #[tracing::instrument(name = "Eventloop::poll", skip_all, level = "debug")] pub fn poll(&mut self, cx: &mut Context<'_>) -> Poll> { loop { diff --git a/rust/gateway/src/main.rs b/rust/gateway/src/main.rs index 63be23b5a..2de06e8c8 100644 --- a/rust/gateway/src/main.rs +++ b/rust/gateway/src/main.rs @@ -1,20 +1,24 @@ use crate::control::ControlSignaler; use crate::eventloop::{Eventloop, PHOENIX_TOPIC}; use crate::messages::InitGateway; -use anyhow::{Context as _, Result}; +use anyhow::{Context, Result}; +use backoff::backoff::Backoff; use backoff::ExponentialBackoffBuilder; -use boringtun::x25519::StaticSecret; use clap::Parser; +use connlib_shared::messages::ClientId; use connlib_shared::{get_device_id, get_user_agent, login_url, Callbacks, Mode}; use firezone_tunnel::Tunnel; -use futures::{future, TryFutureExt}; +use futures::future; use headless_utils::{setup_global_subscriber, CommonArgs}; use phoenix_channel::SecureUrl; use secrecy::{Secret, SecretString}; use std::convert::Infallible; +use std::pin::pin; use std::sync::Arc; -use std::time::Duration; +use tokio::sync::mpsc; use tracing_subscriber::layer; +use url::Url; +use webrtc::ice_transport::ice_candidate::RTCIceCandidate; mod control; mod eventloop; @@ -32,35 +36,41 @@ async fn main() -> Result<()> { get_device_id(), )?; - tokio::spawn(backoff::future::retry_notify( - ExponentialBackoffBuilder::default() - .with_max_elapsed_time(None) - .build(), - move || { - connect( - private_key.clone(), - Secret::new(SecureUrl::from_url(connect_url.clone())), - ) - .map_err(backoff::Error::transient) - }, - |error, t: Duration| { - tracing::warn!(retry_in = ?t, "Error connecting to portal: {error}"); - }, - )); + // Note: This channel is only needed because [`Tunnel`] does not (yet) have a synchronous, poll-like interface. If it would have, ICE candidates would be emitted as events and we could just hand them to the phoenix channel. + let (control_tx, mut control_rx) = mpsc::channel(1); + let signaler = ControlSignaler::new(control_tx); + let tunnel = Arc::new(Tunnel::new(private_key, signaler, CallbackHandler).await?); - tokio::signal::ctrl_c().await?; + let mut backoff = ExponentialBackoffBuilder::default() + .with_max_elapsed_time(None) + .build(); + + let eventloop = async { + loop { + let error = match run(tunnel.clone(), &mut control_rx, connect_url.clone()).await { + Err(e) => e, + Ok(never) => match never {}, + }; + + let t = backoff.next_backoff().expect("the exponential backoff reconnect loop should run indefinetly"); + tracing::warn!(retry_in = ?t, "Error connecting to portal: {error:#}"); + + tokio::time::sleep(t).await; + } + }; + + future::select(pin!(eventloop), pin!(tokio::signal::ctrl_c())).await; Ok(()) } -async fn connect(private_key: StaticSecret, connect_url: Secret) -> Result { - // Note: This is only needed because [`Tunnel`] does not (yet) have a synchronous, poll-like interface. If it would have, ICE candidates would be emitted as events and we could just hand them to the phoenix channel. - let (control_tx, control_rx) = tokio::sync::mpsc::channel(1); - let signaler = ControlSignaler::new(control_tx); - let tunnel = Arc::new(Tunnel::new(private_key, signaler, CallbackHandler).await?); - - let (channel, init) = phoenix_channel::init::( - connect_url, +async fn run( + tunnel: Arc>, + control_rx: &mut mpsc::Receiver<(ClientId, RTCIceCandidate)>, + connect_url: Url, +) -> Result { + let (portal, init) = phoenix_channel::init::( + Secret::new(SecureUrl::from_url(connect_url)), get_user_agent(), PHOENIX_TOPIC, (), @@ -72,7 +82,7 @@ async fn connect(private_key: StaticSecret, connect_url: Secret) -> R .await .context("Failed to set interface")?; - let mut eventloop = Eventloop::new(tunnel, control_rx, channel); + let mut eventloop = Eventloop::new(tunnel, control_rx, portal); future::poll_fn(|cx| eventloop.poll(cx)).await } @@ -81,7 +91,7 @@ async fn connect(private_key: StaticSecret, connect_url: Secret) -> R struct CallbackHandler; impl Callbacks for CallbackHandler { - type Error = std::convert::Infallible; + type Error = Infallible; } #[derive(Parser)]