From c3a45f53df46fcd7b82859e5d595155ea3480eff Mon Sep 17 00:00:00 2001 From: Gabi Date: Mon, 29 Jul 2024 19:25:42 -0300 Subject: [PATCH] fix(connlib): prevent routing loops on windows (#6032) In `connlib`, traffic is sent through sockets via one of three ways: 1. Direct p2p traffic between clients and gateways: For these, we always explicitly set the source IP (and thus interface). 2. UDP traffic to the relays: For these, we let the OS pick an appropriate source interface. 3. WebSocket traffic over TCP to the portal: For this too, we let the OS pick the source interface. For (2) and (3), it is possible to run into routing loops, depending on the routes that we have configured on the TUN device. In Linux, we can prevent routing loops by marking a socket [0] and repeating the mark when we add routes [1]. Packets sent via a marked socket won't be routed by a rule that contains this mark. On Android, we can do something similar by "protecting" a socket via a syscall on the Java side [2]. On Windows, routing works slightly different. There, the source interface is determined based on a computed metric [3] [4]. To prevent routing loops on Windows, we thus need to find the "next best" interface after our TUN interface. We can achieve this with a combination of several syscalls: 1. List all interfaces on the machine 2. Ask Windows for the best route on each interface, except our TUN interface. 3. Sort by Windows' routing metric and pick the lowest one (lower is better). Thanks to the abstraction of `SocketFactory` that we already previously introduced, Integrating this into `connlib` isn't too difficult: 1. For TCP sockets, we simply resolve the best route after creating the socket and then bind it to that local interface. That way, all packets will always going via that interface, regardless of which routes are present on our TUN interface. 2. UDP is connection-less so we need to decide per-packet, which interface to use. "Pick the best interface for me" is modelled in `connlib` via the `DatagramOut::src` field being `None`. - To ensure those packets don't cause a routing loop, we introduce a "source IP resolver" for our `UdpSocket`. This function gets called every time we need to send a packet without a source IP. - For improved performance, we cache these results. The Windows client uses this source IP resolver to use the above devised strategy to find a suitable source IP. - In case the source IP resolution fails, we don't send the packet. This is important, otherwise, the kernel might choose our TUN interface again and trigger a routing loop. The last remark to make here is that this also works for connection roaming. The TCP socket gets thrown away when we reconnect to the portal. Thus, the new socket will pick the new best interface as it is re-created. The UDP sockets also get thrown away as part of roaming. That clears the above cache which is what we want: Upon roaming, the best interface for a given destination IP will likely have changed. [0]: https://github.com/firezone/firezone/blob/59014a962240b32799b932a23f3b7e3ec30b20ee/rust/headless-client/src/linux.rs#L19-L29 [1]: https://github.com/firezone/firezone/blob/59014a962240b32799b932a23f3b7e3ec30b20ee/rust/bin-shared/src/tun_device_manager/linux.rs#L204-L224 [2]: https://github.com/firezone/firezone/blob/59014a962240b32799b932a23f3b7e3ec30b20ee/rust/connlib/clients/android/src/lib.rs#L535-L549 [3]: https://learn.microsoft.com/en-us/previous-versions/technet-magazine/cc137807(v=msdn.10)?redirectedfrom=MSDN [4]: https://learn.microsoft.com/en-us/windows-server/networking/technologies/network-subsystem/net-sub-interface-metric Fixes: #5955. --------- Signed-off-by: Thomas Eizinger Co-authored-by: Thomas Eizinger --- rust/Cargo.lock | 8 +- rust/bin-shared/Cargo.toml | 2 +- rust/bin-shared/src/lib.rs | 3 + .../src/tun_device_manager/windows.rs | 4 +- rust/connlib/shared/Cargo.toml | 1 - rust/headless-client/Cargo.toml | 7 +- rust/headless-client/src/windows.rs | 279 +++++++++++++++++- rust/socket-factory/src/lib.rs | 71 ++++- 8 files changed, 360 insertions(+), 15 deletions(-) diff --git a/rust/Cargo.lock b/rust/Cargo.lock index 2c304a910..9698eabc2 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -1084,7 +1084,6 @@ dependencies = [ "proptest", "rand 0.8.5", "rand_core 0.6.4", - "ring", "secrecy", "serde", "serde_json", @@ -1907,6 +1906,7 @@ dependencies = [ "firezone-bin-shared", "futures", "git-version", + "hex-literal", "humantime", "ip_network", "ipconfig", @@ -2644,6 +2644,12 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b53d6a634507c5d9fdee77261ae54a8d1ff7887f5304389025b03c3292a1756" +[[package]] +name = "hex-literal" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6fe2267d4ed49bc07b63801559be28c718ea06c4738b7a03c94df7386d2cde46" + [[package]] name = "hickory-proto" version = "0.24.0" diff --git a/rust/bin-shared/Cargo.toml b/rust/bin-shared/Cargo.toml index be33c03f4..ae11254dd 100644 --- a/rust/bin-shared/Cargo.toml +++ b/rust/bin-shared/Cargo.toml @@ -27,7 +27,7 @@ netlink-packet-route = { version = "0.19", default-features = false } rtnetlink = { workspace = true } libc = "0.2" -[target.'cfg(target_os = "windows")'.dependencies] +[target.'cfg(windows)'.dependencies] known-folders = "1.1.0" ring = "0.17" uuid = { version = "1.10.0", features = ["v4"] } diff --git a/rust/bin-shared/src/lib.rs b/rust/bin-shared/src/lib.rs index 69623a4b7..c5cb08378 100644 --- a/rust/bin-shared/src/lib.rs +++ b/rust/bin-shared/src/lib.rs @@ -9,6 +9,9 @@ use tracing_subscriber::{ fmt, prelude::__tracing_subscriber_SubscriberExt, EnvFilter, Layer, Registry, }; +// wintun automatically append " Tunnel" to this +pub const TUNNEL_NAME: &str = "Firezone"; + /// Bundle ID / App ID that the client uses to distinguish itself from other programs on the system /// /// e.g. In ProgramData and AppData we use this to name our subdirectories for configs and data, diff --git a/rust/bin-shared/src/tun_device_manager/windows.rs b/rust/bin-shared/src/tun_device_manager/windows.rs index 31964b9ba..9d8dc7367 100644 --- a/rust/bin-shared/src/tun_device_manager/windows.rs +++ b/rust/bin-shared/src/tun_device_manager/windows.rs @@ -1,4 +1,5 @@ use crate::windows::CREATE_NO_WINDOW; +use crate::TUNNEL_NAME; use anyhow::{Context as _, Result}; use connlib_shared::DEFAULT_MTU; use ip_network::{IpNetwork, Ipv4Network, Ipv6Network}; @@ -27,9 +28,6 @@ use windows::Win32::{ }; use wintun::Adapter; -// wintun automatically append " Tunnel" to this -pub(crate) const TUNNEL_NAME: &str = "Firezone"; - /// The ring buffer size used for Wintun. /// /// Must be a power of two within a certain range diff --git a/rust/connlib/shared/Cargo.toml b/rust/connlib/shared/Cargo.toml index acd6815b7..086c54099 100644 --- a/rust/connlib/shared/Cargo.toml +++ b/rust/connlib/shared/Cargo.toml @@ -25,7 +25,6 @@ phoenix-channel = { workspace = true } proptest = { version = "1", optional = true } rand = { version = "0.8", default-features = false, features = ["std"] } rand_core = { version = "0.6.4", default-features = false, features = ["std"] } -ring = "0.17" secrecy = { workspace = true, features = ["serde", "bytes"] } serde = { version = "1.0", default-features = false, features = ["derive", "std"] } serde_json = { version = "1.0", default-features = false, features = ["std"] } diff --git a/rust/headless-client/Cargo.toml b/rust/headless-client/Cargo.toml index a8ee398f5..b4b33b1ab 100644 --- a/rust/headless-client/Cargo.toml +++ b/rust/headless-client/Cargo.toml @@ -26,7 +26,7 @@ socket-factory = { workspace = true } thiserror = { version = "1.0", default-features = false } # This actually relies on many other features in Tokio, so this will probably # fail to build outside the workspace. -tokio = { workspace = true, features = ["macros", "signal", "process"] } +tokio = { workspace = true, features = ["macros", "signal", "process", "time"] } tokio-stream = "0.1.15" tokio-util = { version = "0.7.11", features = ["codec"] } tracing = { workspace = true } @@ -40,6 +40,9 @@ tempfile = "3.10.1" [target.'cfg(target_os = "linux")'.dev-dependencies] mutants = "0.0.3" # Needed to mark functions as exempt from `cargo-mutants` testing +[target.'cfg(target_os = "windows")'.dev-dependencies] +hex-literal = "0.4.1" + [target.'cfg(target_os = "linux")'.dependencies] dirs = "5.0.1" libc = "0.2.150" @@ -59,7 +62,7 @@ tracing-subscriber = { version = "0.3.17", features = ["env-filter"] } windows-service = "0.7.0" winreg = "0.52.0" -[target.'cfg(target_os = "windows")'.dependencies.windows] +[target.'cfg(windows)'.dependencies.windows] version = "0.57.0" features = [ # For DNS control and route control diff --git a/rust/headless-client/src/windows.rs b/rust/headless-client/src/windows.rs index cb7e3f405..f4966c16f 100644 --- a/rust/headless-client/src/windows.rs +++ b/rust/headless-client/src/windows.rs @@ -5,10 +5,207 @@ //! We must tell Windows explicitly when our service is stopping. use anyhow::Result; -use std::path::{Path, PathBuf}; +use std::{ + cmp::Ordering, + io, + mem::MaybeUninit, + net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, + path::{Path, PathBuf}, + ptr::null, +}; -pub(crate) use socket_factory::tcp as tcp_socket_factory; -pub(crate) use socket_factory::udp as udp_socket_factory; +use windows::Win32::NetworkManagement::{IpHelper::GetAdaptersAddresses, Ndis::NET_LUID_LH}; +use windows::Win32::Networking::WinSock::SOCKADDR_INET; +use windows::Win32::{ + NetworkManagement::IpHelper::{ + GetBestRoute2, GET_ADAPTERS_ADDRESSES_FLAGS, IP_ADAPTER_ADDRESSES_LH, MIB_IPFORWARD_ROW2, + }, + Networking::WinSock::{ADDRESS_FAMILY, AF_UNSPEC}, +}; + +use firezone_bin_shared::TUNNEL_NAME; +use socket_factory::{TcpSocket, UdpSocket}; + +pub fn tcp_socket_factory(addr: &SocketAddr) -> io::Result { + let local = get_best_non_tunnel_route(addr.ip())?; + + let socket = socket_factory::tcp(addr)?; + socket.bind((local, 0).into())?; // To avoid routing loops, all TCP sockets are bound to the "best" source IP. + + Ok(socket) +} + +pub fn udp_socket_factory(src_addr: &SocketAddr) -> io::Result { + let source_ip_resolver = |dst| Ok(Some(get_best_non_tunnel_route(dst)?)); + + let socket = + socket_factory::udp(src_addr)?.with_source_ip_resolver(Box::new(source_ip_resolver)); + + Ok(socket) +} + +struct Adapters { + _buffer: Vec, + next: *const IP_ADAPTER_ADDRESSES_LH, +} + +impl Iterator for Adapters { + type Item = &'static IP_ADAPTER_ADDRESSES_LH; + + fn next(&mut self) -> Option { + // SAFETY: We expect windows to give us a valid linked list where each item of the list is actually an IP_ADAPTER_ADDRESSES_LH. + let adapter = unsafe { self.next.as_ref()? }; + + self.next = adapter.Next; + + Some(adapter) + } +} + +/// Finds the best route (i.e. source interface) for a given destination IP, excluding our TUN interface. +/// +/// To prevent routing loops on Windows, we need to explicitly set a source IP for all packets. +/// Windows uses a computed metric per interface for routing. +/// We implement the same logic here, with the addition of explicitly filtering out our TUN interface. +/// +/// # Performance +/// +/// This function performs multiple syscalls and is thus fairly expensive. +/// It should **not** be called on a per-packet basis. +/// Callers should instead cache the result until network interfaces change. +fn get_best_non_tunnel_route(dst: IpAddr) -> io::Result { + let route = list_adapters()? + .filter(|adapter| !is_tun(adapter)) + .map(|adapter| adapter.Luid) + .filter_map(|luid| find_best_route_for_luid(&luid, dst).ok()) + .min() + .ok_or(io::Error::other("No route to host"))?; + + let src = route.addr; + + tracing::debug!(%src, %dst, "Resolved best route outside of tunnel interface"); + + Ok(src) +} + +fn list_adapters() -> io::Result { + use windows::Win32::Foundation::ERROR_BUFFER_OVERFLOW; + use windows::Win32::Foundation::WIN32_ERROR; + + // 15kB is recommended to almost never fail + let mut buffer: Vec = vec![0u8; 15000]; + let mut buffer_len = buffer.len() as u32; + // Safety we just allocated buffer with the len we are passing + let mut res = unsafe { + GetAdaptersAddresses( + AF_UNSPEC.0 as u32, + GET_ADAPTERS_ADDRESSES_FLAGS(0), + Some(null()), + Some(buffer.as_mut_ptr() as *mut _), + &mut buffer_len as *mut _, + ) + }; + + // In case of a buffer overflow buffer_len will contain the necessary length + if res == ERROR_BUFFER_OVERFLOW.0 { + buffer = vec![0u8; buffer_len as usize]; + // SAFETY: we just allocated buffer with the len we are passing + res = unsafe { + GetAdaptersAddresses( + AF_UNSPEC.0 as u32, + GET_ADAPTERS_ADDRESSES_FLAGS(0), + Some(null()), + Some(buffer.as_mut_ptr() as *mut _), + &mut buffer_len as *mut _, + ) + }; + } + + WIN32_ERROR(res).ok()?; + + let next = buffer.as_ptr() as *const _; + Ok(Adapters { + _buffer: buffer, + next, + }) +} + +fn is_tun(adapter: &IP_ADAPTER_ADDRESSES_LH) -> bool { + if adapter.FriendlyName.is_null() { + return false; + } + + // SAFETY: It should be safe to call to_string since we checked it's not null and the reference should be valid + let friendly_name = unsafe { adapter.FriendlyName.to_string() }; + let Ok(friendly_name) = friendly_name else { + return false; + }; + + friendly_name == TUNNEL_NAME +} + +#[derive(PartialEq, Eq)] +struct Route { + metric: u32, + addr: IpAddr, +} + +impl Ord for Route { + fn cmp(&self, other: &Self) -> Ordering { + self.metric.cmp(&other.metric) + } +} + +impl PartialOrd for Route { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +fn find_best_route_for_luid(luid: &NET_LUID_LH, dst: IpAddr) -> Result { + let addr: SOCKADDR_INET = SocketAddr::from((dst, 0)).into(); + let mut best_route: MaybeUninit = MaybeUninit::zeroed(); + let mut best_src: MaybeUninit = MaybeUninit::zeroed(); + + // SAFETY: all pointers w ejust allocated with the correct types so it must be safe + let res = unsafe { + GetBestRoute2( + Some(luid as *const _), + 0, + None, + &addr as *const _, + 0, + best_route.as_mut_ptr(), + best_src.as_mut_ptr(), + ) + }; + + res.ok()?; + + // SAFETY: we just successfully initialized these pointers + let best_route = unsafe { best_route.assume_init() }; + let best_src = unsafe { best_src.assume_init() }; + + Ok(Route { + // SAFETY: we expect to get a valid address + addr: unsafe { to_ip_addr(best_src, dst) } + .ok_or(io::Error::other("can't find a valid route"))?, + metric: best_route.Metric, + }) +} + +// SAFETY: si_family must be always set in the union, which will be the case for a valid SOCKADDR_INET +unsafe fn to_ip_addr(addr: SOCKADDR_INET, dst: IpAddr) -> Option { + match (addr.si_family, dst) { + (ADDRESS_FAMILY(0), IpAddr::V4(_)) | (ADDRESS_FAMILY(2), _) => { + Some(Ipv4Addr::from(addr.Ipv4.sin_addr).into()) + } + (ADDRESS_FAMILY(0), IpAddr::V6(_)) | (ADDRESS_FAMILY(23), _) => { + Some(Ipv6Addr::from(addr.Ipv6.sin6_addr).into()) + } + _ => None, + } +} // The return value is useful on Linux #[allow(clippy::unnecessary_wraps)] @@ -29,3 +226,79 @@ pub(crate) fn default_token_path() -> std::path::PathBuf { pub(crate) fn notify_service_controller() -> Result<()> { Ok(()) } + +#[cfg(test)] +mod test { + use super::*; + use firezone_bin_shared::TunDeviceManager; + use ip_network::Ipv4Network; + use socket_factory::DatagramOut; + use std::borrow::Cow; + use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4}; + use std::time::Duration; + + #[test] + fn best_route_ip4_does_not_panic_or_segfault() { + let _ = get_best_non_tunnel_route("8.8.8.8".parse().unwrap()); + } + + #[test] + fn best_route_ip6_does_not_panic_or_segfault() { + let _ = get_best_non_tunnel_route("2404:6800:4006:811::200e".parse().unwrap()); + } + + // Starts up a WinTUN device, adds a "full-route" (`0.0.0.0/0`) and checks if we can still send packets to IPs outside of our tunnel. + #[tokio::test] + #[ignore = "Needs admin & Internet"] + async fn no_packet_loops() { + let ipv4 = Ipv4Addr::from([100, 90, 215, 97]); + let ipv6 = Ipv6Addr::from([0xfd00, 0x2021, 0x1111, 0x0, 0x0, 0x0, 0x0016, 0x588f]); + + let mut device_manager = TunDeviceManager::new().unwrap(); + let _tun = device_manager.make_tun().unwrap(); + device_manager.set_ips(ipv4, ipv6).await.unwrap(); + + // Configure `0.0.0.0/0` route. + device_manager + .set_routes( + vec![Ipv4Network::new(Ipv4Addr::UNSPECIFIED, 0).unwrap()], + vec![], + ) + .await + .unwrap(); + + // Make a socket. + let mut socket = + udp_socket_factory(&SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0))) + .unwrap(); + + // Send a STUN request. + socket + .send(DatagramOut { + src: None, + dst: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(141, 101, 90, 0), 3478)), // stun.cloudflare.com, + packet: Cow::Borrowed(&hex_literal::hex!( + "000100002112A4420123456789abcdef01234567" + )), + }) + .unwrap(); + + // First send seems to always result as would block + std::future::poll_fn(|cx| socket.poll_flush(cx)) + .await + .unwrap(); + + let task = std::future::poll_fn(|cx| { + let mut buf = [0u8; 1000]; + let result = std::task::ready!(socket.poll_recv_from(&mut buf, cx)); + + let _response = result.unwrap().next().unwrap(); + + std::task::Poll::Ready(()) + }); + + tokio::time::timeout(Duration::from_secs(10), task) + .await + .unwrap(); + } +} diff --git a/rust/socket-factory/src/lib.rs b/rust/socket-factory/src/lib.rs index 225895ed3..fe7967733 100644 --- a/rust/socket-factory/src/lib.rs +++ b/rust/socket-factory/src/lib.rs @@ -1,13 +1,15 @@ +use std::collections::HashMap; use std::{ borrow::Cow, collections::VecDeque, io::{self, IoSliceMut}, - net::SocketAddr, + net::{IpAddr, SocketAddr}, slice, task::{ready, Context, Poll}, }; use socket2::SockAddr; +use std::collections::hash_map::Entry; use tokio::io::Interest; pub trait SocketFactory: Fn(&SocketAddr) -> io::Result + Send + Sync + 'static {} @@ -52,6 +54,10 @@ impl TcpSocket { pub async fn connect(self, addr: SocketAddr) -> io::Result { self.inner.connect(addr).await } + + pub fn bind(&self, addr: SocketAddr) -> io::Result<()> { + self.inner.bind(addr) + } } #[cfg(unix)] @@ -71,6 +77,11 @@ impl std::os::fd::AsFd for TcpSocket { pub struct UdpSocket { inner: tokio::net::UdpSocket, state: quinn_udp::UdpSocketState, + source_ip_resolver: + Box std::io::Result> + Send + Sync + 'static>, + + /// A cache of source IPs by their destination IPs. + src_by_dst_cache: HashMap, port: u16, @@ -85,9 +96,26 @@ impl UdpSocket { state: quinn_udp::UdpSocketState::new(quinn_udp::UdpSockRef::from(&inner))?, port, inner, + source_ip_resolver: Box::new(|_| Ok(None)), buffered_datagrams: VecDeque::new(), + src_by_dst_cache: Default::default(), }) } + + /// Configures a new source IP resolver for this UDP socket. + /// + /// In case [`DatagramOut::src`] is [`None`], this function will be used to set a source IP given the destination IP of the datagram. + /// The resulting IPs will be cached. + /// To evict this cache, drop the [`UdpSocket`] and make a new one. + /// + /// Errors during resolution result in the packet being dropped. + pub fn with_source_ip_resolver( + mut self, + resolver: Box std::io::Result> + Send + Sync + 'static>, + ) -> Self { + self.source_ip_resolver = resolver; + self + } } #[cfg(unix)] @@ -224,19 +252,54 @@ impl UdpSocket { } } - pub fn try_send(&self, transmit: &DatagramOut) -> io::Result<()> { + pub fn try_send(&mut self, transmit: &DatagramOut) -> io::Result<()> { + let destination = transmit.dst; + let src_ip = transmit.src.map(|s| s.ip()); + + let src_ip = match src_ip { + Some(src_ip) => Some(src_ip), + None => match self.resolve_source_for(destination.ip()) { + Ok(src_ip) => src_ip, + Err(e) => { + tracing::trace!( + dst = %transmit.dst.ip(), + "No available interface for packet: {e}" + ); + return Ok(()); + } + }, + }; + let transmit = quinn_udp::Transmit { - destination: transmit.dst, + destination, ecn: None, contents: &transmit.packet, segment_size: None, - src_ip: transmit.src.map(|s| s.ip()), + src_ip, }; self.inner.try_io(Interest::WRITABLE, || { self.state.send((&self.inner).into(), &transmit) }) } + + /// Attempt to resolve the source IP to use for sending to the given destination IP. + fn resolve_source_for(&mut self, dst: IpAddr) -> std::io::Result> { + let src = match self.src_by_dst_cache.entry(dst) { + Entry::Occupied(occ) => *occ.get(), + Entry::Vacant(vac) => { + // Caching errors could be a good idea to not incur in multiple calls for the resolver which can be costly + // For some cases like hosts ipv4-only stack trying to send ipv6 packets this can happen quite often but doing this is also a risk + // that in case that the adapter for some reason is temporarily unavailable it'd prevent the system from recovery. + let Some(src) = (self.source_ip_resolver)(dst)? else { + return Ok(None); + }; + *vac.insert(src) + } + }; + + Ok(Some(src)) + } } #[cfg(feature = "hickory")]