diff --git a/rust/Dockerfile b/rust/Dockerfile index db87b1b23..d7a0f96b2 100644 --- a/rust/Dockerfile +++ b/rust/Dockerfile @@ -12,7 +12,7 @@ RUN cargo chef prepare --recipe-path recipe.json FROM chef as builder COPY --from=planner /build/recipe.json . -RUN cargo chef cook --release --recipe-path recipe.json +RUN cargo chef cook --all-targets --release --recipe-path recipe.json COPY . . ARG PACKAGE RUN cargo build -p $PACKAGE --release diff --git a/rust/connlib/clients/shared/src/control.rs b/rust/connlib/clients/shared/src/control.rs index 250cd11c9..04c6ca81c 100644 --- a/rust/connlib/clients/shared/src/control.rs +++ b/rust/connlib/clients/shared/src/control.rs @@ -14,48 +14,19 @@ use connlib_shared::{ Result, }; -use async_trait::async_trait; -use firezone_tunnel::{ClientState, ControlSignal, Request, Tunnel}; +use firezone_tunnel::{ClientState, Request, Tunnel}; use reqwest::header::{CONTENT_ENCODING, CONTENT_TYPE}; use tokio::io::BufReader; use tokio::sync::Mutex; use tokio_util::codec::{BytesCodec, FramedRead}; use url::Url; -#[async_trait] -impl ControlSignal for ControlSignaler { - async fn signal_connection_to( - &self, - resource: &ResourceDescription, - connected_gateway_ids: &[GatewayId], - reference: usize, - ) -> Result<()> { - self.control_signal - // It's easier if self is not mut - .clone() - .send_with_ref( - EgressMessages::PrepareConnection { - resource_id: resource.id(), - connected_gateway_ids: connected_gateway_ids.to_vec(), - }, - reference, - ) - .await?; - Ok(()) - } -} - pub struct ControlPlane { - pub tunnel: Arc>, - pub control_signaler: ControlSignaler, + pub tunnel: Arc>, + pub phoenix_channel: PhoenixSenderWithTopic, pub tunnel_init: Mutex, } -#[derive(Clone)] -pub struct ControlSignaler { - pub control_signal: PhoenixSenderWithTopic, -} - impl ControlPlane { #[tracing::instrument(level = "trace", skip(self))] async fn init( @@ -139,7 +110,7 @@ impl ControlPlane { reference: Option, ) { let tunnel = Arc::clone(&self.tunnel); - let mut control_signaler = self.control_signaler.clone(); + let mut control_signaler = self.phoenix_channel.clone(); tokio::spawn(async move { let err = match tunnel .request_connection(resource_id, gateway_id, relays, reference) @@ -147,7 +118,6 @@ impl ControlPlane { { Ok(Request::NewConnection(connection_request)) => { if let Err(err) = control_signaler - .control_signal // TODO: create a reference number and keep track for the response .send_with_ref( EgressMessages::RequestConnection(connection_request), @@ -162,7 +132,6 @@ impl ControlPlane { } Ok(Request::ReuseConnection(connection_request)) => { if let Err(err) = control_signaler - .control_signal // TODO: create a reference number and keep track for the response .send_with_ref( EgressMessages::ReuseConnection(connection_request), @@ -178,7 +147,7 @@ impl ControlPlane { Err(err) => err, }; - tunnel.cleanup_connection(resource_id.into()); + tunnel.cleanup_connection(resource_id); tracing::error!("Error request connection details: {err}"); let _ = tunnel.callbacks().on_error(&err); }); @@ -250,7 +219,7 @@ impl ControlPlane { return; }; // TODO: Rate limit the number of attempts of getting the relays before just trying a local network connection - self.tunnel.cleanup_connection(resource_id.into()); + self.tunnel.cleanup_connection(resource_id); } None => { tracing::error!( @@ -273,8 +242,7 @@ impl ControlPlane { tracing::info!("Requesting log upload URL from portal"); let _ = self - .control_signaler - .control_signal + .phoenix_channel .send(EgressMessages::CreateLogSink {}) .await; } @@ -283,8 +251,7 @@ impl ControlPlane { match event { firezone_tunnel::Event::SignalIceCandidate { conn_id, candidate } => { if let Err(e) = self - .control_signaler - .control_signal + .phoenix_channel .send(EgressMessages::BroadcastIceCandidates( BroadcastGatewayIceCandidates { gateway_ids: vec![conn_id], @@ -296,6 +263,28 @@ impl ControlPlane { tracing::error!("Failed to signal ICE candidate: {e}") } } + firezone_tunnel::Event::ConnectionIntent { + resource, + connected_gateway_ids, + reference, + } => { + if let Err(e) = self + .phoenix_channel + .clone() + .send_with_ref( + EgressMessages::PrepareConnection { + resource_id: resource.id(), + connected_gateway_ids: connected_gateway_ids.to_vec(), + }, + reference, + ) + .await + { + tracing::error!("Failed to prepare connection: {e}"); + + // TODO: Clean up connection in `ClientState` here? + } + } } } } diff --git a/rust/connlib/clients/shared/src/lib.rs b/rust/connlib/clients/shared/src/lib.rs index 8d9889ff6..eda2ff2b7 100644 --- a/rust/connlib/clients/shared/src/lib.rs +++ b/rust/connlib/clients/shared/src/lib.rs @@ -3,7 +3,6 @@ pub use connlib_shared::{get_device_id, messages::ResourceDescription}; pub use connlib_shared::{Callbacks, Error}; pub use tracing_appender::non_blocking::WorkerGuard; -use crate::control::ControlSignaler; use backoff::{backoff::Backoff, ExponentialBackoffBuilder}; use connlib_shared::control::SecureUrl; use connlib_shared::{control::PhoenixChannel, login_url, CallbackErrorFacade, Mode, Result}; @@ -149,16 +148,15 @@ where } }); - let control_signaler = ControlSignaler { control_signal: connection.sender_with_topic("client".to_owned()) }; let tunnel = fatal_error!( - Tunnel::new(private_key, control_signaler.clone(), callbacks.clone()).await, + Tunnel::new(private_key, callbacks.clone()).await, runtime_stopper, &callbacks ); let mut control_plane = ControlPlane { tunnel: Arc::new(tunnel), - control_signaler, + phoenix_channel: connection.sender_with_topic("client".to_owned()), tunnel_init: Mutex::new(false), }; diff --git a/rust/connlib/tunnel/src/client.rs b/rust/connlib/tunnel/src/client.rs index 3254e6425..5e2b69352 100644 --- a/rust/connlib/tunnel/src/client.rs +++ b/rust/connlib/tunnel/src/client.rs @@ -1,25 +1,34 @@ use crate::device_channel::{create_iface, DeviceIo}; -use crate::ip_packet::IpPacket; +use crate::peer::Peer; +use crate::resource_table::ResourceTable; use crate::{ - dns, tokio_util, ConnId, ControlSignal, Device, Event, RoleState, Tunnel, + dns, peer_by_ip, tokio_util, Device, Event, PeerConfig, RoleState, Tunnel, ICE_GATHERING_TIMEOUT_SECONDS, MAX_CONCURRENT_ICE_GATHERING, MAX_UDP_SIZE, }; +use boringtun::x25519::{PublicKey, StaticSecret}; use connlib_shared::error::{ConnlibError as Error, ConnlibError}; -use connlib_shared::messages::{GatewayId, Interface as InterfaceConfig, ResourceDescription}; +use connlib_shared::messages::{ + GatewayId, Interface as InterfaceConfig, Key, ResourceDescription, ResourceId, ReuseConnection, + SecretKey, +}; use connlib_shared::{Callbacks, DNS_SENTINEL}; use futures::channel::mpsc::Receiver; +use futures::stream; use futures_bounded::{PushError, StreamMap}; use ip_network::IpNetwork; +use ip_network_table::IpNetworkTable; +use std::collections::hash_map::Entry; use std::collections::HashMap; use std::io; +use std::net::IpAddr; use std::sync::Arc; -use std::task::{ready, Context, Poll}; +use std::task::{Context, Poll}; use std::time::Duration; +use tokio::time::Instant; use webrtc::ice_transport::ice_candidate::RTCIceCandidateInit; -impl Tunnel +impl Tunnel where - C: ControlSignal + Send + Sync + 'static, CB: Callbacks + 'static, { /// Adds a the given resource to the tunnel. @@ -47,9 +56,9 @@ where } let resource_list = { - let mut resources = self.resources.write(); - resources.insert(resource_description); - resources.resource_list() + let mut role_state = self.role_state.lock(); + role_state.resources.insert(resource_description); + role_state.resources.resource_list() }; self.callbacks.on_update_resources(resource_list)?; @@ -80,6 +89,13 @@ where Ok(()) } + /// Clean up a connection to a resource. + // FIXME: this cleanup connection is wrong! + pub fn cleanup_connection(&self, id: ResourceId) { + self.role_state.lock().on_connection_failed(id); + self.peer_connections.lock().remove(&id.into()); + } + #[tracing::instrument(level = "trace", skip(self))] async fn add_route(self: &Arc, route: IpNetwork) -> connlib_shared::Result<()> { let mut device = self.device.write().await; @@ -100,80 +116,14 @@ where Ok(()) } - - #[inline(always)] - fn connection_intent(self: &Arc, packet: IpPacket<'_>) { - const MAX_SIGNAL_CONNECTION_DELAY: Duration = Duration::from_secs(2); - - // We can buffer requests here but will drop them for now and let the upper layer reliability protocol handle this - if let Some(resource) = self.get_resource(packet.destination()) { - // We have awaiting connection to prevent a race condition where - // create_peer_connection hasn't added the thing to peer_connections - // and we are finding another packet to the same address (otherwise we would just use peer_connections here) - let mut awaiting_connection = self.awaiting_connection.lock(); - let conn_id = ConnId::from(resource.id()); - if awaiting_connection.get(&conn_id).is_none() { - tracing::trace!( - resource_ip = %packet.destination(), - "resource_connection_intent", - ); - - awaiting_connection.insert(conn_id, Default::default()); - let dev = Arc::clone(self); - - let mut connected_gateway_ids: Vec<_> = dev - .gateway_awaiting_connection - .lock() - .clone() - .into_keys() - .collect(); - connected_gateway_ids - .extend(dev.resources_gateways.lock().values().collect::>()); - tracing::trace!( - gateways = ?connected_gateway_ids, - "connected_gateways" - ); - tokio::spawn(async move { - let mut interval = tokio::time::interval(MAX_SIGNAL_CONNECTION_DELAY); - loop { - interval.tick().await; - let reference = { - let mut awaiting_connections = dev.awaiting_connection.lock(); - let Some(awaiting_connection) = - awaiting_connections.get_mut(&ConnId::from(resource.id())) - else { - break; - }; - if awaiting_connection.response_received { - break; - } - awaiting_connection.total_attemps += 1; - awaiting_connection.total_attemps - }; - if let Err(e) = dev - .control_signaler - .signal_connection_to(&resource, &connected_gateway_ids, reference) - .await - { - // Not a deadlock because this is a different task - dev.awaiting_connection.lock().remove(&conn_id); - tracing::error!(error = ?e, "start_resource_connection"); - let _ = dev.callbacks.on_error(&e); - } - } - }); - } - } - } } /// Reads IP packets from the [`Device`] and handles them accordingly. -async fn device_handler( - tunnel: Arc>, +async fn device_handler( + tunnel: Arc>, mut device: Device, ) -> Result<(), ConnlibError> where - C: ControlSignal + Send + Sync + 'static, CB: Callbacks + 'static, { let device_writer = device.io.clone(); @@ -183,7 +133,9 @@ where return Ok(()); }; - if let Some(dns_packet) = dns::parse(&tunnel.resources.read(), packet.as_immutable()) { + if let Some(dns_packet) = + dns::parse(&tunnel.role_state.lock().resources, packet.as_immutable()) + { if let Err(e) = send_dns_packet(&device_writer, dns_packet) { tracing::error!(err = %e, "failed to send DNS packet"); let _ = tunnel.callbacks.on_error(&e.into()); @@ -194,8 +146,11 @@ where let dest = packet.destination(); - let Some(peer) = tunnel.peer_by_ip(dest) else { - tunnel.connection_intent(packet.as_immutable()); + let Some(peer) = peer_by_ip(&tunnel.peers_by_ip.read(), dest) else { + tunnel + .role_state + .lock() + .on_connection_intent(packet.destination()); continue; }; @@ -223,9 +178,190 @@ pub struct ClientState { active_candidate_receivers: StreamMap, /// We split the receivers of ICE candidates into two phases because we only want to start sending them once we've received an SDP from the gateway. waiting_for_sdp_from_gatway: HashMap>, + + // TODO: Make private + pub awaiting_connection: HashMap, + pub gateway_awaiting_connection: HashMap>, + + awaiting_connection_timers: StreamMap, + + pub gateway_public_keys: HashMap, + resources_gateways: HashMap, + resources: ResourceTable, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct AwaitingConnectionDetails { + total_attemps: usize, + response_received: bool, + gateways: Vec, } impl ClientState { + pub(crate) fn attempt_to_reuse_connection( + &mut self, + resource: ResourceId, + gateway: GatewayId, + expected_attempts: usize, + connected_peers: &mut IpNetworkTable>, + ) -> Result, ConnlibError> { + if self.is_connected_to(resource, connected_peers) { + return Err(Error::UnexpectedConnectionDetails); + } + + let desc = self + .resources + .get_by_id(&resource) + .ok_or(Error::UnknownResource)?; + + let details = self + .awaiting_connection + .get_mut(&resource) + .ok_or(Error::UnexpectedConnectionDetails)?; + + details.response_received = true; + + if details.total_attemps != expected_attempts { + return Err(Error::UnexpectedConnectionDetails); + } + + self.resources_gateways.insert(resource, gateway); + + match self.gateway_awaiting_connection.entry(gateway) { + Entry::Occupied(mut occupied) => { + occupied.get_mut().extend(desc.ips()); + return Ok(Some(ReuseConnection { + resource_id: resource, + gateway_id: gateway, + })); + } + Entry::Vacant(vacant) => { + vacant.insert(vec![]); + } + } + + let found = { + let peer = connected_peers + .iter() + .find_map(|(_, p)| (p.conn_id == gateway.into()).then_some(p)) + .cloned(); + if let Some(peer) = peer { + for ip in desc.ips() { + peer.add_allowed_ip(ip); + connected_peers.insert(ip, Arc::clone(&peer)); + } + true + } else { + false + } + }; + + if found { + self.awaiting_connection.remove(&resource); + self.awaiting_connection_timers.remove(resource); + + return Ok(Some(ReuseConnection { + resource_id: resource, + gateway_id: gateway, + })); + } + + Ok(None) + } + + pub fn on_connection_failed(&mut self, resource: ResourceId) { + self.awaiting_connection.remove(&resource); + let Some(gateway) = self.resources_gateways.remove(&resource) else { + return; + }; + self.gateway_awaiting_connection.remove(&gateway); + self.awaiting_connection_timers.remove(resource); + } + + pub fn on_connection_intent(&mut self, destination: IpAddr) { + if self.is_awaiting_connection_to(destination) { + return; + } + + tracing::trace!(resource_ip = %destination, "resource_connection_intent"); + + let Some(resource) = self.get_resource_by_destination(destination) else { + return; + }; + + const MAX_SIGNAL_CONNECTION_DELAY: Duration = Duration::from_secs(2); + + let resource_id = resource.id(); + + let connected_gateway_ids = self + .gateway_awaiting_connection + .clone() + .into_keys() + .chain(self.resources_gateways.values().cloned()) + .collect(); + + tracing::trace!( + gateways = ?connected_gateway_ids, + "connected_gateways" + ); + + match self.awaiting_connection_timers.try_push( + resource_id, + stream::poll_fn({ + let mut interval = tokio::time::interval(MAX_SIGNAL_CONNECTION_DELAY); + move |cx| interval.poll_tick(cx).map(Some) + }), + ) { + Ok(()) => {} + Err(PushError::BeyondCapacity(_)) => { + tracing::warn!(%resource_id, "Too many concurrent connection attempts"); + return; + } + Err(PushError::Replaced(_)) => { + // The timers are equivalent for our purpose so we don't really care about this one. + } + } + + self.awaiting_connection.insert( + resource_id, + AwaitingConnectionDetails { + total_attemps: 0, + response_received: false, + gateways: connected_gateway_ids, + }, + ); + } + + pub fn create_peer_config_for_new_connection( + &mut self, + resource: ResourceId, + gateway: GatewayId, + shared_key: StaticSecret, + ) -> Result { + let Some(public_key) = self.gateway_public_keys.remove(&gateway) else { + self.awaiting_connection.remove(&resource); + self.gateway_awaiting_connection.remove(&gateway); + + return Err(Error::ControlProtocolError); + }; + + let desc = self + .resources + .get_by_id(&resource) + .ok_or(Error::ControlProtocolError)?; + + Ok(PeerConfig { + persistent_keepalive: None, + public_key, + ips: desc.ips(), + preshared_key: SecretKey::new(Key(shared_key.to_bytes())), + }) + } + + pub fn gateway_by_resource(&self, resource: &ResourceId) -> Option { + self.resources_gateways.get(resource).copied() + } + pub fn add_waiting_ice_receiver( &mut self, id: GatewayId, @@ -234,10 +370,11 @@ impl ClientState { self.waiting_for_sdp_from_gatway.insert(id, receiver); } - pub fn activate_ice_candidate_receiver(&mut self, id: GatewayId) { + pub fn activate_ice_candidate_receiver(&mut self, id: GatewayId, key: PublicKey) { let Some(receiver) = self.waiting_for_sdp_from_gatway.remove(&id) else { return; }; + self.gateway_public_keys.insert(id, key); match self.active_candidate_receivers.try_push(id, receiver) { Ok(()) => {} @@ -249,6 +386,36 @@ impl ClientState { } } } + + fn is_awaiting_connection_to(&self, destination: IpAddr) -> bool { + let Some(resource) = self.get_resource_by_destination(destination) else { + return false; + }; + + self.awaiting_connection.contains_key(&resource.id()) + } + + fn is_connected_to( + &self, + resource: ResourceId, + connected_peers: &IpNetworkTable>, + ) -> bool { + let Some(resource) = self.resources.get_by_id(&resource) else { + return false; + }; + + resource + .ips() + .iter() + .any(|ip| connected_peers.exact_match(*ip).is_some()) + } + + fn get_resource_by_destination(&self, destination: IpAddr) -> Option<&ResourceDescription> { + match destination { + IpAddr::V4(ipv4) => self.resources.get_by_ip(ipv4), + IpAddr::V6(ipv6) => self.resources.get_by_ip(ipv6), + } + } } impl Default for ClientState { @@ -259,6 +426,12 @@ impl Default for ClientState { MAX_CONCURRENT_ICE_GATHERING, ), waiting_for_sdp_from_gatway: Default::default(), + awaiting_connection: Default::default(), + gateway_awaiting_connection: Default::default(), + awaiting_connection_timers: StreamMap::new(Duration::from_secs(60), 100), + gateway_public_keys: Default::default(), + resources_gateways: Default::default(), + resources: Default::default(), } } } @@ -268,18 +441,60 @@ impl RoleState for ClientState { fn poll_next_event(&mut self, cx: &mut Context<'_>) -> Poll> { loop { - match ready!(self.active_candidate_receivers.poll_next_unpin(cx)) { - (conn_id, Some(Ok(c))) => { + match self.active_candidate_receivers.poll_next_unpin(cx) { + Poll::Ready((conn_id, Some(Ok(c)))) => { return Poll::Ready(Event::SignalIceCandidate { conn_id, candidate: c, }) } - (id, Some(Err(e))) => { + Poll::Ready((id, Some(Err(e)))) => { tracing::warn!(gateway_id = %id, "ICE gathering timed out: {e}") } - (_, None) => {} + Poll::Ready((_, None)) => continue, + Poll::Pending => {} } + + match self.awaiting_connection_timers.poll_next_unpin(cx) { + Poll::Ready((resource, Some(Ok(_)))) => { + let Entry::Occupied(mut entry) = self.awaiting_connection.entry(resource) + else { + self.awaiting_connection_timers.remove(resource); + + continue; + }; + + if entry.get().response_received { + self.awaiting_connection_timers.remove(resource); + + // entry.remove(); Maybe? + + continue; + } + + entry.get_mut().total_attemps += 1; + + let reference = entry.get_mut().total_attemps; + + return Poll::Ready(Event::ConnectionIntent { + resource: self + .resources + .get_by_id(&resource) + .expect("inconsistent internal state") + .clone(), + connected_gateway_ids: entry.get().gateways.clone(), + reference, + }); + } + + Poll::Ready((id, Some(Err(e)))) => { + tracing::warn!(resource_id = %id, "Connection establishment timeout: {e}") + } + Poll::Ready((_, None)) => continue, + Poll::Pending => {} + } + + return Poll::Pending; } } } diff --git a/rust/connlib/tunnel/src/control_protocol.rs b/rust/connlib/tunnel/src/control_protocol.rs index f3f5d044e..0d3c92a2e 100644 --- a/rust/connlib/tunnel/src/control_protocol.rs +++ b/rust/connlib/tunnel/src/control_protocol.rs @@ -1,17 +1,14 @@ -use boringtun::noise::Tunn; -use chrono::{DateTime, Utc}; use futures::channel::mpsc; use futures_util::SinkExt; -use secrecy::ExposeSecret; use std::sync::Arc; -use tracing::instrument; use connlib_shared::{ - messages::{Relay, RequestConnection, ResourceDescription, ReuseConnection}, + messages::{Relay, RequestConnection, ReuseConnection}, Callbacks, Error, Result, }; +use webrtc::data_channel::OnCloseHdlrFn; +use webrtc::peer_connection::OnPeerConnectionStateChangeHdlrFn; use webrtc::{ - data_channel::RTCDataChannel, ice_transport::{ ice_candidate::RTCIceCandidateInit, ice_credential_type::RTCIceCredentialType, ice_server::RTCIceServer, @@ -22,7 +19,7 @@ use webrtc::{ }, }; -use crate::{peer::Peer, ConnId, ControlSignal, PeerConfig, RoleState, Tunnel}; +use crate::{ConnId, RoleState, Tunnel}; mod client; mod gateway; @@ -36,177 +33,35 @@ pub enum Request { ReuseConnection(ReuseConnection), } -#[tracing::instrument(level = "trace", skip(tunnel))] -async fn handle_connection_state_update_with_peer( - tunnel: &Arc>, - state: RTCPeerConnectionState, - index: u32, - conn_id: ConnId, -) where - C: ControlSignal + Clone + Send + Sync + 'static, - CB: Callbacks + 'static, - TRoleState: RoleState, -{ - tracing::trace!(?state, "peer_state_update"); - if state == RTCPeerConnectionState::Failed { - tunnel.stop_peer(index, conn_id).await; - } -} - -#[tracing::instrument(level = "trace", skip(tunnel))] -fn set_connection_state_with_peer( - tunnel: &Arc>, - peer_connection: &Arc, - index: u32, - conn_id: ConnId, -) where - C: ControlSignal + Clone + Send + Sync + 'static, - CB: Callbacks + 'static, - TRoleState: RoleState, -{ - let tunnel = Arc::clone(tunnel); - peer_connection.on_peer_connection_state_change(Box::new( - move |state: RTCPeerConnectionState| { - let tunnel = Arc::clone(&tunnel); - Box::pin(async move { - handle_connection_state_update_with_peer(&tunnel, state, index, conn_id).await - }) - }, - )); -} - -impl Tunnel +impl Tunnel where - C: ControlSignal + Clone + Send + Sync + 'static, CB: Callbacks + 'static, TRoleState: RoleState, { - #[instrument(level = "trace", skip(self, data_channel, peer_config))] - async fn handle_channel_open( - self: &Arc, - data_channel: Arc, - index: u32, - peer_config: PeerConfig, - conn_id: ConnId, - resources: Option<(ResourceDescription, DateTime)>, - ) -> Result<()> { - tracing::trace!( - ?peer_config.ips, - "data_channel_open", - ); - let channel = data_channel.detach().await?; - let tunn = Tunn::new( - self.private_key.clone(), - peer_config.public_key, - Some(peer_config.preshared_key.expose_secret().0), - peer_config.persistent_keepalive, - index, - None, - )?; - - let peer = Arc::new(Peer::from_config( - tunn, - index, - &peer_config, - channel, - conn_id, - resources, - )); - - { - // Watch out! we need 2 locks, make sure you don't lock both at the same time anywhere else - let mut gateway_awaiting_connection = self.gateway_awaiting_connection.lock(); - let mut peers_by_ip = self.peers_by_ip.write(); - // In the gateway this will always be none, no harm done - match conn_id { - ConnId::Gateway(gateway_id) => { - if let Some(awaiting_ips) = gateway_awaiting_connection.remove(&gateway_id) { - for ip in awaiting_ips { - peer.add_allowed_ip(ip); - peers_by_ip.insert(ip, Arc::clone(&peer)); - } - } - } - ConnId::Client(_) => {} - ConnId::Resource(_) => {} - } - for ip in peer_config.ips { - peers_by_ip.insert(ip, Arc::clone(&peer)); - } - } - - if let Some(conn) = self.peer_connections.lock().get(&conn_id) { - set_connection_state_with_peer(self, conn, index, conn_id) - } - - data_channel.on_close({ - let tunnel = Arc::clone(self); - Box::new(move || { - tracing::debug!("channel_closed"); - let tunnel = tunnel.clone(); - Box::pin(async move { - tunnel.stop_peer(index, conn_id).await; - }) + pub fn on_dc_close_handler(self: Arc, index: u32, conn_id: ConnId) -> OnCloseHdlrFn { + Box::new(move || { + tracing::debug!("channel_closed"); + let tunnel = self.clone(); + Box::pin(async move { + tunnel.stop_peer(index, conn_id).await; }) - }); - - let tunnel = Arc::clone(self); - tokio::spawn(async move { tunnel.start_peer_handler(peer).await }); - - Ok(()) + }) } - #[tracing::instrument(level = "trace", skip(self))] - pub async fn new_peer_connection( - self: &Arc, - relays: Vec, - ) -> Result<(Arc, mpsc::Receiver)> { - let config = RTCConfiguration { - ice_servers: relays - .into_iter() - .map(|srv| match srv { - Relay::Stun(stun) => RTCIceServer { - urls: vec![stun.uri], - ..Default::default() - }, - Relay::Turn(turn) => RTCIceServer { - urls: vec![turn.uri], - username: turn.username, - credential: turn.password, - // TODO: check what this is used for - credential_type: RTCIceCredentialType::Password, - }, - }) - .collect(), - ..Default::default() - }; - - let peer_connection = Arc::new(self.webrtc_api.new_peer_connection(config).await?); - - let (ice_candidate_tx, ice_candidate_rx) = mpsc::channel(ICE_CANDIDATE_BUFFER); - - peer_connection.on_ice_candidate(Box::new(move |candidate| { - let Some(candidate) = candidate else { - return Box::pin(async {}); - }; - - let mut ice_candidate_tx = ice_candidate_tx.clone(); + pub fn on_peer_connection_state_change_handler( + self: Arc, + index: u32, + conn_id: ConnId, + ) -> OnPeerConnectionStateChangeHdlrFn { + Box::new(move |state| { + let tunnel = Arc::clone(&self); Box::pin(async move { - let ice_candidate = match candidate.to_json() { - Ok(ice_candidate) => ice_candidate, - Err(e) => { - tracing::warn!("Failed to serialize ICE candidate to JSON: {e}",); - return; - } - }; - - if ice_candidate_tx.send(ice_candidate).await.is_err() { - debug_assert!(false, "receiver was dropped before sender") + tracing::trace!(?state, "peer_state_update"); + if state == RTCPeerConnectionState::Failed { + tunnel.stop_peer(index, conn_id).await; } }) - })); - - Ok((peer_connection, ice_candidate_rx)) + }) } pub async fn add_ice_candidate( @@ -223,11 +78,57 @@ where peer_connection.add_ice_candidate(ice_candidate).await?; Ok(()) } - - /// Clean up a connection to a resource. - // FIXME: this cleanup connection is wrong! - pub fn cleanup_connection(&self, id: ConnId) { - self.awaiting_connection.lock().remove(&id); - self.peer_connections.lock().remove(&id); - } +} + +#[tracing::instrument(level = "trace", skip(webrtc))] +pub async fn new_peer_connection( + webrtc: &webrtc::api::API, + relays: Vec, +) -> Result<(Arc, mpsc::Receiver)> { + let config = RTCConfiguration { + ice_servers: relays + .into_iter() + .map(|srv| match srv { + Relay::Stun(stun) => RTCIceServer { + urls: vec![stun.uri], + ..Default::default() + }, + Relay::Turn(turn) => RTCIceServer { + urls: vec![turn.uri], + username: turn.username, + credential: turn.password, + // TODO: check what this is used for + credential_type: RTCIceCredentialType::Password, + }, + }) + .collect(), + ..Default::default() + }; + + let peer_connection = Arc::new(webrtc.new_peer_connection(config).await?); + + let (ice_candidate_tx, ice_candidate_rx) = mpsc::channel(ICE_CANDIDATE_BUFFER); + + peer_connection.on_ice_candidate(Box::new(move |candidate| { + let Some(candidate) = candidate else { + return Box::pin(async {}); + }; + + let mut ice_candidate_tx = ice_candidate_tx.clone(); + Box::pin(async move { + let ice_candidate = match candidate.to_json() { + Ok(ice_candidate) => ice_candidate, + Err(e) => { + tracing::warn!("Failed to serialize ICE candidate to JSON: {e}",); + return; + } + }; + + if ice_candidate_tx.send(ice_candidate).await.is_err() { + debug_assert!(false, "receiver was dropped before sender") + } + }) + })); + + Ok((peer_connection, ice_candidate_rx)) } diff --git a/rust/connlib/tunnel/src/control_protocol/client.rs b/rust/connlib/tunnel/src/control_protocol/client.rs index e5c006c5d..a92d419eb 100644 --- a/rust/connlib/tunnel/src/control_protocol/client.rs +++ b/rust/connlib/tunnel/src/control_protocol/client.rs @@ -1,10 +1,9 @@ use std::sync::Arc; use boringtun::x25519::{PublicKey, StaticSecret}; -use connlib_shared::messages::SecretKey; use connlib_shared::{ control::Reference, - messages::{GatewayId, Key, Relay, RequestConnection, ResourceId, ReuseConnection}, + messages::{GatewayId, Key, Relay, RequestConnection, ResourceId}, Callbacks, }; use rand_core::OsRng; @@ -17,40 +16,32 @@ use webrtc::{ }, }; -use crate::{ClientState, ControlSignal, Error, PeerConfig, Request, Result, Tunnel}; +use crate::control_protocol::new_peer_connection; +use crate::{peer::Peer, ClientState, Error, Request, Result, Tunnel}; #[tracing::instrument(level = "trace", skip(tunnel))] -fn handle_connection_state_update( - tunnel: &Arc>, +fn handle_connection_state_update( + tunnel: &Arc>, state: RTCPeerConnectionState, gateway_id: GatewayId, resource_id: ResourceId, ) where - C: ControlSignal + Clone + Send + Sync + 'static, CB: Callbacks + 'static, { tracing::trace!("peer_state"); if state == RTCPeerConnectionState::Failed { - tunnel - .awaiting_connection - .lock() - .remove(&resource_id.into()); + tunnel.role_state.lock().on_connection_failed(resource_id); tunnel.peer_connections.lock().remove(&gateway_id.into()); - tunnel - .gateway_awaiting_connection - .lock() - .remove(&gateway_id); } } #[tracing::instrument(level = "trace", skip(tunnel))] -fn set_connection_state_update( - tunnel: &Arc>, +fn set_connection_state_update( + tunnel: &Arc>, peer_connection: &Arc, gateway_id: GatewayId, resource_id: ResourceId, ) where - C: ControlSignal + Clone + Send + Sync + 'static, CB: Callbacks + 'static, { let tunnel = Arc::clone(tunnel); @@ -64,9 +55,8 @@ fn set_connection_state_update( )); } -impl Tunnel +impl Tunnel where - C: ControlSignal + Clone + Send + Sync + 'static, CB: Callbacks + 'static, { /// Initiate an ice connection request. @@ -89,77 +79,23 @@ where reference: Option, ) -> Result { tracing::trace!("request_connection"); - let resource_description = self - .resources - .read() - .get_by_id(&resource_id) - .ok_or(Error::UnknownResource)? - .clone(); let reference: usize = reference .ok_or(Error::InvalidReference)? .parse() .map_err(|_| Error::InvalidReference)?; - { - let mut awaiting_connections = self.awaiting_connection.lock(); - let Some(awaiting_connection) = awaiting_connections.get_mut(&resource_id.into()) - else { - return Err(Error::UnexpectedConnectionDetails); - }; - awaiting_connection.response_received = true; - if awaiting_connection.total_attemps != reference - || resource_description - .ips() - .iter() - .any(|&ip| self.peers_by_ip.read().exact_match(ip).is_some()) - { - return Err(Error::UnexpectedConnectionDetails); - } + + if let Some(connection) = self.role_state.lock().attempt_to_reuse_connection( + resource_id, + gateway_id, + reference, + &mut self.peers_by_ip.write(), + )? { + return Ok(Request::ReuseConnection(connection)); } - self.resources_gateways - .lock() - .insert(resource_id, gateway_id); - { - let mut gateway_awaiting_connection = self.gateway_awaiting_connection.lock(); - if let Some(g) = gateway_awaiting_connection.get_mut(&gateway_id) { - g.extend(resource_description.ips()); - return Ok(Request::ReuseConnection(ReuseConnection { - resource_id, - gateway_id, - })); - } else { - gateway_awaiting_connection.insert(gateway_id, vec![]); - } - } - { - let found = { - let mut peers_by_ip = self.peers_by_ip.write(); - let peer = peers_by_ip - .iter() - .find_map(|(_, p)| (p.conn_id == gateway_id.into()).then_some(p)) - .cloned(); - if let Some(peer) = peer { - for ip in resource_description.ips() { - peer.add_allowed_ip(ip); - peers_by_ip.insert(ip, Arc::clone(&peer)); - } - true - } else { - false - } - }; - - if found { - self.awaiting_connection.lock().remove(&resource_id.into()); - return Ok(Request::ReuseConnection(ReuseConnection { - resource_id, - gateway_id, - })); - } - } let peer_connection = { - let (peer_connection, receiver) = self.new_peer_connection(relays).await?; + let (peer_connection, receiver) = new_peer_connection(&self.webrtc_api, relays).await?; self.role_state .lock() .add_waiting_ice_receiver(gateway_id, receiver); @@ -191,46 +127,63 @@ where Box::pin(async move { tracing::trace!("new_data_channel_opened"); let index = tunnel.next_index(); - let Some(gateway_public_key) = - tunnel.gateway_public_keys.lock().remove(&gateway_id) - else { - tunnel - .awaiting_connection - .lock() - .remove(&resource_id.into()); - tunnel.peer_connections.lock().remove(&gateway_id.into()); - tunnel - .gateway_awaiting_connection - .lock() - .remove(&gateway_id); - let e = Error::ControlProtocolError; - tracing::warn!(err = ?e, "channel_open"); - let _ = tunnel.callbacks.on_error(&e); - return; - }; - let peer_config = PeerConfig { - persistent_keepalive: None, - public_key: gateway_public_key, - ips: resource_description.ips(), - preshared_key: SecretKey::new(Key(p_key.to_bytes())), + + let peer_config = match tunnel.role_state.lock().create_peer_config_for_new_connection(resource_id, gateway_id, p_key) { + Ok(c) => c, + Err(e) => { + tunnel.peer_connections.lock().remove(&gateway_id.into()); + + tracing::warn!(err = ?e, "channel_open"); + let _ = tunnel.callbacks.on_error(&e); + return; + } }; - if let Err(e) = tunnel - .handle_channel_open(d, index, peer_config, gateway_id.into(), None) - .await + d.on_close(tunnel.clone().on_dc_close_handler(index, gateway_id.into())); + + let peer = Arc::new(Peer::new( + tunnel.private_key.clone(), + index, + peer_config.clone(), + d.detach().await.expect("only fails if not opened or not enabled, both of which are always true for us"), + gateway_id.into(), + None, + )); + { - tracing::error!(err = ?e, "channel_open"); - let _ = tunnel.callbacks.on_error(&e); - tunnel.peer_connections.lock().remove(&gateway_id.into()); - tunnel - .gateway_awaiting_connection - .lock() - .remove(&gateway_id); + let mut role_state = tunnel.role_state.lock(); + // Watch out! we need 2 locks, make sure you don't lock both at the same time anywhere else + let mut peers_by_ip = tunnel.peers_by_ip.write(); + + if let Some(awaiting_ips) = + role_state.gateway_awaiting_connection.remove(&gateway_id) + { + for ip in awaiting_ips { + peer.add_allowed_ip(ip); + peers_by_ip.insert(ip, Arc::clone(&peer)); + } + } + + for ip in peer_config.ips { + peers_by_ip.insert(ip, Arc::clone(&peer)); + } } + + if let Some(conn) = tunnel.peer_connections.lock().get(&gateway_id.into()) { + conn.on_peer_connection_state_change( + tunnel + .clone() + .on_peer_connection_state_change_handler(index, gateway_id.into()), + ); + } + + tokio::spawn(tunnel.clone().start_peer_handler(peer)); + tunnel - .awaiting_connection + .role_state .lock() - .remove(&resource_id.into()); + .awaiting_connection + .remove(&resource_id); }) })); @@ -260,10 +213,10 @@ where rtc_sdp: RTCSessionDescription, gateway_public_key: PublicKey, ) -> Result<()> { - let gateway_id = *self - .resources_gateways + let gateway_id = self + .role_state .lock() - .get(&resource_id) + .gateway_by_resource(&resource_id) .ok_or(Error::UnknownResource)?; let peer_connection = self .peer_connections @@ -271,14 +224,11 @@ where .get(&gateway_id.into()) .ok_or(Error::UnknownResource)? .clone(); - self.gateway_public_keys - .lock() - .insert(gateway_id, gateway_public_key); - peer_connection.set_remote_description(rtc_sdp).await?; + self.role_state .lock() - .activate_ice_candidate_receiver(gateway_id); + .activate_ice_candidate_receiver(gateway_id, gateway_public_key); Ok(()) } diff --git a/rust/connlib/tunnel/src/control_protocol/gateway.rs b/rust/connlib/tunnel/src/control_protocol/gateway.rs index 0e34ed189..ab7fb16d9 100644 --- a/rust/connlib/tunnel/src/control_protocol/gateway.rs +++ b/rust/connlib/tunnel/src/control_protocol/gateway.rs @@ -10,15 +10,15 @@ use webrtc::peer_connection::{ RTCPeerConnection, }; -use crate::{ControlSignal, GatewayState, PeerConfig, Tunnel}; +use crate::control_protocol::new_peer_connection; +use crate::{peer::Peer, GatewayState, PeerConfig, Tunnel}; #[tracing::instrument(level = "trace", skip(tunnel))] -fn handle_connection_state_update( - tunnel: &Arc>, +fn handle_connection_state_update( + tunnel: &Arc>, state: RTCPeerConnectionState, client_id: ClientId, ) where - C: ControlSignal + Clone + Send + Sync + 'static, CB: Callbacks + 'static, { tracing::trace!(?state, "peer_state"); @@ -28,12 +28,11 @@ fn handle_connection_state_update( } #[tracing::instrument(level = "trace", skip(tunnel))] -fn set_connection_state_update( - tunnel: &Arc>, +fn set_connection_state_update( + tunnel: &Arc>, peer_connection: &Arc, client_id: ClientId, ) where - C: ControlSignal + Clone + Send + Sync + 'static, CB: Callbacks + 'static, { let tunnel = Arc::clone(tunnel); @@ -45,9 +44,8 @@ fn set_connection_state_update( )); } -impl Tunnel +impl Tunnel where - C: ControlSignal + Clone + Send + Sync + 'static, CB: Callbacks + 'static, { /// Accept a connection request from a client. @@ -72,7 +70,7 @@ where expires_at: DateTime, resource: ResourceDescription, ) -> Result { - let (peer_connection, receiver) = self.new_peer_connection(relays).await?; + let (peer_connection, receiver) = new_peer_connection(&self.webrtc_api, relays).await?; self.role_state .lock() .add_new_ice_receiver(client_id, receiver); @@ -88,12 +86,12 @@ where peer_connection.on_data_channel(Box::new(move |d| { tracing::trace!("new_data_channel"); let data_channel = Arc::clone(&d); - let peer = peer.clone(); + let peer_config = peer.clone(); let tunnel = Arc::clone(&tunnel); let resource = resource.clone(); Box::pin(async move { d.on_open(Box::new(move || { - tracing::trace!("new_data_channel_open"); + tracing::trace!(?peer_config.ips, "new_data_channel_open"); Box::pin(async move { { let Some(device) = tunnel.device.read().await.clone() else { @@ -103,7 +101,7 @@ where return; }; let iface_config = device.config; - for &ip in &peer.ips { + for &ip in &peer_config.ips { if let Err(e) = iface_config.add_route(ip, tunnel.callbacks()).await { let _ = tunnel.callbacks.on_error(&e); @@ -111,28 +109,34 @@ where } } - if let Err(e) = tunnel - .handle_channel_open( - data_channel, - index, - peer, - client_id.into(), - Some((resource, expires_at)), - ) - .await - { - let _ = tunnel.callbacks.on_error(&e); - tracing::error!(err = ?e, "channel_open"); - // Note: handle_channel_open can only error out before insert to peers_by_ip - // otherwise we would need to clean that up too! - let conn = tunnel.peer_connections.lock().remove(&client_id.into()); - if let Some(conn) = conn { - if let Err(e) = conn.close().await { - tracing::error!(error = ?e, "webrtc_close_channel"); - let _ = tunnel.callbacks().on_error(&e.into()); - } - } + data_channel + .on_close(tunnel.clone().on_dc_close_handler(index, client_id.into())); + + let peer = Arc::new(Peer::new( + tunnel.private_key.clone(), + index, + peer_config.clone(), + data_channel.detach().await.expect("only fails if not opened or not enabled, both of which are always true for us"), + client_id.into(), + Some((resource, expires_at)), + )); + + let mut peers_by_ip = tunnel.peers_by_ip.write(); + + for ip in peer_config.ips { + peers_by_ip.insert(ip, Arc::clone(&peer)); } + + if let Some(conn) = tunnel.peer_connections.lock().get(&client_id.into()) { + conn.on_peer_connection_state_change( + tunnel.clone().on_peer_connection_state_change_handler( + index, + client_id.into(), + ), + ); + } + + tokio::spawn(tunnel.clone().start_peer_handler(peer)); }) })) }) diff --git a/rust/connlib/tunnel/src/gateway.rs b/rust/connlib/tunnel/src/gateway.rs index 4c5d9912e..2f88d9548 100644 --- a/rust/connlib/tunnel/src/gateway.rs +++ b/rust/connlib/tunnel/src/gateway.rs @@ -1,6 +1,6 @@ use crate::device_channel::create_iface; use crate::{ - ControlSignal, Device, Event, RoleState, Tunnel, ICE_GATHERING_TIMEOUT_SECONDS, + peer_by_ip, ConnId, Device, Event, RoleState, Tunnel, ICE_GATHERING_TIMEOUT_SECONDS, MAX_CONCURRENT_ICE_GATHERING, MAX_UDP_SIZE, }; use connlib_shared::error::ConnlibError; @@ -13,9 +13,8 @@ use std::task::{ready, Context, Poll}; use std::time::Duration; use webrtc::ice_transport::ice_candidate::RTCIceCandidateInit; -impl Tunnel +impl Tunnel where - C: ControlSignal + Send + Sync + 'static, CB: Callbacks + 'static, { /// Sets the interface configuration and starts background tasks. @@ -35,15 +34,20 @@ where Ok(()) } + + /// Clean up a connection to a resource. + // FIXME: this cleanup connection is wrong! + pub fn cleanup_connection(&self, id: ConnId) { + self.peer_connections.lock().remove(&id); + } } /// Reads IP packets from the [`Device`] and handles them accordingly. -async fn device_handler( - tunnel: Arc>, +async fn device_handler( + tunnel: Arc>, mut device: Device, ) -> Result<(), ConnlibError> where - C: ControlSignal + Send + Sync + 'static, CB: Callbacks + 'static, { let mut buf = [0u8; MAX_UDP_SIZE]; @@ -55,7 +59,7 @@ where let dest = packet.destination(); - let Some(peer) = tunnel.peer_by_ip(dest) else { + let Some(peer) = peer_by_ip(&tunnel.peers_by_ip.read(), dest) else { continue; }; diff --git a/rust/connlib/tunnel/src/iface_handler.rs b/rust/connlib/tunnel/src/iface_handler.rs index ef69e0efc..98be7c242 100644 --- a/rust/connlib/tunnel/src/iface_handler.rs +++ b/rust/connlib/tunnel/src/iface_handler.rs @@ -4,11 +4,10 @@ use boringtun::noise::{errors::WireGuardError, TunnResult}; use bytes::Bytes; use connlib_shared::{Callbacks, Result}; -use crate::{ip_packet::MutableIpPacket, peer::Peer, ControlSignal, RoleState, Tunnel}; +use crate::{ip_packet::MutableIpPacket, peer::Peer, RoleState, Tunnel}; -impl Tunnel +impl Tunnel where - C: ControlSignal + Send + Sync + 'static, CB: Callbacks + 'static, TRoleState: RoleState, { diff --git a/rust/connlib/tunnel/src/lib.rs b/rust/connlib/tunnel/src/lib.rs index c3060a70b..7c5bcb443 100644 --- a/rust/connlib/tunnel/src/lib.rs +++ b/rust/connlib/tunnel/src/lib.rs @@ -13,7 +13,6 @@ use ip_network::IpNetwork; use ip_network_table::IpNetworkTable; use serde::{Deserialize, Serialize}; -use async_trait::async_trait; use itertools::Itertools; use parking_lot::{Mutex, RwLock}; use peer::{Peer, PeerStats}; @@ -136,29 +135,6 @@ impl From for PeerConfig { } } } - -/// Trait used for out-going signals to control plane that are **required** to be made from inside the tunnel. -/// -/// Generally, we try to return from the functions here rather than using this callback. -#[async_trait] -pub trait ControlSignal { - /// Signals to the control plane an intent to initiate a connection to the given resource. - /// - /// Used when a packet is found to a resource we have no connection stablished but is within the list of resources available for the client. - async fn signal_connection_to( - &self, - resource: &ResourceDescription, - connected_gateway_ids: &[GatewayId], - reference: usize, - ) -> Result<()>; -} - -#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] -struct AwaitingConnectionDetails { - pub total_attemps: usize, - pub response_received: bool, -} - #[derive(Clone)] struct Device { config: Arc, @@ -190,7 +166,7 @@ impl Device { // TODO: We should use newtypes for each kind of Id /// Tunnel is a wireguard state machine that uses webrtc's ICE channels instead of UDP sockets /// to communicate between peers. -pub struct Tunnel { +pub struct Tunnel { next_index: Mutex, // We use a tokio Mutex here since this is only read/write during config so there's no relevant performance impact device: tokio::sync::RwLock>, @@ -199,13 +175,7 @@ pub struct Tunnel { public_key: PublicKey, peers_by_ip: RwLock>>, peer_connections: Mutex>>, - awaiting_connection: Mutex>, - gateway_awaiting_connection: Mutex>>, - resources_gateways: Mutex>, webrtc_api: API, - resources: Arc>>, - control_signaler: C, - gateway_public_keys: Mutex>, callbacks: CallbackErrorFacade, iface_handler_abort: Mutex>, @@ -220,18 +190,10 @@ pub struct TunnelStats { public_key: String, peers_by_ip: HashMap, peer_connections: Vec, - resource_gateways: HashMap, - dns_resources: HashMap, - network_resources: HashMap, - gateway_public_keys: HashMap, - - awaiting_connection: HashMap, - gateway_awaiting_connection: HashMap>, } -impl Tunnel +impl Tunnel where - C: ControlSignal + Send + Sync + 'static, CB: Callbacks + 'static, TRoleState: RoleState, { @@ -243,30 +205,11 @@ where .map(|(ip, peer)| (ip, peer.stats())) .collect(); let peer_connections = self.peer_connections.lock().keys().cloned().collect(); - let awaiting_connection = self.awaiting_connection.lock().clone(); - let gateway_awaiting_connection = self.gateway_awaiting_connection.lock().clone(); - let resource_gateways = self.resources_gateways.lock().clone(); - let (network_resources, dns_resources) = { - let resources = self.resources.read(); - (resources.network_resources(), resources.dns_resources()) - }; - let gateway_public_keys = self - .gateway_public_keys - .lock() - .iter() - .map(|(&id, &k)| (id, Key::from(k).to_string())) - .collect(); TunnelStats { public_key: Key::from(self.public_key).to_string(), peers_by_ip, peer_connections, - awaiting_connection, - gateway_awaiting_connection, - resource_gateways, - dns_resources, - network_resources, - gateway_public_keys, } } @@ -277,14 +220,10 @@ where pub fn poll_next_event(&self, cx: &mut Context<'_>) -> Poll> { self.role_state.lock().poll_next_event(cx) } +} - pub(crate) fn peer_by_ip(&self, ip: IpAddr) -> Option> { - self.peers_by_ip - .read() - .longest_match(ip) - .map(|(_, peer)| peer) - .cloned() - } +pub(crate) fn peer_by_ip(peers_by_ip: &IpNetworkTable>, ip: IpAddr) -> Option> { + peers_by_ip.longest_match(ip).map(|(_, peer)| peer).cloned() } pub enum Event { @@ -292,11 +231,15 @@ pub enum Event { conn_id: TId, candidate: RTCIceCandidateInit, }, + ConnectionIntent { + resource: ResourceDescription, + connected_gateway_ids: Vec, + reference: usize, + }, } -impl Tunnel +impl Tunnel where - C: ControlSignal + Send + Sync + 'static, CB: Callbacks + 'static, TRoleState: RoleState, { @@ -305,22 +248,14 @@ where /// # Parameters /// - `private_key`: wireguard's private key. /// - `control_signaler`: this is used to send SDP from the tunnel to the control plane. - #[tracing::instrument(level = "trace", skip(private_key, control_signaler, callbacks))] - pub async fn new( - private_key: StaticSecret, - control_signaler: C, - callbacks: CB, - ) -> Result { + #[tracing::instrument(level = "trace", skip(private_key, callbacks))] + pub async fn new(private_key: StaticSecret, callbacks: CB) -> Result { let public_key = (&private_key).into(); let rate_limiter = Arc::new(RateLimiter::new(&public_key, HANDSHAKE_RATE_LIMIT)); let peers_by_ip = RwLock::new(IpNetworkTable::new()); let next_index = Default::default(); let peer_connections = Default::default(); let resources: Arc>> = Default::default(); - let awaiting_connection = Default::default(); - let gateway_public_keys = Default::default(); - let resources_gateways = Default::default(); - let gateway_awaiting_connection = Default::default(); let device = Default::default(); let iface_handler_abort = Default::default(); @@ -347,7 +282,6 @@ where .build(); Ok(Self { - gateway_public_keys, rate_limiter, private_key, peer_connections, @@ -355,12 +289,7 @@ where peers_by_ip, next_index, webrtc_api, - resources, device, - awaiting_connection, - gateway_awaiting_connection, - control_signaler, - resources_gateways, callbacks: CallbackErrorFacade(callbacks), iface_handler_abort, role_state: Default::default(), @@ -411,31 +340,6 @@ where }); } - fn remove_expired_peers(self: &Arc) { - let mut peers_by_ip = self.peers_by_ip.write(); - - for (_, peer) in peers_by_ip.iter() { - peer.expire_resources(); - if peer.is_emptied() { - tracing::trace!(index = peer.index, "peer_expired"); - let conn = self.peer_connections.lock().remove(&peer.conn_id); - let p = peer.clone(); - - // We are holding a Mutex, particularly a write one, we don't want to make a blocking call - tokio::spawn(async move { - let _ = p.shutdown().await; - if let Some(conn) = conn { - // TODO: it seems that even closing the stream there are messages to the relay - // see where they come from. - let _ = conn.close().await; - } - }); - } - } - - peers_by_ip.retain(|_, p| !p.is_emptied()); - } - fn start_peers_refresh_timer(self: &Arc) { let tunnel = self.clone(); @@ -445,7 +349,10 @@ where let mut dst_buf = [0u8; MAX_UDP_SIZE]; loop { - tunnel.remove_expired_peers(); + remove_expired_peers( + &mut tunnel.peers_by_ip.write(), + &mut tunnel.peer_connections.lock(), + ); let peers: Vec<_> = tunnel .peers_by_ip @@ -497,14 +404,6 @@ where Ok(()) } - fn get_resource(&self, addr: IpAddr) -> Option { - let resources = self.resources.read(); - match addr { - IpAddr::V4(ipv4) => resources.get_by_ip(ipv4).cloned(), - IpAddr::V6(ipv6) => resources.get_by_ip(ipv6).cloned(), - } - } - fn next_index(&self) -> u32 { self.next_index.lock().next() } @@ -514,6 +413,32 @@ where } } +fn remove_expired_peers( + peers_by_ip: &mut IpNetworkTable>, + peer_connections: &mut HashMap>, +) { + for (_, peer) in peers_by_ip.iter() { + peer.expire_resources(); + if peer.is_emptied() { + tracing::trace!(index = peer.index, "peer_expired"); + let conn = peer_connections.remove(&peer.conn_id); + let p = peer.clone(); + + // We are holding a Mutex, particularly a write one, we don't want to make a blocking call + tokio::spawn(async move { + let _ = p.shutdown().await; + if let Some(conn) = conn { + // TODO: it seems that even closing the stream there are messages to the relay + // see where they come from. + let _ = conn.close().await; + } + }); + } + } + + peers_by_ip.retain(|_, p| !p.is_emptied()); +} + /// Dedicated trait for abstracting over the different ICE states. /// /// By design, this trait does not allow any operations apart from advancing via [`RoleState::poll_next_event`]. diff --git a/rust/connlib/tunnel/src/peer.rs b/rust/connlib/tunnel/src/peer.rs index 034ea7832..7abcd6015 100644 --- a/rust/connlib/tunnel/src/peer.rs +++ b/rust/connlib/tunnel/src/peer.rs @@ -1,6 +1,7 @@ use std::{collections::HashMap, net::IpAddr, sync::Arc}; use boringtun::noise::{Tunn, TunnResult}; +use boringtun::x25519::StaticSecret; use bytes::Bytes; use chrono::{DateTime, Utc}; use connlib_shared::{ @@ -11,11 +12,10 @@ use ip_network::IpNetwork; use ip_network_table::IpNetworkTable; use parking_lot::{Mutex, RwLock}; use pnet_packet::MutablePacket; +use secrecy::ExposeSecret; use webrtc::data::data_channel::DataChannel; -use crate::{ip_packet::MutableIpPacket, resource_table::ResourceTable, ConnId}; - -use super::PeerConfig; +use crate::{ip_packet::MutableIpPacket, resource_table::ResourceTable, ConnId, PeerConfig}; type ExpiryingResource = (ResourceDescription, DateTime); @@ -87,34 +87,26 @@ impl Peer { } } - pub(crate) fn from_config( - tunnel: Tunn, - index: u32, - config: &PeerConfig, - channel: Arc, - conn_id: ConnId, - resource: Option<(ResourceDescription, DateTime)>, - ) -> Self { - Self::new( - Mutex::new(tunnel), - index, - config.ips.clone(), - channel, - conn_id, - resource, - ) - } - pub(crate) fn new( - tunnel: Mutex, + private_key: StaticSecret, index: u32, - ips: Vec, + peer_config: PeerConfig, channel: Arc, conn_id: ConnId, resource: Option<(ResourceDescription, DateTime)>, ) -> Peer { + let tunnel = Tunn::new( + private_key.clone(), + peer_config.public_key, + Some(peer_config.preshared_key.expose_secret().0), + peer_config.persistent_keepalive, + index, + None, + ) + .expect("never actually fails"); // See https://github.com/cloudflare/boringtun/pull/366. + let mut allowed_ips = IpNetworkTable::new(); - for ip in ips { + for ip in peer_config.ips { allowed_ips.insert(ip, ()); } let allowed_ips = RwLock::new(allowed_ips); @@ -123,8 +115,9 @@ impl Peer { resource_table.insert(r); RwLock::new(resource_table) }); + Peer { - tunnel, + tunnel: Mutex::new(tunnel), index, allowed_ips, channel, diff --git a/rust/connlib/tunnel/src/peer_handler.rs b/rust/connlib/tunnel/src/peer_handler.rs index ac75be900..7af92cafd 100644 --- a/rust/connlib/tunnel/src/peer_handler.rs +++ b/rust/connlib/tunnel/src/peer_handler.rs @@ -5,13 +5,12 @@ use bytes::Bytes; use connlib_shared::{Callbacks, Error, Result}; use crate::{ - device_channel::DeviceIo, index::check_packet_index, peer::Peer, ControlSignal, RoleState, - Tunnel, MAX_UDP_SIZE, + device_channel::DeviceIo, index::check_packet_index, peer::Peer, RoleState, Tunnel, + MAX_UDP_SIZE, }; -impl Tunnel +impl Tunnel where - C: ControlSignal + Send + Sync + 'static, CB: Callbacks + 'static, TRoleState: RoleState, { @@ -155,7 +154,7 @@ where Ok(()) } - pub(crate) async fn start_peer_handler(self: &Arc, peer: Arc) { + pub(crate) async fn start_peer_handler(self: Arc, peer: Arc) { loop { let Some(device) = self.device.read().await.clone() else { let err = Error::NoIface; diff --git a/rust/connlib/tunnel/src/resource_sender.rs b/rust/connlib/tunnel/src/resource_sender.rs index f164be770..c0bc8fdad 100644 --- a/rust/connlib/tunnel/src/resource_sender.rs +++ b/rust/connlib/tunnel/src/resource_sender.rs @@ -3,15 +3,12 @@ use std::{ sync::Arc, }; -use crate::{ - device_channel::DeviceIo, ip_packet::MutableIpPacket, peer::Peer, ControlSignal, Tunnel, -}; +use crate::{device_channel::DeviceIo, ip_packet::MutableIpPacket, peer::Peer, Tunnel}; use connlib_shared::{messages::ResourceDescription, Callbacks, Error, Result}; -impl Tunnel +impl Tunnel where - C: ControlSignal + Send + Sync + 'static, CB: Callbacks + 'static, { #[inline(always)] diff --git a/rust/gateway/src/control.rs b/rust/gateway/src/control.rs deleted file mode 100644 index b944cfa9a..000000000 --- a/rust/gateway/src/control.rs +++ /dev/null @@ -1,22 +0,0 @@ -use async_trait::async_trait; -use connlib_shared::{ - messages::{GatewayId, ResourceDescription}, - Result, -}; -use firezone_tunnel::ControlSignal; - -#[derive(Clone)] -pub struct ControlSignaler; - -#[async_trait] -impl ControlSignal for ControlSignaler { - async fn signal_connection_to( - &self, - resource: &ResourceDescription, - _connected_gateway_ids: &[GatewayId], - _: usize, - ) -> Result<()> { - tracing::warn!("A message to network resource: {resource:?} was discarded, gateways aren't meant to be used as clients."); - Ok(()) - } -} diff --git a/rust/gateway/src/eventloop.rs b/rust/gateway/src/eventloop.rs index 1c1056d32..abccc1677 100644 --- a/rust/gateway/src/eventloop.rs +++ b/rust/gateway/src/eventloop.rs @@ -1,4 +1,3 @@ -use crate::control::ControlSignaler; use crate::messages::{ AllowAccess, BroadcastClientIceCandidates, ClientIceCandidates, ConnectionReady, EgressMessages, IngressMessages, @@ -7,7 +6,7 @@ use crate::CallbackHandler; use anyhow::Result; use connlib_shared::messages::ClientId; use connlib_shared::Error; -use firezone_tunnel::{GatewayState, Tunnel}; +use firezone_tunnel::{Event, GatewayState, Tunnel}; use phoenix_channel::PhoenixChannel; use std::convert::Infallible; use std::sync::Arc; @@ -18,7 +17,7 @@ use webrtc::peer_connection::sdp::session_description::RTCSessionDescription; pub const PHOENIX_TOPIC: &str = "gateway"; pub struct Eventloop { - tunnel: Arc>, + tunnel: Arc>, portal: PhoenixChannel, // TODO: Strongly type request reference (currently `String`) @@ -31,7 +30,7 @@ pub struct Eventloop { impl Eventloop { pub(crate) fn new( - tunnel: Arc>, + tunnel: Arc>, portal: PhoenixChannel, ) -> Self { Self { @@ -190,6 +189,9 @@ impl Eventloop { ); continue; } + Poll::Ready(Event::ConnectionIntent { .. }) => { + unreachable!("Not used on the gateway, split the events!") + } Poll::Pending => {} } diff --git a/rust/gateway/src/main.rs b/rust/gateway/src/main.rs index f03c27c9a..2e70fe858 100644 --- a/rust/gateway/src/main.rs +++ b/rust/gateway/src/main.rs @@ -1,4 +1,3 @@ -use crate::control::ControlSignaler; use crate::eventloop::{Eventloop, PHOENIX_TOPIC}; use crate::messages::InitGateway; use anyhow::{Context, Result}; @@ -15,7 +14,6 @@ use std::sync::Arc; use tracing_subscriber::layer; use url::Url; -mod control; mod eventloop; mod messages; @@ -30,7 +28,7 @@ async fn main() -> Result<()> { SecretString::new(cli.common.secret), get_device_id(), )?; - let tunnel = Arc::new(Tunnel::new(private_key, ControlSignaler, CallbackHandler).await?); + let tunnel = Arc::new(Tunnel::new(private_key, CallbackHandler).await?); tokio::spawn(backoff::future::retry_notify( ExponentialBackoffBuilder::default() @@ -48,7 +46,7 @@ async fn main() -> Result<()> { } async fn run( - tunnel: Arc>, + tunnel: Arc>, connect_url: Url, ) -> Result { let (portal, init) = phoenix_channel::init::(