fix(gateway): always update translation table from DNS response (#10796)

For DNS resources, the Gateway maintains a per-peer NAT table from the
client-assigned proxy IPs to the real IPs of the domain. Whenever the
Client re-queries a DNS resource domain locally, we asynchronously ping
the Gateway to also re-query said domain. This allows us to detect
changes in the DNS records of DNS resources.

To avoid breaking existing connections, the mapping between proxy IPs
and real IPs is currently not updated if there are any active UDP or TCP
flows for a proxy IP.

This logic turns out to be unnecessarily restrictive as TCP flows can
linger around for up to 2h before they timeout if they are not closed
with a TCP RST. What we really need to do is always update the mapping
of proxy IP <> real IP but honor existing NAT table entries when we
route packets before creating new ones. This ensures that an existing
connection to a previously resolved IP remains intact, even if a later
DNS response for the same domain updates the mapping. At the same time,
new connections (i.e. with a different source port) will immediately use
the new destination IP.
This commit is contained in:
Thomas Eizinger
2025-11-06 22:52:28 +11:00
committed by GitHub
parent b5048ad779
commit 602844ae4a
21 changed files with 505 additions and 182 deletions

8
rust/Cargo.lock generated
View File

@@ -5738,8 +5738,7 @@ dependencies = [
[[package]]
name = "proptest"
version = "1.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bee689443a2bd0a16ab0348b52ee43e3b2d1b1f931c8aa5c9f8de4c86fbe8c40"
source = "git+https://github.com/firezone/proptest?branch=feat%2Fstate-machine-closure#ea2146e4c6116c855571e5a0f717eaaab8821ff0"
dependencies = [
"bit-set",
"bit-vec",
@@ -5756,9 +5755,8 @@ dependencies = [
[[package]]
name = "proptest-state-machine"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "05ad8988b889475b24ab3b485d0f3de21863077c2ff169b37e3c5b805fcff624"
version = "0.6.0"
source = "git+https://github.com/firezone/proptest?branch=feat%2Fstate-machine-closure#ea2146e4c6116c855571e5a0f717eaaab8821ff0"
dependencies = [
"proptest",
]

View File

@@ -136,8 +136,8 @@ parking_lot = "0.12.4"
phoenix-channel = { path = "connlib/phoenix-channel" }
png = "0.17.16"
proc-macro2 = "1.0"
proptest = "1.7.0"
proptest-state-machine = "0.5.0"
proptest = "1.9.0"
proptest-state-machine = "0.6.0"
quinn-udp = { version = "0.5.12", features = ["fast-apple-datapath"] }
quote = "1.0"
rand = "0.8.5"
@@ -249,6 +249,8 @@ softbuffer = { git = "https://github.com/rust-windowing/softbuffer" } # Waiting
str0m = { git = "https://github.com/algesten/str0m", branch = "main" }
moka = { git = "https://github.com/moka-rs/moka", branch = "main" } # Waiting for release.
quinn-udp = { git = "https://github.com/quinn-rs/quinn", branch = "main" } # Waiting for release.
proptest = { git = "https://github.com/firezone/proptest", branch = "feat/state-machine-closure" }
proptest-state-machine = { git = "https://github.com/firezone/proptest", branch = "feat/state-machine-closure" }
# Enforce `tracing-macros` to have released `tracing` version.
[patch.'https://github.com/tokio-rs/tracing']

View File

@@ -222,3 +222,14 @@ cc 04193ee1047f542c469aa0893bf636df9c317943022d922e231de3e821b39486
cc e8520f159df085f7dbe6dce8b121336d33708af9f804a8a14bf6b5a3eb3a9d4d
cc fd95c0fcd3af20d73849004cb642b09c5bacfa7ca25781d0268441d49fe3b6cf
cc 4f65d01188bb870aa8f4893530f77ace3434c133cf24735619196df6d043f4cf
cc bfa39b9578b2d143e2ff3fcd8622c2af37d9a1408288c525138cc9e4c926e9e3
cc 10a24fe019b05296d841546467ea60f02df0ee9350ae1c4e7bba5ea80425ca3f
cc 2ae65b3fa5b90531b325b05e0694ce6d80f12a85d805250d3e70557733a7ccc8
cc bcec408954dca8a463d72d9ff588f1d2eacffac988010ce8830d69ca0aba8a25
cc acb651429f1c625d5d6bfb25fe62cc951262c4d5cd8e372876a439e2627b19cd
cc 7a1d6686275361933c710a0f051607211de92cac46c50436ebd2f970dbcbbf23
cc 27c05624dcd17a8118f6ae4f6d11dd324986b6cd01b998c804dc5eb27d1e2f06
cc 4bf2050d07594df9fbaf3462fe9dc0463739e42fdd5f614c644c86d37163cdb6
cc 0ca23286f6e2952919d254d7c524e65ece78672d558672f124ea67b37e82f7d5
cc 3a579d80f7bff0b34ad67a2b9156c50eeb5c3f1861793a52539b80b75ca415b1
cc 096d9ba59d5770265ec3dfb185560f28b0dcd2839584eef0918eebe12b63c01c

View File

@@ -738,4 +738,8 @@ impl ResolveDnsRequest {
pub fn domain(&self) -> &DomainName {
&self.domain
}
pub fn client(&self) -> ClientId {
self.client
}
}

View File

@@ -129,27 +129,18 @@ impl ClientOnGateway {
.cycle(),
);
let mut ip_maps = ipv4_maps.chain(ipv6_maps);
loop {
let Some(either_or_both) = ip_maps.next() else {
break;
};
// Clear out all current translations for these proxy IPs.
for proxy_ip in proxy_ips.iter() {
self.permanent_translations.remove(proxy_ip);
}
for either_or_both in ipv4_maps.chain(ipv6_maps) {
let (proxy_ip, maybe_real_ip) = match either_or_both {
EitherOrBoth::Both(proxy_ip, real_ip) => (proxy_ip, Some(real_ip)),
EitherOrBoth::Left(proxy_ip) => (proxy_ip, None),
EitherOrBoth::Right(_) => break,
};
if let Some(state) = self.permanent_translations.get(proxy_ip)
&& self.nat_table.has_entry_for_inside(*proxy_ip)
&& state.resolved_ip != maybe_real_ip
{
tracing::debug!(%name, %proxy_ip, new_real_ip = ?maybe_real_ip, current_real_ip = ?state.resolved_ip, "Skipping DNS resource NAT entry because we have open NAT sessions for it");
continue;
}
tracing::debug!(%name, %proxy_ip, real_ip = ?maybe_real_ip);
self.permanent_translations.insert(
@@ -1099,43 +1090,81 @@ mod tests {
)
.unwrap();
let request = ip_packet::make::udp_packet(
client_tun_ipv4(),
proxy_ip1(),
1,
foo_allowed_port(),
vec![0, 0, 0, 0, 0, 0, 0, 0],
)
.unwrap();
{
let request = ip_packet::make::udp_packet(
client_tun_ipv4(),
proxy_ip1(),
1,
foo_allowed_port(),
vec![0, 0, 0, 0, 0, 0, 0, 0],
)
.unwrap();
let result = peer.translate_outbound(request.clone(), now).unwrap();
let result = peer.translate_outbound(request.clone(), now).unwrap();
assert!(matches!(result, TranslateOutboundResult::Send(_)));
assert!(matches!(result, TranslateOutboundResult::Send(_)));
peer.setup_nat(
foo_name().parse().unwrap(),
foo_resource_id(),
BTreeSet::from([foo_real_ip2().into()]), // Setting up with a new IP!
BTreeSet::from([proxy_ip1().into(), proxy_ip2().into()]),
)
.unwrap();
peer.setup_nat(
foo_name().parse().unwrap(),
foo_resource_id(),
BTreeSet::from([foo_real_ip2().into()]), // Setting up with a new IP!
BTreeSet::from([proxy_ip1().into(), proxy_ip2().into()]),
)
.unwrap();
let result = peer.translate_outbound(request, now).unwrap();
let result = peer.translate_outbound(request, now).unwrap();
assert!(matches!(result, TranslateOutboundResult::Send(_)));
assert!(matches!(result, TranslateOutboundResult::Send(_)));
let response = ip_packet::make::udp_packet(
foo_real_ip1(),
client_tun_ipv4(),
foo_allowed_port(),
1,
vec![0, 0, 0, 0, 0, 0, 0, 0],
)
.unwrap();
let response = ip_packet::make::udp_packet(
foo_real_ip1(),
client_tun_ipv4(),
foo_allowed_port(),
1,
vec![0, 0, 0, 0, 0, 0, 0, 0],
)
.unwrap();
let response = peer.translate_inbound(response, now).unwrap();
let response = peer.translate_inbound(response, now).unwrap();
assert!(response.is_some());
assert!(response.is_some());
}
{
let request = ip_packet::make::udp_packet(
client_tun_ipv4(),
proxy_ip1(),
2, // Using a new source port
foo_allowed_port(),
vec![0, 0, 0, 0, 0, 0, 0, 0],
)
.unwrap();
let result = peer.translate_outbound(request, now).unwrap();
let TranslateOutboundResult::Send(outside_packet) = result else {
panic!("Wrong result");
};
assert_eq!(
outside_packet.destination(),
foo_real_ip2(),
"Request with a new source port should use new IP"
);
let response = ip_packet::make::udp_packet(
foo_real_ip2(),
client_tun_ipv4(),
foo_allowed_port(),
2,
vec![0, 0, 0, 0, 0, 0, 0, 0],
)
.unwrap();
let response = peer.translate_inbound(response, now).unwrap();
assert!(response.is_some());
}
}
#[test]

View File

@@ -45,11 +45,6 @@ impl NatTable {
}
}
/// Returns true if the NAT table has any entries with the given "inside" IP address.
pub(crate) fn has_entry_for_inside(&self, ip: IpAddr) -> bool {
self.table.left_values().any(|(_, c)| c == &ip)
}
pub(crate) fn translate_outgoing(
&mut self,
packet: &IpPacket,
@@ -62,25 +57,21 @@ impl NatTable {
let inside = (src, dst);
if let Some(outside) = self.table.get_by_left(&inside).copied() {
if outside.1 == outside_dst {
tracing::trace!(?inside, ?outside, "Translating outgoing packet");
tracing::trace!(?inside, ?outside, "Translating outgoing packet");
if packet.as_tcp().is_some_and(|tcp| tcp.rst()) {
tracing::debug!(
?inside,
?outside,
"Witnessed outgoing TCP RST, removing NAT session"
);
if packet.as_tcp().is_some_and(|tcp| tcp.rst()) {
tracing::debug!(
?inside,
?outside,
"Witnessed outgoing TCP RST, removing NAT session"
);
self.table.remove_by_left(&inside);
self.expired.insert(outside);
}
self.last_seen.insert(outside, now);
return Ok(outside);
self.table.remove_by_left(&inside);
self.expired.insert(outside);
}
tracing::trace!(?inside, ?outside, "Outgoing packet for expired translation");
self.last_seen.insert(outside, now);
return Ok(outside);
}
// Find the first available public port, starting from the port of the to-be-mapped packet.

View File

@@ -1,5 +1,6 @@
use crate::tests::{flux_capacitor::FluxCapacitor, sut::TunnelTest};
use assertions::PanicOnErrorEvents;
use chrono::Utc;
use core::fmt;
use proptest::{
sample::SizeRange,
@@ -8,7 +9,10 @@ use proptest::{
};
use proptest_state_machine::Sequential;
use reference::ReferenceState;
use std::sync::atomic::{self, AtomicU32};
use std::{
sync::atomic::{self, AtomicU32},
time::Instant,
};
use tracing_subscriber::{
EnvFilter, Layer, layer::SubscriberExt as _, util::SubscriberInitExt as _,
};
@@ -46,20 +50,27 @@ fn tunnel_test() {
let _ = std::fs::remove_dir_all("testcases");
let _ = std::fs::create_dir_all("testcases");
let now = Instant::now();
let utc_now = Utc::now();
let flux_capacitor = FluxCapacitor::new(now, utc_now);
let test_runner = &mut TestRunner::new(config);
let strategy = Sequential::new(
SizeRange::new(5..=15),
ReferenceState::initial_state,
move || ReferenceState::initial_state(now),
ReferenceState::is_valid_transition,
ReferenceState::transitions,
ReferenceState::apply,
move |state| ReferenceState::transitions(state, now),
{
let flux_capacitor = flux_capacitor.clone();
move |state, transition| ReferenceState::apply(state, transition, flux_capacitor.now())
},
);
let result = test_runner.run(
&strategy,
|(mut ref_state, transitions, mut seen_counter)| {
let test_index = test_index.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
let flux_capacitor = FluxCapacitor::default();
let _guard = init_logging(flux_capacitor.clone(), test_index);
@@ -78,7 +89,7 @@ fn tunnel_test() {
println!("Running test case {test_index:04} with {num_transitions:02} transitions");
let mut sut = TunnelTest::init_test(&ref_state, flux_capacitor);
let mut sut = TunnelTest::init_test(&ref_state, flux_capacitor.clone());
// Check the invariants on the initial state
TunnelTest::check_invariants(&sut, &ref_state);
@@ -99,7 +110,7 @@ fn tunnel_test() {
);
// Apply the transition on the states
ref_state = ReferenceState::apply(ref_state, transition);
ref_state = ReferenceState::apply(ref_state, transition, flux_capacitor.now());
sut = TunnelTest::apply(sut, &ref_state, transition.clone());
// Check the invariants after the transition is applied
@@ -129,9 +140,11 @@ fn tunnel_test() {
#[test]
fn reference_state_is_deterministic() {
let now = Instant::now();
for n in 0..1000 {
let state1 = sample_from_strategy(n, ReferenceState::initial_state());
let state2 = sample_from_strategy(n, ReferenceState::initial_state());
let state1 = sample_from_strategy(n, ReferenceState::initial_state(now));
let state2 = sample_from_strategy(n, ReferenceState::initial_state(now));
assert_eq!(format!("{state1:?}"), format!("{state2:?}"));
}
@@ -139,10 +152,12 @@ fn reference_state_is_deterministic() {
#[test]
fn transitions_are_deterministic() {
let now = Instant::now();
for n in 0..1000 {
let state = sample_from_strategy(n, ReferenceState::initial_state());
let transitions1 = sample_from_strategy(n, ReferenceState::transitions(&state));
let transitions2 = sample_from_strategy(n, ReferenceState::transitions(&state));
let state = sample_from_strategy(n, ReferenceState::initial_state(now));
let transitions1 = sample_from_strategy(n, ReferenceState::transitions(&state, now));
let transitions2 = sample_from_strategy(n, ReferenceState::transitions(&state, now));
assert_eq!(format!("{transitions1:?}"), format!("{transitions2:?}"));
}

View File

@@ -5,6 +5,7 @@ use super::{
transition::{Destination, ReplyTo},
};
use connlib_model::GatewayId;
use dns_types::DomainName;
use ip_packet::IpPacket;
use itertools::Itertools;
use std::{
@@ -13,6 +14,7 @@ use std::{
marker::PhantomData,
net::{IpAddr, SocketAddr},
sync::atomic::{AtomicBool, Ordering},
time::Instant,
};
use tracing::{Level, Span, Subscriber};
use tracing_subscriber::Layer;
@@ -33,10 +35,15 @@ pub(crate) fn assert_icmp_packets_properties(
.iter()
.map(|(g, s)| (*g, &s.received_icmp_requests))
.collect();
let dns_query_timestamps = sim_gateways
.iter()
.map(|(g, s)| (*g, &s.dns_query_timestamps))
.collect();
assert_packets_properties(
ref_client,
&sim_client.sent_icmp_requests,
&dns_query_timestamps,
&received_icmp_requests,
&ref_client.expected_icmp_handshakes,
&sim_client.received_icmp_replies,
@@ -62,10 +69,15 @@ pub(crate) fn assert_udp_packets_properties(
.iter()
.map(|(g, s)| (*g, &s.received_udp_requests))
.collect();
let dns_query_timestamps = sim_gateways
.iter()
.map(|(g, s)| (*g, &s.dns_query_timestamps))
.collect();
assert_packets_properties(
ref_client,
&sim_client.sent_udp_requests,
&dns_query_timestamps,
&received_udp_requests,
&ref_client.expected_udp_handshakes,
&sim_client.received_udp_replies,
@@ -133,8 +145,9 @@ pub(crate) fn assert_resource_status(ref_client: &RefClient, sim_client: &SimCli
if expected_status_map != actual_status_map {
for (resource, expected_status) in expected_status_map {
match actual_status_map.get(resource) {
// For resources with TCP connections, the expected status might be off.
// The TCP client sends its own keep-alive's so we cannot always track the internal connection state.
// For resources with TCP connections, the expected status might be wrong.
// We generally expect them to always be online because the TCP client sends its own keep-alive's.
// However, if we have sent an ICMP error back, the client may have given up and therefore it is okay for the site to be in `Unknown` then.
Some(&Online)
if expected_status == &Unknown && tcp_resources.contains(resource) => {}
Some(&Unknown)
@@ -161,7 +174,8 @@ pub(crate) fn assert_resource_status(ref_client: &RefClient, sim_client: &SimCli
fn assert_packets_properties<T, U>(
ref_client: &RefClient,
sent_requests: &HashMap<(T, U), IpPacket>,
received_requests: &BTreeMap<GatewayId, &BTreeMap<u64, IpPacket>>,
dns_query_timestamps: &BTreeMap<GatewayId, &BTreeMap<DomainName, Vec<Instant>>>,
received_requests: &BTreeMap<GatewayId, &BTreeMap<u64, (Instant, IpPacket)>>,
expected_handshakes: &BTreeMap<GatewayId, BTreeMap<u64, (Destination, T, U)>>,
received_replies: &BTreeMap<(T, U), IpPacket>,
packet_protocol: &str,
@@ -182,13 +196,14 @@ fn assert_packets_properties<T, U>(
tracing::error!(target: "assertions", ?unexpected_replies, ?expected_handshakes, ?received_replies, "❌ Unexpected {packet_protocol} replies on client");
}
let mut mapping = HashMap::new();
let mut mappings = HashMap::new();
// Assert properties of the individual handshakes per gateway.
// Due to connlib's implementation of NAT64, we cannot match the packets sent by the client to the packets arriving at the resource by port or ICMP identifier.
// Thus, we rely on a custom u64 payload attached to all packets to uniquely identify every individual packet.
for (gateway, expected_handshakes) in expected_handshakes {
let received_requests = received_requests.get(gateway).unwrap();
let dns_query_timestamps = dns_query_timestamps.get(gateway).unwrap();
let mut num_expected_handshakes = expected_handshakes.len();
@@ -205,7 +220,8 @@ fn assert_packets_properties<T, U>(
};
assert_correct_src_and_dst_ips(client_sent_request, client_received_reply);
let Some(gateway_received_request) = received_requests.get(payload) else {
let Some((packet_sent_at, gateway_received_request)) = received_requests.get(payload)
else {
if client_received_reply
.icmp_error()
.ok()
@@ -234,16 +250,38 @@ fn assert_packets_properties<T, U>(
assert_destination_is_cdir_resource(gateway_received_request, resource_dst)
}
Destination::DomainName { name, .. } => {
let Some(query_timestamps) = dns_query_timestamps.get(name) else {
tracing::error!(%name, "Should have resolved domain at least once");
continue;
};
// To correct assert whether the packet was routed to the correct IP, we need to find the timestamp of the DNS query closest to the packet timestamp.
// In other words: Packets should always use the IPs that were most recently resolved when they were sent.
let Some(dns_record_snapshot) = query_timestamps
.iter()
.filter(|query_timestamp| *query_timestamp <= packet_sent_at)
.max()
else {
tracing::error!(%name, "Should have a relevant query timestamp");
continue;
};
// Split the proxy IP mapping by DNS record snapshot.
//
// When we re-resolve DNS, the mapping is allowed to change.
let mapping = mappings.entry(dns_record_snapshot).or_default();
assert_destination_is_dns_resource(
gateway_received_request,
global_dns_records,
name,
*dns_record_snapshot,
);
assert_proxy_ip_mapping_is_stable(
client_sent_request,
gateway_received_request,
&mut mapping,
mapping,
)
}
}
@@ -415,10 +453,11 @@ fn assert_destination_is_dns_resource(
gateway_received_request: &IpPacket,
global_dns_records: &DnsRecords,
domain: &dns_types::DomainName,
at: Instant,
) {
let actual = gateway_received_request.destination();
let possible_resource_ips = global_dns_records
.domain_ips_iter(domain)
.domain_ips_iter(domain, at)
.collect::<Vec<_>>();
if !possible_resource_ips.contains(&actual) {

View File

@@ -1,7 +1,5 @@
use std::{
collections::{BTreeMap, BTreeSet},
net::IpAddr,
};
use std::collections::BTreeSet;
use std::{collections::BTreeMap, net::IpAddr, time::Instant};
use dns_types::prelude::*;
use dns_types::{DomainName, OwnedRecordData, RecordType};
@@ -9,22 +7,26 @@ use itertools::Itertools;
#[derive(Debug, Default, Clone)]
pub(crate) struct DnsRecords {
inner: BTreeMap<DomainName, BTreeSet<OwnedRecordData>>,
inner: BTreeMap<DomainName, BTreeMap<Instant, BTreeSet<OwnedRecordData>>>,
}
impl DnsRecords {
pub(crate) fn domain_ips_iter(&self, name: &DomainName) -> impl Iterator<Item = IpAddr> + '_ {
pub(crate) fn domain_ips_iter(
&self,
name: &DomainName,
at: Instant,
) -> impl Iterator<Item = IpAddr> + '_ {
#[expect(clippy::wildcard_enum_match_arm)]
self.domain_records_iter(name).filter_map(|r| match r {
self.domain_records_iter(name, at).filter_map(|r| match r {
OwnedRecordData::A(a) => Some(a.addr().into()),
OwnedRecordData::Aaaa(aaaa) => Some(aaaa.addr().into()),
_ => None,
})
}
pub(crate) fn ips_iter(&self) -> impl Iterator<Item = IpAddr> + '_ {
pub(crate) fn ips_iter(&self, at: Instant) -> impl Iterator<Item = IpAddr> + '_ {
#[expect(clippy::wildcard_enum_match_arm)]
self.inner.values().flatten().filter_map(|r| match r {
self.records_at(at).filter_map(|(_, r)| match r {
OwnedRecordData::A(a) => Some(a.addr().into()),
OwnedRecordData::Aaaa(aaaa) => Some(aaaa.addr().into()),
_ => None,
@@ -34,8 +36,12 @@ impl DnsRecords {
pub(crate) fn domain_records_iter(
&self,
name: &DomainName,
at: Instant,
) -> impl Iterator<Item = OwnedRecordData> + '_ {
self.inner.get(name).cloned().into_iter().flatten()
let name = name.clone();
self.records_at(at)
.filter_map(move |(domain, records)| (domain == &name).then_some(records.clone()))
}
pub(crate) fn domains_iter(&self) -> impl Iterator<Item = DomainName> + '_ {
@@ -43,11 +49,18 @@ impl DnsRecords {
}
pub(crate) fn merge(&mut self, other: Self) {
self.inner.extend(other.inner);
for (domain, records) in other.inner {
for (timestamp, records) in records {
self.inner
.entry(domain.clone())
.or_default()
.insert(timestamp, records);
}
}
}
pub(crate) fn domain_rtypes(&self, name: &DomainName) -> Vec<RecordType> {
self.domain_records_iter(name)
pub(crate) fn domain_rtypes(&self, name: &DomainName, at: Instant) -> Vec<RecordType> {
self.domain_records_iter(name, at)
.map(|r| r.rtype())
.dedup()
.collect_vec()
@@ -56,11 +69,25 @@ impl DnsRecords {
pub(crate) fn is_empty(&self) -> bool {
self.inner.is_empty()
}
fn records_at(
&self,
at: Instant,
) -> impl Iterator<Item = (&DomainName, &OwnedRecordData)> + '_ {
self.inner.iter().flat_map(move |(domain, records)| {
records
.iter()
.filter(|(timestamp, _)| **timestamp <= at)
.max_by_key(|(timestamp, _)| **timestamp)
.into_iter()
.flat_map(move |(_, records)| records.iter().map(move |records| (domain, records)))
})
}
}
impl<I> From<I> for DnsRecords
where
BTreeMap<DomainName, BTreeSet<OwnedRecordData>>: From<I>,
BTreeMap<DomainName, BTreeMap<Instant, BTreeSet<OwnedRecordData>>>: From<I>,
{
fn from(value: I) -> Self {
Self {
@@ -71,7 +98,7 @@ where
impl<I> FromIterator<I> for DnsRecords
where
BTreeMap<DomainName, BTreeSet<OwnedRecordData>>: FromIterator<I>,
BTreeMap<DomainName, BTreeMap<Instant, BTreeSet<OwnedRecordData>>>: FromIterator<I>,
{
fn from_iter<T: IntoIterator<Item = I>>(iter: T) -> Self {
Self {
@@ -79,3 +106,77 @@ where
}
}
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use dns_types::DomainNameRef;
use super::*;
#[test]
fn returns_most_recent_records_at_timestamp() {
let now = Instant::now();
let mut dns_records = DnsRecords::default();
dns_records.merge(DnsRecords::from([(
EXAMPLE_COM.to_vec(),
BTreeMap::from([(
now,
BTreeSet::from([a_record("127.0.0.1"), a_record("127.0.0.2")]),
)]),
)]));
dns_records.merge(DnsRecords::from([(
EXAMPLE_COM.to_vec(),
BTreeMap::from([(
now + Duration::from_secs(5),
BTreeSet::from([a_record("127.0.0.3"), a_record("127.0.0.4")]),
)]),
)]));
dns_records.merge(DnsRecords::from([(
EXAMPLE_COM.to_vec(),
BTreeMap::from([(
now + Duration::from_secs(10),
BTreeSet::from([a_record("127.0.0.5"), a_record("127.0.0.6")]),
)]),
)]));
assert_eq!(
dns_records
.domain_ips_iter(&EXAMPLE_COM.to_vec(), now)
.collect::<Vec<_>>(),
vec![ip("127.0.0.1"), ip("127.0.0.2")]
);
assert_eq!(
dns_records
.domain_ips_iter(&EXAMPLE_COM.to_vec(), now + Duration::from_secs(2))
.collect::<Vec<_>>(),
vec![ip("127.0.0.1"), ip("127.0.0.2")]
);
assert_eq!(
dns_records
.domain_ips_iter(&EXAMPLE_COM.to_vec(), now + Duration::from_secs(7))
.collect::<Vec<_>>(),
vec![ip("127.0.0.3"), ip("127.0.0.4")]
);
assert_eq!(
dns_records
.domain_ips_iter(&EXAMPLE_COM.to_vec(), now + Duration::from_secs(12))
.collect::<Vec<_>>(),
vec![ip("127.0.0.5"), ip("127.0.0.6")]
);
}
const EXAMPLE_COM: DomainNameRef =
unsafe { DomainNameRef::from_octets_unchecked(b"\x08example\x03com\x00") };
fn a_record(ip: &str) -> OwnedRecordData {
OwnedRecordData::A(ip.parse().unwrap())
}
fn ip(ip: &str) -> IpAddr {
ip.parse().unwrap()
}
}

View File

@@ -35,7 +35,7 @@ impl TcpDnsServerResource {
pub fn handle_timeout(&mut self, global_dns_records: &DnsRecords, now: Instant) {
self.server.handle_timeout(now);
while let Some(query) = self.server.poll_queries() {
let response = handle_dns_query(&query.message, global_dns_records);
let response = handle_dns_query(&query.message, global_dns_records, now);
self.server
.send_message(query.local, query.remote, response)
@@ -53,12 +53,12 @@ impl UdpDnsServerResource {
self.inbound_packets.push_back(packet);
}
pub fn handle_timeout(&mut self, global_dns_records: &DnsRecords, _: Instant) {
pub fn handle_timeout(&mut self, global_dns_records: &DnsRecords, now: Instant) {
while let Some(packet) = self.inbound_packets.pop_front() {
let udp = packet.as_udp().unwrap();
let query = dns_types::Query::parse(udp.payload()).unwrap();
let response = handle_dns_query(&query, global_dns_records);
let response = handle_dns_query(&query, global_dns_records, now);
self.outbound_packets.push_back(
ip_packet::make::udp_packet(
@@ -81,13 +81,14 @@ impl UdpDnsServerResource {
fn handle_dns_query(
query: &dns_types::Query,
global_dns_records: &DnsRecords,
at: Instant,
) -> dns_types::Response {
const TTL: u32 = 1; // We deliberately chose a short TTL so we don't have to model the DNS cache in these tests.
let domain = query.domain().to_vec();
let records = global_dns_records
.domain_records_iter(&domain)
.domain_records_iter(&domain, at)
.filter(|r| r.rtype() == query.qtype())
.map(|rdata| (domain.clone(), TTL, rdata));

View File

@@ -20,19 +20,14 @@ impl FormatTime for FluxCapacitor {
}
}
impl Default for FluxCapacitor {
fn default() -> Self {
let start = Instant::now();
let utc_start = Utc::now();
impl FluxCapacitor {
pub(crate) fn new(start: Instant, utc_start: DateTime<Utc>) -> Self {
Self {
start,
now: Arc::new(Mutex::new((start, utc_start))),
}
}
}
impl FluxCapacitor {
const SMALL_TICK: Duration = Duration::from_millis(10);
const LARGE_TICK: Duration = Duration::from_millis(100);

View File

@@ -1,6 +1,7 @@
use std::{
collections::{BTreeMap, BTreeSet},
net::IpAddr,
time::Instant,
};
use proptest::{prelude::*, sample};
@@ -21,9 +22,10 @@ impl IcmpErrorHosts {
/// Samples a subset of the provided DNS records which we will generate ICMP errors.
pub(crate) fn icmp_error_hosts(
dns_resource_records: DnsRecords,
now: Instant,
) -> impl Strategy<Value = IcmpErrorHosts> {
// First, deduplicate all IPs.
let unique_ips = dns_resource_records.ips_iter().collect::<BTreeSet<_>>();
let unique_ips = dns_resource_records.ips_iter(now).collect::<BTreeSet<_>>();
let ips = Vec::from_iter(unique_ips);
Just(ips)
@@ -31,7 +33,7 @@ pub(crate) fn icmp_error_hosts(
.prop_flat_map(|ips| {
let num_ips = ips.len();
sample::subsequence(ips, 0..num_ips) // Pick a subset of IPs.
sample::subsequence(ips, num_ips / 2) // Pick a subset of IPs.
})
.prop_flat_map(|ips| {
ips.into_iter()

View File

@@ -12,8 +12,10 @@ use dns_types::{DomainName, RecordType};
use ip_network::{Ipv4Network, Ipv6Network};
use itertools::Itertools;
use prop::sample::select;
use proptest::collection::btree_set;
use proptest::{prelude::*, sample};
use std::net::{Ipv4Addr, Ipv6Addr};
use std::time::Instant;
use std::{
collections::{BTreeMap, BTreeSet, HashSet},
fmt, iter,
@@ -53,14 +55,14 @@ pub(crate) struct ReferenceState {
/// Care has to be taken that we don't implement things in a buggy way here.
/// After all, if your test has bugs, it won't catch any in the actual implementation.
impl ReferenceState {
pub(crate) fn initial_state() -> BoxedStrategy<Self> {
pub(crate) fn initial_state(start: Instant) -> BoxedStrategy<Self> {
stub_portal()
.prop_flat_map(|portal| {
let gateways = portal.gateways();
let dns_resource_records = portal.dns_resource_records();
.prop_flat_map(move |portal| {
let gateways = portal.gateways(start);
let dns_resource_records = portal.dns_resource_records(start);
let client = portal.client(system_dns_servers(), upstream_dns_servers());
let relays = relays(relay_id());
let global_dns_records = global_dns_records(); // Start out with a set of global DNS records so we have something to resolve outside of DNS resources.
let global_dns_records = global_dns_records(start); // Start out with a set of global DNS records so we have something to resolve outside of DNS resources.
let drop_direct_client_traffic = any::<bool>();
(
@@ -74,7 +76,7 @@ impl ReferenceState {
)
})
.prop_flat_map(
|(
move |(
client,
gateways,
portal,
@@ -88,7 +90,7 @@ impl ReferenceState {
Just(gateways),
Just(portal),
Just(dns_resource_records.clone()),
icmp_error_hosts(dns_resource_records),
icmp_error_hosts(dns_resource_records, start),
Just(relays),
Just(global_dns),
Just(drop_direct_client_traffic),
@@ -96,7 +98,7 @@ impl ReferenceState {
},
)
.prop_flat_map(
|(
move |(
client,
gateways,
portal,
@@ -112,7 +114,7 @@ impl ReferenceState {
Just(portal),
Just(dns_resource_records.clone()),
Just(icmp_error_hosts.clone()),
tcp_resources(dns_resource_records, icmp_error_hosts),
tcp_resources(dns_resource_records, icmp_error_hosts, start),
Just(relays),
Just(global_dns),
Just(drop_direct_client_traffic),
@@ -178,7 +180,7 @@ impl ReferenceState {
},
)
.prop_map(
|(
move |(
client,
gateways,
relays,
@@ -209,7 +211,7 @@ impl ReferenceState {
///
/// This is invoked by proptest repeatedly to explore further state transitions.
/// Here, we should only generate [`Transition`]s that make sense for the current state.
pub(crate) fn transitions(state: &Self) -> BoxedStrategy<Transition> {
pub(crate) fn transitions(state: &Self, now: Instant) -> BoxedStrategy<Transition> {
CompositeStrategy::default()
.with(
1,
@@ -344,9 +346,33 @@ impl ReferenceState {
connect_tcp(Just(tunnel_ip6), select(dns_v6_domains))
},
)
.with_if_not_empty(
10,
state.resolved_v4_domains_with_icmp_errors(now),
|dns_v4_domains| {
let tunnel_ip4 = state.client.inner().tunnel_ip4;
prop_oneof![
icmp_packet(Just(tunnel_ip4), select(dns_v4_domains.clone())),
udp_packet(Just(tunnel_ip4), select(dns_v4_domains)),
]
},
)
.with_if_not_empty(
10,
state.resolved_v6_domains_with_icmp_errors(now),
|dns_v6_domains| {
let tunnel_ip6 = state.client.inner().tunnel_ip6;
prop_oneof![
icmp_packet(Just(tunnel_ip6), select(dns_v6_domains.clone()),),
udp_packet(Just(tunnel_ip6), select(dns_v6_domains),),
]
},
)
.with_if_not_empty(
5,
(state.all_domains(), state.reachable_dns_servers()),
(state.all_domains(now), state.reachable_dns_servers()),
|(domains, dns_servers)| {
dns_queries(sample::select(domains), sample::select(dns_servers))
.prop_map(Transition::SendDnsQueries)
@@ -384,7 +410,7 @@ impl ReferenceState {
state
.client
.inner()
.resolved_ip4_for_non_resources(&state.global_dns_records),
.resolved_ip4_for_non_resources(&state.global_dns_records, now),
|resolved_non_resource_ip4s| {
let tunnel_ip4 = state.client.inner().tunnel_ip4;
@@ -399,7 +425,7 @@ impl ReferenceState {
state
.client
.inner()
.resolved_ip6_for_non_resources(&state.global_dns_records),
.resolved_ip6_for_non_resources(&state.global_dns_records, now),
|resolved_non_resource_ip6s| {
let tunnel_ip6 = state.client.inner().tunnel_ip6;
@@ -425,13 +451,17 @@ impl ReferenceState {
udp_packet(Just(tunnel_ip6), select_host_v6(&gateway_ips)),
]
})
.with_if_not_empty(5, state.dns_resource_domains(), |domains| {
(sample::select(domains), btree_set(dns_record(), 1..6))
.prop_map(|(domain, records)| Transition::UpdateDnsRecords { domain, records })
})
.boxed()
}
/// Apply the transition to our reference state.
///
/// Here is where we implement the "expected" logic.
pub(crate) fn apply(mut state: Self, transition: &Transition) -> Self {
pub(crate) fn apply(mut state: Self, transition: &Transition, now: Instant) -> Self {
match transition {
Transition::AddResource(resource) => {
state.client.exec_mut(|client| match resource {
@@ -591,6 +621,12 @@ impl ReferenceState {
Transition::RestartClient(key) => state.client.exec_mut(|c| {
c.restart(*key);
}),
Transition::UpdateDnsRecords { domain, records } => {
state.global_dns_records.merge(DnsRecords::from([(
domain.clone(),
BTreeMap::from([(now, records.clone())]),
)]));
}
};
state
@@ -785,6 +821,7 @@ impl ReferenceState {
// Also don't deactivate resources where we have TCP connections as those would get interrupted.
has_resource && has_gateway_for_resource && !has_tcp_connection
}
Transition::UpdateDnsRecords { .. } => true,
}
}
@@ -833,21 +870,36 @@ impl ReferenceState {
impl ReferenceState {
// We surface what are the existing rtypes for a domain so that it's easier
// for the proptests to hit an existing record.
fn all_domains(&self) -> Vec<(DomainName, Vec<RecordType>)> {
fn all_domains(&self, now: Instant) -> Vec<(DomainName, Vec<RecordType>)> {
fn domains_and_rtypes(
records: &DnsRecords,
at: Instant,
) -> impl Iterator<Item = (DomainName, Vec<RecordType>)> {
records
.domains_iter()
.map(|d| (d.clone(), records.domain_rtypes(&d)))
.map(move |d| (d.clone(), records.domain_rtypes(&d, at)))
}
// We may have multiple gateways in a site, so we need to dedup.
let unique_domains = self
.gateways
.values()
.flat_map(|g| domains_and_rtypes(g.inner().dns_records()))
.chain(domains_and_rtypes(&self.global_dns_records))
.flat_map(|g| domains_and_rtypes(g.inner().dns_records(), now))
.chain(domains_and_rtypes(&self.global_dns_records, now))
.filter(|(_, rtypes)| !rtypes.is_empty())
.collect::<BTreeSet<_>>();
Vec::from_iter(unique_domains)
}
fn dns_resource_domains(&self) -> Vec<DomainName> {
// We may have multiple gateways in a site, so we need to dedup.
let unique_domains = self
.gateways
.values()
.flat_map(|g| g.inner().dns_records().domains_iter())
.chain(self.global_dns_records.domains_iter())
.filter(|d| self.client.inner().dns_resource_by_domain(d).is_some())
.collect::<BTreeSet<_>>();
Vec::from_iter(unique_domains)
@@ -965,6 +1017,32 @@ impl ReferenceState {
.collect()
}
fn resolved_v4_domains_with_icmp_errors(&self, at: Instant) -> Vec<DomainName> {
self.client
.inner()
.resolved_v4_domains()
.into_iter()
.filter(|d| {
self.global_dns_records
.domain_ips_iter(d, at)
.any(|ip| self.icmp_error_hosts.icmp_error_for_ip(ip).is_some())
})
.collect()
}
fn resolved_v6_domains_with_icmp_errors(&self, at: Instant) -> Vec<DomainName> {
self.client
.inner()
.resolved_v6_domains()
.into_iter()
.filter(|d| {
self.global_dns_records
.domain_ips_iter(d, at)
.any(|ip| self.icmp_error_hosts.icmp_error_for_ip(ip).is_some())
})
.collect()
}
fn deploy_new_relays(&mut self, new_relays: &BTreeMap<RelayId, Host<u64>>) {
// Always take down all relays because we can't know which one was sampled for the connection.
for relay in self.relays.values() {

View File

@@ -627,6 +627,16 @@ impl RefClient {
for status in self.site_status.values_mut() {
*status = ResourceStatus::Unknown;
}
// TCP connections will automatically re-create connections to Gateways.
for r in self
.expected_tcp_connections
.values()
.copied()
.collect::<Vec<_>>()
{
self.set_resource_online(r);
}
}
pub(crate) fn add_internet_resource(&mut self, resource: InternetResource) {
@@ -643,36 +653,36 @@ impl RefClient {
pub(crate) fn add_cidr_resource(&mut self, r: CidrResource) {
let address = r.address;
let r = Resource::Cidr(r);
let rid = r.id();
if let Some(existing) = self
.resources
.iter()
.find(|existing| existing.id() == r.id())
if let Some(existing) = self.resources.iter().find(|existing| existing.id() == rid)
&& (existing.has_different_address(&r) || existing.has_different_site(&r))
{
self.remove_resource(&existing.id());
}
self.resources.push(r.clone());
self.resources.push(r);
self.cidr_resources = self.recalculate_cidr_routes();
match address {
IpNetwork::V4(v4) => {
self.ipv4_routes.insert(r.id(), v4);
self.ipv4_routes.insert(rid, v4);
}
IpNetwork::V6(v6) => {
self.ipv6_routes.insert(r.id(), v6);
self.ipv6_routes.insert(rid, v6);
}
}
if self.expected_tcp_connections.values().contains(&rid) {
self.set_resource_online(rid);
}
}
pub(crate) fn add_dns_resource(&mut self, r: DnsResource) {
let r = Resource::Dns(r);
let rid = r.id();
if let Some(existing) = self
.resources
.iter()
.find(|existing| existing.id() == r.id())
if let Some(existing) = self.resources.iter().find(|existing| existing.id() == rid)
&& (existing.has_different_address(&r)
|| existing.has_different_ip_stack(&r)
|| existing.has_different_site(&r))
@@ -681,6 +691,10 @@ impl RefClient {
}
self.resources.push(r);
if self.expected_tcp_connections.values().contains(&rid) {
self.set_resource_online(rid);
}
}
/// Re-adds all resources in the order they have been initially added.
@@ -700,10 +714,10 @@ impl RefClient {
&self,
has_failed_tcp_connection: impl Fn((SPort, DPort)) -> bool,
) -> (BTreeMap<ResourceId, ResourceStatus>, BTreeSet<ResourceId>) {
let maybe_online_sites = self
let resources_with_failed_tcp_connections = self
.expected_tcp_connections
.iter()
.filter(|((_, _, sport, dport), _)| !has_failed_tcp_connection((*sport, *dport)))
.filter(|((_, _, sport, dport), _)| has_failed_tcp_connection((*sport, *dport)))
.filter_map(|(_, resource)| self.site_for_resource(*resource))
.flat_map(|site| {
self.resources
@@ -726,7 +740,7 @@ impl RefClient {
})
.collect();
(resource_status, maybe_online_sites)
(resource_status, resources_with_failed_tcp_connections)
}
pub(crate) fn tunnel_ip_for(&self, dst: IpAddr) -> IpAddr {
@@ -1095,8 +1109,9 @@ impl RefClient {
pub(crate) fn resolved_ip4_for_non_resources(
&self,
global_dns_records: &DnsRecords,
at: Instant,
) -> Vec<Ipv4Addr> {
self.resolved_ips_for_non_resources(global_dns_records)
self.resolved_ips_for_non_resources(global_dns_records, at)
.filter_map(|ip| match ip {
IpAddr::V4(v4) => Some(v4),
IpAddr::V6(_) => None,
@@ -1107,8 +1122,9 @@ impl RefClient {
pub(crate) fn resolved_ip6_for_non_resources(
&self,
global_dns_records: &DnsRecords,
at: Instant,
) -> Vec<Ipv6Addr> {
self.resolved_ips_for_non_resources(global_dns_records)
self.resolved_ips_for_non_resources(global_dns_records, at)
.filter_map(|ip| match ip {
IpAddr::V6(v6) => Some(v6),
IpAddr::V4(_) => None,
@@ -1119,13 +1135,14 @@ impl RefClient {
fn resolved_ips_for_non_resources<'a>(
&'a self,
global_dns_records: &'a DnsRecords,
at: Instant,
) -> impl Iterator<Item = IpAddr> + 'a {
self.dns_records
.iter()
.filter_map(|(domain, _)| {
.filter_map(move |(domain, _)| {
self.dns_resource_by_domain(domain)
.is_none()
.then_some(global_dns_records.domain_ips_iter(domain))
.then_some(global_dns_records.domain_ips_iter(domain, at))
})
.flatten()
}

View File

@@ -11,6 +11,7 @@ use crate::{GatewayState, IpConfig};
use anyhow::{Result, bail};
use chrono::{DateTime, Utc};
use connlib_model::{GatewayId, RelayId};
use dns_types::DomainName;
use ip_packet::{IcmpEchoHeader, Icmpv4Type, Icmpv6Type, IpPacket};
use proptest::prelude::*;
use snownet::Transmit;
@@ -27,10 +28,13 @@ pub(crate) struct SimGateway {
pub(crate) sut: GatewayState,
/// The received ICMP packets, indexed by our custom ICMP payload.
pub(crate) received_icmp_requests: BTreeMap<u64, IpPacket>,
pub(crate) received_icmp_requests: BTreeMap<u64, (Instant, IpPacket)>,
/// The received UDP packets, indexed by our custom UDP payload.
pub(crate) received_udp_requests: BTreeMap<u64, IpPacket>,
pub(crate) received_udp_requests: BTreeMap<u64, (Instant, IpPacket)>,
/// The times we resolved DNS records for a domain.
pub(crate) dns_query_timestamps: BTreeMap<DomainName, Vec<Instant>>,
site_specific_dns_records: DnsRecords,
udp_dns_server_resources: BTreeMap<SocketAddr, UdpDnsServerResource>,
@@ -55,6 +59,7 @@ impl SimGateway {
udp_dns_server_resources: Default::default(),
tcp_dns_server_resources: Default::default(),
received_udp_requests: Default::default(),
dns_query_timestamps: Default::default(),
tcp_resources: tcp_resources
.into_iter()
.map(|address| {
@@ -194,7 +199,7 @@ impl SimGateway {
let packet_id = u64::from_be_bytes(*icmp.payload().first_chunk().unwrap());
tracing::debug!(%packet_id, "Received ICMP request");
self.received_icmp_requests
.insert(packet_id, packet.clone());
.insert(packet_id, (now, packet.clone()));
return self.handle_icmp_request(&packet, echo, icmp.payload(), icmp_error, now);
}
@@ -204,7 +209,7 @@ impl SimGateway {
let packet_id = u64::from_be_bytes(*icmp.payload().first_chunk().unwrap());
tracing::debug!(%packet_id, "Received ICMP request");
self.received_icmp_requests
.insert(packet_id, packet.clone());
.insert(packet_id, (now, packet.clone()));
return self.handle_icmp_request(&packet, echo, icmp.payload(), icmp_error, now);
}
@@ -234,7 +239,7 @@ impl SimGateway {
}
if let Some(reply) = icmp_error.or_else(|| echo_reply(packet.clone())) {
self.request_received(&packet);
self.request_received(&packet, now);
let transmit = self.sut.handle_tun_input(reply, now).unwrap()?;
return Some(transmit);
@@ -257,11 +262,12 @@ impl SimGateway {
)
}
fn request_received(&mut self, packet: &IpPacket) {
fn request_received(&mut self, packet: &IpPacket, now: Instant) {
if let Some(udp) = packet.as_udp() {
let packet_id = u64::from_be_bytes(*udp.payload().first_chunk().unwrap());
tracing::debug!(%packet_id, "Received UDP request");
self.received_udp_requests.insert(packet_id, packet.clone());
self.received_udp_requests
.insert(packet_id, (now, packet.clone()));
}
}

View File

@@ -15,22 +15,24 @@ use prop::sample;
use proptest::{collection, prelude::*};
use std::iter;
use std::num::NonZeroU16;
use std::time::Instant;
use std::{
collections::{BTreeMap, BTreeSet},
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
time::Duration,
};
pub(crate) fn global_dns_records() -> impl Strategy<Value = DnsRecords> {
pub(crate) fn global_dns_records(at: Instant) -> impl Strategy<Value = DnsRecords> {
collection::btree_map(
domain_name(2..4),
collection::btree_set(dns_record(), 1..6),
collection::btree_set(dns_record(), 1..6)
.prop_map(move |records| BTreeMap::from([(at, records)])),
0..5,
)
.prop_map_into()
}
fn dns_record() -> impl Strategy<Value = OwnedRecordData> {
pub(crate) fn dns_record() -> impl Strategy<Value = OwnedRecordData> {
prop_oneof![
3 => non_reserved_ip().prop_map(dns_types::records::ip),
1 => collection::vec(txt_record(), 6..=10)
@@ -141,6 +143,7 @@ pub(crate) fn stub_portal() -> impl Strategy<Value = StubPortal> {
pub(crate) fn tcp_resources(
dns_records: DnsRecords,
imcp_error_hosts: IcmpErrorHosts,
at: Instant,
) -> impl Strategy<Value = BTreeMap<DomainName, BTreeSet<SocketAddr>>> {
let all_domains = dns_records.domains_iter().collect::<Vec<_>>();
@@ -153,7 +156,7 @@ pub(crate) fn tcp_resources(
.into_iter()
.filter(|(domain, _)| {
dns_records
.domain_ips_iter(domain)
.domain_ips_iter(domain, at)
.all(|ip| imcp_error_hosts.icmp_error_for_ip(ip).is_none())
})
.map({
@@ -161,7 +164,7 @@ pub(crate) fn tcp_resources(
move |(domain, port)| {
let addresses = dns_records
.domain_ips_iter(&domain)
.domain_ips_iter(&domain, at)
.map(|address| SocketAddr::new(address, port.get()))
.collect::<BTreeSet<_>>();
@@ -311,6 +314,7 @@ pub(crate) fn resolved_ips() -> impl Strategy<Value = BTreeSet<OwnedRecordData>>
pub(crate) fn subdomain_records(
base: String,
subdomains: impl Strategy<Value = String>,
at: Instant,
) -> impl Strategy<Value = DnsRecords> {
collection::hash_map(subdomains, resolved_ips(), 1..4).prop_map(move |subdomain_ips| {
subdomain_ips
@@ -318,7 +322,7 @@ pub(crate) fn subdomain_records(
.map(|(label, ips)| {
let domain = format!("{label}.{base}");
(domain.parse().unwrap(), ips)
(domain.parse().unwrap(), BTreeMap::from([(at, ips)]))
})
.collect()
})

View File

@@ -24,6 +24,7 @@ use std::{
collections::{BTreeMap, BTreeSet},
iter,
net::{IpAddr, Ipv4Addr, Ipv6Addr},
time::Instant,
};
/// Stub implementation of the portal.
@@ -261,6 +262,7 @@ impl StubPortal {
pub(crate) fn gateways(
&self,
at: Instant,
) -> impl Strategy<Value = BTreeMap<GatewayId, Host<RefGateway>>> + use<> {
let dns_resources = self.dns_resources.clone();
@@ -273,7 +275,7 @@ impl StubPortal {
ref_gateway_host(
Just(*ipv4_addr),
Just(*ipv6_addr),
site_specific_dns_records(dns_resources.clone(), *site_id),
site_specific_dns_records(dns_resources.clone(), *site_id, at),
),
)
})
@@ -318,8 +320,11 @@ impl StubPortal {
proptest::option::of(sample::select(possible_search_domains))
}
pub(crate) fn dns_resource_records(&self) -> impl Strategy<Value = DnsRecords> + use<> {
dns_resource_records(self.dns_resources.clone().into_values())
pub(crate) fn dns_resource_records(
&self,
at: Instant,
) -> impl Strategy<Value = DnsRecords> + use<> {
dns_resource_records(self.dns_resources.clone().into_values(), at)
}
}
@@ -327,18 +332,20 @@ impl StubPortal {
fn site_specific_dns_records(
dns_resources: BTreeMap<ResourceId, client::DnsResource>,
site: SiteId,
at: Instant,
) -> impl Strategy<Value = DnsRecords> {
let dns_resources_in_site = dns_resources
.into_values()
.filter(move |resource| resource.sites.iter().any(|s| s.id == site));
dns_resource_records(dns_resources_in_site).prop_flat_map(|records| {
dns_resource_records(dns_resources_in_site, at).prop_flat_map(move |records| {
if records.is_empty() {
Just(DnsRecords::default()).boxed()
} else {
collection::btree_map(
sample::select(records.domains_iter().collect::<Vec<_>>()),
collection::btree_set(site_specific_dns_record(), 1..6),
collection::btree_set(site_specific_dns_record(), 1..6)
.prop_map(move |records| BTreeMap::from([(at, records)])),
0..5,
)
.prop_map_into()
@@ -349,6 +356,7 @@ fn site_specific_dns_records(
fn dns_resource_records(
dns_resources: impl Iterator<Item = DnsResource>,
at: Instant,
) -> impl Strategy<Value = DnsRecords> {
dns_resources
.map(|resource| {
@@ -360,11 +368,14 @@ fn dns_resource_records(
// For example, `*.example.com` and `app.example.com`.
match address.split_once('.') {
Some(("*" | "**", base)) => {
subdomain_records(base.to_owned(), domain_label()).boxed()
subdomain_records(base.to_owned(), domain_label(), at).boxed()
}
_ => resolved_ips()
.prop_map(move |resolved_ips| {
DnsRecords::from([(address.parse().unwrap(), resolved_ips)])
DnsRecords::from([(
address.parse().unwrap(),
BTreeMap::from([(at, resolved_ips)]),
)])
})
.boxed(),
}

View File

@@ -422,6 +422,7 @@ impl TunnelTest {
c.update_relays(iter::empty(), state.relays.iter(), now);
})
}
Transition::UpdateDnsRecords { .. } => {}
};
state.advance(ref_state, &mut buffered_transmits);
@@ -530,7 +531,7 @@ impl TunnelTest {
let transport = query.transport;
let response =
self.on_recursive_dns_query(&query.message, &ref_state.global_dns_records);
self.on_recursive_dns_query(&query.message, &ref_state.global_dns_records, now);
self.client.exec_mut(|c| {
c.sut.handle_dns_response(
dns::RecursiveResponse {
@@ -932,6 +933,7 @@ impl TunnelTest {
&self,
query: &dns_types::Query,
global_dns_records: &DnsRecords,
now: Instant,
) -> dns_types::Response {
const TTL: u32 = 1; // We deliberately chose a short TTL so we don't have to model the DNS cache in these tests.
@@ -941,7 +943,7 @@ impl TunnelTest {
let response = dns_types::ResponseBuilder::for_query(query, ResponseCode::NOERROR)
.with_records(
global_dns_records
.domain_records_iter(&domain)
.domain_records_iter(&domain, now)
.filter(|record| qtype == record.rtype())
.map(|rdata| (domain.clone(), TTL, rdata)),
)
@@ -1057,9 +1059,15 @@ fn on_gateway_event(
}
}),
GatewayEvent::ResolveDns(r) => {
let resolved_ips = global_dns_records.domain_ips_iter(r.domain()).collect();
let resolved_ips = global_dns_records
.domain_ips_iter(r.domain(), now)
.collect();
gateway.exec_mut(|g| {
g.dns_query_timestamps
.entry(r.domain().clone())
.or_default()
.push(now);
g.sut
.handle_domain_resolved(r, Ok(resolved_ips), now)
.unwrap()

View File

@@ -77,10 +77,11 @@ impl Client {
// TODO: Upstream ICMP error handling to `smoltcp`.
if let Ok(Some((failed_packet, _))) = packet.icmp_error()
&& let Layer4Protocol::Tcp { dst, .. } = failed_packet.layer4_protocol()
&& let Some(handle) = self
.sockets_by_remote
.get(&SocketAddr::new(failed_packet.dst(), dst))
&& let socket = SocketAddr::new(failed_packet.dst(), dst)
&& let Some(handle) = self.sockets_by_remote.get(&socket)
{
tracing::debug!(%socket, "Received ICMP error");
self.sockets.get_mut::<l3_tcp::Socket>(*handle).abort();
}

View File

@@ -3,7 +3,7 @@ use crate::{
proptest::{host_v4, host_v6},
};
use connlib_model::{RelayId, ResourceId, Site};
use dns_types::{DomainName, RecordType};
use dns_types::{DomainName, OwnedRecordData, RecordType};
use ip_network::IpNetwork;
use super::{
@@ -14,7 +14,7 @@ use crate::messages::DnsServer;
use prop::collection;
use proptest::{prelude::*, sample};
use std::{
collections::BTreeMap,
collections::{BTreeMap, BTreeSet},
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
num::NonZeroU16,
};
@@ -102,6 +102,12 @@ pub(crate) enum Transition {
/// De-authorize access to a resource whilst the Gateway is network-partitioned from the portal.
DeauthorizeWhileGatewayIsPartitioned(ResourceId),
/// De-authorize access to a resource whilst the Gateway is network-partitioned from the portal.
UpdateDnsRecords {
domain: DomainName,
records: BTreeSet<OwnedRecordData>,
},
}
#[derive(Debug, Clone)]

View File

@@ -26,6 +26,10 @@ export default function Gateway() {
<ChangeItem pull="10620">
Adds a `--log-format` CLI option to output logs as JSON.
</ChangeItem>
<ChangeItem pull="10796">
Fixes an issue where packets for DNS resources would be routed to
stale IPs after DNS record changes.
</ChangeItem>
</Unreleased>
<Entry version="1.4.17" date={new Date("2025-10-16")}>
<ChangeItem pull="10367">