From 2af8d6096c688e78807846a9eec90ca07df536b9 Mon Sep 17 00:00:00 2001 From: Gabi Date: Tue, 9 Jan 2024 18:08:07 -0300 Subject: [PATCH] fix(connlib): mangle packet for upstream dns as resource (#3134) Fixes #3027 Left a few TODO, will solve it when doing #3123 Draft because we're still testing but it's almost ready --- rust/connlib/clients/shared/src/control.rs | 5 +- rust/connlib/tunnel/src/client.rs | 58 +++++++++++++++-- .../tunnel/src/control_protocol/client.rs | 4 ++ rust/connlib/tunnel/src/lib.rs | 4 +- rust/connlib/tunnel/src/peer.rs | 63 +++++++++++++++++-- 5 files changed, 120 insertions(+), 14 deletions(-) diff --git a/rust/connlib/clients/shared/src/control.rs b/rust/connlib/clients/shared/src/control.rs index 5e96b5fe8..d6d9a25f6 100644 --- a/rust/connlib/clients/shared/src/control.rs +++ b/rust/connlib/clients/shared/src/control.rs @@ -100,8 +100,6 @@ impl ControlPlane { return Err(e); } else { *init = true; - *self.fallback_resolver.lock() = - create_resolver(interface.upstream_dns, self.tunnel.callbacks()); tracing::info!("Firezone Started!"); } } else { @@ -109,6 +107,9 @@ impl ControlPlane { } } + self.tunnel.set_upstream_dns(&interface.upstream_dns); + *self.fallback_resolver.lock() = + create_resolver(interface.upstream_dns, self.tunnel.callbacks()); for resource_description in resources { self.add_resource(resource_description); } diff --git a/rust/connlib/tunnel/src/client.rs b/rust/connlib/tunnel/src/client.rs index a91801096..863563534 100644 --- a/rust/connlib/tunnel/src/client.rs +++ b/rust/connlib/tunnel/src/client.rs @@ -9,8 +9,8 @@ use crate::{ use boringtun::x25519::{PublicKey, StaticSecret}; use connlib_shared::error::{ConnlibError as Error, ConnlibError}; use connlib_shared::messages::{ - GatewayId, Interface as InterfaceConfig, Key, ResourceDescription, ResourceDescriptionCidr, - ResourceDescriptionDns, ResourceId, ReuseConnection, SecretKey, + DnsServer, GatewayId, Interface as InterfaceConfig, Key, ResourceDescription, + ResourceDescriptionCidr, ResourceDescriptionDns, ResourceId, ReuseConnection, SecretKey, }; use connlib_shared::{Callbacks, Dname, DNS_SENTINEL}; use domain::base::Rtype; @@ -25,7 +25,7 @@ use itertools::Itertools; use rand_core::OsRng; use std::collections::hash_map::Entry; use std::collections::{HashMap, HashSet, VecDeque}; -use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; use std::sync::Arc; use std::task::{Context, Poll}; use std::time::Duration; @@ -109,6 +109,30 @@ where Ok(()) } + /// Sets the dns server. + pub fn set_upstream_dns(&self, upstream_dns: &[DnsServer]) { + let upstream_dns: HashSet<_> = upstream_dns + .iter() + .map(|dns| { + let DnsServer::IpPort(dns) = dns; + dns.address + }) + .collect(); + + // TODO: assuming single dns + let Some(upstream_dns_srv) = upstream_dns.iter().next().cloned() else { + return; + }; + + self.role_state.lock().upstream_dns = upstream_dns; + + self.role_state + .lock() + .peers_by_ip + .iter() + .for_each(|p| p.1.inner.transform.set_dns(upstream_dns_srv.ip())); + } + /// Sets the interface configuration and starts background tasks. #[tracing::instrument(level = "trace", skip(self))] pub fn set_interface(&self, config: &InterfaceConfig) -> connlib_shared::Result<()> { @@ -199,6 +223,8 @@ pub struct ClientState { pub ip_provider: IpProvider, refresh_dns_timer: Interval, + + pub upstream_dns: HashSet, } #[derive(Debug, Clone, PartialEq, Eq)] @@ -217,7 +243,7 @@ impl ClientState { pub(crate) fn handle_dns<'a>( &mut self, packet: MutableIpPacket<'a>, - ) -> Result>, MutableIpPacket<'a>> { + ) -> Result>, (MutableIpPacket<'a>, IpAddr)> { match dns::parse( &self.dns_resources, &self.dns_resources_internal_ips, @@ -225,6 +251,19 @@ impl ClientState { ) { Some(dns::ResolveStrategy::LocalResponse(query)) => Ok(Some(query)), Some(dns::ResolveStrategy::ForwardQuery(query)) => { + // There's an edge case here, where the resolver's ip has been resolved before as + // a dns resource... we will ignore that weird case for now. + // Assuming a single upstream dns until #3123 lands + if let Some(upstream_dns) = self.upstream_dns.iter().next() { + if self + .cidr_resources + .longest_match(upstream_dns.ip()) + .is_some() + { + return Err((packet, upstream_dns.ip())); + } + } + self.add_pending_dns_query(query); Ok(None) @@ -236,7 +275,10 @@ impl ClientState { Ok(None) } - None => Err(packet), + None => { + let dest = packet.destination(); + Err((packet, dest)) + } } } @@ -694,6 +736,7 @@ impl Default for ClientState { peers_by_ip: IpNetworkTable::new(), deferred_dns_queries: Default::default(), refresh_dns_timer: interval, + upstream_dns: Default::default(), } } } @@ -767,6 +810,11 @@ impl RoleState for ClientState { if self.refresh_dns_timer.poll_tick(cx).is_ready() { let mut connections = Vec::new(); + + self.peers_by_ip + .iter() + .for_each(|p| p.1.inner.transform.expire_dns_track()); + for resource in self.dns_resources_internal_ips.keys() { let Some(gateway_id) = self.resources_gateways.get(&resource.id) else { continue; diff --git a/rust/connlib/tunnel/src/control_protocol/client.rs b/rust/connlib/tunnel/src/control_protocol/client.rs index fa54b9080..b97491f6d 100644 --- a/rust/connlib/tunnel/src/control_protocol/client.rs +++ b/rust/connlib/tunnel/src/control_protocol/client.rs @@ -164,6 +164,10 @@ where let (peer_sender, peer_receiver) = tokio::sync::mpsc::channel(PEER_QUEUE_SIZE); + if let Some(upstream_dns) = self.role_state.lock().upstream_dns.iter().next() { + peer.transform.set_dns(upstream_dns.ip()); + } + start_handlers( Arc::clone(self), Arc::clone(&self.device), diff --git a/rust/connlib/tunnel/src/lib.rs b/rust/connlib/tunnel/src/lib.rs index 4e20707ce..708481ee9 100644 --- a/rust/connlib/tunnel/src/lib.rs +++ b/rust/connlib/tunnel/src/lib.rs @@ -213,7 +213,7 @@ where tracing::trace!(target: "wire", action = "read", from = "device", dest = %packet.destination()); - let packet = match role_state.handle_dns(packet) { + let (packet, dest) = match role_state.handle_dns(packet) { Ok(Some(response)) => { device.write(response)?; continue; @@ -222,8 +222,6 @@ where Err(non_dns_packet) => non_dns_packet, }; - let dest = packet.destination(); - let Some(peer) = peer_by_ip(&role_state.peers_by_ip, dest) else { role_state.on_connection_intent_ip(dest); continue; diff --git a/rust/connlib/tunnel/src/peer.rs b/rust/connlib/tunnel/src/peer.rs index 5d713d434..ce64493af 100644 --- a/rust/connlib/tunnel/src/peer.rs +++ b/rust/connlib/tunnel/src/peer.rs @@ -1,14 +1,17 @@ use std::borrow::Cow; -use std::collections::VecDeque; +use std::collections::{HashMap, VecDeque}; use std::net::IpAddr; use std::sync::Arc; +use std::time::Instant; +use arc_swap::ArcSwapOption; use bimap::BiMap; use boringtun::noise::rate_limiter::RateLimiter; use boringtun::noise::{Tunn, TunnResult}; use boringtun::x25519::StaticSecret; use bytes::Bytes; use chrono::{DateTime, Utc}; +use connlib_shared::DNS_SENTINEL; use connlib_shared::{messages::ResourceDescription, Error, Result}; use ip_network::IpNetwork; use ip_network_table::IpNetworkTable; @@ -22,6 +25,10 @@ use crate::{device_channel, ip_packet::MutableIpPacket, PeerConfig}; type ExpiryingResource = (ResourceDescription, DateTime); +// The max time a dns request can be configured to live in resolvconf +// is 30 seconds. See resolvconf(5) timeout. +const IDS_EXPIRE: std::time::Duration = std::time::Duration::from_secs(60); + pub(crate) struct Peer { tunnel: Mutex, allowed_ips: RwLock>, @@ -194,6 +201,9 @@ impl Default for PacketTransformGateway { #[derive(Default)] pub struct PacketTransformClient { translations: RwLock>, + // TODO: Upstream dns could be something that's not an ip + upstream_dns: ArcSwapOption, + mangled_dns_ids: Mutex>, } impl PacketTransformClient { @@ -212,6 +222,16 @@ impl PacketTransformClient { translations.insert(proxy_ip, *ip); Some(proxy_ip) } + + pub fn expire_dns_track(&self) { + self.mangled_dns_ids + .lock() + .retain(|_, exp| exp.elapsed() < IDS_EXPIRE); + } + + pub fn set_dns(&self, ip_addr: IpAddr) { + self.upstream_dns.store(Some(Arc::new(ip_addr))); + } } impl PacketTransformGateway { @@ -274,16 +294,37 @@ impl PacketTransform for PacketTransformClient { packet: &'a mut [u8], ) -> Result<(device_channel::Packet<'a>, IpAddr)> { let translations = self.translations.read(); - let src = translations.get_by_right(addr).unwrap_or(addr); + let mut src = *translations.get_by_right(addr).unwrap_or(addr); let Some(mut pkt) = MutableIpPacket::new(packet) else { return Err(Error::BadPacket); }; - pkt.set_src(*src); + let original_src = src; + if self + .upstream_dns + .load() + .as_ref() + .is_some_and(|upstream| **upstream == src) + { + if let Some(dgm) = pkt.as_udp() { + if let Ok(message) = domain::base::Message::from_slice(dgm.payload()) { + if self + .mangled_dns_ids + .lock() + .remove(&message.header().id()) + .is_some_and(|exp| exp.elapsed() < IDS_EXPIRE) + { + src = DNS_SENTINEL.into(); + } + } + } + } + + pkt.set_src(src); pkt.update_checksum(); let packet = make_packet(packet, addr); - Ok((packet, *src)) + Ok((packet, original_src)) } fn packet_transform<'a>(&self, mut packet: MutableIpPacket<'a>) -> Option> { @@ -292,6 +333,20 @@ impl PacketTransform for PacketTransformClient { packet.update_checksum(); } + if packet.destination() == DNS_SENTINEL { + if let Some(ip) = self.upstream_dns.load().as_ref() { + if let Some(dgm) = packet.as_udp() { + if let Ok(message) = domain::base::Message::from_slice(dgm.payload()) { + self.mangled_dns_ids + .lock() + .insert(message.header().id(), Instant::now()); + packet.set_dst(**ip); + packet.update_checksum(); + } + } + } + } + Some(packet) } }