From 77b00b3be957b7c176807df703b661cb597a3532 Mon Sep 17 00:00:00 2001 From: Gabi Date: Tue, 27 Feb 2024 00:24:14 -0300 Subject: [PATCH] feat(connlib): support resource updates from the portal (#3754) This PR doesn't yet provide support for the update of upstream DNS but it does provide support for all the other resources update messages. Should comply with the description of issue #2022 but it doesn't respond to DNS upstream updates which is imply it should on the issue title --------- Signed-off-by: Gabi Co-authored-by: Thomas Eizinger --- rust/connlib/clients/shared/src/control.rs | 4 +- rust/connlib/shared/src/messages.rs | 12 ++ rust/connlib/tunnel/src/client.rs | 96 ++++++++++++--- .../tunnel/src/control_protocol/client.rs | 15 ++- .../tunnel/src/control_protocol/gateway.rs | 5 +- rust/connlib/tunnel/src/device_channel.rs | 28 +++++ .../tunnel/src/device_channel/tun_android.rs | 15 +++ .../tunnel/src/device_channel/tun_darwin.rs | 10 ++ .../tunnel/src/device_channel/tun_linux.rs | 96 +++++++++++---- .../tunnel/src/device_channel/tun_windows.rs | 73 +++++++----- rust/connlib/tunnel/src/gateway.rs | 15 ++- rust/connlib/tunnel/src/lib.rs | 5 +- rust/connlib/tunnel/src/peer.rs | 48 +++++--- rust/connlib/tunnel/src/peer_store.rs | 110 ++++++++++++------ rust/gateway/src/eventloop.rs | 16 ++- rust/gateway/src/messages.rs | 7 ++ 16 files changed, 425 insertions(+), 130 deletions(-) diff --git a/rust/connlib/clients/shared/src/control.rs b/rust/connlib/clients/shared/src/control.rs index c09cacf4f..128df5cc7 100644 --- a/rust/connlib/clients/shared/src/control.rs +++ b/rust/connlib/clients/shared/src/control.rs @@ -166,8 +166,8 @@ impl ControlPlane { } #[tracing::instrument(level = "trace", skip(self))] - fn resource_deleted(&self, id: ResourceId) { - // TODO + fn resource_deleted(&mut self, id: ResourceId) { + self.tunnel.remove_resource(id); } fn connection_details( diff --git a/rust/connlib/shared/src/messages.rs b/rust/connlib/shared/src/messages.rs index 04087ab3a..d2dd1a683 100644 --- a/rust/connlib/shared/src/messages.rs +++ b/rust/connlib/shared/src/messages.rs @@ -231,6 +231,18 @@ impl ResourceDescription { ResourceDescription::Cidr(r) => Cow::from(r.address.to_string()), } } + + pub fn has_different_address(&self, other: &ResourceDescription) -> bool { + match (self, other) { + (ResourceDescription::Dns(dns_a), ResourceDescription::Dns(dns_b)) => { + dns_a.address != dns_b.address + } + (ResourceDescription::Cidr(cidr_a), ResourceDescription::Cidr(cidr_b)) => { + cidr_a.address != cidr_b.address + } + _ => true, + } + } } /// Description of a resource that maps to a CIDR. diff --git a/rust/connlib/tunnel/src/client.rs b/rust/connlib/tunnel/src/client.rs index 7262aa6de..a79f649e3 100644 --- a/rust/connlib/tunnel/src/client.rs +++ b/rust/connlib/tunnel/src/client.rs @@ -59,14 +59,10 @@ where &mut self, resource_description: ResourceDescription, ) -> connlib_shared::Result<()> { - if self - .role_state - .resource_ids - .contains_key(&resource_description.id()) - { - // TODO - tracing::info!("Resource updates aren't implemented yet"); - return Ok(()); + if let Some(resource) = self.role_state.resource_ids.get(&resource_description.id()) { + if resource.has_different_address(resource) { + self.remove_resource(resource.id()); + } } match &resource_description { @@ -99,6 +95,63 @@ where Ok(()) } + pub fn remove_resource(&mut self, id: ResourceId) { + self.role_state.awaiting_connection.remove(&id); + self.role_state.awaiting_connection_timers.remove(id); + self.role_state + .dns_resources_internal_ips + .retain(|r, _| r.id != id); + self.role_state.dns_resources.retain(|_, r| r.id != id); + self.role_state.cidr_resources.retain(|_, r| r.id != id); + self.role_state + .deferred_dns_queries + .retain(|(r, _), _| r.id != id); + + if let Some(ResourceDescription::Cidr(resource)) = self.role_state.resource_ids.remove(&id) + { + // Note: hopefully the os doesn't coalece routes in a way that removing a more general route deletes the most specific + if let Err(err) = self.remove_route(resource.address) { + tracing::error!(%id, %resource.address, "failed to remove route: {err:?}"); + } + } + + let Some(gateway_id) = self.role_state.resources_gateways.remove(&id) else { + return; + }; + + let Some(peer) = self.role_state.peers.get_mut(&gateway_id) else { + return; + }; + + // First we remove the id from all allowed ips + for (network, resources) in peer + .allowed_ips + .iter_mut() + .filter(|(_, resources)| resources.contains(&id)) + { + resources.remove(&id); + + if !resources.is_empty() { + continue; + } + + // If the allowed_ips doesn't correspond to any resource anymore we + // clean up any related translation. + peer.transform + .translations + .remove_by_left(&network.network_address()); + } + + // We remove all empty allowed ips entry since there's no resource that corresponds to it + peer.allowed_ips.retain(|_, r| !r.is_empty()); + + // If there's no allowed ip left we remove the whole peer because there's no point on keeping it around + if peer.allowed_ips.is_empty() { + self.role_state.peers.remove(&gateway_id); + // TODO: should we have a Node::remove_connection? + } + } + /// Sets the interface configuration and starts background tasks. #[tracing::instrument(level = "trace", skip(self))] pub fn set_interface( @@ -163,6 +216,22 @@ where Ok(()) } + + #[tracing::instrument(level = "trace", skip(self))] + pub fn remove_route(&mut self, route: IpNetwork) -> connlib_shared::Result<()> { + let callbacks = self.callbacks().clone(); + let maybe_new_device = self + .device + .as_mut() + .ok_or(Error::ControlProtocolError)? + .remove_route(route, &callbacks)?; + + if let Some(new_device) = maybe_new_device { + self.device = Some(new_device); + } + + Ok(()) + } } /// [`Tunnel`] state specific to clients. @@ -189,7 +258,7 @@ pub struct ClientState { pub resource_ids: HashMap, pub deferred_dns_queries: HashMap<(DnsResource, Rtype), IpPacket<'static>>, - pub peers: PeerStore, + pub peers: PeerStore>, forwarded_dns_queries: FuturesTupleSet< Result, @@ -332,11 +401,7 @@ impl ClientState { self.resources_gateways.insert(resource, gateway); - if self - .peers - .add_ips(&gateway, &self.get_resource_ip(desc, &domain)) - .is_none() - { + if self.peers.get(&gateway).is_none() { match self .gateway_awaiting_connection_timers // Note: we don't need to set a timer here because @@ -357,6 +422,9 @@ impl ClientState { return Ok(None); }; + self.peers + .add_ips_with_resource(&gateway, &self.get_resource_ip(desc, &domain), &resource); + self.awaiting_connection.remove(&resource); self.awaiting_connection_timers.remove(resource); diff --git a/rust/connlib/tunnel/src/control_protocol/client.rs b/rust/connlib/tunnel/src/control_protocol/client.rs index 61da858f1..785f3066c 100644 --- a/rust/connlib/tunnel/src/control_protocol/client.rs +++ b/rust/connlib/tunnel/src/control_protocol/client.rs @@ -106,10 +106,11 @@ where &domain_response.as_ref().map(|d| d.domain.clone()), )?; - let mut peer: Peer<_, PacketTransformClient> = - Peer::new(ips.clone(), gateway_id, Default::default()); + let resource_ids = HashSet::from([resource_id]); + let mut peer: Peer<_, PacketTransformClient, _> = + Peer::new(gateway_id, Default::default(), &ips, resource_ids); peer.transform.set_dns(self.role_state.dns_mapping()); - self.role_state.peers.insert(peer); + self.role_state.peers.insert(peer, &[]); let peer_ips = if let Some(domain_response) = domain_response { self.dns_response(&resource_id, &domain_response, &gateway_id)? @@ -117,7 +118,9 @@ where ips }; - self.role_state.peers.add_ips(&gateway_id, &peer_ips); + self.role_state + .peers + .add_ips_with_resource(&gateway_id, &peer_ips, &resource_id); Ok(()) } @@ -232,7 +235,9 @@ where let peer_ips = self.dns_response(&resource_id, &domain_response, &gateway_id)?; - self.role_state.peers.add_ips(&gateway_id, &peer_ips); + self.role_state + .peers + .add_ips_with_resource(&gateway_id, &peer_ips, &resource_id); Ok(()) } diff --git a/rust/connlib/tunnel/src/control_protocol/gateway.rs b/rust/connlib/tunnel/src/control_protocol/gateway.rs index 6760c236f..853fc4b8a 100644 --- a/rust/connlib/tunnel/src/control_protocol/gateway.rs +++ b/rust/connlib/tunnel/src/control_protocol/gateway.rs @@ -167,15 +167,14 @@ where ) -> Result<()> { tracing::trace!(?ips, "new_data_channel_open"); - let mut peer = Peer::new(ips.clone(), client_id, PacketTransformGateway::default()); + let mut peer = Peer::new(client_id, PacketTransformGateway::default(), &ips, ()); for address in resource_addresses { peer.transform .add_resource(address, resource.clone(), expires_at); } - self.role_state.peers.insert(peer); - self.role_state.peers.add_ips(&client_id, &ips); + self.role_state.peers.insert(peer, &ips); Ok(()) } diff --git a/rust/connlib/tunnel/src/device_channel.rs b/rust/connlib/tunnel/src/device_channel.rs index 41cac66d5..d062e0929 100644 --- a/rust/connlib/tunnel/src/device_channel.rs +++ b/rust/connlib/tunnel/src/device_channel.rs @@ -148,6 +148,34 @@ impl Device { })) } + #[cfg(target_family = "unix")] + pub(crate) fn remove_route( + &mut self, + route: IpNetwork, + callbacks: &impl Callbacks, + ) -> Result, Error> { + let Some(tun) = self.tun.remove_route(route, callbacks)? else { + return Ok(None); + }; + let mtu = ioctl::interface_mtu_by_name(tun.name())?; + + Ok(Some(Device { + mtu, + tun, + mtu_refreshed_at: Instant::now(), + })) + } + + #[cfg(target_family = "windows")] + pub(crate) fn remove_route( + &mut self, + route: IpNetwork, + _callbacks: &impl Callbacks, + ) -> Result, Error> { + self.tun.remove_route(route)?; + Ok(None) + } + #[cfg(target_family = "windows")] #[allow(unused_mut)] pub(crate) fn add_route( diff --git a/rust/connlib/tunnel/src/device_channel/tun_android.rs b/rust/connlib/tunnel/src/device_channel/tun_android.rs index 6376cb5bf..5a1a91c22 100644 --- a/rust/connlib/tunnel/src/device_channel/tun_android.rs +++ b/rust/connlib/tunnel/src/device_channel/tun_android.rs @@ -75,6 +75,21 @@ impl Tun { name, })) } + + pub fn remove_route( + &self, + route: IpNetwork, + callbacks: &impl Callbacks, + ) -> Result> { + self.fd.close(); + let fd = callbacks.on_remove_route(route)?.ok_or(Error::NoFd)?; + let name = unsafe { interface_name(fd)? }; + + Ok(Some(Tun { + fd: Closeable::new(AsyncFd::new(fd)?), + name, + })) + } } /// Retrieves the name of the interface pointed to by the provided file descriptor. diff --git a/rust/connlib/tunnel/src/device_channel/tun_darwin.rs b/rust/connlib/tunnel/src/device_channel/tun_darwin.rs index 262e0156d..4840ed635 100644 --- a/rust/connlib/tunnel/src/device_channel/tun_darwin.rs +++ b/rust/connlib/tunnel/src/device_channel/tun_darwin.rs @@ -153,6 +153,16 @@ impl Tun { Ok(None) } + pub fn remove_route( + &self, + route: IpNetwork, + callbacks: &impl Callbacks, + ) -> Result> { + // This will always be None in macos + callbacks.on_remove_route(route)?; + Ok(None) + } + pub fn name(&self) -> &str { self.name.as_str() } diff --git a/rust/connlib/tunnel/src/device_channel/tun_linux.rs b/rust/connlib/tunnel/src/device_channel/tun_linux.rs index 2cfd82208..aa9300920 100644 --- a/rust/connlib/tunnel/src/device_channel/tun_linux.rs +++ b/rust/connlib/tunnel/src/device_channel/tun_linux.rs @@ -6,16 +6,16 @@ use connlib_shared::{ use futures::TryStreamExt; use futures_util::future::BoxFuture; use futures_util::FutureExt; -use ip_network::IpNetwork; +use ip_network::{IpNetwork, Ipv4Network, Ipv6Network}; use libc::{ close, fcntl, makedev, mknod, open, F_GETFL, F_SETFL, IFF_MULTI_QUEUE, IFF_NO_PI, IFF_TUN, O_NONBLOCK, O_RDWR, S_IFCHR, }; use netlink_packet_route::route::{RouteProtocol, RouteScope}; use netlink_packet_route::rule::RuleAction; -use rtnetlink::RuleAddRequest; use rtnetlink::{new_connection, Error::NetlinkError, Handle}; -use std::net::IpAddr; +use rtnetlink::{RouteAddRequest, RuleAddRequest}; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; use std::path::Path; use std::task::{Context, Poll}; use std::{ @@ -154,26 +154,9 @@ impl Tun { .header .index; - let req = handle - .route() - .add() - .output_interface(index) - .protocol(RouteProtocol::Static) - .scope(RouteScope::Universe) - .table_id(FIREZONE_TABLE); let res = match route { - IpNetwork::V4(ipnet) => { - req.v4() - .destination_prefix(ipnet.network_address(), ipnet.netmask()) - .execute() - .await - } - IpNetwork::V6(ipnet) => { - req.v6() - .destination_prefix(ipnet.network_address(), ipnet.netmask()) - .execute() - .await - } + IpNetwork::V4(ipnet) => make_route_v4(index, &handle, ipnet).execute().await, + IpNetwork::V6(ipnet) => make_route_v6(index, &handle, ipnet).execute().await, }; match res { @@ -206,6 +189,53 @@ impl Tun { Ok(None) } + pub fn remove_route(&mut self, route: IpNetwork, _: &impl Callbacks) -> Result> { + let handle = self.handle.clone(); + + let add_route_worker = async move { + let index = handle + .link() + .get() + .match_name(IFACE_NAME.to_string()) + .execute() + .try_next() + .await? + .ok_or(Error::NoIface)? + .header + .index; + + let message = match route { + IpNetwork::V4(ipnet) => make_route_v4(index, &handle, ipnet).message_mut().clone(), + IpNetwork::V6(ipnet) => make_route_v6(index, &handle, ipnet).message_mut().clone(), + }; + + match handle.route().del(message).execute().await { + Ok(_) => Ok(()), + Err(err) => { + tracing::error!(%route, "failed to add route: {err:#?}"); + Ok(()) + } + } + }; + + match self.worker.take() { + None => self.worker = Some(add_route_worker.boxed()), + Some(current_worker) => { + self.worker = Some( + async move { + current_worker.await?; + add_route_worker.await?; + + Ok(()) + } + .boxed(), + ) + } + } + + Ok(None) + } + pub fn name(&self) -> &str { IFACE_NAME } @@ -327,6 +357,28 @@ fn make_rule(handle: &Handle) -> RuleAddRequest { rule } +fn make_route(idx: u32, handle: &Handle) -> RouteAddRequest { + handle + .route() + .add() + .output_interface(idx) + .protocol(RouteProtocol::Static) + .scope(RouteScope::Universe) + .table_id(FIREZONE_TABLE) +} + +fn make_route_v4(idx: u32, handle: &Handle, route: Ipv4Network) -> RouteAddRequest { + make_route(idx, handle) + .v4() + .destination_prefix(route.network_address(), route.netmask()) +} + +fn make_route_v6(idx: u32, handle: &Handle, route: Ipv6Network) -> RouteAddRequest { + make_route(idx, handle) + .v6() + .destination_prefix(route.network_address(), route.netmask()) +} + fn get_last_error() -> Error { Error::Io(io::Error::last_os_error()) } diff --git a/rust/connlib/tunnel/src/device_channel/tun_windows.rs b/rust/connlib/tunnel/src/device_channel/tun_windows.rs index a7a2965a1..8c750fb6a 100644 --- a/rust/connlib/tunnel/src/device_channel/tun_windows.rs +++ b/rust/connlib/tunnel/src/device_channel/tun_windows.rs @@ -13,8 +13,8 @@ use tokio::sync::mpsc; use windows::Win32::{ NetworkManagement::{ IpHelper::{ - CreateIpForwardEntry2, GetIpInterfaceEntry, InitializeIpForwardEntry, - SetIpInterfaceEntry, MIB_IPFORWARD_ROW2, MIB_IPINTERFACE_ROW, + CreateIpForwardEntry2, DeleteIpForwardEntry2, GetIpInterfaceEntry, + InitializeIpForwardEntry, SetIpInterfaceEntry, MIB_IPFORWARD_ROW2, MIB_IPINTERFACE_ROW, }, Ndis::NET_LUID_LH, }, @@ -154,37 +154,27 @@ impl Tun { // It's okay if this blocks until the route is added in the OS. pub fn add_route(&self, route: IpNetwork) -> Result<()> { - tracing::debug!("add_route {route}"); - let mut row = MIB_IPFORWARD_ROW2::default(); - // SAFETY: Windows shouldn't store the reference anywhere, it's just setting defaults - unsafe { InitializeIpForwardEntry(&mut row) }; - - let prefix = &mut row.DestinationPrefix; - match route { - IpNetwork::V4(x) => { - prefix.PrefixLength = x.netmask(); - prefix.Prefix.Ipv4 = SocketAddrV4::new(x.network_address(), 0).into(); - } - IpNetwork::V6(x) => { - prefix.PrefixLength = x.netmask(); - prefix.Prefix.Ipv6 = SocketAddrV6::new(x.network_address(), 0, 0, 0).into(); - } - } - - row.InterfaceIndex = self.iface_idx; - row.Metric = 0; + const DUPLICATE_ERR: u32 = 0x80071392; + let entry = self.forward_entry(route); // SAFETY: Windows shouldn't store the reference anywhere, it's just a way to pass lots of arguments at once. And no other thread sees this variable. - match unsafe { CreateIpForwardEntry2(&row) } { - Ok(_) => {} - Err(e) => { - if e.code().0 as u32 == 0x80071392 { - // "Object already exists" error - tracing::warn!("Failed to add duplicate route, ignoring"); - } else { - Err(e)?; - } + match unsafe { CreateIpForwardEntry2(&entry) } { + Ok(()) => Ok(()), + Err(e) if e.code().0 as u32 == DUPLICATE_ERR => { + tracing::debug!(%route, "Failed to add duplicate route, ignoring"); + Ok(()) } + Err(e) => Err(e.into()), + } + } + + // It's okay if this blocks until the route is added in the OS. + pub fn remove_route(&self, route: IpNetwork) -> Result<()> { + let entry = self.forward_entry(route); + + // SAFETY: Windows shouldn't store the reference anywhere, it's just a way to pass lots of arguments at once. And no other thread sees this variable. + unsafe { + DeleteIpForwardEntry2(&entry)?; } Ok(()) } @@ -239,6 +229,29 @@ impl Tun { self.session.send_packet(pkt); Ok(bytes.len()) } + + fn forward_entry(&self, route: IpNetwork) -> MIB_IPFORWARD_ROW2 { + let mut row = MIB_IPFORWARD_ROW2::default(); + // SAFETY: Windows shouldn't store the reference anywhere, it's just setting defaults + unsafe { InitializeIpForwardEntry(&mut row) }; + + let prefix = &mut row.DestinationPrefix; + match route { + IpNetwork::V4(x) => { + prefix.PrefixLength = x.netmask(); + prefix.Prefix.Ipv4 = SocketAddrV4::new(x.network_address(), 0).into(); + } + IpNetwork::V6(x) => { + prefix.PrefixLength = x.netmask(); + prefix.Prefix.Ipv6 = SocketAddrV6::new(x.network_address(), 0, 0, 0).into(); + } + } + + row.InterfaceIndex = self.iface_idx; + row.Metric = 0; + + row + } } fn start_recv_thread( diff --git a/rust/connlib/tunnel/src/gateway.rs b/rust/connlib/tunnel/src/gateway.rs index ed1e6d0e0..736467822 100644 --- a/rust/connlib/tunnel/src/gateway.rs +++ b/rust/connlib/tunnel/src/gateway.rs @@ -6,7 +6,7 @@ use crate::ip_packet::MutableIpPacket; use crate::peer::PacketTransformGateway; use crate::peer_store::PeerStore; use crate::Tunnel; -use connlib_shared::messages::{ClientId, Interface as InterfaceConfig}; +use connlib_shared::messages::{ClientId, Interface as InterfaceConfig, ResourceId}; use connlib_shared::Callbacks; use snownet::Server; use tokio::time::{interval, Interval, MissedTickBehavior}; @@ -40,11 +40,22 @@ where pub fn cleanup_connection(&mut self, id: &ClientId) { self.role_state.peers.remove(id); } + + pub fn remove_access(&mut self, id: &ClientId, resource_id: &ResourceId) { + let Some(peer) = self.role_state.peers.get_mut(id) else { + return; + }; + + peer.transform.remove_resource(resource_id); + if peer.transform.is_emptied() { + self.role_state.peers.remove(id); + } + } } /// [`Tunnel`] state specific to gateways. pub struct GatewayState { - pub peers: PeerStore, + pub peers: PeerStore, expire_interval: Interval, } diff --git a/rust/connlib/tunnel/src/lib.rs b/rust/connlib/tunnel/src/lib.rs index cae59b1c8..829a61afc 100644 --- a/rust/connlib/tunnel/src/lib.rs +++ b/rust/connlib/tunnel/src/lib.rs @@ -277,14 +277,15 @@ where } // TODO: passing the peer_store looks weird, we can just remove ConnectionState and move everything into Tunnel, there's no Mutexes any longer that justify this separation - fn poll_sockets( + fn poll_sockets( &mut self, device: &mut Device, - peer_store: &mut PeerStore, + peer_store: &mut PeerStore, cx: &mut Context<'_>, ) -> Poll> where TTransform: PacketTransform, + TResource: Clone, { let received = match ready!(self.sockets.poll_recv_from(cx)) { Ok(received) => received, diff --git a/rust/connlib/tunnel/src/peer.rs b/rust/connlib/tunnel/src/peer.rs index bc81ccfae..18cbf0465 100644 --- a/rust/connlib/tunnel/src/peer.rs +++ b/rust/connlib/tunnel/src/peer.rs @@ -1,10 +1,10 @@ -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::net::IpAddr; use std::time::Instant; use bimap::BiMap; use chrono::{DateTime, Utc}; -use connlib_shared::messages::DnsServer; +use connlib_shared::messages::{DnsServer, ResourceId}; use connlib_shared::IpProvider; use connlib_shared::{Error, Result}; use ip_network::IpNetwork; @@ -20,25 +20,44 @@ type ExpiryingResource = (ResourceDescription, Option>); // is 30 seconds. See resolvconf(5) timeout. const IDS_EXPIRE: std::time::Duration = std::time::Duration::from_secs(60); -pub struct Peer { - allowed_ips: IpNetworkTable<()>, +pub struct Peer { + // TODO: we should refactor this + // in the gateway-side this means that we are explicit about () + // maybe duping the Peer struct is the way to go + pub allowed_ips: IpNetworkTable, pub conn_id: TId, pub transform: TTransform, } -impl Peer +impl Peer> where TId: Copy, TTransform: PacketTransform, +{ + pub(crate) fn insert_id(&mut self, ip: &IpNetwork, id: &ResourceId) { + if let Some(resources) = self.allowed_ips.exact_match_mut(*ip) { + resources.insert(*id); + } else { + self.allowed_ips.insert(*ip, HashSet::from([*id])); + } + } +} + +impl Peer +where + TId: Copy, + TTransform: PacketTransform, + TResource: Clone, { pub(crate) fn new( - ips: Vec, conn_id: TId, transform: TTransform, - ) -> Peer { + ips: &[IpNetwork], + resource: TResource, + ) -> Peer { let mut allowed_ips = IpNetworkTable::new(); for ip in ips { - allowed_ips.insert(ip, ()); + allowed_ips.insert(*ip, resource.clone()); } Peer { @@ -48,10 +67,6 @@ where } } - pub(crate) fn add_allowed_ip(&mut self, ip: IpNetwork) { - self.allowed_ips.insert(ip, ()); - } - fn is_allowed(&self, addr: IpAddr) -> bool { self.allowed_ips.longest_match(addr).is_some() } @@ -92,7 +107,7 @@ impl Default for PacketTransformGateway { #[derive(Default)] pub struct PacketTransformClient { - translations: BiMap, + pub translations: BiMap, dns_mapping: BiMap, mangled_dns_ids: HashMap, } @@ -133,6 +148,13 @@ impl PacketTransformGateway { .retain(|_, (_, e)| !e.is_some_and(|e| e <= Utc::now())); } + pub(crate) fn remove_resource(&mut self, resource: &ResourceId) { + self.resources.retain(|_, (r, _)| match r { + connlib_shared::messages::ResourceDescription::Dns(r) => r.id != *resource, + connlib_shared::messages::ResourceDescription::Cidr(r) => r.id != *resource, + }) + } + pub(crate) fn add_resource( &mut self, ip: IpNetwork, diff --git a/rust/connlib/tunnel/src/peer_store.rs b/rust/connlib/tunnel/src/peer_store.rs index a39f4de97..a5c285e86 100644 --- a/rust/connlib/tunnel/src/peer_store.rs +++ b/rust/connlib/tunnel/src/peer_store.rs @@ -1,17 +1,18 @@ -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::hash::Hash; use std::net::IpAddr; use crate::peer::{PacketTransform, Peer}; +use connlib_shared::messages::ResourceId; use ip_network::IpNetwork; use ip_network_table::IpNetworkTable; -pub struct PeerStore { +pub struct PeerStore { id_by_ip: IpNetworkTable, - peer_by_id: HashMap>, + peer_by_id: HashMap>, } -impl Default for PeerStore { +impl Default for PeerStore { fn default() -> Self { Self { id_by_ip: IpNetworkTable::new(), @@ -20,67 +21,92 @@ impl Default for PeerStore { } } -impl PeerStore +impl PeerStore> where - TId: Hash + Eq + Clone + Copy, + TId: Hash + Eq + Copy, TTransform: PacketTransform, { - pub fn retain(&mut self, f: impl Fn(&TId, &mut Peer) -> bool) { + pub fn add_ips_with_resource(&mut self, id: &TId, ips: &[IpNetwork], resource: &ResourceId) { + for ip in ips { + let Some(peer) = self.add_ip(id, ip) else { + continue; + }; + peer.insert_id(ip, resource); + } + } +} + +impl PeerStore +where + TId: Hash + Eq + Copy, + TTransform: PacketTransform, +{ + pub fn retain(&mut self, f: impl Fn(&TId, &mut Peer) -> bool) { self.peer_by_id.retain(f); self.id_by_ip .retain(|_, id| self.peer_by_id.contains_key(id)); } - pub fn add_ips(&mut self, id: &TId, ips: &[IpNetwork]) -> Option<&Peer> { + pub fn add_ip( + &mut self, + id: &TId, + ip: &IpNetwork, + ) -> Option<&mut Peer> { let peer = self.peer_by_id.get_mut(id)?; - - for ip in ips { - self.id_by_ip.insert(*ip, peer.conn_id); - peer.add_allowed_ip(*ip); - } - + self.id_by_ip.insert(*ip, *id); Some(peer) } - pub fn insert(&mut self, peer: Peer) -> Option> { + pub fn insert( + &mut self, + peer: Peer, + ips: &[IpNetwork], + ) -> Option> { self.id_by_ip.retain(|_, &mut r_id| r_id != peer.conn_id); - self.peer_by_id.insert(peer.conn_id, peer) + let id = peer.conn_id; + let old_peer = self.peer_by_id.insert(id, peer); + + for ip in ips { + self.add_ip(&id, ip); + } + + old_peer } - pub fn remove(&mut self, id: &TId) -> Option> { + pub fn remove(&mut self, id: &TId) -> Option> { self.id_by_ip.retain(|_, r_id| r_id != id); self.peer_by_id.remove(id) } - pub fn exact_match(&self, ip: IpNetwork) -> Option<&Peer> { + pub fn exact_match(&self, ip: IpNetwork) -> Option<&Peer> { let ip = self.id_by_ip.exact_match(ip)?; self.peer_by_id.get(ip) } - pub fn get(&self, id: &TId) -> Option<&Peer> { + pub fn get(&self, id: &TId) -> Option<&Peer> { self.peer_by_id.get(id) } - pub fn get_mut(&mut self, id: &TId) -> Option<&mut Peer> { + pub fn get_mut(&mut self, id: &TId) -> Option<&mut Peer> { self.peer_by_id.get_mut(id) } - pub fn peer_by_ip(&self, ip: IpAddr) -> Option<&Peer> { + pub fn peer_by_ip(&self, ip: IpAddr) -> Option<&Peer> { let (_, id) = self.id_by_ip.longest_match(ip)?; self.peer_by_id.get(id) } - pub fn peer_by_ip_mut(&mut self, ip: IpAddr) -> Option<&mut Peer> { + pub fn peer_by_ip_mut(&mut self, ip: IpAddr) -> Option<&mut Peer> { let (_, id) = self.id_by_ip.longest_match(ip)?; self.peer_by_id.get_mut(id) } - pub fn iter_mut(&mut self) -> impl Iterator> { + pub fn iter_mut(&mut self) -> impl Iterator> { self.peer_by_id.values_mut() } - pub fn iter(&mut self) -> impl Iterator> { + pub fn iter(&mut self) -> impl Iterator> { self.peer_by_id.values() } } @@ -93,16 +119,21 @@ mod tests { #[test] fn can_insert_and_retrieve_peer() { - let mut peer_storage = PeerStore::default(); - peer_storage.insert(Peer::new(vec![], 0, PacketTransformGateway::default())); + let mut peer_storage = PeerStore::<_, _, ()>::default(); + peer_storage.insert( + Peer::new(0, PacketTransformGateway::default(), &[], ()), + &[], + ); assert!(peer_storage.get(&0).is_some()); } #[test] fn can_insert_and_retrieve_peer_by_ip() { - let mut peer_storage = PeerStore::default(); - peer_storage.insert(Peer::new(vec![], 0, PacketTransformGateway::default())); - peer_storage.add_ips(&0, &["100.0.0.0/24".parse().unwrap()]); + let mut peer_storage = PeerStore::<_, _, ()>::default(); + peer_storage.insert( + Peer::new(0, PacketTransformGateway::default(), &[], ()), + &["100.0.0.0/24".parse().unwrap()], + ); assert_eq!( peer_storage @@ -115,9 +146,11 @@ mod tests { #[test] fn can_remove_peer() { - let mut peer_storage = PeerStore::default(); - peer_storage.insert(Peer::new(vec![], 0, PacketTransformGateway::default())); - peer_storage.add_ips(&0, &["100.0.0.0/24".parse().unwrap()]); + let mut peer_storage = PeerStore::<_, _, ()>::default(); + peer_storage.insert( + Peer::new(0, PacketTransformGateway::default(), &[], ()), + &["100.0.0.0/24".parse().unwrap()], + ); peer_storage.remove(&0); assert!(peer_storage.get(&0).is_none()); @@ -128,10 +161,15 @@ mod tests { #[test] fn inserting_peer_removes_previous_instances_of_same_id() { - let mut peer_storage = PeerStore::default(); - peer_storage.insert(Peer::new(vec![], 0, PacketTransformGateway::default())); - peer_storage.add_ips(&0, &["100.0.0.0/24".parse().unwrap()]); - peer_storage.insert(Peer::new(vec![], 0, PacketTransformGateway::default())); + let mut peer_storage = PeerStore::<_, _, ()>::default(); + peer_storage.insert( + Peer::new(0, PacketTransformGateway::default(), &[], ()), + &["100.0.0.0/24".parse().unwrap()], + ); + peer_storage.insert( + Peer::new(0, PacketTransformGateway::default(), &[], ()), + &[], + ); assert!(peer_storage.get(&0).is_some()); assert!(peer_storage diff --git a/rust/gateway/src/eventloop.rs b/rust/gateway/src/eventloop.rs index a5cbd1320..59af91d00 100644 --- a/rust/gateway/src/eventloop.rs +++ b/rust/gateway/src/eventloop.rs @@ -1,6 +1,6 @@ use crate::messages::{ AllowAccess, BroadcastClientIceCandidates, ClientIceCandidates, ConnectionReady, - EgressMessages, IngressMessages, RequestConnection, + EgressMessages, IngressMessages, RejectAccess, RequestConnection, }; use crate::CallbackHandler; use anyhow::{anyhow, bail, Result}; @@ -201,6 +201,20 @@ impl Eventloop { } continue; } + + Poll::Ready(phoenix_channel::Event::InboundMessage { + msg: + IngressMessages::RejectAccess(RejectAccess { + client_id, + resource_id, + }), + .. + }) => { + tracing::debug!(client = %client_id, resource = %resource_id, "Access removed"); + + self.tunnel.remove_access(&client_id, &resource_id); + continue; + } Poll::Ready(phoenix_channel::Event::InboundMessage { msg: IngressMessages::Init(_), .. diff --git a/rust/gateway/src/messages.rs b/rust/gateway/src/messages.rs index 0bbce9e50..5e93d51f4 100644 --- a/rust/gateway/src/messages.rs +++ b/rust/gateway/src/messages.rs @@ -86,6 +86,12 @@ pub struct AllowAccess { pub reference: String, } +#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)] +pub struct RejectAccess { + pub client_id: ClientId, + pub resource_id: ResourceId, +} + // These messages are the messages that can be received // either by a client or a gateway by the client. #[derive(Debug, Deserialize, Clone, PartialEq, Eq)] @@ -93,6 +99,7 @@ pub struct AllowAccess { pub enum IngressMessages { RequestConnection(RequestConnection), AllowAccess(AllowAccess), + RejectAccess(RejectAccess), IceCandidates(ClientIceCandidates), Init(InitGateway), }