diff --git a/rust/Cargo.lock b/rust/Cargo.lock index bab9efc75..42d767379 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -1135,7 +1135,6 @@ dependencies = [ "firezone-tunnel", "ip_network", "phoenix-channel", - "rayon", "secrecy", "serde", "serde_json", @@ -1377,16 +1376,6 @@ dependencies = [ "crossbeam-utils", ] -[[package]] -name = "crossbeam-deque" -version = "0.8.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9dd111b7b7f7d55b72c0a6ae361660ee5853c9af73f70c3c2ef6858b950e2e51" -dependencies = [ - "crossbeam-epoch", - "crossbeam-utils", -] - [[package]] name = "crossbeam-epoch" version = "0.9.18" @@ -5571,26 +5560,6 @@ version = "0.6.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "20675572f6f24e9e76ef639bc5552774ed45f1c30e2951e1e99c59888861c539" -[[package]] -name = "rayon" -version = "1.10.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b418a60154510ca1a002a752ca9714984e21e4241e804d32555251faf8b78ffa" -dependencies = [ - "either", - "rayon-core", -] - -[[package]] -name = "rayon-core" -version = "1.12.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1465873a3dfdaa8ae7cb14b4383657caab0b3e8a0aa9ae8e04b044854c8dfce2" -dependencies = [ - "crossbeam-deque", - "crossbeam-utils", -] - [[package]] name = "redox_syscall" version = "0.5.12" diff --git a/rust/Cargo.toml b/rust/Cargo.toml index 98025aec9..25ecfa333 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -129,7 +129,6 @@ quote = "1.0" rand = "0.8.5" rand_core = "0.6.4" rangemap = "1.5.1" -rayon = "1.10.0" reqwest = { version = "0.12.9", default-features = false } resolv-conf = "0.7.3" ringbuffer = "0.15.0" diff --git a/rust/android-client-ffi/src/lib.rs b/rust/android-client-ffi/src/lib.rs index 47fb6953b..008fe8747 100644 --- a/rust/android-client-ffi/src/lib.rs +++ b/rust/android-client-ffi/src/lib.rs @@ -8,7 +8,7 @@ use crate::tun::Tun; use anyhow::{Context as _, Result}; use backoff::ExponentialBackoffBuilder; -use client_shared::{Callbacks, DisconnectError, Session, V4RouteList, V6RouteList}; +use client_shared::{DisconnectError, Session, V4RouteList, V6RouteList}; use connlib_model::ResourceView; use dns_types::DomainName; use firezone_logging::{err_with_src, sentry_layer}; @@ -165,7 +165,7 @@ fn init_logging(log_dir: &Path, log_filter: String) -> Result<()> { Ok(()) } -impl Callbacks for CallbackHandler { +impl CallbackHandler { fn on_set_interface_config( &self, tunnel_address_v4: Ipv4Addr, @@ -382,14 +382,39 @@ fn connect( }, tcp_socket_factory, )?; - let session = Session::connect( + let (session, mut event_stream) = Session::connect( Arc::new(protected_tcp_socket_factory(callbacks.clone())), Arc::new(protected_udp_socket_factory(callbacks.clone())), - callbacks, portal, runtime.handle().clone(), ); + runtime.spawn(async move { + while let Some(event) = event_stream.next().await { + match event { + client_shared::Event::TunInterfaceUpdated { + ipv4, + ipv6, + dns, + search_domain, + ipv4_routes, + ipv6_routes, + } => callbacks.on_set_interface_config( + ipv4, + ipv6, + dns, + search_domain, + ipv4_routes, + ipv6_routes, + ), + client_shared::Event::ResourcesUpdated(resource_views) => { + callbacks.on_update_resources(resource_views) + } + client_shared::Event::Disconnected(error) => callbacks.on_disconnect(error), + } + } + }); + Ok(SessionWrapper { inner: session, runtime, diff --git a/rust/apple-client-ffi/src/lib.rs b/rust/apple-client-ffi/src/lib.rs index 30607c301..6ff2e5662 100644 --- a/rust/apple-client-ffi/src/lib.rs +++ b/rust/apple-client-ffi/src/lib.rs @@ -8,7 +8,7 @@ mod tun; use anyhow::Context; use anyhow::Result; use backoff::ExponentialBackoffBuilder; -use client_shared::{Callbacks, DisconnectError, Session, V4RouteList, V6RouteList}; +use client_shared::{DisconnectError, Event, Session, V4RouteList, V6RouteList}; use connlib_model::ResourceView; use dns_types::DomainName; use firezone_logging::err_with_src; @@ -127,15 +127,14 @@ pub struct WrappedSession { unsafe impl Send for ffi::CallbackHandler {} unsafe impl Sync for ffi::CallbackHandler {} -#[derive(Clone)] pub struct CallbackHandler { // Generated Swift opaque type wrappers have a `Drop` impl that decrements the // refcount, but there's no way to generate a `Clone` impl that increments the // recount. Instead, we just wrap it in an `Arc`. - inner: Arc, + inner: ffi::CallbackHandler, } -impl Callbacks for CallbackHandler { +impl CallbackHandler { fn on_set_interface_config( &self, tunnel_address_v4: Ipv4Addr, @@ -293,17 +292,48 @@ impl WrappedSession { }, Arc::new(socket_factory::tcp), )?; - let session = Session::connect( + let (session, mut event_stream) = Session::connect( Arc::new(socket_factory::tcp), Arc::new(socket_factory::udp), - CallbackHandler { - inner: Arc::new(callback_handler), - }, portal, runtime.handle().clone(), ); session.set_tun(Box::new(Tun::new()?)); + runtime.spawn(async move { + let callback_handler = CallbackHandler { + inner: callback_handler, + }; + + while let Some(event) = event_stream.next().await { + match event { + Event::TunInterfaceUpdated { + ipv4, + ipv6, + dns, + search_domain, + ipv4_routes, + ipv6_routes, + } => { + callback_handler.on_set_interface_config( + ipv4, + ipv6, + dns, + search_domain, + ipv4_routes, + ipv6_routes, + ); + } + Event::ResourcesUpdated(resource_views) => { + callback_handler.on_update_resources(resource_views); + } + Event::Disconnected(error) => { + callback_handler.on_disconnect(error); + } + } + } + }); + Ok(Self { inner: session, runtime, diff --git a/rust/client-shared/Cargo.toml b/rust/client-shared/Cargo.toml index 07e6e59f3..118c69ee8 100644 --- a/rust/client-shared/Cargo.toml +++ b/rust/client-shared/Cargo.toml @@ -14,7 +14,6 @@ firezone-logging = { workspace = true } firezone-tunnel = { workspace = true } ip_network = { workspace = true } phoenix-channel = { workspace = true } -rayon = { workspace = true } secrecy = { workspace = true } serde = { workspace = true, features = ["std", "derive"] } snownet = { workspace = true } diff --git a/rust/client-shared/src/callbacks.rs b/rust/client-shared/src/callbacks.rs deleted file mode 100644 index 3fce776b6..000000000 --- a/rust/client-shared/src/callbacks.rs +++ /dev/null @@ -1,219 +0,0 @@ -use connlib_model::ResourceView; -use dns_types::DomainName; -use ip_network::{Ipv4Network, Ipv6Network}; -use std::{ - net::{IpAddr, Ipv4Addr, Ipv6Addr}, - sync::Arc, -}; -use tokio::sync::mpsc; - -/// Traits that will be used by connlib to callback the client upper layers. -pub trait Callbacks: Clone + Send + Sync { - /// Called when the tunnel address is set. - /// - /// The first time this is called, the Resources list is also ready, - /// the routes are also ready, and the Client can consider the tunnel - /// to be ready for incoming traffic. - fn on_set_interface_config( - &self, - _: Ipv4Addr, - _: Ipv6Addr, - _: Vec, - _: Option, - _: Vec, - _: Vec, - ) { - } - - /// Called when the resource list changes. - /// - /// This may not be called if a Client has no Resources, which can - /// happen to new accounts, or when removing and re-adding Resources, - /// or if all Resources for a user are disabled by policy. - fn on_update_resources(&self, _: Vec) {} - - /// Called when the tunnel is disconnected. - fn on_disconnect(&self, _: DisconnectError) {} -} - -/// Unified error type to use across connlib. -#[derive(thiserror::Error, Debug)] -#[error("{0:#}")] -pub struct DisconnectError(anyhow::Error); - -impl From for DisconnectError { - fn from(e: anyhow::Error) -> Self { - Self(e) - } -} - -impl DisconnectError { - pub fn is_authentication_error(&self) -> bool { - let Some(e) = self.0.downcast_ref::() else { - return false; - }; - - e.is_authentication_error() - } -} - -#[derive(Debug, Clone)] -pub struct BackgroundCallbacks { - inner: C, - threadpool: Arc, -} - -impl BackgroundCallbacks { - pub fn new(callbacks: C) -> Self { - Self { - inner: callbacks, - threadpool: Arc::new( - rayon::ThreadPoolBuilder::new() - .num_threads(1) - .thread_name(|_| "connlib callbacks".to_owned()) - .build() - .expect("Unable to create thread-pool"), - ), - } - } -} - -impl Callbacks for BackgroundCallbacks -where - C: Callbacks + 'static, -{ - fn on_set_interface_config( - &self, - ipv4_addr: Ipv4Addr, - ipv6_addr: Ipv6Addr, - dns_addresses: Vec, - search_domain: Option, - route_list_4: Vec, - route_list_6: Vec, - ) { - let callbacks = self.inner.clone(); - - self.threadpool.spawn(move || { - callbacks.on_set_interface_config( - ipv4_addr, - ipv6_addr, - dns_addresses, - search_domain, - route_list_4, - route_list_6, - ); - }); - } - - fn on_update_resources(&self, resources: Vec) { - let callbacks = self.inner.clone(); - - self.threadpool.spawn(move || { - callbacks.on_update_resources(resources); - }); - } - - fn on_disconnect(&self, error: DisconnectError) { - let callbacks = self.inner.clone(); - - self.threadpool.spawn(move || { - callbacks.on_disconnect(error); - }); - } -} - -/// Messages that connlib can produce and send to the headless Client, Tunnel service, or GUI process. -/// -/// i.e. callbacks -// The names are CamelCase versions of the connlib callbacks. -#[expect(clippy::enum_variant_names)] -pub enum ConnlibMsg { - OnDisconnect { - error_msg: String, - is_authentication_error: bool, - }, - /// Use this as `TunnelReady`, per `callbacks.rs` - OnSetInterfaceConfig { - ipv4: Ipv4Addr, - ipv6: Ipv6Addr, - dns: Vec, - search_domain: Option, - ipv4_routes: Vec, - ipv6_routes: Vec, - }, - OnUpdateResources(Vec), -} - -#[derive(Clone)] -pub struct ChannelCallbackHandler { - cb_tx: mpsc::Sender, -} - -impl ChannelCallbackHandler { - pub fn new() -> (Self, mpsc::Receiver) { - let (cb_tx, cb_rx) = mpsc::channel(1_000); - - (Self { cb_tx }, cb_rx) - } -} - -impl Callbacks for ChannelCallbackHandler { - fn on_disconnect(&self, error: DisconnectError) { - self.cb_tx - .try_send(ConnlibMsg::OnDisconnect { - error_msg: error.to_string(), - is_authentication_error: error.is_authentication_error(), - }) - .expect("should be able to send OnDisconnect"); - } - - fn on_set_interface_config( - &self, - ipv4: Ipv4Addr, - ipv6: Ipv6Addr, - dns: Vec, - search_domain: Option, - ipv4_routes: Vec, - ipv6_routes: Vec, - ) { - self.cb_tx - .try_send(ConnlibMsg::OnSetInterfaceConfig { - ipv4, - ipv6, - dns, - search_domain, - ipv4_routes, - ipv6_routes, - }) - .expect("Should be able to send OnSetInterfaceConfig"); - } - - fn on_update_resources(&self, resources: Vec) { - tracing::debug!(len = resources.len(), "New resource list"); - self.cb_tx - .try_send(ConnlibMsg::OnUpdateResources(resources)) - .expect("Should be able to send OnUpdateResources"); - } -} - -#[cfg(test)] -mod tests { - use phoenix_channel::StatusCode; - - use super::*; - - #[test] - fn printing_disconnect_error_contains_401() { - let disconnect_error = DisconnectError::from(anyhow::Error::new( - phoenix_channel::Error::Client(StatusCode::UNAUTHORIZED), - )); - - assert!(disconnect_error.to_string().contains("401 Unauthorized")); // Apple client relies on this. - } - - // Make sure it's okay to store a bunch of these to mitigate #5880 - #[test] - fn callback_msg_size() { - assert_eq!(std::mem::size_of::(), 120) - } -} diff --git a/rust/client-shared/src/eventloop.rs b/rust/client-shared/src/eventloop.rs index 453f4381f..b0b779b7c 100644 --- a/rust/client-shared/src/eventloop.rs +++ b/rust/client-shared/src/eventloop.rs @@ -1,13 +1,16 @@ -use crate::{PHOENIX_TOPIC, callbacks::Callbacks}; +use crate::PHOENIX_TOPIC; use anyhow::{Context as _, Result}; -use connlib_model::{PublicKey, ResourceId}; +use connlib_model::{PublicKey, ResourceId, ResourceView}; +use dns_types::DomainName; use firezone_tunnel::messages::RelaysPresence; use firezone_tunnel::messages::client::{ EgressMessages, FailReason, FlowCreated, FlowCreationFailed, GatewayIceCandidates, GatewaysIceCandidates, IngressMessages, InitClient, }; use firezone_tunnel::{ClientTunnel, IpConfig}; +use ip_network::{Ipv4Network, Ipv6Network}; use phoenix_channel::{ErrorReply, OutboundRequestId, PhoenixChannel, PublicKeyParam}; +use std::net::{Ipv4Addr, Ipv6Addr}; use std::time::Instant; use std::{ collections::BTreeSet, @@ -15,14 +18,15 @@ use std::{ net::IpAddr, task::{Context, Poll}, }; +use tokio::sync::mpsc::error::TrySendError; use tun::Tun; -pub struct Eventloop { +pub struct Eventloop { tunnel: ClientTunnel, - callbacks: C, portal: PhoenixChannel<(), IngressMessages, (), PublicKeyParam>, - rx: tokio::sync::mpsc::UnboundedReceiver, + cmd_rx: tokio::sync::mpsc::UnboundedReceiver, + event_tx: tokio::sync::mpsc::Sender, } /// Commands that can be sent to the [`Eventloop`]. @@ -33,31 +37,62 @@ pub enum Command { SetDisabledResources(BTreeSet), } -impl Eventloop { +pub enum Event { + TunInterfaceUpdated { + ipv4: Ipv4Addr, + ipv6: Ipv6Addr, + dns: Vec, + search_domain: Option, + ipv4_routes: Vec, + ipv6_routes: Vec, + }, + ResourcesUpdated(Vec), + Disconnected(DisconnectError), +} + +/// Unified error type to use across connlib. +#[derive(thiserror::Error, Debug)] +#[error("{0:#}")] +pub struct DisconnectError(anyhow::Error); + +impl From for DisconnectError { + fn from(e: anyhow::Error) -> Self { + Self(e) + } +} + +impl DisconnectError { + pub fn is_authentication_error(&self) -> bool { + let Some(e) = self.0.downcast_ref::() else { + return false; + }; + + e.is_authentication_error() + } +} + +impl Eventloop { pub(crate) fn new( tunnel: ClientTunnel, - callbacks: C, mut portal: PhoenixChannel<(), IngressMessages, (), PublicKeyParam>, - rx: tokio::sync::mpsc::UnboundedReceiver, + cmd_rx: tokio::sync::mpsc::UnboundedReceiver, + event_tx: tokio::sync::mpsc::Sender, ) -> Self { portal.connect(PublicKeyParam(tunnel.public_key().to_bytes())); Self { tunnel, portal, - rx, - callbacks, + cmd_rx, + event_tx, } } } -impl Eventloop -where - C: Callbacks + 'static, -{ +impl Eventloop { pub fn poll(&mut self, cx: &mut Context<'_>) -> Poll> { loop { - match self.rx.poll_recv(cx) { + match self.cmd_rx.poll_recv(cx) { Poll::Ready(None) => return Poll::Ready(Ok(())), Poll::Ready(Some(Command::SetDns(dns))) => { self.tunnel.state_mut().update_system_resolvers(dns); @@ -84,7 +119,22 @@ where match self.tunnel.poll_next_event(cx) { Poll::Ready(Ok(event)) => { - self.handle_tunnel_event(event); + let Some(e) = self.handle_tunnel_event(event) else { + continue; + }; + + match self.event_tx.try_send(e) { + Ok(()) => {} + Err(TrySendError::Closed(_)) => { + tracing::debug!("Event receiver dropped, exiting event loop"); + + return Poll::Ready(Ok(())); + } + Err(TrySendError::Full(_)) => { + tracing::warn!("App cannot keep up with connlib events, dropping"); + } + }; + continue; } Poll::Ready(Err(e)) => { @@ -123,7 +173,7 @@ where } } - fn handle_tunnel_event(&mut self, event: firezone_tunnel::ClientEvent) { + fn handle_tunnel_event(&mut self, event: firezone_tunnel::ClientEvent) -> Option { match event { firezone_tunnel::ClientEvent::AddedIceCandidates { conn_id: gateway, @@ -138,6 +188,8 @@ where candidates, }), ); + + None } firezone_tunnel::ClientEvent::RemovedIceCandidates { conn_id: gateway, @@ -152,6 +204,8 @@ where candidates, }), ); + + None } firezone_tunnel::ClientEvent::ConnectionIntent { connected_gateway_ids, @@ -164,21 +218,21 @@ where connected_gateway_ids, }, ); + + None } firezone_tunnel::ClientEvent::ResourcesChanged { resources } => { - self.callbacks.on_update_resources(resources) + Some(Event::ResourcesUpdated(resources)) } firezone_tunnel::ClientEvent::TunInterfaceUpdated(config) => { - let dns_servers = config.dns_by_sentinel.left_values().copied().collect(); - - self.callbacks.on_set_interface_config( - config.ip.v4, - config.ip.v6, - dns_servers, - config.search_domain, - Vec::from_iter(config.ipv4_routes), - Vec::from_iter(config.ipv6_routes), - ); + Some(Event::TunInterfaceUpdated { + ipv4: config.ip.v4, + ipv6: config.ip.v6, + dns: config.dns_by_sentinel.left_values().copied().collect(), + search_domain: config.search_domain, + ipv4_routes: Vec::from_iter(config.ipv4_routes), + ipv6_routes: Vec::from_iter(config.ipv6_routes), + }) } } } diff --git a/rust/client-shared/src/lib.rs b/rust/client-shared/src/lib.rs index 2c9b5915f..98498a8cd 100644 --- a/rust/client-shared/src/lib.rs +++ b/rust/client-shared/src/lib.rs @@ -1,25 +1,23 @@ //! Main connlib library for clients. pub use crate::serde_routelist::{V4RouteList, V6RouteList}; -use callbacks::BackgroundCallbacks; -pub use callbacks::{Callbacks, ChannelCallbackHandler, ConnlibMsg, DisconnectError}; pub use connlib_model::StaticSecret; -pub use eventloop::Eventloop; +pub use eventloop::{DisconnectError, Event}; pub use firezone_tunnel::messages::client::{IngressMessages, ResourceDescription}; -use anyhow::{Context, Result}; +use anyhow::{Context as _, Result}; use connlib_model::ResourceId; -use eventloop::Command; +use eventloop::{Command, Eventloop}; use firezone_tunnel::ClientTunnel; use phoenix_channel::{PhoenixChannel, PublicKeyParam}; use socket_factory::{SocketFactory, TcpSocket, UdpSocket}; use std::collections::BTreeSet; use std::net::IpAddr; use std::sync::Arc; -use tokio::sync::mpsc::UnboundedReceiver; +use std::task::{Context, Poll}; +use tokio::sync::mpsc::{Receiver, Sender, UnboundedReceiver, UnboundedSender}; use tokio::task::JoinHandle; use tun::Tun; -mod callbacks; mod eventloop; mod serde_routelist; @@ -31,34 +29,36 @@ const PHOENIX_TOPIC: &str = "client"; /// To stop the session, simply drop this struct. #[derive(Clone)] pub struct Session { - channel: tokio::sync::mpsc::UnboundedSender, + channel: UnboundedSender, +} + +pub struct EventStream { + channel: Receiver, } impl Session { /// Creates a new [`Session`]. /// /// This connects to the portal using the given [`LoginUrl`](phoenix_channel::LoginUrl) and creates a wireguard tunnel using the provided private key. - pub fn connect( + pub fn connect( tcp_socket_factory: Arc>, udp_socket_factory: Arc>, - callbacks: CB, portal: PhoenixChannel<(), IngressMessages, (), PublicKeyParam>, handle: tokio::runtime::Handle, - ) -> Self { - let callbacks = BackgroundCallbacks::new(callbacks); // Run all callbacks on a background thread to avoid blocking the main connlib task. - - let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); + ) -> (Self, EventStream) { + let (cmd_tx, cmd_rx) = tokio::sync::mpsc::unbounded_channel(); + let (event_tx, event_rx) = tokio::sync::mpsc::channel(1000); let connect_handle = handle.spawn(connect( tcp_socket_factory, udp_socket_factory, - callbacks.clone(), portal, - rx, + cmd_rx, + event_tx.clone(), )); - handle.spawn(connect_supervisor(connect_handle, callbacks)); + handle.spawn(connect_supervisor(connect_handle, event_tx)); - Self { channel: tx } + (Self { channel: cmd_tx }, EventStream { channel: event_rx }) } /// Reset a [`Session`]. @@ -107,6 +107,16 @@ impl Session { } } +impl EventStream { + pub fn poll_next(&mut self, cx: &mut Context) -> Poll> { + self.channel.poll_recv(cx) + } + + pub async fn next(&mut self) -> Option { + self.channel.recv().await + } +} + impl Drop for Session { fn drop(&mut self) { tracing::debug!("`Session` dropped") @@ -116,18 +126,15 @@ impl Drop for Session { /// Connects to the portal and starts a tunnel. /// /// When this function exits, the tunnel failed unrecoverably and you need to call it again. -async fn connect( +async fn connect( tcp_socket_factory: Arc>, udp_socket_factory: Arc>, - callbacks: CB, portal: PhoenixChannel<(), IngressMessages, (), PublicKeyParam>, - rx: UnboundedReceiver, -) -> Result<()> -where - CB: Callbacks + 'static, -{ + cmd_rx: UnboundedReceiver, + event_tx: Sender, +) -> Result<()> { let tunnel = ClientTunnel::new(tcp_socket_factory, udp_socket_factory); - let mut eventloop = Eventloop::new(tunnel, callbacks, portal, rx); + let mut eventloop = Eventloop::new(tunnel, portal, cmd_rx, event_tx); std::future::poll_fn(|cx| eventloop.poll(cx)).await?; @@ -135,18 +142,27 @@ where } /// A supervisor task that handles, when [`connect`] exits. -async fn connect_supervisor(connect_handle: JoinHandle>, callbacks: CB) -where - CB: Callbacks, -{ +async fn connect_supervisor( + connect_handle: JoinHandle>, + event_tx: tokio::sync::mpsc::Sender, +) { let task = async { connect_handle.await.context("connlib crashed")??; Ok(()) }; - match task.await { - Ok(()) => tracing::info!("connlib exited gracefully"), - Err(e) => callbacks.on_disconnect(e), + let error = match task.await { + Ok(()) => { + tracing::info!("connlib exited gracefully"); + + return; + } + Err(e) => e, + }; + + match event_tx.send(Event::Disconnected(error)).await { + Ok(()) => (), + Err(_) => tracing::debug!("Event stream closed before we could send disconnected event"), } } diff --git a/rust/gui-client/src-tauri/src/service.rs b/rust/gui-client/src-tauri/src/service.rs index 046ea10ca..ff310f721 100644 --- a/rust/gui-client/src-tauri/src/service.rs +++ b/rust/gui-client/src-tauri/src/service.rs @@ -5,7 +5,6 @@ use crate::{ use anyhow::{Context as _, Result, bail}; use atomicwrites::{AtomicFile, OverwriteBehavior}; use backoff::ExponentialBackoffBuilder; -use client_shared::ConnlibMsg; use connlib_model::{ResourceId, ResourceView}; use firezone_bin_shared::{ DnsControlMethod, DnsController, TunDeviceManager, @@ -31,7 +30,7 @@ use std::{ sync::Arc, time::Duration, }; -use tokio::{sync::mpsc, time::Instant}; +use tokio::time::Instant; use url::Url; #[cfg(target_os = "linux")] @@ -179,12 +178,12 @@ struct Handler<'a> { } struct Session { - cb_rx: mpsc::Receiver, + event_stream: client_shared::EventStream, connlib: client_shared::Session, } enum Event { - Callback(ConnlibMsg), + Connlib(client_shared::Event), CallbackChannelClosed, Ipc(ClientMsg), IpcDisconnected, @@ -247,8 +246,8 @@ impl<'a> Handler<'a> { async fn run(&mut self, signals: &mut signals::Terminate) -> HandlerOk { let ret = loop { match poll_fn(|cx| self.next_event(cx, signals)).await { - Event::Callback(x) => { - if let Err(error) = self.handle_connlib_cb(x).await { + Event::Connlib(x) => { + if let Err(error) = self.handle_connlib_event(x).await { tracing::error!("Error while handling connlib callback: {error:#}"); continue; } @@ -309,10 +308,9 @@ impl<'a> Handler<'a> { }); } if let Some(session) = self.session.as_mut() { - // `tokio::sync::mpsc::Receiver::recv` is cancel-safe. - if let Poll::Ready(option) = session.cb_rx.poll_recv(cx) { + if let Poll::Ready(option) = session.event_stream.poll_next(cx) { return Poll::Ready(match option { - Some(x) => Event::Callback(x), + Some(x) => Event::Connlib(x), None => Event::CallbackChannelClosed, }); } @@ -320,21 +318,18 @@ impl<'a> Handler<'a> { Poll::Pending } - async fn handle_connlib_cb(&mut self, msg: ConnlibMsg) -> Result<()> { + async fn handle_connlib_event(&mut self, msg: client_shared::Event) -> Result<()> { match msg { - ConnlibMsg::OnDisconnect { - error_msg, - is_authentication_error, - } => { + client_shared::Event::Disconnected(error) => { let _ = self.session.take(); self.dns_controller.deactivate()?; self.send_ipc(ServerMsg::OnDisconnect { - error_msg, - is_authentication_error, + error_msg: error.to_string(), + is_authentication_error: error.is_authentication_error(), }) .await? } - ConnlibMsg::OnSetInterfaceConfig { + client_shared::Event::TunInterfaceUpdated { ipv4, ipv6, dns, @@ -352,7 +347,7 @@ impl<'a> Handler<'a> { self.send_ipc(ServerMsg::TunnelReady).await?; } - ConnlibMsg::OnUpdateResources(resources) => { + client_shared::Event::ResourcesUpdated(resources) => { // On every resources update, flush DNS to mitigate self.dns_controller.flush()?; self.send_ipc(ServerMsg::OnUpdateResources(resources)) @@ -472,7 +467,6 @@ impl<'a> Handler<'a> { .context("Failed to create `LoginUrl`")?; self.last_connlib_start_instant = Some(Instant::now()); - let (callbacks, cb_rx) = client_shared::ChannelCallbackHandler::new(); // Synchronous DNS resolution here let portal = PhoenixChannel::disconnected( @@ -490,10 +484,9 @@ impl<'a> Handler<'a> { // Read the resolvers before starting connlib, in case connlib's startup interferes. let dns = self.dns_controller.system_resolvers(); - let connlib = client_shared::Session::connect( + let (connlib, event_stream) = client_shared::Session::connect( Arc::new(tcp_socket_factory), Arc::new(udp_socket_factory), - callbacks, portal, tokio::runtime::Handle::current(), ); @@ -510,7 +503,10 @@ impl<'a> Handler<'a> { }; connlib.set_tun(tun); - let session = Session { cb_rx, connlib }; + let session = Session { + event_stream, + connlib, + }; self.session = Some(session); Ok(()) diff --git a/rust/headless-client/src/main.rs b/rust/headless-client/src/main.rs index 0e16ee203..6f0a8185c 100644 --- a/rust/headless-client/src/main.rs +++ b/rust/headless-client/src/main.rs @@ -5,7 +5,6 @@ use anyhow::{Context as _, Result, anyhow}; use backoff::ExponentialBackoffBuilder; use clap::Parser; -use client_shared::{ChannelCallbackHandler, ConnlibMsg, Session}; use firezone_bin_shared::{ DnsControlMethod, DnsController, TOKEN_ENV_KEY, TunDeviceManager, device_id, device_info, new_dns_notifier, new_network_notifier, @@ -15,7 +14,6 @@ use firezone_bin_shared::{ use firezone_logging::telemetry_span; use firezone_telemetry::Telemetry; use firezone_telemetry::otel; -use futures::StreamExt as _; use opentelemetry_sdk::metrics::{PeriodicReader, SdkMeterProvider}; use phoenix_channel::PhoenixChannel; use phoenix_channel::get_user_agent; @@ -26,7 +24,6 @@ use std::{ sync::Arc, }; use tokio::time::Instant; -use tokio_stream::wrappers::ReceiverStream; #[cfg(target_os = "linux")] #[path = "linux.rs"] @@ -221,8 +218,6 @@ fn main() -> Result<()> { return Ok(()); } - let (callbacks, cb_rx) = ChannelCallbackHandler::new(); - // The name matches that in `ipc_service.rs` let mut last_connlib_start_instant = Some(Instant::now()); @@ -261,10 +256,9 @@ fn main() -> Result<()> { }, Arc::new(tcp_socket_factory), )?; - let session = Session::connect( + let (session, mut event_stream) = client_shared::Session::connect( Arc::new(tcp_socket_factory), Arc::new(udp_socket_factory), - callbacks, portal, rt.handle().clone(), ); @@ -273,7 +267,6 @@ fn main() -> Result<()> { let mut hangup = signals::Hangup::new()?; let mut tun_device = TunDeviceManager::new(ip_packet::MAX_IP_SIZE, 1)?; - let mut cb_rx = ReceiverStream::new(cb_rx).fuse(); let tokio_handle = tokio::runtime::Handle::current(); @@ -294,7 +287,7 @@ fn main() -> Result<()> { drop(connect_span); let result = loop { - let cb = tokio::select! { + let event = tokio::select! { () = terminate.recv() => { tracing::info!("Caught SIGINT / SIGTERM / Ctrl+C"); break Ok(()); @@ -318,20 +311,17 @@ fn main() -> Result<()> { session.reset(); continue; }, - cb = cb_rx.next() => cb.context("cb_rx unexpectedly ran empty")?, + event = event_stream.next() => event.context("event stream unexpectedly ran empty")?, }; - match cb { + match event { // TODO: Headless Client shouldn't be using messages labelled `Ipc` - ConnlibMsg::OnDisconnect { - error_msg, - is_authentication_error: _, - } => break Err(anyhow!(error_msg).context("Firezone disconnected")), - ConnlibMsg::OnUpdateResources(_) => { + client_shared::Event::Disconnected(error) => break Err(anyhow!(error).context("Firezone disconnected")), + client_shared::Event::ResourcesUpdated(_) => { // On every Resources update, flush DNS to mitigate dns_controller.flush()?; } - ConnlibMsg::OnSetInterfaceConfig { + client_shared::Event::TunInterfaceUpdated { ipv4, ipv6, dns,