diff --git a/rust/Cargo.lock b/rust/Cargo.lock index be3457c96..47f316c6a 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -831,6 +831,18 @@ dependencies = [ "alloc-stdlib", ] +[[package]] +name = "bufferpool" +version = "0.1.0" +dependencies = [ + "bytes", + "lockfree-object-pool", + "opentelemetry", + "opentelemetry_sdk", + "tokio", + "tracing", +] + [[package]] name = "bumpalo" version = "3.16.0" @@ -2163,6 +2175,8 @@ version = "0.1.0" dependencies = [ "anyhow", "axum", + "bufferpool", + "bytes", "clap", "firezone-logging", "flume", @@ -2477,6 +2491,7 @@ dependencies = [ "base64 0.22.1", "bimap", "boringtun", + "bufferpool", "bytes", "chrono", "connlib-model", @@ -2499,7 +2514,6 @@ dependencies = [ "itertools 0.13.0", "l4-tcp-dns-server", "l4-udp-dns-server", - "lockfree-object-pool", "lru", "opentelemetry", "proptest", @@ -3501,9 +3515,9 @@ name = "ip-packet" version = "0.1.0" dependencies = [ "anyhow", + "bufferpool", "etherparse", "etherparse-ext", - "lockfree-object-pool", "proptest", "test-strategy", "thiserror 1.0.69", @@ -6457,6 +6471,7 @@ name = "snownet" version = "0.1.0" dependencies = [ "boringtun", + "bufferpool", "bytecodec", "bytes", "derive_more 1.0.0", @@ -6465,7 +6480,6 @@ dependencies = [ "hex-display", "ip-packet", "itertools 0.13.0", - "lockfree-object-pool", "once_cell", "rand 0.8.5", "ringbuffer", @@ -6482,12 +6496,12 @@ name = "socket-factory" version = "0.1.0" dependencies = [ "anyhow", + "bufferpool", "bytes", "derive_more 1.0.0", "firezone-logging", "gat-lending-iterator", "ip-packet", - "lockfree-object-pool", "opentelemetry", "parking_lot", "quinn-udp", diff --git a/rust/Cargo.toml b/rust/Cargo.toml index 05909cf18..f16e24a44 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -1,6 +1,7 @@ [workspace] members = [ "bin-shared", + "bufferpool", "connlib/clients/android", "connlib/clients/apple", "connlib/clients/shared", @@ -97,6 +98,7 @@ keyring = "3.6.2" known-folders = "1.2.0" l4-tcp-dns-server = { path = "connlib/l4-tcp-dns-server" } l4-udp-dns-server = { path = "connlib/l4-udp-dns-server" } +bufferpool = { path = "bufferpool" } libc = "0.2.171" lockfree-object-pool = "0.1.6" log = "0.4" diff --git a/rust/bin-shared/Cargo.toml b/rust/bin-shared/Cargo.toml index 5ef744d53..1a4880828 100644 --- a/rust/bin-shared/Cargo.toml +++ b/rust/bin-shared/Cargo.toml @@ -23,6 +23,8 @@ tracing = { workspace = true } tun = { workspace = true } [dev-dependencies] +bufferpool = { workspace = true } +bytes = { workspace = true } tokio = { workspace = true, features = ["macros", "rt-multi-thread"] } [target.'cfg(target_os = "linux")'.dependencies] diff --git a/rust/bin-shared/tests/no_packet_loops_udp.rs b/rust/bin-shared/tests/no_packet_loops_udp.rs index a0889e591..3ddc22ba2 100644 --- a/rust/bin-shared/tests/no_packet_loops_udp.rs +++ b/rust/bin-shared/tests/no_packet_loops_udp.rs @@ -1,5 +1,7 @@ #![allow(clippy::unwrap_used)] +use bufferpool::BufferPool; +use bytes::BytesMut; use firezone_bin_shared::{TunDeviceManager, platform::udp_socket_factory}; use gat_lending_iterator::LendingIterator as _; use ip_network::Ipv4Network; @@ -19,6 +21,8 @@ async fn no_packet_loops_udp() { let ipv4 = Ipv4Addr::from([100, 90, 215, 97]); let ipv6 = Ipv6Addr::from([0xfd00, 0x2021, 0x1111, 0x0, 0x0, 0x0, 0x0016, 0x588f]); + let bufferpool = BufferPool::::new(0, "test"); + let mut device_manager = TunDeviceManager::new(1280, 1).unwrap(); let _tun = device_manager.make_tun().unwrap(); device_manager.set_ips(ipv4, ipv6).await.unwrap(); @@ -45,7 +49,9 @@ async fn no_packet_loops_udp() { .send(DatagramOut { src: None, dst: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(141, 101, 90, 0), 3478)), // stun.cloudflare.com, - packet: &hex_literal::hex!("000100002112A4420123456789abcdef01234567").as_ref(), + packet: bufferpool.pull_initialised( + hex_literal::hex!("000100002112A4420123456789abcdef01234567").as_ref(), + ), segment_size: None, ecn: Ecn::NonEct, }) diff --git a/rust/bufferpool/Cargo.toml b/rust/bufferpool/Cargo.toml new file mode 100644 index 000000000..fbf3d07ac --- /dev/null +++ b/rust/bufferpool/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "bufferpool" +version = "0.1.0" +edition = { workspace = true } +license = { workspace = true } + +[lib] +path = "lib.rs" + +[dependencies] +bytes = { workspace = true } +lockfree-object-pool = { workspace = true } +opentelemetry = { workspace = true, features = ["metrics"] } +tracing = { workspace = true } + +[dev-dependencies] +opentelemetry_sdk = { workspace = true, features = ["testing"] } +tokio = { workspace = true, features = ["macros", "rt"] } + +[lints] +workspace = true diff --git a/rust/bufferpool/lib.rs b/rust/bufferpool/lib.rs new file mode 100644 index 000000000..b5920e8bf --- /dev/null +++ b/rust/bufferpool/lib.rs @@ -0,0 +1,348 @@ +#![cfg_attr(test, allow(clippy::unwrap_used))] + +use std::{ + ops::{Deref, DerefMut}, + sync::Arc, +}; + +use bytes::BytesMut; +use opentelemetry::{KeyValue, metrics::UpDownCounter}; + +#[derive(Clone)] +pub struct BufferPool { + inner: Arc>>, +} + +impl BufferPool +where + B: Buf, +{ + pub fn new(capacity: usize, tag: &'static str) -> Self { + let buffer_counter = opentelemetry::global::meter("connlib") + .i64_up_down_counter("system.buffer.count") + .with_description("The number of buffers allocated in the pool.") + .with_unit("{buffers}") + .init(); + + Self { + inner: Arc::new(lockfree_object_pool::MutexObjectPool::new( + move || { + BufferStorage::new( + B::with_capacity(capacity), + buffer_counter.clone(), + [ + KeyValue::new("system.buffer.pool.name", tag), + KeyValue::new("system.buffer.pool.buffer_size", capacity as i64), + ], + ) + }, + |_| {}, + )), + } + } + + pub fn pull(&self) -> Buffer { + Buffer { + inner: self.inner.pull_owned(), + pool: self.inner.clone(), + } + } +} + +impl BufferPool +where + B: Buf + DerefMut, +{ + pub fn pull_initialised(&self, data: &[u8]) -> Buffer { + let mut buffer = self.pull(); + let len = data.len(); + + buffer.resize_to(len); + buffer.copy_from_slice(data); + + buffer + } +} + +pub struct Buffer { + inner: lockfree_object_pool::MutexOwnedReusable>, + pool: Arc>>, +} + +impl Buffer> { + /// Shifts the start of the buffer to the right by N bytes, returning the bytes removed from the front of the buffer. + pub fn shift_start_right(&mut self, num: usize) -> Vec { + let num_to_end = self.split_off(num); + + std::mem::replace(&mut self.inner.inner, num_to_end) + } + + /// Shifts the start of the buffer to the left by N bytes, returning a slice to the added bytes at the front of the buffer. + pub fn shift_start_left(&mut self, num: usize) -> &mut [u8] { + let current_len = self.len(); + + self.resize(current_len + num, 0); + self.copy_within(..current_len, num); + + &mut self[..num] + } +} + +impl Clone for Buffer +where + B: Buf, +{ + fn clone(&self) -> Self { + let mut copy = self.pool.pull_owned(); + + self.inner.inner.clone(&mut copy); + + Self { + inner: copy, + pool: self.pool.clone(), + } + } +} + +impl PartialEq for Buffer +where + B: Deref, +{ + fn eq(&self, other: &Self) -> bool { + self.as_ref() == other.as_ref() + } +} + +impl Eq for Buffer where B: Deref {} + +impl PartialOrd for Buffer +where + B: Deref, +{ + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for Buffer +where + B: Deref, +{ + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.as_ref().cmp(other.as_ref()) + } +} + +impl std::fmt::Debug for Buffer { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_tuple("Buffer").finish() + } +} + +impl Deref for Buffer { + type Target = B; + + fn deref(&self) -> &Self::Target { + self.inner.deref() + } +} + +impl DerefMut for Buffer { + fn deref_mut(&mut self) -> &mut Self::Target { + self.inner.deref_mut() + } +} + +pub trait Buf: Sized { + fn with_capacity(capacity: usize) -> Self; + fn clone(&self, dst: &mut Self); + fn resize_to(&mut self, len: usize); +} + +impl Buf for Vec { + fn with_capacity(capacity: usize) -> Self { + vec![0; capacity] + } + + fn clone(&self, dst: &mut Self) { + dst.resize(self.len(), 0); + dst.copy_from_slice(self); + } + + fn resize_to(&mut self, len: usize) { + self.resize(len, 0); + } +} + +impl Buf for BytesMut { + fn with_capacity(capacity: usize) -> Self { + BytesMut::zeroed(capacity) + } + + fn clone(&self, dst: &mut Self) { + dst.resize(self.len(), 0); + dst.copy_from_slice(self); + } + + fn resize_to(&mut self, len: usize) { + self.resize(len, 0); + } +} + +/// A wrapper around a buffer `B` that keeps track of how many buffers there are in a counter. +struct BufferStorage { + inner: B, + + attributes: [KeyValue; 2], + counter: UpDownCounter, +} + +impl Drop for BufferStorage { + fn drop(&mut self) { + self.counter.add(-1, &self.attributes); + } +} + +impl BufferStorage { + fn new(inner: B, counter: UpDownCounter, attributes: [KeyValue; 2]) -> Self { + counter.add(1, &attributes); + + Self { + inner, + counter, + attributes, + } + } +} + +impl Deref for BufferStorage { + type Target = B; + + fn deref(&self) -> &Self::Target { + &self.inner + } +} + +impl DerefMut for BufferStorage { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.inner + } +} + +#[cfg(test)] +mod tests { + use std::time::Duration; + + use opentelemetry::global; + use opentelemetry_sdk::{ + metrics::{PeriodicReader, SdkMeterProvider, data::Sum}, + testing::metrics::InMemoryMetricsExporter, + }; + + use super::*; + + #[test] + fn buffer_can_be_cloned() { + let pool = BufferPool::>::new(1024, "test"); + + let buffer = pool.pull_initialised(b"hello world"); + + #[allow(clippy::redundant_clone)] + let buffer2 = buffer.clone(); + + assert_eq!(&buffer2[..], &buffer[..]); + } + + #[test] + fn cloned_buffer_owns_its_own_memory() { + let pool = BufferPool::>::new(1024, "test"); + + let buffer = pool.pull_initialised(b"hello world"); + + let buffer2 = buffer.clone(); + drop(buffer); + + assert_eq!(&buffer2[..11], b"hello world"); + } + + #[test] + fn initialised_buffer_is_only_as_long_as_content() { + let pool = BufferPool::>::new(1024, "test"); + + let buffer = pool.pull_initialised(b"hello world"); + + assert_eq!(buffer.len(), 11); + } + + #[test] + fn shift_start_right() { + let pool = BufferPool::>::new(1024, "test"); + + let mut buffer = pool.pull_initialised(b"hello world"); + + let front = buffer.shift_start_right(5); + + assert_eq!(front, b"hello"); + assert_eq!(&*buffer, b" world"); + } + + #[test] + fn shift_start_left() { + let pool = BufferPool::>::new(1024, "test"); + + let mut buffer = pool.pull_initialised(b"hello world"); + + let front = buffer.shift_start_left(5); + front.copy_from_slice(b"12345"); + + assert_eq!(&*buffer, b"12345hello world"); + } + + #[tokio::test] + async fn buffer_pool_metrics() { + let (_provider, exporter) = init_meter_provider(); + + let pool = BufferPool::>::new(1024, "test"); + + let buffer1 = pool.pull_initialised(b"hello world"); + let buffer2 = pool.pull_initialised(b"hello world"); + let buffer3 = pool.pull_initialised(b"hello world"); + + tokio::time::sleep(Duration::from_millis(10)).await; // Wait for metrics to be exported. + + assert_eq!(get_num_buffers(&exporter), 3); + + drop(pool); + drop(buffer1); + drop(buffer2); + drop(buffer3); + + tokio::time::sleep(Duration::from_millis(10)).await; // Wait for metrics to be exported. + + assert_eq!(get_num_buffers(&exporter), 0); + } + + fn get_num_buffers(exporter: &InMemoryMetricsExporter) -> i64 { + let metrics = exporter.get_finished_metrics().unwrap(); + + let metric = &metrics.iter().last().unwrap().scope_metrics[0].metrics[0]; + let sum = metric.data.as_any().downcast_ref::>().unwrap(); + + sum.data_points[0].value + } + + fn init_meter_provider() -> (SdkMeterProvider, InMemoryMetricsExporter) { + let exporter = InMemoryMetricsExporter::default(); + + let provider = SdkMeterProvider::builder() + .with_reader( + PeriodicReader::builder(exporter.clone(), opentelemetry_sdk::runtime::Tokio) + .with_interval(Duration::from_millis(1)) + .build(), + ) + .build(); + global::set_meter_provider(provider.clone()); + + (provider, exporter) + } +} diff --git a/rust/connlib/snownet/Cargo.toml b/rust/connlib/snownet/Cargo.toml index d4b95c8c4..c70b5683a 100644 --- a/rust/connlib/snownet/Cargo.toml +++ b/rust/connlib/snownet/Cargo.toml @@ -6,6 +6,7 @@ license = { workspace = true } [dependencies] boringtun = { workspace = true } +bufferpool = { workspace = true } bytecodec = { workspace = true } bytes = { workspace = true } derive_more = { workspace = true, features = ["debug"] } @@ -14,7 +15,6 @@ hex = { workspace = true } hex-display = { workspace = true } ip-packet = { workspace = true } itertools = { workspace = true } -lockfree-object-pool = { workspace = true } once_cell = { workspace = true } rand = { workspace = true } ringbuffer = { workspace = true } diff --git a/rust/connlib/snownet/src/allocation.rs b/rust/connlib/snownet/src/allocation.rs index 00932e901..c55fa050f 100644 --- a/rust/connlib/snownet/src/allocation.rs +++ b/rust/connlib/snownet/src/allocation.rs @@ -2,6 +2,7 @@ use crate::{ backoff::{self, ExponentialBackoff}, node::{SessionId, Transmit}, }; +use bufferpool::BufferPool; use bytecodec::{DecodeExt as _, EncodeExt as _}; use firezone_logging::err_with_src; use hex_display::HexDisplayExt as _; @@ -47,7 +48,6 @@ const BINDING_INTERVAL: Duration = Duration::from_secs(25); /// Represents a TURN allocation that refreshes itself. /// /// Allocations have a lifetime and need to be continuously refreshed to stay active. -#[derive(Debug)] pub struct Allocation { /// The known sockets of the relay. server: RelaySocket, @@ -77,7 +77,7 @@ pub struct Allocation { /// When we received the allocation and how long it is valid. allocation_lifetime: Option<(Instant, Duration)>, - buffered_transmits: VecDeque>, + buffered_transmits: VecDeque, events: VecDeque, sent_requests: BTreeMap, ExponentialBackoff)>, @@ -88,6 +88,8 @@ pub struct Allocation { credentials: Option, explicit_failure: Option, + + buffer_pool: BufferPool>, } #[derive(derive_more::Debug, Clone, Copy)] @@ -212,6 +214,7 @@ impl Allocation { realm: Realm, now: Instant, session_id: SessionId, + buffer_pool: BufferPool>, ) -> Self { let mut allocation = Self { server, @@ -235,6 +238,7 @@ impl Allocation { software: Software::new(format!("snownet; session={session_id}")) .expect("description has less then 128 chars"), explicit_failure: Default::default(), + buffer_pool, }; allocation.send_binding_requests(now); @@ -742,7 +746,7 @@ impl Allocation { self.events.pop_front() } - pub fn poll_transmit(&mut self) -> Option> { + pub fn poll_transmit(&mut self) -> Option { self.buffered_transmits.pop_front() } @@ -1082,10 +1086,11 @@ impl Allocation { self.sent_requests .insert(id, (dst, message.clone(), backoff)); + self.buffered_transmits.push_back(Transmit { src: None, dst, - payload: encode(message).into(), + payload: self.buffer_pool.pull_initialised(&encode(message)), }); true @@ -2818,6 +2823,7 @@ mod tests { Realm::new("firezone".to_owned()).unwrap(), start, SessionId::default(), + BufferPool::new(500, "test"), ) } @@ -2832,6 +2838,7 @@ mod tests { Realm::new("firezone".to_owned()).unwrap(), start, SessionId::default(), + BufferPool::new(500, "test"), ) } diff --git a/rust/connlib/snownet/src/lib.rs b/rust/connlib/snownet/src/lib.rs index 8940a1b19..4569d0f8b 100644 --- a/rust/connlib/snownet/src/lib.rs +++ b/rust/connlib/snownet/src/lib.rs @@ -16,8 +16,8 @@ pub use allocation::RelaySocket; #[allow(deprecated)] // Rust bug: `expect` doesn't seem to work on imports? pub use node::{Answer, Offer}; pub use node::{ - Client, ClientNode, Credentials, EncryptedPacket, Error, Event, HANDSHAKE_TIMEOUT, - NoTurnServers, Node, Server, ServerNode, Transmit, + Client, ClientNode, Credentials, Error, Event, HANDSHAKE_TIMEOUT, NoTurnServers, Node, Server, + ServerNode, Transmit, }; pub use stats::{ConnectionStats, NodeStats}; diff --git a/rust/connlib/snownet/src/node.rs b/rust/connlib/snownet/src/node.rs index 46d83613b..8bdfd60c5 100644 --- a/rust/connlib/snownet/src/node.rs +++ b/rust/connlib/snownet/src/node.rs @@ -7,6 +7,7 @@ use boringtun::noise::errors::WireGuardError; use boringtun::noise::{Tunn, TunnResult}; use boringtun::x25519::PublicKey; use boringtun::{noise::rate_limiter::RateLimiter, x25519::StaticSecret}; +use bufferpool::{Buffer, BufferPool}; use core::fmt; use firezone_logging::err_with_src; use hex_display::HexDisplayExt; @@ -17,7 +18,6 @@ use rand::{Rng, RngCore, SeedableRng, random}; use ringbuffer::{AllocRingBuffer, RingBuffer as _}; use secrecy::{ExposeSecret, Secret}; use sha2::Digest; -use std::borrow::Cow; use std::collections::btree_map::Entry; use std::collections::{BTreeMap, BTreeSet}; use std::hash::Hash; @@ -118,7 +118,7 @@ pub struct Node { rate_limiter: Arc, /// Host and server-reflexive candidates that are shared between all connections. shared_candidates: CandidateSet, - buffered_transmits: VecDeque>, + buffered_transmits: VecDeque, next_rate_limiter_reset: Option, @@ -128,9 +128,7 @@ pub struct Node { pending_events: VecDeque>, stats: NodeStats, - // All access to [`Node`] happens in the same thread, so we should never get contention which makes a spinlock ideal. - // This is wrapped in an `Arc` so we can use `pull_owned`. - buffer_pool: Arc>>, + buffer_pool: BufferPool>, mode: T, rng: StdRng, @@ -184,10 +182,7 @@ where allocations: Default::default(), connections: Default::default(), stats: Default::default(), - buffer_pool: Arc::new(lockfree_object_pool::SpinLockObjectPool::new( - || vec![0; ip_packet::MAX_FZ_PAYLOAD], - |v| v.fill(0), - )), + buffer_pool: BufferPool::new(ip_packet::MAX_FZ_PAYLOAD, "snownet"), } } @@ -447,7 +442,7 @@ where connection: TId, packet: IpPacket, now: Instant, - ) -> Result, Error> { + ) -> Result, Error> { let conn = self .connections .get_established_mut(&connection) @@ -462,7 +457,8 @@ where return Ok(None); } - let mut buffer = self.buffer_pool.pull_owned(); + let mut buffer = self.buffer_pool.pull(); + buffer.resize(ip_packet::MAX_FZ_PAYLOAD, 0); // Encode the packet with an offset of 4 bytes, in case we need to wrap it in a channel-data message. let Some(packet_len) = conn @@ -500,13 +496,16 @@ where | PeerSocket::PeerToRelay { source, dest: remote, - } => Ok(Some(EncryptedPacket { - src: Some(source), - dst: remote, - packet_start, - packet_len, - buffer, - })), + } => { + buffer.copy_within(packet_start..packet_end, 0); + buffer.truncate(packet_len); + + Ok(Some(Transmit { + src: Some(source), + dst: remote, + payload: buffer, + })) + } PeerSocket::RelayToPeer { relay, dest: peer } | PeerSocket::RelayToRelay { relay, dest: peer } => { let Some(allocation) = self.allocations.get_mut(&relay) else { @@ -519,12 +518,12 @@ where return Ok(None); }; - Ok(Some(EncryptedPacket { + buffer.truncate(packet_end); + + Ok(Some(Transmit { src: None, dst: encode_ok.socket, - packet_start: 0, - packet_len: packet_end, - buffer, + payload: buffer, })) } } @@ -616,7 +615,7 @@ where /// Returns buffered data that needs to be sent on the socket. #[must_use] - pub fn poll_transmit(&mut self) -> Option> { + pub fn poll_transmit(&mut self) -> Option { let allocation_transmits = &mut self .allocations .values_mut() @@ -679,6 +678,7 @@ where realm, now, self.session_id.clone(), + self.buffer_pool.clone(), )); tracing::info!(%rid, address = ?server, "Added new TURN server"); @@ -706,6 +706,7 @@ where realm, now, self.session_id.clone(), + self.buffer_pool.clone(), )); tracing::info!(%rid, address = ?server, "Replaced TURN server"); @@ -770,6 +771,7 @@ where }, possible_sockets: BTreeSet::default(), span: info_span!(parent: tracing::Span::none(), "connection", %cid), + buffer_pool: self.buffer_pool.clone(), } } @@ -1494,38 +1496,8 @@ pub enum Event { ConnectionClosed(TId), } -pub struct EncryptedPacket { - pub(crate) src: Option, - pub(crate) dst: SocketAddr, - pub(crate) packet_start: usize, - pub(crate) packet_len: usize, - pub(crate) buffer: lockfree_object_pool::SpinLockOwnedReusable>, -} - -impl EncryptedPacket { - pub fn to_transmit(&self) -> Transmit<'_> { - Transmit { - src: self.src, - dst: self.dst, - payload: Cow::Borrowed(self.payload()), - } - } - - pub fn src(&self) -> Option { - self.src - } - - pub fn dst(&self) -> SocketAddr { - self.dst - } - - pub fn payload(&self) -> &[u8] { - &self.buffer[self.packet_start..(self.packet_start + self.packet_len)] - } -} - #[derive(Clone, PartialEq, PartialOrd, Eq, Ord)] -pub struct Transmit<'a> { +pub struct Transmit { /// The local interface from which this packet should be sent. /// /// If `None`, it can be sent from any interface. @@ -1536,10 +1508,10 @@ pub struct Transmit<'a> { /// The remote the packet should be sent to. pub dst: SocketAddr, /// The data that should be sent. - pub payload: Cow<'a, [u8]>, + pub payload: Buffer>, } -impl fmt::Debug for Transmit<'_> { +impl fmt::Debug for Transmit { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("Transmit") .field("src", &self.src) @@ -1549,16 +1521,6 @@ impl fmt::Debug for Transmit<'_> { } } -impl Transmit<'_> { - pub fn into_owned(self) -> Transmit<'static> { - Transmit { - src: self.src, - dst: self.dst, - payload: Cow::Owned(self.payload.into_owned()), - } - } -} - struct InitialConnection { agent: IceAgent, session_key: Secret<[u8; 32]>, @@ -1624,6 +1586,7 @@ struct Connection { buffer: Vec, span: tracing::Span, + buffer_pool: BufferPool>, } enum ConnectionState { @@ -1833,7 +1796,7 @@ where cid: TId, now: Instant, allocations: &mut BTreeMap, - transmits: &mut VecDeque>, + transmits: &mut VecDeque, ) where TId: Copy + Ord + fmt::Display, RId: Copy + Ord + fmt::Display, @@ -1913,7 +1876,13 @@ where tracing::debug!(%num_buffered, "Flushing packets buffered during ICE"); transmits.extend(buffered.into_iter().flat_map(|packet| { - make_owned_transmit(remote_socket, &packet, allocations, now) + make_owned_transmit( + remote_socket, + &packet, + &self.buffer_pool, + allocations, + now, + ) })); self.state = ConnectionState::Connected { peer_socket: remote_socket, @@ -1984,7 +1953,7 @@ where transmits.push_back(Transmit { src: Some(source), dst, - payload: Cow::Owned(stun_packet.into()), + payload: self.buffer_pool.pull_initialised(&Vec::from(stun_packet)), }); continue; }; @@ -2005,7 +1974,7 @@ where transmits.push_back(Transmit { src: None, dst: encode_ok.socket, - payload: Cow::Owned(data_channel_packet), + payload: self.buffer_pool.pull_initialised(&data_channel_packet), }); } } @@ -2014,7 +1983,7 @@ where &mut self, now: Instant, allocations: &mut BTreeMap, - transmits: &mut VecDeque>, + transmits: &mut VecDeque, ) { // Don't update wireguard timers until we are connected. let Some(peer_socket) = self.socket() else { @@ -2038,7 +2007,13 @@ where tracing::warn!(?e); } TunnResult::WriteToNetwork(b) => { - transmits.extend(make_owned_transmit(peer_socket, b, allocations, now)); + transmits.extend(make_owned_transmit( + peer_socket, + b, + &self.buffer_pool, + allocations, + now, + )); } TunnResult::WriteToTunnelV4(..) | TunnResult::WriteToTunnelV6(..) => { panic!("Unexpected result from update_timers") @@ -2072,7 +2047,7 @@ where &mut self, packet: &[u8], allocations: &mut BTreeMap, - transmits: &mut VecDeque>, + transmits: &mut VecDeque, now: Instant, ) -> ControlFlow, IpPacket> { let _guard = self.span.enter(); @@ -2132,6 +2107,7 @@ where transmits.extend(make_owned_transmit( *peer_socket, bytes, + &self.buffer_pool, allocations, now, )); @@ -2143,6 +2119,7 @@ where transmits.extend(make_owned_transmit( *peer_socket, packet, + &self.buffer_pool, allocations, now, )); @@ -2165,7 +2142,7 @@ where fn force_handshake( &mut self, allocations: &mut BTreeMap, - transmits: &mut VecDeque>, + transmits: &mut VecDeque, now: Instant, ) where RId: Copy, @@ -2188,7 +2165,13 @@ where .socket() .expect("cannot force handshake while not connected"); - transmits.extend(make_owned_transmit(socket, bytes, allocations, now)); + transmits.extend(make_owned_transmit( + socket, + bytes, + &self.buffer_pool, + allocations, + now, + )); } fn socket(&self) -> Option> { @@ -2212,9 +2195,10 @@ where fn make_owned_transmit( socket: PeerSocket, message: &[u8], + buffer_pool: &BufferPool>, allocations: &mut BTreeMap, now: Instant, -) -> Option> +) -> Option where RId: Copy + Eq + Hash + PartialEq + Ord + fmt::Debug, { @@ -2229,19 +2213,19 @@ where } => Transmit { src: Some(source), dst: remote, - payload: Cow::Owned(message.into()), + payload: buffer_pool.pull_initialised(message), }, PeerSocket::RelayToPeer { relay, dest: peer } | PeerSocket::RelayToRelay { relay, dest: peer } => { let allocation = allocations.get_mut(&relay)?; - let mut buffer = channel_data_packet_buffer(message); - let encode_ok = allocation.encode_channel_data_header(peer, &mut buffer, now)?; + let mut channel_data = channel_data_packet_buffer(message); + let encode_ok = allocation.encode_channel_data_header(peer, &mut channel_data, now)?; Transmit { src: None, dst: encode_ok.socket, - payload: Cow::Owned(buffer), + payload: buffer_pool.pull_initialised(&channel_data), } } }; diff --git a/rust/connlib/tunnel/Cargo.toml b/rust/connlib/tunnel/Cargo.toml index 6bbd71b4e..2f40e8a54 100644 --- a/rust/connlib/tunnel/Cargo.toml +++ b/rust/connlib/tunnel/Cargo.toml @@ -9,6 +9,7 @@ anyhow = { workspace = true } base64 = { workspace = true, features = ["std"] } bimap = { workspace = true } boringtun = { workspace = true } +bufferpool = { workspace = true } bytes = { workspace = true, features = ["std"] } chrono = { workspace = true } connlib-model = { workspace = true } @@ -30,7 +31,6 @@ ip_network_table = { workspace = true } itertools = { workspace = true, features = ["use_std"] } l4-tcp-dns-server = { workspace = true } l4-udp-dns-server = { workspace = true } -lockfree-object-pool = { workspace = true } lru = { workspace = true } opentelemetry = { workspace = true, features = ["metrics"] } proptest = { workspace = true, optional = true } diff --git a/rust/connlib/tunnel/src/client.rs b/rust/connlib/tunnel/src/client.rs index 1305ef9d8..8ddc8f668 100644 --- a/rust/connlib/tunnel/src/client.rs +++ b/rust/connlib/tunnel/src/client.rs @@ -145,7 +145,7 @@ pub struct ClientState { buffered_events: VecDeque, buffered_packets: VecDeque, - buffered_transmits: VecDeque>, + buffered_transmits: VecDeque, buffered_dns_queries: VecDeque, } @@ -509,7 +509,7 @@ impl ClientState { &mut self, packet: IpPacket, now: Instant, - ) -> Option { + ) -> Option { let non_dns_packet = match self.try_handle_dns(packet, now) { ControlFlow::Break(()) => return None, ControlFlow::Continue(non_dns_packet) => non_dns_packet, @@ -625,7 +625,7 @@ impl ClientState { } } - fn encapsulate(&mut self, packet: IpPacket, now: Instant) -> Option { + fn encapsulate(&mut self, packet: IpPacket, now: Instant) -> Option { let dst = packet.destination(); if is_definitely_not_a_resource(dst) { @@ -1165,11 +1165,10 @@ impl ClientState { // Check if the client wants to emit any packets. if let Some(packet) = self.tcp_dns_client.poll_outbound() { // All packets from the TCP DNS client _should_ go through the tunnel. - let Some(encryped_packet) = self.encapsulate(packet, now) else { + let Some(transmit) = self.encapsulate(packet, now) else { continue; }; - let transmit = encryped_packet.to_transmit().into_owned(); self.buffered_transmits.push_back(transmit); continue; } @@ -1633,7 +1632,7 @@ impl ClientState { self.tcp_dns_client.reset(); } - pub(crate) fn poll_transmit(&mut self) -> Option> { + pub(crate) fn poll_transmit(&mut self) -> Option { self.buffered_transmits .pop_front() .or_else(|| self.node.poll_transmit()) @@ -1895,9 +1894,9 @@ fn encapsulate_and_buffer( gid: GatewayId, now: Instant, node: &mut ClientNode, - buffered_transmits: &mut VecDeque>, + buffered_transmits: &mut VecDeque, ) { - let Some(enc_packet) = node + let Some(transmit) = node .encapsulate(gid, packet, now) .inspect_err(|e| tracing::debug!(%gid, "Failed to encapsulate: {e}")) .ok() @@ -1906,7 +1905,7 @@ fn encapsulate_and_buffer( return; }; - buffered_transmits.push_back(enc_packet.to_transmit().into_owned()); + buffered_transmits.push_back(transmit); } fn handle_p2p_control_packet( @@ -1914,7 +1913,7 @@ fn handle_p2p_control_packet( fz_p2p_control: ip_packet::FzP2pControlSlice, dns_resource_nat_by_gateway: &mut BTreeMap<(GatewayId, DomainName), DnsResourceNatState>, node: &mut ClientNode, - buffered_transmits: &mut VecDeque>, + buffered_transmits: &mut VecDeque, now: Instant, ) { use p2p_control::dns_resource_nat; diff --git a/rust/connlib/tunnel/src/gateway.rs b/rust/connlib/tunnel/src/gateway.rs index e1fccf151..0e242d8a6 100644 --- a/rust/connlib/tunnel/src/gateway.rs +++ b/rust/connlib/tunnel/src/gateway.rs @@ -38,7 +38,7 @@ pub struct GatewayState { tun_ip_config: Option, buffered_events: VecDeque, - buffered_transmits: VecDeque>, + buffered_transmits: VecDeque, } #[derive(Debug)] @@ -84,7 +84,7 @@ impl GatewayState { &mut self, packet: IpPacket, now: Instant, - ) -> Result> { + ) -> Result> { let dst = packet.destination(); if !crate::is_peer(dst) { @@ -411,7 +411,7 @@ impl GatewayState { } } - pub(crate) fn poll_transmit(&mut self) -> Option> { + pub(crate) fn poll_transmit(&mut self) -> Option { self.buffered_transmits .pop_front() .or_else(|| self.node.poll_transmit()) @@ -497,15 +497,12 @@ fn encrypt_packet( cid: ClientId, node: &mut ServerNode, now: Instant, -) -> Result>> { - let Some(encrypted_packet) = node +) -> Result> { + let transmit = node .encapsulate(cid, packet, now) - .context("Failed to encapsulate packet")? - else { - return Ok(None); - }; + .context("Failed to encapsulate packet")?; - Ok(Some(encrypted_packet.to_transmit().into_owned())) + Ok(transmit) } /// Opaque request struct for when a domain name needs to be resolved. diff --git a/rust/connlib/tunnel/src/io/gso_queue.rs b/rust/connlib/tunnel/src/io/gso_queue.rs index a9ea44810..5ed02f852 100644 --- a/rust/connlib/tunnel/src/io/gso_queue.rs +++ b/rust/connlib/tunnel/src/io/gso_queue.rs @@ -1,9 +1,9 @@ use std::{ collections::{BTreeMap, VecDeque}, net::SocketAddr, - sync::Arc, }; +use bufferpool::{Buffer, BufferPool}; use bytes::BytesMut; use ip_packet::Ecn; use socket_factory::DatagramOut; @@ -13,29 +13,20 @@ use super::MAX_INBOUND_PACKET_BATCH; const MAX_SEGMENT_SIZE: usize = ip_packet::MAX_IP_SIZE + ip_packet::WG_OVERHEAD + ip_packet::DATA_CHANNEL_OVERHEAD; -type Buffer = lockfree_object_pool::SpinLockOwnedReusable; - /// Holds UDP datagrams that we need to send, indexed by src, dst and segment size. /// /// Calling [`Io::send_network`](super::Io::send_network) will copy the provided payload into this buffer. /// The buffer is then flushed using GSO in a single syscall. pub struct GsoQueue { - inner: BTreeMap>, - buffer_pool: Arc>, + inner: BTreeMap)>>, + buffer_pool: BufferPool, } impl GsoQueue { pub fn new() -> Self { Self { inner: Default::default(), - buffer_pool: Arc::new(lockfree_object_pool::SpinLockObjectPool::new( - || { - tracing::debug!("Initialising new buffer for GSO queue"); - - BytesMut::with_capacity(MAX_SEGMENT_SIZE * MAX_INBOUND_PACKET_BATCH) - }, - |b| b.clear(), - )), + buffer_pool: BufferPool::new(MAX_SEGMENT_SIZE * MAX_INBOUND_PACKET_BATCH, "gso-queue"), } } @@ -50,10 +41,7 @@ impl GsoQueue { let batches = self.inner.entry(Connection { src, dst, ecn }).or_default(); let Some((batch_size, buffer)) = batches.back_mut() else { - let mut buffer = self.buffer_pool.pull_owned(); - buffer.extend_from_slice(payload); - - batches.push_back((payload_len, buffer)); + batches.push_back((payload_len, self.buffer_pool.pull_initialised(payload))); return; }; @@ -67,16 +55,10 @@ impl GsoQueue { return; } - let mut buffer = self.buffer_pool.pull_owned(); - buffer.extend_from_slice(payload); - - batches.push_back((payload_len, buffer)); + batches.push_back((payload_len, self.buffer_pool.pull_initialised(payload))); } - pub fn datagrams( - &mut self, - ) -> impl Iterator>> + '_ - { + pub fn datagrams(&mut self) -> impl Iterator + '_ { DrainDatagramsIter { queue: self } } @@ -98,7 +80,7 @@ struct DrainDatagramsIter<'a> { } impl Iterator for DrainDatagramsIter<'_> { - type Item = DatagramOut>; + type Item = DatagramOut; fn next(&mut self) -> Option { loop { diff --git a/rust/connlib/tunnel/src/lib.rs b/rust/connlib/tunnel/src/lib.rs index e98be8e75..aff71f203 100644 --- a/rust/connlib/tunnel/src/lib.rs +++ b/rust/connlib/tunnel/src/lib.rs @@ -177,13 +177,13 @@ impl ClientTunnel { for packet in packets { let ecn = packet.ecn(); - let Some(packet) = self.role_state.handle_tun_input(packet, now) else { + let Some(transmit) = self.role_state.handle_tun_input(packet, now) else { self.role_state.handle_timeout(now); continue; }; self.io - .send_network(packet.src(), packet.dst(), packet.payload(), ecn); + .send_network(transmit.src, transmit.dst, &transmit.payload, ecn); } continue; @@ -308,13 +308,13 @@ impl GatewayTunnel { for packet in packets { let ecn = packet.ecn(); - let Some(packet) = self.role_state.handle_tun_input(packet, now)? else { + let Some(transmit) = self.role_state.handle_tun_input(packet, now)? else { self.role_state.handle_timeout(now, Utc::now()); continue; }; self.io - .send_network(packet.src(), packet.dst(), packet.payload(), ecn); + .send_network(transmit.src, transmit.dst, &transmit.payload, ecn); } continue; diff --git a/rust/connlib/tunnel/src/sockets.rs b/rust/connlib/tunnel/src/sockets.rs index 891afd90c..ec9ca6633 100644 --- a/rust/connlib/tunnel/src/sockets.rs +++ b/rust/connlib/tunnel/src/sockets.rs @@ -1,7 +1,7 @@ use anyhow::Result; -use bytes::BytesMut; use futures::{SinkExt, StreamExt, ready}; use gat_lending_iterator::LendingIterator; +use socket_factory::DatagramOut; use socket_factory::{DatagramIn, DatagramSegmentIter, SocketFactory, UdpSocket}; use std::{ io, @@ -11,9 +11,6 @@ use std::{ task::{Context, Poll, Waker}, }; -type DatagramOut = - socket_factory::DatagramOut>; - const UNSPECIFIED_V4_SOCKET: SocketAddrV4 = SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0); const UNSPECIFIED_V6_SOCKET: SocketAddrV6 = SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, 0, 0, 0); diff --git a/rust/connlib/tunnel/src/tests/buffered_transmits.rs b/rust/connlib/tunnel/src/tests/buffered_transmits.rs index 4b71d0aa3..1965f0dd2 100644 --- a/rust/connlib/tunnel/src/tests/buffered_transmits.rs +++ b/rust/connlib/tunnel/src/tests/buffered_transmits.rs @@ -10,7 +10,7 @@ use std::{ #[derive(Debug, Clone, Default)] pub(crate) struct BufferedTransmits { // Transmits are stored in reverse ordering to emit the earliest first. - inner: BinaryHeap>>>, + inner: BinaryHeap>>, } #[derive(Clone, Debug, PartialEq, PartialOrd, Eq, Ord)] @@ -23,7 +23,7 @@ impl BufferedTransmits { /// Pushes a new [`Transmit`] from a given [`Host`]. pub(crate) fn push_from( &mut self, - transmit: impl Into>>, + transmit: impl Into>, sending_host: &Host, now: Instant, ) { @@ -58,7 +58,7 @@ impl BufferedTransmits { pub(crate) fn push( &mut self, - transmit: impl Into>>, + transmit: impl Into>, latency: Duration, now: Instant, ) { @@ -74,7 +74,7 @@ impl BufferedTransmits { })); } - pub(crate) fn pop(&mut self, now: Instant) -> Option> { + pub(crate) fn pop(&mut self, now: Instant) -> Option { let next = self.inner.peek()?.0.at; if next > now { @@ -86,7 +86,7 @@ impl BufferedTransmits { Some(next.value) } - pub(crate) fn drain(&mut self) -> impl Iterator, Instant)> + '_ { + pub(crate) fn drain(&mut self) -> impl Iterator + '_ { self.inner .drain() .map(|Reverse(ByTime { at, value })| (value, at)) diff --git a/rust/connlib/tunnel/src/tests/sim_client.rs b/rust/connlib/tunnel/src/tests/sim_client.rs index c29fe5a4a..3d6f8a271 100644 --- a/rust/connlib/tunnel/src/tests/sim_client.rs +++ b/rust/connlib/tunnel/src/tests/sim_client.rs @@ -122,7 +122,7 @@ impl SimClient { upstream: SocketAddr, dns_transport: DnsTransport, now: Instant, - ) -> Option> { + ) -> Option { let Some(sentinel) = self.dns_by_sentinel.get_by_right(&upstream).copied() else { tracing::error!(%upstream, "Unknown DNS server"); return None; @@ -167,15 +167,15 @@ impl SimClient { &mut self, packet: IpPacket, now: Instant, - ) -> Option> { + ) -> Option { self.update_sent_requests(&packet); - let Some(enc_packet) = self.sut.handle_tun_input(packet, now) else { + let Some(transmit) = self.sut.handle_tun_input(packet, now) else { self.sut.handle_timeout(now); // If we handled the packet internally, make sure to advance state. return None; }; - Some(enc_packet.to_transmit().into_owned()) + Some(transmit) } fn update_sent_requests(&mut self, packet: &IpPacket) { diff --git a/rust/connlib/tunnel/src/tests/sim_gateway.rs b/rust/connlib/tunnel/src/tests/sim_gateway.rs index b45ba2f9a..933b7607c 100644 --- a/rust/connlib/tunnel/src/tests/sim_gateway.rs +++ b/rust/connlib/tunnel/src/tests/sim_gateway.rs @@ -64,7 +64,7 @@ impl SimGateway { unreachable_hosts: &UnreachableHosts, now: Instant, utc_now: DateTime, - ) -> Option> { + ) -> Option { let Some(packet) = self .sut .handle_network_input(transmit.dst, transmit.src.unwrap(), &transmit.payload, now) @@ -83,7 +83,7 @@ impl SimGateway { &mut self, global_dns_records: &DnsRecords, now: Instant, - ) -> Vec> { + ) -> Vec { let Some(ip_config) = self.sut.tunnel_ip_config() else { tracing::error!("Tunnel IP configuration not set"); return Vec::new(); @@ -116,15 +116,7 @@ impl SimGateway { udp_server_packets .chain(tcp_server_packets) - .filter_map(|packet| { - Some( - self.sut - .handle_tun_input(packet, now) - .unwrap()? - .to_transmit() - .into_owned(), - ) - }) + .filter_map(|packet| self.sut.handle_tun_input(packet, now).unwrap()) .collect() } @@ -165,7 +157,7 @@ impl SimGateway { packet: IpPacket, unreachable_hosts: &UnreachableHosts, now: Instant, - ) -> Option> { + ) -> Option { // TODO: Instead of handling these things inline, here, should we dispatch them via `RoutingTable`? let dst_ip = packet.destination(); @@ -220,12 +212,7 @@ impl SimGateway { if let Some(reply) = icmp_error.or_else(|| echo_reply(packet.clone())) { self.request_received(&packet); - let transmit = self - .sut - .handle_tun_input(reply, now) - .unwrap()? - .to_transmit() - .into_owned(); + let transmit = self.sut.handle_tun_input(reply, now).unwrap()?; return Some(transmit); } @@ -268,7 +255,7 @@ impl SimGateway { payload: &[u8], icmp_error: Option, now: Instant, - ) -> Option> { + ) -> Option { let reply = icmp_error.unwrap_or_else(|| { ip_packet::make::icmp_reply_packet( packet.destination(), @@ -280,12 +267,7 @@ impl SimGateway { .expect("src and dst are taken from incoming packet") }); - let transmit = self - .sut - .handle_tun_input(reply, now) - .unwrap()? - .to_transmit() - .into_owned(); + let transmit = self.sut.handle_tun_input(reply, now).unwrap()?; Some(transmit) } diff --git a/rust/connlib/tunnel/src/tests/sim_net.rs b/rust/connlib/tunnel/src/tests/sim_net.rs index 716c9cc37..0d1fdf4b5 100644 --- a/rust/connlib/tunnel/src/tests/sim_net.rs +++ b/rust/connlib/tunnel/src/tests/sim_net.rs @@ -123,11 +123,11 @@ impl Host { self.latency } - pub(crate) fn receive(&mut self, transmit: Transmit<'static>, now: Instant) { + pub(crate) fn receive(&mut self, transmit: Transmit, now: Instant) { self.inbox.push(transmit, self.latency, now); } - pub(crate) fn poll_transmit(&mut self, now: Instant) -> Option> { + pub(crate) fn poll_transmit(&mut self, now: Instant) -> Option { self.inbox.pop(now) } } diff --git a/rust/connlib/tunnel/src/tests/sim_relay.rs b/rust/connlib/tunnel/src/tests/sim_relay.rs index 1a2d180d0..1802daac6 100644 --- a/rust/connlib/tunnel/src/tests/sim_relay.rs +++ b/rust/connlib/tunnel/src/tests/sim_relay.rs @@ -2,6 +2,7 @@ use super::{ sim_net::{Host, dual_ip_stack, host}, strategies::latency, }; +use bufferpool::Buffer; use connlib_model::RelayId; use firezone_relay::{AddressFamily, AllocationPort, ClientSocket, IpStack, PeerSocket}; use proptest::prelude::*; @@ -9,7 +10,6 @@ use rand::{SeedableRng as _, rngs::StdRng}; use secrecy::SecretString; use snownet::{RelaySocket, Transmit}; use std::{ - borrow::Cow, collections::HashSet, net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}, time::{Duration, Instant, SystemTime}, @@ -18,7 +18,6 @@ use std::{ pub(crate) struct SimRelay { pub(crate) sut: firezone_relay::Server, pub(crate) allocations: HashSet<(AddressFamily, AllocationPort)>, - buffer: Vec, created_at: SystemTime, } @@ -52,7 +51,6 @@ impl SimRelay { Self { sut, allocations: Default::default(), - buffer: vec![0u8; (1 << 16) - 1], created_at: SystemTime::now(), } } @@ -90,13 +88,9 @@ impl SimRelay { } } - pub(crate) fn receive( - &mut self, - transmit: Transmit, - now: Instant, - ) -> Option> { + pub(crate) fn receive(&mut self, transmit: Transmit, now: Instant) -> Option { let dst = transmit.dst; - let payload = &transmit.payload; + let payload = transmit.payload; let sender = transmit.src.unwrap(); if self @@ -115,13 +109,13 @@ impl SimRelay { fn handle_client_input( &mut self, - payload: &[u8], + mut payload: Buffer>, client: ClientSocket, now: Instant, - ) -> Option> { - let (port, peer) = self.sut.handle_client_input(payload, client, now)?; + ) -> Option { + let (port, peer) = self.sut.handle_client_input(&payload, client, now)?; - let payload = &payload[4..]; + payload.shift_start_right(4); // The `dst` of the relayed packet is what TURN calls a "peer". let dst = peer.into_socket(); @@ -156,24 +150,22 @@ impl SimRelay { Some(Transmit { src: Some(src), dst, - payload: Cow::Owned(payload.to_vec()), + payload, }) } fn handle_peer_traffic( &mut self, - payload: &[u8], + mut payload: Buffer>, peer: PeerSocket, port: AllocationPort, - ) -> Option> { - let (client, channel) = self.sut.handle_peer_traffic(payload, peer, port)?; + ) -> Option { + let (client, channel) = self.sut.handle_peer_traffic(&payload, peer, port)?; - let full_length = firezone_relay::ChannelData::encode_header_to_slice( - channel, - payload.len() as u16, - &mut self.buffer[..4], - ); - self.buffer[4..full_length].copy_from_slice(payload); + let data_len = payload.len() as u16; + let header = payload.shift_start_left(4); + + firezone_relay::ChannelData::encode_header_to_slice(channel, data_len, header); let receiving_socket = client.into_socket(); let sending_socket = self @@ -183,7 +175,7 @@ impl SimRelay { Some(Transmit { src: Some(sending_socket), dst: receiving_socket, - payload: Cow::Owned(self.buffer[..full_length].to_vec()), + payload, }) } diff --git a/rust/connlib/tunnel/src/tests/sut.rs b/rust/connlib/tunnel/src/tests/sut.rs index 06bf4674a..1c92b0ee4 100644 --- a/rust/connlib/tunnel/src/tests/sut.rs +++ b/rust/connlib/tunnel/src/tests/sut.rs @@ -16,6 +16,7 @@ use crate::tests::flux_capacitor::FluxCapacitor; use crate::tests::transition::Transition; use crate::utils::earliest; use crate::{ClientEvent, GatewayEvent, dns, messages::Interface}; +use bufferpool::BufferPool; use connlib_model::{ClientId, GatewayId, PublicKey, RelayId}; use dns_types::ResponseCode; use dns_types::prelude::*; @@ -41,6 +42,8 @@ pub(crate) struct TunnelTest { gateways: BTreeMap>, relays: BTreeMap>, + buffer_pool: BufferPool>, + drop_direct_client_traffic: bool, network: RoutingTable, } @@ -91,6 +94,7 @@ impl TunnelTest { client, gateways, relays, + buffer_pool: BufferPool::new(1024, "test"), }; let mut buffered_transmits = BufferedTransmits::default(); @@ -155,9 +159,7 @@ impl TunnelTest { ) .unwrap(); - let transmit = state - .client - .exec_mut(|sim| Some(sim.encapsulate(packet, now)?.into_owned())); + let transmit = state.client.exec_mut(|sim| sim.encapsulate(packet, now)); buffered_transmits.push_from(transmit, &state.client, now); } @@ -179,9 +181,7 @@ impl TunnelTest { ) .unwrap(); - let transmit = state - .client - .exec_mut(|sim| Some(sim.encapsulate(packet, now)?.into_owned())); + let transmit = state.client.exec_mut(|sim| sim.encapsulate(packet, now)); buffered_transmits.push_from(transmit, &state.client, now); } @@ -203,9 +203,7 @@ impl TunnelTest { ) .unwrap(); - let transmit = state - .client - .exec_mut(|sim| Some(sim.encapsulate(packet, now)?.into_owned())); + let transmit = state.client.exec_mut(|sim| sim.encapsulate(packet, now)); buffered_transmits.push_from(transmit, &state.client, now); } @@ -459,7 +457,7 @@ impl TunnelTest { Transmit { src: Some(src), dst, - payload: payload.into(), + payload: self.buffer_pool.pull_initialised(&payload), }, relay, now, @@ -633,7 +631,7 @@ impl TunnelTest { /// It takes a [`Transmit`] and checks, which host accepts it, i.e. has configured the correct IP address. /// /// Currently, the network topology of our tests are a single subnet without NAT. - fn dispatch_transmit(&mut self, transmit: Transmit<'static>, at: Instant) { + fn dispatch_transmit(&mut self, transmit: Transmit, at: Instant) { let src = transmit .src .expect("`src` should always be set in these tests"); diff --git a/rust/ip-packet/Cargo.toml b/rust/ip-packet/Cargo.toml index 1a9f960d5..d65526db8 100644 --- a/rust/ip-packet/Cargo.toml +++ b/rust/ip-packet/Cargo.toml @@ -12,9 +12,9 @@ proptest = ["dep:proptest"] [dependencies] anyhow = { workspace = true } +bufferpool = { workspace = true } etherparse = { workspace = true, features = ["std"] } etherparse-ext = { workspace = true } -lockfree-object-pool = { workspace = true } proptest = { workspace = true, optional = true } thiserror = { workspace = true } tracing = { workspace = true } diff --git a/rust/ip-packet/src/buffer_pool.rs b/rust/ip-packet/src/buffer_pool.rs deleted file mode 100644 index 39faeffd1..000000000 --- a/rust/ip-packet/src/buffer_pool.rs +++ /dev/null @@ -1,96 +0,0 @@ -use std::{ - ops::{Deref, DerefMut}, - sync::{Arc, LazyLock}, -}; - -use crate::MAX_FZ_PAYLOAD; - -type BufferPool = Arc>>; - -static BUFFER_POOL: LazyLock = LazyLock::new(|| { - Arc::new(lockfree_object_pool::MutexObjectPool::new( - || vec![0; MAX_FZ_PAYLOAD], - |v| v.fill(0), - )) -}); - -pub struct Buffer(lockfree_object_pool::MutexOwnedReusable>); - -impl Clone for Buffer { - fn clone(&self) -> Self { - let mut copy = Buffer::default(); - - copy.0.resize(self.len(), 0); - copy.copy_from_slice(self); - - copy - } -} - -impl PartialEq for Buffer { - fn eq(&self, other: &Self) -> bool { - self.as_ref() == other.as_ref() - } -} - -impl std::fmt::Debug for Buffer { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_tuple("Buffer").finish() - } -} - -impl Deref for Buffer { - type Target = [u8]; - - fn deref(&self) -> &Self::Target { - &self.0[..] - } -} - -impl DerefMut for Buffer { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.0[..] - } -} - -impl Default for Buffer { - fn default() -> Self { - Self(BUFFER_POOL.pull_owned()) - } -} - -impl Drop for Buffer { - fn drop(&mut self) { - debug_assert_eq!( - self.0.capacity(), - MAX_FZ_PAYLOAD, - "Buffer should never re-allocate" - ) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn buffer_can_be_cloned() { - let mut buffer = Buffer::default(); - buffer[..11].copy_from_slice(b"hello world"); - - let buffer2 = buffer.clone(); - - assert_eq!(&buffer2[..], &buffer[..]); - } - - #[test] - fn cloned_buffer_owns_its_own_memory() { - let mut buffer = Buffer::default(); - buffer[..11].copy_from_slice(b"hello world"); - - let buffer2 = buffer.clone(); - drop(buffer); - - assert_eq!(&buffer2[..11], b"hello world"); - } -} diff --git a/rust/ip-packet/src/lib.rs b/rust/ip-packet/src/lib.rs index b8c984f43..7b0fb2c7a 100644 --- a/rust/ip-packet/src/lib.rs +++ b/rust/ip-packet/src/lib.rs @@ -2,7 +2,6 @@ pub mod make; -mod buffer_pool; mod fz_p2p_control; mod fz_p2p_control_slice; mod icmp_dest_unreachable; @@ -13,7 +12,7 @@ mod nat64; #[allow(clippy::unwrap_used)] pub mod proptest; -use buffer_pool::Buffer; +use bufferpool::{Buffer, BufferPool}; pub use etherparse::*; pub use fz_p2p_control::EventType as FzP2pEventType; pub use fz_p2p_control_slice::FzP2pControlSlice; @@ -24,6 +23,7 @@ mod proptests; use anyhow::{Context as _, Result, bail}; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; +use std::sync::LazyLock; use etherparse_ext::Icmpv4HeaderSliceMut; use etherparse_ext::Icmpv6EchoHeaderSliceMut; @@ -32,6 +32,9 @@ use etherparse_ext::Ipv6HeaderSliceMut; use etherparse_ext::TcpHeaderSliceMut; use etherparse_ext::UdpHeaderSliceMut; +static BUFFER_POOL: LazyLock>> = + LazyLock::new(|| BufferPool::new(MAX_FZ_PAYLOAD, "ip-packet")); + /// The maximum size of an IP packet we can handle. pub const MAX_IP_SIZE: usize = 1280; /// The maximum payload an IP packet can have. @@ -116,9 +119,16 @@ pub enum Layer4Protocol { } /// A buffer for reading a new [`IpPacket`] from the network. -#[derive(Default)] pub struct IpPacketBuf { - inner: Buffer, + inner: Buffer>, +} + +impl Default for IpPacketBuf { + fn default() -> Self { + Self { + inner: BUFFER_POOL.pull(), + } + } } impl IpPacketBuf { @@ -200,7 +210,7 @@ impl std::fmt::Debug for IpPacket { #[derive(Debug, PartialEq, Clone)] pub struct ConvertibleIpv4Packet { - buf: Buffer, + buf: Buffer>, start: usize, len: usize, } @@ -280,7 +290,7 @@ impl ConvertibleIpv4Packet { #[derive(Debug, PartialEq, Clone)] pub struct ConvertibleIpv6Packet { - buf: Buffer, + buf: Buffer>, start: usize, len: usize, } diff --git a/rust/socket-factory/Cargo.toml b/rust/socket-factory/Cargo.toml index 03f787a90..bfeba2dd3 100644 --- a/rust/socket-factory/Cargo.toml +++ b/rust/socket-factory/Cargo.toml @@ -6,12 +6,12 @@ license = { workspace = true } [dependencies] anyhow = { workspace = true } +bufferpool = { workspace = true } bytes = { workspace = true } derive_more = { workspace = true, features = ["debug"] } firezone-logging = { workspace = true } gat-lending-iterator = { workspace = true } ip-packet = { workspace = true } -lockfree-object-pool = { workspace = true } opentelemetry = { workspace = true, features = ["metrics"] } parking_lot = { workspace = true } quinn-udp = { workspace = true } diff --git a/rust/socket-factory/src/lib.rs b/rust/socket-factory/src/lib.rs index fa0b4fb3d..29c2e0a4e 100644 --- a/rust/socket-factory/src/lib.rs +++ b/rust/socket-factory/src/lib.rs @@ -1,5 +1,6 @@ use anyhow::{Context as _, Result}; -use bytes::Buf as _; +use bufferpool::{Buffer, BufferPool}; +use bytes::{Buf as _, BytesMut}; use firezone_logging::err_with_src; use gat_lending_iterator::LendingIterator; use ip_packet::Ecn; @@ -10,7 +11,6 @@ use std::collections::HashMap; use std::io; use std::io::IoSliceMut; use std::ops::Deref; -use std::sync::Arc; use std::{ net::{IpAddr, SocketAddr}, task::{Context, Poll, ready}, @@ -156,7 +156,7 @@ pub struct UdpSocket { src_by_dst_cache: Mutex>, /// A buffer pool for batches of incoming UDP packets. - buffer_pool: Arc>>, + buffer_pool: BufferPool>, gro_batch_histogram: opentelemetry::metrics::Histogram, port: u16, @@ -164,7 +164,8 @@ pub struct UdpSocket { impl UdpSocket { fn new(inner: tokio::net::UdpSocket) -> io::Result { - let port = inner.local_addr()?.port(); + let socket_addr = inner.local_addr()?; + let port = socket_addr.port(); Ok(UdpSocket { state: quinn_udp::UdpSocketState::new(quinn_udp::UdpSockRef::from(&inner))?, @@ -172,10 +173,13 @@ impl UdpSocket { inner, source_ip_resolver: Box::new(|_| Ok(None)), src_by_dst_cache: Default::default(), - buffer_pool: Arc::new(lockfree_object_pool::MutexObjectPool::new( - || vec![0u8; u16::MAX as usize], - |_| {}, - )), + buffer_pool: BufferPool::new( + u16::MAX as usize, + match socket_addr.ip() { + IpAddr::V4(_) => "udp-socket-v4", + IpAddr::V6(_) => "udp-socket-v6", + }, + ), gro_batch_histogram: opentelemetry::global::meter("connlib") .u64_histogram("system.network.packets.batch_count") .with_description( @@ -249,10 +253,10 @@ pub struct DatagramIn<'a> { } /// An outbound UDP datagram. -pub struct DatagramOut { +pub struct DatagramOut { pub src: Option, pub dst: SocketAddr, - pub packet: B, + pub packet: Buffer, pub segment_size: Option, pub ecn: Ecn, } @@ -268,7 +272,7 @@ impl UdpSocket { } = self; // Stack-allocate arrays for buffers and meta. The size is implied from the const-generic default on `DatagramSegmentIter`. - let mut bufs = std::array::from_fn(|_| self.buffer_pool.pull_owned()); + let mut bufs = std::array::from_fn(|_| self.buffer_pool.pull()); let mut meta = std::array::from_fn(|_| quinn_udp::RecvMeta::default()); loop { @@ -302,14 +306,11 @@ impl UdpSocket { self.inner.poll_send_ready(cx) } - pub async fn send(&self, datagram: DatagramOut) -> Result<()> - where - B: Deref, - { + pub async fn send(&self, datagram: DatagramOut) -> Result<()> { let Some(transmit) = self.prepare_transmit( datagram.dst, datagram.src.map(|s| s.ip()), - datagram.packet.deref().chunk(), + datagram.packet.chunk(), datagram.segment_size, datagram.ecn, )? @@ -497,10 +498,7 @@ impl UdpSocket { /// When [`quinn_udp`] returns us the buffers, it will have populated the [`quinn_udp::RecvMeta`]s accordingly. /// Thus, our main job within this iterator is to loop over the `buffers` and `meta` pair-wise, inspect the `meta` and segment the data within the buffer accordingly. #[derive(derive_more::Debug)] -pub struct DatagramSegmentIter< - const N: usize = { quinn_udp::BATCH_SIZE }, - B = lockfree_object_pool::MutexOwnedReusable>, -> { +pub struct DatagramSegmentIter>> { #[debug(skip)] buffers: [B; N], metas: [quinn_udp::RecvMeta; N],