mirror of
https://github.com/outbackdingo/firezone.git
synced 2026-01-27 10:18:54 +00:00
refactor(connlib): unify peer storage (#3738)
Now that we have `&mut` access everywhere in the tunnel, the remaining shared-memory and locks are in how we store peers. To resolve this, we introduce a new `PeerStore` that allows us to look up peers by IP and by ID.
This commit is contained in:
2
rust/Cargo.lock
generated
2
rust/Cargo.lock
generated
@@ -1922,7 +1922,6 @@ name = "firezone-tunnel"
|
||||
version = "1.0.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"arc-swap",
|
||||
"async-trait",
|
||||
"bimap",
|
||||
"boringtun",
|
||||
@@ -1942,7 +1941,6 @@ dependencies = [
|
||||
"log",
|
||||
"netlink-packet-core",
|
||||
"netlink-packet-route",
|
||||
"parking_lot",
|
||||
"pnet_packet",
|
||||
"quinn-udp",
|
||||
"rand_core 0.6.4",
|
||||
|
||||
@@ -311,10 +311,6 @@ impl<CB: Callbacks + 'static> ControlPlane<CB> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn stats_event(&mut self) {
|
||||
tracing::debug!(target: "tunnel_state", stats = ?self.tunnel.stats());
|
||||
}
|
||||
|
||||
pub async fn request_log_upload_url(&mut self) {
|
||||
tracing::info!("Requesting log upload URL from portal");
|
||||
|
||||
|
||||
@@ -185,7 +185,6 @@ where
|
||||
let runtime_stopper = runtime_stopper.clone();
|
||||
let callbacks = callbacks.clone();
|
||||
async move {
|
||||
let mut log_stats_interval = tokio::time::interval(Duration::from_secs(10));
|
||||
let mut upload_logs_interval = upload_interval();
|
||||
loop {
|
||||
tokio::select! {
|
||||
@@ -201,7 +200,6 @@ where
|
||||
}
|
||||
},
|
||||
event = poll_fn(|cx| control_plane.tunnel.poll_next_event(cx)) => control_plane.handle_tunnel_event(event).await,
|
||||
_ = log_stats_interval.tick() => control_plane.stats_event().await,
|
||||
_ = upload_logs_interval.tick() => control_plane.request_log_upload_url().await,
|
||||
else => break
|
||||
}
|
||||
|
||||
@@ -14,7 +14,6 @@ serde = { version = "1.0", default-features = false, features = ["derive", "std"
|
||||
futures = { version = "0.3", default-features = false, features = ["std", "async-await", "executor"] }
|
||||
futures-util = { version = "0.3", default-features = false, features = ["std", "async-await", "async-await-macro"] }
|
||||
tracing = { workspace = true }
|
||||
parking_lot = { version = "0.12", default-features = false }
|
||||
bytes = { version = "1.4", default-features = false, features = ["std"] }
|
||||
itertools = { version = "0.12", default-features = false, features = ["use_std"] }
|
||||
connlib-shared = { workspace = true }
|
||||
@@ -27,7 +26,6 @@ chrono = { workspace = true }
|
||||
pnet_packet = { version = "0.34" }
|
||||
futures-bounded = { workspace = true }
|
||||
hickory-resolver = { workspace = true, features = ["tokio-runtime"] }
|
||||
arc-swap = "1.6.0"
|
||||
bimap = "0.6"
|
||||
resolv-conf = "0.7.0"
|
||||
socket2 = { version = "0.5" }
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
use crate::device_channel::{Device, Packet};
|
||||
use crate::ip_packet::{IpPacket, MutableIpPacket};
|
||||
use crate::peer::{PacketTransformClient, Peer};
|
||||
use crate::{dns, dns::DnsQuery, peer_by_ip, Event, Tunnel, DNS_QUERIES_QUEUE_SIZE};
|
||||
use crate::peer::PacketTransformClient;
|
||||
use crate::peer_store::PeerStore;
|
||||
use crate::{dns, dns::DnsQuery, Event, Tunnel, DNS_QUERIES_QUEUE_SIZE};
|
||||
use bimap::BiMap;
|
||||
use connlib_shared::error::{ConnlibError as Error, ConnlibError};
|
||||
use connlib_shared::messages::{
|
||||
@@ -22,7 +23,6 @@ use hickory_resolver::TokioAsyncResolver;
|
||||
use std::collections::hash_map::Entry;
|
||||
use std::collections::{HashMap, HashSet, VecDeque};
|
||||
use std::net::IpAddr;
|
||||
use std::sync::Arc;
|
||||
use std::task::{Context, Poll};
|
||||
use std::time::Duration;
|
||||
use tokio::time::{Instant, Interval, MissedTickBehavior};
|
||||
@@ -47,7 +47,7 @@ impl DnsResource {
|
||||
}
|
||||
}
|
||||
|
||||
impl<CB> Tunnel<CB, ClientState, Client, GatewayId, PacketTransformClient>
|
||||
impl<CB> Tunnel<CB, ClientState, Client, GatewayId>
|
||||
where
|
||||
CB: Callbacks + 'static,
|
||||
{
|
||||
@@ -189,7 +189,7 @@ pub struct ClientState {
|
||||
pub resource_ids: HashMap<ResourceId, ResourceDescription>,
|
||||
pub deferred_dns_queries: HashMap<(DnsResource, Rtype), IpPacket<'static>>,
|
||||
|
||||
pub peers_by_ip: IpNetworkTable<Arc<Peer<GatewayId, PacketTransformClient>>>,
|
||||
pub peers: PeerStore<GatewayId, PacketTransformClient>,
|
||||
|
||||
forwarded_dns_queries: FuturesTupleSet<
|
||||
Result<hickory_resolver::lookup::Lookup, hickory_resolver::error::ResolveError>,
|
||||
@@ -227,7 +227,7 @@ impl ClientState {
|
||||
Err(non_dns_packet) => non_dns_packet,
|
||||
};
|
||||
|
||||
let Some(peer) = peer_by_ip(&self.peers_by_ip, dest) else {
|
||||
let Some(peer) = self.peers.peer_by_ip_mut(dest) else {
|
||||
self.on_connection_intent_ip(dest);
|
||||
return None;
|
||||
};
|
||||
@@ -309,7 +309,7 @@ impl ClientState {
|
||||
|
||||
let domain = self.get_awaiting_connection_domain(&resource)?.clone();
|
||||
|
||||
if self.is_connected_to(resource, &self.peers_by_ip, &domain) {
|
||||
if self.is_connected_to(resource, &domain) {
|
||||
return Err(Error::UnexpectedConnectionDetails);
|
||||
}
|
||||
|
||||
@@ -332,11 +332,11 @@ impl ClientState {
|
||||
|
||||
self.resources_gateways.insert(resource, gateway);
|
||||
|
||||
let Some(peer) = self
|
||||
.peers_by_ip
|
||||
.iter()
|
||||
.find_map(|(_, p)| (p.conn_id == gateway).then_some(p.clone()))
|
||||
else {
|
||||
if self
|
||||
.peers
|
||||
.add_ips(&gateway, &self.get_resource_ip(desc, &domain))
|
||||
.is_none()
|
||||
{
|
||||
match self
|
||||
.gateway_awaiting_connection_timers
|
||||
// Note: we don't need to set a timer here because
|
||||
@@ -357,10 +357,6 @@ impl ClientState {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
for ip in self.get_resource_ip(desc, &domain) {
|
||||
peer.add_allowed_ip(ip);
|
||||
self.peers_by_ip.insert(ip, peer.clone());
|
||||
}
|
||||
self.awaiting_connection.remove(&resource);
|
||||
self.awaiting_connection_timers.remove(resource);
|
||||
|
||||
@@ -531,19 +527,13 @@ impl ClientState {
|
||||
self.awaiting_connection.contains_key(&resource.id())
|
||||
}
|
||||
|
||||
fn is_connected_to(
|
||||
&self,
|
||||
resource: ResourceId,
|
||||
connected_peers: &IpNetworkTable<Arc<Peer<GatewayId, PacketTransformClient>>>,
|
||||
domain: &Option<Dname>,
|
||||
) -> bool {
|
||||
fn is_connected_to(&self, resource: ResourceId, domain: &Option<Dname>) -> bool {
|
||||
let Some(resource) = self.resource_ids.get(&resource) else {
|
||||
return false;
|
||||
};
|
||||
|
||||
let ips = self.get_resource_ip(resource, domain);
|
||||
ips.iter()
|
||||
.any(|ip| connected_peers.exact_match(*ip).is_some())
|
||||
ips.iter().any(|ip| self.peers.exact_match(*ip).is_some())
|
||||
}
|
||||
|
||||
fn get_resource_ip(
|
||||
@@ -571,7 +561,7 @@ impl ClientState {
|
||||
}
|
||||
|
||||
pub fn cleanup_connected_gateway(&mut self, gateway_id: &GatewayId) {
|
||||
self.peers_by_ip.retain(|_, p| p.conn_id != *gateway_id);
|
||||
self.peers.remove(gateway_id);
|
||||
self.dns_resources_internal_ips.retain(|resource, _| {
|
||||
!self
|
||||
.resources_gateways
|
||||
@@ -668,20 +658,16 @@ impl ClientState {
|
||||
if self.refresh_dns_timer.poll_tick(cx).is_ready() {
|
||||
let mut connections = Vec::new();
|
||||
|
||||
self.peers_by_ip
|
||||
.iter()
|
||||
.for_each(|p| p.1.transform.expire_dns_track());
|
||||
self.peers
|
||||
.iter_mut()
|
||||
.for_each(|p| p.transform.expire_dns_track());
|
||||
|
||||
for resource in self.dns_resources_internal_ips.keys() {
|
||||
let Some(gateway_id) = self.resources_gateways.get(&resource.id) else {
|
||||
continue;
|
||||
};
|
||||
// filter inactive connections
|
||||
if !self
|
||||
.peers_by_ip
|
||||
.iter()
|
||||
.any(|(_, p)| &p.conn_id == gateway_id)
|
||||
{
|
||||
if self.peers.get(gateway_id).is_none() {
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -761,7 +747,7 @@ impl Default for ClientState {
|
||||
dns_resources: Default::default(),
|
||||
cidr_resources: IpNetworkTable::new(),
|
||||
resource_ids: Default::default(),
|
||||
peers_by_ip: IpNetworkTable::new(),
|
||||
peers: Default::default(),
|
||||
deferred_dns_queries: Default::default(),
|
||||
refresh_dns_timer: interval,
|
||||
dns_mapping: Default::default(),
|
||||
|
||||
@@ -1,13 +1,11 @@
|
||||
use ip_network::IpNetwork;
|
||||
use ip_network_table::IpNetworkTable;
|
||||
use std::{collections::HashSet, fmt, hash::Hash, net::SocketAddr, sync::Arc};
|
||||
use std::{collections::HashSet, fmt, hash::Hash, net::SocketAddr};
|
||||
|
||||
use connlib_shared::{
|
||||
messages::{Relay, RequestConnection, ReuseConnection},
|
||||
Callbacks,
|
||||
};
|
||||
|
||||
use crate::{peer::Peer, Tunnel, REALM};
|
||||
use crate::{Tunnel, REALM};
|
||||
|
||||
mod client;
|
||||
pub mod gateway;
|
||||
@@ -18,7 +16,7 @@ pub enum Request {
|
||||
ReuseConnection(ReuseConnection),
|
||||
}
|
||||
|
||||
impl<CB, TRoleState, TRole, TId, TTransform> Tunnel<CB, TRoleState, TRole, TId, TTransform>
|
||||
impl<CB, TRoleState, TRole, TId> Tunnel<CB, TRoleState, TRole, TId>
|
||||
where
|
||||
CB: Callbacks + 'static,
|
||||
TId: Eq + Hash + Copy + fmt::Display,
|
||||
@@ -30,16 +28,6 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
fn insert_peers<TId: Copy, TTransform>(
|
||||
peers_by_ip: &mut IpNetworkTable<Arc<Peer<TId, TTransform>>>,
|
||||
ips: &Vec<IpNetwork>,
|
||||
peer: Arc<Peer<TId, TTransform>>,
|
||||
) {
|
||||
for ip in ips {
|
||||
peers_by_ip.insert(*ip, peer.clone());
|
||||
}
|
||||
}
|
||||
|
||||
fn stun(relays: &[Relay], predicate: impl Fn(&SocketAddr) -> bool) -> HashSet<SocketAddr> {
|
||||
relays
|
||||
.iter()
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use std::{collections::HashSet, net::IpAddr, sync::Arc};
|
||||
use std::{collections::HashSet, net::IpAddr};
|
||||
|
||||
use boringtun::x25519::PublicKey;
|
||||
use connlib_shared::{
|
||||
@@ -23,9 +23,7 @@ use crate::{
|
||||
};
|
||||
use crate::{peer::Peer, ClientState, Error, Request, Result, Tunnel};
|
||||
|
||||
use super::insert_peers;
|
||||
|
||||
impl<CB> Tunnel<CB, ClientState, Client, GatewayId, PacketTransformClient>
|
||||
impl<CB> Tunnel<CB, ClientState, Client, GatewayId>
|
||||
where
|
||||
CB: Callbacks + 'static,
|
||||
{
|
||||
@@ -108,24 +106,18 @@ where
|
||||
&domain_response.as_ref().map(|d| d.domain.clone()),
|
||||
)?;
|
||||
|
||||
let peer = Arc::new(Peer::new(ips.clone(), gateway_id, Default::default()));
|
||||
let mut peer: Peer<_, PacketTransformClient> =
|
||||
Peer::new(ips.clone(), gateway_id, Default::default());
|
||||
peer.transform.set_dns(self.role_state.dns_mapping());
|
||||
self.role_state.peers.insert(peer);
|
||||
|
||||
let peer_ips = if let Some(domain_response) = domain_response {
|
||||
self.dns_response(&resource_id, &domain_response, &peer)?
|
||||
self.dns_response(&resource_id, &domain_response, &gateway_id)?
|
||||
} else {
|
||||
ips
|
||||
};
|
||||
|
||||
peer.transform.set_dns(self.role_state.dns_mapping());
|
||||
|
||||
// cleaning up old state
|
||||
self.role_state
|
||||
.peers_by_ip
|
||||
.retain(|_, p| p.conn_id != gateway_id);
|
||||
self.connections_state
|
||||
.peers_by_id
|
||||
.insert(gateway_id, Arc::clone(&peer));
|
||||
insert_peers(&mut self.role_state.peers_by_ip, &peer_ips, peer);
|
||||
self.role_state.peers.add_ips(&gateway_id, &peer_ips);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -168,8 +160,14 @@ where
|
||||
&mut self,
|
||||
resource_id: &ResourceId,
|
||||
domain_response: &DomainResponse,
|
||||
peer: &Peer<GatewayId, PacketTransformClient>,
|
||||
peer_id: &GatewayId,
|
||||
) -> Result<Vec<IpNetwork>> {
|
||||
let peer = self
|
||||
.role_state
|
||||
.peers
|
||||
.get_mut(peer_id)
|
||||
.ok_or(Error::ControlProtocolError)?;
|
||||
|
||||
let resource_description = self
|
||||
.role_state
|
||||
.resource_ids
|
||||
@@ -199,9 +197,6 @@ where
|
||||
.insert(resource_description.clone(), addrs.clone());
|
||||
|
||||
let ips: Vec<IpNetwork> = addrs.iter().copied().map(Into::into).collect();
|
||||
for ip in &ips {
|
||||
peer.add_allowed_ip(*ip);
|
||||
}
|
||||
|
||||
if let Some(device) = self.device.as_ref() {
|
||||
send_dns_answer(
|
||||
@@ -235,17 +230,10 @@ where
|
||||
.gateway_by_resource(&resource_id)
|
||||
.ok_or(Error::UnknownResource)?;
|
||||
|
||||
let Some(peer) = self
|
||||
.role_state
|
||||
.peers_by_ip
|
||||
.iter_mut()
|
||||
.find_map(|(_, p)| (p.conn_id == gateway_id).then_some(p.clone()))
|
||||
else {
|
||||
return Err(Error::ControlProtocolError);
|
||||
};
|
||||
let peer_ips = self.dns_response(&resource_id, &domain_response, &gateway_id)?;
|
||||
|
||||
self.role_state.peers.add_ips(&gateway_id, &peer_ips);
|
||||
|
||||
let peer_ips = self.dns_response(&resource_id, &domain_response, &peer)?;
|
||||
insert_peers(&mut self.role_state.peers_by_ip, &peer_ips, peer);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
use crate::{
|
||||
control_protocol::insert_peers,
|
||||
dns::is_subdomain,
|
||||
peer::{PacketTransformGateway, Peer},
|
||||
Error, GatewayState, Tunnel,
|
||||
@@ -17,7 +16,6 @@ use connlib_shared::{
|
||||
use ip_network::IpNetwork;
|
||||
use secrecy::{ExposeSecret as _, Secret};
|
||||
use snownet::{Credentials, Server};
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Description of a resource that maps to a DNS record which had its domain already resolved.
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||
@@ -36,7 +34,7 @@ pub struct ResolvedResourceDescriptionDns {
|
||||
pub type ResourceDescription =
|
||||
connlib_shared::messages::ResourceDescription<ResolvedResourceDescriptionDns>;
|
||||
|
||||
impl<CB> Tunnel<CB, GatewayState, Server, ClientId, PacketTransformGateway>
|
||||
impl<CB> Tunnel<CB, GatewayState, Server, ClientId>
|
||||
where
|
||||
CB: Callbacks + 'static,
|
||||
{
|
||||
@@ -125,14 +123,7 @@ where
|
||||
expires_at: Option<DateTime<Utc>>,
|
||||
domain: Option<Dname>,
|
||||
) -> Option<DomainResponse> {
|
||||
let Some(peer) = self
|
||||
.role_state
|
||||
.peers_by_ip
|
||||
.iter_mut()
|
||||
.find_map(|(_, p)| (p.conn_id == client).then_some(p.clone()))
|
||||
else {
|
||||
return None;
|
||||
};
|
||||
let peer = self.role_state.peers.get_mut(&client)?;
|
||||
|
||||
let (addresses, resource_id) = match &resource {
|
||||
ResourceDescription::Dns(r) => {
|
||||
@@ -176,25 +167,15 @@ where
|
||||
) -> Result<()> {
|
||||
tracing::trace!(?ips, "new_data_channel_open");
|
||||
|
||||
let peer = Arc::new(Peer::new(
|
||||
ips.clone(),
|
||||
client_id,
|
||||
PacketTransformGateway::default(),
|
||||
));
|
||||
let mut peer = Peer::new(ips.clone(), client_id, PacketTransformGateway::default());
|
||||
|
||||
for address in resource_addresses {
|
||||
peer.transform
|
||||
.add_resource(address, resource.clone(), expires_at);
|
||||
}
|
||||
|
||||
// cleaning up old state
|
||||
self.role_state
|
||||
.peers_by_ip
|
||||
.retain(|_, p| p.conn_id != client_id);
|
||||
self.connections_state
|
||||
.peers_by_id
|
||||
.insert(client_id, Arc::clone(&peer));
|
||||
insert_peers(&mut self.role_state.peers_by_ip, &ips, peer);
|
||||
self.role_state.peers.insert(peer);
|
||||
self.role_state.peers.add_ips(&client_id, &ips);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -1,21 +1,20 @@
|
||||
use crate::device_channel::Device;
|
||||
use crate::ip_packet::MutableIpPacket;
|
||||
use crate::peer::{PacketTransformGateway, Peer};
|
||||
use crate::{peer_by_ip, Tunnel};
|
||||
use connlib_shared::messages::{ClientId, Interface as InterfaceConfig};
|
||||
use connlib_shared::Callbacks;
|
||||
use ip_network_table::IpNetworkTable;
|
||||
use itertools::Itertools;
|
||||
use snownet::Server;
|
||||
use std::sync::Arc;
|
||||
use std::task::{ready, Context, Poll};
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::device_channel::Device;
|
||||
use crate::ip_packet::MutableIpPacket;
|
||||
use crate::peer::PacketTransformGateway;
|
||||
use crate::peer_store::PeerStore;
|
||||
use crate::Tunnel;
|
||||
use connlib_shared::messages::{ClientId, Interface as InterfaceConfig};
|
||||
use connlib_shared::Callbacks;
|
||||
use snownet::Server;
|
||||
use tokio::time::{interval, Interval, MissedTickBehavior};
|
||||
|
||||
const PEERS_IPV4: &str = "100.64.0.0/11";
|
||||
const PEERS_IPV6: &str = "fd00:2021:1111::/107";
|
||||
|
||||
impl<CB> Tunnel<CB, GatewayState, Server, ClientId, PacketTransformGateway>
|
||||
impl<CB> Tunnel<CB, GatewayState, Server, ClientId>
|
||||
where
|
||||
CB: Callbacks + 'static,
|
||||
{
|
||||
@@ -38,16 +37,14 @@ where
|
||||
}
|
||||
|
||||
/// Clean up a connection to a resource.
|
||||
pub fn cleanup_connection(&mut self, id: ClientId) {
|
||||
self.connections_state.peers_by_id.remove(&id);
|
||||
self.role_state.peers_by_ip.retain(|_, p| p.conn_id != id);
|
||||
pub fn cleanup_connection(&mut self, id: &ClientId) {
|
||||
self.role_state.peers.remove(id);
|
||||
}
|
||||
}
|
||||
|
||||
/// [`Tunnel`] state specific to gateways.
|
||||
pub struct GatewayState {
|
||||
#[allow(clippy::type_complexity)]
|
||||
pub peers_by_ip: IpNetworkTable<Arc<Peer<ClientId, PacketTransformGateway>>>,
|
||||
pub peers: PeerStore<ClientId, PacketTransformGateway>,
|
||||
expire_interval: Interval,
|
||||
}
|
||||
|
||||
@@ -58,29 +55,23 @@ impl GatewayState {
|
||||
) -> Option<(ClientId, MutableIpPacket<'a>)> {
|
||||
let dest = packet.destination();
|
||||
|
||||
let peer = peer_by_ip(&self.peers_by_ip, dest)?;
|
||||
let peer = self.peers.peer_by_ip_mut(dest)?;
|
||||
let packet = peer.transform(packet)?;
|
||||
|
||||
Some((peer.conn_id, packet))
|
||||
}
|
||||
|
||||
pub fn poll(&mut self, cx: &mut Context<'_>) -> Poll<Vec<ClientId>> {
|
||||
pub fn poll(&mut self, cx: &mut Context<'_>) -> Poll<()> {
|
||||
ready!(self.expire_interval.poll_tick(cx));
|
||||
Poll::Ready(self.expire_resources().collect_vec())
|
||||
self.expire_resources();
|
||||
Poll::Ready(())
|
||||
}
|
||||
|
||||
fn expire_resources(&self) -> impl Iterator<Item = ClientId> + '_ {
|
||||
self.peers_by_ip
|
||||
.iter()
|
||||
.unique_by(|(_, p)| p.conn_id)
|
||||
.for_each(|(_, p)| p.transform.expire_resources());
|
||||
self.peers_by_ip.iter().filter_map(|(_, p)| {
|
||||
if p.transform.is_emptied() {
|
||||
Some(p.conn_id)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
fn expire_resources(&mut self) {
|
||||
self.peers
|
||||
.iter_mut()
|
||||
.for_each(|p| p.transform.expire_resources());
|
||||
self.peers.retain(|_, p| !p.transform.is_emptied());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -89,7 +80,7 @@ impl Default for GatewayState {
|
||||
let mut expire_interval = interval(Duration::from_secs(1));
|
||||
expire_interval.set_missed_tick_behavior(MissedTickBehavior::Delay);
|
||||
Self {
|
||||
peers_by_ip: IpNetworkTable::new(),
|
||||
peers: Default::default(),
|
||||
expire_interval,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,18 +10,16 @@ use connlib_shared::{
|
||||
};
|
||||
use device_channel::Device;
|
||||
use futures_util::{future::BoxFuture, task::AtomicWaker, FutureExt};
|
||||
use ip_network_table::IpNetworkTable;
|
||||
use peer::{PacketTransform, PacketTransformClient, PacketTransformGateway, Peer, PeerStats};
|
||||
use peer::PacketTransform;
|
||||
use peer_store::PeerStore;
|
||||
use pnet_packet::Packet;
|
||||
use snownet::{IpPacket, Node, Server};
|
||||
use sockets::{Received, Sockets};
|
||||
use std::{
|
||||
collections::{HashMap, HashSet},
|
||||
collections::HashSet,
|
||||
fmt,
|
||||
hash::Hash,
|
||||
io,
|
||||
net::IpAddr,
|
||||
sync::Arc,
|
||||
task::{ready, Context, Poll},
|
||||
time::Instant,
|
||||
};
|
||||
@@ -37,6 +35,7 @@ mod dns;
|
||||
mod gateway;
|
||||
mod ip_packet;
|
||||
mod peer;
|
||||
mod peer_store;
|
||||
mod sockets;
|
||||
|
||||
const MAX_UDP_SIZE: usize = (1 << 16) - 1;
|
||||
@@ -47,12 +46,11 @@ const REALM: &str = "firezone";
|
||||
#[cfg(target_os = "linux")]
|
||||
const FIREZONE_MARK: u32 = 0xfd002021;
|
||||
|
||||
pub type GatewayTunnel<CB> = Tunnel<CB, GatewayState, Server, ClientId, PacketTransformGateway>;
|
||||
pub type ClientTunnel<CB> =
|
||||
Tunnel<CB, ClientState, snownet::Client, GatewayId, PacketTransformClient>;
|
||||
pub type GatewayTunnel<CB> = Tunnel<CB, GatewayState, Server, ClientId>;
|
||||
pub type ClientTunnel<CB> = Tunnel<CB, ClientState, snownet::Client, GatewayId>;
|
||||
|
||||
/// Tunnel is a wireguard state machine that uses webrtc's ICE channels instead of UDP sockets to communicate between peers.
|
||||
pub struct Tunnel<CB: Callbacks, TRoleState, TRole, TId, TTransform> {
|
||||
pub struct Tunnel<CB: Callbacks, TRoleState, TRole, TId> {
|
||||
callbacks: CallbackErrorFacade<CB>,
|
||||
|
||||
/// State that differs per role, i.e. clients vs gateways.
|
||||
@@ -61,12 +59,12 @@ pub struct Tunnel<CB: Callbacks, TRoleState, TRole, TId, TTransform> {
|
||||
device: Option<Device>,
|
||||
no_device_waker: AtomicWaker,
|
||||
|
||||
connections_state: ConnectionState<TRole, TId, TTransform>,
|
||||
connections_state: ConnectionState<TRole, TId>,
|
||||
|
||||
read_buf: [u8; MAX_UDP_SIZE],
|
||||
}
|
||||
|
||||
impl<CB> Tunnel<CB, ClientState, snownet::Client, GatewayId, PacketTransformClient>
|
||||
impl<CB> Tunnel<CB, ClientState, snownet::Client, GatewayId>
|
||||
where
|
||||
CB: Callbacks + 'static,
|
||||
{
|
||||
@@ -94,7 +92,10 @@ where
|
||||
_ => (),
|
||||
}
|
||||
|
||||
match self.connections_state.poll_sockets(device, cx)? {
|
||||
match self
|
||||
.connections_state
|
||||
.poll_sockets(device, &mut self.role_state.peers, cx)?
|
||||
{
|
||||
Poll::Ready(()) => {
|
||||
cx.waker().wake_by_ref();
|
||||
}
|
||||
@@ -129,17 +130,14 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl<CB> Tunnel<CB, GatewayState, Server, ClientId, PacketTransformGateway>
|
||||
impl<CB> Tunnel<CB, GatewayState, Server, ClientId>
|
||||
where
|
||||
CB: Callbacks + 'static,
|
||||
{
|
||||
pub fn poll_next_event(&mut self, cx: &mut Context<'_>) -> Poll<Result<Event<ClientId>>> {
|
||||
match self.role_state.poll(cx) {
|
||||
Poll::Ready(ids) => {
|
||||
Poll::Ready(()) => {
|
||||
cx.waker().wake_by_ref();
|
||||
for id in ids {
|
||||
self.cleanup_connection(id);
|
||||
}
|
||||
}
|
||||
Poll::Pending => {}
|
||||
}
|
||||
@@ -151,14 +149,17 @@ where
|
||||
|
||||
match self.connections_state.poll_next_event(cx) {
|
||||
Poll::Ready(Event::StopPeer(id)) => {
|
||||
self.role_state.peers_by_ip.retain(|_, p| p.conn_id != id);
|
||||
self.role_state.peers.remove(&id);
|
||||
cx.waker().wake_by_ref();
|
||||
}
|
||||
Poll::Ready(other) => return Poll::Ready(Ok(other)),
|
||||
_ => (),
|
||||
}
|
||||
|
||||
match self.connections_state.poll_sockets(device, cx)? {
|
||||
match self
|
||||
.connections_state
|
||||
.poll_sockets(device, &mut self.role_state.peers, cx)?
|
||||
{
|
||||
Poll::Ready(()) => {
|
||||
cx.waker().wake_by_ref();
|
||||
}
|
||||
@@ -195,17 +196,10 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TunnelStats<TId> {
|
||||
peer_connections: HashMap<TId, PeerStats<TId>>,
|
||||
}
|
||||
|
||||
impl<CB, TRoleState, TRole, TId, TTransform> Tunnel<CB, TRoleState, TRole, TId, TTransform>
|
||||
impl<CB, TRoleState, TRole, TId> Tunnel<CB, TRoleState, TRole, TId>
|
||||
where
|
||||
CB: Callbacks + 'static,
|
||||
TId: Eq + Hash + Copy + fmt::Display,
|
||||
TTransform: PacketTransform,
|
||||
TRoleState: Default,
|
||||
{
|
||||
/// Creates a new tunnel.
|
||||
@@ -242,34 +236,23 @@ where
|
||||
pub fn callbacks(&self) -> &CallbackErrorFacade<CB> {
|
||||
&self.callbacks
|
||||
}
|
||||
|
||||
pub fn stats(&self) -> HashMap<TId, PeerStats<TId>> {
|
||||
self.connections_state
|
||||
.peers_by_id
|
||||
.iter()
|
||||
.map(|(&id, p)| (id, p.stats()))
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
struct ConnectionState<TRole, TId, TTransform> {
|
||||
struct ConnectionState<TRole, TId> {
|
||||
pub node: Node<TRole, TId>,
|
||||
write_buf: Box<[u8; MAX_UDP_SIZE]>,
|
||||
peers_by_id: HashMap<TId, Arc<Peer<TId, TTransform>>>,
|
||||
connection_pool_timeout: BoxFuture<'static, std::time::Instant>,
|
||||
sockets: Sockets,
|
||||
}
|
||||
|
||||
impl<TRole, TId, TTransform> ConnectionState<TRole, TId, TTransform>
|
||||
impl<TRole, TId> ConnectionState<TRole, TId>
|
||||
where
|
||||
TId: Eq + Hash + Copy + fmt::Display,
|
||||
TTransform: PacketTransform,
|
||||
{
|
||||
fn new(private_key: StaticSecret) -> Result<Self> {
|
||||
Ok(ConnectionState {
|
||||
node: Node::new(private_key, std::time::Instant::now()),
|
||||
write_buf: Box::new([0; MAX_UDP_SIZE]),
|
||||
peers_by_id: HashMap::new(),
|
||||
connection_pool_timeout: sleep_until(std::time::Instant::now()).boxed(),
|
||||
sockets: Sockets::new()?,
|
||||
})
|
||||
@@ -294,7 +277,16 @@ where
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn poll_sockets(&mut self, device: &mut Device, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
// TODO: passing the peer_store looks weird, we can just remove ConnectionState and move everything into Tunnel, there's no Mutexes any longer that justify this separation
|
||||
fn poll_sockets<TTransform>(
|
||||
&mut self,
|
||||
device: &mut Device,
|
||||
peer_store: &mut PeerStore<TId, TTransform>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<io::Result<()>>
|
||||
where
|
||||
TTransform: PacketTransform,
|
||||
{
|
||||
let received = match ready!(self.sockets.poll_recv_from(cx)) {
|
||||
Ok(received) => received,
|
||||
Err(e) => {
|
||||
@@ -332,7 +324,7 @@ where
|
||||
|
||||
tracing::trace!(target: "wire", %local, %from, bytes = %packet.packet().len(), "read new packet");
|
||||
|
||||
let Some(peer) = self.peers_by_id.get(&conn_id) else {
|
||||
let Some(peer) = peer_store.get_mut(&conn_id) else {
|
||||
tracing::error!(%conn_id, %local, %from, "Couldn't find connection");
|
||||
|
||||
continue;
|
||||
@@ -378,7 +370,6 @@ where
|
||||
});
|
||||
}
|
||||
Some(snownet::Event::ConnectionFailed(id)) => {
|
||||
self.peers_by_id.remove(&id);
|
||||
return Poll::Ready(Event::StopPeer(id));
|
||||
}
|
||||
_ => {}
|
||||
@@ -397,13 +388,6 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn peer_by_ip<Id, TTransform>(
|
||||
peers_by_ip: &IpNetworkTable<Arc<Peer<Id, TTransform>>>,
|
||||
ip: IpAddr,
|
||||
) -> Option<&Peer<Id, TTransform>> {
|
||||
peers_by_ip.longest_match(ip).map(|(_, peer)| peer.as_ref())
|
||||
}
|
||||
|
||||
pub enum Event<TId> {
|
||||
SignalIceCandidate {
|
||||
conn_id: TId,
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
use std::borrow::Cow;
|
||||
use std::collections::HashMap;
|
||||
use std::net::IpAddr;
|
||||
use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
|
||||
use arc_swap::ArcSwap;
|
||||
use bimap::BiMap;
|
||||
use boringtun::noise::Tunn;
|
||||
use chrono::{DateTime, Utc};
|
||||
@@ -13,7 +11,6 @@ use connlib_shared::IpProvider;
|
||||
use connlib_shared::{Error, Result};
|
||||
use ip_network::IpNetwork;
|
||||
use ip_network_table::IpNetworkTable;
|
||||
use parking_lot::{Mutex, RwLock};
|
||||
use pnet_packet::Packet;
|
||||
|
||||
use crate::control_protocol::gateway::ResourceDescription;
|
||||
@@ -26,31 +23,16 @@ type ExpiryingResource = (ResourceDescription, Option<DateTime<Utc>>);
|
||||
const IDS_EXPIRE: std::time::Duration = std::time::Duration::from_secs(60);
|
||||
|
||||
pub struct Peer<TId, TTransform> {
|
||||
allowed_ips: RwLock<IpNetworkTable<()>>,
|
||||
allowed_ips: IpNetworkTable<()>,
|
||||
pub conn_id: TId,
|
||||
pub transform: TTransform,
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PeerStats<TId> {
|
||||
pub allowed_ips: Vec<IpNetwork>,
|
||||
pub conn_id: TId,
|
||||
}
|
||||
|
||||
impl<TId, TTransform> Peer<TId, TTransform>
|
||||
where
|
||||
TId: Copy,
|
||||
TTransform: PacketTransform,
|
||||
{
|
||||
pub fn stats(&self) -> PeerStats<TId> {
|
||||
let allowed_ips = self.allowed_ips.read().iter().map(|(ip, _)| ip).collect();
|
||||
PeerStats {
|
||||
allowed_ips,
|
||||
conn_id: self.conn_id,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn new(
|
||||
ips: Vec<IpNetwork>,
|
||||
conn_id: TId,
|
||||
@@ -60,7 +42,6 @@ where
|
||||
for ip in ips {
|
||||
allowed_ips.insert(ip, ());
|
||||
}
|
||||
let allowed_ips = RwLock::new(allowed_ips);
|
||||
|
||||
Peer {
|
||||
allowed_ips,
|
||||
@@ -69,21 +50,24 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn add_allowed_ip(&self, ip: IpNetwork) {
|
||||
self.allowed_ips.write().insert(ip, ());
|
||||
pub(crate) fn add_allowed_ip(&mut self, ip: IpNetwork) {
|
||||
self.allowed_ips.insert(ip, ());
|
||||
}
|
||||
|
||||
fn is_allowed(&self, addr: IpAddr) -> bool {
|
||||
self.allowed_ips.read().longest_match(addr).is_some()
|
||||
self.allowed_ips.longest_match(addr).is_some()
|
||||
}
|
||||
|
||||
/// Sends the given packet to this peer by encapsulating it in a wireguard packet.
|
||||
pub(crate) fn transform<'a>(&self, packet: MutableIpPacket<'a>) -> Option<MutableIpPacket<'a>> {
|
||||
pub(crate) fn transform<'a>(
|
||||
&mut self,
|
||||
packet: MutableIpPacket<'a>,
|
||||
) -> Option<MutableIpPacket<'a>> {
|
||||
self.transform.packet_transform(packet)
|
||||
}
|
||||
|
||||
pub(crate) fn untransform<'b>(
|
||||
&self,
|
||||
&mut self,
|
||||
addr: IpAddr,
|
||||
packet: &'b mut [u8],
|
||||
) -> Result<device_channel::Packet<'b>> {
|
||||
@@ -98,86 +82,83 @@ where
|
||||
}
|
||||
|
||||
pub struct PacketTransformGateway {
|
||||
resources: RwLock<IpNetworkTable<ExpiryingResource>>,
|
||||
resources: IpNetworkTable<ExpiryingResource>,
|
||||
}
|
||||
|
||||
impl Default for PacketTransformGateway {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
resources: RwLock::new(IpNetworkTable::new()),
|
||||
resources: IpNetworkTable::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct PacketTransformClient {
|
||||
translations: RwLock<BiMap<IpAddr, IpAddr>>,
|
||||
dns_mapping: ArcSwap<BiMap<IpAddr, DnsServer>>,
|
||||
mangled_dns_ids: Mutex<HashMap<u16, std::time::Instant>>,
|
||||
translations: BiMap<IpAddr, IpAddr>,
|
||||
dns_mapping: BiMap<IpAddr, DnsServer>,
|
||||
mangled_dns_ids: HashMap<u16, std::time::Instant>,
|
||||
}
|
||||
|
||||
impl PacketTransformClient {
|
||||
pub fn get_or_assign_translation(
|
||||
&self,
|
||||
&mut self,
|
||||
ip: &IpAddr,
|
||||
ip_provider: &mut IpProvider,
|
||||
) -> Option<IpAddr> {
|
||||
let mut translations = self.translations.write();
|
||||
if let Some(proxy_ip) = translations.get_by_right(ip) {
|
||||
if let Some(proxy_ip) = self.translations.get_by_right(ip) {
|
||||
return Some(*proxy_ip);
|
||||
}
|
||||
|
||||
let proxy_ip = ip_provider.get_proxy_ip_for(ip)?;
|
||||
|
||||
translations.insert(proxy_ip, *ip);
|
||||
self.translations.insert(proxy_ip, *ip);
|
||||
Some(proxy_ip)
|
||||
}
|
||||
|
||||
pub fn expire_dns_track(&self) {
|
||||
pub fn expire_dns_track(&mut self) {
|
||||
self.mangled_dns_ids
|
||||
.lock()
|
||||
.retain(|_, exp| exp.elapsed() < IDS_EXPIRE);
|
||||
}
|
||||
|
||||
pub fn set_dns(&self, mapping: BiMap<IpAddr, DnsServer>) {
|
||||
self.dns_mapping.store(Arc::new(mapping));
|
||||
pub fn set_dns(&mut self, mapping: BiMap<IpAddr, DnsServer>) {
|
||||
self.dns_mapping = mapping;
|
||||
}
|
||||
}
|
||||
|
||||
impl PacketTransformGateway {
|
||||
pub(crate) fn is_emptied(&self) -> bool {
|
||||
self.resources.read().is_empty()
|
||||
self.resources.is_empty()
|
||||
}
|
||||
|
||||
pub(crate) fn expire_resources(&self) {
|
||||
pub(crate) fn expire_resources(&mut self) {
|
||||
self.resources
|
||||
.write()
|
||||
.retain(|_, (_, e)| !e.is_some_and(|e| e <= Utc::now()));
|
||||
}
|
||||
|
||||
pub(crate) fn add_resource(
|
||||
&self,
|
||||
&mut self,
|
||||
ip: IpNetwork,
|
||||
resource: ResourceDescription,
|
||||
expires_at: Option<DateTime<Utc>>,
|
||||
) {
|
||||
self.resources.write().insert(ip, (resource, expires_at));
|
||||
self.resources.insert(ip, (resource, expires_at));
|
||||
}
|
||||
}
|
||||
|
||||
pub trait PacketTransform {
|
||||
fn packet_untransform<'a>(
|
||||
&self,
|
||||
&mut self,
|
||||
addr: &IpAddr,
|
||||
packet: &'a mut [u8],
|
||||
) -> Result<(device_channel::Packet<'a>, IpAddr)>;
|
||||
|
||||
fn packet_transform<'a>(&self, packet: MutableIpPacket<'a>) -> Option<MutableIpPacket<'a>>;
|
||||
fn packet_transform<'a>(&mut self, packet: MutableIpPacket<'a>) -> Option<MutableIpPacket<'a>>;
|
||||
}
|
||||
|
||||
impl PacketTransform for PacketTransformGateway {
|
||||
fn packet_untransform<'a>(
|
||||
&self,
|
||||
&mut self,
|
||||
addr: &IpAddr,
|
||||
packet: &'a mut [u8],
|
||||
) -> Result<(device_channel::Packet<'a>, IpAddr)> {
|
||||
@@ -185,7 +166,7 @@ impl PacketTransform for PacketTransformGateway {
|
||||
return Err(Error::BadPacket);
|
||||
};
|
||||
|
||||
if self.resources.read().longest_match(dst).is_some() {
|
||||
if self.resources.longest_match(dst).is_some() {
|
||||
let packet = make_packet(packet, addr);
|
||||
Ok((packet, *addr))
|
||||
} else {
|
||||
@@ -194,19 +175,18 @@ impl PacketTransform for PacketTransformGateway {
|
||||
}
|
||||
}
|
||||
|
||||
fn packet_transform<'a>(&self, packet: MutableIpPacket<'a>) -> Option<MutableIpPacket<'a>> {
|
||||
fn packet_transform<'a>(&mut self, packet: MutableIpPacket<'a>) -> Option<MutableIpPacket<'a>> {
|
||||
Some(packet)
|
||||
}
|
||||
}
|
||||
|
||||
impl PacketTransform for PacketTransformClient {
|
||||
fn packet_untransform<'a>(
|
||||
&self,
|
||||
&mut self,
|
||||
addr: &IpAddr,
|
||||
packet: &'a mut [u8],
|
||||
) -> Result<(device_channel::Packet<'a>, IpAddr)> {
|
||||
let translations = self.translations.read();
|
||||
let mut src = *translations.get_by_right(addr).unwrap_or(addr);
|
||||
let mut src = *self.translations.get_by_right(addr).unwrap_or(addr);
|
||||
|
||||
let Some(mut pkt) = MutableIpPacket::new(packet) else {
|
||||
return Err(Error::BadPacket);
|
||||
@@ -216,14 +196,11 @@ impl PacketTransform for PacketTransformClient {
|
||||
if let Some(dgm) = pkt.as_udp() {
|
||||
if let Some(sentinel) = self
|
||||
.dns_mapping
|
||||
.load()
|
||||
.as_ref()
|
||||
.get_by_right(&(src, dgm.get_source()).into())
|
||||
{
|
||||
if let Ok(message) = domain::base::Message::from_slice(dgm.payload()) {
|
||||
if self
|
||||
.mangled_dns_ids
|
||||
.lock()
|
||||
.remove(&message.header().id())
|
||||
.is_some_and(|exp| exp.elapsed() < IDS_EXPIRE)
|
||||
{
|
||||
@@ -239,22 +216,19 @@ impl PacketTransform for PacketTransformClient {
|
||||
Ok((packet, original_src))
|
||||
}
|
||||
|
||||
fn packet_transform<'a>(&self, mut packet: MutableIpPacket<'a>) -> Option<MutableIpPacket<'a>> {
|
||||
if let Some(translated_ip) = self.translations.read().get_by_left(&packet.destination()) {
|
||||
fn packet_transform<'a>(
|
||||
&mut self,
|
||||
mut packet: MutableIpPacket<'a>,
|
||||
) -> Option<MutableIpPacket<'a>> {
|
||||
if let Some(translated_ip) = self.translations.get_by_left(&packet.destination()) {
|
||||
packet.set_dst(*translated_ip);
|
||||
packet.update_checksum();
|
||||
}
|
||||
|
||||
if let Some(srv) = self
|
||||
.dns_mapping
|
||||
.load()
|
||||
.as_ref()
|
||||
.get_by_left(&packet.destination())
|
||||
{
|
||||
if let Some(srv) = self.dns_mapping.get_by_left(&packet.destination()) {
|
||||
if let Some(dgm) = packet.as_udp() {
|
||||
if let Ok(message) = domain::base::Message::from_slice(dgm.payload()) {
|
||||
self.mangled_dns_ids
|
||||
.lock()
|
||||
.insert(message.header().id(), Instant::now());
|
||||
packet.set_dst(srv.ip());
|
||||
packet.update_checksum();
|
||||
|
||||
141
rust/connlib/tunnel/src/peer_store.rs
Normal file
141
rust/connlib/tunnel/src/peer_store.rs
Normal file
@@ -0,0 +1,141 @@
|
||||
use std::collections::HashMap;
|
||||
use std::hash::Hash;
|
||||
use std::net::IpAddr;
|
||||
|
||||
use crate::peer::{PacketTransform, Peer};
|
||||
use ip_network::IpNetwork;
|
||||
use ip_network_table::IpNetworkTable;
|
||||
|
||||
pub struct PeerStore<TId, TTransform> {
|
||||
id_by_ip: IpNetworkTable<TId>,
|
||||
peer_by_id: HashMap<TId, Peer<TId, TTransform>>,
|
||||
}
|
||||
|
||||
impl<T, U> Default for PeerStore<T, U> {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
id_by_ip: IpNetworkTable::new(),
|
||||
peer_by_id: HashMap::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<TId, TTransform> PeerStore<TId, TTransform>
|
||||
where
|
||||
TId: Hash + Eq + Clone + Copy,
|
||||
TTransform: PacketTransform,
|
||||
{
|
||||
pub fn retain(&mut self, f: impl Fn(&TId, &mut Peer<TId, TTransform>) -> bool) {
|
||||
self.peer_by_id.retain(f);
|
||||
self.id_by_ip
|
||||
.retain(|_, id| self.peer_by_id.contains_key(id));
|
||||
}
|
||||
|
||||
pub fn add_ips(&mut self, id: &TId, ips: &[IpNetwork]) -> Option<&Peer<TId, TTransform>> {
|
||||
let peer = self.peer_by_id.get_mut(id)?;
|
||||
|
||||
for ip in ips {
|
||||
self.id_by_ip.insert(*ip, peer.conn_id);
|
||||
peer.add_allowed_ip(*ip);
|
||||
}
|
||||
|
||||
Some(peer)
|
||||
}
|
||||
|
||||
pub fn insert(&mut self, peer: Peer<TId, TTransform>) -> Option<Peer<TId, TTransform>> {
|
||||
self.id_by_ip.retain(|_, &mut r_id| r_id != peer.conn_id);
|
||||
|
||||
self.peer_by_id.insert(peer.conn_id, peer)
|
||||
}
|
||||
|
||||
pub fn remove(&mut self, id: &TId) -> Option<Peer<TId, TTransform>> {
|
||||
self.id_by_ip.retain(|_, r_id| r_id != id);
|
||||
self.peer_by_id.remove(id)
|
||||
}
|
||||
|
||||
pub fn exact_match(&self, ip: IpNetwork) -> Option<&Peer<TId, TTransform>> {
|
||||
let ip = self.id_by_ip.exact_match(ip)?;
|
||||
self.peer_by_id.get(ip)
|
||||
}
|
||||
|
||||
pub fn get(&self, id: &TId) -> Option<&Peer<TId, TTransform>> {
|
||||
self.peer_by_id.get(id)
|
||||
}
|
||||
|
||||
pub fn get_mut(&mut self, id: &TId) -> Option<&mut Peer<TId, TTransform>> {
|
||||
self.peer_by_id.get_mut(id)
|
||||
}
|
||||
|
||||
pub fn peer_by_ip(&self, ip: IpAddr) -> Option<&Peer<TId, TTransform>> {
|
||||
let (_, id) = self.id_by_ip.longest_match(ip)?;
|
||||
self.peer_by_id.get(id)
|
||||
}
|
||||
|
||||
pub fn peer_by_ip_mut(&mut self, ip: IpAddr) -> Option<&mut Peer<TId, TTransform>> {
|
||||
let (_, id) = self.id_by_ip.longest_match(ip)?;
|
||||
self.peer_by_id.get_mut(id)
|
||||
}
|
||||
|
||||
pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut Peer<TId, TTransform>> {
|
||||
self.peer_by_id.values_mut()
|
||||
}
|
||||
|
||||
pub fn iter(&mut self) -> impl Iterator<Item = &Peer<TId, TTransform>> {
|
||||
self.peer_by_id.values()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::peer::{PacketTransformGateway, Peer};
|
||||
|
||||
use super::PeerStore;
|
||||
|
||||
#[test]
|
||||
fn can_insert_and_retrieve_peer() {
|
||||
let mut peer_storage = PeerStore::default();
|
||||
peer_storage.insert(Peer::new(vec![], 0, PacketTransformGateway::default()));
|
||||
assert!(peer_storage.get(&0).is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_insert_and_retrieve_peer_by_ip() {
|
||||
let mut peer_storage = PeerStore::default();
|
||||
peer_storage.insert(Peer::new(vec![], 0, PacketTransformGateway::default()));
|
||||
peer_storage.add_ips(&0, &["100.0.0.0/24".parse().unwrap()]);
|
||||
|
||||
assert_eq!(
|
||||
peer_storage
|
||||
.peer_by_ip("100.0.0.1".parse().unwrap())
|
||||
.unwrap()
|
||||
.conn_id,
|
||||
0
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_remove_peer() {
|
||||
let mut peer_storage = PeerStore::default();
|
||||
peer_storage.insert(Peer::new(vec![], 0, PacketTransformGateway::default()));
|
||||
peer_storage.add_ips(&0, &["100.0.0.0/24".parse().unwrap()]);
|
||||
peer_storage.remove(&0);
|
||||
|
||||
assert!(peer_storage.get(&0).is_none());
|
||||
assert!(peer_storage
|
||||
.peer_by_ip("100.0.0.1".parse().unwrap())
|
||||
.is_none())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn inserting_peer_removes_previous_instances_of_same_id() {
|
||||
let mut peer_storage = PeerStore::default();
|
||||
peer_storage.insert(Peer::new(vec![], 0, PacketTransformGateway::default()));
|
||||
peer_storage.add_ips(&0, &["100.0.0.0/24".parse().unwrap()]);
|
||||
peer_storage.insert(Peer::new(vec![], 0, PacketTransformGateway::default()));
|
||||
|
||||
assert!(peer_storage.get(&0).is_some());
|
||||
assert!(peer_storage
|
||||
.peer_by_ip("100.0.0.1".parse().unwrap())
|
||||
.is_none())
|
||||
}
|
||||
}
|
||||
@@ -29,7 +29,6 @@ pub struct Eventloop {
|
||||
Result<ResourceDescription<ResolvedResourceDescriptionDns>>,
|
||||
Either<RequestConnection, AllowAccess>,
|
||||
>,
|
||||
print_stats_timer: tokio::time::Interval,
|
||||
}
|
||||
|
||||
impl Eventloop {
|
||||
@@ -41,7 +40,6 @@ impl Eventloop {
|
||||
tunnel,
|
||||
portal,
|
||||
resolve_tasks: futures_bounded::FuturesTupleSet::new(Duration::from_secs(60), 100),
|
||||
print_stats_timer: tokio::time::interval(Duration::from_secs(10)),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -104,7 +102,7 @@ impl Eventloop {
|
||||
Err(e) => {
|
||||
let client = req.client.id;
|
||||
|
||||
self.tunnel.cleanup_connection(client);
|
||||
self.tunnel.cleanup_connection(&client);
|
||||
tracing::debug!(%client, "Connection request failed: {:#}", anyhow::Error::new(e));
|
||||
|
||||
continue;
|
||||
@@ -216,11 +214,6 @@ impl Eventloop {
|
||||
_ => {}
|
||||
}
|
||||
|
||||
if self.print_stats_timer.poll_tick(cx).is_ready() {
|
||||
tracing::debug!(target: "tunnel_state", stats = ?self.tunnel.stats());
|
||||
continue;
|
||||
}
|
||||
|
||||
return Poll::Pending;
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user