diff --git a/rust/connlib/tunnel/src/peer.rs b/rust/connlib/tunnel/src/peer.rs index 6f0939095..d4928b9ce 100644 --- a/rust/connlib/tunnel/src/peer.rs +++ b/rust/connlib/tunnel/src/peer.rs @@ -11,7 +11,7 @@ use connlib_model::{ClientId, DomainName, GatewayId, ResourceId}; use filter_engine::FilterEngine; use ip_network::{IpNetwork, Ipv4Network, Ipv6Network}; use ip_network_table::IpNetworkTable; -use ip_packet::IpPacket; +use ip_packet::{IpPacket, Protocol, UnsupportedProtocol}; use crate::utils::network_contains_network; use crate::GatewayEvent; @@ -272,7 +272,7 @@ impl ClientOnGateway { now: Instant, ) -> anyhow::Result> { // Filtering a packet is not an error. - if let Err(e) = self.ensure_allowed(&packet) { + if let Err(e) = self.ensure_allowed_dst(&packet) { tracing::debug!(filtered_packet = ?packet, "{e:#}"); return Ok(None); } @@ -287,6 +287,28 @@ impl ClientOnGateway { &mut self, packet: IpPacket, now: Instant, + ) -> anyhow::Result> { + let Some(packet) = self.transform_tun_to_network(packet, now)? else { + return Ok(None); + }; + + self.ensure_client_ip(packet.destination())?; + + if let Err(e) = self.ensure_allowed_resource(packet.source(), packet.source_protocol()) { + tracing::debug!( + "Inbound packet is not allowed, perhaps from an old client session? error = {e:#}" + ); + + return Ok(None); + } + + Ok(Some(packet)) + } + + fn transform_tun_to_network( + &mut self, + packet: IpPacket, + now: Instant, ) -> anyhow::Result> { let (proto, ip) = match self.nat_table.translate_incoming(&packet, now)? { TranslateIncomingResult::Ok { proto, src } => (proto, src), @@ -326,39 +348,38 @@ impl ClientOnGateway { self.resources.contains_key(&resource) } - fn ensure_allowed(&self, packet: &IpPacket) -> anyhow::Result<()> { - self.ensure_allowed_src(packet)?; - self.ensure_allowed_dst(packet)?; + fn ensure_allowed_dst(&self, packet: &IpPacket) -> anyhow::Result<()> { + self.ensure_client_ip(packet.source())?; + self.ensure_allowed_resource(packet.destination(), packet.destination_protocol())?; Ok(()) } - fn ensure_allowed_src(&self, packet: &IpPacket) -> anyhow::Result<()> { - let src = packet.source(); - - if !self.allowed_ips().contains(&src) { - return Err(anyhow::Error::new(SrcNotAllowed(src))); + fn ensure_client_ip(&self, ip: IpAddr) -> anyhow::Result<()> { + if !self.allowed_ips().contains(&ip) { + return Err(anyhow::Error::new(NotClientIp(ip))); } Ok(()) } - /// Check if an incoming packet arriving over the network is ok to be forwarded to the TUN device. - fn ensure_allowed_dst(&self, packet: &IpPacket) -> anyhow::Result<()> { - let dst = packet.destination(); - + fn ensure_allowed_resource( + &self, + ip: IpAddr, + protocol: Result, + ) -> anyhow::Result<()> { // Note a Gateway with Internet resource should never get packets for other resources - if self.internet_resource_enabled && !is_dns_addr(packet.destination()) { + if self.internet_resource_enabled && !is_dns_addr(ip) { return Ok(()); } let (_, filter) = self .filters - .longest_match(dst) + .longest_match(ip) .context("No filter") - .context(DstNotAllowed(dst))?; + .context(NotAllowedResource(ip))?; - filter.apply(packet).context(DstNotAllowed(dst))?; + filter.apply(protocol).context(NotAllowedResource(ip))?; Ok(()) } @@ -373,7 +394,7 @@ impl GatewayOnClient { let src = packet.source(); if self.allowed_ips.longest_match(src).is_none() { - return Err(anyhow::Error::new(SrcNotAllowed(src))); + return Err(anyhow::Error::new(NotClientIp(src))); } Ok(()) @@ -385,12 +406,12 @@ impl GatewayOnClient { } #[derive(Debug, thiserror::Error)] -#[error("Source not allowed: {0}")] -pub(crate) struct SrcNotAllowed(IpAddr); +#[error("Not a client IP: {0}")] +pub(crate) struct NotClientIp(IpAddr); #[derive(Debug, thiserror::Error)] -#[error("Destination not allowed: {0}")] -pub(crate) struct DstNotAllowed(IpAddr); +#[error("Accessing this resource IP is not allowed: {0}")] +pub(crate) struct NotAllowedResource(IpAddr); #[derive(Debug)] enum ResourceOnGateway { @@ -627,18 +648,30 @@ mod tests { peer.expire_resources(now); - assert!(peer.ensure_allowed_dst(&tcp_packet).is_ok()); - assert!(peer.ensure_allowed_dst(&udp_packet).is_ok()); + assert!(peer + .ensure_allowed_resource(tcp_packet.destination(), tcp_packet.destination_protocol()) + .is_ok()); + assert!(peer + .ensure_allowed_resource(udp_packet.destination(), udp_packet.destination_protocol()) + .is_ok()); peer.expire_resources(then); - assert!(peer.ensure_allowed_dst(&tcp_packet).is_err()); - assert!(peer.ensure_allowed_dst(&udp_packet).is_ok()); + assert!(peer + .ensure_allowed_resource(tcp_packet.destination(), tcp_packet.destination_protocol()) + .is_err()); + assert!(peer + .ensure_allowed_resource(udp_packet.destination(), udp_packet.destination_protocol()) + .is_ok()); peer.expire_resources(after_then); - assert!(peer.ensure_allowed_dst(&tcp_packet).is_err()); - assert!(peer.ensure_allowed_dst(&udp_packet).is_err()); + assert!(peer + .ensure_allowed_resource(tcp_packet.destination(), tcp_packet.destination_protocol()) + .is_err()); + assert!(peer + .ensure_allowed_resource(udp_packet.destination(), udp_packet.destination_protocol()) + .is_err()); } #[test] @@ -959,7 +992,9 @@ mod proptests { Protocol::Icmp => icmp_request_packet(src, *dest, 1, 0, &[]), } .unwrap(); - assert!(peer.ensure_allowed_dst(&packet).is_ok()); + assert!(peer + .ensure_allowed_resource(packet.destination(), packet.destination_protocol()) + .is_ok()); } } @@ -1005,7 +1040,9 @@ mod proptests { } .unwrap(); - assert!(peer.ensure_allowed_dst(&packet).is_ok()); + assert!(peer + .ensure_allowed_resource(packet.destination(), packet.destination_protocol()) + .is_ok()); } } @@ -1046,7 +1083,9 @@ mod proptests { None, ); - assert!(peer.ensure_allowed_dst(&packet).is_err()); + assert!(peer + .ensure_allowed_resource(packet.destination(), packet.destination_protocol()) + .is_err()); } #[test_strategy::proptest()] @@ -1110,8 +1149,18 @@ mod proptests { ); peer.remove_resource(&resource_id_removed); - assert!(peer.ensure_allowed_dst(&packet_allowed).is_ok()); - assert!(peer.ensure_allowed_dst(&packet_rejected).is_err()); + assert!(peer + .ensure_allowed_resource( + packet_allowed.destination(), + packet_allowed.destination_protocol() + ) + .is_ok()); + assert!(peer + .ensure_allowed_resource( + packet_rejected.destination(), + packet_rejected.destination_protocol() + ) + .is_err()); } fn cidr_resources( diff --git a/rust/connlib/tunnel/src/peer/filter_engine.rs b/rust/connlib/tunnel/src/peer/filter_engine.rs index ebf63d84c..ead3110ab 100644 --- a/rust/connlib/tunnel/src/peer/filter_engine.rs +++ b/rust/connlib/tunnel/src/peer/filter_engine.rs @@ -1,4 +1,4 @@ -use ip_packet::IpPacket; +use ip_packet::{Protocol, UnsupportedProtocol}; use rangemap::RangeInclusiveSet; use crate::messages::gateway::{Filter, Filters}; @@ -24,13 +24,18 @@ pub(crate) enum Filtered { Udp, #[error("ICMP not allowed")] Icmp, + #[error(transparent)] + UnsupportedProtocol(#[from] UnsupportedProtocol), } impl FilterEngine { - pub(crate) fn apply(&self, packet: &IpPacket) -> Result<(), Filtered> { + pub(crate) fn apply( + &self, + protocol: Result, + ) -> Result<(), Filtered> { match self { FilterEngine::PermitAll => Ok(()), - FilterEngine::PermitSome(filter_engine) => filter_engine.apply(packet), + FilterEngine::PermitSome(filter_engine) => filter_engine.apply(protocol), } } @@ -58,32 +63,15 @@ impl AllowRules { } } - fn apply(&self, packet: &IpPacket) -> Result<(), Filtered> { - if let Some(dest_port) = packet.as_tcp().map(|tcp| tcp.destination_port()) { - if self.tcp.contains(&dest_port) { - return Ok(()); - } - - return Err(Filtered::Tcp); + fn apply(&self, protocol: Result) -> Result<(), Filtered> { + match protocol? { + Protocol::Tcp(port) if self.tcp.contains(&port) => Ok(()), + Protocol::Udp(port) if self.udp.contains(&port) => Ok(()), + Protocol::Icmp(_) if self.icmp => Ok(()), + Protocol::Tcp(_) => Err(Filtered::Tcp), + Protocol::Udp(_) => Err(Filtered::Udp), + Protocol::Icmp(_) => Err(Filtered::Icmp), } - - if let Some(dest_port) = packet.as_udp().map(|udp| udp.destination_port()) { - if self.udp.contains(&dest_port) { - return Ok(()); - } - - return Err(Filtered::Udp); - } - - if packet.is_icmp() || packet.is_icmpv6() { - if self.icmp { - return Ok(()); - } - - return Err(Filtered::Icmp); - } - - Ok(()) } fn add_filters<'a>(&mut self, filters: impl IntoIterator) {