From ecfa919bbcbf542c50194a8f4bc6e1f7c0bb3e9f Mon Sep 17 00:00:00 2001 From: Gabi Date: Fri, 22 Dec 2023 16:12:32 -0300 Subject: [PATCH] refactor(connlib): refresh dns addresses (#2994) Fix for #2956 this is achieved by refreshing access to every resource every 5 minutes. There's still an open question for this PR: When the gateway resolves an ip the gateway allows access to a DNS resource it resolves the address and allow access to that ip for that client. Right now, until the access for that resource doesn't expire that access isn't revoked. We could change it so that we require the client to refresh such access(with this PR those refresh queries are already being made every 5 minutes) every x minutes on top of the `expires_at` or we can keep `expires_at` as to mean "allow access until `expires_at` for whatever this resource resolves to". cc @jamilbk @AndrewDryga --- rust/connlib/clients/shared/src/control.rs | 14 ++++++ rust/connlib/tunnel/src/client.rs | 54 ++++++++++++++++++++-- rust/connlib/tunnel/src/lib.rs | 8 +++- 3 files changed, 70 insertions(+), 6 deletions(-) diff --git a/rust/connlib/clients/shared/src/control.rs b/rust/connlib/clients/shared/src/control.rs index 55c135fb0..5e96b5fe8 100644 --- a/rust/connlib/clients/shared/src/control.rs +++ b/rust/connlib/clients/shared/src/control.rs @@ -366,6 +366,20 @@ impl ControlPlane { } }); } + Ok(firezone_tunnel::Event::RefreshResources { connections }) => { + let mut control_signaler = self.phoenix_channel.clone(); + tokio::spawn(async move { + for connection in connections { + let resource_id = connection.resource_id; + if let Err(err) = control_signaler + .send_with_ref(EgressMessages::ReuseConnection(connection), resource_id) + .await + { + tracing::warn!(%resource_id, ?err, "failed to refresh resource dns: {err:#?}"); + } + } + }); + } Err(e) => { tracing::error!("Tunnel failed: {e}"); } diff --git a/rust/connlib/tunnel/src/client.rs b/rust/connlib/tunnel/src/client.rs index 2d474ddbb..b352a800e 100644 --- a/rust/connlib/tunnel/src/client.rs +++ b/rust/connlib/tunnel/src/client.rs @@ -29,7 +29,7 @@ use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; use std::sync::Arc; use std::task::{Context, Poll}; use std::time::Duration; -use tokio::time::Instant; +use tokio::time::{Instant, Interval, MissedTickBehavior}; use webrtc::ice_transport::ice_candidate::RTCIceCandidate; // Using str here because Ipv4/6Network doesn't support `const` 🙃 @@ -197,6 +197,8 @@ pub struct ClientState { forwarded_dns_queries: BoundedQueue>, ip_provider: IpProvider, + + refresh_dns_timer: Interval, } #[derive(Debug, Clone, PartialEq, Eq)] @@ -596,23 +598,35 @@ impl ClientState { .get(resource) .cloned() .unwrap_or(Vec::new()); - let addr_v4 = addrs.iter().copied().filter(IpAddr::is_ipv4); + // We collect here to eagerly filter so that `next` is not called more times than needed with the ip_provider + // that could cause an ip exhaustion. + // This is needed to get the length, since even if zip is run, only until addr_v4 is None, `next` is still being called in both elements + // so ip_provider consumes an extra ip and this could in the long-run consume all ips since this function is also called to refresh ip allocations.. + let addr_v4 = addrs.iter().copied().filter(IpAddr::is_ipv4).collect_vec(); + let len = addr_v4.len(); let internal_ips_v4 = internal_ips .iter() .copied() .filter_map(get_v4) .chain(&mut self.ip_provider.ipv4) .map(Into::::into) - .zip(addr_v4); + .zip(addr_v4.clone()) + .take(len); - let addr_v6 = addrs.iter().copied().filter(IpAddr::is_ipv6); + tracing::warn!("external_ips: {addr_v4:?}"); + tracing::warn!("internal_ips: {internal_ips:?}"); + + // Same note as for ipv4, though an exhaustion is not a realistic scenario. + let addr_v6 = addrs.iter().copied().filter(IpAddr::is_ipv6).collect_vec(); + let len = addr_v6.len(); let internal_ips_v6 = internal_ips .iter() .copied() .filter_map(get_v6) .chain(&mut self.ip_provider.ipv6) .map(Into::::into) - .zip(addr_v6); + .zip(addr_v6) + .take(len); internal_ips_v4.chain(internal_ips_v6).collect() } @@ -634,6 +648,11 @@ impl IpProvider { impl Default for ClientState { fn default() -> Self { + // With this single timer this might mean that some DNS are refreshed too often + // however... this also mean any resource is refresh within a 5 mins interval + // therefore, only the first time it's added that happens, after that it doesn't matter. + let mut interval = tokio::time::interval(Duration::from_secs(300)); + interval.set_missed_tick_behavior(MissedTickBehavior::Delay); Self { active_candidate_receivers: StreamMap::new( Duration::from_secs(ICE_GATHERING_TIMEOUT_SECONDS), @@ -662,6 +681,7 @@ impl Default for ClientState { resource_ids: Default::default(), peers_by_ip: IpNetworkTable::new(), deferred_dns_queries: Default::default(), + refresh_dns_timer: interval, } } } @@ -733,6 +753,30 @@ impl RoleState for ClientState { Poll::Pending => {} } + if self.refresh_dns_timer.poll_tick(cx).is_ready() { + let mut connections = Vec::new(); + for resource in self.dns_resources_internal_ips.keys() { + let Some(gateway_id) = self.resources_gateways.get(&resource.id) else { + continue; + }; + // filter inactive connections + if !self + .peers_by_ip + .iter() + .any(|(_, p)| &p.inner.conn_id == gateway_id) + { + continue; + } + + connections.push(ReuseConnection { + resource_id: resource.id, + gateway_id: *gateway_id, + payload: Some(resource.address.clone()), + }); + } + return Poll::Ready(Event::RefreshResources { connections }); + } + return self.forwarded_dns_queries.poll(cx).map(Event::DnsQuery); } } diff --git a/rust/connlib/tunnel/src/lib.rs b/rust/connlib/tunnel/src/lib.rs index 72b99e7ac..4e20707ce 100644 --- a/rust/connlib/tunnel/src/lib.rs +++ b/rust/connlib/tunnel/src/lib.rs @@ -8,7 +8,10 @@ use boringtun::{ }; use bytes::Bytes; -use connlib_shared::{messages::Key, CallbackErrorFacade, Callbacks, Error}; +use connlib_shared::{ + messages::{Key, ReuseConnection}, + CallbackErrorFacade, Callbacks, Error, +}; use ip_network::IpNetwork; use ip_network_table::IpNetworkTable; use ip_packet::IpPacket; @@ -447,6 +450,9 @@ pub enum Event { connected_gateway_ids: HashSet, reference: usize, }, + RefreshResources { + connections: Vec, + }, DnsQuery(DnsQuery<'static>), }