From 7adbf9c6af3cf3e71792caa0fae4bdf833b6a0ac Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Wed, 11 Sep 2024 19:52:48 -0400 Subject: [PATCH] refactor(connlib): remove `pnet_packet` (#6659) As the final step in removing `pnet_packet`, we need to introduce `-Mut` equivalent slices for UDP, TCP and ICMP packets. As a starting point, introducing `UpdHeaderSliceMut` and `TcpHeaderSliceMut` is fairly trivial. The ICMP variants are a bit trickier because those are different for IPv4 and IPv6. Additionally, ICMP for IPv4 is quite complex because it can have a variable header length. Additionally. for both variants, the values in byte range 5-8 are semantically different depending on the ICMP code. This requires us to design an API that balances ergonomics and correctness. Technically, an ICMP identifier and sequence can only be set if the ICMP code is "echo request" or "echo reply". However, adding an additional parsing step to guarantee this in the type system is quite verbose. The trade-off implemented in this PR allows to us to directly write to the byte 5-8 using the `set_identifier` and `set_sequence` functions. To catch errors early, this functions have debug-assertions built in that ensure that the packet is indeed an ICMP echo packet. Resolves: #6366. --- rust/Cargo.lock | 49 -- rust/bin-shared/benches/tunnel.rs | 2 +- rust/connlib/snownet/src/node.rs | 2 +- rust/connlib/tunnel/src/device_channel.rs | 4 +- rust/connlib/tunnel/src/peer.rs | 2 +- rust/connlib/tunnel/src/tests/sim_client.rs | 40 +- rust/connlib/tunnel/src/tests/sim_gateway.rs | 63 ++- rust/ip-packet/Cargo.toml | 1 - rust/ip-packet/src/icmpv4_header_slice_mut.rs | 81 ++++ rust/ip-packet/src/icmpv6_header_slice_mut.rs | 81 ++++ rust/ip-packet/src/lib.rs | 434 ++++++++---------- rust/ip-packet/src/proptests.rs | 1 - rust/ip-packet/src/slice_utils.rs | 2 +- rust/ip-packet/src/tcp_header_slice_mut.rs | 58 +++ rust/ip-packet/src/udp_header_slice_mut.rs | 65 +++ 15 files changed, 558 insertions(+), 327 deletions(-) create mode 100644 rust/ip-packet/src/icmpv4_header_slice_mut.rs create mode 100644 rust/ip-packet/src/icmpv6_header_slice_mut.rs create mode 100644 rust/ip-packet/src/tcp_header_slice_mut.rs create mode 100644 rust/ip-packet/src/udp_header_slice_mut.rs diff --git a/rust/Cargo.lock b/rust/Cargo.lock index dca77f8d7..4aac70edc 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -3086,7 +3086,6 @@ dependencies = [ "anyhow", "domain", "etherparse", - "pnet_packet", "proptest", "test-strategy", "thiserror", @@ -3822,12 +3821,6 @@ dependencies = [ "memoffset 0.9.1", ] -[[package]] -name = "no-std-net" -version = "0.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "43794a0ace135be66a25d3ae77d41b91615fb68ae937f904090203e81f755b65" - [[package]] name = "nodrop" version = "0.1.14" @@ -4528,48 +4521,6 @@ dependencies = [ "time", ] -[[package]] -name = "pnet_base" -version = "0.35.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ffc190d4067df16af3aba49b3b74c469e611cad6314676eaf1157f31aa0fb2f7" -dependencies = [ - "no-std-net", -] - -[[package]] -name = "pnet_macros" -version = "0.35.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "13325ac86ee1a80a480b0bc8e3d30c25d133616112bb16e86f712dcf8a71c863" -dependencies = [ - "proc-macro2", - "quote", - "regex", - "syn 2.0.72", -] - -[[package]] -name = "pnet_macros_support" -version = "0.35.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eed67a952585d509dd0003049b1fc56b982ac665c8299b124b90ea2bdb3134ab" -dependencies = [ - "pnet_base", -] - -[[package]] -name = "pnet_packet" -version = "0.35.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4c96ebadfab635fcc23036ba30a7d33a80c39e8461b8bd7dc7bb186acb96560f" -dependencies = [ - "glob", - "pnet_base", - "pnet_macros", - "pnet_macros_support", -] - [[package]] name = "png" version = "0.17.13" diff --git a/rust/bin-shared/benches/tunnel.rs b/rust/bin-shared/benches/tunnel.rs index 7fb275e64..9e49a901c 100644 --- a/rust/bin-shared/benches/tunnel.rs +++ b/rust/bin-shared/benches/tunnel.rs @@ -24,7 +24,7 @@ mod platform { mod platform { use anyhow::Result; use firezone_bin_shared::TunDeviceManager; - use ip_packet::{IpPacket, Packet as _}; + use ip_packet::IpPacket; use std::{ future::poll_fn, net::{Ipv4Addr, Ipv6Addr}, diff --git a/rust/connlib/snownet/src/node.rs b/rust/connlib/snownet/src/node.rs index a9106e1a2..6e070f7aa 100644 --- a/rust/connlib/snownet/src/node.rs +++ b/rust/connlib/snownet/src/node.rs @@ -9,7 +9,7 @@ use boringtun::x25519::PublicKey; use boringtun::{noise::rate_limiter::RateLimiter, x25519::StaticSecret}; use core::fmt; use hex_display::HexDisplayExt; -use ip_packet::{ConvertibleIpv4Packet, ConvertibleIpv6Packet, IpPacket, Packet as _}; +use ip_packet::{ConvertibleIpv4Packet, ConvertibleIpv6Packet, IpPacket}; use rand::rngs::StdRng; use rand::seq::IteratorRandom; use rand::{random, SeedableRng}; diff --git a/rust/connlib/tunnel/src/device_channel.rs b/rust/connlib/tunnel/src/device_channel.rs index e72d43621..1844f27c6 100644 --- a/rust/connlib/tunnel/src/device_channel.rs +++ b/rust/connlib/tunnel/src/device_channel.rs @@ -1,4 +1,4 @@ -use ip_packet::{IpPacket, Packet as _}; +use ip_packet::IpPacket; use std::io; use std::task::{Context, Poll, Waker}; use tun::Tun; @@ -31,8 +31,6 @@ impl Device { buf: &'b mut [u8], cx: &mut Context<'_>, ) -> Poll>> { - use ip_packet::Packet as _; - let Some(tun) = self.tun.as_mut() else { self.waker = Some(cx.waker().clone()); return Poll::Pending; diff --git a/rust/connlib/tunnel/src/peer.rs b/rust/connlib/tunnel/src/peer.rs index 9e8708149..0d8cb910c 100644 --- a/rust/connlib/tunnel/src/peer.rs +++ b/rust/connlib/tunnel/src/peer.rs @@ -77,7 +77,7 @@ impl AllowRules { return self.udp.contains(&udp.destination_port()); } - if packet.is_icmp_v4_or_v6() { + if packet.is_icmp() || packet.is_icmpv6() { return self.icmp; } diff --git a/rust/connlib/tunnel/src/tests/sim_client.rs b/rust/connlib/tunnel/src/tests/sim_client.rs index 429f61fe1..111e30471 100644 --- a/rust/connlib/tunnel/src/tests/sim_client.rs +++ b/rust/connlib/tunnel/src/tests/sim_client.rs @@ -24,7 +24,7 @@ use domain::{ }; use ip_network::{IpNetwork, Ipv4Network, Ipv6Network}; use ip_network_table::IpNetworkTable; -use ip_packet::IpPacket; +use ip_packet::{Icmpv4Type, Icmpv6Type, IpPacket}; use itertools::Itertools as _; use prop::collection; use proptest::prelude::*; @@ -123,12 +123,17 @@ impl SimClient { packet: IpPacket<'static>, now: Instant, ) -> Option> { - { - if let Some(icmp) = packet.as_icmp() { - let echo_request = icmp.echo_request_header().expect("to be echo request"); - + if let Some(icmp) = packet.as_icmpv4() { + if let Icmpv4Type::EchoRequest(echo) = icmp.icmp_type() { self.sent_icmp_requests - .insert((echo_request.seq, echo_request.id), packet.clone()); + .insert((echo.seq, echo.id), packet.clone()); + } + } + + if let Some(icmp) = packet.as_icmpv6() { + if let Icmpv6Type::EchoRequest(echo) = icmp.icmp_type() { + self.sent_icmp_requests + .insert((echo.seq, echo.id), packet.clone()); } } @@ -175,14 +180,23 @@ impl SimClient { /// Process an IP packet received on the client. pub(crate) fn on_received_packet(&mut self, packet: IpPacket<'static>) { - if let Some(icmp) = packet.as_icmp() { - let echo_reply = icmp.echo_reply_header().expect("to be echo reply"); + if let Some(icmp) = packet.as_icmpv4() { + if let Icmpv4Type::EchoReply(echo) = icmp.icmp_type() { + self.received_icmp_replies + .insert((echo.seq, echo.id), packet.clone()); - self.received_icmp_replies - .insert((echo_reply.seq, echo_reply.id), packet); + return; + } + } - return; - }; + if let Some(icmp) = packet.as_icmpv6() { + if let Icmpv6Type::EchoReply(echo) = icmp.icmp_type() { + self.received_icmp_replies + .insert((echo.seq, echo.id), packet.clone()); + + return; + } + } if let Some(udp) = packet.as_udp() { if udp.source_port() == 53 { @@ -225,7 +239,7 @@ impl SimClient { } } - tracing::error!("Unhandled packet"); + tracing::error!(?packet, "Unhandled packet"); } pub(crate) fn update_relays<'a>( diff --git a/rust/connlib/tunnel/src/tests/sim_gateway.rs b/rust/connlib/tunnel/src/tests/sim_gateway.rs index 901fb78fd..062e7cb6a 100644 --- a/rust/connlib/tunnel/src/tests/sim_gateway.rs +++ b/rust/connlib/tunnel/src/tests/sim_gateway.rs @@ -9,7 +9,7 @@ use connlib_shared::{ messages::{GatewayId, RelayId}, DomainName, }; -use ip_packet::IpPacket; +use ip_packet::{IcmpEchoHeader, Icmpv4Type, Icmpv6Type, IpPacket}; use proptest::prelude::*; use snownet::{EncryptBuffer, Transmit}; use std::{ @@ -70,29 +70,15 @@ impl SimGateway { ) -> Option> { // TODO: Instead of handling these things inline, here, should we dispatch them via `RoutingTable`? - if let Some(icmp) = packet.as_icmp() { - if let Some(echo_request) = icmp.echo_request_header() { - let payload = icmp.payload(); - let echo_id = u64::from_be_bytes(*payload.first_chunk().unwrap()); - tracing::debug!(%echo_id, "Received ICMP request"); + if let Some(icmp) = packet.as_icmpv4() { + if let Icmpv4Type::EchoRequest(echo) = icmp.icmp_type() { + return self.handle_icmp_request(&packet, echo, icmp.payload(), now); + } + } - self.received_icmp_requests.insert(echo_id, packet.clone()); - - let echo_response = ip_packet::make::icmp_reply_packet( - packet.destination(), - packet.source(), - echo_request.seq, - echo_request.id, - payload, - ) - .expect("src and dst are taken from incoming packet"); - let transmit = self - .sut - .encapsulate(echo_response, now, &mut self.enc_buffer)? - .to_transmit(&self.enc_buffer) - .into_owned(); - - return Some(transmit); + if let Some(icmp) = packet.as_icmpv6() { + if let Icmpv6Type::EchoRequest(echo) = icmp.icmp_type() { + return self.handle_icmp_request(&packet, echo, icmp.payload(), now); } } @@ -110,7 +96,7 @@ impl SimGateway { return Some(transmit); } - tracing::error!("Unhandled packet"); + tracing::error!(?packet, "Unhandled packet"); None } @@ -126,6 +112,35 @@ impl SimGateway { now, ) } + + fn handle_icmp_request( + &mut self, + packet: &IpPacket<'static>, + echo: IcmpEchoHeader, + payload: &[u8], + now: Instant, + ) -> Option> { + let echo_id = u64::from_be_bytes(*payload.first_chunk().unwrap()); + self.received_icmp_requests.insert(echo_id, packet.clone()); + + tracing::debug!(%echo_id, "Received ICMP request"); + + let echo_response = ip_packet::make::icmp_reply_packet( + packet.destination(), + packet.source(), + echo.seq, + echo.id, + payload, + ) + .expect("src and dst are taken from incoming packet"); + let transmit = self + .sut + .encapsulate(echo_response, now, &mut self.enc_buffer)? + .to_transmit(&self.enc_buffer) + .into_owned(); + + Some(transmit) + } } /// Reference state for a particular gateway. diff --git a/rust/ip-packet/Cargo.toml b/rust/ip-packet/Cargo.toml index 8a3cf5837..74748a8bf 100644 --- a/rust/ip-packet/Cargo.toml +++ b/rust/ip-packet/Cargo.toml @@ -13,7 +13,6 @@ proptest = ["dep:proptest"] anyhow = "1.0.86" domain = "0.10.1" etherparse = "0.15" -pnet_packet = { version = "0.35" } proptest = { version = "1", optional = true } thiserror = "1" tracing = "0.1" diff --git a/rust/ip-packet/src/icmpv4_header_slice_mut.rs b/rust/ip-packet/src/icmpv4_header_slice_mut.rs new file mode 100644 index 000000000..34c3f8a39 --- /dev/null +++ b/rust/ip-packet/src/icmpv4_header_slice_mut.rs @@ -0,0 +1,81 @@ +use crate::slice_utils::write_to_offset_unchecked; +use etherparse::{ + icmpv4::{TYPE_ECHO_REPLY, TYPE_ECHO_REQUEST}, + Icmpv4Slice, +}; + +pub struct Icmpv4HeaderSliceMut<'a> { + slice: &'a mut [u8], +} + +impl<'a> Icmpv4HeaderSliceMut<'a> { + /// Creates a new [`Icmpv4HeaderSliceMut`]. + pub fn from_slice(slice: &'a mut [u8]) -> Result { + Icmpv4Slice::from_slice(slice)?; + + Ok(Self { slice }) + } + + pub fn set_checksum(&mut self, checksum: u16) { + // Safety: Slice is at least of length 8 as checked in the ctor. + unsafe { write_to_offset_unchecked(self.slice, 2, checksum.to_be_bytes()) }; + } + + pub fn set_identifier(&mut self, id: u16) { + debug_assert!( + self.is_echo_request_or_reply(), + "ICMP identifier only exists for echo requests and replies" + ); + + // Safety: Slice is at least of length 8 as checked in the ctor. + unsafe { write_to_offset_unchecked(self.slice, 4, id.to_be_bytes()) }; + } + + pub fn set_sequence(&mut self, seq: u16) { + debug_assert!( + self.is_echo_request_or_reply(), + "ICMP sequence only exists for echo requests and replies" + ); + + // Safety: Slice is at least of length 8 as checked in the ctor. + unsafe { write_to_offset_unchecked(self.slice, 6, seq.to_be_bytes()) }; + } + + fn is_echo_request_or_reply(&self) -> bool { + let ty = self.slice[0]; + + ty == TYPE_ECHO_REPLY || ty == TYPE_ECHO_REQUEST + } +} + +#[cfg(test)] +mod tests { + use super::*; + use etherparse::{Icmpv4Type, PacketBuilder}; + + #[test] + fn smoke() { + let mut buf = Vec::new(); + + PacketBuilder::ipv4([0u8; 4], [0u8; 4], 0) + .icmpv4_echo_request(10, 20) + .write(&mut buf, &[]) + .unwrap(); + + let mut slice = Icmpv4HeaderSliceMut::from_slice(&mut buf[20..]).unwrap(); + + slice.set_identifier(30); + slice.set_sequence(40); + slice.set_checksum(50); + + let slice = Icmpv4Slice::from_slice(&buf[20..]).unwrap(); + + let Icmpv4Type::EchoRequest(header) = slice.header().icmp_type else { + panic!("Unexpected ICMP header"); + }; + + assert_eq!(header.id, 30); + assert_eq!(header.seq, 40); + assert_eq!(slice.checksum(), 50); + } +} diff --git a/rust/ip-packet/src/icmpv6_header_slice_mut.rs b/rust/ip-packet/src/icmpv6_header_slice_mut.rs new file mode 100644 index 000000000..a9d7dd793 --- /dev/null +++ b/rust/ip-packet/src/icmpv6_header_slice_mut.rs @@ -0,0 +1,81 @@ +use crate::slice_utils::write_to_offset_unchecked; +use etherparse::{ + icmpv6::{TYPE_ECHO_REPLY, TYPE_ECHO_REQUEST}, + Icmpv6Slice, +}; + +pub struct Icmpv6EchoHeaderSliceMut<'a> { + slice: &'a mut [u8], +} + +impl<'a> Icmpv6EchoHeaderSliceMut<'a> { + /// Creates a new [`Icmpv6EchoHeaderSliceMut`]. + pub fn from_slice(slice: &'a mut [u8]) -> Result { + Icmpv6Slice::from_slice(slice)?; + + Ok(Self { slice }) + } + + pub fn set_checksum(&mut self, checksum: u16) { + // Safety: Slice is at least of length 8 as checked in the ctor. + unsafe { write_to_offset_unchecked(self.slice, 2, checksum.to_be_bytes()) }; + } + + pub fn set_identifier(&mut self, id: u16) { + debug_assert!( + self.is_echo_request_or_reply(), + "ICMP identifier only exists for echo requests and replies" + ); + + // Safety: Slice is at least of length 8 as checked in the ctor. + unsafe { write_to_offset_unchecked(self.slice, 4, id.to_be_bytes()) }; + } + + pub fn set_sequence(&mut self, seq: u16) { + debug_assert!( + self.is_echo_request_or_reply(), + "ICMP sequence only exists for echo requests and replies" + ); + + // Safety: Slice is at least of length 8 as checked in the ctor. + unsafe { write_to_offset_unchecked(self.slice, 6, seq.to_be_bytes()) }; + } + + fn is_echo_request_or_reply(&self) -> bool { + let ty = self.slice[0]; + + ty == TYPE_ECHO_REPLY || ty == TYPE_ECHO_REQUEST + } +} + +#[cfg(test)] +mod tests { + use super::*; + use etherparse::{Icmpv6Type, PacketBuilder}; + + #[test] + fn smoke() { + let mut buf = Vec::new(); + + PacketBuilder::ipv6([0u8; 16], [0u8; 16], 0) + .icmpv6_echo_request(10, 20) + .write(&mut buf, &[]) + .unwrap(); + + let mut slice = Icmpv6EchoHeaderSliceMut::from_slice(&mut buf[40..]).unwrap(); + + slice.set_identifier(30); + slice.set_sequence(40); + slice.set_checksum(50); + + let slice = Icmpv6Slice::from_slice(&buf[40..]).unwrap(); + + let Icmpv6Type::EchoRequest(header) = slice.header().icmp_type else { + panic!("Unexpected ICMP header"); + }; + + assert_eq!(header.id, 30); + assert_eq!(header.seq, 40); + assert_eq!(slice.checksum(), 50); + } +} diff --git a/rust/ip-packet/src/lib.rs b/rust/ip-packet/src/lib.rs index 53d668518..dbbe1cc9f 100644 --- a/rust/ip-packet/src/lib.rs +++ b/rust/ip-packet/src/lib.rs @@ -1,5 +1,7 @@ pub mod make; +mod icmpv4_header_slice_mut; +mod icmpv6_header_slice_mut; mod ipv4_header_slice_mut; mod ipv6_header_slice_mut; mod nat46; @@ -7,31 +9,24 @@ mod nat64; #[cfg(feature = "proptest")] pub mod proptest; mod slice_utils; +mod tcp_header_slice_mut; +mod udp_header_slice_mut; -pub use pnet_packet::*; +pub use etherparse::*; #[cfg(all(test, feature = "proptest"))] mod proptests; -use etherparse::{ - IcmpEchoHeader, Icmpv4Slice, Icmpv4Type, Icmpv6Slice, Icmpv6Type, IpNumber, Ipv4Header, - Ipv4HeaderSlice, Ipv6Header, Ipv6HeaderSlice, TcpSlice, UdpSlice, -}; +use icmpv4_header_slice_mut::Icmpv4HeaderSliceMut; +use icmpv6_header_slice_mut::Icmpv6EchoHeaderSliceMut; use ipv4_header_slice_mut::Ipv4HeaderSliceMut; use ipv6_header_slice_mut::Ipv6HeaderSliceMut; -use pnet_packet::{ - icmp::{ - echo_reply::MutableEchoReplyPacket, echo_request::MutableEchoRequestPacket, IcmpTypes, - MutableIcmpPacket, - }, - icmpv6::{Icmpv6Types, MutableIcmpv6Packet}, - tcp::MutableTcpPacket, - udp::MutableUdpPacket, -}; use std::{ net::{IpAddr, Ipv4Addr, Ipv6Addr}, ops::{Deref, DerefMut}, }; +use tcp_header_slice_mut::TcpHeaderSliceMut; +use udp_header_slice_mut::UdpHeaderSliceMut; macro_rules! for_both { ($this:ident, |$name:ident| $body:expr) => { @@ -79,81 +74,21 @@ impl Protocol { } } -#[derive(Debug, PartialEq)] -pub enum IcmpPacket<'a> { - Ipv4(Icmpv4Slice<'a>), - Ipv6(Icmpv6Slice<'a>), -} - -impl<'a> IcmpPacket<'a> { - pub fn icmp_type(&self) -> IcmpType { - match self { - IcmpPacket::Ipv4(v4) => IcmpType::V4(v4.icmp_type()), - IcmpPacket::Ipv6(v6) => IcmpType::V6(v6.icmp_type()), - } - } - - pub fn identifier(&self) -> Option { - Some(self.echo_request_header().or(self.echo_reply_header())?.id) - } - - pub fn sequence(&self) -> Option { - Some(self.echo_request_header().or(self.echo_reply_header())?.seq) - } - - pub fn payload(&self) -> &[u8] { - match self { - IcmpPacket::Ipv4(v4) => v4.payload(), - IcmpPacket::Ipv6(v6) => v6.payload(), - } - } - - pub fn echo_request_header(&self) -> Option { - #[allow( - clippy::wildcard_enum_match_arm, - reason = "We won't ever need to use other ICMP types here." - )] - match self { - IcmpPacket::Ipv4(v4) => match v4.header().icmp_type { - Icmpv4Type::EchoRequest(echo) => Some(echo), - _ => None, - }, - IcmpPacket::Ipv6(v6) => match v6.header().icmp_type { - Icmpv6Type::EchoRequest(echo) => Some(echo), - _ => None, - }, - } - } - - pub fn echo_reply_header(&self) -> Option { - #[allow( - clippy::wildcard_enum_match_arm, - reason = "We won't ever need to use other ICMP types here." - )] - match self { - IcmpPacket::Ipv4(v4) => match v4.header().icmp_type { - Icmpv4Type::EchoReply(echo) => Some(echo), - _ => None, - }, - IcmpPacket::Ipv6(v6) => match v6.header().icmp_type { - Icmpv6Type::EchoReply(echo) => Some(echo), - _ => None, - }, - } - } -} - -pub enum IcmpType { - V4(Icmpv4Type), - V6(Icmpv6Type), -} - -#[derive(Debug, PartialEq, Clone)] +#[derive(PartialEq, Clone)] pub enum IpPacket<'a> { Ipv4(ConvertibleIpv4Packet<'a>), Ipv6(ConvertibleIpv6Packet<'a>), } +impl<'a> std::fmt::Debug for IpPacket<'a> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Ipv4(arg0) => arg0.ip_header().to_header().fmt(f), + Self::Ipv6(arg0) => arg0.header().to_header().fmt(f), + } + } +} + #[derive(Debug, PartialEq)] enum MaybeOwned<'a> { RefMut(&'a mut [u8]), @@ -222,11 +157,11 @@ impl<'a> ConvertibleIpv4Packet<'a> { } fn ip_header(&self) -> Ipv4HeaderSlice { - Ipv4HeaderSlice::from_slice(&self.buf[20..]).expect("we checked this during `new`") + Ipv4HeaderSlice::from_slice(self.packet()).expect("we checked this during `new`") } fn ip_header_mut(&mut self) -> Ipv4HeaderSliceMut { - Ipv4HeaderSliceMut::from_slice(&mut self.buf[20..]).expect("we checked this during `new`") + Ipv4HeaderSliceMut::from_slice(self.packet_mut()).expect("we checked this during `new`") } pub fn get_source(&self) -> Ipv4Addr { @@ -253,27 +188,14 @@ impl<'a> ConvertibleIpv4Packet<'a> { fn header_length(&self) -> usize { (self.ip_header().ihl() * 4) as usize } -} -impl<'a> Packet for ConvertibleIpv4Packet<'a> { - fn packet(&self) -> &[u8] { + pub fn packet(&self) -> &[u8] { &self.buf[20..] } - fn payload(&self) -> &[u8] { - &self.buf[(self.header_length() + 20)..] - } -} - -impl<'a> MutablePacket for ConvertibleIpv4Packet<'a> { fn packet_mut(&mut self) -> &mut [u8] { &mut self.buf[20..] } - - fn payload_mut(&mut self) -> &mut [u8] { - let header_len = self.header_length(); - &mut self.buf[(header_len + 20)..] - } } #[derive(Debug, PartialEq, Clone)] @@ -299,11 +221,12 @@ impl<'a> ConvertibleIpv6Packet<'a> { } fn header(&self) -> Ipv6HeaderSlice { - Ipv6HeaderSlice::from_slice(&self.buf).expect("We checked this in `new` / `owned`") + Ipv6HeaderSlice::from_slice(self.packet()).expect("We checked this in `new` / `owned`") } fn header_mut(&mut self) -> Ipv6HeaderSliceMut { - Ipv6HeaderSliceMut::from_slice(&mut self.buf).expect("We checked this in `new` / `owned`") + Ipv6HeaderSliceMut::from_slice(self.packet_mut()) + .expect("We checked this in `new` / `owned`") } pub fn get_source(&self) -> Ipv6Addr { @@ -325,26 +248,14 @@ impl<'a> ConvertibleIpv6Packet<'a> { Some(ConvertibleIpv4Packet { buf: self.buf }) } -} -impl<'a> Packet for ConvertibleIpv6Packet<'a> { - fn packet(&self) -> &[u8] { + pub fn packet(&self) -> &[u8] { &self.buf } - fn payload(&self) -> &[u8] { - &self.buf[Ipv6Header::LEN..] - } -} - -impl<'a> MutablePacket for ConvertibleIpv6Packet<'a> { fn packet_mut(&mut self) -> &mut [u8] { &mut self.buf } - - fn payload_mut(&mut self) -> &mut [u8] { - &mut self.buf[Ipv6Header::LEN..] - } } pub fn ipv4_embedded(ip: Ipv4Addr) -> Ipv6Addr { @@ -444,11 +355,20 @@ impl<'a> IpPacket<'a> { return Ok(Protocol::Udp(p.source_port())); } - if let Some(p) = self.as_icmp() { - let id = p.identifier().ok_or_else(|| match p.icmp_type() { - IcmpType::V4(v4) => UnsupportedProtocol::UnsupportedIcmpv4Type(v4), - IcmpType::V6(v6) => UnsupportedProtocol::UnsupportedIcmpv6Type(v6), - })?; + if let Some(p) = self.as_icmpv4() { + let id = self + .icmpv4_echo_header() + .ok_or_else(|| UnsupportedProtocol::UnsupportedIcmpv4Type(p.icmp_type()))? + .id; + + return Ok(Protocol::Icmp(id)); + } + + if let Some(p) = self.as_icmpv6() { + let id = self + .icmpv6_echo_header() + .ok_or_else(|| UnsupportedProtocol::UnsupportedIcmpv6Type(p.icmp_type()))? + .id; return Ok(Protocol::Icmp(id)); } @@ -467,11 +387,20 @@ impl<'a> IpPacket<'a> { return Ok(Protocol::Udp(p.destination_port())); } - if let Some(p) = self.as_icmp() { - let id = p.identifier().ok_or_else(|| match p.icmp_type() { - IcmpType::V4(v4) => UnsupportedProtocol::UnsupportedIcmpv4Type(v4), - IcmpType::V6(v6) => UnsupportedProtocol::UnsupportedIcmpv6Type(v6), - })?; + if let Some(p) = self.as_icmpv4() { + let id = self + .icmpv4_echo_header() + .ok_or_else(|| UnsupportedProtocol::UnsupportedIcmpv4Type(p.icmp_type()))? + .id; + + return Ok(Protocol::Icmp(id)); + } + + if let Some(p) = self.as_icmpv6() { + let id = self + .icmpv6_echo_header() + .ok_or_else(|| UnsupportedProtocol::UnsupportedIcmpv6Type(p.icmp_type()))? + .id; return Ok(Protocol::Icmp(id)); } @@ -483,11 +412,11 @@ impl<'a> IpPacket<'a> { pub fn set_source_protocol(&mut self, v: u16) { if let Some(mut p) = self.as_tcp_mut() { - p.set_source(v); + p.set_source_port(v); } if let Some(mut p) = self.as_udp_mut() { - p.set_source(v); + p.set_source_port(v); } self.set_icmp_identifier(v); @@ -495,51 +424,23 @@ impl<'a> IpPacket<'a> { pub fn set_destination_protocol(&mut self, v: u16) { if let Some(mut p) = self.as_tcp_mut() { - p.set_destination(v); + p.set_destination_port(v); } if let Some(mut p) = self.as_udp_mut() { - p.set_destination(v); + p.set_destination_port(v); } self.set_icmp_identifier(v); } fn set_icmp_identifier(&mut self, v: u16) { - if let Some(mut p) = self.as_icmp_mut() { - if p.get_icmp_type() == IcmpTypes::EchoReply { - let Some(mut echo_reply) = MutableEchoReplyPacket::new(p.packet_mut()) else { - return; - }; - echo_reply.set_identifier(v) - } - - if p.get_icmp_type() == IcmpTypes::EchoRequest { - let Some(mut echo_request) = MutableEchoRequestPacket::new(p.packet_mut()) else { - return; - }; - echo_request.set_identifier(v); - } + if let Some(mut p) = self.as_icmpv4_mut() { + p.set_identifier(v); } - if let Some(mut p) = self.as_icmpv6() { - if p.get_icmpv6_type() == Icmpv6Types::EchoReply { - let Some(mut echo_reply) = - icmpv6::echo_reply::MutableEchoReplyPacket::new(p.packet_mut()) - else { - return; - }; - echo_reply.set_identifier(v) - } - - if p.get_icmpv6_type() == Icmpv6Types::EchoRequest { - let Some(mut echo_request) = - icmpv6::echo_request::MutableEchoRequestPacket::new(p.packet_mut()) - else { - return; - }; - echo_request.set_identifier(v); - } + if let Some(mut p) = self.as_icmpv6_mut() { + p.set_identifier(v); } } @@ -605,76 +506,132 @@ impl<'a> IpPacket<'a> { } pub fn as_udp(&self) -> Option { - self.is_udp() - .then(|| UdpSlice::from_slice(self.payload()).ok()) - .flatten() + if !self.is_udp() { + return None; + } + + UdpSlice::from_slice(self.payload()).ok() } - pub fn as_udp_mut(&mut self) -> Option { - self.is_udp() - .then(|| MutableUdpPacket::new(self.payload_mut())) - .flatten() + pub fn as_udp_mut(&mut self) -> Option { + if !self.is_udp() { + return None; + } + + UdpHeaderSliceMut::from_slice(self.payload_mut()).ok() } pub fn as_tcp(&self) -> Option { - self.is_tcp() - .then(|| TcpSlice::from_slice(self.payload()).ok()) - .flatten() - } - - pub fn as_tcp_mut(&mut self) -> Option { - self.is_tcp() - .then(|| MutableTcpPacket::new(self.payload_mut())) - .flatten() - } - - pub fn is_icmp_v4_or_v6(&self) -> bool { - match self { - IpPacket::Ipv4(v4) => v4.ip_header().protocol() == IpNumber::ICMP, - IpPacket::Ipv6(v6) => v6.header().next_header() == IpNumber::IPV6_ICMP, + if !self.is_tcp() { + return None; } + + TcpSlice::from_slice(self.payload()).ok() + } + + pub fn as_tcp_mut(&mut self) -> Option { + if !self.is_tcp() { + return None; + } + + TcpHeaderSliceMut::from_slice(self.payload_mut()).ok() } fn set_icmpv6_checksum(&mut self) { - let (src_addr, dst_addr) = match self { - IpPacket::Ipv4(_) => return, - IpPacket::Ipv6(p) => (p.get_source(), p.get_destination()), + let Some(i) = self.as_icmpv6() else { + return; }; - if let Some(mut pkt) = self.as_icmpv6() { - let checksum = icmpv6::checksum(&pkt.to_immutable(), &src_addr, &dst_addr); - pkt.set_checksum(checksum); - } + + let IpPacket::Ipv6(p) = &self else { + return; + }; + + let checksum = i + .icmp_type() + .calc_checksum( + p.get_source().octets(), + p.get_destination().octets(), + i.payload(), + ) + .expect("Payload came from the original packet"); + + let Some(mut i) = self.as_icmpv6_mut() else { + return; + }; + + i.set_checksum(checksum); } fn set_icmpv4_checksum(&mut self) { - if let Some(mut pkt) = self.as_icmp_mut() { - let checksum = icmp::checksum(&pkt.to_immutable()); - pkt.set_checksum(checksum); + let Some(i) = self.as_icmpv4() else { + return; + }; + + let checksum = i.icmp_type().calc_checksum(i.payload()); + + let Some(mut i) = self.as_icmpv4_mut() else { + return; + }; + + i.set_checksum(checksum); + } + + pub fn as_icmpv4(&self) -> Option { + if !self.is_icmp() { + return None; } + + Icmpv4Slice::from_slice(self.payload()).ok() } - pub fn as_icmp(&self) -> Option { - match self { - Self::Ipv4(v4) if self.is_icmp() => Some(IcmpPacket::Ipv4( - Icmpv4Slice::from_slice(v4.payload()).ok()?, - )), - Self::Ipv6(v6) if self.is_icmpv6() => Some(IcmpPacket::Ipv6( - Icmpv6Slice::from_slice(v6.payload()).ok()?, - )), - Self::Ipv4(_) | Self::Ipv6(_) => None, + pub fn as_icmpv4_mut(&mut self) -> Option { + if !self.is_icmp() { + return None; } + + Icmpv4HeaderSliceMut::from_slice(self.payload_mut()).ok() } - pub fn as_icmp_mut(&mut self) -> Option { - self.is_icmp() - .then(|| MutableIcmpPacket::new(self.payload_mut())) - .flatten() + pub fn as_icmpv6(&self) -> Option { + if !self.is_icmpv6() { + return None; + } + + Icmpv6Slice::from_slice(self.payload()).ok() } - fn as_icmpv6(&mut self) -> Option { - self.is_icmpv6() - .then(|| MutableIcmpv6Packet::new(self.payload_mut())) - .flatten() + pub fn as_icmpv6_mut(&mut self) -> Option { + if !self.is_icmpv6() { + return None; + } + + Icmpv6EchoHeaderSliceMut::from_slice(self.payload_mut()).ok() + } + + fn icmpv4_echo_header(&self) -> Option { + let p = self.as_icmpv4()?; + + use Icmpv4Type::*; + let icmp_type = p.icmp_type(); + + let (EchoReply(header) | EchoRequest(header)) = icmp_type else { + return None; + }; + + Some(header) + } + + fn icmpv6_echo_header(&self) -> Option { + let p = self.as_icmpv6()?; + + use Icmpv6Type::*; + let icmp_type = p.icmp_type(); + + let (EchoReply(header) | EchoRequest(header)) = icmp_type else { + return None; + }; + + Some(header) } pub fn translate_destination( @@ -790,13 +747,46 @@ impl<'a> IpPacket<'a> { self.next_header() == IpNumber::TCP } - fn is_icmp(&self) -> bool { + pub fn is_icmp(&self) -> bool { self.next_header() == IpNumber::ICMP } - fn is_icmpv6(&self) -> bool { + pub fn is_icmpv6(&self) -> bool { self.next_header() == IpNumber::IPV6_ICMP } + + fn header_length(&self) -> usize { + match self { + IpPacket::Ipv4(v4) => v4.header_length(), + IpPacket::Ipv6(v6) => v6.header().header_len(), + } + } + + pub fn packet(&self) -> &[u8] { + match self { + IpPacket::Ipv4(v4) => v4.packet(), + IpPacket::Ipv6(v6) => v6.packet(), + } + } + + fn packet_mut(&mut self) -> &mut [u8] { + match self { + IpPacket::Ipv4(v4) => v4.packet_mut(), + IpPacket::Ipv6(v6) => v6.packet_mut(), + } + } + + fn payload(&self) -> &[u8] { + let start = self.header_length(); + + &self.packet()[start..] + } + + fn payload_mut(&mut self) -> &mut [u8] { + let start = self.header_length(); + + &mut self.packet_mut()[start..] + } } impl<'a> From> for IpPacket<'a> { @@ -811,26 +801,6 @@ impl<'a> From> for IpPacket<'a> { } } -impl pnet_packet::Packet for IpPacket<'_> { - fn packet(&self) -> &[u8] { - for_both!(self, |i| i.packet()) - } - - fn payload(&self) -> &[u8] { - for_both!(self, |i| i.payload()) - } -} - -impl pnet_packet::MutablePacket for IpPacket<'_> { - fn packet_mut(&mut self) -> &mut [u8] { - for_both!(self, |i| i.packet_mut()) - } - - fn payload_mut(&mut self) -> &mut [u8] { - for_both!(self, |i| i.payload_mut()) - } -} - #[derive(Debug, thiserror::Error)] pub enum UnsupportedProtocol { #[error("Unsupported IP protocol: {0:?}")] diff --git a/rust/ip-packet/src/proptests.rs b/rust/ip-packet/src/proptests.rs index d7eaf49c0..c2a69ef96 100644 --- a/rust/ip-packet/src/proptests.rs +++ b/rust/ip-packet/src/proptests.rs @@ -1,6 +1,5 @@ use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; -use pnet_packet::Packet; use proptest::arbitrary::any; use proptest::prop_oneof; use proptest::strategy::Strategy; diff --git a/rust/ip-packet/src/slice_utils.rs b/rust/ip-packet/src/slice_utils.rs index 558131917..8a8e216a2 100644 --- a/rust/ip-packet/src/slice_utils.rs +++ b/rust/ip-packet/src/slice_utils.rs @@ -8,7 +8,7 @@ pub unsafe fn write_to_offset_unchecked( offset: usize, bytes: [u8; N], ) { - debug_assert!(offset + N < slice.len()); + debug_assert!(offset + N <= slice.len()); let (_front, rest) = unsafe { slice.split_at_mut_unchecked(offset) }; let (target, _rest) = unsafe { rest.split_at_mut_unchecked(N) }; diff --git a/rust/ip-packet/src/tcp_header_slice_mut.rs b/rust/ip-packet/src/tcp_header_slice_mut.rs new file mode 100644 index 000000000..0a896af6d --- /dev/null +++ b/rust/ip-packet/src/tcp_header_slice_mut.rs @@ -0,0 +1,58 @@ +use crate::slice_utils::write_to_offset_unchecked; +use etherparse::TcpHeaderSlice; + +pub struct TcpHeaderSliceMut<'a> { + slice: &'a mut [u8], +} + +impl<'a> TcpHeaderSliceMut<'a> { + /// Creates a new [`TcpHeaderSliceMut`]. + pub fn from_slice(slice: &'a mut [u8]) -> Result { + TcpHeaderSlice::from_slice(slice)?; + + Ok(Self { slice }) + } + + pub fn set_source_port(&mut self, src: u16) { + // Safety: Slice it at least of length 20 as checked in the ctor. + unsafe { write_to_offset_unchecked(self.slice, 0, src.to_be_bytes()) }; + } + + pub fn set_destination_port(&mut self, dst: u16) { + // Safety: Slice it at least of length 20 as checked in the ctor. + unsafe { write_to_offset_unchecked(self.slice, 2, dst.to_be_bytes()) }; + } + + pub fn set_checksum(&mut self, checksum: u16) { + // Safety: Slice it at least of length 20 as checked in the ctor. + unsafe { write_to_offset_unchecked(self.slice, 16, checksum.to_be_bytes()) }; + } +} + +#[cfg(test)] +mod tests { + use super::*; + use etherparse::PacketBuilder; + + #[test] + fn smoke() { + let mut buf = Vec::new(); + + PacketBuilder::ipv4([0u8; 4], [0u8; 4], 0) + .tcp(10, 20, 0, 0) + .write(&mut buf, &[]) + .unwrap(); + + let mut slice = TcpHeaderSliceMut::from_slice(&mut buf[20..]).unwrap(); + + slice.set_source_port(30); + slice.set_destination_port(40); + slice.set_checksum(50); + + let slice = TcpHeaderSlice::from_slice(&buf[20..]).unwrap(); + + assert_eq!(slice.source_port(), 30); + assert_eq!(slice.destination_port(), 40); + assert_eq!(slice.checksum(), 50); + } +} diff --git a/rust/ip-packet/src/udp_header_slice_mut.rs b/rust/ip-packet/src/udp_header_slice_mut.rs new file mode 100644 index 000000000..986b88604 --- /dev/null +++ b/rust/ip-packet/src/udp_header_slice_mut.rs @@ -0,0 +1,65 @@ +use crate::slice_utils::write_to_offset_unchecked; +use etherparse::UdpHeaderSlice; + +pub struct UdpHeaderSliceMut<'a> { + slice: &'a mut [u8], +} + +impl<'a> UdpHeaderSliceMut<'a> { + /// Creates a new [`UdpHeaderSliceMut`]. + pub fn from_slice(slice: &'a mut [u8]) -> Result { + UdpHeaderSlice::from_slice(slice)?; + + Ok(Self { slice }) + } + + pub fn set_source_port(&mut self, src: u16) { + // Safety: Slice it at least of length 8 as checked in the ctor. + unsafe { write_to_offset_unchecked(self.slice, 0, src.to_be_bytes()) }; + } + + pub fn set_destination_port(&mut self, dst: u16) { + // Safety: Slice it at least of length 8 as checked in the ctor. + unsafe { write_to_offset_unchecked(self.slice, 2, dst.to_be_bytes()) }; + } + + pub fn set_length(&mut self, length: u16) { + // Safety: Slice it at least of length 8 as checked in the ctor. + unsafe { write_to_offset_unchecked(self.slice, 4, length.to_be_bytes()) }; + } + + pub fn set_checksum(&mut self, checksum: u16) { + // Safety: Slice it at least of length 8 as checked in the ctor. + unsafe { write_to_offset_unchecked(self.slice, 6, checksum.to_be_bytes()) }; + } +} + +#[cfg(test)] +mod tests { + use super::*; + use etherparse::PacketBuilder; + + #[test] + fn smoke() { + let mut buf = Vec::new(); + + PacketBuilder::ipv4([0u8; 4], [0u8; 4], 0) + .udp(10, 20) + .write(&mut buf, &[]) + .unwrap(); + + let mut slice = UdpHeaderSliceMut::from_slice(&mut buf[20..]).unwrap(); + + slice.set_source_port(30); + slice.set_destination_port(40); + slice.set_length(50); + slice.set_checksum(60); + + let slice = UdpHeaderSlice::from_slice(&buf[20..]).unwrap(); + + assert_eq!(slice.source_port(), 30); + assert_eq!(slice.destination_port(), 40); + assert_eq!(slice.length(), 50); + assert_eq!(slice.checksum(), 60); + } +}