diff --git a/rust/connlib/clients/shared/src/control.rs b/rust/connlib/clients/shared/src/control.rs index c09cacf4f..128df5cc7 100644 --- a/rust/connlib/clients/shared/src/control.rs +++ b/rust/connlib/clients/shared/src/control.rs @@ -166,8 +166,8 @@ impl ControlPlane { } #[tracing::instrument(level = "trace", skip(self))] - fn resource_deleted(&self, id: ResourceId) { - // TODO + fn resource_deleted(&mut self, id: ResourceId) { + self.tunnel.remove_resource(id); } fn connection_details( diff --git a/rust/connlib/shared/src/messages.rs b/rust/connlib/shared/src/messages.rs index 04087ab3a..d2dd1a683 100644 --- a/rust/connlib/shared/src/messages.rs +++ b/rust/connlib/shared/src/messages.rs @@ -231,6 +231,18 @@ impl ResourceDescription { ResourceDescription::Cidr(r) => Cow::from(r.address.to_string()), } } + + pub fn has_different_address(&self, other: &ResourceDescription) -> bool { + match (self, other) { + (ResourceDescription::Dns(dns_a), ResourceDescription::Dns(dns_b)) => { + dns_a.address != dns_b.address + } + (ResourceDescription::Cidr(cidr_a), ResourceDescription::Cidr(cidr_b)) => { + cidr_a.address != cidr_b.address + } + _ => true, + } + } } /// Description of a resource that maps to a CIDR. diff --git a/rust/connlib/tunnel/src/client.rs b/rust/connlib/tunnel/src/client.rs index 7262aa6de..a79f649e3 100644 --- a/rust/connlib/tunnel/src/client.rs +++ b/rust/connlib/tunnel/src/client.rs @@ -59,14 +59,10 @@ where &mut self, resource_description: ResourceDescription, ) -> connlib_shared::Result<()> { - if self - .role_state - .resource_ids - .contains_key(&resource_description.id()) - { - // TODO - tracing::info!("Resource updates aren't implemented yet"); - return Ok(()); + if let Some(resource) = self.role_state.resource_ids.get(&resource_description.id()) { + if resource.has_different_address(resource) { + self.remove_resource(resource.id()); + } } match &resource_description { @@ -99,6 +95,63 @@ where Ok(()) } + pub fn remove_resource(&mut self, id: ResourceId) { + self.role_state.awaiting_connection.remove(&id); + self.role_state.awaiting_connection_timers.remove(id); + self.role_state + .dns_resources_internal_ips + .retain(|r, _| r.id != id); + self.role_state.dns_resources.retain(|_, r| r.id != id); + self.role_state.cidr_resources.retain(|_, r| r.id != id); + self.role_state + .deferred_dns_queries + .retain(|(r, _), _| r.id != id); + + if let Some(ResourceDescription::Cidr(resource)) = self.role_state.resource_ids.remove(&id) + { + // Note: hopefully the os doesn't coalece routes in a way that removing a more general route deletes the most specific + if let Err(err) = self.remove_route(resource.address) { + tracing::error!(%id, %resource.address, "failed to remove route: {err:?}"); + } + } + + let Some(gateway_id) = self.role_state.resources_gateways.remove(&id) else { + return; + }; + + let Some(peer) = self.role_state.peers.get_mut(&gateway_id) else { + return; + }; + + // First we remove the id from all allowed ips + for (network, resources) in peer + .allowed_ips + .iter_mut() + .filter(|(_, resources)| resources.contains(&id)) + { + resources.remove(&id); + + if !resources.is_empty() { + continue; + } + + // If the allowed_ips doesn't correspond to any resource anymore we + // clean up any related translation. + peer.transform + .translations + .remove_by_left(&network.network_address()); + } + + // We remove all empty allowed ips entry since there's no resource that corresponds to it + peer.allowed_ips.retain(|_, r| !r.is_empty()); + + // If there's no allowed ip left we remove the whole peer because there's no point on keeping it around + if peer.allowed_ips.is_empty() { + self.role_state.peers.remove(&gateway_id); + // TODO: should we have a Node::remove_connection? + } + } + /// Sets the interface configuration and starts background tasks. #[tracing::instrument(level = "trace", skip(self))] pub fn set_interface( @@ -163,6 +216,22 @@ where Ok(()) } + + #[tracing::instrument(level = "trace", skip(self))] + pub fn remove_route(&mut self, route: IpNetwork) -> connlib_shared::Result<()> { + let callbacks = self.callbacks().clone(); + let maybe_new_device = self + .device + .as_mut() + .ok_or(Error::ControlProtocolError)? + .remove_route(route, &callbacks)?; + + if let Some(new_device) = maybe_new_device { + self.device = Some(new_device); + } + + Ok(()) + } } /// [`Tunnel`] state specific to clients. @@ -189,7 +258,7 @@ pub struct ClientState { pub resource_ids: HashMap, pub deferred_dns_queries: HashMap<(DnsResource, Rtype), IpPacket<'static>>, - pub peers: PeerStore, + pub peers: PeerStore>, forwarded_dns_queries: FuturesTupleSet< Result, @@ -332,11 +401,7 @@ impl ClientState { self.resources_gateways.insert(resource, gateway); - if self - .peers - .add_ips(&gateway, &self.get_resource_ip(desc, &domain)) - .is_none() - { + if self.peers.get(&gateway).is_none() { match self .gateway_awaiting_connection_timers // Note: we don't need to set a timer here because @@ -357,6 +422,9 @@ impl ClientState { return Ok(None); }; + self.peers + .add_ips_with_resource(&gateway, &self.get_resource_ip(desc, &domain), &resource); + self.awaiting_connection.remove(&resource); self.awaiting_connection_timers.remove(resource); diff --git a/rust/connlib/tunnel/src/control_protocol/client.rs b/rust/connlib/tunnel/src/control_protocol/client.rs index 61da858f1..785f3066c 100644 --- a/rust/connlib/tunnel/src/control_protocol/client.rs +++ b/rust/connlib/tunnel/src/control_protocol/client.rs @@ -106,10 +106,11 @@ where &domain_response.as_ref().map(|d| d.domain.clone()), )?; - let mut peer: Peer<_, PacketTransformClient> = - Peer::new(ips.clone(), gateway_id, Default::default()); + let resource_ids = HashSet::from([resource_id]); + let mut peer: Peer<_, PacketTransformClient, _> = + Peer::new(gateway_id, Default::default(), &ips, resource_ids); peer.transform.set_dns(self.role_state.dns_mapping()); - self.role_state.peers.insert(peer); + self.role_state.peers.insert(peer, &[]); let peer_ips = if let Some(domain_response) = domain_response { self.dns_response(&resource_id, &domain_response, &gateway_id)? @@ -117,7 +118,9 @@ where ips }; - self.role_state.peers.add_ips(&gateway_id, &peer_ips); + self.role_state + .peers + .add_ips_with_resource(&gateway_id, &peer_ips, &resource_id); Ok(()) } @@ -232,7 +235,9 @@ where let peer_ips = self.dns_response(&resource_id, &domain_response, &gateway_id)?; - self.role_state.peers.add_ips(&gateway_id, &peer_ips); + self.role_state + .peers + .add_ips_with_resource(&gateway_id, &peer_ips, &resource_id); Ok(()) } diff --git a/rust/connlib/tunnel/src/control_protocol/gateway.rs b/rust/connlib/tunnel/src/control_protocol/gateway.rs index 6760c236f..853fc4b8a 100644 --- a/rust/connlib/tunnel/src/control_protocol/gateway.rs +++ b/rust/connlib/tunnel/src/control_protocol/gateway.rs @@ -167,15 +167,14 @@ where ) -> Result<()> { tracing::trace!(?ips, "new_data_channel_open"); - let mut peer = Peer::new(ips.clone(), client_id, PacketTransformGateway::default()); + let mut peer = Peer::new(client_id, PacketTransformGateway::default(), &ips, ()); for address in resource_addresses { peer.transform .add_resource(address, resource.clone(), expires_at); } - self.role_state.peers.insert(peer); - self.role_state.peers.add_ips(&client_id, &ips); + self.role_state.peers.insert(peer, &ips); Ok(()) } diff --git a/rust/connlib/tunnel/src/device_channel.rs b/rust/connlib/tunnel/src/device_channel.rs index 41cac66d5..d062e0929 100644 --- a/rust/connlib/tunnel/src/device_channel.rs +++ b/rust/connlib/tunnel/src/device_channel.rs @@ -148,6 +148,34 @@ impl Device { })) } + #[cfg(target_family = "unix")] + pub(crate) fn remove_route( + &mut self, + route: IpNetwork, + callbacks: &impl Callbacks, + ) -> Result, Error> { + let Some(tun) = self.tun.remove_route(route, callbacks)? else { + return Ok(None); + }; + let mtu = ioctl::interface_mtu_by_name(tun.name())?; + + Ok(Some(Device { + mtu, + tun, + mtu_refreshed_at: Instant::now(), + })) + } + + #[cfg(target_family = "windows")] + pub(crate) fn remove_route( + &mut self, + route: IpNetwork, + _callbacks: &impl Callbacks, + ) -> Result, Error> { + self.tun.remove_route(route)?; + Ok(None) + } + #[cfg(target_family = "windows")] #[allow(unused_mut)] pub(crate) fn add_route( diff --git a/rust/connlib/tunnel/src/device_channel/tun_android.rs b/rust/connlib/tunnel/src/device_channel/tun_android.rs index 6376cb5bf..5a1a91c22 100644 --- a/rust/connlib/tunnel/src/device_channel/tun_android.rs +++ b/rust/connlib/tunnel/src/device_channel/tun_android.rs @@ -75,6 +75,21 @@ impl Tun { name, })) } + + pub fn remove_route( + &self, + route: IpNetwork, + callbacks: &impl Callbacks, + ) -> Result> { + self.fd.close(); + let fd = callbacks.on_remove_route(route)?.ok_or(Error::NoFd)?; + let name = unsafe { interface_name(fd)? }; + + Ok(Some(Tun { + fd: Closeable::new(AsyncFd::new(fd)?), + name, + })) + } } /// Retrieves the name of the interface pointed to by the provided file descriptor. diff --git a/rust/connlib/tunnel/src/device_channel/tun_darwin.rs b/rust/connlib/tunnel/src/device_channel/tun_darwin.rs index 262e0156d..4840ed635 100644 --- a/rust/connlib/tunnel/src/device_channel/tun_darwin.rs +++ b/rust/connlib/tunnel/src/device_channel/tun_darwin.rs @@ -153,6 +153,16 @@ impl Tun { Ok(None) } + pub fn remove_route( + &self, + route: IpNetwork, + callbacks: &impl Callbacks, + ) -> Result> { + // This will always be None in macos + callbacks.on_remove_route(route)?; + Ok(None) + } + pub fn name(&self) -> &str { self.name.as_str() } diff --git a/rust/connlib/tunnel/src/device_channel/tun_linux.rs b/rust/connlib/tunnel/src/device_channel/tun_linux.rs index 2cfd82208..aa9300920 100644 --- a/rust/connlib/tunnel/src/device_channel/tun_linux.rs +++ b/rust/connlib/tunnel/src/device_channel/tun_linux.rs @@ -6,16 +6,16 @@ use connlib_shared::{ use futures::TryStreamExt; use futures_util::future::BoxFuture; use futures_util::FutureExt; -use ip_network::IpNetwork; +use ip_network::{IpNetwork, Ipv4Network, Ipv6Network}; use libc::{ close, fcntl, makedev, mknod, open, F_GETFL, F_SETFL, IFF_MULTI_QUEUE, IFF_NO_PI, IFF_TUN, O_NONBLOCK, O_RDWR, S_IFCHR, }; use netlink_packet_route::route::{RouteProtocol, RouteScope}; use netlink_packet_route::rule::RuleAction; -use rtnetlink::RuleAddRequest; use rtnetlink::{new_connection, Error::NetlinkError, Handle}; -use std::net::IpAddr; +use rtnetlink::{RouteAddRequest, RuleAddRequest}; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; use std::path::Path; use std::task::{Context, Poll}; use std::{ @@ -154,26 +154,9 @@ impl Tun { .header .index; - let req = handle - .route() - .add() - .output_interface(index) - .protocol(RouteProtocol::Static) - .scope(RouteScope::Universe) - .table_id(FIREZONE_TABLE); let res = match route { - IpNetwork::V4(ipnet) => { - req.v4() - .destination_prefix(ipnet.network_address(), ipnet.netmask()) - .execute() - .await - } - IpNetwork::V6(ipnet) => { - req.v6() - .destination_prefix(ipnet.network_address(), ipnet.netmask()) - .execute() - .await - } + IpNetwork::V4(ipnet) => make_route_v4(index, &handle, ipnet).execute().await, + IpNetwork::V6(ipnet) => make_route_v6(index, &handle, ipnet).execute().await, }; match res { @@ -206,6 +189,53 @@ impl Tun { Ok(None) } + pub fn remove_route(&mut self, route: IpNetwork, _: &impl Callbacks) -> Result> { + let handle = self.handle.clone(); + + let add_route_worker = async move { + let index = handle + .link() + .get() + .match_name(IFACE_NAME.to_string()) + .execute() + .try_next() + .await? + .ok_or(Error::NoIface)? + .header + .index; + + let message = match route { + IpNetwork::V4(ipnet) => make_route_v4(index, &handle, ipnet).message_mut().clone(), + IpNetwork::V6(ipnet) => make_route_v6(index, &handle, ipnet).message_mut().clone(), + }; + + match handle.route().del(message).execute().await { + Ok(_) => Ok(()), + Err(err) => { + tracing::error!(%route, "failed to add route: {err:#?}"); + Ok(()) + } + } + }; + + match self.worker.take() { + None => self.worker = Some(add_route_worker.boxed()), + Some(current_worker) => { + self.worker = Some( + async move { + current_worker.await?; + add_route_worker.await?; + + Ok(()) + } + .boxed(), + ) + } + } + + Ok(None) + } + pub fn name(&self) -> &str { IFACE_NAME } @@ -327,6 +357,28 @@ fn make_rule(handle: &Handle) -> RuleAddRequest { rule } +fn make_route(idx: u32, handle: &Handle) -> RouteAddRequest { + handle + .route() + .add() + .output_interface(idx) + .protocol(RouteProtocol::Static) + .scope(RouteScope::Universe) + .table_id(FIREZONE_TABLE) +} + +fn make_route_v4(idx: u32, handle: &Handle, route: Ipv4Network) -> RouteAddRequest { + make_route(idx, handle) + .v4() + .destination_prefix(route.network_address(), route.netmask()) +} + +fn make_route_v6(idx: u32, handle: &Handle, route: Ipv6Network) -> RouteAddRequest { + make_route(idx, handle) + .v6() + .destination_prefix(route.network_address(), route.netmask()) +} + fn get_last_error() -> Error { Error::Io(io::Error::last_os_error()) } diff --git a/rust/connlib/tunnel/src/device_channel/tun_windows.rs b/rust/connlib/tunnel/src/device_channel/tun_windows.rs index a7a2965a1..8c750fb6a 100644 --- a/rust/connlib/tunnel/src/device_channel/tun_windows.rs +++ b/rust/connlib/tunnel/src/device_channel/tun_windows.rs @@ -13,8 +13,8 @@ use tokio::sync::mpsc; use windows::Win32::{ NetworkManagement::{ IpHelper::{ - CreateIpForwardEntry2, GetIpInterfaceEntry, InitializeIpForwardEntry, - SetIpInterfaceEntry, MIB_IPFORWARD_ROW2, MIB_IPINTERFACE_ROW, + CreateIpForwardEntry2, DeleteIpForwardEntry2, GetIpInterfaceEntry, + InitializeIpForwardEntry, SetIpInterfaceEntry, MIB_IPFORWARD_ROW2, MIB_IPINTERFACE_ROW, }, Ndis::NET_LUID_LH, }, @@ -154,37 +154,27 @@ impl Tun { // It's okay if this blocks until the route is added in the OS. pub fn add_route(&self, route: IpNetwork) -> Result<()> { - tracing::debug!("add_route {route}"); - let mut row = MIB_IPFORWARD_ROW2::default(); - // SAFETY: Windows shouldn't store the reference anywhere, it's just setting defaults - unsafe { InitializeIpForwardEntry(&mut row) }; - - let prefix = &mut row.DestinationPrefix; - match route { - IpNetwork::V4(x) => { - prefix.PrefixLength = x.netmask(); - prefix.Prefix.Ipv4 = SocketAddrV4::new(x.network_address(), 0).into(); - } - IpNetwork::V6(x) => { - prefix.PrefixLength = x.netmask(); - prefix.Prefix.Ipv6 = SocketAddrV6::new(x.network_address(), 0, 0, 0).into(); - } - } - - row.InterfaceIndex = self.iface_idx; - row.Metric = 0; + const DUPLICATE_ERR: u32 = 0x80071392; + let entry = self.forward_entry(route); // SAFETY: Windows shouldn't store the reference anywhere, it's just a way to pass lots of arguments at once. And no other thread sees this variable. - match unsafe { CreateIpForwardEntry2(&row) } { - Ok(_) => {} - Err(e) => { - if e.code().0 as u32 == 0x80071392 { - // "Object already exists" error - tracing::warn!("Failed to add duplicate route, ignoring"); - } else { - Err(e)?; - } + match unsafe { CreateIpForwardEntry2(&entry) } { + Ok(()) => Ok(()), + Err(e) if e.code().0 as u32 == DUPLICATE_ERR => { + tracing::debug!(%route, "Failed to add duplicate route, ignoring"); + Ok(()) } + Err(e) => Err(e.into()), + } + } + + // It's okay if this blocks until the route is added in the OS. + pub fn remove_route(&self, route: IpNetwork) -> Result<()> { + let entry = self.forward_entry(route); + + // SAFETY: Windows shouldn't store the reference anywhere, it's just a way to pass lots of arguments at once. And no other thread sees this variable. + unsafe { + DeleteIpForwardEntry2(&entry)?; } Ok(()) } @@ -239,6 +229,29 @@ impl Tun { self.session.send_packet(pkt); Ok(bytes.len()) } + + fn forward_entry(&self, route: IpNetwork) -> MIB_IPFORWARD_ROW2 { + let mut row = MIB_IPFORWARD_ROW2::default(); + // SAFETY: Windows shouldn't store the reference anywhere, it's just setting defaults + unsafe { InitializeIpForwardEntry(&mut row) }; + + let prefix = &mut row.DestinationPrefix; + match route { + IpNetwork::V4(x) => { + prefix.PrefixLength = x.netmask(); + prefix.Prefix.Ipv4 = SocketAddrV4::new(x.network_address(), 0).into(); + } + IpNetwork::V6(x) => { + prefix.PrefixLength = x.netmask(); + prefix.Prefix.Ipv6 = SocketAddrV6::new(x.network_address(), 0, 0, 0).into(); + } + } + + row.InterfaceIndex = self.iface_idx; + row.Metric = 0; + + row + } } fn start_recv_thread( diff --git a/rust/connlib/tunnel/src/gateway.rs b/rust/connlib/tunnel/src/gateway.rs index ed1e6d0e0..736467822 100644 --- a/rust/connlib/tunnel/src/gateway.rs +++ b/rust/connlib/tunnel/src/gateway.rs @@ -6,7 +6,7 @@ use crate::ip_packet::MutableIpPacket; use crate::peer::PacketTransformGateway; use crate::peer_store::PeerStore; use crate::Tunnel; -use connlib_shared::messages::{ClientId, Interface as InterfaceConfig}; +use connlib_shared::messages::{ClientId, Interface as InterfaceConfig, ResourceId}; use connlib_shared::Callbacks; use snownet::Server; use tokio::time::{interval, Interval, MissedTickBehavior}; @@ -40,11 +40,22 @@ where pub fn cleanup_connection(&mut self, id: &ClientId) { self.role_state.peers.remove(id); } + + pub fn remove_access(&mut self, id: &ClientId, resource_id: &ResourceId) { + let Some(peer) = self.role_state.peers.get_mut(id) else { + return; + }; + + peer.transform.remove_resource(resource_id); + if peer.transform.is_emptied() { + self.role_state.peers.remove(id); + } + } } /// [`Tunnel`] state specific to gateways. pub struct GatewayState { - pub peers: PeerStore, + pub peers: PeerStore, expire_interval: Interval, } diff --git a/rust/connlib/tunnel/src/lib.rs b/rust/connlib/tunnel/src/lib.rs index cae59b1c8..829a61afc 100644 --- a/rust/connlib/tunnel/src/lib.rs +++ b/rust/connlib/tunnel/src/lib.rs @@ -277,14 +277,15 @@ where } // TODO: passing the peer_store looks weird, we can just remove ConnectionState and move everything into Tunnel, there's no Mutexes any longer that justify this separation - fn poll_sockets( + fn poll_sockets( &mut self, device: &mut Device, - peer_store: &mut PeerStore, + peer_store: &mut PeerStore, cx: &mut Context<'_>, ) -> Poll> where TTransform: PacketTransform, + TResource: Clone, { let received = match ready!(self.sockets.poll_recv_from(cx)) { Ok(received) => received, diff --git a/rust/connlib/tunnel/src/peer.rs b/rust/connlib/tunnel/src/peer.rs index bc81ccfae..18cbf0465 100644 --- a/rust/connlib/tunnel/src/peer.rs +++ b/rust/connlib/tunnel/src/peer.rs @@ -1,10 +1,10 @@ -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::net::IpAddr; use std::time::Instant; use bimap::BiMap; use chrono::{DateTime, Utc}; -use connlib_shared::messages::DnsServer; +use connlib_shared::messages::{DnsServer, ResourceId}; use connlib_shared::IpProvider; use connlib_shared::{Error, Result}; use ip_network::IpNetwork; @@ -20,25 +20,44 @@ type ExpiryingResource = (ResourceDescription, Option>); // is 30 seconds. See resolvconf(5) timeout. const IDS_EXPIRE: std::time::Duration = std::time::Duration::from_secs(60); -pub struct Peer { - allowed_ips: IpNetworkTable<()>, +pub struct Peer { + // TODO: we should refactor this + // in the gateway-side this means that we are explicit about () + // maybe duping the Peer struct is the way to go + pub allowed_ips: IpNetworkTable, pub conn_id: TId, pub transform: TTransform, } -impl Peer +impl Peer> where TId: Copy, TTransform: PacketTransform, +{ + pub(crate) fn insert_id(&mut self, ip: &IpNetwork, id: &ResourceId) { + if let Some(resources) = self.allowed_ips.exact_match_mut(*ip) { + resources.insert(*id); + } else { + self.allowed_ips.insert(*ip, HashSet::from([*id])); + } + } +} + +impl Peer +where + TId: Copy, + TTransform: PacketTransform, + TResource: Clone, { pub(crate) fn new( - ips: Vec, conn_id: TId, transform: TTransform, - ) -> Peer { + ips: &[IpNetwork], + resource: TResource, + ) -> Peer { let mut allowed_ips = IpNetworkTable::new(); for ip in ips { - allowed_ips.insert(ip, ()); + allowed_ips.insert(*ip, resource.clone()); } Peer { @@ -48,10 +67,6 @@ where } } - pub(crate) fn add_allowed_ip(&mut self, ip: IpNetwork) { - self.allowed_ips.insert(ip, ()); - } - fn is_allowed(&self, addr: IpAddr) -> bool { self.allowed_ips.longest_match(addr).is_some() } @@ -92,7 +107,7 @@ impl Default for PacketTransformGateway { #[derive(Default)] pub struct PacketTransformClient { - translations: BiMap, + pub translations: BiMap, dns_mapping: BiMap, mangled_dns_ids: HashMap, } @@ -133,6 +148,13 @@ impl PacketTransformGateway { .retain(|_, (_, e)| !e.is_some_and(|e| e <= Utc::now())); } + pub(crate) fn remove_resource(&mut self, resource: &ResourceId) { + self.resources.retain(|_, (r, _)| match r { + connlib_shared::messages::ResourceDescription::Dns(r) => r.id != *resource, + connlib_shared::messages::ResourceDescription::Cidr(r) => r.id != *resource, + }) + } + pub(crate) fn add_resource( &mut self, ip: IpNetwork, diff --git a/rust/connlib/tunnel/src/peer_store.rs b/rust/connlib/tunnel/src/peer_store.rs index a39f4de97..a5c285e86 100644 --- a/rust/connlib/tunnel/src/peer_store.rs +++ b/rust/connlib/tunnel/src/peer_store.rs @@ -1,17 +1,18 @@ -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::hash::Hash; use std::net::IpAddr; use crate::peer::{PacketTransform, Peer}; +use connlib_shared::messages::ResourceId; use ip_network::IpNetwork; use ip_network_table::IpNetworkTable; -pub struct PeerStore { +pub struct PeerStore { id_by_ip: IpNetworkTable, - peer_by_id: HashMap>, + peer_by_id: HashMap>, } -impl Default for PeerStore { +impl Default for PeerStore { fn default() -> Self { Self { id_by_ip: IpNetworkTable::new(), @@ -20,67 +21,92 @@ impl Default for PeerStore { } } -impl PeerStore +impl PeerStore> where - TId: Hash + Eq + Clone + Copy, + TId: Hash + Eq + Copy, TTransform: PacketTransform, { - pub fn retain(&mut self, f: impl Fn(&TId, &mut Peer) -> bool) { + pub fn add_ips_with_resource(&mut self, id: &TId, ips: &[IpNetwork], resource: &ResourceId) { + for ip in ips { + let Some(peer) = self.add_ip(id, ip) else { + continue; + }; + peer.insert_id(ip, resource); + } + } +} + +impl PeerStore +where + TId: Hash + Eq + Copy, + TTransform: PacketTransform, +{ + pub fn retain(&mut self, f: impl Fn(&TId, &mut Peer) -> bool) { self.peer_by_id.retain(f); self.id_by_ip .retain(|_, id| self.peer_by_id.contains_key(id)); } - pub fn add_ips(&mut self, id: &TId, ips: &[IpNetwork]) -> Option<&Peer> { + pub fn add_ip( + &mut self, + id: &TId, + ip: &IpNetwork, + ) -> Option<&mut Peer> { let peer = self.peer_by_id.get_mut(id)?; - - for ip in ips { - self.id_by_ip.insert(*ip, peer.conn_id); - peer.add_allowed_ip(*ip); - } - + self.id_by_ip.insert(*ip, *id); Some(peer) } - pub fn insert(&mut self, peer: Peer) -> Option> { + pub fn insert( + &mut self, + peer: Peer, + ips: &[IpNetwork], + ) -> Option> { self.id_by_ip.retain(|_, &mut r_id| r_id != peer.conn_id); - self.peer_by_id.insert(peer.conn_id, peer) + let id = peer.conn_id; + let old_peer = self.peer_by_id.insert(id, peer); + + for ip in ips { + self.add_ip(&id, ip); + } + + old_peer } - pub fn remove(&mut self, id: &TId) -> Option> { + pub fn remove(&mut self, id: &TId) -> Option> { self.id_by_ip.retain(|_, r_id| r_id != id); self.peer_by_id.remove(id) } - pub fn exact_match(&self, ip: IpNetwork) -> Option<&Peer> { + pub fn exact_match(&self, ip: IpNetwork) -> Option<&Peer> { let ip = self.id_by_ip.exact_match(ip)?; self.peer_by_id.get(ip) } - pub fn get(&self, id: &TId) -> Option<&Peer> { + pub fn get(&self, id: &TId) -> Option<&Peer> { self.peer_by_id.get(id) } - pub fn get_mut(&mut self, id: &TId) -> Option<&mut Peer> { + pub fn get_mut(&mut self, id: &TId) -> Option<&mut Peer> { self.peer_by_id.get_mut(id) } - pub fn peer_by_ip(&self, ip: IpAddr) -> Option<&Peer> { + pub fn peer_by_ip(&self, ip: IpAddr) -> Option<&Peer> { let (_, id) = self.id_by_ip.longest_match(ip)?; self.peer_by_id.get(id) } - pub fn peer_by_ip_mut(&mut self, ip: IpAddr) -> Option<&mut Peer> { + pub fn peer_by_ip_mut(&mut self, ip: IpAddr) -> Option<&mut Peer> { let (_, id) = self.id_by_ip.longest_match(ip)?; self.peer_by_id.get_mut(id) } - pub fn iter_mut(&mut self) -> impl Iterator> { + pub fn iter_mut(&mut self) -> impl Iterator> { self.peer_by_id.values_mut() } - pub fn iter(&mut self) -> impl Iterator> { + pub fn iter(&mut self) -> impl Iterator> { self.peer_by_id.values() } } @@ -93,16 +119,21 @@ mod tests { #[test] fn can_insert_and_retrieve_peer() { - let mut peer_storage = PeerStore::default(); - peer_storage.insert(Peer::new(vec![], 0, PacketTransformGateway::default())); + let mut peer_storage = PeerStore::<_, _, ()>::default(); + peer_storage.insert( + Peer::new(0, PacketTransformGateway::default(), &[], ()), + &[], + ); assert!(peer_storage.get(&0).is_some()); } #[test] fn can_insert_and_retrieve_peer_by_ip() { - let mut peer_storage = PeerStore::default(); - peer_storage.insert(Peer::new(vec![], 0, PacketTransformGateway::default())); - peer_storage.add_ips(&0, &["100.0.0.0/24".parse().unwrap()]); + let mut peer_storage = PeerStore::<_, _, ()>::default(); + peer_storage.insert( + Peer::new(0, PacketTransformGateway::default(), &[], ()), + &["100.0.0.0/24".parse().unwrap()], + ); assert_eq!( peer_storage @@ -115,9 +146,11 @@ mod tests { #[test] fn can_remove_peer() { - let mut peer_storage = PeerStore::default(); - peer_storage.insert(Peer::new(vec![], 0, PacketTransformGateway::default())); - peer_storage.add_ips(&0, &["100.0.0.0/24".parse().unwrap()]); + let mut peer_storage = PeerStore::<_, _, ()>::default(); + peer_storage.insert( + Peer::new(0, PacketTransformGateway::default(), &[], ()), + &["100.0.0.0/24".parse().unwrap()], + ); peer_storage.remove(&0); assert!(peer_storage.get(&0).is_none()); @@ -128,10 +161,15 @@ mod tests { #[test] fn inserting_peer_removes_previous_instances_of_same_id() { - let mut peer_storage = PeerStore::default(); - peer_storage.insert(Peer::new(vec![], 0, PacketTransformGateway::default())); - peer_storage.add_ips(&0, &["100.0.0.0/24".parse().unwrap()]); - peer_storage.insert(Peer::new(vec![], 0, PacketTransformGateway::default())); + let mut peer_storage = PeerStore::<_, _, ()>::default(); + peer_storage.insert( + Peer::new(0, PacketTransformGateway::default(), &[], ()), + &["100.0.0.0/24".parse().unwrap()], + ); + peer_storage.insert( + Peer::new(0, PacketTransformGateway::default(), &[], ()), + &[], + ); assert!(peer_storage.get(&0).is_some()); assert!(peer_storage diff --git a/rust/gateway/src/eventloop.rs b/rust/gateway/src/eventloop.rs index a5cbd1320..59af91d00 100644 --- a/rust/gateway/src/eventloop.rs +++ b/rust/gateway/src/eventloop.rs @@ -1,6 +1,6 @@ use crate::messages::{ AllowAccess, BroadcastClientIceCandidates, ClientIceCandidates, ConnectionReady, - EgressMessages, IngressMessages, RequestConnection, + EgressMessages, IngressMessages, RejectAccess, RequestConnection, }; use crate::CallbackHandler; use anyhow::{anyhow, bail, Result}; @@ -201,6 +201,20 @@ impl Eventloop { } continue; } + + Poll::Ready(phoenix_channel::Event::InboundMessage { + msg: + IngressMessages::RejectAccess(RejectAccess { + client_id, + resource_id, + }), + .. + }) => { + tracing::debug!(client = %client_id, resource = %resource_id, "Access removed"); + + self.tunnel.remove_access(&client_id, &resource_id); + continue; + } Poll::Ready(phoenix_channel::Event::InboundMessage { msg: IngressMessages::Init(_), .. diff --git a/rust/gateway/src/messages.rs b/rust/gateway/src/messages.rs index 0bbce9e50..5e93d51f4 100644 --- a/rust/gateway/src/messages.rs +++ b/rust/gateway/src/messages.rs @@ -86,6 +86,12 @@ pub struct AllowAccess { pub reference: String, } +#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)] +pub struct RejectAccess { + pub client_id: ClientId, + pub resource_id: ResourceId, +} + // These messages are the messages that can be received // either by a client or a gateway by the client. #[derive(Debug, Deserialize, Clone, PartialEq, Eq)] @@ -93,6 +99,7 @@ pub struct AllowAccess { pub enum IngressMessages { RequestConnection(RequestConnection), AllowAccess(AllowAccess), + RejectAccess(RejectAccess), IceCandidates(ClientIceCandidates), Init(InitGateway), }