diff --git a/.github/workflows/_rust.yml b/.github/workflows/_rust.yml index e3283310d..040c8a599 100644 --- a/.github/workflows/_rust.yml +++ b/.github/workflows/_rust.yml @@ -116,6 +116,7 @@ jobs: rg --count --no-ignore "Performed IP-NAT64" $TESTCASES_DIR rg --count --no-ignore "Too big DNS response, truncating" $TESTCASES_DIR rg --count --no-ignore "Destination is unreachable" $TESTCASES_DIR + rg --count --no-ignore "Forwarding query for DNS resource to corresponding site" $TESTCASES_DIR env: # diff --git a/rust/connlib/tunnel/proptest-regressions/tests.txt b/rust/connlib/tunnel/proptest-regressions/tests.txt index 225515f93..792751988 100644 --- a/rust/connlib/tunnel/proptest-regressions/tests.txt +++ b/rust/connlib/tunnel/proptest-regressions/tests.txt @@ -159,3 +159,4 @@ cc a7f22e7cc2c79ffd580baf4bc8296557c67afe245ccf07e895e7cd2a969a228e cc eca099d2fdef9adba841f523ce426089fda9bf7deb3bc43a86c4f09cf4b1199d cc 2d4a7f40ce445d9b159941ba5cf94b635db018c6229a88e22796091e4c94b059 cc 16a8e929be616a64b36204ff393a1cf376db5559d051627ef4eff1055f9604a5 +cc b5dc48d89cc4f0c61ed3b7c58338f8f9f06654a5948bad62869ea4bbecf270d8 diff --git a/rust/connlib/tunnel/src/client.rs b/rust/connlib/tunnel/src/client.rs index 1bc4a890e..cc890911a 100644 --- a/rust/connlib/tunnel/src/client.rs +++ b/rust/connlib/tunnel/src/client.rs @@ -3,6 +3,7 @@ mod resource; pub(crate) use resource::{CidrResource, Resource}; #[cfg(all(feature = "proptest", test))] pub(crate) use resource::{DnsResource, InternetResource}; +use ringbuffer::{AllocRingBuffer, RingBuffer}; use crate::dns::StubResolver; use crate::expiring_map::ExpiringMap; @@ -179,7 +180,9 @@ impl DnsResourceNatState { struct PendingFlow { last_intent_sent_at: Instant, - packets: UniquePacketBuffer, + resource_packets: UniquePacketBuffer, + udp_dns_queries: AllocRingBuffer, + tcp_dns_queries: AllocRingBuffer, } impl PendingFlow { @@ -189,13 +192,25 @@ impl PendingFlow { /// Thus, we may receive a fair few packets before we can send them. const CAPACITY_POW_2: usize = 7; // 2^7 = 128 - fn new(now: Instant, packet: IpPacket) -> Self { - let mut packets = UniquePacketBuffer::with_capacity_power_of_2(Self::CAPACITY_POW_2); - packets.push(packet); - - Self { + fn new(now: Instant, trigger: ConnectionTrigger) -> Self { + let mut this = Self { last_intent_sent_at: now, - packets, + resource_packets: UniquePacketBuffer::with_capacity_power_of_2(Self::CAPACITY_POW_2), + udp_dns_queries: AllocRingBuffer::with_capacity_power_of_2(Self::CAPACITY_POW_2), + tcp_dns_queries: AllocRingBuffer::with_capacity_power_of_2(Self::CAPACITY_POW_2), + }; + this.push(trigger); + + this + } + + fn push(&mut self, trigger: ConnectionTrigger) { + match trigger { + ConnectionTrigger::PacketForResource(packet) => { + self.resource_packets.push(packet); + } + ConnectionTrigger::UdpDnsQueryForSite(packet) => self.udp_dns_queries.push(packet), + ConnectionTrigger::TcpDnsQueryForSite(query) => self.tcp_dns_queries.push(query), } } } @@ -524,7 +539,7 @@ impl ClientState { .inspect_err(|e| tracing::debug!(%gid, %local, %from, "{e}")) .ok()?; - let packet = maybe_mangle_dns_response_from_cidr_resource( + let packet = maybe_mangle_dns_response_from_upstream_dns_server( packet, &mut self.udp_dns_sockets_by_upstream_and_query_id, ); @@ -749,7 +764,10 @@ impl ClientState { self.peers.add_ip(&gateway_id, &gateway_tun.v4.into()); self.peers.add_ip(&gateway_id, &gateway_tun.v6.into()); - let buffered_packets = pending_flow.packets; + // Deal with buffered packets + + // 1. Buffered packets for resources + let buffered_resource_packets = pending_flow.resource_packets; match resource { Resource::Cidr(_) | Resource::Internet(_) => { @@ -760,7 +778,7 @@ impl ClientState { ); // For CIDR and Internet resources, we can directly queue the buffered packets. - for packet in buffered_packets { + for packet in buffered_resource_packets { encapsulate_and_buffer( packet, gateway_id, @@ -770,7 +788,40 @@ impl ClientState { ); } } - Resource::Dns(_) => self.update_dns_resource_nat(now, buffered_packets.into_iter()), + Resource::Dns(_) => { + self.update_dns_resource_nat(now, buffered_resource_packets.into_iter()) + } + } + + // 2. Buffered UDP DNS queries for the Gateway + for packet in pending_flow.udp_dns_queries { + let gateway = self.peers.get(&gateway_id).context("Unknown peer")?; // If this error happens we have a bug: We just inserted it above. + + let upstream = gateway.tun_dns_server_endpoint(packet.destination()); + let packet = + self.mangle_udp_dns_query_to_new_upstream_through_tunnel(upstream, now, packet); + + encapsulate_and_buffer( + packet, + gateway_id, + now, + &mut self.node, + &mut self.buffered_transmits, + ) + } + + // 3. Buffered TCP DNS queries for the Gateway + for query in pending_flow.tcp_dns_queries { + let server = match query.local { + SocketAddr::V4(_) => { + SocketAddr::new(gateway_tun.v4.into(), crate::gateway::TUN_DNS_PORT) + } + SocketAddr::V6(_) => { + SocketAddr::new(gateway_tun.v6.into(), crate::gateway::TUN_DNS_PORT) + } + }; + + self.forward_tcp_dns_query_to_new_upstream_via_tunnel(server, query); } Ok(Ok(())) @@ -820,16 +871,23 @@ impl ClientState { } #[tracing::instrument(level = "debug", skip_all, fields(%resource))] - fn on_not_connected_resource(&mut self, resource: ResourceId, packet: IpPacket, now: Instant) { + fn on_not_connected_resource( + &mut self, + resource: ResourceId, + trigger: impl Into, + now: Instant, + ) { + let trigger = trigger.into(); + debug_assert!(self.resources_by_id.contains_key(&resource)); match self.pending_flows.entry(resource) { Entry::Vacant(v) => { - v.insert(PendingFlow::new(now, packet)); + v.insert(PendingFlow::new(now, trigger)); } Entry::Occupied(mut o) => { let pending_flow = o.get_mut(); - pending_flow.packets.push(packet); + pending_flow.push(trigger); let time_since_last_intent = now.duration_since(pending_flow.last_intent_sent_at); @@ -1102,7 +1160,7 @@ impl ClientState { fn handle_udp_dns_query( &mut self, upstream: SocketAddr, - mut packet: IpPacket, + packet: IpPacket, now: Instant, ) -> ControlFlow<(), IpPacket> { let Some(datagram) = packet.as_udp() else { @@ -1131,29 +1189,13 @@ impl ClientState { "Failed to queue UDP DNS response: {}" ); } - dns::ResolveStrategy::Recurse => { - let query_id = message.header().id(); - + dns::ResolveStrategy::RecurseLocal => { if self.should_forward_dns_query_to_gateway(upstream.ip()) { - tracing::trace!(server = %upstream, %query_id, "Forwarding UDP DNS query via tunnel"); - - self.udp_dns_sockets_by_upstream_and_query_id.insert( - (upstream, message.header().id()), - SocketAddr::new(packet.destination(), dns::DNS_PORT), - now + IDS_EXPIRE, - ); - packet.set_dst(upstream.ip()); - // TODO: Remove this once we disallow non-standard DNS ports: https://github.com/firezone/firezone/issues/8330 - packet - .as_udp_mut() - .expect("we parsed it as a UDP packet earlier") - .set_destination_port(upstream.port()); - - packet.update_checksum(); + let packet = self + .mangle_udp_dns_query_to_new_upstream_through_tunnel(upstream, now, packet); return ControlFlow::Continue(packet); } - let query_id = message.header().id(); tracing::trace!(server = %upstream, %query_id, "Forwarding UDP DNS query directly via host"); @@ -1161,13 +1203,70 @@ impl ClientState { self.buffered_dns_queries .push_back(dns::RecursiveQuery::via_udp(source, upstream, message)); } + dns::ResolveStrategy::RecurseSite(resource) => { + let Some(gateway) = + peer_by_resource_mut(&self.resources_gateways, &mut self.peers, resource) + else { + self.on_not_connected_resource( + resource, + ConnectionTrigger::UdpDnsQueryForSite(packet), + now, + ); + return ControlFlow::Break(()); + }; + + let upstream = gateway.tun_dns_server_endpoint(packet.destination()); + + let packet = + self.mangle_udp_dns_query_to_new_upstream_through_tunnel(upstream, now, packet); + + return ControlFlow::Continue(packet); + } } ControlFlow::Break(()) } + fn mangle_udp_dns_query_to_new_upstream_through_tunnel( + &mut self, + upstream: SocketAddr, + now: Instant, + mut packet: IpPacket, + ) -> IpPacket { + let dst_ip = packet.destination(); + let datagram = packet + .as_udp() + .expect("to be a valid UDP packet at this point"); + + let dst_port = datagram.destination_port(); + let query_id = parse_udp_dns_message(&datagram) + .expect("to be a valid DNS query at this point") + .header() + .id(); + + let connlib_dns_server = SocketAddr::new(dst_ip, dst_port); + + self.udp_dns_sockets_by_upstream_and_query_id.insert( + (upstream, query_id), + connlib_dns_server, + now + IDS_EXPIRE, + ); + packet.set_dst(upstream.ip()); + // TODO: Remove this once we disallow non-standard DNS ports: https://github.com/firezone/firezone/issues/8330 + packet + .as_udp_mut() + .expect("to be a valid UDP packet at this point") + .set_destination_port(upstream.port()); + + packet.update_checksum(); + + tracing::trace!(%upstream, %connlib_dns_server, %query_id, "Forwarding UDP DNS query via tunnel"); + + packet + } + fn handle_tcp_dns_query(&mut self, query: dns_over_tcp::Query, now: Instant) { - let message = query.message; + let query_id = query.message.header().id(); let Some(upstream) = self.dns_mapping.get_by_left(&query.local.ip()) else { // This is highly-unlikely but might be possible if our DNS mapping changes whilst the TCP DNS server is processing a request. @@ -1175,7 +1274,7 @@ impl ClientState { }; let server = upstream.address(); - match self.stub_resolver.handle(message.for_slice_ref()) { + match self.stub_resolver.handle(query.message.for_slice_ref()) { dns::ResolveStrategy::LocalResponse(response) => { self.clear_dns_resource_nat_for_domain(response.for_slice_ref()); self.update_dns_resource_nat(now, iter::empty()); @@ -1185,31 +1284,9 @@ impl ClientState { "Failed to send TCP DNS response: {}" ); } - dns::ResolveStrategy::Recurse => { - let query_id = message.header().id(); - + dns::ResolveStrategy::RecurseLocal => { if self.should_forward_dns_query_to_gateway(server.ip()) { - match self.tcp_dns_client.send_query(server, message.clone()) { - Ok(()) => {} - Err(e) => { - tracing::warn!("Failed to send recursive TCP DNS query: {e:#}"); - - unwrap_or_debug!( - self.tcp_dns_server.send_message( - query.socket, - dns::servfail(message.for_slice_ref()) - ), - "Failed to send TCP DNS response: {}" - ); - return; - } - }; - - let existing = self - .tcp_dns_sockets_by_upstream_and_query_id - .insert((server, query_id), query.socket); - - debug_assert!(existing.is_none(), "Query IDs should be unique"); + self.forward_tcp_dns_query_to_new_upstream_via_tunnel(server, query); return; } @@ -1217,11 +1294,64 @@ impl ClientState { tracing::trace!(%server, %query_id, "Forwarding TCP DNS query"); self.buffered_dns_queries - .push_back(dns::RecursiveQuery::via_tcp(query.socket, server, message)); + .push_back(dns::RecursiveQuery::via_tcp( + query.socket, + server, + query.message, + )); + } + dns::ResolveStrategy::RecurseSite(resource) => { + let Some(gateway) = + peer_by_resource_mut(&self.resources_gateways, &mut self.peers, resource) + else { + self.on_not_connected_resource( + resource, + ConnectionTrigger::TcpDnsQueryForSite(query), + now, + ); + return; + }; + + let server = gateway.tun_dns_server_endpoint(query.local.ip()); + + self.forward_tcp_dns_query_to_new_upstream_via_tunnel(server, query); } }; } + fn forward_tcp_dns_query_to_new_upstream_via_tunnel( + &mut self, + server: SocketAddr, + query: dns_over_tcp::Query, + ) { + let query_id = query.message.header().id(); + + match self + .tcp_dns_client + .send_query(server, query.message.clone()) + { + Ok(()) => {} + Err(e) => { + tracing::warn!( + "Failed to send recursive TCP DNS query to upstream resolver: {e:#}" + ); + + unwrap_or_debug!( + self.tcp_dns_server + .send_message(query.socket, dns::servfail(query.message.for_slice_ref())), + "Failed to send TCP DNS response: {}" + ); + return; + } + }; + + let existing = self + .tcp_dns_sockets_by_upstream_and_query_id + .insert((server, query_id), query.socket); + + debug_assert!(existing.is_none(), "Query IDs should be unique"); + } + fn maybe_update_tun_routes(&mut self) { let Some(config) = self.tun_config.clone() else { return; @@ -1790,7 +1920,7 @@ fn is_definitely_not_a_resource(ip: IpAddr) -> bool { false } -fn maybe_mangle_dns_response_from_cidr_resource( +fn maybe_mangle_dns_response_from_upstream_dns_server( mut packet: IpPacket, udp_dns_sockets_by_upstream_and_query_id: &mut ExpiringMap<(SocketAddr, u16), SocketAddr>, ) -> IpPacket { @@ -1854,6 +1984,24 @@ fn truncate_dns_response(mut message: Message>) -> Vec { message_bytes } +/// What triggered us to establish a connection to a Gateway. +enum ConnectionTrigger { + /// A packet received on the TUN device with a destination IP that maps to one of our resources. + PacketForResource(IpPacket), + /// A UDP DNS query that needs to be resolved within a particular site that we aren't connected to yet. + /// + /// This packet isn't mangled yet to point to the Gateway's TUN device IP because at the time of buffering, that IP is unknown. + UdpDnsQueryForSite(IpPacket), + /// A TCP DNS query that needs to be resolved within a particular site that we aren't connected to yet. + TcpDnsQueryForSite(dns_over_tcp::Query), +} + +impl From for ConnectionTrigger { + fn from(v: IpPacket) -> Self { + Self::PacketForResource(v) + } +} + pub struct IpProvider { ipv4: Box + Send + Sync>, ipv6: Box + Send + Sync>, diff --git a/rust/connlib/tunnel/src/dns.rs b/rust/connlib/tunnel/src/dns.rs index f614c5341..13d9b4a1c 100644 --- a/rust/connlib/tunnel/src/dns.rs +++ b/rust/connlib/tunnel/src/dns.rs @@ -102,8 +102,10 @@ pub(crate) enum Transport { pub(crate) enum ResolveStrategy { /// The query is for a Resource, we have an IP mapped already, and we can respond instantly LocalResponse(Message>), - /// The query is for a non-Resource, forward it to an upstream or system resolver. - Recurse, + /// The query is for a non-Resource, forward it locally to an upstream or system resolver. + RecurseLocal, + /// The query is for a DNS resource but for a type that we don't intercept (i.e. SRV, TXT, ...), forward it to the site that hosts the DNS resource and resolve it there. + RecurseSite(ResourceId), } impl Default for StubResolver { @@ -274,9 +276,14 @@ impl StubResolver { (Rtype::AAAA, Some(resource)) => { self.get_or_assign_aaaa_records(domain.clone(), resource) } + (Rtype::SRV | Rtype::TXT, Some(resource)) => { + tracing::debug!(%qtype, %resource, "Forwarding query for DNS resource to corresponding site"); + + return Ok(ResolveStrategy::RecurseSite(resource)); + } (Rtype::PTR, _) => { let Some(fqdn) = self.resource_address_name_by_reservse_dns(&domain) else { - return Ok(ResolveStrategy::Recurse); + return Ok(ResolveStrategy::RecurseLocal); }; vec![AllRecordData::Ptr(domain::rdata::Ptr::new(fqdn))] @@ -288,7 +295,7 @@ impl StubResolver { let response = build_dns_with_answer(message, domain, Vec::default())?; return Ok(ResolveStrategy::LocalResponse(response)); } - _ => return Ok(ResolveStrategy::Recurse), + _ => return Ok(ResolveStrategy::RecurseLocal), }; tracing::trace!(%qtype, %domain, records = ?resource_records, "Forming DNS response"); diff --git a/rust/connlib/tunnel/src/gateway.rs b/rust/connlib/tunnel/src/gateway.rs index d01a09941..6291d06eb 100644 --- a/rust/connlib/tunnel/src/gateway.rs +++ b/rust/connlib/tunnel/src/gateway.rs @@ -15,6 +15,8 @@ use std::net::{IpAddr, SocketAddr}; use std::sync::Arc; use std::time::{Duration, Instant}; +pub const TUN_DNS_PORT: u16 = 53535; + const EXPIRE_RESOURCES_INTERVAL: Duration = Duration::from_secs(1); /// A SANS-IO implementation of a gateway's functionality. diff --git a/rust/connlib/tunnel/src/peer.rs b/rust/connlib/tunnel/src/peer.rs index 8f2c51196..2a68136c2 100644 --- a/rust/connlib/tunnel/src/peer.rs +++ b/rust/connlib/tunnel/src/peer.rs @@ -1,6 +1,6 @@ use std::collections::{hash_map, BTreeMap, BTreeSet, HashMap, HashSet, VecDeque}; use std::iter; -use std::net::IpAddr; +use std::net::{IpAddr, SocketAddr}; use std::time::Instant; use crate::client::{IPV4_RESOURCES, IPV6_RESOURCES}; @@ -37,6 +37,17 @@ impl GatewayOnClient { self.allowed_ips.insert(*ip, HashSet::from([*id])); } } + + /// For a given destination IP, return the endpoint to which the DNS query should be sent. + pub(crate) fn tun_dns_server_endpoint(&self, dst: IpAddr) -> SocketAddr { + let new_dst_ip = match dst { + IpAddr::V4(_) => self.gateway_tun.v4.into(), + IpAddr::V6(_) => self.gateway_tun.v6.into(), + }; + let new_dst_port = crate::gateway::TUN_DNS_PORT; + + SocketAddr::new(new_dst_ip, new_dst_port) + } } impl GatewayOnClient { diff --git a/rust/connlib/tunnel/src/tests/dns_records.rs b/rust/connlib/tunnel/src/tests/dns_records.rs index 56428cc9d..e468f51d3 100644 --- a/rust/connlib/tunnel/src/tests/dns_records.rs +++ b/rust/connlib/tunnel/src/tests/dns_records.rs @@ -56,6 +56,10 @@ impl DnsRecords { .dedup() .collect_vec() } + + pub(crate) fn is_empty(&self) -> bool { + self.inner.is_empty() + } } impl From for DnsRecords diff --git a/rust/connlib/tunnel/src/tests/reference.rs b/rust/connlib/tunnel/src/tests/reference.rs index b21c282b3..c43ad011e 100644 --- a/rust/connlib/tunnel/src/tests/reference.rs +++ b/rust/connlib/tunnel/src/tests/reference.rs @@ -701,10 +701,23 @@ impl ReferenceState { // We surface what are the existing rtypes for a domain so that it's easier // for the proptests to hit an existing record. fn all_domains(&self) -> Vec<(DomainName, Vec)> { - self.global_dns_records - .domains_iter() - .map(|d| (d.clone(), self.global_dns_records.domain_rtypes(&d))) - .collect() + fn domains_and_rtypes( + records: &DnsRecords, + ) -> impl Iterator)> + use<'_> { + records + .domains_iter() + .map(|d| (d.clone(), records.domain_rtypes(&d))) + } + + // We may have multiple gateways in a site, so we need to dedup. + let unique_domains = self + .gateways + .values() + .flat_map(|g| domains_and_rtypes(g.inner().dns_records())) + .chain(domains_and_rtypes(&self.global_dns_records)) + .collect::>(); + + Vec::from_iter(unique_domains) } fn reachable_dns_servers(&self) -> Vec { diff --git a/rust/connlib/tunnel/src/tests/sim_client.rs b/rust/connlib/tunnel/src/tests/sim_client.rs index da221d5d1..e380803b2 100644 --- a/rust/connlib/tunnel/src/tests/sim_client.rs +++ b/rust/connlib/tunnel/src/tests/sim_client.rs @@ -358,6 +358,9 @@ impl SimClient { AllRecordData::Txt(_) => { continue; } + AllRecordData::Srv(_) => { + continue; + } unhandled => { panic!("Unexpected record data: {unhandled:?}") } @@ -787,6 +790,11 @@ impl RefClient { } } + if let Some(resource) = self.is_site_specific_dns_query(query) { + self.set_resource_online(resource); + return; + } + if let Some(resource) = self.dns_query_via_resource(query) { self.connect_to_internet_or_cidr_resource(resource); self.set_resource_online(resource); @@ -1018,6 +1026,14 @@ impl RefClient { maybe_active_cidr_resource.or(maybe_active_internet_resource) } + pub(crate) fn is_site_specific_dns_query(&self, query: &DnsQuery) -> Option { + if !matches!(query.r_type, Rtype::SRV | Rtype::TXT) { + return None; + } + + self.dns_resource_by_domain(&query.domain) + } + pub(crate) fn all_resource_ids(&self) -> Vec { self.resources.iter().map(|r| r.id()).collect() } diff --git a/rust/connlib/tunnel/src/tests/sim_gateway.rs b/rust/connlib/tunnel/src/tests/sim_gateway.rs index 34aabd8a5..b926998e1 100644 --- a/rust/connlib/tunnel/src/tests/sim_gateway.rs +++ b/rust/connlib/tunnel/src/tests/sim_gateway.rs @@ -16,6 +16,7 @@ use proptest::prelude::*; use snownet::Transmit; use std::{ collections::{BTreeMap, HashMap}, + iter, net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, time::Instant, }; @@ -34,15 +35,21 @@ pub(crate) struct SimGateway { /// The received TCP packets, indexed by our custom TCP payload. pub(crate) received_tcp_requests: BTreeMap, + site_specific_dns_records: DnsRecords, udp_dns_server_resources: HashMap, tcp_dns_server_resources: HashMap, } impl SimGateway { - pub(crate) fn new(id: GatewayId, sut: GatewayState) -> Self { + pub(crate) fn new( + id: GatewayId, + sut: GatewayState, + site_specific_dns_records: DnsRecords, + ) -> Self { Self { id, sut, + site_specific_dns_records, received_icmp_requests: Default::default(), udp_dns_server_resources: Default::default(), tcp_dns_server_resources: Default::default(), @@ -77,16 +84,35 @@ impl SimGateway { global_dns_records: &DnsRecords, now: Instant, ) -> Vec> { - let udp_server_packets = self.udp_dns_server_resources.values_mut().flat_map(|s| { - s.handle_timeout(global_dns_records, now); + let Some(ip_config) = self.sut.tunnel_ip_config() else { + tracing::error!("Tunnel IP configuration not set"); + return Vec::new(); + }; - std::iter::from_fn(|| s.poll_outbound()) - }); - let tcp_server_packets = self.tcp_dns_server_resources.values_mut().flat_map(|s| { - s.handle_timeout(global_dns_records, now); + let udp_server_packets = + self.udp_dns_server_resources + .iter_mut() + .flat_map(|(socket, server)| { + if ip_config.is_ip(socket.ip()) { + server.handle_timeout(&self.site_specific_dns_records, now); + } else { + server.handle_timeout(global_dns_records, now); + } - std::iter::from_fn(|| s.poll_outbound()) - }); + std::iter::from_fn(|| server.poll_outbound()) + }); + let tcp_server_packets = + self.tcp_dns_server_resources + .iter_mut() + .flat_map(|(socket, server)| { + if ip_config.is_ip(socket.ip()) { + server.handle_timeout(&self.site_specific_dns_records, now); + } else { + server.handle_timeout(global_dns_records, now); + } + + std::iter::from_fn(|| server.poll_outbound()) + }); udp_server_packets .chain(tcp_server_packets) @@ -109,7 +135,22 @@ impl SimGateway { ) { self.udp_dns_server_resources.clear(); - for server in dns_servers { + let tun_dns_server_port = 53535; // Hardcoded here so we think about backwards-compatibility when changing it. + let Some(ip_config) = self.sut.tunnel_ip_config() else { + tracing::error!("Tunnel IP configuration not set"); + return; + }; + + for server in dns_servers + .chain(iter::once(SocketAddr::from(( + ip_config.v4, + tun_dns_server_port, + )))) + .chain(iter::once(SocketAddr::from(( + ip_config.v6, + tun_dns_server_port, + )))) + { self.udp_dns_server_resources .insert(server, UdpDnsServerResource::default()); self.tcp_dns_server_resources @@ -255,6 +296,8 @@ pub struct RefGateway { pub(crate) key: PrivateKey, pub(crate) tunnel_ip4: Ipv4Addr, pub(crate) tunnel_ip6: Ipv6Addr, + + site_specific_dns_records: DnsRecords, } impl RefGateway { @@ -262,24 +305,29 @@ impl RefGateway { /// /// This simulates receiving the `init` message from the portal. pub(crate) fn init(self, id: GatewayId, now: Instant) -> SimGateway { - let mut sut = GatewayState::new(self.key.0, now); + let mut sut = GatewayState::new(self.key.0, now); // Cheating a bit here by reusing the key as seed. sut.update_tun_device(IpConfig { v4: self.tunnel_ip4, v6: self.tunnel_ip6, }); - SimGateway::new(id, sut) // Cheating a bit here by reusing the key as seed. + SimGateway::new(id, sut, self.site_specific_dns_records) + } + + pub fn dns_records(&self) -> &DnsRecords { + &self.site_specific_dns_records } } pub(crate) fn ref_gateway_host( tunnel_ip4s: impl Strategy, tunnel_ip6s: impl Strategy, + site_specific_dns_records: impl Strategy, ) -> impl Strategy> { host( dual_ip_stack(), any_port(), - ref_gateway(tunnel_ip4s, tunnel_ip6s), + ref_gateway(tunnel_ip4s, tunnel_ip6s, site_specific_dns_records), latency(200), // We assume gateways have a somewhat decent Internet connection. ) } @@ -287,14 +335,22 @@ pub(crate) fn ref_gateway_host( fn ref_gateway( tunnel_ip4s: impl Strategy, tunnel_ip6s: impl Strategy, + site_specific_dns_records: impl Strategy, ) -> impl Strategy { - (private_key(), tunnel_ip4s, tunnel_ip6s).prop_map(move |(key, tunnel_ip4, tunnel_ip6)| { - RefGateway { - key, - tunnel_ip4, - tunnel_ip6, - } - }) + ( + private_key(), + tunnel_ip4s, + tunnel_ip6s, + site_specific_dns_records, + ) + .prop_map( + move |(key, tunnel_ip4, tunnel_ip6, site_specific_dns_records)| RefGateway { + key, + tunnel_ip4, + tunnel_ip6, + site_specific_dns_records, + }, + ) } fn icmp_error_reply(packet: &IpPacket, error: IcmpError) -> Result { diff --git a/rust/connlib/tunnel/src/tests/strategies.rs b/rust/connlib/tunnel/src/tests/strategies.rs index cacb3dd8e..cb4187688 100644 --- a/rust/connlib/tunnel/src/tests/strategies.rs +++ b/rust/connlib/tunnel/src/tests/strategies.rs @@ -37,6 +37,16 @@ fn dns_record() -> impl Strategy { ] } +pub(crate) fn site_specific_dns_record() -> impl Strategy { + prop_oneof![ + collection::vec(txt_record(), 6..=10) + .prop_map(|sections| { sections.into_iter().flatten().collect_vec() }) + .prop_map(|o| domain::rdata::Txt::from_octets(o).unwrap()) + .prop_map(DomainRecord::Txt), + srv_record() + ] +} + // A maximum length txt record section fn txt_record() -> impl Strategy> { "[a-z]{255}".prop_map(|s| { @@ -50,6 +60,18 @@ fn txt_record() -> impl Strategy> { }) } +fn srv_record() -> impl Strategy { + ( + any::(), + any::(), + any::(), + domain_name(2..4).prop_map(|d| d.parse().unwrap()), + ) + .prop_map(|(priority, weight, port, target)| { + DomainRecord::Srv(domain::rdata::Srv::new(priority, weight, port, target)) + }) +} + pub(crate) fn packet_source_v4(client: Ipv4Addr) -> impl Strategy { prop_oneof![ 10 => Just(client), diff --git a/rust/connlib/tunnel/src/tests/stub_portal.rs b/rust/connlib/tunnel/src/tests/stub_portal.rs index dc39ae9da..4bbcc8816 100644 --- a/rust/connlib/tunnel/src/tests/stub_portal.rs +++ b/rust/connlib/tunnel/src/tests/stub_portal.rs @@ -3,15 +3,19 @@ use super::{ sim_client::{ref_client_host, RefClient}, sim_gateway::{ref_gateway_host, RefGateway}, sim_net::Host, - strategies::{resolved_ips, subdomain_records}, + strategies::{resolved_ips, site_specific_dns_record, subdomain_records}, }; -use crate::messages::{gateway, DnsServer}; use crate::{client, proptest::*}; +use crate::{ + client::DnsResource, + messages::{gateway, DnsServer}, +}; use connlib_model::GatewayId; use connlib_model::{ResourceId, SiteId}; use itertools::Itertools; use proptest::{ - sample::Selector, + collection, + sample::{self, Selector}, strategy::{Just, Strategy}, }; use std::{ @@ -223,15 +227,22 @@ impl StubPortal { } pub(crate) fn gateways(&self) -> impl Strategy>> { + let dns_resources = self.dns_resources.clone(); + self.gateways_by_site - .values() - .flatten() - .map(|(gid, ipv4_addr, ipv6_addr)| { - ( - Just(*gid), - ref_gateway_host(Just(*ipv4_addr), Just(*ipv6_addr)), - ) - }) // Map each ID to a strategy that samples a gateway. + .iter() + .flat_map(|(site_id, gateways)| { + gateways.iter().map(|(gid, ipv4_addr, ipv6_addr)| { + ( + Just(*gid), + ref_gateway_host( + Just(*ipv4_addr), + Just(*ipv6_addr), + site_specific_dns_records(dns_resources.clone(), *site_id), + ), + ) + }) + }) .collect::>() // A `Vec` implements `Strategy>` .prop_map(BTreeMap::from_iter) } @@ -250,39 +261,68 @@ impl StubPortal { } pub(crate) fn dns_resource_records(&self) -> impl Strategy { - self.dns_resources - .values() - .map(|resource| { - let address = resource.address.clone(); - - // Only generate simple wildcard domains for these tests. - // The matching logic is extensively unit-tested so we don't need to cover all cases here. - // What we do want to cover is multiple domains pointing to the same resource. - // For example, `*.example.com` and `app.example.com`. - match address.split_once('.') { - Some(("*" | "**", base)) => { - subdomain_records(base.to_owned(), domain_label()).boxed() - } - _ => resolved_ips() - .prop_map(move |resolved_ips| { - DnsRecords::from([(address.parse().unwrap(), resolved_ips)]) - }) - .boxed(), - } - }) - .collect::>() - .prop_map(|records| { - let mut map = DnsRecords::default(); - - for record in records { - map.merge(record) - } - - map - }) + dns_resource_records(self.dns_resources.clone().into_values()) } } +/// Generates site-specific DNS records for a particular site. +fn site_specific_dns_records( + dns_resources: BTreeMap, + site: SiteId, +) -> impl Strategy { + let dns_resources_in_site = dns_resources + .into_values() + .filter(move |resource| resource.sites.iter().any(|s| s.id == site)); + + dns_resource_records(dns_resources_in_site).prop_flat_map(|records| { + if records.is_empty() { + Just(DnsRecords::default()).boxed() + } else { + collection::btree_map( + sample::select(records.domains_iter().collect::>()), + collection::btree_set(site_specific_dns_record(), 1..6), + 0..5, + ) + .prop_map_into() + .boxed() + } + }) +} + +fn dns_resource_records( + dns_resources: impl Iterator, +) -> impl Strategy { + dns_resources + .map(|resource| { + let address = resource.address; + + // Only generate simple wildcard domains for these tests. + // The matching logic is extensively unit-tested so we don't need to cover all cases here. + // What we do want to cover is multiple domains pointing to the same resource. + // For example, `*.example.com` and `app.example.com`. + match address.split_once('.') { + Some(("*" | "**", base)) => { + subdomain_records(base.to_owned(), domain_label()).boxed() + } + _ => resolved_ips() + .prop_map(move |resolved_ips| { + DnsRecords::from([(address.parse().unwrap(), resolved_ips)]) + }) + .boxed(), + } + }) + .collect::>() + .prop_map(|records| { + let mut map = DnsRecords::default(); + + for record in records { + map.merge(record) + } + + map + }) +} + /// An [`Iterator`] over the possible IPv4 addresses of a tunnel interface. /// /// We use the CG-NAT range for IPv4.