feat(connlib): introduce l3-udp-dns-client (#10764)

With #8263, we will stop receiving UDP and TCP DNS queries on the tunnel
but use regular sockets instead. This means that for UDP DNS queries
that need to be sent _through_ the tunnel, we actually need to make new
IP packets again. For TCP, we already have a crate that does this for us
because there, we need to manage an entire TCP stack.

For UDP, the story is a bit simpler but there are still a few things
involved. In particular, we need to set a source address for the packets
and we need to sample a new random port for each query.

The crate added in this PR does exactly that. It is not yet used
anywhere but split out into a separate PR to reduce the reviewing burden
of the larger refactor.

Related: #8263
Related: #10758
This commit is contained in:
Thomas Eizinger
2025-11-04 04:04:19 +11:00
committed by GitHub
parent 9e33e514c4
commit 1b7313622a
4 changed files with 280 additions and 0 deletions

11
rust/Cargo.lock generated
View File

@@ -4127,6 +4127,17 @@ dependencies = [
"tracing",
]
[[package]]
name = "l3-udp-dns-client"
version = "0.1.0"
dependencies = [
"anyhow",
"dns-types",
"ip-packet",
"rand 0.8.5",
"tracing",
]
[[package]]
name = "l4-tcp-dns-server"
version = "0.1.0"

View File

@@ -10,6 +10,7 @@ members = [
"connlib/etherparse-ext",
"connlib/ip-packet",
"connlib/l3-tcp",
"connlib/l3-udp-dns-client",
"connlib/l4-tcp-dns-server",
"connlib/l4-udp-dns-server",
"connlib/model",
@@ -102,6 +103,7 @@ jni = "0.21.1"
keyring = "3.6.3"
known-folders = "1.3.1"
l3-tcp = { path = "connlib/l3-tcp" }
l3-udp-dns-client = { path = "connlib/l3-udp-dns-client" }
l4-tcp-dns-server = { path = "connlib/l4-tcp-dns-server" }
l4-udp-dns-server = { path = "connlib/l4-udp-dns-server" }
libc = "0.2.176"

View File

@@ -0,0 +1,18 @@
[package]
name = "l3-udp-dns-client"
version = "0.1.0"
edition = { workspace = true }
license = { workspace = true }
[lib]
path = "lib.rs"
[dependencies]
anyhow = { workspace = true }
dns-types = { workspace = true }
ip-packet = { workspace = true }
rand = { workspace = true }
tracing = { workspace = true }
[lints]
workspace = true

View File

@@ -0,0 +1,249 @@
use std::{
collections::{HashMap, VecDeque},
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
time::{Duration, Instant},
};
use anyhow::{Context as _, Result, anyhow, bail};
use ip_packet::IpPacket;
use rand::{Rng, SeedableRng, rngs::StdRng};
const TIMEOUT: Duration = Duration::from_secs(5);
/// A sans-io DNS-over-UDP client.
pub struct Client<const MIN_PORT: u16 = 49152, const MAX_PORT: u16 = 65535> {
source_ips: Option<(Ipv4Addr, Ipv6Addr)>,
pending_queries_by_local_port: HashMap<u16, PendingQuery>,
scheduled_queries: VecDeque<IpPacket>,
query_results: VecDeque<QueryResult>,
rng: StdRng,
_created_at: Instant,
last_now: Instant,
}
struct PendingQuery {
message: dns_types::Query,
expires_at: Instant,
server: SocketAddr,
}
#[derive(Debug)]
pub struct QueryResult {
pub query: dns_types::Query,
pub server: SocketAddr,
pub result: Result<dns_types::Response>,
}
impl<const MIN_PORT: u16, const MAX_PORT: u16> Client<MIN_PORT, MAX_PORT> {
pub fn new(now: Instant, seed: [u8; 32]) -> Self {
// Sadly, these can't be compile-time assertions :(
assert!(MIN_PORT >= 49152, "Must use ephemeral port range");
assert!(MIN_PORT < MAX_PORT, "Port range must not have length 0");
Self {
source_ips: None,
rng: StdRng::from_seed(seed),
_created_at: now,
last_now: now,
pending_queries_by_local_port: Default::default(),
scheduled_queries: Default::default(),
query_results: Default::default(),
}
}
/// Sets the IPv4 and IPv6 source ips to use for outgoing packets.
pub fn set_source_interface(&mut self, v4: Ipv4Addr, v6: Ipv6Addr) {
self.source_ips = Some((v4, v6));
}
/// Send the given DNS query to the target server.
///
/// This only queues the message. You need to call [`Client::handle_timeout`] to actually send them.
pub fn send_query(
&mut self,
server: SocketAddr,
message: dns_types::Query,
now: Instant,
) -> Result<()> {
let local_port = self.sample_new_unique_port()?;
let (ipv4_source, ipv6_source) = self
.source_ips
.ok_or_else(|| anyhow!("No source interface set"))?;
let local_ip = match server {
SocketAddr::V4(_) => IpAddr::V4(ipv4_source),
SocketAddr::V6(_) => IpAddr::V6(ipv6_source),
};
self.pending_queries_by_local_port.insert(
local_port,
PendingQuery {
message: message.clone(),
expires_at: now + TIMEOUT,
server,
},
);
let payload = message.into_bytes();
let ip_packet =
ip_packet::make::udp_packet(local_ip, server.ip(), local_port, server.port(), payload)
.context("Failed to make IP packet")?;
self.scheduled_queries.push_back(ip_packet);
Ok(())
}
/// Checks whether this client can handle the given packet.
///
/// Only TCP packets originating from one of the connected DNS resolvers are accepted.
pub fn accepts(&self, packet: &IpPacket) -> bool {
let Some(udp) = packet.as_udp() else {
#[cfg(debug_assertions)]
tracing::trace!(?packet, "Not a UDP packet");
return false;
};
let Some((ipv4_source, ipv6_source)) = self.source_ips else {
#[cfg(debug_assertions)]
tracing::trace!("No source interface");
return false;
};
// If the packet doesn't match our source interface, we don't want it.
match packet.destination() {
IpAddr::V4(v4) if v4 != ipv4_source => return false,
IpAddr::V6(v6) if v6 != ipv6_source => return false,
IpAddr::V4(_) | IpAddr::V6(_) => {}
}
self.pending_queries_by_local_port
.contains_key(&udp.destination_port())
}
/// Handle the [`IpPacket`].
///
/// This function only inserts the packet into a buffer.
/// To actually process the packets in the buffer, [`Client::handle_timeout`] must be called.
pub fn handle_inbound(&mut self, packet: IpPacket) {
debug_assert!(self.accepts(&packet));
let Some(udp) = packet.as_udp() else {
return;
};
let result =
dns_types::Response::parse(udp.payload()).context("Failed to parse DNS response");
let source = SocketAddr::new(packet.source(), udp.source_port());
if let Some(PendingQuery {
message, server, ..
}) = self
.pending_queries_by_local_port
.get(&udp.destination_port())
&& let Ok(response) = result.as_ref()
&& (response.id() != message.id() || source != *server)
{
tracing::debug!(%server, %source, query_id = %message.id(), response_id = %response.id(), "Response from server does not match query ID or original destination");
return;
}
let Some(PendingQuery {
message, server, ..
}) = self
.pending_queries_by_local_port
.remove(&udp.destination_port())
else {
return;
};
self.query_results.push_back(QueryResult {
query: message,
server,
result,
});
}
/// Returns [`IpPacket`]s that should be sent.
pub fn poll_outbound(&mut self) -> Option<IpPacket> {
self.scheduled_queries.pop_front()
}
/// Returns the next [`QueryResult`].
pub fn poll_query_result(&mut self) -> Option<QueryResult> {
self.query_results.pop_front()
}
/// Inform the client that time advanced.
///
/// Typical for a sans-IO design, `handle_timeout` will work through all local buffers and process them as much as possible.
pub fn handle_timeout(&mut self, now: Instant) {
self.last_now = now;
for (
_,
PendingQuery {
message, server, ..
},
) in self
.pending_queries_by_local_port
.extract_if(|_, pending_query| now >= pending_query.expires_at)
{
self.query_results.push_back(QueryResult {
query: message,
server,
result: Err(anyhow!("Timeout")),
});
}
}
#[expect(
clippy::disallowed_methods,
reason = "We don't care about the ordering of the Iterator here."
)]
pub fn poll_timeout(&mut self) -> Option<Instant> {
self.pending_queries_by_local_port
.values()
.map(|p| p.expires_at)
.min()
}
pub fn reset(&mut self) {
tracing::debug!("Resetting state");
let aborted_pending_queries =
self.pending_queries_by_local_port
.drain()
.map(|(_, pending_query)| QueryResult {
query: pending_query.message,
server: pending_query.server,
result: Err(anyhow!("Timeout")),
});
self.query_results.extend(aborted_pending_queries);
}
fn sample_new_unique_port(&mut self) -> Result<u16> {
let range = MIN_PORT..=MAX_PORT;
if self.pending_queries_by_local_port.len() == range.len() {
bail!("All ports exhausted")
}
loop {
let port = self.rng.gen_range(range.clone());
if !self.pending_queries_by_local_port.contains_key(&port) {
return Ok(port);
}
}
}
}