From 3b56664e02bde08f0dafb6c1849717850f674bc7 Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Sat, 17 Aug 2024 00:15:58 +0100 Subject: [PATCH] test(rust): ensure deterministic proptests (#6319) For quite a while now, we have been making extensive use of property-based testing to ensure `connlib` works as intended. The idea of proptests is that - given a certain seed - we deterministically sample test inputs and assert properties on a given function. If the test fails, `proptest` prints the seed which can then be added to a regressions file to iterate on the test case and fix it. It is quite obvious that non-determinism in how the test input gets generated is no bueno and reduces the value we get out of these tests a fair bit. The `HashMap` and `HashSet` data structures are known to be non-deterministic in their iteration order. This causes non-determinism during the input generation because we make use of a lot of maps and sets to gradually build up the test input. We fix all uses of `HashMap` and `HashSet` by replacing them with `BTreeMap` and `BTreeSet`. To ensure this doesn't regress, we refactor `tunnel_test` to not make use of proptest's macros and instead, we initialise and run the test ourselves. This allows us to dump the sampled state and transitions into a file per test run. In CI, we then run a 2nd iteration of all regression tests and compare the sampled state and transitions with the previous run. They must match byte-for-byte. Finally, to discourage use of non-deterministic iteration, we ban the use of the iteration functions on `HashMap` and `HashSet` across the codebase. This doesn't catch iteration in a `for`-loop but it is better than not linting against it at all. --------- Signed-off-by: Thomas Eizinger Co-authored-by: Reactor Scram --- .github/workflows/_rust.yml | 38 ++++- rust/Cargo.lock | 2 - rust/clippy.toml | 6 + rust/connlib/clients/android/src/lib.rs | 4 +- rust/connlib/clients/shared/src/eventloop.rs | 6 +- rust/connlib/clients/shared/src/lib.rs | 6 +- rust/connlib/shared/src/messages/client.rs | 2 +- rust/connlib/snownet/src/allocation.rs | 6 +- rust/connlib/snownet/src/node.rs | 10 +- rust/connlib/tunnel/.gitignore | 1 + rust/connlib/tunnel/Cargo.toml | 3 - rust/connlib/tunnel/src/client.rs | 12 +- rust/connlib/tunnel/src/dns.rs | 20 +-- rust/connlib/tunnel/src/lib.rs | 4 +- rust/connlib/tunnel/src/peer.rs | 4 +- rust/connlib/tunnel/src/peer/nat_table.rs | 8 +- rust/connlib/tunnel/src/tests.rs | 131 ++++++++++++++++-- rust/connlib/tunnel/src/tests/assertions.rs | 18 +-- rust/connlib/tunnel/src/tests/reference.rs | 17 ++- .../tunnel/src/tests/run_count_appender.rs | 11 -- rust/connlib/tunnel/src/tests/sim_client.rs | 24 ++-- rust/connlib/tunnel/src/tests/sim_dns.rs | 4 +- rust/connlib/tunnel/src/tests/sim_gateway.rs | 6 +- rust/connlib/tunnel/src/tests/strategies.rs | 31 +++-- rust/connlib/tunnel/src/tests/stub_portal.rs | 30 ++-- rust/connlib/tunnel/src/tests/sut.rs | 45 ++---- rust/connlib/tunnel/src/tests/transition.rs | 4 +- rust/headless-client/src/ipc_service.rs | 4 +- rust/ip-packet/src/lib.rs | 2 +- rust/relay/src/lib.rs | 2 +- rust/relay/src/server.rs | 4 +- 31 files changed, 292 insertions(+), 173 deletions(-) create mode 100644 rust/connlib/tunnel/.gitignore delete mode 100644 rust/connlib/tunnel/src/tests/run_count_appender.rs diff --git a/.github/workflows/_rust.yml b/.github/workflows/_rust.yml index 92c20b1b1..301b487a9 100644 --- a/.github/workflows/_rust.yml +++ b/.github/workflows/_rust.yml @@ -88,15 +88,47 @@ jobs: - uses: ./.github/actions/setup-rust id: setup-rust - uses: ./.github/actions/setup-tauri - - run: cargo test --all-features ${{ steps.setup-rust.outputs.packages }} -- --include-ignored + - name: "cargo test" + shell: bash + run: | + + # First, run all tests. + cargo test --all-features ${{ steps.setup-rust.outputs.packages }} -- --include-ignored --nocapture + + # Backup dumped state and transition samples + mv $TESTCASES_DIR $TESTCASES_BACKUP_DIR + + # Re-run only the regression seeds + PROPTEST_CASES=0 cargo test --all-features ${{ steps.setup-rust.outputs.packages }} -- tunnel_test --nocapture + + # Assert that sampled state and transitions don't change between runs + for file in "$TESTCASES_DIR"/*.{state,transitions}; do + filename=$(basename "$file") + + if ! diff "$file" "$TESTCASES_BACKUP_DIR/$filename"; then + echo "Found non-deterministic testcase: $filename" + exit 1 + fi + done env: # # Needed to create tunnel interfaces in unit tests CARGO_TARGET_X86_64_UNKNOWN_LINUX_GNU_RUNNER: "sudo --preserve-env" PROPTEST_VERBOSE: 0 # Otherwise the output is very long. + PROPTEST_CASES: 1000 # Default is only 256. CARGO_PROFILE_TEST_OPT_LEVEL: 1 # Otherwise the tests take forever. - name: "cargo test" - shell: bash + TESTCASES_DIR: "connlib/tunnel/testcases" + TESTCASES_BACKUP_DIR: "connlib/tunnel/testcases_backup" + - name: Upload testcase data + if: ${{ failure() }} + uses: actions/upload-artifact@v4 + with: + overwrite: true + name: proptest-cases + path: | + rust/connlib/tunnel/testcases + rust/connlib/tunnel/testcases_backup + retention-days: 7 # Runs the Tauri client smoke test, built in debug mode. We can't run it in release # mode because of a known issue: diff --git a/rust/Cargo.lock b/rust/Cargo.lock index ad78a9ec5..4fd612999 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -2030,7 +2030,6 @@ dependencies = [ "derivative", "divan", "domain", - "firezone-logging", "firezone-relay", "futures", "futures-util", @@ -2055,7 +2054,6 @@ dependencies = [ "thiserror", "tokio", "tracing", - "tracing-appender", "tracing-subscriber", "tun", "uuid", diff --git a/rust/clippy.toml b/rust/clippy.toml index 76638cca4..c304c9b6e 100644 --- a/rust/clippy.toml +++ b/rust/clippy.toml @@ -1 +1,7 @@ avoid-breaking-exported-api = false # We don't publish anything to crates.io, hence we don't need to worry about breaking Rust API changes. +disallowed-methods = [ + { path = "std::collections::HashMap::iter", reason = "HashMap has non-deterministic iteration order, use BTreeMap instead" }, + { path = "std::collections::HashSet::iter", reason = "HashSet has non-deterministic iteration order, use BTreeSet instead" }, + { path = "std::collections::HashMap::into_iter", reason = "HashMap has non-deterministic iteration order, use BTreeMap instead" }, + { path = "std::collections::HashSet::into_iter", reason = "HashSet has non-deterministic iteration order, use BTreeSet instead" }, +] diff --git a/rust/connlib/clients/android/src/lib.rs b/rust/connlib/clients/android/src/lib.rs index dda1e01b9..82090151a 100644 --- a/rust/connlib/clients/android/src/lib.rs +++ b/rust/connlib/clients/android/src/lib.rs @@ -20,7 +20,7 @@ use jni::{ use phoenix_channel::PhoenixChannel; use secrecy::{Secret, SecretString}; use socket_factory::{SocketFactory, TcpSocket, UdpSocket}; -use std::{collections::HashSet, io, net::IpAddr, os::fd::AsRawFd, path::Path, sync::Arc}; +use std::{collections::BTreeSet, io, net::IpAddr, os::fd::AsRawFd, path::Path, sync::Arc}; use std::{ net::{Ipv4Addr, Ipv6Addr}, os::fd::RawFd, @@ -484,7 +484,7 @@ pub unsafe extern "system" fn Java_dev_firezone_android_tunnel_ConnlibSession_se }) .expect("Invalid string returned from android client"), ); - let disabled_resources: HashSet = + let disabled_resources: BTreeSet = serde_json::from_str(&disabled_resources).unwrap(); tracing::debug!("disabled resource: {disabled_resources:?}"); let session = &*(session_ptr as *const SessionWrapper); diff --git a/rust/connlib/clients/shared/src/eventloop.rs b/rust/connlib/clients/shared/src/eventloop.rs index ca407c77d..f3e9613c5 100644 --- a/rust/connlib/clients/shared/src/eventloop.rs +++ b/rust/connlib/clients/shared/src/eventloop.rs @@ -13,7 +13,7 @@ use connlib_shared::messages::{ use firezone_tunnel::ClientTunnel; use phoenix_channel::{ErrorReply, OutboundRequestId, PhoenixChannel}; use std::{ - collections::{BTreeSet, HashMap, HashSet}, + collections::{BTreeMap, BTreeSet}, net::IpAddr, task::{Context, Poll}, }; @@ -35,7 +35,7 @@ pub enum Command { Reset, SetDns(Vec), SetTun(Box), - SetDisabledResources(HashSet), + SetDisabledResources(BTreeSet), } impl Eventloop { @@ -348,7 +348,7 @@ where #[derive(Default)] struct SentConnectionIntents { - inner: HashMap, + inner: BTreeMap, } impl SentConnectionIntents { diff --git a/rust/connlib/clients/shared/src/lib.rs b/rust/connlib/clients/shared/src/lib.rs index aeeea7564..bf5ca477f 100644 --- a/rust/connlib/clients/shared/src/lib.rs +++ b/rust/connlib/clients/shared/src/lib.rs @@ -12,7 +12,7 @@ use firezone_tunnel::ClientTunnel; use messages::{IngressMessages, ReplyMessages}; use phoenix_channel::PhoenixChannel; use socket_factory::{SocketFactory, TcpSocket, UdpSocket}; -use std::collections::{HashMap, HashSet}; +use std::collections::{BTreeMap, BTreeSet}; use std::net::IpAddr; use std::sync::Arc; use tokio::sync::mpsc::UnboundedReceiver; @@ -94,7 +94,7 @@ impl Session { let _ = self.channel.send(Command::SetDns(new_dns)); } - pub fn set_disabled_resources(&self, disabled_resources: HashSet) { + pub fn set_disabled_resources(&self, disabled_resources: BTreeSet) { let _ = self .channel .send(Command::SetDisabledResources(disabled_resources)); @@ -135,7 +135,7 @@ where private_key, tcp_socket_factory, udp_socket_factory, - HashMap::from([(portal.server_host().to_owned(), portal.resolved_addresses())]), + BTreeMap::from([(portal.server_host().to_owned(), portal.resolved_addresses())]), )?; let mut eventloop = Eventloop::new(tunnel, callbacks, portal, rx); diff --git a/rust/connlib/shared/src/messages/client.rs b/rust/connlib/shared/src/messages/client.rs index 96115fbd9..812d4e17d 100644 --- a/rust/connlib/shared/src/messages/client.rs +++ b/rust/connlib/shared/src/messages/client.rs @@ -12,7 +12,7 @@ use super::ResourceId; use itertools::Itertools; /// Description of a resource that maps to a DNS record. -#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] pub struct ResourceDescriptionDns { /// Resource's id. pub id: ResourceId, diff --git a/rust/connlib/snownet/src/allocation.rs b/rust/connlib/snownet/src/allocation.rs index bb02895f1..c7d529051 100644 --- a/rust/connlib/snownet/src/allocation.rs +++ b/rust/connlib/snownet/src/allocation.rs @@ -10,7 +10,7 @@ use hex_display::HexDisplayExt as _; use rand::random; use std::{ borrow::Cow, - collections::{HashMap, VecDeque}, + collections::{BTreeMap, VecDeque}, net::{SocketAddr, SocketAddrV4, SocketAddrV6}, time::{Duration, Instant}, }; @@ -70,7 +70,7 @@ pub struct Allocation { buffered_transmits: VecDeque>, events: VecDeque, - sent_requests: HashMap< + sent_requests: BTreeMap< TransactionId, ( SocketAddr, @@ -1191,7 +1191,7 @@ stun_codec::define_attribute_enums!( #[derive(Debug)] struct ChannelBindings { - inner: HashMap, + inner: BTreeMap, next_channel: u16, } diff --git a/rust/connlib/snownet/src/node.rs b/rust/connlib/snownet/src/node.rs index 491797294..9e232a887 100644 --- a/rust/connlib/snownet/src/node.rs +++ b/rust/connlib/snownet/src/node.rs @@ -18,7 +18,7 @@ use rand::{random, SeedableRng}; use secrecy::{ExposeSecret, Secret}; use sha2::Digest; use std::borrow::Cow; -use std::collections::{BTreeMap, BTreeSet, HashSet}; +use std::collections::{BTreeMap, BTreeSet}; use std::hash::Hash; use std::marker::PhantomData; use std::mem; @@ -86,7 +86,7 @@ pub struct Node { index: IndexLfsr, rate_limiter: Arc, - host_candidates: HashSet, + host_candidates: Vec, // `Candidate` doesn't implement `PartialOrd` so we cannot use a `BTreeSet`. Linear search is okay because we expect this vec to be <100 elements buffered_transmits: VecDeque>, next_rate_limiter_reset: Option, @@ -595,12 +595,12 @@ where fn add_local_as_host_candidate(&mut self, local: SocketAddr) -> Result<(), Error> { let host_candidate = Candidate::host(local, Protocol::Udp)?; - let is_new = self.host_candidates.insert(host_candidate.clone()); - - if !is_new { + if self.host_candidates.contains(&host_candidate) { return Ok(()); } + self.host_candidates.push(host_candidate.clone()); + for (cid, agent) in self.connections.agents_mut() { let _span = info_span!("connection", %cid).entered(); diff --git a/rust/connlib/tunnel/.gitignore b/rust/connlib/tunnel/.gitignore new file mode 100644 index 000000000..72285f6c7 --- /dev/null +++ b/rust/connlib/tunnel/.gitignore @@ -0,0 +1 @@ +testcases/ diff --git a/rust/connlib/tunnel/Cargo.toml b/rust/connlib/tunnel/Cargo.toml index e4af3c182..67df8aa93 100644 --- a/rust/connlib/tunnel/Cargo.toml +++ b/rust/connlib/tunnel/Cargo.toml @@ -12,7 +12,6 @@ chrono = { workspace = true } connlib-shared = { workspace = true } divan = { version = "0.1.14", optional = true } domain = { workspace = true } -firezone-logging = { workspace = true } futures = { version = "0.3", default-features = false, features = ["std", "async-await", "executor"] } futures-util = { version = "0.3", default-features = false, features = ["std", "async-await", "async-await-macro"] } glob = "0.3.1" @@ -37,7 +36,6 @@ uuid = { version = "1.10", default-features = false, features = ["std", "v4"] } [dev-dependencies] derivative = "2.2.0" -firezone-logging = { workspace = true } firezone-relay = { workspace = true, features = ["proptest"] } ip-packet = { workspace = true, features = ["proptest"] } proptest-state-machine = "0.3" @@ -45,7 +43,6 @@ rand = "0.8" serde_json = "1.0" test-case = "3.3.1" test-strategy = "0.3.1" -tracing-appender = "0.2.3" tracing-subscriber = { version = "0.3", features = ["env-filter"] } [[bench]] diff --git a/rust/connlib/tunnel/src/client.rs b/rust/connlib/tunnel/src/client.rs index fd00bfdf1..e8b0c9c55 100644 --- a/rust/connlib/tunnel/src/client.rs +++ b/rust/connlib/tunnel/src/client.rs @@ -80,7 +80,7 @@ impl ClientTunnel { }); } - pub fn set_disabled_resources(&mut self, new_disabled_resources: HashSet) { + pub fn set_disabled_resources(&mut self, new_disabled_resources: BTreeSet) { self.role_state .set_disabled_resource(new_disabled_resources); @@ -279,7 +279,7 @@ pub struct ClientState { interface_config: Option, /// Resources that have been disabled by the UI - disabled_resources: HashSet, + disabled_resources: BTreeSet, buffered_events: VecDeque, buffered_packets: VecDeque>, @@ -296,7 +296,7 @@ pub(crate) struct AwaitingConnectionDetails { impl ClientState { pub(crate) fn new( private_key: impl Into, - known_hosts: HashMap>, + known_hosts: BTreeMap>, seed: [u8; 32], ) -> Self { Self { @@ -776,7 +776,7 @@ impl ClientState { self.mangled_dns_queries.clear(); } - pub fn set_disabled_resource(&mut self, new_disabled_resources: HashSet) { + pub fn set_disabled_resource(&mut self, new_disabled_resources: BTreeSet) { let current_disabled_resources = self.disabled_resources.clone(); // We set disabled_resources before anything else so that add_resource knows what resources are enabled right now. @@ -1185,7 +1185,7 @@ fn effective_dns_servers( .peekable(); if dns_servers.peek().is_none() { - tracing::error!("No system default DNS servers available! Can't initialize resolver. DNS interception will be disabled."); + tracing::warn!("No system default DNS servers available! Can't initialize resolver. DNS interception will be disabled."); return Vec::new(); } @@ -1425,7 +1425,7 @@ mod tests { pub fn for_test() -> ClientState { ClientState::new( StaticSecret::random_from_rng(OsRng), - HashMap::new(), + BTreeMap::new(), rand::random(), ) } diff --git a/rust/connlib/tunnel/src/dns.rs b/rust/connlib/tunnel/src/dns.rs index ac6c96891..0d6d34aaf 100644 --- a/rust/connlib/tunnel/src/dns.rs +++ b/rust/connlib/tunnel/src/dns.rs @@ -10,7 +10,7 @@ use ip_packet::IpPacket; use ip_packet::Packet as _; use itertools::Itertools; use pattern::{Candidate, Pattern}; -use std::collections::HashMap; +use std::collections::{BTreeMap, HashMap}; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; const DNS_TTL: u32 = 1; @@ -44,12 +44,12 @@ pub(crate) enum ResolveStrategy { } struct KnownHosts { - fqdn_to_ips: HashMap>, - ips_to_fqdn: HashMap, + fqdn_to_ips: BTreeMap>, + ips_to_fqdn: BTreeMap, } impl KnownHosts { - fn new(hosts: HashMap>) -> KnownHosts { + fn new(hosts: BTreeMap>) -> KnownHosts { KnownHosts { fqdn_to_ips: fqdn_to_ips_for_known_hosts(&hosts), ips_to_fqdn: ips_to_fqdn_for_known_hosts(&hosts), @@ -87,7 +87,7 @@ impl KnownHosts { } impl StubResolver { - pub(crate) fn new(known_hosts: HashMap>) -> StubResolver { + pub(crate) fn new(known_hosts: BTreeMap>) -> StubResolver { StubResolver { fqdn_to_ips: Default::default(), ips_to_fqdn: Default::default(), @@ -392,8 +392,8 @@ fn get_v6(ip: IpAddr) -> Option { } fn fqdn_to_ips_for_known_hosts( - hosts: &HashMap>, -) -> HashMap> { + hosts: &BTreeMap>, +) -> BTreeMap> { hosts .iter() .filter_map(|(d, a)| DomainName::vec_from_str(d).ok().map(|d| (d, a.clone()))) @@ -401,8 +401,8 @@ fn fqdn_to_ips_for_known_hosts( } fn ips_to_fqdn_for_known_hosts( - hosts: &HashMap>, -) -> HashMap { + hosts: &BTreeMap>, +) -> BTreeMap { hosts .iter() .filter_map(|(d, a)| { @@ -593,7 +593,7 @@ mod benches { fn match_domain_linear(bencher: divan::Bencher) { bencher .with_inputs(|| { - let mut resolver = StubResolver::new(HashMap::default()); + let mut resolver = StubResolver::new(BTreeMap::default()); let mut rng = rand::thread_rng(); for n in 0..NUM_RES { diff --git a/rust/connlib/tunnel/src/lib.rs b/rust/connlib/tunnel/src/lib.rs index cb50d5e01..06873090d 100644 --- a/rust/connlib/tunnel/src/lib.rs +++ b/rust/connlib/tunnel/src/lib.rs @@ -16,7 +16,7 @@ use ip_network::{Ipv4Network, Ipv6Network}; use rand::rngs::OsRng; use socket_factory::{SocketFactory, TcpSocket, UdpSocket}; use std::{ - collections::{BTreeSet, HashMap, HashSet}, + collections::{BTreeMap, BTreeSet, HashSet}, net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, sync::Arc, task::{Context, Poll}, @@ -84,7 +84,7 @@ impl ClientTunnel { private_key: StaticSecret, tcp_socket_factory: Arc>, udp_socket_factory: Arc>, - known_hosts: HashMap>, + known_hosts: BTreeMap>, ) -> Result { Ok(Self { io: Io::new(tcp_socket_factory, udp_socket_factory)?, diff --git a/rust/connlib/tunnel/src/peer.rs b/rust/connlib/tunnel/src/peer.rs index 5fa753ee6..7f924677b 100644 --- a/rust/connlib/tunnel/src/peer.rs +++ b/rust/connlib/tunnel/src/peer.rs @@ -1,4 +1,4 @@ -use std::collections::{HashMap, HashSet, VecDeque}; +use std::collections::{BTreeMap, HashMap, HashSet, VecDeque}; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; use std::time::{Duration, Instant}; @@ -598,7 +598,7 @@ pub struct ClientOnGateway { ipv6: Ipv6Addr, resources: HashMap>, filters: IpNetworkTable, - permanent_translations: HashMap, + permanent_translations: BTreeMap, nat_table: NatTable, buffered_events: VecDeque, } diff --git a/rust/connlib/tunnel/src/peer/nat_table.rs b/rust/connlib/tunnel/src/peer/nat_table.rs index 9bd8243f3..02f84071d 100644 --- a/rust/connlib/tunnel/src/peer/nat_table.rs +++ b/rust/connlib/tunnel/src/peer/nat_table.rs @@ -2,7 +2,7 @@ use anyhow::Context; use bimap::BiMap; use ip_packet::{IpPacket, Protocol}; -use std::collections::HashMap; +use std::collections::BTreeMap; use std::net::IpAddr; use std::time::{Duration, Instant}; @@ -19,7 +19,7 @@ use std::time::{Duration, Instant}; #[derive(Default, Debug)] pub(crate) struct NatTable { pub(crate) table: BiMap<(Protocol, IpAddr), (Protocol, IpAddr)>, - pub(crate) last_seen: HashMap<(Protocol, IpAddr), Instant>, + pub(crate) last_seen: BTreeMap<(Protocol, IpAddr), Instant>, } const TTL: Duration = Duration::from_secs(60); @@ -116,8 +116,6 @@ mod tests { ) { proptest::prop_assume!(packet.destination().is_ipv4() == outside_dst.is_ipv4()); // Required for our test to simulate a response. - let _guard = firezone_logging::test("trace"); - let sent_at = Instant::now(); let mut table = NatTable::default(); let response_delay = Duration::from_secs(response_delay); @@ -166,8 +164,6 @@ mod tests { != packet2.as_immutable().source_protocol().unwrap() ); - let _guard = firezone_logging::test("trace"); - let mut table = NatTable::default(); let mut packets = [(packet1, outside_dst1), (packet2, outside_dst2)]; diff --git a/rust/connlib/tunnel/src/tests.rs b/rust/connlib/tunnel/src/tests.rs index 02dd19b3c..ccbf41838 100644 --- a/rust/connlib/tunnel/src/tests.rs +++ b/rust/connlib/tunnel/src/tests.rs @@ -1,12 +1,19 @@ use crate::tests::sut::TunnelTest; -use proptest::test_runner::Config; +use assertions::PanicOnErrorEvents; +use proptest::test_runner::{Config, TestError, TestRunner}; +use proptest_state_machine::{ReferenceStateMachine, StateMachineTest}; +use reference::ReferenceState; +use std::sync::atomic::{self, AtomicU32}; +use tracing::level_filters::LevelFilter; +use tracing_subscriber::{ + layer::SubscriberExt as _, util::SubscriberInitExt as _, EnvFilter, Layer, +}; mod assertions; mod buffered_transmits; mod composite_strategy; mod flux_capacitor; mod reference; -mod run_count_appender; mod sim_client; mod sim_dns; mod sim_gateway; @@ -21,12 +28,118 @@ type QueryId = u16; type IcmpSeq = u16; type IcmpIdentifier = u16; -proptest_state_machine::prop_state_machine! { - #![proptest_config(Config { - cases: 1000, - .. Config::default() - })] +#[test] +#[allow(clippy::print_stdout, clippy::print_stderr)] +fn tunnel_test() { + let config = Config { + source_file: Some(file!()), + ..Default::default() + }; - #[test] - fn run_tunnel_test(sequential 1..10 => TunnelTest); + static TEST_INDEX: AtomicU32 = AtomicU32::new(0); + + let _ = std::fs::remove_dir_all("testcases"); + let _ = std::fs::create_dir_all("testcases"); + + let result = TestRunner::new(config).run( + &ReferenceState::sequential_strategy(5..15), + |(mut ref_state, transitions, mut seen_counter)| { + let test_index = TEST_INDEX.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + let _guard = init_logging(&ref_state, test_index); + + std::fs::write( + format!("testcases/{test_index}.state"), + format!("{ref_state:#?}"), + ) + .unwrap(); + std::fs::write( + format!("testcases/{test_index}.transitions"), + format!("{transitions:#?}"), + ) + .unwrap(); + + let num_transitions = transitions.len(); + + println!("Running test case {test_index:04} with {num_transitions:02} transitions"); + + let mut sut = TunnelTest::init_test(&ref_state); + + // Check the invariants on the initial state + TunnelTest::check_invariants(&sut, &ref_state); + + for (ix, transition) in transitions.iter().enumerate() { + // The counter is `Some` only before shrinking. When it's `Some` it + // must be incremented before every transition that's being applied + // to inform the strategy that the transition has been applied for + // the first step of its shrinking process which removes any unseen + // transitions. + if let Some(seen_counter) = seen_counter.as_mut() { + seen_counter.fetch_add(1, atomic::Ordering::SeqCst); + } + + tracing::info!( + "\n\nApplying transition {}/{num_transitions}: {transition:?}\n", + ix + 1, + ); + + // Apply the transition on the states + ref_state = ReferenceState::apply(ref_state, transition); + sut = TunnelTest::apply(sut, &ref_state, transition.clone()); + + // Check the invariants after the transition is applied + TunnelTest::check_invariants(&sut, &ref_state); + } + + TunnelTest::teardown(sut); + + Ok(()) + }, + ); + + let Err(e) = result else { + return; + }; + + match e { + TestError::Abort(msg) => panic!("Test aborted: {msg}"), + TestError::Fail(msg, (ref_state, transitions, _)) => { + eprintln!("{ref_state:#?}"); + eprintln!("{transitions:#?}"); + + panic!("{msg}") + } + } +} + +/// Initialise logging for [`TunnelTest`]. +/// +/// Log-level can be controlled with `RUST_LOG`. +/// By default, `debug` logs will be written to the `testcases/` directory for each test run. +/// This allows us to download logs from CI. +/// For stdout, only the default log filter applies. +/// +/// Finally, we install [`PanicOnErrorEvents`] into the registry. +/// An `ERROR` log is treated as a fatal error and will fail the test. +fn init_logging(ref_state: &ReferenceState, test_index: u32) -> tracing::subscriber::DefaultGuard { + tracing_subscriber::registry() + .with( + tracing_subscriber::fmt::layer() + .with_test_writer() + .with_timer(ref_state.flux_capacitor.clone()) + .with_filter(EnvFilter::from_default_env()), + ) + .with( + tracing_subscriber::fmt::layer() + .with_writer(std::fs::File::create(format!("testcases/{test_index}.log")).unwrap()) + .with_timer(ref_state.flux_capacitor.clone()) + .with_ansi(false) + .with_filter( + EnvFilter::builder() + .with_default_directive(LevelFilter::DEBUG.into()) + .from_env() + .unwrap(), + ), + ) + .with(PanicOnErrorEvents::new(test_index)) + .set_default() } diff --git a/rust/connlib/tunnel/src/tests/assertions.rs b/rust/connlib/tunnel/src/tests/assertions.rs index bd939a264..caa93d0bd 100644 --- a/rust/connlib/tunnel/src/tests/assertions.rs +++ b/rust/connlib/tunnel/src/tests/assertions.rs @@ -6,7 +6,7 @@ use crate::tests::reference::ResourceDst; use connlib_shared::{messages::GatewayId, DomainName}; use ip_packet::IpPacket; use std::{ - collections::{hash_map::Entry, BTreeMap, HashMap, HashSet, VecDeque}, + collections::{hash_map::Entry, BTreeMap, BTreeSet, HashMap, VecDeque}, marker::PhantomData, net::IpAddr, sync::atomic::{AtomicBool, Ordering}, @@ -24,7 +24,7 @@ pub(crate) fn assert_icmp_packets_properties( ref_client: &RefClient, sim_client: &SimClient, sim_gateways: HashMap, - global_dns_records: &BTreeMap>, + global_dns_records: &BTreeMap>, ) { let unexpected_icmp_replies = find_unexpected_entries( &ref_client @@ -219,7 +219,7 @@ fn assert_destination_is_cdir_resource(gateway_received_request: &IpPacket<'_>, fn assert_destination_is_dns_resource( gateway_received_request: &IpPacket<'_>, - global_dns_records: &BTreeMap>, + global_dns_records: &BTreeMap>, domain: &DomainName, ) { let actual = gateway_received_request.destination(); @@ -269,7 +269,7 @@ fn assert_proxy_ip_mapping_is_stable( fn find_unexpected_entries<'a, E, K, V>( expected: &VecDeque, - actual: &'a HashMap, + actual: &'a BTreeMap, is_equal: impl Fn(&E, &K) -> bool, ) -> Vec<&'a V> { actual @@ -283,13 +283,15 @@ fn find_unexpected_entries<'a, E, K, V>( pub(crate) struct PanicOnErrorEvents { subscriber: PhantomData, has_seen_error: AtomicBool, + index: u32, } -impl Default for PanicOnErrorEvents { - fn default() -> Self { +impl PanicOnErrorEvents { + pub(crate) fn new(index: u32) -> Self { Self { - subscriber: Default::default(), + subscriber: PhantomData, has_seen_error: Default::default(), + index, } } } @@ -297,7 +299,7 @@ impl Default for PanicOnErrorEvents { impl Drop for PanicOnErrorEvents { fn drop(&mut self) { if self.has_seen_error.load(Ordering::SeqCst) { - panic!("At least one assertion failed"); + panic!("Testcase {} failed", self.index); } } } diff --git a/rust/connlib/tunnel/src/tests/reference.rs b/rust/connlib/tunnel/src/tests/reference.rs index 9aac80baa..ac281e237 100644 --- a/rust/connlib/tunnel/src/tests/reference.rs +++ b/rust/connlib/tunnel/src/tests/reference.rs @@ -1,6 +1,6 @@ use super::{ - composite_strategy::CompositeStrategy, sim_client::*, sim_dns::*, sim_gateway::*, sim_net::*, - strategies::*, stub_portal::StubPortal, transition::*, + composite_strategy::CompositeStrategy, flux_capacitor::FluxCapacitor, sim_client::*, + sim_dns::*, sim_gateway::*, sim_net::*, strategies::*, stub_portal::StubPortal, transition::*, }; use crate::dns::is_subdomain; use connlib_shared::{ @@ -14,7 +14,7 @@ use domain::base::Rtype; use proptest::{prelude::*, sample}; use proptest_state_machine::ReferenceStateMachine; use std::{ - collections::{BTreeMap, HashSet}, + collections::{BTreeMap, BTreeSet, HashSet}, fmt, iter, net::{IpAddr, SocketAddr}, }; @@ -22,7 +22,8 @@ use std::{ /// The reference state machine of the tunnel. /// /// This is the "expected" part of our test. -#[derive(Clone, Debug)] +#[derive(Clone, derivative::Derivative)] +#[derivative(Debug)] pub(crate) struct ReferenceState { pub(crate) client: Host, pub(crate) gateways: BTreeMap>, @@ -36,9 +37,12 @@ pub(crate) struct ReferenceState { /// All IP addresses a domain resolves to in our test. /// /// This is used to e.g. mock DNS resolution on the gateway. - pub(crate) global_dns_records: BTreeMap>, + pub(crate) global_dns_records: BTreeMap>, pub(crate) network: RoutingTable, + + #[derivative(Debug = "ignore")] + pub(crate) flux_capacitor: FluxCapacitor, } #[derive(Debug, Clone)] @@ -161,6 +165,7 @@ impl ReferenceStateMachine for ReferenceState { global_dns_records, network, drop_direct_client_traffic, + flux_capacitor: FluxCapacitor::default(), } }, ) @@ -196,7 +201,7 @@ impl ReferenceStateMachine for ReferenceState { .with(1, Just(Transition::Idle)) .with_if_not_empty(1, state.client.inner().all_resource_ids(), |resources_id| { sample::subsequence(resources_id.clone(), resources_id.len()).prop_map( - |resources_id| Transition::DisableResources(HashSet::from_iter(resources_id)), + |resources_id| Transition::DisableResources(BTreeSet::from_iter(resources_id)), ) }) .with_if_not_empty( diff --git a/rust/connlib/tunnel/src/tests/run_count_appender.rs b/rust/connlib/tunnel/src/tests/run_count_appender.rs deleted file mode 100644 index 05ac915b0..000000000 --- a/rust/connlib/tunnel/src/tests/run_count_appender.rs +++ /dev/null @@ -1,11 +0,0 @@ -use std::sync::atomic::AtomicU32; -use tracing_appender::rolling::RollingFileAppender; - -/// A file appender that rolls over to a new file for every instance that is created within the same process. -#[allow(dead_code)] -pub(crate) fn appender() -> RollingFileAppender { - static RUN_COUNT: AtomicU32 = AtomicU32::new(0); - let run_count = RUN_COUNT.fetch_add(1, std::sync::atomic::Ordering::Relaxed); - - tracing_appender::rolling::never(".", format!("run_{run_count:04}.log")) -} diff --git a/rust/connlib/tunnel/src/tests/sim_client.rs b/rust/connlib/tunnel/src/tests/sim_client.rs index aaec04d73..d4e46a254 100644 --- a/rust/connlib/tunnel/src/tests/sim_client.rs +++ b/rust/connlib/tunnel/src/tests/sim_client.rs @@ -47,10 +47,10 @@ pub(crate) struct SimClient { pub(crate) dns_by_sentinel: BiMap, pub(crate) sent_dns_queries: HashMap<(SocketAddr, QueryId), IpPacket<'static>>, - pub(crate) received_dns_responses: HashMap<(SocketAddr, QueryId), IpPacket<'static>>, + pub(crate) received_dns_responses: BTreeMap<(SocketAddr, QueryId), IpPacket<'static>>, pub(crate) sent_icmp_requests: HashMap<(u16, u16), IpPacket<'static>>, - pub(crate) received_icmp_replies: HashMap<(u16, u16), IpPacket<'static>>, + pub(crate) received_icmp_replies: BTreeMap<(u16, u16), IpPacket<'static>>, buffer: Vec, } @@ -247,7 +247,7 @@ impl SimClient { pub struct RefClient { pub(crate) id: ClientId, pub(crate) key: PrivateKey, - pub(crate) known_hosts: HashMap>, + pub(crate) known_hosts: BTreeMap>, pub(crate) tunnel_ip4: Ipv4Addr, pub(crate) tunnel_ip6: Ipv6Addr, @@ -270,7 +270,7 @@ pub struct RefClient { /// The IPs assigned to a domain by connlib are an implementation detail that we don't want to model in these tests. /// Instead, we just remember what _kind_ of records we resolved to be able to sample a matching src IP. #[derivative(Debug = "ignore")] - pub(crate) dns_records: BTreeMap>, + pub(crate) dns_records: BTreeMap>, /// Whether we are connected to the gateway serving the Internet resource. pub(crate) connected_internet_resources: bool, @@ -285,14 +285,14 @@ pub struct RefClient { /// Actively disabled resources by the UI #[derivative(Debug = "ignore")] - pub(crate) disabled_resources: HashSet, + pub(crate) disabled_resources: BTreeSet, /// The expected ICMP handshakes. /// /// This is indexed by gateway because our assertions rely on the order of the sent packets. #[derivative(Debug = "ignore")] pub(crate) expected_icmp_handshakes: - HashMap>, + BTreeMap>, /// The expected DNS handshakes. #[derivative(Debug = "ignore")] pub(crate) expected_dns_handshakes: VecDeque<(SocketAddr, QueryId)>, @@ -508,7 +508,7 @@ impl RefClient { .find(|id| !self.disabled_resources.contains(id)) } - fn resolved_domains(&self) -> impl Iterator)> + '_ { + fn resolved_domains(&self) -> impl Iterator)> + '_ { self.dns_records .iter() .filter(|(domain, _)| self.dns_resource_by_domain(domain).is_some()) @@ -574,7 +574,7 @@ impl RefClient { pub(crate) fn resolved_ip4_for_non_resources( &self, - global_dns_records: &BTreeMap>, + global_dns_records: &BTreeMap>, ) -> Vec { self.resolved_ips_for_non_resources(global_dns_records) .filter_map(|ip| match ip { @@ -586,7 +586,7 @@ impl RefClient { pub(crate) fn resolved_ip6_for_non_resources( &self, - global_dns_records: &BTreeMap>, + global_dns_records: &BTreeMap>, ) -> Vec { self.resolved_ips_for_non_resources(global_dns_records) .filter_map(|ip| match ip { @@ -598,7 +598,7 @@ impl RefClient { fn resolved_ips_for_non_resources<'a>( &'a self, - global_dns_records: &'a BTreeMap>, + global_dns_records: &'a BTreeMap>, ) -> impl Iterator + 'a { self.dns_records .iter() @@ -736,8 +736,8 @@ fn ref_client( ) } -fn known_hosts() -> impl Strategy>> { - collection::hash_map( +fn known_hosts() -> impl Strategy>> { + collection::btree_map( domain_name(2..4).prop_map(|d| d.parse().unwrap()), collection::vec(any::(), 1..6), 0..15, diff --git a/rust/connlib/tunnel/src/tests/sim_dns.rs b/rust/connlib/tunnel/src/tests/sim_dns.rs index 583bb8cab..b8b78f505 100644 --- a/rust/connlib/tunnel/src/tests/sim_dns.rs +++ b/rust/connlib/tunnel/src/tests/sim_dns.rs @@ -18,7 +18,7 @@ use proptest::{ use snownet::Transmit; use std::{ borrow::Cow, - collections::{BTreeMap, HashSet}, + collections::{BTreeMap, BTreeSet}, fmt, net::{IpAddr, SocketAddr}, time::Instant, @@ -50,7 +50,7 @@ pub(crate) struct SimDns {} impl SimDns { pub(crate) fn receive( &mut self, - global_dns_records: &BTreeMap>, + global_dns_records: &BTreeMap>, transmit: Transmit, _now: Instant, ) -> Option> { diff --git a/rust/connlib/tunnel/src/tests/sim_gateway.rs b/rust/connlib/tunnel/src/tests/sim_gateway.rs index 3ed441e7a..cc1055bc7 100644 --- a/rust/connlib/tunnel/src/tests/sim_gateway.rs +++ b/rust/connlib/tunnel/src/tests/sim_gateway.rs @@ -13,7 +13,7 @@ use ip_packet::IpPacket; use proptest::prelude::*; use snownet::Transmit; use std::{ - collections::{BTreeMap, HashSet, VecDeque}, + collections::{BTreeMap, BTreeSet, VecDeque}, net::IpAddr, time::Instant, }; @@ -40,7 +40,7 @@ impl SimGateway { pub(crate) fn receive( &mut self, - global_dns_records: &BTreeMap>, + global_dns_records: &BTreeMap>, transmit: Transmit, now: Instant, ) -> Option> { @@ -61,7 +61,7 @@ impl SimGateway { /// Process an IP packet received on the gateway. fn on_received_packet( &mut self, - global_dns_records: &BTreeMap>, + global_dns_records: &BTreeMap>, packet: IpPacket<'_>, now: Instant, ) -> Option> { diff --git a/rust/connlib/tunnel/src/tests/strategies.rs b/rust/connlib/tunnel/src/tests/strategies.rs index fe04b479c..8984fe5a5 100644 --- a/rust/connlib/tunnel/src/tests/strategies.rs +++ b/rust/connlib/tunnel/src/tests/strategies.rs @@ -20,15 +20,16 @@ use itertools::Itertools; use prop::sample; use proptest::{collection, prelude::*}; use std::{ - collections::{BTreeMap, HashMap, HashSet}, + collections::{BTreeMap, BTreeSet}, net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, time::Duration, }; -pub(crate) fn global_dns_records() -> impl Strategy>> { +pub(crate) fn global_dns_records() -> impl Strategy>> +{ collection::btree_map( domain_name(2..4).prop_map(|d| d.parse().unwrap()), - collection::hash_set(any::(), 1..6), + collection::btree_set(any::(), 1..6), 0..5, ) } @@ -56,13 +57,13 @@ pub(crate) fn latency(max: u64) -> impl Strategy { /// Similar as in production, the portal holds a list of DNS and CIDR resources (those are also sampled from the given sites). /// Via this site mapping, these resources are implicitly assigned to a gateway. pub(crate) fn stub_portal() -> impl Strategy { - collection::hash_set(site(), 1..=3) + collection::btree_set(site(), 1..=3) .prop_flat_map(|sites| { - let cidr_resources = collection::hash_set( + let cidr_resources = collection::btree_set( cidr_resource_outside_reserved_ranges(any_site(sites.clone())), 1..5, ); - let dns_resources = collection::hash_set( + let dns_resources = collection::btree_set( prop_oneof![ non_wildcard_dns_resource(any_site(sites.clone())), star_wildcard_dns_resource(any_site(sites.clone())), @@ -75,9 +76,9 @@ pub(crate) fn stub_portal() -> impl Strategy { // Assign between 1 and 3 gateways to each site. let gateways_by_site = sites .into_iter() - .map(|site| (Just(site.id), collection::hash_set(gateway_id(), 1..=3))) + .map(|site| (Just(site.id), collection::btree_set(gateway_id(), 1..=3))) .collect::>() - .prop_map(HashMap::from_iter); + .prop_map(BTreeMap::from_iter); let gateway_selector = any::(); @@ -116,11 +117,11 @@ pub(crate) fn relays() -> impl Strategy>> { /// /// We make sure to always have at least 1 IPv4 and 1 IPv6 DNS server. pub(crate) fn dns_servers() -> impl Strategy>> { - let ip4_dns_servers = collection::hash_set( + let ip4_dns_servers = collection::btree_set( any::().prop_map(|ip| SocketAddr::from((ip, 53))), 1..4, ); - let ip6_dns_servers = collection::hash_set( + let ip6_dns_servers = collection::btree_set( any::().prop_map(|ip| SocketAddr::from((ip, 53))), 1..4, ); @@ -129,7 +130,7 @@ pub(crate) fn dns_servers() -> impl Strategy impl Strategy) -> impl Strategy { +fn any_site(sites: BTreeSet) -> impl Strategy { sample::select(Vec::from_iter(sites)) } @@ -197,8 +198,8 @@ fn double_star_wildcard_dns_resource( }) } -pub(crate) fn resolved_ips() -> impl Strategy> { - collection::hash_set( +pub(crate) fn resolved_ips() -> impl Strategy> { + collection::btree_set( prop_oneof![ dns_resource_ip4s().prop_map_into(), dns_resource_ip6s().prop_map_into() @@ -211,7 +212,7 @@ pub(crate) fn resolved_ips() -> impl Strategy> { pub(crate) fn subdomain_records( base: String, subdomains: impl Strategy, -) -> impl Strategy>> { +) -> impl Strategy>> { collection::hash_map(subdomains, resolved_ips(), 1..4).prop_map(move |subdomain_ips| { subdomain_ips .into_iter() diff --git a/rust/connlib/tunnel/src/tests/stub_portal.rs b/rust/connlib/tunnel/src/tests/stub_portal.rs index 155049900..d38e970cb 100644 --- a/rust/connlib/tunnel/src/tests/stub_portal.rs +++ b/rust/connlib/tunnel/src/tests/stub_portal.rs @@ -16,7 +16,7 @@ use proptest::{ strategy::{Just, Strategy}, }; use std::{ - collections::{BTreeMap, HashMap, HashSet}, + collections::{BTreeMap, BTreeSet, HashSet}, iter, net::{IpAddr, Ipv4Addr, Ipv6Addr}, }; @@ -25,12 +25,12 @@ use std::{ #[derive(Clone, derivative::Derivative)] #[derivative(Debug)] pub(crate) struct StubPortal { - gateways_by_site: HashMap>, + gateways_by_site: BTreeMap>, #[derivative(Debug = "ignore")] - sites_by_resource: HashMap, - cidr_resources: HashMap, - dns_resources: HashMap, + sites_by_resource: BTreeMap, + cidr_resources: BTreeMap, + dns_resources: BTreeMap, internet_resource: client::ResourceDescriptionInternet, #[derivative(Debug = "ignore")] @@ -39,20 +39,20 @@ pub(crate) struct StubPortal { impl StubPortal { pub(crate) fn new( - gateways_by_site: HashMap>, + gateways_by_site: BTreeMap>, gateway_selector: Selector, - cidr_resources: HashSet, - dns_resources: HashSet, + cidr_resources: BTreeSet, + dns_resources: BTreeSet, internet_resource: client::ResourceDescriptionInternet, ) -> Self { let cidr_resources = cidr_resources .into_iter() .map(|r| (r.id, r)) - .collect::>(); + .collect::>(); let dns_resources = dns_resources .into_iter() .map(|r| (r.id, r)) - .collect::>(); + .collect::>(); let cidr_sites = cidr_resources.iter().map(|(id, r)| { ( @@ -87,7 +87,9 @@ impl StubPortal { Self { gateways_by_site, gateway_selector, - sites_by_resource: HashMap::from_iter(cidr_sites.chain(dns_sites).chain(internet_site)), + sites_by_resource: BTreeMap::from_iter( + cidr_sites.chain(dns_sites).chain(internet_site), + ), cidr_resources, dns_resources, internet_resource, @@ -205,7 +207,7 @@ impl StubPortal { pub(crate) fn dns_resource_records( &self, - ) -> impl Strategy>> { + ) -> impl Strategy>> { self.dns_resources .values() .map(|resource| { @@ -221,14 +223,14 @@ impl StubPortal { } _ => resolved_ips() .prop_map(move |resolved_ips| { - HashMap::from([(address.parse().unwrap(), resolved_ips)]) + BTreeMap::from([(address.parse().unwrap(), resolved_ips)]) }) .boxed(), } }) .collect::>() .prop_map(|records| { - let mut map = HashMap::default(); + let mut map = BTreeMap::default(); for record in records { map.extend(record) diff --git a/rust/connlib/tunnel/src/tests/sut.rs b/rust/connlib/tunnel/src/tests/sut.rs index 9815777a4..f882f0b0e 100644 --- a/rust/connlib/tunnel/src/tests/sut.rs +++ b/rust/connlib/tunnel/src/tests/sut.rs @@ -21,16 +21,14 @@ use connlib_shared::{ use proptest_state_machine::{ReferenceStateMachine, StateMachineTest}; use secrecy::ExposeSecret as _; use snownet::Transmit; +use std::collections::BTreeSet; use std::iter; use std::{ - collections::{BTreeMap, HashSet}, + collections::BTreeMap, net::IpAddr, time::{Duration, Instant}, }; use tracing::debug_span; -use tracing::subscriber::DefaultGuard; -use tracing_subscriber::layer::SubscriberExt; -use tracing_subscriber::{util::SubscriberInitExt as _, EnvFilter}; /// The actual system-under-test. /// @@ -45,9 +43,6 @@ pub(crate) struct TunnelTest { drop_direct_client_traffic: bool, network: RoutingTable, - - #[allow(dead_code)] - logger: DefaultGuard, } impl StateMachineTest for TunnelTest { @@ -58,16 +53,6 @@ impl StateMachineTest for TunnelTest { fn init_test( ref_state: &::State, ) -> Self::SystemUnderTest { - let flux_capacitor = FluxCapacitor::default(); - - let logger = tracing_subscriber::fmt() - .with_test_writer() - // .with_writer(crate::tests::run_count_appender::appender()) // Useful for diffing logs between runs. - .with_timer(flux_capacitor.clone()) - .with_env_filter(EnvFilter::from_default_env()) - .finish() - .set_default(); - // Construct client, gateway and relay from the initial state. let mut client = ref_state .client @@ -107,19 +92,21 @@ impl StateMachineTest for TunnelTest { .collect::>(); // Configure client and gateway with the relays. - client.exec_mut(|c| c.update_relays(iter::empty(), relays.iter(), flux_capacitor.now())); + client.exec_mut(|c| { + c.update_relays(iter::empty(), relays.iter(), ref_state.flux_capacitor.now()) + }); for gateway in gateways.values_mut() { - gateway - .exec_mut(|g| g.update_relays(iter::empty(), relays.iter(), flux_capacitor.now())); + gateway.exec_mut(|g| { + g.update_relays(iter::empty(), relays.iter(), ref_state.flux_capacitor.now()) + }); } let mut this = Self { - flux_capacitor, + flux_capacitor: ref_state.flux_capacitor.clone(), network: ref_state.network.clone(), drop_direct_client_traffic: ref_state.drop_direct_client_traffic, client, gateways, - logger, relays, dns_servers, }; @@ -345,16 +332,6 @@ impl StateMachineTest for TunnelTest { state: &Self::SystemUnderTest, ref_state: &::State, ) { - let _guard = tracing_subscriber::registry() - .with( - tracing_subscriber::fmt::layer() - .with_test_writer() - .with_timer(state.flux_capacitor.clone()), - ) - .with(PanicOnErrorEvents::default()) // Temporarily install a layer that panics when `_guard` goes out of scope if any of our assertions emitted an error. - .with(EnvFilter::from_default_env()) - .set_default(); - let ref_client = ref_state.client.inner(); let sim_client = state.client.inner(); let sim_gateways = state @@ -495,7 +472,7 @@ impl TunnelTest { fn handle_timeout( &mut self, - global_dns_records: &BTreeMap>, + global_dns_records: &BTreeMap>, buffered_transmits: &mut BufferedTransmits, ) { let now = self.flux_capacitor.now(); @@ -624,7 +601,7 @@ impl TunnelTest { src: ClientId, event: ClientEvent, portal: &StubPortal, - global_dns_records: &BTreeMap>, + global_dns_records: &BTreeMap>, ) { let now = self.flux_capacitor.now(); diff --git a/rust/connlib/tunnel/src/tests/transition.rs b/rust/connlib/tunnel/src/tests/transition.rs index 9359e6de5..5a04ebba7 100644 --- a/rust/connlib/tunnel/src/tests/transition.rs +++ b/rust/connlib/tunnel/src/tests/transition.rs @@ -10,7 +10,7 @@ use domain::base::Rtype; use prop::collection; use proptest::{prelude::*, sample}; use std::{ - collections::{BTreeMap, HashSet}, + collections::{BTreeMap, BTreeSet}, net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, }; @@ -24,7 +24,7 @@ pub(crate) enum Transition { /// Deactivate a resource on the client. DeactivateResource(ResourceId), /// Client-side disable resource - DisableResources(HashSet), + DisableResources(BTreeSet), /// Send an ICMP packet to non-resource IP. SendICMPPacketToNonResourceIp { src: IpAddr, diff --git a/rust/headless-client/src/ipc_service.rs b/rust/headless-client/src/ipc_service.rs index 0fc26167b..5b418bcdb 100644 --- a/rust/headless-client/src/ipc_service.rs +++ b/rust/headless-client/src/ipc_service.rs @@ -14,7 +14,7 @@ use futures::{ task::{Context, Poll}, Future as _, SinkExt as _, Stream as _, }; -use std::{collections::HashSet, net::IpAddr, path::PathBuf, pin::pin, sync::Arc, time::Duration}; +use std::{collections::BTreeSet, net::IpAddr, path::PathBuf, pin::pin, sync::Arc, time::Duration}; use tokio::{sync::mpsc, time::Instant}; use tracing::subscriber::set_global_default; use tracing_subscriber::{layer::SubscriberExt, EnvFilter, Layer, Registry}; @@ -76,7 +76,7 @@ pub enum ClientMsg { Disconnect, Reset, SetDns(Vec), - SetDisabledResources(HashSet), + SetDisabledResources(BTreeSet), } /// Only called from the GUI Client's build of the IPC service diff --git a/rust/ip-packet/src/lib.rs b/rust/ip-packet/src/lib.rs index dcb5f758e..e5d1d29cc 100644 --- a/rust/ip-packet/src/lib.rs +++ b/rust/ip-packet/src/lib.rs @@ -45,7 +45,7 @@ macro_rules! swap_src_dst { }; } -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] pub enum Protocol { /// Contains either the source or destination port. Tcp(u16), diff --git a/rust/relay/src/lib.rs b/rust/relay/src/lib.rs index b5eab3d25..ab250ee86 100644 --- a/rust/relay/src/lib.rs +++ b/rust/relay/src/lib.rs @@ -89,7 +89,7 @@ impl From<(Option, Option)> for IpStack { /// From the [spec](https://www.rfc-editor.org/rfc/rfc8656#section-2-4.4): /// /// > A STUN client that implements this specification. -#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)] +#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy, PartialOrd, Ord)] pub struct ClientSocket(SocketAddr); impl ClientSocket { diff --git a/rust/relay/src/server.rs b/rust/relay/src/server.rs index e7f84bd37..11becfcd0 100644 --- a/rust/relay/src/server.rs +++ b/rust/relay/src/server.rs @@ -18,7 +18,7 @@ use opentelemetry::KeyValue; use rand::Rng; use secrecy::SecretString; use smallvec::SmallVec; -use std::collections::{HashMap, VecDeque}; +use std::collections::{BTreeMap, HashMap, VecDeque}; use std::hash::Hash; use std::net::{IpAddr, SocketAddr}; use std::ops::RangeInclusive; @@ -68,7 +68,7 @@ pub struct Server { ports: RangeInclusive, /// Channel numbers are unique by client, thus indexed by both. - channels_by_client_and_number: HashMap<(ClientSocket, ChannelNumber), Channel>, + channels_by_client_and_number: BTreeMap<(ClientSocket, ChannelNumber), Channel>, /// Channel numbers are unique between clients and peers, thus indexed by both. channel_numbers_by_client_and_peer: HashMap<(ClientSocket, PeerSocket), ChannelNumber>,