From e031dfdb4a7f0692d452efed1d361efc9048d2fc Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Wed, 30 Apr 2025 18:52:18 +1000 Subject: [PATCH] refactor(connlib): introduce our own `bufferpool` crate (#8928) We have been using buffer pools for a while all over `connlib` as a way to efficiently use heap-allocated memory. This PR harmonizes the usage of buffer pools across the codebase by introducing a dedicated `bufferpool` crate. This crate offers a convenient and easy-to-use API for all the things we (currently) need from buffer pools. As a nice bonus of having it all in one place, we can now also track metrics of how many buffers we have currently allocated. An example output from the local metrics exporter looks like this: ``` Name : system.buffer.count Description : The number of buffers allocated in the pool. Unit : {buffers} Type : Sum Sum DataPoints Monotonic : false Temporality : Cumulative DataPoint #0 StartTime : 2025-04-29 12:41:25.278436 EndTime : 2025-04-29 12:42:25.278088 Value : 96 Attributes : -> system.buffer.pool.name: udp-socket-v6 -> system.buffer.pool.buffer_size: 65535 DataPoint #1 StartTime : 2025-04-29 12:41:25.278436 EndTime : 2025-04-29 12:42:25.278088 Value : 7 Attributes : -> system.buffer.pool.buffer_size: 131600 -> system.buffer.pool.name: gso-queue DataPoint #2 StartTime : 2025-04-29 12:41:25.278436 EndTime : 2025-04-29 12:42:25.278088 Value : 128 Attributes : -> system.buffer.pool.name: udp-socket-v4 -> system.buffer.pool.buffer_size: 65535 DataPoint #3 StartTime : 2025-04-29 12:41:25.278436 EndTime : 2025-04-29 12:42:25.278088 Value : 8 Attributes : -> system.buffer.pool.buffer_size: 1336 -> system.buffer.pool.name: ip-packet DataPoint #4 StartTime : 2025-04-29 12:41:25.278436 EndTime : 2025-04-29 12:42:25.278088 Value : 9 Attributes : -> system.buffer.pool.buffer_size: 1336 -> system.buffer.pool.name: snownet ``` Resolves: #8385 --- rust/Cargo.lock | 22 +- rust/Cargo.toml | 2 + rust/bin-shared/Cargo.toml | 2 + rust/bin-shared/tests/no_packet_loops_udp.rs | 8 +- rust/bufferpool/Cargo.toml | 21 ++ rust/bufferpool/lib.rs | 348 ++++++++++++++++++ rust/connlib/snownet/Cargo.toml | 2 +- rust/connlib/snownet/src/allocation.rs | 15 +- rust/connlib/snownet/src/lib.rs | 4 +- rust/connlib/snownet/src/node.rs | 144 ++++---- rust/connlib/tunnel/Cargo.toml | 2 +- rust/connlib/tunnel/src/client.rs | 19 +- rust/connlib/tunnel/src/gateway.rs | 17 +- rust/connlib/tunnel/src/io/gso_queue.rs | 34 +- rust/connlib/tunnel/src/lib.rs | 8 +- rust/connlib/tunnel/src/sockets.rs | 5 +- .../tunnel/src/tests/buffered_transmits.rs | 10 +- rust/connlib/tunnel/src/tests/sim_client.rs | 8 +- rust/connlib/tunnel/src/tests/sim_gateway.rs | 32 +- rust/connlib/tunnel/src/tests/sim_net.rs | 4 +- rust/connlib/tunnel/src/tests/sim_relay.rs | 40 +- rust/connlib/tunnel/src/tests/sut.rs | 20 +- rust/ip-packet/Cargo.toml | 2 +- rust/ip-packet/src/buffer_pool.rs | 96 ----- rust/ip-packet/src/lib.rs | 22 +- rust/socket-factory/Cargo.toml | 2 +- rust/socket-factory/src/lib.rs | 38 +- 27 files changed, 585 insertions(+), 342 deletions(-) create mode 100644 rust/bufferpool/Cargo.toml create mode 100644 rust/bufferpool/lib.rs delete mode 100644 rust/ip-packet/src/buffer_pool.rs diff --git a/rust/Cargo.lock b/rust/Cargo.lock index be3457c96..47f316c6a 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -831,6 +831,18 @@ dependencies = [ "alloc-stdlib", ] +[[package]] +name = "bufferpool" +version = "0.1.0" +dependencies = [ + "bytes", + "lockfree-object-pool", + "opentelemetry", + "opentelemetry_sdk", + "tokio", + "tracing", +] + [[package]] name = "bumpalo" version = "3.16.0" @@ -2163,6 +2175,8 @@ version = "0.1.0" dependencies = [ "anyhow", "axum", + "bufferpool", + "bytes", "clap", "firezone-logging", "flume", @@ -2477,6 +2491,7 @@ dependencies = [ "base64 0.22.1", "bimap", "boringtun", + "bufferpool", "bytes", "chrono", "connlib-model", @@ -2499,7 +2514,6 @@ dependencies = [ "itertools 0.13.0", "l4-tcp-dns-server", "l4-udp-dns-server", - "lockfree-object-pool", "lru", "opentelemetry", "proptest", @@ -3501,9 +3515,9 @@ name = "ip-packet" version = "0.1.0" dependencies = [ "anyhow", + "bufferpool", "etherparse", "etherparse-ext", - "lockfree-object-pool", "proptest", "test-strategy", "thiserror 1.0.69", @@ -6457,6 +6471,7 @@ name = "snownet" version = "0.1.0" dependencies = [ "boringtun", + "bufferpool", "bytecodec", "bytes", "derive_more 1.0.0", @@ -6465,7 +6480,6 @@ dependencies = [ "hex-display", "ip-packet", "itertools 0.13.0", - "lockfree-object-pool", "once_cell", "rand 0.8.5", "ringbuffer", @@ -6482,12 +6496,12 @@ name = "socket-factory" version = "0.1.0" dependencies = [ "anyhow", + "bufferpool", "bytes", "derive_more 1.0.0", "firezone-logging", "gat-lending-iterator", "ip-packet", - "lockfree-object-pool", "opentelemetry", "parking_lot", "quinn-udp", diff --git a/rust/Cargo.toml b/rust/Cargo.toml index 05909cf18..f16e24a44 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -1,6 +1,7 @@ [workspace] members = [ "bin-shared", + "bufferpool", "connlib/clients/android", "connlib/clients/apple", "connlib/clients/shared", @@ -97,6 +98,7 @@ keyring = "3.6.2" known-folders = "1.2.0" l4-tcp-dns-server = { path = "connlib/l4-tcp-dns-server" } l4-udp-dns-server = { path = "connlib/l4-udp-dns-server" } +bufferpool = { path = "bufferpool" } libc = "0.2.171" lockfree-object-pool = "0.1.6" log = "0.4" diff --git a/rust/bin-shared/Cargo.toml b/rust/bin-shared/Cargo.toml index 5ef744d53..1a4880828 100644 --- a/rust/bin-shared/Cargo.toml +++ b/rust/bin-shared/Cargo.toml @@ -23,6 +23,8 @@ tracing = { workspace = true } tun = { workspace = true } [dev-dependencies] +bufferpool = { workspace = true } +bytes = { workspace = true } tokio = { workspace = true, features = ["macros", "rt-multi-thread"] } [target.'cfg(target_os = "linux")'.dependencies] diff --git a/rust/bin-shared/tests/no_packet_loops_udp.rs b/rust/bin-shared/tests/no_packet_loops_udp.rs index a0889e591..3ddc22ba2 100644 --- a/rust/bin-shared/tests/no_packet_loops_udp.rs +++ b/rust/bin-shared/tests/no_packet_loops_udp.rs @@ -1,5 +1,7 @@ #![allow(clippy::unwrap_used)] +use bufferpool::BufferPool; +use bytes::BytesMut; use firezone_bin_shared::{TunDeviceManager, platform::udp_socket_factory}; use gat_lending_iterator::LendingIterator as _; use ip_network::Ipv4Network; @@ -19,6 +21,8 @@ async fn no_packet_loops_udp() { let ipv4 = Ipv4Addr::from([100, 90, 215, 97]); let ipv6 = Ipv6Addr::from([0xfd00, 0x2021, 0x1111, 0x0, 0x0, 0x0, 0x0016, 0x588f]); + let bufferpool = BufferPool::::new(0, "test"); + let mut device_manager = TunDeviceManager::new(1280, 1).unwrap(); let _tun = device_manager.make_tun().unwrap(); device_manager.set_ips(ipv4, ipv6).await.unwrap(); @@ -45,7 +49,9 @@ async fn no_packet_loops_udp() { .send(DatagramOut { src: None, dst: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(141, 101, 90, 0), 3478)), // stun.cloudflare.com, - packet: &hex_literal::hex!("000100002112A4420123456789abcdef01234567").as_ref(), + packet: bufferpool.pull_initialised( + hex_literal::hex!("000100002112A4420123456789abcdef01234567").as_ref(), + ), segment_size: None, ecn: Ecn::NonEct, }) diff --git a/rust/bufferpool/Cargo.toml b/rust/bufferpool/Cargo.toml new file mode 100644 index 000000000..fbf3d07ac --- /dev/null +++ b/rust/bufferpool/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "bufferpool" +version = "0.1.0" +edition = { workspace = true } +license = { workspace = true } + +[lib] +path = "lib.rs" + +[dependencies] +bytes = { workspace = true } +lockfree-object-pool = { workspace = true } +opentelemetry = { workspace = true, features = ["metrics"] } +tracing = { workspace = true } + +[dev-dependencies] +opentelemetry_sdk = { workspace = true, features = ["testing"] } +tokio = { workspace = true, features = ["macros", "rt"] } + +[lints] +workspace = true diff --git a/rust/bufferpool/lib.rs b/rust/bufferpool/lib.rs new file mode 100644 index 000000000..b5920e8bf --- /dev/null +++ b/rust/bufferpool/lib.rs @@ -0,0 +1,348 @@ +#![cfg_attr(test, allow(clippy::unwrap_used))] + +use std::{ + ops::{Deref, DerefMut}, + sync::Arc, +}; + +use bytes::BytesMut; +use opentelemetry::{KeyValue, metrics::UpDownCounter}; + +#[derive(Clone)] +pub struct BufferPool { + inner: Arc>>, +} + +impl BufferPool +where + B: Buf, +{ + pub fn new(capacity: usize, tag: &'static str) -> Self { + let buffer_counter = opentelemetry::global::meter("connlib") + .i64_up_down_counter("system.buffer.count") + .with_description("The number of buffers allocated in the pool.") + .with_unit("{buffers}") + .init(); + + Self { + inner: Arc::new(lockfree_object_pool::MutexObjectPool::new( + move || { + BufferStorage::new( + B::with_capacity(capacity), + buffer_counter.clone(), + [ + KeyValue::new("system.buffer.pool.name", tag), + KeyValue::new("system.buffer.pool.buffer_size", capacity as i64), + ], + ) + }, + |_| {}, + )), + } + } + + pub fn pull(&self) -> Buffer { + Buffer { + inner: self.inner.pull_owned(), + pool: self.inner.clone(), + } + } +} + +impl BufferPool +where + B: Buf + DerefMut, +{ + pub fn pull_initialised(&self, data: &[u8]) -> Buffer { + let mut buffer = self.pull(); + let len = data.len(); + + buffer.resize_to(len); + buffer.copy_from_slice(data); + + buffer + } +} + +pub struct Buffer { + inner: lockfree_object_pool::MutexOwnedReusable>, + pool: Arc>>, +} + +impl Buffer> { + /// Shifts the start of the buffer to the right by N bytes, returning the bytes removed from the front of the buffer. + pub fn shift_start_right(&mut self, num: usize) -> Vec { + let num_to_end = self.split_off(num); + + std::mem::replace(&mut self.inner.inner, num_to_end) + } + + /// Shifts the start of the buffer to the left by N bytes, returning a slice to the added bytes at the front of the buffer. + pub fn shift_start_left(&mut self, num: usize) -> &mut [u8] { + let current_len = self.len(); + + self.resize(current_len + num, 0); + self.copy_within(..current_len, num); + + &mut self[..num] + } +} + +impl Clone for Buffer +where + B: Buf, +{ + fn clone(&self) -> Self { + let mut copy = self.pool.pull_owned(); + + self.inner.inner.clone(&mut copy); + + Self { + inner: copy, + pool: self.pool.clone(), + } + } +} + +impl PartialEq for Buffer +where + B: Deref, +{ + fn eq(&self, other: &Self) -> bool { + self.as_ref() == other.as_ref() + } +} + +impl Eq for Buffer where B: Deref {} + +impl PartialOrd for Buffer +where + B: Deref, +{ + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for Buffer +where + B: Deref, +{ + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.as_ref().cmp(other.as_ref()) + } +} + +impl std::fmt::Debug for Buffer { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_tuple("Buffer").finish() + } +} + +impl Deref for Buffer { + type Target = B; + + fn deref(&self) -> &Self::Target { + self.inner.deref() + } +} + +impl DerefMut for Buffer { + fn deref_mut(&mut self) -> &mut Self::Target { + self.inner.deref_mut() + } +} + +pub trait Buf: Sized { + fn with_capacity(capacity: usize) -> Self; + fn clone(&self, dst: &mut Self); + fn resize_to(&mut self, len: usize); +} + +impl Buf for Vec { + fn with_capacity(capacity: usize) -> Self { + vec![0; capacity] + } + + fn clone(&self, dst: &mut Self) { + dst.resize(self.len(), 0); + dst.copy_from_slice(self); + } + + fn resize_to(&mut self, len: usize) { + self.resize(len, 0); + } +} + +impl Buf for BytesMut { + fn with_capacity(capacity: usize) -> Self { + BytesMut::zeroed(capacity) + } + + fn clone(&self, dst: &mut Self) { + dst.resize(self.len(), 0); + dst.copy_from_slice(self); + } + + fn resize_to(&mut self, len: usize) { + self.resize(len, 0); + } +} + +/// A wrapper around a buffer `B` that keeps track of how many buffers there are in a counter. +struct BufferStorage { + inner: B, + + attributes: [KeyValue; 2], + counter: UpDownCounter, +} + +impl Drop for BufferStorage { + fn drop(&mut self) { + self.counter.add(-1, &self.attributes); + } +} + +impl BufferStorage { + fn new(inner: B, counter: UpDownCounter, attributes: [KeyValue; 2]) -> Self { + counter.add(1, &attributes); + + Self { + inner, + counter, + attributes, + } + } +} + +impl Deref for BufferStorage { + type Target = B; + + fn deref(&self) -> &Self::Target { + &self.inner + } +} + +impl DerefMut for BufferStorage { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.inner + } +} + +#[cfg(test)] +mod tests { + use std::time::Duration; + + use opentelemetry::global; + use opentelemetry_sdk::{ + metrics::{PeriodicReader, SdkMeterProvider, data::Sum}, + testing::metrics::InMemoryMetricsExporter, + }; + + use super::*; + + #[test] + fn buffer_can_be_cloned() { + let pool = BufferPool::>::new(1024, "test"); + + let buffer = pool.pull_initialised(b"hello world"); + + #[allow(clippy::redundant_clone)] + let buffer2 = buffer.clone(); + + assert_eq!(&buffer2[..], &buffer[..]); + } + + #[test] + fn cloned_buffer_owns_its_own_memory() { + let pool = BufferPool::>::new(1024, "test"); + + let buffer = pool.pull_initialised(b"hello world"); + + let buffer2 = buffer.clone(); + drop(buffer); + + assert_eq!(&buffer2[..11], b"hello world"); + } + + #[test] + fn initialised_buffer_is_only_as_long_as_content() { + let pool = BufferPool::>::new(1024, "test"); + + let buffer = pool.pull_initialised(b"hello world"); + + assert_eq!(buffer.len(), 11); + } + + #[test] + fn shift_start_right() { + let pool = BufferPool::>::new(1024, "test"); + + let mut buffer = pool.pull_initialised(b"hello world"); + + let front = buffer.shift_start_right(5); + + assert_eq!(front, b"hello"); + assert_eq!(&*buffer, b" world"); + } + + #[test] + fn shift_start_left() { + let pool = BufferPool::>::new(1024, "test"); + + let mut buffer = pool.pull_initialised(b"hello world"); + + let front = buffer.shift_start_left(5); + front.copy_from_slice(b"12345"); + + assert_eq!(&*buffer, b"12345hello world"); + } + + #[tokio::test] + async fn buffer_pool_metrics() { + let (_provider, exporter) = init_meter_provider(); + + let pool = BufferPool::>::new(1024, "test"); + + let buffer1 = pool.pull_initialised(b"hello world"); + let buffer2 = pool.pull_initialised(b"hello world"); + let buffer3 = pool.pull_initialised(b"hello world"); + + tokio::time::sleep(Duration::from_millis(10)).await; // Wait for metrics to be exported. + + assert_eq!(get_num_buffers(&exporter), 3); + + drop(pool); + drop(buffer1); + drop(buffer2); + drop(buffer3); + + tokio::time::sleep(Duration::from_millis(10)).await; // Wait for metrics to be exported. + + assert_eq!(get_num_buffers(&exporter), 0); + } + + fn get_num_buffers(exporter: &InMemoryMetricsExporter) -> i64 { + let metrics = exporter.get_finished_metrics().unwrap(); + + let metric = &metrics.iter().last().unwrap().scope_metrics[0].metrics[0]; + let sum = metric.data.as_any().downcast_ref::>().unwrap(); + + sum.data_points[0].value + } + + fn init_meter_provider() -> (SdkMeterProvider, InMemoryMetricsExporter) { + let exporter = InMemoryMetricsExporter::default(); + + let provider = SdkMeterProvider::builder() + .with_reader( + PeriodicReader::builder(exporter.clone(), opentelemetry_sdk::runtime::Tokio) + .with_interval(Duration::from_millis(1)) + .build(), + ) + .build(); + global::set_meter_provider(provider.clone()); + + (provider, exporter) + } +} diff --git a/rust/connlib/snownet/Cargo.toml b/rust/connlib/snownet/Cargo.toml index d4b95c8c4..c70b5683a 100644 --- a/rust/connlib/snownet/Cargo.toml +++ b/rust/connlib/snownet/Cargo.toml @@ -6,6 +6,7 @@ license = { workspace = true } [dependencies] boringtun = { workspace = true } +bufferpool = { workspace = true } bytecodec = { workspace = true } bytes = { workspace = true } derive_more = { workspace = true, features = ["debug"] } @@ -14,7 +15,6 @@ hex = { workspace = true } hex-display = { workspace = true } ip-packet = { workspace = true } itertools = { workspace = true } -lockfree-object-pool = { workspace = true } once_cell = { workspace = true } rand = { workspace = true } ringbuffer = { workspace = true } diff --git a/rust/connlib/snownet/src/allocation.rs b/rust/connlib/snownet/src/allocation.rs index 00932e901..c55fa050f 100644 --- a/rust/connlib/snownet/src/allocation.rs +++ b/rust/connlib/snownet/src/allocation.rs @@ -2,6 +2,7 @@ use crate::{ backoff::{self, ExponentialBackoff}, node::{SessionId, Transmit}, }; +use bufferpool::BufferPool; use bytecodec::{DecodeExt as _, EncodeExt as _}; use firezone_logging::err_with_src; use hex_display::HexDisplayExt as _; @@ -47,7 +48,6 @@ const BINDING_INTERVAL: Duration = Duration::from_secs(25); /// Represents a TURN allocation that refreshes itself. /// /// Allocations have a lifetime and need to be continuously refreshed to stay active. -#[derive(Debug)] pub struct Allocation { /// The known sockets of the relay. server: RelaySocket, @@ -77,7 +77,7 @@ pub struct Allocation { /// When we received the allocation and how long it is valid. allocation_lifetime: Option<(Instant, Duration)>, - buffered_transmits: VecDeque>, + buffered_transmits: VecDeque, events: VecDeque, sent_requests: BTreeMap, ExponentialBackoff)>, @@ -88,6 +88,8 @@ pub struct Allocation { credentials: Option, explicit_failure: Option, + + buffer_pool: BufferPool>, } #[derive(derive_more::Debug, Clone, Copy)] @@ -212,6 +214,7 @@ impl Allocation { realm: Realm, now: Instant, session_id: SessionId, + buffer_pool: BufferPool>, ) -> Self { let mut allocation = Self { server, @@ -235,6 +238,7 @@ impl Allocation { software: Software::new(format!("snownet; session={session_id}")) .expect("description has less then 128 chars"), explicit_failure: Default::default(), + buffer_pool, }; allocation.send_binding_requests(now); @@ -742,7 +746,7 @@ impl Allocation { self.events.pop_front() } - pub fn poll_transmit(&mut self) -> Option> { + pub fn poll_transmit(&mut self) -> Option { self.buffered_transmits.pop_front() } @@ -1082,10 +1086,11 @@ impl Allocation { self.sent_requests .insert(id, (dst, message.clone(), backoff)); + self.buffered_transmits.push_back(Transmit { src: None, dst, - payload: encode(message).into(), + payload: self.buffer_pool.pull_initialised(&encode(message)), }); true @@ -2818,6 +2823,7 @@ mod tests { Realm::new("firezone".to_owned()).unwrap(), start, SessionId::default(), + BufferPool::new(500, "test"), ) } @@ -2832,6 +2838,7 @@ mod tests { Realm::new("firezone".to_owned()).unwrap(), start, SessionId::default(), + BufferPool::new(500, "test"), ) } diff --git a/rust/connlib/snownet/src/lib.rs b/rust/connlib/snownet/src/lib.rs index 8940a1b19..4569d0f8b 100644 --- a/rust/connlib/snownet/src/lib.rs +++ b/rust/connlib/snownet/src/lib.rs @@ -16,8 +16,8 @@ pub use allocation::RelaySocket; #[allow(deprecated)] // Rust bug: `expect` doesn't seem to work on imports? pub use node::{Answer, Offer}; pub use node::{ - Client, ClientNode, Credentials, EncryptedPacket, Error, Event, HANDSHAKE_TIMEOUT, - NoTurnServers, Node, Server, ServerNode, Transmit, + Client, ClientNode, Credentials, Error, Event, HANDSHAKE_TIMEOUT, NoTurnServers, Node, Server, + ServerNode, Transmit, }; pub use stats::{ConnectionStats, NodeStats}; diff --git a/rust/connlib/snownet/src/node.rs b/rust/connlib/snownet/src/node.rs index 46d83613b..8bdfd60c5 100644 --- a/rust/connlib/snownet/src/node.rs +++ b/rust/connlib/snownet/src/node.rs @@ -7,6 +7,7 @@ use boringtun::noise::errors::WireGuardError; use boringtun::noise::{Tunn, TunnResult}; use boringtun::x25519::PublicKey; use boringtun::{noise::rate_limiter::RateLimiter, x25519::StaticSecret}; +use bufferpool::{Buffer, BufferPool}; use core::fmt; use firezone_logging::err_with_src; use hex_display::HexDisplayExt; @@ -17,7 +18,6 @@ use rand::{Rng, RngCore, SeedableRng, random}; use ringbuffer::{AllocRingBuffer, RingBuffer as _}; use secrecy::{ExposeSecret, Secret}; use sha2::Digest; -use std::borrow::Cow; use std::collections::btree_map::Entry; use std::collections::{BTreeMap, BTreeSet}; use std::hash::Hash; @@ -118,7 +118,7 @@ pub struct Node { rate_limiter: Arc, /// Host and server-reflexive candidates that are shared between all connections. shared_candidates: CandidateSet, - buffered_transmits: VecDeque>, + buffered_transmits: VecDeque, next_rate_limiter_reset: Option, @@ -128,9 +128,7 @@ pub struct Node { pending_events: VecDeque>, stats: NodeStats, - // All access to [`Node`] happens in the same thread, so we should never get contention which makes a spinlock ideal. - // This is wrapped in an `Arc` so we can use `pull_owned`. - buffer_pool: Arc>>, + buffer_pool: BufferPool>, mode: T, rng: StdRng, @@ -184,10 +182,7 @@ where allocations: Default::default(), connections: Default::default(), stats: Default::default(), - buffer_pool: Arc::new(lockfree_object_pool::SpinLockObjectPool::new( - || vec![0; ip_packet::MAX_FZ_PAYLOAD], - |v| v.fill(0), - )), + buffer_pool: BufferPool::new(ip_packet::MAX_FZ_PAYLOAD, "snownet"), } } @@ -447,7 +442,7 @@ where connection: TId, packet: IpPacket, now: Instant, - ) -> Result, Error> { + ) -> Result, Error> { let conn = self .connections .get_established_mut(&connection) @@ -462,7 +457,8 @@ where return Ok(None); } - let mut buffer = self.buffer_pool.pull_owned(); + let mut buffer = self.buffer_pool.pull(); + buffer.resize(ip_packet::MAX_FZ_PAYLOAD, 0); // Encode the packet with an offset of 4 bytes, in case we need to wrap it in a channel-data message. let Some(packet_len) = conn @@ -500,13 +496,16 @@ where | PeerSocket::PeerToRelay { source, dest: remote, - } => Ok(Some(EncryptedPacket { - src: Some(source), - dst: remote, - packet_start, - packet_len, - buffer, - })), + } => { + buffer.copy_within(packet_start..packet_end, 0); + buffer.truncate(packet_len); + + Ok(Some(Transmit { + src: Some(source), + dst: remote, + payload: buffer, + })) + } PeerSocket::RelayToPeer { relay, dest: peer } | PeerSocket::RelayToRelay { relay, dest: peer } => { let Some(allocation) = self.allocations.get_mut(&relay) else { @@ -519,12 +518,12 @@ where return Ok(None); }; - Ok(Some(EncryptedPacket { + buffer.truncate(packet_end); + + Ok(Some(Transmit { src: None, dst: encode_ok.socket, - packet_start: 0, - packet_len: packet_end, - buffer, + payload: buffer, })) } } @@ -616,7 +615,7 @@ where /// Returns buffered data that needs to be sent on the socket. #[must_use] - pub fn poll_transmit(&mut self) -> Option> { + pub fn poll_transmit(&mut self) -> Option { let allocation_transmits = &mut self .allocations .values_mut() @@ -679,6 +678,7 @@ where realm, now, self.session_id.clone(), + self.buffer_pool.clone(), )); tracing::info!(%rid, address = ?server, "Added new TURN server"); @@ -706,6 +706,7 @@ where realm, now, self.session_id.clone(), + self.buffer_pool.clone(), )); tracing::info!(%rid, address = ?server, "Replaced TURN server"); @@ -770,6 +771,7 @@ where }, possible_sockets: BTreeSet::default(), span: info_span!(parent: tracing::Span::none(), "connection", %cid), + buffer_pool: self.buffer_pool.clone(), } } @@ -1494,38 +1496,8 @@ pub enum Event { ConnectionClosed(TId), } -pub struct EncryptedPacket { - pub(crate) src: Option, - pub(crate) dst: SocketAddr, - pub(crate) packet_start: usize, - pub(crate) packet_len: usize, - pub(crate) buffer: lockfree_object_pool::SpinLockOwnedReusable>, -} - -impl EncryptedPacket { - pub fn to_transmit(&self) -> Transmit<'_> { - Transmit { - src: self.src, - dst: self.dst, - payload: Cow::Borrowed(self.payload()), - } - } - - pub fn src(&self) -> Option { - self.src - } - - pub fn dst(&self) -> SocketAddr { - self.dst - } - - pub fn payload(&self) -> &[u8] { - &self.buffer[self.packet_start..(self.packet_start + self.packet_len)] - } -} - #[derive(Clone, PartialEq, PartialOrd, Eq, Ord)] -pub struct Transmit<'a> { +pub struct Transmit { /// The local interface from which this packet should be sent. /// /// If `None`, it can be sent from any interface. @@ -1536,10 +1508,10 @@ pub struct Transmit<'a> { /// The remote the packet should be sent to. pub dst: SocketAddr, /// The data that should be sent. - pub payload: Cow<'a, [u8]>, + pub payload: Buffer>, } -impl fmt::Debug for Transmit<'_> { +impl fmt::Debug for Transmit { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("Transmit") .field("src", &self.src) @@ -1549,16 +1521,6 @@ impl fmt::Debug for Transmit<'_> { } } -impl Transmit<'_> { - pub fn into_owned(self) -> Transmit<'static> { - Transmit { - src: self.src, - dst: self.dst, - payload: Cow::Owned(self.payload.into_owned()), - } - } -} - struct InitialConnection { agent: IceAgent, session_key: Secret<[u8; 32]>, @@ -1624,6 +1586,7 @@ struct Connection { buffer: Vec, span: tracing::Span, + buffer_pool: BufferPool>, } enum ConnectionState { @@ -1833,7 +1796,7 @@ where cid: TId, now: Instant, allocations: &mut BTreeMap, - transmits: &mut VecDeque>, + transmits: &mut VecDeque, ) where TId: Copy + Ord + fmt::Display, RId: Copy + Ord + fmt::Display, @@ -1913,7 +1876,13 @@ where tracing::debug!(%num_buffered, "Flushing packets buffered during ICE"); transmits.extend(buffered.into_iter().flat_map(|packet| { - make_owned_transmit(remote_socket, &packet, allocations, now) + make_owned_transmit( + remote_socket, + &packet, + &self.buffer_pool, + allocations, + now, + ) })); self.state = ConnectionState::Connected { peer_socket: remote_socket, @@ -1984,7 +1953,7 @@ where transmits.push_back(Transmit { src: Some(source), dst, - payload: Cow::Owned(stun_packet.into()), + payload: self.buffer_pool.pull_initialised(&Vec::from(stun_packet)), }); continue; }; @@ -2005,7 +1974,7 @@ where transmits.push_back(Transmit { src: None, dst: encode_ok.socket, - payload: Cow::Owned(data_channel_packet), + payload: self.buffer_pool.pull_initialised(&data_channel_packet), }); } } @@ -2014,7 +1983,7 @@ where &mut self, now: Instant, allocations: &mut BTreeMap, - transmits: &mut VecDeque>, + transmits: &mut VecDeque, ) { // Don't update wireguard timers until we are connected. let Some(peer_socket) = self.socket() else { @@ -2038,7 +2007,13 @@ where tracing::warn!(?e); } TunnResult::WriteToNetwork(b) => { - transmits.extend(make_owned_transmit(peer_socket, b, allocations, now)); + transmits.extend(make_owned_transmit( + peer_socket, + b, + &self.buffer_pool, + allocations, + now, + )); } TunnResult::WriteToTunnelV4(..) | TunnResult::WriteToTunnelV6(..) => { panic!("Unexpected result from update_timers") @@ -2072,7 +2047,7 @@ where &mut self, packet: &[u8], allocations: &mut BTreeMap, - transmits: &mut VecDeque>, + transmits: &mut VecDeque, now: Instant, ) -> ControlFlow, IpPacket> { let _guard = self.span.enter(); @@ -2132,6 +2107,7 @@ where transmits.extend(make_owned_transmit( *peer_socket, bytes, + &self.buffer_pool, allocations, now, )); @@ -2143,6 +2119,7 @@ where transmits.extend(make_owned_transmit( *peer_socket, packet, + &self.buffer_pool, allocations, now, )); @@ -2165,7 +2142,7 @@ where fn force_handshake( &mut self, allocations: &mut BTreeMap, - transmits: &mut VecDeque>, + transmits: &mut VecDeque, now: Instant, ) where RId: Copy, @@ -2188,7 +2165,13 @@ where .socket() .expect("cannot force handshake while not connected"); - transmits.extend(make_owned_transmit(socket, bytes, allocations, now)); + transmits.extend(make_owned_transmit( + socket, + bytes, + &self.buffer_pool, + allocations, + now, + )); } fn socket(&self) -> Option> { @@ -2212,9 +2195,10 @@ where fn make_owned_transmit( socket: PeerSocket, message: &[u8], + buffer_pool: &BufferPool>, allocations: &mut BTreeMap, now: Instant, -) -> Option> +) -> Option where RId: Copy + Eq + Hash + PartialEq + Ord + fmt::Debug, { @@ -2229,19 +2213,19 @@ where } => Transmit { src: Some(source), dst: remote, - payload: Cow::Owned(message.into()), + payload: buffer_pool.pull_initialised(message), }, PeerSocket::RelayToPeer { relay, dest: peer } | PeerSocket::RelayToRelay { relay, dest: peer } => { let allocation = allocations.get_mut(&relay)?; - let mut buffer = channel_data_packet_buffer(message); - let encode_ok = allocation.encode_channel_data_header(peer, &mut buffer, now)?; + let mut channel_data = channel_data_packet_buffer(message); + let encode_ok = allocation.encode_channel_data_header(peer, &mut channel_data, now)?; Transmit { src: None, dst: encode_ok.socket, - payload: Cow::Owned(buffer), + payload: buffer_pool.pull_initialised(&channel_data), } } }; diff --git a/rust/connlib/tunnel/Cargo.toml b/rust/connlib/tunnel/Cargo.toml index 6bbd71b4e..2f40e8a54 100644 --- a/rust/connlib/tunnel/Cargo.toml +++ b/rust/connlib/tunnel/Cargo.toml @@ -9,6 +9,7 @@ anyhow = { workspace = true } base64 = { workspace = true, features = ["std"] } bimap = { workspace = true } boringtun = { workspace = true } +bufferpool = { workspace = true } bytes = { workspace = true, features = ["std"] } chrono = { workspace = true } connlib-model = { workspace = true } @@ -30,7 +31,6 @@ ip_network_table = { workspace = true } itertools = { workspace = true, features = ["use_std"] } l4-tcp-dns-server = { workspace = true } l4-udp-dns-server = { workspace = true } -lockfree-object-pool = { workspace = true } lru = { workspace = true } opentelemetry = { workspace = true, features = ["metrics"] } proptest = { workspace = true, optional = true } diff --git a/rust/connlib/tunnel/src/client.rs b/rust/connlib/tunnel/src/client.rs index 1305ef9d8..8ddc8f668 100644 --- a/rust/connlib/tunnel/src/client.rs +++ b/rust/connlib/tunnel/src/client.rs @@ -145,7 +145,7 @@ pub struct ClientState { buffered_events: VecDeque, buffered_packets: VecDeque, - buffered_transmits: VecDeque>, + buffered_transmits: VecDeque, buffered_dns_queries: VecDeque, } @@ -509,7 +509,7 @@ impl ClientState { &mut self, packet: IpPacket, now: Instant, - ) -> Option { + ) -> Option { let non_dns_packet = match self.try_handle_dns(packet, now) { ControlFlow::Break(()) => return None, ControlFlow::Continue(non_dns_packet) => non_dns_packet, @@ -625,7 +625,7 @@ impl ClientState { } } - fn encapsulate(&mut self, packet: IpPacket, now: Instant) -> Option { + fn encapsulate(&mut self, packet: IpPacket, now: Instant) -> Option { let dst = packet.destination(); if is_definitely_not_a_resource(dst) { @@ -1165,11 +1165,10 @@ impl ClientState { // Check if the client wants to emit any packets. if let Some(packet) = self.tcp_dns_client.poll_outbound() { // All packets from the TCP DNS client _should_ go through the tunnel. - let Some(encryped_packet) = self.encapsulate(packet, now) else { + let Some(transmit) = self.encapsulate(packet, now) else { continue; }; - let transmit = encryped_packet.to_transmit().into_owned(); self.buffered_transmits.push_back(transmit); continue; } @@ -1633,7 +1632,7 @@ impl ClientState { self.tcp_dns_client.reset(); } - pub(crate) fn poll_transmit(&mut self) -> Option> { + pub(crate) fn poll_transmit(&mut self) -> Option { self.buffered_transmits .pop_front() .or_else(|| self.node.poll_transmit()) @@ -1895,9 +1894,9 @@ fn encapsulate_and_buffer( gid: GatewayId, now: Instant, node: &mut ClientNode, - buffered_transmits: &mut VecDeque>, + buffered_transmits: &mut VecDeque, ) { - let Some(enc_packet) = node + let Some(transmit) = node .encapsulate(gid, packet, now) .inspect_err(|e| tracing::debug!(%gid, "Failed to encapsulate: {e}")) .ok() @@ -1906,7 +1905,7 @@ fn encapsulate_and_buffer( return; }; - buffered_transmits.push_back(enc_packet.to_transmit().into_owned()); + buffered_transmits.push_back(transmit); } fn handle_p2p_control_packet( @@ -1914,7 +1913,7 @@ fn handle_p2p_control_packet( fz_p2p_control: ip_packet::FzP2pControlSlice, dns_resource_nat_by_gateway: &mut BTreeMap<(GatewayId, DomainName), DnsResourceNatState>, node: &mut ClientNode, - buffered_transmits: &mut VecDeque>, + buffered_transmits: &mut VecDeque, now: Instant, ) { use p2p_control::dns_resource_nat; diff --git a/rust/connlib/tunnel/src/gateway.rs b/rust/connlib/tunnel/src/gateway.rs index e1fccf151..0e242d8a6 100644 --- a/rust/connlib/tunnel/src/gateway.rs +++ b/rust/connlib/tunnel/src/gateway.rs @@ -38,7 +38,7 @@ pub struct GatewayState { tun_ip_config: Option, buffered_events: VecDeque, - buffered_transmits: VecDeque>, + buffered_transmits: VecDeque, } #[derive(Debug)] @@ -84,7 +84,7 @@ impl GatewayState { &mut self, packet: IpPacket, now: Instant, - ) -> Result> { + ) -> Result> { let dst = packet.destination(); if !crate::is_peer(dst) { @@ -411,7 +411,7 @@ impl GatewayState { } } - pub(crate) fn poll_transmit(&mut self) -> Option> { + pub(crate) fn poll_transmit(&mut self) -> Option { self.buffered_transmits .pop_front() .or_else(|| self.node.poll_transmit()) @@ -497,15 +497,12 @@ fn encrypt_packet( cid: ClientId, node: &mut ServerNode, now: Instant, -) -> Result>> { - let Some(encrypted_packet) = node +) -> Result> { + let transmit = node .encapsulate(cid, packet, now) - .context("Failed to encapsulate packet")? - else { - return Ok(None); - }; + .context("Failed to encapsulate packet")?; - Ok(Some(encrypted_packet.to_transmit().into_owned())) + Ok(transmit) } /// Opaque request struct for when a domain name needs to be resolved. diff --git a/rust/connlib/tunnel/src/io/gso_queue.rs b/rust/connlib/tunnel/src/io/gso_queue.rs index a9ea44810..5ed02f852 100644 --- a/rust/connlib/tunnel/src/io/gso_queue.rs +++ b/rust/connlib/tunnel/src/io/gso_queue.rs @@ -1,9 +1,9 @@ use std::{ collections::{BTreeMap, VecDeque}, net::SocketAddr, - sync::Arc, }; +use bufferpool::{Buffer, BufferPool}; use bytes::BytesMut; use ip_packet::Ecn; use socket_factory::DatagramOut; @@ -13,29 +13,20 @@ use super::MAX_INBOUND_PACKET_BATCH; const MAX_SEGMENT_SIZE: usize = ip_packet::MAX_IP_SIZE + ip_packet::WG_OVERHEAD + ip_packet::DATA_CHANNEL_OVERHEAD; -type Buffer = lockfree_object_pool::SpinLockOwnedReusable; - /// Holds UDP datagrams that we need to send, indexed by src, dst and segment size. /// /// Calling [`Io::send_network`](super::Io::send_network) will copy the provided payload into this buffer. /// The buffer is then flushed using GSO in a single syscall. pub struct GsoQueue { - inner: BTreeMap>, - buffer_pool: Arc>, + inner: BTreeMap)>>, + buffer_pool: BufferPool, } impl GsoQueue { pub fn new() -> Self { Self { inner: Default::default(), - buffer_pool: Arc::new(lockfree_object_pool::SpinLockObjectPool::new( - || { - tracing::debug!("Initialising new buffer for GSO queue"); - - BytesMut::with_capacity(MAX_SEGMENT_SIZE * MAX_INBOUND_PACKET_BATCH) - }, - |b| b.clear(), - )), + buffer_pool: BufferPool::new(MAX_SEGMENT_SIZE * MAX_INBOUND_PACKET_BATCH, "gso-queue"), } } @@ -50,10 +41,7 @@ impl GsoQueue { let batches = self.inner.entry(Connection { src, dst, ecn }).or_default(); let Some((batch_size, buffer)) = batches.back_mut() else { - let mut buffer = self.buffer_pool.pull_owned(); - buffer.extend_from_slice(payload); - - batches.push_back((payload_len, buffer)); + batches.push_back((payload_len, self.buffer_pool.pull_initialised(payload))); return; }; @@ -67,16 +55,10 @@ impl GsoQueue { return; } - let mut buffer = self.buffer_pool.pull_owned(); - buffer.extend_from_slice(payload); - - batches.push_back((payload_len, buffer)); + batches.push_back((payload_len, self.buffer_pool.pull_initialised(payload))); } - pub fn datagrams( - &mut self, - ) -> impl Iterator>> + '_ - { + pub fn datagrams(&mut self) -> impl Iterator + '_ { DrainDatagramsIter { queue: self } } @@ -98,7 +80,7 @@ struct DrainDatagramsIter<'a> { } impl Iterator for DrainDatagramsIter<'_> { - type Item = DatagramOut>; + type Item = DatagramOut; fn next(&mut self) -> Option { loop { diff --git a/rust/connlib/tunnel/src/lib.rs b/rust/connlib/tunnel/src/lib.rs index e98be8e75..aff71f203 100644 --- a/rust/connlib/tunnel/src/lib.rs +++ b/rust/connlib/tunnel/src/lib.rs @@ -177,13 +177,13 @@ impl ClientTunnel { for packet in packets { let ecn = packet.ecn(); - let Some(packet) = self.role_state.handle_tun_input(packet, now) else { + let Some(transmit) = self.role_state.handle_tun_input(packet, now) else { self.role_state.handle_timeout(now); continue; }; self.io - .send_network(packet.src(), packet.dst(), packet.payload(), ecn); + .send_network(transmit.src, transmit.dst, &transmit.payload, ecn); } continue; @@ -308,13 +308,13 @@ impl GatewayTunnel { for packet in packets { let ecn = packet.ecn(); - let Some(packet) = self.role_state.handle_tun_input(packet, now)? else { + let Some(transmit) = self.role_state.handle_tun_input(packet, now)? else { self.role_state.handle_timeout(now, Utc::now()); continue; }; self.io - .send_network(packet.src(), packet.dst(), packet.payload(), ecn); + .send_network(transmit.src, transmit.dst, &transmit.payload, ecn); } continue; diff --git a/rust/connlib/tunnel/src/sockets.rs b/rust/connlib/tunnel/src/sockets.rs index 891afd90c..ec9ca6633 100644 --- a/rust/connlib/tunnel/src/sockets.rs +++ b/rust/connlib/tunnel/src/sockets.rs @@ -1,7 +1,7 @@ use anyhow::Result; -use bytes::BytesMut; use futures::{SinkExt, StreamExt, ready}; use gat_lending_iterator::LendingIterator; +use socket_factory::DatagramOut; use socket_factory::{DatagramIn, DatagramSegmentIter, SocketFactory, UdpSocket}; use std::{ io, @@ -11,9 +11,6 @@ use std::{ task::{Context, Poll, Waker}, }; -type DatagramOut = - socket_factory::DatagramOut>; - const UNSPECIFIED_V4_SOCKET: SocketAddrV4 = SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0); const UNSPECIFIED_V6_SOCKET: SocketAddrV6 = SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, 0, 0, 0); diff --git a/rust/connlib/tunnel/src/tests/buffered_transmits.rs b/rust/connlib/tunnel/src/tests/buffered_transmits.rs index 4b71d0aa3..1965f0dd2 100644 --- a/rust/connlib/tunnel/src/tests/buffered_transmits.rs +++ b/rust/connlib/tunnel/src/tests/buffered_transmits.rs @@ -10,7 +10,7 @@ use std::{ #[derive(Debug, Clone, Default)] pub(crate) struct BufferedTransmits { // Transmits are stored in reverse ordering to emit the earliest first. - inner: BinaryHeap>>>, + inner: BinaryHeap>>, } #[derive(Clone, Debug, PartialEq, PartialOrd, Eq, Ord)] @@ -23,7 +23,7 @@ impl BufferedTransmits { /// Pushes a new [`Transmit`] from a given [`Host`]. pub(crate) fn push_from( &mut self, - transmit: impl Into>>, + transmit: impl Into>, sending_host: &Host, now: Instant, ) { @@ -58,7 +58,7 @@ impl BufferedTransmits { pub(crate) fn push( &mut self, - transmit: impl Into>>, + transmit: impl Into>, latency: Duration, now: Instant, ) { @@ -74,7 +74,7 @@ impl BufferedTransmits { })); } - pub(crate) fn pop(&mut self, now: Instant) -> Option> { + pub(crate) fn pop(&mut self, now: Instant) -> Option { let next = self.inner.peek()?.0.at; if next > now { @@ -86,7 +86,7 @@ impl BufferedTransmits { Some(next.value) } - pub(crate) fn drain(&mut self) -> impl Iterator, Instant)> + '_ { + pub(crate) fn drain(&mut self) -> impl Iterator + '_ { self.inner .drain() .map(|Reverse(ByTime { at, value })| (value, at)) diff --git a/rust/connlib/tunnel/src/tests/sim_client.rs b/rust/connlib/tunnel/src/tests/sim_client.rs index c29fe5a4a..3d6f8a271 100644 --- a/rust/connlib/tunnel/src/tests/sim_client.rs +++ b/rust/connlib/tunnel/src/tests/sim_client.rs @@ -122,7 +122,7 @@ impl SimClient { upstream: SocketAddr, dns_transport: DnsTransport, now: Instant, - ) -> Option> { + ) -> Option { let Some(sentinel) = self.dns_by_sentinel.get_by_right(&upstream).copied() else { tracing::error!(%upstream, "Unknown DNS server"); return None; @@ -167,15 +167,15 @@ impl SimClient { &mut self, packet: IpPacket, now: Instant, - ) -> Option> { + ) -> Option { self.update_sent_requests(&packet); - let Some(enc_packet) = self.sut.handle_tun_input(packet, now) else { + let Some(transmit) = self.sut.handle_tun_input(packet, now) else { self.sut.handle_timeout(now); // If we handled the packet internally, make sure to advance state. return None; }; - Some(enc_packet.to_transmit().into_owned()) + Some(transmit) } fn update_sent_requests(&mut self, packet: &IpPacket) { diff --git a/rust/connlib/tunnel/src/tests/sim_gateway.rs b/rust/connlib/tunnel/src/tests/sim_gateway.rs index b45ba2f9a..933b7607c 100644 --- a/rust/connlib/tunnel/src/tests/sim_gateway.rs +++ b/rust/connlib/tunnel/src/tests/sim_gateway.rs @@ -64,7 +64,7 @@ impl SimGateway { unreachable_hosts: &UnreachableHosts, now: Instant, utc_now: DateTime, - ) -> Option> { + ) -> Option { let Some(packet) = self .sut .handle_network_input(transmit.dst, transmit.src.unwrap(), &transmit.payload, now) @@ -83,7 +83,7 @@ impl SimGateway { &mut self, global_dns_records: &DnsRecords, now: Instant, - ) -> Vec> { + ) -> Vec { let Some(ip_config) = self.sut.tunnel_ip_config() else { tracing::error!("Tunnel IP configuration not set"); return Vec::new(); @@ -116,15 +116,7 @@ impl SimGateway { udp_server_packets .chain(tcp_server_packets) - .filter_map(|packet| { - Some( - self.sut - .handle_tun_input(packet, now) - .unwrap()? - .to_transmit() - .into_owned(), - ) - }) + .filter_map(|packet| self.sut.handle_tun_input(packet, now).unwrap()) .collect() } @@ -165,7 +157,7 @@ impl SimGateway { packet: IpPacket, unreachable_hosts: &UnreachableHosts, now: Instant, - ) -> Option> { + ) -> Option { // TODO: Instead of handling these things inline, here, should we dispatch them via `RoutingTable`? let dst_ip = packet.destination(); @@ -220,12 +212,7 @@ impl SimGateway { if let Some(reply) = icmp_error.or_else(|| echo_reply(packet.clone())) { self.request_received(&packet); - let transmit = self - .sut - .handle_tun_input(reply, now) - .unwrap()? - .to_transmit() - .into_owned(); + let transmit = self.sut.handle_tun_input(reply, now).unwrap()?; return Some(transmit); } @@ -268,7 +255,7 @@ impl SimGateway { payload: &[u8], icmp_error: Option, now: Instant, - ) -> Option> { + ) -> Option { let reply = icmp_error.unwrap_or_else(|| { ip_packet::make::icmp_reply_packet( packet.destination(), @@ -280,12 +267,7 @@ impl SimGateway { .expect("src and dst are taken from incoming packet") }); - let transmit = self - .sut - .handle_tun_input(reply, now) - .unwrap()? - .to_transmit() - .into_owned(); + let transmit = self.sut.handle_tun_input(reply, now).unwrap()?; Some(transmit) } diff --git a/rust/connlib/tunnel/src/tests/sim_net.rs b/rust/connlib/tunnel/src/tests/sim_net.rs index 716c9cc37..0d1fdf4b5 100644 --- a/rust/connlib/tunnel/src/tests/sim_net.rs +++ b/rust/connlib/tunnel/src/tests/sim_net.rs @@ -123,11 +123,11 @@ impl Host { self.latency } - pub(crate) fn receive(&mut self, transmit: Transmit<'static>, now: Instant) { + pub(crate) fn receive(&mut self, transmit: Transmit, now: Instant) { self.inbox.push(transmit, self.latency, now); } - pub(crate) fn poll_transmit(&mut self, now: Instant) -> Option> { + pub(crate) fn poll_transmit(&mut self, now: Instant) -> Option { self.inbox.pop(now) } } diff --git a/rust/connlib/tunnel/src/tests/sim_relay.rs b/rust/connlib/tunnel/src/tests/sim_relay.rs index 1a2d180d0..1802daac6 100644 --- a/rust/connlib/tunnel/src/tests/sim_relay.rs +++ b/rust/connlib/tunnel/src/tests/sim_relay.rs @@ -2,6 +2,7 @@ use super::{ sim_net::{Host, dual_ip_stack, host}, strategies::latency, }; +use bufferpool::Buffer; use connlib_model::RelayId; use firezone_relay::{AddressFamily, AllocationPort, ClientSocket, IpStack, PeerSocket}; use proptest::prelude::*; @@ -9,7 +10,6 @@ use rand::{SeedableRng as _, rngs::StdRng}; use secrecy::SecretString; use snownet::{RelaySocket, Transmit}; use std::{ - borrow::Cow, collections::HashSet, net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}, time::{Duration, Instant, SystemTime}, @@ -18,7 +18,6 @@ use std::{ pub(crate) struct SimRelay { pub(crate) sut: firezone_relay::Server, pub(crate) allocations: HashSet<(AddressFamily, AllocationPort)>, - buffer: Vec, created_at: SystemTime, } @@ -52,7 +51,6 @@ impl SimRelay { Self { sut, allocations: Default::default(), - buffer: vec![0u8; (1 << 16) - 1], created_at: SystemTime::now(), } } @@ -90,13 +88,9 @@ impl SimRelay { } } - pub(crate) fn receive( - &mut self, - transmit: Transmit, - now: Instant, - ) -> Option> { + pub(crate) fn receive(&mut self, transmit: Transmit, now: Instant) -> Option { let dst = transmit.dst; - let payload = &transmit.payload; + let payload = transmit.payload; let sender = transmit.src.unwrap(); if self @@ -115,13 +109,13 @@ impl SimRelay { fn handle_client_input( &mut self, - payload: &[u8], + mut payload: Buffer>, client: ClientSocket, now: Instant, - ) -> Option> { - let (port, peer) = self.sut.handle_client_input(payload, client, now)?; + ) -> Option { + let (port, peer) = self.sut.handle_client_input(&payload, client, now)?; - let payload = &payload[4..]; + payload.shift_start_right(4); // The `dst` of the relayed packet is what TURN calls a "peer". let dst = peer.into_socket(); @@ -156,24 +150,22 @@ impl SimRelay { Some(Transmit { src: Some(src), dst, - payload: Cow::Owned(payload.to_vec()), + payload, }) } fn handle_peer_traffic( &mut self, - payload: &[u8], + mut payload: Buffer>, peer: PeerSocket, port: AllocationPort, - ) -> Option> { - let (client, channel) = self.sut.handle_peer_traffic(payload, peer, port)?; + ) -> Option { + let (client, channel) = self.sut.handle_peer_traffic(&payload, peer, port)?; - let full_length = firezone_relay::ChannelData::encode_header_to_slice( - channel, - payload.len() as u16, - &mut self.buffer[..4], - ); - self.buffer[4..full_length].copy_from_slice(payload); + let data_len = payload.len() as u16; + let header = payload.shift_start_left(4); + + firezone_relay::ChannelData::encode_header_to_slice(channel, data_len, header); let receiving_socket = client.into_socket(); let sending_socket = self @@ -183,7 +175,7 @@ impl SimRelay { Some(Transmit { src: Some(sending_socket), dst: receiving_socket, - payload: Cow::Owned(self.buffer[..full_length].to_vec()), + payload, }) } diff --git a/rust/connlib/tunnel/src/tests/sut.rs b/rust/connlib/tunnel/src/tests/sut.rs index 06bf4674a..1c92b0ee4 100644 --- a/rust/connlib/tunnel/src/tests/sut.rs +++ b/rust/connlib/tunnel/src/tests/sut.rs @@ -16,6 +16,7 @@ use crate::tests::flux_capacitor::FluxCapacitor; use crate::tests::transition::Transition; use crate::utils::earliest; use crate::{ClientEvent, GatewayEvent, dns, messages::Interface}; +use bufferpool::BufferPool; use connlib_model::{ClientId, GatewayId, PublicKey, RelayId}; use dns_types::ResponseCode; use dns_types::prelude::*; @@ -41,6 +42,8 @@ pub(crate) struct TunnelTest { gateways: BTreeMap>, relays: BTreeMap>, + buffer_pool: BufferPool>, + drop_direct_client_traffic: bool, network: RoutingTable, } @@ -91,6 +94,7 @@ impl TunnelTest { client, gateways, relays, + buffer_pool: BufferPool::new(1024, "test"), }; let mut buffered_transmits = BufferedTransmits::default(); @@ -155,9 +159,7 @@ impl TunnelTest { ) .unwrap(); - let transmit = state - .client - .exec_mut(|sim| Some(sim.encapsulate(packet, now)?.into_owned())); + let transmit = state.client.exec_mut(|sim| sim.encapsulate(packet, now)); buffered_transmits.push_from(transmit, &state.client, now); } @@ -179,9 +181,7 @@ impl TunnelTest { ) .unwrap(); - let transmit = state - .client - .exec_mut(|sim| Some(sim.encapsulate(packet, now)?.into_owned())); + let transmit = state.client.exec_mut(|sim| sim.encapsulate(packet, now)); buffered_transmits.push_from(transmit, &state.client, now); } @@ -203,9 +203,7 @@ impl TunnelTest { ) .unwrap(); - let transmit = state - .client - .exec_mut(|sim| Some(sim.encapsulate(packet, now)?.into_owned())); + let transmit = state.client.exec_mut(|sim| sim.encapsulate(packet, now)); buffered_transmits.push_from(transmit, &state.client, now); } @@ -459,7 +457,7 @@ impl TunnelTest { Transmit { src: Some(src), dst, - payload: payload.into(), + payload: self.buffer_pool.pull_initialised(&payload), }, relay, now, @@ -633,7 +631,7 @@ impl TunnelTest { /// It takes a [`Transmit`] and checks, which host accepts it, i.e. has configured the correct IP address. /// /// Currently, the network topology of our tests are a single subnet without NAT. - fn dispatch_transmit(&mut self, transmit: Transmit<'static>, at: Instant) { + fn dispatch_transmit(&mut self, transmit: Transmit, at: Instant) { let src = transmit .src .expect("`src` should always be set in these tests"); diff --git a/rust/ip-packet/Cargo.toml b/rust/ip-packet/Cargo.toml index 1a9f960d5..d65526db8 100644 --- a/rust/ip-packet/Cargo.toml +++ b/rust/ip-packet/Cargo.toml @@ -12,9 +12,9 @@ proptest = ["dep:proptest"] [dependencies] anyhow = { workspace = true } +bufferpool = { workspace = true } etherparse = { workspace = true, features = ["std"] } etherparse-ext = { workspace = true } -lockfree-object-pool = { workspace = true } proptest = { workspace = true, optional = true } thiserror = { workspace = true } tracing = { workspace = true } diff --git a/rust/ip-packet/src/buffer_pool.rs b/rust/ip-packet/src/buffer_pool.rs deleted file mode 100644 index 39faeffd1..000000000 --- a/rust/ip-packet/src/buffer_pool.rs +++ /dev/null @@ -1,96 +0,0 @@ -use std::{ - ops::{Deref, DerefMut}, - sync::{Arc, LazyLock}, -}; - -use crate::MAX_FZ_PAYLOAD; - -type BufferPool = Arc>>; - -static BUFFER_POOL: LazyLock = LazyLock::new(|| { - Arc::new(lockfree_object_pool::MutexObjectPool::new( - || vec![0; MAX_FZ_PAYLOAD], - |v| v.fill(0), - )) -}); - -pub struct Buffer(lockfree_object_pool::MutexOwnedReusable>); - -impl Clone for Buffer { - fn clone(&self) -> Self { - let mut copy = Buffer::default(); - - copy.0.resize(self.len(), 0); - copy.copy_from_slice(self); - - copy - } -} - -impl PartialEq for Buffer { - fn eq(&self, other: &Self) -> bool { - self.as_ref() == other.as_ref() - } -} - -impl std::fmt::Debug for Buffer { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_tuple("Buffer").finish() - } -} - -impl Deref for Buffer { - type Target = [u8]; - - fn deref(&self) -> &Self::Target { - &self.0[..] - } -} - -impl DerefMut for Buffer { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.0[..] - } -} - -impl Default for Buffer { - fn default() -> Self { - Self(BUFFER_POOL.pull_owned()) - } -} - -impl Drop for Buffer { - fn drop(&mut self) { - debug_assert_eq!( - self.0.capacity(), - MAX_FZ_PAYLOAD, - "Buffer should never re-allocate" - ) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn buffer_can_be_cloned() { - let mut buffer = Buffer::default(); - buffer[..11].copy_from_slice(b"hello world"); - - let buffer2 = buffer.clone(); - - assert_eq!(&buffer2[..], &buffer[..]); - } - - #[test] - fn cloned_buffer_owns_its_own_memory() { - let mut buffer = Buffer::default(); - buffer[..11].copy_from_slice(b"hello world"); - - let buffer2 = buffer.clone(); - drop(buffer); - - assert_eq!(&buffer2[..11], b"hello world"); - } -} diff --git a/rust/ip-packet/src/lib.rs b/rust/ip-packet/src/lib.rs index b8c984f43..7b0fb2c7a 100644 --- a/rust/ip-packet/src/lib.rs +++ b/rust/ip-packet/src/lib.rs @@ -2,7 +2,6 @@ pub mod make; -mod buffer_pool; mod fz_p2p_control; mod fz_p2p_control_slice; mod icmp_dest_unreachable; @@ -13,7 +12,7 @@ mod nat64; #[allow(clippy::unwrap_used)] pub mod proptest; -use buffer_pool::Buffer; +use bufferpool::{Buffer, BufferPool}; pub use etherparse::*; pub use fz_p2p_control::EventType as FzP2pEventType; pub use fz_p2p_control_slice::FzP2pControlSlice; @@ -24,6 +23,7 @@ mod proptests; use anyhow::{Context as _, Result, bail}; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; +use std::sync::LazyLock; use etherparse_ext::Icmpv4HeaderSliceMut; use etherparse_ext::Icmpv6EchoHeaderSliceMut; @@ -32,6 +32,9 @@ use etherparse_ext::Ipv6HeaderSliceMut; use etherparse_ext::TcpHeaderSliceMut; use etherparse_ext::UdpHeaderSliceMut; +static BUFFER_POOL: LazyLock>> = + LazyLock::new(|| BufferPool::new(MAX_FZ_PAYLOAD, "ip-packet")); + /// The maximum size of an IP packet we can handle. pub const MAX_IP_SIZE: usize = 1280; /// The maximum payload an IP packet can have. @@ -116,9 +119,16 @@ pub enum Layer4Protocol { } /// A buffer for reading a new [`IpPacket`] from the network. -#[derive(Default)] pub struct IpPacketBuf { - inner: Buffer, + inner: Buffer>, +} + +impl Default for IpPacketBuf { + fn default() -> Self { + Self { + inner: BUFFER_POOL.pull(), + } + } } impl IpPacketBuf { @@ -200,7 +210,7 @@ impl std::fmt::Debug for IpPacket { #[derive(Debug, PartialEq, Clone)] pub struct ConvertibleIpv4Packet { - buf: Buffer, + buf: Buffer>, start: usize, len: usize, } @@ -280,7 +290,7 @@ impl ConvertibleIpv4Packet { #[derive(Debug, PartialEq, Clone)] pub struct ConvertibleIpv6Packet { - buf: Buffer, + buf: Buffer>, start: usize, len: usize, } diff --git a/rust/socket-factory/Cargo.toml b/rust/socket-factory/Cargo.toml index 03f787a90..bfeba2dd3 100644 --- a/rust/socket-factory/Cargo.toml +++ b/rust/socket-factory/Cargo.toml @@ -6,12 +6,12 @@ license = { workspace = true } [dependencies] anyhow = { workspace = true } +bufferpool = { workspace = true } bytes = { workspace = true } derive_more = { workspace = true, features = ["debug"] } firezone-logging = { workspace = true } gat-lending-iterator = { workspace = true } ip-packet = { workspace = true } -lockfree-object-pool = { workspace = true } opentelemetry = { workspace = true, features = ["metrics"] } parking_lot = { workspace = true } quinn-udp = { workspace = true } diff --git a/rust/socket-factory/src/lib.rs b/rust/socket-factory/src/lib.rs index fa0b4fb3d..29c2e0a4e 100644 --- a/rust/socket-factory/src/lib.rs +++ b/rust/socket-factory/src/lib.rs @@ -1,5 +1,6 @@ use anyhow::{Context as _, Result}; -use bytes::Buf as _; +use bufferpool::{Buffer, BufferPool}; +use bytes::{Buf as _, BytesMut}; use firezone_logging::err_with_src; use gat_lending_iterator::LendingIterator; use ip_packet::Ecn; @@ -10,7 +11,6 @@ use std::collections::HashMap; use std::io; use std::io::IoSliceMut; use std::ops::Deref; -use std::sync::Arc; use std::{ net::{IpAddr, SocketAddr}, task::{Context, Poll, ready}, @@ -156,7 +156,7 @@ pub struct UdpSocket { src_by_dst_cache: Mutex>, /// A buffer pool for batches of incoming UDP packets. - buffer_pool: Arc>>, + buffer_pool: BufferPool>, gro_batch_histogram: opentelemetry::metrics::Histogram, port: u16, @@ -164,7 +164,8 @@ pub struct UdpSocket { impl UdpSocket { fn new(inner: tokio::net::UdpSocket) -> io::Result { - let port = inner.local_addr()?.port(); + let socket_addr = inner.local_addr()?; + let port = socket_addr.port(); Ok(UdpSocket { state: quinn_udp::UdpSocketState::new(quinn_udp::UdpSockRef::from(&inner))?, @@ -172,10 +173,13 @@ impl UdpSocket { inner, source_ip_resolver: Box::new(|_| Ok(None)), src_by_dst_cache: Default::default(), - buffer_pool: Arc::new(lockfree_object_pool::MutexObjectPool::new( - || vec![0u8; u16::MAX as usize], - |_| {}, - )), + buffer_pool: BufferPool::new( + u16::MAX as usize, + match socket_addr.ip() { + IpAddr::V4(_) => "udp-socket-v4", + IpAddr::V6(_) => "udp-socket-v6", + }, + ), gro_batch_histogram: opentelemetry::global::meter("connlib") .u64_histogram("system.network.packets.batch_count") .with_description( @@ -249,10 +253,10 @@ pub struct DatagramIn<'a> { } /// An outbound UDP datagram. -pub struct DatagramOut { +pub struct DatagramOut { pub src: Option, pub dst: SocketAddr, - pub packet: B, + pub packet: Buffer, pub segment_size: Option, pub ecn: Ecn, } @@ -268,7 +272,7 @@ impl UdpSocket { } = self; // Stack-allocate arrays for buffers and meta. The size is implied from the const-generic default on `DatagramSegmentIter`. - let mut bufs = std::array::from_fn(|_| self.buffer_pool.pull_owned()); + let mut bufs = std::array::from_fn(|_| self.buffer_pool.pull()); let mut meta = std::array::from_fn(|_| quinn_udp::RecvMeta::default()); loop { @@ -302,14 +306,11 @@ impl UdpSocket { self.inner.poll_send_ready(cx) } - pub async fn send(&self, datagram: DatagramOut) -> Result<()> - where - B: Deref, - { + pub async fn send(&self, datagram: DatagramOut) -> Result<()> { let Some(transmit) = self.prepare_transmit( datagram.dst, datagram.src.map(|s| s.ip()), - datagram.packet.deref().chunk(), + datagram.packet.chunk(), datagram.segment_size, datagram.ecn, )? @@ -497,10 +498,7 @@ impl UdpSocket { /// When [`quinn_udp`] returns us the buffers, it will have populated the [`quinn_udp::RecvMeta`]s accordingly. /// Thus, our main job within this iterator is to loop over the `buffers` and `meta` pair-wise, inspect the `meta` and segment the data within the buffer accordingly. #[derive(derive_more::Debug)] -pub struct DatagramSegmentIter< - const N: usize = { quinn_udp::BATCH_SIZE }, - B = lockfree_object_pool::MutexOwnedReusable>, -> { +pub struct DatagramSegmentIter>> { #[debug(skip)] buffers: [B; N], metas: [quinn_udp::RecvMeta; N],