refactor(connlib): increase UDP queues on desktop platforms (#10235)

On desktop platforms, we can easily afford to have larger queues here
despite each item in there being 65k. Benchmarking showed that we do
sometimes fill these up.

Related: #7452
This commit is contained in:
Thomas Eizinger
2025-08-21 08:56:14 +00:00
committed by GitHub
parent a109c1a2ef
commit f85ae75ae0
3 changed files with 36 additions and 13 deletions

1
rust/Cargo.lock generated
View File

@@ -2675,6 +2675,7 @@ dependencies = [
"test-strategy",
"thiserror 2.0.15",
"tokio",
"tokio-util",
"tracing",
"tracing-subscriber",
"tun",

View File

@@ -50,6 +50,7 @@ socket-factory = { workspace = true }
socket2 = { workspace = true }
thiserror = { workspace = true }
tokio = { workspace = true }
tokio-util = { workspace = true }
tracing = { workspace = true, features = ["attributes"] }
tun = { workspace = true }
uuid = { workspace = true, features = ["std", "v4"] }

View File

@@ -1,6 +1,6 @@
use crate::otel;
use anyhow::{Context as _, Result};
use futures::{SinkExt, StreamExt, ready};
use futures::{SinkExt, ready};
use gat_lending_iterator::LendingIterator;
use socket_factory::DatagramOut;
use socket_factory::{DatagramIn, DatagramSegmentIter, SocketFactory, UdpSocket};
@@ -11,6 +11,8 @@ use std::{
sync::Arc,
task::{Context, Poll, Waker},
};
use tokio::sync::mpsc;
use tokio_util::sync::PollSender;
const DEFAULT_LISTEN_PORT: u16 = EPHEMERAL_PORT_RANGE_START + FIRE;
const EPHEMERAL_PORT_RANGE_START: u16 = 49152;
@@ -173,6 +175,17 @@ where
}
}
/// How big the queue for incoming and outgoing UDP batches is at most.
///
/// On mobile platforms, we are memory-constrained and thus cannot afford to process big batches of packets.
const QUEUE_SIZE: usize = {
if cfg!(any(target_os = "ios", target_os = "android")) {
10
} else {
1000
}
};
struct ThreadedUdpSocket {
thread_name: String,
join_handle: std::thread::JoinHandle<()>,
@@ -180,14 +193,14 @@ struct ThreadedUdpSocket {
}
struct Channels {
outbound_tx: flume::r#async::SendSink<'static, DatagramOut>,
inbound_rx: flume::r#async::RecvStream<'static, Result<DatagramSegmentIter>>,
outbound_tx: PollSender<DatagramOut>,
inbound_rx: mpsc::Receiver<Result<DatagramSegmentIter>>,
}
impl ThreadedUdpSocket {
fn new(sf: Arc<dyn SocketFactory<UdpSocket>>, preferred_addr: SocketAddr) -> io::Result<Self> {
let (outbound_tx, outbound_rx) = flume::bounded(10);
let (inbound_tx, inbound_rx) = flume::bounded(10);
let (outbound_tx, mut outbound_rx) = mpsc::channel(QUEUE_SIZE);
let (inbound_tx, inbound_rx) = mpsc::channel(QUEUE_SIZE);
let (error_tx, error_rx) = flume::bounded(0);
let thread_name = match preferred_addr {
@@ -244,7 +257,7 @@ impl ThreadedUdpSocket {
let socket = socket.clone();
async move {
while let Ok(datagram) = outbound_rx.recv_async().await {
while let Some(datagram) = outbound_rx.recv().await {
tokio::task::yield_now().await;
if let Err(e) = socket.send(datagram).await {
@@ -261,7 +274,7 @@ impl ThreadedUdpSocket {
}
// We use the inbound_tx channel to send the error back to the main thread.
if inbound_tx.send_async(Err(e)).await.is_err() {
if inbound_tx.send(Err(e)).await.is_err() {
tracing::debug!(
"Channel for inbound datagrams closed; exiting UDP thread"
);
@@ -297,7 +310,7 @@ impl ThreadedUdpSocket {
);
}
if inbound_tx.send_async(result).await.is_err() {
if inbound_tx.send(result).await.is_err() {
tracing::debug!(
"Channel for inbound datagrams closed; exiting UDP thread"
);
@@ -317,8 +330,8 @@ impl ThreadedUdpSocket {
thread_name,
join_handle,
channels: Some(Channels {
outbound_tx: outbound_tx.into_sink(),
inbound_rx: inbound_rx.into_stream(),
outbound_tx: PollSender::new(outbound_tx),
inbound_rx,
}),
})
}
@@ -340,8 +353,8 @@ impl ThreadedUdpSocket {
}
fn poll_recv_from(&mut self, cx: &mut Context<'_>) -> Poll<Result<DatagramSegmentIter>> {
let iter = ready!(self.channels_mut()?.inbound_rx.poll_next_unpin(cx))
.ok_or(UdpSocketThreadStopped)?;
let iter =
ready!(self.channels_mut()?.inbound_rx.poll_recv(cx)).ok_or(UdpSocketThreadStopped)?;
Poll::Ready(iter)
}
@@ -353,7 +366,15 @@ impl ThreadedUdpSocket {
fn queue_lengths(&self) -> (usize, usize) {
self.channels
.as_ref()
.map(|c| (c.inbound_rx.len(), c.outbound_tx.len()))
.map(|c| {
(
c.inbound_rx.len(),
c.outbound_tx
.get_ref()
.map(|c| QUEUE_SIZE - c.capacity())
.unwrap_or_default(),
)
})
.unwrap_or_default()
}
}