diff --git a/rust/Cargo.lock b/rust/Cargo.lock index 46d4bd6ff..b00521464 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -3182,6 +3182,7 @@ version = "0.1.0" dependencies = [ "anyhow", "etherparse", + "lockfree-object-pool", "proptest", "test-strategy", "thiserror", @@ -5907,6 +5908,7 @@ dependencies = [ "hex-display", "ip-packet", "itertools 0.13.0", + "lockfree-object-pool", "once_cell", "rand 0.8.5", "ringbuffer", diff --git a/rust/Cargo.toml b/rust/Cargo.toml index f8c397b7f..bdb187b08 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -109,6 +109,7 @@ subprocess = "0.2.9" subtle = "2.5.0" swift-bridge = "0.1.57" swift-bridge-build = "0.1.57" +lockfree-object-pool = "0.1.6" tauri = "2.0.3" tauri-build = "2.0.1" tauri-plugin-dialog = "2.0.1" diff --git a/rust/connlib/snownet/Cargo.toml b/rust/connlib/snownet/Cargo.toml index 3ce758d5f..a7f53d379 100644 --- a/rust/connlib/snownet/Cargo.toml +++ b/rust/connlib/snownet/Cargo.toml @@ -14,6 +14,7 @@ hex = { workspace = true } hex-display = { workspace = true } ip-packet = { workspace = true } itertools = { workspace = true } +lockfree-object-pool = { workspace = true } once_cell = { workspace = true } rand = { workspace = true } ringbuffer = { workspace = true } diff --git a/rust/connlib/snownet/src/allocation.rs b/rust/connlib/snownet/src/allocation.rs index 62ade3273..99c7ba56c 100644 --- a/rust/connlib/snownet/src/allocation.rs +++ b/rust/connlib/snownet/src/allocation.rs @@ -8,7 +8,6 @@ use ::backoff::backoff::Backoff; use bytecodec::{DecodeExt as _, EncodeExt as _}; use firezone_logging::{err_with_src, std_dyn_err}; use hex_display::HexDisplayExt as _; -use ip_packet::MAX_DATAGRAM_PAYLOAD; use rand::random; use ringbuffer::{AllocRingBuffer, RingBuffer as _}; use std::{ @@ -767,7 +766,7 @@ impl Allocation { pub fn encode_to_encrypted_packet( &self, peer: SocketAddr, - mut buffer: [u8; MAX_DATAGRAM_PAYLOAD], + mut buffer: lockfree_object_pool::SpinLockOwnedReusable>, buffer_len: usize, now: Instant, ) -> Option { diff --git a/rust/connlib/snownet/src/node.rs b/rust/connlib/snownet/src/node.rs index fb80e207b..3a18ff137 100644 --- a/rust/connlib/snownet/src/node.rs +++ b/rust/connlib/snownet/src/node.rs @@ -10,9 +10,7 @@ use boringtun::{noise::rate_limiter::RateLimiter, x25519::StaticSecret}; use core::fmt; use firezone_logging::err_with_src; use hex_display::HexDisplayExt; -use ip_packet::{ - ConvertibleIpv4Packet, ConvertibleIpv6Packet, IpPacket, IpPacketBuf, MAX_DATAGRAM_PAYLOAD, -}; +use ip_packet::{ConvertibleIpv4Packet, ConvertibleIpv6Packet, IpPacket, IpPacketBuf}; use rand::rngs::StdRng; use rand::seq::IteratorRandom; use rand::{random, Rng, SeedableRng}; @@ -126,6 +124,9 @@ pub struct Node { pending_events: VecDeque>, stats: NodeStats, + // All access to [`Node`] happens in the same thread, so we should never get contention which makes a spinlock ideal. + // This is wrapped in an `Arc` so we can use `pull_owned`. + buffer_pool: Arc>>, mode: T, rng: StdRng, @@ -180,6 +181,10 @@ where allocations: Default::default(), connections: Default::default(), stats: Default::default(), + buffer_pool: Arc::new(lockfree_object_pool::SpinLockObjectPool::new( + || vec![0; ip_packet::MAX_DATAGRAM_PAYLOAD], + |v| v.fill(0), + )), } } @@ -438,11 +443,11 @@ where .get_established_mut(&connection) .ok_or(Error::NotConnected)?; - let mut buffer = EncryptBuffer::default(); + let mut buffer = self.buffer_pool.pull_owned(); // Encode the packet with an offset of 4 bytes, in case we need to wrap it in a channel-data message. let Some(packet_len) = conn - .encapsulate(packet.packet(), &mut buffer.inner[4..], now)? + .encapsulate(packet.packet(), &mut buffer[4..], now)? .map(|p| p.len()) // Mapping to len() here terminate the mutable borrow of buffer, allowing re-borrowing further down. else { @@ -454,7 +459,7 @@ where let socket = match &mut conn.state { ConnectionState::Connecting { buffered, .. } => { - buffered.push(buffer.inner[packet_start..packet_end].to_vec()); + buffered.push(buffer[packet_start..packet_end].to_vec()); let num_buffered = buffered.len(); let _guard = conn.span.enter(); @@ -477,7 +482,7 @@ where dst: remote, packet_start, packet_len, - buffer: buffer.inner, + buffer, })), PeerSocket::Relay { relay, dest: peer } => { let Some(allocation) = self.allocations.get(&relay) else { @@ -485,7 +490,7 @@ where return Ok(None); }; let Some(enc_packet) = - allocation.encode_to_encrypted_packet(peer, buffer.inner, packet_end, now) + allocation.encode_to_encrypted_packet(peer, buffer, packet_end, now) else { tracing::warn!(%peer, "No channel"); return Ok(None); @@ -1473,31 +1478,12 @@ pub enum Event { ConnectionClosed(TId), } -struct EncryptBuffer { - inner: [u8; MAX_DATAGRAM_PAYLOAD], -} - -impl EncryptBuffer { - fn new() -> Self { - Self { - inner: [0u8; MAX_DATAGRAM_PAYLOAD], - } - } -} - -impl Default for EncryptBuffer { - fn default() -> Self { - Self::new() - } -} - -#[derive(Debug, Clone)] pub struct EncryptedPacket { pub(crate) src: Option, pub(crate) dst: SocketAddr, pub(crate) packet_start: usize, pub(crate) packet_len: usize, - pub(crate) buffer: [u8; MAX_DATAGRAM_PAYLOAD], + pub(crate) buffer: lockfree_object_pool::SpinLockOwnedReusable>, } impl EncryptedPacket { diff --git a/rust/ip-packet/Cargo.toml b/rust/ip-packet/Cargo.toml index d2158418d..5c1524c87 100644 --- a/rust/ip-packet/Cargo.toml +++ b/rust/ip-packet/Cargo.toml @@ -13,6 +13,7 @@ proptest = ["dep:proptest"] [dependencies] anyhow = { workspace = true } etherparse = { workspace = true } +lockfree-object-pool = { workspace = true } proptest = { workspace = true, optional = true } thiserror = { workspace = true } tracing = { workspace = true } diff --git a/rust/ip-packet/src/buffer_pool.rs b/rust/ip-packet/src/buffer_pool.rs new file mode 100644 index 000000000..ed7d8e785 --- /dev/null +++ b/rust/ip-packet/src/buffer_pool.rs @@ -0,0 +1,96 @@ +use std::{ + ops::{Deref, DerefMut}, + sync::{Arc, LazyLock}, +}; + +use crate::MAX_DATAGRAM_PAYLOAD; + +type BufferPool = Arc>>; + +static BUFFER_POOL: LazyLock = LazyLock::new(|| { + Arc::new(lockfree_object_pool::MutexObjectPool::new( + || vec![0; MAX_DATAGRAM_PAYLOAD], + |v| v.fill(0), + )) +}); + +pub struct Buffer(lockfree_object_pool::MutexOwnedReusable>); + +impl Clone for Buffer { + fn clone(&self) -> Self { + let mut copy = Buffer::default(); + + copy.0.resize(self.len(), 0); + copy.copy_from_slice(self); + + copy + } +} + +impl PartialEq for Buffer { + fn eq(&self, other: &Self) -> bool { + self.as_ref() == other.as_ref() + } +} + +impl std::fmt::Debug for Buffer { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_tuple("Buffer").finish() + } +} + +impl Deref for Buffer { + type Target = [u8]; + + fn deref(&self) -> &Self::Target { + &self.0[..] + } +} + +impl DerefMut for Buffer { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0[..] + } +} + +impl Default for Buffer { + fn default() -> Self { + Self(BUFFER_POOL.pull_owned()) + } +} + +impl Drop for Buffer { + fn drop(&mut self) { + debug_assert_eq!( + self.0.capacity(), + MAX_DATAGRAM_PAYLOAD, + "Buffer should never re-allocate" + ) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn buffer_can_be_cloned() { + let mut buffer = Buffer::default(); + buffer[..11].copy_from_slice(b"hello world"); + + let buffer2 = buffer.clone(); + + assert_eq!(&buffer2[..], &buffer[..]); + } + + #[test] + fn cloned_buffer_owns_its_own_memory() { + let mut buffer = Buffer::default(); + buffer[..11].copy_from_slice(b"hello world"); + + let buffer2 = buffer.clone(); + drop(buffer); + + assert_eq!(&buffer2[..11], b"hello world"); + } +} diff --git a/rust/ip-packet/src/lib.rs b/rust/ip-packet/src/lib.rs index 39c6fd55d..06a20b3ad 100644 --- a/rust/ip-packet/src/lib.rs +++ b/rust/ip-packet/src/lib.rs @@ -2,6 +2,7 @@ pub mod make; +mod buffer_pool; mod fz_p2p_control; mod fz_p2p_control_slice; mod icmp_dest_unreachable; @@ -18,6 +19,7 @@ mod slice_utils; mod tcp_header_slice_mut; mod udp_header_slice_mut; +use buffer_pool::Buffer; pub use etherparse::*; pub use fz_p2p_control::EventType as FzP2pEventType; pub use fz_p2p_control_slice::FzP2pControlSlice; @@ -101,15 +103,14 @@ pub enum Layer4Protocol { } /// A buffer for reading a new [`IpPacket`] from the network. +#[derive(Default)] pub struct IpPacketBuf { - inner: [u8; MAX_DATAGRAM_PAYLOAD], + inner: Buffer, } impl IpPacketBuf { pub fn new() -> Self { - Self { - inner: [0u8; MAX_DATAGRAM_PAYLOAD], - } + Self::default() } pub fn buf(&mut self) -> &mut [u8] { @@ -117,12 +118,6 @@ impl IpPacketBuf { } } -impl Default for IpPacketBuf { - fn default() -> Self { - Self::new() - } -} - #[derive(PartialEq, Clone)] pub enum IpPacket { Ipv4(ConvertibleIpv4Packet), @@ -168,7 +163,7 @@ impl std::fmt::Debug for IpPacket { #[derive(Debug, PartialEq, Clone)] pub struct ConvertibleIpv4Packet { - buf: [u8; MAX_DATAGRAM_PAYLOAD], + buf: Buffer, start: usize, len: usize, } @@ -248,7 +243,7 @@ impl ConvertibleIpv4Packet { #[derive(Debug, PartialEq, Clone)] pub struct ConvertibleIpv6Packet { - buf: [u8; MAX_DATAGRAM_PAYLOAD], + buf: Buffer, start: usize, len: usize, }