From f879b430e4e443a503d78fa837851ae5ba7f985c Mon Sep 17 00:00:00 2001 From: Gabi Date: Wed, 27 Mar 2024 22:28:11 -0300 Subject: [PATCH] feat(connlib): react to config updates (#4322) * Move the resource changes to `ClientState` to unit test easier * Add unit tests * Set new config on update from portal * Set parameters as told by portal on re-init Fixes: #2728 --- rust/connlib/clients/shared/src/eventloop.rs | 27 +- rust/connlib/shared/src/messages.rs | 4 +- rust/connlib/tunnel/src/client.rs | 461 +++++++++++++++---- 3 files changed, 388 insertions(+), 104 deletions(-) diff --git a/rust/connlib/clients/shared/src/eventloop.rs b/rust/connlib/clients/shared/src/eventloop.rs index 4adf1a85e..26cae863e 100644 --- a/rust/connlib/clients/shared/src/eventloop.rs +++ b/rust/connlib/clients/shared/src/eventloop.rs @@ -25,7 +25,6 @@ use url::Url; pub struct Eventloop { tunnel: ClientTunnel, - tunnel_init: bool, portal: PhoenixChannel<(), IngressMessages, ReplyMessages>, rx: tokio::sync::mpsc::UnboundedReceiver, @@ -50,7 +49,6 @@ impl Eventloop { Self { tunnel, portal, - tunnel_init: false, connection_intents: SentConnectionIntents::default(), log_upload_interval: upload_interval(), rx, @@ -172,8 +170,10 @@ where fn handle_portal_inbound_message(&mut self, msg: IngressMessages) { match msg { - IngressMessages::ConfigChanged(_) => { - tracing::warn!("Config changes are not yet implemented"); + IngressMessages::ConfigChanged(config) => { + if let Err(e) = self.tunnel.set_interface(config.interface.clone()) { + tracing::warn!(?config, "Failed to update configuration: {e:?}"); + } } IngressMessages::IceCandidates(GatewayIceCandidates { gateway_id, @@ -187,18 +187,13 @@ where interface, resources, }) => { - if !self.tunnel_init { - if let Err(e) = self.tunnel.set_interface(interface) { - tracing::warn!("Failed to set interface on tunnel: {e}"); - return; - } - - self.tunnel_init = true; - tracing::info!("Firezone Started!"); - let _ = self.tunnel.add_resources(&resources); - } else { - tracing::info!("Firezone reinitializated"); + if let Err(e) = self.tunnel.set_interface(interface) { + tracing::warn!("Failed to set interface on tunnel: {e}"); + return; } + + tracing::info!("Firezone Started!"); + let _ = self.tunnel.set_resources(resources); } IngressMessages::ResourceCreatedOrUpdated(resource) => { let resource_id = resource.id(); @@ -208,7 +203,7 @@ where } } IngressMessages::ResourceDeleted(RemoveResource(resource)) => { - self.tunnel.remove_resource(resource); + self.tunnel.remove_resources(&[resource]); } } } diff --git a/rust/connlib/shared/src/messages.rs b/rust/connlib/shared/src/messages.rs index 984d9719f..5300bec0e 100644 --- a/rust/connlib/shared/src/messages.rs +++ b/rust/connlib/shared/src/messages.rs @@ -140,7 +140,7 @@ impl PartialEq for RequestConnection { impl Eq for RequestConnection {} -#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)] +#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq, Hash)] #[serde(tag = "type", rename_all = "snake_case")] pub enum ResourceDescription { Dns(TDNS), @@ -285,7 +285,7 @@ impl ResourceDescription { } /// Description of a resource that maps to a CIDR. -#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)] +#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq, Hash)] pub struct ResourceDescriptionCidr { /// Resource's id. pub id: ResourceId, diff --git a/rust/connlib/tunnel/src/client.rs b/rust/connlib/tunnel/src/client.rs index 6221a3ab6..09f5eb8a7 100644 --- a/rust/connlib/tunnel/src/client.rs +++ b/rust/connlib/tunnel/src/client.rs @@ -58,111 +58,44 @@ impl ClientTunnel where CB: Callbacks + 'static, { + pub fn set_resources( + &mut self, + resources: Vec, + ) -> connlib_shared::Result<()> { + self.role_state.set_resources(resources); + + self.update_routes()?; + self.update_resource_list(); + + Ok(()) + } + /// Adds a the given resource to the tunnel. pub fn add_resources( &mut self, resources: &[ResourceDescription], ) -> connlib_shared::Result<()> { - for resource_description in resources { - if let Some(resource) = self.role_state.resource_ids.get(&resource_description.id()) { - if resource.has_different_address(resource) { - self.remove_resource(resource.id()); - } - } + self.role_state.add_resources(resources); - match &resource_description { - ResourceDescription::Dns(dns) => { - self.role_state - .dns_resources - .insert(dns.address.clone(), dns.clone()); - } - ResourceDescription::Cidr(cidr) => { - self.role_state - .cidr_resources - .insert(cidr.address, cidr.clone()); - } - } - - self.role_state - .resource_ids - .insert(resource_description.id(), resource_description.clone()); - } - - self.update_resource_list(); self.update_routes()?; + self.update_resource_list(); Ok(()) } - #[tracing::instrument(level = "debug", skip_all, fields(%id))] - pub fn remove_resource(&mut self, id: ResourceId) { - self.role_state.awaiting_connection.remove(&id); - self.role_state - .dns_resources_internal_ips - .retain(|r, _| r.id != id); - self.role_state.dns_resources.retain(|_, r| r.id != id); - self.role_state.cidr_resources.retain(|_, r| r.id != id); - self.role_state - .deferred_dns_queries - .retain(|(r, _), _| r.id != id); - - self.role_state.resource_ids.remove(&id); + pub fn remove_resources(&mut self, ids: &[ResourceId]) { + self.role_state.remove_resources(ids); if let Err(err) = self.update_routes() { - tracing::error!(%id, "Failed to update routes: {err:?}"); + tracing::error!(?ids, "Failed to update routes: {err:?}"); } self.update_resource_list(); - - let Some(gateway_id) = self.role_state.resources_gateways.remove(&id) else { - tracing::debug!("No gateway associated with resource"); - return; - }; - - let Some(peer) = self.role_state.peers.get_mut(&gateway_id) else { - return; - }; - - // First we remove the id from all allowed ips - for (network, resources) in peer - .allowed_ips - .iter_mut() - .filter(|(_, resources)| resources.contains(&id)) - { - resources.remove(&id); - - if !resources.is_empty() { - continue; - } - - // If the allowed_ips doesn't correspond to any resource anymore we - // clean up any related translation. - peer.transform - .translations - .remove_by_left(&network.network_address()); - } - - // We remove all empty allowed ips entry since there's no resource that corresponds to it - peer.allowed_ips.retain(|_, r| !r.is_empty()); - - // If there's no allowed ip left we remove the whole peer because there's no point on keeping it around - if peer.allowed_ips.is_empty() { - self.role_state.peers.remove(&gateway_id); - // TODO: should we have a Node::remove_connection? - } - - tracing::debug!("Resource removed") } fn update_resource_list(&self) { - self.callbacks.on_update_resources( - self.role_state - .resource_ids - .values() - .sorted() - .cloned() - .collect_vec(), - ); + self.callbacks + .on_update_resources(self.role_state.resources()); } /// Updates the system's dns @@ -373,6 +306,10 @@ impl ClientState { } } + fn resources(&self) -> Vec { + self.resource_ids.values().sorted().cloned().collect_vec() + } + pub(crate) fn encapsulate<'s>( &'s mut self, packet: MutableIpPacket<'_>, @@ -916,6 +853,100 @@ impl ClientState { self.node.poll_transmit() } + fn set_resources(&mut self, new_resources: Vec) { + self.remove_resources( + &HashSet::from_iter(self.resource_ids.keys().copied()) + .difference(&HashSet::::from_iter( + new_resources.iter().map(|r| r.id()), + )) + .copied() + .collect_vec(), + ); + + self.add_resources( + &HashSet::from_iter(new_resources.iter().cloned()) + .difference(&HashSet::::from_iter( + self.resource_ids.values().cloned(), + )) + .cloned() + .collect_vec(), + ); + } + + fn add_resources(&mut self, resources: &[ResourceDescription]) { + for resource_description in resources { + if let Some(resource) = self.resource_ids.get(&resource_description.id()) { + if resource.has_different_address(resource_description) { + self.remove_resources(&[resource.id()]); + } + } + + match &resource_description { + ResourceDescription::Dns(dns) => { + self.dns_resources.insert(dns.address.clone(), dns.clone()); + } + ResourceDescription::Cidr(cidr) => { + self.cidr_resources.insert(cidr.address, cidr.clone()); + } + } + + self.resource_ids + .insert(resource_description.id(), resource_description.clone()); + } + } + + #[tracing::instrument(level = "debug", skip_all, fields(?ids))] + fn remove_resources(&mut self, ids: &[ResourceId]) { + for id in ids { + self.awaiting_connection.remove(id); + self.dns_resources_internal_ips.retain(|r, _| r.id != *id); + self.dns_resources.retain(|_, r| r.id != *id); + self.cidr_resources.retain(|_, r| r.id != *id); + self.deferred_dns_queries.retain(|(r, _), _| r.id != *id); + + self.resource_ids.remove(id); + + let Some(gateway_id) = self.resources_gateways.remove(id) else { + tracing::debug!("No gateway associated with resource"); + continue; + }; + + let Some(peer) = self.peers.get_mut(&gateway_id) else { + continue; + }; + + // First we remove the id from all allowed ips + for (network, resources) in peer + .allowed_ips + .iter_mut() + .filter(|(_, resources)| resources.contains(id)) + { + resources.remove(id); + + if !resources.is_empty() { + continue; + } + + // If the allowed_ips doesn't correspond to any resource anymore we + // clean up any related translation. + peer.transform + .translations + .remove_by_left(&network.network_address()); + } + + // We remove all empty allowed ips entry since there's no resource that corresponds to it + peer.allowed_ips.retain(|_, r| !r.is_empty()); + + // If there's no allowed ip left we remove the whole peer because there's no point on keeping it around + if peer.allowed_ips.is_empty() { + self.peers.remove(&gateway_id); + // TODO: should we have a Node::remove_connection? + } + } + + tracing::debug!("Resources removed") + } + fn update_dns_mapping(&mut self) -> bool { let Some(config) = &self.interface_config else { return false; @@ -1257,6 +1288,203 @@ mod tests { ) } + #[test] + fn add_resources_works() { + let mut client_state = ClientState::for_test(); + + client_state.add_resources(&[ + cidr_foo_resource("10.0.0.0/24"), + dns_bar_resource("baz.com"), + ]); + + assert_eq!( + hashset(client_state.resources().iter()), + hashset( + [ + cidr_foo_resource("10.0.0.0/24"), + dns_bar_resource("baz.com") + ] + .iter() + ) + ); + assert_eq!( + hashset(client_state.routes()), + expected_routes(vec![IpNetwork::from_str("10.0.0.0/24").unwrap()]) + ); + + client_state.add_resources(&[cidr_baz_resource("11.0.0.0/24")]); + + assert_eq!( + hashset(client_state.resources().iter()), + hashset( + [ + cidr_foo_resource("10.0.0.0/24"), + dns_bar_resource("baz.com"), + cidr_baz_resource("11.0.0.0/24") + ] + .iter() + ) + ); + assert_eq!( + hashset(client_state.routes()), + expected_routes(vec![ + IpNetwork::from_str("10.0.0.0/24").unwrap(), + IpNetwork::from_str("11.0.0.0/24").unwrap() + ]) + ); + } + + #[test] + fn add_resources_update_works_cidr() { + let mut client_state = ClientState::for_test(); + + client_state.add_resources(&[ + cidr_foo_resource("10.0.0.0/24"), + dns_bar_resource("baz.com"), + ]); + client_state.add_resources(&[cidr_foo_resource("11.0.0.0/24")]); + + assert_eq!( + hashset(client_state.resources().iter()), + hashset( + [ + cidr_foo_resource("11.0.0.0/24"), + dns_bar_resource("baz.com") + ] + .iter() + ) + ); + assert_eq!( + HashSet::::from_iter(client_state.routes()), + expected_routes(vec![IpNetwork::from_str("11.0.0.0/24").unwrap()]) + ); + } + + #[test] + fn add_resources_update_works_to_dns() { + let mut client_state = ClientState::for_test(); + + client_state.add_resources(&[ + cidr_foo_resource("10.0.0.0/24"), + dns_bar_resource("baz.com"), + ]); + client_state.add_resources(&[cidr_bar_id("11.0.0.0/24")]); + + assert_eq!( + hashset(client_state.resources().iter()), + hashset([cidr_bar_id("11.0.0.0/24"), cidr_foo_resource("10.0.0.0/24")].iter()) + ); + assert_eq!( + hashset(client_state.routes()), + expected_routes(vec![ + IpNetwork::from_str("10.0.0.0/24").unwrap(), + IpNetwork::from_str("11.0.0.0/24").unwrap() + ]) + ); + } + + #[test] + fn remove_resources_works() { + let mut client_state = ClientState::for_test(); + + client_state.add_resources(&[ + cidr_foo_resource("10.0.0.0/24"), + dns_bar_resource("baz.com"), + ]); + client_state.remove_resources(&[cidr_foo_id()]); + + assert_eq!( + hashset(client_state.resources().iter()), + hashset([dns_bar_resource("baz.com")].iter()) + ); + assert_eq!(hashset(client_state.routes()), expected_routes(vec![])); + } + + #[test] + fn set_resource_works() { + let mut client_state = ClientState::for_test(); + + client_state.set_resources(vec![ + cidr_foo_resource("10.0.0.0/24"), + dns_bar_resource("baz.com"), + ]); + + assert_eq!( + hashset(client_state.resources().iter()), + hashset( + [ + cidr_foo_resource("10.0.0.0/24"), + dns_bar_resource("baz.com") + ] + .iter() + ) + ); + assert_eq!( + HashSet::::from_iter(client_state.routes()), + expected_routes(vec![IpNetwork::from_str("10.0.0.0/24").unwrap()]) + ); + } + + #[test] + fn set_resource_replaces_old_resources() { + let mut client_state = ClientState::for_test(); + + client_state.set_resources(vec![ + cidr_foo_resource("10.0.0.0/24"), + dns_bar_resource("baz.com"), + ]); + client_state.set_resources(vec![cidr_baz_resource("11.0.0.0/24")]); + + assert_eq!( + hashset(client_state.resources().iter()), + hashset([cidr_baz_resource("11.0.0.0/24")].iter()) + ); + assert_eq!( + hashset(client_state.routes()), + expected_routes(vec![IpNetwork::from_str("11.0.0.0/24").unwrap()]) + ); + } + + #[test] + fn set_resource_updates_old_resource_with_same_id() { + let mut client_state = ClientState::for_test(); + + client_state.set_resources(vec![ + cidr_foo_resource("10.0.0.0/24"), + dns_bar_resource("baz.com"), + ]); + client_state.set_resources(vec![cidr_foo_resource("11.0.0.0/24")]); + + assert_eq!( + hashset(client_state.resources().iter()), + hashset([cidr_foo_resource("11.0.0.0/24")].iter()) + ); + assert_eq!( + hashset(client_state.routes()), + expected_routes(vec![IpNetwork::from_str("11.0.0.0/24").unwrap()]) + ); + } + + #[test] + fn set_resource_keeps_resource_if_unchanged() { + let mut client_state = ClientState::for_test(); + + client_state.set_resources(vec![ + cidr_foo_resource("10.0.0.0/24"), + dns_bar_resource("baz.com"), + ]); + client_state.set_resources(vec![cidr_foo_resource("10.0.0.0/24")]); + + assert_eq!( + hashset(client_state.resources().iter()), + hashset([cidr_foo_resource("10.0.0.0/24")].iter()) + ); + assert_eq!( + hashset(client_state.routes()), + expected_routes(vec![IpNetwork::from_str("10.0.0.0/24").unwrap()]) + ); + } + impl ClientState { fn for_test() -> ClientState { ClientState::new(StaticSecret::random_from_rng(OsRng)) @@ -1307,7 +1535,68 @@ mod tests { }) } + fn cidr_foo_resource(addr: &str) -> ResourceDescription { + ResourceDescription::Cidr(ResourceDescriptionCidr { + id: cidr_foo_id(), + address: addr.parse().unwrap(), + name: "foo".to_string(), + }) + } + + fn cidr_bar_id(addr: &str) -> ResourceDescription { + ResourceDescription::Cidr(ResourceDescriptionCidr { + id: dns_bar_id(), + address: addr.parse().unwrap(), + name: "foo".to_string(), + }) + } + + fn dns_bar_resource(addr: &str) -> ResourceDescription { + ResourceDescription::Dns(ResourceDescriptionDns { + id: dns_bar_id(), + address: addr.to_string(), + name: "bar".to_string(), + }) + } + + fn cidr_baz_resource(addr: &str) -> ResourceDescription { + ResourceDescription::Cidr(ResourceDescriptionCidr { + id: cidr_baz_id(), + address: addr.parse().unwrap(), + name: "baz".to_string(), + }) + } + + fn cidr_foo_id() -> ResourceId { + resource_id("fb51081a-2e06-4b59-b5a8-33592de9ebb1") + } + + fn cidr_baz_id() -> ResourceId { + resource_id("4e0bf4ea-4175-4cdb-a7c2-cbeffa8ccc5d") + } + + fn dns_bar_id() -> ResourceId { + resource_id("868483b6-431e-484d-bdd6-dad60ed26418") + } + fn ip(addr: &str) -> IpAddr { addr.parse().unwrap() } + + fn resource_id(id: &str) -> ResourceId { + id.parse().unwrap() + } + + fn expected_routes(resource_routes: Vec) -> HashSet { + HashSet::from_iter( + resource_routes + .into_iter() + .chain(iter::once(IpNetwork::from_str(IPV4_RESOURCES).unwrap())) + .chain(iter::once(IpNetwork::from_str(IPV6_RESOURCES).unwrap())), + ) + } + + fn hashset(val: impl Iterator) -> HashSet { + HashSet::from_iter(val) + } }