diff --git a/.github/codespellrc b/.github/codespellrc index ecedbc568..ec270a8d6 100644 --- a/.github/codespellrc +++ b/.github/codespellrc @@ -1,3 +1,3 @@ [codespell] skip = ./elixir/apps/domain/lib/domain/name_generator.ex,./**/*.svg,./elixir/deps,./**/*.min.js,./kotlin/android/app/build,./kotlin/android/build,./e2e/pnpm-lock.yaml,./website/.next,./website/pnpm-lock.yaml,./rust/connlib/tunnel/testcases,./rust/gui-client/dist,./rust/target,Cargo.lock,./website/docs/reference/api/*.mdx,./**/erl_crash.dump,./cover,./vendor,*.json,seeds.exs,./**/node_modules,./deps,./priv/static,./priv/plts,./**/priv/static,./.git,./_build,*.cast,./**/proptest-regressions -ignore-words-list = optin,crate,keypair,keypairs,iif,statics,wee,anull,commitish,inout,fo,superceded +ignore-words-list = optin,crate,keypair,keypairs,iif,statics,wee,anull,commitish,inout,fo,superceded,ect diff --git a/rust/Cargo.lock b/rust/Cargo.lock index 06848b05a..aa2dc3705 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -6122,6 +6122,7 @@ version = "0.1.0" dependencies = [ "bytes", "firezone-logging", + "ip-packet", "quinn-udp", "socket2", "tokio", diff --git a/rust/bin-shared/tests/no_packet_loops_udp.rs b/rust/bin-shared/tests/no_packet_loops_udp.rs index e219e8d76..eea5b3fd0 100644 --- a/rust/bin-shared/tests/no_packet_loops_udp.rs +++ b/rust/bin-shared/tests/no_packet_loops_udp.rs @@ -2,6 +2,7 @@ use firezone_bin_shared::{TunDeviceManager, platform::udp_socket_factory}; use ip_network::Ipv4Network; +use ip_packet::Ecn; use socket_factory::DatagramOut; use std::{ net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4}, @@ -45,6 +46,7 @@ async fn no_packet_loops_udp() { dst: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(141, 101, 90, 0), 3478)), // stun.cloudflare.com, packet: &hex_literal::hex!("000100002112A4420123456789abcdef01234567").as_ref(), segment_size: None, + ecn: Ecn::NonEct, }) .unwrap(); diff --git a/rust/connlib/tunnel/src/io.rs b/rust/connlib/tunnel/src/io.rs index b0411069b..eedb95f64 100644 --- a/rust/connlib/tunnel/src/io.rs +++ b/rust/connlib/tunnel/src/io.rs @@ -9,7 +9,7 @@ use firezone_logging::{telemetry_event, telemetry_span}; use futures::FutureExt as _; use futures_bounded::FuturesTupleSet; use gso_queue::GsoQueue; -use ip_packet::{IpPacket, MAX_FZ_PAYLOAD}; +use ip_packet::{Ecn, IpPacket, MAX_FZ_PAYLOAD}; use nameserver_set::NameserverSet; use socket_factory::{DatagramIn, SocketFactory, TcpSocket, UdpSocket}; use std::{ @@ -296,8 +296,15 @@ impl Io { } } - pub fn send_network(&mut self, src: Option, dst: SocketAddr, payload: &[u8]) { - self.gso_queue.enqueue(src, dst, payload, Instant::now()) + pub fn send_network( + &mut self, + src: Option, + dst: SocketAddr, + payload: &[u8], + ecn: Ecn, + ) { + self.gso_queue + .enqueue(src, dst, payload, ecn, Instant::now()) } pub fn send_dns_query(&mut self, query: dns::RecursiveQuery) { diff --git a/rust/connlib/tunnel/src/io/gso_queue.rs b/rust/connlib/tunnel/src/io/gso_queue.rs index ac884e1bb..a051a4417 100644 --- a/rust/connlib/tunnel/src/io/gso_queue.rs +++ b/rust/connlib/tunnel/src/io/gso_queue.rs @@ -6,6 +6,7 @@ use std::{ }; use bytes::BytesMut; +use ip_packet::Ecn; use socket_factory::DatagramOut; use super::MAX_INBOUND_PACKET_BATCH; @@ -55,6 +56,7 @@ impl GsoQueue { src: Option, dst: SocketAddr, payload: &[u8], + ecn: Ecn, now: Instant, ) { let segment_size = payload.len(); @@ -74,6 +76,7 @@ impl GsoQueue { .or_insert_with(|| DatagramBuffer { inner: None, last_access: now, + ecn, }); buffer @@ -81,6 +84,7 @@ impl GsoQueue { .get_or_insert_with(|| self.buffer_pool.pull_owned()) .extend_from_slice(payload); buffer.last_access = now; + buffer.ecn = ecn; } pub fn datagrams( @@ -88,6 +92,7 @@ impl GsoQueue { ) -> impl Iterator>> + '_ { self.inner.iter_mut().filter_map(|(key, buffer)| { + let ecn = buffer.ecn; // It is really important that we `take` the buffer here, otherwise it is not returned to the pool after. let buffer = buffer.inner.take()?; @@ -100,6 +105,7 @@ impl GsoQueue { dst: key.dst, packet: buffer, segment_size: Some(key.segment_size), + ecn, }) }) } @@ -119,6 +125,7 @@ struct Key { struct DatagramBuffer { inner: Option>, last_access: Instant, + ecn: Ecn, } #[cfg(test)] @@ -132,7 +139,7 @@ mod tests { let now = Instant::now(); let mut send_queue = GsoQueue::new(); - send_queue.enqueue(None, DST, b"foobar", now); + send_queue.enqueue(None, DST, b"foobar", Ecn::NonEct, now); for _entry in send_queue.datagrams() {} send_queue.handle_timeout(now + Duration::from_secs(60)); @@ -145,7 +152,7 @@ mod tests { let now = Instant::now(); let mut send_queue = GsoQueue::new(); - send_queue.enqueue(None, DST, b"foobar", now); + send_queue.enqueue(None, DST, b"foobar", Ecn::NonEct, now); send_queue.handle_timeout(now + Duration::from_secs(60)); @@ -157,7 +164,7 @@ mod tests { let now = Instant::now(); let mut send_queue = GsoQueue::new(); - send_queue.enqueue(None, DST, b"foobar", now); + send_queue.enqueue(None, DST, b"foobar", Ecn::NonEct, now); let datagrams = send_queue.datagrams(); drop(datagrams); @@ -174,8 +181,8 @@ mod tests { let now = Instant::now(); let mut send_queue = GsoQueue::new(); - send_queue.enqueue(None, DST, b"foobar", now); - send_queue.enqueue(None, DST_2, b"bar", now); + send_queue.enqueue(None, DST, b"foobar", Ecn::NonEct, now); + send_queue.enqueue(None, DST_2, b"bar", Ecn::NonEct, now); // Taking it from the iterator is "sending" ... let _datagrams = send_queue.datagrams().collect::>(); diff --git a/rust/connlib/tunnel/src/lib.rs b/rust/connlib/tunnel/src/lib.rs index 3bef2cfb6..6725c9111 100644 --- a/rust/connlib/tunnel/src/lib.rs +++ b/rust/connlib/tunnel/src/lib.rs @@ -12,6 +12,7 @@ use connlib_model::{ClientId, GatewayId, PublicKey, ResourceId, ResourceView}; use dns_types::DomainName; use io::{Buffers, Io}; use ip_network::{Ipv4Network, Ipv6Network}; +use ip_packet::Ecn; use socket_factory::{SocketFactory, TcpSocket, UdpSocket}; use std::{ collections::BTreeSet, @@ -142,7 +143,8 @@ impl ClientTunnel { } if let Some(trans) = self.role_state.poll_transmit() { - self.io.send_network(trans.src, trans.dst, &trans.payload); + self.io + .send_network(trans.src, trans.dst, &trans.payload, Ecn::NonEct); continue; } @@ -164,13 +166,15 @@ impl ClientTunnel { let now = Instant::now(); for packet in packets { + let ecn = packet.ecn(); + let Some(packet) = self.role_state.handle_tun_input(packet, now) else { self.role_state.handle_timeout(now); continue; }; self.io - .send_network(packet.src(), packet.dst(), packet.payload()); + .send_network(packet.src(), packet.dst(), packet.payload(), ecn); } continue; @@ -189,7 +193,7 @@ impl ClientTunnel { continue; }; - self.io.send_tun(packet); + self.io.send_tun(packet.with_ecn(received.ecn)); } continue; @@ -244,7 +248,8 @@ impl GatewayTunnel { } if let Some(trans) = self.role_state.poll_transmit() { - self.io.send_network(trans.src, trans.dst, &trans.payload); + self.io + .send_network(trans.src, trans.dst, &trans.payload, Ecn::NonEct); continue; } @@ -279,6 +284,8 @@ impl GatewayTunnel { let now = Instant::now(); for packet in packets { + let ecn = packet.ecn(); + let Some(packet) = self .role_state .handle_tun_input(packet, now) @@ -289,7 +296,7 @@ impl GatewayTunnel { }; self.io - .send_network(packet.src(), packet.dst(), packet.payload()); + .send_network(packet.src(), packet.dst(), packet.payload(), ecn); } continue; @@ -313,7 +320,7 @@ impl GatewayTunnel { continue; }; - self.io.send_tun(packet); + self.io.send_tun(packet.with_ecn(received.ecn)); } continue; diff --git a/rust/etherparse-ext/src/ipv4_header_slice_mut.rs b/rust/etherparse-ext/src/ipv4_header_slice_mut.rs index 0808d580d..9a2fb32f5 100644 --- a/rust/etherparse-ext/src/ipv4_header_slice_mut.rs +++ b/rust/etherparse-ext/src/ipv4_header_slice_mut.rs @@ -43,4 +43,12 @@ impl<'a> Ipv4HeaderSliceMut<'a> { // Safety: Slice it at least of length 20 as checked in the ctor. unsafe { write_to_offset_unchecked(self.slice, 2, len) }; } + + pub fn set_ecn(&mut self, ecn: u8) { + let current = self.slice[1]; + let new = current & 0b1111_1100 | ecn; + + // Safety: Slice it at least of length 20 as checked in the ctor. + unsafe { write_to_offset_unchecked(self.slice, 1, [new]) }; + } } diff --git a/rust/etherparse-ext/src/ipv6_header_slice_mut.rs b/rust/etherparse-ext/src/ipv6_header_slice_mut.rs index 40ff9c84d..5b31619c3 100644 --- a/rust/etherparse-ext/src/ipv6_header_slice_mut.rs +++ b/rust/etherparse-ext/src/ipv6_header_slice_mut.rs @@ -24,4 +24,27 @@ impl<'a> Ipv6HeaderSliceMut<'a> { // Safety: Slice it at least of length 40 as checked in the ctor. unsafe { write_to_offset_unchecked(self.slice, 24, dst) }; } + + /// Sets the ECN bits in the IPv6 header. + /// + /// Doing this is a bit trickier than for IPv4 due to the layout of the IPv6 header: + /// + /// ```text + /// 0 8 16 24 32 + /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + /// |Version| Traffic Class | Flow Label | + /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + /// ``` + /// + /// The Traffic Class field (of which the lower two bits are used for ECN) is split across + /// two bytes. Thus, to set the ECN bits, we actually need to set bit 3 & 4 of the second byte. + pub fn set_ecn(&mut self, ecn: u8) { + let mask = 0b1100_1111; // Mask to clear the ecn bits. + let ecn = ecn << 4; // Shift the ecn bits to the correct position (so they fit the mask above). + + let second_byte = self.slice[1]; + let new = second_byte & mask | ecn; + + unsafe { write_to_offset_unchecked(self.slice, 1, [new]) }; + } } diff --git a/rust/ip-packet/src/lib.rs b/rust/ip-packet/src/lib.rs index 807c574a1..999b99885 100644 --- a/rust/ip-packet/src/lib.rs +++ b/rust/ip-packet/src/lib.rs @@ -179,6 +179,19 @@ impl std::fmt::Debug for IpPacket { .field("dst_port", &udp.destination_port()); } + match self.ecn() { + Ecn::NonEct => {} + Ecn::Ect1 => { + dbg.field("ecn", &"ECT(1)"); + } + Ecn::Ect0 => { + dbg.field("ecn", &"ECT(0)"); + } + Ecn::Ce => { + dbg.field("ecn", &"CE"); + } + }; + dbg.finish() } } @@ -844,6 +857,30 @@ impl IpPacket { } } + pub fn with_ecn(mut self, ecn: Ecn) -> Self { + match &mut self { + IpPacket::Ipv4(ip) => ip.ip_header_mut().set_ecn(ecn as u8), + IpPacket::Ipv6(ip) => ip.header_mut().set_ecn(ecn as u8), + } + + self + } + + pub fn ecn(&self) -> Ecn { + let byte = match self { + IpPacket::Ipv4(ip) => ip.ip_header().ecn().value(), + IpPacket::Ipv6(ip) => ip.header().traffic_class(), + }; + + match byte & 0b00000011 { + 0b00000000 => Ecn::NonEct, + 0b00000001 => Ecn::Ect1, + 0b00000010 => Ecn::Ect0, + 0b00000011 => Ecn::Ce, + _ => unreachable!(), + } + } + pub fn ipv4_header(&self) -> Option { match self { Self::Ipv4(p) => Some(p.ip_header().to_header()), @@ -988,6 +1025,17 @@ fn extract_l4_proto(payload: &[u8], protocol: IpNumber) -> Result for details. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum Ecn { + NonEct = 0b00, + Ect1 = 0b01, + Ect0 = 0b10, + Ce = 0b11, +} + impl From for IpPacket { fn from(value: ConvertibleIpv4Packet) -> Self { Self::Ipv4(value) @@ -1034,4 +1082,26 @@ mod tests { assert_eq!(udp_payload, b"foobar"); } + + #[test] + fn ipv4_ecn() { + let p = crate::make::udp_packet(Ipv4Addr::LOCALHOST, Ipv4Addr::LOCALHOST, 0, 0, vec![]) + .unwrap(); + + assert_eq!(p.clone().with_ecn(Ecn::NonEct).ecn(), Ecn::NonEct); + assert_eq!(p.clone().with_ecn(Ecn::Ect0).ecn(), Ecn::Ect0); + assert_eq!(p.clone().with_ecn(Ecn::Ect1).ecn(), Ecn::Ect1); + assert_eq!(p.with_ecn(Ecn::Ce).ecn(), Ecn::Ce); + } + + #[test] + fn ipv6_ecn() { + let p = crate::make::udp_packet(Ipv6Addr::LOCALHOST, Ipv6Addr::LOCALHOST, 0, 0, vec![]) + .unwrap(); + + assert_eq!(p.clone().with_ecn(Ecn::NonEct).ecn(), Ecn::NonEct); + assert_eq!(p.clone().with_ecn(Ecn::Ect1).ecn(), Ecn::Ect1); + assert_eq!(p.clone().with_ecn(Ecn::Ect0).ecn(), Ecn::Ect0); + assert_eq!(p.with_ecn(Ecn::Ce).ecn(), Ecn::Ce); + } } diff --git a/rust/socket-factory/Cargo.toml b/rust/socket-factory/Cargo.toml index e843b6ad3..8def03dfc 100644 --- a/rust/socket-factory/Cargo.toml +++ b/rust/socket-factory/Cargo.toml @@ -7,6 +7,7 @@ license = { workspace = true } [dependencies] bytes = { workspace = true } firezone-logging = { workspace = true } +ip-packet = { workspace = true } quinn-udp = { workspace = true } socket2 = { workspace = true } tokio = { workspace = true, features = ["net"] } diff --git a/rust/socket-factory/src/lib.rs b/rust/socket-factory/src/lib.rs index b30f13c2f..6bef6e26d 100644 --- a/rust/socket-factory/src/lib.rs +++ b/rust/socket-factory/src/lib.rs @@ -1,5 +1,6 @@ use bytes::Buf as _; use firezone_logging::err_with_src; +use ip_packet::Ecn; use quinn_udp::Transmit; use std::collections::HashMap; use std::fmt; @@ -198,6 +199,7 @@ pub struct DatagramIn<'a> { pub local: SocketAddr, pub from: SocketAddr, pub packet: &'a [u8], + pub ecn: Ecn, } /// An outbound UDP datagram. @@ -206,6 +208,7 @@ pub struct DatagramOut { pub dst: SocketAddr, pub packet: B, pub segment_size: Option, + pub ecn: Ecn, } impl UdpSocket { @@ -257,7 +260,7 @@ impl UdpSocket { let num_packets = meta.len / segment_size; let trailing_bytes = meta.len % segment_size; - tracing::trace!(target: "wire::net::recv", src = %meta.addr, dst = %local, %num_packets, %segment_size, %trailing_bytes); + tracing::trace!(target: "wire::net::recv", src = %meta.addr, dst = %local, ecn = ?meta.ecn, %num_packets, %segment_size, %trailing_bytes); let iter = buffer[..meta.len] .chunks(meta.stride) @@ -265,6 +268,12 @@ impl UdpSocket { local, from: meta.addr, packet, + ecn: match meta.ecn { + Some(quinn_udp::EcnCodepoint::Ce) => Ecn::Ce, + Some(quinn_udp::EcnCodepoint::Ect0) => Ecn::Ect0, + Some(quinn_udp::EcnCodepoint::Ect1) => Ecn::Ect1, + None => Ecn::NonEct, + }, }); return Poll::Ready(Ok(iter)); @@ -285,6 +294,7 @@ impl UdpSocket { datagram.src.map(|s| s.ip()), datagram.packet.deref().chunk(), datagram.segment_size, + datagram.ecn, )? else { return Ok(()); @@ -305,7 +315,7 @@ impl UdpSocket { { let num_packets = transmit.contents.len() / segment_size; - tracing::trace!(target: "wire::net::send", src = ?datagram.src, dst = %datagram.dst, %num_packets, %segment_size); + tracing::trace!(target: "wire::net::send", src = ?datagram.src, dst = %datagram.dst, ecn = ?transmit.ecn, %num_packets, %segment_size); self.inner.try_io(Interest::WRITABLE, || { self.state.try_send((&self.inner).into(), &transmit) @@ -315,7 +325,7 @@ impl UdpSocket { None => { let num_bytes = transmit.contents.len(); - tracing::trace!(target: "wire::net::send", src = ?datagram.src, dst = %datagram.dst, %num_bytes); + tracing::trace!(target: "wire::net::send", src = ?datagram.src, dst = %datagram.dst, ecn = ?transmit.ecn, %num_bytes); self.inner.try_io(Interest::WRITABLE, || { self.state.try_send((&self.inner).into(), &transmit) @@ -343,7 +353,7 @@ impl UdpSocket { payload: &[u8], ) -> io::Result> { let transmit = self - .prepare_transmit(dst, None, payload, None)? + .prepare_transmit(dst, None, payload, None, Ecn::NonEct)? .ok_or_else(|| io::Error::other("Failed to prepare `Transmit`"))?; self.inner @@ -373,6 +383,7 @@ impl UdpSocket { src_ip: Option, packet: &'a [u8], segment_size: Option, + ecn: Ecn, ) -> io::Result>> { let src_ip = match src_ip { Some(src_ip) => Some(src_ip), @@ -390,7 +401,12 @@ impl UdpSocket { let transmit = quinn_udp::Transmit { destination: dst, - ecn: None, + ecn: match ecn { + Ecn::NonEct => None, + Ecn::Ect1 => Some(quinn_udp::EcnCodepoint::Ect1), + Ecn::Ect0 => Some(quinn_udp::EcnCodepoint::Ect0), + Ecn::Ce => Some(quinn_udp::EcnCodepoint::Ce), + }, contents: packet, segment_size, src_ip,