diff --git a/rust/connlib/dns-types/lib.rs b/rust/connlib/dns-types/lib.rs index 1f53d18f9..9cc1ca89c 100644 --- a/rust/connlib/dns-types/lib.rs +++ b/rust/connlib/dns-types/lib.rs @@ -225,6 +225,12 @@ impl Response { Self::parse(response.body()) } + pub fn with_id(mut self, id: u16) -> Self { + self.inner.header_mut().set_id(id); + + self + } + pub fn id(&self) -> u16 { self.inner.header().id() } diff --git a/rust/connlib/tunnel/src/client.rs b/rust/connlib/tunnel/src/client.rs index c77c155f4..3a4ee4bb1 100644 --- a/rust/connlib/tunnel/src/client.rs +++ b/rust/connlib/tunnel/src/client.rs @@ -508,49 +508,47 @@ impl ClientState { let _span = tracing::debug_span!("handle_dns_response", %qid, %server, local = %response.local, %domain).entered(); - match (response.transport, response.message) { - (dns::Transport::Udp, Err(e)) - if e.downcast_ref::() - .is_some_and(|e| e.kind() == io::ErrorKind::TimedOut) => - { - tracing::debug!("Recursive UDP DNS query timed out") + let message = match response.message { + Ok(response) => { + tracing::trace!("Received recursive DNS response"); + + if response.truncated() { + tracing::debug!("Upstream DNS server had to truncate response"); + } + + response } - (dns::Transport::Udp, result) => { - let message = result - .inspect(|message| { - tracing::trace!("Received recursive UDP DNS response"); + Err(e) + if response.transport == dns::Transport::Udp + && e.downcast_ref::() + .is_some_and(|e| e.kind() == io::ErrorKind::TimedOut) => + { + tracing::debug!("Recursive UDP DNS query timed out"); - if message.truncated() { - tracing::debug!("Upstream DNS server had to truncate response"); - } + return; // Our UDP DNS query timeout is likely longer than the one from the OS, so don't bother sending a response. + } + Err(e) => { + tracing::debug!("Recursive DNS query failed: {e:#}"); - self.dns_cache.insert(domain, message, now); - }) - .unwrap_or_else(|e| { - tracing::debug!("Recursive UDP DNS query failed: {e:#}"); + dns_types::Response::servfail(&response.query) + } + }; - dns_types::Response::servfail(&response.query) - }); + // Ensure the response we are sending back has the original query ID. + // Recursive DoH queries set the ID to 0. + let message = message.with_id(qid); + self.dns_cache.insert(domain, &message, now); + + match response.transport { + dns::Transport::Udp => { self.buffered_packets.extend(into_udp_dns_packet( response.local, response.remote, message, )); } - (dns::Transport::Tcp, result) => { - let message = result - .inspect(|message| { - tracing::trace!("Received recursive TCP DNS response"); - - self.dns_cache.insert(domain, message, now); - }) - .unwrap_or_else(|e| { - tracing::debug!("Recursive TCP DNS query failed: {e:#}"); - - dns_types::Response::servfail(&response.query) - }); - + dns::Transport::Tcp => { unwrap_or_warn!( self.tcp_dns_server .send_message(response.local, response.remote, message), diff --git a/rust/connlib/tunnel/src/tests/sut.rs b/rust/connlib/tunnel/src/tests/sut.rs index 5e71bddbc..1969c0306 100644 --- a/rust/connlib/tunnel/src/tests/sut.rs +++ b/rust/connlib/tunnel/src/tests/sut.rs @@ -554,8 +554,13 @@ impl TunnelTest { let server = query.server; let transport = query.transport; + // DoH queries are always sent with an ID of 0, simulate that in the tests. + let message = matches!(server, dns::Upstream::DoH { .. }) + .then_some(query.message.clone().with_id(0)) + .unwrap_or(query.message.clone()); + let response = - self.on_recursive_dns_query(&query.message, &ref_state.global_dns_records, now); + self.on_recursive_dns_query(&message, &ref_state.global_dns_records, now); self.client.exec_mut(|c| { c.sut.handle_dns_response( dns::RecursiveResponse {