refactor(connlib): remove async from the Device API (#2815)

At present, the definition of `Device` is heavily nested with
conditional code. I've found this hard to understand and navigate.
Recent refactorings now made it possible to remove a lot of these layers
so we primarily deal with two concepts:

- A `Device` which offers async read and non-blocking write functions
- A `Tun` abstraction which is platform-specific

Instead of dedicated modules, I chose to feature-flag individual
functions on `Device` with `#[cfg(target_family = "unix")]` and
`#[cfg(target_family = "windows")]`. I find this easier to understand
because the code is right next to each other.

In addition, changing the module hierarchy of `Device` allows us to
remove `async` from the public API which is only introduced by the use
of `rtnetlink` in Linux. Instead of making functions across all `Tun`
implementations `async`, we embed a "worker" within the `linux::Tun`
implementation that gets polled before `poll_read`.

---------

Co-authored-by: Gabi <gabrielalejandro7@gmail.com>
This commit is contained in:
Thomas Eizinger
2023-12-13 06:47:26 +11:00
committed by GitHub
parent cd3114cc1d
commit 0de16d3676
19 changed files with 757 additions and 878 deletions

View File

@@ -88,7 +88,7 @@ impl<CB: Callbacks + 'static> ControlPlane<CB> {
{
let mut init = self.tunnel_init.lock().await;
if !*init {
if let Err(e) = self.tunnel.set_interface(&interface).await {
if let Err(e) = self.tunnel.set_interface(&interface) {
tracing::error!(error = ?e, "Error initializing interface");
return Err(e);
} else {
@@ -103,7 +103,7 @@ impl<CB: Callbacks + 'static> ControlPlane<CB> {
}
for resource_description in resources {
self.add_resource(resource_description).await;
self.add_resource(resource_description);
}
Ok(())
}
@@ -141,8 +141,8 @@ impl<CB: Callbacks + 'static> ControlPlane<CB> {
}
#[tracing::instrument(level = "trace", skip(self))]
pub async fn add_resource(&self, resource_description: ResourceDescription) {
if let Err(e) = self.tunnel.add_resource(resource_description).await {
pub fn add_resource(&self, resource_description: ResourceDescription) {
if let Err(e) = self.tunnel.add_resource(resource_description) {
tracing::error!(message = "Can't add resource", error = ?e);
let _ = self.tunnel.callbacks().on_error(&e);
}
@@ -240,7 +240,7 @@ impl<CB: Callbacks + 'static> ControlPlane<CB> {
self.connection_details(connection_details, reference)
}
Messages::Connect(connect) => self.connect(connect),
Messages::ResourceAdded(resource) => self.add_resource(resource).await,
Messages::ResourceAdded(resource) => self.add_resource(resource),
Messages::ResourceRemoved(resource) => self.remove_resource(resource.id),
Messages::ResourceUpdated(resource) => self.update_resource(resource),
Messages::IceCandidates(ice_candidate) => self.add_ice_candidate(ice_candidate).await,

View File

@@ -1,5 +1,5 @@
use crate::bounded_queue::BoundedQueue;
use crate::device_channel::{create_iface, Packet};
use crate::device_channel::{Device, Packet};
use crate::ip_packet::{IpPacket, MutableIpPacket};
use crate::peer::PacketTransformClient;
use crate::{
@@ -60,7 +60,7 @@ where
/// Once added, when a packet for the resource is intercepted a new data channel will be created
/// and packets will be wrapped with wireguard and sent through it.
#[tracing::instrument(level = "trace", skip(self))]
pub async fn add_resource(
pub fn add_resource(
&self,
resource_description: ResourceDescription,
) -> connlib_shared::Result<()> {
@@ -72,7 +72,7 @@ where
.insert(dns.address.clone(), dns.clone());
}
ResourceDescription::Cidr(cidr) => {
self.add_route(cidr.address).await?;
self.add_route(cidr.address)?;
self.role_state
.lock()
@@ -110,13 +110,13 @@ where
/// Sets the interface configuration and starts background tasks.
#[tracing::instrument(level = "trace", skip(self))]
pub async fn set_interface(&self, config: &InterfaceConfig) -> connlib_shared::Result<()> {
pub fn set_interface(&self, config: &InterfaceConfig) -> connlib_shared::Result<()> {
if !config.upstream_dns.is_empty() {
self.role_state.lock().dns_strategy = DnsFallbackStrategy::UpstreamResolver;
}
let dns_strategy = self.role_state.lock().dns_strategy;
let device = Arc::new(create_iface(config, self.callbacks(), dns_strategy).await?);
let device = Arc::new(Device::new(config, self.callbacks(), dns_strategy)?);
self.device.store(Some(device.clone()));
self.no_device_waker.wake();
@@ -124,12 +124,12 @@ where
// TODO: the requirement for the DNS_SENTINEL means you NEED ipv4 stack
// we are trying to support ipv4 and ipv6, so we should have an ipv6 dns sentinel
// alternative.
self.add_route(DNS_SENTINEL.into()).await?;
self.add_route(DNS_SENTINEL.into())?;
// Note: I'm just assuming this needs to succeed since we already require ipv4 stack due to the dns sentinel
// TODO: change me when we don't require ipv4
self.add_route(IPV4_RESOURCES.parse().unwrap()).await?;
self.add_route(IPV4_RESOURCES.parse().unwrap())?;
if let Err(e) = self.add_route(IPV6_RESOURCES.parse().unwrap()).await {
if let Err(e) = self.add_route(IPV6_RESOURCES.parse().unwrap()) {
tracing::warn!(err = ?e, "ipv6 not supported");
}
@@ -148,14 +148,13 @@ where
}
#[tracing::instrument(level = "trace", skip(self))]
pub async fn add_route(&self, route: IpNetwork) -> connlib_shared::Result<()> {
pub fn add_route(&self, route: IpNetwork) -> connlib_shared::Result<()> {
let maybe_new_device = self
.device
.load()
.as_ref()
.ok_or(Error::ControlProtocolError)?
.add_route(route, self.callbacks())
.await?;
.add_route(route, self.callbacks())?;
if let Some(new_device) = maybe_new_device {
self.device.swap(Some(Arc::new(new_device)));

View File

@@ -216,16 +216,11 @@ where
tracing::trace!(?peer_config.ips, "new_data_channel_open");
let device = self.device.load().clone().ok_or(Error::NoIface)?;
let callbacks = self.callbacks.clone();
let ips = peer_config.ips.clone();
// Worst thing if this is not run before peers_by_ip is that some packets are lost to the default route
tokio::spawn(async move {
for ip in ips {
if let Ok(res) = device.add_route(ip, &callbacks).await {
assert!(res.is_none(), "gateway does not run on android and thus never produces a new device upon `add_route`");
}
for ip in &peer_config.ips {
if let Ok(res) = device.add_route(*ip, &callbacks) {
assert!(res.is_none(), "gateway does not run on android and thus never produces a new device upon `add_route`");
}
});
}
let peer = Arc::new(Peer::new(
self.private_key.clone(),

View File

@@ -1,39 +1,78 @@
#![allow(clippy::module_inception)]
#![cfg_attr(target_family = "windows", allow(dead_code))] // TODO: Remove when windows is fully implemented.
#[cfg(target_family = "unix")]
#[path = "device_channel/device_channel_unix.rs"]
mod device_channel;
#[cfg(any(target_os = "macos", target_os = "ios"))]
#[path = "device_channel/tun_darwin.rs"]
mod tun;
#[cfg(target_os = "linux")]
#[path = "device_channel/tun_linux.rs"]
mod tun;
#[cfg(target_family = "windows")]
#[path = "device_channel/device_channel_win.rs"]
mod device_channel;
#[path = "device_channel/tun_windows.rs"]
mod tun;
// TODO: Android and linux are nearly identical; use a common tunnel module?
#[cfg(target_os = "android")]
#[path = "device_channel/tun_android.rs"]
mod tun;
use crate::ip_packet::MutableIpPacket;
use crate::DnsFallbackStrategy;
use connlib_shared::error::ConnlibError;
use connlib_shared::messages::Interface;
use connlib_shared::{Callbacks, Error};
pub(crate) use device_channel::*;
use ip_network::IpNetwork;
use std::borrow::Cow;
use std::io;
use std::task::{ready, Context, Poll};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::task::{Context, Poll};
use tun::Tun;
pub struct Device {
config: IfaceConfig,
io: DeviceIo,
mtu: AtomicUsize,
tun: Tun,
}
impl Device {
#[cfg(target_family = "unix")]
pub(crate) fn new(
config: &Interface,
callbacks: &impl Callbacks<Error = Error>,
dns: DnsFallbackStrategy,
) -> Result<Device, ConnlibError> {
let tun = Tun::new(config, callbacks, dns)?;
let mtu = AtomicUsize::new(ioctl::interface_mtu_by_name(tun.name())?);
Ok(Device { mtu, tun })
}
#[cfg(target_family = "windows")]
pub(crate) fn new(
_: &Interface,
_: &impl Callbacks<Error = Error>,
_: DnsFallbackStrategy,
) -> Result<Device, ConnlibError> {
Ok(Device {
tun: Tun::new(),
mtu: AtomicUsize::default(), // Dummy value for now.
})
}
#[cfg(target_family = "unix")]
pub(crate) fn poll_read<'b>(
&self,
buf: &'b mut [u8],
cx: &mut Context<'_>,
) -> Poll<io::Result<Option<MutableIpPacket<'b>>>> {
let res = ready!(self.io.poll_read(&mut buf[..self.config.mtu()], cx))?;
let n = std::task::ready!(self.tun.poll_read(&mut buf[..self.mtu()], cx))?;
if res == 0 {
if n == 0 {
return Poll::Ready(Ok(None));
}
Poll::Ready(Ok(Some(MutableIpPacket::new(&mut buf[..res]).ok_or_else(
Poll::Ready(Ok(Some(MutableIpPacket::new(&mut buf[..n]).ok_or_else(
|| {
io::Error::new(
io::ErrorKind::InvalidInput,
@@ -43,20 +82,60 @@ impl Device {
)?)))
}
pub(crate) async fn add_route(
#[cfg(target_family = "windows")]
pub(crate) fn poll_read<'b>(
&self,
_: &'b mut [u8],
_: &mut Context<'_>,
) -> Poll<io::Result<Option<MutableIpPacket<'b>>>> {
Poll::Pending
}
pub(crate) fn mtu(&self) -> usize {
self.mtu.load(Ordering::Relaxed)
}
#[cfg(target_family = "unix")]
pub(crate) fn add_route(
&self,
route: IpNetwork,
callbacks: &impl Callbacks<Error = Error>,
) -> Result<Option<Device>, Error> {
self.config.add_route(route, callbacks).await
let Some(tun) = self.tun.add_route(route, callbacks)? else {
return Ok(None);
};
let mtu = AtomicUsize::new(ioctl::interface_mtu_by_name(tun.name())?);
Ok(Some(Device { mtu, tun }))
}
pub(crate) async fn refresh_mtu(&self) -> Result<usize, Error> {
self.config.refresh_mtu().await
#[cfg(target_family = "windows")]
pub(crate) fn add_route(
&self,
_: IpNetwork,
_: &impl Callbacks<Error = Error>,
) -> Result<Option<Device>, Error> {
Ok(None)
}
#[cfg(target_family = "unix")]
pub(crate) fn refresh_mtu(&self) -> Result<usize, Error> {
let mtu = ioctl::interface_mtu_by_name(self.tun.name())?;
self.mtu.store(mtu, Ordering::Relaxed);
Ok(mtu)
}
#[cfg(target_family = "windows")]
pub(crate) fn refresh_mtu(&self) -> Result<usize, Error> {
Ok(0)
}
pub fn write(&self, packet: Packet<'_>) -> io::Result<usize> {
self.io.write(packet)
match packet {
Packet::Ipv4(msg) => self.tun.write4(&msg),
Packet::Ipv6(msg) => self.tun.write6(&msg),
}
}
}
@@ -64,3 +143,98 @@ pub enum Packet<'a> {
Ipv4(Cow<'a, [u8]>),
Ipv6(Cow<'a, [u8]>),
}
#[cfg(target_family = "unix")]
mod ioctl {
use super::*;
use std::os::fd::RawFd;
use tun::SIOCGIFMTU;
pub(crate) fn interface_mtu_by_name(name: &str) -> Result<usize, ConnlibError> {
let socket = Socket::ip4()?;
let request = Request::<GetInterfaceMtuPayload>::new(name)?;
// Safety: The file descriptor is open.
unsafe {
exec(socket.fd, SIOCGIFMTU, &request)?;
}
Ok(request.payload.mtu as usize)
}
/// Executes the `ioctl` syscall on the given file descriptor with the provided request.
///
/// # Safety
///
/// The file descriptor must be open.
pub(crate) unsafe fn exec<P>(
fd: RawFd,
code: libc::c_ulong,
req: &Request<P>,
) -> Result<(), ConnlibError> {
let ret = unsafe { libc::ioctl(fd, code as _, req) };
if ret < 0 {
return Err(io::Error::last_os_error().into());
}
Ok(())
}
/// Represents a control request to an IO device, addresses by the device's name.
///
/// The payload MUST also be `#[repr(C)]` and its layout depends on the particular request you are sending.
#[repr(C)]
pub(crate) struct Request<P> {
pub(crate) name: [std::ffi::c_uchar; libc::IF_NAMESIZE],
pub(crate) payload: P,
}
/// A socket newtype which closes the file descriptor on drop.
struct Socket {
fd: RawFd,
}
impl Socket {
fn ip4() -> io::Result<Socket> {
// Safety: All provided parameters are constants.
let fd = unsafe { libc::socket(libc::AF_INET, libc::SOCK_STREAM, libc::IPPROTO_IP) };
if fd == -1 {
return Err(io::Error::last_os_error());
}
Ok(Self { fd })
}
}
impl Drop for Socket {
fn drop(&mut self) {
// Safety: This is the only call to `close` and it happens when `Guard` is being dropped.
unsafe { libc::close(self.fd) };
}
}
impl Request<GetInterfaceMtuPayload> {
fn new(name: &str) -> io::Result<Self> {
if name.len() > libc::IF_NAMESIZE {
return Err(io::ErrorKind::InvalidInput.into());
}
let mut request = Request {
name: [0u8; libc::IF_NAMESIZE],
payload: Default::default(),
};
request.name[..name.len()].copy_from_slice(name.as_bytes());
Ok(request)
}
}
#[derive(Default)]
#[repr(C)]
struct GetInterfaceMtuPayload {
mtu: libc::c_int,
}
}

View File

@@ -1,99 +0,0 @@
use std::io;
use std::sync::{
atomic::{AtomicUsize, Ordering::Relaxed},
Arc,
};
use std::task::{ready, Context, Poll};
use connlib_shared::{messages::Interface, Callbacks, Error, Result};
use ip_network::IpNetwork;
use tokio::io::{unix::AsyncFd, Ready};
use tun::{IfaceDevice, IfaceStream};
use crate::device_channel::{Device, Packet};
use crate::DnsFallbackStrategy;
mod tun;
pub(crate) struct IfaceConfig {
mtu: AtomicUsize,
iface: IfaceDevice,
}
pub(crate) struct DeviceIo(Arc<AsyncFd<IfaceStream>>);
impl DeviceIo {
pub fn poll_read(&self, out: &mut [u8], cx: &mut Context<'_>) -> Poll<io::Result<usize>> {
loop {
let mut guard = ready!(self.0.poll_read_ready(cx))?;
match guard.get_inner().read(out) {
Ok(n) => return Poll::Ready(Ok(n)),
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
// a read has blocked, but a write might still succeed.
// clear only the read readiness.
guard.clear_ready_matching(Ready::READABLE);
continue;
}
Err(e) => return Poll::Ready(Err(e)),
}
}
}
// Note: write is synchronous because it's non-blocking
// and some losiness is acceptable and increseases performance
// since we don't block the reading loops.
pub fn write(&self, packet: Packet<'_>) -> io::Result<usize> {
match packet {
Packet::Ipv4(msg) => self.0.get_ref().write4(&msg),
Packet::Ipv6(msg) => self.0.get_ref().write6(&msg),
}
}
}
impl IfaceConfig {
pub(crate) fn mtu(&self) -> usize {
self.mtu.load(Relaxed)
}
pub(crate) async fn refresh_mtu(&self) -> Result<usize> {
let mtu = self.iface.mtu().await?;
self.mtu.store(mtu, Relaxed);
Ok(mtu)
}
pub(crate) async fn add_route(
&self,
route: IpNetwork,
callbacks: &impl Callbacks<Error = Error>,
) -> Result<Option<Device>> {
let Some((iface, stream)) = self.iface.add_route(route, callbacks).await? else {
return Ok(None);
};
let io = DeviceIo(stream);
let mtu = iface.mtu().await?;
let config = IfaceConfig {
iface,
mtu: AtomicUsize::new(mtu),
};
Ok(Some(Device { io, config }))
}
}
pub(crate) async fn create_iface(
config: &Interface,
callbacks: &impl Callbacks<Error = Error>,
fallback_strategy: DnsFallbackStrategy,
) -> Result<Device> {
let (iface, stream) = IfaceDevice::new(config, callbacks, fallback_strategy).await?;
iface.up().await?;
let io = DeviceIo(stream);
let mtu = iface.mtu().await?;
let config = IfaceConfig {
iface,
mtu: AtomicUsize::new(mtu),
};
Ok(Device { io, config })
}

View File

@@ -1,60 +0,0 @@
use crate::device_channel::Packet;
use crate::Device;
use crate::DnsFallbackStrategy;
use connlib_shared::{messages::Interface, Callbacks, Result};
use ip_network::IpNetwork;
use std::task::{Context, Poll};
// TODO: Fill all this out. These are just stubs to test the GUI.
pub(crate) struct DeviceIo;
pub(crate) struct IfaceConfig;
impl DeviceIo {
pub fn poll_read(&self, _: &mut [u8], _: &mut Context<'_>) -> Poll<std::io::Result<usize>> {
// Incoming packets will never appear
Poll::Pending
}
pub fn write(&self, packet: Packet<'_>) -> std::io::Result<usize> {
// All outgoing packets are successfully written to the void
match packet {
Packet::Ipv4(msg) => Ok(msg.len()),
Packet::Ipv6(msg) => Ok(msg.len()),
}
}
}
const BOGUS_MTU: usize = 1_500;
impl IfaceConfig {
pub(crate) fn mtu(&self) -> usize {
BOGUS_MTU
}
pub(crate) async fn refresh_mtu(&self) -> Result<usize> {
Ok(BOGUS_MTU)
}
pub(crate) async fn add_route(
&self,
_: IpNetwork,
_: &impl Callbacks,
) -> Result<Option<Device>> {
let io = DeviceIo {};
let config = IfaceConfig {};
Ok(Some(Device { io, config }))
}
}
pub(crate) async fn create_iface(
_: &Interface,
_: &impl Callbacks,
_: DnsFallbackStrategy,
) -> Result<Device> {
Ok(Device {
config: IfaceConfig {},
io: DeviceIo {},
})
}

View File

@@ -1,16 +0,0 @@
#![allow(clippy::module_inception)]
#[cfg(any(target_os = "macos", target_os = "ios"))]
#[path = "tun/tun_darwin.rs"]
mod tun;
#[cfg(target_os = "linux")]
#[path = "tun/tun_linux.rs"]
mod tun;
// TODO: Android and linux are nearly identical; use a common tunnel module?
#[cfg(target_os = "android")]
#[path = "tun/tun_android.rs"]
mod tun;
pub(crate) use tun::*;

View File

@@ -1,35 +0,0 @@
use std::os::fd::{AsRawFd, RawFd};
use std::sync::atomic::{AtomicBool, Ordering};
#[derive(Debug)]
pub(crate) struct Closeable {
closed: AtomicBool,
value: RawFd,
}
impl AsRawFd for Closeable {
fn as_raw_fd(&self) -> RawFd {
self.value.as_raw_fd()
}
}
impl Closeable {
pub(crate) fn new(fd: RawFd) -> Self {
Self {
closed: AtomicBool::new(false),
value: fd,
}
}
pub(crate) fn with<U>(&self, f: impl FnOnce(RawFd) -> U) -> std::io::Result<U> {
if self.closed.load(Ordering::Acquire) {
return Err(std::io::Error::from_raw_os_error(9));
}
Ok(f(self.value))
}
pub(crate) fn close(&self) {
self.closed.store(true, Ordering::Release);
}
}

View File

@@ -1,196 +0,0 @@
use closeable::Closeable;
use connlib_shared::{
messages::Interface as InterfaceConfig, Callbacks, Error, Result, DNS_SENTINEL,
};
use ip_network::IpNetwork;
use libc::{
close, ioctl, read, sockaddr, sockaddr_in, write, AF_INET, IFNAMSIZ, IPPROTO_IP, SIOCGIFMTU,
SOCK_STREAM,
};
use std::{
ffi::{c_int, c_short, c_uchar},
io,
os::fd::{AsRawFd, RawFd},
sync::Arc,
};
use tokio::io::unix::AsyncFd;
use crate::DnsFallbackStrategy;
mod closeable;
mod wrapped_socket;
#[repr(C)]
union IfrIfru {
ifru_addr: sockaddr,
ifru_addr_v4: sockaddr_in,
ifru_addr_v6: sockaddr_in,
ifru_dstaddr: sockaddr,
ifru_broadaddr: sockaddr,
ifru_flags: c_short,
ifru_metric: c_int,
ifru_mtu: c_int,
ifru_phys: c_int,
ifru_media: c_int,
ifru_intval: c_int,
ifru_wake_flags: u32,
ifru_route_refcnt: u32,
ifru_cap: [c_int; 2],
ifru_functional_type: u32,
}
#[repr(C)]
pub struct ifreq {
ifr_name: [c_uchar; IFNAMSIZ],
ifr_ifru: IfrIfru,
}
const TUNGETIFF: u64 = 0x800454d2;
#[derive(Debug)]
pub(crate) struct IfaceDevice(Arc<AsyncFd<IfaceStream>>);
#[derive(Debug)]
pub(crate) struct IfaceStream {
fd: Closeable,
}
impl AsRawFd for IfaceStream {
fn as_raw_fd(&self) -> RawFd {
self.fd.as_raw_fd()
}
}
impl Drop for IfaceStream {
fn drop(&mut self) {
unsafe { close(self.fd.as_raw_fd()) };
}
}
impl IfaceStream {
fn write(&self, buf: &[u8]) -> std::io::Result<usize> {
match self
.fd
.with(|fd| unsafe { write(fd, buf.as_ptr() as _, buf.len() as _) })?
{
-1 => Err(io::Error::last_os_error()),
n => Ok(n as usize),
}
}
pub fn write4(&self, src: &[u8]) -> std::io::Result<usize> {
self.write(src)
}
pub fn write6(&self, src: &[u8]) -> std::io::Result<usize> {
self.write(src)
}
pub fn read(&self, dst: &mut [u8]) -> std::io::Result<usize> {
// We don't read(or write) again from the fd because the given fd number might be reclaimed
// so this could make an spurious read/write to another fd and we DEFINITELY don't want that
match self
.fd
.with(|fd| unsafe { read(fd, dst.as_mut_ptr() as _, dst.len()) })?
{
-1 => Err(io::Error::last_os_error()),
n => Ok(n as usize),
}
}
pub fn close(&self) {
self.fd.close();
}
}
impl IfaceDevice {
pub async fn new(
config: &InterfaceConfig,
callbacks: &impl Callbacks<Error = Error>,
fallback_strategy: DnsFallbackStrategy,
) -> Result<(Self, Arc<AsyncFd<IfaceStream>>)> {
let fd = callbacks
.on_set_interface_config(
config.ipv4,
config.ipv6,
DNS_SENTINEL,
fallback_strategy.to_string(),
)?
.ok_or(Error::NoFd)?;
let iface_stream = Arc::new(AsyncFd::new(IfaceStream {
fd: Closeable::new(fd.into()),
})?);
let this = Self(Arc::clone(&iface_stream));
Ok((this, iface_stream))
}
fn name(&self) -> Result<String> {
let mut ifr = ifreq {
ifr_name: [0; IFNAMSIZ],
ifr_ifru: unsafe { std::mem::zeroed() },
};
match self
.0
.get_ref()
.fd
.with(|fd| unsafe { ioctl(fd, TUNGETIFF as _, &mut ifr) })?
{
0 => {
let name_cstr = unsafe { std::ffi::CStr::from_ptr(ifr.ifr_name.as_ptr() as _) };
Ok(name_cstr.to_string_lossy().into_owned())
}
_ => Err(get_last_error()),
}
}
/// Get the current MTU value
pub async fn mtu(&self) -> Result<usize> {
let socket = wrapped_socket::WrappedSocket::new(AF_INET, SOCK_STREAM, IPPROTO_IP);
let fd = match socket.as_raw_fd() {
-1 => return Err(get_last_error()),
fd => fd,
};
let name = self.name()?;
let iface_name: &[u8] = name.as_ref();
let mut ifr = ifreq {
ifr_name: [0; IFNAMSIZ],
ifr_ifru: IfrIfru { ifru_mtu: 0 },
};
ifr.ifr_name[..iface_name.len()].copy_from_slice(iface_name);
if unsafe { ioctl(fd, SIOCGIFMTU as _, &ifr) } < 0 {
return Err(get_last_error());
}
let mtu = unsafe { ifr.ifr_ifru.ifru_mtu };
Ok(mtu as _)
}
pub async fn add_route(
&self,
route: IpNetwork,
callbacks: &impl Callbacks<Error = Error>,
) -> Result<Option<(Self, Arc<AsyncFd<IfaceStream>>)>> {
self.0.get_ref().close();
let fd = callbacks.on_add_route(route)?.ok_or(Error::NoFd)?;
let iface_stream = Arc::new(AsyncFd::new(IfaceStream {
fd: Closeable::new(fd.into()),
})?);
let this = Self(Arc::clone(&iface_stream));
Ok(Some((this, iface_stream)))
}
pub async fn up(&self) -> Result<()> {
Ok(())
}
}
fn get_last_error() -> Error {
Error::Io(io::Error::last_os_error())
}

View File

@@ -1,273 +0,0 @@
use connlib_shared::{messages::Interface as InterfaceConfig, Callbacks, Error, Result};
use futures::TryStreamExt;
use ip_network::IpNetwork;
use libc::{
close, fcntl, ioctl, open, read, sockaddr, sockaddr_in, write, F_GETFL, F_SETFL,
IFF_MULTI_QUEUE, IFF_NO_PI, IFF_TUN, IFNAMSIZ, O_NONBLOCK, O_RDWR,
};
use netlink_packet_route::{rtnl::link::nlas::Nla, RT_SCOPE_UNIVERSE};
use rtnetlink::{new_connection, Error::NetlinkError, Handle};
use std::{
ffi::{c_int, c_short, c_uchar},
io,
os::fd::{AsRawFd, RawFd},
sync::Arc,
};
use tokio::io::unix::AsyncFd;
use crate::DnsFallbackStrategy;
const IFACE_NAME: &str = "tun-firezone";
const TUNSETIFF: u64 = 0x4004_54ca;
const TUN_FILE: &[u8] = b"/dev/net/tun\0";
const RT_PROT_STATIC: u8 = 4;
const DEFAULT_MTU: u32 = 1280;
const FILE_ALREADY_EXISTS: i32 = -17;
#[repr(C)]
union IfrIfru {
ifru_addr: sockaddr,
ifru_addr_v4: sockaddr_in,
ifru_addr_v6: sockaddr_in,
ifru_dstaddr: sockaddr,
ifru_broadaddr: sockaddr,
ifru_flags: c_short,
ifru_metric: c_int,
ifru_mtu: c_int,
ifru_phys: c_int,
ifru_media: c_int,
ifru_intval: c_int,
ifru_wake_flags: u32,
ifru_route_refcnt: u32,
ifru_cap: [c_int; 2],
ifru_functional_type: u32,
}
#[repr(C)]
pub struct ifreq {
ifr_name: [c_uchar; IFNAMSIZ],
ifr_ifru: IfrIfru,
}
#[derive(Debug)]
pub struct IfaceDevice {
handle: Handle,
connection: tokio::task::JoinHandle<()>,
interface_index: u32,
}
#[derive(Debug)]
pub struct IfaceStream(RawFd);
impl AsRawFd for IfaceStream {
fn as_raw_fd(&self) -> RawFd {
self.0
}
}
impl Drop for IfaceStream {
fn drop(&mut self) {
unsafe { close(self.0) };
}
}
impl Drop for IfaceDevice {
fn drop(&mut self) {
self.connection.abort();
}
}
impl IfaceStream {
fn write(&self, buf: &[u8]) -> std::io::Result<usize> {
match unsafe { write(self.0, buf.as_ptr() as _, buf.len() as _) } {
-1 => Err(io::Error::last_os_error()),
n => Ok(n as usize),
}
}
pub fn write4(&self, buf: &[u8]) -> std::io::Result<usize> {
self.write(buf)
}
pub fn write6(&self, buf: &[u8]) -> std::io::Result<usize> {
self.write(buf)
}
pub fn read(&self, dst: &mut [u8]) -> std::io::Result<usize> {
match unsafe { read(self.0, dst.as_mut_ptr() as _, dst.len()) } {
-1 => Err(io::Error::last_os_error()),
n => Ok(n as usize),
}
}
}
impl IfaceDevice {
pub async fn new(
config: &InterfaceConfig,
cb: &impl Callbacks,
_: DnsFallbackStrategy,
) -> Result<(Self, Arc<AsyncFd<IfaceStream>>)> {
debug_assert!(IFACE_NAME.as_bytes().len() < IFNAMSIZ);
let fd = match unsafe { open(TUN_FILE.as_ptr() as _, O_RDWR) } {
-1 => return Err(get_last_error()),
fd => fd,
};
let mut ifr = ifreq {
ifr_name: [0; IFNAMSIZ],
ifr_ifru: IfrIfru {
ifru_flags: (IFF_TUN | IFF_NO_PI | IFF_MULTI_QUEUE) as _,
},
};
ifr.ifr_name[..IFACE_NAME.as_bytes().len()].copy_from_slice(IFACE_NAME.as_bytes());
if unsafe { ioctl(fd, TUNSETIFF as _, &ifr) } < 0 {
return Err(get_last_error());
}
let (connection, handle, _) = new_connection()?;
let join_handle = tokio::spawn(connection);
let interface_index = handle
.link()
.get()
.match_name(IFACE_NAME.to_string())
.execute()
.try_next()
.await?
.ok_or(Error::NoIface)?
.header
.index;
set_non_blocking(fd)?;
let this = Self {
handle,
connection: join_handle,
interface_index,
};
this.set_iface_config(config, cb).await?;
Ok((this, Arc::new(AsyncFd::new(IfaceStream(fd))?)))
}
/// Get the current MTU value
pub async fn mtu(&self) -> Result<usize> {
while let Ok(Some(msg)) = self
.handle
.link()
.get()
.match_index(self.interface_index)
.execute()
.try_next()
.await
{
for nla in msg.nlas {
if let Nla::Mtu(mtu) = nla {
return Ok(mtu as usize);
}
}
}
Err(Error::NoMtu)
}
pub async fn add_route(
&self,
route: IpNetwork,
_: &impl Callbacks,
) -> Result<Option<(Self, Arc<AsyncFd<IfaceStream>>)>> {
let req = self
.handle
.route()
.add()
.output_interface(self.interface_index)
.protocol(RT_PROT_STATIC)
.scope(RT_SCOPE_UNIVERSE);
let res = match route {
IpNetwork::V4(ipnet) => {
req.v4()
.destination_prefix(ipnet.network_address(), ipnet.netmask())
.execute()
.await
}
IpNetwork::V6(ipnet) => {
req.v6()
.destination_prefix(ipnet.network_address(), ipnet.netmask())
.execute()
.await
}
};
match res {
Ok(_) => Ok(None),
Err(NetlinkError(err)) if err.raw_code() == FILE_ALREADY_EXISTS => Ok(None),
Err(err) => Err(err.into()),
}
}
#[tracing::instrument(level = "trace", skip(self))]
pub async fn set_iface_config(
&self,
config: &InterfaceConfig,
_: &impl Callbacks,
) -> Result<()> {
let ips = self
.handle
.address()
.get()
.set_link_index_filter(self.interface_index)
.execute();
ips.try_for_each(|ip| self.handle.address().del(ip).execute())
.await?;
self.handle
.link()
.set(self.interface_index)
.mtu(DEFAULT_MTU)
.execute()
.await?;
let res_v4 = self
.handle
.address()
.add(self.interface_index, config.ipv4.into(), 32)
.execute()
.await;
let res_v6 = self
.handle
.address()
.add(self.interface_index, config.ipv6.into(), 128)
.execute()
.await;
Ok(res_v4.or(res_v6)?)
}
pub async fn up(&self) -> Result<()> {
self.handle
.link()
.set(self.interface_index)
.up()
.execute()
.await?;
Ok(())
}
}
fn get_last_error() -> Error {
Error::Io(io::Error::last_os_error())
}
fn set_non_blocking(fd: RawFd) -> Result<()> {
match unsafe { fcntl(fd, F_GETFL) } {
-1 => Err(get_last_error()),
flags => match unsafe { fcntl(fd, F_SETFL, flags | O_NONBLOCK) } {
-1 => Err(get_last_error()),
_ => Ok(()),
},
}
}

View File

@@ -1,28 +0,0 @@
use libc::{close, socket};
use std::ffi::c_int;
use std::os::fd::RawFd;
pub struct WrappedSocket {
fd: RawFd,
}
impl WrappedSocket {
pub fn new(domain: c_int, sock_type: c_int, protocol: c_int) -> Self {
let fd = unsafe { socket(domain, sock_type, protocol) };
Self { fd }
}
pub fn as_raw_fd(&self) -> RawFd {
self.fd
}
}
impl Drop for WrappedSocket {
fn drop(&mut self) {
if self.fd == -1 {
return;
}
unsafe { close(self.fd) };
}
}

View File

@@ -0,0 +1,168 @@
use crate::device_channel::ioctl;
use crate::DnsFallbackStrategy;
use connlib_shared::{
messages::Interface as InterfaceConfig, Callbacks, Error, Result, DNS_SENTINEL,
};
use ip_network::IpNetwork;
use std::sync::atomic::{AtomicBool, Ordering};
use std::task::{Context, Poll};
use std::{
io,
os::fd::{AsRawFd, RawFd},
};
use tokio::io::unix::AsyncFd;
mod utils;
pub(crate) const SIOCGIFMTU: libc::c_ulong = libc::SIOCGIFMTU;
#[derive(Debug)]
pub(crate) struct Tun {
fd: Closeable,
name: String,
}
impl Drop for Tun {
fn drop(&mut self) {
unsafe { libc::close(self.fd.fd.as_raw_fd()) };
}
}
impl Tun {
pub fn write4(&self, src: &[u8]) -> std::io::Result<usize> {
self.fd.with(|fd| write(*fd.get_ref(), src))?
}
pub fn write6(&self, src: &[u8]) -> std::io::Result<usize> {
self.fd.with(|fd| write(*fd.get_ref(), src))?
}
pub fn poll_read(&self, buf: &mut [u8], cx: &mut Context<'_>) -> Poll<io::Result<usize>> {
self.fd
.with(|fd| utils::poll_raw_fd(&fd, |fd| read(fd, buf), cx))?
}
pub fn close(&self) {
self.fd.close();
}
pub fn new(
config: &InterfaceConfig,
callbacks: &impl Callbacks<Error = Error>,
fallback_strategy: DnsFallbackStrategy,
) -> Result<Self> {
let fd = callbacks
.on_set_interface_config(
config.ipv4,
config.ipv6,
DNS_SENTINEL,
fallback_strategy.to_string(),
)?
.ok_or(Error::NoFd)?;
// Safety: File descriptor is open.
let name = unsafe { interface_name(fd)? };
Ok(Tun {
fd: Closeable::new(AsyncFd::new(fd)?),
name,
})
}
pub fn name(&self) -> &str {
self.name.as_str()
}
pub fn add_route(
&self,
route: IpNetwork,
callbacks: &impl Callbacks<Error = Error>,
) -> Result<Option<Self>> {
self.fd.close();
let fd = callbacks.on_add_route(route)?.ok_or(Error::NoFd)?;
let name = unsafe { interface_name(fd)? };
Ok(Some(Tun {
fd: Closeable::new(AsyncFd::new(fd)?),
name,
}))
}
}
/// Retrieves the name of the interface pointed to by the provided file descriptor.
///
/// # Safety
///
/// The file descriptor must be open.
unsafe fn interface_name(fd: RawFd) -> Result<String> {
const TUNGETIFF: libc::c_ulong = 0x800454d2;
let request = ioctl::Request::<GetInterfaceNamePayload>::new();
ioctl::exec(fd, TUNGETIFF, &request)?;
Ok(request.name().to_string())
}
impl ioctl::Request<GetInterfaceNamePayload> {
fn new() -> Self {
Self {
name: [0u8; libc::IF_NAMESIZE],
payload: Default::default(),
}
}
fn name(&self) -> std::borrow::Cow<'_, str> {
// Safety: The memory of `self.name` is always initialized.
let cstr = unsafe { std::ffi::CStr::from_ptr(self.name.as_ptr() as _) };
cstr.to_string_lossy()
}
}
#[derive(Default)]
#[repr(C)]
struct GetInterfaceNamePayload;
/// Read from the given file descriptor in the buffer.
fn read(fd: RawFd, dst: &mut [u8]) -> io::Result<usize> {
// Safety: Within this module, the file descriptor is always valid.
match unsafe { libc::read(fd, dst.as_mut_ptr() as _, dst.len()) } {
-1 => Err(io::Error::last_os_error()),
n => Ok(n as usize),
}
}
/// Write the buffer to the given file descriptor.
fn write(fd: RawFd, buf: &[u8]) -> io::Result<usize> {
// Safety: Within this module, the file descriptor is always valid.
match unsafe { libc::write(fd.as_raw_fd(), buf.as_ptr() as _, buf.len() as _) } {
-1 => Err(io::Error::last_os_error()),
n => Ok(n as usize),
}
}
#[derive(Debug)]
struct Closeable {
closed: AtomicBool,
fd: AsyncFd<RawFd>,
}
impl Closeable {
fn new(fd: AsyncFd<RawFd>) -> Self {
Self {
closed: AtomicBool::new(false),
fd: fd,
}
}
fn with<U>(&self, f: impl FnOnce(&AsyncFd<RawFd>) -> U) -> std::io::Result<U> {
if self.closed.load(Ordering::Acquire) {
return Err(std::io::Error::from_raw_os_error(9));
}
Ok(f(&self.fd))
}
fn close(&self) {
self.closed.store(true, Ordering::Release);
}
}

View File

@@ -1,72 +1,34 @@
use crate::DnsFallbackStrategy;
use connlib_shared::{
messages::Interface as InterfaceConfig, Callbacks, Error, Result, DNS_SENTINEL,
};
use ip_network::IpNetwork;
use libc::{
ctl_info, fcntl, getpeername, getsockopt, ioctl, iovec, msghdr, recvmsg, sendmsg, sockaddr,
sockaddr_ctl, sockaddr_in, socklen_t, AF_INET, AF_INET6, AF_SYSTEM, CTLIOCGINFO, F_GETFL,
F_SETFL, IF_NAMESIZE, IPPROTO_IP, O_NONBLOCK, SOCK_STREAM, SYSPROTO_CONTROL, UTUN_OPT_IFNAME,
ctl_info, fcntl, getpeername, getsockopt, ioctl, iovec, msghdr, recvmsg, sendmsg, sockaddr_ctl,
socklen_t, AF_INET, AF_INET6, AF_SYSTEM, CTLIOCGINFO, F_GETFL, F_SETFL, IF_NAMESIZE,
O_NONBLOCK, SYSPROTO_CONTROL, UTUN_OPT_IFNAME,
};
use std::task::{Context, Poll};
use std::{
ffi::{c_int, c_short, c_uchar},
io,
mem::size_of,
os::fd::{AsRawFd, RawFd},
sync::Arc,
};
use tokio::io::unix::AsyncFd;
use crate::DnsFallbackStrategy;
mod utils;
const CTL_NAME: &[u8] = b"com.apple.net.utun_control";
const SIOCGIFMTU: u64 = 0x0000_0000_c020_6933;
/// `libc` for darwin doesn't define this constant so we declare it here.
pub(crate) const SIOCGIFMTU: u64 = 0x0000_0000_c020_6933;
#[derive(Debug)]
pub(crate) struct IfaceDevice {
pub(crate) struct Tun {
name: String,
fd: AsyncFd<RawFd>,
}
#[derive(Debug)]
pub(crate) struct IfaceStream {
fd: RawFd,
}
mod wrapped_socket;
impl AsRawFd for IfaceStream {
fn as_raw_fd(&self) -> RawFd {
self.fd
}
}
// For some reason this is not available in libc for darwin :c
#[allow(non_camel_case_types)]
#[repr(C)]
pub struct ifreq {
ifr_name: [c_uchar; IF_NAMESIZE],
ifr_ifru: IfrIfru,
}
#[repr(C)]
union IfrIfru {
ifru_addr: sockaddr,
ifru_addr_v4: sockaddr_in,
ifru_addr_v6: sockaddr_in,
ifru_dstaddr: sockaddr,
ifru_broadaddr: sockaddr,
ifru_flags: c_short,
ifru_metric: c_int,
ifru_mtu: c_int,
ifru_phys: c_int,
ifru_media: c_int,
ifru_intval: c_int,
ifru_wake_flags: u32,
ifru_route_refcnt: u32,
ifru_cap: [c_int; 2],
ifru_functional_type: u32,
}
impl IfaceStream {
impl Tun {
pub fn write4(&self, src: &[u8]) -> std::io::Result<usize> {
self.write(src, AF_INET as u8)
}
@@ -75,35 +37,8 @@ impl IfaceStream {
self.write(src, AF_INET6 as u8)
}
pub fn read(&self, dst: &mut [u8]) -> std::io::Result<usize> {
let mut hdr = [0u8; 4];
let mut iov = [
iovec {
iov_base: hdr.as_mut_ptr() as _,
iov_len: hdr.len(),
},
iovec {
iov_base: dst.as_mut_ptr() as _,
iov_len: dst.len(),
},
];
let mut msg_hdr = msghdr {
msg_name: std::ptr::null_mut(),
msg_namelen: 0,
msg_iov: &mut iov[0],
msg_iovlen: iov.len() as _,
msg_control: std::ptr::null_mut(),
msg_controllen: 0,
msg_flags: 0,
};
match unsafe { recvmsg(self.fd, &mut msg_hdr, 0) } {
-1 => Err(io::Error::last_os_error()),
0..=4 => Ok(0),
n => Ok((n - 4) as usize),
}
pub fn poll_read(&self, buf: &mut [u8], cx: &mut Context<'_>) -> Poll<io::Result<usize>> {
utils::poll_raw_fd(&self.fd, |fd| read(fd, buf), cx)
}
fn write(&self, src: &[u8], af: u8) -> std::io::Result<usize> {
@@ -129,19 +64,17 @@ impl IfaceStream {
msg_flags: 0,
};
match unsafe { sendmsg(self.fd, &msg_hdr, 0) } {
match unsafe { sendmsg(self.fd.as_raw_fd(), &msg_hdr, 0) } {
-1 => Err(io::Error::last_os_error()),
n => Ok(n as usize),
}
}
}
impl IfaceDevice {
pub async fn new(
pub fn new(
config: &InterfaceConfig,
callbacks: &impl Callbacks<Error = Error>,
fallback_strategy: DnsFallbackStrategy,
) -> Result<(Self, Arc<AsyncFd<IfaceStream>>)> {
) -> Result<Self> {
let mut info = ctl_info {
ctl_id: 0,
ctl_name: [0; 96],
@@ -207,51 +140,28 @@ impl IfaceDevice {
set_non_blocking(fd)?;
return Ok((
Self { name: name(fd)? },
Arc::new(AsyncFd::new(IfaceStream { fd })?),
));
return Ok(Self {
name: name(fd)?,
fd: AsyncFd::new(fd)?,
});
}
}
Err(get_last_error())
}
/// Get the current MTU value
pub async fn mtu(&self) -> Result<usize> {
let socket = wrapped_socket::WrappedSocket::new(AF_INET, SOCK_STREAM, IPPROTO_IP);
let fd = match socket.as_raw_fd() {
-1 => return Err(get_last_error()),
fd => fd,
};
let iface_name: &[u8] = self.name.as_ref();
let mut ifr = ifreq {
ifr_name: [0; IF_NAMESIZE],
ifr_ifru: IfrIfru { ifru_mtu: 0 },
};
ifr.ifr_name[..iface_name.len()].copy_from_slice(iface_name);
if unsafe { ioctl(fd, SIOCGIFMTU, &ifr) } < 0 {
return Err(get_last_error());
}
Ok(unsafe { ifr.ifr_ifru.ifru_mtu } as _)
}
pub async fn add_route(
pub fn add_route(
&self,
route: IpNetwork,
callbacks: &impl Callbacks<Error = Error>,
) -> Result<Option<(Self, Arc<AsyncFd<IfaceStream>>)>> {
) -> Result<Option<Self>> {
// This will always be None in macos
callbacks.on_add_route(route)?;
Ok(None)
}
pub async fn up(&self) -> Result<()> {
Ok(())
pub fn name(&self) -> &str {
self.name.as_str()
}
}
@@ -269,6 +179,38 @@ fn set_non_blocking(fd: RawFd) -> Result<()> {
}
}
fn read(fd: RawFd, dst: &mut [u8]) -> std::io::Result<usize> {
let mut hdr = [0u8; 4];
let mut iov = [
iovec {
iov_base: hdr.as_mut_ptr() as _,
iov_len: hdr.len(),
},
iovec {
iov_base: dst.as_mut_ptr() as _,
iov_len: dst.len(),
},
];
let mut msg_hdr = msghdr {
msg_name: std::ptr::null_mut(),
msg_namelen: 0,
msg_iov: &mut iov[0],
msg_iovlen: iov.len() as _,
msg_control: std::ptr::null_mut(),
msg_controllen: 0,
msg_flags: 0,
};
// Safety: Within this module, the file descriptor is always valid.
match unsafe { recvmsg(fd, &mut msg_hdr, 0) } {
-1 => Err(io::Error::last_os_error()),
0..=4 => Ok(0),
n => Ok((n - 4) as usize),
}
}
fn name(fd: RawFd) -> Result<String> {
let mut tunnel_name = [0u8; IF_NAMESIZE];
let mut tunnel_name_len = tunnel_name.len() as socklen_t;

View File

@@ -0,0 +1,271 @@
use crate::device_channel::ioctl;
use crate::DnsFallbackStrategy;
use connlib_shared::{messages::Interface as InterfaceConfig, Callbacks, Error, Result};
use futures::TryStreamExt;
use futures_util::future::BoxFuture;
use futures_util::FutureExt;
use ip_network::IpNetwork;
use libc::{
close, fcntl, open, F_GETFL, F_SETFL, IFF_MULTI_QUEUE, IFF_NO_PI, IFF_TUN, O_NONBLOCK, O_RDWR,
};
use netlink_packet_route::RT_SCOPE_UNIVERSE;
use parking_lot::Mutex;
use rtnetlink::{new_connection, Error::NetlinkError, Handle};
use std::task::{Context, Poll};
use std::{
fmt, io,
os::fd::{AsRawFd, RawFd},
};
use tokio::io::unix::AsyncFd;
mod utils;
pub(crate) const SIOCGIFMTU: libc::c_ulong = libc::SIOCGIFMTU;
const IFACE_NAME: &str = "tun-firezone";
const TUNSETIFF: u64 = 0x4004_54ca;
const TUN_FILE: &[u8] = b"/dev/net/tun\0";
const RT_PROT_STATIC: u8 = 4;
const DEFAULT_MTU: u32 = 1280;
const FILE_ALREADY_EXISTS: i32 = -17;
pub struct Tun {
handle: Handle,
connection: tokio::task::JoinHandle<()>,
fd: AsyncFd<RawFd>,
worker: Mutex<Option<BoxFuture<'static, Result<()>>>>,
}
impl fmt::Debug for Tun {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Tun")
.field("handle", &self.handle)
.field("connection", &self.connection)
.field("fd", &self.fd)
.finish_non_exhaustive()
}
}
impl Drop for Tun {
fn drop(&mut self) {
unsafe { close(self.fd.as_raw_fd()) };
self.connection.abort();
}
}
impl Tun {
pub fn write4(&self, buf: &[u8]) -> io::Result<usize> {
write(self.fd.as_raw_fd(), buf)
}
pub fn write6(&self, buf: &[u8]) -> io::Result<usize> {
write(self.fd.as_raw_fd(), buf)
}
pub fn poll_read(&self, buf: &mut [u8], cx: &mut Context<'_>) -> Poll<io::Result<usize>> {
let mut guard = self.worker.lock();
if let Some(worker) = guard.as_mut() {
match worker.poll_unpin(cx) {
Poll::Ready(Ok(())) => {
*guard = None;
}
Poll::Ready(Err(e)) => {
*guard = None;
return Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, e)));
}
Poll::Pending => return Poll::Pending,
}
}
utils::poll_raw_fd(&self.fd, |fd| read(fd, buf), cx)
}
pub fn new(
config: &InterfaceConfig,
_: &impl Callbacks,
_: DnsFallbackStrategy,
) -> Result<Self> {
let fd = match unsafe { open(TUN_FILE.as_ptr() as _, O_RDWR) } {
-1 => return Err(get_last_error()),
fd => fd,
};
// Safety: We just opened the file descriptor.
unsafe {
ioctl::exec(fd, TUNSETIFF, &ioctl::Request::<SetTunFlagsPayload>::new())?;
}
set_non_blocking(fd)?;
let (connection, handle, _) = new_connection()?;
let join_handle = tokio::spawn(connection);
Ok(Self {
handle: handle.clone(),
connection: join_handle,
fd: AsyncFd::new(fd)?,
worker: Mutex::new(Some(set_iface_config(config.clone(), handle).boxed())),
})
}
pub fn add_route(&self, route: IpNetwork, _: &impl Callbacks) -> Result<Option<Self>> {
let handle = self.handle.clone();
let add_route_worker = async move {
let index = handle
.link()
.get()
.match_name(IFACE_NAME.to_string())
.execute()
.try_next()
.await?
.ok_or(Error::NoIface)?
.header
.index;
let req = handle
.route()
.add()
.output_interface(index)
.protocol(RT_PROT_STATIC)
.scope(RT_SCOPE_UNIVERSE);
let res = match route {
IpNetwork::V4(ipnet) => {
req.v4()
.destination_prefix(ipnet.network_address(), ipnet.netmask())
.execute()
.await
}
IpNetwork::V6(ipnet) => {
req.v6()
.destination_prefix(ipnet.network_address(), ipnet.netmask())
.execute()
.await
}
};
match res {
Ok(_) => Ok(()),
Err(NetlinkError(err)) if err.raw_code() == FILE_ALREADY_EXISTS => Ok(()),
Err(err) => Err(err.into()),
}
};
let mut guard = self.worker.lock();
match guard.take() {
None => *guard = Some(add_route_worker.boxed()),
Some(current_worker) => {
*guard = Some(
async move {
current_worker.await?;
add_route_worker.await?;
Ok(())
}
.boxed(),
)
}
}
Ok(None)
}
pub fn name(&self) -> &str {
IFACE_NAME
}
}
#[tracing::instrument(level = "trace", skip(handle))]
async fn set_iface_config(config: InterfaceConfig, handle: Handle) -> Result<()> {
let index = handle
.link()
.get()
.match_name(IFACE_NAME.to_string())
.execute()
.try_next()
.await?
.ok_or(Error::NoIface)?
.header
.index;
let ips = handle
.address()
.get()
.set_link_index_filter(index)
.execute();
ips.try_for_each(|ip| handle.address().del(ip).execute())
.await?;
handle.link().set(index).mtu(DEFAULT_MTU).execute().await?;
let res_v4 = handle
.address()
.add(index, config.ipv4.into(), 32)
.execute()
.await;
let res_v6 = handle
.address()
.add(index, config.ipv6.into(), 128)
.execute()
.await;
handle.link().set(index).up().execute().await?;
Ok(res_v4.or(res_v6)?)
}
fn get_last_error() -> Error {
Error::Io(io::Error::last_os_error())
}
fn set_non_blocking(fd: RawFd) -> Result<()> {
match unsafe { fcntl(fd, F_GETFL) } {
-1 => Err(get_last_error()),
flags => match unsafe { fcntl(fd, F_SETFL, flags | O_NONBLOCK) } {
-1 => Err(get_last_error()),
_ => Ok(()),
},
}
}
/// Read from the given file descriptor in the buffer.
fn read(fd: RawFd, dst: &mut [u8]) -> io::Result<usize> {
// Safety: Within this module, the file descriptor is always valid.
match unsafe { libc::read(fd, dst.as_mut_ptr() as _, dst.len()) } {
-1 => Err(io::Error::last_os_error()),
n => Ok(n as usize),
}
}
/// Write the buffer to the given file descriptor.
fn write(fd: RawFd, buf: &[u8]) -> io::Result<usize> {
// Safety: Within this module, the file descriptor is always valid.
match unsafe { libc::write(fd, buf.as_ptr() as _, buf.len() as _) } {
-1 => Err(io::Error::last_os_error()),
n => Ok(n as usize),
}
}
impl ioctl::Request<SetTunFlagsPayload> {
fn new() -> Self {
let name_as_bytes = IFACE_NAME.as_bytes();
debug_assert!(name_as_bytes.len() < libc::IF_NAMESIZE);
let mut name = [0u8; libc::IF_NAMESIZE];
name[..name_as_bytes.len()].copy_from_slice(name_as_bytes);
Self {
name,
payload: SetTunFlagsPayload {
flags: (IFF_TUN | IFF_NO_PI | IFF_MULTI_QUEUE) as _,
},
}
}
}
#[repr(C)]
struct SetTunFlagsPayload {
flags: std::ffi::c_short,
}

View File

@@ -0,0 +1,15 @@
pub struct Tun {}
impl Tun {
pub fn new() -> Self {
Self {}
}
pub fn write4(&self, _: &[u8]) -> std::io::Result<usize> {
Ok(0)
}
pub fn write6(&self, _: &[u8]) -> std::io::Result<usize> {
Ok(0)
}
}

View File

@@ -0,0 +1,26 @@
use std::io;
use std::os::fd::{AsRawFd, RawFd};
use std::task::{Context, Poll};
use tokio::io::Ready;
#[cfg(target_family = "unix")]
pub fn poll_raw_fd(
fd: &tokio::io::unix::AsyncFd<RawFd>,
mut read: impl FnMut(RawFd) -> io::Result<usize>,
cx: &mut Context<'_>,
) -> Poll<io::Result<usize>> {
loop {
let mut guard = std::task::ready!(fd.poll_read_ready(cx))?;
match read(guard.get_inner().as_raw_fd()) {
Ok(n) => return Poll::Ready(Ok(n)),
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
// a read has blocked, but a write might still succeed.
// clear only the read readiness.
guard.clear_ready_matching(Ready::READABLE);
continue;
}
Err(e) => return Poll::Ready(Err(e)),
}
}
}

View File

@@ -1,4 +1,4 @@
use crate::device_channel::create_iface;
use crate::device_channel::Device;
use crate::peer::PacketTransformGateway;
use crate::{
ConnectedPeer, DnsFallbackStrategy, Event, RoleState, Tunnel, ICE_GATHERING_TIMEOUT_SECONDS,
@@ -22,10 +22,13 @@ where
{
/// Sets the interface configuration and starts background tasks.
#[tracing::instrument(level = "trace", skip(self))]
pub async fn set_interface(&self, config: &InterfaceConfig) -> connlib_shared::Result<()> {
pub fn set_interface(&self, config: &InterfaceConfig) -> connlib_shared::Result<()> {
// Note: the dns fallback strategy is irrelevant for gateways
let device =
Arc::new(create_iface(config, self.callbacks(), DnsFallbackStrategy::default()).await?);
let device = Arc::new(Device::new(
config,
self.callbacks(),
DnsFallbackStrategy::default(),
)?);
self.device.store(Some(device.clone()));
self.no_device_waker.wake();

View File

@@ -398,16 +398,10 @@ where
continue;
};
tokio::spawn({
let callbacks = self.callbacks.clone();
async move {
if let Err(e) = device.refresh_mtu().await {
tracing::error!(error = ?e, "refresh_mtu");
let _ = callbacks.on_error(&e);
}
}
});
if let Err(e) = device.refresh_mtu() {
tracing::error!(error = ?e, "refresh_mtu");
let _ = self.callbacks.on_error(&e);
}
}
if let Poll::Ready(event) = self.role_state.lock().poll_next_event(cx) {

View File

@@ -65,7 +65,6 @@ async fn run(
tunnel
.set_interface(&init.interface)
.await
.context("Failed to set interface")?;
let mut eventloop = Eventloop::new(tunnel, portal);