From b73b0cf2b7ebde1a15f327f9423d4605fbcde698 Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Mon, 5 Feb 2024 21:55:23 +1100 Subject: [PATCH] feat(snownet): return `MutableIpPacket` from `decapsulate` (#3555) The user is already passing us a mutable buffer so we might as well give them a `MutableIpPacket` to allow them to further mutate it. Extracted out of #3391. --- rust/connlib/snownet/src/ip_packet.rs | 82 ++++++++++++++++++++++++++- rust/connlib/snownet/src/lib.rs | 2 +- rust/connlib/snownet/src/node.rs | 14 +++-- rust/snownet-tests/src/main.rs | 2 +- 4 files changed, 89 insertions(+), 11 deletions(-) diff --git a/rust/connlib/snownet/src/ip_packet.rs b/rust/connlib/snownet/src/ip_packet.rs index 91196b25a..b2338af58 100644 --- a/rust/connlib/snownet/src/ip_packet.rs +++ b/rust/connlib/snownet/src/ip_packet.rs @@ -1,12 +1,17 @@ use std::net::IpAddr; -use pnet_packet::{ip::IpNextHeaderProtocols, ipv4::Ipv4Packet, ipv6::Ipv6Packet, Packet}; +use pnet_packet::{ + ip::IpNextHeaderProtocols, + ipv4::{Ipv4Packet, MutableIpv4Packet}, + ipv6::{Ipv6Packet, MutableIpv6Packet}, + Packet, +}; macro_rules! for_both { ($this:ident, |$name:ident| $body:expr) => { match $this { - IpPacket::Ipv4($name) => $body, - IpPacket::Ipv6($name) => $body, + Self::Ipv4($name) => $body, + Self::Ipv6($name) => $body, } }; } @@ -17,6 +22,45 @@ pub enum IpPacket<'a> { Ipv6(Ipv6Packet<'a>), } +#[derive(Debug, PartialEq)] +pub enum MutableIpPacket<'a> { + Ipv4(MutableIpv4Packet<'a>), + Ipv6(MutableIpv6Packet<'a>), +} + +impl<'a> MutableIpPacket<'a> { + pub fn new(buf: &'a mut [u8]) -> Option { + match buf[0] >> 4 { + 4 => Some(MutableIpPacket::Ipv4(MutableIpv4Packet::new(buf)?)), + 6 => Some(MutableIpPacket::Ipv6(MutableIpv6Packet::new(buf)?)), + _ => None, + } + } + + pub fn to_owned(&self) -> MutableIpPacket<'static> { + match self { + MutableIpPacket::Ipv4(i) => MutableIpv4Packet::owned(i.packet().to_vec()) + .expect("owned packet is still valid") + .into(), + MutableIpPacket::Ipv6(i) => MutableIpv6Packet::owned(i.packet().to_vec()) + .expect("owned packet is still valid") + .into(), + } + } + + pub fn to_immutable(&self) -> IpPacket { + for_both!(self, |i| i.to_immutable().into()) + } + + pub fn source(&self) -> IpAddr { + for_both!(self, |i| i.get_source().into()) + } + + pub fn destination(&self) -> IpAddr { + for_both!(self, |i| i.get_destination().into()) + } +} + impl<'a> IpPacket<'a> { pub fn new(buf: &'a [u8]) -> Option { match buf[0] >> 4 { @@ -70,6 +114,28 @@ impl<'a> From> for IpPacket<'a> { } } +impl<'a> From> for MutableIpPacket<'a> { + fn from(value: MutableIpv4Packet<'a>) -> Self { + Self::Ipv4(value) + } +} + +impl<'a> From> for MutableIpPacket<'a> { + fn from(value: MutableIpv6Packet<'a>) -> Self { + Self::Ipv6(value) + } +} + +impl pnet_packet::Packet for MutableIpPacket<'_> { + fn packet(&self) -> &[u8] { + for_both!(self, |i| i.packet()) + } + + fn payload(&self) -> &[u8] { + for_both!(self, |i| i.payload()) + } +} + impl pnet_packet::Packet for IpPacket<'_> { fn packet(&self) -> &[u8] { for_both!(self, |i| i.packet()) @@ -79,3 +145,13 @@ impl pnet_packet::Packet for IpPacket<'_> { for_both!(self, |i| i.payload()) } } + +impl pnet_packet::MutablePacket for MutableIpPacket<'_> { + fn packet_mut(&mut self) -> &mut [u8] { + for_both!(self, |i| i.packet_mut()) + } + + fn payload_mut(&mut self) -> &mut [u8] { + for_both!(self, |i| i.payload_mut()) + } +} diff --git a/rust/connlib/snownet/src/lib.rs b/rust/connlib/snownet/src/lib.rs index 28950fb91..547c9bd87 100644 --- a/rust/connlib/snownet/src/lib.rs +++ b/rust/connlib/snownet/src/lib.rs @@ -10,5 +10,5 @@ mod node; mod stun_binding; pub use info::ConnectionInfo; -pub use ip_packet::IpPacket; +pub use ip_packet::{IpPacket, MutableIpPacket}; pub use node::{Answer, ClientNode, Credentials, Error, Event, Node, Offer, ServerNode, Transmit}; diff --git a/rust/connlib/snownet/src/node.rs b/rust/connlib/snownet/src/node.rs index 17bc06492..4898ad2d3 100644 --- a/rust/connlib/snownet/src/node.rs +++ b/rust/connlib/snownet/src/node.rs @@ -2,8 +2,8 @@ use boringtun::noise::{Tunn, TunnResult}; use boringtun::x25519::PublicKey; use boringtun::{noise::rate_limiter::RateLimiter, x25519::StaticSecret}; use core::{fmt, slice}; -use pnet_packet::ipv4::Ipv4Packet; -use pnet_packet::ipv6::Ipv6Packet; +use pnet_packet::ipv4::MutableIpv4Packet; +use pnet_packet::ipv6::MutableIpv6Packet; use pnet_packet::Packet; use rand::random; use secrecy::{ExposeSecret, Secret}; @@ -23,7 +23,7 @@ use crate::allocation::Allocation; use crate::index::IndexLfsr; use crate::info::ConnectionInfo; use crate::stun_binding::StunBinding; -use crate::IpPacket; +use crate::{IpPacket, MutableIpPacket}; use boringtun::noise::errors::WireGuardError; use std::borrow::Cow; use stun_codec::rfc5389::attributes::{Realm, Username}; @@ -189,7 +189,7 @@ where packet: &[u8], now: Instant, buffer: &'s mut [u8], - ) -> Result)>, Error> { + ) -> Result)>, Error> { self.add_local_as_host_candidate(local)?; // First, check if a `StunBinding` wants the packet @@ -268,7 +268,8 @@ where TunnResult::WriteToTunnelV4(packet, ip) => { conn.set_remote_from_wg_activity(local, from, remote_socket); - let ipv4_packet = Ipv4Packet::new(packet).expect("boringtun verifies validity"); + let ipv4_packet = + MutableIpv4Packet::new(packet).expect("boringtun verifies validity"); debug_assert_eq!(ipv4_packet.get_source(), ip); Ok(Some((*id, ipv4_packet.into()))) @@ -276,7 +277,8 @@ where TunnResult::WriteToTunnelV6(packet, ip) => { conn.set_remote_from_wg_activity(local, from, remote_socket); - let ipv6_packet = Ipv6Packet::new(packet).expect("boringtun verifies validity"); + let ipv6_packet = + MutableIpv6Packet::new(packet).expect("boringtun verifies validity"); debug_assert_eq!(ipv6_packet.get_source(), ip); Ok(Some((*id, ipv6_packet.into()))) diff --git a/rust/snownet-tests/src/main.rs b/rust/snownet-tests/src/main.rs index ffb6f5974..cc76b7fe0 100644 --- a/rust/snownet-tests/src/main.rs +++ b/rust/snownet-tests/src/main.rs @@ -428,7 +428,7 @@ impl Eventloop { )? { return Poll::Ready(Ok(Event::Incoming { conn, - packet: packet.to_owned(), + packet: packet.to_immutable().to_owned(), })); }