diff --git a/rust/Cargo.lock b/rust/Cargo.lock index ffa30956d..a08cb31c0 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -2027,6 +2027,7 @@ dependencies = [ "moka", "nix 0.29.0", "phoenix-channel", + "resolv-conf", "rustls", "secrecy", "serde", diff --git a/rust/Cargo.toml b/rust/Cargo.toml index 47ee4a2d0..293e2c3b7 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -93,6 +93,7 @@ rand_core = "0.6.4" rangemap = "1.5.1" rayon = "1.10.0" reqwest = { version = "0.12.9", default-features = false } +resolv-conf = "0.7.0" rtnetlink = { version = "0.14.1", default-features = false, features = ["tokio_socket"] } rustls = { version = "0.23.21", default-features = false, features = ["ring"] } sadness-generator = "0.6.0" diff --git a/rust/connlib/l4-tcp-dns-server/lib.rs b/rust/connlib/l4-tcp-dns-server/lib.rs index fa41fba62..2a05d41e5 100644 --- a/rust/connlib/l4-tcp-dns-server/lib.rs +++ b/rust/connlib/l4-tcp-dns-server/lib.rs @@ -118,9 +118,15 @@ impl Server { continue; }; - self.tcp_streams_by_remote.insert(from, stream); // Store the stream so we can send a response back later. + let local = stream.local_addr()?; + + // Store the stream so we can send a response back later. + // We don't need to index by the local address because we only ever listen on a single socket. + self.tcp_streams_by_remote.insert(from, stream); + return Poll::Ready(Ok(Query { - source: from, + local, + remote: from, message, })); } @@ -179,7 +185,8 @@ async fn read_tcp_query( } pub struct Query { - pub source: SocketAddr, + pub local: SocketAddr, + pub remote: SocketAddr, pub message: Message>, } @@ -220,7 +227,7 @@ mod tests { let query = poll_fn(|cx| server.poll(cx)).await.unwrap(); server - .send_response(query.source, empty_dns_response(query.message)) + .send_response(query.remote, empty_dns_response(query.message)) .unwrap(); } }); diff --git a/rust/connlib/tunnel/src/dns.rs b/rust/connlib/tunnel/src/dns.rs index a52454619..254a5bf8d 100644 --- a/rust/connlib/tunnel/src/dns.rs +++ b/rust/connlib/tunnel/src/dns.rs @@ -86,7 +86,7 @@ impl RecursiveQuery { } } -#[derive(Debug)] +#[derive(Debug, Clone, Copy)] pub(crate) enum Transport { Udp { /// The original source we received the DNS query on. diff --git a/rust/connlib/tunnel/src/io.rs b/rust/connlib/tunnel/src/io.rs index 153eeb4fb..15bda0115 100644 --- a/rust/connlib/tunnel/src/io.rs +++ b/rust/connlib/tunnel/src/io.rs @@ -1,4 +1,7 @@ mod gso_queue; +mod nameserver_set; +mod tcp_dns; +mod udp_dns; use crate::{device_channel::Device, dns, sockets::Sockets}; use anyhow::Result; @@ -8,17 +11,17 @@ use futures::FutureExt as _; use futures_bounded::FuturesTupleSet; use gso_queue::GsoQueue; use ip_packet::{IpPacket, MAX_FZ_PAYLOAD}; +use nameserver_set::NameserverSet; use socket_factory::{DatagramIn, SocketFactory, TcpSocket, UdpSocket}; use std::{ - collections::VecDeque, + collections::{BTreeSet, VecDeque}, io, - net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}, + net::{IpAddr, SocketAddr, SocketAddrV4, SocketAddrV6}, pin::Pin, sync::Arc, task::{ready, Context, Poll}, time::{Duration, Instant}, }; -use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tracing::Instrument; use tun::Tun; @@ -45,6 +48,8 @@ pub struct Io { sockets: Sockets, gso_queue: GsoQueue, + nameservers: NameserverSet, + udp_dns_server: l4_udp_dns_server::Server, tcp_dns_server: l4_tcp_dns_server::Server, @@ -100,14 +105,19 @@ impl Io { pub fn new( tcp_socket_factory: Arc>, udp_socket_factory: Arc>, + nameservers: BTreeSet, ) -> Self { let mut sockets = Sockets::default(); sockets.rebind(udp_socket_factory.as_ref()); // Bind sockets on startup. Must happen within a tokio runtime context. + let mut nameservers = NameserverSet::new(nameservers, udp_socket_factory.clone()); + nameservers.evaluate(); + Self { outbound_packet_buffer: VecDeque::with_capacity(10), // It is unlikely that we process more than 10 packets after 1 GRO call. timeout: None, sockets, + nameservers, tcp_socket_factory, udp_socket_factory, dns_queries: FuturesTupleSet::new(DNS_QUERY_TIMEOUT, 1000), @@ -136,6 +146,15 @@ impl Io { self.sockets.poll_has_sockets(cx) } + pub fn fastest_nameserver(&self) -> io::Result { + let ns = self + .nameservers + .fastest() + .ok_or(io::Error::other(NoNameserverAvailable))?; + + Ok(ns) + } + pub fn poll<'b>( &mut self, cx: &mut Context<'_>, @@ -146,6 +165,7 @@ impl Io { >, > { ready!(self.flush_send_queue(cx)?); + ready!(self.nameservers.poll(cx)); if let Poll::Ready(network) = self.sockets @@ -255,6 +275,7 @@ impl Io { self.sockets.rebind(self.udp_socket_factory.as_ref()); self.gso_queue.clear(); self.dns_queries = FuturesTupleSet::new(DNS_QUERY_TIMEOUT, 1000); + self.nameservers.evaluate(); } pub fn reset_timeout(&mut self, timeout: Instant) { @@ -274,39 +295,19 @@ impl Io { } pub fn send_dns_query(&mut self, query: dns::RecursiveQuery) { + let meta = DnsQueryMetaData { + query: query.message.clone(), + server: query.server, + transport: query.transport, + }; + match query.transport { dns::Transport::Udp { .. } => { - let factory = self.udp_socket_factory.clone(); - let server = query.server; - let bind_addr = match query.server { - SocketAddr::V4(_) => SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), 0), - SocketAddr::V6(_) => SocketAddr::new(Ipv6Addr::UNSPECIFIED.into(), 0), - }; - let meta = DnsQueryMetaData { - query: query.message.clone(), - server, - transport: query.transport, - }; - if self .dns_queries .try_push( - async move { - // To avoid fragmentation, IP and thus also UDP packets can only reliably sent with an MTU of <= 1500 on the public Internet. - const BUF_SIZE: usize = 1500; - - let udp_socket = factory(&bind_addr)?; - - let response = udp_socket - .handshake::(server, query.message.as_slice()) - .await?; - - let message = Message::from_octets(response) - .map_err(|_| io::Error::other("Failed to parse DNS message"))?; - - Ok(message) - } - .instrument(telemetry_span!("recursive_udp_dns_query")), + udp_dns::send(self.udp_socket_factory.clone(), query.server, query.message) + .instrument(telemetry_span!("recursive_udp_dns_query")), meta, ) .is_err() @@ -315,41 +316,11 @@ impl Io { } } dns::Transport::Tcp { .. } => { - let factory = self.tcp_socket_factory.clone(); - let server = query.server; - let meta = DnsQueryMetaData { - query: query.message.clone(), - server, - transport: query.transport, - }; - if self .dns_queries .try_push( - async move { - let tcp_socket = factory(&server)?; // TODO: Optimise this to reuse a TCP socket to the same resolver. - let mut tcp_stream = tcp_socket.connect(server).await?; - - let query = query.message.into_octets(); - let dns_message_length = (query.len() as u16).to_be_bytes(); - - tcp_stream.write_all(&dns_message_length).await?; - tcp_stream.write_all(&query).await?; - - let mut response_length = [0u8; 2]; - tcp_stream.read_exact(&mut response_length).await?; - let response_length = u16::from_be_bytes(response_length) as usize; - - // A u16 is at most 65k, meaning we are okay to allocate here based on what the remote is sending. - let mut response = vec![0u8; response_length]; - tcp_stream.read_exact(&mut response).await?; - - let message = Message::from_octets(response) - .map_err(|_| io::Error::other("Failed to parse DNS message"))?; - - Ok(message) - } - .instrument(telemetry_span!("recursive_tcp_dns_query")), + tcp_dns::send(self.tcp_socket_factory.clone(), query.server, query.message) + .instrument(telemetry_span!("recursive_tcp_dns_query")), meta, ) .is_err() @@ -377,6 +348,10 @@ impl Io { } } +#[derive(Debug, thiserror::Error)] +#[error("No nameserver available to handle DNS query")] +pub struct NoNameserverAvailable; + fn is_max_wg_packet_size(d: &DatagramIn) -> bool { let len = d.packet.len(); if len > MAX_FZ_PAYLOAD { @@ -446,6 +421,7 @@ mod tests { let mut io = Io::new( Arc::new(|_| Err(io::Error::other("not implemented"))), Arc::new(|_| Err(io::Error::other("not implemented"))), + BTreeSet::new(), ); io.set_tun(Box::new(DummyTun)); diff --git a/rust/connlib/tunnel/src/io/nameserver_set.rs b/rust/connlib/tunnel/src/io/nameserver_set.rs new file mode 100644 index 000000000..bc9ce28f9 --- /dev/null +++ b/rust/connlib/tunnel/src/io/nameserver_set.rs @@ -0,0 +1,167 @@ +use std::{ + collections::{BTreeMap, BTreeSet}, + io, + net::{IpAddr, SocketAddr}, + sync::{Arc, LazyLock}, + task::{ready, Context, Poll}, + time::{Duration, Instant}, +}; + +use connlib_model::DomainName; +use domain::base::{iana::Rcode, Message, MessageBuilder, Question, Rtype}; +use futures_bounded::FuturesTupleSet; +use socket_factory::{SocketFactory, UdpSocket}; + +use crate::io::udp_dns; + +const MAX_DNS_SERVERS: usize = 10; // We don't bother selecting from more than 10 servers. +const DNS_TIMEOUT: Duration = Duration::from_secs(2); // Every sensible DNS servers should respond within 2s. + +static FIREZONE_DEV: LazyLock = LazyLock::new(|| { + DomainName::vec_from_str("firezone.dev").expect("static domain should always parse") +}); + +pub struct NameserverSet { + inner: BTreeSet, + nameserver_by_rtt: BTreeMap, + + socket_factory: Arc>, + queries: FuturesTupleSet>>, QueryMetaData>, +} + +struct QueryMetaData { + nameserver: IpAddr, + start: Instant, +} + +impl NameserverSet { + pub fn new( + inner: BTreeSet, + udp_socket_factory: Arc>, + ) -> Self { + Self { + queries: FuturesTupleSet::new(DNS_TIMEOUT, MAX_DNS_SERVERS), + inner, + socket_factory: udp_socket_factory, + nameserver_by_rtt: Default::default(), + } + } + + pub fn evaluate(&mut self) { + self.nameserver_by_rtt.clear(); + let start = Instant::now(); + + for nameserver in self.inner.iter().copied() { + if self + .queries + .try_push( + udp_dns::send( + self.socket_factory.clone(), + SocketAddr::new(nameserver, crate::dns::DNS_PORT), + query_firezone_dev(), + ), + QueryMetaData { nameserver, start }, + ) + .is_err() + { + tracing::debug!(%nameserver, "Failed to queue another DNS query"); + } + } + } + + pub fn fastest(&self) -> Option { + let (_, ns) = self.nameserver_by_rtt.first_key_value()?; + + Some(*ns) + } + + pub fn poll(&mut self, cx: &mut Context<'_>) -> Poll<()> { + if self.queries.is_empty() { + return Poll::Ready(()); + } + + loop { + match ready!(self.queries.poll_unpin(cx)) { + (Ok(Ok(response)), meta) if response.header().rcode() == Rcode::NOERROR => { + let rtt = meta.start.elapsed(); + + tracing::debug!(nameserver = %meta.nameserver, ?rtt, ?response, "DNS query completed"); + + self.nameserver_by_rtt.insert(rtt, meta.nameserver); + } + (Ok(Ok(response)), meta) => { + tracing::debug!(nameserver = %meta.nameserver, ?response, "DNS query failed"); + } + (Ok(Err(e)), meta) => { + tracing::debug!(nameserver = %meta.nameserver, "DNS query failed: {e}"); + } + (Err(_), meta) => { + tracing::debug!(nameserver = %meta.nameserver, "DNS query timed out after {DNS_TIMEOUT:?}"); + } + } + + let Some(fastest) = self.fastest() else { + continue; + }; + + if self.queries.is_empty() { + tracing::info!(%fastest, "Evaluated fastest nameserver"); + + return Poll::Ready(()); + } + } + } +} + +fn query_firezone_dev() -> Message> { + let mut builder = MessageBuilder::new_vec().question(); + builder.header_mut().set_random_id(); + builder.header_mut().set_rd(true); + builder.header_mut().set_qr(false); + + builder + .push(Question::new_in(FIREZONE_DEV.clone(), Rtype::A)) + .expect("static question should always be valid"); + + builder.into_message() +} + +#[cfg(test)] +mod tests { + use std::net::Ipv4Addr; + + use super::*; + + #[tokio::test] + #[ignore = "Needs Internet"] + async fn can_evaluate_fastest_nameserver() { + let _guard = firezone_logging::test("debug"); + + let mut set = NameserverSet::new( + BTreeSet::from([ + Ipv4Addr::new(1, 1, 1, 1).into(), + Ipv4Addr::new(8, 8, 8, 8).into(), + Ipv4Addr::new(8, 8, 4, 4).into(), + Ipv4Addr::new(9, 9, 9, 9).into(), + Ipv4Addr::new(100, 100, 100, 100).into(), // Also include an unreachable server. + ]), + Arc::new(socket_factory::udp), + ); + set.evaluate(); + + std::future::poll_fn(|cx| set.poll(cx)).await; + + assert!(set.fastest().is_some()); + } + + #[tokio::test] + async fn can_handle_no_servers() { + let _guard = firezone_logging::test("debug"); + + let mut set = NameserverSet::new(BTreeSet::default(), Arc::new(socket_factory::udp)); + + std::future::poll_fn(|cx| set.poll(cx)).await; + + assert!(set.fastest().is_none()); + } +} diff --git a/rust/connlib/tunnel/src/io/tcp_dns.rs b/rust/connlib/tunnel/src/io/tcp_dns.rs new file mode 100644 index 000000000..120baf443 --- /dev/null +++ b/rust/connlib/tunnel/src/io/tcp_dns.rs @@ -0,0 +1,41 @@ +use std::{io, net::SocketAddr, sync::Arc}; + +use domain::base::{Message, ToName as _}; +use socket_factory::{SocketFactory, TcpSocket}; +use tokio::io::{AsyncReadExt as _, AsyncWriteExt as _}; + +pub async fn send( + factory: Arc>, + server: SocketAddr, + query: Message>, +) -> io::Result>> { + let domain = query + .sole_question() + .expect("all queries should be for a single name") + .qname() + .to_vec(); + + tracing::trace!(target: "wire::dns::recursive::tcp", %server, %domain); + + let tcp_socket = factory(&server)?; // TODO: Optimise this to reuse a TCP socket to the same resolver. + let mut tcp_stream = tcp_socket.connect(server).await?; + + let query = query.into_octets(); + let dns_message_length = (query.len() as u16).to_be_bytes(); + + tcp_stream.write_all(&dns_message_length).await?; + tcp_stream.write_all(&query).await?; + + let mut response_length = [0u8; 2]; + tcp_stream.read_exact(&mut response_length).await?; + let response_length = u16::from_be_bytes(response_length) as usize; + + // A u16 is at most 65k, meaning we are okay to allocate here based on what the remote is sending. + let mut response = vec![0u8; response_length]; + tcp_stream.read_exact(&mut response).await?; + + let message = Message::from_octets(response) + .map_err(|_| io::Error::other("Failed to parse DNS message"))?; + + Ok(message) +} diff --git a/rust/connlib/tunnel/src/io/udp_dns.rs b/rust/connlib/tunnel/src/io/udp_dns.rs new file mode 100644 index 000000000..c926d76a5 --- /dev/null +++ b/rust/connlib/tunnel/src/io/udp_dns.rs @@ -0,0 +1,40 @@ +use std::{ + io, + net::{Ipv4Addr, Ipv6Addr, SocketAddr}, + sync::Arc, +}; + +use domain::base::{Message, ToName as _}; +use socket_factory::{SocketFactory, UdpSocket}; + +pub async fn send( + factory: Arc>, + server: SocketAddr, + query: Message>, +) -> io::Result>> { + let domain = query + .sole_question() + .expect("all queries should be for a single name") + .qname() + .to_vec(); + let bind_addr = match server { + SocketAddr::V4(_) => SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), 0), + SocketAddr::V6(_) => SocketAddr::new(Ipv6Addr::UNSPECIFIED.into(), 0), + }; + + tracing::trace!(target: "wire::dns::recursive::udp", %server, %domain); + + // To avoid fragmentation, IP and thus also UDP packets can only reliably sent with an MTU of <= 1500 on the public Internet. + const BUF_SIZE: usize = 1500; + + let udp_socket = factory(&bind_addr)?; + + let response = udp_socket + .handshake::(server, query.as_slice()) + .await?; + + let message = Message::from_octets(response) + .map_err(|_| io::Error::other("Failed to parse DNS message"))?; + + Ok(message) +} diff --git a/rust/connlib/tunnel/src/lib.rs b/rust/connlib/tunnel/src/lib.rs index 5131b6e72..ace40c69b 100644 --- a/rust/connlib/tunnel/src/lib.rs +++ b/rust/connlib/tunnel/src/lib.rs @@ -66,6 +66,7 @@ pub type ClientTunnel = Tunnel; pub use client::ClientState; pub use gateway::{DnsResourceNatEntry, GatewayState, ResolveDnsRequest}; +pub use io::NoNameserverAvailable; pub use utils::turn; /// [`Tunnel`] glues together connlib's [`Io`] component and the respective (pure) state of a client or gateway. @@ -107,7 +108,11 @@ impl ClientTunnel { udp_socket_factory: Arc>, ) -> Self { Self { - io: Io::new(tcp_socket_factory, udp_socket_factory), + io: Io::new( + tcp_socket_factory, + udp_socket_factory.clone(), + BTreeSet::default(), + ), role_state: ClientState::new(rand::random(), Instant::now()), buffers: Buffers::default(), } @@ -216,9 +221,10 @@ impl GatewayTunnel { pub fn new( tcp_socket_factory: Arc>, udp_socket_factory: Arc>, + nameservers: BTreeSet, ) -> Self { Self { - io: Io::new(tcp_socket_factory, udp_socket_factory), + io: Io::new(tcp_socket_factory, udp_socket_factory.clone(), nameservers), role_state: GatewayState::new(rand::random(), Instant::now()), buffers: Buffers::default(), } @@ -246,8 +252,23 @@ impl GatewayTunnel { } match self.io.poll(cx, &mut self.buffers)? { - Poll::Ready(io::Input::DnsResponse(_)) => { - unreachable!("Gateway doesn't use user-space DNS resolution") + Poll::Ready(io::Input::DnsResponse(response)) => { + let message = response.message.unwrap_or_else(|e| { + tracing::debug!("DNS query failed: {e}"); + + dns::servfail(response.query.for_slice_ref()) + }); + + match response.transport { + dns::Transport::Udp { source } => { + self.io.send_udp_dns_response(source, message)?; + } + dns::Transport::Tcp { remote, .. } => { + self.io.send_tcp_dns_response(remote, message)?; + } + } + + continue; } Poll::Ready(io::Input::Timeout(timeout)) => { self.role_state.handle_timeout(timeout, Utc::now()); @@ -296,14 +317,25 @@ impl GatewayTunnel { continue; } - Poll::Ready(io::Input::UdpDnsQuery(query)) => self.io.send_udp_dns_response( - query.source, - dns::servfail(query.message.for_slice_ref()), - )?, - Poll::Ready(io::Input::TcpDnsQuery(query)) => self.io.send_tcp_dns_response( - query.source, - dns::servfail(query.message.for_slice_ref()), - )?, + Poll::Ready(io::Input::UdpDnsQuery(query)) => { + let nameserver = self.io.fastest_nameserver()?; + + self.io.send_dns_query(dns::RecursiveQuery::via_udp( + query.source, + SocketAddr::new(nameserver, dns::DNS_PORT), + query.message.for_slice_ref(), + )); + } + Poll::Ready(io::Input::TcpDnsQuery(query)) => { + let nameserver = self.io.fastest_nameserver()?; + + self.io.send_dns_query(dns::RecursiveQuery::via_tcp( + query.local, + query.remote, + SocketAddr::new(nameserver, dns::DNS_PORT), + query.message, + )); + } Poll::Pending => {} } diff --git a/rust/gateway/Cargo.toml b/rust/gateway/Cargo.toml index ffb42a971..bcc01a268 100644 --- a/rust/gateway/Cargo.toml +++ b/rust/gateway/Cargo.toml @@ -29,6 +29,7 @@ libc = { workspace = true, features = ["std", "const-extern-fn", "extra_traits"] moka = { workspace = true, features = ["future"] } nix = { workspace = true } phoenix-channel = { workspace = true } +resolv-conf = { workspace = true } rustls = { workspace = true } secrecy = { workspace = true } serde = { workspace = true, features = ["std", "derive"] } diff --git a/rust/gateway/src/eventloop.rs b/rust/gateway/src/eventloop.rs index 2408fe3a2..ceecfe868 100644 --- a/rust/gateway/src/eventloop.rs +++ b/rust/gateway/src/eventloop.rs @@ -123,6 +123,12 @@ impl Eventloop { continue; } + if e.root_cause() + .is::() + { + return Poll::Ready(Err(Error::NoNameserversAvailable(e))); + } + telemetry_event!("Tunnel error: {e:#}"); continue; } @@ -544,6 +550,8 @@ pub enum Error { UpdateTun(#[source] anyhow::Error), #[error("{0:#}")] BindDnsSockets(#[source] anyhow::Error), + #[error("{0:#}")] + NoNameserversAvailable(#[source] anyhow::Error), } async fn resolve(domain: DomainName) -> Result> { diff --git a/rust/gateway/src/main.rs b/rust/gateway/src/main.rs index 3032ad8b4..a4e96e4f6 100644 --- a/rust/gateway/src/main.rs +++ b/rust/gateway/src/main.rs @@ -16,10 +16,10 @@ use phoenix_channel::LoginUrl; use futures::{future, TryFutureExt}; use phoenix_channel::PhoenixChannel; use secrecy::{Secret, SecretString}; -use std::path::Path; use std::pin::pin; use std::process::ExitCode; use std::sync::Arc; +use std::{collections::BTreeSet, path::Path}; use tokio::io::AsyncWriteExt; use tokio::signal::ctrl_c; use tracing_subscriber::layer; @@ -113,7 +113,21 @@ async fn try_main(cli: Cli) -> Result { ) .context("Failed to construct URL for logging into portal")?; - let mut tunnel = GatewayTunnel::new(Arc::new(tcp_socket_factory), Arc::new(udp_socket_factory)); + let resolv_conf = resolv_conf::Config::parse( + std::fs::read_to_string("/etc/resolv.conf").context("Failed to read /etc/resolv.conf")?, + ) + .context("Failed to parse /etc/resolv.conf")?; + let nameservers = resolv_conf + .nameservers + .into_iter() + .map(|ip| ip.into()) + .collect::>(); + + let mut tunnel = GatewayTunnel::new( + Arc::new(tcp_socket_factory), + Arc::new(udp_socket_factory), + nameservers, + ); let portal = PhoenixChannel::disconnected( Secret::new(login), get_user_agent(None, env!("CARGO_PKG_VERSION")),