diff --git a/rust/connlib/tunnel/src/client.rs b/rust/connlib/tunnel/src/client.rs index 29edc1079..efd6cffbe 100644 --- a/rust/connlib/tunnel/src/client.rs +++ b/rust/connlib/tunnel/src/client.rs @@ -242,13 +242,6 @@ pub struct ClientState { /// /// The [`Instant`] tracks when the DNS query expires. mangled_dns_queries: HashMap, - /// UDP DNS queries that were forwarded to an upstream server, indexed by the DNS query ID + the server we sent it to. - /// - /// The value is the original source IP. - /// - /// DNS query IDs don't appear to be unique across servers they are being sent to on some operating systems (looking at you, Windows). - /// Hence, we need to index by ID + socket of the DNS server. - recursive_udp_dns_queries: HashMap<(u16, SocketAddr), SocketAddr>, /// Manages internal dns records and emits forwarding event when not internally handled stub_resolver: StubResolver, @@ -292,7 +285,6 @@ impl ClientState { sites_status: Default::default(), gateways_site: Default::default(), mangled_dns_queries: Default::default(), - recursive_udp_dns_queries: Default::default(), stub_resolver: StubResolver::new(known_hosts), disabled_resources: Default::default(), buffered_transmits: Default::default(), @@ -507,12 +499,7 @@ impl ClientState { fn try_handle_dns_response(&mut self, response: dns::RecursiveResponse) -> anyhow::Result<()> { match (response.transport, response.message) { - (dns::Transport::Udp, result) => { - let destination = self - .recursive_udp_dns_queries - .remove(&(response.query.header().id(), response.server)) - .context("Unknown query")?; - + (dns::Transport::Udp { source }, result) => { let dns_response = result.unwrap_or_else(|e| { tracing::debug!("UDP DNS query failed: {e}"); @@ -523,7 +510,7 @@ impl ClientState { }); let ip_packet = self - .try_handle_udp_dns_response(response.server, destination, &dns_response) + .try_handle_udp_dns_response(response.server, source, &dns_response) .context("Failed to produce UDP DNS response packet")?; self.buffered_packets.push_back(ip_packet); @@ -737,14 +724,12 @@ impl ClientState { } let query_id = message.header().id(); - let original_src = SocketAddr::new(packet.source(), datagram.source_port()); + let source = SocketAddr::new(packet.source(), datagram.source_port()); tracing::trace!(server = %upstream, %query_id, "Forwarding UDP DNS query"); - self.recursive_udp_dns_queries - .insert((query_id, upstream), original_src); self.buffered_dns_queries - .push_back(dns::RecursiveQuery::via_udp(upstream, message)); + .push_back(dns::RecursiveQuery::via_udp(source, upstream, message)); } Err(e) => { tracing::trace!(?packet, "Failed to handle DNS query: {e:#}"); diff --git a/rust/connlib/tunnel/src/dns.rs b/rust/connlib/tunnel/src/dns.rs index fad2c447d..c9e6207b4 100644 --- a/rust/connlib/tunnel/src/dns.rs +++ b/rust/connlib/tunnel/src/dns.rs @@ -63,18 +63,21 @@ pub(crate) struct RecursiveResponse { } impl RecursiveQuery { - pub(crate) fn via_udp(server: SocketAddr, message: Message<&[u8]>) -> Self { + pub(crate) fn via_udp(source: SocketAddr, server: SocketAddr, message: Message<&[u8]>) -> Self { Self { server, message: message.octets_into(), - transport: Transport::Udp, + transport: Transport::Udp { source }, } } } #[derive(Debug, Clone, Copy)] pub(crate) enum Transport { - Udp, + Udp { + /// The original source we received the DNS query on. + source: SocketAddr, + }, } /// Tells the Client how to reply to a single DNS query diff --git a/rust/connlib/tunnel/src/io.rs b/rust/connlib/tunnel/src/io.rs index 08eae8156..aacd3d37a 100644 --- a/rust/connlib/tunnel/src/io.rs +++ b/rust/connlib/tunnel/src/io.rs @@ -217,7 +217,7 @@ impl Io { pub fn send_dns_query(&mut self, query: dns::RecursiveQuery) { match query.transport { - dns::Transport::Udp => { + dns::Transport::Udp { .. } => { let factory = self.udp_socket_factory.clone(); let server = query.server; let bind_addr = match query.server { @@ -227,7 +227,7 @@ impl Io { let meta = DnsQueryMetaData { query: query.message.clone(), server, - transport: dns::Transport::Udp, + transport: query.transport, }; if self