refactor(relay): reduce allocations during relaying (#4453)

Previously, we would allocate each message twice:

1. When receiving the original packet.
2. When forming the resulting channel-data message.

We can optimise this to only one allocation each by:

1. Carrying around the original `ChannelData` message for traffic from
clients to peers.
2. Pre-allocating enough space for the channel-data header for traffic
from peers to clients.

Local flamegraphing still shows most of user-space activity as
allocations. I did occasionally see a throughput of ~10GBps with these
patches. I'd like to still work towards #4095 to ensure we handle
anything time-sensitive better.
This commit is contained in:
Thomas Eizinger
2024-04-03 09:00:36 +11:00
committed by GitHub
parent 178e0e6170
commit 5f718ad982
7 changed files with 130 additions and 76 deletions

View File

@@ -1,6 +1,6 @@
use crate::server::AllocationId;
use crate::server::{AllocationId, ClientToPeer};
use crate::udp_socket::UdpSocket;
use crate::{AddressFamily, PeerSocket};
use crate::{AddressFamily, PeerSocket, PeerToClient};
use anyhow::{bail, Result};
use futures::channel::mpsc;
use futures::{SinkExt, StreamExt};
@@ -17,12 +17,12 @@ pub struct Allocation {
///
/// Stored here to make resource-cleanup easy.
handle: task::JoinHandle<()>,
sender: mpsc::Sender<(Vec<u8>, PeerSocket)>,
sender: mpsc::Sender<(ClientToPeer, PeerSocket)>,
}
impl Allocation {
pub fn new(
relay_data_sender: mpsc::Sender<(Vec<u8>, PeerSocket, AllocationId)>,
relay_data_sender: mpsc::Sender<(PeerToClient, PeerSocket, AllocationId)>,
id: AllocationId,
family: AddressFamily,
port: u16,
@@ -61,7 +61,7 @@ impl Allocation {
///
/// All our data is relayed over UDP which by design is an unreliable protocol.
/// Thus, any application running on top of this relay must already account for potential packet loss.
pub fn send(&mut self, data: Vec<u8>, recipient: PeerSocket) -> Result<()> {
pub fn send(&mut self, data: ClientToPeer, recipient: PeerSocket) -> Result<()> {
match self.sender.try_send((data, recipient)) {
Ok(()) => Ok(()),
Err(e) if e.is_disconnected() => {
@@ -88,8 +88,8 @@ impl Drop for Allocation {
}
async fn forward_incoming_relay_data(
mut relayed_data_sender: mpsc::Sender<(Vec<u8>, PeerSocket, AllocationId)>,
mut client_to_peer_receiver: mpsc::Receiver<(Vec<u8>, PeerSocket)>,
mut relayed_data_sender: mpsc::Sender<(PeerToClient, PeerSocket, AllocationId)>,
mut client_to_peer_receiver: mpsc::Receiver<(ClientToPeer, PeerSocket)>,
id: AllocationId,
family: AddressFamily,
port: u16,
@@ -100,11 +100,11 @@ async fn forward_incoming_relay_data(
tokio::select! {
result = socket.recv() => {
let (data, sender) = result?;
relayed_data_sender.send((data.to_vec(), PeerSocket::new(sender), id)).await?;
relayed_data_sender.send((PeerToClient::new(data), PeerSocket::new(sender), id)).await?;
}
Some((data, recipient)) = client_to_peer_receiver.next() => {
socket.send_to(&data, recipient.into_socket()).await?;
socket.send_to(data.data(), recipient.into_socket()).await?;
}
}
}

View File

@@ -13,7 +13,7 @@ pub use allocation::Allocation;
pub use net_ext::IpAddrExt;
pub use server::{
Allocate, AllocationId, Attribute, Binding, ChannelBind, ChannelData, ClientMessage, Command,
CreatePermission, Refresh, Server,
CreatePermission, PeerToClient, Refresh, Server,
};
pub use sleep::Sleep;
pub use stun_codec::rfc8656::attributes::AddressFamily;

View File

@@ -2,8 +2,8 @@ use anyhow::{anyhow, bail, Context, Result};
use backoff::ExponentialBackoffBuilder;
use clap::Parser;
use firezone_relay::{
AddressFamily, Allocation, AllocationId, ClientSocket, Command, IpStack, PeerSocket, Server,
Sleep, UdpSocket,
AddressFamily, Allocation, AllocationId, ClientSocket, Command, IpStack, PeerSocket,
PeerToClient, Server, Sleep, UdpSocket,
};
use futures::channel::mpsc;
use futures::{future, FutureExt, SinkExt, StreamExt};
@@ -305,8 +305,8 @@ struct Eventloop<R> {
server: Server<R>,
channel: Option<PhoenixChannel<JoinMessage, (), ()>>,
allocations: HashMap<(AllocationId, AddressFamily), Allocation>,
relay_data_sender: mpsc::Sender<(Vec<u8>, PeerSocket, AllocationId)>,
relay_data_receiver: mpsc::Receiver<(Vec<u8>, PeerSocket, AllocationId)>,
relay_data_sender: mpsc::Sender<(PeerToClient, PeerSocket, AllocationId)>,
relay_data_receiver: mpsc::Receiver<(PeerToClient, PeerSocket, AllocationId)>,
sleep: Sleep,
stats_log_interval: tokio::time::Interval,
@@ -431,7 +431,7 @@ where
Pin::new(&mut self.sleep).reset(deadline);
}
Command::ForwardData { id, data, receiver } => {
Command::ForwardDataClientToPeer { id, data, receiver } => {
let span = tracing::debug_span!("Command::ForwardData", %id, %receiver);
let _guard = span.enter();
@@ -463,7 +463,7 @@ where
if let Poll::Ready(Some((data, sender, allocation))) =
self.relay_data_receiver.poll_next_unpin(cx)
{
self.server.handle_peer_traffic(&data, sender, allocation);
self.server.handle_peer_traffic(data, sender, allocation);
continue; // Handle potentially new commands.
}
@@ -471,7 +471,7 @@ where
if let Poll::Ready(Some((buffer, sender))) =
self.inbound_data_receiver.poll_next_unpin(cx)
{
self.server.handle_client_input(&buffer, sender, now);
self.server.handle_client_input(buffer, sender, now);
continue; // Handle potentially new commands.
}

View File

@@ -11,6 +11,7 @@ use crate::net_ext::IpAddrExt;
use crate::{ClientSocket, IpStack, PeerSocket, TimeEvents};
use anyhow::Result;
use bytecodec::EncodeExt;
use bytes::BytesMut;
use core::fmt;
use opentelemetry::metrics::{Counter, Unit, UpDownCounter};
use opentelemetry::KeyValue;
@@ -107,15 +108,50 @@ pub enum Command {
family: AddressFamily,
},
ForwardData {
ForwardDataClientToPeer {
id: AllocationId,
data: Vec<u8>,
data: ClientToPeer,
receiver: PeerSocket,
},
/// At the latest, the [`Server`] needs to be woken at the specified deadline to execute time-based actions correctly.
Wake { deadline: SystemTime },
}
#[derive(Debug, PartialEq)]
pub struct ClientToPeer(ChannelData);
#[derive(Debug, PartialEq)]
pub struct PeerToClient {
buf: BytesMut,
}
impl PeerToClient {
pub fn new(msg: &[u8]) -> Self {
let mut buf = BytesMut::zeroed(msg.len() + 4);
buf[4..].copy_from_slice(msg);
Self { buf }
}
fn len(&self) -> usize {
self.buf.len() - 4
}
fn header_mut(&mut self) -> &mut [u8] {
&mut self.buf[..4]
}
}
impl ClientToPeer {
/// Extract the data to forward to the peer.
///
/// Data from clients arrives in [`ChannelData`] messages and we only forward the actual payload.
pub fn data(&self) -> &[u8] {
self.0.data()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
pub struct AllocationId(u64);
@@ -230,7 +266,7 @@ where
///
/// After calling this method, you should call [`Server::next_command`] until it returns `None`.
#[tracing::instrument(level = "debug", skip_all, fields(transaction_id, %sender, allocation, channel, recipient, peer))]
pub fn handle_client_input(&mut self, bytes: &[u8], sender: ClientSocket, now: SystemTime) {
pub fn handle_client_input(&mut self, bytes: Vec<u8>, sender: ClientSocket, now: SystemTime) {
tracing::trace!(target: "wire", num_bytes = %bytes.len());
match self.decoder.decode(bytes) {
@@ -324,11 +360,11 @@ where
#[tracing::instrument(level = "debug", skip_all, fields(%sender, %allocation, recipient, channel))]
pub fn handle_peer_traffic(
&mut self,
bytes: &[u8],
mut msg: PeerToClient,
sender: PeerSocket,
allocation: AllocationId,
) {
tracing::trace!(target: "wire", num_bytes = %bytes.len());
tracing::trace!(target: "wire", num_bytes = %msg.len());
let Some(client) = self.clients_by_allocation.get(&allocation).copied() else {
tracing::debug!(target: "relay", "unknown allocation");
@@ -341,7 +377,7 @@ where
.channel_numbers_by_client_and_peer
.get(&(client, sender))
else {
tracing::debug!(target: "relay", "no active channel, refusing to relay {} bytes", bytes.len());
tracing::debug!(target: "relay", "no active channel, refusing to relay {} bytes", msg.len());
return;
};
@@ -365,15 +401,15 @@ where
return;
}
tracing::trace!(target: "wire", num_bytes = %bytes.len());
tracing::trace!(target: "wire", num_bytes = %msg.len());
self.data_relayed_counter.add(bytes.len() as u64, &[]);
self.data_relayed += bytes.len() as u64;
self.data_relayed_counter.add(msg.len() as u64, &[]);
self.data_relayed += msg.len() as u64;
let data = ChannelData::new(*channel_number, bytes).to_bytes();
channel_data::encode_to_slice(*channel_number, msg.len() as u16, msg.header_mut());
self.pending_commands.push_back(Command::SendMessage {
payload: data,
payload: msg.buf.freeze().into(),
recipient: client,
})
}
@@ -774,11 +810,12 @@ where
self.data_relayed_counter.add(data.len() as u64, &[]);
self.data_relayed += data.len() as u64;
self.pending_commands.push_back(Command::ForwardData {
id: channel.allocation,
data: data.to_vec(),
receiver: channel.peer_address,
});
self.pending_commands
.push_back(Command::ForwardDataClientToPeer {
id: channel.allocation,
data: ClientToPeer(message),
receiver: channel.peer_address,
});
}
fn verify_auth(

View File

@@ -3,22 +3,23 @@ use std::io;
const HEADER_LEN: usize = 4;
#[derive(Debug, PartialEq)]
pub struct ChannelData<'a> {
#[derive(Debug, PartialEq, Clone)]
pub struct ChannelData {
channel: u16,
data: &'a [u8],
length: usize,
msg: Vec<u8>,
}
impl<'a> ChannelData<'a> {
pub fn parse(data: &'a [u8]) -> Result<Self, io::Error> {
if data.len() < HEADER_LEN {
impl ChannelData {
pub fn parse(msg: Vec<u8>) -> Result<Self, io::Error> {
if msg.len() < HEADER_LEN {
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"channel data messages are at least 4 bytes long",
));
}
let (header, payload) = data.split_at(HEADER_LEN);
let (header, payload) = msg.split_at(HEADER_LEN);
let channel_number = u16::from_be_bytes([header[0], header[1]]);
if !(0x4000..=0x7FFF).contains(&channel_number) {
@@ -41,20 +42,30 @@ impl<'a> ChannelData<'a> {
Ok(ChannelData {
channel: channel_number,
data: &payload[..length],
msg,
length,
})
}
pub fn new(channel: u16, data: &'a [u8]) -> Self {
pub fn new(channel: u16, data: &[u8]) -> Self {
debug_assert!(channel > 0x400);
debug_assert!(channel < 0x7FFF);
debug_assert!(data.len() <= u16::MAX as usize);
ChannelData { channel, data }
let length = data.len();
let msg = to_bytes(channel, length as u16, data);
ChannelData {
channel,
msg,
length,
}
}
// Panics if self.data.len() > u16::MAX
pub fn to_bytes(&self) -> Vec<u8> {
to_bytes(self.channel, self.data.len() as u16, self.data)
pub fn into_msg(self) -> Vec<u8> {
self.msg
}
pub fn channel(&self) -> u16 {
@@ -62,15 +73,21 @@ impl<'a> ChannelData<'a> {
}
pub fn data(&self) -> &[u8] {
self.data
let (_, payload) = self.msg.split_at(HEADER_LEN);
&payload[..self.length]
}
}
pub fn encode_to_slice(channel: u16, data_len: u16, mut header: impl BufMut) {
header.put_u16(channel);
header.put_u16(data_len);
}
fn to_bytes(channel: u16, len: u16, payload: &[u8]) -> Vec<u8> {
let mut message = BytesMut::with_capacity(HEADER_LEN + (len as usize));
message.put_u16(channel);
message.put_u16(len);
encode_to_slice(channel, len, &mut message);
message.put_slice(payload);
message.freeze().into()
@@ -87,9 +104,9 @@ mod tests {
payload: Vec<u8>,
) {
let channel_data = ChannelData::new(channel.value(), &payload);
let encoded = channel_data.to_bytes();
let encoded = channel_data.clone().into_msg();
let parsed = ChannelData::parse(&encoded).unwrap();
let parsed = ChannelData::parse(encoded).unwrap();
assert_eq!(channel_data, parsed)
}
@@ -100,9 +117,9 @@ mod tests {
#[strategy(crate::proptest::channel_payload())] payload: (Vec<u8>, u16),
) {
let encoded = to_bytes(channel.value(), payload.1, &payload.0);
let parsed = ChannelData::parse(&encoded).unwrap();
let parsed = ChannelData::parse(encoded).unwrap();
assert_eq!(channel.value(), parsed.channel);
assert_eq!(&payload.0[..(payload.1 as usize)], parsed.data)
assert_eq!(&payload.0[..(payload.1 as usize)], parsed.data())
}
}

View File

@@ -33,14 +33,14 @@ pub struct Decoder {
}
impl Decoder {
pub fn decode<'a>(
pub fn decode(
&mut self,
input: &'a [u8],
) -> Result<Result<ClientMessage<'a>, Message<Attribute>>, Error> {
input: Vec<u8>,
) -> Result<Result<ClientMessage, Message<Attribute>>, Error> {
// De-multiplex as per <https://www.rfc-editor.org/rfc/rfc8656#name-channels-2>.
match input.first() {
Some(0..=3) => {
let message = match self.stun_message_decoder.decode_from_bytes(input)? {
let message = match self.stun_message_decoder.decode_from_bytes(&input)? {
Ok(message) => message,
Err(broken_message) => {
let method = broken_message.method();
@@ -88,8 +88,8 @@ impl Decoder {
}
#[derive(derive_more::From)]
pub enum ClientMessage<'a> {
ChannelData(ChannelData<'a>),
pub enum ClientMessage {
ChannelData(ChannelData),
Binding(Binding),
Allocate(Allocate),
Refresh(Refresh),
@@ -97,7 +97,7 @@ pub enum ClientMessage<'a> {
CreatePermission(CreatePermission),
}
impl<'a> ClientMessage<'a> {
impl ClientMessage {
pub fn transaction_id(&self) -> Option<TransactionId> {
match self {
ClientMessage::Binding(request) => Some(request.transaction_id),

View File

@@ -1,7 +1,7 @@
use bytecodec::{DecodeExt, EncodeExt};
use firezone_relay::{
AddressFamily, Allocate, AllocationId, Attribute, Binding, ChannelBind, ChannelData,
ClientMessage, ClientSocket, Command, IpStack, PeerSocket, Refresh, Server,
ClientMessage, ClientSocket, Command, IpStack, PeerSocket, PeerToClient, Refresh, Server,
};
use rand::rngs::mock::StepRng;
use secrecy::SecretString;
@@ -607,7 +607,7 @@ impl TestServer {
}
Input::Peer(peer, data, port) => {
self.server
.handle_peer_traffic(&data, peer, self.id_to_port[&port]);
.handle_peer_traffic(data, peer, self.id_to_port[&port]);
}
}
@@ -699,7 +699,7 @@ impl TestServer {
Output::SendChannelData((peer, channeldata)),
Command::SendMessage { recipient, payload },
) => {
let expected_channel_data = hex::encode(channeldata.to_bytes());
let expected_channel_data = hex::encode(channeldata.into_msg());
let actual_message = hex::encode(payload);
assert_eq!(expected_channel_data, actual_message);
@@ -707,13 +707,13 @@ impl TestServer {
}
(
Output::Forward((peer, expected_data, port)),
Command::ForwardData {
Command::ForwardDataClientToPeer {
id,
data: actual_data,
receiver,
},
) => {
assert_eq!(hex::encode(expected_data), hex::encode(actual_data));
assert_eq!(hex::encode(expected_data), hex::encode(actual_data.data()));
assert_eq!(receiver, peer);
assert_eq!(self.id_to_port[&port], id);
}
@@ -800,39 +800,39 @@ fn parse_message(message: &[u8]) -> Message<Attribute> {
.unwrap()
}
enum Input<'a> {
Client(ClientSocket, ClientMessage<'a>, SystemTime),
Peer(PeerSocket, Vec<u8>, u16),
enum Input {
Client(ClientSocket, ClientMessage, SystemTime),
Peer(PeerSocket, PeerToClient, u16),
Time(SystemTime),
}
fn from_client<'a>(
fn from_client(
from: impl Into<SocketAddr>,
message: impl Into<ClientMessage<'a>>,
message: impl Into<ClientMessage>,
now: SystemTime,
) -> Input<'a> {
) -> Input {
Input::Client(ClientSocket::new(from.into()), message.into(), now)
}
fn from_peer<'a>(from: impl Into<SocketAddr>, data: &[u8], port: u16) -> Input<'a> {
Input::Peer(PeerSocket::new(from.into()), data.to_vec(), port)
fn from_peer(from: impl Into<SocketAddr>, data: &[u8], port: u16) -> Input {
Input::Peer(PeerSocket::new(from.into()), PeerToClient::new(data), port)
}
fn forward_time_to<'a>(when: SystemTime) -> Input<'a> {
fn forward_time_to(when: SystemTime) -> Input {
Input::Time(when)
}
#[derive(Debug)]
enum Output<'a> {
enum Output {
SendMessage((ClientSocket, Message<Attribute>)),
SendChannelData((ClientSocket, ChannelData<'a>)),
SendChannelData((ClientSocket, ChannelData)),
Forward((PeerSocket, Vec<u8>, u16)),
Wake(SystemTime),
CreateAllocation(u16, AddressFamily),
FreeAllocation(u16, AddressFamily),
}
fn send_message<'a>(source: impl Into<SocketAddr>, message: Message<Attribute>) -> Output<'a> {
fn send_message(source: impl Into<SocketAddr>, message: Message<Attribute>) -> Output {
Output::SendMessage((ClientSocket::new(source.into()), message))
}