mirror of
https://github.com/outbackdingo/firezone.git
synced 2026-01-27 18:18:55 +00:00
feat(connlib): introduce l3-udp-dns-client (#10764)
With #8263, we will stop receiving UDP and TCP DNS queries on the tunnel but use regular sockets instead. This means that for UDP DNS queries that need to be sent _through_ the tunnel, we actually need to make new IP packets again. For TCP, we already have a crate that does this for us because there, we need to manage an entire TCP stack. For UDP, the story is a bit simpler but there are still a few things involved. In particular, we need to set a source address for the packets and we need to sample a new random port for each query. The crate added in this PR does exactly that. It is not yet used anywhere but split out into a separate PR to reduce the reviewing burden of the larger refactor. Related: #8263 Related: #10758
This commit is contained in:
11
rust/Cargo.lock
generated
11
rust/Cargo.lock
generated
@@ -4127,6 +4127,17 @@ dependencies = [
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "l3-udp-dns-client"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"dns-types",
|
||||
"ip-packet",
|
||||
"rand 0.8.5",
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "l4-tcp-dns-server"
|
||||
version = "0.1.0"
|
||||
|
||||
@@ -10,6 +10,7 @@ members = [
|
||||
"connlib/etherparse-ext",
|
||||
"connlib/ip-packet",
|
||||
"connlib/l3-tcp",
|
||||
"connlib/l3-udp-dns-client",
|
||||
"connlib/l4-tcp-dns-server",
|
||||
"connlib/l4-udp-dns-server",
|
||||
"connlib/model",
|
||||
@@ -102,6 +103,7 @@ jni = "0.21.1"
|
||||
keyring = "3.6.3"
|
||||
known-folders = "1.3.1"
|
||||
l3-tcp = { path = "connlib/l3-tcp" }
|
||||
l3-udp-dns-client = { path = "connlib/l3-udp-dns-client" }
|
||||
l4-tcp-dns-server = { path = "connlib/l4-tcp-dns-server" }
|
||||
l4-udp-dns-server = { path = "connlib/l4-udp-dns-server" }
|
||||
libc = "0.2.176"
|
||||
|
||||
18
rust/connlib/l3-udp-dns-client/Cargo.toml
Normal file
18
rust/connlib/l3-udp-dns-client/Cargo.toml
Normal file
@@ -0,0 +1,18 @@
|
||||
[package]
|
||||
name = "l3-udp-dns-client"
|
||||
version = "0.1.0"
|
||||
edition = { workspace = true }
|
||||
license = { workspace = true }
|
||||
|
||||
[lib]
|
||||
path = "lib.rs"
|
||||
|
||||
[dependencies]
|
||||
anyhow = { workspace = true }
|
||||
dns-types = { workspace = true }
|
||||
ip-packet = { workspace = true }
|
||||
rand = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
249
rust/connlib/l3-udp-dns-client/lib.rs
Normal file
249
rust/connlib/l3-udp-dns-client/lib.rs
Normal file
@@ -0,0 +1,249 @@
|
||||
use std::{
|
||||
collections::{HashMap, VecDeque},
|
||||
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
|
||||
use anyhow::{Context as _, Result, anyhow, bail};
|
||||
use ip_packet::IpPacket;
|
||||
use rand::{Rng, SeedableRng, rngs::StdRng};
|
||||
|
||||
const TIMEOUT: Duration = Duration::from_secs(5);
|
||||
|
||||
/// A sans-io DNS-over-UDP client.
|
||||
pub struct Client<const MIN_PORT: u16 = 49152, const MAX_PORT: u16 = 65535> {
|
||||
source_ips: Option<(Ipv4Addr, Ipv6Addr)>,
|
||||
|
||||
pending_queries_by_local_port: HashMap<u16, PendingQuery>,
|
||||
|
||||
scheduled_queries: VecDeque<IpPacket>,
|
||||
query_results: VecDeque<QueryResult>,
|
||||
|
||||
rng: StdRng,
|
||||
|
||||
_created_at: Instant,
|
||||
last_now: Instant,
|
||||
}
|
||||
|
||||
struct PendingQuery {
|
||||
message: dns_types::Query,
|
||||
expires_at: Instant,
|
||||
server: SocketAddr,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct QueryResult {
|
||||
pub query: dns_types::Query,
|
||||
pub server: SocketAddr,
|
||||
pub result: Result<dns_types::Response>,
|
||||
}
|
||||
|
||||
impl<const MIN_PORT: u16, const MAX_PORT: u16> Client<MIN_PORT, MAX_PORT> {
|
||||
pub fn new(now: Instant, seed: [u8; 32]) -> Self {
|
||||
// Sadly, these can't be compile-time assertions :(
|
||||
assert!(MIN_PORT >= 49152, "Must use ephemeral port range");
|
||||
assert!(MIN_PORT < MAX_PORT, "Port range must not have length 0");
|
||||
|
||||
Self {
|
||||
source_ips: None,
|
||||
rng: StdRng::from_seed(seed),
|
||||
_created_at: now,
|
||||
last_now: now,
|
||||
pending_queries_by_local_port: Default::default(),
|
||||
scheduled_queries: Default::default(),
|
||||
query_results: Default::default(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Sets the IPv4 and IPv6 source ips to use for outgoing packets.
|
||||
pub fn set_source_interface(&mut self, v4: Ipv4Addr, v6: Ipv6Addr) {
|
||||
self.source_ips = Some((v4, v6));
|
||||
}
|
||||
|
||||
/// Send the given DNS query to the target server.
|
||||
///
|
||||
/// This only queues the message. You need to call [`Client::handle_timeout`] to actually send them.
|
||||
pub fn send_query(
|
||||
&mut self,
|
||||
server: SocketAddr,
|
||||
message: dns_types::Query,
|
||||
now: Instant,
|
||||
) -> Result<()> {
|
||||
let local_port = self.sample_new_unique_port()?;
|
||||
|
||||
let (ipv4_source, ipv6_source) = self
|
||||
.source_ips
|
||||
.ok_or_else(|| anyhow!("No source interface set"))?;
|
||||
|
||||
let local_ip = match server {
|
||||
SocketAddr::V4(_) => IpAddr::V4(ipv4_source),
|
||||
SocketAddr::V6(_) => IpAddr::V6(ipv6_source),
|
||||
};
|
||||
|
||||
self.pending_queries_by_local_port.insert(
|
||||
local_port,
|
||||
PendingQuery {
|
||||
message: message.clone(),
|
||||
expires_at: now + TIMEOUT,
|
||||
server,
|
||||
},
|
||||
);
|
||||
|
||||
let payload = message.into_bytes();
|
||||
|
||||
let ip_packet =
|
||||
ip_packet::make::udp_packet(local_ip, server.ip(), local_port, server.port(), payload)
|
||||
.context("Failed to make IP packet")?;
|
||||
|
||||
self.scheduled_queries.push_back(ip_packet);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Checks whether this client can handle the given packet.
|
||||
///
|
||||
/// Only TCP packets originating from one of the connected DNS resolvers are accepted.
|
||||
pub fn accepts(&self, packet: &IpPacket) -> bool {
|
||||
let Some(udp) = packet.as_udp() else {
|
||||
#[cfg(debug_assertions)]
|
||||
tracing::trace!(?packet, "Not a UDP packet");
|
||||
|
||||
return false;
|
||||
};
|
||||
|
||||
let Some((ipv4_source, ipv6_source)) = self.source_ips else {
|
||||
#[cfg(debug_assertions)]
|
||||
tracing::trace!("No source interface");
|
||||
|
||||
return false;
|
||||
};
|
||||
|
||||
// If the packet doesn't match our source interface, we don't want it.
|
||||
match packet.destination() {
|
||||
IpAddr::V4(v4) if v4 != ipv4_source => return false,
|
||||
IpAddr::V6(v6) if v6 != ipv6_source => return false,
|
||||
IpAddr::V4(_) | IpAddr::V6(_) => {}
|
||||
}
|
||||
|
||||
self.pending_queries_by_local_port
|
||||
.contains_key(&udp.destination_port())
|
||||
}
|
||||
|
||||
/// Handle the [`IpPacket`].
|
||||
///
|
||||
/// This function only inserts the packet into a buffer.
|
||||
/// To actually process the packets in the buffer, [`Client::handle_timeout`] must be called.
|
||||
pub fn handle_inbound(&mut self, packet: IpPacket) {
|
||||
debug_assert!(self.accepts(&packet));
|
||||
|
||||
let Some(udp) = packet.as_udp() else {
|
||||
return;
|
||||
};
|
||||
|
||||
let result =
|
||||
dns_types::Response::parse(udp.payload()).context("Failed to parse DNS response");
|
||||
let source = SocketAddr::new(packet.source(), udp.source_port());
|
||||
|
||||
if let Some(PendingQuery {
|
||||
message, server, ..
|
||||
}) = self
|
||||
.pending_queries_by_local_port
|
||||
.get(&udp.destination_port())
|
||||
&& let Ok(response) = result.as_ref()
|
||||
&& (response.id() != message.id() || source != *server)
|
||||
{
|
||||
tracing::debug!(%server, %source, query_id = %message.id(), response_id = %response.id(), "Response from server does not match query ID or original destination");
|
||||
return;
|
||||
}
|
||||
|
||||
let Some(PendingQuery {
|
||||
message, server, ..
|
||||
}) = self
|
||||
.pending_queries_by_local_port
|
||||
.remove(&udp.destination_port())
|
||||
else {
|
||||
return;
|
||||
};
|
||||
|
||||
self.query_results.push_back(QueryResult {
|
||||
query: message,
|
||||
server,
|
||||
result,
|
||||
});
|
||||
}
|
||||
|
||||
/// Returns [`IpPacket`]s that should be sent.
|
||||
pub fn poll_outbound(&mut self) -> Option<IpPacket> {
|
||||
self.scheduled_queries.pop_front()
|
||||
}
|
||||
|
||||
/// Returns the next [`QueryResult`].
|
||||
pub fn poll_query_result(&mut self) -> Option<QueryResult> {
|
||||
self.query_results.pop_front()
|
||||
}
|
||||
|
||||
/// Inform the client that time advanced.
|
||||
///
|
||||
/// Typical for a sans-IO design, `handle_timeout` will work through all local buffers and process them as much as possible.
|
||||
pub fn handle_timeout(&mut self, now: Instant) {
|
||||
self.last_now = now;
|
||||
|
||||
for (
|
||||
_,
|
||||
PendingQuery {
|
||||
message, server, ..
|
||||
},
|
||||
) in self
|
||||
.pending_queries_by_local_port
|
||||
.extract_if(|_, pending_query| now >= pending_query.expires_at)
|
||||
{
|
||||
self.query_results.push_back(QueryResult {
|
||||
query: message,
|
||||
server,
|
||||
result: Err(anyhow!("Timeout")),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
#[expect(
|
||||
clippy::disallowed_methods,
|
||||
reason = "We don't care about the ordering of the Iterator here."
|
||||
)]
|
||||
pub fn poll_timeout(&mut self) -> Option<Instant> {
|
||||
self.pending_queries_by_local_port
|
||||
.values()
|
||||
.map(|p| p.expires_at)
|
||||
.min()
|
||||
}
|
||||
|
||||
pub fn reset(&mut self) {
|
||||
tracing::debug!("Resetting state");
|
||||
|
||||
let aborted_pending_queries =
|
||||
self.pending_queries_by_local_port
|
||||
.drain()
|
||||
.map(|(_, pending_query)| QueryResult {
|
||||
query: pending_query.message,
|
||||
server: pending_query.server,
|
||||
result: Err(anyhow!("Timeout")),
|
||||
});
|
||||
|
||||
self.query_results.extend(aborted_pending_queries);
|
||||
}
|
||||
|
||||
fn sample_new_unique_port(&mut self) -> Result<u16> {
|
||||
let range = MIN_PORT..=MAX_PORT;
|
||||
|
||||
if self.pending_queries_by_local_port.len() == range.len() {
|
||||
bail!("All ports exhausted")
|
||||
}
|
||||
|
||||
loop {
|
||||
let port = self.rng.gen_range(range.clone());
|
||||
|
||||
if !self.pending_queries_by_local_port.contains_key(&port) {
|
||||
return Ok(port);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user