refactor(connlib): move more logic to poll_next_event (#2403)

This commit is contained in:
Thomas Eizinger
2023-10-19 13:30:04 +11:00
committed by GitHub
parent 573124bd2f
commit 919b7890e6
11 changed files with 537 additions and 514 deletions

View File

@@ -145,6 +145,26 @@ impl ConnlibError {
if e.status().is_client_error()
)
}
/// Whether this error is fatal to the underlying connection.
pub fn is_fatal_connection_error(&self) -> bool {
if let Self::WireguardError(e) = self {
return matches!(
e,
WireGuardError::ConnectionExpired | WireGuardError::NoCurrentSession
);
}
if let Self::IceDataError(e) = self {
return matches!(
e,
webrtc::data::Error::ErrStreamClosed
| webrtc::data::Error::Sctp(webrtc::sctp::Error::ErrStreamClosed)
);
}
false
}
}
#[cfg(target_os = "linux")]

View File

@@ -18,6 +18,7 @@ use connlib_shared::{Callbacks, DNS_SENTINEL};
use futures::channel::mpsc::Receiver;
use futures::stream;
use futures_bounded::{PushError, StreamMap};
use futures_util::SinkExt;
use hickory_resolver::lookup::Lookup;
use ip_network::IpNetwork;
use ip_network_table::IpNetworkTable;
@@ -116,9 +117,8 @@ where
config: &InterfaceConfig,
) -> connlib_shared::Result<()> {
let device = create_iface(config, self.callbacks()).await?;
*self.device.write().await = Some(device.clone());
self.start_timers().await?;
*self.device.write().await = Some(device.clone());
*self.iface_handler_abort.lock() = Some(tokio_util::spawn_log(
&self.callbacks,
device_handler(Arc::clone(self), device),
@@ -203,12 +203,18 @@ where
continue;
};
if let Err(e) = tunnel
.encapsulate_and_send_to_peer(packet, peer, &dest, &mut buf)
.await
{
if let Err(e) = peer.send(packet, dest, &mut buf).await {
tracing::error!(resource_address = %dest, err = ?e, "failed to handle packet {e:#}");
let _ = tunnel.callbacks.on_error(&e);
tracing::error!(err = ?e, "failed to handle packet {e:#}")
if e.is_fatal_connection_error() {
let _ = tunnel
.stop_peer_command_sender
.clone()
.send((peer.index, peer.conn_id))
.await;
}
}
}
}

View File

@@ -38,36 +38,6 @@ where
CB: Callbacks + 'static,
TRoleState: RoleState,
{
pub fn on_dc_close_handler(
self: Arc<Self>,
index: u32,
conn_id: TRoleState::Id,
) -> OnCloseHdlrFn {
Box::new(move || {
tracing::debug!("channel_closed");
let tunnel = self.clone();
Box::pin(async move {
tunnel.stop_peer(index, conn_id).await;
})
})
}
pub fn on_peer_connection_state_change_handler(
self: Arc<Self>,
index: u32,
conn_id: TRoleState::Id,
) -> OnPeerConnectionStateChangeHdlrFn {
Box::new(move |state| {
let tunnel = Arc::clone(&self);
Box::pin(async move {
tracing::trace!(?state, "peer_state_update");
if state == RTCPeerConnectionState::Failed {
tunnel.stop_peer(index, conn_id).await;
}
})
})
}
pub async fn add_ice_candidate(
&self,
conn_id: TRoleState::Id,
@@ -84,6 +54,44 @@ where
}
}
pub fn on_peer_connection_state_change_handler<TId>(
index: u32,
conn_id: TId,
stop_command_sender: mpsc::Sender<(u32, TId)>,
) -> OnPeerConnectionStateChangeHdlrFn
where
TId: Copy + Send + Sync + 'static,
{
Box::new(move |state| {
let mut sender = stop_command_sender.clone();
tracing::trace!(?state, "peer_state_update");
Box::pin(async move {
if state == RTCPeerConnectionState::Failed {
let _ = sender.send((index, conn_id)).await;
}
})
})
}
pub fn on_dc_close_handler<TId>(
index: u32,
conn_id: TId,
stop_command_sender: mpsc::Sender<(u32, TId)>,
) -> OnCloseHdlrFn
where
TId: Copy + Send + Sync + 'static,
{
Box::new(move || {
let mut sender = stop_command_sender.clone();
tracing::debug!("channel_closed");
Box::pin(async move {
let _ = sender.send((index, conn_id)).await;
})
})
}
#[tracing::instrument(level = "trace", skip(webrtc))]
pub async fn new_peer_connection(
webrtc: &webrtc::api::API,

View File

@@ -16,25 +16,11 @@ use webrtc::{
},
};
use crate::control_protocol::new_peer_connection;
use crate::control_protocol::{
new_peer_connection, on_dc_close_handler, on_peer_connection_state_change_handler,
};
use crate::{peer::Peer, ClientState, Error, Request, Result, Tunnel};
#[tracing::instrument(level = "trace", skip(tunnel))]
fn handle_connection_state_update<CB>(
tunnel: &Arc<Tunnel<CB, ClientState>>,
state: RTCPeerConnectionState,
gateway_id: GatewayId,
resource_id: ResourceId,
) where
CB: Callbacks + 'static,
{
tracing::trace!("peer_state");
if state == RTCPeerConnectionState::Failed {
tunnel.role_state.lock().on_connection_failed(resource_id);
tunnel.peer_connections.lock().remove(&gateway_id);
}
}
#[tracing::instrument(level = "trace", skip(tunnel))]
fn set_connection_state_update<CB>(
tunnel: &Arc<Tunnel<CB, ClientState>>,
@@ -49,7 +35,11 @@ fn set_connection_state_update<CB>(
move |state: RTCPeerConnectionState| {
let tunnel = Arc::clone(&tunnel);
Box::pin(async move {
handle_connection_state_update(&tunnel, state, gateway_id, resource_id)
tracing::trace!("peer_state");
if state == RTCPeerConnectionState::Failed {
tunnel.role_state.lock().on_connection_failed(resource_id);
tunnel.peer_connections.lock().remove(&gateway_id);
}
})
},
));
@@ -139,7 +129,7 @@ where
}
};
d.on_close(tunnel.clone().on_dc_close_handler(index, gateway_id));
d.on_close(on_dc_close_handler(index, gateway_id, tunnel.stop_peer_command_sender.clone()));
let peer = Arc::new(Peer::new(
tunnel.private_key.clone(),
@@ -171,9 +161,7 @@ where
if let Some(conn) = tunnel.peer_connections.lock().get(&gateway_id) {
conn.on_peer_connection_state_change(
tunnel
.clone()
.on_peer_connection_state_change_handler(index, gateway_id),
on_peer_connection_state_change_handler(index, gateway_id, tunnel.stop_peer_command_sender.clone()),
);
}

View File

@@ -10,7 +10,9 @@ use webrtc::peer_connection::{
RTCPeerConnection,
};
use crate::control_protocol::new_peer_connection;
use crate::control_protocol::{
new_peer_connection, on_dc_close_handler, on_peer_connection_state_change_handler,
};
use crate::{peer::Peer, GatewayState, PeerConfig, Tunnel};
#[tracing::instrument(level = "trace", skip(tunnel))]
@@ -110,7 +112,7 @@ where
}
data_channel
.on_close(tunnel.clone().on_dc_close_handler(index, client_id));
.on_close(on_dc_close_handler(index, client_id, tunnel.stop_peer_command_sender.clone()));
let peer = Arc::new(Peer::new(
tunnel.private_key.clone(),
@@ -129,9 +131,9 @@ where
if let Some(conn) = tunnel.peer_connections.lock().get(&client_id) {
conn.on_peer_connection_state_change(
tunnel.clone().on_peer_connection_state_change_handler(
on_peer_connection_state_change_handler(
index,
client_id,
client_id, tunnel.stop_peer_command_sender.clone(),
),
);
}

View File

@@ -8,6 +8,7 @@ use connlib_shared::messages::{ClientId, Interface as InterfaceConfig};
use connlib_shared::Callbacks;
use futures::channel::mpsc::Receiver;
use futures_bounded::{PushError, StreamMap};
use futures_util::SinkExt;
use std::sync::Arc;
use std::task::{ready, Context, Poll};
use std::time::Duration;
@@ -24,9 +25,8 @@ where
config: &InterfaceConfig,
) -> connlib_shared::Result<()> {
let device = create_iface(config, self.callbacks()).await?;
*self.device.write().await = Some(device.clone());
self.start_timers().await?;
*self.device.write().await = Some(device.clone());
*self.iface_handler_abort.lock() =
Some(tokio::spawn(device_handler(Arc::clone(self), device)).abort_handle());
@@ -63,11 +63,16 @@ where
continue;
};
if let Err(e) = tunnel
.encapsulate_and_send_to_peer(packet, peer, &dest, &mut buf)
.await
{
tracing::error!(err = ?e, "failed to handle packet {e:#}")
if let Err(e) = peer.send(packet, dest, &mut buf).await {
tracing::error!(resource_address = %dest, err = ?e, "failed to handle packet {e:#}");
if e.is_fatal_connection_error() {
let _ = tunnel
.stop_peer_command_sender
.clone()
.send((peer.index, peer.conn_id))
.await;
}
}
}
}

View File

@@ -1,65 +0,0 @@
use std::{net::IpAddr, sync::Arc};
use boringtun::noise::{errors::WireGuardError, TunnResult};
use bytes::Bytes;
use connlib_shared::{Callbacks, Result};
use crate::{ip_packet::MutableIpPacket, peer::Peer, RoleState, Tunnel};
impl<CB, TRoleState> Tunnel<CB, TRoleState>
where
CB: Callbacks + 'static,
TRoleState: RoleState,
{
#[inline(always)]
pub(crate) async fn encapsulate_and_send_to_peer<'a>(
&self,
mut packet: MutableIpPacket<'_>,
peer: Arc<Peer<TRoleState::Id>>,
dst_addr: &IpAddr,
buf: &mut [u8],
) -> Result<()> {
let encapsulated_packet = peer.encapsulate(&mut packet, buf)?;
match encapsulated_packet.encapsulate_result {
TunnResult::Done => Ok(()),
TunnResult::Err(WireGuardError::ConnectionExpired)
| TunnResult::Err(WireGuardError::NoCurrentSession) => {
self.stop_peer(encapsulated_packet.index, encapsulated_packet.conn_id)
.await;
Ok(())
}
TunnResult::Err(e) => {
tracing::error!(resource_address = %dst_addr, error = ?e, "resource_connection");
let err = e.into();
let _ = self.callbacks.on_error(&err);
Err(err)
}
TunnResult::WriteToNetwork(packet) => {
tracing::trace!(target: "wire", action = "writing", from = "iface", to = %dst_addr);
if let Err(e) = encapsulated_packet
.channel
.write(&Bytes::copy_from_slice(packet))
.await
{
tracing::error!(?e, "webrtc_write");
if matches!(
e,
webrtc::data::Error::ErrStreamClosed
| webrtc::data::Error::Sctp(webrtc::sctp::Error::ErrStreamClosed)
) {
self.stop_peer(encapsulated_packet.index, encapsulated_packet.conn_id)
.await;
}
let err = e.into();
let _ = self.callbacks.on_error(&err);
Err(err)
} else {
Ok(())
}
}
_ => panic!("Unexpected result from encapsulate"),
}
}
}

View File

@@ -3,10 +3,9 @@
//! This is both the wireguard and ICE implementation that should work in tandem.
//! [Tunnel] is the main entry-point for this crate.
use boringtun::{
noise::{errors::WireGuardError, rate_limiter::RateLimiter, TunnResult},
noise::rate_limiter::RateLimiter,
x25519::{PublicKey, StaticSecret},
};
use bytes::Bytes;
use connlib_shared::{messages::Key, CallbackErrorFacade, Callbacks, Error};
use ip_network::IpNetwork;
@@ -29,9 +28,12 @@ use webrtc::{
peer_connection::RTCPeerConnection,
};
use futures::channel::mpsc;
use futures_util::{SinkExt, StreamExt};
use std::hash::Hash;
use std::task::{Context, Poll};
use std::{collections::HashMap, fmt, io, net::IpAddr, sync::Arc, time::Duration};
use tokio::time::Interval;
use webrtc::ice_transport::ice_candidate::RTCIceCandidateInit;
use connlib_shared::{
@@ -56,19 +58,14 @@ mod control_protocol;
mod device_channel;
mod dns;
mod gateway;
mod iface_handler;
mod index;
mod ip_packet;
mod peer;
mod peer_handler;
mod resource_sender;
mod resource_table;
mod tokio_util;
const MAX_UDP_SIZE: usize = (1 << 16) - 1;
const RESET_PACKET_COUNT_INTERVAL: Duration = Duration::from_secs(1);
const REFRESH_PEERS_TIMERS_INTERVAL: Duration = Duration::from_secs(1);
const REFRESH_MTU_INTERVAL: Duration = Duration::from_secs(30);
const DNS_QUERIES_QUEUE_SIZE: usize = 100;
/// For how long we will attempt to gather ICE candidates before aborting.
@@ -152,6 +149,13 @@ pub struct Tunnel<CB: Callbacks, TRoleState: RoleState> {
/// State that differs per role, i.e. clients vs gateways.
role_state: Mutex<TRoleState>,
stop_peer_command_receiver: Mutex<mpsc::Receiver<(u32, TRoleState::Id)>>,
stop_peer_command_sender: mpsc::Sender<(u32, TRoleState::Id)>,
rate_limit_reset_interval: Mutex<Interval>,
peer_refresh_interval: Mutex<Interval>,
mtu_refresh_interval: Mutex<Interval>,
}
// TODO: For now we only use these fields with debug
@@ -189,7 +193,109 @@ where
}
pub fn poll_next_event(&self, cx: &mut Context<'_>) -> Poll<Event<TRoleState::Id>> {
self.role_state.lock().poll_next_event(cx)
loop {
if self
.rate_limit_reset_interval
.lock()
.poll_tick(cx)
.is_ready()
{
self.rate_limiter.reset_count();
continue;
}
if self.peer_refresh_interval.lock().poll_tick(cx).is_ready() {
let peers_to_refresh = {
let mut peers_by_ip = self.peers_by_ip.write();
peers_to_refresh(&mut peers_by_ip, self.stop_peer_command_sender.clone())
};
for peer in peers_to_refresh {
let callbacks = self.callbacks.clone();
let mut stop_command_sender = self.stop_peer_command_sender.clone();
tokio::spawn(async move {
if let Err(e) = peer.update_timers().await {
tracing::error!("Failed to update timers for peer: {e}");
let _ = callbacks.on_error(&e);
if e.is_fatal_connection_error() {
let _ = stop_command_sender.send((peer.index, peer.conn_id)).await;
}
}
});
}
continue;
}
if self.mtu_refresh_interval.lock().poll_tick(cx).is_ready() {
// We use `try_read` to acquire a lock on the device because we are within a synchronous context here and cannot use `.await`.
// The device is only updated during `add_route` and `set_interface` which would be extremely unlucky to hit at the same time as this timer.
// Even if we hit this, we just wait for the next tick to update the MTU.
let device = match self.device.try_read().map(|d| d.clone()) {
Ok(Some(device)) => device,
Ok(None) => {
let err = Error::ControlProtocolError;
tracing::error!(?err, "get_iface_config");
let _ = self.callbacks.on_error(&err);
continue;
}
Err(_) => {
tracing::debug!("Unlucky! Somebody is updating the device just as we are about to update its MTU, trying again on the next tick ...");
continue;
}
};
tokio::spawn({
let callbacks = self.callbacks.clone();
async move {
if let Err(e) = device.config.refresh_mtu().await {
tracing::error!(error = ?e, "refresh_mtu");
let _ = callbacks.on_error(&e);
}
}
});
}
if let Poll::Ready(event) = self.role_state.lock().poll_next_event(cx) {
return Poll::Ready(event);
}
if let Poll::Ready(Some((i, conn_id))) =
self.stop_peer_command_receiver.lock().poll_next_unpin(cx)
{
let mut peers = self.peers_by_ip.write();
let (maybe_network, maybe_peer) = peers
.iter()
.find_map(|(n, p)| (p.index == i).then_some((n, p.clone())))
.unzip();
if let Some(network) = maybe_network {
peers.remove(network);
}
if let Some(conn) = self.peer_connections.lock().remove(&conn_id) {
tokio::spawn({
let callbacks = self.callbacks.clone();
async move {
if let Some(peer) = maybe_peer {
let _ = peer.shutdown().await;
}
if let Err(e) = conn.close().await {
tracing::warn!(%conn_id, error = ?e, "Can't close peer");
let _ = callbacks.on_error(&e.into());
}
}
});
}
continue;
}
return Poll::Pending;
}
}
}
@@ -284,6 +390,8 @@ where
.with_setting_engine(setting_engine)
.build();
let (stop_peer_command_sender, stop_peer_command_receiver) = mpsc::channel(10);
Ok(Self {
rate_limiter,
private_key,
@@ -296,117 +404,14 @@ where
callbacks: CallbackErrorFacade(callbacks),
iface_handler_abort,
role_state: Default::default(),
stop_peer_command_receiver: Mutex::new(stop_peer_command_receiver),
stop_peer_command_sender,
rate_limit_reset_interval: Mutex::new(rate_limit_reset_interval()),
peer_refresh_interval: Mutex::new(peer_refresh_interval()),
mtu_refresh_interval: Mutex::new(mtu_refresh_interval()),
})
}
#[tracing::instrument(level = "trace", skip(self))]
async fn stop_peer(&self, index: u32, conn_id: TRoleState::Id) {
self.peers_by_ip.write().retain(|_, p| p.index != index);
let conn = self.peer_connections.lock().remove(&conn_id);
if let Some(conn) = conn {
if let Err(e) = conn.close().await {
tracing::warn!(error = ?e, "Can't close peer");
let _ = self.callbacks().on_error(&e.into());
}
}
}
async fn peer_refresh(&self, peer: &Peer<TRoleState::Id>, dst_buf: &mut [u8; MAX_UDP_SIZE]) {
let update_timers_result = peer.update_timers(&mut dst_buf[..]);
match update_timers_result {
TunnResult::Done => {}
TunnResult::Err(WireGuardError::ConnectionExpired)
| TunnResult::Err(WireGuardError::NoCurrentSession) => {
self.stop_peer(peer.index, peer.conn_id).await;
let _ = peer.shutdown().await;
}
TunnResult::Err(e) => tracing::error!(error = ?e, "timer_error"),
TunnResult::WriteToNetwork(packet) => {
let bytes = Bytes::copy_from_slice(packet);
peer.send_infallible(bytes, &self.callbacks).await
}
_ => panic!("Unexpected result from update_timers"),
};
}
fn start_rate_limiter_refresh_timer(self: &Arc<Self>) {
let rate_limiter = self.rate_limiter.clone();
tokio::spawn(async move {
let mut interval = tokio::time::interval(RESET_PACKET_COUNT_INTERVAL);
interval.set_missed_tick_behavior(MissedTickBehavior::Delay);
loop {
rate_limiter.reset_count();
interval.tick().await;
}
});
}
fn start_peers_refresh_timer(self: &Arc<Self>) {
let tunnel = self.clone();
tokio::spawn(async move {
let mut interval = tokio::time::interval(REFRESH_PEERS_TIMERS_INTERVAL);
interval.set_missed_tick_behavior(MissedTickBehavior::Delay);
let mut dst_buf = [0u8; MAX_UDP_SIZE];
loop {
remove_expired_peers(
&mut tunnel.peers_by_ip.write(),
&mut tunnel.peer_connections.lock(),
);
let peers: Vec<_> = tunnel
.peers_by_ip
.read()
.iter()
.map(|p| p.1)
.unique_by(|p| p.index)
.cloned()
.collect();
for peer in peers {
tunnel.peer_refresh(&peer, &mut dst_buf).await;
}
interval.tick().await;
}
});
}
async fn start_refresh_mtu_timer(self: &Arc<Self>) -> Result<()> {
let dev = self.clone();
let callbacks = self.callbacks().clone();
tokio::spawn(async move {
let mut interval = tokio::time::interval(REFRESH_MTU_INTERVAL);
interval.set_missed_tick_behavior(MissedTickBehavior::Delay);
loop {
interval.tick().await;
let Some(device) = dev.device.read().await.clone() else {
let err = Error::ControlProtocolError;
tracing::error!(?err, "get_iface_config");
let _ = callbacks.0.on_error(&err);
continue;
};
if let Err(e) = device.config.refresh_mtu().await {
tracing::error!(error = ?e, "refresh_mtu");
let _ = callbacks.0.on_error(&e);
}
}
});
Ok(())
}
async fn start_timers(self: &Arc<Self>) -> Result<()> {
self.start_refresh_mtu_timer().await?;
self.start_rate_limiter_refresh_timer();
self.start_peers_refresh_timer();
Ok(())
}
fn next_index(&self) -> u32 {
self.next_index.lock().next()
}
@@ -416,27 +421,68 @@ where
}
}
/// Constructs the interval for resetting the rate limit count.
///
/// As per documentation on [`RateLimiter::reset_count`], this is configured to run every second.
fn rate_limit_reset_interval() -> Interval {
let mut interval = tokio::time::interval(Duration::from_secs(1));
interval.set_missed_tick_behavior(MissedTickBehavior::Delay);
interval
}
/// Constructs the interval for "refreshing" peers.
///
/// On each tick, we remove expired peers from our map, update wireguard timers and send packets, if any.
fn peer_refresh_interval() -> Interval {
let mut interval = tokio::time::interval(Duration::from_secs(1));
interval.set_missed_tick_behavior(MissedTickBehavior::Delay);
interval
}
/// Constructs the interval for refreshing the MTU of our TUN device.
fn mtu_refresh_interval() -> Interval {
let mut interval = tokio::time::interval(Duration::from_secs(30));
interval.set_missed_tick_behavior(MissedTickBehavior::Delay);
interval
}
fn peers_to_refresh<TId>(
peers_by_ip: &mut IpNetworkTable<Arc<Peer<TId>>>,
shutdown_sender: mpsc::Sender<(u32, TId)>,
) -> Vec<Arc<Peer<TId>>>
where
TId: Eq + Hash + Copy + Send + Sync + 'static,
{
remove_expired_peers(peers_by_ip, shutdown_sender);
peers_by_ip
.iter()
.map(|p| p.1)
.unique_by(|p| p.index)
.cloned()
.collect()
}
fn remove_expired_peers<TId>(
peers_by_ip: &mut IpNetworkTable<Arc<Peer<TId>>>,
peer_connections: &mut HashMap<TId, Arc<RTCPeerConnection>>,
shutdown_sender: mpsc::Sender<(u32, TId)>,
) where
TId: Eq + Hash + Copy + Send + Sync + 'static,
{
for (_, peer) in peers_by_ip.iter() {
peer.expire_resources();
if peer.is_emptied() {
tracing::trace!(index = peer.index, "peer_expired");
let conn = peer_connections.remove(&peer.conn_id);
let p = peer.clone();
let index = peer.index;
let conn_id = peer.conn_id;
// We are holding a Mutex, particularly a write one, we don't want to make a blocking call
tokio::spawn(async move {
let _ = p.shutdown().await;
if let Some(conn) = conn {
// TODO: it seems that even closing the stream there are messages to the relay
// see where they come from.
let _ = conn.close().await;
}
tracing::trace!(%index, "peer_expired");
tokio::spawn({
let mut sender = shutdown_sender.clone();
async move { sender.send((index, conn_id)).await }
});
}
}
@@ -449,7 +495,7 @@ fn remove_expired_peers<TId>(
/// By design, this trait does not allow any operations apart from advancing via [`RoleState::poll_next_event`].
/// The state should only be modified when the concrete type is known, e.g. [`ClientState`] or [`GatewayState`].
pub trait RoleState: Default + Send + 'static {
type Id: fmt::Debug + Eq + Hash + Copy + Send + Sync + 'static;
type Id: fmt::Debug + fmt::Display + Eq + Hash + Copy + Unpin + Send + Sync + 'static;
fn poll_next_event(&mut self, cx: &mut Context<'_>) -> Poll<Event<Self::Id>>;
}

View File

@@ -22,10 +22,10 @@ type ExpiryingResource = (ResourceDescription, DateTime<Utc>);
pub(crate) struct Peer<TId> {
pub tunnel: Mutex<Tunn>,
pub index: u32,
pub allowed_ips: RwLock<IpNetworkTable<()>>,
allowed_ips: RwLock<IpNetworkTable<()>>,
pub channel: Arc<DataChannel>,
pub conn_id: TId,
pub resources: Option<RwLock<ResourceTable<ExpiryingResource>>>,
resources: Option<RwLock<ResourceTable<ExpiryingResource>>>,
// Here we store the address that we obtained for the resource that the peer corresponds to.
// This can have the following problem:
// 1. Peer sends packet to address.com and it resolves to 1.1.1.1
@@ -35,7 +35,7 @@ pub(crate) struct Peer<TId> {
// so, TODO: store multiple ips and expire them.
// Note that this case is quite an unlikely edge case so I wouldn't prioritize this fix
// TODO: Also check if there's any case where we want to talk to ipv4 and ipv6 from the same peer.
pub translated_resource_addresses: RwLock<HashMap<IpAddr, ResourceId>>,
translated_resource_addresses: RwLock<HashMap<IpAddr, ResourceId>>,
}
// TODO: For now we only use these fields with debug
@@ -50,14 +50,6 @@ pub(crate) struct PeerStats<TId> {
pub translated_resource_addresses: HashMap<IpAddr, ResourceId>,
}
#[derive(Debug)]
pub(crate) struct EncapsulatedPacket<'a, Id> {
pub index: u32,
pub conn_id: Id,
pub channel: Arc<DataChannel>,
pub encapsulate_result: TunnResult<'a>,
}
impl<TId> Peer<TId>
where
TId: Copy,
@@ -130,7 +122,7 @@ where
}
}
pub(crate) fn get_translation(&self, ip: IpAddr) -> Option<ResourceDescription> {
fn get_translation(&self, ip: IpAddr) -> Option<ResourceDescription> {
let id = self.translated_resource_addresses.read().get(&ip).cloned();
self.resources.as_ref().and_then(|resources| {
id.and_then(|id| resources.read().get_by_id(&id).map(|r| r.0.clone()))
@@ -141,8 +133,25 @@ where
self.allowed_ips.write().insert(ip, ());
}
pub(crate) fn update_timers<'a>(&self, dst: &'a mut [u8]) -> TunnResult<'a> {
self.tunnel.lock().update_timers(dst)
pub(crate) async fn update_timers<'a>(&self) -> Result<()> {
/// [`boringtun`] requires us to pass buffers in where it can construct its packets.
///
/// When updating the timers, the largest packet that we may have to send is `148` bytes as per `HANDSHAKE_INIT_SZ` constant in [`boringtun`].
const MAX_SCRATCH_SPACE: usize = 148;
let mut buf = [0u8; MAX_SCRATCH_SPACE];
let packet = match self.tunnel.lock().update_timers(&mut buf) {
TunnResult::Done => return Ok(()),
TunnResult::Err(e) => return Err(e.into()),
TunnResult::WriteToNetwork(b) => b,
_ => panic!("Unexpected result from update_timers"),
};
let bytes = Bytes::copy_from_slice(packet);
self.channel.write(&bytes).await?;
Ok(())
}
pub(crate) async fn shutdown(&self) -> Result<()> {
@@ -189,11 +198,13 @@ where
self.translated_resource_addresses.write().insert(addr, id);
}
pub(crate) fn encapsulate<'a>(
/// Sends the given packet to this peer by encapsulating it in a wireguard packet.
pub(crate) async fn send<'a>(
&self,
packet: &mut MutableIpPacket<'a>,
dst: &'a mut [u8],
) -> Result<EncapsulatedPacket<'a, TId>> {
mut packet: MutableIpPacket<'a>,
dest: IpAddr,
buf: &mut [u8],
) -> Result<()> {
if let Some(resource) = self.get_translation(packet.to_immutable().source()) {
let ResourceDescription::Dns(resource) = resource else {
tracing::error!(
@@ -209,12 +220,18 @@ where
packet.update_checksum();
}
Ok(EncapsulatedPacket {
index: self.index,
conn_id: self.conn_id,
channel: self.channel.clone(),
encapsulate_result: self.tunnel.lock().encapsulate(packet.packet_mut(), dst),
})
let packet = match self.tunnel.lock().encapsulate(packet.packet_mut(), buf) {
TunnResult::Done => return Ok(()),
TunnResult::Err(e) => return Err(e.into()),
TunnResult::WriteToNetwork(b) => b,
_ => panic!("Unexpected result from `encapsulate`"),
};
tracing::trace!(target: "wire", action = "writing", from = "iface", to = %dest);
self.channel.write(&Bytes::copy_from_slice(packet)).await?;
Ok(())
}
pub(crate) fn get_packet_resource(

View File

@@ -1,9 +1,14 @@
use std::net::{IpAddr, ToSocketAddrs};
use std::sync::Arc;
use boringtun::noise::{handshake::parse_handshake_anon, Packet, TunnResult};
use boringtun::x25519::{PublicKey, StaticSecret};
use bytes::Bytes;
use connlib_shared::messages::ResourceDescription;
use connlib_shared::{Callbacks, Error, Result};
use futures_util::SinkExt;
use crate::ip_packet::MutableIpPacket;
use crate::{
device_channel::DeviceIo, index::check_packet_index, peer::Peer, RoleState, Tunnel,
MAX_UDP_SIZE,
@@ -14,16 +19,95 @@ where
CB: Callbacks + 'static,
TRoleState: RoleState,
{
#[inline(always)]
fn is_wireguard_packet_ok(&self, parsed_packet: &Packet, peer: &Peer<TRoleState::Id>) -> bool {
match &parsed_packet {
Packet::HandshakeInit(p) => {
parse_handshake_anon(&self.private_key, &self.public_key, p).is_ok()
pub(crate) async fn start_peer_handler(self: Arc<Self>, peer: Arc<Peer<TRoleState::Id>>) {
loop {
let Some(device) = self.device.read().await.clone() else {
let err = Error::NoIface;
tracing::error!(?err);
let _ = self.callbacks().on_disconnect(Some(&err));
break;
};
let device_io = device.io;
if let Err(err) = self.peer_handler(&peer, device_io).await {
if err.raw_os_error() != Some(9) {
tracing::error!(?err);
let _ = self.callbacks().on_error(&err.into());
break;
} else {
tracing::warn!("bad_file_descriptor");
}
}
Packet::HandshakeResponse(p) => check_packet_index(p.receiver_idx, peer.index),
Packet::PacketCookieReply(p) => check_packet_index(p.receiver_idx, peer.index),
Packet::PacketData(p) => check_packet_index(p.receiver_idx, peer.index),
}
tracing::debug!(peer = ?peer.stats(), "peer_stopped");
let _ = self
.stop_peer_command_sender
.clone()
.send((peer.index, peer.conn_id))
.await;
}
async fn peer_handler(
self: &Arc<Self>,
peer: &Arc<Peer<TRoleState::Id>>,
device_io: DeviceIo,
) -> std::io::Result<()> {
let mut src_buf = [0u8; MAX_UDP_SIZE];
let mut dst_buf = [0u8; MAX_UDP_SIZE];
while let Ok(size) = peer.channel.read(&mut src_buf[..]).await {
tracing::trace!(target: "wire", action = "read", bytes = size, from = "peer");
// TODO: Double check that this can only happen on closed channel
// I think it's possible to transmit a 0-byte message through the channel
// but we would never use that.
// We should keep track of an open/closed channel ourselves if we wanted to do it properly then.
if size == 0 {
break;
}
if let Err(Error::Io(e)) = self
.handle_peer_packet(peer, &device_io, &src_buf[..size], &mut dst_buf)
.await
{
return Err(e);
}
}
Ok(())
}
#[inline(always)]
pub(crate) async fn handle_peer_packet(
self: &Arc<Self>,
peer: &Arc<Peer<TRoleState::Id>>,
device_writer: &DeviceIo,
src: &[u8],
dst: &mut [u8],
) -> Result<()> {
let parsed_packet = self.verify_packet(peer, src, dst).await?;
if !is_wireguard_packet_ok(&self.private_key, &self.public_key, &parsed_packet, peer) {
tracing::error!("wireguard_verification");
return Err(Error::BadPacket);
}
let decapsulate_result = peer.tunnel.lock().decapsulate(None, src, dst);
if self
.handle_decapsulated_packet(peer, device_writer, decapsulate_result)
.await?
{
// Flush pending queue
while let TunnResult::WriteToNetwork(packet) =
peer.tunnel.lock().decapsulate(None, &[], dst)
{
let bytes = Bytes::copy_from_slice(packet);
let callbacks = self.callbacks.clone();
let peer = peer.clone();
tokio::spawn(async move { peer.send_infallible(bytes, &callbacks).await });
}
}
Ok(())
}
#[inline(always)]
@@ -80,101 +164,143 @@ where
Ok(true)
}
TunnResult::WriteToTunnelV4(packet, addr) => {
self.send_to_resource(device_io, peer, addr.into(), packet)?;
send_to_resource(device_io, peer, addr.into(), packet)?;
Ok(false)
}
TunnResult::WriteToTunnelV6(packet, addr) => {
self.send_to_resource(device_io, peer, addr.into(), packet)?;
send_to_resource(device_io, peer, addr.into(), packet)?;
Ok(false)
}
}
}
}
#[inline(always)]
pub(crate) async fn handle_peer_packet(
self: &Arc<Self>,
peer: &Arc<Peer<TRoleState::Id>>,
device_writer: &DeviceIo,
src: &[u8],
dst: &mut [u8],
) -> Result<()> {
let parsed_packet = self.verify_packet(peer, src, dst).await?;
if !self.is_wireguard_packet_ok(&parsed_packet, peer) {
tracing::error!("wireguard_verification");
return Err(Error::BadPacket);
}
let decapsulate_result = peer.tunnel.lock().decapsulate(None, src, dst);
if self
.handle_decapsulated_packet(peer, device_writer, decapsulate_result)
.await?
{
// Flush pending queue
while let TunnResult::WriteToNetwork(packet) = {
let res = peer.tunnel.lock().decapsulate(None, &[], dst);
res
} {
let bytes = Bytes::copy_from_slice(packet);
let callbacks = self.callbacks.clone();
let peer = peer.clone();
tokio::spawn(async move { peer.send_infallible(bytes, &callbacks).await });
}
}
Ok(())
}
async fn peer_handler(
self: &Arc<Self>,
peer: &Arc<Peer<TRoleState::Id>>,
device_io: DeviceIo,
) -> std::io::Result<()> {
let mut src_buf = [0u8; MAX_UDP_SIZE];
let mut dst_buf = [0u8; MAX_UDP_SIZE];
while let Ok(size) = peer.channel.read(&mut src_buf[..]).await {
tracing::trace!(target: "wire", action = "read", bytes = size, from = "peer");
// TODO: Double check that this can only happen on closed channel
// I think it's possible to transmit a 0-byte message through the channel
// but we would never use that.
// We should keep track of an open/closed channel ourselves if we wanted to do it properly then.
if size == 0 {
break;
}
if let Err(Error::Io(e)) = self
.handle_peer_packet(peer, &device_io, &src_buf[..size], &mut dst_buf)
.await
{
return Err(e);
}
}
Ok(())
}
pub(crate) async fn start_peer_handler(self: Arc<Self>, peer: Arc<Peer<TRoleState::Id>>) {
loop {
let Some(device) = self.device.read().await.clone() else {
let err = Error::NoIface;
tracing::error!(?err);
let _ = self.callbacks().on_disconnect(Some(&err));
break;
};
let device_io = device.io;
if let Err(err) = self.peer_handler(&peer, device_io).await {
if err.raw_os_error() != Some(9) {
tracing::error!(?err);
let _ = self.callbacks().on_error(&err.into());
break;
} else {
tracing::warn!("bad_file_descriptor");
}
}
}
tracing::debug!(peer = ?peer.stats(), "peer_stopped");
self.stop_peer(peer.index, peer.conn_id).await;
#[inline(always)]
fn is_wireguard_packet_ok<TId>(
private_key: &StaticSecret,
public_key: &PublicKey,
parsed_packet: &Packet,
peer: &Peer<TId>,
) -> bool {
match parsed_packet {
Packet::HandshakeInit(p) => parse_handshake_anon(private_key, public_key, p).is_ok(),
Packet::HandshakeResponse(p) => check_packet_index(p.receiver_idx, peer.index),
Packet::PacketCookieReply(p) => check_packet_index(p.receiver_idx, peer.index),
Packet::PacketData(p) => check_packet_index(p.receiver_idx, peer.index),
}
}
fn send_to_resource<TId>(
device_io: &DeviceIo,
peer: &Arc<Peer<TId>>,
addr: IpAddr,
packet: &mut [u8],
) -> Result<()>
where
TId: Copy,
{
if peer.is_allowed(addr) {
packet_allowed(device_io, peer, addr, packet)?;
Ok(())
} else {
tracing::warn!(%addr, "Received packet from peer with an unallowed ip");
Ok(())
}
}
#[inline(always)]
pub(crate) fn packet_allowed<TId>(
device_io: &DeviceIo,
peer: &Arc<Peer<TId>>,
addr: IpAddr,
packet: &mut [u8],
) -> Result<()>
where
TId: Copy,
{
let Some((dst, resource)) = peer.get_packet_resource(packet) else {
// If there's no associated resource it means that we are in a client, then the packet comes from a gateway
// and we just trust gateways.
// In gateways this should never happen.
tracing::trace!(target: "wire", action = "writing", to = "iface", %addr, bytes = %packet.len());
send_packet(device_io, packet, addr)?;
return Ok(());
};
let (dst_addr, _dst_port) = get_resource_addr_and_port(peer, &resource, &addr, &dst)?;
update_packet(packet, dst_addr);
send_packet(device_io, packet, addr)?;
Ok(())
}
#[inline(always)]
fn send_packet(device_io: &DeviceIo, packet: &mut [u8], dst_addr: IpAddr) -> std::io::Result<()> {
match dst_addr {
IpAddr::V4(_) => device_io.write4(packet)?,
IpAddr::V6(_) => device_io.write6(packet)?,
};
Ok(())
}
#[inline(always)]
fn update_packet(packet: &mut [u8], dst_addr: IpAddr) {
let Some(mut pkt) = MutableIpPacket::new(packet) else {
return;
};
pkt.set_dst(dst_addr);
pkt.update_checksum();
}
fn get_matching_version_ip(addr: &IpAddr, ip: &IpAddr) -> Option<IpAddr> {
((addr.is_ipv4() && ip.is_ipv4()) || (addr.is_ipv6() && ip.is_ipv6())).then_some(*ip)
}
fn get_resource_addr_and_port<TId>(
peer: &Arc<Peer<TId>>,
resource: &ResourceDescription,
addr: &IpAddr,
dst: &IpAddr,
) -> Result<(IpAddr, Option<u16>)>
where
TId: Copy,
{
match resource {
ResourceDescription::Dns(r) => {
let mut address = r.address.split(':');
let Some(dst_addr) = address.next() else {
tracing::error!("invalid DNS name for resource: {}", r.address);
return Err(Error::InvalidResource);
};
let Ok(mut dst_addr) = (dst_addr, 0).to_socket_addrs() else {
tracing::warn!(%addr, "Couldn't resolve name");
return Err(Error::InvalidResource);
};
let Some(dst_addr) = dst_addr.find_map(|d| get_matching_version_ip(addr, &d.ip()))
else {
tracing::warn!(%addr, "Couldn't resolve name addr");
return Err(Error::InvalidResource);
};
peer.update_translated_resource_address(r.id, dst_addr);
Ok((
dst_addr,
address
.next()
.map(str::parse::<u16>)
.and_then(std::result::Result::ok),
))
}
ResourceDescription::Cidr(r) => {
if r.address.contains(*dst) {
Ok((
get_matching_version_ip(addr, dst).ok_or(Error::InvalidResource)?,
None,
))
} else {
tracing::warn!(
"client tried to hijack the tunnel for range outside what it's allowed."
);
Err(Error::InvalidSource)
}
}
}
}

View File

@@ -1,130 +0,0 @@
use std::{
net::{IpAddr, ToSocketAddrs},
sync::Arc,
};
use crate::{device_channel::DeviceIo, ip_packet::MutableIpPacket, peer::Peer, RoleState, Tunnel};
use connlib_shared::{messages::ResourceDescription, Callbacks, Error, Result};
impl<CB, TRoleState> Tunnel<CB, TRoleState>
where
CB: Callbacks + 'static,
TRoleState: RoleState,
{
#[inline(always)]
fn update_packet(&self, packet: &mut [u8], dst_addr: IpAddr) {
let Some(mut pkt) = MutableIpPacket::new(packet) else {
return;
};
pkt.set_dst(dst_addr);
pkt.update_checksum();
}
#[inline(always)]
fn send_packet(
&self,
device_io: &DeviceIo,
packet: &mut [u8],
dst_addr: IpAddr,
) -> std::io::Result<()> {
match dst_addr {
IpAddr::V4(_) => device_io.write4(packet)?,
IpAddr::V6(_) => device_io.write6(packet)?,
};
Ok(())
}
#[inline(always)]
pub(crate) fn packet_allowed(
&self,
device_io: &DeviceIo,
peer: &Arc<Peer<TRoleState::Id>>,
addr: IpAddr,
packet: &mut [u8],
) -> Result<()> {
let Some((dst, resource)) = peer.get_packet_resource(packet) else {
// If there's no associated resource it means that we are in a client, then the packet comes from a gateway
// and we just trust gateways.
// In gateways this should never happen.
tracing::trace!(target: "wire", action = "writing", to = "iface", %addr, bytes = %packet.len());
self.send_packet(device_io, packet, addr)?;
return Ok(());
};
let (dst_addr, _dst_port) = get_resource_addr_and_port(peer, &resource, &addr, &dst)?;
self.update_packet(packet, dst_addr);
self.send_packet(device_io, packet, addr)?;
Ok(())
}
pub(crate) fn send_to_resource(
&self,
device_io: &DeviceIo,
peer: &Arc<Peer<TRoleState::Id>>,
addr: IpAddr,
packet: &mut [u8],
) -> Result<()> {
if peer.is_allowed(addr) {
self.packet_allowed(device_io, peer, addr, packet)?;
Ok(())
} else {
tracing::warn!(%addr, "Received packet from peer with an unallowed ip");
Ok(())
}
}
}
fn get_matching_version_ip(addr: &IpAddr, ip: &IpAddr) -> Option<IpAddr> {
((addr.is_ipv4() && ip.is_ipv4()) || (addr.is_ipv6() && ip.is_ipv6())).then_some(*ip)
}
fn get_resource_addr_and_port<TId>(
peer: &Arc<Peer<TId>>,
resource: &ResourceDescription,
addr: &IpAddr,
dst: &IpAddr,
) -> Result<(IpAddr, Option<u16>)>
where
TId: Copy,
{
match resource {
ResourceDescription::Dns(r) => {
let mut address = r.address.split(':');
let Some(dst_addr) = address.next() else {
tracing::error!("invalid DNS name for resource: {}", r.address);
return Err(Error::InvalidResource);
};
let Ok(mut dst_addr) = (dst_addr, 0).to_socket_addrs() else {
tracing::warn!(%addr, "Couldn't resolve name");
return Err(Error::InvalidResource);
};
let Some(dst_addr) = dst_addr.find_map(|d| get_matching_version_ip(addr, &d.ip()))
else {
tracing::warn!(%addr, "Couldn't resolve name addr");
return Err(Error::InvalidResource);
};
peer.update_translated_resource_address(r.id, dst_addr);
Ok((
dst_addr,
address
.next()
.map(str::parse::<u16>)
.and_then(std::result::Result::ok),
))
}
ResourceDescription::Cidr(r) => {
if r.address.contains(*dst) {
Ok((
get_matching_version_ip(addr, dst).ok_or(Error::InvalidResource)?,
None,
))
} else {
tracing::warn!(
"client tried to hijack the tunnel for range outside what it's allowed."
);
Err(Error::InvalidSource)
}
}
}
}