From 17dfdb63d437d30e5a6e39bc5e675362ce49ad24 Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Mon, 31 Jul 2023 23:39:31 +0200 Subject: [PATCH] feat(relay): handle failed allocations (#1831) This patch series refactors how we handle allocations in the relay to make it easier to forward a failure to the `Server`. Each allocation runs in a separate task (to allow for parallelization). If the allocation fails, this channel is automatically closed. Previously, this would erroneously trigger a `debug_assert!`. Now, we invoke a callback on `Server` to allow it to clean up its internal resources for the allocation. At the same time, we simplify the buffering around data that is destined for a certain allocation. Instead of having an additional buffer in the event-loop, we increase the channel size to 10. Any exceeding items will be dropped to avoid memory growth. This means that the `Server` is never blocked on a slow allocation. Given that we are running on top of an unreliable protocol anyway, I'd say this is fine. --- rust/relay/src/allocation.rs | 103 ++++++++++++++++++++++++++ rust/relay/src/lib.rs | 2 + rust/relay/src/main.rs | 135 +++++------------------------------ rust/relay/src/server.rs | 5 ++ 4 files changed, 129 insertions(+), 116 deletions(-) create mode 100644 rust/relay/src/allocation.rs diff --git a/rust/relay/src/allocation.rs b/rust/relay/src/allocation.rs new file mode 100644 index 000000000..fadaf8094 --- /dev/null +++ b/rust/relay/src/allocation.rs @@ -0,0 +1,103 @@ +use crate::server::AllocationId; +use crate::udp_socket::UdpSocket; +use anyhow::{bail, Result}; +use futures::channel::mpsc; +use futures::{SinkExt, StreamExt}; +use std::convert::Infallible; +use std::net::{Ipv4Addr, SocketAddr}; +use tokio::task; + +/// The maximum amount of items that can be buffered in the channel to the allocation task. +const MAX_BUFFERED_ITEMS: usize = 10; + +pub struct Allocation { + id: AllocationId, + + /// The handle to the task that is running the allocation. + /// + /// Stored here to make resource-cleanup easy. + handle: task::JoinHandle<()>, + sender: mpsc::Sender<(Vec, SocketAddr)>, +} + +impl Allocation { + pub fn new( + relay_data_sender: mpsc::Sender<(Vec, SocketAddr, AllocationId)>, + id: AllocationId, + listen_ip4_addr: Ipv4Addr, + port: u16, + ) -> Self { + let (client_to_peer_sender, client_to_peer_receiver) = mpsc::channel(MAX_BUFFERED_ITEMS); + + let task = tokio::spawn(async move { + let Err(e) = forward_incoming_relay_data(relay_data_sender, client_to_peer_receiver, id, listen_ip4_addr, port).await else { + unreachable!() + }; + + tracing::warn!("Allocation task for {id} failed: {e}"); + + // With the task stopping, the channel will be closed and any attempt to send data to it will fail. + }); + + Self { + id, + handle: task, + sender: client_to_peer_sender, + } + } + + /// Send data to a peer on this allocation. + /// + /// In case the channel is full, we will simply drop the packet and log a warning. + /// In normal operation, this should not happen but if for some reason, the allocation task cannot keep up with the incoming data, we need to drop packets somewhere to avoid unbounded memory growth. + /// + /// All our data is relayed over UDP which by design is an unreliable protocol. + /// Thus, any application running on top of this relay must already account for potential packet loss. + pub fn send(&mut self, data: Vec, recipient: SocketAddr) -> Result<()> { + match self.sender.try_send((data, recipient)) { + Ok(()) => Ok(()), + Err(e) if e.is_disconnected() => { + tracing::warn!(allocation = %self.id, %recipient, "Channel to allocation is disconnected"); + bail!("Channel to allocation {} is disconnected", self.id) + } + Err(e) if e.is_full() => { + tracing::warn!(allocation = %self.id, "Send buffer for allocation is full, dropping packet"); + Ok(()) + } + Err(_) => { + // Fail in debug, but not in release mode. + debug_assert!(false, "TrySendError only has two variants"); + Ok(()) + } + } + } +} + +impl Drop for Allocation { + fn drop(&mut self) { + self.handle.abort(); + } +} + +async fn forward_incoming_relay_data( + mut relayed_data_sender: mpsc::Sender<(Vec, SocketAddr, AllocationId)>, + mut client_to_peer_receiver: mpsc::Receiver<(Vec, SocketAddr)>, + id: AllocationId, + listen_ip4_addr: Ipv4Addr, + port: u16, +) -> Result { + let mut socket = UdpSocket::bind((listen_ip4_addr, port)).await?; + + loop { + tokio::select! { + result = socket.recv() => { + let (data, sender) = result?; + relayed_data_sender.send((data.to_vec(), sender, id)).await?; + } + + Some((data, recipient)) = client_to_peer_receiver.next() => { + socket.send_to(&data, recipient).await?; + } + } + } +} diff --git a/rust/relay/src/lib.rs b/rust/relay/src/lib.rs index 8202a4574..b6dfeab43 100644 --- a/rust/relay/src/lib.rs +++ b/rust/relay/src/lib.rs @@ -1,3 +1,4 @@ +mod allocation; mod auth; mod rfc8656; mod server; @@ -10,6 +11,7 @@ pub mod metrics; #[cfg(feature = "proptest")] pub mod proptest; +pub use allocation::Allocation; pub use server::{ Allocate, AllocationId, Attribute, Binding, ChannelBind, ChannelData, ClientMessage, Command, CreatePermission, Refresh, Server, diff --git a/rust/relay/src/main.rs b/rust/relay/src/main.rs index 2b28e9f75..58d8cb15e 100644 --- a/rust/relay/src/main.rs +++ b/rust/relay/src/main.rs @@ -6,14 +6,14 @@ use phoenix_channel::{Error, Event, PhoenixChannel}; use prometheus_client::registry::Registry; use rand::rngs::StdRng; use rand::{Rng, SeedableRng}; -use relay::{AllocationId, Command, Server, Sleep, UdpSocket}; -use std::collections::{HashMap, VecDeque}; +use relay::{Allocation, AllocationId, Command, Server, Sleep, UdpSocket}; +use std::collections::hash_map::Entry; +use std::collections::HashMap; use std::convert::Infallible; use std::net::{Ipv4Addr, SocketAddr}; use std::pin::Pin; use std::task::Poll; use std::time::SystemTime; -use tokio::task; use tracing::level_filters::LevelFilter; use tracing_subscriber::layer::SubscriberExt; use tracing_subscriber::util::SubscriberInitExt; @@ -194,16 +194,6 @@ struct Eventloop { relay_data_sender: mpsc::Sender<(Vec, SocketAddr, AllocationId)>, relay_data_receiver: mpsc::Receiver<(Vec, SocketAddr, AllocationId)>, sleep: Sleep, - - allocation_send_buffer: VecDeque<(Vec, SocketAddr, AllocationId)>, -} - -struct Allocation { - /// The handle to the task that is running the allocation. - /// - /// Stored here to make resource-cleanup easy. - handle: task::JoinHandle<()>, - sender: mpsc::Sender<(Vec, SocketAddr)>, } impl Eventloop @@ -237,7 +227,6 @@ where relay_data_sender, relay_data_receiver, sleep: Sleep::default(), - allocation_send_buffer: Default::default(), }) } @@ -282,63 +271,31 @@ where Pin::new(&mut self.sleep).reset(deadline); } Command::ForwardData { id, data, receiver } => { - self.allocation_send_buffer.push_back((data, receiver, id)); + let mut allocation = match self.allocations.entry(id) { + Entry::Occupied(entry) => entry, + Entry::Vacant(_) => { + tracing::debug!(allocation = %id, "Unknown allocation"); + continue; + } + }; + + if allocation.get_mut().send(data, receiver).is_err() { + self.server.handle_allocation_failed(id); + allocation.remove(); + } } } continue; // Attempt to process more commands. } - // Priority 2: Forward data to allocations. - if let Some((data, receiver, id)) = self.allocation_send_buffer.pop_front() { - let Some(allocation) = self.allocations.get_mut(&id) else { - tracing::debug!("Unknown allocation {id}"); - continue; - }; - - match allocation.sender.poll_ready(cx) { - Poll::Ready(Ok(())) => {} - Poll::Ready(Err(_)) => { - debug_assert!( - false, - "poll_ready to never fail because we own the other end of the channel" - ); - } - Poll::Pending => { - // Same as above, we need to yield early if we cannot send data. - // The task will be woken up once there is space in the channel. - - self.allocation_send_buffer.push_front((data, receiver, id)); - return Poll::Pending; - } - } - - match allocation.sender.try_send((data, receiver)) { - Ok(()) => {} - Err(e) if e.is_full() => { - let (data, receiver) = e.into_inner(); - - self.allocation_send_buffer.push_front((data, receiver, id)); - return Poll::Pending; - } - Err(_) => { - debug_assert!( - false, - "try_send to never fail because we own the other end of the channel" - ); - } - }; - - continue; - } - - // Priority 3: Handle time-sensitive tasks: + // Priority 2: Handle time-sensitive tasks: if self.sleep.poll_unpin(cx).is_ready() { self.server.handle_deadline_reached(SystemTime::now()); continue; // Handle potentially new commands. } - // Priority 4: Handle relayed data (we prioritize latency for existing allocations over making new ones) + // Priority 3: Handle relayed data (we prioritize latency for existing allocations over making new ones) if let Poll::Ready(Some((data, sender, allocation))) = self.relay_data_receiver.poll_next_unpin(cx) { @@ -346,7 +303,7 @@ where continue; // Handle potentially new commands. } - // Priority 5: Accept new allocations / answer STUN requests etc + // Priority 4: Accept new allocations / answer STUN requests etc if let Poll::Ready(Some((buffer, sender))) = self.inbound_data_receiver.poll_next_unpin(cx) { @@ -355,7 +312,7 @@ where continue; // Handle potentially new commands. } - // Priority 6: Handle portal messages + // Priority 5: Handle portal messages match self.channel.as_mut().map(|c| c.poll(cx)) { Some(Poll::Ready(Ok(Event::InboundMessage { msg: InboundPortalMessage::Init {}, @@ -404,60 +361,6 @@ where } } -impl Allocation { - fn new( - relay_data_sender: mpsc::Sender<(Vec, SocketAddr, AllocationId)>, - id: AllocationId, - listen_ip4_addr: Ipv4Addr, - port: u16, - ) -> Self { - let (client_to_peer_sender, client_to_peer_receiver) = mpsc::channel(1); - - let task = tokio::spawn(async move { - let Err(e) = forward_incoming_relay_data(relay_data_sender, client_to_peer_receiver, id, listen_ip4_addr, port).await else { - unreachable!() - }; - - // TODO: Do we need to clean this up in the server? It will eventually timeout if not refreshed. - tracing::warn!("Allocation task for {id} failed: {e}"); - }); - - Self { - handle: task, - sender: client_to_peer_sender, - } - } -} - -impl Drop for Allocation { - fn drop(&mut self) { - self.handle.abort(); - } -} - -async fn forward_incoming_relay_data( - mut relayed_data_sender: mpsc::Sender<(Vec, SocketAddr, AllocationId)>, - mut client_to_peer_receiver: mpsc::Receiver<(Vec, SocketAddr)>, - id: AllocationId, - listen_ip4_addr: Ipv4Addr, - port: u16, -) -> Result { - let mut socket = UdpSocket::bind((listen_ip4_addr, port)).await?; - - loop { - tokio::select! { - result = socket.recv() => { - let (data, sender) = result?; - relayed_data_sender.send((data.to_vec(), sender, id)).await?; - } - - Some((data, recipient)) = client_to_peer_receiver.next() => { - socket.send_to(&data, recipient).await?; - } - } - } -} - async fn main_udp_socket_task( listen_ip4_address: Ipv4Addr, mut inbound_data_sender: mpsc::Sender<(Vec, SocketAddr)>, diff --git a/rust/relay/src/server.rs b/rust/relay/src/server.rs index 0f3686785..05c3dd4e8 100644 --- a/rust/relay/src/server.rs +++ b/rust/relay/src/server.rs @@ -372,6 +372,11 @@ where } } + /// An allocation failed. + pub fn handle_allocation_failed(&mut self, id: AllocationId) { + self.delete_allocation(id) + } + /// Return the next command to be executed. pub fn next_command(&mut self) -> Option { self.pending_commands.pop_front()