diff --git a/rust/Cargo.lock b/rust/Cargo.lock index f4dfbe328..15a22bb1c 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -4612,6 +4612,7 @@ dependencies = [ name = "phoenix-channel" version = "1.0.0" dependencies = [ + "anyhow", "backoff", "base64 0.21.7", "futures", diff --git a/rust/gateway/src/eventloop.rs b/rust/gateway/src/eventloop.rs index 394dd94a4..5f97d4cda 100644 --- a/rust/gateway/src/eventloop.rs +++ b/rust/gateway/src/eventloop.rs @@ -17,7 +17,7 @@ pub const PHOENIX_TOPIC: &str = "gateway"; pub struct Eventloop { tunnel: Arc>, - portal: PhoenixChannel, + portal: PhoenixChannel<(), IngressMessages, EgressMessages>, // TODO: Strongly type request reference (currently `String`) connection_request_tasks: @@ -31,7 +31,7 @@ pub struct Eventloop { impl Eventloop { pub(crate) fn new( tunnel: Arc>, - portal: PhoenixChannel, + portal: PhoenixChannel<(), IngressMessages, EgressMessages>, ) -> Self { Self { tunnel, @@ -234,6 +234,13 @@ impl Eventloop { } continue; } + Poll::Ready(phoenix_channel::Event::InboundMessage { + msg: IngressMessages::Init(_), + .. + }) => { + // TODO: Handle `init` message during operation. + continue; + } _ => {} } diff --git a/rust/gateway/src/main.rs b/rust/gateway/src/main.rs index 9df064a41..3576deadb 100644 --- a/rust/gateway/src/main.rs +++ b/rust/gateway/src/main.rs @@ -82,7 +82,7 @@ async fn run(connect_url: Url, private_key: StaticSecret) -> Result let tunnel: Arc> = Arc::new(Tunnel::new(private_key, CallbackHandler).await?); - let (portal, init) = phoenix_channel::init::( + let (portal, init) = phoenix_channel::init::<_, InitGateway, _, _>( Secret::new(SecureUrl::from_url(connect_url.clone())), get_user_agent(None), PHOENIX_TOPIC, diff --git a/rust/gateway/src/messages.rs b/rust/gateway/src/messages.rs index fad800148..93fae8fa1 100644 --- a/rust/gateway/src/messages.rs +++ b/rust/gateway/src/messages.rs @@ -97,6 +97,7 @@ pub enum IngressMessages { RequestConnection(RequestConnection), AllowAccess(AllowAccess), IceCandidates(ClientIceCandidates), + Init(InitGateway), } /// A client's ice candidate message. diff --git a/rust/phoenix-channel/Cargo.toml b/rust/phoenix-channel/Cargo.toml index d81ac3be6..a9e006b9e 100644 --- a/rust/phoenix-channel/Cargo.toml +++ b/rust/phoenix-channel/Cargo.toml @@ -19,3 +19,4 @@ serde_json = "1.0.108" thiserror = "1.0.50" tokio = { version = "1.33.0", features = ["net", "time"] } backoff = "0.4.0" +anyhow = "1" diff --git a/rust/phoenix-channel/src/lib.rs b/rust/phoenix-channel/src/lib.rs index 37f1c8525..db35899f7 100644 --- a/rust/phoenix-channel/src/lib.rs +++ b/rust/phoenix-channel/src/lib.rs @@ -24,7 +24,7 @@ const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(30); // TODO: Refactor this PhoenixChannel to be compatible with the needs of the client and gateway // See https://github.com/firezone/firezone/issues/2158 -pub struct PhoenixChannel { +pub struct PhoenixChannel { state: State, pending_messages: Vec, next_request_id: u64, @@ -39,6 +39,9 @@ pub struct PhoenixChannel { secret_url: Secret, user_agent: String, reconnect_backoff: ExponentialBackoff, + + login: &'static str, + init_req: TInitReq, } enum State { @@ -52,35 +55,40 @@ enum State { /// Additionally, you must already provide any query parameters required for authentication. #[tracing::instrument(level = "debug", skip(payload, secret_url, reconnect_backoff))] #[allow(clippy::type_complexity)] -pub async fn init( +pub async fn init( secret_url: Secret, user_agent: String, login_topic: &'static str, - payload: impl Serialize, + payload: TInitReq, reconnect_backoff: ExponentialBackoff, ) -> Result< - Result<(PhoenixChannel, TInitM), UnexpectedEventDuringInit>, + Result< + ( + PhoenixChannel, + TInitRes, + ), + UnexpectedEventDuringInit, + >, Error, > where - TInitM: DeserializeOwned + fmt::Debug, + TInitReq: Serialize + Clone, + TInitRes: DeserializeOwned + fmt::Debug, TInboundMsg: DeserializeOwned, TOutboundRes: DeserializeOwned, { - let mut channel = PhoenixChannel::, ()>::connect( + let mut channel = PhoenixChannel::<_, InitMessage, ()>::connect( secret_url, user_agent, + login_topic, + payload, reconnect_backoff, - ); // No reconnection on `init`. - channel.join(login_topic, payload); + ); tracing::info!("Connected to portal, waiting for `init` message"); let (channel, init_message) = loop { match future::poll_fn(|cx| channel.poll(cx)).await? { - Event::JoinedRoom { topic } if topic == login_topic => { - tracing::info!("Joined {login_topic} room on portal") - } Event::InboundMessage { topic, msg: InitMessage::Init(msg), @@ -157,8 +165,9 @@ impl secrecy::Zeroize for SecureUrl { } } -impl PhoenixChannel +impl PhoenixChannel where + TInitReq: Serialize + Clone, TInboundMsg: DeserializeOwned, TOutboundRes: DeserializeOwned, { @@ -166,12 +175,16 @@ where /// /// The provided URL must contain a host. /// Additionally, you must already provide any query parameters required for authentication. + /// + /// Once the connection is established, pub fn connect( secret_url: Secret, user_agent: String, + login: &'static str, + init_req: TInitReq, reconnect_backoff: ExponentialBackoff, ) -> Self { - Self { + let mut phoenix_channel = Self { reconnect_backoff, secret_url: secret_url.clone(), user_agent: user_agent.clone(), @@ -185,7 +198,12 @@ where next_request_id: 0, next_heartbeat: Box::pin(tokio::time::sleep(HEARTBEAT_INTERVAL)), pending_join_requests: Default::default(), - } + login, + init_req: init_req.clone(), + }; + phoenix_channel.join(login, init_req); + + phoenix_channel } /// Join the provided room. @@ -216,6 +234,7 @@ where self.state = State::Connected(stream); tracing::info!("Connected to portal"); + self.join(self.login, self.init_req.clone()); continue; } @@ -243,10 +262,11 @@ where tracing::warn!("Reconnect backoff expired"); return Poll::Ready(Err(e)); }; + let secret_url = self.secret_url.clone(); let user_agent = self.user_agent.clone(); - tracing::debug!(?backoff, max_elapsed_time = ?self.reconnect_backoff.max_elapsed_time, "Reconnecting to portal"); + tracing::debug!(?backoff, max_elapsed_time = ?self.reconnect_backoff.max_elapsed_time, "Reconnecting to portal on transient client error: {:#}", anyhow::Error::from(e)); self.state = State::Connecting(Box::pin(async move { tokio::time::sleep(backoff).await; @@ -267,7 +287,7 @@ where match stream.start_send_unpin(message) { Ok(()) => {} Err(e) => { - self.reconnect_on_transient_error(e); + self.reconnect_on_transient_error(Error::WebSocket(e)); } } continue; @@ -289,6 +309,10 @@ where >(&text) { Ok(m) => m, + Err(e) if e.is_io() || e.is_eof() => { + self.reconnect_on_transient_error(Error::Serde(e)); + continue; + } Err(e) => { tracing::warn!("Failed to deserialize message {text}: {e}"); continue; @@ -301,7 +325,7 @@ where return Poll::Ready(Ok(Event::InboundMessage { topic: message.topic, msg, - })) + })); } Some(reference) => { return Poll::Ready(Ok(Event::InboundReq { @@ -328,6 +352,8 @@ where OutboundRequestId(message.reference.ok_or(Error::MissingReplyId)?); if self.pending_join_requests.remove(&req_id) { + tracing::info!("Joined {} room on portal", message.topic); + // For `phx_join` requests, `reply` is empty so we can safely ignore it. return Poll::Ready(Ok(Event::JoinedRoom { topic: message.topic, @@ -370,7 +396,7 @@ where } } Poll::Ready(Some(Err(e))) => { - self.reconnect_on_transient_error(e); + self.reconnect_on_transient_error(Error::WebSocket(e)); continue; } _ => (), @@ -392,7 +418,7 @@ where tracing::trace!("Flushed websocket"); } Poll::Ready(Err(e)) => { - self.reconnect_on_transient_error(e); + self.reconnect_on_transient_error(Error::WebSocket(e)); continue; } Poll::Pending => {} @@ -405,9 +431,8 @@ where /// Sets the channels state to [`State::Connecting`] with the given error. /// /// The [`PhoenixChannel::poll`] function will handle the reconnect if appropriate for the given error. - fn reconnect_on_transient_error(&mut self, e: tokio_tungstenite::tungstenite::Error) { - tracing::info!("Websocket disconnected: {e:#?}"); - self.state = State::Connecting(future::ready(Err(Error::WebSocket(e))).boxed()) + fn reconnect_on_transient_error(&mut self, e: Error) { + self.state = State::Connecting(future::ready(Err(e)).boxed()) } fn send_message( @@ -436,7 +461,7 @@ where /// Cast this instance of [PhoenixChannel] to new message types. fn cast( self, - ) -> PhoenixChannel { + ) -> PhoenixChannel { PhoenixChannel { state: self.state, pending_messages: self.pending_messages, @@ -447,6 +472,8 @@ where secret_url: self.secret_url, user_agent: self.user_agent, reconnect_backoff: self.reconnect_backoff, + login: self.login, + init_req: self.init_req, } } } diff --git a/rust/relay/src/main.rs b/rust/relay/src/main.rs index 1a06434a9..2304905fe 100644 --- a/rust/relay/src/main.rs +++ b/rust/relay/src/main.rs @@ -246,7 +246,7 @@ async fn connect_to_portal( token: &SecretString, mut url: Url, stamp_secret: &SecretString, -) -> Result>> { +) -> Result>> { use secrecy::ExposeSecret; if !url.path().is_empty() { @@ -266,7 +266,7 @@ async fn connect_to_portal( .append_pair("ipv6", &public_ip6_addr.to_string()); } - let (channel, Init {}) = phoenix_channel::init::( + let (channel, Init {}) = phoenix_channel::init::<_, Init, _, _>( Secret::from(SecureUrl::from_url(url)), format!("relay/{}", env!("CARGO_PKG_VERSION")), "relay", @@ -285,7 +285,7 @@ async fn connect_to_portal( #[derive(serde::Deserialize, Debug)] struct Init {} -#[derive(serde::Serialize, PartialEq, Debug)] +#[derive(serde::Serialize, PartialEq, Debug, Clone)] struct JoinMessage { stamp_secret: String, } @@ -315,7 +315,7 @@ struct Eventloop { outbound_ip4_data_sender: mpsc::Sender<(Vec, SocketAddr)>, outbound_ip6_data_sender: mpsc::Sender<(Vec, SocketAddr)>, server: Server, - channel: Option>, + channel: Option>, allocations: HashMap<(AllocationId, AddressFamily), Allocation>, relay_data_sender: mpsc::Sender<(Vec, SocketAddr, AllocationId)>, relay_data_receiver: mpsc::Receiver<(Vec, SocketAddr, AllocationId)>, @@ -328,7 +328,7 @@ where { fn new( server: Server, - channel: Option>, + channel: Option>, public_address: IpStack, ) -> Result { let (relay_data_sender, relay_data_receiver) = mpsc::channel(1);