mirror of
https://github.com/outbackdingo/firezone.git
synced 2026-01-27 10:18:54 +00:00
feat(connlib): support resource updates from the portal (#3754)
This PR doesn't yet provide support for the update of upstream DNS but it does provide support for all the other resources update messages. Should comply with the description of issue #2022 but it doesn't respond to DNS upstream updates which is imply it should on the issue title --------- Signed-off-by: Gabi <gabrielalejandro7@gmail.com> Co-authored-by: Thomas Eizinger <thomas@eizinger.io>
This commit is contained in:
@@ -166,8 +166,8 @@ impl<CB: Callbacks + 'static> ControlPlane<CB> {
|
||||
}
|
||||
|
||||
#[tracing::instrument(level = "trace", skip(self))]
|
||||
fn resource_deleted(&self, id: ResourceId) {
|
||||
// TODO
|
||||
fn resource_deleted(&mut self, id: ResourceId) {
|
||||
self.tunnel.remove_resource(id);
|
||||
}
|
||||
|
||||
fn connection_details(
|
||||
|
||||
@@ -231,6 +231,18 @@ impl ResourceDescription {
|
||||
ResourceDescription::Cidr(r) => Cow::from(r.address.to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn has_different_address(&self, other: &ResourceDescription) -> bool {
|
||||
match (self, other) {
|
||||
(ResourceDescription::Dns(dns_a), ResourceDescription::Dns(dns_b)) => {
|
||||
dns_a.address != dns_b.address
|
||||
}
|
||||
(ResourceDescription::Cidr(cidr_a), ResourceDescription::Cidr(cidr_b)) => {
|
||||
cidr_a.address != cidr_b.address
|
||||
}
|
||||
_ => true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Description of a resource that maps to a CIDR.
|
||||
|
||||
@@ -59,14 +59,10 @@ where
|
||||
&mut self,
|
||||
resource_description: ResourceDescription,
|
||||
) -> connlib_shared::Result<()> {
|
||||
if self
|
||||
.role_state
|
||||
.resource_ids
|
||||
.contains_key(&resource_description.id())
|
||||
{
|
||||
// TODO
|
||||
tracing::info!("Resource updates aren't implemented yet");
|
||||
return Ok(());
|
||||
if let Some(resource) = self.role_state.resource_ids.get(&resource_description.id()) {
|
||||
if resource.has_different_address(resource) {
|
||||
self.remove_resource(resource.id());
|
||||
}
|
||||
}
|
||||
|
||||
match &resource_description {
|
||||
@@ -99,6 +95,63 @@ where
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn remove_resource(&mut self, id: ResourceId) {
|
||||
self.role_state.awaiting_connection.remove(&id);
|
||||
self.role_state.awaiting_connection_timers.remove(id);
|
||||
self.role_state
|
||||
.dns_resources_internal_ips
|
||||
.retain(|r, _| r.id != id);
|
||||
self.role_state.dns_resources.retain(|_, r| r.id != id);
|
||||
self.role_state.cidr_resources.retain(|_, r| r.id != id);
|
||||
self.role_state
|
||||
.deferred_dns_queries
|
||||
.retain(|(r, _), _| r.id != id);
|
||||
|
||||
if let Some(ResourceDescription::Cidr(resource)) = self.role_state.resource_ids.remove(&id)
|
||||
{
|
||||
// Note: hopefully the os doesn't coalece routes in a way that removing a more general route deletes the most specific
|
||||
if let Err(err) = self.remove_route(resource.address) {
|
||||
tracing::error!(%id, %resource.address, "failed to remove route: {err:?}");
|
||||
}
|
||||
}
|
||||
|
||||
let Some(gateway_id) = self.role_state.resources_gateways.remove(&id) else {
|
||||
return;
|
||||
};
|
||||
|
||||
let Some(peer) = self.role_state.peers.get_mut(&gateway_id) else {
|
||||
return;
|
||||
};
|
||||
|
||||
// First we remove the id from all allowed ips
|
||||
for (network, resources) in peer
|
||||
.allowed_ips
|
||||
.iter_mut()
|
||||
.filter(|(_, resources)| resources.contains(&id))
|
||||
{
|
||||
resources.remove(&id);
|
||||
|
||||
if !resources.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
// If the allowed_ips doesn't correspond to any resource anymore we
|
||||
// clean up any related translation.
|
||||
peer.transform
|
||||
.translations
|
||||
.remove_by_left(&network.network_address());
|
||||
}
|
||||
|
||||
// We remove all empty allowed ips entry since there's no resource that corresponds to it
|
||||
peer.allowed_ips.retain(|_, r| !r.is_empty());
|
||||
|
||||
// If there's no allowed ip left we remove the whole peer because there's no point on keeping it around
|
||||
if peer.allowed_ips.is_empty() {
|
||||
self.role_state.peers.remove(&gateway_id);
|
||||
// TODO: should we have a Node::remove_connection?
|
||||
}
|
||||
}
|
||||
|
||||
/// Sets the interface configuration and starts background tasks.
|
||||
#[tracing::instrument(level = "trace", skip(self))]
|
||||
pub fn set_interface(
|
||||
@@ -163,6 +216,22 @@ where
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tracing::instrument(level = "trace", skip(self))]
|
||||
pub fn remove_route(&mut self, route: IpNetwork) -> connlib_shared::Result<()> {
|
||||
let callbacks = self.callbacks().clone();
|
||||
let maybe_new_device = self
|
||||
.device
|
||||
.as_mut()
|
||||
.ok_or(Error::ControlProtocolError)?
|
||||
.remove_route(route, &callbacks)?;
|
||||
|
||||
if let Some(new_device) = maybe_new_device {
|
||||
self.device = Some(new_device);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// [`Tunnel`] state specific to clients.
|
||||
@@ -189,7 +258,7 @@ pub struct ClientState {
|
||||
pub resource_ids: HashMap<ResourceId, ResourceDescription>,
|
||||
pub deferred_dns_queries: HashMap<(DnsResource, Rtype), IpPacket<'static>>,
|
||||
|
||||
pub peers: PeerStore<GatewayId, PacketTransformClient>,
|
||||
pub peers: PeerStore<GatewayId, PacketTransformClient, HashSet<ResourceId>>,
|
||||
|
||||
forwarded_dns_queries: FuturesTupleSet<
|
||||
Result<hickory_resolver::lookup::Lookup, hickory_resolver::error::ResolveError>,
|
||||
@@ -332,11 +401,7 @@ impl ClientState {
|
||||
|
||||
self.resources_gateways.insert(resource, gateway);
|
||||
|
||||
if self
|
||||
.peers
|
||||
.add_ips(&gateway, &self.get_resource_ip(desc, &domain))
|
||||
.is_none()
|
||||
{
|
||||
if self.peers.get(&gateway).is_none() {
|
||||
match self
|
||||
.gateway_awaiting_connection_timers
|
||||
// Note: we don't need to set a timer here because
|
||||
@@ -357,6 +422,9 @@ impl ClientState {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
self.peers
|
||||
.add_ips_with_resource(&gateway, &self.get_resource_ip(desc, &domain), &resource);
|
||||
|
||||
self.awaiting_connection.remove(&resource);
|
||||
self.awaiting_connection_timers.remove(resource);
|
||||
|
||||
|
||||
@@ -106,10 +106,11 @@ where
|
||||
&domain_response.as_ref().map(|d| d.domain.clone()),
|
||||
)?;
|
||||
|
||||
let mut peer: Peer<_, PacketTransformClient> =
|
||||
Peer::new(ips.clone(), gateway_id, Default::default());
|
||||
let resource_ids = HashSet::from([resource_id]);
|
||||
let mut peer: Peer<_, PacketTransformClient, _> =
|
||||
Peer::new(gateway_id, Default::default(), &ips, resource_ids);
|
||||
peer.transform.set_dns(self.role_state.dns_mapping());
|
||||
self.role_state.peers.insert(peer);
|
||||
self.role_state.peers.insert(peer, &[]);
|
||||
|
||||
let peer_ips = if let Some(domain_response) = domain_response {
|
||||
self.dns_response(&resource_id, &domain_response, &gateway_id)?
|
||||
@@ -117,7 +118,9 @@ where
|
||||
ips
|
||||
};
|
||||
|
||||
self.role_state.peers.add_ips(&gateway_id, &peer_ips);
|
||||
self.role_state
|
||||
.peers
|
||||
.add_ips_with_resource(&gateway_id, &peer_ips, &resource_id);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -232,7 +235,9 @@ where
|
||||
|
||||
let peer_ips = self.dns_response(&resource_id, &domain_response, &gateway_id)?;
|
||||
|
||||
self.role_state.peers.add_ips(&gateway_id, &peer_ips);
|
||||
self.role_state
|
||||
.peers
|
||||
.add_ips_with_resource(&gateway_id, &peer_ips, &resource_id);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -167,15 +167,14 @@ where
|
||||
) -> Result<()> {
|
||||
tracing::trace!(?ips, "new_data_channel_open");
|
||||
|
||||
let mut peer = Peer::new(ips.clone(), client_id, PacketTransformGateway::default());
|
||||
let mut peer = Peer::new(client_id, PacketTransformGateway::default(), &ips, ());
|
||||
|
||||
for address in resource_addresses {
|
||||
peer.transform
|
||||
.add_resource(address, resource.clone(), expires_at);
|
||||
}
|
||||
|
||||
self.role_state.peers.insert(peer);
|
||||
self.role_state.peers.add_ips(&client_id, &ips);
|
||||
self.role_state.peers.insert(peer, &ips);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -148,6 +148,34 @@ impl Device {
|
||||
}))
|
||||
}
|
||||
|
||||
#[cfg(target_family = "unix")]
|
||||
pub(crate) fn remove_route(
|
||||
&mut self,
|
||||
route: IpNetwork,
|
||||
callbacks: &impl Callbacks<Error = Error>,
|
||||
) -> Result<Option<Device>, Error> {
|
||||
let Some(tun) = self.tun.remove_route(route, callbacks)? else {
|
||||
return Ok(None);
|
||||
};
|
||||
let mtu = ioctl::interface_mtu_by_name(tun.name())?;
|
||||
|
||||
Ok(Some(Device {
|
||||
mtu,
|
||||
tun,
|
||||
mtu_refreshed_at: Instant::now(),
|
||||
}))
|
||||
}
|
||||
|
||||
#[cfg(target_family = "windows")]
|
||||
pub(crate) fn remove_route(
|
||||
&mut self,
|
||||
route: IpNetwork,
|
||||
_callbacks: &impl Callbacks<Error = Error>,
|
||||
) -> Result<Option<Device>, Error> {
|
||||
self.tun.remove_route(route)?;
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
#[cfg(target_family = "windows")]
|
||||
#[allow(unused_mut)]
|
||||
pub(crate) fn add_route(
|
||||
|
||||
@@ -75,6 +75,21 @@ impl Tun {
|
||||
name,
|
||||
}))
|
||||
}
|
||||
|
||||
pub fn remove_route(
|
||||
&self,
|
||||
route: IpNetwork,
|
||||
callbacks: &impl Callbacks<Error = Error>,
|
||||
) -> Result<Option<Self>> {
|
||||
self.fd.close();
|
||||
let fd = callbacks.on_remove_route(route)?.ok_or(Error::NoFd)?;
|
||||
let name = unsafe { interface_name(fd)? };
|
||||
|
||||
Ok(Some(Tun {
|
||||
fd: Closeable::new(AsyncFd::new(fd)?),
|
||||
name,
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
/// Retrieves the name of the interface pointed to by the provided file descriptor.
|
||||
|
||||
@@ -153,6 +153,16 @@ impl Tun {
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
pub fn remove_route(
|
||||
&self,
|
||||
route: IpNetwork,
|
||||
callbacks: &impl Callbacks<Error = Error>,
|
||||
) -> Result<Option<Self>> {
|
||||
// This will always be None in macos
|
||||
callbacks.on_remove_route(route)?;
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
pub fn name(&self) -> &str {
|
||||
self.name.as_str()
|
||||
}
|
||||
|
||||
@@ -6,16 +6,16 @@ use connlib_shared::{
|
||||
use futures::TryStreamExt;
|
||||
use futures_util::future::BoxFuture;
|
||||
use futures_util::FutureExt;
|
||||
use ip_network::IpNetwork;
|
||||
use ip_network::{IpNetwork, Ipv4Network, Ipv6Network};
|
||||
use libc::{
|
||||
close, fcntl, makedev, mknod, open, F_GETFL, F_SETFL, IFF_MULTI_QUEUE, IFF_NO_PI, IFF_TUN,
|
||||
O_NONBLOCK, O_RDWR, S_IFCHR,
|
||||
};
|
||||
use netlink_packet_route::route::{RouteProtocol, RouteScope};
|
||||
use netlink_packet_route::rule::RuleAction;
|
||||
use rtnetlink::RuleAddRequest;
|
||||
use rtnetlink::{new_connection, Error::NetlinkError, Handle};
|
||||
use std::net::IpAddr;
|
||||
use rtnetlink::{RouteAddRequest, RuleAddRequest};
|
||||
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
|
||||
use std::path::Path;
|
||||
use std::task::{Context, Poll};
|
||||
use std::{
|
||||
@@ -154,26 +154,9 @@ impl Tun {
|
||||
.header
|
||||
.index;
|
||||
|
||||
let req = handle
|
||||
.route()
|
||||
.add()
|
||||
.output_interface(index)
|
||||
.protocol(RouteProtocol::Static)
|
||||
.scope(RouteScope::Universe)
|
||||
.table_id(FIREZONE_TABLE);
|
||||
let res = match route {
|
||||
IpNetwork::V4(ipnet) => {
|
||||
req.v4()
|
||||
.destination_prefix(ipnet.network_address(), ipnet.netmask())
|
||||
.execute()
|
||||
.await
|
||||
}
|
||||
IpNetwork::V6(ipnet) => {
|
||||
req.v6()
|
||||
.destination_prefix(ipnet.network_address(), ipnet.netmask())
|
||||
.execute()
|
||||
.await
|
||||
}
|
||||
IpNetwork::V4(ipnet) => make_route_v4(index, &handle, ipnet).execute().await,
|
||||
IpNetwork::V6(ipnet) => make_route_v6(index, &handle, ipnet).execute().await,
|
||||
};
|
||||
|
||||
match res {
|
||||
@@ -206,6 +189,53 @@ impl Tun {
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
pub fn remove_route(&mut self, route: IpNetwork, _: &impl Callbacks) -> Result<Option<Self>> {
|
||||
let handle = self.handle.clone();
|
||||
|
||||
let add_route_worker = async move {
|
||||
let index = handle
|
||||
.link()
|
||||
.get()
|
||||
.match_name(IFACE_NAME.to_string())
|
||||
.execute()
|
||||
.try_next()
|
||||
.await?
|
||||
.ok_or(Error::NoIface)?
|
||||
.header
|
||||
.index;
|
||||
|
||||
let message = match route {
|
||||
IpNetwork::V4(ipnet) => make_route_v4(index, &handle, ipnet).message_mut().clone(),
|
||||
IpNetwork::V6(ipnet) => make_route_v6(index, &handle, ipnet).message_mut().clone(),
|
||||
};
|
||||
|
||||
match handle.route().del(message).execute().await {
|
||||
Ok(_) => Ok(()),
|
||||
Err(err) => {
|
||||
tracing::error!(%route, "failed to add route: {err:#?}");
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
match self.worker.take() {
|
||||
None => self.worker = Some(add_route_worker.boxed()),
|
||||
Some(current_worker) => {
|
||||
self.worker = Some(
|
||||
async move {
|
||||
current_worker.await?;
|
||||
add_route_worker.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
.boxed(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
pub fn name(&self) -> &str {
|
||||
IFACE_NAME
|
||||
}
|
||||
@@ -327,6 +357,28 @@ fn make_rule(handle: &Handle) -> RuleAddRequest {
|
||||
rule
|
||||
}
|
||||
|
||||
fn make_route(idx: u32, handle: &Handle) -> RouteAddRequest {
|
||||
handle
|
||||
.route()
|
||||
.add()
|
||||
.output_interface(idx)
|
||||
.protocol(RouteProtocol::Static)
|
||||
.scope(RouteScope::Universe)
|
||||
.table_id(FIREZONE_TABLE)
|
||||
}
|
||||
|
||||
fn make_route_v4(idx: u32, handle: &Handle, route: Ipv4Network) -> RouteAddRequest<Ipv4Addr> {
|
||||
make_route(idx, handle)
|
||||
.v4()
|
||||
.destination_prefix(route.network_address(), route.netmask())
|
||||
}
|
||||
|
||||
fn make_route_v6(idx: u32, handle: &Handle, route: Ipv6Network) -> RouteAddRequest<Ipv6Addr> {
|
||||
make_route(idx, handle)
|
||||
.v6()
|
||||
.destination_prefix(route.network_address(), route.netmask())
|
||||
}
|
||||
|
||||
fn get_last_error() -> Error {
|
||||
Error::Io(io::Error::last_os_error())
|
||||
}
|
||||
|
||||
@@ -13,8 +13,8 @@ use tokio::sync::mpsc;
|
||||
use windows::Win32::{
|
||||
NetworkManagement::{
|
||||
IpHelper::{
|
||||
CreateIpForwardEntry2, GetIpInterfaceEntry, InitializeIpForwardEntry,
|
||||
SetIpInterfaceEntry, MIB_IPFORWARD_ROW2, MIB_IPINTERFACE_ROW,
|
||||
CreateIpForwardEntry2, DeleteIpForwardEntry2, GetIpInterfaceEntry,
|
||||
InitializeIpForwardEntry, SetIpInterfaceEntry, MIB_IPFORWARD_ROW2, MIB_IPINTERFACE_ROW,
|
||||
},
|
||||
Ndis::NET_LUID_LH,
|
||||
},
|
||||
@@ -154,37 +154,27 @@ impl Tun {
|
||||
|
||||
// It's okay if this blocks until the route is added in the OS.
|
||||
pub fn add_route(&self, route: IpNetwork) -> Result<()> {
|
||||
tracing::debug!("add_route {route}");
|
||||
let mut row = MIB_IPFORWARD_ROW2::default();
|
||||
// SAFETY: Windows shouldn't store the reference anywhere, it's just setting defaults
|
||||
unsafe { InitializeIpForwardEntry(&mut row) };
|
||||
|
||||
let prefix = &mut row.DestinationPrefix;
|
||||
match route {
|
||||
IpNetwork::V4(x) => {
|
||||
prefix.PrefixLength = x.netmask();
|
||||
prefix.Prefix.Ipv4 = SocketAddrV4::new(x.network_address(), 0).into();
|
||||
}
|
||||
IpNetwork::V6(x) => {
|
||||
prefix.PrefixLength = x.netmask();
|
||||
prefix.Prefix.Ipv6 = SocketAddrV6::new(x.network_address(), 0, 0, 0).into();
|
||||
}
|
||||
}
|
||||
|
||||
row.InterfaceIndex = self.iface_idx;
|
||||
row.Metric = 0;
|
||||
const DUPLICATE_ERR: u32 = 0x80071392;
|
||||
let entry = self.forward_entry(route);
|
||||
|
||||
// SAFETY: Windows shouldn't store the reference anywhere, it's just a way to pass lots of arguments at once. And no other thread sees this variable.
|
||||
match unsafe { CreateIpForwardEntry2(&row) } {
|
||||
Ok(_) => {}
|
||||
Err(e) => {
|
||||
if e.code().0 as u32 == 0x80071392 {
|
||||
// "Object already exists" error
|
||||
tracing::warn!("Failed to add duplicate route, ignoring");
|
||||
} else {
|
||||
Err(e)?;
|
||||
}
|
||||
match unsafe { CreateIpForwardEntry2(&entry) } {
|
||||
Ok(()) => Ok(()),
|
||||
Err(e) if e.code().0 as u32 == DUPLICATE_ERR => {
|
||||
tracing::debug!(%route, "Failed to add duplicate route, ignoring");
|
||||
Ok(())
|
||||
}
|
||||
Err(e) => Err(e.into()),
|
||||
}
|
||||
}
|
||||
|
||||
// It's okay if this blocks until the route is added in the OS.
|
||||
pub fn remove_route(&self, route: IpNetwork) -> Result<()> {
|
||||
let entry = self.forward_entry(route);
|
||||
|
||||
// SAFETY: Windows shouldn't store the reference anywhere, it's just a way to pass lots of arguments at once. And no other thread sees this variable.
|
||||
unsafe {
|
||||
DeleteIpForwardEntry2(&entry)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
@@ -239,6 +229,29 @@ impl Tun {
|
||||
self.session.send_packet(pkt);
|
||||
Ok(bytes.len())
|
||||
}
|
||||
|
||||
fn forward_entry(&self, route: IpNetwork) -> MIB_IPFORWARD_ROW2 {
|
||||
let mut row = MIB_IPFORWARD_ROW2::default();
|
||||
// SAFETY: Windows shouldn't store the reference anywhere, it's just setting defaults
|
||||
unsafe { InitializeIpForwardEntry(&mut row) };
|
||||
|
||||
let prefix = &mut row.DestinationPrefix;
|
||||
match route {
|
||||
IpNetwork::V4(x) => {
|
||||
prefix.PrefixLength = x.netmask();
|
||||
prefix.Prefix.Ipv4 = SocketAddrV4::new(x.network_address(), 0).into();
|
||||
}
|
||||
IpNetwork::V6(x) => {
|
||||
prefix.PrefixLength = x.netmask();
|
||||
prefix.Prefix.Ipv6 = SocketAddrV6::new(x.network_address(), 0, 0, 0).into();
|
||||
}
|
||||
}
|
||||
|
||||
row.InterfaceIndex = self.iface_idx;
|
||||
row.Metric = 0;
|
||||
|
||||
row
|
||||
}
|
||||
}
|
||||
|
||||
fn start_recv_thread(
|
||||
|
||||
@@ -6,7 +6,7 @@ use crate::ip_packet::MutableIpPacket;
|
||||
use crate::peer::PacketTransformGateway;
|
||||
use crate::peer_store::PeerStore;
|
||||
use crate::Tunnel;
|
||||
use connlib_shared::messages::{ClientId, Interface as InterfaceConfig};
|
||||
use connlib_shared::messages::{ClientId, Interface as InterfaceConfig, ResourceId};
|
||||
use connlib_shared::Callbacks;
|
||||
use snownet::Server;
|
||||
use tokio::time::{interval, Interval, MissedTickBehavior};
|
||||
@@ -40,11 +40,22 @@ where
|
||||
pub fn cleanup_connection(&mut self, id: &ClientId) {
|
||||
self.role_state.peers.remove(id);
|
||||
}
|
||||
|
||||
pub fn remove_access(&mut self, id: &ClientId, resource_id: &ResourceId) {
|
||||
let Some(peer) = self.role_state.peers.get_mut(id) else {
|
||||
return;
|
||||
};
|
||||
|
||||
peer.transform.remove_resource(resource_id);
|
||||
if peer.transform.is_emptied() {
|
||||
self.role_state.peers.remove(id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// [`Tunnel`] state specific to gateways.
|
||||
pub struct GatewayState {
|
||||
pub peers: PeerStore<ClientId, PacketTransformGateway>,
|
||||
pub peers: PeerStore<ClientId, PacketTransformGateway, ()>,
|
||||
expire_interval: Interval,
|
||||
}
|
||||
|
||||
|
||||
@@ -277,14 +277,15 @@ where
|
||||
}
|
||||
|
||||
// TODO: passing the peer_store looks weird, we can just remove ConnectionState and move everything into Tunnel, there's no Mutexes any longer that justify this separation
|
||||
fn poll_sockets<TTransform>(
|
||||
fn poll_sockets<TTransform, TResource>(
|
||||
&mut self,
|
||||
device: &mut Device,
|
||||
peer_store: &mut PeerStore<TId, TTransform>,
|
||||
peer_store: &mut PeerStore<TId, TTransform, TResource>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<io::Result<()>>
|
||||
where
|
||||
TTransform: PacketTransform,
|
||||
TResource: Clone,
|
||||
{
|
||||
let received = match ready!(self.sockets.poll_recv_from(cx)) {
|
||||
Ok(received) => received,
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
use std::collections::HashMap;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::net::IpAddr;
|
||||
use std::time::Instant;
|
||||
|
||||
use bimap::BiMap;
|
||||
use chrono::{DateTime, Utc};
|
||||
use connlib_shared::messages::DnsServer;
|
||||
use connlib_shared::messages::{DnsServer, ResourceId};
|
||||
use connlib_shared::IpProvider;
|
||||
use connlib_shared::{Error, Result};
|
||||
use ip_network::IpNetwork;
|
||||
@@ -20,25 +20,44 @@ type ExpiryingResource = (ResourceDescription, Option<DateTime<Utc>>);
|
||||
// is 30 seconds. See resolvconf(5) timeout.
|
||||
const IDS_EXPIRE: std::time::Duration = std::time::Duration::from_secs(60);
|
||||
|
||||
pub struct Peer<TId, TTransform> {
|
||||
allowed_ips: IpNetworkTable<()>,
|
||||
pub struct Peer<TId, TTransform, TResource> {
|
||||
// TODO: we should refactor this
|
||||
// in the gateway-side this means that we are explicit about ()
|
||||
// maybe duping the Peer struct is the way to go
|
||||
pub allowed_ips: IpNetworkTable<TResource>,
|
||||
pub conn_id: TId,
|
||||
pub transform: TTransform,
|
||||
}
|
||||
|
||||
impl<TId, TTransform> Peer<TId, TTransform>
|
||||
impl<TId, TTransform> Peer<TId, TTransform, HashSet<ResourceId>>
|
||||
where
|
||||
TId: Copy,
|
||||
TTransform: PacketTransform,
|
||||
{
|
||||
pub(crate) fn insert_id(&mut self, ip: &IpNetwork, id: &ResourceId) {
|
||||
if let Some(resources) = self.allowed_ips.exact_match_mut(*ip) {
|
||||
resources.insert(*id);
|
||||
} else {
|
||||
self.allowed_ips.insert(*ip, HashSet::from([*id]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<TId, TTransform, TResource> Peer<TId, TTransform, TResource>
|
||||
where
|
||||
TId: Copy,
|
||||
TTransform: PacketTransform,
|
||||
TResource: Clone,
|
||||
{
|
||||
pub(crate) fn new(
|
||||
ips: Vec<IpNetwork>,
|
||||
conn_id: TId,
|
||||
transform: TTransform,
|
||||
) -> Peer<TId, TTransform> {
|
||||
ips: &[IpNetwork],
|
||||
resource: TResource,
|
||||
) -> Peer<TId, TTransform, TResource> {
|
||||
let mut allowed_ips = IpNetworkTable::new();
|
||||
for ip in ips {
|
||||
allowed_ips.insert(ip, ());
|
||||
allowed_ips.insert(*ip, resource.clone());
|
||||
}
|
||||
|
||||
Peer {
|
||||
@@ -48,10 +67,6 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn add_allowed_ip(&mut self, ip: IpNetwork) {
|
||||
self.allowed_ips.insert(ip, ());
|
||||
}
|
||||
|
||||
fn is_allowed(&self, addr: IpAddr) -> bool {
|
||||
self.allowed_ips.longest_match(addr).is_some()
|
||||
}
|
||||
@@ -92,7 +107,7 @@ impl Default for PacketTransformGateway {
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct PacketTransformClient {
|
||||
translations: BiMap<IpAddr, IpAddr>,
|
||||
pub translations: BiMap<IpAddr, IpAddr>,
|
||||
dns_mapping: BiMap<IpAddr, DnsServer>,
|
||||
mangled_dns_ids: HashMap<u16, std::time::Instant>,
|
||||
}
|
||||
@@ -133,6 +148,13 @@ impl PacketTransformGateway {
|
||||
.retain(|_, (_, e)| !e.is_some_and(|e| e <= Utc::now()));
|
||||
}
|
||||
|
||||
pub(crate) fn remove_resource(&mut self, resource: &ResourceId) {
|
||||
self.resources.retain(|_, (r, _)| match r {
|
||||
connlib_shared::messages::ResourceDescription::Dns(r) => r.id != *resource,
|
||||
connlib_shared::messages::ResourceDescription::Cidr(r) => r.id != *resource,
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn add_resource(
|
||||
&mut self,
|
||||
ip: IpNetwork,
|
||||
|
||||
@@ -1,17 +1,18 @@
|
||||
use std::collections::HashMap;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::hash::Hash;
|
||||
use std::net::IpAddr;
|
||||
|
||||
use crate::peer::{PacketTransform, Peer};
|
||||
use connlib_shared::messages::ResourceId;
|
||||
use ip_network::IpNetwork;
|
||||
use ip_network_table::IpNetworkTable;
|
||||
|
||||
pub struct PeerStore<TId, TTransform> {
|
||||
pub struct PeerStore<TId, TTransform, TResource> {
|
||||
id_by_ip: IpNetworkTable<TId>,
|
||||
peer_by_id: HashMap<TId, Peer<TId, TTransform>>,
|
||||
peer_by_id: HashMap<TId, Peer<TId, TTransform, TResource>>,
|
||||
}
|
||||
|
||||
impl<T, U> Default for PeerStore<T, U> {
|
||||
impl<TId, TTransform, TResource> Default for PeerStore<TId, TTransform, TResource> {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
id_by_ip: IpNetworkTable::new(),
|
||||
@@ -20,67 +21,92 @@ impl<T, U> Default for PeerStore<T, U> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<TId, TTransform> PeerStore<TId, TTransform>
|
||||
impl<TId, TTransform> PeerStore<TId, TTransform, HashSet<ResourceId>>
|
||||
where
|
||||
TId: Hash + Eq + Clone + Copy,
|
||||
TId: Hash + Eq + Copy,
|
||||
TTransform: PacketTransform,
|
||||
{
|
||||
pub fn retain(&mut self, f: impl Fn(&TId, &mut Peer<TId, TTransform>) -> bool) {
|
||||
pub fn add_ips_with_resource(&mut self, id: &TId, ips: &[IpNetwork], resource: &ResourceId) {
|
||||
for ip in ips {
|
||||
let Some(peer) = self.add_ip(id, ip) else {
|
||||
continue;
|
||||
};
|
||||
peer.insert_id(ip, resource);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<TId, TTransform, TResource> PeerStore<TId, TTransform, TResource>
|
||||
where
|
||||
TId: Hash + Eq + Copy,
|
||||
TTransform: PacketTransform,
|
||||
{
|
||||
pub fn retain(&mut self, f: impl Fn(&TId, &mut Peer<TId, TTransform, TResource>) -> bool) {
|
||||
self.peer_by_id.retain(f);
|
||||
self.id_by_ip
|
||||
.retain(|_, id| self.peer_by_id.contains_key(id));
|
||||
}
|
||||
|
||||
pub fn add_ips(&mut self, id: &TId, ips: &[IpNetwork]) -> Option<&Peer<TId, TTransform>> {
|
||||
pub fn add_ip(
|
||||
&mut self,
|
||||
id: &TId,
|
||||
ip: &IpNetwork,
|
||||
) -> Option<&mut Peer<TId, TTransform, TResource>> {
|
||||
let peer = self.peer_by_id.get_mut(id)?;
|
||||
|
||||
for ip in ips {
|
||||
self.id_by_ip.insert(*ip, peer.conn_id);
|
||||
peer.add_allowed_ip(*ip);
|
||||
}
|
||||
|
||||
self.id_by_ip.insert(*ip, *id);
|
||||
Some(peer)
|
||||
}
|
||||
|
||||
pub fn insert(&mut self, peer: Peer<TId, TTransform>) -> Option<Peer<TId, TTransform>> {
|
||||
pub fn insert(
|
||||
&mut self,
|
||||
peer: Peer<TId, TTransform, TResource>,
|
||||
ips: &[IpNetwork],
|
||||
) -> Option<Peer<TId, TTransform, TResource>> {
|
||||
self.id_by_ip.retain(|_, &mut r_id| r_id != peer.conn_id);
|
||||
|
||||
self.peer_by_id.insert(peer.conn_id, peer)
|
||||
let id = peer.conn_id;
|
||||
let old_peer = self.peer_by_id.insert(id, peer);
|
||||
|
||||
for ip in ips {
|
||||
self.add_ip(&id, ip);
|
||||
}
|
||||
|
||||
old_peer
|
||||
}
|
||||
|
||||
pub fn remove(&mut self, id: &TId) -> Option<Peer<TId, TTransform>> {
|
||||
pub fn remove(&mut self, id: &TId) -> Option<Peer<TId, TTransform, TResource>> {
|
||||
self.id_by_ip.retain(|_, r_id| r_id != id);
|
||||
self.peer_by_id.remove(id)
|
||||
}
|
||||
|
||||
pub fn exact_match(&self, ip: IpNetwork) -> Option<&Peer<TId, TTransform>> {
|
||||
pub fn exact_match(&self, ip: IpNetwork) -> Option<&Peer<TId, TTransform, TResource>> {
|
||||
let ip = self.id_by_ip.exact_match(ip)?;
|
||||
self.peer_by_id.get(ip)
|
||||
}
|
||||
|
||||
pub fn get(&self, id: &TId) -> Option<&Peer<TId, TTransform>> {
|
||||
pub fn get(&self, id: &TId) -> Option<&Peer<TId, TTransform, TResource>> {
|
||||
self.peer_by_id.get(id)
|
||||
}
|
||||
|
||||
pub fn get_mut(&mut self, id: &TId) -> Option<&mut Peer<TId, TTransform>> {
|
||||
pub fn get_mut(&mut self, id: &TId) -> Option<&mut Peer<TId, TTransform, TResource>> {
|
||||
self.peer_by_id.get_mut(id)
|
||||
}
|
||||
|
||||
pub fn peer_by_ip(&self, ip: IpAddr) -> Option<&Peer<TId, TTransform>> {
|
||||
pub fn peer_by_ip(&self, ip: IpAddr) -> Option<&Peer<TId, TTransform, TResource>> {
|
||||
let (_, id) = self.id_by_ip.longest_match(ip)?;
|
||||
self.peer_by_id.get(id)
|
||||
}
|
||||
|
||||
pub fn peer_by_ip_mut(&mut self, ip: IpAddr) -> Option<&mut Peer<TId, TTransform>> {
|
||||
pub fn peer_by_ip_mut(&mut self, ip: IpAddr) -> Option<&mut Peer<TId, TTransform, TResource>> {
|
||||
let (_, id) = self.id_by_ip.longest_match(ip)?;
|
||||
self.peer_by_id.get_mut(id)
|
||||
}
|
||||
|
||||
pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut Peer<TId, TTransform>> {
|
||||
pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut Peer<TId, TTransform, TResource>> {
|
||||
self.peer_by_id.values_mut()
|
||||
}
|
||||
|
||||
pub fn iter(&mut self) -> impl Iterator<Item = &Peer<TId, TTransform>> {
|
||||
pub fn iter(&mut self) -> impl Iterator<Item = &Peer<TId, TTransform, TResource>> {
|
||||
self.peer_by_id.values()
|
||||
}
|
||||
}
|
||||
@@ -93,16 +119,21 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn can_insert_and_retrieve_peer() {
|
||||
let mut peer_storage = PeerStore::default();
|
||||
peer_storage.insert(Peer::new(vec![], 0, PacketTransformGateway::default()));
|
||||
let mut peer_storage = PeerStore::<_, _, ()>::default();
|
||||
peer_storage.insert(
|
||||
Peer::new(0, PacketTransformGateway::default(), &[], ()),
|
||||
&[],
|
||||
);
|
||||
assert!(peer_storage.get(&0).is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_insert_and_retrieve_peer_by_ip() {
|
||||
let mut peer_storage = PeerStore::default();
|
||||
peer_storage.insert(Peer::new(vec![], 0, PacketTransformGateway::default()));
|
||||
peer_storage.add_ips(&0, &["100.0.0.0/24".parse().unwrap()]);
|
||||
let mut peer_storage = PeerStore::<_, _, ()>::default();
|
||||
peer_storage.insert(
|
||||
Peer::new(0, PacketTransformGateway::default(), &[], ()),
|
||||
&["100.0.0.0/24".parse().unwrap()],
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
peer_storage
|
||||
@@ -115,9 +146,11 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn can_remove_peer() {
|
||||
let mut peer_storage = PeerStore::default();
|
||||
peer_storage.insert(Peer::new(vec![], 0, PacketTransformGateway::default()));
|
||||
peer_storage.add_ips(&0, &["100.0.0.0/24".parse().unwrap()]);
|
||||
let mut peer_storage = PeerStore::<_, _, ()>::default();
|
||||
peer_storage.insert(
|
||||
Peer::new(0, PacketTransformGateway::default(), &[], ()),
|
||||
&["100.0.0.0/24".parse().unwrap()],
|
||||
);
|
||||
peer_storage.remove(&0);
|
||||
|
||||
assert!(peer_storage.get(&0).is_none());
|
||||
@@ -128,10 +161,15 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn inserting_peer_removes_previous_instances_of_same_id() {
|
||||
let mut peer_storage = PeerStore::default();
|
||||
peer_storage.insert(Peer::new(vec![], 0, PacketTransformGateway::default()));
|
||||
peer_storage.add_ips(&0, &["100.0.0.0/24".parse().unwrap()]);
|
||||
peer_storage.insert(Peer::new(vec![], 0, PacketTransformGateway::default()));
|
||||
let mut peer_storage = PeerStore::<_, _, ()>::default();
|
||||
peer_storage.insert(
|
||||
Peer::new(0, PacketTransformGateway::default(), &[], ()),
|
||||
&["100.0.0.0/24".parse().unwrap()],
|
||||
);
|
||||
peer_storage.insert(
|
||||
Peer::new(0, PacketTransformGateway::default(), &[], ()),
|
||||
&[],
|
||||
);
|
||||
|
||||
assert!(peer_storage.get(&0).is_some());
|
||||
assert!(peer_storage
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use crate::messages::{
|
||||
AllowAccess, BroadcastClientIceCandidates, ClientIceCandidates, ConnectionReady,
|
||||
EgressMessages, IngressMessages, RequestConnection,
|
||||
EgressMessages, IngressMessages, RejectAccess, RequestConnection,
|
||||
};
|
||||
use crate::CallbackHandler;
|
||||
use anyhow::{anyhow, bail, Result};
|
||||
@@ -201,6 +201,20 @@ impl Eventloop {
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
Poll::Ready(phoenix_channel::Event::InboundMessage {
|
||||
msg:
|
||||
IngressMessages::RejectAccess(RejectAccess {
|
||||
client_id,
|
||||
resource_id,
|
||||
}),
|
||||
..
|
||||
}) => {
|
||||
tracing::debug!(client = %client_id, resource = %resource_id, "Access removed");
|
||||
|
||||
self.tunnel.remove_access(&client_id, &resource_id);
|
||||
continue;
|
||||
}
|
||||
Poll::Ready(phoenix_channel::Event::InboundMessage {
|
||||
msg: IngressMessages::Init(_),
|
||||
..
|
||||
|
||||
@@ -86,6 +86,12 @@ pub struct AllowAccess {
|
||||
pub reference: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
|
||||
pub struct RejectAccess {
|
||||
pub client_id: ClientId,
|
||||
pub resource_id: ResourceId,
|
||||
}
|
||||
|
||||
// These messages are the messages that can be received
|
||||
// either by a client or a gateway by the client.
|
||||
#[derive(Debug, Deserialize, Clone, PartialEq, Eq)]
|
||||
@@ -93,6 +99,7 @@ pub struct AllowAccess {
|
||||
pub enum IngressMessages {
|
||||
RequestConnection(RequestConnection),
|
||||
AllowAccess(AllowAccess),
|
||||
RejectAccess(RejectAccess),
|
||||
IceCandidates(ClientIceCandidates),
|
||||
Init(InitGateway),
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user