diff --git a/rust/Cargo.lock b/rust/Cargo.lock index e800d8bc7..a9ab889a1 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -1338,6 +1338,7 @@ dependencies = [ "firezone-tunnel", "futures", "ip_network", + "l4-udp-dns-client", "libc", "parking_lot", "phoenix-channel", @@ -4188,6 +4189,18 @@ dependencies = [ "tracing", ] +[[package]] +name = "l4-udp-dns-client" +version = "0.1.0" +dependencies = [ + "anyhow", + "dns-types", + "futures", + "socket-factory", + "tokio", + "tracing", +] + [[package]] name = "l4-udp-dns-server" version = "0.1.0" diff --git a/rust/Cargo.toml b/rust/Cargo.toml index 3d9519616..33a543d20 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -13,6 +13,7 @@ members = [ "connlib/l3-tcp", "connlib/l3-udp-dns-client", "connlib/l4-tcp-dns-server", + "connlib/l4-udp-dns-client", "connlib/l4-udp-dns-server", "connlib/model", "connlib/phoenix-channel", @@ -112,6 +113,7 @@ known-folders = "1.4.0" l3-tcp = { path = "connlib/l3-tcp" } l3-udp-dns-client = { path = "connlib/l3-udp-dns-client" } l4-tcp-dns-server = { path = "connlib/l4-tcp-dns-server" } +l4-udp-dns-client = { path = "connlib/l4-udp-dns-client" } l4-udp-dns-server = { path = "connlib/l4-udp-dns-server" } libc = "0.2.176" libfuzzer-sys = "0.4" diff --git a/rust/client-shared/Cargo.toml b/rust/client-shared/Cargo.toml index 15e4f5dbb..e2a9bda54 100644 --- a/rust/client-shared/Cargo.toml +++ b/rust/client-shared/Cargo.toml @@ -14,6 +14,7 @@ firezone-logging = { workspace = true } firezone-tunnel = { workspace = true } futures = { workspace = true } ip_network = { workspace = true } +l4-udp-dns-client = { workspace = true } libc = { workspace = true } parking_lot = { workspace = true } phoenix-channel = { workspace = true } diff --git a/rust/client-shared/src/eventloop.rs b/rust/client-shared/src/eventloop.rs index 9fd893e35..e4432ca51 100644 --- a/rust/client-shared/src/eventloop.rs +++ b/rust/client-shared/src/eventloop.rs @@ -1,7 +1,6 @@ use crate::PHOENIX_TOPIC; use anyhow::{Context as _, Result}; use connlib_model::{PublicKey, ResourceView}; -use dns_types::DomainName; use firezone_tunnel::messages::RelaysPresence; use firezone_tunnel::messages::client::{ EgressMessages, FailReason, FlowCreated, FlowCreationFailed, GatewayIceCandidates, @@ -10,25 +9,22 @@ use firezone_tunnel::messages::client::{ use firezone_tunnel::{ ClientEvent, ClientTunnel, DnsResourceRecord, IpConfig, TunConfig, TunnelError, }; -use futures::TryFutureExt; -use futures::stream::FuturesUnordered; +use l4_udp_dns_client::UdpDnsClient; use parking_lot::Mutex; use phoenix_channel::{ErrorReply, PhoenixChannel, PublicKeyParam}; use socket_factory::{SocketFactory, TcpSocket, UdpSocket}; -use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr}; use std::ops::ControlFlow; use std::pin::pin; use std::sync::Arc; -use std::time::{Duration, Instant}; +use std::time::Instant; use std::{ collections::BTreeSet, io, net::IpAddr, task::{Context, Poll}, }; -use std::{future, iter, mem}; +use std::{future, mem}; use tokio::sync::{mpsc, watch}; -use tokio_stream::StreamExt; use tun::Tun; /// In-memory cache for DNS resource records. @@ -126,7 +122,7 @@ impl Eventloop { portal, portal_event_tx, portal_cmd_rx, - UdpDnsClient::new(udp_socket_factory), + udp_socket_factory.clone(), )); Self { @@ -506,12 +502,14 @@ async fn phoenix_channel_event_loop( mut portal: PhoenixChannel<(), IngressMessages, PublicKeyParam>, event_tx: mpsc::Sender>, mut cmd_rx: mpsc::Receiver, - mut udp_dns_client: UdpDnsClient, + udp_socket_factory: Arc>, ) { use futures::future::Either; use futures::future::select; use std::future::poll_fn; + let mut udp_dns_client = UdpDnsClient::new(udp_socket_factory.clone(), vec![]); + loop { match select(poll_fn(|cx| portal.poll(cx)), pin!(cmd_rx.recv())).await { Either::Left((Ok(phoenix_channel::Event::InboundMessage { msg, .. }), _)) => { @@ -580,7 +578,7 @@ async fn phoenix_channel_event_loop( portal.connect(param); } Either::Right((Some(PortalCommand::UpdateDnsServers(servers)), _)) => { - udp_dns_client.servers = servers; + udp_dns_client = UdpDnsClient::new(udp_socket_factory.clone(), servers); } Either::Right((None, _)) => { tracing::debug!("Command channel closed: exiting phoenix-channel event-loop"); @@ -601,130 +599,3 @@ fn is_unreachable(e: &io::Error) -> bool { || e.kind() == io::ErrorKind::HostUnreachable || e.kind() == io::ErrorKind::AddrNotAvailable } - -struct UdpDnsClient { - socket_factory: Arc>, - servers: Vec, -} - -impl UdpDnsClient { - const TIMEOUT: Duration = Duration::from_secs(2); - - fn new(socket_factory: Arc>) -> Self { - Self { - socket_factory, - servers: Vec::default(), - } - } - - async fn resolve(&self, host: String) -> Result> { - let host = DomainName::vec_from_str(&host).context("Failed to parse domain name")?; - let servers = self.servers.clone(); - - let (a_records, aaaa_records) = self - .servers - .iter() - .map(|socket| { - futures::future::try_join( - self.send( - SocketAddr::new(*socket, 53), - dns_types::Query::new(host.clone(), dns_types::RecordType::A), - ), - self.send( - SocketAddr::new(*socket, 53), - dns_types::Query::new(host.clone(), dns_types::RecordType::AAAA), - ), - ) - .map_err(|e| { - tracing::debug!(%host, "DNS query failed: {e:#}"); - - e - }) - }) - .collect::>() - .filter_map(|result| result.ok()) - .filter(|(a, b)| { - a.response_code() == dns_types::ResponseCode::NOERROR - && b.response_code() == dns_types::ResponseCode::NOERROR - }) - .next() - .await - .with_context(|| { - format!("All DNS servers ({servers:?}) failed to resolve portal host '{host}'") - })?; - - let ips = iter::empty() - .chain( - a_records - .records() - .filter_map(dns_types::records::extract_ip), - ) - .chain( - aaaa_records - .records() - .filter_map(dns_types::records::extract_ip), - ) - .collect(); - - Ok(ips) - } - - async fn send( - &self, - server: SocketAddr, - query: dns_types::Query, - ) -> io::Result { - let bind_addr = match server { - SocketAddr::V4(_) => SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), 0), - SocketAddr::V6(_) => SocketAddr::new(Ipv6Addr::UNSPECIFIED.into(), 0), - }; - - // To avoid fragmentation, IP and thus also UDP packets can only reliably sent with an MTU of <= 1500 on the public Internet. - const BUF_SIZE: usize = 1500; - - let udp_socket = self.socket_factory.bind(bind_addr)?; - - let response = tokio::time::timeout( - Self::TIMEOUT, - udp_socket.handshake::(server, &query.into_bytes()), - ) - .await??; - - let response = dns_types::Response::parse(&response).map_err(io::Error::other)?; - - Ok(response) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[tokio::test] - #[ignore = "Requires Internet"] - async fn udp_dns_client_can_resolve_host() { - let mut client = UdpDnsClient::new(Arc::new(socket_factory::udp)); - client.servers = vec![IpAddr::from([1, 1, 1, 1])]; - - let ips = client.resolve("example.com".to_owned()).await.unwrap(); - - assert!(!ips.is_empty()) - } - - #[tokio::test] - #[ignore = "Requires Internet"] - async fn udp_dns_client_times_out_unreachable_host() { - let mut client = UdpDnsClient::new(Arc::new(socket_factory::udp)); - client.servers = vec![IpAddr::from([2, 2, 2, 2])]; - - let now = Instant::now(); - - let error = client.resolve("example.com".to_owned()).await.unwrap_err(); - - assert_eq!( - error.to_string(), - "All DNS servers ([2.2.2.2]) failed to resolve portal host 'example.com'" - ); - assert!(now.elapsed() >= UdpDnsClient::TIMEOUT) - } -} diff --git a/rust/connlib/l4-udp-dns-client/Cargo.toml b/rust/connlib/l4-udp-dns-client/Cargo.toml new file mode 100644 index 000000000..bc15ee0da --- /dev/null +++ b/rust/connlib/l4-udp-dns-client/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "l4-udp-dns-client" +version = "0.1.0" +edition = { workspace = true } +license = { workspace = true } + +[lib] +path = "lib.rs" + +[dependencies] +anyhow = { workspace = true } +dns-types = { workspace = true } +futures = { workspace = true } +socket-factory = { workspace = true } +tokio = { workspace = true, features = ["time"] } +tracing = { workspace = true } + +[dev-dependencies] +tokio = { workspace = true, features = ["rt", "macros"] } + +[lints] +workspace = true diff --git a/rust/connlib/l4-udp-dns-client/lib.rs b/rust/connlib/l4-udp-dns-client/lib.rs new file mode 100644 index 000000000..8cfc9b873 --- /dev/null +++ b/rust/connlib/l4-udp-dns-client/lib.rs @@ -0,0 +1,158 @@ +#![cfg_attr(test, allow(clippy::unwrap_used))] + +use std::{ + collections::BTreeSet, + net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, + sync::Arc, + time::Duration, +}; + +use anyhow::{Context as _, Result}; +use dns_types::DomainName; +use futures::stream::FuturesUnordered; +use futures::stream::StreamExt as _; +use socket_factory::{SocketFactory, UdpSocket}; + +/// A UDP DNS client, specialised for resolving host names to IP addresses. +/// +/// The implementation uses a multi-shot approach where all configured upstream servers are contacted in parallel. +/// All successful responses are merged together. +pub struct UdpDnsClient { + socket_factory: Arc>, + servers: Vec, +} + +impl UdpDnsClient { + const TIMEOUT: Duration = Duration::from_secs(2); + + pub fn new(socket_factory: Arc>, servers: Vec) -> Self { + Self { + socket_factory, + servers, + } + } + + pub fn resolve(&self, host: String) -> impl Future>> + use<> { + let servers = self.servers.clone(); + let socket_factory = self.socket_factory.clone(); + + async move { + let host = DomainName::vec_from_str(&host).context("Failed to parse domain name")?; + + let ips = servers + .iter() + .flat_map(|socket| { + let socket = SocketAddr::new(*socket, 53); + + [ + send_query( + socket_factory.clone(), + socket, + dns_types::Query::new(host.clone(), dns_types::RecordType::A), + ), + send_query( + socket_factory.clone(), + socket, + dns_types::Query::new(host.clone(), dns_types::RecordType::AAAA), + ), + ] + }) + .collect::>() + .collect::>() + .await + .into_iter() + .flat_map(|result| result.inspect_err(|e| tracing::debug!("{e:#}")).ok()) + .filter(|response| response.response_code() == dns_types::ResponseCode::NOERROR) + .flat_map(|response| { + response + .records() + .filter_map(dns_types::records::extract_ip) + .collect::>() + }) + .collect::>(); // Make them unique. + + Ok(Vec::from_iter(ips)) + } + } +} + +async fn send_query( + socket_factory: Arc>, + server: SocketAddr, + query: dns_types::Query, +) -> Result { + let bind_addr = match server { + SocketAddr::V4(_) => SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), 0), + SocketAddr::V6(_) => SocketAddr::new(Ipv6Addr::UNSPECIFIED.into(), 0), + }; + + // To avoid fragmentation, IP and thus also UDP packets can only reliably sent with an MTU of <= 1500 on the public Internet. + const BUF_SIZE: usize = 1500; + + let udp_socket = socket_factory + .bind(bind_addr) + .context("Failed to bind UDP socket")?; + + let response = tokio::time::timeout( + UdpDnsClient::TIMEOUT, + udp_socket.handshake::(server, &query.into_bytes()), + ) + .await + .with_context(|| format!("DNS query to host {server} timed out"))??; + + let response = dns_types::Response::parse(&response).context("Failed to parse DNS response")?; + + Ok(response) +} + +#[cfg(test)] +mod tests { + use std::time::Instant; + + use super::*; + + #[tokio::test] + #[ignore = "Requires Internet"] + async fn can_resolve_host() { + let client = UdpDnsClient::new( + Arc::new(socket_factory::udp), + vec![IpAddr::from([1, 1, 1, 1])], + ); + + let ips = client.resolve("example.com".to_owned()).await.unwrap(); + + assert!(!ips.is_empty()) + } + + #[tokio::test] + #[ignore = "Requires Internet"] + async fn times_out_unreachable_host() { + let client = UdpDnsClient::new( + Arc::new(socket_factory::udp), + vec![IpAddr::from([2, 2, 2, 2])], + ); + + let now = Instant::now(); + + let ips = client.resolve("example.com".to_owned()).await.unwrap(); + + assert!(ips.is_empty()); + assert!(now.elapsed() >= UdpDnsClient::TIMEOUT) + } + + #[tokio::test] + #[ignore = "Requires Internet"] + async fn returns_all_valid_records() { + let client = UdpDnsClient::new( + Arc::new(socket_factory::udp), + vec![IpAddr::from([1, 1, 1, 1]), IpAddr::from([2, 2, 2, 2])], + ); + + let now = Instant::now(); + + let ips = client.resolve("example.com".to_owned()).await.unwrap(); + + assert!(!ips.is_empty()); + assert!(now.elapsed() >= UdpDnsClient::TIMEOUT) // Still need to wait for the unreachable server. + } +}