From d282b641c5bc2b83bd57ee3867426e50a4317121 Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Fri, 14 Nov 2025 07:37:54 +1100 Subject: [PATCH] refactor(connlib): use `anyhow::Error` for recursive DNS (#10871) With the introduction of DoH, we will need a more advanced error type for recursive DNS responses. In particular, a DoH query might fail because the underlying TCP connection got closed. With #10856, the HTTP client no longer supports retries but instead needs to be recreated. In order to accurately detect this failure case, we need `anyhow`'s downcasting abilities. This PR prepares the already existing code for that by switching from `io::Error` to `anyhow::Error`. --- rust/connlib/tunnel/src/client.rs | 19 +++++++++---------- rust/connlib/tunnel/src/dns.rs | 3 +-- rust/connlib/tunnel/src/io.rs | 7 +++++-- rust/connlib/tunnel/src/io/nameserver_set.rs | 4 ++-- rust/connlib/tunnel/src/io/tcp_dns.rs | 7 ++++--- rust/connlib/tunnel/src/io/udp_dns.rs | 6 +++--- 6 files changed, 24 insertions(+), 22 deletions(-) diff --git a/rust/connlib/tunnel/src/client.rs b/rust/connlib/tunnel/src/client.rs index 6aa5273b0..3ecf331b4 100644 --- a/rust/connlib/tunnel/src/client.rs +++ b/rust/connlib/tunnel/src/client.rs @@ -31,7 +31,7 @@ use connlib_model::{ GatewayId, IceCandidate, PublicKey, RelayId, ResourceId, ResourceStatus, ResourceView, }; use connlib_model::{Site, SiteId}; -use firezone_logging::{err_with_src, unwrap_or_debug, unwrap_or_warn}; +use firezone_logging::{unwrap_or_debug, unwrap_or_warn}; use ip_network::{IpNetwork, Ipv4Network, Ipv6Network}; use ip_network_table::IpNetworkTable; use ip_packet::{IpPacket, MAX_UDP_PAYLOAD}; @@ -508,7 +508,10 @@ impl ClientState { let _span = tracing::debug_span!("handle_dns_response", %qid, %server, local = %response.local, %domain).entered(); match (response.transport, response.message) { - (dns::Transport::Udp, Err(e)) if e.kind() == io::ErrorKind::TimedOut => { + (dns::Transport::Udp, Err(e)) + if e.downcast_ref::() + .is_some_and(|e| e.kind() == io::ErrorKind::TimedOut) => + { tracing::debug!("Recursive UDP DNS query timed out") } (dns::Transport::Udp, result) => { @@ -523,7 +526,7 @@ impl ClientState { self.dns_cache.insert(domain, message, now); }) .unwrap_or_else(|e| { - tracing::debug!("Recursive UDP DNS query failed: {}", err_with_src(&e)); + tracing::debug!("Recursive UDP DNS query failed: {e:#}"); dns_types::Response::servfail(&response.query) }); @@ -541,7 +544,7 @@ impl ClientState { self.dns_cache.insert(domain, message, now); }) .unwrap_or_else(|e| { - tracing::debug!("Recursive TCP DNS query failed: {}", err_with_src(&e)); + tracing::debug!("Recursive TCP DNS query failed: {e:#}"); dns_types::Response::servfail(&response.query) }); @@ -1187,9 +1190,7 @@ impl ClientState { local, remote, query: query_result.query, - message: query_result - .result - .map_err(|e| io::Error::other(format!("{e:#}"))), + message: query_result.result, transport: dns::Transport::Udp, }, now, @@ -1217,9 +1218,7 @@ impl ClientState { local, remote, query: query_result.query, - message: query_result - .result - .map_err(|e| io::Error::other(format!("{e:#}"))), + message: query_result.result, transport: dns::Transport::Tcp, }, now, diff --git a/rust/connlib/tunnel/src/dns.rs b/rust/connlib/tunnel/src/dns.rs index 5ab316075..3d730bf18 100644 --- a/rust/connlib/tunnel/src/dns.rs +++ b/rust/connlib/tunnel/src/dns.rs @@ -9,7 +9,6 @@ use firezone_logging::err_with_src; use itertools::Itertools; use pattern::{Candidate, Pattern}; use std::collections::{BTreeSet, VecDeque}; -use std::io; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; use std::{ collections::{BTreeMap, HashMap}, @@ -86,7 +85,7 @@ pub(crate) struct RecursiveResponse { pub query: dns_types::Query, /// The result of forwarding the DNS query. - pub message: io::Result, + pub message: Result, /// The transport we used. pub transport: Transport, diff --git a/rust/connlib/tunnel/src/io.rs b/rust/connlib/tunnel/src/io.rs index c5b535879..b6d64acea 100644 --- a/rust/connlib/tunnel/src/io.rs +++ b/rust/connlib/tunnel/src/io.rs @@ -55,7 +55,7 @@ pub struct Io { tcp_socket_factory: Arc>, udp_socket_factory: Arc>, - dns_queries: FuturesTupleSet, DnsQueryMetaData>, + dns_queries: FuturesTupleSet, DnsQueryMetaData>, timeout: Option>>, @@ -311,7 +311,10 @@ impl Io { Err(e @ futures_bounded::Timeout { .. }) => dns::RecursiveResponse { server: meta.server, query: meta.query, - message: Err(io::Error::new(io::ErrorKind::TimedOut, e)), + message: Err(anyhow::Error::new(io::Error::new( + io::ErrorKind::TimedOut, + e, + ))), transport: meta.transport, local: meta.local, remote: meta.remote, diff --git a/rust/connlib/tunnel/src/io/nameserver_set.rs b/rust/connlib/tunnel/src/io/nameserver_set.rs index b6c363d7d..b3670cecf 100644 --- a/rust/connlib/tunnel/src/io/nameserver_set.rs +++ b/rust/connlib/tunnel/src/io/nameserver_set.rs @@ -1,12 +1,12 @@ use std::{ collections::{BTreeMap, BTreeSet}, - io, net::{IpAddr, SocketAddr}, sync::Arc, task::{Context, Poll, ready}, time::{Duration, Instant}, }; +use anyhow::Result; use dns_types::{DomainNameRef, Query, RecordType, ResponseCode, prelude::*}; use futures_bounded::FuturesTupleSet; use socket_factory::{SocketFactory, TcpSocket, UdpSocket}; @@ -27,7 +27,7 @@ pub struct NameserverSet { tcp_socket_factory: Arc>, udp_socket_factory: Arc>, - queries: FuturesTupleSet, QueryMetaData>, + queries: FuturesTupleSet, QueryMetaData>, } struct QueryMetaData { diff --git a/rust/connlib/tunnel/src/io/tcp_dns.rs b/rust/connlib/tunnel/src/io/tcp_dns.rs index 1f5a52391..75d83b4c7 100644 --- a/rust/connlib/tunnel/src/io/tcp_dns.rs +++ b/rust/connlib/tunnel/src/io/tcp_dns.rs @@ -1,5 +1,6 @@ -use std::{io, net::SocketAddr, sync::Arc}; +use std::{net::SocketAddr, sync::Arc}; +use anyhow::Result; use socket_factory::{SocketFactory, TcpSocket}; use tokio::io::{AsyncReadExt as _, AsyncWriteExt as _}; @@ -7,7 +8,7 @@ pub async fn send( factory: Arc>, server: SocketAddr, query: dns_types::Query, -) -> io::Result { +) -> Result { tracing::trace!(target: "wire::dns::recursive::tcp", %server, domain = %query.domain()); let tcp_socket = factory.bind(server)?; // TODO: Optimise this to reuse a TCP socket to the same resolver. @@ -27,7 +28,7 @@ pub async fn send( let mut response = vec![0u8; response_length]; tcp_stream.read_exact(&mut response).await?; - let message = dns_types::Response::parse(&response).map_err(io::Error::other)?; + let message = dns_types::Response::parse(&response)?; Ok(message) } diff --git a/rust/connlib/tunnel/src/io/udp_dns.rs b/rust/connlib/tunnel/src/io/udp_dns.rs index 14d0bea1c..62bb88d97 100644 --- a/rust/connlib/tunnel/src/io/udp_dns.rs +++ b/rust/connlib/tunnel/src/io/udp_dns.rs @@ -1,16 +1,16 @@ use std::{ - io, net::{Ipv4Addr, Ipv6Addr, SocketAddr}, sync::Arc, }; +use anyhow::Result; use socket_factory::{SocketFactory, UdpSocket}; pub async fn send( factory: Arc>, server: SocketAddr, query: dns_types::Query, -) -> io::Result { +) -> Result { tracing::trace!(target: "wire::dns::recursive::udp", %server, domain = %query.domain()); let bind_addr = match server { @@ -27,7 +27,7 @@ pub async fn send( .handshake::(server, &query.into_bytes()) .await?; - let response = dns_types::Response::parse(&response).map_err(io::Error::other)?; + let response = dns_types::Response::parse(&response)?; Ok(response) }