refactor(connlib): split device handler for clients & gateway (#2301)

This commit is contained in:
Thomas Eizinger
2023-10-12 10:02:31 +11:00
committed by GitHub
parent 1c03cfc80f
commit dbf0e445b0
13 changed files with 686 additions and 574 deletions

View File

@@ -0,0 +1,285 @@
use crate::device_channel::{create_iface, DeviceIo};
use crate::ip_packet::IpPacket;
use crate::{
dns, tokio_util, ConnId, ControlSignal, Device, Event, RoleState, Tunnel,
ICE_GATHERING_TIMEOUT_SECONDS, MAX_CONCURRENT_ICE_GATHERING, MAX_UDP_SIZE,
};
use connlib_shared::error::{ConnlibError as Error, ConnlibError};
use connlib_shared::messages::{GatewayId, Interface as InterfaceConfig, ResourceDescription};
use connlib_shared::{Callbacks, DNS_SENTINEL};
use futures::channel::mpsc::Receiver;
use futures_bounded::{PushError, StreamMap};
use ip_network::IpNetwork;
use std::collections::HashMap;
use std::io;
use std::sync::Arc;
use std::task::{ready, Context, Poll};
use std::time::Duration;
use webrtc::ice_transport::ice_candidate::RTCIceCandidateInit;
impl<C, CB> Tunnel<C, CB, ClientState>
where
C: ControlSignal + Send + Sync + 'static,
CB: Callbacks + 'static,
{
/// Adds a the given resource to the tunnel.
///
/// Once added, when a packet for the resource is intercepted a new data channel will be created
/// and packets will be wrapped with wireguard and sent through it.
#[tracing::instrument(level = "trace", skip(self))]
pub async fn add_resource(
self: &Arc<Self>,
resource_description: ResourceDescription,
) -> connlib_shared::Result<()> {
let mut any_valid_route = false;
{
for ip in resource_description.ips() {
if let Err(e) = self.add_route(ip).await {
tracing::warn!(route = %ip, error = ?e, "add_route");
let _ = self.callbacks().on_error(&e);
} else {
any_valid_route = true;
}
}
}
if !any_valid_route {
return Err(Error::InvalidResource);
}
let resource_list = {
let mut resources = self.resources.write();
resources.insert(resource_description);
resources.resource_list()
};
self.callbacks.on_update_resources(resource_list)?;
Ok(())
}
/// Sets the interface configuration and starts background tasks.
#[tracing::instrument(level = "trace", skip(self))]
pub async fn set_interface(
self: &Arc<Self>,
config: &InterfaceConfig,
) -> connlib_shared::Result<()> {
let device = create_iface(config, self.callbacks()).await?;
*self.device.write().await = Some(device.clone());
self.start_timers().await?;
*self.iface_handler_abort.lock() = Some(tokio_util::spawn_log(
&self.callbacks,
device_handler(Arc::clone(self), device),
));
self.add_route(DNS_SENTINEL.into()).await?;
self.callbacks.on_tunnel_ready()?;
tracing::debug!("background_loop_started");
Ok(())
}
#[tracing::instrument(level = "trace", skip(self))]
async fn add_route(self: &Arc<Self>, route: IpNetwork) -> connlib_shared::Result<()> {
let mut device = self.device.write().await;
if let Some(new_device) = device
.as_ref()
.ok_or(Error::ControlProtocolError)?
.config
.add_route(route, self.callbacks())
.await?
{
*device = Some(new_device.clone());
*self.iface_handler_abort.lock() = Some(tokio_util::spawn_log(
&self.callbacks,
device_handler(Arc::clone(self), new_device),
));
}
Ok(())
}
#[inline(always)]
fn connection_intent(self: &Arc<Self>, packet: IpPacket<'_>) {
const MAX_SIGNAL_CONNECTION_DELAY: Duration = Duration::from_secs(2);
// We can buffer requests here but will drop them for now and let the upper layer reliability protocol handle this
if let Some(resource) = self.get_resource(packet.source()) {
// We have awaiting connection to prevent a race condition where
// create_peer_connection hasn't added the thing to peer_connections
// and we are finding another packet to the same address (otherwise we would just use peer_connections here)
let mut awaiting_connection = self.awaiting_connection.lock();
let conn_id = ConnId::from(resource.id());
if awaiting_connection.get(&conn_id).is_none() {
tracing::trace!(
resource_ip = %packet.destination(),
"resource_connection_intent",
);
awaiting_connection.insert(conn_id, Default::default());
let dev = Arc::clone(self);
let mut connected_gateway_ids: Vec<_> = dev
.gateway_awaiting_connection
.lock()
.clone()
.into_keys()
.collect();
connected_gateway_ids
.extend(dev.resources_gateways.lock().values().collect::<Vec<_>>());
tracing::trace!(
gateways = ?connected_gateway_ids,
"connected_gateways"
);
tokio::spawn(async move {
let mut interval = tokio::time::interval(MAX_SIGNAL_CONNECTION_DELAY);
loop {
interval.tick().await;
let reference = {
let mut awaiting_connections = dev.awaiting_connection.lock();
let Some(awaiting_connection) =
awaiting_connections.get_mut(&ConnId::from(resource.id()))
else {
break;
};
if awaiting_connection.response_received {
break;
}
awaiting_connection.total_attemps += 1;
awaiting_connection.total_attemps
};
if let Err(e) = dev
.control_signaler
.signal_connection_to(&resource, &connected_gateway_ids, reference)
.await
{
// Not a deadlock because this is a different task
dev.awaiting_connection.lock().remove(&conn_id);
tracing::error!(error = ?e, "start_resource_connection");
let _ = dev.callbacks.on_error(&e);
}
}
});
}
}
}
}
/// Reads IP packets from the [`Device`] and handles them accordingly.
async fn device_handler<C, CB>(
tunnel: Arc<Tunnel<C, CB, ClientState>>,
mut device: Device,
) -> Result<(), ConnlibError>
where
C: ControlSignal + Send + Sync + 'static,
CB: Callbacks + 'static,
{
let device_writer = device.io.clone();
let mut buf = [0u8; MAX_UDP_SIZE];
loop {
let Some(packet) = device.read().await? else {
return Ok(());
};
if let Some(dns_packet) = dns::parse(&tunnel.resources.read(), packet.as_immutable()) {
if let Err(e) = send_dns_packet(&device_writer, dns_packet) {
tracing::error!(err = %e, "failed to send DNS packet");
let _ = tunnel.callbacks.on_error(&e.into());
}
continue;
}
let dest = packet.destination();
let Some(peer) = tunnel.peer_by_ip(dest) else {
tunnel.connection_intent(packet.as_immutable());
continue;
};
if let Err(e) = tunnel
.encapsulate_and_send_to_peer(packet, peer, &dest, &mut buf)
.await
{
let _ = tunnel.callbacks.on_error(&e);
tracing::error!(err = ?e, "failed to handle packet {e:#}")
}
}
}
fn send_dns_packet(device_writer: &DeviceIo, packet: dns::Packet) -> io::Result<()> {
match packet {
dns::Packet::Ipv4(r) => device_writer.write4(&r[..])?,
dns::Packet::Ipv6(r) => device_writer.write6(&r[..])?,
};
Ok(())
}
/// [`Tunnel`] state specific to clients.
pub struct ClientState {
active_candidate_receivers: StreamMap<GatewayId, RTCIceCandidateInit>,
/// We split the receivers of ICE candidates into two phases because we only want to start sending them once we've received an SDP from the gateway.
waiting_for_sdp_from_gatway: HashMap<GatewayId, Receiver<RTCIceCandidateInit>>,
}
impl ClientState {
pub fn add_waiting_ice_receiver(
&mut self,
id: GatewayId,
receiver: Receiver<RTCIceCandidateInit>,
) {
self.waiting_for_sdp_from_gatway.insert(id, receiver);
}
pub fn activate_ice_candidate_receiver(&mut self, id: GatewayId) {
let Some(receiver) = self.waiting_for_sdp_from_gatway.remove(&id) else {
return;
};
match self.active_candidate_receivers.try_push(id, receiver) {
Ok(()) => {}
Err(PushError::BeyondCapacity(_)) => {
tracing::warn!("Too many active ICE candidate receivers at a time")
}
Err(PushError::Replaced(_)) => {
tracing::warn!(%id, "Replaced old ICE candidate receiver with new one")
}
}
}
}
impl Default for ClientState {
fn default() -> Self {
Self {
active_candidate_receivers: StreamMap::new(
Duration::from_secs(ICE_GATHERING_TIMEOUT_SECONDS),
MAX_CONCURRENT_ICE_GATHERING,
),
waiting_for_sdp_from_gatway: Default::default(),
}
}
}
impl RoleState for ClientState {
type Id = GatewayId;
fn poll_next_event(&mut self, cx: &mut Context<'_>) -> Poll<Event<Self::Id>> {
loop {
match ready!(self.active_candidate_receivers.poll_next_unpin(cx)) {
(conn_id, Some(Ok(c))) => {
return Poll::Ready(Event::SignalIceCandidate {
conn_id,
candidate: c,
})
}
(id, Some(Err(e))) => {
tracing::warn!(gateway_id = %id, "ICE gathering timed out: {e}")
}
(_, None) => {}
}
}
}
}

View File

@@ -22,8 +22,7 @@ use webrtc::{
},
};
use crate::role_state::RoleState;
use crate::{peer::Peer, ConnId, ControlSignal, PeerConfig, Tunnel};
use crate::{peer::Peer, ConnId, ControlSignal, PeerConfig, RoleState, Tunnel};
mod client;
mod gateway;

View File

@@ -10,8 +10,7 @@ use webrtc::peer_connection::{
RTCPeerConnection,
};
use crate::role_state::GatewayState;
use crate::{ControlSignal, PeerConfig, Tunnel};
use crate::{ControlSignal, GatewayState, PeerConfig, Tunnel};
#[tracing::instrument(level = "trace", skip(tunnel))]
fn handle_connection_state_update<C, CB>(

View File

@@ -9,7 +9,7 @@ use tokio::io::{unix::AsyncFd, Interest};
use tun::{IfaceDevice, IfaceStream};
use crate::Device;
use crate::{Device, MAX_UDP_SIZE};
mod tun;
@@ -65,7 +65,11 @@ impl IfaceConfig {
iface,
mtu: AtomicUsize::new(mtu),
});
Ok(Some(Device { io, config }))
Ok(Some(Device {
io,
config,
buf: Box::new([0u8; MAX_UDP_SIZE]),
}))
}
}
@@ -82,5 +86,9 @@ pub(crate) async fn create_iface(
mtu: AtomicUsize::new(mtu),
});
Ok(Device { config, io })
Ok(Device {
io,
config,
buf: Box::new([0u8; MAX_UDP_SIZE]),
})
}

View File

@@ -1,15 +1,12 @@
use std::{net::IpAddr, sync::Arc};
use crate::{
ip_packet::{to_dns, IpPacket, MutableIpPacket, Version},
ControlSignal, Tunnel,
};
use connlib_shared::{messages::ResourceDescription, Callbacks, DNS_SENTINEL};
use crate::ip_packet::{to_dns, IpPacket, MutableIpPacket, Version};
use crate::resource_table::ResourceTable;
use connlib_shared::{messages::ResourceDescription, DNS_SENTINEL};
use domain::base::{
iana::{Class, Rcode, Rtype},
Dname, Message, MessageBuilder, ParsedDname, ToDname,
};
use pnet_packet::{udp::MutableUdpPacket, MutablePacket, Packet as UdpPacket, PacketSize};
use std::net::IpAddr;
const DNS_TTL: u32 = 300;
const UDP_HEADER_SIZE: usize = 8;
@@ -18,7 +15,7 @@ const REVERSE_DNS_ADDRESS_V4: &str = "in-addr";
const REVERSE_DNS_ADDRESS_V6: &str = "ip6";
#[derive(Debug, Clone)]
pub(crate) enum SendPacket {
pub(crate) enum Packet {
Ipv4(Vec<u8>),
Ipv6(Vec<u8>),
}
@@ -28,152 +25,139 @@ pub(crate) enum SendPacket {
// as we can therefore we won't do it.
//
// See: https://stackoverflow.com/a/55093896
impl<C, CB, TRoleState> Tunnel<C, CB, TRoleState>
where
C: ControlSignal + Send + Sync + 'static,
CB: Callbacks + 'static,
{
fn build_response(
self: &Arc<Self>,
original_buf: &[u8],
mut dns_answer: Vec<u8>,
) -> Option<Vec<u8>> {
let response_len = dns_answer.len();
let original_pkt = IpPacket::new(original_buf)?;
let original_dgm = original_pkt.as_udp()?;
let hdr_len = original_pkt.packet_size() - original_dgm.payload().len();
let mut res_buf = Vec::with_capacity(hdr_len + response_len);
res_buf.extend_from_slice(&original_buf[..hdr_len]);
res_buf.append(&mut dns_answer);
let mut pkt = MutableIpPacket::new(&mut res_buf)?;
let dgm_len = UDP_HEADER_SIZE + response_len;
pkt.set_len(hdr_len + response_len, dgm_len);
pkt.swap_src_dst();
let mut dgm = MutableUdpPacket::new(pkt.payload_mut())?;
dgm.set_length(dgm_len as u16);
dgm.set_source(original_dgm.get_destination());
dgm.set_destination(original_dgm.get_source());
let mut pkt = MutableIpPacket::new(&mut res_buf)?;
let udp_checksum = pkt.to_immutable().udp_checksum(&pkt.as_immutable_udp()?);
pkt.as_udp()?.set_checksum(udp_checksum);
pkt.set_ipv4_checksum();
Some(res_buf)
pub(crate) fn parse(
resources: &ResourceTable<ResourceDescription>,
packet: IpPacket<'_>,
) -> Option<Packet> {
let version = packet.version();
if packet.destination() != IpAddr::from(DNS_SENTINEL) {
return None;
}
fn build_dns_with_answer<N>(
self: &Arc<Self>,
message: &Message<[u8]>,
qname: &N,
qtype: Rtype,
resource: &ResourceDescription,
) -> Option<Vec<u8>>
where
N: ToDname + ?Sized,
{
let msg_buf = Vec::with_capacity(message.as_slice().len() * 2);
let msg_builder = MessageBuilder::from_target(msg_buf).expect(
"Developer error: we should be always be able to create a MessageBuilder from a Vec",
);
let mut answer_builder = msg_builder.start_answer(message, Rcode::NoError).ok()?;
match qtype {
Rtype::A => answer_builder
.push((
qname,
Class::In,
DNS_TTL,
domain::rdata::A::from(resource.ipv4()?),
))
.ok()?,
Rtype::Aaaa => answer_builder
.push((
qname,
Class::In,
DNS_TTL,
domain::rdata::Aaaa::from(resource.ipv6()?),
))
.ok()?,
Rtype::Ptr => answer_builder
.push((
qname,
Class::In,
DNS_TTL,
domain::rdata::Ptr::<ParsedDname<_>>::new(
resource.dns_name()?.parse::<Dname<Vec<u8>>>().ok()?.into(),
),
))
.ok()?,
_ => return None,
}
Some(answer_builder.finish())
let datagram = packet.as_udp()?;
let message = to_dns(&datagram)?;
if message.header().qr() {
return None;
}
pub(crate) fn check_for_dns(self: &Arc<Self>, buf: &[u8]) -> Option<SendPacket> {
let packet = IpPacket::new(buf)?;
let version = packet.version();
if packet.destination() != IpAddr::from(DNS_SENTINEL) {
return None;
}
let datagram = packet.as_udp()?;
let message = to_dns(&datagram)?;
if message.header().qr() {
return None;
}
let question = message.first_question()?;
let resource = match question.qtype() {
Rtype::A | Rtype::Aaaa => self
.resources
.read()
.get_by_name(&ToDname::to_cow(question.qname()).to_string())
.cloned(),
Rtype::Ptr => {
let dns_parts = ToDname::to_cow(question.qname()).to_string();
let mut dns_parts = dns_parts.split('.').rev();
if !dns_parts
.next()
.is_some_and(|d| d == REVERSE_DNS_ADDRESS_END)
{
return None;
}
let ip: IpAddr = match dns_parts.next() {
Some(REVERSE_DNS_ADDRESS_V4) => {
let mut ip = [0u8; 4];
for i in ip.iter_mut() {
*i = dns_parts.next()?.parse().ok()?;
}
ip.into()
}
Some(REVERSE_DNS_ADDRESS_V6) => {
let mut ip = [0u8; 16];
for i in ip.iter_mut() {
*i = u8::from_str_radix(
&format!("{}{}", dns_parts.next()?, dns_parts.next()?),
16,
)
.ok()?;
}
ip.into()
}
_ => return None,
};
if dns_parts.next().is_some() {
return None;
}
self.resources.read().get_by_ip(ip).cloned()
let question = message.first_question()?;
let resource = match question.qtype() {
Rtype::A | Rtype::Aaaa => resources
.get_by_name(&ToDname::to_cow(question.qname()).to_string())
.cloned(),
Rtype::Ptr => {
let dns_parts = ToDname::to_cow(question.qname()).to_string();
let mut dns_parts = dns_parts.split('.').rev();
if !dns_parts
.next()
.is_some_and(|d| d == REVERSE_DNS_ADDRESS_END)
{
return None;
}
_ => return None,
};
let response =
self.build_dns_with_answer(message, question.qname(), question.qtype(), &resource?)?;
let response = self.build_response(buf, response);
response.map(|pkt| match version {
Version::Ipv4 => SendPacket::Ipv4(pkt),
Version::Ipv6 => SendPacket::Ipv6(pkt),
})
}
let ip: IpAddr = match dns_parts.next() {
Some(REVERSE_DNS_ADDRESS_V4) => {
let mut ip = [0u8; 4];
for i in ip.iter_mut() {
*i = dns_parts.next()?.parse().ok()?;
}
ip.into()
}
Some(REVERSE_DNS_ADDRESS_V6) => {
let mut ip = [0u8; 16];
for i in ip.iter_mut() {
*i = u8::from_str_radix(
&format!("{}{}", dns_parts.next()?, dns_parts.next()?),
16,
)
.ok()?;
}
ip.into()
}
_ => return None,
};
if dns_parts.next().is_some() {
return None;
}
resources.get_by_ip(ip).cloned()
}
_ => return None,
};
let response = build_dns_with_answer(message, question.qname(), question.qtype(), &resource?)?;
let response = build_response(packet, response);
response.map(|pkt| match version {
Version::Ipv4 => Packet::Ipv4(pkt),
Version::Ipv6 => Packet::Ipv6(pkt),
})
}
fn build_response(original_pkt: IpPacket<'_>, mut dns_answer: Vec<u8>) -> Option<Vec<u8>> {
let response_len = dns_answer.len();
let original_dgm = original_pkt.as_udp()?;
let hdr_len = original_pkt.packet_size() - original_dgm.payload().len();
let mut res_buf = Vec::with_capacity(hdr_len + response_len);
res_buf.extend_from_slice(&original_pkt.packet()[..hdr_len]);
res_buf.append(&mut dns_answer);
let mut pkt = MutableIpPacket::new(&mut res_buf)?;
let dgm_len = UDP_HEADER_SIZE + response_len;
pkt.set_len(hdr_len + response_len, dgm_len);
pkt.swap_src_dst();
let mut dgm = MutableUdpPacket::new(pkt.payload_mut())?;
dgm.set_length(dgm_len as u16);
dgm.set_source(original_dgm.get_destination());
dgm.set_destination(original_dgm.get_source());
let mut pkt = MutableIpPacket::new(&mut res_buf)?;
let udp_checksum = pkt.to_immutable().udp_checksum(&pkt.as_immutable_udp()?);
pkt.as_udp()?.set_checksum(udp_checksum);
pkt.set_ipv4_checksum();
Some(res_buf)
}
fn build_dns_with_answer<N>(
message: &Message<[u8]>,
qname: &N,
qtype: Rtype,
resource: &ResourceDescription,
) -> Option<Vec<u8>>
where
N: ToDname + ?Sized,
{
let msg_buf = Vec::with_capacity(message.as_slice().len() * 2);
let msg_builder = MessageBuilder::from_target(msg_buf).expect(
"Developer error: we should be always be able to create a MessageBuilder from a Vec",
);
let mut answer_builder = msg_builder.start_answer(message, Rcode::NoError).ok()?;
match qtype {
Rtype::A => answer_builder
.push((
qname,
Class::In,
DNS_TTL,
domain::rdata::A::from(resource.ipv4()?),
))
.ok()?,
Rtype::Aaaa => answer_builder
.push((
qname,
Class::In,
DNS_TTL,
domain::rdata::Aaaa::from(resource.ipv6()?),
))
.ok()?,
Rtype::Ptr => answer_builder
.push((
qname,
Class::In,
DNS_TTL,
domain::rdata::Ptr::<ParsedDname<_>>::new(
resource.dns_name()?.parse::<Dname<Vec<u8>>>().ok()?.into(),
),
))
.ok()?,
_ => return None,
}
Some(answer_builder.finish())
}

View File

@@ -0,0 +1,120 @@
use crate::device_channel::create_iface;
use crate::{
ControlSignal, Device, Event, RoleState, Tunnel, ICE_GATHERING_TIMEOUT_SECONDS,
MAX_CONCURRENT_ICE_GATHERING, MAX_UDP_SIZE,
};
use connlib_shared::error::ConnlibError;
use connlib_shared::messages::{ClientId, Interface as InterfaceConfig};
use connlib_shared::Callbacks;
use futures::channel::mpsc::Receiver;
use futures_bounded::{PushError, StreamMap};
use std::sync::Arc;
use std::task::{ready, Context, Poll};
use std::time::Duration;
use webrtc::ice_transport::ice_candidate::RTCIceCandidateInit;
impl<C, CB> Tunnel<C, CB, GatewayState>
where
C: ControlSignal + Send + Sync + 'static,
CB: Callbacks + 'static,
{
/// Sets the interface configuration and starts background tasks.
#[tracing::instrument(level = "trace", skip(self))]
pub async fn set_interface(
self: &Arc<Self>,
config: &InterfaceConfig,
) -> connlib_shared::Result<()> {
let device = create_iface(config, self.callbacks()).await?;
*self.device.write().await = Some(device.clone());
self.start_timers().await?;
*self.iface_handler_abort.lock() =
Some(tokio::spawn(device_handler(Arc::clone(self), device)).abort_handle());
tracing::debug!("background_loop_started");
Ok(())
}
}
/// Reads IP packets from the [`Device`] and handles them accordingly.
async fn device_handler<C, CB>(
tunnel: Arc<Tunnel<C, CB, GatewayState>>,
mut device: Device,
) -> Result<(), ConnlibError>
where
C: ControlSignal + Send + Sync + 'static,
CB: Callbacks + 'static,
{
let mut buf = [0u8; MAX_UDP_SIZE];
loop {
let Some(packet) = device.read().await? else {
// Reading a bad IP packet or otherwise from the device seems bad. Should we restart the tunnel or something?
return Ok(());
};
let dest = packet.destination();
let Some(peer) = tunnel.peer_by_ip(dest) else {
continue;
};
if let Err(e) = tunnel
.encapsulate_and_send_to_peer(packet, peer, &dest, &mut buf)
.await
{
tracing::error!(err = ?e, "failed to handle packet {e:#}")
}
}
}
/// [`Tunnel`] state specific to gateways.
pub struct GatewayState {
candidate_receivers: StreamMap<ClientId, RTCIceCandidateInit>,
}
impl GatewayState {
pub fn add_new_ice_receiver(&mut self, id: ClientId, receiver: Receiver<RTCIceCandidateInit>) {
match self.candidate_receivers.try_push(id, receiver) {
Ok(()) => {}
Err(PushError::BeyondCapacity(_)) => {
tracing::warn!("Too many active ICE candidate receivers at a time")
}
Err(PushError::Replaced(_)) => {
tracing::warn!(%id, "Replaced old ICE candidate receiver with new one")
}
}
}
}
impl Default for GatewayState {
fn default() -> Self {
Self {
candidate_receivers: StreamMap::new(
Duration::from_secs(ICE_GATHERING_TIMEOUT_SECONDS),
MAX_CONCURRENT_ICE_GATHERING,
),
}
}
}
impl RoleState for GatewayState {
type Id = ClientId;
fn poll_next_event(&mut self, cx: &mut Context<'_>) -> Poll<Event<Self::Id>> {
loop {
match ready!(self.candidate_receivers.poll_next_unpin(cx)) {
(conn_id, Some(Ok(c))) => {
return Poll::Ready(Event::SignalIceCandidate {
conn_id,
candidate: c,
})
}
(id, Some(Err(e))) => {
tracing::warn!(gateway_id = %id, "ICE gathering timed out: {e}")
}
(_, None) => {}
}
}
}
}

View File

@@ -1,18 +1,10 @@
use std::{net::IpAddr, sync::Arc, time::Duration};
use std::{net::IpAddr, sync::Arc};
use boringtun::noise::{errors::WireGuardError, Tunn, TunnResult};
use boringtun::noise::{errors::WireGuardError, TunnResult};
use bytes::Bytes;
use connlib_shared::{Callbacks, Error, Result};
use connlib_shared::{Callbacks, Result};
use crate::role_state::RoleState;
use crate::{
device_channel::{DeviceIo, IfaceConfig},
dns,
peer::EncapsulatedPacket,
ConnId, ControlSignal, Tunnel, MAX_UDP_SIZE,
};
const MAX_SIGNAL_CONNECTION_DELAY: Duration = Duration::from_secs(2);
use crate::{ip_packet::MutableIpPacket, peer::Peer, ControlSignal, RoleState, Tunnel};
impl<C, CB, TRoleState> Tunnel<C, CB, TRoleState>
where
@@ -21,74 +13,15 @@ where
TRoleState: RoleState,
{
#[inline(always)]
fn connection_intent(self: &Arc<Self>, src: &[u8], dst_addr: &IpAddr) {
// We can buffer requests here but will drop them for now and let the upper layer reliability protocol handle this
if let Some(resource) = self.get_resource(src) {
// We have awaiting connection to prevent a race condition where
// create_peer_connection hasn't added the thing to peer_connections
// and we are finding another packet to the same address (otherwise we would just use peer_connections here)
let mut awaiting_connection = self.awaiting_connection.lock();
let conn_id = ConnId::from(resource.id());
if awaiting_connection.get(&conn_id).is_none() {
tracing::trace!(
resource_ip = %dst_addr,
"resource_connection_intent",
);
awaiting_connection.insert(conn_id, Default::default());
let dev = Arc::clone(self);
let mut connected_gateway_ids: Vec<_> = dev
.gateway_awaiting_connection
.lock()
.clone()
.into_keys()
.collect();
connected_gateway_ids
.extend(dev.resources_gateways.lock().values().collect::<Vec<_>>());
tracing::trace!(
gateways = ?connected_gateway_ids,
"connected_gateways"
);
tokio::spawn(async move {
let mut interval = tokio::time::interval(MAX_SIGNAL_CONNECTION_DELAY);
loop {
interval.tick().await;
let reference = {
let mut awaiting_connections = dev.awaiting_connection.lock();
let Some(awaiting_connection) =
awaiting_connections.get_mut(&ConnId::from(resource.id()))
else {
break;
};
if awaiting_connection.response_received {
break;
}
awaiting_connection.total_attemps += 1;
awaiting_connection.total_attemps
};
if let Err(e) = dev
.control_signaler
.signal_connection_to(&resource, &connected_gateway_ids, reference)
.await
{
// Not a deadlock because this is a different task
dev.awaiting_connection.lock().remove(&conn_id);
tracing::error!(error = ?e, "start_resource_connection");
let _ = dev.callbacks.on_error(&e);
}
}
});
}
}
}
#[inline(always)]
async fn handle_encapsulated_packet<'a>(
pub(crate) async fn encapsulate_and_send_to_peer<'a>(
&self,
encapsulated_packet: EncapsulatedPacket<'a>,
mut packet: MutableIpPacket<'_>,
peer: Arc<Peer>,
dst_addr: &IpAddr,
buf: &mut [u8],
) -> Result<()> {
let encapsulated_packet = peer.encapsulate(&mut packet, buf)?;
match encapsulated_packet.encapsulate_result {
TunnResult::Done => Ok(()),
TunnResult::Err(WireGuardError::ConnectionExpired)
@@ -130,72 +63,4 @@ where
_ => panic!("Unexpected result from encapsulate"),
}
}
#[inline(always)]
async fn handle_iface_packet(
self: &Arc<Self>,
device_writer: &DeviceIo,
src: &mut [u8],
dst: &mut [u8],
) -> Result<()> {
if let Some(r) = self.check_for_dns(src) {
match r {
dns::SendPacket::Ipv4(r) => device_writer.write4(&r[..])?,
dns::SendPacket::Ipv6(r) => device_writer.write6(&r[..])?,
};
return Ok(());
}
let dst_addr = match Tunn::dst_address(src) {
Some(addr) => addr,
None => return Err(Error::BadPacket),
};
let encapsulated_packet = {
match self.peers_by_ip.read().longest_match(dst_addr).map(|p| p.1) {
Some(peer) => peer.encapsulate(src, dst)?,
None => {
self.connection_intent(src, &dst_addr);
return Ok(());
}
}
};
self.handle_encapsulated_packet(encapsulated_packet, &dst_addr)
.await
}
#[tracing::instrument(level = "trace", skip(self, iface_config, device_io))]
pub(crate) async fn iface_handler(
self: &Arc<Self>,
iface_config: Arc<IfaceConfig>,
device_io: DeviceIo,
) {
let device_writer = device_io.clone();
let mut src = [0u8; MAX_UDP_SIZE];
let mut dst = [0u8; MAX_UDP_SIZE];
loop {
let res = match device_io.read(&mut src[..iface_config.mtu()]).await {
Ok(res) => res,
Err(e) => {
tracing::error!(err = ?e, "failed to read interface: {e:#}");
let _ = self.callbacks.on_error(&e.into());
break;
}
};
tracing::trace!(target: "wire", action = "read", bytes = res, from = "iface");
if res == 0 {
break;
}
if let Err(e) = self
.handle_iface_packet(&device_writer, &mut src[..res], &mut dst)
.await
{
let _ = self.callbacks.on_error(&e);
tracing::error!(err = ?e, "failed to handle packet {e:#}")
}
}
}
}

View File

@@ -32,10 +32,20 @@ macro_rules! swap_src_dst {
impl<'a> MutableIpPacket<'a> {
#[inline]
pub(crate) fn new(data: &mut [u8]) -> Option<MutableIpPacket> {
match data[0] >> 4 {
4 => MutableIpv4Packet::new(data).map(Into::into),
6 => MutableIpv6Packet::new(data).map(Into::into),
_ => None,
let packet = match data[0] >> 4 {
4 => MutableIpv4Packet::new(data)?.into(),
6 => MutableIpv6Packet::new(data)?.into(),
_ => return None,
};
Some(packet)
}
#[inline]
pub(crate) fn destination(&self) -> IpAddr {
match self {
MutableIpPacket::MutableIpv4Packet(i) => i.get_destination().into(),
MutableIpPacket::MutableIpv6Packet(i) => i.get_destination().into(),
}
}
@@ -87,6 +97,13 @@ impl<'a> MutableIpPacket<'a> {
}
}
pub(crate) fn as_immutable(&self) -> IpPacket<'_> {
match self {
Self::MutableIpv4Packet(p) => IpPacket::Ipv4Packet(p.to_immutable()),
Self::MutableIpv6Packet(p) => IpPacket::Ipv6Packet(p.to_immutable()),
}
}
pub(crate) fn as_udp(&mut self) -> Option<MutableUdpPacket> {
self.to_immutable()
.is_udp()
@@ -174,14 +191,6 @@ pub(crate) enum IpPacket<'a> {
}
impl<'a> IpPacket<'a> {
pub(crate) fn new(data: &[u8]) -> Option<IpPacket> {
match data[0] >> 4 {
4 => Ipv4Packet::new(data).map(Into::into),
6 => Ipv6Packet::new(data).map(Into::into),
_ => None,
}
}
pub(crate) fn version(&self) -> Version {
match self {
IpPacket::Ipv4Packet(_) => Version::Ipv4,
@@ -214,13 +223,6 @@ impl<'a> IpPacket<'a> {
.flatten()
}
pub(crate) fn destination(&self) -> IpAddr {
match self {
Self::Ipv4Packet(p) => p.get_destination().into(),
Self::Ipv6Packet(p) => p.get_destination().into(),
}
}
pub(crate) fn source(&self) -> IpAddr {
match self {
Self::Ipv4Packet(p) => p.get_source().into(),
@@ -228,6 +230,13 @@ impl<'a> IpPacket<'a> {
}
}
pub(crate) fn destination(&self) -> IpAddr {
match self {
Self::Ipv4Packet(p) => p.get_destination().into(),
Self::Ipv6Packet(p) => p.get_destination().into(),
}
}
pub(crate) fn udp_checksum(&self, dgm: &UdpPacket<'_>) -> u16 {
match self {
Self::Ipv4Packet(p) => udp::ipv4_checksum(dgm, &p.get_source(), &p.get_destination()),

View File

@@ -3,12 +3,12 @@
//! This is both the wireguard and ICE implementation that should work in tandem.
//! [Tunnel] is the main entry-point for this crate.
use boringtun::{
noise::{errors::WireGuardError, rate_limiter::RateLimiter, Tunn, TunnResult},
noise::{errors::WireGuardError, rate_limiter::RateLimiter, TunnResult},
x25519::{PublicKey, StaticSecret},
};
use bytes::Bytes;
use connlib_shared::{messages::Key, CallbackErrorFacade, Callbacks, Error, DNS_SENTINEL};
use connlib_shared::{messages::Key, CallbackErrorFacade, Callbacks, Error};
use ip_network::IpNetwork;
use ip_network_table::IpNetworkTable;
use serde::{Deserialize, Serialize};
@@ -29,7 +29,7 @@ use webrtc::{
};
use std::task::{Context, Poll};
use std::{collections::HashMap, fmt, net::IpAddr, sync::Arc, time::Duration};
use std::{collections::HashMap, fmt, io, net::IpAddr, sync::Arc, time::Duration};
use webrtc::ice_transport::ice_candidate::RTCIceCandidateInit;
use connlib_shared::{
@@ -39,19 +39,22 @@ use connlib_shared::{
Result,
};
use device_channel::{create_iface, DeviceIo, IfaceConfig};
use device_channel::{DeviceIo, IfaceConfig};
pub use client::ClientState;
pub use control_protocol::Request;
pub use role_state::{ClientState, GatewayState};
pub use gateway::GatewayState;
pub use webrtc::peer_connection::sdp::session_description::RTCSessionDescription;
use crate::role_state::RoleState;
use crate::ip_packet::MutableIpPacket;
use connlib_shared::messages::SecretKey;
use index::IndexLfsr;
mod client;
mod control_protocol;
mod device_channel;
mod dns;
mod gateway;
mod iface_handler;
mod index;
mod ip_packet;
@@ -59,12 +62,23 @@ mod peer;
mod peer_handler;
mod resource_sender;
mod resource_table;
mod role_state;
mod tokio_util;
const MAX_UDP_SIZE: usize = (1 << 16) - 1;
const RESET_PACKET_COUNT_INTERVAL: Duration = Duration::from_secs(1);
const REFRESH_PEERS_TIMERS_INTERVAL: Duration = Duration::from_secs(1);
const REFRESH_MTU_INTERVAL: Duration = Duration::from_secs(30);
/// For how long we will attempt to gather ICE candidates before aborting.
///
/// Chosen arbitrarily.
/// Very likely, the actual WebRTC connection will timeout before this.
/// This timeout is just here to eventually clean-up tasks if they are somehow broken.
const ICE_GATHERING_TIMEOUT_SECONDS: u64 = 5 * 60;
/// How many concurrent ICE gathering attempts we are allow.
///
/// Chosen arbitrarily.
const MAX_CONCURRENT_ICE_GATHERING: usize = 100;
// Note: Taken from boringtun
const HANDSHAKE_RATE_LIMIT: u64 = 100;
@@ -149,8 +163,30 @@ struct AwaitingConnectionDetails {
#[derive(Clone)]
struct Device {
pub config: Arc<IfaceConfig>,
pub io: DeviceIo,
config: Arc<IfaceConfig>,
io: DeviceIo,
buf: Box<[u8; MAX_UDP_SIZE]>,
}
impl Device {
async fn read(&mut self) -> io::Result<Option<MutableIpPacket<'_>>> {
let res = self.io.read(&mut self.buf[..self.config.mtu()]).await?;
tracing::trace!(target: "wire", action = "read", bytes = res, from = "iface");
if res == 0 {
return Ok(None);
}
Ok(Some(
MutableIpPacket::new(&mut self.buf[..res]).ok_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidInput,
"received bytes are not an IP packet",
)
})?,
))
}
}
// TODO: We should use newtypes for each kind of Id
@@ -243,6 +279,14 @@ where
pub fn poll_next_event(&self, cx: &mut Context<'_>) -> Poll<Event<TRoleState::Id>> {
self.role_state.lock().poll_next_event(cx)
}
pub(crate) fn peer_by_ip(&self, ip: IpAddr) -> Option<Arc<Peer>> {
self.peers_by_ip
.read()
.longest_match(ip)
.map(|(_, peer)| peer)
.cloned()
}
}
pub enum Event<TId> {
@@ -327,86 +371,6 @@ where
})
}
#[tracing::instrument(level = "trace", skip(self))]
pub async fn add_route(self: &Arc<Self>, route: IpNetwork) -> Result<()> {
let mut device = self.device.write().await;
if let Some(new_device) = device
.as_ref()
.ok_or(Error::ControlProtocolError)?
.config
.add_route(route, self.callbacks())
.await?
{
*device = Some(new_device.clone());
let dev = Arc::clone(self);
self.iface_handler_abort.lock().replace(
tokio::spawn(
async move { dev.iface_handler(new_device.config, new_device.io).await },
)
.abort_handle(),
);
}
Ok(())
}
/// Adds a the given resource to the tunnel.
///
/// Once added, when a packet for the resource is intercepted a new data channel will be created
/// and packets will be wrapped with wireguard and sent through it.
#[tracing::instrument(level = "trace", skip(self))]
pub async fn add_resource(
self: &Arc<Self>,
resource_description: ResourceDescription,
) -> Result<()> {
let mut any_valid_route = false;
{
for ip in resource_description.ips() {
if let Err(e) = self.add_route(ip).await {
tracing::warn!(route = %ip, error = ?e, "add_route");
let _ = self.callbacks().on_error(&e);
} else {
any_valid_route = true;
}
}
}
if !any_valid_route {
return Err(Error::InvalidResource);
}
let resource_list = {
let mut resources = self.resources.write();
resources.insert(resource_description);
resources.resource_list()
};
self.callbacks.on_update_resources(resource_list)?;
Ok(())
}
/// Sets the interface configuration and starts background tasks.
#[tracing::instrument(level = "trace", skip(self))]
pub async fn set_interface(self: &Arc<Self>, config: &InterfaceConfig) -> Result<()> {
let device = create_iface(config, self.callbacks()).await?;
*self.device.write().await = Some(device.clone());
self.start_timers().await?;
let dev = Arc::clone(self);
*self.iface_handler_abort.lock() = Some(
tokio::spawn(async move { dev.iface_handler(device.config, device.io).await })
.abort_handle(),
);
self.add_route(DNS_SENTINEL.into()).await?;
self.callbacks.on_tunnel_ready()?;
tracing::debug!("background_loop_started");
Ok(())
}
#[tracing::instrument(level = "trace", skip(self))]
async fn stop_peer(&self, index: u32, conn_id: ConnId) {
self.peers_by_ip.write().retain(|_, p| p.index != index);
@@ -537,8 +501,7 @@ where
Ok(())
}
fn get_resource(&self, buff: &[u8]) -> Option<ResourceDescription> {
let addr = Tunn::dst_address(buff)?;
fn get_resource(&self, addr: IpAddr) -> Option<ResourceDescription> {
let resources = self.resources.read();
match addr {
IpAddr::V4(ipv4) => resources.get_by_ip(ipv4).cloned(),
@@ -554,3 +517,13 @@ where
&self.callbacks
}
}
/// Dedicated trait for abstracting over the different ICE states.
///
/// By design, this trait does not allow any operations apart from advancing via [`RoleState::poll_next_event`].
/// The state should only be modified when the concrete type is known, e.g. [`ClientState`] or [`GatewayState`].
pub trait RoleState: Default + Send + 'static {
type Id: fmt::Debug;
fn poll_next_event(&mut self, cx: &mut Context<'_>) -> Poll<Event<Self::Id>>;
}

View File

@@ -10,6 +10,7 @@ use connlib_shared::{
use ip_network::IpNetwork;
use ip_network_table::IpNetworkTable;
use parking_lot::{Mutex, RwLock};
use pnet_packet::MutablePacket;
use webrtc::data::data_channel::DataChannel;
use crate::{ip_packet::MutableIpPacket, resource_table::ResourceTable, ConnId};
@@ -194,14 +195,9 @@ impl Peer {
pub(crate) fn encapsulate<'a>(
&self,
src: &'a mut [u8],
packet: &mut MutableIpPacket<'a>,
dst: &'a mut [u8],
) -> Result<EncapsulatedPacket<'a>> {
let Some(mut packet) = MutableIpPacket::new(src) else {
debug_assert!(false, "Got non-ip packet from the tunnel interface");
tracing::error!("Developer error: we should never see a packet through the tunnel wire that isn't ip");
return Err(Error::BadPacket);
};
if let Some(resource) = self.get_translation(packet.to_immutable().source()) {
let ResourceDescription::Dns(resource) = resource else {
tracing::error!(
@@ -210,7 +206,7 @@ impl Peer {
return Err(Error::ControlProtocolError);
};
match &mut packet {
match packet {
MutableIpPacket::MutableIpv4Packet(ref mut p) => p.set_source(resource.ipv4),
MutableIpPacket::MutableIpv6Packet(ref mut p) => p.set_source(resource.ipv6),
}
@@ -221,7 +217,7 @@ impl Peer {
index: self.index,
conn_id: self.conn_id,
channel: self.channel.clone(),
encapsulate_result: self.tunnel.lock().encapsulate(src, dst),
encapsulate_result: self.tunnel.lock().encapsulate(packet.packet_mut(), dst),
})
}

View File

@@ -4,10 +4,9 @@ use boringtun::noise::{handshake::parse_handshake_anon, Packet, TunnResult};
use bytes::Bytes;
use connlib_shared::{Callbacks, Error, Result};
use crate::role_state::RoleState;
use crate::{
device_channel::DeviceIo, index::check_packet_index, peer::Peer, ControlSignal, Tunnel,
MAX_UDP_SIZE,
device_channel::DeviceIo, index::check_packet_index, peer::Peer, ControlSignal, RoleState,
Tunnel, MAX_UDP_SIZE,
};
impl<C, CB, TRoleState> Tunnel<C, CB, TRoleState>

View File

@@ -1,148 +0,0 @@
use crate::Event;
use connlib_shared::messages::{ClientId, GatewayId};
use futures::channel::mpsc::Receiver;
use futures_bounded::{PushError, StreamMap};
use std::collections::HashMap;
use std::fmt;
use std::task::{ready, Context, Poll};
use std::time::Duration;
use webrtc::ice_transport::ice_candidate::RTCIceCandidateInit;
/// Dedicated trait for abstracting over the different ICE states.
///
/// By design, this trait does not allow any operations apart from advancing via [`RoleState::poll_next_event`].
/// The state should only be modified when the concrete type is known, e.g. [`ClientState`] or [`GatewayState`].
pub trait RoleState: Default + Send + 'static {
type Id: fmt::Debug;
fn poll_next_event(&mut self, cx: &mut Context<'_>) -> Poll<Event<Self::Id>>;
}
/// For how long we will attempt to gather ICE candidates before aborting.
///
/// Chosen arbitrarily.
/// Very likely, the actual WebRTC connection will timeout before this.
/// This timeout is just here to eventually clean-up tasks if they are somehow broken.
const ICE_GATHERING_TIMEOUT_SECONDS: u64 = 5 * 60;
/// How many concurrent ICE gathering attempts we are allow.
///
/// Chosen arbitrarily.
const MAX_CONCURRENT_ICE_GATHERING: usize = 100;
/// [`Tunnel`](crate::Tunnel) state specific to clients.
pub struct ClientState {
active_candidate_receivers: StreamMap<GatewayId, RTCIceCandidateInit>,
/// We split the receivers of ICE candidates into two phases because we only want to start sending them once we've received an SDP from the gateway.
waiting_for_sdp_from_gatway: HashMap<GatewayId, Receiver<RTCIceCandidateInit>>,
}
impl ClientState {
pub fn add_waiting_ice_receiver(
&mut self,
id: GatewayId,
receiver: Receiver<RTCIceCandidateInit>,
) {
self.waiting_for_sdp_from_gatway.insert(id, receiver);
}
pub fn activate_ice_candidate_receiver(&mut self, id: GatewayId) {
let Some(receiver) = self.waiting_for_sdp_from_gatway.remove(&id) else {
return;
};
match self.active_candidate_receivers.try_push(id, receiver) {
Ok(()) => {}
Err(PushError::BeyondCapacity(_)) => {
tracing::warn!("Too many active ICE candidate receivers at a time")
}
Err(PushError::Replaced(_)) => {
tracing::warn!(%id, "Replaced old ICE candidate receiver with new one")
}
}
}
}
impl Default for ClientState {
fn default() -> Self {
Self {
active_candidate_receivers: StreamMap::new(
Duration::from_secs(ICE_GATHERING_TIMEOUT_SECONDS),
MAX_CONCURRENT_ICE_GATHERING,
),
waiting_for_sdp_from_gatway: Default::default(),
}
}
}
impl RoleState for ClientState {
type Id = GatewayId;
fn poll_next_event(&mut self, cx: &mut Context<'_>) -> Poll<Event<Self::Id>> {
loop {
match ready!(self.active_candidate_receivers.poll_next_unpin(cx)) {
(conn_id, Some(Ok(c))) => {
return Poll::Ready(Event::SignalIceCandidate {
conn_id,
candidate: c,
})
}
(id, Some(Err(e))) => {
tracing::warn!(gateway_id = %id, "ICE gathering timed out: {e}")
}
(_, None) => {}
}
}
}
}
/// [`Tunnel`](crate::Tunnel) state specific to gateways.
pub struct GatewayState {
candidate_receivers: StreamMap<ClientId, RTCIceCandidateInit>,
}
impl GatewayState {
pub fn add_new_ice_receiver(&mut self, id: ClientId, receiver: Receiver<RTCIceCandidateInit>) {
match self.candidate_receivers.try_push(id, receiver) {
Ok(()) => {}
Err(PushError::BeyondCapacity(_)) => {
tracing::warn!("Too many active ICE candidate receivers at a time")
}
Err(PushError::Replaced(_)) => {
tracing::warn!(%id, "Replaced old ICE candidate receiver with new one")
}
}
}
}
impl Default for GatewayState {
fn default() -> Self {
Self {
candidate_receivers: StreamMap::new(
Duration::from_secs(ICE_GATHERING_TIMEOUT_SECONDS),
MAX_CONCURRENT_ICE_GATHERING,
),
}
}
}
impl RoleState for GatewayState {
type Id = ClientId;
fn poll_next_event(&mut self, cx: &mut Context<'_>) -> Poll<Event<Self::Id>> {
loop {
match ready!(self.candidate_receivers.poll_next_unpin(cx)) {
(conn_id, Some(Ok(c))) => {
return Poll::Ready(Event::SignalIceCandidate {
conn_id,
candidate: c,
})
}
(id, Some(Err(e))) => {
tracing::warn!(gateway_id = %id, "ICE gathering timed out: {e}")
}
(_, None) => {}
}
}
}
}

View File

@@ -0,0 +1,23 @@
use connlib_shared::error::ConnlibError;
use connlib_shared::Callbacks;
use std::future::Future;
/// Spawns a task into the [`tokio`] runtime.
///
/// On error, [`Callbacks::on_error`] is invoked.
/// This also returns a [`tokio::task::AbortHandle`] which MAY be used to abort the task.
/// If you don't need it, you are free to drop it.
/// It won't terminate the task.
pub(crate) fn spawn_log(
cb: &(impl Callbacks + 'static),
f: impl Future<Output = Result<(), ConnlibError>> + Send + 'static,
) -> tokio::task::AbortHandle {
let cb = cb.clone();
tokio::spawn(async move {
if let Err(e) = f.await {
let _ = cb.on_error(&e);
}
})
.abort_handle()
}