mirror of
https://github.com/outbackdingo/firezone.git
synced 2026-01-27 18:18:55 +00:00
refactor(connlib): split device handler for clients & gateway (#2301)
This commit is contained in:
285
rust/connlib/tunnel/src/client.rs
Normal file
285
rust/connlib/tunnel/src/client.rs
Normal file
@@ -0,0 +1,285 @@
|
||||
use crate::device_channel::{create_iface, DeviceIo};
|
||||
use crate::ip_packet::IpPacket;
|
||||
use crate::{
|
||||
dns, tokio_util, ConnId, ControlSignal, Device, Event, RoleState, Tunnel,
|
||||
ICE_GATHERING_TIMEOUT_SECONDS, MAX_CONCURRENT_ICE_GATHERING, MAX_UDP_SIZE,
|
||||
};
|
||||
use connlib_shared::error::{ConnlibError as Error, ConnlibError};
|
||||
use connlib_shared::messages::{GatewayId, Interface as InterfaceConfig, ResourceDescription};
|
||||
use connlib_shared::{Callbacks, DNS_SENTINEL};
|
||||
use futures::channel::mpsc::Receiver;
|
||||
use futures_bounded::{PushError, StreamMap};
|
||||
use ip_network::IpNetwork;
|
||||
use std::collections::HashMap;
|
||||
use std::io;
|
||||
use std::sync::Arc;
|
||||
use std::task::{ready, Context, Poll};
|
||||
use std::time::Duration;
|
||||
use webrtc::ice_transport::ice_candidate::RTCIceCandidateInit;
|
||||
|
||||
impl<C, CB> Tunnel<C, CB, ClientState>
|
||||
where
|
||||
C: ControlSignal + Send + Sync + 'static,
|
||||
CB: Callbacks + 'static,
|
||||
{
|
||||
/// Adds a the given resource to the tunnel.
|
||||
///
|
||||
/// Once added, when a packet for the resource is intercepted a new data channel will be created
|
||||
/// and packets will be wrapped with wireguard and sent through it.
|
||||
#[tracing::instrument(level = "trace", skip(self))]
|
||||
pub async fn add_resource(
|
||||
self: &Arc<Self>,
|
||||
resource_description: ResourceDescription,
|
||||
) -> connlib_shared::Result<()> {
|
||||
let mut any_valid_route = false;
|
||||
{
|
||||
for ip in resource_description.ips() {
|
||||
if let Err(e) = self.add_route(ip).await {
|
||||
tracing::warn!(route = %ip, error = ?e, "add_route");
|
||||
let _ = self.callbacks().on_error(&e);
|
||||
} else {
|
||||
any_valid_route = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
if !any_valid_route {
|
||||
return Err(Error::InvalidResource);
|
||||
}
|
||||
|
||||
let resource_list = {
|
||||
let mut resources = self.resources.write();
|
||||
resources.insert(resource_description);
|
||||
resources.resource_list()
|
||||
};
|
||||
|
||||
self.callbacks.on_update_resources(resource_list)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Sets the interface configuration and starts background tasks.
|
||||
#[tracing::instrument(level = "trace", skip(self))]
|
||||
pub async fn set_interface(
|
||||
self: &Arc<Self>,
|
||||
config: &InterfaceConfig,
|
||||
) -> connlib_shared::Result<()> {
|
||||
let device = create_iface(config, self.callbacks()).await?;
|
||||
*self.device.write().await = Some(device.clone());
|
||||
|
||||
self.start_timers().await?;
|
||||
*self.iface_handler_abort.lock() = Some(tokio_util::spawn_log(
|
||||
&self.callbacks,
|
||||
device_handler(Arc::clone(self), device),
|
||||
));
|
||||
|
||||
self.add_route(DNS_SENTINEL.into()).await?;
|
||||
|
||||
self.callbacks.on_tunnel_ready()?;
|
||||
|
||||
tracing::debug!("background_loop_started");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tracing::instrument(level = "trace", skip(self))]
|
||||
async fn add_route(self: &Arc<Self>, route: IpNetwork) -> connlib_shared::Result<()> {
|
||||
let mut device = self.device.write().await;
|
||||
|
||||
if let Some(new_device) = device
|
||||
.as_ref()
|
||||
.ok_or(Error::ControlProtocolError)?
|
||||
.config
|
||||
.add_route(route, self.callbacks())
|
||||
.await?
|
||||
{
|
||||
*device = Some(new_device.clone());
|
||||
*self.iface_handler_abort.lock() = Some(tokio_util::spawn_log(
|
||||
&self.callbacks,
|
||||
device_handler(Arc::clone(self), new_device),
|
||||
));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn connection_intent(self: &Arc<Self>, packet: IpPacket<'_>) {
|
||||
const MAX_SIGNAL_CONNECTION_DELAY: Duration = Duration::from_secs(2);
|
||||
|
||||
// We can buffer requests here but will drop them for now and let the upper layer reliability protocol handle this
|
||||
if let Some(resource) = self.get_resource(packet.source()) {
|
||||
// We have awaiting connection to prevent a race condition where
|
||||
// create_peer_connection hasn't added the thing to peer_connections
|
||||
// and we are finding another packet to the same address (otherwise we would just use peer_connections here)
|
||||
let mut awaiting_connection = self.awaiting_connection.lock();
|
||||
let conn_id = ConnId::from(resource.id());
|
||||
if awaiting_connection.get(&conn_id).is_none() {
|
||||
tracing::trace!(
|
||||
resource_ip = %packet.destination(),
|
||||
"resource_connection_intent",
|
||||
);
|
||||
|
||||
awaiting_connection.insert(conn_id, Default::default());
|
||||
let dev = Arc::clone(self);
|
||||
|
||||
let mut connected_gateway_ids: Vec<_> = dev
|
||||
.gateway_awaiting_connection
|
||||
.lock()
|
||||
.clone()
|
||||
.into_keys()
|
||||
.collect();
|
||||
connected_gateway_ids
|
||||
.extend(dev.resources_gateways.lock().values().collect::<Vec<_>>());
|
||||
tracing::trace!(
|
||||
gateways = ?connected_gateway_ids,
|
||||
"connected_gateways"
|
||||
);
|
||||
tokio::spawn(async move {
|
||||
let mut interval = tokio::time::interval(MAX_SIGNAL_CONNECTION_DELAY);
|
||||
loop {
|
||||
interval.tick().await;
|
||||
let reference = {
|
||||
let mut awaiting_connections = dev.awaiting_connection.lock();
|
||||
let Some(awaiting_connection) =
|
||||
awaiting_connections.get_mut(&ConnId::from(resource.id()))
|
||||
else {
|
||||
break;
|
||||
};
|
||||
if awaiting_connection.response_received {
|
||||
break;
|
||||
}
|
||||
awaiting_connection.total_attemps += 1;
|
||||
awaiting_connection.total_attemps
|
||||
};
|
||||
if let Err(e) = dev
|
||||
.control_signaler
|
||||
.signal_connection_to(&resource, &connected_gateway_ids, reference)
|
||||
.await
|
||||
{
|
||||
// Not a deadlock because this is a different task
|
||||
dev.awaiting_connection.lock().remove(&conn_id);
|
||||
tracing::error!(error = ?e, "start_resource_connection");
|
||||
let _ = dev.callbacks.on_error(&e);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Reads IP packets from the [`Device`] and handles them accordingly.
|
||||
async fn device_handler<C, CB>(
|
||||
tunnel: Arc<Tunnel<C, CB, ClientState>>,
|
||||
mut device: Device,
|
||||
) -> Result<(), ConnlibError>
|
||||
where
|
||||
C: ControlSignal + Send + Sync + 'static,
|
||||
CB: Callbacks + 'static,
|
||||
{
|
||||
let device_writer = device.io.clone();
|
||||
let mut buf = [0u8; MAX_UDP_SIZE];
|
||||
loop {
|
||||
let Some(packet) = device.read().await? else {
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
if let Some(dns_packet) = dns::parse(&tunnel.resources.read(), packet.as_immutable()) {
|
||||
if let Err(e) = send_dns_packet(&device_writer, dns_packet) {
|
||||
tracing::error!(err = %e, "failed to send DNS packet");
|
||||
let _ = tunnel.callbacks.on_error(&e.into());
|
||||
}
|
||||
|
||||
continue;
|
||||
}
|
||||
|
||||
let dest = packet.destination();
|
||||
|
||||
let Some(peer) = tunnel.peer_by_ip(dest) else {
|
||||
tunnel.connection_intent(packet.as_immutable());
|
||||
continue;
|
||||
};
|
||||
|
||||
if let Err(e) = tunnel
|
||||
.encapsulate_and_send_to_peer(packet, peer, &dest, &mut buf)
|
||||
.await
|
||||
{
|
||||
let _ = tunnel.callbacks.on_error(&e);
|
||||
tracing::error!(err = ?e, "failed to handle packet {e:#}")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn send_dns_packet(device_writer: &DeviceIo, packet: dns::Packet) -> io::Result<()> {
|
||||
match packet {
|
||||
dns::Packet::Ipv4(r) => device_writer.write4(&r[..])?,
|
||||
dns::Packet::Ipv6(r) => device_writer.write6(&r[..])?,
|
||||
};
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// [`Tunnel`] state specific to clients.
|
||||
pub struct ClientState {
|
||||
active_candidate_receivers: StreamMap<GatewayId, RTCIceCandidateInit>,
|
||||
/// We split the receivers of ICE candidates into two phases because we only want to start sending them once we've received an SDP from the gateway.
|
||||
waiting_for_sdp_from_gatway: HashMap<GatewayId, Receiver<RTCIceCandidateInit>>,
|
||||
}
|
||||
|
||||
impl ClientState {
|
||||
pub fn add_waiting_ice_receiver(
|
||||
&mut self,
|
||||
id: GatewayId,
|
||||
receiver: Receiver<RTCIceCandidateInit>,
|
||||
) {
|
||||
self.waiting_for_sdp_from_gatway.insert(id, receiver);
|
||||
}
|
||||
|
||||
pub fn activate_ice_candidate_receiver(&mut self, id: GatewayId) {
|
||||
let Some(receiver) = self.waiting_for_sdp_from_gatway.remove(&id) else {
|
||||
return;
|
||||
};
|
||||
|
||||
match self.active_candidate_receivers.try_push(id, receiver) {
|
||||
Ok(()) => {}
|
||||
Err(PushError::BeyondCapacity(_)) => {
|
||||
tracing::warn!("Too many active ICE candidate receivers at a time")
|
||||
}
|
||||
Err(PushError::Replaced(_)) => {
|
||||
tracing::warn!(%id, "Replaced old ICE candidate receiver with new one")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ClientState {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
active_candidate_receivers: StreamMap::new(
|
||||
Duration::from_secs(ICE_GATHERING_TIMEOUT_SECONDS),
|
||||
MAX_CONCURRENT_ICE_GATHERING,
|
||||
),
|
||||
waiting_for_sdp_from_gatway: Default::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl RoleState for ClientState {
|
||||
type Id = GatewayId;
|
||||
|
||||
fn poll_next_event(&mut self, cx: &mut Context<'_>) -> Poll<Event<Self::Id>> {
|
||||
loop {
|
||||
match ready!(self.active_candidate_receivers.poll_next_unpin(cx)) {
|
||||
(conn_id, Some(Ok(c))) => {
|
||||
return Poll::Ready(Event::SignalIceCandidate {
|
||||
conn_id,
|
||||
candidate: c,
|
||||
})
|
||||
}
|
||||
(id, Some(Err(e))) => {
|
||||
tracing::warn!(gateway_id = %id, "ICE gathering timed out: {e}")
|
||||
}
|
||||
(_, None) => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -22,8 +22,7 @@ use webrtc::{
|
||||
},
|
||||
};
|
||||
|
||||
use crate::role_state::RoleState;
|
||||
use crate::{peer::Peer, ConnId, ControlSignal, PeerConfig, Tunnel};
|
||||
use crate::{peer::Peer, ConnId, ControlSignal, PeerConfig, RoleState, Tunnel};
|
||||
|
||||
mod client;
|
||||
mod gateway;
|
||||
|
||||
@@ -10,8 +10,7 @@ use webrtc::peer_connection::{
|
||||
RTCPeerConnection,
|
||||
};
|
||||
|
||||
use crate::role_state::GatewayState;
|
||||
use crate::{ControlSignal, PeerConfig, Tunnel};
|
||||
use crate::{ControlSignal, GatewayState, PeerConfig, Tunnel};
|
||||
|
||||
#[tracing::instrument(level = "trace", skip(tunnel))]
|
||||
fn handle_connection_state_update<C, CB>(
|
||||
|
||||
@@ -9,7 +9,7 @@ use tokio::io::{unix::AsyncFd, Interest};
|
||||
|
||||
use tun::{IfaceDevice, IfaceStream};
|
||||
|
||||
use crate::Device;
|
||||
use crate::{Device, MAX_UDP_SIZE};
|
||||
|
||||
mod tun;
|
||||
|
||||
@@ -65,7 +65,11 @@ impl IfaceConfig {
|
||||
iface,
|
||||
mtu: AtomicUsize::new(mtu),
|
||||
});
|
||||
Ok(Some(Device { io, config }))
|
||||
Ok(Some(Device {
|
||||
io,
|
||||
config,
|
||||
buf: Box::new([0u8; MAX_UDP_SIZE]),
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -82,5 +86,9 @@ pub(crate) async fn create_iface(
|
||||
mtu: AtomicUsize::new(mtu),
|
||||
});
|
||||
|
||||
Ok(Device { config, io })
|
||||
Ok(Device {
|
||||
io,
|
||||
config,
|
||||
buf: Box::new([0u8; MAX_UDP_SIZE]),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,15 +1,12 @@
|
||||
use std::{net::IpAddr, sync::Arc};
|
||||
|
||||
use crate::{
|
||||
ip_packet::{to_dns, IpPacket, MutableIpPacket, Version},
|
||||
ControlSignal, Tunnel,
|
||||
};
|
||||
use connlib_shared::{messages::ResourceDescription, Callbacks, DNS_SENTINEL};
|
||||
use crate::ip_packet::{to_dns, IpPacket, MutableIpPacket, Version};
|
||||
use crate::resource_table::ResourceTable;
|
||||
use connlib_shared::{messages::ResourceDescription, DNS_SENTINEL};
|
||||
use domain::base::{
|
||||
iana::{Class, Rcode, Rtype},
|
||||
Dname, Message, MessageBuilder, ParsedDname, ToDname,
|
||||
};
|
||||
use pnet_packet::{udp::MutableUdpPacket, MutablePacket, Packet as UdpPacket, PacketSize};
|
||||
use std::net::IpAddr;
|
||||
|
||||
const DNS_TTL: u32 = 300;
|
||||
const UDP_HEADER_SIZE: usize = 8;
|
||||
@@ -18,7 +15,7 @@ const REVERSE_DNS_ADDRESS_V4: &str = "in-addr";
|
||||
const REVERSE_DNS_ADDRESS_V6: &str = "ip6";
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) enum SendPacket {
|
||||
pub(crate) enum Packet {
|
||||
Ipv4(Vec<u8>),
|
||||
Ipv6(Vec<u8>),
|
||||
}
|
||||
@@ -28,152 +25,139 @@ pub(crate) enum SendPacket {
|
||||
// as we can therefore we won't do it.
|
||||
//
|
||||
// See: https://stackoverflow.com/a/55093896
|
||||
impl<C, CB, TRoleState> Tunnel<C, CB, TRoleState>
|
||||
where
|
||||
C: ControlSignal + Send + Sync + 'static,
|
||||
CB: Callbacks + 'static,
|
||||
{
|
||||
fn build_response(
|
||||
self: &Arc<Self>,
|
||||
original_buf: &[u8],
|
||||
mut dns_answer: Vec<u8>,
|
||||
) -> Option<Vec<u8>> {
|
||||
let response_len = dns_answer.len();
|
||||
let original_pkt = IpPacket::new(original_buf)?;
|
||||
let original_dgm = original_pkt.as_udp()?;
|
||||
let hdr_len = original_pkt.packet_size() - original_dgm.payload().len();
|
||||
let mut res_buf = Vec::with_capacity(hdr_len + response_len);
|
||||
|
||||
res_buf.extend_from_slice(&original_buf[..hdr_len]);
|
||||
res_buf.append(&mut dns_answer);
|
||||
|
||||
let mut pkt = MutableIpPacket::new(&mut res_buf)?;
|
||||
let dgm_len = UDP_HEADER_SIZE + response_len;
|
||||
pkt.set_len(hdr_len + response_len, dgm_len);
|
||||
pkt.swap_src_dst();
|
||||
|
||||
let mut dgm = MutableUdpPacket::new(pkt.payload_mut())?;
|
||||
dgm.set_length(dgm_len as u16);
|
||||
dgm.set_source(original_dgm.get_destination());
|
||||
dgm.set_destination(original_dgm.get_source());
|
||||
|
||||
let mut pkt = MutableIpPacket::new(&mut res_buf)?;
|
||||
let udp_checksum = pkt.to_immutable().udp_checksum(&pkt.as_immutable_udp()?);
|
||||
pkt.as_udp()?.set_checksum(udp_checksum);
|
||||
pkt.set_ipv4_checksum();
|
||||
Some(res_buf)
|
||||
pub(crate) fn parse(
|
||||
resources: &ResourceTable<ResourceDescription>,
|
||||
packet: IpPacket<'_>,
|
||||
) -> Option<Packet> {
|
||||
let version = packet.version();
|
||||
if packet.destination() != IpAddr::from(DNS_SENTINEL) {
|
||||
return None;
|
||||
}
|
||||
|
||||
fn build_dns_with_answer<N>(
|
||||
self: &Arc<Self>,
|
||||
message: &Message<[u8]>,
|
||||
qname: &N,
|
||||
qtype: Rtype,
|
||||
resource: &ResourceDescription,
|
||||
) -> Option<Vec<u8>>
|
||||
where
|
||||
N: ToDname + ?Sized,
|
||||
{
|
||||
let msg_buf = Vec::with_capacity(message.as_slice().len() * 2);
|
||||
let msg_builder = MessageBuilder::from_target(msg_buf).expect(
|
||||
"Developer error: we should be always be able to create a MessageBuilder from a Vec",
|
||||
);
|
||||
let mut answer_builder = msg_builder.start_answer(message, Rcode::NoError).ok()?;
|
||||
match qtype {
|
||||
Rtype::A => answer_builder
|
||||
.push((
|
||||
qname,
|
||||
Class::In,
|
||||
DNS_TTL,
|
||||
domain::rdata::A::from(resource.ipv4()?),
|
||||
))
|
||||
.ok()?,
|
||||
Rtype::Aaaa => answer_builder
|
||||
.push((
|
||||
qname,
|
||||
Class::In,
|
||||
DNS_TTL,
|
||||
domain::rdata::Aaaa::from(resource.ipv6()?),
|
||||
))
|
||||
.ok()?,
|
||||
Rtype::Ptr => answer_builder
|
||||
.push((
|
||||
qname,
|
||||
Class::In,
|
||||
DNS_TTL,
|
||||
domain::rdata::Ptr::<ParsedDname<_>>::new(
|
||||
resource.dns_name()?.parse::<Dname<Vec<u8>>>().ok()?.into(),
|
||||
),
|
||||
))
|
||||
.ok()?,
|
||||
_ => return None,
|
||||
}
|
||||
Some(answer_builder.finish())
|
||||
let datagram = packet.as_udp()?;
|
||||
let message = to_dns(&datagram)?;
|
||||
if message.header().qr() {
|
||||
return None;
|
||||
}
|
||||
|
||||
pub(crate) fn check_for_dns(self: &Arc<Self>, buf: &[u8]) -> Option<SendPacket> {
|
||||
let packet = IpPacket::new(buf)?;
|
||||
let version = packet.version();
|
||||
if packet.destination() != IpAddr::from(DNS_SENTINEL) {
|
||||
return None;
|
||||
}
|
||||
let datagram = packet.as_udp()?;
|
||||
let message = to_dns(&datagram)?;
|
||||
if message.header().qr() {
|
||||
return None;
|
||||
}
|
||||
let question = message.first_question()?;
|
||||
let resource = match question.qtype() {
|
||||
Rtype::A | Rtype::Aaaa => self
|
||||
.resources
|
||||
.read()
|
||||
.get_by_name(&ToDname::to_cow(question.qname()).to_string())
|
||||
.cloned(),
|
||||
Rtype::Ptr => {
|
||||
let dns_parts = ToDname::to_cow(question.qname()).to_string();
|
||||
let mut dns_parts = dns_parts.split('.').rev();
|
||||
if !dns_parts
|
||||
.next()
|
||||
.is_some_and(|d| d == REVERSE_DNS_ADDRESS_END)
|
||||
{
|
||||
return None;
|
||||
}
|
||||
let ip: IpAddr = match dns_parts.next() {
|
||||
Some(REVERSE_DNS_ADDRESS_V4) => {
|
||||
let mut ip = [0u8; 4];
|
||||
for i in ip.iter_mut() {
|
||||
*i = dns_parts.next()?.parse().ok()?;
|
||||
}
|
||||
ip.into()
|
||||
}
|
||||
Some(REVERSE_DNS_ADDRESS_V6) => {
|
||||
let mut ip = [0u8; 16];
|
||||
for i in ip.iter_mut() {
|
||||
*i = u8::from_str_radix(
|
||||
&format!("{}{}", dns_parts.next()?, dns_parts.next()?),
|
||||
16,
|
||||
)
|
||||
.ok()?;
|
||||
}
|
||||
ip.into()
|
||||
}
|
||||
_ => return None,
|
||||
};
|
||||
|
||||
if dns_parts.next().is_some() {
|
||||
return None;
|
||||
}
|
||||
|
||||
self.resources.read().get_by_ip(ip).cloned()
|
||||
let question = message.first_question()?;
|
||||
let resource = match question.qtype() {
|
||||
Rtype::A | Rtype::Aaaa => resources
|
||||
.get_by_name(&ToDname::to_cow(question.qname()).to_string())
|
||||
.cloned(),
|
||||
Rtype::Ptr => {
|
||||
let dns_parts = ToDname::to_cow(question.qname()).to_string();
|
||||
let mut dns_parts = dns_parts.split('.').rev();
|
||||
if !dns_parts
|
||||
.next()
|
||||
.is_some_and(|d| d == REVERSE_DNS_ADDRESS_END)
|
||||
{
|
||||
return None;
|
||||
}
|
||||
_ => return None,
|
||||
};
|
||||
let response =
|
||||
self.build_dns_with_answer(message, question.qname(), question.qtype(), &resource?)?;
|
||||
let response = self.build_response(buf, response);
|
||||
response.map(|pkt| match version {
|
||||
Version::Ipv4 => SendPacket::Ipv4(pkt),
|
||||
Version::Ipv6 => SendPacket::Ipv6(pkt),
|
||||
})
|
||||
}
|
||||
let ip: IpAddr = match dns_parts.next() {
|
||||
Some(REVERSE_DNS_ADDRESS_V4) => {
|
||||
let mut ip = [0u8; 4];
|
||||
for i in ip.iter_mut() {
|
||||
*i = dns_parts.next()?.parse().ok()?;
|
||||
}
|
||||
ip.into()
|
||||
}
|
||||
Some(REVERSE_DNS_ADDRESS_V6) => {
|
||||
let mut ip = [0u8; 16];
|
||||
for i in ip.iter_mut() {
|
||||
*i = u8::from_str_radix(
|
||||
&format!("{}{}", dns_parts.next()?, dns_parts.next()?),
|
||||
16,
|
||||
)
|
||||
.ok()?;
|
||||
}
|
||||
ip.into()
|
||||
}
|
||||
_ => return None,
|
||||
};
|
||||
|
||||
if dns_parts.next().is_some() {
|
||||
return None;
|
||||
}
|
||||
|
||||
resources.get_by_ip(ip).cloned()
|
||||
}
|
||||
_ => return None,
|
||||
};
|
||||
let response = build_dns_with_answer(message, question.qname(), question.qtype(), &resource?)?;
|
||||
let response = build_response(packet, response);
|
||||
response.map(|pkt| match version {
|
||||
Version::Ipv4 => Packet::Ipv4(pkt),
|
||||
Version::Ipv6 => Packet::Ipv6(pkt),
|
||||
})
|
||||
}
|
||||
|
||||
fn build_response(original_pkt: IpPacket<'_>, mut dns_answer: Vec<u8>) -> Option<Vec<u8>> {
|
||||
let response_len = dns_answer.len();
|
||||
let original_dgm = original_pkt.as_udp()?;
|
||||
let hdr_len = original_pkt.packet_size() - original_dgm.payload().len();
|
||||
let mut res_buf = Vec::with_capacity(hdr_len + response_len);
|
||||
|
||||
res_buf.extend_from_slice(&original_pkt.packet()[..hdr_len]);
|
||||
res_buf.append(&mut dns_answer);
|
||||
|
||||
let mut pkt = MutableIpPacket::new(&mut res_buf)?;
|
||||
let dgm_len = UDP_HEADER_SIZE + response_len;
|
||||
pkt.set_len(hdr_len + response_len, dgm_len);
|
||||
pkt.swap_src_dst();
|
||||
|
||||
let mut dgm = MutableUdpPacket::new(pkt.payload_mut())?;
|
||||
dgm.set_length(dgm_len as u16);
|
||||
dgm.set_source(original_dgm.get_destination());
|
||||
dgm.set_destination(original_dgm.get_source());
|
||||
|
||||
let mut pkt = MutableIpPacket::new(&mut res_buf)?;
|
||||
let udp_checksum = pkt.to_immutable().udp_checksum(&pkt.as_immutable_udp()?);
|
||||
pkt.as_udp()?.set_checksum(udp_checksum);
|
||||
pkt.set_ipv4_checksum();
|
||||
Some(res_buf)
|
||||
}
|
||||
|
||||
fn build_dns_with_answer<N>(
|
||||
message: &Message<[u8]>,
|
||||
qname: &N,
|
||||
qtype: Rtype,
|
||||
resource: &ResourceDescription,
|
||||
) -> Option<Vec<u8>>
|
||||
where
|
||||
N: ToDname + ?Sized,
|
||||
{
|
||||
let msg_buf = Vec::with_capacity(message.as_slice().len() * 2);
|
||||
let msg_builder = MessageBuilder::from_target(msg_buf).expect(
|
||||
"Developer error: we should be always be able to create a MessageBuilder from a Vec",
|
||||
);
|
||||
let mut answer_builder = msg_builder.start_answer(message, Rcode::NoError).ok()?;
|
||||
match qtype {
|
||||
Rtype::A => answer_builder
|
||||
.push((
|
||||
qname,
|
||||
Class::In,
|
||||
DNS_TTL,
|
||||
domain::rdata::A::from(resource.ipv4()?),
|
||||
))
|
||||
.ok()?,
|
||||
Rtype::Aaaa => answer_builder
|
||||
.push((
|
||||
qname,
|
||||
Class::In,
|
||||
DNS_TTL,
|
||||
domain::rdata::Aaaa::from(resource.ipv6()?),
|
||||
))
|
||||
.ok()?,
|
||||
Rtype::Ptr => answer_builder
|
||||
.push((
|
||||
qname,
|
||||
Class::In,
|
||||
DNS_TTL,
|
||||
domain::rdata::Ptr::<ParsedDname<_>>::new(
|
||||
resource.dns_name()?.parse::<Dname<Vec<u8>>>().ok()?.into(),
|
||||
),
|
||||
))
|
||||
.ok()?,
|
||||
_ => return None,
|
||||
}
|
||||
Some(answer_builder.finish())
|
||||
}
|
||||
|
||||
120
rust/connlib/tunnel/src/gateway.rs
Normal file
120
rust/connlib/tunnel/src/gateway.rs
Normal file
@@ -0,0 +1,120 @@
|
||||
use crate::device_channel::create_iface;
|
||||
use crate::{
|
||||
ControlSignal, Device, Event, RoleState, Tunnel, ICE_GATHERING_TIMEOUT_SECONDS,
|
||||
MAX_CONCURRENT_ICE_GATHERING, MAX_UDP_SIZE,
|
||||
};
|
||||
use connlib_shared::error::ConnlibError;
|
||||
use connlib_shared::messages::{ClientId, Interface as InterfaceConfig};
|
||||
use connlib_shared::Callbacks;
|
||||
use futures::channel::mpsc::Receiver;
|
||||
use futures_bounded::{PushError, StreamMap};
|
||||
use std::sync::Arc;
|
||||
use std::task::{ready, Context, Poll};
|
||||
use std::time::Duration;
|
||||
use webrtc::ice_transport::ice_candidate::RTCIceCandidateInit;
|
||||
|
||||
impl<C, CB> Tunnel<C, CB, GatewayState>
|
||||
where
|
||||
C: ControlSignal + Send + Sync + 'static,
|
||||
CB: Callbacks + 'static,
|
||||
{
|
||||
/// Sets the interface configuration and starts background tasks.
|
||||
#[tracing::instrument(level = "trace", skip(self))]
|
||||
pub async fn set_interface(
|
||||
self: &Arc<Self>,
|
||||
config: &InterfaceConfig,
|
||||
) -> connlib_shared::Result<()> {
|
||||
let device = create_iface(config, self.callbacks()).await?;
|
||||
*self.device.write().await = Some(device.clone());
|
||||
|
||||
self.start_timers().await?;
|
||||
*self.iface_handler_abort.lock() =
|
||||
Some(tokio::spawn(device_handler(Arc::clone(self), device)).abort_handle());
|
||||
|
||||
tracing::debug!("background_loop_started");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Reads IP packets from the [`Device`] and handles them accordingly.
|
||||
async fn device_handler<C, CB>(
|
||||
tunnel: Arc<Tunnel<C, CB, GatewayState>>,
|
||||
mut device: Device,
|
||||
) -> Result<(), ConnlibError>
|
||||
where
|
||||
C: ControlSignal + Send + Sync + 'static,
|
||||
CB: Callbacks + 'static,
|
||||
{
|
||||
let mut buf = [0u8; MAX_UDP_SIZE];
|
||||
loop {
|
||||
let Some(packet) = device.read().await? else {
|
||||
// Reading a bad IP packet or otherwise from the device seems bad. Should we restart the tunnel or something?
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
let dest = packet.destination();
|
||||
|
||||
let Some(peer) = tunnel.peer_by_ip(dest) else {
|
||||
continue;
|
||||
};
|
||||
|
||||
if let Err(e) = tunnel
|
||||
.encapsulate_and_send_to_peer(packet, peer, &dest, &mut buf)
|
||||
.await
|
||||
{
|
||||
tracing::error!(err = ?e, "failed to handle packet {e:#}")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// [`Tunnel`] state specific to gateways.
|
||||
pub struct GatewayState {
|
||||
candidate_receivers: StreamMap<ClientId, RTCIceCandidateInit>,
|
||||
}
|
||||
|
||||
impl GatewayState {
|
||||
pub fn add_new_ice_receiver(&mut self, id: ClientId, receiver: Receiver<RTCIceCandidateInit>) {
|
||||
match self.candidate_receivers.try_push(id, receiver) {
|
||||
Ok(()) => {}
|
||||
Err(PushError::BeyondCapacity(_)) => {
|
||||
tracing::warn!("Too many active ICE candidate receivers at a time")
|
||||
}
|
||||
Err(PushError::Replaced(_)) => {
|
||||
tracing::warn!(%id, "Replaced old ICE candidate receiver with new one")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for GatewayState {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
candidate_receivers: StreamMap::new(
|
||||
Duration::from_secs(ICE_GATHERING_TIMEOUT_SECONDS),
|
||||
MAX_CONCURRENT_ICE_GATHERING,
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl RoleState for GatewayState {
|
||||
type Id = ClientId;
|
||||
|
||||
fn poll_next_event(&mut self, cx: &mut Context<'_>) -> Poll<Event<Self::Id>> {
|
||||
loop {
|
||||
match ready!(self.candidate_receivers.poll_next_unpin(cx)) {
|
||||
(conn_id, Some(Ok(c))) => {
|
||||
return Poll::Ready(Event::SignalIceCandidate {
|
||||
conn_id,
|
||||
candidate: c,
|
||||
})
|
||||
}
|
||||
(id, Some(Err(e))) => {
|
||||
tracing::warn!(gateway_id = %id, "ICE gathering timed out: {e}")
|
||||
}
|
||||
(_, None) => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,18 +1,10 @@
|
||||
use std::{net::IpAddr, sync::Arc, time::Duration};
|
||||
use std::{net::IpAddr, sync::Arc};
|
||||
|
||||
use boringtun::noise::{errors::WireGuardError, Tunn, TunnResult};
|
||||
use boringtun::noise::{errors::WireGuardError, TunnResult};
|
||||
use bytes::Bytes;
|
||||
use connlib_shared::{Callbacks, Error, Result};
|
||||
use connlib_shared::{Callbacks, Result};
|
||||
|
||||
use crate::role_state::RoleState;
|
||||
use crate::{
|
||||
device_channel::{DeviceIo, IfaceConfig},
|
||||
dns,
|
||||
peer::EncapsulatedPacket,
|
||||
ConnId, ControlSignal, Tunnel, MAX_UDP_SIZE,
|
||||
};
|
||||
|
||||
const MAX_SIGNAL_CONNECTION_DELAY: Duration = Duration::from_secs(2);
|
||||
use crate::{ip_packet::MutableIpPacket, peer::Peer, ControlSignal, RoleState, Tunnel};
|
||||
|
||||
impl<C, CB, TRoleState> Tunnel<C, CB, TRoleState>
|
||||
where
|
||||
@@ -21,74 +13,15 @@ where
|
||||
TRoleState: RoleState,
|
||||
{
|
||||
#[inline(always)]
|
||||
fn connection_intent(self: &Arc<Self>, src: &[u8], dst_addr: &IpAddr) {
|
||||
// We can buffer requests here but will drop them for now and let the upper layer reliability protocol handle this
|
||||
if let Some(resource) = self.get_resource(src) {
|
||||
// We have awaiting connection to prevent a race condition where
|
||||
// create_peer_connection hasn't added the thing to peer_connections
|
||||
// and we are finding another packet to the same address (otherwise we would just use peer_connections here)
|
||||
let mut awaiting_connection = self.awaiting_connection.lock();
|
||||
let conn_id = ConnId::from(resource.id());
|
||||
if awaiting_connection.get(&conn_id).is_none() {
|
||||
tracing::trace!(
|
||||
resource_ip = %dst_addr,
|
||||
"resource_connection_intent",
|
||||
);
|
||||
|
||||
awaiting_connection.insert(conn_id, Default::default());
|
||||
let dev = Arc::clone(self);
|
||||
|
||||
let mut connected_gateway_ids: Vec<_> = dev
|
||||
.gateway_awaiting_connection
|
||||
.lock()
|
||||
.clone()
|
||||
.into_keys()
|
||||
.collect();
|
||||
connected_gateway_ids
|
||||
.extend(dev.resources_gateways.lock().values().collect::<Vec<_>>());
|
||||
tracing::trace!(
|
||||
gateways = ?connected_gateway_ids,
|
||||
"connected_gateways"
|
||||
);
|
||||
tokio::spawn(async move {
|
||||
let mut interval = tokio::time::interval(MAX_SIGNAL_CONNECTION_DELAY);
|
||||
loop {
|
||||
interval.tick().await;
|
||||
let reference = {
|
||||
let mut awaiting_connections = dev.awaiting_connection.lock();
|
||||
let Some(awaiting_connection) =
|
||||
awaiting_connections.get_mut(&ConnId::from(resource.id()))
|
||||
else {
|
||||
break;
|
||||
};
|
||||
if awaiting_connection.response_received {
|
||||
break;
|
||||
}
|
||||
awaiting_connection.total_attemps += 1;
|
||||
awaiting_connection.total_attemps
|
||||
};
|
||||
if let Err(e) = dev
|
||||
.control_signaler
|
||||
.signal_connection_to(&resource, &connected_gateway_ids, reference)
|
||||
.await
|
||||
{
|
||||
// Not a deadlock because this is a different task
|
||||
dev.awaiting_connection.lock().remove(&conn_id);
|
||||
tracing::error!(error = ?e, "start_resource_connection");
|
||||
let _ = dev.callbacks.on_error(&e);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
async fn handle_encapsulated_packet<'a>(
|
||||
pub(crate) async fn encapsulate_and_send_to_peer<'a>(
|
||||
&self,
|
||||
encapsulated_packet: EncapsulatedPacket<'a>,
|
||||
mut packet: MutableIpPacket<'_>,
|
||||
peer: Arc<Peer>,
|
||||
dst_addr: &IpAddr,
|
||||
buf: &mut [u8],
|
||||
) -> Result<()> {
|
||||
let encapsulated_packet = peer.encapsulate(&mut packet, buf)?;
|
||||
|
||||
match encapsulated_packet.encapsulate_result {
|
||||
TunnResult::Done => Ok(()),
|
||||
TunnResult::Err(WireGuardError::ConnectionExpired)
|
||||
@@ -130,72 +63,4 @@ where
|
||||
_ => panic!("Unexpected result from encapsulate"),
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
async fn handle_iface_packet(
|
||||
self: &Arc<Self>,
|
||||
device_writer: &DeviceIo,
|
||||
src: &mut [u8],
|
||||
dst: &mut [u8],
|
||||
) -> Result<()> {
|
||||
if let Some(r) = self.check_for_dns(src) {
|
||||
match r {
|
||||
dns::SendPacket::Ipv4(r) => device_writer.write4(&r[..])?,
|
||||
dns::SendPacket::Ipv6(r) => device_writer.write6(&r[..])?,
|
||||
};
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let dst_addr = match Tunn::dst_address(src) {
|
||||
Some(addr) => addr,
|
||||
None => return Err(Error::BadPacket),
|
||||
};
|
||||
|
||||
let encapsulated_packet = {
|
||||
match self.peers_by_ip.read().longest_match(dst_addr).map(|p| p.1) {
|
||||
Some(peer) => peer.encapsulate(src, dst)?,
|
||||
None => {
|
||||
self.connection_intent(src, &dst_addr);
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
self.handle_encapsulated_packet(encapsulated_packet, &dst_addr)
|
||||
.await
|
||||
}
|
||||
|
||||
#[tracing::instrument(level = "trace", skip(self, iface_config, device_io))]
|
||||
pub(crate) async fn iface_handler(
|
||||
self: &Arc<Self>,
|
||||
iface_config: Arc<IfaceConfig>,
|
||||
device_io: DeviceIo,
|
||||
) {
|
||||
let device_writer = device_io.clone();
|
||||
let mut src = [0u8; MAX_UDP_SIZE];
|
||||
let mut dst = [0u8; MAX_UDP_SIZE];
|
||||
loop {
|
||||
let res = match device_io.read(&mut src[..iface_config.mtu()]).await {
|
||||
Ok(res) => res,
|
||||
Err(e) => {
|
||||
tracing::error!(err = ?e, "failed to read interface: {e:#}");
|
||||
let _ = self.callbacks.on_error(&e.into());
|
||||
break;
|
||||
}
|
||||
};
|
||||
tracing::trace!(target: "wire", action = "read", bytes = res, from = "iface");
|
||||
|
||||
if res == 0 {
|
||||
break;
|
||||
}
|
||||
|
||||
if let Err(e) = self
|
||||
.handle_iface_packet(&device_writer, &mut src[..res], &mut dst)
|
||||
.await
|
||||
{
|
||||
let _ = self.callbacks.on_error(&e);
|
||||
tracing::error!(err = ?e, "failed to handle packet {e:#}")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -32,10 +32,20 @@ macro_rules! swap_src_dst {
|
||||
impl<'a> MutableIpPacket<'a> {
|
||||
#[inline]
|
||||
pub(crate) fn new(data: &mut [u8]) -> Option<MutableIpPacket> {
|
||||
match data[0] >> 4 {
|
||||
4 => MutableIpv4Packet::new(data).map(Into::into),
|
||||
6 => MutableIpv6Packet::new(data).map(Into::into),
|
||||
_ => None,
|
||||
let packet = match data[0] >> 4 {
|
||||
4 => MutableIpv4Packet::new(data)?.into(),
|
||||
6 => MutableIpv6Packet::new(data)?.into(),
|
||||
_ => return None,
|
||||
};
|
||||
|
||||
Some(packet)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub(crate) fn destination(&self) -> IpAddr {
|
||||
match self {
|
||||
MutableIpPacket::MutableIpv4Packet(i) => i.get_destination().into(),
|
||||
MutableIpPacket::MutableIpv6Packet(i) => i.get_destination().into(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -87,6 +97,13 @@ impl<'a> MutableIpPacket<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn as_immutable(&self) -> IpPacket<'_> {
|
||||
match self {
|
||||
Self::MutableIpv4Packet(p) => IpPacket::Ipv4Packet(p.to_immutable()),
|
||||
Self::MutableIpv6Packet(p) => IpPacket::Ipv6Packet(p.to_immutable()),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn as_udp(&mut self) -> Option<MutableUdpPacket> {
|
||||
self.to_immutable()
|
||||
.is_udp()
|
||||
@@ -174,14 +191,6 @@ pub(crate) enum IpPacket<'a> {
|
||||
}
|
||||
|
||||
impl<'a> IpPacket<'a> {
|
||||
pub(crate) fn new(data: &[u8]) -> Option<IpPacket> {
|
||||
match data[0] >> 4 {
|
||||
4 => Ipv4Packet::new(data).map(Into::into),
|
||||
6 => Ipv6Packet::new(data).map(Into::into),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn version(&self) -> Version {
|
||||
match self {
|
||||
IpPacket::Ipv4Packet(_) => Version::Ipv4,
|
||||
@@ -214,13 +223,6 @@ impl<'a> IpPacket<'a> {
|
||||
.flatten()
|
||||
}
|
||||
|
||||
pub(crate) fn destination(&self) -> IpAddr {
|
||||
match self {
|
||||
Self::Ipv4Packet(p) => p.get_destination().into(),
|
||||
Self::Ipv6Packet(p) => p.get_destination().into(),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn source(&self) -> IpAddr {
|
||||
match self {
|
||||
Self::Ipv4Packet(p) => p.get_source().into(),
|
||||
@@ -228,6 +230,13 @@ impl<'a> IpPacket<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn destination(&self) -> IpAddr {
|
||||
match self {
|
||||
Self::Ipv4Packet(p) => p.get_destination().into(),
|
||||
Self::Ipv6Packet(p) => p.get_destination().into(),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn udp_checksum(&self, dgm: &UdpPacket<'_>) -> u16 {
|
||||
match self {
|
||||
Self::Ipv4Packet(p) => udp::ipv4_checksum(dgm, &p.get_source(), &p.get_destination()),
|
||||
|
||||
@@ -3,12 +3,12 @@
|
||||
//! This is both the wireguard and ICE implementation that should work in tandem.
|
||||
//! [Tunnel] is the main entry-point for this crate.
|
||||
use boringtun::{
|
||||
noise::{errors::WireGuardError, rate_limiter::RateLimiter, Tunn, TunnResult},
|
||||
noise::{errors::WireGuardError, rate_limiter::RateLimiter, TunnResult},
|
||||
x25519::{PublicKey, StaticSecret},
|
||||
};
|
||||
use bytes::Bytes;
|
||||
|
||||
use connlib_shared::{messages::Key, CallbackErrorFacade, Callbacks, Error, DNS_SENTINEL};
|
||||
use connlib_shared::{messages::Key, CallbackErrorFacade, Callbacks, Error};
|
||||
use ip_network::IpNetwork;
|
||||
use ip_network_table::IpNetworkTable;
|
||||
use serde::{Deserialize, Serialize};
|
||||
@@ -29,7 +29,7 @@ use webrtc::{
|
||||
};
|
||||
|
||||
use std::task::{Context, Poll};
|
||||
use std::{collections::HashMap, fmt, net::IpAddr, sync::Arc, time::Duration};
|
||||
use std::{collections::HashMap, fmt, io, net::IpAddr, sync::Arc, time::Duration};
|
||||
use webrtc::ice_transport::ice_candidate::RTCIceCandidateInit;
|
||||
|
||||
use connlib_shared::{
|
||||
@@ -39,19 +39,22 @@ use connlib_shared::{
|
||||
Result,
|
||||
};
|
||||
|
||||
use device_channel::{create_iface, DeviceIo, IfaceConfig};
|
||||
use device_channel::{DeviceIo, IfaceConfig};
|
||||
|
||||
pub use client::ClientState;
|
||||
pub use control_protocol::Request;
|
||||
pub use role_state::{ClientState, GatewayState};
|
||||
pub use gateway::GatewayState;
|
||||
pub use webrtc::peer_connection::sdp::session_description::RTCSessionDescription;
|
||||
|
||||
use crate::role_state::RoleState;
|
||||
use crate::ip_packet::MutableIpPacket;
|
||||
use connlib_shared::messages::SecretKey;
|
||||
use index::IndexLfsr;
|
||||
|
||||
mod client;
|
||||
mod control_protocol;
|
||||
mod device_channel;
|
||||
mod dns;
|
||||
mod gateway;
|
||||
mod iface_handler;
|
||||
mod index;
|
||||
mod ip_packet;
|
||||
@@ -59,12 +62,23 @@ mod peer;
|
||||
mod peer_handler;
|
||||
mod resource_sender;
|
||||
mod resource_table;
|
||||
mod role_state;
|
||||
mod tokio_util;
|
||||
|
||||
const MAX_UDP_SIZE: usize = (1 << 16) - 1;
|
||||
const RESET_PACKET_COUNT_INTERVAL: Duration = Duration::from_secs(1);
|
||||
const REFRESH_PEERS_TIMERS_INTERVAL: Duration = Duration::from_secs(1);
|
||||
const REFRESH_MTU_INTERVAL: Duration = Duration::from_secs(30);
|
||||
/// For how long we will attempt to gather ICE candidates before aborting.
|
||||
///
|
||||
/// Chosen arbitrarily.
|
||||
/// Very likely, the actual WebRTC connection will timeout before this.
|
||||
/// This timeout is just here to eventually clean-up tasks if they are somehow broken.
|
||||
const ICE_GATHERING_TIMEOUT_SECONDS: u64 = 5 * 60;
|
||||
|
||||
/// How many concurrent ICE gathering attempts we are allow.
|
||||
///
|
||||
/// Chosen arbitrarily.
|
||||
const MAX_CONCURRENT_ICE_GATHERING: usize = 100;
|
||||
|
||||
// Note: Taken from boringtun
|
||||
const HANDSHAKE_RATE_LIMIT: u64 = 100;
|
||||
@@ -149,8 +163,30 @@ struct AwaitingConnectionDetails {
|
||||
|
||||
#[derive(Clone)]
|
||||
struct Device {
|
||||
pub config: Arc<IfaceConfig>,
|
||||
pub io: DeviceIo,
|
||||
config: Arc<IfaceConfig>,
|
||||
io: DeviceIo,
|
||||
|
||||
buf: Box<[u8; MAX_UDP_SIZE]>,
|
||||
}
|
||||
|
||||
impl Device {
|
||||
async fn read(&mut self) -> io::Result<Option<MutableIpPacket<'_>>> {
|
||||
let res = self.io.read(&mut self.buf[..self.config.mtu()]).await?;
|
||||
tracing::trace!(target: "wire", action = "read", bytes = res, from = "iface");
|
||||
|
||||
if res == 0 {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
Ok(Some(
|
||||
MutableIpPacket::new(&mut self.buf[..res]).ok_or_else(|| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
"received bytes are not an IP packet",
|
||||
)
|
||||
})?,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: We should use newtypes for each kind of Id
|
||||
@@ -243,6 +279,14 @@ where
|
||||
pub fn poll_next_event(&self, cx: &mut Context<'_>) -> Poll<Event<TRoleState::Id>> {
|
||||
self.role_state.lock().poll_next_event(cx)
|
||||
}
|
||||
|
||||
pub(crate) fn peer_by_ip(&self, ip: IpAddr) -> Option<Arc<Peer>> {
|
||||
self.peers_by_ip
|
||||
.read()
|
||||
.longest_match(ip)
|
||||
.map(|(_, peer)| peer)
|
||||
.cloned()
|
||||
}
|
||||
}
|
||||
|
||||
pub enum Event<TId> {
|
||||
@@ -327,86 +371,6 @@ where
|
||||
})
|
||||
}
|
||||
|
||||
#[tracing::instrument(level = "trace", skip(self))]
|
||||
pub async fn add_route(self: &Arc<Self>, route: IpNetwork) -> Result<()> {
|
||||
let mut device = self.device.write().await;
|
||||
|
||||
if let Some(new_device) = device
|
||||
.as_ref()
|
||||
.ok_or(Error::ControlProtocolError)?
|
||||
.config
|
||||
.add_route(route, self.callbacks())
|
||||
.await?
|
||||
{
|
||||
*device = Some(new_device.clone());
|
||||
let dev = Arc::clone(self);
|
||||
self.iface_handler_abort.lock().replace(
|
||||
tokio::spawn(
|
||||
async move { dev.iface_handler(new_device.config, new_device.io).await },
|
||||
)
|
||||
.abort_handle(),
|
||||
);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Adds a the given resource to the tunnel.
|
||||
///
|
||||
/// Once added, when a packet for the resource is intercepted a new data channel will be created
|
||||
/// and packets will be wrapped with wireguard and sent through it.
|
||||
#[tracing::instrument(level = "trace", skip(self))]
|
||||
pub async fn add_resource(
|
||||
self: &Arc<Self>,
|
||||
resource_description: ResourceDescription,
|
||||
) -> Result<()> {
|
||||
let mut any_valid_route = false;
|
||||
{
|
||||
for ip in resource_description.ips() {
|
||||
if let Err(e) = self.add_route(ip).await {
|
||||
tracing::warn!(route = %ip, error = ?e, "add_route");
|
||||
let _ = self.callbacks().on_error(&e);
|
||||
} else {
|
||||
any_valid_route = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
if !any_valid_route {
|
||||
return Err(Error::InvalidResource);
|
||||
}
|
||||
|
||||
let resource_list = {
|
||||
let mut resources = self.resources.write();
|
||||
resources.insert(resource_description);
|
||||
resources.resource_list()
|
||||
};
|
||||
|
||||
self.callbacks.on_update_resources(resource_list)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Sets the interface configuration and starts background tasks.
|
||||
#[tracing::instrument(level = "trace", skip(self))]
|
||||
pub async fn set_interface(self: &Arc<Self>, config: &InterfaceConfig) -> Result<()> {
|
||||
let device = create_iface(config, self.callbacks()).await?;
|
||||
*self.device.write().await = Some(device.clone());
|
||||
|
||||
self.start_timers().await?;
|
||||
let dev = Arc::clone(self);
|
||||
*self.iface_handler_abort.lock() = Some(
|
||||
tokio::spawn(async move { dev.iface_handler(device.config, device.io).await })
|
||||
.abort_handle(),
|
||||
);
|
||||
|
||||
self.add_route(DNS_SENTINEL.into()).await?;
|
||||
|
||||
self.callbacks.on_tunnel_ready()?;
|
||||
|
||||
tracing::debug!("background_loop_started");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tracing::instrument(level = "trace", skip(self))]
|
||||
async fn stop_peer(&self, index: u32, conn_id: ConnId) {
|
||||
self.peers_by_ip.write().retain(|_, p| p.index != index);
|
||||
@@ -537,8 +501,7 @@ where
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn get_resource(&self, buff: &[u8]) -> Option<ResourceDescription> {
|
||||
let addr = Tunn::dst_address(buff)?;
|
||||
fn get_resource(&self, addr: IpAddr) -> Option<ResourceDescription> {
|
||||
let resources = self.resources.read();
|
||||
match addr {
|
||||
IpAddr::V4(ipv4) => resources.get_by_ip(ipv4).cloned(),
|
||||
@@ -554,3 +517,13 @@ where
|
||||
&self.callbacks
|
||||
}
|
||||
}
|
||||
|
||||
/// Dedicated trait for abstracting over the different ICE states.
|
||||
///
|
||||
/// By design, this trait does not allow any operations apart from advancing via [`RoleState::poll_next_event`].
|
||||
/// The state should only be modified when the concrete type is known, e.g. [`ClientState`] or [`GatewayState`].
|
||||
pub trait RoleState: Default + Send + 'static {
|
||||
type Id: fmt::Debug;
|
||||
|
||||
fn poll_next_event(&mut self, cx: &mut Context<'_>) -> Poll<Event<Self::Id>>;
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ use connlib_shared::{
|
||||
use ip_network::IpNetwork;
|
||||
use ip_network_table::IpNetworkTable;
|
||||
use parking_lot::{Mutex, RwLock};
|
||||
use pnet_packet::MutablePacket;
|
||||
use webrtc::data::data_channel::DataChannel;
|
||||
|
||||
use crate::{ip_packet::MutableIpPacket, resource_table::ResourceTable, ConnId};
|
||||
@@ -194,14 +195,9 @@ impl Peer {
|
||||
|
||||
pub(crate) fn encapsulate<'a>(
|
||||
&self,
|
||||
src: &'a mut [u8],
|
||||
packet: &mut MutableIpPacket<'a>,
|
||||
dst: &'a mut [u8],
|
||||
) -> Result<EncapsulatedPacket<'a>> {
|
||||
let Some(mut packet) = MutableIpPacket::new(src) else {
|
||||
debug_assert!(false, "Got non-ip packet from the tunnel interface");
|
||||
tracing::error!("Developer error: we should never see a packet through the tunnel wire that isn't ip");
|
||||
return Err(Error::BadPacket);
|
||||
};
|
||||
if let Some(resource) = self.get_translation(packet.to_immutable().source()) {
|
||||
let ResourceDescription::Dns(resource) = resource else {
|
||||
tracing::error!(
|
||||
@@ -210,7 +206,7 @@ impl Peer {
|
||||
return Err(Error::ControlProtocolError);
|
||||
};
|
||||
|
||||
match &mut packet {
|
||||
match packet {
|
||||
MutableIpPacket::MutableIpv4Packet(ref mut p) => p.set_source(resource.ipv4),
|
||||
MutableIpPacket::MutableIpv6Packet(ref mut p) => p.set_source(resource.ipv6),
|
||||
}
|
||||
@@ -221,7 +217,7 @@ impl Peer {
|
||||
index: self.index,
|
||||
conn_id: self.conn_id,
|
||||
channel: self.channel.clone(),
|
||||
encapsulate_result: self.tunnel.lock().encapsulate(src, dst),
|
||||
encapsulate_result: self.tunnel.lock().encapsulate(packet.packet_mut(), dst),
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -4,10 +4,9 @@ use boringtun::noise::{handshake::parse_handshake_anon, Packet, TunnResult};
|
||||
use bytes::Bytes;
|
||||
use connlib_shared::{Callbacks, Error, Result};
|
||||
|
||||
use crate::role_state::RoleState;
|
||||
use crate::{
|
||||
device_channel::DeviceIo, index::check_packet_index, peer::Peer, ControlSignal, Tunnel,
|
||||
MAX_UDP_SIZE,
|
||||
device_channel::DeviceIo, index::check_packet_index, peer::Peer, ControlSignal, RoleState,
|
||||
Tunnel, MAX_UDP_SIZE,
|
||||
};
|
||||
|
||||
impl<C, CB, TRoleState> Tunnel<C, CB, TRoleState>
|
||||
|
||||
@@ -1,148 +0,0 @@
|
||||
use crate::Event;
|
||||
use connlib_shared::messages::{ClientId, GatewayId};
|
||||
use futures::channel::mpsc::Receiver;
|
||||
use futures_bounded::{PushError, StreamMap};
|
||||
use std::collections::HashMap;
|
||||
use std::fmt;
|
||||
use std::task::{ready, Context, Poll};
|
||||
use std::time::Duration;
|
||||
use webrtc::ice_transport::ice_candidate::RTCIceCandidateInit;
|
||||
|
||||
/// Dedicated trait for abstracting over the different ICE states.
|
||||
///
|
||||
/// By design, this trait does not allow any operations apart from advancing via [`RoleState::poll_next_event`].
|
||||
/// The state should only be modified when the concrete type is known, e.g. [`ClientState`] or [`GatewayState`].
|
||||
pub trait RoleState: Default + Send + 'static {
|
||||
type Id: fmt::Debug;
|
||||
|
||||
fn poll_next_event(&mut self, cx: &mut Context<'_>) -> Poll<Event<Self::Id>>;
|
||||
}
|
||||
|
||||
/// For how long we will attempt to gather ICE candidates before aborting.
|
||||
///
|
||||
/// Chosen arbitrarily.
|
||||
/// Very likely, the actual WebRTC connection will timeout before this.
|
||||
/// This timeout is just here to eventually clean-up tasks if they are somehow broken.
|
||||
const ICE_GATHERING_TIMEOUT_SECONDS: u64 = 5 * 60;
|
||||
|
||||
/// How many concurrent ICE gathering attempts we are allow.
|
||||
///
|
||||
/// Chosen arbitrarily.
|
||||
const MAX_CONCURRENT_ICE_GATHERING: usize = 100;
|
||||
|
||||
/// [`Tunnel`](crate::Tunnel) state specific to clients.
|
||||
pub struct ClientState {
|
||||
active_candidate_receivers: StreamMap<GatewayId, RTCIceCandidateInit>,
|
||||
/// We split the receivers of ICE candidates into two phases because we only want to start sending them once we've received an SDP from the gateway.
|
||||
waiting_for_sdp_from_gatway: HashMap<GatewayId, Receiver<RTCIceCandidateInit>>,
|
||||
}
|
||||
|
||||
impl ClientState {
|
||||
pub fn add_waiting_ice_receiver(
|
||||
&mut self,
|
||||
id: GatewayId,
|
||||
receiver: Receiver<RTCIceCandidateInit>,
|
||||
) {
|
||||
self.waiting_for_sdp_from_gatway.insert(id, receiver);
|
||||
}
|
||||
|
||||
pub fn activate_ice_candidate_receiver(&mut self, id: GatewayId) {
|
||||
let Some(receiver) = self.waiting_for_sdp_from_gatway.remove(&id) else {
|
||||
return;
|
||||
};
|
||||
|
||||
match self.active_candidate_receivers.try_push(id, receiver) {
|
||||
Ok(()) => {}
|
||||
Err(PushError::BeyondCapacity(_)) => {
|
||||
tracing::warn!("Too many active ICE candidate receivers at a time")
|
||||
}
|
||||
Err(PushError::Replaced(_)) => {
|
||||
tracing::warn!(%id, "Replaced old ICE candidate receiver with new one")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ClientState {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
active_candidate_receivers: StreamMap::new(
|
||||
Duration::from_secs(ICE_GATHERING_TIMEOUT_SECONDS),
|
||||
MAX_CONCURRENT_ICE_GATHERING,
|
||||
),
|
||||
waiting_for_sdp_from_gatway: Default::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl RoleState for ClientState {
|
||||
type Id = GatewayId;
|
||||
|
||||
fn poll_next_event(&mut self, cx: &mut Context<'_>) -> Poll<Event<Self::Id>> {
|
||||
loop {
|
||||
match ready!(self.active_candidate_receivers.poll_next_unpin(cx)) {
|
||||
(conn_id, Some(Ok(c))) => {
|
||||
return Poll::Ready(Event::SignalIceCandidate {
|
||||
conn_id,
|
||||
candidate: c,
|
||||
})
|
||||
}
|
||||
(id, Some(Err(e))) => {
|
||||
tracing::warn!(gateway_id = %id, "ICE gathering timed out: {e}")
|
||||
}
|
||||
(_, None) => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// [`Tunnel`](crate::Tunnel) state specific to gateways.
|
||||
pub struct GatewayState {
|
||||
candidate_receivers: StreamMap<ClientId, RTCIceCandidateInit>,
|
||||
}
|
||||
|
||||
impl GatewayState {
|
||||
pub fn add_new_ice_receiver(&mut self, id: ClientId, receiver: Receiver<RTCIceCandidateInit>) {
|
||||
match self.candidate_receivers.try_push(id, receiver) {
|
||||
Ok(()) => {}
|
||||
Err(PushError::BeyondCapacity(_)) => {
|
||||
tracing::warn!("Too many active ICE candidate receivers at a time")
|
||||
}
|
||||
Err(PushError::Replaced(_)) => {
|
||||
tracing::warn!(%id, "Replaced old ICE candidate receiver with new one")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for GatewayState {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
candidate_receivers: StreamMap::new(
|
||||
Duration::from_secs(ICE_GATHERING_TIMEOUT_SECONDS),
|
||||
MAX_CONCURRENT_ICE_GATHERING,
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl RoleState for GatewayState {
|
||||
type Id = ClientId;
|
||||
|
||||
fn poll_next_event(&mut self, cx: &mut Context<'_>) -> Poll<Event<Self::Id>> {
|
||||
loop {
|
||||
match ready!(self.candidate_receivers.poll_next_unpin(cx)) {
|
||||
(conn_id, Some(Ok(c))) => {
|
||||
return Poll::Ready(Event::SignalIceCandidate {
|
||||
conn_id,
|
||||
candidate: c,
|
||||
})
|
||||
}
|
||||
(id, Some(Err(e))) => {
|
||||
tracing::warn!(gateway_id = %id, "ICE gathering timed out: {e}")
|
||||
}
|
||||
(_, None) => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
23
rust/connlib/tunnel/src/tokio_util.rs
Normal file
23
rust/connlib/tunnel/src/tokio_util.rs
Normal file
@@ -0,0 +1,23 @@
|
||||
use connlib_shared::error::ConnlibError;
|
||||
use connlib_shared::Callbacks;
|
||||
use std::future::Future;
|
||||
|
||||
/// Spawns a task into the [`tokio`] runtime.
|
||||
///
|
||||
/// On error, [`Callbacks::on_error`] is invoked.
|
||||
/// This also returns a [`tokio::task::AbortHandle`] which MAY be used to abort the task.
|
||||
/// If you don't need it, you are free to drop it.
|
||||
/// It won't terminate the task.
|
||||
pub(crate) fn spawn_log(
|
||||
cb: &(impl Callbacks + 'static),
|
||||
f: impl Future<Output = Result<(), ConnlibError>> + Send + 'static,
|
||||
) -> tokio::task::AbortHandle {
|
||||
let cb = cb.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = f.await {
|
||||
let _ = cb.on_error(&e);
|
||||
}
|
||||
})
|
||||
.abort_handle()
|
||||
}
|
||||
Reference in New Issue
Block a user