diff --git a/rust/Cargo.lock b/rust/Cargo.lock index f22b41ba8..1baa1c09a 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -1713,6 +1713,20 @@ dependencies = [ "syn 2.0.104", ] +[[package]] +name = "dashmap" +version = "6.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5041cc499144891f3790297212f32a74fb938e5136a14943f338ef9e0ae276cf" +dependencies = [ + "cfg-if", + "crossbeam-utils", + "hashbrown 0.14.5", + "lock_api", + "once_cell", + "parking_lot_core", +] + [[package]] name = "data-encoding" version = "2.9.0" @@ -2318,6 +2332,7 @@ dependencies = [ "bufferpool", "bytes", "clap", + "dashmap", "dirs", "dns-types", "firezone-logging", @@ -6855,7 +6870,6 @@ dependencies = [ "ip-packet", "libc", "opentelemetry", - "parking_lot", "quinn-udp", "socket2", "tokio", diff --git a/rust/Cargo.toml b/rust/Cargo.toml index afe677aee..f8297e2cd 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -63,6 +63,7 @@ chrono = { version = "0.4", default-features = false, features = ["std", "clock" clap = "4.5.41" client-shared = { path = "client-shared" } connlib-model = { path = "connlib/model" } +dashmap = "6.1.0" derive_more = "2.0.1" difference = "2.0.0" dirs = "6.0.0" diff --git a/rust/bin-shared/Cargo.toml b/rust/bin-shared/Cargo.toml index 919eeefe7..8f4227ca3 100644 --- a/rust/bin-shared/Cargo.toml +++ b/rust/bin-shared/Cargo.toml @@ -43,6 +43,7 @@ rtnetlink = { workspace = true } zbus = { workspace = true } # Can't use `zbus`'s `tokio` feature here, or it will break toast popups all the way over in `gui-client`. [target.'cfg(windows)'.dependencies] +dashmap = { workspace = true } ipconfig = "0.3.2" itertools = { workspace = true } known-folders = { workspace = true } diff --git a/rust/bin-shared/src/linux.rs b/rust/bin-shared/src/linux.rs index 2cfe29b1f..7956b4d94 100644 --- a/rust/bin-shared/src/linux.rs +++ b/rust/bin-shared/src/linux.rs @@ -2,16 +2,23 @@ use std::{io, net::SocketAddr}; use crate::FIREZONE_MARK; use nix::sys::socket::{setsockopt, sockopt}; -use socket_factory::{TcpSocket, UdpSocket}; +use socket_factory::{SocketFactory, TcpSocket, UdpSocket}; -pub fn tcp_socket_factory(socket_addr: &SocketAddr) -> io::Result { +pub fn tcp_socket_factory(socket_addr: SocketAddr) -> io::Result { let socket = socket_factory::tcp(socket_addr)?; setsockopt(&socket, sockopt::Mark, &FIREZONE_MARK)?; Ok(socket) } -pub fn udp_socket_factory(socket_addr: &SocketAddr) -> io::Result { - let socket = socket_factory::udp(socket_addr)?; - setsockopt(&socket, sockopt::Mark, &FIREZONE_MARK)?; - Ok(socket) +#[derive(Default)] +pub struct UdpSocketFactory {} + +impl SocketFactory for UdpSocketFactory { + fn bind(&self, local: SocketAddr) -> io::Result { + let socket = socket_factory::udp(local)?; + setsockopt(&socket, sockopt::Mark, &FIREZONE_MARK)?; + Ok(socket) + } + + fn reset(&self) {} } diff --git a/rust/bin-shared/src/macos.rs b/rust/bin-shared/src/macos.rs index 43b7b4a9f..2641614e9 100644 --- a/rust/bin-shared/src/macos.rs +++ b/rust/bin-shared/src/macos.rs @@ -1,2 +1,16 @@ +use socket_factory::{SocketFactory, UdpSocket}; +use std::io; +use std::net::SocketAddr; + pub use socket_factory::tcp as tcp_socket_factory; -pub use socket_factory::udp as udp_socket_factory; + +#[derive(Default)] +pub struct UdpSocketFactory {} + +impl SocketFactory for UdpSocketFactory { + fn bind(&self, local: SocketAddr) -> io::Result { + socket_factory::udp(local) + } + + fn reset(&self) {} +} diff --git a/rust/bin-shared/src/windows.rs b/rust/bin-shared/src/windows.rs index fb085fccc..7dd26f9ea 100644 --- a/rust/bin-shared/src/windows.rs +++ b/rust/bin-shared/src/windows.rs @@ -1,6 +1,8 @@ use crate::TUNNEL_NAME; use anyhow::Result; +use dashmap::DashMap; use firezone_logging::err_with_src; +use socket_factory::SocketFactory; use socket_factory::{TcpSocket, UdpSocket}; use std::{ cmp::Ordering, @@ -8,6 +10,7 @@ use std::{ mem::MaybeUninit, net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}, ptr::null, + sync::Arc, }; use uuid::Uuid; use windows::Win32::NetworkManagement::{ @@ -87,7 +90,7 @@ pub mod error { pub const EPT_S_NOT_REGISTERED: HRESULT = HRESULT::from_win32(0x06D9); } -pub fn tcp_socket_factory(addr: &SocketAddr) -> io::Result { +pub fn tcp_socket_factory(addr: SocketAddr) -> io::Result { delete_all_routing_entries_matching(addr.ip())?; let route = get_best_non_tunnel_route(addr.ip())?; @@ -104,13 +107,61 @@ pub fn tcp_socket_factory(addr: &SocketAddr) -> io::Result { Ok(socket) } -pub fn udp_socket_factory(src_addr: &SocketAddr) -> io::Result { - let source_ip_resolver = |dst| Ok(get_best_non_tunnel_route(dst)?.addr); +/// A UDP socket factory with a src IP cache. +/// +/// On Windows, we need to manually compute and set the source IP for all UDP packets. +/// Determining this mapping requires several syscalls and therefore is too expensive to perform on every packet. +/// To speed things up, we therefore implement a cache across all UDP sockets created by a given [`UdpSocketFactory`]. +/// +/// This cache needs to be reset whenever we are roaming networks which happens in the [`SocketFactory::reset`] function. +/// +/// As most of the time we will only read from the cache, we use a [`DashMap`] (a concurrent hash-map). +pub struct UdpSocketFactory { + src_ip_cache: Arc>, +} - let socket = - socket_factory::udp(src_addr)?.with_source_ip_resolver(Box::new(source_ip_resolver)); +impl SocketFactory for UdpSocketFactory { + fn bind(&self, local: SocketAddr) -> io::Result { + let src_ip_cache = self.src_ip_cache.clone(); - Ok(socket) + let source_ip_resolver = move |dst| { + // First, try to get the existing entry (this is only using a read-lock internally so quite fast.) + if let Some(addr) = src_ip_cache.get(&dst) { + return Ok(*addr.value()); + } + + // If we don't have an entry, compute it. + let addr = get_best_non_tunnel_route(dst)?.addr; + + // Insert the result. + // This may be a possible data-race if two sockets want to resolve the same IP at the same time. + // It doesn't matter though as the result of `get_best_non_tunnel_route` should be deterministic. + src_ip_cache.insert(dst, addr); + + Ok(addr) + }; + + let socket = + socket_factory::udp(local)?.with_source_ip_resolver(Box::new(source_ip_resolver)); + + Ok(socket) + } + + fn reset(&self) { + self.src_ip_cache.clear() + } +} + +impl Default for UdpSocketFactory { + fn default() -> Self { + Self { + // This cache is expected to be quite small as there are only so many different IPs a client will talk to: + // - All connected Gateways + // - All DNS servers + // The capacity is only a guideline for the initial memory-consumption, it can also be outgrown. + src_ip_cache: Arc::new(DashMap::with_capacity(16)), + } + } } fn delete_all_routing_entries_matching(addr: IpAddr) -> io::Result<()> { diff --git a/rust/bin-shared/tests/no_packet_loops_tcp.rs b/rust/bin-shared/tests/no_packet_loops_tcp.rs index 80034b3e8..3c3ea6531 100644 --- a/rust/bin-shared/tests/no_packet_loops_tcp.rs +++ b/rust/bin-shared/tests/no_packet_loops_tcp.rs @@ -29,7 +29,7 @@ async fn no_packet_loops_tcp() { .unwrap(); let remote = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::from([1, 1, 1, 1]), 80)); - let socket = tcp_socket_factory(&remote).unwrap(); + let socket = tcp_socket_factory(remote).unwrap(); let mut stream = socket.connect(remote).await.unwrap(); // Send an HTTP request diff --git a/rust/bin-shared/tests/no_packet_loops_udp.rs b/rust/bin-shared/tests/no_packet_loops_udp.rs index 2c46a3557..cc5c8e5aa 100644 --- a/rust/bin-shared/tests/no_packet_loops_udp.rs +++ b/rust/bin-shared/tests/no_packet_loops_udp.rs @@ -2,11 +2,12 @@ use bufferpool::BufferPool; use bytes::BytesMut; -use firezone_bin_shared::{TunDeviceManager, platform::udp_socket_factory}; +use firezone_bin_shared::{TunDeviceManager, platform::UdpSocketFactory}; use gat_lending_iterator::LendingIterator as _; use ip_network::Ipv4Network; use ip_packet::Ecn; use socket_factory::DatagramOut; +use socket_factory::SocketFactory as _; use std::{ net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4}, time::Duration, @@ -36,9 +37,12 @@ async fn no_packet_loops_udp() { .await .unwrap(); + let factory = UdpSocketFactory::default(); + // Make a socket. - let socket = - udp_socket_factory(&SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0))).unwrap(); + let socket = factory + .bind(SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0))) + .unwrap(); // Send a STUN request. socket diff --git a/rust/connlib/phoenix-channel/src/lib.rs b/rust/connlib/phoenix-channel/src/lib.rs index cd583ac61..bf70ea50f 100644 --- a/rust/connlib/phoenix-channel/src/lib.rs +++ b/rust/connlib/phoenix-channel/src/lib.rs @@ -113,7 +113,7 @@ async fn connect( let mut errors = Vec::with_capacity(addresses.len()); for addr in addresses { - let Ok(socket) = socket_factory(&addr) else { + let Ok(socket) = socket_factory.bind(addr) else { continue; }; diff --git a/rust/connlib/socket-factory/Cargo.toml b/rust/connlib/socket-factory/Cargo.toml index 509b15212..08efa7f13 100644 --- a/rust/connlib/socket-factory/Cargo.toml +++ b/rust/connlib/socket-factory/Cargo.toml @@ -12,7 +12,6 @@ derive_more = { workspace = true, features = ["debug"] } gat-lending-iterator = { workspace = true } ip-packet = { workspace = true } opentelemetry = { workspace = true, features = ["metrics"] } -parking_lot = { workspace = true } quinn-udp = { workspace = true } socket2 = { workspace = true } tokio = { workspace = true, features = ["net"] } diff --git a/rust/connlib/socket-factory/src/lib.rs b/rust/connlib/socket-factory/src/lib.rs index 9189dc20d..59e5199e6 100644 --- a/rust/connlib/socket-factory/src/lib.rs +++ b/rust/connlib/socket-factory/src/lib.rs @@ -4,9 +4,7 @@ use bytes::{Buf as _, BytesMut}; use gat_lending_iterator::LendingIterator; use ip_packet::{Ecn, Ipv4Header, Ipv6Header, UdpHeader}; use opentelemetry::KeyValue; -use parking_lot::Mutex; use quinn_udp::{EcnCodepoint, Transmit, UdpSockRef}; -use std::collections::HashMap; use std::io; use std::io::IoSliceMut; use std::ops::Deref; @@ -16,19 +14,30 @@ use std::{ }; use std::any::Any; -use std::collections::hash_map::Entry; use std::pin::Pin; use tokio::io::Interest; -pub trait SocketFactory: Fn(&SocketAddr) -> io::Result + Send + Sync + 'static {} +pub trait SocketFactory: Send + Sync + 'static { + fn bind(&self, local: SocketAddr) -> io::Result; + fn reset(&self); +} pub const SEND_BUFFER_SIZE: usize = ONE_MB; pub const RECV_BUFFER_SIZE: usize = 10 * ONE_MB; const ONE_MB: usize = 1024 * 1024; -impl SocketFactory for F where F: Fn(&SocketAddr) -> io::Result + Send + Sync + 'static {} +impl SocketFactory for F +where + F: Fn(SocketAddr) -> io::Result + Send + Sync + 'static, +{ + fn bind(&self, local: SocketAddr) -> io::Result { + (self)(local) + } -pub fn tcp(addr: &SocketAddr) -> io::Result { + fn reset(&self) {} +} + +pub fn tcp(addr: SocketAddr) -> io::Result { let socket = match addr { SocketAddr::V4(_) => tokio::net::TcpSocket::new_v4()?, SocketAddr::V6(_) => tokio::net::TcpSocket::new_v6()?, @@ -42,8 +51,8 @@ pub fn tcp(addr: &SocketAddr) -> io::Result { }) } -pub fn udp(std_addr: &SocketAddr) -> io::Result { - let addr = socket2::SockAddr::from(*std_addr); +pub fn udp(std_addr: SocketAddr) -> io::Result { + let addr = socket2::SockAddr::from(std_addr); let socket = socket2::Socket::new(addr.domain(), socket2::Type::DGRAM, None)?; // Note: for AF_INET sockets IPV6_V6ONLY is not a valid flag @@ -151,9 +160,6 @@ pub struct UdpSocket { source_ip_resolver: Option std::io::Result + Send + Sync + 'static>>, - /// A cache of source IPs by their destination IPs. - src_by_dst_cache: Mutex>, - /// A buffer pool for batches of incoming UDP packets. buffer_pool: BufferPool>, @@ -171,7 +177,6 @@ impl UdpSocket { port, inner, source_ip_resolver: None, - src_by_dst_cache: Default::default(), buffer_pool: BufferPool::new( u16::MAX as usize, match socket_addr.ip() { @@ -215,8 +220,7 @@ impl UdpSocket { /// Configures a new source IP resolver for this UDP socket. /// /// In case [`DatagramOut::src`] is [`None`], this function will be used to set a source IP given the destination IP of the datagram. - /// The resulting IPs will be cached. - /// To evict this cache, drop the [`UdpSocket`] and make a new one. + /// If set, this function will be called for _every_ packet and should therefore be fast. /// /// Errors during resolution result in the packet being dropped. pub fn with_source_ip_resolver( @@ -469,24 +473,13 @@ impl UdpSocket { /// Attempt to resolve the source IP to use for sending to the given destination IP. fn resolve_source_for(&self, dst: IpAddr) -> std::io::Result> { - let src = match self.src_by_dst_cache.lock().entry(dst) { - Entry::Occupied(occ) => *occ.get(), - Entry::Vacant(vac) => { - // Caching errors could be a good idea to not incur in multiple calls for the resolver which can be costly - // For some cases like hosts ipv4-only stack trying to send ipv6 packets this can happen quite often but doing this is also a risk - // that in case that the adapter for some reason is temporarily unavailable it'd prevent the system from recovery. - - let Some(resolver) = self.source_ip_resolver.as_ref() else { - // If we don't have a resolver, let the operating system decide. - return Ok(None); - }; - - let src = (resolver)(dst)?; - - *vac.insert(src) - } + let Some(resolver) = self.source_ip_resolver.as_ref() else { + // If we don't have a resolver, let the operating system decide. + return Ok(None); }; + let src = (resolver)(dst)?; + Ok(Some(src)) } } diff --git a/rust/connlib/tunnel/src/io.rs b/rust/connlib/tunnel/src/io.rs index 2a4e32cbc..5baf1d656 100644 --- a/rust/connlib/tunnel/src/io.rs +++ b/rust/connlib/tunnel/src/io.rs @@ -321,6 +321,8 @@ impl Io { } pub fn reset(&mut self) { + self.tcp_socket_factory.reset(); + self.udp_socket_factory.reset(); self.sockets.rebind(self.udp_socket_factory.clone()); self.gso_queue.clear(); self.dns_queries = FuturesTupleSet::new(DNS_QUERY_TIMEOUT, 1000); diff --git a/rust/connlib/tunnel/src/io/tcp_dns.rs b/rust/connlib/tunnel/src/io/tcp_dns.rs index ee575db59..1f5a52391 100644 --- a/rust/connlib/tunnel/src/io/tcp_dns.rs +++ b/rust/connlib/tunnel/src/io/tcp_dns.rs @@ -10,7 +10,7 @@ pub async fn send( ) -> io::Result { tracing::trace!(target: "wire::dns::recursive::tcp", %server, domain = %query.domain()); - let tcp_socket = factory(&server)?; // TODO: Optimise this to reuse a TCP socket to the same resolver. + let tcp_socket = factory.bind(server)?; // TODO: Optimise this to reuse a TCP socket to the same resolver. let mut tcp_stream = tcp_socket.connect(server).await?; let query = query.into_bytes(); diff --git a/rust/connlib/tunnel/src/io/udp_dns.rs b/rust/connlib/tunnel/src/io/udp_dns.rs index 3e9727d69..14d0bea1c 100644 --- a/rust/connlib/tunnel/src/io/udp_dns.rs +++ b/rust/connlib/tunnel/src/io/udp_dns.rs @@ -21,7 +21,7 @@ pub async fn send( // To avoid fragmentation, IP and thus also UDP packets can only reliably sent with an MTU of <= 1500 on the public Internet. const BUF_SIZE: usize = 1500; - let udp_socket = factory(&bind_addr)?; + let udp_socket = factory.bind(bind_addr)?; let response = udp_socket .handshake::(server, &query.into_bytes()) diff --git a/rust/connlib/tunnel/src/sockets.rs b/rust/connlib/tunnel/src/sockets.rs index fa2d02b4a..c07b6d717 100644 --- a/rust/connlib/tunnel/src/sockets.rs +++ b/rust/connlib/tunnel/src/sockets.rs @@ -311,7 +311,7 @@ fn listen( let mut last_err = None; for addr in addresses { - match sf(addr) { + match sf.bind(*addr) { Ok(s) => return Ok(s), Err(e) => { tracing::debug!(%addr, "Failed to listen on UDP socket: {e}"); diff --git a/rust/gateway/src/main.rs b/rust/gateway/src/main.rs index 914bf18e9..6b2bcd813 100644 --- a/rust/gateway/src/main.rs +++ b/rust/gateway/src/main.rs @@ -4,7 +4,7 @@ use backoff::ExponentialBackoffBuilder; use clap::Parser; use firezone_bin_shared::{ TunDeviceManager, device_id, http_health_check, - platform::{tcp_socket_factory, udp_socket_factory}, + platform::{UdpSocketFactory, tcp_socket_factory}, }; use firezone_telemetry::{ @@ -165,7 +165,7 @@ async fn try_main(cli: Cli, telemetry: &mut Telemetry) -> Result<()> { let mut tunnel = GatewayTunnel::new( Arc::new(tcp_socket_factory), - Arc::new(udp_socket_factory), + Arc::new(UdpSocketFactory::default()), nameservers, ); let portal = PhoenixChannel::disconnected( diff --git a/rust/gui-client/src-tauri/src/service.rs b/rust/gui-client/src-tauri/src/service.rs index e8e6ea969..2af1bf535 100644 --- a/rust/gui-client/src-tauri/src/service.rs +++ b/rust/gui-client/src-tauri/src/service.rs @@ -10,7 +10,7 @@ use firezone_bin_shared::{ DnsControlMethod, DnsController, TunDeviceManager, device_id::{self, DeviceId}, device_info, known_dirs, - platform::{tcp_socket_factory, udp_socket_factory}, + platform::{UdpSocketFactory, tcp_socket_factory}, signals, }; use firezone_logging::{FilterReloadHandle, err_with_src}; @@ -636,7 +636,7 @@ impl<'a> Handler<'a> { let dns = self.dns_controller.system_resolvers(); let (connlib, event_stream) = client_shared::Session::connect( Arc::new(tcp_socket_factory), - Arc::new(udp_socket_factory), + Arc::new(UdpSocketFactory::default()), portal, tokio::runtime::Handle::current(), ); diff --git a/rust/headless-client/src/main.rs b/rust/headless-client/src/main.rs index 721a8d4c6..f1699dcb3 100644 --- a/rust/headless-client/src/main.rs +++ b/rust/headless-client/src/main.rs @@ -8,7 +8,7 @@ use clap::Parser; use firezone_bin_shared::{ DnsControlMethod, DnsController, TOKEN_ENV_KEY, TunDeviceManager, device_id, device_info, new_dns_notifier, new_network_notifier, - platform::{tcp_socket_factory, udp_socket_factory}, + platform::{UdpSocketFactory, tcp_socket_factory}, signals, }; use firezone_telemetry::{Telemetry, analytics, otel}; @@ -265,7 +265,7 @@ fn main() -> Result<()> { )?; let (session, mut event_stream) = client_shared::Session::connect( Arc::new(tcp_socket_factory), - Arc::new(udp_socket_factory), + Arc::new(UdpSocketFactory::default()), portal, rt.handle().clone(), );