From 8bd8098cabc0de2f9d3f4b359e7731dd11a19555 Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Wed, 29 Jan 2025 15:48:48 +0000 Subject: [PATCH] refactor(connlib): don't re-implement waker for TUN thread (#7944) Within `connlib` - on UNIX platforms - we have dedicated threads that read from and write to the TUN device. These threads are connected with `connlib`'s main thread via bounded channels: one in each direction. When these channels are full, `connlib`'s main thread will suspend and not read any network packets from the sockets in order to maintain back-pressure. Reading more packets from the socket would mean most likely sending more packets out the TUN device. When debugging #7763, it became apparent that _something_ must be wrong with these threads and that somehow, we either consider them as full or aren't emptying them and as a result, we don't read _any_ network packets from our sockets. To maintain back-pressure here, we currently use our own `AtomicWaker` construct that is shared with the TUN thread(s). This is unnecessary. We can also directly convert the `flume::Sender` into a `flume::async::SendSink` and therefore directly access a `poll` interface. --- .../src/tun_device_manager/linux.rs | 32 ++++--------- rust/connlib/clients/android/src/tun.rs | 47 +++++++------------ rust/connlib/clients/apple/src/tun.rs | 47 +++++++------------ rust/tun/src/unix.rs | 5 -- 4 files changed, 44 insertions(+), 87 deletions(-) diff --git a/rust/bin-shared/src/tun_device_manager/linux.rs b/rust/bin-shared/src/tun_device_manager/linux.rs index 4e3ae9c34..68612d7c2 100644 --- a/rust/bin-shared/src/tun_device_manager/linux.rs +++ b/rust/bin-shared/src/tun_device_manager/linux.rs @@ -3,8 +3,7 @@ use crate::FIREZONE_MARK; use anyhow::{anyhow, Context as _, Result}; use firezone_logging::std_dyn_err; -use futures::task::AtomicWaker; -use futures::TryStreamExt; +use futures::{SinkExt, TryStreamExt}; use ip_network::{IpNetwork, Ipv4Network, Ipv6Network}; use ip_packet::{IpPacket, IpPacketBuf}; use libc::{ @@ -15,7 +14,6 @@ use netlink_packet_route::route::{RouteProtocol, RouteScope}; use netlink_packet_route::rule::RuleAction; use rtnetlink::{new_connection, Error::NetlinkError, Handle, RouteAddRequest, RuleAddRequest}; use std::path::Path; -use std::sync::Arc; use std::task::{Context, Poll}; use std::{ collections::HashSet, @@ -298,8 +296,7 @@ async fn remove_route(route: &IpNetwork, idx: u32, handle: &Handle) { #[derive(Debug)] pub struct Tun { - outbound_tx: flume::Sender, - outbound_capacity_waker: Arc, + outbound_tx: flume::r#async::SendSink<'static, IpPacket>, inbound_rx: mpsc::Receiver, } @@ -309,26 +306,17 @@ impl Tun { let (inbound_tx, inbound_rx) = mpsc::channel(1000); let (outbound_tx, outbound_rx) = flume::bounded(1000); // flume is an MPMC channel, therefore perfect for workstealing outbound packets. - let outbound_capacity_waker = Arc::new(AtomicWaker::new()); for n in 0..num_threads { let fd = open_tun()?; let outbound_rx = outbound_rx.clone().into_stream(); let inbound_tx = inbound_tx.clone(); - let outbound_capacity_waker = outbound_capacity_waker.clone(); std::thread::Builder::new() .name(format!("TUN send/recv {n}/{num_threads}")) .spawn(move || { firezone_logging::unwrap_or_warn!( - tun::unix::send_recv_tun( - fd, - inbound_tx, - outbound_rx, - outbound_capacity_waker, - read, - write, - ), + tun::unix::send_recv_tun(fd, inbound_tx, outbound_rx, read, write), "Failed to send / recv from TUN device" ) }) @@ -336,8 +324,7 @@ impl Tun { } Ok(Self { - outbound_tx, - outbound_capacity_waker, + outbound_tx: outbound_tx.into_sink(), inbound_rx, }) } @@ -367,17 +354,14 @@ fn open_tun() -> Result { impl tun::Tun for Tun { fn poll_send_ready(&mut self, cx: &mut Context) -> Poll> { - if self.outbound_tx.is_full() { - self.outbound_capacity_waker.register(cx.waker()); - return Poll::Pending; - } - - Poll::Ready(Ok(())) + self.outbound_tx + .poll_ready_unpin(cx) + .map_err(io::Error::other) } fn send(&mut self, packet: IpPacket) -> io::Result<()> { self.outbound_tx - .try_send(packet) + .start_send_unpin(packet) .map_err(io::Error::other)?; Ok(()) diff --git a/rust/connlib/clients/android/src/tun.rs b/rust/connlib/clients/android/src/tun.rs index 78d70b340..163491656 100644 --- a/rust/connlib/clients/android/src/tun.rs +++ b/rust/connlib/clients/android/src/tun.rs @@ -1,6 +1,5 @@ -use futures::task::AtomicWaker; +use futures::SinkExt as _; use ip_packet::{IpPacket, IpPacketBuf}; -use std::sync::Arc; use std::task::{Context, Poll}; use std::{io, os::fd::RawFd}; use tokio::sync::mpsc; @@ -10,8 +9,7 @@ use tun::unix::TunFd; #[derive(Debug)] pub struct Tun { name: String, - outbound_tx: flume::Sender, - outbound_capacity_waker: Arc, + outbound_tx: flume::r#async::SendSink<'static, IpPacket>, inbound_rx: mpsc::Receiver, } @@ -21,17 +19,14 @@ impl tun::Tun for Tun { } fn poll_send_ready(&mut self, cx: &mut Context) -> Poll> { - if self.outbound_tx.is_full() { - self.outbound_capacity_waker.register(cx.waker()); - return Poll::Pending; - } - - Poll::Ready(Ok(())) + self.outbound_tx + .poll_ready_unpin(cx) + .map_err(io::Error::other) } fn send(&mut self, packet: IpPacket) -> io::Result<()> { self.outbound_tx - .try_send(packet) + .start_send_unpin(packet) .map_err(io::Error::other)?; Ok(()) @@ -62,35 +57,29 @@ impl Tun { let (inbound_tx, inbound_rx) = mpsc::channel(1000); let (outbound_tx, outbound_rx) = flume::bounded(1000); // flume is an MPMC channel, therefore perfect for workstealing outbound packets. - let outbound_capacity_waker = Arc::new(AtomicWaker::new()); // TODO: Test whether we can set `IFF_MULTI_QUEUE` on Android devices. std::thread::Builder::new() .name("TUN send/recv".to_owned()) - .spawn({ - let outbound_capacity_waker = outbound_capacity_waker.clone(); - || { - firezone_logging::unwrap_or_warn!( - tun::unix::send_recv_tun( - fd, - inbound_tx, - outbound_rx.into_stream(), - outbound_capacity_waker, - read, - write, - ), - "Failed to send / recv from TUN device" - ) - } + .spawn(|| { + firezone_logging::unwrap_or_warn!( + tun::unix::send_recv_tun( + fd, + inbound_tx, + outbound_rx.into_stream(), + read, + write, + ), + "Failed to send / recv from TUN device" + ) }) .map_err(io::Error::other)?; Ok(Tun { name, - outbound_tx, + outbound_tx: outbound_tx.into_sink(), inbound_rx, - outbound_capacity_waker, }) } } diff --git a/rust/connlib/clients/apple/src/tun.rs b/rust/connlib/clients/apple/src/tun.rs index b36728b0c..e9f1e5c24 100644 --- a/rust/connlib/clients/apple/src/tun.rs +++ b/rust/connlib/clients/apple/src/tun.rs @@ -1,7 +1,6 @@ -use futures::task::AtomicWaker; +use futures::SinkExt as _; use ip_packet::{IpPacket, IpPacketBuf}; use libc::{fcntl, iovec, msghdr, recvmsg, AF_INET, AF_INET6, F_GETFL, F_SETFL, O_NONBLOCK}; -use std::sync::Arc; use std::task::{Context, Poll}; use std::{ io, @@ -12,8 +11,7 @@ use tokio::sync::mpsc; #[derive(Debug)] pub struct Tun { name: String, - outbound_capacity_waker: Arc, - outbound_tx: flume::Sender, + outbound_tx: flume::r#async::SendSink<'static, IpPacket>, inbound_rx: mpsc::Receiver, } @@ -25,33 +23,27 @@ impl Tun { let (inbound_tx, inbound_rx) = mpsc::channel(1000); let (outbound_tx, outbound_rx) = flume::bounded(1000); // flume is an MPMC channel, therefore perfect for workstealing outbound packets. - let outbound_capacity_waker = Arc::new(AtomicWaker::new()); std::thread::Builder::new() .name("TUN send/recv".to_owned()) - .spawn({ - let outbound_capacity_waker = outbound_capacity_waker.clone(); - move || { - firezone_logging::unwrap_or_warn!( - tun::unix::send_recv_tun( - fd, - inbound_tx, - outbound_rx.into_stream(), - outbound_capacity_waker, - read, - write, - ), - "Failed to send / recv from TUN device" - ) - } + .spawn(move || { + firezone_logging::unwrap_or_warn!( + tun::unix::send_recv_tun( + fd, + inbound_tx, + outbound_rx.into_stream(), + read, + write, + ), + "Failed to send / recv from TUN device" + ) }) .map_err(io::Error::other)?; Ok(Tun { name, - outbound_tx, + outbound_tx: outbound_tx.into_sink(), inbound_rx, - outbound_capacity_waker, }) } } @@ -62,17 +54,14 @@ impl tun::Tun for Tun { } fn poll_send_ready(&mut self, cx: &mut Context) -> Poll> { - if self.outbound_tx.is_full() { - self.outbound_capacity_waker.register(cx.waker()); - return Poll::Pending; - } - - Poll::Ready(Ok(())) + self.outbound_tx + .poll_ready_unpin(cx) + .map_err(io::Error::other) } fn send(&mut self, packet: IpPacket) -> io::Result<()> { self.outbound_tx - .try_send(packet) + .start_send_unpin(packet) .map_err(io::Error::other)?; Ok(()) diff --git a/rust/tun/src/unix.rs b/rust/tun/src/unix.rs index e19569921..ef589404e 100644 --- a/rust/tun/src/unix.rs +++ b/rust/tun/src/unix.rs @@ -1,12 +1,10 @@ use anyhow::{Context as _, Result}; use futures::future::Either; -use futures::task::AtomicWaker; use futures::StreamExt as _; use ip_packet::{IpPacket, IpPacketBuf}; use std::io; use std::os::fd::{AsRawFd, RawFd}; use std::pin::pin; -use std::sync::Arc; use tokio::io::unix::AsyncFd; use tokio::sync::mpsc; @@ -50,7 +48,6 @@ pub fn send_recv_tun( fd: T, inbound_tx: mpsc::Sender, mut outbound_rx: flume::r#async::RecvStream<'static, IpPacket>, - outbound_capacity_waker: Arc, read: impl Fn(RawFd, &mut IpPacketBuf) -> io::Result, write: impl Fn(RawFd, &IpPacket) -> io::Result, ) -> Result<()> @@ -108,8 +105,6 @@ where { tracing::warn!("Failed to write to TUN FD: {e}"); }; - - outbound_capacity_waker.wake(); // We wrote a packet, notify about the new capacity. } Either::Left((None, _)) => { tracing::debug!("Outbound packet sender gone, shutting down task");