mirror of
https://github.com/outbackdingo/firezone.git
synced 2026-01-27 10:18:54 +00:00
feat(linux): multi-threaded TUN device operations (#7449)
## Context At present, we only have a single thread that reads and writes to the TUN device on all platforms. On Linux, it is possible to open the file descriptor of a TUN device multiple times by setting the `IFF_MULTI_QUEUE` option using `ioctl`. Using multi-queue, we can then spawn multiple threads that concurrently read and write to the TUN device. This is critical for achieving a better throughput. ## Solution `IFF_MULTI_QUEUE` is a Linux-only thing and therefore only applies to headless-client, GUI-client on Linux and the Gateway (it may also be possible on Android, I haven't tried). As such, we need to first change our internal abstractions a bit to move the creation of the TUN thread to the `Tun` abstraction itself. For this, we change the interface of `Tun` to the following: - `poll_recv_many`: An API, inspired by tokio's `mpsc::Receiver` where multiple items in a channel can be batch-received. - `poll_send_ready`: Mimics the API of `Sink` to check whether more items can be written. - `send`: Mimics the API of `Sink` to actually send an item. With these APIs in place, we can implement various (performance) improvements for the different platforms. - On Linux, this allows us to spawn multiple threads to read and write from the TUN device and send all packets into the same channel. The `Io` component of `connlib` then uses `poll_recv_many` to read batches of up to 100 packets at once. This ties in well with #7210 because we can then use GSO to send the encrypted packets in single syscalls to the OS. - On Windows, we already have a dedicated recv thread because `WinTun`'s most-convenient API uses blocking IO. As such, we can now also tie into that by batch-receiving from this channel. - In addition to using multiple threads, this API now also uses correct readiness checks on Linux, Darwin and Android to uphold backpressure in case we cannot write to the TUN device. ## Configuration Local testing has shown that 2 threads give the best performance for a local `iperf3` run. I suspect this is because there is only so much traffic that a single application (i.e. `iperf3`) can generate. With more than 2 threads, the throughput actually drops drastically because `connlib`'s main thread is too busy with lock-contention and triggering `Waker`s for the TUN threads (which mostly idle around if there are 4+ of them). I've made it configurable on the Gateway though so we can experiment with this during concurrent speedtests etc. In addition, switching `connlib` to a single-threaded tokio runtime further increased the throughput. I suspect due to less task / context switching. ## Results Local testing with `iperf3` shows some very promising results. We now achieve a throughput of 2+ Gbit/s. ``` Connecting to host 172.20.0.110, port 5201 Reverse mode, remote host 172.20.0.110 is sending [ 5] local 100.80.159.34 port 57040 connected to 172.20.0.110 port 5201 [ ID] Interval Transfer Bitrate [ 5] 0.00-1.00 sec 274 MBytes 2.30 Gbits/sec [ 5] 1.00-2.00 sec 279 MBytes 2.34 Gbits/sec [ 5] 2.00-3.00 sec 216 MBytes 1.82 Gbits/sec [ 5] 3.00-4.00 sec 224 MBytes 1.88 Gbits/sec [ 5] 4.00-5.00 sec 234 MBytes 1.96 Gbits/sec [ 5] 5.00-6.00 sec 238 MBytes 2.00 Gbits/sec [ 5] 6.00-7.00 sec 229 MBytes 1.92 Gbits/sec [ 5] 7.00-8.00 sec 222 MBytes 1.86 Gbits/sec [ 5] 8.00-9.00 sec 223 MBytes 1.87 Gbits/sec [ 5] 9.00-10.00 sec 217 MBytes 1.82 Gbits/sec - - - - - - - - - - - - - - - - - - - - - - - - - [ ID] Interval Transfer Bitrate Retr [ 5] 0.00-10.00 sec 2.30 GBytes 1.98 Gbits/sec 22247 sender [ 5] 0.00-10.00 sec 2.30 GBytes 1.98 Gbits/sec receiver iperf Done. ``` This is a pretty solid improvement over what is in `main`: ``` Connecting to host 172.20.0.110, port 5201 [ 5] local 100.65.159.3 port 56970 connected to 172.20.0.110 port 5201 [ ID] Interval Transfer Bitrate Retr Cwnd [ 5] 0.00-1.00 sec 90.4 MBytes 758 Mbits/sec 1800 106 KBytes [ 5] 1.00-2.00 sec 93.4 MBytes 783 Mbits/sec 1550 51.6 KBytes [ 5] 2.00-3.00 sec 92.6 MBytes 777 Mbits/sec 1350 76.8 KBytes [ 5] 3.00-4.00 sec 92.9 MBytes 779 Mbits/sec 1800 56.4 KBytes [ 5] 4.00-5.00 sec 93.4 MBytes 783 Mbits/sec 1650 69.6 KBytes [ 5] 5.00-6.00 sec 90.6 MBytes 760 Mbits/sec 1500 73.2 KBytes [ 5] 6.00-7.00 sec 87.6 MBytes 735 Mbits/sec 1400 76.8 KBytes [ 5] 7.00-8.00 sec 92.6 MBytes 777 Mbits/sec 1600 82.7 KBytes [ 5] 8.00-9.00 sec 91.1 MBytes 764 Mbits/sec 1500 70.8 KBytes [ 5] 9.00-10.00 sec 92.0 MBytes 771 Mbits/sec 1550 85.1 KBytes - - - - - - - - - - - - - - - - - - - - - - - - - [ ID] Interval Transfer Bitrate Retr [ 5] 0.00-10.00 sec 917 MBytes 769 Mbits/sec 15700 sender [ 5] 0.00-10.00 sec 916 MBytes 768 Mbits/sec receiver iperf Done. ```
This commit is contained in:
38
rust/Cargo.lock
generated
38
rust/Cargo.lock
generated
@@ -1042,6 +1042,9 @@ dependencies = [
|
||||
"connlib-model",
|
||||
"firezone-logging",
|
||||
"firezone-telemetry",
|
||||
"flume",
|
||||
"futures",
|
||||
"ip-packet",
|
||||
"ip_network",
|
||||
"jni",
|
||||
"libc",
|
||||
@@ -1070,6 +1073,9 @@ dependencies = [
|
||||
"connlib-model",
|
||||
"firezone-logging",
|
||||
"firezone-telemetry",
|
||||
"flume",
|
||||
"futures",
|
||||
"ip-packet",
|
||||
"ip_network",
|
||||
"libc",
|
||||
"oslog",
|
||||
@@ -1912,6 +1918,7 @@ dependencies = [
|
||||
"axum",
|
||||
"clap",
|
||||
"firezone-logging",
|
||||
"flume",
|
||||
"futures",
|
||||
"hex-literal",
|
||||
"ip-packet",
|
||||
@@ -2224,7 +2231,6 @@ dependencies = [
|
||||
"test-strategy",
|
||||
"thiserror",
|
||||
"tokio",
|
||||
"tokio-util",
|
||||
"tracing",
|
||||
"tracing-subscriber",
|
||||
"tun",
|
||||
@@ -2250,6 +2256,18 @@ dependencies = [
|
||||
"bitflags 1.3.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "flume"
|
||||
version = "0.11.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "da0e4dd2a88388a1f4ccc7c9ce104604dab68d9f408dc34cd45823d5a9069095"
|
||||
dependencies = [
|
||||
"futures-core",
|
||||
"futures-sink",
|
||||
"nanorand",
|
||||
"spin",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fnv"
|
||||
version = "1.0.7"
|
||||
@@ -2581,8 +2599,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"js-sys",
|
||||
"libc",
|
||||
"wasi 0.11.0+wasi-snapshot-preview1",
|
||||
"wasm-bindgen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -3670,6 +3690,15 @@ version = "0.0.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bc0287524726960e07b119cebd01678f852f147742ae0d925e6a520dca956126"
|
||||
|
||||
[[package]]
|
||||
name = "nanorand"
|
||||
version = "0.7.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6a51313c5820b0b02bd422f4b44776fbf47961755c74ce64afc73bfad10226c3"
|
||||
dependencies = [
|
||||
"getrandom 0.2.15",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "native-dialog"
|
||||
version = "0.7.0"
|
||||
@@ -5953,6 +5982,9 @@ name = "spin"
|
||||
version = "0.9.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67"
|
||||
dependencies = [
|
||||
"lock_api",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "stable_deref_trait"
|
||||
@@ -7125,8 +7157,12 @@ checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b"
|
||||
name = "tun"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"flume",
|
||||
"futures",
|
||||
"ip-packet",
|
||||
"libc",
|
||||
"tokio",
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
@@ -122,6 +122,7 @@ thiserror = "1.0.68"
|
||||
time = "0.3.36"
|
||||
tokio = "1.41"
|
||||
tokio-stream = "0.1.16"
|
||||
flume = { version = "0.11.1", features = ["async"] }
|
||||
tokio-tungstenite = "0.23.1"
|
||||
tokio-util = "0.7.11"
|
||||
tracing = { version = "0.1.40" }
|
||||
|
||||
@@ -13,6 +13,7 @@ clap = { workspace = true, features = ["derive", "env"] }
|
||||
firezone-logging = { workspace = true }
|
||||
futures = { workspace = true, features = ["std", "async-await"] }
|
||||
hex-literal = { workspace = true }
|
||||
ip-packet = { workspace = true }
|
||||
ip_network = { workspace = true, features = ["serde"] }
|
||||
socket-factory = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
@@ -24,6 +25,7 @@ tun = { workspace = true }
|
||||
tokio = { workspace = true, features = ["macros", "rt-multi-thread"] }
|
||||
|
||||
[target.'cfg(target_os = "linux")'.dependencies]
|
||||
flume = { workspace = true }
|
||||
libc = { workspace = true }
|
||||
netlink-packet-core = { version = "0.7" }
|
||||
netlink-packet-route = { version = "0.19" }
|
||||
|
||||
@@ -28,7 +28,6 @@ mod platform {
|
||||
mod platform {
|
||||
use anyhow::Result;
|
||||
use firezone_bin_shared::TunDeviceManager;
|
||||
use ip_packet::{IpPacket, IpPacketBuf};
|
||||
use std::{
|
||||
future::poll_fn,
|
||||
net::{Ipv4Addr, Ipv6Addr},
|
||||
@@ -47,10 +46,11 @@ mod platform {
|
||||
const REQ_LEN: usize = 1_000;
|
||||
const RESP_CODE: u8 = 43;
|
||||
const SERVER_PORT: u16 = 3000;
|
||||
const NUM_THREADS: usize = 1; // Note: Unused on Windows.
|
||||
|
||||
let ipv4 = Ipv4Addr::from([100, 90, 215, 97]);
|
||||
let ipv6 = Ipv6Addr::from([0xfd00, 0x2021, 0x1111, 0x0, 0x0, 0x0, 0x0016, 0x588f]);
|
||||
let mut device_manager = TunDeviceManager::new(MTU)?;
|
||||
let mut device_manager = TunDeviceManager::new(MTU, NUM_THREADS)?;
|
||||
let mut tun = device_manager.make_tun()?;
|
||||
|
||||
device_manager.set_ips(ipv4, ipv6).await?;
|
||||
@@ -62,15 +62,14 @@ mod platform {
|
||||
let server_task = tokio::spawn(async move {
|
||||
tracing::debug!("Server task entered");
|
||||
let mut requests_served = 0;
|
||||
// We aren't interested in allocator speed or doing any processing,
|
||||
// so just cache the response packet
|
||||
let mut response_pkt = None;
|
||||
|
||||
let mut time_spent = Duration::from_millis(0);
|
||||
loop {
|
||||
let mut req_buf = IpPacketBuf::new();
|
||||
let n = poll_fn(|cx| tun.poll_read(req_buf.buf(), cx)).await?;
|
||||
let mut buf = Vec::with_capacity(1);
|
||||
poll_fn(|cx| tun.poll_recv_many(cx, &mut buf, 1)).await;
|
||||
let original_pkt = buf.remove(0);
|
||||
|
||||
let start = Instant::now();
|
||||
let original_pkt = IpPacket::new(req_buf, n).unwrap();
|
||||
let Some(original_udp) = original_pkt.as_udp() else {
|
||||
continue;
|
||||
};
|
||||
@@ -81,21 +80,16 @@ mod platform {
|
||||
panic!("Wrong request code");
|
||||
}
|
||||
|
||||
// Only generate the response packet on the first loop,
|
||||
// then just reuse it.
|
||||
let res_buf = response_pkt
|
||||
.get_or_insert_with(|| {
|
||||
ip_packet::make::udp_packet(
|
||||
original_pkt.destination(),
|
||||
original_pkt.source(),
|
||||
original_udp.destination_port(),
|
||||
original_udp.source_port(),
|
||||
vec![RESP_CODE],
|
||||
)
|
||||
.unwrap()
|
||||
})
|
||||
.packet();
|
||||
tun.write4(res_buf)?;
|
||||
tun.send(
|
||||
ip_packet::make::udp_packet(
|
||||
original_pkt.destination(),
|
||||
original_pkt.source(),
|
||||
original_udp.destination_port(),
|
||||
original_udp.source_port(),
|
||||
vec![RESP_CODE],
|
||||
)
|
||||
.unwrap(),
|
||||
)?;
|
||||
requests_served += 1;
|
||||
time_spent += start.elapsed();
|
||||
if requests_served >= NUM_REQUESTS {
|
||||
|
||||
@@ -42,7 +42,7 @@ mod tests {
|
||||
let ipv4 = Ipv4Addr::from([100, 90, 215, 97]);
|
||||
let ipv6 = Ipv6Addr::from([0xfd00, 0x2021, 0x1111, 0x0, 0x0, 0x0, 0x0016, 0x588f]);
|
||||
|
||||
let mut device_manager = TunDeviceManager::new(1280).unwrap();
|
||||
let mut device_manager = TunDeviceManager::new(1280, 1).unwrap();
|
||||
let _tun = device_manager.make_tun().unwrap();
|
||||
device_manager.set_ips(ipv4, ipv6).await.unwrap();
|
||||
|
||||
@@ -72,7 +72,7 @@ mod tests {
|
||||
let ipv4 = Ipv4Addr::from([100, 90, 215, 97]);
|
||||
let ipv6 = Ipv6Addr::from([0xfd00, 0x2021, 0x1111, 0x0, 0x0, 0x0, 0x0016, 0x588f]);
|
||||
|
||||
let mut device_manager = TunDeviceManager::new(1280).unwrap();
|
||||
let mut device_manager = TunDeviceManager::new(1280, 1).unwrap();
|
||||
let _tun = device_manager.make_tun().unwrap();
|
||||
device_manager.set_ips(ipv4, ipv6).await.unwrap();
|
||||
|
||||
@@ -125,7 +125,7 @@ mod tests {
|
||||
/// Checks for regressions in issue #4765, un-initializing Wintun
|
||||
/// Redundant but harmless on Linux.
|
||||
fn tunnel_drop() {
|
||||
let mut tun_device_manager = TunDeviceManager::new(1280).unwrap();
|
||||
let mut tun_device_manager = TunDeviceManager::new(1280, 1).unwrap();
|
||||
|
||||
// Each cycle takes about half a second, so this will take a fair bit to run.
|
||||
for _ in 0..50 {
|
||||
|
||||
@@ -3,16 +3,18 @@
|
||||
use crate::FIREZONE_MARK;
|
||||
use anyhow::{anyhow, Context as _, Result};
|
||||
use firezone_logging::std_dyn_err;
|
||||
use futures::task::AtomicWaker;
|
||||
use futures::TryStreamExt;
|
||||
use ip_network::{IpNetwork, Ipv4Network, Ipv6Network};
|
||||
use ip_packet::{IpPacket, IpPacketBuf};
|
||||
use libc::{
|
||||
close, fcntl, makedev, mknod, open, EEXIST, ENOENT, F_GETFL, F_SETFL, O_NONBLOCK, O_RDWR,
|
||||
S_IFCHR,
|
||||
fcntl, makedev, mknod, open, EEXIST, ENOENT, F_GETFL, F_SETFL, O_NONBLOCK, O_RDWR, S_IFCHR,
|
||||
};
|
||||
use netlink_packet_route::route::{RouteProtocol, RouteScope};
|
||||
use netlink_packet_route::rule::RuleAction;
|
||||
use rtnetlink::{new_connection, Error::NetlinkError, Handle, RouteAddRequest, RuleAddRequest};
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
use std::task::{Context, Poll};
|
||||
use std::{
|
||||
collections::HashSet,
|
||||
@@ -21,13 +23,12 @@ use std::{
|
||||
use std::{
|
||||
ffi::CStr,
|
||||
fs, io,
|
||||
os::{
|
||||
fd::{AsRawFd, RawFd},
|
||||
unix::fs::PermissionsExt,
|
||||
},
|
||||
os::{fd::RawFd, unix::fs::PermissionsExt},
|
||||
};
|
||||
use tokio::io::unix::AsyncFd;
|
||||
use tokio::sync::mpsc;
|
||||
use tun::ioctl;
|
||||
use tun::unix::TunFd;
|
||||
|
||||
const TUNSETIFF: libc::c_ulong = 0x4004_54ca;
|
||||
const TUN_DEV_MAJOR: u32 = 10;
|
||||
@@ -40,6 +41,7 @@ const FIREZONE_TABLE: u32 = 0x2021_fd00;
|
||||
/// For lack of a better name
|
||||
pub struct TunDeviceManager {
|
||||
mtu: u32,
|
||||
num_threads: usize,
|
||||
connection: Connection,
|
||||
routes: HashSet<IpNetwork>,
|
||||
}
|
||||
@@ -61,7 +63,7 @@ impl TunDeviceManager {
|
||||
/// Creates a new managed tunnel device.
|
||||
///
|
||||
/// Panics if called without a Tokio runtime.
|
||||
pub fn new(mtu: usize) -> Result<Self> {
|
||||
pub fn new(mtu: usize, num_threads: usize) -> Result<Self> {
|
||||
let (cxn, handle, _) = new_connection()?;
|
||||
let task = tokio::spawn(cxn);
|
||||
let connection = Connection { handle, task };
|
||||
@@ -70,11 +72,12 @@ impl TunDeviceManager {
|
||||
connection,
|
||||
routes: Default::default(),
|
||||
mtu: mtu as u32,
|
||||
num_threads,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn make_tun(&mut self) -> Result<Tun> {
|
||||
Ok(Tun::new()?)
|
||||
Ok(Tun::new(self.num_threads)?)
|
||||
}
|
||||
|
||||
#[tracing::instrument(level = "trace", skip(self))]
|
||||
@@ -284,62 +287,100 @@ async fn remove_route(route: &IpNetwork, idx: u32, handle: &Handle) {
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Tun {
|
||||
fd: AsyncFd<RawFd>,
|
||||
outbound_tx: flume::Sender<IpPacket>,
|
||||
outbound_capacity_waker: Arc<AtomicWaker>,
|
||||
inbound_rx: mpsc::Receiver<IpPacket>,
|
||||
}
|
||||
|
||||
impl Tun {
|
||||
pub fn new() -> io::Result<Self> {
|
||||
pub fn new(num_threads: usize) -> io::Result<Self> {
|
||||
create_tun_device()?;
|
||||
|
||||
let fd = match unsafe { open(TUN_FILE.as_ptr() as _, O_RDWR) } {
|
||||
-1 => return Err(get_last_error()),
|
||||
fd => fd,
|
||||
};
|
||||
let (inbound_tx, inbound_rx) = mpsc::channel(1000);
|
||||
let (outbound_tx, outbound_rx) = flume::bounded(1000); // flume is an MPMC channel, therefore perfect for workstealing outbound packets.
|
||||
let outbound_capacity_waker = Arc::new(AtomicWaker::new());
|
||||
|
||||
// Safety: We just opened the file descriptor.
|
||||
unsafe {
|
||||
ioctl::exec(
|
||||
fd,
|
||||
TUNSETIFF,
|
||||
&mut ioctl::Request::<ioctl::SetTunFlagsPayload>::new(TunDeviceManager::IFACE_NAME),
|
||||
)?;
|
||||
for n in 0..num_threads {
|
||||
let fd = AsyncFd::new(open_tun()?)?;
|
||||
let outbound_rx = outbound_rx.clone().into_stream();
|
||||
let inbound_tx = inbound_tx.clone();
|
||||
let outbound_capacity_waker = outbound_capacity_waker.clone();
|
||||
|
||||
std::thread::Builder::new()
|
||||
.name(format!("TUN send/recv {n}/{num_threads}"))
|
||||
.spawn(move || {
|
||||
tokio::runtime::Builder::new_current_thread()
|
||||
.enable_all()
|
||||
.build()?
|
||||
.block_on(tun::unix::send_recv_tun(
|
||||
fd,
|
||||
inbound_tx,
|
||||
outbound_rx,
|
||||
outbound_capacity_waker,
|
||||
read,
|
||||
write,
|
||||
));
|
||||
|
||||
io::Result::Ok(())
|
||||
})
|
||||
.map_err(io::Error::other)?;
|
||||
}
|
||||
|
||||
set_non_blocking(fd)?;
|
||||
|
||||
// Safety: We just opened the fd.
|
||||
unsafe { Self::from_fd(fd) }
|
||||
}
|
||||
|
||||
/// Create a new [`Tun`] from a raw file descriptor.
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// The file descriptor must be open.
|
||||
unsafe fn from_fd(fd: RawFd) -> io::Result<Self> {
|
||||
Ok(Tun {
|
||||
fd: AsyncFd::new(fd)?,
|
||||
Ok(Self {
|
||||
outbound_tx,
|
||||
outbound_capacity_waker,
|
||||
inbound_rx,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for Tun {
|
||||
fn drop(&mut self) {
|
||||
unsafe { close(self.fd.as_raw_fd()) };
|
||||
fn open_tun() -> Result<TunFd, io::Error> {
|
||||
let fd = match unsafe { open(TUN_FILE.as_ptr() as _, O_RDWR) } {
|
||||
-1 => return Err(get_last_error()),
|
||||
fd => fd,
|
||||
};
|
||||
|
||||
unsafe {
|
||||
ioctl::exec(
|
||||
fd,
|
||||
TUNSETIFF,
|
||||
&mut ioctl::Request::<ioctl::SetTunFlagsPayload>::new(TunDeviceManager::IFACE_NAME),
|
||||
)?;
|
||||
}
|
||||
|
||||
set_non_blocking(fd)?;
|
||||
|
||||
// Safety: We are not closing the FD.
|
||||
let fd = unsafe { TunFd::new(fd) };
|
||||
|
||||
Ok(fd)
|
||||
}
|
||||
|
||||
impl tun::Tun for Tun {
|
||||
fn write4(&self, buf: &[u8]) -> io::Result<usize> {
|
||||
write(self.fd.as_raw_fd(), buf)
|
||||
fn poll_send_ready(&mut self, cx: &mut Context) -> Poll<io::Result<()>> {
|
||||
if self.outbound_tx.is_full() {
|
||||
self.outbound_capacity_waker.register(cx.waker());
|
||||
return Poll::Pending;
|
||||
}
|
||||
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn write6(&self, buf: &[u8]) -> io::Result<usize> {
|
||||
write(self.fd.as_raw_fd(), buf)
|
||||
fn send(&mut self, packet: IpPacket) -> io::Result<()> {
|
||||
self.outbound_tx
|
||||
.try_send(packet)
|
||||
.map_err(io::Error::other)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn poll_read(&mut self, buf: &mut [u8], cx: &mut Context<'_>) -> Poll<io::Result<usize>> {
|
||||
tun::unix::poll_raw_fd(&self.fd, |fd| read(fd, buf), cx)
|
||||
fn poll_recv_many(
|
||||
&mut self,
|
||||
cx: &mut Context,
|
||||
buf: &mut Vec<IpPacket>,
|
||||
max: usize,
|
||||
) -> Poll<usize> {
|
||||
self.inbound_rx.poll_recv_many(cx, buf, max)
|
||||
}
|
||||
|
||||
fn name(&self) -> &str {
|
||||
@@ -389,7 +430,9 @@ fn create_tun_device() -> io::Result<()> {
|
||||
}
|
||||
|
||||
/// Read from the given file descriptor in the buffer.
|
||||
fn read(fd: RawFd, dst: &mut [u8]) -> io::Result<usize> {
|
||||
fn read(fd: RawFd, dst: &mut IpPacketBuf) -> io::Result<usize> {
|
||||
let dst = dst.buf();
|
||||
|
||||
// Safety: Within this module, the file descriptor is always valid.
|
||||
match unsafe { libc::read(fd, dst.as_mut_ptr() as _, dst.len()) } {
|
||||
-1 => Err(io::Error::last_os_error()),
|
||||
@@ -397,8 +440,10 @@ fn read(fd: RawFd, dst: &mut [u8]) -> io::Result<usize> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Write the buffer to the given file descriptor.
|
||||
fn write(fd: RawFd, buf: &[u8]) -> io::Result<usize> {
|
||||
/// Write the packet to the given file descriptor.
|
||||
fn write(fd: RawFd, packet: &IpPacket) -> io::Result<usize> {
|
||||
let buf = packet.packet();
|
||||
|
||||
// Safety: Within this module, the file descriptor is always valid.
|
||||
match unsafe { libc::write(fd, buf.as_ptr() as _, buf.len() as _) } {
|
||||
-1 => Err(io::Error::last_os_error()),
|
||||
|
||||
@@ -3,6 +3,7 @@ use crate::TUNNEL_NAME;
|
||||
use anyhow::{Context as _, Result};
|
||||
use firezone_logging::{anyhow_dyn_err, std_dyn_err};
|
||||
use ip_network::{IpNetwork, Ipv4Network, Ipv6Network};
|
||||
use ip_packet::{IpPacket, IpPacketBuf};
|
||||
use ring::digest;
|
||||
use std::{
|
||||
collections::HashSet,
|
||||
@@ -12,7 +13,7 @@ use std::{
|
||||
path::{Path, PathBuf},
|
||||
process::{Command, Stdio},
|
||||
sync::Arc,
|
||||
task::{ready, Context, Poll},
|
||||
task::{Context, Poll},
|
||||
};
|
||||
use tokio::sync::mpsc;
|
||||
use windows::Win32::{
|
||||
@@ -47,7 +48,7 @@ pub struct TunDeviceManager {
|
||||
|
||||
impl TunDeviceManager {
|
||||
#[expect(clippy::unnecessary_wraps, reason = "Fallible on Linux")]
|
||||
pub fn new(mtu: usize) -> Result<Self> {
|
||||
pub fn new(mtu: usize, _num_threads: usize) -> Result<Self> {
|
||||
Ok(Self {
|
||||
iface_idx: None,
|
||||
routes: HashSet::default(),
|
||||
@@ -190,7 +191,7 @@ pub struct Tun {
|
||||
/// The index of our network adapter, we can use this when asking Windows to add / remove routes / DNS rules
|
||||
/// It's stable across app restarts and I'm assuming across system reboots too.
|
||||
iface_idx: u32,
|
||||
packet_rx: mpsc::Receiver<wintun::Packet>,
|
||||
inbound_rx: mpsc::Receiver<IpPacket>,
|
||||
recv_thread: Option<std::thread::JoinHandle<()>>,
|
||||
session: Arc<wintun::Session>,
|
||||
}
|
||||
@@ -198,10 +199,10 @@ pub struct Tun {
|
||||
impl Drop for Tun {
|
||||
fn drop(&mut self) {
|
||||
tracing::debug!(
|
||||
channel_capacity = self.packet_rx.capacity(),
|
||||
channel_capacity = self.inbound_rx.capacity(),
|
||||
"Shutting down packet channel..."
|
||||
);
|
||||
self.packet_rx.close(); // This avoids a deadlock when we join the worker thread, see PR 5571
|
||||
self.inbound_rx.close(); // This avoids a deadlock when we join the worker thread, see PR 5571
|
||||
if let Err(error) = self.session.shutdown() {
|
||||
tracing::error!(error = std_dyn_err(&error), "wintun::Session::shutdown");
|
||||
}
|
||||
@@ -241,16 +242,14 @@ impl Tun {
|
||||
.start_session(RING_BUFFER_SIZE)
|
||||
.context("Failed to start session")?,
|
||||
);
|
||||
// 4 is a nice power of two. Wintun already queues packets for us, so we don't
|
||||
// need much capacity here.
|
||||
let (packet_tx, packet_rx) = mpsc::channel(4);
|
||||
let recv_thread = start_recv_thread(packet_tx, Arc::clone(&session))
|
||||
let (inbound_tx, inbound_rx) = mpsc::channel(1000); // We want to be able to batch-receive from this.
|
||||
let recv_thread = start_recv_thread(inbound_tx, Arc::clone(&session))
|
||||
.context("Failed to start recv thread")?;
|
||||
|
||||
Ok(Self {
|
||||
iface_idx,
|
||||
recv_thread: Some(recv_thread),
|
||||
packet_rx,
|
||||
inbound_rx,
|
||||
session: Arc::clone(&session),
|
||||
})
|
||||
}
|
||||
@@ -258,72 +257,60 @@ impl Tun {
|
||||
pub fn iface_idx(&self) -> u32 {
|
||||
self.iface_idx
|
||||
}
|
||||
|
||||
// Moves packets from the Internet towards the user
|
||||
fn write(&self, bytes: &[u8]) -> io::Result<usize> {
|
||||
let len = bytes
|
||||
.len()
|
||||
.try_into()
|
||||
.map_err(|_| io::Error::other("Packet too large; length does not fit into u16"))?;
|
||||
|
||||
let Ok(mut pkt) = self.session.allocate_send_packet(len) else {
|
||||
// Ring buffer is full, just drop the packet since we're at the IP layer
|
||||
return Ok(0);
|
||||
};
|
||||
|
||||
pkt.bytes_mut().copy_from_slice(bytes);
|
||||
// `send_packet` cannot fail to enqueue the packet, since we already allocated
|
||||
// space in the ring buffer.
|
||||
self.session.send_packet(pkt);
|
||||
Ok(bytes.len())
|
||||
}
|
||||
}
|
||||
|
||||
impl tun::Tun for Tun {
|
||||
// Moves packets from the user towards the Internet
|
||||
fn poll_read(&mut self, buf: &mut [u8], cx: &mut Context<'_>) -> Poll<io::Result<usize>> {
|
||||
let pkt = ready!(self.packet_rx.poll_recv(cx));
|
||||
|
||||
match pkt {
|
||||
Some(pkt) => {
|
||||
let bytes = pkt.bytes();
|
||||
let len = bytes.len();
|
||||
if len > buf.len() {
|
||||
// This shouldn't happen now that we set IPv4 and IPv6 MTU
|
||||
// If it does, something is wrong.
|
||||
tracing::warn!("Packet is too long to read ({len} bytes)");
|
||||
return Poll::Ready(Ok(0));
|
||||
}
|
||||
buf[0..len].copy_from_slice(bytes);
|
||||
Poll::Ready(Ok(len))
|
||||
}
|
||||
None => {
|
||||
tracing::error!("error receiving packet from mpsc channel");
|
||||
Poll::Ready(Err(std::io::ErrorKind::Other.into()))
|
||||
}
|
||||
}
|
||||
/// Receive a batch of packets up to `max`.
|
||||
fn poll_recv_many(
|
||||
&mut self,
|
||||
cx: &mut Context,
|
||||
buf: &mut Vec<IpPacket>,
|
||||
max: usize,
|
||||
) -> Poll<usize> {
|
||||
self.inbound_rx.poll_recv_many(cx, buf, max)
|
||||
}
|
||||
|
||||
fn name(&self) -> &str {
|
||||
TUNNEL_NAME
|
||||
}
|
||||
|
||||
fn write4(&self, bytes: &[u8]) -> io::Result<usize> {
|
||||
self.write(bytes)
|
||||
/// Check if more packets can be sent.
|
||||
fn poll_send_ready(&mut self, _: &mut Context) -> Poll<io::Result<()>> {
|
||||
// TODO: Figure out how we can do readiness checks on `wintun`.
|
||||
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn write6(&self, bytes: &[u8]) -> io::Result<usize> {
|
||||
self.write(bytes)
|
||||
/// Send a packet.
|
||||
fn send(&mut self, packet: IpPacket) -> io::Result<()> {
|
||||
let bytes = packet.packet();
|
||||
|
||||
let len = bytes
|
||||
.len()
|
||||
.try_into()
|
||||
.map_err(|_| io::Error::other("Packet too large; length does not fit into u16"))?;
|
||||
|
||||
let mut pkt = self
|
||||
.session
|
||||
.allocate_send_packet(len)
|
||||
.map_err(io::Error::other)?;
|
||||
|
||||
pkt.bytes_mut().copy_from_slice(bytes);
|
||||
// `send_packet` cannot fail to enqueue the packet, since we already allocated
|
||||
// space in the ring buffer.
|
||||
self.session.send_packet(pkt);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
// Moves packets from the user towards the Internet
|
||||
fn start_recv_thread(
|
||||
packet_tx: mpsc::Sender<wintun::Packet>,
|
||||
packet_tx: mpsc::Sender<IpPacket>,
|
||||
session: Arc<wintun::Session>,
|
||||
) -> io::Result<std::thread::JoinHandle<()>> {
|
||||
std::thread::Builder::new()
|
||||
.name("Firezone wintun worker".into())
|
||||
.name("TUN recv".into())
|
||||
.spawn(move || loop {
|
||||
let pkt = match session.receive_blocking() {
|
||||
Ok(pkt) => pkt,
|
||||
@@ -339,6 +326,26 @@ fn start_recv_thread(
|
||||
}
|
||||
};
|
||||
|
||||
let mut ip_packet_buf = IpPacketBuf::new();
|
||||
|
||||
let src = pkt.bytes();
|
||||
let dst = ip_packet_buf.buf();
|
||||
|
||||
if src.len() > dst.len() {
|
||||
tracing::warn!(len = %src.len(), "Received too large packet");
|
||||
continue;
|
||||
}
|
||||
|
||||
dst[..src.len()].copy_from_slice(src);
|
||||
|
||||
let pkt = match IpPacket::new(ip_packet_buf, src.len()) {
|
||||
Ok(pkt) => pkt,
|
||||
Err(e) => {
|
||||
tracing::debug!("Failed to parse IP packet: {e:#}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
// Use `blocking_send` so that if connlib is behind by a few packets,
|
||||
// Wintun will queue up new packets in its ring buffer while we
|
||||
// wait for our MPSC channel to clear.
|
||||
@@ -352,7 +359,7 @@ fn start_recv_thread(
|
||||
);
|
||||
break;
|
||||
}
|
||||
}
|
||||
};
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -17,6 +17,9 @@ connlib-client-shared = { workspace = true }
|
||||
connlib-model = { workspace = true }
|
||||
firezone-logging = { workspace = true }
|
||||
firezone-telemetry = { workspace = true }
|
||||
flume = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
ip-packet = { workspace = true }
|
||||
ip_network = { workspace = true }
|
||||
jni = { workspace = true, features = ["invocation"] }
|
||||
libc = { workspace = true }
|
||||
@@ -27,7 +30,7 @@ secrecy = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
socket-factory = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
tokio = { workspace = true, features = ["rt-multi-thread"] }
|
||||
tokio = { workspace = true, features = ["rt-multi-thread", "sync"] }
|
||||
tracing = { workspace = true, features = ["std", "attributes"] }
|
||||
tracing-appender = { workspace = true }
|
||||
tracing-subscriber = { workspace = true }
|
||||
|
||||
@@ -1,39 +1,51 @@
|
||||
use futures::task::AtomicWaker;
|
||||
use ip_packet::{IpPacket, IpPacketBuf};
|
||||
use std::sync::Arc;
|
||||
use std::task::{Context, Poll};
|
||||
use std::{
|
||||
io,
|
||||
os::fd::{AsRawFd, RawFd},
|
||||
};
|
||||
use std::{io, os::fd::RawFd};
|
||||
use tokio::io::unix::AsyncFd;
|
||||
use tokio::sync::mpsc;
|
||||
use tun::ioctl;
|
||||
use tun::unix::TunFd;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Tun {
|
||||
fd: AsyncFd<RawFd>,
|
||||
name: String,
|
||||
}
|
||||
|
||||
impl Drop for Tun {
|
||||
fn drop(&mut self) {
|
||||
unsafe { libc::close(self.fd.as_raw_fd()) };
|
||||
}
|
||||
outbound_tx: flume::Sender<IpPacket>,
|
||||
outbound_capacity_waker: Arc<AtomicWaker>,
|
||||
inbound_rx: mpsc::Receiver<IpPacket>,
|
||||
}
|
||||
|
||||
impl tun::Tun for Tun {
|
||||
fn write4(&self, src: &[u8]) -> std::io::Result<usize> {
|
||||
write(self.fd.as_raw_fd(), src)
|
||||
}
|
||||
|
||||
fn write6(&self, src: &[u8]) -> std::io::Result<usize> {
|
||||
write(self.fd.as_raw_fd(), src)
|
||||
}
|
||||
|
||||
fn poll_read(&mut self, buf: &mut [u8], cx: &mut Context<'_>) -> Poll<io::Result<usize>> {
|
||||
tun::unix::poll_raw_fd(&self.fd, |fd| read(fd, buf), cx)
|
||||
}
|
||||
|
||||
fn name(&self) -> &str {
|
||||
self.name.as_str()
|
||||
}
|
||||
|
||||
fn poll_send_ready(&mut self, cx: &mut Context) -> Poll<io::Result<()>> {
|
||||
if self.outbound_tx.is_full() {
|
||||
self.outbound_capacity_waker.register(cx.waker());
|
||||
return Poll::Pending;
|
||||
}
|
||||
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn send(&mut self, packet: IpPacket) -> io::Result<()> {
|
||||
self.outbound_tx
|
||||
.try_send(packet)
|
||||
.map_err(io::Error::other)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn poll_recv_many(
|
||||
&mut self,
|
||||
cx: &mut Context,
|
||||
buf: &mut Vec<IpPacket>,
|
||||
max: usize,
|
||||
) -> Poll<usize> {
|
||||
self.inbound_rx.poll_recv_many(cx, buf, max)
|
||||
}
|
||||
}
|
||||
|
||||
impl Tun {
|
||||
@@ -41,13 +53,49 @@ impl Tun {
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// The file descriptor must be open.
|
||||
/// - The file descriptor must be open.
|
||||
/// - The file descriptor must not get closed by anyone else.
|
||||
pub unsafe fn from_fd(fd: RawFd) -> io::Result<Self> {
|
||||
let name = interface_name(fd)?;
|
||||
|
||||
// Safety: We are forwarding the safety requirements to the caller.
|
||||
let fd = unsafe { TunFd::new(fd) };
|
||||
|
||||
let fd = AsyncFd::new(fd)?;
|
||||
|
||||
let (inbound_tx, inbound_rx) = mpsc::channel(1000);
|
||||
let (outbound_tx, outbound_rx) = flume::bounded(1000); // flume is an MPMC channel, therefore perfect for workstealing outbound packets.
|
||||
let outbound_capacity_waker = Arc::new(AtomicWaker::new());
|
||||
|
||||
// TODO: Test whether we can set `IFF_MULTI_QUEUE` on Android devices.
|
||||
|
||||
std::thread::Builder::new()
|
||||
.name("TUN send/recv".to_owned())
|
||||
.spawn({
|
||||
let outbound_capacity_waker = outbound_capacity_waker.clone();
|
||||
|| {
|
||||
tokio::runtime::Builder::new_current_thread()
|
||||
.enable_all()
|
||||
.build()?
|
||||
.block_on(tun::unix::send_recv_tun(
|
||||
fd,
|
||||
inbound_tx,
|
||||
outbound_rx.into_stream(),
|
||||
outbound_capacity_waker,
|
||||
read,
|
||||
write,
|
||||
));
|
||||
|
||||
io::Result::Ok(())
|
||||
}
|
||||
})
|
||||
.map_err(io::Error::other)?;
|
||||
|
||||
Ok(Tun {
|
||||
fd: AsyncFd::new(fd)?,
|
||||
name,
|
||||
outbound_tx,
|
||||
inbound_rx,
|
||||
outbound_capacity_waker,
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -67,7 +115,9 @@ unsafe fn interface_name(fd: RawFd) -> io::Result<String> {
|
||||
}
|
||||
|
||||
/// Read from the given file descriptor in the buffer.
|
||||
fn read(fd: RawFd, dst: &mut [u8]) -> io::Result<usize> {
|
||||
fn read(fd: RawFd, dst: &mut IpPacketBuf) -> io::Result<usize> {
|
||||
let dst = dst.buf();
|
||||
|
||||
// Safety: Within this module, the file descriptor is always valid.
|
||||
match unsafe { libc::read(fd, dst.as_mut_ptr() as _, dst.len()) } {
|
||||
-1 => Err(io::Error::last_os_error()),
|
||||
@@ -75,10 +125,12 @@ fn read(fd: RawFd, dst: &mut [u8]) -> io::Result<usize> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Write the buffer to the given file descriptor.
|
||||
fn write(fd: RawFd, buf: &[u8]) -> io::Result<usize> {
|
||||
/// Write the packet to the given file descriptor.
|
||||
fn write(fd: RawFd, packet: &IpPacket) -> io::Result<usize> {
|
||||
let buf = packet.packet();
|
||||
|
||||
// Safety: Within this module, the file descriptor is always valid.
|
||||
match unsafe { libc::write(fd.as_raw_fd(), buf.as_ptr() as _, buf.len() as _) } {
|
||||
match unsafe { libc::write(fd, buf.as_ptr() as _, buf.len() as _) } {
|
||||
-1 => Err(io::Error::last_os_error()),
|
||||
n => Ok(n as usize),
|
||||
}
|
||||
|
||||
@@ -15,6 +15,9 @@ connlib-client-shared = { workspace = true }
|
||||
connlib-model = { workspace = true }
|
||||
firezone-logging = { workspace = true }
|
||||
firezone-telemetry = { workspace = true }
|
||||
flume = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
ip-packet = { workspace = true }
|
||||
ip_network = { workspace = true }
|
||||
libc = { workspace = true }
|
||||
phoenix-channel = { workspace = true }
|
||||
@@ -23,7 +26,7 @@ secrecy = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
socket-factory = { workspace = true }
|
||||
swift-bridge = { workspace = true }
|
||||
tokio = { workspace = true, features = ["rt-multi-thread"] }
|
||||
tokio = { workspace = true, features = ["rt-multi-thread", "sync"] }
|
||||
tracing = { workspace = true }
|
||||
tracing-appender = { workspace = true }
|
||||
tracing-subscriber = { workspace = true }
|
||||
|
||||
@@ -1,76 +1,96 @@
|
||||
use futures::task::AtomicWaker;
|
||||
use ip_packet::{IpPacket, IpPacketBuf};
|
||||
use libc::{fcntl, iovec, msghdr, recvmsg, AF_INET, AF_INET6, F_GETFL, F_SETFL, O_NONBLOCK};
|
||||
use std::sync::Arc;
|
||||
use std::task::{Context, Poll};
|
||||
use std::{
|
||||
io,
|
||||
os::fd::{AsRawFd as _, RawFd},
|
||||
};
|
||||
use tokio::io::unix::AsyncFd;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Tun {
|
||||
name: String,
|
||||
fd: AsyncFd<RawFd>,
|
||||
outbound_capacity_waker: Arc<AtomicWaker>,
|
||||
outbound_tx: flume::Sender<IpPacket>,
|
||||
inbound_rx: mpsc::Receiver<IpPacket>,
|
||||
}
|
||||
|
||||
impl Tun {
|
||||
pub fn new() -> io::Result<Self> {
|
||||
let fd = search_for_tun_fd()?;
|
||||
set_non_blocking(fd)?;
|
||||
|
||||
let name = name(fd)?;
|
||||
|
||||
Ok(Self {
|
||||
let fd = AsyncFd::new(fd)?;
|
||||
|
||||
let (inbound_tx, inbound_rx) = mpsc::channel(1000);
|
||||
let (outbound_tx, outbound_rx) = flume::bounded(1000); // flume is an MPMC channel, therefore perfect for workstealing outbound packets.
|
||||
let outbound_capacity_waker = Arc::new(AtomicWaker::new());
|
||||
|
||||
std::thread::Builder::new()
|
||||
.name("TUN send/recv".to_owned())
|
||||
.spawn({
|
||||
let outbound_capacity_waker = outbound_capacity_waker.clone();
|
||||
|| {
|
||||
tokio::runtime::Builder::new_current_thread()
|
||||
.enable_all()
|
||||
.build()?
|
||||
.block_on(tun::unix::send_recv_tun(
|
||||
fd,
|
||||
inbound_tx,
|
||||
outbound_rx.into_stream(),
|
||||
outbound_capacity_waker,
|
||||
read,
|
||||
write,
|
||||
));
|
||||
|
||||
io::Result::Ok(())
|
||||
}
|
||||
})
|
||||
.map_err(io::Error::other)?;
|
||||
|
||||
Ok(Tun {
|
||||
name,
|
||||
fd: AsyncFd::new(fd)?,
|
||||
outbound_tx,
|
||||
inbound_rx,
|
||||
outbound_capacity_waker,
|
||||
})
|
||||
}
|
||||
|
||||
fn write(&self, src: &[u8], af: u8) -> io::Result<usize> {
|
||||
let mut hdr = [0, 0, 0, af];
|
||||
let mut iov = [
|
||||
iovec {
|
||||
iov_base: hdr.as_mut_ptr() as _,
|
||||
iov_len: hdr.len(),
|
||||
},
|
||||
iovec {
|
||||
iov_base: src.as_ptr() as _,
|
||||
iov_len: src.len(),
|
||||
},
|
||||
];
|
||||
|
||||
let msg_hdr = msghdr {
|
||||
msg_name: std::ptr::null_mut(),
|
||||
msg_namelen: 0,
|
||||
msg_iov: &mut iov[0],
|
||||
msg_iovlen: iov.len() as _,
|
||||
msg_control: std::ptr::null_mut(),
|
||||
msg_controllen: 0,
|
||||
msg_flags: 0,
|
||||
};
|
||||
|
||||
match unsafe { libc::sendmsg(self.fd.as_raw_fd(), &msg_hdr, 0) } {
|
||||
-1 => Err(io::Error::last_os_error()),
|
||||
n => Ok(n as usize),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl tun::Tun for Tun {
|
||||
fn write4(&self, src: &[u8]) -> io::Result<usize> {
|
||||
self.write(src, AF_INET as u8)
|
||||
}
|
||||
|
||||
fn write6(&self, src: &[u8]) -> io::Result<usize> {
|
||||
self.write(src, AF_INET6 as u8)
|
||||
}
|
||||
|
||||
fn poll_read(&mut self, buf: &mut [u8], cx: &mut Context<'_>) -> Poll<io::Result<usize>> {
|
||||
tun::unix::poll_raw_fd(&self.fd, |fd| read(fd, buf), cx)
|
||||
}
|
||||
|
||||
fn name(&self) -> &str {
|
||||
self.name.as_str()
|
||||
}
|
||||
|
||||
fn poll_send_ready(&mut self, cx: &mut Context) -> Poll<io::Result<()>> {
|
||||
if self.outbound_tx.is_full() {
|
||||
self.outbound_capacity_waker.register(cx.waker());
|
||||
return Poll::Pending;
|
||||
}
|
||||
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn send(&mut self, packet: IpPacket) -> io::Result<()> {
|
||||
self.outbound_tx
|
||||
.try_send(packet)
|
||||
.map_err(io::Error::other)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn poll_recv_many(
|
||||
&mut self,
|
||||
cx: &mut Context,
|
||||
buf: &mut Vec<IpPacket>,
|
||||
max: usize,
|
||||
) -> Poll<usize> {
|
||||
self.inbound_rx.poll_recv_many(cx, buf, max)
|
||||
}
|
||||
}
|
||||
|
||||
fn get_last_error() -> io::Error {
|
||||
@@ -87,7 +107,9 @@ fn set_non_blocking(fd: RawFd) -> io::Result<()> {
|
||||
}
|
||||
}
|
||||
|
||||
fn read(fd: RawFd, dst: &mut [u8]) -> io::Result<usize> {
|
||||
fn read(fd: RawFd, dst: &mut IpPacketBuf) -> io::Result<usize> {
|
||||
let dst = dst.buf();
|
||||
|
||||
let mut hdr = [0u8; 4];
|
||||
|
||||
let mut iov = [
|
||||
@@ -119,6 +141,41 @@ fn read(fd: RawFd, dst: &mut [u8]) -> io::Result<usize> {
|
||||
}
|
||||
}
|
||||
|
||||
fn write(fd: RawFd, src: &IpPacket) -> io::Result<usize> {
|
||||
let af = match src {
|
||||
IpPacket::Ipv4(_) => AF_INET,
|
||||
IpPacket::Ipv6(_) => AF_INET6,
|
||||
};
|
||||
let src = src.packet();
|
||||
|
||||
let mut hdr = [0, 0, 0, af];
|
||||
let mut iov = [
|
||||
iovec {
|
||||
iov_base: hdr.as_mut_ptr() as _,
|
||||
iov_len: hdr.len(),
|
||||
},
|
||||
iovec {
|
||||
iov_base: src.as_ptr() as _,
|
||||
iov_len: src.len(),
|
||||
},
|
||||
];
|
||||
|
||||
let msg_hdr = msghdr {
|
||||
msg_name: std::ptr::null_mut(),
|
||||
msg_namelen: 0,
|
||||
msg_iov: &mut iov[0],
|
||||
msg_iovlen: iov.len() as _,
|
||||
msg_control: std::ptr::null_mut(),
|
||||
msg_controllen: 0,
|
||||
msg_flags: 0,
|
||||
};
|
||||
|
||||
match unsafe { libc::sendmsg(fd.as_raw_fd(), &msg_hdr, 0) } {
|
||||
-1 => Err(io::Error::last_os_error()),
|
||||
n => Ok(n as usize),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(any(target_os = "macos", target_os = "ios"))]
|
||||
fn name(fd: RawFd) -> io::Result<String> {
|
||||
use libc::{getsockopt, socklen_t, IF_NAMESIZE, SYSPROTO_CONTROL, UTUN_OPT_IFNAME};
|
||||
|
||||
@@ -38,7 +38,6 @@ socket-factory = { workspace = true }
|
||||
socket2 = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
tokio = { workspace = true }
|
||||
tokio-util = { workspace = true }
|
||||
tracing = { workspace = true, features = ["attributes"] }
|
||||
tun = { workspace = true }
|
||||
uuid = { workspace = true, features = ["std", "v4"] }
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use domain::base::iana::Rcode;
|
||||
use domain::base::{Message, ParsedName, Rtype};
|
||||
use domain::rdata::AllRecordData;
|
||||
use ip_packet::{IpPacket, IpPacketBuf};
|
||||
use ip_packet::IpPacket;
|
||||
use itertools::Itertools;
|
||||
use std::io;
|
||||
use std::task::{Context, Poll, Waker};
|
||||
@@ -31,47 +31,46 @@ impl Device {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn poll_read(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<IpPacket>> {
|
||||
pub(crate) fn poll_read_many(
|
||||
&mut self,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut Vec<IpPacket>,
|
||||
max: usize,
|
||||
) -> Poll<usize> {
|
||||
let Some(tun) = self.tun.as_mut() else {
|
||||
self.waker = Some(cx.waker().clone());
|
||||
return Poll::Pending;
|
||||
};
|
||||
|
||||
let mut ip_packet = IpPacketBuf::new();
|
||||
let n = std::task::ready!(tun.poll_read(ip_packet.buf(), cx))?;
|
||||
let n = std::task::ready!(tun.poll_recv_many(cx, buf, max));
|
||||
|
||||
if n == 0 {
|
||||
self.tun = None;
|
||||
|
||||
return Poll::Ready(Err(io::Error::new(
|
||||
io::ErrorKind::UnexpectedEof,
|
||||
"device is closed",
|
||||
)));
|
||||
}
|
||||
|
||||
let packet = IpPacket::new(ip_packet, n).map_err(|e| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
format!("Failed to parse IP packet: {e:#}"),
|
||||
)
|
||||
})?;
|
||||
|
||||
if tracing::event_enabled!(target: "wire::dns::qry", Level::TRACE) {
|
||||
if let Some((qtype, qname, qid)) = parse_dns_query(&packet) {
|
||||
tracing::trace!(target: "wire::dns::qry", %qid, "{:5} {qname}", qtype.to_string());
|
||||
for packet in &buf[..n] {
|
||||
if tracing::event_enabled!(target: "wire::dns::qry", Level::TRACE) {
|
||||
if let Some((qtype, qname, qid)) = parse_dns_query(packet) {
|
||||
tracing::trace!(target: "wire::dns::qry", %qid, "{:5} {qname}", qtype.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
if packet.is_fz_p2p_control() {
|
||||
tracing::warn!("Packet matches heuristics of FZ-internal p2p control protocol");
|
||||
}
|
||||
|
||||
tracing::trace!(target: "wire::dev::recv", dst = %packet.destination(), src = %packet.source(), bytes = %packet.packet().len());
|
||||
}
|
||||
|
||||
if packet.is_fz_p2p_control() {
|
||||
tracing::warn!("Packet matches heuristics of FZ-internal p2p control protocol");
|
||||
}
|
||||
|
||||
tracing::trace!(target: "wire::dev::recv", dst = %packet.destination(), src = %packet.source(), bytes = %packet.packet().len());
|
||||
|
||||
Poll::Ready(Ok(packet))
|
||||
Poll::Ready(n)
|
||||
}
|
||||
|
||||
pub fn write(&self, packet: IpPacket) -> io::Result<usize> {
|
||||
pub fn poll_send_ready(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
let Some(tun) = self.tun.as_mut() else {
|
||||
self.waker = Some(cx.waker().clone());
|
||||
return Poll::Pending;
|
||||
};
|
||||
|
||||
tun.poll_send_ready(cx)
|
||||
}
|
||||
|
||||
pub fn send(&mut self, packet: IpPacket) -> io::Result<()> {
|
||||
if tracing::event_enabled!(target: "wire::dns::res", Level::TRACE) {
|
||||
if let Some((qtype, qname, records, rcode, qid)) = parse_dns_response(&packet) {
|
||||
tracing::trace!(target: "wire::dns::res", %qid, %rcode, "{:5} {qname} => [{records}]", qtype.to_string());
|
||||
@@ -85,18 +84,17 @@ impl Device {
|
||||
"FZ p2p control protocol packets should never leave `connlib`"
|
||||
);
|
||||
|
||||
match packet {
|
||||
IpPacket::Ipv4(msg) => self.tun()?.write4(msg.packet()),
|
||||
IpPacket::Ipv6(msg) => self.tun()?.write6(msg.packet()),
|
||||
}
|
||||
self.tun()?.send(packet)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn tun(&self) -> io::Result<&dyn Tun> {
|
||||
fn tun(&mut self) -> io::Result<&mut dyn Tun> {
|
||||
Ok(self
|
||||
.tun
|
||||
.as_ref()
|
||||
.as_mut()
|
||||
.ok_or_else(io_error_not_initialized)?
|
||||
.as_ref())
|
||||
.as_mut())
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -2,11 +2,7 @@ mod gso_queue;
|
||||
|
||||
use crate::{device_channel::Device, dns, sockets::Sockets};
|
||||
use domain::base::Message;
|
||||
use firezone_logging::{err_with_src, telemetry_event, telemetry_span};
|
||||
use futures::{
|
||||
future::{self, Either},
|
||||
stream, Stream, StreamExt,
|
||||
};
|
||||
use firezone_logging::{telemetry_event, telemetry_span};
|
||||
use futures_bounded::FuturesTupleSet;
|
||||
use futures_util::FutureExt as _;
|
||||
use gso_queue::GsoQueue;
|
||||
@@ -21,11 +17,7 @@ use std::{
|
||||
task::{ready, Context, Poll},
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
use tokio::{
|
||||
io::{AsyncReadExt, AsyncWriteExt},
|
||||
sync::mpsc,
|
||||
};
|
||||
use tokio_util::sync::PollSender;
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
use tracing::Instrument;
|
||||
use tun::Tun;
|
||||
|
||||
@@ -33,7 +25,7 @@ use tun::Tun;
|
||||
///
|
||||
/// Reading IP packets from the channel in batches allows us to process (i.e. encrypt) them as a batch.
|
||||
/// UDP datagrams of the same size and destination can then be sent in a single syscall using GSO.
|
||||
const MAX_INBOUND_PACKET_BATCH: usize = 50;
|
||||
const MAX_INBOUND_PACKET_BATCH: usize = 100;
|
||||
const MAX_UDP_SIZE: usize = (1 << 16) - 1;
|
||||
|
||||
/// Bundles together all side-effects that connlib needs to have access to.
|
||||
@@ -49,10 +41,8 @@ pub struct Io {
|
||||
|
||||
timeout: Option<Pin<Box<tokio::time::Sleep>>>,
|
||||
|
||||
tun_tx: mpsc::Sender<Box<dyn Tun>>,
|
||||
tun: Device,
|
||||
outbound_packet_buffer: VecDeque<IpPacket>,
|
||||
outbound_packet_tx: PollSender<IpPacket>,
|
||||
inbound_packet_rx: mpsc::Receiver<IpPacket>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
@@ -86,7 +76,6 @@ pub enum Input<D, I> {
|
||||
}
|
||||
|
||||
const DNS_QUERY_TIMEOUT: Duration = Duration::from_secs(5);
|
||||
const IP_CHANNEL_SIZE: usize = 1000;
|
||||
|
||||
impl Io {
|
||||
/// Creates a new I/O abstraction
|
||||
@@ -99,32 +88,15 @@ impl Io {
|
||||
let mut sockets = Sockets::default();
|
||||
sockets.rebind(udp_socket_factory.as_ref()); // Bind sockets on startup. Must happen within a tokio runtime context.
|
||||
|
||||
let (inbound_packet_tx, inbound_packet_rx) = mpsc::channel(IP_CHANNEL_SIZE);
|
||||
let (outbound_packet_tx, outbound_packet_rx) = mpsc::channel(IP_CHANNEL_SIZE);
|
||||
let (tun_tx, tun_rx) = mpsc::channel(10);
|
||||
|
||||
std::thread::Builder::new()
|
||||
.name("connlib-tun-send-recv".to_string())
|
||||
.spawn(|| {
|
||||
futures::executor::block_on(tun_send_recv(
|
||||
tun_rx,
|
||||
outbound_packet_rx,
|
||||
inbound_packet_tx,
|
||||
))
|
||||
})
|
||||
.expect("Failed to spawn tun_send_recv thread");
|
||||
|
||||
Self {
|
||||
tun_tx,
|
||||
outbound_packet_buffer: VecDeque::with_capacity(10), // It is unlikely that we process more than 10 packets after 1 GRO call.
|
||||
outbound_packet_tx: PollSender::new(outbound_packet_tx),
|
||||
inbound_packet_rx,
|
||||
timeout: None,
|
||||
sockets,
|
||||
tcp_socket_factory,
|
||||
udp_socket_factory,
|
||||
dns_queries: FuturesTupleSet::new(DNS_QUERY_TIMEOUT, 1000),
|
||||
gso_queue: GsoQueue::new(),
|
||||
tun: Device::new(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -151,8 +123,8 @@ impl Io {
|
||||
}
|
||||
|
||||
if let Poll::Ready(num_packets) =
|
||||
self.inbound_packet_rx
|
||||
.poll_recv_many(cx, &mut buffers.ip, MAX_INBOUND_PACKET_BATCH)
|
||||
self.tun
|
||||
.poll_read_many(cx, &mut buffers.ip, MAX_INBOUND_PACKET_BATCH)
|
||||
{
|
||||
return Poll::Ready(Ok(Input::Device(buffers.ip.drain(..num_packets))));
|
||||
}
|
||||
@@ -209,11 +181,8 @@ impl Io {
|
||||
}
|
||||
|
||||
loop {
|
||||
// First, acquire a slot in the channel.
|
||||
ready!(self
|
||||
.outbound_packet_tx
|
||||
.poll_reserve(cx)
|
||||
.map_err(|_| io::ErrorKind::BrokenPipe)?);
|
||||
// First, check if we can send more packets.
|
||||
ready!(self.tun.poll_send_ready(cx))?;
|
||||
|
||||
// Second, check if we have any buffer packets.
|
||||
let Some(packet) = self.outbound_packet_buffer.pop_front() else {
|
||||
@@ -221,20 +190,14 @@ impl Io {
|
||||
};
|
||||
|
||||
// Third, send the packet.
|
||||
self.outbound_packet_tx
|
||||
.send_item(packet)
|
||||
.map_err(|_| io::ErrorKind::BrokenPipe)?;
|
||||
self.tun.send(packet)?;
|
||||
}
|
||||
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
pub fn set_tun(&mut self, tun: Box<dyn Tun>) {
|
||||
// If we can't set a new TUN device, shut down connlib.
|
||||
|
||||
self.tun_tx
|
||||
.try_send(tun)
|
||||
.expect("Channel to set new TUN device should always have capacity");
|
||||
self.tun.set_tun(tun);
|
||||
}
|
||||
|
||||
pub fn send_tun(&mut self, packet: IpPacket) {
|
||||
@@ -350,82 +313,6 @@ impl Io {
|
||||
}
|
||||
}
|
||||
|
||||
async fn tun_send_recv(
|
||||
mut tun_rx: mpsc::Receiver<Box<dyn Tun>>,
|
||||
mut outbound_packet_rx: mpsc::Receiver<IpPacket>,
|
||||
inbound_packet_tx: mpsc::Sender<IpPacket>,
|
||||
) {
|
||||
let mut device = Device::new();
|
||||
|
||||
let mut command_stream = stream::select_all([
|
||||
new_tun_stream(&mut tun_rx),
|
||||
outgoing_packet_stream(&mut outbound_packet_rx),
|
||||
]);
|
||||
|
||||
loop {
|
||||
match future::select(
|
||||
command_stream.next(),
|
||||
future::poll_fn(|cx| device.poll_read(cx)),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Either::Left((Some(Command::SendPacket(p)), _)) => {
|
||||
if let Err(e) = device.write(p) {
|
||||
tracing::debug!("Failed to write TUN packet: {}", err_with_src(&e));
|
||||
};
|
||||
}
|
||||
Either::Left((Some(Command::UpdateTun(tun)), _)) => {
|
||||
device.set_tun(tun);
|
||||
}
|
||||
Either::Left((None, _)) => {
|
||||
tracing::debug!("Command stream closed");
|
||||
return;
|
||||
}
|
||||
Either::Right((Ok(p), _)) => {
|
||||
if inbound_packet_tx.send(p).await.is_err() {
|
||||
tracing::debug!("Inbound packet channel closed");
|
||||
return;
|
||||
};
|
||||
}
|
||||
Either::Right((Err(e), _)) => {
|
||||
tracing::debug!(
|
||||
"Failed to read packet from TUN device: {}",
|
||||
err_with_src(&e)
|
||||
);
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
#[expect(
|
||||
clippy::large_enum_variant,
|
||||
reason = "We purposely don't want to allocate each IP packet."
|
||||
)]
|
||||
enum Command {
|
||||
UpdateTun(Box<dyn Tun>),
|
||||
SendPacket(IpPacket),
|
||||
}
|
||||
|
||||
fn new_tun_stream(
|
||||
tun_rx: &mut mpsc::Receiver<Box<dyn Tun>>,
|
||||
) -> Pin<Box<dyn Stream<Item = Command> + '_>> {
|
||||
Box::pin(stream::poll_fn(|cx| {
|
||||
tun_rx
|
||||
.poll_recv(cx)
|
||||
.map(|maybe_t| maybe_t.map(Command::UpdateTun))
|
||||
}))
|
||||
}
|
||||
|
||||
fn outgoing_packet_stream(
|
||||
outbound_packet_rx: &mut mpsc::Receiver<IpPacket>,
|
||||
) -> Pin<Box<dyn Stream<Item = Command> + '_>> {
|
||||
Box::pin(stream::poll_fn(|cx| {
|
||||
outbound_packet_rx
|
||||
.poll_recv(cx)
|
||||
.map(|maybe_p| maybe_p.map(Command::SendPacket))
|
||||
}))
|
||||
}
|
||||
|
||||
fn is_max_wg_packet_size(d: &DatagramIn) -> bool {
|
||||
let len = d.packet.len();
|
||||
if len > MAX_DATAGRAM_PAYLOAD {
|
||||
@@ -444,14 +331,6 @@ mod tests {
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn max_ip_channel_size_is_reasonable() {
|
||||
let one_ip_packet = std::mem::size_of::<IpPacket>();
|
||||
let max_channel_size = IP_CHANNEL_SIZE * one_ip_packet;
|
||||
|
||||
assert_eq!(max_channel_size, 1_360_000); // 1.36MB is fine, we only have 2 of these channels, meaning less than 3MB additional buffer in total.
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn timer_is_reset_after_it_fires() {
|
||||
let now = Instant::now();
|
||||
@@ -460,6 +339,7 @@ mod tests {
|
||||
Arc::new(|_| Err(io::Error::other("not implemented"))),
|
||||
Arc::new(|_| Err(io::Error::other("not implemented"))),
|
||||
);
|
||||
io.set_tun(Box::new(DummyTun));
|
||||
|
||||
io.reset_timeout(now + Duration::from_secs(1));
|
||||
|
||||
@@ -494,4 +374,29 @@ mod tests {
|
||||
udp4: Vec::new(),
|
||||
udp6: Vec::new(),
|
||||
};
|
||||
|
||||
struct DummyTun;
|
||||
|
||||
impl Tun for DummyTun {
|
||||
fn poll_send_ready(&mut self, _: &mut Context) -> Poll<io::Result<()>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn send(&mut self, _: IpPacket) -> io::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn poll_recv_many(
|
||||
&mut self,
|
||||
_: &mut Context,
|
||||
_: &mut Vec<IpPacket>,
|
||||
_: usize,
|
||||
) -> Poll<usize> {
|
||||
Poll::Pending
|
||||
}
|
||||
|
||||
fn name(&self) -> &str {
|
||||
"dummy"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,7 +10,6 @@ use anyhow::{Context as _, Result};
|
||||
use domain::base::{iana::Rcode, MessageBuilder};
|
||||
use firezone_bin_shared::TunDeviceManager;
|
||||
use ip_network::Ipv4Network;
|
||||
use ip_packet::{IpPacket, IpPacketBuf};
|
||||
use tokio::task::JoinSet;
|
||||
use tun::Tun;
|
||||
|
||||
@@ -24,7 +23,7 @@ async fn smoke() {
|
||||
let ipv4 = Ipv4Addr::from([100, 90, 215, 97]);
|
||||
let ipv6 = Ipv6Addr::from([0xfd00, 0x2021, 0x1111, 0x0, 0x0, 0x0, 0x0016, 0x588f]);
|
||||
|
||||
let mut device_manager = TunDeviceManager::new(1280).unwrap();
|
||||
let mut device_manager = TunDeviceManager::new(1280, 1).unwrap();
|
||||
let tun = device_manager.make_tun().unwrap();
|
||||
device_manager.set_ips(ipv4, ipv6).await.unwrap();
|
||||
device_manager
|
||||
@@ -100,11 +99,10 @@ impl Eventloop {
|
||||
|
||||
fn poll(&mut self, cx: &mut Context) -> Poll<()> {
|
||||
loop {
|
||||
ready!(self.tun.poll_send_ready(cx)).unwrap();
|
||||
|
||||
if let Some(packet) = self.dns_server.poll_outbound() {
|
||||
match packet {
|
||||
IpPacket::Ipv4(v4) => self.tun.write4(v4.packet()).unwrap(),
|
||||
IpPacket::Ipv6(v6) => self.tun.write6(v6.packet()).unwrap(),
|
||||
};
|
||||
self.tun.send(packet).unwrap();
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -120,12 +118,12 @@ impl Eventloop {
|
||||
continue;
|
||||
}
|
||||
|
||||
let mut packet_buf = IpPacketBuf::default();
|
||||
let num_read = ready!(self.tun.poll_read(packet_buf.buf(), cx)).unwrap();
|
||||
let packet = IpPacket::new(packet_buf, num_read).unwrap();
|
||||
let mut buf = Vec::with_capacity(1);
|
||||
ready!(self.tun.poll_recv_many(cx, &mut buf, 1));
|
||||
let ip_packet = buf.remove(0);
|
||||
|
||||
if self.dns_server.accepts(&packet) {
|
||||
self.dns_server.handle_inbound(packet);
|
||||
if self.dns_server.accepts(&ip_packet) {
|
||||
self.dns_server.handle_inbound(ip_packet);
|
||||
self.dns_server.handle_timeout(Instant::now());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -33,7 +33,7 @@ serde = { workspace = true, features = ["std", "derive"] }
|
||||
snownet = { workspace = true }
|
||||
socket-factory = { workspace = true }
|
||||
static_assertions = { workspace = true }
|
||||
tokio = { workspace = true, features = ["sync", "macros", "rt-multi-thread", "fs", "signal"] }
|
||||
tokio = { workspace = true, features = ["sync", "macros", "fs", "signal", "rt"] }
|
||||
tracing = { workspace = true }
|
||||
tracing-subscriber = { workspace = true }
|
||||
url = { workspace = true }
|
||||
|
||||
@@ -47,7 +47,10 @@ fn main() {
|
||||
);
|
||||
}
|
||||
|
||||
let runtime = tokio::runtime::Runtime::new().expect("Failed to create tokio runtime");
|
||||
let runtime = tokio::runtime::Builder::new_current_thread()
|
||||
.enable_all()
|
||||
.build()
|
||||
.expect("Failed to create tokio runtime");
|
||||
|
||||
match runtime.block_on(try_main(cli, &mut telemetry)) {
|
||||
Ok(()) => runtime.block_on(telemetry.stop()),
|
||||
@@ -79,7 +82,7 @@ async fn try_main(cli: Cli, telemetry: &mut Telemetry) -> Result<()> {
|
||||
cli.firezone_name,
|
||||
)?;
|
||||
|
||||
let task = tokio::spawn(run(login)).err_into();
|
||||
let task = tokio::spawn(run(login, cli.tun_threads)).err_into();
|
||||
|
||||
let ctrl_c = pin!(ctrl_c().map_err(anyhow::Error::new));
|
||||
|
||||
@@ -122,7 +125,7 @@ async fn get_firezone_id(env_id: Option<String>) -> Result<String> {
|
||||
Ok(id)
|
||||
}
|
||||
|
||||
async fn run(login: LoginUrl<PublicKeyParam>) -> Result<Infallible> {
|
||||
async fn run(login: LoginUrl<PublicKeyParam>, num_tun_threads: usize) -> Result<Infallible> {
|
||||
let mut tunnel = GatewayTunnel::new(Arc::new(tcp_socket_factory), Arc::new(udp_socket_factory));
|
||||
let portal = PhoenixChannel::disconnected(
|
||||
Secret::new(login),
|
||||
@@ -138,7 +141,7 @@ async fn run(login: LoginUrl<PublicKeyParam>) -> Result<Infallible> {
|
||||
)?;
|
||||
|
||||
let (sender, receiver) = mpsc::channel::<Interface>(10);
|
||||
let mut tun_device_manager = TunDeviceManager::new(ip_packet::PACKET_SIZE)?;
|
||||
let mut tun_device_manager = TunDeviceManager::new(ip_packet::PACKET_SIZE, num_tun_threads)?;
|
||||
let tun = tun_device_manager.make_tun()?;
|
||||
tunnel.set_tun(Box::new(tun));
|
||||
|
||||
@@ -203,6 +206,10 @@ struct Cli {
|
||||
/// Identifier generated by the portal to identify and display the device.
|
||||
#[arg(short = 'i', long, env = "FIREZONE_ID")]
|
||||
pub firezone_id: Option<String>,
|
||||
|
||||
/// How many threads to use for reading and writing to the TUN device.
|
||||
#[arg(long, env = "FIREZONE_NUM_TUN_THREADS", default_value_t = 2)]
|
||||
tun_threads: usize,
|
||||
}
|
||||
|
||||
impl Cli {
|
||||
|
||||
@@ -31,7 +31,7 @@ smbios-lib = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
# This actually relies on many other features in Tokio, so this will probably
|
||||
# fail to build outside the workspace. <https://github.com/firezone/firezone/pull/4328#discussion_r1540342142>
|
||||
tokio = { workspace = true, features = ["macros", "signal", "process", "time", "rt-multi-thread", "fs"] }
|
||||
tokio = { workspace = true, features = ["macros", "signal", "process", "time", "fs", "rt"] }
|
||||
tokio-stream = { workspace = true }
|
||||
tokio-util = { workspace = true, features = ["codec"] }
|
||||
tracing = { workspace = true }
|
||||
|
||||
@@ -213,7 +213,7 @@ mod tests {
|
||||
|
||||
let rt = tokio::runtime::Runtime::new().unwrap();
|
||||
|
||||
let mut tun_dev_manager = firezone_bin_shared::TunDeviceManager::new(1280).unwrap();
|
||||
let mut tun_dev_manager = firezone_bin_shared::TunDeviceManager::new(1280, 1).unwrap(); // Note: num_threads (`1`) is unused on windows.
|
||||
let _tun = tun_dev_manager.make_tun().unwrap();
|
||||
|
||||
rt.block_on(async {
|
||||
|
||||
@@ -177,7 +177,9 @@ fn run_debug_ipc_service(cli: Cli) -> Result<()> {
|
||||
if !platform::elevation_check()? {
|
||||
bail!("IPC service failed its elevation check, try running as admin / root");
|
||||
}
|
||||
let rt = tokio::runtime::Runtime::new()?;
|
||||
let rt = tokio::runtime::Builder::new_current_thread()
|
||||
.enable_all()
|
||||
.build()?;
|
||||
let _guard = rt.enter();
|
||||
let mut signals = signals::Terminate::new()?;
|
||||
|
||||
@@ -202,7 +204,9 @@ fn run_smoke_test() -> Result<()> {
|
||||
if !platform::elevation_check()? {
|
||||
bail!("IPC service failed its elevation check, try running as admin / root");
|
||||
}
|
||||
let rt = tokio::runtime::Runtime::new()?;
|
||||
let rt = tokio::runtime::Builder::new_current_thread()
|
||||
.enable_all()
|
||||
.build()?;
|
||||
let _guard = rt.enter();
|
||||
let mut dns_controller = DnsController {
|
||||
dns_control_method: Default::default(),
|
||||
@@ -325,7 +329,7 @@ impl<'a> Handler<'a> {
|
||||
.next_client_split()
|
||||
.await
|
||||
.context("Failed to wait for incoming IPC connection from a GUI")?;
|
||||
let tun_device = TunDeviceManager::new(ip_packet::PACKET_SIZE)?;
|
||||
let tun_device = TunDeviceManager::new(ip_packet::PACKET_SIZE, crate::NUM_TUN_THREADS)?;
|
||||
|
||||
Ok(Self {
|
||||
dns_controller,
|
||||
|
||||
@@ -10,7 +10,9 @@ pub(crate) fn run_ipc_service(cli: CliCommon) -> Result<()> {
|
||||
if !elevation_check()? {
|
||||
bail!("IPC service failed its elevation check, try running as admin / root");
|
||||
}
|
||||
let rt = tokio::runtime::Runtime::new()?;
|
||||
let rt = tokio::runtime::Builder::new_current_thread()
|
||||
.enable_all()
|
||||
.build()?;
|
||||
let _guard = rt.enter();
|
||||
let mut signals = signals::Terminate::new()?;
|
||||
|
||||
|
||||
@@ -155,7 +155,9 @@ fn fallible_service_run(
|
||||
bail!("IPC service failed its elevation check, try running as admin / root");
|
||||
}
|
||||
|
||||
let rt = tokio::runtime::Runtime::new()?;
|
||||
let rt = tokio::runtime::Builder::new_current_thread()
|
||||
.enable_all()
|
||||
.build()?;
|
||||
let (shutdown_tx, shutdown_rx) = mpsc::channel(1);
|
||||
|
||||
let event_handler = move |control_event| -> ServiceControlHandlerResult {
|
||||
|
||||
@@ -46,6 +46,9 @@ pub type LogFilterReloader = tracing_subscriber::reload::Handle<EnvFilter, Regis
|
||||
/// Only used on Linux
|
||||
pub const FIREZONE_GROUP: &str = "firezone-client";
|
||||
|
||||
/// Empirically tested to have the best performance.
|
||||
pub const NUM_TUN_THREADS: usize = 2;
|
||||
|
||||
/// CLI args common to both the IPC service and the headless Client
|
||||
#[derive(clap::Parser)]
|
||||
pub struct CliCommon {
|
||||
|
||||
@@ -236,7 +236,10 @@ fn main() -> Result<()> {
|
||||
let mut terminate = signals::Terminate::new()?;
|
||||
let mut hangup = signals::Hangup::new()?;
|
||||
|
||||
let mut tun_device = TunDeviceManager::new(ip_packet::PACKET_SIZE)?;
|
||||
let mut tun_device = TunDeviceManager::new(
|
||||
ip_packet::PACKET_SIZE,
|
||||
firezone_headless_client::NUM_TUN_THREADS,
|
||||
)?;
|
||||
let mut cb_rx = ReceiverStream::new(cb_rx).fuse();
|
||||
|
||||
let tokio_handle = tokio::runtime::Handle::current();
|
||||
|
||||
@@ -5,9 +5,15 @@ edition = { workspace = true }
|
||||
license = { workspace = true }
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
ip-packet = { workspace = true }
|
||||
|
||||
[target.'cfg(target_family = "unix")'.dependencies]
|
||||
libc = { workspace = true }
|
||||
tokio = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
flume = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
@@ -36,7 +36,7 @@ impl Request<SetTunFlagsPayload> {
|
||||
Self {
|
||||
name,
|
||||
payload: SetTunFlagsPayload {
|
||||
flags: (libc::IFF_TUN | libc::IFF_NO_PI) as _,
|
||||
flags: (libc::IFF_TUN | libc::IFF_NO_PI | libc::IFF_MULTI_QUEUE) as _,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,14 +3,27 @@ use std::{
|
||||
task::{Context, Poll},
|
||||
};
|
||||
|
||||
use ip_packet::IpPacket;
|
||||
|
||||
#[cfg(target_family = "unix")]
|
||||
pub mod ioctl;
|
||||
#[cfg(target_family = "unix")]
|
||||
pub mod unix;
|
||||
|
||||
pub trait Tun: Send + Sync + 'static {
|
||||
fn write4(&self, buf: &[u8]) -> io::Result<usize>;
|
||||
fn write6(&self, buf: &[u8]) -> io::Result<usize>;
|
||||
fn poll_read(&mut self, buf: &mut [u8], cx: &mut Context<'_>) -> Poll<io::Result<usize>>;
|
||||
/// Check if more packets can be sent.
|
||||
fn poll_send_ready(&mut self, cx: &mut Context) -> Poll<io::Result<()>>;
|
||||
/// Send a packet.
|
||||
fn send(&mut self, packet: IpPacket) -> io::Result<()>;
|
||||
|
||||
/// Receive a batch of packets up to `max`.
|
||||
fn poll_recv_many(
|
||||
&mut self,
|
||||
cx: &mut Context,
|
||||
buf: &mut Vec<IpPacket>,
|
||||
max: usize,
|
||||
) -> Poll<usize>;
|
||||
|
||||
/// The name of the TUN device.
|
||||
fn name(&self) -> &str;
|
||||
}
|
||||
|
||||
@@ -1,25 +1,107 @@
|
||||
use futures::future::Either;
|
||||
use futures::task::AtomicWaker;
|
||||
use futures::StreamExt as _;
|
||||
use ip_packet::{IpPacket, IpPacketBuf};
|
||||
use std::io;
|
||||
use std::os::fd::{AsRawFd, RawFd};
|
||||
use std::task::{Context, Poll};
|
||||
use tokio::io::Ready;
|
||||
use std::pin::pin;
|
||||
use std::sync::Arc;
|
||||
use tokio::io::unix::AsyncFd;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
pub fn poll_raw_fd(
|
||||
fd: &tokio::io::unix::AsyncFd<RawFd>,
|
||||
mut read: impl FnMut(RawFd) -> io::Result<usize>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<io::Result<usize>> {
|
||||
pub struct TunFd {
|
||||
inner: RawFd,
|
||||
}
|
||||
|
||||
impl TunFd {
|
||||
/// # Safety
|
||||
///
|
||||
/// You must not close this FD yourself.
|
||||
/// [`TunFd`] will close it for you.
|
||||
pub unsafe fn new(fd: RawFd) -> Self {
|
||||
Self { inner: fd }
|
||||
}
|
||||
}
|
||||
|
||||
impl AsRawFd for TunFd {
|
||||
fn as_raw_fd(&self) -> RawFd {
|
||||
self.inner
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for TunFd {
|
||||
fn drop(&mut self) {
|
||||
// Safety: We are the only ones closing the FD.
|
||||
unsafe { libc::close(self.inner) };
|
||||
}
|
||||
}
|
||||
|
||||
/// Concurrently reads and writes packets to the given TUN file-descriptor using the provided function pointers for the actual syscall.
|
||||
///
|
||||
/// - Every packet received on `outbound_rx` channel will be written to the file descriptor using the `write` syscall.
|
||||
/// - Every packet read using the `read` syscall will be sent into the `inbound_tx` channel.
|
||||
/// - Every time we read a packet from `outbound_rx`, we notify `outbound_capacity_waker` about the newly gained capacity.
|
||||
/// - In case any of the channels close, we exit the task.
|
||||
/// - IO errors are not fallible.
|
||||
pub async fn send_recv_tun<T>(
|
||||
fd: AsyncFd<T>,
|
||||
inbound_tx: mpsc::Sender<IpPacket>,
|
||||
mut outbound_rx: flume::r#async::RecvStream<'static, IpPacket>,
|
||||
outbound_capacity_waker: Arc<AtomicWaker>,
|
||||
read: impl Fn(RawFd, &mut IpPacketBuf) -> io::Result<usize>,
|
||||
write: impl Fn(RawFd, &IpPacket) -> io::Result<usize>,
|
||||
) where
|
||||
T: AsRawFd,
|
||||
{
|
||||
loop {
|
||||
let mut guard = std::task::ready!(fd.poll_read_ready(cx))?;
|
||||
let next_inbound_packet = pin!(fd.async_io(tokio::io::Interest::READABLE, |fd| {
|
||||
let mut ip_packet_buf = IpPacketBuf::new();
|
||||
|
||||
match read(guard.get_inner().as_raw_fd()) {
|
||||
Ok(n) => return Poll::Ready(Ok(n)),
|
||||
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
|
||||
// a read has blocked, but a write might still succeed.
|
||||
// clear only the read readiness.
|
||||
guard.clear_ready_matching(Ready::READABLE);
|
||||
let len = read(fd.as_raw_fd(), &mut ip_packet_buf)?;
|
||||
|
||||
if len == 0 {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let packet = IpPacket::new(ip_packet_buf, len)
|
||||
.map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?;
|
||||
|
||||
Ok(Some(packet))
|
||||
}));
|
||||
let next_outbound_packet = pin!(outbound_rx.next());
|
||||
|
||||
match futures::future::select(next_inbound_packet, next_outbound_packet).await {
|
||||
Either::Left((Ok(None), _)) => {
|
||||
tracing::error!("TUN FD is closed");
|
||||
return;
|
||||
}
|
||||
Either::Left((Ok(Some(packet)), _)) => {
|
||||
if inbound_tx.send(packet).await.is_err() {
|
||||
tracing::debug!("Inbound packet receiver gone, shutting down task");
|
||||
|
||||
return;
|
||||
};
|
||||
}
|
||||
Either::Left((Err(e), _)) => {
|
||||
tracing::warn!("Failed to read from TUN FD: {e}");
|
||||
continue;
|
||||
}
|
||||
Err(e) => return Poll::Ready(Err(e)),
|
||||
Either::Right((Some(packet), _)) => {
|
||||
if let Err(e) = fd
|
||||
.async_io(tokio::io::Interest::WRITABLE, |fd| {
|
||||
write(fd.as_raw_fd(), &packet)
|
||||
})
|
||||
.await
|
||||
{
|
||||
tracing::warn!("Failed to write to TUN FD: {e}");
|
||||
};
|
||||
|
||||
outbound_capacity_waker.wake(); // We wrote a packet, notify about the new capacity.
|
||||
}
|
||||
Either::Right((None, _)) => {
|
||||
tracing::debug!("Outbound packet sender gone, shutting down task");
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -23,6 +23,12 @@ export default function GUI({ title }: { title: string }) {
|
||||
Makes use of the new control protocol, delivering faster and more
|
||||
robust connection establishment.
|
||||
</ChangeItem>
|
||||
{title == "Linux GUI" && (
|
||||
<ChangeItem pull="7449">
|
||||
Uses multiple threads to read & write to the TUN device, greatly
|
||||
improving performance.
|
||||
</ChangeItem>
|
||||
)}
|
||||
</Unreleased>
|
||||
<Entry version="1.3.13" date={new Date("2024-11-15")}>
|
||||
<ChangeItem pull="7334">
|
||||
|
||||
@@ -19,6 +19,11 @@ export default function Gateway() {
|
||||
Fixes cases where client applications such as ssh would fail to
|
||||
automatically determine the correct IP protocol version to use (4/6).
|
||||
</ChangeItem>
|
||||
<ChangeItem pull="7449">
|
||||
Uses multiple threads to read & write to the TUN device, greatly
|
||||
improving performance. The number of threads can be controlled with
|
||||
`FIREZONE_NUM_TUN_THREADS` and defaults to 2.
|
||||
</ChangeItem>
|
||||
</Unreleased>
|
||||
<Entry version="1.4.1" date={new Date("2024-11-15")}>
|
||||
<ChangeItem pull="7263">
|
||||
|
||||
@@ -23,6 +23,10 @@ export default function Headless() {
|
||||
Makes use of the new control protocol, delivering faster and more
|
||||
robust connection establishment.
|
||||
</ChangeItem>
|
||||
<ChangeItem pull="7449">
|
||||
Uses multiple threads to read & write to the TUN device, greatly
|
||||
improving performance.
|
||||
</ChangeItem>
|
||||
</Unreleased>
|
||||
<Entry version="1.3.7" date={new Date("2024-11-15")}>
|
||||
<ChangeItem pull="7334">
|
||||
|
||||
Reference in New Issue
Block a user