diff --git a/rust/connlib/tunnel/src/io.rs b/rust/connlib/tunnel/src/io.rs index eedb95f64..d7cdaa55d 100644 --- a/rust/connlib/tunnel/src/io.rs +++ b/rust/connlib/tunnel/src/io.rs @@ -39,8 +39,6 @@ const MAX_INBOUND_PACKET_BATCH: usize = { } }; -const MAX_UDP_SIZE: usize = (1 << 16) - 1; - /// Bundles together all side-effects that connlib needs to have access to. pub struct Io { /// The UDP sockets used to send & receive packets from the network. @@ -78,10 +76,12 @@ pub(crate) struct Buffers { impl Default for Buffers { fn default() -> Self { + const ONE_MB: usize = 1024 * 1024; + Self { ip: Vec::with_capacity(MAX_INBOUND_PACKET_BATCH), - udp4: Vec::from([0; MAX_UDP_SIZE]), - udp6: Vec::from([0; MAX_UDP_SIZE]), + udp4: vec![0; ONE_MB], + udp6: vec![0; ONE_MB], } } } diff --git a/rust/socket-factory/src/lib.rs b/rust/socket-factory/src/lib.rs index 6bef6e26d..baec5a9d4 100644 --- a/rust/socket-factory/src/lib.rs +++ b/rust/socket-factory/src/lib.rs @@ -8,7 +8,6 @@ use std::ops::Deref; use std::{ io::{self, IoSliceMut}, net::{IpAddr, SocketAddr}, - slice, task::{Context, Poll, ready}, }; @@ -217,66 +216,70 @@ impl UdpSocket { buffer: &'b mut [u8], cx: &mut Context<'_>, ) -> Poll> + fmt::Debug + use<'b>>> { + const NUM_BUFFERS: usize = 10; + let Self { port, inner, state, .. } = self; - let bufs = &mut [IoSliceMut::new(buffer)]; - let mut meta = quinn_udp::RecvMeta::default(); + let mut bufs = split_buffer_equal::(buffer).map(IoSliceMut::new); + let mut meta = std::array::from_fn::<_, NUM_BUFFERS, _>(|_| quinn_udp::RecvMeta::default()); loop { ready!(inner.poll_recv_ready(cx))?; if let Ok(len) = inner.try_io(Interest::READABLE, || { - state.recv((&inner).into(), bufs, slice::from_mut(&mut meta)) + state.recv((&inner).into(), &mut bufs, &mut meta) }) { - debug_assert_eq!(len, 1); + let bufs = split_buffer_equal::(buffer); + let port = *port; - if meta.len == 0 { - continue; - } - - let Some(local_ip) = meta.dst_ip else { - tracing::warn!("Skipping packet without local IP"); - continue; - }; - - match meta.stride.cmp(&meta.len) { - std::cmp::Ordering::Equal | std::cmp::Ordering::Less => {} - std::cmp::Ordering::Greater => { - return Poll::Ready(Err(io::Error::new( - io::ErrorKind::InvalidData, - format!( - "stride ({}) is larger than buffer len ({})", - meta.stride, meta.len - ), - ))); + let datagrams = bufs.into_iter().zip(meta).take(len).flat_map(move |(buffer, meta)| { + if meta.len == 0 { + return None; } - } - let local = SocketAddr::new(local_ip, *port); + let Some(local_ip) = meta.dst_ip else { + tracing::warn!("Skipping packet without local IP"); + return None; + }; - let segment_size = meta.stride; - let num_packets = meta.len / segment_size; - let trailing_bytes = meta.len % segment_size; + match meta.stride.cmp(&meta.len) { + std::cmp::Ordering::Equal | std::cmp::Ordering::Less => {} + std::cmp::Ordering::Greater => { + tracing::warn!("stride ({}) is larger than buffer len ({})", meta.stride, meta.len); - tracing::trace!(target: "wire::net::recv", src = %meta.addr, dst = %local, ecn = ?meta.ecn, %num_packets, %segment_size, %trailing_bytes); + return None; + } + } - let iter = buffer[..meta.len] - .chunks(meta.stride) - .map(move |packet| DatagramIn { - local, - from: meta.addr, - packet, - ecn: match meta.ecn { - Some(quinn_udp::EcnCodepoint::Ce) => Ecn::Ce, - Some(quinn_udp::EcnCodepoint::Ect0) => Ecn::Ect0, - Some(quinn_udp::EcnCodepoint::Ect1) => Ecn::Ect1, - None => Ecn::NonEct, - }, - }); + let local = SocketAddr::new(local_ip, port); - return Poll::Ready(Ok(iter)); + let segment_size = meta.stride; + let num_packets = meta.len / segment_size; + let trailing_bytes = meta.len % segment_size; + + tracing::trace!(target: "wire::net::recv", src = %meta.addr, dst = %local, ecn = ?meta.ecn, %num_packets, %segment_size, %trailing_bytes); + + let iter = buffer[..meta.len] + .chunks(meta.stride) + .map(move |packet| DatagramIn { + local, + from: meta.addr, + packet, + ecn: match meta.ecn { + Some(quinn_udp::EcnCodepoint::Ce) => Ecn::Ce, + Some(quinn_udp::EcnCodepoint::Ect0) => Ecn::Ect0, + Some(quinn_udp::EcnCodepoint::Ect1) => Ecn::Ect1, + None => Ecn::NonEct, + }, + }); + + Some(iter) + }) + .flatten(); + + return Poll::Ready(Ok(datagrams)); } } } @@ -433,3 +436,17 @@ impl UdpSocket { Ok(Some(src)) } } + +fn split_buffer_equal(s: &mut [u8]) -> [&mut [u8]; N] { + let chunk_size = s.len() / N; + let mut chunks: [&mut [u8]; N] = std::array::from_fn(|_| [].as_mut()); + + let mut rest = s; + for chunk in &mut chunks { + let (head, tail) = rest.split_at_mut(chunk_size); + *chunk = head; + rest = tail; + } + + chunks +}