refactor(connlib): read from device as part of eventloop (#2520)

As a next step in refactoring the tunnel implementation, I am removing
the `device_handler` task and instead use a poll-based function to read
from the device. Removing the task means there is one less component
that accesses the `Tunnel` via shared-memory. The final one after this
PR is the `peer_handler`.

Once all shared-access is gone, we can stop using `Arc<Tunnel>` and with
it, remove all uses of `Mutex` in the tunnel and simply use `&mut self`.

To remove the `device_handler`, we introduce a `Device::poll_read`
function that we call as the very first thing in the `Tunnel`'s
poll-function. At a later point, we want to think about prioritization
within the event loop. I'd suggest deferring that until we have removed
the locks as handling the guards is a bit finicky at this stage.
This commit is contained in:
Thomas Eizinger
2023-11-03 11:47:26 +11:00
committed by GitHub
parent efe54cc2ec
commit b404f10d87
10 changed files with 304 additions and 344 deletions

View File

@@ -286,9 +286,9 @@ impl<CB: Callbacks + 'static> ControlPlane<CB> {
.await;
}
pub async fn handle_tunnel_event(&mut self, event: firezone_tunnel::Event<GatewayId>) {
pub async fn handle_tunnel_event(&mut self, event: Result<firezone_tunnel::Event<GatewayId>>) {
match event {
firezone_tunnel::Event::SignalIceCandidate { conn_id, candidate } => {
Ok(firezone_tunnel::Event::SignalIceCandidate { conn_id, candidate }) => {
if let Err(e) = self
.phoenix_channel
.send(EgressMessages::BroadcastIceCandidates(
@@ -302,11 +302,11 @@ impl<CB: Callbacks + 'static> ControlPlane<CB> {
tracing::error!("Failed to signal ICE candidate: {e}")
}
}
firezone_tunnel::Event::ConnectionIntent {
Ok(firezone_tunnel::Event::ConnectionIntent {
resource,
connected_gateway_ids,
reference,
} => {
}) => {
if let Err(e) = self
.phoenix_channel
.clone()
@@ -324,7 +324,7 @@ impl<CB: Callbacks + 'static> ControlPlane<CB> {
// TODO: Clean up connection in `ClientState` here?
}
}
firezone_tunnel::Event::DnsQuery(query) => {
Ok(firezone_tunnel::Event::DnsQuery(query)) => {
// Until we handle it better on a gateway-like eventloop, making sure not to block the loop
let Some(resolver) = self.fallback_resolver.lock().clone() else {
return;
@@ -332,14 +332,14 @@ impl<CB: Callbacks + 'static> ControlPlane<CB> {
let tunnel = self.tunnel.clone();
tokio::spawn(async move {
let response = resolver.lookup(query.name, query.record_type).await;
if let Err(err) = tunnel
.write_dns_lookup_response(response, query.query)
.await
{
if let Err(err) = tunnel.write_dns_lookup_response(response, query.query) {
tracing::error!(err = ?err, "DNS lookup failed: {err:#}");
}
});
}
Err(e) => {
tracing::error!("Tunnel failed: {e}");
}
}
}
}

View File

@@ -1,12 +1,10 @@
use crate::bounded_queue::BoundedQueue;
use crate::device_channel::create_iface;
use crate::device_channel::{create_iface, Packet};
use crate::ip_packet::{IpPacket, MutableIpPacket};
use crate::peer::WriteTo;
use crate::resource_table::ResourceTable;
use crate::{
dns, peer_by_ip, tokio_util, ConnectedPeer, Device, DnsQuery, Event, PeerConfig, RoleState,
Tunnel, DNS_QUERIES_QUEUE_SIZE, ICE_GATHERING_TIMEOUT_SECONDS, MAX_CONCURRENT_ICE_GATHERING,
MAX_UDP_SIZE,
dns, ConnectedPeer, DnsQuery, Event, PeerConfig, RoleState, Tunnel, DNS_QUERIES_QUEUE_SIZE,
ICE_GATHERING_TIMEOUT_SECONDS, MAX_CONCURRENT_ICE_GATHERING,
};
use boringtun::x25519::{PublicKey, StaticSecret};
use connlib_shared::error::{ConnlibError as Error, ConnlibError};
@@ -18,14 +16,12 @@ use connlib_shared::{Callbacks, DNS_SENTINEL};
use futures::channel::mpsc::Receiver;
use futures::stream;
use futures_bounded::{PushError, StreamMap};
use futures_util::SinkExt;
use hickory_resolver::lookup::Lookup;
use ip_network::IpNetwork;
use ip_network_table::IpNetworkTable;
use std::collections::hash_map::Entry;
use std::collections::{HashMap, HashSet, VecDeque};
use std::collections::{HashMap, HashSet};
use std::net::IpAddr;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Duration;
use tokio::time::Instant;
@@ -41,7 +37,7 @@ where
/// and packets will be wrapped with wireguard and sent through it.
#[tracing::instrument(level = "trace", skip(self))]
pub async fn add_resource(
self: &Arc<Self>,
&self,
resource_description: ResourceDescription,
) -> connlib_shared::Result<()> {
let mut any_valid_route = false;
@@ -71,13 +67,13 @@ where
/// Writes the response to a DNS lookup
#[tracing::instrument(level = "trace", skip(self))]
pub async fn write_dns_lookup_response(
self: &Arc<Self>,
pub fn write_dns_lookup_response(
&self,
response: hickory_resolver::error::ResolveResult<Lookup>,
query: IpPacket<'static>,
) -> connlib_shared::Result<()> {
if let Some(pkt) = dns::build_response_from_resolve_result(query, response)? {
let Some(ref device) = *self.device.read().await else {
let Some(ref device) = *self.device.read() else {
return Ok(());
};
@@ -89,17 +85,11 @@ where
/// Sets the interface configuration and starts background tasks.
#[tracing::instrument(level = "trace", skip(self))]
pub async fn set_interface(
self: &Arc<Self>,
config: &InterfaceConfig,
) -> connlib_shared::Result<()> {
pub async fn set_interface(&self, config: &InterfaceConfig) -> connlib_shared::Result<()> {
let device = create_iface(config, self.callbacks()).await?;
*self.device.write().await = Some(device.clone());
*self.iface_handler_abort.lock() = Some(tokio_util::spawn_log(
&self.callbacks,
device_handler(Arc::clone(self), device),
));
*self.device.write() = Some(device.clone());
self.no_device_waker.wake();
self.add_route(DNS_SENTINEL.into()).await?;
@@ -118,95 +108,24 @@ where
}
#[tracing::instrument(level = "trace", skip(self))]
async fn add_route(self: &Arc<Self>, route: IpNetwork) -> connlib_shared::Result<()> {
let mut device = self.device.write().await;
async fn add_route(&self, route: IpNetwork) -> connlib_shared::Result<()> {
let device = self
.device
.write()
.take()
.ok_or(Error::ControlProtocolError)?;
if let Some(new_device) = device
.as_ref()
.ok_or(Error::ControlProtocolError)?
let new_device = device
.config
.add_route(route, self.callbacks())
.await?
{
*device = Some(new_device.clone());
*self.iface_handler_abort.lock() = Some(tokio_util::spawn_log(
&self.callbacks,
device_handler(Arc::clone(self), new_device),
));
}
.unwrap_or(device); // Restore the old device.
*self.device.write() = Some(new_device);
Ok(())
}
}
/// Reads IP packets from the [`Device`] and handles them accordingly.
async fn device_handler<CB>(
tunnel: Arc<Tunnel<CB, ClientState>>,
mut device: Device,
) -> Result<(), ConnlibError>
where
CB: Callbacks + 'static,
{
let device_writer = device.io.clone();
let mut buf = [0u8; MAX_UDP_SIZE];
'outer: loop {
let Some(packet) = device.read().await? else {
return Ok(());
};
let dest = packet.destination();
let (peer_conn_id, peer_channel, maybe_write_to) = {
let peers_by_ip = tunnel.peers_by_ip.read();
let peer = peer_by_ip(&peers_by_ip, dest);
let result = tunnel
.role_state
.lock()
.handle_new_packet(packet, peer, &mut buf);
let maybe_write_to = match result {
Ok(None) => continue,
Ok(Some(write_to)) => Ok(write_to),
Err(e) => Err(e),
};
let peer = peer.expect("must have peer if we should write bytes");
(peer.inner.conn_id, peer.channel.clone(), maybe_write_to)
};
let error = match maybe_write_to {
Ok(WriteTo::Network(mut packets)) => loop {
let Some(packet) = packets.pop_front() else {
continue 'outer;
};
match peer_channel.write(&packet).await {
Ok(_) => continue,
Err(e) => break ConnlibError::IceDataError(e),
}
},
Ok(WriteTo::Resource(packet)) => match device_writer.write(packet) {
Ok(_) => continue,
Err(e) => ConnlibError::Io(e),
},
Err(e) => e,
};
tracing::error!(resource_address = %dest, err = ?error, "failed to handle packet {error:#}");
let _ = tunnel.callbacks.on_error(&error);
if error.is_fatal_connection_error() {
let _ = tunnel
.stop_peer_command_sender
.clone()
.send(peer_conn_id)
.await;
}
}
}
/// [`Tunnel`] state specific to clients.
pub struct ClientState {
active_candidate_receivers: StreamMap<GatewayId, RTCIceCandidateInit>,
@@ -233,35 +152,23 @@ pub struct AwaitingConnectionDetails {
}
impl ClientState {
pub(crate) fn handle_new_packet<'b>(
/// Attempt to handle the given packet as a DNS packet.
///
/// Returns `Ok` if the packet is in fact a DNS query with an optional response to send back.
/// Returns `Err` if the packet is not a DNS query.
pub(crate) fn handle_dns<'a>(
&mut self,
packet: MutableIpPacket,
peer: Option<&ConnectedPeer<GatewayId>>,
buf: &'b mut [u8],
) -> Result<Option<WriteTo<'b>>, ConnlibError> {
packet: MutableIpPacket<'a>,
) -> Result<Option<Packet<'a>>, MutableIpPacket<'a>> {
match dns::parse(&self.resources, packet.as_immutable()) {
Some(dns::ResolveStrategy::LocalResponse(pkt)) => {
return Ok(Some(WriteTo::Resource(pkt)))
}
Some(dns::ResolveStrategy::LocalResponse(pkt)) => Ok(Some(pkt)),
Some(dns::ResolveStrategy::ForwardQuery(query)) => {
self.add_pending_dns_query(query);
return Ok(None);
Ok(None)
}
None => {}
None => Err(packet),
}
let dest = packet.destination();
let Some(peer) = peer else {
self.on_connection_intent(dest);
return Ok(None);
};
let Some(bytes) = peer.inner.encapsulate(packet, dest, buf)? else {
return Ok(None);
};
Ok(Some(WriteTo::Network(VecDeque::from([bytes]))))
}
pub(crate) fn attempt_to_reuse_connection(

View File

@@ -65,7 +65,7 @@ where
tracing::trace!(?peer_config.ips, "new_data_channel_open");
Box::pin(async move {
{
let Some(device) = tunnel.device.read().await.clone() else {
let Some(device) = tunnel.device.read().clone() else {
let e = Error::NoIface;
tracing::error!(err = ?e, "channel_open");
let _ = tunnel.callbacks().on_error(&e);

View File

@@ -1,16 +1,18 @@
use std::io;
use std::sync::{
atomic::{AtomicUsize, Ordering::Relaxed},
Arc,
};
use std::task::{ready, Context, Poll};
use connlib_shared::{messages::Interface, CallbackErrorFacade, Callbacks, Result};
use ip_network::IpNetwork;
use tokio::io::{unix::AsyncFd, Interest};
use tokio::io::{unix::AsyncFd, Ready};
use tun::{IfaceDevice, IfaceStream};
use crate::device_channel::Packet;
use crate::{Device, MAX_UDP_SIZE};
use crate::Device;
mod tun;
@@ -23,16 +25,27 @@ pub(crate) struct IfaceConfig {
pub(crate) struct DeviceIo(Arc<AsyncFd<IfaceStream>>);
impl DeviceIo {
pub async fn read(&self, out: &mut [u8]) -> std::io::Result<usize> {
self.0
.async_io(Interest::READABLE, |inner| inner.read(out))
.await
pub fn poll_read(&self, out: &mut [u8], cx: &mut Context<'_>) -> Poll<io::Result<usize>> {
loop {
let mut guard = ready!(self.0.poll_read_ready(cx))?;
match guard.get_inner().read(out) {
Ok(n) => return Poll::Ready(Ok(n)),
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
// a read has blocked, but a write might still succeed.
// clear only the read readiness.
guard.clear_ready_matching(Ready::READABLE);
continue;
}
Err(e) => return Poll::Ready(Err(e)),
}
}
}
// Note: write is synchronous because it's non-blocking
// and some losiness is acceptable and increseases performance
// since we don't block the reading loops.
pub fn write(&self, packet: Packet<'_>) -> std::io::Result<usize> {
pub fn write(&self, packet: Packet<'_>) -> io::Result<usize> {
match packet {
Packet::Ipv4(msg) => self.0.get_ref().write4(&msg),
Packet::Ipv6(msg) => self.0.get_ref().write6(&msg),
@@ -65,11 +78,7 @@ impl IfaceConfig {
iface,
mtu: AtomicUsize::new(mtu),
});
Ok(Some(Device {
io,
config,
buf: Box::new([0u8; MAX_UDP_SIZE]),
}))
Ok(Some(Device { io, config }))
}
}
@@ -86,9 +95,5 @@ pub(crate) async fn create_iface(
mtu: AtomicUsize::new(mtu),
});
Ok(Device {
io,
config,
buf: Box::new([0u8; MAX_UDP_SIZE]),
})
Ok(Device { io, config })
}

View File

@@ -2,6 +2,7 @@ use crate::device_channel::Packet;
use crate::Device;
use connlib_shared::{messages::Interface, CallbackErrorFacade, Callbacks, Result};
use ip_network::IpNetwork;
use std::task::{Context, Poll};
#[derive(Clone)]
pub(crate) struct DeviceIo;
@@ -9,7 +10,7 @@ pub(crate) struct DeviceIo;
pub(crate) struct IfaceConfig;
impl DeviceIo {
pub async fn read(&self, _: &mut [u8]) -> std::io::Result<usize> {
pub fn poll_read(&self, _: &mut [u8], _: &mut Context<'_>) -> Poll<std::io::Result<usize>> {
todo!()
}

View File

@@ -1,16 +1,11 @@
use crate::device_channel::create_iface;
use crate::{
peer_by_ip, Device, Event, RoleState, Tunnel, ICE_GATHERING_TIMEOUT_SECONDS,
MAX_CONCURRENT_ICE_GATHERING, MAX_UDP_SIZE,
Event, RoleState, Tunnel, ICE_GATHERING_TIMEOUT_SECONDS, MAX_CONCURRENT_ICE_GATHERING,
};
use connlib_shared::error::ConnlibError;
use connlib_shared::messages::{ClientId, Interface as InterfaceConfig};
use connlib_shared::Callbacks;
use futures::channel::mpsc::Receiver;
use futures_bounded::{PushError, StreamMap};
use futures_util::SinkExt;
use std::net::IpAddr;
use std::sync::Arc;
use std::task::{ready, Context, Poll};
use std::time::Duration;
use webrtc::ice_transport::ice_candidate::RTCIceCandidateInit;
@@ -21,15 +16,11 @@ where
{
/// Sets the interface configuration and starts background tasks.
#[tracing::instrument(level = "trace", skip(self))]
pub async fn set_interface(
self: &Arc<Self>,
config: &InterfaceConfig,
) -> connlib_shared::Result<()> {
pub async fn set_interface(&self, config: &InterfaceConfig) -> connlib_shared::Result<()> {
let device = create_iface(config, self.callbacks()).await?;
*self.device.write().await = Some(device.clone());
*self.iface_handler_abort.lock() =
Some(tokio::spawn(device_handler(Arc::clone(self), device)).abort_handle());
*self.device.write() = Some(device.clone());
self.no_device_waker.wake();
tracing::debug!("background_loop_started");
@@ -43,67 +34,6 @@ where
}
}
/// Reads IP packets from the [`Device`] and handles them accordingly.
async fn device_handler<CB>(
tunnel: Arc<Tunnel<CB, GatewayState>>,
mut device: Device,
) -> Result<(), ConnlibError>
where
CB: Callbacks + 'static,
{
let mut buf = [0u8; MAX_UDP_SIZE];
loop {
let Some(packet) = device.read().await? else {
// Reading a bad IP packet or otherwise from the device seems bad. Should we restart the tunnel or something?
return Ok(());
};
let dest = packet.destination();
let (result, channel, peer_conn_id) = {
let peers_by_ip = tunnel.peers_by_ip.read();
let Some(peer) = peer_by_ip(&peers_by_ip, dest) else {
continue;
};
let result = peer.inner.encapsulate(packet, dest, &mut buf);
let channel = peer.channel.clone();
(result, channel, peer.inner.conn_id)
};
let error = match result {
Ok(None) => continue,
Ok(Some(b)) => match channel.write(&b).await {
Ok(_) => continue,
Err(e) => ConnlibError::IceDataError(e),
},
Err(e) => e,
};
on_error(&tunnel, dest, error, peer_conn_id).await
}
}
async fn on_error<CB>(
tunnel: &Tunnel<CB, GatewayState>,
dest: IpAddr,
e: ConnlibError,
peer_conn_id: ClientId,
) where
CB: Callbacks + 'static,
{
tracing::error!(resource_address = %dest, err = ?e, "failed to handle packet {e:#}");
if e.is_fatal_connection_error() {
let _ = tunnel
.stop_peer_command_sender
.clone()
.send(peer_conn_id)
.await;
}
}
/// [`Tunnel`] state specific to gateways.
pub struct GatewayState {
candidate_receivers: StreamMap<ClientId, RTCIceCandidateInit>,

View File

@@ -16,7 +16,7 @@ use pnet_packet::Packet;
use hickory_resolver::proto::rr::RecordType;
use parking_lot::{Mutex, RwLock};
use peer::{Peer, PeerStats};
use tokio::{task::AbortHandle, time::MissedTickBehavior};
use tokio::time::MissedTickBehavior;
use webrtc::{
api::{
interceptor_registry::register_default_interceptors, media_engine::MediaEngine,
@@ -27,10 +27,11 @@ use webrtc::{
};
use futures::channel::mpsc;
use futures_util::task::AtomicWaker;
use futures_util::{SinkExt, StreamExt};
use itertools::Itertools;
use std::collections::VecDeque;
use std::task::{Context, Poll};
use std::task::{ready, Context, Poll};
use std::{collections::HashMap, fmt, io, net::IpAddr, sync::Arc, time::Duration};
use std::{collections::HashSet, hash::Hash};
use tokio::time::Interval;
@@ -45,12 +46,13 @@ use connlib_shared::{
use device_channel::{DeviceIo, IfaceConfig};
pub use client::ClientState;
use connlib_shared::error::ConnlibError;
pub use control_protocol::Request;
pub use gateway::GatewayState;
pub use webrtc::peer_connection::sdp::session_description::RTCSessionDescription;
use crate::ip_packet::MutableIpPacket;
use connlib_shared::messages::SecretKey;
use connlib_shared::messages::{ClientId, SecretKey};
use index::IndexLfsr;
mod bounded_queue;
@@ -64,7 +66,6 @@ mod ip_packet;
mod peer;
mod peer_handler;
mod resource_table;
mod tokio_util;
const MAX_UDP_SIZE: usize = (1 << 16) - 1;
const DNS_QUERIES_QUEUE_SIZE: usize = 100;
@@ -108,35 +109,34 @@ impl From<connlib_shared::messages::Peer> for PeerConfig {
struct Device {
config: Arc<IfaceConfig>,
io: DeviceIo,
buf: Box<[u8; MAX_UDP_SIZE]>,
}
impl Device {
async fn read(&mut self) -> io::Result<Option<MutableIpPacket<'_>>> {
let res = self.io.read(&mut self.buf[..self.config.mtu()]).await?;
tracing::trace!(target: "wire", action = "read", bytes = res, from = "iface");
fn poll_read<'b>(
&mut self,
buf: &'b mut [u8],
cx: &mut Context<'_>,
) -> Poll<io::Result<Option<MutableIpPacket<'b>>>> {
let res = ready!(self.io.poll_read(&mut buf[..self.config.mtu()], cx))?;
if res == 0 {
return Ok(None);
return Poll::Ready(Ok(None));
}
Ok(Some(
MutableIpPacket::new(&mut self.buf[..res]).ok_or_else(|| {
Poll::Ready(Ok(Some(MutableIpPacket::new(&mut buf[..res]).ok_or_else(
|| {
io::Error::new(
io::ErrorKind::InvalidInput,
"received bytes are not an IP packet",
)
})?,
))
},
)?)))
}
}
/// Tunnel is a wireguard state machine that uses webrtc's ICE channels instead of UDP sockets to communicate between peers.
pub struct Tunnel<CB: Callbacks, TRoleState: RoleState> {
next_index: Mutex<IndexLfsr>,
// We use a tokio Mutex here since this is only read/write during config so there's no relevant performance impact
device: tokio::sync::RwLock<Option<Device>>,
rate_limiter: Arc<RateLimiter>,
private_key: StaticSecret,
public_key: PublicKey,
@@ -145,7 +145,6 @@ pub struct Tunnel<CB: Callbacks, TRoleState: RoleState> {
peer_connections: Mutex<HashMap<TRoleState::Id, Arc<RTCPeerConnection>>>,
webrtc_api: API,
callbacks: CallbackErrorFacade<CB>,
iface_handler_abort: Mutex<Option<AbortHandle>>,
/// State that differs per role, i.e. clients vs gateways.
role_state: Mutex<TRoleState>,
@@ -158,6 +157,142 @@ pub struct Tunnel<CB: Callbacks, TRoleState: RoleState> {
mtu_refresh_interval: Mutex<Interval>,
peers_to_stop: Mutex<VecDeque<TRoleState::Id>>,
device: RwLock<Option<Device>>,
read_buf: Mutex<Box<[u8; MAX_UDP_SIZE]>>,
write_buf: Mutex<Box<[u8; MAX_UDP_SIZE]>>,
no_device_waker: AtomicWaker,
}
impl<CB> Tunnel<CB, ClientState>
where
CB: Callbacks + 'static,
{
pub async fn next_event(&self) -> Result<Event<GatewayId>> {
std::future::poll_fn(|cx| loop {
{
let mut guard = self.device.write();
if let Some(device) = guard.as_mut() {
match self.poll_device(device, cx) {
Poll::Ready(Ok(Some(event))) => return Poll::Ready(Ok(event)),
Poll::Ready(Ok(None)) => {
tracing::info!("Device stopped");
guard.take();
continue;
}
Poll::Ready(Err(e)) => {
guard.take(); // Ensure we don't poll a failed device again.
return Poll::Ready(Err(e));
}
Poll::Pending => {}
}
} else {
self.no_device_waker.register(cx.waker());
}
}
match self.poll_next_event_common(cx) {
Poll::Ready(event) => return Poll::Ready(Ok(event)),
Poll::Pending => {}
}
return Poll::Pending;
})
.await
}
pub(crate) fn poll_device(
&self,
device: &mut Device,
cx: &mut Context<'_>,
) -> Poll<Result<Option<Event<GatewayId>>>> {
loop {
let mut read_guard = self.read_buf.lock();
let mut write_guard = self.write_buf.lock();
let read_buf = read_guard.as_mut_slice();
let write_buf = write_guard.as_mut_slice();
let Some(packet) = ready!(device.poll_read(read_buf, cx))? else {
return Poll::Ready(Ok(None));
};
let mut role_state = self.role_state.lock();
let packet = match role_state.handle_dns(packet) {
Ok(Some(response)) => {
device.io.write(response)?;
continue;
}
Ok(None) => continue,
Err(non_dns_packet) => non_dns_packet,
};
let dest = packet.destination();
let peers_by_ip = self.peers_by_ip.read();
let Some(peer) = peer_by_ip(&peers_by_ip, dest) else {
role_state.on_connection_intent(dest);
continue;
};
self.encapsulate(write_buf, packet, dest, peer);
continue;
}
}
}
impl<CB> Tunnel<CB, GatewayState>
where
CB: Callbacks + 'static,
{
pub fn poll_next_event(&self, cx: &mut Context<'_>) -> Poll<Result<Event<ClientId>>> {
let mut read_guard = self.read_buf.lock();
let mut write_guard = self.write_buf.lock();
let read_buf = read_guard.as_mut_slice();
let write_buf = write_guard.as_mut_slice();
loop {
{
let mut device = self.device.write();
match device.as_mut().map(|d| d.poll_read(read_buf, cx)) {
Some(Poll::Ready(Ok(Some(packet)))) => {
let dest = packet.destination();
let peers_by_ip = self.peers_by_ip.read();
let Some(peer) = peer_by_ip(&peers_by_ip, dest) else {
continue;
};
self.encapsulate(write_buf, packet, dest, peer);
continue;
}
Some(Poll::Ready(Ok(None))) => {
tracing::info!("Device stopped");
drop(device.take());
}
Some(Poll::Ready(Err(e))) => return Poll::Ready(Err(ConnlibError::Io(e))),
Some(Poll::Pending) => {
// device not ready for reading, moving on ..
}
None => {
self.no_device_waker.register(cx.waker());
}
}
}
match self.poll_next_event_common(cx) {
Poll::Ready(e) => return Poll::Ready(Ok(e)),
Poll::Pending => {}
}
return Poll::Pending;
}
}
}
pub struct ConnectedPeer<TId> {
@@ -195,11 +330,7 @@ where
}
}
pub async fn next_event(&self) -> Event<TRoleState::Id> {
std::future::poll_fn(|cx| self.poll_next_event(cx)).await
}
pub fn poll_next_event(&self, cx: &mut Context<'_>) -> Poll<Event<TRoleState::Id>> {
fn poll_next_event_common(&self, cx: &mut Context<'_>) -> Poll<Event<TRoleState::Id>> {
loop {
if let Some(conn_id) = self.peers_to_stop.lock().pop_front() {
let mut peers = self.peers_by_ip.write();
@@ -291,21 +422,9 @@ where
}
if self.mtu_refresh_interval.lock().poll_tick(cx).is_ready() {
// We use `try_read` to acquire a lock on the device because we are within a synchronous context here and cannot use `.await`.
// The device is only updated during `add_route` and `set_interface` which would be extremely unlucky to hit at the same time as this timer.
// Even if we hit this, we just wait for the next tick to update the MTU.
let device = match self.device.try_read().map(|d| d.clone()) {
Ok(Some(device)) => device,
Ok(None) => {
let err = Error::ControlProtocolError;
tracing::error!(?err, "get_iface_config");
let _ = self.callbacks.on_error(&err);
continue;
}
Err(_) => {
tracing::debug!("Unlucky! Somebody is updating the device just as we are about to update its MTU, trying again on the next tick ...");
continue;
}
let Some(device) = self.device.read().clone() else {
tracing::debug!("Device temporarily not available");
continue;
};
tokio::spawn({
@@ -335,6 +454,40 @@ where
return Poll::Pending;
}
}
fn encapsulate(
&self,
write_buf: &mut [u8],
packet: MutableIpPacket,
dest: IpAddr,
peer: &ConnectedPeer<TRoleState::Id>,
) {
let peer_id = peer.inner.conn_id;
match peer.inner.encapsulate(packet, dest, write_buf) {
Ok(None) => {}
Ok(Some(b)) => {
tokio::spawn({
let channel = peer.channel.clone();
let mut sender = self.stop_peer_command_sender.clone();
async move {
if let Err(e) = channel.write(&b).await {
tracing::error!(resource_address = %dest, err = ?e, "failed to handle packet {e:#}");
let _ = sender.send(peer_id).await;
}
}
});
}
Err(e) => {
tracing::error!(resource_address = %dest, err = ?e, "failed to handle packet {e:#}");
if e.is_fatal_connection_error() {
self.peers_to_stop.lock().push_back(peer_id);
}
}
};
}
}
pub(crate) fn peer_by_ip<Id>(
@@ -403,7 +556,6 @@ where
let next_index = Default::default();
let peer_connections = Default::default();
let device = Default::default();
let iface_handler_abort = Default::default();
// ICE
let mut media_engine = MediaEngine::default();
@@ -433,8 +585,9 @@ where
next_index,
webrtc_api,
device,
read_buf: Mutex::new(Box::new([0u8; MAX_UDP_SIZE])),
write_buf: Mutex::new(Box::new([0u8; MAX_UDP_SIZE])),
callbacks: CallbackErrorFacade(callbacks),
iface_handler_abort,
role_state: Default::default(),
stop_peer_command_receiver: Mutex::new(stop_peer_command_receiver),
stop_peer_command_sender,
@@ -442,6 +595,7 @@ where
peer_refresh_interval: Mutex::new(peer_refresh_interval()),
mtu_refresh_interval: Mutex::new(mtu_refresh_interval()),
peers_to_stop: Default::default(),
no_device_waker: Default::default(),
})
}

View File

@@ -1,6 +1,7 @@
use std::sync::Arc;
use std::time::Duration;
use connlib_shared::{Callbacks, Error, Result};
use connlib_shared::Callbacks;
use futures_util::SinkExt;
use webrtc::data::data_channel::DataChannel;
@@ -18,15 +19,15 @@ where
channel: Arc<DataChannel>,
) {
loop {
let Some(device) = self.device.read().await.clone() else {
let err = Error::NoIface;
tracing::error!(?err);
let _ = self.callbacks().on_disconnect(Some(&err));
break;
let Some(device) = self.device.read().clone() else {
tracing::debug!("Device temporarily not available");
tokio::time::sleep(Duration::from_millis(100)).await;
continue;
};
let device_io = device.io;
let result = self.peer_handler(&peer, channel.clone(), device_io).await;
let result =
peer_handler(self.callbacks.clone(), &peer, channel.clone(), device_io).await;
if matches!(result, Err(ref err) if err.raw_os_error() == Some(9)) {
tracing::warn!("bad_file_descriptor");
@@ -47,66 +48,51 @@ where
.send(peer.conn_id)
.await;
}
}
async fn peer_handler(
self: &Arc<Self>,
peer: &Arc<Peer<TRoleState::Id>>,
channel: Arc<DataChannel>,
device_io: DeviceIo,
) -> std::io::Result<()> {
let mut src_buf = [0u8; MAX_UDP_SIZE];
let mut dst_buf = [0u8; MAX_UDP_SIZE];
while let Ok(size) = channel.read(&mut src_buf[..]).await {
tracing::trace!(target: "wire", action = "read", bytes = size, from = "peer");
async fn peer_handler<TId>(
callbacks: impl Callbacks,
peer: &Arc<Peer<TId>>,
channel: Arc<DataChannel>,
device_io: DeviceIo,
) -> std::io::Result<()>
where
TId: Copy,
{
let mut src_buf = [0u8; MAX_UDP_SIZE];
let mut dst_buf = [0u8; MAX_UDP_SIZE];
while let Ok(size) = channel.read(&mut src_buf[..]).await {
tracing::trace!(target: "wire", action = "read", bytes = size, from = "peer");
// TODO: Double check that this can only happen on closed channel
// I think it's possible to transmit a 0-byte message through the channel
// but we would never use that.
// We should keep track of an open/closed channel ourselves if we wanted to do it properly then.
if size == 0 {
break;
}
match self
.handle_peer_packet(peer, &channel, &device_io, &src_buf[..size], &mut dst_buf)
.await
{
Err(Error::Io(e)) => return Err(e),
Err(other) => {
tracing::error!(error = ?other, "failed to handle peer packet");
let _ = self.callbacks.on_error(&other);
}
_ => {}
}
// TODO: Double check that this can only happen on closed channel
// I think it's possible to transmit a 0-byte message through the channel
// but we would never use that.
// We should keep track of an open/closed channel ourselves if we wanted to do it properly then.
if size == 0 {
break;
}
Ok(())
}
let src = &src_buf[..size];
#[inline(always)]
pub(crate) async fn handle_peer_packet(
self: &Arc<Self>,
peer: &Arc<Peer<TRoleState::Id>>,
channel: &DataChannel,
device_writer: &DeviceIo,
src: &[u8],
dst: &mut [u8],
) -> Result<()> {
match peer.decapsulate(src, dst)? {
Some(WriteTo::Network(bytes)) => {
match peer.decapsulate(src, &mut dst_buf) {
Ok(Some(WriteTo::Network(bytes))) => {
for packet in bytes {
if let Err(e) = channel.write(&packet).await {
tracing::error!("Couldn't send packet to connected peer: {e}");
let _ = self.callbacks.on_error(&e.into());
let _ = callbacks.on_error(&e.into());
}
}
}
Some(WriteTo::Resource(packet)) => {
device_writer.write(packet)?;
Ok(Some(WriteTo::Resource(packet))) => {
device_io.write(packet)?;
}
Ok(None) => {}
Err(other) => {
tracing::error!(error = ?other, "failed to handle peer packet");
let _ = callbacks.on_error(&other);
}
None => {}
}
Ok(())
}
Ok(())
}

View File

@@ -1,23 +0,0 @@
use connlib_shared::error::ConnlibError;
use connlib_shared::Callbacks;
use std::future::Future;
/// Spawns a task into the [`tokio`] runtime.
///
/// On error, [`Callbacks::on_error`] is invoked.
/// This also returns a [`tokio::task::AbortHandle`] which MAY be used to abort the task.
/// If you don't need it, you are free to drop it.
/// It won't terminate the task.
pub(crate) fn spawn_log(
cb: &(impl Callbacks + 'static),
f: impl Future<Output = Result<(), ConnlibError>> + Send + 'static,
) -> tokio::task::AbortHandle {
let cb = cb.clone();
tokio::spawn(async move {
if let Err(e) = f.await {
let _ = cb.on_error(&e);
}
})
.abort_handle()
}

View File

@@ -173,7 +173,7 @@ impl Eventloop {
_ => {}
}
match self.tunnel.poll_next_event(cx) {
match self.tunnel.poll_next_event(cx)? {
Poll::Ready(firezone_tunnel::Event::SignalIceCandidate {
conn_id: client,
candidate,