From 6eab29a7702fd2e370556bc3a195cd7081c63b55 Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Wed, 9 Apr 2025 15:17:46 +1000 Subject: [PATCH] feat(connlib): supply multiple buffers to UDP socket (#8733) At present, `connlib` uses `quinn-udp`'s GRO functionality to read multiple UDP packets within a single syscall. We are however only passing a single buffer and a single `RecvMeta` to the `recv` function. As a result, the function is limited to giving us only packets that originate from one particular IP. By supplying multiple buffers (and their according `RecvMeta`s), we can now read packets from up to 10 different IPs at once within a single syscall. To obtain multiple buffers, we need to split the provided buffer into equal chunks. To ensure that each buffer can still hold several packets, we increase the buffer size to 1MB. It is expected that is increases throughput especially on Gateways which receive UDP packets from many different IPs. --------- Signed-off-by: Thomas Eizinger Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- rust/connlib/tunnel/src/io.rs | 8 +-- rust/socket-factory/src/lib.rs | 105 +++++++++++++++++++-------------- 2 files changed, 65 insertions(+), 48 deletions(-) diff --git a/rust/connlib/tunnel/src/io.rs b/rust/connlib/tunnel/src/io.rs index eedb95f64..d7cdaa55d 100644 --- a/rust/connlib/tunnel/src/io.rs +++ b/rust/connlib/tunnel/src/io.rs @@ -39,8 +39,6 @@ const MAX_INBOUND_PACKET_BATCH: usize = { } }; -const MAX_UDP_SIZE: usize = (1 << 16) - 1; - /// Bundles together all side-effects that connlib needs to have access to. pub struct Io { /// The UDP sockets used to send & receive packets from the network. @@ -78,10 +76,12 @@ pub(crate) struct Buffers { impl Default for Buffers { fn default() -> Self { + const ONE_MB: usize = 1024 * 1024; + Self { ip: Vec::with_capacity(MAX_INBOUND_PACKET_BATCH), - udp4: Vec::from([0; MAX_UDP_SIZE]), - udp6: Vec::from([0; MAX_UDP_SIZE]), + udp4: vec![0; ONE_MB], + udp6: vec![0; ONE_MB], } } } diff --git a/rust/socket-factory/src/lib.rs b/rust/socket-factory/src/lib.rs index 6bef6e26d..baec5a9d4 100644 --- a/rust/socket-factory/src/lib.rs +++ b/rust/socket-factory/src/lib.rs @@ -8,7 +8,6 @@ use std::ops::Deref; use std::{ io::{self, IoSliceMut}, net::{IpAddr, SocketAddr}, - slice, task::{Context, Poll, ready}, }; @@ -217,66 +216,70 @@ impl UdpSocket { buffer: &'b mut [u8], cx: &mut Context<'_>, ) -> Poll> + fmt::Debug + use<'b>>> { + const NUM_BUFFERS: usize = 10; + let Self { port, inner, state, .. } = self; - let bufs = &mut [IoSliceMut::new(buffer)]; - let mut meta = quinn_udp::RecvMeta::default(); + let mut bufs = split_buffer_equal::(buffer).map(IoSliceMut::new); + let mut meta = std::array::from_fn::<_, NUM_BUFFERS, _>(|_| quinn_udp::RecvMeta::default()); loop { ready!(inner.poll_recv_ready(cx))?; if let Ok(len) = inner.try_io(Interest::READABLE, || { - state.recv((&inner).into(), bufs, slice::from_mut(&mut meta)) + state.recv((&inner).into(), &mut bufs, &mut meta) }) { - debug_assert_eq!(len, 1); + let bufs = split_buffer_equal::(buffer); + let port = *port; - if meta.len == 0 { - continue; - } - - let Some(local_ip) = meta.dst_ip else { - tracing::warn!("Skipping packet without local IP"); - continue; - }; - - match meta.stride.cmp(&meta.len) { - std::cmp::Ordering::Equal | std::cmp::Ordering::Less => {} - std::cmp::Ordering::Greater => { - return Poll::Ready(Err(io::Error::new( - io::ErrorKind::InvalidData, - format!( - "stride ({}) is larger than buffer len ({})", - meta.stride, meta.len - ), - ))); + let datagrams = bufs.into_iter().zip(meta).take(len).flat_map(move |(buffer, meta)| { + if meta.len == 0 { + return None; } - } - let local = SocketAddr::new(local_ip, *port); + let Some(local_ip) = meta.dst_ip else { + tracing::warn!("Skipping packet without local IP"); + return None; + }; - let segment_size = meta.stride; - let num_packets = meta.len / segment_size; - let trailing_bytes = meta.len % segment_size; + match meta.stride.cmp(&meta.len) { + std::cmp::Ordering::Equal | std::cmp::Ordering::Less => {} + std::cmp::Ordering::Greater => { + tracing::warn!("stride ({}) is larger than buffer len ({})", meta.stride, meta.len); - tracing::trace!(target: "wire::net::recv", src = %meta.addr, dst = %local, ecn = ?meta.ecn, %num_packets, %segment_size, %trailing_bytes); + return None; + } + } - let iter = buffer[..meta.len] - .chunks(meta.stride) - .map(move |packet| DatagramIn { - local, - from: meta.addr, - packet, - ecn: match meta.ecn { - Some(quinn_udp::EcnCodepoint::Ce) => Ecn::Ce, - Some(quinn_udp::EcnCodepoint::Ect0) => Ecn::Ect0, - Some(quinn_udp::EcnCodepoint::Ect1) => Ecn::Ect1, - None => Ecn::NonEct, - }, - }); + let local = SocketAddr::new(local_ip, port); - return Poll::Ready(Ok(iter)); + let segment_size = meta.stride; + let num_packets = meta.len / segment_size; + let trailing_bytes = meta.len % segment_size; + + tracing::trace!(target: "wire::net::recv", src = %meta.addr, dst = %local, ecn = ?meta.ecn, %num_packets, %segment_size, %trailing_bytes); + + let iter = buffer[..meta.len] + .chunks(meta.stride) + .map(move |packet| DatagramIn { + local, + from: meta.addr, + packet, + ecn: match meta.ecn { + Some(quinn_udp::EcnCodepoint::Ce) => Ecn::Ce, + Some(quinn_udp::EcnCodepoint::Ect0) => Ecn::Ect0, + Some(quinn_udp::EcnCodepoint::Ect1) => Ecn::Ect1, + None => Ecn::NonEct, + }, + }); + + Some(iter) + }) + .flatten(); + + return Poll::Ready(Ok(datagrams)); } } } @@ -433,3 +436,17 @@ impl UdpSocket { Ok(Some(src)) } } + +fn split_buffer_equal(s: &mut [u8]) -> [&mut [u8]; N] { + let chunk_size = s.len() / N; + let mut chunks: [&mut [u8]; N] = std::array::from_fn(|_| [].as_mut()); + + let mut rest = s; + for chunk in &mut chunks { + let (head, tail) = rest.split_at_mut(chunk_size); + *chunk = head; + rest = tail; + } + + chunks +}