mirror of
https://github.com/outbackdingo/firezone.git
synced 2026-01-27 10:18:54 +00:00
refactor(connlib): split Device creation from initialization (#4069)
This reduces the amount of boilerplate required in the `Tunnel`'s eventloop. It also makes re-initialization of a `Device` much easier.
This commit is contained in:
@@ -1,4 +1,3 @@
|
||||
use crate::device_channel::Device;
|
||||
use crate::ip_packet::{IpPacket, MutableIpPacket};
|
||||
use crate::peer::PacketTransformClient;
|
||||
use crate::peer_store::PeerStore;
|
||||
@@ -170,17 +169,14 @@ where
|
||||
config: &InterfaceConfig,
|
||||
dns_mapping: BiMap<IpAddr, DnsServer>,
|
||||
) -> connlib_shared::Result<()> {
|
||||
let device = Device::new(
|
||||
self.device.initialize(
|
||||
config,
|
||||
// We can just sort in here because sentinel ips are created in order
|
||||
dns_mapping.left_values().copied().sorted().collect(),
|
||||
self.callbacks(),
|
||||
&self.callbacks().clone(),
|
||||
)?;
|
||||
|
||||
let name = device.name().to_owned();
|
||||
|
||||
self.device = Some(device);
|
||||
self.no_device_waker.wake();
|
||||
let name = self.device.name().to_owned();
|
||||
|
||||
let mut errs = Vec::new();
|
||||
for sentinel in dns_mapping.left_values() {
|
||||
@@ -217,10 +213,7 @@ where
|
||||
#[tracing::instrument(level = "trace", skip(self))]
|
||||
pub fn add_route(&mut self, route: IpNetwork) -> connlib_shared::Result<()> {
|
||||
let callbacks = self.callbacks().clone();
|
||||
self.device
|
||||
.as_mut()
|
||||
.ok_or(Error::ControlProtocolError)?
|
||||
.add_route(route, &callbacks)?;
|
||||
self.device.add_route(route, &callbacks)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -228,10 +221,7 @@ where
|
||||
#[tracing::instrument(level = "trace", skip(self))]
|
||||
pub fn remove_route(&mut self, route: IpNetwork) -> connlib_shared::Result<()> {
|
||||
let callbacks = self.callbacks().clone();
|
||||
self.device
|
||||
.as_mut()
|
||||
.ok_or(Error::ControlProtocolError)?
|
||||
.remove_route(route, &callbacks)?;
|
||||
self.device.remove_route(route, &callbacks)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -206,23 +206,21 @@ where
|
||||
|
||||
let ips: Vec<IpNetwork> = addrs.iter().copied().map(Into::into).collect();
|
||||
|
||||
if let Some(device) = self.device.as_ref() {
|
||||
send_dns_answer(
|
||||
&mut self.role_state,
|
||||
Rtype::Aaaa,
|
||||
device,
|
||||
&resource_description,
|
||||
&addrs,
|
||||
);
|
||||
send_dns_answer(
|
||||
&mut self.role_state,
|
||||
Rtype::Aaaa,
|
||||
&self.device,
|
||||
&resource_description,
|
||||
&addrs,
|
||||
);
|
||||
|
||||
send_dns_answer(
|
||||
&mut self.role_state,
|
||||
Rtype::A,
|
||||
device,
|
||||
&resource_description,
|
||||
&addrs,
|
||||
);
|
||||
}
|
||||
send_dns_answer(
|
||||
&mut self.role_state,
|
||||
Rtype::A,
|
||||
&self.device,
|
||||
&resource_description,
|
||||
&addrs,
|
||||
);
|
||||
|
||||
Ok(ips)
|
||||
}
|
||||
|
||||
@@ -26,44 +26,63 @@ use ip_network::IpNetwork;
|
||||
use pnet_packet::Packet;
|
||||
use std::io;
|
||||
use std::net::IpAddr;
|
||||
use std::task::{Context, Poll};
|
||||
use std::task::{Context, Poll, Waker};
|
||||
use std::time::{Duration, Instant};
|
||||
use tun::Tun;
|
||||
|
||||
pub struct Device {
|
||||
mtu: usize,
|
||||
tun: Tun,
|
||||
tun: Option<Tun>,
|
||||
waker: Option<Waker>,
|
||||
mtu_refreshed_at: Instant,
|
||||
}
|
||||
|
||||
impl Device {
|
||||
pub(crate) fn new() -> Self {
|
||||
Self {
|
||||
tun: None,
|
||||
mtu: 1_280,
|
||||
waker: None,
|
||||
mtu_refreshed_at: Instant::now(),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(target_family = "unix")]
|
||||
pub(crate) fn new(
|
||||
pub(crate) fn initialize(
|
||||
&mut self,
|
||||
config: &Interface,
|
||||
dns_config: Vec<IpAddr>,
|
||||
callbacks: &impl Callbacks<Error = Error>,
|
||||
) -> Result<Device, ConnlibError> {
|
||||
) -> Result<(), ConnlibError> {
|
||||
let tun = Tun::new(config, dns_config, callbacks)?;
|
||||
let mtu = ioctl::interface_mtu_by_name(tun.name())?;
|
||||
|
||||
Ok(Device {
|
||||
mtu,
|
||||
tun,
|
||||
mtu_refreshed_at: Instant::now(),
|
||||
})
|
||||
self.tun = Some(tun);
|
||||
self.mtu = mtu;
|
||||
|
||||
if let Some(waker) = self.waker.take() {
|
||||
waker.wake();
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(target_family = "windows")]
|
||||
pub(crate) fn new(
|
||||
pub(crate) fn initialize(
|
||||
&mut self,
|
||||
config: &Interface,
|
||||
dns_config: Vec<IpAddr>,
|
||||
_: &impl Callbacks<Error = Error>,
|
||||
) -> Result<Device, ConnlibError> {
|
||||
Ok(Device {
|
||||
tun: Tun::new(config, dns_config)?,
|
||||
mtu: 1_280,
|
||||
mtu_refreshed_at: Instant::now(),
|
||||
})
|
||||
) -> Result<(), ConnlibError> {
|
||||
let tun = Tun::new(config, dns_config)?;
|
||||
|
||||
self.tun = Some(tun);
|
||||
|
||||
if let Some(waker) = self.waker.take() {
|
||||
waker.wake();
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(target_family = "unix")]
|
||||
@@ -72,13 +91,20 @@ impl Device {
|
||||
buf: &'b mut [u8],
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<io::Result<MutableIpPacket<'b>>> {
|
||||
let Some(tun) = self.tun.as_mut() else {
|
||||
self.waker = Some(cx.waker().clone());
|
||||
return Poll::Pending;
|
||||
};
|
||||
|
||||
use pnet_packet::Packet as _;
|
||||
|
||||
if self.mtu_refreshed_at.elapsed() > Duration::from_secs(30) {
|
||||
self.refresh_mtu()?;
|
||||
let mtu = ioctl::interface_mtu_by_name(tun.name())?;
|
||||
self.mtu = mtu;
|
||||
self.mtu_refreshed_at = Instant::now();
|
||||
}
|
||||
|
||||
let n = std::task::ready!(self.tun.poll_read(&mut buf[..self.mtu()], cx))?;
|
||||
let n = std::task::ready!(tun.poll_read(&mut buf[..self.mtu], cx))?;
|
||||
|
||||
if n == 0 {
|
||||
return Poll::Ready(Err(io::Error::new(
|
||||
@@ -101,17 +127,22 @@ impl Device {
|
||||
|
||||
#[cfg(target_family = "windows")]
|
||||
pub(crate) fn poll_read<'b>(
|
||||
&self,
|
||||
&mut self,
|
||||
buf: &'b mut [u8],
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<io::Result<MutableIpPacket<'b>>> {
|
||||
let Some(tun) = self.tun.as_mut() else {
|
||||
self.waker = Some(cx.waker().clone());
|
||||
return Poll::Pending;
|
||||
};
|
||||
|
||||
use pnet_packet::Packet as _;
|
||||
|
||||
if self.mtu_refreshed_at.elapsed() > Duration::from_secs(30) {
|
||||
self.refresh_mtu()?;
|
||||
// TODO
|
||||
}
|
||||
|
||||
let n = std::task::ready!(self.tun.poll_read(&mut buf[..self.mtu()], cx))?;
|
||||
let n = std::task::ready!(tun.poll_read(&mut buf[..self.mtu], cx))?;
|
||||
|
||||
if n == 0 {
|
||||
return Poll::Ready(Err(io::Error::new(
|
||||
@@ -132,12 +163,11 @@ impl Device {
|
||||
Poll::Ready(Ok(packet))
|
||||
}
|
||||
|
||||
pub(crate) fn mtu(&self) -> usize {
|
||||
self.mtu
|
||||
}
|
||||
|
||||
pub(crate) fn name(&self) -> &str {
|
||||
self.tun.name()
|
||||
self.tun
|
||||
.as_ref()
|
||||
.map(|t| t.name())
|
||||
.unwrap_or("uninitialized")
|
||||
}
|
||||
|
||||
pub(crate) fn remove_route(
|
||||
@@ -145,7 +175,8 @@ impl Device {
|
||||
route: IpNetwork,
|
||||
callbacks: &impl Callbacks<Error = Error>,
|
||||
) -> Result<Option<Device>, Error> {
|
||||
self.tun.remove_route(route, callbacks)?;
|
||||
self.tun_mut()?.remove_route(route, callbacks)?;
|
||||
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
@@ -155,33 +186,30 @@ impl Device {
|
||||
route: IpNetwork,
|
||||
callbacks: &impl Callbacks<Error = Error>,
|
||||
) -> Result<Option<Device>, Error> {
|
||||
self.tun.add_route(route, callbacks)?;
|
||||
self.tun_mut()?.add_route(route, callbacks)?;
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
#[cfg(target_family = "unix")]
|
||||
fn refresh_mtu(&mut self) -> io::Result<()> {
|
||||
let mtu = ioctl::interface_mtu_by_name(self.tun.name())?;
|
||||
self.mtu = mtu;
|
||||
self.mtu_refreshed_at = Instant::now();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(target_family = "windows")]
|
||||
fn refresh_mtu(&self) -> io::Result<()> {
|
||||
// TODO
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn write(&self, packet: IpPacket<'_>) -> io::Result<usize> {
|
||||
tracing::trace!(target: "wire", to = "device", bytes = %packet.packet().len());
|
||||
|
||||
match packet {
|
||||
IpPacket::Ipv4Packet(msg) => self.tun.write4(msg.packet()),
|
||||
IpPacket::Ipv6Packet(msg) => self.tun.write6(msg.packet()),
|
||||
IpPacket::Ipv4Packet(msg) => self.tun()?.write4(msg.packet()),
|
||||
IpPacket::Ipv6Packet(msg) => self.tun()?.write6(msg.packet()),
|
||||
}
|
||||
}
|
||||
|
||||
fn tun(&self) -> io::Result<&Tun> {
|
||||
self.tun.as_ref().ok_or_else(io_error_not_initialized)
|
||||
}
|
||||
|
||||
fn tun_mut(&mut self) -> io::Result<&mut Tun> {
|
||||
self.tun.as_mut().ok_or_else(io_error_not_initialized)
|
||||
}
|
||||
}
|
||||
|
||||
fn io_error_not_initialized() -> io::Error {
|
||||
io::Error::new(io::ErrorKind::NotConnected, "device is not initialized yet")
|
||||
}
|
||||
|
||||
#[cfg(target_family = "unix")]
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
use crate::device_channel::Device;
|
||||
use crate::ip_packet::MutableIpPacket;
|
||||
use crate::peer::{PacketTransformGateway, Peer};
|
||||
use crate::peer_store::PeerStore;
|
||||
@@ -46,16 +45,18 @@ where
|
||||
#[tracing::instrument(level = "trace", skip(self))]
|
||||
pub fn set_interface(&mut self, config: &InterfaceConfig) -> connlib_shared::Result<()> {
|
||||
// Note: the dns fallback strategy is irrelevant for gateways
|
||||
let mut device = Device::new(config, vec![], self.callbacks())?;
|
||||
self.device
|
||||
.initialize(config, vec![], &self.callbacks().clone())?;
|
||||
|
||||
let result_v4 = device.add_route(PEERS_IPV4.parse().unwrap(), self.callbacks());
|
||||
let result_v6 = device.add_route(PEERS_IPV6.parse().unwrap(), self.callbacks());
|
||||
let result_v4 = self
|
||||
.device
|
||||
.add_route(PEERS_IPV4.parse().unwrap(), &self.callbacks().clone());
|
||||
let result_v6 = self
|
||||
.device
|
||||
.add_route(PEERS_IPV6.parse().unwrap(), &self.callbacks().clone());
|
||||
result_v4.or(result_v6)?;
|
||||
|
||||
let name = device.name().to_owned();
|
||||
|
||||
self.device = Some(device);
|
||||
self.no_device_waker.wake();
|
||||
let name = self.device.name().to_owned();
|
||||
|
||||
tracing::debug!(ip4 = %config.ipv4, ip6 = %config.ipv6, %name, "TUN device initialized");
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ use connlib_shared::{
|
||||
CallbackErrorFacade, Callbacks, Error, Result,
|
||||
};
|
||||
use device_channel::Device;
|
||||
use futures_util::{task::AtomicWaker, FutureExt};
|
||||
use futures_util::FutureExt;
|
||||
use peer::PacketTransform;
|
||||
use peer_store::PeerStore;
|
||||
use snownet::{Node, Server};
|
||||
@@ -60,8 +60,7 @@ pub struct Tunnel<CB: Callbacks, TRoleState, TRole, TId> {
|
||||
/// State that differs per role, i.e. clients vs gateways.
|
||||
role_state: TRoleState,
|
||||
|
||||
device: Option<Device>,
|
||||
no_device_waker: AtomicWaker,
|
||||
device: Device,
|
||||
|
||||
connections_state: ConnectionState<TRole, TId>,
|
||||
|
||||
@@ -73,14 +72,9 @@ where
|
||||
CB: Callbacks + 'static,
|
||||
{
|
||||
pub fn poll_next_event(&mut self, cx: &mut Context<'_>) -> Poll<Result<Event<GatewayId>>> {
|
||||
let Some(device) = self.device.as_mut() else {
|
||||
self.no_device_waker.register(cx.waker());
|
||||
return Poll::Pending;
|
||||
};
|
||||
|
||||
match self.role_state.poll_next_event(cx) {
|
||||
Poll::Ready(Event::SendPacket(packet)) => {
|
||||
device.write(packet)?;
|
||||
self.device.write(packet)?;
|
||||
cx.waker().wake_by_ref();
|
||||
}
|
||||
Poll::Ready(other) => return Poll::Ready(Ok(other)),
|
||||
@@ -96,10 +90,11 @@ where
|
||||
_ => (),
|
||||
}
|
||||
|
||||
match self
|
||||
.connections_state
|
||||
.poll_sockets(device, &mut self.role_state.peers, cx)?
|
||||
{
|
||||
match self.connections_state.poll_sockets(
|
||||
&mut self.device,
|
||||
&mut self.role_state.peers,
|
||||
cx,
|
||||
)? {
|
||||
Poll::Ready(()) => {
|
||||
cx.waker().wake_by_ref();
|
||||
}
|
||||
@@ -108,7 +103,7 @@ where
|
||||
|
||||
ready!(self.connections_state.sockets.poll_send_ready(cx))?; // Ensure socket is ready before we read from device.
|
||||
|
||||
match device.poll_read(&mut self.read_buf, cx)? {
|
||||
match self.device.poll_read(&mut self.read_buf, cx)? {
|
||||
Poll::Ready(packet) => {
|
||||
let Some((peer_id, packet)) = self.role_state.encapsulate(packet, Instant::now())
|
||||
else {
|
||||
@@ -144,11 +139,6 @@ where
|
||||
Poll::Pending => {}
|
||||
}
|
||||
|
||||
let Some(device) = self.device.as_mut() else {
|
||||
self.no_device_waker.register(cx.waker());
|
||||
return Poll::Pending;
|
||||
};
|
||||
|
||||
match self.connections_state.poll_next_event(cx) {
|
||||
Poll::Ready(Event::StopPeer(id)) => {
|
||||
self.role_state.peers.remove(&id);
|
||||
@@ -158,10 +148,11 @@ where
|
||||
_ => (),
|
||||
}
|
||||
|
||||
match self
|
||||
.connections_state
|
||||
.poll_sockets(device, &mut self.role_state.peers, cx)?
|
||||
{
|
||||
match self.connections_state.poll_sockets(
|
||||
&mut self.device,
|
||||
&mut self.role_state.peers,
|
||||
cx,
|
||||
)? {
|
||||
Poll::Ready(()) => {
|
||||
cx.waker().wake_by_ref();
|
||||
}
|
||||
@@ -170,7 +161,7 @@ where
|
||||
|
||||
ready!(self.connections_state.sockets.poll_send_ready(cx))?; // Ensure socket is ready before we read from device.
|
||||
|
||||
match device.poll_read(&mut self.read_buf, cx)? {
|
||||
match self.device.poll_read(&mut self.read_buf, cx)? {
|
||||
Poll::Ready(packet) => {
|
||||
let Some((peer_id, packet)) = self.role_state.encapsulate(packet) else {
|
||||
cx.waker().wake_by_ref();
|
||||
@@ -223,10 +214,9 @@ where
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
device: Default::default(),
|
||||
device: Device::new(),
|
||||
callbacks,
|
||||
role_state: Default::default(),
|
||||
no_device_waker: Default::default(),
|
||||
connections_state,
|
||||
read_buf: [0u8; MAX_UDP_SIZE],
|
||||
})
|
||||
|
||||
Reference in New Issue
Block a user