Files
firezone/rust/bin-shared/src/tun_device_manager/linux.rs
Thomas Eizinger 90cf191a7c feat(linux): multi-threaded TUN device operations (#7449)
## Context

At present, we only have a single thread that reads and writes to the
TUN device on all platforms. On Linux, it is possible to open the file
descriptor of a TUN device multiple times by setting the
`IFF_MULTI_QUEUE` option using `ioctl`. Using multi-queue, we can then
spawn multiple threads that concurrently read and write to the TUN
device. This is critical for achieving a better throughput.

## Solution

`IFF_MULTI_QUEUE` is a Linux-only thing and therefore only applies to
headless-client, GUI-client on Linux and the Gateway (it may also be
possible on Android, I haven't tried). As such, we need to first change
our internal abstractions a bit to move the creation of the TUN thread
to the `Tun` abstraction itself. For this, we change the interface of
`Tun` to the following:

- `poll_recv_many`: An API, inspired by tokio's `mpsc::Receiver` where
multiple items in a channel can be batch-received.
- `poll_send_ready`: Mimics the API of `Sink` to check whether more
items can be written.
- `send`: Mimics the API of `Sink` to actually send an item.

With these APIs in place, we can implement various (performance)
improvements for the different platforms.

- On Linux, this allows us to spawn multiple threads to read and write
from the TUN device and send all packets into the same channel. The `Io`
component of `connlib` then uses `poll_recv_many` to read batches of up
to 100 packets at once. This ties in well with #7210 because we can then
use GSO to send the encrypted packets in single syscalls to the OS.
- On Windows, we already have a dedicated recv thread because `WinTun`'s
most-convenient API uses blocking IO. As such, we can now also tie into
that by batch-receiving from this channel.
- In addition to using multiple threads, this API now also uses correct
readiness checks on Linux, Darwin and Android to uphold backpressure in
case we cannot write to the TUN device.

## Configuration

Local testing has shown that 2 threads give the best performance for a
local `iperf3` run. I suspect this is because there is only so much
traffic that a single application (i.e. `iperf3`) can generate. With
more than 2 threads, the throughput actually drops drastically because
`connlib`'s main thread is too busy with lock-contention and triggering
`Waker`s for the TUN threads (which mostly idle around if there are 4+
of them). I've made it configurable on the Gateway though so we can
experiment with this during concurrent speedtests etc.

In addition, switching `connlib` to a single-threaded tokio runtime
further increased the throughput. I suspect due to less task / context
switching.

## Results

Local testing with `iperf3` shows some very promising results. We now
achieve a throughput of 2+ Gbit/s.

```
Connecting to host 172.20.0.110, port 5201
Reverse mode, remote host 172.20.0.110 is sending
[  5] local 100.80.159.34 port 57040 connected to 172.20.0.110 port 5201
[ ID] Interval           Transfer     Bitrate
[  5]   0.00-1.00   sec   274 MBytes  2.30 Gbits/sec
[  5]   1.00-2.00   sec   279 MBytes  2.34 Gbits/sec
[  5]   2.00-3.00   sec   216 MBytes  1.82 Gbits/sec
[  5]   3.00-4.00   sec   224 MBytes  1.88 Gbits/sec
[  5]   4.00-5.00   sec   234 MBytes  1.96 Gbits/sec
[  5]   5.00-6.00   sec   238 MBytes  2.00 Gbits/sec
[  5]   6.00-7.00   sec   229 MBytes  1.92 Gbits/sec
[  5]   7.00-8.00   sec   222 MBytes  1.86 Gbits/sec
[  5]   8.00-9.00   sec   223 MBytes  1.87 Gbits/sec
[  5]   9.00-10.00  sec   217 MBytes  1.82 Gbits/sec
- - - - - - - - - - - - - - - - - - - - - - - - -
[ ID] Interval           Transfer     Bitrate         Retr
[  5]   0.00-10.00  sec  2.30 GBytes  1.98 Gbits/sec  22247             sender
[  5]   0.00-10.00  sec  2.30 GBytes  1.98 Gbits/sec                  receiver

iperf Done.
```

This is a pretty solid improvement over what is in `main`:

```
Connecting to host 172.20.0.110, port 5201
[  5] local 100.65.159.3 port 56970 connected to 172.20.0.110 port 5201
[ ID] Interval           Transfer     Bitrate         Retr  Cwnd
[  5]   0.00-1.00   sec  90.4 MBytes   758 Mbits/sec  1800    106 KBytes
[  5]   1.00-2.00   sec  93.4 MBytes   783 Mbits/sec  1550   51.6 KBytes
[  5]   2.00-3.00   sec  92.6 MBytes   777 Mbits/sec  1350   76.8 KBytes
[  5]   3.00-4.00   sec  92.9 MBytes   779 Mbits/sec  1800   56.4 KBytes
[  5]   4.00-5.00   sec  93.4 MBytes   783 Mbits/sec  1650   69.6 KBytes
[  5]   5.00-6.00   sec  90.6 MBytes   760 Mbits/sec  1500   73.2 KBytes
[  5]   6.00-7.00   sec  87.6 MBytes   735 Mbits/sec  1400   76.8 KBytes
[  5]   7.00-8.00   sec  92.6 MBytes   777 Mbits/sec  1600   82.7 KBytes
[  5]   8.00-9.00   sec  91.1 MBytes   764 Mbits/sec  1500   70.8 KBytes
[  5]   9.00-10.00  sec  92.0 MBytes   771 Mbits/sec  1550   85.1 KBytes
- - - - - - - - - - - - - - - - - - - - - - - - -
[ ID] Interval           Transfer     Bitrate         Retr
[  5]   0.00-10.00  sec   917 MBytes   769 Mbits/sec  15700             sender
[  5]   0.00-10.00  sec   916 MBytes   768 Mbits/sec                  receiver

iperf Done.
```
2024-12-05 00:18:20 +00:00

453 lines
12 KiB
Rust

//! Virtual network interface
use crate::FIREZONE_MARK;
use anyhow::{anyhow, Context as _, Result};
use firezone_logging::std_dyn_err;
use futures::task::AtomicWaker;
use futures::TryStreamExt;
use ip_network::{IpNetwork, Ipv4Network, Ipv6Network};
use ip_packet::{IpPacket, IpPacketBuf};
use libc::{
fcntl, makedev, mknod, open, EEXIST, ENOENT, F_GETFL, F_SETFL, O_NONBLOCK, O_RDWR, S_IFCHR,
};
use netlink_packet_route::route::{RouteProtocol, RouteScope};
use netlink_packet_route::rule::RuleAction;
use rtnetlink::{new_connection, Error::NetlinkError, Handle, RouteAddRequest, RuleAddRequest};
use std::path::Path;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::{
collections::HashSet,
net::{Ipv4Addr, Ipv6Addr},
};
use std::{
ffi::CStr,
fs, io,
os::{fd::RawFd, unix::fs::PermissionsExt},
};
use tokio::io::unix::AsyncFd;
use tokio::sync::mpsc;
use tun::ioctl;
use tun::unix::TunFd;
const TUNSETIFF: libc::c_ulong = 0x4004_54ca;
const TUN_DEV_MAJOR: u32 = 10;
const TUN_DEV_MINOR: u32 = 200;
const TUN_FILE: &CStr = c"/dev/net/tun";
const FIREZONE_TABLE: u32 = 0x2021_fd00;
/// For lack of a better name
pub struct TunDeviceManager {
mtu: u32,
num_threads: usize,
connection: Connection,
routes: HashSet<IpNetwork>,
}
struct Connection {
handle: Handle,
task: tokio::task::JoinHandle<()>,
}
impl Drop for TunDeviceManager {
fn drop(&mut self) {
self.connection.task.abort();
}
}
impl TunDeviceManager {
pub const IFACE_NAME: &'static str = "tun-firezone";
/// Creates a new managed tunnel device.
///
/// Panics if called without a Tokio runtime.
pub fn new(mtu: usize, num_threads: usize) -> Result<Self> {
let (cxn, handle, _) = new_connection()?;
let task = tokio::spawn(cxn);
let connection = Connection { handle, task };
Ok(Self {
connection,
routes: Default::default(),
mtu: mtu as u32,
num_threads,
})
}
pub fn make_tun(&mut self) -> Result<Tun> {
Ok(Tun::new(self.num_threads)?)
}
#[tracing::instrument(level = "trace", skip(self))]
pub async fn set_ips(&mut self, ipv4: Ipv4Addr, ipv6: Ipv6Addr) -> Result<()> {
let name = Self::IFACE_NAME;
let handle = &self.connection.handle;
let index = handle
.link()
.get()
.match_name(name.to_string())
.execute()
.try_next()
.await?
.ok_or_else(|| anyhow!("Interface '{name}' does not exist"))?
.header
.index;
let ips = handle
.address()
.get()
.set_link_index_filter(index)
.execute();
ips.try_for_each(|ip| handle.address().del(ip).execute())
.await
.context("Failed to delete existing addresses")?;
handle
.link()
.set(index)
.mtu(self.mtu)
.execute()
.await
.context("Failed to set default MTU")?;
let res_v4 = handle.address().add(index, ipv4.into(), 32).execute().await;
let res_v6 = handle
.address()
.add(index, ipv6.into(), 128)
.execute()
.await;
handle
.link()
.set(index)
.up()
.execute()
.await
.context("Failed to bring up interface")?;
if res_v4.is_ok() {
if let Err(e) = make_rule(handle).v4().execute().await {
if !matches!(&e, NetlinkError(err) if err.raw_code() == -EEXIST) {
tracing::warn!(
"Couldn't add ip rule for ipv4: {e:?}, ipv4 packets won't be routed"
);
}
// TODO: Be smarter about this
} else {
tracing::debug!("Successfully created ip rule for ipv4");
}
}
if res_v6.is_ok() {
if let Err(e) = make_rule(handle).v6().execute().await {
if !matches!(&e, NetlinkError(err) if err.raw_code() == -EEXIST) {
tracing::warn!(
"Couldn't add ip rule for ipv6: {e:?}, ipv6 packets won't be routed"
);
}
// TODO: Be smarter about this
} else {
tracing::debug!("Successfully created ip rule for ipv6");
}
}
res_v4.or(res_v6)?;
Ok(())
}
pub async fn set_routes(
&mut self,
ipv4: Vec<Ipv4Network>,
ipv6: Vec<Ipv6Network>,
) -> Result<()> {
let new_routes: HashSet<IpNetwork> = ipv4
.into_iter()
.map(IpNetwork::from)
.chain(ipv6.into_iter().map(IpNetwork::from))
.collect();
tracing::info!(?new_routes, "Setting new routes");
let handle = &self.connection.handle;
let index = handle
.link()
.get()
.match_name(Self::IFACE_NAME.to_string())
.execute()
.try_next()
.await?
.context("No interface")?
.header
.index;
for route in self.routes.difference(&new_routes) {
remove_route(route, index, handle).await;
}
for route in &new_routes {
add_route(route, index, handle).await;
}
self.routes = new_routes;
Ok(())
}
}
fn make_rule(handle: &Handle) -> RuleAddRequest {
let mut rule = handle
.rule()
.add()
.fw_mark(FIREZONE_MARK)
.table_id(FIREZONE_TABLE)
.action(RuleAction::ToTable);
rule.message_mut()
.header
.flags
.push(netlink_packet_route::rule::RuleFlag::Invert);
rule.message_mut()
.attributes
.push(netlink_packet_route::rule::RuleAttribute::Protocol(
RouteProtocol::Kernel,
));
rule
}
fn make_route(idx: u32, handle: &Handle) -> RouteAddRequest {
handle
.route()
.add()
.output_interface(idx)
.protocol(RouteProtocol::Static)
.scope(RouteScope::Universe)
.table_id(FIREZONE_TABLE)
}
fn make_route_v4(idx: u32, handle: &Handle, route: Ipv4Network) -> RouteAddRequest<Ipv4Addr> {
make_route(idx, handle)
.v4()
.destination_prefix(route.network_address(), route.netmask())
}
fn make_route_v6(idx: u32, handle: &Handle, route: Ipv6Network) -> RouteAddRequest<Ipv6Addr> {
make_route(idx, handle)
.v6()
.destination_prefix(route.network_address(), route.netmask())
}
async fn add_route(route: &IpNetwork, idx: u32, handle: &Handle) {
let res = match route {
IpNetwork::V4(ipnet) => make_route_v4(idx, handle, *ipnet).execute().await,
IpNetwork::V6(ipnet) => make_route_v6(idx, handle, *ipnet).execute().await,
};
let Err(err) = res else {
tracing::debug!(%route, iface_idx = %idx, "Created new route");
return;
};
// We expect this to be called often with an already existing route since set_routes always calls for all routes
if matches!(&err, NetlinkError(err) if err.raw_code() == -EEXIST) {
return;
}
tracing::warn!(error = std_dyn_err(&err), %route, "Failed to add route");
}
async fn remove_route(route: &IpNetwork, idx: u32, handle: &Handle) {
let message = match route {
IpNetwork::V4(ipnet) => make_route_v4(idx, handle, *ipnet).message_mut().clone(),
IpNetwork::V6(ipnet) => make_route_v6(idx, handle, *ipnet).message_mut().clone(),
};
let res = handle.route().del(message).execute().await;
let Err(err) = res else {
tracing::debug!(%route, iface_idx = %idx, "Removed route");
return;
};
// Our view of the current routes may be stale. Removing a route that no longer exists shouldn't print a warning.
if matches!(&err, NetlinkError(err) if err.raw_code() == -ENOENT) {
return;
}
tracing::warn!(error = std_dyn_err(&err), %route, "Failed to remove route");
}
#[derive(Debug)]
pub struct Tun {
outbound_tx: flume::Sender<IpPacket>,
outbound_capacity_waker: Arc<AtomicWaker>,
inbound_rx: mpsc::Receiver<IpPacket>,
}
impl Tun {
pub fn new(num_threads: usize) -> io::Result<Self> {
create_tun_device()?;
let (inbound_tx, inbound_rx) = mpsc::channel(1000);
let (outbound_tx, outbound_rx) = flume::bounded(1000); // flume is an MPMC channel, therefore perfect for workstealing outbound packets.
let outbound_capacity_waker = Arc::new(AtomicWaker::new());
for n in 0..num_threads {
let fd = AsyncFd::new(open_tun()?)?;
let outbound_rx = outbound_rx.clone().into_stream();
let inbound_tx = inbound_tx.clone();
let outbound_capacity_waker = outbound_capacity_waker.clone();
std::thread::Builder::new()
.name(format!("TUN send/recv {n}/{num_threads}"))
.spawn(move || {
tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()?
.block_on(tun::unix::send_recv_tun(
fd,
inbound_tx,
outbound_rx,
outbound_capacity_waker,
read,
write,
));
io::Result::Ok(())
})
.map_err(io::Error::other)?;
}
Ok(Self {
outbound_tx,
outbound_capacity_waker,
inbound_rx,
})
}
}
fn open_tun() -> Result<TunFd, io::Error> {
let fd = match unsafe { open(TUN_FILE.as_ptr() as _, O_RDWR) } {
-1 => return Err(get_last_error()),
fd => fd,
};
unsafe {
ioctl::exec(
fd,
TUNSETIFF,
&mut ioctl::Request::<ioctl::SetTunFlagsPayload>::new(TunDeviceManager::IFACE_NAME),
)?;
}
set_non_blocking(fd)?;
// Safety: We are not closing the FD.
let fd = unsafe { TunFd::new(fd) };
Ok(fd)
}
impl tun::Tun for Tun {
fn poll_send_ready(&mut self, cx: &mut Context) -> Poll<io::Result<()>> {
if self.outbound_tx.is_full() {
self.outbound_capacity_waker.register(cx.waker());
return Poll::Pending;
}
Poll::Ready(Ok(()))
}
fn send(&mut self, packet: IpPacket) -> io::Result<()> {
self.outbound_tx
.try_send(packet)
.map_err(io::Error::other)?;
Ok(())
}
fn poll_recv_many(
&mut self,
cx: &mut Context,
buf: &mut Vec<IpPacket>,
max: usize,
) -> Poll<usize> {
self.inbound_rx.poll_recv_many(cx, buf, max)
}
fn name(&self) -> &str {
TunDeviceManager::IFACE_NAME
}
}
fn get_last_error() -> io::Error {
io::Error::last_os_error()
}
fn set_non_blocking(fd: RawFd) -> io::Result<()> {
match unsafe { fcntl(fd, F_GETFL) } {
-1 => Err(get_last_error()),
flags => match unsafe { fcntl(fd, F_SETFL, flags | O_NONBLOCK) } {
-1 => Err(get_last_error()),
_ => Ok(()),
},
}
}
fn create_tun_device() -> io::Result<()> {
let path = Path::new(TUN_FILE.to_str().expect("path is valid utf-8"));
if path.exists() {
return Ok(());
}
let parent_dir = path
.parent()
.expect("const-declared path always has a parent");
fs::create_dir_all(parent_dir)?;
let permissions = fs::Permissions::from_mode(0o751);
fs::set_permissions(parent_dir, permissions)?;
if unsafe {
mknod(
TUN_FILE.as_ptr() as _,
S_IFCHR,
makedev(TUN_DEV_MAJOR, TUN_DEV_MINOR),
)
} != 0
{
return Err(get_last_error());
}
Ok(())
}
/// Read from the given file descriptor in the buffer.
fn read(fd: RawFd, dst: &mut IpPacketBuf) -> io::Result<usize> {
let dst = dst.buf();
// Safety: Within this module, the file descriptor is always valid.
match unsafe { libc::read(fd, dst.as_mut_ptr() as _, dst.len()) } {
-1 => Err(io::Error::last_os_error()),
n => Ok(n as usize),
}
}
/// Write the packet to the given file descriptor.
fn write(fd: RawFd, packet: &IpPacket) -> io::Result<usize> {
let buf = packet.packet();
// Safety: Within this module, the file descriptor is always valid.
match unsafe { libc::write(fd, buf.as_ptr() as _, buf.len() as _) } {
-1 => Err(io::Error::last_os_error()),
n => Ok(n as usize),
}
}