From 3669f010c40c84777100f135b8611f4c19173560 Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Sat, 20 Apr 2024 01:05:29 +1000 Subject: [PATCH] chore: extract common `ip-packet` crate (#4702) With the introduction of `snownet`, we temporarily duplicated the `IpPacket` abstraction from `firezone-tunnel` because there was no common place to put it. Overtime, these have grown in size and we needed to convert back and forth between time. Lately, we've also been adding more tests to both `snownet` and `firezone-tunnel` that needed to create `IpPacket`s as test data. This seems like an appropriate time to do away with this duplication by introducing a dedicated crate that acts as a facade for the `pnet_packet` crate, extending it with the functionality that we need. Resolves: #3926. --------- Signed-off-by: Thomas Eizinger Co-authored-by: Jamil --- rust/Cargo.lock | 12 +- rust/Cargo.toml | 2 + rust/connlib/snownet/Cargo.toml | 2 +- rust/connlib/snownet/src/ip_packet.rs | 167 --------- rust/connlib/snownet/src/lib.rs | 2 - rust/connlib/snownet/src/node.rs | 24 +- rust/connlib/snownet/tests/lib.rs | 97 +----- rust/connlib/tunnel/Cargo.toml | 2 +- rust/connlib/tunnel/src/client.rs | 6 +- rust/connlib/tunnel/src/device_channel.rs | 11 +- rust/connlib/tunnel/src/dns.rs | 16 +- rust/connlib/tunnel/src/gateway.rs | 6 +- rust/connlib/tunnel/src/io.rs | 2 +- rust/connlib/tunnel/src/ip_packet.rs | 396 ---------------------- rust/connlib/tunnel/src/lib.rs | 1 - rust/connlib/tunnel/src/peer.rs | 3 +- rust/connlib/tunnel/src/tests.rs | 96 +----- rust/ip-packet/Cargo.toml | 15 + rust/ip-packet/src/lib.rs | 377 ++++++++++++++++++++ rust/ip-packet/src/make.rs | 91 +++++ rust/snownet-tests/Cargo.toml | 1 + rust/snownet-tests/src/main.rs | 3 +- 22 files changed, 548 insertions(+), 784 deletions(-) delete mode 100644 rust/connlib/snownet/src/ip_packet.rs delete mode 100644 rust/connlib/tunnel/src/ip_packet.rs create mode 100644 rust/ip-packet/Cargo.toml create mode 100644 rust/ip-packet/src/lib.rs create mode 100644 rust/ip-packet/src/make.rs diff --git a/rust/Cargo.lock b/rust/Cargo.lock index 770504651..6da129b84 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -2031,6 +2031,7 @@ dependencies = [ "futures-util", "hex", "hickory-resolver", + "ip-packet", "ip_network", "ip_network_table", "itertools 0.12.1", @@ -2038,7 +2039,6 @@ dependencies = [ "log", "netlink-packet-core", "netlink-packet-route", - "pnet_packet", "pretty_assertions", "proptest", "proptest-state-machine", @@ -3119,6 +3119,13 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "ip-packet" +version = "1.0.0" +dependencies = [ + "pnet_packet", +] + [[package]] name = "ip_network" version = "0.4.1" @@ -5688,8 +5695,8 @@ dependencies = [ "bytes", "firezone-relay", "hex", + "ip-packet", "once_cell", - "pnet_packet", "rand 0.8.5", "secrecy", "str0m", @@ -5707,6 +5714,7 @@ dependencies = [ "boringtun", "futures", "hex", + "ip-packet", "pnet_packet", "rand 0.8.5", "redis 0.25.3", diff --git a/rust/Cargo.toml b/rust/Cargo.toml index 91fb4188c..fe44009d8 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -16,6 +16,7 @@ members = [ "gui-client/src-tauri", "http-health-check", "http-test-server", + "ip-packet", ] resolver = "2" @@ -49,6 +50,7 @@ connlib-shared = { path = "connlib/shared"} firezone-tunnel = { path = "connlib/tunnel"} phoenix-channel = { path = "phoenix-channel"} http-health-check = { path = "http-health-check"} +ip-packet = { path = "ip-packet"} [workspace.lints] clippy.dbg_macro = "warn" diff --git a/rust/connlib/snownet/Cargo.toml b/rust/connlib/snownet/Cargo.toml index f7806ed15..34749c01e 100644 --- a/rust/connlib/snownet/Cargo.toml +++ b/rust/connlib/snownet/Cargo.toml @@ -6,7 +6,7 @@ edition = "2021" [dependencies] boringtun = { workspace = true } -pnet_packet = { version = "0.34" } +ip-packet = { workspace = true } rand = "0.8" secrecy = { workspace = true } str0m = { workspace = true } diff --git a/rust/connlib/snownet/src/ip_packet.rs b/rust/connlib/snownet/src/ip_packet.rs deleted file mode 100644 index db5e44985..000000000 --- a/rust/connlib/snownet/src/ip_packet.rs +++ /dev/null @@ -1,167 +0,0 @@ -use std::net::IpAddr; - -use pnet_packet::{ - ip::IpNextHeaderProtocols, - ipv4::{Ipv4Packet, MutableIpv4Packet}, - ipv6::{Ipv6Packet, MutableIpv6Packet}, - Packet, -}; - -macro_rules! for_both { - ($this:ident, |$name:ident| $body:expr) => { - match $this { - Self::Ipv4($name) => $body, - Self::Ipv6($name) => $body, - } - }; -} - -#[derive(Debug, PartialEq)] -pub enum IpPacket<'a> { - Ipv4(Ipv4Packet<'a>), - Ipv6(Ipv6Packet<'a>), -} - -#[derive(Debug, PartialEq)] -pub enum MutableIpPacket<'a> { - Ipv4(MutableIpv4Packet<'a>), - Ipv6(MutableIpv6Packet<'a>), -} - -impl<'a> MutableIpPacket<'a> { - pub fn new(buf: &'a mut [u8]) -> Option { - match buf[0] >> 4 { - 4 => Some(MutableIpPacket::Ipv4(MutableIpv4Packet::new(buf)?)), - 6 => Some(MutableIpPacket::Ipv6(MutableIpv6Packet::new(buf)?)), - _ => None, - } - } - - pub fn owned(data: Vec) -> Option> { - let packet = match data[0] >> 4 { - 4 => MutableIpv4Packet::owned(data)?.into(), - 6 => MutableIpv6Packet::owned(data)?.into(), - _ => return None, - }; - - Some(packet) - } - - pub fn to_owned(&self) -> MutableIpPacket<'static> { - match self { - MutableIpPacket::Ipv4(i) => MutableIpv4Packet::owned(i.packet().to_vec()) - .expect("owned packet is still valid") - .into(), - MutableIpPacket::Ipv6(i) => MutableIpv6Packet::owned(i.packet().to_vec()) - .expect("owned packet is still valid") - .into(), - } - } - - pub fn to_immutable(&self) -> IpPacket { - for_both!(self, |i| i.to_immutable().into()) - } - - pub fn source(&self) -> IpAddr { - for_both!(self, |i| i.get_source().into()) - } - - pub fn destination(&self) -> IpAddr { - for_both!(self, |i| i.get_destination().into()) - } -} - -impl<'a> IpPacket<'a> { - pub fn new(buf: &'a [u8]) -> Option { - match buf[0] >> 4 { - 4 => Some(IpPacket::Ipv4(Ipv4Packet::new(buf)?)), - 6 => Some(IpPacket::Ipv6(Ipv6Packet::new(buf)?)), - _ => None, - } - } - - pub fn to_owned(&self) -> IpPacket<'static> { - match self { - IpPacket::Ipv4(i) => Ipv4Packet::owned(i.packet().to_vec()) - .expect("owned packet is still valid") - .into(), - IpPacket::Ipv6(i) => Ipv6Packet::owned(i.packet().to_vec()) - .expect("owned packet is still valid") - .into(), - } - } - - pub fn source(&self) -> IpAddr { - for_both!(self, |i| i.get_source().into()) - } - - pub fn destination(&self) -> IpAddr { - for_both!(self, |i| i.get_destination().into()) - } - - pub fn udp_payload(&self) -> &[u8] { - debug_assert_eq!( - match self { - IpPacket::Ipv4(i) => i.get_next_level_protocol(), - IpPacket::Ipv6(i) => i.get_next_header(), - }, - IpNextHeaderProtocols::Udp - ); - - for_both!(self, |i| &i.payload()[8..]) - } -} - -impl<'a> From> for IpPacket<'a> { - fn from(value: Ipv4Packet<'a>) -> Self { - Self::Ipv4(value) - } -} - -impl<'a> From> for IpPacket<'a> { - fn from(value: Ipv6Packet<'a>) -> Self { - Self::Ipv6(value) - } -} - -impl<'a> From> for MutableIpPacket<'a> { - fn from(value: MutableIpv4Packet<'a>) -> Self { - Self::Ipv4(value) - } -} - -impl<'a> From> for MutableIpPacket<'a> { - fn from(value: MutableIpv6Packet<'a>) -> Self { - Self::Ipv6(value) - } -} - -impl pnet_packet::Packet for MutableIpPacket<'_> { - fn packet(&self) -> &[u8] { - for_both!(self, |i| i.packet()) - } - - fn payload(&self) -> &[u8] { - for_both!(self, |i| i.payload()) - } -} - -impl pnet_packet::Packet for IpPacket<'_> { - fn packet(&self) -> &[u8] { - for_both!(self, |i| i.packet()) - } - - fn payload(&self) -> &[u8] { - for_both!(self, |i| i.payload()) - } -} - -impl pnet_packet::MutablePacket for MutableIpPacket<'_> { - fn packet_mut(&mut self) -> &mut [u8] { - for_both!(self, |i| i.packet_mut()) - } - - fn payload_mut(&mut self) -> &mut [u8] { - for_both!(self, |i| i.payload_mut()) - } -} diff --git a/rust/connlib/snownet/src/lib.rs b/rust/connlib/snownet/src/lib.rs index 6c0886cae..6dde0cf94 100644 --- a/rust/connlib/snownet/src/lib.rs +++ b/rust/connlib/snownet/src/lib.rs @@ -4,14 +4,12 @@ mod allocation; mod backoff; mod channel_data; mod index; -mod ip_packet; mod node; mod ringbuffer; mod stats; mod stun_binding; mod utils; -pub use ip_packet::{IpPacket, MutableIpPacket}; pub use node::{ Answer, Client, ClientNode, Credentials, Error, Event, Node, Offer, Server, ServerNode, Transmit, diff --git a/rust/connlib/snownet/src/node.rs b/rust/connlib/snownet/src/node.rs index 51ec2f5c5..ae1ed37e2 100644 --- a/rust/connlib/snownet/src/node.rs +++ b/rust/connlib/snownet/src/node.rs @@ -1,14 +1,22 @@ +use crate::allocation::{Allocation, Socket}; +use crate::index::IndexLfsr; +use crate::stats::{ConnectionStats, NodeStats}; +use crate::stun_binding::StunBinding; +use crate::utils::earliest; +use boringtun::noise::errors::WireGuardError; use boringtun::noise::{Tunn, TunnResult}; use boringtun::x25519::PublicKey; use boringtun::{noise::rate_limiter::RateLimiter, x25519::StaticSecret}; use core::{fmt, slice}; -use pnet_packet::ipv4::MutableIpv4Packet; -use pnet_packet::ipv6::MutableIpv6Packet; -use pnet_packet::Packet; +use ip_packet::ipv4::MutableIpv4Packet; +use ip_packet::ipv6::MutableIpv6Packet; +use ip_packet::{IpPacket, MutableIpPacket, Packet as _}; use rand::random; use secrecy::{ExposeSecret, Secret}; +use std::borrow::Cow; use std::hash::Hash; use std::marker::PhantomData; +use std::ops::ControlFlow; use std::time::{Duration, Instant}; use std::{ collections::{HashMap, HashSet, VecDeque}, @@ -18,16 +26,6 @@ use std::{ use str0m::ice::{IceAgent, IceAgentEvent, IceCreds, StunMessage, StunPacket}; use str0m::net::Protocol; use str0m::{Candidate, CandidateKind, IceConnectionState}; - -use crate::allocation::{Allocation, Socket}; -use crate::index::IndexLfsr; -use crate::stats::{ConnectionStats, NodeStats}; -use crate::stun_binding::StunBinding; -use crate::utils::earliest; -use crate::{IpPacket, MutableIpPacket}; -use boringtun::noise::errors::WireGuardError; -use std::borrow::Cow; -use std::ops::ControlFlow; use stun_codec::rfc5389::attributes::{Realm, Username}; use tracing::{field, info_span, Span}; diff --git a/rust/connlib/snownet/tests/lib.rs b/rust/connlib/snownet/tests/lib.rs index c50abd340..d8305ee6c 100644 --- a/rust/connlib/snownet/tests/lib.rs +++ b/rust/connlib/snownet/tests/lib.rs @@ -1,7 +1,8 @@ use boringtun::x25519::{PublicKey, StaticSecret}; use firezone_relay::{AddressFamily, AllocationPort, ClientSocket, IpStack, PeerSocket}; +use ip_packet::*; use rand::rngs::OsRng; -use snownet::{Answer, ClientNode, Event, IpPacket, MutableIpPacket, ServerNode, Transmit}; +use snownet::{Answer, ClientNode, Event, ServerNode, Transmit}; use std::{ collections::{HashSet, VecDeque}, iter, @@ -747,8 +748,11 @@ impl TestNode { let transmit = self .span .in_scope(|| { - self.node - .encapsulate(id, icmp_request_packet(src, dst).to_immutable(), now) + self.node.encapsulate( + id, + ip_packet::make::icmp_request_packet(src, dst).to_immutable(), + now, + ) }) .unwrap() .unwrap() @@ -924,90 +928,3 @@ fn progress( } } } - -fn icmp_request_packet(source: IpAddr, dst: IpAddr) -> MutableIpPacket<'static> { - match (source, dst) { - (IpAddr::V4(src), IpAddr::V4(dst)) => { - use pnet_packet::{ - icmp::{ - echo_request::{IcmpCodes, MutableEchoRequestPacket}, - IcmpTypes, MutableIcmpPacket, - }, - ip::IpNextHeaderProtocols, - ipv4::MutableIpv4Packet, - MutablePacket as _, Packet as _, - }; - - let mut buf = vec![0u8; 60]; - - let mut ipv4_packet = MutableIpv4Packet::new(&mut buf[..]).unwrap(); - ipv4_packet.set_version(4); - ipv4_packet.set_header_length(5); - ipv4_packet.set_total_length(60); - ipv4_packet.set_ttl(64); - ipv4_packet.set_next_level_protocol(IpNextHeaderProtocols::Icmp); - ipv4_packet.set_source(src); - ipv4_packet.set_destination(dst); - ipv4_packet.set_checksum(pnet_packet::ipv4::checksum(&ipv4_packet.to_immutable())); - - let mut icmp_packet = MutableIcmpPacket::new(&mut buf[20..]).unwrap(); - icmp_packet.set_icmp_type(IcmpTypes::EchoRequest); - icmp_packet.set_icmp_code(IcmpCodes::NoCode); - icmp_packet.set_checksum(0); - - let mut echo_request_packet = - MutableEchoRequestPacket::new(icmp_packet.payload_mut()).unwrap(); - echo_request_packet.set_sequence_number(1); - echo_request_packet.set_identifier(0); - echo_request_packet.set_checksum(pnet_packet::util::checksum( - echo_request_packet.to_immutable().packet(), - 2, - )); - - MutableIpPacket::owned(buf).unwrap() - } - (IpAddr::V6(src), IpAddr::V6(dst)) => { - use pnet_packet::{ - icmpv6::{ - echo_request::MutableEchoRequestPacket, Icmpv6Code, Icmpv6Types, - MutableIcmpv6Packet, - }, - ip::IpNextHeaderProtocols, - ipv6::MutableIpv6Packet, - MutablePacket as _, - }; - - let mut buf = vec![0u8; 128]; - - let mut ipv6_packet = MutableIpv6Packet::new(&mut buf[..]).unwrap(); - - ipv6_packet.set_version(6); - ipv6_packet.set_payload_length(16); - ipv6_packet.set_next_header(IpNextHeaderProtocols::Icmpv6); - ipv6_packet.set_hop_limit(64); - ipv6_packet.set_source(src); - ipv6_packet.set_destination(dst); - - let mut icmp_packet = MutableIcmpv6Packet::new(&mut buf[40..]).unwrap(); - - icmp_packet.set_icmpv6_type(Icmpv6Types::EchoRequest); - icmp_packet.set_icmpv6_code(Icmpv6Code::new(0)); // No code for echo request - - let mut echo_request_packet = - MutableEchoRequestPacket::new(icmp_packet.payload_mut()).unwrap(); - echo_request_packet.set_identifier(0); - echo_request_packet.set_sequence_number(1); - echo_request_packet.set_checksum(0); - - let checksum = pnet_packet::icmpv6::checksum(&icmp_packet.to_immutable(), &src, &dst); - MutableEchoRequestPacket::new(icmp_packet.payload_mut()) - .unwrap() - .set_checksum(checksum); - - MutableIpPacket::owned(buf).unwrap() - } - (IpAddr::V6(_), IpAddr::V4(_)) | (IpAddr::V4(_), IpAddr::V6(_)) => { - panic!("IPs must be of the same version") - } - } -} diff --git a/rust/connlib/tunnel/Cargo.toml b/rust/connlib/tunnel/Cargo.toml index 1860a98cc..66357858d 100644 --- a/rust/connlib/tunnel/Cargo.toml +++ b/rust/connlib/tunnel/Cargo.toml @@ -23,7 +23,6 @@ ip_network_table = { version = "0.2", default-features = false } domain = { workspace = true } boringtun = { workspace = true } chrono = { workspace = true } -pnet_packet = { version = "0.34" } futures-bounded = { workspace = true } hickory-resolver = { workspace = true, features = ["tokio-runtime"] } bimap = "0.6" @@ -32,6 +31,7 @@ snownet = { workspace = true } quinn-udp = { git = "https://github.com/quinn-rs/quinn", branch = "main"} hex = "0.4.3" proptest = { version = "1.4.0", optional = true } +ip-packet = { workspace = true } # Needed for Android logging until tracing is fixed log = "0.4" diff --git a/rust/connlib/tunnel/src/client.rs b/rust/connlib/tunnel/src/client.rs index 9fb92d254..9882dcdda 100644 --- a/rust/connlib/tunnel/src/client.rs +++ b/rust/connlib/tunnel/src/client.rs @@ -1,4 +1,3 @@ -use crate::ip_packet::{IpPacket, MutableIpPacket}; use crate::peer::{PacketTransformClient, Peer}; use crate::peer_store::PeerStore; use crate::{dns, dns::DnsQuery}; @@ -13,6 +12,7 @@ use connlib_shared::{Callbacks, Dname, PublicKey, StaticSecret}; use domain::base::Rtype; use ip_network::{IpNetwork, Ipv4Network, Ipv6Network}; use ip_network_table::IpNetworkTable; +use ip_packet::{IpPacket, MutableIpPacket}; use itertools::Itertools; use crate::utils::{earliest, stun, turn}; @@ -331,7 +331,7 @@ impl ClientState { let transmit = self .node - .encapsulate(peer.conn_id, packet.as_immutable().into(), Instant::now()) + .encapsulate(peer.conn_id, packet.as_immutable(), Instant::now()) .inspect_err(|e| tracing::debug!("Failed to encapsulate: {e}")) .ok()??; @@ -362,7 +362,7 @@ impl ClientState { return None; }; - let packet = match peer.untransform(packet.into()) { + let packet = match peer.untransform(packet) { Ok(packet) => packet, Err(e) => { tracing::warn!(%conn_id, %local, %from, "Failed to transform packet: {e}"); diff --git a/rust/connlib/tunnel/src/device_channel.rs b/rust/connlib/tunnel/src/device_channel.rs index 7eb823716..271d4a2bd 100644 --- a/rust/connlib/tunnel/src/device_channel.rs +++ b/rust/connlib/tunnel/src/device_channel.rs @@ -25,11 +25,10 @@ use tun_android as tun; #[cfg(target_family = "unix")] mod utils; -use crate::ip_packet::{IpPacket, MutableIpPacket}; use connlib_shared::{error::ConnlibError, messages::Interface, Callbacks, Error}; use connlib_shared::{Cidrv4, Cidrv6}; use ip_network::IpNetwork; -use pnet_packet::Packet; +use ip_packet::{IpPacket, MutableIpPacket, Packet as _}; use std::collections::HashSet; use std::io; use std::net::IpAddr; @@ -133,7 +132,7 @@ impl Device { buf: &'b mut [u8], cx: &mut Context<'_>, ) -> Poll>> { - use pnet_packet::Packet as _; + use ip_packet::Packet as _; let Some(tun) = self.tun.as_mut() else { self.waker = Some(cx.waker().clone()); @@ -167,7 +166,7 @@ impl Device { buf: &'b mut [u8], cx: &mut Context<'_>, ) -> Poll>> { - use pnet_packet::Packet as _; + use ip_packet::Packet as _; let Some(tun) = self.tun.as_mut() else { self.waker = Some(cx.waker().clone()); @@ -215,8 +214,8 @@ impl Device { tracing::trace!(target: "wire", to = "device", dst = %packet.destination(), src = %packet.source(), bytes = %packet.packet().len()); match packet { - IpPacket::Ipv4Packet(msg) => self.tun()?.write4(msg.packet()), - IpPacket::Ipv6Packet(msg) => self.tun()?.write6(msg.packet()), + IpPacket::Ipv4(msg) => self.tun()?.write4(msg.packet()), + IpPacket::Ipv6(msg) => self.tun()?.write6(msg.packet()), } } diff --git a/rust/connlib/tunnel/src/dns.rs b/rust/connlib/tunnel/src/dns.rs index 9a330854f..1a866a405 100644 --- a/rust/connlib/tunnel/src/dns.rs +++ b/rust/connlib/tunnel/src/dns.rs @@ -1,5 +1,4 @@ use crate::client::DnsResource; -use crate::ip_packet::{to_dns, IpPacket, MutableIpPacket}; use connlib_shared::messages::{DnsServer, ResourceDescriptionDns}; use connlib_shared::Dname; use domain::base::RelativeDname; @@ -11,8 +10,10 @@ use hickory_resolver::lookup::Lookup; use hickory_resolver::proto::error::{ProtoError, ProtoErrorKind}; use hickory_resolver::proto::op::{Message as TrustDnsMessage, MessageType}; use hickory_resolver::proto::rr::RecordType; +use ip_packet::udp::UdpPacket; +use ip_packet::Packet as _; +use ip_packet::{udp::MutableUdpPacket, IpPacket, MutableIpPacket, MutablePacket, PacketSize}; use itertools::Itertools; -use pnet_packet::{udp::MutableUdpPacket, MutablePacket, Packet as UdpPacket, PacketSize}; use std::collections::{HashMap, HashSet}; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; @@ -21,6 +22,7 @@ const UDP_HEADER_SIZE: usize = 8; const REVERSE_DNS_ADDRESS_END: &str = "arpa"; const REVERSE_DNS_ADDRESS_V4: &str = "in-addr"; const REVERSE_DNS_ADDRESS_V6: &str = "ip6"; +const DNS_PORT: u16 = 53; /// Tells the Client how to reply to a single DNS query #[derive(Debug)] @@ -39,7 +41,7 @@ pub struct DnsQuery<'a> { pub record_type: RecordType, // We could be much more efficient with this field, // we only need the header to create the response. - pub query: crate::ip_packet::IpPacket<'a>, + pub query: ip_packet::IpPacket<'a>, } impl<'a> DnsQuery<'a> { @@ -50,7 +52,7 @@ impl<'a> DnsQuery<'a> { query, } = self; let buf = query.packet().to_vec(); - let query = crate::ip_packet::IpPacket::owned(buf) + let query = ip_packet::IpPacket::owned(buf) .expect("We are constructing the ip packet from an ip packet"); DnsQuery { @@ -275,6 +277,12 @@ where Some(answer_builder.finish()) } +pub fn to_dns<'a>(pkt: &'a UdpPacket<'a>) -> Option<&'a Message<[u8]>> { + (pkt.get_destination() == DNS_PORT) + .then(|| Message::from_slice(pkt.payload()).ok()) + .flatten() +} + // No object safety =_= #[derive(Clone)] enum RecordData { diff --git a/rust/connlib/tunnel/src/gateway.rs b/rust/connlib/tunnel/src/gateway.rs index 9c64467fa..5999fb9ef 100644 --- a/rust/connlib/tunnel/src/gateway.rs +++ b/rust/connlib/tunnel/src/gateway.rs @@ -1,4 +1,3 @@ -use crate::ip_packet::{IpPacket, MutableIpPacket}; use crate::peer::{PacketTransformGateway, Peer}; use crate::peer_store::PeerStore; use crate::utils::{earliest, stun, turn}; @@ -11,6 +10,7 @@ use connlib_shared::messages::{ }; use connlib_shared::{Callbacks, Dname, Error, Result, StaticSecret}; use ip_network::IpNetwork; +use ip_packet::{IpPacket, MutableIpPacket}; use secrecy::{ExposeSecret as _, Secret}; use snownet::ServerNode; use std::collections::{HashSet, VecDeque}; @@ -215,7 +215,7 @@ impl GatewayState { let transmit = self .node - .encapsulate(peer.conn_id, packet.as_immutable().into(), Instant::now()) + .encapsulate(peer.conn_id, packet.as_immutable(), Instant::now()) .inspect_err(|e| tracing::debug!("Failed to encapsulate: {e}")) .ok()??; @@ -246,7 +246,7 @@ impl GatewayState { return None; }; - let packet = match peer.untransform(packet.into()) { + let packet = match peer.untransform(packet) { Ok(packet) => packet, Err(e) => { // Note: this can happen with apps such as cURL that if started before the tunnel routes are address diff --git a/rust/connlib/tunnel/src/io.rs b/rust/connlib/tunnel/src/io.rs index 5ad0dbff9..4f58c0347 100644 --- a/rust/connlib/tunnel/src/io.rs +++ b/rust/connlib/tunnel/src/io.rs @@ -1,7 +1,6 @@ use crate::{ device_channel::Device, dns::{self, DnsQuery}, - ip_packet::{IpPacket, MutableIpPacket}, sockets::{Received, Sockets}, }; use bytes::Bytes; @@ -12,6 +11,7 @@ use hickory_resolver::{ config::{NameServerConfig, Protocol, ResolverConfig}, TokioAsyncResolver, }; +use ip_packet::{IpPacket, MutableIpPacket}; use quinn_udp::Transmit; use std::{ collections::HashMap, diff --git a/rust/connlib/tunnel/src/ip_packet.rs b/rust/connlib/tunnel/src/ip_packet.rs deleted file mode 100644 index 25af87c85..000000000 --- a/rust/connlib/tunnel/src/ip_packet.rs +++ /dev/null @@ -1,396 +0,0 @@ -use std::net::IpAddr; - -use domain::base::message::Message; -use pnet_packet::{ - icmpv6::{self, MutableIcmpv6Packet}, - ip::{IpNextHeaderProtocol, IpNextHeaderProtocols}, - ipv4::{self, Ipv4Packet, MutableIpv4Packet}, - ipv6::{Ipv6Packet, MutableIpv6Packet}, - tcp::{self, MutableTcpPacket, TcpPacket}, - udp::{self, MutableUdpPacket, UdpPacket}, - MutablePacket, Packet, PacketSize, -}; - -const DNS_PORT: u16 = 53; - -#[derive(Debug, PartialEq)] -pub enum MutableIpPacket<'a> { - MutableIpv4Packet(MutableIpv4Packet<'a>), - MutableIpv6Packet(MutableIpv6Packet<'a>), -} - -// no std::mem:;swap? no problem -macro_rules! swap_src_dst { - ($p:expr) => { - let src = $p.get_source(); - let dst = $p.get_destination(); - $p.set_source(dst); - $p.set_destination(src); - }; -} - -impl<'a> MutableIpPacket<'a> { - #[inline] - pub(crate) fn new(data: &mut [u8]) -> Option { - let packet = match data[0] >> 4 { - 4 => MutableIpv4Packet::new(data)?.into(), - 6 => MutableIpv6Packet::new(data)?.into(), - _ => return None, - }; - - Some(packet) - } - - #[cfg(test)] - pub(crate) fn owned(data: Vec) -> Option> { - let packet = match data[0] >> 4 { - 4 => MutableIpv4Packet::owned(data)?.into(), - 6 => MutableIpv6Packet::owned(data)?.into(), - _ => return None, - }; - - Some(packet) - } - - #[inline] - pub(crate) fn source(&self) -> IpAddr { - match self { - MutableIpPacket::MutableIpv4Packet(i) => i.get_source().into(), - MutableIpPacket::MutableIpv6Packet(i) => i.get_source().into(), - } - } - - #[inline] - pub(crate) fn destination(&self) -> IpAddr { - match self { - MutableIpPacket::MutableIpv4Packet(i) => i.get_destination().into(), - MutableIpPacket::MutableIpv6Packet(i) => i.get_destination().into(), - } - } - - #[inline] - pub(crate) fn update_checksum(&mut self) { - // Note: neither ipv6 nor icmp have a checksum. - self.set_icmpv6_checksum(); - self.set_udp_checksum(); - self.set_tcp_checksum(); - // Note: Ipv4 checksum should be set after the others, - // since it's in an upper layer. - self.set_ipv4_checksum(); - } - - pub(crate) fn set_ipv4_checksum(&mut self) { - if let Self::MutableIpv4Packet(p) = self { - p.set_checksum(ipv4::checksum(&p.to_immutable())); - } - } - - fn set_udp_checksum(&mut self) { - let checksum = if let Some(p) = self.as_immutable_udp() { - self.to_immutable().udp_checksum(&p.to_immutable()) - } else { - return; - }; - - self.as_udp() - .expect("Developer error: we can only get a checksum if the packet is udp") - .set_checksum(checksum); - } - - fn set_tcp_checksum(&mut self) { - let checksum = if let Some(p) = self.as_immutable_tcp() { - self.to_immutable().tcp_checksum(&p.to_immutable()) - } else { - return; - }; - - self.as_tcp() - .expect("Developer error: we can only get a checksum if the packet is tcp") - .set_checksum(checksum); - } - - pub(crate) fn to_immutable(&self) -> IpPacket { - match self { - Self::MutableIpv4Packet(p) => p.to_immutable().into(), - Self::MutableIpv6Packet(p) => p.to_immutable().into(), - } - } - - pub(crate) fn into_immutable(self) -> IpPacket<'a> { - match self { - Self::MutableIpv4Packet(p) => p.consume_to_immutable().into(), - Self::MutableIpv6Packet(p) => p.consume_to_immutable().into(), - } - } - - pub(crate) fn as_immutable(&self) -> IpPacket<'_> { - match self { - Self::MutableIpv4Packet(p) => IpPacket::Ipv4Packet(p.to_immutable()), - Self::MutableIpv6Packet(p) => IpPacket::Ipv6Packet(p.to_immutable()), - } - } - - pub(crate) fn as_udp(&mut self) -> Option { - self.to_immutable() - .is_udp() - .then(|| MutableUdpPacket::new(self.payload_mut())) - .flatten() - } - - fn as_tcp(&mut self) -> Option { - self.to_immutable() - .is_tcp() - .then(|| MutableTcpPacket::new(self.payload_mut())) - .flatten() - } - - fn set_icmpv6_checksum(&mut self) { - let (src_addr, dst_addr) = match self { - MutableIpPacket::MutableIpv4Packet(_) => return, - MutableIpPacket::MutableIpv6Packet(p) => (p.get_source(), p.get_destination()), - }; - if let Some(mut pkt) = self.as_icmpv6() { - let checksum = icmpv6::checksum(&pkt.to_immutable(), &src_addr, &dst_addr); - pkt.set_checksum(checksum); - } - } - - fn as_icmpv6(&mut self) -> Option { - self.to_immutable() - .is_icmpv6() - .then(|| MutableIcmpv6Packet::new(self.payload_mut())) - .flatten() - } - - pub(crate) fn as_immutable_udp(&self) -> Option { - self.to_immutable() - .is_udp() - .then(|| UdpPacket::new(self.payload())) - .flatten() - } - - pub(crate) fn as_immutable_tcp(&self) -> Option { - self.to_immutable() - .is_tcp() - .then(|| TcpPacket::new(self.payload())) - .flatten() - } - - pub(crate) fn swap_src_dst(&mut self) { - match self { - Self::MutableIpv4Packet(p) => { - swap_src_dst!(p); - } - Self::MutableIpv6Packet(p) => { - swap_src_dst!(p); - } - } - } - - #[inline] - pub(crate) fn set_dst(&mut self, dst: IpAddr) { - match (self, dst) { - (Self::MutableIpv4Packet(p), IpAddr::V4(d)) => p.set_destination(d), - (Self::MutableIpv6Packet(p), IpAddr::V6(d)) => p.set_destination(d), - _ => {} - } - } - - #[inline] - pub(crate) fn set_src(&mut self, src: IpAddr) { - match (self, src) { - (Self::MutableIpv4Packet(p), IpAddr::V4(s)) => p.set_source(s), - (Self::MutableIpv6Packet(p), IpAddr::V6(s)) => p.set_source(s), - _ => {} - } - } - - pub(crate) fn set_len(&mut self, total_len: usize, payload_len: usize) { - match self { - Self::MutableIpv4Packet(p) => p.set_total_length(total_len as u16), - Self::MutableIpv6Packet(p) => p.set_payload_length(payload_len as u16), - } - } -} - -#[derive(Debug, PartialEq)] -pub enum IpPacket<'a> { - Ipv4Packet(Ipv4Packet<'a>), - Ipv6Packet(Ipv6Packet<'a>), -} - -// TODO: Create our own `ip_packet` crate that `snownet and `firezone-tunnel` can depend on. -impl<'a> From> for snownet::IpPacket<'a> { - fn from(value: IpPacket<'a>) -> Self { - match value { - IpPacket::Ipv4Packet(p) => Self::Ipv4(p), - IpPacket::Ipv6Packet(p) => Self::Ipv6(p), - } - } -} - -impl<'a> From> for MutableIpPacket<'a> { - fn from(value: snownet::MutableIpPacket<'a>) -> Self { - match value { - snownet::MutableIpPacket::Ipv4(p) => Self::MutableIpv4Packet(p), - snownet::MutableIpPacket::Ipv6(p) => Self::MutableIpv6Packet(p), - } - } -} - -impl<'a> IpPacket<'a> { - pub(crate) fn owned(data: Vec) -> Option> { - let packet = match data[0] >> 4 { - 4 => Ipv4Packet::owned(data)?.into(), - 6 => Ipv6Packet::owned(data)?.into(), - _ => return None, - }; - - Some(packet) - } - - pub(crate) fn to_owned(&self) -> IpPacket<'static> { - // This should never fail as the provided buffer is a vec (unless oom) - IpPacket::owned(self.packet().to_vec()).unwrap() - } - - pub(crate) fn is_icmpv6(&self) -> bool { - self.next_header() == IpNextHeaderProtocols::Icmpv6 - } - - pub(crate) fn next_header(&self) -> IpNextHeaderProtocol { - match self { - Self::Ipv4Packet(p) => p.get_next_level_protocol(), - Self::Ipv6Packet(p) => p.get_next_header(), - } - } - - fn is_udp(&self) -> bool { - self.next_header() == IpNextHeaderProtocols::Udp - } - - fn is_tcp(&self) -> bool { - self.next_header() == IpNextHeaderProtocols::Tcp - } - - pub(crate) fn as_udp(&self) -> Option { - self.is_udp() - .then(|| UdpPacket::new(self.payload())) - .flatten() - } - - pub fn source(&self) -> IpAddr { - match self { - Self::Ipv4Packet(p) => p.get_source().into(), - Self::Ipv6Packet(p) => p.get_source().into(), - } - } - - pub fn destination(&self) -> IpAddr { - match self { - Self::Ipv4Packet(p) => p.get_destination().into(), - Self::Ipv6Packet(p) => p.get_destination().into(), - } - } - - pub(crate) fn udp_checksum(&self, dgm: &UdpPacket<'_>) -> u16 { - match self { - Self::Ipv4Packet(p) => udp::ipv4_checksum(dgm, &p.get_source(), &p.get_destination()), - Self::Ipv6Packet(p) => udp::ipv6_checksum(dgm, &p.get_source(), &p.get_destination()), - } - } - - fn tcp_checksum(&self, pkt: &TcpPacket<'_>) -> u16 { - match self { - Self::Ipv4Packet(p) => tcp::ipv4_checksum(pkt, &p.get_source(), &p.get_destination()), - Self::Ipv6Packet(p) => tcp::ipv6_checksum(pkt, &p.get_source(), &p.get_destination()), - } - } -} - -pub(crate) fn to_dns<'a>(pkt: &'a UdpPacket<'a>) -> Option<&'a Message<[u8]>> { - (pkt.get_destination() == DNS_PORT) - .then(|| Message::from_slice(pkt.payload()).ok()) - .flatten() -} - -impl<'a> Packet for IpPacket<'a> { - fn packet(&self) -> &[u8] { - match self { - Self::Ipv4Packet(p) => p.packet(), - Self::Ipv6Packet(p) => p.packet(), - } - } - - fn payload(&self) -> &[u8] { - match self { - Self::Ipv4Packet(p) => p.payload(), - Self::Ipv6Packet(p) => p.payload(), - } - } -} - -impl<'a> PacketSize for IpPacket<'a> { - fn packet_size(&self) -> usize { - match self { - Self::Ipv4Packet(p) => p.packet_size(), - Self::Ipv6Packet(p) => p.packet_size(), - } - } -} - -impl<'a> Packet for MutableIpPacket<'a> { - fn packet(&self) -> &[u8] { - match self { - Self::MutableIpv4Packet(p) => p.packet(), - Self::MutableIpv6Packet(p) => p.packet(), - } - } - - fn payload(&self) -> &[u8] { - match self { - Self::MutableIpv4Packet(p) => p.payload(), - Self::MutableIpv6Packet(p) => p.payload(), - } - } -} - -impl<'a> MutablePacket for MutableIpPacket<'a> { - fn packet_mut(&mut self) -> &mut [u8] { - match self { - Self::MutableIpv4Packet(p) => p.packet_mut(), - Self::MutableIpv6Packet(p) => p.packet_mut(), - } - } - - fn payload_mut(&mut self) -> &mut [u8] { - match self { - Self::MutableIpv4Packet(p) => p.payload_mut(), - Self::MutableIpv6Packet(p) => p.payload_mut(), - } - } -} - -impl<'a> From> for IpPacket<'a> { - fn from(pkt: Ipv4Packet<'a>) -> Self { - Self::Ipv4Packet(pkt) - } -} - -impl<'a> From> for IpPacket<'a> { - fn from(pkt: Ipv6Packet<'a>) -> Self { - Self::Ipv6Packet(pkt) - } -} - -impl<'a> From> for MutableIpPacket<'a> { - fn from(pkt: MutableIpv4Packet<'a>) -> Self { - Self::MutableIpv4Packet(pkt) - } -} - -impl<'a> From> for MutableIpPacket<'a> { - fn from(pkt: MutableIpv6Packet<'a>) -> Self { - Self::MutableIpv6Packet(pkt) - } -} diff --git a/rust/connlib/tunnel/src/lib.rs b/rust/connlib/tunnel/src/lib.rs index 3406a722d..1db6c8f27 100644 --- a/rust/connlib/tunnel/src/lib.rs +++ b/rust/connlib/tunnel/src/lib.rs @@ -25,7 +25,6 @@ mod device_channel; mod dns; mod gateway; mod io; -mod ip_packet; mod peer; mod peer_store; mod sockets; diff --git a/rust/connlib/tunnel/src/peer.rs b/rust/connlib/tunnel/src/peer.rs index 0ad1eeeb7..3a3365135 100644 --- a/rust/connlib/tunnel/src/peer.rs +++ b/rust/connlib/tunnel/src/peer.rs @@ -8,10 +8,9 @@ use connlib_shared::messages::{DnsServer, ResourceId}; use connlib_shared::{Error, Result}; use ip_network::IpNetwork; use ip_network_table::IpNetworkTable; -use pnet_packet::Packet; +use ip_packet::{MutableIpPacket, Packet}; use crate::client::IpProvider; -use crate::ip_packet::MutableIpPacket; type ExpiryingResource = (ResourceId, Option>); diff --git a/rust/connlib/tunnel/src/tests.rs b/rust/connlib/tunnel/src/tests.rs index 66ad548c3..b829810cf 100644 --- a/rust/connlib/tunnel/src/tests.rs +++ b/rust/connlib/tunnel/src/tests.rs @@ -1,4 +1,4 @@ -use crate::{ip_packet::MutableIpPacket, ClientEvent, ClientState, GatewayState}; +use crate::{ClientEvent, ClientState, GatewayState}; use connlib_shared::{ messages::{ResourceDescription, ResourceDescriptionCidr, ResourceId}, proptest::cidr_resource, @@ -219,9 +219,10 @@ impl TunnelTest { src: impl Into, dst: impl Into, ) { - let _maybe_transmit = self - .client - .encapsulate(icmp_request_packet(src.into(), dst.into()), self.now); + let _maybe_transmit = self.client.encapsulate( + ip_packet::make::icmp_request_packet(src.into(), dst.into()), + self.now, + ); // TODO: Handle transmit (send to relay / gateway) } @@ -370,90 +371,3 @@ enum Transition { /// Advance time by this many milliseconds. Tick { millis: u64 }, } - -fn icmp_request_packet(source: IpAddr, dst: IpAddr) -> MutableIpPacket<'static> { - match (source, dst) { - (IpAddr::V4(src), IpAddr::V4(dst)) => { - use pnet_packet::{ - icmp::{ - echo_request::{IcmpCodes, MutableEchoRequestPacket}, - IcmpTypes, MutableIcmpPacket, - }, - ip::IpNextHeaderProtocols, - ipv4::MutableIpv4Packet, - MutablePacket as _, Packet as _, - }; - - let mut buf = vec![0u8; 60]; - - let mut ipv4_packet = MutableIpv4Packet::new(&mut buf[..]).unwrap(); - ipv4_packet.set_version(4); - ipv4_packet.set_header_length(5); - ipv4_packet.set_total_length(60); - ipv4_packet.set_ttl(64); - ipv4_packet.set_next_level_protocol(IpNextHeaderProtocols::Icmp); - ipv4_packet.set_source(src); - ipv4_packet.set_destination(dst); - ipv4_packet.set_checksum(pnet_packet::ipv4::checksum(&ipv4_packet.to_immutable())); - - let mut icmp_packet = MutableIcmpPacket::new(&mut buf[20..]).unwrap(); - icmp_packet.set_icmp_type(IcmpTypes::EchoRequest); - icmp_packet.set_icmp_code(IcmpCodes::NoCode); - icmp_packet.set_checksum(0); - - let mut echo_request_packet = - MutableEchoRequestPacket::new(icmp_packet.payload_mut()).unwrap(); - echo_request_packet.set_sequence_number(1); - echo_request_packet.set_identifier(0); - echo_request_packet.set_checksum(pnet_packet::util::checksum( - echo_request_packet.to_immutable().packet(), - 2, - )); - - MutableIpPacket::owned(buf).unwrap() - } - (IpAddr::V6(src), IpAddr::V6(dst)) => { - use pnet_packet::{ - icmpv6::{ - echo_request::MutableEchoRequestPacket, Icmpv6Code, Icmpv6Types, - MutableIcmpv6Packet, - }, - ip::IpNextHeaderProtocols, - ipv6::MutableIpv6Packet, - MutablePacket as _, - }; - - let mut buf = vec![0u8; 128]; - - let mut ipv6_packet = MutableIpv6Packet::new(&mut buf[..]).unwrap(); - - ipv6_packet.set_version(6); - ipv6_packet.set_payload_length(16); - ipv6_packet.set_next_header(IpNextHeaderProtocols::Icmpv6); - ipv6_packet.set_hop_limit(64); - ipv6_packet.set_source(src); - ipv6_packet.set_destination(dst); - - let mut icmp_packet = MutableIcmpv6Packet::new(&mut buf[40..]).unwrap(); - - icmp_packet.set_icmpv6_type(Icmpv6Types::EchoRequest); - icmp_packet.set_icmpv6_code(Icmpv6Code::new(0)); // No code for echo request - - let mut echo_request_packet = - MutableEchoRequestPacket::new(icmp_packet.payload_mut()).unwrap(); - echo_request_packet.set_identifier(0); - echo_request_packet.set_sequence_number(1); - echo_request_packet.set_checksum(0); - - let checksum = pnet_packet::icmpv6::checksum(&icmp_packet.to_immutable(), &src, &dst); - MutableEchoRequestPacket::new(icmp_packet.payload_mut()) - .unwrap() - .set_checksum(checksum); - - MutableIpPacket::owned(buf).unwrap() - } - (IpAddr::V6(_), IpAddr::V4(_)) | (IpAddr::V4(_), IpAddr::V6(_)) => { - panic!("IPs must be of the same version") - } - } -} diff --git a/rust/ip-packet/Cargo.toml b/rust/ip-packet/Cargo.toml new file mode 100644 index 000000000..5d9294148 --- /dev/null +++ b/rust/ip-packet/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "ip-packet" +# mark:automatic-version +version = "1.0.0" +edition = "2021" +authors = ["Firezone, Inc."] +publish = false + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +pnet_packet = { version = "0.34" } + +[lints] +workspace = true diff --git a/rust/ip-packet/src/lib.rs b/rust/ip-packet/src/lib.rs new file mode 100644 index 000000000..6c0d0744f --- /dev/null +++ b/rust/ip-packet/src/lib.rs @@ -0,0 +1,377 @@ +pub mod make; + +pub use pnet_packet::*; + +use pnet_packet::{ + icmpv6::MutableIcmpv6Packet, + ip::{IpNextHeaderProtocol, IpNextHeaderProtocols}, + ipv4::{Ipv4Packet, MutableIpv4Packet}, + ipv6::{Ipv6Packet, MutableIpv6Packet}, + tcp::{MutableTcpPacket, TcpPacket}, + udp::{MutableUdpPacket, UdpPacket}, +}; +use std::net::IpAddr; + +macro_rules! for_both { + ($this:ident, |$name:ident| $body:expr) => { + match $this { + Self::Ipv4($name) => $body, + Self::Ipv6($name) => $body, + } + }; +} + +// no std::mem::swap? no problem +macro_rules! swap_src_dst { + ($p:expr) => { + let src = $p.get_source(); + let dst = $p.get_destination(); + $p.set_source(dst); + $p.set_destination(src); + }; +} + +#[derive(Debug, PartialEq)] +pub enum IpPacket<'a> { + Ipv4(Ipv4Packet<'a>), + Ipv6(Ipv6Packet<'a>), +} + +#[derive(Debug, PartialEq)] +pub enum MutableIpPacket<'a> { + Ipv4(MutableIpv4Packet<'a>), + Ipv6(MutableIpv6Packet<'a>), +} + +impl<'a> MutableIpPacket<'a> { + pub fn new(buf: &'a mut [u8]) -> Option { + match buf[0] >> 4 { + 4 => Some(MutableIpPacket::Ipv4(MutableIpv4Packet::new(buf)?)), + 6 => Some(MutableIpPacket::Ipv6(MutableIpv6Packet::new(buf)?)), + _ => None, + } + } + + pub fn owned(data: Vec) -> Option> { + let packet = match data[0] >> 4 { + 4 => MutableIpv4Packet::owned(data)?.into(), + 6 => MutableIpv6Packet::owned(data)?.into(), + _ => return None, + }; + + Some(packet) + } + + pub fn to_owned(&self) -> MutableIpPacket<'static> { + match self { + MutableIpPacket::Ipv4(i) => MutableIpv4Packet::owned(i.packet().to_vec()) + .expect("owned packet is still valid") + .into(), + MutableIpPacket::Ipv6(i) => MutableIpv6Packet::owned(i.packet().to_vec()) + .expect("owned packet is still valid") + .into(), + } + } + + pub fn to_immutable(&self) -> IpPacket { + for_both!(self, |i| i.to_immutable().into()) + } + + pub fn source(&self) -> IpAddr { + for_both!(self, |i| i.get_source().into()) + } + + pub fn destination(&self) -> IpAddr { + for_both!(self, |i| i.get_destination().into()) + } + + #[inline] + pub fn update_checksum(&mut self) { + // Note: neither ipv6 nor icmp have a checksum. + self.set_icmpv6_checksum(); + self.set_udp_checksum(); + self.set_tcp_checksum(); + // Note: Ipv4 checksum should be set after the others, + // since it's in an upper layer. + self.set_ipv4_checksum(); + } + + pub fn set_ipv4_checksum(&mut self) { + if let Self::Ipv4(p) = self { + p.set_checksum(ipv4::checksum(&p.to_immutable())); + } + } + + fn set_udp_checksum(&mut self) { + let checksum = if let Some(p) = self.as_immutable_udp() { + self.to_immutable().udp_checksum(&p.to_immutable()) + } else { + return; + }; + + self.as_udp() + .expect("Developer error: we can only get a checksum if the packet is udp") + .set_checksum(checksum); + } + + fn set_tcp_checksum(&mut self) { + let checksum = if let Some(p) = self.as_immutable_tcp() { + self.to_immutable().tcp_checksum(&p.to_immutable()) + } else { + return; + }; + + self.as_tcp() + .expect("Developer error: we can only get a checksum if the packet is tcp") + .set_checksum(checksum); + } + + pub fn into_immutable(self) -> IpPacket<'a> { + match self { + Self::Ipv4(p) => p.consume_to_immutable().into(), + Self::Ipv6(p) => p.consume_to_immutable().into(), + } + } + + pub fn as_immutable(&self) -> IpPacket<'_> { + match self { + Self::Ipv4(p) => IpPacket::Ipv4(p.to_immutable()), + Self::Ipv6(p) => IpPacket::Ipv6(p.to_immutable()), + } + } + + pub fn as_udp(&mut self) -> Option { + self.to_immutable() + .is_udp() + .then(|| MutableUdpPacket::new(self.payload_mut())) + .flatten() + } + + fn as_tcp(&mut self) -> Option { + self.to_immutable() + .is_tcp() + .then(|| MutableTcpPacket::new(self.payload_mut())) + .flatten() + } + + fn set_icmpv6_checksum(&mut self) { + let (src_addr, dst_addr) = match self { + MutableIpPacket::Ipv4(_) => return, + MutableIpPacket::Ipv6(p) => (p.get_source(), p.get_destination()), + }; + if let Some(mut pkt) = self.as_icmpv6() { + let checksum = icmpv6::checksum(&pkt.to_immutable(), &src_addr, &dst_addr); + pkt.set_checksum(checksum); + } + } + + fn as_icmpv6(&mut self) -> Option { + self.to_immutable() + .is_icmpv6() + .then(|| MutableIcmpv6Packet::new(self.payload_mut())) + .flatten() + } + + pub fn as_immutable_udp(&self) -> Option { + self.to_immutable() + .is_udp() + .then(|| UdpPacket::new(self.payload())) + .flatten() + } + + pub fn as_immutable_tcp(&self) -> Option { + self.to_immutable() + .is_tcp() + .then(|| TcpPacket::new(self.payload())) + .flatten() + } + + pub fn swap_src_dst(&mut self) { + match self { + Self::Ipv4(p) => { + swap_src_dst!(p); + } + Self::Ipv6(p) => { + swap_src_dst!(p); + } + } + } + + #[inline] + pub fn set_dst(&mut self, dst: IpAddr) { + match (self, dst) { + (Self::Ipv4(p), IpAddr::V4(d)) => p.set_destination(d), + (Self::Ipv6(p), IpAddr::V6(d)) => p.set_destination(d), + _ => {} + } + } + + #[inline] + pub fn set_src(&mut self, src: IpAddr) { + match (self, src) { + (Self::Ipv4(p), IpAddr::V4(s)) => p.set_source(s), + (Self::Ipv6(p), IpAddr::V6(s)) => p.set_source(s), + _ => {} + } + } + + pub fn set_len(&mut self, total_len: usize, payload_len: usize) { + match self { + Self::Ipv4(p) => p.set_total_length(total_len as u16), + Self::Ipv6(p) => p.set_payload_length(payload_len as u16), + } + } +} + +impl<'a> IpPacket<'a> { + pub fn new(buf: &'a [u8]) -> Option { + match buf[0] >> 4 { + 4 => Some(IpPacket::Ipv4(Ipv4Packet::new(buf)?)), + 6 => Some(IpPacket::Ipv6(Ipv6Packet::new(buf)?)), + _ => None, + } + } + + pub fn to_owned(&self) -> IpPacket<'static> { + match self { + IpPacket::Ipv4(i) => Ipv4Packet::owned(i.packet().to_vec()) + .expect("owned packet is still valid") + .into(), + IpPacket::Ipv6(i) => Ipv6Packet::owned(i.packet().to_vec()) + .expect("owned packet is still valid") + .into(), + } + } + + pub fn source(&self) -> IpAddr { + for_both!(self, |i| i.get_source().into()) + } + + pub fn destination(&self) -> IpAddr { + for_both!(self, |i| i.get_destination().into()) + } + + pub fn udp_payload(&self) -> &[u8] { + debug_assert_eq!( + match self { + IpPacket::Ipv4(i) => i.get_next_level_protocol(), + IpPacket::Ipv6(i) => i.get_next_header(), + }, + IpNextHeaderProtocols::Udp + ); + + for_both!(self, |i| &i.payload()[8..]) + } + + pub fn owned(data: Vec) -> Option> { + let packet = match data[0] >> 4 { + 4 => Ipv4Packet::owned(data)?.into(), + 6 => Ipv6Packet::owned(data)?.into(), + _ => return None, + }; + + Some(packet) + } + + pub fn is_icmpv6(&self) -> bool { + self.next_header() == IpNextHeaderProtocols::Icmpv6 + } + + pub fn next_header(&self) -> IpNextHeaderProtocol { + match self { + Self::Ipv4(p) => p.get_next_level_protocol(), + Self::Ipv6(p) => p.get_next_header(), + } + } + + fn is_udp(&self) -> bool { + self.next_header() == IpNextHeaderProtocols::Udp + } + + fn is_tcp(&self) -> bool { + self.next_header() == IpNextHeaderProtocols::Tcp + } + + pub fn as_udp(&self) -> Option { + self.is_udp() + .then(|| UdpPacket::new(self.payload())) + .flatten() + } + + pub fn udp_checksum(&self, dgm: &UdpPacket<'_>) -> u16 { + match self { + Self::Ipv4(p) => udp::ipv4_checksum(dgm, &p.get_source(), &p.get_destination()), + Self::Ipv6(p) => udp::ipv6_checksum(dgm, &p.get_source(), &p.get_destination()), + } + } + + fn tcp_checksum(&self, pkt: &TcpPacket<'_>) -> u16 { + match self { + Self::Ipv4(p) => tcp::ipv4_checksum(pkt, &p.get_source(), &p.get_destination()), + Self::Ipv6(p) => tcp::ipv6_checksum(pkt, &p.get_source(), &p.get_destination()), + } + } +} + +impl<'a> From> for IpPacket<'a> { + fn from(value: Ipv4Packet<'a>) -> Self { + Self::Ipv4(value) + } +} + +impl<'a> From> for IpPacket<'a> { + fn from(value: Ipv6Packet<'a>) -> Self { + Self::Ipv6(value) + } +} + +impl<'a> From> for MutableIpPacket<'a> { + fn from(value: MutableIpv4Packet<'a>) -> Self { + Self::Ipv4(value) + } +} + +impl<'a> From> for MutableIpPacket<'a> { + fn from(value: MutableIpv6Packet<'a>) -> Self { + Self::Ipv6(value) + } +} + +impl pnet_packet::Packet for MutableIpPacket<'_> { + fn packet(&self) -> &[u8] { + for_both!(self, |i| i.packet()) + } + + fn payload(&self) -> &[u8] { + for_both!(self, |i| i.payload()) + } +} + +impl pnet_packet::Packet for IpPacket<'_> { + fn packet(&self) -> &[u8] { + for_both!(self, |i| i.packet()) + } + + fn payload(&self) -> &[u8] { + for_both!(self, |i| i.payload()) + } +} + +impl pnet_packet::MutablePacket for MutableIpPacket<'_> { + fn packet_mut(&mut self) -> &mut [u8] { + for_both!(self, |i| i.packet_mut()) + } + + fn payload_mut(&mut self) -> &mut [u8] { + for_both!(self, |i| i.payload_mut()) + } +} + +impl<'a> PacketSize for IpPacket<'a> { + fn packet_size(&self) -> usize { + match self { + Self::Ipv4(p) => p.packet_size(), + Self::Ipv6(p) => p.packet_size(), + } + } +} diff --git a/rust/ip-packet/src/make.rs b/rust/ip-packet/src/make.rs new file mode 100644 index 000000000..40b4f794a --- /dev/null +++ b/rust/ip-packet/src/make.rs @@ -0,0 +1,91 @@ +//! Factory module for making all kinds of packets. + +use crate::MutableIpPacket; +use std::net::IpAddr; + +pub fn icmp_request_packet(source: IpAddr, dst: IpAddr) -> MutableIpPacket<'static> { + match (source, dst) { + (IpAddr::V4(src), IpAddr::V4(dst)) => { + use crate::{ + icmp::{ + echo_request::{IcmpCodes, MutableEchoRequestPacket}, + IcmpTypes, MutableIcmpPacket, + }, + ip::IpNextHeaderProtocols, + ipv4::MutableIpv4Packet, + MutablePacket as _, Packet as _, + }; + + let mut buf = vec![0u8; 60]; + + let mut ipv4_packet = MutableIpv4Packet::new(&mut buf[..]).unwrap(); + ipv4_packet.set_version(4); + ipv4_packet.set_header_length(5); + ipv4_packet.set_total_length(60); + ipv4_packet.set_ttl(64); + ipv4_packet.set_next_level_protocol(IpNextHeaderProtocols::Icmp); + ipv4_packet.set_source(src); + ipv4_packet.set_destination(dst); + ipv4_packet.set_checksum(crate::ipv4::checksum(&ipv4_packet.to_immutable())); + + let mut icmp_packet = MutableIcmpPacket::new(&mut buf[20..]).unwrap(); + icmp_packet.set_icmp_type(IcmpTypes::EchoRequest); + icmp_packet.set_icmp_code(IcmpCodes::NoCode); + icmp_packet.set_checksum(0); + + let mut echo_request_packet = + MutableEchoRequestPacket::new(icmp_packet.payload_mut()).unwrap(); + echo_request_packet.set_sequence_number(1); + echo_request_packet.set_identifier(0); + echo_request_packet.set_checksum(crate::util::checksum( + echo_request_packet.to_immutable().packet(), + 2, + )); + + MutableIpPacket::owned(buf).unwrap() + } + (IpAddr::V6(src), IpAddr::V6(dst)) => { + use crate::{ + icmpv6::{ + echo_request::MutableEchoRequestPacket, Icmpv6Code, Icmpv6Types, + MutableIcmpv6Packet, + }, + ip::IpNextHeaderProtocols, + ipv6::MutableIpv6Packet, + MutablePacket as _, + }; + + let mut buf = vec![0u8; 128]; + + let mut ipv6_packet = MutableIpv6Packet::new(&mut buf[..]).unwrap(); + + ipv6_packet.set_version(6); + ipv6_packet.set_payload_length(16); + ipv6_packet.set_next_header(IpNextHeaderProtocols::Icmpv6); + ipv6_packet.set_hop_limit(64); + ipv6_packet.set_source(src); + ipv6_packet.set_destination(dst); + + let mut icmp_packet = MutableIcmpv6Packet::new(&mut buf[40..]).unwrap(); + + icmp_packet.set_icmpv6_type(Icmpv6Types::EchoRequest); + icmp_packet.set_icmpv6_code(Icmpv6Code::new(0)); // No code for echo request + + let mut echo_request_packet = + MutableEchoRequestPacket::new(icmp_packet.payload_mut()).unwrap(); + echo_request_packet.set_identifier(0); + echo_request_packet.set_sequence_number(1); + echo_request_packet.set_checksum(0); + + let checksum = crate::icmpv6::checksum(&icmp_packet.to_immutable(), &src, &dst); + MutableEchoRequestPacket::new(icmp_packet.payload_mut()) + .unwrap() + .set_checksum(checksum); + + MutableIpPacket::owned(buf).unwrap() + } + (IpAddr::V6(_), IpAddr::V4(_)) | (IpAddr::V4(_), IpAddr::V6(_)) => { + panic!("IPs must be of the same version") + } + } +} diff --git a/rust/snownet-tests/Cargo.toml b/rust/snownet-tests/Cargo.toml index ac9f350a5..99206decd 100644 --- a/rust/snownet-tests/Cargo.toml +++ b/rust/snownet-tests/Cargo.toml @@ -8,6 +8,7 @@ edition = "2021" anyhow = "1" boringtun = { workspace = true } snownet = { workspace = true } +ip-packet = { workspace = true } futures = "0.3" hex = "0.4" pnet_packet = { version = "0.34" } diff --git a/rust/snownet-tests/src/main.rs b/rust/snownet-tests/src/main.rs index bb617ff23..bb6e22221 100644 --- a/rust/snownet-tests/src/main.rs +++ b/rust/snownet-tests/src/main.rs @@ -9,10 +9,11 @@ use std::{ use anyhow::{bail, Context as _, Result}; use boringtun::x25519::{PublicKey, StaticSecret}; use futures::{channel::mpsc, future::BoxFuture, FutureExt, SinkExt, StreamExt}; +use ip_packet::IpPacket; use pnet_packet::{ip::IpNextHeaderProtocols, ipv4::Ipv4Packet}; use redis::{aio::MultiplexedConnection, AsyncCommands}; use secrecy::{ExposeSecret as _, Secret}; -use snownet::{Answer, ClientNode, Credentials, IpPacket, Node, Offer, ServerNode}; +use snownet::{Answer, ClientNode, Credentials, Node, Offer, ServerNode}; use tokio::{io::ReadBuf, net::UdpSocket}; use tracing_subscriber::EnvFilter;