refactor(connlib): use dedicated UDP DNS client (#10850)

By default, DNS queries are sent over UDP by most systems. UDP is an
easy to understand protocol because each packet stands by itself and at
least as far as UDP is concerned, the payload is contained within a
single packet.

In Firezone, we receive all DNS traffic on the TUN device as IP packets.
Processing the UDP packets is trivial as each query is contained within
a single IP packet. For TCP, we first need to assemble the TCP stream
before we can read the entire query.

In case a DNS query is not for a Firezone DNS resource, we want to
forward it to the specified upstream resolver, either directly from the
system or - in case the specified upstream resolver is an IP resource -
through the tunnel as an IP packet. Specifically, the forwarding of UDP
DNS packets through the tunnel currently happens like this:

IP packet -> read UDP payload -> parse DNS query -> mangle original
destination IP to new upstream -> send through tunnel

For TCP DNS queries, it is not quite as easy as we have to decode the
incoming TCP stream first before we can parse the DNS query. Thus, when
we want to then forward the query, we need to open our own TCP stream to
the upstream resolver and encode the DNS query onto that stream, sending
each IP packet from the TCP client through the tunnel.

The difference in these designs makes several code paths in connlib hard
to follow.

Therefore - and despite the simplicity of DNS over UDP - we already
created our own "Layer 3 UDP DNS"-client. This PR now integrates this
client into the tunnel. Using this new client, we can simplify the
processing of UDP DNS queries because we never have to "go back" to the
original IP packet. Instead, when a DNS query needs to be forwarded to
an usptream resolver through the tunnel, we simply tell the Layer 3 UDP
DNS client to make a new DNS query. The processing of the resulting IP
packet then happens in a different place, right next to where we also
process the IP packets of the TCP DNS client.

That simplifications unlocks further refactorings where we now only
process DNS queries in a single place and the transport we received it
over is a simple function parameter with the control flow for both of
them being identical.

Related: #4668
This commit is contained in:
Thomas Eizinger
2025-11-11 14:53:25 +11:00
committed by GitHub
parent de7d3bff89
commit 0008539b65
8 changed files with 211 additions and 356 deletions

1
rust/Cargo.lock generated
View File

@@ -2688,6 +2688,7 @@ dependencies = [
"ip_network_table",
"itertools 0.14.0",
"l3-tcp",
"l3-udp-dns-client",
"l4-tcp-dns-server",
"l4-udp-dns-server",
"lru",

View File

@@ -8,7 +8,7 @@ use anyhow::{Context as _, Result, anyhow, bail};
use ip_packet::IpPacket;
use rand::{Rng, SeedableRng, rngs::StdRng};
const TIMEOUT: Duration = Duration::from_secs(5);
const TIMEOUT: Duration = Duration::from_secs(30);
/// A sans-io DNS-over-UDP client.
pub struct Client<const MIN_PORT: u16 = 49152, const MAX_PORT: u16 = 65535> {

View File

@@ -89,7 +89,7 @@ impl Server {
}
if let Poll::Ready(Some(result)) = self.reading_udp_queries.poll_next_unpin(cx) {
let (from, message) = result
let (remote, message) = result
.context("Failed to read UDP DNS query")
.map_err(anyhow_to_io)?;
@@ -102,7 +102,7 @@ impl Server {
return Poll::Ready(Ok(Query {
local,
from,
remote,
message,
}));
}
@@ -144,7 +144,7 @@ async fn read_udp_query(socket: Arc<UdpSocket>) -> Result<(SocketAddr, dns_types
pub struct Query {
pub local: SocketAddr,
pub from: SocketAddr,
pub remote: SocketAddr,
pub message: dns_types::Query,
}
@@ -192,7 +192,7 @@ mod tests {
let query = poll_fn(|cx| server.poll(cx)).await.unwrap();
server
.send_response(query.from, dns_types::Response::no_error(&query.message))
.send_response(query.remote, dns_types::Response::no_error(&query.message))
.unwrap();
}
});
@@ -224,7 +224,7 @@ mod tests {
let query = poll_fn(|cx| server.poll(cx)).await.unwrap();
server
.send_response(query.from, dns_types::Response::no_error(&query.message))
.send_response(query.remote, dns_types::Response::no_error(&query.message))
.unwrap();
}
});

View File

@@ -17,7 +17,7 @@ bufferpool = { workspace = true }
bytes = { workspace = true, features = ["std"] }
chrono = { workspace = true }
connlib-model = { workspace = true }
derive_more = { workspace = true, features = ["debug", "from"] }
derive_more = { workspace = true, features = ["debug", "from", "display"] }
divan = { workspace = true, optional = true }
dns-over-tcp = { workspace = true }
dns-types = { workspace = true }
@@ -33,6 +33,7 @@ ip-packet = { workspace = true }
ip_network = { workspace = true }
ip_network_table = { workspace = true }
itertools = { workspace = true, features = ["use_std"] }
l3-udp-dns-client = { workspace = true }
l4-tcp-dns-server = { workspace = true }
l4-udp-dns-server = { workspace = true }
lru = { workspace = true }

View File

@@ -21,7 +21,6 @@ use secrecy::ExposeSecret as _;
use crate::client::dns_cache::DnsCache;
use crate::dns::{DnsResourceRecord, StubResolver};
use crate::expiring_map::{self, ExpiringMap};
use crate::messages::Interface as InterfaceConfig;
use crate::messages::{IceCredentials, SecretKey};
use crate::peer_store::PeerStore;
@@ -80,10 +79,6 @@ pub(crate) const DNS_SENTINELS_V6: Ipv6Network = match Ipv6Network::new(
Err(_) => unreachable!(),
};
// The max time a dns request can be configured to live in resolvconf
// is 30 seconds. See resolvconf(5) timeout.
const IDS_EXPIRE: std::time::Duration = std::time::Duration::from_secs(60);
/// How many gateways we at most remember that we connected to.
///
/// 100 has been chosen as a pretty arbitrary value.
@@ -122,8 +117,6 @@ pub struct ClientState {
/// Manages the DNS configuration.
dns_config: DnsConfig,
/// UDP DNS queries that had their destination IP mangled to redirect them to another DNS resolver through the tunnel.
udp_dns_sockets_by_upstream_and_query_id: ExpiringMap<(SocketAddr, u16), SocketAddr>,
/// Manages internal dns records and emits forwarding event when not internally handled
stub_resolver: StubResolver,
/// Caches responses from DNS servers.
@@ -132,10 +125,12 @@ pub struct ClientState {
/// Configuration of the TUN device, when it is up.
tun_config: Option<TunConfig>,
udp_dns_client: l3_udp_dns_client::Client,
tcp_dns_client: dns_over_tcp::Client,
tcp_dns_server: dns_over_tcp::Server,
/// Tracks the TCP stream (i.e. socket-pair) on which we received a TCP DNS query by the ID of the recursive DNS query we issued.
tcp_dns_streams_by_upstream_and_query_id: HashMap<(SocketAddr, u16), (SocketAddr, SocketAddr)>,
/// Tracks the UDP/TCP stream (i.e. socket-pair) on which we received a DNS query by the ID of the recursive DNS query we issued.
dns_streams_by_upstream_and_query_id:
HashMap<(dns::Transport, SocketAddr, u16), (SocketAddr, SocketAddr)>,
/// Stores the gateways we recently connected to.
///
@@ -152,8 +147,7 @@ pub struct ClientState {
struct PendingFlow {
last_intent_sent_at: Instant,
resource_packets: UniquePacketBuffer,
udp_dns_queries: AllocRingBuffer<IpPacket>,
tcp_dns_queries: AllocRingBuffer<dns_over_tcp::Query>,
dns_queries: AllocRingBuffer<DnsQueryForSite>,
}
impl PendingFlow {
@@ -170,8 +164,7 @@ impl PendingFlow {
Self::CAPACITY_POW_2,
"pending-flow-resources",
),
udp_dns_queries: AllocRingBuffer::with_capacity_power_of_2(Self::CAPACITY_POW_2),
tcp_dns_queries: AllocRingBuffer::with_capacity_power_of_2(Self::CAPACITY_POW_2),
dns_queries: AllocRingBuffer::with_capacity_power_of_2(Self::CAPACITY_POW_2),
};
this.push(trigger);
@@ -181,11 +174,8 @@ impl PendingFlow {
fn push(&mut self, trigger: ConnectionTrigger) {
match trigger {
ConnectionTrigger::PacketForResource(packet) => self.resource_packets.push(packet),
ConnectionTrigger::UdpDnsQueryForSite(packet) => {
self.udp_dns_queries.enqueue(packet);
}
ConnectionTrigger::TcpDnsQueryForSite(query) => {
self.tcp_dns_queries.enqueue(query);
ConnectionTrigger::DnsQueryForSite(query) => {
self.dns_queries.enqueue(query);
}
ConnectionTrigger::IcmpDestinationUnreachableProhibited => {}
}
@@ -211,16 +201,16 @@ impl ClientState {
node: ClientNode::new(seed, now),
sites_status: Default::default(),
gateways_site: Default::default(),
udp_dns_sockets_by_upstream_and_query_id: Default::default(),
stub_resolver: StubResolver::new(records),
dns_cache: Default::default(),
buffered_transmits: Default::default(),
is_internet_resource_active,
recently_connected_gateways: LruCache::new(MAX_REMEMBERED_GATEWAYS),
buffered_dns_queries: Default::default(),
udp_dns_client: l3_udp_dns_client::Client::new(now, seed),
tcp_dns_client: dns_over_tcp::Client::new(now, seed),
tcp_dns_server: dns_over_tcp::Server::new(now),
tcp_dns_streams_by_upstream_and_query_id: Default::default(),
dns_streams_by_upstream_and_query_id: Default::default(),
pending_flows: Default::default(),
dns_resource_nat: Default::default(),
pending_tun_update: Default::default(),
@@ -442,6 +432,11 @@ impl ClientState {
.inspect_err(|e| tracing::debug!(%local, %from, num_bytes = %packet.len(), "Failed to decapsulate: {e:#}"))
.ok()??;
if self.udp_dns_client.accepts(&packet) {
self.udp_dns_client.handle_inbound(packet);
return None;
}
if self.tcp_dns_client.accepts(&packet) {
self.tcp_dns_client.handle_inbound(packet);
return None;
@@ -488,13 +483,6 @@ impl ClientState {
.inspect_err(|e| tracing::debug!(%gid, %local, %from, "{e}"))
.ok()?;
let packet = maybe_mangle_dns_response_from_upstream_dns_server(
packet,
&mut self.udp_dns_sockets_by_upstream_and_query_id,
&mut self.dns_cache,
now,
);
if feature_flags::icmp_error_unreachable_prohibited_create_new_flow()
&& let Ok(Some((failed_packet, error))) = packet.icmp_error()
&& error.is_unreachable_prohibited()
@@ -517,7 +505,7 @@ impl ClientState {
let server = response.server;
let domain = response.query.domain();
let _span = tracing::debug_span!("handle_dns_response", %qid, %server, %domain).entered();
let _span = tracing::debug_span!("handle_dns_response", %qid, %server, local = %response.local, %domain).entered();
match (response.transport, response.message) {
(dns::Transport::Udp, Err(e)) if e.kind() == io::ErrorKind::TimedOut => {
@@ -753,7 +741,7 @@ impl ClientState {
// If we are making this connection because we want to send a DNS query to the Gateway,
// mark it as "used" through the DNS resource ID.
if !pending_flow.udp_dns_queries.is_empty() || !pending_flow.tcp_dns_queries.is_empty() {
if !pending_flow.dns_queries.is_empty() {
self.peers.add_ips_with_resource(
&gid,
[
@@ -765,34 +753,19 @@ impl ClientState {
}
// 2. Buffered UDP DNS queries for the Gateway
for packet in pending_flow.udp_dns_queries {
for query in pending_flow.dns_queries {
let gateway = self.peers.get(&gid).context("Unknown peer")?; // If this error happens we have a bug: We just inserted it above.
let upstream = gateway.tun_dns_server_endpoint(packet.destination());
let packet =
self.mangle_udp_dns_query_to_new_upstream_through_tunnel(upstream, now, packet);
let upstream = gateway.tun_dns_server_endpoint(query.local.ip());
encapsulate_and_buffer(
packet,
gid,
self.forward_dns_query_to_new_upstream_via_tunnel(
query.local,
query.remote,
upstream,
query.message,
query.transport,
now,
&mut self.node,
&mut self.buffered_transmits,
)
}
// 3. Buffered TCP DNS queries for the Gateway
for query in pending_flow.tcp_dns_queries {
let server = match query.local {
SocketAddr::V4(_) => {
SocketAddr::new(gateway_tun.v4.into(), crate::gateway::TUN_DNS_PORT)
}
SocketAddr::V6(_) => {
SocketAddr::new(gateway_tun.v6.into(), crate::gateway::TUN_DNS_PORT)
}
};
self.forward_tcp_dns_query_to_new_upstream_via_tunnel(server, query);
);
}
Ok(Ok(()))
@@ -832,7 +805,9 @@ impl ClientState {
return ControlFlow::Break(());
}
self.handle_udp_dns_query(upstream, packet, now)
self.handle_udp_dns_query(upstream, packet, now);
ControlFlow::Break(())
}
pub fn on_connection_failed(&mut self, resource: ResourceId) {
@@ -903,6 +878,8 @@ impl ClientState {
return;
};
self.udp_dns_client
.set_source_interface(tun_config.ip.v4, tun_config.ip.v6);
self.tcp_dns_client
.set_source_interface(tun_config.ip.v4, tun_config.ip.v6);
self.tcp_dns_client.reset();
@@ -1104,9 +1081,9 @@ impl ClientState {
pub fn poll_timeout(&mut self) -> Option<(Instant, &'static str)> {
iter::empty()
.chain(
self.udp_dns_sockets_by_upstream_and_query_id
self.udp_dns_client
.poll_timeout()
.map(|instant| (instant, "DNS socket timeout")),
.map(|instant| (instant, "UDP DNS client")),
)
.chain(
self.dns_cache
@@ -1131,43 +1108,57 @@ impl ClientState {
self.node.handle_timeout(now);
self.drain_node_events();
self.udp_dns_sockets_by_upstream_and_query_id
.handle_timeout(now);
while let Some(event) = self.udp_dns_sockets_by_upstream_and_query_id.poll_event() {
let expiring_map::Event::EntryExpired { key, value } = event;
tracing::debug!(
?key,
?value,
"Mapping entry for forwarded DNS query expired"
);
}
self.advance_dns_tcp_sockets(now);
self.advance_dns_clients_and_servers(now);
self.send_dns_resource_nat_packets(now);
self.dns_cache.handle_timeout(now);
}
/// Advance the TCP DNS server and client state machines.
/// Advance the DNS server and client state machines.
///
/// Receiving something on a TCP server socket may trigger packets to be sent on the TCP client socket and vice versa.
/// Receiving something on a UDP/TCP server socket may trigger packets to be sent on the UDP/TCP client socket and vice versa.
/// Therefore, we loop here until non of the `poll-X` functions return anything anymore.
fn advance_dns_tcp_sockets(&mut self, now: Instant) {
fn advance_dns_clients_and_servers(&mut self, now: Instant) {
loop {
self.tcp_dns_server.handle_timeout(now);
self.tcp_dns_client.handle_timeout(now);
self.udp_dns_client.handle_timeout(now);
// Check if have any pending TCP DNS queries.
if let Some(query) = self.tcp_dns_server.poll_queries() {
self.handle_tcp_dns_query(query, now);
let Some(upstream) = self
.dns_config
.mapping()
.upstream_by_sentinel(query.local.ip())
else {
// This is highly-unlikely but might be possible if our DNS mapping changes whilst the TCP DNS server is processing a request.
continue;
};
if let Some(response) = self.handle_dns_query(
query.message,
query.local,
query.remote,
upstream,
dns::Transport::Tcp,
now,
) {
unwrap_or_debug!(
self.tcp_dns_server
.send_message(query.local, query.remote, response),
"Failed to send TCP DNS response: {}"
);
}
continue;
}
// Check if the client wants to emit any packets.
if let Some(packet) = self.tcp_dns_client.poll_outbound() {
// All packets from the TCP DNS client _should_ go through the tunnel.
// Check if the clients wants to emit any packets.
if let Some(packet) = self
.tcp_dns_client
.poll_outbound()
.or_else(|| self.udp_dns_client.poll_outbound())
{
// All packets from the DNS clients _should_ go through the tunnel.
let Some(transmit) = self.encapsulate(packet, now) else {
continue;
};
@@ -1176,13 +1167,45 @@ impl ClientState {
continue;
}
// Check if the client has assembled a response to a query.
// Check if the UDP DNS client has assembled a response to a query.
if let Some(query_result) = self.udp_dns_client.poll_query_result() {
let server = query_result.server;
let qid = query_result.query.id();
let known_sockets = &mut self.dns_streams_by_upstream_and_query_id;
let Some((local, remote)) =
known_sockets.remove(&(dns::Transport::Udp, server, qid))
else {
tracing::warn!(?known_sockets, %server, %qid, "Failed to find UDP socket handle for query result");
continue;
};
self.handle_dns_response(
dns::RecursiveResponse {
server,
local,
remote,
query: query_result.query,
message: query_result
.result
.map_err(|e| io::Error::other(format!("{e:#}"))),
transport: dns::Transport::Udp,
},
now,
);
continue;
}
// Check if the TCP DNS client has assembled a response to a query.
if let Some(query_result) = self.tcp_dns_client.poll_query_result() {
let server = query_result.server;
let qid = query_result.query.id();
let known_sockets = &mut self.tcp_dns_streams_by_upstream_and_query_id;
let known_sockets = &mut self.dns_streams_by_upstream_and_query_id;
let Some((local, remote)) = known_sockets.remove(&(server, qid)) else {
let Some((local, remote)) =
known_sockets.remove(&(dns::Transport::Tcp, server, qid))
else {
tracing::warn!(?known_sockets, %server, %qid, "Failed to find TCP socket handle for query result");
continue;
@@ -1222,16 +1245,11 @@ impl ClientState {
}
}
fn handle_udp_dns_query(
&mut self,
upstream: SocketAddr,
packet: IpPacket,
now: Instant,
) -> ControlFlow<(), IpPacket> {
fn handle_udp_dns_query(&mut self, upstream: SocketAddr, packet: IpPacket, now: Instant) {
let Some(datagram) = packet.as_udp() else {
tracing::debug!(?packet, "Not a UDP packet");
return ControlFlow::Break(());
return;
};
if datagram.destination_port() != DNS_PORT {
@@ -1239,81 +1257,28 @@ impl ClientState {
?packet,
"UDP DNS queries are only supported on port {DNS_PORT}"
);
return ControlFlow::Break(());
return;
}
let message = match dns_types::Query::parse(datagram.payload()) {
Ok(message) => message,
Err(e) => {
tracing::warn!(?packet, "Failed to parse DNS query: {e:#}");
return ControlFlow::Break(());
return;
}
};
let destination = SocketAddr::new(packet.destination(), datagram.destination_port());
let source = SocketAddr::new(packet.source(), datagram.source_port());
let local = SocketAddr::new(packet.destination(), datagram.destination_port());
let remote = SocketAddr::new(packet.source(), datagram.source_port());
if let Some(response) = self.dns_cache.try_answer(&message, now) {
unwrap_or_debug!(
self.try_queue_udp_dns_response(destination, source, response),
if let Some(response) =
self.handle_dns_query(message, local, remote, upstream, dns::Transport::Udp, now)
{
unwrap_or_warn!(
self.try_queue_udp_dns_response(local, remote, response),
"Failed to queue UDP DNS response: {}"
);
return ControlFlow::Break(());
}
match self.stub_resolver.handle(&message) {
dns::ResolveStrategy::LocalResponse(response) => {
self.dns_resource_nat.recreate(message.domain());
self.update_dns_resource_nat(now, iter::empty());
self.dns_cache.insert(message.domain(), &response, now);
unwrap_or_debug!(
self.try_queue_udp_dns_response(destination, source, response),
"Failed to queue UDP DNS response: {}"
);
}
dns::ResolveStrategy::RecurseLocal => {
if self.should_forward_dns_query_to_gateway(upstream.ip()) {
let packet = self
.mangle_udp_dns_query_to_new_upstream_through_tunnel(upstream, now, packet);
return ControlFlow::Continue(packet);
}
let query_id = message.id();
tracing::trace!(server = %upstream, %query_id, "Forwarding UDP DNS query directly via host");
self.buffered_dns_queries
.push_back(dns::RecursiveQuery::via_udp(
destination,
source,
upstream,
message,
));
}
dns::ResolveStrategy::RecurseSite(resource) => {
let Some(gateway) =
peer_by_resource_mut(&self.resources_gateways, &mut self.peers, resource)
else {
self.on_not_connected_resource(
resource,
ConnectionTrigger::UdpDnsQueryForSite(packet),
now,
);
return ControlFlow::Break(());
};
let upstream = gateway.tun_dns_server_endpoint(packet.destination());
let packet =
self.mangle_udp_dns_query_to_new_upstream_through_tunnel(upstream, now, packet);
return ControlFlow::Continue(packet);
}
}
ControlFlow::Break(())
};
}
fn handle_llmnr_dns_query(&mut self, packet: IpPacket, now: Instant) {
@@ -1373,97 +1338,47 @@ impl ClientState {
}
}
fn mangle_udp_dns_query_to_new_upstream_through_tunnel(
fn handle_dns_query(
&mut self,
message: dns_types::Query,
local: SocketAddr,
remote: SocketAddr,
upstream: SocketAddr,
transport: dns::Transport,
now: Instant,
mut packet: IpPacket,
) -> IpPacket {
let dst_ip = packet.destination();
let datagram = packet
.as_udp()
.expect("to be a valid UDP packet at this point");
) -> Option<dns_types::Response> {
let query_id = message.id();
let dst_port = datagram.destination_port();
let query_id = dns_types::Query::parse(datagram.payload())
.expect("to be a valid DNS query at this point")
.id();
let connlib_dns_server = SocketAddr::new(dst_ip, dst_port);
self.udp_dns_sockets_by_upstream_and_query_id.insert(
(upstream, query_id),
connlib_dns_server,
now,
IDS_EXPIRE,
);
if let Err(e) = packet.set_dst(upstream.ip()) {
tracing::warn!("Failed to set destination IP for UDP DNS query: {e:#}");
}
// TODO: Remove this once we disallow non-standard DNS ports: https://github.com/firezone/firezone/issues/8330
packet
.as_udp_mut()
.expect("to be a valid UDP packet at this point")
.set_destination_port(upstream.port());
packet.update_checksum();
tracing::trace!(%upstream, %connlib_dns_server, %query_id, "Forwarding UDP DNS query via tunnel");
packet
}
fn handle_tcp_dns_query(&mut self, query: dns_over_tcp::Query, now: Instant) {
let query_id = query.message.id();
let Some(server) = self
.dns_config
.mapping()
.upstream_by_sentinel(query.local.ip())
else {
// This is highly-unlikely but might be possible if our DNS mapping changes whilst the TCP DNS server is processing a request.
return;
};
if let Some(response) = self.dns_cache.try_answer(&query.message, now) {
unwrap_or_debug!(
self.tcp_dns_server
.send_message(query.local, query.remote, response),
"Failed to send TCP DNS response: {}"
);
return;
if let Some(response) = self.dns_cache.try_answer(&message, now) {
return Some(response);
}
match self.stub_resolver.handle(&query.message) {
match self.stub_resolver.handle(&message) {
dns::ResolveStrategy::LocalResponse(response) => {
self.dns_resource_nat.recreate(query.message.domain());
self.dns_resource_nat.recreate(message.domain());
self.update_dns_resource_nat(now, iter::empty());
self.dns_cache
.insert(query.message.domain(), &response, now);
self.dns_cache.insert(message.domain(), &response, now);
unwrap_or_debug!(
self.tcp_dns_server
.send_message(query.local, query.remote, response),
"Failed to send TCP DNS response: {}"
);
return Some(response);
}
dns::ResolveStrategy::RecurseLocal => {
if self.should_forward_dns_query_to_gateway(server.ip()) {
self.forward_tcp_dns_query_to_new_upstream_via_tunnel(server, query);
if self.should_forward_dns_query_to_gateway(upstream.ip()) {
self.forward_dns_query_to_new_upstream_via_tunnel(
local, remote, upstream, message, transport, now,
);
return;
return None;
}
tracing::trace!(%server, %query_id, "Forwarding TCP DNS query");
tracing::trace!(%upstream, %query_id, "Forwarding {transport} DNS query");
self.buffered_dns_queries
.push_back(dns::RecursiveQuery::via_tcp(
query.local,
query.remote,
server,
query.message,
));
self.buffered_dns_queries.push_back(dns::RecursiveQuery {
server: upstream,
local,
remote,
message,
transport,
});
}
dns::ResolveStrategy::RecurseSite(resource) => {
let Some(gateway) =
@@ -1471,54 +1386,62 @@ impl ClientState {
else {
self.on_not_connected_resource(
resource,
ConnectionTrigger::TcpDnsQueryForSite(query),
ConnectionTrigger::DnsQueryForSite(DnsQueryForSite {
local,
remote,
transport,
message,
}),
now,
);
return;
return None;
};
let server = gateway.tun_dns_server_endpoint(query.local.ip());
let server = gateway.tun_dns_server_endpoint(local.ip());
self.forward_tcp_dns_query_to_new_upstream_via_tunnel(server, query);
self.forward_dns_query_to_new_upstream_via_tunnel(
local, remote, server, message, transport, now,
);
}
};
None
}
fn forward_tcp_dns_query_to_new_upstream_via_tunnel(
fn forward_dns_query_to_new_upstream_via_tunnel(
&mut self,
local: SocketAddr,
remote: SocketAddr,
server: SocketAddr,
query: dns_over_tcp::Query,
query: dns_types::Query,
transport: dns::Transport,
now: Instant,
) {
let query_id = query.message.id();
let query_id = query.id();
match self
.tcp_dns_client
.send_query(server, query.message.clone())
{
let result = match transport {
dns::Transport::Udp => self.udp_dns_client.send_query(server, query, now),
dns::Transport::Tcp => self.tcp_dns_client.send_query(server, query),
};
match result {
Ok(()) => {}
Err(e) => {
tracing::warn!(
"Failed to send recursive TCP DNS query to upstream resolver: {e:#}"
);
unwrap_or_debug!(
self.tcp_dns_server.send_message(
query.local,
query.remote,
dns_types::Response::servfail(&query.message)
),
"Failed to send TCP DNS response: {}"
"Failed to send recursive {transport} DNS query to upstream resolver: {e:#}"
);
return;
}
};
tracing::trace!(%server, %local, %query_id, "Forwarding {transport} DNS query via tunnel");
let existing = self
.tcp_dns_streams_by_upstream_and_query_id
.insert((server, query_id), (query.local, query.remote));
.dns_streams_by_upstream_and_query_id
.insert((transport, server, query_id), (local, remote));
if let Some((existing_local, existing_remote)) = existing
&& (existing_local != query.local || existing_remote != query.remote)
&& (existing_local != local || existing_remote != remote)
{
debug_assert!(false, "Query IDs should be unique");
}
@@ -1982,73 +1905,30 @@ fn is_definitely_not_a_resource(ip: IpAddr) -> bool {
false
}
fn maybe_mangle_dns_response_from_upstream_dns_server(
mut packet: IpPacket,
udp_dns_sockets_by_upstream_and_query_id: &mut ExpiringMap<(SocketAddr, u16), SocketAddr>,
dns_cache: &mut DnsCache,
now: Instant,
) -> IpPacket {
let src_ip = packet.source();
let Some(udp) = packet.as_udp() else {
return packet;
};
let src_port = udp.source_port();
let src_socket = SocketAddr::new(src_ip, src_port);
let Ok(message) = dns_types::Response::parse(udp.payload()) else {
return packet;
};
let Some(expiring_map::Entry {
value: original_dst,
..
}) = udp_dns_sockets_by_upstream_and_query_id.remove(&(src_socket, message.id()))
else {
return packet;
};
dns_cache.insert(message.domain(), &message, now);
tracing::trace!(server = %src_ip, query_id = %message.id(), domain = %message.domain(), "Received UDP DNS response via tunnel");
if let Err(e) = packet.set_src(original_dst.ip()) {
tracing::warn!("Failed to set source IP for UDP DNS query: {e:#}");
}
packet
.as_udp_mut()
.expect("we parsed it as a UDP packet earlier")
.set_source_port(original_dst.port());
packet.update_checksum();
packet
}
/// What triggered us to establish a connection to a Gateway.
enum ConnectionTrigger {
/// A packet received on the TUN device with a destination IP that maps to one of our resources.
PacketForResource(IpPacket),
/// A UDP DNS query that needs to be resolved within a particular site that we aren't connected to yet.
///
/// This packet isn't mangled yet to point to the Gateway's TUN device IP because at the time of buffering, that IP is unknown.
UdpDnsQueryForSite(IpPacket),
/// A TCP DNS query that needs to be resolved within a particular site that we aren't connected to yet.
TcpDnsQueryForSite(dns_over_tcp::Query),
/// A DNS query that needs to be resolved within a particular site that we aren't connected to yet.
DnsQueryForSite(DnsQueryForSite),
/// We have received an ICMP error that is marked as "access prohibited".
///
/// Most likely, the Gateway is filtering these packets because the Client doesn't have access (anymore).
IcmpDestinationUnreachableProhibited,
}
struct DnsQueryForSite {
local: SocketAddr,
remote: SocketAddr,
transport: dns::Transport,
message: dns_types::Query,
}
impl ConnectionTrigger {
fn name(&self) -> &'static str {
match self {
ConnectionTrigger::PacketForResource(_) => "packet-for-resource",
ConnectionTrigger::UdpDnsQueryForSite(_) => "udp-dns-query-for-site",
ConnectionTrigger::TcpDnsQueryForSite(_) => "tcp-dns-query-for-site",
ConnectionTrigger::DnsQueryForSite(_) => "dns-query-for-site",
ConnectionTrigger::IcmpDestinationUnreachableProhibited => {
"icmp-destination-unreachable-prohibited"
}

View File

@@ -92,41 +92,11 @@ pub(crate) struct RecursiveResponse {
pub transport: Transport,
}
impl RecursiveQuery {
pub(crate) fn via_udp(
local: SocketAddr,
remote: SocketAddr,
server: SocketAddr,
message: dns_types::Query,
) -> Self {
Self {
server,
local,
remote,
message,
transport: Transport::Udp,
}
}
pub(crate) fn via_tcp(
local: SocketAddr,
remote: SocketAddr,
server: SocketAddr,
message: dns_types::Query,
) -> Self {
Self {
server,
local,
remote,
message,
transport: Transport::Tcp,
}
}
}
#[derive(Debug, Clone, Copy)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, derive_more::Display)]
pub(crate) enum Transport {
#[display("UDP")]
Udp,
#[display("TCP")]
Tcp,
}

View File

@@ -53,6 +53,7 @@ where
self.inner.get(key)
}
#[cfg(test)]
pub fn remove(&mut self, key: &K) -> Option<Entry<V>> {
self.expiration.retain(|_, keys| {
keys.retain(|k| k != key);

View File

@@ -456,17 +456,18 @@ impl GatewayTunnel {
for query in udp_dns_queries {
if let Some(nameserver) = self.io.fastest_nameserver() {
self.io.send_dns_query(dns::RecursiveQuery::via_udp(
query.local,
query.from,
SocketAddr::new(nameserver, dns::DNS_PORT),
query.message,
));
self.io.send_dns_query(dns::RecursiveQuery {
server: SocketAddr::new(nameserver, dns::DNS_PORT),
local: query.local,
remote: query.remote,
message: query.message,
transport: dns::Transport::Udp,
});
} else {
tracing::warn!(query = ?query.message, "No nameserver available to handle UDP DNS query");
if let Err(e) = self.io.send_udp_dns_response(
query.from,
query.remote,
query.local,
dns_types::Response::servfail(&query.message),
) {
@@ -479,12 +480,13 @@ impl GatewayTunnel {
for query in tcp_dns_queries {
if let Some(nameserver) = self.io.fastest_nameserver() {
self.io.send_dns_query(dns::RecursiveQuery::via_tcp(
query.local,
query.remote,
SocketAddr::new(nameserver, dns::DNS_PORT),
query.message,
));
self.io.send_dns_query(dns::RecursiveQuery {
server: SocketAddr::new(nameserver, dns::DNS_PORT),
local: query.local,
remote: query.remote,
message: query.message,
transport: dns::Transport::Tcp,
});
} else {
tracing::warn!(query = ?query.message, "No nameserver available to handle TCP DNS query");