From 022e431be291e453bbe7ff40ac362afc71bab890 Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Fri, 19 Apr 2024 12:31:32 +1000 Subject: [PATCH] chore(snownet): assert that we can send ICMP packets through the tunnel (#4675) This is extracted out of #4568 to make that PR smaller. Plus, I'd like to use these new assertions in #4615. --- rust/connlib/snownet/src/ip_packet.rs | 10 ++ rust/connlib/snownet/src/node.rs | 9 +- rust/connlib/snownet/tests/lib.rs | 228 ++++++++++++++++++++++---- 3 files changed, 207 insertions(+), 40 deletions(-) diff --git a/rust/connlib/snownet/src/ip_packet.rs b/rust/connlib/snownet/src/ip_packet.rs index b2338af58..db5e44985 100644 --- a/rust/connlib/snownet/src/ip_packet.rs +++ b/rust/connlib/snownet/src/ip_packet.rs @@ -37,6 +37,16 @@ impl<'a> MutableIpPacket<'a> { } } + pub fn owned(data: Vec) -> Option> { + let packet = match data[0] >> 4 { + 4 => MutableIpv4Packet::owned(data)?.into(), + 6 => MutableIpv6Packet::owned(data)?.into(), + _ => return None, + }; + + Some(packet) + } + pub fn to_owned(&self) -> MutableIpPacket<'static> { match self { MutableIpPacket::Ipv4(i) => MutableIpv4Packet::owned(i.packet().to_vec()) diff --git a/rust/connlib/snownet/src/node.rs b/rust/connlib/snownet/src/node.rs index fac23e150..51ec2f5c5 100644 --- a/rust/connlib/snownet/src/node.rs +++ b/rust/connlib/snownet/src/node.rs @@ -157,10 +157,11 @@ where (&self.private_key).into() } - pub fn is_connected_to(&self, key: PublicKey) -> bool { - self.connections - .iter_established() - .any(|(_, c)| c.remote_pub_key == key && c.tunnel.time_since_last_handshake().is_some()) + pub fn connection_id(&self, key: PublicKey) -> Option { + self.connections.iter_established().find_map(|(id, c)| { + (c.remote_pub_key == key && c.tunnel.time_since_last_handshake().is_some()) + .then_some(id) + }) } pub fn stats(&self) -> (NodeStats, impl Iterator + '_) { diff --git a/rust/connlib/snownet/tests/lib.rs b/rust/connlib/snownet/tests/lib.rs index 386c8f885..c50abd340 100644 --- a/rust/connlib/snownet/tests/lib.rs +++ b/rust/connlib/snownet/tests/lib.rs @@ -1,9 +1,9 @@ use boringtun::x25519::{PublicKey, StaticSecret}; use firezone_relay::{AddressFamily, AllocationPort, ClientSocket, IpStack, PeerSocket}; use rand::rngs::OsRng; -use snownet::{Answer, ClientNode, Event, MutableIpPacket, ServerNode, Transmit}; +use snownet::{Answer, ClientNode, Event, IpPacket, MutableIpPacket, ServerNode, Transmit}; use std::{ - collections::HashSet, + collections::{HashSet, VecDeque}, iter, net::{IpAddr, Ipv4Addr, SocketAddr, SocketAddrV4}, time::{Duration, Instant, SystemTime}, @@ -35,26 +35,34 @@ fn smoke_direct() { progress(&mut alice, &mut bob, &mut [], &firewall, &mut clock); } + + alice.ping(ip("9.9.9.9"), ip("8.8.8.8"), &bob, clock.now); + progress(&mut alice, &mut bob, &mut [], &firewall, &mut clock); + assert_eq!(bob.packets_from(ip("9.9.9.9")).count(), 1); + + bob.ping(ip("8.8.8.8"), ip("9.9.9.9"), &alice, clock.now); + progress(&mut alice, &mut bob, &mut [], &firewall, &mut clock); + assert_eq!(alice.packets_from(ip("8.8.8.8")).count(), 1); } #[test] fn smoke_relayed() { let _guard = setup_tracing(); let mut clock = Clock::new(); - let firewall = Firewall::default() - .with_block_rule("1.1.1.1:80", "2.2.2.2:80") - .with_block_rule("2.2.2.2:80", "1.1.1.1:80"); let (alice, bob) = alice_and_bob(); - let mut relays = [TestRelay::new( - IpAddr::V4(Ipv4Addr::LOCALHOST), - debug_span!("Roger"), + let mut relays = [( + 1, + TestRelay::new(IpAddr::V4(Ipv4Addr::LOCALHOST), debug_span!("Roger")), )]; let mut alice = TestNode::new(debug_span!("Alice"), alice, "1.1.1.1:80") .with_relays(&mut relays, clock.now); let mut bob = TestNode::new(debug_span!("Bob"), bob, "2.2.2.2:80").with_relays(&mut relays, clock.now); + let firewall = Firewall::default() + .with_block_rule(&alice, &bob) + .with_block_rule(&bob, &alice); handshake(&mut alice, &mut bob, &clock); @@ -65,6 +73,14 @@ fn smoke_relayed() { progress(&mut alice, &mut bob, &mut relays, &firewall, &mut clock); } + + alice.ping(ip("9.9.9.9"), ip("8.8.8.8"), &bob, clock.now); + progress(&mut alice, &mut bob, &mut relays, &firewall, &mut clock); + assert_eq!(bob.packets_from(ip("9.9.9.9")).count(), 1); + + bob.ping(ip("8.8.8.8"), ip("9.9.9.9"), &alice, clock.now); + progress(&mut alice, &mut bob, &mut relays, &firewall, &mut clock); + assert_eq!(alice.packets_from(ip("8.8.8.8")).count(), 1); } #[test] @@ -75,9 +91,9 @@ fn reconnect_discovers_new_interface() { let (alice, bob) = alice_and_bob(); - let mut relays = [TestRelay::new( - IpAddr::V4(Ipv4Addr::LOCALHOST), - debug_span!("Roger"), + let mut relays = [( + 1, + TestRelay::new(IpAddr::V4(Ipv4Addr::LOCALHOST), debug_span!("Roger")), )]; let mut alice = TestNode::new(debug_span!("Alice"), alice, "1.1.1.1:80") .with_relays(&mut relays, clock.now); @@ -102,6 +118,14 @@ fn reconnect_discovers_new_interface() { progress(&mut alice, &mut bob, &mut relays, &firewall, &mut clock); } + alice.ping(ip("9.9.9.9"), ip("8.8.8.8"), &bob, clock.now); + progress(&mut alice, &mut bob, &mut relays, &firewall, &mut clock); + assert_eq!(bob.packets_from(ip("9.9.9.9")).count(), 1); + + bob.ping(ip("8.8.8.8"), ip("9.9.9.9"), &alice, clock.now); + progress(&mut alice, &mut bob, &mut relays, &firewall, &mut clock); + assert_eq!(alice.packets_from(ip("8.8.8.8")).count(), 1); + assert!(alice .signalled_candidates() .any(|(_, c, _)| c.addr().to_string() == "10.0.0.1:80")); @@ -305,13 +329,19 @@ fn s(socket: &str) -> SocketAddr { socket.parse().unwrap() } +fn ip(ip: &str) -> IpAddr { + ip.parse().unwrap() +} + const RELAY: SocketAddr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 10000)); // Heavily inspired by https://github.com/algesten/str0m/blob/7ed5143381cf095f7074689cc254b8c9e50d25c5/src/ice/mod.rs#L547-L647. struct TestNode { node: EitherNode, + transmits: VecDeque>, + span: Span, - received_packets: Vec>, + received_packets: Vec>, /// The primary interface we use to send packets (e.g. to relays). primary: SocketAddr, /// All local interfaces. @@ -373,9 +403,8 @@ impl Clock { } impl Firewall { - fn with_block_rule(mut self, from: &str, to: &str) -> Self { - self.blocked - .push((from.parse().unwrap(), to.parse().unwrap())); + fn with_block_rule(mut self, from: &TestNode, to: &TestNode) -> Self { + self.blocked.push((from.primary, to.primary)); self } @@ -551,7 +580,7 @@ impl From> for EitherNode { } impl EitherNode { - fn poll_transmit(&mut self) -> Option { + fn poll_transmit(&mut self) -> Option> { match self { EitherNode::Client(n) => n.poll_transmit(), EitherNode::Server(n) => n.poll_transmit(), @@ -586,10 +615,10 @@ impl EitherNode { } } - fn is_connected_to(&self, key: PublicKey) -> bool { + fn connection_id(&self, key: PublicKey) -> Option { match self { - EitherNode::Client(n) => n.is_connected_to(key), - EitherNode::Server(n) => n.is_connected_to(key), + EitherNode::Client(n) => n.connection_id(key), + EitherNode::Server(n) => n.connection_id(key), } } @@ -614,6 +643,18 @@ impl EitherNode { } } + fn encapsulate<'s>( + &'s mut self, + connection: u64, + packet: IpPacket<'_>, + now: Instant, + ) -> Result>, snownet::Error> { + match self { + EitherNode::Client(n) => n.encapsulate(connection, packet, now), + EitherNode::Server(n) => n.encapsulate(connection, packet, now), + } + } + fn decapsulate<'s>( &mut self, local: SocketAddr, @@ -655,10 +696,11 @@ impl TestNode { primary, local: vec![primary], events: Default::default(), + transmits: Default::default(), } } - fn with_relays(mut self, relays: &mut [TestRelay], now: Instant) -> Self { + fn with_relays(mut self, relays: &mut [(u64, TestRelay)], now: Instant) -> Self { let username = match self.node { EitherNode::Server(_) => "server", EitherNode::Client(_) => "client", @@ -666,12 +708,11 @@ impl TestNode { let turn_servers = relays .iter() - .enumerate() .map(|(idx, relay)| { let (username, password) = relay.make_credentials(username); ( - idx as u64, + *idx, relay.listen_addr, username, password, @@ -680,10 +721,10 @@ impl TestNode { }) .collect::>(); - match &mut self.node { + self.span.in_scope(|| match &mut self.node { EitherNode::Server(s) => s.upsert_turn_servers(&turn_servers, now), EitherNode::Client(c) => c.upsert_turn_servers(&turn_servers, now), - } + }); self } @@ -694,7 +735,26 @@ impl TestNode { } fn is_connected_to(&self, other: &TestNode) -> bool { - self.node.is_connected_to(other.node.public_key()) + self.node.connection_id(other.node.public_key()).is_some() + } + + fn ping(&mut self, src: IpAddr, dst: IpAddr, other: &TestNode, now: Instant) { + let id = self + .node + .connection_id(other.node.public_key()) + .expect("cannot ping not-connected node"); + + let transmit = self + .span + .in_scope(|| { + self.node + .encapsulate(id, icmp_request_packet(src, dst).to_immutable(), now) + }) + .unwrap() + .unwrap() + .into_owned(); + + self.transmits.push_back(transmit); } fn signalled_candidates(&self) -> impl Iterator + '_ { @@ -711,6 +771,12 @@ impl TestNode { }) } + fn packets_from(&self, src: IpAddr) -> impl Iterator> { + self.received_packets + .iter() + .filter(move |p| p.source() == src) + } + fn failed_connections(&self) -> impl Iterator + '_ { self.events.iter().filter_map(|(e, instant)| match e { Event::ConnectionFailed(id) => Some((*id, *instant)), @@ -728,7 +794,7 @@ impl TestNode { }) .unwrap() { - self.received_packets.push(packet.to_owned()) + self.received_packets.push(packet.to_immutable().to_owned()) } } @@ -740,7 +806,9 @@ impl TestNode { Event::SignalIceCandidate { connection, candidate, - } => other.node.add_remote_candidate(connection, candidate, now), + } => other + .span + .in_scope(|| other.node.add_remote_candidate(connection, candidate, now)), Event::ConnectionEstablished(_) => {} Event::ConnectionFailed(_) => {} }; @@ -750,22 +818,23 @@ impl TestNode { fn drain_transmits( &mut self, other: &mut TestNode, - relays: &mut [TestRelay], + relays: &mut [(u64, TestRelay)], firewall: &Firewall, now: Instant, ) { - while let Some(trans) = self.span.in_scope(|| self.node.poll_transmit()) { + for trans in iter::from_fn(|| self.node.poll_transmit()).chain(self.transmits.drain(..)) { let payload = &trans.payload; let dst = trans.dst; - if let Some(relay) = relays.iter_mut().find(|r| r.wants(trans.dst)) { + if let Some((_, relay)) = relays.iter_mut().find(|(_, r)| r.wants(trans.dst)) { relay.handle_packet(payload, self.primary, dst, other, now); continue; } - let src = trans - .src - .expect("transmits without `src` should always be handled by relays"); + let Some(src) = trans.src else { + tracing::debug!(target: "router", %dst, "Unknown relay"); + continue; + }; // Wasn't traffic for the relay, let's check our firewall. if firewall.blocked.contains(&(src, dst)) { @@ -817,7 +886,7 @@ fn handshake(client: &mut TestNode, server: &mut TestNode, clock: &Clock) { fn progress( a1: &mut TestNode, a2: &mut TestNode, - relays: &mut [TestRelay], + relays: &mut [(u64, TestRelay)], firewall: &Firewall, clock: &mut Clock, ) { @@ -829,7 +898,7 @@ fn progress( a1.drain_transmits(a2, relays, firewall, clock.now); a2.drain_transmits(a1, relays, firewall, clock.now); - for relay in relays.iter_mut() { + for (_, relay) in relays.iter_mut() { relay.drain_messages(a1, a2, clock.now); } @@ -845,7 +914,7 @@ fn progress( } } - for relay in relays { + for (_, relay) in relays { if let Some(timeout) = relay.inner.poll_timeout() { if clock.now >= timeout { relay @@ -855,3 +924,90 @@ fn progress( } } } + +fn icmp_request_packet(source: IpAddr, dst: IpAddr) -> MutableIpPacket<'static> { + match (source, dst) { + (IpAddr::V4(src), IpAddr::V4(dst)) => { + use pnet_packet::{ + icmp::{ + echo_request::{IcmpCodes, MutableEchoRequestPacket}, + IcmpTypes, MutableIcmpPacket, + }, + ip::IpNextHeaderProtocols, + ipv4::MutableIpv4Packet, + MutablePacket as _, Packet as _, + }; + + let mut buf = vec![0u8; 60]; + + let mut ipv4_packet = MutableIpv4Packet::new(&mut buf[..]).unwrap(); + ipv4_packet.set_version(4); + ipv4_packet.set_header_length(5); + ipv4_packet.set_total_length(60); + ipv4_packet.set_ttl(64); + ipv4_packet.set_next_level_protocol(IpNextHeaderProtocols::Icmp); + ipv4_packet.set_source(src); + ipv4_packet.set_destination(dst); + ipv4_packet.set_checksum(pnet_packet::ipv4::checksum(&ipv4_packet.to_immutable())); + + let mut icmp_packet = MutableIcmpPacket::new(&mut buf[20..]).unwrap(); + icmp_packet.set_icmp_type(IcmpTypes::EchoRequest); + icmp_packet.set_icmp_code(IcmpCodes::NoCode); + icmp_packet.set_checksum(0); + + let mut echo_request_packet = + MutableEchoRequestPacket::new(icmp_packet.payload_mut()).unwrap(); + echo_request_packet.set_sequence_number(1); + echo_request_packet.set_identifier(0); + echo_request_packet.set_checksum(pnet_packet::util::checksum( + echo_request_packet.to_immutable().packet(), + 2, + )); + + MutableIpPacket::owned(buf).unwrap() + } + (IpAddr::V6(src), IpAddr::V6(dst)) => { + use pnet_packet::{ + icmpv6::{ + echo_request::MutableEchoRequestPacket, Icmpv6Code, Icmpv6Types, + MutableIcmpv6Packet, + }, + ip::IpNextHeaderProtocols, + ipv6::MutableIpv6Packet, + MutablePacket as _, + }; + + let mut buf = vec![0u8; 128]; + + let mut ipv6_packet = MutableIpv6Packet::new(&mut buf[..]).unwrap(); + + ipv6_packet.set_version(6); + ipv6_packet.set_payload_length(16); + ipv6_packet.set_next_header(IpNextHeaderProtocols::Icmpv6); + ipv6_packet.set_hop_limit(64); + ipv6_packet.set_source(src); + ipv6_packet.set_destination(dst); + + let mut icmp_packet = MutableIcmpv6Packet::new(&mut buf[40..]).unwrap(); + + icmp_packet.set_icmpv6_type(Icmpv6Types::EchoRequest); + icmp_packet.set_icmpv6_code(Icmpv6Code::new(0)); // No code for echo request + + let mut echo_request_packet = + MutableEchoRequestPacket::new(icmp_packet.payload_mut()).unwrap(); + echo_request_packet.set_identifier(0); + echo_request_packet.set_sequence_number(1); + echo_request_packet.set_checksum(0); + + let checksum = pnet_packet::icmpv6::checksum(&icmp_packet.to_immutable(), &src, &dst); + MutableEchoRequestPacket::new(icmp_packet.payload_mut()) + .unwrap() + .set_checksum(checksum); + + MutableIpPacket::owned(buf).unwrap() + } + (IpAddr::V6(_), IpAddr::V4(_)) | (IpAddr::V4(_), IpAddr::V6(_)) => { + panic!("IPs must be of the same version") + } + } +}