diff --git a/rust/connlib/tunnel/src/io.rs b/rust/connlib/tunnel/src/io.rs index ca6ccff5b..6dc4ae59c 100644 --- a/rust/connlib/tunnel/src/io.rs +++ b/rust/connlib/tunnel/src/io.rs @@ -255,10 +255,6 @@ impl Io { self.timeout = None; // Clear the timeout. - // Piggy back onto the timeout we already have. - // It is not important when we call this, just needs to be called occasionally. - self.gso_queue.handle_timeout(now); - return Poll::Ready(Ok(Input::Timeout(now))); } } @@ -350,8 +346,7 @@ impl Io { payload: &[u8], ecn: Ecn, ) { - self.gso_queue - .enqueue(src, dst, payload, ecn, Instant::now()); + self.gso_queue.enqueue(src, dst, payload, ecn); self.packet_counter.add( 1, diff --git a/rust/connlib/tunnel/src/io/gso_queue.rs b/rust/connlib/tunnel/src/io/gso_queue.rs index db78a2af3..a9ea44810 100644 --- a/rust/connlib/tunnel/src/io/gso_queue.rs +++ b/rust/connlib/tunnel/src/io/gso_queue.rs @@ -1,8 +1,7 @@ use std::{ - collections::BTreeMap, + collections::{BTreeMap, VecDeque}, net::SocketAddr, sync::Arc, - time::{Duration, Instant}, }; use bytes::BytesMut; @@ -14,12 +13,14 @@ use super::MAX_INBOUND_PACKET_BATCH; const MAX_SEGMENT_SIZE: usize = ip_packet::MAX_IP_SIZE + ip_packet::WG_OVERHEAD + ip_packet::DATA_CHANNEL_OVERHEAD; +type Buffer = lockfree_object_pool::SpinLockOwnedReusable; + /// Holds UDP datagrams that we need to send, indexed by src, dst and segment size. /// /// Calling [`Io::send_network`](super::Io::send_network) will copy the provided payload into this buffer. /// The buffer is then flushed using GSO in a single syscall. pub struct GsoQueue { - inner: BTreeMap, + inner: BTreeMap>, buffer_pool: Arc>, } @@ -38,47 +39,38 @@ impl GsoQueue { } } - pub fn handle_timeout(&mut self, now: Instant) { - self.inner.retain(|_, b| { - if !b.inner.is_empty() { - return true; - } - - now.duration_since(b.last_access) < Duration::from_secs(60) - }) - } - - pub fn enqueue( - &mut self, - src: Option, - dst: SocketAddr, - payload: &[u8], - ecn: Ecn, - now: Instant, - ) { - let segment_size = payload.len(); + pub fn enqueue(&mut self, src: Option, dst: SocketAddr, payload: &[u8], ecn: Ecn) { + let payload_len = payload.len(); debug_assert!( - segment_size <= MAX_SEGMENT_SIZE, + payload_len <= MAX_SEGMENT_SIZE, "MAX_SEGMENT_SIZE is miscalculated" ); - let buffer = self - .inner - .entry(Key { - src, - dst, - segment_size, - }) - .or_insert_with(|| DatagramBuffer { - inner: self.buffer_pool.pull_owned(), - last_access: now, - ecn, - }); + let batches = self.inner.entry(Connection { src, dst, ecn }).or_default(); - buffer.inner.extend_from_slice(payload); - buffer.last_access = now; - buffer.ecn = ecn; + let Some((batch_size, buffer)) = batches.back_mut() else { + let mut buffer = self.buffer_pool.pull_owned(); + buffer.extend_from_slice(payload); + + batches.push_back((payload_len, buffer)); + + return; + }; + let batch_size = *batch_size; + + // A batch is considered "ongoing" if so far we have only pushed packets of the same length. + let batch_is_ongoing = buffer.len() % batch_size == 0; + + if batch_is_ongoing && payload_len <= batch_size { + buffer.extend_from_slice(payload); + return; + } + + let mut buffer = self.buffer_pool.pull_owned(); + buffer.extend_from_slice(payload); + + batches.push_back((payload_len, buffer)); } pub fn datagrams( @@ -94,15 +86,9 @@ impl GsoQueue { } #[derive(Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Clone, Copy)] -struct Key { - segment_size: usize, // `segment_size` comes first to ensure that the datagrams are flushed to the socket in descending order. +struct Connection { src: Option, dst: SocketAddr, -} - -struct DatagramBuffer { - inner: lockfree_object_pool::SpinLockOwnedReusable, - last_access: Instant, ecn: Ecn, } @@ -115,15 +101,24 @@ impl Iterator for DrainDatagramsIter<'_> { type Item = DatagramOut>; fn next(&mut self) -> Option { - let (key, buffer) = self.queue.inner.pop_last()?; + loop { + let mut entry = self.queue.inner.first_entry()?; - Some(DatagramOut { - src: key.src, - dst: key.dst, - packet: buffer.inner, - segment_size: Some(key.segment_size), - ecn: buffer.ecn, - }) + let connection = *entry.key(); + + let Some((segment_size, buffer)) = entry.get_mut().pop_front() else { + entry.remove(); + continue; + }; + + return Some(DatagramOut { + src: connection.src, + dst: connection.dst, + packet: buffer, + segment_size: Some(segment_size), + ecn: connection.ecn, + }); + } } } @@ -133,37 +128,11 @@ mod tests { use super::*; - #[test] - fn send_queue_gcs_after_1_minute() { - let now = Instant::now(); - let mut send_queue = GsoQueue::new(); - - send_queue.enqueue(None, DST_1, b"foobar", Ecn::NonEct, now); - for _entry in send_queue.datagrams() {} - - send_queue.handle_timeout(now + Duration::from_secs(60)); - - assert_eq!(send_queue.inner.len(), 0); - } - - #[test] - fn does_not_gc_unsent_items() { - let now = Instant::now(); - let mut send_queue = GsoQueue::new(); - - send_queue.enqueue(None, DST_1, b"foobar", Ecn::NonEct, now); - - send_queue.handle_timeout(now + Duration::from_secs(60)); - - assert_eq!(send_queue.inner.len(), 1); - } - #[test] fn dropping_datagram_iterator_does_not_drop_items() { - let now = Instant::now(); let mut send_queue = GsoQueue::new(); - send_queue.enqueue(None, DST_1, b"foobar", Ecn::NonEct, now); + send_queue.enqueue(None, DST_1, b"foobar", Ecn::NonEct); let datagrams = send_queue.datagrams(); drop(datagrams); @@ -176,34 +145,87 @@ mod tests { } #[test] - fn prioritises_large_packets() { - let now = Instant::now(); + fn appends_items_of_same_batch() { let mut send_queue = GsoQueue::new(); - send_queue.enqueue( - None, - DST_1, - b"foobarfoobarfoobarfoobarfoobarfoobarfoobarfoobar", - Ecn::NonEct, - now, - ); - send_queue.enqueue(None, DST_2, b"barbaz", Ecn::NonEct, now); - send_queue.enqueue(None, DST_3, b"barbaz1234", Ecn::NonEct, now); - send_queue.enqueue(None, DST_4, b"b", Ecn::NonEct, now); - send_queue.enqueue(None, DST_5, b"barbazfoobafoobarfoobar", Ecn::NonEct, now); - send_queue.enqueue(None, DST_2, b"baz", Ecn::NonEct, now); + send_queue.enqueue(None, DST_1, b"foobar", Ecn::NonEct); + send_queue.enqueue(None, DST_1, b"barbaz", Ecn::NonEct); + send_queue.enqueue(None, DST_1, b"foobaz", Ecn::NonEct); + send_queue.enqueue(None, DST_1, b"foo", Ecn::NonEct); let datagrams = send_queue.datagrams().collect::>(); - let is_sorted = datagrams.is_sorted_by(|a, b| a.segment_size >= b.segment_size); + assert_eq!(datagrams.len(), 1); + assert_eq!(datagrams[0].packet.as_ref(), b"foobarbarbazfoobazfoo"); + assert_eq!(datagrams[0].segment_size, Some(6)); + } - assert!(is_sorted); - assert_eq!(datagrams[0].segment_size, Some(48)); + #[test] + fn starts_new_batch_for_new_dst() { + let mut send_queue = GsoQueue::new(); + + send_queue.enqueue(None, DST_1, b"foobar", Ecn::NonEct); + send_queue.enqueue(None, DST_1, b"barbaz", Ecn::NonEct); + + send_queue.enqueue(None, DST_2, b"barbarba", Ecn::NonEct); + send_queue.enqueue(None, DST_2, b"foofoo", Ecn::NonEct); + + let datagrams = send_queue.datagrams().collect::>(); + + assert_eq!(datagrams.len(), 2); + assert_eq!(datagrams[0].packet.as_ref(), b"foobarbarbaz"); + assert_eq!(datagrams[0].segment_size, Some(6)); + assert_eq!(datagrams[0].dst, DST_1); + assert_eq!(datagrams[1].packet.as_ref(), b"barbarbafoofoo"); + assert_eq!(datagrams[1].segment_size, Some(8)); + assert_eq!(datagrams[1].dst, DST_2); + } + + #[test] + fn continues_batch_for_old_dst() { + let mut send_queue = GsoQueue::new(); + + send_queue.enqueue(None, DST_1, b"foobar", Ecn::NonEct); + send_queue.enqueue(None, DST_1, b"barbaz", Ecn::NonEct); + + send_queue.enqueue(None, DST_2, b"barbarba", Ecn::NonEct); + send_queue.enqueue(None, DST_2, b"foofoo", Ecn::NonEct); + + send_queue.enqueue(None, DST_1, b"foobaz", Ecn::NonEct); + send_queue.enqueue(None, DST_1, b"bazfoo", Ecn::NonEct); + + let datagrams = send_queue.datagrams().collect::>(); + + assert_eq!(datagrams.len(), 2); + assert_eq!(datagrams[0].packet.as_ref(), b"foobarbarbazfoobazbazfoo"); + assert_eq!(datagrams[0].segment_size, Some(6)); + assert_eq!(datagrams[0].dst, DST_1); + assert_eq!(datagrams[1].packet.as_ref(), b"barbarbafoofoo"); + assert_eq!(datagrams[1].segment_size, Some(8)); + assert_eq!(datagrams[1].dst, DST_2); + } + + #[test] + fn starts_new_batch_after_single_item_less_than_segment_length() { + let mut send_queue = GsoQueue::new(); + + send_queue.enqueue(None, DST_1, b"foobar", Ecn::NonEct); + send_queue.enqueue(None, DST_1, b"barbaz", Ecn::NonEct); + send_queue.enqueue(None, DST_1, b"bar", Ecn::NonEct); + + send_queue.enqueue(None, DST_1, b"barbaz", Ecn::NonEct); + + let datagrams = send_queue.datagrams().collect::>(); + + assert_eq!(datagrams.len(), 2); + assert_eq!(datagrams[0].packet.as_ref(), b"foobarbarbazbar"); + assert_eq!(datagrams[0].segment_size, Some(6)); + assert_eq!(datagrams[0].dst, DST_1); + assert_eq!(datagrams[1].packet.as_ref(), b"barbaz"); + assert_eq!(datagrams[1].segment_size, Some(6)); + assert_eq!(datagrams[1].dst, DST_1); } const DST_1: SocketAddr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 1111)); const DST_2: SocketAddr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 2222)); - const DST_3: SocketAddr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 3333)); - const DST_4: SocketAddr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 4444)); - const DST_5: SocketAddr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 5555)); } diff --git a/rust/ip-packet/src/lib.rs b/rust/ip-packet/src/lib.rs index b2c8cc4f6..b8c984f43 100644 --- a/rust/ip-packet/src/lib.rs +++ b/rust/ip-packet/src/lib.rs @@ -1046,7 +1046,7 @@ fn extract_l4_proto(payload: &[u8], protocol: IpNumber) -> Result for details. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] pub enum Ecn { NonEct = 0b00, Ect1 = 0b01,