mirror of
https://github.com/outbackdingo/firezone.git
synced 2026-01-27 10:18:54 +00:00
fix(relay): read from most-recently-ready socket first (#10148)
The relay uses `mio` to react to readiness events from multiple sockets at once. Including the control port 3478, the relay needs to also send and receive traffic from up to 16384 sockets (one for each possible allocation). We need to process readiness events from these sockets as fairly as possible. Under high-load, it may otherwise happen that we don't read packets from an allocation socket, resulting in ICE timeouts of the connection being relayed. To achieve this fairness, we collect all readiness tokens into a set and store it with the number of packets we have read so far from this socket. Then, we always read from the socket next that we have so far read the least amount of packets from.
This commit is contained in:
@@ -1,10 +1,10 @@
|
||||
use anyhow::{Result, bail};
|
||||
use std::{
|
||||
borrow::Cow,
|
||||
collections::{HashMap, VecDeque},
|
||||
collections::{BTreeSet, HashMap, VecDeque},
|
||||
io,
|
||||
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
|
||||
task::{Context, Poll, Waker, ready},
|
||||
task::{Context, Poll, Waker},
|
||||
time::Duration,
|
||||
};
|
||||
use stun_codec::rfc8656::attributes::AddressFamily;
|
||||
@@ -20,11 +20,15 @@ pub struct Sockets {
|
||||
/// [`mio`] operates with a concept of [`mio::Token`]s so we need to store our sockets indexed by those tokens.
|
||||
inner: HashMap<mio::Token, mio::net::UdpSocket>,
|
||||
|
||||
/// Which socket we should still be reading from.
|
||||
/// Which sockets we should still be reading from.
|
||||
///
|
||||
/// [`mio`] sends us a signal when a socket is ready for reading.
|
||||
/// We must read from it until it returns [`io::ErrorKind::WouldBlock`].
|
||||
current_ready_socket: Option<mio::Token>,
|
||||
///
|
||||
/// We store each socket with the number of packets that we read.
|
||||
/// This allows us to always prioritize a socket that we haven't read any packets
|
||||
/// from but is ready.
|
||||
current_ready_sockets: BTreeSet<(usize, mio::Token)>,
|
||||
|
||||
/// If we are waiting to flush packets, this waker tracks the suspended task.
|
||||
flush_waker: Option<Waker>,
|
||||
@@ -63,7 +67,7 @@ impl Sockets {
|
||||
inner: Default::default(),
|
||||
cmd_tx,
|
||||
event_rx,
|
||||
current_ready_socket: None,
|
||||
current_ready_sockets: Default::default(),
|
||||
pending_packets: Default::default(),
|
||||
flush_waker: None,
|
||||
}
|
||||
@@ -157,42 +161,18 @@ impl Sockets {
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Result<Received<'b>, Error>> {
|
||||
loop {
|
||||
if let Some(current) = self.current_ready_socket {
|
||||
if let Some(socket) = self.inner.get(¤t) {
|
||||
let (num_bytes, from) = match socket.recv_from(buf) {
|
||||
Ok(ok) => ok,
|
||||
Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
|
||||
self.current_ready_socket = None;
|
||||
continue;
|
||||
}
|
||||
Err(e) => {
|
||||
self.current_ready_socket = None;
|
||||
return Poll::Ready(Err(Error::Io(e)));
|
||||
}
|
||||
};
|
||||
|
||||
let (port, _) = token_to_port_and_address_family(current);
|
||||
|
||||
return Poll::Ready(Ok(Received {
|
||||
port,
|
||||
from,
|
||||
packet: &buf[..num_bytes],
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
match ready!(self.event_rx.poll_recv(cx)) {
|
||||
Some(Event::NewSocket(token, socket)) => {
|
||||
match self.event_rx.poll_recv(cx) {
|
||||
Poll::Ready(Some(Event::NewSocket(token, socket))) => {
|
||||
self.inner.insert(token, socket);
|
||||
continue;
|
||||
}
|
||||
Some(Event::SocketReady {
|
||||
Poll::Ready(Some(Event::SocketReady {
|
||||
token,
|
||||
readable,
|
||||
writeable,
|
||||
}) => {
|
||||
})) => {
|
||||
if readable {
|
||||
self.current_ready_socket = Some(token);
|
||||
self.current_ready_sockets.insert((0, token));
|
||||
}
|
||||
|
||||
if writeable {
|
||||
@@ -203,13 +183,41 @@ impl Sockets {
|
||||
|
||||
continue;
|
||||
}
|
||||
Some(Event::Crashed(error)) => {
|
||||
Poll::Ready(Some(Event::Crashed(error))) => {
|
||||
return Poll::Ready(Err(Error::MioTaskCrashed(error)));
|
||||
}
|
||||
None => {
|
||||
Poll::Ready(None) => {
|
||||
panic!("must not poll `Sockets` after mio task exited")
|
||||
}
|
||||
};
|
||||
Poll::Pending => {}
|
||||
}
|
||||
|
||||
// Read from all sockets in order of least packets read so far.
|
||||
while let Some((num_packets, current)) = self.current_ready_sockets.pop_first() {
|
||||
let Some(socket) = self.inner.get(¤t) else {
|
||||
continue;
|
||||
};
|
||||
|
||||
let (num_bytes, from) = match socket.recv_from(buf) {
|
||||
Ok(ok) => ok,
|
||||
Err(e) if e.kind() == io::ErrorKind::WouldBlock => continue,
|
||||
Err(e) => return Poll::Ready(Err(Error::Io(e))),
|
||||
};
|
||||
|
||||
// Bump the number of packets and return.
|
||||
self.current_ready_sockets
|
||||
.insert((num_packets + 1, current));
|
||||
|
||||
let (port, _) = token_to_port_and_address_family(current);
|
||||
|
||||
return Poll::Ready(Ok(Received {
|
||||
port,
|
||||
from,
|
||||
packet: &buf[..num_bytes],
|
||||
}));
|
||||
}
|
||||
|
||||
return Poll::Pending; // This is okay because we only get here if `event_rx` returned pending.
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user