diff --git a/rust/connlib/clients/shared/src/control.rs b/rust/connlib/clients/shared/src/control.rs index cb49ac4c1..7c2dae3f8 100644 --- a/rust/connlib/clients/shared/src/control.rs +++ b/rust/connlib/clients/shared/src/control.rs @@ -88,7 +88,7 @@ impl ControlPlane { { let mut init = self.tunnel_init.lock().await; if !*init { - if let Err(e) = self.tunnel.set_interface(&interface).await { + if let Err(e) = self.tunnel.set_interface(&interface) { tracing::error!(error = ?e, "Error initializing interface"); return Err(e); } else { @@ -103,7 +103,7 @@ impl ControlPlane { } for resource_description in resources { - self.add_resource(resource_description).await; + self.add_resource(resource_description); } Ok(()) } @@ -141,8 +141,8 @@ impl ControlPlane { } #[tracing::instrument(level = "trace", skip(self))] - pub async fn add_resource(&self, resource_description: ResourceDescription) { - if let Err(e) = self.tunnel.add_resource(resource_description).await { + pub fn add_resource(&self, resource_description: ResourceDescription) { + if let Err(e) = self.tunnel.add_resource(resource_description) { tracing::error!(message = "Can't add resource", error = ?e); let _ = self.tunnel.callbacks().on_error(&e); } @@ -240,7 +240,7 @@ impl ControlPlane { self.connection_details(connection_details, reference) } Messages::Connect(connect) => self.connect(connect), - Messages::ResourceAdded(resource) => self.add_resource(resource).await, + Messages::ResourceAdded(resource) => self.add_resource(resource), Messages::ResourceRemoved(resource) => self.remove_resource(resource.id), Messages::ResourceUpdated(resource) => self.update_resource(resource), Messages::IceCandidates(ice_candidate) => self.add_ice_candidate(ice_candidate).await, diff --git a/rust/connlib/tunnel/src/client.rs b/rust/connlib/tunnel/src/client.rs index eddb4f9c6..3c4fa6306 100644 --- a/rust/connlib/tunnel/src/client.rs +++ b/rust/connlib/tunnel/src/client.rs @@ -1,5 +1,5 @@ use crate::bounded_queue::BoundedQueue; -use crate::device_channel::{create_iface, Packet}; +use crate::device_channel::{Device, Packet}; use crate::ip_packet::{IpPacket, MutableIpPacket}; use crate::peer::PacketTransformClient; use crate::{ @@ -60,7 +60,7 @@ where /// Once added, when a packet for the resource is intercepted a new data channel will be created /// and packets will be wrapped with wireguard and sent through it. #[tracing::instrument(level = "trace", skip(self))] - pub async fn add_resource( + pub fn add_resource( &self, resource_description: ResourceDescription, ) -> connlib_shared::Result<()> { @@ -72,7 +72,7 @@ where .insert(dns.address.clone(), dns.clone()); } ResourceDescription::Cidr(cidr) => { - self.add_route(cidr.address).await?; + self.add_route(cidr.address)?; self.role_state .lock() @@ -110,13 +110,13 @@ where /// Sets the interface configuration and starts background tasks. #[tracing::instrument(level = "trace", skip(self))] - pub async fn set_interface(&self, config: &InterfaceConfig) -> connlib_shared::Result<()> { + pub fn set_interface(&self, config: &InterfaceConfig) -> connlib_shared::Result<()> { if !config.upstream_dns.is_empty() { self.role_state.lock().dns_strategy = DnsFallbackStrategy::UpstreamResolver; } let dns_strategy = self.role_state.lock().dns_strategy; - let device = Arc::new(create_iface(config, self.callbacks(), dns_strategy).await?); + let device = Arc::new(Device::new(config, self.callbacks(), dns_strategy)?); self.device.store(Some(device.clone())); self.no_device_waker.wake(); @@ -124,12 +124,12 @@ where // TODO: the requirement for the DNS_SENTINEL means you NEED ipv4 stack // we are trying to support ipv4 and ipv6, so we should have an ipv6 dns sentinel // alternative. - self.add_route(DNS_SENTINEL.into()).await?; + self.add_route(DNS_SENTINEL.into())?; // Note: I'm just assuming this needs to succeed since we already require ipv4 stack due to the dns sentinel // TODO: change me when we don't require ipv4 - self.add_route(IPV4_RESOURCES.parse().unwrap()).await?; + self.add_route(IPV4_RESOURCES.parse().unwrap())?; - if let Err(e) = self.add_route(IPV6_RESOURCES.parse().unwrap()).await { + if let Err(e) = self.add_route(IPV6_RESOURCES.parse().unwrap()) { tracing::warn!(err = ?e, "ipv6 not supported"); } @@ -148,14 +148,13 @@ where } #[tracing::instrument(level = "trace", skip(self))] - pub async fn add_route(&self, route: IpNetwork) -> connlib_shared::Result<()> { + pub fn add_route(&self, route: IpNetwork) -> connlib_shared::Result<()> { let maybe_new_device = self .device .load() .as_ref() .ok_or(Error::ControlProtocolError)? - .add_route(route, self.callbacks()) - .await?; + .add_route(route, self.callbacks())?; if let Some(new_device) = maybe_new_device { self.device.swap(Some(Arc::new(new_device))); diff --git a/rust/connlib/tunnel/src/control_protocol/gateway.rs b/rust/connlib/tunnel/src/control_protocol/gateway.rs index de5ca861f..4b823fbe5 100644 --- a/rust/connlib/tunnel/src/control_protocol/gateway.rs +++ b/rust/connlib/tunnel/src/control_protocol/gateway.rs @@ -216,16 +216,11 @@ where tracing::trace!(?peer_config.ips, "new_data_channel_open"); let device = self.device.load().clone().ok_or(Error::NoIface)?; let callbacks = self.callbacks.clone(); - let ips = peer_config.ips.clone(); - - // Worst thing if this is not run before peers_by_ip is that some packets are lost to the default route - tokio::spawn(async move { - for ip in ips { - if let Ok(res) = device.add_route(ip, &callbacks).await { - assert!(res.is_none(), "gateway does not run on android and thus never produces a new device upon `add_route`"); - } + for ip in &peer_config.ips { + if let Ok(res) = device.add_route(*ip, &callbacks) { + assert!(res.is_none(), "gateway does not run on android and thus never produces a new device upon `add_route`"); } - }); + } let peer = Arc::new(Peer::new( self.private_key.clone(), diff --git a/rust/connlib/tunnel/src/device_channel.rs b/rust/connlib/tunnel/src/device_channel.rs index aa59d997c..5dd153634 100644 --- a/rust/connlib/tunnel/src/device_channel.rs +++ b/rust/connlib/tunnel/src/device_channel.rs @@ -1,39 +1,78 @@ #![allow(clippy::module_inception)] +#![cfg_attr(target_family = "windows", allow(dead_code))] // TODO: Remove when windows is fully implemented. -#[cfg(target_family = "unix")] -#[path = "device_channel/device_channel_unix.rs"] -mod device_channel; +#[cfg(any(target_os = "macos", target_os = "ios"))] +#[path = "device_channel/tun_darwin.rs"] +mod tun; + +#[cfg(target_os = "linux")] +#[path = "device_channel/tun_linux.rs"] +mod tun; #[cfg(target_family = "windows")] -#[path = "device_channel/device_channel_win.rs"] -mod device_channel; +#[path = "device_channel/tun_windows.rs"] +mod tun; + +// TODO: Android and linux are nearly identical; use a common tunnel module? +#[cfg(target_os = "android")] +#[path = "device_channel/tun_android.rs"] +mod tun; use crate::ip_packet::MutableIpPacket; +use crate::DnsFallbackStrategy; +use connlib_shared::error::ConnlibError; +use connlib_shared::messages::Interface; use connlib_shared::{Callbacks, Error}; -pub(crate) use device_channel::*; use ip_network::IpNetwork; use std::borrow::Cow; use std::io; -use std::task::{ready, Context, Poll}; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::task::{Context, Poll}; +use tun::Tun; pub struct Device { - config: IfaceConfig, - io: DeviceIo, + mtu: AtomicUsize, + tun: Tun, } impl Device { + #[cfg(target_family = "unix")] + pub(crate) fn new( + config: &Interface, + callbacks: &impl Callbacks, + dns: DnsFallbackStrategy, + ) -> Result { + let tun = Tun::new(config, callbacks, dns)?; + let mtu = AtomicUsize::new(ioctl::interface_mtu_by_name(tun.name())?); + + Ok(Device { mtu, tun }) + } + + #[cfg(target_family = "windows")] + pub(crate) fn new( + _: &Interface, + _: &impl Callbacks, + _: DnsFallbackStrategy, + ) -> Result { + Ok(Device { + tun: Tun::new(), + mtu: AtomicUsize::default(), // Dummy value for now. + }) + } + + #[cfg(target_family = "unix")] pub(crate) fn poll_read<'b>( &self, buf: &'b mut [u8], cx: &mut Context<'_>, ) -> Poll>>> { - let res = ready!(self.io.poll_read(&mut buf[..self.config.mtu()], cx))?; + let n = std::task::ready!(self.tun.poll_read(&mut buf[..self.mtu()], cx))?; - if res == 0 { + if n == 0 { return Poll::Ready(Ok(None)); } - Poll::Ready(Ok(Some(MutableIpPacket::new(&mut buf[..res]).ok_or_else( + Poll::Ready(Ok(Some(MutableIpPacket::new(&mut buf[..n]).ok_or_else( || { io::Error::new( io::ErrorKind::InvalidInput, @@ -43,20 +82,60 @@ impl Device { )?))) } - pub(crate) async fn add_route( + #[cfg(target_family = "windows")] + pub(crate) fn poll_read<'b>( + &self, + _: &'b mut [u8], + _: &mut Context<'_>, + ) -> Poll>>> { + Poll::Pending + } + + pub(crate) fn mtu(&self) -> usize { + self.mtu.load(Ordering::Relaxed) + } + + #[cfg(target_family = "unix")] + pub(crate) fn add_route( &self, route: IpNetwork, callbacks: &impl Callbacks, ) -> Result, Error> { - self.config.add_route(route, callbacks).await + let Some(tun) = self.tun.add_route(route, callbacks)? else { + return Ok(None); + }; + let mtu = AtomicUsize::new(ioctl::interface_mtu_by_name(tun.name())?); + + Ok(Some(Device { mtu, tun })) } - pub(crate) async fn refresh_mtu(&self) -> Result { - self.config.refresh_mtu().await + #[cfg(target_family = "windows")] + pub(crate) fn add_route( + &self, + _: IpNetwork, + _: &impl Callbacks, + ) -> Result, Error> { + Ok(None) + } + + #[cfg(target_family = "unix")] + pub(crate) fn refresh_mtu(&self) -> Result { + let mtu = ioctl::interface_mtu_by_name(self.tun.name())?; + self.mtu.store(mtu, Ordering::Relaxed); + + Ok(mtu) + } + + #[cfg(target_family = "windows")] + pub(crate) fn refresh_mtu(&self) -> Result { + Ok(0) } pub fn write(&self, packet: Packet<'_>) -> io::Result { - self.io.write(packet) + match packet { + Packet::Ipv4(msg) => self.tun.write4(&msg), + Packet::Ipv6(msg) => self.tun.write6(&msg), + } } } @@ -64,3 +143,98 @@ pub enum Packet<'a> { Ipv4(Cow<'a, [u8]>), Ipv6(Cow<'a, [u8]>), } + +#[cfg(target_family = "unix")] +mod ioctl { + use super::*; + use std::os::fd::RawFd; + use tun::SIOCGIFMTU; + + pub(crate) fn interface_mtu_by_name(name: &str) -> Result { + let socket = Socket::ip4()?; + let request = Request::::new(name)?; + + // Safety: The file descriptor is open. + unsafe { + exec(socket.fd, SIOCGIFMTU, &request)?; + } + + Ok(request.payload.mtu as usize) + } + + /// Executes the `ioctl` syscall on the given file descriptor with the provided request. + /// + /// # Safety + /// + /// The file descriptor must be open. + pub(crate) unsafe fn exec

( + fd: RawFd, + code: libc::c_ulong, + req: &Request

, + ) -> Result<(), ConnlibError> { + let ret = unsafe { libc::ioctl(fd, code as _, req) }; + + if ret < 0 { + return Err(io::Error::last_os_error().into()); + } + + Ok(()) + } + + /// Represents a control request to an IO device, addresses by the device's name. + /// + /// The payload MUST also be `#[repr(C)]` and its layout depends on the particular request you are sending. + #[repr(C)] + pub(crate) struct Request

{ + pub(crate) name: [std::ffi::c_uchar; libc::IF_NAMESIZE], + pub(crate) payload: P, + } + + /// A socket newtype which closes the file descriptor on drop. + struct Socket { + fd: RawFd, + } + + impl Socket { + fn ip4() -> io::Result { + // Safety: All provided parameters are constants. + let fd = unsafe { libc::socket(libc::AF_INET, libc::SOCK_STREAM, libc::IPPROTO_IP) }; + + if fd == -1 { + return Err(io::Error::last_os_error()); + } + + Ok(Self { fd }) + } + } + + impl Drop for Socket { + fn drop(&mut self) { + // Safety: This is the only call to `close` and it happens when `Guard` is being dropped. + unsafe { libc::close(self.fd) }; + } + } + + impl Request { + fn new(name: &str) -> io::Result { + if name.len() > libc::IF_NAMESIZE { + return Err(io::ErrorKind::InvalidInput.into()); + } + + let mut request = Request { + name: [0u8; libc::IF_NAMESIZE], + payload: Default::default(), + }; + + request.name[..name.len()].copy_from_slice(name.as_bytes()); + + Ok(request) + } + } + + #[derive(Default)] + #[repr(C)] + struct GetInterfaceMtuPayload { + mtu: libc::c_int, + } +} diff --git a/rust/connlib/tunnel/src/device_channel/device_channel_unix.rs b/rust/connlib/tunnel/src/device_channel/device_channel_unix.rs deleted file mode 100644 index bcacf20b1..000000000 --- a/rust/connlib/tunnel/src/device_channel/device_channel_unix.rs +++ /dev/null @@ -1,99 +0,0 @@ -use std::io; -use std::sync::{ - atomic::{AtomicUsize, Ordering::Relaxed}, - Arc, -}; -use std::task::{ready, Context, Poll}; - -use connlib_shared::{messages::Interface, Callbacks, Error, Result}; -use ip_network::IpNetwork; -use tokio::io::{unix::AsyncFd, Ready}; - -use tun::{IfaceDevice, IfaceStream}; - -use crate::device_channel::{Device, Packet}; -use crate::DnsFallbackStrategy; - -mod tun; - -pub(crate) struct IfaceConfig { - mtu: AtomicUsize, - iface: IfaceDevice, -} - -pub(crate) struct DeviceIo(Arc>); - -impl DeviceIo { - pub fn poll_read(&self, out: &mut [u8], cx: &mut Context<'_>) -> Poll> { - loop { - let mut guard = ready!(self.0.poll_read_ready(cx))?; - - match guard.get_inner().read(out) { - Ok(n) => return Poll::Ready(Ok(n)), - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { - // a read has blocked, but a write might still succeed. - // clear only the read readiness. - guard.clear_ready_matching(Ready::READABLE); - continue; - } - Err(e) => return Poll::Ready(Err(e)), - } - } - } - - // Note: write is synchronous because it's non-blocking - // and some losiness is acceptable and increseases performance - // since we don't block the reading loops. - pub fn write(&self, packet: Packet<'_>) -> io::Result { - match packet { - Packet::Ipv4(msg) => self.0.get_ref().write4(&msg), - Packet::Ipv6(msg) => self.0.get_ref().write6(&msg), - } - } -} - -impl IfaceConfig { - pub(crate) fn mtu(&self) -> usize { - self.mtu.load(Relaxed) - } - - pub(crate) async fn refresh_mtu(&self) -> Result { - let mtu = self.iface.mtu().await?; - self.mtu.store(mtu, Relaxed); - Ok(mtu) - } - - pub(crate) async fn add_route( - &self, - route: IpNetwork, - callbacks: &impl Callbacks, - ) -> Result> { - let Some((iface, stream)) = self.iface.add_route(route, callbacks).await? else { - return Ok(None); - }; - let io = DeviceIo(stream); - let mtu = iface.mtu().await?; - let config = IfaceConfig { - iface, - mtu: AtomicUsize::new(mtu), - }; - Ok(Some(Device { io, config })) - } -} - -pub(crate) async fn create_iface( - config: &Interface, - callbacks: &impl Callbacks, - fallback_strategy: DnsFallbackStrategy, -) -> Result { - let (iface, stream) = IfaceDevice::new(config, callbacks, fallback_strategy).await?; - iface.up().await?; - let io = DeviceIo(stream); - let mtu = iface.mtu().await?; - let config = IfaceConfig { - iface, - mtu: AtomicUsize::new(mtu), - }; - - Ok(Device { io, config }) -} diff --git a/rust/connlib/tunnel/src/device_channel/device_channel_win.rs b/rust/connlib/tunnel/src/device_channel/device_channel_win.rs deleted file mode 100644 index 2d2a0c4ef..000000000 --- a/rust/connlib/tunnel/src/device_channel/device_channel_win.rs +++ /dev/null @@ -1,60 +0,0 @@ -use crate::device_channel::Packet; -use crate::Device; -use crate::DnsFallbackStrategy; -use connlib_shared::{messages::Interface, Callbacks, Result}; -use ip_network::IpNetwork; -use std::task::{Context, Poll}; - -// TODO: Fill all this out. These are just stubs to test the GUI. - -pub(crate) struct DeviceIo; - -pub(crate) struct IfaceConfig; - -impl DeviceIo { - pub fn poll_read(&self, _: &mut [u8], _: &mut Context<'_>) -> Poll> { - // Incoming packets will never appear - Poll::Pending - } - - pub fn write(&self, packet: Packet<'_>) -> std::io::Result { - // All outgoing packets are successfully written to the void - match packet { - Packet::Ipv4(msg) => Ok(msg.len()), - Packet::Ipv6(msg) => Ok(msg.len()), - } - } -} - -const BOGUS_MTU: usize = 1_500; - -impl IfaceConfig { - pub(crate) fn mtu(&self) -> usize { - BOGUS_MTU - } - - pub(crate) async fn refresh_mtu(&self) -> Result { - Ok(BOGUS_MTU) - } - - pub(crate) async fn add_route( - &self, - _: IpNetwork, - _: &impl Callbacks, - ) -> Result> { - let io = DeviceIo {}; - let config = IfaceConfig {}; - Ok(Some(Device { io, config })) - } -} - -pub(crate) async fn create_iface( - _: &Interface, - _: &impl Callbacks, - _: DnsFallbackStrategy, -) -> Result { - Ok(Device { - config: IfaceConfig {}, - io: DeviceIo {}, - }) -} diff --git a/rust/connlib/tunnel/src/device_channel/tun.rs b/rust/connlib/tunnel/src/device_channel/tun.rs deleted file mode 100644 index 13a68a032..000000000 --- a/rust/connlib/tunnel/src/device_channel/tun.rs +++ /dev/null @@ -1,16 +0,0 @@ -#![allow(clippy::module_inception)] - -#[cfg(any(target_os = "macos", target_os = "ios"))] -#[path = "tun/tun_darwin.rs"] -mod tun; - -#[cfg(target_os = "linux")] -#[path = "tun/tun_linux.rs"] -mod tun; - -// TODO: Android and linux are nearly identical; use a common tunnel module? -#[cfg(target_os = "android")] -#[path = "tun/tun_android.rs"] -mod tun; - -pub(crate) use tun::*; diff --git a/rust/connlib/tunnel/src/device_channel/tun/closeable.rs b/rust/connlib/tunnel/src/device_channel/tun/closeable.rs deleted file mode 100644 index 1107310e6..000000000 --- a/rust/connlib/tunnel/src/device_channel/tun/closeable.rs +++ /dev/null @@ -1,35 +0,0 @@ -use std::os::fd::{AsRawFd, RawFd}; -use std::sync::atomic::{AtomicBool, Ordering}; - -#[derive(Debug)] -pub(crate) struct Closeable { - closed: AtomicBool, - value: RawFd, -} - -impl AsRawFd for Closeable { - fn as_raw_fd(&self) -> RawFd { - self.value.as_raw_fd() - } -} - -impl Closeable { - pub(crate) fn new(fd: RawFd) -> Self { - Self { - closed: AtomicBool::new(false), - value: fd, - } - } - - pub(crate) fn with(&self, f: impl FnOnce(RawFd) -> U) -> std::io::Result { - if self.closed.load(Ordering::Acquire) { - return Err(std::io::Error::from_raw_os_error(9)); - } - - Ok(f(self.value)) - } - - pub(crate) fn close(&self) { - self.closed.store(true, Ordering::Release); - } -} diff --git a/rust/connlib/tunnel/src/device_channel/tun/tun_android.rs b/rust/connlib/tunnel/src/device_channel/tun/tun_android.rs deleted file mode 100644 index 914c0466c..000000000 --- a/rust/connlib/tunnel/src/device_channel/tun/tun_android.rs +++ /dev/null @@ -1,196 +0,0 @@ -use closeable::Closeable; -use connlib_shared::{ - messages::Interface as InterfaceConfig, Callbacks, Error, Result, DNS_SENTINEL, -}; -use ip_network::IpNetwork; -use libc::{ - close, ioctl, read, sockaddr, sockaddr_in, write, AF_INET, IFNAMSIZ, IPPROTO_IP, SIOCGIFMTU, - SOCK_STREAM, -}; -use std::{ - ffi::{c_int, c_short, c_uchar}, - io, - os::fd::{AsRawFd, RawFd}, - sync::Arc, -}; -use tokio::io::unix::AsyncFd; - -use crate::DnsFallbackStrategy; - -mod closeable; -mod wrapped_socket; - -#[repr(C)] -union IfrIfru { - ifru_addr: sockaddr, - ifru_addr_v4: sockaddr_in, - ifru_addr_v6: sockaddr_in, - ifru_dstaddr: sockaddr, - ifru_broadaddr: sockaddr, - ifru_flags: c_short, - ifru_metric: c_int, - ifru_mtu: c_int, - ifru_phys: c_int, - ifru_media: c_int, - ifru_intval: c_int, - ifru_wake_flags: u32, - ifru_route_refcnt: u32, - ifru_cap: [c_int; 2], - ifru_functional_type: u32, -} - -#[repr(C)] -pub struct ifreq { - ifr_name: [c_uchar; IFNAMSIZ], - ifr_ifru: IfrIfru, -} - -const TUNGETIFF: u64 = 0x800454d2; - -#[derive(Debug)] -pub(crate) struct IfaceDevice(Arc>); - -#[derive(Debug)] -pub(crate) struct IfaceStream { - fd: Closeable, -} - -impl AsRawFd for IfaceStream { - fn as_raw_fd(&self) -> RawFd { - self.fd.as_raw_fd() - } -} - -impl Drop for IfaceStream { - fn drop(&mut self) { - unsafe { close(self.fd.as_raw_fd()) }; - } -} - -impl IfaceStream { - fn write(&self, buf: &[u8]) -> std::io::Result { - match self - .fd - .with(|fd| unsafe { write(fd, buf.as_ptr() as _, buf.len() as _) })? - { - -1 => Err(io::Error::last_os_error()), - n => Ok(n as usize), - } - } - - pub fn write4(&self, src: &[u8]) -> std::io::Result { - self.write(src) - } - - pub fn write6(&self, src: &[u8]) -> std::io::Result { - self.write(src) - } - - pub fn read(&self, dst: &mut [u8]) -> std::io::Result { - // We don't read(or write) again from the fd because the given fd number might be reclaimed - // so this could make an spurious read/write to another fd and we DEFINITELY don't want that - match self - .fd - .with(|fd| unsafe { read(fd, dst.as_mut_ptr() as _, dst.len()) })? - { - -1 => Err(io::Error::last_os_error()), - n => Ok(n as usize), - } - } - - pub fn close(&self) { - self.fd.close(); - } -} - -impl IfaceDevice { - pub async fn new( - config: &InterfaceConfig, - callbacks: &impl Callbacks, - fallback_strategy: DnsFallbackStrategy, - ) -> Result<(Self, Arc>)> { - let fd = callbacks - .on_set_interface_config( - config.ipv4, - config.ipv6, - DNS_SENTINEL, - fallback_strategy.to_string(), - )? - .ok_or(Error::NoFd)?; - let iface_stream = Arc::new(AsyncFd::new(IfaceStream { - fd: Closeable::new(fd.into()), - })?); - let this = Self(Arc::clone(&iface_stream)); - - Ok((this, iface_stream)) - } - - fn name(&self) -> Result { - let mut ifr = ifreq { - ifr_name: [0; IFNAMSIZ], - ifr_ifru: unsafe { std::mem::zeroed() }, - }; - - match self - .0 - .get_ref() - .fd - .with(|fd| unsafe { ioctl(fd, TUNGETIFF as _, &mut ifr) })? - { - 0 => { - let name_cstr = unsafe { std::ffi::CStr::from_ptr(ifr.ifr_name.as_ptr() as _) }; - Ok(name_cstr.to_string_lossy().into_owned()) - } - _ => Err(get_last_error()), - } - } - - /// Get the current MTU value - pub async fn mtu(&self) -> Result { - let socket = wrapped_socket::WrappedSocket::new(AF_INET, SOCK_STREAM, IPPROTO_IP); - let fd = match socket.as_raw_fd() { - -1 => return Err(get_last_error()), - fd => fd, - }; - - let name = self.name()?; - let iface_name: &[u8] = name.as_ref(); - let mut ifr = ifreq { - ifr_name: [0; IFNAMSIZ], - ifr_ifru: IfrIfru { ifru_mtu: 0 }, - }; - - ifr.ifr_name[..iface_name.len()].copy_from_slice(iface_name); - - if unsafe { ioctl(fd, SIOCGIFMTU as _, &ifr) } < 0 { - return Err(get_last_error()); - } - - let mtu = unsafe { ifr.ifr_ifru.ifru_mtu }; - - Ok(mtu as _) - } - - pub async fn add_route( - &self, - route: IpNetwork, - callbacks: &impl Callbacks, - ) -> Result>)>> { - self.0.get_ref().close(); - let fd = callbacks.on_add_route(route)?.ok_or(Error::NoFd)?; - let iface_stream = Arc::new(AsyncFd::new(IfaceStream { - fd: Closeable::new(fd.into()), - })?); - let this = Self(Arc::clone(&iface_stream)); - - Ok(Some((this, iface_stream))) - } - - pub async fn up(&self) -> Result<()> { - Ok(()) - } -} - -fn get_last_error() -> Error { - Error::Io(io::Error::last_os_error()) -} diff --git a/rust/connlib/tunnel/src/device_channel/tun/tun_linux.rs b/rust/connlib/tunnel/src/device_channel/tun/tun_linux.rs deleted file mode 100644 index 3580866bb..000000000 --- a/rust/connlib/tunnel/src/device_channel/tun/tun_linux.rs +++ /dev/null @@ -1,273 +0,0 @@ -use connlib_shared::{messages::Interface as InterfaceConfig, Callbacks, Error, Result}; -use futures::TryStreamExt; -use ip_network::IpNetwork; -use libc::{ - close, fcntl, ioctl, open, read, sockaddr, sockaddr_in, write, F_GETFL, F_SETFL, - IFF_MULTI_QUEUE, IFF_NO_PI, IFF_TUN, IFNAMSIZ, O_NONBLOCK, O_RDWR, -}; -use netlink_packet_route::{rtnl::link::nlas::Nla, RT_SCOPE_UNIVERSE}; -use rtnetlink::{new_connection, Error::NetlinkError, Handle}; -use std::{ - ffi::{c_int, c_short, c_uchar}, - io, - os::fd::{AsRawFd, RawFd}, - sync::Arc, -}; -use tokio::io::unix::AsyncFd; - -use crate::DnsFallbackStrategy; - -const IFACE_NAME: &str = "tun-firezone"; -const TUNSETIFF: u64 = 0x4004_54ca; -const TUN_FILE: &[u8] = b"/dev/net/tun\0"; -const RT_PROT_STATIC: u8 = 4; -const DEFAULT_MTU: u32 = 1280; -const FILE_ALREADY_EXISTS: i32 = -17; - -#[repr(C)] -union IfrIfru { - ifru_addr: sockaddr, - ifru_addr_v4: sockaddr_in, - ifru_addr_v6: sockaddr_in, - ifru_dstaddr: sockaddr, - ifru_broadaddr: sockaddr, - ifru_flags: c_short, - ifru_metric: c_int, - ifru_mtu: c_int, - ifru_phys: c_int, - ifru_media: c_int, - ifru_intval: c_int, - ifru_wake_flags: u32, - ifru_route_refcnt: u32, - ifru_cap: [c_int; 2], - ifru_functional_type: u32, -} - -#[repr(C)] -pub struct ifreq { - ifr_name: [c_uchar; IFNAMSIZ], - ifr_ifru: IfrIfru, -} - -#[derive(Debug)] -pub struct IfaceDevice { - handle: Handle, - connection: tokio::task::JoinHandle<()>, - interface_index: u32, -} - -#[derive(Debug)] -pub struct IfaceStream(RawFd); - -impl AsRawFd for IfaceStream { - fn as_raw_fd(&self) -> RawFd { - self.0 - } -} - -impl Drop for IfaceStream { - fn drop(&mut self) { - unsafe { close(self.0) }; - } -} - -impl Drop for IfaceDevice { - fn drop(&mut self) { - self.connection.abort(); - } -} - -impl IfaceStream { - fn write(&self, buf: &[u8]) -> std::io::Result { - match unsafe { write(self.0, buf.as_ptr() as _, buf.len() as _) } { - -1 => Err(io::Error::last_os_error()), - n => Ok(n as usize), - } - } - - pub fn write4(&self, buf: &[u8]) -> std::io::Result { - self.write(buf) - } - - pub fn write6(&self, buf: &[u8]) -> std::io::Result { - self.write(buf) - } - - pub fn read(&self, dst: &mut [u8]) -> std::io::Result { - match unsafe { read(self.0, dst.as_mut_ptr() as _, dst.len()) } { - -1 => Err(io::Error::last_os_error()), - n => Ok(n as usize), - } - } -} - -impl IfaceDevice { - pub async fn new( - config: &InterfaceConfig, - cb: &impl Callbacks, - _: DnsFallbackStrategy, - ) -> Result<(Self, Arc>)> { - debug_assert!(IFACE_NAME.as_bytes().len() < IFNAMSIZ); - - let fd = match unsafe { open(TUN_FILE.as_ptr() as _, O_RDWR) } { - -1 => return Err(get_last_error()), - fd => fd, - }; - - let mut ifr = ifreq { - ifr_name: [0; IFNAMSIZ], - ifr_ifru: IfrIfru { - ifru_flags: (IFF_TUN | IFF_NO_PI | IFF_MULTI_QUEUE) as _, - }, - }; - - ifr.ifr_name[..IFACE_NAME.as_bytes().len()].copy_from_slice(IFACE_NAME.as_bytes()); - - if unsafe { ioctl(fd, TUNSETIFF as _, &ifr) } < 0 { - return Err(get_last_error()); - } - - let (connection, handle, _) = new_connection()?; - let join_handle = tokio::spawn(connection); - let interface_index = handle - .link() - .get() - .match_name(IFACE_NAME.to_string()) - .execute() - .try_next() - .await? - .ok_or(Error::NoIface)? - .header - .index; - - set_non_blocking(fd)?; - - let this = Self { - handle, - connection: join_handle, - interface_index, - }; - - this.set_iface_config(config, cb).await?; - - Ok((this, Arc::new(AsyncFd::new(IfaceStream(fd))?))) - } - - /// Get the current MTU value - pub async fn mtu(&self) -> Result { - while let Ok(Some(msg)) = self - .handle - .link() - .get() - .match_index(self.interface_index) - .execute() - .try_next() - .await - { - for nla in msg.nlas { - if let Nla::Mtu(mtu) = nla { - return Ok(mtu as usize); - } - } - } - - Err(Error::NoMtu) - } - - pub async fn add_route( - &self, - route: IpNetwork, - _: &impl Callbacks, - ) -> Result>)>> { - let req = self - .handle - .route() - .add() - .output_interface(self.interface_index) - .protocol(RT_PROT_STATIC) - .scope(RT_SCOPE_UNIVERSE); - let res = match route { - IpNetwork::V4(ipnet) => { - req.v4() - .destination_prefix(ipnet.network_address(), ipnet.netmask()) - .execute() - .await - } - IpNetwork::V6(ipnet) => { - req.v6() - .destination_prefix(ipnet.network_address(), ipnet.netmask()) - .execute() - .await - } - }; - - match res { - Ok(_) => Ok(None), - Err(NetlinkError(err)) if err.raw_code() == FILE_ALREADY_EXISTS => Ok(None), - Err(err) => Err(err.into()), - } - } - - #[tracing::instrument(level = "trace", skip(self))] - pub async fn set_iface_config( - &self, - config: &InterfaceConfig, - _: &impl Callbacks, - ) -> Result<()> { - let ips = self - .handle - .address() - .get() - .set_link_index_filter(self.interface_index) - .execute(); - - ips.try_for_each(|ip| self.handle.address().del(ip).execute()) - .await?; - - self.handle - .link() - .set(self.interface_index) - .mtu(DEFAULT_MTU) - .execute() - .await?; - - let res_v4 = self - .handle - .address() - .add(self.interface_index, config.ipv4.into(), 32) - .execute() - .await; - let res_v6 = self - .handle - .address() - .add(self.interface_index, config.ipv6.into(), 128) - .execute() - .await; - - Ok(res_v4.or(res_v6)?) - } - - pub async fn up(&self) -> Result<()> { - self.handle - .link() - .set(self.interface_index) - .up() - .execute() - .await?; - Ok(()) - } -} - -fn get_last_error() -> Error { - Error::Io(io::Error::last_os_error()) -} - -fn set_non_blocking(fd: RawFd) -> Result<()> { - match unsafe { fcntl(fd, F_GETFL) } { - -1 => Err(get_last_error()), - flags => match unsafe { fcntl(fd, F_SETFL, flags | O_NONBLOCK) } { - -1 => Err(get_last_error()), - _ => Ok(()), - }, - } -} diff --git a/rust/connlib/tunnel/src/device_channel/tun/wrapped_socket.rs b/rust/connlib/tunnel/src/device_channel/tun/wrapped_socket.rs deleted file mode 100644 index d418a68c3..000000000 --- a/rust/connlib/tunnel/src/device_channel/tun/wrapped_socket.rs +++ /dev/null @@ -1,28 +0,0 @@ -use libc::{close, socket}; -use std::ffi::c_int; -use std::os::fd::RawFd; - -pub struct WrappedSocket { - fd: RawFd, -} - -impl WrappedSocket { - pub fn new(domain: c_int, sock_type: c_int, protocol: c_int) -> Self { - let fd = unsafe { socket(domain, sock_type, protocol) }; - Self { fd } - } - - pub fn as_raw_fd(&self) -> RawFd { - self.fd - } -} - -impl Drop for WrappedSocket { - fn drop(&mut self) { - if self.fd == -1 { - return; - } - - unsafe { close(self.fd) }; - } -} diff --git a/rust/connlib/tunnel/src/device_channel/tun_android.rs b/rust/connlib/tunnel/src/device_channel/tun_android.rs new file mode 100644 index 000000000..1050e8e2b --- /dev/null +++ b/rust/connlib/tunnel/src/device_channel/tun_android.rs @@ -0,0 +1,168 @@ +use crate::device_channel::ioctl; +use crate::DnsFallbackStrategy; +use connlib_shared::{ + messages::Interface as InterfaceConfig, Callbacks, Error, Result, DNS_SENTINEL, +}; +use ip_network::IpNetwork; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::task::{Context, Poll}; +use std::{ + io, + os::fd::{AsRawFd, RawFd}, +}; +use tokio::io::unix::AsyncFd; + +mod utils; + +pub(crate) const SIOCGIFMTU: libc::c_ulong = libc::SIOCGIFMTU; + +#[derive(Debug)] +pub(crate) struct Tun { + fd: Closeable, + name: String, +} + +impl Drop for Tun { + fn drop(&mut self) { + unsafe { libc::close(self.fd.fd.as_raw_fd()) }; + } +} + +impl Tun { + pub fn write4(&self, src: &[u8]) -> std::io::Result { + self.fd.with(|fd| write(*fd.get_ref(), src))? + } + + pub fn write6(&self, src: &[u8]) -> std::io::Result { + self.fd.with(|fd| write(*fd.get_ref(), src))? + } + + pub fn poll_read(&self, buf: &mut [u8], cx: &mut Context<'_>) -> Poll> { + self.fd + .with(|fd| utils::poll_raw_fd(&fd, |fd| read(fd, buf), cx))? + } + + pub fn close(&self) { + self.fd.close(); + } + + pub fn new( + config: &InterfaceConfig, + callbacks: &impl Callbacks, + fallback_strategy: DnsFallbackStrategy, + ) -> Result { + let fd = callbacks + .on_set_interface_config( + config.ipv4, + config.ipv6, + DNS_SENTINEL, + fallback_strategy.to_string(), + )? + .ok_or(Error::NoFd)?; + // Safety: File descriptor is open. + let name = unsafe { interface_name(fd)? }; + + Ok(Tun { + fd: Closeable::new(AsyncFd::new(fd)?), + name, + }) + } + + pub fn name(&self) -> &str { + self.name.as_str() + } + + pub fn add_route( + &self, + route: IpNetwork, + callbacks: &impl Callbacks, + ) -> Result> { + self.fd.close(); + let fd = callbacks.on_add_route(route)?.ok_or(Error::NoFd)?; + let name = unsafe { interface_name(fd)? }; + + Ok(Some(Tun { + fd: Closeable::new(AsyncFd::new(fd)?), + name, + })) + } +} + +/// Retrieves the name of the interface pointed to by the provided file descriptor. +/// +/// # Safety +/// +/// The file descriptor must be open. +unsafe fn interface_name(fd: RawFd) -> Result { + const TUNGETIFF: libc::c_ulong = 0x800454d2; + let request = ioctl::Request::::new(); + + ioctl::exec(fd, TUNGETIFF, &request)?; + + Ok(request.name().to_string()) +} + +impl ioctl::Request { + fn new() -> Self { + Self { + name: [0u8; libc::IF_NAMESIZE], + payload: Default::default(), + } + } + + fn name(&self) -> std::borrow::Cow<'_, str> { + // Safety: The memory of `self.name` is always initialized. + let cstr = unsafe { std::ffi::CStr::from_ptr(self.name.as_ptr() as _) }; + + cstr.to_string_lossy() + } +} + +#[derive(Default)] +#[repr(C)] +struct GetInterfaceNamePayload; + +/// Read from the given file descriptor in the buffer. +fn read(fd: RawFd, dst: &mut [u8]) -> io::Result { + // Safety: Within this module, the file descriptor is always valid. + match unsafe { libc::read(fd, dst.as_mut_ptr() as _, dst.len()) } { + -1 => Err(io::Error::last_os_error()), + n => Ok(n as usize), + } +} + +/// Write the buffer to the given file descriptor. +fn write(fd: RawFd, buf: &[u8]) -> io::Result { + // Safety: Within this module, the file descriptor is always valid. + match unsafe { libc::write(fd.as_raw_fd(), buf.as_ptr() as _, buf.len() as _) } { + -1 => Err(io::Error::last_os_error()), + n => Ok(n as usize), + } +} + +#[derive(Debug)] +struct Closeable { + closed: AtomicBool, + fd: AsyncFd, +} + +impl Closeable { + fn new(fd: AsyncFd) -> Self { + Self { + closed: AtomicBool::new(false), + fd: fd, + } + } + + fn with(&self, f: impl FnOnce(&AsyncFd) -> U) -> std::io::Result { + if self.closed.load(Ordering::Acquire) { + return Err(std::io::Error::from_raw_os_error(9)); + } + + Ok(f(&self.fd)) + } + + fn close(&self) { + self.closed.store(true, Ordering::Release); + } +} diff --git a/rust/connlib/tunnel/src/device_channel/tun/tun_darwin.rs b/rust/connlib/tunnel/src/device_channel/tun_darwin.rs similarity index 64% rename from rust/connlib/tunnel/src/device_channel/tun/tun_darwin.rs rename to rust/connlib/tunnel/src/device_channel/tun_darwin.rs index 9af5bd7f7..8311c48ce 100644 --- a/rust/connlib/tunnel/src/device_channel/tun/tun_darwin.rs +++ b/rust/connlib/tunnel/src/device_channel/tun_darwin.rs @@ -1,72 +1,34 @@ +use crate::DnsFallbackStrategy; use connlib_shared::{ messages::Interface as InterfaceConfig, Callbacks, Error, Result, DNS_SENTINEL, }; use ip_network::IpNetwork; use libc::{ - ctl_info, fcntl, getpeername, getsockopt, ioctl, iovec, msghdr, recvmsg, sendmsg, sockaddr, - sockaddr_ctl, sockaddr_in, socklen_t, AF_INET, AF_INET6, AF_SYSTEM, CTLIOCGINFO, F_GETFL, - F_SETFL, IF_NAMESIZE, IPPROTO_IP, O_NONBLOCK, SOCK_STREAM, SYSPROTO_CONTROL, UTUN_OPT_IFNAME, + ctl_info, fcntl, getpeername, getsockopt, ioctl, iovec, msghdr, recvmsg, sendmsg, sockaddr_ctl, + socklen_t, AF_INET, AF_INET6, AF_SYSTEM, CTLIOCGINFO, F_GETFL, F_SETFL, IF_NAMESIZE, + O_NONBLOCK, SYSPROTO_CONTROL, UTUN_OPT_IFNAME, }; +use std::task::{Context, Poll}; use std::{ - ffi::{c_int, c_short, c_uchar}, io, mem::size_of, os::fd::{AsRawFd, RawFd}, - sync::Arc, }; use tokio::io::unix::AsyncFd; -use crate::DnsFallbackStrategy; +mod utils; const CTL_NAME: &[u8] = b"com.apple.net.utun_control"; -const SIOCGIFMTU: u64 = 0x0000_0000_c020_6933; +/// `libc` for darwin doesn't define this constant so we declare it here. +pub(crate) const SIOCGIFMTU: u64 = 0x0000_0000_c020_6933; #[derive(Debug)] -pub(crate) struct IfaceDevice { +pub(crate) struct Tun { name: String, + fd: AsyncFd, } -#[derive(Debug)] -pub(crate) struct IfaceStream { - fd: RawFd, -} - -mod wrapped_socket; - -impl AsRawFd for IfaceStream { - fn as_raw_fd(&self) -> RawFd { - self.fd - } -} - -// For some reason this is not available in libc for darwin :c -#[allow(non_camel_case_types)] -#[repr(C)] -pub struct ifreq { - ifr_name: [c_uchar; IF_NAMESIZE], - ifr_ifru: IfrIfru, -} - -#[repr(C)] -union IfrIfru { - ifru_addr: sockaddr, - ifru_addr_v4: sockaddr_in, - ifru_addr_v6: sockaddr_in, - ifru_dstaddr: sockaddr, - ifru_broadaddr: sockaddr, - ifru_flags: c_short, - ifru_metric: c_int, - ifru_mtu: c_int, - ifru_phys: c_int, - ifru_media: c_int, - ifru_intval: c_int, - ifru_wake_flags: u32, - ifru_route_refcnt: u32, - ifru_cap: [c_int; 2], - ifru_functional_type: u32, -} - -impl IfaceStream { +impl Tun { pub fn write4(&self, src: &[u8]) -> std::io::Result { self.write(src, AF_INET as u8) } @@ -75,35 +37,8 @@ impl IfaceStream { self.write(src, AF_INET6 as u8) } - pub fn read(&self, dst: &mut [u8]) -> std::io::Result { - let mut hdr = [0u8; 4]; - - let mut iov = [ - iovec { - iov_base: hdr.as_mut_ptr() as _, - iov_len: hdr.len(), - }, - iovec { - iov_base: dst.as_mut_ptr() as _, - iov_len: dst.len(), - }, - ]; - - let mut msg_hdr = msghdr { - msg_name: std::ptr::null_mut(), - msg_namelen: 0, - msg_iov: &mut iov[0], - msg_iovlen: iov.len() as _, - msg_control: std::ptr::null_mut(), - msg_controllen: 0, - msg_flags: 0, - }; - - match unsafe { recvmsg(self.fd, &mut msg_hdr, 0) } { - -1 => Err(io::Error::last_os_error()), - 0..=4 => Ok(0), - n => Ok((n - 4) as usize), - } + pub fn poll_read(&self, buf: &mut [u8], cx: &mut Context<'_>) -> Poll> { + utils::poll_raw_fd(&self.fd, |fd| read(fd, buf), cx) } fn write(&self, src: &[u8], af: u8) -> std::io::Result { @@ -129,19 +64,17 @@ impl IfaceStream { msg_flags: 0, }; - match unsafe { sendmsg(self.fd, &msg_hdr, 0) } { + match unsafe { sendmsg(self.fd.as_raw_fd(), &msg_hdr, 0) } { -1 => Err(io::Error::last_os_error()), n => Ok(n as usize), } } -} -impl IfaceDevice { - pub async fn new( + pub fn new( config: &InterfaceConfig, callbacks: &impl Callbacks, fallback_strategy: DnsFallbackStrategy, - ) -> Result<(Self, Arc>)> { + ) -> Result { let mut info = ctl_info { ctl_id: 0, ctl_name: [0; 96], @@ -207,51 +140,28 @@ impl IfaceDevice { set_non_blocking(fd)?; - return Ok(( - Self { name: name(fd)? }, - Arc::new(AsyncFd::new(IfaceStream { fd })?), - )); + return Ok(Self { + name: name(fd)?, + fd: AsyncFd::new(fd)?, + }); } } Err(get_last_error()) } - /// Get the current MTU value - pub async fn mtu(&self) -> Result { - let socket = wrapped_socket::WrappedSocket::new(AF_INET, SOCK_STREAM, IPPROTO_IP); - let fd = match socket.as_raw_fd() { - -1 => return Err(get_last_error()), - fd => fd, - }; - - let iface_name: &[u8] = self.name.as_ref(); - let mut ifr = ifreq { - ifr_name: [0; IF_NAMESIZE], - ifr_ifru: IfrIfru { ifru_mtu: 0 }, - }; - - ifr.ifr_name[..iface_name.len()].copy_from_slice(iface_name); - - if unsafe { ioctl(fd, SIOCGIFMTU, &ifr) } < 0 { - return Err(get_last_error()); - } - - Ok(unsafe { ifr.ifr_ifru.ifru_mtu } as _) - } - - pub async fn add_route( + pub fn add_route( &self, route: IpNetwork, callbacks: &impl Callbacks, - ) -> Result>)>> { + ) -> Result> { // This will always be None in macos callbacks.on_add_route(route)?; Ok(None) } - pub async fn up(&self) -> Result<()> { - Ok(()) + pub fn name(&self) -> &str { + self.name.as_str() } } @@ -269,6 +179,38 @@ fn set_non_blocking(fd: RawFd) -> Result<()> { } } +fn read(fd: RawFd, dst: &mut [u8]) -> std::io::Result { + let mut hdr = [0u8; 4]; + + let mut iov = [ + iovec { + iov_base: hdr.as_mut_ptr() as _, + iov_len: hdr.len(), + }, + iovec { + iov_base: dst.as_mut_ptr() as _, + iov_len: dst.len(), + }, + ]; + + let mut msg_hdr = msghdr { + msg_name: std::ptr::null_mut(), + msg_namelen: 0, + msg_iov: &mut iov[0], + msg_iovlen: iov.len() as _, + msg_control: std::ptr::null_mut(), + msg_controllen: 0, + msg_flags: 0, + }; + + // Safety: Within this module, the file descriptor is always valid. + match unsafe { recvmsg(fd, &mut msg_hdr, 0) } { + -1 => Err(io::Error::last_os_error()), + 0..=4 => Ok(0), + n => Ok((n - 4) as usize), + } +} + fn name(fd: RawFd) -> Result { let mut tunnel_name = [0u8; IF_NAMESIZE]; let mut tunnel_name_len = tunnel_name.len() as socklen_t; diff --git a/rust/connlib/tunnel/src/device_channel/tun_linux.rs b/rust/connlib/tunnel/src/device_channel/tun_linux.rs new file mode 100644 index 000000000..7e3a3a4b2 --- /dev/null +++ b/rust/connlib/tunnel/src/device_channel/tun_linux.rs @@ -0,0 +1,271 @@ +use crate::device_channel::ioctl; +use crate::DnsFallbackStrategy; +use connlib_shared::{messages::Interface as InterfaceConfig, Callbacks, Error, Result}; +use futures::TryStreamExt; +use futures_util::future::BoxFuture; +use futures_util::FutureExt; +use ip_network::IpNetwork; +use libc::{ + close, fcntl, open, F_GETFL, F_SETFL, IFF_MULTI_QUEUE, IFF_NO_PI, IFF_TUN, O_NONBLOCK, O_RDWR, +}; +use netlink_packet_route::RT_SCOPE_UNIVERSE; +use parking_lot::Mutex; +use rtnetlink::{new_connection, Error::NetlinkError, Handle}; +use std::task::{Context, Poll}; +use std::{ + fmt, io, + os::fd::{AsRawFd, RawFd}, +}; +use tokio::io::unix::AsyncFd; + +mod utils; + +pub(crate) const SIOCGIFMTU: libc::c_ulong = libc::SIOCGIFMTU; + +const IFACE_NAME: &str = "tun-firezone"; +const TUNSETIFF: u64 = 0x4004_54ca; +const TUN_FILE: &[u8] = b"/dev/net/tun\0"; +const RT_PROT_STATIC: u8 = 4; +const DEFAULT_MTU: u32 = 1280; +const FILE_ALREADY_EXISTS: i32 = -17; + +pub struct Tun { + handle: Handle, + connection: tokio::task::JoinHandle<()>, + fd: AsyncFd, + + worker: Mutex>>>, +} + +impl fmt::Debug for Tun { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Tun") + .field("handle", &self.handle) + .field("connection", &self.connection) + .field("fd", &self.fd) + .finish_non_exhaustive() + } +} + +impl Drop for Tun { + fn drop(&mut self) { + unsafe { close(self.fd.as_raw_fd()) }; + self.connection.abort(); + } +} + +impl Tun { + pub fn write4(&self, buf: &[u8]) -> io::Result { + write(self.fd.as_raw_fd(), buf) + } + + pub fn write6(&self, buf: &[u8]) -> io::Result { + write(self.fd.as_raw_fd(), buf) + } + + pub fn poll_read(&self, buf: &mut [u8], cx: &mut Context<'_>) -> Poll> { + let mut guard = self.worker.lock(); + if let Some(worker) = guard.as_mut() { + match worker.poll_unpin(cx) { + Poll::Ready(Ok(())) => { + *guard = None; + } + Poll::Ready(Err(e)) => { + *guard = None; + return Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, e))); + } + Poll::Pending => return Poll::Pending, + } + } + + utils::poll_raw_fd(&self.fd, |fd| read(fd, buf), cx) + } + + pub fn new( + config: &InterfaceConfig, + _: &impl Callbacks, + _: DnsFallbackStrategy, + ) -> Result { + let fd = match unsafe { open(TUN_FILE.as_ptr() as _, O_RDWR) } { + -1 => return Err(get_last_error()), + fd => fd, + }; + + // Safety: We just opened the file descriptor. + unsafe { + ioctl::exec(fd, TUNSETIFF, &ioctl::Request::::new())?; + } + + set_non_blocking(fd)?; + + let (connection, handle, _) = new_connection()?; + let join_handle = tokio::spawn(connection); + + Ok(Self { + handle: handle.clone(), + connection: join_handle, + fd: AsyncFd::new(fd)?, + worker: Mutex::new(Some(set_iface_config(config.clone(), handle).boxed())), + }) + } + + pub fn add_route(&self, route: IpNetwork, _: &impl Callbacks) -> Result> { + let handle = self.handle.clone(); + + let add_route_worker = async move { + let index = handle + .link() + .get() + .match_name(IFACE_NAME.to_string()) + .execute() + .try_next() + .await? + .ok_or(Error::NoIface)? + .header + .index; + + let req = handle + .route() + .add() + .output_interface(index) + .protocol(RT_PROT_STATIC) + .scope(RT_SCOPE_UNIVERSE); + let res = match route { + IpNetwork::V4(ipnet) => { + req.v4() + .destination_prefix(ipnet.network_address(), ipnet.netmask()) + .execute() + .await + } + IpNetwork::V6(ipnet) => { + req.v6() + .destination_prefix(ipnet.network_address(), ipnet.netmask()) + .execute() + .await + } + }; + + match res { + Ok(_) => Ok(()), + Err(NetlinkError(err)) if err.raw_code() == FILE_ALREADY_EXISTS => Ok(()), + Err(err) => Err(err.into()), + } + }; + + let mut guard = self.worker.lock(); + match guard.take() { + None => *guard = Some(add_route_worker.boxed()), + Some(current_worker) => { + *guard = Some( + async move { + current_worker.await?; + add_route_worker.await?; + + Ok(()) + } + .boxed(), + ) + } + } + + Ok(None) + } + + pub fn name(&self) -> &str { + IFACE_NAME + } +} + +#[tracing::instrument(level = "trace", skip(handle))] +async fn set_iface_config(config: InterfaceConfig, handle: Handle) -> Result<()> { + let index = handle + .link() + .get() + .match_name(IFACE_NAME.to_string()) + .execute() + .try_next() + .await? + .ok_or(Error::NoIface)? + .header + .index; + + let ips = handle + .address() + .get() + .set_link_index_filter(index) + .execute(); + + ips.try_for_each(|ip| handle.address().del(ip).execute()) + .await?; + + handle.link().set(index).mtu(DEFAULT_MTU).execute().await?; + + let res_v4 = handle + .address() + .add(index, config.ipv4.into(), 32) + .execute() + .await; + let res_v6 = handle + .address() + .add(index, config.ipv6.into(), 128) + .execute() + .await; + + handle.link().set(index).up().execute().await?; + + Ok(res_v4.or(res_v6)?) +} + +fn get_last_error() -> Error { + Error::Io(io::Error::last_os_error()) +} + +fn set_non_blocking(fd: RawFd) -> Result<()> { + match unsafe { fcntl(fd, F_GETFL) } { + -1 => Err(get_last_error()), + flags => match unsafe { fcntl(fd, F_SETFL, flags | O_NONBLOCK) } { + -1 => Err(get_last_error()), + _ => Ok(()), + }, + } +} + +/// Read from the given file descriptor in the buffer. +fn read(fd: RawFd, dst: &mut [u8]) -> io::Result { + // Safety: Within this module, the file descriptor is always valid. + match unsafe { libc::read(fd, dst.as_mut_ptr() as _, dst.len()) } { + -1 => Err(io::Error::last_os_error()), + n => Ok(n as usize), + } +} + +/// Write the buffer to the given file descriptor. +fn write(fd: RawFd, buf: &[u8]) -> io::Result { + // Safety: Within this module, the file descriptor is always valid. + match unsafe { libc::write(fd, buf.as_ptr() as _, buf.len() as _) } { + -1 => Err(io::Error::last_os_error()), + n => Ok(n as usize), + } +} + +impl ioctl::Request { + fn new() -> Self { + let name_as_bytes = IFACE_NAME.as_bytes(); + debug_assert!(name_as_bytes.len() < libc::IF_NAMESIZE); + + let mut name = [0u8; libc::IF_NAMESIZE]; + name[..name_as_bytes.len()].copy_from_slice(name_as_bytes); + + Self { + name, + payload: SetTunFlagsPayload { + flags: (IFF_TUN | IFF_NO_PI | IFF_MULTI_QUEUE) as _, + }, + } + } +} + +#[repr(C)] +struct SetTunFlagsPayload { + flags: std::ffi::c_short, +} diff --git a/rust/connlib/tunnel/src/device_channel/tun_windows.rs b/rust/connlib/tunnel/src/device_channel/tun_windows.rs new file mode 100644 index 000000000..49fd3e60b --- /dev/null +++ b/rust/connlib/tunnel/src/device_channel/tun_windows.rs @@ -0,0 +1,15 @@ +pub struct Tun {} + +impl Tun { + pub fn new() -> Self { + Self {} + } + + pub fn write4(&self, _: &[u8]) -> std::io::Result { + Ok(0) + } + + pub fn write6(&self, _: &[u8]) -> std::io::Result { + Ok(0) + } +} diff --git a/rust/connlib/tunnel/src/device_channel/utils.rs b/rust/connlib/tunnel/src/device_channel/utils.rs new file mode 100644 index 000000000..736671ac7 --- /dev/null +++ b/rust/connlib/tunnel/src/device_channel/utils.rs @@ -0,0 +1,26 @@ +use std::io; +use std::os::fd::{AsRawFd, RawFd}; +use std::task::{Context, Poll}; +use tokio::io::Ready; + +#[cfg(target_family = "unix")] +pub fn poll_raw_fd( + fd: &tokio::io::unix::AsyncFd, + mut read: impl FnMut(RawFd) -> io::Result, + cx: &mut Context<'_>, +) -> Poll> { + loop { + let mut guard = std::task::ready!(fd.poll_read_ready(cx))?; + + match read(guard.get_inner().as_raw_fd()) { + Ok(n) => return Poll::Ready(Ok(n)), + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + // a read has blocked, but a write might still succeed. + // clear only the read readiness. + guard.clear_ready_matching(Ready::READABLE); + continue; + } + Err(e) => return Poll::Ready(Err(e)), + } + } +} diff --git a/rust/connlib/tunnel/src/gateway.rs b/rust/connlib/tunnel/src/gateway.rs index b78b95a4f..4ff426a0a 100644 --- a/rust/connlib/tunnel/src/gateway.rs +++ b/rust/connlib/tunnel/src/gateway.rs @@ -1,4 +1,4 @@ -use crate::device_channel::create_iface; +use crate::device_channel::Device; use crate::peer::PacketTransformGateway; use crate::{ ConnectedPeer, DnsFallbackStrategy, Event, RoleState, Tunnel, ICE_GATHERING_TIMEOUT_SECONDS, @@ -22,10 +22,13 @@ where { /// Sets the interface configuration and starts background tasks. #[tracing::instrument(level = "trace", skip(self))] - pub async fn set_interface(&self, config: &InterfaceConfig) -> connlib_shared::Result<()> { + pub fn set_interface(&self, config: &InterfaceConfig) -> connlib_shared::Result<()> { // Note: the dns fallback strategy is irrelevant for gateways - let device = - Arc::new(create_iface(config, self.callbacks(), DnsFallbackStrategy::default()).await?); + let device = Arc::new(Device::new( + config, + self.callbacks(), + DnsFallbackStrategy::default(), + )?); self.device.store(Some(device.clone())); self.no_device_waker.wake(); diff --git a/rust/connlib/tunnel/src/lib.rs b/rust/connlib/tunnel/src/lib.rs index 1bf638346..52e4e3356 100644 --- a/rust/connlib/tunnel/src/lib.rs +++ b/rust/connlib/tunnel/src/lib.rs @@ -398,16 +398,10 @@ where continue; }; - tokio::spawn({ - let callbacks = self.callbacks.clone(); - - async move { - if let Err(e) = device.refresh_mtu().await { - tracing::error!(error = ?e, "refresh_mtu"); - let _ = callbacks.on_error(&e); - } - } - }); + if let Err(e) = device.refresh_mtu() { + tracing::error!(error = ?e, "refresh_mtu"); + let _ = self.callbacks.on_error(&e); + } } if let Poll::Ready(event) = self.role_state.lock().poll_next_event(cx) { diff --git a/rust/gateway/src/main.rs b/rust/gateway/src/main.rs index 13271984d..c6199d189 100644 --- a/rust/gateway/src/main.rs +++ b/rust/gateway/src/main.rs @@ -65,7 +65,6 @@ async fn run( tunnel .set_interface(&init.interface) - .await .context("Failed to set interface")?; let mut eventloop = Eventloop::new(tunnel, portal);