From 133c2565b2a3df8aa433cfac45015c28ce4f217d Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Wed, 11 Sep 2024 18:32:49 -0400 Subject: [PATCH] refactor(connlib): merge `IpPacket` and `MutableIpPacket` (#6652) Currently, we have two structs for representing IP packets: `IpPacket` and `MutableIpPacket`. As the name suggests, they mostly differ in mutability. This design was originally inspired by the `pnet_packet` crate which we based our `IpPacket` on. With subsequent iterations, we added more and more functionality onto our `IpPacket`, like NAT64 & NAT46 translation. As a result of that, the `MutableIpPacket` is no longer directly based on `pnet_packet` but instead just keeps an internal buffer. This duplication can be resolved by merging the two structs into a single `IpPacket`. We do this by first replacing all usages of `IpPacket` with `MutableIpPacket`, deleting `IpPacket` and renaming `MutableIpPacket` to `IpPacket`. The final design now has different `self`-receivers: Some functions take `&self`, some `&mut self` and some consume the packet using `self`. This results in a more ergonomic usage of `IpPacket` across the codebase and deletes a fair bit of code. It also takes us one step closer towards using `etherparse` for all our IP packet interaction-needs. Lastly, I am currently exploring a performance-optimisation idea that stack-allocates all IP packets and for that, the current split between `IpPacket` and `MutableIpPacket` does not really work. Related: #6366. --- .github/workflows/_rust.yml | 4 +- rust/bin-shared/benches/tunnel.rs | 12 +- rust/connlib/snownet/src/node.rs | 10 +- rust/connlib/tunnel/src/client.rs | 29 +- rust/connlib/tunnel/src/device_channel.rs | 6 +- rust/connlib/tunnel/src/dns.rs | 23 +- rust/connlib/tunnel/src/gateway.rs | 8 +- rust/connlib/tunnel/src/io.rs | 4 +- rust/connlib/tunnel/src/peer.rs | 55 +- rust/connlib/tunnel/src/peer/nat_table.rs | 33 +- rust/connlib/tunnel/src/tests/assertions.rs | 12 +- rust/connlib/tunnel/src/tests/sim_client.rs | 32 +- rust/connlib/tunnel/src/tests/sim_gateway.rs | 19 +- rust/ip-packet/src/ipv4_header_slice_mut.rs | 13 +- rust/ip-packet/src/ipv6_header_slice_mut.rs | 13 +- rust/ip-packet/src/lib.rs | 662 +++++++------------ rust/ip-packet/src/make.rs | 28 +- rust/ip-packet/src/proptest.rs | 10 +- rust/ip-packet/src/proptests.rs | 37 +- 19 files changed, 386 insertions(+), 624 deletions(-) diff --git a/.github/workflows/_rust.yml b/.github/workflows/_rust.yml index 8ab28d452..5995e3367 100644 --- a/.github/workflows/_rust.yml +++ b/.github/workflows/_rust.yml @@ -58,8 +58,8 @@ jobs: env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - run: | - rustup install --no-self-update nightly-2024-06-01 --profile minimal # The exact nightly version doesn't matter, just pin a random one. - cargo +nightly-2024-06-01 udeps --all-targets --all-features ${{ steps.setup-rust.outputs.packages }} + rustup install --no-self-update nightly-2024-09-01 --profile minimal # The exact nightly version doesn't matter, just pin a random one. + cargo +nightly-2024-09-01 udeps --all-targets --all-features ${{ steps.setup-rust.outputs.packages }} name: Check for unused dependencies - run: cargo fmt -- --check - run: cargo doc --all-features --no-deps --document-private-items ${{ steps.setup-rust.outputs.packages }} diff --git a/rust/bin-shared/benches/tunnel.rs b/rust/bin-shared/benches/tunnel.rs index 7e9ef20bd..7fb275e64 100644 --- a/rust/bin-shared/benches/tunnel.rs +++ b/rust/bin-shared/benches/tunnel.rs @@ -63,14 +63,14 @@ mod platform { let mut response_pkt = None; let mut time_spent = Duration::from_millis(0); loop { - let mut req_buf = [0u8; MTU]; - poll_fn(|cx| tun.poll_read(&mut req_buf, cx)).await?; + let mut req_buf = [0u8; MTU + 20]; + poll_fn(|cx| tun.poll_read(&mut req_buf[20..], cx)).await?; let start = Instant::now(); - let original_pkt = IpPacket::new(&req_buf).unwrap(); + let original_pkt = IpPacket::new(&mut req_buf).unwrap(); let Some(original_udp) = original_pkt.as_udp() else { continue; }; - if original_udp.get_destination() != SERVER_PORT { + if original_udp.destination_port() != SERVER_PORT { continue; } if original_udp.payload()[0] != REQ_CODE { @@ -84,8 +84,8 @@ mod platform { ip_packet::make::udp_packet( original_pkt.destination(), original_pkt.source(), - original_udp.get_destination(), - original_udp.get_source(), + original_udp.destination_port(), + original_udp.source_port(), vec![RESP_CODE], ) .unwrap() diff --git a/rust/connlib/snownet/src/node.rs b/rust/connlib/snownet/src/node.rs index 4308755d3..a9106e1a2 100644 --- a/rust/connlib/snownet/src/node.rs +++ b/rust/connlib/snownet/src/node.rs @@ -9,9 +9,7 @@ use boringtun::x25519::PublicKey; use boringtun::{noise::rate_limiter::RateLimiter, x25519::StaticSecret}; use core::fmt; use hex_display::HexDisplayExt; -use ip_packet::{ - ConvertibleIpv4Packet, ConvertibleIpv6Packet, IpPacket, MutableIpPacket, Packet as _, -}; +use ip_packet::{ConvertibleIpv4Packet, ConvertibleIpv6Packet, IpPacket, Packet as _}; use rand::rngs::StdRng; use rand::seq::IteratorRandom; use rand::{random, SeedableRng}; @@ -294,7 +292,7 @@ where packet: &[u8], now: Instant, buffer: &'b mut [u8], - ) -> Result)>, Error> { + ) -> Result)>, Error> { self.add_local_as_host_candidate(local)?; let (from, packet, relayed) = match self.allocations_try_handle(from, local, packet, now) { @@ -716,7 +714,7 @@ where packet: &[u8], buffer: &'b mut [u8], now: Instant, - ) -> ControlFlow, (TId, MutableIpPacket<'b>)> { + ) -> ControlFlow, (TId, IpPacket<'b>)> { for (cid, conn) in self.connections.iter_established_mut() { if !conn.accepts(&from) { continue; @@ -1713,7 +1711,7 @@ where allocations: &mut BTreeMap, transmits: &mut VecDeque>, now: Instant, - ) -> ControlFlow, MutableIpPacket<'b>> { + ) -> ControlFlow, IpPacket<'b>> { let _guard = self.span.enter(); let control_flow = match self.tunnel.decapsulate(None, packet, &mut buffer[20..]) { diff --git a/rust/connlib/tunnel/src/client.rs b/rust/connlib/tunnel/src/client.rs index af711f101..e2557d1f5 100644 --- a/rust/connlib/tunnel/src/client.rs +++ b/rust/connlib/tunnel/src/client.rs @@ -13,7 +13,7 @@ use connlib_shared::messages::{ use connlib_shared::{callbacks, PublicKey, StaticSecret}; use ip_network::{IpNetwork, Ipv4Network, Ipv6Network}; use ip_network_table::IpNetworkTable; -use ip_packet::{IpPacket, MutableIpPacket, Packet as _}; +use ip_packet::IpPacket; use itertools::Itertools; use crate::peer::GatewayOnClient; @@ -390,7 +390,7 @@ impl ClientState { pub(crate) fn encapsulate( &mut self, - packet: MutableIpPacket<'_>, + packet: IpPacket<'_>, now: Instant, buffer: &mut EncryptBuffer, ) -> Option { @@ -439,7 +439,7 @@ impl ClientState { let transmit = self .node - .encapsulate(gid, packet.as_immutable(), now, buffer) + .encapsulate(gid, packet, now, buffer) .inspect_err(|e| tracing::debug!(%gid, "Failed to encapsulate: {e}")) .ok()??; @@ -485,7 +485,7 @@ impl ClientState { now, ); - Some(packet.into_immutable()) + Some(packet) } pub fn add_ice_candidate(&mut self, conn_id: GatewayId, ice_candidate: String, now: Instant) { @@ -618,13 +618,10 @@ impl ClientState { /// Returns `Err` if the packet is not a DNS query. fn try_handle_dns_query<'a>( &mut self, - packet: MutableIpPacket<'a>, + packet: IpPacket<'a>, now: Instant, - ) -> Result>, (MutableIpPacket<'a>, IpAddr)> { - match self - .stub_resolver - .handle(&self.dns_mapping, packet.as_immutable()) - { + ) -> Result>, (IpPacket<'a>, IpAddr)> { + match self.stub_resolver.handle(&self.dns_mapping, &packet) { Some(dns::ResolveStrategy::LocalResponse(query)) => Ok(Some(query)), Some(dns::ResolveStrategy::ForwardQuery { upstream: server, @@ -680,7 +677,7 @@ impl ClientState { .inspect_err(|_| tracing::warn!("Failed to find original dst for DNS response")) .ok()?; - Some(ip_packet.into_immutable()) + Some(ip_packet) } pub fn on_connection_failed(&mut self, resource: ResourceId) { @@ -1327,11 +1324,11 @@ fn is_definitely_not_a_resource(ip: IpAddr) -> bool { /// In case the given packet is a DNS query, change its source IP to that of the actual DNS server. fn maybe_mangle_dns_query_to_cidr_resource<'p>( - mut packet: MutableIpPacket<'p>, + mut packet: IpPacket<'p>, dns_mapping: &BiMap, mangeled_dns_queries: &mut HashMap, now: Instant, -) -> MutableIpPacket<'p> { +) -> IpPacket<'p> { let dst = packet.destination(); let Some(srv) = dns_mapping.get_by_left(&dst) else { @@ -1356,18 +1353,18 @@ fn maybe_mangle_dns_query_to_cidr_resource<'p>( } fn maybe_mangle_dns_response_from_cidr_resource<'p>( - mut packet: MutableIpPacket<'p>, + mut packet: IpPacket<'p>, dns_mapping: &BiMap, mangeled_dns_queries: &mut HashMap, now: Instant, -) -> MutableIpPacket<'p> { +) -> IpPacket<'p> { let src_ip = packet.source(); let Some(udp) = packet.as_udp() else { return packet; }; - let src_port = udp.get_source(); + let src_port = udp.source_port(); let Some(sentinel) = dns_mapping.get_by_right(&DnsServer::from((src_ip, src_port))) else { return packet; diff --git a/rust/connlib/tunnel/src/device_channel.rs b/rust/connlib/tunnel/src/device_channel.rs index 911e7a119..e72d43621 100644 --- a/rust/connlib/tunnel/src/device_channel.rs +++ b/rust/connlib/tunnel/src/device_channel.rs @@ -1,4 +1,4 @@ -use ip_packet::{IpPacket, MutableIpPacket, Packet as _}; +use ip_packet::{IpPacket, Packet as _}; use std::io; use std::task::{Context, Poll, Waker}; use tun::Tun; @@ -30,7 +30,7 @@ impl Device { &mut self, buf: &'b mut [u8], cx: &mut Context<'_>, - ) -> Poll>> { + ) -> Poll>> { use ip_packet::Packet as _; let Some(tun) = self.tun.as_mut() else { @@ -47,7 +47,7 @@ impl Device { ))); } - let packet = MutableIpPacket::new(&mut buf[..(n + 20)]).ok_or_else(|| { + let packet = IpPacket::new(&mut buf[..(n + 20)]).ok_or_else(|| { io::Error::new( io::ErrorKind::InvalidInput, "received bytes are not an IP packet", diff --git a/rust/connlib/tunnel/src/dns.rs b/rust/connlib/tunnel/src/dns.rs index d3ba5106b..0b234af3f 100644 --- a/rust/connlib/tunnel/src/dns.rs +++ b/rust/connlib/tunnel/src/dns.rs @@ -7,7 +7,6 @@ use domain::base::{ }; use domain::rdata::AllRecordData; use ip_packet::IpPacket; -use ip_packet::Packet as _; use itertools::Itertools; use pattern::{Candidate, Pattern}; use std::collections::{BTreeMap, HashMap}; @@ -209,13 +208,13 @@ impl StubResolver { pub(crate) fn handle( &mut self, dns_mapping: &bimap::BiMap, - packet: IpPacket, + packet: &IpPacket, ) -> Option { let upstream = dns_mapping.get_by_left(&packet.destination())?.address(); let datagram = packet.as_udp()?; // We only support DNS on port 53. - if datagram.get_destination() != DNS_PORT { + if datagram.destination_port() != DNS_PORT { return None; } @@ -241,12 +240,11 @@ impl StubResolver { let packet = ip_packet::make::udp_packet( packet.destination(), packet.source(), - datagram.get_destination(), - datagram.get_source(), + datagram.destination_port(), + datagram.source_port(), response, ) - .expect("src and dst come from the same packet") - .into_immutable(); + .expect("src and dst come from the same packet"); return Some(ResolveStrategy::LocalResponse(packet)); } @@ -260,7 +258,7 @@ impl StubResolver { upstream, query_id: message.header().id(), payload: message.into_octets().to_vec(), - original_src: SocketAddr::new(packet.source(), datagram.get_source()), + original_src: SocketAddr::new(packet.source(), datagram.source_port()), }) } (Rtype::A, Some(resource)) => self.get_or_assign_a_records(domain.clone(), resource), @@ -277,7 +275,7 @@ impl StubResolver { upstream, query_id: message.header().id(), payload: message.into_octets().to_vec(), - original_src: SocketAddr::new(packet.source(), datagram.get_source()), + original_src: SocketAddr::new(packet.source(), datagram.source_port()), }) } }; @@ -286,12 +284,11 @@ impl StubResolver { let packet = ip_packet::make::udp_packet( packet.destination(), packet.source(), - datagram.get_destination(), - datagram.get_source(), + datagram.destination_port(), + datagram.source_port(), response, ) - .expect("src and dst come from the same packet") - .into_immutable(); + .expect("src and dst come from the same packet"); Some(ResolveStrategy::LocalResponse(packet)) } diff --git a/rust/connlib/tunnel/src/gateway.rs b/rust/connlib/tunnel/src/gateway.rs index 61f844176..8d526e056 100644 --- a/rust/connlib/tunnel/src/gateway.rs +++ b/rust/connlib/tunnel/src/gateway.rs @@ -11,7 +11,7 @@ use connlib_shared::messages::{ }; use connlib_shared::{DomainName, StaticSecret}; use ip_network::{Ipv4Network, Ipv6Network}; -use ip_packet::{IpPacket, MutableIpPacket}; +use ip_packet::IpPacket; use secrecy::{ExposeSecret as _, Secret}; use snownet::{EncryptBuffer, RelaySocket, ServerNode}; use std::collections::{BTreeMap, BTreeSet, VecDeque}; @@ -157,7 +157,7 @@ impl GatewayState { pub(crate) fn encapsulate( &mut self, - packet: MutableIpPacket<'_>, + packet: IpPacket<'_>, now: Instant, buffer: &mut EncryptBuffer, ) -> Option { @@ -181,7 +181,7 @@ impl GatewayState { let transmit = self .node - .encapsulate(peer.id(), packet.as_immutable(), now, buffer) + .encapsulate(peer.id(), packet, now, buffer) .inspect_err(|e| tracing::debug!(%cid, "Failed to encapsulate: {e}")) .ok()??; @@ -217,7 +217,7 @@ impl GatewayState { .inspect_err(|e| tracing::debug!(%cid, "Invalid packet: {e:#}")) .ok()?; - Some(packet.into_immutable()) + Some(packet) } pub fn add_ice_candidate(&mut self, conn_id: ClientId, ice_candidate: String, now: Instant) { diff --git a/rust/connlib/tunnel/src/io.rs b/rust/connlib/tunnel/src/io.rs index f11dbc27c..098270fb3 100644 --- a/rust/connlib/tunnel/src/io.rs +++ b/rust/connlib/tunnel/src/io.rs @@ -1,6 +1,6 @@ use crate::{device_channel::Device, sockets::Sockets, BUF_SIZE}; use futures_util::FutureExt as _; -use ip_packet::{IpPacket, MutableIpPacket}; +use ip_packet::IpPacket; use snownet::{EncryptBuffer, EncryptedPacket}; use socket_factory::{DatagramIn, DatagramOut, SocketFactory, TcpSocket, UdpSocket}; use std::{ @@ -29,7 +29,7 @@ pub struct Io { pub enum Input<'a, I> { Timeout(Instant), - Device(MutableIpPacket<'a>), + Device(IpPacket<'a>), Network(I), } diff --git a/rust/connlib/tunnel/src/peer.rs b/rust/connlib/tunnel/src/peer.rs index e9ebbb0a5..9e8708149 100644 --- a/rust/connlib/tunnel/src/peer.rs +++ b/rust/connlib/tunnel/src/peer.rs @@ -10,8 +10,7 @@ use connlib_shared::messages::{ use connlib_shared::DomainName; use ip_network::IpNetwork; use ip_network_table::IpNetworkTable; -use ip_packet::ip::IpNextHeaderProtocols; -use ip_packet::{IpPacket, MutableIpPacket}; +use ip_packet::IpPacket; use itertools::Itertools; use rangemap::RangeInclusiveSet; @@ -70,20 +69,19 @@ impl AllowRules { } fn is_allowed(&self, packet: &IpPacket) -> bool { - match packet.next_header() { - // Note: possible optimization here - // if we want to get the port here, and we assume correct formatting - // we can do packet.payload()[2..=3] (for UDP and TCP bytes 2 and 3 are the port) - // but it might be a bit harder to read - IpNextHeaderProtocols::Tcp => packet - .as_tcp() - .is_some_and(|p| self.tcp.contains(&p.get_destination())), - IpNextHeaderProtocols::Udp => packet - .as_udp() - .is_some_and(|p| self.udp.contains(&p.get_destination())), - IpNextHeaderProtocols::Icmp | IpNextHeaderProtocols::Icmpv6 => self.icmp, - _ => false, + if let Some(tcp) = packet.as_tcp() { + return self.tcp.contains(&tcp.destination_port()); } + + if let Some(udp) = packet.as_udp() { + return self.udp.contains(&udp.destination_port()); + } + + if packet.is_icmp_v4_or_v6() { + return self.icmp; + } + + false } fn add_filters<'a>(&mut self, filters: impl IntoIterator) { @@ -377,16 +375,16 @@ impl ClientOnGateway { fn transform_network_to_tun<'a>( &mut self, - packet: MutableIpPacket<'a>, + packet: IpPacket<'a>, now: Instant, - ) -> anyhow::Result> { + ) -> anyhow::Result> { let Some(state) = self.permanent_translations.get_mut(&packet.destination()) else { return Ok(packet); }; let (source_protocol, real_ip) = self.nat_table - .translate_outgoing(packet.as_immutable(), state.resolved_ip, now)?; + .translate_outgoing(&packet, state.resolved_ip, now)?; let mut packet = packet .translate_destination(self.ipv4, self.ipv6, source_protocol, real_ip) @@ -400,9 +398,9 @@ impl ClientOnGateway { pub fn decapsulate<'a>( &mut self, - packet: MutableIpPacket<'a>, + packet: IpPacket<'a>, now: Instant, - ) -> anyhow::Result> { + ) -> anyhow::Result> { self.ensure_allowed_src(&packet)?; let packet = self.transform_network_to_tun(packet, now)?; @@ -414,13 +412,10 @@ impl ClientOnGateway { pub fn encapsulate<'a>( &mut self, - packet: MutableIpPacket<'a>, + packet: IpPacket<'a>, now: Instant, - ) -> anyhow::Result>> { - let Some((proto, ip)) = self - .nat_table - .translate_incoming(packet.as_immutable(), now)? - else { + ) -> anyhow::Result>> { + let Some((proto, ip)) = self.nat_table.translate_incoming(&packet, now)? else { return Ok(Some(packet)); }; @@ -438,7 +433,7 @@ impl ClientOnGateway { Ok(Some(packet)) } - fn ensure_allowed_src(&self, packet: &MutableIpPacket<'_>) -> anyhow::Result<()> { + fn ensure_allowed_src(&self, packet: &IpPacket<'_>) -> anyhow::Result<()> { let src = packet.source(); if !self.allowed_ips().contains(&src) { @@ -449,12 +444,12 @@ impl ClientOnGateway { } /// Check if an incoming packet arriving over the network is ok to be forwarded to the TUN device. - fn ensure_allowed_dst(&self, packet: &MutableIpPacket<'_>) -> anyhow::Result<()> { + fn ensure_allowed_dst(&mut self, packet: &IpPacket<'_>) -> anyhow::Result<()> { let dst = packet.destination(); if !self .filters .longest_match(dst) - .is_some_and(|(_, filter)| filter.is_allowed(&packet.to_immutable())) + .is_some_and(|(_, filter)| filter.is_allowed(packet)) { return Err(anyhow::Error::new(DstNotAllowed(dst))); }; @@ -468,7 +463,7 @@ impl ClientOnGateway { } impl GatewayOnClient { - pub(crate) fn ensure_allowed_src(&self, packet: &MutableIpPacket) -> anyhow::Result<()> { + pub(crate) fn ensure_allowed_src(&self, packet: &IpPacket) -> anyhow::Result<()> { let src = packet.source(); if self.allowed_ips.longest_match(src).is_none() { diff --git a/rust/connlib/tunnel/src/peer/nat_table.rs b/rust/connlib/tunnel/src/peer/nat_table.rs index 02f84071d..af2497eec 100644 --- a/rust/connlib/tunnel/src/peer/nat_table.rs +++ b/rust/connlib/tunnel/src/peer/nat_table.rs @@ -44,7 +44,7 @@ impl NatTable { pub(crate) fn translate_outgoing( &mut self, - packet: IpPacket, + packet: &IpPacket, outside_dst: IpAddr, now: Instant, ) -> anyhow::Result<(Protocol, IpAddr)> { @@ -84,7 +84,7 @@ impl NatTable { pub(crate) fn translate_incoming( &mut self, - packet: IpPacket, + packet: &IpPacket, now: Instant, ) -> anyhow::Result> { let outside = (packet.destination_protocol()?, packet.source()); @@ -105,12 +105,12 @@ impl NatTable { #[cfg(all(test, feature = "proptest"))] mod tests { use super::*; - use ip_packet::{proptest::*, MutableIpPacket}; + use ip_packet::{proptest::*, IpPacket}; use proptest::prelude::*; #[test_strategy::proptest(ProptestConfig { max_local_rejects: 10_000, max_global_rejects: 10_000, ..ProptestConfig::default() })] fn translates_back_and_forth_packet( - #[strategy(udp_or_tcp_or_icmp_packet())] packet: MutableIpPacket<'static>, + #[strategy(udp_or_tcp_or_icmp_packet())] packet: IpPacket<'static>, #[strategy(any::())] outside_dst: IpAddr, #[strategy(0..120u64)] response_delay: u64, ) { @@ -121,12 +121,12 @@ mod tests { let response_delay = Duration::from_secs(response_delay); // Remember original src_p and dst - let src = packet.as_immutable().source_protocol().unwrap(); + let src = packet.source_protocol().unwrap(); let dst = packet.destination(); // Translate out let (new_source_protocol, new_dst_ip) = table - .translate_outgoing(packet.as_immutable(), outside_dst, sent_at) + .translate_outgoing(&packet, outside_dst, sent_at) .unwrap(); // Pretend we are getting a response. @@ -139,7 +139,7 @@ mod tests { // Translate in let translate_incoming = table - .translate_incoming(response.as_immutable(), sent_at + response_delay) + .translate_incoming(&response, sent_at + response_delay) .unwrap(); // Assert @@ -152,16 +152,15 @@ mod tests { #[test_strategy::proptest(ProptestConfig { max_local_rejects: 10_000, max_global_rejects: 10_000, ..ProptestConfig::default() })] fn can_handle_multiple_packets( - #[strategy(udp_or_tcp_or_icmp_packet())] packet1: MutableIpPacket<'static>, + #[strategy(udp_or_tcp_or_icmp_packet())] packet1: IpPacket<'static>, #[strategy(any::())] outside_dst1: IpAddr, - #[strategy(udp_or_tcp_or_icmp_packet())] packet2: MutableIpPacket<'static>, + #[strategy(udp_or_tcp_or_icmp_packet())] packet2: IpPacket<'static>, #[strategy(any::())] outside_dst2: IpAddr, ) { proptest::prop_assume!(packet1.destination().is_ipv4() == outside_dst1.is_ipv4()); // Required for our test to simulate a response. proptest::prop_assume!(packet2.destination().is_ipv4() == outside_dst2.is_ipv4()); // Required for our test to simulate a response. proptest::prop_assume!( - packet1.as_immutable().source_protocol().unwrap() - != packet2.as_immutable().source_protocol().unwrap() + packet1.source_protocol().unwrap() != packet2.source_protocol().unwrap() ); let mut table = NatTable::default(); @@ -171,14 +170,12 @@ mod tests { // Remember original src_p and dst let original_src_p_and_dst = packets .clone() - .map(|(p, _)| (p.as_immutable().source_protocol().unwrap(), p.destination())); + .map(|(p, _)| (p.source_protocol().unwrap(), p.destination())); // Translate out - let new_src_p_and_dst = packets.clone().map(|(p, d)| { - table - .translate_outgoing(p.as_immutable(), d, Instant::now()) - .unwrap() - }); + let new_src_p_and_dst = packets + .clone() + .map(|(p, d)| table.translate_outgoing(&p, d, Instant::now()).unwrap()); // Pretend we are getting a response. for ((p, _), (new_src_p, new_d)) in packets.iter_mut().zip(new_src_p_and_dst) { @@ -189,7 +186,7 @@ mod tests { // Translate in let responses = packets.map(|(p, _)| { table - .translate_incoming(p.as_immutable(), Instant::now()) + .translate_incoming(&p, Instant::now()) .unwrap() .unwrap() }); diff --git a/rust/connlib/tunnel/src/tests/assertions.rs b/rust/connlib/tunnel/src/tests/assertions.rs index dec8269a3..889abf7f6 100644 --- a/rust/connlib/tunnel/src/tests/assertions.rs +++ b/rust/connlib/tunnel/src/tests/assertions.rs @@ -224,11 +224,11 @@ fn assert_correct_src_and_dst_udp_ports( client_sent_request: &IpPacket<'_>, client_received_reply: &IpPacket<'_>, ) { - let client_sent_request = client_sent_request.unwrap_as_udp(); - let client_received_reply = client_received_reply.unwrap_as_udp(); + let client_sent_request = client_sent_request.as_udp().unwrap(); + let client_received_reply = client_received_reply.as_udp().unwrap(); - let req_dst = client_sent_request.get_destination(); - let res_src = client_received_reply.get_source(); + let req_dst = client_sent_request.destination_port(); + let res_src = client_received_reply.source_port(); if req_dst != res_src { tracing::error!(target: "assertions", %req_dst, %res_src, "❌ req dst port != res src port"); @@ -236,8 +236,8 @@ fn assert_correct_src_and_dst_udp_ports( tracing::info!(target: "assertions", port = %req_dst, "✅ req dst port == res src port"); } - let req_src = client_sent_request.get_source(); - let res_dst = client_received_reply.get_destination(); + let req_src = client_sent_request.source_port(); + let res_dst = client_received_reply.destination_port(); if req_src != res_dst { tracing::error!(target: "assertions", %req_src, %res_dst, "❌ req src port != res dst port"); diff --git a/rust/connlib/tunnel/src/tests/sim_client.rs b/rust/connlib/tunnel/src/tests/sim_client.rs index d4b70f46d..429f61fe1 100644 --- a/rust/connlib/tunnel/src/tests/sim_client.rs +++ b/rust/connlib/tunnel/src/tests/sim_client.rs @@ -24,7 +24,7 @@ use domain::{ }; use ip_network::{IpNetwork, Ipv4Network, Ipv6Network}; use ip_network_table::IpNetworkTable; -use ip_packet::{IpPacket, MutableIpPacket, Packet as _}; +use ip_packet::IpPacket; use itertools::Itertools as _; use prop::collection; use proptest::prelude::*; @@ -120,23 +120,19 @@ impl SimClient { pub(crate) fn encapsulate( &mut self, - packet: MutableIpPacket<'static>, + packet: IpPacket<'static>, now: Instant, ) -> Option> { { - let packet = packet.as_immutable().to_owned(); - if let Some(icmp) = packet.as_icmp() { - let echo_request = icmp.as_echo_request().expect("to be echo request"); + let echo_request = icmp.echo_request_header().expect("to be echo request"); self.sent_icmp_requests - .insert((echo_request.sequence(), echo_request.identifier()), packet); + .insert((echo_request.seq, echo_request.id), packet.clone()); } } { - let packet = packet.as_immutable().to_owned(); - if let Some(udp) = packet.as_udp() { if let Ok(message) = Message::from_slice(udp.payload()) { debug_assert!( @@ -145,11 +141,11 @@ impl SimClient { ); // Map back to upstream socket so we can assert on it correctly. - let sentinel = SocketAddr::from((packet.destination(), udp.get_destination())); + let sentinel = SocketAddr::from((packet.destination(), udp.destination_port())); let upstream = self.upstream_dns_by_sentinel(&sentinel).unwrap(); self.sent_dns_queries - .insert((upstream, message.header().id()), packet); + .insert((upstream, message.header().id()), packet.clone()); } } } @@ -178,29 +174,27 @@ impl SimClient { } /// Process an IP packet received on the client. - pub(crate) fn on_received_packet(&mut self, packet: IpPacket<'_>) { + pub(crate) fn on_received_packet(&mut self, packet: IpPacket<'static>) { if let Some(icmp) = packet.as_icmp() { - let echo_reply = icmp.as_echo_reply().expect("to be echo reply"); + let echo_reply = icmp.echo_reply_header().expect("to be echo reply"); - self.received_icmp_replies.insert( - (echo_reply.sequence(), echo_reply.identifier()), - packet.to_owned(), - ); + self.received_icmp_replies + .insert((echo_reply.seq, echo_reply.id), packet); return; }; if let Some(udp) = packet.as_udp() { - if udp.get_source() == 53 { + if udp.source_port() == 53 { let message = Message::from_slice(udp.payload()) .expect("ip packets on port 53 to be DNS packets"); // Map back to upstream socket so we can assert on it correctly. - let sentinel = SocketAddr::from((packet.source(), udp.get_source())); + let sentinel = SocketAddr::from((packet.source(), udp.source_port())); let upstream = self.upstream_dns_by_sentinel(&sentinel).unwrap(); self.received_dns_responses - .insert((upstream, message.header().id()), packet.to_owned()); + .insert((upstream, message.header().id()), packet.clone()); for record in message.answer().unwrap() { let record = record.unwrap(); diff --git a/rust/connlib/tunnel/src/tests/sim_gateway.rs b/rust/connlib/tunnel/src/tests/sim_gateway.rs index c7474b7de..901fb78fd 100644 --- a/rust/connlib/tunnel/src/tests/sim_gateway.rs +++ b/rust/connlib/tunnel/src/tests/sim_gateway.rs @@ -65,26 +65,25 @@ impl SimGateway { fn on_received_packet( &mut self, global_dns_records: &BTreeMap>, - packet: IpPacket<'_>, + packet: IpPacket<'static>, now: Instant, ) -> Option> { - let packet = packet.to_owned(); - // TODO: Instead of handling these things inline, here, should we dispatch them via `RoutingTable`? if let Some(icmp) = packet.as_icmp() { - if let Some(request) = icmp.as_echo_request() { - let payload = u64::from_be_bytes(*request.payload().first_chunk().unwrap()); - tracing::debug!(%payload, "Received ICMP request"); + if let Some(echo_request) = icmp.echo_request_header() { + let payload = icmp.payload(); + let echo_id = u64::from_be_bytes(*payload.first_chunk().unwrap()); + tracing::debug!(%echo_id, "Received ICMP request"); - self.received_icmp_requests.insert(payload, packet.clone()); + self.received_icmp_requests.insert(echo_id, packet.clone()); let echo_response = ip_packet::make::icmp_reply_packet( packet.destination(), packet.source(), - request.sequence(), - request.identifier(), - request.payload(), + echo_request.seq, + echo_request.id, + payload, ) .expect("src and dst are taken from incoming packet"); let transmit = self diff --git a/rust/ip-packet/src/ipv4_header_slice_mut.rs b/rust/ip-packet/src/ipv4_header_slice_mut.rs index a943e81f7..67b006b11 100644 --- a/rust/ip-packet/src/ipv4_header_slice_mut.rs +++ b/rust/ip-packet/src/ipv4_header_slice_mut.rs @@ -7,15 +7,12 @@ pub struct Ipv4HeaderSliceMut<'a> { impl<'a> Ipv4HeaderSliceMut<'a> { /// Creates a new [`Ipv4HeaderSliceMut`]. - /// - /// # Safety - /// - /// - The byte array must be at least of length 20. - /// - The IP version must be 4. - pub unsafe fn from_slice_unchecked(slice: &'a mut [u8]) -> Self { - debug_assert!(Ipv4HeaderSlice::from_slice(slice).is_ok()); // Debug asserts are no-ops in release mode, so this is still "unchecked". + pub fn from_slice( + slice: &'a mut [u8], + ) -> Result { + Ipv4HeaderSlice::from_slice(slice)?; - Self { slice } + Ok(Self { slice }) } pub fn set_checksum(&mut self, checksum: u16) { diff --git a/rust/ip-packet/src/ipv6_header_slice_mut.rs b/rust/ip-packet/src/ipv6_header_slice_mut.rs index f2004917f..40ff9c84d 100644 --- a/rust/ip-packet/src/ipv6_header_slice_mut.rs +++ b/rust/ip-packet/src/ipv6_header_slice_mut.rs @@ -7,15 +7,12 @@ pub struct Ipv6HeaderSliceMut<'a> { impl<'a> Ipv6HeaderSliceMut<'a> { /// Creates a new [`Ipv6HeaderSliceMut`]. - /// - /// # Safety - /// - /// - The byte array must be at least of length 40. - /// - The IP version must be 6. - pub unsafe fn from_slice_unchecked(slice: &'a mut [u8]) -> Self { - debug_assert!(Ipv6HeaderSlice::from_slice(slice).is_ok()); // Debug asserts are no-ops in release mode, so this is still "unchecked". + pub fn from_slice( + slice: &'a mut [u8], + ) -> Result { + Ipv6HeaderSlice::from_slice(slice)?; - Self { slice } + Ok(Self { slice }) } pub fn set_source(&mut self, src: [u8; 16]) { diff --git a/rust/ip-packet/src/lib.rs b/rust/ip-packet/src/lib.rs index 409cb0f54..53d668518 100644 --- a/rust/ip-packet/src/lib.rs +++ b/rust/ip-packet/src/lib.rs @@ -13,8 +13,10 @@ pub use pnet_packet::*; #[cfg(all(test, feature = "proptest"))] mod proptests; -use domain::base::Message; -use etherparse::{Ipv4Header, Ipv4HeaderSlice, Ipv6Header, Ipv6HeaderSlice}; +use etherparse::{ + IcmpEchoHeader, Icmpv4Slice, Icmpv4Type, Icmpv6Slice, Icmpv6Type, IpNumber, Ipv4Header, + Ipv4HeaderSlice, Ipv6Header, Ipv6HeaderSlice, TcpSlice, UdpSlice, +}; use ipv4_header_slice_mut::Ipv4HeaderSliceMut; use ipv6_header_slice_mut::Ipv6HeaderSliceMut; use pnet_packet::{ @@ -23,11 +25,8 @@ use pnet_packet::{ MutableIcmpPacket, }, icmpv6::{Icmpv6Types, MutableIcmpv6Packet}, - ip::{IpNextHeaderProtocol, IpNextHeaderProtocols}, - ipv4::Ipv4Packet, - ipv6::Ipv6Packet, - tcp::{MutableTcpPacket, TcpPacket}, - udp::{MutableUdpPacket, UdpPacket}, + tcp::MutableTcpPacket, + udp::MutableUdpPacket, }; use std::{ net::{IpAddr, Ipv4Addr, Ipv6Addr}, @@ -80,60 +79,77 @@ impl Protocol { } } -#[derive(Debug, PartialEq)] -pub enum IpPacket<'a> { - Ipv4(Ipv4Packet<'a>), - Ipv6(Ipv6Packet<'a>), -} - #[derive(Debug, PartialEq)] pub enum IcmpPacket<'a> { - Ipv4(icmp::IcmpPacket<'a>), - Ipv6(icmpv6::Icmpv6Packet<'a>), + Ipv4(Icmpv4Slice<'a>), + Ipv6(Icmpv6Slice<'a>), } impl<'a> IcmpPacket<'a> { pub fn icmp_type(&self) -> IcmpType { match self { - IcmpPacket::Ipv4(v4) => IcmpType::V4(v4.get_icmp_type()), - IcmpPacket::Ipv6(v6) => IcmpType::V6(v6.get_icmpv6_type()), + IcmpPacket::Ipv4(v4) => IcmpType::V4(v4.icmp_type()), + IcmpPacket::Ipv6(v6) => IcmpType::V6(v6.icmp_type()), } } pub fn identifier(&self) -> Option { - let request_id = self.as_echo_request().map(|r| r.identifier()); - let reply_id = self.as_echo_reply().map(|r| r.identifier()); - - request_id.or(reply_id) + Some(self.echo_request_header().or(self.echo_reply_header())?.id) } pub fn sequence(&self) -> Option { - let request_id = self.as_echo_request().map(|r| r.sequence()); - let reply_id = self.as_echo_reply().map(|r| r.sequence()); + Some(self.echo_request_header().or(self.echo_reply_header())?.seq) + } - request_id.or(reply_id) + pub fn payload(&self) -> &[u8] { + match self { + IcmpPacket::Ipv4(v4) => v4.payload(), + IcmpPacket::Ipv6(v6) => v6.payload(), + } + } + + pub fn echo_request_header(&self) -> Option { + #[allow( + clippy::wildcard_enum_match_arm, + reason = "We won't ever need to use other ICMP types here." + )] + match self { + IcmpPacket::Ipv4(v4) => match v4.header().icmp_type { + Icmpv4Type::EchoRequest(echo) => Some(echo), + _ => None, + }, + IcmpPacket::Ipv6(v6) => match v6.header().icmp_type { + Icmpv6Type::EchoRequest(echo) => Some(echo), + _ => None, + }, + } + } + + pub fn echo_reply_header(&self) -> Option { + #[allow( + clippy::wildcard_enum_match_arm, + reason = "We won't ever need to use other ICMP types here." + )] + match self { + IcmpPacket::Ipv4(v4) => match v4.header().icmp_type { + Icmpv4Type::EchoReply(echo) => Some(echo), + _ => None, + }, + IcmpPacket::Ipv6(v6) => match v6.header().icmp_type { + Icmpv6Type::EchoReply(echo) => Some(echo), + _ => None, + }, + } } } pub enum IcmpType { - V4(icmp::IcmpType), - V6(icmpv6::Icmpv6Type), -} - -#[derive(Debug, PartialEq)] -pub enum IcmpEchoRequest<'a> { - Ipv4(icmp::echo_request::EchoRequestPacket<'a>), - Ipv6(icmpv6::echo_request::EchoRequestPacket<'a>), -} - -#[derive(Debug, PartialEq)] -pub enum IcmpEchoReply<'a> { - Ipv4(icmp::echo_reply::EchoReplyPacket<'a>), - Ipv6(icmpv6::echo_reply::EchoReplyPacket<'a>), + V4(Icmpv4Type), + V6(Icmpv6Type), } #[derive(Debug, PartialEq, Clone)] -pub enum MutableIpPacket<'a> { +pub enum IpPacket<'a> { Ipv4(ConvertibleIpv4Packet<'a>), Ipv6(ConvertibleIpv6Packet<'a>), } @@ -206,17 +222,11 @@ impl<'a> ConvertibleIpv4Packet<'a> { } fn ip_header(&self) -> Ipv4HeaderSlice { - // TODO: Make `_unchecked` variant public upstream. Ipv4HeaderSlice::from_slice(&self.buf[20..]).expect("we checked this during `new`") } fn ip_header_mut(&mut self) -> Ipv4HeaderSliceMut { - // Safety: We checked this in `new` / `owned`. - unsafe { Ipv4HeaderSliceMut::from_slice_unchecked(&mut self.buf[20..]) } - } - - pub fn to_immutable(&self) -> Ipv4Packet { - Ipv4Packet::new(&self.buf[20..]).expect("when constructed we checked that this is some") + Ipv4HeaderSliceMut::from_slice(&mut self.buf[20..]).expect("we checked this during `new`") } pub fn get_source(&self) -> Ipv4Addr { @@ -227,18 +237,6 @@ impl<'a> ConvertibleIpv4Packet<'a> { self.ip_header().destination_addr() } - fn consume_to_immutable(self) -> Ipv4Packet<'a> { - match self.buf { - MaybeOwned::RefMut(buf) => { - Ipv4Packet::new(&buf[20..]).expect("when constructed we checked that this is some") - } - MaybeOwned::Owned(mut owned) => { - owned.drain(..20); - Ipv4Packet::owned(owned).expect("when constructed we checked that this is some") - } - } - } - fn consume_to_ipv6( mut self, src: Ipv6Addr, @@ -301,17 +299,11 @@ impl<'a> ConvertibleIpv6Packet<'a> { } fn header(&self) -> Ipv6HeaderSlice { - // FIXME: Make the `_unchecked` variant public upstream. Ipv6HeaderSlice::from_slice(&self.buf).expect("We checked this in `new` / `owned`") } fn header_mut(&mut self) -> Ipv6HeaderSliceMut { - // Safety: We checked this in `new` / `owned`. - unsafe { Ipv6HeaderSliceMut::from_slice_unchecked(&mut self.buf) } - } - - fn to_immutable(&self) -> Ipv6Packet { - Ipv6Packet::new(&self.buf).expect("when constructed we checked that this is some") + Ipv6HeaderSliceMut::from_slice(&mut self.buf).expect("We checked this in `new` / `owned`") } pub fn get_source(&self) -> Ipv6Addr { @@ -322,17 +314,6 @@ impl<'a> ConvertibleIpv6Packet<'a> { self.header().destination_addr() } - fn consume_to_immutable(self) -> Ipv6Packet<'a> { - match self.buf { - MaybeOwned::RefMut(buf) => { - Ipv6Packet::new(buf).expect("when constructed we checked that this is some") - } - MaybeOwned::Owned(owned) => { - Ipv6Packet::owned(owned).expect("when constructed we checked that this is some") - } - } - } - fn consume_to_ipv4( mut self, src: Ipv4Addr, @@ -398,19 +379,17 @@ pub fn ipv6_translated(ip: Ipv6Addr) -> Option { )) } -impl<'a> MutableIpPacket<'a> { +impl<'a> IpPacket<'a> { // TODO: this API is a bit akward, since you have to pass the extra prepended 20 bytes pub fn new(buf: &'a mut [u8]) -> Option { match buf[20] >> 4 { - 4 => Some(MutableIpPacket::Ipv4(ConvertibleIpv4Packet::new(buf)?)), - 6 => Some(MutableIpPacket::Ipv6(ConvertibleIpv6Packet::new( - &mut buf[20..], - )?)), + 4 => Some(IpPacket::Ipv4(ConvertibleIpv4Packet::new(buf)?)), + 6 => Some(IpPacket::Ipv6(ConvertibleIpv6Packet::new(&mut buf[20..])?)), _ => None, } } - pub(crate) fn owned(mut data: Vec) -> Option> { + pub(crate) fn owned(mut data: Vec) -> Option> { let packet = match data[20] >> 4 { 4 => ConvertibleIpv4Packet::owned(data)?.into(), 6 => { @@ -423,33 +402,28 @@ impl<'a> MutableIpPacket<'a> { Some(packet) } - pub fn to_immutable(&self) -> IpPacket { - for_both!(self, |i| i.to_immutable().into()) - } - - pub(crate) fn consume_to_ipv4( - self, - src: Ipv4Addr, - dst: Ipv4Addr, - ) -> Option> { + pub fn to_owned(&self) -> IpPacket<'static> { match self { - MutableIpPacket::Ipv4(pkt) => Some(MutableIpPacket::Ipv4(pkt)), - MutableIpPacket::Ipv6(pkt) => { - Some(MutableIpPacket::Ipv4(pkt.consume_to_ipv4(src, dst)?)) - } + IpPacket::Ipv4(i) => IpPacket::Ipv4(ConvertibleIpv4Packet { + buf: MaybeOwned::Owned(i.buf.to_vec()), + }), + IpPacket::Ipv6(i) => IpPacket::Ipv6(ConvertibleIpv6Packet { + buf: MaybeOwned::Owned(i.buf.to_vec()), + }), } } - pub(crate) fn consume_to_ipv6( - self, - src: Ipv6Addr, - dst: Ipv6Addr, - ) -> Option> { + pub(crate) fn consume_to_ipv4(self, src: Ipv4Addr, dst: Ipv4Addr) -> Option> { match self { - MutableIpPacket::Ipv4(pkt) => { - Some(MutableIpPacket::Ipv6(pkt.consume_to_ipv6(src, dst)?)) - } - MutableIpPacket::Ipv6(pkt) => Some(MutableIpPacket::Ipv6(pkt)), + IpPacket::Ipv4(pkt) => Some(IpPacket::Ipv4(pkt)), + IpPacket::Ipv6(pkt) => Some(IpPacket::Ipv4(pkt.consume_to_ipv4(src, dst)?)), + } + } + + pub(crate) fn consume_to_ipv6(self, src: Ipv6Addr, dst: Ipv6Addr) -> Option> { + match self { + IpPacket::Ipv4(pkt) => Some(IpPacket::Ipv6(pkt.consume_to_ipv6(src, dst)?)), + IpPacket::Ipv6(pkt) => Some(IpPacket::Ipv6(pkt)), } } @@ -461,12 +435,58 @@ impl<'a> MutableIpPacket<'a> { for_both!(self, |i| i.get_destination().into()) } + pub fn source_protocol(&self) -> Result { + if let Some(p) = self.as_tcp() { + return Ok(Protocol::Tcp(p.source_port())); + } + + if let Some(p) = self.as_udp() { + return Ok(Protocol::Udp(p.source_port())); + } + + if let Some(p) = self.as_icmp() { + let id = p.identifier().ok_or_else(|| match p.icmp_type() { + IcmpType::V4(v4) => UnsupportedProtocol::UnsupportedIcmpv4Type(v4), + IcmpType::V6(v6) => UnsupportedProtocol::UnsupportedIcmpv6Type(v6), + })?; + + return Ok(Protocol::Icmp(id)); + } + + Err(UnsupportedProtocol::UnsupportedIpPayload( + self.next_header(), + )) + } + + pub fn destination_protocol(&self) -> Result { + if let Some(p) = self.as_tcp() { + return Ok(Protocol::Tcp(p.destination_port())); + } + + if let Some(p) = self.as_udp() { + return Ok(Protocol::Udp(p.destination_port())); + } + + if let Some(p) = self.as_icmp() { + let id = p.identifier().ok_or_else(|| match p.icmp_type() { + IcmpType::V4(v4) => UnsupportedProtocol::UnsupportedIcmpv4Type(v4), + IcmpType::V6(v6) => UnsupportedProtocol::UnsupportedIcmpv6Type(v6), + })?; + + return Ok(Protocol::Icmp(id)); + } + + Err(UnsupportedProtocol::UnsupportedIpPayload( + self.next_header(), + )) + } + pub fn set_source_protocol(&mut self, v: u16) { - if let Some(mut p) = self.as_tcp() { + if let Some(mut p) = self.as_tcp_mut() { p.set_source(v); } - if let Some(mut p) = self.as_udp() { + if let Some(mut p) = self.as_udp_mut() { p.set_source(v); } @@ -474,11 +494,11 @@ impl<'a> MutableIpPacket<'a> { } pub fn set_destination_protocol(&mut self, v: u16) { - if let Some(mut p) = self.as_tcp() { + if let Some(mut p) = self.as_tcp_mut() { p.set_destination(v); } - if let Some(mut p) = self.as_udp() { + if let Some(mut p) = self.as_udp_mut() { p.set_destination(v); } @@ -486,7 +506,7 @@ impl<'a> MutableIpPacket<'a> { } fn set_icmp_identifier(&mut self, v: u16) { - if let Some(mut p) = self.as_icmp() { + if let Some(mut p) = self.as_icmp_mut() { if p.get_icmp_type() == IcmpTypes::EchoReply { let Some(mut echo_reply) = MutableEchoReplyPacket::new(p.packet_mut()) else { return; @@ -536,68 +556,89 @@ impl<'a> MutableIpPacket<'a> { } fn set_ipv4_checksum(&mut self) { - if let Self::Ipv4(p) = self { - let checksum = ipv4::checksum(&p.to_immutable()); - p.ip_header_mut().set_checksum(checksum); - } - } - - fn set_udp_checksum(&mut self) { - let checksum = if let Some(p) = self.as_immutable_udp() { - self.to_immutable().udp_checksum(&p.to_immutable()) - } else { + let Self::Ipv4(p) = self else { return; }; - self.as_udp() + let checksum = p.ip_header().to_header().calc_header_checksum(); + p.ip_header_mut().set_checksum(checksum); + } + + fn set_udp_checksum(&mut self) { + let Some(udp) = self.as_udp() else { + return; + }; + + let checksum = match &self { + IpPacket::Ipv4(v4) => udp + .to_header() + .calc_checksum_ipv4(&v4.ip_header().to_header(), udp.payload()), + IpPacket::Ipv6(v6) => udp + .to_header() + .calc_checksum_ipv6(&v6.header().to_header(), udp.payload()), + } + .expect("size of payload was previously checked to be okay"); + + self.as_udp_mut() .expect("Developer error: we can only get a UDP checksum if the packet is udp") .set_checksum(checksum); } fn set_tcp_checksum(&mut self) { - let checksum = if let Some(p) = self.as_immutable_tcp() { - self.to_immutable().tcp_checksum(&p.to_immutable()) - } else { + let Some(tcp) = self.as_tcp() else { return; }; - self.as_tcp() - .expect("Developer error: we can only get a TCP checksum if the packet is tcp") + let checksum = match &self { + IpPacket::Ipv4(v4) => tcp + .to_header() + .calc_checksum_ipv4(&v4.ip_header().to_header(), tcp.payload()), + IpPacket::Ipv6(v6) => tcp + .to_header() + .calc_checksum_ipv6(&v6.header().to_header(), tcp.payload()), + } + .expect("size of payload was previously checked to be okay"); + + self.as_tcp_mut() + .expect("Developer error: we can only get a UDP checksum if the packet is udp") .set_checksum(checksum); } - pub fn into_immutable(self) -> IpPacket<'a> { - match self { - Self::Ipv4(p) => p.consume_to_immutable().into(), - Self::Ipv6(p) => p.consume_to_immutable().into(), - } + pub fn as_udp(&self) -> Option { + self.is_udp() + .then(|| UdpSlice::from_slice(self.payload()).ok()) + .flatten() } - pub fn as_immutable(&self) -> IpPacket<'_> { - match self { - Self::Ipv4(p) => IpPacket::Ipv4(p.to_immutable()), - Self::Ipv6(p) => IpPacket::Ipv6(p.to_immutable()), - } - } - - pub fn as_udp(&mut self) -> Option { - self.to_immutable() - .is_udp() + pub fn as_udp_mut(&mut self) -> Option { + self.is_udp() .then(|| MutableUdpPacket::new(self.payload_mut())) .flatten() } - fn as_tcp(&mut self) -> Option { - self.to_immutable() - .is_tcp() + pub fn as_tcp(&self) -> Option { + self.is_tcp() + .then(|| TcpSlice::from_slice(self.payload()).ok()) + .flatten() + } + + pub fn as_tcp_mut(&mut self) -> Option { + self.is_tcp() .then(|| MutableTcpPacket::new(self.payload_mut())) .flatten() } + pub fn is_icmp_v4_or_v6(&self) -> bool { + match self { + IpPacket::Ipv4(v4) => v4.ip_header().protocol() == IpNumber::ICMP, + IpPacket::Ipv6(v6) => v6.header().next_header() == IpNumber::IPV6_ICMP, + } + } + fn set_icmpv6_checksum(&mut self) { let (src_addr, dst_addr) = match self { - MutableIpPacket::Ipv4(_) => return, - MutableIpPacket::Ipv6(p) => (p.get_source(), p.get_destination()), + IpPacket::Ipv4(_) => return, + IpPacket::Ipv6(p) => (p.get_source(), p.get_destination()), }; if let Some(mut pkt) = self.as_icmpv6() { let checksum = icmpv6::checksum(&pkt.to_immutable(), &src_addr, &dst_addr); @@ -606,50 +647,46 @@ impl<'a> MutableIpPacket<'a> { } fn set_icmpv4_checksum(&mut self) { - if let Some(mut pkt) = self.as_icmp() { + if let Some(mut pkt) = self.as_icmp_mut() { let checksum = icmp::checksum(&pkt.to_immutable()); pkt.set_checksum(checksum); } } - fn as_icmp(&mut self) -> Option { - self.to_immutable() - .is_icmp() + pub fn as_icmp(&self) -> Option { + match self { + Self::Ipv4(v4) if self.is_icmp() => Some(IcmpPacket::Ipv4( + Icmpv4Slice::from_slice(v4.payload()).ok()?, + )), + Self::Ipv6(v6) if self.is_icmpv6() => Some(IcmpPacket::Ipv6( + Icmpv6Slice::from_slice(v6.payload()).ok()?, + )), + Self::Ipv4(_) | Self::Ipv6(_) => None, + } + } + + pub fn as_icmp_mut(&mut self) -> Option { + self.is_icmp() .then(|| MutableIcmpPacket::new(self.payload_mut())) .flatten() } fn as_icmpv6(&mut self) -> Option { - self.to_immutable() - .is_icmpv6() + self.is_icmpv6() .then(|| MutableIcmpv6Packet::new(self.payload_mut())) .flatten() } - fn as_immutable_udp(&self) -> Option { - self.to_immutable() - .is_udp() - .then(|| UdpPacket::new(self.payload())) - .flatten() - } - - fn as_immutable_tcp(&self) -> Option { - self.to_immutable() - .is_tcp() - .then(|| TcpPacket::new(self.payload())) - .flatten() - } - pub fn translate_destination( mut self, src_v4: Ipv4Addr, src_v6: Ipv6Addr, src_proto: Protocol, dst: IpAddr, - ) -> Option> { + ) -> Option> { let mut packet = match (&self, dst) { - (&MutableIpPacket::Ipv4(_), IpAddr::V6(dst)) => self.consume_to_ipv6(src_v6, dst)?, - (&MutableIpPacket::Ipv6(_), IpAddr::V4(dst)) => self.consume_to_ipv4(src_v4, dst)?, + (&IpPacket::Ipv4(_), IpAddr::V6(dst)) => self.consume_to_ipv6(src_v6, dst)?, + (&IpPacket::Ipv6(_), IpAddr::V4(dst)) => self.consume_to_ipv4(src_v4, dst)?, _ => { self.set_dst(dst); self @@ -666,10 +703,10 @@ impl<'a> MutableIpPacket<'a> { dst_v6: Ipv6Addr, dst_proto: Protocol, src: IpAddr, - ) -> Option> { + ) -> Option> { let mut packet = match (&self, src) { - (&MutableIpPacket::Ipv4(_), IpAddr::V6(src)) => self.consume_to_ipv6(src, dst_v6)?, - (&MutableIpPacket::Ipv6(_), IpAddr::V4(src)) => self.consume_to_ipv4(src, dst_v4)?, + (&IpPacket::Ipv4(_), IpAddr::V6(src)) => self.consume_to_ipv6(src, dst_v6)?, + (&IpPacket::Ipv6(_), IpAddr::V4(src)) => self.consume_to_ipv4(src, dst_v4)?, _ => { self.set_src(src); self @@ -715,43 +752,22 @@ impl<'a> MutableIpPacket<'a> { } } } -} - -impl<'a> IpPacket<'a> { - pub fn new(buf: &'a [u8]) -> Option { - match buf[0] >> 4 { - 4 => Some(IpPacket::Ipv4(Ipv4Packet::new(buf)?)), - 6 => Some(IpPacket::Ipv6(Ipv6Packet::new(buf)?)), - _ => None, - } - } - - pub fn to_owned(&self) -> IpPacket<'static> { - match self { - IpPacket::Ipv4(i) => Ipv4Packet::owned(i.packet().to_vec()) - .expect("owned packet should still be valid") - .into(), - IpPacket::Ipv6(i) => Ipv6Packet::owned(i.packet().to_vec()) - .expect("owned packet should still be valid") - .into(), - } - } pub fn ipv4_header(&self) -> Option { match self { - IpPacket::Ipv4(p) => Some( + Self::Ipv4(p) => Some( Ipv4HeaderSlice::from_slice(p.packet()) .expect("Should be a valid packet") .to_header(), ), - IpPacket::Ipv6(_) => None, + Self::Ipv6(_) => None, } } pub fn ipv6_header(&self) -> Option { match self { - IpPacket::Ipv4(_) => None, - IpPacket::Ipv6(p) => Some( + Self::Ipv4(_) => None, + Self::Ipv6(p) => Some( Ipv6HeaderSlice::from_slice(p.packet()) .expect("Should be a valid packet") .to_header(), @@ -759,255 +775,42 @@ impl<'a> IpPacket<'a> { } } - pub fn source_protocol(&self) -> Result { - if let Some(p) = self.as_tcp() { - return Ok(Protocol::Tcp(p.get_source())); - } - - if let Some(p) = self.as_udp() { - return Ok(Protocol::Udp(p.get_source())); - } - - if let Some(p) = self.as_icmp() { - let id = p.identifier().ok_or_else(|| match p.icmp_type() { - IcmpType::V4(v4) => UnsupportedProtocol::UnsupportedIcmpv4Type(v4.0), - IcmpType::V6(v6) => UnsupportedProtocol::UnsupportedIcmpv6Type(v6.0), - })?; - - return Ok(Protocol::Icmp(id)); - } - - Err(UnsupportedProtocol::UnsupportedIpPayload( - self.next_header(), - )) - } - - pub fn destination_protocol(&self) -> Result { - if let Some(p) = self.as_tcp() { - return Ok(Protocol::Tcp(p.get_destination())); - } - - if let Some(p) = self.as_udp() { - return Ok(Protocol::Udp(p.get_destination())); - } - - if let Some(p) = self.as_icmp() { - let id = p.identifier().ok_or_else(|| match p.icmp_type() { - IcmpType::V4(v4) => UnsupportedProtocol::UnsupportedIcmpv4Type(v4.0), - IcmpType::V6(v6) => UnsupportedProtocol::UnsupportedIcmpv6Type(v6.0), - })?; - - return Ok(Protocol::Icmp(id)); - } - - Err(UnsupportedProtocol::UnsupportedIpPayload( - self.next_header(), - )) - } - - pub fn source(&self) -> IpAddr { - for_both!(self, |i| i.get_source().into()) - } - - pub fn destination(&self) -> IpAddr { - for_both!(self, |i| i.get_destination().into()) - } - - pub fn next_header(&self) -> IpNextHeaderProtocol { + fn next_header(&self) -> IpNumber { match self { - Self::Ipv4(p) => p.get_next_level_protocol(), - Self::Ipv6(p) => p.get_next_header(), + Self::Ipv4(p) => p.ip_header().protocol(), + Self::Ipv6(p) => p.header().next_header(), } } fn is_udp(&self) -> bool { - self.next_header() == IpNextHeaderProtocols::Udp + self.next_header() == IpNumber::UDP } fn is_tcp(&self) -> bool { - self.next_header() == IpNextHeaderProtocols::Tcp + self.next_header() == IpNumber::TCP } fn is_icmp(&self) -> bool { - self.next_header() == IpNextHeaderProtocols::Icmp + self.next_header() == IpNumber::ICMP } fn is_icmpv6(&self) -> bool { - self.next_header() == IpNextHeaderProtocols::Icmpv6 - } - - pub fn as_udp(&self) -> Option { - self.is_udp() - .then(|| UdpPacket::new(self.payload())) - .flatten() - } - - /// Unwrap this [`IpPacket`] as a [`UdpPacket`], panicking in case it is not. - pub fn unwrap_as_udp(&self) -> UdpPacket { - self.as_udp().expect("Packet is not a UDP packet") - } - - /// Unwrap this [`IpPacket`] as a DNS message, panicking in case it is not. - pub fn unwrap_as_dns(&self) -> Message> { - let udp = self.unwrap_as_udp(); - let message = match Message::from_octets(udp.payload().to_vec()) { - Ok(message) => message, - Err(e) => { - panic!("Failed to parse UDP payload as DNS message: {e}"); - } - }; - - message - } - - pub fn as_tcp(&self) -> Option { - self.is_tcp() - .then(|| TcpPacket::new(self.payload())) - .flatten() - } - - pub fn as_icmp(&self) -> Option { - match self { - IpPacket::Ipv4(v4) if v4.get_next_level_protocol() == IpNextHeaderProtocols::Icmp => { - Some(IcmpPacket::Ipv4(pnet_packet::icmp::IcmpPacket::new( - v4.payload(), - )?)) - } - IpPacket::Ipv6(v6) if v6.get_next_header() == IpNextHeaderProtocols::Icmpv6 => { - Some(IcmpPacket::Ipv6(icmpv6::Icmpv6Packet::new(v6.payload())?)) - } - IpPacket::Ipv4(_) | IpPacket::Ipv6(_) => None, - } - } - - fn udp_checksum(&self, dgm: &UdpPacket<'_>) -> u16 { - match self { - Self::Ipv4(p) => udp::ipv4_checksum(dgm, &p.get_source(), &p.get_destination()), - Self::Ipv6(p) => udp::ipv6_checksum(dgm, &p.get_source(), &p.get_destination()), - } - } - - fn tcp_checksum(&self, pkt: &TcpPacket<'_>) -> u16 { - match self { - Self::Ipv4(p) => tcp::ipv4_checksum(pkt, &p.get_source(), &p.get_destination()), - Self::Ipv6(p) => tcp::ipv6_checksum(pkt, &p.get_source(), &p.get_destination()), - } + self.next_header() == IpNumber::IPV6_ICMP } } -impl<'a> IcmpPacket<'a> { - pub fn as_echo_request(&self) -> Option { - match self { - IcmpPacket::Ipv4(v4) if matches!(v4.get_icmp_type(), icmp::IcmpTypes::EchoRequest) => { - Some(IcmpEchoRequest::Ipv4( - icmp::echo_request::EchoRequestPacket::new(v4.packet())?, - )) - } - IcmpPacket::Ipv6(v6) - if matches!(v6.get_icmpv6_type(), icmpv6::Icmpv6Types::EchoRequest) => - { - Some(IcmpEchoRequest::Ipv6( - icmpv6::echo_request::EchoRequestPacket::new(v6.packet())?, - )) - } - IcmpPacket::Ipv4(_) | IcmpPacket::Ipv6(_) => None, - } - } - - pub fn as_echo_reply(&self) -> Option { - match self { - IcmpPacket::Ipv4(v4) if matches!(v4.get_icmp_type(), icmp::IcmpTypes::EchoReply) => { - Some(IcmpEchoReply::Ipv4(icmp::echo_reply::EchoReplyPacket::new( - v4.packet(), - )?)) - } - IcmpPacket::Ipv6(v6) - if matches!(v6.get_icmpv6_type(), icmpv6::Icmpv6Types::EchoReply) => - { - Some(IcmpEchoReply::Ipv6( - icmpv6::echo_reply::EchoReplyPacket::new(v6.packet())?, - )) - } - IcmpPacket::Ipv4(_) | IcmpPacket::Ipv6(_) => None, - } - } - - pub fn is_echo_reply(&self) -> bool { - self.as_echo_reply().is_some() - } - - pub fn is_echo_request(&self) -> bool { - self.as_echo_request().is_some() - } -} - -impl<'a> IcmpEchoRequest<'a> { - pub fn sequence(&self) -> u16 { - for_both!(self, |i| i.get_sequence_number()) - } - - pub fn identifier(&self) -> u16 { - for_both!(self, |i| i.get_identifier()) - } - - pub fn payload(&self) -> &[u8] { - for_both!(self, |i| i.payload()) - } -} - -impl<'a> IcmpEchoReply<'a> { - pub fn sequence(&self) -> u16 { - for_both!(self, |i| i.get_sequence_number()) - } - - pub fn identifier(&self) -> u16 { - for_both!(self, |i| i.get_identifier()) - } -} - -impl Clone for IpPacket<'static> { - fn clone(&self) -> Self { - match self { - Self::Ipv4(ip4) => Self::Ipv4(Ipv4Packet::owned(ip4.packet().to_vec()).unwrap()), - Self::Ipv6(ip6) => Self::Ipv6(Ipv6Packet::owned(ip6.packet().to_vec()).unwrap()), - } - } -} - -impl<'a> From> for IpPacket<'a> { - fn from(value: Ipv4Packet<'a>) -> Self { - Self::Ipv4(value) - } -} - -impl<'a> From> for IpPacket<'a> { - fn from(value: Ipv6Packet<'a>) -> Self { - Self::Ipv6(value) - } -} - -impl<'a> From> for MutableIpPacket<'a> { +impl<'a> From> for IpPacket<'a> { fn from(value: ConvertibleIpv4Packet<'a>) -> Self { Self::Ipv4(value) } } -impl<'a> From> for MutableIpPacket<'a> { +impl<'a> From> for IpPacket<'a> { fn from(value: ConvertibleIpv6Packet<'a>) -> Self { Self::Ipv6(value) } } -impl pnet_packet::Packet for MutableIpPacket<'_> { - fn packet(&self) -> &[u8] { - for_both!(self, |i| i.packet()) - } - - fn payload(&self) -> &[u8] { - for_both!(self, |i| i.payload()) - } -} - impl pnet_packet::Packet for IpPacket<'_> { fn packet(&self) -> &[u8] { for_both!(self, |i| i.packet()) @@ -1018,7 +821,7 @@ impl pnet_packet::Packet for IpPacket<'_> { } } -impl pnet_packet::MutablePacket for MutableIpPacket<'_> { +impl pnet_packet::MutablePacket for IpPacket<'_> { fn packet_mut(&mut self) -> &mut [u8] { for_both!(self, |i| i.packet_mut()) } @@ -1028,21 +831,12 @@ impl pnet_packet::MutablePacket for MutableIpPacket<'_> { } } -impl<'a> PacketSize for IpPacket<'a> { - fn packet_size(&self) -> usize { - match self { - Self::Ipv4(p) => p.packet_size(), - Self::Ipv6(p) => p.packet_size(), - } - } -} - #[derive(Debug, thiserror::Error)] pub enum UnsupportedProtocol { - #[error("Unsupported IP protocol: {0}")] - UnsupportedIpPayload(IpNextHeaderProtocol), - #[error("Unsupported ICMPv4 type: {0}")] - UnsupportedIcmpv4Type(u8), - #[error("Unsupported ICMPv6 type: {0}")] - UnsupportedIcmpv6Type(u8), + #[error("Unsupported IP protocol: {0:?}")] + UnsupportedIpPayload(IpNumber), + #[error("Unsupported ICMPv4 type: {0:?}")] + UnsupportedIcmpv4Type(Icmpv4Type), + #[error("Unsupported ICMPv6 type: {0:?}")] + UnsupportedIcmpv6Type(Icmpv6Type), } diff --git a/rust/ip-packet/src/make.rs b/rust/ip-packet/src/make.rs index ddeb8746c..187bcf234 100644 --- a/rust/ip-packet/src/make.rs +++ b/rust/ip-packet/src/make.rs @@ -1,17 +1,17 @@ //! Factory module for making all kinds of packets. -use crate::{IpPacket, MutableIpPacket}; +use crate::IpPacket; use domain::{ base::{ iana::{Class, Opcode, Rcode}, - MessageBuilder, Name, Question, Record, Rtype, ToName, Ttl, + Message, MessageBuilder, Name, Question, Record, Rtype, ToName, Ttl, }, rdata::AllRecordData, }; use etherparse::PacketBuilder; use std::net::{IpAddr, SocketAddr}; -/// Helper macro to turn a [`PacketBuilder`] into a [`MutableIpPacket`]. +/// Helper macro to turn a [`PacketBuilder`] into an [`IpPacket`]. #[macro_export] macro_rules! build { ($packet:expr, $payload:ident) => {{ @@ -22,7 +22,7 @@ macro_rules! build { .write(&mut std::io::Cursor::new(&mut buf[20..]), &$payload) .expect("Buffer should be big enough"); - MutableIpPacket::owned(buf).expect("Should be a valid IP packet") + IpPacket::owned(buf).expect("Should be a valid IP packet") }}; } @@ -32,7 +32,7 @@ pub fn icmp_request_packet( seq: u16, identifier: u16, payload: &[u8], -) -> Result, IpVersionMismatch> { +) -> Result, IpVersionMismatch> { match (src, dst.into()) { (IpAddr::V4(src), IpAddr::V4(dst)) => { let packet = PacketBuilder::ipv4(src.octets(), dst.octets(), 64) @@ -56,7 +56,7 @@ pub fn icmp_reply_packet( seq: u16, identifier: u16, payload: &[u8], -) -> Result, IpVersionMismatch> { +) -> Result, IpVersionMismatch> { match (src, dst.into()) { (IpAddr::V4(src), IpAddr::V4(dst)) => { let packet = PacketBuilder::ipv4(src.octets(), dst.octets(), 64) @@ -80,7 +80,7 @@ pub fn tcp_packet( sport: u16, dport: u16, payload: Vec, -) -> Result, IpVersionMismatch> +) -> Result, IpVersionMismatch> where IP: Into, { @@ -107,7 +107,7 @@ pub fn udp_packet( sport: u16, dport: u16, payload: Vec, -) -> Result, IpVersionMismatch> +) -> Result, IpVersionMismatch> where IP: Into, { @@ -132,7 +132,7 @@ pub fn dns_query( src: SocketAddr, dst: SocketAddr, id: u16, -) -> Result, IpVersionMismatch> { +) -> Result, IpVersionMismatch> { // Create the DNS query message let mut msg_builder = MessageBuilder::new_vec(); @@ -155,12 +155,12 @@ pub fn dns_query( pub fn dns_ok_response( packet: IpPacket<'static>, resolve: impl Fn(&Name>) -> I, -) -> MutableIpPacket<'static> +) -> IpPacket<'static> where I: Iterator, { - let udp = packet.unwrap_as_udp(); - let query = packet.unwrap_as_dns(); + let udp = packet.as_udp().unwrap(); + let query = Message::from_octets(udp.payload().to_vec()).unwrap(); let response = MessageBuilder::new_vec(); let mut answers = response.start_answer(&query, Rcode::NOERROR).unwrap(); @@ -194,8 +194,8 @@ where udp_packet( packet.destination(), packet.source(), - udp.get_destination(), - udp.get_source(), + udp.destination_port(), + udp.source_port(), payload, ) .expect("src and dst are retrieved from the same packet") diff --git a/rust/ip-packet/src/proptest.rs b/rust/ip-packet/src/proptest.rs index 6d8e6f2a1..c68d6b564 100644 --- a/rust/ip-packet/src/proptest.rs +++ b/rust/ip-packet/src/proptest.rs @@ -1,8 +1,8 @@ -use crate::MutableIpPacket; +use crate::IpPacket; use proptest::{arbitrary::any, prop_oneof, strategy::Strategy}; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; -pub fn udp_packet() -> impl Strategy> { +pub fn udp_packet() -> impl Strategy> { prop_oneof![ (ip4_tuple(), any::(), any::()).prop_map(|((saddr, daddr), sport, dport)| { crate::make::udp_packet(saddr, daddr, sport, dport, Vec::new()).unwrap() @@ -13,7 +13,7 @@ pub fn udp_packet() -> impl Strategy> { ] } -pub fn tcp_packet() -> impl Strategy> { +pub fn tcp_packet() -> impl Strategy> { prop_oneof![ (ip4_tuple(), any::(), any::()).prop_map(|((saddr, daddr), sport, dport)| { crate::make::tcp_packet(saddr, daddr, sport, dport, Vec::new()).unwrap() @@ -24,7 +24,7 @@ pub fn tcp_packet() -> impl Strategy> { ] } -pub fn icmp_request_packet() -> impl Strategy> { +pub fn icmp_request_packet() -> impl Strategy> { prop_oneof![ (ip4_tuple(), any::(), any::()).prop_map(|((saddr, daddr), sport, dport)| { crate::make::icmp_request_packet(IpAddr::V4(saddr), daddr, sport, dport, &[]).unwrap() @@ -35,7 +35,7 @@ pub fn icmp_request_packet() -> impl Strategy> ] } -pub fn udp_or_tcp_or_icmp_packet() -> impl Strategy> { +pub fn udp_or_tcp_or_icmp_packet() -> impl Strategy> { prop_oneof![udp_packet(), tcp_packet(), icmp_request_packet()] } diff --git a/rust/ip-packet/src/proptests.rs b/rust/ip-packet/src/proptests.rs index 0c5554a9d..d7eaf49c0 100644 --- a/rust/ip-packet/src/proptests.rs +++ b/rust/ip-packet/src/proptests.rs @@ -5,13 +5,13 @@ use proptest::arbitrary::any; use proptest::prop_oneof; use proptest::strategy::Strategy; -use crate::{build, MutableIpPacket}; +use crate::{build, IpPacket}; use etherparse::{Ipv4Extensions, Ipv4Header, Ipv4Options, PacketBuilder}; use proptest::prelude::Just; const EMPTY_PAYLOAD: &[u8] = &[]; -fn tcp_packet_v4() -> impl Strategy> { +fn tcp_packet_v4() -> impl Strategy> { ( any::(), any::(), @@ -27,7 +27,7 @@ fn tcp_packet_v4() -> impl Strategy> { }) } -fn tcp_packet_v6() -> impl Strategy> { +fn tcp_packet_v6() -> impl Strategy> { ( any::(), any::(), @@ -43,7 +43,7 @@ fn tcp_packet_v6() -> impl Strategy> { }) } -fn udp_packet_v4() -> impl Strategy> { +fn udp_packet_v4() -> impl Strategy> { ( any::(), any::(), @@ -59,7 +59,7 @@ fn udp_packet_v4() -> impl Strategy> { }) } -fn udp_packet_v6() -> impl Strategy> { +fn udp_packet_v6() -> impl Strategy> { ( any::(), any::(), @@ -75,7 +75,7 @@ fn udp_packet_v6() -> impl Strategy> { }) } -fn icmp_request_packet_v4() -> impl Strategy> { +fn icmp_request_packet_v4() -> impl Strategy> { ( any::(), any::(), @@ -99,7 +99,7 @@ fn icmp_request_packet_v4() -> impl Strategy> { }) } -fn icmp_reply_packet_v4() -> impl Strategy> { +fn icmp_reply_packet_v4() -> impl Strategy> { ( any::(), any::(), @@ -123,7 +123,7 @@ fn icmp_reply_packet_v4() -> impl Strategy> { }) } -fn icmp_request_packet_v6() -> impl Strategy> { +fn icmp_request_packet_v6() -> impl Strategy> { ( any::(), any::(), @@ -138,7 +138,7 @@ fn icmp_request_packet_v6() -> impl Strategy> { }) } -fn icmp_reply_packet_v6() -> impl Strategy> { +fn icmp_reply_packet_v6() -> impl Strategy> { ( any::(), any::(), @@ -169,7 +169,7 @@ fn ipv4_options() -> impl Strategy { ] } -fn packet_v4() -> impl Strategy> { +fn packet_v4() -> impl Strategy> { prop_oneof![ tcp_packet_v4(), udp_packet_v4(), @@ -178,7 +178,7 @@ fn packet_v4() -> impl Strategy> { ] } -fn packet_v6() -> impl Strategy> { +fn packet_v6() -> impl Strategy> { prop_oneof![ tcp_packet_v6(), udp_packet_v6(), @@ -189,11 +189,11 @@ fn packet_v6() -> impl Strategy> { #[test_strategy::proptest()] fn nat_6446( - #[strategy(packet_v6())] packet_v6: MutableIpPacket<'static>, + #[strategy(packet_v6())] packet_v6: IpPacket<'static>, #[strategy(any::())] new_src: Ipv4Addr, #[strategy(any::())] new_dst: Ipv4Addr, ) { - let header = packet_v6.as_immutable().ipv6_header().unwrap(); + let header = packet_v6.ipv6_header().unwrap(); let payload = packet_v6.payload().to_vec(); let packet_v4 = packet_v6.consume_to_ipv4(new_src, new_dst).unwrap(); @@ -206,17 +206,17 @@ fn nat_6446( .unwrap(); new_packet_v6.update_checksum(); - assert_eq!(new_packet_v6.as_immutable().ipv6_header().unwrap(), header); + assert_eq!(new_packet_v6.ipv6_header().unwrap(), header); assert_eq!(new_packet_v6.payload(), payload); } #[test_strategy::proptest()] fn nat_4664( - #[strategy(packet_v4())] packet_v4: MutableIpPacket<'static>, + #[strategy(packet_v4())] packet_v4: IpPacket<'static>, #[strategy(any::())] new_src: Ipv6Addr, #[strategy(any::())] new_dst: Ipv6Addr, ) { - let header = packet_v4.as_immutable().ipv4_header().unwrap(); + let header = packet_v4.ipv4_header().unwrap(); let payload = packet_v4.payload().to_vec(); let packet_v6 = packet_v4.consume_to_ipv6(new_src, new_dst).unwrap(); @@ -236,9 +236,6 @@ fn nat_4664( }; header_without_options.header_checksum = header_without_options.calc_header_checksum(); - assert_eq!( - new_packet_v4.as_immutable().ipv4_header().unwrap(), - header_without_options - ); + assert_eq!(new_packet_v4.ipv4_header().unwrap(), header_without_options); assert_eq!(new_packet_v4.payload(), payload); }