refactor(connlib): use anyhow::Error for recursive DNS (#10871)

With the introduction of DoH, we will need a more advanced error type
for recursive DNS responses. In particular, a DoH query might fail
because the underlying TCP connection got closed. With #10856, the HTTP
client no longer supports retries but instead needs to be recreated.

In order to accurately detect this failure case, we need `anyhow`'s
downcasting abilities.

This PR prepares the already existing code for that by switching from
`io::Error` to `anyhow::Error`.
This commit is contained in:
Thomas Eizinger
2025-11-14 07:37:54 +11:00
committed by GitHub
parent 8f6f6666a1
commit d282b641c5
6 changed files with 24 additions and 22 deletions

View File

@@ -31,7 +31,7 @@ use connlib_model::{
GatewayId, IceCandidate, PublicKey, RelayId, ResourceId, ResourceStatus, ResourceView,
};
use connlib_model::{Site, SiteId};
use firezone_logging::{err_with_src, unwrap_or_debug, unwrap_or_warn};
use firezone_logging::{unwrap_or_debug, unwrap_or_warn};
use ip_network::{IpNetwork, Ipv4Network, Ipv6Network};
use ip_network_table::IpNetworkTable;
use ip_packet::{IpPacket, MAX_UDP_PAYLOAD};
@@ -508,7 +508,10 @@ impl ClientState {
let _span = tracing::debug_span!("handle_dns_response", %qid, %server, local = %response.local, %domain).entered();
match (response.transport, response.message) {
(dns::Transport::Udp, Err(e)) if e.kind() == io::ErrorKind::TimedOut => {
(dns::Transport::Udp, Err(e))
if e.downcast_ref::<io::Error>()
.is_some_and(|e| e.kind() == io::ErrorKind::TimedOut) =>
{
tracing::debug!("Recursive UDP DNS query timed out")
}
(dns::Transport::Udp, result) => {
@@ -523,7 +526,7 @@ impl ClientState {
self.dns_cache.insert(domain, message, now);
})
.unwrap_or_else(|e| {
tracing::debug!("Recursive UDP DNS query failed: {}", err_with_src(&e));
tracing::debug!("Recursive UDP DNS query failed: {e:#}");
dns_types::Response::servfail(&response.query)
});
@@ -541,7 +544,7 @@ impl ClientState {
self.dns_cache.insert(domain, message, now);
})
.unwrap_or_else(|e| {
tracing::debug!("Recursive TCP DNS query failed: {}", err_with_src(&e));
tracing::debug!("Recursive TCP DNS query failed: {e:#}");
dns_types::Response::servfail(&response.query)
});
@@ -1187,9 +1190,7 @@ impl ClientState {
local,
remote,
query: query_result.query,
message: query_result
.result
.map_err(|e| io::Error::other(format!("{e:#}"))),
message: query_result.result,
transport: dns::Transport::Udp,
},
now,
@@ -1217,9 +1218,7 @@ impl ClientState {
local,
remote,
query: query_result.query,
message: query_result
.result
.map_err(|e| io::Error::other(format!("{e:#}"))),
message: query_result.result,
transport: dns::Transport::Tcp,
},
now,

View File

@@ -9,7 +9,6 @@ use firezone_logging::err_with_src;
use itertools::Itertools;
use pattern::{Candidate, Pattern};
use std::collections::{BTreeSet, VecDeque};
use std::io;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
use std::{
collections::{BTreeMap, HashMap},
@@ -86,7 +85,7 @@ pub(crate) struct RecursiveResponse {
pub query: dns_types::Query,
/// The result of forwarding the DNS query.
pub message: io::Result<dns_types::Response>,
pub message: Result<dns_types::Response>,
/// The transport we used.
pub transport: Transport,

View File

@@ -55,7 +55,7 @@ pub struct Io {
tcp_socket_factory: Arc<dyn SocketFactory<TcpSocket>>,
udp_socket_factory: Arc<dyn SocketFactory<UdpSocket>>,
dns_queries: FuturesTupleSet<io::Result<dns_types::Response>, DnsQueryMetaData>,
dns_queries: FuturesTupleSet<Result<dns_types::Response>, DnsQueryMetaData>,
timeout: Option<Pin<Box<tokio::time::Sleep>>>,
@@ -311,7 +311,10 @@ impl Io {
Err(e @ futures_bounded::Timeout { .. }) => dns::RecursiveResponse {
server: meta.server,
query: meta.query,
message: Err(io::Error::new(io::ErrorKind::TimedOut, e)),
message: Err(anyhow::Error::new(io::Error::new(
io::ErrorKind::TimedOut,
e,
))),
transport: meta.transport,
local: meta.local,
remote: meta.remote,

View File

@@ -1,12 +1,12 @@
use std::{
collections::{BTreeMap, BTreeSet},
io,
net::{IpAddr, SocketAddr},
sync::Arc,
task::{Context, Poll, ready},
time::{Duration, Instant},
};
use anyhow::Result;
use dns_types::{DomainNameRef, Query, RecordType, ResponseCode, prelude::*};
use futures_bounded::FuturesTupleSet;
use socket_factory::{SocketFactory, TcpSocket, UdpSocket};
@@ -27,7 +27,7 @@ pub struct NameserverSet {
tcp_socket_factory: Arc<dyn SocketFactory<TcpSocket>>,
udp_socket_factory: Arc<dyn SocketFactory<UdpSocket>>,
queries: FuturesTupleSet<io::Result<dns_types::Response>, QueryMetaData>,
queries: FuturesTupleSet<Result<dns_types::Response>, QueryMetaData>,
}
struct QueryMetaData {

View File

@@ -1,5 +1,6 @@
use std::{io, net::SocketAddr, sync::Arc};
use std::{net::SocketAddr, sync::Arc};
use anyhow::Result;
use socket_factory::{SocketFactory, TcpSocket};
use tokio::io::{AsyncReadExt as _, AsyncWriteExt as _};
@@ -7,7 +8,7 @@ pub async fn send(
factory: Arc<dyn SocketFactory<TcpSocket>>,
server: SocketAddr,
query: dns_types::Query,
) -> io::Result<dns_types::Response> {
) -> Result<dns_types::Response> {
tracing::trace!(target: "wire::dns::recursive::tcp", %server, domain = %query.domain());
let tcp_socket = factory.bind(server)?; // TODO: Optimise this to reuse a TCP socket to the same resolver.
@@ -27,7 +28,7 @@ pub async fn send(
let mut response = vec![0u8; response_length];
tcp_stream.read_exact(&mut response).await?;
let message = dns_types::Response::parse(&response).map_err(io::Error::other)?;
let message = dns_types::Response::parse(&response)?;
Ok(message)
}

View File

@@ -1,16 +1,16 @@
use std::{
io,
net::{Ipv4Addr, Ipv6Addr, SocketAddr},
sync::Arc,
};
use anyhow::Result;
use socket_factory::{SocketFactory, UdpSocket};
pub async fn send(
factory: Arc<dyn SocketFactory<UdpSocket>>,
server: SocketAddr,
query: dns_types::Query,
) -> io::Result<dns_types::Response> {
) -> Result<dns_types::Response> {
tracing::trace!(target: "wire::dns::recursive::udp", %server, domain = %query.domain());
let bind_addr = match server {
@@ -27,7 +27,7 @@ pub async fn send(
.handshake::<BUF_SIZE>(server, &query.into_bytes())
.await?;
let response = dns_types::Response::parse(&response).map_err(io::Error::other)?;
let response = dns_types::Response::parse(&response)?;
Ok(response)
}