diff --git a/rust/Cargo.lock b/rust/Cargo.lock index dca77f8d7..4aac70edc 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -3086,7 +3086,6 @@ dependencies = [ "anyhow", "domain", "etherparse", - "pnet_packet", "proptest", "test-strategy", "thiserror", @@ -3822,12 +3821,6 @@ dependencies = [ "memoffset 0.9.1", ] -[[package]] -name = "no-std-net" -version = "0.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "43794a0ace135be66a25d3ae77d41b91615fb68ae937f904090203e81f755b65" - [[package]] name = "nodrop" version = "0.1.14" @@ -4528,48 +4521,6 @@ dependencies = [ "time", ] -[[package]] -name = "pnet_base" -version = "0.35.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ffc190d4067df16af3aba49b3b74c469e611cad6314676eaf1157f31aa0fb2f7" -dependencies = [ - "no-std-net", -] - -[[package]] -name = "pnet_macros" -version = "0.35.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "13325ac86ee1a80a480b0bc8e3d30c25d133616112bb16e86f712dcf8a71c863" -dependencies = [ - "proc-macro2", - "quote", - "regex", - "syn 2.0.72", -] - -[[package]] -name = "pnet_macros_support" -version = "0.35.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eed67a952585d509dd0003049b1fc56b982ac665c8299b124b90ea2bdb3134ab" -dependencies = [ - "pnet_base", -] - -[[package]] -name = "pnet_packet" -version = "0.35.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4c96ebadfab635fcc23036ba30a7d33a80c39e8461b8bd7dc7bb186acb96560f" -dependencies = [ - "glob", - "pnet_base", - "pnet_macros", - "pnet_macros_support", -] - [[package]] name = "png" version = "0.17.13" diff --git a/rust/bin-shared/benches/tunnel.rs b/rust/bin-shared/benches/tunnel.rs index 7fb275e64..9e49a901c 100644 --- a/rust/bin-shared/benches/tunnel.rs +++ b/rust/bin-shared/benches/tunnel.rs @@ -24,7 +24,7 @@ mod platform { mod platform { use anyhow::Result; use firezone_bin_shared::TunDeviceManager; - use ip_packet::{IpPacket, Packet as _}; + use ip_packet::IpPacket; use std::{ future::poll_fn, net::{Ipv4Addr, Ipv6Addr}, diff --git a/rust/connlib/snownet/src/node.rs b/rust/connlib/snownet/src/node.rs index a9106e1a2..6e070f7aa 100644 --- a/rust/connlib/snownet/src/node.rs +++ b/rust/connlib/snownet/src/node.rs @@ -9,7 +9,7 @@ use boringtun::x25519::PublicKey; use boringtun::{noise::rate_limiter::RateLimiter, x25519::StaticSecret}; use core::fmt; use hex_display::HexDisplayExt; -use ip_packet::{ConvertibleIpv4Packet, ConvertibleIpv6Packet, IpPacket, Packet as _}; +use ip_packet::{ConvertibleIpv4Packet, ConvertibleIpv6Packet, IpPacket}; use rand::rngs::StdRng; use rand::seq::IteratorRandom; use rand::{random, SeedableRng}; diff --git a/rust/connlib/tunnel/src/device_channel.rs b/rust/connlib/tunnel/src/device_channel.rs index e72d43621..1844f27c6 100644 --- a/rust/connlib/tunnel/src/device_channel.rs +++ b/rust/connlib/tunnel/src/device_channel.rs @@ -1,4 +1,4 @@ -use ip_packet::{IpPacket, Packet as _}; +use ip_packet::IpPacket; use std::io; use std::task::{Context, Poll, Waker}; use tun::Tun; @@ -31,8 +31,6 @@ impl Device { buf: &'b mut [u8], cx: &mut Context<'_>, ) -> Poll>> { - use ip_packet::Packet as _; - let Some(tun) = self.tun.as_mut() else { self.waker = Some(cx.waker().clone()); return Poll::Pending; diff --git a/rust/connlib/tunnel/src/peer.rs b/rust/connlib/tunnel/src/peer.rs index 9e8708149..0d8cb910c 100644 --- a/rust/connlib/tunnel/src/peer.rs +++ b/rust/connlib/tunnel/src/peer.rs @@ -77,7 +77,7 @@ impl AllowRules { return self.udp.contains(&udp.destination_port()); } - if packet.is_icmp_v4_or_v6() { + if packet.is_icmp() || packet.is_icmpv6() { return self.icmp; } diff --git a/rust/connlib/tunnel/src/tests/sim_client.rs b/rust/connlib/tunnel/src/tests/sim_client.rs index 429f61fe1..111e30471 100644 --- a/rust/connlib/tunnel/src/tests/sim_client.rs +++ b/rust/connlib/tunnel/src/tests/sim_client.rs @@ -24,7 +24,7 @@ use domain::{ }; use ip_network::{IpNetwork, Ipv4Network, Ipv6Network}; use ip_network_table::IpNetworkTable; -use ip_packet::IpPacket; +use ip_packet::{Icmpv4Type, Icmpv6Type, IpPacket}; use itertools::Itertools as _; use prop::collection; use proptest::prelude::*; @@ -123,12 +123,17 @@ impl SimClient { packet: IpPacket<'static>, now: Instant, ) -> Option> { - { - if let Some(icmp) = packet.as_icmp() { - let echo_request = icmp.echo_request_header().expect("to be echo request"); - + if let Some(icmp) = packet.as_icmpv4() { + if let Icmpv4Type::EchoRequest(echo) = icmp.icmp_type() { self.sent_icmp_requests - .insert((echo_request.seq, echo_request.id), packet.clone()); + .insert((echo.seq, echo.id), packet.clone()); + } + } + + if let Some(icmp) = packet.as_icmpv6() { + if let Icmpv6Type::EchoRequest(echo) = icmp.icmp_type() { + self.sent_icmp_requests + .insert((echo.seq, echo.id), packet.clone()); } } @@ -175,14 +180,23 @@ impl SimClient { /// Process an IP packet received on the client. pub(crate) fn on_received_packet(&mut self, packet: IpPacket<'static>) { - if let Some(icmp) = packet.as_icmp() { - let echo_reply = icmp.echo_reply_header().expect("to be echo reply"); + if let Some(icmp) = packet.as_icmpv4() { + if let Icmpv4Type::EchoReply(echo) = icmp.icmp_type() { + self.received_icmp_replies + .insert((echo.seq, echo.id), packet.clone()); - self.received_icmp_replies - .insert((echo_reply.seq, echo_reply.id), packet); + return; + } + } - return; - }; + if let Some(icmp) = packet.as_icmpv6() { + if let Icmpv6Type::EchoReply(echo) = icmp.icmp_type() { + self.received_icmp_replies + .insert((echo.seq, echo.id), packet.clone()); + + return; + } + } if let Some(udp) = packet.as_udp() { if udp.source_port() == 53 { @@ -225,7 +239,7 @@ impl SimClient { } } - tracing::error!("Unhandled packet"); + tracing::error!(?packet, "Unhandled packet"); } pub(crate) fn update_relays<'a>( diff --git a/rust/connlib/tunnel/src/tests/sim_gateway.rs b/rust/connlib/tunnel/src/tests/sim_gateway.rs index 901fb78fd..062e7cb6a 100644 --- a/rust/connlib/tunnel/src/tests/sim_gateway.rs +++ b/rust/connlib/tunnel/src/tests/sim_gateway.rs @@ -9,7 +9,7 @@ use connlib_shared::{ messages::{GatewayId, RelayId}, DomainName, }; -use ip_packet::IpPacket; +use ip_packet::{IcmpEchoHeader, Icmpv4Type, Icmpv6Type, IpPacket}; use proptest::prelude::*; use snownet::{EncryptBuffer, Transmit}; use std::{ @@ -70,29 +70,15 @@ impl SimGateway { ) -> Option> { // TODO: Instead of handling these things inline, here, should we dispatch them via `RoutingTable`? - if let Some(icmp) = packet.as_icmp() { - if let Some(echo_request) = icmp.echo_request_header() { - let payload = icmp.payload(); - let echo_id = u64::from_be_bytes(*payload.first_chunk().unwrap()); - tracing::debug!(%echo_id, "Received ICMP request"); + if let Some(icmp) = packet.as_icmpv4() { + if let Icmpv4Type::EchoRequest(echo) = icmp.icmp_type() { + return self.handle_icmp_request(&packet, echo, icmp.payload(), now); + } + } - self.received_icmp_requests.insert(echo_id, packet.clone()); - - let echo_response = ip_packet::make::icmp_reply_packet( - packet.destination(), - packet.source(), - echo_request.seq, - echo_request.id, - payload, - ) - .expect("src and dst are taken from incoming packet"); - let transmit = self - .sut - .encapsulate(echo_response, now, &mut self.enc_buffer)? - .to_transmit(&self.enc_buffer) - .into_owned(); - - return Some(transmit); + if let Some(icmp) = packet.as_icmpv6() { + if let Icmpv6Type::EchoRequest(echo) = icmp.icmp_type() { + return self.handle_icmp_request(&packet, echo, icmp.payload(), now); } } @@ -110,7 +96,7 @@ impl SimGateway { return Some(transmit); } - tracing::error!("Unhandled packet"); + tracing::error!(?packet, "Unhandled packet"); None } @@ -126,6 +112,35 @@ impl SimGateway { now, ) } + + fn handle_icmp_request( + &mut self, + packet: &IpPacket<'static>, + echo: IcmpEchoHeader, + payload: &[u8], + now: Instant, + ) -> Option> { + let echo_id = u64::from_be_bytes(*payload.first_chunk().unwrap()); + self.received_icmp_requests.insert(echo_id, packet.clone()); + + tracing::debug!(%echo_id, "Received ICMP request"); + + let echo_response = ip_packet::make::icmp_reply_packet( + packet.destination(), + packet.source(), + echo.seq, + echo.id, + payload, + ) + .expect("src and dst are taken from incoming packet"); + let transmit = self + .sut + .encapsulate(echo_response, now, &mut self.enc_buffer)? + .to_transmit(&self.enc_buffer) + .into_owned(); + + Some(transmit) + } } /// Reference state for a particular gateway. diff --git a/rust/ip-packet/Cargo.toml b/rust/ip-packet/Cargo.toml index 8a3cf5837..74748a8bf 100644 --- a/rust/ip-packet/Cargo.toml +++ b/rust/ip-packet/Cargo.toml @@ -13,7 +13,6 @@ proptest = ["dep:proptest"] anyhow = "1.0.86" domain = "0.10.1" etherparse = "0.15" -pnet_packet = { version = "0.35" } proptest = { version = "1", optional = true } thiserror = "1" tracing = "0.1" diff --git a/rust/ip-packet/src/icmpv4_header_slice_mut.rs b/rust/ip-packet/src/icmpv4_header_slice_mut.rs new file mode 100644 index 000000000..34c3f8a39 --- /dev/null +++ b/rust/ip-packet/src/icmpv4_header_slice_mut.rs @@ -0,0 +1,81 @@ +use crate::slice_utils::write_to_offset_unchecked; +use etherparse::{ + icmpv4::{TYPE_ECHO_REPLY, TYPE_ECHO_REQUEST}, + Icmpv4Slice, +}; + +pub struct Icmpv4HeaderSliceMut<'a> { + slice: &'a mut [u8], +} + +impl<'a> Icmpv4HeaderSliceMut<'a> { + /// Creates a new [`Icmpv4HeaderSliceMut`]. + pub fn from_slice(slice: &'a mut [u8]) -> Result { + Icmpv4Slice::from_slice(slice)?; + + Ok(Self { slice }) + } + + pub fn set_checksum(&mut self, checksum: u16) { + // Safety: Slice is at least of length 8 as checked in the ctor. + unsafe { write_to_offset_unchecked(self.slice, 2, checksum.to_be_bytes()) }; + } + + pub fn set_identifier(&mut self, id: u16) { + debug_assert!( + self.is_echo_request_or_reply(), + "ICMP identifier only exists for echo requests and replies" + ); + + // Safety: Slice is at least of length 8 as checked in the ctor. + unsafe { write_to_offset_unchecked(self.slice, 4, id.to_be_bytes()) }; + } + + pub fn set_sequence(&mut self, seq: u16) { + debug_assert!( + self.is_echo_request_or_reply(), + "ICMP sequence only exists for echo requests and replies" + ); + + // Safety: Slice is at least of length 8 as checked in the ctor. + unsafe { write_to_offset_unchecked(self.slice, 6, seq.to_be_bytes()) }; + } + + fn is_echo_request_or_reply(&self) -> bool { + let ty = self.slice[0]; + + ty == TYPE_ECHO_REPLY || ty == TYPE_ECHO_REQUEST + } +} + +#[cfg(test)] +mod tests { + use super::*; + use etherparse::{Icmpv4Type, PacketBuilder}; + + #[test] + fn smoke() { + let mut buf = Vec::new(); + + PacketBuilder::ipv4([0u8; 4], [0u8; 4], 0) + .icmpv4_echo_request(10, 20) + .write(&mut buf, &[]) + .unwrap(); + + let mut slice = Icmpv4HeaderSliceMut::from_slice(&mut buf[20..]).unwrap(); + + slice.set_identifier(30); + slice.set_sequence(40); + slice.set_checksum(50); + + let slice = Icmpv4Slice::from_slice(&buf[20..]).unwrap(); + + let Icmpv4Type::EchoRequest(header) = slice.header().icmp_type else { + panic!("Unexpected ICMP header"); + }; + + assert_eq!(header.id, 30); + assert_eq!(header.seq, 40); + assert_eq!(slice.checksum(), 50); + } +} diff --git a/rust/ip-packet/src/icmpv6_header_slice_mut.rs b/rust/ip-packet/src/icmpv6_header_slice_mut.rs new file mode 100644 index 000000000..a9d7dd793 --- /dev/null +++ b/rust/ip-packet/src/icmpv6_header_slice_mut.rs @@ -0,0 +1,81 @@ +use crate::slice_utils::write_to_offset_unchecked; +use etherparse::{ + icmpv6::{TYPE_ECHO_REPLY, TYPE_ECHO_REQUEST}, + Icmpv6Slice, +}; + +pub struct Icmpv6EchoHeaderSliceMut<'a> { + slice: &'a mut [u8], +} + +impl<'a> Icmpv6EchoHeaderSliceMut<'a> { + /// Creates a new [`Icmpv6EchoHeaderSliceMut`]. + pub fn from_slice(slice: &'a mut [u8]) -> Result { + Icmpv6Slice::from_slice(slice)?; + + Ok(Self { slice }) + } + + pub fn set_checksum(&mut self, checksum: u16) { + // Safety: Slice is at least of length 8 as checked in the ctor. + unsafe { write_to_offset_unchecked(self.slice, 2, checksum.to_be_bytes()) }; + } + + pub fn set_identifier(&mut self, id: u16) { + debug_assert!( + self.is_echo_request_or_reply(), + "ICMP identifier only exists for echo requests and replies" + ); + + // Safety: Slice is at least of length 8 as checked in the ctor. + unsafe { write_to_offset_unchecked(self.slice, 4, id.to_be_bytes()) }; + } + + pub fn set_sequence(&mut self, seq: u16) { + debug_assert!( + self.is_echo_request_or_reply(), + "ICMP sequence only exists for echo requests and replies" + ); + + // Safety: Slice is at least of length 8 as checked in the ctor. + unsafe { write_to_offset_unchecked(self.slice, 6, seq.to_be_bytes()) }; + } + + fn is_echo_request_or_reply(&self) -> bool { + let ty = self.slice[0]; + + ty == TYPE_ECHO_REPLY || ty == TYPE_ECHO_REQUEST + } +} + +#[cfg(test)] +mod tests { + use super::*; + use etherparse::{Icmpv6Type, PacketBuilder}; + + #[test] + fn smoke() { + let mut buf = Vec::new(); + + PacketBuilder::ipv6([0u8; 16], [0u8; 16], 0) + .icmpv6_echo_request(10, 20) + .write(&mut buf, &[]) + .unwrap(); + + let mut slice = Icmpv6EchoHeaderSliceMut::from_slice(&mut buf[40..]).unwrap(); + + slice.set_identifier(30); + slice.set_sequence(40); + slice.set_checksum(50); + + let slice = Icmpv6Slice::from_slice(&buf[40..]).unwrap(); + + let Icmpv6Type::EchoRequest(header) = slice.header().icmp_type else { + panic!("Unexpected ICMP header"); + }; + + assert_eq!(header.id, 30); + assert_eq!(header.seq, 40); + assert_eq!(slice.checksum(), 50); + } +} diff --git a/rust/ip-packet/src/lib.rs b/rust/ip-packet/src/lib.rs index 53d668518..dbbe1cc9f 100644 --- a/rust/ip-packet/src/lib.rs +++ b/rust/ip-packet/src/lib.rs @@ -1,5 +1,7 @@ pub mod make; +mod icmpv4_header_slice_mut; +mod icmpv6_header_slice_mut; mod ipv4_header_slice_mut; mod ipv6_header_slice_mut; mod nat46; @@ -7,31 +9,24 @@ mod nat64; #[cfg(feature = "proptest")] pub mod proptest; mod slice_utils; +mod tcp_header_slice_mut; +mod udp_header_slice_mut; -pub use pnet_packet::*; +pub use etherparse::*; #[cfg(all(test, feature = "proptest"))] mod proptests; -use etherparse::{ - IcmpEchoHeader, Icmpv4Slice, Icmpv4Type, Icmpv6Slice, Icmpv6Type, IpNumber, Ipv4Header, - Ipv4HeaderSlice, Ipv6Header, Ipv6HeaderSlice, TcpSlice, UdpSlice, -}; +use icmpv4_header_slice_mut::Icmpv4HeaderSliceMut; +use icmpv6_header_slice_mut::Icmpv6EchoHeaderSliceMut; use ipv4_header_slice_mut::Ipv4HeaderSliceMut; use ipv6_header_slice_mut::Ipv6HeaderSliceMut; -use pnet_packet::{ - icmp::{ - echo_reply::MutableEchoReplyPacket, echo_request::MutableEchoRequestPacket, IcmpTypes, - MutableIcmpPacket, - }, - icmpv6::{Icmpv6Types, MutableIcmpv6Packet}, - tcp::MutableTcpPacket, - udp::MutableUdpPacket, -}; use std::{ net::{IpAddr, Ipv4Addr, Ipv6Addr}, ops::{Deref, DerefMut}, }; +use tcp_header_slice_mut::TcpHeaderSliceMut; +use udp_header_slice_mut::UdpHeaderSliceMut; macro_rules! for_both { ($this:ident, |$name:ident| $body:expr) => { @@ -79,81 +74,21 @@ impl Protocol { } } -#[derive(Debug, PartialEq)] -pub enum IcmpPacket<'a> { - Ipv4(Icmpv4Slice<'a>), - Ipv6(Icmpv6Slice<'a>), -} - -impl<'a> IcmpPacket<'a> { - pub fn icmp_type(&self) -> IcmpType { - match self { - IcmpPacket::Ipv4(v4) => IcmpType::V4(v4.icmp_type()), - IcmpPacket::Ipv6(v6) => IcmpType::V6(v6.icmp_type()), - } - } - - pub fn identifier(&self) -> Option { - Some(self.echo_request_header().or(self.echo_reply_header())?.id) - } - - pub fn sequence(&self) -> Option { - Some(self.echo_request_header().or(self.echo_reply_header())?.seq) - } - - pub fn payload(&self) -> &[u8] { - match self { - IcmpPacket::Ipv4(v4) => v4.payload(), - IcmpPacket::Ipv6(v6) => v6.payload(), - } - } - - pub fn echo_request_header(&self) -> Option { - #[allow( - clippy::wildcard_enum_match_arm, - reason = "We won't ever need to use other ICMP types here." - )] - match self { - IcmpPacket::Ipv4(v4) => match v4.header().icmp_type { - Icmpv4Type::EchoRequest(echo) => Some(echo), - _ => None, - }, - IcmpPacket::Ipv6(v6) => match v6.header().icmp_type { - Icmpv6Type::EchoRequest(echo) => Some(echo), - _ => None, - }, - } - } - - pub fn echo_reply_header(&self) -> Option { - #[allow( - clippy::wildcard_enum_match_arm, - reason = "We won't ever need to use other ICMP types here." - )] - match self { - IcmpPacket::Ipv4(v4) => match v4.header().icmp_type { - Icmpv4Type::EchoReply(echo) => Some(echo), - _ => None, - }, - IcmpPacket::Ipv6(v6) => match v6.header().icmp_type { - Icmpv6Type::EchoReply(echo) => Some(echo), - _ => None, - }, - } - } -} - -pub enum IcmpType { - V4(Icmpv4Type), - V6(Icmpv6Type), -} - -#[derive(Debug, PartialEq, Clone)] +#[derive(PartialEq, Clone)] pub enum IpPacket<'a> { Ipv4(ConvertibleIpv4Packet<'a>), Ipv6(ConvertibleIpv6Packet<'a>), } +impl<'a> std::fmt::Debug for IpPacket<'a> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Ipv4(arg0) => arg0.ip_header().to_header().fmt(f), + Self::Ipv6(arg0) => arg0.header().to_header().fmt(f), + } + } +} + #[derive(Debug, PartialEq)] enum MaybeOwned<'a> { RefMut(&'a mut [u8]), @@ -222,11 +157,11 @@ impl<'a> ConvertibleIpv4Packet<'a> { } fn ip_header(&self) -> Ipv4HeaderSlice { - Ipv4HeaderSlice::from_slice(&self.buf[20..]).expect("we checked this during `new`") + Ipv4HeaderSlice::from_slice(self.packet()).expect("we checked this during `new`") } fn ip_header_mut(&mut self) -> Ipv4HeaderSliceMut { - Ipv4HeaderSliceMut::from_slice(&mut self.buf[20..]).expect("we checked this during `new`") + Ipv4HeaderSliceMut::from_slice(self.packet_mut()).expect("we checked this during `new`") } pub fn get_source(&self) -> Ipv4Addr { @@ -253,27 +188,14 @@ impl<'a> ConvertibleIpv4Packet<'a> { fn header_length(&self) -> usize { (self.ip_header().ihl() * 4) as usize } -} -impl<'a> Packet for ConvertibleIpv4Packet<'a> { - fn packet(&self) -> &[u8] { + pub fn packet(&self) -> &[u8] { &self.buf[20..] } - fn payload(&self) -> &[u8] { - &self.buf[(self.header_length() + 20)..] - } -} - -impl<'a> MutablePacket for ConvertibleIpv4Packet<'a> { fn packet_mut(&mut self) -> &mut [u8] { &mut self.buf[20..] } - - fn payload_mut(&mut self) -> &mut [u8] { - let header_len = self.header_length(); - &mut self.buf[(header_len + 20)..] - } } #[derive(Debug, PartialEq, Clone)] @@ -299,11 +221,12 @@ impl<'a> ConvertibleIpv6Packet<'a> { } fn header(&self) -> Ipv6HeaderSlice { - Ipv6HeaderSlice::from_slice(&self.buf).expect("We checked this in `new` / `owned`") + Ipv6HeaderSlice::from_slice(self.packet()).expect("We checked this in `new` / `owned`") } fn header_mut(&mut self) -> Ipv6HeaderSliceMut { - Ipv6HeaderSliceMut::from_slice(&mut self.buf).expect("We checked this in `new` / `owned`") + Ipv6HeaderSliceMut::from_slice(self.packet_mut()) + .expect("We checked this in `new` / `owned`") } pub fn get_source(&self) -> Ipv6Addr { @@ -325,26 +248,14 @@ impl<'a> ConvertibleIpv6Packet<'a> { Some(ConvertibleIpv4Packet { buf: self.buf }) } -} -impl<'a> Packet for ConvertibleIpv6Packet<'a> { - fn packet(&self) -> &[u8] { + pub fn packet(&self) -> &[u8] { &self.buf } - fn payload(&self) -> &[u8] { - &self.buf[Ipv6Header::LEN..] - } -} - -impl<'a> MutablePacket for ConvertibleIpv6Packet<'a> { fn packet_mut(&mut self) -> &mut [u8] { &mut self.buf } - - fn payload_mut(&mut self) -> &mut [u8] { - &mut self.buf[Ipv6Header::LEN..] - } } pub fn ipv4_embedded(ip: Ipv4Addr) -> Ipv6Addr { @@ -444,11 +355,20 @@ impl<'a> IpPacket<'a> { return Ok(Protocol::Udp(p.source_port())); } - if let Some(p) = self.as_icmp() { - let id = p.identifier().ok_or_else(|| match p.icmp_type() { - IcmpType::V4(v4) => UnsupportedProtocol::UnsupportedIcmpv4Type(v4), - IcmpType::V6(v6) => UnsupportedProtocol::UnsupportedIcmpv6Type(v6), - })?; + if let Some(p) = self.as_icmpv4() { + let id = self + .icmpv4_echo_header() + .ok_or_else(|| UnsupportedProtocol::UnsupportedIcmpv4Type(p.icmp_type()))? + .id; + + return Ok(Protocol::Icmp(id)); + } + + if let Some(p) = self.as_icmpv6() { + let id = self + .icmpv6_echo_header() + .ok_or_else(|| UnsupportedProtocol::UnsupportedIcmpv6Type(p.icmp_type()))? + .id; return Ok(Protocol::Icmp(id)); } @@ -467,11 +387,20 @@ impl<'a> IpPacket<'a> { return Ok(Protocol::Udp(p.destination_port())); } - if let Some(p) = self.as_icmp() { - let id = p.identifier().ok_or_else(|| match p.icmp_type() { - IcmpType::V4(v4) => UnsupportedProtocol::UnsupportedIcmpv4Type(v4), - IcmpType::V6(v6) => UnsupportedProtocol::UnsupportedIcmpv6Type(v6), - })?; + if let Some(p) = self.as_icmpv4() { + let id = self + .icmpv4_echo_header() + .ok_or_else(|| UnsupportedProtocol::UnsupportedIcmpv4Type(p.icmp_type()))? + .id; + + return Ok(Protocol::Icmp(id)); + } + + if let Some(p) = self.as_icmpv6() { + let id = self + .icmpv6_echo_header() + .ok_or_else(|| UnsupportedProtocol::UnsupportedIcmpv6Type(p.icmp_type()))? + .id; return Ok(Protocol::Icmp(id)); } @@ -483,11 +412,11 @@ impl<'a> IpPacket<'a> { pub fn set_source_protocol(&mut self, v: u16) { if let Some(mut p) = self.as_tcp_mut() { - p.set_source(v); + p.set_source_port(v); } if let Some(mut p) = self.as_udp_mut() { - p.set_source(v); + p.set_source_port(v); } self.set_icmp_identifier(v); @@ -495,51 +424,23 @@ impl<'a> IpPacket<'a> { pub fn set_destination_protocol(&mut self, v: u16) { if let Some(mut p) = self.as_tcp_mut() { - p.set_destination(v); + p.set_destination_port(v); } if let Some(mut p) = self.as_udp_mut() { - p.set_destination(v); + p.set_destination_port(v); } self.set_icmp_identifier(v); } fn set_icmp_identifier(&mut self, v: u16) { - if let Some(mut p) = self.as_icmp_mut() { - if p.get_icmp_type() == IcmpTypes::EchoReply { - let Some(mut echo_reply) = MutableEchoReplyPacket::new(p.packet_mut()) else { - return; - }; - echo_reply.set_identifier(v) - } - - if p.get_icmp_type() == IcmpTypes::EchoRequest { - let Some(mut echo_request) = MutableEchoRequestPacket::new(p.packet_mut()) else { - return; - }; - echo_request.set_identifier(v); - } + if let Some(mut p) = self.as_icmpv4_mut() { + p.set_identifier(v); } - if let Some(mut p) = self.as_icmpv6() { - if p.get_icmpv6_type() == Icmpv6Types::EchoReply { - let Some(mut echo_reply) = - icmpv6::echo_reply::MutableEchoReplyPacket::new(p.packet_mut()) - else { - return; - }; - echo_reply.set_identifier(v) - } - - if p.get_icmpv6_type() == Icmpv6Types::EchoRequest { - let Some(mut echo_request) = - icmpv6::echo_request::MutableEchoRequestPacket::new(p.packet_mut()) - else { - return; - }; - echo_request.set_identifier(v); - } + if let Some(mut p) = self.as_icmpv6_mut() { + p.set_identifier(v); } } @@ -605,76 +506,132 @@ impl<'a> IpPacket<'a> { } pub fn as_udp(&self) -> Option { - self.is_udp() - .then(|| UdpSlice::from_slice(self.payload()).ok()) - .flatten() + if !self.is_udp() { + return None; + } + + UdpSlice::from_slice(self.payload()).ok() } - pub fn as_udp_mut(&mut self) -> Option { - self.is_udp() - .then(|| MutableUdpPacket::new(self.payload_mut())) - .flatten() + pub fn as_udp_mut(&mut self) -> Option { + if !self.is_udp() { + return None; + } + + UdpHeaderSliceMut::from_slice(self.payload_mut()).ok() } pub fn as_tcp(&self) -> Option { - self.is_tcp() - .then(|| TcpSlice::from_slice(self.payload()).ok()) - .flatten() - } - - pub fn as_tcp_mut(&mut self) -> Option { - self.is_tcp() - .then(|| MutableTcpPacket::new(self.payload_mut())) - .flatten() - } - - pub fn is_icmp_v4_or_v6(&self) -> bool { - match self { - IpPacket::Ipv4(v4) => v4.ip_header().protocol() == IpNumber::ICMP, - IpPacket::Ipv6(v6) => v6.header().next_header() == IpNumber::IPV6_ICMP, + if !self.is_tcp() { + return None; } + + TcpSlice::from_slice(self.payload()).ok() + } + + pub fn as_tcp_mut(&mut self) -> Option { + if !self.is_tcp() { + return None; + } + + TcpHeaderSliceMut::from_slice(self.payload_mut()).ok() } fn set_icmpv6_checksum(&mut self) { - let (src_addr, dst_addr) = match self { - IpPacket::Ipv4(_) => return, - IpPacket::Ipv6(p) => (p.get_source(), p.get_destination()), + let Some(i) = self.as_icmpv6() else { + return; }; - if let Some(mut pkt) = self.as_icmpv6() { - let checksum = icmpv6::checksum(&pkt.to_immutable(), &src_addr, &dst_addr); - pkt.set_checksum(checksum); - } + + let IpPacket::Ipv6(p) = &self else { + return; + }; + + let checksum = i + .icmp_type() + .calc_checksum( + p.get_source().octets(), + p.get_destination().octets(), + i.payload(), + ) + .expect("Payload came from the original packet"); + + let Some(mut i) = self.as_icmpv6_mut() else { + return; + }; + + i.set_checksum(checksum); } fn set_icmpv4_checksum(&mut self) { - if let Some(mut pkt) = self.as_icmp_mut() { - let checksum = icmp::checksum(&pkt.to_immutable()); - pkt.set_checksum(checksum); + let Some(i) = self.as_icmpv4() else { + return; + }; + + let checksum = i.icmp_type().calc_checksum(i.payload()); + + let Some(mut i) = self.as_icmpv4_mut() else { + return; + }; + + i.set_checksum(checksum); + } + + pub fn as_icmpv4(&self) -> Option { + if !self.is_icmp() { + return None; } + + Icmpv4Slice::from_slice(self.payload()).ok() } - pub fn as_icmp(&self) -> Option { - match self { - Self::Ipv4(v4) if self.is_icmp() => Some(IcmpPacket::Ipv4( - Icmpv4Slice::from_slice(v4.payload()).ok()?, - )), - Self::Ipv6(v6) if self.is_icmpv6() => Some(IcmpPacket::Ipv6( - Icmpv6Slice::from_slice(v6.payload()).ok()?, - )), - Self::Ipv4(_) | Self::Ipv6(_) => None, + pub fn as_icmpv4_mut(&mut self) -> Option { + if !self.is_icmp() { + return None; } + + Icmpv4HeaderSliceMut::from_slice(self.payload_mut()).ok() } - pub fn as_icmp_mut(&mut self) -> Option { - self.is_icmp() - .then(|| MutableIcmpPacket::new(self.payload_mut())) - .flatten() + pub fn as_icmpv6(&self) -> Option { + if !self.is_icmpv6() { + return None; + } + + Icmpv6Slice::from_slice(self.payload()).ok() } - fn as_icmpv6(&mut self) -> Option { - self.is_icmpv6() - .then(|| MutableIcmpv6Packet::new(self.payload_mut())) - .flatten() + pub fn as_icmpv6_mut(&mut self) -> Option { + if !self.is_icmpv6() { + return None; + } + + Icmpv6EchoHeaderSliceMut::from_slice(self.payload_mut()).ok() + } + + fn icmpv4_echo_header(&self) -> Option { + let p = self.as_icmpv4()?; + + use Icmpv4Type::*; + let icmp_type = p.icmp_type(); + + let (EchoReply(header) | EchoRequest(header)) = icmp_type else { + return None; + }; + + Some(header) + } + + fn icmpv6_echo_header(&self) -> Option { + let p = self.as_icmpv6()?; + + use Icmpv6Type::*; + let icmp_type = p.icmp_type(); + + let (EchoReply(header) | EchoRequest(header)) = icmp_type else { + return None; + }; + + Some(header) } pub fn translate_destination( @@ -790,13 +747,46 @@ impl<'a> IpPacket<'a> { self.next_header() == IpNumber::TCP } - fn is_icmp(&self) -> bool { + pub fn is_icmp(&self) -> bool { self.next_header() == IpNumber::ICMP } - fn is_icmpv6(&self) -> bool { + pub fn is_icmpv6(&self) -> bool { self.next_header() == IpNumber::IPV6_ICMP } + + fn header_length(&self) -> usize { + match self { + IpPacket::Ipv4(v4) => v4.header_length(), + IpPacket::Ipv6(v6) => v6.header().header_len(), + } + } + + pub fn packet(&self) -> &[u8] { + match self { + IpPacket::Ipv4(v4) => v4.packet(), + IpPacket::Ipv6(v6) => v6.packet(), + } + } + + fn packet_mut(&mut self) -> &mut [u8] { + match self { + IpPacket::Ipv4(v4) => v4.packet_mut(), + IpPacket::Ipv6(v6) => v6.packet_mut(), + } + } + + fn payload(&self) -> &[u8] { + let start = self.header_length(); + + &self.packet()[start..] + } + + fn payload_mut(&mut self) -> &mut [u8] { + let start = self.header_length(); + + &mut self.packet_mut()[start..] + } } impl<'a> From> for IpPacket<'a> { @@ -811,26 +801,6 @@ impl<'a> From> for IpPacket<'a> { } } -impl pnet_packet::Packet for IpPacket<'_> { - fn packet(&self) -> &[u8] { - for_both!(self, |i| i.packet()) - } - - fn payload(&self) -> &[u8] { - for_both!(self, |i| i.payload()) - } -} - -impl pnet_packet::MutablePacket for IpPacket<'_> { - fn packet_mut(&mut self) -> &mut [u8] { - for_both!(self, |i| i.packet_mut()) - } - - fn payload_mut(&mut self) -> &mut [u8] { - for_both!(self, |i| i.payload_mut()) - } -} - #[derive(Debug, thiserror::Error)] pub enum UnsupportedProtocol { #[error("Unsupported IP protocol: {0:?}")] diff --git a/rust/ip-packet/src/proptests.rs b/rust/ip-packet/src/proptests.rs index d7eaf49c0..c2a69ef96 100644 --- a/rust/ip-packet/src/proptests.rs +++ b/rust/ip-packet/src/proptests.rs @@ -1,6 +1,5 @@ use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; -use pnet_packet::Packet; use proptest::arbitrary::any; use proptest::prop_oneof; use proptest::strategy::Strategy; diff --git a/rust/ip-packet/src/slice_utils.rs b/rust/ip-packet/src/slice_utils.rs index 558131917..8a8e216a2 100644 --- a/rust/ip-packet/src/slice_utils.rs +++ b/rust/ip-packet/src/slice_utils.rs @@ -8,7 +8,7 @@ pub unsafe fn write_to_offset_unchecked( offset: usize, bytes: [u8; N], ) { - debug_assert!(offset + N < slice.len()); + debug_assert!(offset + N <= slice.len()); let (_front, rest) = unsafe { slice.split_at_mut_unchecked(offset) }; let (target, _rest) = unsafe { rest.split_at_mut_unchecked(N) }; diff --git a/rust/ip-packet/src/tcp_header_slice_mut.rs b/rust/ip-packet/src/tcp_header_slice_mut.rs new file mode 100644 index 000000000..0a896af6d --- /dev/null +++ b/rust/ip-packet/src/tcp_header_slice_mut.rs @@ -0,0 +1,58 @@ +use crate::slice_utils::write_to_offset_unchecked; +use etherparse::TcpHeaderSlice; + +pub struct TcpHeaderSliceMut<'a> { + slice: &'a mut [u8], +} + +impl<'a> TcpHeaderSliceMut<'a> { + /// Creates a new [`TcpHeaderSliceMut`]. + pub fn from_slice(slice: &'a mut [u8]) -> Result { + TcpHeaderSlice::from_slice(slice)?; + + Ok(Self { slice }) + } + + pub fn set_source_port(&mut self, src: u16) { + // Safety: Slice it at least of length 20 as checked in the ctor. + unsafe { write_to_offset_unchecked(self.slice, 0, src.to_be_bytes()) }; + } + + pub fn set_destination_port(&mut self, dst: u16) { + // Safety: Slice it at least of length 20 as checked in the ctor. + unsafe { write_to_offset_unchecked(self.slice, 2, dst.to_be_bytes()) }; + } + + pub fn set_checksum(&mut self, checksum: u16) { + // Safety: Slice it at least of length 20 as checked in the ctor. + unsafe { write_to_offset_unchecked(self.slice, 16, checksum.to_be_bytes()) }; + } +} + +#[cfg(test)] +mod tests { + use super::*; + use etherparse::PacketBuilder; + + #[test] + fn smoke() { + let mut buf = Vec::new(); + + PacketBuilder::ipv4([0u8; 4], [0u8; 4], 0) + .tcp(10, 20, 0, 0) + .write(&mut buf, &[]) + .unwrap(); + + let mut slice = TcpHeaderSliceMut::from_slice(&mut buf[20..]).unwrap(); + + slice.set_source_port(30); + slice.set_destination_port(40); + slice.set_checksum(50); + + let slice = TcpHeaderSlice::from_slice(&buf[20..]).unwrap(); + + assert_eq!(slice.source_port(), 30); + assert_eq!(slice.destination_port(), 40); + assert_eq!(slice.checksum(), 50); + } +} diff --git a/rust/ip-packet/src/udp_header_slice_mut.rs b/rust/ip-packet/src/udp_header_slice_mut.rs new file mode 100644 index 000000000..986b88604 --- /dev/null +++ b/rust/ip-packet/src/udp_header_slice_mut.rs @@ -0,0 +1,65 @@ +use crate::slice_utils::write_to_offset_unchecked; +use etherparse::UdpHeaderSlice; + +pub struct UdpHeaderSliceMut<'a> { + slice: &'a mut [u8], +} + +impl<'a> UdpHeaderSliceMut<'a> { + /// Creates a new [`UdpHeaderSliceMut`]. + pub fn from_slice(slice: &'a mut [u8]) -> Result { + UdpHeaderSlice::from_slice(slice)?; + + Ok(Self { slice }) + } + + pub fn set_source_port(&mut self, src: u16) { + // Safety: Slice it at least of length 8 as checked in the ctor. + unsafe { write_to_offset_unchecked(self.slice, 0, src.to_be_bytes()) }; + } + + pub fn set_destination_port(&mut self, dst: u16) { + // Safety: Slice it at least of length 8 as checked in the ctor. + unsafe { write_to_offset_unchecked(self.slice, 2, dst.to_be_bytes()) }; + } + + pub fn set_length(&mut self, length: u16) { + // Safety: Slice it at least of length 8 as checked in the ctor. + unsafe { write_to_offset_unchecked(self.slice, 4, length.to_be_bytes()) }; + } + + pub fn set_checksum(&mut self, checksum: u16) { + // Safety: Slice it at least of length 8 as checked in the ctor. + unsafe { write_to_offset_unchecked(self.slice, 6, checksum.to_be_bytes()) }; + } +} + +#[cfg(test)] +mod tests { + use super::*; + use etherparse::PacketBuilder; + + #[test] + fn smoke() { + let mut buf = Vec::new(); + + PacketBuilder::ipv4([0u8; 4], [0u8; 4], 0) + .udp(10, 20) + .write(&mut buf, &[]) + .unwrap(); + + let mut slice = UdpHeaderSliceMut::from_slice(&mut buf[20..]).unwrap(); + + slice.set_source_port(30); + slice.set_destination_port(40); + slice.set_length(50); + slice.set_checksum(60); + + let slice = UdpHeaderSlice::from_slice(&buf[20..]).unwrap(); + + assert_eq!(slice.source_port(), 30); + assert_eq!(slice.destination_port(), 40); + assert_eq!(slice.length(), 50); + assert_eq!(slice.checksum(), 60); + } +}