diff --git a/.github/workflows/_integration_tests.yml b/.github/workflows/_integration_tests.yml index 702b0d078..b81bc5b34 100644 --- a/.github/workflows/_integration_tests.yml +++ b/.github/workflows/_integration_tests.yml @@ -106,6 +106,7 @@ jobs: # Too noisy can cause flaky tests due to the amount of data rust_log: debug - name: dns-nm + - name: tcp-dns - name: relay-graceful-shutdown - name: systemd/dns-systemd-resolved steps: diff --git a/rust/Cargo.lock b/rust/Cargo.lock index a01754d7f..628b77edc 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -2095,6 +2095,7 @@ dependencies = [ "connlib-model", "derivative", "divan", + "dns-over-tcp", "domain", "firezone-logging", "firezone-relay", @@ -3145,7 +3146,6 @@ name = "ip-packet" version = "0.1.0" dependencies = [ "anyhow", - "domain", "etherparse", "proptest", "test-strategy", diff --git a/rust/connlib/tunnel/Cargo.toml b/rust/connlib/tunnel/Cargo.toml index 2c453873f..74426a47b 100644 --- a/rust/connlib/tunnel/Cargo.toml +++ b/rust/connlib/tunnel/Cargo.toml @@ -13,6 +13,7 @@ chrono = { workspace = true } connlib-model = { workspace = true } derivative = "2.2.0" divan = { version = "0.1.14", optional = true } +dns-over-tcp = { workspace = true } domain = { workspace = true } firezone-logging = { workspace = true } futures = { version = "0.3", default-features = false, features = ["std", "async-await", "executor"] } diff --git a/rust/connlib/tunnel/proptest-regressions/tests.txt b/rust/connlib/tunnel/proptest-regressions/tests.txt index 949b601d8..e00bf7fbe 100644 --- a/rust/connlib/tunnel/proptest-regressions/tests.txt +++ b/rust/connlib/tunnel/proptest-regressions/tests.txt @@ -118,3 +118,9 @@ cc b4dd2a98e4e6aa29f875fa7b8e3af451b1ce6ef8b4e9d6c4cd29fcb68e9249de cc 0a717e57a998e97be9134007c6a102c5ebaba5c477c95003eaa8f3c4503f88f1 cc 1ead95151ff4ea386b990d1ec7c81a33a816bd8f81d3e3b54abf181e9ff7f3c7 cc 879b2d7d9592265e8cb2799fc0a5d6ab19c6637f53a3181d9613ac3be3e4e532 +cc a5f733ee61b9a545b93f5eccb71631918250f8b0657b2479c5f2e85c10fd013d +cc a5f733ee61b9a545b93f5eccb71631918250f8b0657b2479c5f2e85c10fd013d +cc 33cd1cba9c6ecf15d6ff86c3114752f2437e432c77f671f67b08116d2b507131 +cc d9793b201ec425bd77f9849ea48e63677014aeb4a91a55be9371b81e644b7a24 +cc 8fcbd19c41f0483d9b81aac2ab7440bb23d7796ef9f6bf346f73f0d633f65baa +cc 4494e475d22ff9a318d676f10c79f545982b7787d145925c3719fe47e9868acc diff --git a/rust/connlib/tunnel/src/client.rs b/rust/connlib/tunnel/src/client.rs index 8c1362de5..6fd424612 100644 --- a/rust/connlib/tunnel/src/client.rs +++ b/rust/connlib/tunnel/src/client.rs @@ -1,6 +1,5 @@ mod resource; -use domain::base::iana::Rcode; pub(crate) use resource::{CidrResource, Resource}; #[cfg(all(feature = "proptest", test))] pub(crate) use resource::{DnsResource, InternetResource}; @@ -24,7 +23,7 @@ use itertools::Itertools; use crate::peer::GatewayOnClient; use crate::utils::earliest; use crate::ClientEvent; -use domain::base::{Message, MessageBuilder}; +use domain::base::Message; use lru::LruCache; use secrecy::{ExposeSecret as _, Secret}; use snownet::{ClientNode, EncryptBuffer, RelaySocket, Transmit}; @@ -74,6 +73,9 @@ const IDS_EXPIRE: std::time::Duration = std::time::Duration::from_secs(60); /// We only store [`GatewayId`]s so the memory footprint is negligible. const MAX_REMEMBERED_GATEWAYS: NonZeroUsize = unsafe { NonZeroUsize::new_unchecked(100) }; +/// How many concurrent TCP DNS clients we can server _per_ sentinel DNS server IP. +const NUM_CONCURRENT_TCP_DNS_CLIENTS: usize = 10; + /// A sans-IO implementation of a Client's functionality. /// /// Internally, this composes a [`snownet::ClientNode`] with firezone's policy engine around resources. @@ -123,6 +125,12 @@ pub struct ClientState { /// Resources that have been disabled by the UI disabled_resources: BTreeSet, + tcp_dns_client: dns_over_tcp::Client, + tcp_dns_server: dns_over_tcp::Server, + /// Tracks the socket on which we received a TCP DNS query by the ID of the recursive DNS query we issued. + tcp_dns_sockets_by_upstream_and_query_id: + HashMap<(SocketAddr, u16), dns_over_tcp::SocketHandle>, + /// Stores the gateways we recently connected to. /// /// We use this as a hint to the portal to re-connect us to the same gateway for a resource. @@ -141,7 +149,11 @@ struct AwaitingConnectionDetails { } impl ClientState { - pub(crate) fn new(known_hosts: BTreeMap>, seed: [u8; 32]) -> Self { + pub(crate) fn new( + known_hosts: BTreeMap>, + seed: [u8; 32], + now: Instant, + ) -> Self { Self { awaiting_connection_details: Default::default(), resources_gateways: Default::default(), @@ -164,6 +176,9 @@ impl ClientState { recently_connected_gateways: LruCache::new(MAX_REMEMBERED_GATEWAYS), upstream_dns: Default::default(), buffered_dns_queries: Default::default(), + tcp_dns_client: dns_over_tcp::Client::new(now, seed), + tcp_dns_server: dns_over_tcp::Server::new(now), + tcp_dns_sockets_by_upstream_and_query_id: Default::default(), } } @@ -283,10 +298,119 @@ impl ClientState { now: Instant, buffer: &mut EncryptBuffer, ) -> Option { - let packet = match self.try_handle_dns(packet, now) { + let non_dns_packet = match self.try_handle_dns(packet, now) { ControlFlow::Break(()) => return None, ControlFlow::Continue(non_dns_packet) => non_dns_packet, }; + + self.encapsulate(non_dns_packet, now, buffer) + } + + /// Handles UDP packets received on the network interface. + /// + /// Most of these packets will be WireGuard encrypted IP packets and will thus yield an [`IpPacket`]. + /// Some of them will however be handled internally, for example, TURN control packets exchanged with relays. + /// + /// In case this function returns `None`, you should call [`ClientState::handle_timeout`] next to fully advance the internal state. + pub(crate) fn handle_network_input( + &mut self, + local: SocketAddr, + from: SocketAddr, + packet: &[u8], + now: Instant, + ) -> Option { + let (gid, packet) = self.node.decapsulate( + local, + from, + packet.as_ref(), + now, + ) + .inspect_err(|e| tracing::debug!(%local, num_bytes = %packet.len(), "Failed to decapsulate incoming packet: {e}")) + .ok()??; + + if self.tcp_dns_client.accepts(&packet) { + self.tcp_dns_client.handle_inbound(packet); + return None; + } + + let Some(peer) = self.peers.get_mut(&gid) else { + tracing::error!(%gid, "Couldn't find connection by ID"); + + return None; + }; + + peer.ensure_allowed_src(&packet) + .inspect_err(|e| tracing::debug!(%gid, %local, %from, "{e}")) + .ok()?; + + let packet = maybe_mangle_dns_response_from_cidr_resource( + packet, + &self.dns_mapping, + &mut self.mangled_dns_queries, + now, + ); + + Some(packet) + } + + pub(crate) fn handle_dns_response(&mut self, response: dns::RecursiveResponse) { + let qid = response.query.header().id(); + let server = response.server; + let domain = response + .query + .sole_question() + .ok() + .map(|q| q.into_qname()) + .map(tracing::field::display); + + let _span = tracing::debug_span!("handle_dns_response", %qid, %server, domain).entered(); + + match (response.transport, response.message) { + (dns::Transport::Udp { .. }, Err(e)) if e.kind() == io::ErrorKind::TimedOut => { + tracing::debug!("Recursive UDP DNS query timed out") + } + (dns::Transport::Udp { source }, result) => { + let message = result + .inspect(|message| { + tracing::trace!("Received recursive UDP DNS response"); + + if message.header().tc() { + tracing::debug!("Upstream DNS server had to truncate response"); + } + }) + .unwrap_or_else(|e| { + tracing::debug!("Recursive UDP DNS query failed: {e}"); + + dns::servfail(response.query.for_slice_ref()) + }); + + self.try_queue_udp_dns_response(server, source, &message) + .log_unwrap_debug("Failed to queue UDP DNS response"); + } + (dns::Transport::Tcp { source }, result) => { + let message = result + .inspect(|_| { + tracing::trace!("Received recursive TCP DNS response"); + }) + .unwrap_or_else(|e| { + tracing::debug!("Recursive TCP DNS query failed: {e}"); + + dns::servfail(response.query.for_slice_ref()) + }); + + self.tcp_dns_server + .send_message(source, message) + .log_unwrap_debug("Failed to send TCP DNS response"); + } + } + } + + fn encapsulate( + &mut self, + packet: IpPacket, + now: Instant, + buffer: &mut EncryptBuffer, + ) -> Option { let dst = packet.destination(); if is_definitely_not_a_resource(dst) { @@ -326,88 +450,6 @@ impl ClientState { Some(transmit) } - /// Handles UDP packets received on the network interface. - /// - /// Most of these packets will be WireGuard encrypted IP packets and will thus yield an [`IpPacket`]. - /// Some of them will however be handled internally, for example, TURN control packets exchanged with relays. - /// - /// In case this function returns `None`, you should call [`ClientState::handle_timeout`] next to fully advance the internal state. - pub(crate) fn handle_network_input( - &mut self, - local: SocketAddr, - from: SocketAddr, - packet: &[u8], - now: Instant, - ) -> Option { - let (gid, packet) = self.node.decapsulate( - local, - from, - packet.as_ref(), - now, - ) - .inspect_err(|e| tracing::debug!(%local, num_bytes = %packet.len(), "Failed to decapsulate incoming packet: {e}")) - .ok()??; - - let Some(peer) = self.peers.get_mut(&gid) else { - tracing::error!(%gid, "Couldn't find connection by ID"); - - return None; - }; - - peer.ensure_allowed_src(&packet) - .inspect_err(|e| tracing::debug!(%gid, %local, %from, "{e}")) - .ok()?; - - let packet = maybe_mangle_dns_response_from_cidr_resource( - packet, - &self.dns_mapping, - &mut self.mangled_dns_queries, - now, - ); - - Some(packet) - } - - pub(crate) fn handle_dns_response(&mut self, response: dns::RecursiveResponse) { - let qid = response.query.header().id(); - let server = response.server; - let domain = response - .query - .sole_question() - .ok() - .map(|q| q.into_qname()) - .map(tracing::field::display); - - let _span = tracing::debug_span!("handle_dns_response", %qid, %server, domain).entered(); - - match (response.transport, response.message) { - (dns::Transport::Udp { .. }, Err(e)) if e.kind() == io::ErrorKind::TimedOut => { - tracing::debug!("Recursive DNS query timed out") - } - (dns::Transport::Udp { source }, result) => { - let message = result - .inspect(|message| { - tracing::trace!("Received recursive DNS response"); - - if message.header().tc() { - tracing::debug!("Upstream DNS server had to truncate response"); - } - }) - .unwrap_or_else(|e| { - tracing::debug!("Recursive DNS query failed: {e}"); - - MessageBuilder::new_vec() - .start_answer(&response.query, Rcode::SERVFAIL) - .expect("original query is valid") - .into_message() - }); - - self.try_queue_udp_dns_response(server, source, &message) - .log_unwrap_debug("Failed to queue UDP DNS response"); - } - } - } - fn try_queue_udp_dns_response( &mut self, from: SocketAddr, @@ -570,49 +612,18 @@ impl ClientState { } /// Handles UDP & TCP packets targeted at our stub resolver. - fn try_handle_dns(&mut self, mut packet: IpPacket, now: Instant) -> ControlFlow<(), IpPacket> { + fn try_handle_dns(&mut self, packet: IpPacket, now: Instant) -> ControlFlow<(), IpPacket> { let dst = packet.destination(); let Some(upstream) = self.dns_mapping.get_by_left(&dst).map(|s| s.address()) else { return ControlFlow::Continue(packet); // Not for our DNS resolver. }; - let (datagram, message) = match parse_udp_dns_message(&packet) { - Ok((datagram, message)) => (datagram, message), - Err(e) => { - tracing::trace!(?packet, "Failed to parse DNS query: {e:#}"); - return ControlFlow::Break(()); - } - }; - - let source = SocketAddr::new(packet.source(), datagram.source_port()); - - match self.stub_resolver.handle(message) { - dns::ResolveStrategy::LocalResponse(response) => { - self.try_queue_udp_dns_response(upstream, source, &response) - .log_unwrap_debug("Failed to queue UDP DNS response"); - } - dns::ResolveStrategy::Recurse => { - let query_id = message.header().id(); - - if self.should_forward_dns_query_to_gateway(upstream.ip()) { - tracing::trace!(server = %upstream, %query_id, "Forwarding UDP DNS query via tunnel"); - - self.mangled_dns_queries - .insert((upstream, message.header().id()), now + IDS_EXPIRE); - packet.set_dst(upstream.ip()); - packet.update_checksum(); - - return ControlFlow::Continue(packet); - } - - tracing::trace!(server = %upstream, %query_id, "Forwarding UDP DNS query directly via host"); - - self.buffered_dns_queries - .push_back(dns::RecursiveQuery::via_udp(source, upstream, message)); - } + if self.tcp_dns_server.accepts(&packet) { + self.tcp_dns_server.handle_inbound(packet); + return ControlFlow::Break(()); } - ControlFlow::Break(()) + self.handle_udp_dns_query(upstream, packet, now) } pub fn on_connection_failed(&mut self, resource: ResourceId) { @@ -693,6 +704,36 @@ impl ClientState { self.mangled_dns_queries.clear(); } + fn initialise_tcp_dns_client(&mut self) { + let Some(tun_config) = self.tun_config.as_ref() else { + return; + }; + + self.tcp_dns_client + .set_source_interface(tun_config.ip4, tun_config.ip6); + + let upstream_resolvers = self + .dns_mapping + .right_values() + .map(|s| s.address()) + .collect(); + + if let Err(e) = self.tcp_dns_client.set_resolvers(upstream_resolvers) { + tracing::warn!("Failed to connect to upstream DNS resolvers over TCP: {e:#}"); + } + } + + fn initialise_tcp_dns_server(&mut self) { + let sentinel_sockets = self + .dns_mapping + .left_values() + .map(|ip| SocketAddr::new(*ip, DNS_PORT)) + .collect(); + + self.tcp_dns_server + .set_listen_addresses::(sentinel_sockets); + } + pub fn set_disabled_resources(&mut self, new_disabled_resources: BTreeSet) { let current_disabled_resources = self.disabled_resources.clone(); @@ -805,16 +846,23 @@ impl ClientState { } pub fn poll_packets(&mut self) -> Option { - self.buffered_packets.pop_front() + self.buffered_packets + .pop_front() + .or_else(|| self.tcp_dns_server.poll_outbound()) } pub fn poll_timeout(&mut self) -> Option { // The number of mangled DNS queries is expected to be fairly small because we only track them whilst connecting to a CIDR resource that is a DNS server. // Thus, sorting these values on-demand even within `poll_timeout` is expected to be performant enough. let next_dns_query_expiry = self.mangled_dns_queries.values().min().copied(); - let next_node_timeout = self.node.poll_timeout(); - earliest(next_dns_query_expiry, next_node_timeout) + earliest( + earliest( + self.tcp_dns_client.poll_timeout(), + self.tcp_dns_server.poll_timeout(), + ), + earliest(self.node.poll_timeout(), next_dns_query_expiry), + ) } pub fn handle_timeout(&mut self, now: Instant) { @@ -822,6 +870,163 @@ impl ClientState { self.drain_node_events(); self.mangled_dns_queries.retain(|_, exp| now < *exp); + + self.advance_dns_tcp_sockets(now); + } + + /// Advance the TCP DNS server and client state machines. + /// + /// Receiving something on a TCP server socket may trigger packets to be sent on the TCP client socket and vice versa. + /// Therefore, we loop here until non of the `poll-X` functions return anything anymore. + fn advance_dns_tcp_sockets(&mut self, now: Instant) { + loop { + self.tcp_dns_server.handle_timeout(now); + self.tcp_dns_client.handle_timeout(now); + + // Check if have any pending TCP DNS queries. + if let Some(query) = self.tcp_dns_server.poll_queries() { + self.handle_tcp_dns_query(query); + continue; + } + + // Check if the client wants to emit any packets. + if let Some(packet) = self.tcp_dns_client.poll_outbound() { + let mut buffer = snownet::EncryptBuffer::new(); + + // All packets from the TCP DNS client _should_ go through the tunnel. + let Some(encryped_packet) = self.encapsulate(packet, now, &mut buffer) else { + continue; + }; + + let transmit = encryped_packet.to_transmit(&buffer).into_owned(); + self.buffered_transmits.push_back(transmit); + continue; + } + + // Check if the client has assembled a response to a query. + if let Some(query_result) = self.tcp_dns_client.poll_query_result() { + let server = query_result.server; + let qid = query_result.query.header().id(); + let known_sockets = &mut self.tcp_dns_sockets_by_upstream_and_query_id; + + let Some(source) = known_sockets.remove(&(server, qid)) else { + tracing::debug!(?known_sockets, %server, %qid, "Failed to find TCP socket handle for query result"); + + continue; + }; + + self.handle_dns_response(dns::RecursiveResponse { + server, + query: query_result.query, + message: query_result + .result + .map_err(|e| io::Error::new(io::ErrorKind::Other, format!("{e:#}"))), + transport: dns::Transport::Tcp { source }, + }); + continue; + } + + break; + } + } + + fn handle_udp_dns_query( + &mut self, + upstream: SocketAddr, + mut packet: IpPacket, + now: Instant, + ) -> ControlFlow<(), IpPacket> { + let (datagram, message) = match parse_udp_dns_message(&packet) { + Ok((datagram, message)) => (datagram, message), + Err(e) => { + tracing::trace!(?packet, "Failed to parse DNS query: {e:#}"); + return ControlFlow::Break(()); + } + }; + + let source = SocketAddr::new(packet.source(), datagram.source_port()); + + match self.stub_resolver.handle(message) { + dns::ResolveStrategy::LocalResponse(response) => { + self.try_queue_udp_dns_response(upstream, source, &response) + .log_unwrap_debug("Failed to queue UDP DNS response"); + } + dns::ResolveStrategy::Recurse => { + let query_id = message.header().id(); + + if self.should_forward_dns_query_to_gateway(upstream.ip()) { + tracing::trace!(server = %upstream, %query_id, "Forwarding UDP DNS query via tunnel"); + + self.mangled_dns_queries + .insert((upstream, message.header().id()), now + IDS_EXPIRE); + packet.set_dst(upstream.ip()); + packet.update_checksum(); + + return ControlFlow::Continue(packet); + } + + let query_id = message.header().id(); + + tracing::trace!(server = %upstream, %query_id, "Forwarding UDP DNS query directly via host"); + + self.buffered_dns_queries + .push_back(dns::RecursiveQuery::via_udp(source, upstream, message)); + } + } + + ControlFlow::Break(()) + } + + fn handle_tcp_dns_query(&mut self, query: dns_over_tcp::Query) { + let message = query.message; + + let Some(upstream) = self.dns_mapping.get_by_left(&query.local.ip()) else { + tracing::debug!("Received TCP packet for non-sentinel IP"); + debug_assert!( + false, + "We only dispatch packets to sentinel IPs to the TCP DNS server" + ); + return; + }; + let server = upstream.address(); + + match self.stub_resolver.handle(message.for_slice_ref()) { + dns::ResolveStrategy::LocalResponse(response) => { + self.tcp_dns_server + .send_message(query.socket, response) + .log_unwrap_debug("Failed to send TCP DNS response"); + } + dns::ResolveStrategy::Recurse => { + let query_id = message.header().id(); + + if self.should_forward_dns_query_to_gateway(server.ip()) { + match self.tcp_dns_client.send_query(server, message.clone()) { + Ok(()) => {} + Err(e) => { + tracing::debug!("Failed to send recursive TCP DNS query {e:#}"); + + self.tcp_dns_server + .send_message(query.socket, dns::servfail(message.for_slice_ref())) + .log_unwrap_debug("Failed to send TCP DNS response"); + return; + } + }; + + let existing = self + .tcp_dns_sockets_by_upstream_and_query_id + .insert((server, query_id), query.socket); + + debug_assert!(existing.is_none(), "Query IDs should be unique"); + + return; + } + + tracing::trace!(%server, %query_id, "Forwarding TCP DNS query"); + + self.buffered_dns_queries + .push_back(dns::RecursiveQuery::via_tcp(query.socket, server, message)); + } + }; } fn maybe_update_tun_routes(&mut self) { @@ -885,6 +1090,9 @@ impl ClientState { self.tun_config = Some(new_tun_config.clone()); self.buffered_events .push_back(ClientEvent::TunInterfaceUpdated(new_tun_config)); + + self.initialise_tcp_dns_client(); + self.initialise_tcp_dns_server(); } fn drain_node_events(&mut self) { @@ -965,6 +1173,11 @@ impl ClientState { self.node.reset(); self.recently_connected_gateways.clear(); // Ensure we don't have sticky gateways when we roam. self.drain_node_events(); + + // Resetting the client will trigger a failed `QueryResult` for each one that is in-progress. + // Failed queries get translated into `SERVFAIL` responses to the client. + // This will also allocate new local ports for our outgoing TCP connections. + self.initialise_tcp_dns_client(); } pub(crate) fn poll_transmit(&mut self) -> Option> { @@ -1480,7 +1693,7 @@ mod tests { impl ClientState { pub fn for_test() -> ClientState { - ClientState::new(BTreeMap::new(), rand::random()) + ClientState::new(BTreeMap::new(), rand::random(), Instant::now()) } } diff --git a/rust/connlib/tunnel/src/dns.rs b/rust/connlib/tunnel/src/dns.rs index 0aa57f1c2..dbd753591 100644 --- a/rust/connlib/tunnel/src/dns.rs +++ b/rust/connlib/tunnel/src/dns.rs @@ -1,6 +1,7 @@ use crate::client::IpProvider; use anyhow::{Context, Result}; use connlib_model::{DomainName, ResourceId}; +use dns_over_tcp::SocketHandle; use domain::rdata::AllRecordData; use domain::{ base::{ @@ -46,7 +47,7 @@ pub struct StubResolver { } /// A query that needs to be forwarded to an upstream DNS server for resolution. -#[derive(Debug, Clone)] +#[derive(Debug)] pub(crate) struct RecursiveQuery { pub server: SocketAddr, pub message: Message>, @@ -70,14 +71,29 @@ impl RecursiveQuery { transport: Transport::Udp { source }, } } + + pub(crate) fn via_tcp( + source: SocketHandle, + server: SocketAddr, + message: Message>, + ) -> Self { + Self { + server, + message, + transport: Transport::Tcp { source }, + } + } } -#[derive(Debug, Clone, Copy)] +#[derive(Debug)] pub(crate) enum Transport { Udp { /// The original source we received the DNS query on. source: SocketAddr, }, + Tcp { + source: SocketHandle, + }, } /// Tells the Client how to reply to a single DNS query @@ -259,12 +275,7 @@ impl StubResolver { Err(e) => { tracing::trace!("Failed to handle DNS query: {e:#}"); - let response = MessageBuilder::new_vec() - .start_answer(&message, Rcode::SERVFAIL) - .unwrap() - .into_message(); - - ResolveStrategy::LocalResponse(response) + ResolveStrategy::LocalResponse(servfail(message)) } } } @@ -335,6 +346,13 @@ impl StubResolver { } } +pub fn servfail(message: Message<&[u8]>) -> Message> { + MessageBuilder::new_vec() + .start_answer(&message, Rcode::SERVFAIL) + .expect("should always be able to create a heap-allocated SERVFAIL message") + .into_message() +} + fn to_a_records(ips: impl Iterator) -> Vec, DomainName>> { ips.filter_map(get_v4) .map(domain::rdata::A::new) diff --git a/rust/connlib/tunnel/src/io.rs b/rust/connlib/tunnel/src/io.rs index aacd3d37a..4fcef0ff2 100644 --- a/rust/connlib/tunnel/src/io.rs +++ b/rust/connlib/tunnel/src/io.rs @@ -17,7 +17,10 @@ use std::{ task::{ready, Context, Poll}, time::{Duration, Instant}, }; -use tokio::sync::mpsc; +use tokio::{ + io::{AsyncReadExt, AsyncWriteExt}, + sync::mpsc, +}; use tun::Tun; /// Bundles together all side-effects that connlib needs to have access to. @@ -26,7 +29,7 @@ pub struct Io { sockets: Sockets, unwritten_packet: Option, - _tcp_socket_factory: Arc>, + tcp_socket_factory: Arc>, udp_socket_factory: Arc>, dns_queries: FuturesTupleSet>>, DnsQueryMetaData>, @@ -87,7 +90,7 @@ impl Io { inbound_packet_rx, timeout: None, sockets, - _tcp_socket_factory: tcp_socket_factory, + tcp_socket_factory, udp_socket_factory, unwritten_packet: None, dns_queries: FuturesTupleSet::new(DNS_QUERY_TIMEOUT, 1000), @@ -117,19 +120,20 @@ impl Io { match self.dns_queries.poll_unpin(cx) { Poll::Ready((result, meta)) => { - let response = result - .map(|result| dns::RecursiveResponse { - server: meta.server, - query: meta.query.clone(), - message: result, - transport: meta.transport, - }) - .unwrap_or_else(|_| dns::RecursiveResponse { + let response = match result { + Ok(result) => dns::RecursiveResponse { server: meta.server, query: meta.query, - message: Err(io::Error::from(io::ErrorKind::TimedOut)), + message: result, transport: meta.transport, - }); + }, + Err(e @ futures_bounded::Timeout { .. }) => dns::RecursiveResponse { + server: meta.server, + query: meta.query, + message: Err(io::Error::new(io::ErrorKind::TimedOut, e)), + transport: meta.transport, + }, + }; return Poll::Ready(Ok(Input::DnsResponse(response))); } @@ -255,6 +259,48 @@ impl Io { tracing::debug!("Failed to queue UDP DNS query") } } + dns::Transport::Tcp { .. } => { + let factory = self.tcp_socket_factory.clone(); + let server = query.server; + let meta = DnsQueryMetaData { + query: query.message.clone(), + server, + transport: query.transport, + }; + + if self + .dns_queries + .try_push( + async move { + let tcp_socket = factory(&server)?; // TODO: Optimise this to reuse a TCP socket to the same resolver. + let mut tcp_stream = tcp_socket.connect(server).await?; + + let query = query.message.into_octets(); + let dns_message_length = (query.len() as u16).to_be_bytes(); + + tcp_stream.write_all(&dns_message_length).await?; + tcp_stream.write_all(&query).await?; + + let mut response_length = [0u8; 2]; + tcp_stream.read_exact(&mut response_length).await?; + let response_length = u16::from_be_bytes(response_length) as usize; + + // A u16 is at most 65k, meaning we are okay to allocate here based on what the remote is sending. + let mut response = vec![0u8; response_length]; + tcp_stream.read_exact(&mut response).await?; + + let message = Message::from_octets(response) + .map_err(|_| io::Error::other("Failed to parse DNS message"))?; + + Ok(message) + }, + meta, + ) + .is_err() + { + tracing::debug!("Failed to queue TCP DNS query") + } + } } } diff --git a/rust/connlib/tunnel/src/lib.rs b/rust/connlib/tunnel/src/lib.rs index 527ba9e61..0e11c318a 100644 --- a/rust/connlib/tunnel/src/lib.rs +++ b/rust/connlib/tunnel/src/lib.rs @@ -95,7 +95,7 @@ impl ClientTunnel { ) -> Self { Self { io: Io::new(tcp_socket_factory, udp_socket_factory), - role_state: ClientState::new(known_hosts, rand::random()), + role_state: ClientState::new(known_hosts, rand::random(), Instant::now()), ip4_read_buf: Box::new([0u8; MAX_UDP_SIZE]), ip6_read_buf: Box::new([0u8; MAX_UDP_SIZE]), encrypt_buf: Default::default(), diff --git a/rust/connlib/tunnel/src/tests/assertions.rs b/rust/connlib/tunnel/src/tests/assertions.rs index 67ce33876..51f69ff26 100644 --- a/rust/connlib/tunnel/src/tests/assertions.rs +++ b/rust/connlib/tunnel/src/tests/assertions.rs @@ -159,31 +159,31 @@ pub(crate) fn assert_routes_are_valid(ref_client: &RefClient, sim_client: &SimCl } } -pub(crate) fn assert_dns_packets_properties(ref_client: &RefClient, sim_client: &SimClient) { +pub(crate) fn assert_udp_dns_packets_properties(ref_client: &RefClient, sim_client: &SimClient) { let unexpected_dns_replies = find_unexpected_entries( - &ref_client.expected_dns_handshakes, - &sim_client.received_dns_responses, + &ref_client.expected_udp_dns_handshakes, + &sim_client.received_udp_dns_responses, |(_, id_a), (_, id_b)| id_a == id_b, ); if !unexpected_dns_replies.is_empty() { - tracing::error!(target: "assertions", ?unexpected_dns_replies, "❌ Unexpected DNS replies on client"); + tracing::error!(target: "assertions", ?unexpected_dns_replies, "❌ Unexpected UDP DNS replies on client"); } - for (dns_server, query_id) in ref_client.expected_dns_handshakes.iter() { + for (dns_server, query_id) in ref_client.expected_udp_dns_handshakes.iter() { let _guard = - tracing::info_span!(target: "assertions", "dns", %query_id, %dns_server).entered(); + tracing::info_span!(target: "assertions", "udp_dns", %query_id, %dns_server).entered(); let key = &(*dns_server, *query_id); - let queries = &sim_client.sent_dns_queries; - let responses = &sim_client.received_dns_responses; + let queries = &sim_client.sent_udp_dns_queries; + let responses = &sim_client.received_udp_dns_responses; let Some(client_sent_query) = queries.get(key) else { - tracing::error!(target: "assertions", ?queries, "❌ Missing DNS query on client"); + tracing::error!(target: "assertions", ?queries, "❌ Missing UDP DNS query on client"); continue; }; let Some(client_received_response) = responses.get(key) else { - tracing::error!(target: "assertions", ?responses, "❌ Missing DNS response on client"); + tracing::error!(target: "assertions", ?responses, "❌ Missing UDP DNS response on client"); continue; }; @@ -192,6 +192,26 @@ pub(crate) fn assert_dns_packets_properties(ref_client: &RefClient, sim_client: } } +pub(crate) fn assert_tcp_dns(ref_client: &RefClient, sim_client: &SimClient) { + for (dns_server, query_id) in ref_client.expected_tcp_dns_handshakes.iter() { + let _guard = + tracing::info_span!(target: "assertions", "tcp_dns", %query_id, %dns_server).entered(); + let key = &(*dns_server, *query_id); + + let queries = &sim_client.sent_tcp_dns_queries; + let responses = &sim_client.received_tcp_dns_responses; + + if queries.get(key).is_none() { + tracing::error!(target: "assertions", ?queries, "❌ Missing TCP DNS query on client"); + continue; + }; + if responses.get(key).is_none() { + tracing::error!(target: "assertions", ?responses, "❌ Missing TCP DNS response on client"); + continue; + }; + } +} + fn assert_correct_src_and_dst_ips( client_sent_request: &IpPacket, client_received_reply: &IpPacket, diff --git a/rust/connlib/tunnel/src/tests/dns_server_resource.rs b/rust/connlib/tunnel/src/tests/dns_server_resource.rs index 71c95f4b3..69573c4ca 100644 --- a/rust/connlib/tunnel/src/tests/dns_server_resource.rs +++ b/rust/connlib/tunnel/src/tests/dns_server_resource.rs @@ -1,6 +1,6 @@ use std::{ collections::{BTreeMap, BTreeSet, VecDeque}, - net::IpAddr, + net::{IpAddr, SocketAddr}, time::Instant, }; @@ -14,12 +14,47 @@ use domain::{ }; use ip_packet::IpPacket; +pub struct TcpDnsServerResource { + server: dns_over_tcp::Server, +} + #[derive(Debug, Default)] pub struct UdpDnsServerResource { inbound_packets: VecDeque, outbound_packets: VecDeque, } +impl TcpDnsServerResource { + pub fn new(socket: SocketAddr, now: Instant) -> Self { + let mut server = dns_over_tcp::Server::new(now); + server.set_listen_addresses::<5>(BTreeSet::from([socket])); + + Self { server } + } + + pub fn handle_input(&mut self, packet: IpPacket) { + self.server.handle_inbound(packet); + } + + pub fn handle_timeout( + &mut self, + global_dns_records: &BTreeMap>, + + now: Instant, + ) { + self.server.handle_timeout(now); + while let Some(query) = self.server.poll_queries() { + let response = handle_dns_query(query.message.for_slice(), global_dns_records); + + self.server.send_message(query.socket, response).unwrap(); + } + } + + pub fn poll_outbound(&mut self) -> Option { + self.server.poll_outbound() + } +} + impl UdpDnsServerResource { pub fn handle_input(&mut self, packet: IpPacket) { self.inbound_packets.push_back(packet); diff --git a/rust/connlib/tunnel/src/tests/reference.rs b/rust/connlib/tunnel/src/tests/reference.rs index c0e29102f..4d5124104 100644 --- a/rust/connlib/tunnel/src/tests/reference.rs +++ b/rust/connlib/tunnel/src/tests/reference.rs @@ -345,6 +345,10 @@ impl ReferenceState { if connected_resources.is_empty() { connected_resources.insert(resource); } + // TCP has retries so we will also be connected to those for sure. + if query.transport == DnsTransport::Tcp { + connected_resources.insert(resource); + } } continue; diff --git a/rust/connlib/tunnel/src/tests/sim_client.rs b/rust/connlib/tunnel/src/tests/sim_client.rs index 30cc96cda..587c7d906 100644 --- a/rust/connlib/tunnel/src/tests/sim_client.rs +++ b/rust/connlib/tunnel/src/tests/sim_client.rs @@ -3,7 +3,7 @@ use super::{ sim_net::{any_ip_stack, any_port, host, Host}, sim_relay::{map_explode, SimRelay}, strategies::latency, - transition::DnsQuery, + transition::{DnsQuery, DnsTransport}, IcmpIdentifier, IcmpSeq, QueryId, }; use crate::{ @@ -15,7 +15,7 @@ use crate::{proptest::*, ClientState}; use bimap::BiMap; use connlib_model::{ClientId, GatewayId, RelayId, ResourceId}; use domain::{ - base::{Message, Rtype, ToName}, + base::{iana::Opcode, Message, MessageBuilder, Question, Rtype, ToName}, rdata::AllRecordData, }; use ip_network::{IpNetwork, Ipv4Network, Ipv6Network}; @@ -44,34 +44,45 @@ pub(crate) struct SimClient { pub(crate) dns_records: HashMap>, /// Bi-directional mapping between connlib's sentinel DNS IPs and the effective DNS servers. - pub(crate) dns_by_sentinel: BiMap, + dns_by_sentinel: BiMap, pub(crate) ipv4_routes: BTreeSet, pub(crate) ipv6_routes: BTreeSet, - pub(crate) sent_dns_queries: HashMap<(SocketAddr, QueryId), IpPacket>, - pub(crate) received_dns_responses: BTreeMap<(SocketAddr, QueryId), IpPacket>, + pub(crate) sent_udp_dns_queries: HashMap<(SocketAddr, QueryId), IpPacket>, + pub(crate) received_udp_dns_responses: BTreeMap<(SocketAddr, QueryId), IpPacket>, + + pub(crate) sent_tcp_dns_queries: HashSet<(SocketAddr, QueryId)>, + pub(crate) received_tcp_dns_responses: BTreeSet<(SocketAddr, QueryId)>, pub(crate) sent_icmp_requests: HashMap<(u16, u16), IpPacket>, pub(crate) received_icmp_replies: BTreeMap<(u16, u16), IpPacket>, + pub(crate) tcp_dns_client: dns_over_tcp::Client, + enc_buffer: EncryptBuffer, } impl SimClient { - pub(crate) fn new(id: ClientId, sut: ClientState) -> Self { + pub(crate) fn new(id: ClientId, sut: ClientState, now: Instant) -> Self { + let mut tcp_dns_client = dns_over_tcp::Client::new(now, [0u8; 32]); + tcp_dns_client.set_source_interface(Ipv4Addr::LOCALHOST, Ipv6Addr::LOCALHOST); + Self { id, sut, dns_records: Default::default(), dns_by_sentinel: Default::default(), - sent_dns_queries: Default::default(), - received_dns_responses: Default::default(), + sent_udp_dns_queries: Default::default(), + received_udp_dns_responses: Default::default(), + sent_tcp_dns_queries: Default::default(), + received_tcp_dns_responses: Default::default(), sent_icmp_requests: Default::default(), received_icmp_replies: Default::default(), enc_buffer: Default::default(), ipv4_routes: Default::default(), ipv6_routes: Default::default(), + tcp_dns_client, } } @@ -80,36 +91,85 @@ impl SimClient { self.dns_by_sentinel.right_values().copied().collect() } + pub(crate) fn set_new_dns_servers(&mut self, mapping: BiMap) { + if self.dns_by_sentinel != mapping { + self.tcp_dns_client + .set_resolvers( + mapping + .left_values() + .map(|ip| SocketAddr::new(*ip, 53)) + .collect(), + ) + .unwrap(); + } + + self.dns_by_sentinel = mapping; + } + + pub(crate) fn dns_mapping(&self) -> &BiMap { + &self.dns_by_sentinel + } + pub(crate) fn send_dns_query_for( &mut self, domain: DomainName, r_type: Rtype, query_id: u16, - dns_server: SocketAddr, + upstream: SocketAddr, + dns_transport: DnsTransport, now: Instant, ) -> Option> { - let Some(dns_server) = self.dns_by_sentinel.get_by_right(&dns_server).copied() else { - tracing::error!(%dns_server, "Unknown DNS server"); + let Some(sentinel) = self.dns_by_sentinel.get_by_right(&upstream).copied() else { + tracing::error!(%upstream, "Unknown DNS server"); return None; }; - tracing::debug!(%dns_server, %domain, "Sending DNS query"); + tracing::debug!(%sentinel, %domain, "Sending DNS query"); let src = self .sut - .tunnel_ip_for(dns_server) + .tunnel_ip_for(sentinel) .expect("tunnel should be initialised"); - let packet = ip_packet::make::dns_query( - domain, - r_type, - SocketAddr::new(src, 9999), // An application would pick a random source port that is free. - SocketAddr::new(dns_server, 53), - query_id, - ) - .unwrap(); + // Create the DNS query message + let mut msg_builder = MessageBuilder::new_vec(); - self.encapsulate(packet, now) + msg_builder.header_mut().set_opcode(Opcode::QUERY); + msg_builder.header_mut().set_rd(true); + msg_builder.header_mut().set_id(query_id); + + // Create the query + let mut question_builder = msg_builder.question(); + question_builder + .push(Question::new_in(domain, r_type)) + .unwrap(); + + let message = question_builder.into_message(); + + match dns_transport { + DnsTransport::Udp => { + let packet = ip_packet::make::udp_packet( + src, + sentinel, + 9999, // An application would pick a free source port. + 53, + message.as_octets().to_vec(), + ) + .unwrap(); + + self.sent_udp_dns_queries + .insert((upstream, query_id), packet.clone()); + self.encapsulate(packet, now) + } + DnsTransport::Tcp => { + self.tcp_dns_client + .send_query(SocketAddr::new(sentinel, 53), message) + .unwrap(); + self.sent_tcp_dns_queries.insert((upstream, query_id)); + + None + } + } } pub(crate) fn encapsulate( @@ -131,24 +191,6 @@ impl SimClient { } } - { - if let Some(udp) = packet.as_udp() { - if let Ok(message) = Message::from_slice(udp.payload()) { - debug_assert!( - !message.header().qr(), - "every DNS message sent from the client should be a DNS query" - ); - - // Map back to upstream socket so we can assert on it correctly. - let sentinel = SocketAddr::from((packet.destination(), udp.destination_port())); - let upstream = self.upstream_dns_by_sentinel(&sentinel).unwrap(); - - self.sent_dns_queries - .insert((upstream, message.header().id()), packet.clone()); - } - } - } - let Some(enc_packet) = self.sut.handle_tun_input(packet, now, &mut self.enc_buffer) else { self.sut.handle_timeout(now); // If we handled the packet internally, make sure to advance state. return None; @@ -191,6 +233,11 @@ impl SimClient { } } + if self.tcp_dns_client.accepts(&packet) { + self.tcp_dns_client.handle_inbound(packet); + return; + } + if let Some(udp) = packet.as_udp() { if udp.source_port() == 53 { let message = Message::from_slice(udp.payload()) @@ -203,36 +250,9 @@ impl SimClient { return; }; - self.received_dns_responses + self.received_udp_dns_responses .insert((upstream, message.header().id()), packet.clone()); - - for record in message.answer().unwrap() { - let record = record.unwrap(); - let domain = record.owner().to_name(); - - #[expect(clippy::wildcard_enum_match_arm)] - let ip = match record - .into_any_record::>() - .unwrap() - .data() - { - AllRecordData::A(a) => IpAddr::from(a.addr()), - AllRecordData::Aaaa(aaaa) => IpAddr::from(aaaa.addr()), - AllRecordData::Ptr(_) => { - continue; - } - unhandled => { - panic!("Unexpected record data: {unhandled:?}") - } - }; - - self.dns_records.entry(domain).or_default().push(ip); - } - - // Ensure all IPs are always sorted. - for ips in self.dns_records.values_mut() { - ips.sort() - } + self.handle_dns_response(message); return; } @@ -259,6 +279,36 @@ impl SimClient { Some(*socket) } + + pub(crate) fn handle_dns_response(&mut self, message: &Message<[u8]>) { + for record in message.answer().unwrap() { + let record = record.unwrap(); + let domain = record.owner().to_name(); + + #[expect(clippy::wildcard_enum_match_arm)] + let ip = match record + .into_any_record::>() + .unwrap() + .data() + { + AllRecordData::A(a) => IpAddr::from(a.addr()), + AllRecordData::Aaaa(aaaa) => IpAddr::from(aaaa.addr()), + AllRecordData::Ptr(_) => { + continue; + } + unhandled => { + panic!("Unexpected record data: {unhandled:?}") + } + }; + + self.dns_records.entry(domain).or_default().push(ip); + } + + // Ensure all IPs are always sorted. + for ips in self.dns_records.values_mut() { + ips.sort() + } + } } /// Reference state for a particular client. @@ -327,17 +377,20 @@ pub struct RefClient { #[derivative(Debug = "ignore")] pub(crate) expected_icmp_handshakes: BTreeMap>, - /// The expected DNS handshakes. + /// The expected UDP DNS handshakes. #[derivative(Debug = "ignore")] - pub(crate) expected_dns_handshakes: VecDeque<(SocketAddr, QueryId)>, + pub(crate) expected_udp_dns_handshakes: VecDeque<(SocketAddr, QueryId)>, + /// The expected TCP DNS handshakes. + #[derivative(Debug = "ignore")] + pub(crate) expected_tcp_dns_handshakes: VecDeque<(SocketAddr, QueryId)>, } impl RefClient { /// Initialize the [`ClientState`]. /// /// This simulates receiving the `init` message from the portal. - pub(crate) fn init(self) -> SimClient { - let mut client_state = ClientState::new(self.known_hosts, self.key.0); // Cheating a bit here by reusing the key as seed. + pub(crate) fn init(self, now: Instant) -> SimClient { + let mut client_state = ClientState::new(self.known_hosts, self.key.0, now); // Cheating a bit here by reusing the key as seed. client_state.update_interface_config(Interface { ipv4: self.tunnel_ip4, ipv6: self.tunnel_ip6, @@ -345,7 +398,7 @@ impl RefClient { }); client_state.update_system_resolvers(self.system_dns_resolvers.clone()); - SimClient::new(self.id, client_state) + SimClient::new(self.id, client_state, now) } pub(crate) fn disconnect_resource(&mut self, resource: &ResourceId) { @@ -624,8 +677,16 @@ impl RefClient { .or_default() .insert(query.r_type); - self.expected_dns_handshakes - .push_back((query.dns_server, query.query_id)); + match query.transport { + DnsTransport::Udp => { + self.expected_udp_dns_handshakes + .push_back((query.dns_server, query.query_id)); + } + DnsTransport::Tcp => { + self.expected_tcp_dns_handshakes + .push_back((query.dns_server, query.query_id)); + } + } } pub(crate) fn ipv4_cidr_resource_dsts(&self) -> Vec { @@ -930,7 +991,8 @@ fn ref_client( connected_dns_resources: Default::default(), connected_internet_resource: Default::default(), expected_icmp_handshakes: Default::default(), - expected_dns_handshakes: Default::default(), + expected_udp_dns_handshakes: Default::default(), + expected_tcp_dns_handshakes: Default::default(), disabled_resources: Default::default(), resources: Default::default(), ipv4_routes: Default::default(), diff --git a/rust/connlib/tunnel/src/tests/sim_gateway.rs b/rust/connlib/tunnel/src/tests/sim_gateway.rs index 968d38947..ec4b89ba4 100644 --- a/rust/connlib/tunnel/src/tests/sim_gateway.rs +++ b/rust/connlib/tunnel/src/tests/sim_gateway.rs @@ -1,5 +1,5 @@ use super::{ - dns_server_resource::UdpDnsServerResource, + dns_server_resource::{TcpDnsServerResource, UdpDnsServerResource}, reference::{private_key, PrivateKey}, sim_net::{any_port, dual_ip_stack, host, Host}, sim_relay::{map_explode, SimRelay}, @@ -28,6 +28,7 @@ pub(crate) struct SimGateway { pub(crate) received_icmp_requests: BTreeMap, udp_dns_server_resources: HashMap, + tcp_dns_server_resources: HashMap, } impl SimGateway { @@ -38,6 +39,7 @@ impl SimGateway { received_icmp_requests: Default::default(), enc_buffer: Default::default(), udp_dns_server_resources: Default::default(), + tcp_dns_server_resources: Default::default(), } } @@ -70,8 +72,14 @@ impl SimGateway { std::iter::from_fn(|| s.poll_outbound()) }); + let tcp_server_packets = self.tcp_dns_server_resources.values_mut().flat_map(|s| { + s.handle_timeout(global_dns_records, now); + + std::iter::from_fn(|| s.poll_outbound()) + }); udp_server_packets + .chain(tcp_server_packets) .filter_map(|packet| { Some( self.sut @@ -83,12 +91,18 @@ impl SimGateway { .collect() } - pub(crate) fn deploy_new_dns_servers(&mut self, dns_servers: impl Iterator) { + pub(crate) fn deploy_new_dns_servers( + &mut self, + dns_servers: impl Iterator, + now: Instant, + ) { self.udp_dns_server_resources.clear(); for server in dns_servers { self.udp_dns_server_resources .insert(server, UdpDnsServerResource::default()); + self.tcp_dns_server_resources + .insert(server, TcpDnsServerResource::new(server, now)); } } @@ -117,6 +131,15 @@ impl SimGateway { } } + if let Some(tcp) = packet.as_tcp() { + let socket = SocketAddr::new(packet.destination(), tcp.destination_port()); + + if let Some(server) = self.tcp_dns_server_resources.get_mut(&socket) { + server.handle_input(packet); + return None; + } + } + tracing::error!(?packet, "Unhandled packet"); None } diff --git a/rust/connlib/tunnel/src/tests/strategies.rs b/rust/connlib/tunnel/src/tests/strategies.rs index 936c49b85..d86a70551 100644 --- a/rust/connlib/tunnel/src/tests/strategies.rs +++ b/rust/connlib/tunnel/src/tests/strategies.rs @@ -115,6 +115,9 @@ pub(crate) fn dns_servers() -> impl Strategy> { .prop_filter("must not be in IPv4 resources range", |ip| { !crate::client::IPV4_RESOURCES.contains(*ip) }) + .prop_filter("must be addressable IP", |ip| { + !ip.is_unspecified() && !ip.is_multicast() && !ip.is_broadcast() + }) .prop_map(|ip| SocketAddr::from((ip, 53))), 1..4, ); @@ -126,6 +129,9 @@ pub(crate) fn dns_servers() -> impl Strategy> { .prop_filter("must not be in IPv6 resources range", |ip| { !crate::client::IPV6_RESOURCES.contains(*ip) }) + .prop_filter("must be addressable IP", |ip| { + !ip.is_unspecified() && !ip.is_multicast() + }) .prop_map(|ip| SocketAddr::from((ip, 53))), 1..4, ); diff --git a/rust/connlib/tunnel/src/tests/sut.rs b/rust/connlib/tunnel/src/tests/sut.rs index 8199f157d..caacd82f5 100644 --- a/rust/connlib/tunnel/src/tests/sut.rs +++ b/rust/connlib/tunnel/src/tests/sut.rs @@ -47,9 +47,10 @@ impl TunnelTest { // Initialize the system under test from our reference state. pub(crate) fn init_test(ref_state: &ReferenceState, flux_capacitor: FluxCapacitor) -> Self { // Construct client, gateway and relay from the initial state. - let mut client = ref_state - .client - .map(|ref_client, _, _| ref_client.init(), debug_span!("client")); + let mut client = ref_state.client.map( + |ref_client, _, _| ref_client.init(flux_capacitor.now()), + debug_span!("client"), + ); let mut gateways = ref_state .gateways @@ -203,10 +204,11 @@ impl TunnelTest { r_type, dns_server, query_id, + transport, } in queries { let transmit = state.client.exec_mut(|sim| { - sim.send_dns_query_for(domain, r_type, query_id, dns_server, now) + sim.send_dns_query_for(domain, r_type, query_id, dns_server, transport, now) }); buffered_transmits.push_from(transmit, &state.client, now); @@ -342,7 +344,8 @@ impl TunnelTest { sim_gateways, &ref_state.global_dns_records, ); - assert_dns_packets_properties(ref_client, sim_client); + assert_udp_dns_packets_properties(ref_client, sim_client); + assert_tcp_dns(ref_client, sim_client); assert_known_hosts_are_valid(ref_client, sim_client); assert_dns_servers_are_valid(ref_client, sim_client); assert_routes_are_valid(ref_client, sim_client); @@ -388,8 +391,10 @@ impl TunnelTest { let server = query.server; let transport = query.transport; - let response = - self.on_recursive_dns_query(query.clone(), &ref_state.global_dns_records); + let response = self.on_recursive_dns_query( + query.message.for_slice_ref(), + &ref_state.global_dns_records, + ); self.client.exec_mut(|c| { c.sut.handle_dns_response(dns::RecursiveResponse { server, @@ -486,6 +491,33 @@ impl TunnelTest { ) { let now = self.flux_capacitor.now(); + // Handle the TCP DNS client, i.e. simulate applications making TCP DNS queries. + self.client.exec_mut(|c| { + c.tcp_dns_client.handle_timeout(now); + + while let Some(result) = c.tcp_dns_client.poll_query_result() { + match result.result { + Ok(message) => { + let upstream = c.dns_mapping().get_by_left(&result.server.ip()).unwrap(); + + c.received_tcp_dns_responses + .insert((*upstream, result.query.header().id())); + c.handle_dns_response(message.for_slice()) + } + Err(e) => { + tracing::error!("TCP DNS query failed: {e:#}"); + } + } + } + }); + while let Some(transmit) = self.client.exec_mut(|c| { + let packet = c.tcp_dns_client.poll_outbound()?; + c.encapsulate(packet, now) + }) { + buffered_transmits.push_from(transmit, &self.client, now) + } + + // Handle the client's `Transmit`s and timeout. while let Some(transmit) = self.client.poll_transmit(now) { self.client.exec_mut(|c| c.receive(transmit, now)) } @@ -495,6 +527,7 @@ impl TunnelTest { } }); + // Handle all gateway `Transmit`s and timeouts. for (_, gateway) in self.gateways.iter_mut() { for transmit in gateway.exec_mut(|g| g.advance_resources(global_dns_records, now)) { buffered_transmits.push_from(transmit, gateway, now); @@ -517,6 +550,7 @@ impl TunnelTest { }); } + // Handle all relay `Transmit`s and timeouts. for (_, relay) in self.relays.iter_mut() { while let Some(transmit) = relay.poll_transmit(now) { let Some(reply) = relay.exec_mut(|r| r.receive(transmit, now)) else { @@ -682,7 +716,7 @@ impl TunnelTest { tracing::warn!("Unimplemented"); } ClientEvent::TunInterfaceUpdated(config) => { - if self.client.inner().dns_by_sentinel == config.dns_by_sentinel + if self.client.inner().dns_mapping() == &config.dns_by_sentinel && self.client.inner().ipv4_routes == config.ipv4_routes && self.client.inner().ipv6_routes == config.ipv6_routes { @@ -691,16 +725,19 @@ impl TunnelTest { ); } - if self.client.inner().dns_by_sentinel != config.dns_by_sentinel { + if self.client.inner().dns_mapping() != &config.dns_by_sentinel { for gateway in self.gateways.values_mut() { gateway.exec_mut(|g| { - g.deploy_new_dns_servers(config.dns_by_sentinel.right_values().copied()) + g.deploy_new_dns_servers( + config.dns_by_sentinel.right_values().copied(), + now, + ) }) } } self.client.exec_mut(|c| { - c.dns_by_sentinel = config.dns_by_sentinel; + c.set_new_dns_servers(config.dns_by_sentinel); c.ipv4_routes = config.ipv4_routes; c.ipv6_routes = config.ipv6_routes; }); @@ -778,11 +815,9 @@ impl TunnelTest { fn on_recursive_dns_query( &self, - query: crate::dns::RecursiveQuery, + query: Message<&[u8]>, global_dns_records: &BTreeMap>, ) -> Message> { - let query = query.message; - let response = MessageBuilder::new_vec(); let mut answers = response.start_answer(&query, Rcode::NOERROR).unwrap(); diff --git a/rust/connlib/tunnel/src/tests/transition.rs b/rust/connlib/tunnel/src/tests/transition.rs index 242108d81..474e53d3a 100644 --- a/rust/connlib/tunnel/src/tests/transition.rs +++ b/rust/connlib/tunnel/src/tests/transition.rs @@ -98,6 +98,13 @@ pub(crate) struct DnsQuery { /// The DNS query ID. pub(crate) query_id: u16, pub(crate) dns_server: SocketAddr, + pub(crate) transport: DnsTransport, +} + +#[derive(Debug, Clone, PartialEq)] +pub(crate) enum DnsTransport { + Udp, + Tcp, } pub(crate) fn ping_random_ip( @@ -202,9 +209,17 @@ pub(crate) fn dns_queries( query_type(), Just(query_id), ptr_query_ip(), + dns_transport(), ) .prop_map( - |(mut domain, dns_server, r_type, query_id, maybe_reverse_record)| { + |( + mut domain, + dns_server, + r_type, + query_id, + maybe_reverse_record, + transport, + )| { if matches!(r_type, Rtype::PTR) { domain = DomainName::reverse_from_addr(maybe_reverse_record).unwrap(); @@ -215,6 +230,7 @@ pub(crate) fn dns_queries( r_type, query_id, dns_server, + transport, } }, ) @@ -231,6 +247,10 @@ fn ptr_query_ip() -> impl Strategy { ] } +fn dns_transport() -> impl Strategy { + prop_oneof![Just(DnsTransport::Udp), Just(DnsTransport::Tcp),] +} + pub(crate) fn query_type() -> impl Strategy { prop_oneof![ Just(Rtype::A), diff --git a/rust/dns-over-tcp/src/client.rs b/rust/dns-over-tcp/src/client.rs index 9085e6f75..b009d6605 100644 --- a/rust/dns-over-tcp/src/client.rs +++ b/rust/dns-over-tcp/src/client.rs @@ -1,5 +1,5 @@ use std::{ - collections::{BTreeSet, HashMap, HashSet, VecDeque}, + collections::{BTreeMap, BTreeSet, HashMap, HashSet, VecDeque}, net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, time::{Duration, Instant}, }; @@ -14,7 +14,7 @@ use ip_packet::IpPacket; use rand::{rngs::StdRng, Rng, SeedableRng}; use smoltcp::{ iface::{Interface, PollResult, SocketSet}, - socket::tcp::{self, Socket}, + socket::tcp, }; /// A sans-io DNS-over-TCP client. @@ -33,7 +33,7 @@ pub struct Client { source_ips: Option<(Ipv4Addr, Ipv6Addr)>, sockets: SocketSet<'static>, - sockets_by_remote: HashMap, + sockets_by_remote: BTreeMap, local_ports_by_socket: HashMap, /// Queries we should send to a DNS resolver. pending_queries_by_remote: HashMap>>>, @@ -95,45 +95,17 @@ impl Client { self.sockets = SocketSet::new(vec![]); self.sockets_by_remote.clear(); self.local_ports_by_socket.clear(); + self.abort_all_pending_and_sent_queries(); - self.query_results - .extend( - self.pending_queries_by_remote - .drain() - .flat_map(|(server, queries)| { - into_failed_results(server, queries, || anyhow!("Aborted")) - }), - ); - self.query_results - .extend( - self.sent_queries_by_remote - .drain() - .flat_map(|(server, queries)| { - into_failed_results(server, queries.into_values(), || anyhow!("Aborted")) - }), - ); + // Second, try to allocate a unique port per resolver. + let unique_ports = self.sample_unique_ports(resolvers.len())?; - // Second, try to create all new sockets. - let new_sockets = std::iter::zip(self.sample_unique_ports(resolvers.len())?, resolvers) - .map(|(port, server)| { - let local_endpoint = match server { - SocketAddr::V4(_) => SocketAddr::new(ipv4_source.into(), port), - SocketAddr::V6(_) => SocketAddr::new(ipv6_source.into(), port), - }; - let socket = create_tcp_socket(); - - Ok((server, local_endpoint, socket)) - }) - .collect::>>()?; - - // Third, if everything was successful, change the local state. - for (server, local_endpoint, socket) in new_sockets { - let handle = self.sockets.add(socket); - - self.sockets_by_remote.insert(server, handle); - self.local_ports_by_socket - .insert(handle, local_endpoint.port()); - } + // Third, initialise the sockets. + self.init_sockets( + std::iter::zip(unique_ports, resolvers), + ipv4_source, + ipv6_source, + ); Ok(()) } @@ -231,7 +203,7 @@ impl Client { } for (remote, handle) in self.sockets_by_remote.iter_mut() { - let socket = self.sockets.get_mut::(*handle); + let socket = self.sockets.get_mut::(*handle); let server = *remote; let pending_queries = self.pending_queries_by_remote.entry(server).or_default(); @@ -292,6 +264,52 @@ impl Client { Some(self.last_now + Duration::from(poll_in)) } + fn abort_all_pending_and_sent_queries(&mut self) { + let aborted_pending_queries = + self.pending_queries_by_remote + .drain() + .flat_map(|(server, queries)| { + into_failed_results(server, queries, || anyhow!("Aborted")) + }); + let aborted_sent_queries = + self.sent_queries_by_remote + .drain() + .flat_map(|(server, queries)| { + into_failed_results(server, queries.into_values(), || anyhow!("Aborted")) + }); + + self.query_results + .extend(aborted_pending_queries.chain(aborted_sent_queries)); + } + + fn init_sockets( + &mut self, + ports_and_resolvers: impl IntoIterator, + ipv4_source: Ipv4Addr, + ipv6_source: Ipv6Addr, + ) { + let new_sockets = ports_and_resolvers + .into_iter() + .map(|(port, server)| { + let local_endpoint = match server { + SocketAddr::V4(_) => SocketAddr::new(ipv4_source.into(), port), + SocketAddr::V6(_) => SocketAddr::new(ipv6_source.into(), port), + }; + let socket = create_tcp_socket(); + + (server, local_endpoint, socket) + }) + .collect::>(); + + for (server, local_endpoint, socket) in new_sockets { + let handle = self.sockets.add(socket); + + self.sockets_by_remote.insert(server, handle); + self.local_ports_by_socket + .insert(handle, local_endpoint.port()); + } + } + fn sample_unique_ports(&mut self, num_ports: usize) -> Result> { let mut ports = HashSet::with_capacity(num_ports); let range = MIN_PORT..=MAX_PORT; @@ -312,7 +330,7 @@ impl Client { } fn send_pending_queries( - socket: &mut Socket, + socket: &mut tcp::Socket, server: SocketAddr, pending_queries: &mut VecDeque>>, sent_queries: &mut HashMap>>, @@ -348,7 +366,7 @@ fn send_pending_queries( } fn recv_responses( - socket: &mut Socket, + socket: &mut tcp::Socket, server: SocketAddr, pending_queries: &mut VecDeque>>, sent_queries: &mut HashMap>>, diff --git a/rust/dns-over-tcp/src/lib.rs b/rust/dns-over-tcp/src/lib.rs index 846a2f1fc..ba275c7d8 100644 --- a/rust/dns-over-tcp/src/lib.rs +++ b/rust/dns-over-tcp/src/lib.rs @@ -6,7 +6,7 @@ mod stub_device; mod time; pub use client::{Client, QueryResult}; -pub use server::{Server, SocketHandle}; +pub use server::{Query, Server, SocketHandle}; fn create_tcp_socket() -> smoltcp::socket::tcp::Socket<'static> { /// The 2-byte length prefix of DNS over TCP messages limits their size to effectively u16::MAX. diff --git a/rust/dns-over-tcp/src/server.rs b/rust/dns-over-tcp/src/server.rs index 1cd09e291..df94c9582 100644 --- a/rust/dns-over-tcp/src/server.rs +++ b/rust/dns-over-tcp/src/server.rs @@ -70,8 +70,23 @@ impl Server { /// The constant configures, how many concurrent clients you would like to be able to serve per listen address. pub fn set_listen_addresses( &mut self, - addresses: Vec, + addresses: BTreeSet, ) { + let current_listen_endpoints = self + .listen_endpoints + .values() + .copied() + .collect::>(); + + if current_listen_endpoints == addresses { + tracing::debug!( + ?current_listen_endpoints, + "Already listening on this exact set of addresses" + ); + + return; + } + assert!(NUM_CONCURRENT_CLIENTS > 0); let mut sockets = @@ -143,13 +158,6 @@ impl Server { Ok(()) } - /// Resets the socket associated with the given handle. - /// - /// Use this if you encountered an error while processing a previously emitted DNS query. - pub fn reset(&mut self, handle: SocketHandle) { - self.sockets.get_mut::(handle.0).abort(); - } - /// Inform the server that time advanced. /// /// Typical for a sans-IO design, `handle_timeout` will work through all local buffers and process them as much as possible. diff --git a/rust/dns-over-tcp/tests/client_and_server.rs b/rust/dns-over-tcp/tests/client_and_server.rs index d8ccf968f..ed5b294ba 100644 --- a/rust/dns-over-tcp/tests/client_and_server.rs +++ b/rust/dns-over-tcp/tests/client_and_server.rs @@ -25,7 +25,7 @@ fn smoke() { .unwrap(); let mut dns_server = dns_over_tcp::Server::new(Instant::now()); - dns_server.set_listen_addresses::<1>(vec![resolver_addr]); + dns_server.set_listen_addresses::<1>(BTreeSet::from([resolver_addr])); for id in 0..5 { dns_client diff --git a/rust/dns-over-tcp/tests/smoke_server.rs b/rust/dns-over-tcp/tests/smoke_server.rs index 5003653ad..95b1620a3 100644 --- a/rust/dns-over-tcp/tests/smoke_server.rs +++ b/rust/dns-over-tcp/tests/smoke_server.rs @@ -1,4 +1,5 @@ use std::{ + collections::BTreeSet, net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4}, process::Stdio, task::{ready, Context, Poll}, @@ -36,7 +37,7 @@ async fn smoke() { let listen_addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(100, 100, 111, 1), 53)); let mut dns_server = dns_over_tcp::Server::new(Instant::now()); - dns_server.set_listen_addresses::(vec![listen_addr]); + dns_server.set_listen_addresses::(BTreeSet::from([listen_addr])); let mut eventloop = Eventloop::new(Box::new(tun), dns_server); tokio::spawn(std::future::poll_fn(move |cx| eventloop.poll(cx))); diff --git a/rust/ip-packet/Cargo.toml b/rust/ip-packet/Cargo.toml index 74748a8bf..79f691687 100644 --- a/rust/ip-packet/Cargo.toml +++ b/rust/ip-packet/Cargo.toml @@ -11,7 +11,6 @@ proptest = ["dep:proptest"] [dependencies] anyhow = "1.0.86" -domain = "0.10.1" etherparse = "0.15" proptest = { version = "1", optional = true } thiserror = "1" diff --git a/rust/ip-packet/src/make.rs b/rust/ip-packet/src/make.rs index e723b0315..b57791dc8 100644 --- a/rust/ip-packet/src/make.rs +++ b/rust/ip-packet/src/make.rs @@ -2,9 +2,8 @@ use crate::{IpPacket, IpPacketBuf}; use anyhow::{Context, Result}; -use domain::base::{iana::Opcode, MessageBuilder, Name, Question, Rtype}; use etherparse::PacketBuilder; -use std::net::{IpAddr, SocketAddr}; +use std::net::IpAddr; /// Helper macro to turn a [`PacketBuilder`] into an [`IpPacket`]. #[macro_export] @@ -151,31 +150,6 @@ where } } -pub fn dns_query( - domain: Name>, - kind: Rtype, - src: SocketAddr, - dst: SocketAddr, - id: u16, -) -> Result { - // Create the DNS query message - let mut msg_builder = MessageBuilder::new_vec(); - - msg_builder.header_mut().set_opcode(Opcode::QUERY); - msg_builder.header_mut().set_rd(true); - msg_builder.header_mut().set_id(id); - - // Create the query - let mut question_builder = msg_builder.question(); - question_builder - .push(Question::new_in(domain, kind)) - .unwrap(); - - let payload = question_builder.finish(); - - udp_packet(src.ip(), dst.ip(), src.port(), dst.port(), payload) -} - #[derive(thiserror::Error, Debug)] #[error("IPs must be of the same version")] pub struct IpVersionMismatch; diff --git a/scripts/tests/tcp-dns.sh b/scripts/tests/tcp-dns.sh new file mode 100755 index 000000000..c4c8946e4 --- /dev/null +++ b/scripts/tests/tcp-dns.sh @@ -0,0 +1,14 @@ +#!/usr/bin/env bash + +source "./scripts/tests/lib.sh" + +client sh -c "apk add bind-tools" # The compat tests run using the production image which doesn't have `dig`. + +echo "Resolving DNS resource over TCP" +client sh -c "dig +tcp dns.httpbin" + +echo "Resolving non-DNS resource over TCP" +client sh -c "dig +tcp example.com" + +echo "Testing TCP fallback" +client sh -c "dig 2048.size.dns.netmeister.org" diff --git a/website/src/components/Changelog/Android.tsx b/website/src/components/Changelog/Android.tsx index 7a5c03d20..6a28958ef 100644 --- a/website/src/components/Changelog/Android.tsx +++ b/website/src/components/Changelog/Android.tsx @@ -11,7 +11,9 @@ export default function Android() { title="Android" > {/* When you cut a release, remove any solved issues from the "known issues" lists over in `client-apps`. This must not be done when the issue's PR merges. */} - + + Handles DNS queries over TCP correctly. + Ensures Firefox doesn't attempt to use DNS over HTTPS when Firezone is diff --git a/website/src/components/Changelog/Apple.tsx b/website/src/components/Changelog/Apple.tsx index e5da86015..afeb7b05e 100644 --- a/website/src/components/Changelog/Apple.tsx +++ b/website/src/components/Changelog/Apple.tsx @@ -11,7 +11,9 @@ export default function Apple() { title="macOS / iOS" > {/* When you cut a release, remove any solved issues from the "known issues" lists over in `client-apps`. This must not be done when the issue's PR merges. */} - + + Handles DNS queries over TCP correctly. + Ensures Firefox doesn't attempt to use DNS over HTTPS when Firezone is diff --git a/website/src/components/Changelog/GUI.tsx b/website/src/components/Changelog/GUI.tsx index fd2ca1d2d..7c729c7c4 100644 --- a/website/src/components/Changelog/GUI.tsx +++ b/website/src/components/Changelog/GUI.tsx @@ -15,9 +15,7 @@ export default function GUI({ title }: { title: string }) { {/* When you cut a release, remove any solved issues from the "known issues" lists over in `client-apps`. This must not be done when the issue's PR merges. */} - - This is a maintenance release with no user-facing changes. - + Handles DNS queries over TCP correctly. The IPC service `firezone-client-ipc.exe` is now signed. diff --git a/website/src/components/Changelog/Headless.tsx b/website/src/components/Changelog/Headless.tsx index 5493587a3..6951a3674 100644 --- a/website/src/components/Changelog/Headless.tsx +++ b/website/src/components/Changelog/Headless.tsx @@ -11,7 +11,9 @@ export default function Headless() { return ( {/* When you cut a release, remove any solved issues from the "known issues" lists over in `client-apps`. This must not be done when the issue's PR merges. */} - + + Handles DNS queries over TCP correctly. + Ensures Firefox doesn't attempt to use DNS over HTTPS when Firezone is