diff --git a/rust/connlib/clients/shared/src/eventloop.rs b/rust/connlib/clients/shared/src/eventloop.rs index 0bd362025..9f2a55f7d 100644 --- a/rust/connlib/clients/shared/src/eventloop.rs +++ b/rust/connlib/clients/shared/src/eventloop.rs @@ -229,10 +229,10 @@ where self.tunnel.update_relays(HashSet::default(), relays) } IngressMessages::ResourceCreatedOrUpdated(resource) => { - self.tunnel.add_resources(&[resource]); + self.tunnel.add_resource(resource); } IngressMessages::ResourceDeleted(resource) => { - self.tunnel.remove_resources(&[resource]); + self.tunnel.remove_resource(resource); } IngressMessages::RelaysPresence(RelaysPresence { disconnected_ids, diff --git a/rust/connlib/clients/shared/src/lib.rs b/rust/connlib/clients/shared/src/lib.rs index eabfd7f52..1a2148a07 100644 --- a/rust/connlib/clients/shared/src/lib.rs +++ b/rust/connlib/clients/shared/src/lib.rs @@ -228,8 +228,6 @@ mod tests { }; tunnel.set_tun(Tun::new().unwrap()); tunnel.set_new_interface_config(interface).unwrap(); - let resources = vec![]; - tunnel.add_resources(&resources); let tunnel = tokio::spawn(async move { std::future::poll_fn(|cx| tunnel.poll_next_event(cx)) diff --git a/rust/connlib/tunnel/proptest-regressions/client.txt b/rust/connlib/tunnel/proptest-regressions/client.txt new file mode 100644 index 000000000..57c0e09ca --- /dev/null +++ b/rust/connlib/tunnel/proptest-regressions/client.txt @@ -0,0 +1,7 @@ +# Seeds for failure cases proptest has generated in the past. It is +# automatically read and these particular cases re-run before any +# novel cases are generated. +# +# It is recommended to check this file in to source control so that +# everyone who runs the test benefits from these saved cases. +cc 9b48b1c90455e632268397a3253352fd834ff7e2952f8efa5959543547be8892 # shrinks to input = _AddingSameResourceWithDifferentAddressUpdatesTheAddressArgs { resource: ResourceDescriptionCidr { id: ResourceId(0003585c-0f03-a9db-f663-31382f9195f3), address: V6(Ipv6Network { network_address: ::ffff:143.55.54.183, netmask: 128 }), name: "pammh", address_description: None, sites: [Site { name: "laey", id: SiteId(6707ba24-4d4b-4fb0-dae7-64b89f4401b8) }] }, new_address: V6(Ipv6Network { network_address: ::ffff:127.0.0.0, netmask: 126 }) } diff --git a/rust/connlib/tunnel/src/client.rs b/rust/connlib/tunnel/src/client.rs index c4e4183c7..020204c3f 100644 --- a/rust/connlib/tunnel/src/client.rs +++ b/rust/connlib/tunnel/src/client.rs @@ -73,8 +73,8 @@ impl ClientTunnel { } /// Adds a the given resource to the tunnel. - pub fn add_resources(&mut self, resources: &[ResourceDescription]) { - self.role_state.add_resources(resources); + pub fn add_resource(&mut self, resource: ResourceDescription) { + self.role_state.add_resource(resource); self.role_state .buffered_events @@ -89,8 +89,8 @@ impl ClientTunnel { }); } - pub fn remove_resources(&mut self, ids: &[ResourceId]) { - self.role_state.remove_resources(ids); + pub fn remove_resource(&mut self, id: ResourceId) { + self.role_state.remove_resource(id); self.role_state .buffered_events @@ -237,7 +237,7 @@ pub struct ClientState { /// All CIDR resources we know about, indexed by the IP range they cover (like `1.1.0.0/8`). cidr_resources: IpNetworkTable, /// All resources indexed by their ID. - resource_ids: HashMap, + resources_by_id: HashMap, /// The DNS resolvers configured on the system outside of connlib. system_resolvers: Vec, @@ -277,7 +277,7 @@ impl ClientState { awaiting_connection_details: Default::default(), resources_gateways: Default::default(), cidr_resources: IpNetworkTable::new(), - resource_ids: Default::default(), + resources_by_id: Default::default(), peers: Default::default(), dns_mapping: Default::default(), buffered_events: Default::default(), @@ -312,7 +312,7 @@ impl ClientState { } pub(crate) fn resources(&self) -> Vec { - self.resource_ids + self.resources_by_id .values() .sorted() .cloned() @@ -344,7 +344,7 @@ impl ClientState { } fn set_resource_offline(&mut self, id: ResourceId) { - let Some(resource) = self.resource_ids.get(&id).cloned() else { + let Some(resource) = self.resources_by_id.get(&id).cloned() else { return; }; @@ -511,7 +511,7 @@ impl ClientState { tracing::trace!("Creating or reusing connection"); let desc = self - .resource_ids + .resources_by_id .get(&resource_id) .context("Unknown resource")?; @@ -663,7 +663,7 @@ impl ClientState { destination: &IpAddr, now: Instant, ) { - debug_assert!(self.resource_ids.contains_key(&resource)); + debug_assert!(self.resources_by_id.contains_key(&resource)); let gateways = self .resources_gateways @@ -886,109 +886,98 @@ impl ClientState { /// Sets a new set of resources. /// /// This function does **not** perform a blanket "clear all and set new resources". - /// Instead, it diffs which resources to remove and which ones to add. + /// Instead, it diffs which resources to remove first and then adds the new ones. /// - /// This is important because we don't want to lose state like resolved DNS names for resources that didn't change. + /// Removing a resource interrupts routing for all packets, even if the resource is added back right away because [`GatewayOnClient`] tracks the allowed IPs which has to contain the resource ID. /// /// TODO: Add a test that asserts the above. - /// That is tricky because we need to assert on state deleted by [`ClientState::remove_resources`] and check that it did in fact not get deleted. - fn set_resources(&mut self, new_resources: Vec) { - self.remove_resources( - &HashSet::from_iter(self.resource_ids.keys().copied()) - .difference(&HashSet::::from_iter( - new_resources.iter().map(|r| r.id()), - )) - .copied() - .collect_vec(), - ); + /// That is tricky because we need to assert on state deleted by [`ClientState::remove_resource`] and check that it did in fact not get deleted. + pub(crate) fn set_resources(&mut self, new_resources: Vec) { + let current_resource_ids = self.resources_by_id.keys().copied().collect::>(); + let new_resource_ids = new_resources.iter().map(|r| r.id()).collect(); - self.add_resources( - &HashSet::from_iter(new_resources.iter().cloned()) - .difference(&HashSet::::from_iter( - self.resource_ids.values().cloned(), - )) - .cloned() - .collect_vec(), - ); - } + // First, remove all resources that are not present in the new resource list. + for id in current_resource_ids.difference(&new_resource_ids).copied() { + self.remove_resource(id); + } - pub(crate) fn add_resources(&mut self, resources: &[ResourceDescription]) { - for resource_description in resources { - if let Some(resource) = self.resource_ids.get(&resource_description.id()) { - if resource.has_different_address(resource_description) { - self.remove_resources(&[resource.id()]); - } - } - - match &resource_description { - ResourceDescription::Dns(dns) => { - self.stub_resolver.add_resource(dns.id, dns.address.clone()); - } - ResourceDescription::Cidr(cidr) => { - let existing = self.cidr_resources.insert(cidr.address, cidr.clone()); - - match existing { - Some(existing) if existing.id != cidr.id => { - tracing::info!(address = %cidr.address, old = %existing.name, new = %cidr.name, "Replacing CIDR resource"); - } - Some(_) => {} - None => { - tracing::info!(address = %cidr.address, name = %cidr.name, "Activating CIDR resource"); - } - } - } - ResourceDescription::Internet(_) => {} - } - - self.resource_ids - .insert(resource_description.id(), resource_description.clone()); + // Second, add all resources. + for resource in new_resources { + self.add_resource(resource) } } - #[tracing::instrument(level = "debug", skip_all, fields(?ids))] - pub(crate) fn remove_resources(&mut self, ids: &[ResourceId]) { - for id in ids { - self.awaiting_connection_details.remove(id); - self.stub_resolver.remove_resource(*id); - self.cidr_resources.retain(|_, r| { - if r.id == *id { - tracing::info!(address = %r.address, name = %r.name, "Deactivating CIDR resource"); - return false; + pub(crate) fn add_resource(&mut self, new_resource: ResourceDescription) { + if let Some(resource) = self.resources_by_id.get(&new_resource.id()) { + if resource.has_different_address(&new_resource) { + self.remove_resource(resource.id()); + } + } + + match &new_resource { + ResourceDescription::Dns(dns) => { + self.stub_resolver.add_resource(dns.id, dns.address.clone()); + } + ResourceDescription::Cidr(cidr) => { + let existing = self.cidr_resources.insert(cidr.address, cidr.clone()); + + match existing { + Some(existing) if existing.id != cidr.id => { + tracing::info!(address = %cidr.address, old = %existing.name, new = %cidr.name, "Replacing CIDR resource"); + } + Some(_) => {} + None => { + tracing::info!(address = %cidr.address, name = %cidr.name, "Activating CIDR resource"); + } } + } + ResourceDescription::Internet(_) => {} + } - true - }); + self.resources_by_id.insert(new_resource.id(), new_resource); + } - self.resource_ids.remove(id); + #[tracing::instrument(level = "debug", skip_all, fields(?id))] + pub(crate) fn remove_resource(&mut self, id: ResourceId) { + self.awaiting_connection_details.remove(&id); + self.stub_resolver.remove_resource(id); + self.cidr_resources.retain(|_, r| { + if r.id == id { + tracing::info!(address = %r.address, name = %r.name, "Deactivating CIDR resource"); + return false; + } - let Some(peer) = peer_by_resource_mut(&self.resources_gateways, &mut self.peers, *id) - else { + true + }); + + self.resources_by_id.remove(&id); + + let Some(peer) = peer_by_resource_mut(&self.resources_gateways, &mut self.peers, id) else { + return; + }; + let gateway_id = peer.id(); + + // First we remove the id from all allowed ips + for (_, resources) in peer + .allowed_ips + .iter_mut() + .filter(|(_, resources)| resources.contains(&id)) + { + resources.remove(&id); + + if !resources.is_empty() { continue; - }; - let gateway_id = peer.id(); - - // First we remove the id from all allowed ips - for (_, resources) in peer - .allowed_ips - .iter_mut() - .filter(|(_, resources)| resources.contains(id)) - { - resources.remove(id); - - if !resources.is_empty() { - continue; - } } + } - // We remove all empty allowed ips entry since there's no resource that corresponds to it - peer.allowed_ips.retain(|_, r| !r.is_empty()); + // We remove all empty allowed ips entry since there's no resource that corresponds to it + peer.allowed_ips.retain(|_, r| !r.is_empty()); - // If there's no allowed ip left we remove the whole peer because there's no point on keeping it around - if peer.allowed_ips.is_empty() { - self.peers.remove(&gateway_id); - self.update_site_status_by_gateway(&gateway_id, Status::Unknown); - // TODO: should we have a Node::remove_connection? - } + // If there's no allowed ip left we remove the whole peer because there's no point on keeping it around + if peer.allowed_ips.is_empty() { + self.peers.remove(&gateway_id); + self.update_site_status_by_gateway(&gateway_id, Status::Unknown); + // TODO: should we have a Node::remove_connection? } } @@ -1539,10 +1528,8 @@ mod proptests { ) { let mut client_state = ClientState::for_test(); - client_state.add_resources(&[ - ResourceDescription::Cidr(resource1.clone()), - ResourceDescription::Cidr(resource2.clone()), - ]); + client_state.add_resource(ResourceDescription::Cidr(resource1.clone())); + client_state.add_resource(ResourceDescription::Cidr(resource2.clone())); assert_eq!( hashset(client_state.routes()), @@ -1560,10 +1547,8 @@ mod proptests { let mut client_state = ClientState::for_test(); - client_state.add_resources(&[ - ResourceDescription::Cidr(resource1.clone()), - ResourceDescription::Dns(resource2.clone()), - ]); + client_state.add_resource(ResourceDescription::Cidr(resource1.clone())); + client_state.add_resource(ResourceDescription::Dns(resource2.clone())); assert_eq!( hashset(client_state.resources()), @@ -1573,7 +1558,7 @@ mod proptests { ]) ); - client_state.add_resources(&[ResourceDescription::Cidr(resource3.clone())]); + client_state.add_resource(ResourceDescription::Cidr(resource3.clone())); assert_eq!( hashset(client_state.resources()), @@ -1593,14 +1578,14 @@ mod proptests { use callbacks as cb; let mut client_state = ClientState::for_test(); - client_state.add_resources(&[ResourceDescription::Cidr(resource.clone())]); + client_state.add_resource(ResourceDescription::Cidr(resource.clone())); let updated_resource = ResourceDescriptionCidr { address: new_address, ..resource }; - client_state.add_resources(&[ResourceDescription::Cidr(updated_resource.clone())]); + client_state.add_resource(ResourceDescription::Cidr(updated_resource.clone())); assert_eq!( hashset(client_state.resources()), @@ -1622,7 +1607,7 @@ mod proptests { use callbacks as cb; let mut client_state = ClientState::for_test(); - client_state.add_resources(&[ResourceDescription::Dns(resource.clone())]); + client_state.add_resource(ResourceDescription::Dns(resource.clone())); let dns_as_cidr_resource = ResourceDescriptionCidr { address, @@ -1632,7 +1617,7 @@ mod proptests { sites: resource.sites, }; - client_state.add_resources(&[ResourceDescription::Cidr(dns_as_cidr_resource.clone())]); + client_state.add_resource(ResourceDescription::Cidr(dns_as_cidr_resource.clone())); assert_eq!( hashset(client_state.resources()), @@ -1654,12 +1639,10 @@ mod proptests { use callbacks as cb; let mut client_state = ClientState::for_test(); - client_state.add_resources(&[ - ResourceDescription::Dns(dns_resource.clone()), - ResourceDescription::Cidr(cidr_resource.clone()), - ]); + client_state.add_resource(ResourceDescription::Dns(dns_resource.clone())); + client_state.add_resource(ResourceDescription::Cidr(cidr_resource.clone())); - client_state.remove_resources(&[dns_resource.id]); + client_state.remove_resource(dns_resource.id); assert_eq!( hashset(client_state.resources()), @@ -1672,7 +1655,7 @@ mod proptests { expected_routes(vec![cidr_resource.address]) ); - client_state.remove_resources(&[cidr_resource.id]); + client_state.remove_resource(cidr_resource.id); assert_eq!(hashset(client_state.resources().iter()), hashset(&[])); assert_eq!(hashset(client_state.routes()), expected_routes(vec![])); @@ -1688,10 +1671,8 @@ mod proptests { use callbacks as cb; let mut client_state = ClientState::for_test(); - client_state.add_resources(&[ - ResourceDescription::Dns(dns_resource1), - ResourceDescription::Cidr(cidr_resource1), - ]); + client_state.add_resource(ResourceDescription::Dns(dns_resource1)); + client_state.add_resource(ResourceDescription::Cidr(cidr_resource1)); client_state.set_resources(vec![ ResourceDescription::Dns(dns_resource2.clone()), @@ -1718,8 +1699,11 @@ mod proptests { #[strategy(gateway_id())] gateway: GatewayId, ) { let mut client_state = ClientState::for_test(); - client_state.add_resources(&resources_online); - client_state.add_resources(&resources_unknown); + + for r in resources_online.iter().chain(resources_unknown.iter()) { + client_state.add_resource(r.clone()) + } + let first_resource = resources_online.first().unwrap(); client_state .resources_gateways @@ -1745,7 +1729,9 @@ mod proptests { #[strategy(gateway_id())] gateway: GatewayId, ) { let mut client_state = ClientState::for_test(); - client_state.add_resources(&resources); + for r in &resources { + client_state.add_resource(r.clone()) + } let first_resources = resources.first().unwrap(); client_state .resources_gateways @@ -1768,8 +1754,10 @@ mod proptests { #[strategy(resource())] single_site_resource: ResourceDescription, ) { let mut client_state = ClientState::for_test(); - client_state.add_resources(&multi_site_resources); - client_state.add_resources(&[single_site_resource.clone()]); + client_state.add_resource(single_site_resource.clone()); + for r in &multi_site_resources { + client_state.add_resource(r.clone()) + } client_state.set_resource_offline(single_site_resource.id()); diff --git a/rust/connlib/tunnel/src/tests/reference.rs b/rust/connlib/tunnel/src/tests/reference.rs index 91ecc2e2d..6bd0b4b21 100644 --- a/rust/connlib/tunnel/src/tests/reference.rs +++ b/rust/connlib/tunnel/src/tests/reference.rs @@ -246,7 +246,7 @@ impl ReferenceStateMachine for ReferenceState { ) }, ) - .with_if_not_empty(1, state.client.inner().all_resources(), |resources| { + .with_if_not_empty(1, state.client.inner().all_resource_ids(), |resources| { sample::select(resources).prop_map(Transition::RemoveResource) }) .boxed() @@ -382,7 +382,12 @@ impl ReferenceStateMachine for ReferenceState { match transition { Transition::AddCidrResource { resource } => { // Resource IDs must be unique. - if state.client.inner().all_resources().contains(&resource.id) { + if state + .client + .inner() + .all_resource_ids() + .contains(&resource.id) + { return false; } let Some(gid) = state.portal.gateway_for_resource(resource.id) else { @@ -438,7 +443,12 @@ impl ReferenceStateMachine for ReferenceState { } // Resource IDs must be unique. - if state.client.inner().all_resources().contains(&resource.id) { + if state + .client + .inner() + .all_resource_ids() + .contains(&resource.id) + { return false; } @@ -537,7 +547,7 @@ impl ReferenceStateMachine for ReferenceState { .expected_dns_servers() .contains(dns_server) } - Transition::RemoveResource(id) => state.client.inner().all_resources().contains(id), + Transition::RemoveResource(id) => state.client.inner().all_resource_ids().contains(id), Transition::RoamClient { ip4, ip6, port } => { // In production, we always rebind to a new port so we never roam to our old existing IP / port combination. diff --git a/rust/connlib/tunnel/src/tests/sim_client.rs b/rust/connlib/tunnel/src/tests/sim_client.rs index 94871de24..bc419ed62 100644 --- a/rust/connlib/tunnel/src/tests/sim_client.rs +++ b/rust/connlib/tunnel/src/tests/sim_client.rs @@ -9,7 +9,7 @@ use crate::{tests::sut::hickory_name_to_domain, ClientState}; use bimap::BiMap; use connlib_shared::{ messages::{ - client::{ResourceDescriptionCidr, ResourceDescriptionDns}, + client::{ResourceDescription, ResourceDescriptionCidr, ResourceDescriptionDns}, ClientId, DnsServer, GatewayId, Interface, ResourceId, }, proptest::{client_id, domain_name}, @@ -542,12 +542,28 @@ impl RefClient { self.cidr_resource_by_ip(dns_server) } - pub(crate) fn all_resources(&self) -> Vec { + pub(crate) fn all_resource_ids(&self) -> Vec { let cidr_resources = self.cidr_resources.iter().map(|(_, r)| r.id); let dns_resources = self.dns_resources.keys().copied(); Vec::from_iter(cidr_resources.chain(dns_resources)) } + + pub(crate) fn all_resources(&self) -> Vec { + let cidr_resources = self + .cidr_resources + .iter() + .map(|(_, r)| r) + .cloned() + .map(ResourceDescription::Cidr); + let dns_resources = self + .dns_resources + .values() + .cloned() + .map(ResourceDescription::Dns); + + Vec::from_iter(cidr_resources.chain(dns_resources)) + } } fn is_subdomain(name: &str, record: &str) -> bool { diff --git a/rust/connlib/tunnel/src/tests/sut.rs b/rust/connlib/tunnel/src/tests/sut.rs index 90d352aed..98d3eabd8 100644 --- a/rust/connlib/tunnel/src/tests/sut.rs +++ b/rust/connlib/tunnel/src/tests/sut.rs @@ -156,14 +156,12 @@ impl StateMachineTest for TunnelTest { Transition::AddCidrResource { resource } => { state .client - .exec_mut(|c| c.sut.add_resources(&[ResourceDescription::Cidr(resource)])); + .exec_mut(|c| c.sut.add_resource(ResourceDescription::Cidr(resource))); } Transition::AddDnsResource { resource, .. } => state .client - .exec_mut(|c| c.sut.add_resources(&[ResourceDescription::Dns(resource)])), - Transition::RemoveResource(id) => { - state.client.exec_mut(|c| c.sut.remove_resources(&[id])) - } + .exec_mut(|c| c.sut.add_resource(ResourceDescription::Dns(resource))), + Transition::RemoveResource(id) => state.client.exec_mut(|c| c.sut.remove_resource(id)), Transition::SendICMPPacketToNonResourceIp { src, dst, @@ -253,7 +251,9 @@ impl StateMachineTest for TunnelTest { HashSet::default(), HashSet::from_iter(map_explode(state.relays.iter(), "client")), ref_state.now, - ) + ); + c.sut + .set_resources(ref_state.client.inner().all_resources()); }); } };