feat(gateway): improve state tracking of DNS resource NAT (#10868)

Right now, the state tracking within the DNS resource NAT table is
pretty simple:

- We map from inside to outside and back
- When we see a TCP RST, we remove it immediately

To improve our logs a bit and make the NAT table more robust, we extend
it by:

- Tracking last inbound and outbound packet
- Tracking FIN and RST flags

This allows us to fully observe e.g. a TCP shutdown where both parties
send TCP FIN. It also allows us to remove entries that have never been
confirmed after a shorter amount of time.

Resolves: #10795

---------

Signed-off-by: Thomas Eizinger <thomas@eizinger.io>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
Thomas Eizinger
2025-11-19 10:48:04 +11:00
committed by GitHub
parent f735855344
commit 35b28692de

View File

@@ -16,32 +16,66 @@ use std::time::{Duration, Instant};
/// Thus, purely an L3 NAT would not be sufficient as it would be impossible to map back to the proxy IP.
#[derive(Default, Debug)]
pub(crate) struct NatTable {
pub(crate) table: BiMap<(Protocol, IpAddr), (Protocol, IpAddr)>,
pub(crate) last_seen: BTreeMap<(Protocol, IpAddr), Instant>,
table: BiMap<Inside, Outside>,
state_by_inside: BTreeMap<Inside, EntryState>,
// We don't bother with proactively freeing this because a single entry is only ~20 bytes and it gets cleanup once the connection to the client goes away.
expired: HashSet<(Protocol, IpAddr)>,
expired: HashSet<Outside>,
}
#[derive(Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Clone, Copy)]
struct Inside(Protocol, IpAddr);
impl Inside {
fn into_inner(self) -> (Protocol, IpAddr) {
(self.0, self.1)
}
}
#[derive(Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Clone, Copy)]
struct Outside(Protocol, IpAddr);
impl Outside {
fn into_inner(self) -> (Protocol, IpAddr) {
(self.0, self.1)
}
}
pub(crate) const TCP_TTL: Duration = Duration::from_secs(60 * 60 * 2);
pub(crate) const UDP_TTL: Duration = Duration::from_secs(60 * 2);
pub(crate) const ICMP_TTL: Duration = Duration::from_secs(60 * 2);
pub(crate) const UNCONFIRMED_TTL: Duration = Duration::from_secs(60);
impl NatTable {
pub(crate) fn handle_timeout(&mut self, now: Instant) {
for (outside, e) in self.last_seen.iter() {
let ttl = match outside.0 {
Protocol::Tcp(_) => TCP_TTL,
Protocol::Udp(_) => UDP_TTL,
Protocol::Icmp(_) => ICMP_TTL,
let expired = self.state_by_inside.extract_if(.., |inside, state| {
state
.remove_at(inside.0)
.is_some_and(|remove_at| now >= remove_at)
});
for (inside, state) in expired {
let Some((_, outside)) = self.table.remove_by_left(&inside) else {
continue;
};
if now.duration_since(*e) >= ttl
&& let Some((inside, _)) = self.table.remove_by_right(outside)
{
tracing::debug!(?inside, ?outside, ?ttl, "NAT session expired");
self.expired.insert(*outside);
}
self.expired.insert(outside);
let last_outgoing = now.duration_since(state.last_outgoing);
let last_incoming = state.last_incoming.map(|t| now.duration_since(t));
tracing::debug!(
?inside,
?outside,
?last_outgoing,
?last_incoming,
fin_tx = %state.outgoing_fin,
fin_rx = %state.incoming_fin,
rst_tx = %state.outgoing_rst,
rst_rx = %state.incoming_rst,
"NAT entry removed"
);
}
}
@@ -54,43 +88,41 @@ impl NatTable {
let src = packet.source_protocol()?;
let dst = packet.destination();
let inside = (src, dst);
let inside = Inside(src, dst);
if let Some(outside) = self.table.get_by_left(&inside).copied() {
tracing::trace!(?inside, ?outside, "Translating outgoing packet");
if let Some(outside) = self.table.get_by_left(&inside).copied()
&& let Some(state) = self.state_by_inside.get_mut(&inside)
{
tracing::trace!(?inside, ?outside, ?state, "Translating outgoing packet");
if packet.as_tcp().is_some_and(|tcp| tcp.rst()) {
tracing::debug!(
?inside,
?outside,
"Witnessed outgoing TCP RST, removing NAT session"
);
self.table.remove_by_left(&inside);
self.expired.insert(outside);
state.outgoing_rst = true;
}
self.last_seen.insert(outside, now);
return Ok(outside);
if packet.as_tcp().is_some_and(|tcp| tcp.fin()) {
state.outgoing_fin = true;
}
state.last_outgoing = now;
return Ok(outside.into_inner());
}
// Find the first available public port, starting from the port of the to-be-mapped packet.
// This will re-assign the same port in most cases, even after the mapping expires.
let outside = (src.value()..=u16::MAX)
.chain(1..src.value())
.map(|p| (src.with_value(p), outside_dst))
.map(|p| Outside(src.with_value(p), outside_dst))
.find(|outside| !self.table.contains_right(outside))
.context("Exhausted NAT")?;
let inside = (src, dst);
self.table.insert(inside, outside);
self.last_seen.insert(outside, now);
self.state_by_inside.insert(inside, EntryState::new(now));
self.expired.remove(&outside);
tracing::debug!(?inside, ?outside, "New NAT session");
Ok(outside)
Ok(outside.into_inner())
}
pub(crate) fn translate_incoming(
@@ -99,9 +131,11 @@ impl NatTable {
now: Instant,
) -> Result<TranslateIncomingResult> {
if let Some((failed_packet, icmp_error)) = packet.icmp_error()? {
let outside = (failed_packet.src_proto(), failed_packet.dst());
let outside = Outside(failed_packet.src_proto(), failed_packet.dst());
if let Some((inside_proto, inside_dst)) = self.translate_incoming_inner(&outside, now) {
if let Some(Inside(inside_proto, inside_dst)) =
self.translate_incoming_inner(&outside, now)
{
return Ok(TranslateIncomingResult::IcmpError(IcmpErrorPrototype {
inside_dst,
inside_proto,
@@ -117,21 +151,20 @@ impl NatTable {
return Ok(TranslateIncomingResult::NoNatSession);
}
let outside = (packet.destination_protocol()?, packet.source());
let outside = Outside(packet.destination_protocol()?, packet.source());
if let Some(inside) = self.translate_incoming_inner(&outside, now) {
if let Some(inside) = self.translate_incoming_inner(&outside, now)
&& let Some(state) = self.state_by_inside.get_mut(&inside)
{
if packet.as_tcp().is_some_and(|tcp| tcp.rst()) {
tracing::debug!(
?inside,
?outside,
"Witnessed incoming TCP RST, removing NAT session"
);
self.table.remove_by_right(&outside);
self.expired.insert(outside);
state.incoming_rst = true;
}
let (proto, src) = inside;
if packet.as_tcp().is_some_and(|tcp| tcp.fin()) {
state.incoming_fin = true;
}
let (proto, src) = inside.into_inner();
return Ok(TranslateIncomingResult::Ok { proto, src });
}
@@ -143,20 +176,96 @@ impl NatTable {
Ok(TranslateIncomingResult::NoNatSession)
}
fn translate_incoming_inner(
&mut self,
outside: &(Protocol, IpAddr),
now: Instant,
) -> Option<(Protocol, IpAddr)> {
fn translate_incoming_inner(&mut self, outside: &Outside, now: Instant) -> Option<Inside> {
let inside = self.table.get_by_right(outside)?;
let state = self.state_by_inside.get_mut(inside)?;
tracing::trace!(?inside, ?outside, "Translating incoming packet");
self.last_seen.insert(*inside, now);
tracing::trace!(?inside, ?outside, ?state, "Translating incoming packet");
let prev_last_incoming = state.last_incoming.replace(now);
if prev_last_incoming.is_none() {
tracing::debug!(?inside, ?outside, "NAT session confirmed");
}
Some(*inside)
}
}
#[derive(Debug)]
struct EntryState {
last_outgoing: Instant,
last_incoming: Option<Instant>,
outgoing_rst: bool,
incoming_rst: bool,
outgoing_fin: bool,
incoming_fin: bool,
}
impl EntryState {
fn new(last_outgoing: Instant) -> Self {
Self {
last_outgoing,
last_incoming: None,
outgoing_rst: false,
incoming_rst: false,
outgoing_fin: false,
incoming_fin: false,
}
}
fn ttl_timeout(&self, protocol: Protocol) -> Instant {
let ttl = match protocol {
Protocol::Tcp(_) => TCP_TTL,
Protocol::Udp(_) => UDP_TTL,
Protocol::Icmp(_) => ICMP_TTL,
};
self.last_packet() + ttl
}
fn unconfirmed_timeout(&self) -> Option<Instant> {
if self.last_incoming.is_some() {
return None;
}
Some(self.last_outgoing + UNCONFIRMED_TTL)
}
fn fin_timeout(&self) -> Option<Instant> {
if !self.outgoing_fin || !self.incoming_fin {
return None;
}
Some(self.last_packet() + Duration::from_secs(5)) // Keep NAT open for a few more seconds.
}
fn rst_timeout(&self) -> Option<Instant> {
if !self.outgoing_rst && !self.incoming_rst {
return None;
}
Some(self.last_packet()) // Close immediately.
}
fn remove_at(&self, protocol: Protocol) -> Option<Instant> {
std::iter::empty()
.chain(Some(self.ttl_timeout(protocol)))
.chain(self.unconfirmed_timeout())
.chain(self.fin_timeout())
.chain(self.rst_timeout())
.min()
}
fn last_packet(&self) -> Instant {
let Some(last_incoming) = self.last_incoming else {
return self.last_outgoing;
};
std::cmp::max(self.last_outgoing, last_incoming)
}
}
/// A prototype for an ICMP error packet.
///
/// A packet coming in from the "outside" of the NAT may be an ICMP error.
@@ -259,9 +368,15 @@ mod tests {
response.set_src(new_dst_ip).unwrap();
// Update time.
table.handle_timeout(sent_at + response_delay);
table.handle_timeout(sent_at + Duration::from_secs(1));
// Translate in
// Confirm mapping
table
.translate_incoming(&response.clone(), sent_at + Duration::from_secs(1))
.unwrap();
// Simulate another packet after _response_delay_
table.handle_timeout(sent_at + response_delay);
let translate_incoming = table
.translate_incoming(&response, sent_at + response_delay)
.unwrap();
@@ -352,16 +467,17 @@ mod tests {
rst.set_dst(req.destination()).unwrap();
let mut table = NatTable::default();
let mut now = Instant::now();
let outside = table
.translate_outgoing(&req, outside_dst, Instant::now())
.unwrap();
let outside = table.translate_outgoing(&req, outside_dst, now).unwrap();
let mut response = req.clone();
response.set_destination_protocol(outside.0.value());
response.set_src(outside.1).unwrap();
match table.translate_incoming(&response, Instant::now()).unwrap() {
now += Duration::from_secs(1);
match table.translate_incoming(&response, now).unwrap() {
TranslateIncomingResult::Ok { .. } => {}
result @ (TranslateIncomingResult::NoNatSession
| TranslateIncomingResult::ExpiredNatSession
@@ -370,11 +486,14 @@ mod tests {
}
};
table
.translate_outgoing(&rst, outside_dst, Instant::now())
.unwrap();
now += Duration::from_secs(1);
match table.translate_incoming(&response, Instant::now()).unwrap() {
table.translate_outgoing(&rst, outside_dst, now).unwrap();
now += Duration::from_secs(1);
table.handle_timeout(now);
match table.translate_incoming(&response, now).unwrap() {
TranslateIncomingResult::ExpiredNatSession => {}
result @ (TranslateIncomingResult::NoNatSession
| TranslateIncomingResult::Ok { .. }