From 407d20d81738901239e90b967d545414215cda04 Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Tue, 12 Mar 2024 19:10:56 +1100 Subject: [PATCH] refactor(connlib): use `phoenix-channel` crate for clients (#3682) Depends-On: #4048. Depends-On: #4015. Resolves: #2158. --------- Co-authored-by: conectado --- rust/Cargo.lock | 1 + rust/connlib/clients/android/src/lib.rs | 6 +- rust/connlib/clients/apple/Cargo.toml | 1 + rust/connlib/clients/apple/src/lib.rs | 2 +- rust/connlib/clients/shared/src/control.rs | 517 ----------------- rust/connlib/clients/shared/src/eventloop.rs | 564 +++++++++++++++++++ rust/connlib/clients/shared/src/lib.rs | 219 ++----- rust/connlib/clients/shared/src/messages.rs | 39 +- rust/connlib/shared/src/error.rs | 3 + rust/connlib/shared/src/lib.rs | 1 - rust/gateway/src/messages.rs | 2 +- rust/gui-client/src-tauri/src/client/gui.rs | 2 +- rust/phoenix-channel/src/heartbeat.rs | 6 +- rust/phoenix-channel/src/lib.rs | 67 ++- rust/relay/src/main.rs | 8 +- 15 files changed, 698 insertions(+), 740 deletions(-) delete mode 100644 rust/connlib/clients/shared/src/control.rs create mode 100644 rust/connlib/clients/shared/src/eventloop.rs diff --git a/rust/Cargo.lock b/rust/Cargo.lock index 3747694db..b027e3301 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -1117,6 +1117,7 @@ dependencies = [ "tracing-appender", "tracing-oslog", "tracing-subscriber", + "url", "walkdir", ] diff --git a/rust/connlib/clients/android/src/lib.rs b/rust/connlib/clients/android/src/lib.rs index a42ed5b7e..f1557d1f2 100644 --- a/rust/connlib/clients/android/src/lib.rs +++ b/rust/connlib/clients/android/src/lib.rs @@ -398,7 +398,7 @@ fn connect( log_dir: JString, log_filter: JString, callback_handler: GlobalRef, -) -> Result, ConnectError> { +) -> Result { let api_url = string_from_jstring!(env, api_url); let secret = SecretString::from(string_from_jstring!(env, token)); let device_id = string_from_jstring!(env, device_id); @@ -451,7 +451,7 @@ pub unsafe extern "system" fn Java_dev_firezone_android_tunnel_ConnlibSession_co log_dir: JString, log_filter: JString, callback_handler: JObject, -) -> *const Session { +) -> *const Session { let Ok(callback_handler) = env.new_global_ref(callback_handler) else { return std::ptr::null(); }; @@ -489,7 +489,7 @@ pub unsafe extern "system" fn Java_dev_firezone_android_tunnel_ConnlibSession_co pub unsafe extern "system" fn Java_dev_firezone_android_tunnel_ConnlibSession_disconnect( mut env: JNIEnv, _: JClass, - session: *mut Session, + session: *mut Session, ) { catch_and_throw(&mut env, |_| { Box::from_raw(session).disconnect(); diff --git a/rust/connlib/clients/apple/Cargo.toml b/rust/connlib/clients/apple/Cargo.toml index 1df5ae77e..26b88fc3c 100644 --- a/rust/connlib/clients/apple/Cargo.toml +++ b/rust/connlib/clients/apple/Cargo.toml @@ -23,6 +23,7 @@ tracing = { workspace = true } tracing-oslog = { git = "https://github.com/Absolucy/tracing-oslog", branch = "main" } # Waiting for a release. tracing-subscriber = "0.3" tracing-appender = "0.2" +url = "2.5.0" [lib] name = "connlib" diff --git a/rust/connlib/clients/apple/src/lib.rs b/rust/connlib/clients/apple/src/lib.rs index cd7a90855..dcd8f4337 100644 --- a/rust/connlib/clients/apple/src/lib.rs +++ b/rust/connlib/clients/apple/src/lib.rs @@ -78,7 +78,7 @@ mod ffi { } /// This is used by the apple client to interact with our code. -pub struct WrappedSession(Session); +pub struct WrappedSession(Session); // SAFETY: `CallbackHandler.swift` promises to be thread-safe. // TODO: Uphold that promise! diff --git a/rust/connlib/clients/shared/src/control.rs b/rust/connlib/clients/shared/src/control.rs deleted file mode 100644 index 8b511215a..000000000 --- a/rust/connlib/clients/shared/src/control.rs +++ /dev/null @@ -1,517 +0,0 @@ -use bimap::BiMap; -use connlib_shared::control::{ChannelError, ErrorReply}; -use connlib_shared::messages::{DnsServer, GatewayResponse, IpDnsServer}; -use connlib_shared::IpProvider; -use firezone_tunnel::ClientTunnel; -use ip_network::IpNetwork; -use std::io; -use std::net::IpAddr; -use std::path::PathBuf; -use std::str::FromStr; - -// TODO: These are used in the `upload` function, which is currently disabled. -// See the comment there for more information. -// use async_compression::tokio::bufread::GzipEncoder; -// use tokio_util::codec::{BytesCodec, FramedRead}; -// use reqwest::header::{CONTENT_ENCODING, CONTENT_TYPE}; -// use tokio::io::BufReader; - -use crate::messages::{ - BroadcastGatewayIceCandidates, Connect, ConnectionDetails, EgressMessages, - GatewayIceCandidates, InitClient, Messages, -}; -use connlib_shared::{ - control::{PhoenixSenderWithTopic, Reference}, - messages::{GatewayId, ResourceDescription, ResourceId}, - Callbacks, - Error::{self}, - Result, -}; - -use firezone_tunnel::Request; -use std::collections::HashMap; -use url::Url; - -const DNS_PORT: u16 = 53; -const DNS_SENTINELS_V4: &str = "100.100.111.0/24"; -const DNS_SENTINELS_V6: &str = "fd00:2021:1111:8000:100:100:111:0/120"; - -pub struct ControlPlane { - pub tunnel: ClientTunnel, - pub phoenix_channel: PhoenixSenderWithTopic, - pub tunnel_init: bool, - - pub next_request_id: usize, - pub sent_connection_intents: SentConnectionIntents, -} - -#[derive(Default)] -pub struct SentConnectionIntents { - inner: HashMap, -} - -impl SentConnectionIntents { - fn register_new_intent(&mut self, id: usize, resource: ResourceId) { - self.inner.insert(id, resource); - } - - /// To be called when we receive the connection details for a particular resource. - /// - /// Returns whether we should accept them. - fn handle_connection_details_received(&mut self, reference: usize, r: ResourceId) -> bool { - let has_more_recent_intent = self - .inner - .iter() - .any(|(req, resource)| req > &reference && resource == &r); - - if has_more_recent_intent { - return false; - } - - debug_assert!(self - .inner - .get(&reference) - .is_some_and(|resource| resource == &r)); - self.inner.retain(|_, v| v != &r); - - true - } - - fn handle_error(&mut self, reference: usize) -> Option { - self.inner.remove(&reference) - } -} - -fn effective_dns_servers( - upstream_dns: Vec, - default_resolvers: Vec, -) -> Vec { - if !upstream_dns.is_empty() { - return upstream_dns; - } - - let mut dns_servers = default_resolvers - .into_iter() - .filter(|ip| !IpNetwork::from_str(DNS_SENTINELS_V4).unwrap().contains(*ip)) - .filter(|ip| !IpNetwork::from_str(DNS_SENTINELS_V6).unwrap().contains(*ip)) - .peekable(); - - if dns_servers.peek().is_none() { - tracing::error!("No system default DNS servers available! Can't initialize resolver. DNS will be broken."); - return Vec::new(); - } - - dns_servers - .map(|ip| { - DnsServer::IpPort(IpDnsServer { - address: (ip, DNS_PORT).into(), - }) - }) - .collect() -} - -fn sentinel_dns_mapping(dns: &[DnsServer]) -> BiMap { - let mut ip_provider = IpProvider::new( - DNS_SENTINELS_V4.parse().unwrap(), - DNS_SENTINELS_V6.parse().unwrap(), - ); - - dns.iter() - .cloned() - .map(|i| { - ( - ip_provider - .get_proxy_ip_for(&i.ip()) - .expect("We only support up to 256 IpV4 DNS servers and 256 IpV6 DNS servers"), - i, - ) - }) - .collect() -} - -impl ControlPlane { - async fn init( - &mut self, - InitClient { - interface, - resources, - }: InitClient, - ) -> Result<()> { - let effective_dns_servers = effective_dns_servers( - interface.upstream_dns.clone(), - self.tunnel - .callbacks() - .get_system_default_resolvers() - .ok() - .flatten() - .unwrap_or_default(), - ); - - let sentinel_mapping = sentinel_dns_mapping(&effective_dns_servers); - - if !self.tunnel_init { - if let Err(e) = self - .tunnel - .set_interface(&interface, sentinel_mapping.clone()) - { - tracing::error!(error = ?e, "Error initializing interface"); - return Err(e); - } else { - self.tunnel_init = true; - tracing::info!("Firezone Started!"); - } - - for resource_description in resources { - self.add_resource(resource_description); - } - } else { - tracing::info!("Firezone reinitializated"); - } - Ok(()) - } - - pub fn connect( - &mut self, - Connect { - gateway_payload, - resource_id, - gateway_public_key, - .. - }: Connect, - ) { - match gateway_payload { - GatewayResponse::ConnectionAccepted(gateway_payload) => { - if let Err(e) = self.tunnel.received_offer_response( - resource_id, - gateway_payload.ice_parameters, - gateway_payload.domain_response, - gateway_public_key.0.into(), - ) { - tracing::debug!(error = ?e, "Error accepting connection: {e:#?}"); - } - } - GatewayResponse::ResourceAccepted(gateway_payload) => { - if let Err(e) = self - .tunnel - .received_domain_parameters(resource_id, gateway_payload.domain_response) - { - tracing::debug!(error = ?e, "Error accepting resource: {e:#?}"); - } - } - } - } - - pub fn add_resource(&mut self, resource_description: ResourceDescription) { - if let Err(e) = self.tunnel.add_resource(resource_description) { - tracing::error!(message = "Can't add resource", error = ?e); - } - } - - #[tracing::instrument(level = "trace", skip(self))] - fn resource_deleted(&mut self, id: ResourceId) { - self.tunnel.remove_resource(id); - } - - fn connection_details( - &mut self, - ConnectionDetails { - gateway_id, - resource_id, - relays, - .. - }: ConnectionDetails, - reference: Option, - ) { - let Some(reference) = reference.as_ref().and_then(|r| r.parse::().ok()) else { - tracing::warn!(?reference, "Failed to parse reference as usize"); - return; - }; - - if !self - .sent_connection_intents - .handle_connection_details_received(reference, resource_id) - { - tracing::debug!("Discarding stale connection details"); - - return; - } - - let mut control_signaler = self.phoenix_channel.clone(); - - let err = match self - .tunnel - .request_connection(resource_id, gateway_id, relays) - { - Ok(Request::NewConnection(connection_request)) => { - tokio::spawn(async move { - // TODO: create a reference number and keep track for the response - // Note: We used to clean up connections here upon failures with the _channel_ to the underlying portal connection. - // This is deemed unnecessary during the migration period to `phoenix-channel` because it means that the receiver is deallocated at which point we are probably shutting down? - let _ = control_signaler - .send_with_ref( - EgressMessages::RequestConnection(connection_request), - resource_id, - ) - .await; - }); - return; - } - Ok(Request::ReuseConnection(connection_request)) => { - tokio::spawn(async move { - // TODO: create a reference number and keep track for the response - // Note: We used to clean up connections here upon failures with the _channel_ to the underlying portal connection. - // This is deemed unnecessary during the migration period to `phoenix-channel` because it means that the receiver is deallocated at which point we are probably shutting down? - let _ = control_signaler - .send_with_ref( - EgressMessages::ReuseConnection(connection_request), - resource_id, - ) - .await; - }); - return; - } - Err(err) => err, - }; - - self.tunnel.cleanup_connection(resource_id); - tracing::error!("Error request connection details: {err}"); - } - - #[tracing::instrument(level = "trace", skip_all, fields(gateway = %gateway_id))] - fn add_ice_candidate( - &mut self, - GatewayIceCandidates { - gateway_id, - candidates, - }: GatewayIceCandidates, - ) { - for candidate in candidates { - self.tunnel.add_ice_candidate(gateway_id, candidate) - } - } - - #[tracing::instrument(level = "trace", skip(self, msg))] - pub async fn handle_message( - &mut self, - msg: Messages, - reference: Option, - ) -> Result<()> { - match msg { - Messages::Init(init) => self.init(init).await?, - Messages::ConfigChanged(_update) => { - tracing::info!("Runtime config updates not yet implemented"); - } - Messages::ConnectionDetails(connection_details) => { - self.connection_details(connection_details, reference) - } - Messages::Connect(connect) => self.connect(connect), - Messages::ResourceCreatedOrUpdated(resource) => self.add_resource(resource), - Messages::ResourceDeleted(resource) => self.resource_deleted(resource.0), - Messages::IceCandidates(ice_candidate) => self.add_ice_candidate(ice_candidate), - Messages::SignedLogUrl(url) => { - let Some(path) = self.tunnel.callbacks().roll_log_file() else { - return Ok(()); - }; - - tokio::spawn(async move { - if let Err(e) = upload(path.clone(), url).await { - tracing::warn!( - "Failed to upload log file at path {path_display}: {e}. Not retrying.", - path_display = path.display() - ); - } - }); - } - } - Ok(()) - } - - // Errors here means we need to disconnect - #[tracing::instrument(level = "trace", skip(self))] - pub async fn handle_error( - &mut self, - reply_error: ChannelError, - reference: Option, - topic: String, - ) -> Result<()> { - match (reply_error, reference) { - (ChannelError::ErrorReply(ErrorReply::Offline), Some(reference)) => { - let Ok(request_id) = reference.parse::() else { - return Ok(()); - }; - - let Some(resource) = self.sent_connection_intents.handle_error(request_id) else { - return Ok(()); - }; - - tracing::debug!(%resource, "Resource is offline"); - - self.tunnel.cleanup_connection(resource); - } - (ChannelError::ErrorReply(ErrorReply::UnmatchedTopic), _) => { - if let Err(e) = self.phoenix_channel.get_sender().join_topic(topic).await { - tracing::debug!(err = ?e, "couldn't join topic: {e:#?}"); - } - } - (ChannelError::ErrorReply(ErrorReply::TokenExpired), _) - | (ChannelError::ErrorMsg(Error::ClosedByPortal), _) => { - return Err(Error::ClosedByPortal); - } - _ => {} - } - Ok(()) - } - - pub async fn request_log_upload_url(&mut self) { - tracing::info!("Requesting log upload URL from portal"); - - let _ = self - .phoenix_channel - .send(EgressMessages::CreateLogSink {}) - .await; - } - - pub async fn handle_tunnel_event(&mut self, event: Result>) { - match event { - Ok(firezone_tunnel::Event::SignalIceCandidate { conn_id, candidate }) => { - if let Err(e) = self - .phoenix_channel - .send(EgressMessages::BroadcastIceCandidates( - BroadcastGatewayIceCandidates { - gateway_ids: vec![conn_id], - candidates: vec![candidate], - }, - )) - .await - { - tracing::error!("Failed to signal ICE candidate: {e}") - } - } - Ok(firezone_tunnel::Event::ConnectionIntent { - resource, - connected_gateway_ids, - }) => { - let id = self.next_request_id; - self.next_request_id += 1; - self.sent_connection_intents - .register_new_intent(id, resource); - - if let Err(e) = self - .phoenix_channel - .clone() - .send_with_ref( - EgressMessages::PrepareConnection { - resource_id: resource, - connected_gateway_ids, - }, - id, - ) - .await - { - tracing::error!("Failed to prepare connection: {e}"); - - // TODO: Clean up connection in `ClientState` here? - } - } - Ok(firezone_tunnel::Event::RefreshResources { connections }) => { - let mut control_signaler = self.phoenix_channel.clone(); - tokio::spawn(async move { - for connection in connections { - let resource_id = connection.resource_id; - if let Err(err) = control_signaler - .send_with_ref(EgressMessages::ReuseConnection(connection), resource_id) - .await - { - tracing::warn!(%resource_id, ?err, "failed to refresh resource dns: {err:#?}"); - } - } - }); - } - Ok(firezone_tunnel::Event::StopPeer(_)) => { - // This should never bubbled up - // TODO: we might want to segregate events further - } - Ok(firezone_tunnel::Event::SendPacket(_)) => { - unimplemented!("Handled internally"); - } - Err(e) => { - tracing::error!("Tunnel failed: {e:#?}"); - } - } - } -} - -async fn upload(_path: PathBuf, _url: Url) -> io::Result<()> { - // TODO: Log uploads are disabled by default for GA until we expose a way to opt in - // to the user. See https://github.com/firezone/firezone/issues/3910 - - // tracing::info!(path = %path.display(), %url, "Uploading log file"); - // - // let file = tokio::fs::File::open(&path).await?; - // - // let response = reqwest::Client::new() - // .put(url) - // .header(CONTENT_TYPE, "text/plain") - // .header(CONTENT_ENCODING, "gzip") - // .body(reqwest::Body::wrap_stream(FramedRead::new( - // GzipEncoder::new(BufReader::new(file)), - // BytesCodec::default(), - // ))) - // .send() - // .await - // .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; - // - // let status_code = response.status(); - // - // if !status_code.is_success() { - // let body = response - // .text() - // .await - // .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; - // - // tracing::warn!(%body, %status_code, "Failed to upload logs"); - // - // return Err(io::Error::new( - // io::ErrorKind::Other, - // "portal returned non-successful exit code", - // )); - // } - - Ok(()) -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn discards_old_connection_intent() { - let mut intents = SentConnectionIntents::default(); - - let resource = ResourceId::random(); - - intents.register_new_intent(1, resource); - intents.register_new_intent(2, resource); - - let should_accept = intents.handle_connection_details_received(1, resource); - - assert!(!should_accept); - } - - #[test] - fn allows_unrelated_intents() { - let mut intents = SentConnectionIntents::default(); - - let resource1 = ResourceId::random(); - let resource2 = ResourceId::random(); - - intents.register_new_intent(1, resource1); - intents.register_new_intent(2, resource2); - - let should_accept_1 = intents.handle_connection_details_received(1, resource1); - let should_accept_2 = intents.handle_connection_details_received(2, resource2); - - assert!(should_accept_1); - assert!(should_accept_2); - } -} diff --git a/rust/connlib/clients/shared/src/eventloop.rs b/rust/connlib/clients/shared/src/eventloop.rs new file mode 100644 index 000000000..290e04b21 --- /dev/null +++ b/rust/connlib/clients/shared/src/eventloop.rs @@ -0,0 +1,564 @@ +use crate::{ + messages::{ + BroadcastGatewayIceCandidates, Connect, ConnectionDetails, EgressMessages, + GatewayIceCandidates, IngressMessages, InitClient, RemoveResource, ReplyMessages, + }, + PHOENIX_TOPIC, +}; +use anyhow::Result; +use bimap::BiMap; +use connlib_shared::{ + messages::{ + ConnectionAccepted, DnsServer, GatewayId, GatewayResponse, IpDnsServer, ResourceAccepted, + ResourceId, + }, + Callbacks, IpProvider, +}; +use firezone_tunnel::ClientTunnel; +use ip_network::IpNetwork; +use phoenix_channel::{ErrorReply, OutboundRequestId, PhoenixChannel}; +use std::{ + collections::HashMap, + convert::Infallible, + io, + net::IpAddr, + path::PathBuf, + str::FromStr, + task::{Context, Poll}, + time::Duration, +}; +use tokio::time::{Instant, Interval, MissedTickBehavior}; +use url::Url; + +const DNS_PORT: u16 = 53; +const DNS_SENTINELS_V4: &str = "100.100.111.0/24"; +const DNS_SENTINELS_V6: &str = "fd00:2021:1111:8000:100:100:111:0/120"; + +pub struct Eventloop { + tunnel: ClientTunnel, + tunnel_init: bool, + + portal: PhoenixChannel<(), IngressMessages, ReplyMessages>, + connection_intents: SentConnectionIntents, + log_upload_interval: tokio::time::Interval, +} + +impl Eventloop { + pub(crate) fn new( + tunnel: ClientTunnel, + portal: PhoenixChannel<(), IngressMessages, ReplyMessages>, + ) -> Self { + Self { + tunnel, + portal, + tunnel_init: false, + connection_intents: SentConnectionIntents::default(), + log_upload_interval: upload_interval(), + } + } +} + +impl Eventloop +where + C: Callbacks + 'static, +{ + #[tracing::instrument(name = "Eventloop::poll", skip_all, level = "debug")] + pub fn poll( + &mut self, + cx: &mut Context<'_>, + ) -> Poll> { + loop { + match self.tunnel.poll_next_event(cx) { + Poll::Ready(Ok(event)) => { + self.handle_tunnel_event(event); + continue; + } + Poll::Ready(Err(e)) => { + tracing::error!("Tunnel failed: {e}"); + continue; + } + Poll::Pending => {} + } + + match self.portal.poll(cx)? { + Poll::Ready(event) => { + self.handle_portal_event(event); + continue; + } + Poll::Pending => {} + } + + if self.log_upload_interval.poll_tick(cx).is_ready() { + self.portal + .send(PHOENIX_TOPIC, EgressMessages::CreateLogSink {}); + continue; + } + + return Poll::Pending; + } + } + + fn handle_tunnel_event(&mut self, event: firezone_tunnel::Event) { + match event { + firezone_tunnel::Event::SignalIceCandidate { + conn_id: gateway, + candidate, + } => { + tracing::debug!(%gateway, %candidate, "Sending ICE candidate to gateway"); + + self.portal.send( + PHOENIX_TOPIC, + EgressMessages::BroadcastIceCandidates(BroadcastGatewayIceCandidates { + gateway_ids: vec![gateway], + candidates: vec![candidate], + }), + ); + } + firezone_tunnel::Event::ConnectionIntent { + connected_gateway_ids, + resource, + .. + } => { + let id = self.portal.send( + PHOENIX_TOPIC, + EgressMessages::PrepareConnection { + resource_id: resource, + connected_gateway_ids, + }, + ); + self.connection_intents.register_new_intent(id, resource); + } + firezone_tunnel::Event::RefreshResources { connections } => { + for connection in connections { + self.portal + .send(PHOENIX_TOPIC, EgressMessages::ReuseConnection(connection)); + } + } + firezone_tunnel::Event::SendPacket { .. } | firezone_tunnel::Event::StopPeer { .. } => { + unreachable!("Handled internally") + } + } + } + + fn handle_portal_event( + &mut self, + event: phoenix_channel::Event, + ) { + match event { + phoenix_channel::Event::InboundMessage { msg, .. } => { + self.handle_portal_inbound_message(msg); + } + phoenix_channel::Event::SuccessResponse { res, req_id, .. } => { + self.handle_portal_success_reply(res, req_id); + } + phoenix_channel::Event::ErrorResponse { res, req_id, topic } => { + self.handle_portal_error_reply(res, topic, req_id); + } + phoenix_channel::Event::HeartbeatSent => {} + phoenix_channel::Event::JoinedRoom { .. } => {} + } + } + + fn handle_portal_inbound_message(&mut self, msg: IngressMessages) { + match msg { + IngressMessages::ConfigChanged(_) => { + tracing::warn!("Config changes are not yet implemented"); + } + IngressMessages::IceCandidates(GatewayIceCandidates { + gateway_id, + candidates, + }) => { + for candidate in candidates { + self.tunnel.add_ice_candidate(gateway_id, candidate) + } + } + IngressMessages::Init(InitClient { + interface, + resources, + }) => { + let effective_dns_servers = effective_dns_servers( + interface.upstream_dns.clone(), + self.tunnel + .callbacks() + .get_system_default_resolvers() + .ok() + .flatten() + .unwrap_or_default(), + ); + + let sentinel_mapping = sentinel_dns_mapping(&effective_dns_servers); + + if !self.tunnel_init { + if let Err(e) = self + .tunnel + .set_interface(&interface, sentinel_mapping.clone()) + { + tracing::warn!("Failed to set interface on tunnel: {e}"); + return; + } + + self.tunnel_init = true; + tracing::info!("Firezone Started!"); + + for resource_description in resources { + let _ = self.tunnel.add_resource(resource_description); + } + } else { + tracing::info!("Firezone reinitializated"); + } + } + IngressMessages::ResourceCreatedOrUpdated(resource) => { + let resource_id = resource.id(); + + if let Err(e) = self.tunnel.add_resource(resource) { + tracing::warn!(%resource_id, "Failed to add resource: {e}"); + } + } + IngressMessages::ResourceDeleted(RemoveResource(resource)) => { + self.tunnel.remove_resource(resource); + } + } + } + + fn handle_portal_success_reply(&mut self, res: ReplyMessages, req_id: OutboundRequestId) { + match res { + ReplyMessages::Connect(Connect { + gateway_payload: + GatewayResponse::ConnectionAccepted(ConnectionAccepted { + ice_parameters, + domain_response, + }), + gateway_public_key, + resource_id, + .. + }) => { + if let Err(e) = self.tunnel.received_offer_response( + resource_id, + ice_parameters, + domain_response, + gateway_public_key.0.into(), + ) { + tracing::warn!("Failed to accept connection: {e}"); + } + } + ReplyMessages::Connect(Connect { + gateway_payload: + GatewayResponse::ResourceAccepted(ResourceAccepted { domain_response }), + resource_id, + .. + }) => { + if let Err(e) = self + .tunnel + .received_domain_parameters(resource_id, domain_response) + { + tracing::warn!("Failed to accept resource: {e}"); + } + } + ReplyMessages::ConnectionDetails(ConnectionDetails { + gateway_id, + resource_id, + relays, + .. + }) => { + let should_accept = self + .connection_intents + .handle_connection_details_received(req_id, resource_id); + + if !should_accept { + tracing::debug!(%resource_id, "Ignoring stale connection details"); + return; + } + + match self + .tunnel + .request_connection(resource_id, gateway_id, relays) + { + Ok(firezone_tunnel::Request::NewConnection(connection_request)) => { + // TODO: keep track for the response + let _id = self.portal.send( + PHOENIX_TOPIC, + EgressMessages::RequestConnection(connection_request), + ); + } + Ok(firezone_tunnel::Request::ReuseConnection(connection_request)) => { + // TODO: keep track for the response + let _id = self.portal.send( + PHOENIX_TOPIC, + EgressMessages::ReuseConnection(connection_request), + ); + } + Err(e) => { + self.tunnel.cleanup_connection(resource_id); + tracing::warn!("Failed to request new connection: {e}"); + } + }; + } + ReplyMessages::SignedLogUrl(url) => { + let Some(path) = self.tunnel.callbacks().roll_log_file() else { + return; + }; + + tokio::spawn(async move { + if let Err(e) = upload(path.clone(), url).await { + tracing::warn!( + "Failed to upload log file at path {path_display}: {e}. Not retrying.", + path_display = path.display() + ); + } + }); + } + } + } + + fn handle_portal_error_reply( + &mut self, + res: ErrorReply, + topic: String, + req_id: OutboundRequestId, + ) { + match res { + ErrorReply::Offline => { + let Some(offline_resource) = self.connection_intents.handle_error(req_id) else { + return; + }; + + tracing::debug!(resource_id = %offline_resource, "Resource is offline"); + + self.tunnel.cleanup_connection(offline_resource); + } + + ErrorReply::Disabled => { + tracing::debug!(%req_id, "Functionality is disabled"); + } + ErrorReply::UnmatchedTopic => { + self.portal.join(topic, ()); + } + ErrorReply::NotFound | ErrorReply::Other => {} + } + } +} + +async fn upload(_path: PathBuf, _url: Url) -> io::Result<()> { + // TODO: Log uploads are disabled by default for GA until we expose a way to opt in + // to the user. See https://github.com/firezone/firezone/issues/3910 + + // tracing::info!(path = %path.display(), %url, "Uploading log file"); + + // let file = tokio::fs::File::open(&path).await?; + + // let response = reqwest::Client::new() + // .put(url) + // .header(CONTENT_TYPE, "text/plain") + // .header(CONTENT_ENCODING, "gzip") + // .body(reqwest::Body::wrap_stream(FramedRead::new( + // GzipEncoder::new(BufReader::new(file)), + // BytesCodec::default(), + // ))) + // .send() + // .await + // .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; + + // let status_code = response.status(); + + // if !status_code.is_success() { + // let body = response + // .text() + // .await + // .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; + + // tracing::warn!(%body, %status_code, "Failed to upload logs"); + + // return Err(io::Error::new( + // io::ErrorKind::Other, + // "portal returned non-successful exit code", + // )); + // } + + Ok(()) +} + +fn effective_dns_servers( + upstream_dns: Vec, + default_resolvers: Vec, +) -> Vec { + if !upstream_dns.is_empty() { + return upstream_dns; + } + + let mut dns_servers = default_resolvers + .into_iter() + .filter(|ip| !IpNetwork::from_str(DNS_SENTINELS_V4).unwrap().contains(*ip)) + .filter(|ip| !IpNetwork::from_str(DNS_SENTINELS_V6).unwrap().contains(*ip)) + .peekable(); + + if dns_servers.peek().is_none() { + tracing::error!("No system default DNS servers available! Can't initialize resolver. DNS will be broken."); + return Vec::new(); + } + + dns_servers + .map(|ip| { + DnsServer::IpPort(IpDnsServer { + address: (ip, DNS_PORT).into(), + }) + }) + .collect() +} + +fn sentinel_dns_mapping(dns: &[DnsServer]) -> BiMap { + let mut ip_provider = IpProvider::new( + DNS_SENTINELS_V4.parse().unwrap(), + DNS_SENTINELS_V6.parse().unwrap(), + ); + + dns.iter() + .cloned() + .map(|i| { + ( + ip_provider + .get_proxy_ip_for(&i.ip()) + .expect("We only support up to 256 IpV4 DNS servers and 256 IpV6 DNS servers"), + i, + ) + }) + .collect() +} + +fn upload_interval() -> Interval { + let duration = upload_interval_duration_from_env_or_default(); + let mut interval = tokio::time::interval_at(Instant::now() + duration, duration); + interval.set_missed_tick_behavior(MissedTickBehavior::Skip); + + interval +} + +/// Parses an interval from the _compile-time_ env variable `CONNLIB_LOG_UPLOAD_INTERVAL_SECS`. +/// +/// If not present or parsing as u64 fails, we fall back to a default interval of 5 minutes. +fn upload_interval_duration_from_env_or_default() -> Duration { + const DEFAULT: Duration = Duration::from_secs(60 * 5); + + let Some(interval) = option_env!("CONNLIB_LOG_UPLOAD_INTERVAL_SECS") else { + tracing::warn!(interval = ?DEFAULT, "Env variable `CONNLIB_LOG_UPLOAD_INTERVAL_SECS` was not set during compile-time, falling back to default"); + + return DEFAULT; + }; + + let interval = match interval.parse() { + Ok(i) => i, + Err(e) => { + tracing::warn!(interval = ?DEFAULT, "Failed to parse `CONNLIB_LOG_UPLOAD_INTERVAL_SECS` as u64: {e}"); + return DEFAULT; + } + }; + + tracing::info!( + ?interval, + "Using upload interval specified at compile-time via `CONNLIB_LOG_UPLOAD_INTERVAL_SECS`" + ); + + Duration::from_secs(interval) +} + +#[derive(Default)] +struct SentConnectionIntents { + inner: HashMap, +} + +impl SentConnectionIntents { + fn register_new_intent(&mut self, id: OutboundRequestId, resource: ResourceId) { + self.inner.insert(id, resource); + } + + /// To be called when we receive the connection details for a particular resource. + /// + /// Returns whether we should accept them. + fn handle_connection_details_received( + &mut self, + reference: OutboundRequestId, + r: ResourceId, + ) -> bool { + let has_more_recent_intent = self + .inner + .iter() + .any(|(req, resource)| req > &reference && resource == &r); + + if has_more_recent_intent { + return false; + } + + let has_intent = self + .inner + .get(&reference) + .is_some_and(|resource| resource == &r); + + if !has_intent { + return false; + } + + self.inner.retain(|_, v| v != &r); + + true + } + + fn handle_error(&mut self, req: OutboundRequestId) -> Option { + self.inner.remove(&req) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn discards_old_connection_intent() { + let mut intents = SentConnectionIntents::default(); + + let resource = ResourceId::random(); + + intents.register_new_intent(OutboundRequestId::for_test(1), resource); + intents.register_new_intent(OutboundRequestId::for_test(2), resource); + + let should_accept = + intents.handle_connection_details_received(OutboundRequestId::for_test(1), resource); + + assert!(!should_accept); + } + + #[test] + fn allows_unrelated_intents() { + let mut intents = SentConnectionIntents::default(); + + let resource1 = ResourceId::random(); + let resource2 = ResourceId::random(); + + intents.register_new_intent(OutboundRequestId::for_test(1), resource1); + intents.register_new_intent(OutboundRequestId::for_test(2), resource2); + + let should_accept_1 = + intents.handle_connection_details_received(OutboundRequestId::for_test(1), resource1); + let should_accept_2 = + intents.handle_connection_details_received(OutboundRequestId::for_test(2), resource2); + + assert!(should_accept_1); + assert!(should_accept_2); + } + + #[test] + fn handles_out_of_order_responses() { + let mut intents = SentConnectionIntents::default(); + + let resource = ResourceId::random(); + + intents.register_new_intent(OutboundRequestId::for_test(1), resource); + intents.register_new_intent(OutboundRequestId::for_test(2), resource); + + let should_accept_2 = + intents.handle_connection_details_received(OutboundRequestId::for_test(2), resource); + let should_accept_1 = + intents.handle_connection_details_received(OutboundRequestId::for_test(1), resource); + + assert!(should_accept_2); + assert!(!should_accept_1); + } +} diff --git a/rust/connlib/clients/shared/src/lib.rs b/rust/connlib/clients/shared/src/lib.rs index d00f42bc6..b01f07546 100644 --- a/rust/connlib/clients/shared/src/lib.rs +++ b/rust/connlib/clients/shared/src/lib.rs @@ -1,28 +1,25 @@ //! Main connlib library for clients. pub use connlib_shared::messages::ResourceDescription; -pub use connlib_shared::{keypair, Callbacks, Error, LoginUrl, LoginUrlError}; +pub use connlib_shared::{keypair, Callbacks, Error, LoginUrl, LoginUrlError, StaticSecret}; pub use tracing_appender::non_blocking::WorkerGuard; -use backoff::{backoff::Backoff, ExponentialBackoffBuilder}; -use connlib_shared::StaticSecret; -use connlib_shared::{control::PhoenixChannel, CallbackErrorFacade, Result}; -use control::ControlPlane; +use backoff::ExponentialBackoffBuilder; +use connlib_shared::{get_user_agent, CallbackErrorFacade}; use firezone_tunnel::Tunnel; -use messages::IngressMessages; -use messages::Messages; -use messages::ReplyMessages; -use secrecy::Secret; -use std::future::poll_fn; +use phoenix_channel::PhoenixChannel; use std::time::Duration; -use tokio::time::{Interval, MissedTickBehavior}; -use tokio::{runtime::Runtime, time::Instant}; -mod control; +mod eventloop; pub mod file_logger; mod messages; +const PHOENIX_TOPIC: &str = "client"; + struct StopRuntime; +pub use eventloop::Eventloop; +use secrecy::Secret; + /// Max interval to retry connections to the portal if it's down or the client has network /// connectivity changes. Set this to something short so that the end-user experiences /// minimal disruption to their Firezone resources when switching networks. @@ -31,27 +28,11 @@ const MAX_RECONNECT_INTERVAL: Duration = Duration::from_secs(5); /// A session is the entry-point for connlib, maintains the runtime and the tunnel. /// /// A session is created using [Session::connect], then to stop a session we use [Session::disconnect]. -pub struct Session { +pub struct Session { runtime_stopper: tokio::sync::mpsc::Sender, - pub callbacks: CallbackErrorFacade, } -macro_rules! fatal_error { - ($result:expr, $rt:expr, $cb:expr) => { - match $result { - Ok(res) => res, - Err(err) => { - Self::disconnect_inner($rt, $cb, Some(err)); - return; - } - } - }; -} - -impl Session -where - CB: Callbacks + 'static, -{ +impl Session { /// Starts a session in the background. /// /// This will: @@ -65,13 +46,13 @@ where /// /// * `device_id` - The cleartext device ID. connlib will obscure this with a hash internally. // TODO: token should be something like SecretString but we need to think about FFI compatibility - pub fn connect( + pub fn connect( url: LoginUrl, private_key: StaticSecret, os_version_override: Option, callbacks: CB, max_partition_time: Option, - ) -> Result { + ) -> connlib_shared::Result { // TODO: We could use tokio::runtime::current() to get the current runtime // which could work with swift-rust that already runs a runtime. But IDK if that will work // in all platforms, a couple of new threads shouldn't bother none. @@ -108,15 +89,14 @@ where })); } - Self::connect_inner( - &runtime, - tx.clone(), + runtime.spawn(connect( url, private_key, os_version_override, - callbacks.clone(), + callbacks, max_partition_time, - ); + )); + std::thread::spawn(move || { rx.blocking_recv(); runtime.shutdown_background(); @@ -124,110 +104,10 @@ where Ok(Self { runtime_stopper: tx, - callbacks, }) } - // TODO: Refactor this when we refactor PhoenixChannel. - // See https://github.com/firezone/firezone/issues/2158 - #[allow(clippy::too_many_arguments)] - fn connect_inner( - runtime: &Runtime, - runtime_stopper: tokio::sync::mpsc::Sender, - url: LoginUrl, - private_key: StaticSecret, - os_version_override: Option, - callbacks: CallbackErrorFacade, - max_partition_time: Option, - ) { - runtime.spawn(async move { - // This is kinda hacky, the buffer size is 1 so that we make sure that we - // process one message at a time, blocking if a previous message haven't been processed - // to force queue ordering. - let (control_plane_sender, mut control_plane_receiver) = tokio::sync::mpsc::channel(1); - - let mut connection = PhoenixChannel::<_, IngressMessages, ReplyMessages, Messages>::new(Secret::new(url), os_version_override, move |msg, reference, topic| { - let control_plane_sender = control_plane_sender.clone(); - async move { - tracing::trace!(?msg); - if let Err(e) = control_plane_sender.send((msg, reference, topic)).await { - tracing::warn!("Received a message after handler already closed: {e}. Probably message received during session clean up."); - } - } - }); - - let tunnel = fatal_error!( - Tunnel::new(private_key, callbacks.clone()), - runtime_stopper, - &callbacks - ); - - let mut control_plane = ControlPlane { - tunnel, - phoenix_channel: connection.sender_with_topic("client".to_owned()), - tunnel_init: false, - next_request_id: 0, - sent_connection_intents: Default::default(), - }; - - tokio::spawn({ - let runtime_stopper = runtime_stopper.clone(); - let callbacks = callbacks.clone(); - async move { - let mut upload_logs_interval = upload_interval(); - loop { - tokio::select! { - Some((msg, reference, topic)) = control_plane_receiver.recv() => { - match msg { - Ok(msg) => control_plane.handle_message(msg, reference).await?, - Err(err) => { - if let Err(e) = control_plane.handle_error(err, reference, topic).await { - Self::disconnect_inner(runtime_stopper, &callbacks, Some(e)); - break; - } - }, - } - }, - event = poll_fn(|cx| control_plane.tunnel.poll_next_event(cx)) => control_plane.handle_tunnel_event(event).await, - _ = upload_logs_interval.tick() => control_plane.request_log_upload_url().await, - else => break - } - } - - Result::Ok(()) - }}); - - tokio::spawn(async move { - let mut exponential_backoff = ExponentialBackoffBuilder::default().with_max_elapsed_time(max_partition_time).with_max_interval(MAX_RECONNECT_INTERVAL).build(); - loop { - // `connection.start` calls the callback only after connecting - tracing::debug!("Attempting connection to portal..."); - let result = connection.start(vec!["client".to_owned()], || exponential_backoff.reset()).await; - tracing::warn!("Disconnected from the portal"); - if let Err(e) = &result { - if e.is_http_client_error() { - tracing::error!(error = ?e, "Connection to portal failed. Is your token valid?"); - fatal_error!(result, runtime_stopper, &callbacks); - } else { - tracing::error!(error = ?e, "Connection to portal failed. Starting retries with backoff timer."); - } - } - if let Some(t) = exponential_backoff.next_backoff() { - tracing::debug!("Connection to portal failed. Retrying connection to portal in {:?}", t); - tokio::time::sleep(t).await; - } else { - tracing::error!("Connection to portal failed, giving up!"); - Self::disconnect_inner(runtime_stopper, &callbacks, None); - break; - } - } - - }); - - }); - } - - fn disconnect_inner( + fn disconnect_inner( runtime_stopper: tokio::sync::mpsc::Sender, callbacks: &CallbackErrorFacade, error: Option, @@ -267,38 +147,45 @@ where } } -fn upload_interval() -> Interval { - let duration = upload_interval_duration_from_env_or_default(); - let mut interval = tokio::time::interval_at(Instant::now() + duration, duration); - interval.set_missed_tick_behavior(MissedTickBehavior::Skip); - - interval -} - -/// Parses an interval from the _compile-time_ env variable `CONNLIB_LOG_UPLOAD_INTERVAL_SECS`. +/// Connects to the portal and starts a tunnel. /// -/// If not present or parsing as u64 fails, we fall back to a default interval of 5 minutes. -fn upload_interval_duration_from_env_or_default() -> Duration { - const DEFAULT: Duration = Duration::from_secs(60 * 5); - - let Some(interval) = option_env!("CONNLIB_LOG_UPLOAD_INTERVAL_SECS") else { - tracing::warn!(interval = ?DEFAULT, "Env variable `CONNLIB_LOG_UPLOAD_INTERVAL_SECS` was not set during compile-time, falling back to default"); - - return DEFAULT; - }; - - let interval = match interval.parse() { - Ok(i) => i, +/// When this function exits, the tunnel failed unrecoverably and you need to call it again. +async fn connect( + url: LoginUrl, + private_key: StaticSecret, + os_version_override: Option, + callbacks: CB, + max_partition_time: Option, +) where + CB: Callbacks + 'static, +{ + let tunnel = match Tunnel::new(private_key, callbacks.clone()) { + Ok(tunnel) => tunnel, Err(e) => { - tracing::warn!(interval = ?DEFAULT, "Failed to parse `CONNLIB_LOG_UPLOAD_INTERVAL_SECS` as u64: {e}"); - return DEFAULT; + tracing::error!("Failed to make tunnel: {e}"); + let _ = callbacks.on_disconnect(&e); + return; } }; - tracing::info!( - ?interval, - "Using upload interval specified at compile-time via `CONNLIB_LOG_UPLOAD_INTERVAL_SECS`" + let portal = PhoenixChannel::connect( + Secret::new(url), + get_user_agent(os_version_override), + PHOENIX_TOPIC, + (), + ExponentialBackoffBuilder::default() + .with_max_elapsed_time(max_partition_time) + .with_max_interval(MAX_RECONNECT_INTERVAL) + .build(), ); - Duration::from_secs(interval) + let mut eventloop = Eventloop::new(tunnel, portal); + + match std::future::poll_fn(|cx| eventloop.poll(cx)).await { + Ok(never) => match never {}, + Err(e) => { + tracing::error!("Eventloop failed: {e}"); + let _ = callbacks.on_disconnect(&Error::PortalConnectionFailed); // TMP Error until we have a narrower API for `onDisconnect` + } + } } diff --git a/rust/connlib/clients/shared/src/messages.rs b/rust/connlib/clients/shared/src/messages.rs index dcebbb9fa..a317808e4 100644 --- a/rust/connlib/clients/shared/src/messages.rs +++ b/rust/connlib/clients/shared/src/messages.rs @@ -152,13 +152,11 @@ pub enum EgressMessages { mod test { use std::collections::HashSet; - use connlib_shared::{ - control::PhoenixMessage, - messages::{ - DnsServer, Interface, IpDnsServer, Relay, ResourceDescription, ResourceDescriptionCidr, - ResourceDescriptionDns, Stun, Turn, - }, + use connlib_shared::messages::{ + DnsServer, Interface, IpDnsServer, Relay, ResourceDescription, ResourceDescriptionCidr, + ResourceDescriptionDns, Stun, Turn, }; + use phoenix_channel::{OutboundRequestId, PhoenixMessage}; use chrono::DateTime; @@ -171,7 +169,7 @@ mod test { #[test] fn connection_ready_deserialization() { let message = r#"{ - "ref": "0", + "ref": 0, "topic": "client", "event": "phx_reply", "payload": { @@ -204,7 +202,7 @@ mod test { #[test] fn config_updated() { - let m = PhoenixMessage::new( + let m = PhoenixMessage::new_message( "client", IngressMessages::ConfigChanged(ConfigUpdate { interface: Interface { @@ -243,7 +241,7 @@ mod test { #[test] fn init_phoenix_message() { - let m = PhoenixMessage::new( + let m = PhoenixMessage::new_message( "client", IngressMessages::Init(InitClient { interface: Interface { @@ -301,7 +299,7 @@ mod test { #[test] fn list_relays_message() { - let m = PhoenixMessage::::new( + let m = PhoenixMessage::::new_message( "client", EgressMessages::PrepareConnection { resource_id: "f16ecfa0-a94f-4bfd-a2ef-1cc1f2ef3da3".parse().unwrap(), @@ -326,7 +324,7 @@ mod test { #[test] fn connection_details_reply() { - let m = PhoenixMessage::::new_ok_reply( + let m = PhoenixMessage::::new_ok_reply( "client", ReplyMessages::ConnectionDetails(ConnectionDetails { gateway_id: "73037362-715d-4a83-a749-f18eadd970e6".parse().unwrap(), @@ -396,31 +394,16 @@ mod test { assert_eq!(m, reply_message); } - #[test] - fn create_log_sink_error_response() { - let json = r#"{"event":"phx_reply","ref":"unique_log_sink_ref","topic":"client","payload":{"status":"error","response":{"reason": "disabled"}}}"#; - - let actual = - serde_json::from_str::>(json).unwrap(); - let expected = PhoenixMessage::new_err_reply( - "client", - connlib_shared::control::ErrorReply::Disabled, - "unique_log_sink_ref".to_owned(), - ); - - assert_eq!(actual, expected) - } - #[test] fn create_log_sink_ok_response() { - let json = r#"{"event":"phx_reply","ref":"unique_log_sink_ref","topic":"client","payload":{"status":"ok","response":"https://storage.googleapis.com/foo/bar"}}"#; + let json = r#"{"event":"phx_reply","ref":2,"topic":"client","payload":{"status":"ok","response":"https://storage.googleapis.com/foo/bar"}}"#; let actual = serde_json::from_str::>(json).unwrap(); let expected = PhoenixMessage::new_ok_reply( "client", ReplyMessages::SignedLogUrl("https://storage.googleapis.com/foo/bar".parse().unwrap()), - "unique_log_sink_ref".to_owned(), + Some(OutboundRequestId::for_test(2)), ); assert_eq!(actual, expected) diff --git a/rust/connlib/shared/src/error.rs b/rust/connlib/shared/src/error.rs index 9faae5136..d6075ee94 100644 --- a/rust/connlib/shared/src/error.rs +++ b/rust/connlib/shared/src/error.rs @@ -177,6 +177,9 @@ pub enum ConnlibError { // Error variants for `systemd-resolved` DNS control #[error("Failed to control system DNS with `resolvectl`")] ResolvectlFailed, + + #[error("connection to the portal failed")] + PortalConnectionFailed, } impl ConnlibError { diff --git a/rust/connlib/shared/src/lib.rs b/rust/connlib/shared/src/lib.rs index ba3ada1b1..41118cc6f 100644 --- a/rust/connlib/shared/src/lib.rs +++ b/rust/connlib/shared/src/lib.rs @@ -5,7 +5,6 @@ mod callbacks; mod callbacks_error_facade; -pub mod control; pub mod error; pub mod messages; diff --git a/rust/gateway/src/messages.rs b/rust/gateway/src/messages.rs index 6981694dd..48d8c8ca9 100644 --- a/rust/gateway/src/messages.rs +++ b/rust/gateway/src/messages.rs @@ -147,8 +147,8 @@ pub struct ConnectionReady { #[cfg(test)] mod test { use super::*; - use connlib_shared::control::PhoenixMessage; use phoenix_channel::InitMessage; + use phoenix_channel::PhoenixMessage; #[test] fn request_connection_message() { diff --git a/rust/gui-client/src-tauri/src/client/gui.rs b/rust/gui-client/src-tauri/src/client/gui.rs index 2c46d3ecd..6ca14a284 100644 --- a/rust/gui-client/src-tauri/src/client/gui.rs +++ b/rust/gui-client/src-tauri/src/client/gui.rs @@ -504,7 +504,7 @@ struct Controller { /// Everything related to a signed-in user session struct Session { callback_handler: CallbackHandler, - connlib: connlib_client_shared::Session, + connlib: connlib_client_shared::Session, } impl Controller { diff --git a/rust/phoenix-channel/src/heartbeat.rs b/rust/phoenix-channel/src/heartbeat.rs index 5d36c27e8..6f61bba7f 100644 --- a/rust/phoenix-channel/src/heartbeat.rs +++ b/rust/phoenix-channel/src/heartbeat.rs @@ -92,7 +92,7 @@ mod tests { let mut heartbeat = Heartbeat::new(Duration::from_millis(10)); let _ = poll_fn(|cx| heartbeat.poll(cx)).await; - heartbeat.set_id(OutboundRequestId::new(1)); + heartbeat.set_id(OutboundRequestId::for_test(1)); let result = poll_fn(|cx| heartbeat.poll(cx)).await; assert!(result.is_err()); @@ -103,8 +103,8 @@ mod tests { let mut heartbeat = Heartbeat::new(Duration::from_millis(10)); let _ = poll_fn(|cx| heartbeat.poll(cx)).await; - heartbeat.set_id(OutboundRequestId::new(1)); - heartbeat.maybe_handle_reply(OutboundRequestId::new(1)); + heartbeat.set_id(OutboundRequestId::for_test(1)); + heartbeat.maybe_handle_reply(OutboundRequestId::for_test(1)); let result = poll_fn(|cx| heartbeat.poll(cx)).await; assert!(result.is_ok()); diff --git a/rust/phoenix-channel/src/lib.rs b/rust/phoenix-channel/src/lib.rs index 95779b3cf..1f3eb3316 100644 --- a/rust/phoenix-channel/src/lib.rs +++ b/rust/phoenix-channel/src/lib.rs @@ -164,12 +164,13 @@ impl fmt::Display for InternalError { } } -#[derive(Debug, PartialEq, Eq, Hash, Deserialize, Serialize)] +/// A strict-monotonically increasing ID for outbound requests. +#[derive(Debug, PartialEq, Eq, Hash, Deserialize, Serialize, PartialOrd, Ord)] pub struct OutboundRequestId(u64); impl OutboundRequestId { - #[cfg(test)] - pub(crate) fn new(id: u64) -> Self { + // Should only be used for unit-testing. + pub fn for_test(id: u64) -> Self { Self(id) } @@ -355,7 +356,7 @@ where return Poll::Ready(Ok(Event::ErrorResponse { topic: message.topic, req_id, - reason, + res: reason, })); } (Payload::Reply(Reply::Ok(OkReply::Message(reply))), Some(req_id)) => { @@ -458,10 +459,10 @@ where let request_id = self.fetch_add_request_id(); // We don't care about the reply type when serializing - let msg = serde_json::to_string(&PhoenixMessage::<_, ()>::new( + let msg = serde_json::to_string(&PhoenixMessage::<_, ()>::new_message( topic, payload, - request_id.copy(), + Some(request_id.copy()), )) .expect("we should always be able to serialize a join topic message"); @@ -505,15 +506,15 @@ pub enum Event { /// The response received for an outbound request. res: TOutboundRes, }, + ErrorResponse { + topic: String, + req_id: OutboundRequestId, + res: ErrorReply, + }, JoinedRoom { topic: String, }, HeartbeatSent, - ErrorResponse { - topic: String, - req_id: OutboundRequestId, - reason: ErrorReply, - }, /// The server sent us a message, most likely this is a broadcast to all connected clients. InboundMessage { topic: String, @@ -565,6 +566,7 @@ enum OkReply { NoMessage(Empty), } +// TODO: I think this should also be a type-parameter. /// This represents the info we have about the error #[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)] #[serde(rename_all = "snake_case")] @@ -585,11 +587,40 @@ pub enum DisconnectReason { } impl PhoenixMessage { - pub fn new(topic: impl Into, payload: T, reference: OutboundRequestId) -> Self { + pub fn new_message( + topic: impl Into, + payload: T, + reference: Option, + ) -> Self { Self { topic: topic.into(), payload: Payload::Message(payload), - reference: Some(reference), + reference, + } + } + + pub fn new_ok_reply( + topic: impl Into, + payload: R, + reference: Option, + ) -> Self { + Self { + topic: topic.into(), + payload: Payload::Reply(Reply::Ok(OkReply::Message(payload))), + reference, + } + } + + #[cfg(test)] + fn new_err_reply( + topic: impl Into, + reason: ErrorReply, + reference: Option, + ) -> Self { + Self { + topic: topic.into(), + payload: Payload::Reply(Reply::Error { reason }), + reference, } } } @@ -769,4 +800,14 @@ mod tests { }); assert_eq!(actual_reply, expected_reply); } + + #[test] + fn disabled_err_reply() { + let json = r#"{"event":"phx_reply","ref":null,"topic":"client","payload":{"status":"error","response":{"reason": "disabled"}}}"#; + + let actual = serde_json::from_str::>(json).unwrap(); + let expected = PhoenixMessage::new_err_reply("client", ErrorReply::Disabled, None); + + assert_eq!(actual, expected) + } } diff --git a/rust/relay/src/main.rs b/rust/relay/src/main.rs index 0966447eb..849777ec1 100644 --- a/rust/relay/src/main.rs +++ b/rust/relay/src/main.rs @@ -499,12 +499,8 @@ where tracing::info!(target: "relay", "Successfully joined room '{topic}'"); continue; } - Some(Poll::Ready(Ok(Event::ErrorResponse { - topic, - req_id, - reason, - }))) => { - tracing::warn!(target: "relay", "Request with ID {req_id} on topic {topic} failed: {reason:?}"); + Some(Poll::Ready(Ok(Event::ErrorResponse { topic, req_id, res }))) => { + tracing::warn!(target: "relay", "Request with ID {req_id} on topic {topic} failed: {res:?}"); continue; } Some(Poll::Ready(Ok(Event::HeartbeatSent))) => {