chore(connlib): add sans-IO DNS-over-TCP client (#7007)

This brings us one step closer to completing #6140. In Firezone, users
can define custom upstream DNS servers that take priority over
system-defined DNS servers. The IPs of these servers could also be
resources, meaning the DNS queries must be sent through the WireGuard
tunnel to the gateway.

For UDP DNS queries, that is easy because each query is only a single
packet. For TCP DNS queries, we need to have a dedicated TCP-capable DNS
server that parses all incoming queries. If they are required to be
forwarded to the gateway, we then need a TCP-capable DNS client that can
send them to the actual upstream DNS server.

This PR implements such a DNS client. The design is tailored for what we
need in `connlib`: We maintain a permanent TCP connection to each
upstream DNS server and send queries to them. Most likely, users will
only have a handful of DNS servers defined. TCP requires a three-way
handshake before any application data can be sent, maintaining a
connection should therefore greatly improve DNS resolution latency.

DNS resolvers are encouraged to keep TCP connections open but may close
them if they run out of resources. We only re-connect once we have more
queries to send in order to not spam the resolver with connections.

Resolves: #7000.

---------

Signed-off-by: Thomas Eizinger <thomas@eizinger.io>
This commit is contained in:
Thomas Eizinger
2024-10-12 09:04:45 +11:00
committed by GitHub
parent 7838da9739
commit 274cc86557
9 changed files with 747 additions and 151 deletions

2
rust/Cargo.lock generated
View File

@@ -1912,9 +1912,11 @@ dependencies = [
"domain",
"firezone-bin-shared",
"firezone-logging",
"futures",
"ip-packet",
"ip_network",
"itertools 0.13.0",
"rand 0.8.5",
"smoltcp",
"tokio",
"tracing",

View File

@@ -9,12 +9,14 @@ anyhow = "1.0"
domain = { workspace = true }
ip-packet = { workspace = true }
itertools = "0.13"
rand = "0.8"
smoltcp = { version = "0.11", default-features = false, features = ["std", "log", "medium-ip", "proto-ipv4", "proto-ipv6", "socket-tcp"] }
tracing = { workspace = true }
[dev-dependencies]
firezone-bin-shared = { workspace = true }
firezone-logging = { workspace = true }
futures = "0.3"
ip_network = { version = "0.4", default-features = false }
tokio = { workspace = true, features = ["process", "rt", "macros"] }
tun = { workspace = true }

View File

@@ -0,0 +1,417 @@
use std::{
collections::{BTreeSet, HashMap, HashSet, VecDeque},
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
ops::RangeInclusive,
time::Instant,
};
use crate::{codec, create_tcp_socket, interface::create_interface, stub_device::InMemoryDevice};
use anyhow::{anyhow, bail, Context as _, Result};
use domain::{base::Message, dep::octseq::OctetsInto};
use ip_packet::IpPacket;
use rand::{rngs::StdRng, Rng, SeedableRng};
use smoltcp::{
iface::{Interface, SocketSet},
socket::tcp::{self, Socket},
};
/// A sans-io DNS-over-TCP client.
///
/// The client maintains a single TCP connection for each configured resolver.
/// If the TCP connection fails for some reason, we try to establish a new one.
///
/// One of the design goals of this client is to always provide a result for a query.
/// If the TCP connection fails, we report all currently pending queries to that resolver as failed.
///
/// There are however currently no timeouts.
/// If the upstream resolver refuses to answer, we don't fail the query.
pub struct Client {
device: InMemoryDevice,
interface: Interface,
source_ips: Option<(Ipv4Addr, Ipv6Addr)>,
/// The port range we are allowed to use on our local interface for outgoing connections.
port_range: RangeInclusive<u16>,
sockets: SocketSet<'static>,
sockets_by_remote: HashMap<SocketAddr, smoltcp::iface::SocketHandle>,
local_ports_by_socket: HashMap<smoltcp::iface::SocketHandle, u16>,
/// Queries we should send to a DNS resolver.
pending_queries_by_remote: HashMap<SocketAddr, VecDeque<Message<Vec<u8>>>>,
/// Queries we have sent to a DNS resolver and are waiting for a reply.
sent_queries_by_remote: HashMap<SocketAddr, HashMap<u16, Message<Vec<u8>>>>,
query_results: VecDeque<QueryResult>,
rng: StdRng,
}
#[derive(Debug)]
pub struct QueryResult {
pub query: Message<Vec<u8>>,
pub server: SocketAddr,
pub result: Result<Message<Vec<u8>>>,
}
impl Client {
pub fn new(now: Instant, port_range: RangeInclusive<u16>, seed: [u8; 32]) -> Self {
assert!(!port_range.contains(&0), "0 port must not be possible");
let mut device = InMemoryDevice::default();
let interface = create_interface(&mut device, now);
Self {
device,
interface,
sockets: SocketSet::new(Vec::default()),
source_ips: None,
port_range,
sent_queries_by_remote: Default::default(),
query_results: Default::default(),
rng: StdRng::from_seed(seed),
sockets_by_remote: Default::default(),
local_ports_by_socket: Default::default(),
pending_queries_by_remote: Default::default(),
}
}
/// Sets the IPv4 and IPv6 source ips to use for outgoing packets.
pub fn set_source_interface(&mut self, v4: Ipv4Addr, v6: Ipv6Addr) {
self.source_ips = Some((v4, v6));
}
/// Connect to the specified DNS resolvers.
///
/// All currently pending queries will be reported as failed.
pub fn connect_to_resolvers(&mut self, resolvers: BTreeSet<SocketAddr>) -> Result<()> {
let (ipv4_source, ipv6_source) = self.source_ips.context("Missing source IPs")?;
// First, clear all local state.
self.sockets = SocketSet::new(vec![]);
self.sockets_by_remote.clear();
self.local_ports_by_socket.clear();
self.query_results
.extend(
self.pending_queries_by_remote
.drain()
.flat_map(|(server, queries)| {
into_failed_results(server, queries, || anyhow!("Aborted"))
}),
);
self.query_results
.extend(
self.sent_queries_by_remote
.drain()
.flat_map(|(server, queries)| {
into_failed_results(server, queries.into_values(), || anyhow!("Aborted"))
}),
);
// Second, try to create all new sockets.
let new_sockets = std::iter::zip(self.sample_unique_ports(resolvers.len())?, resolvers).map(|(port, server)| {
let local_endpoint = match server {
SocketAddr::V4(_) => SocketAddr::new(ipv4_source.into(), port),
SocketAddr::V6(_) => SocketAddr::new(ipv6_source.into(), port),
};
let mut socket = create_tcp_socket();
socket
.connect(self.interface.context(), server, local_endpoint)
.context("Failed to connect socket")?;
tracing::info!(local = %local_endpoint, remote = %server, "Connecting to DNS resolver");
Ok((server, local_endpoint, socket))
})
.collect::<Result<Vec<_>>>()?;
// Third, if everything was successful, change the local state.
for (server, local_endpoint, socket) in new_sockets {
let handle = self.sockets.add(socket);
self.sockets_by_remote.insert(server, handle);
self.local_ports_by_socket
.insert(handle, local_endpoint.port());
}
Ok(())
}
/// Send the given DNS query to the target server.
///
/// This only queues the message. You need to call [`Client::handle_timeout`] to actually send them.
pub fn send_query(&mut self, server: SocketAddr, message: Message<Vec<u8>>) -> Result<()> {
anyhow::ensure!(!message.header().qr(), "Message is a DNS response!");
anyhow::ensure!(
self.sockets_by_remote.contains_key(&server),
"Unknown DNS resolver"
);
self.pending_queries_by_remote
.entry(server)
.or_default()
.push_back(message);
Ok(())
}
/// Checks whether this client can handle the given packet.
///
/// Only TCP packets originating from one of the connected DNS resolvers are accepted.
pub fn accepts(&self, packet: &IpPacket) -> bool {
let Some(tcp) = packet.as_tcp() else {
tracing::trace!(?packet, "Not a TCP packet");
return false;
};
let Some((ipv4_source, ipv6_source)) = self.source_ips else {
tracing::trace!("No source interface");
return false;
};
// If the packet doesn't match our source interface, we don't want it.
match packet.destination() {
IpAddr::V4(v4) if v4 != ipv4_source => return false,
IpAddr::V6(v6) if v6 != ipv6_source => return false,
_ => {}
}
let remote = SocketAddr::new(packet.source(), tcp.source_port());
let has_socket = self.sockets_by_remote.contains_key(&remote);
if !has_socket && tracing::enabled!(tracing::Level::TRACE) {
let open_sockets = BTreeSet::from_iter(self.sockets_by_remote.keys().copied());
tracing::trace!(%remote, ?open_sockets, "No open socket for remote");
}
has_socket
}
/// Handle the [`IpPacket`].
///
/// This function only inserts the packet into a buffer.
/// To actually process the packets in the buffer, [`Client::handle_timeout`] must be called.
pub fn handle_inbound(&mut self, packet: IpPacket) {
debug_assert!(self.accepts(&packet));
self.device.receive(packet);
}
/// Returns [`IpPacket`]s that should be sent.
pub fn poll_outbound(&mut self) -> Option<IpPacket> {
self.device.next_send()
}
/// Returns the next [`QueryResult`].
pub fn poll_query_result(&mut self) -> Option<QueryResult> {
self.query_results.pop_front()
}
/// Inform the client that time advanced.
///
/// Typical for a sans-IO design, `handle_timeout` will work through all local buffers and process them as much as possible.
pub fn handle_timeout(&mut self, now: Instant) {
let Some((ipv4_source, ipv6_source)) = self.source_ips else {
return;
};
let changed = self.interface.poll(
smoltcp::time::Instant::from(now),
&mut self.device,
&mut self.sockets,
);
if !changed && self.pending_queries_by_remote.is_empty() {
return;
}
for (remote, handle) in self.sockets_by_remote.iter_mut() {
let socket = self.sockets.get_mut::<Socket>(*handle);
let server = *remote;
// First, attempt to send all pending queries on this socket.
send_pending_queries(
socket,
server,
&mut self.pending_queries_by_remote,
&mut self.sent_queries_by_remote,
&mut self.query_results,
);
// Second, attempt to receive responses.
recv_responses(
socket,
server,
&mut self.pending_queries_by_remote,
&mut self.sent_queries_by_remote,
&mut self.query_results,
);
let has_pending_dns_queries = !self
.pending_queries_by_remote
.entry(server)
.or_default()
.is_empty();
// Third, if the socket got closed, reconnect it.
if matches!(socket.state(), tcp::State::Closed) && has_pending_dns_queries {
let local_port = self
.local_ports_by_socket
.get(handle)
.expect("must always have a port for each socket");
let local_endpoint = match server {
SocketAddr::V4(_) => SocketAddr::new(ipv4_source.into(), *local_port),
SocketAddr::V6(_) => SocketAddr::new(ipv6_source.into(), *local_port),
};
tracing::info!(local = %local_endpoint, remote = %server, "Re-connecting to DNS resolver");
socket
.connect(self.interface.context(), server, local_endpoint)
.expect(
"re-connecting a closed socket with the same parameters should always work",
);
}
}
}
fn sample_unique_ports(&mut self, num_ports: usize) -> Result<impl Iterator<Item = u16>> {
let mut ports = HashSet::with_capacity(num_ports);
if num_ports > self.port_range.len() {
bail!(
"Port range only provides {} values but we need {num_ports}",
self.port_range.len()
)
}
while ports.len() < num_ports {
ports.insert(self.rng.gen_range(self.port_range.clone()));
}
Ok(ports.into_iter())
}
}
fn send_pending_queries(
socket: &mut Socket,
server: SocketAddr,
pending_queries_by_remote: &mut HashMap<SocketAddr, VecDeque<Message<Vec<u8>>>>,
sent_queries_by_remote: &mut HashMap<SocketAddr, HashMap<u16, Message<Vec<u8>>>>,
query_results: &mut VecDeque<QueryResult>,
) {
let pending_queries = pending_queries_by_remote.entry(server).or_default();
let sent_queries = sent_queries_by_remote.entry(server).or_default();
loop {
if !socket.can_send() {
break;
}
let Some(query) = pending_queries.pop_front() else {
break;
};
match codec::try_send(socket, query.for_slice_ref()).context("Failed to send DNS query") {
Ok(()) => {
let replaced = sent_queries.insert(query.header().id(), query).is_some();
debug_assert!(!replaced, "Query ID is not unique");
}
Err(e) => {
// We failed to send the query, declare the socket as failed.
socket.abort();
query_results.extend(into_failed_results(
server,
pending_queries
.drain(..)
.chain(sent_queries.drain().map(|(_, query)| query)),
|| anyhow!("{e:#}"),
));
query_results.push_back(QueryResult {
query,
server,
result: Err(e),
});
}
}
}
}
fn recv_responses(
socket: &mut Socket,
server: SocketAddr,
pending_queries_by_remote: &mut HashMap<SocketAddr, VecDeque<Message<Vec<u8>>>>,
sent_queries_by_remote: &mut HashMap<SocketAddr, HashMap<u16, Message<Vec<u8>>>>,
query_results: &mut VecDeque<QueryResult>,
) {
let Some(result) = try_recv_response(socket)
.context("Failed to receive DNS response")
.transpose()
else {
return; // No messages on this socket, continue.
};
let pending_queries = pending_queries_by_remote.entry(server).or_default();
let sent_queries = sent_queries_by_remote.entry(server).or_default();
let new_results = result
.and_then(|response| {
let query = sent_queries
.remove(&response.header().id())
.context("DNS resolver sent response for unknown query")?;
Ok(vec![QueryResult {
query,
server,
result: Ok(response.octets_into()),
}])
})
.unwrap_or_else(|e| {
socket.abort();
into_failed_results(
server,
pending_queries
.drain(..)
.chain(sent_queries.drain().map(|(_, query)| query)),
|| anyhow!("{e:#}"),
)
.collect()
});
query_results.extend(new_results);
}
fn into_failed_results(
server: SocketAddr,
iter: impl IntoIterator<Item = Message<Vec<u8>>>,
make_error: impl Fn() -> anyhow::Error,
) -> impl Iterator<Item = QueryResult> {
iter.into_iter().map(move |query| QueryResult {
query,
server,
result: Err(make_error()),
})
}
fn try_recv_response<'b>(socket: &'b mut tcp::Socket) -> Result<Option<Message<&'b [u8]>>> {
anyhow::ensure!(socket.is_active(), "Socket is not active");
if !socket.can_recv() {
tracing::trace!("Not yet ready to receive next message");
return Ok(None);
}
let Some(message) = codec::try_recv(socket)? else {
return Ok(None);
};
anyhow::ensure!(message.header().qr(), "DNS message is a query!");
Ok(Some(message))
}

View File

@@ -0,0 +1,141 @@
//! Implements sending and receiving of DNS messages over TCP.
//!
//! TCP's stream-oriented nature requires us to know how long the encoded DNS message is before we can read it.
//! For this purpose, DNS messages over TCP are prefixed using a big-endian encoded u16.
//!
//! Source: <https://datatracker.ietf.org/doc/html/rfc1035#section-4.2.2>.
use anyhow::{Context as _, Result};
use domain::{
base::{iana::Rcode, Message, ParsedName, Rtype},
rdata::AllRecordData,
};
use itertools::Itertools as _;
use smoltcp::socket::tcp;
pub fn try_send(socket: &mut tcp::Socket, message: Message<&[u8]>) -> Result<()> {
let response = message.as_slice();
let dns_message_length = (response.len() as u16).to_be_bytes();
let written = socket
.send_slice(&dns_message_length)
.context("Failed to write TCP DNS length header")?;
anyhow::ensure!(
written == 2,
"Not enough space in write buffer for TCP DNS length header"
);
let written = socket
.send_slice(response)
.context("Failed to write DNS message")?;
anyhow::ensure!(
written == response.len(),
"Not enough space in write buffer for DNS message"
);
if tracing::event_enabled!(target: "wire::dns::tcp::send", tracing::Level::TRACE) {
if let Some(ParsedMessage {
qid,
qname,
qtype,
response,
rcode,
records,
}) = parse(message)
{
if response {
let records = records.into_iter().join(" | ");
tracing::trace!(target: "wire::dns::tcp::send", %qid, %rcode, "{:5} {qname} => [{records}]", qtype.to_string());
} else {
tracing::trace!(target: "wire::dns::tcp::send", %qid, "{:5} {qname}", qtype.to_string());
}
}
}
Ok(())
}
pub fn try_recv<'b>(socket: &'b mut tcp::Socket) -> Result<Option<Message<&'b [u8]>>> {
let maybe_message = socket
.recv(|r| {
// DNS over TCP has a 2-byte length prefix at the start, see <https://datatracker.ietf.org/doc/html/rfc1035#section-4.2.2>.
let Some((header, message)) = r.split_first_chunk::<2>() else {
return (0, None);
};
let dns_message_length = u16::from_be_bytes(*header) as usize;
if message.len() < dns_message_length {
return (0, None); // Don't consume any bytes unless we can read the full message at once.
}
(2 + dns_message_length, Some(Message::from_octets(message)))
})
.context("Failed to recv TCP data")?
.transpose()
.context("Failed to parse DNS message")?;
if tracing::event_enabled!(target: "wire::dns::tcp::recv", tracing::Level::TRACE) {
if let Some(ParsedMessage {
qid,
qname,
qtype,
rcode,
response,
records,
}) = maybe_message.and_then(parse)
{
if response {
let records = records.into_iter().join(" | ");
tracing::trace!(target: "wire::dns::tcp::recv", %qid, %rcode, "{:5} {qname} => [{records}]", qtype.to_string());
} else {
tracing::trace!(target: "wire::dns::tcp::recv", %qid, "{:5} {qname}", qtype.to_string());
}
}
}
Ok(maybe_message)
}
fn parse(message: Message<&[u8]>) -> Option<ParsedMessage<'_>> {
let question = message.sole_question().ok()?;
let answers = message.answer().ok()?;
let qtype = question.qtype();
let qname = question.into_qname();
let qid = message.header().id();
let response = message.header().qr();
let rcode = message.header().rcode();
let records = answers
.into_iter()
.filter_map(|r| {
let data = r
.ok()?
.into_any_record::<AllRecordData<_, _>>()
.ok()?
.data()
.clone();
Some(data)
})
.collect();
Some(ParsedMessage {
qid,
qname,
rcode,
qtype,
response,
records,
})
}
struct ParsedMessage<'a> {
qid: u16,
qname: ParsedName<&'a [u8]>,
qtype: Rtype,
rcode: Rcode,
response: bool,
records: Vec<AllRecordData<&'a [u8], ParsedName<&'a [u8]>>>,
}

View File

@@ -0,0 +1,44 @@
use std::time::Instant;
use smoltcp::{
iface::{Config, Interface, Route},
wire::{HardwareAddress, Ipv4Address, Ipv4Cidr, Ipv6Address, Ipv6Cidr},
};
use crate::stub_device::InMemoryDevice;
const IP4_ADDR: Ipv4Address = Ipv4Address::new(127, 0, 0, 1);
const IP6_ADDR: Ipv6Address = Ipv6Address::new(0, 0, 0, 0, 0, 0, 0, 1);
/// Creates a smoltcp [`Interface`].
///
/// smoltcp's abstractions allow to directly plug it in a TUN device.
/// As a result, it has all the features you'd expect from a network interface:
/// - Setting IP addresses
/// - Defining routes
///
/// In our implementation, we don't want to use any of that.
/// Our device is entirely backed by in-memory buffers and we and selectively feed IP packets to it.
/// Therefore, we configure it to:
/// - Accept any packet
/// - Define dummy IPs (localhost for IPv4 and IPv6)
/// - Define catch-all routes (0.0.0.0/0) that routes all traffic to the interface
pub fn create_interface(device: &mut InMemoryDevice, now: Instant) -> Interface {
let mut interface = Interface::new(Config::new(HardwareAddress::Ip), device, now.into());
// Accept packets with any destination IP, not just our interface.
interface.set_any_ip(true);
// Set our interface IPs. These are just dummies and don't show up anywhere!
interface.update_ip_addrs(|ips| {
ips.push(Ipv4Cidr::new(IP4_ADDR, 32).into()).unwrap();
ips.push(Ipv6Cidr::new(IP6_ADDR, 128).into()).unwrap();
});
// Configure catch-all routes, meaning all packets given to `smoltcp` will be routed to our interface.
interface.routes_mut().update(|routes| {
routes.push(Route::new_ipv4_gateway(IP4_ADDR)).unwrap();
routes.push(Route::new_ipv6_gateway(IP6_ADDR)).unwrap();
});
interface
}

View File

@@ -1,4 +1,21 @@
mod client;
mod codec;
mod interface;
mod server;
mod stub_device;
pub use client::{Client, QueryResult};
pub use server::{Server, SocketHandle};
fn create_tcp_socket() -> smoltcp::socket::tcp::Socket<'static> {
/// The 2-byte length prefix of DNS over TCP messages limits their size to effectively u16::MAX.
/// It is quite unlikely that we have to buffer _multiple_ of these max-sized messages.
/// Being able to buffer at least one of them means we can handle the extreme case.
/// In practice, this allows the OS to queue multiple queries even if we can't immediately process them.
const MAX_TCP_DNS_MSG_LENGTH: usize = u16::MAX as usize;
smoltcp::socket::tcp::Socket::new(
smoltcp::storage::RingBuffer::new(vec![0u8; MAX_TCP_DNS_MSG_LENGTH]),
smoltcp::storage::RingBuffer::new(vec![0u8; MAX_TCP_DNS_MSG_LENGTH]),
)
}

View File

@@ -4,16 +4,14 @@ use std::{
time::Instant,
};
use crate::stub_device::InMemoryDevice;
use crate::{codec, create_tcp_socket, interface::create_interface, stub_device::InMemoryDevice};
use anyhow::{Context as _, Result};
use domain::{base::Message, dep::octseq::OctetsInto as _, rdata::AllRecordData};
use domain::{base::Message, dep::octseq::OctetsInto as _};
use ip_packet::IpPacket;
use itertools::Itertools as _;
use smoltcp::{
iface::{Config, Interface, Route, SocketSet},
iface::{Interface, SocketSet},
socket::tcp,
storage::RingBuffer,
wire::{HardwareAddress, IpEndpoint, Ipv4Address, Ipv4Cidr, Ipv6Address, Ipv6Cidr},
wire::IpEndpoint,
};
/// A sans-IO implementation of DNS-over-TCP server.
@@ -43,34 +41,10 @@ pub struct Query {
pub local: SocketAddr,
}
const SERVER_IP4_ADDR: Ipv4Address = Ipv4Address::new(127, 0, 0, 1);
const SERVER_IP6_ADDR: Ipv6Address = Ipv6Address::new(0, 0, 0, 0, 0, 0, 0, 1);
impl Server {
pub fn new(now: Instant) -> Self {
let mut device = InMemoryDevice::default();
let mut interface =
Interface::new(Config::new(HardwareAddress::Ip), &mut device, now.into());
// Accept packets with any destination IP, not just our interface.
interface.set_any_ip(true);
// Set our interface IPs. These are just dummies and don't show up anywhere!
interface.update_ip_addrs(|ips| {
ips.push(Ipv4Cidr::new(SERVER_IP4_ADDR, 32).into()).unwrap();
ips.push(Ipv6Cidr::new(SERVER_IP6_ADDR, 128).into())
.unwrap();
});
// Configure catch-all routes, meaning all packets given to `smoltcp` will be routed to our interface.
interface.routes_mut().update(|routes| {
routes
.push(Route::new_ipv4_gateway(SERVER_IP4_ADDR))
.unwrap();
routes
.push(Route::new_ipv6_gateway(SERVER_IP6_ADDR))
.unwrap();
});
let interface = create_interface(&mut device, now);
Self {
device,
@@ -98,7 +72,12 @@ impl Server {
for listen_endpoint in addresses {
for _ in 0..NUM_CONCURRENT_CLIENTS {
let handle = sockets.add(create_tcp_socket(listen_endpoint));
let mut socket = create_tcp_socket();
socket
.listen(listen_endpoint)
.expect("A fresh socket should always be able to listen");
let handle = sockets.add(socket);
listen_endpoints.insert(handle, listen_endpoint);
}
@@ -149,40 +128,9 @@ impl Server {
pub fn send_message(&mut self, socket: SocketHandle, message: Message<Vec<u8>>) -> Result<()> {
let socket = self.sockets.get_mut::<tcp::Socket>(socket.0);
let result = write_tcp_dns_response(socket, message.for_slice_ref());
if result.is_err() {
socket.abort();
}
result.context("Failed to write DNS response")?; // Bail before logging in case we failed to write the response.
if tracing::event_enabled!(target: "wire::dns::res", tracing::Level::TRACE) {
if let Ok(question) = message.sole_question() {
let qtype = question.qtype();
let qname = question.into_qname();
let rcode = message.header().rcode();
if let Ok(record_section) = message.answer() {
let records = record_section
.into_iter()
.filter_map(|r| {
let data = r
.ok()?
.into_any_record::<AllRecordData<_, _>>()
.ok()?
.data()
.clone();
Some(data)
})
.join(" | ");
let qid = message.header().id();
tracing::trace!(target: "wire::dns::res", %qid, %rcode, "{:5} {qname} => [{records}]", qtype.to_string());
}
}
}
write_tcp_dns_response(socket, message.for_slice_ref())
.inspect_err(|_| socket.abort()) // Abort socket on error.
.context("Failed to write DNS response")?;
Ok(())
}
@@ -211,28 +159,20 @@ impl Server {
for (handle, smoltcp::socket::Socket::Tcp(socket)) in self.sockets.iter_mut() {
let listen = self.listen_endpoints.get(&handle).copied().unwrap();
match try_recv_query(socket, listen) {
Ok(Some(message)) => {
if tracing::event_enabled!(target: "wire::dns::qry", tracing::Level::TRACE) {
if let Ok(question) = message.sole_question() {
let qtype = question.qtype();
let qname = question.into_qname();
let qid = message.header().id();
tracing::trace!(target: "wire::dns::qry", %qid, "{:5} {qname}", qtype.to_string());
}
while let Some(result) = try_recv_query(socket, listen).transpose() {
match result {
Ok(message) => {
self.received_queries.push_back(Query {
message: message.octets_into(),
socket: SocketHandle(handle),
local: listen,
});
}
Err(e) => {
tracing::debug!("Error on receiving DNS query: {e}");
socket.abort();
break;
}
self.received_queries.push_back(Query {
message,
socket: SocketHandle(handle),
local: listen,
});
}
Ok(None) => {}
Err(e) => {
tracing::debug!("Error on receiving DNS query: {e}");
socket.abort();
}
}
}
@@ -249,28 +189,10 @@ impl Server {
}
}
fn create_tcp_socket(listen_endpoint: SocketAddr) -> tcp::Socket<'static> {
/// The 2-byte length prefix of DNS over TCP messages limits their size to effectively u16::MAX.
/// It is quite unlikely that we have to buffer _multiple_ of these max-sized messages.
/// Being able to buffer at least one of them means we can handle the extreme case.
/// In practice, this allows the OS to queue multiple queries even if we can't immediately process them.
const MAX_TCP_DNS_MSG_LENGTH: usize = u16::MAX as usize;
let mut socket = tcp::Socket::new(
RingBuffer::new(vec![0u8; MAX_TCP_DNS_MSG_LENGTH]),
RingBuffer::new(vec![0u8; MAX_TCP_DNS_MSG_LENGTH]),
);
socket
.listen(listen_endpoint)
.expect("A fresh socket should always be able to listen");
socket
}
fn try_recv_query(
socket: &mut tcp::Socket,
fn try_recv_query<'b>(
socket: &'b mut tcp::Socket,
listen: SocketAddr,
) -> Result<Option<Message<Vec<u8>>>> {
) -> Result<Option<Message<&'b [u8]>>> {
// smoltcp's sockets can only ever handle a single remote, i.e. there is no permanent listening socket.
// to be able to handle a new connection, reset the socket back to `listen` once the connection is closed / closing.
{
@@ -306,56 +228,19 @@ fn try_recv_query(
return Ok(None);
}
// Read a DNS message from the socket.
let Some(message) = socket
.recv(|r| {
// DNS over TCP has a 2-byte length prefix at the start, see <https://datatracker.ietf.org/doc/html/rfc1035#section-4.2.2>.
let Some((header, message)) = r.split_first_chunk::<2>() else {
return (0, None);
};
let dns_message_length = u16::from_be_bytes(*header) as usize;
if message.len() < dns_message_length {
return (0, None); // Don't consume any bytes unless we can read the full message at once.
}
(2 + dns_message_length, Some(Message::from_octets(message)))
})
.context("Failed to recv TCP data")?
.transpose()
.context("Failed to parse DNS message")?
else {
let Some(message) = codec::try_recv(socket)? else {
return Ok(None);
};
anyhow::ensure!(!message.header().qr(), "DNS message is a response!");
Ok(Some(message.octets_into()))
Ok(Some(message))
}
fn write_tcp_dns_response(socket: &mut tcp::Socket, response: Message<&[u8]>) -> Result<()> {
anyhow::ensure!(response.header().qr(), "DNS message is a query!");
let response = response.as_slice();
let dns_message_length = (response.len() as u16).to_be_bytes();
let written = socket
.send_slice(&dns_message_length)
.context("Failed to write TCP DNS length header")?;
anyhow::ensure!(
written == 2,
"Not enough space in write buffer for TCP DNS length header"
);
let written = socket
.send_slice(response)
.context("Failed to write DNS message")?;
anyhow::ensure!(
written == response.len(),
"Not enough space in write buffer for DNS response"
);
codec::try_send(socket, response)?;
Ok(())
}

View File

@@ -0,0 +1,89 @@
use std::{
collections::BTreeSet,
net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4},
time::Instant,
};
use dns_over_tcp::QueryResult;
use domain::base::{iana::Rcode, Message, MessageBuilder, Name, Rtype};
#[test]
fn smoke() {
let _guard = firezone_logging::test(
"netlink_proto=off,wire::dns::res=trace,dns_over_tcp=trace,smoltcp=trace,debug",
);
let ipv4 = Ipv4Addr::from([100, 90, 215, 97]);
let ipv6 = Ipv6Addr::from([0xfd00, 0x2021, 0x1111, 0x0, 0x0, 0x0, 0x0016, 0x588f]);
let resolver_addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(100, 100, 111, 1), 53));
let mut dns_client = dns_over_tcp::Client::new(Instant::now(), 49152..=65535, [0u8; 32]);
dns_client.set_source_interface(ipv4, ipv6);
dns_client
.connect_to_resolvers(BTreeSet::from_iter([resolver_addr]))
.unwrap();
let mut dns_server = dns_over_tcp::Server::new(Instant::now());
dns_server.set_listen_addresses::<1>(vec![resolver_addr]);
for id in 0..5 {
dns_client
.send_query(resolver_addr, a_query("example.com", id))
.unwrap();
}
let results = std::iter::from_fn(|| progress(&mut dns_client, &mut dns_server))
.take(5)
.collect::<Vec<_>>();
for query_result in results {
let result = query_result.result.unwrap();
println!("{result:?}")
}
}
fn a_query(domain: &str, id: u16) -> Message<Vec<u8>> {
let mut builder = MessageBuilder::new_vec().question();
builder.header_mut().set_id(id);
builder
.push((Name::vec_from_str(domain).unwrap(), Rtype::A))
.unwrap();
builder.into_message()
}
fn progress(
dns_client: &mut dns_over_tcp::Client,
dns_server: &mut dns_over_tcp::Server,
) -> Option<QueryResult> {
loop {
if let Some(packet) = dns_client.poll_outbound() {
dns_server.handle_inbound(packet);
continue;
}
if let Some(packet) = dns_server.poll_outbound() {
dns_client.handle_inbound(packet);
continue;
}
if let Some(query) = dns_server.poll_queries() {
let response = MessageBuilder::new_vec()
.start_answer(&query.message, Rcode::NXDOMAIN)
.unwrap()
.into_message();
dns_server.send_message(query.socket, response).unwrap();
continue;
}
if let Some(query) = dns_client.poll_query_result() {
return Some(query);
}
dns_client.handle_timeout(Instant::now());
dns_server.handle_timeout(Instant::now());
}
}

View File

@@ -16,10 +16,9 @@ use tun::Tun;
const CLIENT_CONCURRENCY: usize = 3;
#[tokio::test]
#[ignore = "Requires root"]
#[ignore = "Requires root & IP forwarding"]
async fn smoke() {
let _guard =
firezone_logging::test("netlink_proto=off,wire::dns::res=trace,dns_over_tcp=trace,debug");
let _guard = firezone_logging::test("netlink_proto=off,wire::dns=trace,debug");
let ipv4 = Ipv4Addr::from([100, 90, 215, 97]);
let ipv6 = Ipv6Addr::from([0xfd00, 0x2021, 0x1111, 0x0, 0x0, 0x0, 0x0016, 0x588f]);