chore: add sans-IO DNS-over-TCP implementation (#6997)

This splits out the actual DNS server from #6944 into a separate crate.
At present, it only contains a DNS server. Later, we will likely add a
DNS client to it as well because the proptests and connlib itself will
need a user-space DNS TCP client.

The implementation uses `smoltcp` but that is entirely encapsulated. The
`Server` struct exposes only a high-level interface for

- feeding inbound packets as well as retrieving outbound packets
- retrieving parsed DNS queries and sending DNS responses

Related: #6140.
This commit is contained in:
Thomas Eizinger
2024-10-11 08:05:12 +11:00
committed by GitHub
parent 8c4f6bdb0f
commit cd2dea7846
8 changed files with 696 additions and 1 deletions

89
rust/Cargo.lock generated
View File

@@ -1685,6 +1685,38 @@ dependencies = [
"uuid",
]
[[package]]
name = "defmt"
version = "0.3.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a99dd22262668b887121d4672af5a64b238f026099f1a2a1b322066c9ecfe9e0"
dependencies = [
"bitflags 1.3.2",
"defmt-macros",
]
[[package]]
name = "defmt-macros"
version = "0.3.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e3a9f309eff1f79b3ebdf252954d90ae440599c26c2c553fe87a2d17195f2dcb"
dependencies = [
"defmt-parser",
"proc-macro-error",
"proc-macro2",
"quote",
"syn 2.0.77",
]
[[package]]
name = "defmt-parser"
version = "0.3.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ff4a5fefe330e8d7f31b16a318f9ce81000d8e35e69b93eae154d16d2278f70f"
dependencies = [
"thiserror",
]
[[package]]
name = "deranged"
version = "0.3.11"
@@ -1872,6 +1904,23 @@ dependencies = [
"windows-sys 0.48.0",
]
[[package]]
name = "dns-over-tcp"
version = "0.1.0"
dependencies = [
"anyhow",
"domain",
"firezone-bin-shared",
"firezone-logging",
"ip-packet",
"ip_network",
"itertools 0.13.0",
"smoltcp",
"tokio",
"tracing",
"tun",
]
[[package]]
name = "domain"
version = "0.10.1"
@@ -3236,6 +3285,15 @@ dependencies = [
"crunchy",
]
[[package]]
name = "hash32"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "47d60b12902ba28e2730cd37e95b8c9223af2808df9e902d4df49588d1470606"
dependencies = [
"byteorder",
]
[[package]]
name = "hashbrown"
version = "0.12.3"
@@ -3267,6 +3325,16 @@ dependencies = [
"winapi",
]
[[package]]
name = "heapless"
version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0bfb9eb618601c89945a70e254898da93b13be0388091d42117462b265bb3fad"
dependencies = [
"hash32",
"stable_deref_trait",
]
[[package]]
name = "heck"
version = "0.3.3"
@@ -4327,6 +4395,12 @@ dependencies = [
"libc",
]
[[package]]
name = "managed"
version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0ca88d725a0a943b096803bd34e73a4437208b6077654cc4ecb2947a5f91618d"
[[package]]
name = "markup5ever"
version = "0.11.0"
@@ -6947,6 +7021,21 @@ dependencies = [
"serde",
]
[[package]]
name = "smoltcp"
version = "0.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5a1a996951e50b5971a2c8c0fa05a381480d70a933064245c4a223ddc87ccc97"
dependencies = [
"bitflags 1.3.2",
"byteorder",
"cfg-if",
"defmt",
"heapless",
"log",
"managed",
]
[[package]]
name = "snownet"
version = "0.1.0"

View File

@@ -7,6 +7,7 @@ members = [
"connlib/model",
"connlib/snownet",
"connlib/tunnel",
"dns-over-tcp",
"gateway",
"gui-client/src-common",
"gui-client/src-tauri",
@@ -51,6 +52,7 @@ firezone-bin-shared = { path = "bin-shared" }
firezone-logging = { path = "logging" }
firezone-telemetry = { path = "telemetry" }
snownet = { path = "connlib/snownet" }
dns-over-tcp = { path = "dns-over-tcp" }
firezone-relay = { path = "relay" }
connlib-model = { path = "connlib/model" }
firezone-tunnel = { path = "connlib/tunnel" }

View File

@@ -0,0 +1,20 @@
[package]
name = "dns-over-tcp"
version = "0.1.0"
edition = "2021"
description = "User-space implementation of DNS over TCP."
[dependencies]
anyhow = "1.0"
domain = { workspace = true }
ip-packet = { workspace = true }
itertools = "0.13"
smoltcp = { version = "0.11", default-features = false, features = ["std", "log", "medium-ip", "proto-ipv4", "proto-ipv6", "socket-tcp"] }
tracing = { workspace = true }
[dev-dependencies]
firezone-bin-shared = { workspace = true }
firezone-logging = { workspace = true }
ip_network = { version = "0.4", default-features = false }
tokio = { workspace = true, features = ["process", "rt", "macros"] }
tun = { workspace = true }

View File

@@ -0,0 +1,4 @@
mod server;
mod stub_device;
pub use server::{Server, SocketHandle};

View File

@@ -0,0 +1,361 @@
use std::{
collections::{BTreeSet, HashMap, VecDeque},
net::SocketAddr,
time::Instant,
};
use crate::stub_device::InMemoryDevice;
use anyhow::{Context as _, Result};
use domain::{base::Message, dep::octseq::OctetsInto as _, rdata::AllRecordData};
use ip_packet::IpPacket;
use itertools::Itertools as _;
use smoltcp::{
iface::{Config, Interface, Route, SocketSet},
socket::tcp,
storage::RingBuffer,
wire::{HardwareAddress, IpEndpoint, Ipv4Address, Ipv4Cidr, Ipv6Address, Ipv6Cidr},
};
/// A sans-IO implementation of DNS-over-TCP server.
///
/// Listens on a specified number of socket addresses, parses incoming DNS queries and allows writing back responses.
pub struct Server {
device: InMemoryDevice,
interface: Interface,
sockets: SocketSet<'static>,
listen_endpoints: HashMap<smoltcp::iface::SocketHandle, SocketAddr>,
received_queries: VecDeque<Query>,
}
/// Opaque handle to a TCP socket.
///
/// This purposely does not implement [`Clone`] or [`Copy`] to make them single-use.
#[derive(Debug, PartialEq, Eq, Hash)]
#[must_use = "An active `SocketHandle` means a TCP socket is waiting for a reply somewhere"]
pub struct SocketHandle(smoltcp::iface::SocketHandle);
pub struct Query {
pub message: Message<Vec<u8>>,
pub socket: SocketHandle,
/// The address of the socket that received the query.
pub local: SocketAddr,
}
const SERVER_IP4_ADDR: Ipv4Address = Ipv4Address::new(127, 0, 0, 1);
const SERVER_IP6_ADDR: Ipv6Address = Ipv6Address::new(0, 0, 0, 0, 0, 0, 0, 1);
impl Server {
pub fn new(now: Instant) -> Self {
let mut device = InMemoryDevice::default();
let mut interface =
Interface::new(Config::new(HardwareAddress::Ip), &mut device, now.into());
// Accept packets with any destination IP, not just our interface.
interface.set_any_ip(true);
// Set our interface IPs. These are just dummies and don't show up anywhere!
interface.update_ip_addrs(|ips| {
ips.push(Ipv4Cidr::new(SERVER_IP4_ADDR, 32).into()).unwrap();
ips.push(Ipv6Cidr::new(SERVER_IP6_ADDR, 128).into())
.unwrap();
});
// Configure catch-all routes, meaning all packets given to `smoltcp` will be routed to our interface.
interface.routes_mut().update(|routes| {
routes
.push(Route::new_ipv4_gateway(SERVER_IP4_ADDR))
.unwrap();
routes
.push(Route::new_ipv6_gateway(SERVER_IP6_ADDR))
.unwrap();
});
Self {
device,
interface,
sockets: SocketSet::new(Vec::default()),
listen_endpoints: Default::default(),
received_queries: Default::default(),
}
}
/// Listen on the specified addresses.
///
/// This resets all sockets we were previously listening on.
/// This function is generic over a `NUM_CONCURRENT_CLIENTS` constant.
/// The constant configures, how many concurrent clients you would like to be able to serve per listen address.
pub fn set_listen_addresses<const NUM_CONCURRENT_CLIENTS: usize>(
&mut self,
addresses: Vec<SocketAddr>,
) {
assert!(NUM_CONCURRENT_CLIENTS > 0);
let mut sockets =
SocketSet::new(Vec::with_capacity(addresses.len() * NUM_CONCURRENT_CLIENTS));
let mut listen_endpoints = HashMap::with_capacity(addresses.len());
for listen_endpoint in addresses {
for _ in 0..NUM_CONCURRENT_CLIENTS {
let handle = sockets.add(create_tcp_socket(listen_endpoint));
listen_endpoints.insert(handle, listen_endpoint);
}
tracing::info!(%listen_endpoint, concurrency = %NUM_CONCURRENT_CLIENTS, "Created listening TCP socket");
}
self.sockets = sockets;
self.listen_endpoints = listen_endpoints;
self.received_queries.clear();
}
/// Checks whether this server can handle the given packet.
///
/// Only TCP packets targeted at one of sockets configured with [`Server::set_listen_addresses`] are accepted.
pub fn accepts(&self, packet: &IpPacket) -> bool {
let Some(tcp) = packet.as_tcp() else {
tracing::trace!(?packet, "Not a TCP packet");
return false;
};
let dst = SocketAddr::new(packet.destination(), tcp.destination_port());
let is_listening = self.listen_endpoints.values().any(|listen| listen == &dst);
if !is_listening && tracing::enabled!(tracing::Level::TRACE) {
let listen_endpoints = BTreeSet::from_iter(self.listen_endpoints.values().copied());
tracing::trace!(%dst, ?listen_endpoints, "No listening socket for destination");
}
is_listening
}
/// Handle the [`IpPacket`].
///
/// This function only inserts the packet into a buffer.
/// To actually process the packets in the buffer, [`Server::handle_timeout`] must be called.
pub fn handle_inbound(&mut self, packet: IpPacket) {
debug_assert!(self.accepts(&packet));
self.device.receive(packet);
}
/// Send a message on the socket associated with the handle.
///
/// This fails if the socket is not writeable.
/// On any error, the TCP connection is automatically reset.
pub fn send_message(&mut self, socket: SocketHandle, message: Message<Vec<u8>>) -> Result<()> {
let socket = self.sockets.get_mut::<tcp::Socket>(socket.0);
let result = write_tcp_dns_response(socket, message.for_slice_ref());
if result.is_err() {
socket.abort();
}
result.context("Failed to write DNS response")?; // Bail before logging in case we failed to write the response.
if tracing::event_enabled!(target: "wire::dns::res", tracing::Level::TRACE) {
if let Ok(question) = message.sole_question() {
let qtype = question.qtype();
let qname = question.into_qname();
let rcode = message.header().rcode();
if let Ok(record_section) = message.answer() {
let records = record_section
.into_iter()
.filter_map(|r| {
let data = r
.ok()?
.into_any_record::<AllRecordData<_, _>>()
.ok()?
.data()
.clone();
Some(data)
})
.join(" | ");
let qid = message.header().id();
tracing::trace!(target: "wire::dns::res", %qid, %rcode, "{:5} {qname} => [{records}]", qtype.to_string());
}
}
}
Ok(())
}
/// Resets the socket associated with the given handle.
///
/// Use this if you encountered an error while processing a previously emitted DNS query.
pub fn reset(&mut self, handle: SocketHandle) {
self.sockets.get_mut::<tcp::Socket>(handle.0).abort();
}
/// Inform the server that time advanced.
///
/// Typical for a sans-IO design, `handle_timeout` will work through all local buffers and process them as much as possible.
pub fn handle_timeout(&mut self, now: Instant) {
let changed = self.interface.poll(
smoltcp::time::Instant::from(now),
&mut self.device,
&mut self.sockets,
);
if !changed {
return;
}
for (handle, smoltcp::socket::Socket::Tcp(socket)) in self.sockets.iter_mut() {
let listen = self.listen_endpoints.get(&handle).copied().unwrap();
match try_recv_query(socket, listen) {
Ok(Some(message)) => {
if tracing::event_enabled!(target: "wire::dns::qry", tracing::Level::TRACE) {
if let Ok(question) = message.sole_question() {
let qtype = question.qtype();
let qname = question.into_qname();
let qid = message.header().id();
tracing::trace!(target: "wire::dns::qry", %qid, "{:5} {qname}", qtype.to_string());
}
}
self.received_queries.push_back(Query {
message,
socket: SocketHandle(handle),
local: listen,
});
}
Ok(None) => {}
Err(e) => {
tracing::debug!("Error on receiving DNS query: {e}");
socket.abort();
}
}
}
}
/// Returns [`IpPacket`]s that should be sent.
pub fn poll_outbound(&mut self) -> Option<IpPacket> {
self.device.next_send()
}
/// Returns queries received from a DNS client.
pub fn poll_queries(&mut self) -> Option<Query> {
self.received_queries.pop_front()
}
}
fn create_tcp_socket(listen_endpoint: SocketAddr) -> tcp::Socket<'static> {
/// The 2-byte length prefix of DNS over TCP messages limits their size to effectively u16::MAX.
/// It is quite unlikely that we have to buffer _multiple_ of these max-sized messages.
/// Being able to buffer at least one of them means we can handle the extreme case.
/// In practice, this allows the OS to queue multiple queries even if we can't immediately process them.
const MAX_TCP_DNS_MSG_LENGTH: usize = u16::MAX as usize;
let mut socket = tcp::Socket::new(
RingBuffer::new(vec![0u8; MAX_TCP_DNS_MSG_LENGTH]),
RingBuffer::new(vec![0u8; MAX_TCP_DNS_MSG_LENGTH]),
);
socket
.listen(listen_endpoint)
.expect("A fresh socket should always be able to listen");
socket
}
fn try_recv_query(
socket: &mut tcp::Socket,
listen: SocketAddr,
) -> Result<Option<Message<Vec<u8>>>> {
// smoltcp's sockets can only ever handle a single remote, i.e. there is no permanent listening socket.
// to be able to handle a new connection, reset the socket back to `listen` once the connection is closed / closing.
{
use smoltcp::socket::tcp::State::*;
if matches!(socket.state(), Closed | TimeWait | CloseWait) {
tracing::debug!(state = %socket.state(), "Resetting socket to listen state");
socket.abort();
socket
.listen(listen)
.expect("Can always listen after `abort()`");
}
}
// We configure `smoltcp` with "any-ip", meaning packets to technically any IP will be routed here to us.
if let Some(local) = socket.local_endpoint() {
anyhow::ensure!(
local == IpEndpoint::from(listen),
"Bad destination socket: {local}"
)
}
// Ensure we can recv, send and have space to send.
if !socket.can_recv() || !socket.can_send() || socket.send_queue() > 0 {
tracing::trace!(
can_recv = %socket.can_recv(),
can_send = %socket.can_send(),
send_queue = %socket.send_queue(),
"Not yet ready to receive next message"
);
return Ok(None);
}
// Read a DNS message from the socket.
let Some(message) = socket
.recv(|r| {
// DNS over TCP has a 2-byte length prefix at the start, see <https://datatracker.ietf.org/doc/html/rfc1035#section-4.2.2>.
let Some((header, message)) = r.split_first_chunk::<2>() else {
return (0, None);
};
let dns_message_length = u16::from_be_bytes(*header) as usize;
if message.len() < dns_message_length {
return (0, None); // Don't consume any bytes unless we can read the full message at once.
}
(2 + dns_message_length, Some(Message::from_octets(message)))
})
.context("Failed to recv TCP data")?
.transpose()
.context("Failed to parse DNS message")?
else {
return Ok(None);
};
anyhow::ensure!(!message.header().qr(), "DNS message is a response!");
Ok(Some(message.octets_into()))
}
fn write_tcp_dns_response(socket: &mut tcp::Socket, response: Message<&[u8]>) -> Result<()> {
anyhow::ensure!(response.header().qr(), "DNS message is a query!");
let response = response.as_slice();
let dns_message_length = (response.len() as u16).to_be_bytes();
let written = socket
.send_slice(&dns_message_length)
.context("Failed to write TCP DNS length header")?;
anyhow::ensure!(
written == 2,
"Not enough space in write buffer for TCP DNS length header"
);
let written = socket
.send_slice(response)
.context("Failed to write DNS message")?;
anyhow::ensure!(
written == response.len(),
"Not enough space in write buffer for DNS response"
);
Ok(())
}

View File

@@ -0,0 +1,86 @@
use std::collections::VecDeque;
use ip_packet::{IpPacket, IpPacketBuf};
/// A in-memory device for [`smoltcp`] that is entirely backed by buffers.
#[derive(Debug, Default)]
pub(crate) struct InMemoryDevice {
inbound_packets: VecDeque<IpPacket>,
outbound_packets: VecDeque<IpPacket>,
}
impl InMemoryDevice {
pub(crate) fn receive(&mut self, packet: IpPacket) {
self.inbound_packets.push_back(packet);
}
pub(crate) fn next_send(&mut self) -> Option<IpPacket> {
self.outbound_packets.pop_front()
}
}
impl smoltcp::phy::Device for InMemoryDevice {
type RxToken<'a> = SmolRxToken;
type TxToken<'a> = SmolTxToken<'a>;
fn receive(
&mut self,
_timestamp: smoltcp::time::Instant,
) -> Option<(Self::RxToken<'_>, Self::TxToken<'_>)> {
let rx_token = SmolRxToken {
packet: self.inbound_packets.pop_front()?,
};
let tx_token = SmolTxToken {
outbound_packets: &mut self.outbound_packets,
};
Some((rx_token, tx_token))
}
fn transmit(&mut self, _timestamp: smoltcp::time::Instant) -> Option<Self::TxToken<'_>> {
Some(SmolTxToken {
outbound_packets: &mut self.outbound_packets,
})
}
fn capabilities(&self) -> smoltcp::phy::DeviceCapabilities {
let mut caps = smoltcp::phy::DeviceCapabilities::default();
caps.medium = smoltcp::phy::Medium::Ip;
caps.max_transmission_unit = ip_packet::PACKET_SIZE;
caps
}
}
pub(crate) struct SmolTxToken<'a> {
outbound_packets: &'a mut VecDeque<IpPacket>,
}
impl<'a> smoltcp::phy::TxToken for SmolTxToken<'a> {
fn consume<R, F>(self, len: usize, f: F) -> R
where
F: FnOnce(&mut [u8]) -> R,
{
let mut ip_packet_buf = IpPacketBuf::new();
let result = f(ip_packet_buf.buf());
let mut ip_packet = IpPacket::new(ip_packet_buf, len).unwrap();
ip_packet.update_checksum();
self.outbound_packets.push_back(ip_packet);
result
}
}
pub(crate) struct SmolRxToken {
packet: IpPacket,
}
impl smoltcp::phy::RxToken for SmolRxToken {
fn consume<R, F>(mut self, f: F) -> R
where
F: FnOnce(&mut [u8]) -> R,
{
f(self.packet.packet_mut())
}
}

View File

@@ -0,0 +1,133 @@
use std::{
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4},
process::Stdio,
task::{ready, Context, Poll},
time::Instant,
};
use anyhow::{Context as _, Result};
use domain::base::{iana::Rcode, MessageBuilder};
use firezone_bin_shared::TunDeviceManager;
use ip_network::Ipv4Network;
use ip_packet::{IpPacket, IpPacketBuf};
use tokio::task::JoinSet;
use tun::Tun;
const CLIENT_CONCURRENCY: usize = 3;
#[tokio::test]
#[ignore = "Requires root"]
async fn smoke() {
let _guard =
firezone_logging::test("netlink_proto=off,wire::dns::res=trace,dns_over_tcp=trace,debug");
let ipv4 = Ipv4Addr::from([100, 90, 215, 97]);
let ipv6 = Ipv6Addr::from([0xfd00, 0x2021, 0x1111, 0x0, 0x0, 0x0, 0x0016, 0x588f]);
let mut device_manager = TunDeviceManager::new(1280).unwrap();
let tun = device_manager.make_tun().unwrap();
device_manager.set_ips(ipv4, ipv6).await.unwrap();
device_manager
.set_routes(
vec![Ipv4Network::new(Ipv4Addr::new(100, 100, 111, 0), 24).unwrap()],
vec![],
)
.await
.unwrap();
let listen_addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(100, 100, 111, 1), 53));
let mut dns_server = dns_over_tcp::Server::new(Instant::now());
dns_server.set_listen_addresses::<CLIENT_CONCURRENCY>(vec![listen_addr]);
let mut eventloop = Eventloop::new(Box::new(tun), dns_server);
tokio::spawn(std::future::poll_fn(move |cx| eventloop.poll(cx)));
// Running the queries multiple times ensures we can reuse sockets.
run_queries(listen_addr.ip()).await;
run_queries(listen_addr.ip()).await;
}
async fn run_queries(dns_server: IpAddr) {
let mut set = JoinSet::new();
for _ in 0..CLIENT_CONCURRENCY {
set.spawn(dig(dns_server));
}
let exit_codes = set
.join_all()
.await
.into_iter()
.collect::<Result<Vec<_>>>()
.unwrap();
for status in exit_codes {
assert_eq!(status, 0)
}
}
async fn dig(dns_server: IpAddr) -> Result<i32> {
let exit_status = tokio::process::Command::new("dig")
.args([
"+tcp",
"+tries=1",
"+keepopen", // Reuse the TCP socket
&format!("@{dns_server}"),
"example.com",
"example.com", // Querying more than one domain ensures a client can reuse a TCP connection
"example.com",
"example.com",
])
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.status()
.await?
.code()
.context("Missing code")?;
Ok(exit_status)
}
struct Eventloop {
tun: Box<dyn Tun>,
dns_server: dns_over_tcp::Server,
}
impl Eventloop {
fn new(tun: Box<dyn Tun>, dns_server: dns_over_tcp::Server) -> Self {
Self { tun, dns_server }
}
fn poll(&mut self, cx: &mut Context) -> Poll<()> {
loop {
if let Some(packet) = self.dns_server.poll_outbound() {
match packet {
IpPacket::Ipv4(v4) => self.tun.write4(v4.packet()).unwrap(),
IpPacket::Ipv6(v6) => self.tun.write6(v6.packet()).unwrap(),
};
continue;
}
if let Some(query) = self.dns_server.poll_queries() {
let response = MessageBuilder::new_vec()
.start_answer(&query.message, Rcode::NXDOMAIN)
.unwrap()
.into_message();
self.dns_server
.send_message(query.socket, response)
.unwrap();
continue;
}
let mut packet_buf = IpPacketBuf::default();
let num_read = ready!(self.tun.poll_read(packet_buf.buf(), cx)).unwrap();
let packet = IpPacket::new(packet_buf, num_read).unwrap();
if self.dns_server.accepts(&packet) {
self.dns_server.handle_inbound(packet);
self.dns_server.handle_timeout(Instant::now());
}
}
}
}

View File

@@ -789,7 +789,7 @@ impl IpPacket {
}
}
fn packet_mut(&mut self) -> &mut [u8] {
pub fn packet_mut(&mut self) -> &mut [u8] {
match self {
IpPacket::Ipv4(v4) => v4.packet_mut(),
IpPacket::Ipv6(v6) => v6.packet_mut(),