From 128d0eb4074374f6063387a00cb3a4f79d527b73 Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Wed, 7 Aug 2024 09:54:49 +0100 Subject: [PATCH] feat(connlib): transparently forward non-resources DNS queries (#6181) Currently, `connlib` depends on `hickory-resolver` to perform DNS queries for non-resources. This is unnecessary. Instead of buffering the original UDP DNS query, consulting hickory to resolve the name and mapping the response back, we can simply take the UDP payload and send it via our protected socket directly to the original upstream DNS server. This ensures `connlib` is as transparent as possible for DNS queries for non-resources. Additionally, it removes a lot of error handling and other cruft that we currently have to perform because we are using hickory. For example, hickory will automatically retry a DNS query after a certain timeout. However, the OS / client talking to `connlib` will also retry after a certain timeout because it is making DNS queries over an unreliable transport (UDP). It is thus unnecessary for us to do that internally. To correctly test this change, our test-suite needed some refactoring. Specifically, DNS servers are now modelled as dedicated `Host`s that can receive (UDP) traffic. Lastly, we can remove our dependency on `hickory-proto` and `hickory-resolver` everywhere and only use `domain` for parsing DNS messages. Resolves: #6141. Related: #6033. Related: #4800. (Impossible to happen with this design) --- rust/Cargo.lock | 102 +----- rust/bin-shared/Cargo.toml | 2 +- rust/connlib/clients/android/Cargo.toml | 2 +- rust/connlib/clients/apple/Cargo.toml | 2 +- rust/connlib/clients/shared/src/eventloop.rs | 16 +- rust/connlib/tunnel/Cargo.toml | 7 +- rust/connlib/tunnel/src/client.rs | 319 +++++-------------- rust/connlib/tunnel/src/dns.rs | 160 +++------- rust/connlib/tunnel/src/io.rs | 177 +--------- rust/connlib/tunnel/src/lib.rs | 14 - rust/connlib/tunnel/src/tests.rs | 1 + rust/connlib/tunnel/src/tests/reference.rs | 65 ++-- rust/connlib/tunnel/src/tests/sim_client.rs | 105 ++---- rust/connlib/tunnel/src/tests/sim_dns.rs | 124 +++++++ rust/connlib/tunnel/src/tests/sim_gateway.rs | 10 +- rust/connlib/tunnel/src/tests/sim_net.rs | 17 + rust/connlib/tunnel/src/tests/strategies.rs | 72 +++-- rust/connlib/tunnel/src/tests/sut.rs | 106 +++--- rust/connlib/tunnel/src/tests/transition.rs | 39 ++- rust/headless-client/Cargo.toml | 2 +- rust/ip-packet/Cargo.toml | 2 +- rust/ip-packet/src/lib.rs | 5 +- rust/ip-packet/src/make.rs | 85 +++-- rust/socket-factory/Cargo.toml | 5 - rust/socket-factory/src/lib.rs | 59 ---- 25 files changed, 498 insertions(+), 1000 deletions(-) create mode 100644 rust/connlib/tunnel/src/tests/sim_dns.rs diff --git a/rust/Cargo.lock b/rust/Cargo.lock index f274dedad..ec9f6dc62 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -1608,18 +1608,6 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a3d8a32ae18130a3c84dd492d4215c3d913c3b07c6b63c2eb3eb7ff1101ab7bf" -[[package]] -name = "enum-as-inner" -version = "0.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5ffccbb6966c05b32ef8fbac435df276c4ae4d3dc55a8cd0eb9745e6c12f546a" -dependencies = [ - "heck 0.4.1", - "proc-macro2", - "quote", - "syn 2.0.72", -] - [[package]] name = "enumflags2" version = "0.7.9" @@ -1999,11 +1987,8 @@ dependencies = [ "domain", "firezone-relay", "futures", - "futures-bounded", "futures-util", "hex", - "hickory-proto", - "hickory-resolver", "ip-packet", "ip_network", "ip_network_table", @@ -2025,6 +2010,7 @@ dependencies = [ "tracing-appender", "tracing-subscriber", "tun", + "uuid", ] [[package]] @@ -2121,9 +2107,9 @@ dependencies = [ [[package]] name = "futures-bounded" -version = "0.2.3" +version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e1e2774cc104e198ef3d3e1ff4ab40f86fa3245d6cb6a3a46174f21463cee173" +checksum = "91f328e7fb845fc832912fb6a34f40cf6d1888c92f974d1893a54e97b5ff542e" dependencies = [ "futures-timer", "futures-util", @@ -2655,49 +2641,6 @@ version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6fe2267d4ed49bc07b63801559be28c718ea06c4738b7a03c94df7386d2cde46" -[[package]] -name = "hickory-proto" -version = "0.24.0" -source = "git+https://github.com/hickory-dns/hickory-dns?rev=a3669bd80f3f7b97f0c301c15f1cba6368d97b63#a3669bd80f3f7b97f0c301c15f1cba6368d97b63" -dependencies = [ - "async-trait", - "cfg-if", - "data-encoding", - "enum-as-inner", - "futures-channel", - "futures-io", - "futures-util", - "idna", - "ipnet", - "once_cell", - "rand 0.8.5", - "thiserror", - "tinyvec", - "tokio", - "tracing", - "url", -] - -[[package]] -name = "hickory-resolver" -version = "0.24.0" -source = "git+https://github.com/hickory-dns/hickory-dns?rev=a3669bd80f3f7b97f0c301c15f1cba6368d97b63#a3669bd80f3f7b97f0c301c15f1cba6368d97b63" -dependencies = [ - "cfg-if", - "futures-util", - "hickory-proto", - "ipconfig", - "lru-cache", - "once_cell", - "parking_lot", - "rand 0.8.5", - "resolv-conf", - "smallvec 1.13.2", - "thiserror", - "tokio", - "tracing", -] - [[package]] name = "hkdf" version = "0.12.4" @@ -2725,17 +2668,6 @@ dependencies = [ "windows-sys 0.52.0", ] -[[package]] -name = "hostname" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3c731c3e10504cc8ed35cfe2f1db4c9274c3d35fa486e3b31df46f068ef3e867" -dependencies = [ - "libc", - "match_cfg", - "winapi", -] - [[package]] name = "hostname" version = "0.4.0" @@ -3061,7 +2993,7 @@ dependencies = [ name = "ip-packet" version = "0.1.0" dependencies = [ - "hickory-proto", + "domain", "pnet_packet", "proptest", "test-strategy", @@ -3355,12 +3287,6 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dd1bc4d24ad230d21fb898d1116b1801d7adfc449d42026475862ab48b11e70e" -[[package]] -name = "linked-hash-map" -version = "0.5.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0717cef1bc8b636c6e1c1bbdefc09e6322da8a9321966e8928ef80d20f7f770f" - [[package]] name = "linux-raw-sys" version = "0.4.13" @@ -3404,15 +3330,6 @@ dependencies = [ "tracing-subscriber", ] -[[package]] -name = "lru-cache" -version = "0.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "31e24f1ad8321ca0e8a1e0ac13f23cb668e6f5466c2c57319f6a5cf1cc8e3b1c" -dependencies = [ - "linked-hash-map", -] - [[package]] name = "mac" version = "0.1.1" @@ -3464,12 +3381,6 @@ dependencies = [ "tendril", ] -[[package]] -name = "match_cfg" -version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ffbee8634e0d45d258acb448e7eaab3fce7a0a467395d4d9f228e3c1f01fb2e4" - [[package]] name = "matchers" version = "0.1.0" @@ -4438,7 +4349,7 @@ dependencies = [ "base64 0.22.1", "futures", "hex", - "hostname 0.4.0", + "hostname", "libc", "rand_core 0.6.4", "secrecy", @@ -5044,7 +4955,6 @@ version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "52e44394d2086d010551b14b53b1f24e31647570cd1deb0379e2c21b329aba00" dependencies = [ - "hostname 0.3.1", "quick-error", ] @@ -5636,8 +5546,6 @@ dependencies = [ name = "socket-factory" version = "0.1.0" dependencies = [ - "async-trait", - "hickory-proto", "quinn-udp", "socket2", "tokio", diff --git a/rust/bin-shared/Cargo.toml b/rust/bin-shared/Cargo.toml index a2bbe74b9..59bdb6f5c 100644 --- a/rust/bin-shared/Cargo.toml +++ b/rust/bin-shared/Cargo.toml @@ -20,7 +20,7 @@ tun = { workspace = true } url = { version = "2.5.2", default-features = false } [dev-dependencies] -tokio = { workspace = true, features = ["macros"] } +tokio = { workspace = true, features = ["macros", "rt-multi-thread"] } [target.'cfg(target_os = "linux")'.dependencies] libc = "0.2" diff --git a/rust/connlib/clients/android/Cargo.toml b/rust/connlib/clients/android/Cargo.toml index 94095533b..a51601bbb 100644 --- a/rust/connlib/clients/android/Cargo.toml +++ b/rust/connlib/clients/android/Cargo.toml @@ -25,7 +25,7 @@ secrecy = { workspace = true } serde_json = "1" socket-factory = { workspace = true } thiserror = "1" -tokio = { workspace = true, features = ["rt"] } +tokio = { workspace = true, features = ["rt-multi-thread"] } tracing = { workspace = true, features = ["std", "attributes"] } tracing-appender = "0.2" tracing-subscriber = { workspace = true } diff --git a/rust/connlib/clients/apple/Cargo.toml b/rust/connlib/clients/apple/Cargo.toml index 18f406a6d..e085ad16c 100644 --- a/rust/connlib/clients/apple/Cargo.toml +++ b/rust/connlib/clients/apple/Cargo.toml @@ -22,7 +22,7 @@ secrecy = { workspace = true } serde_json = "1" socket-factory = { workspace = true } swift-bridge = { workspace = true } -tokio = { workspace = true, features = ["rt"] } +tokio = { workspace = true, features = ["rt-multi-thread"] } tracing = { workspace = true } tracing-appender = "0.2" tracing-subscriber = "0.3" diff --git a/rust/connlib/clients/shared/src/eventloop.rs b/rust/connlib/clients/shared/src/eventloop.rs index 673846f34..7d4f27972 100644 --- a/rust/connlib/clients/shared/src/eventloop.rs +++ b/rust/connlib/clients/shared/src/eventloop.rs @@ -205,12 +205,7 @@ where fn handle_portal_inbound_message(&mut self, msg: IngressMessages) { match msg { IngressMessages::ConfigChanged(config) => { - if let Err(e) = self - .tunnel - .set_new_interface_config(config.interface.clone()) - { - tracing::warn!(?config, "Failed to update configuration: {e:?}"); - } + self.tunnel.set_new_interface_config(config.interface) } IngressMessages::IceCandidates(GatewayIceCandidates { gateway_id, @@ -225,14 +220,11 @@ where resources, relays, }) => { - if let Err(e) = self.tunnel.set_new_interface_config(interface) { - tracing::warn!("Failed to set interface on tunnel: {e}"); - return; - } + self.tunnel.set_new_interface_config(interface); + self.tunnel.set_resources(resources); + self.tunnel.update_relays(BTreeSet::default(), relays); tracing::info!("Firezone Started!"); - self.tunnel.set_resources(resources); - self.tunnel.update_relays(BTreeSet::default(), relays) } IngressMessages::ResourceCreatedOrUpdated(resource) => { self.tunnel.add_resource(resource); diff --git a/rust/connlib/tunnel/Cargo.toml b/rust/connlib/tunnel/Cargo.toml index aa621e0ea..45df4d9d5 100644 --- a/rust/connlib/tunnel/Cargo.toml +++ b/rust/connlib/tunnel/Cargo.toml @@ -12,11 +12,8 @@ chrono = { workspace = true } connlib-shared = { workspace = true } domain = { workspace = true } futures = { version = "0.3", default-features = false, features = ["std", "async-await", "executor"] } -futures-bounded = { workspace = true } futures-util = { version = "0.3", default-features = false, features = ["std", "async-await", "async-await-macro"] } hex = "0.4.3" -hickory-proto = { workspace = true } -hickory-resolver = { workspace = true, features = ["tokio-runtime"] } ip-packet = { workspace = true } ip_network = { version = "0.4", default-features = false } ip_network_table = { version = "0.2", default-features = false } @@ -27,17 +24,17 @@ rangemap = "1.5.1" secrecy = { workspace = true } serde = { version = "1.0", default-features = false, features = ["derive", "std"] } snownet = { workspace = true } -socket-factory = { workspace = true, features = ["hickory"] } +socket-factory = { workspace = true } socket2 = { workspace = true } thiserror = { version = "1.0", default-features = false } tokio = { workspace = true } tracing = { workspace = true, features = ["attributes"] } tun = { workspace = true } +uuid = { version = "1.10", default-features = false, features = ["std", "v4"] } [dev-dependencies] derivative = "2.2.0" firezone-relay = { workspace = true, features = ["proptest"] } -hickory-proto = { workspace = true } ip-packet = { workspace = true, features = ["proptest"] } proptest-state-machine = "0.3" rand = "0.8" diff --git a/rust/connlib/tunnel/src/client.rs b/rust/connlib/tunnel/src/client.rs index 54e6cee5e..1c5413a00 100644 --- a/rust/connlib/tunnel/src/client.rs +++ b/rust/connlib/tunnel/src/client.rs @@ -1,7 +1,6 @@ +use crate::dns; use crate::dns::StubResolver; -use crate::io::DnsQueryError; use crate::peer_store::PeerStore; -use crate::{dns, dns::DnsQuery}; use anyhow::Context; use bimap::BiMap; use connlib_shared::callbacks::Status; @@ -18,14 +17,14 @@ use ip_network::{IpNetwork, Ipv4Network, Ipv6Network}; use ip_network_table::IpNetworkTable; use ip_packet::{IpPacket, MutableIpPacket, Packet as _}; use itertools::Itertools; -use tracing::Level; use crate::peer::GatewayOnClient; use crate::utils::{self, earliest, turn}; use crate::{ClientEvent, ClientTunnel, Tun}; -use core::fmt; +use domain::base::Message; use secrecy::{ExposeSecret as _, Secret}; -use snownet::{ClientNode, RelaySocket}; +use snownet::{ClientNode, RelaySocket, Transmit}; +use std::borrow::Cow; use std::collections::hash_map::Entry; use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet, VecDeque}; use std::iter; @@ -140,29 +139,12 @@ impl ClientTunnel { pub fn set_new_dns(&mut self, new_dns: Vec) { // We store the sentinel dns both in the config and in the system's resolvers // but when we calculate the dns mapping, those are ignored. - let dns_changed = self.role_state.update_system_resolvers(new_dns); - - if !dns_changed { - return; - } - - self.io - .set_upstream_dns_servers(self.role_state.dns_mapping()); + self.role_state.update_system_resolvers(new_dns); } #[tracing::instrument(level = "trace", skip(self))] - pub fn set_new_interface_config( - &mut self, - config: InterfaceConfig, - ) -> connlib_shared::Result<()> { - let dns_changed = self.role_state.update_interface_config(config); - - if dns_changed { - self.io - .set_upstream_dns_servers(self.role_state.dns_mapping()); - } - - Ok(()) + pub fn set_new_interface_config(&mut self, config: InterfaceConfig) { + self.role_state.update_interface_config(config); } pub fn cleanup_connection(&mut self, id: ResourceId) { @@ -273,15 +255,19 @@ pub struct ClientState { /// The DNS resolvers configured on the system outside of connlib. system_resolvers: Vec, - /// DNS queries that we need to forward to the system resolver. - buffered_dns_queries: VecDeque>, - /// Maps from connlib-assigned IP of a DNS server back to the originally configured system DNS resolver. dns_mapping: BiMap, /// DNS queries that had their destination IP mangled because the servers is a CIDR resource. /// /// The [`Instant`] tracks when the DNS query expires. mangled_dns_queries: HashMap, + /// DNS queries that were forwarded to an upstream server. + /// + /// - The [`SocketAddr`] is the original source IP. + /// - The [`Instant`] tracks when the DNS query expires. + /// + /// We store an explicit expiry to avoid a memory leak in case of a non-responding DNS server. + forwarded_dns_queries: HashMap, /// Manages internal dns records and emits forwarding event when not internally handled stub_resolver: StubResolver, @@ -293,6 +279,7 @@ pub struct ClientState { buffered_events: VecDeque, buffered_packets: VecDeque>, + buffered_transmits: VecDeque>, } #[derive(Debug, Clone, PartialEq, Eq)] @@ -318,14 +305,15 @@ impl ClientState { buffered_events: Default::default(), interface_config: Default::default(), buffered_packets: Default::default(), - buffered_dns_queries: Default::default(), node: ClientNode::new(private_key.into(), seed), system_resolvers: Default::default(), sites_status: Default::default(), gateways_site: Default::default(), mangled_dns_queries: Default::default(), + forwarded_dns_queries: Default::default(), stub_resolver: StubResolver::new(known_hosts), disabled_resources: Default::default(), + buffered_transmits: Default::default(), } } @@ -432,7 +420,7 @@ impl ClientState { packet: MutableIpPacket<'_>, now: Instant, ) -> Option> { - let (packet, dst) = match self.handle_dns(packet) { + let (packet, dst) = match self.try_handle_dns_query(packet, now) { Ok(response) => { self.buffered_packets.push_back(response?.to_owned()); return None; @@ -492,6 +480,10 @@ impl ClientState { now: Instant, buffer: &'b mut [u8], ) -> Option> { + if let Some(response) = self.try_handle_forwarded_dns_response(from, packet) { + return Some(response); + }; + let (gid, packet) = self.node.decapsulate( local, from, @@ -627,35 +619,42 @@ impl ClientState { !interface.upstream_dns.is_empty() } - /// Attempt to handle the given packet as a DNS packet. + /// Attempt to handle the given packet as a DNS query packet. /// /// Returns `Ok` if the packet is in fact a DNS query with an optional response to send back. /// Returns `Err` if the packet is not a DNS query. - fn handle_dns<'a>( + fn try_handle_dns_query<'a>( &mut self, packet: MutableIpPacket<'a>, + now: Instant, ) -> Result>, (MutableIpPacket<'a>, IpAddr)> { match self .stub_resolver .handle(&self.dns_mapping, packet.as_immutable()) { Some(dns::ResolveStrategy::LocalResponse(query)) => Ok(Some(query)), - Some(dns::ResolveStrategy::ForwardQuery(query)) => { - // There's an edge case here, where the resolver's ip has been resolved before as - // a dns resource... we will ignore that weird case for now. - if let Some(upstream_dns) = self.dns_mapping.get_by_left(&query.query.destination()) - { - let ip = upstream_dns.ip(); + Some(dns::ResolveStrategy::ForwardQuery { + upstream: server, + query_id: id, + payload, + original_src, + }) => { + let ip = server.ip(); - // In case the DNS server is a CIDR resource, it needs to go through the tunnel. - if self.is_upstream_set_by_the_portal() - && self.active_cidr_resources.longest_match(ip).is_some() - { - return Err((packet, ip)); - } + // In case the DNS server is a CIDR resource, it needs to go through the tunnel. + if self.is_upstream_set_by_the_portal() + && self.active_cidr_resources.longest_match(ip).is_some() + { + return Err((packet, ip)); } - self.buffered_dns_queries.push_back(query.into_owned()); + self.forwarded_dns_queries + .insert(id, (original_src, now + IDS_EXPIRE)); + self.buffered_transmits.push_back(Transmit { + src: None, + dst: server, + payload: Cow::Owned(payload), + }); Ok(None) } @@ -666,46 +665,26 @@ impl ClientState { } } - #[tracing::instrument(level = "debug", skip_all, fields(name = %query.name, server = %query.query.destination()))] // On debug level, we can log potentially sensitive information such as domain names. - pub(crate) fn on_dns_result( + fn try_handle_forwarded_dns_response<'a>( &mut self, - query: DnsQuery<'static>, - response: Result< - Result< - Result, - futures_bounded::Timeout, - >, - DnsQueryError, - >, - ) { - let query = query.query; - let make_error_reply = { - let query = query.clone(); + from: SocketAddr, + packet: &[u8], + ) -> Option> { + // The sentinel DNS server shall be the source. If we don't have a sentinel DNS for this socket, it cannot be a DNS response. + let saddr = *self.dns_mapping.get_by_right(&DnsServer::from(from))?; + let sport = DNS_PORT; - |e: &dyn fmt::Display| { - // To avoid sensitive data getting into the logs, only log the error if debug logging is enabled. - // We always want to see a warning. - if tracing::enabled!(Level::DEBUG) { - tracing::warn!("DNS query failed: {e}"); - } else { - tracing::warn!("DNS query failed"); - }; + let message = Message::from_slice(packet).ok()?; + let query_id = message.header().id(); - ip_packet::make::dns_err_response(query, hickory_proto::op::ResponseCode::ServFail) - .into_immutable() - } - }; + let (destination, _) = self.forwarded_dns_queries.remove(&query_id)?; + let daddr = destination.ip(); + let dport = destination.port(); - let dns_reply = match response { - Ok(Ok(response)) => match dns::build_response_from_resolve_result(query, response) { - Ok(dns_reply) => dns_reply, - Err(e) => make_error_reply(&e), - }, - Ok(Err(timeout)) => make_error_reply(&timeout), - Err(e) => make_error_reply(&e), - }; - - self.buffered_packets.push_back(dns_reply); + Some( + ip_packet::make::udp_packet(saddr, daddr, sport, dport, packet.to_vec()) + .into_immutable(), + ) } pub fn on_connection_failed(&mut self, resource: ResourceId) { @@ -844,15 +823,13 @@ impl ClientState { .filter(|resource| self.is_resource_enabled(resource)) } - #[must_use] - pub(crate) fn update_system_resolvers(&mut self, new_dns: Vec) -> bool { + pub(crate) fn update_system_resolvers(&mut self, new_dns: Vec) { self.system_resolvers = new_dns; self.update_dns_mapping() } - #[must_use] - pub(crate) fn update_interface_config(&mut self, config: InterfaceConfig) -> bool { + pub(crate) fn update_interface_config(&mut self, config: InterfaceConfig) { self.interface_config = Some(config); self.update_dns_mapping() @@ -862,10 +839,6 @@ impl ClientState { self.buffered_packets.pop_front() } - pub fn poll_dns_queries(&mut self) -> Option> { - self.buffered_dns_queries.pop_front() - } - pub fn poll_timeout(&mut self) -> Option { // The number of mangled DNS queries is expected to be fairly small because we only track them whilst connecting to a CIDR resource that is a DNS server. // Thus, sorting these values on-demand even within `poll_timeout` is expected to be performant enough. @@ -878,6 +851,7 @@ impl ClientState { pub fn handle_timeout(&mut self, now: Instant) { self.node.handle_timeout(now); self.mangled_dns_queries.retain(|_, exp| now < *exp); + self.forwarded_dns_queries.retain(|_, (_, exp)| now < *exp); self.drain_node_events(); } @@ -965,7 +939,9 @@ impl ClientState { } pub(crate) fn poll_transmit(&mut self) -> Option> { - self.node.poll_transmit() + self.buffered_transmits + .pop_front() + .or_else(|| self.node.poll_transmit()) } /// Sets a new set of resources. @@ -1078,9 +1054,11 @@ impl ClientState { self.resources_gateways.remove(&id); } - fn update_dns_mapping(&mut self) -> bool { + fn update_dns_mapping(&mut self) { let Some(config) = &self.interface_config else { - return false; + tracing::debug!("Unable to update DNS servesr without interface configuration"); + + return; }; let effective_dns_servers = @@ -1089,7 +1067,9 @@ impl ClientState { if HashSet::<&DnsServer>::from_iter(effective_dns_servers.iter()) == HashSet::from_iter(self.dns_mapping.right_values()) { - return false; + tracing::debug!("Effective DNS servers are unchanged"); + + return; } let dns_mapping = sentinel_dns_mapping( @@ -1121,8 +1101,6 @@ impl ClientState { ip4: self.routes().filter_map(utils::ipv4).collect(), ip6: self.routes().filter_map(utils::ipv6).collect(), }); - - true } pub fn update_relays( @@ -1390,134 +1368,6 @@ mod tests { assert!(is_definitely_not_a_resource(ip("ff02::2"))) } - #[test] - fn update_system_dns_works() { - let mut client_state = ClientState::for_test(); - client_state.interface_config = Some(interface_config_without_dns()); - - let dns_changed = client_state.update_system_resolvers(vec![ip("1.1.1.1")]); - - assert!(dns_changed); - dns_mapping_is_exactly(client_state.dns_mapping(), vec![dns("1.1.1.1:53")]); - } - - #[test] - fn update_system_dns_without_change_is_a_no_op() { - let mut client_state = ClientState::for_test(); - client_state.interface_config = Some(interface_config_without_dns()); - - let _ = client_state.update_system_resolvers(vec![ip("1.1.1.1")]); - let dns_changed = client_state.update_system_resolvers(vec![ip("1.1.1.1")]); - - assert!(!dns_changed); - dns_mapping_is_exactly(client_state.dns_mapping(), vec![dns("1.1.1.1:53")]); - } - - #[test] - fn update_system_dns_with_change_works() { - let mut client_state = ClientState::for_test(); - client_state.interface_config = Some(interface_config_without_dns()); - - let _ = client_state.update_system_resolvers(vec![ip("1.1.1.1")]); - let dns_changed = client_state.update_system_resolvers(vec![ip("1.0.0.1")]); - - assert!(dns_changed); - dns_mapping_is_exactly(client_state.dns_mapping(), vec![dns("1.0.0.1:53")]); - } - - #[test] - fn update_to_system_with_sentinels_are_ignored() { - let mut client_state = ClientState::for_test(); - client_state.interface_config = Some(interface_config_without_dns()); - - let _ = client_state.update_system_resolvers(vec![ip("1.1.1.1")]); - let dns_changed = client_state.update_system_resolvers(vec![ - ip("1.1.1.1"), - ip("100.100.111.1"), - ip("fd00:2021:1111:8000:100:100:111:0"), - ]); - - assert!(!dns_changed); - dns_mapping_is_exactly(client_state.dns_mapping(), vec![dns("1.1.1.1:53")]); - } - - #[test] - fn upstream_dns_wins_over_system() { - let mut client_state = ClientState::for_test(); - client_state.interface_config = Some(interface_config_with_dns()); - - let dns_changed = client_state.update_dns_mapping(); - assert!(dns_changed); - - let dns_changed = client_state.update_system_resolvers(vec![ip("1.0.0.1")]); - assert!(!dns_changed); - - dns_mapping_is_exactly(client_state.dns_mapping(), dns_list()); - } - - #[test] - fn upstream_dns_change_updates() { - let mut client_state = ClientState::for_test(); - - let dns_changed = client_state.update_interface_config(interface_config_with_dns()); - - assert!(dns_changed); - - let dns_changed = client_state.update_interface_config(InterfaceConfig { - upstream_dns: vec![dns("8.8.8.8:53")], - ..interface_config_without_dns() - }); - - assert!(dns_changed); - - dns_mapping_is_exactly(client_state.dns_mapping(), vec![dns("8.8.8.8:53")]); - } - - #[test] - fn upstream_dns_no_change_is_a_no_op() { - let mut client_state = ClientState::for_test(); - client_state.interface_config = Some(interface_config_with_dns()); - - let dns_changed = client_state.update_system_resolvers(vec![ip("1.0.0.1")]); - - assert!(dns_changed); - - let dns_changed = client_state.update_interface_config(interface_config_with_dns()); - - assert!(!dns_changed); - dns_mapping_is_exactly(client_state.dns_mapping(), dns_list()); - } - - #[test] - fn upstream_dns_sentinels_are_ignored() { - let mut client_state = ClientState::for_test(); - let mut config = interface_config_with_dns(); - - let _ = client_state.update_interface_config(config.clone()); - - config.upstream_dns.push(dns("100.100.111.1:53")); - config - .upstream_dns - .push(dns("[fd00:2021:1111:8000:100:100:111:0]:53")); - - let dns_changed = client_state.update_interface_config(config); - - assert!(!dns_changed); - dns_mapping_is_exactly(client_state.dns_mapping(), dns_list()) - } - - #[test] - fn system_dns_takes_over_when_upstream_are_unset() { - let mut client_state = ClientState::for_test(); - let _ = client_state.update_interface_config(interface_config_with_dns()); - - let _ = client_state.update_system_resolvers(vec![ip("1.0.0.1")]); - let dns_changed = client_state.update_interface_config(interface_config_without_dns()); - - assert!(dns_changed); - dns_mapping_is_exactly(client_state.dns_mapping(), vec![dns("1.0.0.1:53")]); - } - #[test] fn sentinel_dns_works() { let servers = dns_list(); @@ -1559,29 +1409,6 @@ mod tests { } } - fn dns_mapping_is_exactly(mapping: BiMap, servers: Vec) { - assert_eq!( - HashSet::<&DnsServer>::from_iter(mapping.right_values()), - HashSet::from_iter(servers.iter()) - ) - } - - fn interface_config_without_dns() -> InterfaceConfig { - InterfaceConfig { - ipv4: "10.0.0.1".parse().unwrap(), - ipv6: "fe80::".parse().unwrap(), - upstream_dns: Vec::new(), - } - } - - fn interface_config_with_dns() -> InterfaceConfig { - InterfaceConfig { - ipv4: "10.0.0.1".parse().unwrap(), - ipv6: "fe80::".parse().unwrap(), - upstream_dns: dns_list(), - } - } - fn sentinel_ranges() -> Vec { vec![ IpNetwork::V4(DNS_SENTINELS_V4), diff --git a/rust/connlib/tunnel/src/dns.rs b/rust/connlib/tunnel/src/dns.rs index d8cdc8055..5c47056e3 100644 --- a/rust/connlib/tunnel/src/dns.rs +++ b/rust/connlib/tunnel/src/dns.rs @@ -7,16 +7,11 @@ use domain::base::{ Message, MessageBuilder, ToName, }; use domain::rdata::AllRecordData; -use hickory_resolver::lookup::Lookup; -use hickory_resolver::proto::error::{ProtoError, ProtoErrorKind}; -use hickory_resolver::proto::op::MessageType; -use hickory_resolver::proto::rr::RecordType; -use ip_packet::udp::UdpPacket; use ip_packet::IpPacket; use ip_packet::Packet as _; use itertools::Itertools; use std::collections::HashMap; -use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; const DNS_TTL: u32 = 1; const REVERSE_DNS_ADDRESS_END: &str = "arpa"; @@ -34,22 +29,18 @@ pub struct StubResolver { known_hosts: KnownHosts, } -#[derive(Debug)] -pub struct DnsQuery<'a> { - pub name: DomainName, - pub record_type: RecordType, - // We could be much more efficient with this field, - // we only need the header to create the response. - pub query: ip_packet::IpPacket<'a>, -} - /// Tells the Client how to reply to a single DNS query #[derive(Debug)] -pub(crate) enum ResolveStrategy<'a> { +pub(crate) enum ResolveStrategy { /// The query is for a Resource, we have an IP mapped already, and we can respond instantly LocalResponse(IpPacket<'static>), - /// The query is for a non-Resource, forward it to an upstream or system resolver - ForwardQuery(DnsQuery<'a>), + /// The query is for a non-Resource, forward it to an upstream or system resolver. + ForwardQuery { + upstream: SocketAddr, + original_src: SocketAddr, + query_id: u16, + payload: Vec, + }, } struct KnownHosts { @@ -193,29 +184,35 @@ impl StubResolver { self.dns_resources.values().contains(resource) } - // TODO: we can save a few allocations here still - // We don't need to support multiple questions/qname in a single query because - // nobody does it and since this run with each packet we want to squeeze as much optimization - // as we can therefore we won't do it. - // - // See: https://stackoverflow.com/a/55093896 /// Parses an incoming packet as a DNS query and decides how to respond to it /// /// Returns: /// - `None` if the packet is not a valid DNS query destined for one of our sentinel resolvers /// - Otherwise, a strategy for responding to the query - pub(crate) fn handle<'a>( + pub(crate) fn handle( &mut self, dns_mapping: &bimap::BiMap, - packet: IpPacket<'a>, - ) -> Option> { - dns_mapping.get_by_left(&packet.destination())?; + packet: IpPacket, + ) -> Option { + let upstream = dns_mapping.get_by_left(&packet.destination())?.address(); let datagram = packet.as_udp()?; - let message = as_dns(&datagram)?; + + // We only support DNS on port 53. + if datagram.get_destination() != DNS_PORT { + return None; + } + + let message = Message::from_octets(datagram.payload()).ok()?; + if message.header().qr() { return None; } + // We don't need to support multiple questions/qname in a single query because + // nobody does it and since this run with each packet we want to squeeze as much optimization + // as we can therefore we won't do it. + // + // See: https://stackoverflow.com/a/55093896 let question = message.first_question()?; let domain = question.qname().to_vec(); let qtype = question.qtype(); @@ -240,11 +237,12 @@ impl StubResolver { let resource_records = match (qtype, maybe_resource) { (_, Some(resource)) if !self.knows_resource(&resource) => { - return Some(ResolveStrategy::ForwardQuery(DnsQuery { - name: domain, - record_type: u16::from(qtype).into(), - query: packet, - })) + return Some(ResolveStrategy::ForwardQuery { + upstream, + query_id: message.header().id(), + payload: message.into_octets().to_vec(), + original_src: SocketAddr::new(packet.source(), datagram.get_source()), + }) } (Rtype::A, Some(resource)) => self.get_or_assign_a_records(domain.clone(), resource), (Rtype::AAAA, Some(resource)) => { @@ -256,11 +254,12 @@ impl StubResolver { vec![AllRecordData::Ptr(domain::rdata::Ptr::new(fqdn))] } _ => { - return Some(ResolveStrategy::ForwardQuery(DnsQuery { - name: domain, - record_type: u16::from(qtype).into(), - query: packet, - })) + return Some(ResolveStrategy::ForwardQuery { + upstream, + query_id: message.header().id(), + payload: message.into_octets().to_vec(), + original_src: SocketAddr::new(packet.source(), datagram.get_source()), + }) } }; @@ -278,35 +277,6 @@ impl StubResolver { } } -impl<'a> DnsQuery<'a> { - pub(crate) fn into_owned(self) -> DnsQuery<'static> { - let Self { - name, - record_type, - query, - } = self; - let buf = query.packet().to_vec(); - let query = ip_packet::IpPacket::owned(buf) - .expect("We are constructing the ip packet from an ip packet"); - - DnsQuery { - name, - record_type, - query, - } - } -} - -impl Clone for DnsQuery<'static> { - fn clone(&self) -> Self { - Self { - name: self.name.clone(), - record_type: self.record_type, - query: self.query.clone(), - } - } -} - fn to_a_records(ips: impl Iterator) -> Vec, DomainName>> { ips.filter_map(get_v4) .map(domain::rdata::A::new) @@ -321,57 +291,13 @@ fn to_aaaa_records(ips: impl Iterator) -> Vec, - response: hickory_resolver::error::ResolveResult, -) -> Result { - let datagram = original_pkt.unwrap_as_udp(); - let mut message = original_pkt.unwrap_as_dns(); - - message.set_message_type(MessageType::Response); - message.set_recursion_available(true); - - let response = match response.map_err(|err| err.kind().clone()) { - Ok(response) => message.add_answers(response.records().to_vec()), - Err(hickory_resolver::error::ResolveErrorKind::Proto(ProtoError { kind, .. })) - if matches!(*kind, ProtoErrorKind::NoRecordsFound { .. }) => - { - let ProtoErrorKind::NoRecordsFound { - soa, response_code, .. - } = *kind - else { - panic!("Impossible - We matched on `ProtoErrorKind::NoRecordsFound` but then could not destructure that same variant"); - }; - if let Some(soa) = soa { - message.add_name_server(soa.into_record_of_rdata()); - } - - message.set_response_code(response_code) - } - Err(e) => { - return Err(e.into()); - } - }; - - let packet = ip_packet::make::udp_packet( - original_pkt.destination(), - original_pkt.source(), - datagram.get_destination(), - datagram.get_source(), - response.to_vec()?, - ) - .into_immutable(); - - Ok(packet) -} - fn build_dns_with_answer( - message: &Message<[u8]>, + message: Message<&[u8]>, qname: DomainName, records: Vec, DomainName>>, ) -> Option> { let mut answer_builder = MessageBuilder::new_vec() - .start_answer(message, Rcode::NOERROR) + .start_answer(&message, Rcode::NOERROR) .ok()?; answer_builder.header_mut().set_ra(true); @@ -384,12 +310,6 @@ fn build_dns_with_answer( Some(answer_builder.finish()) } -pub fn as_dns<'a>(pkt: &'a UdpPacket<'a>) -> Option<&'a Message<[u8]>> { - (pkt.get_destination() == DNS_PORT) - .then(|| Message::from_slice(pkt.payload()).ok()) - .flatten() -} - pub fn is_subdomain(name: &DomainName, resource: &str) -> bool { let question_mark = RelativeName::>::from_octets(b"\x01?".as_ref().into()).unwrap(); let Ok(resource) = DomainName::vec_from_str(resource) else { diff --git a/rust/connlib/tunnel/src/io.rs b/rust/connlib/tunnel/src/io.rs index ec14f7892..6846ad7e5 100644 --- a/rust/connlib/tunnel/src/io.rs +++ b/rust/connlib/tunnel/src/io.rs @@ -1,29 +1,15 @@ -use crate::{device_channel::Device, dns::DnsQuery, sockets::Sockets}; -use connlib_shared::messages::DnsServer; -use futures::Future; -use futures_bounded::FuturesTupleSet; +use crate::{device_channel::Device, sockets::Sockets}; use futures_util::FutureExt as _; -use hickory_proto::iocompat::AsyncIoTokioAsStd; -use hickory_proto::TokioTime; -use hickory_resolver::{ - config::{NameServerConfig, Protocol, ResolverConfig, ResolverOpts}, - name_server::{GenericConnector, RuntimeProvider}, - AsyncResolver, TokioHandle, -}; use ip_packet::{IpPacket, MutableIpPacket}; use socket_factory::{DatagramIn, DatagramOut, SocketFactory, TcpSocket, UdpSocket}; use std::{ - collections::HashMap, io, - net::{IpAddr, SocketAddr}, pin::Pin, sync::Arc, task::{ready, Context, Poll}, - time::{Duration, Instant}, + time::Instant, }; -const DNS_QUERIES_QUEUE_SIZE: usize = 100; - /// Bundles together all side-effects that connlib needs to have access to. pub struct Io { /// The TUN device offered to the user. @@ -33,29 +19,16 @@ pub struct Io { /// The UDP sockets used to send & receive packets from the network. sockets: Sockets, - tcp_socket_factory: Arc>, + _tcp_socket_factory: Arc>, udp_socket_factory: Arc>, timeout: Option>>, - - upstream_dns_servers: HashMap>>, - forwarded_dns_queries: FuturesTupleSet< - Result, - DnsQuery<'static>, - >, } pub enum Input<'a, I> { Timeout(Instant), Device(MutableIpPacket<'a>), Network(I), - DnsResponse( - DnsQuery<'static>, - Result< - Result, - futures_bounded::Timeout, - >, - ), } impl Io { @@ -73,13 +46,8 @@ impl Io { device: Device::new(), timeout: None, sockets, - tcp_socket_factory, + _tcp_socket_factory: tcp_socket_factory, udp_socket_factory, - upstream_dns_servers: HashMap::default(), - forwarded_dns_queries: FuturesTupleSet::new( - Duration::from_secs(60), - DNS_QUERIES_QUEUE_SIZE, - ), }) } @@ -90,10 +58,6 @@ impl Io { ip6_bffer: &'b mut [u8], device_buffer: &'b mut [u8], ) -> Poll>>>> { - if let Poll::Ready((response, query)) = self.forwarded_dns_queries.poll_unpin(cx) { - return Poll::Ready(Ok(Input::DnsResponse(query, response))); - } - if let Poll::Ready(network) = self.sockets.poll_recv_from(ip4_buffer, ip6_bffer, cx)? { return Poll::Ready(Ok(Input::Network(network))); } @@ -126,50 +90,6 @@ impl Io { Ok(()) } - pub fn set_upstream_dns_servers( - &mut self, - dns_servers: impl IntoIterator, - ) { - tracing::info!("Setting new DNS resolvers"); - - self.forwarded_dns_queries = - FuturesTupleSet::new(Duration::from_secs(60), DNS_QUERIES_QUEUE_SIZE); - self.upstream_dns_servers = create_resolvers( - dns_servers, - TokioRuntimeProvider::new( - self.tcp_socket_factory.clone(), - self.udp_socket_factory.clone(), - ), - ); - } - - pub fn perform_dns_query(&mut self, query: DnsQuery<'static>) -> Result<(), DnsQueryError> { - let upstream = query.query.destination(); - let resolver = self - .upstream_dns_servers - .get(&upstream) - .cloned() - .expect("Only DNS queries to known upstream servers should be forwarded to `Io`"); - - if self - .forwarded_dns_queries - .try_push( - { - let name = query.name.clone().to_string(); - let record_type = query.record_type; - - async move { resolver.lookup(&name, record_type).await } - }, - query, - ) - .is_err() - { - return Err(DnsQueryError::TooManyQueries); - } - - Ok(()) - } - pub fn reset_timeout(&mut self, timeout: Instant) { let timeout = tokio::time::Instant::from_std(timeout); @@ -198,92 +118,3 @@ impl Io { Ok(()) } } - -#[derive(Debug, thiserror::Error)] -pub enum DnsQueryError { - #[error("Too many ongoing DNS queries")] - TooManyQueries, -} - -/// Identical to [`TokioRuntimeProvider`](hickory_resolver::name_server::TokioRuntimeProvider) but using our own [`SocketFactory`]. -#[derive(Clone)] -struct TokioRuntimeProvider { - handle: TokioHandle, - tcp_socket_factory: Arc>, - udp_socket_factory: Arc>, -} - -impl TokioRuntimeProvider { - fn new( - tcp_socket_factory: Arc>, - udp_socket_factory: Arc>, - ) -> TokioRuntimeProvider { - Self { - handle: Default::default(), - tcp_socket_factory, - udp_socket_factory, - } - } -} - -impl RuntimeProvider for TokioRuntimeProvider { - type Handle = TokioHandle; - type Timer = TokioTime; - type Udp = UdpSocket; - type Tcp = AsyncIoTokioAsStd; - - fn create_handle(&self) -> Self::Handle { - self.handle.clone() - } - - fn connect_tcp( - &self, - server_addr: SocketAddr, - ) -> Pin>>> { - let socket = (self.tcp_socket_factory)(&server_addr); - Box::pin(async move { - let socket = socket?; - let stream = socket.connect(server_addr).await?; - - Ok(AsyncIoTokioAsStd(stream)) - }) - } - - fn bind_udp( - &self, - local_addr: SocketAddr, - _server_addr: SocketAddr, - ) -> Pin>>> { - let socket = (self.udp_socket_factory)(&local_addr); - - Box::pin(async move { socket }) - } -} - -fn create_resolvers( - dns_servers: impl IntoIterator, - runtime_provider: TokioRuntimeProvider, -) -> HashMap>> { - dns_servers - .into_iter() - .map(|(sentinel, srv)| { - let mut resolver_config = ResolverConfig::new(); - resolver_config.add_name_server(NameServerConfig::new(srv.address(), Protocol::Udp)); - resolver_config.add_name_server(NameServerConfig::new(srv.address(), Protocol::Tcp)); - - let mut resolver_opts = ResolverOpts::default(); - resolver_opts.edns0 = true; - resolver_opts.cache_size = 0; - resolver_opts.attempts = 1; - - ( - sentinel, - AsyncResolver::new_with_conn( - resolver_config, - resolver_opts, - GenericConnector::new(runtime_provider.clone()), - ), - ) - }) - .collect() -} diff --git a/rust/connlib/tunnel/src/lib.rs b/rust/connlib/tunnel/src/lib.rs index 379012b19..0378ab0f5 100644 --- a/rust/connlib/tunnel/src/lib.rs +++ b/rust/connlib/tunnel/src/lib.rs @@ -116,13 +116,6 @@ impl ClientTunnel { continue; } - if let Some(dns_query) = self.role_state.poll_dns_queries() { - if let Err(e) = self.io.perform_dns_query(dns_query.clone()) { - self.role_state.on_dns_result(dns_query, Err(e)) - } - continue; - } - if let Some(timeout) = self.role_state.poll_timeout() { self.io.reset_timeout(timeout); } @@ -163,10 +156,6 @@ impl ClientTunnel { continue; } - Poll::Ready(io::Input::DnsResponse(query, response)) => { - self.role_state.on_dns_result(query, Ok(response)); - continue; - } Poll::Pending => {} } @@ -254,9 +243,6 @@ impl GatewayTunnel { continue; } - Poll::Ready(io::Input::DnsResponse(_, _)) => { - unreachable!("Gateway does not (yet) resolve DNS queries via `Io`") - } Poll::Pending => {} } diff --git a/rust/connlib/tunnel/src/tests.rs b/rust/connlib/tunnel/src/tests.rs index 8be1f5cdf..02dd19b3c 100644 --- a/rust/connlib/tunnel/src/tests.rs +++ b/rust/connlib/tunnel/src/tests.rs @@ -8,6 +8,7 @@ mod flux_capacitor; mod reference; mod run_count_appender; mod sim_client; +mod sim_dns; mod sim_gateway; mod sim_net; mod sim_relay; diff --git a/rust/connlib/tunnel/src/tests/reference.rs b/rust/connlib/tunnel/src/tests/reference.rs index 9e4ca709d..a7d477788 100644 --- a/rust/connlib/tunnel/src/tests/reference.rs +++ b/rust/connlib/tunnel/src/tests/reference.rs @@ -1,5 +1,5 @@ use super::{ - composite_strategy::CompositeStrategy, sim_client::*, sim_gateway::*, sim_net::*, + composite_strategy::CompositeStrategy, sim_client::*, sim_dns::*, sim_gateway::*, sim_net::*, strategies::*, stub_portal::StubPortal, transition::*, }; use crate::dns::is_subdomain; @@ -10,7 +10,7 @@ use connlib_shared::{ }, DomainName, StaticSecret, }; -use hickory_proto::rr::RecordType; +use domain::base::Rtype; use proptest::{prelude::*, sample}; use proptest_state_machine::ReferenceStateMachine; use std::{ @@ -27,6 +27,7 @@ pub(crate) struct ReferenceState { pub(crate) client: Host, pub(crate) gateways: BTreeMap>, pub(crate) relays: BTreeMap>, + pub(crate) dns_servers: BTreeMap>, pub(crate) portal: StubPortal, @@ -58,13 +59,14 @@ impl ReferenceStateMachine for ReferenceState { fn init_state() -> BoxedStrategy { stub_portal() - .prop_flat_map(move |portal| { + .prop_flat_map(|portal| { let gateways = portal.gateways(); let dns_resource_records = portal.dns_resource_records(); let client = portal.client(); let relays = relays(); let global_dns_records = global_dns_records(); // Start out with a set of global DNS records so we have something to resolve outside of DNS resources. let drop_direct_client_traffic = any::(); + let dns_servers = dns_servers(); ( client, @@ -74,6 +76,7 @@ impl ReferenceStateMachine for ReferenceState { relays, global_dns_records, drop_direct_client_traffic, + dns_servers, ) }) .prop_filter_map( @@ -86,6 +89,7 @@ impl ReferenceStateMachine for ReferenceState { relays, mut global_dns, drop_direct_client_traffic, + dns_servers, )| { let mut routing_table = RoutingTable::default(); @@ -104,6 +108,12 @@ impl ReferenceStateMachine for ReferenceState { }; } + for (id, dns_server) in &dns_servers { + if !routing_table.add_host(*id, dns_server) { + return None; + }; + } + // Merge all DNS records into `global_dns`. global_dns.extend(records); @@ -111,6 +121,7 @@ impl ReferenceStateMachine for ReferenceState { c, gateways, relays, + dns_servers, portal, global_dns, drop_direct_client_traffic, @@ -120,7 +131,7 @@ impl ReferenceStateMachine for ReferenceState { ) .prop_filter( "private keys must be unique", - |(c, gateways, _, _, _, _, _)| { + |(c, gateways, _, _, _, _, _, _)| { let different_keys = gateways .iter() .map(|(_, g)| g.inner().key) @@ -135,6 +146,7 @@ impl ReferenceStateMachine for ReferenceState { client, gateways, relays, + dns_servers, portal, global_dns_records, drop_direct_client_traffic, @@ -144,6 +156,7 @@ impl ReferenceStateMachine for ReferenceState { client, gateways, relays, + dns_servers, portal, global_dns_records, network, @@ -162,13 +175,11 @@ impl ReferenceStateMachine for ReferenceState { CompositeStrategy::default() .with( 1, - system_dns_servers() - .prop_map(|servers| Transition::UpdateSystemDnsServers { servers }), + update_system_dns_servers(state.dns_servers.values().cloned().collect()), ) .with( 1, - upstream_dns_servers() - .prop_map(|servers| Transition::UpdateUpstreamDnsServers { servers }), + update_upstream_dns_servers(state.dns_servers.values().cloned().collect()), ) .with_if_not_empty( 5, @@ -396,12 +407,12 @@ impl ReferenceStateMachine for ReferenceState { state.portal.gateway_for_resource(r).copied() }) }), - Transition::UpdateSystemDnsServers { servers } => { + Transition::UpdateSystemDnsServers(servers) => { state .client .exec_mut(|client| client.system_dns_resolvers.clone_from(servers)); } - Transition::UpdateUpstreamDnsServers { servers } => { + Transition::UpdateUpstreamDnsServers(servers) => { state .client .exec_mut(|client| client.upstream_dns_resolvers.clone_from(servers)); @@ -513,34 +524,28 @@ impl ReferenceStateMachine for ReferenceState { ref_client.is_valid_icmp_packet(seq, identifier) && ref_client.dns_records.get(dst).is_some_and(|r| match src { - IpAddr::V4(_) => r.contains(&RecordType::A), - IpAddr::V6(_) => r.contains(&RecordType::AAAA), + IpAddr::V4(_) => r.contains(&Rtype::A), + IpAddr::V6(_) => r.contains(&Rtype::AAAA), }) && state.gateways.contains_key(gateway) } - Transition::UpdateSystemDnsServers { servers } => { - // TODO: PRODUCTION CODE DOES NOT HANDLE THIS! - - if state.client.ip4.is_none() && servers.iter().all(|s| s.is_ipv4()) { - return false; - } - if state.client.ip6.is_none() && servers.iter().all(|s| s.is_ipv6()) { - return false; + Transition::UpdateSystemDnsServers(servers) => { + if servers.is_empty() { + return true; // Clearing is allowed. } - true + servers + .iter() + .any(|dns_server| state.client.sending_socket_for(*dns_server).is_some()) } - Transition::UpdateUpstreamDnsServers { servers } => { - // TODO: PRODUCTION CODE DOES NOT HANDLE THIS! - - if state.client.ip4.is_none() && servers.iter().all(|s| s.ip().is_ipv4()) { - return false; - } - if state.client.ip6.is_none() && servers.iter().all(|s| s.ip().is_ipv6()) { - return false; + Transition::UpdateUpstreamDnsServers(servers) => { + if servers.is_empty() { + return true; // Clearing is allowed. } - true + servers + .iter() + .any(|dns_server| state.client.sending_socket_for(dns_server.ip()).is_some()) } Transition::SendDnsQuery { domain, dns_server, .. diff --git a/rust/connlib/tunnel/src/tests/sim_client.rs b/rust/connlib/tunnel/src/tests/sim_client.rs index eb7011001..46ff466fe 100644 --- a/rust/connlib/tunnel/src/tests/sim_client.rs +++ b/rust/connlib/tunnel/src/tests/sim_client.rs @@ -2,11 +2,10 @@ use super::{ reference::{private_key, PrivateKey, ResourceDst}, sim_net::{any_ip_stack, any_port, host, Host}, sim_relay::{map_explode, SimRelay}, - strategies::{latency, system_dns_servers, upstream_dns_servers}, - sut::domain_to_hickory_name, + strategies::latency, IcmpIdentifier, IcmpSeq, QueryId, }; -use crate::{tests::sut::hickory_name_to_domain, ClientState}; +use crate::ClientState; use bimap::BiMap; use connlib_shared::{ messages::{ @@ -16,10 +15,9 @@ use connlib_shared::{ proptest::{client_id, domain_name}, DomainName, }; -use hickory_proto::{ - op::MessageType, - rr::{rdata, RData, RecordType}, - serialize::binary::BinDecodable as _, +use domain::{ + base::{Message, Rtype, ToName}, + rdata::AllRecordData, }; use ip_network::{Ipv4Network, Ipv6Network}; use ip_network_table::IpNetworkTable; @@ -80,7 +78,7 @@ impl SimClient { pub(crate) fn send_dns_query_for( &mut self, domain: DomainName, - r_type: RecordType, + r_type: Rtype, query_id: u16, dns_server: SocketAddr, now: Instant, @@ -92,15 +90,13 @@ impl SimClient { tracing::debug!(%dns_server, %domain, "Sending DNS query"); - let name = domain_to_hickory_name(domain); - let src = self .sut .tunnel_ip_for(dns_server) .expect("tunnel should be initialised"); let packet = ip_packet::make::dns_query( - name, + domain, r_type, SocketAddr::new(src, 9999), // An application would pick a random source port that is free. SocketAddr::new(dns_server, 53), @@ -130,14 +126,13 @@ impl SimClient { let packet = packet.to_owned().into_immutable(); if let Some(udp) = packet.as_udp() { - if let Ok(message) = hickory_proto::op::Message::from_bytes(udp.payload()) { - debug_assert_eq!( - message.message_type(), - MessageType::Query, + if let Ok(message) = Message::from_slice(udp.payload()) { + debug_assert!( + !message.header().qr(), "every DNS message sent from the client should be a DNS query" ); - self.sent_dns_queries.insert(message.id(), packet); + self.sent_dns_queries.insert(message.header().id(), packet); } } } @@ -175,18 +170,24 @@ impl SimClient { if let Some(udp) = packet.as_udp() { if udp.get_source() == 53 { - let mut message = hickory_proto::op::Message::from_bytes(udp.payload()) + let message = Message::from_slice(udp.payload()) .expect("ip packets on port 53 to be DNS packets"); self.received_dns_responses - .insert(message.id(), packet.to_owned()); + .insert(message.header().id(), packet.to_owned()); - for record in message.take_answers().into_iter() { - let domain = hickory_name_to_domain(record.name().clone()); + for record in message.answer().unwrap() { + let record = record.unwrap(); + let domain = record.owner().to_name(); - let ip = match record.data() { - Some(RData::A(rdata::A(ip4))) => IpAddr::from(*ip4), - Some(RData::AAAA(rdata::AAAA(ip6))) => IpAddr::from(*ip6), + #[allow(clippy::wildcard_enum_match_arm)] + let ip = match record + .into_any_record::>() + .unwrap() + .data() + { + AllRecordData::A(a) => IpAddr::from(a.addr()), + AllRecordData::Aaaa(aaaa) => IpAddr::from(aaaa.addr()), unhandled => { panic!("Unexpected record data: {unhandled:?}") } @@ -253,7 +254,7 @@ pub struct RefClient { /// The IPs assigned to a domain by connlib are an implementation detail that we don't want to model in these tests. /// Instead, we just remember what _kind_ of records we resolved to be able to sample a matching src IP. #[derivative(Debug = "ignore")] - pub(crate) dns_records: BTreeMap>, + pub(crate) dns_records: BTreeMap>, /// Whether we are connected to the gateway serving the Internet resource. pub(crate) connected_internet_resources: bool, @@ -287,12 +288,12 @@ impl RefClient { /// This simulates receiving the `init` message from the portal. pub(crate) fn init(self) -> SimClient { let mut client_state = ClientState::new(self.key, self.known_hosts, self.key.0); // Cheating a bit here by reusing the key as seed. - let _ = client_state.update_interface_config(Interface { + client_state.update_interface_config(Interface { ipv4: self.tunnel_ip4, ipv6: self.tunnel_ip6, - upstream_dns: self.upstream_dns_resolvers, + upstream_dns: self.upstream_dns_resolvers.clone(), }); - let _ = client_state.update_system_resolvers(self.system_dns_resolvers); + client_state.update_system_resolvers(self.system_dns_resolvers.clone()); SimClient::new(self.id, client_state) } @@ -478,7 +479,7 @@ impl RefClient { .find(|id| !self.disabled_resources.contains(id)) } - fn resolved_domains(&self) -> impl Iterator)> + '_ { + fn resolved_domains(&self) -> impl Iterator)> + '_ { self.dns_records .iter() .filter(|(domain, _)| self.dns_resource_by_domain(domain).is_some()) @@ -499,7 +500,7 @@ impl RefClient { .filter_map(|(domain, records)| { records .iter() - .any(|r| matches!(r, RecordType::A)) + .any(|r| matches!(r, &Rtype::A)) .then_some(domain) }) .collect() @@ -510,7 +511,7 @@ impl RefClient { .filter_map(|(domain, records)| { records .iter() - .any(|r| matches!(r, RecordType::AAAA)) + .any(|r| matches!(r, &Rtype::AAAA)) .then_some(domain) }) .collect() @@ -525,7 +526,7 @@ impl RefClient { return self .upstream_dns_resolvers .iter() - .map(|s| s.address()) + .map(DnsServer::address) .collect(); } @@ -687,32 +688,6 @@ pub(crate) fn ref_client_host( ref_client(tunnel_ip4s, tunnel_ip6s), latency(300), // TODO: Increase with #6062. ) - .prop_filter("at least one DNS server needs to be reachable", |host| { - // TODO: PRODUCTION CODE DOES NOT HANDLE THIS! - - let upstream_dns_resolvers = &host.inner().upstream_dns_resolvers; - let system_dns = &host.inner().system_dns_resolvers; - - if !upstream_dns_resolvers.is_empty() { - if host.ip4.is_none() && upstream_dns_resolvers.iter().all(|s| s.ip().is_ipv4()) { - return false; - } - if host.ip6.is_none() && upstream_dns_resolvers.iter().all(|s| s.ip().is_ipv6()) { - return false; - } - - return true; - } - - if host.ip4.is_none() && system_dns.iter().all(|s| s.is_ipv4()) { - return false; - } - if host.ip6.is_none() && system_dns.iter().all(|s| s.is_ipv6()) { - return false; - } - - true - }) } fn ref_client( @@ -725,26 +700,16 @@ fn ref_client( client_id(), private_key(), known_hosts(), - system_dns_servers(), - upstream_dns_servers(), ) .prop_map( - move |( - tunnel_ip4, - tunnel_ip6, - id, - key, - known_hosts, - system_dns_resolvers, - upstream_dns_resolvers, - )| RefClient { + move |(tunnel_ip4, tunnel_ip6, id, key, known_hosts)| RefClient { id, key, known_hosts, tunnel_ip4, tunnel_ip6, - system_dns_resolvers, - upstream_dns_resolvers, + system_dns_resolvers: Default::default(), + upstream_dns_resolvers: Default::default(), internet_resource: Default::default(), cidr_resources: IpNetworkTable::new(), dns_resources: Default::default(), diff --git a/rust/connlib/tunnel/src/tests/sim_dns.rs b/rust/connlib/tunnel/src/tests/sim_dns.rs new file mode 100644 index 000000000..583bb8cab --- /dev/null +++ b/rust/connlib/tunnel/src/tests/sim_dns.rs @@ -0,0 +1,124 @@ +use super::{ + sim_net::{host, Host}, + strategies::latency, +}; +use connlib_shared::DomainName; +use domain::{ + base::{ + iana::{Class, Rcode}, + Message, MessageBuilder, Record, Rtype, ToName as _, Ttl, + }, + rdata::AllRecordData, +}; +use firezone_relay::IpStack; +use proptest::{ + arbitrary::any, + strategy::{Just, Strategy}, +}; +use snownet::Transmit; +use std::{ + borrow::Cow, + collections::{BTreeMap, HashSet}, + fmt, + net::{IpAddr, SocketAddr}, + time::Instant, +}; +use uuid::Uuid; + +pub(crate) fn dns_server_id() -> impl Strategy { + any::().prop_map(DnsServerId::from_u128) +} + +pub(crate) fn ref_dns_host(addr: SocketAddr) -> impl Strategy> { + let ip = addr.ip(); + let port = addr.port(); + + host( + Just(IpStack::from(ip)), + Just(port), + Just(RefDns {}), + latency(50), + ) +} + +#[derive(Debug, Clone)] +pub(crate) struct RefDns {} + +#[derive(Debug)] +pub(crate) struct SimDns {} + +impl SimDns { + pub(crate) fn receive( + &mut self, + global_dns_records: &BTreeMap>, + transmit: Transmit, + _now: Instant, + ) -> Option> { + let query = Message::from_octets(&transmit.payload).ok()?; + + let response = MessageBuilder::new_vec(); + let mut answers = response.start_answer(&query, Rcode::NOERROR).unwrap(); + + let query = query.sole_question().unwrap(); + let name = query.qname().to_vec(); + + let records = global_dns_records + .get(&name) + .into_iter() + .flatten() + .filter(|ip| { + #[allow(clippy::wildcard_enum_match_arm)] + match query.qtype() { + Rtype::A => ip.is_ipv4(), + Rtype::AAAA => ip.is_ipv6(), + _ => todo!(), + } + }) + .copied() + .map(|ip| match ip { + IpAddr::V4(v4) => AllRecordData::, DomainName>::A(v4.into()), + IpAddr::V6(v6) => AllRecordData::, DomainName>::Aaaa(v6.into()), + }) + .map(|rdata| Record::new(name.clone(), Class::IN, Ttl::from_days(1), rdata)); + + for record in records { + answers.push(record).unwrap(); + } + + let payload = answers.finish(); + + tracing::debug!(%name, "Responding to DNS query"); + + Some(Transmit { + src: Some(transmit.dst), + dst: transmit.src.unwrap(), + payload: Cow::Owned(payload), + }) + } +} + +#[derive(Hash, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +pub(crate) struct DnsServerId(Uuid); + +impl DnsServerId { + #[cfg(feature = "proptest")] + pub fn from_u128(v: u128) -> Self { + Self(Uuid::from_u128(v)) + } +} + +impl fmt::Display for DnsServerId { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + if cfg!(feature = "proptest") { + write!(f, "{:X}", self.0.as_u128()) + } else { + write!(f, "{}", self.0) + } + } +} + +impl fmt::Debug for DnsServerId { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Display::fmt(&self, f) + } +} diff --git a/rust/connlib/tunnel/src/tests/sim_gateway.rs b/rust/connlib/tunnel/src/tests/sim_gateway.rs index e11c5e8f0..3ed441e7a 100644 --- a/rust/connlib/tunnel/src/tests/sim_gateway.rs +++ b/rust/connlib/tunnel/src/tests/sim_gateway.rs @@ -4,7 +4,7 @@ use super::{ sim_relay::{map_explode, SimRelay}, strategies::latency, }; -use crate::{tests::sut::hickory_name_to_domain, GatewayState}; +use crate::GatewayState; use connlib_shared::{ messages::{GatewayId, RelayId}, DomainName, @@ -67,6 +67,8 @@ impl SimGateway { ) -> Option> { let packet = packet.to_owned(); + // TODO: Instead of handling these things inline, here, should we dispatch them via `RoutingTable`? + if packet.as_icmp().is_some() { self.received_icmp_requests.push_back(packet.clone()); @@ -78,11 +80,7 @@ impl SimGateway { if packet.as_udp().is_some() { let response = ip_packet::make::dns_ok_response(packet, |name| { - global_dns_records - .get(&hickory_name_to_domain(name.clone())) - .cloned() - .into_iter() - .flatten() + global_dns_records.get(name).cloned().into_iter().flatten() }); let transmit = self.sut.encapsulate(response, now)?.into_owned(); diff --git a/rust/connlib/tunnel/src/tests/sim_net.rs b/rust/connlib/tunnel/src/tests/sim_net.rs index 718fd095c..adf97f8a3 100644 --- a/rust/connlib/tunnel/src/tests/sim_net.rs +++ b/rust/connlib/tunnel/src/tests/sim_net.rs @@ -1,3 +1,4 @@ +use super::sim_dns::DnsServerId; use crate::tests::buffered_transmits::BufferedTransmits; use crate::tests::strategies::documentation_ip6s; use connlib_shared::messages::{ClientId, GatewayId, RelayId}; @@ -122,6 +123,15 @@ impl Host { } } + pub(crate) fn single_socket(&self) -> SocketAddr { + match (self.ip4, self.ip6) { + (None, Some(ip6)) => SocketAddr::new(ip6.into(), self.default_port), + (Some(ip4), None) => SocketAddr::new(ip4.into(), self.default_port), + (Some(_), Some(_)) => panic!("Dual-stack host"), + (None, None) => panic!("No socket available"), + } + } + pub(crate) fn latency(&self) -> Duration { self.latency } @@ -253,6 +263,7 @@ pub(crate) enum HostId { Client(ClientId), Gateway(GatewayId), Relay(RelayId), + DnsServer(DnsServerId), Stale, } @@ -274,6 +285,12 @@ impl From for HostId { } } +impl From for HostId { + fn from(v: DnsServerId) -> Self { + Self::DnsServer(v) + } +} + pub(crate) fn host( socket_ips: impl Strategy, default_port: impl Strategy, diff --git a/rust/connlib/tunnel/src/tests/strategies.rs b/rust/connlib/tunnel/src/tests/strategies.rs index bf7e7cafc..60cb96f33 100644 --- a/rust/connlib/tunnel/src/tests/strategies.rs +++ b/rust/connlib/tunnel/src/tests/strategies.rs @@ -1,4 +1,9 @@ -use super::{sim_net::Host, sim_relay::ref_relay_host, stub_portal::StubPortal}; +use super::{ + sim_dns::{dns_server_id, ref_dns_host, DnsServerId, RefDns}, + sim_net::Host, + sim_relay::ref_relay_host, + stub_portal::StubPortal, +}; use crate::client::{IPV4_RESOURCES, IPV6_RESOURCES}; use connlib_shared::{ messages::{ @@ -6,7 +11,7 @@ use connlib_shared::{ ResourceDescriptionCidr, ResourceDescriptionDns, ResourceDescriptionInternet, Site, SiteId, }, - DnsServer, GatewayId, RelayId, + GatewayId, RelayId, }, proptest::{ any_ip_network, cidr_resource, dns_resource, domain_name, gateway_id, relay_id, site, @@ -14,42 +19,15 @@ use connlib_shared::{ DomainName, }; use ip_network::{IpNetwork, Ipv4Network, Ipv6Network}; -use itertools::Itertools as _; +use itertools::Itertools; use prop::sample; use proptest::{collection, prelude::*}; use std::{ collections::{BTreeMap, HashMap, HashSet}, - net::{IpAddr, Ipv4Addr, Ipv6Addr}, + net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, time::Duration, }; -pub(crate) fn upstream_dns_servers() -> impl Strategy> { - let ip4_dns_servers = collection::vec( - any::().prop_map(|ip| DnsServer::from((ip, 53))), - 1..4, - ); - let ip6_dns_servers = collection::vec( - any::().prop_map(|ip| DnsServer::from((ip, 53))), - 1..4, - ); - - // TODO: PRODUCTION CODE DOES NOT HAVE A SAFEGUARD FOR THIS YET. - // AN ADMIN COULD CONFIGURE ONLY IPv4 SERVERS IN WHICH CASE WE ARE SCREWED IF THE CLIENT ONLY HAS IPv6 CONNECTIVITY. - - prop_oneof![ - Just(Vec::new()), - (ip4_dns_servers, ip6_dns_servers).prop_map(|(mut ip4_servers, ip6_servers)| { - ip4_servers.extend(ip6_servers); - - ip4_servers - }) - ] -} - -pub(crate) fn system_dns_servers() -> impl Strategy> { - collection::vec(any::(), 1..4) // Always need at least 1 system DNS server. TODO: Should we test what happens if we don't? -} - pub(crate) fn global_dns_records() -> impl Strategy>> { collection::btree_map( domain_name(2..4).prop_map(|d| d.parse().unwrap()), @@ -144,6 +122,38 @@ pub(crate) fn relays() -> impl Strategy>> { collection::btree_map(relay_id(), ref_relay_host(), 1..=2) } +/// Sample a list of DNS servers. +/// +/// We make sure to always have at least 1 IPv4 and 1 IPv6 DNS server. +pub(crate) fn dns_servers() -> impl Strategy>> { + let ip4_dns_servers = collection::hash_set( + any::().prop_map(|ip| SocketAddr::from((ip, 53))), + 1..4, + ); + let ip6_dns_servers = collection::hash_set( + any::().prop_map(|ip| SocketAddr::from((ip, 53))), + 1..4, + ); + + (ip4_dns_servers, ip6_dns_servers).prop_flat_map(|(ip4_dns_servers, ip6_dns_servers)| { + let servers = Vec::from_iter(ip4_dns_servers.into_iter().chain(ip6_dns_servers)); + + // First, generate a unique number of IDs, one for each DNS server. + let ids = collection::hash_set(dns_server_id(), servers.len()); + + (ids, Just(servers)) + .prop_flat_map(move |(ids, servers)| { + let ids = ids.into_iter(); + + // Second, zip the IDs and addresses together. + ids.zip(servers) + .map(|(id, addr)| (Just(id), ref_dns_host(addr))) + .collect::>() + }) + .prop_map(BTreeMap::from_iter) // Third, turn the `Vec` of tuples into a `BTreeMap`. + }) +} + fn any_site(sites: HashSet) -> impl Strategy { sample::select(Vec::from_iter(sites)) } diff --git a/rust/connlib/tunnel/src/tests/sut.rs b/rust/connlib/tunnel/src/tests/sut.rs index 37f6917c3..fc5ce336c 100644 --- a/rust/connlib/tunnel/src/tests/sut.rs +++ b/rust/connlib/tunnel/src/tests/sut.rs @@ -1,6 +1,7 @@ use super::buffered_transmits::BufferedTransmits; use super::reference::ReferenceState; use super::sim_client::SimClient; +use super::sim_dns::{DnsServerId, SimDns}; use super::sim_gateway::SimGateway; use super::sim_net::{Host, HostId, RoutingTable}; use super::sim_relay::SimRelay; @@ -10,17 +11,12 @@ use crate::tests::assertions::*; use crate::tests::flux_capacitor::FluxCapacitor; use crate::tests::transition::Transition; use crate::utils::earliest; -use crate::{dns::DnsQuery, ClientEvent, GatewayEvent, Request}; +use crate::{ClientEvent, GatewayEvent, Request}; use connlib_shared::messages::client::ResourceDescription; use connlib_shared::{ messages::{ClientId, GatewayId, Interface, RelayId}, DomainName, }; -use hickory_proto::{ - op::Query, - rr::{RData, Record, RecordType}, -}; -use hickory_resolver::lookup::Lookup; use proptest_state_machine::{ReferenceStateMachine, StateMachineTest}; use secrecy::ExposeSecret as _; use snownet::Transmit; @@ -28,8 +24,6 @@ use std::iter; use std::{ collections::{BTreeMap, HashSet}, net::IpAddr, - str::FromStr as _, - sync::Arc, time::{Duration, Instant}, }; use tracing::debug_span; @@ -47,6 +41,7 @@ pub(crate) struct TunnelTest { client: Host, gateways: BTreeMap>, relays: BTreeMap>, + dns_servers: BTreeMap>, drop_direct_client_traffic: bool, network: RoutingTable, @@ -101,6 +96,16 @@ impl StateMachineTest for TunnelTest { }) .collect::>(); + let dns_servers = ref_state + .dns_servers + .iter() + .map(|(did, dns_server)| { + let dns_server = dns_server.map(|_, _, _| SimDns {}, debug_span!("dns", %did)); + + (*did, dns_server) + }) + .collect::>(); + // Configure client and gateway with the relays. client.exec_mut(|c| c.update_relays(iter::empty(), relays.iter(), flux_capacitor.now())); for gateway in gateways.values_mut() { @@ -116,6 +121,7 @@ impl StateMachineTest for TunnelTest { gateways, logger, relays, + dns_servers, }; let mut buffered_transmits = BufferedTransmits::default(); @@ -221,12 +227,12 @@ impl StateMachineTest for TunnelTest { buffered_transmits.push_from(transmit, &state.client, now); } - Transition::UpdateSystemDnsServers { servers } => { + Transition::UpdateSystemDnsServers(servers) => { state .client .exec_mut(|c| c.sut.update_system_resolvers(servers)); } - Transition::UpdateUpstreamDnsServers { servers } => { + Transition::UpdateUpstreamDnsServers(servers) => { state.client.exec_mut(|c| { c.sut.update_interface_config(Interface { ipv4: c.sut.tunnel_ip4().unwrap(), @@ -259,7 +265,7 @@ impl StateMachineTest for TunnelTest { // Simulate receiving `init`. state.client.exec_mut(|c| { - let _ = c.sut.update_interface_config(Interface { + c.sut.update_interface_config(Interface { ipv4, ipv6, upstream_dns, @@ -454,10 +460,6 @@ impl TunnelTest { ); continue; } - if let Some(query) = self.client.exec_mut(|client| client.sut.poll_dns_queries()) { - self.on_forwarded_dns_query(query, ref_state); - continue; - } self.client.exec_mut(|sim| { while let Some(packet) = sim.sut.poll_packets() { sim.on_received_packet(packet) @@ -514,7 +516,7 @@ impl TunnelTest { for (_, relay) in self.relays.iter_mut() { while let Some(transmit) = relay.poll_transmit(now) { - let Some(reply) = relay.exec_mut(|g| g.receive(transmit, now)) else { + let Some(reply) = relay.exec_mut(|r| r.receive(transmit, now)) else { continue; }; @@ -523,6 +525,18 @@ impl TunnelTest { relay.exec_mut(|r| r.sut.handle_timeout(now)) } + + for (_, dns_server) in self.dns_servers.iter_mut() { + while let Some(transmit) = dns_server.poll_transmit(now) { + let Some(reply) = + dns_server.exec_mut(|d| d.receive(global_dns_records, transmit, now)) + else { + continue; + }; + + buffered_transmits.push_from(reply, dns_server, now); + } + } } fn poll_timeout(&mut self) -> Option { @@ -591,6 +605,12 @@ impl TunnelTest { HostId::Stale => { tracing::debug!(%dst, "Dropping packet because host roamed away or is offline"); } + HostId::DnsServer(id) => { + self.dns_servers + .get_mut(&id) + .expect("unknown DNS server") + .receive(transmit, now); + } } } @@ -780,41 +800,6 @@ impl TunnelTest { ClientEvent::TunRoutesUpdated { .. } => {} } } - - // TODO: Should we vary the following things via proptests? - // - Forwarded DNS query timing out? - // - hickory error? - // - TTL? - fn on_forwarded_dns_query(&mut self, query: DnsQuery<'static>, ref_state: &ReferenceState) { - let all_ips = &ref_state - .global_dns_records - .get(&query.name) - .expect("Forwarded DNS query to be for known domain"); - - let name = domain_to_hickory_name(query.name.clone()); - let requested_type = query.record_type; - - let record_data = all_ips - .iter() - .filter_map(|ip| match (requested_type, ip) { - (RecordType::A, IpAddr::V4(v4)) => Some(RData::A((*v4).into())), - (RecordType::AAAA, IpAddr::V6(v6)) => Some(RData::AAAA((*v6).into())), - (RecordType::A, IpAddr::V6(_)) | (RecordType::AAAA, IpAddr::V4(_)) => None, - _ => unreachable!(), - }) - .map(|rdata| Record::from_rdata(name.clone(), 86400_u32, rdata)) - .collect::>(); - - self.client.exec_mut(|c| { - c.sut.on_dns_result( - query, - Ok(Ok(Ok(Lookup::new_with_max_ttl( - Query::query(name, requested_type), - record_data, - )))), - ) - }) - } } fn on_gateway_event( @@ -837,22 +822,3 @@ fn on_gateway_event( GatewayEvent::RefreshDns { .. } => todo!(), } } - -pub(crate) fn hickory_name_to_domain(mut name: hickory_proto::rr::Name) -> DomainName { - name.set_fqdn(false); // Hack to work around hickory always parsing as FQ - let name = name.to_string(); - - let domain = DomainName::from_chars(name.chars()).unwrap(); - debug_assert_eq!(name, domain.to_string()); - - domain -} - -pub(crate) fn domain_to_hickory_name(domain: DomainName) -> hickory_proto::rr::Name { - let domain = domain.to_string(); - - let name = hickory_proto::rr::Name::from_str(&domain).unwrap(); - debug_assert_eq!(name.to_string(), domain); - - name -} diff --git a/rust/connlib/tunnel/src/tests/transition.rs b/rust/connlib/tunnel/src/tests/transition.rs index 006c6135a..6b2f4bbc5 100644 --- a/rust/connlib/tunnel/src/tests/transition.rs +++ b/rust/connlib/tunnel/src/tests/transition.rs @@ -1,9 +1,12 @@ -use super::sim_net::{any_ip_stack, any_port, Host}; +use super::{ + sim_dns::RefDns, + sim_net::{any_ip_stack, any_port, Host}, +}; use connlib_shared::{ messages::{client::ResourceDescription, DnsServer, RelayId, ResourceId}, DomainName, }; -use hickory_proto::rr::RecordType; +use domain::base::Rtype; use proptest::{prelude::*, sample}; use std::{ collections::{BTreeMap, HashSet}, @@ -50,16 +53,16 @@ pub(crate) enum Transition { SendDnsQuery { domain: DomainName, /// The type of DNS query we should send. - r_type: RecordType, + r_type: Rtype, /// The DNS query ID. query_id: u16, dns_server: SocketAddr, }, /// The system's DNS servers changed. - UpdateSystemDnsServers { servers: Vec }, + UpdateSystemDnsServers(Vec), /// The upstream DNS servers changed. - UpdateUpstreamDnsServers { servers: Vec }, + UpdateUpstreamDnsServers(Vec), /// Roam the client to a new pair of sockets. RoamClient { @@ -165,7 +168,7 @@ where ( domain, dns_server.prop_map_into(), - prop_oneof![Just(RecordType::A), Just(RecordType::AAAA)], + prop_oneof![Just(Rtype::A), Just(Rtype::AAAA)], any::(), ) .prop_map( @@ -185,3 +188,27 @@ pub(crate) fn roam_client() -> impl Strategy { port, }) } + +pub(crate) fn update_system_dns_servers( + dns_servers: Vec>, +) -> impl Strategy { + let max = dns_servers.len(); + + sample::subsequence(dns_servers, ..=max).prop_map(|seq| { + Transition::UpdateSystemDnsServers( + seq.into_iter().map(|h| h.single_socket().ip()).collect(), + ) + }) +} + +pub(crate) fn update_upstream_dns_servers( + dns_servers: Vec>, +) -> impl Strategy { + let max = dns_servers.len(); + + sample::subsequence(dns_servers, ..=max).prop_map(|seq| { + Transition::UpdateUpstreamDnsServers( + seq.into_iter().map(|h| h.single_socket().into()).collect(), + ) + }) +} diff --git a/rust/headless-client/Cargo.toml b/rust/headless-client/Cargo.toml index 3688a9139..d95f16b87 100644 --- a/rust/headless-client/Cargo.toml +++ b/rust/headless-client/Cargo.toml @@ -26,7 +26,7 @@ socket-factory = { workspace = true } thiserror = { version = "1.0", default-features = false } # This actually relies on many other features in Tokio, so this will probably # fail to build outside the workspace. -tokio = { workspace = true, features = ["macros", "signal", "process", "time"] } +tokio = { workspace = true, features = ["macros", "signal", "process", "time", "rt-multi-thread"] } tokio-stream = "0.1.15" tokio-util = { version = "0.7.11", features = ["codec"] } tracing = { workspace = true } diff --git a/rust/ip-packet/Cargo.toml b/rust/ip-packet/Cargo.toml index 6f3275fe3..4863910a1 100644 --- a/rust/ip-packet/Cargo.toml +++ b/rust/ip-packet/Cargo.toml @@ -10,7 +10,7 @@ publish = false proptest = ["dep:proptest"] [dependencies] -hickory-proto = { workspace = true } +domain = "0.10.1" pnet_packet = { version = "0.35" } proptest = { version = "1", optional = true } thiserror = "1" diff --git a/rust/ip-packet/src/lib.rs b/rust/ip-packet/src/lib.rs index d90a7041b..dcb5f758e 100644 --- a/rust/ip-packet/src/lib.rs +++ b/rust/ip-packet/src/lib.rs @@ -8,6 +8,7 @@ pub use pnet_packet::*; #[cfg(all(test, feature = "proptest"))] mod proptests; +use domain::base::Message; use pnet_packet::{ icmp::{ destination_unreachable::IcmpCodes, echo_reply::MutableEchoReplyPacket, @@ -1409,9 +1410,9 @@ impl<'a> IpPacket<'a> { } /// Unwrap this [`IpPacket`] as a DNS message, panicking in case it is not. - pub fn unwrap_as_dns(&self) -> hickory_proto::op::Message { + pub fn unwrap_as_dns(&self) -> Message> { let udp = self.unwrap_as_udp(); - let message = match hickory_proto::op::Message::from_vec(udp.payload()) { + let message = match Message::from_octets(udp.payload().to_vec()) { Ok(message) => message, Err(e) => { panic!("Failed to parse UDP payload as DNS message: {e}"); diff --git a/rust/ip-packet/src/make.rs b/rust/ip-packet/src/make.rs index 14aef1ad9..24d27b7a0 100644 --- a/rust/ip-packet/src/make.rs +++ b/rust/ip-packet/src/make.rs @@ -1,9 +1,12 @@ //! Factory module for making all kinds of packets. use crate::{IpPacket, MutableIpPacket}; -use hickory_proto::{ - op::{Message, Query, ResponseCode}, - rr::{Name, RData, Record, RecordType}, +use domain::{ + base::{ + iana::{Class, Opcode, Rcode}, + MessageBuilder, Name, Question, Record, Rtype, ToName, Ttl, + }, + rdata::AllRecordData, }; use pnet_packet::{ ip::IpNextHeaderProtocol, @@ -244,24 +247,26 @@ where } pub fn dns_query( - domain: Name, - kind: RecordType, + domain: Name>, + kind: Rtype, src: SocketAddr, dst: SocketAddr, id: u16, ) -> MutableIpPacket<'static> { // Create the DNS query message - let mut msg = Message::new(); - msg.set_message_type(hickory_proto::op::MessageType::Query); - msg.set_op_code(hickory_proto::op::OpCode::Query); - msg.set_recursion_desired(true); - msg.set_id(id); + let mut msg_builder = MessageBuilder::new_vec(); + + msg_builder.header_mut().set_opcode(Opcode::QUERY); + msg_builder.header_mut().set_rd(true); + msg_builder.header_mut().set_id(id); // Create the query - let query = Query::query(domain, kind); - msg.add_query(query); + let mut question_builder = msg_builder.question(); + question_builder + .push(Question::new_in(domain, kind)) + .unwrap(); - let payload = msg.to_vec().unwrap(); + let payload = question_builder.finish(); udp_packet(src.ip(), dst.ip(), src.port(), dst.port(), payload) } @@ -269,60 +274,42 @@ pub fn dns_query( /// Makes a DNS response to the given DNS query packet, using a resolver callback. pub fn dns_ok_response( packet: IpPacket<'static>, - resolve: impl Fn(&Name) -> I, + resolve: impl Fn(&Name>) -> I, ) -> MutableIpPacket<'static> where I: Iterator, { let udp = packet.unwrap_as_udp(); - let mut query = packet.unwrap_as_dns(); + let query = packet.unwrap_as_dns(); - let mut response = Message::new(); - response.set_id(query.id()); - response.set_message_type(hickory_proto::op::MessageType::Response); + let response = MessageBuilder::new_vec(); + let mut answers = response.start_answer(&query, Rcode::NOERROR).unwrap(); - for query in query.take_queries() { - response.add_query(query.clone()); + for query in query.question() { + let query = query.unwrap(); + let name = query.qname().to_name(); - let records = resolve(query.name()) + let records = resolve(&name) .filter(|ip| { #[allow(clippy::wildcard_enum_match_arm)] - match query.query_type() { - RecordType::A => ip.is_ipv4(), - RecordType::AAAA => ip.is_ipv6(), + match query.qtype() { + Rtype::A => ip.is_ipv4(), + Rtype::AAAA => ip.is_ipv6(), _ => todo!(), } }) .map(|ip| match ip { - IpAddr::V4(v4) => RData::A(v4.into()), - IpAddr::V6(v6) => RData::AAAA(v6.into()), + IpAddr::V4(v4) => AllRecordData::, Name>>::A(v4.into()), + IpAddr::V6(v6) => AllRecordData::, Name>>::Aaaa(v6.into()), }) - .map(|rdata| Record::from_rdata(query.name().clone(), 86400_u32, rdata)); + .map(|rdata| Record::new(name.clone(), Class::IN, Ttl::from_days(1), rdata)); - response.add_answers(records); + for record in records { + answers.push(record).unwrap(); + } } - let payload = response.to_vec().unwrap(); - - udp_packet( - packet.destination(), - packet.source(), - udp.get_destination(), - udp.get_source(), - payload, - ) -} - -/// Makes a DNS response to the given DNS query packet, using the given error code. -pub fn dns_err_response(packet: IpPacket<'static>, code: ResponseCode) -> MutableIpPacket<'static> { - let udp = packet.unwrap_as_udp(); - let query = packet.unwrap_as_dns(); - - debug_assert_ne!(code, ResponseCode::NoError); - - let response = Message::error_msg(query.id(), query.op_code(), code); - - let payload = response.to_vec().unwrap(); + let payload = answers.finish(); udp_packet( packet.destination(), diff --git a/rust/socket-factory/Cargo.toml b/rust/socket-factory/Cargo.toml index b9ae869ab..d8f003aca 100644 --- a/rust/socket-factory/Cargo.toml +++ b/rust/socket-factory/Cargo.toml @@ -4,12 +4,7 @@ version = "0.1.0" edition = "2021" [dependencies] -async-trait = { version = "0.1", optional = true } -hickory-proto = { workspace = true, optional = true } quinn-udp = { git = "https://github.com/quinn-rs/quinn", branch = "main" } socket2 = { workspace = true } tokio = { version = "1.39", features = ["net"] } tracing = "0.1" - -[features] -hickory = ["dep:hickory-proto", "dep:async-trait"] diff --git a/rust/socket-factory/src/lib.rs b/rust/socket-factory/src/lib.rs index fe7967733..f17d42103 100644 --- a/rust/socket-factory/src/lib.rs +++ b/rust/socket-factory/src/lib.rs @@ -301,62 +301,3 @@ impl UdpSocket { Ok(Some(src)) } } - -#[cfg(feature = "hickory")] -mod hickory { - use super::*; - use hickory_proto::{ - udp::DnsUdpSocket as DnsUdpSocketTrait, udp::UdpSocket as UdpSocketTrait, TokioTime, - }; - use tokio::net::UdpSocket as TokioUdpSocket; - - #[async_trait::async_trait] - impl UdpSocketTrait for crate::UdpSocket { - /// setups up a "client" udp connection that will only receive packets from the associated address - async fn connect(addr: SocketAddr) -> io::Result { - let inner = ::connect(addr).await?; - let socket = Self::new(inner)?; - - Ok(socket) - } - - /// same as connect, but binds to the specified local address for sending address - async fn connect_with_bind(addr: SocketAddr, bind_addr: SocketAddr) -> io::Result { - let inner = - ::connect_with_bind(addr, bind_addr).await?; - let socket = Self::new(inner)?; - - Ok(socket) - } - - /// a "server" UDP socket, that bind to the local listening address, and unbound remote address (can receive from anything) - async fn bind(addr: SocketAddr) -> io::Result { - let inner = ::bind(addr).await?; - let socket = Self::new(inner)?; - - Ok(socket) - } - } - - #[cfg(feature = "hickory")] - impl DnsUdpSocketTrait for crate::UdpSocket { - type Time = TokioTime; - - fn poll_recv_from( - &self, - cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { - ::poll_recv_from(&self.inner, cx, buf) - } - - fn poll_send_to( - &self, - cx: &mut Context<'_>, - buf: &[u8], - target: SocketAddr, - ) -> Poll> { - ::poll_send_to(&self.inner, cx, buf, target) - } - } -}