diff --git a/.github/workflows/_rust.yml b/.github/workflows/_rust.yml index 5743d7264..a827a11be 100644 --- a/.github/workflows/_rust.yml +++ b/.github/workflows/_rust.yml @@ -105,9 +105,13 @@ jobs: cargo test --all-features ${{ steps.setup-rust.outputs.packages }} -- --include-ignored --nocapture # Poor man's test coverage testing: Grep the generated logs for specific patterns / lines. - rg --count --no-ignore SendICMPPacketToCidrResource $TESTCASES_DIR - rg --count --no-ignore SendICMPPacketToDnsResource $TESTCASES_DIR + rg --count --no-ignore SendIcmpPacket $TESTCASES_DIR + rg --count --no-ignore SendUdpPacket $TESTCASES_DIR + rg --count --no-ignore SendTcpPayload $TESTCASES_DIR rg --count --no-ignore SendDnsQueries $TESTCASES_DIR + rg --count --no-ignore "Packet for DNS resource" $TESTCASES_DIR + rg --count --no-ignore "Packet for CIDR resource" $TESTCASES_DIR + rg --count --no-ignore "Packet for Internet resource" $TESTCASES_DIR rg --count --no-ignore "Performed IP-NAT46" $TESTCASES_DIR rg --count --no-ignore "Performed IP-NAT64" $TESTCASES_DIR diff --git a/rust/connlib/tunnel/src/client.rs b/rust/connlib/tunnel/src/client.rs index e0f520424..da0a2bb15 100644 --- a/rust/connlib/tunnel/src/client.rs +++ b/rust/connlib/tunnel/src/client.rs @@ -793,17 +793,28 @@ impl ClientState { let maybe_dns_resource_id = self .stub_resolver .resolve_resource_by_ip(&destination) - .filter(|resource| self.is_resource_enabled(resource)); + .filter(|resource| self.is_resource_enabled(resource)) + .inspect( + |resource| tracing::trace!(target: "tunnel_test_coverage", %destination, %resource, "Packet for DNS resource"), + ); // We don't need to filter from here because resources are removed from the active_cidr_resources as soon as they are disabled. let maybe_cidr_resource_id = self .active_cidr_resources .longest_match(destination) - .map(|(_, res)| res.id); + .map(|(_, res)| res.id) + .inspect( + |resource| tracing::trace!(target: "tunnel_test_coverage", %destination, %resource, "Packet for CIDR resource"), + ); maybe_dns_resource_id .or(maybe_cidr_resource_id) .or(self.internet_resource) + .inspect(|r| { + if Some(*r) == self.internet_resource { + tracing::trace!(target: "tunnel_test_coverage", %destination, "Packet for Internet resource") + } + }) } pub fn update_system_resolvers(&mut self, new_dns: Vec) { diff --git a/rust/connlib/tunnel/src/tests.rs b/rust/connlib/tunnel/src/tests.rs index 64634855f..8ae960a63 100644 --- a/rust/connlib/tunnel/src/tests.rs +++ b/rust/connlib/tunnel/src/tests.rs @@ -29,8 +29,6 @@ mod sut; mod transition; type QueryId = u16; -type IcmpSeq = u16; -type IcmpIdentifier = u16; #[test] #[expect(clippy::print_stdout, clippy::print_stderr)] @@ -197,7 +195,7 @@ fn init_logging( fn log_file_filter() -> EnvFilter { let default_filter = - "debug,firezone_tunnel=trace,firezone_tunnel::tests=debug,ip_packet=trace".to_owned(); + "debug,firezone_tunnel=trace,firezone_tunnel::tests=debug,tunnel_test_coverage=trace,ip_packet=trace".to_owned(); let env_filter = std::env::var("RUST_LOG").unwrap_or_default(); EnvFilter::new([default_filter, env_filter].join(",")) diff --git a/rust/connlib/tunnel/src/tests/assertions.rs b/rust/connlib/tunnel/src/tests/assertions.rs index 51f69ff26..a77db634f 100644 --- a/rust/connlib/tunnel/src/tests/assertions.rs +++ b/rust/connlib/tunnel/src/tests/assertions.rs @@ -1,18 +1,19 @@ use super::{ sim_client::{RefClient, SimClient}, sim_gateway::SimGateway, + transition::{Destination, ReplyTo}, }; -use crate::tests::reference::ResourceDst; use connlib_model::{DomainName, GatewayId}; use ip_packet::IpPacket; use itertools::Itertools; use std::{ collections::{hash_map::Entry, BTreeMap, BTreeSet, HashMap, VecDeque}, + hash::Hash, marker::PhantomData, net::IpAddr, sync::atomic::{AtomicBool, Ordering}, }; -use tracing::{Level, Subscriber}; +use tracing::{Level, Span, Subscriber}; use tracing_subscriber::Layer; /// Asserts the following properties for all ICMP handshakes: @@ -24,63 +25,144 @@ use tracing_subscriber::Layer; pub(crate) fn assert_icmp_packets_properties( ref_client: &RefClient, sim_client: &SimClient, - sim_gateways: HashMap, + sim_gateways: &BTreeMap, global_dns_records: &BTreeMap>, ) { - let unexpected_icmp_replies = find_unexpected_entries( - &ref_client - .expected_icmp_handshakes - .values() - .flatten() - .collect(), + let received_icmp_requests = sim_gateways + .iter() + .map(|(g, s)| (*g, &s.received_icmp_requests)) + .collect(); + + assert_packets_properties( + ref_client, + &sim_client.sent_icmp_requests, + &received_icmp_requests, + &ref_client.expected_icmp_handshakes, &sim_client.received_icmp_replies, - |(_, (_, seq_a, id_a)), (seq_b, id_b)| seq_a == seq_b && id_a == id_b, + "ICMP", + global_dns_records, + |seq, identifier| tracing::info_span!(target: "assertions", "ICMP", ?seq, ?identifier), + ); +} + +/// Asserts the following properties for all UDP handshakes: +/// 1. An UDP request on the client MUST result in an UDP response using the flipped src & dst IP and sport and dport. +/// 2. An UDP request on the gateway MUST target the intended resource: +/// - For CIDR resources, that is the actual CIDR resource IP. +/// - For DNS resources, the IP must match one of the resolved IPs for the domain. +/// 3. For DNS resources, the mapping of proxy IP to actual resource IP must be stable. +pub(crate) fn assert_udp_packets_properties( + ref_client: &RefClient, + sim_client: &SimClient, + sim_gateways: &BTreeMap, + global_dns_records: &BTreeMap>, +) { + let received_udp_requests = sim_gateways + .iter() + .map(|(g, s)| (*g, &s.received_udp_requests)) + .collect(); + + assert_packets_properties( + ref_client, + &sim_client.sent_udp_requests, + &received_udp_requests, + &ref_client.expected_udp_handshakes, + &sim_client.received_udp_replies, + "UDP", + global_dns_records, + |sport, dport| tracing::info_span!(target: "assertions", "UDP", ?sport, ?dport), + ); +} + +/// Asserts the following properties for all TCP handshakes: +/// 1. An TCP request on the client MUST result in an TCP response using the flipped src & dst IP and sport and dport. +/// 2. An TCP request on the gateway MUST target the intended resource: +/// - For CIDR resources, that is the actual CIDR resource IP. +/// - For DNS resources, the IP must match one of the resolved IPs for the domain. +/// 3. For DNS resources, the mapping of proxy IP to actual resource IP must be stable. +pub(crate) fn assert_tcp_packets_properties( + ref_client: &RefClient, + sim_client: &SimClient, + sim_gateways: &BTreeMap, + global_dns_records: &BTreeMap>, +) { + let received_tcp_requests = sim_gateways + .iter() + .map(|(g, s)| (*g, &s.received_tcp_requests)) + .collect(); + + assert_packets_properties( + ref_client, + &sim_client.sent_tcp_requests, + &received_tcp_requests, + &ref_client.expected_tcp_exchanges, + &sim_client.received_tcp_replies, + "TCP", + global_dns_records, + |sport, dport| tracing::info_span!(target: "assertions", "TCP", ?sport, ?dport), + ); +} + +#[expect(clippy::too_many_arguments)] +fn assert_packets_properties( + ref_client: &RefClient, + sent_requests: &HashMap<(T, U), IpPacket>, + received_requests: &BTreeMap>, + expected_handshakes: &BTreeMap>, + received_replies: &BTreeMap<(T, U), IpPacket>, + packet_protocol: &str, + global_dns_records: &BTreeMap>, + make_span: impl Fn(T, U) -> Span, +) where + T: Copy + std::fmt::Debug, + U: Copy + std::fmt::Debug, + (T, U): ReplyTo + Hash + Eq + Ord, +{ + let unexpected_replies = find_unexpected_entries( + &expected_handshakes.values().flatten().collect(), + received_replies, + |(_, (_, t_a, u_a)), b| (*t_a, *u_a) == b.reply_to(), ); - if !unexpected_icmp_replies.is_empty() { - tracing::error!(target: "assertions", ?unexpected_icmp_replies, "❌ Unexpected ICMP replies on client"); + if !unexpected_replies.is_empty() { + tracing::error!(target: "assertions", ?unexpected_replies, "❌ Unexpected {packet_protocol} replies on client"); } - for (gid, expected_icmp_handshakes) in ref_client.expected_icmp_handshakes.iter() { - let gateway = sim_gateways.get(gid).unwrap(); + for (gid, expected_handshakes) in expected_handshakes.iter() { + let received_requests = received_requests.get(gid).unwrap(); - let num_expected_handshakes = expected_icmp_handshakes.len(); - let num_actual_handshakes = gateway.received_icmp_requests.len(); + let num_expected_handshakes = expected_handshakes.len(); + let num_actual_handshakes = received_requests.len(); if num_expected_handshakes != num_actual_handshakes { - tracing::error!(target: "assertions", %num_expected_handshakes, %num_actual_handshakes, %gid, "❌ Unexpected ICMP requests"); + tracing::error!(target: "assertions", %num_expected_handshakes, %num_actual_handshakes, %gid, "❌ Unexpected {packet_protocol} requests"); } else { - tracing::info!(target: "assertions", %num_expected_handshakes, %gid, "✅ Performed the expected ICMP handshakes"); + tracing::info!(target: "assertions", %num_expected_handshakes, %gid, "✅ Performed the expected {packet_protocol} handshakes"); } } let mut mapping = HashMap::new(); - // Assert properties of the individual ICMP handshakes per gateway. + // Assert properties of the individual handshakes per gateway. // Due to connlib's implementation of NAT64, we cannot match the packets sent by the client to the packets arriving at the resource by port or ICMP identifier. - // Thus, we rely on the _order_ here which is why the packets are indexed by gateway in the `RefClient`. - for (gateway, expected_icmp_handshakes) in &ref_client.expected_icmp_handshakes { - let received_icmp_requests = &sim_gateways.get(gateway).unwrap().received_icmp_requests; + // Thus, we rely on a custom u64 payload attached to all packets to uniquely identify every individual packet. + for (gateway, expected_handshakes) in expected_handshakes { + let received_requests = received_requests.get(gateway).unwrap(); + for (payload, (resource_dst, t, u)) in expected_handshakes { + let _guard = make_span(*t, *u).entered(); - for (payload, (resource_dst, seq, identifier)) in expected_icmp_handshakes { - let _guard = - tracing::info_span!(target: "assertions", "icmp", %seq, %identifier).entered(); - - let Some(client_sent_request) = sim_client.sent_icmp_requests.get(&(*seq, *identifier)) - else { - tracing::error!(target: "assertions", "❌ Missing ICMP request on client"); + let Some(client_sent_request) = sent_requests.get(&(*t, *u)) else { + tracing::error!(target: "assertions", "❌ Missing {packet_protocol} request on client"); continue; }; - let Some(client_received_reply) = - sim_client.received_icmp_replies.get(&(*seq, *identifier)) - else { - tracing::error!(target: "assertions", "❌ Missing ICMP reply on client"); + let Some(client_received_reply) = received_replies.get(&(*t, *u).reply_to()) else { + tracing::error!(target: "assertions", "❌ Missing {packet_protocol} reply on client"); continue; }; assert_correct_src_and_dst_ips(client_sent_request, client_received_reply); - let Some(gateway_received_request) = received_icmp_requests.get(payload) else { - tracing::error!(target: "assertions", "❌ Missing ICMP request on gateway"); + let Some(gateway_received_request) = received_requests.get(payload) else { + tracing::error!(target: "assertions", "❌ Missing {packet_protocol} request on gateway"); continue; }; @@ -89,19 +171,19 @@ pub(crate) fn assert_icmp_packets_properties( let actual = gateway_received_request.source(); if expected != actual { - tracing::error!(target: "assertions", %expected, %actual, "❌ Unexpected request source"); + tracing::error!(target: "assertions", %expected, %actual, "❌ Unexpected {packet_protocol} request source"); } } match resource_dst { - ResourceDst::Cidr(resource_dst) => { + Destination::IpAddr(resource_dst) => { assert_destination_is_cdir_resource(gateway_received_request, resource_dst) } - ResourceDst::Dns(domain) => { + Destination::DomainName { name, .. } => { assert_destination_is_dns_resource( gateway_received_request, global_dns_records, - domain, + name, ); assert_proxy_ip_mapping_is_stable( @@ -110,9 +192,6 @@ pub(crate) fn assert_icmp_packets_properties( &mut mapping, ) } - ResourceDst::Internet(resource_dst) => { - assert_destination_is_cdir_resource(gateway_received_request, resource_dst) - } } } } @@ -324,11 +403,11 @@ fn assert_proxy_ip_mapping_is_stable( fn find_unexpected_entries<'a, E, K, V>( expected: &VecDeque, actual: &'a BTreeMap, - is_equal: impl Fn(&E, &K) -> bool, + is_expected: impl Fn(&E, &K) -> bool, ) -> Vec<&'a V> { actual .iter() - .filter(|(k, _)| !expected.iter().any(|e| is_equal(e, k))) + .filter(|(k, _)| !expected.iter().any(|e| is_expected(e, k))) .map(|(_, v)| v) .collect() } diff --git a/rust/connlib/tunnel/src/tests/reference.rs b/rust/connlib/tunnel/src/tests/reference.rs index 65851710a..30236410e 100644 --- a/rust/connlib/tunnel/src/tests/reference.rs +++ b/rust/connlib/tunnel/src/tests/reference.rs @@ -6,7 +6,10 @@ use crate::{client, DomainName}; use crate::{dns::is_subdomain, proptest::relay_id}; use connlib_model::{GatewayId, RelayId, ResourceId, StaticSecret}; use domain::base::Rtype; +use ip_network::{Ipv4Network, Ipv6Network}; +use prop::sample::select; use proptest::{prelude::*, sample}; +use std::net::{Ipv4Addr, Ipv6Addr}; use std::{ collections::{BTreeMap, BTreeSet, HashSet}, fmt, iter, @@ -35,13 +38,6 @@ pub(crate) struct ReferenceState { pub(crate) network: RoutingTable, } -#[derive(Debug, Clone)] -pub(crate) enum ResourceDst { - Internet(IpAddr), - Cidr(IpAddr), - Dns(DomainName), -} - /// Implementation of our reference state machine. /// /// The logic in here represents what we expect the [`ClientState`] & [`GatewayState`] to do. @@ -189,40 +185,52 @@ impl ReferenceState { 10, state.client.inner().ipv4_cidr_resource_dsts(), |ip4_resources| { - icmp_to_cidr_resource( - packet_source_v4(state.client.inner().tunnel_ip4), - sample::select(ip4_resources).prop_flat_map(crate::proptest::host_v4), - ) + let tunnel_ip4 = state.client.inner().tunnel_ip4; + + prop_oneof![ + icmp_packet(packet_source_v4(tunnel_ip4), select_host_v4(&ip4_resources)), + udp_packet(packet_source_v4(tunnel_ip4), select_host_v4(&ip4_resources)), + tcp_packet(packet_source_v4(tunnel_ip4), select_host_v4(&ip4_resources)), + ] }, ) .with_if_not_empty( 10, state.client.inner().ipv6_cidr_resource_dsts(), |ip6_resources| { - icmp_to_cidr_resource( - packet_source_v6(state.client.inner().tunnel_ip6), - sample::select(ip6_resources).prop_flat_map(crate::proptest::host_v6), - ) + let tunnel_ip6 = state.client.inner().tunnel_ip6; + + prop_oneof![ + icmp_packet(packet_source_v6(tunnel_ip6), select_host_v6(&ip6_resources)), + udp_packet(packet_source_v6(tunnel_ip6), select_host_v6(&ip6_resources)), + tcp_packet(packet_source_v6(tunnel_ip6), select_host_v6(&ip6_resources)), + ] }, ) .with_if_not_empty( 10, state.client.inner().resolved_v4_domains(), |dns_v4_domains| { - icmp_to_dns_resource( - packet_source_v4(state.client.inner().tunnel_ip4), - sample::select(dns_v4_domains), - ) + let tunnel_ip4 = state.client.inner().tunnel_ip4; + + prop_oneof![ + icmp_packet(packet_source_v4(tunnel_ip4), select(dns_v4_domains.clone())), + udp_packet(packet_source_v4(tunnel_ip4), select(dns_v4_domains.clone())), + tcp_packet(packet_source_v4(tunnel_ip4), select(dns_v4_domains)), + ] }, ) .with_if_not_empty( 10, state.client.inner().resolved_v6_domains(), |dns_v6_domains| { - icmp_to_dns_resource( - packet_source_v6(state.client.inner().tunnel_ip6), - sample::select(dns_v6_domains), - ) + let tunnel_ip6 = state.client.inner().tunnel_ip6; + + prop_oneof![ + icmp_packet(packet_source_v6(tunnel_ip6), select(dns_v6_domains.clone()),), + udp_packet(packet_source_v6(tunnel_ip6), select(dns_v6_domains.clone()),), + tcp_packet(packet_source_v6(tunnel_ip6), select(dns_v6_domains),), + ] }, ) .with_if_not_empty( @@ -240,10 +248,22 @@ impl ReferenceState { .inner() .resolved_ip4_for_non_resources(&state.global_dns_records), |resolved_non_resource_ip4s| { - ping_random_ip( - packet_source_v4(state.client.inner().tunnel_ip4), - sample::select(resolved_non_resource_ip4s), - ) + let tunnel_ip4 = state.client.inner().tunnel_ip4; + + prop_oneof![ + icmp_packet( + packet_source_v4(tunnel_ip4), + select(resolved_non_resource_ip4s.clone()), + ), + udp_packet( + packet_source_v4(tunnel_ip4), + select(resolved_non_resource_ip4s.clone()), + ), + tcp_packet( + packet_source_v4(tunnel_ip4), + select(resolved_non_resource_ip4s), + ), + ] }, ) .with_if_not_empty( @@ -253,10 +273,22 @@ impl ReferenceState { .inner() .resolved_ip6_for_non_resources(&state.global_dns_records), |resolved_non_resource_ip6s| { - ping_random_ip( - packet_source_v6(state.client.inner().tunnel_ip6), - sample::select(resolved_non_resource_ip6s), - ) + let tunnel_ip6 = state.client.inner().tunnel_ip6; + + prop_oneof![ + icmp_packet( + packet_source_v6(tunnel_ip6), + select(resolved_non_resource_ip6s.clone()), + ), + udp_packet( + packet_source_v6(tunnel_ip6), + select(resolved_non_resource_ip6s.clone()), + ), + tcp_packet( + packet_source_v6(tunnel_ip6), + select(resolved_non_resource_ip6s), + ), + ] }, ) .boxed() @@ -379,7 +411,7 @@ impl ReferenceState { } } } - Transition::SendICMPPacketToNonResourceIp { + Transition::SendIcmpPacket { src, dst, seq, @@ -387,44 +419,37 @@ impl ReferenceState { payload, } => { state.client.exec_mut(|client| { - // If the Internet Resource is active, all packets are expected to be routed. - if client.active_internet_resource().is_some() { - client.on_icmp_packet_to_internet( - *src, - *dst, - *seq, - *identifier, - *payload, - |r| state.portal.gateway_for_resource(r).copied(), - ) - } - }); - } - Transition::SendICMPPacketToCidrResource { - src, - dst, - seq, - identifier, - payload, - } => { - state.client.exec_mut(|client| { - client.on_icmp_packet_to_cidr(*src, *dst, *seq, *identifier, *payload, |r| { + client.on_icmp_packet(*src, dst.clone(), *seq, *identifier, *payload, |r| { state.portal.gateway_for_resource(r).copied() }) }); } - Transition::SendICMPPacketToDnsResource { + Transition::SendUdpPacket { src, dst, - seq, - identifier, + sport, + dport, payload, - .. - } => state.client.exec_mut(|client| { - client.on_icmp_packet_to_dns(*src, dst.clone(), *seq, *identifier, *payload, |r| { - state.portal.gateway_for_resource(r).copied() - }) - }), + } => { + state.client.exec_mut(|client| { + client.on_udp_packet(*src, dst.clone(), *sport, *dport, *payload, |r| { + state.portal.gateway_for_resource(r).copied() + }) + }); + } + Transition::SendTcpPayload { + src, + dst, + sport, + dport, + payload, + } => { + state.client.exec_mut(|client| { + client.on_tcp_packet(*src, dst.clone(), *sport, *dport, *payload, |r| { + state.portal.gateway_for_resource(r).copied() + }) + }); + } Transition::UpdateSystemDnsServers(servers) => { state .client @@ -487,61 +512,75 @@ impl ReferenceState { .iter() .all(|r| state.client.inner().has_resource(*r)) } - Transition::SendICMPPacketToNonResourceIp { - dst, - seq, - identifier, - payload, - .. - } => { - let is_valid_icmp_packet = state - .client - .inner() - .is_valid_icmp_packet(seq, identifier, payload); - let is_cidr_resource = state.client.inner().cidr_resource_by_ip(*dst).is_some(); - - is_valid_icmp_packet && !is_cidr_resource - } - Transition::SendICMPPacketToCidrResource { - seq, - identifier, - dst, - payload, - .. - } => { - let ref_client = state.client.inner(); - let Some(rid) = ref_client.cidr_resource_by_ip(*dst) else { - return false; - }; - let Some(gateway) = state.portal.gateway_for_resource(rid) else { - return false; - }; - - ref_client.is_valid_icmp_packet(seq, identifier, payload) - && state.gateways.contains_key(gateway) - } - Transition::SendICMPPacketToDnsResource { - seq, - identifier, - dst, + Transition::SendIcmpPacket { src, + dst: Destination::DomainName { name, .. }, + seq, + identifier, + payload, + } => { + let ref_client = state.client.inner(); + + ref_client.is_valid_icmp_packet(seq, identifier, payload) + && state.is_valid_dst_domain(name, src) + } + Transition::SendUdpPacket { + src, + dst: Destination::DomainName { name, .. }, + sport, + dport, + payload, + } => { + let ref_client = state.client.inner(); + + ref_client.is_valid_udp_packet(sport, dport, payload) + && state.is_valid_dst_domain(name, src) + } + Transition::SendTcpPayload { + src, + dst: Destination::DomainName { name, .. }, + sport, + dport, + payload, + } => { + let ref_client = state.client.inner(); + + ref_client.is_valid_tcp_packet(sport, dport, payload) + && state.is_valid_dst_domain(name, src) + } + Transition::SendIcmpPacket { + dst: Destination::IpAddr(dst), + seq, + identifier, payload, .. } => { let ref_client = state.client.inner(); - let Some(resource) = ref_client.dns_resource_by_domain(dst) else { - return false; - }; - let Some(gateway) = state.portal.gateway_for_resource(resource) else { - return false; - }; ref_client.is_valid_icmp_packet(seq, identifier, payload) - && ref_client.dns_records.get(dst).is_some_and(|r| match src { - IpAddr::V4(_) => r.contains(&Rtype::A), - IpAddr::V6(_) => r.contains(&Rtype::AAAA), - }) - && state.gateways.contains_key(gateway) + && state.is_valid_dst_ip(*dst) + } + Transition::SendUdpPacket { + dst: Destination::IpAddr(dst), + sport, + dport, + payload, + .. + } => { + let ref_client = state.client.inner(); + + ref_client.is_valid_udp_packet(sport, dport, payload) && state.is_valid_dst_ip(*dst) + } + Transition::SendTcpPayload { + dst: Destination::IpAddr(dst), + sport, + dport, + payload, + .. + } => { + let ref_client = state.client.inner(); + + ref_client.is_valid_tcp_packet(sport, dport, payload) && state.is_valid_dst_ip(*dst) } Transition::UpdateSystemDnsServers(servers) => { if servers.is_empty() { @@ -624,6 +663,37 @@ impl ReferenceState { Transition::PartitionRelaysFromPortal => true, } } + + fn is_valid_dst_ip(&self, dst: IpAddr) -> bool { + let Some(rid) = self.client.inner().cidr_resource_by_ip(dst) else { + // As long as the packet is valid it's always valid to send to a non-resource + return true; + }; + let Some(gateway) = self.portal.gateway_for_resource(rid) else { + return false; + }; + + self.gateways.contains_key(gateway) + } + + fn is_valid_dst_domain(&self, name: &DomainName, src: &IpAddr) -> bool { + let Some(resource) = self.client.inner().dns_resource_by_domain(name) else { + return false; + }; + let Some(gateway) = self.portal.gateway_for_resource(resource) else { + return false; + }; + + self.client + .inner() + .dns_records + .get(name) + .is_some_and(|r| match src { + IpAddr::V4(_) => r.contains(&Rtype::A), + IpAddr::V6(_) => r.contains(&Rtype::AAAA), + }) + && self.gateways.contains_key(gateway) + } } /// Several helper functions to make the reference state more readable. @@ -680,6 +750,14 @@ impl ReferenceState { } } +fn select_host_v4(hosts: &[Ipv4Network]) -> impl Strategy { + sample::select(hosts.to_vec()).prop_flat_map(crate::proptest::host_v4) +} + +fn select_host_v6(hosts: &[Ipv6Network]) -> impl Strategy { + sample::select(hosts.to_vec()).prop_flat_map(crate::proptest::host_v6) +} + pub(crate) fn private_key() -> impl Strategy { any::<[u8; 32]>().prop_map(PrivateKey) } diff --git a/rust/connlib/tunnel/src/tests/sim_client.rs b/rust/connlib/tunnel/src/tests/sim_client.rs index 587c7d906..78301627a 100644 --- a/rust/connlib/tunnel/src/tests/sim_client.rs +++ b/rust/connlib/tunnel/src/tests/sim_client.rs @@ -1,10 +1,10 @@ use super::{ - reference::{private_key, PrivateKey, ResourceDst}, + reference::{private_key, PrivateKey}, sim_net::{any_ip_stack, any_port, host, Host}, sim_relay::{map_explode, SimRelay}, strategies::latency, - transition::{DnsQuery, DnsTransport}, - IcmpIdentifier, IcmpSeq, QueryId, + transition::{DPort, Destination, DnsQuery, DnsTransport, Identifier, SPort, Seq}, + QueryId, }; use crate::{ client::{CidrResource, DnsResource, InternetResource, Resource}, @@ -55,8 +55,14 @@ pub(crate) struct SimClient { pub(crate) sent_tcp_dns_queries: HashSet<(SocketAddr, QueryId)>, pub(crate) received_tcp_dns_responses: BTreeSet<(SocketAddr, QueryId)>, - pub(crate) sent_icmp_requests: HashMap<(u16, u16), IpPacket>, - pub(crate) received_icmp_replies: BTreeMap<(u16, u16), IpPacket>, + pub(crate) sent_icmp_requests: HashMap<(Seq, Identifier), IpPacket>, + pub(crate) received_icmp_replies: BTreeMap<(Seq, Identifier), IpPacket>, + + pub(crate) sent_tcp_requests: HashMap<(SPort, DPort), IpPacket>, + pub(crate) received_tcp_replies: BTreeMap<(SPort, DPort), IpPacket>, + + pub(crate) sent_udp_requests: HashMap<(SPort, DPort), IpPacket>, + pub(crate) received_udp_replies: BTreeMap<(SPort, DPort), IpPacket>, pub(crate) tcp_dns_client: dns_over_tcp::Client, @@ -79,6 +85,10 @@ impl SimClient { received_tcp_dns_responses: Default::default(), sent_icmp_requests: Default::default(), received_icmp_replies: Default::default(), + sent_tcp_requests: Default::default(), + received_tcp_replies: Default::default(), + sent_udp_requests: Default::default(), + received_udp_replies: Default::default(), enc_buffer: Default::default(), ipv4_routes: Default::default(), ipv6_routes: Default::default(), @@ -177,19 +187,7 @@ impl SimClient { packet: IpPacket, now: Instant, ) -> Option> { - if let Some(icmp) = packet.as_icmpv4() { - if let Icmpv4Type::EchoRequest(echo) = icmp.icmp_type() { - self.sent_icmp_requests - .insert((echo.seq, echo.id), packet.clone()); - } - } - - if let Some(icmp) = packet.as_icmpv6() { - if let Icmpv6Type::EchoRequest(echo) = icmp.icmp_type() { - self.sent_icmp_requests - .insert((echo.seq, echo.id), packet.clone()); - } - } + self.update_sent_requests(&packet); let Some(enc_packet) = self.sut.handle_tun_input(packet, now, &mut self.enc_buffer) else { self.sut.handle_timeout(now); // If we handled the packet internally, make sure to advance state. @@ -199,6 +197,43 @@ impl SimClient { Some(enc_packet.to_transmit(&self.enc_buffer).into_owned()) } + fn update_sent_requests(&mut self, packet: &IpPacket) { + if let Some(icmp) = packet.as_icmpv4() { + if let Icmpv4Type::EchoRequest(echo) = icmp.icmp_type() { + self.sent_icmp_requests + .insert((Seq(echo.seq), Identifier(echo.id)), packet.clone()); + return; + } + } + + if let Some(icmp) = packet.as_icmpv6() { + if let Icmpv6Type::EchoRequest(echo) = icmp.icmp_type() { + self.sent_icmp_requests + .insert((Seq(echo.seq), Identifier(echo.id)), packet.clone()); + return; + } + } + + if let Some(tcp) = packet.as_tcp() { + self.sent_tcp_requests.insert( + (SPort(tcp.source_port()), DPort(tcp.destination_port())), + packet.clone(), + ); + return; + } + + if let Some(udp) = packet.as_udp() { + self.sent_udp_requests.insert( + (SPort(udp.source_port()), DPort(udp.destination_port())), + packet.clone(), + ); + + return; + } + + tracing::error!("Sent a request with an unknown transport protocol"); + } + pub(crate) fn receive(&mut self, transmit: Transmit, now: Instant) { let Some(packet) = self.sut.handle_network_input( transmit.dst, @@ -215,29 +250,6 @@ impl SimClient { /// Process an IP packet received on the client. pub(crate) fn on_received_packet(&mut self, packet: IpPacket) { - if let Some(icmp) = packet.as_icmpv4() { - if let Icmpv4Type::EchoReply(echo) = icmp.icmp_type() { - self.received_icmp_replies - .insert((echo.seq, echo.id), packet.clone()); - - return; - } - } - - if let Some(icmp) = packet.as_icmpv6() { - if let Icmpv6Type::EchoReply(echo) = icmp.icmp_type() { - self.received_icmp_replies - .insert((echo.seq, echo.id), packet.clone()); - - return; - } - } - - if self.tcp_dns_client.accepts(&packet) { - self.tcp_dns_client.handle_inbound(packet); - return; - } - if let Some(udp) = packet.as_udp() { if udp.source_port() == 53 { let message = Message::from_slice(udp.payload()) @@ -256,6 +268,41 @@ impl SimClient { return; } + + self.received_udp_replies.insert( + (SPort(udp.source_port()), DPort(udp.destination_port())), + packet.clone(), + ); + return; + } + + if self.tcp_dns_client.accepts(&packet) { + self.tcp_dns_client.handle_inbound(packet); + return; + } + + if let Some(tcp) = packet.as_tcp() { + self.received_tcp_replies.insert( + (SPort(tcp.source_port()), DPort(tcp.destination_port())), + packet.clone(), + ); + return; + } + + if let Some(icmp) = packet.as_icmpv4() { + if let Icmpv4Type::EchoReply(echo) = icmp.icmp_type() { + self.received_icmp_replies + .insert((Seq(echo.seq), Identifier(echo.id)), packet.clone()); + return; + } + } + + if let Some(icmp) = packet.as_icmpv6() { + if let Icmpv6Type::EchoReply(echo) = icmp.icmp_type() { + self.received_icmp_replies + .insert((Seq(echo.seq), Identifier(echo.id)), packet.clone()); + return; + } } tracing::error!(?packet, "Unhandled packet"); @@ -376,7 +423,18 @@ pub struct RefClient { /// The expected ICMP handshakes. #[derivative(Debug = "ignore")] pub(crate) expected_icmp_handshakes: - BTreeMap>, + BTreeMap>, + + /// The expected UDP handshakes. + #[derivative(Debug = "ignore")] + pub(crate) expected_udp_handshakes: + BTreeMap>, + + /// The expected TCP exchanges. + #[derivative(Debug = "ignore")] + pub(crate) expected_tcp_exchanges: + BTreeMap>, + /// The expected UDP DNS handshakes. #[derivative(Debug = "ignore")] pub(crate) expected_udp_dns_handshakes: VecDeque<(SocketAddr, QueryId)>, @@ -517,100 +575,74 @@ impl RefClient { } } - #[tracing::instrument(level = "debug", skip_all, fields(dst, resource))] - pub(crate) fn on_icmp_packet_to_internet( + pub(crate) fn on_icmp_packet( &mut self, src: IpAddr, - dst: IpAddr, - seq: u16, - identifier: u16, + dst: Destination, + seq: Seq, + identifier: Identifier, payload: u64, gateway_by_resource: impl Fn(ResourceId) -> Option, ) { - tracing::Span::current().record("dst", tracing::field::display(dst)); + self.on_packet( + src, + dst.clone(), + (dst, seq, identifier), + |ref_client| &mut ref_client.expected_icmp_handshakes, + payload, + gateway_by_resource, + ); + } - // Second, if we are not yet connected, check if we have a resource for this IP. - let Some(rid) = self.active_internet_resource() else { - tracing::debug!("No internet resource"); - return; - }; - tracing::Span::current().record("resource", tracing::field::display(rid)); + pub(crate) fn on_udp_packet( + &mut self, + src: IpAddr, + dst: Destination, + sport: SPort, + dport: DPort, + payload: u64, + gateway_by_resource: impl Fn(ResourceId) -> Option, + ) { + self.on_packet( + src, + dst.clone(), + (dst, sport, dport), + |ref_client| &mut ref_client.expected_udp_handshakes, + payload, + gateway_by_resource, + ); + } - let Some(gateway) = gateway_by_resource(rid) else { - tracing::error!("No gateway for resource"); - return; - }; - - if self.is_connected_to_internet(rid) && self.is_tunnel_ip(src) { - tracing::debug!("Connected to Internet resource, expecting packet to be routed"); - self.expected_icmp_handshakes - .entry(gateway) - .or_default() - .insert(payload, (ResourceDst::Internet(dst), seq, identifier)); - return; - } - - // If we have a resource, the first packet will initiate a connection to the gateway. - tracing::debug!("Not connected to resource, expecting to trigger connection intent"); - self.connected_internet_resource = true; + pub(crate) fn on_tcp_packet( + &mut self, + src: IpAddr, + dst: Destination, + sport: SPort, + dport: DPort, + payload: u64, + gateway_by_resource: impl Fn(ResourceId) -> Option, + ) { + self.on_packet( + src, + dst.clone(), + (dst, sport, dport), + |ref_client| &mut ref_client.expected_tcp_exchanges, + payload, + gateway_by_resource, + ); } #[tracing::instrument(level = "debug", skip_all, fields(dst, resource))] - pub(crate) fn on_icmp_packet_to_cidr( + fn on_packet( &mut self, src: IpAddr, - dst: IpAddr, - seq: u16, - identifier: u16, + dst: Destination, + packet_id: E, + map: impl FnOnce(&mut Self) -> &mut BTreeMap>, payload: u64, gateway_by_resource: impl Fn(ResourceId) -> Option, ) { - tracing::Span::current().record("dst", tracing::field::display(dst)); - - // Second, if we are not yet connected, check if we have a resource for this IP. - let Some(rid) = self.cidr_resource_by_ip(dst) else { - tracing::debug!("No resource corresponds to IP"); - return; - }; - tracing::Span::current().record("resource", tracing::field::display(rid)); - - if self.disabled_resources.contains(&rid) { - return; - } - - let Some(gateway) = gateway_by_resource(rid) else { - tracing::error!("No gateway for resource"); - return; - }; - - if self.is_connected_to_internet_or_cidr(rid) && self.is_tunnel_ip(src) { - tracing::debug!("Connected to CIDR resource, expecting packet to be routed"); - self.expected_icmp_handshakes - .entry(gateway) - .or_default() - .insert(payload, (ResourceDst::Cidr(dst), seq, identifier)); - return; - } - - // If we have a resource, the first packet will initiate a connection to the gateway. - tracing::debug!("Not connected to resource, expecting to trigger connection intent"); - self.connect_to_internet_or_cidr_resource(rid, gateway); - } - - #[tracing::instrument(level = "debug", skip_all, fields(dst, resource))] - pub(crate) fn on_icmp_packet_to_dns( - &mut self, - src: IpAddr, - dst: DomainName, - seq: u16, - identifier: u16, - payload: u64, - gateway_by_resource: impl Fn(ResourceId) -> Option, - ) { - tracing::Span::current().record("dst", tracing::field::display(&dst)); - - let Some(resource) = self.dns_resource_by_domain(&dst) else { - tracing::debug!("No resource corresponds to IP"); + let Some(resource) = self.resource_by_dst(&dst) else { return; }; @@ -621,28 +653,40 @@ impl RefClient { return; }; - if self - .connected_dns_resources - .contains(&(resource, dst.clone())) - && self.is_tunnel_ip(src) - { - tracing::debug!("Connected to DNS resource, expecting packet to be routed"); - self.expected_icmp_handshakes + if self.is_connected_to_resource(resource, &dst) && self.is_tunnel_ip(src) { + tracing::debug!("Connected to resource, expecting packet to be routed"); + map(self) .entry(gateway) .or_default() - .insert(payload, (ResourceDst::Dns(dst), seq, identifier)); + .insert(payload, packet_id); return; } - debug_assert!( - self.dns_records.iter().any(|(name, _)| name == &dst), - "Should only sample ICMPs to domains that we resolved" - ); + if let Destination::DomainName { name: dst, .. } = &dst { + debug_assert!( + self.dns_records.iter().any(|(name, _)| name == dst), + "Should only sample domains that we resolved" + ); + } tracing::debug!("Not connected to resource, expecting to trigger connection intent"); - if !self.disabled_resources.contains(&resource) { - self.connected_dns_resources.insert((resource, dst)); - self.connected_gateways.insert(gateway); + self.connect_to_resource(resource, dst, gateway); + } + + fn connect_to_resource( + &mut self, + resource: ResourceId, + destination: Destination, + gateway: GatewayId, + ) { + match destination { + Destination::DomainName { name, .. } => { + if !self.disabled_resources.contains(&resource) { + self.connected_dns_resources.insert((resource, name)); + self.connected_gateways.insert(gateway); + } + } + Destination::IpAddr(_) => self.connect_to_internet_or_cidr_resource(resource, gateway), } } @@ -703,6 +747,19 @@ impl RefClient { .collect_vec() } + fn is_connected_to_resource(&self, resource: ResourceId, destination: &Destination) -> bool { + if self.is_connected_to_internet_or_cidr(resource) { + return true; + } + + let Destination::DomainName { name, .. } = destination else { + return false; + }; + + self.connected_dns_resources + .contains(&(resource, name.clone())) + } + fn is_connected_to_internet(&self, id: ResourceId) -> bool { self.active_internet_resource() == Some(id) && self.connected_internet_resource } @@ -731,6 +788,23 @@ impl RefClient { (is_known_host || is_dns_resource) && is_suppported_type } + fn resource_by_dst(&self, destination: &Destination) -> Option { + match destination { + Destination::DomainName { name, .. } => { + if let Some(id) = self.dns_resource_by_domain(name) { + return Some(id); + } + } + Destination::IpAddr(addr) => { + if let Some(id) = self.cidr_resource_by_ip(*addr) { + return Some(id); + } + } + } + + self.active_internet_resource() + } + pub(crate) fn dns_resource_by_domain(&self, domain: &DomainName) -> Option { self.resources .iter() @@ -751,16 +825,39 @@ impl RefClient { } /// An ICMP packet is valid if we didn't yet send an ICMP packet with the same seq, identifier and payload. - pub(crate) fn is_valid_icmp_packet(&self, seq: &u16, identifier: &u16, payload: &u64) -> bool { + pub(crate) fn is_valid_icmp_packet( + &self, + seq: &Seq, + identifier: &Identifier, + payload: &u64, + ) -> bool { self.expected_icmp_handshakes.values().flatten().all( - |(existig_payload, (_, existing_seq, existing_identifer))| { + |(existig_payload, (_, existing_seq, existing_identifier))| { existing_seq != seq - && existing_identifer != identifier + && existing_identifier != identifier && existig_payload != payload }, ) } + /// An UDP packet is valid if we didn't yet send an UDP packet with the same sport, dport and payload. + pub(crate) fn is_valid_udp_packet(&self, sport: &SPort, dport: &DPort, payload: &u64) -> bool { + self.expected_udp_handshakes.values().flatten().all( + |(existig_payload, (_, existing_sport, existing_dport))| { + existing_dport != dport && existing_sport != sport && existig_payload != payload + }, + ) + } + + /// An TCP packet is valid if we didn't yet send an TCP packet with the same sport, dport and payload. + pub(crate) fn is_valid_tcp_packet(&self, sport: &SPort, dport: &DPort, payload: &u64) -> bool { + self.expected_tcp_exchanges.values().flatten().all( + |(existig_payload, (_, existing_sport, existing_dport))| { + existing_dport != dport && existing_sport != sport && existig_payload != payload + }, + ) + } + pub(crate) fn resolved_v4_domains(&self) -> Vec { self.resolved_domains() .filter_map(|(domain, records)| { @@ -991,6 +1088,8 @@ fn ref_client( connected_dns_resources: Default::default(), connected_internet_resource: Default::default(), expected_icmp_handshakes: Default::default(), + expected_udp_handshakes: Default::default(), + expected_tcp_exchanges: Default::default(), expected_udp_dns_handshakes: Default::default(), expected_tcp_dns_handshakes: Default::default(), disabled_resources: Default::default(), diff --git a/rust/connlib/tunnel/src/tests/sim_gateway.rs b/rust/connlib/tunnel/src/tests/sim_gateway.rs index ec4b89ba4..0bc9da5ec 100644 --- a/rust/connlib/tunnel/src/tests/sim_gateway.rs +++ b/rust/connlib/tunnel/src/tests/sim_gateway.rs @@ -27,6 +27,12 @@ pub(crate) struct SimGateway { /// The received ICMP packets, indexed by our custom ICMP payload. pub(crate) received_icmp_requests: BTreeMap, + /// The received UDP packets, indexed by our custom UDP payload. + pub(crate) received_udp_requests: BTreeMap, + + /// The received TCP packets, indexed by our custom TCP payload. + pub(crate) received_tcp_requests: BTreeMap, + udp_dns_server_resources: HashMap, tcp_dns_server_resources: HashMap, } @@ -40,6 +46,8 @@ impl SimGateway { enc_buffer: Default::default(), udp_dns_server_resources: Default::default(), tcp_dns_server_resources: Default::default(), + received_udp_requests: Default::default(), + received_tcp_requests: Default::default(), } } @@ -112,12 +120,20 @@ impl SimGateway { if let Some(icmp) = packet.as_icmpv4() { if let Icmpv4Type::EchoRequest(echo) = icmp.icmp_type() { + let packet_id = u64::from_be_bytes(*icmp.payload().first_chunk().unwrap()); + tracing::debug!(%packet_id, "Received ICMP request"); + self.received_icmp_requests + .insert(packet_id, packet.clone()); return self.handle_icmp_request(&packet, echo, icmp.payload(), now); } } if let Some(icmp) = packet.as_icmpv6() { if let Icmpv6Type::EchoRequest(echo) = icmp.icmp_type() { + let packet_id = u64::from_be_bytes(*icmp.payload().first_chunk().unwrap()); + tracing::debug!(%packet_id, "Received ICMP request"); + self.received_icmp_requests + .insert(packet_id, packet.clone()); return self.handle_icmp_request(&packet, echo, icmp.payload(), now); } } @@ -125,6 +141,7 @@ impl SimGateway { if let Some(udp) = packet.as_udp() { let socket = SocketAddr::new(packet.destination(), udp.destination_port()); + // NOTE: we can make this assumption because port 53 is excluded from non-dns query packets if let Some(server) = self.udp_dns_server_resources.get_mut(&socket) { server.handle_input(packet); return None; @@ -134,12 +151,24 @@ impl SimGateway { if let Some(tcp) = packet.as_tcp() { let socket = SocketAddr::new(packet.destination(), tcp.destination_port()); + // NOTE: we can make this assumption because port 53 is excluded from non-dns query packets if let Some(server) = self.tcp_dns_server_resources.get_mut(&socket) { server.handle_input(packet); return None; } } + if let Some(reply) = ip_packet::make::echo_reply(packet.clone()) { + self.request_received(&packet); + let transmit = self + .sut + .handle_tun_input(reply, now, &mut self.enc_buffer)? + .to_transmit(&self.enc_buffer) + .into_owned(); + + return Some(transmit); + } + tracing::error!(?packet, "Unhandled packet"); None } @@ -157,6 +186,20 @@ impl SimGateway { ) } + fn request_received(&mut self, packet: &IpPacket) { + if let Some(udp) = packet.as_udp() { + let packet_id = u64::from_be_bytes(*udp.payload().first_chunk().unwrap()); + tracing::debug!(%packet_id, "Received UDP request"); + self.received_udp_requests.insert(packet_id, packet.clone()); + } + + if let Some(tcp) = packet.as_tcp() { + let packet_id = u64::from_be_bytes(*tcp.payload().first_chunk().unwrap()); + tracing::debug!(%packet_id, "Received TCP request"); + self.received_tcp_requests.insert(packet_id, packet.clone()); + } + } + fn handle_icmp_request( &mut self, packet: &IpPacket, @@ -164,11 +207,6 @@ impl SimGateway { payload: &[u8], now: Instant, ) -> Option> { - let echo_id = u64::from_be_bytes(*payload.first_chunk().unwrap()); - self.received_icmp_requests.insert(echo_id, packet.clone()); - - tracing::debug!(%echo_id, "Received ICMP request"); - let echo_response = ip_packet::make::icmp_reply_packet( packet.destination(), packet.source(), diff --git a/rust/connlib/tunnel/src/tests/strategies.rs b/rust/connlib/tunnel/src/tests/strategies.rs index d86a70551..62d286bb2 100644 --- a/rust/connlib/tunnel/src/tests/strategies.rs +++ b/rust/connlib/tunnel/src/tests/strategies.rs @@ -151,7 +151,7 @@ fn cidr_resource_outside_reserved_ranges( ) -> impl Strategy { cidr_resource(any_ip_network(8), sites.prop_map(|s| vec![s])) .prop_filter( - "tests doesn't support yet CIDR resources overlapping DNS resources", + "tests doesn't support CIDR resources overlapping DNS resources", |r| { // This works because CIDR resources' host mask is always <8 while IP resource is 21 let is_ip4_reserved = IpNetwork::V4(IPV4_RESOURCES) diff --git a/rust/connlib/tunnel/src/tests/sut.rs b/rust/connlib/tunnel/src/tests/sut.rs index d0449fc90..b2051a643 100644 --- a/rust/connlib/tunnel/src/tests/sut.rs +++ b/rust/connlib/tunnel/src/tests/sut.rs @@ -5,7 +5,7 @@ use super::sim_gateway::SimGateway; use super::sim_net::{Host, HostId, RoutingTable}; use super::sim_relay::SimRelay; use super::stub_portal::StubPortal; -use super::transition::DnsQuery; +use super::transition::{Destination, DnsQuery}; use crate::client::Resource; use crate::dns::{self, is_subdomain}; use crate::gateway::DnsResourceNatEntry; @@ -134,61 +134,70 @@ impl TunnelTest { Transition::DisableResources(resources) => state .client .exec_mut(|c| c.sut.set_disabled_resources(resources)), - Transition::SendICMPPacketToNonResourceIp { - src, - dst, - seq, - identifier, - payload, - } - | Transition::SendICMPPacketToCidrResource { + Transition::SendIcmpPacket { src, dst, seq, identifier, payload, + .. } => { + let dst = address_from_destination(&dst, &state, &src); + let packet = ip_packet::make::icmp_request_packet( src, dst, - seq, - identifier, + seq.0, + identifier.0, &payload.to_be_bytes(), ) .unwrap(); - let transmit = state.client.exec_mut(|sim| sim.encapsulate(packet, now)); + let transmit = state + .client + .exec_mut(|sim| Some(sim.encapsulate(packet, now)?.into_owned())); buffered_transmits.push_from(transmit, &state.client, now); } - Transition::SendICMPPacketToDnsResource { + Transition::SendUdpPacket { src, dst, - seq, - identifier, + sport, + dport, payload, - resolved_ip, - .. } => { - let available_ips = state - .client - .inner() - .dns_records - .get(&dst) - .unwrap() - .iter() - .filter(|ip| match ip { - IpAddr::V4(_) => src.is_ipv4(), - IpAddr::V6(_) => src.is_ipv6(), - }); - let dst = *resolved_ip.select(available_ips); + let dst = address_from_destination(&dst, &state, &src); - let packet = ip_packet::make::icmp_request_packet( + let packet = ip_packet::make::udp_packet( src, dst, - seq, - identifier, - &payload.to_be_bytes(), + sport.0, + dport.0, + payload.to_be_bytes().to_vec(), + ) + .unwrap(); + + let transmit = state + .client + .exec_mut(|sim| Some(sim.encapsulate(packet, now)?.into_owned())); + + buffered_transmits.push_from(transmit, &state.client, now); + } + Transition::SendTcpPayload { + src, + dst, + sport, + dport, + payload, + } => { + let dst = address_from_destination(&dst, &state, &src); + + let packet = ip_packet::make::tcp_packet( + src, + dst, + sport.0, + dport.0, + payload.to_be_bytes().to_vec(), ) .unwrap(); @@ -341,7 +350,19 @@ impl TunnelTest { assert_icmp_packets_properties( ref_client, sim_client, - sim_gateways, + &sim_gateways, + &ref_state.global_dns_records, + ); + assert_udp_packets_properties( + ref_client, + sim_client, + &sim_gateways, + &ref_state.global_dns_records, + ); + assert_tcp_packets_properties( + ref_client, + sim_client, + &sim_gateways, &ref_state.global_dns_records, ); assert_udp_dns_packets_properties(ref_client, sim_client); @@ -886,6 +907,26 @@ impl TunnelTest { } } +fn address_from_destination(destination: &Destination, state: &TunnelTest, src: &IpAddr) -> IpAddr { + match destination { + Destination::DomainName { resolved_ip, name } => { + let available_ips = state + .client + .inner() + .dns_records + .get(name) + .unwrap() + .iter() + .filter(|ip| match ip { + IpAddr::V4(_) => src.is_ipv4(), + IpAddr::V6(_) => src.is_ipv6(), + }); + *resolved_ip.select(available_ips) + } + Destination::IpAddr(addr) => *addr, + } +} + fn on_gateway_event( src: GatewayId, event: GatewayEvent, diff --git a/rust/connlib/tunnel/src/tests/transition.rs b/rust/connlib/tunnel/src/tests/transition.rs index 474e53d3a..ab47fde28 100644 --- a/rust/connlib/tunnel/src/tests/transition.rs +++ b/rust/connlib/tunnel/src/tests/transition.rs @@ -18,7 +18,6 @@ use std::{ /// The possible transitions of the state machine. #[derive(Clone, derivative::Derivative)] #[derivative(Debug)] -#[expect(clippy::large_enum_variant)] pub(crate) enum Transition { /// Activate a resource on the client. ActivateResource(Resource), @@ -26,31 +25,28 @@ pub(crate) enum Transition { DeactivateResource(ResourceId), /// Client-side disable resource DisableResources(BTreeSet), - /// Send an ICMP packet to non-resource IP. - SendICMPPacketToNonResourceIp { + /// Send an ICMP packet to destination (IP resource, DNS resource or IP non-resource). + SendIcmpPacket { src: IpAddr, - dst: IpAddr, - seq: u16, - identifier: u16, + dst: Destination, + seq: Seq, + identifier: Identifier, payload: u64, }, - /// Send an ICMP packet to a CIDR resource. - SendICMPPacketToCidrResource { + /// Send an UDP packet to destination (IP resource, DNS resource or IP non-resource). + SendUdpPacket { src: IpAddr, - dst: IpAddr, - seq: u16, - identifier: u16, + dst: Destination, + sport: SPort, + dport: DPort, payload: u64, }, - /// Send an ICMP packet to a DNS resource. - SendICMPPacketToDnsResource { + /// Send an TCP payload to destination (IP resource, DNS resource or IP non-resource). + SendTcpPayload { src: IpAddr, - dst: DomainName, - #[derivative(Debug = "ignore")] - resolved_ip: sample::Selector, - - seq: u16, - identifier: u16, + dst: Destination, + sport: SPort, + dport: DPort, payload: u64, }, @@ -107,78 +103,163 @@ pub(crate) enum DnsTransport { Tcp, } -pub(crate) fn ping_random_ip( +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub(crate) struct Seq(pub u16); + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub(crate) struct Identifier(pub u16); + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub(crate) struct SPort(pub u16); + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub(crate) struct DPort(pub u16); + +#[derive(Debug, Clone)] +#[expect(clippy::large_enum_variant)] +pub(crate) enum Destination { + DomainName { + resolved_ip: sample::Selector, + name: DomainName, + }, + IpAddr(IpAddr), +} + +/// Helper enum +#[derive(Debug, Clone)] +enum PacketDestination { + DomainName(DomainName), + IpAddr(IpAddr), +} + +pub(crate) trait ReplyTo { + fn reply_to(self) -> Self; +} + +impl ReplyTo for (SPort, DPort) { + fn reply_to(self) -> Self { + (SPort(self.1 .0), DPort(self.0 .0)) + } +} + +impl ReplyTo for (Seq, Identifier) { + fn reply_to(self) -> Self { + self + } +} + +impl From for PacketDestination { + fn from(name: DomainName) -> Self { + PacketDestination::DomainName(name) + } +} + +impl From for PacketDestination { + fn from(addr: Ipv4Addr) -> Self { + PacketDestination::IpAddr(addr.into()) + } +} + +impl From for PacketDestination { + fn from(addr: Ipv6Addr) -> Self { + PacketDestination::IpAddr(addr.into()) + } +} + +impl From for PacketDestination { + fn from(addr: IpAddr) -> Self { + PacketDestination::IpAddr(addr) + } +} + +impl PacketDestination { + fn into_destination(self, resolved_ip: sample::Selector) -> Destination { + match self { + PacketDestination::DomainName(name) => Destination::DomainName { resolved_ip, name }, + PacketDestination::IpAddr(addr) => Destination::IpAddr(addr), + } + } +} + +#[expect(private_bounds)] +pub(crate) fn icmp_packet( src: impl Strategy, - dst: impl Strategy, + dst: impl Strategy, ) -> impl Strategy where I: Into, + D: Into, { ( src.prop_map(Into::into), dst.prop_map(Into::into), any::(), any::(), - any::(), - ) - .prop_map(|(src, dst, seq, identifier, payload)| { - Transition::SendICMPPacketToNonResourceIp { - src, - dst, - seq, - identifier, - payload, - } - }) -} - -pub(crate) fn icmp_to_cidr_resource( - src: impl Strategy, - dst: impl Strategy, -) -> impl Strategy -where - I: Into, -{ - ( - dst.prop_map(Into::into), - any::(), - any::(), - src.prop_map(Into::into), - any::(), - ) - .prop_map(|(dst, seq, identifier, src, payload)| { - Transition::SendICMPPacketToCidrResource { - src, - dst, - seq, - identifier, - payload, - } - }) -} - -pub(crate) fn icmp_to_dns_resource( - src: impl Strategy, - dst: impl Strategy, -) -> impl Strategy -where - I: Into, -{ - ( - dst, - any::(), - any::(), - src.prop_map(Into::into), any::(), any::(), ) - .prop_map(|(dst, seq, identifier, src, resolved_ip, payload)| { - Transition::SendICMPPacketToDnsResource { + .prop_map(|(src, dst, seq, identifier, resolved_ip, payload)| { + Transition::SendIcmpPacket { src, - dst, - resolved_ip, - seq, - identifier, + dst: dst.into_destination(resolved_ip), + seq: Seq(seq), + identifier: Identifier(identifier), + payload, + } + }) +} + +#[expect(private_bounds)] +pub(crate) fn udp_packet( + src: impl Strategy, + dst: impl Strategy, +) -> impl Strategy +where + I: Into, + D: Into, +{ + ( + src.prop_map(Into::into), + dst.prop_map(Into::into), + any::(), + any::(), + any::(), + any::(), + ) + .prop_map( + |(src, dst, sport, dport, resolved_ip, payload)| Transition::SendUdpPacket { + src, + dst: dst.into_destination(resolved_ip), + sport: SPort(sport), + dport: DPort(dport), + payload, + }, + ) +} + +#[expect(private_bounds)] +pub(crate) fn tcp_packet( + src: impl Strategy, + dst: impl Strategy, +) -> impl Strategy +where + I: Into, + D: Into, +{ + ( + src.prop_map(Into::into), + dst.prop_map(Into::into), + any::(), + any::(), + any::(), + any::(), + ) + .prop_map(|(src, dst, sport, dport, resolved_ip, payload)| { + Transition::SendTcpPayload { + src, + dst: dst.into_destination(resolved_ip), + sport: SPort(sport), + dport: DPort(dport), payload, } }) diff --git a/rust/ip-packet/src/make.rs b/rust/ip-packet/src/make.rs index b57791dc8..ccd2a7b27 100644 --- a/rust/ip-packet/src/make.rs +++ b/rust/ip-packet/src/make.rs @@ -98,6 +98,36 @@ pub fn icmp_reply_packet( } } +pub fn echo_reply(mut req: IpPacket) -> Option { + if !req.is_udp() && !req.is_tcp() { + return None; + } + + if let Some(mut packet) = req.as_tcp_mut() { + let original_src = packet.get_source_port(); + let original_dst = packet.get_destination_port(); + + packet.set_source_port(original_dst); + packet.set_destination_port(original_src); + } + + if let Some(mut packet) = req.as_udp_mut() { + let original_src = packet.get_source_port(); + let original_dst = packet.get_destination_port(); + + packet.set_source_port(original_dst); + packet.set_destination_port(original_src); + } + + let original_src = req.source(); + let original_dst = req.destination(); + + req.set_dst(original_src); + req.set_src(original_dst); + + Some(req) +} + pub fn tcp_packet( saddr: IP, daddr: IP, diff --git a/rust/ip-packet/src/tcp_header_slice_mut.rs b/rust/ip-packet/src/tcp_header_slice_mut.rs index 0a896af6d..69fb3132c 100644 --- a/rust/ip-packet/src/tcp_header_slice_mut.rs +++ b/rust/ip-packet/src/tcp_header_slice_mut.rs @@ -13,6 +13,14 @@ impl<'a> TcpHeaderSliceMut<'a> { Ok(Self { slice }) } + pub fn get_source_port(&self) -> u16 { + u16::from_be_bytes([self.slice[0], self.slice[1]]) + } + + pub fn get_destination_port(&self) -> u16 { + u16::from_be_bytes([self.slice[2], self.slice[3]]) + } + pub fn set_source_port(&mut self, src: u16) { // Safety: Slice it at least of length 20 as checked in the ctor. unsafe { write_to_offset_unchecked(self.slice, 0, src.to_be_bytes()) }; diff --git a/rust/ip-packet/src/udp_header_slice_mut.rs b/rust/ip-packet/src/udp_header_slice_mut.rs index 986b88604..5b1610580 100644 --- a/rust/ip-packet/src/udp_header_slice_mut.rs +++ b/rust/ip-packet/src/udp_header_slice_mut.rs @@ -13,6 +13,14 @@ impl<'a> UdpHeaderSliceMut<'a> { Ok(Self { slice }) } + pub fn get_source_port(&self) -> u16 { + u16::from_be_bytes([self.slice[0], self.slice[1]]) + } + + pub fn get_destination_port(&self) -> u16 { + u16::from_be_bytes([self.slice[2], self.slice[3]]) + } + pub fn set_source_port(&mut self, src: u16) { // Safety: Slice it at least of length 8 as checked in the ctor. unsafe { write_to_offset_unchecked(self.slice, 0, src.to_be_bytes()) };