diff --git a/rust/connlib/tunnel/src/io/gso_queue.rs b/rust/connlib/tunnel/src/io/gso_queue.rs index 5a86cc653..ac884e1bb 100644 --- a/rust/connlib/tunnel/src/io/gso_queue.rs +++ b/rust/connlib/tunnel/src/io/gso_queue.rs @@ -1,6 +1,5 @@ use std::{ collections::HashMap, - mem, net::SocketAddr, sync::Arc, time::{Duration, Instant}, @@ -40,7 +39,10 @@ impl GsoQueue { pub fn handle_timeout(&mut self, now: Instant) { self.inner.retain(|_, b| { - if !b.is_empty() { + if !{ + let this = &b; + this.inner.as_ref().is_none_or(|b| b.is_empty()) + } { return true; } @@ -62,32 +64,44 @@ impl GsoQueue { "MAX_SEGMENT_SIZE is miscalculated" ); - self.inner + let buffer = self + .inner .entry(Key { src, dst, segment_size, }) .or_insert_with(|| DatagramBuffer { - inner: self.buffer_pool.pull_owned(), + inner: None, last_access: now, - }) - .extend(payload, now); + }); + + buffer + .inner + .get_or_insert_with(|| self.buffer_pool.pull_owned()) + .extend_from_slice(payload); + buffer.last_access = now; } pub fn datagrams( &mut self, ) -> impl Iterator>> + '_ { - self.inner - .iter_mut() - .filter(|(_, b)| !b.is_empty()) - .map(|(key, buffer)| DatagramOut { + self.inner.iter_mut().filter_map(|(key, buffer)| { + // It is really important that we `take` the buffer here, otherwise it is not returned to the pool after. + let buffer = buffer.inner.take()?; + + if buffer.is_empty() { + return None; + } + + Some(DatagramOut { src: key.src, dst: key.dst, - packet: mem::replace(&mut buffer.inner, self.buffer_pool.pull_owned()), + packet: buffer, segment_size: Some(key.segment_size), }) + }) } pub fn clear(&mut self) { @@ -103,21 +117,10 @@ struct Key { } struct DatagramBuffer { - inner: lockfree_object_pool::SpinLockOwnedReusable, + inner: Option>, last_access: Instant, } -impl DatagramBuffer { - pub(crate) fn is_empty(&self) -> bool { - self.inner.is_empty() - } - - pub(crate) fn extend(&mut self, payload: &[u8], now: Instant) { - self.inner.extend_from_slice(payload); - self.last_access = now; - } -} - #[cfg(test)] mod tests { use std::net::{Ipv4Addr, SocketAddrV4}; @@ -166,5 +169,22 @@ mod tests { assert_eq!(datagrams[0].packet.as_ref(), b"foobar"); } + #[test] + fn sending_datagrams_returns_buffers_to_pool() { + let now = Instant::now(); + let mut send_queue = GsoQueue::new(); + + send_queue.enqueue(None, DST, b"foobar", now); + send_queue.enqueue(None, DST_2, b"bar", now); + + // Taking it from the iterator is "sending" ... + let _datagrams = send_queue.datagrams().collect::>(); + + for buf in send_queue.inner.values() { + assert!(buf.inner.is_none()) + } + } + const DST: SocketAddr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 1234)); + const DST_2: SocketAddr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 5678)); }