refactor(connlib): unify peer storage (#3738)

Now that we have `&mut` access everywhere in the tunnel, the remaining
shared-memory and locks are in how we store peers. To resolve this, we
introduce a new `PeerStore` that allows us to look up peers by IP and by
ID.
This commit is contained in:
Gabi
2024-02-26 13:07:38 -03:00
committed by GitHub
parent 220c9ee1e1
commit 5edd195320
13 changed files with 282 additions and 266 deletions

2
rust/Cargo.lock generated
View File

@@ -1922,7 +1922,6 @@ name = "firezone-tunnel"
version = "1.0.0"
dependencies = [
"anyhow",
"arc-swap",
"async-trait",
"bimap",
"boringtun",
@@ -1942,7 +1941,6 @@ dependencies = [
"log",
"netlink-packet-core",
"netlink-packet-route",
"parking_lot",
"pnet_packet",
"quinn-udp",
"rand_core 0.6.4",

View File

@@ -311,10 +311,6 @@ impl<CB: Callbacks + 'static> ControlPlane<CB> {
Ok(())
}
pub async fn stats_event(&mut self) {
tracing::debug!(target: "tunnel_state", stats = ?self.tunnel.stats());
}
pub async fn request_log_upload_url(&mut self) {
tracing::info!("Requesting log upload URL from portal");

View File

@@ -185,7 +185,6 @@ where
let runtime_stopper = runtime_stopper.clone();
let callbacks = callbacks.clone();
async move {
let mut log_stats_interval = tokio::time::interval(Duration::from_secs(10));
let mut upload_logs_interval = upload_interval();
loop {
tokio::select! {
@@ -201,7 +200,6 @@ where
}
},
event = poll_fn(|cx| control_plane.tunnel.poll_next_event(cx)) => control_plane.handle_tunnel_event(event).await,
_ = log_stats_interval.tick() => control_plane.stats_event().await,
_ = upload_logs_interval.tick() => control_plane.request_log_upload_url().await,
else => break
}

View File

@@ -14,7 +14,6 @@ serde = { version = "1.0", default-features = false, features = ["derive", "std"
futures = { version = "0.3", default-features = false, features = ["std", "async-await", "executor"] }
futures-util = { version = "0.3", default-features = false, features = ["std", "async-await", "async-await-macro"] }
tracing = { workspace = true }
parking_lot = { version = "0.12", default-features = false }
bytes = { version = "1.4", default-features = false, features = ["std"] }
itertools = { version = "0.12", default-features = false, features = ["use_std"] }
connlib-shared = { workspace = true }
@@ -27,7 +26,6 @@ chrono = { workspace = true }
pnet_packet = { version = "0.34" }
futures-bounded = { workspace = true }
hickory-resolver = { workspace = true, features = ["tokio-runtime"] }
arc-swap = "1.6.0"
bimap = "0.6"
resolv-conf = "0.7.0"
socket2 = { version = "0.5" }

View File

@@ -1,7 +1,8 @@
use crate::device_channel::{Device, Packet};
use crate::ip_packet::{IpPacket, MutableIpPacket};
use crate::peer::{PacketTransformClient, Peer};
use crate::{dns, dns::DnsQuery, peer_by_ip, Event, Tunnel, DNS_QUERIES_QUEUE_SIZE};
use crate::peer::PacketTransformClient;
use crate::peer_store::PeerStore;
use crate::{dns, dns::DnsQuery, Event, Tunnel, DNS_QUERIES_QUEUE_SIZE};
use bimap::BiMap;
use connlib_shared::error::{ConnlibError as Error, ConnlibError};
use connlib_shared::messages::{
@@ -22,7 +23,6 @@ use hickory_resolver::TokioAsyncResolver;
use std::collections::hash_map::Entry;
use std::collections::{HashMap, HashSet, VecDeque};
use std::net::IpAddr;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Duration;
use tokio::time::{Instant, Interval, MissedTickBehavior};
@@ -47,7 +47,7 @@ impl DnsResource {
}
}
impl<CB> Tunnel<CB, ClientState, Client, GatewayId, PacketTransformClient>
impl<CB> Tunnel<CB, ClientState, Client, GatewayId>
where
CB: Callbacks + 'static,
{
@@ -189,7 +189,7 @@ pub struct ClientState {
pub resource_ids: HashMap<ResourceId, ResourceDescription>,
pub deferred_dns_queries: HashMap<(DnsResource, Rtype), IpPacket<'static>>,
pub peers_by_ip: IpNetworkTable<Arc<Peer<GatewayId, PacketTransformClient>>>,
pub peers: PeerStore<GatewayId, PacketTransformClient>,
forwarded_dns_queries: FuturesTupleSet<
Result<hickory_resolver::lookup::Lookup, hickory_resolver::error::ResolveError>,
@@ -227,7 +227,7 @@ impl ClientState {
Err(non_dns_packet) => non_dns_packet,
};
let Some(peer) = peer_by_ip(&self.peers_by_ip, dest) else {
let Some(peer) = self.peers.peer_by_ip_mut(dest) else {
self.on_connection_intent_ip(dest);
return None;
};
@@ -309,7 +309,7 @@ impl ClientState {
let domain = self.get_awaiting_connection_domain(&resource)?.clone();
if self.is_connected_to(resource, &self.peers_by_ip, &domain) {
if self.is_connected_to(resource, &domain) {
return Err(Error::UnexpectedConnectionDetails);
}
@@ -332,11 +332,11 @@ impl ClientState {
self.resources_gateways.insert(resource, gateway);
let Some(peer) = self
.peers_by_ip
.iter()
.find_map(|(_, p)| (p.conn_id == gateway).then_some(p.clone()))
else {
if self
.peers
.add_ips(&gateway, &self.get_resource_ip(desc, &domain))
.is_none()
{
match self
.gateway_awaiting_connection_timers
// Note: we don't need to set a timer here because
@@ -357,10 +357,6 @@ impl ClientState {
return Ok(None);
};
for ip in self.get_resource_ip(desc, &domain) {
peer.add_allowed_ip(ip);
self.peers_by_ip.insert(ip, peer.clone());
}
self.awaiting_connection.remove(&resource);
self.awaiting_connection_timers.remove(resource);
@@ -531,19 +527,13 @@ impl ClientState {
self.awaiting_connection.contains_key(&resource.id())
}
fn is_connected_to(
&self,
resource: ResourceId,
connected_peers: &IpNetworkTable<Arc<Peer<GatewayId, PacketTransformClient>>>,
domain: &Option<Dname>,
) -> bool {
fn is_connected_to(&self, resource: ResourceId, domain: &Option<Dname>) -> bool {
let Some(resource) = self.resource_ids.get(&resource) else {
return false;
};
let ips = self.get_resource_ip(resource, domain);
ips.iter()
.any(|ip| connected_peers.exact_match(*ip).is_some())
ips.iter().any(|ip| self.peers.exact_match(*ip).is_some())
}
fn get_resource_ip(
@@ -571,7 +561,7 @@ impl ClientState {
}
pub fn cleanup_connected_gateway(&mut self, gateway_id: &GatewayId) {
self.peers_by_ip.retain(|_, p| p.conn_id != *gateway_id);
self.peers.remove(gateway_id);
self.dns_resources_internal_ips.retain(|resource, _| {
!self
.resources_gateways
@@ -668,20 +658,16 @@ impl ClientState {
if self.refresh_dns_timer.poll_tick(cx).is_ready() {
let mut connections = Vec::new();
self.peers_by_ip
.iter()
.for_each(|p| p.1.transform.expire_dns_track());
self.peers
.iter_mut()
.for_each(|p| p.transform.expire_dns_track());
for resource in self.dns_resources_internal_ips.keys() {
let Some(gateway_id) = self.resources_gateways.get(&resource.id) else {
continue;
};
// filter inactive connections
if !self
.peers_by_ip
.iter()
.any(|(_, p)| &p.conn_id == gateway_id)
{
if self.peers.get(gateway_id).is_none() {
continue;
}
@@ -761,7 +747,7 @@ impl Default for ClientState {
dns_resources: Default::default(),
cidr_resources: IpNetworkTable::new(),
resource_ids: Default::default(),
peers_by_ip: IpNetworkTable::new(),
peers: Default::default(),
deferred_dns_queries: Default::default(),
refresh_dns_timer: interval,
dns_mapping: Default::default(),

View File

@@ -1,13 +1,11 @@
use ip_network::IpNetwork;
use ip_network_table::IpNetworkTable;
use std::{collections::HashSet, fmt, hash::Hash, net::SocketAddr, sync::Arc};
use std::{collections::HashSet, fmt, hash::Hash, net::SocketAddr};
use connlib_shared::{
messages::{Relay, RequestConnection, ReuseConnection},
Callbacks,
};
use crate::{peer::Peer, Tunnel, REALM};
use crate::{Tunnel, REALM};
mod client;
pub mod gateway;
@@ -18,7 +16,7 @@ pub enum Request {
ReuseConnection(ReuseConnection),
}
impl<CB, TRoleState, TRole, TId, TTransform> Tunnel<CB, TRoleState, TRole, TId, TTransform>
impl<CB, TRoleState, TRole, TId> Tunnel<CB, TRoleState, TRole, TId>
where
CB: Callbacks + 'static,
TId: Eq + Hash + Copy + fmt::Display,
@@ -30,16 +28,6 @@ where
}
}
fn insert_peers<TId: Copy, TTransform>(
peers_by_ip: &mut IpNetworkTable<Arc<Peer<TId, TTransform>>>,
ips: &Vec<IpNetwork>,
peer: Arc<Peer<TId, TTransform>>,
) {
for ip in ips {
peers_by_ip.insert(*ip, peer.clone());
}
}
fn stun(relays: &[Relay], predicate: impl Fn(&SocketAddr) -> bool) -> HashSet<SocketAddr> {
relays
.iter()

View File

@@ -1,4 +1,4 @@
use std::{collections::HashSet, net::IpAddr, sync::Arc};
use std::{collections::HashSet, net::IpAddr};
use boringtun::x25519::PublicKey;
use connlib_shared::{
@@ -23,9 +23,7 @@ use crate::{
};
use crate::{peer::Peer, ClientState, Error, Request, Result, Tunnel};
use super::insert_peers;
impl<CB> Tunnel<CB, ClientState, Client, GatewayId, PacketTransformClient>
impl<CB> Tunnel<CB, ClientState, Client, GatewayId>
where
CB: Callbacks + 'static,
{
@@ -108,24 +106,18 @@ where
&domain_response.as_ref().map(|d| d.domain.clone()),
)?;
let peer = Arc::new(Peer::new(ips.clone(), gateway_id, Default::default()));
let mut peer: Peer<_, PacketTransformClient> =
Peer::new(ips.clone(), gateway_id, Default::default());
peer.transform.set_dns(self.role_state.dns_mapping());
self.role_state.peers.insert(peer);
let peer_ips = if let Some(domain_response) = domain_response {
self.dns_response(&resource_id, &domain_response, &peer)?
self.dns_response(&resource_id, &domain_response, &gateway_id)?
} else {
ips
};
peer.transform.set_dns(self.role_state.dns_mapping());
// cleaning up old state
self.role_state
.peers_by_ip
.retain(|_, p| p.conn_id != gateway_id);
self.connections_state
.peers_by_id
.insert(gateway_id, Arc::clone(&peer));
insert_peers(&mut self.role_state.peers_by_ip, &peer_ips, peer);
self.role_state.peers.add_ips(&gateway_id, &peer_ips);
Ok(())
}
@@ -168,8 +160,14 @@ where
&mut self,
resource_id: &ResourceId,
domain_response: &DomainResponse,
peer: &Peer<GatewayId, PacketTransformClient>,
peer_id: &GatewayId,
) -> Result<Vec<IpNetwork>> {
let peer = self
.role_state
.peers
.get_mut(peer_id)
.ok_or(Error::ControlProtocolError)?;
let resource_description = self
.role_state
.resource_ids
@@ -199,9 +197,6 @@ where
.insert(resource_description.clone(), addrs.clone());
let ips: Vec<IpNetwork> = addrs.iter().copied().map(Into::into).collect();
for ip in &ips {
peer.add_allowed_ip(*ip);
}
if let Some(device) = self.device.as_ref() {
send_dns_answer(
@@ -235,17 +230,10 @@ where
.gateway_by_resource(&resource_id)
.ok_or(Error::UnknownResource)?;
let Some(peer) = self
.role_state
.peers_by_ip
.iter_mut()
.find_map(|(_, p)| (p.conn_id == gateway_id).then_some(p.clone()))
else {
return Err(Error::ControlProtocolError);
};
let peer_ips = self.dns_response(&resource_id, &domain_response, &gateway_id)?;
self.role_state.peers.add_ips(&gateway_id, &peer_ips);
let peer_ips = self.dns_response(&resource_id, &domain_response, &peer)?;
insert_peers(&mut self.role_state.peers_by_ip, &peer_ips, peer);
Ok(())
}
}

View File

@@ -1,5 +1,4 @@
use crate::{
control_protocol::insert_peers,
dns::is_subdomain,
peer::{PacketTransformGateway, Peer},
Error, GatewayState, Tunnel,
@@ -17,7 +16,6 @@ use connlib_shared::{
use ip_network::IpNetwork;
use secrecy::{ExposeSecret as _, Secret};
use snownet::{Credentials, Server};
use std::sync::Arc;
/// Description of a resource that maps to a DNS record which had its domain already resolved.
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
@@ -36,7 +34,7 @@ pub struct ResolvedResourceDescriptionDns {
pub type ResourceDescription =
connlib_shared::messages::ResourceDescription<ResolvedResourceDescriptionDns>;
impl<CB> Tunnel<CB, GatewayState, Server, ClientId, PacketTransformGateway>
impl<CB> Tunnel<CB, GatewayState, Server, ClientId>
where
CB: Callbacks + 'static,
{
@@ -125,14 +123,7 @@ where
expires_at: Option<DateTime<Utc>>,
domain: Option<Dname>,
) -> Option<DomainResponse> {
let Some(peer) = self
.role_state
.peers_by_ip
.iter_mut()
.find_map(|(_, p)| (p.conn_id == client).then_some(p.clone()))
else {
return None;
};
let peer = self.role_state.peers.get_mut(&client)?;
let (addresses, resource_id) = match &resource {
ResourceDescription::Dns(r) => {
@@ -176,25 +167,15 @@ where
) -> Result<()> {
tracing::trace!(?ips, "new_data_channel_open");
let peer = Arc::new(Peer::new(
ips.clone(),
client_id,
PacketTransformGateway::default(),
));
let mut peer = Peer::new(ips.clone(), client_id, PacketTransformGateway::default());
for address in resource_addresses {
peer.transform
.add_resource(address, resource.clone(), expires_at);
}
// cleaning up old state
self.role_state
.peers_by_ip
.retain(|_, p| p.conn_id != client_id);
self.connections_state
.peers_by_id
.insert(client_id, Arc::clone(&peer));
insert_peers(&mut self.role_state.peers_by_ip, &ips, peer);
self.role_state.peers.insert(peer);
self.role_state.peers.add_ips(&client_id, &ips);
Ok(())
}

View File

@@ -1,21 +1,20 @@
use crate::device_channel::Device;
use crate::ip_packet::MutableIpPacket;
use crate::peer::{PacketTransformGateway, Peer};
use crate::{peer_by_ip, Tunnel};
use connlib_shared::messages::{ClientId, Interface as InterfaceConfig};
use connlib_shared::Callbacks;
use ip_network_table::IpNetworkTable;
use itertools::Itertools;
use snownet::Server;
use std::sync::Arc;
use std::task::{ready, Context, Poll};
use std::time::Duration;
use crate::device_channel::Device;
use crate::ip_packet::MutableIpPacket;
use crate::peer::PacketTransformGateway;
use crate::peer_store::PeerStore;
use crate::Tunnel;
use connlib_shared::messages::{ClientId, Interface as InterfaceConfig};
use connlib_shared::Callbacks;
use snownet::Server;
use tokio::time::{interval, Interval, MissedTickBehavior};
const PEERS_IPV4: &str = "100.64.0.0/11";
const PEERS_IPV6: &str = "fd00:2021:1111::/107";
impl<CB> Tunnel<CB, GatewayState, Server, ClientId, PacketTransformGateway>
impl<CB> Tunnel<CB, GatewayState, Server, ClientId>
where
CB: Callbacks + 'static,
{
@@ -38,16 +37,14 @@ where
}
/// Clean up a connection to a resource.
pub fn cleanup_connection(&mut self, id: ClientId) {
self.connections_state.peers_by_id.remove(&id);
self.role_state.peers_by_ip.retain(|_, p| p.conn_id != id);
pub fn cleanup_connection(&mut self, id: &ClientId) {
self.role_state.peers.remove(id);
}
}
/// [`Tunnel`] state specific to gateways.
pub struct GatewayState {
#[allow(clippy::type_complexity)]
pub peers_by_ip: IpNetworkTable<Arc<Peer<ClientId, PacketTransformGateway>>>,
pub peers: PeerStore<ClientId, PacketTransformGateway>,
expire_interval: Interval,
}
@@ -58,29 +55,23 @@ impl GatewayState {
) -> Option<(ClientId, MutableIpPacket<'a>)> {
let dest = packet.destination();
let peer = peer_by_ip(&self.peers_by_ip, dest)?;
let peer = self.peers.peer_by_ip_mut(dest)?;
let packet = peer.transform(packet)?;
Some((peer.conn_id, packet))
}
pub fn poll(&mut self, cx: &mut Context<'_>) -> Poll<Vec<ClientId>> {
pub fn poll(&mut self, cx: &mut Context<'_>) -> Poll<()> {
ready!(self.expire_interval.poll_tick(cx));
Poll::Ready(self.expire_resources().collect_vec())
self.expire_resources();
Poll::Ready(())
}
fn expire_resources(&self) -> impl Iterator<Item = ClientId> + '_ {
self.peers_by_ip
.iter()
.unique_by(|(_, p)| p.conn_id)
.for_each(|(_, p)| p.transform.expire_resources());
self.peers_by_ip.iter().filter_map(|(_, p)| {
if p.transform.is_emptied() {
Some(p.conn_id)
} else {
None
}
})
fn expire_resources(&mut self) {
self.peers
.iter_mut()
.for_each(|p| p.transform.expire_resources());
self.peers.retain(|_, p| !p.transform.is_emptied());
}
}
@@ -89,7 +80,7 @@ impl Default for GatewayState {
let mut expire_interval = interval(Duration::from_secs(1));
expire_interval.set_missed_tick_behavior(MissedTickBehavior::Delay);
Self {
peers_by_ip: IpNetworkTable::new(),
peers: Default::default(),
expire_interval,
}
}

View File

@@ -10,18 +10,16 @@ use connlib_shared::{
};
use device_channel::Device;
use futures_util::{future::BoxFuture, task::AtomicWaker, FutureExt};
use ip_network_table::IpNetworkTable;
use peer::{PacketTransform, PacketTransformClient, PacketTransformGateway, Peer, PeerStats};
use peer::PacketTransform;
use peer_store::PeerStore;
use pnet_packet::Packet;
use snownet::{IpPacket, Node, Server};
use sockets::{Received, Sockets};
use std::{
collections::{HashMap, HashSet},
collections::HashSet,
fmt,
hash::Hash,
io,
net::IpAddr,
sync::Arc,
task::{ready, Context, Poll},
time::Instant,
};
@@ -37,6 +35,7 @@ mod dns;
mod gateway;
mod ip_packet;
mod peer;
mod peer_store;
mod sockets;
const MAX_UDP_SIZE: usize = (1 << 16) - 1;
@@ -47,12 +46,11 @@ const REALM: &str = "firezone";
#[cfg(target_os = "linux")]
const FIREZONE_MARK: u32 = 0xfd002021;
pub type GatewayTunnel<CB> = Tunnel<CB, GatewayState, Server, ClientId, PacketTransformGateway>;
pub type ClientTunnel<CB> =
Tunnel<CB, ClientState, snownet::Client, GatewayId, PacketTransformClient>;
pub type GatewayTunnel<CB> = Tunnel<CB, GatewayState, Server, ClientId>;
pub type ClientTunnel<CB> = Tunnel<CB, ClientState, snownet::Client, GatewayId>;
/// Tunnel is a wireguard state machine that uses webrtc's ICE channels instead of UDP sockets to communicate between peers.
pub struct Tunnel<CB: Callbacks, TRoleState, TRole, TId, TTransform> {
pub struct Tunnel<CB: Callbacks, TRoleState, TRole, TId> {
callbacks: CallbackErrorFacade<CB>,
/// State that differs per role, i.e. clients vs gateways.
@@ -61,12 +59,12 @@ pub struct Tunnel<CB: Callbacks, TRoleState, TRole, TId, TTransform> {
device: Option<Device>,
no_device_waker: AtomicWaker,
connections_state: ConnectionState<TRole, TId, TTransform>,
connections_state: ConnectionState<TRole, TId>,
read_buf: [u8; MAX_UDP_SIZE],
}
impl<CB> Tunnel<CB, ClientState, snownet::Client, GatewayId, PacketTransformClient>
impl<CB> Tunnel<CB, ClientState, snownet::Client, GatewayId>
where
CB: Callbacks + 'static,
{
@@ -94,7 +92,10 @@ where
_ => (),
}
match self.connections_state.poll_sockets(device, cx)? {
match self
.connections_state
.poll_sockets(device, &mut self.role_state.peers, cx)?
{
Poll::Ready(()) => {
cx.waker().wake_by_ref();
}
@@ -129,17 +130,14 @@ where
}
}
impl<CB> Tunnel<CB, GatewayState, Server, ClientId, PacketTransformGateway>
impl<CB> Tunnel<CB, GatewayState, Server, ClientId>
where
CB: Callbacks + 'static,
{
pub fn poll_next_event(&mut self, cx: &mut Context<'_>) -> Poll<Result<Event<ClientId>>> {
match self.role_state.poll(cx) {
Poll::Ready(ids) => {
Poll::Ready(()) => {
cx.waker().wake_by_ref();
for id in ids {
self.cleanup_connection(id);
}
}
Poll::Pending => {}
}
@@ -151,14 +149,17 @@ where
match self.connections_state.poll_next_event(cx) {
Poll::Ready(Event::StopPeer(id)) => {
self.role_state.peers_by_ip.retain(|_, p| p.conn_id != id);
self.role_state.peers.remove(&id);
cx.waker().wake_by_ref();
}
Poll::Ready(other) => return Poll::Ready(Ok(other)),
_ => (),
}
match self.connections_state.poll_sockets(device, cx)? {
match self
.connections_state
.poll_sockets(device, &mut self.role_state.peers, cx)?
{
Poll::Ready(()) => {
cx.waker().wake_by_ref();
}
@@ -195,17 +196,10 @@ where
}
}
#[allow(dead_code)]
#[derive(Debug, Clone)]
pub struct TunnelStats<TId> {
peer_connections: HashMap<TId, PeerStats<TId>>,
}
impl<CB, TRoleState, TRole, TId, TTransform> Tunnel<CB, TRoleState, TRole, TId, TTransform>
impl<CB, TRoleState, TRole, TId> Tunnel<CB, TRoleState, TRole, TId>
where
CB: Callbacks + 'static,
TId: Eq + Hash + Copy + fmt::Display,
TTransform: PacketTransform,
TRoleState: Default,
{
/// Creates a new tunnel.
@@ -242,34 +236,23 @@ where
pub fn callbacks(&self) -> &CallbackErrorFacade<CB> {
&self.callbacks
}
pub fn stats(&self) -> HashMap<TId, PeerStats<TId>> {
self.connections_state
.peers_by_id
.iter()
.map(|(&id, p)| (id, p.stats()))
.collect()
}
}
struct ConnectionState<TRole, TId, TTransform> {
struct ConnectionState<TRole, TId> {
pub node: Node<TRole, TId>,
write_buf: Box<[u8; MAX_UDP_SIZE]>,
peers_by_id: HashMap<TId, Arc<Peer<TId, TTransform>>>,
connection_pool_timeout: BoxFuture<'static, std::time::Instant>,
sockets: Sockets,
}
impl<TRole, TId, TTransform> ConnectionState<TRole, TId, TTransform>
impl<TRole, TId> ConnectionState<TRole, TId>
where
TId: Eq + Hash + Copy + fmt::Display,
TTransform: PacketTransform,
{
fn new(private_key: StaticSecret) -> Result<Self> {
Ok(ConnectionState {
node: Node::new(private_key, std::time::Instant::now()),
write_buf: Box::new([0; MAX_UDP_SIZE]),
peers_by_id: HashMap::new(),
connection_pool_timeout: sleep_until(std::time::Instant::now()).boxed(),
sockets: Sockets::new()?,
})
@@ -294,7 +277,16 @@ where
Ok(())
}
fn poll_sockets(&mut self, device: &mut Device, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
// TODO: passing the peer_store looks weird, we can just remove ConnectionState and move everything into Tunnel, there's no Mutexes any longer that justify this separation
fn poll_sockets<TTransform>(
&mut self,
device: &mut Device,
peer_store: &mut PeerStore<TId, TTransform>,
cx: &mut Context<'_>,
) -> Poll<io::Result<()>>
where
TTransform: PacketTransform,
{
let received = match ready!(self.sockets.poll_recv_from(cx)) {
Ok(received) => received,
Err(e) => {
@@ -332,7 +324,7 @@ where
tracing::trace!(target: "wire", %local, %from, bytes = %packet.packet().len(), "read new packet");
let Some(peer) = self.peers_by_id.get(&conn_id) else {
let Some(peer) = peer_store.get_mut(&conn_id) else {
tracing::error!(%conn_id, %local, %from, "Couldn't find connection");
continue;
@@ -378,7 +370,6 @@ where
});
}
Some(snownet::Event::ConnectionFailed(id)) => {
self.peers_by_id.remove(&id);
return Poll::Ready(Event::StopPeer(id));
}
_ => {}
@@ -397,13 +388,6 @@ where
}
}
pub(crate) fn peer_by_ip<Id, TTransform>(
peers_by_ip: &IpNetworkTable<Arc<Peer<Id, TTransform>>>,
ip: IpAddr,
) -> Option<&Peer<Id, TTransform>> {
peers_by_ip.longest_match(ip).map(|(_, peer)| peer.as_ref())
}
pub enum Event<TId> {
SignalIceCandidate {
conn_id: TId,

View File

@@ -1,10 +1,8 @@
use std::borrow::Cow;
use std::collections::HashMap;
use std::net::IpAddr;
use std::sync::Arc;
use std::time::Instant;
use arc_swap::ArcSwap;
use bimap::BiMap;
use boringtun::noise::Tunn;
use chrono::{DateTime, Utc};
@@ -13,7 +11,6 @@ use connlib_shared::IpProvider;
use connlib_shared::{Error, Result};
use ip_network::IpNetwork;
use ip_network_table::IpNetworkTable;
use parking_lot::{Mutex, RwLock};
use pnet_packet::Packet;
use crate::control_protocol::gateway::ResourceDescription;
@@ -26,31 +23,16 @@ type ExpiryingResource = (ResourceDescription, Option<DateTime<Utc>>);
const IDS_EXPIRE: std::time::Duration = std::time::Duration::from_secs(60);
pub struct Peer<TId, TTransform> {
allowed_ips: RwLock<IpNetworkTable<()>>,
allowed_ips: IpNetworkTable<()>,
pub conn_id: TId,
pub transform: TTransform,
}
#[allow(dead_code)]
#[derive(Debug, Clone)]
pub struct PeerStats<TId> {
pub allowed_ips: Vec<IpNetwork>,
pub conn_id: TId,
}
impl<TId, TTransform> Peer<TId, TTransform>
where
TId: Copy,
TTransform: PacketTransform,
{
pub fn stats(&self) -> PeerStats<TId> {
let allowed_ips = self.allowed_ips.read().iter().map(|(ip, _)| ip).collect();
PeerStats {
allowed_ips,
conn_id: self.conn_id,
}
}
pub(crate) fn new(
ips: Vec<IpNetwork>,
conn_id: TId,
@@ -60,7 +42,6 @@ where
for ip in ips {
allowed_ips.insert(ip, ());
}
let allowed_ips = RwLock::new(allowed_ips);
Peer {
allowed_ips,
@@ -69,21 +50,24 @@ where
}
}
pub(crate) fn add_allowed_ip(&self, ip: IpNetwork) {
self.allowed_ips.write().insert(ip, ());
pub(crate) fn add_allowed_ip(&mut self, ip: IpNetwork) {
self.allowed_ips.insert(ip, ());
}
fn is_allowed(&self, addr: IpAddr) -> bool {
self.allowed_ips.read().longest_match(addr).is_some()
self.allowed_ips.longest_match(addr).is_some()
}
/// Sends the given packet to this peer by encapsulating it in a wireguard packet.
pub(crate) fn transform<'a>(&self, packet: MutableIpPacket<'a>) -> Option<MutableIpPacket<'a>> {
pub(crate) fn transform<'a>(
&mut self,
packet: MutableIpPacket<'a>,
) -> Option<MutableIpPacket<'a>> {
self.transform.packet_transform(packet)
}
pub(crate) fn untransform<'b>(
&self,
&mut self,
addr: IpAddr,
packet: &'b mut [u8],
) -> Result<device_channel::Packet<'b>> {
@@ -98,86 +82,83 @@ where
}
pub struct PacketTransformGateway {
resources: RwLock<IpNetworkTable<ExpiryingResource>>,
resources: IpNetworkTable<ExpiryingResource>,
}
impl Default for PacketTransformGateway {
fn default() -> Self {
Self {
resources: RwLock::new(IpNetworkTable::new()),
resources: IpNetworkTable::new(),
}
}
}
#[derive(Default)]
pub struct PacketTransformClient {
translations: RwLock<BiMap<IpAddr, IpAddr>>,
dns_mapping: ArcSwap<BiMap<IpAddr, DnsServer>>,
mangled_dns_ids: Mutex<HashMap<u16, std::time::Instant>>,
translations: BiMap<IpAddr, IpAddr>,
dns_mapping: BiMap<IpAddr, DnsServer>,
mangled_dns_ids: HashMap<u16, std::time::Instant>,
}
impl PacketTransformClient {
pub fn get_or_assign_translation(
&self,
&mut self,
ip: &IpAddr,
ip_provider: &mut IpProvider,
) -> Option<IpAddr> {
let mut translations = self.translations.write();
if let Some(proxy_ip) = translations.get_by_right(ip) {
if let Some(proxy_ip) = self.translations.get_by_right(ip) {
return Some(*proxy_ip);
}
let proxy_ip = ip_provider.get_proxy_ip_for(ip)?;
translations.insert(proxy_ip, *ip);
self.translations.insert(proxy_ip, *ip);
Some(proxy_ip)
}
pub fn expire_dns_track(&self) {
pub fn expire_dns_track(&mut self) {
self.mangled_dns_ids
.lock()
.retain(|_, exp| exp.elapsed() < IDS_EXPIRE);
}
pub fn set_dns(&self, mapping: BiMap<IpAddr, DnsServer>) {
self.dns_mapping.store(Arc::new(mapping));
pub fn set_dns(&mut self, mapping: BiMap<IpAddr, DnsServer>) {
self.dns_mapping = mapping;
}
}
impl PacketTransformGateway {
pub(crate) fn is_emptied(&self) -> bool {
self.resources.read().is_empty()
self.resources.is_empty()
}
pub(crate) fn expire_resources(&self) {
pub(crate) fn expire_resources(&mut self) {
self.resources
.write()
.retain(|_, (_, e)| !e.is_some_and(|e| e <= Utc::now()));
}
pub(crate) fn add_resource(
&self,
&mut self,
ip: IpNetwork,
resource: ResourceDescription,
expires_at: Option<DateTime<Utc>>,
) {
self.resources.write().insert(ip, (resource, expires_at));
self.resources.insert(ip, (resource, expires_at));
}
}
pub trait PacketTransform {
fn packet_untransform<'a>(
&self,
&mut self,
addr: &IpAddr,
packet: &'a mut [u8],
) -> Result<(device_channel::Packet<'a>, IpAddr)>;
fn packet_transform<'a>(&self, packet: MutableIpPacket<'a>) -> Option<MutableIpPacket<'a>>;
fn packet_transform<'a>(&mut self, packet: MutableIpPacket<'a>) -> Option<MutableIpPacket<'a>>;
}
impl PacketTransform for PacketTransformGateway {
fn packet_untransform<'a>(
&self,
&mut self,
addr: &IpAddr,
packet: &'a mut [u8],
) -> Result<(device_channel::Packet<'a>, IpAddr)> {
@@ -185,7 +166,7 @@ impl PacketTransform for PacketTransformGateway {
return Err(Error::BadPacket);
};
if self.resources.read().longest_match(dst).is_some() {
if self.resources.longest_match(dst).is_some() {
let packet = make_packet(packet, addr);
Ok((packet, *addr))
} else {
@@ -194,19 +175,18 @@ impl PacketTransform for PacketTransformGateway {
}
}
fn packet_transform<'a>(&self, packet: MutableIpPacket<'a>) -> Option<MutableIpPacket<'a>> {
fn packet_transform<'a>(&mut self, packet: MutableIpPacket<'a>) -> Option<MutableIpPacket<'a>> {
Some(packet)
}
}
impl PacketTransform for PacketTransformClient {
fn packet_untransform<'a>(
&self,
&mut self,
addr: &IpAddr,
packet: &'a mut [u8],
) -> Result<(device_channel::Packet<'a>, IpAddr)> {
let translations = self.translations.read();
let mut src = *translations.get_by_right(addr).unwrap_or(addr);
let mut src = *self.translations.get_by_right(addr).unwrap_or(addr);
let Some(mut pkt) = MutableIpPacket::new(packet) else {
return Err(Error::BadPacket);
@@ -216,14 +196,11 @@ impl PacketTransform for PacketTransformClient {
if let Some(dgm) = pkt.as_udp() {
if let Some(sentinel) = self
.dns_mapping
.load()
.as_ref()
.get_by_right(&(src, dgm.get_source()).into())
{
if let Ok(message) = domain::base::Message::from_slice(dgm.payload()) {
if self
.mangled_dns_ids
.lock()
.remove(&message.header().id())
.is_some_and(|exp| exp.elapsed() < IDS_EXPIRE)
{
@@ -239,22 +216,19 @@ impl PacketTransform for PacketTransformClient {
Ok((packet, original_src))
}
fn packet_transform<'a>(&self, mut packet: MutableIpPacket<'a>) -> Option<MutableIpPacket<'a>> {
if let Some(translated_ip) = self.translations.read().get_by_left(&packet.destination()) {
fn packet_transform<'a>(
&mut self,
mut packet: MutableIpPacket<'a>,
) -> Option<MutableIpPacket<'a>> {
if let Some(translated_ip) = self.translations.get_by_left(&packet.destination()) {
packet.set_dst(*translated_ip);
packet.update_checksum();
}
if let Some(srv) = self
.dns_mapping
.load()
.as_ref()
.get_by_left(&packet.destination())
{
if let Some(srv) = self.dns_mapping.get_by_left(&packet.destination()) {
if let Some(dgm) = packet.as_udp() {
if let Ok(message) = domain::base::Message::from_slice(dgm.payload()) {
self.mangled_dns_ids
.lock()
.insert(message.header().id(), Instant::now());
packet.set_dst(srv.ip());
packet.update_checksum();

View File

@@ -0,0 +1,141 @@
use std::collections::HashMap;
use std::hash::Hash;
use std::net::IpAddr;
use crate::peer::{PacketTransform, Peer};
use ip_network::IpNetwork;
use ip_network_table::IpNetworkTable;
pub struct PeerStore<TId, TTransform> {
id_by_ip: IpNetworkTable<TId>,
peer_by_id: HashMap<TId, Peer<TId, TTransform>>,
}
impl<T, U> Default for PeerStore<T, U> {
fn default() -> Self {
Self {
id_by_ip: IpNetworkTable::new(),
peer_by_id: HashMap::new(),
}
}
}
impl<TId, TTransform> PeerStore<TId, TTransform>
where
TId: Hash + Eq + Clone + Copy,
TTransform: PacketTransform,
{
pub fn retain(&mut self, f: impl Fn(&TId, &mut Peer<TId, TTransform>) -> bool) {
self.peer_by_id.retain(f);
self.id_by_ip
.retain(|_, id| self.peer_by_id.contains_key(id));
}
pub fn add_ips(&mut self, id: &TId, ips: &[IpNetwork]) -> Option<&Peer<TId, TTransform>> {
let peer = self.peer_by_id.get_mut(id)?;
for ip in ips {
self.id_by_ip.insert(*ip, peer.conn_id);
peer.add_allowed_ip(*ip);
}
Some(peer)
}
pub fn insert(&mut self, peer: Peer<TId, TTransform>) -> Option<Peer<TId, TTransform>> {
self.id_by_ip.retain(|_, &mut r_id| r_id != peer.conn_id);
self.peer_by_id.insert(peer.conn_id, peer)
}
pub fn remove(&mut self, id: &TId) -> Option<Peer<TId, TTransform>> {
self.id_by_ip.retain(|_, r_id| r_id != id);
self.peer_by_id.remove(id)
}
pub fn exact_match(&self, ip: IpNetwork) -> Option<&Peer<TId, TTransform>> {
let ip = self.id_by_ip.exact_match(ip)?;
self.peer_by_id.get(ip)
}
pub fn get(&self, id: &TId) -> Option<&Peer<TId, TTransform>> {
self.peer_by_id.get(id)
}
pub fn get_mut(&mut self, id: &TId) -> Option<&mut Peer<TId, TTransform>> {
self.peer_by_id.get_mut(id)
}
pub fn peer_by_ip(&self, ip: IpAddr) -> Option<&Peer<TId, TTransform>> {
let (_, id) = self.id_by_ip.longest_match(ip)?;
self.peer_by_id.get(id)
}
pub fn peer_by_ip_mut(&mut self, ip: IpAddr) -> Option<&mut Peer<TId, TTransform>> {
let (_, id) = self.id_by_ip.longest_match(ip)?;
self.peer_by_id.get_mut(id)
}
pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut Peer<TId, TTransform>> {
self.peer_by_id.values_mut()
}
pub fn iter(&mut self) -> impl Iterator<Item = &Peer<TId, TTransform>> {
self.peer_by_id.values()
}
}
#[cfg(test)]
mod tests {
use crate::peer::{PacketTransformGateway, Peer};
use super::PeerStore;
#[test]
fn can_insert_and_retrieve_peer() {
let mut peer_storage = PeerStore::default();
peer_storage.insert(Peer::new(vec![], 0, PacketTransformGateway::default()));
assert!(peer_storage.get(&0).is_some());
}
#[test]
fn can_insert_and_retrieve_peer_by_ip() {
let mut peer_storage = PeerStore::default();
peer_storage.insert(Peer::new(vec![], 0, PacketTransformGateway::default()));
peer_storage.add_ips(&0, &["100.0.0.0/24".parse().unwrap()]);
assert_eq!(
peer_storage
.peer_by_ip("100.0.0.1".parse().unwrap())
.unwrap()
.conn_id,
0
);
}
#[test]
fn can_remove_peer() {
let mut peer_storage = PeerStore::default();
peer_storage.insert(Peer::new(vec![], 0, PacketTransformGateway::default()));
peer_storage.add_ips(&0, &["100.0.0.0/24".parse().unwrap()]);
peer_storage.remove(&0);
assert!(peer_storage.get(&0).is_none());
assert!(peer_storage
.peer_by_ip("100.0.0.1".parse().unwrap())
.is_none())
}
#[test]
fn inserting_peer_removes_previous_instances_of_same_id() {
let mut peer_storage = PeerStore::default();
peer_storage.insert(Peer::new(vec![], 0, PacketTransformGateway::default()));
peer_storage.add_ips(&0, &["100.0.0.0/24".parse().unwrap()]);
peer_storage.insert(Peer::new(vec![], 0, PacketTransformGateway::default()));
assert!(peer_storage.get(&0).is_some());
assert!(peer_storage
.peer_by_ip("100.0.0.1".parse().unwrap())
.is_none())
}
}

View File

@@ -29,7 +29,6 @@ pub struct Eventloop {
Result<ResourceDescription<ResolvedResourceDescriptionDns>>,
Either<RequestConnection, AllowAccess>,
>,
print_stats_timer: tokio::time::Interval,
}
impl Eventloop {
@@ -41,7 +40,6 @@ impl Eventloop {
tunnel,
portal,
resolve_tasks: futures_bounded::FuturesTupleSet::new(Duration::from_secs(60), 100),
print_stats_timer: tokio::time::interval(Duration::from_secs(10)),
}
}
}
@@ -104,7 +102,7 @@ impl Eventloop {
Err(e) => {
let client = req.client.id;
self.tunnel.cleanup_connection(client);
self.tunnel.cleanup_connection(&client);
tracing::debug!(%client, "Connection request failed: {:#}", anyhow::Error::new(e));
continue;
@@ -216,11 +214,6 @@ impl Eventloop {
_ => {}
}
if self.print_stats_timer.poll_tick(cx).is_ready() {
tracing::debug!(target: "tunnel_state", stats = ?self.tunnel.stats());
continue;
}
return Poll::Pending;
}
}