refactor(rust): remove forced callback indirection (#9362)

As relict from very early designs of `connlib`, the `Callbacks` trait is
still present and defines how the host app receives events from a
running `Session`. Callbacks are not a great design pattern however
because they force the running code, i.e. `connlib`s event-loop to
execute unknown code. For example, if that code panics, all of `connlib`
is taken down. Additionally, not all consumers may want to receive
events via callbacks. The GUI and headless client for example already
have their own event-loop in which they process all kinds of things.
Having to deal with the `Callbacks` interface introduces an odd
indirection here.

To fix this, we instead return an `EventStream` when constructing a
`Session`. This essentially aligns the API of `Session` with that of a
channel. You receive two handles, one for sending in commands and one
for receiving events. A `Session` will automatically spawn itself onto
the given runtime so progress is made even if one does not poll on these
channel handles.

This greatly simplifies the code:

- We get to delete the `Callbacks` interface.
- We can delete the threaded callback adapter. This was only necessary
because we didn't want to block `connlib` with the handling of the
event. By using a channel for events, this is automatically guaranteed.
- The GUI and headless client can directly integrate the event handling
in their event-loop, without having to create an indirection with a
channel.
- It is now clear that only the Apple and Android FFI layers actually
use callbacks to communicate these events.
- We net-delete 100 LoC
This commit is contained in:
Thomas Eizinger
2025-06-02 19:28:04 +08:00
committed by GitHub
parent b7b296a102
commit 1914ea7076
10 changed files with 223 additions and 364 deletions

31
rust/Cargo.lock generated
View File

@@ -1135,7 +1135,6 @@ dependencies = [
"firezone-tunnel",
"ip_network",
"phoenix-channel",
"rayon",
"secrecy",
"serde",
"serde_json",
@@ -1377,16 +1376,6 @@ dependencies = [
"crossbeam-utils",
]
[[package]]
name = "crossbeam-deque"
version = "0.8.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9dd111b7b7f7d55b72c0a6ae361660ee5853c9af73f70c3c2ef6858b950e2e51"
dependencies = [
"crossbeam-epoch",
"crossbeam-utils",
]
[[package]]
name = "crossbeam-epoch"
version = "0.9.18"
@@ -5571,26 +5560,6 @@ version = "0.6.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "20675572f6f24e9e76ef639bc5552774ed45f1c30e2951e1e99c59888861c539"
[[package]]
name = "rayon"
version = "1.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b418a60154510ca1a002a752ca9714984e21e4241e804d32555251faf8b78ffa"
dependencies = [
"either",
"rayon-core",
]
[[package]]
name = "rayon-core"
version = "1.12.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1465873a3dfdaa8ae7cb14b4383657caab0b3e8a0aa9ae8e04b044854c8dfce2"
dependencies = [
"crossbeam-deque",
"crossbeam-utils",
]
[[package]]
name = "redox_syscall"
version = "0.5.12"

View File

@@ -129,7 +129,6 @@ quote = "1.0"
rand = "0.8.5"
rand_core = "0.6.4"
rangemap = "1.5.1"
rayon = "1.10.0"
reqwest = { version = "0.12.9", default-features = false }
resolv-conf = "0.7.3"
ringbuffer = "0.15.0"

View File

@@ -8,7 +8,7 @@
use crate::tun::Tun;
use anyhow::{Context as _, Result};
use backoff::ExponentialBackoffBuilder;
use client_shared::{Callbacks, DisconnectError, Session, V4RouteList, V6RouteList};
use client_shared::{DisconnectError, Session, V4RouteList, V6RouteList};
use connlib_model::ResourceView;
use dns_types::DomainName;
use firezone_logging::{err_with_src, sentry_layer};
@@ -165,7 +165,7 @@ fn init_logging(log_dir: &Path, log_filter: String) -> Result<()> {
Ok(())
}
impl Callbacks for CallbackHandler {
impl CallbackHandler {
fn on_set_interface_config(
&self,
tunnel_address_v4: Ipv4Addr,
@@ -382,14 +382,39 @@ fn connect(
},
tcp_socket_factory,
)?;
let session = Session::connect(
let (session, mut event_stream) = Session::connect(
Arc::new(protected_tcp_socket_factory(callbacks.clone())),
Arc::new(protected_udp_socket_factory(callbacks.clone())),
callbacks,
portal,
runtime.handle().clone(),
);
runtime.spawn(async move {
while let Some(event) = event_stream.next().await {
match event {
client_shared::Event::TunInterfaceUpdated {
ipv4,
ipv6,
dns,
search_domain,
ipv4_routes,
ipv6_routes,
} => callbacks.on_set_interface_config(
ipv4,
ipv6,
dns,
search_domain,
ipv4_routes,
ipv6_routes,
),
client_shared::Event::ResourcesUpdated(resource_views) => {
callbacks.on_update_resources(resource_views)
}
client_shared::Event::Disconnected(error) => callbacks.on_disconnect(error),
}
}
});
Ok(SessionWrapper {
inner: session,
runtime,

View File

@@ -8,7 +8,7 @@ mod tun;
use anyhow::Context;
use anyhow::Result;
use backoff::ExponentialBackoffBuilder;
use client_shared::{Callbacks, DisconnectError, Session, V4RouteList, V6RouteList};
use client_shared::{DisconnectError, Event, Session, V4RouteList, V6RouteList};
use connlib_model::ResourceView;
use dns_types::DomainName;
use firezone_logging::err_with_src;
@@ -127,15 +127,14 @@ pub struct WrappedSession {
unsafe impl Send for ffi::CallbackHandler {}
unsafe impl Sync for ffi::CallbackHandler {}
#[derive(Clone)]
pub struct CallbackHandler {
// Generated Swift opaque type wrappers have a `Drop` impl that decrements the
// refcount, but there's no way to generate a `Clone` impl that increments the
// recount. Instead, we just wrap it in an `Arc`.
inner: Arc<ffi::CallbackHandler>,
inner: ffi::CallbackHandler,
}
impl Callbacks for CallbackHandler {
impl CallbackHandler {
fn on_set_interface_config(
&self,
tunnel_address_v4: Ipv4Addr,
@@ -293,17 +292,48 @@ impl WrappedSession {
},
Arc::new(socket_factory::tcp),
)?;
let session = Session::connect(
let (session, mut event_stream) = Session::connect(
Arc::new(socket_factory::tcp),
Arc::new(socket_factory::udp),
CallbackHandler {
inner: Arc::new(callback_handler),
},
portal,
runtime.handle().clone(),
);
session.set_tun(Box::new(Tun::new()?));
runtime.spawn(async move {
let callback_handler = CallbackHandler {
inner: callback_handler,
};
while let Some(event) = event_stream.next().await {
match event {
Event::TunInterfaceUpdated {
ipv4,
ipv6,
dns,
search_domain,
ipv4_routes,
ipv6_routes,
} => {
callback_handler.on_set_interface_config(
ipv4,
ipv6,
dns,
search_domain,
ipv4_routes,
ipv6_routes,
);
}
Event::ResourcesUpdated(resource_views) => {
callback_handler.on_update_resources(resource_views);
}
Event::Disconnected(error) => {
callback_handler.on_disconnect(error);
}
}
}
});
Ok(Self {
inner: session,
runtime,

View File

@@ -14,7 +14,6 @@ firezone-logging = { workspace = true }
firezone-tunnel = { workspace = true }
ip_network = { workspace = true }
phoenix-channel = { workspace = true }
rayon = { workspace = true }
secrecy = { workspace = true }
serde = { workspace = true, features = ["std", "derive"] }
snownet = { workspace = true }

View File

@@ -1,219 +0,0 @@
use connlib_model::ResourceView;
use dns_types::DomainName;
use ip_network::{Ipv4Network, Ipv6Network};
use std::{
net::{IpAddr, Ipv4Addr, Ipv6Addr},
sync::Arc,
};
use tokio::sync::mpsc;
/// Traits that will be used by connlib to callback the client upper layers.
pub trait Callbacks: Clone + Send + Sync {
/// Called when the tunnel address is set.
///
/// The first time this is called, the Resources list is also ready,
/// the routes are also ready, and the Client can consider the tunnel
/// to be ready for incoming traffic.
fn on_set_interface_config(
&self,
_: Ipv4Addr,
_: Ipv6Addr,
_: Vec<IpAddr>,
_: Option<DomainName>,
_: Vec<Ipv4Network>,
_: Vec<Ipv6Network>,
) {
}
/// Called when the resource list changes.
///
/// This may not be called if a Client has no Resources, which can
/// happen to new accounts, or when removing and re-adding Resources,
/// or if all Resources for a user are disabled by policy.
fn on_update_resources(&self, _: Vec<ResourceView>) {}
/// Called when the tunnel is disconnected.
fn on_disconnect(&self, _: DisconnectError) {}
}
/// Unified error type to use across connlib.
#[derive(thiserror::Error, Debug)]
#[error("{0:#}")]
pub struct DisconnectError(anyhow::Error);
impl From<anyhow::Error> for DisconnectError {
fn from(e: anyhow::Error) -> Self {
Self(e)
}
}
impl DisconnectError {
pub fn is_authentication_error(&self) -> bool {
let Some(e) = self.0.downcast_ref::<phoenix_channel::Error>() else {
return false;
};
e.is_authentication_error()
}
}
#[derive(Debug, Clone)]
pub struct BackgroundCallbacks<C> {
inner: C,
threadpool: Arc<rayon::ThreadPool>,
}
impl<C> BackgroundCallbacks<C> {
pub fn new(callbacks: C) -> Self {
Self {
inner: callbacks,
threadpool: Arc::new(
rayon::ThreadPoolBuilder::new()
.num_threads(1)
.thread_name(|_| "connlib callbacks".to_owned())
.build()
.expect("Unable to create thread-pool"),
),
}
}
}
impl<C> Callbacks for BackgroundCallbacks<C>
where
C: Callbacks + 'static,
{
fn on_set_interface_config(
&self,
ipv4_addr: Ipv4Addr,
ipv6_addr: Ipv6Addr,
dns_addresses: Vec<IpAddr>,
search_domain: Option<DomainName>,
route_list_4: Vec<Ipv4Network>,
route_list_6: Vec<Ipv6Network>,
) {
let callbacks = self.inner.clone();
self.threadpool.spawn(move || {
callbacks.on_set_interface_config(
ipv4_addr,
ipv6_addr,
dns_addresses,
search_domain,
route_list_4,
route_list_6,
);
});
}
fn on_update_resources(&self, resources: Vec<ResourceView>) {
let callbacks = self.inner.clone();
self.threadpool.spawn(move || {
callbacks.on_update_resources(resources);
});
}
fn on_disconnect(&self, error: DisconnectError) {
let callbacks = self.inner.clone();
self.threadpool.spawn(move || {
callbacks.on_disconnect(error);
});
}
}
/// Messages that connlib can produce and send to the headless Client, Tunnel service, or GUI process.
///
/// i.e. callbacks
// The names are CamelCase versions of the connlib callbacks.
#[expect(clippy::enum_variant_names)]
pub enum ConnlibMsg {
OnDisconnect {
error_msg: String,
is_authentication_error: bool,
},
/// Use this as `TunnelReady`, per `callbacks.rs`
OnSetInterfaceConfig {
ipv4: Ipv4Addr,
ipv6: Ipv6Addr,
dns: Vec<IpAddr>,
search_domain: Option<DomainName>,
ipv4_routes: Vec<Ipv4Network>,
ipv6_routes: Vec<Ipv6Network>,
},
OnUpdateResources(Vec<ResourceView>),
}
#[derive(Clone)]
pub struct ChannelCallbackHandler {
cb_tx: mpsc::Sender<ConnlibMsg>,
}
impl ChannelCallbackHandler {
pub fn new() -> (Self, mpsc::Receiver<ConnlibMsg>) {
let (cb_tx, cb_rx) = mpsc::channel(1_000);
(Self { cb_tx }, cb_rx)
}
}
impl Callbacks for ChannelCallbackHandler {
fn on_disconnect(&self, error: DisconnectError) {
self.cb_tx
.try_send(ConnlibMsg::OnDisconnect {
error_msg: error.to_string(),
is_authentication_error: error.is_authentication_error(),
})
.expect("should be able to send OnDisconnect");
}
fn on_set_interface_config(
&self,
ipv4: Ipv4Addr,
ipv6: Ipv6Addr,
dns: Vec<IpAddr>,
search_domain: Option<DomainName>,
ipv4_routes: Vec<Ipv4Network>,
ipv6_routes: Vec<Ipv6Network>,
) {
self.cb_tx
.try_send(ConnlibMsg::OnSetInterfaceConfig {
ipv4,
ipv6,
dns,
search_domain,
ipv4_routes,
ipv6_routes,
})
.expect("Should be able to send OnSetInterfaceConfig");
}
fn on_update_resources(&self, resources: Vec<ResourceView>) {
tracing::debug!(len = resources.len(), "New resource list");
self.cb_tx
.try_send(ConnlibMsg::OnUpdateResources(resources))
.expect("Should be able to send OnUpdateResources");
}
}
#[cfg(test)]
mod tests {
use phoenix_channel::StatusCode;
use super::*;
#[test]
fn printing_disconnect_error_contains_401() {
let disconnect_error = DisconnectError::from(anyhow::Error::new(
phoenix_channel::Error::Client(StatusCode::UNAUTHORIZED),
));
assert!(disconnect_error.to_string().contains("401 Unauthorized")); // Apple client relies on this.
}
// Make sure it's okay to store a bunch of these to mitigate #5880
#[test]
fn callback_msg_size() {
assert_eq!(std::mem::size_of::<ConnlibMsg>(), 120)
}
}

View File

@@ -1,13 +1,16 @@
use crate::{PHOENIX_TOPIC, callbacks::Callbacks};
use crate::PHOENIX_TOPIC;
use anyhow::{Context as _, Result};
use connlib_model::{PublicKey, ResourceId};
use connlib_model::{PublicKey, ResourceId, ResourceView};
use dns_types::DomainName;
use firezone_tunnel::messages::RelaysPresence;
use firezone_tunnel::messages::client::{
EgressMessages, FailReason, FlowCreated, FlowCreationFailed, GatewayIceCandidates,
GatewaysIceCandidates, IngressMessages, InitClient,
};
use firezone_tunnel::{ClientTunnel, IpConfig};
use ip_network::{Ipv4Network, Ipv6Network};
use phoenix_channel::{ErrorReply, OutboundRequestId, PhoenixChannel, PublicKeyParam};
use std::net::{Ipv4Addr, Ipv6Addr};
use std::time::Instant;
use std::{
collections::BTreeSet,
@@ -15,14 +18,15 @@ use std::{
net::IpAddr,
task::{Context, Poll},
};
use tokio::sync::mpsc::error::TrySendError;
use tun::Tun;
pub struct Eventloop<C: Callbacks> {
pub struct Eventloop {
tunnel: ClientTunnel,
callbacks: C,
portal: PhoenixChannel<(), IngressMessages, (), PublicKeyParam>,
rx: tokio::sync::mpsc::UnboundedReceiver<Command>,
cmd_rx: tokio::sync::mpsc::UnboundedReceiver<Command>,
event_tx: tokio::sync::mpsc::Sender<Event>,
}
/// Commands that can be sent to the [`Eventloop`].
@@ -33,31 +37,62 @@ pub enum Command {
SetDisabledResources(BTreeSet<ResourceId>),
}
impl<C: Callbacks> Eventloop<C> {
pub enum Event {
TunInterfaceUpdated {
ipv4: Ipv4Addr,
ipv6: Ipv6Addr,
dns: Vec<IpAddr>,
search_domain: Option<DomainName>,
ipv4_routes: Vec<Ipv4Network>,
ipv6_routes: Vec<Ipv6Network>,
},
ResourcesUpdated(Vec<ResourceView>),
Disconnected(DisconnectError),
}
/// Unified error type to use across connlib.
#[derive(thiserror::Error, Debug)]
#[error("{0:#}")]
pub struct DisconnectError(anyhow::Error);
impl From<anyhow::Error> for DisconnectError {
fn from(e: anyhow::Error) -> Self {
Self(e)
}
}
impl DisconnectError {
pub fn is_authentication_error(&self) -> bool {
let Some(e) = self.0.downcast_ref::<phoenix_channel::Error>() else {
return false;
};
e.is_authentication_error()
}
}
impl Eventloop {
pub(crate) fn new(
tunnel: ClientTunnel,
callbacks: C,
mut portal: PhoenixChannel<(), IngressMessages, (), PublicKeyParam>,
rx: tokio::sync::mpsc::UnboundedReceiver<Command>,
cmd_rx: tokio::sync::mpsc::UnboundedReceiver<Command>,
event_tx: tokio::sync::mpsc::Sender<Event>,
) -> Self {
portal.connect(PublicKeyParam(tunnel.public_key().to_bytes()));
Self {
tunnel,
portal,
rx,
callbacks,
cmd_rx,
event_tx,
}
}
}
impl<C> Eventloop<C>
where
C: Callbacks + 'static,
{
impl Eventloop {
pub fn poll(&mut self, cx: &mut Context<'_>) -> Poll<Result<()>> {
loop {
match self.rx.poll_recv(cx) {
match self.cmd_rx.poll_recv(cx) {
Poll::Ready(None) => return Poll::Ready(Ok(())),
Poll::Ready(Some(Command::SetDns(dns))) => {
self.tunnel.state_mut().update_system_resolvers(dns);
@@ -84,7 +119,22 @@ where
match self.tunnel.poll_next_event(cx) {
Poll::Ready(Ok(event)) => {
self.handle_tunnel_event(event);
let Some(e) = self.handle_tunnel_event(event) else {
continue;
};
match self.event_tx.try_send(e) {
Ok(()) => {}
Err(TrySendError::Closed(_)) => {
tracing::debug!("Event receiver dropped, exiting event loop");
return Poll::Ready(Ok(()));
}
Err(TrySendError::Full(_)) => {
tracing::warn!("App cannot keep up with connlib events, dropping");
}
};
continue;
}
Poll::Ready(Err(e)) => {
@@ -123,7 +173,7 @@ where
}
}
fn handle_tunnel_event(&mut self, event: firezone_tunnel::ClientEvent) {
fn handle_tunnel_event(&mut self, event: firezone_tunnel::ClientEvent) -> Option<Event> {
match event {
firezone_tunnel::ClientEvent::AddedIceCandidates {
conn_id: gateway,
@@ -138,6 +188,8 @@ where
candidates,
}),
);
None
}
firezone_tunnel::ClientEvent::RemovedIceCandidates {
conn_id: gateway,
@@ -152,6 +204,8 @@ where
candidates,
}),
);
None
}
firezone_tunnel::ClientEvent::ConnectionIntent {
connected_gateway_ids,
@@ -164,21 +218,21 @@ where
connected_gateway_ids,
},
);
None
}
firezone_tunnel::ClientEvent::ResourcesChanged { resources } => {
self.callbacks.on_update_resources(resources)
Some(Event::ResourcesUpdated(resources))
}
firezone_tunnel::ClientEvent::TunInterfaceUpdated(config) => {
let dns_servers = config.dns_by_sentinel.left_values().copied().collect();
self.callbacks.on_set_interface_config(
config.ip.v4,
config.ip.v6,
dns_servers,
config.search_domain,
Vec::from_iter(config.ipv4_routes),
Vec::from_iter(config.ipv6_routes),
);
Some(Event::TunInterfaceUpdated {
ipv4: config.ip.v4,
ipv6: config.ip.v6,
dns: config.dns_by_sentinel.left_values().copied().collect(),
search_domain: config.search_domain,
ipv4_routes: Vec::from_iter(config.ipv4_routes),
ipv6_routes: Vec::from_iter(config.ipv6_routes),
})
}
}
}

View File

@@ -1,25 +1,23 @@
//! Main connlib library for clients.
pub use crate::serde_routelist::{V4RouteList, V6RouteList};
use callbacks::BackgroundCallbacks;
pub use callbacks::{Callbacks, ChannelCallbackHandler, ConnlibMsg, DisconnectError};
pub use connlib_model::StaticSecret;
pub use eventloop::Eventloop;
pub use eventloop::{DisconnectError, Event};
pub use firezone_tunnel::messages::client::{IngressMessages, ResourceDescription};
use anyhow::{Context, Result};
use anyhow::{Context as _, Result};
use connlib_model::ResourceId;
use eventloop::Command;
use eventloop::{Command, Eventloop};
use firezone_tunnel::ClientTunnel;
use phoenix_channel::{PhoenixChannel, PublicKeyParam};
use socket_factory::{SocketFactory, TcpSocket, UdpSocket};
use std::collections::BTreeSet;
use std::net::IpAddr;
use std::sync::Arc;
use tokio::sync::mpsc::UnboundedReceiver;
use std::task::{Context, Poll};
use tokio::sync::mpsc::{Receiver, Sender, UnboundedReceiver, UnboundedSender};
use tokio::task::JoinHandle;
use tun::Tun;
mod callbacks;
mod eventloop;
mod serde_routelist;
@@ -31,34 +29,36 @@ const PHOENIX_TOPIC: &str = "client";
/// To stop the session, simply drop this struct.
#[derive(Clone)]
pub struct Session {
channel: tokio::sync::mpsc::UnboundedSender<Command>,
channel: UnboundedSender<Command>,
}
pub struct EventStream {
channel: Receiver<Event>,
}
impl Session {
/// Creates a new [`Session`].
///
/// This connects to the portal using the given [`LoginUrl`](phoenix_channel::LoginUrl) and creates a wireguard tunnel using the provided private key.
pub fn connect<CB: Callbacks + 'static>(
pub fn connect(
tcp_socket_factory: Arc<dyn SocketFactory<TcpSocket>>,
udp_socket_factory: Arc<dyn SocketFactory<UdpSocket>>,
callbacks: CB,
portal: PhoenixChannel<(), IngressMessages, (), PublicKeyParam>,
handle: tokio::runtime::Handle,
) -> Self {
let callbacks = BackgroundCallbacks::new(callbacks); // Run all callbacks on a background thread to avoid blocking the main connlib task.
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
) -> (Self, EventStream) {
let (cmd_tx, cmd_rx) = tokio::sync::mpsc::unbounded_channel();
let (event_tx, event_rx) = tokio::sync::mpsc::channel(1000);
let connect_handle = handle.spawn(connect(
tcp_socket_factory,
udp_socket_factory,
callbacks.clone(),
portal,
rx,
cmd_rx,
event_tx.clone(),
));
handle.spawn(connect_supervisor(connect_handle, callbacks));
handle.spawn(connect_supervisor(connect_handle, event_tx));
Self { channel: tx }
(Self { channel: cmd_tx }, EventStream { channel: event_rx })
}
/// Reset a [`Session`].
@@ -107,6 +107,16 @@ impl Session {
}
}
impl EventStream {
pub fn poll_next(&mut self, cx: &mut Context) -> Poll<Option<Event>> {
self.channel.poll_recv(cx)
}
pub async fn next(&mut self) -> Option<Event> {
self.channel.recv().await
}
}
impl Drop for Session {
fn drop(&mut self) {
tracing::debug!("`Session` dropped")
@@ -116,18 +126,15 @@ impl Drop for Session {
/// Connects to the portal and starts a tunnel.
///
/// When this function exits, the tunnel failed unrecoverably and you need to call it again.
async fn connect<CB>(
async fn connect(
tcp_socket_factory: Arc<dyn SocketFactory<TcpSocket>>,
udp_socket_factory: Arc<dyn SocketFactory<UdpSocket>>,
callbacks: CB,
portal: PhoenixChannel<(), IngressMessages, (), PublicKeyParam>,
rx: UnboundedReceiver<Command>,
) -> Result<()>
where
CB: Callbacks + 'static,
{
cmd_rx: UnboundedReceiver<Command>,
event_tx: Sender<Event>,
) -> Result<()> {
let tunnel = ClientTunnel::new(tcp_socket_factory, udp_socket_factory);
let mut eventloop = Eventloop::new(tunnel, callbacks, portal, rx);
let mut eventloop = Eventloop::new(tunnel, portal, cmd_rx, event_tx);
std::future::poll_fn(|cx| eventloop.poll(cx)).await?;
@@ -135,18 +142,27 @@ where
}
/// A supervisor task that handles, when [`connect`] exits.
async fn connect_supervisor<CB>(connect_handle: JoinHandle<Result<()>>, callbacks: CB)
where
CB: Callbacks,
{
async fn connect_supervisor(
connect_handle: JoinHandle<Result<()>>,
event_tx: tokio::sync::mpsc::Sender<Event>,
) {
let task = async {
connect_handle.await.context("connlib crashed")??;
Ok(())
};
match task.await {
Ok(()) => tracing::info!("connlib exited gracefully"),
Err(e) => callbacks.on_disconnect(e),
let error = match task.await {
Ok(()) => {
tracing::info!("connlib exited gracefully");
return;
}
Err(e) => e,
};
match event_tx.send(Event::Disconnected(error)).await {
Ok(()) => (),
Err(_) => tracing::debug!("Event stream closed before we could send disconnected event"),
}
}

View File

@@ -5,7 +5,6 @@ use crate::{
use anyhow::{Context as _, Result, bail};
use atomicwrites::{AtomicFile, OverwriteBehavior};
use backoff::ExponentialBackoffBuilder;
use client_shared::ConnlibMsg;
use connlib_model::{ResourceId, ResourceView};
use firezone_bin_shared::{
DnsControlMethod, DnsController, TunDeviceManager,
@@ -31,7 +30,7 @@ use std::{
sync::Arc,
time::Duration,
};
use tokio::{sync::mpsc, time::Instant};
use tokio::time::Instant;
use url::Url;
#[cfg(target_os = "linux")]
@@ -179,12 +178,12 @@ struct Handler<'a> {
}
struct Session {
cb_rx: mpsc::Receiver<ConnlibMsg>,
event_stream: client_shared::EventStream,
connlib: client_shared::Session,
}
enum Event {
Callback(ConnlibMsg),
Connlib(client_shared::Event),
CallbackChannelClosed,
Ipc(ClientMsg),
IpcDisconnected,
@@ -247,8 +246,8 @@ impl<'a> Handler<'a> {
async fn run(&mut self, signals: &mut signals::Terminate) -> HandlerOk {
let ret = loop {
match poll_fn(|cx| self.next_event(cx, signals)).await {
Event::Callback(x) => {
if let Err(error) = self.handle_connlib_cb(x).await {
Event::Connlib(x) => {
if let Err(error) = self.handle_connlib_event(x).await {
tracing::error!("Error while handling connlib callback: {error:#}");
continue;
}
@@ -309,10 +308,9 @@ impl<'a> Handler<'a> {
});
}
if let Some(session) = self.session.as_mut() {
// `tokio::sync::mpsc::Receiver::recv` is cancel-safe.
if let Poll::Ready(option) = session.cb_rx.poll_recv(cx) {
if let Poll::Ready(option) = session.event_stream.poll_next(cx) {
return Poll::Ready(match option {
Some(x) => Event::Callback(x),
Some(x) => Event::Connlib(x),
None => Event::CallbackChannelClosed,
});
}
@@ -320,21 +318,18 @@ impl<'a> Handler<'a> {
Poll::Pending
}
async fn handle_connlib_cb(&mut self, msg: ConnlibMsg) -> Result<()> {
async fn handle_connlib_event(&mut self, msg: client_shared::Event) -> Result<()> {
match msg {
ConnlibMsg::OnDisconnect {
error_msg,
is_authentication_error,
} => {
client_shared::Event::Disconnected(error) => {
let _ = self.session.take();
self.dns_controller.deactivate()?;
self.send_ipc(ServerMsg::OnDisconnect {
error_msg,
is_authentication_error,
error_msg: error.to_string(),
is_authentication_error: error.is_authentication_error(),
})
.await?
}
ConnlibMsg::OnSetInterfaceConfig {
client_shared::Event::TunInterfaceUpdated {
ipv4,
ipv6,
dns,
@@ -352,7 +347,7 @@ impl<'a> Handler<'a> {
self.send_ipc(ServerMsg::TunnelReady).await?;
}
ConnlibMsg::OnUpdateResources(resources) => {
client_shared::Event::ResourcesUpdated(resources) => {
// On every resources update, flush DNS to mitigate <https://github.com/firezone/firezone/issues/5052>
self.dns_controller.flush()?;
self.send_ipc(ServerMsg::OnUpdateResources(resources))
@@ -472,7 +467,6 @@ impl<'a> Handler<'a> {
.context("Failed to create `LoginUrl`")?;
self.last_connlib_start_instant = Some(Instant::now());
let (callbacks, cb_rx) = client_shared::ChannelCallbackHandler::new();
// Synchronous DNS resolution here
let portal = PhoenixChannel::disconnected(
@@ -490,10 +484,9 @@ impl<'a> Handler<'a> {
// Read the resolvers before starting connlib, in case connlib's startup interferes.
let dns = self.dns_controller.system_resolvers();
let connlib = client_shared::Session::connect(
let (connlib, event_stream) = client_shared::Session::connect(
Arc::new(tcp_socket_factory),
Arc::new(udp_socket_factory),
callbacks,
portal,
tokio::runtime::Handle::current(),
);
@@ -510,7 +503,10 @@ impl<'a> Handler<'a> {
};
connlib.set_tun(tun);
let session = Session { cb_rx, connlib };
let session = Session {
event_stream,
connlib,
};
self.session = Some(session);
Ok(())

View File

@@ -5,7 +5,6 @@
use anyhow::{Context as _, Result, anyhow};
use backoff::ExponentialBackoffBuilder;
use clap::Parser;
use client_shared::{ChannelCallbackHandler, ConnlibMsg, Session};
use firezone_bin_shared::{
DnsControlMethod, DnsController, TOKEN_ENV_KEY, TunDeviceManager, device_id, device_info,
new_dns_notifier, new_network_notifier,
@@ -15,7 +14,6 @@ use firezone_bin_shared::{
use firezone_logging::telemetry_span;
use firezone_telemetry::Telemetry;
use firezone_telemetry::otel;
use futures::StreamExt as _;
use opentelemetry_sdk::metrics::{PeriodicReader, SdkMeterProvider};
use phoenix_channel::PhoenixChannel;
use phoenix_channel::get_user_agent;
@@ -26,7 +24,6 @@ use std::{
sync::Arc,
};
use tokio::time::Instant;
use tokio_stream::wrappers::ReceiverStream;
#[cfg(target_os = "linux")]
#[path = "linux.rs"]
@@ -221,8 +218,6 @@ fn main() -> Result<()> {
return Ok(());
}
let (callbacks, cb_rx) = ChannelCallbackHandler::new();
// The name matches that in `ipc_service.rs`
let mut last_connlib_start_instant = Some(Instant::now());
@@ -261,10 +256,9 @@ fn main() -> Result<()> {
},
Arc::new(tcp_socket_factory),
)?;
let session = Session::connect(
let (session, mut event_stream) = client_shared::Session::connect(
Arc::new(tcp_socket_factory),
Arc::new(udp_socket_factory),
callbacks,
portal,
rt.handle().clone(),
);
@@ -273,7 +267,6 @@ fn main() -> Result<()> {
let mut hangup = signals::Hangup::new()?;
let mut tun_device = TunDeviceManager::new(ip_packet::MAX_IP_SIZE, 1)?;
let mut cb_rx = ReceiverStream::new(cb_rx).fuse();
let tokio_handle = tokio::runtime::Handle::current();
@@ -294,7 +287,7 @@ fn main() -> Result<()> {
drop(connect_span);
let result = loop {
let cb = tokio::select! {
let event = tokio::select! {
() = terminate.recv() => {
tracing::info!("Caught SIGINT / SIGTERM / Ctrl+C");
break Ok(());
@@ -318,20 +311,17 @@ fn main() -> Result<()> {
session.reset();
continue;
},
cb = cb_rx.next() => cb.context("cb_rx unexpectedly ran empty")?,
event = event_stream.next() => event.context("event stream unexpectedly ran empty")?,
};
match cb {
match event {
// TODO: Headless Client shouldn't be using messages labelled `Ipc`
ConnlibMsg::OnDisconnect {
error_msg,
is_authentication_error: _,
} => break Err(anyhow!(error_msg).context("Firezone disconnected")),
ConnlibMsg::OnUpdateResources(_) => {
client_shared::Event::Disconnected(error) => break Err(anyhow!(error).context("Firezone disconnected")),
client_shared::Event::ResourcesUpdated(_) => {
// On every Resources update, flush DNS to mitigate <https://github.com/firezone/firezone/issues/5052>
dns_controller.flush()?;
}
ConnlibMsg::OnSetInterfaceConfig {
client_shared::Event::TunInterfaceUpdated {
ipv4,
ipv6,
dns,