From 58db5f06397e6b66e368004b61020dcfdfad9acf Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Wed, 17 Jul 2024 07:00:40 +1000 Subject: [PATCH] refactor(connlib): remove `Callbacks` from `Tunnel` (#5885) Following the removal of the return type from the callback functions in #5839, we can now move the use of the `Callbacks` one layer up the stack and decouple them entirely from the `Tunnel`. --------- Signed-off-by: Thomas Eizinger Co-authored-by: Gabi --- rust/connlib/clients/shared/src/eventloop.rs | 33 +++--- rust/connlib/clients/shared/src/lib.rs | 5 +- rust/connlib/tunnel/src/client.rs | 104 ++++++++++--------- rust/connlib/tunnel/src/gateway.rs | 7 +- rust/connlib/tunnel/src/lib.rs | 37 +++---- rust/connlib/tunnel/src/tests/sut.rs | 5 +- rust/gateway/src/eventloop.rs | 5 +- rust/gateway/src/main.rs | 11 +- 8 files changed, 100 insertions(+), 107 deletions(-) diff --git a/rust/connlib/clients/shared/src/eventloop.rs b/rust/connlib/clients/shared/src/eventloop.rs index 247c9d03a..0bd362025 100644 --- a/rust/connlib/clients/shared/src/eventloop.rs +++ b/rust/connlib/clients/shared/src/eventloop.rs @@ -19,7 +19,8 @@ use std::{ }; pub struct Eventloop { - tunnel: ClientTunnel, + tunnel: ClientTunnel, + callbacks: C, portal: PhoenixChannel<(), IngressMessages, ReplyMessages>, rx: tokio::sync::mpsc::UnboundedReceiver, @@ -37,7 +38,8 @@ pub enum Command { impl Eventloop { pub(crate) fn new( - tunnel: ClientTunnel, + tunnel: ClientTunnel, + callbacks: C, portal: PhoenixChannel<(), IngressMessages, ReplyMessages>, rx: tokio::sync::mpsc::UnboundedReceiver, ) -> Self { @@ -46,6 +48,7 @@ impl Eventloop { portal, connection_intents: SentConnectionIntents::default(), rx, + callbacks, } } } @@ -153,20 +156,20 @@ where } } firezone_tunnel::ClientEvent::ResourcesChanged { resources } => { - // Note: This may look a bit weird: We are reading an event from the tunnel and yet delegate back to the tunnel here. - // Couldn't the tunnel just do this internally? - // Technically, yes. - // But, we are only accessing the callbacks here which _eventually_ will be removed from `Tunnel`. - // At that point, the tunnel has to emit this event and we need to handle it without delegating back to the tunnel. - // We only access the callbacks here because `Tunnel` already has them and the callbacks are the current way of talking to the UI. - // At a later point, we will probably map to another event here that gets pushed further up. - - self.tunnel.callbacks.on_update_resources(resources) + self.callbacks.on_update_resources(resources) } - firezone_tunnel::ClientEvent::DnsServersChanged { .. } => { - // Unhandled for now. - // As we decouple the core of connlib from the callbacks, this is where we will hook into the DNS server change and notify our clients to set new DNS servers on their platform. - // See https://github.com/firezone/firezone/issues/5106 for details. + firezone_tunnel::ClientEvent::TunInterfaceUpdated { + ip4, + ip6, + dns_by_sentinel, + } => { + let dns_servers = dns_by_sentinel.left_values().copied().collect(); + + self.callbacks + .on_set_interface_config(ip4, ip6, dns_servers); + } + firezone_tunnel::ClientEvent::TunRoutesUpdated { ip4, ip6 } => { + self.callbacks.on_update_routes(ip4, ip6); } } } diff --git a/rust/connlib/clients/shared/src/lib.rs b/rust/connlib/clients/shared/src/lib.rs index f3110d265..31a4b30bb 100644 --- a/rust/connlib/clients/shared/src/lib.rs +++ b/rust/connlib/clients/shared/src/lib.rs @@ -145,7 +145,6 @@ where private_key, tcp_socket_factory.clone(), udp_socket_factory, - callbacks, HashMap::from([(url.host().to_string(), addrs)]), )?; @@ -160,7 +159,7 @@ where tcp_socket_factory, ); - let mut eventloop = Eventloop::new(tunnel, portal, rx); + let mut eventloop = Eventloop::new(tunnel, callbacks, portal, rx); std::future::poll_fn(|cx| eventloop.poll(cx)) .await @@ -241,12 +240,10 @@ mod tests { use std::{collections::HashMap, sync::Arc}; let (private_key, _public_key) = connlib_shared::keypair(); - let callbacks = Callbacks::default(); let mut tunnel = firezone_tunnel::ClientTunnel::new( private_key, Arc::new(socket_factory::tcp), Arc::new(socket_factory::udp), - callbacks, HashMap::new(), ) .unwrap(); diff --git a/rust/connlib/tunnel/src/client.rs b/rust/connlib/tunnel/src/client.rs index c3b42c8ee..8f3f86bb3 100644 --- a/rust/connlib/tunnel/src/client.rs +++ b/rust/connlib/tunnel/src/client.rs @@ -13,7 +13,7 @@ use connlib_shared::messages::{ GatewayId, Interface as InterfaceConfig, IpDnsServer, Key, Offer, Relay, RelayId, RequestConnection, ResourceId, ReuseConnection, }; -use connlib_shared::{callbacks, Callbacks, DomainName, PublicKey, StaticSecret}; +use connlib_shared::{callbacks, DomainName, PublicKey, StaticSecret}; use ip_network::{IpNetwork, Ipv4Network, Ipv6Network}; use ip_network_table::IpNetworkTable; use ip_packet::{IpPacket, MutableIpPacket, Packet as _}; @@ -45,18 +45,22 @@ const DNS_SENTINELS_V6: &str = "fd00:2021:1111:8000:100:100:111:0/120"; // is 30 seconds. See resolvconf(5) timeout. const IDS_EXPIRE: std::time::Duration = std::time::Duration::from_secs(60); -impl ClientTunnel -where - CB: Callbacks + 'static, -{ +impl ClientTunnel { pub fn set_resources(&mut self, resources: Vec) { self.role_state.set_resources(resources); - self.callbacks.on_update_routes( - self.role_state.routes().filter_map(utils::ipv4).collect(), - self.role_state.routes().filter_map(utils::ipv6).collect(), - ); - self.callbacks - .on_update_resources(self.role_state.resources()); + + // FIXME: It would be good to add this event from _within_ `ClientState` but we don't want to emit duplicates. + self.role_state + .buffered_events + .push_back(ClientEvent::TunRoutesUpdated { + ip4: self.role_state.routes().filter_map(utils::ipv4).collect(), + ip6: self.role_state.routes().filter_map(utils::ipv6).collect(), + }); + self.role_state + .buffered_events + .push_back(ClientEvent::ResourcesChanged { + resources: self.role_state.resources(), + }); } pub fn set_tun(&mut self, tun: Tun) { @@ -72,23 +76,33 @@ where pub fn add_resources(&mut self, resources: &[ResourceDescription]) { self.role_state.add_resources(resources); - self.callbacks.on_update_routes( - self.role_state.routes().filter_map(utils::ipv4).collect(), - self.role_state.routes().filter_map(utils::ipv6).collect(), - ); - self.callbacks - .on_update_resources(self.role_state.resources()); + self.role_state + .buffered_events + .push_back(ClientEvent::TunRoutesUpdated { + ip4: self.role_state.routes().filter_map(utils::ipv4).collect(), + ip6: self.role_state.routes().filter_map(utils::ipv6).collect(), + }); + self.role_state + .buffered_events + .push_back(ClientEvent::ResourcesChanged { + resources: self.role_state.resources(), + }); } pub fn remove_resources(&mut self, ids: &[ResourceId]) { self.role_state.remove_resources(ids); - self.callbacks.on_update_routes( - self.role_state.routes().filter_map(utils::ipv4).collect(), - self.role_state.routes().filter_map(utils::ipv6).collect(), - ); - self.callbacks - .on_update_resources(self.role_state.resources()) + self.role_state + .buffered_events + .push_back(ClientEvent::TunRoutesUpdated { + ip4: self.role_state.routes().filter_map(utils::ipv4).collect(), + ip6: self.role_state.routes().filter_map(utils::ipv6).collect(), + }); + self.role_state + .buffered_events + .push_back(ClientEvent::ResourcesChanged { + resources: self.role_state.resources(), + }); } /// Updates the system's dns @@ -103,10 +117,6 @@ where self.io .set_upstream_dns_servers(self.role_state.dns_mapping()); - - if let Some(config) = self.role_state.interface_config.as_ref().cloned() { - self.update_device(config, self.role_state.dns_mapping()); - }; } #[tracing::instrument(level = "trace", skip(self))] @@ -114,35 +124,16 @@ where &mut self, config: InterfaceConfig, ) -> connlib_shared::Result<()> { - let dns_changed = self.role_state.update_interface_config(config.clone()); + let dns_changed = self.role_state.update_interface_config(config); if dns_changed { self.io .set_upstream_dns_servers(self.role_state.dns_mapping()); } - self.update_device(config, self.role_state.dns_mapping()); - Ok(()) } - pub(crate) fn update_device( - &mut self, - config: InterfaceConfig, - dns_mapping: BiMap, - ) { - // We can just sort in here because sentinel ips are created in order - let dns_config = dns_mapping.left_values().copied().sorted().collect(); - - self.callbacks - .clone() - .on_set_interface_config(config.ipv4, config.ipv6, dns_config); - self.callbacks.on_update_routes( - self.role_state.routes().filter_map(utils::ipv4).collect(), - self.role_state.routes().filter_map(utils::ipv6).collect(), - ); - } - pub fn cleanup_connection(&mut self, id: ResourceId) { self.role_state.on_connection_failed(id); } @@ -152,8 +143,11 @@ where self.role_state.on_connection_failed(id); - self.callbacks - .on_update_resources(self.role_state.resources()); + self.role_state + .buffered_events + .push_back(ClientEvent::ResourcesChanged { + resources: self.role_state.resources(), + }); } pub fn add_ice_candidate(&mut self, conn_id: GatewayId, ice_candidate: String) { @@ -1016,16 +1010,26 @@ impl ClientState { .collect_vec(), ); + let ip4 = config.ipv4; + let ip6 = config.ipv6; + self.set_dns_mapping(dns_mapping); self.buffered_events - .push_back(ClientEvent::DnsServersChanged { + .push_back(ClientEvent::TunInterfaceUpdated { + ip4, + ip6, dns_by_sentinel: self .dns_mapping .iter() .map(|(sentinel_dns, effective_dns)| (*sentinel_dns, effective_dns.address())) .collect(), }); + self.buffered_events + .push_back(ClientEvent::TunRoutesUpdated { + ip4: self.routes().filter_map(utils::ipv4).collect(), + ip6: self.routes().filter_map(utils::ipv6).collect(), + }); true } diff --git a/rust/connlib/tunnel/src/gateway.rs b/rust/connlib/tunnel/src/gateway.rs index e66dea6e4..8ffdcb3cb 100644 --- a/rust/connlib/tunnel/src/gateway.rs +++ b/rust/connlib/tunnel/src/gateway.rs @@ -8,7 +8,7 @@ use connlib_shared::messages::{ gateway::ResolvedResourceDescriptionDns, gateway::ResourceDescription, Answer, ClientId, Key, Offer, RelayId, ResourceId, }; -use connlib_shared::{Callbacks, DomainName, Error, Result, StaticSecret}; +use connlib_shared::{DomainName, Error, Result, StaticSecret}; use ip_packet::{IpPacket, MutableIpPacket}; use secrecy::{ExposeSecret as _, Secret}; use snownet::{RelaySocket, ServerNode}; @@ -18,10 +18,7 @@ use std::time::{Duration, Instant}; const EXPIRE_RESOURCES_INTERVAL: Duration = Duration::from_secs(1); -impl GatewayTunnel -where - CB: Callbacks + 'static, -{ +impl GatewayTunnel { pub fn set_tun(&mut self, tun: Tun) { self.io.device_mut().set_tun(tun); } diff --git a/rust/connlib/tunnel/src/lib.rs b/rust/connlib/tunnel/src/lib.rs index c5118fbda..394c92458 100644 --- a/rust/connlib/tunnel/src/lib.rs +++ b/rust/connlib/tunnel/src/lib.rs @@ -8,12 +8,12 @@ use chrono::Utc; use connlib_shared::{ callbacks, messages::{ClientId, GatewayId, Relay, RelayId, ResourceId, ReuseConnection}, - Callbacks, DomainName, Result, + DomainName, Result, }; use io::Io; use std::{ collections::{HashMap, HashSet}, - net::{IpAddr, SocketAddr}, + net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, sync::Arc, task::{Context, Poll}, time::Instant, @@ -35,6 +35,7 @@ mod sockets; mod utils; pub use device_channel::Tun; +use ip_network::{Ipv4Network, Ipv6Network}; #[cfg(all(test, feature = "proptest"))] mod tests; @@ -44,16 +45,14 @@ const MTU: usize = 1280; const REALM: &str = "firezone"; -pub type GatewayTunnel = Tunnel; -pub type ClientTunnel = Tunnel; +pub type GatewayTunnel = Tunnel; +pub type ClientTunnel = Tunnel; /// [`Tunnel`] glues together connlib's [`Io`] component and the respective (pure) state of a client or gateway. /// /// Most of connlib's functionality is implemented as a pure state machine in [`ClientState`] and [`GatewayState`]. /// The only job of [`Tunnel`] is to take input from the TUN [`Device`](crate::device_channel::Device), [`Sockets`](crate::sockets::Sockets) or time and pass it to the respective state. -pub struct Tunnel { - pub callbacks: CB, - +pub struct Tunnel { /// (pure) state that differs per role, either [`ClientState`] or [`GatewayState`]. role_state: TRoleState, @@ -71,20 +70,15 @@ pub struct Tunnel { device_read_buf: Box<[u8; MTU + 20]>, } -impl ClientTunnel -where - CB: Callbacks + 'static, -{ +impl ClientTunnel { pub fn new( private_key: StaticSecret, tcp_socket_factory: Arc>, udp_socket_factory: Arc>, - callbacks: CB, known_hosts: HashMap>, ) -> std::io::Result { Ok(Self { io: Io::new(tcp_socket_factory, udp_socket_factory)?, - callbacks, role_state: ClientState::new(private_key, known_hosts), write_buf: Box::new([0u8; MTU + 16 + 20]), ip4_read_buf: Box::new([0u8; MAX_UDP_SIZE]), @@ -175,14 +169,10 @@ where } } -impl GatewayTunnel -where - CB: Callbacks + 'static, -{ - pub fn new(private_key: StaticSecret, callbacks: CB) -> std::io::Result { +impl GatewayTunnel { + pub fn new(private_key: StaticSecret) -> std::io::Result { Ok(Self { io: Io::new(Arc::new(socket_factory::tcp), Arc::new(socket_factory::udp))?, - callbacks, role_state: GatewayState::new(private_key), write_buf: Box::new([0u8; MTU + 20 + 16]), ip4_read_buf: Box::new([0u8; MAX_UDP_SIZE]), @@ -282,7 +272,10 @@ pub enum ClientEvent { ResourcesChanged { resources: Vec, }, - DnsServersChanged { + // TODO: Make this more fine-granular. + TunInterfaceUpdated { + ip4: Ipv4Addr, + ip6: Ipv6Addr, /// The map of DNS servers that connlib will use. /// /// - The "left" values are the connlib-assigned, proxy (or "sentinel") IPs. @@ -291,6 +284,10 @@ pub enum ClientEvent { /// Otherwise, we will use the DNS servers configured on the system. dns_by_sentinel: BiMap, }, + TunRoutesUpdated { + ip4: Vec, + ip6: Vec, + }, } #[derive(Debug, Clone)] diff --git a/rust/connlib/tunnel/src/tests/sut.rs b/rust/connlib/tunnel/src/tests/sut.rs index b72dde7d9..9794062fd 100644 --- a/rust/connlib/tunnel/src/tests/sut.rs +++ b/rust/connlib/tunnel/src/tests/sut.rs @@ -689,10 +689,13 @@ impl TunnelTest { ClientEvent::ResourcesChanged { .. } => { tracing::warn!("Unimplemented"); } - ClientEvent::DnsServersChanged { dns_by_sentinel } => { + ClientEvent::TunInterfaceUpdated { + dns_by_sentinel, .. + } => { self.client .exec_mut(|c| c.dns_by_sentinel = dns_by_sentinel); } + ClientEvent::TunRoutesUpdated { .. } => {} } } diff --git a/rust/gateway/src/eventloop.rs b/rust/gateway/src/eventloop.rs index 5485dc5dc..98e9ab581 100644 --- a/rust/gateway/src/eventloop.rs +++ b/rust/gateway/src/eventloop.rs @@ -2,7 +2,6 @@ use crate::messages::{ AllowAccess, ClientIceCandidates, ClientsIceCandidates, ConnectionReady, EgressMessages, IngressMessages, RejectAccess, RequestConnection, }; -use crate::CallbackHandler; use anyhow::Result; use boringtun::x25519::PublicKey; use connlib_shared::messages::{ @@ -40,7 +39,7 @@ enum ResolveTrigger { } pub struct Eventloop { - tunnel: GatewayTunnel, + tunnel: GatewayTunnel, portal: PhoenixChannel<(), IngressMessages, ()>, tun_device_channel: mpsc::Sender, @@ -49,7 +48,7 @@ pub struct Eventloop { impl Eventloop { pub(crate) fn new( - tunnel: GatewayTunnel, + tunnel: GatewayTunnel, portal: PhoenixChannel<(), IngressMessages, ()>, tun_device_channel: mpsc::Sender, ) -> Self { diff --git a/rust/gateway/src/main.rs b/rust/gateway/src/main.rs index c631ec8a5..40e1dbf9c 100644 --- a/rust/gateway/src/main.rs +++ b/rust/gateway/src/main.rs @@ -2,9 +2,7 @@ use crate::eventloop::{Eventloop, PHOENIX_TOPIC}; use anyhow::{Context, Result}; use backoff::ExponentialBackoffBuilder; use clap::Parser; -use connlib_shared::{ - get_user_agent, keypair, messages::Interface, Callbacks, LoginUrl, StaticSecret, -}; +use connlib_shared::{get_user_agent, keypair, messages::Interface, LoginUrl, StaticSecret}; use firezone_bin_shared::{setup_global_subscriber, CommonArgs, TunDeviceManager}; use firezone_tunnel::{GatewayTunnel, Tun}; @@ -102,7 +100,7 @@ async fn get_firezone_id(env_id: Option) -> Result { } async fn run(login: LoginUrl, private_key: StaticSecret) -> Result { - let mut tunnel = GatewayTunnel::new(private_key, CallbackHandler)?; + let mut tunnel = GatewayTunnel::new(private_key)?; let portal = PhoenixChannel::connect( Secret::new(login), get_user_agent(None, env!("CARGO_PKG_VERSION")), @@ -154,11 +152,6 @@ async fn update_device_task( } } -#[derive(Clone)] -struct CallbackHandler; - -impl Callbacks for CallbackHandler {} - #[derive(Parser)] #[command(author, version, about, long_about = None)] struct Cli {