feat(gateway): free TCP NAT bindings on RSTs (#9682)

Whenever we see a TCP packet with the RST bit set, we clear the current
NAT binding and move it to the `expired` list.
This commit is contained in:
Thomas Eizinger
2025-06-26 15:20:01 +01:00
committed by GitHub
parent 1acfcd5678
commit 5f38ccaeab
5 changed files with 152 additions and 23 deletions

View File

@@ -109,28 +109,44 @@ pub fn tcp_packet<IP>(
daddr: IP,
sport: u16,
dport: u16,
flags: TcpFlags,
payload: Vec<u8>,
) -> Result<IpPacket>
where
IP: Into<IpAddr>,
{
let TcpFlags { rst } = flags;
match (saddr.into(), daddr.into()) {
(IpAddr::V4(src), IpAddr::V4(dst)) => {
let packet =
let mut packet =
PacketBuilder::ipv4(src.octets(), dst.octets(), 64).tcp(sport, dport, 0, 128);
if rst {
packet = packet.rst();
}
build!(packet, payload)
}
(IpAddr::V6(src), IpAddr::V6(dst)) => {
let packet =
let mut packet =
PacketBuilder::ipv6(src.octets(), dst.octets(), 64).tcp(sport, dport, 0, 128);
if rst {
packet = packet.rst();
}
build!(packet, payload)
}
_ => bail!(IpVersionMismatch),
}
}
#[derive(Debug, Default, Clone, Copy)]
pub struct TcpFlags {
pub rst: bool,
}
pub fn udp_packet<IP>(
saddr: IP,
daddr: IP,

View File

@@ -1,5 +1,5 @@
use crate::IpPacket;
use proptest::{arbitrary::any, prop_oneof, strategy::Strategy};
use crate::{IpPacket, make::TcpFlags};
use proptest::{arbitrary::any, prelude::Just, prop_oneof, strategy::Strategy};
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
pub fn udp_packet() -> impl Strategy<Value = IpPacket> {
@@ -13,14 +13,20 @@ pub fn udp_packet() -> impl Strategy<Value = IpPacket> {
]
}
pub fn tcp_packet() -> impl Strategy<Value = IpPacket> {
pub fn tcp_packet(
flags: impl Strategy<Value = TcpFlags> + Clone,
) -> impl Strategy<Value = IpPacket> {
prop_oneof![
(ip4_tuple(), any::<u16>(), any::<u16>()).prop_map(|((saddr, daddr), sport, dport)| {
crate::make::tcp_packet(saddr, daddr, sport, dport, Vec::new()).unwrap()
}),
(ip6_tuple(), any::<u16>(), any::<u16>()).prop_map(|((saddr, daddr), sport, dport)| {
crate::make::tcp_packet(saddr, daddr, sport, dport, Vec::new()).unwrap()
}),
(ip4_tuple(), any::<u16>(), any::<u16>(), flags.clone()).prop_map(
|((saddr, daddr), sport, dport, flags)| {
crate::make::tcp_packet(saddr, daddr, sport, dport, flags, Vec::new()).unwrap()
}
),
(ip6_tuple(), any::<u16>(), any::<u16>(), flags).prop_map(
|((saddr, daddr), sport, dport, flags)| {
crate::make::tcp_packet(saddr, daddr, sport, dport, flags, Vec::new()).unwrap()
}
),
]
}
@@ -36,7 +42,11 @@ pub fn icmp_request_packet() -> impl Strategy<Value = IpPacket> {
}
pub fn udp_or_tcp_or_icmp_packet() -> impl Strategy<Value = IpPacket> {
prop_oneof![udp_packet(), tcp_packet(), icmp_request_packet()]
prop_oneof![
udp_packet(),
tcp_packet(Just(TcpFlags::default())),
icmp_request_packet()
]
}
fn ip4_tuple() -> impl Strategy<Value = (Ipv4Addr, Ipv4Addr)> {

View File

@@ -688,6 +688,7 @@ mod tests {
use chrono::Utc;
use connlib_model::{ClientId, ResourceId};
use ip_network::{IpNetwork, Ipv4Network};
use ip_packet::make::TcpFlags;
use super::ClientOnGateway;
@@ -727,6 +728,7 @@ mod tests {
cidr_v4_resource().hosts().next().unwrap(),
5401,
80,
TcpFlags::default(),
vec![0; 100],
)
.unwrap();
@@ -801,6 +803,7 @@ mod tests {
gateway_tun_ipv4(),
5401,
80,
TcpFlags::default(),
vec![0; 100],
)
.unwrap();
@@ -810,6 +813,7 @@ mod tests {
client_tun_ipv4(),
80,
5401,
TcpFlags::default(),
vec![0; 100],
)
.unwrap();
@@ -1188,7 +1192,7 @@ mod proptests {
Filter, PortRange, ResourceDescription, ResourceDescriptionCidr,
};
use crate::proptest::*;
use ip_packet::make::{icmp_request_packet, tcp_packet, udp_packet};
use ip_packet::make::{TcpFlags, icmp_request_packet, tcp_packet, udp_packet};
use itertools::Itertools as _;
use proptest::{
arbitrary::any,
@@ -1235,7 +1239,14 @@ mod proptests {
};
let packet = match protocol {
Protocol::Tcp { dport } => tcp_packet(src, *dest, sport, *dport, payload.clone()),
Protocol::Tcp { dport } => tcp_packet(
src,
*dest,
sport,
*dport,
TcpFlags::default(),
payload.clone(),
),
Protocol::Udp { dport } => udp_packet(src, *dest, sport, *dport, payload.clone()),
Protocol::Icmp => icmp_request_packet(src, *dest, 1, 0, &[]),
}
@@ -1290,7 +1301,14 @@ mod proptests {
for (_, protocol) in protocol_config {
let packet = match protocol {
Protocol::Tcp { dport } => tcp_packet(src, dest, sport, dport, payload.clone()),
Protocol::Tcp { dport } => tcp_packet(
src,
dest,
sport,
dport,
TcpFlags::default(),
payload.clone(),
),
Protocol::Udp { dport } => udp_packet(src, dest, sport, dport, payload.clone()),
Protocol::Icmp => icmp_request_packet(src, dest, 1, 0, &[]),
}
@@ -1331,7 +1349,9 @@ mod proptests {
gateway_tun(),
);
let packet = match protocol {
Protocol::Tcp { dport } => tcp_packet(src, dest, sport, dport, payload),
Protocol::Tcp { dport } => {
tcp_packet(src, dest, sport, dport, TcpFlags::default(), payload)
}
Protocol::Udp { dport } => udp_packet(src, dest, sport, dport, payload),
Protocol::Icmp => icmp_request_packet(src, dest, 1, 0, &[]),
}
@@ -1387,14 +1407,23 @@ mod proptests {
);
let packet_allowed = match protocol_allowed {
Protocol::Tcp { dport } => tcp_packet(src, dest, sport, dport, payload.clone()),
Protocol::Tcp { dport } => tcp_packet(
src,
dest,
sport,
dport,
TcpFlags::default(),
payload.clone(),
),
Protocol::Udp { dport } => udp_packet(src, dest, sport, dport, payload.clone()),
Protocol::Icmp => icmp_request_packet(src, dest, 1, 0, &[]),
}
.unwrap();
let packet_rejected = match protocol_removed {
Protocol::Tcp { dport } => tcp_packet(src, dest, sport, dport, payload),
Protocol::Tcp { dport } => {
tcp_packet(src, dest, sport, dport, TcpFlags::default(), payload)
}
Protocol::Udp { dport } => udp_packet(src, dest, sport, dport, payload),
Protocol::Icmp => icmp_request_packet(src, dest, 1, 0, &[]),
}

View File

@@ -61,12 +61,23 @@ impl NatTable {
let inside = (src, dst);
if let Some(outside) = self.table.get_by_left(&inside) {
if let Some(outside) = self.table.get_by_left(&inside).copied() {
if outside.1 == outside_dst {
tracing::trace!(?inside, ?outside, "Translating outgoing packet");
self.last_seen.insert(*outside, now);
return Ok(*outside);
if packet.as_tcp().is_some_and(|tcp| tcp.rst()) {
tracing::debug!(
?inside,
?outside,
"Witnessed outgoing TCP RST, removing NAT session"
);
self.table.remove_by_left(&inside);
self.expired.insert(outside);
}
self.last_seen.insert(outside, now);
return Ok(outside);
}
tracing::trace!(?inside, ?outside, "Outgoing packet for expired translation");
@@ -84,6 +95,7 @@ impl NatTable {
self.table.insert(inside, outside);
self.last_seen.insert(outside, now);
self.expired.remove(&outside);
tracing::debug!(?inside, ?outside, "New NAT session");
@@ -118,7 +130,20 @@ impl NatTable {
let outside = (packet.destination_protocol()?, packet.source());
if let Some((proto, src)) = self.translate_incoming_inner(&outside, now) {
if let Some(inside) = self.translate_incoming_inner(&outside, now) {
if packet.as_tcp().is_some_and(|tcp| tcp.rst()) {
tracing::debug!(
?inside,
?outside,
"Witnessed incoming TCP RST, removing NAT session"
);
self.table.remove_by_right(&outside);
self.expired.insert(outside);
}
let (proto, src) = inside;
return Ok(TranslateIncomingResult::Ok { proto, src });
}
@@ -215,7 +240,7 @@ pub enum TranslateIncomingResult {
#[cfg(all(test, feature = "proptest"))]
mod tests {
use super::*;
use ip_packet::{IpPacket, proptest::*};
use ip_packet::{IpPacket, make::TcpFlags, proptest::*};
use proptest::prelude::*;
#[test_strategy::proptest(ProptestConfig { max_local_rejects: 10_000, max_global_rejects: 10_000, ..ProptestConfig::default() })]
@@ -322,4 +347,51 @@ mod tests {
assert_eq!(responses, original_src_p_and_dst);
}
#[test_strategy::proptest]
fn outgoing_tcp_rst_removes_nat_mapping(
#[strategy(tcp_packet(Just(TcpFlags::default())))] req: IpPacket,
#[strategy(tcp_packet(Just(TcpFlags { rst: true })))] mut rst: IpPacket,
#[strategy(any::<IpAddr>())] outside_dst: IpAddr,
) {
let _guard = firezone_logging::test("trace");
proptest::prop_assume!(req.destination().is_ipv4() == outside_dst.is_ipv4()); // Required for our test to simulate a response.
proptest::prop_assume!(rst.destination().is_ipv4() == outside_dst.is_ipv4()); // Required for our test to simulate a response.
rst.set_source_protocol(req.source_protocol().unwrap().value());
rst.set_destination_protocol(req.destination_protocol().unwrap().value());
rst.set_dst(req.destination()).unwrap();
let mut table = NatTable::default();
let outside = table
.translate_outgoing(&req, outside_dst, Instant::now())
.unwrap();
let mut response = req.clone();
response.set_destination_protocol(outside.0.value());
response.set_src(outside.1).unwrap();
match table.translate_incoming(&response, Instant::now()).unwrap() {
TranslateIncomingResult::Ok { .. } => {}
result @ (TranslateIncomingResult::NoNatSession
| TranslateIncomingResult::ExpiredNatSession
| TranslateIncomingResult::DestinationUnreachable(_)) => {
panic!("Wrong result: {result:?}")
}
};
table
.translate_outgoing(&rst, outside_dst, Instant::now())
.unwrap();
match table.translate_incoming(&response, Instant::now()).unwrap() {
TranslateIncomingResult::ExpiredNatSession => {}
result @ (TranslateIncomingResult::NoNatSession
| TranslateIncomingResult::Ok { .. }
| TranslateIncomingResult::DestinationUnreachable(_)) => {
panic!("Wrong result: {result:?}")
}
};
}
}

View File

@@ -20,6 +20,7 @@ use bufferpool::BufferPool;
use connlib_model::{ClientId, GatewayId, PublicKey, RelayId};
use dns_types::ResponseCode;
use dns_types::prelude::*;
use ip_packet::make::TcpFlags;
use rand::SeedableRng;
use rand::distributions::DistString;
use sha2::Digest;
@@ -199,6 +200,7 @@ impl TunnelTest {
dst,
sport.0,
dport.0,
TcpFlags::default(),
payload.to_be_bytes().to_vec(),
)
.unwrap();