refactor(rust): remove forced callback indirection (#9362)

As relict from very early designs of `connlib`, the `Callbacks` trait is
still present and defines how the host app receives events from a
running `Session`. Callbacks are not a great design pattern however
because they force the running code, i.e. `connlib`s event-loop to
execute unknown code. For example, if that code panics, all of `connlib`
is taken down. Additionally, not all consumers may want to receive
events via callbacks. The GUI and headless client for example already
have their own event-loop in which they process all kinds of things.
Having to deal with the `Callbacks` interface introduces an odd
indirection here.

To fix this, we instead return an `EventStream` when constructing a
`Session`. This essentially aligns the API of `Session` with that of a
channel. You receive two handles, one for sending in commands and one
for receiving events. A `Session` will automatically spawn itself onto
the given runtime so progress is made even if one does not poll on these
channel handles.

This greatly simplifies the code:

- We get to delete the `Callbacks` interface.
- We can delete the threaded callback adapter. This was only necessary
because we didn't want to block `connlib` with the handling of the
event. By using a channel for events, this is automatically guaranteed.
- The GUI and headless client can directly integrate the event handling
in their event-loop, without having to create an indirection with a
channel.
- It is now clear that only the Apple and Android FFI layers actually
use callbacks to communicate these events.
- We net-delete 100 LoC
This commit is contained in:
Thomas Eizinger
2025-06-02 19:28:04 +08:00
committed by GitHub
parent b7b296a102
commit 1914ea7076
10 changed files with 223 additions and 364 deletions

View File

@@ -14,7 +14,6 @@ firezone-logging = { workspace = true }
firezone-tunnel = { workspace = true }
ip_network = { workspace = true }
phoenix-channel = { workspace = true }
rayon = { workspace = true }
secrecy = { workspace = true }
serde = { workspace = true, features = ["std", "derive"] }
snownet = { workspace = true }

View File

@@ -1,219 +0,0 @@
use connlib_model::ResourceView;
use dns_types::DomainName;
use ip_network::{Ipv4Network, Ipv6Network};
use std::{
net::{IpAddr, Ipv4Addr, Ipv6Addr},
sync::Arc,
};
use tokio::sync::mpsc;
/// Traits that will be used by connlib to callback the client upper layers.
pub trait Callbacks: Clone + Send + Sync {
/// Called when the tunnel address is set.
///
/// The first time this is called, the Resources list is also ready,
/// the routes are also ready, and the Client can consider the tunnel
/// to be ready for incoming traffic.
fn on_set_interface_config(
&self,
_: Ipv4Addr,
_: Ipv6Addr,
_: Vec<IpAddr>,
_: Option<DomainName>,
_: Vec<Ipv4Network>,
_: Vec<Ipv6Network>,
) {
}
/// Called when the resource list changes.
///
/// This may not be called if a Client has no Resources, which can
/// happen to new accounts, or when removing and re-adding Resources,
/// or if all Resources for a user are disabled by policy.
fn on_update_resources(&self, _: Vec<ResourceView>) {}
/// Called when the tunnel is disconnected.
fn on_disconnect(&self, _: DisconnectError) {}
}
/// Unified error type to use across connlib.
#[derive(thiserror::Error, Debug)]
#[error("{0:#}")]
pub struct DisconnectError(anyhow::Error);
impl From<anyhow::Error> for DisconnectError {
fn from(e: anyhow::Error) -> Self {
Self(e)
}
}
impl DisconnectError {
pub fn is_authentication_error(&self) -> bool {
let Some(e) = self.0.downcast_ref::<phoenix_channel::Error>() else {
return false;
};
e.is_authentication_error()
}
}
#[derive(Debug, Clone)]
pub struct BackgroundCallbacks<C> {
inner: C,
threadpool: Arc<rayon::ThreadPool>,
}
impl<C> BackgroundCallbacks<C> {
pub fn new(callbacks: C) -> Self {
Self {
inner: callbacks,
threadpool: Arc::new(
rayon::ThreadPoolBuilder::new()
.num_threads(1)
.thread_name(|_| "connlib callbacks".to_owned())
.build()
.expect("Unable to create thread-pool"),
),
}
}
}
impl<C> Callbacks for BackgroundCallbacks<C>
where
C: Callbacks + 'static,
{
fn on_set_interface_config(
&self,
ipv4_addr: Ipv4Addr,
ipv6_addr: Ipv6Addr,
dns_addresses: Vec<IpAddr>,
search_domain: Option<DomainName>,
route_list_4: Vec<Ipv4Network>,
route_list_6: Vec<Ipv6Network>,
) {
let callbacks = self.inner.clone();
self.threadpool.spawn(move || {
callbacks.on_set_interface_config(
ipv4_addr,
ipv6_addr,
dns_addresses,
search_domain,
route_list_4,
route_list_6,
);
});
}
fn on_update_resources(&self, resources: Vec<ResourceView>) {
let callbacks = self.inner.clone();
self.threadpool.spawn(move || {
callbacks.on_update_resources(resources);
});
}
fn on_disconnect(&self, error: DisconnectError) {
let callbacks = self.inner.clone();
self.threadpool.spawn(move || {
callbacks.on_disconnect(error);
});
}
}
/// Messages that connlib can produce and send to the headless Client, Tunnel service, or GUI process.
///
/// i.e. callbacks
// The names are CamelCase versions of the connlib callbacks.
#[expect(clippy::enum_variant_names)]
pub enum ConnlibMsg {
OnDisconnect {
error_msg: String,
is_authentication_error: bool,
},
/// Use this as `TunnelReady`, per `callbacks.rs`
OnSetInterfaceConfig {
ipv4: Ipv4Addr,
ipv6: Ipv6Addr,
dns: Vec<IpAddr>,
search_domain: Option<DomainName>,
ipv4_routes: Vec<Ipv4Network>,
ipv6_routes: Vec<Ipv6Network>,
},
OnUpdateResources(Vec<ResourceView>),
}
#[derive(Clone)]
pub struct ChannelCallbackHandler {
cb_tx: mpsc::Sender<ConnlibMsg>,
}
impl ChannelCallbackHandler {
pub fn new() -> (Self, mpsc::Receiver<ConnlibMsg>) {
let (cb_tx, cb_rx) = mpsc::channel(1_000);
(Self { cb_tx }, cb_rx)
}
}
impl Callbacks for ChannelCallbackHandler {
fn on_disconnect(&self, error: DisconnectError) {
self.cb_tx
.try_send(ConnlibMsg::OnDisconnect {
error_msg: error.to_string(),
is_authentication_error: error.is_authentication_error(),
})
.expect("should be able to send OnDisconnect");
}
fn on_set_interface_config(
&self,
ipv4: Ipv4Addr,
ipv6: Ipv6Addr,
dns: Vec<IpAddr>,
search_domain: Option<DomainName>,
ipv4_routes: Vec<Ipv4Network>,
ipv6_routes: Vec<Ipv6Network>,
) {
self.cb_tx
.try_send(ConnlibMsg::OnSetInterfaceConfig {
ipv4,
ipv6,
dns,
search_domain,
ipv4_routes,
ipv6_routes,
})
.expect("Should be able to send OnSetInterfaceConfig");
}
fn on_update_resources(&self, resources: Vec<ResourceView>) {
tracing::debug!(len = resources.len(), "New resource list");
self.cb_tx
.try_send(ConnlibMsg::OnUpdateResources(resources))
.expect("Should be able to send OnUpdateResources");
}
}
#[cfg(test)]
mod tests {
use phoenix_channel::StatusCode;
use super::*;
#[test]
fn printing_disconnect_error_contains_401() {
let disconnect_error = DisconnectError::from(anyhow::Error::new(
phoenix_channel::Error::Client(StatusCode::UNAUTHORIZED),
));
assert!(disconnect_error.to_string().contains("401 Unauthorized")); // Apple client relies on this.
}
// Make sure it's okay to store a bunch of these to mitigate #5880
#[test]
fn callback_msg_size() {
assert_eq!(std::mem::size_of::<ConnlibMsg>(), 120)
}
}

View File

@@ -1,13 +1,16 @@
use crate::{PHOENIX_TOPIC, callbacks::Callbacks};
use crate::PHOENIX_TOPIC;
use anyhow::{Context as _, Result};
use connlib_model::{PublicKey, ResourceId};
use connlib_model::{PublicKey, ResourceId, ResourceView};
use dns_types::DomainName;
use firezone_tunnel::messages::RelaysPresence;
use firezone_tunnel::messages::client::{
EgressMessages, FailReason, FlowCreated, FlowCreationFailed, GatewayIceCandidates,
GatewaysIceCandidates, IngressMessages, InitClient,
};
use firezone_tunnel::{ClientTunnel, IpConfig};
use ip_network::{Ipv4Network, Ipv6Network};
use phoenix_channel::{ErrorReply, OutboundRequestId, PhoenixChannel, PublicKeyParam};
use std::net::{Ipv4Addr, Ipv6Addr};
use std::time::Instant;
use std::{
collections::BTreeSet,
@@ -15,14 +18,15 @@ use std::{
net::IpAddr,
task::{Context, Poll},
};
use tokio::sync::mpsc::error::TrySendError;
use tun::Tun;
pub struct Eventloop<C: Callbacks> {
pub struct Eventloop {
tunnel: ClientTunnel,
callbacks: C,
portal: PhoenixChannel<(), IngressMessages, (), PublicKeyParam>,
rx: tokio::sync::mpsc::UnboundedReceiver<Command>,
cmd_rx: tokio::sync::mpsc::UnboundedReceiver<Command>,
event_tx: tokio::sync::mpsc::Sender<Event>,
}
/// Commands that can be sent to the [`Eventloop`].
@@ -33,31 +37,62 @@ pub enum Command {
SetDisabledResources(BTreeSet<ResourceId>),
}
impl<C: Callbacks> Eventloop<C> {
pub enum Event {
TunInterfaceUpdated {
ipv4: Ipv4Addr,
ipv6: Ipv6Addr,
dns: Vec<IpAddr>,
search_domain: Option<DomainName>,
ipv4_routes: Vec<Ipv4Network>,
ipv6_routes: Vec<Ipv6Network>,
},
ResourcesUpdated(Vec<ResourceView>),
Disconnected(DisconnectError),
}
/// Unified error type to use across connlib.
#[derive(thiserror::Error, Debug)]
#[error("{0:#}")]
pub struct DisconnectError(anyhow::Error);
impl From<anyhow::Error> for DisconnectError {
fn from(e: anyhow::Error) -> Self {
Self(e)
}
}
impl DisconnectError {
pub fn is_authentication_error(&self) -> bool {
let Some(e) = self.0.downcast_ref::<phoenix_channel::Error>() else {
return false;
};
e.is_authentication_error()
}
}
impl Eventloop {
pub(crate) fn new(
tunnel: ClientTunnel,
callbacks: C,
mut portal: PhoenixChannel<(), IngressMessages, (), PublicKeyParam>,
rx: tokio::sync::mpsc::UnboundedReceiver<Command>,
cmd_rx: tokio::sync::mpsc::UnboundedReceiver<Command>,
event_tx: tokio::sync::mpsc::Sender<Event>,
) -> Self {
portal.connect(PublicKeyParam(tunnel.public_key().to_bytes()));
Self {
tunnel,
portal,
rx,
callbacks,
cmd_rx,
event_tx,
}
}
}
impl<C> Eventloop<C>
where
C: Callbacks + 'static,
{
impl Eventloop {
pub fn poll(&mut self, cx: &mut Context<'_>) -> Poll<Result<()>> {
loop {
match self.rx.poll_recv(cx) {
match self.cmd_rx.poll_recv(cx) {
Poll::Ready(None) => return Poll::Ready(Ok(())),
Poll::Ready(Some(Command::SetDns(dns))) => {
self.tunnel.state_mut().update_system_resolvers(dns);
@@ -84,7 +119,22 @@ where
match self.tunnel.poll_next_event(cx) {
Poll::Ready(Ok(event)) => {
self.handle_tunnel_event(event);
let Some(e) = self.handle_tunnel_event(event) else {
continue;
};
match self.event_tx.try_send(e) {
Ok(()) => {}
Err(TrySendError::Closed(_)) => {
tracing::debug!("Event receiver dropped, exiting event loop");
return Poll::Ready(Ok(()));
}
Err(TrySendError::Full(_)) => {
tracing::warn!("App cannot keep up with connlib events, dropping");
}
};
continue;
}
Poll::Ready(Err(e)) => {
@@ -123,7 +173,7 @@ where
}
}
fn handle_tunnel_event(&mut self, event: firezone_tunnel::ClientEvent) {
fn handle_tunnel_event(&mut self, event: firezone_tunnel::ClientEvent) -> Option<Event> {
match event {
firezone_tunnel::ClientEvent::AddedIceCandidates {
conn_id: gateway,
@@ -138,6 +188,8 @@ where
candidates,
}),
);
None
}
firezone_tunnel::ClientEvent::RemovedIceCandidates {
conn_id: gateway,
@@ -152,6 +204,8 @@ where
candidates,
}),
);
None
}
firezone_tunnel::ClientEvent::ConnectionIntent {
connected_gateway_ids,
@@ -164,21 +218,21 @@ where
connected_gateway_ids,
},
);
None
}
firezone_tunnel::ClientEvent::ResourcesChanged { resources } => {
self.callbacks.on_update_resources(resources)
Some(Event::ResourcesUpdated(resources))
}
firezone_tunnel::ClientEvent::TunInterfaceUpdated(config) => {
let dns_servers = config.dns_by_sentinel.left_values().copied().collect();
self.callbacks.on_set_interface_config(
config.ip.v4,
config.ip.v6,
dns_servers,
config.search_domain,
Vec::from_iter(config.ipv4_routes),
Vec::from_iter(config.ipv6_routes),
);
Some(Event::TunInterfaceUpdated {
ipv4: config.ip.v4,
ipv6: config.ip.v6,
dns: config.dns_by_sentinel.left_values().copied().collect(),
search_domain: config.search_domain,
ipv4_routes: Vec::from_iter(config.ipv4_routes),
ipv6_routes: Vec::from_iter(config.ipv6_routes),
})
}
}
}

View File

@@ -1,25 +1,23 @@
//! Main connlib library for clients.
pub use crate::serde_routelist::{V4RouteList, V6RouteList};
use callbacks::BackgroundCallbacks;
pub use callbacks::{Callbacks, ChannelCallbackHandler, ConnlibMsg, DisconnectError};
pub use connlib_model::StaticSecret;
pub use eventloop::Eventloop;
pub use eventloop::{DisconnectError, Event};
pub use firezone_tunnel::messages::client::{IngressMessages, ResourceDescription};
use anyhow::{Context, Result};
use anyhow::{Context as _, Result};
use connlib_model::ResourceId;
use eventloop::Command;
use eventloop::{Command, Eventloop};
use firezone_tunnel::ClientTunnel;
use phoenix_channel::{PhoenixChannel, PublicKeyParam};
use socket_factory::{SocketFactory, TcpSocket, UdpSocket};
use std::collections::BTreeSet;
use std::net::IpAddr;
use std::sync::Arc;
use tokio::sync::mpsc::UnboundedReceiver;
use std::task::{Context, Poll};
use tokio::sync::mpsc::{Receiver, Sender, UnboundedReceiver, UnboundedSender};
use tokio::task::JoinHandle;
use tun::Tun;
mod callbacks;
mod eventloop;
mod serde_routelist;
@@ -31,34 +29,36 @@ const PHOENIX_TOPIC: &str = "client";
/// To stop the session, simply drop this struct.
#[derive(Clone)]
pub struct Session {
channel: tokio::sync::mpsc::UnboundedSender<Command>,
channel: UnboundedSender<Command>,
}
pub struct EventStream {
channel: Receiver<Event>,
}
impl Session {
/// Creates a new [`Session`].
///
/// This connects to the portal using the given [`LoginUrl`](phoenix_channel::LoginUrl) and creates a wireguard tunnel using the provided private key.
pub fn connect<CB: Callbacks + 'static>(
pub fn connect(
tcp_socket_factory: Arc<dyn SocketFactory<TcpSocket>>,
udp_socket_factory: Arc<dyn SocketFactory<UdpSocket>>,
callbacks: CB,
portal: PhoenixChannel<(), IngressMessages, (), PublicKeyParam>,
handle: tokio::runtime::Handle,
) -> Self {
let callbacks = BackgroundCallbacks::new(callbacks); // Run all callbacks on a background thread to avoid blocking the main connlib task.
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
) -> (Self, EventStream) {
let (cmd_tx, cmd_rx) = tokio::sync::mpsc::unbounded_channel();
let (event_tx, event_rx) = tokio::sync::mpsc::channel(1000);
let connect_handle = handle.spawn(connect(
tcp_socket_factory,
udp_socket_factory,
callbacks.clone(),
portal,
rx,
cmd_rx,
event_tx.clone(),
));
handle.spawn(connect_supervisor(connect_handle, callbacks));
handle.spawn(connect_supervisor(connect_handle, event_tx));
Self { channel: tx }
(Self { channel: cmd_tx }, EventStream { channel: event_rx })
}
/// Reset a [`Session`].
@@ -107,6 +107,16 @@ impl Session {
}
}
impl EventStream {
pub fn poll_next(&mut self, cx: &mut Context) -> Poll<Option<Event>> {
self.channel.poll_recv(cx)
}
pub async fn next(&mut self) -> Option<Event> {
self.channel.recv().await
}
}
impl Drop for Session {
fn drop(&mut self) {
tracing::debug!("`Session` dropped")
@@ -116,18 +126,15 @@ impl Drop for Session {
/// Connects to the portal and starts a tunnel.
///
/// When this function exits, the tunnel failed unrecoverably and you need to call it again.
async fn connect<CB>(
async fn connect(
tcp_socket_factory: Arc<dyn SocketFactory<TcpSocket>>,
udp_socket_factory: Arc<dyn SocketFactory<UdpSocket>>,
callbacks: CB,
portal: PhoenixChannel<(), IngressMessages, (), PublicKeyParam>,
rx: UnboundedReceiver<Command>,
) -> Result<()>
where
CB: Callbacks + 'static,
{
cmd_rx: UnboundedReceiver<Command>,
event_tx: Sender<Event>,
) -> Result<()> {
let tunnel = ClientTunnel::new(tcp_socket_factory, udp_socket_factory);
let mut eventloop = Eventloop::new(tunnel, callbacks, portal, rx);
let mut eventloop = Eventloop::new(tunnel, portal, cmd_rx, event_tx);
std::future::poll_fn(|cx| eventloop.poll(cx)).await?;
@@ -135,18 +142,27 @@ where
}
/// A supervisor task that handles, when [`connect`] exits.
async fn connect_supervisor<CB>(connect_handle: JoinHandle<Result<()>>, callbacks: CB)
where
CB: Callbacks,
{
async fn connect_supervisor(
connect_handle: JoinHandle<Result<()>>,
event_tx: tokio::sync::mpsc::Sender<Event>,
) {
let task = async {
connect_handle.await.context("connlib crashed")??;
Ok(())
};
match task.await {
Ok(()) => tracing::info!("connlib exited gracefully"),
Err(e) => callbacks.on_disconnect(e),
let error = match task.await {
Ok(()) => {
tracing::info!("connlib exited gracefully");
return;
}
Err(e) => e,
};
match event_tx.send(Event::Disconnected(error)).await {
Ok(()) => (),
Err(_) => tracing::debug!("Event stream closed before we could send disconnected event"),
}
}