mirror of
https://github.com/outbackdingo/firezone.git
synced 2026-01-27 18:18:55 +00:00
feat(gateway): free TCP NAT bindings on RSTs (#9682)
Whenever we see a TCP packet with the RST bit set, we clear the current NAT binding and move it to the `expired` list.
This commit is contained in:
@@ -109,28 +109,44 @@ pub fn tcp_packet<IP>(
|
||||
daddr: IP,
|
||||
sport: u16,
|
||||
dport: u16,
|
||||
flags: TcpFlags,
|
||||
payload: Vec<u8>,
|
||||
) -> Result<IpPacket>
|
||||
where
|
||||
IP: Into<IpAddr>,
|
||||
{
|
||||
let TcpFlags { rst } = flags;
|
||||
|
||||
match (saddr.into(), daddr.into()) {
|
||||
(IpAddr::V4(src), IpAddr::V4(dst)) => {
|
||||
let packet =
|
||||
let mut packet =
|
||||
PacketBuilder::ipv4(src.octets(), dst.octets(), 64).tcp(sport, dport, 0, 128);
|
||||
|
||||
if rst {
|
||||
packet = packet.rst();
|
||||
}
|
||||
|
||||
build!(packet, payload)
|
||||
}
|
||||
(IpAddr::V6(src), IpAddr::V6(dst)) => {
|
||||
let packet =
|
||||
let mut packet =
|
||||
PacketBuilder::ipv6(src.octets(), dst.octets(), 64).tcp(sport, dport, 0, 128);
|
||||
|
||||
if rst {
|
||||
packet = packet.rst();
|
||||
}
|
||||
|
||||
build!(packet, payload)
|
||||
}
|
||||
_ => bail!(IpVersionMismatch),
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Clone, Copy)]
|
||||
pub struct TcpFlags {
|
||||
pub rst: bool,
|
||||
}
|
||||
|
||||
pub fn udp_packet<IP>(
|
||||
saddr: IP,
|
||||
daddr: IP,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use crate::IpPacket;
|
||||
use proptest::{arbitrary::any, prop_oneof, strategy::Strategy};
|
||||
use crate::{IpPacket, make::TcpFlags};
|
||||
use proptest::{arbitrary::any, prelude::Just, prop_oneof, strategy::Strategy};
|
||||
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
|
||||
|
||||
pub fn udp_packet() -> impl Strategy<Value = IpPacket> {
|
||||
@@ -13,14 +13,20 @@ pub fn udp_packet() -> impl Strategy<Value = IpPacket> {
|
||||
]
|
||||
}
|
||||
|
||||
pub fn tcp_packet() -> impl Strategy<Value = IpPacket> {
|
||||
pub fn tcp_packet(
|
||||
flags: impl Strategy<Value = TcpFlags> + Clone,
|
||||
) -> impl Strategy<Value = IpPacket> {
|
||||
prop_oneof![
|
||||
(ip4_tuple(), any::<u16>(), any::<u16>()).prop_map(|((saddr, daddr), sport, dport)| {
|
||||
crate::make::tcp_packet(saddr, daddr, sport, dport, Vec::new()).unwrap()
|
||||
}),
|
||||
(ip6_tuple(), any::<u16>(), any::<u16>()).prop_map(|((saddr, daddr), sport, dport)| {
|
||||
crate::make::tcp_packet(saddr, daddr, sport, dport, Vec::new()).unwrap()
|
||||
}),
|
||||
(ip4_tuple(), any::<u16>(), any::<u16>(), flags.clone()).prop_map(
|
||||
|((saddr, daddr), sport, dport, flags)| {
|
||||
crate::make::tcp_packet(saddr, daddr, sport, dport, flags, Vec::new()).unwrap()
|
||||
}
|
||||
),
|
||||
(ip6_tuple(), any::<u16>(), any::<u16>(), flags).prop_map(
|
||||
|((saddr, daddr), sport, dport, flags)| {
|
||||
crate::make::tcp_packet(saddr, daddr, sport, dport, flags, Vec::new()).unwrap()
|
||||
}
|
||||
),
|
||||
]
|
||||
}
|
||||
|
||||
@@ -36,7 +42,11 @@ pub fn icmp_request_packet() -> impl Strategy<Value = IpPacket> {
|
||||
}
|
||||
|
||||
pub fn udp_or_tcp_or_icmp_packet() -> impl Strategy<Value = IpPacket> {
|
||||
prop_oneof![udp_packet(), tcp_packet(), icmp_request_packet()]
|
||||
prop_oneof![
|
||||
udp_packet(),
|
||||
tcp_packet(Just(TcpFlags::default())),
|
||||
icmp_request_packet()
|
||||
]
|
||||
}
|
||||
|
||||
fn ip4_tuple() -> impl Strategy<Value = (Ipv4Addr, Ipv4Addr)> {
|
||||
|
||||
@@ -688,6 +688,7 @@ mod tests {
|
||||
use chrono::Utc;
|
||||
use connlib_model::{ClientId, ResourceId};
|
||||
use ip_network::{IpNetwork, Ipv4Network};
|
||||
use ip_packet::make::TcpFlags;
|
||||
|
||||
use super::ClientOnGateway;
|
||||
|
||||
@@ -727,6 +728,7 @@ mod tests {
|
||||
cidr_v4_resource().hosts().next().unwrap(),
|
||||
5401,
|
||||
80,
|
||||
TcpFlags::default(),
|
||||
vec![0; 100],
|
||||
)
|
||||
.unwrap();
|
||||
@@ -801,6 +803,7 @@ mod tests {
|
||||
gateway_tun_ipv4(),
|
||||
5401,
|
||||
80,
|
||||
TcpFlags::default(),
|
||||
vec![0; 100],
|
||||
)
|
||||
.unwrap();
|
||||
@@ -810,6 +813,7 @@ mod tests {
|
||||
client_tun_ipv4(),
|
||||
80,
|
||||
5401,
|
||||
TcpFlags::default(),
|
||||
vec![0; 100],
|
||||
)
|
||||
.unwrap();
|
||||
@@ -1188,7 +1192,7 @@ mod proptests {
|
||||
Filter, PortRange, ResourceDescription, ResourceDescriptionCidr,
|
||||
};
|
||||
use crate::proptest::*;
|
||||
use ip_packet::make::{icmp_request_packet, tcp_packet, udp_packet};
|
||||
use ip_packet::make::{TcpFlags, icmp_request_packet, tcp_packet, udp_packet};
|
||||
use itertools::Itertools as _;
|
||||
use proptest::{
|
||||
arbitrary::any,
|
||||
@@ -1235,7 +1239,14 @@ mod proptests {
|
||||
};
|
||||
|
||||
let packet = match protocol {
|
||||
Protocol::Tcp { dport } => tcp_packet(src, *dest, sport, *dport, payload.clone()),
|
||||
Protocol::Tcp { dport } => tcp_packet(
|
||||
src,
|
||||
*dest,
|
||||
sport,
|
||||
*dport,
|
||||
TcpFlags::default(),
|
||||
payload.clone(),
|
||||
),
|
||||
Protocol::Udp { dport } => udp_packet(src, *dest, sport, *dport, payload.clone()),
|
||||
Protocol::Icmp => icmp_request_packet(src, *dest, 1, 0, &[]),
|
||||
}
|
||||
@@ -1290,7 +1301,14 @@ mod proptests {
|
||||
|
||||
for (_, protocol) in protocol_config {
|
||||
let packet = match protocol {
|
||||
Protocol::Tcp { dport } => tcp_packet(src, dest, sport, dport, payload.clone()),
|
||||
Protocol::Tcp { dport } => tcp_packet(
|
||||
src,
|
||||
dest,
|
||||
sport,
|
||||
dport,
|
||||
TcpFlags::default(),
|
||||
payload.clone(),
|
||||
),
|
||||
Protocol::Udp { dport } => udp_packet(src, dest, sport, dport, payload.clone()),
|
||||
Protocol::Icmp => icmp_request_packet(src, dest, 1, 0, &[]),
|
||||
}
|
||||
@@ -1331,7 +1349,9 @@ mod proptests {
|
||||
gateway_tun(),
|
||||
);
|
||||
let packet = match protocol {
|
||||
Protocol::Tcp { dport } => tcp_packet(src, dest, sport, dport, payload),
|
||||
Protocol::Tcp { dport } => {
|
||||
tcp_packet(src, dest, sport, dport, TcpFlags::default(), payload)
|
||||
}
|
||||
Protocol::Udp { dport } => udp_packet(src, dest, sport, dport, payload),
|
||||
Protocol::Icmp => icmp_request_packet(src, dest, 1, 0, &[]),
|
||||
}
|
||||
@@ -1387,14 +1407,23 @@ mod proptests {
|
||||
);
|
||||
|
||||
let packet_allowed = match protocol_allowed {
|
||||
Protocol::Tcp { dport } => tcp_packet(src, dest, sport, dport, payload.clone()),
|
||||
Protocol::Tcp { dport } => tcp_packet(
|
||||
src,
|
||||
dest,
|
||||
sport,
|
||||
dport,
|
||||
TcpFlags::default(),
|
||||
payload.clone(),
|
||||
),
|
||||
Protocol::Udp { dport } => udp_packet(src, dest, sport, dport, payload.clone()),
|
||||
Protocol::Icmp => icmp_request_packet(src, dest, 1, 0, &[]),
|
||||
}
|
||||
.unwrap();
|
||||
|
||||
let packet_rejected = match protocol_removed {
|
||||
Protocol::Tcp { dport } => tcp_packet(src, dest, sport, dport, payload),
|
||||
Protocol::Tcp { dport } => {
|
||||
tcp_packet(src, dest, sport, dport, TcpFlags::default(), payload)
|
||||
}
|
||||
Protocol::Udp { dport } => udp_packet(src, dest, sport, dport, payload),
|
||||
Protocol::Icmp => icmp_request_packet(src, dest, 1, 0, &[]),
|
||||
}
|
||||
|
||||
@@ -61,12 +61,23 @@ impl NatTable {
|
||||
|
||||
let inside = (src, dst);
|
||||
|
||||
if let Some(outside) = self.table.get_by_left(&inside) {
|
||||
if let Some(outside) = self.table.get_by_left(&inside).copied() {
|
||||
if outside.1 == outside_dst {
|
||||
tracing::trace!(?inside, ?outside, "Translating outgoing packet");
|
||||
|
||||
self.last_seen.insert(*outside, now);
|
||||
return Ok(*outside);
|
||||
if packet.as_tcp().is_some_and(|tcp| tcp.rst()) {
|
||||
tracing::debug!(
|
||||
?inside,
|
||||
?outside,
|
||||
"Witnessed outgoing TCP RST, removing NAT session"
|
||||
);
|
||||
|
||||
self.table.remove_by_left(&inside);
|
||||
self.expired.insert(outside);
|
||||
}
|
||||
|
||||
self.last_seen.insert(outside, now);
|
||||
return Ok(outside);
|
||||
}
|
||||
|
||||
tracing::trace!(?inside, ?outside, "Outgoing packet for expired translation");
|
||||
@@ -84,6 +95,7 @@ impl NatTable {
|
||||
|
||||
self.table.insert(inside, outside);
|
||||
self.last_seen.insert(outside, now);
|
||||
self.expired.remove(&outside);
|
||||
|
||||
tracing::debug!(?inside, ?outside, "New NAT session");
|
||||
|
||||
@@ -118,7 +130,20 @@ impl NatTable {
|
||||
|
||||
let outside = (packet.destination_protocol()?, packet.source());
|
||||
|
||||
if let Some((proto, src)) = self.translate_incoming_inner(&outside, now) {
|
||||
if let Some(inside) = self.translate_incoming_inner(&outside, now) {
|
||||
if packet.as_tcp().is_some_and(|tcp| tcp.rst()) {
|
||||
tracing::debug!(
|
||||
?inside,
|
||||
?outside,
|
||||
"Witnessed incoming TCP RST, removing NAT session"
|
||||
);
|
||||
|
||||
self.table.remove_by_right(&outside);
|
||||
self.expired.insert(outside);
|
||||
}
|
||||
|
||||
let (proto, src) = inside;
|
||||
|
||||
return Ok(TranslateIncomingResult::Ok { proto, src });
|
||||
}
|
||||
|
||||
@@ -215,7 +240,7 @@ pub enum TranslateIncomingResult {
|
||||
#[cfg(all(test, feature = "proptest"))]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use ip_packet::{IpPacket, proptest::*};
|
||||
use ip_packet::{IpPacket, make::TcpFlags, proptest::*};
|
||||
use proptest::prelude::*;
|
||||
|
||||
#[test_strategy::proptest(ProptestConfig { max_local_rejects: 10_000, max_global_rejects: 10_000, ..ProptestConfig::default() })]
|
||||
@@ -322,4 +347,51 @@ mod tests {
|
||||
|
||||
assert_eq!(responses, original_src_p_and_dst);
|
||||
}
|
||||
|
||||
#[test_strategy::proptest]
|
||||
fn outgoing_tcp_rst_removes_nat_mapping(
|
||||
#[strategy(tcp_packet(Just(TcpFlags::default())))] req: IpPacket,
|
||||
#[strategy(tcp_packet(Just(TcpFlags { rst: true })))] mut rst: IpPacket,
|
||||
#[strategy(any::<IpAddr>())] outside_dst: IpAddr,
|
||||
) {
|
||||
let _guard = firezone_logging::test("trace");
|
||||
|
||||
proptest::prop_assume!(req.destination().is_ipv4() == outside_dst.is_ipv4()); // Required for our test to simulate a response.
|
||||
proptest::prop_assume!(rst.destination().is_ipv4() == outside_dst.is_ipv4()); // Required for our test to simulate a response.
|
||||
rst.set_source_protocol(req.source_protocol().unwrap().value());
|
||||
rst.set_destination_protocol(req.destination_protocol().unwrap().value());
|
||||
rst.set_dst(req.destination()).unwrap();
|
||||
|
||||
let mut table = NatTable::default();
|
||||
|
||||
let outside = table
|
||||
.translate_outgoing(&req, outside_dst, Instant::now())
|
||||
.unwrap();
|
||||
|
||||
let mut response = req.clone();
|
||||
response.set_destination_protocol(outside.0.value());
|
||||
response.set_src(outside.1).unwrap();
|
||||
|
||||
match table.translate_incoming(&response, Instant::now()).unwrap() {
|
||||
TranslateIncomingResult::Ok { .. } => {}
|
||||
result @ (TranslateIncomingResult::NoNatSession
|
||||
| TranslateIncomingResult::ExpiredNatSession
|
||||
| TranslateIncomingResult::DestinationUnreachable(_)) => {
|
||||
panic!("Wrong result: {result:?}")
|
||||
}
|
||||
};
|
||||
|
||||
table
|
||||
.translate_outgoing(&rst, outside_dst, Instant::now())
|
||||
.unwrap();
|
||||
|
||||
match table.translate_incoming(&response, Instant::now()).unwrap() {
|
||||
TranslateIncomingResult::ExpiredNatSession => {}
|
||||
result @ (TranslateIncomingResult::NoNatSession
|
||||
| TranslateIncomingResult::Ok { .. }
|
||||
| TranslateIncomingResult::DestinationUnreachable(_)) => {
|
||||
panic!("Wrong result: {result:?}")
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
@@ -20,6 +20,7 @@ use bufferpool::BufferPool;
|
||||
use connlib_model::{ClientId, GatewayId, PublicKey, RelayId};
|
||||
use dns_types::ResponseCode;
|
||||
use dns_types::prelude::*;
|
||||
use ip_packet::make::TcpFlags;
|
||||
use rand::SeedableRng;
|
||||
use rand::distributions::DistString;
|
||||
use sha2::Digest;
|
||||
@@ -199,6 +200,7 @@ impl TunnelTest {
|
||||
dst,
|
||||
sport.0,
|
||||
dport.0,
|
||||
TcpFlags::default(),
|
||||
payload.to_be_bytes().to_vec(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
Reference in New Issue
Block a user