diff --git a/rust/connlib/tunnel/src/client.rs b/rust/connlib/tunnel/src/client.rs index 2d38c1480..d73bbea33 100644 --- a/rust/connlib/tunnel/src/client.rs +++ b/rust/connlib/tunnel/src/client.rs @@ -1,13 +1,16 @@ mod dns_cache; mod dns_resource_nat; +mod gateway_on_client; mod resource; +pub(crate) use crate::client::gateway_on_client::GatewayOnClient; +#[cfg(all(feature = "proptest", test))] +pub(crate) use resource::DnsResource; +pub(crate) use resource::{CidrResource, InternetResource, Resource}; + use dns_resource_nat::DnsResourceNat; use dns_types::ResponseCode; use firezone_telemetry::{analytics, feature_flags}; -#[cfg(all(feature = "proptest", test))] -pub(crate) use resource::DnsResource; -pub(crate) use resource::{CidrResource, InternetResource, Resource}; use ringbuffer::{AllocRingBuffer, RingBuffer}; use crate::client::dns_cache::DnsCache; @@ -31,7 +34,6 @@ use ip_packet::{IpPacket, MAX_UDP_PAYLOAD}; use itertools::Itertools; use crate::ClientEvent; -use crate::peer::GatewayOnClient; use lru::LruCache; use secrecy::{ExposeSecret as _, Secret}; use snownet::{ClientNode, NoTurnServers, RelaySocket, Transmit}; diff --git a/rust/connlib/tunnel/src/client/gateway_on_client.rs b/rust/connlib/tunnel/src/client/gateway_on_client.rs new file mode 100644 index 000000000..b2f22bdb9 --- /dev/null +++ b/rust/connlib/tunnel/src/client/gateway_on_client.rs @@ -0,0 +1,69 @@ +use std::{ + collections::HashSet, + net::{IpAddr, SocketAddr}, +}; + +use connlib_model::{GatewayId, ResourceId}; +use ip_network::IpNetwork; +use ip_network_table::IpNetworkTable; +use ip_packet::IpPacket; + +use crate::{IpConfig, NotAllowedResource}; + +/// The state of one gateway on a client. +pub(crate) struct GatewayOnClient { + id: GatewayId, + gateway_tun: IpConfig, + pub allowed_ips: IpNetworkTable>, +} + +impl GatewayOnClient { + pub(crate) fn insert_id(&mut self, ip: &IpNetwork, id: &ResourceId) { + if let Some(resources) = self.allowed_ips.exact_match_mut(*ip) { + resources.insert(*id); + } else { + self.allowed_ips.insert(*ip, HashSet::from([*id])); + } + } + + /// For a given destination IP, return the endpoint to which the DNS query should be sent. + pub(crate) fn tun_dns_server_endpoint(&self, dst: IpAddr) -> SocketAddr { + let new_dst_ip = match dst { + IpAddr::V4(_) => self.gateway_tun.v4.into(), + IpAddr::V6(_) => self.gateway_tun.v6.into(), + }; + let new_dst_port = crate::gateway::TUN_DNS_PORT; + + SocketAddr::new(new_dst_ip, new_dst_port) + } +} + +impl GatewayOnClient { + pub(crate) fn new(id: GatewayId, gateway_tun: IpConfig) -> GatewayOnClient { + GatewayOnClient { + id, + allowed_ips: IpNetworkTable::new(), + gateway_tun, + } + } +} + +impl GatewayOnClient { + pub(crate) fn ensure_allowed_src(&self, packet: &IpPacket) -> anyhow::Result<()> { + let src = packet.source(); + + if self.gateway_tun.is_ip(src) { + return Ok(()); + } + + if self.allowed_ips.longest_match(src).is_none() { + return Err(anyhow::Error::new(NotAllowedResource(src))); + } + + Ok(()) + } + + pub fn id(&self) -> GatewayId { + self.id + } +} diff --git a/rust/connlib/tunnel/src/gateway.rs b/rust/connlib/tunnel/src/gateway.rs index cbfe5bfbd..60f3f2e2f 100644 --- a/rust/connlib/tunnel/src/gateway.rs +++ b/rust/connlib/tunnel/src/gateway.rs @@ -1,8 +1,14 @@ +mod client_on_gateway; +mod filter_engine; +mod nat_table; + +pub(crate) use crate::gateway::client_on_gateway::ClientOnGateway; + +use crate::gateway::client_on_gateway::TranslateOutboundResult; use crate::messages::gateway::ResourceDescription; use crate::messages::{Answer, IceCredentials, ResolveRequest, SecretKey}; -use crate::peer::TranslateOutboundResult; +use crate::peer_store::PeerStore; use crate::{GatewayEvent, IpConfig, p2p_control}; -use crate::{peer::ClientOnGateway, peer_store::PeerStore}; use anyhow::{Context, Result}; use boringtun::x25519::PublicKey; use chrono::{DateTime, Utc}; diff --git a/rust/connlib/tunnel/src/peer.rs b/rust/connlib/tunnel/src/gateway/client_on_gateway.rs similarity index 91% rename from rust/connlib/tunnel/src/peer.rs rename to rust/connlib/tunnel/src/gateway/client_on_gateway.rs index 6b0aa1fba..c5943b5d9 100644 --- a/rust/connlib/tunnel/src/peer.rs +++ b/rust/connlib/tunnel/src/gateway/client_on_gateway.rs @@ -1,65 +1,23 @@ -use std::collections::{BTreeMap, BTreeSet, HashSet, VecDeque, btree_map}; +use std::collections::{BTreeMap, BTreeSet, VecDeque, btree_map}; use std::iter; -use std::net::{IpAddr, SocketAddr}; +use std::net::IpAddr; use std::time::Instant; -use crate::client::{IPV4_RESOURCES, IPV6_RESOURCES}; -use crate::messages::gateway::Filters; -use crate::messages::gateway::ResourceDescription; +use anyhow::{Context, Result, bail}; use chrono::{DateTime, Utc}; -use connlib_model::{ClientId, GatewayId, ResourceId}; +use connlib_model::{ClientId, ResourceId}; use dns_types::DomainName; -use filter_engine::FilterEngine; use ip_network::{IpNetwork, Ipv4Network, Ipv6Network}; use ip_network_table::IpNetworkTable; use ip_packet::{IpPacket, Protocol, UnsupportedProtocol}; +use crate::client::{IPV4_RESOURCES, IPV6_RESOURCES}; +use crate::gateway::filter_engine::FilterEngine; +use crate::gateway::nat_table::{NatTable, TranslateIncomingResult}; +use crate::messages::gateway::Filters; +use crate::messages::gateway::ResourceDescription; use crate::utils::network_contains_network; -use crate::{GatewayEvent, IpConfig, otel}; - -use anyhow::{Context, Result, bail}; -use nat_table::{NatTable, TranslateIncomingResult}; - -mod filter_engine; -mod nat_table; - -/// The state of one gateway on a client. -pub(crate) struct GatewayOnClient { - id: GatewayId, - gateway_tun: IpConfig, - pub allowed_ips: IpNetworkTable>, -} - -impl GatewayOnClient { - pub(crate) fn insert_id(&mut self, ip: &IpNetwork, id: &ResourceId) { - if let Some(resources) = self.allowed_ips.exact_match_mut(*ip) { - resources.insert(*id); - } else { - self.allowed_ips.insert(*ip, HashSet::from([*id])); - } - } - - /// For a given destination IP, return the endpoint to which the DNS query should be sent. - pub(crate) fn tun_dns_server_endpoint(&self, dst: IpAddr) -> SocketAddr { - let new_dst_ip = match dst { - IpAddr::V4(_) => self.gateway_tun.v4.into(), - IpAddr::V6(_) => self.gateway_tun.v6.into(), - }; - let new_dst_port = crate::gateway::TUN_DNS_PORT; - - SocketAddr::new(new_dst_ip, new_dst_port) - } -} - -impl GatewayOnClient { - pub(crate) fn new(id: GatewayId, gateway_tun: IpConfig) -> GatewayOnClient { - GatewayOnClient { - id, - allowed_ips: IpNetworkTable::new(), - gateway_tun, - } - } -} +use crate::{GatewayEvent, IpConfig, NotAllowedResource, NotClientIp, otel}; /// The state of one client on a gateway. pub struct ClientOnGateway { @@ -70,8 +28,8 @@ pub struct ClientOnGateway { resources: BTreeMap, /// Caches the existence of internet resource - internet_resource_enabled: bool, - filters: IpNetworkTable, + internet_resource_enabled: Option, + filters: IpNetworkTable<(FilterEngine, ResourceId)>, permanent_translations: BTreeMap, nat_table: NatTable, buffered_events: VecDeque, @@ -79,6 +37,13 @@ pub struct ClientOnGateway { num_dropped_packets: opentelemetry::metrics::Counter, } +#[derive(Debug, PartialEq)] +pub enum TranslateOutboundResult { + Send(IpPacket), + DestinationUnreachable(IpPacket), + Filtered(IpPacket), +} + impl ClientOnGateway { pub(crate) fn new( id: ClientId, @@ -94,7 +59,7 @@ impl ClientOnGateway { permanent_translations: Default::default(), nat_table: Default::default(), buffered_events: Default::default(), - internet_resource_enabled: false, + internet_resource_enabled: None, num_dropped_packets: otel::metrics::network_packet_dropped(), } } @@ -310,11 +275,14 @@ impl ClientOnGateway { self.recalculate_cidr_filters(); self.recalculate_dns_filters(); - self.internet_resource_enabled = self.resources.values().any(|r| r.is_internet_resource()); + self.internet_resource_enabled = self + .resources + .iter() + .find_map(|(id, r)| r.is_internet_resource().then_some(*id)); } fn recalculate_cidr_filters(&mut self) { - for resource in self.resources.values().filter(|r| r.is_cidr()) { + for (id, resource) in self.resources.iter().filter(|(_, r)| r.is_cidr()) { for ip in &resource.ips() { let filters = self.resources.values().filter_map(|r| { r.ips() @@ -323,7 +291,7 @@ impl ClientOnGateway { .then_some(r.filters()) }); - insert_filters(&mut self.filters, *ip, filters); + insert_filters(&mut self.filters, *ip, *id, filters); } } } @@ -339,6 +307,7 @@ impl ClientOnGateway { insert_filters( &mut self.filters, IpNetwork::from(*addr), + *resource_id, iter::once(resource.filters()), ); } @@ -386,7 +355,7 @@ impl ClientOnGateway { return Ok(Some(packet)); } - if let Err(e) = self.ensure_allowed_resource(packet.source(), packet.source_protocol()) { + if let Err(e) = self.classify_resource(packet.source(), packet.source_protocol()) { tracing::debug!( "Inbound packet is not allowed, perhaps from an old client session? error = {e:#}" ); @@ -515,7 +484,7 @@ impl ClientOnGateway { return Ok(()); } - self.ensure_allowed_resource(packet.destination(), packet.destination_protocol())?; + self.classify_resource(packet.destination(), packet.destination_protocol())?; Ok(()) } @@ -528,25 +497,32 @@ impl ClientOnGateway { Ok(()) } - fn ensure_allowed_resource( + /// Classifies traffic to/from a resource IP. + /// + /// If traffic with this resource is allowed, the resource ID is returned. + fn classify_resource( &self, - ip: IpAddr, + resource_ip: IpAddr, protocol: Result, - ) -> anyhow::Result<()> { + ) -> anyhow::Result { // Note a Gateway with Internet resource should never get packets for other resources - if self.internet_resource_enabled && !is_dns_addr(ip) { - return Ok(()); + if let Some(rid) = self.internet_resource_enabled + && !is_dns_addr(resource_ip) + { + return Ok(rid); } - let (_, filter) = self + let (_, (filter, rid)) = self .filters - .longest_match(ip) + .longest_match(resource_ip) .context("No filter") - .context(NotAllowedResource(ip))?; + .context(NotAllowedResource(resource_ip))?; - filter.apply(protocol).context(NotAllowedResource(ip))?; + filter + .apply(protocol) + .context(NotAllowedResource(resource_ip))?; - Ok(()) + Ok(*rid) } pub fn id(&self) -> ClientId { @@ -554,41 +530,6 @@ impl ClientOnGateway { } } -#[derive(Debug, PartialEq)] -pub enum TranslateOutboundResult { - Send(IpPacket), - DestinationUnreachable(IpPacket), - Filtered(IpPacket), -} - -impl GatewayOnClient { - pub(crate) fn ensure_allowed_src(&self, packet: &IpPacket) -> anyhow::Result<()> { - let src = packet.source(); - - if self.gateway_tun.is_ip(src) { - return Ok(()); - } - - if self.allowed_ips.longest_match(src).is_none() { - return Err(anyhow::Error::new(NotAllowedResource(src))); - } - - Ok(()) - } - - pub fn id(&self) -> GatewayId { - self.id - } -} - -#[derive(Debug, thiserror::Error)] -#[error("Not a client IP: {0}")] -pub(crate) struct NotClientIp(IpAddr); - -#[derive(Debug, thiserror::Error)] -#[error("Traffic to/from this resource IP is not allowed: {0}")] -pub(crate) struct NotAllowedResource(IpAddr); - #[derive(Debug)] enum ResourceOnGateway { Cidr { @@ -748,35 +689,32 @@ fn is_dns_addr(addr: IpAddr) -> bool { } fn insert_filters<'a>( - filter_store: &mut IpNetworkTable, + filter_store: &mut IpNetworkTable<(FilterEngine, ResourceId)>, ip: IpNetwork, + id: ResourceId, filters: impl Iterator + Clone, ) { let filter_engine = FilterEngine::with_filters(filters); tracing::trace!(%ip, filters = ?filter_engine, "Installing new filters"); - filter_store.insert(ip, filter_engine); + filter_store.insert(ip, (filter_engine, id)); } #[cfg(test)] mod tests { + use super::*; + use std::{ - collections::BTreeSet, net::{Ipv4Addr, Ipv6Addr}, - time::{Duration, Instant}, + time::Duration, }; - use crate::{ - IpConfig, - messages::gateway::{Filter, PortRange, ResourceDescription, ResourceDescriptionCidr}, - peer::{TranslateOutboundResult, nat_table}, - }; - use chrono::Utc; - use connlib_model::{ClientId, ResourceId}; - use ip_network::{IpNetwork, Ipv4Network}; use ip_packet::make::TcpFlags; - use super::ClientOnGateway; + use crate::{ + gateway::nat_table, + messages::gateway::{Filter, PortRange, ResourceDescriptionCidr}, + }; #[test] fn gateway_filters_expire_individually() { @@ -831,52 +769,34 @@ mod tests { peer.expire_resources(now); assert!( - peer.ensure_allowed_resource( - tcp_packet.destination(), - tcp_packet.destination_protocol() - ) - .is_ok() + peer.classify_resource(tcp_packet.destination(), tcp_packet.destination_protocol()) + .is_ok() ); assert!( - peer.ensure_allowed_resource( - udp_packet.destination(), - udp_packet.destination_protocol() - ) - .is_ok() + peer.classify_resource(udp_packet.destination(), udp_packet.destination_protocol()) + .is_ok() ); peer.expire_resources(then); assert!( - peer.ensure_allowed_resource( - tcp_packet.destination(), - tcp_packet.destination_protocol() - ) - .is_err() + peer.classify_resource(tcp_packet.destination(), tcp_packet.destination_protocol()) + .is_err() ); assert!( - peer.ensure_allowed_resource( - udp_packet.destination(), - udp_packet.destination_protocol() - ) - .is_ok() + peer.classify_resource(udp_packet.destination(), udp_packet.destination_protocol()) + .is_ok() ); peer.expire_resources(after_then); assert!( - peer.ensure_allowed_resource( - tcp_packet.destination(), - tcp_packet.destination_protocol() - ) - .is_err() + peer.classify_resource(tcp_packet.destination(), tcp_packet.destination_protocol()) + .is_err() ); assert!( - peer.ensure_allowed_resource( - udp_packet.destination(), - udp_packet.destination_protocol() - ) - .is_err() + peer.classify_resource(udp_packet.destination(), udp_packet.destination_protocol()) + .is_err() ); } @@ -906,7 +826,7 @@ mod tests { assert!(matches!( peer.translate_outbound(request, Instant::now()).unwrap(), - crate::peer::TranslateOutboundResult::Send(_) + TranslateOutboundResult::Send(_) )); assert!( peer.translate_inbound(response, Instant::now()) @@ -952,7 +872,7 @@ mod tests { assert!(matches!( peer.translate_outbound(pkt, Instant::now()).unwrap(), - crate::peer::TranslateOutboundResult::Filtered(_) + TranslateOutboundResult::Filtered(_) )); let pkt = ip_packet::make::udp_packet( @@ -966,7 +886,7 @@ mod tests { assert!(matches!( peer.translate_outbound(pkt, Instant::now()).unwrap(), - crate::peer::TranslateOutboundResult::Filtered(_) + TranslateOutboundResult::Filtered(_) )); let pkt = ip_packet::make::udp_packet( @@ -1016,7 +936,7 @@ mod tests { assert!(matches!( peer.translate_outbound(pkt, Instant::now()).unwrap(), - crate::peer::TranslateOutboundResult::Filtered(_) + TranslateOutboundResult::Filtered(_) )); let pkt = ip_packet::make::udp_packet( @@ -1413,7 +1333,7 @@ mod proptests { } .unwrap(); assert!( - peer.ensure_allowed_resource(packet.destination(), packet.destination_protocol()) + peer.classify_resource(packet.destination(), packet.destination_protocol()) .is_ok() ); } @@ -1476,7 +1396,7 @@ mod proptests { .unwrap(); assert!( - peer.ensure_allowed_resource(packet.destination(), packet.destination_protocol()) + peer.classify_resource(packet.destination(), packet.destination_protocol()) .is_ok() ); } @@ -1529,7 +1449,7 @@ mod proptests { ); assert!( - peer.ensure_allowed_resource(packet.destination(), packet.destination_protocol()) + peer.classify_resource(packet.destination(), packet.destination_protocol()) .is_err() ); } @@ -1612,14 +1532,14 @@ mod proptests { peer.remove_resource(&resource_id_removed); assert!( - peer.ensure_allowed_resource( + peer.classify_resource( packet_allowed.destination(), packet_allowed.destination_protocol() ) .is_ok() ); assert!( - peer.ensure_allowed_resource( + peer.classify_resource( packet_rejected.destination(), packet_rejected.destination_protocol() ) diff --git a/rust/connlib/tunnel/src/peer/filter_engine.rs b/rust/connlib/tunnel/src/gateway/filter_engine.rs similarity index 100% rename from rust/connlib/tunnel/src/peer/filter_engine.rs rename to rust/connlib/tunnel/src/gateway/filter_engine.rs diff --git a/rust/connlib/tunnel/src/peer/nat_table.rs b/rust/connlib/tunnel/src/gateway/nat_table.rs similarity index 100% rename from rust/connlib/tunnel/src/peer/nat_table.rs rename to rust/connlib/tunnel/src/gateway/nat_table.rs diff --git a/rust/connlib/tunnel/src/lib.rs b/rust/connlib/tunnel/src/lib.rs index 6df110e52..f778663eb 100644 --- a/rust/connlib/tunnel/src/lib.rs +++ b/rust/connlib/tunnel/src/lib.rs @@ -34,7 +34,6 @@ mod io; pub mod messages; mod otel; mod p2p_control; -mod peer; mod peer_store; #[cfg(all(test, feature = "proptest"))] mod proptest; @@ -636,6 +635,14 @@ impl Drop for TunnelError { } } +#[derive(Debug, thiserror::Error)] +#[error("Not a client IP: {0}")] +pub(crate) struct NotClientIp(IpAddr); + +#[derive(Debug, thiserror::Error)] +#[error("Traffic to/from this resource IP is not allowed: {0}")] +pub(crate) struct NotAllowedResource(IpAddr); + /// Adapter-struct to [`fmt::Display`] a [`BTreeSet`]. #[expect(dead_code, reason = "It is used in the `Debug` impl of `TunConfig`")] struct DisplaySet<'a, T>(&'a BTreeSet); diff --git a/rust/connlib/tunnel/src/peer_store.rs b/rust/connlib/tunnel/src/peer_store.rs index d28cbc09c..6896e6c72 100644 --- a/rust/connlib/tunnel/src/peer_store.rs +++ b/rust/connlib/tunnel/src/peer_store.rs @@ -3,11 +3,13 @@ use std::collections::{HashMap, hash_map::Entry}; use std::hash::Hash; use std::net::IpAddr; -use crate::peer::{ClientOnGateway, GatewayOnClient}; use connlib_model::{ClientId, GatewayId, ResourceId}; use ip_network::IpNetwork; use ip_network_table::IpNetworkTable; +use crate::client::GatewayOnClient; +use crate::gateway::ClientOnGateway; + pub(crate) struct PeerStore { id_by_ip: IpNetworkTable, peer_by_id: HashMap,