mirror of
https://github.com/outbackdingo/firezone.git
synced 2026-01-27 10:18:54 +00:00
refactor(connlib): remove ControlSignal (#2321)
This commit is contained in:
@@ -12,7 +12,7 @@ RUN cargo chef prepare --recipe-path recipe.json
|
||||
|
||||
FROM chef as builder
|
||||
COPY --from=planner /build/recipe.json .
|
||||
RUN cargo chef cook --release --recipe-path recipe.json
|
||||
RUN cargo chef cook --all-targets --release --recipe-path recipe.json
|
||||
COPY . .
|
||||
ARG PACKAGE
|
||||
RUN cargo build -p $PACKAGE --release
|
||||
|
||||
@@ -14,48 +14,19 @@ use connlib_shared::{
|
||||
Result,
|
||||
};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use firezone_tunnel::{ClientState, ControlSignal, Request, Tunnel};
|
||||
use firezone_tunnel::{ClientState, Request, Tunnel};
|
||||
use reqwest::header::{CONTENT_ENCODING, CONTENT_TYPE};
|
||||
use tokio::io::BufReader;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio_util::codec::{BytesCodec, FramedRead};
|
||||
use url::Url;
|
||||
|
||||
#[async_trait]
|
||||
impl ControlSignal for ControlSignaler {
|
||||
async fn signal_connection_to(
|
||||
&self,
|
||||
resource: &ResourceDescription,
|
||||
connected_gateway_ids: &[GatewayId],
|
||||
reference: usize,
|
||||
) -> Result<()> {
|
||||
self.control_signal
|
||||
// It's easier if self is not mut
|
||||
.clone()
|
||||
.send_with_ref(
|
||||
EgressMessages::PrepareConnection {
|
||||
resource_id: resource.id(),
|
||||
connected_gateway_ids: connected_gateway_ids.to_vec(),
|
||||
},
|
||||
reference,
|
||||
)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ControlPlane<CB: Callbacks> {
|
||||
pub tunnel: Arc<Tunnel<ControlSignaler, CB, ClientState>>,
|
||||
pub control_signaler: ControlSignaler,
|
||||
pub tunnel: Arc<Tunnel<CB, ClientState>>,
|
||||
pub phoenix_channel: PhoenixSenderWithTopic,
|
||||
pub tunnel_init: Mutex<bool>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ControlSignaler {
|
||||
pub control_signal: PhoenixSenderWithTopic,
|
||||
}
|
||||
|
||||
impl<CB: Callbacks + 'static> ControlPlane<CB> {
|
||||
#[tracing::instrument(level = "trace", skip(self))]
|
||||
async fn init(
|
||||
@@ -139,7 +110,7 @@ impl<CB: Callbacks + 'static> ControlPlane<CB> {
|
||||
reference: Option<Reference>,
|
||||
) {
|
||||
let tunnel = Arc::clone(&self.tunnel);
|
||||
let mut control_signaler = self.control_signaler.clone();
|
||||
let mut control_signaler = self.phoenix_channel.clone();
|
||||
tokio::spawn(async move {
|
||||
let err = match tunnel
|
||||
.request_connection(resource_id, gateway_id, relays, reference)
|
||||
@@ -147,7 +118,6 @@ impl<CB: Callbacks + 'static> ControlPlane<CB> {
|
||||
{
|
||||
Ok(Request::NewConnection(connection_request)) => {
|
||||
if let Err(err) = control_signaler
|
||||
.control_signal
|
||||
// TODO: create a reference number and keep track for the response
|
||||
.send_with_ref(
|
||||
EgressMessages::RequestConnection(connection_request),
|
||||
@@ -162,7 +132,6 @@ impl<CB: Callbacks + 'static> ControlPlane<CB> {
|
||||
}
|
||||
Ok(Request::ReuseConnection(connection_request)) => {
|
||||
if let Err(err) = control_signaler
|
||||
.control_signal
|
||||
// TODO: create a reference number and keep track for the response
|
||||
.send_with_ref(
|
||||
EgressMessages::ReuseConnection(connection_request),
|
||||
@@ -178,7 +147,7 @@ impl<CB: Callbacks + 'static> ControlPlane<CB> {
|
||||
Err(err) => err,
|
||||
};
|
||||
|
||||
tunnel.cleanup_connection(resource_id.into());
|
||||
tunnel.cleanup_connection(resource_id);
|
||||
tracing::error!("Error request connection details: {err}");
|
||||
let _ = tunnel.callbacks().on_error(&err);
|
||||
});
|
||||
@@ -250,7 +219,7 @@ impl<CB: Callbacks + 'static> ControlPlane<CB> {
|
||||
return;
|
||||
};
|
||||
// TODO: Rate limit the number of attempts of getting the relays before just trying a local network connection
|
||||
self.tunnel.cleanup_connection(resource_id.into());
|
||||
self.tunnel.cleanup_connection(resource_id);
|
||||
}
|
||||
None => {
|
||||
tracing::error!(
|
||||
@@ -273,8 +242,7 @@ impl<CB: Callbacks + 'static> ControlPlane<CB> {
|
||||
tracing::info!("Requesting log upload URL from portal");
|
||||
|
||||
let _ = self
|
||||
.control_signaler
|
||||
.control_signal
|
||||
.phoenix_channel
|
||||
.send(EgressMessages::CreateLogSink {})
|
||||
.await;
|
||||
}
|
||||
@@ -283,8 +251,7 @@ impl<CB: Callbacks + 'static> ControlPlane<CB> {
|
||||
match event {
|
||||
firezone_tunnel::Event::SignalIceCandidate { conn_id, candidate } => {
|
||||
if let Err(e) = self
|
||||
.control_signaler
|
||||
.control_signal
|
||||
.phoenix_channel
|
||||
.send(EgressMessages::BroadcastIceCandidates(
|
||||
BroadcastGatewayIceCandidates {
|
||||
gateway_ids: vec![conn_id],
|
||||
@@ -296,6 +263,28 @@ impl<CB: Callbacks + 'static> ControlPlane<CB> {
|
||||
tracing::error!("Failed to signal ICE candidate: {e}")
|
||||
}
|
||||
}
|
||||
firezone_tunnel::Event::ConnectionIntent {
|
||||
resource,
|
||||
connected_gateway_ids,
|
||||
reference,
|
||||
} => {
|
||||
if let Err(e) = self
|
||||
.phoenix_channel
|
||||
.clone()
|
||||
.send_with_ref(
|
||||
EgressMessages::PrepareConnection {
|
||||
resource_id: resource.id(),
|
||||
connected_gateway_ids: connected_gateway_ids.to_vec(),
|
||||
},
|
||||
reference,
|
||||
)
|
||||
.await
|
||||
{
|
||||
tracing::error!("Failed to prepare connection: {e}");
|
||||
|
||||
// TODO: Clean up connection in `ClientState` here?
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,7 +3,6 @@ pub use connlib_shared::{get_device_id, messages::ResourceDescription};
|
||||
pub use connlib_shared::{Callbacks, Error};
|
||||
pub use tracing_appender::non_blocking::WorkerGuard;
|
||||
|
||||
use crate::control::ControlSignaler;
|
||||
use backoff::{backoff::Backoff, ExponentialBackoffBuilder};
|
||||
use connlib_shared::control::SecureUrl;
|
||||
use connlib_shared::{control::PhoenixChannel, login_url, CallbackErrorFacade, Mode, Result};
|
||||
@@ -149,16 +148,15 @@ where
|
||||
}
|
||||
});
|
||||
|
||||
let control_signaler = ControlSignaler { control_signal: connection.sender_with_topic("client".to_owned()) };
|
||||
let tunnel = fatal_error!(
|
||||
Tunnel::new(private_key, control_signaler.clone(), callbacks.clone()).await,
|
||||
Tunnel::new(private_key, callbacks.clone()).await,
|
||||
runtime_stopper,
|
||||
&callbacks
|
||||
);
|
||||
|
||||
let mut control_plane = ControlPlane {
|
||||
tunnel: Arc::new(tunnel),
|
||||
control_signaler,
|
||||
phoenix_channel: connection.sender_with_topic("client".to_owned()),
|
||||
tunnel_init: Mutex::new(false),
|
||||
};
|
||||
|
||||
|
||||
@@ -1,25 +1,34 @@
|
||||
use crate::device_channel::{create_iface, DeviceIo};
|
||||
use crate::ip_packet::IpPacket;
|
||||
use crate::peer::Peer;
|
||||
use crate::resource_table::ResourceTable;
|
||||
use crate::{
|
||||
dns, tokio_util, ConnId, ControlSignal, Device, Event, RoleState, Tunnel,
|
||||
dns, peer_by_ip, tokio_util, Device, Event, PeerConfig, RoleState, Tunnel,
|
||||
ICE_GATHERING_TIMEOUT_SECONDS, MAX_CONCURRENT_ICE_GATHERING, MAX_UDP_SIZE,
|
||||
};
|
||||
use boringtun::x25519::{PublicKey, StaticSecret};
|
||||
use connlib_shared::error::{ConnlibError as Error, ConnlibError};
|
||||
use connlib_shared::messages::{GatewayId, Interface as InterfaceConfig, ResourceDescription};
|
||||
use connlib_shared::messages::{
|
||||
GatewayId, Interface as InterfaceConfig, Key, ResourceDescription, ResourceId, ReuseConnection,
|
||||
SecretKey,
|
||||
};
|
||||
use connlib_shared::{Callbacks, DNS_SENTINEL};
|
||||
use futures::channel::mpsc::Receiver;
|
||||
use futures::stream;
|
||||
use futures_bounded::{PushError, StreamMap};
|
||||
use ip_network::IpNetwork;
|
||||
use ip_network_table::IpNetworkTable;
|
||||
use std::collections::hash_map::Entry;
|
||||
use std::collections::HashMap;
|
||||
use std::io;
|
||||
use std::net::IpAddr;
|
||||
use std::sync::Arc;
|
||||
use std::task::{ready, Context, Poll};
|
||||
use std::task::{Context, Poll};
|
||||
use std::time::Duration;
|
||||
use tokio::time::Instant;
|
||||
use webrtc::ice_transport::ice_candidate::RTCIceCandidateInit;
|
||||
|
||||
impl<C, CB> Tunnel<C, CB, ClientState>
|
||||
impl<CB> Tunnel<CB, ClientState>
|
||||
where
|
||||
C: ControlSignal + Send + Sync + 'static,
|
||||
CB: Callbacks + 'static,
|
||||
{
|
||||
/// Adds a the given resource to the tunnel.
|
||||
@@ -47,9 +56,9 @@ where
|
||||
}
|
||||
|
||||
let resource_list = {
|
||||
let mut resources = self.resources.write();
|
||||
resources.insert(resource_description);
|
||||
resources.resource_list()
|
||||
let mut role_state = self.role_state.lock();
|
||||
role_state.resources.insert(resource_description);
|
||||
role_state.resources.resource_list()
|
||||
};
|
||||
|
||||
self.callbacks.on_update_resources(resource_list)?;
|
||||
@@ -80,6 +89,13 @@ where
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Clean up a connection to a resource.
|
||||
// FIXME: this cleanup connection is wrong!
|
||||
pub fn cleanup_connection(&self, id: ResourceId) {
|
||||
self.role_state.lock().on_connection_failed(id);
|
||||
self.peer_connections.lock().remove(&id.into());
|
||||
}
|
||||
|
||||
#[tracing::instrument(level = "trace", skip(self))]
|
||||
async fn add_route(self: &Arc<Self>, route: IpNetwork) -> connlib_shared::Result<()> {
|
||||
let mut device = self.device.write().await;
|
||||
@@ -100,80 +116,14 @@ where
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn connection_intent(self: &Arc<Self>, packet: IpPacket<'_>) {
|
||||
const MAX_SIGNAL_CONNECTION_DELAY: Duration = Duration::from_secs(2);
|
||||
|
||||
// We can buffer requests here but will drop them for now and let the upper layer reliability protocol handle this
|
||||
if let Some(resource) = self.get_resource(packet.destination()) {
|
||||
// We have awaiting connection to prevent a race condition where
|
||||
// create_peer_connection hasn't added the thing to peer_connections
|
||||
// and we are finding another packet to the same address (otherwise we would just use peer_connections here)
|
||||
let mut awaiting_connection = self.awaiting_connection.lock();
|
||||
let conn_id = ConnId::from(resource.id());
|
||||
if awaiting_connection.get(&conn_id).is_none() {
|
||||
tracing::trace!(
|
||||
resource_ip = %packet.destination(),
|
||||
"resource_connection_intent",
|
||||
);
|
||||
|
||||
awaiting_connection.insert(conn_id, Default::default());
|
||||
let dev = Arc::clone(self);
|
||||
|
||||
let mut connected_gateway_ids: Vec<_> = dev
|
||||
.gateway_awaiting_connection
|
||||
.lock()
|
||||
.clone()
|
||||
.into_keys()
|
||||
.collect();
|
||||
connected_gateway_ids
|
||||
.extend(dev.resources_gateways.lock().values().collect::<Vec<_>>());
|
||||
tracing::trace!(
|
||||
gateways = ?connected_gateway_ids,
|
||||
"connected_gateways"
|
||||
);
|
||||
tokio::spawn(async move {
|
||||
let mut interval = tokio::time::interval(MAX_SIGNAL_CONNECTION_DELAY);
|
||||
loop {
|
||||
interval.tick().await;
|
||||
let reference = {
|
||||
let mut awaiting_connections = dev.awaiting_connection.lock();
|
||||
let Some(awaiting_connection) =
|
||||
awaiting_connections.get_mut(&ConnId::from(resource.id()))
|
||||
else {
|
||||
break;
|
||||
};
|
||||
if awaiting_connection.response_received {
|
||||
break;
|
||||
}
|
||||
awaiting_connection.total_attemps += 1;
|
||||
awaiting_connection.total_attemps
|
||||
};
|
||||
if let Err(e) = dev
|
||||
.control_signaler
|
||||
.signal_connection_to(&resource, &connected_gateway_ids, reference)
|
||||
.await
|
||||
{
|
||||
// Not a deadlock because this is a different task
|
||||
dev.awaiting_connection.lock().remove(&conn_id);
|
||||
tracing::error!(error = ?e, "start_resource_connection");
|
||||
let _ = dev.callbacks.on_error(&e);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Reads IP packets from the [`Device`] and handles them accordingly.
|
||||
async fn device_handler<C, CB>(
|
||||
tunnel: Arc<Tunnel<C, CB, ClientState>>,
|
||||
async fn device_handler<CB>(
|
||||
tunnel: Arc<Tunnel<CB, ClientState>>,
|
||||
mut device: Device,
|
||||
) -> Result<(), ConnlibError>
|
||||
where
|
||||
C: ControlSignal + Send + Sync + 'static,
|
||||
CB: Callbacks + 'static,
|
||||
{
|
||||
let device_writer = device.io.clone();
|
||||
@@ -183,7 +133,9 @@ where
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
if let Some(dns_packet) = dns::parse(&tunnel.resources.read(), packet.as_immutable()) {
|
||||
if let Some(dns_packet) =
|
||||
dns::parse(&tunnel.role_state.lock().resources, packet.as_immutable())
|
||||
{
|
||||
if let Err(e) = send_dns_packet(&device_writer, dns_packet) {
|
||||
tracing::error!(err = %e, "failed to send DNS packet");
|
||||
let _ = tunnel.callbacks.on_error(&e.into());
|
||||
@@ -194,8 +146,11 @@ where
|
||||
|
||||
let dest = packet.destination();
|
||||
|
||||
let Some(peer) = tunnel.peer_by_ip(dest) else {
|
||||
tunnel.connection_intent(packet.as_immutable());
|
||||
let Some(peer) = peer_by_ip(&tunnel.peers_by_ip.read(), dest) else {
|
||||
tunnel
|
||||
.role_state
|
||||
.lock()
|
||||
.on_connection_intent(packet.destination());
|
||||
continue;
|
||||
};
|
||||
|
||||
@@ -223,9 +178,190 @@ pub struct ClientState {
|
||||
active_candidate_receivers: StreamMap<GatewayId, RTCIceCandidateInit>,
|
||||
/// We split the receivers of ICE candidates into two phases because we only want to start sending them once we've received an SDP from the gateway.
|
||||
waiting_for_sdp_from_gatway: HashMap<GatewayId, Receiver<RTCIceCandidateInit>>,
|
||||
|
||||
// TODO: Make private
|
||||
pub awaiting_connection: HashMap<ResourceId, AwaitingConnectionDetails>,
|
||||
pub gateway_awaiting_connection: HashMap<GatewayId, Vec<IpNetwork>>,
|
||||
|
||||
awaiting_connection_timers: StreamMap<ResourceId, Instant>,
|
||||
|
||||
pub gateway_public_keys: HashMap<GatewayId, PublicKey>,
|
||||
resources_gateways: HashMap<ResourceId, GatewayId>,
|
||||
resources: ResourceTable<ResourceDescription>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct AwaitingConnectionDetails {
|
||||
total_attemps: usize,
|
||||
response_received: bool,
|
||||
gateways: Vec<GatewayId>,
|
||||
}
|
||||
|
||||
impl ClientState {
|
||||
pub(crate) fn attempt_to_reuse_connection(
|
||||
&mut self,
|
||||
resource: ResourceId,
|
||||
gateway: GatewayId,
|
||||
expected_attempts: usize,
|
||||
connected_peers: &mut IpNetworkTable<Arc<Peer>>,
|
||||
) -> Result<Option<ReuseConnection>, ConnlibError> {
|
||||
if self.is_connected_to(resource, connected_peers) {
|
||||
return Err(Error::UnexpectedConnectionDetails);
|
||||
}
|
||||
|
||||
let desc = self
|
||||
.resources
|
||||
.get_by_id(&resource)
|
||||
.ok_or(Error::UnknownResource)?;
|
||||
|
||||
let details = self
|
||||
.awaiting_connection
|
||||
.get_mut(&resource)
|
||||
.ok_or(Error::UnexpectedConnectionDetails)?;
|
||||
|
||||
details.response_received = true;
|
||||
|
||||
if details.total_attemps != expected_attempts {
|
||||
return Err(Error::UnexpectedConnectionDetails);
|
||||
}
|
||||
|
||||
self.resources_gateways.insert(resource, gateway);
|
||||
|
||||
match self.gateway_awaiting_connection.entry(gateway) {
|
||||
Entry::Occupied(mut occupied) => {
|
||||
occupied.get_mut().extend(desc.ips());
|
||||
return Ok(Some(ReuseConnection {
|
||||
resource_id: resource,
|
||||
gateway_id: gateway,
|
||||
}));
|
||||
}
|
||||
Entry::Vacant(vacant) => {
|
||||
vacant.insert(vec![]);
|
||||
}
|
||||
}
|
||||
|
||||
let found = {
|
||||
let peer = connected_peers
|
||||
.iter()
|
||||
.find_map(|(_, p)| (p.conn_id == gateway.into()).then_some(p))
|
||||
.cloned();
|
||||
if let Some(peer) = peer {
|
||||
for ip in desc.ips() {
|
||||
peer.add_allowed_ip(ip);
|
||||
connected_peers.insert(ip, Arc::clone(&peer));
|
||||
}
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
};
|
||||
|
||||
if found {
|
||||
self.awaiting_connection.remove(&resource);
|
||||
self.awaiting_connection_timers.remove(resource);
|
||||
|
||||
return Ok(Some(ReuseConnection {
|
||||
resource_id: resource,
|
||||
gateway_id: gateway,
|
||||
}));
|
||||
}
|
||||
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
pub fn on_connection_failed(&mut self, resource: ResourceId) {
|
||||
self.awaiting_connection.remove(&resource);
|
||||
let Some(gateway) = self.resources_gateways.remove(&resource) else {
|
||||
return;
|
||||
};
|
||||
self.gateway_awaiting_connection.remove(&gateway);
|
||||
self.awaiting_connection_timers.remove(resource);
|
||||
}
|
||||
|
||||
pub fn on_connection_intent(&mut self, destination: IpAddr) {
|
||||
if self.is_awaiting_connection_to(destination) {
|
||||
return;
|
||||
}
|
||||
|
||||
tracing::trace!(resource_ip = %destination, "resource_connection_intent");
|
||||
|
||||
let Some(resource) = self.get_resource_by_destination(destination) else {
|
||||
return;
|
||||
};
|
||||
|
||||
const MAX_SIGNAL_CONNECTION_DELAY: Duration = Duration::from_secs(2);
|
||||
|
||||
let resource_id = resource.id();
|
||||
|
||||
let connected_gateway_ids = self
|
||||
.gateway_awaiting_connection
|
||||
.clone()
|
||||
.into_keys()
|
||||
.chain(self.resources_gateways.values().cloned())
|
||||
.collect();
|
||||
|
||||
tracing::trace!(
|
||||
gateways = ?connected_gateway_ids,
|
||||
"connected_gateways"
|
||||
);
|
||||
|
||||
match self.awaiting_connection_timers.try_push(
|
||||
resource_id,
|
||||
stream::poll_fn({
|
||||
let mut interval = tokio::time::interval(MAX_SIGNAL_CONNECTION_DELAY);
|
||||
move |cx| interval.poll_tick(cx).map(Some)
|
||||
}),
|
||||
) {
|
||||
Ok(()) => {}
|
||||
Err(PushError::BeyondCapacity(_)) => {
|
||||
tracing::warn!(%resource_id, "Too many concurrent connection attempts");
|
||||
return;
|
||||
}
|
||||
Err(PushError::Replaced(_)) => {
|
||||
// The timers are equivalent for our purpose so we don't really care about this one.
|
||||
}
|
||||
}
|
||||
|
||||
self.awaiting_connection.insert(
|
||||
resource_id,
|
||||
AwaitingConnectionDetails {
|
||||
total_attemps: 0,
|
||||
response_received: false,
|
||||
gateways: connected_gateway_ids,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
pub fn create_peer_config_for_new_connection(
|
||||
&mut self,
|
||||
resource: ResourceId,
|
||||
gateway: GatewayId,
|
||||
shared_key: StaticSecret,
|
||||
) -> Result<PeerConfig, ConnlibError> {
|
||||
let Some(public_key) = self.gateway_public_keys.remove(&gateway) else {
|
||||
self.awaiting_connection.remove(&resource);
|
||||
self.gateway_awaiting_connection.remove(&gateway);
|
||||
|
||||
return Err(Error::ControlProtocolError);
|
||||
};
|
||||
|
||||
let desc = self
|
||||
.resources
|
||||
.get_by_id(&resource)
|
||||
.ok_or(Error::ControlProtocolError)?;
|
||||
|
||||
Ok(PeerConfig {
|
||||
persistent_keepalive: None,
|
||||
public_key,
|
||||
ips: desc.ips(),
|
||||
preshared_key: SecretKey::new(Key(shared_key.to_bytes())),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn gateway_by_resource(&self, resource: &ResourceId) -> Option<GatewayId> {
|
||||
self.resources_gateways.get(resource).copied()
|
||||
}
|
||||
|
||||
pub fn add_waiting_ice_receiver(
|
||||
&mut self,
|
||||
id: GatewayId,
|
||||
@@ -234,10 +370,11 @@ impl ClientState {
|
||||
self.waiting_for_sdp_from_gatway.insert(id, receiver);
|
||||
}
|
||||
|
||||
pub fn activate_ice_candidate_receiver(&mut self, id: GatewayId) {
|
||||
pub fn activate_ice_candidate_receiver(&mut self, id: GatewayId, key: PublicKey) {
|
||||
let Some(receiver) = self.waiting_for_sdp_from_gatway.remove(&id) else {
|
||||
return;
|
||||
};
|
||||
self.gateway_public_keys.insert(id, key);
|
||||
|
||||
match self.active_candidate_receivers.try_push(id, receiver) {
|
||||
Ok(()) => {}
|
||||
@@ -249,6 +386,36 @@ impl ClientState {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn is_awaiting_connection_to(&self, destination: IpAddr) -> bool {
|
||||
let Some(resource) = self.get_resource_by_destination(destination) else {
|
||||
return false;
|
||||
};
|
||||
|
||||
self.awaiting_connection.contains_key(&resource.id())
|
||||
}
|
||||
|
||||
fn is_connected_to(
|
||||
&self,
|
||||
resource: ResourceId,
|
||||
connected_peers: &IpNetworkTable<Arc<Peer>>,
|
||||
) -> bool {
|
||||
let Some(resource) = self.resources.get_by_id(&resource) else {
|
||||
return false;
|
||||
};
|
||||
|
||||
resource
|
||||
.ips()
|
||||
.iter()
|
||||
.any(|ip| connected_peers.exact_match(*ip).is_some())
|
||||
}
|
||||
|
||||
fn get_resource_by_destination(&self, destination: IpAddr) -> Option<&ResourceDescription> {
|
||||
match destination {
|
||||
IpAddr::V4(ipv4) => self.resources.get_by_ip(ipv4),
|
||||
IpAddr::V6(ipv6) => self.resources.get_by_ip(ipv6),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ClientState {
|
||||
@@ -259,6 +426,12 @@ impl Default for ClientState {
|
||||
MAX_CONCURRENT_ICE_GATHERING,
|
||||
),
|
||||
waiting_for_sdp_from_gatway: Default::default(),
|
||||
awaiting_connection: Default::default(),
|
||||
gateway_awaiting_connection: Default::default(),
|
||||
awaiting_connection_timers: StreamMap::new(Duration::from_secs(60), 100),
|
||||
gateway_public_keys: Default::default(),
|
||||
resources_gateways: Default::default(),
|
||||
resources: Default::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -268,18 +441,60 @@ impl RoleState for ClientState {
|
||||
|
||||
fn poll_next_event(&mut self, cx: &mut Context<'_>) -> Poll<Event<Self::Id>> {
|
||||
loop {
|
||||
match ready!(self.active_candidate_receivers.poll_next_unpin(cx)) {
|
||||
(conn_id, Some(Ok(c))) => {
|
||||
match self.active_candidate_receivers.poll_next_unpin(cx) {
|
||||
Poll::Ready((conn_id, Some(Ok(c)))) => {
|
||||
return Poll::Ready(Event::SignalIceCandidate {
|
||||
conn_id,
|
||||
candidate: c,
|
||||
})
|
||||
}
|
||||
(id, Some(Err(e))) => {
|
||||
Poll::Ready((id, Some(Err(e)))) => {
|
||||
tracing::warn!(gateway_id = %id, "ICE gathering timed out: {e}")
|
||||
}
|
||||
(_, None) => {}
|
||||
Poll::Ready((_, None)) => continue,
|
||||
Poll::Pending => {}
|
||||
}
|
||||
|
||||
match self.awaiting_connection_timers.poll_next_unpin(cx) {
|
||||
Poll::Ready((resource, Some(Ok(_)))) => {
|
||||
let Entry::Occupied(mut entry) = self.awaiting_connection.entry(resource)
|
||||
else {
|
||||
self.awaiting_connection_timers.remove(resource);
|
||||
|
||||
continue;
|
||||
};
|
||||
|
||||
if entry.get().response_received {
|
||||
self.awaiting_connection_timers.remove(resource);
|
||||
|
||||
// entry.remove(); Maybe?
|
||||
|
||||
continue;
|
||||
}
|
||||
|
||||
entry.get_mut().total_attemps += 1;
|
||||
|
||||
let reference = entry.get_mut().total_attemps;
|
||||
|
||||
return Poll::Ready(Event::ConnectionIntent {
|
||||
resource: self
|
||||
.resources
|
||||
.get_by_id(&resource)
|
||||
.expect("inconsistent internal state")
|
||||
.clone(),
|
||||
connected_gateway_ids: entry.get().gateways.clone(),
|
||||
reference,
|
||||
});
|
||||
}
|
||||
|
||||
Poll::Ready((id, Some(Err(e)))) => {
|
||||
tracing::warn!(resource_id = %id, "Connection establishment timeout: {e}")
|
||||
}
|
||||
Poll::Ready((_, None)) => continue,
|
||||
Poll::Pending => {}
|
||||
}
|
||||
|
||||
return Poll::Pending;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,17 +1,14 @@
|
||||
use boringtun::noise::Tunn;
|
||||
use chrono::{DateTime, Utc};
|
||||
use futures::channel::mpsc;
|
||||
use futures_util::SinkExt;
|
||||
use secrecy::ExposeSecret;
|
||||
use std::sync::Arc;
|
||||
use tracing::instrument;
|
||||
|
||||
use connlib_shared::{
|
||||
messages::{Relay, RequestConnection, ResourceDescription, ReuseConnection},
|
||||
messages::{Relay, RequestConnection, ReuseConnection},
|
||||
Callbacks, Error, Result,
|
||||
};
|
||||
use webrtc::data_channel::OnCloseHdlrFn;
|
||||
use webrtc::peer_connection::OnPeerConnectionStateChangeHdlrFn;
|
||||
use webrtc::{
|
||||
data_channel::RTCDataChannel,
|
||||
ice_transport::{
|
||||
ice_candidate::RTCIceCandidateInit, ice_credential_type::RTCIceCredentialType,
|
||||
ice_server::RTCIceServer,
|
||||
@@ -22,7 +19,7 @@ use webrtc::{
|
||||
},
|
||||
};
|
||||
|
||||
use crate::{peer::Peer, ConnId, ControlSignal, PeerConfig, RoleState, Tunnel};
|
||||
use crate::{ConnId, RoleState, Tunnel};
|
||||
|
||||
mod client;
|
||||
mod gateway;
|
||||
@@ -36,177 +33,35 @@ pub enum Request {
|
||||
ReuseConnection(ReuseConnection),
|
||||
}
|
||||
|
||||
#[tracing::instrument(level = "trace", skip(tunnel))]
|
||||
async fn handle_connection_state_update_with_peer<C, CB, TRoleState>(
|
||||
tunnel: &Arc<Tunnel<C, CB, TRoleState>>,
|
||||
state: RTCPeerConnectionState,
|
||||
index: u32,
|
||||
conn_id: ConnId,
|
||||
) where
|
||||
C: ControlSignal + Clone + Send + Sync + 'static,
|
||||
CB: Callbacks + 'static,
|
||||
TRoleState: RoleState,
|
||||
{
|
||||
tracing::trace!(?state, "peer_state_update");
|
||||
if state == RTCPeerConnectionState::Failed {
|
||||
tunnel.stop_peer(index, conn_id).await;
|
||||
}
|
||||
}
|
||||
|
||||
#[tracing::instrument(level = "trace", skip(tunnel))]
|
||||
fn set_connection_state_with_peer<C, CB, TRoleState>(
|
||||
tunnel: &Arc<Tunnel<C, CB, TRoleState>>,
|
||||
peer_connection: &Arc<RTCPeerConnection>,
|
||||
index: u32,
|
||||
conn_id: ConnId,
|
||||
) where
|
||||
C: ControlSignal + Clone + Send + Sync + 'static,
|
||||
CB: Callbacks + 'static,
|
||||
TRoleState: RoleState,
|
||||
{
|
||||
let tunnel = Arc::clone(tunnel);
|
||||
peer_connection.on_peer_connection_state_change(Box::new(
|
||||
move |state: RTCPeerConnectionState| {
|
||||
let tunnel = Arc::clone(&tunnel);
|
||||
Box::pin(async move {
|
||||
handle_connection_state_update_with_peer(&tunnel, state, index, conn_id).await
|
||||
})
|
||||
},
|
||||
));
|
||||
}
|
||||
|
||||
impl<C, CB, TRoleState> Tunnel<C, CB, TRoleState>
|
||||
impl<CB, TRoleState> Tunnel<CB, TRoleState>
|
||||
where
|
||||
C: ControlSignal + Clone + Send + Sync + 'static,
|
||||
CB: Callbacks + 'static,
|
||||
TRoleState: RoleState,
|
||||
{
|
||||
#[instrument(level = "trace", skip(self, data_channel, peer_config))]
|
||||
async fn handle_channel_open(
|
||||
self: &Arc<Self>,
|
||||
data_channel: Arc<RTCDataChannel>,
|
||||
index: u32,
|
||||
peer_config: PeerConfig,
|
||||
conn_id: ConnId,
|
||||
resources: Option<(ResourceDescription, DateTime<Utc>)>,
|
||||
) -> Result<()> {
|
||||
tracing::trace!(
|
||||
?peer_config.ips,
|
||||
"data_channel_open",
|
||||
);
|
||||
let channel = data_channel.detach().await?;
|
||||
let tunn = Tunn::new(
|
||||
self.private_key.clone(),
|
||||
peer_config.public_key,
|
||||
Some(peer_config.preshared_key.expose_secret().0),
|
||||
peer_config.persistent_keepalive,
|
||||
index,
|
||||
None,
|
||||
)?;
|
||||
|
||||
let peer = Arc::new(Peer::from_config(
|
||||
tunn,
|
||||
index,
|
||||
&peer_config,
|
||||
channel,
|
||||
conn_id,
|
||||
resources,
|
||||
));
|
||||
|
||||
{
|
||||
// Watch out! we need 2 locks, make sure you don't lock both at the same time anywhere else
|
||||
let mut gateway_awaiting_connection = self.gateway_awaiting_connection.lock();
|
||||
let mut peers_by_ip = self.peers_by_ip.write();
|
||||
// In the gateway this will always be none, no harm done
|
||||
match conn_id {
|
||||
ConnId::Gateway(gateway_id) => {
|
||||
if let Some(awaiting_ips) = gateway_awaiting_connection.remove(&gateway_id) {
|
||||
for ip in awaiting_ips {
|
||||
peer.add_allowed_ip(ip);
|
||||
peers_by_ip.insert(ip, Arc::clone(&peer));
|
||||
}
|
||||
}
|
||||
}
|
||||
ConnId::Client(_) => {}
|
||||
ConnId::Resource(_) => {}
|
||||
}
|
||||
for ip in peer_config.ips {
|
||||
peers_by_ip.insert(ip, Arc::clone(&peer));
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(conn) = self.peer_connections.lock().get(&conn_id) {
|
||||
set_connection_state_with_peer(self, conn, index, conn_id)
|
||||
}
|
||||
|
||||
data_channel.on_close({
|
||||
let tunnel = Arc::clone(self);
|
||||
Box::new(move || {
|
||||
tracing::debug!("channel_closed");
|
||||
let tunnel = tunnel.clone();
|
||||
Box::pin(async move {
|
||||
tunnel.stop_peer(index, conn_id).await;
|
||||
})
|
||||
pub fn on_dc_close_handler(self: Arc<Self>, index: u32, conn_id: ConnId) -> OnCloseHdlrFn {
|
||||
Box::new(move || {
|
||||
tracing::debug!("channel_closed");
|
||||
let tunnel = self.clone();
|
||||
Box::pin(async move {
|
||||
tunnel.stop_peer(index, conn_id).await;
|
||||
})
|
||||
});
|
||||
|
||||
let tunnel = Arc::clone(self);
|
||||
tokio::spawn(async move { tunnel.start_peer_handler(peer).await });
|
||||
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
#[tracing::instrument(level = "trace", skip(self))]
|
||||
pub async fn new_peer_connection(
|
||||
self: &Arc<Self>,
|
||||
relays: Vec<Relay>,
|
||||
) -> Result<(Arc<RTCPeerConnection>, mpsc::Receiver<RTCIceCandidateInit>)> {
|
||||
let config = RTCConfiguration {
|
||||
ice_servers: relays
|
||||
.into_iter()
|
||||
.map(|srv| match srv {
|
||||
Relay::Stun(stun) => RTCIceServer {
|
||||
urls: vec![stun.uri],
|
||||
..Default::default()
|
||||
},
|
||||
Relay::Turn(turn) => RTCIceServer {
|
||||
urls: vec![turn.uri],
|
||||
username: turn.username,
|
||||
credential: turn.password,
|
||||
// TODO: check what this is used for
|
||||
credential_type: RTCIceCredentialType::Password,
|
||||
},
|
||||
})
|
||||
.collect(),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let peer_connection = Arc::new(self.webrtc_api.new_peer_connection(config).await?);
|
||||
|
||||
let (ice_candidate_tx, ice_candidate_rx) = mpsc::channel(ICE_CANDIDATE_BUFFER);
|
||||
|
||||
peer_connection.on_ice_candidate(Box::new(move |candidate| {
|
||||
let Some(candidate) = candidate else {
|
||||
return Box::pin(async {});
|
||||
};
|
||||
|
||||
let mut ice_candidate_tx = ice_candidate_tx.clone();
|
||||
pub fn on_peer_connection_state_change_handler(
|
||||
self: Arc<Self>,
|
||||
index: u32,
|
||||
conn_id: ConnId,
|
||||
) -> OnPeerConnectionStateChangeHdlrFn {
|
||||
Box::new(move |state| {
|
||||
let tunnel = Arc::clone(&self);
|
||||
Box::pin(async move {
|
||||
let ice_candidate = match candidate.to_json() {
|
||||
Ok(ice_candidate) => ice_candidate,
|
||||
Err(e) => {
|
||||
tracing::warn!("Failed to serialize ICE candidate to JSON: {e}",);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
if ice_candidate_tx.send(ice_candidate).await.is_err() {
|
||||
debug_assert!(false, "receiver was dropped before sender")
|
||||
tracing::trace!(?state, "peer_state_update");
|
||||
if state == RTCPeerConnectionState::Failed {
|
||||
tunnel.stop_peer(index, conn_id).await;
|
||||
}
|
||||
})
|
||||
}));
|
||||
|
||||
Ok((peer_connection, ice_candidate_rx))
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn add_ice_candidate(
|
||||
@@ -223,11 +78,57 @@ where
|
||||
peer_connection.add_ice_candidate(ice_candidate).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Clean up a connection to a resource.
|
||||
// FIXME: this cleanup connection is wrong!
|
||||
pub fn cleanup_connection(&self, id: ConnId) {
|
||||
self.awaiting_connection.lock().remove(&id);
|
||||
self.peer_connections.lock().remove(&id);
|
||||
}
|
||||
}
|
||||
|
||||
#[tracing::instrument(level = "trace", skip(webrtc))]
|
||||
pub async fn new_peer_connection(
|
||||
webrtc: &webrtc::api::API,
|
||||
relays: Vec<Relay>,
|
||||
) -> Result<(Arc<RTCPeerConnection>, mpsc::Receiver<RTCIceCandidateInit>)> {
|
||||
let config = RTCConfiguration {
|
||||
ice_servers: relays
|
||||
.into_iter()
|
||||
.map(|srv| match srv {
|
||||
Relay::Stun(stun) => RTCIceServer {
|
||||
urls: vec![stun.uri],
|
||||
..Default::default()
|
||||
},
|
||||
Relay::Turn(turn) => RTCIceServer {
|
||||
urls: vec![turn.uri],
|
||||
username: turn.username,
|
||||
credential: turn.password,
|
||||
// TODO: check what this is used for
|
||||
credential_type: RTCIceCredentialType::Password,
|
||||
},
|
||||
})
|
||||
.collect(),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let peer_connection = Arc::new(webrtc.new_peer_connection(config).await?);
|
||||
|
||||
let (ice_candidate_tx, ice_candidate_rx) = mpsc::channel(ICE_CANDIDATE_BUFFER);
|
||||
|
||||
peer_connection.on_ice_candidate(Box::new(move |candidate| {
|
||||
let Some(candidate) = candidate else {
|
||||
return Box::pin(async {});
|
||||
};
|
||||
|
||||
let mut ice_candidate_tx = ice_candidate_tx.clone();
|
||||
Box::pin(async move {
|
||||
let ice_candidate = match candidate.to_json() {
|
||||
Ok(ice_candidate) => ice_candidate,
|
||||
Err(e) => {
|
||||
tracing::warn!("Failed to serialize ICE candidate to JSON: {e}",);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
if ice_candidate_tx.send(ice_candidate).await.is_err() {
|
||||
debug_assert!(false, "receiver was dropped before sender")
|
||||
}
|
||||
})
|
||||
}));
|
||||
|
||||
Ok((peer_connection, ice_candidate_rx))
|
||||
}
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use boringtun::x25519::{PublicKey, StaticSecret};
|
||||
use connlib_shared::messages::SecretKey;
|
||||
use connlib_shared::{
|
||||
control::Reference,
|
||||
messages::{GatewayId, Key, Relay, RequestConnection, ResourceId, ReuseConnection},
|
||||
messages::{GatewayId, Key, Relay, RequestConnection, ResourceId},
|
||||
Callbacks,
|
||||
};
|
||||
use rand_core::OsRng;
|
||||
@@ -17,40 +16,32 @@ use webrtc::{
|
||||
},
|
||||
};
|
||||
|
||||
use crate::{ClientState, ControlSignal, Error, PeerConfig, Request, Result, Tunnel};
|
||||
use crate::control_protocol::new_peer_connection;
|
||||
use crate::{peer::Peer, ClientState, Error, Request, Result, Tunnel};
|
||||
|
||||
#[tracing::instrument(level = "trace", skip(tunnel))]
|
||||
fn handle_connection_state_update<C, CB>(
|
||||
tunnel: &Arc<Tunnel<C, CB, ClientState>>,
|
||||
fn handle_connection_state_update<CB>(
|
||||
tunnel: &Arc<Tunnel<CB, ClientState>>,
|
||||
state: RTCPeerConnectionState,
|
||||
gateway_id: GatewayId,
|
||||
resource_id: ResourceId,
|
||||
) where
|
||||
C: ControlSignal + Clone + Send + Sync + 'static,
|
||||
CB: Callbacks + 'static,
|
||||
{
|
||||
tracing::trace!("peer_state");
|
||||
if state == RTCPeerConnectionState::Failed {
|
||||
tunnel
|
||||
.awaiting_connection
|
||||
.lock()
|
||||
.remove(&resource_id.into());
|
||||
tunnel.role_state.lock().on_connection_failed(resource_id);
|
||||
tunnel.peer_connections.lock().remove(&gateway_id.into());
|
||||
tunnel
|
||||
.gateway_awaiting_connection
|
||||
.lock()
|
||||
.remove(&gateway_id);
|
||||
}
|
||||
}
|
||||
|
||||
#[tracing::instrument(level = "trace", skip(tunnel))]
|
||||
fn set_connection_state_update<C, CB>(
|
||||
tunnel: &Arc<Tunnel<C, CB, ClientState>>,
|
||||
fn set_connection_state_update<CB>(
|
||||
tunnel: &Arc<Tunnel<CB, ClientState>>,
|
||||
peer_connection: &Arc<RTCPeerConnection>,
|
||||
gateway_id: GatewayId,
|
||||
resource_id: ResourceId,
|
||||
) where
|
||||
C: ControlSignal + Clone + Send + Sync + 'static,
|
||||
CB: Callbacks + 'static,
|
||||
{
|
||||
let tunnel = Arc::clone(tunnel);
|
||||
@@ -64,9 +55,8 @@ fn set_connection_state_update<C, CB>(
|
||||
));
|
||||
}
|
||||
|
||||
impl<C, CB> Tunnel<C, CB, ClientState>
|
||||
impl<CB> Tunnel<CB, ClientState>
|
||||
where
|
||||
C: ControlSignal + Clone + Send + Sync + 'static,
|
||||
CB: Callbacks + 'static,
|
||||
{
|
||||
/// Initiate an ice connection request.
|
||||
@@ -89,77 +79,23 @@ where
|
||||
reference: Option<Reference>,
|
||||
) -> Result<Request> {
|
||||
tracing::trace!("request_connection");
|
||||
let resource_description = self
|
||||
.resources
|
||||
.read()
|
||||
.get_by_id(&resource_id)
|
||||
.ok_or(Error::UnknownResource)?
|
||||
.clone();
|
||||
|
||||
let reference: usize = reference
|
||||
.ok_or(Error::InvalidReference)?
|
||||
.parse()
|
||||
.map_err(|_| Error::InvalidReference)?;
|
||||
{
|
||||
let mut awaiting_connections = self.awaiting_connection.lock();
|
||||
let Some(awaiting_connection) = awaiting_connections.get_mut(&resource_id.into())
|
||||
else {
|
||||
return Err(Error::UnexpectedConnectionDetails);
|
||||
};
|
||||
awaiting_connection.response_received = true;
|
||||
if awaiting_connection.total_attemps != reference
|
||||
|| resource_description
|
||||
.ips()
|
||||
.iter()
|
||||
.any(|&ip| self.peers_by_ip.read().exact_match(ip).is_some())
|
||||
{
|
||||
return Err(Error::UnexpectedConnectionDetails);
|
||||
}
|
||||
|
||||
if let Some(connection) = self.role_state.lock().attempt_to_reuse_connection(
|
||||
resource_id,
|
||||
gateway_id,
|
||||
reference,
|
||||
&mut self.peers_by_ip.write(),
|
||||
)? {
|
||||
return Ok(Request::ReuseConnection(connection));
|
||||
}
|
||||
|
||||
self.resources_gateways
|
||||
.lock()
|
||||
.insert(resource_id, gateway_id);
|
||||
{
|
||||
let mut gateway_awaiting_connection = self.gateway_awaiting_connection.lock();
|
||||
if let Some(g) = gateway_awaiting_connection.get_mut(&gateway_id) {
|
||||
g.extend(resource_description.ips());
|
||||
return Ok(Request::ReuseConnection(ReuseConnection {
|
||||
resource_id,
|
||||
gateway_id,
|
||||
}));
|
||||
} else {
|
||||
gateway_awaiting_connection.insert(gateway_id, vec![]);
|
||||
}
|
||||
}
|
||||
{
|
||||
let found = {
|
||||
let mut peers_by_ip = self.peers_by_ip.write();
|
||||
let peer = peers_by_ip
|
||||
.iter()
|
||||
.find_map(|(_, p)| (p.conn_id == gateway_id.into()).then_some(p))
|
||||
.cloned();
|
||||
if let Some(peer) = peer {
|
||||
for ip in resource_description.ips() {
|
||||
peer.add_allowed_ip(ip);
|
||||
peers_by_ip.insert(ip, Arc::clone(&peer));
|
||||
}
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
};
|
||||
|
||||
if found {
|
||||
self.awaiting_connection.lock().remove(&resource_id.into());
|
||||
return Ok(Request::ReuseConnection(ReuseConnection {
|
||||
resource_id,
|
||||
gateway_id,
|
||||
}));
|
||||
}
|
||||
}
|
||||
let peer_connection = {
|
||||
let (peer_connection, receiver) = self.new_peer_connection(relays).await?;
|
||||
let (peer_connection, receiver) = new_peer_connection(&self.webrtc_api, relays).await?;
|
||||
self.role_state
|
||||
.lock()
|
||||
.add_waiting_ice_receiver(gateway_id, receiver);
|
||||
@@ -191,46 +127,63 @@ where
|
||||
Box::pin(async move {
|
||||
tracing::trace!("new_data_channel_opened");
|
||||
let index = tunnel.next_index();
|
||||
let Some(gateway_public_key) =
|
||||
tunnel.gateway_public_keys.lock().remove(&gateway_id)
|
||||
else {
|
||||
tunnel
|
||||
.awaiting_connection
|
||||
.lock()
|
||||
.remove(&resource_id.into());
|
||||
tunnel.peer_connections.lock().remove(&gateway_id.into());
|
||||
tunnel
|
||||
.gateway_awaiting_connection
|
||||
.lock()
|
||||
.remove(&gateway_id);
|
||||
let e = Error::ControlProtocolError;
|
||||
tracing::warn!(err = ?e, "channel_open");
|
||||
let _ = tunnel.callbacks.on_error(&e);
|
||||
return;
|
||||
};
|
||||
let peer_config = PeerConfig {
|
||||
persistent_keepalive: None,
|
||||
public_key: gateway_public_key,
|
||||
ips: resource_description.ips(),
|
||||
preshared_key: SecretKey::new(Key(p_key.to_bytes())),
|
||||
|
||||
let peer_config = match tunnel.role_state.lock().create_peer_config_for_new_connection(resource_id, gateway_id, p_key) {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
tunnel.peer_connections.lock().remove(&gateway_id.into());
|
||||
|
||||
tracing::warn!(err = ?e, "channel_open");
|
||||
let _ = tunnel.callbacks.on_error(&e);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
if let Err(e) = tunnel
|
||||
.handle_channel_open(d, index, peer_config, gateway_id.into(), None)
|
||||
.await
|
||||
d.on_close(tunnel.clone().on_dc_close_handler(index, gateway_id.into()));
|
||||
|
||||
let peer = Arc::new(Peer::new(
|
||||
tunnel.private_key.clone(),
|
||||
index,
|
||||
peer_config.clone(),
|
||||
d.detach().await.expect("only fails if not opened or not enabled, both of which are always true for us"),
|
||||
gateway_id.into(),
|
||||
None,
|
||||
));
|
||||
|
||||
{
|
||||
tracing::error!(err = ?e, "channel_open");
|
||||
let _ = tunnel.callbacks.on_error(&e);
|
||||
tunnel.peer_connections.lock().remove(&gateway_id.into());
|
||||
tunnel
|
||||
.gateway_awaiting_connection
|
||||
.lock()
|
||||
.remove(&gateway_id);
|
||||
let mut role_state = tunnel.role_state.lock();
|
||||
// Watch out! we need 2 locks, make sure you don't lock both at the same time anywhere else
|
||||
let mut peers_by_ip = tunnel.peers_by_ip.write();
|
||||
|
||||
if let Some(awaiting_ips) =
|
||||
role_state.gateway_awaiting_connection.remove(&gateway_id)
|
||||
{
|
||||
for ip in awaiting_ips {
|
||||
peer.add_allowed_ip(ip);
|
||||
peers_by_ip.insert(ip, Arc::clone(&peer));
|
||||
}
|
||||
}
|
||||
|
||||
for ip in peer_config.ips {
|
||||
peers_by_ip.insert(ip, Arc::clone(&peer));
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(conn) = tunnel.peer_connections.lock().get(&gateway_id.into()) {
|
||||
conn.on_peer_connection_state_change(
|
||||
tunnel
|
||||
.clone()
|
||||
.on_peer_connection_state_change_handler(index, gateway_id.into()),
|
||||
);
|
||||
}
|
||||
|
||||
tokio::spawn(tunnel.clone().start_peer_handler(peer));
|
||||
|
||||
tunnel
|
||||
.awaiting_connection
|
||||
.role_state
|
||||
.lock()
|
||||
.remove(&resource_id.into());
|
||||
.awaiting_connection
|
||||
.remove(&resource_id);
|
||||
})
|
||||
}));
|
||||
|
||||
@@ -260,10 +213,10 @@ where
|
||||
rtc_sdp: RTCSessionDescription,
|
||||
gateway_public_key: PublicKey,
|
||||
) -> Result<()> {
|
||||
let gateway_id = *self
|
||||
.resources_gateways
|
||||
let gateway_id = self
|
||||
.role_state
|
||||
.lock()
|
||||
.get(&resource_id)
|
||||
.gateway_by_resource(&resource_id)
|
||||
.ok_or(Error::UnknownResource)?;
|
||||
let peer_connection = self
|
||||
.peer_connections
|
||||
@@ -271,14 +224,11 @@ where
|
||||
.get(&gateway_id.into())
|
||||
.ok_or(Error::UnknownResource)?
|
||||
.clone();
|
||||
self.gateway_public_keys
|
||||
.lock()
|
||||
.insert(gateway_id, gateway_public_key);
|
||||
|
||||
peer_connection.set_remote_description(rtc_sdp).await?;
|
||||
|
||||
self.role_state
|
||||
.lock()
|
||||
.activate_ice_candidate_receiver(gateway_id);
|
||||
.activate_ice_candidate_receiver(gateway_id, gateway_public_key);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -10,15 +10,15 @@ use webrtc::peer_connection::{
|
||||
RTCPeerConnection,
|
||||
};
|
||||
|
||||
use crate::{ControlSignal, GatewayState, PeerConfig, Tunnel};
|
||||
use crate::control_protocol::new_peer_connection;
|
||||
use crate::{peer::Peer, GatewayState, PeerConfig, Tunnel};
|
||||
|
||||
#[tracing::instrument(level = "trace", skip(tunnel))]
|
||||
fn handle_connection_state_update<C, CB>(
|
||||
tunnel: &Arc<Tunnel<C, CB, GatewayState>>,
|
||||
fn handle_connection_state_update<CB>(
|
||||
tunnel: &Arc<Tunnel<CB, GatewayState>>,
|
||||
state: RTCPeerConnectionState,
|
||||
client_id: ClientId,
|
||||
) where
|
||||
C: ControlSignal + Clone + Send + Sync + 'static,
|
||||
CB: Callbacks + 'static,
|
||||
{
|
||||
tracing::trace!(?state, "peer_state");
|
||||
@@ -28,12 +28,11 @@ fn handle_connection_state_update<C, CB>(
|
||||
}
|
||||
|
||||
#[tracing::instrument(level = "trace", skip(tunnel))]
|
||||
fn set_connection_state_update<C, CB>(
|
||||
tunnel: &Arc<Tunnel<C, CB, GatewayState>>,
|
||||
fn set_connection_state_update<CB>(
|
||||
tunnel: &Arc<Tunnel<CB, GatewayState>>,
|
||||
peer_connection: &Arc<RTCPeerConnection>,
|
||||
client_id: ClientId,
|
||||
) where
|
||||
C: ControlSignal + Clone + Send + Sync + 'static,
|
||||
CB: Callbacks + 'static,
|
||||
{
|
||||
let tunnel = Arc::clone(tunnel);
|
||||
@@ -45,9 +44,8 @@ fn set_connection_state_update<C, CB>(
|
||||
));
|
||||
}
|
||||
|
||||
impl<C, CB> Tunnel<C, CB, GatewayState>
|
||||
impl<CB> Tunnel<CB, GatewayState>
|
||||
where
|
||||
C: ControlSignal + Clone + Send + Sync + 'static,
|
||||
CB: Callbacks + 'static,
|
||||
{
|
||||
/// Accept a connection request from a client.
|
||||
@@ -72,7 +70,7 @@ where
|
||||
expires_at: DateTime<Utc>,
|
||||
resource: ResourceDescription,
|
||||
) -> Result<RTCSessionDescription> {
|
||||
let (peer_connection, receiver) = self.new_peer_connection(relays).await?;
|
||||
let (peer_connection, receiver) = new_peer_connection(&self.webrtc_api, relays).await?;
|
||||
self.role_state
|
||||
.lock()
|
||||
.add_new_ice_receiver(client_id, receiver);
|
||||
@@ -88,12 +86,12 @@ where
|
||||
peer_connection.on_data_channel(Box::new(move |d| {
|
||||
tracing::trace!("new_data_channel");
|
||||
let data_channel = Arc::clone(&d);
|
||||
let peer = peer.clone();
|
||||
let peer_config = peer.clone();
|
||||
let tunnel = Arc::clone(&tunnel);
|
||||
let resource = resource.clone();
|
||||
Box::pin(async move {
|
||||
d.on_open(Box::new(move || {
|
||||
tracing::trace!("new_data_channel_open");
|
||||
tracing::trace!(?peer_config.ips, "new_data_channel_open");
|
||||
Box::pin(async move {
|
||||
{
|
||||
let Some(device) = tunnel.device.read().await.clone() else {
|
||||
@@ -103,7 +101,7 @@ where
|
||||
return;
|
||||
};
|
||||
let iface_config = device.config;
|
||||
for &ip in &peer.ips {
|
||||
for &ip in &peer_config.ips {
|
||||
if let Err(e) = iface_config.add_route(ip, tunnel.callbacks()).await
|
||||
{
|
||||
let _ = tunnel.callbacks.on_error(&e);
|
||||
@@ -111,28 +109,34 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
if let Err(e) = tunnel
|
||||
.handle_channel_open(
|
||||
data_channel,
|
||||
index,
|
||||
peer,
|
||||
client_id.into(),
|
||||
Some((resource, expires_at)),
|
||||
)
|
||||
.await
|
||||
{
|
||||
let _ = tunnel.callbacks.on_error(&e);
|
||||
tracing::error!(err = ?e, "channel_open");
|
||||
// Note: handle_channel_open can only error out before insert to peers_by_ip
|
||||
// otherwise we would need to clean that up too!
|
||||
let conn = tunnel.peer_connections.lock().remove(&client_id.into());
|
||||
if let Some(conn) = conn {
|
||||
if let Err(e) = conn.close().await {
|
||||
tracing::error!(error = ?e, "webrtc_close_channel");
|
||||
let _ = tunnel.callbacks().on_error(&e.into());
|
||||
}
|
||||
}
|
||||
data_channel
|
||||
.on_close(tunnel.clone().on_dc_close_handler(index, client_id.into()));
|
||||
|
||||
let peer = Arc::new(Peer::new(
|
||||
tunnel.private_key.clone(),
|
||||
index,
|
||||
peer_config.clone(),
|
||||
data_channel.detach().await.expect("only fails if not opened or not enabled, both of which are always true for us"),
|
||||
client_id.into(),
|
||||
Some((resource, expires_at)),
|
||||
));
|
||||
|
||||
let mut peers_by_ip = tunnel.peers_by_ip.write();
|
||||
|
||||
for ip in peer_config.ips {
|
||||
peers_by_ip.insert(ip, Arc::clone(&peer));
|
||||
}
|
||||
|
||||
if let Some(conn) = tunnel.peer_connections.lock().get(&client_id.into()) {
|
||||
conn.on_peer_connection_state_change(
|
||||
tunnel.clone().on_peer_connection_state_change_handler(
|
||||
index,
|
||||
client_id.into(),
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
tokio::spawn(tunnel.clone().start_peer_handler(peer));
|
||||
})
|
||||
}))
|
||||
})
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use crate::device_channel::create_iface;
|
||||
use crate::{
|
||||
ControlSignal, Device, Event, RoleState, Tunnel, ICE_GATHERING_TIMEOUT_SECONDS,
|
||||
peer_by_ip, ConnId, Device, Event, RoleState, Tunnel, ICE_GATHERING_TIMEOUT_SECONDS,
|
||||
MAX_CONCURRENT_ICE_GATHERING, MAX_UDP_SIZE,
|
||||
};
|
||||
use connlib_shared::error::ConnlibError;
|
||||
@@ -13,9 +13,8 @@ use std::task::{ready, Context, Poll};
|
||||
use std::time::Duration;
|
||||
use webrtc::ice_transport::ice_candidate::RTCIceCandidateInit;
|
||||
|
||||
impl<C, CB> Tunnel<C, CB, GatewayState>
|
||||
impl<CB> Tunnel<CB, GatewayState>
|
||||
where
|
||||
C: ControlSignal + Send + Sync + 'static,
|
||||
CB: Callbacks + 'static,
|
||||
{
|
||||
/// Sets the interface configuration and starts background tasks.
|
||||
@@ -35,15 +34,20 @@ where
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Clean up a connection to a resource.
|
||||
// FIXME: this cleanup connection is wrong!
|
||||
pub fn cleanup_connection(&self, id: ConnId) {
|
||||
self.peer_connections.lock().remove(&id);
|
||||
}
|
||||
}
|
||||
|
||||
/// Reads IP packets from the [`Device`] and handles them accordingly.
|
||||
async fn device_handler<C, CB>(
|
||||
tunnel: Arc<Tunnel<C, CB, GatewayState>>,
|
||||
async fn device_handler<CB>(
|
||||
tunnel: Arc<Tunnel<CB, GatewayState>>,
|
||||
mut device: Device,
|
||||
) -> Result<(), ConnlibError>
|
||||
where
|
||||
C: ControlSignal + Send + Sync + 'static,
|
||||
CB: Callbacks + 'static,
|
||||
{
|
||||
let mut buf = [0u8; MAX_UDP_SIZE];
|
||||
@@ -55,7 +59,7 @@ where
|
||||
|
||||
let dest = packet.destination();
|
||||
|
||||
let Some(peer) = tunnel.peer_by_ip(dest) else {
|
||||
let Some(peer) = peer_by_ip(&tunnel.peers_by_ip.read(), dest) else {
|
||||
continue;
|
||||
};
|
||||
|
||||
|
||||
@@ -4,11 +4,10 @@ use boringtun::noise::{errors::WireGuardError, TunnResult};
|
||||
use bytes::Bytes;
|
||||
use connlib_shared::{Callbacks, Result};
|
||||
|
||||
use crate::{ip_packet::MutableIpPacket, peer::Peer, ControlSignal, RoleState, Tunnel};
|
||||
use crate::{ip_packet::MutableIpPacket, peer::Peer, RoleState, Tunnel};
|
||||
|
||||
impl<C, CB, TRoleState> Tunnel<C, CB, TRoleState>
|
||||
impl<CB, TRoleState> Tunnel<CB, TRoleState>
|
||||
where
|
||||
C: ControlSignal + Send + Sync + 'static,
|
||||
CB: Callbacks + 'static,
|
||||
TRoleState: RoleState,
|
||||
{
|
||||
|
||||
@@ -13,7 +13,6 @@ use ip_network::IpNetwork;
|
||||
use ip_network_table::IpNetworkTable;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use itertools::Itertools;
|
||||
use parking_lot::{Mutex, RwLock};
|
||||
use peer::{Peer, PeerStats};
|
||||
@@ -136,29 +135,6 @@ impl From<connlib_shared::messages::Peer> for PeerConfig {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Trait used for out-going signals to control plane that are **required** to be made from inside the tunnel.
|
||||
///
|
||||
/// Generally, we try to return from the functions here rather than using this callback.
|
||||
#[async_trait]
|
||||
pub trait ControlSignal {
|
||||
/// Signals to the control plane an intent to initiate a connection to the given resource.
|
||||
///
|
||||
/// Used when a packet is found to a resource we have no connection stablished but is within the list of resources available for the client.
|
||||
async fn signal_connection_to(
|
||||
&self,
|
||||
resource: &ResourceDescription,
|
||||
connected_gateway_ids: &[GatewayId],
|
||||
reference: usize,
|
||||
) -> Result<()>;
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
|
||||
struct AwaitingConnectionDetails {
|
||||
pub total_attemps: usize,
|
||||
pub response_received: bool,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct Device {
|
||||
config: Arc<IfaceConfig>,
|
||||
@@ -190,7 +166,7 @@ impl Device {
|
||||
// TODO: We should use newtypes for each kind of Id
|
||||
/// Tunnel is a wireguard state machine that uses webrtc's ICE channels instead of UDP sockets
|
||||
/// to communicate between peers.
|
||||
pub struct Tunnel<C: ControlSignal, CB: Callbacks, TRoleState> {
|
||||
pub struct Tunnel<CB: Callbacks, TRoleState> {
|
||||
next_index: Mutex<IndexLfsr>,
|
||||
// We use a tokio Mutex here since this is only read/write during config so there's no relevant performance impact
|
||||
device: tokio::sync::RwLock<Option<Device>>,
|
||||
@@ -199,13 +175,7 @@ pub struct Tunnel<C: ControlSignal, CB: Callbacks, TRoleState> {
|
||||
public_key: PublicKey,
|
||||
peers_by_ip: RwLock<IpNetworkTable<Arc<Peer>>>,
|
||||
peer_connections: Mutex<HashMap<ConnId, Arc<RTCPeerConnection>>>,
|
||||
awaiting_connection: Mutex<HashMap<ConnId, AwaitingConnectionDetails>>,
|
||||
gateway_awaiting_connection: Mutex<HashMap<GatewayId, Vec<IpNetwork>>>,
|
||||
resources_gateways: Mutex<HashMap<ResourceId, GatewayId>>,
|
||||
webrtc_api: API,
|
||||
resources: Arc<RwLock<ResourceTable<ResourceDescription>>>,
|
||||
control_signaler: C,
|
||||
gateway_public_keys: Mutex<HashMap<GatewayId, PublicKey>>,
|
||||
callbacks: CallbackErrorFacade<CB>,
|
||||
iface_handler_abort: Mutex<Option<AbortHandle>>,
|
||||
|
||||
@@ -220,18 +190,10 @@ pub struct TunnelStats {
|
||||
public_key: String,
|
||||
peers_by_ip: HashMap<IpNetwork, PeerStats>,
|
||||
peer_connections: Vec<ConnId>,
|
||||
resource_gateways: HashMap<ResourceId, GatewayId>,
|
||||
dns_resources: HashMap<String, ResourceDescription>,
|
||||
network_resources: HashMap<IpNetwork, ResourceDescription>,
|
||||
gateway_public_keys: HashMap<GatewayId, String>,
|
||||
|
||||
awaiting_connection: HashMap<ConnId, AwaitingConnectionDetails>,
|
||||
gateway_awaiting_connection: HashMap<GatewayId, Vec<IpNetwork>>,
|
||||
}
|
||||
|
||||
impl<C, CB, TRoleState> Tunnel<C, CB, TRoleState>
|
||||
impl<CB, TRoleState> Tunnel<CB, TRoleState>
|
||||
where
|
||||
C: ControlSignal + Send + Sync + 'static,
|
||||
CB: Callbacks + 'static,
|
||||
TRoleState: RoleState,
|
||||
{
|
||||
@@ -243,30 +205,11 @@ where
|
||||
.map(|(ip, peer)| (ip, peer.stats()))
|
||||
.collect();
|
||||
let peer_connections = self.peer_connections.lock().keys().cloned().collect();
|
||||
let awaiting_connection = self.awaiting_connection.lock().clone();
|
||||
let gateway_awaiting_connection = self.gateway_awaiting_connection.lock().clone();
|
||||
let resource_gateways = self.resources_gateways.lock().clone();
|
||||
let (network_resources, dns_resources) = {
|
||||
let resources = self.resources.read();
|
||||
(resources.network_resources(), resources.dns_resources())
|
||||
};
|
||||
|
||||
let gateway_public_keys = self
|
||||
.gateway_public_keys
|
||||
.lock()
|
||||
.iter()
|
||||
.map(|(&id, &k)| (id, Key::from(k).to_string()))
|
||||
.collect();
|
||||
TunnelStats {
|
||||
public_key: Key::from(self.public_key).to_string(),
|
||||
peers_by_ip,
|
||||
peer_connections,
|
||||
awaiting_connection,
|
||||
gateway_awaiting_connection,
|
||||
resource_gateways,
|
||||
dns_resources,
|
||||
network_resources,
|
||||
gateway_public_keys,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -277,14 +220,10 @@ where
|
||||
pub fn poll_next_event(&self, cx: &mut Context<'_>) -> Poll<Event<TRoleState::Id>> {
|
||||
self.role_state.lock().poll_next_event(cx)
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn peer_by_ip(&self, ip: IpAddr) -> Option<Arc<Peer>> {
|
||||
self.peers_by_ip
|
||||
.read()
|
||||
.longest_match(ip)
|
||||
.map(|(_, peer)| peer)
|
||||
.cloned()
|
||||
}
|
||||
pub(crate) fn peer_by_ip(peers_by_ip: &IpNetworkTable<Arc<Peer>>, ip: IpAddr) -> Option<Arc<Peer>> {
|
||||
peers_by_ip.longest_match(ip).map(|(_, peer)| peer).cloned()
|
||||
}
|
||||
|
||||
pub enum Event<TId> {
|
||||
@@ -292,11 +231,15 @@ pub enum Event<TId> {
|
||||
conn_id: TId,
|
||||
candidate: RTCIceCandidateInit,
|
||||
},
|
||||
ConnectionIntent {
|
||||
resource: ResourceDescription,
|
||||
connected_gateway_ids: Vec<GatewayId>,
|
||||
reference: usize,
|
||||
},
|
||||
}
|
||||
|
||||
impl<C, CB, TRoleState> Tunnel<C, CB, TRoleState>
|
||||
impl<CB, TRoleState> Tunnel<CB, TRoleState>
|
||||
where
|
||||
C: ControlSignal + Send + Sync + 'static,
|
||||
CB: Callbacks + 'static,
|
||||
TRoleState: RoleState,
|
||||
{
|
||||
@@ -305,22 +248,14 @@ where
|
||||
/// # Parameters
|
||||
/// - `private_key`: wireguard's private key.
|
||||
/// - `control_signaler`: this is used to send SDP from the tunnel to the control plane.
|
||||
#[tracing::instrument(level = "trace", skip(private_key, control_signaler, callbacks))]
|
||||
pub async fn new(
|
||||
private_key: StaticSecret,
|
||||
control_signaler: C,
|
||||
callbacks: CB,
|
||||
) -> Result<Self> {
|
||||
#[tracing::instrument(level = "trace", skip(private_key, callbacks))]
|
||||
pub async fn new(private_key: StaticSecret, callbacks: CB) -> Result<Self> {
|
||||
let public_key = (&private_key).into();
|
||||
let rate_limiter = Arc::new(RateLimiter::new(&public_key, HANDSHAKE_RATE_LIMIT));
|
||||
let peers_by_ip = RwLock::new(IpNetworkTable::new());
|
||||
let next_index = Default::default();
|
||||
let peer_connections = Default::default();
|
||||
let resources: Arc<RwLock<ResourceTable<ResourceDescription>>> = Default::default();
|
||||
let awaiting_connection = Default::default();
|
||||
let gateway_public_keys = Default::default();
|
||||
let resources_gateways = Default::default();
|
||||
let gateway_awaiting_connection = Default::default();
|
||||
let device = Default::default();
|
||||
let iface_handler_abort = Default::default();
|
||||
|
||||
@@ -347,7 +282,6 @@ where
|
||||
.build();
|
||||
|
||||
Ok(Self {
|
||||
gateway_public_keys,
|
||||
rate_limiter,
|
||||
private_key,
|
||||
peer_connections,
|
||||
@@ -355,12 +289,7 @@ where
|
||||
peers_by_ip,
|
||||
next_index,
|
||||
webrtc_api,
|
||||
resources,
|
||||
device,
|
||||
awaiting_connection,
|
||||
gateway_awaiting_connection,
|
||||
control_signaler,
|
||||
resources_gateways,
|
||||
callbacks: CallbackErrorFacade(callbacks),
|
||||
iface_handler_abort,
|
||||
role_state: Default::default(),
|
||||
@@ -411,31 +340,6 @@ where
|
||||
});
|
||||
}
|
||||
|
||||
fn remove_expired_peers(self: &Arc<Self>) {
|
||||
let mut peers_by_ip = self.peers_by_ip.write();
|
||||
|
||||
for (_, peer) in peers_by_ip.iter() {
|
||||
peer.expire_resources();
|
||||
if peer.is_emptied() {
|
||||
tracing::trace!(index = peer.index, "peer_expired");
|
||||
let conn = self.peer_connections.lock().remove(&peer.conn_id);
|
||||
let p = peer.clone();
|
||||
|
||||
// We are holding a Mutex, particularly a write one, we don't want to make a blocking call
|
||||
tokio::spawn(async move {
|
||||
let _ = p.shutdown().await;
|
||||
if let Some(conn) = conn {
|
||||
// TODO: it seems that even closing the stream there are messages to the relay
|
||||
// see where they come from.
|
||||
let _ = conn.close().await;
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
peers_by_ip.retain(|_, p| !p.is_emptied());
|
||||
}
|
||||
|
||||
fn start_peers_refresh_timer(self: &Arc<Self>) {
|
||||
let tunnel = self.clone();
|
||||
|
||||
@@ -445,7 +349,10 @@ where
|
||||
let mut dst_buf = [0u8; MAX_UDP_SIZE];
|
||||
|
||||
loop {
|
||||
tunnel.remove_expired_peers();
|
||||
remove_expired_peers(
|
||||
&mut tunnel.peers_by_ip.write(),
|
||||
&mut tunnel.peer_connections.lock(),
|
||||
);
|
||||
|
||||
let peers: Vec<_> = tunnel
|
||||
.peers_by_ip
|
||||
@@ -497,14 +404,6 @@ where
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn get_resource(&self, addr: IpAddr) -> Option<ResourceDescription> {
|
||||
let resources = self.resources.read();
|
||||
match addr {
|
||||
IpAddr::V4(ipv4) => resources.get_by_ip(ipv4).cloned(),
|
||||
IpAddr::V6(ipv6) => resources.get_by_ip(ipv6).cloned(),
|
||||
}
|
||||
}
|
||||
|
||||
fn next_index(&self) -> u32 {
|
||||
self.next_index.lock().next()
|
||||
}
|
||||
@@ -514,6 +413,32 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
fn remove_expired_peers(
|
||||
peers_by_ip: &mut IpNetworkTable<Arc<Peer>>,
|
||||
peer_connections: &mut HashMap<ConnId, Arc<RTCPeerConnection>>,
|
||||
) {
|
||||
for (_, peer) in peers_by_ip.iter() {
|
||||
peer.expire_resources();
|
||||
if peer.is_emptied() {
|
||||
tracing::trace!(index = peer.index, "peer_expired");
|
||||
let conn = peer_connections.remove(&peer.conn_id);
|
||||
let p = peer.clone();
|
||||
|
||||
// We are holding a Mutex, particularly a write one, we don't want to make a blocking call
|
||||
tokio::spawn(async move {
|
||||
let _ = p.shutdown().await;
|
||||
if let Some(conn) = conn {
|
||||
// TODO: it seems that even closing the stream there are messages to the relay
|
||||
// see where they come from.
|
||||
let _ = conn.close().await;
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
peers_by_ip.retain(|_, p| !p.is_emptied());
|
||||
}
|
||||
|
||||
/// Dedicated trait for abstracting over the different ICE states.
|
||||
///
|
||||
/// By design, this trait does not allow any operations apart from advancing via [`RoleState::poll_next_event`].
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
use std::{collections::HashMap, net::IpAddr, sync::Arc};
|
||||
|
||||
use boringtun::noise::{Tunn, TunnResult};
|
||||
use boringtun::x25519::StaticSecret;
|
||||
use bytes::Bytes;
|
||||
use chrono::{DateTime, Utc};
|
||||
use connlib_shared::{
|
||||
@@ -11,11 +12,10 @@ use ip_network::IpNetwork;
|
||||
use ip_network_table::IpNetworkTable;
|
||||
use parking_lot::{Mutex, RwLock};
|
||||
use pnet_packet::MutablePacket;
|
||||
use secrecy::ExposeSecret;
|
||||
use webrtc::data::data_channel::DataChannel;
|
||||
|
||||
use crate::{ip_packet::MutableIpPacket, resource_table::ResourceTable, ConnId};
|
||||
|
||||
use super::PeerConfig;
|
||||
use crate::{ip_packet::MutableIpPacket, resource_table::ResourceTable, ConnId, PeerConfig};
|
||||
|
||||
type ExpiryingResource = (ResourceDescription, DateTime<Utc>);
|
||||
|
||||
@@ -87,34 +87,26 @@ impl Peer {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn from_config(
|
||||
tunnel: Tunn,
|
||||
index: u32,
|
||||
config: &PeerConfig,
|
||||
channel: Arc<DataChannel>,
|
||||
conn_id: ConnId,
|
||||
resource: Option<(ResourceDescription, DateTime<Utc>)>,
|
||||
) -> Self {
|
||||
Self::new(
|
||||
Mutex::new(tunnel),
|
||||
index,
|
||||
config.ips.clone(),
|
||||
channel,
|
||||
conn_id,
|
||||
resource,
|
||||
)
|
||||
}
|
||||
|
||||
pub(crate) fn new(
|
||||
tunnel: Mutex<Tunn>,
|
||||
private_key: StaticSecret,
|
||||
index: u32,
|
||||
ips: Vec<IpNetwork>,
|
||||
peer_config: PeerConfig,
|
||||
channel: Arc<DataChannel>,
|
||||
conn_id: ConnId,
|
||||
resource: Option<(ResourceDescription, DateTime<Utc>)>,
|
||||
) -> Peer {
|
||||
let tunnel = Tunn::new(
|
||||
private_key.clone(),
|
||||
peer_config.public_key,
|
||||
Some(peer_config.preshared_key.expose_secret().0),
|
||||
peer_config.persistent_keepalive,
|
||||
index,
|
||||
None,
|
||||
)
|
||||
.expect("never actually fails"); // See https://github.com/cloudflare/boringtun/pull/366.
|
||||
|
||||
let mut allowed_ips = IpNetworkTable::new();
|
||||
for ip in ips {
|
||||
for ip in peer_config.ips {
|
||||
allowed_ips.insert(ip, ());
|
||||
}
|
||||
let allowed_ips = RwLock::new(allowed_ips);
|
||||
@@ -123,8 +115,9 @@ impl Peer {
|
||||
resource_table.insert(r);
|
||||
RwLock::new(resource_table)
|
||||
});
|
||||
|
||||
Peer {
|
||||
tunnel,
|
||||
tunnel: Mutex::new(tunnel),
|
||||
index,
|
||||
allowed_ips,
|
||||
channel,
|
||||
|
||||
@@ -5,13 +5,12 @@ use bytes::Bytes;
|
||||
use connlib_shared::{Callbacks, Error, Result};
|
||||
|
||||
use crate::{
|
||||
device_channel::DeviceIo, index::check_packet_index, peer::Peer, ControlSignal, RoleState,
|
||||
Tunnel, MAX_UDP_SIZE,
|
||||
device_channel::DeviceIo, index::check_packet_index, peer::Peer, RoleState, Tunnel,
|
||||
MAX_UDP_SIZE,
|
||||
};
|
||||
|
||||
impl<C, CB, TRoleState> Tunnel<C, CB, TRoleState>
|
||||
impl<CB, TRoleState> Tunnel<CB, TRoleState>
|
||||
where
|
||||
C: ControlSignal + Send + Sync + 'static,
|
||||
CB: Callbacks + 'static,
|
||||
TRoleState: RoleState,
|
||||
{
|
||||
@@ -155,7 +154,7 @@ where
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) async fn start_peer_handler(self: &Arc<Self>, peer: Arc<Peer>) {
|
||||
pub(crate) async fn start_peer_handler(self: Arc<Self>, peer: Arc<Peer>) {
|
||||
loop {
|
||||
let Some(device) = self.device.read().await.clone() else {
|
||||
let err = Error::NoIface;
|
||||
|
||||
@@ -3,15 +3,12 @@ use std::{
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
device_channel::DeviceIo, ip_packet::MutableIpPacket, peer::Peer, ControlSignal, Tunnel,
|
||||
};
|
||||
use crate::{device_channel::DeviceIo, ip_packet::MutableIpPacket, peer::Peer, Tunnel};
|
||||
|
||||
use connlib_shared::{messages::ResourceDescription, Callbacks, Error, Result};
|
||||
|
||||
impl<C, CB, TRoleState> Tunnel<C, CB, TRoleState>
|
||||
impl<CB, TRoleState> Tunnel<CB, TRoleState>
|
||||
where
|
||||
C: ControlSignal + Send + Sync + 'static,
|
||||
CB: Callbacks + 'static,
|
||||
{
|
||||
#[inline(always)]
|
||||
|
||||
@@ -1,22 +0,0 @@
|
||||
use async_trait::async_trait;
|
||||
use connlib_shared::{
|
||||
messages::{GatewayId, ResourceDescription},
|
||||
Result,
|
||||
};
|
||||
use firezone_tunnel::ControlSignal;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ControlSignaler;
|
||||
|
||||
#[async_trait]
|
||||
impl ControlSignal for ControlSignaler {
|
||||
async fn signal_connection_to(
|
||||
&self,
|
||||
resource: &ResourceDescription,
|
||||
_connected_gateway_ids: &[GatewayId],
|
||||
_: usize,
|
||||
) -> Result<()> {
|
||||
tracing::warn!("A message to network resource: {resource:?} was discarded, gateways aren't meant to be used as clients.");
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,3 @@
|
||||
use crate::control::ControlSignaler;
|
||||
use crate::messages::{
|
||||
AllowAccess, BroadcastClientIceCandidates, ClientIceCandidates, ConnectionReady,
|
||||
EgressMessages, IngressMessages,
|
||||
@@ -7,7 +6,7 @@ use crate::CallbackHandler;
|
||||
use anyhow::Result;
|
||||
use connlib_shared::messages::ClientId;
|
||||
use connlib_shared::Error;
|
||||
use firezone_tunnel::{GatewayState, Tunnel};
|
||||
use firezone_tunnel::{Event, GatewayState, Tunnel};
|
||||
use phoenix_channel::PhoenixChannel;
|
||||
use std::convert::Infallible;
|
||||
use std::sync::Arc;
|
||||
@@ -18,7 +17,7 @@ use webrtc::peer_connection::sdp::session_description::RTCSessionDescription;
|
||||
pub const PHOENIX_TOPIC: &str = "gateway";
|
||||
|
||||
pub struct Eventloop {
|
||||
tunnel: Arc<Tunnel<ControlSignaler, CallbackHandler, GatewayState>>,
|
||||
tunnel: Arc<Tunnel<CallbackHandler, GatewayState>>,
|
||||
portal: PhoenixChannel<IngressMessages, ()>,
|
||||
|
||||
// TODO: Strongly type request reference (currently `String`)
|
||||
@@ -31,7 +30,7 @@ pub struct Eventloop {
|
||||
|
||||
impl Eventloop {
|
||||
pub(crate) fn new(
|
||||
tunnel: Arc<Tunnel<ControlSignaler, CallbackHandler, GatewayState>>,
|
||||
tunnel: Arc<Tunnel<CallbackHandler, GatewayState>>,
|
||||
portal: PhoenixChannel<IngressMessages, ()>,
|
||||
) -> Self {
|
||||
Self {
|
||||
@@ -190,6 +189,9 @@ impl Eventloop {
|
||||
);
|
||||
continue;
|
||||
}
|
||||
Poll::Ready(Event::ConnectionIntent { .. }) => {
|
||||
unreachable!("Not used on the gateway, split the events!")
|
||||
}
|
||||
Poll::Pending => {}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
use crate::control::ControlSignaler;
|
||||
use crate::eventloop::{Eventloop, PHOENIX_TOPIC};
|
||||
use crate::messages::InitGateway;
|
||||
use anyhow::{Context, Result};
|
||||
@@ -15,7 +14,6 @@ use std::sync::Arc;
|
||||
use tracing_subscriber::layer;
|
||||
use url::Url;
|
||||
|
||||
mod control;
|
||||
mod eventloop;
|
||||
mod messages;
|
||||
|
||||
@@ -30,7 +28,7 @@ async fn main() -> Result<()> {
|
||||
SecretString::new(cli.common.secret),
|
||||
get_device_id(),
|
||||
)?;
|
||||
let tunnel = Arc::new(Tunnel::new(private_key, ControlSignaler, CallbackHandler).await?);
|
||||
let tunnel = Arc::new(Tunnel::new(private_key, CallbackHandler).await?);
|
||||
|
||||
tokio::spawn(backoff::future::retry_notify(
|
||||
ExponentialBackoffBuilder::default()
|
||||
@@ -48,7 +46,7 @@ async fn main() -> Result<()> {
|
||||
}
|
||||
|
||||
async fn run(
|
||||
tunnel: Arc<Tunnel<ControlSignaler, CallbackHandler, GatewayState>>,
|
||||
tunnel: Arc<Tunnel<CallbackHandler, GatewayState>>,
|
||||
connect_url: Url,
|
||||
) -> Result<Infallible> {
|
||||
let (portal, init) = phoenix_channel::init::<InitGateway, _, _>(
|
||||
|
||||
Reference in New Issue
Block a user