mirror of
https://github.com/outbackdingo/firezone.git
synced 2026-01-27 10:18:54 +00:00
refactor(connlib): move more logic to poll_next_event (#2403)
This commit is contained in:
@@ -145,6 +145,26 @@ impl ConnlibError {
|
||||
if e.status().is_client_error()
|
||||
)
|
||||
}
|
||||
|
||||
/// Whether this error is fatal to the underlying connection.
|
||||
pub fn is_fatal_connection_error(&self) -> bool {
|
||||
if let Self::WireguardError(e) = self {
|
||||
return matches!(
|
||||
e,
|
||||
WireGuardError::ConnectionExpired | WireGuardError::NoCurrentSession
|
||||
);
|
||||
}
|
||||
|
||||
if let Self::IceDataError(e) = self {
|
||||
return matches!(
|
||||
e,
|
||||
webrtc::data::Error::ErrStreamClosed
|
||||
| webrtc::data::Error::Sctp(webrtc::sctp::Error::ErrStreamClosed)
|
||||
);
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(target_os = "linux")]
|
||||
|
||||
@@ -18,6 +18,7 @@ use connlib_shared::{Callbacks, DNS_SENTINEL};
|
||||
use futures::channel::mpsc::Receiver;
|
||||
use futures::stream;
|
||||
use futures_bounded::{PushError, StreamMap};
|
||||
use futures_util::SinkExt;
|
||||
use hickory_resolver::lookup::Lookup;
|
||||
use ip_network::IpNetwork;
|
||||
use ip_network_table::IpNetworkTable;
|
||||
@@ -116,9 +117,8 @@ where
|
||||
config: &InterfaceConfig,
|
||||
) -> connlib_shared::Result<()> {
|
||||
let device = create_iface(config, self.callbacks()).await?;
|
||||
*self.device.write().await = Some(device.clone());
|
||||
|
||||
self.start_timers().await?;
|
||||
*self.device.write().await = Some(device.clone());
|
||||
*self.iface_handler_abort.lock() = Some(tokio_util::spawn_log(
|
||||
&self.callbacks,
|
||||
device_handler(Arc::clone(self), device),
|
||||
@@ -203,12 +203,18 @@ where
|
||||
continue;
|
||||
};
|
||||
|
||||
if let Err(e) = tunnel
|
||||
.encapsulate_and_send_to_peer(packet, peer, &dest, &mut buf)
|
||||
.await
|
||||
{
|
||||
if let Err(e) = peer.send(packet, dest, &mut buf).await {
|
||||
tracing::error!(resource_address = %dest, err = ?e, "failed to handle packet {e:#}");
|
||||
|
||||
let _ = tunnel.callbacks.on_error(&e);
|
||||
tracing::error!(err = ?e, "failed to handle packet {e:#}")
|
||||
|
||||
if e.is_fatal_connection_error() {
|
||||
let _ = tunnel
|
||||
.stop_peer_command_sender
|
||||
.clone()
|
||||
.send((peer.index, peer.conn_id))
|
||||
.await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -38,36 +38,6 @@ where
|
||||
CB: Callbacks + 'static,
|
||||
TRoleState: RoleState,
|
||||
{
|
||||
pub fn on_dc_close_handler(
|
||||
self: Arc<Self>,
|
||||
index: u32,
|
||||
conn_id: TRoleState::Id,
|
||||
) -> OnCloseHdlrFn {
|
||||
Box::new(move || {
|
||||
tracing::debug!("channel_closed");
|
||||
let tunnel = self.clone();
|
||||
Box::pin(async move {
|
||||
tunnel.stop_peer(index, conn_id).await;
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
pub fn on_peer_connection_state_change_handler(
|
||||
self: Arc<Self>,
|
||||
index: u32,
|
||||
conn_id: TRoleState::Id,
|
||||
) -> OnPeerConnectionStateChangeHdlrFn {
|
||||
Box::new(move |state| {
|
||||
let tunnel = Arc::clone(&self);
|
||||
Box::pin(async move {
|
||||
tracing::trace!(?state, "peer_state_update");
|
||||
if state == RTCPeerConnectionState::Failed {
|
||||
tunnel.stop_peer(index, conn_id).await;
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn add_ice_candidate(
|
||||
&self,
|
||||
conn_id: TRoleState::Id,
|
||||
@@ -84,6 +54,44 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
pub fn on_peer_connection_state_change_handler<TId>(
|
||||
index: u32,
|
||||
conn_id: TId,
|
||||
stop_command_sender: mpsc::Sender<(u32, TId)>,
|
||||
) -> OnPeerConnectionStateChangeHdlrFn
|
||||
where
|
||||
TId: Copy + Send + Sync + 'static,
|
||||
{
|
||||
Box::new(move |state| {
|
||||
let mut sender = stop_command_sender.clone();
|
||||
|
||||
tracing::trace!(?state, "peer_state_update");
|
||||
Box::pin(async move {
|
||||
if state == RTCPeerConnectionState::Failed {
|
||||
let _ = sender.send((index, conn_id)).await;
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
pub fn on_dc_close_handler<TId>(
|
||||
index: u32,
|
||||
conn_id: TId,
|
||||
stop_command_sender: mpsc::Sender<(u32, TId)>,
|
||||
) -> OnCloseHdlrFn
|
||||
where
|
||||
TId: Copy + Send + Sync + 'static,
|
||||
{
|
||||
Box::new(move || {
|
||||
let mut sender = stop_command_sender.clone();
|
||||
|
||||
tracing::debug!("channel_closed");
|
||||
Box::pin(async move {
|
||||
let _ = sender.send((index, conn_id)).await;
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
#[tracing::instrument(level = "trace", skip(webrtc))]
|
||||
pub async fn new_peer_connection(
|
||||
webrtc: &webrtc::api::API,
|
||||
|
||||
@@ -16,25 +16,11 @@ use webrtc::{
|
||||
},
|
||||
};
|
||||
|
||||
use crate::control_protocol::new_peer_connection;
|
||||
use crate::control_protocol::{
|
||||
new_peer_connection, on_dc_close_handler, on_peer_connection_state_change_handler,
|
||||
};
|
||||
use crate::{peer::Peer, ClientState, Error, Request, Result, Tunnel};
|
||||
|
||||
#[tracing::instrument(level = "trace", skip(tunnel))]
|
||||
fn handle_connection_state_update<CB>(
|
||||
tunnel: &Arc<Tunnel<CB, ClientState>>,
|
||||
state: RTCPeerConnectionState,
|
||||
gateway_id: GatewayId,
|
||||
resource_id: ResourceId,
|
||||
) where
|
||||
CB: Callbacks + 'static,
|
||||
{
|
||||
tracing::trace!("peer_state");
|
||||
if state == RTCPeerConnectionState::Failed {
|
||||
tunnel.role_state.lock().on_connection_failed(resource_id);
|
||||
tunnel.peer_connections.lock().remove(&gateway_id);
|
||||
}
|
||||
}
|
||||
|
||||
#[tracing::instrument(level = "trace", skip(tunnel))]
|
||||
fn set_connection_state_update<CB>(
|
||||
tunnel: &Arc<Tunnel<CB, ClientState>>,
|
||||
@@ -49,7 +35,11 @@ fn set_connection_state_update<CB>(
|
||||
move |state: RTCPeerConnectionState| {
|
||||
let tunnel = Arc::clone(&tunnel);
|
||||
Box::pin(async move {
|
||||
handle_connection_state_update(&tunnel, state, gateway_id, resource_id)
|
||||
tracing::trace!("peer_state");
|
||||
if state == RTCPeerConnectionState::Failed {
|
||||
tunnel.role_state.lock().on_connection_failed(resource_id);
|
||||
tunnel.peer_connections.lock().remove(&gateway_id);
|
||||
}
|
||||
})
|
||||
},
|
||||
));
|
||||
@@ -139,7 +129,7 @@ where
|
||||
}
|
||||
};
|
||||
|
||||
d.on_close(tunnel.clone().on_dc_close_handler(index, gateway_id));
|
||||
d.on_close(on_dc_close_handler(index, gateway_id, tunnel.stop_peer_command_sender.clone()));
|
||||
|
||||
let peer = Arc::new(Peer::new(
|
||||
tunnel.private_key.clone(),
|
||||
@@ -171,9 +161,7 @@ where
|
||||
|
||||
if let Some(conn) = tunnel.peer_connections.lock().get(&gateway_id) {
|
||||
conn.on_peer_connection_state_change(
|
||||
tunnel
|
||||
.clone()
|
||||
.on_peer_connection_state_change_handler(index, gateway_id),
|
||||
on_peer_connection_state_change_handler(index, gateway_id, tunnel.stop_peer_command_sender.clone()),
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
@@ -10,7 +10,9 @@ use webrtc::peer_connection::{
|
||||
RTCPeerConnection,
|
||||
};
|
||||
|
||||
use crate::control_protocol::new_peer_connection;
|
||||
use crate::control_protocol::{
|
||||
new_peer_connection, on_dc_close_handler, on_peer_connection_state_change_handler,
|
||||
};
|
||||
use crate::{peer::Peer, GatewayState, PeerConfig, Tunnel};
|
||||
|
||||
#[tracing::instrument(level = "trace", skip(tunnel))]
|
||||
@@ -110,7 +112,7 @@ where
|
||||
}
|
||||
|
||||
data_channel
|
||||
.on_close(tunnel.clone().on_dc_close_handler(index, client_id));
|
||||
.on_close(on_dc_close_handler(index, client_id, tunnel.stop_peer_command_sender.clone()));
|
||||
|
||||
let peer = Arc::new(Peer::new(
|
||||
tunnel.private_key.clone(),
|
||||
@@ -129,9 +131,9 @@ where
|
||||
|
||||
if let Some(conn) = tunnel.peer_connections.lock().get(&client_id) {
|
||||
conn.on_peer_connection_state_change(
|
||||
tunnel.clone().on_peer_connection_state_change_handler(
|
||||
on_peer_connection_state_change_handler(
|
||||
index,
|
||||
client_id,
|
||||
client_id, tunnel.stop_peer_command_sender.clone(),
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ use connlib_shared::messages::{ClientId, Interface as InterfaceConfig};
|
||||
use connlib_shared::Callbacks;
|
||||
use futures::channel::mpsc::Receiver;
|
||||
use futures_bounded::{PushError, StreamMap};
|
||||
use futures_util::SinkExt;
|
||||
use std::sync::Arc;
|
||||
use std::task::{ready, Context, Poll};
|
||||
use std::time::Duration;
|
||||
@@ -24,9 +25,8 @@ where
|
||||
config: &InterfaceConfig,
|
||||
) -> connlib_shared::Result<()> {
|
||||
let device = create_iface(config, self.callbacks()).await?;
|
||||
*self.device.write().await = Some(device.clone());
|
||||
|
||||
self.start_timers().await?;
|
||||
*self.device.write().await = Some(device.clone());
|
||||
*self.iface_handler_abort.lock() =
|
||||
Some(tokio::spawn(device_handler(Arc::clone(self), device)).abort_handle());
|
||||
|
||||
@@ -63,11 +63,16 @@ where
|
||||
continue;
|
||||
};
|
||||
|
||||
if let Err(e) = tunnel
|
||||
.encapsulate_and_send_to_peer(packet, peer, &dest, &mut buf)
|
||||
.await
|
||||
{
|
||||
tracing::error!(err = ?e, "failed to handle packet {e:#}")
|
||||
if let Err(e) = peer.send(packet, dest, &mut buf).await {
|
||||
tracing::error!(resource_address = %dest, err = ?e, "failed to handle packet {e:#}");
|
||||
|
||||
if e.is_fatal_connection_error() {
|
||||
let _ = tunnel
|
||||
.stop_peer_command_sender
|
||||
.clone()
|
||||
.send((peer.index, peer.conn_id))
|
||||
.await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,65 +0,0 @@
|
||||
use std::{net::IpAddr, sync::Arc};
|
||||
|
||||
use boringtun::noise::{errors::WireGuardError, TunnResult};
|
||||
use bytes::Bytes;
|
||||
use connlib_shared::{Callbacks, Result};
|
||||
|
||||
use crate::{ip_packet::MutableIpPacket, peer::Peer, RoleState, Tunnel};
|
||||
|
||||
impl<CB, TRoleState> Tunnel<CB, TRoleState>
|
||||
where
|
||||
CB: Callbacks + 'static,
|
||||
TRoleState: RoleState,
|
||||
{
|
||||
#[inline(always)]
|
||||
pub(crate) async fn encapsulate_and_send_to_peer<'a>(
|
||||
&self,
|
||||
mut packet: MutableIpPacket<'_>,
|
||||
peer: Arc<Peer<TRoleState::Id>>,
|
||||
dst_addr: &IpAddr,
|
||||
buf: &mut [u8],
|
||||
) -> Result<()> {
|
||||
let encapsulated_packet = peer.encapsulate(&mut packet, buf)?;
|
||||
|
||||
match encapsulated_packet.encapsulate_result {
|
||||
TunnResult::Done => Ok(()),
|
||||
TunnResult::Err(WireGuardError::ConnectionExpired)
|
||||
| TunnResult::Err(WireGuardError::NoCurrentSession) => {
|
||||
self.stop_peer(encapsulated_packet.index, encapsulated_packet.conn_id)
|
||||
.await;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
TunnResult::Err(e) => {
|
||||
tracing::error!(resource_address = %dst_addr, error = ?e, "resource_connection");
|
||||
let err = e.into();
|
||||
let _ = self.callbacks.on_error(&err);
|
||||
Err(err)
|
||||
}
|
||||
TunnResult::WriteToNetwork(packet) => {
|
||||
tracing::trace!(target: "wire", action = "writing", from = "iface", to = %dst_addr);
|
||||
if let Err(e) = encapsulated_packet
|
||||
.channel
|
||||
.write(&Bytes::copy_from_slice(packet))
|
||||
.await
|
||||
{
|
||||
tracing::error!(?e, "webrtc_write");
|
||||
if matches!(
|
||||
e,
|
||||
webrtc::data::Error::ErrStreamClosed
|
||||
| webrtc::data::Error::Sctp(webrtc::sctp::Error::ErrStreamClosed)
|
||||
) {
|
||||
self.stop_peer(encapsulated_packet.index, encapsulated_packet.conn_id)
|
||||
.await;
|
||||
}
|
||||
let err = e.into();
|
||||
let _ = self.callbacks.on_error(&err);
|
||||
Err(err)
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
_ => panic!("Unexpected result from encapsulate"),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -3,10 +3,9 @@
|
||||
//! This is both the wireguard and ICE implementation that should work in tandem.
|
||||
//! [Tunnel] is the main entry-point for this crate.
|
||||
use boringtun::{
|
||||
noise::{errors::WireGuardError, rate_limiter::RateLimiter, TunnResult},
|
||||
noise::rate_limiter::RateLimiter,
|
||||
x25519::{PublicKey, StaticSecret},
|
||||
};
|
||||
use bytes::Bytes;
|
||||
|
||||
use connlib_shared::{messages::Key, CallbackErrorFacade, Callbacks, Error};
|
||||
use ip_network::IpNetwork;
|
||||
@@ -29,9 +28,12 @@ use webrtc::{
|
||||
peer_connection::RTCPeerConnection,
|
||||
};
|
||||
|
||||
use futures::channel::mpsc;
|
||||
use futures_util::{SinkExt, StreamExt};
|
||||
use std::hash::Hash;
|
||||
use std::task::{Context, Poll};
|
||||
use std::{collections::HashMap, fmt, io, net::IpAddr, sync::Arc, time::Duration};
|
||||
use tokio::time::Interval;
|
||||
use webrtc::ice_transport::ice_candidate::RTCIceCandidateInit;
|
||||
|
||||
use connlib_shared::{
|
||||
@@ -56,19 +58,14 @@ mod control_protocol;
|
||||
mod device_channel;
|
||||
mod dns;
|
||||
mod gateway;
|
||||
mod iface_handler;
|
||||
mod index;
|
||||
mod ip_packet;
|
||||
mod peer;
|
||||
mod peer_handler;
|
||||
mod resource_sender;
|
||||
mod resource_table;
|
||||
mod tokio_util;
|
||||
|
||||
const MAX_UDP_SIZE: usize = (1 << 16) - 1;
|
||||
const RESET_PACKET_COUNT_INTERVAL: Duration = Duration::from_secs(1);
|
||||
const REFRESH_PEERS_TIMERS_INTERVAL: Duration = Duration::from_secs(1);
|
||||
const REFRESH_MTU_INTERVAL: Duration = Duration::from_secs(30);
|
||||
const DNS_QUERIES_QUEUE_SIZE: usize = 100;
|
||||
|
||||
/// For how long we will attempt to gather ICE candidates before aborting.
|
||||
@@ -152,6 +149,13 @@ pub struct Tunnel<CB: Callbacks, TRoleState: RoleState> {
|
||||
|
||||
/// State that differs per role, i.e. clients vs gateways.
|
||||
role_state: Mutex<TRoleState>,
|
||||
|
||||
stop_peer_command_receiver: Mutex<mpsc::Receiver<(u32, TRoleState::Id)>>,
|
||||
stop_peer_command_sender: mpsc::Sender<(u32, TRoleState::Id)>,
|
||||
|
||||
rate_limit_reset_interval: Mutex<Interval>,
|
||||
peer_refresh_interval: Mutex<Interval>,
|
||||
mtu_refresh_interval: Mutex<Interval>,
|
||||
}
|
||||
|
||||
// TODO: For now we only use these fields with debug
|
||||
@@ -189,7 +193,109 @@ where
|
||||
}
|
||||
|
||||
pub fn poll_next_event(&self, cx: &mut Context<'_>) -> Poll<Event<TRoleState::Id>> {
|
||||
self.role_state.lock().poll_next_event(cx)
|
||||
loop {
|
||||
if self
|
||||
.rate_limit_reset_interval
|
||||
.lock()
|
||||
.poll_tick(cx)
|
||||
.is_ready()
|
||||
{
|
||||
self.rate_limiter.reset_count();
|
||||
continue;
|
||||
}
|
||||
|
||||
if self.peer_refresh_interval.lock().poll_tick(cx).is_ready() {
|
||||
let peers_to_refresh = {
|
||||
let mut peers_by_ip = self.peers_by_ip.write();
|
||||
|
||||
peers_to_refresh(&mut peers_by_ip, self.stop_peer_command_sender.clone())
|
||||
};
|
||||
|
||||
for peer in peers_to_refresh {
|
||||
let callbacks = self.callbacks.clone();
|
||||
let mut stop_command_sender = self.stop_peer_command_sender.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = peer.update_timers().await {
|
||||
tracing::error!("Failed to update timers for peer: {e}");
|
||||
let _ = callbacks.on_error(&e);
|
||||
|
||||
if e.is_fatal_connection_error() {
|
||||
let _ = stop_command_sender.send((peer.index, peer.conn_id)).await;
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if self.mtu_refresh_interval.lock().poll_tick(cx).is_ready() {
|
||||
// We use `try_read` to acquire a lock on the device because we are within a synchronous context here and cannot use `.await`.
|
||||
// The device is only updated during `add_route` and `set_interface` which would be extremely unlucky to hit at the same time as this timer.
|
||||
// Even if we hit this, we just wait for the next tick to update the MTU.
|
||||
let device = match self.device.try_read().map(|d| d.clone()) {
|
||||
Ok(Some(device)) => device,
|
||||
Ok(None) => {
|
||||
let err = Error::ControlProtocolError;
|
||||
tracing::error!(?err, "get_iface_config");
|
||||
let _ = self.callbacks.on_error(&err);
|
||||
continue;
|
||||
}
|
||||
Err(_) => {
|
||||
tracing::debug!("Unlucky! Somebody is updating the device just as we are about to update its MTU, trying again on the next tick ...");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
tokio::spawn({
|
||||
let callbacks = self.callbacks.clone();
|
||||
|
||||
async move {
|
||||
if let Err(e) = device.config.refresh_mtu().await {
|
||||
tracing::error!(error = ?e, "refresh_mtu");
|
||||
let _ = callbacks.on_error(&e);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
if let Poll::Ready(event) = self.role_state.lock().poll_next_event(cx) {
|
||||
return Poll::Ready(event);
|
||||
}
|
||||
|
||||
if let Poll::Ready(Some((i, conn_id))) =
|
||||
self.stop_peer_command_receiver.lock().poll_next_unpin(cx)
|
||||
{
|
||||
let mut peers = self.peers_by_ip.write();
|
||||
|
||||
let (maybe_network, maybe_peer) = peers
|
||||
.iter()
|
||||
.find_map(|(n, p)| (p.index == i).then_some((n, p.clone())))
|
||||
.unzip();
|
||||
|
||||
if let Some(network) = maybe_network {
|
||||
peers.remove(network);
|
||||
}
|
||||
|
||||
if let Some(conn) = self.peer_connections.lock().remove(&conn_id) {
|
||||
tokio::spawn({
|
||||
let callbacks = self.callbacks.clone();
|
||||
async move {
|
||||
if let Some(peer) = maybe_peer {
|
||||
let _ = peer.shutdown().await;
|
||||
}
|
||||
if let Err(e) = conn.close().await {
|
||||
tracing::warn!(%conn_id, error = ?e, "Can't close peer");
|
||||
let _ = callbacks.on_error(&e.into());
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
return Poll::Pending;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -284,6 +390,8 @@ where
|
||||
.with_setting_engine(setting_engine)
|
||||
.build();
|
||||
|
||||
let (stop_peer_command_sender, stop_peer_command_receiver) = mpsc::channel(10);
|
||||
|
||||
Ok(Self {
|
||||
rate_limiter,
|
||||
private_key,
|
||||
@@ -296,117 +404,14 @@ where
|
||||
callbacks: CallbackErrorFacade(callbacks),
|
||||
iface_handler_abort,
|
||||
role_state: Default::default(),
|
||||
stop_peer_command_receiver: Mutex::new(stop_peer_command_receiver),
|
||||
stop_peer_command_sender,
|
||||
rate_limit_reset_interval: Mutex::new(rate_limit_reset_interval()),
|
||||
peer_refresh_interval: Mutex::new(peer_refresh_interval()),
|
||||
mtu_refresh_interval: Mutex::new(mtu_refresh_interval()),
|
||||
})
|
||||
}
|
||||
|
||||
#[tracing::instrument(level = "trace", skip(self))]
|
||||
async fn stop_peer(&self, index: u32, conn_id: TRoleState::Id) {
|
||||
self.peers_by_ip.write().retain(|_, p| p.index != index);
|
||||
let conn = self.peer_connections.lock().remove(&conn_id);
|
||||
if let Some(conn) = conn {
|
||||
if let Err(e) = conn.close().await {
|
||||
tracing::warn!(error = ?e, "Can't close peer");
|
||||
let _ = self.callbacks().on_error(&e.into());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn peer_refresh(&self, peer: &Peer<TRoleState::Id>, dst_buf: &mut [u8; MAX_UDP_SIZE]) {
|
||||
let update_timers_result = peer.update_timers(&mut dst_buf[..]);
|
||||
|
||||
match update_timers_result {
|
||||
TunnResult::Done => {}
|
||||
TunnResult::Err(WireGuardError::ConnectionExpired)
|
||||
| TunnResult::Err(WireGuardError::NoCurrentSession) => {
|
||||
self.stop_peer(peer.index, peer.conn_id).await;
|
||||
let _ = peer.shutdown().await;
|
||||
}
|
||||
TunnResult::Err(e) => tracing::error!(error = ?e, "timer_error"),
|
||||
TunnResult::WriteToNetwork(packet) => {
|
||||
let bytes = Bytes::copy_from_slice(packet);
|
||||
peer.send_infallible(bytes, &self.callbacks).await
|
||||
}
|
||||
|
||||
_ => panic!("Unexpected result from update_timers"),
|
||||
};
|
||||
}
|
||||
|
||||
fn start_rate_limiter_refresh_timer(self: &Arc<Self>) {
|
||||
let rate_limiter = self.rate_limiter.clone();
|
||||
tokio::spawn(async move {
|
||||
let mut interval = tokio::time::interval(RESET_PACKET_COUNT_INTERVAL);
|
||||
interval.set_missed_tick_behavior(MissedTickBehavior::Delay);
|
||||
loop {
|
||||
rate_limiter.reset_count();
|
||||
interval.tick().await;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
fn start_peers_refresh_timer(self: &Arc<Self>) {
|
||||
let tunnel = self.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let mut interval = tokio::time::interval(REFRESH_PEERS_TIMERS_INTERVAL);
|
||||
interval.set_missed_tick_behavior(MissedTickBehavior::Delay);
|
||||
let mut dst_buf = [0u8; MAX_UDP_SIZE];
|
||||
|
||||
loop {
|
||||
remove_expired_peers(
|
||||
&mut tunnel.peers_by_ip.write(),
|
||||
&mut tunnel.peer_connections.lock(),
|
||||
);
|
||||
|
||||
let peers: Vec<_> = tunnel
|
||||
.peers_by_ip
|
||||
.read()
|
||||
.iter()
|
||||
.map(|p| p.1)
|
||||
.unique_by(|p| p.index)
|
||||
.cloned()
|
||||
.collect();
|
||||
|
||||
for peer in peers {
|
||||
tunnel.peer_refresh(&peer, &mut dst_buf).await;
|
||||
}
|
||||
|
||||
interval.tick().await;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
async fn start_refresh_mtu_timer(self: &Arc<Self>) -> Result<()> {
|
||||
let dev = self.clone();
|
||||
let callbacks = self.callbacks().clone();
|
||||
tokio::spawn(async move {
|
||||
let mut interval = tokio::time::interval(REFRESH_MTU_INTERVAL);
|
||||
interval.set_missed_tick_behavior(MissedTickBehavior::Delay);
|
||||
loop {
|
||||
interval.tick().await;
|
||||
|
||||
let Some(device) = dev.device.read().await.clone() else {
|
||||
let err = Error::ControlProtocolError;
|
||||
tracing::error!(?err, "get_iface_config");
|
||||
let _ = callbacks.0.on_error(&err);
|
||||
continue;
|
||||
};
|
||||
if let Err(e) = device.config.refresh_mtu().await {
|
||||
tracing::error!(error = ?e, "refresh_mtu");
|
||||
let _ = callbacks.0.on_error(&e);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn start_timers(self: &Arc<Self>) -> Result<()> {
|
||||
self.start_refresh_mtu_timer().await?;
|
||||
self.start_rate_limiter_refresh_timer();
|
||||
self.start_peers_refresh_timer();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn next_index(&self) -> u32 {
|
||||
self.next_index.lock().next()
|
||||
}
|
||||
@@ -416,27 +421,68 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
/// Constructs the interval for resetting the rate limit count.
|
||||
///
|
||||
/// As per documentation on [`RateLimiter::reset_count`], this is configured to run every second.
|
||||
fn rate_limit_reset_interval() -> Interval {
|
||||
let mut interval = tokio::time::interval(Duration::from_secs(1));
|
||||
interval.set_missed_tick_behavior(MissedTickBehavior::Delay);
|
||||
|
||||
interval
|
||||
}
|
||||
|
||||
/// Constructs the interval for "refreshing" peers.
|
||||
///
|
||||
/// On each tick, we remove expired peers from our map, update wireguard timers and send packets, if any.
|
||||
fn peer_refresh_interval() -> Interval {
|
||||
let mut interval = tokio::time::interval(Duration::from_secs(1));
|
||||
interval.set_missed_tick_behavior(MissedTickBehavior::Delay);
|
||||
|
||||
interval
|
||||
}
|
||||
|
||||
/// Constructs the interval for refreshing the MTU of our TUN device.
|
||||
fn mtu_refresh_interval() -> Interval {
|
||||
let mut interval = tokio::time::interval(Duration::from_secs(30));
|
||||
interval.set_missed_tick_behavior(MissedTickBehavior::Delay);
|
||||
|
||||
interval
|
||||
}
|
||||
|
||||
fn peers_to_refresh<TId>(
|
||||
peers_by_ip: &mut IpNetworkTable<Arc<Peer<TId>>>,
|
||||
shutdown_sender: mpsc::Sender<(u32, TId)>,
|
||||
) -> Vec<Arc<Peer<TId>>>
|
||||
where
|
||||
TId: Eq + Hash + Copy + Send + Sync + 'static,
|
||||
{
|
||||
remove_expired_peers(peers_by_ip, shutdown_sender);
|
||||
|
||||
peers_by_ip
|
||||
.iter()
|
||||
.map(|p| p.1)
|
||||
.unique_by(|p| p.index)
|
||||
.cloned()
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn remove_expired_peers<TId>(
|
||||
peers_by_ip: &mut IpNetworkTable<Arc<Peer<TId>>>,
|
||||
peer_connections: &mut HashMap<TId, Arc<RTCPeerConnection>>,
|
||||
shutdown_sender: mpsc::Sender<(u32, TId)>,
|
||||
) where
|
||||
TId: Eq + Hash + Copy + Send + Sync + 'static,
|
||||
{
|
||||
for (_, peer) in peers_by_ip.iter() {
|
||||
peer.expire_resources();
|
||||
if peer.is_emptied() {
|
||||
tracing::trace!(index = peer.index, "peer_expired");
|
||||
let conn = peer_connections.remove(&peer.conn_id);
|
||||
let p = peer.clone();
|
||||
let index = peer.index;
|
||||
let conn_id = peer.conn_id;
|
||||
|
||||
// We are holding a Mutex, particularly a write one, we don't want to make a blocking call
|
||||
tokio::spawn(async move {
|
||||
let _ = p.shutdown().await;
|
||||
if let Some(conn) = conn {
|
||||
// TODO: it seems that even closing the stream there are messages to the relay
|
||||
// see where they come from.
|
||||
let _ = conn.close().await;
|
||||
}
|
||||
tracing::trace!(%index, "peer_expired");
|
||||
|
||||
tokio::spawn({
|
||||
let mut sender = shutdown_sender.clone();
|
||||
async move { sender.send((index, conn_id)).await }
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -449,7 +495,7 @@ fn remove_expired_peers<TId>(
|
||||
/// By design, this trait does not allow any operations apart from advancing via [`RoleState::poll_next_event`].
|
||||
/// The state should only be modified when the concrete type is known, e.g. [`ClientState`] or [`GatewayState`].
|
||||
pub trait RoleState: Default + Send + 'static {
|
||||
type Id: fmt::Debug + Eq + Hash + Copy + Send + Sync + 'static;
|
||||
type Id: fmt::Debug + fmt::Display + Eq + Hash + Copy + Unpin + Send + Sync + 'static;
|
||||
|
||||
fn poll_next_event(&mut self, cx: &mut Context<'_>) -> Poll<Event<Self::Id>>;
|
||||
}
|
||||
|
||||
@@ -22,10 +22,10 @@ type ExpiryingResource = (ResourceDescription, DateTime<Utc>);
|
||||
pub(crate) struct Peer<TId> {
|
||||
pub tunnel: Mutex<Tunn>,
|
||||
pub index: u32,
|
||||
pub allowed_ips: RwLock<IpNetworkTable<()>>,
|
||||
allowed_ips: RwLock<IpNetworkTable<()>>,
|
||||
pub channel: Arc<DataChannel>,
|
||||
pub conn_id: TId,
|
||||
pub resources: Option<RwLock<ResourceTable<ExpiryingResource>>>,
|
||||
resources: Option<RwLock<ResourceTable<ExpiryingResource>>>,
|
||||
// Here we store the address that we obtained for the resource that the peer corresponds to.
|
||||
// This can have the following problem:
|
||||
// 1. Peer sends packet to address.com and it resolves to 1.1.1.1
|
||||
@@ -35,7 +35,7 @@ pub(crate) struct Peer<TId> {
|
||||
// so, TODO: store multiple ips and expire them.
|
||||
// Note that this case is quite an unlikely edge case so I wouldn't prioritize this fix
|
||||
// TODO: Also check if there's any case where we want to talk to ipv4 and ipv6 from the same peer.
|
||||
pub translated_resource_addresses: RwLock<HashMap<IpAddr, ResourceId>>,
|
||||
translated_resource_addresses: RwLock<HashMap<IpAddr, ResourceId>>,
|
||||
}
|
||||
|
||||
// TODO: For now we only use these fields with debug
|
||||
@@ -50,14 +50,6 @@ pub(crate) struct PeerStats<TId> {
|
||||
pub translated_resource_addresses: HashMap<IpAddr, ResourceId>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct EncapsulatedPacket<'a, Id> {
|
||||
pub index: u32,
|
||||
pub conn_id: Id,
|
||||
pub channel: Arc<DataChannel>,
|
||||
pub encapsulate_result: TunnResult<'a>,
|
||||
}
|
||||
|
||||
impl<TId> Peer<TId>
|
||||
where
|
||||
TId: Copy,
|
||||
@@ -130,7 +122,7 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn get_translation(&self, ip: IpAddr) -> Option<ResourceDescription> {
|
||||
fn get_translation(&self, ip: IpAddr) -> Option<ResourceDescription> {
|
||||
let id = self.translated_resource_addresses.read().get(&ip).cloned();
|
||||
self.resources.as_ref().and_then(|resources| {
|
||||
id.and_then(|id| resources.read().get_by_id(&id).map(|r| r.0.clone()))
|
||||
@@ -141,8 +133,25 @@ where
|
||||
self.allowed_ips.write().insert(ip, ());
|
||||
}
|
||||
|
||||
pub(crate) fn update_timers<'a>(&self, dst: &'a mut [u8]) -> TunnResult<'a> {
|
||||
self.tunnel.lock().update_timers(dst)
|
||||
pub(crate) async fn update_timers<'a>(&self) -> Result<()> {
|
||||
/// [`boringtun`] requires us to pass buffers in where it can construct its packets.
|
||||
///
|
||||
/// When updating the timers, the largest packet that we may have to send is `148` bytes as per `HANDSHAKE_INIT_SZ` constant in [`boringtun`].
|
||||
const MAX_SCRATCH_SPACE: usize = 148;
|
||||
|
||||
let mut buf = [0u8; MAX_SCRATCH_SPACE];
|
||||
|
||||
let packet = match self.tunnel.lock().update_timers(&mut buf) {
|
||||
TunnResult::Done => return Ok(()),
|
||||
TunnResult::Err(e) => return Err(e.into()),
|
||||
TunnResult::WriteToNetwork(b) => b,
|
||||
_ => panic!("Unexpected result from update_timers"),
|
||||
};
|
||||
|
||||
let bytes = Bytes::copy_from_slice(packet);
|
||||
self.channel.write(&bytes).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) async fn shutdown(&self) -> Result<()> {
|
||||
@@ -189,11 +198,13 @@ where
|
||||
self.translated_resource_addresses.write().insert(addr, id);
|
||||
}
|
||||
|
||||
pub(crate) fn encapsulate<'a>(
|
||||
/// Sends the given packet to this peer by encapsulating it in a wireguard packet.
|
||||
pub(crate) async fn send<'a>(
|
||||
&self,
|
||||
packet: &mut MutableIpPacket<'a>,
|
||||
dst: &'a mut [u8],
|
||||
) -> Result<EncapsulatedPacket<'a, TId>> {
|
||||
mut packet: MutableIpPacket<'a>,
|
||||
dest: IpAddr,
|
||||
buf: &mut [u8],
|
||||
) -> Result<()> {
|
||||
if let Some(resource) = self.get_translation(packet.to_immutable().source()) {
|
||||
let ResourceDescription::Dns(resource) = resource else {
|
||||
tracing::error!(
|
||||
@@ -209,12 +220,18 @@ where
|
||||
|
||||
packet.update_checksum();
|
||||
}
|
||||
Ok(EncapsulatedPacket {
|
||||
index: self.index,
|
||||
conn_id: self.conn_id,
|
||||
channel: self.channel.clone(),
|
||||
encapsulate_result: self.tunnel.lock().encapsulate(packet.packet_mut(), dst),
|
||||
})
|
||||
let packet = match self.tunnel.lock().encapsulate(packet.packet_mut(), buf) {
|
||||
TunnResult::Done => return Ok(()),
|
||||
TunnResult::Err(e) => return Err(e.into()),
|
||||
TunnResult::WriteToNetwork(b) => b,
|
||||
_ => panic!("Unexpected result from `encapsulate`"),
|
||||
};
|
||||
|
||||
tracing::trace!(target: "wire", action = "writing", from = "iface", to = %dest);
|
||||
|
||||
self.channel.write(&Bytes::copy_from_slice(packet)).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn get_packet_resource(
|
||||
|
||||
@@ -1,9 +1,14 @@
|
||||
use std::net::{IpAddr, ToSocketAddrs};
|
||||
use std::sync::Arc;
|
||||
|
||||
use boringtun::noise::{handshake::parse_handshake_anon, Packet, TunnResult};
|
||||
use boringtun::x25519::{PublicKey, StaticSecret};
|
||||
use bytes::Bytes;
|
||||
use connlib_shared::messages::ResourceDescription;
|
||||
use connlib_shared::{Callbacks, Error, Result};
|
||||
use futures_util::SinkExt;
|
||||
|
||||
use crate::ip_packet::MutableIpPacket;
|
||||
use crate::{
|
||||
device_channel::DeviceIo, index::check_packet_index, peer::Peer, RoleState, Tunnel,
|
||||
MAX_UDP_SIZE,
|
||||
@@ -14,16 +19,95 @@ where
|
||||
CB: Callbacks + 'static,
|
||||
TRoleState: RoleState,
|
||||
{
|
||||
#[inline(always)]
|
||||
fn is_wireguard_packet_ok(&self, parsed_packet: &Packet, peer: &Peer<TRoleState::Id>) -> bool {
|
||||
match &parsed_packet {
|
||||
Packet::HandshakeInit(p) => {
|
||||
parse_handshake_anon(&self.private_key, &self.public_key, p).is_ok()
|
||||
pub(crate) async fn start_peer_handler(self: Arc<Self>, peer: Arc<Peer<TRoleState::Id>>) {
|
||||
loop {
|
||||
let Some(device) = self.device.read().await.clone() else {
|
||||
let err = Error::NoIface;
|
||||
tracing::error!(?err);
|
||||
let _ = self.callbacks().on_disconnect(Some(&err));
|
||||
break;
|
||||
};
|
||||
let device_io = device.io;
|
||||
|
||||
if let Err(err) = self.peer_handler(&peer, device_io).await {
|
||||
if err.raw_os_error() != Some(9) {
|
||||
tracing::error!(?err);
|
||||
let _ = self.callbacks().on_error(&err.into());
|
||||
break;
|
||||
} else {
|
||||
tracing::warn!("bad_file_descriptor");
|
||||
}
|
||||
}
|
||||
Packet::HandshakeResponse(p) => check_packet_index(p.receiver_idx, peer.index),
|
||||
Packet::PacketCookieReply(p) => check_packet_index(p.receiver_idx, peer.index),
|
||||
Packet::PacketData(p) => check_packet_index(p.receiver_idx, peer.index),
|
||||
}
|
||||
tracing::debug!(peer = ?peer.stats(), "peer_stopped");
|
||||
let _ = self
|
||||
.stop_peer_command_sender
|
||||
.clone()
|
||||
.send((peer.index, peer.conn_id))
|
||||
.await;
|
||||
}
|
||||
|
||||
async fn peer_handler(
|
||||
self: &Arc<Self>,
|
||||
peer: &Arc<Peer<TRoleState::Id>>,
|
||||
device_io: DeviceIo,
|
||||
) -> std::io::Result<()> {
|
||||
let mut src_buf = [0u8; MAX_UDP_SIZE];
|
||||
let mut dst_buf = [0u8; MAX_UDP_SIZE];
|
||||
while let Ok(size) = peer.channel.read(&mut src_buf[..]).await {
|
||||
tracing::trace!(target: "wire", action = "read", bytes = size, from = "peer");
|
||||
|
||||
// TODO: Double check that this can only happen on closed channel
|
||||
// I think it's possible to transmit a 0-byte message through the channel
|
||||
// but we would never use that.
|
||||
// We should keep track of an open/closed channel ourselves if we wanted to do it properly then.
|
||||
if size == 0 {
|
||||
break;
|
||||
}
|
||||
|
||||
if let Err(Error::Io(e)) = self
|
||||
.handle_peer_packet(peer, &device_io, &src_buf[..size], &mut dst_buf)
|
||||
.await
|
||||
{
|
||||
return Err(e);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) async fn handle_peer_packet(
|
||||
self: &Arc<Self>,
|
||||
peer: &Arc<Peer<TRoleState::Id>>,
|
||||
device_writer: &DeviceIo,
|
||||
src: &[u8],
|
||||
dst: &mut [u8],
|
||||
) -> Result<()> {
|
||||
let parsed_packet = self.verify_packet(peer, src, dst).await?;
|
||||
if !is_wireguard_packet_ok(&self.private_key, &self.public_key, &parsed_packet, peer) {
|
||||
tracing::error!("wireguard_verification");
|
||||
return Err(Error::BadPacket);
|
||||
}
|
||||
|
||||
let decapsulate_result = peer.tunnel.lock().decapsulate(None, src, dst);
|
||||
|
||||
if self
|
||||
.handle_decapsulated_packet(peer, device_writer, decapsulate_result)
|
||||
.await?
|
||||
{
|
||||
// Flush pending queue
|
||||
while let TunnResult::WriteToNetwork(packet) =
|
||||
peer.tunnel.lock().decapsulate(None, &[], dst)
|
||||
{
|
||||
let bytes = Bytes::copy_from_slice(packet);
|
||||
let callbacks = self.callbacks.clone();
|
||||
let peer = peer.clone();
|
||||
tokio::spawn(async move { peer.send_infallible(bytes, &callbacks).await });
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
@@ -80,101 +164,143 @@ where
|
||||
Ok(true)
|
||||
}
|
||||
TunnResult::WriteToTunnelV4(packet, addr) => {
|
||||
self.send_to_resource(device_io, peer, addr.into(), packet)?;
|
||||
send_to_resource(device_io, peer, addr.into(), packet)?;
|
||||
Ok(false)
|
||||
}
|
||||
TunnResult::WriteToTunnelV6(packet, addr) => {
|
||||
self.send_to_resource(device_io, peer, addr.into(), packet)?;
|
||||
send_to_resource(device_io, peer, addr.into(), packet)?;
|
||||
Ok(false)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) async fn handle_peer_packet(
|
||||
self: &Arc<Self>,
|
||||
peer: &Arc<Peer<TRoleState::Id>>,
|
||||
device_writer: &DeviceIo,
|
||||
src: &[u8],
|
||||
dst: &mut [u8],
|
||||
) -> Result<()> {
|
||||
let parsed_packet = self.verify_packet(peer, src, dst).await?;
|
||||
if !self.is_wireguard_packet_ok(&parsed_packet, peer) {
|
||||
tracing::error!("wireguard_verification");
|
||||
return Err(Error::BadPacket);
|
||||
}
|
||||
|
||||
let decapsulate_result = peer.tunnel.lock().decapsulate(None, src, dst);
|
||||
|
||||
if self
|
||||
.handle_decapsulated_packet(peer, device_writer, decapsulate_result)
|
||||
.await?
|
||||
{
|
||||
// Flush pending queue
|
||||
while let TunnResult::WriteToNetwork(packet) = {
|
||||
let res = peer.tunnel.lock().decapsulate(None, &[], dst);
|
||||
res
|
||||
} {
|
||||
let bytes = Bytes::copy_from_slice(packet);
|
||||
let callbacks = self.callbacks.clone();
|
||||
let peer = peer.clone();
|
||||
tokio::spawn(async move { peer.send_infallible(bytes, &callbacks).await });
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn peer_handler(
|
||||
self: &Arc<Self>,
|
||||
peer: &Arc<Peer<TRoleState::Id>>,
|
||||
device_io: DeviceIo,
|
||||
) -> std::io::Result<()> {
|
||||
let mut src_buf = [0u8; MAX_UDP_SIZE];
|
||||
let mut dst_buf = [0u8; MAX_UDP_SIZE];
|
||||
while let Ok(size) = peer.channel.read(&mut src_buf[..]).await {
|
||||
tracing::trace!(target: "wire", action = "read", bytes = size, from = "peer");
|
||||
|
||||
// TODO: Double check that this can only happen on closed channel
|
||||
// I think it's possible to transmit a 0-byte message through the channel
|
||||
// but we would never use that.
|
||||
// We should keep track of an open/closed channel ourselves if we wanted to do it properly then.
|
||||
if size == 0 {
|
||||
break;
|
||||
}
|
||||
|
||||
if let Err(Error::Io(e)) = self
|
||||
.handle_peer_packet(peer, &device_io, &src_buf[..size], &mut dst_buf)
|
||||
.await
|
||||
{
|
||||
return Err(e);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) async fn start_peer_handler(self: Arc<Self>, peer: Arc<Peer<TRoleState::Id>>) {
|
||||
loop {
|
||||
let Some(device) = self.device.read().await.clone() else {
|
||||
let err = Error::NoIface;
|
||||
tracing::error!(?err);
|
||||
let _ = self.callbacks().on_disconnect(Some(&err));
|
||||
break;
|
||||
};
|
||||
let device_io = device.io;
|
||||
|
||||
if let Err(err) = self.peer_handler(&peer, device_io).await {
|
||||
if err.raw_os_error() != Some(9) {
|
||||
tracing::error!(?err);
|
||||
let _ = self.callbacks().on_error(&err.into());
|
||||
break;
|
||||
} else {
|
||||
tracing::warn!("bad_file_descriptor");
|
||||
}
|
||||
}
|
||||
}
|
||||
tracing::debug!(peer = ?peer.stats(), "peer_stopped");
|
||||
self.stop_peer(peer.index, peer.conn_id).await;
|
||||
#[inline(always)]
|
||||
fn is_wireguard_packet_ok<TId>(
|
||||
private_key: &StaticSecret,
|
||||
public_key: &PublicKey,
|
||||
parsed_packet: &Packet,
|
||||
peer: &Peer<TId>,
|
||||
) -> bool {
|
||||
match parsed_packet {
|
||||
Packet::HandshakeInit(p) => parse_handshake_anon(private_key, public_key, p).is_ok(),
|
||||
Packet::HandshakeResponse(p) => check_packet_index(p.receiver_idx, peer.index),
|
||||
Packet::PacketCookieReply(p) => check_packet_index(p.receiver_idx, peer.index),
|
||||
Packet::PacketData(p) => check_packet_index(p.receiver_idx, peer.index),
|
||||
}
|
||||
}
|
||||
|
||||
fn send_to_resource<TId>(
|
||||
device_io: &DeviceIo,
|
||||
peer: &Arc<Peer<TId>>,
|
||||
addr: IpAddr,
|
||||
packet: &mut [u8],
|
||||
) -> Result<()>
|
||||
where
|
||||
TId: Copy,
|
||||
{
|
||||
if peer.is_allowed(addr) {
|
||||
packet_allowed(device_io, peer, addr, packet)?;
|
||||
Ok(())
|
||||
} else {
|
||||
tracing::warn!(%addr, "Received packet from peer with an unallowed ip");
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn packet_allowed<TId>(
|
||||
device_io: &DeviceIo,
|
||||
peer: &Arc<Peer<TId>>,
|
||||
addr: IpAddr,
|
||||
packet: &mut [u8],
|
||||
) -> Result<()>
|
||||
where
|
||||
TId: Copy,
|
||||
{
|
||||
let Some((dst, resource)) = peer.get_packet_resource(packet) else {
|
||||
// If there's no associated resource it means that we are in a client, then the packet comes from a gateway
|
||||
// and we just trust gateways.
|
||||
// In gateways this should never happen.
|
||||
tracing::trace!(target: "wire", action = "writing", to = "iface", %addr, bytes = %packet.len());
|
||||
send_packet(device_io, packet, addr)?;
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
let (dst_addr, _dst_port) = get_resource_addr_and_port(peer, &resource, &addr, &dst)?;
|
||||
update_packet(packet, dst_addr);
|
||||
send_packet(device_io, packet, addr)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn send_packet(device_io: &DeviceIo, packet: &mut [u8], dst_addr: IpAddr) -> std::io::Result<()> {
|
||||
match dst_addr {
|
||||
IpAddr::V4(_) => device_io.write4(packet)?,
|
||||
IpAddr::V6(_) => device_io.write6(packet)?,
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn update_packet(packet: &mut [u8], dst_addr: IpAddr) {
|
||||
let Some(mut pkt) = MutableIpPacket::new(packet) else {
|
||||
return;
|
||||
};
|
||||
pkt.set_dst(dst_addr);
|
||||
pkt.update_checksum();
|
||||
}
|
||||
|
||||
fn get_matching_version_ip(addr: &IpAddr, ip: &IpAddr) -> Option<IpAddr> {
|
||||
((addr.is_ipv4() && ip.is_ipv4()) || (addr.is_ipv6() && ip.is_ipv6())).then_some(*ip)
|
||||
}
|
||||
|
||||
fn get_resource_addr_and_port<TId>(
|
||||
peer: &Arc<Peer<TId>>,
|
||||
resource: &ResourceDescription,
|
||||
addr: &IpAddr,
|
||||
dst: &IpAddr,
|
||||
) -> Result<(IpAddr, Option<u16>)>
|
||||
where
|
||||
TId: Copy,
|
||||
{
|
||||
match resource {
|
||||
ResourceDescription::Dns(r) => {
|
||||
let mut address = r.address.split(':');
|
||||
let Some(dst_addr) = address.next() else {
|
||||
tracing::error!("invalid DNS name for resource: {}", r.address);
|
||||
return Err(Error::InvalidResource);
|
||||
};
|
||||
let Ok(mut dst_addr) = (dst_addr, 0).to_socket_addrs() else {
|
||||
tracing::warn!(%addr, "Couldn't resolve name");
|
||||
return Err(Error::InvalidResource);
|
||||
};
|
||||
let Some(dst_addr) = dst_addr.find_map(|d| get_matching_version_ip(addr, &d.ip()))
|
||||
else {
|
||||
tracing::warn!(%addr, "Couldn't resolve name addr");
|
||||
return Err(Error::InvalidResource);
|
||||
};
|
||||
peer.update_translated_resource_address(r.id, dst_addr);
|
||||
Ok((
|
||||
dst_addr,
|
||||
address
|
||||
.next()
|
||||
.map(str::parse::<u16>)
|
||||
.and_then(std::result::Result::ok),
|
||||
))
|
||||
}
|
||||
ResourceDescription::Cidr(r) => {
|
||||
if r.address.contains(*dst) {
|
||||
Ok((
|
||||
get_matching_version_ip(addr, dst).ok_or(Error::InvalidResource)?,
|
||||
None,
|
||||
))
|
||||
} else {
|
||||
tracing::warn!(
|
||||
"client tried to hijack the tunnel for range outside what it's allowed."
|
||||
);
|
||||
Err(Error::InvalidSource)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,130 +0,0 @@
|
||||
use std::{
|
||||
net::{IpAddr, ToSocketAddrs},
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
use crate::{device_channel::DeviceIo, ip_packet::MutableIpPacket, peer::Peer, RoleState, Tunnel};
|
||||
|
||||
use connlib_shared::{messages::ResourceDescription, Callbacks, Error, Result};
|
||||
|
||||
impl<CB, TRoleState> Tunnel<CB, TRoleState>
|
||||
where
|
||||
CB: Callbacks + 'static,
|
||||
TRoleState: RoleState,
|
||||
{
|
||||
#[inline(always)]
|
||||
fn update_packet(&self, packet: &mut [u8], dst_addr: IpAddr) {
|
||||
let Some(mut pkt) = MutableIpPacket::new(packet) else {
|
||||
return;
|
||||
};
|
||||
pkt.set_dst(dst_addr);
|
||||
pkt.update_checksum();
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn send_packet(
|
||||
&self,
|
||||
device_io: &DeviceIo,
|
||||
packet: &mut [u8],
|
||||
dst_addr: IpAddr,
|
||||
) -> std::io::Result<()> {
|
||||
match dst_addr {
|
||||
IpAddr::V4(_) => device_io.write4(packet)?,
|
||||
IpAddr::V6(_) => device_io.write6(packet)?,
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn packet_allowed(
|
||||
&self,
|
||||
device_io: &DeviceIo,
|
||||
peer: &Arc<Peer<TRoleState::Id>>,
|
||||
addr: IpAddr,
|
||||
packet: &mut [u8],
|
||||
) -> Result<()> {
|
||||
let Some((dst, resource)) = peer.get_packet_resource(packet) else {
|
||||
// If there's no associated resource it means that we are in a client, then the packet comes from a gateway
|
||||
// and we just trust gateways.
|
||||
// In gateways this should never happen.
|
||||
tracing::trace!(target: "wire", action = "writing", to = "iface", %addr, bytes = %packet.len());
|
||||
self.send_packet(device_io, packet, addr)?;
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
let (dst_addr, _dst_port) = get_resource_addr_and_port(peer, &resource, &addr, &dst)?;
|
||||
self.update_packet(packet, dst_addr);
|
||||
self.send_packet(device_io, packet, addr)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn send_to_resource(
|
||||
&self,
|
||||
device_io: &DeviceIo,
|
||||
peer: &Arc<Peer<TRoleState::Id>>,
|
||||
addr: IpAddr,
|
||||
packet: &mut [u8],
|
||||
) -> Result<()> {
|
||||
if peer.is_allowed(addr) {
|
||||
self.packet_allowed(device_io, peer, addr, packet)?;
|
||||
Ok(())
|
||||
} else {
|
||||
tracing::warn!(%addr, "Received packet from peer with an unallowed ip");
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn get_matching_version_ip(addr: &IpAddr, ip: &IpAddr) -> Option<IpAddr> {
|
||||
((addr.is_ipv4() && ip.is_ipv4()) || (addr.is_ipv6() && ip.is_ipv6())).then_some(*ip)
|
||||
}
|
||||
|
||||
fn get_resource_addr_and_port<TId>(
|
||||
peer: &Arc<Peer<TId>>,
|
||||
resource: &ResourceDescription,
|
||||
addr: &IpAddr,
|
||||
dst: &IpAddr,
|
||||
) -> Result<(IpAddr, Option<u16>)>
|
||||
where
|
||||
TId: Copy,
|
||||
{
|
||||
match resource {
|
||||
ResourceDescription::Dns(r) => {
|
||||
let mut address = r.address.split(':');
|
||||
let Some(dst_addr) = address.next() else {
|
||||
tracing::error!("invalid DNS name for resource: {}", r.address);
|
||||
return Err(Error::InvalidResource);
|
||||
};
|
||||
let Ok(mut dst_addr) = (dst_addr, 0).to_socket_addrs() else {
|
||||
tracing::warn!(%addr, "Couldn't resolve name");
|
||||
return Err(Error::InvalidResource);
|
||||
};
|
||||
let Some(dst_addr) = dst_addr.find_map(|d| get_matching_version_ip(addr, &d.ip()))
|
||||
else {
|
||||
tracing::warn!(%addr, "Couldn't resolve name addr");
|
||||
return Err(Error::InvalidResource);
|
||||
};
|
||||
peer.update_translated_resource_address(r.id, dst_addr);
|
||||
Ok((
|
||||
dst_addr,
|
||||
address
|
||||
.next()
|
||||
.map(str::parse::<u16>)
|
||||
.and_then(std::result::Result::ok),
|
||||
))
|
||||
}
|
||||
ResourceDescription::Cidr(r) => {
|
||||
if r.address.contains(*dst) {
|
||||
Ok((
|
||||
get_matching_version_ip(addr, dst).ok_or(Error::InvalidResource)?,
|
||||
None,
|
||||
))
|
||||
} else {
|
||||
tracing::warn!(
|
||||
"client tried to hijack the tunnel for range outside what it's allowed."
|
||||
);
|
||||
Err(Error::InvalidSource)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user