refactor(connlib): don't re-implement waker for TUN thread (#7944)

Within `connlib` - on UNIX platforms - we have dedicated threads that
read from and write to the TUN device. These threads are connected with
`connlib`'s main thread via bounded channels: one in each direction.
When these channels are full, `connlib`'s main thread will suspend and
not read any network packets from the sockets in order to maintain
back-pressure. Reading more packets from the socket would mean most
likely sending more packets out the TUN device.

When debugging #7763, it became apparent that _something_ must be wrong
with these threads and that somehow, we either consider them as full or
aren't emptying them and as a result, we don't read _any_ network
packets from our sockets.

To maintain back-pressure here, we currently use our own `AtomicWaker`
construct that is shared with the TUN thread(s). This is unnecessary. We
can also directly convert the `flume::Sender` into a
`flume::async::SendSink` and therefore directly access a `poll`
interface.
This commit is contained in:
Thomas Eizinger
2025-01-29 15:48:48 +00:00
committed by GitHub
parent 287ea1e8b2
commit 8bd8098cab
4 changed files with 44 additions and 87 deletions

View File

@@ -3,8 +3,7 @@
use crate::FIREZONE_MARK;
use anyhow::{anyhow, Context as _, Result};
use firezone_logging::std_dyn_err;
use futures::task::AtomicWaker;
use futures::TryStreamExt;
use futures::{SinkExt, TryStreamExt};
use ip_network::{IpNetwork, Ipv4Network, Ipv6Network};
use ip_packet::{IpPacket, IpPacketBuf};
use libc::{
@@ -15,7 +14,6 @@ use netlink_packet_route::route::{RouteProtocol, RouteScope};
use netlink_packet_route::rule::RuleAction;
use rtnetlink::{new_connection, Error::NetlinkError, Handle, RouteAddRequest, RuleAddRequest};
use std::path::Path;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::{
collections::HashSet,
@@ -298,8 +296,7 @@ async fn remove_route(route: &IpNetwork, idx: u32, handle: &Handle) {
#[derive(Debug)]
pub struct Tun {
outbound_tx: flume::Sender<IpPacket>,
outbound_capacity_waker: Arc<AtomicWaker>,
outbound_tx: flume::r#async::SendSink<'static, IpPacket>,
inbound_rx: mpsc::Receiver<IpPacket>,
}
@@ -309,26 +306,17 @@ impl Tun {
let (inbound_tx, inbound_rx) = mpsc::channel(1000);
let (outbound_tx, outbound_rx) = flume::bounded(1000); // flume is an MPMC channel, therefore perfect for workstealing outbound packets.
let outbound_capacity_waker = Arc::new(AtomicWaker::new());
for n in 0..num_threads {
let fd = open_tun()?;
let outbound_rx = outbound_rx.clone().into_stream();
let inbound_tx = inbound_tx.clone();
let outbound_capacity_waker = outbound_capacity_waker.clone();
std::thread::Builder::new()
.name(format!("TUN send/recv {n}/{num_threads}"))
.spawn(move || {
firezone_logging::unwrap_or_warn!(
tun::unix::send_recv_tun(
fd,
inbound_tx,
outbound_rx,
outbound_capacity_waker,
read,
write,
),
tun::unix::send_recv_tun(fd, inbound_tx, outbound_rx, read, write),
"Failed to send / recv from TUN device"
)
})
@@ -336,8 +324,7 @@ impl Tun {
}
Ok(Self {
outbound_tx,
outbound_capacity_waker,
outbound_tx: outbound_tx.into_sink(),
inbound_rx,
})
}
@@ -367,17 +354,14 @@ fn open_tun() -> Result<TunFd, io::Error> {
impl tun::Tun for Tun {
fn poll_send_ready(&mut self, cx: &mut Context) -> Poll<io::Result<()>> {
if self.outbound_tx.is_full() {
self.outbound_capacity_waker.register(cx.waker());
return Poll::Pending;
}
Poll::Ready(Ok(()))
self.outbound_tx
.poll_ready_unpin(cx)
.map_err(io::Error::other)
}
fn send(&mut self, packet: IpPacket) -> io::Result<()> {
self.outbound_tx
.try_send(packet)
.start_send_unpin(packet)
.map_err(io::Error::other)?;
Ok(())

View File

@@ -1,6 +1,5 @@
use futures::task::AtomicWaker;
use futures::SinkExt as _;
use ip_packet::{IpPacket, IpPacketBuf};
use std::sync::Arc;
use std::task::{Context, Poll};
use std::{io, os::fd::RawFd};
use tokio::sync::mpsc;
@@ -10,8 +9,7 @@ use tun::unix::TunFd;
#[derive(Debug)]
pub struct Tun {
name: String,
outbound_tx: flume::Sender<IpPacket>,
outbound_capacity_waker: Arc<AtomicWaker>,
outbound_tx: flume::r#async::SendSink<'static, IpPacket>,
inbound_rx: mpsc::Receiver<IpPacket>,
}
@@ -21,17 +19,14 @@ impl tun::Tun for Tun {
}
fn poll_send_ready(&mut self, cx: &mut Context) -> Poll<io::Result<()>> {
if self.outbound_tx.is_full() {
self.outbound_capacity_waker.register(cx.waker());
return Poll::Pending;
}
Poll::Ready(Ok(()))
self.outbound_tx
.poll_ready_unpin(cx)
.map_err(io::Error::other)
}
fn send(&mut self, packet: IpPacket) -> io::Result<()> {
self.outbound_tx
.try_send(packet)
.start_send_unpin(packet)
.map_err(io::Error::other)?;
Ok(())
@@ -62,35 +57,29 @@ impl Tun {
let (inbound_tx, inbound_rx) = mpsc::channel(1000);
let (outbound_tx, outbound_rx) = flume::bounded(1000); // flume is an MPMC channel, therefore perfect for workstealing outbound packets.
let outbound_capacity_waker = Arc::new(AtomicWaker::new());
// TODO: Test whether we can set `IFF_MULTI_QUEUE` on Android devices.
std::thread::Builder::new()
.name("TUN send/recv".to_owned())
.spawn({
let outbound_capacity_waker = outbound_capacity_waker.clone();
|| {
firezone_logging::unwrap_or_warn!(
tun::unix::send_recv_tun(
fd,
inbound_tx,
outbound_rx.into_stream(),
outbound_capacity_waker,
read,
write,
),
"Failed to send / recv from TUN device"
)
}
.spawn(|| {
firezone_logging::unwrap_or_warn!(
tun::unix::send_recv_tun(
fd,
inbound_tx,
outbound_rx.into_stream(),
read,
write,
),
"Failed to send / recv from TUN device"
)
})
.map_err(io::Error::other)?;
Ok(Tun {
name,
outbound_tx,
outbound_tx: outbound_tx.into_sink(),
inbound_rx,
outbound_capacity_waker,
})
}
}

View File

@@ -1,7 +1,6 @@
use futures::task::AtomicWaker;
use futures::SinkExt as _;
use ip_packet::{IpPacket, IpPacketBuf};
use libc::{fcntl, iovec, msghdr, recvmsg, AF_INET, AF_INET6, F_GETFL, F_SETFL, O_NONBLOCK};
use std::sync::Arc;
use std::task::{Context, Poll};
use std::{
io,
@@ -12,8 +11,7 @@ use tokio::sync::mpsc;
#[derive(Debug)]
pub struct Tun {
name: String,
outbound_capacity_waker: Arc<AtomicWaker>,
outbound_tx: flume::Sender<IpPacket>,
outbound_tx: flume::r#async::SendSink<'static, IpPacket>,
inbound_rx: mpsc::Receiver<IpPacket>,
}
@@ -25,33 +23,27 @@ impl Tun {
let (inbound_tx, inbound_rx) = mpsc::channel(1000);
let (outbound_tx, outbound_rx) = flume::bounded(1000); // flume is an MPMC channel, therefore perfect for workstealing outbound packets.
let outbound_capacity_waker = Arc::new(AtomicWaker::new());
std::thread::Builder::new()
.name("TUN send/recv".to_owned())
.spawn({
let outbound_capacity_waker = outbound_capacity_waker.clone();
move || {
firezone_logging::unwrap_or_warn!(
tun::unix::send_recv_tun(
fd,
inbound_tx,
outbound_rx.into_stream(),
outbound_capacity_waker,
read,
write,
),
"Failed to send / recv from TUN device"
)
}
.spawn(move || {
firezone_logging::unwrap_or_warn!(
tun::unix::send_recv_tun(
fd,
inbound_tx,
outbound_rx.into_stream(),
read,
write,
),
"Failed to send / recv from TUN device"
)
})
.map_err(io::Error::other)?;
Ok(Tun {
name,
outbound_tx,
outbound_tx: outbound_tx.into_sink(),
inbound_rx,
outbound_capacity_waker,
})
}
}
@@ -62,17 +54,14 @@ impl tun::Tun for Tun {
}
fn poll_send_ready(&mut self, cx: &mut Context) -> Poll<io::Result<()>> {
if self.outbound_tx.is_full() {
self.outbound_capacity_waker.register(cx.waker());
return Poll::Pending;
}
Poll::Ready(Ok(()))
self.outbound_tx
.poll_ready_unpin(cx)
.map_err(io::Error::other)
}
fn send(&mut self, packet: IpPacket) -> io::Result<()> {
self.outbound_tx
.try_send(packet)
.start_send_unpin(packet)
.map_err(io::Error::other)?;
Ok(())

View File

@@ -1,12 +1,10 @@
use anyhow::{Context as _, Result};
use futures::future::Either;
use futures::task::AtomicWaker;
use futures::StreamExt as _;
use ip_packet::{IpPacket, IpPacketBuf};
use std::io;
use std::os::fd::{AsRawFd, RawFd};
use std::pin::pin;
use std::sync::Arc;
use tokio::io::unix::AsyncFd;
use tokio::sync::mpsc;
@@ -50,7 +48,6 @@ pub fn send_recv_tun<T>(
fd: T,
inbound_tx: mpsc::Sender<IpPacket>,
mut outbound_rx: flume::r#async::RecvStream<'static, IpPacket>,
outbound_capacity_waker: Arc<AtomicWaker>,
read: impl Fn(RawFd, &mut IpPacketBuf) -> io::Result<usize>,
write: impl Fn(RawFd, &IpPacket) -> io::Result<usize>,
) -> Result<()>
@@ -108,8 +105,6 @@ where
{
tracing::warn!("Failed to write to TUN FD: {e}");
};
outbound_capacity_waker.wake(); // We wrote a packet, notify about the new capacity.
}
Either::Left((None, _)) => {
tracing::debug!("Outbound packet sender gone, shutting down task");