refactor(windows): share src IP cache across UDP sockets (#9976)

When looking through customer logs, we see a lot of "Resolved best route
outside of tunnel" messages. Those get logged every time we need to
rerun our re-implementation of Windows' weighting algorithm as to which
source interface / IP a packet should be sent from.

Currently, this gets cached in every socket instance so for the
peer-to-peer socket, this is only computed once per destination IP.
However, for DNS queries, we make a new socket for every query. Using a
new source port DNS queries is recommended to avoid fingerprinting of
DNS queries. Using a new socket also means that we need to re-run this
algorithm every time we make a DNS query which is why we see this log so
often.

To fix this, we need to share this cache across all UDP sockets. Cache
invalidation is one of the hardest problems in computer science and this
instance is no different. This cache needs to be reset every time we
roam as that changes the weighting of which source interface to use.

To achieve this, we extend the `SocketFactory` trait with a `reset`
method. This method is called whenever we roam and can then reset a
shared cache inside the `UdpSocketFactory`. The "source IP resolver"
function that is passed to the UDP socket now simply accesses this
shared cache and inserts a new entry when it needs to resolve the IP.

As an added benefit, this may speed up DNS queries on Windows a bit
(although I haven't benchmarked it). It should certainly drastically
reduce the amount of syscalls we make on Windows.
This commit is contained in:
Thomas Eizinger
2025-07-24 11:36:53 +10:00
committed by GitHub
parent 409459f11c
commit 301d2137e5
18 changed files with 145 additions and 59 deletions

16
rust/Cargo.lock generated
View File

@@ -1713,6 +1713,20 @@ dependencies = [
"syn 2.0.104",
]
[[package]]
name = "dashmap"
version = "6.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5041cc499144891f3790297212f32a74fb938e5136a14943f338ef9e0ae276cf"
dependencies = [
"cfg-if",
"crossbeam-utils",
"hashbrown 0.14.5",
"lock_api",
"once_cell",
"parking_lot_core",
]
[[package]]
name = "data-encoding"
version = "2.9.0"
@@ -2318,6 +2332,7 @@ dependencies = [
"bufferpool",
"bytes",
"clap",
"dashmap",
"dirs",
"dns-types",
"firezone-logging",
@@ -6855,7 +6870,6 @@ dependencies = [
"ip-packet",
"libc",
"opentelemetry",
"parking_lot",
"quinn-udp",
"socket2",
"tokio",

View File

@@ -63,6 +63,7 @@ chrono = { version = "0.4", default-features = false, features = ["std", "clock"
clap = "4.5.41"
client-shared = { path = "client-shared" }
connlib-model = { path = "connlib/model" }
dashmap = "6.1.0"
derive_more = "2.0.1"
difference = "2.0.0"
dirs = "6.0.0"

View File

@@ -43,6 +43,7 @@ rtnetlink = { workspace = true }
zbus = { workspace = true } # Can't use `zbus`'s `tokio` feature here, or it will break toast popups all the way over in `gui-client`.
[target.'cfg(windows)'.dependencies]
dashmap = { workspace = true }
ipconfig = "0.3.2"
itertools = { workspace = true }
known-folders = { workspace = true }

View File

@@ -2,16 +2,23 @@ use std::{io, net::SocketAddr};
use crate::FIREZONE_MARK;
use nix::sys::socket::{setsockopt, sockopt};
use socket_factory::{TcpSocket, UdpSocket};
use socket_factory::{SocketFactory, TcpSocket, UdpSocket};
pub fn tcp_socket_factory(socket_addr: &SocketAddr) -> io::Result<TcpSocket> {
pub fn tcp_socket_factory(socket_addr: SocketAddr) -> io::Result<TcpSocket> {
let socket = socket_factory::tcp(socket_addr)?;
setsockopt(&socket, sockopt::Mark, &FIREZONE_MARK)?;
Ok(socket)
}
pub fn udp_socket_factory(socket_addr: &SocketAddr) -> io::Result<UdpSocket> {
let socket = socket_factory::udp(socket_addr)?;
setsockopt(&socket, sockopt::Mark, &FIREZONE_MARK)?;
Ok(socket)
#[derive(Default)]
pub struct UdpSocketFactory {}
impl SocketFactory<UdpSocket> for UdpSocketFactory {
fn bind(&self, local: SocketAddr) -> io::Result<UdpSocket> {
let socket = socket_factory::udp(local)?;
setsockopt(&socket, sockopt::Mark, &FIREZONE_MARK)?;
Ok(socket)
}
fn reset(&self) {}
}

View File

@@ -1,2 +1,16 @@
use socket_factory::{SocketFactory, UdpSocket};
use std::io;
use std::net::SocketAddr;
pub use socket_factory::tcp as tcp_socket_factory;
pub use socket_factory::udp as udp_socket_factory;
#[derive(Default)]
pub struct UdpSocketFactory {}
impl SocketFactory<UdpSocket> for UdpSocketFactory {
fn bind(&self, local: SocketAddr) -> io::Result<UdpSocket> {
socket_factory::udp(local)
}
fn reset(&self) {}
}

View File

@@ -1,6 +1,8 @@
use crate::TUNNEL_NAME;
use anyhow::Result;
use dashmap::DashMap;
use firezone_logging::err_with_src;
use socket_factory::SocketFactory;
use socket_factory::{TcpSocket, UdpSocket};
use std::{
cmp::Ordering,
@@ -8,6 +10,7 @@ use std::{
mem::MaybeUninit,
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6},
ptr::null,
sync::Arc,
};
use uuid::Uuid;
use windows::Win32::NetworkManagement::{
@@ -87,7 +90,7 @@ pub mod error {
pub const EPT_S_NOT_REGISTERED: HRESULT = HRESULT::from_win32(0x06D9);
}
pub fn tcp_socket_factory(addr: &SocketAddr) -> io::Result<TcpSocket> {
pub fn tcp_socket_factory(addr: SocketAddr) -> io::Result<TcpSocket> {
delete_all_routing_entries_matching(addr.ip())?;
let route = get_best_non_tunnel_route(addr.ip())?;
@@ -104,13 +107,61 @@ pub fn tcp_socket_factory(addr: &SocketAddr) -> io::Result<TcpSocket> {
Ok(socket)
}
pub fn udp_socket_factory(src_addr: &SocketAddr) -> io::Result<UdpSocket> {
let source_ip_resolver = |dst| Ok(get_best_non_tunnel_route(dst)?.addr);
/// A UDP socket factory with a src IP cache.
///
/// On Windows, we need to manually compute and set the source IP for all UDP packets.
/// Determining this mapping requires several syscalls and therefore is too expensive to perform on every packet.
/// To speed things up, we therefore implement a cache across all UDP sockets created by a given [`UdpSocketFactory`].
///
/// This cache needs to be reset whenever we are roaming networks which happens in the [`SocketFactory::reset`] function.
///
/// As most of the time we will only read from the cache, we use a [`DashMap`] (a concurrent hash-map).
pub struct UdpSocketFactory {
src_ip_cache: Arc<DashMap<IpAddr, IpAddr>>,
}
let socket =
socket_factory::udp(src_addr)?.with_source_ip_resolver(Box::new(source_ip_resolver));
impl SocketFactory<UdpSocket> for UdpSocketFactory {
fn bind(&self, local: SocketAddr) -> io::Result<UdpSocket> {
let src_ip_cache = self.src_ip_cache.clone();
Ok(socket)
let source_ip_resolver = move |dst| {
// First, try to get the existing entry (this is only using a read-lock internally so quite fast.)
if let Some(addr) = src_ip_cache.get(&dst) {
return Ok(*addr.value());
}
// If we don't have an entry, compute it.
let addr = get_best_non_tunnel_route(dst)?.addr;
// Insert the result.
// This may be a possible data-race if two sockets want to resolve the same IP at the same time.
// It doesn't matter though as the result of `get_best_non_tunnel_route` should be deterministic.
src_ip_cache.insert(dst, addr);
Ok(addr)
};
let socket =
socket_factory::udp(local)?.with_source_ip_resolver(Box::new(source_ip_resolver));
Ok(socket)
}
fn reset(&self) {
self.src_ip_cache.clear()
}
}
impl Default for UdpSocketFactory {
fn default() -> Self {
Self {
// This cache is expected to be quite small as there are only so many different IPs a client will talk to:
// - All connected Gateways
// - All DNS servers
// The capacity is only a guideline for the initial memory-consumption, it can also be outgrown.
src_ip_cache: Arc::new(DashMap::with_capacity(16)),
}
}
}
fn delete_all_routing_entries_matching(addr: IpAddr) -> io::Result<()> {

View File

@@ -29,7 +29,7 @@ async fn no_packet_loops_tcp() {
.unwrap();
let remote = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::from([1, 1, 1, 1]), 80));
let socket = tcp_socket_factory(&remote).unwrap();
let socket = tcp_socket_factory(remote).unwrap();
let mut stream = socket.connect(remote).await.unwrap();
// Send an HTTP request

View File

@@ -2,11 +2,12 @@
use bufferpool::BufferPool;
use bytes::BytesMut;
use firezone_bin_shared::{TunDeviceManager, platform::udp_socket_factory};
use firezone_bin_shared::{TunDeviceManager, platform::UdpSocketFactory};
use gat_lending_iterator::LendingIterator as _;
use ip_network::Ipv4Network;
use ip_packet::Ecn;
use socket_factory::DatagramOut;
use socket_factory::SocketFactory as _;
use std::{
net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4},
time::Duration,
@@ -36,9 +37,12 @@ async fn no_packet_loops_udp() {
.await
.unwrap();
let factory = UdpSocketFactory::default();
// Make a socket.
let socket =
udp_socket_factory(&SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0))).unwrap();
let socket = factory
.bind(SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0)))
.unwrap();
// Send a STUN request.
socket

View File

@@ -113,7 +113,7 @@ async fn connect(
let mut errors = Vec::with_capacity(addresses.len());
for addr in addresses {
let Ok(socket) = socket_factory(&addr) else {
let Ok(socket) = socket_factory.bind(addr) else {
continue;
};

View File

@@ -12,7 +12,6 @@ derive_more = { workspace = true, features = ["debug"] }
gat-lending-iterator = { workspace = true }
ip-packet = { workspace = true }
opentelemetry = { workspace = true, features = ["metrics"] }
parking_lot = { workspace = true }
quinn-udp = { workspace = true }
socket2 = { workspace = true }
tokio = { workspace = true, features = ["net"] }

View File

@@ -4,9 +4,7 @@ use bytes::{Buf as _, BytesMut};
use gat_lending_iterator::LendingIterator;
use ip_packet::{Ecn, Ipv4Header, Ipv6Header, UdpHeader};
use opentelemetry::KeyValue;
use parking_lot::Mutex;
use quinn_udp::{EcnCodepoint, Transmit, UdpSockRef};
use std::collections::HashMap;
use std::io;
use std::io::IoSliceMut;
use std::ops::Deref;
@@ -16,19 +14,30 @@ use std::{
};
use std::any::Any;
use std::collections::hash_map::Entry;
use std::pin::Pin;
use tokio::io::Interest;
pub trait SocketFactory<S>: Fn(&SocketAddr) -> io::Result<S> + Send + Sync + 'static {}
pub trait SocketFactory<S>: Send + Sync + 'static {
fn bind(&self, local: SocketAddr) -> io::Result<S>;
fn reset(&self);
}
pub const SEND_BUFFER_SIZE: usize = ONE_MB;
pub const RECV_BUFFER_SIZE: usize = 10 * ONE_MB;
const ONE_MB: usize = 1024 * 1024;
impl<F, S> SocketFactory<S> for F where F: Fn(&SocketAddr) -> io::Result<S> + Send + Sync + 'static {}
impl<F, S> SocketFactory<S> for F
where
F: Fn(SocketAddr) -> io::Result<S> + Send + Sync + 'static,
{
fn bind(&self, local: SocketAddr) -> io::Result<S> {
(self)(local)
}
pub fn tcp(addr: &SocketAddr) -> io::Result<TcpSocket> {
fn reset(&self) {}
}
pub fn tcp(addr: SocketAddr) -> io::Result<TcpSocket> {
let socket = match addr {
SocketAddr::V4(_) => tokio::net::TcpSocket::new_v4()?,
SocketAddr::V6(_) => tokio::net::TcpSocket::new_v6()?,
@@ -42,8 +51,8 @@ pub fn tcp(addr: &SocketAddr) -> io::Result<TcpSocket> {
})
}
pub fn udp(std_addr: &SocketAddr) -> io::Result<UdpSocket> {
let addr = socket2::SockAddr::from(*std_addr);
pub fn udp(std_addr: SocketAddr) -> io::Result<UdpSocket> {
let addr = socket2::SockAddr::from(std_addr);
let socket = socket2::Socket::new(addr.domain(), socket2::Type::DGRAM, None)?;
// Note: for AF_INET sockets IPV6_V6ONLY is not a valid flag
@@ -151,9 +160,6 @@ pub struct UdpSocket {
source_ip_resolver:
Option<Box<dyn Fn(IpAddr) -> std::io::Result<IpAddr> + Send + Sync + 'static>>,
/// A cache of source IPs by their destination IPs.
src_by_dst_cache: Mutex<HashMap<IpAddr, IpAddr>>,
/// A buffer pool for batches of incoming UDP packets.
buffer_pool: BufferPool<Vec<u8>>,
@@ -171,7 +177,6 @@ impl UdpSocket {
port,
inner,
source_ip_resolver: None,
src_by_dst_cache: Default::default(),
buffer_pool: BufferPool::new(
u16::MAX as usize,
match socket_addr.ip() {
@@ -215,8 +220,7 @@ impl UdpSocket {
/// Configures a new source IP resolver for this UDP socket.
///
/// In case [`DatagramOut::src`] is [`None`], this function will be used to set a source IP given the destination IP of the datagram.
/// The resulting IPs will be cached.
/// To evict this cache, drop the [`UdpSocket`] and make a new one.
/// If set, this function will be called for _every_ packet and should therefore be fast.
///
/// Errors during resolution result in the packet being dropped.
pub fn with_source_ip_resolver(
@@ -469,24 +473,13 @@ impl UdpSocket {
/// Attempt to resolve the source IP to use for sending to the given destination IP.
fn resolve_source_for(&self, dst: IpAddr) -> std::io::Result<Option<IpAddr>> {
let src = match self.src_by_dst_cache.lock().entry(dst) {
Entry::Occupied(occ) => *occ.get(),
Entry::Vacant(vac) => {
// Caching errors could be a good idea to not incur in multiple calls for the resolver which can be costly
// For some cases like hosts ipv4-only stack trying to send ipv6 packets this can happen quite often but doing this is also a risk
// that in case that the adapter for some reason is temporarily unavailable it'd prevent the system from recovery.
let Some(resolver) = self.source_ip_resolver.as_ref() else {
// If we don't have a resolver, let the operating system decide.
return Ok(None);
};
let src = (resolver)(dst)?;
*vac.insert(src)
}
let Some(resolver) = self.source_ip_resolver.as_ref() else {
// If we don't have a resolver, let the operating system decide.
return Ok(None);
};
let src = (resolver)(dst)?;
Ok(Some(src))
}
}

View File

@@ -321,6 +321,8 @@ impl Io {
}
pub fn reset(&mut self) {
self.tcp_socket_factory.reset();
self.udp_socket_factory.reset();
self.sockets.rebind(self.udp_socket_factory.clone());
self.gso_queue.clear();
self.dns_queries = FuturesTupleSet::new(DNS_QUERY_TIMEOUT, 1000);

View File

@@ -10,7 +10,7 @@ pub async fn send(
) -> io::Result<dns_types::Response> {
tracing::trace!(target: "wire::dns::recursive::tcp", %server, domain = %query.domain());
let tcp_socket = factory(&server)?; // TODO: Optimise this to reuse a TCP socket to the same resolver.
let tcp_socket = factory.bind(server)?; // TODO: Optimise this to reuse a TCP socket to the same resolver.
let mut tcp_stream = tcp_socket.connect(server).await?;
let query = query.into_bytes();

View File

@@ -21,7 +21,7 @@ pub async fn send(
// To avoid fragmentation, IP and thus also UDP packets can only reliably sent with an MTU of <= 1500 on the public Internet.
const BUF_SIZE: usize = 1500;
let udp_socket = factory(&bind_addr)?;
let udp_socket = factory.bind(bind_addr)?;
let response = udp_socket
.handshake::<BUF_SIZE>(server, &query.into_bytes())

View File

@@ -311,7 +311,7 @@ fn listen(
let mut last_err = None;
for addr in addresses {
match sf(addr) {
match sf.bind(*addr) {
Ok(s) => return Ok(s),
Err(e) => {
tracing::debug!(%addr, "Failed to listen on UDP socket: {e}");

View File

@@ -4,7 +4,7 @@ use backoff::ExponentialBackoffBuilder;
use clap::Parser;
use firezone_bin_shared::{
TunDeviceManager, device_id, http_health_check,
platform::{tcp_socket_factory, udp_socket_factory},
platform::{UdpSocketFactory, tcp_socket_factory},
};
use firezone_telemetry::{
@@ -165,7 +165,7 @@ async fn try_main(cli: Cli, telemetry: &mut Telemetry) -> Result<()> {
let mut tunnel = GatewayTunnel::new(
Arc::new(tcp_socket_factory),
Arc::new(udp_socket_factory),
Arc::new(UdpSocketFactory::default()),
nameservers,
);
let portal = PhoenixChannel::disconnected(

View File

@@ -10,7 +10,7 @@ use firezone_bin_shared::{
DnsControlMethod, DnsController, TunDeviceManager,
device_id::{self, DeviceId},
device_info, known_dirs,
platform::{tcp_socket_factory, udp_socket_factory},
platform::{UdpSocketFactory, tcp_socket_factory},
signals,
};
use firezone_logging::{FilterReloadHandle, err_with_src};
@@ -636,7 +636,7 @@ impl<'a> Handler<'a> {
let dns = self.dns_controller.system_resolvers();
let (connlib, event_stream) = client_shared::Session::connect(
Arc::new(tcp_socket_factory),
Arc::new(udp_socket_factory),
Arc::new(UdpSocketFactory::default()),
portal,
tokio::runtime::Handle::current(),
);

View File

@@ -8,7 +8,7 @@ use clap::Parser;
use firezone_bin_shared::{
DnsControlMethod, DnsController, TOKEN_ENV_KEY, TunDeviceManager, device_id, device_info,
new_dns_notifier, new_network_notifier,
platform::{tcp_socket_factory, udp_socket_factory},
platform::{UdpSocketFactory, tcp_socket_factory},
signals,
};
use firezone_telemetry::{Telemetry, analytics, otel};
@@ -265,7 +265,7 @@ fn main() -> Result<()> {
)?;
let (session, mut event_stream) = client_shared::Session::connect(
Arc::new(tcp_socket_factory),
Arc::new(udp_socket_factory),
Arc::new(UdpSocketFactory::default()),
portal,
rt.handle().clone(),
);