mirror of
https://github.com/outbackdingo/firezone.git
synced 2026-01-27 18:18:55 +00:00
refactor(connlib): read from device as part of eventloop (#2520)
As a next step in refactoring the tunnel implementation, I am removing the `device_handler` task and instead use a poll-based function to read from the device. Removing the task means there is one less component that accesses the `Tunnel` via shared-memory. The final one after this PR is the `peer_handler`. Once all shared-access is gone, we can stop using `Arc<Tunnel>` and with it, remove all uses of `Mutex` in the tunnel and simply use `&mut self`. To remove the `device_handler`, we introduce a `Device::poll_read` function that we call as the very first thing in the `Tunnel`'s poll-function. At a later point, we want to think about prioritization within the event loop. I'd suggest deferring that until we have removed the locks as handling the guards is a bit finicky at this stage.
This commit is contained in:
@@ -286,9 +286,9 @@ impl<CB: Callbacks + 'static> ControlPlane<CB> {
|
||||
.await;
|
||||
}
|
||||
|
||||
pub async fn handle_tunnel_event(&mut self, event: firezone_tunnel::Event<GatewayId>) {
|
||||
pub async fn handle_tunnel_event(&mut self, event: Result<firezone_tunnel::Event<GatewayId>>) {
|
||||
match event {
|
||||
firezone_tunnel::Event::SignalIceCandidate { conn_id, candidate } => {
|
||||
Ok(firezone_tunnel::Event::SignalIceCandidate { conn_id, candidate }) => {
|
||||
if let Err(e) = self
|
||||
.phoenix_channel
|
||||
.send(EgressMessages::BroadcastIceCandidates(
|
||||
@@ -302,11 +302,11 @@ impl<CB: Callbacks + 'static> ControlPlane<CB> {
|
||||
tracing::error!("Failed to signal ICE candidate: {e}")
|
||||
}
|
||||
}
|
||||
firezone_tunnel::Event::ConnectionIntent {
|
||||
Ok(firezone_tunnel::Event::ConnectionIntent {
|
||||
resource,
|
||||
connected_gateway_ids,
|
||||
reference,
|
||||
} => {
|
||||
}) => {
|
||||
if let Err(e) = self
|
||||
.phoenix_channel
|
||||
.clone()
|
||||
@@ -324,7 +324,7 @@ impl<CB: Callbacks + 'static> ControlPlane<CB> {
|
||||
// TODO: Clean up connection in `ClientState` here?
|
||||
}
|
||||
}
|
||||
firezone_tunnel::Event::DnsQuery(query) => {
|
||||
Ok(firezone_tunnel::Event::DnsQuery(query)) => {
|
||||
// Until we handle it better on a gateway-like eventloop, making sure not to block the loop
|
||||
let Some(resolver) = self.fallback_resolver.lock().clone() else {
|
||||
return;
|
||||
@@ -332,14 +332,14 @@ impl<CB: Callbacks + 'static> ControlPlane<CB> {
|
||||
let tunnel = self.tunnel.clone();
|
||||
tokio::spawn(async move {
|
||||
let response = resolver.lookup(query.name, query.record_type).await;
|
||||
if let Err(err) = tunnel
|
||||
.write_dns_lookup_response(response, query.query)
|
||||
.await
|
||||
{
|
||||
if let Err(err) = tunnel.write_dns_lookup_response(response, query.query) {
|
||||
tracing::error!(err = ?err, "DNS lookup failed: {err:#}");
|
||||
}
|
||||
});
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("Tunnel failed: {e}");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,12 +1,10 @@
|
||||
use crate::bounded_queue::BoundedQueue;
|
||||
use crate::device_channel::create_iface;
|
||||
use crate::device_channel::{create_iface, Packet};
|
||||
use crate::ip_packet::{IpPacket, MutableIpPacket};
|
||||
use crate::peer::WriteTo;
|
||||
use crate::resource_table::ResourceTable;
|
||||
use crate::{
|
||||
dns, peer_by_ip, tokio_util, ConnectedPeer, Device, DnsQuery, Event, PeerConfig, RoleState,
|
||||
Tunnel, DNS_QUERIES_QUEUE_SIZE, ICE_GATHERING_TIMEOUT_SECONDS, MAX_CONCURRENT_ICE_GATHERING,
|
||||
MAX_UDP_SIZE,
|
||||
dns, ConnectedPeer, DnsQuery, Event, PeerConfig, RoleState, Tunnel, DNS_QUERIES_QUEUE_SIZE,
|
||||
ICE_GATHERING_TIMEOUT_SECONDS, MAX_CONCURRENT_ICE_GATHERING,
|
||||
};
|
||||
use boringtun::x25519::{PublicKey, StaticSecret};
|
||||
use connlib_shared::error::{ConnlibError as Error, ConnlibError};
|
||||
@@ -18,14 +16,12 @@ use connlib_shared::{Callbacks, DNS_SENTINEL};
|
||||
use futures::channel::mpsc::Receiver;
|
||||
use futures::stream;
|
||||
use futures_bounded::{PushError, StreamMap};
|
||||
use futures_util::SinkExt;
|
||||
use hickory_resolver::lookup::Lookup;
|
||||
use ip_network::IpNetwork;
|
||||
use ip_network_table::IpNetworkTable;
|
||||
use std::collections::hash_map::Entry;
|
||||
use std::collections::{HashMap, HashSet, VecDeque};
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::net::IpAddr;
|
||||
use std::sync::Arc;
|
||||
use std::task::{Context, Poll};
|
||||
use std::time::Duration;
|
||||
use tokio::time::Instant;
|
||||
@@ -41,7 +37,7 @@ where
|
||||
/// and packets will be wrapped with wireguard and sent through it.
|
||||
#[tracing::instrument(level = "trace", skip(self))]
|
||||
pub async fn add_resource(
|
||||
self: &Arc<Self>,
|
||||
&self,
|
||||
resource_description: ResourceDescription,
|
||||
) -> connlib_shared::Result<()> {
|
||||
let mut any_valid_route = false;
|
||||
@@ -71,13 +67,13 @@ where
|
||||
|
||||
/// Writes the response to a DNS lookup
|
||||
#[tracing::instrument(level = "trace", skip(self))]
|
||||
pub async fn write_dns_lookup_response(
|
||||
self: &Arc<Self>,
|
||||
pub fn write_dns_lookup_response(
|
||||
&self,
|
||||
response: hickory_resolver::error::ResolveResult<Lookup>,
|
||||
query: IpPacket<'static>,
|
||||
) -> connlib_shared::Result<()> {
|
||||
if let Some(pkt) = dns::build_response_from_resolve_result(query, response)? {
|
||||
let Some(ref device) = *self.device.read().await else {
|
||||
let Some(ref device) = *self.device.read() else {
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
@@ -89,17 +85,11 @@ where
|
||||
|
||||
/// Sets the interface configuration and starts background tasks.
|
||||
#[tracing::instrument(level = "trace", skip(self))]
|
||||
pub async fn set_interface(
|
||||
self: &Arc<Self>,
|
||||
config: &InterfaceConfig,
|
||||
) -> connlib_shared::Result<()> {
|
||||
pub async fn set_interface(&self, config: &InterfaceConfig) -> connlib_shared::Result<()> {
|
||||
let device = create_iface(config, self.callbacks()).await?;
|
||||
|
||||
*self.device.write().await = Some(device.clone());
|
||||
*self.iface_handler_abort.lock() = Some(tokio_util::spawn_log(
|
||||
&self.callbacks,
|
||||
device_handler(Arc::clone(self), device),
|
||||
));
|
||||
*self.device.write() = Some(device.clone());
|
||||
self.no_device_waker.wake();
|
||||
|
||||
self.add_route(DNS_SENTINEL.into()).await?;
|
||||
|
||||
@@ -118,95 +108,24 @@ where
|
||||
}
|
||||
|
||||
#[tracing::instrument(level = "trace", skip(self))]
|
||||
async fn add_route(self: &Arc<Self>, route: IpNetwork) -> connlib_shared::Result<()> {
|
||||
let mut device = self.device.write().await;
|
||||
async fn add_route(&self, route: IpNetwork) -> connlib_shared::Result<()> {
|
||||
let device = self
|
||||
.device
|
||||
.write()
|
||||
.take()
|
||||
.ok_or(Error::ControlProtocolError)?;
|
||||
|
||||
if let Some(new_device) = device
|
||||
.as_ref()
|
||||
.ok_or(Error::ControlProtocolError)?
|
||||
let new_device = device
|
||||
.config
|
||||
.add_route(route, self.callbacks())
|
||||
.await?
|
||||
{
|
||||
*device = Some(new_device.clone());
|
||||
*self.iface_handler_abort.lock() = Some(tokio_util::spawn_log(
|
||||
&self.callbacks,
|
||||
device_handler(Arc::clone(self), new_device),
|
||||
));
|
||||
}
|
||||
.unwrap_or(device); // Restore the old device.
|
||||
*self.device.write() = Some(new_device);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Reads IP packets from the [`Device`] and handles them accordingly.
|
||||
async fn device_handler<CB>(
|
||||
tunnel: Arc<Tunnel<CB, ClientState>>,
|
||||
mut device: Device,
|
||||
) -> Result<(), ConnlibError>
|
||||
where
|
||||
CB: Callbacks + 'static,
|
||||
{
|
||||
let device_writer = device.io.clone();
|
||||
let mut buf = [0u8; MAX_UDP_SIZE];
|
||||
'outer: loop {
|
||||
let Some(packet) = device.read().await? else {
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
let dest = packet.destination();
|
||||
let (peer_conn_id, peer_channel, maybe_write_to) = {
|
||||
let peers_by_ip = tunnel.peers_by_ip.read();
|
||||
let peer = peer_by_ip(&peers_by_ip, dest);
|
||||
|
||||
let result = tunnel
|
||||
.role_state
|
||||
.lock()
|
||||
.handle_new_packet(packet, peer, &mut buf);
|
||||
|
||||
let maybe_write_to = match result {
|
||||
Ok(None) => continue,
|
||||
Ok(Some(write_to)) => Ok(write_to),
|
||||
Err(e) => Err(e),
|
||||
};
|
||||
|
||||
let peer = peer.expect("must have peer if we should write bytes");
|
||||
|
||||
(peer.inner.conn_id, peer.channel.clone(), maybe_write_to)
|
||||
};
|
||||
|
||||
let error = match maybe_write_to {
|
||||
Ok(WriteTo::Network(mut packets)) => loop {
|
||||
let Some(packet) = packets.pop_front() else {
|
||||
continue 'outer;
|
||||
};
|
||||
|
||||
match peer_channel.write(&packet).await {
|
||||
Ok(_) => continue,
|
||||
Err(e) => break ConnlibError::IceDataError(e),
|
||||
}
|
||||
},
|
||||
Ok(WriteTo::Resource(packet)) => match device_writer.write(packet) {
|
||||
Ok(_) => continue,
|
||||
Err(e) => ConnlibError::Io(e),
|
||||
},
|
||||
Err(e) => e,
|
||||
};
|
||||
|
||||
tracing::error!(resource_address = %dest, err = ?error, "failed to handle packet {error:#}");
|
||||
|
||||
let _ = tunnel.callbacks.on_error(&error);
|
||||
|
||||
if error.is_fatal_connection_error() {
|
||||
let _ = tunnel
|
||||
.stop_peer_command_sender
|
||||
.clone()
|
||||
.send(peer_conn_id)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// [`Tunnel`] state specific to clients.
|
||||
pub struct ClientState {
|
||||
active_candidate_receivers: StreamMap<GatewayId, RTCIceCandidateInit>,
|
||||
@@ -233,35 +152,23 @@ pub struct AwaitingConnectionDetails {
|
||||
}
|
||||
|
||||
impl ClientState {
|
||||
pub(crate) fn handle_new_packet<'b>(
|
||||
/// Attempt to handle the given packet as a DNS packet.
|
||||
///
|
||||
/// Returns `Ok` if the packet is in fact a DNS query with an optional response to send back.
|
||||
/// Returns `Err` if the packet is not a DNS query.
|
||||
pub(crate) fn handle_dns<'a>(
|
||||
&mut self,
|
||||
packet: MutableIpPacket,
|
||||
peer: Option<&ConnectedPeer<GatewayId>>,
|
||||
buf: &'b mut [u8],
|
||||
) -> Result<Option<WriteTo<'b>>, ConnlibError> {
|
||||
packet: MutableIpPacket<'a>,
|
||||
) -> Result<Option<Packet<'a>>, MutableIpPacket<'a>> {
|
||||
match dns::parse(&self.resources, packet.as_immutable()) {
|
||||
Some(dns::ResolveStrategy::LocalResponse(pkt)) => {
|
||||
return Ok(Some(WriteTo::Resource(pkt)))
|
||||
}
|
||||
Some(dns::ResolveStrategy::LocalResponse(pkt)) => Ok(Some(pkt)),
|
||||
Some(dns::ResolveStrategy::ForwardQuery(query)) => {
|
||||
self.add_pending_dns_query(query);
|
||||
return Ok(None);
|
||||
|
||||
Ok(None)
|
||||
}
|
||||
None => {}
|
||||
None => Err(packet),
|
||||
}
|
||||
|
||||
let dest = packet.destination();
|
||||
|
||||
let Some(peer) = peer else {
|
||||
self.on_connection_intent(dest);
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
let Some(bytes) = peer.inner.encapsulate(packet, dest, buf)? else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
Ok(Some(WriteTo::Network(VecDeque::from([bytes]))))
|
||||
}
|
||||
|
||||
pub(crate) fn attempt_to_reuse_connection(
|
||||
|
||||
@@ -65,7 +65,7 @@ where
|
||||
tracing::trace!(?peer_config.ips, "new_data_channel_open");
|
||||
Box::pin(async move {
|
||||
{
|
||||
let Some(device) = tunnel.device.read().await.clone() else {
|
||||
let Some(device) = tunnel.device.read().clone() else {
|
||||
let e = Error::NoIface;
|
||||
tracing::error!(err = ?e, "channel_open");
|
||||
let _ = tunnel.callbacks().on_error(&e);
|
||||
|
||||
@@ -1,16 +1,18 @@
|
||||
use std::io;
|
||||
use std::sync::{
|
||||
atomic::{AtomicUsize, Ordering::Relaxed},
|
||||
Arc,
|
||||
};
|
||||
use std::task::{ready, Context, Poll};
|
||||
|
||||
use connlib_shared::{messages::Interface, CallbackErrorFacade, Callbacks, Result};
|
||||
use ip_network::IpNetwork;
|
||||
use tokio::io::{unix::AsyncFd, Interest};
|
||||
use tokio::io::{unix::AsyncFd, Ready};
|
||||
|
||||
use tun::{IfaceDevice, IfaceStream};
|
||||
|
||||
use crate::device_channel::Packet;
|
||||
use crate::{Device, MAX_UDP_SIZE};
|
||||
use crate::Device;
|
||||
|
||||
mod tun;
|
||||
|
||||
@@ -23,16 +25,27 @@ pub(crate) struct IfaceConfig {
|
||||
pub(crate) struct DeviceIo(Arc<AsyncFd<IfaceStream>>);
|
||||
|
||||
impl DeviceIo {
|
||||
pub async fn read(&self, out: &mut [u8]) -> std::io::Result<usize> {
|
||||
self.0
|
||||
.async_io(Interest::READABLE, |inner| inner.read(out))
|
||||
.await
|
||||
pub fn poll_read(&self, out: &mut [u8], cx: &mut Context<'_>) -> Poll<io::Result<usize>> {
|
||||
loop {
|
||||
let mut guard = ready!(self.0.poll_read_ready(cx))?;
|
||||
|
||||
match guard.get_inner().read(out) {
|
||||
Ok(n) => return Poll::Ready(Ok(n)),
|
||||
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
|
||||
// a read has blocked, but a write might still succeed.
|
||||
// clear only the read readiness.
|
||||
guard.clear_ready_matching(Ready::READABLE);
|
||||
continue;
|
||||
}
|
||||
Err(e) => return Poll::Ready(Err(e)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Note: write is synchronous because it's non-blocking
|
||||
// and some losiness is acceptable and increseases performance
|
||||
// since we don't block the reading loops.
|
||||
pub fn write(&self, packet: Packet<'_>) -> std::io::Result<usize> {
|
||||
pub fn write(&self, packet: Packet<'_>) -> io::Result<usize> {
|
||||
match packet {
|
||||
Packet::Ipv4(msg) => self.0.get_ref().write4(&msg),
|
||||
Packet::Ipv6(msg) => self.0.get_ref().write6(&msg),
|
||||
@@ -65,11 +78,7 @@ impl IfaceConfig {
|
||||
iface,
|
||||
mtu: AtomicUsize::new(mtu),
|
||||
});
|
||||
Ok(Some(Device {
|
||||
io,
|
||||
config,
|
||||
buf: Box::new([0u8; MAX_UDP_SIZE]),
|
||||
}))
|
||||
Ok(Some(Device { io, config }))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -86,9 +95,5 @@ pub(crate) async fn create_iface(
|
||||
mtu: AtomicUsize::new(mtu),
|
||||
});
|
||||
|
||||
Ok(Device {
|
||||
io,
|
||||
config,
|
||||
buf: Box::new([0u8; MAX_UDP_SIZE]),
|
||||
})
|
||||
Ok(Device { io, config })
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ use crate::device_channel::Packet;
|
||||
use crate::Device;
|
||||
use connlib_shared::{messages::Interface, CallbackErrorFacade, Callbacks, Result};
|
||||
use ip_network::IpNetwork;
|
||||
use std::task::{Context, Poll};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct DeviceIo;
|
||||
@@ -9,7 +10,7 @@ pub(crate) struct DeviceIo;
|
||||
pub(crate) struct IfaceConfig;
|
||||
|
||||
impl DeviceIo {
|
||||
pub async fn read(&self, _: &mut [u8]) -> std::io::Result<usize> {
|
||||
pub fn poll_read(&self, _: &mut [u8], _: &mut Context<'_>) -> Poll<std::io::Result<usize>> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
|
||||
@@ -1,16 +1,11 @@
|
||||
use crate::device_channel::create_iface;
|
||||
use crate::{
|
||||
peer_by_ip, Device, Event, RoleState, Tunnel, ICE_GATHERING_TIMEOUT_SECONDS,
|
||||
MAX_CONCURRENT_ICE_GATHERING, MAX_UDP_SIZE,
|
||||
Event, RoleState, Tunnel, ICE_GATHERING_TIMEOUT_SECONDS, MAX_CONCURRENT_ICE_GATHERING,
|
||||
};
|
||||
use connlib_shared::error::ConnlibError;
|
||||
use connlib_shared::messages::{ClientId, Interface as InterfaceConfig};
|
||||
use connlib_shared::Callbacks;
|
||||
use futures::channel::mpsc::Receiver;
|
||||
use futures_bounded::{PushError, StreamMap};
|
||||
use futures_util::SinkExt;
|
||||
use std::net::IpAddr;
|
||||
use std::sync::Arc;
|
||||
use std::task::{ready, Context, Poll};
|
||||
use std::time::Duration;
|
||||
use webrtc::ice_transport::ice_candidate::RTCIceCandidateInit;
|
||||
@@ -21,15 +16,11 @@ where
|
||||
{
|
||||
/// Sets the interface configuration and starts background tasks.
|
||||
#[tracing::instrument(level = "trace", skip(self))]
|
||||
pub async fn set_interface(
|
||||
self: &Arc<Self>,
|
||||
config: &InterfaceConfig,
|
||||
) -> connlib_shared::Result<()> {
|
||||
pub async fn set_interface(&self, config: &InterfaceConfig) -> connlib_shared::Result<()> {
|
||||
let device = create_iface(config, self.callbacks()).await?;
|
||||
|
||||
*self.device.write().await = Some(device.clone());
|
||||
*self.iface_handler_abort.lock() =
|
||||
Some(tokio::spawn(device_handler(Arc::clone(self), device)).abort_handle());
|
||||
*self.device.write() = Some(device.clone());
|
||||
self.no_device_waker.wake();
|
||||
|
||||
tracing::debug!("background_loop_started");
|
||||
|
||||
@@ -43,67 +34,6 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
/// Reads IP packets from the [`Device`] and handles them accordingly.
|
||||
async fn device_handler<CB>(
|
||||
tunnel: Arc<Tunnel<CB, GatewayState>>,
|
||||
mut device: Device,
|
||||
) -> Result<(), ConnlibError>
|
||||
where
|
||||
CB: Callbacks + 'static,
|
||||
{
|
||||
let mut buf = [0u8; MAX_UDP_SIZE];
|
||||
loop {
|
||||
let Some(packet) = device.read().await? else {
|
||||
// Reading a bad IP packet or otherwise from the device seems bad. Should we restart the tunnel or something?
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
let dest = packet.destination();
|
||||
|
||||
let (result, channel, peer_conn_id) = {
|
||||
let peers_by_ip = tunnel.peers_by_ip.read();
|
||||
let Some(peer) = peer_by_ip(&peers_by_ip, dest) else {
|
||||
continue;
|
||||
};
|
||||
|
||||
let result = peer.inner.encapsulate(packet, dest, &mut buf);
|
||||
let channel = peer.channel.clone();
|
||||
|
||||
(result, channel, peer.inner.conn_id)
|
||||
};
|
||||
|
||||
let error = match result {
|
||||
Ok(None) => continue,
|
||||
Ok(Some(b)) => match channel.write(&b).await {
|
||||
Ok(_) => continue,
|
||||
Err(e) => ConnlibError::IceDataError(e),
|
||||
},
|
||||
Err(e) => e,
|
||||
};
|
||||
|
||||
on_error(&tunnel, dest, error, peer_conn_id).await
|
||||
}
|
||||
}
|
||||
|
||||
async fn on_error<CB>(
|
||||
tunnel: &Tunnel<CB, GatewayState>,
|
||||
dest: IpAddr,
|
||||
e: ConnlibError,
|
||||
peer_conn_id: ClientId,
|
||||
) where
|
||||
CB: Callbacks + 'static,
|
||||
{
|
||||
tracing::error!(resource_address = %dest, err = ?e, "failed to handle packet {e:#}");
|
||||
|
||||
if e.is_fatal_connection_error() {
|
||||
let _ = tunnel
|
||||
.stop_peer_command_sender
|
||||
.clone()
|
||||
.send(peer_conn_id)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
/// [`Tunnel`] state specific to gateways.
|
||||
pub struct GatewayState {
|
||||
candidate_receivers: StreamMap<ClientId, RTCIceCandidateInit>,
|
||||
|
||||
@@ -16,7 +16,7 @@ use pnet_packet::Packet;
|
||||
use hickory_resolver::proto::rr::RecordType;
|
||||
use parking_lot::{Mutex, RwLock};
|
||||
use peer::{Peer, PeerStats};
|
||||
use tokio::{task::AbortHandle, time::MissedTickBehavior};
|
||||
use tokio::time::MissedTickBehavior;
|
||||
use webrtc::{
|
||||
api::{
|
||||
interceptor_registry::register_default_interceptors, media_engine::MediaEngine,
|
||||
@@ -27,10 +27,11 @@ use webrtc::{
|
||||
};
|
||||
|
||||
use futures::channel::mpsc;
|
||||
use futures_util::task::AtomicWaker;
|
||||
use futures_util::{SinkExt, StreamExt};
|
||||
use itertools::Itertools;
|
||||
use std::collections::VecDeque;
|
||||
use std::task::{Context, Poll};
|
||||
use std::task::{ready, Context, Poll};
|
||||
use std::{collections::HashMap, fmt, io, net::IpAddr, sync::Arc, time::Duration};
|
||||
use std::{collections::HashSet, hash::Hash};
|
||||
use tokio::time::Interval;
|
||||
@@ -45,12 +46,13 @@ use connlib_shared::{
|
||||
use device_channel::{DeviceIo, IfaceConfig};
|
||||
|
||||
pub use client::ClientState;
|
||||
use connlib_shared::error::ConnlibError;
|
||||
pub use control_protocol::Request;
|
||||
pub use gateway::GatewayState;
|
||||
pub use webrtc::peer_connection::sdp::session_description::RTCSessionDescription;
|
||||
|
||||
use crate::ip_packet::MutableIpPacket;
|
||||
use connlib_shared::messages::SecretKey;
|
||||
use connlib_shared::messages::{ClientId, SecretKey};
|
||||
use index::IndexLfsr;
|
||||
|
||||
mod bounded_queue;
|
||||
@@ -64,7 +66,6 @@ mod ip_packet;
|
||||
mod peer;
|
||||
mod peer_handler;
|
||||
mod resource_table;
|
||||
mod tokio_util;
|
||||
|
||||
const MAX_UDP_SIZE: usize = (1 << 16) - 1;
|
||||
const DNS_QUERIES_QUEUE_SIZE: usize = 100;
|
||||
@@ -108,35 +109,34 @@ impl From<connlib_shared::messages::Peer> for PeerConfig {
|
||||
struct Device {
|
||||
config: Arc<IfaceConfig>,
|
||||
io: DeviceIo,
|
||||
|
||||
buf: Box<[u8; MAX_UDP_SIZE]>,
|
||||
}
|
||||
|
||||
impl Device {
|
||||
async fn read(&mut self) -> io::Result<Option<MutableIpPacket<'_>>> {
|
||||
let res = self.io.read(&mut self.buf[..self.config.mtu()]).await?;
|
||||
tracing::trace!(target: "wire", action = "read", bytes = res, from = "iface");
|
||||
fn poll_read<'b>(
|
||||
&mut self,
|
||||
buf: &'b mut [u8],
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<io::Result<Option<MutableIpPacket<'b>>>> {
|
||||
let res = ready!(self.io.poll_read(&mut buf[..self.config.mtu()], cx))?;
|
||||
|
||||
if res == 0 {
|
||||
return Ok(None);
|
||||
return Poll::Ready(Ok(None));
|
||||
}
|
||||
|
||||
Ok(Some(
|
||||
MutableIpPacket::new(&mut self.buf[..res]).ok_or_else(|| {
|
||||
Poll::Ready(Ok(Some(MutableIpPacket::new(&mut buf[..res]).ok_or_else(
|
||||
|| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
"received bytes are not an IP packet",
|
||||
)
|
||||
})?,
|
||||
))
|
||||
},
|
||||
)?)))
|
||||
}
|
||||
}
|
||||
|
||||
/// Tunnel is a wireguard state machine that uses webrtc's ICE channels instead of UDP sockets to communicate between peers.
|
||||
pub struct Tunnel<CB: Callbacks, TRoleState: RoleState> {
|
||||
next_index: Mutex<IndexLfsr>,
|
||||
// We use a tokio Mutex here since this is only read/write during config so there's no relevant performance impact
|
||||
device: tokio::sync::RwLock<Option<Device>>,
|
||||
rate_limiter: Arc<RateLimiter>,
|
||||
private_key: StaticSecret,
|
||||
public_key: PublicKey,
|
||||
@@ -145,7 +145,6 @@ pub struct Tunnel<CB: Callbacks, TRoleState: RoleState> {
|
||||
peer_connections: Mutex<HashMap<TRoleState::Id, Arc<RTCPeerConnection>>>,
|
||||
webrtc_api: API,
|
||||
callbacks: CallbackErrorFacade<CB>,
|
||||
iface_handler_abort: Mutex<Option<AbortHandle>>,
|
||||
|
||||
/// State that differs per role, i.e. clients vs gateways.
|
||||
role_state: Mutex<TRoleState>,
|
||||
@@ -158,6 +157,142 @@ pub struct Tunnel<CB: Callbacks, TRoleState: RoleState> {
|
||||
mtu_refresh_interval: Mutex<Interval>,
|
||||
|
||||
peers_to_stop: Mutex<VecDeque<TRoleState::Id>>,
|
||||
|
||||
device: RwLock<Option<Device>>,
|
||||
read_buf: Mutex<Box<[u8; MAX_UDP_SIZE]>>,
|
||||
write_buf: Mutex<Box<[u8; MAX_UDP_SIZE]>>,
|
||||
no_device_waker: AtomicWaker,
|
||||
}
|
||||
|
||||
impl<CB> Tunnel<CB, ClientState>
|
||||
where
|
||||
CB: Callbacks + 'static,
|
||||
{
|
||||
pub async fn next_event(&self) -> Result<Event<GatewayId>> {
|
||||
std::future::poll_fn(|cx| loop {
|
||||
{
|
||||
let mut guard = self.device.write();
|
||||
|
||||
if let Some(device) = guard.as_mut() {
|
||||
match self.poll_device(device, cx) {
|
||||
Poll::Ready(Ok(Some(event))) => return Poll::Ready(Ok(event)),
|
||||
Poll::Ready(Ok(None)) => {
|
||||
tracing::info!("Device stopped");
|
||||
guard.take();
|
||||
continue;
|
||||
}
|
||||
Poll::Ready(Err(e)) => {
|
||||
guard.take(); // Ensure we don't poll a failed device again.
|
||||
return Poll::Ready(Err(e));
|
||||
}
|
||||
Poll::Pending => {}
|
||||
}
|
||||
} else {
|
||||
self.no_device_waker.register(cx.waker());
|
||||
}
|
||||
}
|
||||
|
||||
match self.poll_next_event_common(cx) {
|
||||
Poll::Ready(event) => return Poll::Ready(Ok(event)),
|
||||
Poll::Pending => {}
|
||||
}
|
||||
|
||||
return Poll::Pending;
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
pub(crate) fn poll_device(
|
||||
&self,
|
||||
device: &mut Device,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Result<Option<Event<GatewayId>>>> {
|
||||
loop {
|
||||
let mut read_guard = self.read_buf.lock();
|
||||
let mut write_guard = self.write_buf.lock();
|
||||
let read_buf = read_guard.as_mut_slice();
|
||||
let write_buf = write_guard.as_mut_slice();
|
||||
|
||||
let Some(packet) = ready!(device.poll_read(read_buf, cx))? else {
|
||||
return Poll::Ready(Ok(None));
|
||||
};
|
||||
|
||||
let mut role_state = self.role_state.lock();
|
||||
|
||||
let packet = match role_state.handle_dns(packet) {
|
||||
Ok(Some(response)) => {
|
||||
device.io.write(response)?;
|
||||
continue;
|
||||
}
|
||||
Ok(None) => continue,
|
||||
Err(non_dns_packet) => non_dns_packet,
|
||||
};
|
||||
|
||||
let dest = packet.destination();
|
||||
|
||||
let peers_by_ip = self.peers_by_ip.read();
|
||||
let Some(peer) = peer_by_ip(&peers_by_ip, dest) else {
|
||||
role_state.on_connection_intent(dest);
|
||||
continue;
|
||||
};
|
||||
|
||||
self.encapsulate(write_buf, packet, dest, peer);
|
||||
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<CB> Tunnel<CB, GatewayState>
|
||||
where
|
||||
CB: Callbacks + 'static,
|
||||
{
|
||||
pub fn poll_next_event(&self, cx: &mut Context<'_>) -> Poll<Result<Event<ClientId>>> {
|
||||
let mut read_guard = self.read_buf.lock();
|
||||
let mut write_guard = self.write_buf.lock();
|
||||
|
||||
let read_buf = read_guard.as_mut_slice();
|
||||
let write_buf = write_guard.as_mut_slice();
|
||||
|
||||
loop {
|
||||
{
|
||||
let mut device = self.device.write();
|
||||
|
||||
match device.as_mut().map(|d| d.poll_read(read_buf, cx)) {
|
||||
Some(Poll::Ready(Ok(Some(packet)))) => {
|
||||
let dest = packet.destination();
|
||||
|
||||
let peers_by_ip = self.peers_by_ip.read();
|
||||
let Some(peer) = peer_by_ip(&peers_by_ip, dest) else {
|
||||
continue;
|
||||
};
|
||||
|
||||
self.encapsulate(write_buf, packet, dest, peer);
|
||||
|
||||
continue;
|
||||
}
|
||||
Some(Poll::Ready(Ok(None))) => {
|
||||
tracing::info!("Device stopped");
|
||||
drop(device.take());
|
||||
}
|
||||
Some(Poll::Ready(Err(e))) => return Poll::Ready(Err(ConnlibError::Io(e))),
|
||||
Some(Poll::Pending) => {
|
||||
// device not ready for reading, moving on ..
|
||||
}
|
||||
None => {
|
||||
self.no_device_waker.register(cx.waker());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
match self.poll_next_event_common(cx) {
|
||||
Poll::Ready(e) => return Poll::Ready(Ok(e)),
|
||||
Poll::Pending => {}
|
||||
}
|
||||
|
||||
return Poll::Pending;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ConnectedPeer<TId> {
|
||||
@@ -195,11 +330,7 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn next_event(&self) -> Event<TRoleState::Id> {
|
||||
std::future::poll_fn(|cx| self.poll_next_event(cx)).await
|
||||
}
|
||||
|
||||
pub fn poll_next_event(&self, cx: &mut Context<'_>) -> Poll<Event<TRoleState::Id>> {
|
||||
fn poll_next_event_common(&self, cx: &mut Context<'_>) -> Poll<Event<TRoleState::Id>> {
|
||||
loop {
|
||||
if let Some(conn_id) = self.peers_to_stop.lock().pop_front() {
|
||||
let mut peers = self.peers_by_ip.write();
|
||||
@@ -291,21 +422,9 @@ where
|
||||
}
|
||||
|
||||
if self.mtu_refresh_interval.lock().poll_tick(cx).is_ready() {
|
||||
// We use `try_read` to acquire a lock on the device because we are within a synchronous context here and cannot use `.await`.
|
||||
// The device is only updated during `add_route` and `set_interface` which would be extremely unlucky to hit at the same time as this timer.
|
||||
// Even if we hit this, we just wait for the next tick to update the MTU.
|
||||
let device = match self.device.try_read().map(|d| d.clone()) {
|
||||
Ok(Some(device)) => device,
|
||||
Ok(None) => {
|
||||
let err = Error::ControlProtocolError;
|
||||
tracing::error!(?err, "get_iface_config");
|
||||
let _ = self.callbacks.on_error(&err);
|
||||
continue;
|
||||
}
|
||||
Err(_) => {
|
||||
tracing::debug!("Unlucky! Somebody is updating the device just as we are about to update its MTU, trying again on the next tick ...");
|
||||
continue;
|
||||
}
|
||||
let Some(device) = self.device.read().clone() else {
|
||||
tracing::debug!("Device temporarily not available");
|
||||
continue;
|
||||
};
|
||||
|
||||
tokio::spawn({
|
||||
@@ -335,6 +454,40 @@ where
|
||||
return Poll::Pending;
|
||||
}
|
||||
}
|
||||
|
||||
fn encapsulate(
|
||||
&self,
|
||||
write_buf: &mut [u8],
|
||||
packet: MutableIpPacket,
|
||||
dest: IpAddr,
|
||||
peer: &ConnectedPeer<TRoleState::Id>,
|
||||
) {
|
||||
let peer_id = peer.inner.conn_id;
|
||||
|
||||
match peer.inner.encapsulate(packet, dest, write_buf) {
|
||||
Ok(None) => {}
|
||||
Ok(Some(b)) => {
|
||||
tokio::spawn({
|
||||
let channel = peer.channel.clone();
|
||||
let mut sender = self.stop_peer_command_sender.clone();
|
||||
|
||||
async move {
|
||||
if let Err(e) = channel.write(&b).await {
|
||||
tracing::error!(resource_address = %dest, err = ?e, "failed to handle packet {e:#}");
|
||||
let _ = sender.send(peer_id).await;
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!(resource_address = %dest, err = ?e, "failed to handle packet {e:#}");
|
||||
|
||||
if e.is_fatal_connection_error() {
|
||||
self.peers_to_stop.lock().push_back(peer_id);
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn peer_by_ip<Id>(
|
||||
@@ -403,7 +556,6 @@ where
|
||||
let next_index = Default::default();
|
||||
let peer_connections = Default::default();
|
||||
let device = Default::default();
|
||||
let iface_handler_abort = Default::default();
|
||||
|
||||
// ICE
|
||||
let mut media_engine = MediaEngine::default();
|
||||
@@ -433,8 +585,9 @@ where
|
||||
next_index,
|
||||
webrtc_api,
|
||||
device,
|
||||
read_buf: Mutex::new(Box::new([0u8; MAX_UDP_SIZE])),
|
||||
write_buf: Mutex::new(Box::new([0u8; MAX_UDP_SIZE])),
|
||||
callbacks: CallbackErrorFacade(callbacks),
|
||||
iface_handler_abort,
|
||||
role_state: Default::default(),
|
||||
stop_peer_command_receiver: Mutex::new(stop_peer_command_receiver),
|
||||
stop_peer_command_sender,
|
||||
@@ -442,6 +595,7 @@ where
|
||||
peer_refresh_interval: Mutex::new(peer_refresh_interval()),
|
||||
mtu_refresh_interval: Mutex::new(mtu_refresh_interval()),
|
||||
peers_to_stop: Default::default(),
|
||||
no_device_waker: Default::default(),
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use connlib_shared::{Callbacks, Error, Result};
|
||||
use connlib_shared::Callbacks;
|
||||
use futures_util::SinkExt;
|
||||
use webrtc::data::data_channel::DataChannel;
|
||||
|
||||
@@ -18,15 +19,15 @@ where
|
||||
channel: Arc<DataChannel>,
|
||||
) {
|
||||
loop {
|
||||
let Some(device) = self.device.read().await.clone() else {
|
||||
let err = Error::NoIface;
|
||||
tracing::error!(?err);
|
||||
let _ = self.callbacks().on_disconnect(Some(&err));
|
||||
break;
|
||||
let Some(device) = self.device.read().clone() else {
|
||||
tracing::debug!("Device temporarily not available");
|
||||
tokio::time::sleep(Duration::from_millis(100)).await;
|
||||
continue;
|
||||
};
|
||||
let device_io = device.io;
|
||||
|
||||
let result = self.peer_handler(&peer, channel.clone(), device_io).await;
|
||||
let result =
|
||||
peer_handler(self.callbacks.clone(), &peer, channel.clone(), device_io).await;
|
||||
|
||||
if matches!(result, Err(ref err) if err.raw_os_error() == Some(9)) {
|
||||
tracing::warn!("bad_file_descriptor");
|
||||
@@ -47,66 +48,51 @@ where
|
||||
.send(peer.conn_id)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
async fn peer_handler(
|
||||
self: &Arc<Self>,
|
||||
peer: &Arc<Peer<TRoleState::Id>>,
|
||||
channel: Arc<DataChannel>,
|
||||
device_io: DeviceIo,
|
||||
) -> std::io::Result<()> {
|
||||
let mut src_buf = [0u8; MAX_UDP_SIZE];
|
||||
let mut dst_buf = [0u8; MAX_UDP_SIZE];
|
||||
while let Ok(size) = channel.read(&mut src_buf[..]).await {
|
||||
tracing::trace!(target: "wire", action = "read", bytes = size, from = "peer");
|
||||
async fn peer_handler<TId>(
|
||||
callbacks: impl Callbacks,
|
||||
peer: &Arc<Peer<TId>>,
|
||||
channel: Arc<DataChannel>,
|
||||
device_io: DeviceIo,
|
||||
) -> std::io::Result<()>
|
||||
where
|
||||
TId: Copy,
|
||||
{
|
||||
let mut src_buf = [0u8; MAX_UDP_SIZE];
|
||||
let mut dst_buf = [0u8; MAX_UDP_SIZE];
|
||||
while let Ok(size) = channel.read(&mut src_buf[..]).await {
|
||||
tracing::trace!(target: "wire", action = "read", bytes = size, from = "peer");
|
||||
|
||||
// TODO: Double check that this can only happen on closed channel
|
||||
// I think it's possible to transmit a 0-byte message through the channel
|
||||
// but we would never use that.
|
||||
// We should keep track of an open/closed channel ourselves if we wanted to do it properly then.
|
||||
if size == 0 {
|
||||
break;
|
||||
}
|
||||
|
||||
match self
|
||||
.handle_peer_packet(peer, &channel, &device_io, &src_buf[..size], &mut dst_buf)
|
||||
.await
|
||||
{
|
||||
Err(Error::Io(e)) => return Err(e),
|
||||
Err(other) => {
|
||||
tracing::error!(error = ?other, "failed to handle peer packet");
|
||||
let _ = self.callbacks.on_error(&other);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
// TODO: Double check that this can only happen on closed channel
|
||||
// I think it's possible to transmit a 0-byte message through the channel
|
||||
// but we would never use that.
|
||||
// We should keep track of an open/closed channel ourselves if we wanted to do it properly then.
|
||||
if size == 0 {
|
||||
break;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
let src = &src_buf[..size];
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) async fn handle_peer_packet(
|
||||
self: &Arc<Self>,
|
||||
peer: &Arc<Peer<TRoleState::Id>>,
|
||||
channel: &DataChannel,
|
||||
device_writer: &DeviceIo,
|
||||
src: &[u8],
|
||||
dst: &mut [u8],
|
||||
) -> Result<()> {
|
||||
match peer.decapsulate(src, dst)? {
|
||||
Some(WriteTo::Network(bytes)) => {
|
||||
match peer.decapsulate(src, &mut dst_buf) {
|
||||
Ok(Some(WriteTo::Network(bytes))) => {
|
||||
for packet in bytes {
|
||||
if let Err(e) = channel.write(&packet).await {
|
||||
tracing::error!("Couldn't send packet to connected peer: {e}");
|
||||
let _ = self.callbacks.on_error(&e.into());
|
||||
let _ = callbacks.on_error(&e.into());
|
||||
}
|
||||
}
|
||||
}
|
||||
Some(WriteTo::Resource(packet)) => {
|
||||
device_writer.write(packet)?;
|
||||
Ok(Some(WriteTo::Resource(packet))) => {
|
||||
device_io.write(packet)?;
|
||||
}
|
||||
Ok(None) => {}
|
||||
Err(other) => {
|
||||
tracing::error!(error = ?other, "failed to handle peer packet");
|
||||
let _ = callbacks.on_error(&other);
|
||||
}
|
||||
None => {}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -1,23 +0,0 @@
|
||||
use connlib_shared::error::ConnlibError;
|
||||
use connlib_shared::Callbacks;
|
||||
use std::future::Future;
|
||||
|
||||
/// Spawns a task into the [`tokio`] runtime.
|
||||
///
|
||||
/// On error, [`Callbacks::on_error`] is invoked.
|
||||
/// This also returns a [`tokio::task::AbortHandle`] which MAY be used to abort the task.
|
||||
/// If you don't need it, you are free to drop it.
|
||||
/// It won't terminate the task.
|
||||
pub(crate) fn spawn_log(
|
||||
cb: &(impl Callbacks + 'static),
|
||||
f: impl Future<Output = Result<(), ConnlibError>> + Send + 'static,
|
||||
) -> tokio::task::AbortHandle {
|
||||
let cb = cb.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = f.await {
|
||||
let _ = cb.on_error(&e);
|
||||
}
|
||||
})
|
||||
.abort_handle()
|
||||
}
|
||||
@@ -173,7 +173,7 @@ impl Eventloop {
|
||||
_ => {}
|
||||
}
|
||||
|
||||
match self.tunnel.poll_next_event(cx) {
|
||||
match self.tunnel.poll_next_event(cx)? {
|
||||
Poll::Ready(firezone_tunnel::Event::SignalIceCandidate {
|
||||
conn_id: client,
|
||||
candidate,
|
||||
|
||||
Reference in New Issue
Block a user