refactor(connlib): polish DNS resource matching (#5866)

In preparation for implementing #5056, I familiarized myself with the
current code and ended up implementing a couple of refactorings.
This commit is contained in:
Thomas Eizinger
2024-07-16 09:56:48 +10:00
committed by GitHub
parent 92a2a7852b
commit 14abda01fd
3 changed files with 173 additions and 355 deletions

View File

@@ -44,8 +44,7 @@ impl ResourceId {
ResourceId(Uuid::new_v4())
}
#[cfg(feature = "proptest")]
pub(crate) fn from_u128(v: u128) -> Self {
pub fn from_u128(v: u128) -> Self {
Self(Uuid::from_u128(v))
}
}

View File

@@ -756,10 +756,7 @@ impl ClientState {
.longest_match(destination)
.map(|(_, res)| res.id);
let maybe_dns_resource_id = self
.stub_resolver
.get_description(&destination)
.map(|r| r.id);
let maybe_dns_resource_id = self.stub_resolver.resolve_resource_by_ip(&destination);
maybe_cidr_resource_id.or(maybe_dns_resource_id)
}
@@ -927,7 +924,7 @@ impl ClientState {
match &resource_description {
ResourceDescription::Dns(dns) => {
self.stub_resolver.add_resource(dns);
self.stub_resolver.add_resource(dns.id, dns.address.clone());
}
ResourceDescription::Cidr(cidr) => {
let existing = self.cidr_resources.insert(cidr.address, cidr.clone());

View File

@@ -1,5 +1,4 @@
use crate::client::IpProvider;
use connlib_shared::messages::client::ResourceDescriptionDns;
use connlib_shared::messages::{DnsServer, ResourceId};
use connlib_shared::DomainName;
use domain::base::RelativeName;
@@ -26,13 +25,14 @@ const REVERSE_DNS_ADDRESS_V4: &str = "in-addr";
const REVERSE_DNS_ADDRESS_V6: &str = "ip6";
const DNS_PORT: u16 = 53;
/// Tells the Client how to reply to a single DNS query
#[derive(Debug)]
pub(crate) enum ResolveStrategy<'a> {
/// The query is for a Resource, we have an IP mapped already, and we can respond instantly
LocalResponse(IpPacket<'static>),
/// The query is for a non-Resource, forward it to an upstream or system resolver
ForwardQuery(DnsQuery<'a>),
pub struct StubResolver {
fqdn_to_ips: HashMap<DomainName, Vec<IpAddr>>,
ips_to_fqdn: HashMap<IpAddr, (DomainName, ResourceId)>,
ip_provider: IpProvider,
/// All DNS resources we know about, indexed by their domain (could be wildcard domain like `*.mycompany.com`).
dns_resources: HashMap<String, ResourceId>,
/// Fixed dns name that will be resolved to fixed ip addrs, similar to /etc/hosts
known_hosts: KnownHosts,
}
#[derive(Debug)]
@@ -44,33 +44,13 @@ pub struct DnsQuery<'a> {
pub query: ip_packet::IpPacket<'a>,
}
impl<'a> DnsQuery<'a> {
pub(crate) fn into_owned(self) -> DnsQuery<'static> {
let Self {
name,
record_type,
query,
} = self;
let buf = query.packet().to_vec();
let query = ip_packet::IpPacket::owned(buf)
.expect("We are constructing the ip packet from an ip packet");
DnsQuery {
name,
record_type,
query,
}
}
}
impl Clone for DnsQuery<'static> {
fn clone(&self) -> Self {
Self {
name: self.name.clone(),
record_type: self.record_type,
query: self.query.clone(),
}
}
/// Tells the Client how to reply to a single DNS query
#[derive(Debug)]
pub(crate) enum ResolveStrategy<'a> {
/// The query is for a Resource, we have an IP mapped already, and we can respond instantly
LocalResponse(IpPacket<'static>),
/// The query is for a non-Resource, forward it to an upstream or system resolver
ForwardQuery(DnsQuery<'a>),
}
struct KnownHosts {
@@ -116,39 +96,6 @@ impl KnownHosts {
}
}
pub struct StubResolver {
fqdn_to_ips: HashMap<DomainName, Vec<IpAddr>>,
ips_to_fqdn: HashMap<IpAddr, DomainName>,
ip_provider: IpProvider,
/// All DNS resources we know about, indexed by their domain (could be wildcard domain like `*.mycompany.com`).
dns_resources: HashMap<String, ResourceDescriptionDns>,
/// Fixed dns name that will be resolved to fixed ip addrs, similar to /etc/hosts
known_hosts: KnownHosts,
}
fn fqdn_to_ips_for_known_hosts(
hosts: &HashMap<String, Vec<IpAddr>>,
) -> HashMap<DomainName, Vec<IpAddr>> {
hosts
.iter()
.filter_map(|(d, a)| DomainName::vec_from_str(d).ok().map(|d| (d, a.clone())))
.collect()
}
fn ips_to_fqdn_for_known_hosts(
hosts: &HashMap<String, Vec<IpAddr>>,
) -> HashMap<IpAddr, DomainName> {
hosts
.iter()
.filter_map(|(d, a)| {
DomainName::vec_from_str(d)
.ok()
.map(|d| a.iter().map(move |a| (*a, d.clone())))
})
.flatten()
.collect()
}
impl StubResolver {
pub(crate) fn new(known_hosts: HashMap<String, Vec<IpAddr>>) -> StubResolver {
StubResolver {
@@ -160,30 +107,33 @@ impl StubResolver {
}
}
pub(crate) fn get_description(&self, ip: &IpAddr) -> Option<ResourceDescriptionDns> {
let name = self.ips_to_fqdn.get(ip)?;
get_description(name, &self.dns_resources)
/// Attempts to resolve an IP to a given resource.
///
/// Semantically, this is like a PTR query, i.e. we check whether we handed out this IP as part of answering a DNS query for one of our resources.
/// This is in the hot-path of packet routing and must be fast!
pub(crate) fn resolve_resource_by_ip(&self, ip: &IpAddr) -> Option<ResourceId> {
let (_, resource_id) = self.ips_to_fqdn.get(ip)?;
Some(*resource_id)
}
pub(crate) fn get_fqdn(&self, ip: &IpAddr) -> Option<(&DomainName, &Vec<IpAddr>)> {
let fqdn = self.ips_to_fqdn.get(ip)?;
let (fqdn, _) = self.ips_to_fqdn.get(ip)?;
Some((fqdn, self.fqdn_to_ips.get(fqdn).unwrap()))
}
pub(crate) fn add_resource(&mut self, resource: &ResourceDescriptionDns) {
let existing = self
.dns_resources
.insert(resource.address.clone(), resource.clone());
pub(crate) fn add_resource(&mut self, id: ResourceId, address: String) {
let existing = self.dns_resources.insert(address.clone(), id);
if existing.is_none() {
tracing::info!(address = %resource.address, "Activating DNS resource");
tracing::info!(%address, "Activating DNS resource");
}
}
pub(crate) fn remove_resource(&mut self, id: ResourceId) {
self.dns_resources.retain(|_, r| {
if r.id == id {
tracing::info!(address = %r.address, "Deactivating DNS resource");
self.dns_resources.retain(|address, r| {
if *r == id {
tracing::info!(%address, "Deactivating DNS resource");
return false;
}
@@ -194,18 +144,20 @@ impl StubResolver {
fn get_or_assign_a_records(
&mut self,
fqdn: DomainName,
resource_id: ResourceId,
) -> Vec<AllRecordData<Vec<u8>, DomainName>> {
to_a_records(self.get_or_assign_ips(fqdn).into_iter())
to_a_records(self.get_or_assign_ips(fqdn, resource_id).into_iter())
}
fn get_or_assign_aaaa_records(
&mut self,
fqdn: DomainName,
resource_id: ResourceId,
) -> Vec<AllRecordData<Vec<u8>, DomainName>> {
to_aaaa_records(self.get_or_assign_ips(fqdn).into_iter())
to_aaaa_records(self.get_or_assign_ips(fqdn, resource_id).into_iter())
}
fn get_or_assign_ips(&mut self, fqdn: DomainName) -> Vec<IpAddr> {
fn get_or_assign_ips(&mut self, fqdn: DomainName, resource_id: ResourceId) -> Vec<IpAddr> {
let ips = self
.fqdn_to_ips
.entry(fqdn.clone())
@@ -218,14 +170,14 @@ impl StubResolver {
})
.clone();
for ip in &ips {
self.ips_to_fqdn.insert(*ip, fqdn.clone());
self.ips_to_fqdn.insert(*ip, (fqdn.clone(), resource_id));
}
ips
}
fn is_fqdn_resource(&self, domain_name: &DomainName) -> bool {
get_description(domain_name, &self.dns_resources).is_some()
fn match_resource(&self, domain_name: &DomainName) -> Option<ResourceId> {
match_domain(domain_name, &self.dns_resources)
}
fn resource_address_name_by_reservse_dns(
@@ -233,8 +185,9 @@ impl StubResolver {
reverse_dns_name: &DomainName,
) -> Option<DomainName> {
let address = reverse_dns_addr(&reverse_dns_name.to_string())?;
let (domain, _) = self.ips_to_fqdn.get(&address)?;
self.ips_to_fqdn.get(&address).cloned()
Some(domain.clone())
}
// TODO: we can save a few allocations here still
@@ -273,14 +226,14 @@ impl StubResolver {
)));
}
let resource_records = match qtype {
Rtype::A if self.is_fqdn_resource(&domain) => {
self.get_or_assign_a_records(domain.clone())
let maybe_resource = self.match_resource(&domain);
let resource_records = match (qtype, maybe_resource) {
(Rtype::A, Some(resource)) => self.get_or_assign_a_records(domain.clone(), resource),
(Rtype::AAAA, Some(resource)) => {
self.get_or_assign_aaaa_records(domain.clone(), resource)
}
Rtype::AAAA if self.is_fqdn_resource(&domain) => {
self.get_or_assign_aaaa_records(domain.clone())
}
Rtype::PTR => {
(Rtype::PTR, _) => {
let fqdn = self.resource_address_name_by_reservse_dns(&domain)?;
vec![AllRecordData::Ptr(domain::rdata::Ptr::new(fqdn))]
@@ -302,6 +255,35 @@ impl StubResolver {
}
}
impl<'a> DnsQuery<'a> {
pub(crate) fn into_owned(self) -> DnsQuery<'static> {
let Self {
name,
record_type,
query,
} = self;
let buf = query.packet().to_vec();
let query = ip_packet::IpPacket::owned(buf)
.expect("We are constructing the ip packet from an ip packet");
DnsQuery {
name,
record_type,
query,
}
}
}
impl Clone for DnsQuery<'static> {
fn clone(&self) -> Self {
Self {
name: self.name.clone(),
record_type: self.record_type,
query: self.query.clone(),
}
}
}
fn to_a_records(ips: impl Iterator<Item = IpAddr>) -> Vec<AllRecordData<Vec<u8>, DomainName>> {
ips.filter_map(get_v4)
.map(domain::rdata::A::new)
@@ -439,41 +421,47 @@ pub fn is_subdomain(name: &DomainName, resource: &str) -> bool {
name == &resource
}
fn get_description(
name: &DomainName,
dns_resources: &HashMap<String, ResourceDescriptionDns>,
) -> Option<ResourceDescriptionDns> {
if let Some(resource) = dns_resources.get(&name.to_string()) {
return Some(resource.clone());
fn match_domain<T>(name: &DomainName, resources: &HashMap<String, T>) -> Option<T>
where
T: Copy,
{
// Safety: `?` is less than 254 bytes long.
const QUESTION_MARK: RelativeName<&'static [u8]> =
unsafe { RelativeName::from_octets_unchecked(b"\x01?") };
// Safety: `*` is less than 254 bytes long.
const WILDCARD: RelativeName<&'static [u8]> =
unsafe { RelativeName::from_octets_unchecked(b"\x01*") };
// First, check for full match.
if let Some(resource) = resources.get(&name.to_string()) {
return Some(*resource);
}
if let Some(resource) = dns_resources.get(
&RelativeName::<Vec<_>>::from_octets(b"\x01?".as_ref().into())
.ok()?
.chain(name)
.ok()?
.to_string(),
) {
return Some(resource.clone());
// Second, check for `?` matching this domain exactly.
let qm_dot_domain = QUESTION_MARK.chain(name).ok()?.to_string();
if let Some(resource) = resources.get(&qm_dot_domain) {
return Some(*resource);
}
// Third, check for `?` matching up to 1 parent.
if let Some(parent) = name.parent() {
if let Some(resource) = dns_resources.get(
&RelativeName::<Vec<_>>::from_octets(b"\x01?".as_ref().into())
.ok()?
.chain(parent)
.ok()?
.to_string(),
) {
return Some(resource.clone());
let qm_dot_parent = QUESTION_MARK.chain(parent).ok()?.to_string();
if let Some(resource) = resources.get(&qm_dot_parent) {
return Some(*resource);
}
}
name.iter_suffixes().find_map(|n| {
dns_resources
.get(&RelativeName::wildcard_vec().chain(n).ok()?.to_string())
.cloned()
})
// Last, check for any wildcard domains, starting with the most specific one.
for suffix in name.iter_suffixes() {
let wildcard_dot_suffix = WILDCARD.chain(suffix).ok()?.to_string();
if let Some(resource) = resources.get(&wildcard_dot_suffix) {
return Some(*resource);
}
}
None
}
fn reverse_dns_addr(name: &str) -> Option<IpAddr> {
@@ -523,65 +511,32 @@ fn get_v6(ip: IpAddr) -> Option<Ipv6Addr> {
}
}
fn fqdn_to_ips_for_known_hosts(
hosts: &HashMap<String, Vec<IpAddr>>,
) -> HashMap<DomainName, Vec<IpAddr>> {
hosts
.iter()
.filter_map(|(d, a)| DomainName::vec_from_str(d).ok().map(|d| (d, a.clone())))
.collect()
}
fn ips_to_fqdn_for_known_hosts(
hosts: &HashMap<String, Vec<IpAddr>>,
) -> HashMap<IpAddr, DomainName> {
hosts
.iter()
.filter_map(|(d, a)| {
DomainName::vec_from_str(d)
.ok()
.map(|d| a.iter().map(move |a| (*a, d.clone())))
})
.flatten()
.collect()
}
#[cfg(test)]
mod test {
use connlib_shared::{messages::client::ResourceDescriptionDns, DomainName};
use crate::dns::is_subdomain;
use super::{get_description, reverse_dns_addr};
use std::{collections::HashMap, net::Ipv4Addr};
fn foo() -> ResourceDescriptionDns {
serde_json::from_str(
r#"{
"id": "c4bb3d79-afa7-4660-8918-06c38fda3a4a",
"address": "*.foo.com",
"name": "foo.com wildcard",
"address_description": "foo",
"gateway_groups": [{"id": "bf56f32d-7b2c-4f5d-a784-788977d014a4", "name": "test"}]
}"#,
)
.unwrap()
}
fn bar() -> ResourceDescriptionDns {
serde_json::from_str(
r#"{
"id": "c4bb3d79-afa7-4660-8918-06c38fda3a4b",
"address": "*.bar.com",
"name": "bar.com wildcard",
"address_description": "bar",
"gateway_groups": [{"id": "bf56f32d-7b2c-4f5d-a784-788977d014a4", "name": "test"}]
}"#,
)
.unwrap()
}
fn baz() -> ResourceDescriptionDns {
serde_json::from_str(
r#"{
"id": "c4bb3d79-afa7-4660-8918-06c38fda3a4c",
"address": "baz.com",
"name": "baz.com",
"address_description": "baz",
"gateway_groups": [{"id": "bf56f32d-7b2c-4f5d-a784-788977d014a4", "name": "test"}]
}"#,
)
.unwrap()
}
fn dns_resource_fixture() -> HashMap<String, ResourceDescriptionDns> {
let mut dns_resources_fixture = HashMap::new();
dns_resources_fixture.insert("*.foo.com".to_string(), foo());
dns_resources_fixture.insert("?.bar.com".to_string(), bar());
dns_resources_fixture.insert("baz.com".to_string(), baz());
dns_resources_fixture
}
mod tests {
use super::*;
#[test]
fn reverse_dns_addr_works_v4() {
@@ -638,198 +593,65 @@ mod test {
#[test]
fn wildcard_matching() {
let dns_resources_fixture = dns_resource_fixture();
let resources = HashMap::from([("*.foo.com".to_string(), 0), ("*.com".to_string(), 1)]);
assert_eq!(
get_description(
&DomainName::vec_from_str("a.foo.com").unwrap(),
&dns_resources_fixture,
)
.unwrap(),
foo(),
);
assert_eq!(
get_description(
&DomainName::vec_from_str("foo.com").unwrap(),
&dns_resources_fixture,
)
.unwrap(),
foo(),
);
assert_eq!(
get_description(
&DomainName::vec_from_str("a.b.foo.com").unwrap(),
&dns_resources_fixture,
)
.unwrap(),
foo(),
);
assert!(get_description(
&DomainName::vec_from_str("oo.com").unwrap(),
&dns_resources_fixture,
)
.is_none(),);
assert_eq!(match_domain(&domain("a.foo.com"), &resources), Some(0));
assert_eq!(match_domain(&domain("foo.com"), &resources), Some(0));
assert_eq!(match_domain(&domain("a.b.foo.com"), &resources), Some(0));
assert_eq!(match_domain(&domain("oo.com"), &resources), Some(1));
assert_eq!(match_domain(&domain("oo.xyz"), &resources), None);
}
#[test]
fn question_mark_matching() {
let dns_resources_fixture = dns_resource_fixture();
let resources = HashMap::from([("?.bar.com".to_string(), 1)]);
assert_eq!(
get_description(
&DomainName::vec_from_str("a.bar.com").unwrap(),
&dns_resources_fixture,
)
.unwrap(),
bar(),
);
assert_eq!(
get_description(
&DomainName::vec_from_str("bar.com").unwrap(),
&dns_resources_fixture,
)
.unwrap(),
bar(),
);
assert!(get_description(
&DomainName::vec_from_str("a.b.bar.com").unwrap(),
&dns_resources_fixture,
)
.is_none(),);
assert_eq!(match_domain(&domain("a.bar.com"), &resources), Some(1));
assert_eq!(match_domain(&domain("bar.com"), &resources), Some(1));
assert_eq!(match_domain(&domain("a.b.bar.com"), &resources), None);
}
#[test]
fn exact_matching() {
let dns_resources_fixture = dns_resource_fixture();
let resources = HashMap::from([("baz.com".to_string(), 2)]);
assert_eq!(
get_description(
&DomainName::vec_from_str("baz.com").unwrap(),
&dns_resources_fixture,
)
.unwrap(),
baz(),
);
assert!(get_description(
&DomainName::vec_from_str("a.baz.com").unwrap(),
&dns_resources_fixture,
)
.is_none());
assert!(get_description(
&DomainName::vec_from_str("a.b.baz.com").unwrap(),
&dns_resources_fixture,
)
.is_none(),);
assert_eq!(match_domain(&domain("baz.com"), &resources), Some(2));
assert_eq!(match_domain(&domain("a.baz.com"), &resources), None);
assert_eq!(match_domain(&domain("a.b.baz.com"), &resources), None);
}
#[test]
fn exact_subdomain_match() {
assert!(is_subdomain(
&DomainName::vec_from_str("foo.com").unwrap(),
"foo.com"
));
assert!(!is_subdomain(
&DomainName::vec_from_str("a.foo.com").unwrap(),
"foo.com"
));
assert!(!is_subdomain(
&DomainName::vec_from_str("a.b.foo.com").unwrap(),
"foo.com"
));
assert!(!is_subdomain(
&DomainName::vec_from_str("foo.com").unwrap(),
"a.foo.com"
));
assert!(is_subdomain(&domain("foo.com"), "foo.com"));
assert!(!is_subdomain(&domain("a.foo.com"), "foo.com"));
assert!(!is_subdomain(&domain("a.b.foo.com"), "foo.com"));
assert!(!is_subdomain(&domain("foo.com"), "a.foo.com"));
}
#[test]
fn wildcard_subdomain_match() {
assert!(is_subdomain(
&DomainName::vec_from_str("foo.com").unwrap(),
"*.foo.com"
));
assert!(is_subdomain(
&DomainName::vec_from_str("a.foo.com").unwrap(),
"*.foo.com"
));
assert!(is_subdomain(
&DomainName::vec_from_str("a.foo.com").unwrap(),
"*.a.foo.com"
));
assert!(is_subdomain(
&DomainName::vec_from_str("b.a.foo.com").unwrap(),
"*.a.foo.com"
));
assert!(is_subdomain(
&DomainName::vec_from_str("a.b.foo.com").unwrap(),
"*.foo.com"
));
assert!(!is_subdomain(
&DomainName::vec_from_str("afoo.com").unwrap(),
"*.foo.com"
));
assert!(!is_subdomain(
&DomainName::vec_from_str("b.afoo.com").unwrap(),
"*.foo.com"
));
assert!(!is_subdomain(
&DomainName::vec_from_str("bar.com").unwrap(),
"*.foo.com"
));
assert!(!is_subdomain(
&DomainName::vec_from_str("foo.com").unwrap(),
"*.a.foo.com"
));
assert!(is_subdomain(&domain("foo.com"), "*.foo.com"));
assert!(is_subdomain(&domain("a.foo.com"), "*.foo.com"));
assert!(is_subdomain(&domain("a.foo.com"), "*.a.foo.com"));
assert!(is_subdomain(&domain("b.a.foo.com"), "*.a.foo.com"));
assert!(is_subdomain(&domain("a.b.foo.com"), "*.foo.com"));
assert!(!is_subdomain(&domain("afoo.com"), "*.foo.com"));
assert!(!is_subdomain(&domain("b.afoo.com"), "*.foo.com"));
assert!(!is_subdomain(&domain("bar.com"), "*.foo.com"));
assert!(!is_subdomain(&domain("foo.com"), "*.a.foo.com"));
}
#[test]
fn question_mark_subdomain_match() {
assert!(is_subdomain(
&DomainName::vec_from_str("foo.com").unwrap(),
"?.foo.com"
));
assert!(is_subdomain(&domain("foo.com"), "?.foo.com"));
assert!(is_subdomain(&domain("a.foo.com"), "?.foo.com"));
assert!(!is_subdomain(&domain("a.b.foo.com"), "?.foo.com"));
assert!(!is_subdomain(&domain("bar.com"), "?.foo.com"));
assert!(!is_subdomain(&domain("foo.com"), "?.a.foo.com"));
assert!(!is_subdomain(&domain("afoo.com"), "?.foo.com"));
}
assert!(is_subdomain(
&DomainName::vec_from_str("a.foo.com").unwrap(),
"?.foo.com"
));
assert!(!is_subdomain(
&DomainName::vec_from_str("a.b.foo.com").unwrap(),
"?.foo.com"
));
assert!(!is_subdomain(
&DomainName::vec_from_str("bar.com").unwrap(),
"?.foo.com"
));
assert!(!is_subdomain(
&DomainName::vec_from_str("foo.com").unwrap(),
"?.a.foo.com"
));
assert!(!is_subdomain(
&DomainName::vec_from_str("afoo.com").unwrap(),
"?.foo.com"
));
fn domain(name: &str) -> DomainName {
DomainName::vec_from_str(name).unwrap()
}
}