From 58fe527b0e87ed810caec2b0296f530eb8a26ea2 Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Thu, 27 Mar 2025 07:55:51 +1100 Subject: [PATCH] feat(connlib): mirror ECN bits on TUN device (#8511) From the perspective of any application, Firezone is a layer-3 network and will thus use the host's networking stack to form IP packets for whichever application protocol is in use (UDP, TCP, etc). These packets then get encapsulated into UDP packets by Firezone and sent to a Gateway. As a result of this design, the IP header seen by the networking stacks of the Client and the receiving service are not visible to any intermediary along the network path of the Client and Gateway. In case this network path is congested and middleboxes such as routers need to drop packets, they will look at the ECN bits in the IP header (of the UDP packet generated by a Client or Gateway) and flip a bit in case the previous value indicated support for ECN (`0x01` or `0x10`). When received by a network stack that supports ECN, seeing `0x11` means that the network path is congested and that it must reduce its send/receive windows (or otherwise throttle the connection). At present, this doesn't work with Firezone because of the aforementioned encapsulation of IP packets. To support ECN, we need to therefore: - Copy ECN bits from a received IP packet to the datagram that encapsulates it: This ensures that if the Client's network stack support ECN, we mirror that support on the wire. - Copy ECN bits from a received datagram to the IP packet the is sent to the TUN device: This ensures that if the "Congestion Experienced" bit get set along the network path between Client and Gateway, we reflect that accordingly on the IP packet emitted by the TUN device. Resolves: #3758 --------- Signed-off-by: Thomas Eizinger Co-authored-by: Jamil Bou Kheir --- .github/codespellrc | 2 +- rust/Cargo.lock | 1 + rust/bin-shared/tests/no_packet_loops_udp.rs | 2 + rust/connlib/tunnel/src/io.rs | 13 +++- rust/connlib/tunnel/src/io/gso_queue.rs | 17 +++-- rust/connlib/tunnel/src/lib.rs | 19 +++-- .../src/ipv4_header_slice_mut.rs | 8 +++ .../src/ipv6_header_slice_mut.rs | 23 ++++++ rust/ip-packet/src/lib.rs | 70 +++++++++++++++++++ rust/socket-factory/Cargo.toml | 1 + rust/socket-factory/src/lib.rs | 26 +++++-- 11 files changed, 162 insertions(+), 20 deletions(-) diff --git a/.github/codespellrc b/.github/codespellrc index ecedbc568..ec270a8d6 100644 --- a/.github/codespellrc +++ b/.github/codespellrc @@ -1,3 +1,3 @@ [codespell] skip = ./elixir/apps/domain/lib/domain/name_generator.ex,./**/*.svg,./elixir/deps,./**/*.min.js,./kotlin/android/app/build,./kotlin/android/build,./e2e/pnpm-lock.yaml,./website/.next,./website/pnpm-lock.yaml,./rust/connlib/tunnel/testcases,./rust/gui-client/dist,./rust/target,Cargo.lock,./website/docs/reference/api/*.mdx,./**/erl_crash.dump,./cover,./vendor,*.json,seeds.exs,./**/node_modules,./deps,./priv/static,./priv/plts,./**/priv/static,./.git,./_build,*.cast,./**/proptest-regressions -ignore-words-list = optin,crate,keypair,keypairs,iif,statics,wee,anull,commitish,inout,fo,superceded +ignore-words-list = optin,crate,keypair,keypairs,iif,statics,wee,anull,commitish,inout,fo,superceded,ect diff --git a/rust/Cargo.lock b/rust/Cargo.lock index 06848b05a..aa2dc3705 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -6122,6 +6122,7 @@ version = "0.1.0" dependencies = [ "bytes", "firezone-logging", + "ip-packet", "quinn-udp", "socket2", "tokio", diff --git a/rust/bin-shared/tests/no_packet_loops_udp.rs b/rust/bin-shared/tests/no_packet_loops_udp.rs index e219e8d76..eea5b3fd0 100644 --- a/rust/bin-shared/tests/no_packet_loops_udp.rs +++ b/rust/bin-shared/tests/no_packet_loops_udp.rs @@ -2,6 +2,7 @@ use firezone_bin_shared::{TunDeviceManager, platform::udp_socket_factory}; use ip_network::Ipv4Network; +use ip_packet::Ecn; use socket_factory::DatagramOut; use std::{ net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4}, @@ -45,6 +46,7 @@ async fn no_packet_loops_udp() { dst: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(141, 101, 90, 0), 3478)), // stun.cloudflare.com, packet: &hex_literal::hex!("000100002112A4420123456789abcdef01234567").as_ref(), segment_size: None, + ecn: Ecn::NonEct, }) .unwrap(); diff --git a/rust/connlib/tunnel/src/io.rs b/rust/connlib/tunnel/src/io.rs index b0411069b..eedb95f64 100644 --- a/rust/connlib/tunnel/src/io.rs +++ b/rust/connlib/tunnel/src/io.rs @@ -9,7 +9,7 @@ use firezone_logging::{telemetry_event, telemetry_span}; use futures::FutureExt as _; use futures_bounded::FuturesTupleSet; use gso_queue::GsoQueue; -use ip_packet::{IpPacket, MAX_FZ_PAYLOAD}; +use ip_packet::{Ecn, IpPacket, MAX_FZ_PAYLOAD}; use nameserver_set::NameserverSet; use socket_factory::{DatagramIn, SocketFactory, TcpSocket, UdpSocket}; use std::{ @@ -296,8 +296,15 @@ impl Io { } } - pub fn send_network(&mut self, src: Option, dst: SocketAddr, payload: &[u8]) { - self.gso_queue.enqueue(src, dst, payload, Instant::now()) + pub fn send_network( + &mut self, + src: Option, + dst: SocketAddr, + payload: &[u8], + ecn: Ecn, + ) { + self.gso_queue + .enqueue(src, dst, payload, ecn, Instant::now()) } pub fn send_dns_query(&mut self, query: dns::RecursiveQuery) { diff --git a/rust/connlib/tunnel/src/io/gso_queue.rs b/rust/connlib/tunnel/src/io/gso_queue.rs index ac884e1bb..a051a4417 100644 --- a/rust/connlib/tunnel/src/io/gso_queue.rs +++ b/rust/connlib/tunnel/src/io/gso_queue.rs @@ -6,6 +6,7 @@ use std::{ }; use bytes::BytesMut; +use ip_packet::Ecn; use socket_factory::DatagramOut; use super::MAX_INBOUND_PACKET_BATCH; @@ -55,6 +56,7 @@ impl GsoQueue { src: Option, dst: SocketAddr, payload: &[u8], + ecn: Ecn, now: Instant, ) { let segment_size = payload.len(); @@ -74,6 +76,7 @@ impl GsoQueue { .or_insert_with(|| DatagramBuffer { inner: None, last_access: now, + ecn, }); buffer @@ -81,6 +84,7 @@ impl GsoQueue { .get_or_insert_with(|| self.buffer_pool.pull_owned()) .extend_from_slice(payload); buffer.last_access = now; + buffer.ecn = ecn; } pub fn datagrams( @@ -88,6 +92,7 @@ impl GsoQueue { ) -> impl Iterator>> + '_ { self.inner.iter_mut().filter_map(|(key, buffer)| { + let ecn = buffer.ecn; // It is really important that we `take` the buffer here, otherwise it is not returned to the pool after. let buffer = buffer.inner.take()?; @@ -100,6 +105,7 @@ impl GsoQueue { dst: key.dst, packet: buffer, segment_size: Some(key.segment_size), + ecn, }) }) } @@ -119,6 +125,7 @@ struct Key { struct DatagramBuffer { inner: Option>, last_access: Instant, + ecn: Ecn, } #[cfg(test)] @@ -132,7 +139,7 @@ mod tests { let now = Instant::now(); let mut send_queue = GsoQueue::new(); - send_queue.enqueue(None, DST, b"foobar", now); + send_queue.enqueue(None, DST, b"foobar", Ecn::NonEct, now); for _entry in send_queue.datagrams() {} send_queue.handle_timeout(now + Duration::from_secs(60)); @@ -145,7 +152,7 @@ mod tests { let now = Instant::now(); let mut send_queue = GsoQueue::new(); - send_queue.enqueue(None, DST, b"foobar", now); + send_queue.enqueue(None, DST, b"foobar", Ecn::NonEct, now); send_queue.handle_timeout(now + Duration::from_secs(60)); @@ -157,7 +164,7 @@ mod tests { let now = Instant::now(); let mut send_queue = GsoQueue::new(); - send_queue.enqueue(None, DST, b"foobar", now); + send_queue.enqueue(None, DST, b"foobar", Ecn::NonEct, now); let datagrams = send_queue.datagrams(); drop(datagrams); @@ -174,8 +181,8 @@ mod tests { let now = Instant::now(); let mut send_queue = GsoQueue::new(); - send_queue.enqueue(None, DST, b"foobar", now); - send_queue.enqueue(None, DST_2, b"bar", now); + send_queue.enqueue(None, DST, b"foobar", Ecn::NonEct, now); + send_queue.enqueue(None, DST_2, b"bar", Ecn::NonEct, now); // Taking it from the iterator is "sending" ... let _datagrams = send_queue.datagrams().collect::>(); diff --git a/rust/connlib/tunnel/src/lib.rs b/rust/connlib/tunnel/src/lib.rs index 3bef2cfb6..6725c9111 100644 --- a/rust/connlib/tunnel/src/lib.rs +++ b/rust/connlib/tunnel/src/lib.rs @@ -12,6 +12,7 @@ use connlib_model::{ClientId, GatewayId, PublicKey, ResourceId, ResourceView}; use dns_types::DomainName; use io::{Buffers, Io}; use ip_network::{Ipv4Network, Ipv6Network}; +use ip_packet::Ecn; use socket_factory::{SocketFactory, TcpSocket, UdpSocket}; use std::{ collections::BTreeSet, @@ -142,7 +143,8 @@ impl ClientTunnel { } if let Some(trans) = self.role_state.poll_transmit() { - self.io.send_network(trans.src, trans.dst, &trans.payload); + self.io + .send_network(trans.src, trans.dst, &trans.payload, Ecn::NonEct); continue; } @@ -164,13 +166,15 @@ impl ClientTunnel { let now = Instant::now(); for packet in packets { + let ecn = packet.ecn(); + let Some(packet) = self.role_state.handle_tun_input(packet, now) else { self.role_state.handle_timeout(now); continue; }; self.io - .send_network(packet.src(), packet.dst(), packet.payload()); + .send_network(packet.src(), packet.dst(), packet.payload(), ecn); } continue; @@ -189,7 +193,7 @@ impl ClientTunnel { continue; }; - self.io.send_tun(packet); + self.io.send_tun(packet.with_ecn(received.ecn)); } continue; @@ -244,7 +248,8 @@ impl GatewayTunnel { } if let Some(trans) = self.role_state.poll_transmit() { - self.io.send_network(trans.src, trans.dst, &trans.payload); + self.io + .send_network(trans.src, trans.dst, &trans.payload, Ecn::NonEct); continue; } @@ -279,6 +284,8 @@ impl GatewayTunnel { let now = Instant::now(); for packet in packets { + let ecn = packet.ecn(); + let Some(packet) = self .role_state .handle_tun_input(packet, now) @@ -289,7 +296,7 @@ impl GatewayTunnel { }; self.io - .send_network(packet.src(), packet.dst(), packet.payload()); + .send_network(packet.src(), packet.dst(), packet.payload(), ecn); } continue; @@ -313,7 +320,7 @@ impl GatewayTunnel { continue; }; - self.io.send_tun(packet); + self.io.send_tun(packet.with_ecn(received.ecn)); } continue; diff --git a/rust/etherparse-ext/src/ipv4_header_slice_mut.rs b/rust/etherparse-ext/src/ipv4_header_slice_mut.rs index 0808d580d..9a2fb32f5 100644 --- a/rust/etherparse-ext/src/ipv4_header_slice_mut.rs +++ b/rust/etherparse-ext/src/ipv4_header_slice_mut.rs @@ -43,4 +43,12 @@ impl<'a> Ipv4HeaderSliceMut<'a> { // Safety: Slice it at least of length 20 as checked in the ctor. unsafe { write_to_offset_unchecked(self.slice, 2, len) }; } + + pub fn set_ecn(&mut self, ecn: u8) { + let current = self.slice[1]; + let new = current & 0b1111_1100 | ecn; + + // Safety: Slice it at least of length 20 as checked in the ctor. + unsafe { write_to_offset_unchecked(self.slice, 1, [new]) }; + } } diff --git a/rust/etherparse-ext/src/ipv6_header_slice_mut.rs b/rust/etherparse-ext/src/ipv6_header_slice_mut.rs index 40ff9c84d..5b31619c3 100644 --- a/rust/etherparse-ext/src/ipv6_header_slice_mut.rs +++ b/rust/etherparse-ext/src/ipv6_header_slice_mut.rs @@ -24,4 +24,27 @@ impl<'a> Ipv6HeaderSliceMut<'a> { // Safety: Slice it at least of length 40 as checked in the ctor. unsafe { write_to_offset_unchecked(self.slice, 24, dst) }; } + + /// Sets the ECN bits in the IPv6 header. + /// + /// Doing this is a bit trickier than for IPv4 due to the layout of the IPv6 header: + /// + /// ```text + /// 0 8 16 24 32 + /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + /// |Version| Traffic Class | Flow Label | + /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + /// ``` + /// + /// The Traffic Class field (of which the lower two bits are used for ECN) is split across + /// two bytes. Thus, to set the ECN bits, we actually need to set bit 3 & 4 of the second byte. + pub fn set_ecn(&mut self, ecn: u8) { + let mask = 0b1100_1111; // Mask to clear the ecn bits. + let ecn = ecn << 4; // Shift the ecn bits to the correct position (so they fit the mask above). + + let second_byte = self.slice[1]; + let new = second_byte & mask | ecn; + + unsafe { write_to_offset_unchecked(self.slice, 1, [new]) }; + } } diff --git a/rust/ip-packet/src/lib.rs b/rust/ip-packet/src/lib.rs index 807c574a1..999b99885 100644 --- a/rust/ip-packet/src/lib.rs +++ b/rust/ip-packet/src/lib.rs @@ -179,6 +179,19 @@ impl std::fmt::Debug for IpPacket { .field("dst_port", &udp.destination_port()); } + match self.ecn() { + Ecn::NonEct => {} + Ecn::Ect1 => { + dbg.field("ecn", &"ECT(1)"); + } + Ecn::Ect0 => { + dbg.field("ecn", &"ECT(0)"); + } + Ecn::Ce => { + dbg.field("ecn", &"CE"); + } + }; + dbg.finish() } } @@ -844,6 +857,30 @@ impl IpPacket { } } + pub fn with_ecn(mut self, ecn: Ecn) -> Self { + match &mut self { + IpPacket::Ipv4(ip) => ip.ip_header_mut().set_ecn(ecn as u8), + IpPacket::Ipv6(ip) => ip.header_mut().set_ecn(ecn as u8), + } + + self + } + + pub fn ecn(&self) -> Ecn { + let byte = match self { + IpPacket::Ipv4(ip) => ip.ip_header().ecn().value(), + IpPacket::Ipv6(ip) => ip.header().traffic_class(), + }; + + match byte & 0b00000011 { + 0b00000000 => Ecn::NonEct, + 0b00000001 => Ecn::Ect1, + 0b00000010 => Ecn::Ect0, + 0b00000011 => Ecn::Ce, + _ => unreachable!(), + } + } + pub fn ipv4_header(&self) -> Option { match self { Self::Ipv4(p) => Some(p.ip_header().to_header()), @@ -988,6 +1025,17 @@ fn extract_l4_proto(payload: &[u8], protocol: IpNumber) -> Result for details. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum Ecn { + NonEct = 0b00, + Ect1 = 0b01, + Ect0 = 0b10, + Ce = 0b11, +} + impl From for IpPacket { fn from(value: ConvertibleIpv4Packet) -> Self { Self::Ipv4(value) @@ -1034,4 +1082,26 @@ mod tests { assert_eq!(udp_payload, b"foobar"); } + + #[test] + fn ipv4_ecn() { + let p = crate::make::udp_packet(Ipv4Addr::LOCALHOST, Ipv4Addr::LOCALHOST, 0, 0, vec![]) + .unwrap(); + + assert_eq!(p.clone().with_ecn(Ecn::NonEct).ecn(), Ecn::NonEct); + assert_eq!(p.clone().with_ecn(Ecn::Ect0).ecn(), Ecn::Ect0); + assert_eq!(p.clone().with_ecn(Ecn::Ect1).ecn(), Ecn::Ect1); + assert_eq!(p.with_ecn(Ecn::Ce).ecn(), Ecn::Ce); + } + + #[test] + fn ipv6_ecn() { + let p = crate::make::udp_packet(Ipv6Addr::LOCALHOST, Ipv6Addr::LOCALHOST, 0, 0, vec![]) + .unwrap(); + + assert_eq!(p.clone().with_ecn(Ecn::NonEct).ecn(), Ecn::NonEct); + assert_eq!(p.clone().with_ecn(Ecn::Ect1).ecn(), Ecn::Ect1); + assert_eq!(p.clone().with_ecn(Ecn::Ect0).ecn(), Ecn::Ect0); + assert_eq!(p.with_ecn(Ecn::Ce).ecn(), Ecn::Ce); + } } diff --git a/rust/socket-factory/Cargo.toml b/rust/socket-factory/Cargo.toml index e843b6ad3..8def03dfc 100644 --- a/rust/socket-factory/Cargo.toml +++ b/rust/socket-factory/Cargo.toml @@ -7,6 +7,7 @@ license = { workspace = true } [dependencies] bytes = { workspace = true } firezone-logging = { workspace = true } +ip-packet = { workspace = true } quinn-udp = { workspace = true } socket2 = { workspace = true } tokio = { workspace = true, features = ["net"] } diff --git a/rust/socket-factory/src/lib.rs b/rust/socket-factory/src/lib.rs index b30f13c2f..6bef6e26d 100644 --- a/rust/socket-factory/src/lib.rs +++ b/rust/socket-factory/src/lib.rs @@ -1,5 +1,6 @@ use bytes::Buf as _; use firezone_logging::err_with_src; +use ip_packet::Ecn; use quinn_udp::Transmit; use std::collections::HashMap; use std::fmt; @@ -198,6 +199,7 @@ pub struct DatagramIn<'a> { pub local: SocketAddr, pub from: SocketAddr, pub packet: &'a [u8], + pub ecn: Ecn, } /// An outbound UDP datagram. @@ -206,6 +208,7 @@ pub struct DatagramOut { pub dst: SocketAddr, pub packet: B, pub segment_size: Option, + pub ecn: Ecn, } impl UdpSocket { @@ -257,7 +260,7 @@ impl UdpSocket { let num_packets = meta.len / segment_size; let trailing_bytes = meta.len % segment_size; - tracing::trace!(target: "wire::net::recv", src = %meta.addr, dst = %local, %num_packets, %segment_size, %trailing_bytes); + tracing::trace!(target: "wire::net::recv", src = %meta.addr, dst = %local, ecn = ?meta.ecn, %num_packets, %segment_size, %trailing_bytes); let iter = buffer[..meta.len] .chunks(meta.stride) @@ -265,6 +268,12 @@ impl UdpSocket { local, from: meta.addr, packet, + ecn: match meta.ecn { + Some(quinn_udp::EcnCodepoint::Ce) => Ecn::Ce, + Some(quinn_udp::EcnCodepoint::Ect0) => Ecn::Ect0, + Some(quinn_udp::EcnCodepoint::Ect1) => Ecn::Ect1, + None => Ecn::NonEct, + }, }); return Poll::Ready(Ok(iter)); @@ -285,6 +294,7 @@ impl UdpSocket { datagram.src.map(|s| s.ip()), datagram.packet.deref().chunk(), datagram.segment_size, + datagram.ecn, )? else { return Ok(()); @@ -305,7 +315,7 @@ impl UdpSocket { { let num_packets = transmit.contents.len() / segment_size; - tracing::trace!(target: "wire::net::send", src = ?datagram.src, dst = %datagram.dst, %num_packets, %segment_size); + tracing::trace!(target: "wire::net::send", src = ?datagram.src, dst = %datagram.dst, ecn = ?transmit.ecn, %num_packets, %segment_size); self.inner.try_io(Interest::WRITABLE, || { self.state.try_send((&self.inner).into(), &transmit) @@ -315,7 +325,7 @@ impl UdpSocket { None => { let num_bytes = transmit.contents.len(); - tracing::trace!(target: "wire::net::send", src = ?datagram.src, dst = %datagram.dst, %num_bytes); + tracing::trace!(target: "wire::net::send", src = ?datagram.src, dst = %datagram.dst, ecn = ?transmit.ecn, %num_bytes); self.inner.try_io(Interest::WRITABLE, || { self.state.try_send((&self.inner).into(), &transmit) @@ -343,7 +353,7 @@ impl UdpSocket { payload: &[u8], ) -> io::Result> { let transmit = self - .prepare_transmit(dst, None, payload, None)? + .prepare_transmit(dst, None, payload, None, Ecn::NonEct)? .ok_or_else(|| io::Error::other("Failed to prepare `Transmit`"))?; self.inner @@ -373,6 +383,7 @@ impl UdpSocket { src_ip: Option, packet: &'a [u8], segment_size: Option, + ecn: Ecn, ) -> io::Result>> { let src_ip = match src_ip { Some(src_ip) => Some(src_ip), @@ -390,7 +401,12 @@ impl UdpSocket { let transmit = quinn_udp::Transmit { destination: dst, - ecn: None, + ecn: match ecn { + Ecn::NonEct => None, + Ecn::Ect1 => Some(quinn_udp::EcnCodepoint::Ect1), + Ecn::Ect0 => Some(quinn_udp::EcnCodepoint::Ect0), + Ecn::Ce => Some(quinn_udp::EcnCodepoint::Ce), + }, contents: packet, segment_size, src_ip,