diff --git a/rust/Cargo.lock b/rust/Cargo.lock index b910c14eb..70d831df0 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -1218,7 +1218,6 @@ dependencies = [ "chrono", "connlib-shared", "firezone-tunnel", - "hickory-resolver", "ip_network", "parking_lot", "reqwest", diff --git a/rust/connlib/clients/shared/Cargo.toml b/rust/connlib/clients/shared/Cargo.toml index 35bdd582f..76e894534 100644 --- a/rust/connlib/clients/shared/Cargo.toml +++ b/rust/connlib/clients/shared/Cargo.toml @@ -27,7 +27,6 @@ time = { version = "0.3.34", features = ["formatting"] } reqwest = { version = "0.11.22", default-features = false, features = ["stream", "rustls-tls"] } tokio-tungstenite = { version = "0.21", default-features = false, features = ["connect", "handshake", "rustls-tls-webpki-roots"] } async-compression = { version = "0.4.6", features = ["tokio", "gzip"] } -hickory-resolver = { workspace = true, features = ["tokio-runtime"] } parking_lot = "0.12" bimap = "0.6" ip_network = { version = "0.4", default-features = false } diff --git a/rust/connlib/clients/shared/src/control.rs b/rust/connlib/clients/shared/src/control.rs index 2f4dc0143..edc9f8311 100644 --- a/rust/connlib/clients/shared/src/control.rs +++ b/rust/connlib/clients/shared/src/control.rs @@ -6,7 +6,6 @@ use connlib_shared::control::Reason; use connlib_shared::messages::{DnsServer, GatewayResponse, IpDnsServer}; use connlib_shared::IpProvider; use ip_network::IpNetwork; -use std::collections::HashMap; use std::net::IpAddr; use std::path::PathBuf; use std::str::FromStr; @@ -25,8 +24,6 @@ use connlib_shared::{ }; use firezone_tunnel::{ClientState, Request, Tunnel}; -use hickory_resolver::config::{NameServerConfig, Protocol, ResolverConfig}; -use hickory_resolver::TokioAsyncResolver; use reqwest::header::{CONTENT_ENCODING, CONTENT_TYPE}; use tokio::io::BufReader; use tokio::sync::Mutex; @@ -41,11 +38,6 @@ pub struct ControlPlane { pub tunnel: Arc>, pub phoenix_channel: PhoenixSenderWithTopic, pub tunnel_init: Mutex, - // It's a Mutex> because we need the init message to initialize the resolver - // also, in platforms with split DNS and no configured upstream dns this will be None. - // - // We could still initialize the resolver with no nameservers in those platforms... - pub fallback_resolver: parking_lot::Mutex>, } fn effective_dns_servers( @@ -76,22 +68,6 @@ fn effective_dns_servers( .collect() } -fn create_resolvers( - sentinel_mapping: BiMap, -) -> HashMap { - sentinel_mapping - .iter() - .map(|(sentinel, srv)| { - let mut resolver_config = ResolverConfig::new(); - resolver_config.add_name_server(NameServerConfig::new(srv.address(), Protocol::Udp)); - ( - *sentinel, - TokioAsyncResolver::tokio(resolver_config, Default::default()), - ) - }) - .collect() -} - fn sentinel_dns_mapping(dns: &[DnsServer]) -> BiMap { let mut ip_provider = IpProvider::new( DNS_SENTINELS_V4.parse().unwrap(), @@ -148,9 +124,6 @@ impl ControlPlane { for resource_description in resources { self.add_resource(resource_description); } - - // Note: watch out here we're holding 2 mutexes - *self.fallback_resolver.lock() = create_resolvers(sentinel_mapping); } else { tracing::info!("Firezone reinitializated"); } @@ -395,25 +368,6 @@ impl ControlPlane { // TODO: Clean up connection in `ClientState` here? } } - Ok(firezone_tunnel::Event::DnsQuery(query)) => { - // Until we handle it better on a gateway-like eventloop, making sure not to block the loop - let Some(resolver) = self - .fallback_resolver - .lock() - .get(&query.query.destination()) - .cloned() - else { - return; - }; - - let tunnel = self.tunnel.clone(); - tokio::spawn(async move { - let response = resolver.lookup(&query.name, query.record_type).await; - if let Err(err) = tunnel.write_dns_lookup_response(response, query.query) { - tracing::debug!(err = ?err, name = query.name, record_type = ?query.record_type, "DNS lookup failed: {err:#}"); - } - }); - } Ok(firezone_tunnel::Event::RefreshResources { connections }) => { let mut control_signaler = self.phoenix_channel.clone(); tokio::spawn(async move { @@ -428,6 +382,9 @@ impl ControlPlane { } }); } + Ok(firezone_tunnel::Event::SendPacket(_)) => { + unimplemented!("Handled internally"); + } Err(e) => { tracing::error!("Tunnel failed: {e}"); } diff --git a/rust/connlib/clients/shared/src/lib.rs b/rust/connlib/clients/shared/src/lib.rs index f6159601f..04e104f58 100644 --- a/rust/connlib/clients/shared/src/lib.rs +++ b/rust/connlib/clients/shared/src/lib.rs @@ -12,7 +12,6 @@ use messages::IngressMessages; use messages::Messages; use messages::ReplyMessages; use secrecy::{Secret, SecretString}; -use std::collections::HashMap; use std::sync::Arc; use std::time::Duration; use tokio::time::{Interval, MissedTickBehavior}; @@ -178,7 +177,6 @@ where tunnel: Arc::new(tunnel), phoenix_channel: connection.sender_with_topic("client".to_owned()), tunnel_init: Mutex::new(false), - fallback_resolver: parking_lot::Mutex::new(HashMap::new()), }; tokio::spawn({ diff --git a/rust/connlib/tunnel/Cargo.toml b/rust/connlib/tunnel/Cargo.toml index 285fa879f..cd99a5b83 100644 --- a/rust/connlib/tunnel/Cargo.toml +++ b/rust/connlib/tunnel/Cargo.toml @@ -26,7 +26,7 @@ boringtun = { workspace = true } chrono = { workspace = true } pnet_packet = { version = "0.34" } futures-bounded = { workspace = true } -hickory-resolver = { workspace = true } +hickory-resolver = { workspace = true, features = ["tokio-runtime"] } arc-swap = "1.6.0" bimap = "0.6" resolv-conf = "0.7.0" diff --git a/rust/connlib/tunnel/src/bounded_queue.rs b/rust/connlib/tunnel/src/bounded_queue.rs deleted file mode 100644 index 898440fd3..000000000 --- a/rust/connlib/tunnel/src/bounded_queue.rs +++ /dev/null @@ -1,59 +0,0 @@ -use core::fmt; -use std::{ - collections::VecDeque, - task::{Context, Poll, Waker}, -}; - -// Simple bounded queue for one-time events -#[derive(Debug, Clone)] -pub(crate) struct BoundedQueue { - queue: VecDeque, - limit: usize, - waker: Option, -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub(crate) struct Full; - -impl fmt::Display for Full { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "Queue is full") - } -} - -impl BoundedQueue { - pub(crate) fn with_capacity(cap: usize) -> BoundedQueue { - BoundedQueue { - queue: VecDeque::with_capacity(cap), - limit: cap, - waker: None, - } - } - - pub(crate) fn poll(&mut self, cx: &Context) -> Poll { - if let Some(front) = self.queue.pop_front() { - return Poll::Ready(front); - } - - self.waker = Some(cx.waker().clone()); - Poll::Pending - } - - fn at_capacity(&self) -> bool { - self.queue.len() == self.limit - } - - pub(crate) fn push_back(&mut self, x: T) -> Result<(), Full> { - if self.at_capacity() { - return Err(Full); - } - - self.queue.push_back(x); - - if let Some(ref waker) = self.waker { - waker.wake_by_ref(); - } - - Ok(()) - } -} diff --git a/rust/connlib/tunnel/src/client.rs b/rust/connlib/tunnel/src/client.rs index 8e115e4a9..da30c68b5 100644 --- a/rust/connlib/tunnel/src/client.rs +++ b/rust/connlib/tunnel/src/client.rs @@ -1,4 +1,3 @@ -use crate::bounded_queue::BoundedQueue; use crate::device_channel::{Device, Packet}; use crate::ip_packet::{IpPacket, MutableIpPacket}; use crate::peer::PacketTransformClient; @@ -17,12 +16,13 @@ use connlib_shared::{Callbacks, Dname, IpProvider}; use domain::base::Rtype; use futures::channel::mpsc::Receiver; use futures::stream; -use futures_bounded::{FuturesMap, PushError, StreamMap}; -use hickory_resolver::lookup::Lookup; +use futures_bounded::{FuturesMap, FuturesTupleSet, PushError, StreamMap}; use ip_network::IpNetwork; use ip_network_table::IpNetworkTable; use itertools::Itertools; +use hickory_resolver::config::{NameServerConfig, Protocol, ResolverConfig}; +use hickory_resolver::TokioAsyncResolver; use rand_core::OsRng; use std::collections::hash_map::Entry; use std::collections::{HashMap, HashSet, VecDeque}; @@ -110,24 +110,6 @@ where Ok(()) } - /// Writes the response to a DNS lookup - #[tracing::instrument(level = "trace", skip(self))] - pub fn write_dns_lookup_response( - &self, - response: hickory_resolver::error::ResolveResult, - query: IpPacket<'static>, - ) -> connlib_shared::Result<()> { - if let Some(pkt) = dns::build_response_from_resolve_result(query, response)? { - let Some(ref device) = *self.device.load() else { - return Ok(()); - }; - - device.write(pkt)?; - } - - Ok(()) - } - /// Sets the interface configuration and starts background tasks. #[tracing::instrument(level = "trace", skip(self))] pub fn set_interface( @@ -157,7 +139,7 @@ where return Err(errs.pop().unwrap()); } - self.role_state.lock().dns_mapping = dns_mapping; + self.role_state.lock().set_dns_mapping(dns_mapping); let res_v4 = self.add_route(IPV4_RESOURCES.parse().unwrap()); let res_v6 = self.add_route(IPV6_RESOURCES.parse().unwrap()); @@ -228,13 +210,17 @@ pub struct ClientState { #[allow(clippy::type_complexity)] pub peers_by_ip: IpNetworkTable>, - forwarded_dns_queries: BoundedQueue>, + forwarded_dns_queries: FuturesTupleSet< + Result, + DnsQuery<'static>, + >, pub ip_provider: IpProvider, refresh_dns_timer: Interval, - pub dns_mapping: BiMap, + dns_mapping: BiMap, + dns_resolvers: HashMap, } #[derive(Debug, Clone, PartialEq, Eq)] @@ -579,6 +565,15 @@ impl ClientState { } } + pub fn set_dns_mapping(&mut self, mapping: BiMap) { + self.dns_mapping = mapping.clone(); + self.dns_resolvers = create_resolvers(mapping); + } + + pub fn dns_mapping(&self) -> BiMap { + self.dns_mapping.clone() + } + fn is_awaiting_connection_to_cidr(&self, destination: IpAddr) -> bool { let Some(resource) = self.get_cidr_resource_by_destination(destination) else { return false; @@ -633,16 +628,48 @@ impl ClientState { } fn add_pending_dns_query(&mut self, query: DnsQuery) { + let upstream = query.query.destination(); + let Some(resolver) = self.dns_resolvers.get(&upstream).cloned() else { + tracing::warn!(%upstream, "Dropping DNS query because of unknown upstream DNS server"); + return; + }; + + let query = query.into_owned(); + if self .forwarded_dns_queries - .push_back(query.into_owned()) + .try_push( + { + let name = query.name.clone(); + let record_type = query.record_type; + + async move { resolver.lookup(&name, record_type).await } + }, + query, + ) .is_err() { - tracing::warn!("Too many DNS queries, dropping new ones"); + tracing::warn!("Too many DNS queries, dropping existing one"); } } } +fn create_resolvers( + sentinel_mapping: BiMap, +) -> HashMap { + sentinel_mapping + .into_iter() + .map(|(sentinel, srv)| { + let mut resolver_config = ResolverConfig::new(); + resolver_config.add_name_server(NameServerConfig::new(srv.address(), Protocol::Udp)); + ( + sentinel, + TokioAsyncResolver::tokio(resolver_config, Default::default()), + ) + }) + .collect() +} + impl Default for ClientState { fn default() -> Self { // With this single timer this might mean that some DNS are refreshed too often @@ -665,7 +692,10 @@ impl Default for ClientState { gateway_public_keys: Default::default(), resources_gateways: Default::default(), - forwarded_dns_queries: BoundedQueue::with_capacity(DNS_QUERIES_QUEUE_SIZE), + forwarded_dns_queries: FuturesTupleSet::new( + Duration::from_secs(60), + DNS_QUERIES_QUEUE_SIZE, + ), gateway_preshared_keys: Default::default(), // TODO: decide ip ranges ip_provider: IpProvider::new( @@ -680,6 +710,7 @@ impl Default for ClientState { deferred_dns_queries: Default::default(), refresh_dns_timer: interval, dns_mapping: Default::default(), + dns_resolvers: Default::default(), } } } @@ -780,7 +811,25 @@ impl RoleState for ClientState { return Poll::Ready(Event::RefreshResources { connections }); } - return self.forwarded_dns_queries.poll(cx).map(Event::DnsQuery); + match self.forwarded_dns_queries.poll_unpin(cx) { + Poll::Ready((Ok(response), query)) => { + match dns::build_response_from_resolve_result(query.query, response) { + Ok(Some(packet)) => return Poll::Ready(Event::SendPacket(packet)), + Ok(None) => continue, + Err(e) => { + tracing::warn!("Failed to build DNS response from lookup result: {e}"); + continue; + } + } + } + Poll::Ready((Err(resolve_timeout), query)) => { + tracing::warn!(name = %query.name, server = %query.query.destination(), "DNS query timed out: {resolve_timeout}"); + continue; + } + Poll::Pending => {} + } + + return Poll::Pending; } } diff --git a/rust/connlib/tunnel/src/control_protocol/client.rs b/rust/connlib/tunnel/src/control_protocol/client.rs index 8b93d5f26..8b452842d 100644 --- a/rust/connlib/tunnel/src/control_protocol/client.rs +++ b/rust/connlib/tunnel/src/control_protocol/client.rs @@ -164,8 +164,7 @@ where let (peer_sender, peer_receiver) = tokio::sync::mpsc::channel(PEER_QUEUE_SIZE); - peer.transform - .set_dns(self.role_state.lock().dns_mapping.clone()); + peer.transform.set_dns(self.role_state.lock().dns_mapping()); start_handlers( Arc::clone(self), diff --git a/rust/connlib/tunnel/src/lib.rs b/rust/connlib/tunnel/src/lib.rs index 02e10fbd3..7c803b719 100644 --- a/rust/connlib/tunnel/src/lib.rs +++ b/rust/connlib/tunnel/src/lib.rs @@ -30,7 +30,7 @@ use webrtc::{ interceptor::registry::Registry, }; -use arc_swap::ArcSwapOption; +use arc_swap::{access::Access, ArcSwapOption}; use futures_util::task::AtomicWaker; use std::task::{ready, Context, Poll}; use std::{collections::HashMap, fmt, net::IpAddr, sync::Arc, time::Duration}; @@ -56,7 +56,6 @@ use connlib_shared::messages::{ClientId, SecretKey}; use device_channel::Device; use index::IndexLfsr; -mod bounded_queue; mod client; mod control_protocol; mod device_channel; @@ -365,8 +364,18 @@ where } } - if let Poll::Ready(event) = self.role_state.lock().poll_next_event(cx) { - return Poll::Ready(event); + match self.role_state.lock().poll_next_event(cx) { + Poll::Ready(Event::SendPacket(packet)) => { + let Some(device) = self.device.load().clone() else { + continue; + }; + + let _ = device.write(packet); + + continue; + } + Poll::Ready(other) => return Poll::Ready(other), + _ => (), } return Poll::Pending; @@ -448,7 +457,7 @@ pub enum Event { RefreshResources { connections: Vec, }, - DnsQuery(DnsQuery<'static>), + SendPacket(device_channel::Packet<'static>), } impl Tunnel