diff --git a/rust/connlib/libs/tunnel/src/dns.rs b/rust/connlib/libs/tunnel/src/dns.rs index cafaf8299..8514eadc2 100644 --- a/rust/connlib/libs/tunnel/src/dns.rs +++ b/rust/connlib/libs/tunnel/src/dns.rs @@ -60,7 +60,7 @@ where let mut pkt = MutableIpPacket::new(&mut res_buf)?; let udp_checksum = pkt.to_immutable().udp_checksum(&pkt.as_immutable_udp()?); pkt.as_udp()?.set_checksum(udp_checksum); - pkt.set_checksum(); + pkt.set_ipv4_checksum(); Some(res_buf) } diff --git a/rust/connlib/libs/tunnel/src/ip_packet.rs b/rust/connlib/libs/tunnel/src/ip_packet.rs index a1484d63f..0d988c287 100644 --- a/rust/connlib/libs/tunnel/src/ip_packet.rs +++ b/rust/connlib/libs/tunnel/src/ip_packet.rs @@ -4,9 +4,10 @@ use domain::base::message::Message; use pnet_packet::{ icmpv6::{self, MutableIcmpv6Packet}, ip::{IpNextHeaderProtocol, IpNextHeaderProtocols}, - ipv4::{checksum, Ipv4Packet, MutableIpv4Packet}, + ipv4::{self, Ipv4Packet, MutableIpv4Packet}, ipv6::{Ipv6Packet, MutableIpv6Packet}, - udp::{ipv4_checksum, ipv6_checksum, MutableUdpPacket, UdpPacket}, + tcp::{self, MutableTcpPacket, TcpPacket}, + udp::{self, MutableUdpPacket, UdpPacket}, MutablePacket, Packet, PacketSize, }; @@ -37,12 +38,46 @@ impl<'a> MutableIpPacket<'a> { } } - pub(crate) fn set_checksum(&mut self) { + pub(crate) fn update_checksum(&mut self) { + // Note: neither ipv6 nor icmp have a checksum. + self.set_icmpv6_checksum(); + self.set_udp_checksum(); + self.set_tcp_checksum(); + // Note: Ipv4 checksum should be set after the others, + // since it's in an upper layer. + self.set_ipv4_checksum(); + } + + pub(crate) fn set_ipv4_checksum(&mut self) { if let Self::MutableIpv4Packet(p) = self { - p.set_checksum(checksum(&p.to_immutable())); + p.set_checksum(ipv4::checksum(&p.to_immutable())); } } + fn set_udp_checksum(&mut self) { + let checksum = if let Some(p) = self.as_immutable_udp() { + self.to_immutable().udp_checksum(&p.to_immutable()) + } else { + return; + }; + + self.as_udp() + .expect("Developer error: we can only get a checksum if the packet is udp") + .set_checksum(checksum); + } + + fn set_tcp_checksum(&mut self) { + let checksum = if let Some(p) = self.as_immutable_tcp() { + self.to_immutable().tcp_checksum(&p.to_immutable()) + } else { + return; + }; + + self.as_tcp() + .expect("Developer error: we can only get a checksum if the packet is tcp") + .set_checksum(checksum); + } + pub(crate) fn to_immutable(&self) -> IpPacket { match self { Self::MutableIpv4Packet(p) => p.to_immutable().into(), @@ -57,7 +92,14 @@ impl<'a> MutableIpPacket<'a> { .flatten() } - pub fn set_icmpv6_checksum(&mut self) { + fn as_tcp(&mut self) -> Option { + self.to_immutable() + .is_tcp() + .then(|| MutableTcpPacket::new(self.payload_mut())) + .flatten() + } + + fn set_icmpv6_checksum(&mut self) { let (src_addr, dst_addr) = match self { MutableIpPacket::MutableIpv4Packet(_) => return, MutableIpPacket::MutableIpv6Packet(p) => (p.get_source(), p.get_destination()), @@ -82,6 +124,13 @@ impl<'a> MutableIpPacket<'a> { .flatten() } + pub(crate) fn as_immutable_tcp(&self) -> Option { + self.to_immutable() + .is_tcp() + .then(|| TcpPacket::new(self.payload())) + .flatten() + } + pub(crate) fn swap_src_dst(&mut self) { match self { Self::MutableIpv4Packet(p) => { @@ -152,6 +201,10 @@ impl<'a> IpPacket<'a> { self.next_header() == IpNextHeaderProtocols::Udp } + fn is_tcp(&self) -> bool { + self.next_header() == IpNextHeaderProtocols::Tcp + } + pub(crate) fn as_udp(&self) -> Option { self.is_udp() .then(|| UdpPacket::new(self.payload())) @@ -174,8 +227,15 @@ impl<'a> IpPacket<'a> { pub(crate) fn udp_checksum(&self, dgm: &UdpPacket<'_>) -> u16 { match self { - Self::Ipv4Packet(p) => ipv4_checksum(dgm, &p.get_source(), &p.get_destination()), - Self::Ipv6Packet(p) => ipv6_checksum(dgm, &p.get_source(), &p.get_destination()), + Self::Ipv4Packet(p) => udp::ipv4_checksum(dgm, &p.get_source(), &p.get_destination()), + Self::Ipv6Packet(p) => udp::ipv6_checksum(dgm, &p.get_source(), &p.get_destination()), + } + } + + fn tcp_checksum(&self, pkt: &TcpPacket<'_>) -> u16 { + match self { + Self::Ipv4Packet(p) => tcp::ipv4_checksum(pkt, &p.get_source(), &p.get_destination()), + Self::Ipv6Packet(p) => tcp::ipv6_checksum(pkt, &p.get_source(), &p.get_destination()), } } } diff --git a/rust/connlib/libs/tunnel/src/lib.rs b/rust/connlib/libs/tunnel/src/lib.rs index f41728269..06b411733 100644 --- a/rust/connlib/libs/tunnel/src/lib.rs +++ b/rust/connlib/libs/tunnel/src/lib.rs @@ -559,8 +559,7 @@ where } } - packet.set_checksum(); - packet.set_icmpv6_checksum(); + packet.update_checksum(); } ( peer.tunnel.lock().encapsulate(&src[..res], &mut dst[..]), diff --git a/rust/connlib/libs/tunnel/src/resource_sender.rs b/rust/connlib/libs/tunnel/src/resource_sender.rs index 0d30a69c5..2da833906 100644 --- a/rust/connlib/libs/tunnel/src/resource_sender.rs +++ b/rust/connlib/libs/tunnel/src/resource_sender.rs @@ -19,16 +19,15 @@ where async fn update_and_send_packet(&self, packet: &mut [u8], dst_addr: IpAddr) { let Some(mut pkt) = MutableIpPacket::new(packet) else { return }; pkt.set_dst(dst_addr); - pkt.set_checksum(); - pkt.set_icmpv6_checksum(); + pkt.update_checksum(); match dst_addr { IpAddr::V4(addr) => { - tracing::trace!("Sending to packet to {addr}"); + tracing::trace!("Sending packet to {addr}"); self.write4_device_infallible(packet).await; } IpAddr::V6(addr) => { - tracing::trace!("Sending to packet to {addr}"); + tracing::trace!("Sending packet to {addr}"); self.write6_device_infallible(packet).await; } }