diff --git a/rust/connlib/clients/shared/src/control.rs b/rust/connlib/clients/shared/src/control.rs index 55c135fb0..5e96b5fe8 100644 --- a/rust/connlib/clients/shared/src/control.rs +++ b/rust/connlib/clients/shared/src/control.rs @@ -366,6 +366,20 @@ impl ControlPlane { } }); } + Ok(firezone_tunnel::Event::RefreshResources { connections }) => { + let mut control_signaler = self.phoenix_channel.clone(); + tokio::spawn(async move { + for connection in connections { + let resource_id = connection.resource_id; + if let Err(err) = control_signaler + .send_with_ref(EgressMessages::ReuseConnection(connection), resource_id) + .await + { + tracing::warn!(%resource_id, ?err, "failed to refresh resource dns: {err:#?}"); + } + } + }); + } Err(e) => { tracing::error!("Tunnel failed: {e}"); } diff --git a/rust/connlib/tunnel/src/client.rs b/rust/connlib/tunnel/src/client.rs index 2d474ddbb..b352a800e 100644 --- a/rust/connlib/tunnel/src/client.rs +++ b/rust/connlib/tunnel/src/client.rs @@ -29,7 +29,7 @@ use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; use std::sync::Arc; use std::task::{Context, Poll}; use std::time::Duration; -use tokio::time::Instant; +use tokio::time::{Instant, Interval, MissedTickBehavior}; use webrtc::ice_transport::ice_candidate::RTCIceCandidate; // Using str here because Ipv4/6Network doesn't support `const` 🙃 @@ -197,6 +197,8 @@ pub struct ClientState { forwarded_dns_queries: BoundedQueue>, ip_provider: IpProvider, + + refresh_dns_timer: Interval, } #[derive(Debug, Clone, PartialEq, Eq)] @@ -596,23 +598,35 @@ impl ClientState { .get(resource) .cloned() .unwrap_or(Vec::new()); - let addr_v4 = addrs.iter().copied().filter(IpAddr::is_ipv4); + // We collect here to eagerly filter so that `next` is not called more times than needed with the ip_provider + // that could cause an ip exhaustion. + // This is needed to get the length, since even if zip is run, only until addr_v4 is None, `next` is still being called in both elements + // so ip_provider consumes an extra ip and this could in the long-run consume all ips since this function is also called to refresh ip allocations.. + let addr_v4 = addrs.iter().copied().filter(IpAddr::is_ipv4).collect_vec(); + let len = addr_v4.len(); let internal_ips_v4 = internal_ips .iter() .copied() .filter_map(get_v4) .chain(&mut self.ip_provider.ipv4) .map(Into::::into) - .zip(addr_v4); + .zip(addr_v4.clone()) + .take(len); - let addr_v6 = addrs.iter().copied().filter(IpAddr::is_ipv6); + tracing::warn!("external_ips: {addr_v4:?}"); + tracing::warn!("internal_ips: {internal_ips:?}"); + + // Same note as for ipv4, though an exhaustion is not a realistic scenario. + let addr_v6 = addrs.iter().copied().filter(IpAddr::is_ipv6).collect_vec(); + let len = addr_v6.len(); let internal_ips_v6 = internal_ips .iter() .copied() .filter_map(get_v6) .chain(&mut self.ip_provider.ipv6) .map(Into::::into) - .zip(addr_v6); + .zip(addr_v6) + .take(len); internal_ips_v4.chain(internal_ips_v6).collect() } @@ -634,6 +648,11 @@ impl IpProvider { impl Default for ClientState { fn default() -> Self { + // With this single timer this might mean that some DNS are refreshed too often + // however... this also mean any resource is refresh within a 5 mins interval + // therefore, only the first time it's added that happens, after that it doesn't matter. + let mut interval = tokio::time::interval(Duration::from_secs(300)); + interval.set_missed_tick_behavior(MissedTickBehavior::Delay); Self { active_candidate_receivers: StreamMap::new( Duration::from_secs(ICE_GATHERING_TIMEOUT_SECONDS), @@ -662,6 +681,7 @@ impl Default for ClientState { resource_ids: Default::default(), peers_by_ip: IpNetworkTable::new(), deferred_dns_queries: Default::default(), + refresh_dns_timer: interval, } } } @@ -733,6 +753,30 @@ impl RoleState for ClientState { Poll::Pending => {} } + if self.refresh_dns_timer.poll_tick(cx).is_ready() { + let mut connections = Vec::new(); + for resource in self.dns_resources_internal_ips.keys() { + let Some(gateway_id) = self.resources_gateways.get(&resource.id) else { + continue; + }; + // filter inactive connections + if !self + .peers_by_ip + .iter() + .any(|(_, p)| &p.inner.conn_id == gateway_id) + { + continue; + } + + connections.push(ReuseConnection { + resource_id: resource.id, + gateway_id: *gateway_id, + payload: Some(resource.address.clone()), + }); + } + return Poll::Ready(Event::RefreshResources { connections }); + } + return self.forwarded_dns_queries.poll(cx).map(Event::DnsQuery); } } diff --git a/rust/connlib/tunnel/src/lib.rs b/rust/connlib/tunnel/src/lib.rs index 72b99e7ac..4e20707ce 100644 --- a/rust/connlib/tunnel/src/lib.rs +++ b/rust/connlib/tunnel/src/lib.rs @@ -8,7 +8,10 @@ use boringtun::{ }; use bytes::Bytes; -use connlib_shared::{messages::Key, CallbackErrorFacade, Callbacks, Error}; +use connlib_shared::{ + messages::{Key, ReuseConnection}, + CallbackErrorFacade, Callbacks, Error, +}; use ip_network::IpNetwork; use ip_network_table::IpNetworkTable; use ip_packet::IpPacket; @@ -447,6 +450,9 @@ pub enum Event { connected_gateway_ids: HashSet, reference: usize, }, + RefreshResources { + connections: Vec, + }, DnsQuery(DnsQuery<'static>), }