From 1a5c40bd75abfff485d393c7fe38588b91f0b843 Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Fri, 14 Nov 2025 00:19:22 +1100 Subject: [PATCH] refactor(connlib): extract `l4-udp-dns-client` (#10854) In order to bootstrap DoH servers, we need a way of reliably resolving the domain of the DoH server to an IP address. Initially, I thought that this would be tricky to do if we have to integrate this into the Client's state machine. Whilst implementing DoH however, I realised that we can instead put this responsibility onto the IO layer of connlib. Similar to other cases, we can reuse external triggers as our retry mechanism in case of failure. In particular, we can simply issue UDP DNS queries for the DoH domain to all system-defined DNS resolvers every time we are told to send a DNS query over DoH but the corresponding client isn't initialized yet. In other words, instead of building a retry mechanism ourselves, we attempt to repair any kind of broken state once per DNS query that we receive. Performing this DNS resolution does require a bit of code. We already started to do something similar in #10817. In order to reuse that code, we extract it into a `l4-udp-dns-client` crate and slightly refactor its semantics. In particular, we now wait for the response of all upstream servers (but at most 2s) and combine the result. The resulting `UdpDnsClient` can now be used inside the Client's event-loop to re-resolve the portal URL and will also be used as part of our DoH implementation to bootstrap the connection to the DoH server. Related: #4668 --- rust/Cargo.lock | 13 ++ rust/Cargo.toml | 2 + rust/client-shared/Cargo.toml | 1 + rust/client-shared/src/eventloop.rs | 145 ++------------------ rust/connlib/l4-udp-dns-client/Cargo.toml | 22 +++ rust/connlib/l4-udp-dns-client/lib.rs | 158 ++++++++++++++++++++++ 6 files changed, 204 insertions(+), 137 deletions(-) create mode 100644 rust/connlib/l4-udp-dns-client/Cargo.toml create mode 100644 rust/connlib/l4-udp-dns-client/lib.rs diff --git a/rust/Cargo.lock b/rust/Cargo.lock index e800d8bc7..a9ab889a1 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -1338,6 +1338,7 @@ dependencies = [ "firezone-tunnel", "futures", "ip_network", + "l4-udp-dns-client", "libc", "parking_lot", "phoenix-channel", @@ -4188,6 +4189,18 @@ dependencies = [ "tracing", ] +[[package]] +name = "l4-udp-dns-client" +version = "0.1.0" +dependencies = [ + "anyhow", + "dns-types", + "futures", + "socket-factory", + "tokio", + "tracing", +] + [[package]] name = "l4-udp-dns-server" version = "0.1.0" diff --git a/rust/Cargo.toml b/rust/Cargo.toml index 3d9519616..33a543d20 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -13,6 +13,7 @@ members = [ "connlib/l3-tcp", "connlib/l3-udp-dns-client", "connlib/l4-tcp-dns-server", + "connlib/l4-udp-dns-client", "connlib/l4-udp-dns-server", "connlib/model", "connlib/phoenix-channel", @@ -112,6 +113,7 @@ known-folders = "1.4.0" l3-tcp = { path = "connlib/l3-tcp" } l3-udp-dns-client = { path = "connlib/l3-udp-dns-client" } l4-tcp-dns-server = { path = "connlib/l4-tcp-dns-server" } +l4-udp-dns-client = { path = "connlib/l4-udp-dns-client" } l4-udp-dns-server = { path = "connlib/l4-udp-dns-server" } libc = "0.2.176" libfuzzer-sys = "0.4" diff --git a/rust/client-shared/Cargo.toml b/rust/client-shared/Cargo.toml index 15e4f5dbb..e2a9bda54 100644 --- a/rust/client-shared/Cargo.toml +++ b/rust/client-shared/Cargo.toml @@ -14,6 +14,7 @@ firezone-logging = { workspace = true } firezone-tunnel = { workspace = true } futures = { workspace = true } ip_network = { workspace = true } +l4-udp-dns-client = { workspace = true } libc = { workspace = true } parking_lot = { workspace = true } phoenix-channel = { workspace = true } diff --git a/rust/client-shared/src/eventloop.rs b/rust/client-shared/src/eventloop.rs index 9fd893e35..e4432ca51 100644 --- a/rust/client-shared/src/eventloop.rs +++ b/rust/client-shared/src/eventloop.rs @@ -1,7 +1,6 @@ use crate::PHOENIX_TOPIC; use anyhow::{Context as _, Result}; use connlib_model::{PublicKey, ResourceView}; -use dns_types::DomainName; use firezone_tunnel::messages::RelaysPresence; use firezone_tunnel::messages::client::{ EgressMessages, FailReason, FlowCreated, FlowCreationFailed, GatewayIceCandidates, @@ -10,25 +9,22 @@ use firezone_tunnel::messages::client::{ use firezone_tunnel::{ ClientEvent, ClientTunnel, DnsResourceRecord, IpConfig, TunConfig, TunnelError, }; -use futures::TryFutureExt; -use futures::stream::FuturesUnordered; +use l4_udp_dns_client::UdpDnsClient; use parking_lot::Mutex; use phoenix_channel::{ErrorReply, PhoenixChannel, PublicKeyParam}; use socket_factory::{SocketFactory, TcpSocket, UdpSocket}; -use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr}; use std::ops::ControlFlow; use std::pin::pin; use std::sync::Arc; -use std::time::{Duration, Instant}; +use std::time::Instant; use std::{ collections::BTreeSet, io, net::IpAddr, task::{Context, Poll}, }; -use std::{future, iter, mem}; +use std::{future, mem}; use tokio::sync::{mpsc, watch}; -use tokio_stream::StreamExt; use tun::Tun; /// In-memory cache for DNS resource records. @@ -126,7 +122,7 @@ impl Eventloop { portal, portal_event_tx, portal_cmd_rx, - UdpDnsClient::new(udp_socket_factory), + udp_socket_factory.clone(), )); Self { @@ -506,12 +502,14 @@ async fn phoenix_channel_event_loop( mut portal: PhoenixChannel<(), IngressMessages, PublicKeyParam>, event_tx: mpsc::Sender>, mut cmd_rx: mpsc::Receiver, - mut udp_dns_client: UdpDnsClient, + udp_socket_factory: Arc>, ) { use futures::future::Either; use futures::future::select; use std::future::poll_fn; + let mut udp_dns_client = UdpDnsClient::new(udp_socket_factory.clone(), vec![]); + loop { match select(poll_fn(|cx| portal.poll(cx)), pin!(cmd_rx.recv())).await { Either::Left((Ok(phoenix_channel::Event::InboundMessage { msg, .. }), _)) => { @@ -580,7 +578,7 @@ async fn phoenix_channel_event_loop( portal.connect(param); } Either::Right((Some(PortalCommand::UpdateDnsServers(servers)), _)) => { - udp_dns_client.servers = servers; + udp_dns_client = UdpDnsClient::new(udp_socket_factory.clone(), servers); } Either::Right((None, _)) => { tracing::debug!("Command channel closed: exiting phoenix-channel event-loop"); @@ -601,130 +599,3 @@ fn is_unreachable(e: &io::Error) -> bool { || e.kind() == io::ErrorKind::HostUnreachable || e.kind() == io::ErrorKind::AddrNotAvailable } - -struct UdpDnsClient { - socket_factory: Arc>, - servers: Vec, -} - -impl UdpDnsClient { - const TIMEOUT: Duration = Duration::from_secs(2); - - fn new(socket_factory: Arc>) -> Self { - Self { - socket_factory, - servers: Vec::default(), - } - } - - async fn resolve(&self, host: String) -> Result> { - let host = DomainName::vec_from_str(&host).context("Failed to parse domain name")?; - let servers = self.servers.clone(); - - let (a_records, aaaa_records) = self - .servers - .iter() - .map(|socket| { - futures::future::try_join( - self.send( - SocketAddr::new(*socket, 53), - dns_types::Query::new(host.clone(), dns_types::RecordType::A), - ), - self.send( - SocketAddr::new(*socket, 53), - dns_types::Query::new(host.clone(), dns_types::RecordType::AAAA), - ), - ) - .map_err(|e| { - tracing::debug!(%host, "DNS query failed: {e:#}"); - - e - }) - }) - .collect::>() - .filter_map(|result| result.ok()) - .filter(|(a, b)| { - a.response_code() == dns_types::ResponseCode::NOERROR - && b.response_code() == dns_types::ResponseCode::NOERROR - }) - .next() - .await - .with_context(|| { - format!("All DNS servers ({servers:?}) failed to resolve portal host '{host}'") - })?; - - let ips = iter::empty() - .chain( - a_records - .records() - .filter_map(dns_types::records::extract_ip), - ) - .chain( - aaaa_records - .records() - .filter_map(dns_types::records::extract_ip), - ) - .collect(); - - Ok(ips) - } - - async fn send( - &self, - server: SocketAddr, - query: dns_types::Query, - ) -> io::Result { - let bind_addr = match server { - SocketAddr::V4(_) => SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), 0), - SocketAddr::V6(_) => SocketAddr::new(Ipv6Addr::UNSPECIFIED.into(), 0), - }; - - // To avoid fragmentation, IP and thus also UDP packets can only reliably sent with an MTU of <= 1500 on the public Internet. - const BUF_SIZE: usize = 1500; - - let udp_socket = self.socket_factory.bind(bind_addr)?; - - let response = tokio::time::timeout( - Self::TIMEOUT, - udp_socket.handshake::(server, &query.into_bytes()), - ) - .await??; - - let response = dns_types::Response::parse(&response).map_err(io::Error::other)?; - - Ok(response) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[tokio::test] - #[ignore = "Requires Internet"] - async fn udp_dns_client_can_resolve_host() { - let mut client = UdpDnsClient::new(Arc::new(socket_factory::udp)); - client.servers = vec![IpAddr::from([1, 1, 1, 1])]; - - let ips = client.resolve("example.com".to_owned()).await.unwrap(); - - assert!(!ips.is_empty()) - } - - #[tokio::test] - #[ignore = "Requires Internet"] - async fn udp_dns_client_times_out_unreachable_host() { - let mut client = UdpDnsClient::new(Arc::new(socket_factory::udp)); - client.servers = vec![IpAddr::from([2, 2, 2, 2])]; - - let now = Instant::now(); - - let error = client.resolve("example.com".to_owned()).await.unwrap_err(); - - assert_eq!( - error.to_string(), - "All DNS servers ([2.2.2.2]) failed to resolve portal host 'example.com'" - ); - assert!(now.elapsed() >= UdpDnsClient::TIMEOUT) - } -} diff --git a/rust/connlib/l4-udp-dns-client/Cargo.toml b/rust/connlib/l4-udp-dns-client/Cargo.toml new file mode 100644 index 000000000..bc15ee0da --- /dev/null +++ b/rust/connlib/l4-udp-dns-client/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "l4-udp-dns-client" +version = "0.1.0" +edition = { workspace = true } +license = { workspace = true } + +[lib] +path = "lib.rs" + +[dependencies] +anyhow = { workspace = true } +dns-types = { workspace = true } +futures = { workspace = true } +socket-factory = { workspace = true } +tokio = { workspace = true, features = ["time"] } +tracing = { workspace = true } + +[dev-dependencies] +tokio = { workspace = true, features = ["rt", "macros"] } + +[lints] +workspace = true diff --git a/rust/connlib/l4-udp-dns-client/lib.rs b/rust/connlib/l4-udp-dns-client/lib.rs new file mode 100644 index 000000000..8cfc9b873 --- /dev/null +++ b/rust/connlib/l4-udp-dns-client/lib.rs @@ -0,0 +1,158 @@ +#![cfg_attr(test, allow(clippy::unwrap_used))] + +use std::{ + collections::BTreeSet, + net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, + sync::Arc, + time::Duration, +}; + +use anyhow::{Context as _, Result}; +use dns_types::DomainName; +use futures::stream::FuturesUnordered; +use futures::stream::StreamExt as _; +use socket_factory::{SocketFactory, UdpSocket}; + +/// A UDP DNS client, specialised for resolving host names to IP addresses. +/// +/// The implementation uses a multi-shot approach where all configured upstream servers are contacted in parallel. +/// All successful responses are merged together. +pub struct UdpDnsClient { + socket_factory: Arc>, + servers: Vec, +} + +impl UdpDnsClient { + const TIMEOUT: Duration = Duration::from_secs(2); + + pub fn new(socket_factory: Arc>, servers: Vec) -> Self { + Self { + socket_factory, + servers, + } + } + + pub fn resolve(&self, host: String) -> impl Future>> + use<> { + let servers = self.servers.clone(); + let socket_factory = self.socket_factory.clone(); + + async move { + let host = DomainName::vec_from_str(&host).context("Failed to parse domain name")?; + + let ips = servers + .iter() + .flat_map(|socket| { + let socket = SocketAddr::new(*socket, 53); + + [ + send_query( + socket_factory.clone(), + socket, + dns_types::Query::new(host.clone(), dns_types::RecordType::A), + ), + send_query( + socket_factory.clone(), + socket, + dns_types::Query::new(host.clone(), dns_types::RecordType::AAAA), + ), + ] + }) + .collect::>() + .collect::>() + .await + .into_iter() + .flat_map(|result| result.inspect_err(|e| tracing::debug!("{e:#}")).ok()) + .filter(|response| response.response_code() == dns_types::ResponseCode::NOERROR) + .flat_map(|response| { + response + .records() + .filter_map(dns_types::records::extract_ip) + .collect::>() + }) + .collect::>(); // Make them unique. + + Ok(Vec::from_iter(ips)) + } + } +} + +async fn send_query( + socket_factory: Arc>, + server: SocketAddr, + query: dns_types::Query, +) -> Result { + let bind_addr = match server { + SocketAddr::V4(_) => SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), 0), + SocketAddr::V6(_) => SocketAddr::new(Ipv6Addr::UNSPECIFIED.into(), 0), + }; + + // To avoid fragmentation, IP and thus also UDP packets can only reliably sent with an MTU of <= 1500 on the public Internet. + const BUF_SIZE: usize = 1500; + + let udp_socket = socket_factory + .bind(bind_addr) + .context("Failed to bind UDP socket")?; + + let response = tokio::time::timeout( + UdpDnsClient::TIMEOUT, + udp_socket.handshake::(server, &query.into_bytes()), + ) + .await + .with_context(|| format!("DNS query to host {server} timed out"))??; + + let response = dns_types::Response::parse(&response).context("Failed to parse DNS response")?; + + Ok(response) +} + +#[cfg(test)] +mod tests { + use std::time::Instant; + + use super::*; + + #[tokio::test] + #[ignore = "Requires Internet"] + async fn can_resolve_host() { + let client = UdpDnsClient::new( + Arc::new(socket_factory::udp), + vec![IpAddr::from([1, 1, 1, 1])], + ); + + let ips = client.resolve("example.com".to_owned()).await.unwrap(); + + assert!(!ips.is_empty()) + } + + #[tokio::test] + #[ignore = "Requires Internet"] + async fn times_out_unreachable_host() { + let client = UdpDnsClient::new( + Arc::new(socket_factory::udp), + vec![IpAddr::from([2, 2, 2, 2])], + ); + + let now = Instant::now(); + + let ips = client.resolve("example.com".to_owned()).await.unwrap(); + + assert!(ips.is_empty()); + assert!(now.elapsed() >= UdpDnsClient::TIMEOUT) + } + + #[tokio::test] + #[ignore = "Requires Internet"] + async fn returns_all_valid_records() { + let client = UdpDnsClient::new( + Arc::new(socket_factory::udp), + vec![IpAddr::from([1, 1, 1, 1]), IpAddr::from([2, 2, 2, 2])], + ); + + let now = Instant::now(); + + let ips = client.resolve("example.com".to_owned()).await.unwrap(); + + assert!(!ips.is_empty()); + assert!(now.elapsed() >= UdpDnsClient::TIMEOUT) // Still need to wait for the unreachable server. + } +}