diff --git a/rust/connlib/tunnel/src/peer.rs b/rust/connlib/tunnel/src/peer.rs index 344adab39..742c217c9 100644 --- a/rust/connlib/tunnel/src/peer.rs +++ b/rust/connlib/tunnel/src/peer.rs @@ -188,8 +188,11 @@ impl ClientOnGateway { } pub(crate) fn expire_resources(&mut self, now: DateTime) { - self.resources - .retain(|_, r| !r.expires_at.is_some_and(|e| e <= now)); + for resource in self.resources.values_mut() { + resource.retain(|r| !r.expires_at.is_some_and(|e| e <= now)); + } + + self.resources.retain(|_, r| !r.is_empty()); self.recalculate_filters(); } @@ -205,14 +208,15 @@ impl ClientOnGateway { filters: Filters, expires_at: Option>, ) { - self.resources.insert( - resource, - ResourceOnGateway { + self.resources + .entry(resource) + .or_default() + .push(ResourceOnGateway { ips, filters, + // Each resource subdomain can expire individually so it's worth keeping a list expires_at, - }, - ); + }); self.recalculate_filters(); } @@ -225,7 +229,10 @@ impl ClientOnGateway { let Some(old_resource) = self.resources.get_mut(&resource.id()) else { return; }; - old_resource.filters = resource.filters(); + for r in old_resource { + r.filters = resource.filters(); + } + self.recalculate_filters(); } @@ -235,10 +242,10 @@ impl ClientOnGateway { // in case that 2 or more resources have overlapping rules. fn recalculate_filters(&mut self) { self.filters = IpNetworkTable::new(); - for resource in self.resources.values() { + for resource in self.resources.values().flatten() { for ip in &resource.ips { let mut filter_engine = FilterEngine::empty(); - let filters = self.resources.values().filter_map(|r| { + let filters = self.resources.values().flatten().filter_map(|r| { r.ips .iter() .any(|r_ip| network_contains_network(*r_ip, *ip)) @@ -358,7 +365,7 @@ struct ResourceOnGateway { pub struct ClientOnGateway { id: ClientId, allowed_ips: IpNetworkTable<()>, - resources: HashMap, + resources: HashMap>, filters: IpNetworkTable, } @@ -563,6 +570,75 @@ mod proptests { } } + #[test_strategy::proptest()] + fn gateway_accepts_allowed_packet_multiple_ips_resource_multiple_adds( + #[strategy(client_id())] client_id: ClientId, + #[strategy(resource_id())] resource_id: ResourceId, + #[strategy(collection::vec(source_resource_and_host_within(), 1..=5))] config_res_1: Vec<( + IpAddr, + IpNetwork, + IpAddr, + )>, + #[strategy(collection::vec(source_resource_and_host_within(), 1..=5))] config_res_2: Vec<( + IpAddr, + IpNetwork, + IpAddr, + )>, + #[strategy(filters_with_allowed_protocol())] protocol_config: (Filters, Protocol), + #[strategy(any::())] sport: u16, + #[strategy(any::>())] payload: Vec, + ) { + let (src_1, resource_addr_1, dest_1): (Vec<_>, Vec<_>, Vec<_>) = + config_res_1.into_iter().multiunzip(); + let (src_2, resource_addr_2, dest_2): (Vec<_>, Vec<_>, Vec<_>) = + config_res_2.into_iter().multiunzip(); + let (filters, protocol) = protocol_config; + let mut src = Vec::new(); + src.extend(src_1); + src.extend(src_2); + let mut peer = ClientOnGateway::new( + client_id, + &src.clone().into_iter().map(Into::into).collect_vec(), + ); + + peer.add_resource(resource_addr_1, resource_id, filters.clone(), None); + peer.add_resource(resource_addr_2, resource_id, filters, None); + + for dest in dest_1 { + for src in &src { + if dest.is_ipv4() == src.is_ipv4() { + let packet = match protocol { + Protocol::Tcp { dport } => { + tcp_packet(*src, dest, sport, dport, payload.clone()) + } + Protocol::Udp { dport } => { + udp_packet(*src, dest, sport, dport, payload.clone()) + } + Protocol::Icmp => icmp_request_packet(*src, dest), + }; + assert!(peer.ensure_allowed(&packet).is_ok()); + } + } + } + + for dest in dest_2 { + for src in &src { + if dest.is_ipv4() == src.is_ipv4() { + let packet = match protocol { + Protocol::Tcp { dport } => { + tcp_packet(*src, dest, sport, dport, payload.clone()) + } + Protocol::Udp { dport } => { + udp_packet(*src, dest, sport, dport, payload.clone()) + } + Protocol::Icmp => icmp_request_packet(*src, dest), + }; + assert!(peer.ensure_allowed(&packet).is_ok()); + } + } + } + } + #[test_strategy::proptest()] fn gateway_accepts_different_resources_with_same_ip_packet( #[strategy(client_id())] client_id: ClientId,