From 35b28692de30bfd288f545e9813fd837a74282e3 Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Wed, 19 Nov 2025 10:48:04 +1100 Subject: [PATCH] feat(gateway): improve state tracking of DNS resource NAT (#10868) Right now, the state tracking within the DNS resource NAT table is pretty simple: - We map from inside to outside and back - When we see a TCP RST, we remove it immediately To improve our logs a bit and make the NAT table more robust, we extend it by: - Tracking last inbound and outbound packet - Tracking FIN and RST flags This allows us to fully observe e.g. a TCP shutdown where both parties send TCP FIN. It also allows us to remove entries that have never been confirmed after a shorter amount of time. Resolves: #10795 --------- Signed-off-by: Thomas Eizinger Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- rust/connlib/tunnel/src/gateway/nat_table.rs | 243 ++++++++++++++----- 1 file changed, 181 insertions(+), 62 deletions(-) diff --git a/rust/connlib/tunnel/src/gateway/nat_table.rs b/rust/connlib/tunnel/src/gateway/nat_table.rs index d7ca033d7..256e7b4af 100644 --- a/rust/connlib/tunnel/src/gateway/nat_table.rs +++ b/rust/connlib/tunnel/src/gateway/nat_table.rs @@ -16,32 +16,66 @@ use std::time::{Duration, Instant}; /// Thus, purely an L3 NAT would not be sufficient as it would be impossible to map back to the proxy IP. #[derive(Default, Debug)] pub(crate) struct NatTable { - pub(crate) table: BiMap<(Protocol, IpAddr), (Protocol, IpAddr)>, - pub(crate) last_seen: BTreeMap<(Protocol, IpAddr), Instant>, + table: BiMap, + state_by_inside: BTreeMap, // We don't bother with proactively freeing this because a single entry is only ~20 bytes and it gets cleanup once the connection to the client goes away. - expired: HashSet<(Protocol, IpAddr)>, + expired: HashSet, +} + +#[derive(Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Clone, Copy)] +struct Inside(Protocol, IpAddr); + +impl Inside { + fn into_inner(self) -> (Protocol, IpAddr) { + (self.0, self.1) + } +} + +#[derive(Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Clone, Copy)] +struct Outside(Protocol, IpAddr); + +impl Outside { + fn into_inner(self) -> (Protocol, IpAddr) { + (self.0, self.1) + } } pub(crate) const TCP_TTL: Duration = Duration::from_secs(60 * 60 * 2); pub(crate) const UDP_TTL: Duration = Duration::from_secs(60 * 2); pub(crate) const ICMP_TTL: Duration = Duration::from_secs(60 * 2); +pub(crate) const UNCONFIRMED_TTL: Duration = Duration::from_secs(60); + impl NatTable { pub(crate) fn handle_timeout(&mut self, now: Instant) { - for (outside, e) in self.last_seen.iter() { - let ttl = match outside.0 { - Protocol::Tcp(_) => TCP_TTL, - Protocol::Udp(_) => UDP_TTL, - Protocol::Icmp(_) => ICMP_TTL, + let expired = self.state_by_inside.extract_if(.., |inside, state| { + state + .remove_at(inside.0) + .is_some_and(|remove_at| now >= remove_at) + }); + + for (inside, state) in expired { + let Some((_, outside)) = self.table.remove_by_left(&inside) else { + continue; }; - if now.duration_since(*e) >= ttl - && let Some((inside, _)) = self.table.remove_by_right(outside) - { - tracing::debug!(?inside, ?outside, ?ttl, "NAT session expired"); - self.expired.insert(*outside); - } + self.expired.insert(outside); + + let last_outgoing = now.duration_since(state.last_outgoing); + let last_incoming = state.last_incoming.map(|t| now.duration_since(t)); + + tracing::debug!( + ?inside, + ?outside, + ?last_outgoing, + ?last_incoming, + fin_tx = %state.outgoing_fin, + fin_rx = %state.incoming_fin, + rst_tx = %state.outgoing_rst, + rst_rx = %state.incoming_rst, + "NAT entry removed" + ); } } @@ -54,43 +88,41 @@ impl NatTable { let src = packet.source_protocol()?; let dst = packet.destination(); - let inside = (src, dst); + let inside = Inside(src, dst); - if let Some(outside) = self.table.get_by_left(&inside).copied() { - tracing::trace!(?inside, ?outside, "Translating outgoing packet"); + if let Some(outside) = self.table.get_by_left(&inside).copied() + && let Some(state) = self.state_by_inside.get_mut(&inside) + { + tracing::trace!(?inside, ?outside, ?state, "Translating outgoing packet"); if packet.as_tcp().is_some_and(|tcp| tcp.rst()) { - tracing::debug!( - ?inside, - ?outside, - "Witnessed outgoing TCP RST, removing NAT session" - ); - - self.table.remove_by_left(&inside); - self.expired.insert(outside); + state.outgoing_rst = true; } - self.last_seen.insert(outside, now); - return Ok(outside); + if packet.as_tcp().is_some_and(|tcp| tcp.fin()) { + state.outgoing_fin = true; + } + + state.last_outgoing = now; + + return Ok(outside.into_inner()); } // Find the first available public port, starting from the port of the to-be-mapped packet. // This will re-assign the same port in most cases, even after the mapping expires. let outside = (src.value()..=u16::MAX) .chain(1..src.value()) - .map(|p| (src.with_value(p), outside_dst)) + .map(|p| Outside(src.with_value(p), outside_dst)) .find(|outside| !self.table.contains_right(outside)) .context("Exhausted NAT")?; - let inside = (src, dst); - self.table.insert(inside, outside); - self.last_seen.insert(outside, now); + self.state_by_inside.insert(inside, EntryState::new(now)); self.expired.remove(&outside); tracing::debug!(?inside, ?outside, "New NAT session"); - Ok(outside) + Ok(outside.into_inner()) } pub(crate) fn translate_incoming( @@ -99,9 +131,11 @@ impl NatTable { now: Instant, ) -> Result { if let Some((failed_packet, icmp_error)) = packet.icmp_error()? { - let outside = (failed_packet.src_proto(), failed_packet.dst()); + let outside = Outside(failed_packet.src_proto(), failed_packet.dst()); - if let Some((inside_proto, inside_dst)) = self.translate_incoming_inner(&outside, now) { + if let Some(Inside(inside_proto, inside_dst)) = + self.translate_incoming_inner(&outside, now) + { return Ok(TranslateIncomingResult::IcmpError(IcmpErrorPrototype { inside_dst, inside_proto, @@ -117,21 +151,20 @@ impl NatTable { return Ok(TranslateIncomingResult::NoNatSession); } - let outside = (packet.destination_protocol()?, packet.source()); + let outside = Outside(packet.destination_protocol()?, packet.source()); - if let Some(inside) = self.translate_incoming_inner(&outside, now) { + if let Some(inside) = self.translate_incoming_inner(&outside, now) + && let Some(state) = self.state_by_inside.get_mut(&inside) + { if packet.as_tcp().is_some_and(|tcp| tcp.rst()) { - tracing::debug!( - ?inside, - ?outside, - "Witnessed incoming TCP RST, removing NAT session" - ); - - self.table.remove_by_right(&outside); - self.expired.insert(outside); + state.incoming_rst = true; } - let (proto, src) = inside; + if packet.as_tcp().is_some_and(|tcp| tcp.fin()) { + state.incoming_fin = true; + } + + let (proto, src) = inside.into_inner(); return Ok(TranslateIncomingResult::Ok { proto, src }); } @@ -143,20 +176,96 @@ impl NatTable { Ok(TranslateIncomingResult::NoNatSession) } - fn translate_incoming_inner( - &mut self, - outside: &(Protocol, IpAddr), - now: Instant, - ) -> Option<(Protocol, IpAddr)> { + fn translate_incoming_inner(&mut self, outside: &Outside, now: Instant) -> Option { let inside = self.table.get_by_right(outside)?; + let state = self.state_by_inside.get_mut(inside)?; - tracing::trace!(?inside, ?outside, "Translating incoming packet"); - self.last_seen.insert(*inside, now); + tracing::trace!(?inside, ?outside, ?state, "Translating incoming packet"); + + let prev_last_incoming = state.last_incoming.replace(now); + if prev_last_incoming.is_none() { + tracing::debug!(?inside, ?outside, "NAT session confirmed"); + } Some(*inside) } } +#[derive(Debug)] +struct EntryState { + last_outgoing: Instant, + last_incoming: Option, + + outgoing_rst: bool, + incoming_rst: bool, + outgoing_fin: bool, + incoming_fin: bool, +} + +impl EntryState { + fn new(last_outgoing: Instant) -> Self { + Self { + last_outgoing, + last_incoming: None, + outgoing_rst: false, + incoming_rst: false, + outgoing_fin: false, + incoming_fin: false, + } + } + + fn ttl_timeout(&self, protocol: Protocol) -> Instant { + let ttl = match protocol { + Protocol::Tcp(_) => TCP_TTL, + Protocol::Udp(_) => UDP_TTL, + Protocol::Icmp(_) => ICMP_TTL, + }; + + self.last_packet() + ttl + } + + fn unconfirmed_timeout(&self) -> Option { + if self.last_incoming.is_some() { + return None; + } + + Some(self.last_outgoing + UNCONFIRMED_TTL) + } + + fn fin_timeout(&self) -> Option { + if !self.outgoing_fin || !self.incoming_fin { + return None; + } + + Some(self.last_packet() + Duration::from_secs(5)) // Keep NAT open for a few more seconds. + } + + fn rst_timeout(&self) -> Option { + if !self.outgoing_rst && !self.incoming_rst { + return None; + } + + Some(self.last_packet()) // Close immediately. + } + + fn remove_at(&self, protocol: Protocol) -> Option { + std::iter::empty() + .chain(Some(self.ttl_timeout(protocol))) + .chain(self.unconfirmed_timeout()) + .chain(self.fin_timeout()) + .chain(self.rst_timeout()) + .min() + } + + fn last_packet(&self) -> Instant { + let Some(last_incoming) = self.last_incoming else { + return self.last_outgoing; + }; + + std::cmp::max(self.last_outgoing, last_incoming) + } +} + /// A prototype for an ICMP error packet. /// /// A packet coming in from the "outside" of the NAT may be an ICMP error. @@ -259,9 +368,15 @@ mod tests { response.set_src(new_dst_ip).unwrap(); // Update time. - table.handle_timeout(sent_at + response_delay); + table.handle_timeout(sent_at + Duration::from_secs(1)); - // Translate in + // Confirm mapping + table + .translate_incoming(&response.clone(), sent_at + Duration::from_secs(1)) + .unwrap(); + + // Simulate another packet after _response_delay_ + table.handle_timeout(sent_at + response_delay); let translate_incoming = table .translate_incoming(&response, sent_at + response_delay) .unwrap(); @@ -352,16 +467,17 @@ mod tests { rst.set_dst(req.destination()).unwrap(); let mut table = NatTable::default(); + let mut now = Instant::now(); - let outside = table - .translate_outgoing(&req, outside_dst, Instant::now()) - .unwrap(); + let outside = table.translate_outgoing(&req, outside_dst, now).unwrap(); let mut response = req.clone(); response.set_destination_protocol(outside.0.value()); response.set_src(outside.1).unwrap(); - match table.translate_incoming(&response, Instant::now()).unwrap() { + now += Duration::from_secs(1); + + match table.translate_incoming(&response, now).unwrap() { TranslateIncomingResult::Ok { .. } => {} result @ (TranslateIncomingResult::NoNatSession | TranslateIncomingResult::ExpiredNatSession @@ -370,11 +486,14 @@ mod tests { } }; - table - .translate_outgoing(&rst, outside_dst, Instant::now()) - .unwrap(); + now += Duration::from_secs(1); - match table.translate_incoming(&response, Instant::now()).unwrap() { + table.translate_outgoing(&rst, outside_dst, now).unwrap(); + + now += Duration::from_secs(1); + table.handle_timeout(now); + + match table.translate_incoming(&response, now).unwrap() { TranslateIncomingResult::ExpiredNatSession => {} result @ (TranslateIncomingResult::NoNatSession | TranslateIncomingResult::Ok { .. }