From 919b7890e6f87480f28c8d398de0136349be8c66 Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Thu, 19 Oct 2023 13:30:04 +1100 Subject: [PATCH] refactor(connlib): move more logic to `poll_next_event` (#2403) --- rust/connlib/shared/src/error.rs | 20 ++ rust/connlib/tunnel/src/client.rs | 20 +- rust/connlib/tunnel/src/control_protocol.rs | 68 ++-- .../tunnel/src/control_protocol/client.rs | 32 +- .../tunnel/src/control_protocol/gateway.rs | 10 +- rust/connlib/tunnel/src/gateway.rs | 19 +- rust/connlib/tunnel/src/iface_handler.rs | 65 ---- rust/connlib/tunnel/src/lib.rs | 304 ++++++++++------- rust/connlib/tunnel/src/peer.rs | 65 ++-- rust/connlib/tunnel/src/peer_handler.rs | 318 ++++++++++++------ rust/connlib/tunnel/src/resource_sender.rs | 130 ------- 11 files changed, 537 insertions(+), 514 deletions(-) delete mode 100644 rust/connlib/tunnel/src/iface_handler.rs delete mode 100644 rust/connlib/tunnel/src/resource_sender.rs diff --git a/rust/connlib/shared/src/error.rs b/rust/connlib/shared/src/error.rs index 78e4956f8..ffcb41db7 100644 --- a/rust/connlib/shared/src/error.rs +++ b/rust/connlib/shared/src/error.rs @@ -145,6 +145,26 @@ impl ConnlibError { if e.status().is_client_error() ) } + + /// Whether this error is fatal to the underlying connection. + pub fn is_fatal_connection_error(&self) -> bool { + if let Self::WireguardError(e) = self { + return matches!( + e, + WireGuardError::ConnectionExpired | WireGuardError::NoCurrentSession + ); + } + + if let Self::IceDataError(e) = self { + return matches!( + e, + webrtc::data::Error::ErrStreamClosed + | webrtc::data::Error::Sctp(webrtc::sctp::Error::ErrStreamClosed) + ); + } + + false + } } #[cfg(target_os = "linux")] diff --git a/rust/connlib/tunnel/src/client.rs b/rust/connlib/tunnel/src/client.rs index 35deaab67..3835de76b 100644 --- a/rust/connlib/tunnel/src/client.rs +++ b/rust/connlib/tunnel/src/client.rs @@ -18,6 +18,7 @@ use connlib_shared::{Callbacks, DNS_SENTINEL}; use futures::channel::mpsc::Receiver; use futures::stream; use futures_bounded::{PushError, StreamMap}; +use futures_util::SinkExt; use hickory_resolver::lookup::Lookup; use ip_network::IpNetwork; use ip_network_table::IpNetworkTable; @@ -116,9 +117,8 @@ where config: &InterfaceConfig, ) -> connlib_shared::Result<()> { let device = create_iface(config, self.callbacks()).await?; - *self.device.write().await = Some(device.clone()); - self.start_timers().await?; + *self.device.write().await = Some(device.clone()); *self.iface_handler_abort.lock() = Some(tokio_util::spawn_log( &self.callbacks, device_handler(Arc::clone(self), device), @@ -203,12 +203,18 @@ where continue; }; - if let Err(e) = tunnel - .encapsulate_and_send_to_peer(packet, peer, &dest, &mut buf) - .await - { + if let Err(e) = peer.send(packet, dest, &mut buf).await { + tracing::error!(resource_address = %dest, err = ?e, "failed to handle packet {e:#}"); + let _ = tunnel.callbacks.on_error(&e); - tracing::error!(err = ?e, "failed to handle packet {e:#}") + + if e.is_fatal_connection_error() { + let _ = tunnel + .stop_peer_command_sender + .clone() + .send((peer.index, peer.conn_id)) + .await; + } } } } diff --git a/rust/connlib/tunnel/src/control_protocol.rs b/rust/connlib/tunnel/src/control_protocol.rs index cc77b848e..9139dff4d 100644 --- a/rust/connlib/tunnel/src/control_protocol.rs +++ b/rust/connlib/tunnel/src/control_protocol.rs @@ -38,36 +38,6 @@ where CB: Callbacks + 'static, TRoleState: RoleState, { - pub fn on_dc_close_handler( - self: Arc, - index: u32, - conn_id: TRoleState::Id, - ) -> OnCloseHdlrFn { - Box::new(move || { - tracing::debug!("channel_closed"); - let tunnel = self.clone(); - Box::pin(async move { - tunnel.stop_peer(index, conn_id).await; - }) - }) - } - - pub fn on_peer_connection_state_change_handler( - self: Arc, - index: u32, - conn_id: TRoleState::Id, - ) -> OnPeerConnectionStateChangeHdlrFn { - Box::new(move |state| { - let tunnel = Arc::clone(&self); - Box::pin(async move { - tracing::trace!(?state, "peer_state_update"); - if state == RTCPeerConnectionState::Failed { - tunnel.stop_peer(index, conn_id).await; - } - }) - }) - } - pub async fn add_ice_candidate( &self, conn_id: TRoleState::Id, @@ -84,6 +54,44 @@ where } } +pub fn on_peer_connection_state_change_handler( + index: u32, + conn_id: TId, + stop_command_sender: mpsc::Sender<(u32, TId)>, +) -> OnPeerConnectionStateChangeHdlrFn +where + TId: Copy + Send + Sync + 'static, +{ + Box::new(move |state| { + let mut sender = stop_command_sender.clone(); + + tracing::trace!(?state, "peer_state_update"); + Box::pin(async move { + if state == RTCPeerConnectionState::Failed { + let _ = sender.send((index, conn_id)).await; + } + }) + }) +} + +pub fn on_dc_close_handler( + index: u32, + conn_id: TId, + stop_command_sender: mpsc::Sender<(u32, TId)>, +) -> OnCloseHdlrFn +where + TId: Copy + Send + Sync + 'static, +{ + Box::new(move || { + let mut sender = stop_command_sender.clone(); + + tracing::debug!("channel_closed"); + Box::pin(async move { + let _ = sender.send((index, conn_id)).await; + }) + }) +} + #[tracing::instrument(level = "trace", skip(webrtc))] pub async fn new_peer_connection( webrtc: &webrtc::api::API, diff --git a/rust/connlib/tunnel/src/control_protocol/client.rs b/rust/connlib/tunnel/src/control_protocol/client.rs index 350435828..e90f84022 100644 --- a/rust/connlib/tunnel/src/control_protocol/client.rs +++ b/rust/connlib/tunnel/src/control_protocol/client.rs @@ -16,25 +16,11 @@ use webrtc::{ }, }; -use crate::control_protocol::new_peer_connection; +use crate::control_protocol::{ + new_peer_connection, on_dc_close_handler, on_peer_connection_state_change_handler, +}; use crate::{peer::Peer, ClientState, Error, Request, Result, Tunnel}; -#[tracing::instrument(level = "trace", skip(tunnel))] -fn handle_connection_state_update( - tunnel: &Arc>, - state: RTCPeerConnectionState, - gateway_id: GatewayId, - resource_id: ResourceId, -) where - CB: Callbacks + 'static, -{ - tracing::trace!("peer_state"); - if state == RTCPeerConnectionState::Failed { - tunnel.role_state.lock().on_connection_failed(resource_id); - tunnel.peer_connections.lock().remove(&gateway_id); - } -} - #[tracing::instrument(level = "trace", skip(tunnel))] fn set_connection_state_update( tunnel: &Arc>, @@ -49,7 +35,11 @@ fn set_connection_state_update( move |state: RTCPeerConnectionState| { let tunnel = Arc::clone(&tunnel); Box::pin(async move { - handle_connection_state_update(&tunnel, state, gateway_id, resource_id) + tracing::trace!("peer_state"); + if state == RTCPeerConnectionState::Failed { + tunnel.role_state.lock().on_connection_failed(resource_id); + tunnel.peer_connections.lock().remove(&gateway_id); + } }) }, )); @@ -139,7 +129,7 @@ where } }; - d.on_close(tunnel.clone().on_dc_close_handler(index, gateway_id)); + d.on_close(on_dc_close_handler(index, gateway_id, tunnel.stop_peer_command_sender.clone())); let peer = Arc::new(Peer::new( tunnel.private_key.clone(), @@ -171,9 +161,7 @@ where if let Some(conn) = tunnel.peer_connections.lock().get(&gateway_id) { conn.on_peer_connection_state_change( - tunnel - .clone() - .on_peer_connection_state_change_handler(index, gateway_id), + on_peer_connection_state_change_handler(index, gateway_id, tunnel.stop_peer_command_sender.clone()), ); } diff --git a/rust/connlib/tunnel/src/control_protocol/gateway.rs b/rust/connlib/tunnel/src/control_protocol/gateway.rs index 8bea0167c..d2f083167 100644 --- a/rust/connlib/tunnel/src/control_protocol/gateway.rs +++ b/rust/connlib/tunnel/src/control_protocol/gateway.rs @@ -10,7 +10,9 @@ use webrtc::peer_connection::{ RTCPeerConnection, }; -use crate::control_protocol::new_peer_connection; +use crate::control_protocol::{ + new_peer_connection, on_dc_close_handler, on_peer_connection_state_change_handler, +}; use crate::{peer::Peer, GatewayState, PeerConfig, Tunnel}; #[tracing::instrument(level = "trace", skip(tunnel))] @@ -110,7 +112,7 @@ where } data_channel - .on_close(tunnel.clone().on_dc_close_handler(index, client_id)); + .on_close(on_dc_close_handler(index, client_id, tunnel.stop_peer_command_sender.clone())); let peer = Arc::new(Peer::new( tunnel.private_key.clone(), @@ -129,9 +131,9 @@ where if let Some(conn) = tunnel.peer_connections.lock().get(&client_id) { conn.on_peer_connection_state_change( - tunnel.clone().on_peer_connection_state_change_handler( + on_peer_connection_state_change_handler( index, - client_id, + client_id, tunnel.stop_peer_command_sender.clone(), ), ); } diff --git a/rust/connlib/tunnel/src/gateway.rs b/rust/connlib/tunnel/src/gateway.rs index 79f2d55c0..bbc22cbb2 100644 --- a/rust/connlib/tunnel/src/gateway.rs +++ b/rust/connlib/tunnel/src/gateway.rs @@ -8,6 +8,7 @@ use connlib_shared::messages::{ClientId, Interface as InterfaceConfig}; use connlib_shared::Callbacks; use futures::channel::mpsc::Receiver; use futures_bounded::{PushError, StreamMap}; +use futures_util::SinkExt; use std::sync::Arc; use std::task::{ready, Context, Poll}; use std::time::Duration; @@ -24,9 +25,8 @@ where config: &InterfaceConfig, ) -> connlib_shared::Result<()> { let device = create_iface(config, self.callbacks()).await?; - *self.device.write().await = Some(device.clone()); - self.start_timers().await?; + *self.device.write().await = Some(device.clone()); *self.iface_handler_abort.lock() = Some(tokio::spawn(device_handler(Arc::clone(self), device)).abort_handle()); @@ -63,11 +63,16 @@ where continue; }; - if let Err(e) = tunnel - .encapsulate_and_send_to_peer(packet, peer, &dest, &mut buf) - .await - { - tracing::error!(err = ?e, "failed to handle packet {e:#}") + if let Err(e) = peer.send(packet, dest, &mut buf).await { + tracing::error!(resource_address = %dest, err = ?e, "failed to handle packet {e:#}"); + + if e.is_fatal_connection_error() { + let _ = tunnel + .stop_peer_command_sender + .clone() + .send((peer.index, peer.conn_id)) + .await; + } } } } diff --git a/rust/connlib/tunnel/src/iface_handler.rs b/rust/connlib/tunnel/src/iface_handler.rs deleted file mode 100644 index a5569aaf9..000000000 --- a/rust/connlib/tunnel/src/iface_handler.rs +++ /dev/null @@ -1,65 +0,0 @@ -use std::{net::IpAddr, sync::Arc}; - -use boringtun::noise::{errors::WireGuardError, TunnResult}; -use bytes::Bytes; -use connlib_shared::{Callbacks, Result}; - -use crate::{ip_packet::MutableIpPacket, peer::Peer, RoleState, Tunnel}; - -impl Tunnel -where - CB: Callbacks + 'static, - TRoleState: RoleState, -{ - #[inline(always)] - pub(crate) async fn encapsulate_and_send_to_peer<'a>( - &self, - mut packet: MutableIpPacket<'_>, - peer: Arc>, - dst_addr: &IpAddr, - buf: &mut [u8], - ) -> Result<()> { - let encapsulated_packet = peer.encapsulate(&mut packet, buf)?; - - match encapsulated_packet.encapsulate_result { - TunnResult::Done => Ok(()), - TunnResult::Err(WireGuardError::ConnectionExpired) - | TunnResult::Err(WireGuardError::NoCurrentSession) => { - self.stop_peer(encapsulated_packet.index, encapsulated_packet.conn_id) - .await; - Ok(()) - } - - TunnResult::Err(e) => { - tracing::error!(resource_address = %dst_addr, error = ?e, "resource_connection"); - let err = e.into(); - let _ = self.callbacks.on_error(&err); - Err(err) - } - TunnResult::WriteToNetwork(packet) => { - tracing::trace!(target: "wire", action = "writing", from = "iface", to = %dst_addr); - if let Err(e) = encapsulated_packet - .channel - .write(&Bytes::copy_from_slice(packet)) - .await - { - tracing::error!(?e, "webrtc_write"); - if matches!( - e, - webrtc::data::Error::ErrStreamClosed - | webrtc::data::Error::Sctp(webrtc::sctp::Error::ErrStreamClosed) - ) { - self.stop_peer(encapsulated_packet.index, encapsulated_packet.conn_id) - .await; - } - let err = e.into(); - let _ = self.callbacks.on_error(&err); - Err(err) - } else { - Ok(()) - } - } - _ => panic!("Unexpected result from encapsulate"), - } - } -} diff --git a/rust/connlib/tunnel/src/lib.rs b/rust/connlib/tunnel/src/lib.rs index 5fb70a090..b22b32016 100644 --- a/rust/connlib/tunnel/src/lib.rs +++ b/rust/connlib/tunnel/src/lib.rs @@ -3,10 +3,9 @@ //! This is both the wireguard and ICE implementation that should work in tandem. //! [Tunnel] is the main entry-point for this crate. use boringtun::{ - noise::{errors::WireGuardError, rate_limiter::RateLimiter, TunnResult}, + noise::rate_limiter::RateLimiter, x25519::{PublicKey, StaticSecret}, }; -use bytes::Bytes; use connlib_shared::{messages::Key, CallbackErrorFacade, Callbacks, Error}; use ip_network::IpNetwork; @@ -29,9 +28,12 @@ use webrtc::{ peer_connection::RTCPeerConnection, }; +use futures::channel::mpsc; +use futures_util::{SinkExt, StreamExt}; use std::hash::Hash; use std::task::{Context, Poll}; use std::{collections::HashMap, fmt, io, net::IpAddr, sync::Arc, time::Duration}; +use tokio::time::Interval; use webrtc::ice_transport::ice_candidate::RTCIceCandidateInit; use connlib_shared::{ @@ -56,19 +58,14 @@ mod control_protocol; mod device_channel; mod dns; mod gateway; -mod iface_handler; mod index; mod ip_packet; mod peer; mod peer_handler; -mod resource_sender; mod resource_table; mod tokio_util; const MAX_UDP_SIZE: usize = (1 << 16) - 1; -const RESET_PACKET_COUNT_INTERVAL: Duration = Duration::from_secs(1); -const REFRESH_PEERS_TIMERS_INTERVAL: Duration = Duration::from_secs(1); -const REFRESH_MTU_INTERVAL: Duration = Duration::from_secs(30); const DNS_QUERIES_QUEUE_SIZE: usize = 100; /// For how long we will attempt to gather ICE candidates before aborting. @@ -152,6 +149,13 @@ pub struct Tunnel { /// State that differs per role, i.e. clients vs gateways. role_state: Mutex, + + stop_peer_command_receiver: Mutex>, + stop_peer_command_sender: mpsc::Sender<(u32, TRoleState::Id)>, + + rate_limit_reset_interval: Mutex, + peer_refresh_interval: Mutex, + mtu_refresh_interval: Mutex, } // TODO: For now we only use these fields with debug @@ -189,7 +193,109 @@ where } pub fn poll_next_event(&self, cx: &mut Context<'_>) -> Poll> { - self.role_state.lock().poll_next_event(cx) + loop { + if self + .rate_limit_reset_interval + .lock() + .poll_tick(cx) + .is_ready() + { + self.rate_limiter.reset_count(); + continue; + } + + if self.peer_refresh_interval.lock().poll_tick(cx).is_ready() { + let peers_to_refresh = { + let mut peers_by_ip = self.peers_by_ip.write(); + + peers_to_refresh(&mut peers_by_ip, self.stop_peer_command_sender.clone()) + }; + + for peer in peers_to_refresh { + let callbacks = self.callbacks.clone(); + let mut stop_command_sender = self.stop_peer_command_sender.clone(); + + tokio::spawn(async move { + if let Err(e) = peer.update_timers().await { + tracing::error!("Failed to update timers for peer: {e}"); + let _ = callbacks.on_error(&e); + + if e.is_fatal_connection_error() { + let _ = stop_command_sender.send((peer.index, peer.conn_id)).await; + } + } + }); + } + continue; + } + + if self.mtu_refresh_interval.lock().poll_tick(cx).is_ready() { + // We use `try_read` to acquire a lock on the device because we are within a synchronous context here and cannot use `.await`. + // The device is only updated during `add_route` and `set_interface` which would be extremely unlucky to hit at the same time as this timer. + // Even if we hit this, we just wait for the next tick to update the MTU. + let device = match self.device.try_read().map(|d| d.clone()) { + Ok(Some(device)) => device, + Ok(None) => { + let err = Error::ControlProtocolError; + tracing::error!(?err, "get_iface_config"); + let _ = self.callbacks.on_error(&err); + continue; + } + Err(_) => { + tracing::debug!("Unlucky! Somebody is updating the device just as we are about to update its MTU, trying again on the next tick ..."); + continue; + } + }; + + tokio::spawn({ + let callbacks = self.callbacks.clone(); + + async move { + if let Err(e) = device.config.refresh_mtu().await { + tracing::error!(error = ?e, "refresh_mtu"); + let _ = callbacks.on_error(&e); + } + } + }); + } + + if let Poll::Ready(event) = self.role_state.lock().poll_next_event(cx) { + return Poll::Ready(event); + } + + if let Poll::Ready(Some((i, conn_id))) = + self.stop_peer_command_receiver.lock().poll_next_unpin(cx) + { + let mut peers = self.peers_by_ip.write(); + + let (maybe_network, maybe_peer) = peers + .iter() + .find_map(|(n, p)| (p.index == i).then_some((n, p.clone()))) + .unzip(); + + if let Some(network) = maybe_network { + peers.remove(network); + } + + if let Some(conn) = self.peer_connections.lock().remove(&conn_id) { + tokio::spawn({ + let callbacks = self.callbacks.clone(); + async move { + if let Some(peer) = maybe_peer { + let _ = peer.shutdown().await; + } + if let Err(e) = conn.close().await { + tracing::warn!(%conn_id, error = ?e, "Can't close peer"); + let _ = callbacks.on_error(&e.into()); + } + } + }); + } + continue; + } + + return Poll::Pending; + } } } @@ -284,6 +390,8 @@ where .with_setting_engine(setting_engine) .build(); + let (stop_peer_command_sender, stop_peer_command_receiver) = mpsc::channel(10); + Ok(Self { rate_limiter, private_key, @@ -296,117 +404,14 @@ where callbacks: CallbackErrorFacade(callbacks), iface_handler_abort, role_state: Default::default(), + stop_peer_command_receiver: Mutex::new(stop_peer_command_receiver), + stop_peer_command_sender, + rate_limit_reset_interval: Mutex::new(rate_limit_reset_interval()), + peer_refresh_interval: Mutex::new(peer_refresh_interval()), + mtu_refresh_interval: Mutex::new(mtu_refresh_interval()), }) } - #[tracing::instrument(level = "trace", skip(self))] - async fn stop_peer(&self, index: u32, conn_id: TRoleState::Id) { - self.peers_by_ip.write().retain(|_, p| p.index != index); - let conn = self.peer_connections.lock().remove(&conn_id); - if let Some(conn) = conn { - if let Err(e) = conn.close().await { - tracing::warn!(error = ?e, "Can't close peer"); - let _ = self.callbacks().on_error(&e.into()); - } - } - } - - async fn peer_refresh(&self, peer: &Peer, dst_buf: &mut [u8; MAX_UDP_SIZE]) { - let update_timers_result = peer.update_timers(&mut dst_buf[..]); - - match update_timers_result { - TunnResult::Done => {} - TunnResult::Err(WireGuardError::ConnectionExpired) - | TunnResult::Err(WireGuardError::NoCurrentSession) => { - self.stop_peer(peer.index, peer.conn_id).await; - let _ = peer.shutdown().await; - } - TunnResult::Err(e) => tracing::error!(error = ?e, "timer_error"), - TunnResult::WriteToNetwork(packet) => { - let bytes = Bytes::copy_from_slice(packet); - peer.send_infallible(bytes, &self.callbacks).await - } - - _ => panic!("Unexpected result from update_timers"), - }; - } - - fn start_rate_limiter_refresh_timer(self: &Arc) { - let rate_limiter = self.rate_limiter.clone(); - tokio::spawn(async move { - let mut interval = tokio::time::interval(RESET_PACKET_COUNT_INTERVAL); - interval.set_missed_tick_behavior(MissedTickBehavior::Delay); - loop { - rate_limiter.reset_count(); - interval.tick().await; - } - }); - } - - fn start_peers_refresh_timer(self: &Arc) { - let tunnel = self.clone(); - - tokio::spawn(async move { - let mut interval = tokio::time::interval(REFRESH_PEERS_TIMERS_INTERVAL); - interval.set_missed_tick_behavior(MissedTickBehavior::Delay); - let mut dst_buf = [0u8; MAX_UDP_SIZE]; - - loop { - remove_expired_peers( - &mut tunnel.peers_by_ip.write(), - &mut tunnel.peer_connections.lock(), - ); - - let peers: Vec<_> = tunnel - .peers_by_ip - .read() - .iter() - .map(|p| p.1) - .unique_by(|p| p.index) - .cloned() - .collect(); - - for peer in peers { - tunnel.peer_refresh(&peer, &mut dst_buf).await; - } - - interval.tick().await; - } - }); - } - - async fn start_refresh_mtu_timer(self: &Arc) -> Result<()> { - let dev = self.clone(); - let callbacks = self.callbacks().clone(); - tokio::spawn(async move { - let mut interval = tokio::time::interval(REFRESH_MTU_INTERVAL); - interval.set_missed_tick_behavior(MissedTickBehavior::Delay); - loop { - interval.tick().await; - - let Some(device) = dev.device.read().await.clone() else { - let err = Error::ControlProtocolError; - tracing::error!(?err, "get_iface_config"); - let _ = callbacks.0.on_error(&err); - continue; - }; - if let Err(e) = device.config.refresh_mtu().await { - tracing::error!(error = ?e, "refresh_mtu"); - let _ = callbacks.0.on_error(&e); - } - } - }); - - Ok(()) - } - - async fn start_timers(self: &Arc) -> Result<()> { - self.start_refresh_mtu_timer().await?; - self.start_rate_limiter_refresh_timer(); - self.start_peers_refresh_timer(); - Ok(()) - } - fn next_index(&self) -> u32 { self.next_index.lock().next() } @@ -416,27 +421,68 @@ where } } +/// Constructs the interval for resetting the rate limit count. +/// +/// As per documentation on [`RateLimiter::reset_count`], this is configured to run every second. +fn rate_limit_reset_interval() -> Interval { + let mut interval = tokio::time::interval(Duration::from_secs(1)); + interval.set_missed_tick_behavior(MissedTickBehavior::Delay); + + interval +} + +/// Constructs the interval for "refreshing" peers. +/// +/// On each tick, we remove expired peers from our map, update wireguard timers and send packets, if any. +fn peer_refresh_interval() -> Interval { + let mut interval = tokio::time::interval(Duration::from_secs(1)); + interval.set_missed_tick_behavior(MissedTickBehavior::Delay); + + interval +} + +/// Constructs the interval for refreshing the MTU of our TUN device. +fn mtu_refresh_interval() -> Interval { + let mut interval = tokio::time::interval(Duration::from_secs(30)); + interval.set_missed_tick_behavior(MissedTickBehavior::Delay); + + interval +} + +fn peers_to_refresh( + peers_by_ip: &mut IpNetworkTable>>, + shutdown_sender: mpsc::Sender<(u32, TId)>, +) -> Vec>> +where + TId: Eq + Hash + Copy + Send + Sync + 'static, +{ + remove_expired_peers(peers_by_ip, shutdown_sender); + + peers_by_ip + .iter() + .map(|p| p.1) + .unique_by(|p| p.index) + .cloned() + .collect() +} + fn remove_expired_peers( peers_by_ip: &mut IpNetworkTable>>, - peer_connections: &mut HashMap>, + shutdown_sender: mpsc::Sender<(u32, TId)>, ) where TId: Eq + Hash + Copy + Send + Sync + 'static, { for (_, peer) in peers_by_ip.iter() { peer.expire_resources(); if peer.is_emptied() { - tracing::trace!(index = peer.index, "peer_expired"); - let conn = peer_connections.remove(&peer.conn_id); - let p = peer.clone(); + let index = peer.index; + let conn_id = peer.conn_id; - // We are holding a Mutex, particularly a write one, we don't want to make a blocking call - tokio::spawn(async move { - let _ = p.shutdown().await; - if let Some(conn) = conn { - // TODO: it seems that even closing the stream there are messages to the relay - // see where they come from. - let _ = conn.close().await; - } + tracing::trace!(%index, "peer_expired"); + + tokio::spawn({ + let mut sender = shutdown_sender.clone(); + async move { sender.send((index, conn_id)).await } }); } } @@ -449,7 +495,7 @@ fn remove_expired_peers( /// By design, this trait does not allow any operations apart from advancing via [`RoleState::poll_next_event`]. /// The state should only be modified when the concrete type is known, e.g. [`ClientState`] or [`GatewayState`]. pub trait RoleState: Default + Send + 'static { - type Id: fmt::Debug + Eq + Hash + Copy + Send + Sync + 'static; + type Id: fmt::Debug + fmt::Display + Eq + Hash + Copy + Unpin + Send + Sync + 'static; fn poll_next_event(&mut self, cx: &mut Context<'_>) -> Poll>; } diff --git a/rust/connlib/tunnel/src/peer.rs b/rust/connlib/tunnel/src/peer.rs index d58f5de9d..906825040 100644 --- a/rust/connlib/tunnel/src/peer.rs +++ b/rust/connlib/tunnel/src/peer.rs @@ -22,10 +22,10 @@ type ExpiryingResource = (ResourceDescription, DateTime); pub(crate) struct Peer { pub tunnel: Mutex, pub index: u32, - pub allowed_ips: RwLock>, + allowed_ips: RwLock>, pub channel: Arc, pub conn_id: TId, - pub resources: Option>>, + resources: Option>>, // Here we store the address that we obtained for the resource that the peer corresponds to. // This can have the following problem: // 1. Peer sends packet to address.com and it resolves to 1.1.1.1 @@ -35,7 +35,7 @@ pub(crate) struct Peer { // so, TODO: store multiple ips and expire them. // Note that this case is quite an unlikely edge case so I wouldn't prioritize this fix // TODO: Also check if there's any case where we want to talk to ipv4 and ipv6 from the same peer. - pub translated_resource_addresses: RwLock>, + translated_resource_addresses: RwLock>, } // TODO: For now we only use these fields with debug @@ -50,14 +50,6 @@ pub(crate) struct PeerStats { pub translated_resource_addresses: HashMap, } -#[derive(Debug)] -pub(crate) struct EncapsulatedPacket<'a, Id> { - pub index: u32, - pub conn_id: Id, - pub channel: Arc, - pub encapsulate_result: TunnResult<'a>, -} - impl Peer where TId: Copy, @@ -130,7 +122,7 @@ where } } - pub(crate) fn get_translation(&self, ip: IpAddr) -> Option { + fn get_translation(&self, ip: IpAddr) -> Option { let id = self.translated_resource_addresses.read().get(&ip).cloned(); self.resources.as_ref().and_then(|resources| { id.and_then(|id| resources.read().get_by_id(&id).map(|r| r.0.clone())) @@ -141,8 +133,25 @@ where self.allowed_ips.write().insert(ip, ()); } - pub(crate) fn update_timers<'a>(&self, dst: &'a mut [u8]) -> TunnResult<'a> { - self.tunnel.lock().update_timers(dst) + pub(crate) async fn update_timers<'a>(&self) -> Result<()> { + /// [`boringtun`] requires us to pass buffers in where it can construct its packets. + /// + /// When updating the timers, the largest packet that we may have to send is `148` bytes as per `HANDSHAKE_INIT_SZ` constant in [`boringtun`]. + const MAX_SCRATCH_SPACE: usize = 148; + + let mut buf = [0u8; MAX_SCRATCH_SPACE]; + + let packet = match self.tunnel.lock().update_timers(&mut buf) { + TunnResult::Done => return Ok(()), + TunnResult::Err(e) => return Err(e.into()), + TunnResult::WriteToNetwork(b) => b, + _ => panic!("Unexpected result from update_timers"), + }; + + let bytes = Bytes::copy_from_slice(packet); + self.channel.write(&bytes).await?; + + Ok(()) } pub(crate) async fn shutdown(&self) -> Result<()> { @@ -189,11 +198,13 @@ where self.translated_resource_addresses.write().insert(addr, id); } - pub(crate) fn encapsulate<'a>( + /// Sends the given packet to this peer by encapsulating it in a wireguard packet. + pub(crate) async fn send<'a>( &self, - packet: &mut MutableIpPacket<'a>, - dst: &'a mut [u8], - ) -> Result> { + mut packet: MutableIpPacket<'a>, + dest: IpAddr, + buf: &mut [u8], + ) -> Result<()> { if let Some(resource) = self.get_translation(packet.to_immutable().source()) { let ResourceDescription::Dns(resource) = resource else { tracing::error!( @@ -209,12 +220,18 @@ where packet.update_checksum(); } - Ok(EncapsulatedPacket { - index: self.index, - conn_id: self.conn_id, - channel: self.channel.clone(), - encapsulate_result: self.tunnel.lock().encapsulate(packet.packet_mut(), dst), - }) + let packet = match self.tunnel.lock().encapsulate(packet.packet_mut(), buf) { + TunnResult::Done => return Ok(()), + TunnResult::Err(e) => return Err(e.into()), + TunnResult::WriteToNetwork(b) => b, + _ => panic!("Unexpected result from `encapsulate`"), + }; + + tracing::trace!(target: "wire", action = "writing", from = "iface", to = %dest); + + self.channel.write(&Bytes::copy_from_slice(packet)).await?; + + Ok(()) } pub(crate) fn get_packet_resource( diff --git a/rust/connlib/tunnel/src/peer_handler.rs b/rust/connlib/tunnel/src/peer_handler.rs index c2edaa086..7a790efe5 100644 --- a/rust/connlib/tunnel/src/peer_handler.rs +++ b/rust/connlib/tunnel/src/peer_handler.rs @@ -1,9 +1,14 @@ +use std::net::{IpAddr, ToSocketAddrs}; use std::sync::Arc; use boringtun::noise::{handshake::parse_handshake_anon, Packet, TunnResult}; +use boringtun::x25519::{PublicKey, StaticSecret}; use bytes::Bytes; +use connlib_shared::messages::ResourceDescription; use connlib_shared::{Callbacks, Error, Result}; +use futures_util::SinkExt; +use crate::ip_packet::MutableIpPacket; use crate::{ device_channel::DeviceIo, index::check_packet_index, peer::Peer, RoleState, Tunnel, MAX_UDP_SIZE, @@ -14,16 +19,95 @@ where CB: Callbacks + 'static, TRoleState: RoleState, { - #[inline(always)] - fn is_wireguard_packet_ok(&self, parsed_packet: &Packet, peer: &Peer) -> bool { - match &parsed_packet { - Packet::HandshakeInit(p) => { - parse_handshake_anon(&self.private_key, &self.public_key, p).is_ok() + pub(crate) async fn start_peer_handler(self: Arc, peer: Arc>) { + loop { + let Some(device) = self.device.read().await.clone() else { + let err = Error::NoIface; + tracing::error!(?err); + let _ = self.callbacks().on_disconnect(Some(&err)); + break; + }; + let device_io = device.io; + + if let Err(err) = self.peer_handler(&peer, device_io).await { + if err.raw_os_error() != Some(9) { + tracing::error!(?err); + let _ = self.callbacks().on_error(&err.into()); + break; + } else { + tracing::warn!("bad_file_descriptor"); + } } - Packet::HandshakeResponse(p) => check_packet_index(p.receiver_idx, peer.index), - Packet::PacketCookieReply(p) => check_packet_index(p.receiver_idx, peer.index), - Packet::PacketData(p) => check_packet_index(p.receiver_idx, peer.index), } + tracing::debug!(peer = ?peer.stats(), "peer_stopped"); + let _ = self + .stop_peer_command_sender + .clone() + .send((peer.index, peer.conn_id)) + .await; + } + + async fn peer_handler( + self: &Arc, + peer: &Arc>, + device_io: DeviceIo, + ) -> std::io::Result<()> { + let mut src_buf = [0u8; MAX_UDP_SIZE]; + let mut dst_buf = [0u8; MAX_UDP_SIZE]; + while let Ok(size) = peer.channel.read(&mut src_buf[..]).await { + tracing::trace!(target: "wire", action = "read", bytes = size, from = "peer"); + + // TODO: Double check that this can only happen on closed channel + // I think it's possible to transmit a 0-byte message through the channel + // but we would never use that. + // We should keep track of an open/closed channel ourselves if we wanted to do it properly then. + if size == 0 { + break; + } + + if let Err(Error::Io(e)) = self + .handle_peer_packet(peer, &device_io, &src_buf[..size], &mut dst_buf) + .await + { + return Err(e); + } + } + + Ok(()) + } + + #[inline(always)] + pub(crate) async fn handle_peer_packet( + self: &Arc, + peer: &Arc>, + device_writer: &DeviceIo, + src: &[u8], + dst: &mut [u8], + ) -> Result<()> { + let parsed_packet = self.verify_packet(peer, src, dst).await?; + if !is_wireguard_packet_ok(&self.private_key, &self.public_key, &parsed_packet, peer) { + tracing::error!("wireguard_verification"); + return Err(Error::BadPacket); + } + + let decapsulate_result = peer.tunnel.lock().decapsulate(None, src, dst); + + if self + .handle_decapsulated_packet(peer, device_writer, decapsulate_result) + .await? + { + // Flush pending queue + while let TunnResult::WriteToNetwork(packet) = + peer.tunnel.lock().decapsulate(None, &[], dst) + { + let bytes = Bytes::copy_from_slice(packet); + let callbacks = self.callbacks.clone(); + let peer = peer.clone(); + tokio::spawn(async move { peer.send_infallible(bytes, &callbacks).await }); + } + } + + Ok(()) } #[inline(always)] @@ -80,101 +164,143 @@ where Ok(true) } TunnResult::WriteToTunnelV4(packet, addr) => { - self.send_to_resource(device_io, peer, addr.into(), packet)?; + send_to_resource(device_io, peer, addr.into(), packet)?; Ok(false) } TunnResult::WriteToTunnelV6(packet, addr) => { - self.send_to_resource(device_io, peer, addr.into(), packet)?; + send_to_resource(device_io, peer, addr.into(), packet)?; Ok(false) } } } +} - #[inline(always)] - pub(crate) async fn handle_peer_packet( - self: &Arc, - peer: &Arc>, - device_writer: &DeviceIo, - src: &[u8], - dst: &mut [u8], - ) -> Result<()> { - let parsed_packet = self.verify_packet(peer, src, dst).await?; - if !self.is_wireguard_packet_ok(&parsed_packet, peer) { - tracing::error!("wireguard_verification"); - return Err(Error::BadPacket); - } - - let decapsulate_result = peer.tunnel.lock().decapsulate(None, src, dst); - - if self - .handle_decapsulated_packet(peer, device_writer, decapsulate_result) - .await? - { - // Flush pending queue - while let TunnResult::WriteToNetwork(packet) = { - let res = peer.tunnel.lock().decapsulate(None, &[], dst); - res - } { - let bytes = Bytes::copy_from_slice(packet); - let callbacks = self.callbacks.clone(); - let peer = peer.clone(); - tokio::spawn(async move { peer.send_infallible(bytes, &callbacks).await }); - } - } - - Ok(()) - } - - async fn peer_handler( - self: &Arc, - peer: &Arc>, - device_io: DeviceIo, - ) -> std::io::Result<()> { - let mut src_buf = [0u8; MAX_UDP_SIZE]; - let mut dst_buf = [0u8; MAX_UDP_SIZE]; - while let Ok(size) = peer.channel.read(&mut src_buf[..]).await { - tracing::trace!(target: "wire", action = "read", bytes = size, from = "peer"); - - // TODO: Double check that this can only happen on closed channel - // I think it's possible to transmit a 0-byte message through the channel - // but we would never use that. - // We should keep track of an open/closed channel ourselves if we wanted to do it properly then. - if size == 0 { - break; - } - - if let Err(Error::Io(e)) = self - .handle_peer_packet(peer, &device_io, &src_buf[..size], &mut dst_buf) - .await - { - return Err(e); - } - } - - Ok(()) - } - - pub(crate) async fn start_peer_handler(self: Arc, peer: Arc>) { - loop { - let Some(device) = self.device.read().await.clone() else { - let err = Error::NoIface; - tracing::error!(?err); - let _ = self.callbacks().on_disconnect(Some(&err)); - break; - }; - let device_io = device.io; - - if let Err(err) = self.peer_handler(&peer, device_io).await { - if err.raw_os_error() != Some(9) { - tracing::error!(?err); - let _ = self.callbacks().on_error(&err.into()); - break; - } else { - tracing::warn!("bad_file_descriptor"); - } - } - } - tracing::debug!(peer = ?peer.stats(), "peer_stopped"); - self.stop_peer(peer.index, peer.conn_id).await; +#[inline(always)] +fn is_wireguard_packet_ok( + private_key: &StaticSecret, + public_key: &PublicKey, + parsed_packet: &Packet, + peer: &Peer, +) -> bool { + match parsed_packet { + Packet::HandshakeInit(p) => parse_handshake_anon(private_key, public_key, p).is_ok(), + Packet::HandshakeResponse(p) => check_packet_index(p.receiver_idx, peer.index), + Packet::PacketCookieReply(p) => check_packet_index(p.receiver_idx, peer.index), + Packet::PacketData(p) => check_packet_index(p.receiver_idx, peer.index), + } +} + +fn send_to_resource( + device_io: &DeviceIo, + peer: &Arc>, + addr: IpAddr, + packet: &mut [u8], +) -> Result<()> +where + TId: Copy, +{ + if peer.is_allowed(addr) { + packet_allowed(device_io, peer, addr, packet)?; + Ok(()) + } else { + tracing::warn!(%addr, "Received packet from peer with an unallowed ip"); + Ok(()) + } +} + +#[inline(always)] +pub(crate) fn packet_allowed( + device_io: &DeviceIo, + peer: &Arc>, + addr: IpAddr, + packet: &mut [u8], +) -> Result<()> +where + TId: Copy, +{ + let Some((dst, resource)) = peer.get_packet_resource(packet) else { + // If there's no associated resource it means that we are in a client, then the packet comes from a gateway + // and we just trust gateways. + // In gateways this should never happen. + tracing::trace!(target: "wire", action = "writing", to = "iface", %addr, bytes = %packet.len()); + send_packet(device_io, packet, addr)?; + return Ok(()); + }; + + let (dst_addr, _dst_port) = get_resource_addr_and_port(peer, &resource, &addr, &dst)?; + update_packet(packet, dst_addr); + send_packet(device_io, packet, addr)?; + Ok(()) +} + +#[inline(always)] +fn send_packet(device_io: &DeviceIo, packet: &mut [u8], dst_addr: IpAddr) -> std::io::Result<()> { + match dst_addr { + IpAddr::V4(_) => device_io.write4(packet)?, + IpAddr::V6(_) => device_io.write6(packet)?, + }; + Ok(()) +} + +#[inline(always)] +fn update_packet(packet: &mut [u8], dst_addr: IpAddr) { + let Some(mut pkt) = MutableIpPacket::new(packet) else { + return; + }; + pkt.set_dst(dst_addr); + pkt.update_checksum(); +} + +fn get_matching_version_ip(addr: &IpAddr, ip: &IpAddr) -> Option { + ((addr.is_ipv4() && ip.is_ipv4()) || (addr.is_ipv6() && ip.is_ipv6())).then_some(*ip) +} + +fn get_resource_addr_and_port( + peer: &Arc>, + resource: &ResourceDescription, + addr: &IpAddr, + dst: &IpAddr, +) -> Result<(IpAddr, Option)> +where + TId: Copy, +{ + match resource { + ResourceDescription::Dns(r) => { + let mut address = r.address.split(':'); + let Some(dst_addr) = address.next() else { + tracing::error!("invalid DNS name for resource: {}", r.address); + return Err(Error::InvalidResource); + }; + let Ok(mut dst_addr) = (dst_addr, 0).to_socket_addrs() else { + tracing::warn!(%addr, "Couldn't resolve name"); + return Err(Error::InvalidResource); + }; + let Some(dst_addr) = dst_addr.find_map(|d| get_matching_version_ip(addr, &d.ip())) + else { + tracing::warn!(%addr, "Couldn't resolve name addr"); + return Err(Error::InvalidResource); + }; + peer.update_translated_resource_address(r.id, dst_addr); + Ok(( + dst_addr, + address + .next() + .map(str::parse::) + .and_then(std::result::Result::ok), + )) + } + ResourceDescription::Cidr(r) => { + if r.address.contains(*dst) { + Ok(( + get_matching_version_ip(addr, dst).ok_or(Error::InvalidResource)?, + None, + )) + } else { + tracing::warn!( + "client tried to hijack the tunnel for range outside what it's allowed." + ); + Err(Error::InvalidSource) + } + } } } diff --git a/rust/connlib/tunnel/src/resource_sender.rs b/rust/connlib/tunnel/src/resource_sender.rs deleted file mode 100644 index 666faa926..000000000 --- a/rust/connlib/tunnel/src/resource_sender.rs +++ /dev/null @@ -1,130 +0,0 @@ -use std::{ - net::{IpAddr, ToSocketAddrs}, - sync::Arc, -}; - -use crate::{device_channel::DeviceIo, ip_packet::MutableIpPacket, peer::Peer, RoleState, Tunnel}; - -use connlib_shared::{messages::ResourceDescription, Callbacks, Error, Result}; - -impl Tunnel -where - CB: Callbacks + 'static, - TRoleState: RoleState, -{ - #[inline(always)] - fn update_packet(&self, packet: &mut [u8], dst_addr: IpAddr) { - let Some(mut pkt) = MutableIpPacket::new(packet) else { - return; - }; - pkt.set_dst(dst_addr); - pkt.update_checksum(); - } - - #[inline(always)] - fn send_packet( - &self, - device_io: &DeviceIo, - packet: &mut [u8], - dst_addr: IpAddr, - ) -> std::io::Result<()> { - match dst_addr { - IpAddr::V4(_) => device_io.write4(packet)?, - IpAddr::V6(_) => device_io.write6(packet)?, - }; - Ok(()) - } - - #[inline(always)] - pub(crate) fn packet_allowed( - &self, - device_io: &DeviceIo, - peer: &Arc>, - addr: IpAddr, - packet: &mut [u8], - ) -> Result<()> { - let Some((dst, resource)) = peer.get_packet_resource(packet) else { - // If there's no associated resource it means that we are in a client, then the packet comes from a gateway - // and we just trust gateways. - // In gateways this should never happen. - tracing::trace!(target: "wire", action = "writing", to = "iface", %addr, bytes = %packet.len()); - self.send_packet(device_io, packet, addr)?; - return Ok(()); - }; - - let (dst_addr, _dst_port) = get_resource_addr_and_port(peer, &resource, &addr, &dst)?; - self.update_packet(packet, dst_addr); - self.send_packet(device_io, packet, addr)?; - Ok(()) - } - - pub(crate) fn send_to_resource( - &self, - device_io: &DeviceIo, - peer: &Arc>, - addr: IpAddr, - packet: &mut [u8], - ) -> Result<()> { - if peer.is_allowed(addr) { - self.packet_allowed(device_io, peer, addr, packet)?; - Ok(()) - } else { - tracing::warn!(%addr, "Received packet from peer with an unallowed ip"); - Ok(()) - } - } -} - -fn get_matching_version_ip(addr: &IpAddr, ip: &IpAddr) -> Option { - ((addr.is_ipv4() && ip.is_ipv4()) || (addr.is_ipv6() && ip.is_ipv6())).then_some(*ip) -} - -fn get_resource_addr_and_port( - peer: &Arc>, - resource: &ResourceDescription, - addr: &IpAddr, - dst: &IpAddr, -) -> Result<(IpAddr, Option)> -where - TId: Copy, -{ - match resource { - ResourceDescription::Dns(r) => { - let mut address = r.address.split(':'); - let Some(dst_addr) = address.next() else { - tracing::error!("invalid DNS name for resource: {}", r.address); - return Err(Error::InvalidResource); - }; - let Ok(mut dst_addr) = (dst_addr, 0).to_socket_addrs() else { - tracing::warn!(%addr, "Couldn't resolve name"); - return Err(Error::InvalidResource); - }; - let Some(dst_addr) = dst_addr.find_map(|d| get_matching_version_ip(addr, &d.ip())) - else { - tracing::warn!(%addr, "Couldn't resolve name addr"); - return Err(Error::InvalidResource); - }; - peer.update_translated_resource_address(r.id, dst_addr); - Ok(( - dst_addr, - address - .next() - .map(str::parse::) - .and_then(std::result::Result::ok), - )) - } - ResourceDescription::Cidr(r) => { - if r.address.contains(*dst) { - Ok(( - get_matching_version_ip(addr, dst).ok_or(Error::InvalidResource)?, - None, - )) - } else { - tracing::warn!( - "client tried to hijack the tunnel for range outside what it's allowed." - ); - Err(Error::InvalidSource) - } - } - } -}