mirror of
https://github.com/outbackdingo/firezone.git
synced 2026-01-27 10:18:54 +00:00
feat(gateway): improve state tracking of DNS resource NAT (#10868)
Right now, the state tracking within the DNS resource NAT table is pretty simple: - We map from inside to outside and back - When we see a TCP RST, we remove it immediately To improve our logs a bit and make the NAT table more robust, we extend it by: - Tracking last inbound and outbound packet - Tracking FIN and RST flags This allows us to fully observe e.g. a TCP shutdown where both parties send TCP FIN. It also allows us to remove entries that have never been confirmed after a shorter amount of time. Resolves: #10795 --------- Signed-off-by: Thomas Eizinger <thomas@eizinger.io> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -16,32 +16,66 @@ use std::time::{Duration, Instant};
|
||||
/// Thus, purely an L3 NAT would not be sufficient as it would be impossible to map back to the proxy IP.
|
||||
#[derive(Default, Debug)]
|
||||
pub(crate) struct NatTable {
|
||||
pub(crate) table: BiMap<(Protocol, IpAddr), (Protocol, IpAddr)>,
|
||||
pub(crate) last_seen: BTreeMap<(Protocol, IpAddr), Instant>,
|
||||
table: BiMap<Inside, Outside>,
|
||||
state_by_inside: BTreeMap<Inside, EntryState>,
|
||||
|
||||
// We don't bother with proactively freeing this because a single entry is only ~20 bytes and it gets cleanup once the connection to the client goes away.
|
||||
expired: HashSet<(Protocol, IpAddr)>,
|
||||
expired: HashSet<Outside>,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Clone, Copy)]
|
||||
struct Inside(Protocol, IpAddr);
|
||||
|
||||
impl Inside {
|
||||
fn into_inner(self) -> (Protocol, IpAddr) {
|
||||
(self.0, self.1)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Clone, Copy)]
|
||||
struct Outside(Protocol, IpAddr);
|
||||
|
||||
impl Outside {
|
||||
fn into_inner(self) -> (Protocol, IpAddr) {
|
||||
(self.0, self.1)
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) const TCP_TTL: Duration = Duration::from_secs(60 * 60 * 2);
|
||||
pub(crate) const UDP_TTL: Duration = Duration::from_secs(60 * 2);
|
||||
pub(crate) const ICMP_TTL: Duration = Duration::from_secs(60 * 2);
|
||||
|
||||
pub(crate) const UNCONFIRMED_TTL: Duration = Duration::from_secs(60);
|
||||
|
||||
impl NatTable {
|
||||
pub(crate) fn handle_timeout(&mut self, now: Instant) {
|
||||
for (outside, e) in self.last_seen.iter() {
|
||||
let ttl = match outside.0 {
|
||||
Protocol::Tcp(_) => TCP_TTL,
|
||||
Protocol::Udp(_) => UDP_TTL,
|
||||
Protocol::Icmp(_) => ICMP_TTL,
|
||||
let expired = self.state_by_inside.extract_if(.., |inside, state| {
|
||||
state
|
||||
.remove_at(inside.0)
|
||||
.is_some_and(|remove_at| now >= remove_at)
|
||||
});
|
||||
|
||||
for (inside, state) in expired {
|
||||
let Some((_, outside)) = self.table.remove_by_left(&inside) else {
|
||||
continue;
|
||||
};
|
||||
|
||||
if now.duration_since(*e) >= ttl
|
||||
&& let Some((inside, _)) = self.table.remove_by_right(outside)
|
||||
{
|
||||
tracing::debug!(?inside, ?outside, ?ttl, "NAT session expired");
|
||||
self.expired.insert(*outside);
|
||||
}
|
||||
self.expired.insert(outside);
|
||||
|
||||
let last_outgoing = now.duration_since(state.last_outgoing);
|
||||
let last_incoming = state.last_incoming.map(|t| now.duration_since(t));
|
||||
|
||||
tracing::debug!(
|
||||
?inside,
|
||||
?outside,
|
||||
?last_outgoing,
|
||||
?last_incoming,
|
||||
fin_tx = %state.outgoing_fin,
|
||||
fin_rx = %state.incoming_fin,
|
||||
rst_tx = %state.outgoing_rst,
|
||||
rst_rx = %state.incoming_rst,
|
||||
"NAT entry removed"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -54,43 +88,41 @@ impl NatTable {
|
||||
let src = packet.source_protocol()?;
|
||||
let dst = packet.destination();
|
||||
|
||||
let inside = (src, dst);
|
||||
let inside = Inside(src, dst);
|
||||
|
||||
if let Some(outside) = self.table.get_by_left(&inside).copied() {
|
||||
tracing::trace!(?inside, ?outside, "Translating outgoing packet");
|
||||
if let Some(outside) = self.table.get_by_left(&inside).copied()
|
||||
&& let Some(state) = self.state_by_inside.get_mut(&inside)
|
||||
{
|
||||
tracing::trace!(?inside, ?outside, ?state, "Translating outgoing packet");
|
||||
|
||||
if packet.as_tcp().is_some_and(|tcp| tcp.rst()) {
|
||||
tracing::debug!(
|
||||
?inside,
|
||||
?outside,
|
||||
"Witnessed outgoing TCP RST, removing NAT session"
|
||||
);
|
||||
|
||||
self.table.remove_by_left(&inside);
|
||||
self.expired.insert(outside);
|
||||
state.outgoing_rst = true;
|
||||
}
|
||||
|
||||
self.last_seen.insert(outside, now);
|
||||
return Ok(outside);
|
||||
if packet.as_tcp().is_some_and(|tcp| tcp.fin()) {
|
||||
state.outgoing_fin = true;
|
||||
}
|
||||
|
||||
state.last_outgoing = now;
|
||||
|
||||
return Ok(outside.into_inner());
|
||||
}
|
||||
|
||||
// Find the first available public port, starting from the port of the to-be-mapped packet.
|
||||
// This will re-assign the same port in most cases, even after the mapping expires.
|
||||
let outside = (src.value()..=u16::MAX)
|
||||
.chain(1..src.value())
|
||||
.map(|p| (src.with_value(p), outside_dst))
|
||||
.map(|p| Outside(src.with_value(p), outside_dst))
|
||||
.find(|outside| !self.table.contains_right(outside))
|
||||
.context("Exhausted NAT")?;
|
||||
|
||||
let inside = (src, dst);
|
||||
|
||||
self.table.insert(inside, outside);
|
||||
self.last_seen.insert(outside, now);
|
||||
self.state_by_inside.insert(inside, EntryState::new(now));
|
||||
self.expired.remove(&outside);
|
||||
|
||||
tracing::debug!(?inside, ?outside, "New NAT session");
|
||||
|
||||
Ok(outside)
|
||||
Ok(outside.into_inner())
|
||||
}
|
||||
|
||||
pub(crate) fn translate_incoming(
|
||||
@@ -99,9 +131,11 @@ impl NatTable {
|
||||
now: Instant,
|
||||
) -> Result<TranslateIncomingResult> {
|
||||
if let Some((failed_packet, icmp_error)) = packet.icmp_error()? {
|
||||
let outside = (failed_packet.src_proto(), failed_packet.dst());
|
||||
let outside = Outside(failed_packet.src_proto(), failed_packet.dst());
|
||||
|
||||
if let Some((inside_proto, inside_dst)) = self.translate_incoming_inner(&outside, now) {
|
||||
if let Some(Inside(inside_proto, inside_dst)) =
|
||||
self.translate_incoming_inner(&outside, now)
|
||||
{
|
||||
return Ok(TranslateIncomingResult::IcmpError(IcmpErrorPrototype {
|
||||
inside_dst,
|
||||
inside_proto,
|
||||
@@ -117,21 +151,20 @@ impl NatTable {
|
||||
return Ok(TranslateIncomingResult::NoNatSession);
|
||||
}
|
||||
|
||||
let outside = (packet.destination_protocol()?, packet.source());
|
||||
let outside = Outside(packet.destination_protocol()?, packet.source());
|
||||
|
||||
if let Some(inside) = self.translate_incoming_inner(&outside, now) {
|
||||
if let Some(inside) = self.translate_incoming_inner(&outside, now)
|
||||
&& let Some(state) = self.state_by_inside.get_mut(&inside)
|
||||
{
|
||||
if packet.as_tcp().is_some_and(|tcp| tcp.rst()) {
|
||||
tracing::debug!(
|
||||
?inside,
|
||||
?outside,
|
||||
"Witnessed incoming TCP RST, removing NAT session"
|
||||
);
|
||||
|
||||
self.table.remove_by_right(&outside);
|
||||
self.expired.insert(outside);
|
||||
state.incoming_rst = true;
|
||||
}
|
||||
|
||||
let (proto, src) = inside;
|
||||
if packet.as_tcp().is_some_and(|tcp| tcp.fin()) {
|
||||
state.incoming_fin = true;
|
||||
}
|
||||
|
||||
let (proto, src) = inside.into_inner();
|
||||
|
||||
return Ok(TranslateIncomingResult::Ok { proto, src });
|
||||
}
|
||||
@@ -143,20 +176,96 @@ impl NatTable {
|
||||
Ok(TranslateIncomingResult::NoNatSession)
|
||||
}
|
||||
|
||||
fn translate_incoming_inner(
|
||||
&mut self,
|
||||
outside: &(Protocol, IpAddr),
|
||||
now: Instant,
|
||||
) -> Option<(Protocol, IpAddr)> {
|
||||
fn translate_incoming_inner(&mut self, outside: &Outside, now: Instant) -> Option<Inside> {
|
||||
let inside = self.table.get_by_right(outside)?;
|
||||
let state = self.state_by_inside.get_mut(inside)?;
|
||||
|
||||
tracing::trace!(?inside, ?outside, "Translating incoming packet");
|
||||
self.last_seen.insert(*inside, now);
|
||||
tracing::trace!(?inside, ?outside, ?state, "Translating incoming packet");
|
||||
|
||||
let prev_last_incoming = state.last_incoming.replace(now);
|
||||
if prev_last_incoming.is_none() {
|
||||
tracing::debug!(?inside, ?outside, "NAT session confirmed");
|
||||
}
|
||||
|
||||
Some(*inside)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct EntryState {
|
||||
last_outgoing: Instant,
|
||||
last_incoming: Option<Instant>,
|
||||
|
||||
outgoing_rst: bool,
|
||||
incoming_rst: bool,
|
||||
outgoing_fin: bool,
|
||||
incoming_fin: bool,
|
||||
}
|
||||
|
||||
impl EntryState {
|
||||
fn new(last_outgoing: Instant) -> Self {
|
||||
Self {
|
||||
last_outgoing,
|
||||
last_incoming: None,
|
||||
outgoing_rst: false,
|
||||
incoming_rst: false,
|
||||
outgoing_fin: false,
|
||||
incoming_fin: false,
|
||||
}
|
||||
}
|
||||
|
||||
fn ttl_timeout(&self, protocol: Protocol) -> Instant {
|
||||
let ttl = match protocol {
|
||||
Protocol::Tcp(_) => TCP_TTL,
|
||||
Protocol::Udp(_) => UDP_TTL,
|
||||
Protocol::Icmp(_) => ICMP_TTL,
|
||||
};
|
||||
|
||||
self.last_packet() + ttl
|
||||
}
|
||||
|
||||
fn unconfirmed_timeout(&self) -> Option<Instant> {
|
||||
if self.last_incoming.is_some() {
|
||||
return None;
|
||||
}
|
||||
|
||||
Some(self.last_outgoing + UNCONFIRMED_TTL)
|
||||
}
|
||||
|
||||
fn fin_timeout(&self) -> Option<Instant> {
|
||||
if !self.outgoing_fin || !self.incoming_fin {
|
||||
return None;
|
||||
}
|
||||
|
||||
Some(self.last_packet() + Duration::from_secs(5)) // Keep NAT open for a few more seconds.
|
||||
}
|
||||
|
||||
fn rst_timeout(&self) -> Option<Instant> {
|
||||
if !self.outgoing_rst && !self.incoming_rst {
|
||||
return None;
|
||||
}
|
||||
|
||||
Some(self.last_packet()) // Close immediately.
|
||||
}
|
||||
|
||||
fn remove_at(&self, protocol: Protocol) -> Option<Instant> {
|
||||
std::iter::empty()
|
||||
.chain(Some(self.ttl_timeout(protocol)))
|
||||
.chain(self.unconfirmed_timeout())
|
||||
.chain(self.fin_timeout())
|
||||
.chain(self.rst_timeout())
|
||||
.min()
|
||||
}
|
||||
|
||||
fn last_packet(&self) -> Instant {
|
||||
let Some(last_incoming) = self.last_incoming else {
|
||||
return self.last_outgoing;
|
||||
};
|
||||
|
||||
std::cmp::max(self.last_outgoing, last_incoming)
|
||||
}
|
||||
}
|
||||
|
||||
/// A prototype for an ICMP error packet.
|
||||
///
|
||||
/// A packet coming in from the "outside" of the NAT may be an ICMP error.
|
||||
@@ -259,9 +368,15 @@ mod tests {
|
||||
response.set_src(new_dst_ip).unwrap();
|
||||
|
||||
// Update time.
|
||||
table.handle_timeout(sent_at + response_delay);
|
||||
table.handle_timeout(sent_at + Duration::from_secs(1));
|
||||
|
||||
// Translate in
|
||||
// Confirm mapping
|
||||
table
|
||||
.translate_incoming(&response.clone(), sent_at + Duration::from_secs(1))
|
||||
.unwrap();
|
||||
|
||||
// Simulate another packet after _response_delay_
|
||||
table.handle_timeout(sent_at + response_delay);
|
||||
let translate_incoming = table
|
||||
.translate_incoming(&response, sent_at + response_delay)
|
||||
.unwrap();
|
||||
@@ -352,16 +467,17 @@ mod tests {
|
||||
rst.set_dst(req.destination()).unwrap();
|
||||
|
||||
let mut table = NatTable::default();
|
||||
let mut now = Instant::now();
|
||||
|
||||
let outside = table
|
||||
.translate_outgoing(&req, outside_dst, Instant::now())
|
||||
.unwrap();
|
||||
let outside = table.translate_outgoing(&req, outside_dst, now).unwrap();
|
||||
|
||||
let mut response = req.clone();
|
||||
response.set_destination_protocol(outside.0.value());
|
||||
response.set_src(outside.1).unwrap();
|
||||
|
||||
match table.translate_incoming(&response, Instant::now()).unwrap() {
|
||||
now += Duration::from_secs(1);
|
||||
|
||||
match table.translate_incoming(&response, now).unwrap() {
|
||||
TranslateIncomingResult::Ok { .. } => {}
|
||||
result @ (TranslateIncomingResult::NoNatSession
|
||||
| TranslateIncomingResult::ExpiredNatSession
|
||||
@@ -370,11 +486,14 @@ mod tests {
|
||||
}
|
||||
};
|
||||
|
||||
table
|
||||
.translate_outgoing(&rst, outside_dst, Instant::now())
|
||||
.unwrap();
|
||||
now += Duration::from_secs(1);
|
||||
|
||||
match table.translate_incoming(&response, Instant::now()).unwrap() {
|
||||
table.translate_outgoing(&rst, outside_dst, now).unwrap();
|
||||
|
||||
now += Duration::from_secs(1);
|
||||
table.handle_timeout(now);
|
||||
|
||||
match table.translate_incoming(&response, now).unwrap() {
|
||||
TranslateIncomingResult::ExpiredNatSession => {}
|
||||
result @ (TranslateIncomingResult::NoNatSession
|
||||
| TranslateIncomingResult::Ok { .. }
|
||||
|
||||
Reference in New Issue
Block a user