diff --git a/rust/connlib/tunnel/src/tests.rs b/rust/connlib/tunnel/src/tests.rs index 05e722ecf..9064f4ed0 100644 --- a/rust/connlib/tunnel/src/tests.rs +++ b/rust/connlib/tunnel/src/tests.rs @@ -11,7 +11,6 @@ use connlib_shared::{ use firezone_relay::{AddressFamily, AllocationPort, ClientSocket, PeerSocket}; use ip_network::Ipv4Network; use ip_network_table::IpNetworkTable; -use ip_packet::IpPacket; use pretty_assertions::assert_eq; use proptest::{ arbitrary::any, @@ -59,8 +58,8 @@ struct TunnelTest { relay: SimRelay>, portal: SimPortal, - client_received_packets: VecDeque>, - gateway_received_icmp_packets: VecDeque<(Instant, IpAddr, IpAddr)>, + gateway_received_icmp_requests: VecDeque<(Instant, IpAddr, IpAddr)>, + client_received_icmp_replies: VecDeque<(Instant, IpAddr, IpAddr)>, #[allow(dead_code)] logger: DefaultGuard, @@ -82,7 +81,8 @@ struct ReferenceState { /// The IP ranges we are connected to. connected_resources: IpNetworkTable<()>, - gateway_received_icmp_packets: VecDeque<(Instant, IpAddr, IpAddr)>, + gateway_received_icmp_requests: VecDeque<(Instant, IpAddr, IpAddr)>, + client_received_icmp_replies: VecDeque<(Instant, IpAddr, IpAddr)>, } /// The possible transitions of the state machine. @@ -159,9 +159,9 @@ impl StateMachineTest for TunnelTest { gateway, portal, logger, - client_received_packets: Default::default(), - gateway_received_icmp_packets: Default::default(), relay, + gateway_received_icmp_requests: Default::default(), + client_received_icmp_replies: Default::default(), }; let mut buffered_transmits = VecDeque::new(); @@ -218,8 +218,12 @@ impl StateMachineTest for TunnelTest { // Assert: Check that our actual state is equivalent to our expectation (the reference state). assert_eq!( - state.gateway_received_icmp_packets, - ref_state.gateway_received_icmp_packets + state.gateway_received_icmp_requests, + ref_state.gateway_received_icmp_requests + ); + assert_eq!( + state.client_received_icmp_replies, + ref_state.client_received_icmp_replies ); assert!(buffered_transmits.is_empty()); // Sanity check to ensure we handled all packets. @@ -277,7 +281,8 @@ impl ReferenceStateMachine for ReferenceState { relay, client_cidr_resources: IpNetworkTable::new(), connected_resources: Default::default(), - gateway_received_icmp_packets: Default::default(), + gateway_received_icmp_requests: Default::default(), + client_received_icmp_replies: Default::default(), }) .boxed() } @@ -426,7 +431,9 @@ impl TunnelTest { continue; } - if let ControlFlow::Break(_) = self.try_handle_gateway(dst, src, &payload) { + if let ControlFlow::Break(_) = + self.try_handle_gateway(dst, src, &payload, buffered_transmits) + { continue; } @@ -536,7 +543,10 @@ impl TunnelTest { return; } - if self.try_handle_gateway(dst, src, payload).is_break() { + if self + .try_handle_gateway(dst, src, payload, buffered_transmits) + .is_break() + { return; } @@ -577,7 +587,11 @@ impl TunnelTest { .state .decapsulate(dst, src, payload, self.now, &mut buffer) }) { - self.client_received_packets.push_back(packet.to_owned()); + self.client_received_icmp_replies.push_back(( + self.now, + packet.source(), + packet.destination(), + )); }; ControlFlow::Break(()) @@ -588,6 +602,7 @@ impl TunnelTest { dst: SocketAddr, src: SocketAddr, payload: &[u8], + buffered_transmits: &mut VecDeque<(Transmit<'static>, Option)>, ) -> ControlFlow<()> { let mut buffer = [0u8; 200]; // In these tests, we only send ICMP packets which are very small. @@ -602,11 +617,29 @@ impl TunnelTest { }) { // TODO: Assert that it is an ICMP packet. - self.gateway_received_icmp_packets.push_back(( - self.now, - packet.source(), - packet.destination(), - )); + let packet_src = packet.source(); + let packet_dst = packet.destination(); + + assert_eq!( + packet_src, + self.client.tunnel_ip(packet_src), + "ICMP packet to originate from client" + ); + + self.gateway_received_icmp_requests + .push_back((self.now, packet_src, packet_dst)); + + if let Some(transmit) = self.gateway.span.in_scope(|| { + self.gateway.state.encapsulate( + ip_packet::make::icmp_response_packet(packet_dst, packet_src), + self.now, + ) + }) { + let transmit = transmit.into_owned(); + let dst = transmit.dst; + + buffered_transmits.push_back((transmit, self.gateway.sending_socket_for(dst))); + }; }; ControlFlow::Break(()) @@ -776,8 +809,10 @@ impl ReferenceState { if self.connected_resources.longest_match(dst).is_some() { tracing::debug!("Connected to resource, expecting packet to be routed to gateway"); - self.gateway_received_icmp_packets + self.gateway_received_icmp_requests .push_back((self.now, src, dst)); + self.client_received_icmp_replies + .push_back((self.now, dst, src)); return; } diff --git a/rust/ip-packet/src/make.rs b/rust/ip-packet/src/make.rs index 9fef2eca2..2200a20e7 100644 --- a/rust/ip-packet/src/make.rs +++ b/rust/ip-packet/src/make.rs @@ -11,8 +11,27 @@ use pnet_packet::{ use crate::MutableIpPacket; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; -pub fn icmp_request_packet(source: IpAddr, dst: IpAddr) -> MutableIpPacket<'static> { - match (source, dst) { +pub fn icmp_request_packet(src: IpAddr, dst: IpAddr) -> MutableIpPacket<'static> { + icmp_packet(src, dst, 1, 0, IcmpKind::Request) +} + +pub fn icmp_response_packet(src: IpAddr, dst: IpAddr) -> MutableIpPacket<'static> { + icmp_packet(src, dst, 1, 0, IcmpKind::Response) +} + +enum IcmpKind { + Request, + Response, +} + +fn icmp_packet( + src: IpAddr, + dst: IpAddr, + seq: u16, + identifier: u16, + kind: IcmpKind, +) -> MutableIpPacket<'static> { + match (src, dst) { (IpAddr::V4(src), IpAddr::V4(dst)) => { use crate::{ icmp::{ @@ -28,14 +47,24 @@ pub fn icmp_request_packet(source: IpAddr, dst: IpAddr) -> MutableIpPacket<'stat ipv4_header(src, dst, IpNextHeaderProtocols::Icmp, &mut buf[..]); let mut icmp_packet = MutableIcmpPacket::new(&mut buf[20..]).unwrap(); - icmp_packet.set_icmp_type(IcmpTypes::EchoRequest); - icmp_packet.set_icmp_code(IcmpCodes::NoCode); + + match kind { + IcmpKind::Request => { + icmp_packet.set_icmp_type(IcmpTypes::EchoRequest); + icmp_packet.set_icmp_code(IcmpCodes::NoCode); + } + IcmpKind::Response => { + icmp_packet.set_icmp_type(IcmpTypes::EchoReply); + icmp_packet.set_icmp_code(IcmpCodes::NoCode); + } + } + icmp_packet.set_checksum(0); let mut echo_request_packet = MutableEchoRequestPacket::new(icmp_packet.payload_mut()).unwrap(); - echo_request_packet.set_sequence_number(1); - echo_request_packet.set_identifier(0); + echo_request_packet.set_sequence_number(seq); + echo_request_packet.set_identifier(identifier); echo_request_packet.set_checksum(crate::util::checksum( echo_request_packet.to_immutable().packet(), 2, @@ -59,13 +88,21 @@ pub fn icmp_request_packet(source: IpAddr, dst: IpAddr) -> MutableIpPacket<'stat let mut icmp_packet = MutableIcmpv6Packet::new(&mut buf[40..]).unwrap(); - icmp_packet.set_icmpv6_type(Icmpv6Types::EchoRequest); - icmp_packet.set_icmpv6_code(Icmpv6Code::new(0)); // No code for echo request + match kind { + IcmpKind::Request => { + icmp_packet.set_icmpv6_type(Icmpv6Types::EchoRequest); + icmp_packet.set_icmpv6_code(Icmpv6Code::new(0)); + } + IcmpKind::Response => { + icmp_packet.set_icmpv6_type(Icmpv6Types::EchoReply); + icmp_packet.set_icmpv6_code(Icmpv6Code::new(0)); + } + } let mut echo_request_packet = MutableEchoRequestPacket::new(icmp_packet.payload_mut()).unwrap(); - echo_request_packet.set_identifier(0); - echo_request_packet.set_sequence_number(1); + echo_request_packet.set_identifier(identifier); + echo_request_packet.set_sequence_number(seq); echo_request_packet.set_checksum(0); let checksum = crate::icmpv6::checksum(&icmp_packet.to_immutable(), &src, &dst);