From dbf0e445b00ebf375bd5c68dffc5e1752f2e5dff Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Thu, 12 Oct 2023 10:02:31 +1100 Subject: [PATCH] refactor(connlib): split device handler for clients & gateway (#2301) --- rust/connlib/tunnel/src/client.rs | 285 +++++++++++++++++ rust/connlib/tunnel/src/control_protocol.rs | 3 +- .../tunnel/src/control_protocol/gateway.rs | 3 +- .../src/device_channel/device_channel_unix.rs | 14 +- rust/connlib/tunnel/src/dns.rs | 290 +++++++++--------- rust/connlib/tunnel/src/gateway.rs | 120 ++++++++ rust/connlib/tunnel/src/iface_handler.rs | 155 +--------- rust/connlib/tunnel/src/ip_packet.rs | 47 +-- rust/connlib/tunnel/src/lib.rs | 155 ++++------ rust/connlib/tunnel/src/peer.rs | 12 +- rust/connlib/tunnel/src/peer_handler.rs | 5 +- rust/connlib/tunnel/src/role_state.rs | 148 --------- rust/connlib/tunnel/src/tokio_util.rs | 23 ++ 13 files changed, 686 insertions(+), 574 deletions(-) create mode 100644 rust/connlib/tunnel/src/client.rs create mode 100644 rust/connlib/tunnel/src/gateway.rs delete mode 100644 rust/connlib/tunnel/src/role_state.rs create mode 100644 rust/connlib/tunnel/src/tokio_util.rs diff --git a/rust/connlib/tunnel/src/client.rs b/rust/connlib/tunnel/src/client.rs new file mode 100644 index 000000000..7ce52de80 --- /dev/null +++ b/rust/connlib/tunnel/src/client.rs @@ -0,0 +1,285 @@ +use crate::device_channel::{create_iface, DeviceIo}; +use crate::ip_packet::IpPacket; +use crate::{ + dns, tokio_util, ConnId, ControlSignal, Device, Event, RoleState, Tunnel, + ICE_GATHERING_TIMEOUT_SECONDS, MAX_CONCURRENT_ICE_GATHERING, MAX_UDP_SIZE, +}; +use connlib_shared::error::{ConnlibError as Error, ConnlibError}; +use connlib_shared::messages::{GatewayId, Interface as InterfaceConfig, ResourceDescription}; +use connlib_shared::{Callbacks, DNS_SENTINEL}; +use futures::channel::mpsc::Receiver; +use futures_bounded::{PushError, StreamMap}; +use ip_network::IpNetwork; +use std::collections::HashMap; +use std::io; +use std::sync::Arc; +use std::task::{ready, Context, Poll}; +use std::time::Duration; +use webrtc::ice_transport::ice_candidate::RTCIceCandidateInit; + +impl Tunnel +where + C: ControlSignal + Send + Sync + 'static, + CB: Callbacks + 'static, +{ + /// Adds a the given resource to the tunnel. + /// + /// Once added, when a packet for the resource is intercepted a new data channel will be created + /// and packets will be wrapped with wireguard and sent through it. + #[tracing::instrument(level = "trace", skip(self))] + pub async fn add_resource( + self: &Arc, + resource_description: ResourceDescription, + ) -> connlib_shared::Result<()> { + let mut any_valid_route = false; + { + for ip in resource_description.ips() { + if let Err(e) = self.add_route(ip).await { + tracing::warn!(route = %ip, error = ?e, "add_route"); + let _ = self.callbacks().on_error(&e); + } else { + any_valid_route = true; + } + } + } + if !any_valid_route { + return Err(Error::InvalidResource); + } + + let resource_list = { + let mut resources = self.resources.write(); + resources.insert(resource_description); + resources.resource_list() + }; + + self.callbacks.on_update_resources(resource_list)?; + Ok(()) + } + + /// Sets the interface configuration and starts background tasks. + #[tracing::instrument(level = "trace", skip(self))] + pub async fn set_interface( + self: &Arc, + config: &InterfaceConfig, + ) -> connlib_shared::Result<()> { + let device = create_iface(config, self.callbacks()).await?; + *self.device.write().await = Some(device.clone()); + + self.start_timers().await?; + *self.iface_handler_abort.lock() = Some(tokio_util::spawn_log( + &self.callbacks, + device_handler(Arc::clone(self), device), + )); + + self.add_route(DNS_SENTINEL.into()).await?; + + self.callbacks.on_tunnel_ready()?; + + tracing::debug!("background_loop_started"); + + Ok(()) + } + + #[tracing::instrument(level = "trace", skip(self))] + async fn add_route(self: &Arc, route: IpNetwork) -> connlib_shared::Result<()> { + let mut device = self.device.write().await; + + if let Some(new_device) = device + .as_ref() + .ok_or(Error::ControlProtocolError)? + .config + .add_route(route, self.callbacks()) + .await? + { + *device = Some(new_device.clone()); + *self.iface_handler_abort.lock() = Some(tokio_util::spawn_log( + &self.callbacks, + device_handler(Arc::clone(self), new_device), + )); + } + + Ok(()) + } + + #[inline(always)] + fn connection_intent(self: &Arc, packet: IpPacket<'_>) { + const MAX_SIGNAL_CONNECTION_DELAY: Duration = Duration::from_secs(2); + + // We can buffer requests here but will drop them for now and let the upper layer reliability protocol handle this + if let Some(resource) = self.get_resource(packet.source()) { + // We have awaiting connection to prevent a race condition where + // create_peer_connection hasn't added the thing to peer_connections + // and we are finding another packet to the same address (otherwise we would just use peer_connections here) + let mut awaiting_connection = self.awaiting_connection.lock(); + let conn_id = ConnId::from(resource.id()); + if awaiting_connection.get(&conn_id).is_none() { + tracing::trace!( + resource_ip = %packet.destination(), + "resource_connection_intent", + ); + + awaiting_connection.insert(conn_id, Default::default()); + let dev = Arc::clone(self); + + let mut connected_gateway_ids: Vec<_> = dev + .gateway_awaiting_connection + .lock() + .clone() + .into_keys() + .collect(); + connected_gateway_ids + .extend(dev.resources_gateways.lock().values().collect::>()); + tracing::trace!( + gateways = ?connected_gateway_ids, + "connected_gateways" + ); + tokio::spawn(async move { + let mut interval = tokio::time::interval(MAX_SIGNAL_CONNECTION_DELAY); + loop { + interval.tick().await; + let reference = { + let mut awaiting_connections = dev.awaiting_connection.lock(); + let Some(awaiting_connection) = + awaiting_connections.get_mut(&ConnId::from(resource.id())) + else { + break; + }; + if awaiting_connection.response_received { + break; + } + awaiting_connection.total_attemps += 1; + awaiting_connection.total_attemps + }; + if let Err(e) = dev + .control_signaler + .signal_connection_to(&resource, &connected_gateway_ids, reference) + .await + { + // Not a deadlock because this is a different task + dev.awaiting_connection.lock().remove(&conn_id); + tracing::error!(error = ?e, "start_resource_connection"); + let _ = dev.callbacks.on_error(&e); + } + } + }); + } + } + } +} + +/// Reads IP packets from the [`Device`] and handles them accordingly. +async fn device_handler( + tunnel: Arc>, + mut device: Device, +) -> Result<(), ConnlibError> +where + C: ControlSignal + Send + Sync + 'static, + CB: Callbacks + 'static, +{ + let device_writer = device.io.clone(); + let mut buf = [0u8; MAX_UDP_SIZE]; + loop { + let Some(packet) = device.read().await? else { + return Ok(()); + }; + + if let Some(dns_packet) = dns::parse(&tunnel.resources.read(), packet.as_immutable()) { + if let Err(e) = send_dns_packet(&device_writer, dns_packet) { + tracing::error!(err = %e, "failed to send DNS packet"); + let _ = tunnel.callbacks.on_error(&e.into()); + } + + continue; + } + + let dest = packet.destination(); + + let Some(peer) = tunnel.peer_by_ip(dest) else { + tunnel.connection_intent(packet.as_immutable()); + continue; + }; + + if let Err(e) = tunnel + .encapsulate_and_send_to_peer(packet, peer, &dest, &mut buf) + .await + { + let _ = tunnel.callbacks.on_error(&e); + tracing::error!(err = ?e, "failed to handle packet {e:#}") + } + } +} + +fn send_dns_packet(device_writer: &DeviceIo, packet: dns::Packet) -> io::Result<()> { + match packet { + dns::Packet::Ipv4(r) => device_writer.write4(&r[..])?, + dns::Packet::Ipv6(r) => device_writer.write6(&r[..])?, + }; + + Ok(()) +} + +/// [`Tunnel`] state specific to clients. +pub struct ClientState { + active_candidate_receivers: StreamMap, + /// We split the receivers of ICE candidates into two phases because we only want to start sending them once we've received an SDP from the gateway. + waiting_for_sdp_from_gatway: HashMap>, +} + +impl ClientState { + pub fn add_waiting_ice_receiver( + &mut self, + id: GatewayId, + receiver: Receiver, + ) { + self.waiting_for_sdp_from_gatway.insert(id, receiver); + } + + pub fn activate_ice_candidate_receiver(&mut self, id: GatewayId) { + let Some(receiver) = self.waiting_for_sdp_from_gatway.remove(&id) else { + return; + }; + + match self.active_candidate_receivers.try_push(id, receiver) { + Ok(()) => {} + Err(PushError::BeyondCapacity(_)) => { + tracing::warn!("Too many active ICE candidate receivers at a time") + } + Err(PushError::Replaced(_)) => { + tracing::warn!(%id, "Replaced old ICE candidate receiver with new one") + } + } + } +} + +impl Default for ClientState { + fn default() -> Self { + Self { + active_candidate_receivers: StreamMap::new( + Duration::from_secs(ICE_GATHERING_TIMEOUT_SECONDS), + MAX_CONCURRENT_ICE_GATHERING, + ), + waiting_for_sdp_from_gatway: Default::default(), + } + } +} + +impl RoleState for ClientState { + type Id = GatewayId; + + fn poll_next_event(&mut self, cx: &mut Context<'_>) -> Poll> { + loop { + match ready!(self.active_candidate_receivers.poll_next_unpin(cx)) { + (conn_id, Some(Ok(c))) => { + return Poll::Ready(Event::SignalIceCandidate { + conn_id, + candidate: c, + }) + } + (id, Some(Err(e))) => { + tracing::warn!(gateway_id = %id, "ICE gathering timed out: {e}") + } + (_, None) => {} + } + } + } +} diff --git a/rust/connlib/tunnel/src/control_protocol.rs b/rust/connlib/tunnel/src/control_protocol.rs index f6b21a077..f3f5d044e 100644 --- a/rust/connlib/tunnel/src/control_protocol.rs +++ b/rust/connlib/tunnel/src/control_protocol.rs @@ -22,8 +22,7 @@ use webrtc::{ }, }; -use crate::role_state::RoleState; -use crate::{peer::Peer, ConnId, ControlSignal, PeerConfig, Tunnel}; +use crate::{peer::Peer, ConnId, ControlSignal, PeerConfig, RoleState, Tunnel}; mod client; mod gateway; diff --git a/rust/connlib/tunnel/src/control_protocol/gateway.rs b/rust/connlib/tunnel/src/control_protocol/gateway.rs index 6ab93ee60..0e34ed189 100644 --- a/rust/connlib/tunnel/src/control_protocol/gateway.rs +++ b/rust/connlib/tunnel/src/control_protocol/gateway.rs @@ -10,8 +10,7 @@ use webrtc::peer_connection::{ RTCPeerConnection, }; -use crate::role_state::GatewayState; -use crate::{ControlSignal, PeerConfig, Tunnel}; +use crate::{ControlSignal, GatewayState, PeerConfig, Tunnel}; #[tracing::instrument(level = "trace", skip(tunnel))] fn handle_connection_state_update( diff --git a/rust/connlib/tunnel/src/device_channel/device_channel_unix.rs b/rust/connlib/tunnel/src/device_channel/device_channel_unix.rs index 3de06cf94..927997908 100644 --- a/rust/connlib/tunnel/src/device_channel/device_channel_unix.rs +++ b/rust/connlib/tunnel/src/device_channel/device_channel_unix.rs @@ -9,7 +9,7 @@ use tokio::io::{unix::AsyncFd, Interest}; use tun::{IfaceDevice, IfaceStream}; -use crate::Device; +use crate::{Device, MAX_UDP_SIZE}; mod tun; @@ -65,7 +65,11 @@ impl IfaceConfig { iface, mtu: AtomicUsize::new(mtu), }); - Ok(Some(Device { io, config })) + Ok(Some(Device { + io, + config, + buf: Box::new([0u8; MAX_UDP_SIZE]), + })) } } @@ -82,5 +86,9 @@ pub(crate) async fn create_iface( mtu: AtomicUsize::new(mtu), }); - Ok(Device { config, io }) + Ok(Device { + io, + config, + buf: Box::new([0u8; MAX_UDP_SIZE]), + }) } diff --git a/rust/connlib/tunnel/src/dns.rs b/rust/connlib/tunnel/src/dns.rs index c592d0e70..41783c564 100644 --- a/rust/connlib/tunnel/src/dns.rs +++ b/rust/connlib/tunnel/src/dns.rs @@ -1,15 +1,12 @@ -use std::{net::IpAddr, sync::Arc}; - -use crate::{ - ip_packet::{to_dns, IpPacket, MutableIpPacket, Version}, - ControlSignal, Tunnel, -}; -use connlib_shared::{messages::ResourceDescription, Callbacks, DNS_SENTINEL}; +use crate::ip_packet::{to_dns, IpPacket, MutableIpPacket, Version}; +use crate::resource_table::ResourceTable; +use connlib_shared::{messages::ResourceDescription, DNS_SENTINEL}; use domain::base::{ iana::{Class, Rcode, Rtype}, Dname, Message, MessageBuilder, ParsedDname, ToDname, }; use pnet_packet::{udp::MutableUdpPacket, MutablePacket, Packet as UdpPacket, PacketSize}; +use std::net::IpAddr; const DNS_TTL: u32 = 300; const UDP_HEADER_SIZE: usize = 8; @@ -18,7 +15,7 @@ const REVERSE_DNS_ADDRESS_V4: &str = "in-addr"; const REVERSE_DNS_ADDRESS_V6: &str = "ip6"; #[derive(Debug, Clone)] -pub(crate) enum SendPacket { +pub(crate) enum Packet { Ipv4(Vec), Ipv6(Vec), } @@ -28,152 +25,139 @@ pub(crate) enum SendPacket { // as we can therefore we won't do it. // // See: https://stackoverflow.com/a/55093896 -impl Tunnel -where - C: ControlSignal + Send + Sync + 'static, - CB: Callbacks + 'static, -{ - fn build_response( - self: &Arc, - original_buf: &[u8], - mut dns_answer: Vec, - ) -> Option> { - let response_len = dns_answer.len(); - let original_pkt = IpPacket::new(original_buf)?; - let original_dgm = original_pkt.as_udp()?; - let hdr_len = original_pkt.packet_size() - original_dgm.payload().len(); - let mut res_buf = Vec::with_capacity(hdr_len + response_len); - - res_buf.extend_from_slice(&original_buf[..hdr_len]); - res_buf.append(&mut dns_answer); - - let mut pkt = MutableIpPacket::new(&mut res_buf)?; - let dgm_len = UDP_HEADER_SIZE + response_len; - pkt.set_len(hdr_len + response_len, dgm_len); - pkt.swap_src_dst(); - - let mut dgm = MutableUdpPacket::new(pkt.payload_mut())?; - dgm.set_length(dgm_len as u16); - dgm.set_source(original_dgm.get_destination()); - dgm.set_destination(original_dgm.get_source()); - - let mut pkt = MutableIpPacket::new(&mut res_buf)?; - let udp_checksum = pkt.to_immutable().udp_checksum(&pkt.as_immutable_udp()?); - pkt.as_udp()?.set_checksum(udp_checksum); - pkt.set_ipv4_checksum(); - Some(res_buf) +pub(crate) fn parse( + resources: &ResourceTable, + packet: IpPacket<'_>, +) -> Option { + let version = packet.version(); + if packet.destination() != IpAddr::from(DNS_SENTINEL) { + return None; } - - fn build_dns_with_answer( - self: &Arc, - message: &Message<[u8]>, - qname: &N, - qtype: Rtype, - resource: &ResourceDescription, - ) -> Option> - where - N: ToDname + ?Sized, - { - let msg_buf = Vec::with_capacity(message.as_slice().len() * 2); - let msg_builder = MessageBuilder::from_target(msg_buf).expect( - "Developer error: we should be always be able to create a MessageBuilder from a Vec", - ); - let mut answer_builder = msg_builder.start_answer(message, Rcode::NoError).ok()?; - match qtype { - Rtype::A => answer_builder - .push(( - qname, - Class::In, - DNS_TTL, - domain::rdata::A::from(resource.ipv4()?), - )) - .ok()?, - Rtype::Aaaa => answer_builder - .push(( - qname, - Class::In, - DNS_TTL, - domain::rdata::Aaaa::from(resource.ipv6()?), - )) - .ok()?, - Rtype::Ptr => answer_builder - .push(( - qname, - Class::In, - DNS_TTL, - domain::rdata::Ptr::>::new( - resource.dns_name()?.parse::>>().ok()?.into(), - ), - )) - .ok()?, - _ => return None, - } - Some(answer_builder.finish()) + let datagram = packet.as_udp()?; + let message = to_dns(&datagram)?; + if message.header().qr() { + return None; } - - pub(crate) fn check_for_dns(self: &Arc, buf: &[u8]) -> Option { - let packet = IpPacket::new(buf)?; - let version = packet.version(); - if packet.destination() != IpAddr::from(DNS_SENTINEL) { - return None; - } - let datagram = packet.as_udp()?; - let message = to_dns(&datagram)?; - if message.header().qr() { - return None; - } - let question = message.first_question()?; - let resource = match question.qtype() { - Rtype::A | Rtype::Aaaa => self - .resources - .read() - .get_by_name(&ToDname::to_cow(question.qname()).to_string()) - .cloned(), - Rtype::Ptr => { - let dns_parts = ToDname::to_cow(question.qname()).to_string(); - let mut dns_parts = dns_parts.split('.').rev(); - if !dns_parts - .next() - .is_some_and(|d| d == REVERSE_DNS_ADDRESS_END) - { - return None; - } - let ip: IpAddr = match dns_parts.next() { - Some(REVERSE_DNS_ADDRESS_V4) => { - let mut ip = [0u8; 4]; - for i in ip.iter_mut() { - *i = dns_parts.next()?.parse().ok()?; - } - ip.into() - } - Some(REVERSE_DNS_ADDRESS_V6) => { - let mut ip = [0u8; 16]; - for i in ip.iter_mut() { - *i = u8::from_str_radix( - &format!("{}{}", dns_parts.next()?, dns_parts.next()?), - 16, - ) - .ok()?; - } - ip.into() - } - _ => return None, - }; - - if dns_parts.next().is_some() { - return None; - } - - self.resources.read().get_by_ip(ip).cloned() + let question = message.first_question()?; + let resource = match question.qtype() { + Rtype::A | Rtype::Aaaa => resources + .get_by_name(&ToDname::to_cow(question.qname()).to_string()) + .cloned(), + Rtype::Ptr => { + let dns_parts = ToDname::to_cow(question.qname()).to_string(); + let mut dns_parts = dns_parts.split('.').rev(); + if !dns_parts + .next() + .is_some_and(|d| d == REVERSE_DNS_ADDRESS_END) + { + return None; } - _ => return None, - }; - let response = - self.build_dns_with_answer(message, question.qname(), question.qtype(), &resource?)?; - let response = self.build_response(buf, response); - response.map(|pkt| match version { - Version::Ipv4 => SendPacket::Ipv4(pkt), - Version::Ipv6 => SendPacket::Ipv6(pkt), - }) - } + let ip: IpAddr = match dns_parts.next() { + Some(REVERSE_DNS_ADDRESS_V4) => { + let mut ip = [0u8; 4]; + for i in ip.iter_mut() { + *i = dns_parts.next()?.parse().ok()?; + } + ip.into() + } + Some(REVERSE_DNS_ADDRESS_V6) => { + let mut ip = [0u8; 16]; + for i in ip.iter_mut() { + *i = u8::from_str_radix( + &format!("{}{}", dns_parts.next()?, dns_parts.next()?), + 16, + ) + .ok()?; + } + ip.into() + } + _ => return None, + }; + + if dns_parts.next().is_some() { + return None; + } + + resources.get_by_ip(ip).cloned() + } + _ => return None, + }; + let response = build_dns_with_answer(message, question.qname(), question.qtype(), &resource?)?; + let response = build_response(packet, response); + response.map(|pkt| match version { + Version::Ipv4 => Packet::Ipv4(pkt), + Version::Ipv6 => Packet::Ipv6(pkt), + }) +} + +fn build_response(original_pkt: IpPacket<'_>, mut dns_answer: Vec) -> Option> { + let response_len = dns_answer.len(); + let original_dgm = original_pkt.as_udp()?; + let hdr_len = original_pkt.packet_size() - original_dgm.payload().len(); + let mut res_buf = Vec::with_capacity(hdr_len + response_len); + + res_buf.extend_from_slice(&original_pkt.packet()[..hdr_len]); + res_buf.append(&mut dns_answer); + + let mut pkt = MutableIpPacket::new(&mut res_buf)?; + let dgm_len = UDP_HEADER_SIZE + response_len; + pkt.set_len(hdr_len + response_len, dgm_len); + pkt.swap_src_dst(); + + let mut dgm = MutableUdpPacket::new(pkt.payload_mut())?; + dgm.set_length(dgm_len as u16); + dgm.set_source(original_dgm.get_destination()); + dgm.set_destination(original_dgm.get_source()); + + let mut pkt = MutableIpPacket::new(&mut res_buf)?; + let udp_checksum = pkt.to_immutable().udp_checksum(&pkt.as_immutable_udp()?); + pkt.as_udp()?.set_checksum(udp_checksum); + pkt.set_ipv4_checksum(); + Some(res_buf) +} + +fn build_dns_with_answer( + message: &Message<[u8]>, + qname: &N, + qtype: Rtype, + resource: &ResourceDescription, +) -> Option> +where + N: ToDname + ?Sized, +{ + let msg_buf = Vec::with_capacity(message.as_slice().len() * 2); + let msg_builder = MessageBuilder::from_target(msg_buf).expect( + "Developer error: we should be always be able to create a MessageBuilder from a Vec", + ); + let mut answer_builder = msg_builder.start_answer(message, Rcode::NoError).ok()?; + match qtype { + Rtype::A => answer_builder + .push(( + qname, + Class::In, + DNS_TTL, + domain::rdata::A::from(resource.ipv4()?), + )) + .ok()?, + Rtype::Aaaa => answer_builder + .push(( + qname, + Class::In, + DNS_TTL, + domain::rdata::Aaaa::from(resource.ipv6()?), + )) + .ok()?, + Rtype::Ptr => answer_builder + .push(( + qname, + Class::In, + DNS_TTL, + domain::rdata::Ptr::>::new( + resource.dns_name()?.parse::>>().ok()?.into(), + ), + )) + .ok()?, + _ => return None, + } + Some(answer_builder.finish()) } diff --git a/rust/connlib/tunnel/src/gateway.rs b/rust/connlib/tunnel/src/gateway.rs new file mode 100644 index 000000000..4c5d9912e --- /dev/null +++ b/rust/connlib/tunnel/src/gateway.rs @@ -0,0 +1,120 @@ +use crate::device_channel::create_iface; +use crate::{ + ControlSignal, Device, Event, RoleState, Tunnel, ICE_GATHERING_TIMEOUT_SECONDS, + MAX_CONCURRENT_ICE_GATHERING, MAX_UDP_SIZE, +}; +use connlib_shared::error::ConnlibError; +use connlib_shared::messages::{ClientId, Interface as InterfaceConfig}; +use connlib_shared::Callbacks; +use futures::channel::mpsc::Receiver; +use futures_bounded::{PushError, StreamMap}; +use std::sync::Arc; +use std::task::{ready, Context, Poll}; +use std::time::Duration; +use webrtc::ice_transport::ice_candidate::RTCIceCandidateInit; + +impl Tunnel +where + C: ControlSignal + Send + Sync + 'static, + CB: Callbacks + 'static, +{ + /// Sets the interface configuration and starts background tasks. + #[tracing::instrument(level = "trace", skip(self))] + pub async fn set_interface( + self: &Arc, + config: &InterfaceConfig, + ) -> connlib_shared::Result<()> { + let device = create_iface(config, self.callbacks()).await?; + *self.device.write().await = Some(device.clone()); + + self.start_timers().await?; + *self.iface_handler_abort.lock() = + Some(tokio::spawn(device_handler(Arc::clone(self), device)).abort_handle()); + + tracing::debug!("background_loop_started"); + + Ok(()) + } +} + +/// Reads IP packets from the [`Device`] and handles them accordingly. +async fn device_handler( + tunnel: Arc>, + mut device: Device, +) -> Result<(), ConnlibError> +where + C: ControlSignal + Send + Sync + 'static, + CB: Callbacks + 'static, +{ + let mut buf = [0u8; MAX_UDP_SIZE]; + loop { + let Some(packet) = device.read().await? else { + // Reading a bad IP packet or otherwise from the device seems bad. Should we restart the tunnel or something? + return Ok(()); + }; + + let dest = packet.destination(); + + let Some(peer) = tunnel.peer_by_ip(dest) else { + continue; + }; + + if let Err(e) = tunnel + .encapsulate_and_send_to_peer(packet, peer, &dest, &mut buf) + .await + { + tracing::error!(err = ?e, "failed to handle packet {e:#}") + } + } +} + +/// [`Tunnel`] state specific to gateways. +pub struct GatewayState { + candidate_receivers: StreamMap, +} + +impl GatewayState { + pub fn add_new_ice_receiver(&mut self, id: ClientId, receiver: Receiver) { + match self.candidate_receivers.try_push(id, receiver) { + Ok(()) => {} + Err(PushError::BeyondCapacity(_)) => { + tracing::warn!("Too many active ICE candidate receivers at a time") + } + Err(PushError::Replaced(_)) => { + tracing::warn!(%id, "Replaced old ICE candidate receiver with new one") + } + } + } +} + +impl Default for GatewayState { + fn default() -> Self { + Self { + candidate_receivers: StreamMap::new( + Duration::from_secs(ICE_GATHERING_TIMEOUT_SECONDS), + MAX_CONCURRENT_ICE_GATHERING, + ), + } + } +} + +impl RoleState for GatewayState { + type Id = ClientId; + + fn poll_next_event(&mut self, cx: &mut Context<'_>) -> Poll> { + loop { + match ready!(self.candidate_receivers.poll_next_unpin(cx)) { + (conn_id, Some(Ok(c))) => { + return Poll::Ready(Event::SignalIceCandidate { + conn_id, + candidate: c, + }) + } + (id, Some(Err(e))) => { + tracing::warn!(gateway_id = %id, "ICE gathering timed out: {e}") + } + (_, None) => {} + } + } + } +} diff --git a/rust/connlib/tunnel/src/iface_handler.rs b/rust/connlib/tunnel/src/iface_handler.rs index f28879ea7..ef69e0efc 100644 --- a/rust/connlib/tunnel/src/iface_handler.rs +++ b/rust/connlib/tunnel/src/iface_handler.rs @@ -1,18 +1,10 @@ -use std::{net::IpAddr, sync::Arc, time::Duration}; +use std::{net::IpAddr, sync::Arc}; -use boringtun::noise::{errors::WireGuardError, Tunn, TunnResult}; +use boringtun::noise::{errors::WireGuardError, TunnResult}; use bytes::Bytes; -use connlib_shared::{Callbacks, Error, Result}; +use connlib_shared::{Callbacks, Result}; -use crate::role_state::RoleState; -use crate::{ - device_channel::{DeviceIo, IfaceConfig}, - dns, - peer::EncapsulatedPacket, - ConnId, ControlSignal, Tunnel, MAX_UDP_SIZE, -}; - -const MAX_SIGNAL_CONNECTION_DELAY: Duration = Duration::from_secs(2); +use crate::{ip_packet::MutableIpPacket, peer::Peer, ControlSignal, RoleState, Tunnel}; impl Tunnel where @@ -21,74 +13,15 @@ where TRoleState: RoleState, { #[inline(always)] - fn connection_intent(self: &Arc, src: &[u8], dst_addr: &IpAddr) { - // We can buffer requests here but will drop them for now and let the upper layer reliability protocol handle this - if let Some(resource) = self.get_resource(src) { - // We have awaiting connection to prevent a race condition where - // create_peer_connection hasn't added the thing to peer_connections - // and we are finding another packet to the same address (otherwise we would just use peer_connections here) - let mut awaiting_connection = self.awaiting_connection.lock(); - let conn_id = ConnId::from(resource.id()); - if awaiting_connection.get(&conn_id).is_none() { - tracing::trace!( - resource_ip = %dst_addr, - "resource_connection_intent", - ); - - awaiting_connection.insert(conn_id, Default::default()); - let dev = Arc::clone(self); - - let mut connected_gateway_ids: Vec<_> = dev - .gateway_awaiting_connection - .lock() - .clone() - .into_keys() - .collect(); - connected_gateway_ids - .extend(dev.resources_gateways.lock().values().collect::>()); - tracing::trace!( - gateways = ?connected_gateway_ids, - "connected_gateways" - ); - tokio::spawn(async move { - let mut interval = tokio::time::interval(MAX_SIGNAL_CONNECTION_DELAY); - loop { - interval.tick().await; - let reference = { - let mut awaiting_connections = dev.awaiting_connection.lock(); - let Some(awaiting_connection) = - awaiting_connections.get_mut(&ConnId::from(resource.id())) - else { - break; - }; - if awaiting_connection.response_received { - break; - } - awaiting_connection.total_attemps += 1; - awaiting_connection.total_attemps - }; - if let Err(e) = dev - .control_signaler - .signal_connection_to(&resource, &connected_gateway_ids, reference) - .await - { - // Not a deadlock because this is a different task - dev.awaiting_connection.lock().remove(&conn_id); - tracing::error!(error = ?e, "start_resource_connection"); - let _ = dev.callbacks.on_error(&e); - } - } - }); - } - } - } - - #[inline(always)] - async fn handle_encapsulated_packet<'a>( + pub(crate) async fn encapsulate_and_send_to_peer<'a>( &self, - encapsulated_packet: EncapsulatedPacket<'a>, + mut packet: MutableIpPacket<'_>, + peer: Arc, dst_addr: &IpAddr, + buf: &mut [u8], ) -> Result<()> { + let encapsulated_packet = peer.encapsulate(&mut packet, buf)?; + match encapsulated_packet.encapsulate_result { TunnResult::Done => Ok(()), TunnResult::Err(WireGuardError::ConnectionExpired) @@ -130,72 +63,4 @@ where _ => panic!("Unexpected result from encapsulate"), } } - - #[inline(always)] - async fn handle_iface_packet( - self: &Arc, - device_writer: &DeviceIo, - src: &mut [u8], - dst: &mut [u8], - ) -> Result<()> { - if let Some(r) = self.check_for_dns(src) { - match r { - dns::SendPacket::Ipv4(r) => device_writer.write4(&r[..])?, - dns::SendPacket::Ipv6(r) => device_writer.write6(&r[..])?, - }; - return Ok(()); - } - - let dst_addr = match Tunn::dst_address(src) { - Some(addr) => addr, - None => return Err(Error::BadPacket), - }; - - let encapsulated_packet = { - match self.peers_by_ip.read().longest_match(dst_addr).map(|p| p.1) { - Some(peer) => peer.encapsulate(src, dst)?, - None => { - self.connection_intent(src, &dst_addr); - return Ok(()); - } - } - }; - - self.handle_encapsulated_packet(encapsulated_packet, &dst_addr) - .await - } - - #[tracing::instrument(level = "trace", skip(self, iface_config, device_io))] - pub(crate) async fn iface_handler( - self: &Arc, - iface_config: Arc, - device_io: DeviceIo, - ) { - let device_writer = device_io.clone(); - let mut src = [0u8; MAX_UDP_SIZE]; - let mut dst = [0u8; MAX_UDP_SIZE]; - loop { - let res = match device_io.read(&mut src[..iface_config.mtu()]).await { - Ok(res) => res, - Err(e) => { - tracing::error!(err = ?e, "failed to read interface: {e:#}"); - let _ = self.callbacks.on_error(&e.into()); - break; - } - }; - tracing::trace!(target: "wire", action = "read", bytes = res, from = "iface"); - - if res == 0 { - break; - } - - if let Err(e) = self - .handle_iface_packet(&device_writer, &mut src[..res], &mut dst) - .await - { - let _ = self.callbacks.on_error(&e); - tracing::error!(err = ?e, "failed to handle packet {e:#}") - } - } - } } diff --git a/rust/connlib/tunnel/src/ip_packet.rs b/rust/connlib/tunnel/src/ip_packet.rs index 94cd3ca15..74c552238 100644 --- a/rust/connlib/tunnel/src/ip_packet.rs +++ b/rust/connlib/tunnel/src/ip_packet.rs @@ -32,10 +32,20 @@ macro_rules! swap_src_dst { impl<'a> MutableIpPacket<'a> { #[inline] pub(crate) fn new(data: &mut [u8]) -> Option { - match data[0] >> 4 { - 4 => MutableIpv4Packet::new(data).map(Into::into), - 6 => MutableIpv6Packet::new(data).map(Into::into), - _ => None, + let packet = match data[0] >> 4 { + 4 => MutableIpv4Packet::new(data)?.into(), + 6 => MutableIpv6Packet::new(data)?.into(), + _ => return None, + }; + + Some(packet) + } + + #[inline] + pub(crate) fn destination(&self) -> IpAddr { + match self { + MutableIpPacket::MutableIpv4Packet(i) => i.get_destination().into(), + MutableIpPacket::MutableIpv6Packet(i) => i.get_destination().into(), } } @@ -87,6 +97,13 @@ impl<'a> MutableIpPacket<'a> { } } + pub(crate) fn as_immutable(&self) -> IpPacket<'_> { + match self { + Self::MutableIpv4Packet(p) => IpPacket::Ipv4Packet(p.to_immutable()), + Self::MutableIpv6Packet(p) => IpPacket::Ipv6Packet(p.to_immutable()), + } + } + pub(crate) fn as_udp(&mut self) -> Option { self.to_immutable() .is_udp() @@ -174,14 +191,6 @@ pub(crate) enum IpPacket<'a> { } impl<'a> IpPacket<'a> { - pub(crate) fn new(data: &[u8]) -> Option { - match data[0] >> 4 { - 4 => Ipv4Packet::new(data).map(Into::into), - 6 => Ipv6Packet::new(data).map(Into::into), - _ => None, - } - } - pub(crate) fn version(&self) -> Version { match self { IpPacket::Ipv4Packet(_) => Version::Ipv4, @@ -214,13 +223,6 @@ impl<'a> IpPacket<'a> { .flatten() } - pub(crate) fn destination(&self) -> IpAddr { - match self { - Self::Ipv4Packet(p) => p.get_destination().into(), - Self::Ipv6Packet(p) => p.get_destination().into(), - } - } - pub(crate) fn source(&self) -> IpAddr { match self { Self::Ipv4Packet(p) => p.get_source().into(), @@ -228,6 +230,13 @@ impl<'a> IpPacket<'a> { } } + pub(crate) fn destination(&self) -> IpAddr { + match self { + Self::Ipv4Packet(p) => p.get_destination().into(), + Self::Ipv6Packet(p) => p.get_destination().into(), + } + } + pub(crate) fn udp_checksum(&self, dgm: &UdpPacket<'_>) -> u16 { match self { Self::Ipv4Packet(p) => udp::ipv4_checksum(dgm, &p.get_source(), &p.get_destination()), diff --git a/rust/connlib/tunnel/src/lib.rs b/rust/connlib/tunnel/src/lib.rs index ae083aab5..6a9bf1037 100644 --- a/rust/connlib/tunnel/src/lib.rs +++ b/rust/connlib/tunnel/src/lib.rs @@ -3,12 +3,12 @@ //! This is both the wireguard and ICE implementation that should work in tandem. //! [Tunnel] is the main entry-point for this crate. use boringtun::{ - noise::{errors::WireGuardError, rate_limiter::RateLimiter, Tunn, TunnResult}, + noise::{errors::WireGuardError, rate_limiter::RateLimiter, TunnResult}, x25519::{PublicKey, StaticSecret}, }; use bytes::Bytes; -use connlib_shared::{messages::Key, CallbackErrorFacade, Callbacks, Error, DNS_SENTINEL}; +use connlib_shared::{messages::Key, CallbackErrorFacade, Callbacks, Error}; use ip_network::IpNetwork; use ip_network_table::IpNetworkTable; use serde::{Deserialize, Serialize}; @@ -29,7 +29,7 @@ use webrtc::{ }; use std::task::{Context, Poll}; -use std::{collections::HashMap, fmt, net::IpAddr, sync::Arc, time::Duration}; +use std::{collections::HashMap, fmt, io, net::IpAddr, sync::Arc, time::Duration}; use webrtc::ice_transport::ice_candidate::RTCIceCandidateInit; use connlib_shared::{ @@ -39,19 +39,22 @@ use connlib_shared::{ Result, }; -use device_channel::{create_iface, DeviceIo, IfaceConfig}; +use device_channel::{DeviceIo, IfaceConfig}; +pub use client::ClientState; pub use control_protocol::Request; -pub use role_state::{ClientState, GatewayState}; +pub use gateway::GatewayState; pub use webrtc::peer_connection::sdp::session_description::RTCSessionDescription; -use crate::role_state::RoleState; +use crate::ip_packet::MutableIpPacket; use connlib_shared::messages::SecretKey; use index::IndexLfsr; +mod client; mod control_protocol; mod device_channel; mod dns; +mod gateway; mod iface_handler; mod index; mod ip_packet; @@ -59,12 +62,23 @@ mod peer; mod peer_handler; mod resource_sender; mod resource_table; -mod role_state; +mod tokio_util; const MAX_UDP_SIZE: usize = (1 << 16) - 1; const RESET_PACKET_COUNT_INTERVAL: Duration = Duration::from_secs(1); const REFRESH_PEERS_TIMERS_INTERVAL: Duration = Duration::from_secs(1); const REFRESH_MTU_INTERVAL: Duration = Duration::from_secs(30); +/// For how long we will attempt to gather ICE candidates before aborting. +/// +/// Chosen arbitrarily. +/// Very likely, the actual WebRTC connection will timeout before this. +/// This timeout is just here to eventually clean-up tasks if they are somehow broken. +const ICE_GATHERING_TIMEOUT_SECONDS: u64 = 5 * 60; + +/// How many concurrent ICE gathering attempts we are allow. +/// +/// Chosen arbitrarily. +const MAX_CONCURRENT_ICE_GATHERING: usize = 100; // Note: Taken from boringtun const HANDSHAKE_RATE_LIMIT: u64 = 100; @@ -149,8 +163,30 @@ struct AwaitingConnectionDetails { #[derive(Clone)] struct Device { - pub config: Arc, - pub io: DeviceIo, + config: Arc, + io: DeviceIo, + + buf: Box<[u8; MAX_UDP_SIZE]>, +} + +impl Device { + async fn read(&mut self) -> io::Result>> { + let res = self.io.read(&mut self.buf[..self.config.mtu()]).await?; + tracing::trace!(target: "wire", action = "read", bytes = res, from = "iface"); + + if res == 0 { + return Ok(None); + } + + Ok(Some( + MutableIpPacket::new(&mut self.buf[..res]).ok_or_else(|| { + io::Error::new( + io::ErrorKind::InvalidInput, + "received bytes are not an IP packet", + ) + })?, + )) + } } // TODO: We should use newtypes for each kind of Id @@ -243,6 +279,14 @@ where pub fn poll_next_event(&self, cx: &mut Context<'_>) -> Poll> { self.role_state.lock().poll_next_event(cx) } + + pub(crate) fn peer_by_ip(&self, ip: IpAddr) -> Option> { + self.peers_by_ip + .read() + .longest_match(ip) + .map(|(_, peer)| peer) + .cloned() + } } pub enum Event { @@ -327,86 +371,6 @@ where }) } - #[tracing::instrument(level = "trace", skip(self))] - pub async fn add_route(self: &Arc, route: IpNetwork) -> Result<()> { - let mut device = self.device.write().await; - - if let Some(new_device) = device - .as_ref() - .ok_or(Error::ControlProtocolError)? - .config - .add_route(route, self.callbacks()) - .await? - { - *device = Some(new_device.clone()); - let dev = Arc::clone(self); - self.iface_handler_abort.lock().replace( - tokio::spawn( - async move { dev.iface_handler(new_device.config, new_device.io).await }, - ) - .abort_handle(), - ); - } - - Ok(()) - } - - /// Adds a the given resource to the tunnel. - /// - /// Once added, when a packet for the resource is intercepted a new data channel will be created - /// and packets will be wrapped with wireguard and sent through it. - #[tracing::instrument(level = "trace", skip(self))] - pub async fn add_resource( - self: &Arc, - resource_description: ResourceDescription, - ) -> Result<()> { - let mut any_valid_route = false; - { - for ip in resource_description.ips() { - if let Err(e) = self.add_route(ip).await { - tracing::warn!(route = %ip, error = ?e, "add_route"); - let _ = self.callbacks().on_error(&e); - } else { - any_valid_route = true; - } - } - } - if !any_valid_route { - return Err(Error::InvalidResource); - } - - let resource_list = { - let mut resources = self.resources.write(); - resources.insert(resource_description); - resources.resource_list() - }; - - self.callbacks.on_update_resources(resource_list)?; - Ok(()) - } - - /// Sets the interface configuration and starts background tasks. - #[tracing::instrument(level = "trace", skip(self))] - pub async fn set_interface(self: &Arc, config: &InterfaceConfig) -> Result<()> { - let device = create_iface(config, self.callbacks()).await?; - *self.device.write().await = Some(device.clone()); - - self.start_timers().await?; - let dev = Arc::clone(self); - *self.iface_handler_abort.lock() = Some( - tokio::spawn(async move { dev.iface_handler(device.config, device.io).await }) - .abort_handle(), - ); - - self.add_route(DNS_SENTINEL.into()).await?; - - self.callbacks.on_tunnel_ready()?; - - tracing::debug!("background_loop_started"); - - Ok(()) - } - #[tracing::instrument(level = "trace", skip(self))] async fn stop_peer(&self, index: u32, conn_id: ConnId) { self.peers_by_ip.write().retain(|_, p| p.index != index); @@ -537,8 +501,7 @@ where Ok(()) } - fn get_resource(&self, buff: &[u8]) -> Option { - let addr = Tunn::dst_address(buff)?; + fn get_resource(&self, addr: IpAddr) -> Option { let resources = self.resources.read(); match addr { IpAddr::V4(ipv4) => resources.get_by_ip(ipv4).cloned(), @@ -554,3 +517,13 @@ where &self.callbacks } } + +/// Dedicated trait for abstracting over the different ICE states. +/// +/// By design, this trait does not allow any operations apart from advancing via [`RoleState::poll_next_event`]. +/// The state should only be modified when the concrete type is known, e.g. [`ClientState`] or [`GatewayState`]. +pub trait RoleState: Default + Send + 'static { + type Id: fmt::Debug; + + fn poll_next_event(&mut self, cx: &mut Context<'_>) -> Poll>; +} diff --git a/rust/connlib/tunnel/src/peer.rs b/rust/connlib/tunnel/src/peer.rs index 9df880108..034ea7832 100644 --- a/rust/connlib/tunnel/src/peer.rs +++ b/rust/connlib/tunnel/src/peer.rs @@ -10,6 +10,7 @@ use connlib_shared::{ use ip_network::IpNetwork; use ip_network_table::IpNetworkTable; use parking_lot::{Mutex, RwLock}; +use pnet_packet::MutablePacket; use webrtc::data::data_channel::DataChannel; use crate::{ip_packet::MutableIpPacket, resource_table::ResourceTable, ConnId}; @@ -194,14 +195,9 @@ impl Peer { pub(crate) fn encapsulate<'a>( &self, - src: &'a mut [u8], + packet: &mut MutableIpPacket<'a>, dst: &'a mut [u8], ) -> Result> { - let Some(mut packet) = MutableIpPacket::new(src) else { - debug_assert!(false, "Got non-ip packet from the tunnel interface"); - tracing::error!("Developer error: we should never see a packet through the tunnel wire that isn't ip"); - return Err(Error::BadPacket); - }; if let Some(resource) = self.get_translation(packet.to_immutable().source()) { let ResourceDescription::Dns(resource) = resource else { tracing::error!( @@ -210,7 +206,7 @@ impl Peer { return Err(Error::ControlProtocolError); }; - match &mut packet { + match packet { MutableIpPacket::MutableIpv4Packet(ref mut p) => p.set_source(resource.ipv4), MutableIpPacket::MutableIpv6Packet(ref mut p) => p.set_source(resource.ipv6), } @@ -221,7 +217,7 @@ impl Peer { index: self.index, conn_id: self.conn_id, channel: self.channel.clone(), - encapsulate_result: self.tunnel.lock().encapsulate(src, dst), + encapsulate_result: self.tunnel.lock().encapsulate(packet.packet_mut(), dst), }) } diff --git a/rust/connlib/tunnel/src/peer_handler.rs b/rust/connlib/tunnel/src/peer_handler.rs index d475d009b..ac75be900 100644 --- a/rust/connlib/tunnel/src/peer_handler.rs +++ b/rust/connlib/tunnel/src/peer_handler.rs @@ -4,10 +4,9 @@ use boringtun::noise::{handshake::parse_handshake_anon, Packet, TunnResult}; use bytes::Bytes; use connlib_shared::{Callbacks, Error, Result}; -use crate::role_state::RoleState; use crate::{ - device_channel::DeviceIo, index::check_packet_index, peer::Peer, ControlSignal, Tunnel, - MAX_UDP_SIZE, + device_channel::DeviceIo, index::check_packet_index, peer::Peer, ControlSignal, RoleState, + Tunnel, MAX_UDP_SIZE, }; impl Tunnel diff --git a/rust/connlib/tunnel/src/role_state.rs b/rust/connlib/tunnel/src/role_state.rs deleted file mode 100644 index 0fa60c018..000000000 --- a/rust/connlib/tunnel/src/role_state.rs +++ /dev/null @@ -1,148 +0,0 @@ -use crate::Event; -use connlib_shared::messages::{ClientId, GatewayId}; -use futures::channel::mpsc::Receiver; -use futures_bounded::{PushError, StreamMap}; -use std::collections::HashMap; -use std::fmt; -use std::task::{ready, Context, Poll}; -use std::time::Duration; -use webrtc::ice_transport::ice_candidate::RTCIceCandidateInit; - -/// Dedicated trait for abstracting over the different ICE states. -/// -/// By design, this trait does not allow any operations apart from advancing via [`RoleState::poll_next_event`]. -/// The state should only be modified when the concrete type is known, e.g. [`ClientState`] or [`GatewayState`]. -pub trait RoleState: Default + Send + 'static { - type Id: fmt::Debug; - - fn poll_next_event(&mut self, cx: &mut Context<'_>) -> Poll>; -} - -/// For how long we will attempt to gather ICE candidates before aborting. -/// -/// Chosen arbitrarily. -/// Very likely, the actual WebRTC connection will timeout before this. -/// This timeout is just here to eventually clean-up tasks if they are somehow broken. -const ICE_GATHERING_TIMEOUT_SECONDS: u64 = 5 * 60; - -/// How many concurrent ICE gathering attempts we are allow. -/// -/// Chosen arbitrarily. -const MAX_CONCURRENT_ICE_GATHERING: usize = 100; - -/// [`Tunnel`](crate::Tunnel) state specific to clients. -pub struct ClientState { - active_candidate_receivers: StreamMap, - /// We split the receivers of ICE candidates into two phases because we only want to start sending them once we've received an SDP from the gateway. - waiting_for_sdp_from_gatway: HashMap>, -} - -impl ClientState { - pub fn add_waiting_ice_receiver( - &mut self, - id: GatewayId, - receiver: Receiver, - ) { - self.waiting_for_sdp_from_gatway.insert(id, receiver); - } - - pub fn activate_ice_candidate_receiver(&mut self, id: GatewayId) { - let Some(receiver) = self.waiting_for_sdp_from_gatway.remove(&id) else { - return; - }; - - match self.active_candidate_receivers.try_push(id, receiver) { - Ok(()) => {} - Err(PushError::BeyondCapacity(_)) => { - tracing::warn!("Too many active ICE candidate receivers at a time") - } - Err(PushError::Replaced(_)) => { - tracing::warn!(%id, "Replaced old ICE candidate receiver with new one") - } - } - } -} - -impl Default for ClientState { - fn default() -> Self { - Self { - active_candidate_receivers: StreamMap::new( - Duration::from_secs(ICE_GATHERING_TIMEOUT_SECONDS), - MAX_CONCURRENT_ICE_GATHERING, - ), - waiting_for_sdp_from_gatway: Default::default(), - } - } -} - -impl RoleState for ClientState { - type Id = GatewayId; - - fn poll_next_event(&mut self, cx: &mut Context<'_>) -> Poll> { - loop { - match ready!(self.active_candidate_receivers.poll_next_unpin(cx)) { - (conn_id, Some(Ok(c))) => { - return Poll::Ready(Event::SignalIceCandidate { - conn_id, - candidate: c, - }) - } - (id, Some(Err(e))) => { - tracing::warn!(gateway_id = %id, "ICE gathering timed out: {e}") - } - (_, None) => {} - } - } - } -} - -/// [`Tunnel`](crate::Tunnel) state specific to gateways. -pub struct GatewayState { - candidate_receivers: StreamMap, -} - -impl GatewayState { - pub fn add_new_ice_receiver(&mut self, id: ClientId, receiver: Receiver) { - match self.candidate_receivers.try_push(id, receiver) { - Ok(()) => {} - Err(PushError::BeyondCapacity(_)) => { - tracing::warn!("Too many active ICE candidate receivers at a time") - } - Err(PushError::Replaced(_)) => { - tracing::warn!(%id, "Replaced old ICE candidate receiver with new one") - } - } - } -} - -impl Default for GatewayState { - fn default() -> Self { - Self { - candidate_receivers: StreamMap::new( - Duration::from_secs(ICE_GATHERING_TIMEOUT_SECONDS), - MAX_CONCURRENT_ICE_GATHERING, - ), - } - } -} - -impl RoleState for GatewayState { - type Id = ClientId; - - fn poll_next_event(&mut self, cx: &mut Context<'_>) -> Poll> { - loop { - match ready!(self.candidate_receivers.poll_next_unpin(cx)) { - (conn_id, Some(Ok(c))) => { - return Poll::Ready(Event::SignalIceCandidate { - conn_id, - candidate: c, - }) - } - (id, Some(Err(e))) => { - tracing::warn!(gateway_id = %id, "ICE gathering timed out: {e}") - } - (_, None) => {} - } - } - } -} diff --git a/rust/connlib/tunnel/src/tokio_util.rs b/rust/connlib/tunnel/src/tokio_util.rs new file mode 100644 index 000000000..8c0840ba3 --- /dev/null +++ b/rust/connlib/tunnel/src/tokio_util.rs @@ -0,0 +1,23 @@ +use connlib_shared::error::ConnlibError; +use connlib_shared::Callbacks; +use std::future::Future; + +/// Spawns a task into the [`tokio`] runtime. +/// +/// On error, [`Callbacks::on_error`] is invoked. +/// This also returns a [`tokio::task::AbortHandle`] which MAY be used to abort the task. +/// If you don't need it, you are free to drop it. +/// It won't terminate the task. +pub(crate) fn spawn_log( + cb: &(impl Callbacks + 'static), + f: impl Future> + Send + 'static, +) -> tokio::task::AbortHandle { + let cb = cb.clone(); + + tokio::spawn(async move { + if let Err(e) = f.await { + let _ = cb.on_error(&e); + } + }) + .abort_handle() +}