diff --git a/rust/connlib/clients/shared/src/control.rs b/rust/connlib/clients/shared/src/control.rs index 8f1492ec1..91b5830c2 100644 --- a/rust/connlib/clients/shared/src/control.rs +++ b/rust/connlib/clients/shared/src/control.rs @@ -286,9 +286,9 @@ impl ControlPlane { .await; } - pub async fn handle_tunnel_event(&mut self, event: firezone_tunnel::Event) { + pub async fn handle_tunnel_event(&mut self, event: Result>) { match event { - firezone_tunnel::Event::SignalIceCandidate { conn_id, candidate } => { + Ok(firezone_tunnel::Event::SignalIceCandidate { conn_id, candidate }) => { if let Err(e) = self .phoenix_channel .send(EgressMessages::BroadcastIceCandidates( @@ -302,11 +302,11 @@ impl ControlPlane { tracing::error!("Failed to signal ICE candidate: {e}") } } - firezone_tunnel::Event::ConnectionIntent { + Ok(firezone_tunnel::Event::ConnectionIntent { resource, connected_gateway_ids, reference, - } => { + }) => { if let Err(e) = self .phoenix_channel .clone() @@ -324,7 +324,7 @@ impl ControlPlane { // TODO: Clean up connection in `ClientState` here? } } - firezone_tunnel::Event::DnsQuery(query) => { + Ok(firezone_tunnel::Event::DnsQuery(query)) => { // Until we handle it better on a gateway-like eventloop, making sure not to block the loop let Some(resolver) = self.fallback_resolver.lock().clone() else { return; @@ -332,14 +332,14 @@ impl ControlPlane { let tunnel = self.tunnel.clone(); tokio::spawn(async move { let response = resolver.lookup(query.name, query.record_type).await; - if let Err(err) = tunnel - .write_dns_lookup_response(response, query.query) - .await - { + if let Err(err) = tunnel.write_dns_lookup_response(response, query.query) { tracing::error!(err = ?err, "DNS lookup failed: {err:#}"); } }); } + Err(e) => { + tracing::error!("Tunnel failed: {e}"); + } } } } diff --git a/rust/connlib/tunnel/src/client.rs b/rust/connlib/tunnel/src/client.rs index ed652d169..fa1fb54ae 100644 --- a/rust/connlib/tunnel/src/client.rs +++ b/rust/connlib/tunnel/src/client.rs @@ -1,12 +1,10 @@ use crate::bounded_queue::BoundedQueue; -use crate::device_channel::create_iface; +use crate::device_channel::{create_iface, Packet}; use crate::ip_packet::{IpPacket, MutableIpPacket}; -use crate::peer::WriteTo; use crate::resource_table::ResourceTable; use crate::{ - dns, peer_by_ip, tokio_util, ConnectedPeer, Device, DnsQuery, Event, PeerConfig, RoleState, - Tunnel, DNS_QUERIES_QUEUE_SIZE, ICE_GATHERING_TIMEOUT_SECONDS, MAX_CONCURRENT_ICE_GATHERING, - MAX_UDP_SIZE, + dns, ConnectedPeer, DnsQuery, Event, PeerConfig, RoleState, Tunnel, DNS_QUERIES_QUEUE_SIZE, + ICE_GATHERING_TIMEOUT_SECONDS, MAX_CONCURRENT_ICE_GATHERING, }; use boringtun::x25519::{PublicKey, StaticSecret}; use connlib_shared::error::{ConnlibError as Error, ConnlibError}; @@ -18,14 +16,12 @@ use connlib_shared::{Callbacks, DNS_SENTINEL}; use futures::channel::mpsc::Receiver; use futures::stream; use futures_bounded::{PushError, StreamMap}; -use futures_util::SinkExt; use hickory_resolver::lookup::Lookup; use ip_network::IpNetwork; use ip_network_table::IpNetworkTable; use std::collections::hash_map::Entry; -use std::collections::{HashMap, HashSet, VecDeque}; +use std::collections::{HashMap, HashSet}; use std::net::IpAddr; -use std::sync::Arc; use std::task::{Context, Poll}; use std::time::Duration; use tokio::time::Instant; @@ -41,7 +37,7 @@ where /// and packets will be wrapped with wireguard and sent through it. #[tracing::instrument(level = "trace", skip(self))] pub async fn add_resource( - self: &Arc, + &self, resource_description: ResourceDescription, ) -> connlib_shared::Result<()> { let mut any_valid_route = false; @@ -71,13 +67,13 @@ where /// Writes the response to a DNS lookup #[tracing::instrument(level = "trace", skip(self))] - pub async fn write_dns_lookup_response( - self: &Arc, + pub fn write_dns_lookup_response( + &self, response: hickory_resolver::error::ResolveResult, query: IpPacket<'static>, ) -> connlib_shared::Result<()> { if let Some(pkt) = dns::build_response_from_resolve_result(query, response)? { - let Some(ref device) = *self.device.read().await else { + let Some(ref device) = *self.device.read() else { return Ok(()); }; @@ -89,17 +85,11 @@ where /// Sets the interface configuration and starts background tasks. #[tracing::instrument(level = "trace", skip(self))] - pub async fn set_interface( - self: &Arc, - config: &InterfaceConfig, - ) -> connlib_shared::Result<()> { + pub async fn set_interface(&self, config: &InterfaceConfig) -> connlib_shared::Result<()> { let device = create_iface(config, self.callbacks()).await?; - *self.device.write().await = Some(device.clone()); - *self.iface_handler_abort.lock() = Some(tokio_util::spawn_log( - &self.callbacks, - device_handler(Arc::clone(self), device), - )); + *self.device.write() = Some(device.clone()); + self.no_device_waker.wake(); self.add_route(DNS_SENTINEL.into()).await?; @@ -118,95 +108,24 @@ where } #[tracing::instrument(level = "trace", skip(self))] - async fn add_route(self: &Arc, route: IpNetwork) -> connlib_shared::Result<()> { - let mut device = self.device.write().await; + async fn add_route(&self, route: IpNetwork) -> connlib_shared::Result<()> { + let device = self + .device + .write() + .take() + .ok_or(Error::ControlProtocolError)?; - if let Some(new_device) = device - .as_ref() - .ok_or(Error::ControlProtocolError)? + let new_device = device .config .add_route(route, self.callbacks()) .await? - { - *device = Some(new_device.clone()); - *self.iface_handler_abort.lock() = Some(tokio_util::spawn_log( - &self.callbacks, - device_handler(Arc::clone(self), new_device), - )); - } + .unwrap_or(device); // Restore the old device. + *self.device.write() = Some(new_device); Ok(()) } } -/// Reads IP packets from the [`Device`] and handles them accordingly. -async fn device_handler( - tunnel: Arc>, - mut device: Device, -) -> Result<(), ConnlibError> -where - CB: Callbacks + 'static, -{ - let device_writer = device.io.clone(); - let mut buf = [0u8; MAX_UDP_SIZE]; - 'outer: loop { - let Some(packet) = device.read().await? else { - return Ok(()); - }; - - let dest = packet.destination(); - let (peer_conn_id, peer_channel, maybe_write_to) = { - let peers_by_ip = tunnel.peers_by_ip.read(); - let peer = peer_by_ip(&peers_by_ip, dest); - - let result = tunnel - .role_state - .lock() - .handle_new_packet(packet, peer, &mut buf); - - let maybe_write_to = match result { - Ok(None) => continue, - Ok(Some(write_to)) => Ok(write_to), - Err(e) => Err(e), - }; - - let peer = peer.expect("must have peer if we should write bytes"); - - (peer.inner.conn_id, peer.channel.clone(), maybe_write_to) - }; - - let error = match maybe_write_to { - Ok(WriteTo::Network(mut packets)) => loop { - let Some(packet) = packets.pop_front() else { - continue 'outer; - }; - - match peer_channel.write(&packet).await { - Ok(_) => continue, - Err(e) => break ConnlibError::IceDataError(e), - } - }, - Ok(WriteTo::Resource(packet)) => match device_writer.write(packet) { - Ok(_) => continue, - Err(e) => ConnlibError::Io(e), - }, - Err(e) => e, - }; - - tracing::error!(resource_address = %dest, err = ?error, "failed to handle packet {error:#}"); - - let _ = tunnel.callbacks.on_error(&error); - - if error.is_fatal_connection_error() { - let _ = tunnel - .stop_peer_command_sender - .clone() - .send(peer_conn_id) - .await; - } - } -} - /// [`Tunnel`] state specific to clients. pub struct ClientState { active_candidate_receivers: StreamMap, @@ -233,35 +152,23 @@ pub struct AwaitingConnectionDetails { } impl ClientState { - pub(crate) fn handle_new_packet<'b>( + /// Attempt to handle the given packet as a DNS packet. + /// + /// Returns `Ok` if the packet is in fact a DNS query with an optional response to send back. + /// Returns `Err` if the packet is not a DNS query. + pub(crate) fn handle_dns<'a>( &mut self, - packet: MutableIpPacket, - peer: Option<&ConnectedPeer>, - buf: &'b mut [u8], - ) -> Result>, ConnlibError> { + packet: MutableIpPacket<'a>, + ) -> Result>, MutableIpPacket<'a>> { match dns::parse(&self.resources, packet.as_immutable()) { - Some(dns::ResolveStrategy::LocalResponse(pkt)) => { - return Ok(Some(WriteTo::Resource(pkt))) - } + Some(dns::ResolveStrategy::LocalResponse(pkt)) => Ok(Some(pkt)), Some(dns::ResolveStrategy::ForwardQuery(query)) => { self.add_pending_dns_query(query); - return Ok(None); + + Ok(None) } - None => {} + None => Err(packet), } - - let dest = packet.destination(); - - let Some(peer) = peer else { - self.on_connection_intent(dest); - return Ok(None); - }; - - let Some(bytes) = peer.inner.encapsulate(packet, dest, buf)? else { - return Ok(None); - }; - - Ok(Some(WriteTo::Network(VecDeque::from([bytes])))) } pub(crate) fn attempt_to_reuse_connection( diff --git a/rust/connlib/tunnel/src/control_protocol/gateway.rs b/rust/connlib/tunnel/src/control_protocol/gateway.rs index fe2a950ed..d3eea8b67 100644 --- a/rust/connlib/tunnel/src/control_protocol/gateway.rs +++ b/rust/connlib/tunnel/src/control_protocol/gateway.rs @@ -65,7 +65,7 @@ where tracing::trace!(?peer_config.ips, "new_data_channel_open"); Box::pin(async move { { - let Some(device) = tunnel.device.read().await.clone() else { + let Some(device) = tunnel.device.read().clone() else { let e = Error::NoIface; tracing::error!(err = ?e, "channel_open"); let _ = tunnel.callbacks().on_error(&e); diff --git a/rust/connlib/tunnel/src/device_channel/device_channel_unix.rs b/rust/connlib/tunnel/src/device_channel/device_channel_unix.rs index 186339fbf..f5200875d 100644 --- a/rust/connlib/tunnel/src/device_channel/device_channel_unix.rs +++ b/rust/connlib/tunnel/src/device_channel/device_channel_unix.rs @@ -1,16 +1,18 @@ +use std::io; use std::sync::{ atomic::{AtomicUsize, Ordering::Relaxed}, Arc, }; +use std::task::{ready, Context, Poll}; use connlib_shared::{messages::Interface, CallbackErrorFacade, Callbacks, Result}; use ip_network::IpNetwork; -use tokio::io::{unix::AsyncFd, Interest}; +use tokio::io::{unix::AsyncFd, Ready}; use tun::{IfaceDevice, IfaceStream}; use crate::device_channel::Packet; -use crate::{Device, MAX_UDP_SIZE}; +use crate::Device; mod tun; @@ -23,16 +25,27 @@ pub(crate) struct IfaceConfig { pub(crate) struct DeviceIo(Arc>); impl DeviceIo { - pub async fn read(&self, out: &mut [u8]) -> std::io::Result { - self.0 - .async_io(Interest::READABLE, |inner| inner.read(out)) - .await + pub fn poll_read(&self, out: &mut [u8], cx: &mut Context<'_>) -> Poll> { + loop { + let mut guard = ready!(self.0.poll_read_ready(cx))?; + + match guard.get_inner().read(out) { + Ok(n) => return Poll::Ready(Ok(n)), + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + // a read has blocked, but a write might still succeed. + // clear only the read readiness. + guard.clear_ready_matching(Ready::READABLE); + continue; + } + Err(e) => return Poll::Ready(Err(e)), + } + } } // Note: write is synchronous because it's non-blocking // and some losiness is acceptable and increseases performance // since we don't block the reading loops. - pub fn write(&self, packet: Packet<'_>) -> std::io::Result { + pub fn write(&self, packet: Packet<'_>) -> io::Result { match packet { Packet::Ipv4(msg) => self.0.get_ref().write4(&msg), Packet::Ipv6(msg) => self.0.get_ref().write6(&msg), @@ -65,11 +78,7 @@ impl IfaceConfig { iface, mtu: AtomicUsize::new(mtu), }); - Ok(Some(Device { - io, - config, - buf: Box::new([0u8; MAX_UDP_SIZE]), - })) + Ok(Some(Device { io, config })) } } @@ -86,9 +95,5 @@ pub(crate) async fn create_iface( mtu: AtomicUsize::new(mtu), }); - Ok(Device { - io, - config, - buf: Box::new([0u8; MAX_UDP_SIZE]), - }) + Ok(Device { io, config }) } diff --git a/rust/connlib/tunnel/src/device_channel/device_channel_win.rs b/rust/connlib/tunnel/src/device_channel/device_channel_win.rs index 872b35e8a..b2bd654be 100644 --- a/rust/connlib/tunnel/src/device_channel/device_channel_win.rs +++ b/rust/connlib/tunnel/src/device_channel/device_channel_win.rs @@ -2,6 +2,7 @@ use crate::device_channel::Packet; use crate::Device; use connlib_shared::{messages::Interface, CallbackErrorFacade, Callbacks, Result}; use ip_network::IpNetwork; +use std::task::{Context, Poll}; #[derive(Clone)] pub(crate) struct DeviceIo; @@ -9,7 +10,7 @@ pub(crate) struct DeviceIo; pub(crate) struct IfaceConfig; impl DeviceIo { - pub async fn read(&self, _: &mut [u8]) -> std::io::Result { + pub fn poll_read(&self, _: &mut [u8], _: &mut Context<'_>) -> Poll> { todo!() } diff --git a/rust/connlib/tunnel/src/gateway.rs b/rust/connlib/tunnel/src/gateway.rs index e7ae27115..2f94733e3 100644 --- a/rust/connlib/tunnel/src/gateway.rs +++ b/rust/connlib/tunnel/src/gateway.rs @@ -1,16 +1,11 @@ use crate::device_channel::create_iface; use crate::{ - peer_by_ip, Device, Event, RoleState, Tunnel, ICE_GATHERING_TIMEOUT_SECONDS, - MAX_CONCURRENT_ICE_GATHERING, MAX_UDP_SIZE, + Event, RoleState, Tunnel, ICE_GATHERING_TIMEOUT_SECONDS, MAX_CONCURRENT_ICE_GATHERING, }; -use connlib_shared::error::ConnlibError; use connlib_shared::messages::{ClientId, Interface as InterfaceConfig}; use connlib_shared::Callbacks; use futures::channel::mpsc::Receiver; use futures_bounded::{PushError, StreamMap}; -use futures_util::SinkExt; -use std::net::IpAddr; -use std::sync::Arc; use std::task::{ready, Context, Poll}; use std::time::Duration; use webrtc::ice_transport::ice_candidate::RTCIceCandidateInit; @@ -21,15 +16,11 @@ where { /// Sets the interface configuration and starts background tasks. #[tracing::instrument(level = "trace", skip(self))] - pub async fn set_interface( - self: &Arc, - config: &InterfaceConfig, - ) -> connlib_shared::Result<()> { + pub async fn set_interface(&self, config: &InterfaceConfig) -> connlib_shared::Result<()> { let device = create_iface(config, self.callbacks()).await?; - *self.device.write().await = Some(device.clone()); - *self.iface_handler_abort.lock() = - Some(tokio::spawn(device_handler(Arc::clone(self), device)).abort_handle()); + *self.device.write() = Some(device.clone()); + self.no_device_waker.wake(); tracing::debug!("background_loop_started"); @@ -43,67 +34,6 @@ where } } -/// Reads IP packets from the [`Device`] and handles them accordingly. -async fn device_handler( - tunnel: Arc>, - mut device: Device, -) -> Result<(), ConnlibError> -where - CB: Callbacks + 'static, -{ - let mut buf = [0u8; MAX_UDP_SIZE]; - loop { - let Some(packet) = device.read().await? else { - // Reading a bad IP packet or otherwise from the device seems bad. Should we restart the tunnel or something? - return Ok(()); - }; - - let dest = packet.destination(); - - let (result, channel, peer_conn_id) = { - let peers_by_ip = tunnel.peers_by_ip.read(); - let Some(peer) = peer_by_ip(&peers_by_ip, dest) else { - continue; - }; - - let result = peer.inner.encapsulate(packet, dest, &mut buf); - let channel = peer.channel.clone(); - - (result, channel, peer.inner.conn_id) - }; - - let error = match result { - Ok(None) => continue, - Ok(Some(b)) => match channel.write(&b).await { - Ok(_) => continue, - Err(e) => ConnlibError::IceDataError(e), - }, - Err(e) => e, - }; - - on_error(&tunnel, dest, error, peer_conn_id).await - } -} - -async fn on_error( - tunnel: &Tunnel, - dest: IpAddr, - e: ConnlibError, - peer_conn_id: ClientId, -) where - CB: Callbacks + 'static, -{ - tracing::error!(resource_address = %dest, err = ?e, "failed to handle packet {e:#}"); - - if e.is_fatal_connection_error() { - let _ = tunnel - .stop_peer_command_sender - .clone() - .send(peer_conn_id) - .await; - } -} - /// [`Tunnel`] state specific to gateways. pub struct GatewayState { candidate_receivers: StreamMap, diff --git a/rust/connlib/tunnel/src/lib.rs b/rust/connlib/tunnel/src/lib.rs index f6e446838..7641f5f7d 100644 --- a/rust/connlib/tunnel/src/lib.rs +++ b/rust/connlib/tunnel/src/lib.rs @@ -16,7 +16,7 @@ use pnet_packet::Packet; use hickory_resolver::proto::rr::RecordType; use parking_lot::{Mutex, RwLock}; use peer::{Peer, PeerStats}; -use tokio::{task::AbortHandle, time::MissedTickBehavior}; +use tokio::time::MissedTickBehavior; use webrtc::{ api::{ interceptor_registry::register_default_interceptors, media_engine::MediaEngine, @@ -27,10 +27,11 @@ use webrtc::{ }; use futures::channel::mpsc; +use futures_util::task::AtomicWaker; use futures_util::{SinkExt, StreamExt}; use itertools::Itertools; use std::collections::VecDeque; -use std::task::{Context, Poll}; +use std::task::{ready, Context, Poll}; use std::{collections::HashMap, fmt, io, net::IpAddr, sync::Arc, time::Duration}; use std::{collections::HashSet, hash::Hash}; use tokio::time::Interval; @@ -45,12 +46,13 @@ use connlib_shared::{ use device_channel::{DeviceIo, IfaceConfig}; pub use client::ClientState; +use connlib_shared::error::ConnlibError; pub use control_protocol::Request; pub use gateway::GatewayState; pub use webrtc::peer_connection::sdp::session_description::RTCSessionDescription; use crate::ip_packet::MutableIpPacket; -use connlib_shared::messages::SecretKey; +use connlib_shared::messages::{ClientId, SecretKey}; use index::IndexLfsr; mod bounded_queue; @@ -64,7 +66,6 @@ mod ip_packet; mod peer; mod peer_handler; mod resource_table; -mod tokio_util; const MAX_UDP_SIZE: usize = (1 << 16) - 1; const DNS_QUERIES_QUEUE_SIZE: usize = 100; @@ -108,35 +109,34 @@ impl From for PeerConfig { struct Device { config: Arc, io: DeviceIo, - - buf: Box<[u8; MAX_UDP_SIZE]>, } impl Device { - async fn read(&mut self) -> io::Result>> { - let res = self.io.read(&mut self.buf[..self.config.mtu()]).await?; - tracing::trace!(target: "wire", action = "read", bytes = res, from = "iface"); + fn poll_read<'b>( + &mut self, + buf: &'b mut [u8], + cx: &mut Context<'_>, + ) -> Poll>>> { + let res = ready!(self.io.poll_read(&mut buf[..self.config.mtu()], cx))?; if res == 0 { - return Ok(None); + return Poll::Ready(Ok(None)); } - Ok(Some( - MutableIpPacket::new(&mut self.buf[..res]).ok_or_else(|| { + Poll::Ready(Ok(Some(MutableIpPacket::new(&mut buf[..res]).ok_or_else( + || { io::Error::new( io::ErrorKind::InvalidInput, "received bytes are not an IP packet", ) - })?, - )) + }, + )?))) } } /// Tunnel is a wireguard state machine that uses webrtc's ICE channels instead of UDP sockets to communicate between peers. pub struct Tunnel { next_index: Mutex, - // We use a tokio Mutex here since this is only read/write during config so there's no relevant performance impact - device: tokio::sync::RwLock>, rate_limiter: Arc, private_key: StaticSecret, public_key: PublicKey, @@ -145,7 +145,6 @@ pub struct Tunnel { peer_connections: Mutex>>, webrtc_api: API, callbacks: CallbackErrorFacade, - iface_handler_abort: Mutex>, /// State that differs per role, i.e. clients vs gateways. role_state: Mutex, @@ -158,6 +157,142 @@ pub struct Tunnel { mtu_refresh_interval: Mutex, peers_to_stop: Mutex>, + + device: RwLock>, + read_buf: Mutex>, + write_buf: Mutex>, + no_device_waker: AtomicWaker, +} + +impl Tunnel +where + CB: Callbacks + 'static, +{ + pub async fn next_event(&self) -> Result> { + std::future::poll_fn(|cx| loop { + { + let mut guard = self.device.write(); + + if let Some(device) = guard.as_mut() { + match self.poll_device(device, cx) { + Poll::Ready(Ok(Some(event))) => return Poll::Ready(Ok(event)), + Poll::Ready(Ok(None)) => { + tracing::info!("Device stopped"); + guard.take(); + continue; + } + Poll::Ready(Err(e)) => { + guard.take(); // Ensure we don't poll a failed device again. + return Poll::Ready(Err(e)); + } + Poll::Pending => {} + } + } else { + self.no_device_waker.register(cx.waker()); + } + } + + match self.poll_next_event_common(cx) { + Poll::Ready(event) => return Poll::Ready(Ok(event)), + Poll::Pending => {} + } + + return Poll::Pending; + }) + .await + } + + pub(crate) fn poll_device( + &self, + device: &mut Device, + cx: &mut Context<'_>, + ) -> Poll>>> { + loop { + let mut read_guard = self.read_buf.lock(); + let mut write_guard = self.write_buf.lock(); + let read_buf = read_guard.as_mut_slice(); + let write_buf = write_guard.as_mut_slice(); + + let Some(packet) = ready!(device.poll_read(read_buf, cx))? else { + return Poll::Ready(Ok(None)); + }; + + let mut role_state = self.role_state.lock(); + + let packet = match role_state.handle_dns(packet) { + Ok(Some(response)) => { + device.io.write(response)?; + continue; + } + Ok(None) => continue, + Err(non_dns_packet) => non_dns_packet, + }; + + let dest = packet.destination(); + + let peers_by_ip = self.peers_by_ip.read(); + let Some(peer) = peer_by_ip(&peers_by_ip, dest) else { + role_state.on_connection_intent(dest); + continue; + }; + + self.encapsulate(write_buf, packet, dest, peer); + + continue; + } + } +} + +impl Tunnel +where + CB: Callbacks + 'static, +{ + pub fn poll_next_event(&self, cx: &mut Context<'_>) -> Poll>> { + let mut read_guard = self.read_buf.lock(); + let mut write_guard = self.write_buf.lock(); + + let read_buf = read_guard.as_mut_slice(); + let write_buf = write_guard.as_mut_slice(); + + loop { + { + let mut device = self.device.write(); + + match device.as_mut().map(|d| d.poll_read(read_buf, cx)) { + Some(Poll::Ready(Ok(Some(packet)))) => { + let dest = packet.destination(); + + let peers_by_ip = self.peers_by_ip.read(); + let Some(peer) = peer_by_ip(&peers_by_ip, dest) else { + continue; + }; + + self.encapsulate(write_buf, packet, dest, peer); + + continue; + } + Some(Poll::Ready(Ok(None))) => { + tracing::info!("Device stopped"); + drop(device.take()); + } + Some(Poll::Ready(Err(e))) => return Poll::Ready(Err(ConnlibError::Io(e))), + Some(Poll::Pending) => { + // device not ready for reading, moving on .. + } + None => { + self.no_device_waker.register(cx.waker()); + } + } + } + + match self.poll_next_event_common(cx) { + Poll::Ready(e) => return Poll::Ready(Ok(e)), + Poll::Pending => {} + } + + return Poll::Pending; + } + } } pub struct ConnectedPeer { @@ -195,11 +330,7 @@ where } } - pub async fn next_event(&self) -> Event { - std::future::poll_fn(|cx| self.poll_next_event(cx)).await - } - - pub fn poll_next_event(&self, cx: &mut Context<'_>) -> Poll> { + fn poll_next_event_common(&self, cx: &mut Context<'_>) -> Poll> { loop { if let Some(conn_id) = self.peers_to_stop.lock().pop_front() { let mut peers = self.peers_by_ip.write(); @@ -291,21 +422,9 @@ where } if self.mtu_refresh_interval.lock().poll_tick(cx).is_ready() { - // We use `try_read` to acquire a lock on the device because we are within a synchronous context here and cannot use `.await`. - // The device is only updated during `add_route` and `set_interface` which would be extremely unlucky to hit at the same time as this timer. - // Even if we hit this, we just wait for the next tick to update the MTU. - let device = match self.device.try_read().map(|d| d.clone()) { - Ok(Some(device)) => device, - Ok(None) => { - let err = Error::ControlProtocolError; - tracing::error!(?err, "get_iface_config"); - let _ = self.callbacks.on_error(&err); - continue; - } - Err(_) => { - tracing::debug!("Unlucky! Somebody is updating the device just as we are about to update its MTU, trying again on the next tick ..."); - continue; - } + let Some(device) = self.device.read().clone() else { + tracing::debug!("Device temporarily not available"); + continue; }; tokio::spawn({ @@ -335,6 +454,40 @@ where return Poll::Pending; } } + + fn encapsulate( + &self, + write_buf: &mut [u8], + packet: MutableIpPacket, + dest: IpAddr, + peer: &ConnectedPeer, + ) { + let peer_id = peer.inner.conn_id; + + match peer.inner.encapsulate(packet, dest, write_buf) { + Ok(None) => {} + Ok(Some(b)) => { + tokio::spawn({ + let channel = peer.channel.clone(); + let mut sender = self.stop_peer_command_sender.clone(); + + async move { + if let Err(e) = channel.write(&b).await { + tracing::error!(resource_address = %dest, err = ?e, "failed to handle packet {e:#}"); + let _ = sender.send(peer_id).await; + } + } + }); + } + Err(e) => { + tracing::error!(resource_address = %dest, err = ?e, "failed to handle packet {e:#}"); + + if e.is_fatal_connection_error() { + self.peers_to_stop.lock().push_back(peer_id); + } + } + }; + } } pub(crate) fn peer_by_ip( @@ -403,7 +556,6 @@ where let next_index = Default::default(); let peer_connections = Default::default(); let device = Default::default(); - let iface_handler_abort = Default::default(); // ICE let mut media_engine = MediaEngine::default(); @@ -433,8 +585,9 @@ where next_index, webrtc_api, device, + read_buf: Mutex::new(Box::new([0u8; MAX_UDP_SIZE])), + write_buf: Mutex::new(Box::new([0u8; MAX_UDP_SIZE])), callbacks: CallbackErrorFacade(callbacks), - iface_handler_abort, role_state: Default::default(), stop_peer_command_receiver: Mutex::new(stop_peer_command_receiver), stop_peer_command_sender, @@ -442,6 +595,7 @@ where peer_refresh_interval: Mutex::new(peer_refresh_interval()), mtu_refresh_interval: Mutex::new(mtu_refresh_interval()), peers_to_stop: Default::default(), + no_device_waker: Default::default(), }) } diff --git a/rust/connlib/tunnel/src/peer_handler.rs b/rust/connlib/tunnel/src/peer_handler.rs index 994ed433d..a66884610 100644 --- a/rust/connlib/tunnel/src/peer_handler.rs +++ b/rust/connlib/tunnel/src/peer_handler.rs @@ -1,6 +1,7 @@ use std::sync::Arc; +use std::time::Duration; -use connlib_shared::{Callbacks, Error, Result}; +use connlib_shared::Callbacks; use futures_util::SinkExt; use webrtc::data::data_channel::DataChannel; @@ -18,15 +19,15 @@ where channel: Arc, ) { loop { - let Some(device) = self.device.read().await.clone() else { - let err = Error::NoIface; - tracing::error!(?err); - let _ = self.callbacks().on_disconnect(Some(&err)); - break; + let Some(device) = self.device.read().clone() else { + tracing::debug!("Device temporarily not available"); + tokio::time::sleep(Duration::from_millis(100)).await; + continue; }; let device_io = device.io; - let result = self.peer_handler(&peer, channel.clone(), device_io).await; + let result = + peer_handler(self.callbacks.clone(), &peer, channel.clone(), device_io).await; if matches!(result, Err(ref err) if err.raw_os_error() == Some(9)) { tracing::warn!("bad_file_descriptor"); @@ -47,66 +48,51 @@ where .send(peer.conn_id) .await; } +} - async fn peer_handler( - self: &Arc, - peer: &Arc>, - channel: Arc, - device_io: DeviceIo, - ) -> std::io::Result<()> { - let mut src_buf = [0u8; MAX_UDP_SIZE]; - let mut dst_buf = [0u8; MAX_UDP_SIZE]; - while let Ok(size) = channel.read(&mut src_buf[..]).await { - tracing::trace!(target: "wire", action = "read", bytes = size, from = "peer"); +async fn peer_handler( + callbacks: impl Callbacks, + peer: &Arc>, + channel: Arc, + device_io: DeviceIo, +) -> std::io::Result<()> +where + TId: Copy, +{ + let mut src_buf = [0u8; MAX_UDP_SIZE]; + let mut dst_buf = [0u8; MAX_UDP_SIZE]; + while let Ok(size) = channel.read(&mut src_buf[..]).await { + tracing::trace!(target: "wire", action = "read", bytes = size, from = "peer"); - // TODO: Double check that this can only happen on closed channel - // I think it's possible to transmit a 0-byte message through the channel - // but we would never use that. - // We should keep track of an open/closed channel ourselves if we wanted to do it properly then. - if size == 0 { - break; - } - - match self - .handle_peer_packet(peer, &channel, &device_io, &src_buf[..size], &mut dst_buf) - .await - { - Err(Error::Io(e)) => return Err(e), - Err(other) => { - tracing::error!(error = ?other, "failed to handle peer packet"); - let _ = self.callbacks.on_error(&other); - } - _ => {} - } + // TODO: Double check that this can only happen on closed channel + // I think it's possible to transmit a 0-byte message through the channel + // but we would never use that. + // We should keep track of an open/closed channel ourselves if we wanted to do it properly then. + if size == 0 { + break; } - Ok(()) - } + let src = &src_buf[..size]; - #[inline(always)] - pub(crate) async fn handle_peer_packet( - self: &Arc, - peer: &Arc>, - channel: &DataChannel, - device_writer: &DeviceIo, - src: &[u8], - dst: &mut [u8], - ) -> Result<()> { - match peer.decapsulate(src, dst)? { - Some(WriteTo::Network(bytes)) => { + match peer.decapsulate(src, &mut dst_buf) { + Ok(Some(WriteTo::Network(bytes))) => { for packet in bytes { if let Err(e) = channel.write(&packet).await { tracing::error!("Couldn't send packet to connected peer: {e}"); - let _ = self.callbacks.on_error(&e.into()); + let _ = callbacks.on_error(&e.into()); } } } - Some(WriteTo::Resource(packet)) => { - device_writer.write(packet)?; + Ok(Some(WriteTo::Resource(packet))) => { + device_io.write(packet)?; + } + Ok(None) => {} + Err(other) => { + tracing::error!(error = ?other, "failed to handle peer packet"); + let _ = callbacks.on_error(&other); } - None => {} } - - Ok(()) } + + Ok(()) } diff --git a/rust/connlib/tunnel/src/tokio_util.rs b/rust/connlib/tunnel/src/tokio_util.rs deleted file mode 100644 index 8c0840ba3..000000000 --- a/rust/connlib/tunnel/src/tokio_util.rs +++ /dev/null @@ -1,23 +0,0 @@ -use connlib_shared::error::ConnlibError; -use connlib_shared::Callbacks; -use std::future::Future; - -/// Spawns a task into the [`tokio`] runtime. -/// -/// On error, [`Callbacks::on_error`] is invoked. -/// This also returns a [`tokio::task::AbortHandle`] which MAY be used to abort the task. -/// If you don't need it, you are free to drop it. -/// It won't terminate the task. -pub(crate) fn spawn_log( - cb: &(impl Callbacks + 'static), - f: impl Future> + Send + 'static, -) -> tokio::task::AbortHandle { - let cb = cb.clone(); - - tokio::spawn(async move { - if let Err(e) = f.await { - let _ = cb.on_error(&e); - } - }) - .abort_handle() -} diff --git a/rust/gateway/src/eventloop.rs b/rust/gateway/src/eventloop.rs index d0c3ef6b3..d666559b0 100644 --- a/rust/gateway/src/eventloop.rs +++ b/rust/gateway/src/eventloop.rs @@ -173,7 +173,7 @@ impl Eventloop { _ => {} } - match self.tunnel.poll_next_event(cx) { + match self.tunnel.poll_next_event(cx)? { Poll::Ready(firezone_tunnel::Event::SignalIceCandidate { conn_id: client, candidate,