diff --git a/rust/Cargo.lock b/rust/Cargo.lock index 76d089e29..d7ac72a88 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -5738,8 +5738,7 @@ dependencies = [ [[package]] name = "proptest" version = "1.9.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bee689443a2bd0a16ab0348b52ee43e3b2d1b1f931c8aa5c9f8de4c86fbe8c40" +source = "git+https://github.com/firezone/proptest?branch=feat%2Fstate-machine-closure#ea2146e4c6116c855571e5a0f717eaaab8821ff0" dependencies = [ "bit-set", "bit-vec", @@ -5756,9 +5755,8 @@ dependencies = [ [[package]] name = "proptest-state-machine" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "05ad8988b889475b24ab3b485d0f3de21863077c2ff169b37e3c5b805fcff624" +version = "0.6.0" +source = "git+https://github.com/firezone/proptest?branch=feat%2Fstate-machine-closure#ea2146e4c6116c855571e5a0f717eaaab8821ff0" dependencies = [ "proptest", ] diff --git a/rust/Cargo.toml b/rust/Cargo.toml index be12244c3..3692871fa 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -136,8 +136,8 @@ parking_lot = "0.12.4" phoenix-channel = { path = "connlib/phoenix-channel" } png = "0.17.16" proc-macro2 = "1.0" -proptest = "1.7.0" -proptest-state-machine = "0.5.0" +proptest = "1.9.0" +proptest-state-machine = "0.6.0" quinn-udp = { version = "0.5.12", features = ["fast-apple-datapath"] } quote = "1.0" rand = "0.8.5" @@ -249,6 +249,8 @@ softbuffer = { git = "https://github.com/rust-windowing/softbuffer" } # Waiting str0m = { git = "https://github.com/algesten/str0m", branch = "main" } moka = { git = "https://github.com/moka-rs/moka", branch = "main" } # Waiting for release. quinn-udp = { git = "https://github.com/quinn-rs/quinn", branch = "main" } # Waiting for release. +proptest = { git = "https://github.com/firezone/proptest", branch = "feat/state-machine-closure" } +proptest-state-machine = { git = "https://github.com/firezone/proptest", branch = "feat/state-machine-closure" } # Enforce `tracing-macros` to have released `tracing` version. [patch.'https://github.com/tokio-rs/tracing'] diff --git a/rust/connlib/tunnel/proptest-regressions/tests.txt b/rust/connlib/tunnel/proptest-regressions/tests.txt index 4e3a09edf..8b18e671c 100644 --- a/rust/connlib/tunnel/proptest-regressions/tests.txt +++ b/rust/connlib/tunnel/proptest-regressions/tests.txt @@ -222,3 +222,14 @@ cc 04193ee1047f542c469aa0893bf636df9c317943022d922e231de3e821b39486 cc e8520f159df085f7dbe6dce8b121336d33708af9f804a8a14bf6b5a3eb3a9d4d cc fd95c0fcd3af20d73849004cb642b09c5bacfa7ca25781d0268441d49fe3b6cf cc 4f65d01188bb870aa8f4893530f77ace3434c133cf24735619196df6d043f4cf +cc bfa39b9578b2d143e2ff3fcd8622c2af37d9a1408288c525138cc9e4c926e9e3 +cc 10a24fe019b05296d841546467ea60f02df0ee9350ae1c4e7bba5ea80425ca3f +cc 2ae65b3fa5b90531b325b05e0694ce6d80f12a85d805250d3e70557733a7ccc8 +cc bcec408954dca8a463d72d9ff588f1d2eacffac988010ce8830d69ca0aba8a25 +cc acb651429f1c625d5d6bfb25fe62cc951262c4d5cd8e372876a439e2627b19cd +cc 7a1d6686275361933c710a0f051607211de92cac46c50436ebd2f970dbcbbf23 +cc 27c05624dcd17a8118f6ae4f6d11dd324986b6cd01b998c804dc5eb27d1e2f06 +cc 4bf2050d07594df9fbaf3462fe9dc0463739e42fdd5f614c644c86d37163cdb6 +cc 0ca23286f6e2952919d254d7c524e65ece78672d558672f124ea67b37e82f7d5 +cc 3a579d80f7bff0b34ad67a2b9156c50eeb5c3f1861793a52539b80b75ca415b1 +cc 096d9ba59d5770265ec3dfb185560f28b0dcd2839584eef0918eebe12b63c01c diff --git a/rust/connlib/tunnel/src/gateway.rs b/rust/connlib/tunnel/src/gateway.rs index 8294faae1..258989ec5 100644 --- a/rust/connlib/tunnel/src/gateway.rs +++ b/rust/connlib/tunnel/src/gateway.rs @@ -738,4 +738,8 @@ impl ResolveDnsRequest { pub fn domain(&self) -> &DomainName { &self.domain } + + pub fn client(&self) -> ClientId { + self.client + } } diff --git a/rust/connlib/tunnel/src/gateway/client_on_gateway.rs b/rust/connlib/tunnel/src/gateway/client_on_gateway.rs index 9e5df322f..a94e61abe 100644 --- a/rust/connlib/tunnel/src/gateway/client_on_gateway.rs +++ b/rust/connlib/tunnel/src/gateway/client_on_gateway.rs @@ -129,27 +129,18 @@ impl ClientOnGateway { .cycle(), ); - let mut ip_maps = ipv4_maps.chain(ipv6_maps); - - loop { - let Some(either_or_both) = ip_maps.next() else { - break; - }; + // Clear out all current translations for these proxy IPs. + for proxy_ip in proxy_ips.iter() { + self.permanent_translations.remove(proxy_ip); + } + for either_or_both in ipv4_maps.chain(ipv6_maps) { let (proxy_ip, maybe_real_ip) = match either_or_both { EitherOrBoth::Both(proxy_ip, real_ip) => (proxy_ip, Some(real_ip)), EitherOrBoth::Left(proxy_ip) => (proxy_ip, None), EitherOrBoth::Right(_) => break, }; - if let Some(state) = self.permanent_translations.get(proxy_ip) - && self.nat_table.has_entry_for_inside(*proxy_ip) - && state.resolved_ip != maybe_real_ip - { - tracing::debug!(%name, %proxy_ip, new_real_ip = ?maybe_real_ip, current_real_ip = ?state.resolved_ip, "Skipping DNS resource NAT entry because we have open NAT sessions for it"); - continue; - } - tracing::debug!(%name, %proxy_ip, real_ip = ?maybe_real_ip); self.permanent_translations.insert( @@ -1099,43 +1090,81 @@ mod tests { ) .unwrap(); - let request = ip_packet::make::udp_packet( - client_tun_ipv4(), - proxy_ip1(), - 1, - foo_allowed_port(), - vec![0, 0, 0, 0, 0, 0, 0, 0], - ) - .unwrap(); + { + let request = ip_packet::make::udp_packet( + client_tun_ipv4(), + proxy_ip1(), + 1, + foo_allowed_port(), + vec![0, 0, 0, 0, 0, 0, 0, 0], + ) + .unwrap(); - let result = peer.translate_outbound(request.clone(), now).unwrap(); + let result = peer.translate_outbound(request.clone(), now).unwrap(); - assert!(matches!(result, TranslateOutboundResult::Send(_))); + assert!(matches!(result, TranslateOutboundResult::Send(_))); - peer.setup_nat( - foo_name().parse().unwrap(), - foo_resource_id(), - BTreeSet::from([foo_real_ip2().into()]), // Setting up with a new IP! - BTreeSet::from([proxy_ip1().into(), proxy_ip2().into()]), - ) - .unwrap(); + peer.setup_nat( + foo_name().parse().unwrap(), + foo_resource_id(), + BTreeSet::from([foo_real_ip2().into()]), // Setting up with a new IP! + BTreeSet::from([proxy_ip1().into(), proxy_ip2().into()]), + ) + .unwrap(); - let result = peer.translate_outbound(request, now).unwrap(); + let result = peer.translate_outbound(request, now).unwrap(); - assert!(matches!(result, TranslateOutboundResult::Send(_))); + assert!(matches!(result, TranslateOutboundResult::Send(_))); - let response = ip_packet::make::udp_packet( - foo_real_ip1(), - client_tun_ipv4(), - foo_allowed_port(), - 1, - vec![0, 0, 0, 0, 0, 0, 0, 0], - ) - .unwrap(); + let response = ip_packet::make::udp_packet( + foo_real_ip1(), + client_tun_ipv4(), + foo_allowed_port(), + 1, + vec![0, 0, 0, 0, 0, 0, 0, 0], + ) + .unwrap(); - let response = peer.translate_inbound(response, now).unwrap(); + let response = peer.translate_inbound(response, now).unwrap(); - assert!(response.is_some()); + assert!(response.is_some()); + } + + { + let request = ip_packet::make::udp_packet( + client_tun_ipv4(), + proxy_ip1(), + 2, // Using a new source port + foo_allowed_port(), + vec![0, 0, 0, 0, 0, 0, 0, 0], + ) + .unwrap(); + + let result = peer.translate_outbound(request, now).unwrap(); + + let TranslateOutboundResult::Send(outside_packet) = result else { + panic!("Wrong result"); + }; + + assert_eq!( + outside_packet.destination(), + foo_real_ip2(), + "Request with a new source port should use new IP" + ); + + let response = ip_packet::make::udp_packet( + foo_real_ip2(), + client_tun_ipv4(), + foo_allowed_port(), + 2, + vec![0, 0, 0, 0, 0, 0, 0, 0], + ) + .unwrap(); + + let response = peer.translate_inbound(response, now).unwrap(); + + assert!(response.is_some()); + } } #[test] diff --git a/rust/connlib/tunnel/src/gateway/nat_table.rs b/rust/connlib/tunnel/src/gateway/nat_table.rs index 46edc06ad..d7ca033d7 100644 --- a/rust/connlib/tunnel/src/gateway/nat_table.rs +++ b/rust/connlib/tunnel/src/gateway/nat_table.rs @@ -45,11 +45,6 @@ impl NatTable { } } - /// Returns true if the NAT table has any entries with the given "inside" IP address. - pub(crate) fn has_entry_for_inside(&self, ip: IpAddr) -> bool { - self.table.left_values().any(|(_, c)| c == &ip) - } - pub(crate) fn translate_outgoing( &mut self, packet: &IpPacket, @@ -62,25 +57,21 @@ impl NatTable { let inside = (src, dst); if let Some(outside) = self.table.get_by_left(&inside).copied() { - if outside.1 == outside_dst { - tracing::trace!(?inside, ?outside, "Translating outgoing packet"); + tracing::trace!(?inside, ?outside, "Translating outgoing packet"); - if packet.as_tcp().is_some_and(|tcp| tcp.rst()) { - tracing::debug!( - ?inside, - ?outside, - "Witnessed outgoing TCP RST, removing NAT session" - ); + if packet.as_tcp().is_some_and(|tcp| tcp.rst()) { + tracing::debug!( + ?inside, + ?outside, + "Witnessed outgoing TCP RST, removing NAT session" + ); - self.table.remove_by_left(&inside); - self.expired.insert(outside); - } - - self.last_seen.insert(outside, now); - return Ok(outside); + self.table.remove_by_left(&inside); + self.expired.insert(outside); } - tracing::trace!(?inside, ?outside, "Outgoing packet for expired translation"); + self.last_seen.insert(outside, now); + return Ok(outside); } // Find the first available public port, starting from the port of the to-be-mapped packet. diff --git a/rust/connlib/tunnel/src/tests.rs b/rust/connlib/tunnel/src/tests.rs index 8b686309b..7145106e9 100644 --- a/rust/connlib/tunnel/src/tests.rs +++ b/rust/connlib/tunnel/src/tests.rs @@ -1,5 +1,6 @@ use crate::tests::{flux_capacitor::FluxCapacitor, sut::TunnelTest}; use assertions::PanicOnErrorEvents; +use chrono::Utc; use core::fmt; use proptest::{ sample::SizeRange, @@ -8,7 +9,10 @@ use proptest::{ }; use proptest_state_machine::Sequential; use reference::ReferenceState; -use std::sync::atomic::{self, AtomicU32}; +use std::{ + sync::atomic::{self, AtomicU32}, + time::Instant, +}; use tracing_subscriber::{ EnvFilter, Layer, layer::SubscriberExt as _, util::SubscriberInitExt as _, }; @@ -46,20 +50,27 @@ fn tunnel_test() { let _ = std::fs::remove_dir_all("testcases"); let _ = std::fs::create_dir_all("testcases"); + let now = Instant::now(); + let utc_now = Utc::now(); + let flux_capacitor = FluxCapacitor::new(now, utc_now); + let test_runner = &mut TestRunner::new(config); let strategy = Sequential::new( SizeRange::new(5..=15), - ReferenceState::initial_state, + move || ReferenceState::initial_state(now), ReferenceState::is_valid_transition, - ReferenceState::transitions, - ReferenceState::apply, + move |state| ReferenceState::transitions(state, now), + { + let flux_capacitor = flux_capacitor.clone(); + + move |state, transition| ReferenceState::apply(state, transition, flux_capacitor.now()) + }, ); let result = test_runner.run( &strategy, |(mut ref_state, transitions, mut seen_counter)| { let test_index = test_index.fetch_add(1, std::sync::atomic::Ordering::Relaxed); - let flux_capacitor = FluxCapacitor::default(); let _guard = init_logging(flux_capacitor.clone(), test_index); @@ -78,7 +89,7 @@ fn tunnel_test() { println!("Running test case {test_index:04} with {num_transitions:02} transitions"); - let mut sut = TunnelTest::init_test(&ref_state, flux_capacitor); + let mut sut = TunnelTest::init_test(&ref_state, flux_capacitor.clone()); // Check the invariants on the initial state TunnelTest::check_invariants(&sut, &ref_state); @@ -99,7 +110,7 @@ fn tunnel_test() { ); // Apply the transition on the states - ref_state = ReferenceState::apply(ref_state, transition); + ref_state = ReferenceState::apply(ref_state, transition, flux_capacitor.now()); sut = TunnelTest::apply(sut, &ref_state, transition.clone()); // Check the invariants after the transition is applied @@ -129,9 +140,11 @@ fn tunnel_test() { #[test] fn reference_state_is_deterministic() { + let now = Instant::now(); + for n in 0..1000 { - let state1 = sample_from_strategy(n, ReferenceState::initial_state()); - let state2 = sample_from_strategy(n, ReferenceState::initial_state()); + let state1 = sample_from_strategy(n, ReferenceState::initial_state(now)); + let state2 = sample_from_strategy(n, ReferenceState::initial_state(now)); assert_eq!(format!("{state1:?}"), format!("{state2:?}")); } @@ -139,10 +152,12 @@ fn reference_state_is_deterministic() { #[test] fn transitions_are_deterministic() { + let now = Instant::now(); + for n in 0..1000 { - let state = sample_from_strategy(n, ReferenceState::initial_state()); - let transitions1 = sample_from_strategy(n, ReferenceState::transitions(&state)); - let transitions2 = sample_from_strategy(n, ReferenceState::transitions(&state)); + let state = sample_from_strategy(n, ReferenceState::initial_state(now)); + let transitions1 = sample_from_strategy(n, ReferenceState::transitions(&state, now)); + let transitions2 = sample_from_strategy(n, ReferenceState::transitions(&state, now)); assert_eq!(format!("{transitions1:?}"), format!("{transitions2:?}")); } diff --git a/rust/connlib/tunnel/src/tests/assertions.rs b/rust/connlib/tunnel/src/tests/assertions.rs index cd10ae646..ad6d7804a 100644 --- a/rust/connlib/tunnel/src/tests/assertions.rs +++ b/rust/connlib/tunnel/src/tests/assertions.rs @@ -5,6 +5,7 @@ use super::{ transition::{Destination, ReplyTo}, }; use connlib_model::GatewayId; +use dns_types::DomainName; use ip_packet::IpPacket; use itertools::Itertools; use std::{ @@ -13,6 +14,7 @@ use std::{ marker::PhantomData, net::{IpAddr, SocketAddr}, sync::atomic::{AtomicBool, Ordering}, + time::Instant, }; use tracing::{Level, Span, Subscriber}; use tracing_subscriber::Layer; @@ -33,10 +35,15 @@ pub(crate) fn assert_icmp_packets_properties( .iter() .map(|(g, s)| (*g, &s.received_icmp_requests)) .collect(); + let dns_query_timestamps = sim_gateways + .iter() + .map(|(g, s)| (*g, &s.dns_query_timestamps)) + .collect(); assert_packets_properties( ref_client, &sim_client.sent_icmp_requests, + &dns_query_timestamps, &received_icmp_requests, &ref_client.expected_icmp_handshakes, &sim_client.received_icmp_replies, @@ -62,10 +69,15 @@ pub(crate) fn assert_udp_packets_properties( .iter() .map(|(g, s)| (*g, &s.received_udp_requests)) .collect(); + let dns_query_timestamps = sim_gateways + .iter() + .map(|(g, s)| (*g, &s.dns_query_timestamps)) + .collect(); assert_packets_properties( ref_client, &sim_client.sent_udp_requests, + &dns_query_timestamps, &received_udp_requests, &ref_client.expected_udp_handshakes, &sim_client.received_udp_replies, @@ -133,8 +145,9 @@ pub(crate) fn assert_resource_status(ref_client: &RefClient, sim_client: &SimCli if expected_status_map != actual_status_map { for (resource, expected_status) in expected_status_map { match actual_status_map.get(resource) { - // For resources with TCP connections, the expected status might be off. - // The TCP client sends its own keep-alive's so we cannot always track the internal connection state. + // For resources with TCP connections, the expected status might be wrong. + // We generally expect them to always be online because the TCP client sends its own keep-alive's. + // However, if we have sent an ICMP error back, the client may have given up and therefore it is okay for the site to be in `Unknown` then. Some(&Online) if expected_status == &Unknown && tcp_resources.contains(resource) => {} Some(&Unknown) @@ -161,7 +174,8 @@ pub(crate) fn assert_resource_status(ref_client: &RefClient, sim_client: &SimCli fn assert_packets_properties( ref_client: &RefClient, sent_requests: &HashMap<(T, U), IpPacket>, - received_requests: &BTreeMap>, + dns_query_timestamps: &BTreeMap>>, + received_requests: &BTreeMap>, expected_handshakes: &BTreeMap>, received_replies: &BTreeMap<(T, U), IpPacket>, packet_protocol: &str, @@ -182,13 +196,14 @@ fn assert_packets_properties( tracing::error!(target: "assertions", ?unexpected_replies, ?expected_handshakes, ?received_replies, "❌ Unexpected {packet_protocol} replies on client"); } - let mut mapping = HashMap::new(); + let mut mappings = HashMap::new(); // Assert properties of the individual handshakes per gateway. // Due to connlib's implementation of NAT64, we cannot match the packets sent by the client to the packets arriving at the resource by port or ICMP identifier. // Thus, we rely on a custom u64 payload attached to all packets to uniquely identify every individual packet. for (gateway, expected_handshakes) in expected_handshakes { let received_requests = received_requests.get(gateway).unwrap(); + let dns_query_timestamps = dns_query_timestamps.get(gateway).unwrap(); let mut num_expected_handshakes = expected_handshakes.len(); @@ -205,7 +220,8 @@ fn assert_packets_properties( }; assert_correct_src_and_dst_ips(client_sent_request, client_received_reply); - let Some(gateway_received_request) = received_requests.get(payload) else { + let Some((packet_sent_at, gateway_received_request)) = received_requests.get(payload) + else { if client_received_reply .icmp_error() .ok() @@ -234,16 +250,38 @@ fn assert_packets_properties( assert_destination_is_cdir_resource(gateway_received_request, resource_dst) } Destination::DomainName { name, .. } => { + let Some(query_timestamps) = dns_query_timestamps.get(name) else { + tracing::error!(%name, "Should have resolved domain at least once"); + continue; + }; + + // To correct assert whether the packet was routed to the correct IP, we need to find the timestamp of the DNS query closest to the packet timestamp. + // In other words: Packets should always use the IPs that were most recently resolved when they were sent. + let Some(dns_record_snapshot) = query_timestamps + .iter() + .filter(|query_timestamp| *query_timestamp <= packet_sent_at) + .max() + else { + tracing::error!(%name, "Should have a relevant query timestamp"); + continue; + }; + + // Split the proxy IP mapping by DNS record snapshot. + // + // When we re-resolve DNS, the mapping is allowed to change. + let mapping = mappings.entry(dns_record_snapshot).or_default(); + assert_destination_is_dns_resource( gateway_received_request, global_dns_records, name, + *dns_record_snapshot, ); assert_proxy_ip_mapping_is_stable( client_sent_request, gateway_received_request, - &mut mapping, + mapping, ) } } @@ -415,10 +453,11 @@ fn assert_destination_is_dns_resource( gateway_received_request: &IpPacket, global_dns_records: &DnsRecords, domain: &dns_types::DomainName, + at: Instant, ) { let actual = gateway_received_request.destination(); let possible_resource_ips = global_dns_records - .domain_ips_iter(domain) + .domain_ips_iter(domain, at) .collect::>(); if !possible_resource_ips.contains(&actual) { diff --git a/rust/connlib/tunnel/src/tests/dns_records.rs b/rust/connlib/tunnel/src/tests/dns_records.rs index dc41eb468..d9102aace 100644 --- a/rust/connlib/tunnel/src/tests/dns_records.rs +++ b/rust/connlib/tunnel/src/tests/dns_records.rs @@ -1,7 +1,5 @@ -use std::{ - collections::{BTreeMap, BTreeSet}, - net::IpAddr, -}; +use std::collections::BTreeSet; +use std::{collections::BTreeMap, net::IpAddr, time::Instant}; use dns_types::prelude::*; use dns_types::{DomainName, OwnedRecordData, RecordType}; @@ -9,22 +7,26 @@ use itertools::Itertools; #[derive(Debug, Default, Clone)] pub(crate) struct DnsRecords { - inner: BTreeMap>, + inner: BTreeMap>>, } impl DnsRecords { - pub(crate) fn domain_ips_iter(&self, name: &DomainName) -> impl Iterator + '_ { + pub(crate) fn domain_ips_iter( + &self, + name: &DomainName, + at: Instant, + ) -> impl Iterator + '_ { #[expect(clippy::wildcard_enum_match_arm)] - self.domain_records_iter(name).filter_map(|r| match r { + self.domain_records_iter(name, at).filter_map(|r| match r { OwnedRecordData::A(a) => Some(a.addr().into()), OwnedRecordData::Aaaa(aaaa) => Some(aaaa.addr().into()), _ => None, }) } - pub(crate) fn ips_iter(&self) -> impl Iterator + '_ { + pub(crate) fn ips_iter(&self, at: Instant) -> impl Iterator + '_ { #[expect(clippy::wildcard_enum_match_arm)] - self.inner.values().flatten().filter_map(|r| match r { + self.records_at(at).filter_map(|(_, r)| match r { OwnedRecordData::A(a) => Some(a.addr().into()), OwnedRecordData::Aaaa(aaaa) => Some(aaaa.addr().into()), _ => None, @@ -34,8 +36,12 @@ impl DnsRecords { pub(crate) fn domain_records_iter( &self, name: &DomainName, + at: Instant, ) -> impl Iterator + '_ { - self.inner.get(name).cloned().into_iter().flatten() + let name = name.clone(); + + self.records_at(at) + .filter_map(move |(domain, records)| (domain == &name).then_some(records.clone())) } pub(crate) fn domains_iter(&self) -> impl Iterator + '_ { @@ -43,11 +49,18 @@ impl DnsRecords { } pub(crate) fn merge(&mut self, other: Self) { - self.inner.extend(other.inner); + for (domain, records) in other.inner { + for (timestamp, records) in records { + self.inner + .entry(domain.clone()) + .or_default() + .insert(timestamp, records); + } + } } - pub(crate) fn domain_rtypes(&self, name: &DomainName) -> Vec { - self.domain_records_iter(name) + pub(crate) fn domain_rtypes(&self, name: &DomainName, at: Instant) -> Vec { + self.domain_records_iter(name, at) .map(|r| r.rtype()) .dedup() .collect_vec() @@ -56,11 +69,25 @@ impl DnsRecords { pub(crate) fn is_empty(&self) -> bool { self.inner.is_empty() } + + fn records_at( + &self, + at: Instant, + ) -> impl Iterator + '_ { + self.inner.iter().flat_map(move |(domain, records)| { + records + .iter() + .filter(|(timestamp, _)| **timestamp <= at) + .max_by_key(|(timestamp, _)| **timestamp) + .into_iter() + .flat_map(move |(_, records)| records.iter().map(move |records| (domain, records))) + }) + } } impl From for DnsRecords where - BTreeMap>: From, + BTreeMap>>: From, { fn from(value: I) -> Self { Self { @@ -71,7 +98,7 @@ where impl FromIterator for DnsRecords where - BTreeMap>: FromIterator, + BTreeMap>>: FromIterator, { fn from_iter>(iter: T) -> Self { Self { @@ -79,3 +106,77 @@ where } } } + +#[cfg(test)] +mod tests { + use std::time::Duration; + + use dns_types::DomainNameRef; + + use super::*; + + #[test] + fn returns_most_recent_records_at_timestamp() { + let now = Instant::now(); + + let mut dns_records = DnsRecords::default(); + + dns_records.merge(DnsRecords::from([( + EXAMPLE_COM.to_vec(), + BTreeMap::from([( + now, + BTreeSet::from([a_record("127.0.0.1"), a_record("127.0.0.2")]), + )]), + )])); + dns_records.merge(DnsRecords::from([( + EXAMPLE_COM.to_vec(), + BTreeMap::from([( + now + Duration::from_secs(5), + BTreeSet::from([a_record("127.0.0.3"), a_record("127.0.0.4")]), + )]), + )])); + dns_records.merge(DnsRecords::from([( + EXAMPLE_COM.to_vec(), + BTreeMap::from([( + now + Duration::from_secs(10), + BTreeSet::from([a_record("127.0.0.5"), a_record("127.0.0.6")]), + )]), + )])); + + assert_eq!( + dns_records + .domain_ips_iter(&EXAMPLE_COM.to_vec(), now) + .collect::>(), + vec![ip("127.0.0.1"), ip("127.0.0.2")] + ); + assert_eq!( + dns_records + .domain_ips_iter(&EXAMPLE_COM.to_vec(), now + Duration::from_secs(2)) + .collect::>(), + vec![ip("127.0.0.1"), ip("127.0.0.2")] + ); + assert_eq!( + dns_records + .domain_ips_iter(&EXAMPLE_COM.to_vec(), now + Duration::from_secs(7)) + .collect::>(), + vec![ip("127.0.0.3"), ip("127.0.0.4")] + ); + assert_eq!( + dns_records + .domain_ips_iter(&EXAMPLE_COM.to_vec(), now + Duration::from_secs(12)) + .collect::>(), + vec![ip("127.0.0.5"), ip("127.0.0.6")] + ); + } + + const EXAMPLE_COM: DomainNameRef = + unsafe { DomainNameRef::from_octets_unchecked(b"\x08example\x03com\x00") }; + + fn a_record(ip: &str) -> OwnedRecordData { + OwnedRecordData::A(ip.parse().unwrap()) + } + + fn ip(ip: &str) -> IpAddr { + ip.parse().unwrap() + } +} diff --git a/rust/connlib/tunnel/src/tests/dns_server_resource.rs b/rust/connlib/tunnel/src/tests/dns_server_resource.rs index 21c667ee3..bf60d8bff 100644 --- a/rust/connlib/tunnel/src/tests/dns_server_resource.rs +++ b/rust/connlib/tunnel/src/tests/dns_server_resource.rs @@ -35,7 +35,7 @@ impl TcpDnsServerResource { pub fn handle_timeout(&mut self, global_dns_records: &DnsRecords, now: Instant) { self.server.handle_timeout(now); while let Some(query) = self.server.poll_queries() { - let response = handle_dns_query(&query.message, global_dns_records); + let response = handle_dns_query(&query.message, global_dns_records, now); self.server .send_message(query.local, query.remote, response) @@ -53,12 +53,12 @@ impl UdpDnsServerResource { self.inbound_packets.push_back(packet); } - pub fn handle_timeout(&mut self, global_dns_records: &DnsRecords, _: Instant) { + pub fn handle_timeout(&mut self, global_dns_records: &DnsRecords, now: Instant) { while let Some(packet) = self.inbound_packets.pop_front() { let udp = packet.as_udp().unwrap(); let query = dns_types::Query::parse(udp.payload()).unwrap(); - let response = handle_dns_query(&query, global_dns_records); + let response = handle_dns_query(&query, global_dns_records, now); self.outbound_packets.push_back( ip_packet::make::udp_packet( @@ -81,13 +81,14 @@ impl UdpDnsServerResource { fn handle_dns_query( query: &dns_types::Query, global_dns_records: &DnsRecords, + at: Instant, ) -> dns_types::Response { const TTL: u32 = 1; // We deliberately chose a short TTL so we don't have to model the DNS cache in these tests. let domain = query.domain().to_vec(); let records = global_dns_records - .domain_records_iter(&domain) + .domain_records_iter(&domain, at) .filter(|r| r.rtype() == query.qtype()) .map(|rdata| (domain.clone(), TTL, rdata)); diff --git a/rust/connlib/tunnel/src/tests/flux_capacitor.rs b/rust/connlib/tunnel/src/tests/flux_capacitor.rs index 308781de6..8f4aceac4 100644 --- a/rust/connlib/tunnel/src/tests/flux_capacitor.rs +++ b/rust/connlib/tunnel/src/tests/flux_capacitor.rs @@ -20,19 +20,14 @@ impl FormatTime for FluxCapacitor { } } -impl Default for FluxCapacitor { - fn default() -> Self { - let start = Instant::now(); - let utc_start = Utc::now(); - +impl FluxCapacitor { + pub(crate) fn new(start: Instant, utc_start: DateTime) -> Self { Self { start, now: Arc::new(Mutex::new((start, utc_start))), } } -} -impl FluxCapacitor { const SMALL_TICK: Duration = Duration::from_millis(10); const LARGE_TICK: Duration = Duration::from_millis(100); diff --git a/rust/connlib/tunnel/src/tests/icmp_error_hosts.rs b/rust/connlib/tunnel/src/tests/icmp_error_hosts.rs index 554059293..e9376d4c8 100644 --- a/rust/connlib/tunnel/src/tests/icmp_error_hosts.rs +++ b/rust/connlib/tunnel/src/tests/icmp_error_hosts.rs @@ -1,6 +1,7 @@ use std::{ collections::{BTreeMap, BTreeSet}, net::IpAddr, + time::Instant, }; use proptest::{prelude::*, sample}; @@ -21,9 +22,10 @@ impl IcmpErrorHosts { /// Samples a subset of the provided DNS records which we will generate ICMP errors. pub(crate) fn icmp_error_hosts( dns_resource_records: DnsRecords, + now: Instant, ) -> impl Strategy { // First, deduplicate all IPs. - let unique_ips = dns_resource_records.ips_iter().collect::>(); + let unique_ips = dns_resource_records.ips_iter(now).collect::>(); let ips = Vec::from_iter(unique_ips); Just(ips) @@ -31,7 +33,7 @@ pub(crate) fn icmp_error_hosts( .prop_flat_map(|ips| { let num_ips = ips.len(); - sample::subsequence(ips, 0..num_ips) // Pick a subset of IPs. + sample::subsequence(ips, num_ips / 2) // Pick a subset of IPs. }) .prop_flat_map(|ips| { ips.into_iter() diff --git a/rust/connlib/tunnel/src/tests/reference.rs b/rust/connlib/tunnel/src/tests/reference.rs index 845d7015a..d3efbb2a4 100644 --- a/rust/connlib/tunnel/src/tests/reference.rs +++ b/rust/connlib/tunnel/src/tests/reference.rs @@ -12,8 +12,10 @@ use dns_types::{DomainName, RecordType}; use ip_network::{Ipv4Network, Ipv6Network}; use itertools::Itertools; use prop::sample::select; +use proptest::collection::btree_set; use proptest::{prelude::*, sample}; use std::net::{Ipv4Addr, Ipv6Addr}; +use std::time::Instant; use std::{ collections::{BTreeMap, BTreeSet, HashSet}, fmt, iter, @@ -53,14 +55,14 @@ pub(crate) struct ReferenceState { /// Care has to be taken that we don't implement things in a buggy way here. /// After all, if your test has bugs, it won't catch any in the actual implementation. impl ReferenceState { - pub(crate) fn initial_state() -> BoxedStrategy { + pub(crate) fn initial_state(start: Instant) -> BoxedStrategy { stub_portal() - .prop_flat_map(|portal| { - let gateways = portal.gateways(); - let dns_resource_records = portal.dns_resource_records(); + .prop_flat_map(move |portal| { + let gateways = portal.gateways(start); + let dns_resource_records = portal.dns_resource_records(start); let client = portal.client(system_dns_servers(), upstream_dns_servers()); let relays = relays(relay_id()); - let global_dns_records = global_dns_records(); // Start out with a set of global DNS records so we have something to resolve outside of DNS resources. + let global_dns_records = global_dns_records(start); // Start out with a set of global DNS records so we have something to resolve outside of DNS resources. let drop_direct_client_traffic = any::(); ( @@ -74,7 +76,7 @@ impl ReferenceState { ) }) .prop_flat_map( - |( + move |( client, gateways, portal, @@ -88,7 +90,7 @@ impl ReferenceState { Just(gateways), Just(portal), Just(dns_resource_records.clone()), - icmp_error_hosts(dns_resource_records), + icmp_error_hosts(dns_resource_records, start), Just(relays), Just(global_dns), Just(drop_direct_client_traffic), @@ -96,7 +98,7 @@ impl ReferenceState { }, ) .prop_flat_map( - |( + move |( client, gateways, portal, @@ -112,7 +114,7 @@ impl ReferenceState { Just(portal), Just(dns_resource_records.clone()), Just(icmp_error_hosts.clone()), - tcp_resources(dns_resource_records, icmp_error_hosts), + tcp_resources(dns_resource_records, icmp_error_hosts, start), Just(relays), Just(global_dns), Just(drop_direct_client_traffic), @@ -178,7 +180,7 @@ impl ReferenceState { }, ) .prop_map( - |( + move |( client, gateways, relays, @@ -209,7 +211,7 @@ impl ReferenceState { /// /// This is invoked by proptest repeatedly to explore further state transitions. /// Here, we should only generate [`Transition`]s that make sense for the current state. - pub(crate) fn transitions(state: &Self) -> BoxedStrategy { + pub(crate) fn transitions(state: &Self, now: Instant) -> BoxedStrategy { CompositeStrategy::default() .with( 1, @@ -344,9 +346,33 @@ impl ReferenceState { connect_tcp(Just(tunnel_ip6), select(dns_v6_domains)) }, ) + .with_if_not_empty( + 10, + state.resolved_v4_domains_with_icmp_errors(now), + |dns_v4_domains| { + let tunnel_ip4 = state.client.inner().tunnel_ip4; + + prop_oneof![ + icmp_packet(Just(tunnel_ip4), select(dns_v4_domains.clone())), + udp_packet(Just(tunnel_ip4), select(dns_v4_domains)), + ] + }, + ) + .with_if_not_empty( + 10, + state.resolved_v6_domains_with_icmp_errors(now), + |dns_v6_domains| { + let tunnel_ip6 = state.client.inner().tunnel_ip6; + + prop_oneof![ + icmp_packet(Just(tunnel_ip6), select(dns_v6_domains.clone()),), + udp_packet(Just(tunnel_ip6), select(dns_v6_domains),), + ] + }, + ) .with_if_not_empty( 5, - (state.all_domains(), state.reachable_dns_servers()), + (state.all_domains(now), state.reachable_dns_servers()), |(domains, dns_servers)| { dns_queries(sample::select(domains), sample::select(dns_servers)) .prop_map(Transition::SendDnsQueries) @@ -384,7 +410,7 @@ impl ReferenceState { state .client .inner() - .resolved_ip4_for_non_resources(&state.global_dns_records), + .resolved_ip4_for_non_resources(&state.global_dns_records, now), |resolved_non_resource_ip4s| { let tunnel_ip4 = state.client.inner().tunnel_ip4; @@ -399,7 +425,7 @@ impl ReferenceState { state .client .inner() - .resolved_ip6_for_non_resources(&state.global_dns_records), + .resolved_ip6_for_non_resources(&state.global_dns_records, now), |resolved_non_resource_ip6s| { let tunnel_ip6 = state.client.inner().tunnel_ip6; @@ -425,13 +451,17 @@ impl ReferenceState { udp_packet(Just(tunnel_ip6), select_host_v6(&gateway_ips)), ] }) + .with_if_not_empty(5, state.dns_resource_domains(), |domains| { + (sample::select(domains), btree_set(dns_record(), 1..6)) + .prop_map(|(domain, records)| Transition::UpdateDnsRecords { domain, records }) + }) .boxed() } /// Apply the transition to our reference state. /// /// Here is where we implement the "expected" logic. - pub(crate) fn apply(mut state: Self, transition: &Transition) -> Self { + pub(crate) fn apply(mut state: Self, transition: &Transition, now: Instant) -> Self { match transition { Transition::AddResource(resource) => { state.client.exec_mut(|client| match resource { @@ -591,6 +621,12 @@ impl ReferenceState { Transition::RestartClient(key) => state.client.exec_mut(|c| { c.restart(*key); }), + Transition::UpdateDnsRecords { domain, records } => { + state.global_dns_records.merge(DnsRecords::from([( + domain.clone(), + BTreeMap::from([(now, records.clone())]), + )])); + } }; state @@ -785,6 +821,7 @@ impl ReferenceState { // Also don't deactivate resources where we have TCP connections as those would get interrupted. has_resource && has_gateway_for_resource && !has_tcp_connection } + Transition::UpdateDnsRecords { .. } => true, } } @@ -833,21 +870,36 @@ impl ReferenceState { impl ReferenceState { // We surface what are the existing rtypes for a domain so that it's easier // for the proptests to hit an existing record. - fn all_domains(&self) -> Vec<(DomainName, Vec)> { + fn all_domains(&self, now: Instant) -> Vec<(DomainName, Vec)> { fn domains_and_rtypes( records: &DnsRecords, + at: Instant, ) -> impl Iterator)> { records .domains_iter() - .map(|d| (d.clone(), records.domain_rtypes(&d))) + .map(move |d| (d.clone(), records.domain_rtypes(&d, at))) } // We may have multiple gateways in a site, so we need to dedup. let unique_domains = self .gateways .values() - .flat_map(|g| domains_and_rtypes(g.inner().dns_records())) - .chain(domains_and_rtypes(&self.global_dns_records)) + .flat_map(|g| domains_and_rtypes(g.inner().dns_records(), now)) + .chain(domains_and_rtypes(&self.global_dns_records, now)) + .filter(|(_, rtypes)| !rtypes.is_empty()) + .collect::>(); + + Vec::from_iter(unique_domains) + } + + fn dns_resource_domains(&self) -> Vec { + // We may have multiple gateways in a site, so we need to dedup. + let unique_domains = self + .gateways + .values() + .flat_map(|g| g.inner().dns_records().domains_iter()) + .chain(self.global_dns_records.domains_iter()) + .filter(|d| self.client.inner().dns_resource_by_domain(d).is_some()) .collect::>(); Vec::from_iter(unique_domains) @@ -965,6 +1017,32 @@ impl ReferenceState { .collect() } + fn resolved_v4_domains_with_icmp_errors(&self, at: Instant) -> Vec { + self.client + .inner() + .resolved_v4_domains() + .into_iter() + .filter(|d| { + self.global_dns_records + .domain_ips_iter(d, at) + .any(|ip| self.icmp_error_hosts.icmp_error_for_ip(ip).is_some()) + }) + .collect() + } + + fn resolved_v6_domains_with_icmp_errors(&self, at: Instant) -> Vec { + self.client + .inner() + .resolved_v6_domains() + .into_iter() + .filter(|d| { + self.global_dns_records + .domain_ips_iter(d, at) + .any(|ip| self.icmp_error_hosts.icmp_error_for_ip(ip).is_some()) + }) + .collect() + } + fn deploy_new_relays(&mut self, new_relays: &BTreeMap>) { // Always take down all relays because we can't know which one was sampled for the connection. for relay in self.relays.values() { diff --git a/rust/connlib/tunnel/src/tests/sim_client.rs b/rust/connlib/tunnel/src/tests/sim_client.rs index 9ff74dcfd..579d4b1f8 100644 --- a/rust/connlib/tunnel/src/tests/sim_client.rs +++ b/rust/connlib/tunnel/src/tests/sim_client.rs @@ -627,6 +627,16 @@ impl RefClient { for status in self.site_status.values_mut() { *status = ResourceStatus::Unknown; } + + // TCP connections will automatically re-create connections to Gateways. + for r in self + .expected_tcp_connections + .values() + .copied() + .collect::>() + { + self.set_resource_online(r); + } } pub(crate) fn add_internet_resource(&mut self, resource: InternetResource) { @@ -643,36 +653,36 @@ impl RefClient { pub(crate) fn add_cidr_resource(&mut self, r: CidrResource) { let address = r.address; let r = Resource::Cidr(r); + let rid = r.id(); - if let Some(existing) = self - .resources - .iter() - .find(|existing| existing.id() == r.id()) + if let Some(existing) = self.resources.iter().find(|existing| existing.id() == rid) && (existing.has_different_address(&r) || existing.has_different_site(&r)) { self.remove_resource(&existing.id()); } - self.resources.push(r.clone()); + self.resources.push(r); self.cidr_resources = self.recalculate_cidr_routes(); match address { IpNetwork::V4(v4) => { - self.ipv4_routes.insert(r.id(), v4); + self.ipv4_routes.insert(rid, v4); } IpNetwork::V6(v6) => { - self.ipv6_routes.insert(r.id(), v6); + self.ipv6_routes.insert(rid, v6); } } + + if self.expected_tcp_connections.values().contains(&rid) { + self.set_resource_online(rid); + } } pub(crate) fn add_dns_resource(&mut self, r: DnsResource) { let r = Resource::Dns(r); + let rid = r.id(); - if let Some(existing) = self - .resources - .iter() - .find(|existing| existing.id() == r.id()) + if let Some(existing) = self.resources.iter().find(|existing| existing.id() == rid) && (existing.has_different_address(&r) || existing.has_different_ip_stack(&r) || existing.has_different_site(&r)) @@ -681,6 +691,10 @@ impl RefClient { } self.resources.push(r); + + if self.expected_tcp_connections.values().contains(&rid) { + self.set_resource_online(rid); + } } /// Re-adds all resources in the order they have been initially added. @@ -700,10 +714,10 @@ impl RefClient { &self, has_failed_tcp_connection: impl Fn((SPort, DPort)) -> bool, ) -> (BTreeMap, BTreeSet) { - let maybe_online_sites = self + let resources_with_failed_tcp_connections = self .expected_tcp_connections .iter() - .filter(|((_, _, sport, dport), _)| !has_failed_tcp_connection((*sport, *dport))) + .filter(|((_, _, sport, dport), _)| has_failed_tcp_connection((*sport, *dport))) .filter_map(|(_, resource)| self.site_for_resource(*resource)) .flat_map(|site| { self.resources @@ -726,7 +740,7 @@ impl RefClient { }) .collect(); - (resource_status, maybe_online_sites) + (resource_status, resources_with_failed_tcp_connections) } pub(crate) fn tunnel_ip_for(&self, dst: IpAddr) -> IpAddr { @@ -1095,8 +1109,9 @@ impl RefClient { pub(crate) fn resolved_ip4_for_non_resources( &self, global_dns_records: &DnsRecords, + at: Instant, ) -> Vec { - self.resolved_ips_for_non_resources(global_dns_records) + self.resolved_ips_for_non_resources(global_dns_records, at) .filter_map(|ip| match ip { IpAddr::V4(v4) => Some(v4), IpAddr::V6(_) => None, @@ -1107,8 +1122,9 @@ impl RefClient { pub(crate) fn resolved_ip6_for_non_resources( &self, global_dns_records: &DnsRecords, + at: Instant, ) -> Vec { - self.resolved_ips_for_non_resources(global_dns_records) + self.resolved_ips_for_non_resources(global_dns_records, at) .filter_map(|ip| match ip { IpAddr::V6(v6) => Some(v6), IpAddr::V4(_) => None, @@ -1119,13 +1135,14 @@ impl RefClient { fn resolved_ips_for_non_resources<'a>( &'a self, global_dns_records: &'a DnsRecords, + at: Instant, ) -> impl Iterator + 'a { self.dns_records .iter() - .filter_map(|(domain, _)| { + .filter_map(move |(domain, _)| { self.dns_resource_by_domain(domain) .is_none() - .then_some(global_dns_records.domain_ips_iter(domain)) + .then_some(global_dns_records.domain_ips_iter(domain, at)) }) .flatten() } diff --git a/rust/connlib/tunnel/src/tests/sim_gateway.rs b/rust/connlib/tunnel/src/tests/sim_gateway.rs index 6dd2e0729..9a4571895 100644 --- a/rust/connlib/tunnel/src/tests/sim_gateway.rs +++ b/rust/connlib/tunnel/src/tests/sim_gateway.rs @@ -11,6 +11,7 @@ use crate::{GatewayState, IpConfig}; use anyhow::{Result, bail}; use chrono::{DateTime, Utc}; use connlib_model::{GatewayId, RelayId}; +use dns_types::DomainName; use ip_packet::{IcmpEchoHeader, Icmpv4Type, Icmpv6Type, IpPacket}; use proptest::prelude::*; use snownet::Transmit; @@ -27,10 +28,13 @@ pub(crate) struct SimGateway { pub(crate) sut: GatewayState, /// The received ICMP packets, indexed by our custom ICMP payload. - pub(crate) received_icmp_requests: BTreeMap, + pub(crate) received_icmp_requests: BTreeMap, /// The received UDP packets, indexed by our custom UDP payload. - pub(crate) received_udp_requests: BTreeMap, + pub(crate) received_udp_requests: BTreeMap, + + /// The times we resolved DNS records for a domain. + pub(crate) dns_query_timestamps: BTreeMap>, site_specific_dns_records: DnsRecords, udp_dns_server_resources: BTreeMap, @@ -55,6 +59,7 @@ impl SimGateway { udp_dns_server_resources: Default::default(), tcp_dns_server_resources: Default::default(), received_udp_requests: Default::default(), + dns_query_timestamps: Default::default(), tcp_resources: tcp_resources .into_iter() .map(|address| { @@ -194,7 +199,7 @@ impl SimGateway { let packet_id = u64::from_be_bytes(*icmp.payload().first_chunk().unwrap()); tracing::debug!(%packet_id, "Received ICMP request"); self.received_icmp_requests - .insert(packet_id, packet.clone()); + .insert(packet_id, (now, packet.clone())); return self.handle_icmp_request(&packet, echo, icmp.payload(), icmp_error, now); } @@ -204,7 +209,7 @@ impl SimGateway { let packet_id = u64::from_be_bytes(*icmp.payload().first_chunk().unwrap()); tracing::debug!(%packet_id, "Received ICMP request"); self.received_icmp_requests - .insert(packet_id, packet.clone()); + .insert(packet_id, (now, packet.clone())); return self.handle_icmp_request(&packet, echo, icmp.payload(), icmp_error, now); } @@ -234,7 +239,7 @@ impl SimGateway { } if let Some(reply) = icmp_error.or_else(|| echo_reply(packet.clone())) { - self.request_received(&packet); + self.request_received(&packet, now); let transmit = self.sut.handle_tun_input(reply, now).unwrap()?; return Some(transmit); @@ -257,11 +262,12 @@ impl SimGateway { ) } - fn request_received(&mut self, packet: &IpPacket) { + fn request_received(&mut self, packet: &IpPacket, now: Instant) { if let Some(udp) = packet.as_udp() { let packet_id = u64::from_be_bytes(*udp.payload().first_chunk().unwrap()); tracing::debug!(%packet_id, "Received UDP request"); - self.received_udp_requests.insert(packet_id, packet.clone()); + self.received_udp_requests + .insert(packet_id, (now, packet.clone())); } } diff --git a/rust/connlib/tunnel/src/tests/strategies.rs b/rust/connlib/tunnel/src/tests/strategies.rs index bbefaf310..b9fea2e9a 100644 --- a/rust/connlib/tunnel/src/tests/strategies.rs +++ b/rust/connlib/tunnel/src/tests/strategies.rs @@ -15,22 +15,24 @@ use prop::sample; use proptest::{collection, prelude::*}; use std::iter; use std::num::NonZeroU16; +use std::time::Instant; use std::{ collections::{BTreeMap, BTreeSet}, net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, time::Duration, }; -pub(crate) fn global_dns_records() -> impl Strategy { +pub(crate) fn global_dns_records(at: Instant) -> impl Strategy { collection::btree_map( domain_name(2..4), - collection::btree_set(dns_record(), 1..6), + collection::btree_set(dns_record(), 1..6) + .prop_map(move |records| BTreeMap::from([(at, records)])), 0..5, ) .prop_map_into() } -fn dns_record() -> impl Strategy { +pub(crate) fn dns_record() -> impl Strategy { prop_oneof![ 3 => non_reserved_ip().prop_map(dns_types::records::ip), 1 => collection::vec(txt_record(), 6..=10) @@ -141,6 +143,7 @@ pub(crate) fn stub_portal() -> impl Strategy { pub(crate) fn tcp_resources( dns_records: DnsRecords, imcp_error_hosts: IcmpErrorHosts, + at: Instant, ) -> impl Strategy>> { let all_domains = dns_records.domains_iter().collect::>(); @@ -153,7 +156,7 @@ pub(crate) fn tcp_resources( .into_iter() .filter(|(domain, _)| { dns_records - .domain_ips_iter(domain) + .domain_ips_iter(domain, at) .all(|ip| imcp_error_hosts.icmp_error_for_ip(ip).is_none()) }) .map({ @@ -161,7 +164,7 @@ pub(crate) fn tcp_resources( move |(domain, port)| { let addresses = dns_records - .domain_ips_iter(&domain) + .domain_ips_iter(&domain, at) .map(|address| SocketAddr::new(address, port.get())) .collect::>(); @@ -311,6 +314,7 @@ pub(crate) fn resolved_ips() -> impl Strategy> pub(crate) fn subdomain_records( base: String, subdomains: impl Strategy, + at: Instant, ) -> impl Strategy { collection::hash_map(subdomains, resolved_ips(), 1..4).prop_map(move |subdomain_ips| { subdomain_ips @@ -318,7 +322,7 @@ pub(crate) fn subdomain_records( .map(|(label, ips)| { let domain = format!("{label}.{base}"); - (domain.parse().unwrap(), ips) + (domain.parse().unwrap(), BTreeMap::from([(at, ips)])) }) .collect() }) diff --git a/rust/connlib/tunnel/src/tests/stub_portal.rs b/rust/connlib/tunnel/src/tests/stub_portal.rs index cad6d2682..558ae5691 100644 --- a/rust/connlib/tunnel/src/tests/stub_portal.rs +++ b/rust/connlib/tunnel/src/tests/stub_portal.rs @@ -24,6 +24,7 @@ use std::{ collections::{BTreeMap, BTreeSet}, iter, net::{IpAddr, Ipv4Addr, Ipv6Addr}, + time::Instant, }; /// Stub implementation of the portal. @@ -261,6 +262,7 @@ impl StubPortal { pub(crate) fn gateways( &self, + at: Instant, ) -> impl Strategy>> + use<> { let dns_resources = self.dns_resources.clone(); @@ -273,7 +275,7 @@ impl StubPortal { ref_gateway_host( Just(*ipv4_addr), Just(*ipv6_addr), - site_specific_dns_records(dns_resources.clone(), *site_id), + site_specific_dns_records(dns_resources.clone(), *site_id, at), ), ) }) @@ -318,8 +320,11 @@ impl StubPortal { proptest::option::of(sample::select(possible_search_domains)) } - pub(crate) fn dns_resource_records(&self) -> impl Strategy + use<> { - dns_resource_records(self.dns_resources.clone().into_values()) + pub(crate) fn dns_resource_records( + &self, + at: Instant, + ) -> impl Strategy + use<> { + dns_resource_records(self.dns_resources.clone().into_values(), at) } } @@ -327,18 +332,20 @@ impl StubPortal { fn site_specific_dns_records( dns_resources: BTreeMap, site: SiteId, + at: Instant, ) -> impl Strategy { let dns_resources_in_site = dns_resources .into_values() .filter(move |resource| resource.sites.iter().any(|s| s.id == site)); - dns_resource_records(dns_resources_in_site).prop_flat_map(|records| { + dns_resource_records(dns_resources_in_site, at).prop_flat_map(move |records| { if records.is_empty() { Just(DnsRecords::default()).boxed() } else { collection::btree_map( sample::select(records.domains_iter().collect::>()), - collection::btree_set(site_specific_dns_record(), 1..6), + collection::btree_set(site_specific_dns_record(), 1..6) + .prop_map(move |records| BTreeMap::from([(at, records)])), 0..5, ) .prop_map_into() @@ -349,6 +356,7 @@ fn site_specific_dns_records( fn dns_resource_records( dns_resources: impl Iterator, + at: Instant, ) -> impl Strategy { dns_resources .map(|resource| { @@ -360,11 +368,14 @@ fn dns_resource_records( // For example, `*.example.com` and `app.example.com`. match address.split_once('.') { Some(("*" | "**", base)) => { - subdomain_records(base.to_owned(), domain_label()).boxed() + subdomain_records(base.to_owned(), domain_label(), at).boxed() } _ => resolved_ips() .prop_map(move |resolved_ips| { - DnsRecords::from([(address.parse().unwrap(), resolved_ips)]) + DnsRecords::from([( + address.parse().unwrap(), + BTreeMap::from([(at, resolved_ips)]), + )]) }) .boxed(), } diff --git a/rust/connlib/tunnel/src/tests/sut.rs b/rust/connlib/tunnel/src/tests/sut.rs index 7bcba0218..2b07fea7f 100644 --- a/rust/connlib/tunnel/src/tests/sut.rs +++ b/rust/connlib/tunnel/src/tests/sut.rs @@ -422,6 +422,7 @@ impl TunnelTest { c.update_relays(iter::empty(), state.relays.iter(), now); }) } + Transition::UpdateDnsRecords { .. } => {} }; state.advance(ref_state, &mut buffered_transmits); @@ -530,7 +531,7 @@ impl TunnelTest { let transport = query.transport; let response = - self.on_recursive_dns_query(&query.message, &ref_state.global_dns_records); + self.on_recursive_dns_query(&query.message, &ref_state.global_dns_records, now); self.client.exec_mut(|c| { c.sut.handle_dns_response( dns::RecursiveResponse { @@ -932,6 +933,7 @@ impl TunnelTest { &self, query: &dns_types::Query, global_dns_records: &DnsRecords, + now: Instant, ) -> dns_types::Response { const TTL: u32 = 1; // We deliberately chose a short TTL so we don't have to model the DNS cache in these tests. @@ -941,7 +943,7 @@ impl TunnelTest { let response = dns_types::ResponseBuilder::for_query(query, ResponseCode::NOERROR) .with_records( global_dns_records - .domain_records_iter(&domain) + .domain_records_iter(&domain, now) .filter(|record| qtype == record.rtype()) .map(|rdata| (domain.clone(), TTL, rdata)), ) @@ -1057,9 +1059,15 @@ fn on_gateway_event( } }), GatewayEvent::ResolveDns(r) => { - let resolved_ips = global_dns_records.domain_ips_iter(r.domain()).collect(); + let resolved_ips = global_dns_records + .domain_ips_iter(r.domain(), now) + .collect(); gateway.exec_mut(|g| { + g.dns_query_timestamps + .entry(r.domain().clone()) + .or_default() + .push(now); g.sut .handle_domain_resolved(r, Ok(resolved_ips), now) .unwrap() diff --git a/rust/connlib/tunnel/src/tests/tcp.rs b/rust/connlib/tunnel/src/tests/tcp.rs index 1fa461d72..7e170d063 100644 --- a/rust/connlib/tunnel/src/tests/tcp.rs +++ b/rust/connlib/tunnel/src/tests/tcp.rs @@ -77,10 +77,11 @@ impl Client { // TODO: Upstream ICMP error handling to `smoltcp`. if let Ok(Some((failed_packet, _))) = packet.icmp_error() && let Layer4Protocol::Tcp { dst, .. } = failed_packet.layer4_protocol() - && let Some(handle) = self - .sockets_by_remote - .get(&SocketAddr::new(failed_packet.dst(), dst)) + && let socket = SocketAddr::new(failed_packet.dst(), dst) + && let Some(handle) = self.sockets_by_remote.get(&socket) { + tracing::debug!(%socket, "Received ICMP error"); + self.sockets.get_mut::(*handle).abort(); } diff --git a/rust/connlib/tunnel/src/tests/transition.rs b/rust/connlib/tunnel/src/tests/transition.rs index cd2450dfc..9fa390993 100644 --- a/rust/connlib/tunnel/src/tests/transition.rs +++ b/rust/connlib/tunnel/src/tests/transition.rs @@ -3,7 +3,7 @@ use crate::{ proptest::{host_v4, host_v6}, }; use connlib_model::{RelayId, ResourceId, Site}; -use dns_types::{DomainName, RecordType}; +use dns_types::{DomainName, OwnedRecordData, RecordType}; use ip_network::IpNetwork; use super::{ @@ -14,7 +14,7 @@ use crate::messages::DnsServer; use prop::collection; use proptest::{prelude::*, sample}; use std::{ - collections::BTreeMap, + collections::{BTreeMap, BTreeSet}, net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, num::NonZeroU16, }; @@ -102,6 +102,12 @@ pub(crate) enum Transition { /// De-authorize access to a resource whilst the Gateway is network-partitioned from the portal. DeauthorizeWhileGatewayIsPartitioned(ResourceId), + + /// De-authorize access to a resource whilst the Gateway is network-partitioned from the portal. + UpdateDnsRecords { + domain: DomainName, + records: BTreeSet, + }, } #[derive(Debug, Clone)] diff --git a/website/src/components/Changelog/Gateway.tsx b/website/src/components/Changelog/Gateway.tsx index 06aba5a0c..ddb5a93d7 100644 --- a/website/src/components/Changelog/Gateway.tsx +++ b/website/src/components/Changelog/Gateway.tsx @@ -26,6 +26,10 @@ export default function Gateway() { Adds a `--log-format` CLI option to output logs as JSON. + + Fixes an issue where packets for DNS resources would be routed to + stale IPs after DNS record changes. +