refactor(connlib): split Device creation from initialization (#4069)

This reduces the amount of boilerplate required in the `Tunnel`'s
eventloop. It also makes re-initialization of a `Device` much easier.
This commit is contained in:
Thomas Eizinger
2024-03-12 08:04:18 +11:00
committed by GitHub
parent dde8b646f0
commit 879a9019b3
5 changed files with 117 additions and 110 deletions

View File

@@ -1,4 +1,3 @@
use crate::device_channel::Device;
use crate::ip_packet::{IpPacket, MutableIpPacket};
use crate::peer::PacketTransformClient;
use crate::peer_store::PeerStore;
@@ -170,17 +169,14 @@ where
config: &InterfaceConfig,
dns_mapping: BiMap<IpAddr, DnsServer>,
) -> connlib_shared::Result<()> {
let device = Device::new(
self.device.initialize(
config,
// We can just sort in here because sentinel ips are created in order
dns_mapping.left_values().copied().sorted().collect(),
self.callbacks(),
&self.callbacks().clone(),
)?;
let name = device.name().to_owned();
self.device = Some(device);
self.no_device_waker.wake();
let name = self.device.name().to_owned();
let mut errs = Vec::new();
for sentinel in dns_mapping.left_values() {
@@ -217,10 +213,7 @@ where
#[tracing::instrument(level = "trace", skip(self))]
pub fn add_route(&mut self, route: IpNetwork) -> connlib_shared::Result<()> {
let callbacks = self.callbacks().clone();
self.device
.as_mut()
.ok_or(Error::ControlProtocolError)?
.add_route(route, &callbacks)?;
self.device.add_route(route, &callbacks)?;
Ok(())
}
@@ -228,10 +221,7 @@ where
#[tracing::instrument(level = "trace", skip(self))]
pub fn remove_route(&mut self, route: IpNetwork) -> connlib_shared::Result<()> {
let callbacks = self.callbacks().clone();
self.device
.as_mut()
.ok_or(Error::ControlProtocolError)?
.remove_route(route, &callbacks)?;
self.device.remove_route(route, &callbacks)?;
Ok(())
}

View File

@@ -206,23 +206,21 @@ where
let ips: Vec<IpNetwork> = addrs.iter().copied().map(Into::into).collect();
if let Some(device) = self.device.as_ref() {
send_dns_answer(
&mut self.role_state,
Rtype::Aaaa,
device,
&resource_description,
&addrs,
);
send_dns_answer(
&mut self.role_state,
Rtype::Aaaa,
&self.device,
&resource_description,
&addrs,
);
send_dns_answer(
&mut self.role_state,
Rtype::A,
device,
&resource_description,
&addrs,
);
}
send_dns_answer(
&mut self.role_state,
Rtype::A,
&self.device,
&resource_description,
&addrs,
);
Ok(ips)
}

View File

@@ -26,44 +26,63 @@ use ip_network::IpNetwork;
use pnet_packet::Packet;
use std::io;
use std::net::IpAddr;
use std::task::{Context, Poll};
use std::task::{Context, Poll, Waker};
use std::time::{Duration, Instant};
use tun::Tun;
pub struct Device {
mtu: usize,
tun: Tun,
tun: Option<Tun>,
waker: Option<Waker>,
mtu_refreshed_at: Instant,
}
impl Device {
pub(crate) fn new() -> Self {
Self {
tun: None,
mtu: 1_280,
waker: None,
mtu_refreshed_at: Instant::now(),
}
}
#[cfg(target_family = "unix")]
pub(crate) fn new(
pub(crate) fn initialize(
&mut self,
config: &Interface,
dns_config: Vec<IpAddr>,
callbacks: &impl Callbacks<Error = Error>,
) -> Result<Device, ConnlibError> {
) -> Result<(), ConnlibError> {
let tun = Tun::new(config, dns_config, callbacks)?;
let mtu = ioctl::interface_mtu_by_name(tun.name())?;
Ok(Device {
mtu,
tun,
mtu_refreshed_at: Instant::now(),
})
self.tun = Some(tun);
self.mtu = mtu;
if let Some(waker) = self.waker.take() {
waker.wake();
}
Ok(())
}
#[cfg(target_family = "windows")]
pub(crate) fn new(
pub(crate) fn initialize(
&mut self,
config: &Interface,
dns_config: Vec<IpAddr>,
_: &impl Callbacks<Error = Error>,
) -> Result<Device, ConnlibError> {
Ok(Device {
tun: Tun::new(config, dns_config)?,
mtu: 1_280,
mtu_refreshed_at: Instant::now(),
})
) -> Result<(), ConnlibError> {
let tun = Tun::new(config, dns_config)?;
self.tun = Some(tun);
if let Some(waker) = self.waker.take() {
waker.wake();
}
Ok(())
}
#[cfg(target_family = "unix")]
@@ -72,13 +91,20 @@ impl Device {
buf: &'b mut [u8],
cx: &mut Context<'_>,
) -> Poll<io::Result<MutableIpPacket<'b>>> {
let Some(tun) = self.tun.as_mut() else {
self.waker = Some(cx.waker().clone());
return Poll::Pending;
};
use pnet_packet::Packet as _;
if self.mtu_refreshed_at.elapsed() > Duration::from_secs(30) {
self.refresh_mtu()?;
let mtu = ioctl::interface_mtu_by_name(tun.name())?;
self.mtu = mtu;
self.mtu_refreshed_at = Instant::now();
}
let n = std::task::ready!(self.tun.poll_read(&mut buf[..self.mtu()], cx))?;
let n = std::task::ready!(tun.poll_read(&mut buf[..self.mtu], cx))?;
if n == 0 {
return Poll::Ready(Err(io::Error::new(
@@ -101,17 +127,22 @@ impl Device {
#[cfg(target_family = "windows")]
pub(crate) fn poll_read<'b>(
&self,
&mut self,
buf: &'b mut [u8],
cx: &mut Context<'_>,
) -> Poll<io::Result<MutableIpPacket<'b>>> {
let Some(tun) = self.tun.as_mut() else {
self.waker = Some(cx.waker().clone());
return Poll::Pending;
};
use pnet_packet::Packet as _;
if self.mtu_refreshed_at.elapsed() > Duration::from_secs(30) {
self.refresh_mtu()?;
// TODO
}
let n = std::task::ready!(self.tun.poll_read(&mut buf[..self.mtu()], cx))?;
let n = std::task::ready!(tun.poll_read(&mut buf[..self.mtu], cx))?;
if n == 0 {
return Poll::Ready(Err(io::Error::new(
@@ -132,12 +163,11 @@ impl Device {
Poll::Ready(Ok(packet))
}
pub(crate) fn mtu(&self) -> usize {
self.mtu
}
pub(crate) fn name(&self) -> &str {
self.tun.name()
self.tun
.as_ref()
.map(|t| t.name())
.unwrap_or("uninitialized")
}
pub(crate) fn remove_route(
@@ -145,7 +175,8 @@ impl Device {
route: IpNetwork,
callbacks: &impl Callbacks<Error = Error>,
) -> Result<Option<Device>, Error> {
self.tun.remove_route(route, callbacks)?;
self.tun_mut()?.remove_route(route, callbacks)?;
Ok(None)
}
@@ -155,33 +186,30 @@ impl Device {
route: IpNetwork,
callbacks: &impl Callbacks<Error = Error>,
) -> Result<Option<Device>, Error> {
self.tun.add_route(route, callbacks)?;
self.tun_mut()?.add_route(route, callbacks)?;
Ok(None)
}
#[cfg(target_family = "unix")]
fn refresh_mtu(&mut self) -> io::Result<()> {
let mtu = ioctl::interface_mtu_by_name(self.tun.name())?;
self.mtu = mtu;
self.mtu_refreshed_at = Instant::now();
Ok(())
}
#[cfg(target_family = "windows")]
fn refresh_mtu(&self) -> io::Result<()> {
// TODO
Ok(())
}
pub fn write(&self, packet: IpPacket<'_>) -> io::Result<usize> {
tracing::trace!(target: "wire", to = "device", bytes = %packet.packet().len());
match packet {
IpPacket::Ipv4Packet(msg) => self.tun.write4(msg.packet()),
IpPacket::Ipv6Packet(msg) => self.tun.write6(msg.packet()),
IpPacket::Ipv4Packet(msg) => self.tun()?.write4(msg.packet()),
IpPacket::Ipv6Packet(msg) => self.tun()?.write6(msg.packet()),
}
}
fn tun(&self) -> io::Result<&Tun> {
self.tun.as_ref().ok_or_else(io_error_not_initialized)
}
fn tun_mut(&mut self) -> io::Result<&mut Tun> {
self.tun.as_mut().ok_or_else(io_error_not_initialized)
}
}
fn io_error_not_initialized() -> io::Error {
io::Error::new(io::ErrorKind::NotConnected, "device is not initialized yet")
}
#[cfg(target_family = "unix")]

View File

@@ -1,4 +1,3 @@
use crate::device_channel::Device;
use crate::ip_packet::MutableIpPacket;
use crate::peer::{PacketTransformGateway, Peer};
use crate::peer_store::PeerStore;
@@ -46,16 +45,18 @@ where
#[tracing::instrument(level = "trace", skip(self))]
pub fn set_interface(&mut self, config: &InterfaceConfig) -> connlib_shared::Result<()> {
// Note: the dns fallback strategy is irrelevant for gateways
let mut device = Device::new(config, vec![], self.callbacks())?;
self.device
.initialize(config, vec![], &self.callbacks().clone())?;
let result_v4 = device.add_route(PEERS_IPV4.parse().unwrap(), self.callbacks());
let result_v6 = device.add_route(PEERS_IPV6.parse().unwrap(), self.callbacks());
let result_v4 = self
.device
.add_route(PEERS_IPV4.parse().unwrap(), &self.callbacks().clone());
let result_v6 = self
.device
.add_route(PEERS_IPV6.parse().unwrap(), &self.callbacks().clone());
result_v4.or(result_v6)?;
let name = device.name().to_owned();
self.device = Some(device);
self.no_device_waker.wake();
let name = self.device.name().to_owned();
tracing::debug!(ip4 = %config.ipv4, ip6 = %config.ipv6, %name, "TUN device initialized");

View File

@@ -9,7 +9,7 @@ use connlib_shared::{
CallbackErrorFacade, Callbacks, Error, Result,
};
use device_channel::Device;
use futures_util::{task::AtomicWaker, FutureExt};
use futures_util::FutureExt;
use peer::PacketTransform;
use peer_store::PeerStore;
use snownet::{Node, Server};
@@ -60,8 +60,7 @@ pub struct Tunnel<CB: Callbacks, TRoleState, TRole, TId> {
/// State that differs per role, i.e. clients vs gateways.
role_state: TRoleState,
device: Option<Device>,
no_device_waker: AtomicWaker,
device: Device,
connections_state: ConnectionState<TRole, TId>,
@@ -73,14 +72,9 @@ where
CB: Callbacks + 'static,
{
pub fn poll_next_event(&mut self, cx: &mut Context<'_>) -> Poll<Result<Event<GatewayId>>> {
let Some(device) = self.device.as_mut() else {
self.no_device_waker.register(cx.waker());
return Poll::Pending;
};
match self.role_state.poll_next_event(cx) {
Poll::Ready(Event::SendPacket(packet)) => {
device.write(packet)?;
self.device.write(packet)?;
cx.waker().wake_by_ref();
}
Poll::Ready(other) => return Poll::Ready(Ok(other)),
@@ -96,10 +90,11 @@ where
_ => (),
}
match self
.connections_state
.poll_sockets(device, &mut self.role_state.peers, cx)?
{
match self.connections_state.poll_sockets(
&mut self.device,
&mut self.role_state.peers,
cx,
)? {
Poll::Ready(()) => {
cx.waker().wake_by_ref();
}
@@ -108,7 +103,7 @@ where
ready!(self.connections_state.sockets.poll_send_ready(cx))?; // Ensure socket is ready before we read from device.
match device.poll_read(&mut self.read_buf, cx)? {
match self.device.poll_read(&mut self.read_buf, cx)? {
Poll::Ready(packet) => {
let Some((peer_id, packet)) = self.role_state.encapsulate(packet, Instant::now())
else {
@@ -144,11 +139,6 @@ where
Poll::Pending => {}
}
let Some(device) = self.device.as_mut() else {
self.no_device_waker.register(cx.waker());
return Poll::Pending;
};
match self.connections_state.poll_next_event(cx) {
Poll::Ready(Event::StopPeer(id)) => {
self.role_state.peers.remove(&id);
@@ -158,10 +148,11 @@ where
_ => (),
}
match self
.connections_state
.poll_sockets(device, &mut self.role_state.peers, cx)?
{
match self.connections_state.poll_sockets(
&mut self.device,
&mut self.role_state.peers,
cx,
)? {
Poll::Ready(()) => {
cx.waker().wake_by_ref();
}
@@ -170,7 +161,7 @@ where
ready!(self.connections_state.sockets.poll_send_ready(cx))?; // Ensure socket is ready before we read from device.
match device.poll_read(&mut self.read_buf, cx)? {
match self.device.poll_read(&mut self.read_buf, cx)? {
Poll::Ready(packet) => {
let Some((peer_id, packet)) = self.role_state.encapsulate(packet) else {
cx.waker().wake_by_ref();
@@ -223,10 +214,9 @@ where
}
Ok(Self {
device: Default::default(),
device: Device::new(),
callbacks,
role_state: Default::default(),
no_device_waker: Default::default(),
connections_state,
read_buf: [0u8; MAX_UDP_SIZE],
})