From 5edd1953200032dd99746d5e5e2d7553eaa0771a Mon Sep 17 00:00:00 2001 From: Gabi Date: Mon, 26 Feb 2024 13:07:38 -0300 Subject: [PATCH] refactor(connlib): unify peer storage (#3738) Now that we have `&mut` access everywhere in the tunnel, the remaining shared-memory and locks are in how we store peers. To resolve this, we introduce a new `PeerStore` that allows us to look up peers by IP and by ID. --- rust/Cargo.lock | 2 - rust/connlib/clients/shared/src/control.rs | 4 - rust/connlib/clients/shared/src/lib.rs | 2 - rust/connlib/tunnel/Cargo.toml | 2 - rust/connlib/tunnel/src/client.rs | 54 +++---- rust/connlib/tunnel/src/control_protocol.rs | 18 +-- .../tunnel/src/control_protocol/client.rs | 48 +++--- .../tunnel/src/control_protocol/gateway.rs | 29 +--- rust/connlib/tunnel/src/gateway.rs | 55 +++---- rust/connlib/tunnel/src/lib.rs | 84 +++++------ rust/connlib/tunnel/src/peer.rs | 100 +++++-------- rust/connlib/tunnel/src/peer_store.rs | 141 ++++++++++++++++++ rust/gateway/src/eventloop.rs | 9 +- 13 files changed, 282 insertions(+), 266 deletions(-) create mode 100644 rust/connlib/tunnel/src/peer_store.rs diff --git a/rust/Cargo.lock b/rust/Cargo.lock index 257ef2f33..008990728 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -1922,7 +1922,6 @@ name = "firezone-tunnel" version = "1.0.0" dependencies = [ "anyhow", - "arc-swap", "async-trait", "bimap", "boringtun", @@ -1942,7 +1941,6 @@ dependencies = [ "log", "netlink-packet-core", "netlink-packet-route", - "parking_lot", "pnet_packet", "quinn-udp", "rand_core 0.6.4", diff --git a/rust/connlib/clients/shared/src/control.rs b/rust/connlib/clients/shared/src/control.rs index c64cc1ff6..c09cacf4f 100644 --- a/rust/connlib/clients/shared/src/control.rs +++ b/rust/connlib/clients/shared/src/control.rs @@ -311,10 +311,6 @@ impl ControlPlane { Ok(()) } - pub async fn stats_event(&mut self) { - tracing::debug!(target: "tunnel_state", stats = ?self.tunnel.stats()); - } - pub async fn request_log_upload_url(&mut self) { tracing::info!("Requesting log upload URL from portal"); diff --git a/rust/connlib/clients/shared/src/lib.rs b/rust/connlib/clients/shared/src/lib.rs index 6b778ef50..35ccdea15 100644 --- a/rust/connlib/clients/shared/src/lib.rs +++ b/rust/connlib/clients/shared/src/lib.rs @@ -185,7 +185,6 @@ where let runtime_stopper = runtime_stopper.clone(); let callbacks = callbacks.clone(); async move { - let mut log_stats_interval = tokio::time::interval(Duration::from_secs(10)); let mut upload_logs_interval = upload_interval(); loop { tokio::select! { @@ -201,7 +200,6 @@ where } }, event = poll_fn(|cx| control_plane.tunnel.poll_next_event(cx)) => control_plane.handle_tunnel_event(event).await, - _ = log_stats_interval.tick() => control_plane.stats_event().await, _ = upload_logs_interval.tick() => control_plane.request_log_upload_url().await, else => break } diff --git a/rust/connlib/tunnel/Cargo.toml b/rust/connlib/tunnel/Cargo.toml index 59438fd03..57c254ebc 100644 --- a/rust/connlib/tunnel/Cargo.toml +++ b/rust/connlib/tunnel/Cargo.toml @@ -14,7 +14,6 @@ serde = { version = "1.0", default-features = false, features = ["derive", "std" futures = { version = "0.3", default-features = false, features = ["std", "async-await", "executor"] } futures-util = { version = "0.3", default-features = false, features = ["std", "async-await", "async-await-macro"] } tracing = { workspace = true } -parking_lot = { version = "0.12", default-features = false } bytes = { version = "1.4", default-features = false, features = ["std"] } itertools = { version = "0.12", default-features = false, features = ["use_std"] } connlib-shared = { workspace = true } @@ -27,7 +26,6 @@ chrono = { workspace = true } pnet_packet = { version = "0.34" } futures-bounded = { workspace = true } hickory-resolver = { workspace = true, features = ["tokio-runtime"] } -arc-swap = "1.6.0" bimap = "0.6" resolv-conf = "0.7.0" socket2 = { version = "0.5" } diff --git a/rust/connlib/tunnel/src/client.rs b/rust/connlib/tunnel/src/client.rs index edb3ff71b..c9bbebcfd 100644 --- a/rust/connlib/tunnel/src/client.rs +++ b/rust/connlib/tunnel/src/client.rs @@ -1,7 +1,8 @@ use crate::device_channel::{Device, Packet}; use crate::ip_packet::{IpPacket, MutableIpPacket}; -use crate::peer::{PacketTransformClient, Peer}; -use crate::{dns, dns::DnsQuery, peer_by_ip, Event, Tunnel, DNS_QUERIES_QUEUE_SIZE}; +use crate::peer::PacketTransformClient; +use crate::peer_store::PeerStore; +use crate::{dns, dns::DnsQuery, Event, Tunnel, DNS_QUERIES_QUEUE_SIZE}; use bimap::BiMap; use connlib_shared::error::{ConnlibError as Error, ConnlibError}; use connlib_shared::messages::{ @@ -22,7 +23,6 @@ use hickory_resolver::TokioAsyncResolver; use std::collections::hash_map::Entry; use std::collections::{HashMap, HashSet, VecDeque}; use std::net::IpAddr; -use std::sync::Arc; use std::task::{Context, Poll}; use std::time::Duration; use tokio::time::{Instant, Interval, MissedTickBehavior}; @@ -47,7 +47,7 @@ impl DnsResource { } } -impl Tunnel +impl Tunnel where CB: Callbacks + 'static, { @@ -189,7 +189,7 @@ pub struct ClientState { pub resource_ids: HashMap, pub deferred_dns_queries: HashMap<(DnsResource, Rtype), IpPacket<'static>>, - pub peers_by_ip: IpNetworkTable>>, + pub peers: PeerStore, forwarded_dns_queries: FuturesTupleSet< Result, @@ -227,7 +227,7 @@ impl ClientState { Err(non_dns_packet) => non_dns_packet, }; - let Some(peer) = peer_by_ip(&self.peers_by_ip, dest) else { + let Some(peer) = self.peers.peer_by_ip_mut(dest) else { self.on_connection_intent_ip(dest); return None; }; @@ -309,7 +309,7 @@ impl ClientState { let domain = self.get_awaiting_connection_domain(&resource)?.clone(); - if self.is_connected_to(resource, &self.peers_by_ip, &domain) { + if self.is_connected_to(resource, &domain) { return Err(Error::UnexpectedConnectionDetails); } @@ -332,11 +332,11 @@ impl ClientState { self.resources_gateways.insert(resource, gateway); - let Some(peer) = self - .peers_by_ip - .iter() - .find_map(|(_, p)| (p.conn_id == gateway).then_some(p.clone())) - else { + if self + .peers + .add_ips(&gateway, &self.get_resource_ip(desc, &domain)) + .is_none() + { match self .gateway_awaiting_connection_timers // Note: we don't need to set a timer here because @@ -357,10 +357,6 @@ impl ClientState { return Ok(None); }; - for ip in self.get_resource_ip(desc, &domain) { - peer.add_allowed_ip(ip); - self.peers_by_ip.insert(ip, peer.clone()); - } self.awaiting_connection.remove(&resource); self.awaiting_connection_timers.remove(resource); @@ -531,19 +527,13 @@ impl ClientState { self.awaiting_connection.contains_key(&resource.id()) } - fn is_connected_to( - &self, - resource: ResourceId, - connected_peers: &IpNetworkTable>>, - domain: &Option, - ) -> bool { + fn is_connected_to(&self, resource: ResourceId, domain: &Option) -> bool { let Some(resource) = self.resource_ids.get(&resource) else { return false; }; let ips = self.get_resource_ip(resource, domain); - ips.iter() - .any(|ip| connected_peers.exact_match(*ip).is_some()) + ips.iter().any(|ip| self.peers.exact_match(*ip).is_some()) } fn get_resource_ip( @@ -571,7 +561,7 @@ impl ClientState { } pub fn cleanup_connected_gateway(&mut self, gateway_id: &GatewayId) { - self.peers_by_ip.retain(|_, p| p.conn_id != *gateway_id); + self.peers.remove(gateway_id); self.dns_resources_internal_ips.retain(|resource, _| { !self .resources_gateways @@ -668,20 +658,16 @@ impl ClientState { if self.refresh_dns_timer.poll_tick(cx).is_ready() { let mut connections = Vec::new(); - self.peers_by_ip - .iter() - .for_each(|p| p.1.transform.expire_dns_track()); + self.peers + .iter_mut() + .for_each(|p| p.transform.expire_dns_track()); for resource in self.dns_resources_internal_ips.keys() { let Some(gateway_id) = self.resources_gateways.get(&resource.id) else { continue; }; // filter inactive connections - if !self - .peers_by_ip - .iter() - .any(|(_, p)| &p.conn_id == gateway_id) - { + if self.peers.get(gateway_id).is_none() { continue; } @@ -761,7 +747,7 @@ impl Default for ClientState { dns_resources: Default::default(), cidr_resources: IpNetworkTable::new(), resource_ids: Default::default(), - peers_by_ip: IpNetworkTable::new(), + peers: Default::default(), deferred_dns_queries: Default::default(), refresh_dns_timer: interval, dns_mapping: Default::default(), diff --git a/rust/connlib/tunnel/src/control_protocol.rs b/rust/connlib/tunnel/src/control_protocol.rs index 400774bb4..879d469f5 100644 --- a/rust/connlib/tunnel/src/control_protocol.rs +++ b/rust/connlib/tunnel/src/control_protocol.rs @@ -1,13 +1,11 @@ -use ip_network::IpNetwork; -use ip_network_table::IpNetworkTable; -use std::{collections::HashSet, fmt, hash::Hash, net::SocketAddr, sync::Arc}; +use std::{collections::HashSet, fmt, hash::Hash, net::SocketAddr}; use connlib_shared::{ messages::{Relay, RequestConnection, ReuseConnection}, Callbacks, }; -use crate::{peer::Peer, Tunnel, REALM}; +use crate::{Tunnel, REALM}; mod client; pub mod gateway; @@ -18,7 +16,7 @@ pub enum Request { ReuseConnection(ReuseConnection), } -impl Tunnel +impl Tunnel where CB: Callbacks + 'static, TId: Eq + Hash + Copy + fmt::Display, @@ -30,16 +28,6 @@ where } } -fn insert_peers( - peers_by_ip: &mut IpNetworkTable>>, - ips: &Vec, - peer: Arc>, -) { - for ip in ips { - peers_by_ip.insert(*ip, peer.clone()); - } -} - fn stun(relays: &[Relay], predicate: impl Fn(&SocketAddr) -> bool) -> HashSet { relays .iter() diff --git a/rust/connlib/tunnel/src/control_protocol/client.rs b/rust/connlib/tunnel/src/control_protocol/client.rs index 782a115d7..61da858f1 100644 --- a/rust/connlib/tunnel/src/control_protocol/client.rs +++ b/rust/connlib/tunnel/src/control_protocol/client.rs @@ -1,4 +1,4 @@ -use std::{collections::HashSet, net::IpAddr, sync::Arc}; +use std::{collections::HashSet, net::IpAddr}; use boringtun::x25519::PublicKey; use connlib_shared::{ @@ -23,9 +23,7 @@ use crate::{ }; use crate::{peer::Peer, ClientState, Error, Request, Result, Tunnel}; -use super::insert_peers; - -impl Tunnel +impl Tunnel where CB: Callbacks + 'static, { @@ -108,24 +106,18 @@ where &domain_response.as_ref().map(|d| d.domain.clone()), )?; - let peer = Arc::new(Peer::new(ips.clone(), gateway_id, Default::default())); + let mut peer: Peer<_, PacketTransformClient> = + Peer::new(ips.clone(), gateway_id, Default::default()); + peer.transform.set_dns(self.role_state.dns_mapping()); + self.role_state.peers.insert(peer); let peer_ips = if let Some(domain_response) = domain_response { - self.dns_response(&resource_id, &domain_response, &peer)? + self.dns_response(&resource_id, &domain_response, &gateway_id)? } else { ips }; - peer.transform.set_dns(self.role_state.dns_mapping()); - - // cleaning up old state - self.role_state - .peers_by_ip - .retain(|_, p| p.conn_id != gateway_id); - self.connections_state - .peers_by_id - .insert(gateway_id, Arc::clone(&peer)); - insert_peers(&mut self.role_state.peers_by_ip, &peer_ips, peer); + self.role_state.peers.add_ips(&gateway_id, &peer_ips); Ok(()) } @@ -168,8 +160,14 @@ where &mut self, resource_id: &ResourceId, domain_response: &DomainResponse, - peer: &Peer, + peer_id: &GatewayId, ) -> Result> { + let peer = self + .role_state + .peers + .get_mut(peer_id) + .ok_or(Error::ControlProtocolError)?; + let resource_description = self .role_state .resource_ids @@ -199,9 +197,6 @@ where .insert(resource_description.clone(), addrs.clone()); let ips: Vec = addrs.iter().copied().map(Into::into).collect(); - for ip in &ips { - peer.add_allowed_ip(*ip); - } if let Some(device) = self.device.as_ref() { send_dns_answer( @@ -235,17 +230,10 @@ where .gateway_by_resource(&resource_id) .ok_or(Error::UnknownResource)?; - let Some(peer) = self - .role_state - .peers_by_ip - .iter_mut() - .find_map(|(_, p)| (p.conn_id == gateway_id).then_some(p.clone())) - else { - return Err(Error::ControlProtocolError); - }; + let peer_ips = self.dns_response(&resource_id, &domain_response, &gateway_id)?; + + self.role_state.peers.add_ips(&gateway_id, &peer_ips); - let peer_ips = self.dns_response(&resource_id, &domain_response, &peer)?; - insert_peers(&mut self.role_state.peers_by_ip, &peer_ips, peer); Ok(()) } } diff --git a/rust/connlib/tunnel/src/control_protocol/gateway.rs b/rust/connlib/tunnel/src/control_protocol/gateway.rs index f87dd4607..6760c236f 100644 --- a/rust/connlib/tunnel/src/control_protocol/gateway.rs +++ b/rust/connlib/tunnel/src/control_protocol/gateway.rs @@ -1,5 +1,4 @@ use crate::{ - control_protocol::insert_peers, dns::is_subdomain, peer::{PacketTransformGateway, Peer}, Error, GatewayState, Tunnel, @@ -17,7 +16,6 @@ use connlib_shared::{ use ip_network::IpNetwork; use secrecy::{ExposeSecret as _, Secret}; use snownet::{Credentials, Server}; -use std::sync::Arc; /// Description of a resource that maps to a DNS record which had its domain already resolved. #[derive(Debug, Clone, PartialEq, Eq, Hash)] @@ -36,7 +34,7 @@ pub struct ResolvedResourceDescriptionDns { pub type ResourceDescription = connlib_shared::messages::ResourceDescription; -impl Tunnel +impl Tunnel where CB: Callbacks + 'static, { @@ -125,14 +123,7 @@ where expires_at: Option>, domain: Option, ) -> Option { - let Some(peer) = self - .role_state - .peers_by_ip - .iter_mut() - .find_map(|(_, p)| (p.conn_id == client).then_some(p.clone())) - else { - return None; - }; + let peer = self.role_state.peers.get_mut(&client)?; let (addresses, resource_id) = match &resource { ResourceDescription::Dns(r) => { @@ -176,25 +167,15 @@ where ) -> Result<()> { tracing::trace!(?ips, "new_data_channel_open"); - let peer = Arc::new(Peer::new( - ips.clone(), - client_id, - PacketTransformGateway::default(), - )); + let mut peer = Peer::new(ips.clone(), client_id, PacketTransformGateway::default()); for address in resource_addresses { peer.transform .add_resource(address, resource.clone(), expires_at); } - // cleaning up old state - self.role_state - .peers_by_ip - .retain(|_, p| p.conn_id != client_id); - self.connections_state - .peers_by_id - .insert(client_id, Arc::clone(&peer)); - insert_peers(&mut self.role_state.peers_by_ip, &ips, peer); + self.role_state.peers.insert(peer); + self.role_state.peers.add_ips(&client_id, &ips); Ok(()) } diff --git a/rust/connlib/tunnel/src/gateway.rs b/rust/connlib/tunnel/src/gateway.rs index 96cf829c7..ed1e6d0e0 100644 --- a/rust/connlib/tunnel/src/gateway.rs +++ b/rust/connlib/tunnel/src/gateway.rs @@ -1,21 +1,20 @@ -use crate::device_channel::Device; -use crate::ip_packet::MutableIpPacket; -use crate::peer::{PacketTransformGateway, Peer}; -use crate::{peer_by_ip, Tunnel}; -use connlib_shared::messages::{ClientId, Interface as InterfaceConfig}; -use connlib_shared::Callbacks; -use ip_network_table::IpNetworkTable; -use itertools::Itertools; -use snownet::Server; -use std::sync::Arc; use std::task::{ready, Context, Poll}; use std::time::Duration; + +use crate::device_channel::Device; +use crate::ip_packet::MutableIpPacket; +use crate::peer::PacketTransformGateway; +use crate::peer_store::PeerStore; +use crate::Tunnel; +use connlib_shared::messages::{ClientId, Interface as InterfaceConfig}; +use connlib_shared::Callbacks; +use snownet::Server; use tokio::time::{interval, Interval, MissedTickBehavior}; const PEERS_IPV4: &str = "100.64.0.0/11"; const PEERS_IPV6: &str = "fd00:2021:1111::/107"; -impl Tunnel +impl Tunnel where CB: Callbacks + 'static, { @@ -38,16 +37,14 @@ where } /// Clean up a connection to a resource. - pub fn cleanup_connection(&mut self, id: ClientId) { - self.connections_state.peers_by_id.remove(&id); - self.role_state.peers_by_ip.retain(|_, p| p.conn_id != id); + pub fn cleanup_connection(&mut self, id: &ClientId) { + self.role_state.peers.remove(id); } } /// [`Tunnel`] state specific to gateways. pub struct GatewayState { - #[allow(clippy::type_complexity)] - pub peers_by_ip: IpNetworkTable>>, + pub peers: PeerStore, expire_interval: Interval, } @@ -58,29 +55,23 @@ impl GatewayState { ) -> Option<(ClientId, MutableIpPacket<'a>)> { let dest = packet.destination(); - let peer = peer_by_ip(&self.peers_by_ip, dest)?; + let peer = self.peers.peer_by_ip_mut(dest)?; let packet = peer.transform(packet)?; Some((peer.conn_id, packet)) } - pub fn poll(&mut self, cx: &mut Context<'_>) -> Poll> { + pub fn poll(&mut self, cx: &mut Context<'_>) -> Poll<()> { ready!(self.expire_interval.poll_tick(cx)); - Poll::Ready(self.expire_resources().collect_vec()) + self.expire_resources(); + Poll::Ready(()) } - fn expire_resources(&self) -> impl Iterator + '_ { - self.peers_by_ip - .iter() - .unique_by(|(_, p)| p.conn_id) - .for_each(|(_, p)| p.transform.expire_resources()); - self.peers_by_ip.iter().filter_map(|(_, p)| { - if p.transform.is_emptied() { - Some(p.conn_id) - } else { - None - } - }) + fn expire_resources(&mut self) { + self.peers + .iter_mut() + .for_each(|p| p.transform.expire_resources()); + self.peers.retain(|_, p| !p.transform.is_emptied()); } } @@ -89,7 +80,7 @@ impl Default for GatewayState { let mut expire_interval = interval(Duration::from_secs(1)); expire_interval.set_missed_tick_behavior(MissedTickBehavior::Delay); Self { - peers_by_ip: IpNetworkTable::new(), + peers: Default::default(), expire_interval, } } diff --git a/rust/connlib/tunnel/src/lib.rs b/rust/connlib/tunnel/src/lib.rs index 20439c795..c07054436 100644 --- a/rust/connlib/tunnel/src/lib.rs +++ b/rust/connlib/tunnel/src/lib.rs @@ -10,18 +10,16 @@ use connlib_shared::{ }; use device_channel::Device; use futures_util::{future::BoxFuture, task::AtomicWaker, FutureExt}; -use ip_network_table::IpNetworkTable; -use peer::{PacketTransform, PacketTransformClient, PacketTransformGateway, Peer, PeerStats}; +use peer::PacketTransform; +use peer_store::PeerStore; use pnet_packet::Packet; use snownet::{IpPacket, Node, Server}; use sockets::{Received, Sockets}; use std::{ - collections::{HashMap, HashSet}, + collections::HashSet, fmt, hash::Hash, io, - net::IpAddr, - sync::Arc, task::{ready, Context, Poll}, time::Instant, }; @@ -37,6 +35,7 @@ mod dns; mod gateway; mod ip_packet; mod peer; +mod peer_store; mod sockets; const MAX_UDP_SIZE: usize = (1 << 16) - 1; @@ -47,12 +46,11 @@ const REALM: &str = "firezone"; #[cfg(target_os = "linux")] const FIREZONE_MARK: u32 = 0xfd002021; -pub type GatewayTunnel = Tunnel; -pub type ClientTunnel = - Tunnel; +pub type GatewayTunnel = Tunnel; +pub type ClientTunnel = Tunnel; /// Tunnel is a wireguard state machine that uses webrtc's ICE channels instead of UDP sockets to communicate between peers. -pub struct Tunnel { +pub struct Tunnel { callbacks: CallbackErrorFacade, /// State that differs per role, i.e. clients vs gateways. @@ -61,12 +59,12 @@ pub struct Tunnel { device: Option, no_device_waker: AtomicWaker, - connections_state: ConnectionState, + connections_state: ConnectionState, read_buf: [u8; MAX_UDP_SIZE], } -impl Tunnel +impl Tunnel where CB: Callbacks + 'static, { @@ -94,7 +92,10 @@ where _ => (), } - match self.connections_state.poll_sockets(device, cx)? { + match self + .connections_state + .poll_sockets(device, &mut self.role_state.peers, cx)? + { Poll::Ready(()) => { cx.waker().wake_by_ref(); } @@ -129,17 +130,14 @@ where } } -impl Tunnel +impl Tunnel where CB: Callbacks + 'static, { pub fn poll_next_event(&mut self, cx: &mut Context<'_>) -> Poll>> { match self.role_state.poll(cx) { - Poll::Ready(ids) => { + Poll::Ready(()) => { cx.waker().wake_by_ref(); - for id in ids { - self.cleanup_connection(id); - } } Poll::Pending => {} } @@ -151,14 +149,17 @@ where match self.connections_state.poll_next_event(cx) { Poll::Ready(Event::StopPeer(id)) => { - self.role_state.peers_by_ip.retain(|_, p| p.conn_id != id); + self.role_state.peers.remove(&id); cx.waker().wake_by_ref(); } Poll::Ready(other) => return Poll::Ready(Ok(other)), _ => (), } - match self.connections_state.poll_sockets(device, cx)? { + match self + .connections_state + .poll_sockets(device, &mut self.role_state.peers, cx)? + { Poll::Ready(()) => { cx.waker().wake_by_ref(); } @@ -195,17 +196,10 @@ where } } -#[allow(dead_code)] -#[derive(Debug, Clone)] -pub struct TunnelStats { - peer_connections: HashMap>, -} - -impl Tunnel +impl Tunnel where CB: Callbacks + 'static, TId: Eq + Hash + Copy + fmt::Display, - TTransform: PacketTransform, TRoleState: Default, { /// Creates a new tunnel. @@ -242,34 +236,23 @@ where pub fn callbacks(&self) -> &CallbackErrorFacade { &self.callbacks } - - pub fn stats(&self) -> HashMap> { - self.connections_state - .peers_by_id - .iter() - .map(|(&id, p)| (id, p.stats())) - .collect() - } } -struct ConnectionState { +struct ConnectionState { pub node: Node, write_buf: Box<[u8; MAX_UDP_SIZE]>, - peers_by_id: HashMap>>, connection_pool_timeout: BoxFuture<'static, std::time::Instant>, sockets: Sockets, } -impl ConnectionState +impl ConnectionState where TId: Eq + Hash + Copy + fmt::Display, - TTransform: PacketTransform, { fn new(private_key: StaticSecret) -> Result { Ok(ConnectionState { node: Node::new(private_key, std::time::Instant::now()), write_buf: Box::new([0; MAX_UDP_SIZE]), - peers_by_id: HashMap::new(), connection_pool_timeout: sleep_until(std::time::Instant::now()).boxed(), sockets: Sockets::new()?, }) @@ -294,7 +277,16 @@ where Ok(()) } - fn poll_sockets(&mut self, device: &mut Device, cx: &mut Context<'_>) -> Poll> { + // TODO: passing the peer_store looks weird, we can just remove ConnectionState and move everything into Tunnel, there's no Mutexes any longer that justify this separation + fn poll_sockets( + &mut self, + device: &mut Device, + peer_store: &mut PeerStore, + cx: &mut Context<'_>, + ) -> Poll> + where + TTransform: PacketTransform, + { let received = match ready!(self.sockets.poll_recv_from(cx)) { Ok(received) => received, Err(e) => { @@ -332,7 +324,7 @@ where tracing::trace!(target: "wire", %local, %from, bytes = %packet.packet().len(), "read new packet"); - let Some(peer) = self.peers_by_id.get(&conn_id) else { + let Some(peer) = peer_store.get_mut(&conn_id) else { tracing::error!(%conn_id, %local, %from, "Couldn't find connection"); continue; @@ -378,7 +370,6 @@ where }); } Some(snownet::Event::ConnectionFailed(id)) => { - self.peers_by_id.remove(&id); return Poll::Ready(Event::StopPeer(id)); } _ => {} @@ -397,13 +388,6 @@ where } } -pub(crate) fn peer_by_ip( - peers_by_ip: &IpNetworkTable>>, - ip: IpAddr, -) -> Option<&Peer> { - peers_by_ip.longest_match(ip).map(|(_, peer)| peer.as_ref()) -} - pub enum Event { SignalIceCandidate { conn_id: TId, diff --git a/rust/connlib/tunnel/src/peer.rs b/rust/connlib/tunnel/src/peer.rs index 4b0c893ef..0655c681c 100644 --- a/rust/connlib/tunnel/src/peer.rs +++ b/rust/connlib/tunnel/src/peer.rs @@ -1,10 +1,8 @@ use std::borrow::Cow; use std::collections::HashMap; use std::net::IpAddr; -use std::sync::Arc; use std::time::Instant; -use arc_swap::ArcSwap; use bimap::BiMap; use boringtun::noise::Tunn; use chrono::{DateTime, Utc}; @@ -13,7 +11,6 @@ use connlib_shared::IpProvider; use connlib_shared::{Error, Result}; use ip_network::IpNetwork; use ip_network_table::IpNetworkTable; -use parking_lot::{Mutex, RwLock}; use pnet_packet::Packet; use crate::control_protocol::gateway::ResourceDescription; @@ -26,31 +23,16 @@ type ExpiryingResource = (ResourceDescription, Option>); const IDS_EXPIRE: std::time::Duration = std::time::Duration::from_secs(60); pub struct Peer { - allowed_ips: RwLock>, + allowed_ips: IpNetworkTable<()>, pub conn_id: TId, pub transform: TTransform, } -#[allow(dead_code)] -#[derive(Debug, Clone)] -pub struct PeerStats { - pub allowed_ips: Vec, - pub conn_id: TId, -} - impl Peer where TId: Copy, TTransform: PacketTransform, { - pub fn stats(&self) -> PeerStats { - let allowed_ips = self.allowed_ips.read().iter().map(|(ip, _)| ip).collect(); - PeerStats { - allowed_ips, - conn_id: self.conn_id, - } - } - pub(crate) fn new( ips: Vec, conn_id: TId, @@ -60,7 +42,6 @@ where for ip in ips { allowed_ips.insert(ip, ()); } - let allowed_ips = RwLock::new(allowed_ips); Peer { allowed_ips, @@ -69,21 +50,24 @@ where } } - pub(crate) fn add_allowed_ip(&self, ip: IpNetwork) { - self.allowed_ips.write().insert(ip, ()); + pub(crate) fn add_allowed_ip(&mut self, ip: IpNetwork) { + self.allowed_ips.insert(ip, ()); } fn is_allowed(&self, addr: IpAddr) -> bool { - self.allowed_ips.read().longest_match(addr).is_some() + self.allowed_ips.longest_match(addr).is_some() } /// Sends the given packet to this peer by encapsulating it in a wireguard packet. - pub(crate) fn transform<'a>(&self, packet: MutableIpPacket<'a>) -> Option> { + pub(crate) fn transform<'a>( + &mut self, + packet: MutableIpPacket<'a>, + ) -> Option> { self.transform.packet_transform(packet) } pub(crate) fn untransform<'b>( - &self, + &mut self, addr: IpAddr, packet: &'b mut [u8], ) -> Result> { @@ -98,86 +82,83 @@ where } pub struct PacketTransformGateway { - resources: RwLock>, + resources: IpNetworkTable, } impl Default for PacketTransformGateway { fn default() -> Self { Self { - resources: RwLock::new(IpNetworkTable::new()), + resources: IpNetworkTable::new(), } } } #[derive(Default)] pub struct PacketTransformClient { - translations: RwLock>, - dns_mapping: ArcSwap>, - mangled_dns_ids: Mutex>, + translations: BiMap, + dns_mapping: BiMap, + mangled_dns_ids: HashMap, } impl PacketTransformClient { pub fn get_or_assign_translation( - &self, + &mut self, ip: &IpAddr, ip_provider: &mut IpProvider, ) -> Option { - let mut translations = self.translations.write(); - if let Some(proxy_ip) = translations.get_by_right(ip) { + if let Some(proxy_ip) = self.translations.get_by_right(ip) { return Some(*proxy_ip); } let proxy_ip = ip_provider.get_proxy_ip_for(ip)?; - translations.insert(proxy_ip, *ip); + self.translations.insert(proxy_ip, *ip); Some(proxy_ip) } - pub fn expire_dns_track(&self) { + pub fn expire_dns_track(&mut self) { self.mangled_dns_ids - .lock() .retain(|_, exp| exp.elapsed() < IDS_EXPIRE); } - pub fn set_dns(&self, mapping: BiMap) { - self.dns_mapping.store(Arc::new(mapping)); + pub fn set_dns(&mut self, mapping: BiMap) { + self.dns_mapping = mapping; } } impl PacketTransformGateway { pub(crate) fn is_emptied(&self) -> bool { - self.resources.read().is_empty() + self.resources.is_empty() } - pub(crate) fn expire_resources(&self) { + pub(crate) fn expire_resources(&mut self) { self.resources - .write() .retain(|_, (_, e)| !e.is_some_and(|e| e <= Utc::now())); } pub(crate) fn add_resource( - &self, + &mut self, ip: IpNetwork, resource: ResourceDescription, expires_at: Option>, ) { - self.resources.write().insert(ip, (resource, expires_at)); + self.resources.insert(ip, (resource, expires_at)); } } pub trait PacketTransform { fn packet_untransform<'a>( - &self, + &mut self, addr: &IpAddr, packet: &'a mut [u8], ) -> Result<(device_channel::Packet<'a>, IpAddr)>; - fn packet_transform<'a>(&self, packet: MutableIpPacket<'a>) -> Option>; + fn packet_transform<'a>(&mut self, packet: MutableIpPacket<'a>) -> Option>; } impl PacketTransform for PacketTransformGateway { fn packet_untransform<'a>( - &self, + &mut self, addr: &IpAddr, packet: &'a mut [u8], ) -> Result<(device_channel::Packet<'a>, IpAddr)> { @@ -185,7 +166,7 @@ impl PacketTransform for PacketTransformGateway { return Err(Error::BadPacket); }; - if self.resources.read().longest_match(dst).is_some() { + if self.resources.longest_match(dst).is_some() { let packet = make_packet(packet, addr); Ok((packet, *addr)) } else { @@ -194,19 +175,18 @@ impl PacketTransform for PacketTransformGateway { } } - fn packet_transform<'a>(&self, packet: MutableIpPacket<'a>) -> Option> { + fn packet_transform<'a>(&mut self, packet: MutableIpPacket<'a>) -> Option> { Some(packet) } } impl PacketTransform for PacketTransformClient { fn packet_untransform<'a>( - &self, + &mut self, addr: &IpAddr, packet: &'a mut [u8], ) -> Result<(device_channel::Packet<'a>, IpAddr)> { - let translations = self.translations.read(); - let mut src = *translations.get_by_right(addr).unwrap_or(addr); + let mut src = *self.translations.get_by_right(addr).unwrap_or(addr); let Some(mut pkt) = MutableIpPacket::new(packet) else { return Err(Error::BadPacket); @@ -216,14 +196,11 @@ impl PacketTransform for PacketTransformClient { if let Some(dgm) = pkt.as_udp() { if let Some(sentinel) = self .dns_mapping - .load() - .as_ref() .get_by_right(&(src, dgm.get_source()).into()) { if let Ok(message) = domain::base::Message::from_slice(dgm.payload()) { if self .mangled_dns_ids - .lock() .remove(&message.header().id()) .is_some_and(|exp| exp.elapsed() < IDS_EXPIRE) { @@ -239,22 +216,19 @@ impl PacketTransform for PacketTransformClient { Ok((packet, original_src)) } - fn packet_transform<'a>(&self, mut packet: MutableIpPacket<'a>) -> Option> { - if let Some(translated_ip) = self.translations.read().get_by_left(&packet.destination()) { + fn packet_transform<'a>( + &mut self, + mut packet: MutableIpPacket<'a>, + ) -> Option> { + if let Some(translated_ip) = self.translations.get_by_left(&packet.destination()) { packet.set_dst(*translated_ip); packet.update_checksum(); } - if let Some(srv) = self - .dns_mapping - .load() - .as_ref() - .get_by_left(&packet.destination()) - { + if let Some(srv) = self.dns_mapping.get_by_left(&packet.destination()) { if let Some(dgm) = packet.as_udp() { if let Ok(message) = domain::base::Message::from_slice(dgm.payload()) { self.mangled_dns_ids - .lock() .insert(message.header().id(), Instant::now()); packet.set_dst(srv.ip()); packet.update_checksum(); diff --git a/rust/connlib/tunnel/src/peer_store.rs b/rust/connlib/tunnel/src/peer_store.rs new file mode 100644 index 000000000..a39f4de97 --- /dev/null +++ b/rust/connlib/tunnel/src/peer_store.rs @@ -0,0 +1,141 @@ +use std::collections::HashMap; +use std::hash::Hash; +use std::net::IpAddr; + +use crate::peer::{PacketTransform, Peer}; +use ip_network::IpNetwork; +use ip_network_table::IpNetworkTable; + +pub struct PeerStore { + id_by_ip: IpNetworkTable, + peer_by_id: HashMap>, +} + +impl Default for PeerStore { + fn default() -> Self { + Self { + id_by_ip: IpNetworkTable::new(), + peer_by_id: HashMap::new(), + } + } +} + +impl PeerStore +where + TId: Hash + Eq + Clone + Copy, + TTransform: PacketTransform, +{ + pub fn retain(&mut self, f: impl Fn(&TId, &mut Peer) -> bool) { + self.peer_by_id.retain(f); + self.id_by_ip + .retain(|_, id| self.peer_by_id.contains_key(id)); + } + + pub fn add_ips(&mut self, id: &TId, ips: &[IpNetwork]) -> Option<&Peer> { + let peer = self.peer_by_id.get_mut(id)?; + + for ip in ips { + self.id_by_ip.insert(*ip, peer.conn_id); + peer.add_allowed_ip(*ip); + } + + Some(peer) + } + + pub fn insert(&mut self, peer: Peer) -> Option> { + self.id_by_ip.retain(|_, &mut r_id| r_id != peer.conn_id); + + self.peer_by_id.insert(peer.conn_id, peer) + } + + pub fn remove(&mut self, id: &TId) -> Option> { + self.id_by_ip.retain(|_, r_id| r_id != id); + self.peer_by_id.remove(id) + } + + pub fn exact_match(&self, ip: IpNetwork) -> Option<&Peer> { + let ip = self.id_by_ip.exact_match(ip)?; + self.peer_by_id.get(ip) + } + + pub fn get(&self, id: &TId) -> Option<&Peer> { + self.peer_by_id.get(id) + } + + pub fn get_mut(&mut self, id: &TId) -> Option<&mut Peer> { + self.peer_by_id.get_mut(id) + } + + pub fn peer_by_ip(&self, ip: IpAddr) -> Option<&Peer> { + let (_, id) = self.id_by_ip.longest_match(ip)?; + self.peer_by_id.get(id) + } + + pub fn peer_by_ip_mut(&mut self, ip: IpAddr) -> Option<&mut Peer> { + let (_, id) = self.id_by_ip.longest_match(ip)?; + self.peer_by_id.get_mut(id) + } + + pub fn iter_mut(&mut self) -> impl Iterator> { + self.peer_by_id.values_mut() + } + + pub fn iter(&mut self) -> impl Iterator> { + self.peer_by_id.values() + } +} + +#[cfg(test)] +mod tests { + use crate::peer::{PacketTransformGateway, Peer}; + + use super::PeerStore; + + #[test] + fn can_insert_and_retrieve_peer() { + let mut peer_storage = PeerStore::default(); + peer_storage.insert(Peer::new(vec![], 0, PacketTransformGateway::default())); + assert!(peer_storage.get(&0).is_some()); + } + + #[test] + fn can_insert_and_retrieve_peer_by_ip() { + let mut peer_storage = PeerStore::default(); + peer_storage.insert(Peer::new(vec![], 0, PacketTransformGateway::default())); + peer_storage.add_ips(&0, &["100.0.0.0/24".parse().unwrap()]); + + assert_eq!( + peer_storage + .peer_by_ip("100.0.0.1".parse().unwrap()) + .unwrap() + .conn_id, + 0 + ); + } + + #[test] + fn can_remove_peer() { + let mut peer_storage = PeerStore::default(); + peer_storage.insert(Peer::new(vec![], 0, PacketTransformGateway::default())); + peer_storage.add_ips(&0, &["100.0.0.0/24".parse().unwrap()]); + peer_storage.remove(&0); + + assert!(peer_storage.get(&0).is_none()); + assert!(peer_storage + .peer_by_ip("100.0.0.1".parse().unwrap()) + .is_none()) + } + + #[test] + fn inserting_peer_removes_previous_instances_of_same_id() { + let mut peer_storage = PeerStore::default(); + peer_storage.insert(Peer::new(vec![], 0, PacketTransformGateway::default())); + peer_storage.add_ips(&0, &["100.0.0.0/24".parse().unwrap()]); + peer_storage.insert(Peer::new(vec![], 0, PacketTransformGateway::default())); + + assert!(peer_storage.get(&0).is_some()); + assert!(peer_storage + .peer_by_ip("100.0.0.1".parse().unwrap()) + .is_none()) + } +} diff --git a/rust/gateway/src/eventloop.rs b/rust/gateway/src/eventloop.rs index ac56ac5b1..a5cbd1320 100644 --- a/rust/gateway/src/eventloop.rs +++ b/rust/gateway/src/eventloop.rs @@ -29,7 +29,6 @@ pub struct Eventloop { Result>, Either, >, - print_stats_timer: tokio::time::Interval, } impl Eventloop { @@ -41,7 +40,6 @@ impl Eventloop { tunnel, portal, resolve_tasks: futures_bounded::FuturesTupleSet::new(Duration::from_secs(60), 100), - print_stats_timer: tokio::time::interval(Duration::from_secs(10)), } } } @@ -104,7 +102,7 @@ impl Eventloop { Err(e) => { let client = req.client.id; - self.tunnel.cleanup_connection(client); + self.tunnel.cleanup_connection(&client); tracing::debug!(%client, "Connection request failed: {:#}", anyhow::Error::new(e)); continue; @@ -216,11 +214,6 @@ impl Eventloop { _ => {} } - if self.print_stats_timer.poll_tick(cx).is_ready() { - tracing::debug!(target: "tunnel_state", stats = ?self.tunnel.stats()); - continue; - } - return Poll::Pending; } }