feat(connlib): support DNS over TCP (#6944)

At present, `connlib` only supports DNS over UDP on port 53. Responses
over UDP are size-constrained on the IP MTU and thus, not all DNS
responses fit into a UDP packet. RFC9210 therefore mandates that all DNS
resolvers must also support DNS over TCP to overcome this limitation
[0].

Handling UDP packets is easy, handling TCP streams is more difficult
because we need to effectively implement a valid TCP state machine.

Building on top of a lot of earlier work (linked in issue), this is
relatively easy because we can now simply import
`dns_over_tcp::{Client,Server}` which do the heavy lifting of sending
and receiving the correct packets for us.

The main aspects of the integration that are worth pointing out are:

- We can handle at most 10 concurrent DNS TCP connections _per defined
resolver_. The assumption here is that most applications will first
query for DNS records over UDP and only fall back to TCP if the response
is truncated. Additionally, we assume that clients will close the TCP
connections once they no longer need it.
- Errors on the TCP stream to an upstream resolver result in `SERVFAIL`
responses to the client.
- All TCP connections to upstream resolvers get reset when we roam, all
currently ongoing queries will be answered with `SERVFAIL`.
- Upon network reset (i.e. roaming), we also re-allocate new local ports
for all TCP sockets, similar to our UDP sockets.

Resolves: #6140.

[0]: https://www.ietf.org/rfc/rfc9210.html#section-3-5
This commit is contained in:
Thomas Eizinger
2024-10-18 14:40:50 +11:00
committed by GitHub
parent 3365981e1b
commit 9de1119b69
28 changed files with 850 additions and 342 deletions

View File

@@ -106,6 +106,7 @@ jobs:
# Too noisy can cause flaky tests due to the amount of data
rust_log: debug
- name: dns-nm
- name: tcp-dns
- name: relay-graceful-shutdown
- name: systemd/dns-systemd-resolved
steps:

2
rust/Cargo.lock generated
View File

@@ -2095,6 +2095,7 @@ dependencies = [
"connlib-model",
"derivative",
"divan",
"dns-over-tcp",
"domain",
"firezone-logging",
"firezone-relay",
@@ -3145,7 +3146,6 @@ name = "ip-packet"
version = "0.1.0"
dependencies = [
"anyhow",
"domain",
"etherparse",
"proptest",
"test-strategy",

View File

@@ -13,6 +13,7 @@ chrono = { workspace = true }
connlib-model = { workspace = true }
derivative = "2.2.0"
divan = { version = "0.1.14", optional = true }
dns-over-tcp = { workspace = true }
domain = { workspace = true }
firezone-logging = { workspace = true }
futures = { version = "0.3", default-features = false, features = ["std", "async-await", "executor"] }

View File

@@ -118,3 +118,9 @@ cc b4dd2a98e4e6aa29f875fa7b8e3af451b1ce6ef8b4e9d6c4cd29fcb68e9249de
cc 0a717e57a998e97be9134007c6a102c5ebaba5c477c95003eaa8f3c4503f88f1
cc 1ead95151ff4ea386b990d1ec7c81a33a816bd8f81d3e3b54abf181e9ff7f3c7
cc 879b2d7d9592265e8cb2799fc0a5d6ab19c6637f53a3181d9613ac3be3e4e532
cc a5f733ee61b9a545b93f5eccb71631918250f8b0657b2479c5f2e85c10fd013d
cc a5f733ee61b9a545b93f5eccb71631918250f8b0657b2479c5f2e85c10fd013d
cc 33cd1cba9c6ecf15d6ff86c3114752f2437e432c77f671f67b08116d2b507131
cc d9793b201ec425bd77f9849ea48e63677014aeb4a91a55be9371b81e644b7a24
cc 8fcbd19c41f0483d9b81aac2ab7440bb23d7796ef9f6bf346f73f0d633f65baa
cc 4494e475d22ff9a318d676f10c79f545982b7787d145925c3719fe47e9868acc

View File

@@ -1,6 +1,5 @@
mod resource;
use domain::base::iana::Rcode;
pub(crate) use resource::{CidrResource, Resource};
#[cfg(all(feature = "proptest", test))]
pub(crate) use resource::{DnsResource, InternetResource};
@@ -24,7 +23,7 @@ use itertools::Itertools;
use crate::peer::GatewayOnClient;
use crate::utils::earliest;
use crate::ClientEvent;
use domain::base::{Message, MessageBuilder};
use domain::base::Message;
use lru::LruCache;
use secrecy::{ExposeSecret as _, Secret};
use snownet::{ClientNode, EncryptBuffer, RelaySocket, Transmit};
@@ -74,6 +73,9 @@ const IDS_EXPIRE: std::time::Duration = std::time::Duration::from_secs(60);
/// We only store [`GatewayId`]s so the memory footprint is negligible.
const MAX_REMEMBERED_GATEWAYS: NonZeroUsize = unsafe { NonZeroUsize::new_unchecked(100) };
/// How many concurrent TCP DNS clients we can server _per_ sentinel DNS server IP.
const NUM_CONCURRENT_TCP_DNS_CLIENTS: usize = 10;
/// A sans-IO implementation of a Client's functionality.
///
/// Internally, this composes a [`snownet::ClientNode`] with firezone's policy engine around resources.
@@ -123,6 +125,12 @@ pub struct ClientState {
/// Resources that have been disabled by the UI
disabled_resources: BTreeSet<ResourceId>,
tcp_dns_client: dns_over_tcp::Client,
tcp_dns_server: dns_over_tcp::Server,
/// Tracks the socket on which we received a TCP DNS query by the ID of the recursive DNS query we issued.
tcp_dns_sockets_by_upstream_and_query_id:
HashMap<(SocketAddr, u16), dns_over_tcp::SocketHandle>,
/// Stores the gateways we recently connected to.
///
/// We use this as a hint to the portal to re-connect us to the same gateway for a resource.
@@ -141,7 +149,11 @@ struct AwaitingConnectionDetails {
}
impl ClientState {
pub(crate) fn new(known_hosts: BTreeMap<String, Vec<IpAddr>>, seed: [u8; 32]) -> Self {
pub(crate) fn new(
known_hosts: BTreeMap<String, Vec<IpAddr>>,
seed: [u8; 32],
now: Instant,
) -> Self {
Self {
awaiting_connection_details: Default::default(),
resources_gateways: Default::default(),
@@ -164,6 +176,9 @@ impl ClientState {
recently_connected_gateways: LruCache::new(MAX_REMEMBERED_GATEWAYS),
upstream_dns: Default::default(),
buffered_dns_queries: Default::default(),
tcp_dns_client: dns_over_tcp::Client::new(now, seed),
tcp_dns_server: dns_over_tcp::Server::new(now),
tcp_dns_sockets_by_upstream_and_query_id: Default::default(),
}
}
@@ -283,10 +298,119 @@ impl ClientState {
now: Instant,
buffer: &mut EncryptBuffer,
) -> Option<snownet::EncryptedPacket> {
let packet = match self.try_handle_dns(packet, now) {
let non_dns_packet = match self.try_handle_dns(packet, now) {
ControlFlow::Break(()) => return None,
ControlFlow::Continue(non_dns_packet) => non_dns_packet,
};
self.encapsulate(non_dns_packet, now, buffer)
}
/// Handles UDP packets received on the network interface.
///
/// Most of these packets will be WireGuard encrypted IP packets and will thus yield an [`IpPacket`].
/// Some of them will however be handled internally, for example, TURN control packets exchanged with relays.
///
/// In case this function returns `None`, you should call [`ClientState::handle_timeout`] next to fully advance the internal state.
pub(crate) fn handle_network_input(
&mut self,
local: SocketAddr,
from: SocketAddr,
packet: &[u8],
now: Instant,
) -> Option<IpPacket> {
let (gid, packet) = self.node.decapsulate(
local,
from,
packet.as_ref(),
now,
)
.inspect_err(|e| tracing::debug!(%local, num_bytes = %packet.len(), "Failed to decapsulate incoming packet: {e}"))
.ok()??;
if self.tcp_dns_client.accepts(&packet) {
self.tcp_dns_client.handle_inbound(packet);
return None;
}
let Some(peer) = self.peers.get_mut(&gid) else {
tracing::error!(%gid, "Couldn't find connection by ID");
return None;
};
peer.ensure_allowed_src(&packet)
.inspect_err(|e| tracing::debug!(%gid, %local, %from, "{e}"))
.ok()?;
let packet = maybe_mangle_dns_response_from_cidr_resource(
packet,
&self.dns_mapping,
&mut self.mangled_dns_queries,
now,
);
Some(packet)
}
pub(crate) fn handle_dns_response(&mut self, response: dns::RecursiveResponse) {
let qid = response.query.header().id();
let server = response.server;
let domain = response
.query
.sole_question()
.ok()
.map(|q| q.into_qname())
.map(tracing::field::display);
let _span = tracing::debug_span!("handle_dns_response", %qid, %server, domain).entered();
match (response.transport, response.message) {
(dns::Transport::Udp { .. }, Err(e)) if e.kind() == io::ErrorKind::TimedOut => {
tracing::debug!("Recursive UDP DNS query timed out")
}
(dns::Transport::Udp { source }, result) => {
let message = result
.inspect(|message| {
tracing::trace!("Received recursive UDP DNS response");
if message.header().tc() {
tracing::debug!("Upstream DNS server had to truncate response");
}
})
.unwrap_or_else(|e| {
tracing::debug!("Recursive UDP DNS query failed: {e}");
dns::servfail(response.query.for_slice_ref())
});
self.try_queue_udp_dns_response(server, source, &message)
.log_unwrap_debug("Failed to queue UDP DNS response");
}
(dns::Transport::Tcp { source }, result) => {
let message = result
.inspect(|_| {
tracing::trace!("Received recursive TCP DNS response");
})
.unwrap_or_else(|e| {
tracing::debug!("Recursive TCP DNS query failed: {e}");
dns::servfail(response.query.for_slice_ref())
});
self.tcp_dns_server
.send_message(source, message)
.log_unwrap_debug("Failed to send TCP DNS response");
}
}
}
fn encapsulate(
&mut self,
packet: IpPacket,
now: Instant,
buffer: &mut EncryptBuffer,
) -> Option<snownet::EncryptedPacket> {
let dst = packet.destination();
if is_definitely_not_a_resource(dst) {
@@ -326,88 +450,6 @@ impl ClientState {
Some(transmit)
}
/// Handles UDP packets received on the network interface.
///
/// Most of these packets will be WireGuard encrypted IP packets and will thus yield an [`IpPacket`].
/// Some of them will however be handled internally, for example, TURN control packets exchanged with relays.
///
/// In case this function returns `None`, you should call [`ClientState::handle_timeout`] next to fully advance the internal state.
pub(crate) fn handle_network_input(
&mut self,
local: SocketAddr,
from: SocketAddr,
packet: &[u8],
now: Instant,
) -> Option<IpPacket> {
let (gid, packet) = self.node.decapsulate(
local,
from,
packet.as_ref(),
now,
)
.inspect_err(|e| tracing::debug!(%local, num_bytes = %packet.len(), "Failed to decapsulate incoming packet: {e}"))
.ok()??;
let Some(peer) = self.peers.get_mut(&gid) else {
tracing::error!(%gid, "Couldn't find connection by ID");
return None;
};
peer.ensure_allowed_src(&packet)
.inspect_err(|e| tracing::debug!(%gid, %local, %from, "{e}"))
.ok()?;
let packet = maybe_mangle_dns_response_from_cidr_resource(
packet,
&self.dns_mapping,
&mut self.mangled_dns_queries,
now,
);
Some(packet)
}
pub(crate) fn handle_dns_response(&mut self, response: dns::RecursiveResponse) {
let qid = response.query.header().id();
let server = response.server;
let domain = response
.query
.sole_question()
.ok()
.map(|q| q.into_qname())
.map(tracing::field::display);
let _span = tracing::debug_span!("handle_dns_response", %qid, %server, domain).entered();
match (response.transport, response.message) {
(dns::Transport::Udp { .. }, Err(e)) if e.kind() == io::ErrorKind::TimedOut => {
tracing::debug!("Recursive DNS query timed out")
}
(dns::Transport::Udp { source }, result) => {
let message = result
.inspect(|message| {
tracing::trace!("Received recursive DNS response");
if message.header().tc() {
tracing::debug!("Upstream DNS server had to truncate response");
}
})
.unwrap_or_else(|e| {
tracing::debug!("Recursive DNS query failed: {e}");
MessageBuilder::new_vec()
.start_answer(&response.query, Rcode::SERVFAIL)
.expect("original query is valid")
.into_message()
});
self.try_queue_udp_dns_response(server, source, &message)
.log_unwrap_debug("Failed to queue UDP DNS response");
}
}
}
fn try_queue_udp_dns_response(
&mut self,
from: SocketAddr,
@@ -570,49 +612,18 @@ impl ClientState {
}
/// Handles UDP & TCP packets targeted at our stub resolver.
fn try_handle_dns(&mut self, mut packet: IpPacket, now: Instant) -> ControlFlow<(), IpPacket> {
fn try_handle_dns(&mut self, packet: IpPacket, now: Instant) -> ControlFlow<(), IpPacket> {
let dst = packet.destination();
let Some(upstream) = self.dns_mapping.get_by_left(&dst).map(|s| s.address()) else {
return ControlFlow::Continue(packet); // Not for our DNS resolver.
};
let (datagram, message) = match parse_udp_dns_message(&packet) {
Ok((datagram, message)) => (datagram, message),
Err(e) => {
tracing::trace!(?packet, "Failed to parse DNS query: {e:#}");
return ControlFlow::Break(());
}
};
let source = SocketAddr::new(packet.source(), datagram.source_port());
match self.stub_resolver.handle(message) {
dns::ResolveStrategy::LocalResponse(response) => {
self.try_queue_udp_dns_response(upstream, source, &response)
.log_unwrap_debug("Failed to queue UDP DNS response");
}
dns::ResolveStrategy::Recurse => {
let query_id = message.header().id();
if self.should_forward_dns_query_to_gateway(upstream.ip()) {
tracing::trace!(server = %upstream, %query_id, "Forwarding UDP DNS query via tunnel");
self.mangled_dns_queries
.insert((upstream, message.header().id()), now + IDS_EXPIRE);
packet.set_dst(upstream.ip());
packet.update_checksum();
return ControlFlow::Continue(packet);
}
tracing::trace!(server = %upstream, %query_id, "Forwarding UDP DNS query directly via host");
self.buffered_dns_queries
.push_back(dns::RecursiveQuery::via_udp(source, upstream, message));
}
if self.tcp_dns_server.accepts(&packet) {
self.tcp_dns_server.handle_inbound(packet);
return ControlFlow::Break(());
}
ControlFlow::Break(())
self.handle_udp_dns_query(upstream, packet, now)
}
pub fn on_connection_failed(&mut self, resource: ResourceId) {
@@ -693,6 +704,36 @@ impl ClientState {
self.mangled_dns_queries.clear();
}
fn initialise_tcp_dns_client(&mut self) {
let Some(tun_config) = self.tun_config.as_ref() else {
return;
};
self.tcp_dns_client
.set_source_interface(tun_config.ip4, tun_config.ip6);
let upstream_resolvers = self
.dns_mapping
.right_values()
.map(|s| s.address())
.collect();
if let Err(e) = self.tcp_dns_client.set_resolvers(upstream_resolvers) {
tracing::warn!("Failed to connect to upstream DNS resolvers over TCP: {e:#}");
}
}
fn initialise_tcp_dns_server(&mut self) {
let sentinel_sockets = self
.dns_mapping
.left_values()
.map(|ip| SocketAddr::new(*ip, DNS_PORT))
.collect();
self.tcp_dns_server
.set_listen_addresses::<NUM_CONCURRENT_TCP_DNS_CLIENTS>(sentinel_sockets);
}
pub fn set_disabled_resources(&mut self, new_disabled_resources: BTreeSet<ResourceId>) {
let current_disabled_resources = self.disabled_resources.clone();
@@ -805,16 +846,23 @@ impl ClientState {
}
pub fn poll_packets(&mut self) -> Option<IpPacket> {
self.buffered_packets.pop_front()
self.buffered_packets
.pop_front()
.or_else(|| self.tcp_dns_server.poll_outbound())
}
pub fn poll_timeout(&mut self) -> Option<Instant> {
// The number of mangled DNS queries is expected to be fairly small because we only track them whilst connecting to a CIDR resource that is a DNS server.
// Thus, sorting these values on-demand even within `poll_timeout` is expected to be performant enough.
let next_dns_query_expiry = self.mangled_dns_queries.values().min().copied();
let next_node_timeout = self.node.poll_timeout();
earliest(next_dns_query_expiry, next_node_timeout)
earliest(
earliest(
self.tcp_dns_client.poll_timeout(),
self.tcp_dns_server.poll_timeout(),
),
earliest(self.node.poll_timeout(), next_dns_query_expiry),
)
}
pub fn handle_timeout(&mut self, now: Instant) {
@@ -822,6 +870,163 @@ impl ClientState {
self.drain_node_events();
self.mangled_dns_queries.retain(|_, exp| now < *exp);
self.advance_dns_tcp_sockets(now);
}
/// Advance the TCP DNS server and client state machines.
///
/// Receiving something on a TCP server socket may trigger packets to be sent on the TCP client socket and vice versa.
/// Therefore, we loop here until non of the `poll-X` functions return anything anymore.
fn advance_dns_tcp_sockets(&mut self, now: Instant) {
loop {
self.tcp_dns_server.handle_timeout(now);
self.tcp_dns_client.handle_timeout(now);
// Check if have any pending TCP DNS queries.
if let Some(query) = self.tcp_dns_server.poll_queries() {
self.handle_tcp_dns_query(query);
continue;
}
// Check if the client wants to emit any packets.
if let Some(packet) = self.tcp_dns_client.poll_outbound() {
let mut buffer = snownet::EncryptBuffer::new();
// All packets from the TCP DNS client _should_ go through the tunnel.
let Some(encryped_packet) = self.encapsulate(packet, now, &mut buffer) else {
continue;
};
let transmit = encryped_packet.to_transmit(&buffer).into_owned();
self.buffered_transmits.push_back(transmit);
continue;
}
// Check if the client has assembled a response to a query.
if let Some(query_result) = self.tcp_dns_client.poll_query_result() {
let server = query_result.server;
let qid = query_result.query.header().id();
let known_sockets = &mut self.tcp_dns_sockets_by_upstream_and_query_id;
let Some(source) = known_sockets.remove(&(server, qid)) else {
tracing::debug!(?known_sockets, %server, %qid, "Failed to find TCP socket handle for query result");
continue;
};
self.handle_dns_response(dns::RecursiveResponse {
server,
query: query_result.query,
message: query_result
.result
.map_err(|e| io::Error::new(io::ErrorKind::Other, format!("{e:#}"))),
transport: dns::Transport::Tcp { source },
});
continue;
}
break;
}
}
fn handle_udp_dns_query(
&mut self,
upstream: SocketAddr,
mut packet: IpPacket,
now: Instant,
) -> ControlFlow<(), IpPacket> {
let (datagram, message) = match parse_udp_dns_message(&packet) {
Ok((datagram, message)) => (datagram, message),
Err(e) => {
tracing::trace!(?packet, "Failed to parse DNS query: {e:#}");
return ControlFlow::Break(());
}
};
let source = SocketAddr::new(packet.source(), datagram.source_port());
match self.stub_resolver.handle(message) {
dns::ResolveStrategy::LocalResponse(response) => {
self.try_queue_udp_dns_response(upstream, source, &response)
.log_unwrap_debug("Failed to queue UDP DNS response");
}
dns::ResolveStrategy::Recurse => {
let query_id = message.header().id();
if self.should_forward_dns_query_to_gateway(upstream.ip()) {
tracing::trace!(server = %upstream, %query_id, "Forwarding UDP DNS query via tunnel");
self.mangled_dns_queries
.insert((upstream, message.header().id()), now + IDS_EXPIRE);
packet.set_dst(upstream.ip());
packet.update_checksum();
return ControlFlow::Continue(packet);
}
let query_id = message.header().id();
tracing::trace!(server = %upstream, %query_id, "Forwarding UDP DNS query directly via host");
self.buffered_dns_queries
.push_back(dns::RecursiveQuery::via_udp(source, upstream, message));
}
}
ControlFlow::Break(())
}
fn handle_tcp_dns_query(&mut self, query: dns_over_tcp::Query) {
let message = query.message;
let Some(upstream) = self.dns_mapping.get_by_left(&query.local.ip()) else {
tracing::debug!("Received TCP packet for non-sentinel IP");
debug_assert!(
false,
"We only dispatch packets to sentinel IPs to the TCP DNS server"
);
return;
};
let server = upstream.address();
match self.stub_resolver.handle(message.for_slice_ref()) {
dns::ResolveStrategy::LocalResponse(response) => {
self.tcp_dns_server
.send_message(query.socket, response)
.log_unwrap_debug("Failed to send TCP DNS response");
}
dns::ResolveStrategy::Recurse => {
let query_id = message.header().id();
if self.should_forward_dns_query_to_gateway(server.ip()) {
match self.tcp_dns_client.send_query(server, message.clone()) {
Ok(()) => {}
Err(e) => {
tracing::debug!("Failed to send recursive TCP DNS query {e:#}");
self.tcp_dns_server
.send_message(query.socket, dns::servfail(message.for_slice_ref()))
.log_unwrap_debug("Failed to send TCP DNS response");
return;
}
};
let existing = self
.tcp_dns_sockets_by_upstream_and_query_id
.insert((server, query_id), query.socket);
debug_assert!(existing.is_none(), "Query IDs should be unique");
return;
}
tracing::trace!(%server, %query_id, "Forwarding TCP DNS query");
self.buffered_dns_queries
.push_back(dns::RecursiveQuery::via_tcp(query.socket, server, message));
}
};
}
fn maybe_update_tun_routes(&mut self) {
@@ -885,6 +1090,9 @@ impl ClientState {
self.tun_config = Some(new_tun_config.clone());
self.buffered_events
.push_back(ClientEvent::TunInterfaceUpdated(new_tun_config));
self.initialise_tcp_dns_client();
self.initialise_tcp_dns_server();
}
fn drain_node_events(&mut self) {
@@ -965,6 +1173,11 @@ impl ClientState {
self.node.reset();
self.recently_connected_gateways.clear(); // Ensure we don't have sticky gateways when we roam.
self.drain_node_events();
// Resetting the client will trigger a failed `QueryResult` for each one that is in-progress.
// Failed queries get translated into `SERVFAIL` responses to the client.
// This will also allocate new local ports for our outgoing TCP connections.
self.initialise_tcp_dns_client();
}
pub(crate) fn poll_transmit(&mut self) -> Option<snownet::Transmit<'static>> {
@@ -1480,7 +1693,7 @@ mod tests {
impl ClientState {
pub fn for_test() -> ClientState {
ClientState::new(BTreeMap::new(), rand::random())
ClientState::new(BTreeMap::new(), rand::random(), Instant::now())
}
}

View File

@@ -1,6 +1,7 @@
use crate::client::IpProvider;
use anyhow::{Context, Result};
use connlib_model::{DomainName, ResourceId};
use dns_over_tcp::SocketHandle;
use domain::rdata::AllRecordData;
use domain::{
base::{
@@ -46,7 +47,7 @@ pub struct StubResolver {
}
/// A query that needs to be forwarded to an upstream DNS server for resolution.
#[derive(Debug, Clone)]
#[derive(Debug)]
pub(crate) struct RecursiveQuery {
pub server: SocketAddr,
pub message: Message<Vec<u8>>,
@@ -70,14 +71,29 @@ impl RecursiveQuery {
transport: Transport::Udp { source },
}
}
pub(crate) fn via_tcp(
source: SocketHandle,
server: SocketAddr,
message: Message<Vec<u8>>,
) -> Self {
Self {
server,
message,
transport: Transport::Tcp { source },
}
}
}
#[derive(Debug, Clone, Copy)]
#[derive(Debug)]
pub(crate) enum Transport {
Udp {
/// The original source we received the DNS query on.
source: SocketAddr,
},
Tcp {
source: SocketHandle,
},
}
/// Tells the Client how to reply to a single DNS query
@@ -259,12 +275,7 @@ impl StubResolver {
Err(e) => {
tracing::trace!("Failed to handle DNS query: {e:#}");
let response = MessageBuilder::new_vec()
.start_answer(&message, Rcode::SERVFAIL)
.unwrap()
.into_message();
ResolveStrategy::LocalResponse(response)
ResolveStrategy::LocalResponse(servfail(message))
}
}
}
@@ -335,6 +346,13 @@ impl StubResolver {
}
}
pub fn servfail(message: Message<&[u8]>) -> Message<Vec<u8>> {
MessageBuilder::new_vec()
.start_answer(&message, Rcode::SERVFAIL)
.expect("should always be able to create a heap-allocated SERVFAIL message")
.into_message()
}
fn to_a_records(ips: impl Iterator<Item = IpAddr>) -> Vec<AllRecordData<Vec<u8>, DomainName>> {
ips.filter_map(get_v4)
.map(domain::rdata::A::new)

View File

@@ -17,7 +17,10 @@ use std::{
task::{ready, Context, Poll},
time::{Duration, Instant},
};
use tokio::sync::mpsc;
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
sync::mpsc,
};
use tun::Tun;
/// Bundles together all side-effects that connlib needs to have access to.
@@ -26,7 +29,7 @@ pub struct Io {
sockets: Sockets,
unwritten_packet: Option<EncryptedPacket>,
_tcp_socket_factory: Arc<dyn SocketFactory<TcpSocket>>,
tcp_socket_factory: Arc<dyn SocketFactory<TcpSocket>>,
udp_socket_factory: Arc<dyn SocketFactory<UdpSocket>>,
dns_queries: FuturesTupleSet<io::Result<Message<Vec<u8>>>, DnsQueryMetaData>,
@@ -87,7 +90,7 @@ impl Io {
inbound_packet_rx,
timeout: None,
sockets,
_tcp_socket_factory: tcp_socket_factory,
tcp_socket_factory,
udp_socket_factory,
unwritten_packet: None,
dns_queries: FuturesTupleSet::new(DNS_QUERY_TIMEOUT, 1000),
@@ -117,19 +120,20 @@ impl Io {
match self.dns_queries.poll_unpin(cx) {
Poll::Ready((result, meta)) => {
let response = result
.map(|result| dns::RecursiveResponse {
server: meta.server,
query: meta.query.clone(),
message: result,
transport: meta.transport,
})
.unwrap_or_else(|_| dns::RecursiveResponse {
let response = match result {
Ok(result) => dns::RecursiveResponse {
server: meta.server,
query: meta.query,
message: Err(io::Error::from(io::ErrorKind::TimedOut)),
message: result,
transport: meta.transport,
});
},
Err(e @ futures_bounded::Timeout { .. }) => dns::RecursiveResponse {
server: meta.server,
query: meta.query,
message: Err(io::Error::new(io::ErrorKind::TimedOut, e)),
transport: meta.transport,
},
};
return Poll::Ready(Ok(Input::DnsResponse(response)));
}
@@ -255,6 +259,48 @@ impl Io {
tracing::debug!("Failed to queue UDP DNS query")
}
}
dns::Transport::Tcp { .. } => {
let factory = self.tcp_socket_factory.clone();
let server = query.server;
let meta = DnsQueryMetaData {
query: query.message.clone(),
server,
transport: query.transport,
};
if self
.dns_queries
.try_push(
async move {
let tcp_socket = factory(&server)?; // TODO: Optimise this to reuse a TCP socket to the same resolver.
let mut tcp_stream = tcp_socket.connect(server).await?;
let query = query.message.into_octets();
let dns_message_length = (query.len() as u16).to_be_bytes();
tcp_stream.write_all(&dns_message_length).await?;
tcp_stream.write_all(&query).await?;
let mut response_length = [0u8; 2];
tcp_stream.read_exact(&mut response_length).await?;
let response_length = u16::from_be_bytes(response_length) as usize;
// A u16 is at most 65k, meaning we are okay to allocate here based on what the remote is sending.
let mut response = vec![0u8; response_length];
tcp_stream.read_exact(&mut response).await?;
let message = Message::from_octets(response)
.map_err(|_| io::Error::other("Failed to parse DNS message"))?;
Ok(message)
},
meta,
)
.is_err()
{
tracing::debug!("Failed to queue TCP DNS query")
}
}
}
}

View File

@@ -95,7 +95,7 @@ impl ClientTunnel {
) -> Self {
Self {
io: Io::new(tcp_socket_factory, udp_socket_factory),
role_state: ClientState::new(known_hosts, rand::random()),
role_state: ClientState::new(known_hosts, rand::random(), Instant::now()),
ip4_read_buf: Box::new([0u8; MAX_UDP_SIZE]),
ip6_read_buf: Box::new([0u8; MAX_UDP_SIZE]),
encrypt_buf: Default::default(),

View File

@@ -159,31 +159,31 @@ pub(crate) fn assert_routes_are_valid(ref_client: &RefClient, sim_client: &SimCl
}
}
pub(crate) fn assert_dns_packets_properties(ref_client: &RefClient, sim_client: &SimClient) {
pub(crate) fn assert_udp_dns_packets_properties(ref_client: &RefClient, sim_client: &SimClient) {
let unexpected_dns_replies = find_unexpected_entries(
&ref_client.expected_dns_handshakes,
&sim_client.received_dns_responses,
&ref_client.expected_udp_dns_handshakes,
&sim_client.received_udp_dns_responses,
|(_, id_a), (_, id_b)| id_a == id_b,
);
if !unexpected_dns_replies.is_empty() {
tracing::error!(target: "assertions", ?unexpected_dns_replies, "❌ Unexpected DNS replies on client");
tracing::error!(target: "assertions", ?unexpected_dns_replies, "❌ Unexpected UDP DNS replies on client");
}
for (dns_server, query_id) in ref_client.expected_dns_handshakes.iter() {
for (dns_server, query_id) in ref_client.expected_udp_dns_handshakes.iter() {
let _guard =
tracing::info_span!(target: "assertions", "dns", %query_id, %dns_server).entered();
tracing::info_span!(target: "assertions", "udp_dns", %query_id, %dns_server).entered();
let key = &(*dns_server, *query_id);
let queries = &sim_client.sent_dns_queries;
let responses = &sim_client.received_dns_responses;
let queries = &sim_client.sent_udp_dns_queries;
let responses = &sim_client.received_udp_dns_responses;
let Some(client_sent_query) = queries.get(key) else {
tracing::error!(target: "assertions", ?queries, "❌ Missing DNS query on client");
tracing::error!(target: "assertions", ?queries, "❌ Missing UDP DNS query on client");
continue;
};
let Some(client_received_response) = responses.get(key) else {
tracing::error!(target: "assertions", ?responses, "❌ Missing DNS response on client");
tracing::error!(target: "assertions", ?responses, "❌ Missing UDP DNS response on client");
continue;
};
@@ -192,6 +192,26 @@ pub(crate) fn assert_dns_packets_properties(ref_client: &RefClient, sim_client:
}
}
pub(crate) fn assert_tcp_dns(ref_client: &RefClient, sim_client: &SimClient) {
for (dns_server, query_id) in ref_client.expected_tcp_dns_handshakes.iter() {
let _guard =
tracing::info_span!(target: "assertions", "tcp_dns", %query_id, %dns_server).entered();
let key = &(*dns_server, *query_id);
let queries = &sim_client.sent_tcp_dns_queries;
let responses = &sim_client.received_tcp_dns_responses;
if queries.get(key).is_none() {
tracing::error!(target: "assertions", ?queries, "❌ Missing TCP DNS query on client");
continue;
};
if responses.get(key).is_none() {
tracing::error!(target: "assertions", ?responses, "❌ Missing TCP DNS response on client");
continue;
};
}
}
fn assert_correct_src_and_dst_ips(
client_sent_request: &IpPacket,
client_received_reply: &IpPacket,

View File

@@ -1,6 +1,6 @@
use std::{
collections::{BTreeMap, BTreeSet, VecDeque},
net::IpAddr,
net::{IpAddr, SocketAddr},
time::Instant,
};
@@ -14,12 +14,47 @@ use domain::{
};
use ip_packet::IpPacket;
pub struct TcpDnsServerResource {
server: dns_over_tcp::Server,
}
#[derive(Debug, Default)]
pub struct UdpDnsServerResource {
inbound_packets: VecDeque<IpPacket>,
outbound_packets: VecDeque<IpPacket>,
}
impl TcpDnsServerResource {
pub fn new(socket: SocketAddr, now: Instant) -> Self {
let mut server = dns_over_tcp::Server::new(now);
server.set_listen_addresses::<5>(BTreeSet::from([socket]));
Self { server }
}
pub fn handle_input(&mut self, packet: IpPacket) {
self.server.handle_inbound(packet);
}
pub fn handle_timeout(
&mut self,
global_dns_records: &BTreeMap<DomainName, BTreeSet<IpAddr>>,
now: Instant,
) {
self.server.handle_timeout(now);
while let Some(query) = self.server.poll_queries() {
let response = handle_dns_query(query.message.for_slice(), global_dns_records);
self.server.send_message(query.socket, response).unwrap();
}
}
pub fn poll_outbound(&mut self) -> Option<IpPacket> {
self.server.poll_outbound()
}
}
impl UdpDnsServerResource {
pub fn handle_input(&mut self, packet: IpPacket) {
self.inbound_packets.push_back(packet);

View File

@@ -345,6 +345,10 @@ impl ReferenceState {
if connected_resources.is_empty() {
connected_resources.insert(resource);
}
// TCP has retries so we will also be connected to those for sure.
if query.transport == DnsTransport::Tcp {
connected_resources.insert(resource);
}
}
continue;

View File

@@ -3,7 +3,7 @@ use super::{
sim_net::{any_ip_stack, any_port, host, Host},
sim_relay::{map_explode, SimRelay},
strategies::latency,
transition::DnsQuery,
transition::{DnsQuery, DnsTransport},
IcmpIdentifier, IcmpSeq, QueryId,
};
use crate::{
@@ -15,7 +15,7 @@ use crate::{proptest::*, ClientState};
use bimap::BiMap;
use connlib_model::{ClientId, GatewayId, RelayId, ResourceId};
use domain::{
base::{Message, Rtype, ToName},
base::{iana::Opcode, Message, MessageBuilder, Question, Rtype, ToName},
rdata::AllRecordData,
};
use ip_network::{IpNetwork, Ipv4Network, Ipv6Network};
@@ -44,34 +44,45 @@ pub(crate) struct SimClient {
pub(crate) dns_records: HashMap<DomainName, Vec<IpAddr>>,
/// Bi-directional mapping between connlib's sentinel DNS IPs and the effective DNS servers.
pub(crate) dns_by_sentinel: BiMap<IpAddr, SocketAddr>,
dns_by_sentinel: BiMap<IpAddr, SocketAddr>,
pub(crate) ipv4_routes: BTreeSet<Ipv4Network>,
pub(crate) ipv6_routes: BTreeSet<Ipv6Network>,
pub(crate) sent_dns_queries: HashMap<(SocketAddr, QueryId), IpPacket>,
pub(crate) received_dns_responses: BTreeMap<(SocketAddr, QueryId), IpPacket>,
pub(crate) sent_udp_dns_queries: HashMap<(SocketAddr, QueryId), IpPacket>,
pub(crate) received_udp_dns_responses: BTreeMap<(SocketAddr, QueryId), IpPacket>,
pub(crate) sent_tcp_dns_queries: HashSet<(SocketAddr, QueryId)>,
pub(crate) received_tcp_dns_responses: BTreeSet<(SocketAddr, QueryId)>,
pub(crate) sent_icmp_requests: HashMap<(u16, u16), IpPacket>,
pub(crate) received_icmp_replies: BTreeMap<(u16, u16), IpPacket>,
pub(crate) tcp_dns_client: dns_over_tcp::Client,
enc_buffer: EncryptBuffer,
}
impl SimClient {
pub(crate) fn new(id: ClientId, sut: ClientState) -> Self {
pub(crate) fn new(id: ClientId, sut: ClientState, now: Instant) -> Self {
let mut tcp_dns_client = dns_over_tcp::Client::new(now, [0u8; 32]);
tcp_dns_client.set_source_interface(Ipv4Addr::LOCALHOST, Ipv6Addr::LOCALHOST);
Self {
id,
sut,
dns_records: Default::default(),
dns_by_sentinel: Default::default(),
sent_dns_queries: Default::default(),
received_dns_responses: Default::default(),
sent_udp_dns_queries: Default::default(),
received_udp_dns_responses: Default::default(),
sent_tcp_dns_queries: Default::default(),
received_tcp_dns_responses: Default::default(),
sent_icmp_requests: Default::default(),
received_icmp_replies: Default::default(),
enc_buffer: Default::default(),
ipv4_routes: Default::default(),
ipv6_routes: Default::default(),
tcp_dns_client,
}
}
@@ -80,36 +91,85 @@ impl SimClient {
self.dns_by_sentinel.right_values().copied().collect()
}
pub(crate) fn set_new_dns_servers(&mut self, mapping: BiMap<IpAddr, SocketAddr>) {
if self.dns_by_sentinel != mapping {
self.tcp_dns_client
.set_resolvers(
mapping
.left_values()
.map(|ip| SocketAddr::new(*ip, 53))
.collect(),
)
.unwrap();
}
self.dns_by_sentinel = mapping;
}
pub(crate) fn dns_mapping(&self) -> &BiMap<IpAddr, SocketAddr> {
&self.dns_by_sentinel
}
pub(crate) fn send_dns_query_for(
&mut self,
domain: DomainName,
r_type: Rtype,
query_id: u16,
dns_server: SocketAddr,
upstream: SocketAddr,
dns_transport: DnsTransport,
now: Instant,
) -> Option<Transmit<'static>> {
let Some(dns_server) = self.dns_by_sentinel.get_by_right(&dns_server).copied() else {
tracing::error!(%dns_server, "Unknown DNS server");
let Some(sentinel) = self.dns_by_sentinel.get_by_right(&upstream).copied() else {
tracing::error!(%upstream, "Unknown DNS server");
return None;
};
tracing::debug!(%dns_server, %domain, "Sending DNS query");
tracing::debug!(%sentinel, %domain, "Sending DNS query");
let src = self
.sut
.tunnel_ip_for(dns_server)
.tunnel_ip_for(sentinel)
.expect("tunnel should be initialised");
let packet = ip_packet::make::dns_query(
domain,
r_type,
SocketAddr::new(src, 9999), // An application would pick a random source port that is free.
SocketAddr::new(dns_server, 53),
query_id,
)
.unwrap();
// Create the DNS query message
let mut msg_builder = MessageBuilder::new_vec();
self.encapsulate(packet, now)
msg_builder.header_mut().set_opcode(Opcode::QUERY);
msg_builder.header_mut().set_rd(true);
msg_builder.header_mut().set_id(query_id);
// Create the query
let mut question_builder = msg_builder.question();
question_builder
.push(Question::new_in(domain, r_type))
.unwrap();
let message = question_builder.into_message();
match dns_transport {
DnsTransport::Udp => {
let packet = ip_packet::make::udp_packet(
src,
sentinel,
9999, // An application would pick a free source port.
53,
message.as_octets().to_vec(),
)
.unwrap();
self.sent_udp_dns_queries
.insert((upstream, query_id), packet.clone());
self.encapsulate(packet, now)
}
DnsTransport::Tcp => {
self.tcp_dns_client
.send_query(SocketAddr::new(sentinel, 53), message)
.unwrap();
self.sent_tcp_dns_queries.insert((upstream, query_id));
None
}
}
}
pub(crate) fn encapsulate(
@@ -131,24 +191,6 @@ impl SimClient {
}
}
{
if let Some(udp) = packet.as_udp() {
if let Ok(message) = Message::from_slice(udp.payload()) {
debug_assert!(
!message.header().qr(),
"every DNS message sent from the client should be a DNS query"
);
// Map back to upstream socket so we can assert on it correctly.
let sentinel = SocketAddr::from((packet.destination(), udp.destination_port()));
let upstream = self.upstream_dns_by_sentinel(&sentinel).unwrap();
self.sent_dns_queries
.insert((upstream, message.header().id()), packet.clone());
}
}
}
let Some(enc_packet) = self.sut.handle_tun_input(packet, now, &mut self.enc_buffer) else {
self.sut.handle_timeout(now); // If we handled the packet internally, make sure to advance state.
return None;
@@ -191,6 +233,11 @@ impl SimClient {
}
}
if self.tcp_dns_client.accepts(&packet) {
self.tcp_dns_client.handle_inbound(packet);
return;
}
if let Some(udp) = packet.as_udp() {
if udp.source_port() == 53 {
let message = Message::from_slice(udp.payload())
@@ -203,36 +250,9 @@ impl SimClient {
return;
};
self.received_dns_responses
self.received_udp_dns_responses
.insert((upstream, message.header().id()), packet.clone());
for record in message.answer().unwrap() {
let record = record.unwrap();
let domain = record.owner().to_name();
#[expect(clippy::wildcard_enum_match_arm)]
let ip = match record
.into_any_record::<AllRecordData<_, _>>()
.unwrap()
.data()
{
AllRecordData::A(a) => IpAddr::from(a.addr()),
AllRecordData::Aaaa(aaaa) => IpAddr::from(aaaa.addr()),
AllRecordData::Ptr(_) => {
continue;
}
unhandled => {
panic!("Unexpected record data: {unhandled:?}")
}
};
self.dns_records.entry(domain).or_default().push(ip);
}
// Ensure all IPs are always sorted.
for ips in self.dns_records.values_mut() {
ips.sort()
}
self.handle_dns_response(message);
return;
}
@@ -259,6 +279,36 @@ impl SimClient {
Some(*socket)
}
pub(crate) fn handle_dns_response(&mut self, message: &Message<[u8]>) {
for record in message.answer().unwrap() {
let record = record.unwrap();
let domain = record.owner().to_name();
#[expect(clippy::wildcard_enum_match_arm)]
let ip = match record
.into_any_record::<AllRecordData<_, _>>()
.unwrap()
.data()
{
AllRecordData::A(a) => IpAddr::from(a.addr()),
AllRecordData::Aaaa(aaaa) => IpAddr::from(aaaa.addr()),
AllRecordData::Ptr(_) => {
continue;
}
unhandled => {
panic!("Unexpected record data: {unhandled:?}")
}
};
self.dns_records.entry(domain).or_default().push(ip);
}
// Ensure all IPs are always sorted.
for ips in self.dns_records.values_mut() {
ips.sort()
}
}
}
/// Reference state for a particular client.
@@ -327,17 +377,20 @@ pub struct RefClient {
#[derivative(Debug = "ignore")]
pub(crate) expected_icmp_handshakes:
BTreeMap<GatewayId, BTreeMap<u64, (ResourceDst, IcmpSeq, IcmpIdentifier)>>,
/// The expected DNS handshakes.
/// The expected UDP DNS handshakes.
#[derivative(Debug = "ignore")]
pub(crate) expected_dns_handshakes: VecDeque<(SocketAddr, QueryId)>,
pub(crate) expected_udp_dns_handshakes: VecDeque<(SocketAddr, QueryId)>,
/// The expected TCP DNS handshakes.
#[derivative(Debug = "ignore")]
pub(crate) expected_tcp_dns_handshakes: VecDeque<(SocketAddr, QueryId)>,
}
impl RefClient {
/// Initialize the [`ClientState`].
///
/// This simulates receiving the `init` message from the portal.
pub(crate) fn init(self) -> SimClient {
let mut client_state = ClientState::new(self.known_hosts, self.key.0); // Cheating a bit here by reusing the key as seed.
pub(crate) fn init(self, now: Instant) -> SimClient {
let mut client_state = ClientState::new(self.known_hosts, self.key.0, now); // Cheating a bit here by reusing the key as seed.
client_state.update_interface_config(Interface {
ipv4: self.tunnel_ip4,
ipv6: self.tunnel_ip6,
@@ -345,7 +398,7 @@ impl RefClient {
});
client_state.update_system_resolvers(self.system_dns_resolvers.clone());
SimClient::new(self.id, client_state)
SimClient::new(self.id, client_state, now)
}
pub(crate) fn disconnect_resource(&mut self, resource: &ResourceId) {
@@ -624,8 +677,16 @@ impl RefClient {
.or_default()
.insert(query.r_type);
self.expected_dns_handshakes
.push_back((query.dns_server, query.query_id));
match query.transport {
DnsTransport::Udp => {
self.expected_udp_dns_handshakes
.push_back((query.dns_server, query.query_id));
}
DnsTransport::Tcp => {
self.expected_tcp_dns_handshakes
.push_back((query.dns_server, query.query_id));
}
}
}
pub(crate) fn ipv4_cidr_resource_dsts(&self) -> Vec<Ipv4Network> {
@@ -930,7 +991,8 @@ fn ref_client(
connected_dns_resources: Default::default(),
connected_internet_resource: Default::default(),
expected_icmp_handshakes: Default::default(),
expected_dns_handshakes: Default::default(),
expected_udp_dns_handshakes: Default::default(),
expected_tcp_dns_handshakes: Default::default(),
disabled_resources: Default::default(),
resources: Default::default(),
ipv4_routes: Default::default(),

View File

@@ -1,5 +1,5 @@
use super::{
dns_server_resource::UdpDnsServerResource,
dns_server_resource::{TcpDnsServerResource, UdpDnsServerResource},
reference::{private_key, PrivateKey},
sim_net::{any_port, dual_ip_stack, host, Host},
sim_relay::{map_explode, SimRelay},
@@ -28,6 +28,7 @@ pub(crate) struct SimGateway {
pub(crate) received_icmp_requests: BTreeMap<u64, IpPacket>,
udp_dns_server_resources: HashMap<SocketAddr, UdpDnsServerResource>,
tcp_dns_server_resources: HashMap<SocketAddr, TcpDnsServerResource>,
}
impl SimGateway {
@@ -38,6 +39,7 @@ impl SimGateway {
received_icmp_requests: Default::default(),
enc_buffer: Default::default(),
udp_dns_server_resources: Default::default(),
tcp_dns_server_resources: Default::default(),
}
}
@@ -70,8 +72,14 @@ impl SimGateway {
std::iter::from_fn(|| s.poll_outbound())
});
let tcp_server_packets = self.tcp_dns_server_resources.values_mut().flat_map(|s| {
s.handle_timeout(global_dns_records, now);
std::iter::from_fn(|| s.poll_outbound())
});
udp_server_packets
.chain(tcp_server_packets)
.filter_map(|packet| {
Some(
self.sut
@@ -83,12 +91,18 @@ impl SimGateway {
.collect()
}
pub(crate) fn deploy_new_dns_servers(&mut self, dns_servers: impl Iterator<Item = SocketAddr>) {
pub(crate) fn deploy_new_dns_servers(
&mut self,
dns_servers: impl Iterator<Item = SocketAddr>,
now: Instant,
) {
self.udp_dns_server_resources.clear();
for server in dns_servers {
self.udp_dns_server_resources
.insert(server, UdpDnsServerResource::default());
self.tcp_dns_server_resources
.insert(server, TcpDnsServerResource::new(server, now));
}
}
@@ -117,6 +131,15 @@ impl SimGateway {
}
}
if let Some(tcp) = packet.as_tcp() {
let socket = SocketAddr::new(packet.destination(), tcp.destination_port());
if let Some(server) = self.tcp_dns_server_resources.get_mut(&socket) {
server.handle_input(packet);
return None;
}
}
tracing::error!(?packet, "Unhandled packet");
None
}

View File

@@ -115,6 +115,9 @@ pub(crate) fn dns_servers() -> impl Strategy<Value = BTreeSet<SocketAddr>> {
.prop_filter("must not be in IPv4 resources range", |ip| {
!crate::client::IPV4_RESOURCES.contains(*ip)
})
.prop_filter("must be addressable IP", |ip| {
!ip.is_unspecified() && !ip.is_multicast() && !ip.is_broadcast()
})
.prop_map(|ip| SocketAddr::from((ip, 53))),
1..4,
);
@@ -126,6 +129,9 @@ pub(crate) fn dns_servers() -> impl Strategy<Value = BTreeSet<SocketAddr>> {
.prop_filter("must not be in IPv6 resources range", |ip| {
!crate::client::IPV6_RESOURCES.contains(*ip)
})
.prop_filter("must be addressable IP", |ip| {
!ip.is_unspecified() && !ip.is_multicast()
})
.prop_map(|ip| SocketAddr::from((ip, 53))),
1..4,
);

View File

@@ -47,9 +47,10 @@ impl TunnelTest {
// Initialize the system under test from our reference state.
pub(crate) fn init_test(ref_state: &ReferenceState, flux_capacitor: FluxCapacitor) -> Self {
// Construct client, gateway and relay from the initial state.
let mut client = ref_state
.client
.map(|ref_client, _, _| ref_client.init(), debug_span!("client"));
let mut client = ref_state.client.map(
|ref_client, _, _| ref_client.init(flux_capacitor.now()),
debug_span!("client"),
);
let mut gateways = ref_state
.gateways
@@ -203,10 +204,11 @@ impl TunnelTest {
r_type,
dns_server,
query_id,
transport,
} in queries
{
let transmit = state.client.exec_mut(|sim| {
sim.send_dns_query_for(domain, r_type, query_id, dns_server, now)
sim.send_dns_query_for(domain, r_type, query_id, dns_server, transport, now)
});
buffered_transmits.push_from(transmit, &state.client, now);
@@ -342,7 +344,8 @@ impl TunnelTest {
sim_gateways,
&ref_state.global_dns_records,
);
assert_dns_packets_properties(ref_client, sim_client);
assert_udp_dns_packets_properties(ref_client, sim_client);
assert_tcp_dns(ref_client, sim_client);
assert_known_hosts_are_valid(ref_client, sim_client);
assert_dns_servers_are_valid(ref_client, sim_client);
assert_routes_are_valid(ref_client, sim_client);
@@ -388,8 +391,10 @@ impl TunnelTest {
let server = query.server;
let transport = query.transport;
let response =
self.on_recursive_dns_query(query.clone(), &ref_state.global_dns_records);
let response = self.on_recursive_dns_query(
query.message.for_slice_ref(),
&ref_state.global_dns_records,
);
self.client.exec_mut(|c| {
c.sut.handle_dns_response(dns::RecursiveResponse {
server,
@@ -486,6 +491,33 @@ impl TunnelTest {
) {
let now = self.flux_capacitor.now();
// Handle the TCP DNS client, i.e. simulate applications making TCP DNS queries.
self.client.exec_mut(|c| {
c.tcp_dns_client.handle_timeout(now);
while let Some(result) = c.tcp_dns_client.poll_query_result() {
match result.result {
Ok(message) => {
let upstream = c.dns_mapping().get_by_left(&result.server.ip()).unwrap();
c.received_tcp_dns_responses
.insert((*upstream, result.query.header().id()));
c.handle_dns_response(message.for_slice())
}
Err(e) => {
tracing::error!("TCP DNS query failed: {e:#}");
}
}
}
});
while let Some(transmit) = self.client.exec_mut(|c| {
let packet = c.tcp_dns_client.poll_outbound()?;
c.encapsulate(packet, now)
}) {
buffered_transmits.push_from(transmit, &self.client, now)
}
// Handle the client's `Transmit`s and timeout.
while let Some(transmit) = self.client.poll_transmit(now) {
self.client.exec_mut(|c| c.receive(transmit, now))
}
@@ -495,6 +527,7 @@ impl TunnelTest {
}
});
// Handle all gateway `Transmit`s and timeouts.
for (_, gateway) in self.gateways.iter_mut() {
for transmit in gateway.exec_mut(|g| g.advance_resources(global_dns_records, now)) {
buffered_transmits.push_from(transmit, gateway, now);
@@ -517,6 +550,7 @@ impl TunnelTest {
});
}
// Handle all relay `Transmit`s and timeouts.
for (_, relay) in self.relays.iter_mut() {
while let Some(transmit) = relay.poll_transmit(now) {
let Some(reply) = relay.exec_mut(|r| r.receive(transmit, now)) else {
@@ -682,7 +716,7 @@ impl TunnelTest {
tracing::warn!("Unimplemented");
}
ClientEvent::TunInterfaceUpdated(config) => {
if self.client.inner().dns_by_sentinel == config.dns_by_sentinel
if self.client.inner().dns_mapping() == &config.dns_by_sentinel
&& self.client.inner().ipv4_routes == config.ipv4_routes
&& self.client.inner().ipv6_routes == config.ipv6_routes
{
@@ -691,16 +725,19 @@ impl TunnelTest {
);
}
if self.client.inner().dns_by_sentinel != config.dns_by_sentinel {
if self.client.inner().dns_mapping() != &config.dns_by_sentinel {
for gateway in self.gateways.values_mut() {
gateway.exec_mut(|g| {
g.deploy_new_dns_servers(config.dns_by_sentinel.right_values().copied())
g.deploy_new_dns_servers(
config.dns_by_sentinel.right_values().copied(),
now,
)
})
}
}
self.client.exec_mut(|c| {
c.dns_by_sentinel = config.dns_by_sentinel;
c.set_new_dns_servers(config.dns_by_sentinel);
c.ipv4_routes = config.ipv4_routes;
c.ipv6_routes = config.ipv6_routes;
});
@@ -778,11 +815,9 @@ impl TunnelTest {
fn on_recursive_dns_query(
&self,
query: crate::dns::RecursiveQuery,
query: Message<&[u8]>,
global_dns_records: &BTreeMap<DomainName, BTreeSet<IpAddr>>,
) -> Message<Vec<u8>> {
let query = query.message;
let response = MessageBuilder::new_vec();
let mut answers = response.start_answer(&query, Rcode::NOERROR).unwrap();

View File

@@ -98,6 +98,13 @@ pub(crate) struct DnsQuery {
/// The DNS query ID.
pub(crate) query_id: u16,
pub(crate) dns_server: SocketAddr,
pub(crate) transport: DnsTransport,
}
#[derive(Debug, Clone, PartialEq)]
pub(crate) enum DnsTransport {
Udp,
Tcp,
}
pub(crate) fn ping_random_ip<I>(
@@ -202,9 +209,17 @@ pub(crate) fn dns_queries(
query_type(),
Just(query_id),
ptr_query_ip(),
dns_transport(),
)
.prop_map(
|(mut domain, dns_server, r_type, query_id, maybe_reverse_record)| {
|(
mut domain,
dns_server,
r_type,
query_id,
maybe_reverse_record,
transport,
)| {
if matches!(r_type, Rtype::PTR) {
domain =
DomainName::reverse_from_addr(maybe_reverse_record).unwrap();
@@ -215,6 +230,7 @@ pub(crate) fn dns_queries(
r_type,
query_id,
dns_server,
transport,
}
},
)
@@ -231,6 +247,10 @@ fn ptr_query_ip() -> impl Strategy<Value = IpAddr> {
]
}
fn dns_transport() -> impl Strategy<Value = DnsTransport> {
prop_oneof![Just(DnsTransport::Udp), Just(DnsTransport::Tcp),]
}
pub(crate) fn query_type() -> impl Strategy<Value = Rtype> {
prop_oneof![
Just(Rtype::A),

View File

@@ -1,5 +1,5 @@
use std::{
collections::{BTreeSet, HashMap, HashSet, VecDeque},
collections::{BTreeMap, BTreeSet, HashMap, HashSet, VecDeque},
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
time::{Duration, Instant},
};
@@ -14,7 +14,7 @@ use ip_packet::IpPacket;
use rand::{rngs::StdRng, Rng, SeedableRng};
use smoltcp::{
iface::{Interface, PollResult, SocketSet},
socket::tcp::{self, Socket},
socket::tcp,
};
/// A sans-io DNS-over-TCP client.
@@ -33,7 +33,7 @@ pub struct Client<const MIN_PORT: u16 = 49152, const MAX_PORT: u16 = 65535> {
source_ips: Option<(Ipv4Addr, Ipv6Addr)>,
sockets: SocketSet<'static>,
sockets_by_remote: HashMap<SocketAddr, smoltcp::iface::SocketHandle>,
sockets_by_remote: BTreeMap<SocketAddr, smoltcp::iface::SocketHandle>,
local_ports_by_socket: HashMap<smoltcp::iface::SocketHandle, u16>,
/// Queries we should send to a DNS resolver.
pending_queries_by_remote: HashMap<SocketAddr, VecDeque<Message<Vec<u8>>>>,
@@ -95,45 +95,17 @@ impl<const MIN_PORT: u16, const MAX_PORT: u16> Client<MIN_PORT, MAX_PORT> {
self.sockets = SocketSet::new(vec![]);
self.sockets_by_remote.clear();
self.local_ports_by_socket.clear();
self.abort_all_pending_and_sent_queries();
self.query_results
.extend(
self.pending_queries_by_remote
.drain()
.flat_map(|(server, queries)| {
into_failed_results(server, queries, || anyhow!("Aborted"))
}),
);
self.query_results
.extend(
self.sent_queries_by_remote
.drain()
.flat_map(|(server, queries)| {
into_failed_results(server, queries.into_values(), || anyhow!("Aborted"))
}),
);
// Second, try to allocate a unique port per resolver.
let unique_ports = self.sample_unique_ports(resolvers.len())?;
// Second, try to create all new sockets.
let new_sockets = std::iter::zip(self.sample_unique_ports(resolvers.len())?, resolvers)
.map(|(port, server)| {
let local_endpoint = match server {
SocketAddr::V4(_) => SocketAddr::new(ipv4_source.into(), port),
SocketAddr::V6(_) => SocketAddr::new(ipv6_source.into(), port),
};
let socket = create_tcp_socket();
Ok((server, local_endpoint, socket))
})
.collect::<Result<Vec<_>>>()?;
// Third, if everything was successful, change the local state.
for (server, local_endpoint, socket) in new_sockets {
let handle = self.sockets.add(socket);
self.sockets_by_remote.insert(server, handle);
self.local_ports_by_socket
.insert(handle, local_endpoint.port());
}
// Third, initialise the sockets.
self.init_sockets(
std::iter::zip(unique_ports, resolvers),
ipv4_source,
ipv6_source,
);
Ok(())
}
@@ -231,7 +203,7 @@ impl<const MIN_PORT: u16, const MAX_PORT: u16> Client<MIN_PORT, MAX_PORT> {
}
for (remote, handle) in self.sockets_by_remote.iter_mut() {
let socket = self.sockets.get_mut::<Socket>(*handle);
let socket = self.sockets.get_mut::<tcp::Socket>(*handle);
let server = *remote;
let pending_queries = self.pending_queries_by_remote.entry(server).or_default();
@@ -292,6 +264,52 @@ impl<const MIN_PORT: u16, const MAX_PORT: u16> Client<MIN_PORT, MAX_PORT> {
Some(self.last_now + Duration::from(poll_in))
}
fn abort_all_pending_and_sent_queries(&mut self) {
let aborted_pending_queries =
self.pending_queries_by_remote
.drain()
.flat_map(|(server, queries)| {
into_failed_results(server, queries, || anyhow!("Aborted"))
});
let aborted_sent_queries =
self.sent_queries_by_remote
.drain()
.flat_map(|(server, queries)| {
into_failed_results(server, queries.into_values(), || anyhow!("Aborted"))
});
self.query_results
.extend(aborted_pending_queries.chain(aborted_sent_queries));
}
fn init_sockets(
&mut self,
ports_and_resolvers: impl IntoIterator<Item = (u16, SocketAddr)>,
ipv4_source: Ipv4Addr,
ipv6_source: Ipv6Addr,
) {
let new_sockets = ports_and_resolvers
.into_iter()
.map(|(port, server)| {
let local_endpoint = match server {
SocketAddr::V4(_) => SocketAddr::new(ipv4_source.into(), port),
SocketAddr::V6(_) => SocketAddr::new(ipv6_source.into(), port),
};
let socket = create_tcp_socket();
(server, local_endpoint, socket)
})
.collect::<Vec<_>>();
for (server, local_endpoint, socket) in new_sockets {
let handle = self.sockets.add(socket);
self.sockets_by_remote.insert(server, handle);
self.local_ports_by_socket
.insert(handle, local_endpoint.port());
}
}
fn sample_unique_ports(&mut self, num_ports: usize) -> Result<impl Iterator<Item = u16>> {
let mut ports = HashSet::with_capacity(num_ports);
let range = MIN_PORT..=MAX_PORT;
@@ -312,7 +330,7 @@ impl<const MIN_PORT: u16, const MAX_PORT: u16> Client<MIN_PORT, MAX_PORT> {
}
fn send_pending_queries(
socket: &mut Socket,
socket: &mut tcp::Socket,
server: SocketAddr,
pending_queries: &mut VecDeque<Message<Vec<u8>>>,
sent_queries: &mut HashMap<u16, Message<Vec<u8>>>,
@@ -348,7 +366,7 @@ fn send_pending_queries(
}
fn recv_responses(
socket: &mut Socket,
socket: &mut tcp::Socket,
server: SocketAddr,
pending_queries: &mut VecDeque<Message<Vec<u8>>>,
sent_queries: &mut HashMap<u16, Message<Vec<u8>>>,

View File

@@ -6,7 +6,7 @@ mod stub_device;
mod time;
pub use client::{Client, QueryResult};
pub use server::{Server, SocketHandle};
pub use server::{Query, Server, SocketHandle};
fn create_tcp_socket() -> smoltcp::socket::tcp::Socket<'static> {
/// The 2-byte length prefix of DNS over TCP messages limits their size to effectively u16::MAX.

View File

@@ -70,8 +70,23 @@ impl Server {
/// The constant configures, how many concurrent clients you would like to be able to serve per listen address.
pub fn set_listen_addresses<const NUM_CONCURRENT_CLIENTS: usize>(
&mut self,
addresses: Vec<SocketAddr>,
addresses: BTreeSet<SocketAddr>,
) {
let current_listen_endpoints = self
.listen_endpoints
.values()
.copied()
.collect::<BTreeSet<_>>();
if current_listen_endpoints == addresses {
tracing::debug!(
?current_listen_endpoints,
"Already listening on this exact set of addresses"
);
return;
}
assert!(NUM_CONCURRENT_CLIENTS > 0);
let mut sockets =
@@ -143,13 +158,6 @@ impl Server {
Ok(())
}
/// Resets the socket associated with the given handle.
///
/// Use this if you encountered an error while processing a previously emitted DNS query.
pub fn reset(&mut self, handle: SocketHandle) {
self.sockets.get_mut::<tcp::Socket>(handle.0).abort();
}
/// Inform the server that time advanced.
///
/// Typical for a sans-IO design, `handle_timeout` will work through all local buffers and process them as much as possible.

View File

@@ -25,7 +25,7 @@ fn smoke() {
.unwrap();
let mut dns_server = dns_over_tcp::Server::new(Instant::now());
dns_server.set_listen_addresses::<1>(vec![resolver_addr]);
dns_server.set_listen_addresses::<1>(BTreeSet::from([resolver_addr]));
for id in 0..5 {
dns_client

View File

@@ -1,4 +1,5 @@
use std::{
collections::BTreeSet,
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4},
process::Stdio,
task::{ready, Context, Poll},
@@ -36,7 +37,7 @@ async fn smoke() {
let listen_addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(100, 100, 111, 1), 53));
let mut dns_server = dns_over_tcp::Server::new(Instant::now());
dns_server.set_listen_addresses::<CLIENT_CONCURRENCY>(vec![listen_addr]);
dns_server.set_listen_addresses::<CLIENT_CONCURRENCY>(BTreeSet::from([listen_addr]));
let mut eventloop = Eventloop::new(Box::new(tun), dns_server);
tokio::spawn(std::future::poll_fn(move |cx| eventloop.poll(cx)));

View File

@@ -11,7 +11,6 @@ proptest = ["dep:proptest"]
[dependencies]
anyhow = "1.0.86"
domain = "0.10.1"
etherparse = "0.15"
proptest = { version = "1", optional = true }
thiserror = "1"

View File

@@ -2,9 +2,8 @@
use crate::{IpPacket, IpPacketBuf};
use anyhow::{Context, Result};
use domain::base::{iana::Opcode, MessageBuilder, Name, Question, Rtype};
use etherparse::PacketBuilder;
use std::net::{IpAddr, SocketAddr};
use std::net::IpAddr;
/// Helper macro to turn a [`PacketBuilder`] into an [`IpPacket`].
#[macro_export]
@@ -151,31 +150,6 @@ where
}
}
pub fn dns_query(
domain: Name<Vec<u8>>,
kind: Rtype,
src: SocketAddr,
dst: SocketAddr,
id: u16,
) -> Result<IpPacket, IpVersionMismatch> {
// Create the DNS query message
let mut msg_builder = MessageBuilder::new_vec();
msg_builder.header_mut().set_opcode(Opcode::QUERY);
msg_builder.header_mut().set_rd(true);
msg_builder.header_mut().set_id(id);
// Create the query
let mut question_builder = msg_builder.question();
question_builder
.push(Question::new_in(domain, kind))
.unwrap();
let payload = question_builder.finish();
udp_packet(src.ip(), dst.ip(), src.port(), dst.port(), payload)
}
#[derive(thiserror::Error, Debug)]
#[error("IPs must be of the same version")]
pub struct IpVersionMismatch;

14
scripts/tests/tcp-dns.sh Executable file
View File

@@ -0,0 +1,14 @@
#!/usr/bin/env bash
source "./scripts/tests/lib.sh"
client sh -c "apk add bind-tools" # The compat tests run using the production image which doesn't have `dig`.
echo "Resolving DNS resource over TCP"
client sh -c "dig +tcp dns.httpbin"
echo "Resolving non-DNS resource over TCP"
client sh -c "dig +tcp example.com"
echo "Testing TCP fallback"
client sh -c "dig 2048.size.dns.netmeister.org"

View File

@@ -11,7 +11,9 @@ export default function Android() {
title="Android"
>
{/* When you cut a release, remove any solved issues from the "known issues" lists over in `client-apps`. This must not be done when the issue's PR merges. */}
<Unreleased></Unreleased>
<Unreleased>
<ChangeItem>Handles DNS queries over TCP correctly.</ChangeItem>
</Unreleased>
<Entry version="1.3.5" date={new Date("2024-10-03")}>
<ChangeItem pull="6831">
Ensures Firefox doesn't attempt to use DNS over HTTPS when Firezone is

View File

@@ -11,7 +11,9 @@ export default function Apple() {
title="macOS / iOS"
>
{/* When you cut a release, remove any solved issues from the "known issues" lists over in `client-apps`. This must not be done when the issue's PR merges. */}
<Unreleased></Unreleased>
<Unreleased>
<ChangeItem>Handles DNS queries over TCP correctly.</ChangeItem>
</Unreleased>
<Entry version="1.3.6" date={new Date("2024-10-02")}>
<ChangeItem pull="6831">
Ensures Firefox doesn't attempt to use DNS over HTTPS when Firezone is

View File

@@ -15,9 +15,7 @@ export default function GUI({ title }: { title: string }) {
<Entries href={href} arches={arches} title={title}>
{/* When you cut a release, remove any solved issues from the "known issues" lists over in `client-apps`. This must not be done when the issue's PR merges. */}
<Unreleased>
<ChangeItem enable={title === "Linux GUI"}>
This is a maintenance release with no user-facing changes.
</ChangeItem>
<ChangeItem>Handles DNS queries over TCP correctly.</ChangeItem>
<ChangeItem enable={title === "Windows"} pull="7009">
The IPC service `firezone-client-ipc.exe` is now signed.
</ChangeItem>

View File

@@ -11,7 +11,9 @@ export default function Headless() {
return (
<Entries href={href} arches={arches} title="Linux headless">
{/* When you cut a release, remove any solved issues from the "known issues" lists over in `client-apps`. This must not be done when the issue's PR merges. */}
<Unreleased></Unreleased>
<Unreleased>
<ChangeItem>Handles DNS queries over TCP correctly.</ChangeItem>
</Unreleased>
<Entry version="1.3.4" date={new Date("2024-10-02")}>
<ChangeItem pull="6831">
Ensures Firefox doesn't attempt to use DNS over HTTPS when Firezone is