diff --git a/rust/connlib/tunnel/src/client.rs b/rust/connlib/tunnel/src/client.rs index 21b278c05..00a8485a5 100644 --- a/rust/connlib/tunnel/src/client.rs +++ b/rust/connlib/tunnel/src/client.rs @@ -1112,7 +1112,7 @@ impl ClientState { } #[tracing::instrument(level = "debug", skip_all, fields(?ids))] - fn remove_resources(&mut self, ids: &[ResourceId]) { + pub(crate) fn remove_resources(&mut self, ids: &[ResourceId]) { for id in ids { self.awaiting_connection.remove(id); self.dns_resources_internal_ips diff --git a/rust/connlib/tunnel/src/tests/reference.rs b/rust/connlib/tunnel/src/tests/reference.rs index c0bc37517..a39af3625 100644 --- a/rust/connlib/tunnel/src/tests/reference.rs +++ b/rust/connlib/tunnel/src/tests/reference.rs @@ -269,6 +269,9 @@ impl ReferenceStateMachine for ReferenceState { ) }, ) + .with_if_not_empty(1, state.all_resources(), |resources| { + sample::select(resources).prop_map(Transition::RemoveResource) + }) .boxed() } @@ -280,6 +283,11 @@ impl ReferenceStateMachine for ReferenceState { Transition::AddCidrResource(r) => { state.client_cidr_resources.insert(r.address, r.clone()); } + Transition::RemoveResource(id) => { + state.client_cidr_resources.retain(|_, r| &r.id != id); + state.client_connected_cidr_resources.remove(id); + state.client_dns_resources.remove(id); + } Transition::AddDnsResource { resource: new_resource, records, @@ -487,6 +495,10 @@ impl ReferenceStateMachine for ReferenceState { state.global_dns_records.contains_key(domain) && state.expected_dns_servers().contains(dns_server) } + Transition::RemoveResource(id) => { + state.client_cidr_resources.iter().any(|(_, r)| &r.id == id) + || state.client_dns_resources.contains_key(id) + } } } } @@ -704,6 +716,13 @@ impl ReferenceState { self.cidr_resource_by_ip(dns_server) } + + fn all_resources(&self) -> Vec { + let cidr_resources = self.client_cidr_resources.iter().map(|(_, r)| r.id); + let dns_resources = self.client_dns_resources.keys().copied(); + + Vec::from_iter(cidr_resources.chain(dns_resources)) + } } fn matches_domain(resource_address: &str, domain: &DomainName) -> bool { diff --git a/rust/connlib/tunnel/src/tests/sim_node.rs b/rust/connlib/tunnel/src/tests/sim_node.rs index a47f3e6c4..4a89a3e41 100644 --- a/rust/connlib/tunnel/src/tests/sim_node.rs +++ b/rust/connlib/tunnel/src/tests/sim_node.rs @@ -1,7 +1,9 @@ use super::sim_relay::SimRelay; use crate::{ClientState, GatewayState}; use connlib_shared::{ - messages::{client::ResourceDescription, ClientId, DnsServer, GatewayId, Interface}, + messages::{ + client::ResourceDescription, ClientId, DnsServer, GatewayId, Interface, ResourceId, + }, StaticSecret, }; use ip_network::{Ipv4Network, Ipv6Network}; @@ -104,6 +106,12 @@ impl SimNode { self.state.add_resources(&[resource]); }) } + + pub(crate) fn remove_resource(&mut self, resource: ResourceId) { + self.span.in_scope(|| { + self.state.remove_resources(&[resource]); + }) + } } impl SimNode { diff --git a/rust/connlib/tunnel/src/tests/sut.rs b/rust/connlib/tunnel/src/tests/sut.rs index d5092b771..fc7d84d6a 100644 --- a/rust/connlib/tunnel/src/tests/sut.rs +++ b/rust/connlib/tunnel/src/tests/sut.rs @@ -147,11 +147,12 @@ impl StateMachineTest for TunnelTest { // Act: Apply the transition match transition { Transition::AddCidrResource(r) => { - state.client.add_resource(ResourceDescription::Cidr(r)); + state.client.add_resource(ResourceDescription::Cidr(r)) } Transition::AddDnsResource { resource, .. } => state .client .add_resource(ResourceDescription::Dns(resource)), + Transition::RemoveResource(id) => state.client.remove_resource(id), Transition::SendICMPPacketToNonResourceIp { src, dst, diff --git a/rust/connlib/tunnel/src/tests/transition.rs b/rust/connlib/tunnel/src/tests/transition.rs index 963c124ad..cd1010d50 100644 --- a/rust/connlib/tunnel/src/tests/transition.rs +++ b/rust/connlib/tunnel/src/tests/transition.rs @@ -2,7 +2,7 @@ use super::strategies::*; use connlib_shared::{ messages::{ client::{ResourceDescriptionCidr, ResourceDescriptionDns}, - DnsServer, + DnsServer, ResourceId, }, proptest::*, DomainName, @@ -69,6 +69,9 @@ pub(crate) enum Transition { /// Advance time by this many milliseconds. Tick { millis: u64 }, + + /// Remove a resource from the client. + RemoveResource(ResourceId), } pub(crate) fn ping_random_ip(