refactor(connlib): use events to handle ICE candidates (#2279)

This commit is contained in:
Thomas Eizinger
2023-10-11 09:26:42 +11:00
committed by GitHub
parent 0d411f60aa
commit 82c2bf3574
18 changed files with 351 additions and 230 deletions

13
rust/Cargo.lock generated
View File

@@ -792,6 +792,7 @@ dependencies = [
name = "connlib-client-shared"
version = "1.20231001.0"
dependencies = [
"anyhow",
"async-trait",
"backoff",
"chrono",
@@ -1245,7 +1246,7 @@ dependencies = [
"connlib-shared",
"firezone-tunnel",
"futures",
"futures-bounded",
"futures-bounded 0.1.0",
"headless-utils",
"phoenix-channel",
"secrecy",
@@ -1283,6 +1284,7 @@ dependencies = [
"connlib-shared",
"domain",
"futures",
"futures-bounded 0.2.0",
"futures-util",
"ip_network",
"ip_network_table",
@@ -1345,6 +1347,15 @@ dependencies = [
"futures-util",
]
[[package]]
name = "futures-bounded"
version = "0.2.0"
source = "git+https://github.com/libp2p/rust-libp2p?branch=feat/stream-map#1e4ad64558159dfc94b50daf701b3ee7315553b9"
dependencies = [
"futures-timer",
"futures-util",
]
[[package]]
name = "futures-channel"
version = "0.3.28"

View File

@@ -8,6 +8,7 @@ edition = "2021"
mock = ["connlib-shared/mock"]
[dependencies]
anyhow = "1.0.75"
tokio = { version = "1.32", default-features = false, features = ["sync", "rt"] }
tokio-util = "0.7.9"
secrecy = { workspace = true }

View File

@@ -9,13 +9,12 @@ use connlib_shared::{
control::{ErrorInfo, ErrorReply, PhoenixSenderWithTopic, Reference},
messages::{GatewayId, ResourceDescription, ResourceId},
Callbacks,
Error::{self, ControlProtocolError},
Error::{self},
Result,
};
use webrtc::ice_transport::ice_candidate::RTCIceCandidate;
use async_trait::async_trait;
use firezone_tunnel::{ConnId, ControlSignal, Request, Tunnel};
use firezone_tunnel::{ClientState, ControlSignal, Request, Tunnel};
use tokio::sync::Mutex;
use tokio_util::codec::{BytesCodec, FramedRead};
use url::Url;
@@ -41,35 +40,10 @@ impl ControlSignal for ControlSignaler {
.await?;
Ok(())
}
async fn signal_ice_candidate(
&self,
ice_candidate: RTCIceCandidate,
conn_id: ConnId,
) -> Result<()> {
// TODO: We probably want to have different signal_ice_candidate
// functions for gateway/client but ultimately we just want
// separate control_plane modules
if let ConnId::Gateway(id) = conn_id {
self.control_signal
.clone()
.send(EgressMessages::BroadcastIceCandidates(
BroadcastGatewayIceCandidates {
gateway_ids: vec![id],
candidates: vec![ice_candidate.to_json()?],
},
))
.await?;
Ok(())
} else {
Err(ControlProtocolError)
}
}
}
pub struct ControlPlane<CB: Callbacks> {
pub tunnel: Arc<Tunnel<ControlSignaler, CB>>,
pub tunnel: Arc<Tunnel<ControlSignaler, CB, ClientState>>,
pub control_signaler: ControlSignaler,
pub tunnel_init: Mutex<bool>,
}
@@ -301,6 +275,26 @@ impl<CB: Callbacks + 'static> ControlPlane<CB> {
.send(EgressMessages::CreateLogSink {})
.await;
}
pub async fn handle_tunnel_event(&mut self, event: firezone_tunnel::Event<GatewayId>) {
match event {
firezone_tunnel::Event::SignalIceCandidate { conn_id, candidate } => {
if let Err(e) = self
.control_signaler
.control_signal
.send(EgressMessages::BroadcastIceCandidates(
BroadcastGatewayIceCandidates {
gateway_ids: vec![conn_id],
candidates: vec![candidate],
},
))
.await
{
tracing::error!("Failed to signal ICE candidate: {e}")
}
}
}
}
}
async fn upload(path: PathBuf, url: Url) -> io::Result<()> {

View File

@@ -165,6 +165,7 @@ where
tokio::spawn(async move {
let mut log_stats_interval = tokio::time::interval(Duration::from_secs(10));
let mut upload_logs_interval = upload_interval();
loop {
tokio::select! {
Some((msg, reference)) = control_plane_receiver.recv() => {
@@ -173,6 +174,7 @@ where
Err(err) => control_plane.handle_error(err, reference).await,
}
},
event = control_plane.tunnel.next_event() => control_plane.handle_tunnel_event(event).await,
_ = log_stats_interval.tick() => control_plane.stats_event().await,
_ = upload_logs_interval.tick() => control_plane.request_log_upload_url().await,
else => break

View File

@@ -49,6 +49,12 @@ impl fmt::Display for ClientId {
}
}
impl fmt::Display for GatewayId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.0)
}
}
/// Represents a wireguard peer.
#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct Peer {

View File

@@ -25,6 +25,7 @@ domain = "0.8"
boringtun = { workspace = true }
chrono = { workspace = true }
pnet_packet = { version = "0.34" }
futures-bounded = { git = "https://github.com/libp2p/rust-libp2p", branch = "feat/stream-map" }
# TODO: research replacing for https://github.com/algesten/str0m
webrtc = { version = "0.8" }

View File

@@ -1,5 +1,7 @@
use boringtun::noise::Tunn;
use chrono::{DateTime, Utc};
use futures::channel::mpsc;
use futures_util::SinkExt;
use secrecy::ExposeSecret;
use std::sync::Arc;
use tracing::instrument;
@@ -20,6 +22,7 @@ use webrtc::{
},
};
use crate::role_state::RoleState;
use crate::{peer::Peer, ConnId, ControlSignal, PeerConfig, Tunnel};
mod client;
@@ -35,14 +38,15 @@ pub enum Request {
}
#[tracing::instrument(level = "trace", skip(tunnel))]
async fn handle_connection_state_update_with_peer<C, CB>(
tunnel: &Arc<Tunnel<C, CB>>,
async fn handle_connection_state_update_with_peer<C, CB, TRoleState>(
tunnel: &Arc<Tunnel<C, CB, TRoleState>>,
state: RTCPeerConnectionState,
index: u32,
conn_id: ConnId,
) where
C: ControlSignal + Clone + Send + Sync + 'static,
CB: Callbacks + 'static,
TRoleState: RoleState,
{
tracing::trace!(?state, "peer_state_update");
if state == RTCPeerConnectionState::Failed {
@@ -51,14 +55,15 @@ async fn handle_connection_state_update_with_peer<C, CB>(
}
#[tracing::instrument(level = "trace", skip(tunnel))]
fn set_connection_state_with_peer<C, CB>(
tunnel: &Arc<Tunnel<C, CB>>,
fn set_connection_state_with_peer<C, CB, TRoleState>(
tunnel: &Arc<Tunnel<C, CB, TRoleState>>,
peer_connection: &Arc<RTCPeerConnection>,
index: u32,
conn_id: ConnId,
) where
C: ControlSignal + Clone + Send + Sync + 'static,
CB: Callbacks + 'static,
TRoleState: RoleState,
{
let tunnel = Arc::clone(tunnel);
peer_connection.on_peer_connection_state_change(Box::new(
@@ -71,10 +76,11 @@ fn set_connection_state_with_peer<C, CB>(
));
}
impl<C, CB> Tunnel<C, CB>
impl<C, CB, TRoleState> Tunnel<C, CB, TRoleState>
where
C: ControlSignal + Clone + Send + Sync + 'static,
CB: Callbacks + 'static,
TRoleState: RoleState,
{
#[instrument(level = "trace", skip(self, data_channel, peer_config))]
async fn handle_channel_open(
@@ -152,11 +158,10 @@ where
}
#[tracing::instrument(level = "trace", skip(self))]
async fn initialize_peer_request(
pub async fn new_peer_connection(
self: &Arc<Self>,
relays: Vec<Relay>,
conn_id: ConnId,
) -> Result<Arc<RTCPeerConnection>> {
) -> Result<(Arc<RTCPeerConnection>, mpsc::Receiver<RTCIceCandidateInit>)> {
let config = RTCConfiguration {
ice_servers: relays
.into_iter()
@@ -176,50 +181,33 @@ where
.collect(),
..Default::default()
};
let peer_connection = Arc::new(self.webrtc_api.new_peer_connection(config).await?);
let (ice_candidate_tx, ice_candidate_rx) = tokio::sync::mpsc::channel(ICE_CANDIDATE_BUFFER);
self.ice_candidate_queue
.lock()
.insert(conn_id, ice_candidate_rx);
let (ice_candidate_tx, ice_candidate_rx) = mpsc::channel(ICE_CANDIDATE_BUFFER);
let callbacks = self.callbacks().clone();
peer_connection.on_ice_candidate(Box::new(move |candidate| {
let ice_candidate_tx = ice_candidate_tx.clone();
let callbacks = callbacks.clone();
let Some(candidate) = candidate else {
return Box::pin(async {});
};
let mut ice_candidate_tx = ice_candidate_tx.clone();
Box::pin(async move {
if let Err(e) = ice_candidate_tx.send(candidate).await {
tracing::error!(err = ?e, "buffer_ice_candidate");
let _ = callbacks.on_error(&e.into());
let ice_candidate = match candidate.to_json() {
Ok(ice_candidate) => ice_candidate,
Err(e) => {
tracing::warn!("Failed to serialize ICE candidate to JSON: {e}",);
return;
}
};
if ice_candidate_tx.send(ice_candidate).await.is_err() {
debug_assert!(false, "receiver was dropped before sender")
}
})
}));
Ok(peer_connection)
}
fn start_ice_candidate_handler(&self, conn_id: ConnId) -> Result<()> {
let mut ice_candidate_rx = self
.ice_candidate_queue
.lock()
.remove(&conn_id)
.ok_or(Error::ControlProtocolError)?;
let control_signaler = self.control_signaler.clone();
let callbacks = self.callbacks().clone();
tokio::spawn(async move {
while let Some(ice_candidate) = ice_candidate_rx.recv().await.flatten() {
if let Err(e) = control_signaler
.signal_ice_candidate(ice_candidate, conn_id)
.await
{
tracing::error!(err = ?e, "add_ice_candidate");
let _ = callbacks.on_error(&e);
}
}
});
Ok(())
Ok((peer_connection, ice_candidate_rx))
}
pub async fn add_ice_candidate(

View File

@@ -1,14 +1,10 @@
use std::sync::Arc;
use boringtun::x25519::{PublicKey, StaticSecret};
use chrono::{DateTime, Utc};
use connlib_shared::messages::SecretKey;
use connlib_shared::{
control::Reference,
messages::{
ClientId, GatewayId, Key, Relay, RequestConnection, ResourceDescription, ResourceId,
ReuseConnection,
},
messages::{GatewayId, Key, Relay, RequestConnection, ResourceId, ReuseConnection},
Callbacks,
};
use rand_core::OsRng;
@@ -21,11 +17,11 @@ use webrtc::{
},
};
use crate::{ControlSignal, Error, PeerConfig, Request, Result, Tunnel};
use crate::{ClientState, ControlSignal, Error, PeerConfig, Request, Result, Tunnel};
#[tracing::instrument(level = "trace", skip(tunnel))]
fn handle_connection_state_update<C, CB>(
tunnel: &Arc<Tunnel<C, CB>>,
tunnel: &Arc<Tunnel<C, CB, ClientState>>,
state: RTCPeerConnectionState,
gateway_id: GatewayId,
resource_id: ResourceId,
@@ -49,7 +45,7 @@ fn handle_connection_state_update<C, CB>(
#[tracing::instrument(level = "trace", skip(tunnel))]
fn set_connection_state_update<C, CB>(
tunnel: &Arc<Tunnel<C, CB>>,
tunnel: &Arc<Tunnel<C, CB, ClientState>>,
peer_connection: &Arc<RTCPeerConnection>,
gateway_id: GatewayId,
resource_id: ResourceId,
@@ -68,7 +64,7 @@ fn set_connection_state_update<C, CB>(
));
}
impl<C, CB> Tunnel<C, CB>
impl<C, CB> Tunnel<C, CB, ClientState>
where
C: ControlSignal + Clone + Send + Sync + 'static,
CB: Callbacks + 'static,
@@ -163,10 +159,11 @@ where
}
}
let peer_connection = {
let peer_connection = Arc::new(
self.initialize_peer_request(relays, gateway_id.into())
.await?,
);
let (peer_connection, receiver) = self.new_peer_connection(relays).await?;
self.role_state
.lock()
.add_waiting_ice_receiver(gateway_id, receiver);
let peer_connection = Arc::new(peer_connection);
let mut peer_connections = self.peer_connections.lock();
peer_connections.insert(gateway_id.into(), Arc::clone(&peer_connection));
peer_connection
@@ -279,24 +276,10 @@ where
.insert(gateway_id, gateway_public_key);
peer_connection.set_remote_description(rtc_sdp).await?;
self.start_ice_candidate_handler(gateway_id.into())?;
self.role_state
.lock()
.activate_ice_candidate_receiver(gateway_id);
Ok(())
}
pub fn allow_access(
&self,
resource: ResourceDescription,
client_id: ClientId,
expires_at: DateTime<Utc>,
) {
if let Some(peer) = self
.peers_by_ip
.write()
.iter_mut()
.find_map(|(_, p)| (p.conn_id == client_id.into()).then_some(p))
{
peer.add_resource(resource, expires_at);
}
}
}

View File

@@ -10,11 +10,12 @@ use webrtc::peer_connection::{
RTCPeerConnection,
};
use crate::role_state::GatewayState;
use crate::{ControlSignal, PeerConfig, Tunnel};
#[tracing::instrument(level = "trace", skip(tunnel))]
fn handle_connection_state_update<C, CB>(
tunnel: &Arc<Tunnel<C, CB>>,
tunnel: &Arc<Tunnel<C, CB, GatewayState>>,
state: RTCPeerConnectionState,
client_id: ClientId,
) where
@@ -29,7 +30,7 @@ fn handle_connection_state_update<C, CB>(
#[tracing::instrument(level = "trace", skip(tunnel))]
fn set_connection_state_update<C, CB>(
tunnel: &Arc<Tunnel<C, CB>>,
tunnel: &Arc<Tunnel<C, CB, GatewayState>>,
peer_connection: &Arc<RTCPeerConnection>,
client_id: ClientId,
) where
@@ -45,7 +46,7 @@ fn set_connection_state_update<C, CB>(
));
}
impl<C, CB> Tunnel<C, CB>
impl<C, CB> Tunnel<C, CB, GatewayState>
where
C: ControlSignal + Clone + Send + Sync + 'static,
CB: Callbacks + 'static,
@@ -72,10 +73,10 @@ where
expires_at: DateTime<Utc>,
resource: ResourceDescription,
) -> Result<RTCSessionDescription> {
let peer_connection = self
.initialize_peer_request(relays, client_id.into())
.await?;
self.start_ice_candidate_handler(client_id.into())?;
let (peer_connection, receiver) = self.new_peer_connection(relays).await?;
self.role_state
.lock()
.add_new_ice_receiver(client_id, receiver);
let index = self.next_index();
let tunnel = Arc::clone(self);
@@ -150,4 +151,20 @@ where
Ok(local_desc)
}
pub fn allow_access(
&self,
resource: ResourceDescription,
client_id: ClientId,
expires_at: DateTime<Utc>,
) {
if let Some(peer) = self
.peers_by_ip
.write()
.iter_mut()
.find_map(|(_, p)| (p.conn_id == client_id.into()).then_some(p))
{
peer.add_resource(resource, expires_at);
}
}
}

View File

@@ -28,7 +28,7 @@ pub(crate) enum SendPacket {
// as we can therefore we won't do it.
//
// See: https://stackoverflow.com/a/55093896
impl<C, CB> Tunnel<C, CB>
impl<C, CB, TRoleState> Tunnel<C, CB, TRoleState>
where
C: ControlSignal + Send + Sync + 'static,
CB: Callbacks + 'static,

View File

@@ -4,6 +4,7 @@ use boringtun::noise::{errors::WireGuardError, Tunn, TunnResult};
use bytes::Bytes;
use connlib_shared::{Callbacks, Error, Result};
use crate::role_state::RoleState;
use crate::{
device_channel::{DeviceIo, IfaceConfig},
dns,
@@ -13,10 +14,11 @@ use crate::{
const MAX_SIGNAL_CONNECTION_DELAY: Duration = Duration::from_secs(2);
impl<C, CB> Tunnel<C, CB>
impl<C, CB, TRoleState> Tunnel<C, CB, TRoleState>
where
C: ControlSignal + Send + Sync + 'static,
CB: Callbacks + 'static,
TRoleState: RoleState,
{
#[inline(always)]
fn connection_intent(self: &Arc<Self>, src: &[u8], dst_addr: &IpAddr) {

View File

@@ -24,12 +24,13 @@ use webrtc::{
interceptor_registry::register_default_interceptors, media_engine::MediaEngine,
setting_engine::SettingEngine, APIBuilder, API,
},
ice_transport::ice_candidate::RTCIceCandidate,
interceptor::registry::Registry,
peer_connection::RTCPeerConnection,
};
use std::{collections::HashMap, net::IpAddr, sync::Arc, time::Duration};
use std::task::{Context, Poll};
use std::{collections::HashMap, fmt, net::IpAddr, sync::Arc, time::Duration};
use webrtc::ice_transport::ice_candidate::RTCIceCandidateInit;
use connlib_shared::{
messages::{
@@ -41,8 +42,10 @@ use connlib_shared::{
use device_channel::{create_iface, DeviceIo, IfaceConfig};
pub use control_protocol::Request;
pub use role_state::{ClientState, GatewayState};
pub use webrtc::peer_connection::sdp::session_description::RTCSessionDescription;
use crate::role_state::RoleState;
use connlib_shared::messages::SecretKey;
use index::IndexLfsr;
@@ -56,6 +59,7 @@ mod peer;
mod peer_handler;
mod resource_sender;
mod resource_table;
mod role_state;
const MAX_UDP_SIZE: usize = (1 << 16) - 1;
const RESET_PACKET_COUNT_INTERVAL: Duration = Duration::from_secs(1);
@@ -90,6 +94,16 @@ impl From<ResourceId> for ConnId {
}
}
impl fmt::Display for ConnId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
ConnId::Gateway(inner) => fmt::Display::fmt(inner, f),
ConnId::Client(inner) => fmt::Display::fmt(inner, f),
ConnId::Resource(inner) => fmt::Display::fmt(inner, f),
}
}
}
/// Represent's the tunnel actual peer's config
/// Obtained from connlib_shared's Peer
#[derive(Clone)]
@@ -125,13 +139,6 @@ pub trait ControlSignal {
connected_gateway_ids: &[GatewayId],
reference: usize,
) -> Result<()>;
/// Signals a new candidate to the control plane
async fn signal_ice_candidate(
&self,
ice_candidate: RTCIceCandidate,
conn_id: ConnId,
) -> Result<()>;
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
@@ -149,7 +156,7 @@ struct Device {
// TODO: We should use newtypes for each kind of Id
/// Tunnel is a wireguard state machine that uses webrtc's ICE channels instead of UDP sockets
/// to communicate between peers.
pub struct Tunnel<C: ControlSignal, CB: Callbacks> {
pub struct Tunnel<C: ControlSignal, CB: Callbacks, TRoleState> {
next_index: Mutex<IndexLfsr>,
// We use a tokio Mutex here since this is only read/write during config so there's no relevant performance impact
device: tokio::sync::RwLock<Option<Device>>,
@@ -158,8 +165,6 @@ pub struct Tunnel<C: ControlSignal, CB: Callbacks> {
public_key: PublicKey,
peers_by_ip: RwLock<IpNetworkTable<Arc<Peer>>>,
peer_connections: Mutex<HashMap<ConnId, Arc<RTCPeerConnection>>>,
ice_candidate_queue:
Mutex<HashMap<ConnId, tokio::sync::mpsc::Receiver<Option<RTCIceCandidate>>>>,
awaiting_connection: Mutex<HashMap<ConnId, AwaitingConnectionDetails>>,
gateway_awaiting_connection: Mutex<HashMap<GatewayId, Vec<IpNetwork>>>,
resources_gateways: Mutex<HashMap<ResourceId, GatewayId>>,
@@ -169,6 +174,9 @@ pub struct Tunnel<C: ControlSignal, CB: Callbacks> {
gateway_public_keys: Mutex<HashMap<GatewayId, PublicKey>>,
callbacks: CallbackErrorFacade<CB>,
iface_handler_abort: Mutex<Option<AbortHandle>>,
/// State that differs per role, i.e. clients vs gateways.
role_state: Mutex<TRoleState>,
}
// TODO: For now we only use these fields with debug
@@ -187,10 +195,11 @@ pub struct TunnelStats {
gateway_awaiting_connection: HashMap<GatewayId, Vec<IpNetwork>>,
}
impl<C, CB> Tunnel<C, CB>
impl<C, CB, TRoleState> Tunnel<C, CB, TRoleState>
where
C: ControlSignal + Send + Sync + 'static,
CB: Callbacks + 'static,
TRoleState: RoleState,
{
pub fn stats(&self) -> TunnelStats {
let peers_by_ip = self
@@ -226,12 +235,28 @@ where
gateway_public_keys,
}
}
pub async fn next_event(&self) -> Event<TRoleState::Id> {
std::future::poll_fn(|cx| self.poll_next_event(cx)).await
}
pub fn poll_next_event(&self, cx: &mut Context<'_>) -> Poll<Event<TRoleState::Id>> {
self.role_state.lock().poll_next_event(cx)
}
}
impl<C, CB> Tunnel<C, CB>
pub enum Event<TId> {
SignalIceCandidate {
conn_id: TId,
candidate: RTCIceCandidateInit,
},
}
impl<C, CB, TRoleState> Tunnel<C, CB, TRoleState>
where
C: ControlSignal + Send + Sync + 'static,
CB: Callbacks + 'static,
TRoleState: RoleState,
{
/// Creates a new tunnel.
///
@@ -255,7 +280,6 @@ where
let resources_gateways = Default::default();
let gateway_awaiting_connection = Default::default();
let device = Default::default();
let ice_candidate_queue = Default::default();
let iface_handler_abort = Default::default();
// ICE
@@ -297,9 +321,9 @@ where
gateway_awaiting_connection,
control_signaler,
resources_gateways,
ice_candidate_queue,
callbacks: CallbackErrorFacade(callbacks),
iface_handler_abort,
role_state: Default::default(),
})
}

View File

@@ -4,15 +4,17 @@ use boringtun::noise::{handshake::parse_handshake_anon, Packet, TunnResult};
use bytes::Bytes;
use connlib_shared::{Callbacks, Error, Result};
use crate::role_state::RoleState;
use crate::{
device_channel::DeviceIo, index::check_packet_index, peer::Peer, ControlSignal, Tunnel,
MAX_UDP_SIZE,
};
impl<C, CB> Tunnel<C, CB>
impl<C, CB, TRoleState> Tunnel<C, CB, TRoleState>
where
C: ControlSignal + Send + Sync + 'static,
CB: Callbacks + 'static,
TRoleState: RoleState,
{
#[inline(always)]
fn is_wireguard_packet_ok(&self, parsed_packet: &Packet, peer: &Peer) -> bool {

View File

@@ -9,7 +9,7 @@ use crate::{
use connlib_shared::{messages::ResourceDescription, Callbacks, Error, Result};
impl<C, CB> Tunnel<C, CB>
impl<C, CB, TRoleState> Tunnel<C, CB, TRoleState>
where
C: ControlSignal + Send + Sync + 'static,
CB: Callbacks + 'static,

View File

@@ -0,0 +1,148 @@
use crate::Event;
use connlib_shared::messages::{ClientId, GatewayId};
use futures::channel::mpsc::Receiver;
use futures_bounded::{PushError, StreamMap};
use std::collections::HashMap;
use std::fmt;
use std::task::{ready, Context, Poll};
use std::time::Duration;
use webrtc::ice_transport::ice_candidate::RTCIceCandidateInit;
/// Dedicated trait for abstracting over the different ICE states.
///
/// By design, this trait does not allow any operations apart from advancing via [`RoleState::poll_next_event`].
/// The state should only be modified when the concrete type is known, e.g. [`ClientState`] or [`GatewayState`].
pub trait RoleState: Default + Send + 'static {
type Id: fmt::Debug;
fn poll_next_event(&mut self, cx: &mut Context<'_>) -> Poll<Event<Self::Id>>;
}
/// For how long we will attempt to gather ICE candidates before aborting.
///
/// Chosen arbitrarily.
/// Very likely, the actual WebRTC connection will timeout before this.
/// This timeout is just here to eventually clean-up tasks if they are somehow broken.
const ICE_GATHERING_TIMEOUT_SECONDS: u64 = 5 * 60;
/// How many concurrent ICE gathering attempts we are allow.
///
/// Chosen arbitrarily.
const MAX_CONCURRENT_ICE_GATHERING: usize = 100;
/// [`Tunnel`](crate::Tunnel) state specific to clients.
pub struct ClientState {
active_candidate_receivers: StreamMap<GatewayId, RTCIceCandidateInit>,
/// We split the receivers of ICE candidates into two phases because we only want to start sending them once we've received an SDP from the gateway.
waiting_for_sdp_from_gatway: HashMap<GatewayId, Receiver<RTCIceCandidateInit>>,
}
impl ClientState {
pub fn add_waiting_ice_receiver(
&mut self,
id: GatewayId,
receiver: Receiver<RTCIceCandidateInit>,
) {
self.waiting_for_sdp_from_gatway.insert(id, receiver);
}
pub fn activate_ice_candidate_receiver(&mut self, id: GatewayId) {
let Some(receiver) = self.waiting_for_sdp_from_gatway.remove(&id) else {
return;
};
match self.active_candidate_receivers.try_push(id, receiver) {
Ok(()) => {}
Err(PushError::BeyondCapacity(_)) => {
tracing::warn!("Too many active ICE candidate receivers at a time")
}
Err(PushError::Replaced(_)) => {
tracing::warn!(%id, "Replaced old ICE candidate receiver with new one")
}
}
}
}
impl Default for ClientState {
fn default() -> Self {
Self {
active_candidate_receivers: StreamMap::new(
Duration::from_secs(ICE_GATHERING_TIMEOUT_SECONDS),
MAX_CONCURRENT_ICE_GATHERING,
),
waiting_for_sdp_from_gatway: Default::default(),
}
}
}
impl RoleState for ClientState {
type Id = GatewayId;
fn poll_next_event(&mut self, cx: &mut Context<'_>) -> Poll<Event<Self::Id>> {
loop {
match ready!(self.active_candidate_receivers.poll_next_unpin(cx)) {
(conn_id, Some(Ok(c))) => {
return Poll::Ready(Event::SignalIceCandidate {
conn_id,
candidate: c,
})
}
(id, Some(Err(e))) => {
tracing::warn!(gateway_id = %id, "ICE gathering timed out: {e}")
}
(_, None) => {}
}
}
}
}
/// [`Tunnel`](crate::Tunnel) state specific to gateways.
pub struct GatewayState {
candidate_receivers: StreamMap<ClientId, RTCIceCandidateInit>,
}
impl GatewayState {
pub fn add_new_ice_receiver(&mut self, id: ClientId, receiver: Receiver<RTCIceCandidateInit>) {
match self.candidate_receivers.try_push(id, receiver) {
Ok(()) => {}
Err(PushError::BeyondCapacity(_)) => {
tracing::warn!("Too many active ICE candidate receivers at a time")
}
Err(PushError::Replaced(_)) => {
tracing::warn!(%id, "Replaced old ICE candidate receiver with new one")
}
}
}
}
impl Default for GatewayState {
fn default() -> Self {
Self {
candidate_receivers: StreamMap::new(
Duration::from_secs(ICE_GATHERING_TIMEOUT_SECONDS),
MAX_CONCURRENT_ICE_GATHERING,
),
}
}
}
impl RoleState for GatewayState {
type Id = ClientId;
fn poll_next_event(&mut self, cx: &mut Context<'_>) -> Poll<Event<Self::Id>> {
loop {
match ready!(self.candidate_receivers.poll_next_unpin(cx)) {
(conn_id, Some(Ok(c))) => {
return Poll::Ready(Event::SignalIceCandidate {
conn_id,
candidate: c,
})
}
(id, Some(Err(e))) => {
tracing::warn!(gateway_id = %id, "ICE gathering timed out: {e}")
}
(_, None) => {}
}
}
}
}

View File

@@ -1,24 +1,12 @@
use async_trait::async_trait;
use connlib_shared::messages::ClientId;
use connlib_shared::Error::ControlProtocolError;
use connlib_shared::{
messages::{GatewayId, ResourceDescription},
Result,
};
use firezone_tunnel::{ConnId, ControlSignal};
use tokio::sync::mpsc;
use webrtc::ice_transport::ice_candidate::RTCIceCandidate;
use firezone_tunnel::ControlSignal;
#[derive(Clone)]
pub struct ControlSignaler {
tx: mpsc::Sender<(ClientId, RTCIceCandidate)>,
}
impl ControlSignaler {
pub fn new(tx: mpsc::Sender<(ClientId, RTCIceCandidate)>) -> Self {
Self { tx }
}
}
pub struct ControlSignaler;
#[async_trait]
impl ControlSignal for ControlSignaler {
@@ -31,20 +19,4 @@ impl ControlSignal for ControlSignaler {
tracing::warn!("A message to network resource: {resource:?} was discarded, gateways aren't meant to be used as clients.");
Ok(())
}
async fn signal_ice_candidate(
&self,
ice_candidate: RTCIceCandidate,
conn_id: ConnId,
) -> Result<()> {
// TODO: We probably want to have different signal_ice_candidate
// functions for gateway/client but ultimately we just want
// separate control_plane modules
if let ConnId::Client(id) = conn_id {
let _ = self.tx.send((id, ice_candidate)).await;
Ok(())
} else {
Err(ControlProtocolError)
}
}
}

View File

@@ -7,21 +7,18 @@ use crate::CallbackHandler;
use anyhow::Result;
use connlib_shared::messages::ClientId;
use connlib_shared::Error;
use firezone_tunnel::Tunnel;
use firezone_tunnel::{GatewayState, Tunnel};
use phoenix_channel::PhoenixChannel;
use std::convert::Infallible;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Duration;
use tokio::sync::mpsc;
use webrtc::ice_transport::ice_candidate::RTCIceCandidate;
use webrtc::peer_connection::sdp::session_description::RTCSessionDescription;
pub const PHOENIX_TOPIC: &str = "gateway";
pub struct Eventloop<'a> {
tunnel: Arc<Tunnel<ControlSignaler, CallbackHandler>>,
control_rx: &'a mut mpsc::Receiver<(ClientId, RTCIceCandidate)>,
pub struct Eventloop {
tunnel: Arc<Tunnel<ControlSignaler, CallbackHandler, GatewayState>>,
portal: PhoenixChannel<IngressMessages, ()>,
// TODO: Strongly type request reference (currently `String`)
@@ -32,15 +29,13 @@ pub struct Eventloop<'a> {
print_stats_timer: tokio::time::Interval,
}
impl<'a> Eventloop<'a> {
impl Eventloop {
pub(crate) fn new(
tunnel: Arc<Tunnel<ControlSignaler, CallbackHandler>>,
control_rx: &'a mut mpsc::Receiver<(ClientId, RTCIceCandidate)>,
tunnel: Arc<Tunnel<ControlSignaler, CallbackHandler, GatewayState>>,
portal: PhoenixChannel<IngressMessages, ()>,
) -> Eventloop<'a> {
) -> Self {
Self {
tunnel,
control_rx,
portal,
// TODO: Pick sane values for timeouts and size.
@@ -54,34 +49,10 @@ impl<'a> Eventloop<'a> {
}
}
impl Eventloop<'_> {
impl Eventloop {
#[tracing::instrument(name = "Eventloop::poll", skip_all, level = "debug")]
pub fn poll(&mut self, cx: &mut Context<'_>) -> Poll<Result<Infallible>> {
loop {
if let Poll::Ready(Some((client, ice_candidate))) = self.control_rx.poll_recv(cx) {
let ice_candidate = match ice_candidate.to_json() {
Ok(ice_candidate) => ice_candidate,
Err(e) => {
tracing::warn!(
"Failed to serialize ICE candidate to JSON: {:#}",
anyhow::Error::new(e)
);
continue;
}
};
tracing::debug!(%client, candidate = %ice_candidate.candidate, "Sending ICE candidate to client");
let _id = self.portal.send(
PHOENIX_TOPIC,
EgressMessages::BroadcastIceCandidates(BroadcastClientIceCandidates {
client_ids: vec![client],
candidates: vec![ice_candidate],
}),
);
continue;
}
match self.connection_request_tasks.poll_unpin(cx) {
Poll::Ready(((client, reference), Ok(Ok(gateway_rtc_session_description)))) => {
tracing::debug!(%client, %reference, "Connection is ready");
@@ -203,6 +174,25 @@ impl Eventloop<'_> {
_ => {}
}
match self.tunnel.poll_next_event(cx) {
Poll::Ready(firezone_tunnel::Event::SignalIceCandidate {
conn_id: client,
candidate,
}) => {
tracing::debug!(%client, candidate = %candidate.candidate, "Sending ICE candidate to client");
let _id = self.portal.send(
PHOENIX_TOPIC,
EgressMessages::BroadcastIceCandidates(BroadcastClientIceCandidates {
client_ids: vec![client],
candidates: vec![candidate],
}),
);
continue;
}
Poll::Pending => {}
}
if self.print_stats_timer.poll_tick(cx).is_ready() {
tracing::debug!(target: "tunnel_state", stats = ?self.tunnel.stats());
continue;

View File

@@ -2,23 +2,18 @@ use crate::control::ControlSignaler;
use crate::eventloop::{Eventloop, PHOENIX_TOPIC};
use crate::messages::InitGateway;
use anyhow::{Context, Result};
use backoff::backoff::Backoff;
use backoff::ExponentialBackoffBuilder;
use clap::Parser;
use connlib_shared::messages::ClientId;
use connlib_shared::{get_device_id, get_user_agent, login_url, Callbacks, Mode};
use firezone_tunnel::Tunnel;
use futures::future;
use firezone_tunnel::{GatewayState, Tunnel};
use futures::{future, TryFutureExt};
use headless_utils::{setup_global_subscriber, CommonArgs};
use phoenix_channel::SecureUrl;
use secrecy::{Secret, SecretString};
use std::convert::Infallible;
use std::pin::pin;
use std::sync::Arc;
use tokio::sync::mpsc;
use tracing_subscriber::layer;
use url::Url;
use webrtc::ice_transport::ice_candidate::RTCIceCandidate;
mod control;
mod eventloop;
@@ -35,40 +30,25 @@ async fn main() -> Result<()> {
SecretString::new(cli.common.secret),
get_device_id(),
)?;
let tunnel = Arc::new(Tunnel::new(private_key, ControlSignaler, CallbackHandler).await?);
// Note: This channel is only needed because [`Tunnel`] does not (yet) have a synchronous, poll-like interface. If it would have, ICE candidates would be emitted as events and we could just hand them to the phoenix channel.
let (control_tx, mut control_rx) = mpsc::channel(1);
let signaler = ControlSignaler::new(control_tx);
let tunnel = Arc::new(Tunnel::new(private_key, signaler, CallbackHandler).await?);
let mut backoff = ExponentialBackoffBuilder::default()
.with_max_elapsed_time(None)
.build();
let eventloop = async {
loop {
let error = match run(tunnel.clone(), &mut control_rx, connect_url.clone()).await {
Err(e) => e,
Ok(never) => match never {},
};
let t = backoff
.next_backoff()
.expect("the exponential backoff reconnect loop should run indefinitely");
tokio::spawn(backoff::future::retry_notify(
ExponentialBackoffBuilder::default()
.with_max_elapsed_time(None)
.build(),
move || run(tunnel.clone(), connect_url.clone()).map_err(backoff::Error::transient),
|error, t| {
tracing::warn!(retry_in = ?t, "Error connecting to portal: {error:#}");
},
));
tokio::time::sleep(t).await;
}
};
future::select(pin!(eventloop), pin!(tokio::signal::ctrl_c())).await;
tokio::signal::ctrl_c().await?;
Ok(())
}
async fn run(
tunnel: Arc<Tunnel<ControlSignaler, CallbackHandler>>,
control_rx: &mut mpsc::Receiver<(ClientId, RTCIceCandidate)>,
tunnel: Arc<Tunnel<ControlSignaler, CallbackHandler, GatewayState>>,
connect_url: Url,
) -> Result<Infallible> {
let (portal, init) = phoenix_channel::init::<InitGateway, _, _>(
@@ -84,7 +64,7 @@ async fn run(
.await
.context("Failed to set interface")?;
let mut eventloop = Eventloop::new(tunnel, control_rx, portal);
let mut eventloop = Eventloop::new(tunnel, portal);
future::poll_fn(|cx| eventloop.poll(cx)).await
}