chore(connlib): implement IP translation according to RFC6145 (#5364)

As part of #4994, we need to translate IP packets between IPv4 and IPv6.
This PR introduces the `ConvertiblePacket` abstraction that implements
this.
This commit is contained in:
Gabi
2024-06-14 18:33:07 -03:00
committed by GitHub
parent 23bcf877a8
commit 8cc28499e9
11 changed files with 1293 additions and 112 deletions

3
rust/Cargo.lock generated
View File

@@ -3170,6 +3170,9 @@ version = "0.1.0"
dependencies = [
"hickory-proto",
"pnet_packet",
"proptest",
"test-strategy",
"thiserror",
]
[[package]]

View File

@@ -9,9 +9,9 @@ use boringtun::noise::{Tunn, TunnResult};
use boringtun::x25519::PublicKey;
use boringtun::{noise::rate_limiter::RateLimiter, x25519::StaticSecret};
use core::fmt;
use ip_packet::ipv4::MutableIpv4Packet;
use ip_packet::ipv6::MutableIpv6Packet;
use ip_packet::{IpPacket, MutableIpPacket, Packet as _};
use ip_packet::{
ConvertibleIpv4Packet, ConvertibleIpv6Packet, IpPacket, MutableIpPacket, Packet as _,
};
use rand::random;
use secrecy::{ExposeSecret, Secret};
use std::borrow::Cow;
@@ -1643,7 +1643,7 @@ where
transmits: &mut VecDeque<Transmit<'static>>,
now: Instant,
) -> ControlFlow<Result<(), Error>, MutableIpPacket<'b>> {
match self.tunnel.decapsulate(None, packet, buffer) {
match self.tunnel.decapsulate(None, packet, &mut buffer[20..]) {
TunnResult::Done => ControlFlow::Break(Ok(())),
TunnResult::Err(e) => ControlFlow::Break(Err(Error::Decapsulate(e))),
@@ -1652,15 +1652,20 @@ where
// In our API, we parse the packets directly as an IpPacket.
// Thus, the caller can query whatever data they'd like, not just the source IP so we don't return it in addition.
TunnResult::WriteToTunnelV4(packet, ip) => {
let ipv4_packet =
MutableIpv4Packet::new(packet).expect("boringtun verifies validity");
let packet_len = packet.len();
let ipv4_packet = ConvertibleIpv4Packet::new(&mut buffer[..(packet_len + 20)])
.expect("boringtun verifies validity");
debug_assert_eq!(ipv4_packet.get_source(), ip);
ControlFlow::Continue(ipv4_packet.into())
}
TunnResult::WriteToTunnelV6(packet, ip) => {
let ipv6_packet =
MutableIpv6Packet::new(packet).expect("boringtun verifies validity");
// For ipv4 we need to use buffer to create the ip packet because we need the extra 20 bytes at the beginning
// for ipv6 we just need this to convince the borrow-checker that `packet`'s lifetime isn't `'b`, otherwise it's taken
// as `'b` for all branches.
let packet_len = packet.len();
let ipv6_packet = ConvertibleIpv6Packet::new(&mut buffer[20..(packet_len + 20)])
.expect("boringtun verifies validity");
debug_assert_eq!(ipv6_packet.get_source(), ip);
ControlFlow::Continue(ipv6_packet.into())

View File

@@ -133,7 +133,7 @@ impl Device {
return Poll::Pending;
};
let n = std::task::ready!(tun.poll_read(buf, cx))?;
let n = std::task::ready!(tun.poll_read(&mut buf[20..], cx))?;
if n == 0 {
return Poll::Ready(Err(io::Error::new(
@@ -142,7 +142,7 @@ impl Device {
)));
}
let packet = MutableIpPacket::new(&mut buf[..n]).ok_or_else(|| {
let packet = MutableIpPacket::new(&mut buf[..(n + 20)]).ok_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidInput,
"received bytes are not an IP packet",
@@ -167,7 +167,7 @@ impl Device {
return Poll::Pending;
};
let n = std::task::ready!(tun.poll_read(buf, cx))?;
let n = std::task::ready!(tun.poll_read(&mut buf[20..], cx))?;
if n == 0 {
return Poll::Ready(Err(io::Error::new(
@@ -176,7 +176,7 @@ impl Device {
)));
}
let packet = MutableIpPacket::new(&mut buf[..n]).ok_or_else(|| {
let packet = MutableIpPacket::new(&mut buf[..(n + 20)]).ok_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidInput,
"received bytes are not an IP packet",

View File

@@ -1,6 +1,7 @@
use crate::MTU;
use connlib_shared::{
windows::{CREATE_NO_WINDOW, TUNNEL_NAME},
Callbacks, Result, DEFAULT_MTU,
Callbacks, Result,
};
use ip_network::IpNetwork;
use std::{
@@ -83,7 +84,7 @@ impl Tun {
.stdout(Stdio::null())
.status()?;
set_iface_config(adapter.get_luid(), DEFAULT_MTU)?;
set_iface_config(adapter.get_luid(), MTU as u32)?;
let session = Arc::new(adapter.start_session(wintun::MAX_RING_CAPACITY)?);
let (packet_tx, packet_rx) = mpsc::channel(5);

View File

@@ -220,8 +220,11 @@ fn build_response(original_pkt: IpPacket<'_>, mut dns_answer: Vec<u8>) -> IpPack
let response_len = dns_answer.len();
let original_dgm = original_pkt.unwrap_as_udp();
let hdr_len = original_pkt.packet_size() - original_dgm.payload().len();
let mut res_buf = Vec::with_capacity(hdr_len + response_len);
let mut res_buf = Vec::with_capacity(hdr_len + response_len + 20);
// TODO: this is some weirdness due to how MutableIpPacket is implemented
// we need an extra 20 bytes padding.
res_buf.extend_from_slice(&[0; 20]);
res_buf.extend_from_slice(&original_pkt.packet()[..hdr_len]);
res_buf.append(&mut dns_answer);
@@ -245,6 +248,8 @@ fn build_response(original_pkt: IpPacket<'_>, mut dns_answer: Vec<u8>) -> IpPack
pkt.unwrap_as_udp().set_checksum(udp_checksum);
pkt.set_ipv4_checksum();
// TODO: more of this weirdness
res_buf.drain(0..20);
IpPacket::owned(res_buf).unwrap()
}

View File

@@ -38,7 +38,6 @@ mod utils;
mod tests;
const MAX_UDP_SIZE: usize = (1 << 16) - 1;
const MTU: usize = 1280;
const REALM: &str = "firezone";
@@ -61,15 +60,13 @@ pub struct Tunnel<CB: Callbacks, TRoleState> {
/// Handles all side-effects.
io: Io,
// TODO: could we make these buffers smaller? Since all the valid packets will be at most
// MTU + Wireguard Header + optionally Data Channel + UDP header + IPV4/IPV6 header (1280 + 32 + 4 + 8 + 40 = 1364)
// or STUN control messages which afaik are smaller than that
ip4_read_buf: Box<[u8; MAX_UDP_SIZE]>,
ip6_read_buf: Box<[u8; MAX_UDP_SIZE]>,
// We need an extra 16 bytes on top of the MTU for write_buf since boringtun copies the extra AEAD tag before decrypting it
write_buf: Box<[u8; MTU + 16]>,
device_read_buf: Box<[u8; MTU]>,
write_buf: Box<[u8; MTU + 16 + 20]>,
// We have 20 extra bytes to be able to convert between ipv4 and ipv6
device_read_buf: Box<[u8; MTU + 20]>,
}
impl<CB> ClientTunnel<CB>
@@ -85,10 +82,10 @@ where
io: Io::new(sockets)?,
callbacks,
role_state: ClientState::new(private_key),
write_buf: Box::new([0u8; MTU + 16]),
write_buf: Box::new([0u8; MTU + 16 + 20]),
ip4_read_buf: Box::new([0u8; MAX_UDP_SIZE]),
ip6_read_buf: Box::new([0u8; MAX_UDP_SIZE]),
device_read_buf: Box::new([0u8; MTU]),
device_read_buf: Box::new([0u8; MTU + 20]),
})
}
@@ -187,10 +184,10 @@ where
io: Io::new(sockets)?,
callbacks,
role_state: GatewayState::new(private_key),
write_buf: Box::new([0u8; MTU + 16]),
write_buf: Box::new([0u8; MTU + 20 + 16]),
ip4_read_buf: Box::new([0u8; MAX_UDP_SIZE]),
ip6_read_buf: Box::new([0u8; MAX_UDP_SIZE]),
device_read_buf: Box::new([0u8; MTU]),
device_read_buf: Box::new([0u8; MTU + 20]),
})
}

View File

@@ -7,9 +7,17 @@ publish = false
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[features]
proptest = ["dep:proptest"]
[dependencies]
pnet_packet = { version = "0.34" }
hickory-proto = { workspace = true }
proptest = { version = "1.4.0", optional = true }
thiserror = "1"
[dev-dependencies]
test-strategy = "0.3.1"
[lints]
workspace = true

View File

@@ -0,0 +1,11 @@
# Seeds for failure cases proptest has generated in the past. It is
# automatically read and these particular cases re-run before any
# novel cases are generated.
#
# It is recommended to check this file in to source control so that
# everyone who runs the test benefits from these saved cases.
cc 28493c42d9a80886807dbafa94ceb96ca9ac11e25c5f5e907a507468039b19e4 # shrinks to input = _CanTranslateDstPacketArgs { packet: Ipv6(ConvertibleIpv6Packet { buf: Owned([96, 0, 0, 0, 0, 8, 17, 64, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 0, 0, 0, 0, 0, 0, 0, 0, 0, 8, 255, 222]) }), src_v4: 0.0.0.0, src_v6: ::ffff:127.0.0.1, dst: ::ffff:0.0.0.0 }
cc d0df47592e3f20988340dd52e6e2d5b1da3398085aa70423cbd302567d2a1162 # shrinks to input = _CanTranslateDstPacketArgs { packet: Ipv4(ConvertibleIpv4Packet { buf: Owned([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 69, 0, 0, 40, 0, 0, 0, 0, 64, 6, 122, 209, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 80, 0, 0, 128, 175, 101, 0, 0]) }), src_v4: 0.0.0.0, src_v6: ::ffff:0.0.0.0, dst: ::ffff:0.0.0.0 }
cc 63b7eaae54d351ff5dbda1673d09efc202850e2b356a5f28df8aa59bdb075581 # shrinks to input = _CanTranslateSrcPacketArgs { packet: Ipv4(ConvertibleIpv4Packet { buf: Owned([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 69, 0, 0, 28, 0, 0, 0, 0, 64, 17, 122, 210, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 8, 255, 222]) }), dst_v4: 0.0.0.0, dst_v6: c:5aa3:3b04:ca41:bbf6:d619:2c8b:9da7, src: ::ffff:254.113.152.73 }
cc 4bbd1ef685070370226af2c039fba677cc79d9a1ca41f8573805de2d6a5cd27c # shrinks to input = _CanTranslateDstPacketArgs { packet: Ipv4(ConvertibleIpv4Packet { buf: Owned([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 69, 0, 0, 60, 0, 0, 192, 0, 64, 1, 186, 193, 0, 0, 0, 0, 0, 0, 0, 0, 8, 0, 247, 255, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) }), src_v4: 0.0.0.0, src_v6: ::8:674, dst: ::ffff:127.0.0.1 }
cc 308072a334e0b7555485fe3181f8a817dd31d30dec1823eb6992f6265a56f078 # shrinks to input = _CanTranslateSrcPacketArgs { packet: Ipv4(ConvertibleIpv4Packet { buf: Owned([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 69, 0, 0, 60, 0, 0, 192, 0, 64, 1, 186, 193, 0, 0, 0, 0, 0, 0, 0, 0, 8, 0, 247, 255, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) }), dst_v4: 0.0.0.0, dst_v6: ::ffff:0.0.0.0, src: 0.0.0.0 }

File diff suppressed because it is too large Load Diff

View File

@@ -7,7 +7,7 @@ use hickory_proto::{
};
use pnet_packet::{
ip::IpNextHeaderProtocol,
ipv4::MutableIpv4Packet,
ipv4::{Ipv4Flags, MutableIpv4Packet},
ipv6::MutableIpv6Packet,
tcp::{self, MutableTcpPacket},
udp::{self, MutableUdpPacket},
@@ -23,6 +23,15 @@ pub fn icmp_request_packet(
icmp_packet(src, dst.into(), seq, identifier, IcmpKind::Request)
}
pub fn icmp_reply_packet(
src: IpAddr,
dst: impl Into<IpAddr>,
seq: u16,
identifier: u16,
) -> MutableIpPacket<'static> {
icmp_packet(src, dst.into(), seq, identifier, IcmpKind::Response)
}
pub fn icmp_response_packet(packet: IpPacket<'static>) -> MutableIpPacket<'static> {
let icmp = packet
.as_icmp()
@@ -38,12 +47,66 @@ pub fn icmp_response_packet(packet: IpPacket<'static>) -> MutableIpPacket<'stati
)
}
enum IcmpKind {
#[cfg_attr(test, derive(Debug, test_strategy::Arbitrary))]
pub(crate) enum IcmpKind {
Request,
Response,
}
fn icmp_packet(
pub(crate) fn icmp4_packet_with_options(
src: Ipv4Addr,
dst: Ipv4Addr,
seq: u16,
identifier: u16,
kind: IcmpKind,
ip_header_length: u8,
) -> MutableIpPacket<'static> {
use crate::{
icmp::{
echo_request::{IcmpCodes, MutableEchoRequestPacket},
IcmpTypes, MutableIcmpPacket,
},
ip::IpNextHeaderProtocols,
MutablePacket as _,
};
let ip_header_bytes = ip_header_length * 4;
let mut buf = vec![0u8; 60 + ip_header_bytes as usize];
ipv4_header(
src,
dst,
IpNextHeaderProtocols::Icmp,
ip_header_length,
&mut buf[20..],
);
let mut icmp_packet =
MutableIcmpPacket::new(&mut buf[(20 + ip_header_bytes as usize)..]).unwrap();
match kind {
IcmpKind::Request => {
icmp_packet.set_icmp_type(IcmpTypes::EchoRequest);
icmp_packet.set_icmp_code(IcmpCodes::NoCode);
}
IcmpKind::Response => {
icmp_packet.set_icmp_type(IcmpTypes::EchoReply);
icmp_packet.set_icmp_code(IcmpCodes::NoCode);
}
}
icmp_packet.set_checksum(0);
let mut echo_request_packet = MutableEchoRequestPacket::new(icmp_packet.packet_mut()).unwrap();
echo_request_packet.set_sequence_number(seq);
echo_request_packet.set_identifier(identifier);
let mut result = MutableIpPacket::owned(buf).unwrap();
result.update_checksum();
result
}
pub(crate) fn icmp_packet(
src: IpAddr,
dst: IpAddr,
seq: u16,
@@ -52,44 +115,7 @@ fn icmp_packet(
) -> MutableIpPacket<'static> {
match (src, dst) {
(IpAddr::V4(src), IpAddr::V4(dst)) => {
use crate::{
icmp::{
echo_request::{IcmpCodes, MutableEchoRequestPacket},
IcmpTypes, MutableIcmpPacket,
},
ip::IpNextHeaderProtocols,
MutablePacket as _, Packet as _,
};
let mut buf = vec![0u8; 60];
ipv4_header(src, dst, IpNextHeaderProtocols::Icmp, &mut buf[..]);
let mut icmp_packet = MutableIcmpPacket::new(&mut buf[20..]).unwrap();
match kind {
IcmpKind::Request => {
icmp_packet.set_icmp_type(IcmpTypes::EchoRequest);
icmp_packet.set_icmp_code(IcmpCodes::NoCode);
}
IcmpKind::Response => {
icmp_packet.set_icmp_type(IcmpTypes::EchoReply);
icmp_packet.set_icmp_code(IcmpCodes::NoCode);
}
}
icmp_packet.set_checksum(0);
let mut echo_request_packet =
MutableEchoRequestPacket::new(icmp_packet.packet_mut()).unwrap();
echo_request_packet.set_sequence_number(seq);
echo_request_packet.set_identifier(identifier);
echo_request_packet.set_checksum(crate::util::checksum(
echo_request_packet.to_immutable().packet(),
2,
));
MutableIpPacket::owned(buf).unwrap()
icmp4_packet_with_options(src, dst, seq, identifier, kind, 5)
}
(IpAddr::V6(src), IpAddr::V6(dst)) => {
use crate::{
@@ -101,11 +127,11 @@ fn icmp_packet(
MutablePacket as _,
};
let mut buf = vec![0u8; 128];
let mut buf = vec![0u8; 128 + 20];
ipv6_header(src, dst, IpNextHeaderProtocols::Icmpv6, &mut buf);
ipv6_header(src, dst, IpNextHeaderProtocols::Icmpv6, &mut buf[20..]);
let mut icmp_packet = MutableIcmpv6Packet::new(&mut buf[40..]).unwrap();
let mut icmp_packet = MutableIcmpv6Packet::new(&mut buf[60..]).unwrap();
match kind {
IcmpKind::Request => {
@@ -124,12 +150,9 @@ fn icmp_packet(
echo_request_packet.set_sequence_number(seq);
echo_request_packet.set_checksum(0);
let checksum = crate::icmpv6::checksum(&icmp_packet.to_immutable(), &src, &dst);
MutableEchoRequestPacket::new(icmp_packet.packet_mut())
.unwrap()
.set_checksum(checksum);
MutableIpPacket::owned(buf).unwrap()
let mut result = MutableIpPacket::owned(buf).unwrap();
result.update_checksum();
result
}
(IpAddr::V6(_), IpAddr::V4(_)) | (IpAddr::V4(_), IpAddr::V6(_)) => {
panic!("IPs must be of the same version")
@@ -148,22 +171,23 @@ pub fn tcp_packet(
(IpAddr::V4(src), IpAddr::V4(dst)) => {
use crate::ip::IpNextHeaderProtocols;
let len = 20 + 20 + payload.len();
let len = 20 + 20 + payload.len() + 20;
let mut buf = vec![0u8; len];
ipv4_header(src, dst, IpNextHeaderProtocols::Tcp, &mut buf);
ipv4_header(src, dst, IpNextHeaderProtocols::Tcp, 5, &mut buf[20..]);
tcp_header(saddr, daddr, sport, dport, &payload, &mut buf[20..]);
tcp_header(saddr, daddr, sport, dport, &payload, &mut buf[40..]);
MutableIpPacket::owned(buf).unwrap()
}
(IpAddr::V6(src), IpAddr::V6(dst)) => {
use crate::ip::IpNextHeaderProtocols;
let mut buf = vec![0u8; 40 + 20 + payload.len()];
let mut buf = vec![0u8; 40 + 20 + payload.len() + 20];
ipv6_header(src, dst, IpNextHeaderProtocols::Tcp, &mut buf);
ipv6_header(src, dst, IpNextHeaderProtocols::Tcp, &mut buf[20..]);
tcp_header(saddr, daddr, sport, dport, &payload, &mut buf[40..]);
tcp_header(saddr, daddr, sport, dport, &payload, &mut buf[60..]);
MutableIpPacket::owned(buf).unwrap()
}
(IpAddr::V6(_), IpAddr::V4(_)) | (IpAddr::V4(_), IpAddr::V6(_)) => {
@@ -183,22 +207,22 @@ pub fn udp_packet(
(IpAddr::V4(src), IpAddr::V4(dst)) => {
use crate::ip::IpNextHeaderProtocols;
let len = 20 + 8 + payload.len();
let len = 20 + 8 + payload.len() + 20;
let mut buf = vec![0u8; len];
ipv4_header(src, dst, IpNextHeaderProtocols::Udp, &mut buf);
ipv4_header(src, dst, IpNextHeaderProtocols::Udp, 5, &mut buf[20..]);
udp_header(saddr, daddr, sport, dport, &payload, &mut buf[20..]);
udp_header(saddr, daddr, sport, dport, &payload, &mut buf[40..]);
MutableIpPacket::owned(buf).unwrap()
}
(IpAddr::V6(src), IpAddr::V6(dst)) => {
use crate::ip::IpNextHeaderProtocols;
let mut buf = vec![0u8; 40 + 8 + payload.len()];
let mut buf = vec![0u8; 40 + 8 + payload.len() + 20];
ipv6_header(src, dst, IpNextHeaderProtocols::Udp, &mut buf);
ipv6_header(src, dst, IpNextHeaderProtocols::Udp, &mut buf[20..]);
udp_header(saddr, daddr, sport, dport, &payload, &mut buf[40..]);
udp_header(saddr, daddr, sport, dport, &payload, &mut buf[60..]);
MutableIpPacket::owned(buf).unwrap()
}
(IpAddr::V6(_), IpAddr::V4(_)) | (IpAddr::V4(_), IpAddr::V6(_)) => {
@@ -297,11 +321,25 @@ pub fn dns_err_response(packet: IpPacket<'static>, code: ResponseCode) -> Mutabl
)
}
fn ipv4_header(src: Ipv4Addr, dst: Ipv4Addr, proto: IpNextHeaderProtocol, buf: &mut [u8]) {
fn ipv4_header(
src: Ipv4Addr,
dst: Ipv4Addr,
proto: IpNextHeaderProtocol,
// We allow setting the ip header length as a way to emulate ip options without having to set ip options
ip_header_length: u8,
buf: &mut [u8],
) {
assert!(ip_header_length >= 5);
assert!(ip_header_length <= 16);
let len = buf.len();
let mut ipv4_packet = MutableIpv4Packet::new(buf).unwrap();
ipv4_packet.set_version(4);
ipv4_packet.set_header_length(5);
// TODO: packet conversion always set the flags like this.
// we still need to support fragmented packets for translated packet properly
ipv4_packet.set_flags(Ipv4Flags::DontFragment | !Ipv4Flags::MoreFragments);
ipv4_packet.set_header_length(ip_header_length);
ipv4_packet.set_total_length(len as u16);
ipv4_packet.set_ttl(64);
ipv4_packet.set_next_level_protocol(proto);

View File

@@ -0,0 +1,233 @@
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
use pnet_packet::Packet;
use proptest::arbitrary::any;
use proptest::prop_oneof;
use proptest::strategy::Strategy;
use crate::make::{icmp4_packet_with_options, icmp_packet, tcp_packet, udp_packet, IcmpKind};
use crate::MutableIpPacket;
fn tcp_packet_v4() -> impl Strategy<Value = MutableIpPacket<'static>> {
(
any::<Ipv4Addr>(),
any::<Ipv4Addr>(),
any::<u16>(),
any::<u16>(),
any::<Vec<u8>>(),
)
.prop_map(|(src, dst, sport, dport, payload)| {
tcp_packet(src.into(), dst.into(), sport, dport, payload)
})
}
fn tcp_packet_v6() -> impl Strategy<Value = MutableIpPacket<'static>> {
(
any::<Ipv6Addr>(),
any::<Ipv6Addr>(),
any::<u16>(),
any::<u16>(),
any::<Vec<u8>>(),
)
.prop_map(|(src, dst, sport, dport, payload)| {
tcp_packet(src.into(), dst.into(), sport, dport, payload)
})
}
fn udp_packet_v4() -> impl Strategy<Value = MutableIpPacket<'static>> {
(
any::<Ipv4Addr>(),
any::<Ipv4Addr>(),
any::<u16>(),
any::<u16>(),
any::<Vec<u8>>(),
)
.prop_map(|(src, dst, sport, dport, payload)| {
udp_packet(src.into(), dst.into(), sport, dport, payload)
})
}
fn udp_packet_v6() -> impl Strategy<Value = MutableIpPacket<'static>> {
(
any::<Ipv6Addr>(),
any::<Ipv6Addr>(),
any::<u16>(),
any::<u16>(),
any::<Vec<u8>>(),
)
.prop_map(|(src, dst, sport, dport, payload)| {
udp_packet(src.into(), dst.into(), sport, dport, payload)
})
}
fn icmp_packet_v4() -> impl Strategy<Value = MutableIpPacket<'static>> {
(
any::<Ipv4Addr>(),
any::<Ipv4Addr>(),
any::<u16>(),
any::<u16>(),
any::<IcmpKind>(),
)
.prop_map(|(src, dst, id, seq, kind)| icmp_packet(src.into(), dst.into(), id, seq, kind))
}
fn icmp_packet_v4_header_options() -> impl Strategy<Value = MutableIpPacket<'static>> {
(
any::<Ipv4Addr>(),
any::<Ipv4Addr>(),
any::<u16>(),
any::<u16>(),
any::<IcmpKind>(),
(5u8..15),
)
.prop_map(|(src, dst, id, seq, kind, header_length)| {
icmp4_packet_with_options(src, dst, id, seq, kind, header_length)
})
}
fn icmp_packet_v6() -> impl Strategy<Value = MutableIpPacket<'static>> {
(
any::<Ipv6Addr>(),
any::<Ipv6Addr>(),
any::<u16>(),
any::<u16>(),
any::<IcmpKind>(),
)
.prop_map(|(src, dst, id, seq, kind)| icmp_packet(src.into(), dst.into(), id, seq, kind))
}
fn packet() -> impl Strategy<Value = MutableIpPacket<'static>> {
prop_oneof![
tcp_packet_v4(),
tcp_packet_v6(),
udp_packet_v4(),
udp_packet_v6(),
icmp_packet_v4(),
icmp_packet_v6(),
]
}
#[test_strategy::proptest()]
fn can_translate_dst_packet_back_and_forth(
#[strategy(packet())] packet: MutableIpPacket<'static>,
#[strategy(any::<Ipv4Addr>())] src_v4: Ipv4Addr,
#[strategy(any::<Ipv6Addr>())] src_v6: Ipv6Addr,
#[strategy(any::<IpAddr>())] dst: IpAddr,
) {
let original_source = packet.source();
let original_destination = packet.destination();
let original_packet = packet.packet().to_vec();
let original_source_v4 = if let IpAddr::V4(v4) = original_source {
v4
} else {
Ipv4Addr::UNSPECIFIED
};
let original_source_v6 = if let IpAddr::V6(v6) = original_source {
v6
} else {
Ipv6Addr::UNSPECIFIED
};
let packet = packet.translate_destination(src_v4, src_v6, dst).unwrap();
assert!(packet.source() == IpAddr::from(src_v4) || packet.source() == IpAddr::from(src_v6) || packet.source() == original_source, "either the translated packet was set to one of the sources or it wasn't translated and it kept the old source");
assert_eq!(packet.destination(), dst);
let mut packet = packet
.translate_destination(original_source_v4, original_source_v6, original_destination)
.unwrap();
packet.update_checksum();
assert_eq!(packet.packet(), original_packet);
}
#[test_strategy::proptest()]
fn can_translate_src_packet_back_and_forth(
#[strategy(packet())] packet: MutableIpPacket<'static>,
#[strategy(any::<Ipv4Addr>())] dst_v4: Ipv4Addr,
#[strategy(any::<Ipv6Addr>())] dst_v6: Ipv6Addr,
#[strategy(any::<IpAddr>())] src: IpAddr,
) {
let original_source = packet.source();
let original_destination = packet.destination();
let original_packet = packet.packet().to_vec();
let original_destination_v4 = if let IpAddr::V4(v4) = original_destination {
v4
} else {
Ipv4Addr::UNSPECIFIED
};
let original_destination_v6 = if let IpAddr::V6(v6) = original_destination {
v6
} else {
Ipv6Addr::UNSPECIFIED
};
let packet = packet.translate_source(dst_v4, dst_v6, src).unwrap();
assert!(packet.destination() == IpAddr::from(dst_v4) || packet.destination() == IpAddr::from(dst_v6) || packet.destination() == original_destination, "either the translated packet was set to one of the destinations or it wasn't translated and it kept the old destination");
assert_eq!(packet.source(), src);
let mut packet = packet
.translate_source(
original_destination_v4,
original_destination_v6,
original_source,
)
.unwrap();
packet.update_checksum();
assert_eq!(packet.packet(), original_packet);
}
#[test_strategy::proptest()]
fn can_translate_dst_packet_with_options(
#[strategy(icmp_packet_v4_header_options())] packet: MutableIpPacket<'static>,
#[strategy(any::<Ipv4Addr>())] src_v4: Ipv4Addr,
#[strategy(any::<Ipv6Addr>())] src_v6: Ipv6Addr,
#[strategy(any::<IpAddr>())] dst: IpAddr,
) {
let source_protocol = packet.to_immutable().source_protocol().unwrap();
let destination_protocol = packet.to_immutable().destination_protocol().unwrap();
let source = packet.source();
let sequence = packet.to_immutable().as_icmp().and_then(|i| i.sequence());
let identifier = packet.to_immutable().as_icmp().and_then(|i| i.identifier());
let packet = packet.translate_destination(src_v4, src_v6, dst).unwrap();
let packet = packet.to_immutable().to_owned();
let icmp = packet.as_icmp().unwrap();
assert!(packet.source() == IpAddr::from(src_v4) || packet.source() == IpAddr::from(src_v6) || packet.source() == source, "either the translated packet was set to one of the sources or it wasn't translated and it kept the old source");
assert_eq!(packet.destination(), dst);
assert_eq!(source_protocol, packet.source_protocol().unwrap());
assert_eq!(destination_protocol, packet.destination_protocol().unwrap());
assert_eq!(sequence, icmp.sequence());
assert_eq!(identifier, icmp.identifier());
}
#[test_strategy::proptest()]
fn can_translate_src_packet_with_options(
#[strategy(icmp_packet_v4_header_options())] packet: MutableIpPacket<'static>,
#[strategy(any::<Ipv4Addr>())] dst_v4: Ipv4Addr,
#[strategy(any::<Ipv6Addr>())] dst_v6: Ipv6Addr,
#[strategy(any::<IpAddr>())] src: IpAddr,
) {
let source_protocol = packet.to_immutable().source_protocol().unwrap();
let destination_protocol = packet.to_immutable().destination_protocol().unwrap();
let destination = packet.destination();
let sequence = packet.to_immutable().as_icmp().and_then(|i| i.sequence());
let identifier = packet.to_immutable().as_icmp().and_then(|i| i.identifier());
let packet = packet.translate_source(dst_v4, dst_v6, src).unwrap();
let packet = packet.to_immutable().to_owned();
let icmp = packet.as_icmp().unwrap();
assert!(packet.destination() == IpAddr::from(dst_v4) || packet.destination() == IpAddr::from(dst_v6) || packet.destination() == destination, "either the translated packet was set to one of the destinations or it wasn't translated and it kept the old destination");
assert_eq!(packet.source(), src);
assert_eq!(source_protocol, packet.source_protocol().unwrap());
assert_eq!(destination_protocol, packet.destination_protocol().unwrap());
assert_eq!(sequence, icmp.sequence());
assert_eq!(identifier, icmp.identifier());
}