refactor(connlib): deal with resources one at a time (#5886)

The two primary users of the `add_resources` and `remove_resources` are
the client's eventloop and the `tunnel_test`. Both of them only ever
pass a single resource at a time.

It is thus simpler to remove the inner loop from within `ClientState`
and simply process a single resource at a time.
This commit is contained in:
Thomas Eizinger
2024-07-18 14:59:12 +10:00
committed by GitHub
parent 5268756b60
commit 4937291d23
7 changed files with 160 additions and 141 deletions

View File

@@ -229,10 +229,10 @@ where
self.tunnel.update_relays(HashSet::default(), relays)
}
IngressMessages::ResourceCreatedOrUpdated(resource) => {
self.tunnel.add_resources(&[resource]);
self.tunnel.add_resource(resource);
}
IngressMessages::ResourceDeleted(resource) => {
self.tunnel.remove_resources(&[resource]);
self.tunnel.remove_resource(resource);
}
IngressMessages::RelaysPresence(RelaysPresence {
disconnected_ids,

View File

@@ -228,8 +228,6 @@ mod tests {
};
tunnel.set_tun(Tun::new().unwrap());
tunnel.set_new_interface_config(interface).unwrap();
let resources = vec![];
tunnel.add_resources(&resources);
let tunnel = tokio::spawn(async move {
std::future::poll_fn(|cx| tunnel.poll_next_event(cx))

View File

@@ -0,0 +1,7 @@
# Seeds for failure cases proptest has generated in the past. It is
# automatically read and these particular cases re-run before any
# novel cases are generated.
#
# It is recommended to check this file in to source control so that
# everyone who runs the test benefits from these saved cases.
cc 9b48b1c90455e632268397a3253352fd834ff7e2952f8efa5959543547be8892 # shrinks to input = _AddingSameResourceWithDifferentAddressUpdatesTheAddressArgs { resource: ResourceDescriptionCidr { id: ResourceId(0003585c-0f03-a9db-f663-31382f9195f3), address: V6(Ipv6Network { network_address: ::ffff:143.55.54.183, netmask: 128 }), name: "pammh", address_description: None, sites: [Site { name: "laey", id: SiteId(6707ba24-4d4b-4fb0-dae7-64b89f4401b8) }] }, new_address: V6(Ipv6Network { network_address: ::ffff:127.0.0.0, netmask: 126 }) }

View File

@@ -73,8 +73,8 @@ impl ClientTunnel {
}
/// Adds a the given resource to the tunnel.
pub fn add_resources(&mut self, resources: &[ResourceDescription]) {
self.role_state.add_resources(resources);
pub fn add_resource(&mut self, resource: ResourceDescription) {
self.role_state.add_resource(resource);
self.role_state
.buffered_events
@@ -89,8 +89,8 @@ impl ClientTunnel {
});
}
pub fn remove_resources(&mut self, ids: &[ResourceId]) {
self.role_state.remove_resources(ids);
pub fn remove_resource(&mut self, id: ResourceId) {
self.role_state.remove_resource(id);
self.role_state
.buffered_events
@@ -237,7 +237,7 @@ pub struct ClientState {
/// All CIDR resources we know about, indexed by the IP range they cover (like `1.1.0.0/8`).
cidr_resources: IpNetworkTable<ResourceDescriptionCidr>,
/// All resources indexed by their ID.
resource_ids: HashMap<ResourceId, ResourceDescription>,
resources_by_id: HashMap<ResourceId, ResourceDescription>,
/// The DNS resolvers configured on the system outside of connlib.
system_resolvers: Vec<IpAddr>,
@@ -277,7 +277,7 @@ impl ClientState {
awaiting_connection_details: Default::default(),
resources_gateways: Default::default(),
cidr_resources: IpNetworkTable::new(),
resource_ids: Default::default(),
resources_by_id: Default::default(),
peers: Default::default(),
dns_mapping: Default::default(),
buffered_events: Default::default(),
@@ -312,7 +312,7 @@ impl ClientState {
}
pub(crate) fn resources(&self) -> Vec<callbacks::ResourceDescription> {
self.resource_ids
self.resources_by_id
.values()
.sorted()
.cloned()
@@ -344,7 +344,7 @@ impl ClientState {
}
fn set_resource_offline(&mut self, id: ResourceId) {
let Some(resource) = self.resource_ids.get(&id).cloned() else {
let Some(resource) = self.resources_by_id.get(&id).cloned() else {
return;
};
@@ -511,7 +511,7 @@ impl ClientState {
tracing::trace!("Creating or reusing connection");
let desc = self
.resource_ids
.resources_by_id
.get(&resource_id)
.context("Unknown resource")?;
@@ -663,7 +663,7 @@ impl ClientState {
destination: &IpAddr,
now: Instant,
) {
debug_assert!(self.resource_ids.contains_key(&resource));
debug_assert!(self.resources_by_id.contains_key(&resource));
let gateways = self
.resources_gateways
@@ -886,109 +886,98 @@ impl ClientState {
/// Sets a new set of resources.
///
/// This function does **not** perform a blanket "clear all and set new resources".
/// Instead, it diffs which resources to remove and which ones to add.
/// Instead, it diffs which resources to remove first and then adds the new ones.
///
/// This is important because we don't want to lose state like resolved DNS names for resources that didn't change.
/// Removing a resource interrupts routing for all packets, even if the resource is added back right away because [`GatewayOnClient`] tracks the allowed IPs which has to contain the resource ID.
///
/// TODO: Add a test that asserts the above.
/// That is tricky because we need to assert on state deleted by [`ClientState::remove_resources`] and check that it did in fact not get deleted.
fn set_resources(&mut self, new_resources: Vec<ResourceDescription>) {
self.remove_resources(
&HashSet::from_iter(self.resource_ids.keys().copied())
.difference(&HashSet::<ResourceId>::from_iter(
new_resources.iter().map(|r| r.id()),
))
.copied()
.collect_vec(),
);
/// That is tricky because we need to assert on state deleted by [`ClientState::remove_resource`] and check that it did in fact not get deleted.
pub(crate) fn set_resources(&mut self, new_resources: Vec<ResourceDescription>) {
let current_resource_ids = self.resources_by_id.keys().copied().collect::<HashSet<_>>();
let new_resource_ids = new_resources.iter().map(|r| r.id()).collect();
self.add_resources(
&HashSet::from_iter(new_resources.iter().cloned())
.difference(&HashSet::<ResourceDescription>::from_iter(
self.resource_ids.values().cloned(),
))
.cloned()
.collect_vec(),
);
}
// First, remove all resources that are not present in the new resource list.
for id in current_resource_ids.difference(&new_resource_ids).copied() {
self.remove_resource(id);
}
pub(crate) fn add_resources(&mut self, resources: &[ResourceDescription]) {
for resource_description in resources {
if let Some(resource) = self.resource_ids.get(&resource_description.id()) {
if resource.has_different_address(resource_description) {
self.remove_resources(&[resource.id()]);
}
}
match &resource_description {
ResourceDescription::Dns(dns) => {
self.stub_resolver.add_resource(dns.id, dns.address.clone());
}
ResourceDescription::Cidr(cidr) => {
let existing = self.cidr_resources.insert(cidr.address, cidr.clone());
match existing {
Some(existing) if existing.id != cidr.id => {
tracing::info!(address = %cidr.address, old = %existing.name, new = %cidr.name, "Replacing CIDR resource");
}
Some(_) => {}
None => {
tracing::info!(address = %cidr.address, name = %cidr.name, "Activating CIDR resource");
}
}
}
ResourceDescription::Internet(_) => {}
}
self.resource_ids
.insert(resource_description.id(), resource_description.clone());
// Second, add all resources.
for resource in new_resources {
self.add_resource(resource)
}
}
#[tracing::instrument(level = "debug", skip_all, fields(?ids))]
pub(crate) fn remove_resources(&mut self, ids: &[ResourceId]) {
for id in ids {
self.awaiting_connection_details.remove(id);
self.stub_resolver.remove_resource(*id);
self.cidr_resources.retain(|_, r| {
if r.id == *id {
tracing::info!(address = %r.address, name = %r.name, "Deactivating CIDR resource");
return false;
pub(crate) fn add_resource(&mut self, new_resource: ResourceDescription) {
if let Some(resource) = self.resources_by_id.get(&new_resource.id()) {
if resource.has_different_address(&new_resource) {
self.remove_resource(resource.id());
}
}
match &new_resource {
ResourceDescription::Dns(dns) => {
self.stub_resolver.add_resource(dns.id, dns.address.clone());
}
ResourceDescription::Cidr(cidr) => {
let existing = self.cidr_resources.insert(cidr.address, cidr.clone());
match existing {
Some(existing) if existing.id != cidr.id => {
tracing::info!(address = %cidr.address, old = %existing.name, new = %cidr.name, "Replacing CIDR resource");
}
Some(_) => {}
None => {
tracing::info!(address = %cidr.address, name = %cidr.name, "Activating CIDR resource");
}
}
}
ResourceDescription::Internet(_) => {}
}
true
});
self.resources_by_id.insert(new_resource.id(), new_resource);
}
self.resource_ids.remove(id);
#[tracing::instrument(level = "debug", skip_all, fields(?id))]
pub(crate) fn remove_resource(&mut self, id: ResourceId) {
self.awaiting_connection_details.remove(&id);
self.stub_resolver.remove_resource(id);
self.cidr_resources.retain(|_, r| {
if r.id == id {
tracing::info!(address = %r.address, name = %r.name, "Deactivating CIDR resource");
return false;
}
let Some(peer) = peer_by_resource_mut(&self.resources_gateways, &mut self.peers, *id)
else {
true
});
self.resources_by_id.remove(&id);
let Some(peer) = peer_by_resource_mut(&self.resources_gateways, &mut self.peers, id) else {
return;
};
let gateway_id = peer.id();
// First we remove the id from all allowed ips
for (_, resources) in peer
.allowed_ips
.iter_mut()
.filter(|(_, resources)| resources.contains(&id))
{
resources.remove(&id);
if !resources.is_empty() {
continue;
};
let gateway_id = peer.id();
// First we remove the id from all allowed ips
for (_, resources) in peer
.allowed_ips
.iter_mut()
.filter(|(_, resources)| resources.contains(id))
{
resources.remove(id);
if !resources.is_empty() {
continue;
}
}
}
// We remove all empty allowed ips entry since there's no resource that corresponds to it
peer.allowed_ips.retain(|_, r| !r.is_empty());
// We remove all empty allowed ips entry since there's no resource that corresponds to it
peer.allowed_ips.retain(|_, r| !r.is_empty());
// If there's no allowed ip left we remove the whole peer because there's no point on keeping it around
if peer.allowed_ips.is_empty() {
self.peers.remove(&gateway_id);
self.update_site_status_by_gateway(&gateway_id, Status::Unknown);
// TODO: should we have a Node::remove_connection?
}
// If there's no allowed ip left we remove the whole peer because there's no point on keeping it around
if peer.allowed_ips.is_empty() {
self.peers.remove(&gateway_id);
self.update_site_status_by_gateway(&gateway_id, Status::Unknown);
// TODO: should we have a Node::remove_connection?
}
}
@@ -1539,10 +1528,8 @@ mod proptests {
) {
let mut client_state = ClientState::for_test();
client_state.add_resources(&[
ResourceDescription::Cidr(resource1.clone()),
ResourceDescription::Cidr(resource2.clone()),
]);
client_state.add_resource(ResourceDescription::Cidr(resource1.clone()));
client_state.add_resource(ResourceDescription::Cidr(resource2.clone()));
assert_eq!(
hashset(client_state.routes()),
@@ -1560,10 +1547,8 @@ mod proptests {
let mut client_state = ClientState::for_test();
client_state.add_resources(&[
ResourceDescription::Cidr(resource1.clone()),
ResourceDescription::Dns(resource2.clone()),
]);
client_state.add_resource(ResourceDescription::Cidr(resource1.clone()));
client_state.add_resource(ResourceDescription::Dns(resource2.clone()));
assert_eq!(
hashset(client_state.resources()),
@@ -1573,7 +1558,7 @@ mod proptests {
])
);
client_state.add_resources(&[ResourceDescription::Cidr(resource3.clone())]);
client_state.add_resource(ResourceDescription::Cidr(resource3.clone()));
assert_eq!(
hashset(client_state.resources()),
@@ -1593,14 +1578,14 @@ mod proptests {
use callbacks as cb;
let mut client_state = ClientState::for_test();
client_state.add_resources(&[ResourceDescription::Cidr(resource.clone())]);
client_state.add_resource(ResourceDescription::Cidr(resource.clone()));
let updated_resource = ResourceDescriptionCidr {
address: new_address,
..resource
};
client_state.add_resources(&[ResourceDescription::Cidr(updated_resource.clone())]);
client_state.add_resource(ResourceDescription::Cidr(updated_resource.clone()));
assert_eq!(
hashset(client_state.resources()),
@@ -1622,7 +1607,7 @@ mod proptests {
use callbacks as cb;
let mut client_state = ClientState::for_test();
client_state.add_resources(&[ResourceDescription::Dns(resource.clone())]);
client_state.add_resource(ResourceDescription::Dns(resource.clone()));
let dns_as_cidr_resource = ResourceDescriptionCidr {
address,
@@ -1632,7 +1617,7 @@ mod proptests {
sites: resource.sites,
};
client_state.add_resources(&[ResourceDescription::Cidr(dns_as_cidr_resource.clone())]);
client_state.add_resource(ResourceDescription::Cidr(dns_as_cidr_resource.clone()));
assert_eq!(
hashset(client_state.resources()),
@@ -1654,12 +1639,10 @@ mod proptests {
use callbacks as cb;
let mut client_state = ClientState::for_test();
client_state.add_resources(&[
ResourceDescription::Dns(dns_resource.clone()),
ResourceDescription::Cidr(cidr_resource.clone()),
]);
client_state.add_resource(ResourceDescription::Dns(dns_resource.clone()));
client_state.add_resource(ResourceDescription::Cidr(cidr_resource.clone()));
client_state.remove_resources(&[dns_resource.id]);
client_state.remove_resource(dns_resource.id);
assert_eq!(
hashset(client_state.resources()),
@@ -1672,7 +1655,7 @@ mod proptests {
expected_routes(vec![cidr_resource.address])
);
client_state.remove_resources(&[cidr_resource.id]);
client_state.remove_resource(cidr_resource.id);
assert_eq!(hashset(client_state.resources().iter()), hashset(&[]));
assert_eq!(hashset(client_state.routes()), expected_routes(vec![]));
@@ -1688,10 +1671,8 @@ mod proptests {
use callbacks as cb;
let mut client_state = ClientState::for_test();
client_state.add_resources(&[
ResourceDescription::Dns(dns_resource1),
ResourceDescription::Cidr(cidr_resource1),
]);
client_state.add_resource(ResourceDescription::Dns(dns_resource1));
client_state.add_resource(ResourceDescription::Cidr(cidr_resource1));
client_state.set_resources(vec![
ResourceDescription::Dns(dns_resource2.clone()),
@@ -1718,8 +1699,11 @@ mod proptests {
#[strategy(gateway_id())] gateway: GatewayId,
) {
let mut client_state = ClientState::for_test();
client_state.add_resources(&resources_online);
client_state.add_resources(&resources_unknown);
for r in resources_online.iter().chain(resources_unknown.iter()) {
client_state.add_resource(r.clone())
}
let first_resource = resources_online.first().unwrap();
client_state
.resources_gateways
@@ -1745,7 +1729,9 @@ mod proptests {
#[strategy(gateway_id())] gateway: GatewayId,
) {
let mut client_state = ClientState::for_test();
client_state.add_resources(&resources);
for r in &resources {
client_state.add_resource(r.clone())
}
let first_resources = resources.first().unwrap();
client_state
.resources_gateways
@@ -1768,8 +1754,10 @@ mod proptests {
#[strategy(resource())] single_site_resource: ResourceDescription,
) {
let mut client_state = ClientState::for_test();
client_state.add_resources(&multi_site_resources);
client_state.add_resources(&[single_site_resource.clone()]);
client_state.add_resource(single_site_resource.clone());
for r in &multi_site_resources {
client_state.add_resource(r.clone())
}
client_state.set_resource_offline(single_site_resource.id());

View File

@@ -246,7 +246,7 @@ impl ReferenceStateMachine for ReferenceState {
)
},
)
.with_if_not_empty(1, state.client.inner().all_resources(), |resources| {
.with_if_not_empty(1, state.client.inner().all_resource_ids(), |resources| {
sample::select(resources).prop_map(Transition::RemoveResource)
})
.boxed()
@@ -382,7 +382,12 @@ impl ReferenceStateMachine for ReferenceState {
match transition {
Transition::AddCidrResource { resource } => {
// Resource IDs must be unique.
if state.client.inner().all_resources().contains(&resource.id) {
if state
.client
.inner()
.all_resource_ids()
.contains(&resource.id)
{
return false;
}
let Some(gid) = state.portal.gateway_for_resource(resource.id) else {
@@ -438,7 +443,12 @@ impl ReferenceStateMachine for ReferenceState {
}
// Resource IDs must be unique.
if state.client.inner().all_resources().contains(&resource.id) {
if state
.client
.inner()
.all_resource_ids()
.contains(&resource.id)
{
return false;
}
@@ -537,7 +547,7 @@ impl ReferenceStateMachine for ReferenceState {
.expected_dns_servers()
.contains(dns_server)
}
Transition::RemoveResource(id) => state.client.inner().all_resources().contains(id),
Transition::RemoveResource(id) => state.client.inner().all_resource_ids().contains(id),
Transition::RoamClient { ip4, ip6, port } => {
// In production, we always rebind to a new port so we never roam to our old existing IP / port combination.

View File

@@ -9,7 +9,7 @@ use crate::{tests::sut::hickory_name_to_domain, ClientState};
use bimap::BiMap;
use connlib_shared::{
messages::{
client::{ResourceDescriptionCidr, ResourceDescriptionDns},
client::{ResourceDescription, ResourceDescriptionCidr, ResourceDescriptionDns},
ClientId, DnsServer, GatewayId, Interface, ResourceId,
},
proptest::{client_id, domain_name},
@@ -542,12 +542,28 @@ impl RefClient {
self.cidr_resource_by_ip(dns_server)
}
pub(crate) fn all_resources(&self) -> Vec<ResourceId> {
pub(crate) fn all_resource_ids(&self) -> Vec<ResourceId> {
let cidr_resources = self.cidr_resources.iter().map(|(_, r)| r.id);
let dns_resources = self.dns_resources.keys().copied();
Vec::from_iter(cidr_resources.chain(dns_resources))
}
pub(crate) fn all_resources(&self) -> Vec<ResourceDescription> {
let cidr_resources = self
.cidr_resources
.iter()
.map(|(_, r)| r)
.cloned()
.map(ResourceDescription::Cidr);
let dns_resources = self
.dns_resources
.values()
.cloned()
.map(ResourceDescription::Dns);
Vec::from_iter(cidr_resources.chain(dns_resources))
}
}
fn is_subdomain(name: &str, record: &str) -> bool {

View File

@@ -156,14 +156,12 @@ impl StateMachineTest for TunnelTest {
Transition::AddCidrResource { resource } => {
state
.client
.exec_mut(|c| c.sut.add_resources(&[ResourceDescription::Cidr(resource)]));
.exec_mut(|c| c.sut.add_resource(ResourceDescription::Cidr(resource)));
}
Transition::AddDnsResource { resource, .. } => state
.client
.exec_mut(|c| c.sut.add_resources(&[ResourceDescription::Dns(resource)])),
Transition::RemoveResource(id) => {
state.client.exec_mut(|c| c.sut.remove_resources(&[id]))
}
.exec_mut(|c| c.sut.add_resource(ResourceDescription::Dns(resource))),
Transition::RemoveResource(id) => state.client.exec_mut(|c| c.sut.remove_resource(id)),
Transition::SendICMPPacketToNonResourceIp {
src,
dst,
@@ -253,7 +251,9 @@ impl StateMachineTest for TunnelTest {
HashSet::default(),
HashSet::from_iter(map_explode(state.relays.iter(), "client")),
ref_state.now,
)
);
c.sut
.set_resources(ref_state.client.inner().all_resources());
});
}
};