chore(snownet): assert that we can send ICMP packets through the tunnel (#4675)

This is extracted out of #4568 to make that PR smaller. Plus, I'd like
to use these new assertions in #4615.
This commit is contained in:
Thomas Eizinger
2024-04-19 12:31:32 +10:00
committed by GitHub
parent 95219376b9
commit 022e431be2
3 changed files with 207 additions and 40 deletions

View File

@@ -37,6 +37,16 @@ impl<'a> MutableIpPacket<'a> {
}
}
pub fn owned(data: Vec<u8>) -> Option<MutableIpPacket<'static>> {
let packet = match data[0] >> 4 {
4 => MutableIpv4Packet::owned(data)?.into(),
6 => MutableIpv6Packet::owned(data)?.into(),
_ => return None,
};
Some(packet)
}
pub fn to_owned(&self) -> MutableIpPacket<'static> {
match self {
MutableIpPacket::Ipv4(i) => MutableIpv4Packet::owned(i.packet().to_vec())

View File

@@ -157,10 +157,11 @@ where
(&self.private_key).into()
}
pub fn is_connected_to(&self, key: PublicKey) -> bool {
self.connections
.iter_established()
.any(|(_, c)| c.remote_pub_key == key && c.tunnel.time_since_last_handshake().is_some())
pub fn connection_id(&self, key: PublicKey) -> Option<TId> {
self.connections.iter_established().find_map(|(id, c)| {
(c.remote_pub_key == key && c.tunnel.time_since_last_handshake().is_some())
.then_some(id)
})
}
pub fn stats(&self) -> (NodeStats, impl Iterator<Item = (TId, ConnectionStats)> + '_) {

View File

@@ -1,9 +1,9 @@
use boringtun::x25519::{PublicKey, StaticSecret};
use firezone_relay::{AddressFamily, AllocationPort, ClientSocket, IpStack, PeerSocket};
use rand::rngs::OsRng;
use snownet::{Answer, ClientNode, Event, MutableIpPacket, ServerNode, Transmit};
use snownet::{Answer, ClientNode, Event, IpPacket, MutableIpPacket, ServerNode, Transmit};
use std::{
collections::HashSet,
collections::{HashSet, VecDeque},
iter,
net::{IpAddr, Ipv4Addr, SocketAddr, SocketAddrV4},
time::{Duration, Instant, SystemTime},
@@ -35,26 +35,34 @@ fn smoke_direct() {
progress(&mut alice, &mut bob, &mut [], &firewall, &mut clock);
}
alice.ping(ip("9.9.9.9"), ip("8.8.8.8"), &bob, clock.now);
progress(&mut alice, &mut bob, &mut [], &firewall, &mut clock);
assert_eq!(bob.packets_from(ip("9.9.9.9")).count(), 1);
bob.ping(ip("8.8.8.8"), ip("9.9.9.9"), &alice, clock.now);
progress(&mut alice, &mut bob, &mut [], &firewall, &mut clock);
assert_eq!(alice.packets_from(ip("8.8.8.8")).count(), 1);
}
#[test]
fn smoke_relayed() {
let _guard = setup_tracing();
let mut clock = Clock::new();
let firewall = Firewall::default()
.with_block_rule("1.1.1.1:80", "2.2.2.2:80")
.with_block_rule("2.2.2.2:80", "1.1.1.1:80");
let (alice, bob) = alice_and_bob();
let mut relays = [TestRelay::new(
IpAddr::V4(Ipv4Addr::LOCALHOST),
debug_span!("Roger"),
let mut relays = [(
1,
TestRelay::new(IpAddr::V4(Ipv4Addr::LOCALHOST), debug_span!("Roger")),
)];
let mut alice = TestNode::new(debug_span!("Alice"), alice, "1.1.1.1:80")
.with_relays(&mut relays, clock.now);
let mut bob =
TestNode::new(debug_span!("Bob"), bob, "2.2.2.2:80").with_relays(&mut relays, clock.now);
let firewall = Firewall::default()
.with_block_rule(&alice, &bob)
.with_block_rule(&bob, &alice);
handshake(&mut alice, &mut bob, &clock);
@@ -65,6 +73,14 @@ fn smoke_relayed() {
progress(&mut alice, &mut bob, &mut relays, &firewall, &mut clock);
}
alice.ping(ip("9.9.9.9"), ip("8.8.8.8"), &bob, clock.now);
progress(&mut alice, &mut bob, &mut relays, &firewall, &mut clock);
assert_eq!(bob.packets_from(ip("9.9.9.9")).count(), 1);
bob.ping(ip("8.8.8.8"), ip("9.9.9.9"), &alice, clock.now);
progress(&mut alice, &mut bob, &mut relays, &firewall, &mut clock);
assert_eq!(alice.packets_from(ip("8.8.8.8")).count(), 1);
}
#[test]
@@ -75,9 +91,9 @@ fn reconnect_discovers_new_interface() {
let (alice, bob) = alice_and_bob();
let mut relays = [TestRelay::new(
IpAddr::V4(Ipv4Addr::LOCALHOST),
debug_span!("Roger"),
let mut relays = [(
1,
TestRelay::new(IpAddr::V4(Ipv4Addr::LOCALHOST), debug_span!("Roger")),
)];
let mut alice = TestNode::new(debug_span!("Alice"), alice, "1.1.1.1:80")
.with_relays(&mut relays, clock.now);
@@ -102,6 +118,14 @@ fn reconnect_discovers_new_interface() {
progress(&mut alice, &mut bob, &mut relays, &firewall, &mut clock);
}
alice.ping(ip("9.9.9.9"), ip("8.8.8.8"), &bob, clock.now);
progress(&mut alice, &mut bob, &mut relays, &firewall, &mut clock);
assert_eq!(bob.packets_from(ip("9.9.9.9")).count(), 1);
bob.ping(ip("8.8.8.8"), ip("9.9.9.9"), &alice, clock.now);
progress(&mut alice, &mut bob, &mut relays, &firewall, &mut clock);
assert_eq!(alice.packets_from(ip("8.8.8.8")).count(), 1);
assert!(alice
.signalled_candidates()
.any(|(_, c, _)| c.addr().to_string() == "10.0.0.1:80"));
@@ -305,13 +329,19 @@ fn s(socket: &str) -> SocketAddr {
socket.parse().unwrap()
}
fn ip(ip: &str) -> IpAddr {
ip.parse().unwrap()
}
const RELAY: SocketAddr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 10000));
// Heavily inspired by https://github.com/algesten/str0m/blob/7ed5143381cf095f7074689cc254b8c9e50d25c5/src/ice/mod.rs#L547-L647.
struct TestNode {
node: EitherNode,
transmits: VecDeque<Transmit<'static>>,
span: Span,
received_packets: Vec<MutableIpPacket<'static>>,
received_packets: Vec<IpPacket<'static>>,
/// The primary interface we use to send packets (e.g. to relays).
primary: SocketAddr,
/// All local interfaces.
@@ -373,9 +403,8 @@ impl Clock {
}
impl Firewall {
fn with_block_rule(mut self, from: &str, to: &str) -> Self {
self.blocked
.push((from.parse().unwrap(), to.parse().unwrap()));
fn with_block_rule(mut self, from: &TestNode, to: &TestNode) -> Self {
self.blocked.push((from.primary, to.primary));
self
}
@@ -551,7 +580,7 @@ impl From<ServerNode<u64, u64>> for EitherNode {
}
impl EitherNode {
fn poll_transmit(&mut self) -> Option<Transmit> {
fn poll_transmit(&mut self) -> Option<Transmit<'static>> {
match self {
EitherNode::Client(n) => n.poll_transmit(),
EitherNode::Server(n) => n.poll_transmit(),
@@ -586,10 +615,10 @@ impl EitherNode {
}
}
fn is_connected_to(&self, key: PublicKey) -> bool {
fn connection_id(&self, key: PublicKey) -> Option<u64> {
match self {
EitherNode::Client(n) => n.is_connected_to(key),
EitherNode::Server(n) => n.is_connected_to(key),
EitherNode::Client(n) => n.connection_id(key),
EitherNode::Server(n) => n.connection_id(key),
}
}
@@ -614,6 +643,18 @@ impl EitherNode {
}
}
fn encapsulate<'s>(
&'s mut self,
connection: u64,
packet: IpPacket<'_>,
now: Instant,
) -> Result<Option<Transmit<'s>>, snownet::Error> {
match self {
EitherNode::Client(n) => n.encapsulate(connection, packet, now),
EitherNode::Server(n) => n.encapsulate(connection, packet, now),
}
}
fn decapsulate<'s>(
&mut self,
local: SocketAddr,
@@ -655,10 +696,11 @@ impl TestNode {
primary,
local: vec![primary],
events: Default::default(),
transmits: Default::default(),
}
}
fn with_relays(mut self, relays: &mut [TestRelay], now: Instant) -> Self {
fn with_relays(mut self, relays: &mut [(u64, TestRelay)], now: Instant) -> Self {
let username = match self.node {
EitherNode::Server(_) => "server",
EitherNode::Client(_) => "client",
@@ -666,12 +708,11 @@ impl TestNode {
let turn_servers = relays
.iter()
.enumerate()
.map(|(idx, relay)| {
let (username, password) = relay.make_credentials(username);
(
idx as u64,
*idx,
relay.listen_addr,
username,
password,
@@ -680,10 +721,10 @@ impl TestNode {
})
.collect::<HashSet<_>>();
match &mut self.node {
self.span.in_scope(|| match &mut self.node {
EitherNode::Server(s) => s.upsert_turn_servers(&turn_servers, now),
EitherNode::Client(c) => c.upsert_turn_servers(&turn_servers, now),
}
});
self
}
@@ -694,7 +735,26 @@ impl TestNode {
}
fn is_connected_to(&self, other: &TestNode) -> bool {
self.node.is_connected_to(other.node.public_key())
self.node.connection_id(other.node.public_key()).is_some()
}
fn ping(&mut self, src: IpAddr, dst: IpAddr, other: &TestNode, now: Instant) {
let id = self
.node
.connection_id(other.node.public_key())
.expect("cannot ping not-connected node");
let transmit = self
.span
.in_scope(|| {
self.node
.encapsulate(id, icmp_request_packet(src, dst).to_immutable(), now)
})
.unwrap()
.unwrap()
.into_owned();
self.transmits.push_back(transmit);
}
fn signalled_candidates(&self) -> impl Iterator<Item = (u64, Candidate, Instant)> + '_ {
@@ -711,6 +771,12 @@ impl TestNode {
})
}
fn packets_from(&self, src: IpAddr) -> impl Iterator<Item = &IpPacket<'static>> {
self.received_packets
.iter()
.filter(move |p| p.source() == src)
}
fn failed_connections(&self) -> impl Iterator<Item = (u64, Instant)> + '_ {
self.events.iter().filter_map(|(e, instant)| match e {
Event::ConnectionFailed(id) => Some((*id, *instant)),
@@ -728,7 +794,7 @@ impl TestNode {
})
.unwrap()
{
self.received_packets.push(packet.to_owned())
self.received_packets.push(packet.to_immutable().to_owned())
}
}
@@ -740,7 +806,9 @@ impl TestNode {
Event::SignalIceCandidate {
connection,
candidate,
} => other.node.add_remote_candidate(connection, candidate, now),
} => other
.span
.in_scope(|| other.node.add_remote_candidate(connection, candidate, now)),
Event::ConnectionEstablished(_) => {}
Event::ConnectionFailed(_) => {}
};
@@ -750,22 +818,23 @@ impl TestNode {
fn drain_transmits(
&mut self,
other: &mut TestNode,
relays: &mut [TestRelay],
relays: &mut [(u64, TestRelay)],
firewall: &Firewall,
now: Instant,
) {
while let Some(trans) = self.span.in_scope(|| self.node.poll_transmit()) {
for trans in iter::from_fn(|| self.node.poll_transmit()).chain(self.transmits.drain(..)) {
let payload = &trans.payload;
let dst = trans.dst;
if let Some(relay) = relays.iter_mut().find(|r| r.wants(trans.dst)) {
if let Some((_, relay)) = relays.iter_mut().find(|(_, r)| r.wants(trans.dst)) {
relay.handle_packet(payload, self.primary, dst, other, now);
continue;
}
let src = trans
.src
.expect("transmits without `src` should always be handled by relays");
let Some(src) = trans.src else {
tracing::debug!(target: "router", %dst, "Unknown relay");
continue;
};
// Wasn't traffic for the relay, let's check our firewall.
if firewall.blocked.contains(&(src, dst)) {
@@ -817,7 +886,7 @@ fn handshake(client: &mut TestNode, server: &mut TestNode, clock: &Clock) {
fn progress(
a1: &mut TestNode,
a2: &mut TestNode,
relays: &mut [TestRelay],
relays: &mut [(u64, TestRelay)],
firewall: &Firewall,
clock: &mut Clock,
) {
@@ -829,7 +898,7 @@ fn progress(
a1.drain_transmits(a2, relays, firewall, clock.now);
a2.drain_transmits(a1, relays, firewall, clock.now);
for relay in relays.iter_mut() {
for (_, relay) in relays.iter_mut() {
relay.drain_messages(a1, a2, clock.now);
}
@@ -845,7 +914,7 @@ fn progress(
}
}
for relay in relays {
for (_, relay) in relays {
if let Some(timeout) = relay.inner.poll_timeout() {
if clock.now >= timeout {
relay
@@ -855,3 +924,90 @@ fn progress(
}
}
}
fn icmp_request_packet(source: IpAddr, dst: IpAddr) -> MutableIpPacket<'static> {
match (source, dst) {
(IpAddr::V4(src), IpAddr::V4(dst)) => {
use pnet_packet::{
icmp::{
echo_request::{IcmpCodes, MutableEchoRequestPacket},
IcmpTypes, MutableIcmpPacket,
},
ip::IpNextHeaderProtocols,
ipv4::MutableIpv4Packet,
MutablePacket as _, Packet as _,
};
let mut buf = vec![0u8; 60];
let mut ipv4_packet = MutableIpv4Packet::new(&mut buf[..]).unwrap();
ipv4_packet.set_version(4);
ipv4_packet.set_header_length(5);
ipv4_packet.set_total_length(60);
ipv4_packet.set_ttl(64);
ipv4_packet.set_next_level_protocol(IpNextHeaderProtocols::Icmp);
ipv4_packet.set_source(src);
ipv4_packet.set_destination(dst);
ipv4_packet.set_checksum(pnet_packet::ipv4::checksum(&ipv4_packet.to_immutable()));
let mut icmp_packet = MutableIcmpPacket::new(&mut buf[20..]).unwrap();
icmp_packet.set_icmp_type(IcmpTypes::EchoRequest);
icmp_packet.set_icmp_code(IcmpCodes::NoCode);
icmp_packet.set_checksum(0);
let mut echo_request_packet =
MutableEchoRequestPacket::new(icmp_packet.payload_mut()).unwrap();
echo_request_packet.set_sequence_number(1);
echo_request_packet.set_identifier(0);
echo_request_packet.set_checksum(pnet_packet::util::checksum(
echo_request_packet.to_immutable().packet(),
2,
));
MutableIpPacket::owned(buf).unwrap()
}
(IpAddr::V6(src), IpAddr::V6(dst)) => {
use pnet_packet::{
icmpv6::{
echo_request::MutableEchoRequestPacket, Icmpv6Code, Icmpv6Types,
MutableIcmpv6Packet,
},
ip::IpNextHeaderProtocols,
ipv6::MutableIpv6Packet,
MutablePacket as _,
};
let mut buf = vec![0u8; 128];
let mut ipv6_packet = MutableIpv6Packet::new(&mut buf[..]).unwrap();
ipv6_packet.set_version(6);
ipv6_packet.set_payload_length(16);
ipv6_packet.set_next_header(IpNextHeaderProtocols::Icmpv6);
ipv6_packet.set_hop_limit(64);
ipv6_packet.set_source(src);
ipv6_packet.set_destination(dst);
let mut icmp_packet = MutableIcmpv6Packet::new(&mut buf[40..]).unwrap();
icmp_packet.set_icmpv6_type(Icmpv6Types::EchoRequest);
icmp_packet.set_icmpv6_code(Icmpv6Code::new(0)); // No code for echo request
let mut echo_request_packet =
MutableEchoRequestPacket::new(icmp_packet.payload_mut()).unwrap();
echo_request_packet.set_identifier(0);
echo_request_packet.set_sequence_number(1);
echo_request_packet.set_checksum(0);
let checksum = pnet_packet::icmpv6::checksum(&icmp_packet.to_immutable(), &src, &dst);
MutableEchoRequestPacket::new(icmp_packet.payload_mut())
.unwrap()
.set_checksum(checksum);
MutableIpPacket::owned(buf).unwrap()
}
(IpAddr::V6(_), IpAddr::V4(_)) | (IpAddr::V4(_), IpAddr::V6(_)) => {
panic!("IPs must be of the same version")
}
}
}