diff --git a/rust/connlib/ip-packet/src/make.rs b/rust/connlib/ip-packet/src/make.rs index 4be25cfdd..b639d45dc 100644 --- a/rust/connlib/ip-packet/src/make.rs +++ b/rust/connlib/ip-packet/src/make.rs @@ -109,28 +109,44 @@ pub fn tcp_packet( daddr: IP, sport: u16, dport: u16, + flags: TcpFlags, payload: Vec, ) -> Result where IP: Into, { + let TcpFlags { rst } = flags; + match (saddr.into(), daddr.into()) { (IpAddr::V4(src), IpAddr::V4(dst)) => { - let packet = + let mut packet = PacketBuilder::ipv4(src.octets(), dst.octets(), 64).tcp(sport, dport, 0, 128); + if rst { + packet = packet.rst(); + } + build!(packet, payload) } (IpAddr::V6(src), IpAddr::V6(dst)) => { - let packet = + let mut packet = PacketBuilder::ipv6(src.octets(), dst.octets(), 64).tcp(sport, dport, 0, 128); + if rst { + packet = packet.rst(); + } + build!(packet, payload) } _ => bail!(IpVersionMismatch), } } +#[derive(Debug, Default, Clone, Copy)] +pub struct TcpFlags { + pub rst: bool, +} + pub fn udp_packet( saddr: IP, daddr: IP, diff --git a/rust/connlib/ip-packet/src/proptest.rs b/rust/connlib/ip-packet/src/proptest.rs index 446192b0f..7e8d9b86b 100644 --- a/rust/connlib/ip-packet/src/proptest.rs +++ b/rust/connlib/ip-packet/src/proptest.rs @@ -1,5 +1,5 @@ -use crate::IpPacket; -use proptest::{arbitrary::any, prop_oneof, strategy::Strategy}; +use crate::{IpPacket, make::TcpFlags}; +use proptest::{arbitrary::any, prelude::Just, prop_oneof, strategy::Strategy}; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; pub fn udp_packet() -> impl Strategy { @@ -13,14 +13,20 @@ pub fn udp_packet() -> impl Strategy { ] } -pub fn tcp_packet() -> impl Strategy { +pub fn tcp_packet( + flags: impl Strategy + Clone, +) -> impl Strategy { prop_oneof![ - (ip4_tuple(), any::(), any::()).prop_map(|((saddr, daddr), sport, dport)| { - crate::make::tcp_packet(saddr, daddr, sport, dport, Vec::new()).unwrap() - }), - (ip6_tuple(), any::(), any::()).prop_map(|((saddr, daddr), sport, dport)| { - crate::make::tcp_packet(saddr, daddr, sport, dport, Vec::new()).unwrap() - }), + (ip4_tuple(), any::(), any::(), flags.clone()).prop_map( + |((saddr, daddr), sport, dport, flags)| { + crate::make::tcp_packet(saddr, daddr, sport, dport, flags, Vec::new()).unwrap() + } + ), + (ip6_tuple(), any::(), any::(), flags).prop_map( + |((saddr, daddr), sport, dport, flags)| { + crate::make::tcp_packet(saddr, daddr, sport, dport, flags, Vec::new()).unwrap() + } + ), ] } @@ -36,7 +42,11 @@ pub fn icmp_request_packet() -> impl Strategy { } pub fn udp_or_tcp_or_icmp_packet() -> impl Strategy { - prop_oneof![udp_packet(), tcp_packet(), icmp_request_packet()] + prop_oneof![ + udp_packet(), + tcp_packet(Just(TcpFlags::default())), + icmp_request_packet() + ] } fn ip4_tuple() -> impl Strategy { diff --git a/rust/connlib/tunnel/src/peer.rs b/rust/connlib/tunnel/src/peer.rs index 6157fc3dc..5cefee1a4 100644 --- a/rust/connlib/tunnel/src/peer.rs +++ b/rust/connlib/tunnel/src/peer.rs @@ -688,6 +688,7 @@ mod tests { use chrono::Utc; use connlib_model::{ClientId, ResourceId}; use ip_network::{IpNetwork, Ipv4Network}; + use ip_packet::make::TcpFlags; use super::ClientOnGateway; @@ -727,6 +728,7 @@ mod tests { cidr_v4_resource().hosts().next().unwrap(), 5401, 80, + TcpFlags::default(), vec![0; 100], ) .unwrap(); @@ -801,6 +803,7 @@ mod tests { gateway_tun_ipv4(), 5401, 80, + TcpFlags::default(), vec![0; 100], ) .unwrap(); @@ -810,6 +813,7 @@ mod tests { client_tun_ipv4(), 80, 5401, + TcpFlags::default(), vec![0; 100], ) .unwrap(); @@ -1188,7 +1192,7 @@ mod proptests { Filter, PortRange, ResourceDescription, ResourceDescriptionCidr, }; use crate::proptest::*; - use ip_packet::make::{icmp_request_packet, tcp_packet, udp_packet}; + use ip_packet::make::{TcpFlags, icmp_request_packet, tcp_packet, udp_packet}; use itertools::Itertools as _; use proptest::{ arbitrary::any, @@ -1235,7 +1239,14 @@ mod proptests { }; let packet = match protocol { - Protocol::Tcp { dport } => tcp_packet(src, *dest, sport, *dport, payload.clone()), + Protocol::Tcp { dport } => tcp_packet( + src, + *dest, + sport, + *dport, + TcpFlags::default(), + payload.clone(), + ), Protocol::Udp { dport } => udp_packet(src, *dest, sport, *dport, payload.clone()), Protocol::Icmp => icmp_request_packet(src, *dest, 1, 0, &[]), } @@ -1290,7 +1301,14 @@ mod proptests { for (_, protocol) in protocol_config { let packet = match protocol { - Protocol::Tcp { dport } => tcp_packet(src, dest, sport, dport, payload.clone()), + Protocol::Tcp { dport } => tcp_packet( + src, + dest, + sport, + dport, + TcpFlags::default(), + payload.clone(), + ), Protocol::Udp { dport } => udp_packet(src, dest, sport, dport, payload.clone()), Protocol::Icmp => icmp_request_packet(src, dest, 1, 0, &[]), } @@ -1331,7 +1349,9 @@ mod proptests { gateway_tun(), ); let packet = match protocol { - Protocol::Tcp { dport } => tcp_packet(src, dest, sport, dport, payload), + Protocol::Tcp { dport } => { + tcp_packet(src, dest, sport, dport, TcpFlags::default(), payload) + } Protocol::Udp { dport } => udp_packet(src, dest, sport, dport, payload), Protocol::Icmp => icmp_request_packet(src, dest, 1, 0, &[]), } @@ -1387,14 +1407,23 @@ mod proptests { ); let packet_allowed = match protocol_allowed { - Protocol::Tcp { dport } => tcp_packet(src, dest, sport, dport, payload.clone()), + Protocol::Tcp { dport } => tcp_packet( + src, + dest, + sport, + dport, + TcpFlags::default(), + payload.clone(), + ), Protocol::Udp { dport } => udp_packet(src, dest, sport, dport, payload.clone()), Protocol::Icmp => icmp_request_packet(src, dest, 1, 0, &[]), } .unwrap(); let packet_rejected = match protocol_removed { - Protocol::Tcp { dport } => tcp_packet(src, dest, sport, dport, payload), + Protocol::Tcp { dport } => { + tcp_packet(src, dest, sport, dport, TcpFlags::default(), payload) + } Protocol::Udp { dport } => udp_packet(src, dest, sport, dport, payload), Protocol::Icmp => icmp_request_packet(src, dest, 1, 0, &[]), } diff --git a/rust/connlib/tunnel/src/peer/nat_table.rs b/rust/connlib/tunnel/src/peer/nat_table.rs index d6cd22088..8a3c79a67 100644 --- a/rust/connlib/tunnel/src/peer/nat_table.rs +++ b/rust/connlib/tunnel/src/peer/nat_table.rs @@ -61,12 +61,23 @@ impl NatTable { let inside = (src, dst); - if let Some(outside) = self.table.get_by_left(&inside) { + if let Some(outside) = self.table.get_by_left(&inside).copied() { if outside.1 == outside_dst { tracing::trace!(?inside, ?outside, "Translating outgoing packet"); - self.last_seen.insert(*outside, now); - return Ok(*outside); + if packet.as_tcp().is_some_and(|tcp| tcp.rst()) { + tracing::debug!( + ?inside, + ?outside, + "Witnessed outgoing TCP RST, removing NAT session" + ); + + self.table.remove_by_left(&inside); + self.expired.insert(outside); + } + + self.last_seen.insert(outside, now); + return Ok(outside); } tracing::trace!(?inside, ?outside, "Outgoing packet for expired translation"); @@ -84,6 +95,7 @@ impl NatTable { self.table.insert(inside, outside); self.last_seen.insert(outside, now); + self.expired.remove(&outside); tracing::debug!(?inside, ?outside, "New NAT session"); @@ -118,7 +130,20 @@ impl NatTable { let outside = (packet.destination_protocol()?, packet.source()); - if let Some((proto, src)) = self.translate_incoming_inner(&outside, now) { + if let Some(inside) = self.translate_incoming_inner(&outside, now) { + if packet.as_tcp().is_some_and(|tcp| tcp.rst()) { + tracing::debug!( + ?inside, + ?outside, + "Witnessed incoming TCP RST, removing NAT session" + ); + + self.table.remove_by_right(&outside); + self.expired.insert(outside); + } + + let (proto, src) = inside; + return Ok(TranslateIncomingResult::Ok { proto, src }); } @@ -215,7 +240,7 @@ pub enum TranslateIncomingResult { #[cfg(all(test, feature = "proptest"))] mod tests { use super::*; - use ip_packet::{IpPacket, proptest::*}; + use ip_packet::{IpPacket, make::TcpFlags, proptest::*}; use proptest::prelude::*; #[test_strategy::proptest(ProptestConfig { max_local_rejects: 10_000, max_global_rejects: 10_000, ..ProptestConfig::default() })] @@ -322,4 +347,51 @@ mod tests { assert_eq!(responses, original_src_p_and_dst); } + + #[test_strategy::proptest] + fn outgoing_tcp_rst_removes_nat_mapping( + #[strategy(tcp_packet(Just(TcpFlags::default())))] req: IpPacket, + #[strategy(tcp_packet(Just(TcpFlags { rst: true })))] mut rst: IpPacket, + #[strategy(any::())] outside_dst: IpAddr, + ) { + let _guard = firezone_logging::test("trace"); + + proptest::prop_assume!(req.destination().is_ipv4() == outside_dst.is_ipv4()); // Required for our test to simulate a response. + proptest::prop_assume!(rst.destination().is_ipv4() == outside_dst.is_ipv4()); // Required for our test to simulate a response. + rst.set_source_protocol(req.source_protocol().unwrap().value()); + rst.set_destination_protocol(req.destination_protocol().unwrap().value()); + rst.set_dst(req.destination()).unwrap(); + + let mut table = NatTable::default(); + + let outside = table + .translate_outgoing(&req, outside_dst, Instant::now()) + .unwrap(); + + let mut response = req.clone(); + response.set_destination_protocol(outside.0.value()); + response.set_src(outside.1).unwrap(); + + match table.translate_incoming(&response, Instant::now()).unwrap() { + TranslateIncomingResult::Ok { .. } => {} + result @ (TranslateIncomingResult::NoNatSession + | TranslateIncomingResult::ExpiredNatSession + | TranslateIncomingResult::DestinationUnreachable(_)) => { + panic!("Wrong result: {result:?}") + } + }; + + table + .translate_outgoing(&rst, outside_dst, Instant::now()) + .unwrap(); + + match table.translate_incoming(&response, Instant::now()).unwrap() { + TranslateIncomingResult::ExpiredNatSession => {} + result @ (TranslateIncomingResult::NoNatSession + | TranslateIncomingResult::Ok { .. } + | TranslateIncomingResult::DestinationUnreachable(_)) => { + panic!("Wrong result: {result:?}") + } + }; + } } diff --git a/rust/connlib/tunnel/src/tests/sut.rs b/rust/connlib/tunnel/src/tests/sut.rs index 8686ae73a..08d887e3d 100644 --- a/rust/connlib/tunnel/src/tests/sut.rs +++ b/rust/connlib/tunnel/src/tests/sut.rs @@ -20,6 +20,7 @@ use bufferpool::BufferPool; use connlib_model::{ClientId, GatewayId, PublicKey, RelayId}; use dns_types::ResponseCode; use dns_types::prelude::*; +use ip_packet::make::TcpFlags; use rand::SeedableRng; use rand::distributions::DistString; use sha2::Digest; @@ -199,6 +200,7 @@ impl TunnelTest { dst, sport.0, dport.0, + TcpFlags::default(), payload.to_be_bytes().to_vec(), ) .unwrap();