diff --git a/rust/relay/src/main.rs b/rust/relay/src/main.rs index a0bc98b73..7db3ccd91 100644 --- a/rust/relay/src/main.rs +++ b/rust/relay/src/main.rs @@ -16,10 +16,11 @@ use phoenix_channel::{Event, LoginUrl, NoParams, PhoenixChannel}; use rand::rngs::StdRng; use rand::{Rng, SeedableRng}; use secrecy::{Secret, SecretString}; +use std::borrow::Cow; use std::net::{Ipv4Addr, Ipv6Addr}; use std::pin::Pin; use std::sync::{Arc, Mutex}; -use std::task::Poll; +use std::task::{ready, Poll}; use std::time::{Duration, Instant}; use tokio::signal::unix; use tracing::{level_filters::LevelFilter, Subscriber}; @@ -418,6 +419,8 @@ where return Poll::Ready(Ok(())); } + ready!(self.sockets.flush(cx))?; + // Priority 1: Execute the pending commands of the server. if let Some(next_command) = self.server.next_command() { match next_command { @@ -425,7 +428,7 @@ where if let Err(e) = self.sockets.try_send( self.server.listen_port(), recipient.into_socket(), - &payload, + Cow::Owned(payload), ) { tracing::warn!(target: "relay", error = std_dyn_err(&e), %recipient, "Failed to send message"); } @@ -493,10 +496,11 @@ where .expect("valid ChannelData if we should relay it") .data(); // When relaying data from a client to peer, we need to forward only the channel-data's payload. - if let Err(e) = - self.sockets - .try_send(port.value(), peer.into_socket(), payload) - { + if let Err(e) = self.sockets.try_send( + port.value(), + peer.into_socket(), + Cow::Borrowed(payload), + ) { tracing::warn!(target: "relay", error = std_dyn_err(&e), %peer, "Failed to relay data to peer"); } }; @@ -521,7 +525,7 @@ where if let Err(e) = self.sockets.try_send( self.server.listen_port(), // Packets coming in from peers always go out on the TURN port client.into_socket(), - &self.buffer[..total_length], + Cow::Borrowed(&self.buffer[..total_length]), ) { tracing::warn!(target: "relay", error = std_dyn_err(&e), %client, "Failed to relay data to client"); }; diff --git a/rust/relay/src/sockets.rs b/rust/relay/src/sockets.rs index 12a9d58c0..7b8b2f606 100644 --- a/rust/relay/src/sockets.rs +++ b/rust/relay/src/sockets.rs @@ -1,9 +1,10 @@ use anyhow::{bail, Result}; use std::{ - collections::HashMap, + borrow::Cow, + collections::{HashMap, VecDeque}, io, net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, - task::{ready, Context, Poll}, + task::{ready, Context, Poll, Waker}, time::Duration, }; use stun_codec::rfc8656::attributes::AddressFamily; @@ -25,8 +26,20 @@ pub struct Sockets { /// We must read from it until it returns [`io::ErrorKind::WouldBlock`]. current_ready_socket: Option, + /// If we are waiting to flush packets, this waker tracks the suspended task. + flush_waker: Option, + cmd_tx: mpsc::Sender, event_rx: mpsc::Receiver, + + pending_packets: VecDeque, +} + +/// A packet that could not be sent and is buffered until the socket is ready again. +struct PendingPacket { + src: u16, + dst: SocketAddr, + payload: Vec, } impl Default for Sockets { @@ -51,6 +64,8 @@ impl Sockets { cmd_tx, event_rx, current_ready_socket: None, + pending_packets: Default::default(), + flush_waker: None, } } @@ -83,7 +98,41 @@ impl Sockets { Ok(()) } - pub fn try_send(&self, port: u16, dest: SocketAddr, msg: &[u8]) -> io::Result<()> { + /// Flush all buffered packets. + pub fn flush(&mut self, cx: &mut Context<'_>) -> Poll> { + while let Some(packet) = self.pending_packets.pop_front() { + match self.try_send_internal(packet.src, packet.dst, &packet.payload) { + Ok(()) => continue, + Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + self.flush_waker = Some(cx.waker().clone()); + self.pending_packets.push_front(packet); + + return Poll::Pending; + } + Err(e) => return Poll::Ready(Err(e)), + }; + } + + Poll::Ready(Ok(())) + } + + pub fn try_send(&mut self, port: u16, dest: SocketAddr, msg: Cow<'_, [u8]>) -> io::Result<()> { + match self.try_send_internal(port, dest, msg.as_ref()) { + Ok(()) => Ok(()), + Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + self.pending_packets.push_back(PendingPacket { + src: port, + dst: dest, + payload: msg.into_owned(), + }); + + Ok(()) + } + Err(e) => Err(e), + } + } + + fn try_send_internal(&mut self, port: u16, dest: SocketAddr, msg: &[u8]) -> io::Result<()> { let address_family = match dest { SocketAddr::V4(_) => AddressFamily::V4, SocketAddr::V6(_) => AddressFamily::V6, @@ -137,8 +186,21 @@ impl Sockets { self.inner.insert(token, socket); continue; } - Some(Event::SocketReady(ready)) => { - self.current_ready_socket = Some(ready); + Some(Event::SocketReady { + token, + readable, + writeable, + }) => { + if readable { + self.current_ready_socket = Some(token); + } + + if writeable { + if let Some(waker) = self.flush_waker.take() { + waker.wake(); + } + } + continue; } Some(Event::Crashed(error)) => { @@ -173,7 +235,11 @@ enum Command { enum Event { NewSocket(mio::Token, mio::net::UdpSocket), - SocketReady(mio::Token), + SocketReady { + token: mio::Token, + readable: bool, + writeable: bool, + }, Crashed(anyhow::Error), } @@ -199,7 +265,11 @@ fn mio_worker_task( // Send all events into the channel, block as necessary. for event in events.iter() { - event_tx.blocking_send(Event::SocketReady(event.token()))?; + event_tx.blocking_send(Event::SocketReady { + token: event.token(), + readable: event.is_readable(), + writeable: event.is_writable(), + })?; } loop { @@ -210,8 +280,11 @@ fn mio_worker_task( let mut socket = mio::net::UdpSocket::from_std(make_wildcard_socket(af, port)?); let token = token_from_port_and_address_family(port, af); - poll.registry() - .register(&mut socket, token, mio::Interest::READABLE)?; + poll.registry().register( + &mut socket, + token, + mio::Interest::READABLE | mio::Interest::WRITABLE, + )?; event_tx.blocking_send(Event::NewSocket(token, socket))?; }