diff --git a/rust/Cargo.lock b/rust/Cargo.lock index f274dedad..ec9f6dc62 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -1608,18 +1608,6 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a3d8a32ae18130a3c84dd492d4215c3d913c3b07c6b63c2eb3eb7ff1101ab7bf" -[[package]] -name = "enum-as-inner" -version = "0.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5ffccbb6966c05b32ef8fbac435df276c4ae4d3dc55a8cd0eb9745e6c12f546a" -dependencies = [ - "heck 0.4.1", - "proc-macro2", - "quote", - "syn 2.0.72", -] - [[package]] name = "enumflags2" version = "0.7.9" @@ -1999,11 +1987,8 @@ dependencies = [ "domain", "firezone-relay", "futures", - "futures-bounded", "futures-util", "hex", - "hickory-proto", - "hickory-resolver", "ip-packet", "ip_network", "ip_network_table", @@ -2025,6 +2010,7 @@ dependencies = [ "tracing-appender", "tracing-subscriber", "tun", + "uuid", ] [[package]] @@ -2121,9 +2107,9 @@ dependencies = [ [[package]] name = "futures-bounded" -version = "0.2.3" +version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e1e2774cc104e198ef3d3e1ff4ab40f86fa3245d6cb6a3a46174f21463cee173" +checksum = "91f328e7fb845fc832912fb6a34f40cf6d1888c92f974d1893a54e97b5ff542e" dependencies = [ "futures-timer", "futures-util", @@ -2655,49 +2641,6 @@ version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6fe2267d4ed49bc07b63801559be28c718ea06c4738b7a03c94df7386d2cde46" -[[package]] -name = "hickory-proto" -version = "0.24.0" -source = "git+https://github.com/hickory-dns/hickory-dns?rev=a3669bd80f3f7b97f0c301c15f1cba6368d97b63#a3669bd80f3f7b97f0c301c15f1cba6368d97b63" -dependencies = [ - "async-trait", - "cfg-if", - "data-encoding", - "enum-as-inner", - "futures-channel", - "futures-io", - "futures-util", - "idna", - "ipnet", - "once_cell", - "rand 0.8.5", - "thiserror", - "tinyvec", - "tokio", - "tracing", - "url", -] - -[[package]] -name = "hickory-resolver" -version = "0.24.0" -source = "git+https://github.com/hickory-dns/hickory-dns?rev=a3669bd80f3f7b97f0c301c15f1cba6368d97b63#a3669bd80f3f7b97f0c301c15f1cba6368d97b63" -dependencies = [ - "cfg-if", - "futures-util", - "hickory-proto", - "ipconfig", - "lru-cache", - "once_cell", - "parking_lot", - "rand 0.8.5", - "resolv-conf", - "smallvec 1.13.2", - "thiserror", - "tokio", - "tracing", -] - [[package]] name = "hkdf" version = "0.12.4" @@ -2725,17 +2668,6 @@ dependencies = [ "windows-sys 0.52.0", ] -[[package]] -name = "hostname" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3c731c3e10504cc8ed35cfe2f1db4c9274c3d35fa486e3b31df46f068ef3e867" -dependencies = [ - "libc", - "match_cfg", - "winapi", -] - [[package]] name = "hostname" version = "0.4.0" @@ -3061,7 +2993,7 @@ dependencies = [ name = "ip-packet" version = "0.1.0" dependencies = [ - "hickory-proto", + "domain", "pnet_packet", "proptest", "test-strategy", @@ -3355,12 +3287,6 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dd1bc4d24ad230d21fb898d1116b1801d7adfc449d42026475862ab48b11e70e" -[[package]] -name = "linked-hash-map" -version = "0.5.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0717cef1bc8b636c6e1c1bbdefc09e6322da8a9321966e8928ef80d20f7f770f" - [[package]] name = "linux-raw-sys" version = "0.4.13" @@ -3404,15 +3330,6 @@ dependencies = [ "tracing-subscriber", ] -[[package]] -name = "lru-cache" -version = "0.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "31e24f1ad8321ca0e8a1e0ac13f23cb668e6f5466c2c57319f6a5cf1cc8e3b1c" -dependencies = [ - "linked-hash-map", -] - [[package]] name = "mac" version = "0.1.1" @@ -3464,12 +3381,6 @@ dependencies = [ "tendril", ] -[[package]] -name = "match_cfg" -version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ffbee8634e0d45d258acb448e7eaab3fce7a0a467395d4d9f228e3c1f01fb2e4" - [[package]] name = "matchers" version = "0.1.0" @@ -4438,7 +4349,7 @@ dependencies = [ "base64 0.22.1", "futures", "hex", - "hostname 0.4.0", + "hostname", "libc", "rand_core 0.6.4", "secrecy", @@ -5044,7 +4955,6 @@ version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "52e44394d2086d010551b14b53b1f24e31647570cd1deb0379e2c21b329aba00" dependencies = [ - "hostname 0.3.1", "quick-error", ] @@ -5636,8 +5546,6 @@ dependencies = [ name = "socket-factory" version = "0.1.0" dependencies = [ - "async-trait", - "hickory-proto", "quinn-udp", "socket2", "tokio", diff --git a/rust/bin-shared/Cargo.toml b/rust/bin-shared/Cargo.toml index a2bbe74b9..59bdb6f5c 100644 --- a/rust/bin-shared/Cargo.toml +++ b/rust/bin-shared/Cargo.toml @@ -20,7 +20,7 @@ tun = { workspace = true } url = { version = "2.5.2", default-features = false } [dev-dependencies] -tokio = { workspace = true, features = ["macros"] } +tokio = { workspace = true, features = ["macros", "rt-multi-thread"] } [target.'cfg(target_os = "linux")'.dependencies] libc = "0.2" diff --git a/rust/connlib/clients/android/Cargo.toml b/rust/connlib/clients/android/Cargo.toml index 94095533b..a51601bbb 100644 --- a/rust/connlib/clients/android/Cargo.toml +++ b/rust/connlib/clients/android/Cargo.toml @@ -25,7 +25,7 @@ secrecy = { workspace = true } serde_json = "1" socket-factory = { workspace = true } thiserror = "1" -tokio = { workspace = true, features = ["rt"] } +tokio = { workspace = true, features = ["rt-multi-thread"] } tracing = { workspace = true, features = ["std", "attributes"] } tracing-appender = "0.2" tracing-subscriber = { workspace = true } diff --git a/rust/connlib/clients/apple/Cargo.toml b/rust/connlib/clients/apple/Cargo.toml index 18f406a6d..e085ad16c 100644 --- a/rust/connlib/clients/apple/Cargo.toml +++ b/rust/connlib/clients/apple/Cargo.toml @@ -22,7 +22,7 @@ secrecy = { workspace = true } serde_json = "1" socket-factory = { workspace = true } swift-bridge = { workspace = true } -tokio = { workspace = true, features = ["rt"] } +tokio = { workspace = true, features = ["rt-multi-thread"] } tracing = { workspace = true } tracing-appender = "0.2" tracing-subscriber = "0.3" diff --git a/rust/connlib/clients/shared/src/eventloop.rs b/rust/connlib/clients/shared/src/eventloop.rs index 673846f34..7d4f27972 100644 --- a/rust/connlib/clients/shared/src/eventloop.rs +++ b/rust/connlib/clients/shared/src/eventloop.rs @@ -205,12 +205,7 @@ where fn handle_portal_inbound_message(&mut self, msg: IngressMessages) { match msg { IngressMessages::ConfigChanged(config) => { - if let Err(e) = self - .tunnel - .set_new_interface_config(config.interface.clone()) - { - tracing::warn!(?config, "Failed to update configuration: {e:?}"); - } + self.tunnel.set_new_interface_config(config.interface) } IngressMessages::IceCandidates(GatewayIceCandidates { gateway_id, @@ -225,14 +220,11 @@ where resources, relays, }) => { - if let Err(e) = self.tunnel.set_new_interface_config(interface) { - tracing::warn!("Failed to set interface on tunnel: {e}"); - return; - } + self.tunnel.set_new_interface_config(interface); + self.tunnel.set_resources(resources); + self.tunnel.update_relays(BTreeSet::default(), relays); tracing::info!("Firezone Started!"); - self.tunnel.set_resources(resources); - self.tunnel.update_relays(BTreeSet::default(), relays) } IngressMessages::ResourceCreatedOrUpdated(resource) => { self.tunnel.add_resource(resource); diff --git a/rust/connlib/tunnel/Cargo.toml b/rust/connlib/tunnel/Cargo.toml index aa621e0ea..45df4d9d5 100644 --- a/rust/connlib/tunnel/Cargo.toml +++ b/rust/connlib/tunnel/Cargo.toml @@ -12,11 +12,8 @@ chrono = { workspace = true } connlib-shared = { workspace = true } domain = { workspace = true } futures = { version = "0.3", default-features = false, features = ["std", "async-await", "executor"] } -futures-bounded = { workspace = true } futures-util = { version = "0.3", default-features = false, features = ["std", "async-await", "async-await-macro"] } hex = "0.4.3" -hickory-proto = { workspace = true } -hickory-resolver = { workspace = true, features = ["tokio-runtime"] } ip-packet = { workspace = true } ip_network = { version = "0.4", default-features = false } ip_network_table = { version = "0.2", default-features = false } @@ -27,17 +24,17 @@ rangemap = "1.5.1" secrecy = { workspace = true } serde = { version = "1.0", default-features = false, features = ["derive", "std"] } snownet = { workspace = true } -socket-factory = { workspace = true, features = ["hickory"] } +socket-factory = { workspace = true } socket2 = { workspace = true } thiserror = { version = "1.0", default-features = false } tokio = { workspace = true } tracing = { workspace = true, features = ["attributes"] } tun = { workspace = true } +uuid = { version = "1.10", default-features = false, features = ["std", "v4"] } [dev-dependencies] derivative = "2.2.0" firezone-relay = { workspace = true, features = ["proptest"] } -hickory-proto = { workspace = true } ip-packet = { workspace = true, features = ["proptest"] } proptest-state-machine = "0.3" rand = "0.8" diff --git a/rust/connlib/tunnel/src/client.rs b/rust/connlib/tunnel/src/client.rs index 54e6cee5e..1c5413a00 100644 --- a/rust/connlib/tunnel/src/client.rs +++ b/rust/connlib/tunnel/src/client.rs @@ -1,7 +1,6 @@ +use crate::dns; use crate::dns::StubResolver; -use crate::io::DnsQueryError; use crate::peer_store::PeerStore; -use crate::{dns, dns::DnsQuery}; use anyhow::Context; use bimap::BiMap; use connlib_shared::callbacks::Status; @@ -18,14 +17,14 @@ use ip_network::{IpNetwork, Ipv4Network, Ipv6Network}; use ip_network_table::IpNetworkTable; use ip_packet::{IpPacket, MutableIpPacket, Packet as _}; use itertools::Itertools; -use tracing::Level; use crate::peer::GatewayOnClient; use crate::utils::{self, earliest, turn}; use crate::{ClientEvent, ClientTunnel, Tun}; -use core::fmt; +use domain::base::Message; use secrecy::{ExposeSecret as _, Secret}; -use snownet::{ClientNode, RelaySocket}; +use snownet::{ClientNode, RelaySocket, Transmit}; +use std::borrow::Cow; use std::collections::hash_map::Entry; use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet, VecDeque}; use std::iter; @@ -140,29 +139,12 @@ impl ClientTunnel { pub fn set_new_dns(&mut self, new_dns: Vec) { // We store the sentinel dns both in the config and in the system's resolvers // but when we calculate the dns mapping, those are ignored. - let dns_changed = self.role_state.update_system_resolvers(new_dns); - - if !dns_changed { - return; - } - - self.io - .set_upstream_dns_servers(self.role_state.dns_mapping()); + self.role_state.update_system_resolvers(new_dns); } #[tracing::instrument(level = "trace", skip(self))] - pub fn set_new_interface_config( - &mut self, - config: InterfaceConfig, - ) -> connlib_shared::Result<()> { - let dns_changed = self.role_state.update_interface_config(config); - - if dns_changed { - self.io - .set_upstream_dns_servers(self.role_state.dns_mapping()); - } - - Ok(()) + pub fn set_new_interface_config(&mut self, config: InterfaceConfig) { + self.role_state.update_interface_config(config); } pub fn cleanup_connection(&mut self, id: ResourceId) { @@ -273,15 +255,19 @@ pub struct ClientState { /// The DNS resolvers configured on the system outside of connlib. system_resolvers: Vec, - /// DNS queries that we need to forward to the system resolver. - buffered_dns_queries: VecDeque>, - /// Maps from connlib-assigned IP of a DNS server back to the originally configured system DNS resolver. dns_mapping: BiMap, /// DNS queries that had their destination IP mangled because the servers is a CIDR resource. /// /// The [`Instant`] tracks when the DNS query expires. mangled_dns_queries: HashMap, + /// DNS queries that were forwarded to an upstream server. + /// + /// - The [`SocketAddr`] is the original source IP. + /// - The [`Instant`] tracks when the DNS query expires. + /// + /// We store an explicit expiry to avoid a memory leak in case of a non-responding DNS server. + forwarded_dns_queries: HashMap, /// Manages internal dns records and emits forwarding event when not internally handled stub_resolver: StubResolver, @@ -293,6 +279,7 @@ pub struct ClientState { buffered_events: VecDeque, buffered_packets: VecDeque>, + buffered_transmits: VecDeque>, } #[derive(Debug, Clone, PartialEq, Eq)] @@ -318,14 +305,15 @@ impl ClientState { buffered_events: Default::default(), interface_config: Default::default(), buffered_packets: Default::default(), - buffered_dns_queries: Default::default(), node: ClientNode::new(private_key.into(), seed), system_resolvers: Default::default(), sites_status: Default::default(), gateways_site: Default::default(), mangled_dns_queries: Default::default(), + forwarded_dns_queries: Default::default(), stub_resolver: StubResolver::new(known_hosts), disabled_resources: Default::default(), + buffered_transmits: Default::default(), } } @@ -432,7 +420,7 @@ impl ClientState { packet: MutableIpPacket<'_>, now: Instant, ) -> Option> { - let (packet, dst) = match self.handle_dns(packet) { + let (packet, dst) = match self.try_handle_dns_query(packet, now) { Ok(response) => { self.buffered_packets.push_back(response?.to_owned()); return None; @@ -492,6 +480,10 @@ impl ClientState { now: Instant, buffer: &'b mut [u8], ) -> Option> { + if let Some(response) = self.try_handle_forwarded_dns_response(from, packet) { + return Some(response); + }; + let (gid, packet) = self.node.decapsulate( local, from, @@ -627,35 +619,42 @@ impl ClientState { !interface.upstream_dns.is_empty() } - /// Attempt to handle the given packet as a DNS packet. + /// Attempt to handle the given packet as a DNS query packet. /// /// Returns `Ok` if the packet is in fact a DNS query with an optional response to send back. /// Returns `Err` if the packet is not a DNS query. - fn handle_dns<'a>( + fn try_handle_dns_query<'a>( &mut self, packet: MutableIpPacket<'a>, + now: Instant, ) -> Result>, (MutableIpPacket<'a>, IpAddr)> { match self .stub_resolver .handle(&self.dns_mapping, packet.as_immutable()) { Some(dns::ResolveStrategy::LocalResponse(query)) => Ok(Some(query)), - Some(dns::ResolveStrategy::ForwardQuery(query)) => { - // There's an edge case here, where the resolver's ip has been resolved before as - // a dns resource... we will ignore that weird case for now. - if let Some(upstream_dns) = self.dns_mapping.get_by_left(&query.query.destination()) - { - let ip = upstream_dns.ip(); + Some(dns::ResolveStrategy::ForwardQuery { + upstream: server, + query_id: id, + payload, + original_src, + }) => { + let ip = server.ip(); - // In case the DNS server is a CIDR resource, it needs to go through the tunnel. - if self.is_upstream_set_by_the_portal() - && self.active_cidr_resources.longest_match(ip).is_some() - { - return Err((packet, ip)); - } + // In case the DNS server is a CIDR resource, it needs to go through the tunnel. + if self.is_upstream_set_by_the_portal() + && self.active_cidr_resources.longest_match(ip).is_some() + { + return Err((packet, ip)); } - self.buffered_dns_queries.push_back(query.into_owned()); + self.forwarded_dns_queries + .insert(id, (original_src, now + IDS_EXPIRE)); + self.buffered_transmits.push_back(Transmit { + src: None, + dst: server, + payload: Cow::Owned(payload), + }); Ok(None) } @@ -666,46 +665,26 @@ impl ClientState { } } - #[tracing::instrument(level = "debug", skip_all, fields(name = %query.name, server = %query.query.destination()))] // On debug level, we can log potentially sensitive information such as domain names. - pub(crate) fn on_dns_result( + fn try_handle_forwarded_dns_response<'a>( &mut self, - query: DnsQuery<'static>, - response: Result< - Result< - Result, - futures_bounded::Timeout, - >, - DnsQueryError, - >, - ) { - let query = query.query; - let make_error_reply = { - let query = query.clone(); + from: SocketAddr, + packet: &[u8], + ) -> Option> { + // The sentinel DNS server shall be the source. If we don't have a sentinel DNS for this socket, it cannot be a DNS response. + let saddr = *self.dns_mapping.get_by_right(&DnsServer::from(from))?; + let sport = DNS_PORT; - |e: &dyn fmt::Display| { - // To avoid sensitive data getting into the logs, only log the error if debug logging is enabled. - // We always want to see a warning. - if tracing::enabled!(Level::DEBUG) { - tracing::warn!("DNS query failed: {e}"); - } else { - tracing::warn!("DNS query failed"); - }; + let message = Message::from_slice(packet).ok()?; + let query_id = message.header().id(); - ip_packet::make::dns_err_response(query, hickory_proto::op::ResponseCode::ServFail) - .into_immutable() - } - }; + let (destination, _) = self.forwarded_dns_queries.remove(&query_id)?; + let daddr = destination.ip(); + let dport = destination.port(); - let dns_reply = match response { - Ok(Ok(response)) => match dns::build_response_from_resolve_result(query, response) { - Ok(dns_reply) => dns_reply, - Err(e) => make_error_reply(&e), - }, - Ok(Err(timeout)) => make_error_reply(&timeout), - Err(e) => make_error_reply(&e), - }; - - self.buffered_packets.push_back(dns_reply); + Some( + ip_packet::make::udp_packet(saddr, daddr, sport, dport, packet.to_vec()) + .into_immutable(), + ) } pub fn on_connection_failed(&mut self, resource: ResourceId) { @@ -844,15 +823,13 @@ impl ClientState { .filter(|resource| self.is_resource_enabled(resource)) } - #[must_use] - pub(crate) fn update_system_resolvers(&mut self, new_dns: Vec) -> bool { + pub(crate) fn update_system_resolvers(&mut self, new_dns: Vec) { self.system_resolvers = new_dns; self.update_dns_mapping() } - #[must_use] - pub(crate) fn update_interface_config(&mut self, config: InterfaceConfig) -> bool { + pub(crate) fn update_interface_config(&mut self, config: InterfaceConfig) { self.interface_config = Some(config); self.update_dns_mapping() @@ -862,10 +839,6 @@ impl ClientState { self.buffered_packets.pop_front() } - pub fn poll_dns_queries(&mut self) -> Option> { - self.buffered_dns_queries.pop_front() - } - pub fn poll_timeout(&mut self) -> Option { // The number of mangled DNS queries is expected to be fairly small because we only track them whilst connecting to a CIDR resource that is a DNS server. // Thus, sorting these values on-demand even within `poll_timeout` is expected to be performant enough. @@ -878,6 +851,7 @@ impl ClientState { pub fn handle_timeout(&mut self, now: Instant) { self.node.handle_timeout(now); self.mangled_dns_queries.retain(|_, exp| now < *exp); + self.forwarded_dns_queries.retain(|_, (_, exp)| now < *exp); self.drain_node_events(); } @@ -965,7 +939,9 @@ impl ClientState { } pub(crate) fn poll_transmit(&mut self) -> Option> { - self.node.poll_transmit() + self.buffered_transmits + .pop_front() + .or_else(|| self.node.poll_transmit()) } /// Sets a new set of resources. @@ -1078,9 +1054,11 @@ impl ClientState { self.resources_gateways.remove(&id); } - fn update_dns_mapping(&mut self) -> bool { + fn update_dns_mapping(&mut self) { let Some(config) = &self.interface_config else { - return false; + tracing::debug!("Unable to update DNS servesr without interface configuration"); + + return; }; let effective_dns_servers = @@ -1089,7 +1067,9 @@ impl ClientState { if HashSet::<&DnsServer>::from_iter(effective_dns_servers.iter()) == HashSet::from_iter(self.dns_mapping.right_values()) { - return false; + tracing::debug!("Effective DNS servers are unchanged"); + + return; } let dns_mapping = sentinel_dns_mapping( @@ -1121,8 +1101,6 @@ impl ClientState { ip4: self.routes().filter_map(utils::ipv4).collect(), ip6: self.routes().filter_map(utils::ipv6).collect(), }); - - true } pub fn update_relays( @@ -1390,134 +1368,6 @@ mod tests { assert!(is_definitely_not_a_resource(ip("ff02::2"))) } - #[test] - fn update_system_dns_works() { - let mut client_state = ClientState::for_test(); - client_state.interface_config = Some(interface_config_without_dns()); - - let dns_changed = client_state.update_system_resolvers(vec![ip("1.1.1.1")]); - - assert!(dns_changed); - dns_mapping_is_exactly(client_state.dns_mapping(), vec![dns("1.1.1.1:53")]); - } - - #[test] - fn update_system_dns_without_change_is_a_no_op() { - let mut client_state = ClientState::for_test(); - client_state.interface_config = Some(interface_config_without_dns()); - - let _ = client_state.update_system_resolvers(vec![ip("1.1.1.1")]); - let dns_changed = client_state.update_system_resolvers(vec![ip("1.1.1.1")]); - - assert!(!dns_changed); - dns_mapping_is_exactly(client_state.dns_mapping(), vec![dns("1.1.1.1:53")]); - } - - #[test] - fn update_system_dns_with_change_works() { - let mut client_state = ClientState::for_test(); - client_state.interface_config = Some(interface_config_without_dns()); - - let _ = client_state.update_system_resolvers(vec![ip("1.1.1.1")]); - let dns_changed = client_state.update_system_resolvers(vec![ip("1.0.0.1")]); - - assert!(dns_changed); - dns_mapping_is_exactly(client_state.dns_mapping(), vec![dns("1.0.0.1:53")]); - } - - #[test] - fn update_to_system_with_sentinels_are_ignored() { - let mut client_state = ClientState::for_test(); - client_state.interface_config = Some(interface_config_without_dns()); - - let _ = client_state.update_system_resolvers(vec![ip("1.1.1.1")]); - let dns_changed = client_state.update_system_resolvers(vec![ - ip("1.1.1.1"), - ip("100.100.111.1"), - ip("fd00:2021:1111:8000:100:100:111:0"), - ]); - - assert!(!dns_changed); - dns_mapping_is_exactly(client_state.dns_mapping(), vec![dns("1.1.1.1:53")]); - } - - #[test] - fn upstream_dns_wins_over_system() { - let mut client_state = ClientState::for_test(); - client_state.interface_config = Some(interface_config_with_dns()); - - let dns_changed = client_state.update_dns_mapping(); - assert!(dns_changed); - - let dns_changed = client_state.update_system_resolvers(vec![ip("1.0.0.1")]); - assert!(!dns_changed); - - dns_mapping_is_exactly(client_state.dns_mapping(), dns_list()); - } - - #[test] - fn upstream_dns_change_updates() { - let mut client_state = ClientState::for_test(); - - let dns_changed = client_state.update_interface_config(interface_config_with_dns()); - - assert!(dns_changed); - - let dns_changed = client_state.update_interface_config(InterfaceConfig { - upstream_dns: vec![dns("8.8.8.8:53")], - ..interface_config_without_dns() - }); - - assert!(dns_changed); - - dns_mapping_is_exactly(client_state.dns_mapping(), vec![dns("8.8.8.8:53")]); - } - - #[test] - fn upstream_dns_no_change_is_a_no_op() { - let mut client_state = ClientState::for_test(); - client_state.interface_config = Some(interface_config_with_dns()); - - let dns_changed = client_state.update_system_resolvers(vec![ip("1.0.0.1")]); - - assert!(dns_changed); - - let dns_changed = client_state.update_interface_config(interface_config_with_dns()); - - assert!(!dns_changed); - dns_mapping_is_exactly(client_state.dns_mapping(), dns_list()); - } - - #[test] - fn upstream_dns_sentinels_are_ignored() { - let mut client_state = ClientState::for_test(); - let mut config = interface_config_with_dns(); - - let _ = client_state.update_interface_config(config.clone()); - - config.upstream_dns.push(dns("100.100.111.1:53")); - config - .upstream_dns - .push(dns("[fd00:2021:1111:8000:100:100:111:0]:53")); - - let dns_changed = client_state.update_interface_config(config); - - assert!(!dns_changed); - dns_mapping_is_exactly(client_state.dns_mapping(), dns_list()) - } - - #[test] - fn system_dns_takes_over_when_upstream_are_unset() { - let mut client_state = ClientState::for_test(); - let _ = client_state.update_interface_config(interface_config_with_dns()); - - let _ = client_state.update_system_resolvers(vec![ip("1.0.0.1")]); - let dns_changed = client_state.update_interface_config(interface_config_without_dns()); - - assert!(dns_changed); - dns_mapping_is_exactly(client_state.dns_mapping(), vec![dns("1.0.0.1:53")]); - } - #[test] fn sentinel_dns_works() { let servers = dns_list(); @@ -1559,29 +1409,6 @@ mod tests { } } - fn dns_mapping_is_exactly(mapping: BiMap, servers: Vec) { - assert_eq!( - HashSet::<&DnsServer>::from_iter(mapping.right_values()), - HashSet::from_iter(servers.iter()) - ) - } - - fn interface_config_without_dns() -> InterfaceConfig { - InterfaceConfig { - ipv4: "10.0.0.1".parse().unwrap(), - ipv6: "fe80::".parse().unwrap(), - upstream_dns: Vec::new(), - } - } - - fn interface_config_with_dns() -> InterfaceConfig { - InterfaceConfig { - ipv4: "10.0.0.1".parse().unwrap(), - ipv6: "fe80::".parse().unwrap(), - upstream_dns: dns_list(), - } - } - fn sentinel_ranges() -> Vec { vec![ IpNetwork::V4(DNS_SENTINELS_V4), diff --git a/rust/connlib/tunnel/src/dns.rs b/rust/connlib/tunnel/src/dns.rs index d8cdc8055..5c47056e3 100644 --- a/rust/connlib/tunnel/src/dns.rs +++ b/rust/connlib/tunnel/src/dns.rs @@ -7,16 +7,11 @@ use domain::base::{ Message, MessageBuilder, ToName, }; use domain::rdata::AllRecordData; -use hickory_resolver::lookup::Lookup; -use hickory_resolver::proto::error::{ProtoError, ProtoErrorKind}; -use hickory_resolver::proto::op::MessageType; -use hickory_resolver::proto::rr::RecordType; -use ip_packet::udp::UdpPacket; use ip_packet::IpPacket; use ip_packet::Packet as _; use itertools::Itertools; use std::collections::HashMap; -use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; const DNS_TTL: u32 = 1; const REVERSE_DNS_ADDRESS_END: &str = "arpa"; @@ -34,22 +29,18 @@ pub struct StubResolver { known_hosts: KnownHosts, } -#[derive(Debug)] -pub struct DnsQuery<'a> { - pub name: DomainName, - pub record_type: RecordType, - // We could be much more efficient with this field, - // we only need the header to create the response. - pub query: ip_packet::IpPacket<'a>, -} - /// Tells the Client how to reply to a single DNS query #[derive(Debug)] -pub(crate) enum ResolveStrategy<'a> { +pub(crate) enum ResolveStrategy { /// The query is for a Resource, we have an IP mapped already, and we can respond instantly LocalResponse(IpPacket<'static>), - /// The query is for a non-Resource, forward it to an upstream or system resolver - ForwardQuery(DnsQuery<'a>), + /// The query is for a non-Resource, forward it to an upstream or system resolver. + ForwardQuery { + upstream: SocketAddr, + original_src: SocketAddr, + query_id: u16, + payload: Vec, + }, } struct KnownHosts { @@ -193,29 +184,35 @@ impl StubResolver { self.dns_resources.values().contains(resource) } - // TODO: we can save a few allocations here still - // We don't need to support multiple questions/qname in a single query because - // nobody does it and since this run with each packet we want to squeeze as much optimization - // as we can therefore we won't do it. - // - // See: https://stackoverflow.com/a/55093896 /// Parses an incoming packet as a DNS query and decides how to respond to it /// /// Returns: /// - `None` if the packet is not a valid DNS query destined for one of our sentinel resolvers /// - Otherwise, a strategy for responding to the query - pub(crate) fn handle<'a>( + pub(crate) fn handle( &mut self, dns_mapping: &bimap::BiMap, - packet: IpPacket<'a>, - ) -> Option> { - dns_mapping.get_by_left(&packet.destination())?; + packet: IpPacket, + ) -> Option { + let upstream = dns_mapping.get_by_left(&packet.destination())?.address(); let datagram = packet.as_udp()?; - let message = as_dns(&datagram)?; + + // We only support DNS on port 53. + if datagram.get_destination() != DNS_PORT { + return None; + } + + let message = Message::from_octets(datagram.payload()).ok()?; + if message.header().qr() { return None; } + // We don't need to support multiple questions/qname in a single query because + // nobody does it and since this run with each packet we want to squeeze as much optimization + // as we can therefore we won't do it. + // + // See: https://stackoverflow.com/a/55093896 let question = message.first_question()?; let domain = question.qname().to_vec(); let qtype = question.qtype(); @@ -240,11 +237,12 @@ impl StubResolver { let resource_records = match (qtype, maybe_resource) { (_, Some(resource)) if !self.knows_resource(&resource) => { - return Some(ResolveStrategy::ForwardQuery(DnsQuery { - name: domain, - record_type: u16::from(qtype).into(), - query: packet, - })) + return Some(ResolveStrategy::ForwardQuery { + upstream, + query_id: message.header().id(), + payload: message.into_octets().to_vec(), + original_src: SocketAddr::new(packet.source(), datagram.get_source()), + }) } (Rtype::A, Some(resource)) => self.get_or_assign_a_records(domain.clone(), resource), (Rtype::AAAA, Some(resource)) => { @@ -256,11 +254,12 @@ impl StubResolver { vec![AllRecordData::Ptr(domain::rdata::Ptr::new(fqdn))] } _ => { - return Some(ResolveStrategy::ForwardQuery(DnsQuery { - name: domain, - record_type: u16::from(qtype).into(), - query: packet, - })) + return Some(ResolveStrategy::ForwardQuery { + upstream, + query_id: message.header().id(), + payload: message.into_octets().to_vec(), + original_src: SocketAddr::new(packet.source(), datagram.get_source()), + }) } }; @@ -278,35 +277,6 @@ impl StubResolver { } } -impl<'a> DnsQuery<'a> { - pub(crate) fn into_owned(self) -> DnsQuery<'static> { - let Self { - name, - record_type, - query, - } = self; - let buf = query.packet().to_vec(); - let query = ip_packet::IpPacket::owned(buf) - .expect("We are constructing the ip packet from an ip packet"); - - DnsQuery { - name, - record_type, - query, - } - } -} - -impl Clone for DnsQuery<'static> { - fn clone(&self) -> Self { - Self { - name: self.name.clone(), - record_type: self.record_type, - query: self.query.clone(), - } - } -} - fn to_a_records(ips: impl Iterator) -> Vec, DomainName>> { ips.filter_map(get_v4) .map(domain::rdata::A::new) @@ -321,57 +291,13 @@ fn to_aaaa_records(ips: impl Iterator) -> Vec, - response: hickory_resolver::error::ResolveResult, -) -> Result { - let datagram = original_pkt.unwrap_as_udp(); - let mut message = original_pkt.unwrap_as_dns(); - - message.set_message_type(MessageType::Response); - message.set_recursion_available(true); - - let response = match response.map_err(|err| err.kind().clone()) { - Ok(response) => message.add_answers(response.records().to_vec()), - Err(hickory_resolver::error::ResolveErrorKind::Proto(ProtoError { kind, .. })) - if matches!(*kind, ProtoErrorKind::NoRecordsFound { .. }) => - { - let ProtoErrorKind::NoRecordsFound { - soa, response_code, .. - } = *kind - else { - panic!("Impossible - We matched on `ProtoErrorKind::NoRecordsFound` but then could not destructure that same variant"); - }; - if let Some(soa) = soa { - message.add_name_server(soa.into_record_of_rdata()); - } - - message.set_response_code(response_code) - } - Err(e) => { - return Err(e.into()); - } - }; - - let packet = ip_packet::make::udp_packet( - original_pkt.destination(), - original_pkt.source(), - datagram.get_destination(), - datagram.get_source(), - response.to_vec()?, - ) - .into_immutable(); - - Ok(packet) -} - fn build_dns_with_answer( - message: &Message<[u8]>, + message: Message<&[u8]>, qname: DomainName, records: Vec, DomainName>>, ) -> Option> { let mut answer_builder = MessageBuilder::new_vec() - .start_answer(message, Rcode::NOERROR) + .start_answer(&message, Rcode::NOERROR) .ok()?; answer_builder.header_mut().set_ra(true); @@ -384,12 +310,6 @@ fn build_dns_with_answer( Some(answer_builder.finish()) } -pub fn as_dns<'a>(pkt: &'a UdpPacket<'a>) -> Option<&'a Message<[u8]>> { - (pkt.get_destination() == DNS_PORT) - .then(|| Message::from_slice(pkt.payload()).ok()) - .flatten() -} - pub fn is_subdomain(name: &DomainName, resource: &str) -> bool { let question_mark = RelativeName::>::from_octets(b"\x01?".as_ref().into()).unwrap(); let Ok(resource) = DomainName::vec_from_str(resource) else { diff --git a/rust/connlib/tunnel/src/io.rs b/rust/connlib/tunnel/src/io.rs index ec14f7892..6846ad7e5 100644 --- a/rust/connlib/tunnel/src/io.rs +++ b/rust/connlib/tunnel/src/io.rs @@ -1,29 +1,15 @@ -use crate::{device_channel::Device, dns::DnsQuery, sockets::Sockets}; -use connlib_shared::messages::DnsServer; -use futures::Future; -use futures_bounded::FuturesTupleSet; +use crate::{device_channel::Device, sockets::Sockets}; use futures_util::FutureExt as _; -use hickory_proto::iocompat::AsyncIoTokioAsStd; -use hickory_proto::TokioTime; -use hickory_resolver::{ - config::{NameServerConfig, Protocol, ResolverConfig, ResolverOpts}, - name_server::{GenericConnector, RuntimeProvider}, - AsyncResolver, TokioHandle, -}; use ip_packet::{IpPacket, MutableIpPacket}; use socket_factory::{DatagramIn, DatagramOut, SocketFactory, TcpSocket, UdpSocket}; use std::{ - collections::HashMap, io, - net::{IpAddr, SocketAddr}, pin::Pin, sync::Arc, task::{ready, Context, Poll}, - time::{Duration, Instant}, + time::Instant, }; -const DNS_QUERIES_QUEUE_SIZE: usize = 100; - /// Bundles together all side-effects that connlib needs to have access to. pub struct Io { /// The TUN device offered to the user. @@ -33,29 +19,16 @@ pub struct Io { /// The UDP sockets used to send & receive packets from the network. sockets: Sockets, - tcp_socket_factory: Arc>, + _tcp_socket_factory: Arc>, udp_socket_factory: Arc>, timeout: Option>>, - - upstream_dns_servers: HashMap>>, - forwarded_dns_queries: FuturesTupleSet< - Result, - DnsQuery<'static>, - >, } pub enum Input<'a, I> { Timeout(Instant), Device(MutableIpPacket<'a>), Network(I), - DnsResponse( - DnsQuery<'static>, - Result< - Result, - futures_bounded::Timeout, - >, - ), } impl Io { @@ -73,13 +46,8 @@ impl Io { device: Device::new(), timeout: None, sockets, - tcp_socket_factory, + _tcp_socket_factory: tcp_socket_factory, udp_socket_factory, - upstream_dns_servers: HashMap::default(), - forwarded_dns_queries: FuturesTupleSet::new( - Duration::from_secs(60), - DNS_QUERIES_QUEUE_SIZE, - ), }) } @@ -90,10 +58,6 @@ impl Io { ip6_bffer: &'b mut [u8], device_buffer: &'b mut [u8], ) -> Poll>>>> { - if let Poll::Ready((response, query)) = self.forwarded_dns_queries.poll_unpin(cx) { - return Poll::Ready(Ok(Input::DnsResponse(query, response))); - } - if let Poll::Ready(network) = self.sockets.poll_recv_from(ip4_buffer, ip6_bffer, cx)? { return Poll::Ready(Ok(Input::Network(network))); } @@ -126,50 +90,6 @@ impl Io { Ok(()) } - pub fn set_upstream_dns_servers( - &mut self, - dns_servers: impl IntoIterator, - ) { - tracing::info!("Setting new DNS resolvers"); - - self.forwarded_dns_queries = - FuturesTupleSet::new(Duration::from_secs(60), DNS_QUERIES_QUEUE_SIZE); - self.upstream_dns_servers = create_resolvers( - dns_servers, - TokioRuntimeProvider::new( - self.tcp_socket_factory.clone(), - self.udp_socket_factory.clone(), - ), - ); - } - - pub fn perform_dns_query(&mut self, query: DnsQuery<'static>) -> Result<(), DnsQueryError> { - let upstream = query.query.destination(); - let resolver = self - .upstream_dns_servers - .get(&upstream) - .cloned() - .expect("Only DNS queries to known upstream servers should be forwarded to `Io`"); - - if self - .forwarded_dns_queries - .try_push( - { - let name = query.name.clone().to_string(); - let record_type = query.record_type; - - async move { resolver.lookup(&name, record_type).await } - }, - query, - ) - .is_err() - { - return Err(DnsQueryError::TooManyQueries); - } - - Ok(()) - } - pub fn reset_timeout(&mut self, timeout: Instant) { let timeout = tokio::time::Instant::from_std(timeout); @@ -198,92 +118,3 @@ impl Io { Ok(()) } } - -#[derive(Debug, thiserror::Error)] -pub enum DnsQueryError { - #[error("Too many ongoing DNS queries")] - TooManyQueries, -} - -/// Identical to [`TokioRuntimeProvider`](hickory_resolver::name_server::TokioRuntimeProvider) but using our own [`SocketFactory`]. -#[derive(Clone)] -struct TokioRuntimeProvider { - handle: TokioHandle, - tcp_socket_factory: Arc>, - udp_socket_factory: Arc>, -} - -impl TokioRuntimeProvider { - fn new( - tcp_socket_factory: Arc>, - udp_socket_factory: Arc>, - ) -> TokioRuntimeProvider { - Self { - handle: Default::default(), - tcp_socket_factory, - udp_socket_factory, - } - } -} - -impl RuntimeProvider for TokioRuntimeProvider { - type Handle = TokioHandle; - type Timer = TokioTime; - type Udp = UdpSocket; - type Tcp = AsyncIoTokioAsStd; - - fn create_handle(&self) -> Self::Handle { - self.handle.clone() - } - - fn connect_tcp( - &self, - server_addr: SocketAddr, - ) -> Pin>>> { - let socket = (self.tcp_socket_factory)(&server_addr); - Box::pin(async move { - let socket = socket?; - let stream = socket.connect(server_addr).await?; - - Ok(AsyncIoTokioAsStd(stream)) - }) - } - - fn bind_udp( - &self, - local_addr: SocketAddr, - _server_addr: SocketAddr, - ) -> Pin>>> { - let socket = (self.udp_socket_factory)(&local_addr); - - Box::pin(async move { socket }) - } -} - -fn create_resolvers( - dns_servers: impl IntoIterator, - runtime_provider: TokioRuntimeProvider, -) -> HashMap>> { - dns_servers - .into_iter() - .map(|(sentinel, srv)| { - let mut resolver_config = ResolverConfig::new(); - resolver_config.add_name_server(NameServerConfig::new(srv.address(), Protocol::Udp)); - resolver_config.add_name_server(NameServerConfig::new(srv.address(), Protocol::Tcp)); - - let mut resolver_opts = ResolverOpts::default(); - resolver_opts.edns0 = true; - resolver_opts.cache_size = 0; - resolver_opts.attempts = 1; - - ( - sentinel, - AsyncResolver::new_with_conn( - resolver_config, - resolver_opts, - GenericConnector::new(runtime_provider.clone()), - ), - ) - }) - .collect() -} diff --git a/rust/connlib/tunnel/src/lib.rs b/rust/connlib/tunnel/src/lib.rs index 379012b19..0378ab0f5 100644 --- a/rust/connlib/tunnel/src/lib.rs +++ b/rust/connlib/tunnel/src/lib.rs @@ -116,13 +116,6 @@ impl ClientTunnel { continue; } - if let Some(dns_query) = self.role_state.poll_dns_queries() { - if let Err(e) = self.io.perform_dns_query(dns_query.clone()) { - self.role_state.on_dns_result(dns_query, Err(e)) - } - continue; - } - if let Some(timeout) = self.role_state.poll_timeout() { self.io.reset_timeout(timeout); } @@ -163,10 +156,6 @@ impl ClientTunnel { continue; } - Poll::Ready(io::Input::DnsResponse(query, response)) => { - self.role_state.on_dns_result(query, Ok(response)); - continue; - } Poll::Pending => {} } @@ -254,9 +243,6 @@ impl GatewayTunnel { continue; } - Poll::Ready(io::Input::DnsResponse(_, _)) => { - unreachable!("Gateway does not (yet) resolve DNS queries via `Io`") - } Poll::Pending => {} } diff --git a/rust/connlib/tunnel/src/tests.rs b/rust/connlib/tunnel/src/tests.rs index 8be1f5cdf..02dd19b3c 100644 --- a/rust/connlib/tunnel/src/tests.rs +++ b/rust/connlib/tunnel/src/tests.rs @@ -8,6 +8,7 @@ mod flux_capacitor; mod reference; mod run_count_appender; mod sim_client; +mod sim_dns; mod sim_gateway; mod sim_net; mod sim_relay; diff --git a/rust/connlib/tunnel/src/tests/reference.rs b/rust/connlib/tunnel/src/tests/reference.rs index 9e4ca709d..a7d477788 100644 --- a/rust/connlib/tunnel/src/tests/reference.rs +++ b/rust/connlib/tunnel/src/tests/reference.rs @@ -1,5 +1,5 @@ use super::{ - composite_strategy::CompositeStrategy, sim_client::*, sim_gateway::*, sim_net::*, + composite_strategy::CompositeStrategy, sim_client::*, sim_dns::*, sim_gateway::*, sim_net::*, strategies::*, stub_portal::StubPortal, transition::*, }; use crate::dns::is_subdomain; @@ -10,7 +10,7 @@ use connlib_shared::{ }, DomainName, StaticSecret, }; -use hickory_proto::rr::RecordType; +use domain::base::Rtype; use proptest::{prelude::*, sample}; use proptest_state_machine::ReferenceStateMachine; use std::{ @@ -27,6 +27,7 @@ pub(crate) struct ReferenceState { pub(crate) client: Host, pub(crate) gateways: BTreeMap>, pub(crate) relays: BTreeMap>, + pub(crate) dns_servers: BTreeMap>, pub(crate) portal: StubPortal, @@ -58,13 +59,14 @@ impl ReferenceStateMachine for ReferenceState { fn init_state() -> BoxedStrategy { stub_portal() - .prop_flat_map(move |portal| { + .prop_flat_map(|portal| { let gateways = portal.gateways(); let dns_resource_records = portal.dns_resource_records(); let client = portal.client(); let relays = relays(); let global_dns_records = global_dns_records(); // Start out with a set of global DNS records so we have something to resolve outside of DNS resources. let drop_direct_client_traffic = any::(); + let dns_servers = dns_servers(); ( client, @@ -74,6 +76,7 @@ impl ReferenceStateMachine for ReferenceState { relays, global_dns_records, drop_direct_client_traffic, + dns_servers, ) }) .prop_filter_map( @@ -86,6 +89,7 @@ impl ReferenceStateMachine for ReferenceState { relays, mut global_dns, drop_direct_client_traffic, + dns_servers, )| { let mut routing_table = RoutingTable::default(); @@ -104,6 +108,12 @@ impl ReferenceStateMachine for ReferenceState { }; } + for (id, dns_server) in &dns_servers { + if !routing_table.add_host(*id, dns_server) { + return None; + }; + } + // Merge all DNS records into `global_dns`. global_dns.extend(records); @@ -111,6 +121,7 @@ impl ReferenceStateMachine for ReferenceState { c, gateways, relays, + dns_servers, portal, global_dns, drop_direct_client_traffic, @@ -120,7 +131,7 @@ impl ReferenceStateMachine for ReferenceState { ) .prop_filter( "private keys must be unique", - |(c, gateways, _, _, _, _, _)| { + |(c, gateways, _, _, _, _, _, _)| { let different_keys = gateways .iter() .map(|(_, g)| g.inner().key) @@ -135,6 +146,7 @@ impl ReferenceStateMachine for ReferenceState { client, gateways, relays, + dns_servers, portal, global_dns_records, drop_direct_client_traffic, @@ -144,6 +156,7 @@ impl ReferenceStateMachine for ReferenceState { client, gateways, relays, + dns_servers, portal, global_dns_records, network, @@ -162,13 +175,11 @@ impl ReferenceStateMachine for ReferenceState { CompositeStrategy::default() .with( 1, - system_dns_servers() - .prop_map(|servers| Transition::UpdateSystemDnsServers { servers }), + update_system_dns_servers(state.dns_servers.values().cloned().collect()), ) .with( 1, - upstream_dns_servers() - .prop_map(|servers| Transition::UpdateUpstreamDnsServers { servers }), + update_upstream_dns_servers(state.dns_servers.values().cloned().collect()), ) .with_if_not_empty( 5, @@ -396,12 +407,12 @@ impl ReferenceStateMachine for ReferenceState { state.portal.gateway_for_resource(r).copied() }) }), - Transition::UpdateSystemDnsServers { servers } => { + Transition::UpdateSystemDnsServers(servers) => { state .client .exec_mut(|client| client.system_dns_resolvers.clone_from(servers)); } - Transition::UpdateUpstreamDnsServers { servers } => { + Transition::UpdateUpstreamDnsServers(servers) => { state .client .exec_mut(|client| client.upstream_dns_resolvers.clone_from(servers)); @@ -513,34 +524,28 @@ impl ReferenceStateMachine for ReferenceState { ref_client.is_valid_icmp_packet(seq, identifier) && ref_client.dns_records.get(dst).is_some_and(|r| match src { - IpAddr::V4(_) => r.contains(&RecordType::A), - IpAddr::V6(_) => r.contains(&RecordType::AAAA), + IpAddr::V4(_) => r.contains(&Rtype::A), + IpAddr::V6(_) => r.contains(&Rtype::AAAA), }) && state.gateways.contains_key(gateway) } - Transition::UpdateSystemDnsServers { servers } => { - // TODO: PRODUCTION CODE DOES NOT HANDLE THIS! - - if state.client.ip4.is_none() && servers.iter().all(|s| s.is_ipv4()) { - return false; - } - if state.client.ip6.is_none() && servers.iter().all(|s| s.is_ipv6()) { - return false; + Transition::UpdateSystemDnsServers(servers) => { + if servers.is_empty() { + return true; // Clearing is allowed. } - true + servers + .iter() + .any(|dns_server| state.client.sending_socket_for(*dns_server).is_some()) } - Transition::UpdateUpstreamDnsServers { servers } => { - // TODO: PRODUCTION CODE DOES NOT HANDLE THIS! - - if state.client.ip4.is_none() && servers.iter().all(|s| s.ip().is_ipv4()) { - return false; - } - if state.client.ip6.is_none() && servers.iter().all(|s| s.ip().is_ipv6()) { - return false; + Transition::UpdateUpstreamDnsServers(servers) => { + if servers.is_empty() { + return true; // Clearing is allowed. } - true + servers + .iter() + .any(|dns_server| state.client.sending_socket_for(dns_server.ip()).is_some()) } Transition::SendDnsQuery { domain, dns_server, .. diff --git a/rust/connlib/tunnel/src/tests/sim_client.rs b/rust/connlib/tunnel/src/tests/sim_client.rs index eb7011001..46ff466fe 100644 --- a/rust/connlib/tunnel/src/tests/sim_client.rs +++ b/rust/connlib/tunnel/src/tests/sim_client.rs @@ -2,11 +2,10 @@ use super::{ reference::{private_key, PrivateKey, ResourceDst}, sim_net::{any_ip_stack, any_port, host, Host}, sim_relay::{map_explode, SimRelay}, - strategies::{latency, system_dns_servers, upstream_dns_servers}, - sut::domain_to_hickory_name, + strategies::latency, IcmpIdentifier, IcmpSeq, QueryId, }; -use crate::{tests::sut::hickory_name_to_domain, ClientState}; +use crate::ClientState; use bimap::BiMap; use connlib_shared::{ messages::{ @@ -16,10 +15,9 @@ use connlib_shared::{ proptest::{client_id, domain_name}, DomainName, }; -use hickory_proto::{ - op::MessageType, - rr::{rdata, RData, RecordType}, - serialize::binary::BinDecodable as _, +use domain::{ + base::{Message, Rtype, ToName}, + rdata::AllRecordData, }; use ip_network::{Ipv4Network, Ipv6Network}; use ip_network_table::IpNetworkTable; @@ -80,7 +78,7 @@ impl SimClient { pub(crate) fn send_dns_query_for( &mut self, domain: DomainName, - r_type: RecordType, + r_type: Rtype, query_id: u16, dns_server: SocketAddr, now: Instant, @@ -92,15 +90,13 @@ impl SimClient { tracing::debug!(%dns_server, %domain, "Sending DNS query"); - let name = domain_to_hickory_name(domain); - let src = self .sut .tunnel_ip_for(dns_server) .expect("tunnel should be initialised"); let packet = ip_packet::make::dns_query( - name, + domain, r_type, SocketAddr::new(src, 9999), // An application would pick a random source port that is free. SocketAddr::new(dns_server, 53), @@ -130,14 +126,13 @@ impl SimClient { let packet = packet.to_owned().into_immutable(); if let Some(udp) = packet.as_udp() { - if let Ok(message) = hickory_proto::op::Message::from_bytes(udp.payload()) { - debug_assert_eq!( - message.message_type(), - MessageType::Query, + if let Ok(message) = Message::from_slice(udp.payload()) { + debug_assert!( + !message.header().qr(), "every DNS message sent from the client should be a DNS query" ); - self.sent_dns_queries.insert(message.id(), packet); + self.sent_dns_queries.insert(message.header().id(), packet); } } } @@ -175,18 +170,24 @@ impl SimClient { if let Some(udp) = packet.as_udp() { if udp.get_source() == 53 { - let mut message = hickory_proto::op::Message::from_bytes(udp.payload()) + let message = Message::from_slice(udp.payload()) .expect("ip packets on port 53 to be DNS packets"); self.received_dns_responses - .insert(message.id(), packet.to_owned()); + .insert(message.header().id(), packet.to_owned()); - for record in message.take_answers().into_iter() { - let domain = hickory_name_to_domain(record.name().clone()); + for record in message.answer().unwrap() { + let record = record.unwrap(); + let domain = record.owner().to_name(); - let ip = match record.data() { - Some(RData::A(rdata::A(ip4))) => IpAddr::from(*ip4), - Some(RData::AAAA(rdata::AAAA(ip6))) => IpAddr::from(*ip6), + #[allow(clippy::wildcard_enum_match_arm)] + let ip = match record + .into_any_record::>() + .unwrap() + .data() + { + AllRecordData::A(a) => IpAddr::from(a.addr()), + AllRecordData::Aaaa(aaaa) => IpAddr::from(aaaa.addr()), unhandled => { panic!("Unexpected record data: {unhandled:?}") } @@ -253,7 +254,7 @@ pub struct RefClient { /// The IPs assigned to a domain by connlib are an implementation detail that we don't want to model in these tests. /// Instead, we just remember what _kind_ of records we resolved to be able to sample a matching src IP. #[derivative(Debug = "ignore")] - pub(crate) dns_records: BTreeMap>, + pub(crate) dns_records: BTreeMap>, /// Whether we are connected to the gateway serving the Internet resource. pub(crate) connected_internet_resources: bool, @@ -287,12 +288,12 @@ impl RefClient { /// This simulates receiving the `init` message from the portal. pub(crate) fn init(self) -> SimClient { let mut client_state = ClientState::new(self.key, self.known_hosts, self.key.0); // Cheating a bit here by reusing the key as seed. - let _ = client_state.update_interface_config(Interface { + client_state.update_interface_config(Interface { ipv4: self.tunnel_ip4, ipv6: self.tunnel_ip6, - upstream_dns: self.upstream_dns_resolvers, + upstream_dns: self.upstream_dns_resolvers.clone(), }); - let _ = client_state.update_system_resolvers(self.system_dns_resolvers); + client_state.update_system_resolvers(self.system_dns_resolvers.clone()); SimClient::new(self.id, client_state) } @@ -478,7 +479,7 @@ impl RefClient { .find(|id| !self.disabled_resources.contains(id)) } - fn resolved_domains(&self) -> impl Iterator)> + '_ { + fn resolved_domains(&self) -> impl Iterator)> + '_ { self.dns_records .iter() .filter(|(domain, _)| self.dns_resource_by_domain(domain).is_some()) @@ -499,7 +500,7 @@ impl RefClient { .filter_map(|(domain, records)| { records .iter() - .any(|r| matches!(r, RecordType::A)) + .any(|r| matches!(r, &Rtype::A)) .then_some(domain) }) .collect() @@ -510,7 +511,7 @@ impl RefClient { .filter_map(|(domain, records)| { records .iter() - .any(|r| matches!(r, RecordType::AAAA)) + .any(|r| matches!(r, &Rtype::AAAA)) .then_some(domain) }) .collect() @@ -525,7 +526,7 @@ impl RefClient { return self .upstream_dns_resolvers .iter() - .map(|s| s.address()) + .map(DnsServer::address) .collect(); } @@ -687,32 +688,6 @@ pub(crate) fn ref_client_host( ref_client(tunnel_ip4s, tunnel_ip6s), latency(300), // TODO: Increase with #6062. ) - .prop_filter("at least one DNS server needs to be reachable", |host| { - // TODO: PRODUCTION CODE DOES NOT HANDLE THIS! - - let upstream_dns_resolvers = &host.inner().upstream_dns_resolvers; - let system_dns = &host.inner().system_dns_resolvers; - - if !upstream_dns_resolvers.is_empty() { - if host.ip4.is_none() && upstream_dns_resolvers.iter().all(|s| s.ip().is_ipv4()) { - return false; - } - if host.ip6.is_none() && upstream_dns_resolvers.iter().all(|s| s.ip().is_ipv6()) { - return false; - } - - return true; - } - - if host.ip4.is_none() && system_dns.iter().all(|s| s.is_ipv4()) { - return false; - } - if host.ip6.is_none() && system_dns.iter().all(|s| s.is_ipv6()) { - return false; - } - - true - }) } fn ref_client( @@ -725,26 +700,16 @@ fn ref_client( client_id(), private_key(), known_hosts(), - system_dns_servers(), - upstream_dns_servers(), ) .prop_map( - move |( - tunnel_ip4, - tunnel_ip6, - id, - key, - known_hosts, - system_dns_resolvers, - upstream_dns_resolvers, - )| RefClient { + move |(tunnel_ip4, tunnel_ip6, id, key, known_hosts)| RefClient { id, key, known_hosts, tunnel_ip4, tunnel_ip6, - system_dns_resolvers, - upstream_dns_resolvers, + system_dns_resolvers: Default::default(), + upstream_dns_resolvers: Default::default(), internet_resource: Default::default(), cidr_resources: IpNetworkTable::new(), dns_resources: Default::default(), diff --git a/rust/connlib/tunnel/src/tests/sim_dns.rs b/rust/connlib/tunnel/src/tests/sim_dns.rs new file mode 100644 index 000000000..583bb8cab --- /dev/null +++ b/rust/connlib/tunnel/src/tests/sim_dns.rs @@ -0,0 +1,124 @@ +use super::{ + sim_net::{host, Host}, + strategies::latency, +}; +use connlib_shared::DomainName; +use domain::{ + base::{ + iana::{Class, Rcode}, + Message, MessageBuilder, Record, Rtype, ToName as _, Ttl, + }, + rdata::AllRecordData, +}; +use firezone_relay::IpStack; +use proptest::{ + arbitrary::any, + strategy::{Just, Strategy}, +}; +use snownet::Transmit; +use std::{ + borrow::Cow, + collections::{BTreeMap, HashSet}, + fmt, + net::{IpAddr, SocketAddr}, + time::Instant, +}; +use uuid::Uuid; + +pub(crate) fn dns_server_id() -> impl Strategy { + any::().prop_map(DnsServerId::from_u128) +} + +pub(crate) fn ref_dns_host(addr: SocketAddr) -> impl Strategy> { + let ip = addr.ip(); + let port = addr.port(); + + host( + Just(IpStack::from(ip)), + Just(port), + Just(RefDns {}), + latency(50), + ) +} + +#[derive(Debug, Clone)] +pub(crate) struct RefDns {} + +#[derive(Debug)] +pub(crate) struct SimDns {} + +impl SimDns { + pub(crate) fn receive( + &mut self, + global_dns_records: &BTreeMap>, + transmit: Transmit, + _now: Instant, + ) -> Option> { + let query = Message::from_octets(&transmit.payload).ok()?; + + let response = MessageBuilder::new_vec(); + let mut answers = response.start_answer(&query, Rcode::NOERROR).unwrap(); + + let query = query.sole_question().unwrap(); + let name = query.qname().to_vec(); + + let records = global_dns_records + .get(&name) + .into_iter() + .flatten() + .filter(|ip| { + #[allow(clippy::wildcard_enum_match_arm)] + match query.qtype() { + Rtype::A => ip.is_ipv4(), + Rtype::AAAA => ip.is_ipv6(), + _ => todo!(), + } + }) + .copied() + .map(|ip| match ip { + IpAddr::V4(v4) => AllRecordData::, DomainName>::A(v4.into()), + IpAddr::V6(v6) => AllRecordData::, DomainName>::Aaaa(v6.into()), + }) + .map(|rdata| Record::new(name.clone(), Class::IN, Ttl::from_days(1), rdata)); + + for record in records { + answers.push(record).unwrap(); + } + + let payload = answers.finish(); + + tracing::debug!(%name, "Responding to DNS query"); + + Some(Transmit { + src: Some(transmit.dst), + dst: transmit.src.unwrap(), + payload: Cow::Owned(payload), + }) + } +} + +#[derive(Hash, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +pub(crate) struct DnsServerId(Uuid); + +impl DnsServerId { + #[cfg(feature = "proptest")] + pub fn from_u128(v: u128) -> Self { + Self(Uuid::from_u128(v)) + } +} + +impl fmt::Display for DnsServerId { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + if cfg!(feature = "proptest") { + write!(f, "{:X}", self.0.as_u128()) + } else { + write!(f, "{}", self.0) + } + } +} + +impl fmt::Debug for DnsServerId { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Display::fmt(&self, f) + } +} diff --git a/rust/connlib/tunnel/src/tests/sim_gateway.rs b/rust/connlib/tunnel/src/tests/sim_gateway.rs index e11c5e8f0..3ed441e7a 100644 --- a/rust/connlib/tunnel/src/tests/sim_gateway.rs +++ b/rust/connlib/tunnel/src/tests/sim_gateway.rs @@ -4,7 +4,7 @@ use super::{ sim_relay::{map_explode, SimRelay}, strategies::latency, }; -use crate::{tests::sut::hickory_name_to_domain, GatewayState}; +use crate::GatewayState; use connlib_shared::{ messages::{GatewayId, RelayId}, DomainName, @@ -67,6 +67,8 @@ impl SimGateway { ) -> Option> { let packet = packet.to_owned(); + // TODO: Instead of handling these things inline, here, should we dispatch them via `RoutingTable`? + if packet.as_icmp().is_some() { self.received_icmp_requests.push_back(packet.clone()); @@ -78,11 +80,7 @@ impl SimGateway { if packet.as_udp().is_some() { let response = ip_packet::make::dns_ok_response(packet, |name| { - global_dns_records - .get(&hickory_name_to_domain(name.clone())) - .cloned() - .into_iter() - .flatten() + global_dns_records.get(name).cloned().into_iter().flatten() }); let transmit = self.sut.encapsulate(response, now)?.into_owned(); diff --git a/rust/connlib/tunnel/src/tests/sim_net.rs b/rust/connlib/tunnel/src/tests/sim_net.rs index 718fd095c..adf97f8a3 100644 --- a/rust/connlib/tunnel/src/tests/sim_net.rs +++ b/rust/connlib/tunnel/src/tests/sim_net.rs @@ -1,3 +1,4 @@ +use super::sim_dns::DnsServerId; use crate::tests::buffered_transmits::BufferedTransmits; use crate::tests::strategies::documentation_ip6s; use connlib_shared::messages::{ClientId, GatewayId, RelayId}; @@ -122,6 +123,15 @@ impl Host { } } + pub(crate) fn single_socket(&self) -> SocketAddr { + match (self.ip4, self.ip6) { + (None, Some(ip6)) => SocketAddr::new(ip6.into(), self.default_port), + (Some(ip4), None) => SocketAddr::new(ip4.into(), self.default_port), + (Some(_), Some(_)) => panic!("Dual-stack host"), + (None, None) => panic!("No socket available"), + } + } + pub(crate) fn latency(&self) -> Duration { self.latency } @@ -253,6 +263,7 @@ pub(crate) enum HostId { Client(ClientId), Gateway(GatewayId), Relay(RelayId), + DnsServer(DnsServerId), Stale, } @@ -274,6 +285,12 @@ impl From for HostId { } } +impl From for HostId { + fn from(v: DnsServerId) -> Self { + Self::DnsServer(v) + } +} + pub(crate) fn host( socket_ips: impl Strategy, default_port: impl Strategy, diff --git a/rust/connlib/tunnel/src/tests/strategies.rs b/rust/connlib/tunnel/src/tests/strategies.rs index bf7e7cafc..60cb96f33 100644 --- a/rust/connlib/tunnel/src/tests/strategies.rs +++ b/rust/connlib/tunnel/src/tests/strategies.rs @@ -1,4 +1,9 @@ -use super::{sim_net::Host, sim_relay::ref_relay_host, stub_portal::StubPortal}; +use super::{ + sim_dns::{dns_server_id, ref_dns_host, DnsServerId, RefDns}, + sim_net::Host, + sim_relay::ref_relay_host, + stub_portal::StubPortal, +}; use crate::client::{IPV4_RESOURCES, IPV6_RESOURCES}; use connlib_shared::{ messages::{ @@ -6,7 +11,7 @@ use connlib_shared::{ ResourceDescriptionCidr, ResourceDescriptionDns, ResourceDescriptionInternet, Site, SiteId, }, - DnsServer, GatewayId, RelayId, + GatewayId, RelayId, }, proptest::{ any_ip_network, cidr_resource, dns_resource, domain_name, gateway_id, relay_id, site, @@ -14,42 +19,15 @@ use connlib_shared::{ DomainName, }; use ip_network::{IpNetwork, Ipv4Network, Ipv6Network}; -use itertools::Itertools as _; +use itertools::Itertools; use prop::sample; use proptest::{collection, prelude::*}; use std::{ collections::{BTreeMap, HashMap, HashSet}, - net::{IpAddr, Ipv4Addr, Ipv6Addr}, + net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, time::Duration, }; -pub(crate) fn upstream_dns_servers() -> impl Strategy> { - let ip4_dns_servers = collection::vec( - any::().prop_map(|ip| DnsServer::from((ip, 53))), - 1..4, - ); - let ip6_dns_servers = collection::vec( - any::().prop_map(|ip| DnsServer::from((ip, 53))), - 1..4, - ); - - // TODO: PRODUCTION CODE DOES NOT HAVE A SAFEGUARD FOR THIS YET. - // AN ADMIN COULD CONFIGURE ONLY IPv4 SERVERS IN WHICH CASE WE ARE SCREWED IF THE CLIENT ONLY HAS IPv6 CONNECTIVITY. - - prop_oneof![ - Just(Vec::new()), - (ip4_dns_servers, ip6_dns_servers).prop_map(|(mut ip4_servers, ip6_servers)| { - ip4_servers.extend(ip6_servers); - - ip4_servers - }) - ] -} - -pub(crate) fn system_dns_servers() -> impl Strategy> { - collection::vec(any::(), 1..4) // Always need at least 1 system DNS server. TODO: Should we test what happens if we don't? -} - pub(crate) fn global_dns_records() -> impl Strategy>> { collection::btree_map( domain_name(2..4).prop_map(|d| d.parse().unwrap()), @@ -144,6 +122,38 @@ pub(crate) fn relays() -> impl Strategy>> { collection::btree_map(relay_id(), ref_relay_host(), 1..=2) } +/// Sample a list of DNS servers. +/// +/// We make sure to always have at least 1 IPv4 and 1 IPv6 DNS server. +pub(crate) fn dns_servers() -> impl Strategy>> { + let ip4_dns_servers = collection::hash_set( + any::().prop_map(|ip| SocketAddr::from((ip, 53))), + 1..4, + ); + let ip6_dns_servers = collection::hash_set( + any::().prop_map(|ip| SocketAddr::from((ip, 53))), + 1..4, + ); + + (ip4_dns_servers, ip6_dns_servers).prop_flat_map(|(ip4_dns_servers, ip6_dns_servers)| { + let servers = Vec::from_iter(ip4_dns_servers.into_iter().chain(ip6_dns_servers)); + + // First, generate a unique number of IDs, one for each DNS server. + let ids = collection::hash_set(dns_server_id(), servers.len()); + + (ids, Just(servers)) + .prop_flat_map(move |(ids, servers)| { + let ids = ids.into_iter(); + + // Second, zip the IDs and addresses together. + ids.zip(servers) + .map(|(id, addr)| (Just(id), ref_dns_host(addr))) + .collect::>() + }) + .prop_map(BTreeMap::from_iter) // Third, turn the `Vec` of tuples into a `BTreeMap`. + }) +} + fn any_site(sites: HashSet) -> impl Strategy { sample::select(Vec::from_iter(sites)) } diff --git a/rust/connlib/tunnel/src/tests/sut.rs b/rust/connlib/tunnel/src/tests/sut.rs index 37f6917c3..fc5ce336c 100644 --- a/rust/connlib/tunnel/src/tests/sut.rs +++ b/rust/connlib/tunnel/src/tests/sut.rs @@ -1,6 +1,7 @@ use super::buffered_transmits::BufferedTransmits; use super::reference::ReferenceState; use super::sim_client::SimClient; +use super::sim_dns::{DnsServerId, SimDns}; use super::sim_gateway::SimGateway; use super::sim_net::{Host, HostId, RoutingTable}; use super::sim_relay::SimRelay; @@ -10,17 +11,12 @@ use crate::tests::assertions::*; use crate::tests::flux_capacitor::FluxCapacitor; use crate::tests::transition::Transition; use crate::utils::earliest; -use crate::{dns::DnsQuery, ClientEvent, GatewayEvent, Request}; +use crate::{ClientEvent, GatewayEvent, Request}; use connlib_shared::messages::client::ResourceDescription; use connlib_shared::{ messages::{ClientId, GatewayId, Interface, RelayId}, DomainName, }; -use hickory_proto::{ - op::Query, - rr::{RData, Record, RecordType}, -}; -use hickory_resolver::lookup::Lookup; use proptest_state_machine::{ReferenceStateMachine, StateMachineTest}; use secrecy::ExposeSecret as _; use snownet::Transmit; @@ -28,8 +24,6 @@ use std::iter; use std::{ collections::{BTreeMap, HashSet}, net::IpAddr, - str::FromStr as _, - sync::Arc, time::{Duration, Instant}, }; use tracing::debug_span; @@ -47,6 +41,7 @@ pub(crate) struct TunnelTest { client: Host, gateways: BTreeMap>, relays: BTreeMap>, + dns_servers: BTreeMap>, drop_direct_client_traffic: bool, network: RoutingTable, @@ -101,6 +96,16 @@ impl StateMachineTest for TunnelTest { }) .collect::>(); + let dns_servers = ref_state + .dns_servers + .iter() + .map(|(did, dns_server)| { + let dns_server = dns_server.map(|_, _, _| SimDns {}, debug_span!("dns", %did)); + + (*did, dns_server) + }) + .collect::>(); + // Configure client and gateway with the relays. client.exec_mut(|c| c.update_relays(iter::empty(), relays.iter(), flux_capacitor.now())); for gateway in gateways.values_mut() { @@ -116,6 +121,7 @@ impl StateMachineTest for TunnelTest { gateways, logger, relays, + dns_servers, }; let mut buffered_transmits = BufferedTransmits::default(); @@ -221,12 +227,12 @@ impl StateMachineTest for TunnelTest { buffered_transmits.push_from(transmit, &state.client, now); } - Transition::UpdateSystemDnsServers { servers } => { + Transition::UpdateSystemDnsServers(servers) => { state .client .exec_mut(|c| c.sut.update_system_resolvers(servers)); } - Transition::UpdateUpstreamDnsServers { servers } => { + Transition::UpdateUpstreamDnsServers(servers) => { state.client.exec_mut(|c| { c.sut.update_interface_config(Interface { ipv4: c.sut.tunnel_ip4().unwrap(), @@ -259,7 +265,7 @@ impl StateMachineTest for TunnelTest { // Simulate receiving `init`. state.client.exec_mut(|c| { - let _ = c.sut.update_interface_config(Interface { + c.sut.update_interface_config(Interface { ipv4, ipv6, upstream_dns, @@ -454,10 +460,6 @@ impl TunnelTest { ); continue; } - if let Some(query) = self.client.exec_mut(|client| client.sut.poll_dns_queries()) { - self.on_forwarded_dns_query(query, ref_state); - continue; - } self.client.exec_mut(|sim| { while let Some(packet) = sim.sut.poll_packets() { sim.on_received_packet(packet) @@ -514,7 +516,7 @@ impl TunnelTest { for (_, relay) in self.relays.iter_mut() { while let Some(transmit) = relay.poll_transmit(now) { - let Some(reply) = relay.exec_mut(|g| g.receive(transmit, now)) else { + let Some(reply) = relay.exec_mut(|r| r.receive(transmit, now)) else { continue; }; @@ -523,6 +525,18 @@ impl TunnelTest { relay.exec_mut(|r| r.sut.handle_timeout(now)) } + + for (_, dns_server) in self.dns_servers.iter_mut() { + while let Some(transmit) = dns_server.poll_transmit(now) { + let Some(reply) = + dns_server.exec_mut(|d| d.receive(global_dns_records, transmit, now)) + else { + continue; + }; + + buffered_transmits.push_from(reply, dns_server, now); + } + } } fn poll_timeout(&mut self) -> Option { @@ -591,6 +605,12 @@ impl TunnelTest { HostId::Stale => { tracing::debug!(%dst, "Dropping packet because host roamed away or is offline"); } + HostId::DnsServer(id) => { + self.dns_servers + .get_mut(&id) + .expect("unknown DNS server") + .receive(transmit, now); + } } } @@ -780,41 +800,6 @@ impl TunnelTest { ClientEvent::TunRoutesUpdated { .. } => {} } } - - // TODO: Should we vary the following things via proptests? - // - Forwarded DNS query timing out? - // - hickory error? - // - TTL? - fn on_forwarded_dns_query(&mut self, query: DnsQuery<'static>, ref_state: &ReferenceState) { - let all_ips = &ref_state - .global_dns_records - .get(&query.name) - .expect("Forwarded DNS query to be for known domain"); - - let name = domain_to_hickory_name(query.name.clone()); - let requested_type = query.record_type; - - let record_data = all_ips - .iter() - .filter_map(|ip| match (requested_type, ip) { - (RecordType::A, IpAddr::V4(v4)) => Some(RData::A((*v4).into())), - (RecordType::AAAA, IpAddr::V6(v6)) => Some(RData::AAAA((*v6).into())), - (RecordType::A, IpAddr::V6(_)) | (RecordType::AAAA, IpAddr::V4(_)) => None, - _ => unreachable!(), - }) - .map(|rdata| Record::from_rdata(name.clone(), 86400_u32, rdata)) - .collect::>(); - - self.client.exec_mut(|c| { - c.sut.on_dns_result( - query, - Ok(Ok(Ok(Lookup::new_with_max_ttl( - Query::query(name, requested_type), - record_data, - )))), - ) - }) - } } fn on_gateway_event( @@ -837,22 +822,3 @@ fn on_gateway_event( GatewayEvent::RefreshDns { .. } => todo!(), } } - -pub(crate) fn hickory_name_to_domain(mut name: hickory_proto::rr::Name) -> DomainName { - name.set_fqdn(false); // Hack to work around hickory always parsing as FQ - let name = name.to_string(); - - let domain = DomainName::from_chars(name.chars()).unwrap(); - debug_assert_eq!(name, domain.to_string()); - - domain -} - -pub(crate) fn domain_to_hickory_name(domain: DomainName) -> hickory_proto::rr::Name { - let domain = domain.to_string(); - - let name = hickory_proto::rr::Name::from_str(&domain).unwrap(); - debug_assert_eq!(name.to_string(), domain); - - name -} diff --git a/rust/connlib/tunnel/src/tests/transition.rs b/rust/connlib/tunnel/src/tests/transition.rs index 006c6135a..6b2f4bbc5 100644 --- a/rust/connlib/tunnel/src/tests/transition.rs +++ b/rust/connlib/tunnel/src/tests/transition.rs @@ -1,9 +1,12 @@ -use super::sim_net::{any_ip_stack, any_port, Host}; +use super::{ + sim_dns::RefDns, + sim_net::{any_ip_stack, any_port, Host}, +}; use connlib_shared::{ messages::{client::ResourceDescription, DnsServer, RelayId, ResourceId}, DomainName, }; -use hickory_proto::rr::RecordType; +use domain::base::Rtype; use proptest::{prelude::*, sample}; use std::{ collections::{BTreeMap, HashSet}, @@ -50,16 +53,16 @@ pub(crate) enum Transition { SendDnsQuery { domain: DomainName, /// The type of DNS query we should send. - r_type: RecordType, + r_type: Rtype, /// The DNS query ID. query_id: u16, dns_server: SocketAddr, }, /// The system's DNS servers changed. - UpdateSystemDnsServers { servers: Vec }, + UpdateSystemDnsServers(Vec), /// The upstream DNS servers changed. - UpdateUpstreamDnsServers { servers: Vec }, + UpdateUpstreamDnsServers(Vec), /// Roam the client to a new pair of sockets. RoamClient { @@ -165,7 +168,7 @@ where ( domain, dns_server.prop_map_into(), - prop_oneof![Just(RecordType::A), Just(RecordType::AAAA)], + prop_oneof![Just(Rtype::A), Just(Rtype::AAAA)], any::(), ) .prop_map( @@ -185,3 +188,27 @@ pub(crate) fn roam_client() -> impl Strategy { port, }) } + +pub(crate) fn update_system_dns_servers( + dns_servers: Vec>, +) -> impl Strategy { + let max = dns_servers.len(); + + sample::subsequence(dns_servers, ..=max).prop_map(|seq| { + Transition::UpdateSystemDnsServers( + seq.into_iter().map(|h| h.single_socket().ip()).collect(), + ) + }) +} + +pub(crate) fn update_upstream_dns_servers( + dns_servers: Vec>, +) -> impl Strategy { + let max = dns_servers.len(); + + sample::subsequence(dns_servers, ..=max).prop_map(|seq| { + Transition::UpdateUpstreamDnsServers( + seq.into_iter().map(|h| h.single_socket().into()).collect(), + ) + }) +} diff --git a/rust/headless-client/Cargo.toml b/rust/headless-client/Cargo.toml index 3688a9139..d95f16b87 100644 --- a/rust/headless-client/Cargo.toml +++ b/rust/headless-client/Cargo.toml @@ -26,7 +26,7 @@ socket-factory = { workspace = true } thiserror = { version = "1.0", default-features = false } # This actually relies on many other features in Tokio, so this will probably # fail to build outside the workspace. -tokio = { workspace = true, features = ["macros", "signal", "process", "time"] } +tokio = { workspace = true, features = ["macros", "signal", "process", "time", "rt-multi-thread"] } tokio-stream = "0.1.15" tokio-util = { version = "0.7.11", features = ["codec"] } tracing = { workspace = true } diff --git a/rust/ip-packet/Cargo.toml b/rust/ip-packet/Cargo.toml index 6f3275fe3..4863910a1 100644 --- a/rust/ip-packet/Cargo.toml +++ b/rust/ip-packet/Cargo.toml @@ -10,7 +10,7 @@ publish = false proptest = ["dep:proptest"] [dependencies] -hickory-proto = { workspace = true } +domain = "0.10.1" pnet_packet = { version = "0.35" } proptest = { version = "1", optional = true } thiserror = "1" diff --git a/rust/ip-packet/src/lib.rs b/rust/ip-packet/src/lib.rs index d90a7041b..dcb5f758e 100644 --- a/rust/ip-packet/src/lib.rs +++ b/rust/ip-packet/src/lib.rs @@ -8,6 +8,7 @@ pub use pnet_packet::*; #[cfg(all(test, feature = "proptest"))] mod proptests; +use domain::base::Message; use pnet_packet::{ icmp::{ destination_unreachable::IcmpCodes, echo_reply::MutableEchoReplyPacket, @@ -1409,9 +1410,9 @@ impl<'a> IpPacket<'a> { } /// Unwrap this [`IpPacket`] as a DNS message, panicking in case it is not. - pub fn unwrap_as_dns(&self) -> hickory_proto::op::Message { + pub fn unwrap_as_dns(&self) -> Message> { let udp = self.unwrap_as_udp(); - let message = match hickory_proto::op::Message::from_vec(udp.payload()) { + let message = match Message::from_octets(udp.payload().to_vec()) { Ok(message) => message, Err(e) => { panic!("Failed to parse UDP payload as DNS message: {e}"); diff --git a/rust/ip-packet/src/make.rs b/rust/ip-packet/src/make.rs index 14aef1ad9..24d27b7a0 100644 --- a/rust/ip-packet/src/make.rs +++ b/rust/ip-packet/src/make.rs @@ -1,9 +1,12 @@ //! Factory module for making all kinds of packets. use crate::{IpPacket, MutableIpPacket}; -use hickory_proto::{ - op::{Message, Query, ResponseCode}, - rr::{Name, RData, Record, RecordType}, +use domain::{ + base::{ + iana::{Class, Opcode, Rcode}, + MessageBuilder, Name, Question, Record, Rtype, ToName, Ttl, + }, + rdata::AllRecordData, }; use pnet_packet::{ ip::IpNextHeaderProtocol, @@ -244,24 +247,26 @@ where } pub fn dns_query( - domain: Name, - kind: RecordType, + domain: Name>, + kind: Rtype, src: SocketAddr, dst: SocketAddr, id: u16, ) -> MutableIpPacket<'static> { // Create the DNS query message - let mut msg = Message::new(); - msg.set_message_type(hickory_proto::op::MessageType::Query); - msg.set_op_code(hickory_proto::op::OpCode::Query); - msg.set_recursion_desired(true); - msg.set_id(id); + let mut msg_builder = MessageBuilder::new_vec(); + + msg_builder.header_mut().set_opcode(Opcode::QUERY); + msg_builder.header_mut().set_rd(true); + msg_builder.header_mut().set_id(id); // Create the query - let query = Query::query(domain, kind); - msg.add_query(query); + let mut question_builder = msg_builder.question(); + question_builder + .push(Question::new_in(domain, kind)) + .unwrap(); - let payload = msg.to_vec().unwrap(); + let payload = question_builder.finish(); udp_packet(src.ip(), dst.ip(), src.port(), dst.port(), payload) } @@ -269,60 +274,42 @@ pub fn dns_query( /// Makes a DNS response to the given DNS query packet, using a resolver callback. pub fn dns_ok_response( packet: IpPacket<'static>, - resolve: impl Fn(&Name) -> I, + resolve: impl Fn(&Name>) -> I, ) -> MutableIpPacket<'static> where I: Iterator, { let udp = packet.unwrap_as_udp(); - let mut query = packet.unwrap_as_dns(); + let query = packet.unwrap_as_dns(); - let mut response = Message::new(); - response.set_id(query.id()); - response.set_message_type(hickory_proto::op::MessageType::Response); + let response = MessageBuilder::new_vec(); + let mut answers = response.start_answer(&query, Rcode::NOERROR).unwrap(); - for query in query.take_queries() { - response.add_query(query.clone()); + for query in query.question() { + let query = query.unwrap(); + let name = query.qname().to_name(); - let records = resolve(query.name()) + let records = resolve(&name) .filter(|ip| { #[allow(clippy::wildcard_enum_match_arm)] - match query.query_type() { - RecordType::A => ip.is_ipv4(), - RecordType::AAAA => ip.is_ipv6(), + match query.qtype() { + Rtype::A => ip.is_ipv4(), + Rtype::AAAA => ip.is_ipv6(), _ => todo!(), } }) .map(|ip| match ip { - IpAddr::V4(v4) => RData::A(v4.into()), - IpAddr::V6(v6) => RData::AAAA(v6.into()), + IpAddr::V4(v4) => AllRecordData::, Name>>::A(v4.into()), + IpAddr::V6(v6) => AllRecordData::, Name>>::Aaaa(v6.into()), }) - .map(|rdata| Record::from_rdata(query.name().clone(), 86400_u32, rdata)); + .map(|rdata| Record::new(name.clone(), Class::IN, Ttl::from_days(1), rdata)); - response.add_answers(records); + for record in records { + answers.push(record).unwrap(); + } } - let payload = response.to_vec().unwrap(); - - udp_packet( - packet.destination(), - packet.source(), - udp.get_destination(), - udp.get_source(), - payload, - ) -} - -/// Makes a DNS response to the given DNS query packet, using the given error code. -pub fn dns_err_response(packet: IpPacket<'static>, code: ResponseCode) -> MutableIpPacket<'static> { - let udp = packet.unwrap_as_udp(); - let query = packet.unwrap_as_dns(); - - debug_assert_ne!(code, ResponseCode::NoError); - - let response = Message::error_msg(query.id(), query.op_code(), code); - - let payload = response.to_vec().unwrap(); + let payload = answers.finish(); udp_packet( packet.destination(), diff --git a/rust/socket-factory/Cargo.toml b/rust/socket-factory/Cargo.toml index b9ae869ab..d8f003aca 100644 --- a/rust/socket-factory/Cargo.toml +++ b/rust/socket-factory/Cargo.toml @@ -4,12 +4,7 @@ version = "0.1.0" edition = "2021" [dependencies] -async-trait = { version = "0.1", optional = true } -hickory-proto = { workspace = true, optional = true } quinn-udp = { git = "https://github.com/quinn-rs/quinn", branch = "main" } socket2 = { workspace = true } tokio = { version = "1.39", features = ["net"] } tracing = "0.1" - -[features] -hickory = ["dep:hickory-proto", "dep:async-trait"] diff --git a/rust/socket-factory/src/lib.rs b/rust/socket-factory/src/lib.rs index fe7967733..f17d42103 100644 --- a/rust/socket-factory/src/lib.rs +++ b/rust/socket-factory/src/lib.rs @@ -301,62 +301,3 @@ impl UdpSocket { Ok(Some(src)) } } - -#[cfg(feature = "hickory")] -mod hickory { - use super::*; - use hickory_proto::{ - udp::DnsUdpSocket as DnsUdpSocketTrait, udp::UdpSocket as UdpSocketTrait, TokioTime, - }; - use tokio::net::UdpSocket as TokioUdpSocket; - - #[async_trait::async_trait] - impl UdpSocketTrait for crate::UdpSocket { - /// setups up a "client" udp connection that will only receive packets from the associated address - async fn connect(addr: SocketAddr) -> io::Result { - let inner = ::connect(addr).await?; - let socket = Self::new(inner)?; - - Ok(socket) - } - - /// same as connect, but binds to the specified local address for sending address - async fn connect_with_bind(addr: SocketAddr, bind_addr: SocketAddr) -> io::Result { - let inner = - ::connect_with_bind(addr, bind_addr).await?; - let socket = Self::new(inner)?; - - Ok(socket) - } - - /// a "server" UDP socket, that bind to the local listening address, and unbound remote address (can receive from anything) - async fn bind(addr: SocketAddr) -> io::Result { - let inner = ::bind(addr).await?; - let socket = Self::new(inner)?; - - Ok(socket) - } - } - - #[cfg(feature = "hickory")] - impl DnsUdpSocketTrait for crate::UdpSocket { - type Time = TokioTime; - - fn poll_recv_from( - &self, - cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { - ::poll_recv_from(&self.inner, cx, buf) - } - - fn poll_send_to( - &self, - cx: &mut Context<'_>, - buf: &[u8], - target: SocketAddr, - ) -> Poll> { - ::poll_send_to(&self.inner, cx, buf, target) - } - } -}