From 19bcaa95398bcf079d98758c8f1f13051ca0c7e7 Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Fri, 16 Feb 2024 11:29:31 +1100 Subject: [PATCH] refactor(connlib): move DNS resolution into tunnel (#3652) Previously, this mapping was not stored within the tunnel so we had to perform the resolution further up. This has changed and the tunnel itself now knows about this mapping. Thus, we can easily move the actual DNS resolution also into the tunnel, thereby reducing the API surface of `Tunnel` because we don't need the `write_dns_lookup_response` function. This is crucial because it is the last place where `Tunnel` is being cloned in #3391. With this sorted out the way, we can remove all `Arc`s and locks from `Tunnel` as part of #3391. --- rust/Cargo.lock | 1 - rust/connlib/clients/shared/Cargo.toml | 1 - rust/connlib/clients/shared/src/control.rs | 49 +------- rust/connlib/clients/shared/src/lib.rs | 2 - rust/connlib/tunnel/Cargo.toml | 2 +- rust/connlib/tunnel/src/bounded_queue.rs | 59 ---------- rust/connlib/tunnel/src/client.rs | 105 +++++++++++++----- .../tunnel/src/control_protocol/client.rs | 3 +- rust/connlib/tunnel/src/lib.rs | 19 +++- 9 files changed, 96 insertions(+), 145 deletions(-) delete mode 100644 rust/connlib/tunnel/src/bounded_queue.rs diff --git a/rust/Cargo.lock b/rust/Cargo.lock index b910c14eb..70d831df0 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -1218,7 +1218,6 @@ dependencies = [ "chrono", "connlib-shared", "firezone-tunnel", - "hickory-resolver", "ip_network", "parking_lot", "reqwest", diff --git a/rust/connlib/clients/shared/Cargo.toml b/rust/connlib/clients/shared/Cargo.toml index 35bdd582f..76e894534 100644 --- a/rust/connlib/clients/shared/Cargo.toml +++ b/rust/connlib/clients/shared/Cargo.toml @@ -27,7 +27,6 @@ time = { version = "0.3.34", features = ["formatting"] } reqwest = { version = "0.11.22", default-features = false, features = ["stream", "rustls-tls"] } tokio-tungstenite = { version = "0.21", default-features = false, features = ["connect", "handshake", "rustls-tls-webpki-roots"] } async-compression = { version = "0.4.6", features = ["tokio", "gzip"] } -hickory-resolver = { workspace = true, features = ["tokio-runtime"] } parking_lot = "0.12" bimap = "0.6" ip_network = { version = "0.4", default-features = false } diff --git a/rust/connlib/clients/shared/src/control.rs b/rust/connlib/clients/shared/src/control.rs index 2f4dc0143..edc9f8311 100644 --- a/rust/connlib/clients/shared/src/control.rs +++ b/rust/connlib/clients/shared/src/control.rs @@ -6,7 +6,6 @@ use connlib_shared::control::Reason; use connlib_shared::messages::{DnsServer, GatewayResponse, IpDnsServer}; use connlib_shared::IpProvider; use ip_network::IpNetwork; -use std::collections::HashMap; use std::net::IpAddr; use std::path::PathBuf; use std::str::FromStr; @@ -25,8 +24,6 @@ use connlib_shared::{ }; use firezone_tunnel::{ClientState, Request, Tunnel}; -use hickory_resolver::config::{NameServerConfig, Protocol, ResolverConfig}; -use hickory_resolver::TokioAsyncResolver; use reqwest::header::{CONTENT_ENCODING, CONTENT_TYPE}; use tokio::io::BufReader; use tokio::sync::Mutex; @@ -41,11 +38,6 @@ pub struct ControlPlane { pub tunnel: Arc>, pub phoenix_channel: PhoenixSenderWithTopic, pub tunnel_init: Mutex, - // It's a Mutex> because we need the init message to initialize the resolver - // also, in platforms with split DNS and no configured upstream dns this will be None. - // - // We could still initialize the resolver with no nameservers in those platforms... - pub fallback_resolver: parking_lot::Mutex>, } fn effective_dns_servers( @@ -76,22 +68,6 @@ fn effective_dns_servers( .collect() } -fn create_resolvers( - sentinel_mapping: BiMap, -) -> HashMap { - sentinel_mapping - .iter() - .map(|(sentinel, srv)| { - let mut resolver_config = ResolverConfig::new(); - resolver_config.add_name_server(NameServerConfig::new(srv.address(), Protocol::Udp)); - ( - *sentinel, - TokioAsyncResolver::tokio(resolver_config, Default::default()), - ) - }) - .collect() -} - fn sentinel_dns_mapping(dns: &[DnsServer]) -> BiMap { let mut ip_provider = IpProvider::new( DNS_SENTINELS_V4.parse().unwrap(), @@ -148,9 +124,6 @@ impl ControlPlane { for resource_description in resources { self.add_resource(resource_description); } - - // Note: watch out here we're holding 2 mutexes - *self.fallback_resolver.lock() = create_resolvers(sentinel_mapping); } else { tracing::info!("Firezone reinitializated"); } @@ -395,25 +368,6 @@ impl ControlPlane { // TODO: Clean up connection in `ClientState` here? } } - Ok(firezone_tunnel::Event::DnsQuery(query)) => { - // Until we handle it better on a gateway-like eventloop, making sure not to block the loop - let Some(resolver) = self - .fallback_resolver - .lock() - .get(&query.query.destination()) - .cloned() - else { - return; - }; - - let tunnel = self.tunnel.clone(); - tokio::spawn(async move { - let response = resolver.lookup(&query.name, query.record_type).await; - if let Err(err) = tunnel.write_dns_lookup_response(response, query.query) { - tracing::debug!(err = ?err, name = query.name, record_type = ?query.record_type, "DNS lookup failed: {err:#}"); - } - }); - } Ok(firezone_tunnel::Event::RefreshResources { connections }) => { let mut control_signaler = self.phoenix_channel.clone(); tokio::spawn(async move { @@ -428,6 +382,9 @@ impl ControlPlane { } }); } + Ok(firezone_tunnel::Event::SendPacket(_)) => { + unimplemented!("Handled internally"); + } Err(e) => { tracing::error!("Tunnel failed: {e}"); } diff --git a/rust/connlib/clients/shared/src/lib.rs b/rust/connlib/clients/shared/src/lib.rs index f6159601f..04e104f58 100644 --- a/rust/connlib/clients/shared/src/lib.rs +++ b/rust/connlib/clients/shared/src/lib.rs @@ -12,7 +12,6 @@ use messages::IngressMessages; use messages::Messages; use messages::ReplyMessages; use secrecy::{Secret, SecretString}; -use std::collections::HashMap; use std::sync::Arc; use std::time::Duration; use tokio::time::{Interval, MissedTickBehavior}; @@ -178,7 +177,6 @@ where tunnel: Arc::new(tunnel), phoenix_channel: connection.sender_with_topic("client".to_owned()), tunnel_init: Mutex::new(false), - fallback_resolver: parking_lot::Mutex::new(HashMap::new()), }; tokio::spawn({ diff --git a/rust/connlib/tunnel/Cargo.toml b/rust/connlib/tunnel/Cargo.toml index 285fa879f..cd99a5b83 100644 --- a/rust/connlib/tunnel/Cargo.toml +++ b/rust/connlib/tunnel/Cargo.toml @@ -26,7 +26,7 @@ boringtun = { workspace = true } chrono = { workspace = true } pnet_packet = { version = "0.34" } futures-bounded = { workspace = true } -hickory-resolver = { workspace = true } +hickory-resolver = { workspace = true, features = ["tokio-runtime"] } arc-swap = "1.6.0" bimap = "0.6" resolv-conf = "0.7.0" diff --git a/rust/connlib/tunnel/src/bounded_queue.rs b/rust/connlib/tunnel/src/bounded_queue.rs deleted file mode 100644 index 898440fd3..000000000 --- a/rust/connlib/tunnel/src/bounded_queue.rs +++ /dev/null @@ -1,59 +0,0 @@ -use core::fmt; -use std::{ - collections::VecDeque, - task::{Context, Poll, Waker}, -}; - -// Simple bounded queue for one-time events -#[derive(Debug, Clone)] -pub(crate) struct BoundedQueue { - queue: VecDeque, - limit: usize, - waker: Option, -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub(crate) struct Full; - -impl fmt::Display for Full { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "Queue is full") - } -} - -impl BoundedQueue { - pub(crate) fn with_capacity(cap: usize) -> BoundedQueue { - BoundedQueue { - queue: VecDeque::with_capacity(cap), - limit: cap, - waker: None, - } - } - - pub(crate) fn poll(&mut self, cx: &Context) -> Poll { - if let Some(front) = self.queue.pop_front() { - return Poll::Ready(front); - } - - self.waker = Some(cx.waker().clone()); - Poll::Pending - } - - fn at_capacity(&self) -> bool { - self.queue.len() == self.limit - } - - pub(crate) fn push_back(&mut self, x: T) -> Result<(), Full> { - if self.at_capacity() { - return Err(Full); - } - - self.queue.push_back(x); - - if let Some(ref waker) = self.waker { - waker.wake_by_ref(); - } - - Ok(()) - } -} diff --git a/rust/connlib/tunnel/src/client.rs b/rust/connlib/tunnel/src/client.rs index 8e115e4a9..da30c68b5 100644 --- a/rust/connlib/tunnel/src/client.rs +++ b/rust/connlib/tunnel/src/client.rs @@ -1,4 +1,3 @@ -use crate::bounded_queue::BoundedQueue; use crate::device_channel::{Device, Packet}; use crate::ip_packet::{IpPacket, MutableIpPacket}; use crate::peer::PacketTransformClient; @@ -17,12 +16,13 @@ use connlib_shared::{Callbacks, Dname, IpProvider}; use domain::base::Rtype; use futures::channel::mpsc::Receiver; use futures::stream; -use futures_bounded::{FuturesMap, PushError, StreamMap}; -use hickory_resolver::lookup::Lookup; +use futures_bounded::{FuturesMap, FuturesTupleSet, PushError, StreamMap}; use ip_network::IpNetwork; use ip_network_table::IpNetworkTable; use itertools::Itertools; +use hickory_resolver::config::{NameServerConfig, Protocol, ResolverConfig}; +use hickory_resolver::TokioAsyncResolver; use rand_core::OsRng; use std::collections::hash_map::Entry; use std::collections::{HashMap, HashSet, VecDeque}; @@ -110,24 +110,6 @@ where Ok(()) } - /// Writes the response to a DNS lookup - #[tracing::instrument(level = "trace", skip(self))] - pub fn write_dns_lookup_response( - &self, - response: hickory_resolver::error::ResolveResult, - query: IpPacket<'static>, - ) -> connlib_shared::Result<()> { - if let Some(pkt) = dns::build_response_from_resolve_result(query, response)? { - let Some(ref device) = *self.device.load() else { - return Ok(()); - }; - - device.write(pkt)?; - } - - Ok(()) - } - /// Sets the interface configuration and starts background tasks. #[tracing::instrument(level = "trace", skip(self))] pub fn set_interface( @@ -157,7 +139,7 @@ where return Err(errs.pop().unwrap()); } - self.role_state.lock().dns_mapping = dns_mapping; + self.role_state.lock().set_dns_mapping(dns_mapping); let res_v4 = self.add_route(IPV4_RESOURCES.parse().unwrap()); let res_v6 = self.add_route(IPV6_RESOURCES.parse().unwrap()); @@ -228,13 +210,17 @@ pub struct ClientState { #[allow(clippy::type_complexity)] pub peers_by_ip: IpNetworkTable>, - forwarded_dns_queries: BoundedQueue>, + forwarded_dns_queries: FuturesTupleSet< + Result, + DnsQuery<'static>, + >, pub ip_provider: IpProvider, refresh_dns_timer: Interval, - pub dns_mapping: BiMap, + dns_mapping: BiMap, + dns_resolvers: HashMap, } #[derive(Debug, Clone, PartialEq, Eq)] @@ -579,6 +565,15 @@ impl ClientState { } } + pub fn set_dns_mapping(&mut self, mapping: BiMap) { + self.dns_mapping = mapping.clone(); + self.dns_resolvers = create_resolvers(mapping); + } + + pub fn dns_mapping(&self) -> BiMap { + self.dns_mapping.clone() + } + fn is_awaiting_connection_to_cidr(&self, destination: IpAddr) -> bool { let Some(resource) = self.get_cidr_resource_by_destination(destination) else { return false; @@ -633,16 +628,48 @@ impl ClientState { } fn add_pending_dns_query(&mut self, query: DnsQuery) { + let upstream = query.query.destination(); + let Some(resolver) = self.dns_resolvers.get(&upstream).cloned() else { + tracing::warn!(%upstream, "Dropping DNS query because of unknown upstream DNS server"); + return; + }; + + let query = query.into_owned(); + if self .forwarded_dns_queries - .push_back(query.into_owned()) + .try_push( + { + let name = query.name.clone(); + let record_type = query.record_type; + + async move { resolver.lookup(&name, record_type).await } + }, + query, + ) .is_err() { - tracing::warn!("Too many DNS queries, dropping new ones"); + tracing::warn!("Too many DNS queries, dropping existing one"); } } } +fn create_resolvers( + sentinel_mapping: BiMap, +) -> HashMap { + sentinel_mapping + .into_iter() + .map(|(sentinel, srv)| { + let mut resolver_config = ResolverConfig::new(); + resolver_config.add_name_server(NameServerConfig::new(srv.address(), Protocol::Udp)); + ( + sentinel, + TokioAsyncResolver::tokio(resolver_config, Default::default()), + ) + }) + .collect() +} + impl Default for ClientState { fn default() -> Self { // With this single timer this might mean that some DNS are refreshed too often @@ -665,7 +692,10 @@ impl Default for ClientState { gateway_public_keys: Default::default(), resources_gateways: Default::default(), - forwarded_dns_queries: BoundedQueue::with_capacity(DNS_QUERIES_QUEUE_SIZE), + forwarded_dns_queries: FuturesTupleSet::new( + Duration::from_secs(60), + DNS_QUERIES_QUEUE_SIZE, + ), gateway_preshared_keys: Default::default(), // TODO: decide ip ranges ip_provider: IpProvider::new( @@ -680,6 +710,7 @@ impl Default for ClientState { deferred_dns_queries: Default::default(), refresh_dns_timer: interval, dns_mapping: Default::default(), + dns_resolvers: Default::default(), } } } @@ -780,7 +811,25 @@ impl RoleState for ClientState { return Poll::Ready(Event::RefreshResources { connections }); } - return self.forwarded_dns_queries.poll(cx).map(Event::DnsQuery); + match self.forwarded_dns_queries.poll_unpin(cx) { + Poll::Ready((Ok(response), query)) => { + match dns::build_response_from_resolve_result(query.query, response) { + Ok(Some(packet)) => return Poll::Ready(Event::SendPacket(packet)), + Ok(None) => continue, + Err(e) => { + tracing::warn!("Failed to build DNS response from lookup result: {e}"); + continue; + } + } + } + Poll::Ready((Err(resolve_timeout), query)) => { + tracing::warn!(name = %query.name, server = %query.query.destination(), "DNS query timed out: {resolve_timeout}"); + continue; + } + Poll::Pending => {} + } + + return Poll::Pending; } } diff --git a/rust/connlib/tunnel/src/control_protocol/client.rs b/rust/connlib/tunnel/src/control_protocol/client.rs index 8b93d5f26..8b452842d 100644 --- a/rust/connlib/tunnel/src/control_protocol/client.rs +++ b/rust/connlib/tunnel/src/control_protocol/client.rs @@ -164,8 +164,7 @@ where let (peer_sender, peer_receiver) = tokio::sync::mpsc::channel(PEER_QUEUE_SIZE); - peer.transform - .set_dns(self.role_state.lock().dns_mapping.clone()); + peer.transform.set_dns(self.role_state.lock().dns_mapping()); start_handlers( Arc::clone(self), diff --git a/rust/connlib/tunnel/src/lib.rs b/rust/connlib/tunnel/src/lib.rs index 02e10fbd3..7c803b719 100644 --- a/rust/connlib/tunnel/src/lib.rs +++ b/rust/connlib/tunnel/src/lib.rs @@ -30,7 +30,7 @@ use webrtc::{ interceptor::registry::Registry, }; -use arc_swap::ArcSwapOption; +use arc_swap::{access::Access, ArcSwapOption}; use futures_util::task::AtomicWaker; use std::task::{ready, Context, Poll}; use std::{collections::HashMap, fmt, net::IpAddr, sync::Arc, time::Duration}; @@ -56,7 +56,6 @@ use connlib_shared::messages::{ClientId, SecretKey}; use device_channel::Device; use index::IndexLfsr; -mod bounded_queue; mod client; mod control_protocol; mod device_channel; @@ -365,8 +364,18 @@ where } } - if let Poll::Ready(event) = self.role_state.lock().poll_next_event(cx) { - return Poll::Ready(event); + match self.role_state.lock().poll_next_event(cx) { + Poll::Ready(Event::SendPacket(packet)) => { + let Some(device) = self.device.load().clone() else { + continue; + }; + + let _ = device.write(packet); + + continue; + } + Poll::Ready(other) => return Poll::Ready(other), + _ => (), } return Poll::Pending; @@ -448,7 +457,7 @@ pub enum Event { RefreshResources { connections: Vec, }, - DnsQuery(DnsQuery<'static>), + SendPacket(device_channel::Packet<'static>), } impl Tunnel