refactor(connlib): remove pnet_packet (#6659)

As the final step in removing `pnet_packet`, we need to introduce `-Mut`
equivalent slices for UDP, TCP and ICMP packets. As a starting point,
introducing `UpdHeaderSliceMut` and `TcpHeaderSliceMut` is fairly
trivial. The ICMP variants are a bit trickier because those are
different for IPv4 and IPv6. Additionally, ICMP for IPv4 is quite
complex because it can have a variable header length. Additionally. for
both variants, the values in byte range 5-8 are semantically different
depending on the ICMP code.

This requires us to design an API that balances ergonomics and
correctness. Technically, an ICMP identifier and sequence can only be
set if the ICMP code is "echo request" or "echo reply". However, adding
an additional parsing step to guarantee this in the type system is quite
verbose.

The trade-off implemented in this PR allows to us to directly write to
the byte 5-8 using the `set_identifier` and `set_sequence` functions. To
catch errors early, this functions have debug-assertions built in that
ensure that the packet is indeed an ICMP echo packet.

Resolves: #6366.
This commit is contained in:
Thomas Eizinger
2024-09-11 19:52:48 -04:00
committed by GitHub
parent 133c2565b2
commit 7adbf9c6af
15 changed files with 558 additions and 327 deletions

49
rust/Cargo.lock generated
View File

@@ -3086,7 +3086,6 @@ dependencies = [
"anyhow",
"domain",
"etherparse",
"pnet_packet",
"proptest",
"test-strategy",
"thiserror",
@@ -3822,12 +3821,6 @@ dependencies = [
"memoffset 0.9.1",
]
[[package]]
name = "no-std-net"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "43794a0ace135be66a25d3ae77d41b91615fb68ae937f904090203e81f755b65"
[[package]]
name = "nodrop"
version = "0.1.14"
@@ -4528,48 +4521,6 @@ dependencies = [
"time",
]
[[package]]
name = "pnet_base"
version = "0.35.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ffc190d4067df16af3aba49b3b74c469e611cad6314676eaf1157f31aa0fb2f7"
dependencies = [
"no-std-net",
]
[[package]]
name = "pnet_macros"
version = "0.35.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "13325ac86ee1a80a480b0bc8e3d30c25d133616112bb16e86f712dcf8a71c863"
dependencies = [
"proc-macro2",
"quote",
"regex",
"syn 2.0.72",
]
[[package]]
name = "pnet_macros_support"
version = "0.35.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eed67a952585d509dd0003049b1fc56b982ac665c8299b124b90ea2bdb3134ab"
dependencies = [
"pnet_base",
]
[[package]]
name = "pnet_packet"
version = "0.35.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4c96ebadfab635fcc23036ba30a7d33a80c39e8461b8bd7dc7bb186acb96560f"
dependencies = [
"glob",
"pnet_base",
"pnet_macros",
"pnet_macros_support",
]
[[package]]
name = "png"
version = "0.17.13"

View File

@@ -24,7 +24,7 @@ mod platform {
mod platform {
use anyhow::Result;
use firezone_bin_shared::TunDeviceManager;
use ip_packet::{IpPacket, Packet as _};
use ip_packet::IpPacket;
use std::{
future::poll_fn,
net::{Ipv4Addr, Ipv6Addr},

View File

@@ -9,7 +9,7 @@ use boringtun::x25519::PublicKey;
use boringtun::{noise::rate_limiter::RateLimiter, x25519::StaticSecret};
use core::fmt;
use hex_display::HexDisplayExt;
use ip_packet::{ConvertibleIpv4Packet, ConvertibleIpv6Packet, IpPacket, Packet as _};
use ip_packet::{ConvertibleIpv4Packet, ConvertibleIpv6Packet, IpPacket};
use rand::rngs::StdRng;
use rand::seq::IteratorRandom;
use rand::{random, SeedableRng};

View File

@@ -1,4 +1,4 @@
use ip_packet::{IpPacket, Packet as _};
use ip_packet::IpPacket;
use std::io;
use std::task::{Context, Poll, Waker};
use tun::Tun;
@@ -31,8 +31,6 @@ impl Device {
buf: &'b mut [u8],
cx: &mut Context<'_>,
) -> Poll<io::Result<IpPacket<'b>>> {
use ip_packet::Packet as _;
let Some(tun) = self.tun.as_mut() else {
self.waker = Some(cx.waker().clone());
return Poll::Pending;

View File

@@ -77,7 +77,7 @@ impl AllowRules {
return self.udp.contains(&udp.destination_port());
}
if packet.is_icmp_v4_or_v6() {
if packet.is_icmp() || packet.is_icmpv6() {
return self.icmp;
}

View File

@@ -24,7 +24,7 @@ use domain::{
};
use ip_network::{IpNetwork, Ipv4Network, Ipv6Network};
use ip_network_table::IpNetworkTable;
use ip_packet::IpPacket;
use ip_packet::{Icmpv4Type, Icmpv6Type, IpPacket};
use itertools::Itertools as _;
use prop::collection;
use proptest::prelude::*;
@@ -123,12 +123,17 @@ impl SimClient {
packet: IpPacket<'static>,
now: Instant,
) -> Option<snownet::Transmit<'static>> {
{
if let Some(icmp) = packet.as_icmp() {
let echo_request = icmp.echo_request_header().expect("to be echo request");
if let Some(icmp) = packet.as_icmpv4() {
if let Icmpv4Type::EchoRequest(echo) = icmp.icmp_type() {
self.sent_icmp_requests
.insert((echo_request.seq, echo_request.id), packet.clone());
.insert((echo.seq, echo.id), packet.clone());
}
}
if let Some(icmp) = packet.as_icmpv6() {
if let Icmpv6Type::EchoRequest(echo) = icmp.icmp_type() {
self.sent_icmp_requests
.insert((echo.seq, echo.id), packet.clone());
}
}
@@ -175,14 +180,23 @@ impl SimClient {
/// Process an IP packet received on the client.
pub(crate) fn on_received_packet(&mut self, packet: IpPacket<'static>) {
if let Some(icmp) = packet.as_icmp() {
let echo_reply = icmp.echo_reply_header().expect("to be echo reply");
if let Some(icmp) = packet.as_icmpv4() {
if let Icmpv4Type::EchoReply(echo) = icmp.icmp_type() {
self.received_icmp_replies
.insert((echo.seq, echo.id), packet.clone());
self.received_icmp_replies
.insert((echo_reply.seq, echo_reply.id), packet);
return;
}
}
return;
};
if let Some(icmp) = packet.as_icmpv6() {
if let Icmpv6Type::EchoReply(echo) = icmp.icmp_type() {
self.received_icmp_replies
.insert((echo.seq, echo.id), packet.clone());
return;
}
}
if let Some(udp) = packet.as_udp() {
if udp.source_port() == 53 {
@@ -225,7 +239,7 @@ impl SimClient {
}
}
tracing::error!("Unhandled packet");
tracing::error!(?packet, "Unhandled packet");
}
pub(crate) fn update_relays<'a>(

View File

@@ -9,7 +9,7 @@ use connlib_shared::{
messages::{GatewayId, RelayId},
DomainName,
};
use ip_packet::IpPacket;
use ip_packet::{IcmpEchoHeader, Icmpv4Type, Icmpv6Type, IpPacket};
use proptest::prelude::*;
use snownet::{EncryptBuffer, Transmit};
use std::{
@@ -70,29 +70,15 @@ impl SimGateway {
) -> Option<Transmit<'static>> {
// TODO: Instead of handling these things inline, here, should we dispatch them via `RoutingTable`?
if let Some(icmp) = packet.as_icmp() {
if let Some(echo_request) = icmp.echo_request_header() {
let payload = icmp.payload();
let echo_id = u64::from_be_bytes(*payload.first_chunk().unwrap());
tracing::debug!(%echo_id, "Received ICMP request");
if let Some(icmp) = packet.as_icmpv4() {
if let Icmpv4Type::EchoRequest(echo) = icmp.icmp_type() {
return self.handle_icmp_request(&packet, echo, icmp.payload(), now);
}
}
self.received_icmp_requests.insert(echo_id, packet.clone());
let echo_response = ip_packet::make::icmp_reply_packet(
packet.destination(),
packet.source(),
echo_request.seq,
echo_request.id,
payload,
)
.expect("src and dst are taken from incoming packet");
let transmit = self
.sut
.encapsulate(echo_response, now, &mut self.enc_buffer)?
.to_transmit(&self.enc_buffer)
.into_owned();
return Some(transmit);
if let Some(icmp) = packet.as_icmpv6() {
if let Icmpv6Type::EchoRequest(echo) = icmp.icmp_type() {
return self.handle_icmp_request(&packet, echo, icmp.payload(), now);
}
}
@@ -110,7 +96,7 @@ impl SimGateway {
return Some(transmit);
}
tracing::error!("Unhandled packet");
tracing::error!(?packet, "Unhandled packet");
None
}
@@ -126,6 +112,35 @@ impl SimGateway {
now,
)
}
fn handle_icmp_request(
&mut self,
packet: &IpPacket<'static>,
echo: IcmpEchoHeader,
payload: &[u8],
now: Instant,
) -> Option<Transmit<'static>> {
let echo_id = u64::from_be_bytes(*payload.first_chunk().unwrap());
self.received_icmp_requests.insert(echo_id, packet.clone());
tracing::debug!(%echo_id, "Received ICMP request");
let echo_response = ip_packet::make::icmp_reply_packet(
packet.destination(),
packet.source(),
echo.seq,
echo.id,
payload,
)
.expect("src and dst are taken from incoming packet");
let transmit = self
.sut
.encapsulate(echo_response, now, &mut self.enc_buffer)?
.to_transmit(&self.enc_buffer)
.into_owned();
Some(transmit)
}
}
/// Reference state for a particular gateway.

View File

@@ -13,7 +13,6 @@ proptest = ["dep:proptest"]
anyhow = "1.0.86"
domain = "0.10.1"
etherparse = "0.15"
pnet_packet = { version = "0.35" }
proptest = { version = "1", optional = true }
thiserror = "1"
tracing = "0.1"

View File

@@ -0,0 +1,81 @@
use crate::slice_utils::write_to_offset_unchecked;
use etherparse::{
icmpv4::{TYPE_ECHO_REPLY, TYPE_ECHO_REQUEST},
Icmpv4Slice,
};
pub struct Icmpv4HeaderSliceMut<'a> {
slice: &'a mut [u8],
}
impl<'a> Icmpv4HeaderSliceMut<'a> {
/// Creates a new [`Icmpv4HeaderSliceMut`].
pub fn from_slice(slice: &'a mut [u8]) -> Result<Self, etherparse::err::LenError> {
Icmpv4Slice::from_slice(slice)?;
Ok(Self { slice })
}
pub fn set_checksum(&mut self, checksum: u16) {
// Safety: Slice is at least of length 8 as checked in the ctor.
unsafe { write_to_offset_unchecked(self.slice, 2, checksum.to_be_bytes()) };
}
pub fn set_identifier(&mut self, id: u16) {
debug_assert!(
self.is_echo_request_or_reply(),
"ICMP identifier only exists for echo requests and replies"
);
// Safety: Slice is at least of length 8 as checked in the ctor.
unsafe { write_to_offset_unchecked(self.slice, 4, id.to_be_bytes()) };
}
pub fn set_sequence(&mut self, seq: u16) {
debug_assert!(
self.is_echo_request_or_reply(),
"ICMP sequence only exists for echo requests and replies"
);
// Safety: Slice is at least of length 8 as checked in the ctor.
unsafe { write_to_offset_unchecked(self.slice, 6, seq.to_be_bytes()) };
}
fn is_echo_request_or_reply(&self) -> bool {
let ty = self.slice[0];
ty == TYPE_ECHO_REPLY || ty == TYPE_ECHO_REQUEST
}
}
#[cfg(test)]
mod tests {
use super::*;
use etherparse::{Icmpv4Type, PacketBuilder};
#[test]
fn smoke() {
let mut buf = Vec::new();
PacketBuilder::ipv4([0u8; 4], [0u8; 4], 0)
.icmpv4_echo_request(10, 20)
.write(&mut buf, &[])
.unwrap();
let mut slice = Icmpv4HeaderSliceMut::from_slice(&mut buf[20..]).unwrap();
slice.set_identifier(30);
slice.set_sequence(40);
slice.set_checksum(50);
let slice = Icmpv4Slice::from_slice(&buf[20..]).unwrap();
let Icmpv4Type::EchoRequest(header) = slice.header().icmp_type else {
panic!("Unexpected ICMP header");
};
assert_eq!(header.id, 30);
assert_eq!(header.seq, 40);
assert_eq!(slice.checksum(), 50);
}
}

View File

@@ -0,0 +1,81 @@
use crate::slice_utils::write_to_offset_unchecked;
use etherparse::{
icmpv6::{TYPE_ECHO_REPLY, TYPE_ECHO_REQUEST},
Icmpv6Slice,
};
pub struct Icmpv6EchoHeaderSliceMut<'a> {
slice: &'a mut [u8],
}
impl<'a> Icmpv6EchoHeaderSliceMut<'a> {
/// Creates a new [`Icmpv6EchoHeaderSliceMut`].
pub fn from_slice(slice: &'a mut [u8]) -> Result<Self, etherparse::err::LenError> {
Icmpv6Slice::from_slice(slice)?;
Ok(Self { slice })
}
pub fn set_checksum(&mut self, checksum: u16) {
// Safety: Slice is at least of length 8 as checked in the ctor.
unsafe { write_to_offset_unchecked(self.slice, 2, checksum.to_be_bytes()) };
}
pub fn set_identifier(&mut self, id: u16) {
debug_assert!(
self.is_echo_request_or_reply(),
"ICMP identifier only exists for echo requests and replies"
);
// Safety: Slice is at least of length 8 as checked in the ctor.
unsafe { write_to_offset_unchecked(self.slice, 4, id.to_be_bytes()) };
}
pub fn set_sequence(&mut self, seq: u16) {
debug_assert!(
self.is_echo_request_or_reply(),
"ICMP sequence only exists for echo requests and replies"
);
// Safety: Slice is at least of length 8 as checked in the ctor.
unsafe { write_to_offset_unchecked(self.slice, 6, seq.to_be_bytes()) };
}
fn is_echo_request_or_reply(&self) -> bool {
let ty = self.slice[0];
ty == TYPE_ECHO_REPLY || ty == TYPE_ECHO_REQUEST
}
}
#[cfg(test)]
mod tests {
use super::*;
use etherparse::{Icmpv6Type, PacketBuilder};
#[test]
fn smoke() {
let mut buf = Vec::new();
PacketBuilder::ipv6([0u8; 16], [0u8; 16], 0)
.icmpv6_echo_request(10, 20)
.write(&mut buf, &[])
.unwrap();
let mut slice = Icmpv6EchoHeaderSliceMut::from_slice(&mut buf[40..]).unwrap();
slice.set_identifier(30);
slice.set_sequence(40);
slice.set_checksum(50);
let slice = Icmpv6Slice::from_slice(&buf[40..]).unwrap();
let Icmpv6Type::EchoRequest(header) = slice.header().icmp_type else {
panic!("Unexpected ICMP header");
};
assert_eq!(header.id, 30);
assert_eq!(header.seq, 40);
assert_eq!(slice.checksum(), 50);
}
}

View File

@@ -1,5 +1,7 @@
pub mod make;
mod icmpv4_header_slice_mut;
mod icmpv6_header_slice_mut;
mod ipv4_header_slice_mut;
mod ipv6_header_slice_mut;
mod nat46;
@@ -7,31 +9,24 @@ mod nat64;
#[cfg(feature = "proptest")]
pub mod proptest;
mod slice_utils;
mod tcp_header_slice_mut;
mod udp_header_slice_mut;
pub use pnet_packet::*;
pub use etherparse::*;
#[cfg(all(test, feature = "proptest"))]
mod proptests;
use etherparse::{
IcmpEchoHeader, Icmpv4Slice, Icmpv4Type, Icmpv6Slice, Icmpv6Type, IpNumber, Ipv4Header,
Ipv4HeaderSlice, Ipv6Header, Ipv6HeaderSlice, TcpSlice, UdpSlice,
};
use icmpv4_header_slice_mut::Icmpv4HeaderSliceMut;
use icmpv6_header_slice_mut::Icmpv6EchoHeaderSliceMut;
use ipv4_header_slice_mut::Ipv4HeaderSliceMut;
use ipv6_header_slice_mut::Ipv6HeaderSliceMut;
use pnet_packet::{
icmp::{
echo_reply::MutableEchoReplyPacket, echo_request::MutableEchoRequestPacket, IcmpTypes,
MutableIcmpPacket,
},
icmpv6::{Icmpv6Types, MutableIcmpv6Packet},
tcp::MutableTcpPacket,
udp::MutableUdpPacket,
};
use std::{
net::{IpAddr, Ipv4Addr, Ipv6Addr},
ops::{Deref, DerefMut},
};
use tcp_header_slice_mut::TcpHeaderSliceMut;
use udp_header_slice_mut::UdpHeaderSliceMut;
macro_rules! for_both {
($this:ident, |$name:ident| $body:expr) => {
@@ -79,81 +74,21 @@ impl Protocol {
}
}
#[derive(Debug, PartialEq)]
pub enum IcmpPacket<'a> {
Ipv4(Icmpv4Slice<'a>),
Ipv6(Icmpv6Slice<'a>),
}
impl<'a> IcmpPacket<'a> {
pub fn icmp_type(&self) -> IcmpType {
match self {
IcmpPacket::Ipv4(v4) => IcmpType::V4(v4.icmp_type()),
IcmpPacket::Ipv6(v6) => IcmpType::V6(v6.icmp_type()),
}
}
pub fn identifier(&self) -> Option<u16> {
Some(self.echo_request_header().or(self.echo_reply_header())?.id)
}
pub fn sequence(&self) -> Option<u16> {
Some(self.echo_request_header().or(self.echo_reply_header())?.seq)
}
pub fn payload(&self) -> &[u8] {
match self {
IcmpPacket::Ipv4(v4) => v4.payload(),
IcmpPacket::Ipv6(v6) => v6.payload(),
}
}
pub fn echo_request_header(&self) -> Option<IcmpEchoHeader> {
#[allow(
clippy::wildcard_enum_match_arm,
reason = "We won't ever need to use other ICMP types here."
)]
match self {
IcmpPacket::Ipv4(v4) => match v4.header().icmp_type {
Icmpv4Type::EchoRequest(echo) => Some(echo),
_ => None,
},
IcmpPacket::Ipv6(v6) => match v6.header().icmp_type {
Icmpv6Type::EchoRequest(echo) => Some(echo),
_ => None,
},
}
}
pub fn echo_reply_header(&self) -> Option<IcmpEchoHeader> {
#[allow(
clippy::wildcard_enum_match_arm,
reason = "We won't ever need to use other ICMP types here."
)]
match self {
IcmpPacket::Ipv4(v4) => match v4.header().icmp_type {
Icmpv4Type::EchoReply(echo) => Some(echo),
_ => None,
},
IcmpPacket::Ipv6(v6) => match v6.header().icmp_type {
Icmpv6Type::EchoReply(echo) => Some(echo),
_ => None,
},
}
}
}
pub enum IcmpType {
V4(Icmpv4Type),
V6(Icmpv6Type),
}
#[derive(Debug, PartialEq, Clone)]
#[derive(PartialEq, Clone)]
pub enum IpPacket<'a> {
Ipv4(ConvertibleIpv4Packet<'a>),
Ipv6(ConvertibleIpv6Packet<'a>),
}
impl<'a> std::fmt::Debug for IpPacket<'a> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Ipv4(arg0) => arg0.ip_header().to_header().fmt(f),
Self::Ipv6(arg0) => arg0.header().to_header().fmt(f),
}
}
}
#[derive(Debug, PartialEq)]
enum MaybeOwned<'a> {
RefMut(&'a mut [u8]),
@@ -222,11 +157,11 @@ impl<'a> ConvertibleIpv4Packet<'a> {
}
fn ip_header(&self) -> Ipv4HeaderSlice {
Ipv4HeaderSlice::from_slice(&self.buf[20..]).expect("we checked this during `new`")
Ipv4HeaderSlice::from_slice(self.packet()).expect("we checked this during `new`")
}
fn ip_header_mut(&mut self) -> Ipv4HeaderSliceMut {
Ipv4HeaderSliceMut::from_slice(&mut self.buf[20..]).expect("we checked this during `new`")
Ipv4HeaderSliceMut::from_slice(self.packet_mut()).expect("we checked this during `new`")
}
pub fn get_source(&self) -> Ipv4Addr {
@@ -253,27 +188,14 @@ impl<'a> ConvertibleIpv4Packet<'a> {
fn header_length(&self) -> usize {
(self.ip_header().ihl() * 4) as usize
}
}
impl<'a> Packet for ConvertibleIpv4Packet<'a> {
fn packet(&self) -> &[u8] {
pub fn packet(&self) -> &[u8] {
&self.buf[20..]
}
fn payload(&self) -> &[u8] {
&self.buf[(self.header_length() + 20)..]
}
}
impl<'a> MutablePacket for ConvertibleIpv4Packet<'a> {
fn packet_mut(&mut self) -> &mut [u8] {
&mut self.buf[20..]
}
fn payload_mut(&mut self) -> &mut [u8] {
let header_len = self.header_length();
&mut self.buf[(header_len + 20)..]
}
}
#[derive(Debug, PartialEq, Clone)]
@@ -299,11 +221,12 @@ impl<'a> ConvertibleIpv6Packet<'a> {
}
fn header(&self) -> Ipv6HeaderSlice {
Ipv6HeaderSlice::from_slice(&self.buf).expect("We checked this in `new` / `owned`")
Ipv6HeaderSlice::from_slice(self.packet()).expect("We checked this in `new` / `owned`")
}
fn header_mut(&mut self) -> Ipv6HeaderSliceMut {
Ipv6HeaderSliceMut::from_slice(&mut self.buf).expect("We checked this in `new` / `owned`")
Ipv6HeaderSliceMut::from_slice(self.packet_mut())
.expect("We checked this in `new` / `owned`")
}
pub fn get_source(&self) -> Ipv6Addr {
@@ -325,26 +248,14 @@ impl<'a> ConvertibleIpv6Packet<'a> {
Some(ConvertibleIpv4Packet { buf: self.buf })
}
}
impl<'a> Packet for ConvertibleIpv6Packet<'a> {
fn packet(&self) -> &[u8] {
pub fn packet(&self) -> &[u8] {
&self.buf
}
fn payload(&self) -> &[u8] {
&self.buf[Ipv6Header::LEN..]
}
}
impl<'a> MutablePacket for ConvertibleIpv6Packet<'a> {
fn packet_mut(&mut self) -> &mut [u8] {
&mut self.buf
}
fn payload_mut(&mut self) -> &mut [u8] {
&mut self.buf[Ipv6Header::LEN..]
}
}
pub fn ipv4_embedded(ip: Ipv4Addr) -> Ipv6Addr {
@@ -444,11 +355,20 @@ impl<'a> IpPacket<'a> {
return Ok(Protocol::Udp(p.source_port()));
}
if let Some(p) = self.as_icmp() {
let id = p.identifier().ok_or_else(|| match p.icmp_type() {
IcmpType::V4(v4) => UnsupportedProtocol::UnsupportedIcmpv4Type(v4),
IcmpType::V6(v6) => UnsupportedProtocol::UnsupportedIcmpv6Type(v6),
})?;
if let Some(p) = self.as_icmpv4() {
let id = self
.icmpv4_echo_header()
.ok_or_else(|| UnsupportedProtocol::UnsupportedIcmpv4Type(p.icmp_type()))?
.id;
return Ok(Protocol::Icmp(id));
}
if let Some(p) = self.as_icmpv6() {
let id = self
.icmpv6_echo_header()
.ok_or_else(|| UnsupportedProtocol::UnsupportedIcmpv6Type(p.icmp_type()))?
.id;
return Ok(Protocol::Icmp(id));
}
@@ -467,11 +387,20 @@ impl<'a> IpPacket<'a> {
return Ok(Protocol::Udp(p.destination_port()));
}
if let Some(p) = self.as_icmp() {
let id = p.identifier().ok_or_else(|| match p.icmp_type() {
IcmpType::V4(v4) => UnsupportedProtocol::UnsupportedIcmpv4Type(v4),
IcmpType::V6(v6) => UnsupportedProtocol::UnsupportedIcmpv6Type(v6),
})?;
if let Some(p) = self.as_icmpv4() {
let id = self
.icmpv4_echo_header()
.ok_or_else(|| UnsupportedProtocol::UnsupportedIcmpv4Type(p.icmp_type()))?
.id;
return Ok(Protocol::Icmp(id));
}
if let Some(p) = self.as_icmpv6() {
let id = self
.icmpv6_echo_header()
.ok_or_else(|| UnsupportedProtocol::UnsupportedIcmpv6Type(p.icmp_type()))?
.id;
return Ok(Protocol::Icmp(id));
}
@@ -483,11 +412,11 @@ impl<'a> IpPacket<'a> {
pub fn set_source_protocol(&mut self, v: u16) {
if let Some(mut p) = self.as_tcp_mut() {
p.set_source(v);
p.set_source_port(v);
}
if let Some(mut p) = self.as_udp_mut() {
p.set_source(v);
p.set_source_port(v);
}
self.set_icmp_identifier(v);
@@ -495,51 +424,23 @@ impl<'a> IpPacket<'a> {
pub fn set_destination_protocol(&mut self, v: u16) {
if let Some(mut p) = self.as_tcp_mut() {
p.set_destination(v);
p.set_destination_port(v);
}
if let Some(mut p) = self.as_udp_mut() {
p.set_destination(v);
p.set_destination_port(v);
}
self.set_icmp_identifier(v);
}
fn set_icmp_identifier(&mut self, v: u16) {
if let Some(mut p) = self.as_icmp_mut() {
if p.get_icmp_type() == IcmpTypes::EchoReply {
let Some(mut echo_reply) = MutableEchoReplyPacket::new(p.packet_mut()) else {
return;
};
echo_reply.set_identifier(v)
}
if p.get_icmp_type() == IcmpTypes::EchoRequest {
let Some(mut echo_request) = MutableEchoRequestPacket::new(p.packet_mut()) else {
return;
};
echo_request.set_identifier(v);
}
if let Some(mut p) = self.as_icmpv4_mut() {
p.set_identifier(v);
}
if let Some(mut p) = self.as_icmpv6() {
if p.get_icmpv6_type() == Icmpv6Types::EchoReply {
let Some(mut echo_reply) =
icmpv6::echo_reply::MutableEchoReplyPacket::new(p.packet_mut())
else {
return;
};
echo_reply.set_identifier(v)
}
if p.get_icmpv6_type() == Icmpv6Types::EchoRequest {
let Some(mut echo_request) =
icmpv6::echo_request::MutableEchoRequestPacket::new(p.packet_mut())
else {
return;
};
echo_request.set_identifier(v);
}
if let Some(mut p) = self.as_icmpv6_mut() {
p.set_identifier(v);
}
}
@@ -605,76 +506,132 @@ impl<'a> IpPacket<'a> {
}
pub fn as_udp(&self) -> Option<UdpSlice> {
self.is_udp()
.then(|| UdpSlice::from_slice(self.payload()).ok())
.flatten()
if !self.is_udp() {
return None;
}
UdpSlice::from_slice(self.payload()).ok()
}
pub fn as_udp_mut(&mut self) -> Option<MutableUdpPacket> {
self.is_udp()
.then(|| MutableUdpPacket::new(self.payload_mut()))
.flatten()
pub fn as_udp_mut(&mut self) -> Option<UdpHeaderSliceMut> {
if !self.is_udp() {
return None;
}
UdpHeaderSliceMut::from_slice(self.payload_mut()).ok()
}
pub fn as_tcp(&self) -> Option<TcpSlice> {
self.is_tcp()
.then(|| TcpSlice::from_slice(self.payload()).ok())
.flatten()
}
pub fn as_tcp_mut(&mut self) -> Option<MutableTcpPacket> {
self.is_tcp()
.then(|| MutableTcpPacket::new(self.payload_mut()))
.flatten()
}
pub fn is_icmp_v4_or_v6(&self) -> bool {
match self {
IpPacket::Ipv4(v4) => v4.ip_header().protocol() == IpNumber::ICMP,
IpPacket::Ipv6(v6) => v6.header().next_header() == IpNumber::IPV6_ICMP,
if !self.is_tcp() {
return None;
}
TcpSlice::from_slice(self.payload()).ok()
}
pub fn as_tcp_mut(&mut self) -> Option<TcpHeaderSliceMut> {
if !self.is_tcp() {
return None;
}
TcpHeaderSliceMut::from_slice(self.payload_mut()).ok()
}
fn set_icmpv6_checksum(&mut self) {
let (src_addr, dst_addr) = match self {
IpPacket::Ipv4(_) => return,
IpPacket::Ipv6(p) => (p.get_source(), p.get_destination()),
let Some(i) = self.as_icmpv6() else {
return;
};
if let Some(mut pkt) = self.as_icmpv6() {
let checksum = icmpv6::checksum(&pkt.to_immutable(), &src_addr, &dst_addr);
pkt.set_checksum(checksum);
}
let IpPacket::Ipv6(p) = &self else {
return;
};
let checksum = i
.icmp_type()
.calc_checksum(
p.get_source().octets(),
p.get_destination().octets(),
i.payload(),
)
.expect("Payload came from the original packet");
let Some(mut i) = self.as_icmpv6_mut() else {
return;
};
i.set_checksum(checksum);
}
fn set_icmpv4_checksum(&mut self) {
if let Some(mut pkt) = self.as_icmp_mut() {
let checksum = icmp::checksum(&pkt.to_immutable());
pkt.set_checksum(checksum);
let Some(i) = self.as_icmpv4() else {
return;
};
let checksum = i.icmp_type().calc_checksum(i.payload());
let Some(mut i) = self.as_icmpv4_mut() else {
return;
};
i.set_checksum(checksum);
}
pub fn as_icmpv4(&self) -> Option<Icmpv4Slice> {
if !self.is_icmp() {
return None;
}
Icmpv4Slice::from_slice(self.payload()).ok()
}
pub fn as_icmp(&self) -> Option<IcmpPacket> {
match self {
Self::Ipv4(v4) if self.is_icmp() => Some(IcmpPacket::Ipv4(
Icmpv4Slice::from_slice(v4.payload()).ok()?,
)),
Self::Ipv6(v6) if self.is_icmpv6() => Some(IcmpPacket::Ipv6(
Icmpv6Slice::from_slice(v6.payload()).ok()?,
)),
Self::Ipv4(_) | Self::Ipv6(_) => None,
pub fn as_icmpv4_mut(&mut self) -> Option<Icmpv4HeaderSliceMut> {
if !self.is_icmp() {
return None;
}
Icmpv4HeaderSliceMut::from_slice(self.payload_mut()).ok()
}
pub fn as_icmp_mut(&mut self) -> Option<MutableIcmpPacket> {
self.is_icmp()
.then(|| MutableIcmpPacket::new(self.payload_mut()))
.flatten()
pub fn as_icmpv6(&self) -> Option<Icmpv6Slice> {
if !self.is_icmpv6() {
return None;
}
Icmpv6Slice::from_slice(self.payload()).ok()
}
fn as_icmpv6(&mut self) -> Option<MutableIcmpv6Packet> {
self.is_icmpv6()
.then(|| MutableIcmpv6Packet::new(self.payload_mut()))
.flatten()
pub fn as_icmpv6_mut(&mut self) -> Option<Icmpv6EchoHeaderSliceMut> {
if !self.is_icmpv6() {
return None;
}
Icmpv6EchoHeaderSliceMut::from_slice(self.payload_mut()).ok()
}
fn icmpv4_echo_header(&self) -> Option<IcmpEchoHeader> {
let p = self.as_icmpv4()?;
use Icmpv4Type::*;
let icmp_type = p.icmp_type();
let (EchoReply(header) | EchoRequest(header)) = icmp_type else {
return None;
};
Some(header)
}
fn icmpv6_echo_header(&self) -> Option<IcmpEchoHeader> {
let p = self.as_icmpv6()?;
use Icmpv6Type::*;
let icmp_type = p.icmp_type();
let (EchoReply(header) | EchoRequest(header)) = icmp_type else {
return None;
};
Some(header)
}
pub fn translate_destination(
@@ -790,13 +747,46 @@ impl<'a> IpPacket<'a> {
self.next_header() == IpNumber::TCP
}
fn is_icmp(&self) -> bool {
pub fn is_icmp(&self) -> bool {
self.next_header() == IpNumber::ICMP
}
fn is_icmpv6(&self) -> bool {
pub fn is_icmpv6(&self) -> bool {
self.next_header() == IpNumber::IPV6_ICMP
}
fn header_length(&self) -> usize {
match self {
IpPacket::Ipv4(v4) => v4.header_length(),
IpPacket::Ipv6(v6) => v6.header().header_len(),
}
}
pub fn packet(&self) -> &[u8] {
match self {
IpPacket::Ipv4(v4) => v4.packet(),
IpPacket::Ipv6(v6) => v6.packet(),
}
}
fn packet_mut(&mut self) -> &mut [u8] {
match self {
IpPacket::Ipv4(v4) => v4.packet_mut(),
IpPacket::Ipv6(v6) => v6.packet_mut(),
}
}
fn payload(&self) -> &[u8] {
let start = self.header_length();
&self.packet()[start..]
}
fn payload_mut(&mut self) -> &mut [u8] {
let start = self.header_length();
&mut self.packet_mut()[start..]
}
}
impl<'a> From<ConvertibleIpv4Packet<'a>> for IpPacket<'a> {
@@ -811,26 +801,6 @@ impl<'a> From<ConvertibleIpv6Packet<'a>> for IpPacket<'a> {
}
}
impl pnet_packet::Packet for IpPacket<'_> {
fn packet(&self) -> &[u8] {
for_both!(self, |i| i.packet())
}
fn payload(&self) -> &[u8] {
for_both!(self, |i| i.payload())
}
}
impl pnet_packet::MutablePacket for IpPacket<'_> {
fn packet_mut(&mut self) -> &mut [u8] {
for_both!(self, |i| i.packet_mut())
}
fn payload_mut(&mut self) -> &mut [u8] {
for_both!(self, |i| i.payload_mut())
}
}
#[derive(Debug, thiserror::Error)]
pub enum UnsupportedProtocol {
#[error("Unsupported IP protocol: {0:?}")]

View File

@@ -1,6 +1,5 @@
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
use pnet_packet::Packet;
use proptest::arbitrary::any;
use proptest::prop_oneof;
use proptest::strategy::Strategy;

View File

@@ -8,7 +8,7 @@ pub unsafe fn write_to_offset_unchecked<const N: usize>(
offset: usize,
bytes: [u8; N],
) {
debug_assert!(offset + N < slice.len());
debug_assert!(offset + N <= slice.len());
let (_front, rest) = unsafe { slice.split_at_mut_unchecked(offset) };
let (target, _rest) = unsafe { rest.split_at_mut_unchecked(N) };

View File

@@ -0,0 +1,58 @@
use crate::slice_utils::write_to_offset_unchecked;
use etherparse::TcpHeaderSlice;
pub struct TcpHeaderSliceMut<'a> {
slice: &'a mut [u8],
}
impl<'a> TcpHeaderSliceMut<'a> {
/// Creates a new [`TcpHeaderSliceMut`].
pub fn from_slice(slice: &'a mut [u8]) -> Result<Self, etherparse::err::tcp::HeaderSliceError> {
TcpHeaderSlice::from_slice(slice)?;
Ok(Self { slice })
}
pub fn set_source_port(&mut self, src: u16) {
// Safety: Slice it at least of length 20 as checked in the ctor.
unsafe { write_to_offset_unchecked(self.slice, 0, src.to_be_bytes()) };
}
pub fn set_destination_port(&mut self, dst: u16) {
// Safety: Slice it at least of length 20 as checked in the ctor.
unsafe { write_to_offset_unchecked(self.slice, 2, dst.to_be_bytes()) };
}
pub fn set_checksum(&mut self, checksum: u16) {
// Safety: Slice it at least of length 20 as checked in the ctor.
unsafe { write_to_offset_unchecked(self.slice, 16, checksum.to_be_bytes()) };
}
}
#[cfg(test)]
mod tests {
use super::*;
use etherparse::PacketBuilder;
#[test]
fn smoke() {
let mut buf = Vec::new();
PacketBuilder::ipv4([0u8; 4], [0u8; 4], 0)
.tcp(10, 20, 0, 0)
.write(&mut buf, &[])
.unwrap();
let mut slice = TcpHeaderSliceMut::from_slice(&mut buf[20..]).unwrap();
slice.set_source_port(30);
slice.set_destination_port(40);
slice.set_checksum(50);
let slice = TcpHeaderSlice::from_slice(&buf[20..]).unwrap();
assert_eq!(slice.source_port(), 30);
assert_eq!(slice.destination_port(), 40);
assert_eq!(slice.checksum(), 50);
}
}

View File

@@ -0,0 +1,65 @@
use crate::slice_utils::write_to_offset_unchecked;
use etherparse::UdpHeaderSlice;
pub struct UdpHeaderSliceMut<'a> {
slice: &'a mut [u8],
}
impl<'a> UdpHeaderSliceMut<'a> {
/// Creates a new [`UdpHeaderSliceMut`].
pub fn from_slice(slice: &'a mut [u8]) -> Result<Self, etherparse::err::LenError> {
UdpHeaderSlice::from_slice(slice)?;
Ok(Self { slice })
}
pub fn set_source_port(&mut self, src: u16) {
// Safety: Slice it at least of length 8 as checked in the ctor.
unsafe { write_to_offset_unchecked(self.slice, 0, src.to_be_bytes()) };
}
pub fn set_destination_port(&mut self, dst: u16) {
// Safety: Slice it at least of length 8 as checked in the ctor.
unsafe { write_to_offset_unchecked(self.slice, 2, dst.to_be_bytes()) };
}
pub fn set_length(&mut self, length: u16) {
// Safety: Slice it at least of length 8 as checked in the ctor.
unsafe { write_to_offset_unchecked(self.slice, 4, length.to_be_bytes()) };
}
pub fn set_checksum(&mut self, checksum: u16) {
// Safety: Slice it at least of length 8 as checked in the ctor.
unsafe { write_to_offset_unchecked(self.slice, 6, checksum.to_be_bytes()) };
}
}
#[cfg(test)]
mod tests {
use super::*;
use etherparse::PacketBuilder;
#[test]
fn smoke() {
let mut buf = Vec::new();
PacketBuilder::ipv4([0u8; 4], [0u8; 4], 0)
.udp(10, 20)
.write(&mut buf, &[])
.unwrap();
let mut slice = UdpHeaderSliceMut::from_slice(&mut buf[20..]).unwrap();
slice.set_source_port(30);
slice.set_destination_port(40);
slice.set_length(50);
slice.set_checksum(60);
let slice = UdpHeaderSlice::from_slice(&buf[20..]).unwrap();
assert_eq!(slice.source_port(), 30);
assert_eq!(slice.destination_port(), 40);
assert_eq!(slice.length(), 50);
assert_eq!(slice.checksum(), 60);
}
}