feat(linux): multi-threaded TUN device operations (#7449)

## Context

At present, we only have a single thread that reads and writes to the
TUN device on all platforms. On Linux, it is possible to open the file
descriptor of a TUN device multiple times by setting the
`IFF_MULTI_QUEUE` option using `ioctl`. Using multi-queue, we can then
spawn multiple threads that concurrently read and write to the TUN
device. This is critical for achieving a better throughput.

## Solution

`IFF_MULTI_QUEUE` is a Linux-only thing and therefore only applies to
headless-client, GUI-client on Linux and the Gateway (it may also be
possible on Android, I haven't tried). As such, we need to first change
our internal abstractions a bit to move the creation of the TUN thread
to the `Tun` abstraction itself. For this, we change the interface of
`Tun` to the following:

- `poll_recv_many`: An API, inspired by tokio's `mpsc::Receiver` where
multiple items in a channel can be batch-received.
- `poll_send_ready`: Mimics the API of `Sink` to check whether more
items can be written.
- `send`: Mimics the API of `Sink` to actually send an item.

With these APIs in place, we can implement various (performance)
improvements for the different platforms.

- On Linux, this allows us to spawn multiple threads to read and write
from the TUN device and send all packets into the same channel. The `Io`
component of `connlib` then uses `poll_recv_many` to read batches of up
to 100 packets at once. This ties in well with #7210 because we can then
use GSO to send the encrypted packets in single syscalls to the OS.
- On Windows, we already have a dedicated recv thread because `WinTun`'s
most-convenient API uses blocking IO. As such, we can now also tie into
that by batch-receiving from this channel.
- In addition to using multiple threads, this API now also uses correct
readiness checks on Linux, Darwin and Android to uphold backpressure in
case we cannot write to the TUN device.

## Configuration

Local testing has shown that 2 threads give the best performance for a
local `iperf3` run. I suspect this is because there is only so much
traffic that a single application (i.e. `iperf3`) can generate. With
more than 2 threads, the throughput actually drops drastically because
`connlib`'s main thread is too busy with lock-contention and triggering
`Waker`s for the TUN threads (which mostly idle around if there are 4+
of them). I've made it configurable on the Gateway though so we can
experiment with this during concurrent speedtests etc.

In addition, switching `connlib` to a single-threaded tokio runtime
further increased the throughput. I suspect due to less task / context
switching.

## Results

Local testing with `iperf3` shows some very promising results. We now
achieve a throughput of 2+ Gbit/s.

```
Connecting to host 172.20.0.110, port 5201
Reverse mode, remote host 172.20.0.110 is sending
[  5] local 100.80.159.34 port 57040 connected to 172.20.0.110 port 5201
[ ID] Interval           Transfer     Bitrate
[  5]   0.00-1.00   sec   274 MBytes  2.30 Gbits/sec
[  5]   1.00-2.00   sec   279 MBytes  2.34 Gbits/sec
[  5]   2.00-3.00   sec   216 MBytes  1.82 Gbits/sec
[  5]   3.00-4.00   sec   224 MBytes  1.88 Gbits/sec
[  5]   4.00-5.00   sec   234 MBytes  1.96 Gbits/sec
[  5]   5.00-6.00   sec   238 MBytes  2.00 Gbits/sec
[  5]   6.00-7.00   sec   229 MBytes  1.92 Gbits/sec
[  5]   7.00-8.00   sec   222 MBytes  1.86 Gbits/sec
[  5]   8.00-9.00   sec   223 MBytes  1.87 Gbits/sec
[  5]   9.00-10.00  sec   217 MBytes  1.82 Gbits/sec
- - - - - - - - - - - - - - - - - - - - - - - - -
[ ID] Interval           Transfer     Bitrate         Retr
[  5]   0.00-10.00  sec  2.30 GBytes  1.98 Gbits/sec  22247             sender
[  5]   0.00-10.00  sec  2.30 GBytes  1.98 Gbits/sec                  receiver

iperf Done.
```

This is a pretty solid improvement over what is in `main`:

```
Connecting to host 172.20.0.110, port 5201
[  5] local 100.65.159.3 port 56970 connected to 172.20.0.110 port 5201
[ ID] Interval           Transfer     Bitrate         Retr  Cwnd
[  5]   0.00-1.00   sec  90.4 MBytes   758 Mbits/sec  1800    106 KBytes
[  5]   1.00-2.00   sec  93.4 MBytes   783 Mbits/sec  1550   51.6 KBytes
[  5]   2.00-3.00   sec  92.6 MBytes   777 Mbits/sec  1350   76.8 KBytes
[  5]   3.00-4.00   sec  92.9 MBytes   779 Mbits/sec  1800   56.4 KBytes
[  5]   4.00-5.00   sec  93.4 MBytes   783 Mbits/sec  1650   69.6 KBytes
[  5]   5.00-6.00   sec  90.6 MBytes   760 Mbits/sec  1500   73.2 KBytes
[  5]   6.00-7.00   sec  87.6 MBytes   735 Mbits/sec  1400   76.8 KBytes
[  5]   7.00-8.00   sec  92.6 MBytes   777 Mbits/sec  1600   82.7 KBytes
[  5]   8.00-9.00   sec  91.1 MBytes   764 Mbits/sec  1500   70.8 KBytes
[  5]   9.00-10.00  sec  92.0 MBytes   771 Mbits/sec  1550   85.1 KBytes
- - - - - - - - - - - - - - - - - - - - - - - - -
[ ID] Interval           Transfer     Bitrate         Retr
[  5]   0.00-10.00  sec   917 MBytes   769 Mbits/sec  15700             sender
[  5]   0.00-10.00  sec   916 MBytes   768 Mbits/sec                  receiver

iperf Done.
```
This commit is contained in:
Thomas Eizinger
2024-12-05 00:18:20 +00:00
committed by GitHub
parent 2f2ad2cffe
commit 90cf191a7c
31 changed files with 660 additions and 423 deletions

38
rust/Cargo.lock generated
View File

@@ -1042,6 +1042,9 @@ dependencies = [
"connlib-model",
"firezone-logging",
"firezone-telemetry",
"flume",
"futures",
"ip-packet",
"ip_network",
"jni",
"libc",
@@ -1070,6 +1073,9 @@ dependencies = [
"connlib-model",
"firezone-logging",
"firezone-telemetry",
"flume",
"futures",
"ip-packet",
"ip_network",
"libc",
"oslog",
@@ -1912,6 +1918,7 @@ dependencies = [
"axum",
"clap",
"firezone-logging",
"flume",
"futures",
"hex-literal",
"ip-packet",
@@ -2224,7 +2231,6 @@ dependencies = [
"test-strategy",
"thiserror",
"tokio",
"tokio-util",
"tracing",
"tracing-subscriber",
"tun",
@@ -2250,6 +2256,18 @@ dependencies = [
"bitflags 1.3.2",
]
[[package]]
name = "flume"
version = "0.11.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "da0e4dd2a88388a1f4ccc7c9ce104604dab68d9f408dc34cd45823d5a9069095"
dependencies = [
"futures-core",
"futures-sink",
"nanorand",
"spin",
]
[[package]]
name = "fnv"
version = "1.0.7"
@@ -2581,8 +2599,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7"
dependencies = [
"cfg-if",
"js-sys",
"libc",
"wasi 0.11.0+wasi-snapshot-preview1",
"wasm-bindgen",
]
[[package]]
@@ -3670,6 +3690,15 @@ version = "0.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bc0287524726960e07b119cebd01678f852f147742ae0d925e6a520dca956126"
[[package]]
name = "nanorand"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6a51313c5820b0b02bd422f4b44776fbf47961755c74ce64afc73bfad10226c3"
dependencies = [
"getrandom 0.2.15",
]
[[package]]
name = "native-dialog"
version = "0.7.0"
@@ -5953,6 +5982,9 @@ name = "spin"
version = "0.9.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67"
dependencies = [
"lock_api",
]
[[package]]
name = "stable_deref_trait"
@@ -7125,8 +7157,12 @@ checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b"
name = "tun"
version = "0.1.0"
dependencies = [
"flume",
"futures",
"ip-packet",
"libc",
"tokio",
"tracing",
]
[[package]]

View File

@@ -122,6 +122,7 @@ thiserror = "1.0.68"
time = "0.3.36"
tokio = "1.41"
tokio-stream = "0.1.16"
flume = { version = "0.11.1", features = ["async"] }
tokio-tungstenite = "0.23.1"
tokio-util = "0.7.11"
tracing = { version = "0.1.40" }

View File

@@ -13,6 +13,7 @@ clap = { workspace = true, features = ["derive", "env"] }
firezone-logging = { workspace = true }
futures = { workspace = true, features = ["std", "async-await"] }
hex-literal = { workspace = true }
ip-packet = { workspace = true }
ip_network = { workspace = true, features = ["serde"] }
socket-factory = { workspace = true }
thiserror = { workspace = true }
@@ -24,6 +25,7 @@ tun = { workspace = true }
tokio = { workspace = true, features = ["macros", "rt-multi-thread"] }
[target.'cfg(target_os = "linux")'.dependencies]
flume = { workspace = true }
libc = { workspace = true }
netlink-packet-core = { version = "0.7" }
netlink-packet-route = { version = "0.19" }

View File

@@ -28,7 +28,6 @@ mod platform {
mod platform {
use anyhow::Result;
use firezone_bin_shared::TunDeviceManager;
use ip_packet::{IpPacket, IpPacketBuf};
use std::{
future::poll_fn,
net::{Ipv4Addr, Ipv6Addr},
@@ -47,10 +46,11 @@ mod platform {
const REQ_LEN: usize = 1_000;
const RESP_CODE: u8 = 43;
const SERVER_PORT: u16 = 3000;
const NUM_THREADS: usize = 1; // Note: Unused on Windows.
let ipv4 = Ipv4Addr::from([100, 90, 215, 97]);
let ipv6 = Ipv6Addr::from([0xfd00, 0x2021, 0x1111, 0x0, 0x0, 0x0, 0x0016, 0x588f]);
let mut device_manager = TunDeviceManager::new(MTU)?;
let mut device_manager = TunDeviceManager::new(MTU, NUM_THREADS)?;
let mut tun = device_manager.make_tun()?;
device_manager.set_ips(ipv4, ipv6).await?;
@@ -62,15 +62,14 @@ mod platform {
let server_task = tokio::spawn(async move {
tracing::debug!("Server task entered");
let mut requests_served = 0;
// We aren't interested in allocator speed or doing any processing,
// so just cache the response packet
let mut response_pkt = None;
let mut time_spent = Duration::from_millis(0);
loop {
let mut req_buf = IpPacketBuf::new();
let n = poll_fn(|cx| tun.poll_read(req_buf.buf(), cx)).await?;
let mut buf = Vec::with_capacity(1);
poll_fn(|cx| tun.poll_recv_many(cx, &mut buf, 1)).await;
let original_pkt = buf.remove(0);
let start = Instant::now();
let original_pkt = IpPacket::new(req_buf, n).unwrap();
let Some(original_udp) = original_pkt.as_udp() else {
continue;
};
@@ -81,21 +80,16 @@ mod platform {
panic!("Wrong request code");
}
// Only generate the response packet on the first loop,
// then just reuse it.
let res_buf = response_pkt
.get_or_insert_with(|| {
ip_packet::make::udp_packet(
original_pkt.destination(),
original_pkt.source(),
original_udp.destination_port(),
original_udp.source_port(),
vec![RESP_CODE],
)
.unwrap()
})
.packet();
tun.write4(res_buf)?;
tun.send(
ip_packet::make::udp_packet(
original_pkt.destination(),
original_pkt.source(),
original_udp.destination_port(),
original_udp.source_port(),
vec![RESP_CODE],
)
.unwrap(),
)?;
requests_served += 1;
time_spent += start.elapsed();
if requests_served >= NUM_REQUESTS {

View File

@@ -42,7 +42,7 @@ mod tests {
let ipv4 = Ipv4Addr::from([100, 90, 215, 97]);
let ipv6 = Ipv6Addr::from([0xfd00, 0x2021, 0x1111, 0x0, 0x0, 0x0, 0x0016, 0x588f]);
let mut device_manager = TunDeviceManager::new(1280).unwrap();
let mut device_manager = TunDeviceManager::new(1280, 1).unwrap();
let _tun = device_manager.make_tun().unwrap();
device_manager.set_ips(ipv4, ipv6).await.unwrap();
@@ -72,7 +72,7 @@ mod tests {
let ipv4 = Ipv4Addr::from([100, 90, 215, 97]);
let ipv6 = Ipv6Addr::from([0xfd00, 0x2021, 0x1111, 0x0, 0x0, 0x0, 0x0016, 0x588f]);
let mut device_manager = TunDeviceManager::new(1280).unwrap();
let mut device_manager = TunDeviceManager::new(1280, 1).unwrap();
let _tun = device_manager.make_tun().unwrap();
device_manager.set_ips(ipv4, ipv6).await.unwrap();
@@ -125,7 +125,7 @@ mod tests {
/// Checks for regressions in issue #4765, un-initializing Wintun
/// Redundant but harmless on Linux.
fn tunnel_drop() {
let mut tun_device_manager = TunDeviceManager::new(1280).unwrap();
let mut tun_device_manager = TunDeviceManager::new(1280, 1).unwrap();
// Each cycle takes about half a second, so this will take a fair bit to run.
for _ in 0..50 {

View File

@@ -3,16 +3,18 @@
use crate::FIREZONE_MARK;
use anyhow::{anyhow, Context as _, Result};
use firezone_logging::std_dyn_err;
use futures::task::AtomicWaker;
use futures::TryStreamExt;
use ip_network::{IpNetwork, Ipv4Network, Ipv6Network};
use ip_packet::{IpPacket, IpPacketBuf};
use libc::{
close, fcntl, makedev, mknod, open, EEXIST, ENOENT, F_GETFL, F_SETFL, O_NONBLOCK, O_RDWR,
S_IFCHR,
fcntl, makedev, mknod, open, EEXIST, ENOENT, F_GETFL, F_SETFL, O_NONBLOCK, O_RDWR, S_IFCHR,
};
use netlink_packet_route::route::{RouteProtocol, RouteScope};
use netlink_packet_route::rule::RuleAction;
use rtnetlink::{new_connection, Error::NetlinkError, Handle, RouteAddRequest, RuleAddRequest};
use std::path::Path;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::{
collections::HashSet,
@@ -21,13 +23,12 @@ use std::{
use std::{
ffi::CStr,
fs, io,
os::{
fd::{AsRawFd, RawFd},
unix::fs::PermissionsExt,
},
os::{fd::RawFd, unix::fs::PermissionsExt},
};
use tokio::io::unix::AsyncFd;
use tokio::sync::mpsc;
use tun::ioctl;
use tun::unix::TunFd;
const TUNSETIFF: libc::c_ulong = 0x4004_54ca;
const TUN_DEV_MAJOR: u32 = 10;
@@ -40,6 +41,7 @@ const FIREZONE_TABLE: u32 = 0x2021_fd00;
/// For lack of a better name
pub struct TunDeviceManager {
mtu: u32,
num_threads: usize,
connection: Connection,
routes: HashSet<IpNetwork>,
}
@@ -61,7 +63,7 @@ impl TunDeviceManager {
/// Creates a new managed tunnel device.
///
/// Panics if called without a Tokio runtime.
pub fn new(mtu: usize) -> Result<Self> {
pub fn new(mtu: usize, num_threads: usize) -> Result<Self> {
let (cxn, handle, _) = new_connection()?;
let task = tokio::spawn(cxn);
let connection = Connection { handle, task };
@@ -70,11 +72,12 @@ impl TunDeviceManager {
connection,
routes: Default::default(),
mtu: mtu as u32,
num_threads,
})
}
pub fn make_tun(&mut self) -> Result<Tun> {
Ok(Tun::new()?)
Ok(Tun::new(self.num_threads)?)
}
#[tracing::instrument(level = "trace", skip(self))]
@@ -284,62 +287,100 @@ async fn remove_route(route: &IpNetwork, idx: u32, handle: &Handle) {
#[derive(Debug)]
pub struct Tun {
fd: AsyncFd<RawFd>,
outbound_tx: flume::Sender<IpPacket>,
outbound_capacity_waker: Arc<AtomicWaker>,
inbound_rx: mpsc::Receiver<IpPacket>,
}
impl Tun {
pub fn new() -> io::Result<Self> {
pub fn new(num_threads: usize) -> io::Result<Self> {
create_tun_device()?;
let fd = match unsafe { open(TUN_FILE.as_ptr() as _, O_RDWR) } {
-1 => return Err(get_last_error()),
fd => fd,
};
let (inbound_tx, inbound_rx) = mpsc::channel(1000);
let (outbound_tx, outbound_rx) = flume::bounded(1000); // flume is an MPMC channel, therefore perfect for workstealing outbound packets.
let outbound_capacity_waker = Arc::new(AtomicWaker::new());
// Safety: We just opened the file descriptor.
unsafe {
ioctl::exec(
fd,
TUNSETIFF,
&mut ioctl::Request::<ioctl::SetTunFlagsPayload>::new(TunDeviceManager::IFACE_NAME),
)?;
for n in 0..num_threads {
let fd = AsyncFd::new(open_tun()?)?;
let outbound_rx = outbound_rx.clone().into_stream();
let inbound_tx = inbound_tx.clone();
let outbound_capacity_waker = outbound_capacity_waker.clone();
std::thread::Builder::new()
.name(format!("TUN send/recv {n}/{num_threads}"))
.spawn(move || {
tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()?
.block_on(tun::unix::send_recv_tun(
fd,
inbound_tx,
outbound_rx,
outbound_capacity_waker,
read,
write,
));
io::Result::Ok(())
})
.map_err(io::Error::other)?;
}
set_non_blocking(fd)?;
// Safety: We just opened the fd.
unsafe { Self::from_fd(fd) }
}
/// Create a new [`Tun`] from a raw file descriptor.
///
/// # Safety
///
/// The file descriptor must be open.
unsafe fn from_fd(fd: RawFd) -> io::Result<Self> {
Ok(Tun {
fd: AsyncFd::new(fd)?,
Ok(Self {
outbound_tx,
outbound_capacity_waker,
inbound_rx,
})
}
}
impl Drop for Tun {
fn drop(&mut self) {
unsafe { close(self.fd.as_raw_fd()) };
fn open_tun() -> Result<TunFd, io::Error> {
let fd = match unsafe { open(TUN_FILE.as_ptr() as _, O_RDWR) } {
-1 => return Err(get_last_error()),
fd => fd,
};
unsafe {
ioctl::exec(
fd,
TUNSETIFF,
&mut ioctl::Request::<ioctl::SetTunFlagsPayload>::new(TunDeviceManager::IFACE_NAME),
)?;
}
set_non_blocking(fd)?;
// Safety: We are not closing the FD.
let fd = unsafe { TunFd::new(fd) };
Ok(fd)
}
impl tun::Tun for Tun {
fn write4(&self, buf: &[u8]) -> io::Result<usize> {
write(self.fd.as_raw_fd(), buf)
fn poll_send_ready(&mut self, cx: &mut Context) -> Poll<io::Result<()>> {
if self.outbound_tx.is_full() {
self.outbound_capacity_waker.register(cx.waker());
return Poll::Pending;
}
Poll::Ready(Ok(()))
}
fn write6(&self, buf: &[u8]) -> io::Result<usize> {
write(self.fd.as_raw_fd(), buf)
fn send(&mut self, packet: IpPacket) -> io::Result<()> {
self.outbound_tx
.try_send(packet)
.map_err(io::Error::other)?;
Ok(())
}
fn poll_read(&mut self, buf: &mut [u8], cx: &mut Context<'_>) -> Poll<io::Result<usize>> {
tun::unix::poll_raw_fd(&self.fd, |fd| read(fd, buf), cx)
fn poll_recv_many(
&mut self,
cx: &mut Context,
buf: &mut Vec<IpPacket>,
max: usize,
) -> Poll<usize> {
self.inbound_rx.poll_recv_many(cx, buf, max)
}
fn name(&self) -> &str {
@@ -389,7 +430,9 @@ fn create_tun_device() -> io::Result<()> {
}
/// Read from the given file descriptor in the buffer.
fn read(fd: RawFd, dst: &mut [u8]) -> io::Result<usize> {
fn read(fd: RawFd, dst: &mut IpPacketBuf) -> io::Result<usize> {
let dst = dst.buf();
// Safety: Within this module, the file descriptor is always valid.
match unsafe { libc::read(fd, dst.as_mut_ptr() as _, dst.len()) } {
-1 => Err(io::Error::last_os_error()),
@@ -397,8 +440,10 @@ fn read(fd: RawFd, dst: &mut [u8]) -> io::Result<usize> {
}
}
/// Write the buffer to the given file descriptor.
fn write(fd: RawFd, buf: &[u8]) -> io::Result<usize> {
/// Write the packet to the given file descriptor.
fn write(fd: RawFd, packet: &IpPacket) -> io::Result<usize> {
let buf = packet.packet();
// Safety: Within this module, the file descriptor is always valid.
match unsafe { libc::write(fd, buf.as_ptr() as _, buf.len() as _) } {
-1 => Err(io::Error::last_os_error()),

View File

@@ -3,6 +3,7 @@ use crate::TUNNEL_NAME;
use anyhow::{Context as _, Result};
use firezone_logging::{anyhow_dyn_err, std_dyn_err};
use ip_network::{IpNetwork, Ipv4Network, Ipv6Network};
use ip_packet::{IpPacket, IpPacketBuf};
use ring::digest;
use std::{
collections::HashSet,
@@ -12,7 +13,7 @@ use std::{
path::{Path, PathBuf},
process::{Command, Stdio},
sync::Arc,
task::{ready, Context, Poll},
task::{Context, Poll},
};
use tokio::sync::mpsc;
use windows::Win32::{
@@ -47,7 +48,7 @@ pub struct TunDeviceManager {
impl TunDeviceManager {
#[expect(clippy::unnecessary_wraps, reason = "Fallible on Linux")]
pub fn new(mtu: usize) -> Result<Self> {
pub fn new(mtu: usize, _num_threads: usize) -> Result<Self> {
Ok(Self {
iface_idx: None,
routes: HashSet::default(),
@@ -190,7 +191,7 @@ pub struct Tun {
/// The index of our network adapter, we can use this when asking Windows to add / remove routes / DNS rules
/// It's stable across app restarts and I'm assuming across system reboots too.
iface_idx: u32,
packet_rx: mpsc::Receiver<wintun::Packet>,
inbound_rx: mpsc::Receiver<IpPacket>,
recv_thread: Option<std::thread::JoinHandle<()>>,
session: Arc<wintun::Session>,
}
@@ -198,10 +199,10 @@ pub struct Tun {
impl Drop for Tun {
fn drop(&mut self) {
tracing::debug!(
channel_capacity = self.packet_rx.capacity(),
channel_capacity = self.inbound_rx.capacity(),
"Shutting down packet channel..."
);
self.packet_rx.close(); // This avoids a deadlock when we join the worker thread, see PR 5571
self.inbound_rx.close(); // This avoids a deadlock when we join the worker thread, see PR 5571
if let Err(error) = self.session.shutdown() {
tracing::error!(error = std_dyn_err(&error), "wintun::Session::shutdown");
}
@@ -241,16 +242,14 @@ impl Tun {
.start_session(RING_BUFFER_SIZE)
.context("Failed to start session")?,
);
// 4 is a nice power of two. Wintun already queues packets for us, so we don't
// need much capacity here.
let (packet_tx, packet_rx) = mpsc::channel(4);
let recv_thread = start_recv_thread(packet_tx, Arc::clone(&session))
let (inbound_tx, inbound_rx) = mpsc::channel(1000); // We want to be able to batch-receive from this.
let recv_thread = start_recv_thread(inbound_tx, Arc::clone(&session))
.context("Failed to start recv thread")?;
Ok(Self {
iface_idx,
recv_thread: Some(recv_thread),
packet_rx,
inbound_rx,
session: Arc::clone(&session),
})
}
@@ -258,72 +257,60 @@ impl Tun {
pub fn iface_idx(&self) -> u32 {
self.iface_idx
}
// Moves packets from the Internet towards the user
fn write(&self, bytes: &[u8]) -> io::Result<usize> {
let len = bytes
.len()
.try_into()
.map_err(|_| io::Error::other("Packet too large; length does not fit into u16"))?;
let Ok(mut pkt) = self.session.allocate_send_packet(len) else {
// Ring buffer is full, just drop the packet since we're at the IP layer
return Ok(0);
};
pkt.bytes_mut().copy_from_slice(bytes);
// `send_packet` cannot fail to enqueue the packet, since we already allocated
// space in the ring buffer.
self.session.send_packet(pkt);
Ok(bytes.len())
}
}
impl tun::Tun for Tun {
// Moves packets from the user towards the Internet
fn poll_read(&mut self, buf: &mut [u8], cx: &mut Context<'_>) -> Poll<io::Result<usize>> {
let pkt = ready!(self.packet_rx.poll_recv(cx));
match pkt {
Some(pkt) => {
let bytes = pkt.bytes();
let len = bytes.len();
if len > buf.len() {
// This shouldn't happen now that we set IPv4 and IPv6 MTU
// If it does, something is wrong.
tracing::warn!("Packet is too long to read ({len} bytes)");
return Poll::Ready(Ok(0));
}
buf[0..len].copy_from_slice(bytes);
Poll::Ready(Ok(len))
}
None => {
tracing::error!("error receiving packet from mpsc channel");
Poll::Ready(Err(std::io::ErrorKind::Other.into()))
}
}
/// Receive a batch of packets up to `max`.
fn poll_recv_many(
&mut self,
cx: &mut Context,
buf: &mut Vec<IpPacket>,
max: usize,
) -> Poll<usize> {
self.inbound_rx.poll_recv_many(cx, buf, max)
}
fn name(&self) -> &str {
TUNNEL_NAME
}
fn write4(&self, bytes: &[u8]) -> io::Result<usize> {
self.write(bytes)
/// Check if more packets can be sent.
fn poll_send_ready(&mut self, _: &mut Context) -> Poll<io::Result<()>> {
// TODO: Figure out how we can do readiness checks on `wintun`.
Poll::Ready(Ok(()))
}
fn write6(&self, bytes: &[u8]) -> io::Result<usize> {
self.write(bytes)
/// Send a packet.
fn send(&mut self, packet: IpPacket) -> io::Result<()> {
let bytes = packet.packet();
let len = bytes
.len()
.try_into()
.map_err(|_| io::Error::other("Packet too large; length does not fit into u16"))?;
let mut pkt = self
.session
.allocate_send_packet(len)
.map_err(io::Error::other)?;
pkt.bytes_mut().copy_from_slice(bytes);
// `send_packet` cannot fail to enqueue the packet, since we already allocated
// space in the ring buffer.
self.session.send_packet(pkt);
Ok(())
}
}
// Moves packets from the user towards the Internet
fn start_recv_thread(
packet_tx: mpsc::Sender<wintun::Packet>,
packet_tx: mpsc::Sender<IpPacket>,
session: Arc<wintun::Session>,
) -> io::Result<std::thread::JoinHandle<()>> {
std::thread::Builder::new()
.name("Firezone wintun worker".into())
.name("TUN recv".into())
.spawn(move || loop {
let pkt = match session.receive_blocking() {
Ok(pkt) => pkt,
@@ -339,6 +326,26 @@ fn start_recv_thread(
}
};
let mut ip_packet_buf = IpPacketBuf::new();
let src = pkt.bytes();
let dst = ip_packet_buf.buf();
if src.len() > dst.len() {
tracing::warn!(len = %src.len(), "Received too large packet");
continue;
}
dst[..src.len()].copy_from_slice(src);
let pkt = match IpPacket::new(ip_packet_buf, src.len()) {
Ok(pkt) => pkt,
Err(e) => {
tracing::debug!("Failed to parse IP packet: {e:#}");
continue;
}
};
// Use `blocking_send` so that if connlib is behind by a few packets,
// Wintun will queue up new packets in its ring buffer while we
// wait for our MPSC channel to clear.
@@ -352,7 +359,7 @@ fn start_recv_thread(
);
break;
}
}
};
})
}

View File

@@ -17,6 +17,9 @@ connlib-client-shared = { workspace = true }
connlib-model = { workspace = true }
firezone-logging = { workspace = true }
firezone-telemetry = { workspace = true }
flume = { workspace = true }
futures = { workspace = true }
ip-packet = { workspace = true }
ip_network = { workspace = true }
jni = { workspace = true, features = ["invocation"] }
libc = { workspace = true }
@@ -27,7 +30,7 @@ secrecy = { workspace = true }
serde_json = { workspace = true }
socket-factory = { workspace = true }
thiserror = { workspace = true }
tokio = { workspace = true, features = ["rt-multi-thread"] }
tokio = { workspace = true, features = ["rt-multi-thread", "sync"] }
tracing = { workspace = true, features = ["std", "attributes"] }
tracing-appender = { workspace = true }
tracing-subscriber = { workspace = true }

View File

@@ -1,39 +1,51 @@
use futures::task::AtomicWaker;
use ip_packet::{IpPacket, IpPacketBuf};
use std::sync::Arc;
use std::task::{Context, Poll};
use std::{
io,
os::fd::{AsRawFd, RawFd},
};
use std::{io, os::fd::RawFd};
use tokio::io::unix::AsyncFd;
use tokio::sync::mpsc;
use tun::ioctl;
use tun::unix::TunFd;
#[derive(Debug)]
pub struct Tun {
fd: AsyncFd<RawFd>,
name: String,
}
impl Drop for Tun {
fn drop(&mut self) {
unsafe { libc::close(self.fd.as_raw_fd()) };
}
outbound_tx: flume::Sender<IpPacket>,
outbound_capacity_waker: Arc<AtomicWaker>,
inbound_rx: mpsc::Receiver<IpPacket>,
}
impl tun::Tun for Tun {
fn write4(&self, src: &[u8]) -> std::io::Result<usize> {
write(self.fd.as_raw_fd(), src)
}
fn write6(&self, src: &[u8]) -> std::io::Result<usize> {
write(self.fd.as_raw_fd(), src)
}
fn poll_read(&mut self, buf: &mut [u8], cx: &mut Context<'_>) -> Poll<io::Result<usize>> {
tun::unix::poll_raw_fd(&self.fd, |fd| read(fd, buf), cx)
}
fn name(&self) -> &str {
self.name.as_str()
}
fn poll_send_ready(&mut self, cx: &mut Context) -> Poll<io::Result<()>> {
if self.outbound_tx.is_full() {
self.outbound_capacity_waker.register(cx.waker());
return Poll::Pending;
}
Poll::Ready(Ok(()))
}
fn send(&mut self, packet: IpPacket) -> io::Result<()> {
self.outbound_tx
.try_send(packet)
.map_err(io::Error::other)?;
Ok(())
}
fn poll_recv_many(
&mut self,
cx: &mut Context,
buf: &mut Vec<IpPacket>,
max: usize,
) -> Poll<usize> {
self.inbound_rx.poll_recv_many(cx, buf, max)
}
}
impl Tun {
@@ -41,13 +53,49 @@ impl Tun {
///
/// # Safety
///
/// The file descriptor must be open.
/// - The file descriptor must be open.
/// - The file descriptor must not get closed by anyone else.
pub unsafe fn from_fd(fd: RawFd) -> io::Result<Self> {
let name = interface_name(fd)?;
// Safety: We are forwarding the safety requirements to the caller.
let fd = unsafe { TunFd::new(fd) };
let fd = AsyncFd::new(fd)?;
let (inbound_tx, inbound_rx) = mpsc::channel(1000);
let (outbound_tx, outbound_rx) = flume::bounded(1000); // flume is an MPMC channel, therefore perfect for workstealing outbound packets.
let outbound_capacity_waker = Arc::new(AtomicWaker::new());
// TODO: Test whether we can set `IFF_MULTI_QUEUE` on Android devices.
std::thread::Builder::new()
.name("TUN send/recv".to_owned())
.spawn({
let outbound_capacity_waker = outbound_capacity_waker.clone();
|| {
tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()?
.block_on(tun::unix::send_recv_tun(
fd,
inbound_tx,
outbound_rx.into_stream(),
outbound_capacity_waker,
read,
write,
));
io::Result::Ok(())
}
})
.map_err(io::Error::other)?;
Ok(Tun {
fd: AsyncFd::new(fd)?,
name,
outbound_tx,
inbound_rx,
outbound_capacity_waker,
})
}
}
@@ -67,7 +115,9 @@ unsafe fn interface_name(fd: RawFd) -> io::Result<String> {
}
/// Read from the given file descriptor in the buffer.
fn read(fd: RawFd, dst: &mut [u8]) -> io::Result<usize> {
fn read(fd: RawFd, dst: &mut IpPacketBuf) -> io::Result<usize> {
let dst = dst.buf();
// Safety: Within this module, the file descriptor is always valid.
match unsafe { libc::read(fd, dst.as_mut_ptr() as _, dst.len()) } {
-1 => Err(io::Error::last_os_error()),
@@ -75,10 +125,12 @@ fn read(fd: RawFd, dst: &mut [u8]) -> io::Result<usize> {
}
}
/// Write the buffer to the given file descriptor.
fn write(fd: RawFd, buf: &[u8]) -> io::Result<usize> {
/// Write the packet to the given file descriptor.
fn write(fd: RawFd, packet: &IpPacket) -> io::Result<usize> {
let buf = packet.packet();
// Safety: Within this module, the file descriptor is always valid.
match unsafe { libc::write(fd.as_raw_fd(), buf.as_ptr() as _, buf.len() as _) } {
match unsafe { libc::write(fd, buf.as_ptr() as _, buf.len() as _) } {
-1 => Err(io::Error::last_os_error()),
n => Ok(n as usize),
}

View File

@@ -15,6 +15,9 @@ connlib-client-shared = { workspace = true }
connlib-model = { workspace = true }
firezone-logging = { workspace = true }
firezone-telemetry = { workspace = true }
flume = { workspace = true }
futures = { workspace = true }
ip-packet = { workspace = true }
ip_network = { workspace = true }
libc = { workspace = true }
phoenix-channel = { workspace = true }
@@ -23,7 +26,7 @@ secrecy = { workspace = true }
serde_json = { workspace = true }
socket-factory = { workspace = true }
swift-bridge = { workspace = true }
tokio = { workspace = true, features = ["rt-multi-thread"] }
tokio = { workspace = true, features = ["rt-multi-thread", "sync"] }
tracing = { workspace = true }
tracing-appender = { workspace = true }
tracing-subscriber = { workspace = true }

View File

@@ -1,76 +1,96 @@
use futures::task::AtomicWaker;
use ip_packet::{IpPacket, IpPacketBuf};
use libc::{fcntl, iovec, msghdr, recvmsg, AF_INET, AF_INET6, F_GETFL, F_SETFL, O_NONBLOCK};
use std::sync::Arc;
use std::task::{Context, Poll};
use std::{
io,
os::fd::{AsRawFd as _, RawFd},
};
use tokio::io::unix::AsyncFd;
use tokio::sync::mpsc;
#[derive(Debug)]
pub struct Tun {
name: String,
fd: AsyncFd<RawFd>,
outbound_capacity_waker: Arc<AtomicWaker>,
outbound_tx: flume::Sender<IpPacket>,
inbound_rx: mpsc::Receiver<IpPacket>,
}
impl Tun {
pub fn new() -> io::Result<Self> {
let fd = search_for_tun_fd()?;
set_non_blocking(fd)?;
let name = name(fd)?;
Ok(Self {
let fd = AsyncFd::new(fd)?;
let (inbound_tx, inbound_rx) = mpsc::channel(1000);
let (outbound_tx, outbound_rx) = flume::bounded(1000); // flume is an MPMC channel, therefore perfect for workstealing outbound packets.
let outbound_capacity_waker = Arc::new(AtomicWaker::new());
std::thread::Builder::new()
.name("TUN send/recv".to_owned())
.spawn({
let outbound_capacity_waker = outbound_capacity_waker.clone();
|| {
tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()?
.block_on(tun::unix::send_recv_tun(
fd,
inbound_tx,
outbound_rx.into_stream(),
outbound_capacity_waker,
read,
write,
));
io::Result::Ok(())
}
})
.map_err(io::Error::other)?;
Ok(Tun {
name,
fd: AsyncFd::new(fd)?,
outbound_tx,
inbound_rx,
outbound_capacity_waker,
})
}
fn write(&self, src: &[u8], af: u8) -> io::Result<usize> {
let mut hdr = [0, 0, 0, af];
let mut iov = [
iovec {
iov_base: hdr.as_mut_ptr() as _,
iov_len: hdr.len(),
},
iovec {
iov_base: src.as_ptr() as _,
iov_len: src.len(),
},
];
let msg_hdr = msghdr {
msg_name: std::ptr::null_mut(),
msg_namelen: 0,
msg_iov: &mut iov[0],
msg_iovlen: iov.len() as _,
msg_control: std::ptr::null_mut(),
msg_controllen: 0,
msg_flags: 0,
};
match unsafe { libc::sendmsg(self.fd.as_raw_fd(), &msg_hdr, 0) } {
-1 => Err(io::Error::last_os_error()),
n => Ok(n as usize),
}
}
}
impl tun::Tun for Tun {
fn write4(&self, src: &[u8]) -> io::Result<usize> {
self.write(src, AF_INET as u8)
}
fn write6(&self, src: &[u8]) -> io::Result<usize> {
self.write(src, AF_INET6 as u8)
}
fn poll_read(&mut self, buf: &mut [u8], cx: &mut Context<'_>) -> Poll<io::Result<usize>> {
tun::unix::poll_raw_fd(&self.fd, |fd| read(fd, buf), cx)
}
fn name(&self) -> &str {
self.name.as_str()
}
fn poll_send_ready(&mut self, cx: &mut Context) -> Poll<io::Result<()>> {
if self.outbound_tx.is_full() {
self.outbound_capacity_waker.register(cx.waker());
return Poll::Pending;
}
Poll::Ready(Ok(()))
}
fn send(&mut self, packet: IpPacket) -> io::Result<()> {
self.outbound_tx
.try_send(packet)
.map_err(io::Error::other)?;
Ok(())
}
fn poll_recv_many(
&mut self,
cx: &mut Context,
buf: &mut Vec<IpPacket>,
max: usize,
) -> Poll<usize> {
self.inbound_rx.poll_recv_many(cx, buf, max)
}
}
fn get_last_error() -> io::Error {
@@ -87,7 +107,9 @@ fn set_non_blocking(fd: RawFd) -> io::Result<()> {
}
}
fn read(fd: RawFd, dst: &mut [u8]) -> io::Result<usize> {
fn read(fd: RawFd, dst: &mut IpPacketBuf) -> io::Result<usize> {
let dst = dst.buf();
let mut hdr = [0u8; 4];
let mut iov = [
@@ -119,6 +141,41 @@ fn read(fd: RawFd, dst: &mut [u8]) -> io::Result<usize> {
}
}
fn write(fd: RawFd, src: &IpPacket) -> io::Result<usize> {
let af = match src {
IpPacket::Ipv4(_) => AF_INET,
IpPacket::Ipv6(_) => AF_INET6,
};
let src = src.packet();
let mut hdr = [0, 0, 0, af];
let mut iov = [
iovec {
iov_base: hdr.as_mut_ptr() as _,
iov_len: hdr.len(),
},
iovec {
iov_base: src.as_ptr() as _,
iov_len: src.len(),
},
];
let msg_hdr = msghdr {
msg_name: std::ptr::null_mut(),
msg_namelen: 0,
msg_iov: &mut iov[0],
msg_iovlen: iov.len() as _,
msg_control: std::ptr::null_mut(),
msg_controllen: 0,
msg_flags: 0,
};
match unsafe { libc::sendmsg(fd.as_raw_fd(), &msg_hdr, 0) } {
-1 => Err(io::Error::last_os_error()),
n => Ok(n as usize),
}
}
#[cfg(any(target_os = "macos", target_os = "ios"))]
fn name(fd: RawFd) -> io::Result<String> {
use libc::{getsockopt, socklen_t, IF_NAMESIZE, SYSPROTO_CONTROL, UTUN_OPT_IFNAME};

View File

@@ -38,7 +38,6 @@ socket-factory = { workspace = true }
socket2 = { workspace = true }
thiserror = { workspace = true }
tokio = { workspace = true }
tokio-util = { workspace = true }
tracing = { workspace = true, features = ["attributes"] }
tun = { workspace = true }
uuid = { workspace = true, features = ["std", "v4"] }

View File

@@ -1,7 +1,7 @@
use domain::base::iana::Rcode;
use domain::base::{Message, ParsedName, Rtype};
use domain::rdata::AllRecordData;
use ip_packet::{IpPacket, IpPacketBuf};
use ip_packet::IpPacket;
use itertools::Itertools;
use std::io;
use std::task::{Context, Poll, Waker};
@@ -31,47 +31,46 @@ impl Device {
}
}
pub(crate) fn poll_read(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<IpPacket>> {
pub(crate) fn poll_read_many(
&mut self,
cx: &mut Context<'_>,
buf: &mut Vec<IpPacket>,
max: usize,
) -> Poll<usize> {
let Some(tun) = self.tun.as_mut() else {
self.waker = Some(cx.waker().clone());
return Poll::Pending;
};
let mut ip_packet = IpPacketBuf::new();
let n = std::task::ready!(tun.poll_read(ip_packet.buf(), cx))?;
let n = std::task::ready!(tun.poll_recv_many(cx, buf, max));
if n == 0 {
self.tun = None;
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"device is closed",
)));
}
let packet = IpPacket::new(ip_packet, n).map_err(|e| {
io::Error::new(
io::ErrorKind::InvalidInput,
format!("Failed to parse IP packet: {e:#}"),
)
})?;
if tracing::event_enabled!(target: "wire::dns::qry", Level::TRACE) {
if let Some((qtype, qname, qid)) = parse_dns_query(&packet) {
tracing::trace!(target: "wire::dns::qry", %qid, "{:5} {qname}", qtype.to_string());
for packet in &buf[..n] {
if tracing::event_enabled!(target: "wire::dns::qry", Level::TRACE) {
if let Some((qtype, qname, qid)) = parse_dns_query(packet) {
tracing::trace!(target: "wire::dns::qry", %qid, "{:5} {qname}", qtype.to_string());
}
}
if packet.is_fz_p2p_control() {
tracing::warn!("Packet matches heuristics of FZ-internal p2p control protocol");
}
tracing::trace!(target: "wire::dev::recv", dst = %packet.destination(), src = %packet.source(), bytes = %packet.packet().len());
}
if packet.is_fz_p2p_control() {
tracing::warn!("Packet matches heuristics of FZ-internal p2p control protocol");
}
tracing::trace!(target: "wire::dev::recv", dst = %packet.destination(), src = %packet.source(), bytes = %packet.packet().len());
Poll::Ready(Ok(packet))
Poll::Ready(n)
}
pub fn write(&self, packet: IpPacket) -> io::Result<usize> {
pub fn poll_send_ready(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let Some(tun) = self.tun.as_mut() else {
self.waker = Some(cx.waker().clone());
return Poll::Pending;
};
tun.poll_send_ready(cx)
}
pub fn send(&mut self, packet: IpPacket) -> io::Result<()> {
if tracing::event_enabled!(target: "wire::dns::res", Level::TRACE) {
if let Some((qtype, qname, records, rcode, qid)) = parse_dns_response(&packet) {
tracing::trace!(target: "wire::dns::res", %qid, %rcode, "{:5} {qname} => [{records}]", qtype.to_string());
@@ -85,18 +84,17 @@ impl Device {
"FZ p2p control protocol packets should never leave `connlib`"
);
match packet {
IpPacket::Ipv4(msg) => self.tun()?.write4(msg.packet()),
IpPacket::Ipv6(msg) => self.tun()?.write6(msg.packet()),
}
self.tun()?.send(packet)?;
Ok(())
}
fn tun(&self) -> io::Result<&dyn Tun> {
fn tun(&mut self) -> io::Result<&mut dyn Tun> {
Ok(self
.tun
.as_ref()
.as_mut()
.ok_or_else(io_error_not_initialized)?
.as_ref())
.as_mut())
}
}

View File

@@ -2,11 +2,7 @@ mod gso_queue;
use crate::{device_channel::Device, dns, sockets::Sockets};
use domain::base::Message;
use firezone_logging::{err_with_src, telemetry_event, telemetry_span};
use futures::{
future::{self, Either},
stream, Stream, StreamExt,
};
use firezone_logging::{telemetry_event, telemetry_span};
use futures_bounded::FuturesTupleSet;
use futures_util::FutureExt as _;
use gso_queue::GsoQueue;
@@ -21,11 +17,7 @@ use std::{
task::{ready, Context, Poll},
time::{Duration, Instant},
};
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
sync::mpsc,
};
use tokio_util::sync::PollSender;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tracing::Instrument;
use tun::Tun;
@@ -33,7 +25,7 @@ use tun::Tun;
///
/// Reading IP packets from the channel in batches allows us to process (i.e. encrypt) them as a batch.
/// UDP datagrams of the same size and destination can then be sent in a single syscall using GSO.
const MAX_INBOUND_PACKET_BATCH: usize = 50;
const MAX_INBOUND_PACKET_BATCH: usize = 100;
const MAX_UDP_SIZE: usize = (1 << 16) - 1;
/// Bundles together all side-effects that connlib needs to have access to.
@@ -49,10 +41,8 @@ pub struct Io {
timeout: Option<Pin<Box<tokio::time::Sleep>>>,
tun_tx: mpsc::Sender<Box<dyn Tun>>,
tun: Device,
outbound_packet_buffer: VecDeque<IpPacket>,
outbound_packet_tx: PollSender<IpPacket>,
inbound_packet_rx: mpsc::Receiver<IpPacket>,
}
#[derive(Debug)]
@@ -86,7 +76,6 @@ pub enum Input<D, I> {
}
const DNS_QUERY_TIMEOUT: Duration = Duration::from_secs(5);
const IP_CHANNEL_SIZE: usize = 1000;
impl Io {
/// Creates a new I/O abstraction
@@ -99,32 +88,15 @@ impl Io {
let mut sockets = Sockets::default();
sockets.rebind(udp_socket_factory.as_ref()); // Bind sockets on startup. Must happen within a tokio runtime context.
let (inbound_packet_tx, inbound_packet_rx) = mpsc::channel(IP_CHANNEL_SIZE);
let (outbound_packet_tx, outbound_packet_rx) = mpsc::channel(IP_CHANNEL_SIZE);
let (tun_tx, tun_rx) = mpsc::channel(10);
std::thread::Builder::new()
.name("connlib-tun-send-recv".to_string())
.spawn(|| {
futures::executor::block_on(tun_send_recv(
tun_rx,
outbound_packet_rx,
inbound_packet_tx,
))
})
.expect("Failed to spawn tun_send_recv thread");
Self {
tun_tx,
outbound_packet_buffer: VecDeque::with_capacity(10), // It is unlikely that we process more than 10 packets after 1 GRO call.
outbound_packet_tx: PollSender::new(outbound_packet_tx),
inbound_packet_rx,
timeout: None,
sockets,
tcp_socket_factory,
udp_socket_factory,
dns_queries: FuturesTupleSet::new(DNS_QUERY_TIMEOUT, 1000),
gso_queue: GsoQueue::new(),
tun: Device::new(),
}
}
@@ -151,8 +123,8 @@ impl Io {
}
if let Poll::Ready(num_packets) =
self.inbound_packet_rx
.poll_recv_many(cx, &mut buffers.ip, MAX_INBOUND_PACKET_BATCH)
self.tun
.poll_read_many(cx, &mut buffers.ip, MAX_INBOUND_PACKET_BATCH)
{
return Poll::Ready(Ok(Input::Device(buffers.ip.drain(..num_packets))));
}
@@ -209,11 +181,8 @@ impl Io {
}
loop {
// First, acquire a slot in the channel.
ready!(self
.outbound_packet_tx
.poll_reserve(cx)
.map_err(|_| io::ErrorKind::BrokenPipe)?);
// First, check if we can send more packets.
ready!(self.tun.poll_send_ready(cx))?;
// Second, check if we have any buffer packets.
let Some(packet) = self.outbound_packet_buffer.pop_front() else {
@@ -221,20 +190,14 @@ impl Io {
};
// Third, send the packet.
self.outbound_packet_tx
.send_item(packet)
.map_err(|_| io::ErrorKind::BrokenPipe)?;
self.tun.send(packet)?;
}
Poll::Ready(Ok(()))
}
pub fn set_tun(&mut self, tun: Box<dyn Tun>) {
// If we can't set a new TUN device, shut down connlib.
self.tun_tx
.try_send(tun)
.expect("Channel to set new TUN device should always have capacity");
self.tun.set_tun(tun);
}
pub fn send_tun(&mut self, packet: IpPacket) {
@@ -350,82 +313,6 @@ impl Io {
}
}
async fn tun_send_recv(
mut tun_rx: mpsc::Receiver<Box<dyn Tun>>,
mut outbound_packet_rx: mpsc::Receiver<IpPacket>,
inbound_packet_tx: mpsc::Sender<IpPacket>,
) {
let mut device = Device::new();
let mut command_stream = stream::select_all([
new_tun_stream(&mut tun_rx),
outgoing_packet_stream(&mut outbound_packet_rx),
]);
loop {
match future::select(
command_stream.next(),
future::poll_fn(|cx| device.poll_read(cx)),
)
.await
{
Either::Left((Some(Command::SendPacket(p)), _)) => {
if let Err(e) = device.write(p) {
tracing::debug!("Failed to write TUN packet: {}", err_with_src(&e));
};
}
Either::Left((Some(Command::UpdateTun(tun)), _)) => {
device.set_tun(tun);
}
Either::Left((None, _)) => {
tracing::debug!("Command stream closed");
return;
}
Either::Right((Ok(p), _)) => {
if inbound_packet_tx.send(p).await.is_err() {
tracing::debug!("Inbound packet channel closed");
return;
};
}
Either::Right((Err(e), _)) => {
tracing::debug!(
"Failed to read packet from TUN device: {}",
err_with_src(&e)
);
}
};
}
}
#[expect(
clippy::large_enum_variant,
reason = "We purposely don't want to allocate each IP packet."
)]
enum Command {
UpdateTun(Box<dyn Tun>),
SendPacket(IpPacket),
}
fn new_tun_stream(
tun_rx: &mut mpsc::Receiver<Box<dyn Tun>>,
) -> Pin<Box<dyn Stream<Item = Command> + '_>> {
Box::pin(stream::poll_fn(|cx| {
tun_rx
.poll_recv(cx)
.map(|maybe_t| maybe_t.map(Command::UpdateTun))
}))
}
fn outgoing_packet_stream(
outbound_packet_rx: &mut mpsc::Receiver<IpPacket>,
) -> Pin<Box<dyn Stream<Item = Command> + '_>> {
Box::pin(stream::poll_fn(|cx| {
outbound_packet_rx
.poll_recv(cx)
.map(|maybe_p| maybe_p.map(Command::SendPacket))
}))
}
fn is_max_wg_packet_size(d: &DatagramIn) -> bool {
let len = d.packet.len();
if len > MAX_DATAGRAM_PAYLOAD {
@@ -444,14 +331,6 @@ mod tests {
use super::*;
#[test]
fn max_ip_channel_size_is_reasonable() {
let one_ip_packet = std::mem::size_of::<IpPacket>();
let max_channel_size = IP_CHANNEL_SIZE * one_ip_packet;
assert_eq!(max_channel_size, 1_360_000); // 1.36MB is fine, we only have 2 of these channels, meaning less than 3MB additional buffer in total.
}
#[tokio::test]
async fn timer_is_reset_after_it_fires() {
let now = Instant::now();
@@ -460,6 +339,7 @@ mod tests {
Arc::new(|_| Err(io::Error::other("not implemented"))),
Arc::new(|_| Err(io::Error::other("not implemented"))),
);
io.set_tun(Box::new(DummyTun));
io.reset_timeout(now + Duration::from_secs(1));
@@ -494,4 +374,29 @@ mod tests {
udp4: Vec::new(),
udp6: Vec::new(),
};
struct DummyTun;
impl Tun for DummyTun {
fn poll_send_ready(&mut self, _: &mut Context) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
fn send(&mut self, _: IpPacket) -> io::Result<()> {
Ok(())
}
fn poll_recv_many(
&mut self,
_: &mut Context,
_: &mut Vec<IpPacket>,
_: usize,
) -> Poll<usize> {
Poll::Pending
}
fn name(&self) -> &str {
"dummy"
}
}
}

View File

@@ -10,7 +10,6 @@ use anyhow::{Context as _, Result};
use domain::base::{iana::Rcode, MessageBuilder};
use firezone_bin_shared::TunDeviceManager;
use ip_network::Ipv4Network;
use ip_packet::{IpPacket, IpPacketBuf};
use tokio::task::JoinSet;
use tun::Tun;
@@ -24,7 +23,7 @@ async fn smoke() {
let ipv4 = Ipv4Addr::from([100, 90, 215, 97]);
let ipv6 = Ipv6Addr::from([0xfd00, 0x2021, 0x1111, 0x0, 0x0, 0x0, 0x0016, 0x588f]);
let mut device_manager = TunDeviceManager::new(1280).unwrap();
let mut device_manager = TunDeviceManager::new(1280, 1).unwrap();
let tun = device_manager.make_tun().unwrap();
device_manager.set_ips(ipv4, ipv6).await.unwrap();
device_manager
@@ -100,11 +99,10 @@ impl Eventloop {
fn poll(&mut self, cx: &mut Context) -> Poll<()> {
loop {
ready!(self.tun.poll_send_ready(cx)).unwrap();
if let Some(packet) = self.dns_server.poll_outbound() {
match packet {
IpPacket::Ipv4(v4) => self.tun.write4(v4.packet()).unwrap(),
IpPacket::Ipv6(v6) => self.tun.write6(v6.packet()).unwrap(),
};
self.tun.send(packet).unwrap();
continue;
}
@@ -120,12 +118,12 @@ impl Eventloop {
continue;
}
let mut packet_buf = IpPacketBuf::default();
let num_read = ready!(self.tun.poll_read(packet_buf.buf(), cx)).unwrap();
let packet = IpPacket::new(packet_buf, num_read).unwrap();
let mut buf = Vec::with_capacity(1);
ready!(self.tun.poll_recv_many(cx, &mut buf, 1));
let ip_packet = buf.remove(0);
if self.dns_server.accepts(&packet) {
self.dns_server.handle_inbound(packet);
if self.dns_server.accepts(&ip_packet) {
self.dns_server.handle_inbound(ip_packet);
self.dns_server.handle_timeout(Instant::now());
}
}

View File

@@ -33,7 +33,7 @@ serde = { workspace = true, features = ["std", "derive"] }
snownet = { workspace = true }
socket-factory = { workspace = true }
static_assertions = { workspace = true }
tokio = { workspace = true, features = ["sync", "macros", "rt-multi-thread", "fs", "signal"] }
tokio = { workspace = true, features = ["sync", "macros", "fs", "signal", "rt"] }
tracing = { workspace = true }
tracing-subscriber = { workspace = true }
url = { workspace = true }

View File

@@ -47,7 +47,10 @@ fn main() {
);
}
let runtime = tokio::runtime::Runtime::new().expect("Failed to create tokio runtime");
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.expect("Failed to create tokio runtime");
match runtime.block_on(try_main(cli, &mut telemetry)) {
Ok(()) => runtime.block_on(telemetry.stop()),
@@ -79,7 +82,7 @@ async fn try_main(cli: Cli, telemetry: &mut Telemetry) -> Result<()> {
cli.firezone_name,
)?;
let task = tokio::spawn(run(login)).err_into();
let task = tokio::spawn(run(login, cli.tun_threads)).err_into();
let ctrl_c = pin!(ctrl_c().map_err(anyhow::Error::new));
@@ -122,7 +125,7 @@ async fn get_firezone_id(env_id: Option<String>) -> Result<String> {
Ok(id)
}
async fn run(login: LoginUrl<PublicKeyParam>) -> Result<Infallible> {
async fn run(login: LoginUrl<PublicKeyParam>, num_tun_threads: usize) -> Result<Infallible> {
let mut tunnel = GatewayTunnel::new(Arc::new(tcp_socket_factory), Arc::new(udp_socket_factory));
let portal = PhoenixChannel::disconnected(
Secret::new(login),
@@ -138,7 +141,7 @@ async fn run(login: LoginUrl<PublicKeyParam>) -> Result<Infallible> {
)?;
let (sender, receiver) = mpsc::channel::<Interface>(10);
let mut tun_device_manager = TunDeviceManager::new(ip_packet::PACKET_SIZE)?;
let mut tun_device_manager = TunDeviceManager::new(ip_packet::PACKET_SIZE, num_tun_threads)?;
let tun = tun_device_manager.make_tun()?;
tunnel.set_tun(Box::new(tun));
@@ -203,6 +206,10 @@ struct Cli {
/// Identifier generated by the portal to identify and display the device.
#[arg(short = 'i', long, env = "FIREZONE_ID")]
pub firezone_id: Option<String>,
/// How many threads to use for reading and writing to the TUN device.
#[arg(long, env = "FIREZONE_NUM_TUN_THREADS", default_value_t = 2)]
tun_threads: usize,
}
impl Cli {

View File

@@ -31,7 +31,7 @@ smbios-lib = { workspace = true }
thiserror = { workspace = true }
# This actually relies on many other features in Tokio, so this will probably
# fail to build outside the workspace. <https://github.com/firezone/firezone/pull/4328#discussion_r1540342142>
tokio = { workspace = true, features = ["macros", "signal", "process", "time", "rt-multi-thread", "fs"] }
tokio = { workspace = true, features = ["macros", "signal", "process", "time", "fs", "rt"] }
tokio-stream = { workspace = true }
tokio-util = { workspace = true, features = ["codec"] }
tracing = { workspace = true }

View File

@@ -213,7 +213,7 @@ mod tests {
let rt = tokio::runtime::Runtime::new().unwrap();
let mut tun_dev_manager = firezone_bin_shared::TunDeviceManager::new(1280).unwrap();
let mut tun_dev_manager = firezone_bin_shared::TunDeviceManager::new(1280, 1).unwrap(); // Note: num_threads (`1`) is unused on windows.
let _tun = tun_dev_manager.make_tun().unwrap();
rt.block_on(async {

View File

@@ -177,7 +177,9 @@ fn run_debug_ipc_service(cli: Cli) -> Result<()> {
if !platform::elevation_check()? {
bail!("IPC service failed its elevation check, try running as admin / root");
}
let rt = tokio::runtime::Runtime::new()?;
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()?;
let _guard = rt.enter();
let mut signals = signals::Terminate::new()?;
@@ -202,7 +204,9 @@ fn run_smoke_test() -> Result<()> {
if !platform::elevation_check()? {
bail!("IPC service failed its elevation check, try running as admin / root");
}
let rt = tokio::runtime::Runtime::new()?;
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()?;
let _guard = rt.enter();
let mut dns_controller = DnsController {
dns_control_method: Default::default(),
@@ -325,7 +329,7 @@ impl<'a> Handler<'a> {
.next_client_split()
.await
.context("Failed to wait for incoming IPC connection from a GUI")?;
let tun_device = TunDeviceManager::new(ip_packet::PACKET_SIZE)?;
let tun_device = TunDeviceManager::new(ip_packet::PACKET_SIZE, crate::NUM_TUN_THREADS)?;
Ok(Self {
dns_controller,

View File

@@ -10,7 +10,9 @@ pub(crate) fn run_ipc_service(cli: CliCommon) -> Result<()> {
if !elevation_check()? {
bail!("IPC service failed its elevation check, try running as admin / root");
}
let rt = tokio::runtime::Runtime::new()?;
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()?;
let _guard = rt.enter();
let mut signals = signals::Terminate::new()?;

View File

@@ -155,7 +155,9 @@ fn fallible_service_run(
bail!("IPC service failed its elevation check, try running as admin / root");
}
let rt = tokio::runtime::Runtime::new()?;
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()?;
let (shutdown_tx, shutdown_rx) = mpsc::channel(1);
let event_handler = move |control_event| -> ServiceControlHandlerResult {

View File

@@ -46,6 +46,9 @@ pub type LogFilterReloader = tracing_subscriber::reload::Handle<EnvFilter, Regis
/// Only used on Linux
pub const FIREZONE_GROUP: &str = "firezone-client";
/// Empirically tested to have the best performance.
pub const NUM_TUN_THREADS: usize = 2;
/// CLI args common to both the IPC service and the headless Client
#[derive(clap::Parser)]
pub struct CliCommon {

View File

@@ -236,7 +236,10 @@ fn main() -> Result<()> {
let mut terminate = signals::Terminate::new()?;
let mut hangup = signals::Hangup::new()?;
let mut tun_device = TunDeviceManager::new(ip_packet::PACKET_SIZE)?;
let mut tun_device = TunDeviceManager::new(
ip_packet::PACKET_SIZE,
firezone_headless_client::NUM_TUN_THREADS,
)?;
let mut cb_rx = ReceiverStream::new(cb_rx).fuse();
let tokio_handle = tokio::runtime::Handle::current();

View File

@@ -5,9 +5,15 @@ edition = { workspace = true }
license = { workspace = true }
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
ip-packet = { workspace = true }
[target.'cfg(target_family = "unix")'.dependencies]
libc = { workspace = true }
tokio = { workspace = true }
futures = { workspace = true }
flume = { workspace = true }
tracing = { workspace = true }
[lints]
workspace = true

View File

@@ -36,7 +36,7 @@ impl Request<SetTunFlagsPayload> {
Self {
name,
payload: SetTunFlagsPayload {
flags: (libc::IFF_TUN | libc::IFF_NO_PI) as _,
flags: (libc::IFF_TUN | libc::IFF_NO_PI | libc::IFF_MULTI_QUEUE) as _,
},
}
}

View File

@@ -3,14 +3,27 @@ use std::{
task::{Context, Poll},
};
use ip_packet::IpPacket;
#[cfg(target_family = "unix")]
pub mod ioctl;
#[cfg(target_family = "unix")]
pub mod unix;
pub trait Tun: Send + Sync + 'static {
fn write4(&self, buf: &[u8]) -> io::Result<usize>;
fn write6(&self, buf: &[u8]) -> io::Result<usize>;
fn poll_read(&mut self, buf: &mut [u8], cx: &mut Context<'_>) -> Poll<io::Result<usize>>;
/// Check if more packets can be sent.
fn poll_send_ready(&mut self, cx: &mut Context) -> Poll<io::Result<()>>;
/// Send a packet.
fn send(&mut self, packet: IpPacket) -> io::Result<()>;
/// Receive a batch of packets up to `max`.
fn poll_recv_many(
&mut self,
cx: &mut Context,
buf: &mut Vec<IpPacket>,
max: usize,
) -> Poll<usize>;
/// The name of the TUN device.
fn name(&self) -> &str;
}

View File

@@ -1,25 +1,107 @@
use futures::future::Either;
use futures::task::AtomicWaker;
use futures::StreamExt as _;
use ip_packet::{IpPacket, IpPacketBuf};
use std::io;
use std::os::fd::{AsRawFd, RawFd};
use std::task::{Context, Poll};
use tokio::io::Ready;
use std::pin::pin;
use std::sync::Arc;
use tokio::io::unix::AsyncFd;
use tokio::sync::mpsc;
pub fn poll_raw_fd(
fd: &tokio::io::unix::AsyncFd<RawFd>,
mut read: impl FnMut(RawFd) -> io::Result<usize>,
cx: &mut Context<'_>,
) -> Poll<io::Result<usize>> {
pub struct TunFd {
inner: RawFd,
}
impl TunFd {
/// # Safety
///
/// You must not close this FD yourself.
/// [`TunFd`] will close it for you.
pub unsafe fn new(fd: RawFd) -> Self {
Self { inner: fd }
}
}
impl AsRawFd for TunFd {
fn as_raw_fd(&self) -> RawFd {
self.inner
}
}
impl Drop for TunFd {
fn drop(&mut self) {
// Safety: We are the only ones closing the FD.
unsafe { libc::close(self.inner) };
}
}
/// Concurrently reads and writes packets to the given TUN file-descriptor using the provided function pointers for the actual syscall.
///
/// - Every packet received on `outbound_rx` channel will be written to the file descriptor using the `write` syscall.
/// - Every packet read using the `read` syscall will be sent into the `inbound_tx` channel.
/// - Every time we read a packet from `outbound_rx`, we notify `outbound_capacity_waker` about the newly gained capacity.
/// - In case any of the channels close, we exit the task.
/// - IO errors are not fallible.
pub async fn send_recv_tun<T>(
fd: AsyncFd<T>,
inbound_tx: mpsc::Sender<IpPacket>,
mut outbound_rx: flume::r#async::RecvStream<'static, IpPacket>,
outbound_capacity_waker: Arc<AtomicWaker>,
read: impl Fn(RawFd, &mut IpPacketBuf) -> io::Result<usize>,
write: impl Fn(RawFd, &IpPacket) -> io::Result<usize>,
) where
T: AsRawFd,
{
loop {
let mut guard = std::task::ready!(fd.poll_read_ready(cx))?;
let next_inbound_packet = pin!(fd.async_io(tokio::io::Interest::READABLE, |fd| {
let mut ip_packet_buf = IpPacketBuf::new();
match read(guard.get_inner().as_raw_fd()) {
Ok(n) => return Poll::Ready(Ok(n)),
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
// a read has blocked, but a write might still succeed.
// clear only the read readiness.
guard.clear_ready_matching(Ready::READABLE);
let len = read(fd.as_raw_fd(), &mut ip_packet_buf)?;
if len == 0 {
return Ok(None);
}
let packet = IpPacket::new(ip_packet_buf, len)
.map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?;
Ok(Some(packet))
}));
let next_outbound_packet = pin!(outbound_rx.next());
match futures::future::select(next_inbound_packet, next_outbound_packet).await {
Either::Left((Ok(None), _)) => {
tracing::error!("TUN FD is closed");
return;
}
Either::Left((Ok(Some(packet)), _)) => {
if inbound_tx.send(packet).await.is_err() {
tracing::debug!("Inbound packet receiver gone, shutting down task");
return;
};
}
Either::Left((Err(e), _)) => {
tracing::warn!("Failed to read from TUN FD: {e}");
continue;
}
Err(e) => return Poll::Ready(Err(e)),
Either::Right((Some(packet), _)) => {
if let Err(e) = fd
.async_io(tokio::io::Interest::WRITABLE, |fd| {
write(fd.as_raw_fd(), &packet)
})
.await
{
tracing::warn!("Failed to write to TUN FD: {e}");
};
outbound_capacity_waker.wake(); // We wrote a packet, notify about the new capacity.
}
Either::Right((None, _)) => {
tracing::debug!("Outbound packet sender gone, shutting down task");
return;
}
}
}
}

View File

@@ -23,6 +23,12 @@ export default function GUI({ title }: { title: string }) {
Makes use of the new control protocol, delivering faster and more
robust connection establishment.
</ChangeItem>
{title == "Linux GUI" && (
<ChangeItem pull="7449">
Uses multiple threads to read & write to the TUN device, greatly
improving performance.
</ChangeItem>
)}
</Unreleased>
<Entry version="1.3.13" date={new Date("2024-11-15")}>
<ChangeItem pull="7334">

View File

@@ -19,6 +19,11 @@ export default function Gateway() {
Fixes cases where client applications such as ssh would fail to
automatically determine the correct IP protocol version to use (4/6).
</ChangeItem>
<ChangeItem pull="7449">
Uses multiple threads to read & write to the TUN device, greatly
improving performance. The number of threads can be controlled with
`FIREZONE_NUM_TUN_THREADS` and defaults to 2.
</ChangeItem>
</Unreleased>
<Entry version="1.4.1" date={new Date("2024-11-15")}>
<ChangeItem pull="7263">

View File

@@ -23,6 +23,10 @@ export default function Headless() {
Makes use of the new control protocol, delivering faster and more
robust connection establishment.
</ChangeItem>
<ChangeItem pull="7449">
Uses multiple threads to read & write to the TUN device, greatly
improving performance.
</ChangeItem>
</Unreleased>
<Entry version="1.3.7" date={new Date("2024-11-15")}>
<ChangeItem pull="7334">