mirror of
https://github.com/outbackdingo/firezone.git
synced 2026-01-27 18:18:55 +00:00
refactor(rust): remove forced callback indirection (#9362)
As relict from very early designs of `connlib`, the `Callbacks` trait is still present and defines how the host app receives events from a running `Session`. Callbacks are not a great design pattern however because they force the running code, i.e. `connlib`s event-loop to execute unknown code. For example, if that code panics, all of `connlib` is taken down. Additionally, not all consumers may want to receive events via callbacks. The GUI and headless client for example already have their own event-loop in which they process all kinds of things. Having to deal with the `Callbacks` interface introduces an odd indirection here. To fix this, we instead return an `EventStream` when constructing a `Session`. This essentially aligns the API of `Session` with that of a channel. You receive two handles, one for sending in commands and one for receiving events. A `Session` will automatically spawn itself onto the given runtime so progress is made even if one does not poll on these channel handles. This greatly simplifies the code: - We get to delete the `Callbacks` interface. - We can delete the threaded callback adapter. This was only necessary because we didn't want to block `connlib` with the handling of the event. By using a channel for events, this is automatically guaranteed. - The GUI and headless client can directly integrate the event handling in their event-loop, without having to create an indirection with a channel. - It is now clear that only the Apple and Android FFI layers actually use callbacks to communicate these events. - We net-delete 100 LoC
This commit is contained in:
31
rust/Cargo.lock
generated
31
rust/Cargo.lock
generated
@@ -1135,7 +1135,6 @@ dependencies = [
|
||||
"firezone-tunnel",
|
||||
"ip_network",
|
||||
"phoenix-channel",
|
||||
"rayon",
|
||||
"secrecy",
|
||||
"serde",
|
||||
"serde_json",
|
||||
@@ -1377,16 +1376,6 @@ dependencies = [
|
||||
"crossbeam-utils",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crossbeam-deque"
|
||||
version = "0.8.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9dd111b7b7f7d55b72c0a6ae361660ee5853c9af73f70c3c2ef6858b950e2e51"
|
||||
dependencies = [
|
||||
"crossbeam-epoch",
|
||||
"crossbeam-utils",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crossbeam-epoch"
|
||||
version = "0.9.18"
|
||||
@@ -5571,26 +5560,6 @@ version = "0.6.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "20675572f6f24e9e76ef639bc5552774ed45f1c30e2951e1e99c59888861c539"
|
||||
|
||||
[[package]]
|
||||
name = "rayon"
|
||||
version = "1.10.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b418a60154510ca1a002a752ca9714984e21e4241e804d32555251faf8b78ffa"
|
||||
dependencies = [
|
||||
"either",
|
||||
"rayon-core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rayon-core"
|
||||
version = "1.12.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1465873a3dfdaa8ae7cb14b4383657caab0b3e8a0aa9ae8e04b044854c8dfce2"
|
||||
dependencies = [
|
||||
"crossbeam-deque",
|
||||
"crossbeam-utils",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "redox_syscall"
|
||||
version = "0.5.12"
|
||||
|
||||
@@ -129,7 +129,6 @@ quote = "1.0"
|
||||
rand = "0.8.5"
|
||||
rand_core = "0.6.4"
|
||||
rangemap = "1.5.1"
|
||||
rayon = "1.10.0"
|
||||
reqwest = { version = "0.12.9", default-features = false }
|
||||
resolv-conf = "0.7.3"
|
||||
ringbuffer = "0.15.0"
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
use crate::tun::Tun;
|
||||
use anyhow::{Context as _, Result};
|
||||
use backoff::ExponentialBackoffBuilder;
|
||||
use client_shared::{Callbacks, DisconnectError, Session, V4RouteList, V6RouteList};
|
||||
use client_shared::{DisconnectError, Session, V4RouteList, V6RouteList};
|
||||
use connlib_model::ResourceView;
|
||||
use dns_types::DomainName;
|
||||
use firezone_logging::{err_with_src, sentry_layer};
|
||||
@@ -165,7 +165,7 @@ fn init_logging(log_dir: &Path, log_filter: String) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
impl Callbacks for CallbackHandler {
|
||||
impl CallbackHandler {
|
||||
fn on_set_interface_config(
|
||||
&self,
|
||||
tunnel_address_v4: Ipv4Addr,
|
||||
@@ -382,14 +382,39 @@ fn connect(
|
||||
},
|
||||
tcp_socket_factory,
|
||||
)?;
|
||||
let session = Session::connect(
|
||||
let (session, mut event_stream) = Session::connect(
|
||||
Arc::new(protected_tcp_socket_factory(callbacks.clone())),
|
||||
Arc::new(protected_udp_socket_factory(callbacks.clone())),
|
||||
callbacks,
|
||||
portal,
|
||||
runtime.handle().clone(),
|
||||
);
|
||||
|
||||
runtime.spawn(async move {
|
||||
while let Some(event) = event_stream.next().await {
|
||||
match event {
|
||||
client_shared::Event::TunInterfaceUpdated {
|
||||
ipv4,
|
||||
ipv6,
|
||||
dns,
|
||||
search_domain,
|
||||
ipv4_routes,
|
||||
ipv6_routes,
|
||||
} => callbacks.on_set_interface_config(
|
||||
ipv4,
|
||||
ipv6,
|
||||
dns,
|
||||
search_domain,
|
||||
ipv4_routes,
|
||||
ipv6_routes,
|
||||
),
|
||||
client_shared::Event::ResourcesUpdated(resource_views) => {
|
||||
callbacks.on_update_resources(resource_views)
|
||||
}
|
||||
client_shared::Event::Disconnected(error) => callbacks.on_disconnect(error),
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Ok(SessionWrapper {
|
||||
inner: session,
|
||||
runtime,
|
||||
|
||||
@@ -8,7 +8,7 @@ mod tun;
|
||||
use anyhow::Context;
|
||||
use anyhow::Result;
|
||||
use backoff::ExponentialBackoffBuilder;
|
||||
use client_shared::{Callbacks, DisconnectError, Session, V4RouteList, V6RouteList};
|
||||
use client_shared::{DisconnectError, Event, Session, V4RouteList, V6RouteList};
|
||||
use connlib_model::ResourceView;
|
||||
use dns_types::DomainName;
|
||||
use firezone_logging::err_with_src;
|
||||
@@ -127,15 +127,14 @@ pub struct WrappedSession {
|
||||
unsafe impl Send for ffi::CallbackHandler {}
|
||||
unsafe impl Sync for ffi::CallbackHandler {}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct CallbackHandler {
|
||||
// Generated Swift opaque type wrappers have a `Drop` impl that decrements the
|
||||
// refcount, but there's no way to generate a `Clone` impl that increments the
|
||||
// recount. Instead, we just wrap it in an `Arc`.
|
||||
inner: Arc<ffi::CallbackHandler>,
|
||||
inner: ffi::CallbackHandler,
|
||||
}
|
||||
|
||||
impl Callbacks for CallbackHandler {
|
||||
impl CallbackHandler {
|
||||
fn on_set_interface_config(
|
||||
&self,
|
||||
tunnel_address_v4: Ipv4Addr,
|
||||
@@ -293,17 +292,48 @@ impl WrappedSession {
|
||||
},
|
||||
Arc::new(socket_factory::tcp),
|
||||
)?;
|
||||
let session = Session::connect(
|
||||
let (session, mut event_stream) = Session::connect(
|
||||
Arc::new(socket_factory::tcp),
|
||||
Arc::new(socket_factory::udp),
|
||||
CallbackHandler {
|
||||
inner: Arc::new(callback_handler),
|
||||
},
|
||||
portal,
|
||||
runtime.handle().clone(),
|
||||
);
|
||||
session.set_tun(Box::new(Tun::new()?));
|
||||
|
||||
runtime.spawn(async move {
|
||||
let callback_handler = CallbackHandler {
|
||||
inner: callback_handler,
|
||||
};
|
||||
|
||||
while let Some(event) = event_stream.next().await {
|
||||
match event {
|
||||
Event::TunInterfaceUpdated {
|
||||
ipv4,
|
||||
ipv6,
|
||||
dns,
|
||||
search_domain,
|
||||
ipv4_routes,
|
||||
ipv6_routes,
|
||||
} => {
|
||||
callback_handler.on_set_interface_config(
|
||||
ipv4,
|
||||
ipv6,
|
||||
dns,
|
||||
search_domain,
|
||||
ipv4_routes,
|
||||
ipv6_routes,
|
||||
);
|
||||
}
|
||||
Event::ResourcesUpdated(resource_views) => {
|
||||
callback_handler.on_update_resources(resource_views);
|
||||
}
|
||||
Event::Disconnected(error) => {
|
||||
callback_handler.on_disconnect(error);
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Ok(Self {
|
||||
inner: session,
|
||||
runtime,
|
||||
|
||||
@@ -14,7 +14,6 @@ firezone-logging = { workspace = true }
|
||||
firezone-tunnel = { workspace = true }
|
||||
ip_network = { workspace = true }
|
||||
phoenix-channel = { workspace = true }
|
||||
rayon = { workspace = true }
|
||||
secrecy = { workspace = true }
|
||||
serde = { workspace = true, features = ["std", "derive"] }
|
||||
snownet = { workspace = true }
|
||||
|
||||
@@ -1,219 +0,0 @@
|
||||
use connlib_model::ResourceView;
|
||||
use dns_types::DomainName;
|
||||
use ip_network::{Ipv4Network, Ipv6Network};
|
||||
use std::{
|
||||
net::{IpAddr, Ipv4Addr, Ipv6Addr},
|
||||
sync::Arc,
|
||||
};
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
/// Traits that will be used by connlib to callback the client upper layers.
|
||||
pub trait Callbacks: Clone + Send + Sync {
|
||||
/// Called when the tunnel address is set.
|
||||
///
|
||||
/// The first time this is called, the Resources list is also ready,
|
||||
/// the routes are also ready, and the Client can consider the tunnel
|
||||
/// to be ready for incoming traffic.
|
||||
fn on_set_interface_config(
|
||||
&self,
|
||||
_: Ipv4Addr,
|
||||
_: Ipv6Addr,
|
||||
_: Vec<IpAddr>,
|
||||
_: Option<DomainName>,
|
||||
_: Vec<Ipv4Network>,
|
||||
_: Vec<Ipv6Network>,
|
||||
) {
|
||||
}
|
||||
|
||||
/// Called when the resource list changes.
|
||||
///
|
||||
/// This may not be called if a Client has no Resources, which can
|
||||
/// happen to new accounts, or when removing and re-adding Resources,
|
||||
/// or if all Resources for a user are disabled by policy.
|
||||
fn on_update_resources(&self, _: Vec<ResourceView>) {}
|
||||
|
||||
/// Called when the tunnel is disconnected.
|
||||
fn on_disconnect(&self, _: DisconnectError) {}
|
||||
}
|
||||
|
||||
/// Unified error type to use across connlib.
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
#[error("{0:#}")]
|
||||
pub struct DisconnectError(anyhow::Error);
|
||||
|
||||
impl From<anyhow::Error> for DisconnectError {
|
||||
fn from(e: anyhow::Error) -> Self {
|
||||
Self(e)
|
||||
}
|
||||
}
|
||||
|
||||
impl DisconnectError {
|
||||
pub fn is_authentication_error(&self) -> bool {
|
||||
let Some(e) = self.0.downcast_ref::<phoenix_channel::Error>() else {
|
||||
return false;
|
||||
};
|
||||
|
||||
e.is_authentication_error()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct BackgroundCallbacks<C> {
|
||||
inner: C,
|
||||
threadpool: Arc<rayon::ThreadPool>,
|
||||
}
|
||||
|
||||
impl<C> BackgroundCallbacks<C> {
|
||||
pub fn new(callbacks: C) -> Self {
|
||||
Self {
|
||||
inner: callbacks,
|
||||
threadpool: Arc::new(
|
||||
rayon::ThreadPoolBuilder::new()
|
||||
.num_threads(1)
|
||||
.thread_name(|_| "connlib callbacks".to_owned())
|
||||
.build()
|
||||
.expect("Unable to create thread-pool"),
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<C> Callbacks for BackgroundCallbacks<C>
|
||||
where
|
||||
C: Callbacks + 'static,
|
||||
{
|
||||
fn on_set_interface_config(
|
||||
&self,
|
||||
ipv4_addr: Ipv4Addr,
|
||||
ipv6_addr: Ipv6Addr,
|
||||
dns_addresses: Vec<IpAddr>,
|
||||
search_domain: Option<DomainName>,
|
||||
route_list_4: Vec<Ipv4Network>,
|
||||
route_list_6: Vec<Ipv6Network>,
|
||||
) {
|
||||
let callbacks = self.inner.clone();
|
||||
|
||||
self.threadpool.spawn(move || {
|
||||
callbacks.on_set_interface_config(
|
||||
ipv4_addr,
|
||||
ipv6_addr,
|
||||
dns_addresses,
|
||||
search_domain,
|
||||
route_list_4,
|
||||
route_list_6,
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
fn on_update_resources(&self, resources: Vec<ResourceView>) {
|
||||
let callbacks = self.inner.clone();
|
||||
|
||||
self.threadpool.spawn(move || {
|
||||
callbacks.on_update_resources(resources);
|
||||
});
|
||||
}
|
||||
|
||||
fn on_disconnect(&self, error: DisconnectError) {
|
||||
let callbacks = self.inner.clone();
|
||||
|
||||
self.threadpool.spawn(move || {
|
||||
callbacks.on_disconnect(error);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/// Messages that connlib can produce and send to the headless Client, Tunnel service, or GUI process.
|
||||
///
|
||||
/// i.e. callbacks
|
||||
// The names are CamelCase versions of the connlib callbacks.
|
||||
#[expect(clippy::enum_variant_names)]
|
||||
pub enum ConnlibMsg {
|
||||
OnDisconnect {
|
||||
error_msg: String,
|
||||
is_authentication_error: bool,
|
||||
},
|
||||
/// Use this as `TunnelReady`, per `callbacks.rs`
|
||||
OnSetInterfaceConfig {
|
||||
ipv4: Ipv4Addr,
|
||||
ipv6: Ipv6Addr,
|
||||
dns: Vec<IpAddr>,
|
||||
search_domain: Option<DomainName>,
|
||||
ipv4_routes: Vec<Ipv4Network>,
|
||||
ipv6_routes: Vec<Ipv6Network>,
|
||||
},
|
||||
OnUpdateResources(Vec<ResourceView>),
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ChannelCallbackHandler {
|
||||
cb_tx: mpsc::Sender<ConnlibMsg>,
|
||||
}
|
||||
|
||||
impl ChannelCallbackHandler {
|
||||
pub fn new() -> (Self, mpsc::Receiver<ConnlibMsg>) {
|
||||
let (cb_tx, cb_rx) = mpsc::channel(1_000);
|
||||
|
||||
(Self { cb_tx }, cb_rx)
|
||||
}
|
||||
}
|
||||
|
||||
impl Callbacks for ChannelCallbackHandler {
|
||||
fn on_disconnect(&self, error: DisconnectError) {
|
||||
self.cb_tx
|
||||
.try_send(ConnlibMsg::OnDisconnect {
|
||||
error_msg: error.to_string(),
|
||||
is_authentication_error: error.is_authentication_error(),
|
||||
})
|
||||
.expect("should be able to send OnDisconnect");
|
||||
}
|
||||
|
||||
fn on_set_interface_config(
|
||||
&self,
|
||||
ipv4: Ipv4Addr,
|
||||
ipv6: Ipv6Addr,
|
||||
dns: Vec<IpAddr>,
|
||||
search_domain: Option<DomainName>,
|
||||
ipv4_routes: Vec<Ipv4Network>,
|
||||
ipv6_routes: Vec<Ipv6Network>,
|
||||
) {
|
||||
self.cb_tx
|
||||
.try_send(ConnlibMsg::OnSetInterfaceConfig {
|
||||
ipv4,
|
||||
ipv6,
|
||||
dns,
|
||||
search_domain,
|
||||
ipv4_routes,
|
||||
ipv6_routes,
|
||||
})
|
||||
.expect("Should be able to send OnSetInterfaceConfig");
|
||||
}
|
||||
|
||||
fn on_update_resources(&self, resources: Vec<ResourceView>) {
|
||||
tracing::debug!(len = resources.len(), "New resource list");
|
||||
self.cb_tx
|
||||
.try_send(ConnlibMsg::OnUpdateResources(resources))
|
||||
.expect("Should be able to send OnUpdateResources");
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use phoenix_channel::StatusCode;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn printing_disconnect_error_contains_401() {
|
||||
let disconnect_error = DisconnectError::from(anyhow::Error::new(
|
||||
phoenix_channel::Error::Client(StatusCode::UNAUTHORIZED),
|
||||
));
|
||||
|
||||
assert!(disconnect_error.to_string().contains("401 Unauthorized")); // Apple client relies on this.
|
||||
}
|
||||
|
||||
// Make sure it's okay to store a bunch of these to mitigate #5880
|
||||
#[test]
|
||||
fn callback_msg_size() {
|
||||
assert_eq!(std::mem::size_of::<ConnlibMsg>(), 120)
|
||||
}
|
||||
}
|
||||
@@ -1,13 +1,16 @@
|
||||
use crate::{PHOENIX_TOPIC, callbacks::Callbacks};
|
||||
use crate::PHOENIX_TOPIC;
|
||||
use anyhow::{Context as _, Result};
|
||||
use connlib_model::{PublicKey, ResourceId};
|
||||
use connlib_model::{PublicKey, ResourceId, ResourceView};
|
||||
use dns_types::DomainName;
|
||||
use firezone_tunnel::messages::RelaysPresence;
|
||||
use firezone_tunnel::messages::client::{
|
||||
EgressMessages, FailReason, FlowCreated, FlowCreationFailed, GatewayIceCandidates,
|
||||
GatewaysIceCandidates, IngressMessages, InitClient,
|
||||
};
|
||||
use firezone_tunnel::{ClientTunnel, IpConfig};
|
||||
use ip_network::{Ipv4Network, Ipv6Network};
|
||||
use phoenix_channel::{ErrorReply, OutboundRequestId, PhoenixChannel, PublicKeyParam};
|
||||
use std::net::{Ipv4Addr, Ipv6Addr};
|
||||
use std::time::Instant;
|
||||
use std::{
|
||||
collections::BTreeSet,
|
||||
@@ -15,14 +18,15 @@ use std::{
|
||||
net::IpAddr,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
use tokio::sync::mpsc::error::TrySendError;
|
||||
use tun::Tun;
|
||||
|
||||
pub struct Eventloop<C: Callbacks> {
|
||||
pub struct Eventloop {
|
||||
tunnel: ClientTunnel,
|
||||
callbacks: C,
|
||||
|
||||
portal: PhoenixChannel<(), IngressMessages, (), PublicKeyParam>,
|
||||
rx: tokio::sync::mpsc::UnboundedReceiver<Command>,
|
||||
cmd_rx: tokio::sync::mpsc::UnboundedReceiver<Command>,
|
||||
event_tx: tokio::sync::mpsc::Sender<Event>,
|
||||
}
|
||||
|
||||
/// Commands that can be sent to the [`Eventloop`].
|
||||
@@ -33,31 +37,62 @@ pub enum Command {
|
||||
SetDisabledResources(BTreeSet<ResourceId>),
|
||||
}
|
||||
|
||||
impl<C: Callbacks> Eventloop<C> {
|
||||
pub enum Event {
|
||||
TunInterfaceUpdated {
|
||||
ipv4: Ipv4Addr,
|
||||
ipv6: Ipv6Addr,
|
||||
dns: Vec<IpAddr>,
|
||||
search_domain: Option<DomainName>,
|
||||
ipv4_routes: Vec<Ipv4Network>,
|
||||
ipv6_routes: Vec<Ipv6Network>,
|
||||
},
|
||||
ResourcesUpdated(Vec<ResourceView>),
|
||||
Disconnected(DisconnectError),
|
||||
}
|
||||
|
||||
/// Unified error type to use across connlib.
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
#[error("{0:#}")]
|
||||
pub struct DisconnectError(anyhow::Error);
|
||||
|
||||
impl From<anyhow::Error> for DisconnectError {
|
||||
fn from(e: anyhow::Error) -> Self {
|
||||
Self(e)
|
||||
}
|
||||
}
|
||||
|
||||
impl DisconnectError {
|
||||
pub fn is_authentication_error(&self) -> bool {
|
||||
let Some(e) = self.0.downcast_ref::<phoenix_channel::Error>() else {
|
||||
return false;
|
||||
};
|
||||
|
||||
e.is_authentication_error()
|
||||
}
|
||||
}
|
||||
|
||||
impl Eventloop {
|
||||
pub(crate) fn new(
|
||||
tunnel: ClientTunnel,
|
||||
callbacks: C,
|
||||
mut portal: PhoenixChannel<(), IngressMessages, (), PublicKeyParam>,
|
||||
rx: tokio::sync::mpsc::UnboundedReceiver<Command>,
|
||||
cmd_rx: tokio::sync::mpsc::UnboundedReceiver<Command>,
|
||||
event_tx: tokio::sync::mpsc::Sender<Event>,
|
||||
) -> Self {
|
||||
portal.connect(PublicKeyParam(tunnel.public_key().to_bytes()));
|
||||
|
||||
Self {
|
||||
tunnel,
|
||||
portal,
|
||||
rx,
|
||||
callbacks,
|
||||
cmd_rx,
|
||||
event_tx,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<C> Eventloop<C>
|
||||
where
|
||||
C: Callbacks + 'static,
|
||||
{
|
||||
impl Eventloop {
|
||||
pub fn poll(&mut self, cx: &mut Context<'_>) -> Poll<Result<()>> {
|
||||
loop {
|
||||
match self.rx.poll_recv(cx) {
|
||||
match self.cmd_rx.poll_recv(cx) {
|
||||
Poll::Ready(None) => return Poll::Ready(Ok(())),
|
||||
Poll::Ready(Some(Command::SetDns(dns))) => {
|
||||
self.tunnel.state_mut().update_system_resolvers(dns);
|
||||
@@ -84,7 +119,22 @@ where
|
||||
|
||||
match self.tunnel.poll_next_event(cx) {
|
||||
Poll::Ready(Ok(event)) => {
|
||||
self.handle_tunnel_event(event);
|
||||
let Some(e) = self.handle_tunnel_event(event) else {
|
||||
continue;
|
||||
};
|
||||
|
||||
match self.event_tx.try_send(e) {
|
||||
Ok(()) => {}
|
||||
Err(TrySendError::Closed(_)) => {
|
||||
tracing::debug!("Event receiver dropped, exiting event loop");
|
||||
|
||||
return Poll::Ready(Ok(()));
|
||||
}
|
||||
Err(TrySendError::Full(_)) => {
|
||||
tracing::warn!("App cannot keep up with connlib events, dropping");
|
||||
}
|
||||
};
|
||||
|
||||
continue;
|
||||
}
|
||||
Poll::Ready(Err(e)) => {
|
||||
@@ -123,7 +173,7 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_tunnel_event(&mut self, event: firezone_tunnel::ClientEvent) {
|
||||
fn handle_tunnel_event(&mut self, event: firezone_tunnel::ClientEvent) -> Option<Event> {
|
||||
match event {
|
||||
firezone_tunnel::ClientEvent::AddedIceCandidates {
|
||||
conn_id: gateway,
|
||||
@@ -138,6 +188,8 @@ where
|
||||
candidates,
|
||||
}),
|
||||
);
|
||||
|
||||
None
|
||||
}
|
||||
firezone_tunnel::ClientEvent::RemovedIceCandidates {
|
||||
conn_id: gateway,
|
||||
@@ -152,6 +204,8 @@ where
|
||||
candidates,
|
||||
}),
|
||||
);
|
||||
|
||||
None
|
||||
}
|
||||
firezone_tunnel::ClientEvent::ConnectionIntent {
|
||||
connected_gateway_ids,
|
||||
@@ -164,21 +218,21 @@ where
|
||||
connected_gateway_ids,
|
||||
},
|
||||
);
|
||||
|
||||
None
|
||||
}
|
||||
firezone_tunnel::ClientEvent::ResourcesChanged { resources } => {
|
||||
self.callbacks.on_update_resources(resources)
|
||||
Some(Event::ResourcesUpdated(resources))
|
||||
}
|
||||
firezone_tunnel::ClientEvent::TunInterfaceUpdated(config) => {
|
||||
let dns_servers = config.dns_by_sentinel.left_values().copied().collect();
|
||||
|
||||
self.callbacks.on_set_interface_config(
|
||||
config.ip.v4,
|
||||
config.ip.v6,
|
||||
dns_servers,
|
||||
config.search_domain,
|
||||
Vec::from_iter(config.ipv4_routes),
|
||||
Vec::from_iter(config.ipv6_routes),
|
||||
);
|
||||
Some(Event::TunInterfaceUpdated {
|
||||
ipv4: config.ip.v4,
|
||||
ipv6: config.ip.v6,
|
||||
dns: config.dns_by_sentinel.left_values().copied().collect(),
|
||||
search_domain: config.search_domain,
|
||||
ipv4_routes: Vec::from_iter(config.ipv4_routes),
|
||||
ipv6_routes: Vec::from_iter(config.ipv6_routes),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,25 +1,23 @@
|
||||
//! Main connlib library for clients.
|
||||
pub use crate::serde_routelist::{V4RouteList, V6RouteList};
|
||||
use callbacks::BackgroundCallbacks;
|
||||
pub use callbacks::{Callbacks, ChannelCallbackHandler, ConnlibMsg, DisconnectError};
|
||||
pub use connlib_model::StaticSecret;
|
||||
pub use eventloop::Eventloop;
|
||||
pub use eventloop::{DisconnectError, Event};
|
||||
pub use firezone_tunnel::messages::client::{IngressMessages, ResourceDescription};
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use anyhow::{Context as _, Result};
|
||||
use connlib_model::ResourceId;
|
||||
use eventloop::Command;
|
||||
use eventloop::{Command, Eventloop};
|
||||
use firezone_tunnel::ClientTunnel;
|
||||
use phoenix_channel::{PhoenixChannel, PublicKeyParam};
|
||||
use socket_factory::{SocketFactory, TcpSocket, UdpSocket};
|
||||
use std::collections::BTreeSet;
|
||||
use std::net::IpAddr;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::mpsc::UnboundedReceiver;
|
||||
use std::task::{Context, Poll};
|
||||
use tokio::sync::mpsc::{Receiver, Sender, UnboundedReceiver, UnboundedSender};
|
||||
use tokio::task::JoinHandle;
|
||||
use tun::Tun;
|
||||
|
||||
mod callbacks;
|
||||
mod eventloop;
|
||||
mod serde_routelist;
|
||||
|
||||
@@ -31,34 +29,36 @@ const PHOENIX_TOPIC: &str = "client";
|
||||
/// To stop the session, simply drop this struct.
|
||||
#[derive(Clone)]
|
||||
pub struct Session {
|
||||
channel: tokio::sync::mpsc::UnboundedSender<Command>,
|
||||
channel: UnboundedSender<Command>,
|
||||
}
|
||||
|
||||
pub struct EventStream {
|
||||
channel: Receiver<Event>,
|
||||
}
|
||||
|
||||
impl Session {
|
||||
/// Creates a new [`Session`].
|
||||
///
|
||||
/// This connects to the portal using the given [`LoginUrl`](phoenix_channel::LoginUrl) and creates a wireguard tunnel using the provided private key.
|
||||
pub fn connect<CB: Callbacks + 'static>(
|
||||
pub fn connect(
|
||||
tcp_socket_factory: Arc<dyn SocketFactory<TcpSocket>>,
|
||||
udp_socket_factory: Arc<dyn SocketFactory<UdpSocket>>,
|
||||
callbacks: CB,
|
||||
portal: PhoenixChannel<(), IngressMessages, (), PublicKeyParam>,
|
||||
handle: tokio::runtime::Handle,
|
||||
) -> Self {
|
||||
let callbacks = BackgroundCallbacks::new(callbacks); // Run all callbacks on a background thread to avoid blocking the main connlib task.
|
||||
|
||||
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
|
||||
) -> (Self, EventStream) {
|
||||
let (cmd_tx, cmd_rx) = tokio::sync::mpsc::unbounded_channel();
|
||||
let (event_tx, event_rx) = tokio::sync::mpsc::channel(1000);
|
||||
|
||||
let connect_handle = handle.spawn(connect(
|
||||
tcp_socket_factory,
|
||||
udp_socket_factory,
|
||||
callbacks.clone(),
|
||||
portal,
|
||||
rx,
|
||||
cmd_rx,
|
||||
event_tx.clone(),
|
||||
));
|
||||
handle.spawn(connect_supervisor(connect_handle, callbacks));
|
||||
handle.spawn(connect_supervisor(connect_handle, event_tx));
|
||||
|
||||
Self { channel: tx }
|
||||
(Self { channel: cmd_tx }, EventStream { channel: event_rx })
|
||||
}
|
||||
|
||||
/// Reset a [`Session`].
|
||||
@@ -107,6 +107,16 @@ impl Session {
|
||||
}
|
||||
}
|
||||
|
||||
impl EventStream {
|
||||
pub fn poll_next(&mut self, cx: &mut Context) -> Poll<Option<Event>> {
|
||||
self.channel.poll_recv(cx)
|
||||
}
|
||||
|
||||
pub async fn next(&mut self) -> Option<Event> {
|
||||
self.channel.recv().await
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for Session {
|
||||
fn drop(&mut self) {
|
||||
tracing::debug!("`Session` dropped")
|
||||
@@ -116,18 +126,15 @@ impl Drop for Session {
|
||||
/// Connects to the portal and starts a tunnel.
|
||||
///
|
||||
/// When this function exits, the tunnel failed unrecoverably and you need to call it again.
|
||||
async fn connect<CB>(
|
||||
async fn connect(
|
||||
tcp_socket_factory: Arc<dyn SocketFactory<TcpSocket>>,
|
||||
udp_socket_factory: Arc<dyn SocketFactory<UdpSocket>>,
|
||||
callbacks: CB,
|
||||
portal: PhoenixChannel<(), IngressMessages, (), PublicKeyParam>,
|
||||
rx: UnboundedReceiver<Command>,
|
||||
) -> Result<()>
|
||||
where
|
||||
CB: Callbacks + 'static,
|
||||
{
|
||||
cmd_rx: UnboundedReceiver<Command>,
|
||||
event_tx: Sender<Event>,
|
||||
) -> Result<()> {
|
||||
let tunnel = ClientTunnel::new(tcp_socket_factory, udp_socket_factory);
|
||||
let mut eventloop = Eventloop::new(tunnel, callbacks, portal, rx);
|
||||
let mut eventloop = Eventloop::new(tunnel, portal, cmd_rx, event_tx);
|
||||
|
||||
std::future::poll_fn(|cx| eventloop.poll(cx)).await?;
|
||||
|
||||
@@ -135,18 +142,27 @@ where
|
||||
}
|
||||
|
||||
/// A supervisor task that handles, when [`connect`] exits.
|
||||
async fn connect_supervisor<CB>(connect_handle: JoinHandle<Result<()>>, callbacks: CB)
|
||||
where
|
||||
CB: Callbacks,
|
||||
{
|
||||
async fn connect_supervisor(
|
||||
connect_handle: JoinHandle<Result<()>>,
|
||||
event_tx: tokio::sync::mpsc::Sender<Event>,
|
||||
) {
|
||||
let task = async {
|
||||
connect_handle.await.context("connlib crashed")??;
|
||||
|
||||
Ok(())
|
||||
};
|
||||
|
||||
match task.await {
|
||||
Ok(()) => tracing::info!("connlib exited gracefully"),
|
||||
Err(e) => callbacks.on_disconnect(e),
|
||||
let error = match task.await {
|
||||
Ok(()) => {
|
||||
tracing::info!("connlib exited gracefully");
|
||||
|
||||
return;
|
||||
}
|
||||
Err(e) => e,
|
||||
};
|
||||
|
||||
match event_tx.send(Event::Disconnected(error)).await {
|
||||
Ok(()) => (),
|
||||
Err(_) => tracing::debug!("Event stream closed before we could send disconnected event"),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,7 +5,6 @@ use crate::{
|
||||
use anyhow::{Context as _, Result, bail};
|
||||
use atomicwrites::{AtomicFile, OverwriteBehavior};
|
||||
use backoff::ExponentialBackoffBuilder;
|
||||
use client_shared::ConnlibMsg;
|
||||
use connlib_model::{ResourceId, ResourceView};
|
||||
use firezone_bin_shared::{
|
||||
DnsControlMethod, DnsController, TunDeviceManager,
|
||||
@@ -31,7 +30,7 @@ use std::{
|
||||
sync::Arc,
|
||||
time::Duration,
|
||||
};
|
||||
use tokio::{sync::mpsc, time::Instant};
|
||||
use tokio::time::Instant;
|
||||
use url::Url;
|
||||
|
||||
#[cfg(target_os = "linux")]
|
||||
@@ -179,12 +178,12 @@ struct Handler<'a> {
|
||||
}
|
||||
|
||||
struct Session {
|
||||
cb_rx: mpsc::Receiver<ConnlibMsg>,
|
||||
event_stream: client_shared::EventStream,
|
||||
connlib: client_shared::Session,
|
||||
}
|
||||
|
||||
enum Event {
|
||||
Callback(ConnlibMsg),
|
||||
Connlib(client_shared::Event),
|
||||
CallbackChannelClosed,
|
||||
Ipc(ClientMsg),
|
||||
IpcDisconnected,
|
||||
@@ -247,8 +246,8 @@ impl<'a> Handler<'a> {
|
||||
async fn run(&mut self, signals: &mut signals::Terminate) -> HandlerOk {
|
||||
let ret = loop {
|
||||
match poll_fn(|cx| self.next_event(cx, signals)).await {
|
||||
Event::Callback(x) => {
|
||||
if let Err(error) = self.handle_connlib_cb(x).await {
|
||||
Event::Connlib(x) => {
|
||||
if let Err(error) = self.handle_connlib_event(x).await {
|
||||
tracing::error!("Error while handling connlib callback: {error:#}");
|
||||
continue;
|
||||
}
|
||||
@@ -309,10 +308,9 @@ impl<'a> Handler<'a> {
|
||||
});
|
||||
}
|
||||
if let Some(session) = self.session.as_mut() {
|
||||
// `tokio::sync::mpsc::Receiver::recv` is cancel-safe.
|
||||
if let Poll::Ready(option) = session.cb_rx.poll_recv(cx) {
|
||||
if let Poll::Ready(option) = session.event_stream.poll_next(cx) {
|
||||
return Poll::Ready(match option {
|
||||
Some(x) => Event::Callback(x),
|
||||
Some(x) => Event::Connlib(x),
|
||||
None => Event::CallbackChannelClosed,
|
||||
});
|
||||
}
|
||||
@@ -320,21 +318,18 @@ impl<'a> Handler<'a> {
|
||||
Poll::Pending
|
||||
}
|
||||
|
||||
async fn handle_connlib_cb(&mut self, msg: ConnlibMsg) -> Result<()> {
|
||||
async fn handle_connlib_event(&mut self, msg: client_shared::Event) -> Result<()> {
|
||||
match msg {
|
||||
ConnlibMsg::OnDisconnect {
|
||||
error_msg,
|
||||
is_authentication_error,
|
||||
} => {
|
||||
client_shared::Event::Disconnected(error) => {
|
||||
let _ = self.session.take();
|
||||
self.dns_controller.deactivate()?;
|
||||
self.send_ipc(ServerMsg::OnDisconnect {
|
||||
error_msg,
|
||||
is_authentication_error,
|
||||
error_msg: error.to_string(),
|
||||
is_authentication_error: error.is_authentication_error(),
|
||||
})
|
||||
.await?
|
||||
}
|
||||
ConnlibMsg::OnSetInterfaceConfig {
|
||||
client_shared::Event::TunInterfaceUpdated {
|
||||
ipv4,
|
||||
ipv6,
|
||||
dns,
|
||||
@@ -352,7 +347,7 @@ impl<'a> Handler<'a> {
|
||||
|
||||
self.send_ipc(ServerMsg::TunnelReady).await?;
|
||||
}
|
||||
ConnlibMsg::OnUpdateResources(resources) => {
|
||||
client_shared::Event::ResourcesUpdated(resources) => {
|
||||
// On every resources update, flush DNS to mitigate <https://github.com/firezone/firezone/issues/5052>
|
||||
self.dns_controller.flush()?;
|
||||
self.send_ipc(ServerMsg::OnUpdateResources(resources))
|
||||
@@ -472,7 +467,6 @@ impl<'a> Handler<'a> {
|
||||
.context("Failed to create `LoginUrl`")?;
|
||||
|
||||
self.last_connlib_start_instant = Some(Instant::now());
|
||||
let (callbacks, cb_rx) = client_shared::ChannelCallbackHandler::new();
|
||||
|
||||
// Synchronous DNS resolution here
|
||||
let portal = PhoenixChannel::disconnected(
|
||||
@@ -490,10 +484,9 @@ impl<'a> Handler<'a> {
|
||||
|
||||
// Read the resolvers before starting connlib, in case connlib's startup interferes.
|
||||
let dns = self.dns_controller.system_resolvers();
|
||||
let connlib = client_shared::Session::connect(
|
||||
let (connlib, event_stream) = client_shared::Session::connect(
|
||||
Arc::new(tcp_socket_factory),
|
||||
Arc::new(udp_socket_factory),
|
||||
callbacks,
|
||||
portal,
|
||||
tokio::runtime::Handle::current(),
|
||||
);
|
||||
@@ -510,7 +503,10 @@ impl<'a> Handler<'a> {
|
||||
};
|
||||
connlib.set_tun(tun);
|
||||
|
||||
let session = Session { cb_rx, connlib };
|
||||
let session = Session {
|
||||
event_stream,
|
||||
connlib,
|
||||
};
|
||||
self.session = Some(session);
|
||||
|
||||
Ok(())
|
||||
|
||||
@@ -5,7 +5,6 @@
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use backoff::ExponentialBackoffBuilder;
|
||||
use clap::Parser;
|
||||
use client_shared::{ChannelCallbackHandler, ConnlibMsg, Session};
|
||||
use firezone_bin_shared::{
|
||||
DnsControlMethod, DnsController, TOKEN_ENV_KEY, TunDeviceManager, device_id, device_info,
|
||||
new_dns_notifier, new_network_notifier,
|
||||
@@ -15,7 +14,6 @@ use firezone_bin_shared::{
|
||||
use firezone_logging::telemetry_span;
|
||||
use firezone_telemetry::Telemetry;
|
||||
use firezone_telemetry::otel;
|
||||
use futures::StreamExt as _;
|
||||
use opentelemetry_sdk::metrics::{PeriodicReader, SdkMeterProvider};
|
||||
use phoenix_channel::PhoenixChannel;
|
||||
use phoenix_channel::get_user_agent;
|
||||
@@ -26,7 +24,6 @@ use std::{
|
||||
sync::Arc,
|
||||
};
|
||||
use tokio::time::Instant;
|
||||
use tokio_stream::wrappers::ReceiverStream;
|
||||
|
||||
#[cfg(target_os = "linux")]
|
||||
#[path = "linux.rs"]
|
||||
@@ -221,8 +218,6 @@ fn main() -> Result<()> {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let (callbacks, cb_rx) = ChannelCallbackHandler::new();
|
||||
|
||||
// The name matches that in `ipc_service.rs`
|
||||
let mut last_connlib_start_instant = Some(Instant::now());
|
||||
|
||||
@@ -261,10 +256,9 @@ fn main() -> Result<()> {
|
||||
},
|
||||
Arc::new(tcp_socket_factory),
|
||||
)?;
|
||||
let session = Session::connect(
|
||||
let (session, mut event_stream) = client_shared::Session::connect(
|
||||
Arc::new(tcp_socket_factory),
|
||||
Arc::new(udp_socket_factory),
|
||||
callbacks,
|
||||
portal,
|
||||
rt.handle().clone(),
|
||||
);
|
||||
@@ -273,7 +267,6 @@ fn main() -> Result<()> {
|
||||
let mut hangup = signals::Hangup::new()?;
|
||||
|
||||
let mut tun_device = TunDeviceManager::new(ip_packet::MAX_IP_SIZE, 1)?;
|
||||
let mut cb_rx = ReceiverStream::new(cb_rx).fuse();
|
||||
|
||||
let tokio_handle = tokio::runtime::Handle::current();
|
||||
|
||||
@@ -294,7 +287,7 @@ fn main() -> Result<()> {
|
||||
drop(connect_span);
|
||||
|
||||
let result = loop {
|
||||
let cb = tokio::select! {
|
||||
let event = tokio::select! {
|
||||
() = terminate.recv() => {
|
||||
tracing::info!("Caught SIGINT / SIGTERM / Ctrl+C");
|
||||
break Ok(());
|
||||
@@ -318,20 +311,17 @@ fn main() -> Result<()> {
|
||||
session.reset();
|
||||
continue;
|
||||
},
|
||||
cb = cb_rx.next() => cb.context("cb_rx unexpectedly ran empty")?,
|
||||
event = event_stream.next() => event.context("event stream unexpectedly ran empty")?,
|
||||
};
|
||||
|
||||
match cb {
|
||||
match event {
|
||||
// TODO: Headless Client shouldn't be using messages labelled `Ipc`
|
||||
ConnlibMsg::OnDisconnect {
|
||||
error_msg,
|
||||
is_authentication_error: _,
|
||||
} => break Err(anyhow!(error_msg).context("Firezone disconnected")),
|
||||
ConnlibMsg::OnUpdateResources(_) => {
|
||||
client_shared::Event::Disconnected(error) => break Err(anyhow!(error).context("Firezone disconnected")),
|
||||
client_shared::Event::ResourcesUpdated(_) => {
|
||||
// On every Resources update, flush DNS to mitigate <https://github.com/firezone/firezone/issues/5052>
|
||||
dns_controller.flush()?;
|
||||
}
|
||||
ConnlibMsg::OnSetInterfaceConfig {
|
||||
client_shared::Event::TunInterfaceUpdated {
|
||||
ipv4,
|
||||
ipv6,
|
||||
dns,
|
||||
|
||||
Reference in New Issue
Block a user