From 879a9019b3ed3f1e6b4501f5729eea8e5f54f0a2 Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Tue, 12 Mar 2024 08:04:18 +1100 Subject: [PATCH] refactor(connlib): split `Device` creation from initialization (#4069) This reduces the amount of boilerplate required in the `Tunnel`'s eventloop. It also makes re-initialization of a `Device` much easier. --- rust/connlib/tunnel/src/client.rs | 20 +-- .../tunnel/src/control_protocol/client.rs | 30 +++-- rust/connlib/tunnel/src/device_channel.rs | 118 +++++++++++------- rust/connlib/tunnel/src/gateway.rs | 17 +-- rust/connlib/tunnel/src/lib.rs | 42 +++---- 5 files changed, 117 insertions(+), 110 deletions(-) diff --git a/rust/connlib/tunnel/src/client.rs b/rust/connlib/tunnel/src/client.rs index 9d019968a..5e69da8c7 100644 --- a/rust/connlib/tunnel/src/client.rs +++ b/rust/connlib/tunnel/src/client.rs @@ -1,4 +1,3 @@ -use crate::device_channel::Device; use crate::ip_packet::{IpPacket, MutableIpPacket}; use crate::peer::PacketTransformClient; use crate::peer_store::PeerStore; @@ -170,17 +169,14 @@ where config: &InterfaceConfig, dns_mapping: BiMap, ) -> connlib_shared::Result<()> { - let device = Device::new( + self.device.initialize( config, // We can just sort in here because sentinel ips are created in order dns_mapping.left_values().copied().sorted().collect(), - self.callbacks(), + &self.callbacks().clone(), )?; - let name = device.name().to_owned(); - - self.device = Some(device); - self.no_device_waker.wake(); + let name = self.device.name().to_owned(); let mut errs = Vec::new(); for sentinel in dns_mapping.left_values() { @@ -217,10 +213,7 @@ where #[tracing::instrument(level = "trace", skip(self))] pub fn add_route(&mut self, route: IpNetwork) -> connlib_shared::Result<()> { let callbacks = self.callbacks().clone(); - self.device - .as_mut() - .ok_or(Error::ControlProtocolError)? - .add_route(route, &callbacks)?; + self.device.add_route(route, &callbacks)?; Ok(()) } @@ -228,10 +221,7 @@ where #[tracing::instrument(level = "trace", skip(self))] pub fn remove_route(&mut self, route: IpNetwork) -> connlib_shared::Result<()> { let callbacks = self.callbacks().clone(); - self.device - .as_mut() - .ok_or(Error::ControlProtocolError)? - .remove_route(route, &callbacks)?; + self.device.remove_route(route, &callbacks)?; Ok(()) } diff --git a/rust/connlib/tunnel/src/control_protocol/client.rs b/rust/connlib/tunnel/src/control_protocol/client.rs index 4f9193207..c1f567296 100644 --- a/rust/connlib/tunnel/src/control_protocol/client.rs +++ b/rust/connlib/tunnel/src/control_protocol/client.rs @@ -206,23 +206,21 @@ where let ips: Vec = addrs.iter().copied().map(Into::into).collect(); - if let Some(device) = self.device.as_ref() { - send_dns_answer( - &mut self.role_state, - Rtype::Aaaa, - device, - &resource_description, - &addrs, - ); + send_dns_answer( + &mut self.role_state, + Rtype::Aaaa, + &self.device, + &resource_description, + &addrs, + ); - send_dns_answer( - &mut self.role_state, - Rtype::A, - device, - &resource_description, - &addrs, - ); - } + send_dns_answer( + &mut self.role_state, + Rtype::A, + &self.device, + &resource_description, + &addrs, + ); Ok(ips) } diff --git a/rust/connlib/tunnel/src/device_channel.rs b/rust/connlib/tunnel/src/device_channel.rs index cbe15eddf..1fe4c2be2 100644 --- a/rust/connlib/tunnel/src/device_channel.rs +++ b/rust/connlib/tunnel/src/device_channel.rs @@ -26,44 +26,63 @@ use ip_network::IpNetwork; use pnet_packet::Packet; use std::io; use std::net::IpAddr; -use std::task::{Context, Poll}; +use std::task::{Context, Poll, Waker}; use std::time::{Duration, Instant}; use tun::Tun; pub struct Device { mtu: usize, - tun: Tun, + tun: Option, + waker: Option, mtu_refreshed_at: Instant, } impl Device { + pub(crate) fn new() -> Self { + Self { + tun: None, + mtu: 1_280, + waker: None, + mtu_refreshed_at: Instant::now(), + } + } + #[cfg(target_family = "unix")] - pub(crate) fn new( + pub(crate) fn initialize( + &mut self, config: &Interface, dns_config: Vec, callbacks: &impl Callbacks, - ) -> Result { + ) -> Result<(), ConnlibError> { let tun = Tun::new(config, dns_config, callbacks)?; let mtu = ioctl::interface_mtu_by_name(tun.name())?; - Ok(Device { - mtu, - tun, - mtu_refreshed_at: Instant::now(), - }) + self.tun = Some(tun); + self.mtu = mtu; + + if let Some(waker) = self.waker.take() { + waker.wake(); + } + + Ok(()) } #[cfg(target_family = "windows")] - pub(crate) fn new( + pub(crate) fn initialize( + &mut self, config: &Interface, dns_config: Vec, _: &impl Callbacks, - ) -> Result { - Ok(Device { - tun: Tun::new(config, dns_config)?, - mtu: 1_280, - mtu_refreshed_at: Instant::now(), - }) + ) -> Result<(), ConnlibError> { + let tun = Tun::new(config, dns_config)?; + + self.tun = Some(tun); + + if let Some(waker) = self.waker.take() { + waker.wake(); + } + + Ok(()) } #[cfg(target_family = "unix")] @@ -72,13 +91,20 @@ impl Device { buf: &'b mut [u8], cx: &mut Context<'_>, ) -> Poll>> { + let Some(tun) = self.tun.as_mut() else { + self.waker = Some(cx.waker().clone()); + return Poll::Pending; + }; + use pnet_packet::Packet as _; if self.mtu_refreshed_at.elapsed() > Duration::from_secs(30) { - self.refresh_mtu()?; + let mtu = ioctl::interface_mtu_by_name(tun.name())?; + self.mtu = mtu; + self.mtu_refreshed_at = Instant::now(); } - let n = std::task::ready!(self.tun.poll_read(&mut buf[..self.mtu()], cx))?; + let n = std::task::ready!(tun.poll_read(&mut buf[..self.mtu], cx))?; if n == 0 { return Poll::Ready(Err(io::Error::new( @@ -101,17 +127,22 @@ impl Device { #[cfg(target_family = "windows")] pub(crate) fn poll_read<'b>( - &self, + &mut self, buf: &'b mut [u8], cx: &mut Context<'_>, ) -> Poll>> { + let Some(tun) = self.tun.as_mut() else { + self.waker = Some(cx.waker().clone()); + return Poll::Pending; + }; + use pnet_packet::Packet as _; if self.mtu_refreshed_at.elapsed() > Duration::from_secs(30) { - self.refresh_mtu()?; + // TODO } - let n = std::task::ready!(self.tun.poll_read(&mut buf[..self.mtu()], cx))?; + let n = std::task::ready!(tun.poll_read(&mut buf[..self.mtu], cx))?; if n == 0 { return Poll::Ready(Err(io::Error::new( @@ -132,12 +163,11 @@ impl Device { Poll::Ready(Ok(packet)) } - pub(crate) fn mtu(&self) -> usize { - self.mtu - } - pub(crate) fn name(&self) -> &str { - self.tun.name() + self.tun + .as_ref() + .map(|t| t.name()) + .unwrap_or("uninitialized") } pub(crate) fn remove_route( @@ -145,7 +175,8 @@ impl Device { route: IpNetwork, callbacks: &impl Callbacks, ) -> Result, Error> { - self.tun.remove_route(route, callbacks)?; + self.tun_mut()?.remove_route(route, callbacks)?; + Ok(None) } @@ -155,33 +186,30 @@ impl Device { route: IpNetwork, callbacks: &impl Callbacks, ) -> Result, Error> { - self.tun.add_route(route, callbacks)?; + self.tun_mut()?.add_route(route, callbacks)?; Ok(None) } - #[cfg(target_family = "unix")] - fn refresh_mtu(&mut self) -> io::Result<()> { - let mtu = ioctl::interface_mtu_by_name(self.tun.name())?; - self.mtu = mtu; - self.mtu_refreshed_at = Instant::now(); - - Ok(()) - } - - #[cfg(target_family = "windows")] - fn refresh_mtu(&self) -> io::Result<()> { - // TODO - Ok(()) - } - pub fn write(&self, packet: IpPacket<'_>) -> io::Result { tracing::trace!(target: "wire", to = "device", bytes = %packet.packet().len()); match packet { - IpPacket::Ipv4Packet(msg) => self.tun.write4(msg.packet()), - IpPacket::Ipv6Packet(msg) => self.tun.write6(msg.packet()), + IpPacket::Ipv4Packet(msg) => self.tun()?.write4(msg.packet()), + IpPacket::Ipv6Packet(msg) => self.tun()?.write6(msg.packet()), } } + + fn tun(&self) -> io::Result<&Tun> { + self.tun.as_ref().ok_or_else(io_error_not_initialized) + } + + fn tun_mut(&mut self) -> io::Result<&mut Tun> { + self.tun.as_mut().ok_or_else(io_error_not_initialized) + } +} + +fn io_error_not_initialized() -> io::Error { + io::Error::new(io::ErrorKind::NotConnected, "device is not initialized yet") } #[cfg(target_family = "unix")] diff --git a/rust/connlib/tunnel/src/gateway.rs b/rust/connlib/tunnel/src/gateway.rs index a08a3628f..afb4904f5 100644 --- a/rust/connlib/tunnel/src/gateway.rs +++ b/rust/connlib/tunnel/src/gateway.rs @@ -1,4 +1,3 @@ -use crate::device_channel::Device; use crate::ip_packet::MutableIpPacket; use crate::peer::{PacketTransformGateway, Peer}; use crate::peer_store::PeerStore; @@ -46,16 +45,18 @@ where #[tracing::instrument(level = "trace", skip(self))] pub fn set_interface(&mut self, config: &InterfaceConfig) -> connlib_shared::Result<()> { // Note: the dns fallback strategy is irrelevant for gateways - let mut device = Device::new(config, vec![], self.callbacks())?; + self.device + .initialize(config, vec![], &self.callbacks().clone())?; - let result_v4 = device.add_route(PEERS_IPV4.parse().unwrap(), self.callbacks()); - let result_v6 = device.add_route(PEERS_IPV6.parse().unwrap(), self.callbacks()); + let result_v4 = self + .device + .add_route(PEERS_IPV4.parse().unwrap(), &self.callbacks().clone()); + let result_v6 = self + .device + .add_route(PEERS_IPV6.parse().unwrap(), &self.callbacks().clone()); result_v4.or(result_v6)?; - let name = device.name().to_owned(); - - self.device = Some(device); - self.no_device_waker.wake(); + let name = self.device.name().to_owned(); tracing::debug!(ip4 = %config.ipv4, ip6 = %config.ipv6, %name, "TUN device initialized"); diff --git a/rust/connlib/tunnel/src/lib.rs b/rust/connlib/tunnel/src/lib.rs index d17e69331..1b1b71b88 100644 --- a/rust/connlib/tunnel/src/lib.rs +++ b/rust/connlib/tunnel/src/lib.rs @@ -9,7 +9,7 @@ use connlib_shared::{ CallbackErrorFacade, Callbacks, Error, Result, }; use device_channel::Device; -use futures_util::{task::AtomicWaker, FutureExt}; +use futures_util::FutureExt; use peer::PacketTransform; use peer_store::PeerStore; use snownet::{Node, Server}; @@ -60,8 +60,7 @@ pub struct Tunnel { /// State that differs per role, i.e. clients vs gateways. role_state: TRoleState, - device: Option, - no_device_waker: AtomicWaker, + device: Device, connections_state: ConnectionState, @@ -73,14 +72,9 @@ where CB: Callbacks + 'static, { pub fn poll_next_event(&mut self, cx: &mut Context<'_>) -> Poll>> { - let Some(device) = self.device.as_mut() else { - self.no_device_waker.register(cx.waker()); - return Poll::Pending; - }; - match self.role_state.poll_next_event(cx) { Poll::Ready(Event::SendPacket(packet)) => { - device.write(packet)?; + self.device.write(packet)?; cx.waker().wake_by_ref(); } Poll::Ready(other) => return Poll::Ready(Ok(other)), @@ -96,10 +90,11 @@ where _ => (), } - match self - .connections_state - .poll_sockets(device, &mut self.role_state.peers, cx)? - { + match self.connections_state.poll_sockets( + &mut self.device, + &mut self.role_state.peers, + cx, + )? { Poll::Ready(()) => { cx.waker().wake_by_ref(); } @@ -108,7 +103,7 @@ where ready!(self.connections_state.sockets.poll_send_ready(cx))?; // Ensure socket is ready before we read from device. - match device.poll_read(&mut self.read_buf, cx)? { + match self.device.poll_read(&mut self.read_buf, cx)? { Poll::Ready(packet) => { let Some((peer_id, packet)) = self.role_state.encapsulate(packet, Instant::now()) else { @@ -144,11 +139,6 @@ where Poll::Pending => {} } - let Some(device) = self.device.as_mut() else { - self.no_device_waker.register(cx.waker()); - return Poll::Pending; - }; - match self.connections_state.poll_next_event(cx) { Poll::Ready(Event::StopPeer(id)) => { self.role_state.peers.remove(&id); @@ -158,10 +148,11 @@ where _ => (), } - match self - .connections_state - .poll_sockets(device, &mut self.role_state.peers, cx)? - { + match self.connections_state.poll_sockets( + &mut self.device, + &mut self.role_state.peers, + cx, + )? { Poll::Ready(()) => { cx.waker().wake_by_ref(); } @@ -170,7 +161,7 @@ where ready!(self.connections_state.sockets.poll_send_ready(cx))?; // Ensure socket is ready before we read from device. - match device.poll_read(&mut self.read_buf, cx)? { + match self.device.poll_read(&mut self.read_buf, cx)? { Poll::Ready(packet) => { let Some((peer_id, packet)) = self.role_state.encapsulate(packet) else { cx.waker().wake_by_ref(); @@ -223,10 +214,9 @@ where } Ok(Self { - device: Default::default(), + device: Device::new(), callbacks, role_state: Default::default(), - no_device_waker: Default::default(), connections_state, read_buf: [0u8; MAX_UDP_SIZE], })