diff --git a/rust/Cargo.lock b/rust/Cargo.lock index 997461deb..1c2e760e9 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -2688,6 +2688,7 @@ dependencies = [ "ip_network_table", "itertools 0.14.0", "l3-tcp", + "l3-udp-dns-client", "l4-tcp-dns-server", "l4-udp-dns-server", "lru", diff --git a/rust/connlib/l3-udp-dns-client/lib.rs b/rust/connlib/l3-udp-dns-client/lib.rs index 0c41846db..21beb97f7 100644 --- a/rust/connlib/l3-udp-dns-client/lib.rs +++ b/rust/connlib/l3-udp-dns-client/lib.rs @@ -8,7 +8,7 @@ use anyhow::{Context as _, Result, anyhow, bail}; use ip_packet::IpPacket; use rand::{Rng, SeedableRng, rngs::StdRng}; -const TIMEOUT: Duration = Duration::from_secs(5); +const TIMEOUT: Duration = Duration::from_secs(30); /// A sans-io DNS-over-UDP client. pub struct Client { diff --git a/rust/connlib/l4-udp-dns-server/lib.rs b/rust/connlib/l4-udp-dns-server/lib.rs index 87c34a6bf..581fe46a2 100644 --- a/rust/connlib/l4-udp-dns-server/lib.rs +++ b/rust/connlib/l4-udp-dns-server/lib.rs @@ -89,7 +89,7 @@ impl Server { } if let Poll::Ready(Some(result)) = self.reading_udp_queries.poll_next_unpin(cx) { - let (from, message) = result + let (remote, message) = result .context("Failed to read UDP DNS query") .map_err(anyhow_to_io)?; @@ -102,7 +102,7 @@ impl Server { return Poll::Ready(Ok(Query { local, - from, + remote, message, })); } @@ -144,7 +144,7 @@ async fn read_udp_query(socket: Arc) -> Result<(SocketAddr, dns_types pub struct Query { pub local: SocketAddr, - pub from: SocketAddr, + pub remote: SocketAddr, pub message: dns_types::Query, } @@ -192,7 +192,7 @@ mod tests { let query = poll_fn(|cx| server.poll(cx)).await.unwrap(); server - .send_response(query.from, dns_types::Response::no_error(&query.message)) + .send_response(query.remote, dns_types::Response::no_error(&query.message)) .unwrap(); } }); @@ -224,7 +224,7 @@ mod tests { let query = poll_fn(|cx| server.poll(cx)).await.unwrap(); server - .send_response(query.from, dns_types::Response::no_error(&query.message)) + .send_response(query.remote, dns_types::Response::no_error(&query.message)) .unwrap(); } }); diff --git a/rust/connlib/tunnel/Cargo.toml b/rust/connlib/tunnel/Cargo.toml index d86708704..6f368956e 100644 --- a/rust/connlib/tunnel/Cargo.toml +++ b/rust/connlib/tunnel/Cargo.toml @@ -17,7 +17,7 @@ bufferpool = { workspace = true } bytes = { workspace = true, features = ["std"] } chrono = { workspace = true } connlib-model = { workspace = true } -derive_more = { workspace = true, features = ["debug", "from"] } +derive_more = { workspace = true, features = ["debug", "from", "display"] } divan = { workspace = true, optional = true } dns-over-tcp = { workspace = true } dns-types = { workspace = true } @@ -33,6 +33,7 @@ ip-packet = { workspace = true } ip_network = { workspace = true } ip_network_table = { workspace = true } itertools = { workspace = true, features = ["use_std"] } +l3-udp-dns-client = { workspace = true } l4-tcp-dns-server = { workspace = true } l4-udp-dns-server = { workspace = true } lru = { workspace = true } diff --git a/rust/connlib/tunnel/src/client.rs b/rust/connlib/tunnel/src/client.rs index bdd580cb5..edd3f71a3 100644 --- a/rust/connlib/tunnel/src/client.rs +++ b/rust/connlib/tunnel/src/client.rs @@ -21,7 +21,6 @@ use secrecy::ExposeSecret as _; use crate::client::dns_cache::DnsCache; use crate::dns::{DnsResourceRecord, StubResolver}; -use crate::expiring_map::{self, ExpiringMap}; use crate::messages::Interface as InterfaceConfig; use crate::messages::{IceCredentials, SecretKey}; use crate::peer_store::PeerStore; @@ -80,10 +79,6 @@ pub(crate) const DNS_SENTINELS_V6: Ipv6Network = match Ipv6Network::new( Err(_) => unreachable!(), }; -// The max time a dns request can be configured to live in resolvconf -// is 30 seconds. See resolvconf(5) timeout. -const IDS_EXPIRE: std::time::Duration = std::time::Duration::from_secs(60); - /// How many gateways we at most remember that we connected to. /// /// 100 has been chosen as a pretty arbitrary value. @@ -122,8 +117,6 @@ pub struct ClientState { /// Manages the DNS configuration. dns_config: DnsConfig, - /// UDP DNS queries that had their destination IP mangled to redirect them to another DNS resolver through the tunnel. - udp_dns_sockets_by_upstream_and_query_id: ExpiringMap<(SocketAddr, u16), SocketAddr>, /// Manages internal dns records and emits forwarding event when not internally handled stub_resolver: StubResolver, /// Caches responses from DNS servers. @@ -132,10 +125,12 @@ pub struct ClientState { /// Configuration of the TUN device, when it is up. tun_config: Option, + udp_dns_client: l3_udp_dns_client::Client, tcp_dns_client: dns_over_tcp::Client, tcp_dns_server: dns_over_tcp::Server, - /// Tracks the TCP stream (i.e. socket-pair) on which we received a TCP DNS query by the ID of the recursive DNS query we issued. - tcp_dns_streams_by_upstream_and_query_id: HashMap<(SocketAddr, u16), (SocketAddr, SocketAddr)>, + /// Tracks the UDP/TCP stream (i.e. socket-pair) on which we received a DNS query by the ID of the recursive DNS query we issued. + dns_streams_by_upstream_and_query_id: + HashMap<(dns::Transport, SocketAddr, u16), (SocketAddr, SocketAddr)>, /// Stores the gateways we recently connected to. /// @@ -152,8 +147,7 @@ pub struct ClientState { struct PendingFlow { last_intent_sent_at: Instant, resource_packets: UniquePacketBuffer, - udp_dns_queries: AllocRingBuffer, - tcp_dns_queries: AllocRingBuffer, + dns_queries: AllocRingBuffer, } impl PendingFlow { @@ -170,8 +164,7 @@ impl PendingFlow { Self::CAPACITY_POW_2, "pending-flow-resources", ), - udp_dns_queries: AllocRingBuffer::with_capacity_power_of_2(Self::CAPACITY_POW_2), - tcp_dns_queries: AllocRingBuffer::with_capacity_power_of_2(Self::CAPACITY_POW_2), + dns_queries: AllocRingBuffer::with_capacity_power_of_2(Self::CAPACITY_POW_2), }; this.push(trigger); @@ -181,11 +174,8 @@ impl PendingFlow { fn push(&mut self, trigger: ConnectionTrigger) { match trigger { ConnectionTrigger::PacketForResource(packet) => self.resource_packets.push(packet), - ConnectionTrigger::UdpDnsQueryForSite(packet) => { - self.udp_dns_queries.enqueue(packet); - } - ConnectionTrigger::TcpDnsQueryForSite(query) => { - self.tcp_dns_queries.enqueue(query); + ConnectionTrigger::DnsQueryForSite(query) => { + self.dns_queries.enqueue(query); } ConnectionTrigger::IcmpDestinationUnreachableProhibited => {} } @@ -211,16 +201,16 @@ impl ClientState { node: ClientNode::new(seed, now), sites_status: Default::default(), gateways_site: Default::default(), - udp_dns_sockets_by_upstream_and_query_id: Default::default(), stub_resolver: StubResolver::new(records), dns_cache: Default::default(), buffered_transmits: Default::default(), is_internet_resource_active, recently_connected_gateways: LruCache::new(MAX_REMEMBERED_GATEWAYS), buffered_dns_queries: Default::default(), + udp_dns_client: l3_udp_dns_client::Client::new(now, seed), tcp_dns_client: dns_over_tcp::Client::new(now, seed), tcp_dns_server: dns_over_tcp::Server::new(now), - tcp_dns_streams_by_upstream_and_query_id: Default::default(), + dns_streams_by_upstream_and_query_id: Default::default(), pending_flows: Default::default(), dns_resource_nat: Default::default(), pending_tun_update: Default::default(), @@ -442,6 +432,11 @@ impl ClientState { .inspect_err(|e| tracing::debug!(%local, %from, num_bytes = %packet.len(), "Failed to decapsulate: {e:#}")) .ok()??; + if self.udp_dns_client.accepts(&packet) { + self.udp_dns_client.handle_inbound(packet); + return None; + } + if self.tcp_dns_client.accepts(&packet) { self.tcp_dns_client.handle_inbound(packet); return None; @@ -488,13 +483,6 @@ impl ClientState { .inspect_err(|e| tracing::debug!(%gid, %local, %from, "{e}")) .ok()?; - let packet = maybe_mangle_dns_response_from_upstream_dns_server( - packet, - &mut self.udp_dns_sockets_by_upstream_and_query_id, - &mut self.dns_cache, - now, - ); - if feature_flags::icmp_error_unreachable_prohibited_create_new_flow() && let Ok(Some((failed_packet, error))) = packet.icmp_error() && error.is_unreachable_prohibited() @@ -517,7 +505,7 @@ impl ClientState { let server = response.server; let domain = response.query.domain(); - let _span = tracing::debug_span!("handle_dns_response", %qid, %server, %domain).entered(); + let _span = tracing::debug_span!("handle_dns_response", %qid, %server, local = %response.local, %domain).entered(); match (response.transport, response.message) { (dns::Transport::Udp, Err(e)) if e.kind() == io::ErrorKind::TimedOut => { @@ -753,7 +741,7 @@ impl ClientState { // If we are making this connection because we want to send a DNS query to the Gateway, // mark it as "used" through the DNS resource ID. - if !pending_flow.udp_dns_queries.is_empty() || !pending_flow.tcp_dns_queries.is_empty() { + if !pending_flow.dns_queries.is_empty() { self.peers.add_ips_with_resource( &gid, [ @@ -765,34 +753,19 @@ impl ClientState { } // 2. Buffered UDP DNS queries for the Gateway - for packet in pending_flow.udp_dns_queries { + for query in pending_flow.dns_queries { let gateway = self.peers.get(&gid).context("Unknown peer")?; // If this error happens we have a bug: We just inserted it above. - let upstream = gateway.tun_dns_server_endpoint(packet.destination()); - let packet = - self.mangle_udp_dns_query_to_new_upstream_through_tunnel(upstream, now, packet); + let upstream = gateway.tun_dns_server_endpoint(query.local.ip()); - encapsulate_and_buffer( - packet, - gid, + self.forward_dns_query_to_new_upstream_via_tunnel( + query.local, + query.remote, + upstream, + query.message, + query.transport, now, - &mut self.node, - &mut self.buffered_transmits, - ) - } - - // 3. Buffered TCP DNS queries for the Gateway - for query in pending_flow.tcp_dns_queries { - let server = match query.local { - SocketAddr::V4(_) => { - SocketAddr::new(gateway_tun.v4.into(), crate::gateway::TUN_DNS_PORT) - } - SocketAddr::V6(_) => { - SocketAddr::new(gateway_tun.v6.into(), crate::gateway::TUN_DNS_PORT) - } - }; - - self.forward_tcp_dns_query_to_new_upstream_via_tunnel(server, query); + ); } Ok(Ok(())) @@ -832,7 +805,9 @@ impl ClientState { return ControlFlow::Break(()); } - self.handle_udp_dns_query(upstream, packet, now) + self.handle_udp_dns_query(upstream, packet, now); + + ControlFlow::Break(()) } pub fn on_connection_failed(&mut self, resource: ResourceId) { @@ -903,6 +878,8 @@ impl ClientState { return; }; + self.udp_dns_client + .set_source_interface(tun_config.ip.v4, tun_config.ip.v6); self.tcp_dns_client .set_source_interface(tun_config.ip.v4, tun_config.ip.v6); self.tcp_dns_client.reset(); @@ -1104,9 +1081,9 @@ impl ClientState { pub fn poll_timeout(&mut self) -> Option<(Instant, &'static str)> { iter::empty() .chain( - self.udp_dns_sockets_by_upstream_and_query_id + self.udp_dns_client .poll_timeout() - .map(|instant| (instant, "DNS socket timeout")), + .map(|instant| (instant, "UDP DNS client")), ) .chain( self.dns_cache @@ -1131,43 +1108,57 @@ impl ClientState { self.node.handle_timeout(now); self.drain_node_events(); - self.udp_dns_sockets_by_upstream_and_query_id - .handle_timeout(now); - - while let Some(event) = self.udp_dns_sockets_by_upstream_and_query_id.poll_event() { - let expiring_map::Event::EntryExpired { key, value } = event; - - tracing::debug!( - ?key, - ?value, - "Mapping entry for forwarded DNS query expired" - ); - } - - self.advance_dns_tcp_sockets(now); + self.advance_dns_clients_and_servers(now); self.send_dns_resource_nat_packets(now); self.dns_cache.handle_timeout(now); } - /// Advance the TCP DNS server and client state machines. + /// Advance the DNS server and client state machines. /// - /// Receiving something on a TCP server socket may trigger packets to be sent on the TCP client socket and vice versa. + /// Receiving something on a UDP/TCP server socket may trigger packets to be sent on the UDP/TCP client socket and vice versa. /// Therefore, we loop here until non of the `poll-X` functions return anything anymore. - fn advance_dns_tcp_sockets(&mut self, now: Instant) { + fn advance_dns_clients_and_servers(&mut self, now: Instant) { loop { self.tcp_dns_server.handle_timeout(now); self.tcp_dns_client.handle_timeout(now); + self.udp_dns_client.handle_timeout(now); // Check if have any pending TCP DNS queries. if let Some(query) = self.tcp_dns_server.poll_queries() { - self.handle_tcp_dns_query(query, now); + let Some(upstream) = self + .dns_config + .mapping() + .upstream_by_sentinel(query.local.ip()) + else { + // This is highly-unlikely but might be possible if our DNS mapping changes whilst the TCP DNS server is processing a request. + continue; + }; + + if let Some(response) = self.handle_dns_query( + query.message, + query.local, + query.remote, + upstream, + dns::Transport::Tcp, + now, + ) { + unwrap_or_debug!( + self.tcp_dns_server + .send_message(query.local, query.remote, response), + "Failed to send TCP DNS response: {}" + ); + } continue; } - // Check if the client wants to emit any packets. - if let Some(packet) = self.tcp_dns_client.poll_outbound() { - // All packets from the TCP DNS client _should_ go through the tunnel. + // Check if the clients wants to emit any packets. + if let Some(packet) = self + .tcp_dns_client + .poll_outbound() + .or_else(|| self.udp_dns_client.poll_outbound()) + { + // All packets from the DNS clients _should_ go through the tunnel. let Some(transmit) = self.encapsulate(packet, now) else { continue; }; @@ -1176,13 +1167,45 @@ impl ClientState { continue; } - // Check if the client has assembled a response to a query. + // Check if the UDP DNS client has assembled a response to a query. + if let Some(query_result) = self.udp_dns_client.poll_query_result() { + let server = query_result.server; + let qid = query_result.query.id(); + let known_sockets = &mut self.dns_streams_by_upstream_and_query_id; + + let Some((local, remote)) = + known_sockets.remove(&(dns::Transport::Udp, server, qid)) + else { + tracing::warn!(?known_sockets, %server, %qid, "Failed to find UDP socket handle for query result"); + + continue; + }; + + self.handle_dns_response( + dns::RecursiveResponse { + server, + local, + remote, + query: query_result.query, + message: query_result + .result + .map_err(|e| io::Error::other(format!("{e:#}"))), + transport: dns::Transport::Udp, + }, + now, + ); + continue; + } + + // Check if the TCP DNS client has assembled a response to a query. if let Some(query_result) = self.tcp_dns_client.poll_query_result() { let server = query_result.server; let qid = query_result.query.id(); - let known_sockets = &mut self.tcp_dns_streams_by_upstream_and_query_id; + let known_sockets = &mut self.dns_streams_by_upstream_and_query_id; - let Some((local, remote)) = known_sockets.remove(&(server, qid)) else { + let Some((local, remote)) = + known_sockets.remove(&(dns::Transport::Tcp, server, qid)) + else { tracing::warn!(?known_sockets, %server, %qid, "Failed to find TCP socket handle for query result"); continue; @@ -1222,16 +1245,11 @@ impl ClientState { } } - fn handle_udp_dns_query( - &mut self, - upstream: SocketAddr, - packet: IpPacket, - now: Instant, - ) -> ControlFlow<(), IpPacket> { + fn handle_udp_dns_query(&mut self, upstream: SocketAddr, packet: IpPacket, now: Instant) { let Some(datagram) = packet.as_udp() else { tracing::debug!(?packet, "Not a UDP packet"); - return ControlFlow::Break(()); + return; }; if datagram.destination_port() != DNS_PORT { @@ -1239,81 +1257,28 @@ impl ClientState { ?packet, "UDP DNS queries are only supported on port {DNS_PORT}" ); - return ControlFlow::Break(()); + return; } let message = match dns_types::Query::parse(datagram.payload()) { Ok(message) => message, Err(e) => { tracing::warn!(?packet, "Failed to parse DNS query: {e:#}"); - return ControlFlow::Break(()); + return; } }; - let destination = SocketAddr::new(packet.destination(), datagram.destination_port()); - let source = SocketAddr::new(packet.source(), datagram.source_port()); + let local = SocketAddr::new(packet.destination(), datagram.destination_port()); + let remote = SocketAddr::new(packet.source(), datagram.source_port()); - if let Some(response) = self.dns_cache.try_answer(&message, now) { - unwrap_or_debug!( - self.try_queue_udp_dns_response(destination, source, response), + if let Some(response) = + self.handle_dns_query(message, local, remote, upstream, dns::Transport::Udp, now) + { + unwrap_or_warn!( + self.try_queue_udp_dns_response(local, remote, response), "Failed to queue UDP DNS response: {}" ); - - return ControlFlow::Break(()); - } - - match self.stub_resolver.handle(&message) { - dns::ResolveStrategy::LocalResponse(response) => { - self.dns_resource_nat.recreate(message.domain()); - self.update_dns_resource_nat(now, iter::empty()); - self.dns_cache.insert(message.domain(), &response, now); - - unwrap_or_debug!( - self.try_queue_udp_dns_response(destination, source, response), - "Failed to queue UDP DNS response: {}" - ); - } - dns::ResolveStrategy::RecurseLocal => { - if self.should_forward_dns_query_to_gateway(upstream.ip()) { - let packet = self - .mangle_udp_dns_query_to_new_upstream_through_tunnel(upstream, now, packet); - - return ControlFlow::Continue(packet); - } - let query_id = message.id(); - - tracing::trace!(server = %upstream, %query_id, "Forwarding UDP DNS query directly via host"); - - self.buffered_dns_queries - .push_back(dns::RecursiveQuery::via_udp( - destination, - source, - upstream, - message, - )); - } - dns::ResolveStrategy::RecurseSite(resource) => { - let Some(gateway) = - peer_by_resource_mut(&self.resources_gateways, &mut self.peers, resource) - else { - self.on_not_connected_resource( - resource, - ConnectionTrigger::UdpDnsQueryForSite(packet), - now, - ); - return ControlFlow::Break(()); - }; - - let upstream = gateway.tun_dns_server_endpoint(packet.destination()); - - let packet = - self.mangle_udp_dns_query_to_new_upstream_through_tunnel(upstream, now, packet); - - return ControlFlow::Continue(packet); - } - } - - ControlFlow::Break(()) + }; } fn handle_llmnr_dns_query(&mut self, packet: IpPacket, now: Instant) { @@ -1373,97 +1338,47 @@ impl ClientState { } } - fn mangle_udp_dns_query_to_new_upstream_through_tunnel( + fn handle_dns_query( &mut self, + message: dns_types::Query, + local: SocketAddr, + remote: SocketAddr, upstream: SocketAddr, + transport: dns::Transport, now: Instant, - mut packet: IpPacket, - ) -> IpPacket { - let dst_ip = packet.destination(); - let datagram = packet - .as_udp() - .expect("to be a valid UDP packet at this point"); + ) -> Option { + let query_id = message.id(); - let dst_port = datagram.destination_port(); - let query_id = dns_types::Query::parse(datagram.payload()) - .expect("to be a valid DNS query at this point") - .id(); - - let connlib_dns_server = SocketAddr::new(dst_ip, dst_port); - - self.udp_dns_sockets_by_upstream_and_query_id.insert( - (upstream, query_id), - connlib_dns_server, - now, - IDS_EXPIRE, - ); - if let Err(e) = packet.set_dst(upstream.ip()) { - tracing::warn!("Failed to set destination IP for UDP DNS query: {e:#}"); - } - // TODO: Remove this once we disallow non-standard DNS ports: https://github.com/firezone/firezone/issues/8330 - packet - .as_udp_mut() - .expect("to be a valid UDP packet at this point") - .set_destination_port(upstream.port()); - - packet.update_checksum(); - - tracing::trace!(%upstream, %connlib_dns_server, %query_id, "Forwarding UDP DNS query via tunnel"); - - packet - } - - fn handle_tcp_dns_query(&mut self, query: dns_over_tcp::Query, now: Instant) { - let query_id = query.message.id(); - - let Some(server) = self - .dns_config - .mapping() - .upstream_by_sentinel(query.local.ip()) - else { - // This is highly-unlikely but might be possible if our DNS mapping changes whilst the TCP DNS server is processing a request. - return; - }; - - if let Some(response) = self.dns_cache.try_answer(&query.message, now) { - unwrap_or_debug!( - self.tcp_dns_server - .send_message(query.local, query.remote, response), - "Failed to send TCP DNS response: {}" - ); - - return; + if let Some(response) = self.dns_cache.try_answer(&message, now) { + return Some(response); } - match self.stub_resolver.handle(&query.message) { + match self.stub_resolver.handle(&message) { dns::ResolveStrategy::LocalResponse(response) => { - self.dns_resource_nat.recreate(query.message.domain()); + self.dns_resource_nat.recreate(message.domain()); self.update_dns_resource_nat(now, iter::empty()); - self.dns_cache - .insert(query.message.domain(), &response, now); + self.dns_cache.insert(message.domain(), &response, now); - unwrap_or_debug!( - self.tcp_dns_server - .send_message(query.local, query.remote, response), - "Failed to send TCP DNS response: {}" - ); + return Some(response); } dns::ResolveStrategy::RecurseLocal => { - if self.should_forward_dns_query_to_gateway(server.ip()) { - self.forward_tcp_dns_query_to_new_upstream_via_tunnel(server, query); + if self.should_forward_dns_query_to_gateway(upstream.ip()) { + self.forward_dns_query_to_new_upstream_via_tunnel( + local, remote, upstream, message, transport, now, + ); - return; + return None; } - tracing::trace!(%server, %query_id, "Forwarding TCP DNS query"); + tracing::trace!(%upstream, %query_id, "Forwarding {transport} DNS query"); - self.buffered_dns_queries - .push_back(dns::RecursiveQuery::via_tcp( - query.local, - query.remote, - server, - query.message, - )); + self.buffered_dns_queries.push_back(dns::RecursiveQuery { + server: upstream, + local, + remote, + message, + transport, + }); } dns::ResolveStrategy::RecurseSite(resource) => { let Some(gateway) = @@ -1471,54 +1386,62 @@ impl ClientState { else { self.on_not_connected_resource( resource, - ConnectionTrigger::TcpDnsQueryForSite(query), + ConnectionTrigger::DnsQueryForSite(DnsQueryForSite { + local, + remote, + transport, + message, + }), now, ); - return; + return None; }; - let server = gateway.tun_dns_server_endpoint(query.local.ip()); + let server = gateway.tun_dns_server_endpoint(local.ip()); - self.forward_tcp_dns_query_to_new_upstream_via_tunnel(server, query); + self.forward_dns_query_to_new_upstream_via_tunnel( + local, remote, server, message, transport, now, + ); } }; + + None } - fn forward_tcp_dns_query_to_new_upstream_via_tunnel( + fn forward_dns_query_to_new_upstream_via_tunnel( &mut self, + local: SocketAddr, + remote: SocketAddr, server: SocketAddr, - query: dns_over_tcp::Query, + query: dns_types::Query, + transport: dns::Transport, + now: Instant, ) { - let query_id = query.message.id(); + let query_id = query.id(); - match self - .tcp_dns_client - .send_query(server, query.message.clone()) - { + let result = match transport { + dns::Transport::Udp => self.udp_dns_client.send_query(server, query, now), + dns::Transport::Tcp => self.tcp_dns_client.send_query(server, query), + }; + + match result { Ok(()) => {} Err(e) => { tracing::warn!( - "Failed to send recursive TCP DNS query to upstream resolver: {e:#}" - ); - - unwrap_or_debug!( - self.tcp_dns_server.send_message( - query.local, - query.remote, - dns_types::Response::servfail(&query.message) - ), - "Failed to send TCP DNS response: {}" + "Failed to send recursive {transport} DNS query to upstream resolver: {e:#}" ); return; } }; + tracing::trace!(%server, %local, %query_id, "Forwarding {transport} DNS query via tunnel"); + let existing = self - .tcp_dns_streams_by_upstream_and_query_id - .insert((server, query_id), (query.local, query.remote)); + .dns_streams_by_upstream_and_query_id + .insert((transport, server, query_id), (local, remote)); if let Some((existing_local, existing_remote)) = existing - && (existing_local != query.local || existing_remote != query.remote) + && (existing_local != local || existing_remote != remote) { debug_assert!(false, "Query IDs should be unique"); } @@ -1982,73 +1905,30 @@ fn is_definitely_not_a_resource(ip: IpAddr) -> bool { false } -fn maybe_mangle_dns_response_from_upstream_dns_server( - mut packet: IpPacket, - udp_dns_sockets_by_upstream_and_query_id: &mut ExpiringMap<(SocketAddr, u16), SocketAddr>, - dns_cache: &mut DnsCache, - now: Instant, -) -> IpPacket { - let src_ip = packet.source(); - - let Some(udp) = packet.as_udp() else { - return packet; - }; - - let src_port = udp.source_port(); - let src_socket = SocketAddr::new(src_ip, src_port); - - let Ok(message) = dns_types::Response::parse(udp.payload()) else { - return packet; - }; - - let Some(expiring_map::Entry { - value: original_dst, - .. - }) = udp_dns_sockets_by_upstream_and_query_id.remove(&(src_socket, message.id())) - else { - return packet; - }; - - dns_cache.insert(message.domain(), &message, now); - - tracing::trace!(server = %src_ip, query_id = %message.id(), domain = %message.domain(), "Received UDP DNS response via tunnel"); - - if let Err(e) = packet.set_src(original_dst.ip()) { - tracing::warn!("Failed to set source IP for UDP DNS query: {e:#}"); - } - - packet - .as_udp_mut() - .expect("we parsed it as a UDP packet earlier") - .set_source_port(original_dst.port()); - - packet.update_checksum(); - - packet -} - /// What triggered us to establish a connection to a Gateway. enum ConnectionTrigger { /// A packet received on the TUN device with a destination IP that maps to one of our resources. PacketForResource(IpPacket), - /// A UDP DNS query that needs to be resolved within a particular site that we aren't connected to yet. - /// - /// This packet isn't mangled yet to point to the Gateway's TUN device IP because at the time of buffering, that IP is unknown. - UdpDnsQueryForSite(IpPacket), - /// A TCP DNS query that needs to be resolved within a particular site that we aren't connected to yet. - TcpDnsQueryForSite(dns_over_tcp::Query), + /// A DNS query that needs to be resolved within a particular site that we aren't connected to yet. + DnsQueryForSite(DnsQueryForSite), /// We have received an ICMP error that is marked as "access prohibited". /// /// Most likely, the Gateway is filtering these packets because the Client doesn't have access (anymore). IcmpDestinationUnreachableProhibited, } +struct DnsQueryForSite { + local: SocketAddr, + remote: SocketAddr, + transport: dns::Transport, + message: dns_types::Query, +} + impl ConnectionTrigger { fn name(&self) -> &'static str { match self { ConnectionTrigger::PacketForResource(_) => "packet-for-resource", - ConnectionTrigger::UdpDnsQueryForSite(_) => "udp-dns-query-for-site", - ConnectionTrigger::TcpDnsQueryForSite(_) => "tcp-dns-query-for-site", + ConnectionTrigger::DnsQueryForSite(_) => "dns-query-for-site", ConnectionTrigger::IcmpDestinationUnreachableProhibited => { "icmp-destination-unreachable-prohibited" } diff --git a/rust/connlib/tunnel/src/dns.rs b/rust/connlib/tunnel/src/dns.rs index d22511e5d..5ab316075 100644 --- a/rust/connlib/tunnel/src/dns.rs +++ b/rust/connlib/tunnel/src/dns.rs @@ -92,41 +92,11 @@ pub(crate) struct RecursiveResponse { pub transport: Transport, } -impl RecursiveQuery { - pub(crate) fn via_udp( - local: SocketAddr, - remote: SocketAddr, - server: SocketAddr, - message: dns_types::Query, - ) -> Self { - Self { - server, - local, - remote, - message, - transport: Transport::Udp, - } - } - - pub(crate) fn via_tcp( - local: SocketAddr, - remote: SocketAddr, - server: SocketAddr, - message: dns_types::Query, - ) -> Self { - Self { - server, - local, - remote, - message, - transport: Transport::Tcp, - } - } -} - -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, derive_more::Display)] pub(crate) enum Transport { + #[display("UDP")] Udp, + #[display("TCP")] Tcp, } diff --git a/rust/connlib/tunnel/src/expiring_map.rs b/rust/connlib/tunnel/src/expiring_map.rs index 895126e81..22fd800df 100644 --- a/rust/connlib/tunnel/src/expiring_map.rs +++ b/rust/connlib/tunnel/src/expiring_map.rs @@ -53,6 +53,7 @@ where self.inner.get(key) } + #[cfg(test)] pub fn remove(&mut self, key: &K) -> Option> { self.expiration.retain(|_, keys| { keys.retain(|k| k != key); diff --git a/rust/connlib/tunnel/src/lib.rs b/rust/connlib/tunnel/src/lib.rs index 46c96d723..ccbc9d719 100644 --- a/rust/connlib/tunnel/src/lib.rs +++ b/rust/connlib/tunnel/src/lib.rs @@ -456,17 +456,18 @@ impl GatewayTunnel { for query in udp_dns_queries { if let Some(nameserver) = self.io.fastest_nameserver() { - self.io.send_dns_query(dns::RecursiveQuery::via_udp( - query.local, - query.from, - SocketAddr::new(nameserver, dns::DNS_PORT), - query.message, - )); + self.io.send_dns_query(dns::RecursiveQuery { + server: SocketAddr::new(nameserver, dns::DNS_PORT), + local: query.local, + remote: query.remote, + message: query.message, + transport: dns::Transport::Udp, + }); } else { tracing::warn!(query = ?query.message, "No nameserver available to handle UDP DNS query"); if let Err(e) = self.io.send_udp_dns_response( - query.from, + query.remote, query.local, dns_types::Response::servfail(&query.message), ) { @@ -479,12 +480,13 @@ impl GatewayTunnel { for query in tcp_dns_queries { if let Some(nameserver) = self.io.fastest_nameserver() { - self.io.send_dns_query(dns::RecursiveQuery::via_tcp( - query.local, - query.remote, - SocketAddr::new(nameserver, dns::DNS_PORT), - query.message, - )); + self.io.send_dns_query(dns::RecursiveQuery { + server: SocketAddr::new(nameserver, dns::DNS_PORT), + local: query.local, + remote: query.remote, + message: query.message, + transport: dns::Transport::Tcp, + }); } else { tracing::warn!(query = ?query.message, "No nameserver available to handle TCP DNS query");