From 14abda01fdc4e72563aa5823da836afae635facc Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Tue, 16 Jul 2024 09:56:48 +1000 Subject: [PATCH] refactor(connlib): polish DNS resource matching (#5866) In preparation for implementing #5056, I familiarized myself with the current code and ended up implementing a couple of refactorings. --- rust/connlib/shared/src/messages.rs | 3 +- rust/connlib/tunnel/src/client.rs | 7 +- rust/connlib/tunnel/src/dns.rs | 518 +++++++++------------------- 3 files changed, 173 insertions(+), 355 deletions(-) diff --git a/rust/connlib/shared/src/messages.rs b/rust/connlib/shared/src/messages.rs index 8d193542e..9b4f86ff7 100644 --- a/rust/connlib/shared/src/messages.rs +++ b/rust/connlib/shared/src/messages.rs @@ -44,8 +44,7 @@ impl ResourceId { ResourceId(Uuid::new_v4()) } - #[cfg(feature = "proptest")] - pub(crate) fn from_u128(v: u128) -> Self { + pub fn from_u128(v: u128) -> Self { Self(Uuid::from_u128(v)) } } diff --git a/rust/connlib/tunnel/src/client.rs b/rust/connlib/tunnel/src/client.rs index 58c8620e1..c3b42c8ee 100644 --- a/rust/connlib/tunnel/src/client.rs +++ b/rust/connlib/tunnel/src/client.rs @@ -756,10 +756,7 @@ impl ClientState { .longest_match(destination) .map(|(_, res)| res.id); - let maybe_dns_resource_id = self - .stub_resolver - .get_description(&destination) - .map(|r| r.id); + let maybe_dns_resource_id = self.stub_resolver.resolve_resource_by_ip(&destination); maybe_cidr_resource_id.or(maybe_dns_resource_id) } @@ -927,7 +924,7 @@ impl ClientState { match &resource_description { ResourceDescription::Dns(dns) => { - self.stub_resolver.add_resource(dns); + self.stub_resolver.add_resource(dns.id, dns.address.clone()); } ResourceDescription::Cidr(cidr) => { let existing = self.cidr_resources.insert(cidr.address, cidr.clone()); diff --git a/rust/connlib/tunnel/src/dns.rs b/rust/connlib/tunnel/src/dns.rs index a2e83a5ea..81a99f096 100644 --- a/rust/connlib/tunnel/src/dns.rs +++ b/rust/connlib/tunnel/src/dns.rs @@ -1,5 +1,4 @@ use crate::client::IpProvider; -use connlib_shared::messages::client::ResourceDescriptionDns; use connlib_shared::messages::{DnsServer, ResourceId}; use connlib_shared::DomainName; use domain::base::RelativeName; @@ -26,13 +25,14 @@ const REVERSE_DNS_ADDRESS_V4: &str = "in-addr"; const REVERSE_DNS_ADDRESS_V6: &str = "ip6"; const DNS_PORT: u16 = 53; -/// Tells the Client how to reply to a single DNS query -#[derive(Debug)] -pub(crate) enum ResolveStrategy<'a> { - /// The query is for a Resource, we have an IP mapped already, and we can respond instantly - LocalResponse(IpPacket<'static>), - /// The query is for a non-Resource, forward it to an upstream or system resolver - ForwardQuery(DnsQuery<'a>), +pub struct StubResolver { + fqdn_to_ips: HashMap>, + ips_to_fqdn: HashMap, + ip_provider: IpProvider, + /// All DNS resources we know about, indexed by their domain (could be wildcard domain like `*.mycompany.com`). + dns_resources: HashMap, + /// Fixed dns name that will be resolved to fixed ip addrs, similar to /etc/hosts + known_hosts: KnownHosts, } #[derive(Debug)] @@ -44,33 +44,13 @@ pub struct DnsQuery<'a> { pub query: ip_packet::IpPacket<'a>, } -impl<'a> DnsQuery<'a> { - pub(crate) fn into_owned(self) -> DnsQuery<'static> { - let Self { - name, - record_type, - query, - } = self; - let buf = query.packet().to_vec(); - let query = ip_packet::IpPacket::owned(buf) - .expect("We are constructing the ip packet from an ip packet"); - - DnsQuery { - name, - record_type, - query, - } - } -} - -impl Clone for DnsQuery<'static> { - fn clone(&self) -> Self { - Self { - name: self.name.clone(), - record_type: self.record_type, - query: self.query.clone(), - } - } +/// Tells the Client how to reply to a single DNS query +#[derive(Debug)] +pub(crate) enum ResolveStrategy<'a> { + /// The query is for a Resource, we have an IP mapped already, and we can respond instantly + LocalResponse(IpPacket<'static>), + /// The query is for a non-Resource, forward it to an upstream or system resolver + ForwardQuery(DnsQuery<'a>), } struct KnownHosts { @@ -116,39 +96,6 @@ impl KnownHosts { } } -pub struct StubResolver { - fqdn_to_ips: HashMap>, - ips_to_fqdn: HashMap, - ip_provider: IpProvider, - /// All DNS resources we know about, indexed by their domain (could be wildcard domain like `*.mycompany.com`). - dns_resources: HashMap, - /// Fixed dns name that will be resolved to fixed ip addrs, similar to /etc/hosts - known_hosts: KnownHosts, -} - -fn fqdn_to_ips_for_known_hosts( - hosts: &HashMap>, -) -> HashMap> { - hosts - .iter() - .filter_map(|(d, a)| DomainName::vec_from_str(d).ok().map(|d| (d, a.clone()))) - .collect() -} - -fn ips_to_fqdn_for_known_hosts( - hosts: &HashMap>, -) -> HashMap { - hosts - .iter() - .filter_map(|(d, a)| { - DomainName::vec_from_str(d) - .ok() - .map(|d| a.iter().map(move |a| (*a, d.clone()))) - }) - .flatten() - .collect() -} - impl StubResolver { pub(crate) fn new(known_hosts: HashMap>) -> StubResolver { StubResolver { @@ -160,30 +107,33 @@ impl StubResolver { } } - pub(crate) fn get_description(&self, ip: &IpAddr) -> Option { - let name = self.ips_to_fqdn.get(ip)?; - get_description(name, &self.dns_resources) + /// Attempts to resolve an IP to a given resource. + /// + /// Semantically, this is like a PTR query, i.e. we check whether we handed out this IP as part of answering a DNS query for one of our resources. + /// This is in the hot-path of packet routing and must be fast! + pub(crate) fn resolve_resource_by_ip(&self, ip: &IpAddr) -> Option { + let (_, resource_id) = self.ips_to_fqdn.get(ip)?; + + Some(*resource_id) } pub(crate) fn get_fqdn(&self, ip: &IpAddr) -> Option<(&DomainName, &Vec)> { - let fqdn = self.ips_to_fqdn.get(ip)?; + let (fqdn, _) = self.ips_to_fqdn.get(ip)?; Some((fqdn, self.fqdn_to_ips.get(fqdn).unwrap())) } - pub(crate) fn add_resource(&mut self, resource: &ResourceDescriptionDns) { - let existing = self - .dns_resources - .insert(resource.address.clone(), resource.clone()); + pub(crate) fn add_resource(&mut self, id: ResourceId, address: String) { + let existing = self.dns_resources.insert(address.clone(), id); if existing.is_none() { - tracing::info!(address = %resource.address, "Activating DNS resource"); + tracing::info!(%address, "Activating DNS resource"); } } pub(crate) fn remove_resource(&mut self, id: ResourceId) { - self.dns_resources.retain(|_, r| { - if r.id == id { - tracing::info!(address = %r.address, "Deactivating DNS resource"); + self.dns_resources.retain(|address, r| { + if *r == id { + tracing::info!(%address, "Deactivating DNS resource"); return false; } @@ -194,18 +144,20 @@ impl StubResolver { fn get_or_assign_a_records( &mut self, fqdn: DomainName, + resource_id: ResourceId, ) -> Vec, DomainName>> { - to_a_records(self.get_or_assign_ips(fqdn).into_iter()) + to_a_records(self.get_or_assign_ips(fqdn, resource_id).into_iter()) } fn get_or_assign_aaaa_records( &mut self, fqdn: DomainName, + resource_id: ResourceId, ) -> Vec, DomainName>> { - to_aaaa_records(self.get_or_assign_ips(fqdn).into_iter()) + to_aaaa_records(self.get_or_assign_ips(fqdn, resource_id).into_iter()) } - fn get_or_assign_ips(&mut self, fqdn: DomainName) -> Vec { + fn get_or_assign_ips(&mut self, fqdn: DomainName, resource_id: ResourceId) -> Vec { let ips = self .fqdn_to_ips .entry(fqdn.clone()) @@ -218,14 +170,14 @@ impl StubResolver { }) .clone(); for ip in &ips { - self.ips_to_fqdn.insert(*ip, fqdn.clone()); + self.ips_to_fqdn.insert(*ip, (fqdn.clone(), resource_id)); } ips } - fn is_fqdn_resource(&self, domain_name: &DomainName) -> bool { - get_description(domain_name, &self.dns_resources).is_some() + fn match_resource(&self, domain_name: &DomainName) -> Option { + match_domain(domain_name, &self.dns_resources) } fn resource_address_name_by_reservse_dns( @@ -233,8 +185,9 @@ impl StubResolver { reverse_dns_name: &DomainName, ) -> Option { let address = reverse_dns_addr(&reverse_dns_name.to_string())?; + let (domain, _) = self.ips_to_fqdn.get(&address)?; - self.ips_to_fqdn.get(&address).cloned() + Some(domain.clone()) } // TODO: we can save a few allocations here still @@ -273,14 +226,14 @@ impl StubResolver { ))); } - let resource_records = match qtype { - Rtype::A if self.is_fqdn_resource(&domain) => { - self.get_or_assign_a_records(domain.clone()) + let maybe_resource = self.match_resource(&domain); + + let resource_records = match (qtype, maybe_resource) { + (Rtype::A, Some(resource)) => self.get_or_assign_a_records(domain.clone(), resource), + (Rtype::AAAA, Some(resource)) => { + self.get_or_assign_aaaa_records(domain.clone(), resource) } - Rtype::AAAA if self.is_fqdn_resource(&domain) => { - self.get_or_assign_aaaa_records(domain.clone()) - } - Rtype::PTR => { + (Rtype::PTR, _) => { let fqdn = self.resource_address_name_by_reservse_dns(&domain)?; vec![AllRecordData::Ptr(domain::rdata::Ptr::new(fqdn))] @@ -302,6 +255,35 @@ impl StubResolver { } } +impl<'a> DnsQuery<'a> { + pub(crate) fn into_owned(self) -> DnsQuery<'static> { + let Self { + name, + record_type, + query, + } = self; + let buf = query.packet().to_vec(); + let query = ip_packet::IpPacket::owned(buf) + .expect("We are constructing the ip packet from an ip packet"); + + DnsQuery { + name, + record_type, + query, + } + } +} + +impl Clone for DnsQuery<'static> { + fn clone(&self) -> Self { + Self { + name: self.name.clone(), + record_type: self.record_type, + query: self.query.clone(), + } + } +} + fn to_a_records(ips: impl Iterator) -> Vec, DomainName>> { ips.filter_map(get_v4) .map(domain::rdata::A::new) @@ -439,41 +421,47 @@ pub fn is_subdomain(name: &DomainName, resource: &str) -> bool { name == &resource } -fn get_description( - name: &DomainName, - dns_resources: &HashMap, -) -> Option { - if let Some(resource) = dns_resources.get(&name.to_string()) { - return Some(resource.clone()); +fn match_domain(name: &DomainName, resources: &HashMap) -> Option +where + T: Copy, +{ + // Safety: `?` is less than 254 bytes long. + const QUESTION_MARK: RelativeName<&'static [u8]> = + unsafe { RelativeName::from_octets_unchecked(b"\x01?") }; + // Safety: `*` is less than 254 bytes long. + const WILDCARD: RelativeName<&'static [u8]> = + unsafe { RelativeName::from_octets_unchecked(b"\x01*") }; + + // First, check for full match. + if let Some(resource) = resources.get(&name.to_string()) { + return Some(*resource); } - if let Some(resource) = dns_resources.get( - &RelativeName::>::from_octets(b"\x01?".as_ref().into()) - .ok()? - .chain(name) - .ok()? - .to_string(), - ) { - return Some(resource.clone()); + // Second, check for `?` matching this domain exactly. + let qm_dot_domain = QUESTION_MARK.chain(name).ok()?.to_string(); + if let Some(resource) = resources.get(&qm_dot_domain) { + return Some(*resource); } + // Third, check for `?` matching up to 1 parent. if let Some(parent) = name.parent() { - if let Some(resource) = dns_resources.get( - &RelativeName::>::from_octets(b"\x01?".as_ref().into()) - .ok()? - .chain(parent) - .ok()? - .to_string(), - ) { - return Some(resource.clone()); + let qm_dot_parent = QUESTION_MARK.chain(parent).ok()?.to_string(); + + if let Some(resource) = resources.get(&qm_dot_parent) { + return Some(*resource); } } - name.iter_suffixes().find_map(|n| { - dns_resources - .get(&RelativeName::wildcard_vec().chain(n).ok()?.to_string()) - .cloned() - }) + // Last, check for any wildcard domains, starting with the most specific one. + for suffix in name.iter_suffixes() { + let wildcard_dot_suffix = WILDCARD.chain(suffix).ok()?.to_string(); + + if let Some(resource) = resources.get(&wildcard_dot_suffix) { + return Some(*resource); + } + } + + None } fn reverse_dns_addr(name: &str) -> Option { @@ -523,65 +511,32 @@ fn get_v6(ip: IpAddr) -> Option { } } +fn fqdn_to_ips_for_known_hosts( + hosts: &HashMap>, +) -> HashMap> { + hosts + .iter() + .filter_map(|(d, a)| DomainName::vec_from_str(d).ok().map(|d| (d, a.clone()))) + .collect() +} + +fn ips_to_fqdn_for_known_hosts( + hosts: &HashMap>, +) -> HashMap { + hosts + .iter() + .filter_map(|(d, a)| { + DomainName::vec_from_str(d) + .ok() + .map(|d| a.iter().map(move |a| (*a, d.clone()))) + }) + .flatten() + .collect() +} + #[cfg(test)] -mod test { - use connlib_shared::{messages::client::ResourceDescriptionDns, DomainName}; - - use crate::dns::is_subdomain; - - use super::{get_description, reverse_dns_addr}; - use std::{collections::HashMap, net::Ipv4Addr}; - - fn foo() -> ResourceDescriptionDns { - serde_json::from_str( - r#"{ - "id": "c4bb3d79-afa7-4660-8918-06c38fda3a4a", - "address": "*.foo.com", - "name": "foo.com wildcard", - "address_description": "foo", - "gateway_groups": [{"id": "bf56f32d-7b2c-4f5d-a784-788977d014a4", "name": "test"}] - }"#, - ) - .unwrap() - } - - fn bar() -> ResourceDescriptionDns { - serde_json::from_str( - r#"{ - "id": "c4bb3d79-afa7-4660-8918-06c38fda3a4b", - "address": "*.bar.com", - "name": "bar.com wildcard", - "address_description": "bar", - "gateway_groups": [{"id": "bf56f32d-7b2c-4f5d-a784-788977d014a4", "name": "test"}] - }"#, - ) - .unwrap() - } - - fn baz() -> ResourceDescriptionDns { - serde_json::from_str( - r#"{ - "id": "c4bb3d79-afa7-4660-8918-06c38fda3a4c", - "address": "baz.com", - "name": "baz.com", - "address_description": "baz", - "gateway_groups": [{"id": "bf56f32d-7b2c-4f5d-a784-788977d014a4", "name": "test"}] - }"#, - ) - .unwrap() - } - - fn dns_resource_fixture() -> HashMap { - let mut dns_resources_fixture = HashMap::new(); - - dns_resources_fixture.insert("*.foo.com".to_string(), foo()); - - dns_resources_fixture.insert("?.bar.com".to_string(), bar()); - - dns_resources_fixture.insert("baz.com".to_string(), baz()); - - dns_resources_fixture - } +mod tests { + use super::*; #[test] fn reverse_dns_addr_works_v4() { @@ -638,198 +593,65 @@ mod test { #[test] fn wildcard_matching() { - let dns_resources_fixture = dns_resource_fixture(); + let resources = HashMap::from([("*.foo.com".to_string(), 0), ("*.com".to_string(), 1)]); - assert_eq!( - get_description( - &DomainName::vec_from_str("a.foo.com").unwrap(), - &dns_resources_fixture, - ) - .unwrap(), - foo(), - ); - - assert_eq!( - get_description( - &DomainName::vec_from_str("foo.com").unwrap(), - &dns_resources_fixture, - ) - .unwrap(), - foo(), - ); - - assert_eq!( - get_description( - &DomainName::vec_from_str("a.b.foo.com").unwrap(), - &dns_resources_fixture, - ) - .unwrap(), - foo(), - ); - - assert!(get_description( - &DomainName::vec_from_str("oo.com").unwrap(), - &dns_resources_fixture, - ) - .is_none(),); + assert_eq!(match_domain(&domain("a.foo.com"), &resources), Some(0)); + assert_eq!(match_domain(&domain("foo.com"), &resources), Some(0)); + assert_eq!(match_domain(&domain("a.b.foo.com"), &resources), Some(0)); + assert_eq!(match_domain(&domain("oo.com"), &resources), Some(1)); + assert_eq!(match_domain(&domain("oo.xyz"), &resources), None); } #[test] fn question_mark_matching() { - let dns_resources_fixture = dns_resource_fixture(); + let resources = HashMap::from([("?.bar.com".to_string(), 1)]); - assert_eq!( - get_description( - &DomainName::vec_from_str("a.bar.com").unwrap(), - &dns_resources_fixture, - ) - .unwrap(), - bar(), - ); - - assert_eq!( - get_description( - &DomainName::vec_from_str("bar.com").unwrap(), - &dns_resources_fixture, - ) - .unwrap(), - bar(), - ); - - assert!(get_description( - &DomainName::vec_from_str("a.b.bar.com").unwrap(), - &dns_resources_fixture, - ) - .is_none(),); + assert_eq!(match_domain(&domain("a.bar.com"), &resources), Some(1)); + assert_eq!(match_domain(&domain("bar.com"), &resources), Some(1)); + assert_eq!(match_domain(&domain("a.b.bar.com"), &resources), None); } #[test] fn exact_matching() { - let dns_resources_fixture = dns_resource_fixture(); + let resources = HashMap::from([("baz.com".to_string(), 2)]); - assert_eq!( - get_description( - &DomainName::vec_from_str("baz.com").unwrap(), - &dns_resources_fixture, - ) - .unwrap(), - baz(), - ); - - assert!(get_description( - &DomainName::vec_from_str("a.baz.com").unwrap(), - &dns_resources_fixture, - ) - .is_none()); - - assert!(get_description( - &DomainName::vec_from_str("a.b.baz.com").unwrap(), - &dns_resources_fixture, - ) - .is_none(),); + assert_eq!(match_domain(&domain("baz.com"), &resources), Some(2)); + assert_eq!(match_domain(&domain("a.baz.com"), &resources), None); + assert_eq!(match_domain(&domain("a.b.baz.com"), &resources), None); } #[test] fn exact_subdomain_match() { - assert!(is_subdomain( - &DomainName::vec_from_str("foo.com").unwrap(), - "foo.com" - )); - - assert!(!is_subdomain( - &DomainName::vec_from_str("a.foo.com").unwrap(), - "foo.com" - )); - - assert!(!is_subdomain( - &DomainName::vec_from_str("a.b.foo.com").unwrap(), - "foo.com" - )); - - assert!(!is_subdomain( - &DomainName::vec_from_str("foo.com").unwrap(), - "a.foo.com" - )); + assert!(is_subdomain(&domain("foo.com"), "foo.com")); + assert!(!is_subdomain(&domain("a.foo.com"), "foo.com")); + assert!(!is_subdomain(&domain("a.b.foo.com"), "foo.com")); + assert!(!is_subdomain(&domain("foo.com"), "a.foo.com")); } #[test] fn wildcard_subdomain_match() { - assert!(is_subdomain( - &DomainName::vec_from_str("foo.com").unwrap(), - "*.foo.com" - )); - - assert!(is_subdomain( - &DomainName::vec_from_str("a.foo.com").unwrap(), - "*.foo.com" - )); - - assert!(is_subdomain( - &DomainName::vec_from_str("a.foo.com").unwrap(), - "*.a.foo.com" - )); - - assert!(is_subdomain( - &DomainName::vec_from_str("b.a.foo.com").unwrap(), - "*.a.foo.com" - )); - - assert!(is_subdomain( - &DomainName::vec_from_str("a.b.foo.com").unwrap(), - "*.foo.com" - )); - - assert!(!is_subdomain( - &DomainName::vec_from_str("afoo.com").unwrap(), - "*.foo.com" - )); - - assert!(!is_subdomain( - &DomainName::vec_from_str("b.afoo.com").unwrap(), - "*.foo.com" - )); - - assert!(!is_subdomain( - &DomainName::vec_from_str("bar.com").unwrap(), - "*.foo.com" - )); - - assert!(!is_subdomain( - &DomainName::vec_from_str("foo.com").unwrap(), - "*.a.foo.com" - )); + assert!(is_subdomain(&domain("foo.com"), "*.foo.com")); + assert!(is_subdomain(&domain("a.foo.com"), "*.foo.com")); + assert!(is_subdomain(&domain("a.foo.com"), "*.a.foo.com")); + assert!(is_subdomain(&domain("b.a.foo.com"), "*.a.foo.com")); + assert!(is_subdomain(&domain("a.b.foo.com"), "*.foo.com")); + assert!(!is_subdomain(&domain("afoo.com"), "*.foo.com")); + assert!(!is_subdomain(&domain("b.afoo.com"), "*.foo.com")); + assert!(!is_subdomain(&domain("bar.com"), "*.foo.com")); + assert!(!is_subdomain(&domain("foo.com"), "*.a.foo.com")); } #[test] fn question_mark_subdomain_match() { - assert!(is_subdomain( - &DomainName::vec_from_str("foo.com").unwrap(), - "?.foo.com" - )); + assert!(is_subdomain(&domain("foo.com"), "?.foo.com")); + assert!(is_subdomain(&domain("a.foo.com"), "?.foo.com")); + assert!(!is_subdomain(&domain("a.b.foo.com"), "?.foo.com")); + assert!(!is_subdomain(&domain("bar.com"), "?.foo.com")); + assert!(!is_subdomain(&domain("foo.com"), "?.a.foo.com")); + assert!(!is_subdomain(&domain("afoo.com"), "?.foo.com")); + } - assert!(is_subdomain( - &DomainName::vec_from_str("a.foo.com").unwrap(), - "?.foo.com" - )); - - assert!(!is_subdomain( - &DomainName::vec_from_str("a.b.foo.com").unwrap(), - "?.foo.com" - )); - - assert!(!is_subdomain( - &DomainName::vec_from_str("bar.com").unwrap(), - "?.foo.com" - )); - - assert!(!is_subdomain( - &DomainName::vec_from_str("foo.com").unwrap(), - "?.a.foo.com" - )); - - assert!(!is_subdomain( - &DomainName::vec_from_str("afoo.com").unwrap(), - "?.foo.com" - )); + fn domain(name: &str) -> DomainName { + DomainName::vec_from_str(name).unwrap() } }