diff --git a/rust/Cargo.lock b/rust/Cargo.lock index 75f48b5f4..689ea7ffc 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -2234,6 +2234,7 @@ dependencies = [ "ip_network", "ip_network_table", "itertools 0.13.0", + "lockfree-object-pool", "lru", "proptest", "proptest-state-machine", @@ -5936,6 +5937,7 @@ dependencies = [ name = "socket-factory" version = "0.1.0" dependencies = [ + "bytes", "firezone-logging", "quinn-udp", "socket2", diff --git a/rust/bin-shared/src/tun_device_manager.rs b/rust/bin-shared/src/tun_device_manager.rs index 10ede1834..185dee436 100644 --- a/rust/bin-shared/src/tun_device_manager.rs +++ b/rust/bin-shared/src/tun_device_manager.rs @@ -20,7 +20,6 @@ mod tests { use ip_network::Ipv4Network; use socket_factory::DatagramOut; use std::{ - borrow::Cow, net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4}, time::Duration, }; @@ -101,9 +100,7 @@ mod tests { .send(DatagramOut { src: None, dst: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(141, 101, 90, 0), 3478)), // stun.cloudflare.com, - packet: Cow::Borrowed(&hex_literal::hex!( - "000100002112A4420123456789abcdef01234567" - )), + packet: &hex_literal::hex!("000100002112A4420123456789abcdef01234567").as_ref(), segment_size: None, }) .unwrap(); diff --git a/rust/connlib/tunnel/Cargo.toml b/rust/connlib/tunnel/Cargo.toml index 35d38671d..5a9dc30dd 100644 --- a/rust/connlib/tunnel/Cargo.toml +++ b/rust/connlib/tunnel/Cargo.toml @@ -26,6 +26,7 @@ ip-packet = { workspace = true } ip_network = { workspace = true } ip_network_table = { workspace = true } itertools = { workspace = true, features = ["use_std"] } +lockfree-object-pool = { workspace = true } lru = { workspace = true } proptest = { workspace = true, optional = true } rand = { workspace = true } diff --git a/rust/connlib/tunnel/src/io/gso_queue.rs b/rust/connlib/tunnel/src/io/gso_queue.rs index 2c9e0a033..8e10cefa2 100644 --- a/rust/connlib/tunnel/src/io/gso_queue.rs +++ b/rust/connlib/tunnel/src/io/gso_queue.rs @@ -1,7 +1,7 @@ use std::{ - borrow::Cow, - collections::BTreeMap, + collections::HashMap, net::SocketAddr, + sync::Arc, time::{Duration, Instant}, }; @@ -10,18 +10,30 @@ use socket_factory::DatagramOut; use super::MAX_INBOUND_PACKET_BATCH; +const MAX_SEGMENT_SIZE: usize = + ip_packet::MAX_IP_SIZE + ip_packet::WG_OVERHEAD + ip_packet::DATA_CHANNEL_OVERHEAD; + /// Holds UDP datagrams that we need to send, indexed by src, dst and segment size. /// /// Calling [`Io::send_network`](super::Io::send_network) will copy the provided payload into this buffer. /// The buffer is then flushed using GSO in a single syscall. pub struct GsoQueue { - inner: BTreeMap, + inner: HashMap, + buffer_pool: Arc>, } impl GsoQueue { pub fn new() -> Self { Self { inner: Default::default(), + buffer_pool: Arc::new(lockfree_object_pool::SpinLockObjectPool::new( + || { + tracing::debug!("Initialising new buffer for GSO queue"); + + BytesMut::with_capacity(MAX_SEGMENT_SIZE * MAX_INBOUND_PACKET_BATCH) + }, + |b| b.clear(), + )), } } @@ -42,13 +54,13 @@ impl GsoQueue { payload: &[u8], now: Instant, ) { + let buffer = self.buffer_pool.pull_owned(); let segment_size = payload.len(); - // At most, a single batch translates to packets all going to the same destination and length. - // Thus, to avoid a lot of re-allocations during sending, allocate enough space to store a quarter of the packets in a batch. - // Re-allocations happen by doubling the capacity, so this means we have at most 2 re-allocation. - // This number has been chosen empirically by observing how big the GSO batches typically are. - let capacity = segment_size * MAX_INBOUND_PACKET_BATCH / 4; + debug_assert!( + segment_size <= MAX_SEGMENT_SIZE, + "MAX_SEGMENT_SIZE is miscalculated" + ); self.inner .entry(Key { @@ -56,18 +68,24 @@ impl GsoQueue { dst, segment_size, }) - .or_insert_with(|| DatagramBuffer::new(now, capacity)) + .or_insert_with(|| DatagramBuffer { + inner: buffer, + last_access: now, + }) .extend(payload, now); } - pub fn datagrams(&mut self) -> impl Iterator> + '_ { + pub fn datagrams( + &mut self, + ) -> impl Iterator>> + '_ + { self.inner - .iter_mut() + .drain() .filter(|(_, b)| !b.is_empty()) .map(|(key, buffer)| DatagramOut { src: key.src, dst: key.dst, - packet: Cow::Owned(buffer.inner.split().freeze().into()), + packet: buffer.inner, segment_size: Some(key.segment_size), }) } @@ -85,18 +103,11 @@ struct Key { } struct DatagramBuffer { - inner: BytesMut, + inner: lockfree_object_pool::SpinLockOwnedReusable, last_access: Instant, } impl DatagramBuffer { - pub fn new(now: Instant, capacity: usize) -> Self { - Self { - inner: BytesMut::with_capacity(capacity), - last_access: now, - } - } - pub(crate) fn is_empty(&self) -> bool { self.inner.is_empty() } diff --git a/rust/connlib/tunnel/src/sockets.rs b/rust/connlib/tunnel/src/sockets.rs index 82b34dcd2..f2ffb3a5b 100644 --- a/rust/connlib/tunnel/src/sockets.rs +++ b/rust/connlib/tunnel/src/sockets.rs @@ -2,6 +2,7 @@ use socket_factory::{DatagramIn, DatagramOut, SocketFactory, UdpSocket}; use std::{ io, net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}, + ops::Deref, task::{ready, Context, Poll, Waker}, }; @@ -57,7 +58,10 @@ impl Sockets { Poll::Ready(Ok(())) } - pub fn send(&mut self, datagram: DatagramOut) -> io::Result<()> { + pub fn send(&mut self, datagram: DatagramOut) -> io::Result<()> + where + B: Deref, + { let socket = match datagram.dst { SocketAddr::V4(dst) => self.socket_v4.as_mut().ok_or_else(|| { io::Error::new( diff --git a/rust/ip-packet/src/lib.rs b/rust/ip-packet/src/lib.rs index 674818a85..6c13ff1e6 100644 --- a/rust/ip-packet/src/lib.rs +++ b/rust/ip-packet/src/lib.rs @@ -61,11 +61,11 @@ pub const MAX_UDP_PAYLOAD: usize = MAX_IP_PAYLOAD - etherparse::UdpHeader::LEN; pub const MAX_FZ_PAYLOAD: usize = MAX_IP_SIZE + WG_OVERHEAD + NAT46_OVERHEAD + DATA_CHANNEL_OVERHEAD; /// Wireguard has a 32-byte overhead (4b message type + 4b receiver idx + 8b packet counter + 16b AEAD tag) -const WG_OVERHEAD: usize = 32; +pub const WG_OVERHEAD: usize = 32; /// In order to do NAT46 without copying, we need 20 extra byte in the buffer (IPv6 packets are 20 byte bigger than IPv4). pub(crate) const NAT46_OVERHEAD: usize = 20; /// TURN's data channels have a 4 byte overhead. -const DATA_CHANNEL_OVERHEAD: usize = 4; +pub const DATA_CHANNEL_OVERHEAD: usize = 4; macro_rules! for_both { ($this:ident, |$name:ident| $body:expr) => { diff --git a/rust/socket-factory/Cargo.toml b/rust/socket-factory/Cargo.toml index f79684bde..e843b6ad3 100644 --- a/rust/socket-factory/Cargo.toml +++ b/rust/socket-factory/Cargo.toml @@ -5,6 +5,7 @@ edition = { workspace = true } license = { workspace = true } [dependencies] +bytes = { workspace = true } firezone-logging = { workspace = true } quinn-udp = { workspace = true } socket2 = { workspace = true } diff --git a/rust/socket-factory/src/lib.rs b/rust/socket-factory/src/lib.rs index 93957a198..a36cd8b34 100644 --- a/rust/socket-factory/src/lib.rs +++ b/rust/socket-factory/src/lib.rs @@ -1,9 +1,10 @@ +use bytes::Buf as _; use firezone_logging::std_dyn_err; use quinn_udp::Transmit; use std::collections::HashMap; use std::fmt; +use std::ops::Deref; use std::{ - borrow::Cow, io::{self, IoSliceMut}, net::{IpAddr, SocketAddr}, slice, @@ -200,10 +201,10 @@ pub struct DatagramIn<'a> { } /// An outbound UDP datagram. -pub struct DatagramOut<'a> { +pub struct DatagramOut { pub src: Option, pub dst: SocketAddr, - pub packet: Cow<'a, [u8]>, + pub packet: B, pub segment_size: Option, } @@ -275,11 +276,14 @@ impl UdpSocket { self.inner.poll_send_ready(cx) } - pub fn send(&mut self, datagram: DatagramOut) -> io::Result<()> { + pub fn send(&mut self, datagram: DatagramOut) -> io::Result<()> + where + B: Deref, + { let Some(transmit) = self.prepare_transmit( datagram.dst, datagram.src.map(|s| s.ip()), - &datagram.packet, + datagram.packet.deref().chunk(), datagram.segment_size, )? else {