feat(connlib): resolve SRV & TXT queries for resources in sites (#8335)

## Description

We want to resolve DNS queries of type SRV & TXT for DNS resources
within the network context of the site that is hosting the DNS resource
itself. This allows admins to e.g. deploy dedicated nameservers into
those sites and have them resolve their SRV and TXT records to names
that are scoped to that particular site.

SRV records themselves return more domains which - if they are
configured as DNS resources - will be intercepted and then routed to the
correct site.

Prior to this PR, SRV & TXT records got resolved by the DNS server
configured on the client (or the server defined in the Firezone portal),
even if the domain in question was a DNS resource. This effectively
meant that those SRV records have to be valid globally and could not be
specific to the site that the DNS resource is hosted in.

## Example

Say we have these wildcard DNS resources:

- `**.department-a.example.com`
- `**.department-b.example.com`

Each of these DNS resources is assigned to a different site. If we now
issue an SRV DNS query to `_my-service.department-a.example.com`, we may
receive back the following records:

- `_my-service.department-a.example.com. 86400 IN SRV 10 60 8080
my-service1.department-a.example.com.`
- `_my-service.department-a.example.com. 86400 IN SRV 10 60 8080
my-service2.department-a.example.com.`
- `_my-service.department-a.example.com. 86400 IN SRV 10 60 8080
my-service3.department-a.example.com.`

Notice how the SRV records point to domains that will also match the
wildcard DNS resource above! If that is the case, Firezone will also
intercept A & AAAA queries for this service (which are a natural
follow-up from an application making an SRV query). As a result, traffic
for `my-service1.department-a.example.com` will be routed to the same
site the DNS resource is defined in. If the returned domains don't match
the wildcard DNS resource, the traffic will either not be intercepted at
all (if it is not a DNS resource) or routed to whichever site defines
the corresponding DNS resource.

All of these scenarios may be what the admin wants. If the SRV records
defined for the DNS resource are globally valid (and e.g. not even
resources), then resolving them using the Client's system resolver may
be all that is needed. If the services are running in a dedicated site,
that traffic should indeed be routed to that site.

As such, Firezone itself cannot make any assumption about the structure
of these records at all. The only thing that is enabled with this PR is
that IF the structure happens to match the same DNS resource, it allows
admins to deploy site-specific services that resolve their concrete
domains via SRV records.

## Testing

The implementation is tested using our property-based testing framework.
In order to cover these cases, we introduce the notion of site-specific
DNS records which are sampled when we create each individual Gateway.
When selecting a domain to query for, all global DNS records and the
site-specific ones are merged and a domain name and query type is chosen
at random.

At present, this testing framework does not assert that the DNS response
itself is correct, i.e. that it actually returned the site-specific
record. We don't assert this for any other DNS queries, hence this is
left for a future extension. We do assert using our regression grep's
that we hit the codepath of querying an SRV or TXT record for a DNS
resource.

Related: #8221
This commit is contained in:
Thomas Eizinger
2025-03-04 23:41:32 +11:00
committed by GitHub
parent 1fe38bb272
commit 99d8fcb8fc
12 changed files with 453 additions and 132 deletions

View File

@@ -116,6 +116,7 @@ jobs:
rg --count --no-ignore "Performed IP-NAT64" $TESTCASES_DIR
rg --count --no-ignore "Too big DNS response, truncating" $TESTCASES_DIR
rg --count --no-ignore "Destination is unreachable" $TESTCASES_DIR
rg --count --no-ignore "Forwarding query for DNS resource to corresponding site" $TESTCASES_DIR
env:
# <https://github.com/rust-lang/cargo/issues/5999>

View File

@@ -159,3 +159,4 @@ cc a7f22e7cc2c79ffd580baf4bc8296557c67afe245ccf07e895e7cd2a969a228e
cc eca099d2fdef9adba841f523ce426089fda9bf7deb3bc43a86c4f09cf4b1199d
cc 2d4a7f40ce445d9b159941ba5cf94b635db018c6229a88e22796091e4c94b059
cc 16a8e929be616a64b36204ff393a1cf376db5559d051627ef4eff1055f9604a5
cc b5dc48d89cc4f0c61ed3b7c58338f8f9f06654a5948bad62869ea4bbecf270d8

View File

@@ -3,6 +3,7 @@ mod resource;
pub(crate) use resource::{CidrResource, Resource};
#[cfg(all(feature = "proptest", test))]
pub(crate) use resource::{DnsResource, InternetResource};
use ringbuffer::{AllocRingBuffer, RingBuffer};
use crate::dns::StubResolver;
use crate::expiring_map::ExpiringMap;
@@ -179,7 +180,9 @@ impl DnsResourceNatState {
struct PendingFlow {
last_intent_sent_at: Instant,
packets: UniquePacketBuffer,
resource_packets: UniquePacketBuffer,
udp_dns_queries: AllocRingBuffer<IpPacket>,
tcp_dns_queries: AllocRingBuffer<dns_over_tcp::Query>,
}
impl PendingFlow {
@@ -189,13 +192,25 @@ impl PendingFlow {
/// Thus, we may receive a fair few packets before we can send them.
const CAPACITY_POW_2: usize = 7; // 2^7 = 128
fn new(now: Instant, packet: IpPacket) -> Self {
let mut packets = UniquePacketBuffer::with_capacity_power_of_2(Self::CAPACITY_POW_2);
packets.push(packet);
Self {
fn new(now: Instant, trigger: ConnectionTrigger) -> Self {
let mut this = Self {
last_intent_sent_at: now,
packets,
resource_packets: UniquePacketBuffer::with_capacity_power_of_2(Self::CAPACITY_POW_2),
udp_dns_queries: AllocRingBuffer::with_capacity_power_of_2(Self::CAPACITY_POW_2),
tcp_dns_queries: AllocRingBuffer::with_capacity_power_of_2(Self::CAPACITY_POW_2),
};
this.push(trigger);
this
}
fn push(&mut self, trigger: ConnectionTrigger) {
match trigger {
ConnectionTrigger::PacketForResource(packet) => {
self.resource_packets.push(packet);
}
ConnectionTrigger::UdpDnsQueryForSite(packet) => self.udp_dns_queries.push(packet),
ConnectionTrigger::TcpDnsQueryForSite(query) => self.tcp_dns_queries.push(query),
}
}
}
@@ -524,7 +539,7 @@ impl ClientState {
.inspect_err(|e| tracing::debug!(%gid, %local, %from, "{e}"))
.ok()?;
let packet = maybe_mangle_dns_response_from_cidr_resource(
let packet = maybe_mangle_dns_response_from_upstream_dns_server(
packet,
&mut self.udp_dns_sockets_by_upstream_and_query_id,
);
@@ -749,7 +764,10 @@ impl ClientState {
self.peers.add_ip(&gateway_id, &gateway_tun.v4.into());
self.peers.add_ip(&gateway_id, &gateway_tun.v6.into());
let buffered_packets = pending_flow.packets;
// Deal with buffered packets
// 1. Buffered packets for resources
let buffered_resource_packets = pending_flow.resource_packets;
match resource {
Resource::Cidr(_) | Resource::Internet(_) => {
@@ -760,7 +778,7 @@ impl ClientState {
);
// For CIDR and Internet resources, we can directly queue the buffered packets.
for packet in buffered_packets {
for packet in buffered_resource_packets {
encapsulate_and_buffer(
packet,
gateway_id,
@@ -770,7 +788,40 @@ impl ClientState {
);
}
}
Resource::Dns(_) => self.update_dns_resource_nat(now, buffered_packets.into_iter()),
Resource::Dns(_) => {
self.update_dns_resource_nat(now, buffered_resource_packets.into_iter())
}
}
// 2. Buffered UDP DNS queries for the Gateway
for packet in pending_flow.udp_dns_queries {
let gateway = self.peers.get(&gateway_id).context("Unknown peer")?; // If this error happens we have a bug: We just inserted it above.
let upstream = gateway.tun_dns_server_endpoint(packet.destination());
let packet =
self.mangle_udp_dns_query_to_new_upstream_through_tunnel(upstream, now, packet);
encapsulate_and_buffer(
packet,
gateway_id,
now,
&mut self.node,
&mut self.buffered_transmits,
)
}
// 3. Buffered TCP DNS queries for the Gateway
for query in pending_flow.tcp_dns_queries {
let server = match query.local {
SocketAddr::V4(_) => {
SocketAddr::new(gateway_tun.v4.into(), crate::gateway::TUN_DNS_PORT)
}
SocketAddr::V6(_) => {
SocketAddr::new(gateway_tun.v6.into(), crate::gateway::TUN_DNS_PORT)
}
};
self.forward_tcp_dns_query_to_new_upstream_via_tunnel(server, query);
}
Ok(Ok(()))
@@ -820,16 +871,23 @@ impl ClientState {
}
#[tracing::instrument(level = "debug", skip_all, fields(%resource))]
fn on_not_connected_resource(&mut self, resource: ResourceId, packet: IpPacket, now: Instant) {
fn on_not_connected_resource(
&mut self,
resource: ResourceId,
trigger: impl Into<ConnectionTrigger>,
now: Instant,
) {
let trigger = trigger.into();
debug_assert!(self.resources_by_id.contains_key(&resource));
match self.pending_flows.entry(resource) {
Entry::Vacant(v) => {
v.insert(PendingFlow::new(now, packet));
v.insert(PendingFlow::new(now, trigger));
}
Entry::Occupied(mut o) => {
let pending_flow = o.get_mut();
pending_flow.packets.push(packet);
pending_flow.push(trigger);
let time_since_last_intent = now.duration_since(pending_flow.last_intent_sent_at);
@@ -1102,7 +1160,7 @@ impl ClientState {
fn handle_udp_dns_query(
&mut self,
upstream: SocketAddr,
mut packet: IpPacket,
packet: IpPacket,
now: Instant,
) -> ControlFlow<(), IpPacket> {
let Some(datagram) = packet.as_udp() else {
@@ -1131,29 +1189,13 @@ impl ClientState {
"Failed to queue UDP DNS response: {}"
);
}
dns::ResolveStrategy::Recurse => {
let query_id = message.header().id();
dns::ResolveStrategy::RecurseLocal => {
if self.should_forward_dns_query_to_gateway(upstream.ip()) {
tracing::trace!(server = %upstream, %query_id, "Forwarding UDP DNS query via tunnel");
self.udp_dns_sockets_by_upstream_and_query_id.insert(
(upstream, message.header().id()),
SocketAddr::new(packet.destination(), dns::DNS_PORT),
now + IDS_EXPIRE,
);
packet.set_dst(upstream.ip());
// TODO: Remove this once we disallow non-standard DNS ports: https://github.com/firezone/firezone/issues/8330
packet
.as_udp_mut()
.expect("we parsed it as a UDP packet earlier")
.set_destination_port(upstream.port());
packet.update_checksum();
let packet = self
.mangle_udp_dns_query_to_new_upstream_through_tunnel(upstream, now, packet);
return ControlFlow::Continue(packet);
}
let query_id = message.header().id();
tracing::trace!(server = %upstream, %query_id, "Forwarding UDP DNS query directly via host");
@@ -1161,13 +1203,70 @@ impl ClientState {
self.buffered_dns_queries
.push_back(dns::RecursiveQuery::via_udp(source, upstream, message));
}
dns::ResolveStrategy::RecurseSite(resource) => {
let Some(gateway) =
peer_by_resource_mut(&self.resources_gateways, &mut self.peers, resource)
else {
self.on_not_connected_resource(
resource,
ConnectionTrigger::UdpDnsQueryForSite(packet),
now,
);
return ControlFlow::Break(());
};
let upstream = gateway.tun_dns_server_endpoint(packet.destination());
let packet =
self.mangle_udp_dns_query_to_new_upstream_through_tunnel(upstream, now, packet);
return ControlFlow::Continue(packet);
}
}
ControlFlow::Break(())
}
fn mangle_udp_dns_query_to_new_upstream_through_tunnel(
&mut self,
upstream: SocketAddr,
now: Instant,
mut packet: IpPacket,
) -> IpPacket {
let dst_ip = packet.destination();
let datagram = packet
.as_udp()
.expect("to be a valid UDP packet at this point");
let dst_port = datagram.destination_port();
let query_id = parse_udp_dns_message(&datagram)
.expect("to be a valid DNS query at this point")
.header()
.id();
let connlib_dns_server = SocketAddr::new(dst_ip, dst_port);
self.udp_dns_sockets_by_upstream_and_query_id.insert(
(upstream, query_id),
connlib_dns_server,
now + IDS_EXPIRE,
);
packet.set_dst(upstream.ip());
// TODO: Remove this once we disallow non-standard DNS ports: https://github.com/firezone/firezone/issues/8330
packet
.as_udp_mut()
.expect("to be a valid UDP packet at this point")
.set_destination_port(upstream.port());
packet.update_checksum();
tracing::trace!(%upstream, %connlib_dns_server, %query_id, "Forwarding UDP DNS query via tunnel");
packet
}
fn handle_tcp_dns_query(&mut self, query: dns_over_tcp::Query, now: Instant) {
let message = query.message;
let query_id = query.message.header().id();
let Some(upstream) = self.dns_mapping.get_by_left(&query.local.ip()) else {
// This is highly-unlikely but might be possible if our DNS mapping changes whilst the TCP DNS server is processing a request.
@@ -1175,7 +1274,7 @@ impl ClientState {
};
let server = upstream.address();
match self.stub_resolver.handle(message.for_slice_ref()) {
match self.stub_resolver.handle(query.message.for_slice_ref()) {
dns::ResolveStrategy::LocalResponse(response) => {
self.clear_dns_resource_nat_for_domain(response.for_slice_ref());
self.update_dns_resource_nat(now, iter::empty());
@@ -1185,31 +1284,9 @@ impl ClientState {
"Failed to send TCP DNS response: {}"
);
}
dns::ResolveStrategy::Recurse => {
let query_id = message.header().id();
dns::ResolveStrategy::RecurseLocal => {
if self.should_forward_dns_query_to_gateway(server.ip()) {
match self.tcp_dns_client.send_query(server, message.clone()) {
Ok(()) => {}
Err(e) => {
tracing::warn!("Failed to send recursive TCP DNS query: {e:#}");
unwrap_or_debug!(
self.tcp_dns_server.send_message(
query.socket,
dns::servfail(message.for_slice_ref())
),
"Failed to send TCP DNS response: {}"
);
return;
}
};
let existing = self
.tcp_dns_sockets_by_upstream_and_query_id
.insert((server, query_id), query.socket);
debug_assert!(existing.is_none(), "Query IDs should be unique");
self.forward_tcp_dns_query_to_new_upstream_via_tunnel(server, query);
return;
}
@@ -1217,11 +1294,64 @@ impl ClientState {
tracing::trace!(%server, %query_id, "Forwarding TCP DNS query");
self.buffered_dns_queries
.push_back(dns::RecursiveQuery::via_tcp(query.socket, server, message));
.push_back(dns::RecursiveQuery::via_tcp(
query.socket,
server,
query.message,
));
}
dns::ResolveStrategy::RecurseSite(resource) => {
let Some(gateway) =
peer_by_resource_mut(&self.resources_gateways, &mut self.peers, resource)
else {
self.on_not_connected_resource(
resource,
ConnectionTrigger::TcpDnsQueryForSite(query),
now,
);
return;
};
let server = gateway.tun_dns_server_endpoint(query.local.ip());
self.forward_tcp_dns_query_to_new_upstream_via_tunnel(server, query);
}
};
}
fn forward_tcp_dns_query_to_new_upstream_via_tunnel(
&mut self,
server: SocketAddr,
query: dns_over_tcp::Query,
) {
let query_id = query.message.header().id();
match self
.tcp_dns_client
.send_query(server, query.message.clone())
{
Ok(()) => {}
Err(e) => {
tracing::warn!(
"Failed to send recursive TCP DNS query to upstream resolver: {e:#}"
);
unwrap_or_debug!(
self.tcp_dns_server
.send_message(query.socket, dns::servfail(query.message.for_slice_ref())),
"Failed to send TCP DNS response: {}"
);
return;
}
};
let existing = self
.tcp_dns_sockets_by_upstream_and_query_id
.insert((server, query_id), query.socket);
debug_assert!(existing.is_none(), "Query IDs should be unique");
}
fn maybe_update_tun_routes(&mut self) {
let Some(config) = self.tun_config.clone() else {
return;
@@ -1790,7 +1920,7 @@ fn is_definitely_not_a_resource(ip: IpAddr) -> bool {
false
}
fn maybe_mangle_dns_response_from_cidr_resource(
fn maybe_mangle_dns_response_from_upstream_dns_server(
mut packet: IpPacket,
udp_dns_sockets_by_upstream_and_query_id: &mut ExpiringMap<(SocketAddr, u16), SocketAddr>,
) -> IpPacket {
@@ -1854,6 +1984,24 @@ fn truncate_dns_response(mut message: Message<Vec<u8>>) -> Vec<u8> {
message_bytes
}
/// What triggered us to establish a connection to a Gateway.
enum ConnectionTrigger {
/// A packet received on the TUN device with a destination IP that maps to one of our resources.
PacketForResource(IpPacket),
/// A UDP DNS query that needs to be resolved within a particular site that we aren't connected to yet.
///
/// This packet isn't mangled yet to point to the Gateway's TUN device IP because at the time of buffering, that IP is unknown.
UdpDnsQueryForSite(IpPacket),
/// A TCP DNS query that needs to be resolved within a particular site that we aren't connected to yet.
TcpDnsQueryForSite(dns_over_tcp::Query),
}
impl From<IpPacket> for ConnectionTrigger {
fn from(v: IpPacket) -> Self {
Self::PacketForResource(v)
}
}
pub struct IpProvider {
ipv4: Box<dyn Iterator<Item = Ipv4Addr> + Send + Sync>,
ipv6: Box<dyn Iterator<Item = Ipv6Addr> + Send + Sync>,

View File

@@ -102,8 +102,10 @@ pub(crate) enum Transport {
pub(crate) enum ResolveStrategy {
/// The query is for a Resource, we have an IP mapped already, and we can respond instantly
LocalResponse(Message<Vec<u8>>),
/// The query is for a non-Resource, forward it to an upstream or system resolver.
Recurse,
/// The query is for a non-Resource, forward it locally to an upstream or system resolver.
RecurseLocal,
/// The query is for a DNS resource but for a type that we don't intercept (i.e. SRV, TXT, ...), forward it to the site that hosts the DNS resource and resolve it there.
RecurseSite(ResourceId),
}
impl Default for StubResolver {
@@ -274,9 +276,14 @@ impl StubResolver {
(Rtype::AAAA, Some(resource)) => {
self.get_or_assign_aaaa_records(domain.clone(), resource)
}
(Rtype::SRV | Rtype::TXT, Some(resource)) => {
tracing::debug!(%qtype, %resource, "Forwarding query for DNS resource to corresponding site");
return Ok(ResolveStrategy::RecurseSite(resource));
}
(Rtype::PTR, _) => {
let Some(fqdn) = self.resource_address_name_by_reservse_dns(&domain) else {
return Ok(ResolveStrategy::Recurse);
return Ok(ResolveStrategy::RecurseLocal);
};
vec![AllRecordData::Ptr(domain::rdata::Ptr::new(fqdn))]
@@ -288,7 +295,7 @@ impl StubResolver {
let response = build_dns_with_answer(message, domain, Vec::default())?;
return Ok(ResolveStrategy::LocalResponse(response));
}
_ => return Ok(ResolveStrategy::Recurse),
_ => return Ok(ResolveStrategy::RecurseLocal),
};
tracing::trace!(%qtype, %domain, records = ?resource_records, "Forming DNS response");

View File

@@ -15,6 +15,8 @@ use std::net::{IpAddr, SocketAddr};
use std::sync::Arc;
use std::time::{Duration, Instant};
pub const TUN_DNS_PORT: u16 = 53535;
const EXPIRE_RESOURCES_INTERVAL: Duration = Duration::from_secs(1);
/// A SANS-IO implementation of a gateway's functionality.

View File

@@ -1,6 +1,6 @@
use std::collections::{hash_map, BTreeMap, BTreeSet, HashMap, HashSet, VecDeque};
use std::iter;
use std::net::IpAddr;
use std::net::{IpAddr, SocketAddr};
use std::time::Instant;
use crate::client::{IPV4_RESOURCES, IPV6_RESOURCES};
@@ -37,6 +37,17 @@ impl GatewayOnClient {
self.allowed_ips.insert(*ip, HashSet::from([*id]));
}
}
/// For a given destination IP, return the endpoint to which the DNS query should be sent.
pub(crate) fn tun_dns_server_endpoint(&self, dst: IpAddr) -> SocketAddr {
let new_dst_ip = match dst {
IpAddr::V4(_) => self.gateway_tun.v4.into(),
IpAddr::V6(_) => self.gateway_tun.v6.into(),
};
let new_dst_port = crate::gateway::TUN_DNS_PORT;
SocketAddr::new(new_dst_ip, new_dst_port)
}
}
impl GatewayOnClient {

View File

@@ -56,6 +56,10 @@ impl DnsRecords {
.dedup()
.collect_vec()
}
pub(crate) fn is_empty(&self) -> bool {
self.inner.is_empty()
}
}
impl<I> From<I> for DnsRecords

View File

@@ -701,10 +701,23 @@ impl ReferenceState {
// We surface what are the existing rtypes for a domain so that it's easier
// for the proptests to hit an existing record.
fn all_domains(&self) -> Vec<(DomainName, Vec<Rtype>)> {
self.global_dns_records
.domains_iter()
.map(|d| (d.clone(), self.global_dns_records.domain_rtypes(&d)))
.collect()
fn domains_and_rtypes(
records: &DnsRecords,
) -> impl Iterator<Item = (DomainName, Vec<Rtype>)> + use<'_> {
records
.domains_iter()
.map(|d| (d.clone(), records.domain_rtypes(&d)))
}
// We may have multiple gateways in a site, so we need to dedup.
let unique_domains = self
.gateways
.values()
.flat_map(|g| domains_and_rtypes(g.inner().dns_records()))
.chain(domains_and_rtypes(&self.global_dns_records))
.collect::<BTreeSet<_>>();
Vec::from_iter(unique_domains)
}
fn reachable_dns_servers(&self) -> Vec<SocketAddr> {

View File

@@ -358,6 +358,9 @@ impl SimClient {
AllRecordData::Txt(_) => {
continue;
}
AllRecordData::Srv(_) => {
continue;
}
unhandled => {
panic!("Unexpected record data: {unhandled:?}")
}
@@ -787,6 +790,11 @@ impl RefClient {
}
}
if let Some(resource) = self.is_site_specific_dns_query(query) {
self.set_resource_online(resource);
return;
}
if let Some(resource) = self.dns_query_via_resource(query) {
self.connect_to_internet_or_cidr_resource(resource);
self.set_resource_online(resource);
@@ -1018,6 +1026,14 @@ impl RefClient {
maybe_active_cidr_resource.or(maybe_active_internet_resource)
}
pub(crate) fn is_site_specific_dns_query(&self, query: &DnsQuery) -> Option<ResourceId> {
if !matches!(query.r_type, Rtype::SRV | Rtype::TXT) {
return None;
}
self.dns_resource_by_domain(&query.domain)
}
pub(crate) fn all_resource_ids(&self) -> Vec<ResourceId> {
self.resources.iter().map(|r| r.id()).collect()
}

View File

@@ -16,6 +16,7 @@ use proptest::prelude::*;
use snownet::Transmit;
use std::{
collections::{BTreeMap, HashMap},
iter,
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
time::Instant,
};
@@ -34,15 +35,21 @@ pub(crate) struct SimGateway {
/// The received TCP packets, indexed by our custom TCP payload.
pub(crate) received_tcp_requests: BTreeMap<u64, IpPacket>,
site_specific_dns_records: DnsRecords,
udp_dns_server_resources: HashMap<SocketAddr, UdpDnsServerResource>,
tcp_dns_server_resources: HashMap<SocketAddr, TcpDnsServerResource>,
}
impl SimGateway {
pub(crate) fn new(id: GatewayId, sut: GatewayState) -> Self {
pub(crate) fn new(
id: GatewayId,
sut: GatewayState,
site_specific_dns_records: DnsRecords,
) -> Self {
Self {
id,
sut,
site_specific_dns_records,
received_icmp_requests: Default::default(),
udp_dns_server_resources: Default::default(),
tcp_dns_server_resources: Default::default(),
@@ -77,16 +84,35 @@ impl SimGateway {
global_dns_records: &DnsRecords,
now: Instant,
) -> Vec<Transmit<'static>> {
let udp_server_packets = self.udp_dns_server_resources.values_mut().flat_map(|s| {
s.handle_timeout(global_dns_records, now);
let Some(ip_config) = self.sut.tunnel_ip_config() else {
tracing::error!("Tunnel IP configuration not set");
return Vec::new();
};
std::iter::from_fn(|| s.poll_outbound())
});
let tcp_server_packets = self.tcp_dns_server_resources.values_mut().flat_map(|s| {
s.handle_timeout(global_dns_records, now);
let udp_server_packets =
self.udp_dns_server_resources
.iter_mut()
.flat_map(|(socket, server)| {
if ip_config.is_ip(socket.ip()) {
server.handle_timeout(&self.site_specific_dns_records, now);
} else {
server.handle_timeout(global_dns_records, now);
}
std::iter::from_fn(|| s.poll_outbound())
});
std::iter::from_fn(|| server.poll_outbound())
});
let tcp_server_packets =
self.tcp_dns_server_resources
.iter_mut()
.flat_map(|(socket, server)| {
if ip_config.is_ip(socket.ip()) {
server.handle_timeout(&self.site_specific_dns_records, now);
} else {
server.handle_timeout(global_dns_records, now);
}
std::iter::from_fn(|| server.poll_outbound())
});
udp_server_packets
.chain(tcp_server_packets)
@@ -109,7 +135,22 @@ impl SimGateway {
) {
self.udp_dns_server_resources.clear();
for server in dns_servers {
let tun_dns_server_port = 53535; // Hardcoded here so we think about backwards-compatibility when changing it.
let Some(ip_config) = self.sut.tunnel_ip_config() else {
tracing::error!("Tunnel IP configuration not set");
return;
};
for server in dns_servers
.chain(iter::once(SocketAddr::from((
ip_config.v4,
tun_dns_server_port,
))))
.chain(iter::once(SocketAddr::from((
ip_config.v6,
tun_dns_server_port,
))))
{
self.udp_dns_server_resources
.insert(server, UdpDnsServerResource::default());
self.tcp_dns_server_resources
@@ -255,6 +296,8 @@ pub struct RefGateway {
pub(crate) key: PrivateKey,
pub(crate) tunnel_ip4: Ipv4Addr,
pub(crate) tunnel_ip6: Ipv6Addr,
site_specific_dns_records: DnsRecords,
}
impl RefGateway {
@@ -262,24 +305,29 @@ impl RefGateway {
///
/// This simulates receiving the `init` message from the portal.
pub(crate) fn init(self, id: GatewayId, now: Instant) -> SimGateway {
let mut sut = GatewayState::new(self.key.0, now);
let mut sut = GatewayState::new(self.key.0, now); // Cheating a bit here by reusing the key as seed.
sut.update_tun_device(IpConfig {
v4: self.tunnel_ip4,
v6: self.tunnel_ip6,
});
SimGateway::new(id, sut) // Cheating a bit here by reusing the key as seed.
SimGateway::new(id, sut, self.site_specific_dns_records)
}
pub fn dns_records(&self) -> &DnsRecords {
&self.site_specific_dns_records
}
}
pub(crate) fn ref_gateway_host(
tunnel_ip4s: impl Strategy<Value = Ipv4Addr>,
tunnel_ip6s: impl Strategy<Value = Ipv6Addr>,
site_specific_dns_records: impl Strategy<Value = DnsRecords>,
) -> impl Strategy<Value = Host<RefGateway>> {
host(
dual_ip_stack(),
any_port(),
ref_gateway(tunnel_ip4s, tunnel_ip6s),
ref_gateway(tunnel_ip4s, tunnel_ip6s, site_specific_dns_records),
latency(200), // We assume gateways have a somewhat decent Internet connection.
)
}
@@ -287,14 +335,22 @@ pub(crate) fn ref_gateway_host(
fn ref_gateway(
tunnel_ip4s: impl Strategy<Value = Ipv4Addr>,
tunnel_ip6s: impl Strategy<Value = Ipv6Addr>,
site_specific_dns_records: impl Strategy<Value = DnsRecords>,
) -> impl Strategy<Value = RefGateway> {
(private_key(), tunnel_ip4s, tunnel_ip6s).prop_map(move |(key, tunnel_ip4, tunnel_ip6)| {
RefGateway {
key,
tunnel_ip4,
tunnel_ip6,
}
})
(
private_key(),
tunnel_ip4s,
tunnel_ip6s,
site_specific_dns_records,
)
.prop_map(
move |(key, tunnel_ip4, tunnel_ip6, site_specific_dns_records)| RefGateway {
key,
tunnel_ip4,
tunnel_ip6,
site_specific_dns_records,
},
)
}
fn icmp_error_reply(packet: &IpPacket, error: IcmpError) -> Result<IpPacket> {

View File

@@ -37,6 +37,16 @@ fn dns_record() -> impl Strategy<Value = DomainRecord> {
]
}
pub(crate) fn site_specific_dns_record() -> impl Strategy<Value = DomainRecord> {
prop_oneof![
collection::vec(txt_record(), 6..=10)
.prop_map(|sections| { sections.into_iter().flatten().collect_vec() })
.prop_map(|o| domain::rdata::Txt::from_octets(o).unwrap())
.prop_map(DomainRecord::Txt),
srv_record()
]
}
// A maximum length txt record section
fn txt_record() -> impl Strategy<Value = Vec<u8>> {
"[a-z]{255}".prop_map(|s| {
@@ -50,6 +60,18 @@ fn txt_record() -> impl Strategy<Value = Vec<u8>> {
})
}
fn srv_record() -> impl Strategy<Value = DomainRecord> {
(
any::<u16>(),
any::<u16>(),
any::<u16>(),
domain_name(2..4).prop_map(|d| d.parse().unwrap()),
)
.prop_map(|(priority, weight, port, target)| {
DomainRecord::Srv(domain::rdata::Srv::new(priority, weight, port, target))
})
}
pub(crate) fn packet_source_v4(client: Ipv4Addr) -> impl Strategy<Value = Ipv4Addr> {
prop_oneof![
10 => Just(client),

View File

@@ -3,15 +3,19 @@ use super::{
sim_client::{ref_client_host, RefClient},
sim_gateway::{ref_gateway_host, RefGateway},
sim_net::Host,
strategies::{resolved_ips, subdomain_records},
strategies::{resolved_ips, site_specific_dns_record, subdomain_records},
};
use crate::messages::{gateway, DnsServer};
use crate::{client, proptest::*};
use crate::{
client::DnsResource,
messages::{gateway, DnsServer},
};
use connlib_model::GatewayId;
use connlib_model::{ResourceId, SiteId};
use itertools::Itertools;
use proptest::{
sample::Selector,
collection,
sample::{self, Selector},
strategy::{Just, Strategy},
};
use std::{
@@ -223,15 +227,22 @@ impl StubPortal {
}
pub(crate) fn gateways(&self) -> impl Strategy<Value = BTreeMap<GatewayId, Host<RefGateway>>> {
let dns_resources = self.dns_resources.clone();
self.gateways_by_site
.values()
.flatten()
.map(|(gid, ipv4_addr, ipv6_addr)| {
(
Just(*gid),
ref_gateway_host(Just(*ipv4_addr), Just(*ipv6_addr)),
)
}) // Map each ID to a strategy that samples a gateway.
.iter()
.flat_map(|(site_id, gateways)| {
gateways.iter().map(|(gid, ipv4_addr, ipv6_addr)| {
(
Just(*gid),
ref_gateway_host(
Just(*ipv4_addr),
Just(*ipv6_addr),
site_specific_dns_records(dns_resources.clone(), *site_id),
),
)
})
})
.collect::<Vec<_>>() // A `Vec<Strategy>` implements `Strategy<Value = Vec<_>>`
.prop_map(BTreeMap::from_iter)
}
@@ -250,39 +261,68 @@ impl StubPortal {
}
pub(crate) fn dns_resource_records(&self) -> impl Strategy<Value = DnsRecords> {
self.dns_resources
.values()
.map(|resource| {
let address = resource.address.clone();
// Only generate simple wildcard domains for these tests.
// The matching logic is extensively unit-tested so we don't need to cover all cases here.
// What we do want to cover is multiple domains pointing to the same resource.
// For example, `*.example.com` and `app.example.com`.
match address.split_once('.') {
Some(("*" | "**", base)) => {
subdomain_records(base.to_owned(), domain_label()).boxed()
}
_ => resolved_ips()
.prop_map(move |resolved_ips| {
DnsRecords::from([(address.parse().unwrap(), resolved_ips)])
})
.boxed(),
}
})
.collect::<Vec<_>>()
.prop_map(|records| {
let mut map = DnsRecords::default();
for record in records {
map.merge(record)
}
map
})
dns_resource_records(self.dns_resources.clone().into_values())
}
}
/// Generates site-specific DNS records for a particular site.
fn site_specific_dns_records(
dns_resources: BTreeMap<ResourceId, client::DnsResource>,
site: SiteId,
) -> impl Strategy<Value = DnsRecords> {
let dns_resources_in_site = dns_resources
.into_values()
.filter(move |resource| resource.sites.iter().any(|s| s.id == site));
dns_resource_records(dns_resources_in_site).prop_flat_map(|records| {
if records.is_empty() {
Just(DnsRecords::default()).boxed()
} else {
collection::btree_map(
sample::select(records.domains_iter().collect::<Vec<_>>()),
collection::btree_set(site_specific_dns_record(), 1..6),
0..5,
)
.prop_map_into()
.boxed()
}
})
}
fn dns_resource_records(
dns_resources: impl Iterator<Item = DnsResource>,
) -> impl Strategy<Value = DnsRecords> {
dns_resources
.map(|resource| {
let address = resource.address;
// Only generate simple wildcard domains for these tests.
// The matching logic is extensively unit-tested so we don't need to cover all cases here.
// What we do want to cover is multiple domains pointing to the same resource.
// For example, `*.example.com` and `app.example.com`.
match address.split_once('.') {
Some(("*" | "**", base)) => {
subdomain_records(base.to_owned(), domain_label()).boxed()
}
_ => resolved_ips()
.prop_map(move |resolved_ips| {
DnsRecords::from([(address.parse().unwrap(), resolved_ips)])
})
.boxed(),
}
})
.collect::<Vec<_>>()
.prop_map(|records| {
let mut map = DnsRecords::default();
for record in records {
map.merge(record)
}
map
})
}
/// An [`Iterator`] over the possible IPv4 addresses of a tunnel interface.
///
/// We use the CG-NAT range for IPv4.