diff --git a/rust/Cargo.lock b/rust/Cargo.lock index 4b5243878..908037904 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -792,6 +792,7 @@ dependencies = [ name = "connlib-client-shared" version = "1.20231001.0" dependencies = [ + "anyhow", "async-trait", "backoff", "chrono", @@ -1245,7 +1246,7 @@ dependencies = [ "connlib-shared", "firezone-tunnel", "futures", - "futures-bounded", + "futures-bounded 0.1.0", "headless-utils", "phoenix-channel", "secrecy", @@ -1283,6 +1284,7 @@ dependencies = [ "connlib-shared", "domain", "futures", + "futures-bounded 0.2.0", "futures-util", "ip_network", "ip_network_table", @@ -1345,6 +1347,15 @@ dependencies = [ "futures-util", ] +[[package]] +name = "futures-bounded" +version = "0.2.0" +source = "git+https://github.com/libp2p/rust-libp2p?branch=feat/stream-map#1e4ad64558159dfc94b50daf701b3ee7315553b9" +dependencies = [ + "futures-timer", + "futures-util", +] + [[package]] name = "futures-channel" version = "0.3.28" diff --git a/rust/connlib/clients/shared/Cargo.toml b/rust/connlib/clients/shared/Cargo.toml index dcabcaa2e..bd2dd9235 100644 --- a/rust/connlib/clients/shared/Cargo.toml +++ b/rust/connlib/clients/shared/Cargo.toml @@ -8,6 +8,7 @@ edition = "2021" mock = ["connlib-shared/mock"] [dependencies] +anyhow = "1.0.75" tokio = { version = "1.32", default-features = false, features = ["sync", "rt"] } tokio-util = "0.7.9" secrecy = { workspace = true } diff --git a/rust/connlib/clients/shared/src/control.rs b/rust/connlib/clients/shared/src/control.rs index 8b02445e5..ed22a1f03 100644 --- a/rust/connlib/clients/shared/src/control.rs +++ b/rust/connlib/clients/shared/src/control.rs @@ -9,13 +9,12 @@ use connlib_shared::{ control::{ErrorInfo, ErrorReply, PhoenixSenderWithTopic, Reference}, messages::{GatewayId, ResourceDescription, ResourceId}, Callbacks, - Error::{self, ControlProtocolError}, + Error::{self}, Result, }; -use webrtc::ice_transport::ice_candidate::RTCIceCandidate; use async_trait::async_trait; -use firezone_tunnel::{ConnId, ControlSignal, Request, Tunnel}; +use firezone_tunnel::{ClientState, ControlSignal, Request, Tunnel}; use tokio::sync::Mutex; use tokio_util::codec::{BytesCodec, FramedRead}; use url::Url; @@ -41,35 +40,10 @@ impl ControlSignal for ControlSignaler { .await?; Ok(()) } - - async fn signal_ice_candidate( - &self, - ice_candidate: RTCIceCandidate, - conn_id: ConnId, - ) -> Result<()> { - // TODO: We probably want to have different signal_ice_candidate - // functions for gateway/client but ultimately we just want - // separate control_plane modules - if let ConnId::Gateway(id) = conn_id { - self.control_signal - .clone() - .send(EgressMessages::BroadcastIceCandidates( - BroadcastGatewayIceCandidates { - gateway_ids: vec![id], - candidates: vec![ice_candidate.to_json()?], - }, - )) - .await?; - - Ok(()) - } else { - Err(ControlProtocolError) - } - } } pub struct ControlPlane { - pub tunnel: Arc>, + pub tunnel: Arc>, pub control_signaler: ControlSignaler, pub tunnel_init: Mutex, } @@ -301,6 +275,26 @@ impl ControlPlane { .send(EgressMessages::CreateLogSink {}) .await; } + + pub async fn handle_tunnel_event(&mut self, event: firezone_tunnel::Event) { + match event { + firezone_tunnel::Event::SignalIceCandidate { conn_id, candidate } => { + if let Err(e) = self + .control_signaler + .control_signal + .send(EgressMessages::BroadcastIceCandidates( + BroadcastGatewayIceCandidates { + gateway_ids: vec![conn_id], + candidates: vec![candidate], + }, + )) + .await + { + tracing::error!("Failed to signal ICE candidate: {e}") + } + } + } + } } async fn upload(path: PathBuf, url: Url) -> io::Result<()> { diff --git a/rust/connlib/clients/shared/src/lib.rs b/rust/connlib/clients/shared/src/lib.rs index 094b88a51..8d9889ff6 100644 --- a/rust/connlib/clients/shared/src/lib.rs +++ b/rust/connlib/clients/shared/src/lib.rs @@ -165,6 +165,7 @@ where tokio::spawn(async move { let mut log_stats_interval = tokio::time::interval(Duration::from_secs(10)); let mut upload_logs_interval = upload_interval(); + loop { tokio::select! { Some((msg, reference)) = control_plane_receiver.recv() => { @@ -173,6 +174,7 @@ where Err(err) => control_plane.handle_error(err, reference).await, } }, + event = control_plane.tunnel.next_event() => control_plane.handle_tunnel_event(event).await, _ = log_stats_interval.tick() => control_plane.stats_event().await, _ = upload_logs_interval.tick() => control_plane.request_log_upload_url().await, else => break diff --git a/rust/connlib/shared/src/messages.rs b/rust/connlib/shared/src/messages.rs index 0a958f0a1..65ede6c42 100644 --- a/rust/connlib/shared/src/messages.rs +++ b/rust/connlib/shared/src/messages.rs @@ -49,6 +49,12 @@ impl fmt::Display for ClientId { } } +impl fmt::Display for GatewayId { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } +} + /// Represents a wireguard peer. #[derive(Debug, Deserialize, Serialize, Clone)] pub struct Peer { diff --git a/rust/connlib/tunnel/Cargo.toml b/rust/connlib/tunnel/Cargo.toml index 7c6e17608..40889b95a 100644 --- a/rust/connlib/tunnel/Cargo.toml +++ b/rust/connlib/tunnel/Cargo.toml @@ -25,6 +25,7 @@ domain = "0.8" boringtun = { workspace = true } chrono = { workspace = true } pnet_packet = { version = "0.34" } +futures-bounded = { git = "https://github.com/libp2p/rust-libp2p", branch = "feat/stream-map" } # TODO: research replacing for https://github.com/algesten/str0m webrtc = { version = "0.8" } diff --git a/rust/connlib/tunnel/src/control_protocol.rs b/rust/connlib/tunnel/src/control_protocol.rs index 027bab19e..f6b21a077 100644 --- a/rust/connlib/tunnel/src/control_protocol.rs +++ b/rust/connlib/tunnel/src/control_protocol.rs @@ -1,5 +1,7 @@ use boringtun::noise::Tunn; use chrono::{DateTime, Utc}; +use futures::channel::mpsc; +use futures_util::SinkExt; use secrecy::ExposeSecret; use std::sync::Arc; use tracing::instrument; @@ -20,6 +22,7 @@ use webrtc::{ }, }; +use crate::role_state::RoleState; use crate::{peer::Peer, ConnId, ControlSignal, PeerConfig, Tunnel}; mod client; @@ -35,14 +38,15 @@ pub enum Request { } #[tracing::instrument(level = "trace", skip(tunnel))] -async fn handle_connection_state_update_with_peer( - tunnel: &Arc>, +async fn handle_connection_state_update_with_peer( + tunnel: &Arc>, state: RTCPeerConnectionState, index: u32, conn_id: ConnId, ) where C: ControlSignal + Clone + Send + Sync + 'static, CB: Callbacks + 'static, + TRoleState: RoleState, { tracing::trace!(?state, "peer_state_update"); if state == RTCPeerConnectionState::Failed { @@ -51,14 +55,15 @@ async fn handle_connection_state_update_with_peer( } #[tracing::instrument(level = "trace", skip(tunnel))] -fn set_connection_state_with_peer( - tunnel: &Arc>, +fn set_connection_state_with_peer( + tunnel: &Arc>, peer_connection: &Arc, index: u32, conn_id: ConnId, ) where C: ControlSignal + Clone + Send + Sync + 'static, CB: Callbacks + 'static, + TRoleState: RoleState, { let tunnel = Arc::clone(tunnel); peer_connection.on_peer_connection_state_change(Box::new( @@ -71,10 +76,11 @@ fn set_connection_state_with_peer( )); } -impl Tunnel +impl Tunnel where C: ControlSignal + Clone + Send + Sync + 'static, CB: Callbacks + 'static, + TRoleState: RoleState, { #[instrument(level = "trace", skip(self, data_channel, peer_config))] async fn handle_channel_open( @@ -152,11 +158,10 @@ where } #[tracing::instrument(level = "trace", skip(self))] - async fn initialize_peer_request( + pub async fn new_peer_connection( self: &Arc, relays: Vec, - conn_id: ConnId, - ) -> Result> { + ) -> Result<(Arc, mpsc::Receiver)> { let config = RTCConfiguration { ice_servers: relays .into_iter() @@ -176,50 +181,33 @@ where .collect(), ..Default::default() }; + let peer_connection = Arc::new(self.webrtc_api.new_peer_connection(config).await?); - let (ice_candidate_tx, ice_candidate_rx) = tokio::sync::mpsc::channel(ICE_CANDIDATE_BUFFER); - self.ice_candidate_queue - .lock() - .insert(conn_id, ice_candidate_rx); + let (ice_candidate_tx, ice_candidate_rx) = mpsc::channel(ICE_CANDIDATE_BUFFER); - let callbacks = self.callbacks().clone(); peer_connection.on_ice_candidate(Box::new(move |candidate| { - let ice_candidate_tx = ice_candidate_tx.clone(); - let callbacks = callbacks.clone(); + let Some(candidate) = candidate else { + return Box::pin(async {}); + }; + + let mut ice_candidate_tx = ice_candidate_tx.clone(); Box::pin(async move { - if let Err(e) = ice_candidate_tx.send(candidate).await { - tracing::error!(err = ?e, "buffer_ice_candidate"); - let _ = callbacks.on_error(&e.into()); + let ice_candidate = match candidate.to_json() { + Ok(ice_candidate) => ice_candidate, + Err(e) => { + tracing::warn!("Failed to serialize ICE candidate to JSON: {e}",); + return; + } + }; + + if ice_candidate_tx.send(ice_candidate).await.is_err() { + debug_assert!(false, "receiver was dropped before sender") } }) })); - Ok(peer_connection) - } - - fn start_ice_candidate_handler(&self, conn_id: ConnId) -> Result<()> { - let mut ice_candidate_rx = self - .ice_candidate_queue - .lock() - .remove(&conn_id) - .ok_or(Error::ControlProtocolError)?; - let control_signaler = self.control_signaler.clone(); - let callbacks = self.callbacks().clone(); - - tokio::spawn(async move { - while let Some(ice_candidate) = ice_candidate_rx.recv().await.flatten() { - if let Err(e) = control_signaler - .signal_ice_candidate(ice_candidate, conn_id) - .await - { - tracing::error!(err = ?e, "add_ice_candidate"); - let _ = callbacks.on_error(&e); - } - } - }); - - Ok(()) + Ok((peer_connection, ice_candidate_rx)) } pub async fn add_ice_candidate( diff --git a/rust/connlib/tunnel/src/control_protocol/client.rs b/rust/connlib/tunnel/src/control_protocol/client.rs index 151d482e4..e5c006c5d 100644 --- a/rust/connlib/tunnel/src/control_protocol/client.rs +++ b/rust/connlib/tunnel/src/control_protocol/client.rs @@ -1,14 +1,10 @@ use std::sync::Arc; use boringtun::x25519::{PublicKey, StaticSecret}; -use chrono::{DateTime, Utc}; use connlib_shared::messages::SecretKey; use connlib_shared::{ control::Reference, - messages::{ - ClientId, GatewayId, Key, Relay, RequestConnection, ResourceDescription, ResourceId, - ReuseConnection, - }, + messages::{GatewayId, Key, Relay, RequestConnection, ResourceId, ReuseConnection}, Callbacks, }; use rand_core::OsRng; @@ -21,11 +17,11 @@ use webrtc::{ }, }; -use crate::{ControlSignal, Error, PeerConfig, Request, Result, Tunnel}; +use crate::{ClientState, ControlSignal, Error, PeerConfig, Request, Result, Tunnel}; #[tracing::instrument(level = "trace", skip(tunnel))] fn handle_connection_state_update( - tunnel: &Arc>, + tunnel: &Arc>, state: RTCPeerConnectionState, gateway_id: GatewayId, resource_id: ResourceId, @@ -49,7 +45,7 @@ fn handle_connection_state_update( #[tracing::instrument(level = "trace", skip(tunnel))] fn set_connection_state_update( - tunnel: &Arc>, + tunnel: &Arc>, peer_connection: &Arc, gateway_id: GatewayId, resource_id: ResourceId, @@ -68,7 +64,7 @@ fn set_connection_state_update( )); } -impl Tunnel +impl Tunnel where C: ControlSignal + Clone + Send + Sync + 'static, CB: Callbacks + 'static, @@ -163,10 +159,11 @@ where } } let peer_connection = { - let peer_connection = Arc::new( - self.initialize_peer_request(relays, gateway_id.into()) - .await?, - ); + let (peer_connection, receiver) = self.new_peer_connection(relays).await?; + self.role_state + .lock() + .add_waiting_ice_receiver(gateway_id, receiver); + let peer_connection = Arc::new(peer_connection); let mut peer_connections = self.peer_connections.lock(); peer_connections.insert(gateway_id.into(), Arc::clone(&peer_connection)); peer_connection @@ -279,24 +276,10 @@ where .insert(gateway_id, gateway_public_key); peer_connection.set_remote_description(rtc_sdp).await?; - self.start_ice_candidate_handler(gateway_id.into())?; + self.role_state + .lock() + .activate_ice_candidate_receiver(gateway_id); Ok(()) } - - pub fn allow_access( - &self, - resource: ResourceDescription, - client_id: ClientId, - expires_at: DateTime, - ) { - if let Some(peer) = self - .peers_by_ip - .write() - .iter_mut() - .find_map(|(_, p)| (p.conn_id == client_id.into()).then_some(p)) - { - peer.add_resource(resource, expires_at); - } - } } diff --git a/rust/connlib/tunnel/src/control_protocol/gateway.rs b/rust/connlib/tunnel/src/control_protocol/gateway.rs index 1bc781a02..6ab93ee60 100644 --- a/rust/connlib/tunnel/src/control_protocol/gateway.rs +++ b/rust/connlib/tunnel/src/control_protocol/gateway.rs @@ -10,11 +10,12 @@ use webrtc::peer_connection::{ RTCPeerConnection, }; +use crate::role_state::GatewayState; use crate::{ControlSignal, PeerConfig, Tunnel}; #[tracing::instrument(level = "trace", skip(tunnel))] fn handle_connection_state_update( - tunnel: &Arc>, + tunnel: &Arc>, state: RTCPeerConnectionState, client_id: ClientId, ) where @@ -29,7 +30,7 @@ fn handle_connection_state_update( #[tracing::instrument(level = "trace", skip(tunnel))] fn set_connection_state_update( - tunnel: &Arc>, + tunnel: &Arc>, peer_connection: &Arc, client_id: ClientId, ) where @@ -45,7 +46,7 @@ fn set_connection_state_update( )); } -impl Tunnel +impl Tunnel where C: ControlSignal + Clone + Send + Sync + 'static, CB: Callbacks + 'static, @@ -72,10 +73,10 @@ where expires_at: DateTime, resource: ResourceDescription, ) -> Result { - let peer_connection = self - .initialize_peer_request(relays, client_id.into()) - .await?; - self.start_ice_candidate_handler(client_id.into())?; + let (peer_connection, receiver) = self.new_peer_connection(relays).await?; + self.role_state + .lock() + .add_new_ice_receiver(client_id, receiver); let index = self.next_index(); let tunnel = Arc::clone(self); @@ -150,4 +151,20 @@ where Ok(local_desc) } + + pub fn allow_access( + &self, + resource: ResourceDescription, + client_id: ClientId, + expires_at: DateTime, + ) { + if let Some(peer) = self + .peers_by_ip + .write() + .iter_mut() + .find_map(|(_, p)| (p.conn_id == client_id.into()).then_some(p)) + { + peer.add_resource(resource, expires_at); + } + } } diff --git a/rust/connlib/tunnel/src/dns.rs b/rust/connlib/tunnel/src/dns.rs index 7b802833d..c592d0e70 100644 --- a/rust/connlib/tunnel/src/dns.rs +++ b/rust/connlib/tunnel/src/dns.rs @@ -28,7 +28,7 @@ pub(crate) enum SendPacket { // as we can therefore we won't do it. // // See: https://stackoverflow.com/a/55093896 -impl Tunnel +impl Tunnel where C: ControlSignal + Send + Sync + 'static, CB: Callbacks + 'static, diff --git a/rust/connlib/tunnel/src/iface_handler.rs b/rust/connlib/tunnel/src/iface_handler.rs index 4c6a120dd..f28879ea7 100644 --- a/rust/connlib/tunnel/src/iface_handler.rs +++ b/rust/connlib/tunnel/src/iface_handler.rs @@ -4,6 +4,7 @@ use boringtun::noise::{errors::WireGuardError, Tunn, TunnResult}; use bytes::Bytes; use connlib_shared::{Callbacks, Error, Result}; +use crate::role_state::RoleState; use crate::{ device_channel::{DeviceIo, IfaceConfig}, dns, @@ -13,10 +14,11 @@ use crate::{ const MAX_SIGNAL_CONNECTION_DELAY: Duration = Duration::from_secs(2); -impl Tunnel +impl Tunnel where C: ControlSignal + Send + Sync + 'static, CB: Callbacks + 'static, + TRoleState: RoleState, { #[inline(always)] fn connection_intent(self: &Arc, src: &[u8], dst_addr: &IpAddr) { diff --git a/rust/connlib/tunnel/src/lib.rs b/rust/connlib/tunnel/src/lib.rs index 10c7e4e56..ae083aab5 100644 --- a/rust/connlib/tunnel/src/lib.rs +++ b/rust/connlib/tunnel/src/lib.rs @@ -24,12 +24,13 @@ use webrtc::{ interceptor_registry::register_default_interceptors, media_engine::MediaEngine, setting_engine::SettingEngine, APIBuilder, API, }, - ice_transport::ice_candidate::RTCIceCandidate, interceptor::registry::Registry, peer_connection::RTCPeerConnection, }; -use std::{collections::HashMap, net::IpAddr, sync::Arc, time::Duration}; +use std::task::{Context, Poll}; +use std::{collections::HashMap, fmt, net::IpAddr, sync::Arc, time::Duration}; +use webrtc::ice_transport::ice_candidate::RTCIceCandidateInit; use connlib_shared::{ messages::{ @@ -41,8 +42,10 @@ use connlib_shared::{ use device_channel::{create_iface, DeviceIo, IfaceConfig}; pub use control_protocol::Request; +pub use role_state::{ClientState, GatewayState}; pub use webrtc::peer_connection::sdp::session_description::RTCSessionDescription; +use crate::role_state::RoleState; use connlib_shared::messages::SecretKey; use index::IndexLfsr; @@ -56,6 +59,7 @@ mod peer; mod peer_handler; mod resource_sender; mod resource_table; +mod role_state; const MAX_UDP_SIZE: usize = (1 << 16) - 1; const RESET_PACKET_COUNT_INTERVAL: Duration = Duration::from_secs(1); @@ -90,6 +94,16 @@ impl From for ConnId { } } +impl fmt::Display for ConnId { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + ConnId::Gateway(inner) => fmt::Display::fmt(inner, f), + ConnId::Client(inner) => fmt::Display::fmt(inner, f), + ConnId::Resource(inner) => fmt::Display::fmt(inner, f), + } + } +} + /// Represent's the tunnel actual peer's config /// Obtained from connlib_shared's Peer #[derive(Clone)] @@ -125,13 +139,6 @@ pub trait ControlSignal { connected_gateway_ids: &[GatewayId], reference: usize, ) -> Result<()>; - - /// Signals a new candidate to the control plane - async fn signal_ice_candidate( - &self, - ice_candidate: RTCIceCandidate, - conn_id: ConnId, - ) -> Result<()>; } #[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] @@ -149,7 +156,7 @@ struct Device { // TODO: We should use newtypes for each kind of Id /// Tunnel is a wireguard state machine that uses webrtc's ICE channels instead of UDP sockets /// to communicate between peers. -pub struct Tunnel { +pub struct Tunnel { next_index: Mutex, // We use a tokio Mutex here since this is only read/write during config so there's no relevant performance impact device: tokio::sync::RwLock>, @@ -158,8 +165,6 @@ pub struct Tunnel { public_key: PublicKey, peers_by_ip: RwLock>>, peer_connections: Mutex>>, - ice_candidate_queue: - Mutex>>>, awaiting_connection: Mutex>, gateway_awaiting_connection: Mutex>>, resources_gateways: Mutex>, @@ -169,6 +174,9 @@ pub struct Tunnel { gateway_public_keys: Mutex>, callbacks: CallbackErrorFacade, iface_handler_abort: Mutex>, + + /// State that differs per role, i.e. clients vs gateways. + role_state: Mutex, } // TODO: For now we only use these fields with debug @@ -187,10 +195,11 @@ pub struct TunnelStats { gateway_awaiting_connection: HashMap>, } -impl Tunnel +impl Tunnel where C: ControlSignal + Send + Sync + 'static, CB: Callbacks + 'static, + TRoleState: RoleState, { pub fn stats(&self) -> TunnelStats { let peers_by_ip = self @@ -226,12 +235,28 @@ where gateway_public_keys, } } + + pub async fn next_event(&self) -> Event { + std::future::poll_fn(|cx| self.poll_next_event(cx)).await + } + + pub fn poll_next_event(&self, cx: &mut Context<'_>) -> Poll> { + self.role_state.lock().poll_next_event(cx) + } } -impl Tunnel +pub enum Event { + SignalIceCandidate { + conn_id: TId, + candidate: RTCIceCandidateInit, + }, +} + +impl Tunnel where C: ControlSignal + Send + Sync + 'static, CB: Callbacks + 'static, + TRoleState: RoleState, { /// Creates a new tunnel. /// @@ -255,7 +280,6 @@ where let resources_gateways = Default::default(); let gateway_awaiting_connection = Default::default(); let device = Default::default(); - let ice_candidate_queue = Default::default(); let iface_handler_abort = Default::default(); // ICE @@ -297,9 +321,9 @@ where gateway_awaiting_connection, control_signaler, resources_gateways, - ice_candidate_queue, callbacks: CallbackErrorFacade(callbacks), iface_handler_abort, + role_state: Default::default(), }) } diff --git a/rust/connlib/tunnel/src/peer_handler.rs b/rust/connlib/tunnel/src/peer_handler.rs index d6c57a155..d475d009b 100644 --- a/rust/connlib/tunnel/src/peer_handler.rs +++ b/rust/connlib/tunnel/src/peer_handler.rs @@ -4,15 +4,17 @@ use boringtun::noise::{handshake::parse_handshake_anon, Packet, TunnResult}; use bytes::Bytes; use connlib_shared::{Callbacks, Error, Result}; +use crate::role_state::RoleState; use crate::{ device_channel::DeviceIo, index::check_packet_index, peer::Peer, ControlSignal, Tunnel, MAX_UDP_SIZE, }; -impl Tunnel +impl Tunnel where C: ControlSignal + Send + Sync + 'static, CB: Callbacks + 'static, + TRoleState: RoleState, { #[inline(always)] fn is_wireguard_packet_ok(&self, parsed_packet: &Packet, peer: &Peer) -> bool { diff --git a/rust/connlib/tunnel/src/resource_sender.rs b/rust/connlib/tunnel/src/resource_sender.rs index f0eb9d8d0..f164be770 100644 --- a/rust/connlib/tunnel/src/resource_sender.rs +++ b/rust/connlib/tunnel/src/resource_sender.rs @@ -9,7 +9,7 @@ use crate::{ use connlib_shared::{messages::ResourceDescription, Callbacks, Error, Result}; -impl Tunnel +impl Tunnel where C: ControlSignal + Send + Sync + 'static, CB: Callbacks + 'static, diff --git a/rust/connlib/tunnel/src/role_state.rs b/rust/connlib/tunnel/src/role_state.rs new file mode 100644 index 000000000..0fa60c018 --- /dev/null +++ b/rust/connlib/tunnel/src/role_state.rs @@ -0,0 +1,148 @@ +use crate::Event; +use connlib_shared::messages::{ClientId, GatewayId}; +use futures::channel::mpsc::Receiver; +use futures_bounded::{PushError, StreamMap}; +use std::collections::HashMap; +use std::fmt; +use std::task::{ready, Context, Poll}; +use std::time::Duration; +use webrtc::ice_transport::ice_candidate::RTCIceCandidateInit; + +/// Dedicated trait for abstracting over the different ICE states. +/// +/// By design, this trait does not allow any operations apart from advancing via [`RoleState::poll_next_event`]. +/// The state should only be modified when the concrete type is known, e.g. [`ClientState`] or [`GatewayState`]. +pub trait RoleState: Default + Send + 'static { + type Id: fmt::Debug; + + fn poll_next_event(&mut self, cx: &mut Context<'_>) -> Poll>; +} + +/// For how long we will attempt to gather ICE candidates before aborting. +/// +/// Chosen arbitrarily. +/// Very likely, the actual WebRTC connection will timeout before this. +/// This timeout is just here to eventually clean-up tasks if they are somehow broken. +const ICE_GATHERING_TIMEOUT_SECONDS: u64 = 5 * 60; + +/// How many concurrent ICE gathering attempts we are allow. +/// +/// Chosen arbitrarily. +const MAX_CONCURRENT_ICE_GATHERING: usize = 100; + +/// [`Tunnel`](crate::Tunnel) state specific to clients. +pub struct ClientState { + active_candidate_receivers: StreamMap, + /// We split the receivers of ICE candidates into two phases because we only want to start sending them once we've received an SDP from the gateway. + waiting_for_sdp_from_gatway: HashMap>, +} + +impl ClientState { + pub fn add_waiting_ice_receiver( + &mut self, + id: GatewayId, + receiver: Receiver, + ) { + self.waiting_for_sdp_from_gatway.insert(id, receiver); + } + + pub fn activate_ice_candidate_receiver(&mut self, id: GatewayId) { + let Some(receiver) = self.waiting_for_sdp_from_gatway.remove(&id) else { + return; + }; + + match self.active_candidate_receivers.try_push(id, receiver) { + Ok(()) => {} + Err(PushError::BeyondCapacity(_)) => { + tracing::warn!("Too many active ICE candidate receivers at a time") + } + Err(PushError::Replaced(_)) => { + tracing::warn!(%id, "Replaced old ICE candidate receiver with new one") + } + } + } +} + +impl Default for ClientState { + fn default() -> Self { + Self { + active_candidate_receivers: StreamMap::new( + Duration::from_secs(ICE_GATHERING_TIMEOUT_SECONDS), + MAX_CONCURRENT_ICE_GATHERING, + ), + waiting_for_sdp_from_gatway: Default::default(), + } + } +} + +impl RoleState for ClientState { + type Id = GatewayId; + + fn poll_next_event(&mut self, cx: &mut Context<'_>) -> Poll> { + loop { + match ready!(self.active_candidate_receivers.poll_next_unpin(cx)) { + (conn_id, Some(Ok(c))) => { + return Poll::Ready(Event::SignalIceCandidate { + conn_id, + candidate: c, + }) + } + (id, Some(Err(e))) => { + tracing::warn!(gateway_id = %id, "ICE gathering timed out: {e}") + } + (_, None) => {} + } + } + } +} + +/// [`Tunnel`](crate::Tunnel) state specific to gateways. +pub struct GatewayState { + candidate_receivers: StreamMap, +} + +impl GatewayState { + pub fn add_new_ice_receiver(&mut self, id: ClientId, receiver: Receiver) { + match self.candidate_receivers.try_push(id, receiver) { + Ok(()) => {} + Err(PushError::BeyondCapacity(_)) => { + tracing::warn!("Too many active ICE candidate receivers at a time") + } + Err(PushError::Replaced(_)) => { + tracing::warn!(%id, "Replaced old ICE candidate receiver with new one") + } + } + } +} + +impl Default for GatewayState { + fn default() -> Self { + Self { + candidate_receivers: StreamMap::new( + Duration::from_secs(ICE_GATHERING_TIMEOUT_SECONDS), + MAX_CONCURRENT_ICE_GATHERING, + ), + } + } +} + +impl RoleState for GatewayState { + type Id = ClientId; + + fn poll_next_event(&mut self, cx: &mut Context<'_>) -> Poll> { + loop { + match ready!(self.candidate_receivers.poll_next_unpin(cx)) { + (conn_id, Some(Ok(c))) => { + return Poll::Ready(Event::SignalIceCandidate { + conn_id, + candidate: c, + }) + } + (id, Some(Err(e))) => { + tracing::warn!(gateway_id = %id, "ICE gathering timed out: {e}") + } + (_, None) => {} + } + } + } +} diff --git a/rust/gateway/src/control.rs b/rust/gateway/src/control.rs index 659bbc4b4..b944cfa9a 100644 --- a/rust/gateway/src/control.rs +++ b/rust/gateway/src/control.rs @@ -1,24 +1,12 @@ use async_trait::async_trait; -use connlib_shared::messages::ClientId; -use connlib_shared::Error::ControlProtocolError; use connlib_shared::{ messages::{GatewayId, ResourceDescription}, Result, }; -use firezone_tunnel::{ConnId, ControlSignal}; -use tokio::sync::mpsc; -use webrtc::ice_transport::ice_candidate::RTCIceCandidate; +use firezone_tunnel::ControlSignal; #[derive(Clone)] -pub struct ControlSignaler { - tx: mpsc::Sender<(ClientId, RTCIceCandidate)>, -} - -impl ControlSignaler { - pub fn new(tx: mpsc::Sender<(ClientId, RTCIceCandidate)>) -> Self { - Self { tx } - } -} +pub struct ControlSignaler; #[async_trait] impl ControlSignal for ControlSignaler { @@ -31,20 +19,4 @@ impl ControlSignal for ControlSignaler { tracing::warn!("A message to network resource: {resource:?} was discarded, gateways aren't meant to be used as clients."); Ok(()) } - - async fn signal_ice_candidate( - &self, - ice_candidate: RTCIceCandidate, - conn_id: ConnId, - ) -> Result<()> { - // TODO: We probably want to have different signal_ice_candidate - // functions for gateway/client but ultimately we just want - // separate control_plane modules - if let ConnId::Client(id) = conn_id { - let _ = self.tx.send((id, ice_candidate)).await; - Ok(()) - } else { - Err(ControlProtocolError) - } - } } diff --git a/rust/gateway/src/eventloop.rs b/rust/gateway/src/eventloop.rs index 8c6af57f6..1c1056d32 100644 --- a/rust/gateway/src/eventloop.rs +++ b/rust/gateway/src/eventloop.rs @@ -7,21 +7,18 @@ use crate::CallbackHandler; use anyhow::Result; use connlib_shared::messages::ClientId; use connlib_shared::Error; -use firezone_tunnel::Tunnel; +use firezone_tunnel::{GatewayState, Tunnel}; use phoenix_channel::PhoenixChannel; use std::convert::Infallible; use std::sync::Arc; use std::task::{Context, Poll}; use std::time::Duration; -use tokio::sync::mpsc; -use webrtc::ice_transport::ice_candidate::RTCIceCandidate; use webrtc::peer_connection::sdp::session_description::RTCSessionDescription; pub const PHOENIX_TOPIC: &str = "gateway"; -pub struct Eventloop<'a> { - tunnel: Arc>, - control_rx: &'a mut mpsc::Receiver<(ClientId, RTCIceCandidate)>, +pub struct Eventloop { + tunnel: Arc>, portal: PhoenixChannel, // TODO: Strongly type request reference (currently `String`) @@ -32,15 +29,13 @@ pub struct Eventloop<'a> { print_stats_timer: tokio::time::Interval, } -impl<'a> Eventloop<'a> { +impl Eventloop { pub(crate) fn new( - tunnel: Arc>, - control_rx: &'a mut mpsc::Receiver<(ClientId, RTCIceCandidate)>, + tunnel: Arc>, portal: PhoenixChannel, - ) -> Eventloop<'a> { + ) -> Self { Self { tunnel, - control_rx, portal, // TODO: Pick sane values for timeouts and size. @@ -54,34 +49,10 @@ impl<'a> Eventloop<'a> { } } -impl Eventloop<'_> { +impl Eventloop { #[tracing::instrument(name = "Eventloop::poll", skip_all, level = "debug")] pub fn poll(&mut self, cx: &mut Context<'_>) -> Poll> { loop { - if let Poll::Ready(Some((client, ice_candidate))) = self.control_rx.poll_recv(cx) { - let ice_candidate = match ice_candidate.to_json() { - Ok(ice_candidate) => ice_candidate, - Err(e) => { - tracing::warn!( - "Failed to serialize ICE candidate to JSON: {:#}", - anyhow::Error::new(e) - ); - continue; - } - }; - - tracing::debug!(%client, candidate = %ice_candidate.candidate, "Sending ICE candidate to client"); - - let _id = self.portal.send( - PHOENIX_TOPIC, - EgressMessages::BroadcastIceCandidates(BroadcastClientIceCandidates { - client_ids: vec![client], - candidates: vec![ice_candidate], - }), - ); - continue; - } - match self.connection_request_tasks.poll_unpin(cx) { Poll::Ready(((client, reference), Ok(Ok(gateway_rtc_session_description)))) => { tracing::debug!(%client, %reference, "Connection is ready"); @@ -203,6 +174,25 @@ impl Eventloop<'_> { _ => {} } + match self.tunnel.poll_next_event(cx) { + Poll::Ready(firezone_tunnel::Event::SignalIceCandidate { + conn_id: client, + candidate, + }) => { + tracing::debug!(%client, candidate = %candidate.candidate, "Sending ICE candidate to client"); + + let _id = self.portal.send( + PHOENIX_TOPIC, + EgressMessages::BroadcastIceCandidates(BroadcastClientIceCandidates { + client_ids: vec![client], + candidates: vec![candidate], + }), + ); + continue; + } + Poll::Pending => {} + } + if self.print_stats_timer.poll_tick(cx).is_ready() { tracing::debug!(target: "tunnel_state", stats = ?self.tunnel.stats()); continue; diff --git a/rust/gateway/src/main.rs b/rust/gateway/src/main.rs index 6d5c04898..f03c27c9a 100644 --- a/rust/gateway/src/main.rs +++ b/rust/gateway/src/main.rs @@ -2,23 +2,18 @@ use crate::control::ControlSignaler; use crate::eventloop::{Eventloop, PHOENIX_TOPIC}; use crate::messages::InitGateway; use anyhow::{Context, Result}; -use backoff::backoff::Backoff; use backoff::ExponentialBackoffBuilder; use clap::Parser; -use connlib_shared::messages::ClientId; use connlib_shared::{get_device_id, get_user_agent, login_url, Callbacks, Mode}; -use firezone_tunnel::Tunnel; -use futures::future; +use firezone_tunnel::{GatewayState, Tunnel}; +use futures::{future, TryFutureExt}; use headless_utils::{setup_global_subscriber, CommonArgs}; use phoenix_channel::SecureUrl; use secrecy::{Secret, SecretString}; use std::convert::Infallible; -use std::pin::pin; use std::sync::Arc; -use tokio::sync::mpsc; use tracing_subscriber::layer; use url::Url; -use webrtc::ice_transport::ice_candidate::RTCIceCandidate; mod control; mod eventloop; @@ -35,40 +30,25 @@ async fn main() -> Result<()> { SecretString::new(cli.common.secret), get_device_id(), )?; + let tunnel = Arc::new(Tunnel::new(private_key, ControlSignaler, CallbackHandler).await?); - // Note: This channel is only needed because [`Tunnel`] does not (yet) have a synchronous, poll-like interface. If it would have, ICE candidates would be emitted as events and we could just hand them to the phoenix channel. - let (control_tx, mut control_rx) = mpsc::channel(1); - let signaler = ControlSignaler::new(control_tx); - let tunnel = Arc::new(Tunnel::new(private_key, signaler, CallbackHandler).await?); - - let mut backoff = ExponentialBackoffBuilder::default() - .with_max_elapsed_time(None) - .build(); - - let eventloop = async { - loop { - let error = match run(tunnel.clone(), &mut control_rx, connect_url.clone()).await { - Err(e) => e, - Ok(never) => match never {}, - }; - - let t = backoff - .next_backoff() - .expect("the exponential backoff reconnect loop should run indefinitely"); + tokio::spawn(backoff::future::retry_notify( + ExponentialBackoffBuilder::default() + .with_max_elapsed_time(None) + .build(), + move || run(tunnel.clone(), connect_url.clone()).map_err(backoff::Error::transient), + |error, t| { tracing::warn!(retry_in = ?t, "Error connecting to portal: {error:#}"); + }, + )); - tokio::time::sleep(t).await; - } - }; - - future::select(pin!(eventloop), pin!(tokio::signal::ctrl_c())).await; + tokio::signal::ctrl_c().await?; Ok(()) } async fn run( - tunnel: Arc>, - control_rx: &mut mpsc::Receiver<(ClientId, RTCIceCandidate)>, + tunnel: Arc>, connect_url: Url, ) -> Result { let (portal, init) = phoenix_channel::init::( @@ -84,7 +64,7 @@ async fn run( .await .context("Failed to set interface")?; - let mut eventloop = Eventloop::new(tunnel, control_rx, portal); + let mut eventloop = Eventloop::new(tunnel, portal); future::poll_fn(|cx| eventloop.poll(cx)).await }