feat(connlib): supply multiple buffers to UDP socket (#8733)

At present, `connlib` uses `quinn-udp`'s GRO functionality to read
multiple UDP packets within a single syscall. We are however only
passing a single buffer and a single `RecvMeta` to the `recv` function.
As a result, the function is limited to giving us only packets that
originate from one particular IP.

By supplying multiple buffers (and their according `RecvMeta`s), we can
now read packets from up to 10 different IPs at once within a single
syscall. To obtain multiple buffers, we need to split the provided
buffer into equal chunks. To ensure that each buffer can still hold
several packets, we increase the buffer size to 1MB.

It is expected that is increases throughput especially on Gateways which
receive UDP packets from many different IPs.

---------

Signed-off-by: Thomas Eizinger <thomas@eizinger.io>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
Thomas Eizinger
2025-04-09 15:17:46 +10:00
committed by GitHub
parent dc92ee7251
commit 6eab29a770
2 changed files with 65 additions and 48 deletions

View File

@@ -39,8 +39,6 @@ const MAX_INBOUND_PACKET_BATCH: usize = {
}
};
const MAX_UDP_SIZE: usize = (1 << 16) - 1;
/// Bundles together all side-effects that connlib needs to have access to.
pub struct Io {
/// The UDP sockets used to send & receive packets from the network.
@@ -78,10 +76,12 @@ pub(crate) struct Buffers {
impl Default for Buffers {
fn default() -> Self {
const ONE_MB: usize = 1024 * 1024;
Self {
ip: Vec::with_capacity(MAX_INBOUND_PACKET_BATCH),
udp4: Vec::from([0; MAX_UDP_SIZE]),
udp6: Vec::from([0; MAX_UDP_SIZE]),
udp4: vec![0; ONE_MB],
udp6: vec![0; ONE_MB],
}
}
}

View File

@@ -8,7 +8,6 @@ use std::ops::Deref;
use std::{
io::{self, IoSliceMut},
net::{IpAddr, SocketAddr},
slice,
task::{Context, Poll, ready},
};
@@ -217,66 +216,70 @@ impl UdpSocket {
buffer: &'b mut [u8],
cx: &mut Context<'_>,
) -> Poll<io::Result<impl Iterator<Item = DatagramIn<'b>> + fmt::Debug + use<'b>>> {
const NUM_BUFFERS: usize = 10;
let Self {
port, inner, state, ..
} = self;
let bufs = &mut [IoSliceMut::new(buffer)];
let mut meta = quinn_udp::RecvMeta::default();
let mut bufs = split_buffer_equal::<NUM_BUFFERS>(buffer).map(IoSliceMut::new);
let mut meta = std::array::from_fn::<_, NUM_BUFFERS, _>(|_| quinn_udp::RecvMeta::default());
loop {
ready!(inner.poll_recv_ready(cx))?;
if let Ok(len) = inner.try_io(Interest::READABLE, || {
state.recv((&inner).into(), bufs, slice::from_mut(&mut meta))
state.recv((&inner).into(), &mut bufs, &mut meta)
}) {
debug_assert_eq!(len, 1);
let bufs = split_buffer_equal::<NUM_BUFFERS>(buffer);
let port = *port;
if meta.len == 0 {
continue;
}
let Some(local_ip) = meta.dst_ip else {
tracing::warn!("Skipping packet without local IP");
continue;
};
match meta.stride.cmp(&meta.len) {
std::cmp::Ordering::Equal | std::cmp::Ordering::Less => {}
std::cmp::Ordering::Greater => {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::InvalidData,
format!(
"stride ({}) is larger than buffer len ({})",
meta.stride, meta.len
),
)));
let datagrams = bufs.into_iter().zip(meta).take(len).flat_map(move |(buffer, meta)| {
if meta.len == 0 {
return None;
}
}
let local = SocketAddr::new(local_ip, *port);
let Some(local_ip) = meta.dst_ip else {
tracing::warn!("Skipping packet without local IP");
return None;
};
let segment_size = meta.stride;
let num_packets = meta.len / segment_size;
let trailing_bytes = meta.len % segment_size;
match meta.stride.cmp(&meta.len) {
std::cmp::Ordering::Equal | std::cmp::Ordering::Less => {}
std::cmp::Ordering::Greater => {
tracing::warn!("stride ({}) is larger than buffer len ({})", meta.stride, meta.len);
tracing::trace!(target: "wire::net::recv", src = %meta.addr, dst = %local, ecn = ?meta.ecn, %num_packets, %segment_size, %trailing_bytes);
return None;
}
}
let iter = buffer[..meta.len]
.chunks(meta.stride)
.map(move |packet| DatagramIn {
local,
from: meta.addr,
packet,
ecn: match meta.ecn {
Some(quinn_udp::EcnCodepoint::Ce) => Ecn::Ce,
Some(quinn_udp::EcnCodepoint::Ect0) => Ecn::Ect0,
Some(quinn_udp::EcnCodepoint::Ect1) => Ecn::Ect1,
None => Ecn::NonEct,
},
});
let local = SocketAddr::new(local_ip, port);
return Poll::Ready(Ok(iter));
let segment_size = meta.stride;
let num_packets = meta.len / segment_size;
let trailing_bytes = meta.len % segment_size;
tracing::trace!(target: "wire::net::recv", src = %meta.addr, dst = %local, ecn = ?meta.ecn, %num_packets, %segment_size, %trailing_bytes);
let iter = buffer[..meta.len]
.chunks(meta.stride)
.map(move |packet| DatagramIn {
local,
from: meta.addr,
packet,
ecn: match meta.ecn {
Some(quinn_udp::EcnCodepoint::Ce) => Ecn::Ce,
Some(quinn_udp::EcnCodepoint::Ect0) => Ecn::Ect0,
Some(quinn_udp::EcnCodepoint::Ect1) => Ecn::Ect1,
None => Ecn::NonEct,
},
});
Some(iter)
})
.flatten();
return Poll::Ready(Ok(datagrams));
}
}
}
@@ -433,3 +436,17 @@ impl UdpSocket {
Ok(Some(src))
}
}
fn split_buffer_equal<const N: usize>(s: &mut [u8]) -> [&mut [u8]; N] {
let chunk_size = s.len() / N;
let mut chunks: [&mut [u8]; N] = std::array::from_fn(|_| [].as_mut());
let mut rest = s;
for chunk in &mut chunks {
let (head, tail) = rest.split_at_mut(chunk_size);
*chunk = head;
rest = tail;
}
chunks
}