From 5f718ad982404d47daecbae12fa8e0f21db6d87b Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Wed, 3 Apr 2024 09:00:36 +1100 Subject: [PATCH] refactor(relay): reduce allocations during relaying (#4453) Previously, we would allocate each message twice: 1. When receiving the original packet. 2. When forming the resulting channel-data message. We can optimise this to only one allocation each by: 1. Carrying around the original `ChannelData` message for traffic from clients to peers. 2. Pre-allocating enough space for the channel-data header for traffic from peers to clients. Local flamegraphing still shows most of user-space activity as allocations. I did occasionally see a throughput of ~10GBps with these patches. I'd like to still work towards #4095 to ensure we handle anything time-sensitive better. --- rust/relay/src/allocation.rs | 18 +++---- rust/relay/src/lib.rs | 2 +- rust/relay/src/main.rs | 14 ++--- rust/relay/src/server.rs | 69 +++++++++++++++++++------ rust/relay/src/server/channel_data.rs | 55 +++++++++++++------- rust/relay/src/server/client_message.rs | 14 ++--- rust/relay/tests/regression.rs | 34 ++++++------ 7 files changed, 130 insertions(+), 76 deletions(-) diff --git a/rust/relay/src/allocation.rs b/rust/relay/src/allocation.rs index 9132aef93..b9bd49446 100644 --- a/rust/relay/src/allocation.rs +++ b/rust/relay/src/allocation.rs @@ -1,6 +1,6 @@ -use crate::server::AllocationId; +use crate::server::{AllocationId, ClientToPeer}; use crate::udp_socket::UdpSocket; -use crate::{AddressFamily, PeerSocket}; +use crate::{AddressFamily, PeerSocket, PeerToClient}; use anyhow::{bail, Result}; use futures::channel::mpsc; use futures::{SinkExt, StreamExt}; @@ -17,12 +17,12 @@ pub struct Allocation { /// /// Stored here to make resource-cleanup easy. handle: task::JoinHandle<()>, - sender: mpsc::Sender<(Vec, PeerSocket)>, + sender: mpsc::Sender<(ClientToPeer, PeerSocket)>, } impl Allocation { pub fn new( - relay_data_sender: mpsc::Sender<(Vec, PeerSocket, AllocationId)>, + relay_data_sender: mpsc::Sender<(PeerToClient, PeerSocket, AllocationId)>, id: AllocationId, family: AddressFamily, port: u16, @@ -61,7 +61,7 @@ impl Allocation { /// /// All our data is relayed over UDP which by design is an unreliable protocol. /// Thus, any application running on top of this relay must already account for potential packet loss. - pub fn send(&mut self, data: Vec, recipient: PeerSocket) -> Result<()> { + pub fn send(&mut self, data: ClientToPeer, recipient: PeerSocket) -> Result<()> { match self.sender.try_send((data, recipient)) { Ok(()) => Ok(()), Err(e) if e.is_disconnected() => { @@ -88,8 +88,8 @@ impl Drop for Allocation { } async fn forward_incoming_relay_data( - mut relayed_data_sender: mpsc::Sender<(Vec, PeerSocket, AllocationId)>, - mut client_to_peer_receiver: mpsc::Receiver<(Vec, PeerSocket)>, + mut relayed_data_sender: mpsc::Sender<(PeerToClient, PeerSocket, AllocationId)>, + mut client_to_peer_receiver: mpsc::Receiver<(ClientToPeer, PeerSocket)>, id: AllocationId, family: AddressFamily, port: u16, @@ -100,11 +100,11 @@ async fn forward_incoming_relay_data( tokio::select! { result = socket.recv() => { let (data, sender) = result?; - relayed_data_sender.send((data.to_vec(), PeerSocket::new(sender), id)).await?; + relayed_data_sender.send((PeerToClient::new(data), PeerSocket::new(sender), id)).await?; } Some((data, recipient)) = client_to_peer_receiver.next() => { - socket.send_to(&data, recipient.into_socket()).await?; + socket.send_to(data.data(), recipient.into_socket()).await?; } } } diff --git a/rust/relay/src/lib.rs b/rust/relay/src/lib.rs index 764a5f9b6..dcb5f90d3 100644 --- a/rust/relay/src/lib.rs +++ b/rust/relay/src/lib.rs @@ -13,7 +13,7 @@ pub use allocation::Allocation; pub use net_ext::IpAddrExt; pub use server::{ Allocate, AllocationId, Attribute, Binding, ChannelBind, ChannelData, ClientMessage, Command, - CreatePermission, Refresh, Server, + CreatePermission, PeerToClient, Refresh, Server, }; pub use sleep::Sleep; pub use stun_codec::rfc8656::attributes::AddressFamily; diff --git a/rust/relay/src/main.rs b/rust/relay/src/main.rs index 602583eda..dec8c8abe 100644 --- a/rust/relay/src/main.rs +++ b/rust/relay/src/main.rs @@ -2,8 +2,8 @@ use anyhow::{anyhow, bail, Context, Result}; use backoff::ExponentialBackoffBuilder; use clap::Parser; use firezone_relay::{ - AddressFamily, Allocation, AllocationId, ClientSocket, Command, IpStack, PeerSocket, Server, - Sleep, UdpSocket, + AddressFamily, Allocation, AllocationId, ClientSocket, Command, IpStack, PeerSocket, + PeerToClient, Server, Sleep, UdpSocket, }; use futures::channel::mpsc; use futures::{future, FutureExt, SinkExt, StreamExt}; @@ -305,8 +305,8 @@ struct Eventloop { server: Server, channel: Option>, allocations: HashMap<(AllocationId, AddressFamily), Allocation>, - relay_data_sender: mpsc::Sender<(Vec, PeerSocket, AllocationId)>, - relay_data_receiver: mpsc::Receiver<(Vec, PeerSocket, AllocationId)>, + relay_data_sender: mpsc::Sender<(PeerToClient, PeerSocket, AllocationId)>, + relay_data_receiver: mpsc::Receiver<(PeerToClient, PeerSocket, AllocationId)>, sleep: Sleep, stats_log_interval: tokio::time::Interval, @@ -431,7 +431,7 @@ where Pin::new(&mut self.sleep).reset(deadline); } - Command::ForwardData { id, data, receiver } => { + Command::ForwardDataClientToPeer { id, data, receiver } => { let span = tracing::debug_span!("Command::ForwardData", %id, %receiver); let _guard = span.enter(); @@ -463,7 +463,7 @@ where if let Poll::Ready(Some((data, sender, allocation))) = self.relay_data_receiver.poll_next_unpin(cx) { - self.server.handle_peer_traffic(&data, sender, allocation); + self.server.handle_peer_traffic(data, sender, allocation); continue; // Handle potentially new commands. } @@ -471,7 +471,7 @@ where if let Poll::Ready(Some((buffer, sender))) = self.inbound_data_receiver.poll_next_unpin(cx) { - self.server.handle_client_input(&buffer, sender, now); + self.server.handle_client_input(buffer, sender, now); continue; // Handle potentially new commands. } diff --git a/rust/relay/src/server.rs b/rust/relay/src/server.rs index 4d6392f96..a313c5bbd 100644 --- a/rust/relay/src/server.rs +++ b/rust/relay/src/server.rs @@ -11,6 +11,7 @@ use crate::net_ext::IpAddrExt; use crate::{ClientSocket, IpStack, PeerSocket, TimeEvents}; use anyhow::Result; use bytecodec::EncodeExt; +use bytes::BytesMut; use core::fmt; use opentelemetry::metrics::{Counter, Unit, UpDownCounter}; use opentelemetry::KeyValue; @@ -107,15 +108,50 @@ pub enum Command { family: AddressFamily, }, - ForwardData { + ForwardDataClientToPeer { id: AllocationId, - data: Vec, + data: ClientToPeer, receiver: PeerSocket, }, + /// At the latest, the [`Server`] needs to be woken at the specified deadline to execute time-based actions correctly. Wake { deadline: SystemTime }, } +#[derive(Debug, PartialEq)] +pub struct ClientToPeer(ChannelData); + +#[derive(Debug, PartialEq)] +pub struct PeerToClient { + buf: BytesMut, +} + +impl PeerToClient { + pub fn new(msg: &[u8]) -> Self { + let mut buf = BytesMut::zeroed(msg.len() + 4); + buf[4..].copy_from_slice(msg); + + Self { buf } + } + + fn len(&self) -> usize { + self.buf.len() - 4 + } + + fn header_mut(&mut self) -> &mut [u8] { + &mut self.buf[..4] + } +} + +impl ClientToPeer { + /// Extract the data to forward to the peer. + /// + /// Data from clients arrives in [`ChannelData`] messages and we only forward the actual payload. + pub fn data(&self) -> &[u8] { + self.0.data() + } +} + #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)] pub struct AllocationId(u64); @@ -230,7 +266,7 @@ where /// /// After calling this method, you should call [`Server::next_command`] until it returns `None`. #[tracing::instrument(level = "debug", skip_all, fields(transaction_id, %sender, allocation, channel, recipient, peer))] - pub fn handle_client_input(&mut self, bytes: &[u8], sender: ClientSocket, now: SystemTime) { + pub fn handle_client_input(&mut self, bytes: Vec, sender: ClientSocket, now: SystemTime) { tracing::trace!(target: "wire", num_bytes = %bytes.len()); match self.decoder.decode(bytes) { @@ -324,11 +360,11 @@ where #[tracing::instrument(level = "debug", skip_all, fields(%sender, %allocation, recipient, channel))] pub fn handle_peer_traffic( &mut self, - bytes: &[u8], + mut msg: PeerToClient, sender: PeerSocket, allocation: AllocationId, ) { - tracing::trace!(target: "wire", num_bytes = %bytes.len()); + tracing::trace!(target: "wire", num_bytes = %msg.len()); let Some(client) = self.clients_by_allocation.get(&allocation).copied() else { tracing::debug!(target: "relay", "unknown allocation"); @@ -341,7 +377,7 @@ where .channel_numbers_by_client_and_peer .get(&(client, sender)) else { - tracing::debug!(target: "relay", "no active channel, refusing to relay {} bytes", bytes.len()); + tracing::debug!(target: "relay", "no active channel, refusing to relay {} bytes", msg.len()); return; }; @@ -365,15 +401,15 @@ where return; } - tracing::trace!(target: "wire", num_bytes = %bytes.len()); + tracing::trace!(target: "wire", num_bytes = %msg.len()); - self.data_relayed_counter.add(bytes.len() as u64, &[]); - self.data_relayed += bytes.len() as u64; + self.data_relayed_counter.add(msg.len() as u64, &[]); + self.data_relayed += msg.len() as u64; - let data = ChannelData::new(*channel_number, bytes).to_bytes(); + channel_data::encode_to_slice(*channel_number, msg.len() as u16, msg.header_mut()); self.pending_commands.push_back(Command::SendMessage { - payload: data, + payload: msg.buf.freeze().into(), recipient: client, }) } @@ -774,11 +810,12 @@ where self.data_relayed_counter.add(data.len() as u64, &[]); self.data_relayed += data.len() as u64; - self.pending_commands.push_back(Command::ForwardData { - id: channel.allocation, - data: data.to_vec(), - receiver: channel.peer_address, - }); + self.pending_commands + .push_back(Command::ForwardDataClientToPeer { + id: channel.allocation, + data: ClientToPeer(message), + receiver: channel.peer_address, + }); } fn verify_auth( diff --git a/rust/relay/src/server/channel_data.rs b/rust/relay/src/server/channel_data.rs index d1781ab3a..0e1f3732f 100644 --- a/rust/relay/src/server/channel_data.rs +++ b/rust/relay/src/server/channel_data.rs @@ -3,22 +3,23 @@ use std::io; const HEADER_LEN: usize = 4; -#[derive(Debug, PartialEq)] -pub struct ChannelData<'a> { +#[derive(Debug, PartialEq, Clone)] +pub struct ChannelData { channel: u16, - data: &'a [u8], + length: usize, + msg: Vec, } -impl<'a> ChannelData<'a> { - pub fn parse(data: &'a [u8]) -> Result { - if data.len() < HEADER_LEN { +impl ChannelData { + pub fn parse(msg: Vec) -> Result { + if msg.len() < HEADER_LEN { return Err(io::Error::new( io::ErrorKind::UnexpectedEof, "channel data messages are at least 4 bytes long", )); } - let (header, payload) = data.split_at(HEADER_LEN); + let (header, payload) = msg.split_at(HEADER_LEN); let channel_number = u16::from_be_bytes([header[0], header[1]]); if !(0x4000..=0x7FFF).contains(&channel_number) { @@ -41,20 +42,30 @@ impl<'a> ChannelData<'a> { Ok(ChannelData { channel: channel_number, - data: &payload[..length], + msg, + length, }) } - pub fn new(channel: u16, data: &'a [u8]) -> Self { + pub fn new(channel: u16, data: &[u8]) -> Self { debug_assert!(channel > 0x400); debug_assert!(channel < 0x7FFF); debug_assert!(data.len() <= u16::MAX as usize); - ChannelData { channel, data } + + let length = data.len(); + + let msg = to_bytes(channel, length as u16, data); + + ChannelData { + channel, + msg, + length, + } } // Panics if self.data.len() > u16::MAX - pub fn to_bytes(&self) -> Vec { - to_bytes(self.channel, self.data.len() as u16, self.data) + pub fn into_msg(self) -> Vec { + self.msg } pub fn channel(&self) -> u16 { @@ -62,15 +73,21 @@ impl<'a> ChannelData<'a> { } pub fn data(&self) -> &[u8] { - self.data + let (_, payload) = self.msg.split_at(HEADER_LEN); + + &payload[..self.length] } } +pub fn encode_to_slice(channel: u16, data_len: u16, mut header: impl BufMut) { + header.put_u16(channel); + header.put_u16(data_len); +} + fn to_bytes(channel: u16, len: u16, payload: &[u8]) -> Vec { let mut message = BytesMut::with_capacity(HEADER_LEN + (len as usize)); - message.put_u16(channel); - message.put_u16(len); + encode_to_slice(channel, len, &mut message); message.put_slice(payload); message.freeze().into() @@ -87,9 +104,9 @@ mod tests { payload: Vec, ) { let channel_data = ChannelData::new(channel.value(), &payload); - let encoded = channel_data.to_bytes(); + let encoded = channel_data.clone().into_msg(); - let parsed = ChannelData::parse(&encoded).unwrap(); + let parsed = ChannelData::parse(encoded).unwrap(); assert_eq!(channel_data, parsed) } @@ -100,9 +117,9 @@ mod tests { #[strategy(crate::proptest::channel_payload())] payload: (Vec, u16), ) { let encoded = to_bytes(channel.value(), payload.1, &payload.0); - let parsed = ChannelData::parse(&encoded).unwrap(); + let parsed = ChannelData::parse(encoded).unwrap(); assert_eq!(channel.value(), parsed.channel); - assert_eq!(&payload.0[..(payload.1 as usize)], parsed.data) + assert_eq!(&payload.0[..(payload.1 as usize)], parsed.data()) } } diff --git a/rust/relay/src/server/client_message.rs b/rust/relay/src/server/client_message.rs index 8c6ef6528..daf519158 100644 --- a/rust/relay/src/server/client_message.rs +++ b/rust/relay/src/server/client_message.rs @@ -33,14 +33,14 @@ pub struct Decoder { } impl Decoder { - pub fn decode<'a>( + pub fn decode( &mut self, - input: &'a [u8], - ) -> Result, Message>, Error> { + input: Vec, + ) -> Result>, Error> { // De-multiplex as per . match input.first() { Some(0..=3) => { - let message = match self.stun_message_decoder.decode_from_bytes(input)? { + let message = match self.stun_message_decoder.decode_from_bytes(&input)? { Ok(message) => message, Err(broken_message) => { let method = broken_message.method(); @@ -88,8 +88,8 @@ impl Decoder { } #[derive(derive_more::From)] -pub enum ClientMessage<'a> { - ChannelData(ChannelData<'a>), +pub enum ClientMessage { + ChannelData(ChannelData), Binding(Binding), Allocate(Allocate), Refresh(Refresh), @@ -97,7 +97,7 @@ pub enum ClientMessage<'a> { CreatePermission(CreatePermission), } -impl<'a> ClientMessage<'a> { +impl ClientMessage { pub fn transaction_id(&self) -> Option { match self { ClientMessage::Binding(request) => Some(request.transaction_id), diff --git a/rust/relay/tests/regression.rs b/rust/relay/tests/regression.rs index a1a35a136..ad38fbcb2 100644 --- a/rust/relay/tests/regression.rs +++ b/rust/relay/tests/regression.rs @@ -1,7 +1,7 @@ use bytecodec::{DecodeExt, EncodeExt}; use firezone_relay::{ AddressFamily, Allocate, AllocationId, Attribute, Binding, ChannelBind, ChannelData, - ClientMessage, ClientSocket, Command, IpStack, PeerSocket, Refresh, Server, + ClientMessage, ClientSocket, Command, IpStack, PeerSocket, PeerToClient, Refresh, Server, }; use rand::rngs::mock::StepRng; use secrecy::SecretString; @@ -607,7 +607,7 @@ impl TestServer { } Input::Peer(peer, data, port) => { self.server - .handle_peer_traffic(&data, peer, self.id_to_port[&port]); + .handle_peer_traffic(data, peer, self.id_to_port[&port]); } } @@ -699,7 +699,7 @@ impl TestServer { Output::SendChannelData((peer, channeldata)), Command::SendMessage { recipient, payload }, ) => { - let expected_channel_data = hex::encode(channeldata.to_bytes()); + let expected_channel_data = hex::encode(channeldata.into_msg()); let actual_message = hex::encode(payload); assert_eq!(expected_channel_data, actual_message); @@ -707,13 +707,13 @@ impl TestServer { } ( Output::Forward((peer, expected_data, port)), - Command::ForwardData { + Command::ForwardDataClientToPeer { id, data: actual_data, receiver, }, ) => { - assert_eq!(hex::encode(expected_data), hex::encode(actual_data)); + assert_eq!(hex::encode(expected_data), hex::encode(actual_data.data())); assert_eq!(receiver, peer); assert_eq!(self.id_to_port[&port], id); } @@ -800,39 +800,39 @@ fn parse_message(message: &[u8]) -> Message { .unwrap() } -enum Input<'a> { - Client(ClientSocket, ClientMessage<'a>, SystemTime), - Peer(PeerSocket, Vec, u16), +enum Input { + Client(ClientSocket, ClientMessage, SystemTime), + Peer(PeerSocket, PeerToClient, u16), Time(SystemTime), } -fn from_client<'a>( +fn from_client( from: impl Into, - message: impl Into>, + message: impl Into, now: SystemTime, -) -> Input<'a> { +) -> Input { Input::Client(ClientSocket::new(from.into()), message.into(), now) } -fn from_peer<'a>(from: impl Into, data: &[u8], port: u16) -> Input<'a> { - Input::Peer(PeerSocket::new(from.into()), data.to_vec(), port) +fn from_peer(from: impl Into, data: &[u8], port: u16) -> Input { + Input::Peer(PeerSocket::new(from.into()), PeerToClient::new(data), port) } -fn forward_time_to<'a>(when: SystemTime) -> Input<'a> { +fn forward_time_to(when: SystemTime) -> Input { Input::Time(when) } #[derive(Debug)] -enum Output<'a> { +enum Output { SendMessage((ClientSocket, Message)), - SendChannelData((ClientSocket, ChannelData<'a>)), + SendChannelData((ClientSocket, ChannelData)), Forward((PeerSocket, Vec, u16)), Wake(SystemTime), CreateAllocation(u16, AddressFamily), FreeAllocation(u16, AddressFamily), } -fn send_message<'a>(source: impl Into, message: Message) -> Output<'a> { +fn send_message(source: impl Into, message: Message) -> Output { Output::SendMessage((ClientSocket::new(source.into()), message)) }