refactor(connlib): remove Callbacks from Tunnel (#5885)

Following the removal of the return type from the callback functions in
#5839, we can now move the use of the `Callbacks` one layer up the stack
and decouple them entirely from the `Tunnel`.

---------

Signed-off-by: Thomas Eizinger <thomas@eizinger.io>
Co-authored-by: Gabi <gabrielalejandro7@gmail.com>
This commit is contained in:
Thomas Eizinger
2024-07-17 07:00:40 +10:00
committed by GitHub
parent 0e2a13148f
commit 58db5f0639
8 changed files with 100 additions and 107 deletions

View File

@@ -19,7 +19,8 @@ use std::{
};
pub struct Eventloop<C: Callbacks> {
tunnel: ClientTunnel<C>,
tunnel: ClientTunnel,
callbacks: C,
portal: PhoenixChannel<(), IngressMessages, ReplyMessages>,
rx: tokio::sync::mpsc::UnboundedReceiver<Command>,
@@ -37,7 +38,8 @@ pub enum Command {
impl<C: Callbacks> Eventloop<C> {
pub(crate) fn new(
tunnel: ClientTunnel<C>,
tunnel: ClientTunnel,
callbacks: C,
portal: PhoenixChannel<(), IngressMessages, ReplyMessages>,
rx: tokio::sync::mpsc::UnboundedReceiver<Command>,
) -> Self {
@@ -46,6 +48,7 @@ impl<C: Callbacks> Eventloop<C> {
portal,
connection_intents: SentConnectionIntents::default(),
rx,
callbacks,
}
}
}
@@ -153,20 +156,20 @@ where
}
}
firezone_tunnel::ClientEvent::ResourcesChanged { resources } => {
// Note: This may look a bit weird: We are reading an event from the tunnel and yet delegate back to the tunnel here.
// Couldn't the tunnel just do this internally?
// Technically, yes.
// But, we are only accessing the callbacks here which _eventually_ will be removed from `Tunnel`.
// At that point, the tunnel has to emit this event and we need to handle it without delegating back to the tunnel.
// We only access the callbacks here because `Tunnel` already has them and the callbacks are the current way of talking to the UI.
// At a later point, we will probably map to another event here that gets pushed further up.
self.tunnel.callbacks.on_update_resources(resources)
self.callbacks.on_update_resources(resources)
}
firezone_tunnel::ClientEvent::DnsServersChanged { .. } => {
// Unhandled for now.
// As we decouple the core of connlib from the callbacks, this is where we will hook into the DNS server change and notify our clients to set new DNS servers on their platform.
// See https://github.com/firezone/firezone/issues/5106 for details.
firezone_tunnel::ClientEvent::TunInterfaceUpdated {
ip4,
ip6,
dns_by_sentinel,
} => {
let dns_servers = dns_by_sentinel.left_values().copied().collect();
self.callbacks
.on_set_interface_config(ip4, ip6, dns_servers);
}
firezone_tunnel::ClientEvent::TunRoutesUpdated { ip4, ip6 } => {
self.callbacks.on_update_routes(ip4, ip6);
}
}
}

View File

@@ -145,7 +145,6 @@ where
private_key,
tcp_socket_factory.clone(),
udp_socket_factory,
callbacks,
HashMap::from([(url.host().to_string(), addrs)]),
)?;
@@ -160,7 +159,7 @@ where
tcp_socket_factory,
);
let mut eventloop = Eventloop::new(tunnel, portal, rx);
let mut eventloop = Eventloop::new(tunnel, callbacks, portal, rx);
std::future::poll_fn(|cx| eventloop.poll(cx))
.await
@@ -241,12 +240,10 @@ mod tests {
use std::{collections::HashMap, sync::Arc};
let (private_key, _public_key) = connlib_shared::keypair();
let callbacks = Callbacks::default();
let mut tunnel = firezone_tunnel::ClientTunnel::new(
private_key,
Arc::new(socket_factory::tcp),
Arc::new(socket_factory::udp),
callbacks,
HashMap::new(),
)
.unwrap();

View File

@@ -13,7 +13,7 @@ use connlib_shared::messages::{
GatewayId, Interface as InterfaceConfig, IpDnsServer, Key, Offer, Relay, RelayId,
RequestConnection, ResourceId, ReuseConnection,
};
use connlib_shared::{callbacks, Callbacks, DomainName, PublicKey, StaticSecret};
use connlib_shared::{callbacks, DomainName, PublicKey, StaticSecret};
use ip_network::{IpNetwork, Ipv4Network, Ipv6Network};
use ip_network_table::IpNetworkTable;
use ip_packet::{IpPacket, MutableIpPacket, Packet as _};
@@ -45,18 +45,22 @@ const DNS_SENTINELS_V6: &str = "fd00:2021:1111:8000:100:100:111:0/120";
// is 30 seconds. See resolvconf(5) timeout.
const IDS_EXPIRE: std::time::Duration = std::time::Duration::from_secs(60);
impl<CB> ClientTunnel<CB>
where
CB: Callbacks + 'static,
{
impl ClientTunnel {
pub fn set_resources(&mut self, resources: Vec<ResourceDescription>) {
self.role_state.set_resources(resources);
self.callbacks.on_update_routes(
self.role_state.routes().filter_map(utils::ipv4).collect(),
self.role_state.routes().filter_map(utils::ipv6).collect(),
);
self.callbacks
.on_update_resources(self.role_state.resources());
// FIXME: It would be good to add this event from _within_ `ClientState` but we don't want to emit duplicates.
self.role_state
.buffered_events
.push_back(ClientEvent::TunRoutesUpdated {
ip4: self.role_state.routes().filter_map(utils::ipv4).collect(),
ip6: self.role_state.routes().filter_map(utils::ipv6).collect(),
});
self.role_state
.buffered_events
.push_back(ClientEvent::ResourcesChanged {
resources: self.role_state.resources(),
});
}
pub fn set_tun(&mut self, tun: Tun) {
@@ -72,23 +76,33 @@ where
pub fn add_resources(&mut self, resources: &[ResourceDescription]) {
self.role_state.add_resources(resources);
self.callbacks.on_update_routes(
self.role_state.routes().filter_map(utils::ipv4).collect(),
self.role_state.routes().filter_map(utils::ipv6).collect(),
);
self.callbacks
.on_update_resources(self.role_state.resources());
self.role_state
.buffered_events
.push_back(ClientEvent::TunRoutesUpdated {
ip4: self.role_state.routes().filter_map(utils::ipv4).collect(),
ip6: self.role_state.routes().filter_map(utils::ipv6).collect(),
});
self.role_state
.buffered_events
.push_back(ClientEvent::ResourcesChanged {
resources: self.role_state.resources(),
});
}
pub fn remove_resources(&mut self, ids: &[ResourceId]) {
self.role_state.remove_resources(ids);
self.callbacks.on_update_routes(
self.role_state.routes().filter_map(utils::ipv4).collect(),
self.role_state.routes().filter_map(utils::ipv6).collect(),
);
self.callbacks
.on_update_resources(self.role_state.resources())
self.role_state
.buffered_events
.push_back(ClientEvent::TunRoutesUpdated {
ip4: self.role_state.routes().filter_map(utils::ipv4).collect(),
ip6: self.role_state.routes().filter_map(utils::ipv6).collect(),
});
self.role_state
.buffered_events
.push_back(ClientEvent::ResourcesChanged {
resources: self.role_state.resources(),
});
}
/// Updates the system's dns
@@ -103,10 +117,6 @@ where
self.io
.set_upstream_dns_servers(self.role_state.dns_mapping());
if let Some(config) = self.role_state.interface_config.as_ref().cloned() {
self.update_device(config, self.role_state.dns_mapping());
};
}
#[tracing::instrument(level = "trace", skip(self))]
@@ -114,35 +124,16 @@ where
&mut self,
config: InterfaceConfig,
) -> connlib_shared::Result<()> {
let dns_changed = self.role_state.update_interface_config(config.clone());
let dns_changed = self.role_state.update_interface_config(config);
if dns_changed {
self.io
.set_upstream_dns_servers(self.role_state.dns_mapping());
}
self.update_device(config, self.role_state.dns_mapping());
Ok(())
}
pub(crate) fn update_device(
&mut self,
config: InterfaceConfig,
dns_mapping: BiMap<IpAddr, DnsServer>,
) {
// We can just sort in here because sentinel ips are created in order
let dns_config = dns_mapping.left_values().copied().sorted().collect();
self.callbacks
.clone()
.on_set_interface_config(config.ipv4, config.ipv6, dns_config);
self.callbacks.on_update_routes(
self.role_state.routes().filter_map(utils::ipv4).collect(),
self.role_state.routes().filter_map(utils::ipv6).collect(),
);
}
pub fn cleanup_connection(&mut self, id: ResourceId) {
self.role_state.on_connection_failed(id);
}
@@ -152,8 +143,11 @@ where
self.role_state.on_connection_failed(id);
self.callbacks
.on_update_resources(self.role_state.resources());
self.role_state
.buffered_events
.push_back(ClientEvent::ResourcesChanged {
resources: self.role_state.resources(),
});
}
pub fn add_ice_candidate(&mut self, conn_id: GatewayId, ice_candidate: String) {
@@ -1016,16 +1010,26 @@ impl ClientState {
.collect_vec(),
);
let ip4 = config.ipv4;
let ip6 = config.ipv6;
self.set_dns_mapping(dns_mapping);
self.buffered_events
.push_back(ClientEvent::DnsServersChanged {
.push_back(ClientEvent::TunInterfaceUpdated {
ip4,
ip6,
dns_by_sentinel: self
.dns_mapping
.iter()
.map(|(sentinel_dns, effective_dns)| (*sentinel_dns, effective_dns.address()))
.collect(),
});
self.buffered_events
.push_back(ClientEvent::TunRoutesUpdated {
ip4: self.routes().filter_map(utils::ipv4).collect(),
ip6: self.routes().filter_map(utils::ipv6).collect(),
});
true
}

View File

@@ -8,7 +8,7 @@ use connlib_shared::messages::{
gateway::ResolvedResourceDescriptionDns, gateway::ResourceDescription, Answer, ClientId, Key,
Offer, RelayId, ResourceId,
};
use connlib_shared::{Callbacks, DomainName, Error, Result, StaticSecret};
use connlib_shared::{DomainName, Error, Result, StaticSecret};
use ip_packet::{IpPacket, MutableIpPacket};
use secrecy::{ExposeSecret as _, Secret};
use snownet::{RelaySocket, ServerNode};
@@ -18,10 +18,7 @@ use std::time::{Duration, Instant};
const EXPIRE_RESOURCES_INTERVAL: Duration = Duration::from_secs(1);
impl<CB> GatewayTunnel<CB>
where
CB: Callbacks + 'static,
{
impl GatewayTunnel {
pub fn set_tun(&mut self, tun: Tun) {
self.io.device_mut().set_tun(tun);
}

View File

@@ -8,12 +8,12 @@ use chrono::Utc;
use connlib_shared::{
callbacks,
messages::{ClientId, GatewayId, Relay, RelayId, ResourceId, ReuseConnection},
Callbacks, DomainName, Result,
DomainName, Result,
};
use io::Io;
use std::{
collections::{HashMap, HashSet},
net::{IpAddr, SocketAddr},
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
sync::Arc,
task::{Context, Poll},
time::Instant,
@@ -35,6 +35,7 @@ mod sockets;
mod utils;
pub use device_channel::Tun;
use ip_network::{Ipv4Network, Ipv6Network};
#[cfg(all(test, feature = "proptest"))]
mod tests;
@@ -44,16 +45,14 @@ const MTU: usize = 1280;
const REALM: &str = "firezone";
pub type GatewayTunnel<CB> = Tunnel<CB, GatewayState>;
pub type ClientTunnel<CB> = Tunnel<CB, ClientState>;
pub type GatewayTunnel = Tunnel<GatewayState>;
pub type ClientTunnel = Tunnel<ClientState>;
/// [`Tunnel`] glues together connlib's [`Io`] component and the respective (pure) state of a client or gateway.
///
/// Most of connlib's functionality is implemented as a pure state machine in [`ClientState`] and [`GatewayState`].
/// The only job of [`Tunnel`] is to take input from the TUN [`Device`](crate::device_channel::Device), [`Sockets`](crate::sockets::Sockets) or time and pass it to the respective state.
pub struct Tunnel<CB: Callbacks, TRoleState> {
pub callbacks: CB,
pub struct Tunnel<TRoleState> {
/// (pure) state that differs per role, either [`ClientState`] or [`GatewayState`].
role_state: TRoleState,
@@ -71,20 +70,15 @@ pub struct Tunnel<CB: Callbacks, TRoleState> {
device_read_buf: Box<[u8; MTU + 20]>,
}
impl<CB> ClientTunnel<CB>
where
CB: Callbacks + 'static,
{
impl ClientTunnel {
pub fn new(
private_key: StaticSecret,
tcp_socket_factory: Arc<dyn socket_factory::SocketFactory<tokio::net::TcpSocket>>,
udp_socket_factory: Arc<dyn socket_factory::SocketFactory<tokio::net::UdpSocket>>,
callbacks: CB,
known_hosts: HashMap<String, Vec<IpAddr>>,
) -> std::io::Result<Self> {
Ok(Self {
io: Io::new(tcp_socket_factory, udp_socket_factory)?,
callbacks,
role_state: ClientState::new(private_key, known_hosts),
write_buf: Box::new([0u8; MTU + 16 + 20]),
ip4_read_buf: Box::new([0u8; MAX_UDP_SIZE]),
@@ -175,14 +169,10 @@ where
}
}
impl<CB> GatewayTunnel<CB>
where
CB: Callbacks + 'static,
{
pub fn new(private_key: StaticSecret, callbacks: CB) -> std::io::Result<Self> {
impl GatewayTunnel {
pub fn new(private_key: StaticSecret) -> std::io::Result<Self> {
Ok(Self {
io: Io::new(Arc::new(socket_factory::tcp), Arc::new(socket_factory::udp))?,
callbacks,
role_state: GatewayState::new(private_key),
write_buf: Box::new([0u8; MTU + 20 + 16]),
ip4_read_buf: Box::new([0u8; MAX_UDP_SIZE]),
@@ -282,7 +272,10 @@ pub enum ClientEvent {
ResourcesChanged {
resources: Vec<callbacks::ResourceDescription>,
},
DnsServersChanged {
// TODO: Make this more fine-granular.
TunInterfaceUpdated {
ip4: Ipv4Addr,
ip6: Ipv6Addr,
/// The map of DNS servers that connlib will use.
///
/// - The "left" values are the connlib-assigned, proxy (or "sentinel") IPs.
@@ -291,6 +284,10 @@ pub enum ClientEvent {
/// Otherwise, we will use the DNS servers configured on the system.
dns_by_sentinel: BiMap<IpAddr, SocketAddr>,
},
TunRoutesUpdated {
ip4: Vec<Ipv4Network>,
ip6: Vec<Ipv6Network>,
},
}
#[derive(Debug, Clone)]

View File

@@ -689,10 +689,13 @@ impl TunnelTest {
ClientEvent::ResourcesChanged { .. } => {
tracing::warn!("Unimplemented");
}
ClientEvent::DnsServersChanged { dns_by_sentinel } => {
ClientEvent::TunInterfaceUpdated {
dns_by_sentinel, ..
} => {
self.client
.exec_mut(|c| c.dns_by_sentinel = dns_by_sentinel);
}
ClientEvent::TunRoutesUpdated { .. } => {}
}
}

View File

@@ -2,7 +2,6 @@ use crate::messages::{
AllowAccess, ClientIceCandidates, ClientsIceCandidates, ConnectionReady, EgressMessages,
IngressMessages, RejectAccess, RequestConnection,
};
use crate::CallbackHandler;
use anyhow::Result;
use boringtun::x25519::PublicKey;
use connlib_shared::messages::{
@@ -40,7 +39,7 @@ enum ResolveTrigger {
}
pub struct Eventloop {
tunnel: GatewayTunnel<CallbackHandler>,
tunnel: GatewayTunnel,
portal: PhoenixChannel<(), IngressMessages, ()>,
tun_device_channel: mpsc::Sender<Interface>,
@@ -49,7 +48,7 @@ pub struct Eventloop {
impl Eventloop {
pub(crate) fn new(
tunnel: GatewayTunnel<CallbackHandler>,
tunnel: GatewayTunnel,
portal: PhoenixChannel<(), IngressMessages, ()>,
tun_device_channel: mpsc::Sender<Interface>,
) -> Self {

View File

@@ -2,9 +2,7 @@ use crate::eventloop::{Eventloop, PHOENIX_TOPIC};
use anyhow::{Context, Result};
use backoff::ExponentialBackoffBuilder;
use clap::Parser;
use connlib_shared::{
get_user_agent, keypair, messages::Interface, Callbacks, LoginUrl, StaticSecret,
};
use connlib_shared::{get_user_agent, keypair, messages::Interface, LoginUrl, StaticSecret};
use firezone_bin_shared::{setup_global_subscriber, CommonArgs, TunDeviceManager};
use firezone_tunnel::{GatewayTunnel, Tun};
@@ -102,7 +100,7 @@ async fn get_firezone_id(env_id: Option<String>) -> Result<String> {
}
async fn run(login: LoginUrl, private_key: StaticSecret) -> Result<Infallible> {
let mut tunnel = GatewayTunnel::new(private_key, CallbackHandler)?;
let mut tunnel = GatewayTunnel::new(private_key)?;
let portal = PhoenixChannel::connect(
Secret::new(login),
get_user_agent(None, env!("CARGO_PKG_VERSION")),
@@ -154,11 +152,6 @@ async fn update_device_task(
}
}
#[derive(Clone)]
struct CallbackHandler;
impl Callbacks for CallbackHandler {}
#[derive(Parser)]
#[command(author, version, about, long_about = None)]
struct Cli {