diff --git a/rust/connlib/snownet/src/allocation.rs b/rust/connlib/snownet/src/allocation.rs index 52b91bcf2..b1c4f1f8f 100644 --- a/rust/connlib/snownet/src/allocation.rs +++ b/rust/connlib/snownet/src/allocation.rs @@ -7,6 +7,7 @@ use bufferpool::BufferPool; use bytecodec::{DecodeExt as _, EncodeExt as _}; use firezone_logging::err_with_src; use hex_display::HexDisplayExt as _; +use ip_packet::Ecn; use rand::random; use ringbuffer::{AllocRingBuffer, RingBuffer as _}; use std::{ @@ -1138,6 +1139,7 @@ impl Allocation { src: None, dst, payload: self.buffer_pool.pull_initialised(&encode(message)), + ecn: Ecn::NonEct, }); true diff --git a/rust/connlib/snownet/src/node.rs b/rust/connlib/snownet/src/node.rs index beef650a6..48ffc653c 100644 --- a/rust/connlib/snownet/src/node.rs +++ b/rust/connlib/snownet/src/node.rs @@ -12,7 +12,7 @@ use boringtun::{noise::rate_limiter::RateLimiter, x25519::StaticSecret}; use bufferpool::{Buffer, BufferPool}; use core::fmt; use hex_display::HexDisplayExt; -use ip_packet::{IpPacket, IpPacketBuf}; +use ip_packet::{Ecn, IpPacket, IpPacketBuf}; use itertools::Itertools; use rand::rngs::StdRng; use rand::seq::IteratorRandom; @@ -1691,6 +1691,8 @@ pub struct Transmit { pub dst: SocketAddr, /// The data that should be sent. pub payload: Buffer>, + /// The ECN bits to set for the UDP packet. + pub ecn: Ecn, } impl fmt::Debug for Transmit { @@ -2252,6 +2254,7 @@ where src: Some(source), dst, payload: self.buffer_pool.pull_initialised(&Vec::from(stun_packet)), + ecn: Ecn::NonEct, }); continue; }; @@ -2273,6 +2276,7 @@ where src: None, dst: encode_ok.socket, payload: self.buffer_pool.pull_initialised(&data_channel_packet), + ecn: Ecn::NonEct, }); } } @@ -2366,6 +2370,7 @@ where src: Some(source), dst: remote, payload: buffer, + ecn: packet.ecn(), })), PeerSocket::RelayToPeer { dest: peer } | PeerSocket::RelayToRelay { dest: peer } => { let Some(allocation) = allocations.get_mut(&self.relay.id) else { @@ -2384,6 +2389,7 @@ where src: None, dst: encode_ok.socket, payload: buffer, + ecn: packet.ecn(), })) } } @@ -2594,6 +2600,7 @@ where src: Some(source), dst: remote, payload: buffer_pool.pull_initialised(message), + ecn: Ecn::NonEct, }, PeerSocket::RelayToPeer { dest: peer } | PeerSocket::RelayToRelay { dest: peer } => { let allocation = allocations.get_mut(&relay)?; @@ -2605,6 +2612,7 @@ where src: None, dst: encode_ok.socket, payload: buffer_pool.pull_initialised(&channel_data), + ecn: Ecn::NonEct, } } }; diff --git a/rust/connlib/tunnel/src/lib.rs b/rust/connlib/tunnel/src/lib.rs index d2eba066b..1586cae24 100644 --- a/rust/connlib/tunnel/src/lib.rs +++ b/rust/connlib/tunnel/src/lib.rs @@ -14,7 +14,6 @@ use futures::{FutureExt, future::BoxFuture}; use gat_lending_iterator::LendingIterator; use io::{Buffers, Io}; use ip_network::{Ipv4Network, Ipv6Network}; -use ip_packet::Ecn; use socket_factory::{SocketFactory, TcpSocket, UdpSocket}; use std::{ collections::BTreeSet, @@ -148,7 +147,7 @@ impl ClientTunnel { // Drain all UDP packets that need to be sent. while let Some(trans) = self.role_state.poll_transmit() { self.io - .send_network(trans.src, trans.dst, &trans.payload, Ecn::NonEct); + .send_network(trans.src, trans.dst, &trans.payload, trans.ecn); } // Return a future that "owns" our IO, polling it until all packets have been flushed. @@ -185,7 +184,7 @@ impl ClientTunnel { // Drain all buffered transmits. while let Some(trans) = self.role_state.poll_transmit() { self.io - .send_network(trans.src, trans.dst, &trans.payload, Ecn::NonEct); + .send_network(trans.src, trans.dst, &trans.payload, trans.ecn); ready = true; } @@ -222,15 +221,13 @@ impl ClientTunnel { if let Some(packets) = device { for packet in packets { - let ecn = packet.ecn(); - match self.role_state.handle_tun_input(packet, now) { Some(transmit) => { self.io.send_network( transmit.src, transmit.dst, &transmit.payload, - ecn, + transmit.ecn, ); } None => { @@ -321,7 +318,7 @@ impl GatewayTunnel { // Drain all UDP packets that need to be sent. while let Some(trans) = self.role_state.poll_transmit() { self.io - .send_network(trans.src, trans.dst, &trans.payload, Ecn::NonEct); + .send_network(trans.src, trans.dst, &trans.payload, trans.ecn); } // Return a future that "owns" our IO, polling it until all packets have been flushed. @@ -352,7 +349,7 @@ impl GatewayTunnel { // Drain all buffered transmits. while let Some(trans) = self.role_state.poll_transmit() { self.io - .send_network(trans.src, trans.dst, &trans.payload, Ecn::NonEct); + .send_network(trans.src, trans.dst, &trans.payload, trans.ecn); ready = true; } @@ -400,15 +397,13 @@ impl GatewayTunnel { if let Some(packets) = device { for packet in packets { - let ecn = packet.ecn(); - match self.role_state.handle_tun_input(packet, now) { Ok(Some(transmit)) => { self.io.send_network( transmit.src, transmit.dst, &transmit.payload, - ecn, + transmit.ecn, ); } Ok(None) => { diff --git a/rust/connlib/tunnel/src/tests/sim_relay.rs b/rust/connlib/tunnel/src/tests/sim_relay.rs index b56bc7595..8d645902a 100644 --- a/rust/connlib/tunnel/src/tests/sim_relay.rs +++ b/rust/connlib/tunnel/src/tests/sim_relay.rs @@ -5,6 +5,7 @@ use super::{ use bufferpool::Buffer; use connlib_model::RelayId; use firezone_relay::{AddressFamily, AllocationPort, ClientSocket, IpStack, PeerSocket}; +use ip_packet::Ecn; use proptest::prelude::*; use rand::{SeedableRng as _, rngs::StdRng}; use secrecy::SecretString; @@ -151,6 +152,7 @@ impl SimRelay { src: Some(src), dst, payload, + ecn: Ecn::NonEct, }) } @@ -176,6 +178,7 @@ impl SimRelay { src: Some(sending_socket), dst: receiving_socket, payload, + ecn: Ecn::NonEct, }) } diff --git a/rust/connlib/tunnel/src/tests/sut.rs b/rust/connlib/tunnel/src/tests/sut.rs index 9b9c66cd7..58ebc3e6a 100644 --- a/rust/connlib/tunnel/src/tests/sut.rs +++ b/rust/connlib/tunnel/src/tests/sut.rs @@ -19,6 +19,7 @@ use bufferpool::BufferPool; use connlib_model::{ClientId, GatewayId, PublicKey, RelayId}; use dns_types::ResponseCode; use dns_types::prelude::*; +use ip_packet::Ecn; use rand::SeedableRng; use rand::distributions::DistString; use sha2::Digest; @@ -529,6 +530,7 @@ impl TunnelTest { src: Some(src), dst, payload: self.buffer_pool.pull_initialised(&payload), + ecn: Ecn::NonEct, }, relay, now,