diff --git a/rust/Cargo.lock b/rust/Cargo.lock index 3fd2749e7..00215240d 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -2675,6 +2675,7 @@ dependencies = [ "test-strategy", "thiserror 2.0.15", "tokio", + "tokio-util", "tracing", "tracing-subscriber", "tun", diff --git a/rust/connlib/tunnel/Cargo.toml b/rust/connlib/tunnel/Cargo.toml index 86acfaf72..74ed0a2bd 100644 --- a/rust/connlib/tunnel/Cargo.toml +++ b/rust/connlib/tunnel/Cargo.toml @@ -50,6 +50,7 @@ socket-factory = { workspace = true } socket2 = { workspace = true } thiserror = { workspace = true } tokio = { workspace = true } +tokio-util = { workspace = true } tracing = { workspace = true, features = ["attributes"] } tun = { workspace = true } uuid = { workspace = true, features = ["std", "v4"] } diff --git a/rust/connlib/tunnel/src/sockets.rs b/rust/connlib/tunnel/src/sockets.rs index 73422b6a3..83d3d83cc 100644 --- a/rust/connlib/tunnel/src/sockets.rs +++ b/rust/connlib/tunnel/src/sockets.rs @@ -1,6 +1,6 @@ use crate::otel; use anyhow::{Context as _, Result}; -use futures::{SinkExt, StreamExt, ready}; +use futures::{SinkExt, ready}; use gat_lending_iterator::LendingIterator; use socket_factory::DatagramOut; use socket_factory::{DatagramIn, DatagramSegmentIter, SocketFactory, UdpSocket}; @@ -11,6 +11,8 @@ use std::{ sync::Arc, task::{Context, Poll, Waker}, }; +use tokio::sync::mpsc; +use tokio_util::sync::PollSender; const DEFAULT_LISTEN_PORT: u16 = EPHEMERAL_PORT_RANGE_START + FIRE; const EPHEMERAL_PORT_RANGE_START: u16 = 49152; @@ -173,6 +175,17 @@ where } } +/// How big the queue for incoming and outgoing UDP batches is at most. +/// +/// On mobile platforms, we are memory-constrained and thus cannot afford to process big batches of packets. +const QUEUE_SIZE: usize = { + if cfg!(any(target_os = "ios", target_os = "android")) { + 10 + } else { + 1000 + } +}; + struct ThreadedUdpSocket { thread_name: String, join_handle: std::thread::JoinHandle<()>, @@ -180,14 +193,14 @@ struct ThreadedUdpSocket { } struct Channels { - outbound_tx: flume::r#async::SendSink<'static, DatagramOut>, - inbound_rx: flume::r#async::RecvStream<'static, Result>, + outbound_tx: PollSender, + inbound_rx: mpsc::Receiver>, } impl ThreadedUdpSocket { fn new(sf: Arc>, preferred_addr: SocketAddr) -> io::Result { - let (outbound_tx, outbound_rx) = flume::bounded(10); - let (inbound_tx, inbound_rx) = flume::bounded(10); + let (outbound_tx, mut outbound_rx) = mpsc::channel(QUEUE_SIZE); + let (inbound_tx, inbound_rx) = mpsc::channel(QUEUE_SIZE); let (error_tx, error_rx) = flume::bounded(0); let thread_name = match preferred_addr { @@ -244,7 +257,7 @@ impl ThreadedUdpSocket { let socket = socket.clone(); async move { - while let Ok(datagram) = outbound_rx.recv_async().await { + while let Some(datagram) = outbound_rx.recv().await { tokio::task::yield_now().await; if let Err(e) = socket.send(datagram).await { @@ -261,7 +274,7 @@ impl ThreadedUdpSocket { } // We use the inbound_tx channel to send the error back to the main thread. - if inbound_tx.send_async(Err(e)).await.is_err() { + if inbound_tx.send(Err(e)).await.is_err() { tracing::debug!( "Channel for inbound datagrams closed; exiting UDP thread" ); @@ -297,7 +310,7 @@ impl ThreadedUdpSocket { ); } - if inbound_tx.send_async(result).await.is_err() { + if inbound_tx.send(result).await.is_err() { tracing::debug!( "Channel for inbound datagrams closed; exiting UDP thread" ); @@ -317,8 +330,8 @@ impl ThreadedUdpSocket { thread_name, join_handle, channels: Some(Channels { - outbound_tx: outbound_tx.into_sink(), - inbound_rx: inbound_rx.into_stream(), + outbound_tx: PollSender::new(outbound_tx), + inbound_rx, }), }) } @@ -340,8 +353,8 @@ impl ThreadedUdpSocket { } fn poll_recv_from(&mut self, cx: &mut Context<'_>) -> Poll> { - let iter = ready!(self.channels_mut()?.inbound_rx.poll_next_unpin(cx)) - .ok_or(UdpSocketThreadStopped)?; + let iter = + ready!(self.channels_mut()?.inbound_rx.poll_recv(cx)).ok_or(UdpSocketThreadStopped)?; Poll::Ready(iter) } @@ -353,7 +366,15 @@ impl ThreadedUdpSocket { fn queue_lengths(&self) -> (usize, usize) { self.channels .as_ref() - .map(|c| (c.inbound_rx.len(), c.outbound_tx.len())) + .map(|c| { + ( + c.inbound_rx.len(), + c.outbound_tx + .get_ref() + .map(|c| QUEUE_SIZE - c.capacity()) + .unwrap_or_default(), + ) + }) .unwrap_or_default() } }