diff --git a/rust/connlib/tunnel/src/peer.rs b/rust/connlib/tunnel/src/peer.rs index 7f924677b..69297aa53 100644 --- a/rust/connlib/tunnel/src/peer.rs +++ b/rust/connlib/tunnel/src/peer.rs @@ -1058,7 +1058,7 @@ mod proptests { let packet = match protocol { Protocol::Tcp { dport } => tcp_packet(src, dest, sport, *dport, payload.clone()), Protocol::Udp { dport } => udp_packet(src, dest, sport, *dport, payload.clone()), - Protocol::Icmp => icmp_request_packet(src, dest, 1, 0), + Protocol::Icmp => icmp_request_packet(src, dest, 1, 0, &[]), } .unwrap(); assert!(peer.ensure_allowed_dst(&packet).is_ok()); @@ -1091,7 +1091,7 @@ mod proptests { let packet = match protocol { Protocol::Tcp { dport } => tcp_packet(src, dest, sport, dport, payload.clone()), Protocol::Udp { dport } => udp_packet(src, dest, sport, dport, payload.clone()), - Protocol::Icmp => icmp_request_packet(src, dest, 1, 0), + Protocol::Icmp => icmp_request_packet(src, dest, 1, 0, &[]), } .unwrap(); assert!(peer.ensure_allowed_dst(&packet).is_ok()); @@ -1133,7 +1133,7 @@ mod proptests { let packet = match protocol { Protocol::Tcp { dport } => tcp_packet(src, dest, sport, dport, payload.clone()), Protocol::Udp { dport } => udp_packet(src, dest, sport, dport, payload.clone()), - Protocol::Icmp => icmp_request_packet(src, dest, 1, 0), + Protocol::Icmp => icmp_request_packet(src, dest, 1, 0, &[]), } .unwrap(); assert!(peer.ensure_allowed_dst(&packet).is_ok()); @@ -1148,7 +1148,7 @@ mod proptests { let packet = match protocol { Protocol::Tcp { dport } => tcp_packet(src, dest, sport, dport, payload.clone()), Protocol::Udp { dport } => udp_packet(src, dest, sport, dport, payload.clone()), - Protocol::Icmp => icmp_request_packet(src, dest, 1, 0), + Protocol::Icmp => icmp_request_packet(src, dest, 1, 0, &[]), } .unwrap(); assert!(peer.ensure_allowed_dst(&packet).is_ok()); @@ -1191,7 +1191,7 @@ mod proptests { let packet = match protocol { Protocol::Tcp { dport } => tcp_packet(src, dest, sport, dport, payload.clone()), Protocol::Udp { dport } => udp_packet(src, dest, sport, dport, payload.clone()), - Protocol::Icmp => icmp_request_packet(src, dest, 1, 0), + Protocol::Icmp => icmp_request_packet(src, dest, 1, 0, &[]), } .unwrap(); @@ -1222,7 +1222,7 @@ mod proptests { let packet = match protocol { Protocol::Tcp { dport } => tcp_packet(src, dest, sport, dport, payload), Protocol::Udp { dport } => udp_packet(src, dest, sport, dport, payload), - Protocol::Icmp => icmp_request_packet(src, dest, 1, 0), + Protocol::Icmp => icmp_request_packet(src, dest, 1, 0, &[]), } .unwrap(); @@ -1260,14 +1260,14 @@ mod proptests { let packet_allowed = match protocol_allowed { Protocol::Tcp { dport } => tcp_packet(src, dest, sport, dport, payload.clone()), Protocol::Udp { dport } => udp_packet(src, dest, sport, dport, payload.clone()), - Protocol::Icmp => icmp_request_packet(src, dest, 1, 0), + Protocol::Icmp => icmp_request_packet(src, dest, 1, 0, &[]), } .unwrap(); let packet_rejected = match protocol_removed { Protocol::Tcp { dport } => tcp_packet(src, dest, sport, dport, payload), Protocol::Udp { dport } => udp_packet(src, dest, sport, dport, payload), - Protocol::Icmp => icmp_request_packet(src, dest, 1, 0), + Protocol::Icmp => icmp_request_packet(src, dest, 1, 0, &[]), } .unwrap(); diff --git a/rust/connlib/tunnel/src/tests/assertions.rs b/rust/connlib/tunnel/src/tests/assertions.rs index caa93d0bd..de4900438 100644 --- a/rust/connlib/tunnel/src/tests/assertions.rs +++ b/rust/connlib/tunnel/src/tests/assertions.rs @@ -33,7 +33,7 @@ pub(crate) fn assert_icmp_packets_properties( .flatten() .collect(), &sim_client.received_icmp_replies, - |(_, seq_a, id_a), (seq_b, id_b)| seq_a == seq_b && id_a == id_b, + |(_, (_, seq_a, id_a)), (seq_b, id_b)| seq_a == seq_b && id_a == id_b, ); if !unexpected_icmp_replies.is_empty() { @@ -61,9 +61,7 @@ pub(crate) fn assert_icmp_packets_properties( for (gateway, expected_icmp_handshakes) in &ref_client.expected_icmp_handshakes { let received_icmp_requests = &sim_gateways.get(gateway).unwrap().received_icmp_requests; - for ((resource_dst, seq, identifier), gateway_received_request) in - expected_icmp_handshakes.iter().zip(received_icmp_requests) - { + for (payload, (resource_dst, seq, identifier)) in expected_icmp_handshakes { let _guard = tracing::info_span!(target: "assertions", "icmp", %seq, %identifier).entered(); @@ -78,9 +76,13 @@ pub(crate) fn assert_icmp_packets_properties( tracing::error!(target: "assertions", "❌ Missing ICMP reply on client"); continue; }; - assert_correct_src_and_dst_ips(client_sent_request, client_received_reply); + let Some(gateway_received_request) = received_icmp_requests.get(payload) else { + tracing::error!(target: "assertions", "❌ Missing ICMP request on gateway"); + continue; + }; + { let expected = ref_client.tunnel_ip_for(gateway_received_request.source()); let actual = gateway_received_request.source(); @@ -211,7 +213,7 @@ fn assert_destination_is_cdir_resource(gateway_received_request: &IpPacket<'_>, let actual = gateway_received_request.destination(); if actual != *expected { - tracing::error!(target: "assertions", %actual, %expected, "❌ Unknown resource IP"); + tracing::error!(target: "assertions", %actual, %expected, "❌ Incorrect resource destination"); } else { tracing::info!(target: "assertions", ip = %actual, "✅ ICMP request targets correct resource"); } diff --git a/rust/connlib/tunnel/src/tests/reference.rs b/rust/connlib/tunnel/src/tests/reference.rs index d7825e74b..aa1c6cb9b 100644 --- a/rust/connlib/tunnel/src/tests/reference.rs +++ b/rust/connlib/tunnel/src/tests/reference.rs @@ -366,13 +366,19 @@ impl ReferenceStateMachine for ReferenceState { dst, seq, identifier, + payload, } => { state.client.exec_mut(|client| { // If the Internet Resource is active, all packets are expected to be routed. if client.active_internet_resource().is_some() { - client.on_icmp_packet_to_internet(*src, *dst, *seq, *identifier, |r| { - state.portal.gateway_for_resource(r).copied() - }) + client.on_icmp_packet_to_internet( + *src, + *dst, + *seq, + *identifier, + *payload, + |r| state.portal.gateway_for_resource(r).copied(), + ) } }); } @@ -381,10 +387,10 @@ impl ReferenceStateMachine for ReferenceState { dst, seq, identifier, - .. + payload, } => { state.client.exec_mut(|client| { - client.on_icmp_packet_to_cidr(*src, *dst, *seq, *identifier, |r| { + client.on_icmp_packet_to_cidr(*src, *dst, *seq, *identifier, *payload, |r| { state.portal.gateway_for_resource(r).copied() }) }); @@ -394,9 +400,10 @@ impl ReferenceStateMachine for ReferenceState { dst, seq, identifier, + payload, .. } => state.client.exec_mut(|client| { - client.on_icmp_packet_to_dns(*src, dst.clone(), *seq, *identifier, |r| { + client.on_icmp_packet_to_dns(*src, dst.clone(), *seq, *identifier, *payload, |r| { state.portal.gateway_for_resource(r).copied() }) }), @@ -478,10 +485,13 @@ impl ReferenceStateMachine for ReferenceState { dst, seq, identifier, + payload, .. } => { - let is_valid_icmp_packet = - state.client.inner().is_valid_icmp_packet(seq, identifier); + let is_valid_icmp_packet = state + .client + .inner() + .is_valid_icmp_packet(seq, identifier, payload); let is_cidr_resource = state.client.inner().cidr_resource_by_ip(*dst).is_some(); is_valid_icmp_packet && !is_cidr_resource @@ -490,6 +500,7 @@ impl ReferenceStateMachine for ReferenceState { seq, identifier, dst, + payload, .. } => { let ref_client = state.client.inner(); @@ -500,7 +511,7 @@ impl ReferenceStateMachine for ReferenceState { return false; }; - ref_client.is_valid_icmp_packet(seq, identifier) + ref_client.is_valid_icmp_packet(seq, identifier, payload) && state.gateways.contains_key(gateway) } Transition::SendICMPPacketToDnsResource { @@ -508,6 +519,7 @@ impl ReferenceStateMachine for ReferenceState { identifier, dst, src, + payload, .. } => { let ref_client = state.client.inner(); @@ -518,7 +530,7 @@ impl ReferenceStateMachine for ReferenceState { return false; }; - ref_client.is_valid_icmp_packet(seq, identifier) + ref_client.is_valid_icmp_packet(seq, identifier, payload) && ref_client.dns_records.get(dst).is_some_and(|r| match src { IpAddr::V4(_) => r.contains(&Rtype::A), IpAddr::V6(_) => r.contains(&Rtype::AAAA), diff --git a/rust/connlib/tunnel/src/tests/sim_client.rs b/rust/connlib/tunnel/src/tests/sim_client.rs index b140ab4c1..c44b5a93b 100644 --- a/rust/connlib/tunnel/src/tests/sim_client.rs +++ b/rust/connlib/tunnel/src/tests/sim_client.rs @@ -299,11 +299,9 @@ pub struct RefClient { pub(crate) disabled_resources: BTreeSet, /// The expected ICMP handshakes. - /// - /// This is indexed by gateway because our assertions rely on the order of the sent packets. #[derivative(Debug = "ignore")] pub(crate) expected_icmp_handshakes: - BTreeMap>, + BTreeMap>, /// The expected DNS handshakes. #[derivative(Debug = "ignore")] pub(crate) expected_dns_handshakes: VecDeque<(SocketAddr, QueryId)>, @@ -399,6 +397,7 @@ impl RefClient { dst: IpAddr, seq: u16, identifier: u16, + payload: u64, gateway_by_resource: impl Fn(ResourceId) -> Option, ) { tracing::Span::current().record("dst", tracing::field::display(dst)); @@ -420,7 +419,7 @@ impl RefClient { self.expected_icmp_handshakes .entry(gateway) .or_default() - .push_back((ResourceDst::Internet(dst), seq, identifier)); + .insert(payload, (ResourceDst::Internet(dst), seq, identifier)); return; } @@ -436,6 +435,7 @@ impl RefClient { dst: IpAddr, seq: u16, identifier: u16, + payload: u64, gateway_by_resource: impl Fn(ResourceId) -> Option, ) { tracing::Span::current().record("dst", tracing::field::display(dst)); @@ -461,7 +461,7 @@ impl RefClient { self.expected_icmp_handshakes .entry(gateway) .or_default() - .push_back((ResourceDst::Cidr(dst), seq, identifier)); + .insert(payload, (ResourceDst::Cidr(dst), seq, identifier)); return; } @@ -477,6 +477,7 @@ impl RefClient { dst: DomainName, seq: u16, identifier: u16, + payload: u64, gateway_by_resource: impl Fn(ResourceId) -> Option, ) { tracing::Span::current().record("dst", tracing::field::display(&dst)); @@ -502,7 +503,7 @@ impl RefClient { self.expected_icmp_handshakes .entry(gateway) .or_default() - .push_back((ResourceDst::Dns(dst), seq, identifier)); + .insert(payload, (ResourceDst::Dns(dst), seq, identifier)); return; } @@ -595,11 +596,13 @@ impl RefClient { .map(|(domain, ips)| (domain.clone(), ips.clone())) } - /// An ICMP packet is valid if we didn't yet send an ICMP packet with the same seq and identifier. - pub(crate) fn is_valid_icmp_packet(&self, seq: &u16, identifier: &u16) -> bool { + /// An ICMP packet is valid if we didn't yet send an ICMP packet with the same seq, identifier and payload. + pub(crate) fn is_valid_icmp_packet(&self, seq: &u16, identifier: &u16, payload: &u64) -> bool { self.expected_icmp_handshakes.values().flatten().all( - |(_, existing_seq, existing_identifer)| { - existing_seq != seq && existing_identifer != identifier + |(existig_payload, (_, existing_seq, existing_identifer))| { + existing_seq != seq + && existing_identifer != identifier + && existig_payload != payload }, ) } diff --git a/rust/connlib/tunnel/src/tests/sim_gateway.rs b/rust/connlib/tunnel/src/tests/sim_gateway.rs index cc1055bc7..022443dd7 100644 --- a/rust/connlib/tunnel/src/tests/sim_gateway.rs +++ b/rust/connlib/tunnel/src/tests/sim_gateway.rs @@ -13,7 +13,7 @@ use ip_packet::IpPacket; use proptest::prelude::*; use snownet::Transmit; use std::{ - collections::{BTreeMap, BTreeSet, VecDeque}, + collections::{BTreeMap, BTreeSet}, net::IpAddr, time::Instant, }; @@ -23,7 +23,8 @@ pub(crate) struct SimGateway { id: GatewayId, pub(crate) sut: GatewayState, - pub(crate) received_icmp_requests: VecDeque>, + /// The received ICMP packets, indexed by our custom ICMP payload. + pub(crate) received_icmp_requests: BTreeMap>, buffer: Vec, } @@ -69,13 +70,18 @@ impl SimGateway { // TODO: Instead of handling these things inline, here, should we dispatch them via `RoutingTable`? - if packet.as_icmp().is_some() { - self.received_icmp_requests.push_back(packet.clone()); + if let Some(icmp) = packet.as_icmp() { + if let Some(request) = icmp.as_echo_request() { + let payload = u64::from_be_bytes(*request.payload().first_chunk().unwrap()); + tracing::debug!(%payload, "Received ICMP request"); - let echo_response = ip_packet::make::icmp_response_packet(packet); - let transmit = self.sut.encapsulate(echo_response, now)?.into_owned(); + self.received_icmp_requests.insert(payload, packet.clone()); - return Some(transmit); + let echo_response = ip_packet::make::icmp_response_packet(packet); + let transmit = self.sut.encapsulate(echo_response, now)?.into_owned(); + + return Some(transmit); + } } if packet.as_udp().is_some() { diff --git a/rust/connlib/tunnel/src/tests/sut.rs b/rust/connlib/tunnel/src/tests/sut.rs index 4e3b6522a..8ff7659b6 100644 --- a/rust/connlib/tunnel/src/tests/sut.rs +++ b/rust/connlib/tunnel/src/tests/sut.rs @@ -150,16 +150,23 @@ impl TunnelTest { dst, seq, identifier, + payload, } | Transition::SendICMPPacketToCidrResource { src, dst, seq, identifier, - .. + payload, } => { - let packet = - ip_packet::make::icmp_request_packet(src, dst, seq, identifier).unwrap(); + let packet = ip_packet::make::icmp_request_packet( + src, + dst, + seq, + identifier, + &payload.to_be_bytes(), + ) + .unwrap(); let transmit = state.client.exec_mut(|sim| sim.encapsulate(packet, now)); @@ -170,6 +177,7 @@ impl TunnelTest { dst, seq, identifier, + payload, resolved_ip, .. } => { @@ -186,8 +194,14 @@ impl TunnelTest { }); let dst = *resolved_ip.select(available_ips); - let packet = - ip_packet::make::icmp_request_packet(src, dst, seq, identifier).unwrap(); + let packet = ip_packet::make::icmp_request_packet( + src, + dst, + seq, + identifier, + &payload.to_be_bytes(), + ) + .unwrap(); let transmit = state .client diff --git a/rust/connlib/tunnel/src/tests/transition.rs b/rust/connlib/tunnel/src/tests/transition.rs index 5a04ebba7..a6a838459 100644 --- a/rust/connlib/tunnel/src/tests/transition.rs +++ b/rust/connlib/tunnel/src/tests/transition.rs @@ -31,6 +31,7 @@ pub(crate) enum Transition { dst: IpAddr, seq: u16, identifier: u16, + payload: u64, }, /// Send an ICMP packet to a CIDR resource. SendICMPPacketToCidrResource { @@ -38,6 +39,7 @@ pub(crate) enum Transition { dst: IpAddr, seq: u16, identifier: u16, + payload: u64, }, /// Send an ICMP packet to a DNS resource. SendICMPPacketToDnsResource { @@ -48,6 +50,7 @@ pub(crate) enum Transition { seq: u16, identifier: u16, + payload: u64, }, /// Send a DNS query. @@ -103,15 +106,17 @@ where dst.prop_map(Into::into), any::(), any::(), + any::(), ) - .prop_map( - |(src, dst, seq, identifier)| Transition::SendICMPPacketToNonResourceIp { + .prop_map(|(src, dst, seq, identifier, payload)| { + Transition::SendICMPPacketToNonResourceIp { src, dst, seq, identifier, - }, - ) + payload, + } + }) } pub(crate) fn icmp_to_cidr_resource( @@ -126,15 +131,17 @@ where any::(), any::(), src.prop_map(Into::into), + any::(), ) - .prop_map( - |(dst, seq, identifier, src)| Transition::SendICMPPacketToCidrResource { + .prop_map(|(dst, seq, identifier, src, payload)| { + Transition::SendICMPPacketToCidrResource { src, dst, seq, identifier, - }, - ) + payload, + } + }) } pub(crate) fn icmp_to_dns_resource( @@ -150,14 +157,16 @@ where any::(), src.prop_map(Into::into), any::(), + any::(), ) - .prop_map(|(dst, seq, identifier, src, resolved_ip)| { + .prop_map(|(dst, seq, identifier, src, resolved_ip, payload)| { Transition::SendICMPPacketToDnsResource { src, dst, resolved_ip, seq, identifier, + payload, } }) } diff --git a/rust/ip-packet/src/lib.rs b/rust/ip-packet/src/lib.rs index e5d1d29cc..a6feb9f3a 100644 --- a/rust/ip-packet/src/lib.rs +++ b/rust/ip-packet/src/lib.rs @@ -1518,6 +1518,10 @@ impl<'a> IcmpEchoRequest<'a> { pub fn identifier(&self) -> u16 { for_both!(self, |i| i.get_identifier()) } + + pub fn payload(&self) -> &[u8] { + for_both!(self, |i| i.payload()) + } } impl<'a> IcmpEchoReply<'a> { diff --git a/rust/ip-packet/src/make.rs b/rust/ip-packet/src/make.rs index c2e6e5e92..f5b9f71bf 100644 --- a/rust/ip-packet/src/make.rs +++ b/rust/ip-packet/src/make.rs @@ -22,8 +22,9 @@ pub fn icmp_request_packet( dst: impl Into, seq: u16, identifier: u16, + payload: &[u8], ) -> Result, IpVersionMismatch> { - icmp_packet(src, dst.into(), seq, identifier, IcmpKind::Request) + icmp_packet(src, dst.into(), seq, identifier, payload, IcmpKind::Request) } pub fn icmp_reply_packet( @@ -31,8 +32,16 @@ pub fn icmp_reply_packet( dst: impl Into, seq: u16, identifier: u16, + payload: &[u8], ) -> Result, IpVersionMismatch> { - icmp_packet(src, dst.into(), seq, identifier, IcmpKind::Response) + icmp_packet( + src, + dst.into(), + seq, + identifier, + payload, + IcmpKind::Response, + ) } pub fn icmp_response_packet(packet: IpPacket<'static>) -> MutableIpPacket<'static> { @@ -46,6 +55,7 @@ pub fn icmp_response_packet(packet: IpPacket<'static>) -> MutableIpPacket<'stati packet.source(), echo_request.sequence(), echo_request.identifier(), + echo_request.payload(), IcmpKind::Response, ) .expect("src and dst come from the same packet") @@ -62,6 +72,7 @@ pub(crate) fn icmp4_packet_with_options( dst: Ipv4Addr, seq: u16, identifier: u16, + payload: &[u8], kind: IcmpKind, ip_header_length: u8, ) -> MutableIpPacket<'static> { @@ -75,7 +86,7 @@ pub(crate) fn icmp4_packet_with_options( }; let ip_header_bytes = ip_header_length * 4; - let mut buf = vec![0u8; 60 + ip_header_bytes as usize]; + let mut buf = vec![0u8; 60 + payload.len() + ip_header_bytes as usize]; ipv4_header( src, @@ -104,6 +115,7 @@ pub(crate) fn icmp4_packet_with_options( let mut echo_request_packet = MutableEchoRequestPacket::new(icmp_packet.packet_mut()).unwrap(); echo_request_packet.set_sequence_number(seq); echo_request_packet.set_identifier(identifier); + echo_request_packet.set_payload(payload); let mut result = MutableIpPacket::owned(buf).unwrap(); result.update_checksum(); @@ -115,11 +127,12 @@ pub(crate) fn icmp_packet( dst: IpAddr, seq: u16, identifier: u16, + payload: &[u8], kind: IcmpKind, ) -> Result, IpVersionMismatch> { match (src, dst) { (IpAddr::V4(src), IpAddr::V4(dst)) => Ok(icmp4_packet_with_options( - src, dst, seq, identifier, kind, 5, + src, dst, seq, identifier, payload, kind, 5, )), (IpAddr::V6(src), IpAddr::V6(dst)) => { use crate::{ @@ -131,7 +144,7 @@ pub(crate) fn icmp_packet( MutablePacket as _, }; - let mut buf = vec![0u8; 128 + 20]; + let mut buf = vec![0u8; 128 + 20 + payload.len()]; ipv6_header(src, dst, IpNextHeaderProtocols::Icmpv6, &mut buf[20..]); @@ -153,6 +166,7 @@ pub(crate) fn icmp_packet( echo_request_packet.set_identifier(identifier); echo_request_packet.set_sequence_number(seq); echo_request_packet.set_checksum(0); + echo_request_packet.set_payload(payload); let mut result = MutableIpPacket::owned(buf).unwrap(); result.update_checksum(); diff --git a/rust/ip-packet/src/proptest.rs b/rust/ip-packet/src/proptest.rs index ba34576d0..6d8e6f2a1 100644 --- a/rust/ip-packet/src/proptest.rs +++ b/rust/ip-packet/src/proptest.rs @@ -27,10 +27,10 @@ pub fn tcp_packet() -> impl Strategy> { pub fn icmp_request_packet() -> impl Strategy> { prop_oneof![ (ip4_tuple(), any::(), any::()).prop_map(|((saddr, daddr), sport, dport)| { - crate::make::icmp_request_packet(IpAddr::V4(saddr), daddr, sport, dport).unwrap() + crate::make::icmp_request_packet(IpAddr::V4(saddr), daddr, sport, dport, &[]).unwrap() }), (ip6_tuple(), any::(), any::()).prop_map(|((saddr, daddr), sport, dport)| { - crate::make::icmp_request_packet(IpAddr::V6(saddr), daddr, sport, dport).unwrap() + crate::make::icmp_request_packet(IpAddr::V6(saddr), daddr, sport, dport, &[]).unwrap() }), ] } diff --git a/rust/ip-packet/src/proptests.rs b/rust/ip-packet/src/proptests.rs index 37c36b122..f6cbfccb7 100644 --- a/rust/ip-packet/src/proptests.rs +++ b/rust/ip-packet/src/proptests.rs @@ -69,7 +69,7 @@ fn icmp_packet_v4() -> impl Strategy> { any::(), ) .prop_map(|(src, dst, id, seq, kind)| { - icmp_packet(src.into(), dst.into(), id, seq, kind).unwrap() + icmp_packet(src.into(), dst.into(), id, seq, &[], kind).unwrap() }) } @@ -83,7 +83,7 @@ fn icmp_packet_v4_header_options() -> impl Strategy impl Strategy> { any::(), ) .prop_map(|(src, dst, id, seq, kind)| { - icmp_packet(src.into(), dst.into(), id, seq, kind).unwrap() + icmp_packet(src.into(), dst.into(), id, seq, &[], kind).unwrap() }) }