diff --git a/rust/Cargo.lock b/rust/Cargo.lock index 3872d3160..c945d46f0 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -2690,6 +2690,7 @@ dependencies = [ "gat-lending-iterator", "glob", "hex", + "http-client", "ip-packet", "ip_network", "ip_network_table", @@ -2697,6 +2698,7 @@ dependencies = [ "l3-tcp", "l3-udp-dns-client", "l4-tcp-dns-server", + "l4-udp-dns-client", "l4-udp-dns-server", "lru", "opentelemetry", diff --git a/rust/client-shared/src/eventloop.rs b/rust/client-shared/src/eventloop.rs index 3d0379bdb..951698b36 100644 --- a/rust/client-shared/src/eventloop.rs +++ b/rust/client-shared/src/eventloop.rs @@ -200,7 +200,7 @@ impl Eventloop { return Ok(ControlFlow::Continue(())); }; - let dns = tunnel.state_mut().update_system_resolvers(dns); + let dns = tunnel.update_system_resolvers(dns); self.portal_cmd_tx .send(PortalCommand::UpdateDnsServers(dns)) diff --git a/rust/connlib/tunnel/Cargo.toml b/rust/connlib/tunnel/Cargo.toml index 47678d274..ce4c055d5 100644 --- a/rust/connlib/tunnel/Cargo.toml +++ b/rust/connlib/tunnel/Cargo.toml @@ -29,12 +29,14 @@ futures-bounded = { workspace = true, features = ["tokio"] } gat-lending-iterator = { workspace = true } glob = { workspace = true } hex = { workspace = true } +http-client = { workspace = true } ip-packet = { workspace = true } ip_network = { workspace = true } ip_network_table = { workspace = true } itertools = { workspace = true, features = ["use_std"] } l3-udp-dns-client = { workspace = true } l4-tcp-dns-server = { workspace = true } +l4-udp-dns-client = { workspace = true } l4-udp-dns-server = { workspace = true } lru = { workspace = true } opentelemetry = { workspace = true, features = ["metrics"] } diff --git a/rust/connlib/tunnel/src/client.rs b/rust/connlib/tunnel/src/client.rs index bd7b91f5e..03c644ea5 100644 --- a/rust/connlib/tunnel/src/client.rs +++ b/rust/connlib/tunnel/src/client.rs @@ -778,17 +778,27 @@ impl ClientState { /// For DNS queries to IPs that are a CIDR resources we want to mangle and forward to the gateway that handles that resource. /// /// We only want to do this if the upstream DNS server is set by the portal, otherwise, the server might be a local IP. - fn should_forward_dns_query_to_gateway(&self, dns_server: IpAddr) -> bool { + fn should_forward_dns_query_to_gateway( + &self, + dns_server: &dns::Upstream, + ) -> Option { if !self.dns_config.has_custom_upstream() { - return false; + return None; } + + let server = match dns_server { + dns::Upstream::Do53 { server } => server, + dns::Upstream::DoH { .. } => return None, // If DoH upstreams are in effect, we never forward queries to upstreams. + }; + if self.active_internet_resource().is_some() { - return true; + return Some(*server); } self.active_cidr_resources - .longest_match(dns_server) + .longest_match(server.ip()) .is_some() + .then_some(*server) } /// Handles UDP & TCP packets targeted at our stub resolver. @@ -1014,7 +1024,7 @@ impl ClientState { /// /// Note: The returned list is not necessarily the list of DNS resolvers that is active. /// If DNS servers are defined in the portal, those will be preferred over the system defined ones. - pub fn update_system_resolvers(&mut self, new_dns: Vec) -> Vec { + pub(crate) fn update_system_resolvers(&mut self, new_dns: Vec) -> Vec { let changed = self.dns_config.update_system_resolvers(new_dns); if !changed { @@ -1199,7 +1209,7 @@ impl ClientState { self.handle_dns_response( dns::RecursiveResponse { - server, + server: dns::Upstream::Do53 { server }, local, remote, query: query_result.query, @@ -1227,7 +1237,7 @@ impl ClientState { self.handle_dns_response( dns::RecursiveResponse { - server, + server: dns::Upstream::Do53 { server }, local, remote, query: query_result.query, @@ -1257,7 +1267,7 @@ impl ClientState { } } - fn handle_udp_dns_query(&mut self, upstream: SocketAddr, packet: IpPacket, now: Instant) { + fn handle_udp_dns_query(&mut self, upstream: dns::Upstream, packet: IpPacket, now: Instant) { let Some(datagram) = packet.as_udp() else { tracing::debug!(?packet, "Not a UDP packet"); @@ -1355,7 +1365,7 @@ impl ClientState { message: dns_types::Query, local: SocketAddr, remote: SocketAddr, - upstream: SocketAddr, + upstream: dns::Upstream, transport: dns::Transport, now: Instant, ) -> Option { @@ -1374,7 +1384,7 @@ impl ClientState { return Some(response); } dns::ResolveStrategy::RecurseLocal => { - if self.should_forward_dns_query_to_gateway(upstream.ip()) { + if let Some(upstream) = self.should_forward_dns_query_to_gateway(&upstream) { self.forward_dns_query_to_new_upstream_via_tunnel( local, remote, upstream, message, transport, now, ); diff --git a/rust/connlib/tunnel/src/client/dns_config.rs b/rust/connlib/tunnel/src/client/dns_config.rs index f76ec9551..b284376d2 100644 --- a/rust/connlib/tunnel/src/client/dns_config.rs +++ b/rust/connlib/tunnel/src/client/dns_config.rs @@ -1,6 +1,6 @@ use std::{ collections::HashSet, - net::{IpAddr, SocketAddr}, + net::{IpAddr, Ipv4Addr, SocketAddr}, }; use dns_types::DoHUrl; @@ -8,7 +8,7 @@ use ip_network::IpNetwork; use crate::{ client::{DNS_SENTINELS_V4, DNS_SENTINELS_V6, IpProvider}, - dns::DNS_PORT, + dns::{self, DNS_PORT}, }; #[derive(Debug, Default)] @@ -18,6 +18,7 @@ pub(crate) struct DnsConfig { /// The Do53 resolvers configured in the portal. /// /// Has priority over system-configured DNS servers. + /// Has priority over DoH resolvers. upstream_do53: Vec, /// The DoH resolvers configured in the portal. /// @@ -30,7 +31,7 @@ pub(crate) struct DnsConfig { #[derive(Debug, Clone, PartialEq, Eq, Default)] pub struct DnsMapping { - inner: Vec<(IpAddr, SocketAddr)>, + inner: Vec<(IpAddr, dns::Upstream)>, } impl DnsMapping { @@ -38,11 +39,11 @@ impl DnsMapping { self.inner.iter().map(|(ip, _)| ip).copied().collect() } - pub fn upstream_sockets(&self) -> Vec { + pub fn upstream_servers(&self) -> Vec { self.inner .iter() - .map(|(_, socket)| socket) - .copied() + .map(|(_, upstream)| upstream) + .cloned() .collect() } @@ -54,16 +55,16 @@ impl DnsMapping { // Most importantly, it is much easier for us to retain the ordering of the DNS servers if we don't use a map. #[cfg(test)] - pub(crate) fn sentinel_by_upstream(&self, upstream: SocketAddr) -> Option { + pub(crate) fn sentinel_by_upstream(&self, upstream: &dns::Upstream) -> Option { self.inner .iter() - .find_map(|(sentinel, candidate)| (candidate == &upstream).then_some(*sentinel)) + .find_map(|(sentinel, candidate)| (candidate == upstream).then_some(*sentinel)) } - pub(crate) fn upstream_by_sentinel(&self, sentinel: IpAddr) -> Option { + pub(crate) fn upstream_by_sentinel(&self, sentinel: IpAddr) -> Option { self.inner .iter() - .find_map(|(candidate, upstream)| (candidate == &sentinel).then_some(*upstream)) + .find_map(|(candidate, upstream)| (candidate == &sentinel).then_some(upstream.clone())) } } @@ -104,7 +105,7 @@ impl DnsConfig { } pub(crate) fn has_custom_upstream(&self) -> bool { - !self.upstream_do53.is_empty() + !self.upstream_do53.is_empty() || !self.upstream_doh.is_empty() } pub(crate) fn mapping(&mut self) -> DnsMapping { @@ -116,11 +117,14 @@ impl DnsConfig { } fn update_dns_mapping(&mut self) -> bool { - let effective_dns_servers = - effective_dns_servers(self.upstream_do53.clone(), self.system_resolvers.clone()); + let effective_dns_servers = effective_dns_servers( + self.upstream_do53.clone(), + self.upstream_doh.clone(), + self.system_resolvers.clone(), + ); - if HashSet::::from_iter(effective_dns_servers.clone()) - == HashSet::from_iter(self.mapping.upstream_sockets()) + if HashSet::::from_iter(effective_dns_servers.clone()) + == HashSet::from_iter(self.mapping.upstream_servers()) { tracing::debug!(servers = ?effective_dns_servers, "Effective DNS servers are unchanged"); @@ -135,12 +139,22 @@ impl DnsConfig { fn effective_dns_servers( upstream_do53: Vec, + upstream_doh: Vec, default_resolvers: Vec, -) -> Vec { +) -> Vec { if !upstream_do53.is_empty() { return upstream_do53 .into_iter() - .map(|ip| SocketAddr::new(ip, DNS_PORT)) + .map(|ip| dns::Upstream::Do53 { + server: SocketAddr::new(ip, DNS_PORT), + }) + .collect(); + } + + if !upstream_doh.is_empty() { + return upstream_doh + .into_iter() + .map(|server| dns::Upstream::DoH { server }) .collect(); } @@ -153,22 +167,28 @@ fn effective_dns_servers( default_resolvers .into_iter() - .map(|ip| SocketAddr::new(ip, DNS_PORT)) + .map(|ip| dns::Upstream::Do53 { + server: SocketAddr::new(ip, DNS_PORT), + }) .collect() } -fn sentinel_dns_mapping(dns: &[SocketAddr], old_sentinels: Vec) -> DnsMapping { +fn sentinel_dns_mapping(dns: &[dns::Upstream], old_sentinels: Vec) -> DnsMapping { let mut ip_provider = IpProvider::for_stub_dns_servers(old_sentinels); let mapping = dns .iter() - .copied() - .map(|i| { + .map(|u| { + let ip_addr = match u { + dns::Upstream::Do53 { server } => server.ip(), + dns::Upstream::DoH { .. } => IpAddr::V4(Ipv4Addr::UNSPECIFIED), // DoH servers are always mapped to IPv4 servers. + }; + ( ip_provider - .get_proxy_ip_for(&i.ip()) + .get_proxy_ip_for(&ip_addr) .expect("We only support up to 256 IPv4 DNS servers and 256 IPv6 DNS servers"), - i, + u.clone(), ) }) .collect(); @@ -204,11 +224,11 @@ mod tests { assert_eq!(config.mapping().sentinel_ips().len(), 3); assert_eq!( - config.mapping().upstream_sockets(), + config.mapping().upstream_servers(), vec![ - socket("1.1.1.1:53"), - socket("1.0.0.1:53"), - socket("[2606:4700:4700::1111]:53"), + do53("1.1.1.1:53"), + do53("1.0.0.1:53"), + do53("[2606:4700:4700::1111]:53"), ] ); } @@ -224,8 +244,8 @@ mod tests { assert_eq!(config.mapping().sentinel_ips().len(), 1); assert_eq!( - config.mapping().upstream_sockets(), - vec![socket("1.0.0.1:53"),] + config.mapping().upstream_servers(), + vec![do53("1.0.0.1:53"),] ); } @@ -238,8 +258,8 @@ mod tests { assert_eq!(config.mapping().sentinel_ips().len(), 1); assert_eq!( - config.mapping().upstream_sockets(), - vec![socket("1.1.1.1:53"),] + config.mapping().upstream_servers(), + vec![do53("1.1.1.1:53"),] ); } @@ -253,8 +273,8 @@ mod tests { assert_eq!(config.mapping().sentinel_ips().len(), 1); assert_eq!( - config.mapping().upstream_sockets(), - vec![socket("1.1.1.1:53"),] + config.mapping().upstream_servers(), + vec![do53("1.1.1.1:53"),] ); } @@ -262,7 +282,9 @@ mod tests { address.parse().unwrap() } - fn socket(socket: &str) -> SocketAddr { - socket.parse().unwrap() + fn do53(socket: &str) -> dns::Upstream { + dns::Upstream::Do53 { + server: socket.parse().unwrap(), + } } } diff --git a/rust/connlib/tunnel/src/dns.rs b/rust/connlib/tunnel/src/dns.rs index 3d730bf18..89c2feff8 100644 --- a/rust/connlib/tunnel/src/dns.rs +++ b/rust/connlib/tunnel/src/dns.rs @@ -2,8 +2,8 @@ use crate::client::IpProvider; use anyhow::Result; use connlib_model::{IpStack, ResourceId}; use dns_types::{ - DomainName, DomainNameRef, OwnedRecordData, Query, RecordType, Response, ResponseBuilder, - ResponseCode, + DoHUrl, DomainName, DomainNameRef, OwnedRecordData, Query, RecordType, Response, + ResponseBuilder, ResponseCode, }; use firezone_logging::err_with_src; use itertools::Itertools; @@ -54,7 +54,7 @@ struct Resource { #[derive(Debug)] pub(crate) struct RecursiveQuery { /// The server we want to send the query to. - pub server: SocketAddr, + pub server: Upstream, /// The local address we received the query on. pub local: SocketAddr, @@ -73,7 +73,7 @@ pub(crate) struct RecursiveQuery { #[derive(Debug)] pub(crate) struct RecursiveResponse { /// The server we sent the query to. - pub server: SocketAddr, + pub server: Upstream, /// The local address we received the original query on. pub local: SocketAddr, @@ -99,6 +99,14 @@ pub(crate) enum Transport { Tcp, } +#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, derive_more::Display)] +pub enum Upstream { + #[display("Do53({server})")] + Do53 { server: SocketAddr }, + #[display("DoH({server})")] + DoH { server: DoHUrl }, +} + /// Tells the Client how to reply to a single DNS query #[derive(Debug)] pub(crate) enum ResolveStrategy { diff --git a/rust/connlib/tunnel/src/io.rs b/rust/connlib/tunnel/src/io.rs index 63c64a925..9fe1978fc 100644 --- a/rust/connlib/tunnel/src/io.rs +++ b/rust/connlib/tunnel/src/io.rs @@ -1,3 +1,4 @@ +mod doh; mod gso_queue; mod nameserver_set; mod tcp_dns; @@ -6,10 +7,12 @@ mod udp_dns; use crate::{TunnelError, device_channel::Device, dns, otel, sockets::Sockets}; use anyhow::{Context as _, Result}; use chrono::{DateTime, Utc}; +use dns_types::DoHUrl; use futures::FutureExt as _; -use futures_bounded::FuturesTupleSet; +use futures_bounded::{FuturesMap, FuturesTupleSet}; use gat_lending_iterator::LendingIterator; use gso_queue::GsoQueue; +use http_client::HttpClient; use ip_packet::{Ecn, IpPacket, MAX_FZ_PAYLOAD}; use nameserver_set::NameserverSet; use socket_factory::{DatagramIn, SocketFactory, TcpSocket, UdpSocket}; @@ -57,6 +60,10 @@ pub struct Io { dns_queries: FuturesTupleSet, DnsQueryMetaData>, + udp_dns_client: l4_udp_dns_client::UdpDnsClient, + doh_clients: BTreeMap, + doh_clients_bootstrap: FuturesMap>, + timeout: Option>>, tun: Device, @@ -65,10 +72,10 @@ pub struct Io { dropped_packets: opentelemetry::metrics::Counter, } -#[derive(Debug)] +#[derive(Debug, Clone)] struct DnsQueryMetaData { query: dns_types::Query, - server: SocketAddr, + server: dns::Upstream, local: SocketAddr, remote: SocketAddr, transport: dns::Transport, @@ -164,6 +171,10 @@ impl Io { tcp_socket_factory.clone(), udp_socket_factory.clone(), ), + udp_dns_client: l4_udp_dns_client::UdpDnsClient::new( + udp_socket_factory.clone(), + Vec::default(), + ), reval_nameserver_interval: tokio::time::interval(RE_EVALUATE_NAMESERVER_INTERVAL), tcp_socket_factory, udp_socket_factory, @@ -171,6 +182,11 @@ impl Io { || futures_bounded::Delay::tokio(DNS_QUERY_TIMEOUT), 1000, ), + doh_clients: Default::default(), + doh_clients_bootstrap: FuturesMap::new( + || futures_bounded::Delay::tokio(DNS_QUERY_TIMEOUT), + 10, + ), gso_queue: GsoQueue::new(), tun: Device::new(), udp_dns_server: Default::default(), @@ -203,6 +219,13 @@ impl Io { Ok(()) } + pub fn update_system_resolvers(&mut self, resolvers: Vec) { + tracing::debug!(servers = ?resolvers, "Re-configuring UDP DNS client with new upstreams"); + + self.udp_dns_client = + l4_udp_dns_client::UdpDnsClient::new(self.udp_socket_factory.clone(), resolvers) + } + pub fn poll_has_sockets(&mut self, cx: &mut Context<'_>) -> Poll<()> { self.sockets.poll_has_sockets(cx) } @@ -234,6 +257,16 @@ impl Io { // We purposely don't want to block the event loop here because we can do plenty of other work while this is running. let _ = self.nameservers.poll(cx); + while let Poll::Ready((url, result)) = self.doh_clients_bootstrap.poll_unpin(cx) { + match result { + Ok(Ok(client)) => { + self.doh_clients.insert(url.clone(), client); + } + Ok(Err(e)) => tracing::debug!(%url, "Failed to bootstrap DoH client: {e:#}"), + Err(e) => tracing::debug!(%url, "Failed to bootstrap DoH client: {e:#}"), + } + } + let network = self.sockets.poll_recv_from(cx).map(|network| { anyhow::Ok( network @@ -323,6 +356,18 @@ impl Io { }, }); + // We need to discard DoH clients if their queries fail because the connection got closed. + // They will get re-bootstrapped on the next requested DoH query. + if let Poll::Ready(response) = &dns_response + && let dns::Upstream::DoH { server } = &response.server + && let Err(e) = &response.message + && e.is::() + { + tracing::debug!(%server, "Connection of DoH client failed"); + + self.doh_clients.remove(server); + } + let timeout = self .timeout .as_mut() @@ -423,6 +468,10 @@ impl Io { self.dns_queries = FuturesTupleSet::new(|| futures_bounded::Delay::tokio(DNS_QUERY_TIMEOUT), 1000); self.nameservers.evaluate(); + + for (server, _) in std::mem::take(&mut self.doh_clients) { + self.bootstrap_doh_client(server); + } } pub fn reset_timeout(&mut self, timeout: Instant, reason: &'static str) { @@ -470,40 +519,69 @@ impl Io { pub fn send_dns_query(&mut self, query: dns::RecursiveQuery) { let meta = DnsQueryMetaData { query: query.message.clone(), - server: query.server, + server: query.server.clone(), transport: query.transport, local: query.local, remote: query.remote, }; - match query.transport { - dns::Transport::Udp => { - if self - .dns_queries - .try_push( - udp_dns::send(self.udp_socket_factory.clone(), query.server, query.message), - meta, - ) - .is_err() - { - tracing::debug!("Failed to queue UDP DNS query") - } + match (query.transport, query.server) { + (dns::Transport::Udp, dns::Upstream::Do53 { server }) => { + self.queue_dns_query( + udp_dns::send(self.udp_socket_factory.clone(), server, query.message), + meta, + ); } - dns::Transport::Tcp => { - if self - .dns_queries - .try_push( - tcp_dns::send(self.tcp_socket_factory.clone(), query.server, query.message), - meta, - ) - .is_err() - { - tracing::debug!("Failed to queue TCP DNS query") - } + (dns::Transport::Tcp, dns::Upstream::Do53 { server }) => { + self.queue_dns_query( + tcp_dns::send(self.tcp_socket_factory.clone(), server, query.message), + meta, + ); + } + (_, dns::Upstream::DoH { server }) => { + let Some(http_client) = self.doh_clients.get(&server).cloned() else { + self.bootstrap_doh_client(server); + + // Queue a dummy "query" that instantly fails to ensure we don't let the application run into a timeout. + // This will trigger a SERVFAIL response. + self.queue_dns_query(async { anyhow::bail!("Bootstrapping DoH client") }, meta); + + return; + }; + + self.queue_dns_query(doh::send(http_client, server, query.message), meta); } } } + pub(crate) fn bootstrap_doh_client(&mut self, server: DoHUrl) { + if self.doh_clients.contains_key(&server) { + return; + } + + if self.doh_clients_bootstrap.contains(server.clone()) { + return; // Already bootstrapping. + } + + let socket_factory = self.tcp_socket_factory.clone(); + let addresses = self.udp_dns_client.resolve(server.host()); + + let _ = self + .doh_clients_bootstrap + .try_push(server.clone(), async move { + tracing::debug!(%server, "Bootstrapping DoH client"); + + let addresses = addresses.await?; + let http_client = + HttpClient::new(server.host().to_string(), addresses.clone(), socket_factory) + .await?; + + tracing::debug!(%server, "Bootstrapped DoH client"); + + Ok(http_client) + }); + } + pub(crate) fn send_udp_dns_response( &mut self, to: SocketAddr, @@ -531,6 +609,16 @@ impl Io { pub(crate) fn inc_dropped_packet(&self, attrs: &[opentelemetry::KeyValue]) { self.dropped_packets.add(1, attrs); } + + fn queue_dns_query( + &mut self, + future: impl Future> + Send + 'static, + meta: DnsQueryMetaData, + ) { + if self.dns_queries.try_push(future, meta.clone()).is_err() { + tracing::debug!(?meta, "Failed to queue DNS query") + } + } } fn is_max_wg_packet_size(d: &DatagramIn) -> bool { @@ -545,7 +633,7 @@ fn is_max_wg_packet_size(d: &DatagramIn) -> bool { #[cfg(test)] mod tests { use futures::task::noop_waker_ref; - use std::{future::poll_fn, ptr::addr_of_mut}; + use std::{future::poll_fn, net::Ipv4Addr, ptr::addr_of_mut}; use super::*; @@ -581,14 +669,62 @@ mod tests { assert!(timeout >= now, "timeout = {timeout:?}, now = {now:?}"); } + #[tokio::test] + async fn bootstrap_doh() { + let _guard = firezone_logging::test("debug"); + + let mut io = Io::for_test(); + io.update_system_resolvers(vec![IpAddr::from([1, 1, 1, 1])]); + + { + io.send_dns_query(example_com_recursive_query()); + + let input = io.next().await; + + assert_eq!( + input.dns_response.unwrap().message.unwrap_err().to_string(), + "Bootstrapping DoH client" + ); + } + + // Hack: Advance for a bit but timeout after 2s. We don't emit an event when the client is bootstrapped so this will always be `Pending`. + let _ = tokio::time::timeout(Duration::from_secs(2), io.next()).await; + + { + io.send_dns_query(example_com_recursive_query()); + + let input = io.next().await; + + assert_eq!( + input.dns_response.unwrap().message.unwrap().response_code(), + dns_types::ResponseCode::NOERROR + ); + } + } + + fn example_com_recursive_query() -> dns::RecursiveQuery { + dns::RecursiveQuery { + server: dns::Upstream::DoH { + server: DoHUrl::cloudflare(), + }, + local: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 11111), + remote: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 22222), + message: dns_types::Query::new( + "example.com".parse().unwrap(), + dns_types::RecordType::A, + ), + transport: dns::Transport::Udp, + } + } + static mut DUMMY_BUF: Buffers = Buffers { ip: Vec::new() }; /// Helper functions to make the test more concise. impl Io { fn for_test() -> Io { let mut io = Io::new( - Arc::new(|_| Err(io::Error::other("not implemented"))), - Arc::new(|_| Err(io::Error::other("not implemented"))), + Arc::new(socket_factory::tcp), + Arc::new(socket_factory::udp), BTreeSet::new(), ); io.set_tun(Box::new(DummyTun)); diff --git a/rust/connlib/tunnel/src/io/doh.rs b/rust/connlib/tunnel/src/io/doh.rs new file mode 100644 index 000000000..5d8216fc2 --- /dev/null +++ b/rust/connlib/tunnel/src/io/doh.rs @@ -0,0 +1,17 @@ +use anyhow::Result; +use dns_types::DoHUrl; +use http_client::HttpClient; + +pub async fn send( + client: HttpClient, + server: DoHUrl, + query: dns_types::Query, +) -> Result { + tracing::trace!(target: "wire::dns::recursive::https", %server, domain = %query.domain()); + + let request = query.try_into_http_request(&server)?; + let response = client.send_request(request)?.await?; + let response = dns_types::Response::try_from_http_response(response)?; + + Ok(response) +} diff --git a/rust/connlib/tunnel/src/lib.rs b/rust/connlib/tunnel/src/lib.rs index 48c429f6c..40bef8ed5 100644 --- a/rust/connlib/tunnel/src/lib.rs +++ b/rust/connlib/tunnel/src/lib.rs @@ -145,6 +145,13 @@ impl ClientTunnel { self.io.reset(); } + pub fn update_system_resolvers(&mut self, resolvers: Vec) -> Vec { + let resolvers = self.role_state.update_system_resolvers(resolvers); + self.io.update_system_resolvers(resolvers.clone()); // IO needs the system resolvers to bootstrap DoH upstream. + + resolvers + } + /// Shut down the Client tunnel. pub fn shut_down(mut self) -> BoxFuture<'static, Result<()>> { // Initiate shutdown. @@ -178,6 +185,16 @@ impl ClientTunnel { // Pass up existing events. if let Some(event) = self.role_state.poll_event() { + if let ClientEvent::TunInterfaceUpdated(config) = &event { + for url in &config.dns_by_sentinel.upstream_servers() { + let dns::Upstream::DoH { server } = url else { + continue; + }; + + self.io.bootstrap_doh_client(server.clone()); + } + } + return Poll::Ready(event); } @@ -480,7 +497,9 @@ impl GatewayTunnel { for query in udp_dns_queries { if let Some(nameserver) = self.io.fastest_nameserver() { self.io.send_dns_query(dns::RecursiveQuery { - server: SocketAddr::new(nameserver, dns::DNS_PORT), + server: dns::Upstream::Do53 { + server: SocketAddr::new(nameserver, dns::DNS_PORT), + }, local: query.local, remote: query.remote, message: query.message, @@ -504,7 +523,9 @@ impl GatewayTunnel { for query in tcp_dns_queries { if let Some(nameserver) = self.io.fastest_nameserver() { self.io.send_dns_query(dns::RecursiveQuery { - server: SocketAddr::new(nameserver, dns::DNS_PORT), + server: dns::Upstream::Do53 { + server: SocketAddr::new(nameserver, dns::DNS_PORT), + }, local: query.local, remote: query.remote, message: query.message, diff --git a/rust/connlib/tunnel/src/tests/assertions.rs b/rust/connlib/tunnel/src/tests/assertions.rs index ad6d7804a..292a03d7b 100644 --- a/rust/connlib/tunnel/src/tests/assertions.rs +++ b/rust/connlib/tunnel/src/tests/assertions.rs @@ -351,7 +351,7 @@ pub(crate) fn assert_udp_dns_packets_properties(ref_client: &RefClient, sim_clie for (dns_server, query_id) in ref_client.expected_udp_dns_handshakes.iter() { let _guard = tracing::info_span!(target: "assertions", "udp_dns", %query_id, %dns_server).entered(); - let key = &(*dns_server, *query_id); + let key = &(dns_server.clone(), *query_id); let queries = &sim_client.sent_udp_dns_queries; let responses = &sim_client.received_udp_dns_responses; @@ -374,7 +374,7 @@ pub(crate) fn assert_tcp_dns(ref_client: &RefClient, sim_client: &SimClient) { for (dns_server, query_id) in ref_client.expected_tcp_dns_handshakes.iter() { let _guard = tracing::info_span!(target: "assertions", "tcp_dns", %query_id, %dns_server).entered(); - let key = &(*dns_server, *query_id); + let key = &(dns_server.clone(), *query_id); let queries = &sim_client.sent_tcp_dns_queries; let responses = &sim_client.received_tcp_dns_responses; diff --git a/rust/connlib/tunnel/src/tests/reference.rs b/rust/connlib/tunnel/src/tests/reference.rs index 58cd4f8c8..1319de69a 100644 --- a/rust/connlib/tunnel/src/tests/reference.rs +++ b/rust/connlib/tunnel/src/tests/reference.rs @@ -4,8 +4,8 @@ use super::{ composite_strategy::CompositeStrategy, sim_client::*, sim_gateway::*, sim_net::*, strategies::*, stub_portal::StubPortal, transition::*, }; -use crate::client; use crate::proptest::domain_label; +use crate::{client, dns}; use crate::{dns::is_subdomain, proptest::relay_id}; use connlib_model::{GatewayId, RelayId, Site, StaticSecret}; use dns_types::{DomainName, RecordType}; @@ -756,10 +756,12 @@ impl ReferenceState { Transition::UpdateUpstreamDoHServers(_) => true, Transition::UpdateUpstreamSearchDomain(_) => true, Transition::SendDnsQueries(queries) => queries.iter().all(|query| { - let has_socket_for_server = state - .client - .sending_socket_for(query.dns_server.ip()) - .is_some(); + let has_socket_for_server = match query.dns_server { + crate::dns::Upstream::Do53 { server } => { + state.client.sending_socket_for(server.ip()).is_some() + } + crate::dns::Upstream::DoH { .. } => true, + }; let has_dns_server = state .client @@ -919,14 +921,19 @@ impl ReferenceState { Vec::from_iter(unique_domains) } - fn reachable_dns_servers(&self) -> Vec { + fn reachable_dns_servers(&self) -> Vec { self.client .inner() .expected_dns_servers() .into_iter() .filter(|s| match s { - SocketAddr::V4(_) => self.client.ip4.is_some(), - SocketAddr::V6(_) => self.client.ip6.is_some(), + crate::dns::Upstream::Do53 { + server: SocketAddr::V4(_), + } => self.client.ip4.is_some(), + crate::dns::Upstream::Do53 { + server: SocketAddr::V6(_), + } => self.client.ip6.is_some(), + crate::dns::Upstream::DoH { .. } => true, }) .collect() } diff --git a/rust/connlib/tunnel/src/tests/sim_client.rs b/rust/connlib/tunnel/src/tests/sim_client.rs index af532c3f7..38df0b937 100644 --- a/rust/connlib/tunnel/src/tests/sim_client.rs +++ b/rust/connlib/tunnel/src/tests/sim_client.rs @@ -8,7 +8,7 @@ use super::{ transition::{DPort, Destination, DnsQuery, DnsTransport, Identifier, SPort, Seq}, }; use crate::{ - ClientState, DnsMapping, DnsResourceRecord, + ClientState, DnsMapping, DnsResourceRecord, dns, messages::{UpstreamDo53, UpstreamDoH}, proptest::*, }; @@ -61,11 +61,11 @@ pub(crate) struct SimClient { pub(crate) resource_status: BTreeMap, - pub(crate) sent_udp_dns_queries: HashMap<(SocketAddr, QueryId), IpPacket>, - pub(crate) received_udp_dns_responses: BTreeMap<(SocketAddr, QueryId), IpPacket>, + pub(crate) sent_udp_dns_queries: HashMap<(dns::Upstream, QueryId), IpPacket>, + pub(crate) received_udp_dns_responses: BTreeMap<(dns::Upstream, QueryId), IpPacket>, - pub(crate) sent_tcp_dns_queries: HashSet<(SocketAddr, QueryId)>, - pub(crate) received_tcp_dns_responses: BTreeSet<(SocketAddr, QueryId)>, + pub(crate) sent_tcp_dns_queries: HashSet<(dns::Upstream, QueryId)>, + pub(crate) received_tcp_dns_responses: BTreeSet<(dns::Upstream, QueryId)>, pub(crate) sent_icmp_requests: HashMap<(Seq, Identifier), IpPacket>, pub(crate) received_icmp_replies: BTreeMap<(Seq, Identifier), IpPacket>, @@ -138,8 +138,8 @@ impl SimClient { } /// Returns the _effective_ DNS servers that connlib is using. - pub(crate) fn effective_dns_servers(&self) -> Vec { - self.dns_by_sentinel.upstream_sockets() + pub(crate) fn effective_dns_servers(&self) -> Vec { + self.dns_by_sentinel.upstream_servers() } pub(crate) fn effective_search_domain(&self) -> Option { @@ -160,11 +160,11 @@ impl SimClient { domain: DomainName, r_type: RecordType, query_id: u16, - upstream: SocketAddr, + upstream: dns::Upstream, dns_transport: DnsTransport, now: Instant, ) -> Option { - let Some(sentinel) = self.dns_by_sentinel.sentinel_by_upstream(upstream) else { + let Some(sentinel) = self.dns_by_sentinel.sentinel_by_upstream(&upstream) else { tracing::error!(%upstream, "Unknown DNS server"); return None; }; @@ -493,10 +493,10 @@ pub struct RefClient { /// The expected UDP DNS handshakes. #[debug(skip)] - pub(crate) expected_udp_dns_handshakes: VecDeque<(SocketAddr, QueryId)>, + pub(crate) expected_udp_dns_handshakes: VecDeque<(dns::Upstream, QueryId)>, /// The expected TCP DNS handshakes. #[debug(skip)] - pub(crate) expected_tcp_dns_handshakes: VecDeque<(SocketAddr, QueryId)>, + pub(crate) expected_tcp_dns_handshakes: VecDeque<(dns::Upstream, QueryId)>, } impl RefClient { @@ -926,11 +926,11 @@ impl RefClient { match query.transport { DnsTransport::Udp => { self.expected_udp_dns_handshakes - .push_back((query.dns_server, query.query_id)); + .push_back((query.dns_server.clone(), query.query_id)); } DnsTransport::Tcp => { self.expected_tcp_dns_handshakes - .push_back((query.dns_server, query.query_id)); + .push_back((query.dns_server.clone(), query.query_id)); } } @@ -1076,22 +1076,37 @@ impl RefClient { /// Returns the DNS servers that we expect connlib to use. /// - /// If there are upstream DNS servers configured in the portal, it should use those. + /// If there are upstream Do53 servers configured in the portal, it should use those. + /// If there are no custom servers defined, it should use the DoH servers specified in the portal. /// Otherwise it should use whatever was configured on the system prior to connlib starting. /// /// This purposely returns a `Vec` so we also assert the order! - pub(crate) fn expected_dns_servers(&self) -> Vec { + pub(crate) fn expected_dns_servers(&self) -> Vec { if !self.upstream_do53_resolvers.is_empty() { return self .upstream_do53_resolvers .iter() - .map(|u| SocketAddr::new(u.ip, 53)) + .map(|u| dns::Upstream::Do53 { + server: SocketAddr::new(u.ip, 53), + }) + .collect(); + } + + if !self.upstream_doh_resolvers.is_empty() { + return self + .upstream_doh_resolvers + .iter() + .map(|u| dns::Upstream::DoH { + server: u.url.clone(), + }) .collect(); } self.system_dns_resolvers .iter() - .map(|ip| SocketAddr::new(*ip, 53)) + .map(|ip| dns::Upstream::Do53 { + server: SocketAddr::new(*ip, 53), + }) .collect() } @@ -1185,7 +1200,12 @@ impl RefClient { return None; } - let maybe_active_cidr_resource = self.cidr_resource_by_ip(query.dns_server.ip()); + let server = match query.dns_server { + dns::Upstream::Do53 { server } => server, + dns::Upstream::DoH { .. } => return None, + }; + + let maybe_active_cidr_resource = self.cidr_resource_by_ip(server.ip()); let maybe_active_internet_resource = self.active_internet_resource(); maybe_active_cidr_resource.or(maybe_active_internet_resource) diff --git a/rust/connlib/tunnel/src/tests/sut.rs b/rust/connlib/tunnel/src/tests/sut.rs index ff0babbbe..5e71bddbc 100644 --- a/rust/connlib/tunnel/src/tests/sut.rs +++ b/rust/connlib/tunnel/src/tests/sut.rs @@ -279,7 +279,7 @@ impl TunnelTest { upstream_dns: vec![], upstream_do53, search_domain: ref_state.client.inner().search_domain.clone(), - upstream_doh: vec![], + upstream_doh: ref_state.client.inner().upstream_doh_resolvers(), }) }); } @@ -424,6 +424,7 @@ impl TunnelTest { let ipv6 = state.client.inner().sut.tunnel_ip_config().unwrap().v6; let system_dns = ref_state.client.inner().system_dns_resolvers(); let upstream_do53 = ref_state.client.inner().upstream_do53_resolvers(); + let upstream_doh = ref_state.client.inner().upstream_doh_resolvers(); let all_resources = ref_state.client.inner().all_resources(); let internet_resource_state = ref_state.client.inner().internet_resource_active; @@ -436,8 +437,8 @@ impl TunnelTest { ipv6, upstream_dns: Vec::new(), upstream_do53, + upstream_doh, search_domain: ref_state.client.inner().search_domain.clone(), - upstream_doh: Vec::new(), }); c.sut.update_system_resolvers(system_dns); c.sut.set_resources(all_resources, now); @@ -927,7 +928,18 @@ impl TunnelTest { for gateway in self.gateways.values_mut() { gateway.exec_mut(|g| { - g.deploy_new_dns_servers(config.dns_by_sentinel.upstream_sockets(), now) + // If DoH servers are configured, we never route them through the tunnel. + // Therefore, we also don't need to "deploy" any DNS servers here. + let upstream_do53_servers = config + .dns_by_sentinel + .upstream_servers() + .into_iter() + .filter_map(|u| match u { + dns::Upstream::Do53 { server } => Some(server), + dns::Upstream::DoH { .. } => None, + }); + + g.deploy_new_dns_servers(upstream_do53_servers, now) }) } diff --git a/rust/connlib/tunnel/src/tests/transition.rs b/rust/connlib/tunnel/src/tests/transition.rs index d4d60a45e..32b168bdf 100644 --- a/rust/connlib/tunnel/src/tests/transition.rs +++ b/rust/connlib/tunnel/src/tests/transition.rs @@ -1,5 +1,6 @@ use crate::{ client::{CidrResource, IPV4_RESOURCES, IPV6_RESOURCES, Resource}, + dns, messages::{UpstreamDo53, UpstreamDoH}, proptest::{host_v4, host_v6}, }; @@ -15,7 +16,7 @@ use prop::collection; use proptest::{prelude::*, sample}; use std::{ collections::{BTreeMap, BTreeSet}, - net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, + net::{IpAddr, Ipv4Addr, Ipv6Addr}, num::NonZeroU16, }; @@ -119,7 +120,7 @@ pub(crate) struct DnsQuery { pub(crate) r_type: RecordType, /// The DNS query ID. pub(crate) query_id: u16, - pub(crate) dns_server: SocketAddr, + pub(crate) dns_server: dns::Upstream, pub(crate) transport: DnsTransport, } @@ -352,7 +353,7 @@ fn non_dns_ports() -> impl Strategy { /// Samples up to 5 DNS queries that will be sent concurrently into connlib. pub(crate) fn dns_queries( domain: impl Strategy)>, - dns_server: impl Strategy, + dns_server: impl Strategy, ) -> impl Strategy> { // Queries can be uniquely identified by the tuple of DNS server and query ID. let unique_queries = collection::btree_set((dns_server, any::()), 1..5);