feat(connlib): support resource updates from the portal (#3754)

This PR doesn't yet provide support for the update of upstream DNS but
it does provide support for all the other resources update messages.

Should comply with the description of issue #2022 but it doesn't respond
to DNS upstream updates which is imply it should on the issue title

---------

Signed-off-by: Gabi <gabrielalejandro7@gmail.com>
Co-authored-by: Thomas Eizinger <thomas@eizinger.io>
This commit is contained in:
Gabi
2024-02-27 00:24:14 -03:00
committed by GitHub
parent 67aeb009e9
commit 77b00b3be9
16 changed files with 425 additions and 130 deletions

View File

@@ -166,8 +166,8 @@ impl<CB: Callbacks + 'static> ControlPlane<CB> {
}
#[tracing::instrument(level = "trace", skip(self))]
fn resource_deleted(&self, id: ResourceId) {
// TODO
fn resource_deleted(&mut self, id: ResourceId) {
self.tunnel.remove_resource(id);
}
fn connection_details(

View File

@@ -231,6 +231,18 @@ impl ResourceDescription {
ResourceDescription::Cidr(r) => Cow::from(r.address.to_string()),
}
}
pub fn has_different_address(&self, other: &ResourceDescription) -> bool {
match (self, other) {
(ResourceDescription::Dns(dns_a), ResourceDescription::Dns(dns_b)) => {
dns_a.address != dns_b.address
}
(ResourceDescription::Cidr(cidr_a), ResourceDescription::Cidr(cidr_b)) => {
cidr_a.address != cidr_b.address
}
_ => true,
}
}
}
/// Description of a resource that maps to a CIDR.

View File

@@ -59,14 +59,10 @@ where
&mut self,
resource_description: ResourceDescription,
) -> connlib_shared::Result<()> {
if self
.role_state
.resource_ids
.contains_key(&resource_description.id())
{
// TODO
tracing::info!("Resource updates aren't implemented yet");
return Ok(());
if let Some(resource) = self.role_state.resource_ids.get(&resource_description.id()) {
if resource.has_different_address(resource) {
self.remove_resource(resource.id());
}
}
match &resource_description {
@@ -99,6 +95,63 @@ where
Ok(())
}
pub fn remove_resource(&mut self, id: ResourceId) {
self.role_state.awaiting_connection.remove(&id);
self.role_state.awaiting_connection_timers.remove(id);
self.role_state
.dns_resources_internal_ips
.retain(|r, _| r.id != id);
self.role_state.dns_resources.retain(|_, r| r.id != id);
self.role_state.cidr_resources.retain(|_, r| r.id != id);
self.role_state
.deferred_dns_queries
.retain(|(r, _), _| r.id != id);
if let Some(ResourceDescription::Cidr(resource)) = self.role_state.resource_ids.remove(&id)
{
// Note: hopefully the os doesn't coalece routes in a way that removing a more general route deletes the most specific
if let Err(err) = self.remove_route(resource.address) {
tracing::error!(%id, %resource.address, "failed to remove route: {err:?}");
}
}
let Some(gateway_id) = self.role_state.resources_gateways.remove(&id) else {
return;
};
let Some(peer) = self.role_state.peers.get_mut(&gateway_id) else {
return;
};
// First we remove the id from all allowed ips
for (network, resources) in peer
.allowed_ips
.iter_mut()
.filter(|(_, resources)| resources.contains(&id))
{
resources.remove(&id);
if !resources.is_empty() {
continue;
}
// If the allowed_ips doesn't correspond to any resource anymore we
// clean up any related translation.
peer.transform
.translations
.remove_by_left(&network.network_address());
}
// We remove all empty allowed ips entry since there's no resource that corresponds to it
peer.allowed_ips.retain(|_, r| !r.is_empty());
// If there's no allowed ip left we remove the whole peer because there's no point on keeping it around
if peer.allowed_ips.is_empty() {
self.role_state.peers.remove(&gateway_id);
// TODO: should we have a Node::remove_connection?
}
}
/// Sets the interface configuration and starts background tasks.
#[tracing::instrument(level = "trace", skip(self))]
pub fn set_interface(
@@ -163,6 +216,22 @@ where
Ok(())
}
#[tracing::instrument(level = "trace", skip(self))]
pub fn remove_route(&mut self, route: IpNetwork) -> connlib_shared::Result<()> {
let callbacks = self.callbacks().clone();
let maybe_new_device = self
.device
.as_mut()
.ok_or(Error::ControlProtocolError)?
.remove_route(route, &callbacks)?;
if let Some(new_device) = maybe_new_device {
self.device = Some(new_device);
}
Ok(())
}
}
/// [`Tunnel`] state specific to clients.
@@ -189,7 +258,7 @@ pub struct ClientState {
pub resource_ids: HashMap<ResourceId, ResourceDescription>,
pub deferred_dns_queries: HashMap<(DnsResource, Rtype), IpPacket<'static>>,
pub peers: PeerStore<GatewayId, PacketTransformClient>,
pub peers: PeerStore<GatewayId, PacketTransformClient, HashSet<ResourceId>>,
forwarded_dns_queries: FuturesTupleSet<
Result<hickory_resolver::lookup::Lookup, hickory_resolver::error::ResolveError>,
@@ -332,11 +401,7 @@ impl ClientState {
self.resources_gateways.insert(resource, gateway);
if self
.peers
.add_ips(&gateway, &self.get_resource_ip(desc, &domain))
.is_none()
{
if self.peers.get(&gateway).is_none() {
match self
.gateway_awaiting_connection_timers
// Note: we don't need to set a timer here because
@@ -357,6 +422,9 @@ impl ClientState {
return Ok(None);
};
self.peers
.add_ips_with_resource(&gateway, &self.get_resource_ip(desc, &domain), &resource);
self.awaiting_connection.remove(&resource);
self.awaiting_connection_timers.remove(resource);

View File

@@ -106,10 +106,11 @@ where
&domain_response.as_ref().map(|d| d.domain.clone()),
)?;
let mut peer: Peer<_, PacketTransformClient> =
Peer::new(ips.clone(), gateway_id, Default::default());
let resource_ids = HashSet::from([resource_id]);
let mut peer: Peer<_, PacketTransformClient, _> =
Peer::new(gateway_id, Default::default(), &ips, resource_ids);
peer.transform.set_dns(self.role_state.dns_mapping());
self.role_state.peers.insert(peer);
self.role_state.peers.insert(peer, &[]);
let peer_ips = if let Some(domain_response) = domain_response {
self.dns_response(&resource_id, &domain_response, &gateway_id)?
@@ -117,7 +118,9 @@ where
ips
};
self.role_state.peers.add_ips(&gateway_id, &peer_ips);
self.role_state
.peers
.add_ips_with_resource(&gateway_id, &peer_ips, &resource_id);
Ok(())
}
@@ -232,7 +235,9 @@ where
let peer_ips = self.dns_response(&resource_id, &domain_response, &gateway_id)?;
self.role_state.peers.add_ips(&gateway_id, &peer_ips);
self.role_state
.peers
.add_ips_with_resource(&gateway_id, &peer_ips, &resource_id);
Ok(())
}

View File

@@ -167,15 +167,14 @@ where
) -> Result<()> {
tracing::trace!(?ips, "new_data_channel_open");
let mut peer = Peer::new(ips.clone(), client_id, PacketTransformGateway::default());
let mut peer = Peer::new(client_id, PacketTransformGateway::default(), &ips, ());
for address in resource_addresses {
peer.transform
.add_resource(address, resource.clone(), expires_at);
}
self.role_state.peers.insert(peer);
self.role_state.peers.add_ips(&client_id, &ips);
self.role_state.peers.insert(peer, &ips);
Ok(())
}

View File

@@ -148,6 +148,34 @@ impl Device {
}))
}
#[cfg(target_family = "unix")]
pub(crate) fn remove_route(
&mut self,
route: IpNetwork,
callbacks: &impl Callbacks<Error = Error>,
) -> Result<Option<Device>, Error> {
let Some(tun) = self.tun.remove_route(route, callbacks)? else {
return Ok(None);
};
let mtu = ioctl::interface_mtu_by_name(tun.name())?;
Ok(Some(Device {
mtu,
tun,
mtu_refreshed_at: Instant::now(),
}))
}
#[cfg(target_family = "windows")]
pub(crate) fn remove_route(
&mut self,
route: IpNetwork,
_callbacks: &impl Callbacks<Error = Error>,
) -> Result<Option<Device>, Error> {
self.tun.remove_route(route)?;
Ok(None)
}
#[cfg(target_family = "windows")]
#[allow(unused_mut)]
pub(crate) fn add_route(

View File

@@ -75,6 +75,21 @@ impl Tun {
name,
}))
}
pub fn remove_route(
&self,
route: IpNetwork,
callbacks: &impl Callbacks<Error = Error>,
) -> Result<Option<Self>> {
self.fd.close();
let fd = callbacks.on_remove_route(route)?.ok_or(Error::NoFd)?;
let name = unsafe { interface_name(fd)? };
Ok(Some(Tun {
fd: Closeable::new(AsyncFd::new(fd)?),
name,
}))
}
}
/// Retrieves the name of the interface pointed to by the provided file descriptor.

View File

@@ -153,6 +153,16 @@ impl Tun {
Ok(None)
}
pub fn remove_route(
&self,
route: IpNetwork,
callbacks: &impl Callbacks<Error = Error>,
) -> Result<Option<Self>> {
// This will always be None in macos
callbacks.on_remove_route(route)?;
Ok(None)
}
pub fn name(&self) -> &str {
self.name.as_str()
}

View File

@@ -6,16 +6,16 @@ use connlib_shared::{
use futures::TryStreamExt;
use futures_util::future::BoxFuture;
use futures_util::FutureExt;
use ip_network::IpNetwork;
use ip_network::{IpNetwork, Ipv4Network, Ipv6Network};
use libc::{
close, fcntl, makedev, mknod, open, F_GETFL, F_SETFL, IFF_MULTI_QUEUE, IFF_NO_PI, IFF_TUN,
O_NONBLOCK, O_RDWR, S_IFCHR,
};
use netlink_packet_route::route::{RouteProtocol, RouteScope};
use netlink_packet_route::rule::RuleAction;
use rtnetlink::RuleAddRequest;
use rtnetlink::{new_connection, Error::NetlinkError, Handle};
use std::net::IpAddr;
use rtnetlink::{RouteAddRequest, RuleAddRequest};
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
use std::path::Path;
use std::task::{Context, Poll};
use std::{
@@ -154,26 +154,9 @@ impl Tun {
.header
.index;
let req = handle
.route()
.add()
.output_interface(index)
.protocol(RouteProtocol::Static)
.scope(RouteScope::Universe)
.table_id(FIREZONE_TABLE);
let res = match route {
IpNetwork::V4(ipnet) => {
req.v4()
.destination_prefix(ipnet.network_address(), ipnet.netmask())
.execute()
.await
}
IpNetwork::V6(ipnet) => {
req.v6()
.destination_prefix(ipnet.network_address(), ipnet.netmask())
.execute()
.await
}
IpNetwork::V4(ipnet) => make_route_v4(index, &handle, ipnet).execute().await,
IpNetwork::V6(ipnet) => make_route_v6(index, &handle, ipnet).execute().await,
};
match res {
@@ -206,6 +189,53 @@ impl Tun {
Ok(None)
}
pub fn remove_route(&mut self, route: IpNetwork, _: &impl Callbacks) -> Result<Option<Self>> {
let handle = self.handle.clone();
let add_route_worker = async move {
let index = handle
.link()
.get()
.match_name(IFACE_NAME.to_string())
.execute()
.try_next()
.await?
.ok_or(Error::NoIface)?
.header
.index;
let message = match route {
IpNetwork::V4(ipnet) => make_route_v4(index, &handle, ipnet).message_mut().clone(),
IpNetwork::V6(ipnet) => make_route_v6(index, &handle, ipnet).message_mut().clone(),
};
match handle.route().del(message).execute().await {
Ok(_) => Ok(()),
Err(err) => {
tracing::error!(%route, "failed to add route: {err:#?}");
Ok(())
}
}
};
match self.worker.take() {
None => self.worker = Some(add_route_worker.boxed()),
Some(current_worker) => {
self.worker = Some(
async move {
current_worker.await?;
add_route_worker.await?;
Ok(())
}
.boxed(),
)
}
}
Ok(None)
}
pub fn name(&self) -> &str {
IFACE_NAME
}
@@ -327,6 +357,28 @@ fn make_rule(handle: &Handle) -> RuleAddRequest {
rule
}
fn make_route(idx: u32, handle: &Handle) -> RouteAddRequest {
handle
.route()
.add()
.output_interface(idx)
.protocol(RouteProtocol::Static)
.scope(RouteScope::Universe)
.table_id(FIREZONE_TABLE)
}
fn make_route_v4(idx: u32, handle: &Handle, route: Ipv4Network) -> RouteAddRequest<Ipv4Addr> {
make_route(idx, handle)
.v4()
.destination_prefix(route.network_address(), route.netmask())
}
fn make_route_v6(idx: u32, handle: &Handle, route: Ipv6Network) -> RouteAddRequest<Ipv6Addr> {
make_route(idx, handle)
.v6()
.destination_prefix(route.network_address(), route.netmask())
}
fn get_last_error() -> Error {
Error::Io(io::Error::last_os_error())
}

View File

@@ -13,8 +13,8 @@ use tokio::sync::mpsc;
use windows::Win32::{
NetworkManagement::{
IpHelper::{
CreateIpForwardEntry2, GetIpInterfaceEntry, InitializeIpForwardEntry,
SetIpInterfaceEntry, MIB_IPFORWARD_ROW2, MIB_IPINTERFACE_ROW,
CreateIpForwardEntry2, DeleteIpForwardEntry2, GetIpInterfaceEntry,
InitializeIpForwardEntry, SetIpInterfaceEntry, MIB_IPFORWARD_ROW2, MIB_IPINTERFACE_ROW,
},
Ndis::NET_LUID_LH,
},
@@ -154,37 +154,27 @@ impl Tun {
// It's okay if this blocks until the route is added in the OS.
pub fn add_route(&self, route: IpNetwork) -> Result<()> {
tracing::debug!("add_route {route}");
let mut row = MIB_IPFORWARD_ROW2::default();
// SAFETY: Windows shouldn't store the reference anywhere, it's just setting defaults
unsafe { InitializeIpForwardEntry(&mut row) };
let prefix = &mut row.DestinationPrefix;
match route {
IpNetwork::V4(x) => {
prefix.PrefixLength = x.netmask();
prefix.Prefix.Ipv4 = SocketAddrV4::new(x.network_address(), 0).into();
}
IpNetwork::V6(x) => {
prefix.PrefixLength = x.netmask();
prefix.Prefix.Ipv6 = SocketAddrV6::new(x.network_address(), 0, 0, 0).into();
}
}
row.InterfaceIndex = self.iface_idx;
row.Metric = 0;
const DUPLICATE_ERR: u32 = 0x80071392;
let entry = self.forward_entry(route);
// SAFETY: Windows shouldn't store the reference anywhere, it's just a way to pass lots of arguments at once. And no other thread sees this variable.
match unsafe { CreateIpForwardEntry2(&row) } {
Ok(_) => {}
Err(e) => {
if e.code().0 as u32 == 0x80071392 {
// "Object already exists" error
tracing::warn!("Failed to add duplicate route, ignoring");
} else {
Err(e)?;
}
match unsafe { CreateIpForwardEntry2(&entry) } {
Ok(()) => Ok(()),
Err(e) if e.code().0 as u32 == DUPLICATE_ERR => {
tracing::debug!(%route, "Failed to add duplicate route, ignoring");
Ok(())
}
Err(e) => Err(e.into()),
}
}
// It's okay if this blocks until the route is added in the OS.
pub fn remove_route(&self, route: IpNetwork) -> Result<()> {
let entry = self.forward_entry(route);
// SAFETY: Windows shouldn't store the reference anywhere, it's just a way to pass lots of arguments at once. And no other thread sees this variable.
unsafe {
DeleteIpForwardEntry2(&entry)?;
}
Ok(())
}
@@ -239,6 +229,29 @@ impl Tun {
self.session.send_packet(pkt);
Ok(bytes.len())
}
fn forward_entry(&self, route: IpNetwork) -> MIB_IPFORWARD_ROW2 {
let mut row = MIB_IPFORWARD_ROW2::default();
// SAFETY: Windows shouldn't store the reference anywhere, it's just setting defaults
unsafe { InitializeIpForwardEntry(&mut row) };
let prefix = &mut row.DestinationPrefix;
match route {
IpNetwork::V4(x) => {
prefix.PrefixLength = x.netmask();
prefix.Prefix.Ipv4 = SocketAddrV4::new(x.network_address(), 0).into();
}
IpNetwork::V6(x) => {
prefix.PrefixLength = x.netmask();
prefix.Prefix.Ipv6 = SocketAddrV6::new(x.network_address(), 0, 0, 0).into();
}
}
row.InterfaceIndex = self.iface_idx;
row.Metric = 0;
row
}
}
fn start_recv_thread(

View File

@@ -6,7 +6,7 @@ use crate::ip_packet::MutableIpPacket;
use crate::peer::PacketTransformGateway;
use crate::peer_store::PeerStore;
use crate::Tunnel;
use connlib_shared::messages::{ClientId, Interface as InterfaceConfig};
use connlib_shared::messages::{ClientId, Interface as InterfaceConfig, ResourceId};
use connlib_shared::Callbacks;
use snownet::Server;
use tokio::time::{interval, Interval, MissedTickBehavior};
@@ -40,11 +40,22 @@ where
pub fn cleanup_connection(&mut self, id: &ClientId) {
self.role_state.peers.remove(id);
}
pub fn remove_access(&mut self, id: &ClientId, resource_id: &ResourceId) {
let Some(peer) = self.role_state.peers.get_mut(id) else {
return;
};
peer.transform.remove_resource(resource_id);
if peer.transform.is_emptied() {
self.role_state.peers.remove(id);
}
}
}
/// [`Tunnel`] state specific to gateways.
pub struct GatewayState {
pub peers: PeerStore<ClientId, PacketTransformGateway>,
pub peers: PeerStore<ClientId, PacketTransformGateway, ()>,
expire_interval: Interval,
}

View File

@@ -277,14 +277,15 @@ where
}
// TODO: passing the peer_store looks weird, we can just remove ConnectionState and move everything into Tunnel, there's no Mutexes any longer that justify this separation
fn poll_sockets<TTransform>(
fn poll_sockets<TTransform, TResource>(
&mut self,
device: &mut Device,
peer_store: &mut PeerStore<TId, TTransform>,
peer_store: &mut PeerStore<TId, TTransform, TResource>,
cx: &mut Context<'_>,
) -> Poll<io::Result<()>>
where
TTransform: PacketTransform,
TResource: Clone,
{
let received = match ready!(self.sockets.poll_recv_from(cx)) {
Ok(received) => received,

View File

@@ -1,10 +1,10 @@
use std::collections::HashMap;
use std::collections::{HashMap, HashSet};
use std::net::IpAddr;
use std::time::Instant;
use bimap::BiMap;
use chrono::{DateTime, Utc};
use connlib_shared::messages::DnsServer;
use connlib_shared::messages::{DnsServer, ResourceId};
use connlib_shared::IpProvider;
use connlib_shared::{Error, Result};
use ip_network::IpNetwork;
@@ -20,25 +20,44 @@ type ExpiryingResource = (ResourceDescription, Option<DateTime<Utc>>);
// is 30 seconds. See resolvconf(5) timeout.
const IDS_EXPIRE: std::time::Duration = std::time::Duration::from_secs(60);
pub struct Peer<TId, TTransform> {
allowed_ips: IpNetworkTable<()>,
pub struct Peer<TId, TTransform, TResource> {
// TODO: we should refactor this
// in the gateway-side this means that we are explicit about ()
// maybe duping the Peer struct is the way to go
pub allowed_ips: IpNetworkTable<TResource>,
pub conn_id: TId,
pub transform: TTransform,
}
impl<TId, TTransform> Peer<TId, TTransform>
impl<TId, TTransform> Peer<TId, TTransform, HashSet<ResourceId>>
where
TId: Copy,
TTransform: PacketTransform,
{
pub(crate) fn insert_id(&mut self, ip: &IpNetwork, id: &ResourceId) {
if let Some(resources) = self.allowed_ips.exact_match_mut(*ip) {
resources.insert(*id);
} else {
self.allowed_ips.insert(*ip, HashSet::from([*id]));
}
}
}
impl<TId, TTransform, TResource> Peer<TId, TTransform, TResource>
where
TId: Copy,
TTransform: PacketTransform,
TResource: Clone,
{
pub(crate) fn new(
ips: Vec<IpNetwork>,
conn_id: TId,
transform: TTransform,
) -> Peer<TId, TTransform> {
ips: &[IpNetwork],
resource: TResource,
) -> Peer<TId, TTransform, TResource> {
let mut allowed_ips = IpNetworkTable::new();
for ip in ips {
allowed_ips.insert(ip, ());
allowed_ips.insert(*ip, resource.clone());
}
Peer {
@@ -48,10 +67,6 @@ where
}
}
pub(crate) fn add_allowed_ip(&mut self, ip: IpNetwork) {
self.allowed_ips.insert(ip, ());
}
fn is_allowed(&self, addr: IpAddr) -> bool {
self.allowed_ips.longest_match(addr).is_some()
}
@@ -92,7 +107,7 @@ impl Default for PacketTransformGateway {
#[derive(Default)]
pub struct PacketTransformClient {
translations: BiMap<IpAddr, IpAddr>,
pub translations: BiMap<IpAddr, IpAddr>,
dns_mapping: BiMap<IpAddr, DnsServer>,
mangled_dns_ids: HashMap<u16, std::time::Instant>,
}
@@ -133,6 +148,13 @@ impl PacketTransformGateway {
.retain(|_, (_, e)| !e.is_some_and(|e| e <= Utc::now()));
}
pub(crate) fn remove_resource(&mut self, resource: &ResourceId) {
self.resources.retain(|_, (r, _)| match r {
connlib_shared::messages::ResourceDescription::Dns(r) => r.id != *resource,
connlib_shared::messages::ResourceDescription::Cidr(r) => r.id != *resource,
})
}
pub(crate) fn add_resource(
&mut self,
ip: IpNetwork,

View File

@@ -1,17 +1,18 @@
use std::collections::HashMap;
use std::collections::{HashMap, HashSet};
use std::hash::Hash;
use std::net::IpAddr;
use crate::peer::{PacketTransform, Peer};
use connlib_shared::messages::ResourceId;
use ip_network::IpNetwork;
use ip_network_table::IpNetworkTable;
pub struct PeerStore<TId, TTransform> {
pub struct PeerStore<TId, TTransform, TResource> {
id_by_ip: IpNetworkTable<TId>,
peer_by_id: HashMap<TId, Peer<TId, TTransform>>,
peer_by_id: HashMap<TId, Peer<TId, TTransform, TResource>>,
}
impl<T, U> Default for PeerStore<T, U> {
impl<TId, TTransform, TResource> Default for PeerStore<TId, TTransform, TResource> {
fn default() -> Self {
Self {
id_by_ip: IpNetworkTable::new(),
@@ -20,67 +21,92 @@ impl<T, U> Default for PeerStore<T, U> {
}
}
impl<TId, TTransform> PeerStore<TId, TTransform>
impl<TId, TTransform> PeerStore<TId, TTransform, HashSet<ResourceId>>
where
TId: Hash + Eq + Clone + Copy,
TId: Hash + Eq + Copy,
TTransform: PacketTransform,
{
pub fn retain(&mut self, f: impl Fn(&TId, &mut Peer<TId, TTransform>) -> bool) {
pub fn add_ips_with_resource(&mut self, id: &TId, ips: &[IpNetwork], resource: &ResourceId) {
for ip in ips {
let Some(peer) = self.add_ip(id, ip) else {
continue;
};
peer.insert_id(ip, resource);
}
}
}
impl<TId, TTransform, TResource> PeerStore<TId, TTransform, TResource>
where
TId: Hash + Eq + Copy,
TTransform: PacketTransform,
{
pub fn retain(&mut self, f: impl Fn(&TId, &mut Peer<TId, TTransform, TResource>) -> bool) {
self.peer_by_id.retain(f);
self.id_by_ip
.retain(|_, id| self.peer_by_id.contains_key(id));
}
pub fn add_ips(&mut self, id: &TId, ips: &[IpNetwork]) -> Option<&Peer<TId, TTransform>> {
pub fn add_ip(
&mut self,
id: &TId,
ip: &IpNetwork,
) -> Option<&mut Peer<TId, TTransform, TResource>> {
let peer = self.peer_by_id.get_mut(id)?;
for ip in ips {
self.id_by_ip.insert(*ip, peer.conn_id);
peer.add_allowed_ip(*ip);
}
self.id_by_ip.insert(*ip, *id);
Some(peer)
}
pub fn insert(&mut self, peer: Peer<TId, TTransform>) -> Option<Peer<TId, TTransform>> {
pub fn insert(
&mut self,
peer: Peer<TId, TTransform, TResource>,
ips: &[IpNetwork],
) -> Option<Peer<TId, TTransform, TResource>> {
self.id_by_ip.retain(|_, &mut r_id| r_id != peer.conn_id);
self.peer_by_id.insert(peer.conn_id, peer)
let id = peer.conn_id;
let old_peer = self.peer_by_id.insert(id, peer);
for ip in ips {
self.add_ip(&id, ip);
}
old_peer
}
pub fn remove(&mut self, id: &TId) -> Option<Peer<TId, TTransform>> {
pub fn remove(&mut self, id: &TId) -> Option<Peer<TId, TTransform, TResource>> {
self.id_by_ip.retain(|_, r_id| r_id != id);
self.peer_by_id.remove(id)
}
pub fn exact_match(&self, ip: IpNetwork) -> Option<&Peer<TId, TTransform>> {
pub fn exact_match(&self, ip: IpNetwork) -> Option<&Peer<TId, TTransform, TResource>> {
let ip = self.id_by_ip.exact_match(ip)?;
self.peer_by_id.get(ip)
}
pub fn get(&self, id: &TId) -> Option<&Peer<TId, TTransform>> {
pub fn get(&self, id: &TId) -> Option<&Peer<TId, TTransform, TResource>> {
self.peer_by_id.get(id)
}
pub fn get_mut(&mut self, id: &TId) -> Option<&mut Peer<TId, TTransform>> {
pub fn get_mut(&mut self, id: &TId) -> Option<&mut Peer<TId, TTransform, TResource>> {
self.peer_by_id.get_mut(id)
}
pub fn peer_by_ip(&self, ip: IpAddr) -> Option<&Peer<TId, TTransform>> {
pub fn peer_by_ip(&self, ip: IpAddr) -> Option<&Peer<TId, TTransform, TResource>> {
let (_, id) = self.id_by_ip.longest_match(ip)?;
self.peer_by_id.get(id)
}
pub fn peer_by_ip_mut(&mut self, ip: IpAddr) -> Option<&mut Peer<TId, TTransform>> {
pub fn peer_by_ip_mut(&mut self, ip: IpAddr) -> Option<&mut Peer<TId, TTransform, TResource>> {
let (_, id) = self.id_by_ip.longest_match(ip)?;
self.peer_by_id.get_mut(id)
}
pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut Peer<TId, TTransform>> {
pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut Peer<TId, TTransform, TResource>> {
self.peer_by_id.values_mut()
}
pub fn iter(&mut self) -> impl Iterator<Item = &Peer<TId, TTransform>> {
pub fn iter(&mut self) -> impl Iterator<Item = &Peer<TId, TTransform, TResource>> {
self.peer_by_id.values()
}
}
@@ -93,16 +119,21 @@ mod tests {
#[test]
fn can_insert_and_retrieve_peer() {
let mut peer_storage = PeerStore::default();
peer_storage.insert(Peer::new(vec![], 0, PacketTransformGateway::default()));
let mut peer_storage = PeerStore::<_, _, ()>::default();
peer_storage.insert(
Peer::new(0, PacketTransformGateway::default(), &[], ()),
&[],
);
assert!(peer_storage.get(&0).is_some());
}
#[test]
fn can_insert_and_retrieve_peer_by_ip() {
let mut peer_storage = PeerStore::default();
peer_storage.insert(Peer::new(vec![], 0, PacketTransformGateway::default()));
peer_storage.add_ips(&0, &["100.0.0.0/24".parse().unwrap()]);
let mut peer_storage = PeerStore::<_, _, ()>::default();
peer_storage.insert(
Peer::new(0, PacketTransformGateway::default(), &[], ()),
&["100.0.0.0/24".parse().unwrap()],
);
assert_eq!(
peer_storage
@@ -115,9 +146,11 @@ mod tests {
#[test]
fn can_remove_peer() {
let mut peer_storage = PeerStore::default();
peer_storage.insert(Peer::new(vec![], 0, PacketTransformGateway::default()));
peer_storage.add_ips(&0, &["100.0.0.0/24".parse().unwrap()]);
let mut peer_storage = PeerStore::<_, _, ()>::default();
peer_storage.insert(
Peer::new(0, PacketTransformGateway::default(), &[], ()),
&["100.0.0.0/24".parse().unwrap()],
);
peer_storage.remove(&0);
assert!(peer_storage.get(&0).is_none());
@@ -128,10 +161,15 @@ mod tests {
#[test]
fn inserting_peer_removes_previous_instances_of_same_id() {
let mut peer_storage = PeerStore::default();
peer_storage.insert(Peer::new(vec![], 0, PacketTransformGateway::default()));
peer_storage.add_ips(&0, &["100.0.0.0/24".parse().unwrap()]);
peer_storage.insert(Peer::new(vec![], 0, PacketTransformGateway::default()));
let mut peer_storage = PeerStore::<_, _, ()>::default();
peer_storage.insert(
Peer::new(0, PacketTransformGateway::default(), &[], ()),
&["100.0.0.0/24".parse().unwrap()],
);
peer_storage.insert(
Peer::new(0, PacketTransformGateway::default(), &[], ()),
&[],
);
assert!(peer_storage.get(&0).is_some());
assert!(peer_storage

View File

@@ -1,6 +1,6 @@
use crate::messages::{
AllowAccess, BroadcastClientIceCandidates, ClientIceCandidates, ConnectionReady,
EgressMessages, IngressMessages, RequestConnection,
EgressMessages, IngressMessages, RejectAccess, RequestConnection,
};
use crate::CallbackHandler;
use anyhow::{anyhow, bail, Result};
@@ -201,6 +201,20 @@ impl Eventloop {
}
continue;
}
Poll::Ready(phoenix_channel::Event::InboundMessage {
msg:
IngressMessages::RejectAccess(RejectAccess {
client_id,
resource_id,
}),
..
}) => {
tracing::debug!(client = %client_id, resource = %resource_id, "Access removed");
self.tunnel.remove_access(&client_id, &resource_id);
continue;
}
Poll::Ready(phoenix_channel::Event::InboundMessage {
msg: IngressMessages::Init(_),
..

View File

@@ -86,6 +86,12 @@ pub struct AllowAccess {
pub reference: String,
}
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
pub struct RejectAccess {
pub client_id: ClientId,
pub resource_id: ResourceId,
}
// These messages are the messages that can be received
// either by a client or a gateway by the client.
#[derive(Debug, Deserialize, Clone, PartialEq, Eq)]
@@ -93,6 +99,7 @@ pub struct AllowAccess {
pub enum IngressMessages {
RequestConnection(RequestConnection),
AllowAccess(AllowAccess),
RejectAccess(RejectAccess),
IceCandidates(ClientIceCandidates),
Init(InitGateway),
}