feat(connlib): mirror ECN bits on TUN device (#8511)

From the perspective of any application, Firezone is a layer-3 network
and will thus use the host's networking stack to form IP packets for
whichever application protocol is in use (UDP, TCP, etc). These packets
then get encapsulated into UDP packets by Firezone and sent to a
Gateway.

As a result of this design, the IP header seen by the networking stacks
of the Client and the receiving service are not visible to any
intermediary along the network path of the Client and Gateway.

In case this network path is congested and middleboxes such as routers
need to drop packets, they will look at the ECN bits in the IP header
(of the UDP packet generated by a Client or Gateway) and flip a bit in
case the previous value indicated support for ECN (`0x01` or `0x10`).
When received by a network stack that supports ECN, seeing `0x11` means
that the network path is congested and that it must reduce its
send/receive windows (or otherwise throttle the connection).

At present, this doesn't work with Firezone because of the
aforementioned encapsulation of IP packets. To support ECN, we need to
therefore:

- Copy ECN bits from a received IP packet to the datagram that
encapsulates it: This ensures that if the Client's network stack support
ECN, we mirror that support on the wire.
- Copy ECN bits from a received datagram to the IP packet the is sent to
the TUN device: This ensures that if the "Congestion Experienced" bit
get set along the network path between Client and Gateway, we reflect
that accordingly on the IP packet emitted by the TUN device.

Resolves: #3758

---------

Signed-off-by: Thomas Eizinger <thomas@eizinger.io>
Co-authored-by: Jamil Bou Kheir <jamilbk@users.noreply.github.com>
This commit is contained in:
Thomas Eizinger
2025-03-27 07:55:51 +11:00
committed by GitHub
parent 41d89f4c12
commit 58fe527b0e
11 changed files with 162 additions and 20 deletions

2
.github/codespellrc vendored
View File

@@ -1,3 +1,3 @@
[codespell]
skip = ./elixir/apps/domain/lib/domain/name_generator.ex,./**/*.svg,./elixir/deps,./**/*.min.js,./kotlin/android/app/build,./kotlin/android/build,./e2e/pnpm-lock.yaml,./website/.next,./website/pnpm-lock.yaml,./rust/connlib/tunnel/testcases,./rust/gui-client/dist,./rust/target,Cargo.lock,./website/docs/reference/api/*.mdx,./**/erl_crash.dump,./cover,./vendor,*.json,seeds.exs,./**/node_modules,./deps,./priv/static,./priv/plts,./**/priv/static,./.git,./_build,*.cast,./**/proptest-regressions
ignore-words-list = optin,crate,keypair,keypairs,iif,statics,wee,anull,commitish,inout,fo,superceded
ignore-words-list = optin,crate,keypair,keypairs,iif,statics,wee,anull,commitish,inout,fo,superceded,ect

1
rust/Cargo.lock generated
View File

@@ -6122,6 +6122,7 @@ version = "0.1.0"
dependencies = [
"bytes",
"firezone-logging",
"ip-packet",
"quinn-udp",
"socket2",
"tokio",

View File

@@ -2,6 +2,7 @@
use firezone_bin_shared::{TunDeviceManager, platform::udp_socket_factory};
use ip_network::Ipv4Network;
use ip_packet::Ecn;
use socket_factory::DatagramOut;
use std::{
net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4},
@@ -45,6 +46,7 @@ async fn no_packet_loops_udp() {
dst: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(141, 101, 90, 0), 3478)), // stun.cloudflare.com,
packet: &hex_literal::hex!("000100002112A4420123456789abcdef01234567").as_ref(),
segment_size: None,
ecn: Ecn::NonEct,
})
.unwrap();

View File

@@ -9,7 +9,7 @@ use firezone_logging::{telemetry_event, telemetry_span};
use futures::FutureExt as _;
use futures_bounded::FuturesTupleSet;
use gso_queue::GsoQueue;
use ip_packet::{IpPacket, MAX_FZ_PAYLOAD};
use ip_packet::{Ecn, IpPacket, MAX_FZ_PAYLOAD};
use nameserver_set::NameserverSet;
use socket_factory::{DatagramIn, SocketFactory, TcpSocket, UdpSocket};
use std::{
@@ -296,8 +296,15 @@ impl Io {
}
}
pub fn send_network(&mut self, src: Option<SocketAddr>, dst: SocketAddr, payload: &[u8]) {
self.gso_queue.enqueue(src, dst, payload, Instant::now())
pub fn send_network(
&mut self,
src: Option<SocketAddr>,
dst: SocketAddr,
payload: &[u8],
ecn: Ecn,
) {
self.gso_queue
.enqueue(src, dst, payload, ecn, Instant::now())
}
pub fn send_dns_query(&mut self, query: dns::RecursiveQuery) {

View File

@@ -6,6 +6,7 @@ use std::{
};
use bytes::BytesMut;
use ip_packet::Ecn;
use socket_factory::DatagramOut;
use super::MAX_INBOUND_PACKET_BATCH;
@@ -55,6 +56,7 @@ impl GsoQueue {
src: Option<SocketAddr>,
dst: SocketAddr,
payload: &[u8],
ecn: Ecn,
now: Instant,
) {
let segment_size = payload.len();
@@ -74,6 +76,7 @@ impl GsoQueue {
.or_insert_with(|| DatagramBuffer {
inner: None,
last_access: now,
ecn,
});
buffer
@@ -81,6 +84,7 @@ impl GsoQueue {
.get_or_insert_with(|| self.buffer_pool.pull_owned())
.extend_from_slice(payload);
buffer.last_access = now;
buffer.ecn = ecn;
}
pub fn datagrams(
@@ -88,6 +92,7 @@ impl GsoQueue {
) -> impl Iterator<Item = DatagramOut<lockfree_object_pool::SpinLockOwnedReusable<BytesMut>>> + '_
{
self.inner.iter_mut().filter_map(|(key, buffer)| {
let ecn = buffer.ecn;
// It is really important that we `take` the buffer here, otherwise it is not returned to the pool after.
let buffer = buffer.inner.take()?;
@@ -100,6 +105,7 @@ impl GsoQueue {
dst: key.dst,
packet: buffer,
segment_size: Some(key.segment_size),
ecn,
})
})
}
@@ -119,6 +125,7 @@ struct Key {
struct DatagramBuffer {
inner: Option<lockfree_object_pool::SpinLockOwnedReusable<BytesMut>>,
last_access: Instant,
ecn: Ecn,
}
#[cfg(test)]
@@ -132,7 +139,7 @@ mod tests {
let now = Instant::now();
let mut send_queue = GsoQueue::new();
send_queue.enqueue(None, DST, b"foobar", now);
send_queue.enqueue(None, DST, b"foobar", Ecn::NonEct, now);
for _entry in send_queue.datagrams() {}
send_queue.handle_timeout(now + Duration::from_secs(60));
@@ -145,7 +152,7 @@ mod tests {
let now = Instant::now();
let mut send_queue = GsoQueue::new();
send_queue.enqueue(None, DST, b"foobar", now);
send_queue.enqueue(None, DST, b"foobar", Ecn::NonEct, now);
send_queue.handle_timeout(now + Duration::from_secs(60));
@@ -157,7 +164,7 @@ mod tests {
let now = Instant::now();
let mut send_queue = GsoQueue::new();
send_queue.enqueue(None, DST, b"foobar", now);
send_queue.enqueue(None, DST, b"foobar", Ecn::NonEct, now);
let datagrams = send_queue.datagrams();
drop(datagrams);
@@ -174,8 +181,8 @@ mod tests {
let now = Instant::now();
let mut send_queue = GsoQueue::new();
send_queue.enqueue(None, DST, b"foobar", now);
send_queue.enqueue(None, DST_2, b"bar", now);
send_queue.enqueue(None, DST, b"foobar", Ecn::NonEct, now);
send_queue.enqueue(None, DST_2, b"bar", Ecn::NonEct, now);
// Taking it from the iterator is "sending" ...
let _datagrams = send_queue.datagrams().collect::<Vec<_>>();

View File

@@ -12,6 +12,7 @@ use connlib_model::{ClientId, GatewayId, PublicKey, ResourceId, ResourceView};
use dns_types::DomainName;
use io::{Buffers, Io};
use ip_network::{Ipv4Network, Ipv6Network};
use ip_packet::Ecn;
use socket_factory::{SocketFactory, TcpSocket, UdpSocket};
use std::{
collections::BTreeSet,
@@ -142,7 +143,8 @@ impl ClientTunnel {
}
if let Some(trans) = self.role_state.poll_transmit() {
self.io.send_network(trans.src, trans.dst, &trans.payload);
self.io
.send_network(trans.src, trans.dst, &trans.payload, Ecn::NonEct);
continue;
}
@@ -164,13 +166,15 @@ impl ClientTunnel {
let now = Instant::now();
for packet in packets {
let ecn = packet.ecn();
let Some(packet) = self.role_state.handle_tun_input(packet, now) else {
self.role_state.handle_timeout(now);
continue;
};
self.io
.send_network(packet.src(), packet.dst(), packet.payload());
.send_network(packet.src(), packet.dst(), packet.payload(), ecn);
}
continue;
@@ -189,7 +193,7 @@ impl ClientTunnel {
continue;
};
self.io.send_tun(packet);
self.io.send_tun(packet.with_ecn(received.ecn));
}
continue;
@@ -244,7 +248,8 @@ impl GatewayTunnel {
}
if let Some(trans) = self.role_state.poll_transmit() {
self.io.send_network(trans.src, trans.dst, &trans.payload);
self.io
.send_network(trans.src, trans.dst, &trans.payload, Ecn::NonEct);
continue;
}
@@ -279,6 +284,8 @@ impl GatewayTunnel {
let now = Instant::now();
for packet in packets {
let ecn = packet.ecn();
let Some(packet) = self
.role_state
.handle_tun_input(packet, now)
@@ -289,7 +296,7 @@ impl GatewayTunnel {
};
self.io
.send_network(packet.src(), packet.dst(), packet.payload());
.send_network(packet.src(), packet.dst(), packet.payload(), ecn);
}
continue;
@@ -313,7 +320,7 @@ impl GatewayTunnel {
continue;
};
self.io.send_tun(packet);
self.io.send_tun(packet.with_ecn(received.ecn));
}
continue;

View File

@@ -43,4 +43,12 @@ impl<'a> Ipv4HeaderSliceMut<'a> {
// Safety: Slice it at least of length 20 as checked in the ctor.
unsafe { write_to_offset_unchecked(self.slice, 2, len) };
}
pub fn set_ecn(&mut self, ecn: u8) {
let current = self.slice[1];
let new = current & 0b1111_1100 | ecn;
// Safety: Slice it at least of length 20 as checked in the ctor.
unsafe { write_to_offset_unchecked(self.slice, 1, [new]) };
}
}

View File

@@ -24,4 +24,27 @@ impl<'a> Ipv6HeaderSliceMut<'a> {
// Safety: Slice it at least of length 40 as checked in the ctor.
unsafe { write_to_offset_unchecked(self.slice, 24, dst) };
}
/// Sets the ECN bits in the IPv6 header.
///
/// Doing this is a bit trickier than for IPv4 due to the layout of the IPv6 header:
///
/// ```text
/// 0 8 16 24 32
/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
/// |Version| Traffic Class | Flow Label |
/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
/// ```
///
/// The Traffic Class field (of which the lower two bits are used for ECN) is split across
/// two bytes. Thus, to set the ECN bits, we actually need to set bit 3 & 4 of the second byte.
pub fn set_ecn(&mut self, ecn: u8) {
let mask = 0b1100_1111; // Mask to clear the ecn bits.
let ecn = ecn << 4; // Shift the ecn bits to the correct position (so they fit the mask above).
let second_byte = self.slice[1];
let new = second_byte & mask | ecn;
unsafe { write_to_offset_unchecked(self.slice, 1, [new]) };
}
}

View File

@@ -179,6 +179,19 @@ impl std::fmt::Debug for IpPacket {
.field("dst_port", &udp.destination_port());
}
match self.ecn() {
Ecn::NonEct => {}
Ecn::Ect1 => {
dbg.field("ecn", &"ECT(1)");
}
Ecn::Ect0 => {
dbg.field("ecn", &"ECT(0)");
}
Ecn::Ce => {
dbg.field("ecn", &"CE");
}
};
dbg.finish()
}
}
@@ -844,6 +857,30 @@ impl IpPacket {
}
}
pub fn with_ecn(mut self, ecn: Ecn) -> Self {
match &mut self {
IpPacket::Ipv4(ip) => ip.ip_header_mut().set_ecn(ecn as u8),
IpPacket::Ipv6(ip) => ip.header_mut().set_ecn(ecn as u8),
}
self
}
pub fn ecn(&self) -> Ecn {
let byte = match self {
IpPacket::Ipv4(ip) => ip.ip_header().ecn().value(),
IpPacket::Ipv6(ip) => ip.header().traffic_class(),
};
match byte & 0b00000011 {
0b00000000 => Ecn::NonEct,
0b00000001 => Ecn::Ect1,
0b00000010 => Ecn::Ect0,
0b00000011 => Ecn::Ce,
_ => unreachable!(),
}
}
pub fn ipv4_header(&self) -> Option<Ipv4Header> {
match self {
Self::Ipv4(p) => Some(p.ip_header().to_header()),
@@ -988,6 +1025,17 @@ fn extract_l4_proto(payload: &[u8], protocol: IpNumber) -> Result<Layer4Protocol
Ok(proto)
}
/// Models the possible ECN states.
///
/// See <https://www.rfc-editor.org/rfc/rfc3168#section-23.1> for details.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Ecn {
NonEct = 0b00,
Ect1 = 0b01,
Ect0 = 0b10,
Ce = 0b11,
}
impl From<ConvertibleIpv4Packet> for IpPacket {
fn from(value: ConvertibleIpv4Packet) -> Self {
Self::Ipv4(value)
@@ -1034,4 +1082,26 @@ mod tests {
assert_eq!(udp_payload, b"foobar");
}
#[test]
fn ipv4_ecn() {
let p = crate::make::udp_packet(Ipv4Addr::LOCALHOST, Ipv4Addr::LOCALHOST, 0, 0, vec![])
.unwrap();
assert_eq!(p.clone().with_ecn(Ecn::NonEct).ecn(), Ecn::NonEct);
assert_eq!(p.clone().with_ecn(Ecn::Ect0).ecn(), Ecn::Ect0);
assert_eq!(p.clone().with_ecn(Ecn::Ect1).ecn(), Ecn::Ect1);
assert_eq!(p.with_ecn(Ecn::Ce).ecn(), Ecn::Ce);
}
#[test]
fn ipv6_ecn() {
let p = crate::make::udp_packet(Ipv6Addr::LOCALHOST, Ipv6Addr::LOCALHOST, 0, 0, vec![])
.unwrap();
assert_eq!(p.clone().with_ecn(Ecn::NonEct).ecn(), Ecn::NonEct);
assert_eq!(p.clone().with_ecn(Ecn::Ect1).ecn(), Ecn::Ect1);
assert_eq!(p.clone().with_ecn(Ecn::Ect0).ecn(), Ecn::Ect0);
assert_eq!(p.with_ecn(Ecn::Ce).ecn(), Ecn::Ce);
}
}

View File

@@ -7,6 +7,7 @@ license = { workspace = true }
[dependencies]
bytes = { workspace = true }
firezone-logging = { workspace = true }
ip-packet = { workspace = true }
quinn-udp = { workspace = true }
socket2 = { workspace = true }
tokio = { workspace = true, features = ["net"] }

View File

@@ -1,5 +1,6 @@
use bytes::Buf as _;
use firezone_logging::err_with_src;
use ip_packet::Ecn;
use quinn_udp::Transmit;
use std::collections::HashMap;
use std::fmt;
@@ -198,6 +199,7 @@ pub struct DatagramIn<'a> {
pub local: SocketAddr,
pub from: SocketAddr,
pub packet: &'a [u8],
pub ecn: Ecn,
}
/// An outbound UDP datagram.
@@ -206,6 +208,7 @@ pub struct DatagramOut<B> {
pub dst: SocketAddr,
pub packet: B,
pub segment_size: Option<usize>,
pub ecn: Ecn,
}
impl UdpSocket {
@@ -257,7 +260,7 @@ impl UdpSocket {
let num_packets = meta.len / segment_size;
let trailing_bytes = meta.len % segment_size;
tracing::trace!(target: "wire::net::recv", src = %meta.addr, dst = %local, %num_packets, %segment_size, %trailing_bytes);
tracing::trace!(target: "wire::net::recv", src = %meta.addr, dst = %local, ecn = ?meta.ecn, %num_packets, %segment_size, %trailing_bytes);
let iter = buffer[..meta.len]
.chunks(meta.stride)
@@ -265,6 +268,12 @@ impl UdpSocket {
local,
from: meta.addr,
packet,
ecn: match meta.ecn {
Some(quinn_udp::EcnCodepoint::Ce) => Ecn::Ce,
Some(quinn_udp::EcnCodepoint::Ect0) => Ecn::Ect0,
Some(quinn_udp::EcnCodepoint::Ect1) => Ecn::Ect1,
None => Ecn::NonEct,
},
});
return Poll::Ready(Ok(iter));
@@ -285,6 +294,7 @@ impl UdpSocket {
datagram.src.map(|s| s.ip()),
datagram.packet.deref().chunk(),
datagram.segment_size,
datagram.ecn,
)?
else {
return Ok(());
@@ -305,7 +315,7 @@ impl UdpSocket {
{
let num_packets = transmit.contents.len() / segment_size;
tracing::trace!(target: "wire::net::send", src = ?datagram.src, dst = %datagram.dst, %num_packets, %segment_size);
tracing::trace!(target: "wire::net::send", src = ?datagram.src, dst = %datagram.dst, ecn = ?transmit.ecn, %num_packets, %segment_size);
self.inner.try_io(Interest::WRITABLE, || {
self.state.try_send((&self.inner).into(), &transmit)
@@ -315,7 +325,7 @@ impl UdpSocket {
None => {
let num_bytes = transmit.contents.len();
tracing::trace!(target: "wire::net::send", src = ?datagram.src, dst = %datagram.dst, %num_bytes);
tracing::trace!(target: "wire::net::send", src = ?datagram.src, dst = %datagram.dst, ecn = ?transmit.ecn, %num_bytes);
self.inner.try_io(Interest::WRITABLE, || {
self.state.try_send((&self.inner).into(), &transmit)
@@ -343,7 +353,7 @@ impl UdpSocket {
payload: &[u8],
) -> io::Result<Vec<u8>> {
let transmit = self
.prepare_transmit(dst, None, payload, None)?
.prepare_transmit(dst, None, payload, None, Ecn::NonEct)?
.ok_or_else(|| io::Error::other("Failed to prepare `Transmit`"))?;
self.inner
@@ -373,6 +383,7 @@ impl UdpSocket {
src_ip: Option<IpAddr>,
packet: &'a [u8],
segment_size: Option<usize>,
ecn: Ecn,
) -> io::Result<Option<quinn_udp::Transmit<'a>>> {
let src_ip = match src_ip {
Some(src_ip) => Some(src_ip),
@@ -390,7 +401,12 @@ impl UdpSocket {
let transmit = quinn_udp::Transmit {
destination: dst,
ecn: None,
ecn: match ecn {
Ecn::NonEct => None,
Ecn::Ect1 => Some(quinn_udp::EcnCodepoint::Ect1),
Ecn::Ect0 => Some(quinn_udp::EcnCodepoint::Ect0),
Ecn::Ce => Some(quinn_udp::EcnCodepoint::Ce),
},
contents: packet,
segment_size,
src_ip,