diff --git a/rust/Cargo.lock b/rust/Cargo.lock index 2c304a910..9698eabc2 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -1084,7 +1084,6 @@ dependencies = [ "proptest", "rand 0.8.5", "rand_core 0.6.4", - "ring", "secrecy", "serde", "serde_json", @@ -1907,6 +1906,7 @@ dependencies = [ "firezone-bin-shared", "futures", "git-version", + "hex-literal", "humantime", "ip_network", "ipconfig", @@ -2644,6 +2644,12 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b53d6a634507c5d9fdee77261ae54a8d1ff7887f5304389025b03c3292a1756" +[[package]] +name = "hex-literal" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6fe2267d4ed49bc07b63801559be28c718ea06c4738b7a03c94df7386d2cde46" + [[package]] name = "hickory-proto" version = "0.24.0" diff --git a/rust/bin-shared/Cargo.toml b/rust/bin-shared/Cargo.toml index be33c03f4..ae11254dd 100644 --- a/rust/bin-shared/Cargo.toml +++ b/rust/bin-shared/Cargo.toml @@ -27,7 +27,7 @@ netlink-packet-route = { version = "0.19", default-features = false } rtnetlink = { workspace = true } libc = "0.2" -[target.'cfg(target_os = "windows")'.dependencies] +[target.'cfg(windows)'.dependencies] known-folders = "1.1.0" ring = "0.17" uuid = { version = "1.10.0", features = ["v4"] } diff --git a/rust/bin-shared/src/lib.rs b/rust/bin-shared/src/lib.rs index 69623a4b7..c5cb08378 100644 --- a/rust/bin-shared/src/lib.rs +++ b/rust/bin-shared/src/lib.rs @@ -9,6 +9,9 @@ use tracing_subscriber::{ fmt, prelude::__tracing_subscriber_SubscriberExt, EnvFilter, Layer, Registry, }; +// wintun automatically append " Tunnel" to this +pub const TUNNEL_NAME: &str = "Firezone"; + /// Bundle ID / App ID that the client uses to distinguish itself from other programs on the system /// /// e.g. In ProgramData and AppData we use this to name our subdirectories for configs and data, diff --git a/rust/bin-shared/src/tun_device_manager/windows.rs b/rust/bin-shared/src/tun_device_manager/windows.rs index 31964b9ba..9d8dc7367 100644 --- a/rust/bin-shared/src/tun_device_manager/windows.rs +++ b/rust/bin-shared/src/tun_device_manager/windows.rs @@ -1,4 +1,5 @@ use crate::windows::CREATE_NO_WINDOW; +use crate::TUNNEL_NAME; use anyhow::{Context as _, Result}; use connlib_shared::DEFAULT_MTU; use ip_network::{IpNetwork, Ipv4Network, Ipv6Network}; @@ -27,9 +28,6 @@ use windows::Win32::{ }; use wintun::Adapter; -// wintun automatically append " Tunnel" to this -pub(crate) const TUNNEL_NAME: &str = "Firezone"; - /// The ring buffer size used for Wintun. /// /// Must be a power of two within a certain range diff --git a/rust/connlib/shared/Cargo.toml b/rust/connlib/shared/Cargo.toml index acd6815b7..086c54099 100644 --- a/rust/connlib/shared/Cargo.toml +++ b/rust/connlib/shared/Cargo.toml @@ -25,7 +25,6 @@ phoenix-channel = { workspace = true } proptest = { version = "1", optional = true } rand = { version = "0.8", default-features = false, features = ["std"] } rand_core = { version = "0.6.4", default-features = false, features = ["std"] } -ring = "0.17" secrecy = { workspace = true, features = ["serde", "bytes"] } serde = { version = "1.0", default-features = false, features = ["derive", "std"] } serde_json = { version = "1.0", default-features = false, features = ["std"] } diff --git a/rust/headless-client/Cargo.toml b/rust/headless-client/Cargo.toml index a8ee398f5..b4b33b1ab 100644 --- a/rust/headless-client/Cargo.toml +++ b/rust/headless-client/Cargo.toml @@ -26,7 +26,7 @@ socket-factory = { workspace = true } thiserror = { version = "1.0", default-features = false } # This actually relies on many other features in Tokio, so this will probably # fail to build outside the workspace. -tokio = { workspace = true, features = ["macros", "signal", "process"] } +tokio = { workspace = true, features = ["macros", "signal", "process", "time"] } tokio-stream = "0.1.15" tokio-util = { version = "0.7.11", features = ["codec"] } tracing = { workspace = true } @@ -40,6 +40,9 @@ tempfile = "3.10.1" [target.'cfg(target_os = "linux")'.dev-dependencies] mutants = "0.0.3" # Needed to mark functions as exempt from `cargo-mutants` testing +[target.'cfg(target_os = "windows")'.dev-dependencies] +hex-literal = "0.4.1" + [target.'cfg(target_os = "linux")'.dependencies] dirs = "5.0.1" libc = "0.2.150" @@ -59,7 +62,7 @@ tracing-subscriber = { version = "0.3.17", features = ["env-filter"] } windows-service = "0.7.0" winreg = "0.52.0" -[target.'cfg(target_os = "windows")'.dependencies.windows] +[target.'cfg(windows)'.dependencies.windows] version = "0.57.0" features = [ # For DNS control and route control diff --git a/rust/headless-client/src/windows.rs b/rust/headless-client/src/windows.rs index cb7e3f405..f4966c16f 100644 --- a/rust/headless-client/src/windows.rs +++ b/rust/headless-client/src/windows.rs @@ -5,10 +5,207 @@ //! We must tell Windows explicitly when our service is stopping. use anyhow::Result; -use std::path::{Path, PathBuf}; +use std::{ + cmp::Ordering, + io, + mem::MaybeUninit, + net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, + path::{Path, PathBuf}, + ptr::null, +}; -pub(crate) use socket_factory::tcp as tcp_socket_factory; -pub(crate) use socket_factory::udp as udp_socket_factory; +use windows::Win32::NetworkManagement::{IpHelper::GetAdaptersAddresses, Ndis::NET_LUID_LH}; +use windows::Win32::Networking::WinSock::SOCKADDR_INET; +use windows::Win32::{ + NetworkManagement::IpHelper::{ + GetBestRoute2, GET_ADAPTERS_ADDRESSES_FLAGS, IP_ADAPTER_ADDRESSES_LH, MIB_IPFORWARD_ROW2, + }, + Networking::WinSock::{ADDRESS_FAMILY, AF_UNSPEC}, +}; + +use firezone_bin_shared::TUNNEL_NAME; +use socket_factory::{TcpSocket, UdpSocket}; + +pub fn tcp_socket_factory(addr: &SocketAddr) -> io::Result { + let local = get_best_non_tunnel_route(addr.ip())?; + + let socket = socket_factory::tcp(addr)?; + socket.bind((local, 0).into())?; // To avoid routing loops, all TCP sockets are bound to the "best" source IP. + + Ok(socket) +} + +pub fn udp_socket_factory(src_addr: &SocketAddr) -> io::Result { + let source_ip_resolver = |dst| Ok(Some(get_best_non_tunnel_route(dst)?)); + + let socket = + socket_factory::udp(src_addr)?.with_source_ip_resolver(Box::new(source_ip_resolver)); + + Ok(socket) +} + +struct Adapters { + _buffer: Vec, + next: *const IP_ADAPTER_ADDRESSES_LH, +} + +impl Iterator for Adapters { + type Item = &'static IP_ADAPTER_ADDRESSES_LH; + + fn next(&mut self) -> Option { + // SAFETY: We expect windows to give us a valid linked list where each item of the list is actually an IP_ADAPTER_ADDRESSES_LH. + let adapter = unsafe { self.next.as_ref()? }; + + self.next = adapter.Next; + + Some(adapter) + } +} + +/// Finds the best route (i.e. source interface) for a given destination IP, excluding our TUN interface. +/// +/// To prevent routing loops on Windows, we need to explicitly set a source IP for all packets. +/// Windows uses a computed metric per interface for routing. +/// We implement the same logic here, with the addition of explicitly filtering out our TUN interface. +/// +/// # Performance +/// +/// This function performs multiple syscalls and is thus fairly expensive. +/// It should **not** be called on a per-packet basis. +/// Callers should instead cache the result until network interfaces change. +fn get_best_non_tunnel_route(dst: IpAddr) -> io::Result { + let route = list_adapters()? + .filter(|adapter| !is_tun(adapter)) + .map(|adapter| adapter.Luid) + .filter_map(|luid| find_best_route_for_luid(&luid, dst).ok()) + .min() + .ok_or(io::Error::other("No route to host"))?; + + let src = route.addr; + + tracing::debug!(%src, %dst, "Resolved best route outside of tunnel interface"); + + Ok(src) +} + +fn list_adapters() -> io::Result { + use windows::Win32::Foundation::ERROR_BUFFER_OVERFLOW; + use windows::Win32::Foundation::WIN32_ERROR; + + // 15kB is recommended to almost never fail + let mut buffer: Vec = vec![0u8; 15000]; + let mut buffer_len = buffer.len() as u32; + // Safety we just allocated buffer with the len we are passing + let mut res = unsafe { + GetAdaptersAddresses( + AF_UNSPEC.0 as u32, + GET_ADAPTERS_ADDRESSES_FLAGS(0), + Some(null()), + Some(buffer.as_mut_ptr() as *mut _), + &mut buffer_len as *mut _, + ) + }; + + // In case of a buffer overflow buffer_len will contain the necessary length + if res == ERROR_BUFFER_OVERFLOW.0 { + buffer = vec![0u8; buffer_len as usize]; + // SAFETY: we just allocated buffer with the len we are passing + res = unsafe { + GetAdaptersAddresses( + AF_UNSPEC.0 as u32, + GET_ADAPTERS_ADDRESSES_FLAGS(0), + Some(null()), + Some(buffer.as_mut_ptr() as *mut _), + &mut buffer_len as *mut _, + ) + }; + } + + WIN32_ERROR(res).ok()?; + + let next = buffer.as_ptr() as *const _; + Ok(Adapters { + _buffer: buffer, + next, + }) +} + +fn is_tun(adapter: &IP_ADAPTER_ADDRESSES_LH) -> bool { + if adapter.FriendlyName.is_null() { + return false; + } + + // SAFETY: It should be safe to call to_string since we checked it's not null and the reference should be valid + let friendly_name = unsafe { adapter.FriendlyName.to_string() }; + let Ok(friendly_name) = friendly_name else { + return false; + }; + + friendly_name == TUNNEL_NAME +} + +#[derive(PartialEq, Eq)] +struct Route { + metric: u32, + addr: IpAddr, +} + +impl Ord for Route { + fn cmp(&self, other: &Self) -> Ordering { + self.metric.cmp(&other.metric) + } +} + +impl PartialOrd for Route { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +fn find_best_route_for_luid(luid: &NET_LUID_LH, dst: IpAddr) -> Result { + let addr: SOCKADDR_INET = SocketAddr::from((dst, 0)).into(); + let mut best_route: MaybeUninit = MaybeUninit::zeroed(); + let mut best_src: MaybeUninit = MaybeUninit::zeroed(); + + // SAFETY: all pointers w ejust allocated with the correct types so it must be safe + let res = unsafe { + GetBestRoute2( + Some(luid as *const _), + 0, + None, + &addr as *const _, + 0, + best_route.as_mut_ptr(), + best_src.as_mut_ptr(), + ) + }; + + res.ok()?; + + // SAFETY: we just successfully initialized these pointers + let best_route = unsafe { best_route.assume_init() }; + let best_src = unsafe { best_src.assume_init() }; + + Ok(Route { + // SAFETY: we expect to get a valid address + addr: unsafe { to_ip_addr(best_src, dst) } + .ok_or(io::Error::other("can't find a valid route"))?, + metric: best_route.Metric, + }) +} + +// SAFETY: si_family must be always set in the union, which will be the case for a valid SOCKADDR_INET +unsafe fn to_ip_addr(addr: SOCKADDR_INET, dst: IpAddr) -> Option { + match (addr.si_family, dst) { + (ADDRESS_FAMILY(0), IpAddr::V4(_)) | (ADDRESS_FAMILY(2), _) => { + Some(Ipv4Addr::from(addr.Ipv4.sin_addr).into()) + } + (ADDRESS_FAMILY(0), IpAddr::V6(_)) | (ADDRESS_FAMILY(23), _) => { + Some(Ipv6Addr::from(addr.Ipv6.sin6_addr).into()) + } + _ => None, + } +} // The return value is useful on Linux #[allow(clippy::unnecessary_wraps)] @@ -29,3 +226,79 @@ pub(crate) fn default_token_path() -> std::path::PathBuf { pub(crate) fn notify_service_controller() -> Result<()> { Ok(()) } + +#[cfg(test)] +mod test { + use super::*; + use firezone_bin_shared::TunDeviceManager; + use ip_network::Ipv4Network; + use socket_factory::DatagramOut; + use std::borrow::Cow; + use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4}; + use std::time::Duration; + + #[test] + fn best_route_ip4_does_not_panic_or_segfault() { + let _ = get_best_non_tunnel_route("8.8.8.8".parse().unwrap()); + } + + #[test] + fn best_route_ip6_does_not_panic_or_segfault() { + let _ = get_best_non_tunnel_route("2404:6800:4006:811::200e".parse().unwrap()); + } + + // Starts up a WinTUN device, adds a "full-route" (`0.0.0.0/0`) and checks if we can still send packets to IPs outside of our tunnel. + #[tokio::test] + #[ignore = "Needs admin & Internet"] + async fn no_packet_loops() { + let ipv4 = Ipv4Addr::from([100, 90, 215, 97]); + let ipv6 = Ipv6Addr::from([0xfd00, 0x2021, 0x1111, 0x0, 0x0, 0x0, 0x0016, 0x588f]); + + let mut device_manager = TunDeviceManager::new().unwrap(); + let _tun = device_manager.make_tun().unwrap(); + device_manager.set_ips(ipv4, ipv6).await.unwrap(); + + // Configure `0.0.0.0/0` route. + device_manager + .set_routes( + vec![Ipv4Network::new(Ipv4Addr::UNSPECIFIED, 0).unwrap()], + vec![], + ) + .await + .unwrap(); + + // Make a socket. + let mut socket = + udp_socket_factory(&SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0))) + .unwrap(); + + // Send a STUN request. + socket + .send(DatagramOut { + src: None, + dst: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(141, 101, 90, 0), 3478)), // stun.cloudflare.com, + packet: Cow::Borrowed(&hex_literal::hex!( + "000100002112A4420123456789abcdef01234567" + )), + }) + .unwrap(); + + // First send seems to always result as would block + std::future::poll_fn(|cx| socket.poll_flush(cx)) + .await + .unwrap(); + + let task = std::future::poll_fn(|cx| { + let mut buf = [0u8; 1000]; + let result = std::task::ready!(socket.poll_recv_from(&mut buf, cx)); + + let _response = result.unwrap().next().unwrap(); + + std::task::Poll::Ready(()) + }); + + tokio::time::timeout(Duration::from_secs(10), task) + .await + .unwrap(); + } +} diff --git a/rust/socket-factory/src/lib.rs b/rust/socket-factory/src/lib.rs index 225895ed3..fe7967733 100644 --- a/rust/socket-factory/src/lib.rs +++ b/rust/socket-factory/src/lib.rs @@ -1,13 +1,15 @@ +use std::collections::HashMap; use std::{ borrow::Cow, collections::VecDeque, io::{self, IoSliceMut}, - net::SocketAddr, + net::{IpAddr, SocketAddr}, slice, task::{ready, Context, Poll}, }; use socket2::SockAddr; +use std::collections::hash_map::Entry; use tokio::io::Interest; pub trait SocketFactory: Fn(&SocketAddr) -> io::Result + Send + Sync + 'static {} @@ -52,6 +54,10 @@ impl TcpSocket { pub async fn connect(self, addr: SocketAddr) -> io::Result { self.inner.connect(addr).await } + + pub fn bind(&self, addr: SocketAddr) -> io::Result<()> { + self.inner.bind(addr) + } } #[cfg(unix)] @@ -71,6 +77,11 @@ impl std::os::fd::AsFd for TcpSocket { pub struct UdpSocket { inner: tokio::net::UdpSocket, state: quinn_udp::UdpSocketState, + source_ip_resolver: + Box std::io::Result> + Send + Sync + 'static>, + + /// A cache of source IPs by their destination IPs. + src_by_dst_cache: HashMap, port: u16, @@ -85,9 +96,26 @@ impl UdpSocket { state: quinn_udp::UdpSocketState::new(quinn_udp::UdpSockRef::from(&inner))?, port, inner, + source_ip_resolver: Box::new(|_| Ok(None)), buffered_datagrams: VecDeque::new(), + src_by_dst_cache: Default::default(), }) } + + /// Configures a new source IP resolver for this UDP socket. + /// + /// In case [`DatagramOut::src`] is [`None`], this function will be used to set a source IP given the destination IP of the datagram. + /// The resulting IPs will be cached. + /// To evict this cache, drop the [`UdpSocket`] and make a new one. + /// + /// Errors during resolution result in the packet being dropped. + pub fn with_source_ip_resolver( + mut self, + resolver: Box std::io::Result> + Send + Sync + 'static>, + ) -> Self { + self.source_ip_resolver = resolver; + self + } } #[cfg(unix)] @@ -224,19 +252,54 @@ impl UdpSocket { } } - pub fn try_send(&self, transmit: &DatagramOut) -> io::Result<()> { + pub fn try_send(&mut self, transmit: &DatagramOut) -> io::Result<()> { + let destination = transmit.dst; + let src_ip = transmit.src.map(|s| s.ip()); + + let src_ip = match src_ip { + Some(src_ip) => Some(src_ip), + None => match self.resolve_source_for(destination.ip()) { + Ok(src_ip) => src_ip, + Err(e) => { + tracing::trace!( + dst = %transmit.dst.ip(), + "No available interface for packet: {e}" + ); + return Ok(()); + } + }, + }; + let transmit = quinn_udp::Transmit { - destination: transmit.dst, + destination, ecn: None, contents: &transmit.packet, segment_size: None, - src_ip: transmit.src.map(|s| s.ip()), + src_ip, }; self.inner.try_io(Interest::WRITABLE, || { self.state.send((&self.inner).into(), &transmit) }) } + + /// Attempt to resolve the source IP to use for sending to the given destination IP. + fn resolve_source_for(&mut self, dst: IpAddr) -> std::io::Result> { + let src = match self.src_by_dst_cache.entry(dst) { + Entry::Occupied(occ) => *occ.get(), + Entry::Vacant(vac) => { + // Caching errors could be a good idea to not incur in multiple calls for the resolver which can be costly + // For some cases like hosts ipv4-only stack trying to send ipv6 packets this can happen quite often but doing this is also a risk + // that in case that the adapter for some reason is temporarily unavailable it'd prevent the system from recovery. + let Some(src) = (self.source_ip_resolver)(dst)? else { + return Ok(None); + }; + *vac.insert(src) + } + }; + + Ok(Some(src)) + } } #[cfg(feature = "hickory")]