diff --git a/rust/relay/server/src/sockets.rs b/rust/relay/server/src/sockets.rs index 1b0484be3..27f9427c2 100644 --- a/rust/relay/server/src/sockets.rs +++ b/rust/relay/server/src/sockets.rs @@ -1,10 +1,10 @@ use anyhow::{Result, bail}; use std::{ borrow::Cow, - collections::{HashMap, VecDeque}, + collections::{BTreeSet, HashMap, VecDeque}, io, net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, - task::{Context, Poll, Waker, ready}, + task::{Context, Poll, Waker}, time::Duration, }; use stun_codec::rfc8656::attributes::AddressFamily; @@ -20,11 +20,15 @@ pub struct Sockets { /// [`mio`] operates with a concept of [`mio::Token`]s so we need to store our sockets indexed by those tokens. inner: HashMap, - /// Which socket we should still be reading from. + /// Which sockets we should still be reading from. /// /// [`mio`] sends us a signal when a socket is ready for reading. /// We must read from it until it returns [`io::ErrorKind::WouldBlock`]. - current_ready_socket: Option, + /// + /// We store each socket with the number of packets that we read. + /// This allows us to always prioritize a socket that we haven't read any packets + /// from but is ready. + current_ready_sockets: BTreeSet<(usize, mio::Token)>, /// If we are waiting to flush packets, this waker tracks the suspended task. flush_waker: Option, @@ -63,7 +67,7 @@ impl Sockets { inner: Default::default(), cmd_tx, event_rx, - current_ready_socket: None, + current_ready_sockets: Default::default(), pending_packets: Default::default(), flush_waker: None, } @@ -157,42 +161,18 @@ impl Sockets { cx: &mut Context<'_>, ) -> Poll, Error>> { loop { - if let Some(current) = self.current_ready_socket { - if let Some(socket) = self.inner.get(¤t) { - let (num_bytes, from) = match socket.recv_from(buf) { - Ok(ok) => ok, - Err(e) if e.kind() == io::ErrorKind::WouldBlock => { - self.current_ready_socket = None; - continue; - } - Err(e) => { - self.current_ready_socket = None; - return Poll::Ready(Err(Error::Io(e))); - } - }; - - let (port, _) = token_to_port_and_address_family(current); - - return Poll::Ready(Ok(Received { - port, - from, - packet: &buf[..num_bytes], - })); - } - } - - match ready!(self.event_rx.poll_recv(cx)) { - Some(Event::NewSocket(token, socket)) => { + match self.event_rx.poll_recv(cx) { + Poll::Ready(Some(Event::NewSocket(token, socket))) => { self.inner.insert(token, socket); continue; } - Some(Event::SocketReady { + Poll::Ready(Some(Event::SocketReady { token, readable, writeable, - }) => { + })) => { if readable { - self.current_ready_socket = Some(token); + self.current_ready_sockets.insert((0, token)); } if writeable { @@ -203,13 +183,41 @@ impl Sockets { continue; } - Some(Event::Crashed(error)) => { + Poll::Ready(Some(Event::Crashed(error))) => { return Poll::Ready(Err(Error::MioTaskCrashed(error))); } - None => { + Poll::Ready(None) => { panic!("must not poll `Sockets` after mio task exited") } - }; + Poll::Pending => {} + } + + // Read from all sockets in order of least packets read so far. + while let Some((num_packets, current)) = self.current_ready_sockets.pop_first() { + let Some(socket) = self.inner.get(¤t) else { + continue; + }; + + let (num_bytes, from) = match socket.recv_from(buf) { + Ok(ok) => ok, + Err(e) if e.kind() == io::ErrorKind::WouldBlock => continue, + Err(e) => return Poll::Ready(Err(Error::Io(e))), + }; + + // Bump the number of packets and return. + self.current_ready_sockets + .insert((num_packets + 1, current)); + + let (port, _) = token_to_port_and_address_family(current); + + return Poll::Ready(Ok(Received { + port, + from, + packet: &buf[..num_bytes], + })); + } + + return Poll::Pending; // This is okay because we only get here if `event_rx` returned pending. } } }