refactor(connlib): remove ControlSignal (#2321)

This commit is contained in:
Thomas Eizinger
2023-10-18 17:28:04 +11:00
committed by GitHub
parent a929c6bdbf
commit 2cfe7befef
16 changed files with 609 additions and 657 deletions

View File

@@ -12,7 +12,7 @@ RUN cargo chef prepare --recipe-path recipe.json
FROM chef as builder
COPY --from=planner /build/recipe.json .
RUN cargo chef cook --release --recipe-path recipe.json
RUN cargo chef cook --all-targets --release --recipe-path recipe.json
COPY . .
ARG PACKAGE
RUN cargo build -p $PACKAGE --release

View File

@@ -14,48 +14,19 @@ use connlib_shared::{
Result,
};
use async_trait::async_trait;
use firezone_tunnel::{ClientState, ControlSignal, Request, Tunnel};
use firezone_tunnel::{ClientState, Request, Tunnel};
use reqwest::header::{CONTENT_ENCODING, CONTENT_TYPE};
use tokio::io::BufReader;
use tokio::sync::Mutex;
use tokio_util::codec::{BytesCodec, FramedRead};
use url::Url;
#[async_trait]
impl ControlSignal for ControlSignaler {
async fn signal_connection_to(
&self,
resource: &ResourceDescription,
connected_gateway_ids: &[GatewayId],
reference: usize,
) -> Result<()> {
self.control_signal
// It's easier if self is not mut
.clone()
.send_with_ref(
EgressMessages::PrepareConnection {
resource_id: resource.id(),
connected_gateway_ids: connected_gateway_ids.to_vec(),
},
reference,
)
.await?;
Ok(())
}
}
pub struct ControlPlane<CB: Callbacks> {
pub tunnel: Arc<Tunnel<ControlSignaler, CB, ClientState>>,
pub control_signaler: ControlSignaler,
pub tunnel: Arc<Tunnel<CB, ClientState>>,
pub phoenix_channel: PhoenixSenderWithTopic,
pub tunnel_init: Mutex<bool>,
}
#[derive(Clone)]
pub struct ControlSignaler {
pub control_signal: PhoenixSenderWithTopic,
}
impl<CB: Callbacks + 'static> ControlPlane<CB> {
#[tracing::instrument(level = "trace", skip(self))]
async fn init(
@@ -139,7 +110,7 @@ impl<CB: Callbacks + 'static> ControlPlane<CB> {
reference: Option<Reference>,
) {
let tunnel = Arc::clone(&self.tunnel);
let mut control_signaler = self.control_signaler.clone();
let mut control_signaler = self.phoenix_channel.clone();
tokio::spawn(async move {
let err = match tunnel
.request_connection(resource_id, gateway_id, relays, reference)
@@ -147,7 +118,6 @@ impl<CB: Callbacks + 'static> ControlPlane<CB> {
{
Ok(Request::NewConnection(connection_request)) => {
if let Err(err) = control_signaler
.control_signal
// TODO: create a reference number and keep track for the response
.send_with_ref(
EgressMessages::RequestConnection(connection_request),
@@ -162,7 +132,6 @@ impl<CB: Callbacks + 'static> ControlPlane<CB> {
}
Ok(Request::ReuseConnection(connection_request)) => {
if let Err(err) = control_signaler
.control_signal
// TODO: create a reference number and keep track for the response
.send_with_ref(
EgressMessages::ReuseConnection(connection_request),
@@ -178,7 +147,7 @@ impl<CB: Callbacks + 'static> ControlPlane<CB> {
Err(err) => err,
};
tunnel.cleanup_connection(resource_id.into());
tunnel.cleanup_connection(resource_id);
tracing::error!("Error request connection details: {err}");
let _ = tunnel.callbacks().on_error(&err);
});
@@ -250,7 +219,7 @@ impl<CB: Callbacks + 'static> ControlPlane<CB> {
return;
};
// TODO: Rate limit the number of attempts of getting the relays before just trying a local network connection
self.tunnel.cleanup_connection(resource_id.into());
self.tunnel.cleanup_connection(resource_id);
}
None => {
tracing::error!(
@@ -273,8 +242,7 @@ impl<CB: Callbacks + 'static> ControlPlane<CB> {
tracing::info!("Requesting log upload URL from portal");
let _ = self
.control_signaler
.control_signal
.phoenix_channel
.send(EgressMessages::CreateLogSink {})
.await;
}
@@ -283,8 +251,7 @@ impl<CB: Callbacks + 'static> ControlPlane<CB> {
match event {
firezone_tunnel::Event::SignalIceCandidate { conn_id, candidate } => {
if let Err(e) = self
.control_signaler
.control_signal
.phoenix_channel
.send(EgressMessages::BroadcastIceCandidates(
BroadcastGatewayIceCandidates {
gateway_ids: vec![conn_id],
@@ -296,6 +263,28 @@ impl<CB: Callbacks + 'static> ControlPlane<CB> {
tracing::error!("Failed to signal ICE candidate: {e}")
}
}
firezone_tunnel::Event::ConnectionIntent {
resource,
connected_gateway_ids,
reference,
} => {
if let Err(e) = self
.phoenix_channel
.clone()
.send_with_ref(
EgressMessages::PrepareConnection {
resource_id: resource.id(),
connected_gateway_ids: connected_gateway_ids.to_vec(),
},
reference,
)
.await
{
tracing::error!("Failed to prepare connection: {e}");
// TODO: Clean up connection in `ClientState` here?
}
}
}
}
}

View File

@@ -3,7 +3,6 @@ pub use connlib_shared::{get_device_id, messages::ResourceDescription};
pub use connlib_shared::{Callbacks, Error};
pub use tracing_appender::non_blocking::WorkerGuard;
use crate::control::ControlSignaler;
use backoff::{backoff::Backoff, ExponentialBackoffBuilder};
use connlib_shared::control::SecureUrl;
use connlib_shared::{control::PhoenixChannel, login_url, CallbackErrorFacade, Mode, Result};
@@ -149,16 +148,15 @@ where
}
});
let control_signaler = ControlSignaler { control_signal: connection.sender_with_topic("client".to_owned()) };
let tunnel = fatal_error!(
Tunnel::new(private_key, control_signaler.clone(), callbacks.clone()).await,
Tunnel::new(private_key, callbacks.clone()).await,
runtime_stopper,
&callbacks
);
let mut control_plane = ControlPlane {
tunnel: Arc::new(tunnel),
control_signaler,
phoenix_channel: connection.sender_with_topic("client".to_owned()),
tunnel_init: Mutex::new(false),
};

View File

@@ -1,25 +1,34 @@
use crate::device_channel::{create_iface, DeviceIo};
use crate::ip_packet::IpPacket;
use crate::peer::Peer;
use crate::resource_table::ResourceTable;
use crate::{
dns, tokio_util, ConnId, ControlSignal, Device, Event, RoleState, Tunnel,
dns, peer_by_ip, tokio_util, Device, Event, PeerConfig, RoleState, Tunnel,
ICE_GATHERING_TIMEOUT_SECONDS, MAX_CONCURRENT_ICE_GATHERING, MAX_UDP_SIZE,
};
use boringtun::x25519::{PublicKey, StaticSecret};
use connlib_shared::error::{ConnlibError as Error, ConnlibError};
use connlib_shared::messages::{GatewayId, Interface as InterfaceConfig, ResourceDescription};
use connlib_shared::messages::{
GatewayId, Interface as InterfaceConfig, Key, ResourceDescription, ResourceId, ReuseConnection,
SecretKey,
};
use connlib_shared::{Callbacks, DNS_SENTINEL};
use futures::channel::mpsc::Receiver;
use futures::stream;
use futures_bounded::{PushError, StreamMap};
use ip_network::IpNetwork;
use ip_network_table::IpNetworkTable;
use std::collections::hash_map::Entry;
use std::collections::HashMap;
use std::io;
use std::net::IpAddr;
use std::sync::Arc;
use std::task::{ready, Context, Poll};
use std::task::{Context, Poll};
use std::time::Duration;
use tokio::time::Instant;
use webrtc::ice_transport::ice_candidate::RTCIceCandidateInit;
impl<C, CB> Tunnel<C, CB, ClientState>
impl<CB> Tunnel<CB, ClientState>
where
C: ControlSignal + Send + Sync + 'static,
CB: Callbacks + 'static,
{
/// Adds a the given resource to the tunnel.
@@ -47,9 +56,9 @@ where
}
let resource_list = {
let mut resources = self.resources.write();
resources.insert(resource_description);
resources.resource_list()
let mut role_state = self.role_state.lock();
role_state.resources.insert(resource_description);
role_state.resources.resource_list()
};
self.callbacks.on_update_resources(resource_list)?;
@@ -80,6 +89,13 @@ where
Ok(())
}
/// Clean up a connection to a resource.
// FIXME: this cleanup connection is wrong!
pub fn cleanup_connection(&self, id: ResourceId) {
self.role_state.lock().on_connection_failed(id);
self.peer_connections.lock().remove(&id.into());
}
#[tracing::instrument(level = "trace", skip(self))]
async fn add_route(self: &Arc<Self>, route: IpNetwork) -> connlib_shared::Result<()> {
let mut device = self.device.write().await;
@@ -100,80 +116,14 @@ where
Ok(())
}
#[inline(always)]
fn connection_intent(self: &Arc<Self>, packet: IpPacket<'_>) {
const MAX_SIGNAL_CONNECTION_DELAY: Duration = Duration::from_secs(2);
// We can buffer requests here but will drop them for now and let the upper layer reliability protocol handle this
if let Some(resource) = self.get_resource(packet.destination()) {
// We have awaiting connection to prevent a race condition where
// create_peer_connection hasn't added the thing to peer_connections
// and we are finding another packet to the same address (otherwise we would just use peer_connections here)
let mut awaiting_connection = self.awaiting_connection.lock();
let conn_id = ConnId::from(resource.id());
if awaiting_connection.get(&conn_id).is_none() {
tracing::trace!(
resource_ip = %packet.destination(),
"resource_connection_intent",
);
awaiting_connection.insert(conn_id, Default::default());
let dev = Arc::clone(self);
let mut connected_gateway_ids: Vec<_> = dev
.gateway_awaiting_connection
.lock()
.clone()
.into_keys()
.collect();
connected_gateway_ids
.extend(dev.resources_gateways.lock().values().collect::<Vec<_>>());
tracing::trace!(
gateways = ?connected_gateway_ids,
"connected_gateways"
);
tokio::spawn(async move {
let mut interval = tokio::time::interval(MAX_SIGNAL_CONNECTION_DELAY);
loop {
interval.tick().await;
let reference = {
let mut awaiting_connections = dev.awaiting_connection.lock();
let Some(awaiting_connection) =
awaiting_connections.get_mut(&ConnId::from(resource.id()))
else {
break;
};
if awaiting_connection.response_received {
break;
}
awaiting_connection.total_attemps += 1;
awaiting_connection.total_attemps
};
if let Err(e) = dev
.control_signaler
.signal_connection_to(&resource, &connected_gateway_ids, reference)
.await
{
// Not a deadlock because this is a different task
dev.awaiting_connection.lock().remove(&conn_id);
tracing::error!(error = ?e, "start_resource_connection");
let _ = dev.callbacks.on_error(&e);
}
}
});
}
}
}
}
/// Reads IP packets from the [`Device`] and handles them accordingly.
async fn device_handler<C, CB>(
tunnel: Arc<Tunnel<C, CB, ClientState>>,
async fn device_handler<CB>(
tunnel: Arc<Tunnel<CB, ClientState>>,
mut device: Device,
) -> Result<(), ConnlibError>
where
C: ControlSignal + Send + Sync + 'static,
CB: Callbacks + 'static,
{
let device_writer = device.io.clone();
@@ -183,7 +133,9 @@ where
return Ok(());
};
if let Some(dns_packet) = dns::parse(&tunnel.resources.read(), packet.as_immutable()) {
if let Some(dns_packet) =
dns::parse(&tunnel.role_state.lock().resources, packet.as_immutable())
{
if let Err(e) = send_dns_packet(&device_writer, dns_packet) {
tracing::error!(err = %e, "failed to send DNS packet");
let _ = tunnel.callbacks.on_error(&e.into());
@@ -194,8 +146,11 @@ where
let dest = packet.destination();
let Some(peer) = tunnel.peer_by_ip(dest) else {
tunnel.connection_intent(packet.as_immutable());
let Some(peer) = peer_by_ip(&tunnel.peers_by_ip.read(), dest) else {
tunnel
.role_state
.lock()
.on_connection_intent(packet.destination());
continue;
};
@@ -223,9 +178,190 @@ pub struct ClientState {
active_candidate_receivers: StreamMap<GatewayId, RTCIceCandidateInit>,
/// We split the receivers of ICE candidates into two phases because we only want to start sending them once we've received an SDP from the gateway.
waiting_for_sdp_from_gatway: HashMap<GatewayId, Receiver<RTCIceCandidateInit>>,
// TODO: Make private
pub awaiting_connection: HashMap<ResourceId, AwaitingConnectionDetails>,
pub gateway_awaiting_connection: HashMap<GatewayId, Vec<IpNetwork>>,
awaiting_connection_timers: StreamMap<ResourceId, Instant>,
pub gateway_public_keys: HashMap<GatewayId, PublicKey>,
resources_gateways: HashMap<ResourceId, GatewayId>,
resources: ResourceTable<ResourceDescription>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct AwaitingConnectionDetails {
total_attemps: usize,
response_received: bool,
gateways: Vec<GatewayId>,
}
impl ClientState {
pub(crate) fn attempt_to_reuse_connection(
&mut self,
resource: ResourceId,
gateway: GatewayId,
expected_attempts: usize,
connected_peers: &mut IpNetworkTable<Arc<Peer>>,
) -> Result<Option<ReuseConnection>, ConnlibError> {
if self.is_connected_to(resource, connected_peers) {
return Err(Error::UnexpectedConnectionDetails);
}
let desc = self
.resources
.get_by_id(&resource)
.ok_or(Error::UnknownResource)?;
let details = self
.awaiting_connection
.get_mut(&resource)
.ok_or(Error::UnexpectedConnectionDetails)?;
details.response_received = true;
if details.total_attemps != expected_attempts {
return Err(Error::UnexpectedConnectionDetails);
}
self.resources_gateways.insert(resource, gateway);
match self.gateway_awaiting_connection.entry(gateway) {
Entry::Occupied(mut occupied) => {
occupied.get_mut().extend(desc.ips());
return Ok(Some(ReuseConnection {
resource_id: resource,
gateway_id: gateway,
}));
}
Entry::Vacant(vacant) => {
vacant.insert(vec![]);
}
}
let found = {
let peer = connected_peers
.iter()
.find_map(|(_, p)| (p.conn_id == gateway.into()).then_some(p))
.cloned();
if let Some(peer) = peer {
for ip in desc.ips() {
peer.add_allowed_ip(ip);
connected_peers.insert(ip, Arc::clone(&peer));
}
true
} else {
false
}
};
if found {
self.awaiting_connection.remove(&resource);
self.awaiting_connection_timers.remove(resource);
return Ok(Some(ReuseConnection {
resource_id: resource,
gateway_id: gateway,
}));
}
Ok(None)
}
pub fn on_connection_failed(&mut self, resource: ResourceId) {
self.awaiting_connection.remove(&resource);
let Some(gateway) = self.resources_gateways.remove(&resource) else {
return;
};
self.gateway_awaiting_connection.remove(&gateway);
self.awaiting_connection_timers.remove(resource);
}
pub fn on_connection_intent(&mut self, destination: IpAddr) {
if self.is_awaiting_connection_to(destination) {
return;
}
tracing::trace!(resource_ip = %destination, "resource_connection_intent");
let Some(resource) = self.get_resource_by_destination(destination) else {
return;
};
const MAX_SIGNAL_CONNECTION_DELAY: Duration = Duration::from_secs(2);
let resource_id = resource.id();
let connected_gateway_ids = self
.gateway_awaiting_connection
.clone()
.into_keys()
.chain(self.resources_gateways.values().cloned())
.collect();
tracing::trace!(
gateways = ?connected_gateway_ids,
"connected_gateways"
);
match self.awaiting_connection_timers.try_push(
resource_id,
stream::poll_fn({
let mut interval = tokio::time::interval(MAX_SIGNAL_CONNECTION_DELAY);
move |cx| interval.poll_tick(cx).map(Some)
}),
) {
Ok(()) => {}
Err(PushError::BeyondCapacity(_)) => {
tracing::warn!(%resource_id, "Too many concurrent connection attempts");
return;
}
Err(PushError::Replaced(_)) => {
// The timers are equivalent for our purpose so we don't really care about this one.
}
}
self.awaiting_connection.insert(
resource_id,
AwaitingConnectionDetails {
total_attemps: 0,
response_received: false,
gateways: connected_gateway_ids,
},
);
}
pub fn create_peer_config_for_new_connection(
&mut self,
resource: ResourceId,
gateway: GatewayId,
shared_key: StaticSecret,
) -> Result<PeerConfig, ConnlibError> {
let Some(public_key) = self.gateway_public_keys.remove(&gateway) else {
self.awaiting_connection.remove(&resource);
self.gateway_awaiting_connection.remove(&gateway);
return Err(Error::ControlProtocolError);
};
let desc = self
.resources
.get_by_id(&resource)
.ok_or(Error::ControlProtocolError)?;
Ok(PeerConfig {
persistent_keepalive: None,
public_key,
ips: desc.ips(),
preshared_key: SecretKey::new(Key(shared_key.to_bytes())),
})
}
pub fn gateway_by_resource(&self, resource: &ResourceId) -> Option<GatewayId> {
self.resources_gateways.get(resource).copied()
}
pub fn add_waiting_ice_receiver(
&mut self,
id: GatewayId,
@@ -234,10 +370,11 @@ impl ClientState {
self.waiting_for_sdp_from_gatway.insert(id, receiver);
}
pub fn activate_ice_candidate_receiver(&mut self, id: GatewayId) {
pub fn activate_ice_candidate_receiver(&mut self, id: GatewayId, key: PublicKey) {
let Some(receiver) = self.waiting_for_sdp_from_gatway.remove(&id) else {
return;
};
self.gateway_public_keys.insert(id, key);
match self.active_candidate_receivers.try_push(id, receiver) {
Ok(()) => {}
@@ -249,6 +386,36 @@ impl ClientState {
}
}
}
fn is_awaiting_connection_to(&self, destination: IpAddr) -> bool {
let Some(resource) = self.get_resource_by_destination(destination) else {
return false;
};
self.awaiting_connection.contains_key(&resource.id())
}
fn is_connected_to(
&self,
resource: ResourceId,
connected_peers: &IpNetworkTable<Arc<Peer>>,
) -> bool {
let Some(resource) = self.resources.get_by_id(&resource) else {
return false;
};
resource
.ips()
.iter()
.any(|ip| connected_peers.exact_match(*ip).is_some())
}
fn get_resource_by_destination(&self, destination: IpAddr) -> Option<&ResourceDescription> {
match destination {
IpAddr::V4(ipv4) => self.resources.get_by_ip(ipv4),
IpAddr::V6(ipv6) => self.resources.get_by_ip(ipv6),
}
}
}
impl Default for ClientState {
@@ -259,6 +426,12 @@ impl Default for ClientState {
MAX_CONCURRENT_ICE_GATHERING,
),
waiting_for_sdp_from_gatway: Default::default(),
awaiting_connection: Default::default(),
gateway_awaiting_connection: Default::default(),
awaiting_connection_timers: StreamMap::new(Duration::from_secs(60), 100),
gateway_public_keys: Default::default(),
resources_gateways: Default::default(),
resources: Default::default(),
}
}
}
@@ -268,18 +441,60 @@ impl RoleState for ClientState {
fn poll_next_event(&mut self, cx: &mut Context<'_>) -> Poll<Event<Self::Id>> {
loop {
match ready!(self.active_candidate_receivers.poll_next_unpin(cx)) {
(conn_id, Some(Ok(c))) => {
match self.active_candidate_receivers.poll_next_unpin(cx) {
Poll::Ready((conn_id, Some(Ok(c)))) => {
return Poll::Ready(Event::SignalIceCandidate {
conn_id,
candidate: c,
})
}
(id, Some(Err(e))) => {
Poll::Ready((id, Some(Err(e)))) => {
tracing::warn!(gateway_id = %id, "ICE gathering timed out: {e}")
}
(_, None) => {}
Poll::Ready((_, None)) => continue,
Poll::Pending => {}
}
match self.awaiting_connection_timers.poll_next_unpin(cx) {
Poll::Ready((resource, Some(Ok(_)))) => {
let Entry::Occupied(mut entry) = self.awaiting_connection.entry(resource)
else {
self.awaiting_connection_timers.remove(resource);
continue;
};
if entry.get().response_received {
self.awaiting_connection_timers.remove(resource);
// entry.remove(); Maybe?
continue;
}
entry.get_mut().total_attemps += 1;
let reference = entry.get_mut().total_attemps;
return Poll::Ready(Event::ConnectionIntent {
resource: self
.resources
.get_by_id(&resource)
.expect("inconsistent internal state")
.clone(),
connected_gateway_ids: entry.get().gateways.clone(),
reference,
});
}
Poll::Ready((id, Some(Err(e)))) => {
tracing::warn!(resource_id = %id, "Connection establishment timeout: {e}")
}
Poll::Ready((_, None)) => continue,
Poll::Pending => {}
}
return Poll::Pending;
}
}
}

View File

@@ -1,17 +1,14 @@
use boringtun::noise::Tunn;
use chrono::{DateTime, Utc};
use futures::channel::mpsc;
use futures_util::SinkExt;
use secrecy::ExposeSecret;
use std::sync::Arc;
use tracing::instrument;
use connlib_shared::{
messages::{Relay, RequestConnection, ResourceDescription, ReuseConnection},
messages::{Relay, RequestConnection, ReuseConnection},
Callbacks, Error, Result,
};
use webrtc::data_channel::OnCloseHdlrFn;
use webrtc::peer_connection::OnPeerConnectionStateChangeHdlrFn;
use webrtc::{
data_channel::RTCDataChannel,
ice_transport::{
ice_candidate::RTCIceCandidateInit, ice_credential_type::RTCIceCredentialType,
ice_server::RTCIceServer,
@@ -22,7 +19,7 @@ use webrtc::{
},
};
use crate::{peer::Peer, ConnId, ControlSignal, PeerConfig, RoleState, Tunnel};
use crate::{ConnId, RoleState, Tunnel};
mod client;
mod gateway;
@@ -36,177 +33,35 @@ pub enum Request {
ReuseConnection(ReuseConnection),
}
#[tracing::instrument(level = "trace", skip(tunnel))]
async fn handle_connection_state_update_with_peer<C, CB, TRoleState>(
tunnel: &Arc<Tunnel<C, CB, TRoleState>>,
state: RTCPeerConnectionState,
index: u32,
conn_id: ConnId,
) where
C: ControlSignal + Clone + Send + Sync + 'static,
CB: Callbacks + 'static,
TRoleState: RoleState,
{
tracing::trace!(?state, "peer_state_update");
if state == RTCPeerConnectionState::Failed {
tunnel.stop_peer(index, conn_id).await;
}
}
#[tracing::instrument(level = "trace", skip(tunnel))]
fn set_connection_state_with_peer<C, CB, TRoleState>(
tunnel: &Arc<Tunnel<C, CB, TRoleState>>,
peer_connection: &Arc<RTCPeerConnection>,
index: u32,
conn_id: ConnId,
) where
C: ControlSignal + Clone + Send + Sync + 'static,
CB: Callbacks + 'static,
TRoleState: RoleState,
{
let tunnel = Arc::clone(tunnel);
peer_connection.on_peer_connection_state_change(Box::new(
move |state: RTCPeerConnectionState| {
let tunnel = Arc::clone(&tunnel);
Box::pin(async move {
handle_connection_state_update_with_peer(&tunnel, state, index, conn_id).await
})
},
));
}
impl<C, CB, TRoleState> Tunnel<C, CB, TRoleState>
impl<CB, TRoleState> Tunnel<CB, TRoleState>
where
C: ControlSignal + Clone + Send + Sync + 'static,
CB: Callbacks + 'static,
TRoleState: RoleState,
{
#[instrument(level = "trace", skip(self, data_channel, peer_config))]
async fn handle_channel_open(
self: &Arc<Self>,
data_channel: Arc<RTCDataChannel>,
index: u32,
peer_config: PeerConfig,
conn_id: ConnId,
resources: Option<(ResourceDescription, DateTime<Utc>)>,
) -> Result<()> {
tracing::trace!(
?peer_config.ips,
"data_channel_open",
);
let channel = data_channel.detach().await?;
let tunn = Tunn::new(
self.private_key.clone(),
peer_config.public_key,
Some(peer_config.preshared_key.expose_secret().0),
peer_config.persistent_keepalive,
index,
None,
)?;
let peer = Arc::new(Peer::from_config(
tunn,
index,
&peer_config,
channel,
conn_id,
resources,
));
{
// Watch out! we need 2 locks, make sure you don't lock both at the same time anywhere else
let mut gateway_awaiting_connection = self.gateway_awaiting_connection.lock();
let mut peers_by_ip = self.peers_by_ip.write();
// In the gateway this will always be none, no harm done
match conn_id {
ConnId::Gateway(gateway_id) => {
if let Some(awaiting_ips) = gateway_awaiting_connection.remove(&gateway_id) {
for ip in awaiting_ips {
peer.add_allowed_ip(ip);
peers_by_ip.insert(ip, Arc::clone(&peer));
}
}
}
ConnId::Client(_) => {}
ConnId::Resource(_) => {}
}
for ip in peer_config.ips {
peers_by_ip.insert(ip, Arc::clone(&peer));
}
}
if let Some(conn) = self.peer_connections.lock().get(&conn_id) {
set_connection_state_with_peer(self, conn, index, conn_id)
}
data_channel.on_close({
let tunnel = Arc::clone(self);
Box::new(move || {
tracing::debug!("channel_closed");
let tunnel = tunnel.clone();
Box::pin(async move {
tunnel.stop_peer(index, conn_id).await;
})
pub fn on_dc_close_handler(self: Arc<Self>, index: u32, conn_id: ConnId) -> OnCloseHdlrFn {
Box::new(move || {
tracing::debug!("channel_closed");
let tunnel = self.clone();
Box::pin(async move {
tunnel.stop_peer(index, conn_id).await;
})
});
let tunnel = Arc::clone(self);
tokio::spawn(async move { tunnel.start_peer_handler(peer).await });
Ok(())
})
}
#[tracing::instrument(level = "trace", skip(self))]
pub async fn new_peer_connection(
self: &Arc<Self>,
relays: Vec<Relay>,
) -> Result<(Arc<RTCPeerConnection>, mpsc::Receiver<RTCIceCandidateInit>)> {
let config = RTCConfiguration {
ice_servers: relays
.into_iter()
.map(|srv| match srv {
Relay::Stun(stun) => RTCIceServer {
urls: vec![stun.uri],
..Default::default()
},
Relay::Turn(turn) => RTCIceServer {
urls: vec![turn.uri],
username: turn.username,
credential: turn.password,
// TODO: check what this is used for
credential_type: RTCIceCredentialType::Password,
},
})
.collect(),
..Default::default()
};
let peer_connection = Arc::new(self.webrtc_api.new_peer_connection(config).await?);
let (ice_candidate_tx, ice_candidate_rx) = mpsc::channel(ICE_CANDIDATE_BUFFER);
peer_connection.on_ice_candidate(Box::new(move |candidate| {
let Some(candidate) = candidate else {
return Box::pin(async {});
};
let mut ice_candidate_tx = ice_candidate_tx.clone();
pub fn on_peer_connection_state_change_handler(
self: Arc<Self>,
index: u32,
conn_id: ConnId,
) -> OnPeerConnectionStateChangeHdlrFn {
Box::new(move |state| {
let tunnel = Arc::clone(&self);
Box::pin(async move {
let ice_candidate = match candidate.to_json() {
Ok(ice_candidate) => ice_candidate,
Err(e) => {
tracing::warn!("Failed to serialize ICE candidate to JSON: {e}",);
return;
}
};
if ice_candidate_tx.send(ice_candidate).await.is_err() {
debug_assert!(false, "receiver was dropped before sender")
tracing::trace!(?state, "peer_state_update");
if state == RTCPeerConnectionState::Failed {
tunnel.stop_peer(index, conn_id).await;
}
})
}));
Ok((peer_connection, ice_candidate_rx))
})
}
pub async fn add_ice_candidate(
@@ -223,11 +78,57 @@ where
peer_connection.add_ice_candidate(ice_candidate).await?;
Ok(())
}
/// Clean up a connection to a resource.
// FIXME: this cleanup connection is wrong!
pub fn cleanup_connection(&self, id: ConnId) {
self.awaiting_connection.lock().remove(&id);
self.peer_connections.lock().remove(&id);
}
}
#[tracing::instrument(level = "trace", skip(webrtc))]
pub async fn new_peer_connection(
webrtc: &webrtc::api::API,
relays: Vec<Relay>,
) -> Result<(Arc<RTCPeerConnection>, mpsc::Receiver<RTCIceCandidateInit>)> {
let config = RTCConfiguration {
ice_servers: relays
.into_iter()
.map(|srv| match srv {
Relay::Stun(stun) => RTCIceServer {
urls: vec![stun.uri],
..Default::default()
},
Relay::Turn(turn) => RTCIceServer {
urls: vec![turn.uri],
username: turn.username,
credential: turn.password,
// TODO: check what this is used for
credential_type: RTCIceCredentialType::Password,
},
})
.collect(),
..Default::default()
};
let peer_connection = Arc::new(webrtc.new_peer_connection(config).await?);
let (ice_candidate_tx, ice_candidate_rx) = mpsc::channel(ICE_CANDIDATE_BUFFER);
peer_connection.on_ice_candidate(Box::new(move |candidate| {
let Some(candidate) = candidate else {
return Box::pin(async {});
};
let mut ice_candidate_tx = ice_candidate_tx.clone();
Box::pin(async move {
let ice_candidate = match candidate.to_json() {
Ok(ice_candidate) => ice_candidate,
Err(e) => {
tracing::warn!("Failed to serialize ICE candidate to JSON: {e}",);
return;
}
};
if ice_candidate_tx.send(ice_candidate).await.is_err() {
debug_assert!(false, "receiver was dropped before sender")
}
})
}));
Ok((peer_connection, ice_candidate_rx))
}

View File

@@ -1,10 +1,9 @@
use std::sync::Arc;
use boringtun::x25519::{PublicKey, StaticSecret};
use connlib_shared::messages::SecretKey;
use connlib_shared::{
control::Reference,
messages::{GatewayId, Key, Relay, RequestConnection, ResourceId, ReuseConnection},
messages::{GatewayId, Key, Relay, RequestConnection, ResourceId},
Callbacks,
};
use rand_core::OsRng;
@@ -17,40 +16,32 @@ use webrtc::{
},
};
use crate::{ClientState, ControlSignal, Error, PeerConfig, Request, Result, Tunnel};
use crate::control_protocol::new_peer_connection;
use crate::{peer::Peer, ClientState, Error, Request, Result, Tunnel};
#[tracing::instrument(level = "trace", skip(tunnel))]
fn handle_connection_state_update<C, CB>(
tunnel: &Arc<Tunnel<C, CB, ClientState>>,
fn handle_connection_state_update<CB>(
tunnel: &Arc<Tunnel<CB, ClientState>>,
state: RTCPeerConnectionState,
gateway_id: GatewayId,
resource_id: ResourceId,
) where
C: ControlSignal + Clone + Send + Sync + 'static,
CB: Callbacks + 'static,
{
tracing::trace!("peer_state");
if state == RTCPeerConnectionState::Failed {
tunnel
.awaiting_connection
.lock()
.remove(&resource_id.into());
tunnel.role_state.lock().on_connection_failed(resource_id);
tunnel.peer_connections.lock().remove(&gateway_id.into());
tunnel
.gateway_awaiting_connection
.lock()
.remove(&gateway_id);
}
}
#[tracing::instrument(level = "trace", skip(tunnel))]
fn set_connection_state_update<C, CB>(
tunnel: &Arc<Tunnel<C, CB, ClientState>>,
fn set_connection_state_update<CB>(
tunnel: &Arc<Tunnel<CB, ClientState>>,
peer_connection: &Arc<RTCPeerConnection>,
gateway_id: GatewayId,
resource_id: ResourceId,
) where
C: ControlSignal + Clone + Send + Sync + 'static,
CB: Callbacks + 'static,
{
let tunnel = Arc::clone(tunnel);
@@ -64,9 +55,8 @@ fn set_connection_state_update<C, CB>(
));
}
impl<C, CB> Tunnel<C, CB, ClientState>
impl<CB> Tunnel<CB, ClientState>
where
C: ControlSignal + Clone + Send + Sync + 'static,
CB: Callbacks + 'static,
{
/// Initiate an ice connection request.
@@ -89,77 +79,23 @@ where
reference: Option<Reference>,
) -> Result<Request> {
tracing::trace!("request_connection");
let resource_description = self
.resources
.read()
.get_by_id(&resource_id)
.ok_or(Error::UnknownResource)?
.clone();
let reference: usize = reference
.ok_or(Error::InvalidReference)?
.parse()
.map_err(|_| Error::InvalidReference)?;
{
let mut awaiting_connections = self.awaiting_connection.lock();
let Some(awaiting_connection) = awaiting_connections.get_mut(&resource_id.into())
else {
return Err(Error::UnexpectedConnectionDetails);
};
awaiting_connection.response_received = true;
if awaiting_connection.total_attemps != reference
|| resource_description
.ips()
.iter()
.any(|&ip| self.peers_by_ip.read().exact_match(ip).is_some())
{
return Err(Error::UnexpectedConnectionDetails);
}
if let Some(connection) = self.role_state.lock().attempt_to_reuse_connection(
resource_id,
gateway_id,
reference,
&mut self.peers_by_ip.write(),
)? {
return Ok(Request::ReuseConnection(connection));
}
self.resources_gateways
.lock()
.insert(resource_id, gateway_id);
{
let mut gateway_awaiting_connection = self.gateway_awaiting_connection.lock();
if let Some(g) = gateway_awaiting_connection.get_mut(&gateway_id) {
g.extend(resource_description.ips());
return Ok(Request::ReuseConnection(ReuseConnection {
resource_id,
gateway_id,
}));
} else {
gateway_awaiting_connection.insert(gateway_id, vec![]);
}
}
{
let found = {
let mut peers_by_ip = self.peers_by_ip.write();
let peer = peers_by_ip
.iter()
.find_map(|(_, p)| (p.conn_id == gateway_id.into()).then_some(p))
.cloned();
if let Some(peer) = peer {
for ip in resource_description.ips() {
peer.add_allowed_ip(ip);
peers_by_ip.insert(ip, Arc::clone(&peer));
}
true
} else {
false
}
};
if found {
self.awaiting_connection.lock().remove(&resource_id.into());
return Ok(Request::ReuseConnection(ReuseConnection {
resource_id,
gateway_id,
}));
}
}
let peer_connection = {
let (peer_connection, receiver) = self.new_peer_connection(relays).await?;
let (peer_connection, receiver) = new_peer_connection(&self.webrtc_api, relays).await?;
self.role_state
.lock()
.add_waiting_ice_receiver(gateway_id, receiver);
@@ -191,46 +127,63 @@ where
Box::pin(async move {
tracing::trace!("new_data_channel_opened");
let index = tunnel.next_index();
let Some(gateway_public_key) =
tunnel.gateway_public_keys.lock().remove(&gateway_id)
else {
tunnel
.awaiting_connection
.lock()
.remove(&resource_id.into());
tunnel.peer_connections.lock().remove(&gateway_id.into());
tunnel
.gateway_awaiting_connection
.lock()
.remove(&gateway_id);
let e = Error::ControlProtocolError;
tracing::warn!(err = ?e, "channel_open");
let _ = tunnel.callbacks.on_error(&e);
return;
};
let peer_config = PeerConfig {
persistent_keepalive: None,
public_key: gateway_public_key,
ips: resource_description.ips(),
preshared_key: SecretKey::new(Key(p_key.to_bytes())),
let peer_config = match tunnel.role_state.lock().create_peer_config_for_new_connection(resource_id, gateway_id, p_key) {
Ok(c) => c,
Err(e) => {
tunnel.peer_connections.lock().remove(&gateway_id.into());
tracing::warn!(err = ?e, "channel_open");
let _ = tunnel.callbacks.on_error(&e);
return;
}
};
if let Err(e) = tunnel
.handle_channel_open(d, index, peer_config, gateway_id.into(), None)
.await
d.on_close(tunnel.clone().on_dc_close_handler(index, gateway_id.into()));
let peer = Arc::new(Peer::new(
tunnel.private_key.clone(),
index,
peer_config.clone(),
d.detach().await.expect("only fails if not opened or not enabled, both of which are always true for us"),
gateway_id.into(),
None,
));
{
tracing::error!(err = ?e, "channel_open");
let _ = tunnel.callbacks.on_error(&e);
tunnel.peer_connections.lock().remove(&gateway_id.into());
tunnel
.gateway_awaiting_connection
.lock()
.remove(&gateway_id);
let mut role_state = tunnel.role_state.lock();
// Watch out! we need 2 locks, make sure you don't lock both at the same time anywhere else
let mut peers_by_ip = tunnel.peers_by_ip.write();
if let Some(awaiting_ips) =
role_state.gateway_awaiting_connection.remove(&gateway_id)
{
for ip in awaiting_ips {
peer.add_allowed_ip(ip);
peers_by_ip.insert(ip, Arc::clone(&peer));
}
}
for ip in peer_config.ips {
peers_by_ip.insert(ip, Arc::clone(&peer));
}
}
if let Some(conn) = tunnel.peer_connections.lock().get(&gateway_id.into()) {
conn.on_peer_connection_state_change(
tunnel
.clone()
.on_peer_connection_state_change_handler(index, gateway_id.into()),
);
}
tokio::spawn(tunnel.clone().start_peer_handler(peer));
tunnel
.awaiting_connection
.role_state
.lock()
.remove(&resource_id.into());
.awaiting_connection
.remove(&resource_id);
})
}));
@@ -260,10 +213,10 @@ where
rtc_sdp: RTCSessionDescription,
gateway_public_key: PublicKey,
) -> Result<()> {
let gateway_id = *self
.resources_gateways
let gateway_id = self
.role_state
.lock()
.get(&resource_id)
.gateway_by_resource(&resource_id)
.ok_or(Error::UnknownResource)?;
let peer_connection = self
.peer_connections
@@ -271,14 +224,11 @@ where
.get(&gateway_id.into())
.ok_or(Error::UnknownResource)?
.clone();
self.gateway_public_keys
.lock()
.insert(gateway_id, gateway_public_key);
peer_connection.set_remote_description(rtc_sdp).await?;
self.role_state
.lock()
.activate_ice_candidate_receiver(gateway_id);
.activate_ice_candidate_receiver(gateway_id, gateway_public_key);
Ok(())
}

View File

@@ -10,15 +10,15 @@ use webrtc::peer_connection::{
RTCPeerConnection,
};
use crate::{ControlSignal, GatewayState, PeerConfig, Tunnel};
use crate::control_protocol::new_peer_connection;
use crate::{peer::Peer, GatewayState, PeerConfig, Tunnel};
#[tracing::instrument(level = "trace", skip(tunnel))]
fn handle_connection_state_update<C, CB>(
tunnel: &Arc<Tunnel<C, CB, GatewayState>>,
fn handle_connection_state_update<CB>(
tunnel: &Arc<Tunnel<CB, GatewayState>>,
state: RTCPeerConnectionState,
client_id: ClientId,
) where
C: ControlSignal + Clone + Send + Sync + 'static,
CB: Callbacks + 'static,
{
tracing::trace!(?state, "peer_state");
@@ -28,12 +28,11 @@ fn handle_connection_state_update<C, CB>(
}
#[tracing::instrument(level = "trace", skip(tunnel))]
fn set_connection_state_update<C, CB>(
tunnel: &Arc<Tunnel<C, CB, GatewayState>>,
fn set_connection_state_update<CB>(
tunnel: &Arc<Tunnel<CB, GatewayState>>,
peer_connection: &Arc<RTCPeerConnection>,
client_id: ClientId,
) where
C: ControlSignal + Clone + Send + Sync + 'static,
CB: Callbacks + 'static,
{
let tunnel = Arc::clone(tunnel);
@@ -45,9 +44,8 @@ fn set_connection_state_update<C, CB>(
));
}
impl<C, CB> Tunnel<C, CB, GatewayState>
impl<CB> Tunnel<CB, GatewayState>
where
C: ControlSignal + Clone + Send + Sync + 'static,
CB: Callbacks + 'static,
{
/// Accept a connection request from a client.
@@ -72,7 +70,7 @@ where
expires_at: DateTime<Utc>,
resource: ResourceDescription,
) -> Result<RTCSessionDescription> {
let (peer_connection, receiver) = self.new_peer_connection(relays).await?;
let (peer_connection, receiver) = new_peer_connection(&self.webrtc_api, relays).await?;
self.role_state
.lock()
.add_new_ice_receiver(client_id, receiver);
@@ -88,12 +86,12 @@ where
peer_connection.on_data_channel(Box::new(move |d| {
tracing::trace!("new_data_channel");
let data_channel = Arc::clone(&d);
let peer = peer.clone();
let peer_config = peer.clone();
let tunnel = Arc::clone(&tunnel);
let resource = resource.clone();
Box::pin(async move {
d.on_open(Box::new(move || {
tracing::trace!("new_data_channel_open");
tracing::trace!(?peer_config.ips, "new_data_channel_open");
Box::pin(async move {
{
let Some(device) = tunnel.device.read().await.clone() else {
@@ -103,7 +101,7 @@ where
return;
};
let iface_config = device.config;
for &ip in &peer.ips {
for &ip in &peer_config.ips {
if let Err(e) = iface_config.add_route(ip, tunnel.callbacks()).await
{
let _ = tunnel.callbacks.on_error(&e);
@@ -111,28 +109,34 @@ where
}
}
if let Err(e) = tunnel
.handle_channel_open(
data_channel,
index,
peer,
client_id.into(),
Some((resource, expires_at)),
)
.await
{
let _ = tunnel.callbacks.on_error(&e);
tracing::error!(err = ?e, "channel_open");
// Note: handle_channel_open can only error out before insert to peers_by_ip
// otherwise we would need to clean that up too!
let conn = tunnel.peer_connections.lock().remove(&client_id.into());
if let Some(conn) = conn {
if let Err(e) = conn.close().await {
tracing::error!(error = ?e, "webrtc_close_channel");
let _ = tunnel.callbacks().on_error(&e.into());
}
}
data_channel
.on_close(tunnel.clone().on_dc_close_handler(index, client_id.into()));
let peer = Arc::new(Peer::new(
tunnel.private_key.clone(),
index,
peer_config.clone(),
data_channel.detach().await.expect("only fails if not opened or not enabled, both of which are always true for us"),
client_id.into(),
Some((resource, expires_at)),
));
let mut peers_by_ip = tunnel.peers_by_ip.write();
for ip in peer_config.ips {
peers_by_ip.insert(ip, Arc::clone(&peer));
}
if let Some(conn) = tunnel.peer_connections.lock().get(&client_id.into()) {
conn.on_peer_connection_state_change(
tunnel.clone().on_peer_connection_state_change_handler(
index,
client_id.into(),
),
);
}
tokio::spawn(tunnel.clone().start_peer_handler(peer));
})
}))
})

View File

@@ -1,6 +1,6 @@
use crate::device_channel::create_iface;
use crate::{
ControlSignal, Device, Event, RoleState, Tunnel, ICE_GATHERING_TIMEOUT_SECONDS,
peer_by_ip, ConnId, Device, Event, RoleState, Tunnel, ICE_GATHERING_TIMEOUT_SECONDS,
MAX_CONCURRENT_ICE_GATHERING, MAX_UDP_SIZE,
};
use connlib_shared::error::ConnlibError;
@@ -13,9 +13,8 @@ use std::task::{ready, Context, Poll};
use std::time::Duration;
use webrtc::ice_transport::ice_candidate::RTCIceCandidateInit;
impl<C, CB> Tunnel<C, CB, GatewayState>
impl<CB> Tunnel<CB, GatewayState>
where
C: ControlSignal + Send + Sync + 'static,
CB: Callbacks + 'static,
{
/// Sets the interface configuration and starts background tasks.
@@ -35,15 +34,20 @@ where
Ok(())
}
/// Clean up a connection to a resource.
// FIXME: this cleanup connection is wrong!
pub fn cleanup_connection(&self, id: ConnId) {
self.peer_connections.lock().remove(&id);
}
}
/// Reads IP packets from the [`Device`] and handles them accordingly.
async fn device_handler<C, CB>(
tunnel: Arc<Tunnel<C, CB, GatewayState>>,
async fn device_handler<CB>(
tunnel: Arc<Tunnel<CB, GatewayState>>,
mut device: Device,
) -> Result<(), ConnlibError>
where
C: ControlSignal + Send + Sync + 'static,
CB: Callbacks + 'static,
{
let mut buf = [0u8; MAX_UDP_SIZE];
@@ -55,7 +59,7 @@ where
let dest = packet.destination();
let Some(peer) = tunnel.peer_by_ip(dest) else {
let Some(peer) = peer_by_ip(&tunnel.peers_by_ip.read(), dest) else {
continue;
};

View File

@@ -4,11 +4,10 @@ use boringtun::noise::{errors::WireGuardError, TunnResult};
use bytes::Bytes;
use connlib_shared::{Callbacks, Result};
use crate::{ip_packet::MutableIpPacket, peer::Peer, ControlSignal, RoleState, Tunnel};
use crate::{ip_packet::MutableIpPacket, peer::Peer, RoleState, Tunnel};
impl<C, CB, TRoleState> Tunnel<C, CB, TRoleState>
impl<CB, TRoleState> Tunnel<CB, TRoleState>
where
C: ControlSignal + Send + Sync + 'static,
CB: Callbacks + 'static,
TRoleState: RoleState,
{

View File

@@ -13,7 +13,6 @@ use ip_network::IpNetwork;
use ip_network_table::IpNetworkTable;
use serde::{Deserialize, Serialize};
use async_trait::async_trait;
use itertools::Itertools;
use parking_lot::{Mutex, RwLock};
use peer::{Peer, PeerStats};
@@ -136,29 +135,6 @@ impl From<connlib_shared::messages::Peer> for PeerConfig {
}
}
}
/// Trait used for out-going signals to control plane that are **required** to be made from inside the tunnel.
///
/// Generally, we try to return from the functions here rather than using this callback.
#[async_trait]
pub trait ControlSignal {
/// Signals to the control plane an intent to initiate a connection to the given resource.
///
/// Used when a packet is found to a resource we have no connection stablished but is within the list of resources available for the client.
async fn signal_connection_to(
&self,
resource: &ResourceDescription,
connected_gateway_ids: &[GatewayId],
reference: usize,
) -> Result<()>;
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
struct AwaitingConnectionDetails {
pub total_attemps: usize,
pub response_received: bool,
}
#[derive(Clone)]
struct Device {
config: Arc<IfaceConfig>,
@@ -190,7 +166,7 @@ impl Device {
// TODO: We should use newtypes for each kind of Id
/// Tunnel is a wireguard state machine that uses webrtc's ICE channels instead of UDP sockets
/// to communicate between peers.
pub struct Tunnel<C: ControlSignal, CB: Callbacks, TRoleState> {
pub struct Tunnel<CB: Callbacks, TRoleState> {
next_index: Mutex<IndexLfsr>,
// We use a tokio Mutex here since this is only read/write during config so there's no relevant performance impact
device: tokio::sync::RwLock<Option<Device>>,
@@ -199,13 +175,7 @@ pub struct Tunnel<C: ControlSignal, CB: Callbacks, TRoleState> {
public_key: PublicKey,
peers_by_ip: RwLock<IpNetworkTable<Arc<Peer>>>,
peer_connections: Mutex<HashMap<ConnId, Arc<RTCPeerConnection>>>,
awaiting_connection: Mutex<HashMap<ConnId, AwaitingConnectionDetails>>,
gateway_awaiting_connection: Mutex<HashMap<GatewayId, Vec<IpNetwork>>>,
resources_gateways: Mutex<HashMap<ResourceId, GatewayId>>,
webrtc_api: API,
resources: Arc<RwLock<ResourceTable<ResourceDescription>>>,
control_signaler: C,
gateway_public_keys: Mutex<HashMap<GatewayId, PublicKey>>,
callbacks: CallbackErrorFacade<CB>,
iface_handler_abort: Mutex<Option<AbortHandle>>,
@@ -220,18 +190,10 @@ pub struct TunnelStats {
public_key: String,
peers_by_ip: HashMap<IpNetwork, PeerStats>,
peer_connections: Vec<ConnId>,
resource_gateways: HashMap<ResourceId, GatewayId>,
dns_resources: HashMap<String, ResourceDescription>,
network_resources: HashMap<IpNetwork, ResourceDescription>,
gateway_public_keys: HashMap<GatewayId, String>,
awaiting_connection: HashMap<ConnId, AwaitingConnectionDetails>,
gateway_awaiting_connection: HashMap<GatewayId, Vec<IpNetwork>>,
}
impl<C, CB, TRoleState> Tunnel<C, CB, TRoleState>
impl<CB, TRoleState> Tunnel<CB, TRoleState>
where
C: ControlSignal + Send + Sync + 'static,
CB: Callbacks + 'static,
TRoleState: RoleState,
{
@@ -243,30 +205,11 @@ where
.map(|(ip, peer)| (ip, peer.stats()))
.collect();
let peer_connections = self.peer_connections.lock().keys().cloned().collect();
let awaiting_connection = self.awaiting_connection.lock().clone();
let gateway_awaiting_connection = self.gateway_awaiting_connection.lock().clone();
let resource_gateways = self.resources_gateways.lock().clone();
let (network_resources, dns_resources) = {
let resources = self.resources.read();
(resources.network_resources(), resources.dns_resources())
};
let gateway_public_keys = self
.gateway_public_keys
.lock()
.iter()
.map(|(&id, &k)| (id, Key::from(k).to_string()))
.collect();
TunnelStats {
public_key: Key::from(self.public_key).to_string(),
peers_by_ip,
peer_connections,
awaiting_connection,
gateway_awaiting_connection,
resource_gateways,
dns_resources,
network_resources,
gateway_public_keys,
}
}
@@ -277,14 +220,10 @@ where
pub fn poll_next_event(&self, cx: &mut Context<'_>) -> Poll<Event<TRoleState::Id>> {
self.role_state.lock().poll_next_event(cx)
}
}
pub(crate) fn peer_by_ip(&self, ip: IpAddr) -> Option<Arc<Peer>> {
self.peers_by_ip
.read()
.longest_match(ip)
.map(|(_, peer)| peer)
.cloned()
}
pub(crate) fn peer_by_ip(peers_by_ip: &IpNetworkTable<Arc<Peer>>, ip: IpAddr) -> Option<Arc<Peer>> {
peers_by_ip.longest_match(ip).map(|(_, peer)| peer).cloned()
}
pub enum Event<TId> {
@@ -292,11 +231,15 @@ pub enum Event<TId> {
conn_id: TId,
candidate: RTCIceCandidateInit,
},
ConnectionIntent {
resource: ResourceDescription,
connected_gateway_ids: Vec<GatewayId>,
reference: usize,
},
}
impl<C, CB, TRoleState> Tunnel<C, CB, TRoleState>
impl<CB, TRoleState> Tunnel<CB, TRoleState>
where
C: ControlSignal + Send + Sync + 'static,
CB: Callbacks + 'static,
TRoleState: RoleState,
{
@@ -305,22 +248,14 @@ where
/// # Parameters
/// - `private_key`: wireguard's private key.
/// - `control_signaler`: this is used to send SDP from the tunnel to the control plane.
#[tracing::instrument(level = "trace", skip(private_key, control_signaler, callbacks))]
pub async fn new(
private_key: StaticSecret,
control_signaler: C,
callbacks: CB,
) -> Result<Self> {
#[tracing::instrument(level = "trace", skip(private_key, callbacks))]
pub async fn new(private_key: StaticSecret, callbacks: CB) -> Result<Self> {
let public_key = (&private_key).into();
let rate_limiter = Arc::new(RateLimiter::new(&public_key, HANDSHAKE_RATE_LIMIT));
let peers_by_ip = RwLock::new(IpNetworkTable::new());
let next_index = Default::default();
let peer_connections = Default::default();
let resources: Arc<RwLock<ResourceTable<ResourceDescription>>> = Default::default();
let awaiting_connection = Default::default();
let gateway_public_keys = Default::default();
let resources_gateways = Default::default();
let gateway_awaiting_connection = Default::default();
let device = Default::default();
let iface_handler_abort = Default::default();
@@ -347,7 +282,6 @@ where
.build();
Ok(Self {
gateway_public_keys,
rate_limiter,
private_key,
peer_connections,
@@ -355,12 +289,7 @@ where
peers_by_ip,
next_index,
webrtc_api,
resources,
device,
awaiting_connection,
gateway_awaiting_connection,
control_signaler,
resources_gateways,
callbacks: CallbackErrorFacade(callbacks),
iface_handler_abort,
role_state: Default::default(),
@@ -411,31 +340,6 @@ where
});
}
fn remove_expired_peers(self: &Arc<Self>) {
let mut peers_by_ip = self.peers_by_ip.write();
for (_, peer) in peers_by_ip.iter() {
peer.expire_resources();
if peer.is_emptied() {
tracing::trace!(index = peer.index, "peer_expired");
let conn = self.peer_connections.lock().remove(&peer.conn_id);
let p = peer.clone();
// We are holding a Mutex, particularly a write one, we don't want to make a blocking call
tokio::spawn(async move {
let _ = p.shutdown().await;
if let Some(conn) = conn {
// TODO: it seems that even closing the stream there are messages to the relay
// see where they come from.
let _ = conn.close().await;
}
});
}
}
peers_by_ip.retain(|_, p| !p.is_emptied());
}
fn start_peers_refresh_timer(self: &Arc<Self>) {
let tunnel = self.clone();
@@ -445,7 +349,10 @@ where
let mut dst_buf = [0u8; MAX_UDP_SIZE];
loop {
tunnel.remove_expired_peers();
remove_expired_peers(
&mut tunnel.peers_by_ip.write(),
&mut tunnel.peer_connections.lock(),
);
let peers: Vec<_> = tunnel
.peers_by_ip
@@ -497,14 +404,6 @@ where
Ok(())
}
fn get_resource(&self, addr: IpAddr) -> Option<ResourceDescription> {
let resources = self.resources.read();
match addr {
IpAddr::V4(ipv4) => resources.get_by_ip(ipv4).cloned(),
IpAddr::V6(ipv6) => resources.get_by_ip(ipv6).cloned(),
}
}
fn next_index(&self) -> u32 {
self.next_index.lock().next()
}
@@ -514,6 +413,32 @@ where
}
}
fn remove_expired_peers(
peers_by_ip: &mut IpNetworkTable<Arc<Peer>>,
peer_connections: &mut HashMap<ConnId, Arc<RTCPeerConnection>>,
) {
for (_, peer) in peers_by_ip.iter() {
peer.expire_resources();
if peer.is_emptied() {
tracing::trace!(index = peer.index, "peer_expired");
let conn = peer_connections.remove(&peer.conn_id);
let p = peer.clone();
// We are holding a Mutex, particularly a write one, we don't want to make a blocking call
tokio::spawn(async move {
let _ = p.shutdown().await;
if let Some(conn) = conn {
// TODO: it seems that even closing the stream there are messages to the relay
// see where they come from.
let _ = conn.close().await;
}
});
}
}
peers_by_ip.retain(|_, p| !p.is_emptied());
}
/// Dedicated trait for abstracting over the different ICE states.
///
/// By design, this trait does not allow any operations apart from advancing via [`RoleState::poll_next_event`].

View File

@@ -1,6 +1,7 @@
use std::{collections::HashMap, net::IpAddr, sync::Arc};
use boringtun::noise::{Tunn, TunnResult};
use boringtun::x25519::StaticSecret;
use bytes::Bytes;
use chrono::{DateTime, Utc};
use connlib_shared::{
@@ -11,11 +12,10 @@ use ip_network::IpNetwork;
use ip_network_table::IpNetworkTable;
use parking_lot::{Mutex, RwLock};
use pnet_packet::MutablePacket;
use secrecy::ExposeSecret;
use webrtc::data::data_channel::DataChannel;
use crate::{ip_packet::MutableIpPacket, resource_table::ResourceTable, ConnId};
use super::PeerConfig;
use crate::{ip_packet::MutableIpPacket, resource_table::ResourceTable, ConnId, PeerConfig};
type ExpiryingResource = (ResourceDescription, DateTime<Utc>);
@@ -87,34 +87,26 @@ impl Peer {
}
}
pub(crate) fn from_config(
tunnel: Tunn,
index: u32,
config: &PeerConfig,
channel: Arc<DataChannel>,
conn_id: ConnId,
resource: Option<(ResourceDescription, DateTime<Utc>)>,
) -> Self {
Self::new(
Mutex::new(tunnel),
index,
config.ips.clone(),
channel,
conn_id,
resource,
)
}
pub(crate) fn new(
tunnel: Mutex<Tunn>,
private_key: StaticSecret,
index: u32,
ips: Vec<IpNetwork>,
peer_config: PeerConfig,
channel: Arc<DataChannel>,
conn_id: ConnId,
resource: Option<(ResourceDescription, DateTime<Utc>)>,
) -> Peer {
let tunnel = Tunn::new(
private_key.clone(),
peer_config.public_key,
Some(peer_config.preshared_key.expose_secret().0),
peer_config.persistent_keepalive,
index,
None,
)
.expect("never actually fails"); // See https://github.com/cloudflare/boringtun/pull/366.
let mut allowed_ips = IpNetworkTable::new();
for ip in ips {
for ip in peer_config.ips {
allowed_ips.insert(ip, ());
}
let allowed_ips = RwLock::new(allowed_ips);
@@ -123,8 +115,9 @@ impl Peer {
resource_table.insert(r);
RwLock::new(resource_table)
});
Peer {
tunnel,
tunnel: Mutex::new(tunnel),
index,
allowed_ips,
channel,

View File

@@ -5,13 +5,12 @@ use bytes::Bytes;
use connlib_shared::{Callbacks, Error, Result};
use crate::{
device_channel::DeviceIo, index::check_packet_index, peer::Peer, ControlSignal, RoleState,
Tunnel, MAX_UDP_SIZE,
device_channel::DeviceIo, index::check_packet_index, peer::Peer, RoleState, Tunnel,
MAX_UDP_SIZE,
};
impl<C, CB, TRoleState> Tunnel<C, CB, TRoleState>
impl<CB, TRoleState> Tunnel<CB, TRoleState>
where
C: ControlSignal + Send + Sync + 'static,
CB: Callbacks + 'static,
TRoleState: RoleState,
{
@@ -155,7 +154,7 @@ where
Ok(())
}
pub(crate) async fn start_peer_handler(self: &Arc<Self>, peer: Arc<Peer>) {
pub(crate) async fn start_peer_handler(self: Arc<Self>, peer: Arc<Peer>) {
loop {
let Some(device) = self.device.read().await.clone() else {
let err = Error::NoIface;

View File

@@ -3,15 +3,12 @@ use std::{
sync::Arc,
};
use crate::{
device_channel::DeviceIo, ip_packet::MutableIpPacket, peer::Peer, ControlSignal, Tunnel,
};
use crate::{device_channel::DeviceIo, ip_packet::MutableIpPacket, peer::Peer, Tunnel};
use connlib_shared::{messages::ResourceDescription, Callbacks, Error, Result};
impl<C, CB, TRoleState> Tunnel<C, CB, TRoleState>
impl<CB, TRoleState> Tunnel<CB, TRoleState>
where
C: ControlSignal + Send + Sync + 'static,
CB: Callbacks + 'static,
{
#[inline(always)]

View File

@@ -1,22 +0,0 @@
use async_trait::async_trait;
use connlib_shared::{
messages::{GatewayId, ResourceDescription},
Result,
};
use firezone_tunnel::ControlSignal;
#[derive(Clone)]
pub struct ControlSignaler;
#[async_trait]
impl ControlSignal for ControlSignaler {
async fn signal_connection_to(
&self,
resource: &ResourceDescription,
_connected_gateway_ids: &[GatewayId],
_: usize,
) -> Result<()> {
tracing::warn!("A message to network resource: {resource:?} was discarded, gateways aren't meant to be used as clients.");
Ok(())
}
}

View File

@@ -1,4 +1,3 @@
use crate::control::ControlSignaler;
use crate::messages::{
AllowAccess, BroadcastClientIceCandidates, ClientIceCandidates, ConnectionReady,
EgressMessages, IngressMessages,
@@ -7,7 +6,7 @@ use crate::CallbackHandler;
use anyhow::Result;
use connlib_shared::messages::ClientId;
use connlib_shared::Error;
use firezone_tunnel::{GatewayState, Tunnel};
use firezone_tunnel::{Event, GatewayState, Tunnel};
use phoenix_channel::PhoenixChannel;
use std::convert::Infallible;
use std::sync::Arc;
@@ -18,7 +17,7 @@ use webrtc::peer_connection::sdp::session_description::RTCSessionDescription;
pub const PHOENIX_TOPIC: &str = "gateway";
pub struct Eventloop {
tunnel: Arc<Tunnel<ControlSignaler, CallbackHandler, GatewayState>>,
tunnel: Arc<Tunnel<CallbackHandler, GatewayState>>,
portal: PhoenixChannel<IngressMessages, ()>,
// TODO: Strongly type request reference (currently `String`)
@@ -31,7 +30,7 @@ pub struct Eventloop {
impl Eventloop {
pub(crate) fn new(
tunnel: Arc<Tunnel<ControlSignaler, CallbackHandler, GatewayState>>,
tunnel: Arc<Tunnel<CallbackHandler, GatewayState>>,
portal: PhoenixChannel<IngressMessages, ()>,
) -> Self {
Self {
@@ -190,6 +189,9 @@ impl Eventloop {
);
continue;
}
Poll::Ready(Event::ConnectionIntent { .. }) => {
unreachable!("Not used on the gateway, split the events!")
}
Poll::Pending => {}
}

View File

@@ -1,4 +1,3 @@
use crate::control::ControlSignaler;
use crate::eventloop::{Eventloop, PHOENIX_TOPIC};
use crate::messages::InitGateway;
use anyhow::{Context, Result};
@@ -15,7 +14,6 @@ use std::sync::Arc;
use tracing_subscriber::layer;
use url::Url;
mod control;
mod eventloop;
mod messages;
@@ -30,7 +28,7 @@ async fn main() -> Result<()> {
SecretString::new(cli.common.secret),
get_device_id(),
)?;
let tunnel = Arc::new(Tunnel::new(private_key, ControlSignaler, CallbackHandler).await?);
let tunnel = Arc::new(Tunnel::new(private_key, CallbackHandler).await?);
tokio::spawn(backoff::future::retry_notify(
ExponentialBackoffBuilder::default()
@@ -48,7 +46,7 @@ async fn main() -> Result<()> {
}
async fn run(
tunnel: Arc<Tunnel<ControlSignaler, CallbackHandler, GatewayState>>,
tunnel: Arc<Tunnel<CallbackHandler, GatewayState>>,
connect_url: Url,
) -> Result<Infallible> {
let (portal, init) = phoenix_channel::init::<InitGateway, _, _>(