diff --git a/rust/connlib/tunnel/src/client.rs b/rust/connlib/tunnel/src/client.rs index 8c9f3072b..1bc4a890e 100644 --- a/rust/connlib/tunnel/src/client.rs +++ b/rust/connlib/tunnel/src/client.rs @@ -875,16 +875,7 @@ impl ClientState { self.tcp_dns_client .set_source_interface(tun_config.ip.v4, tun_config.ip.v6); - - let upstream_resolvers = self - .dns_mapping - .right_values() - .map(|s| s.address()) - .collect(); - - if let Err(e) = self.tcp_dns_client.set_resolvers(upstream_resolvers) { - tracing::warn!("Failed to connect to upstream DNS resolvers over TCP: {e:#}"); - } + self.tcp_dns_client.reset(); } fn initialise_tcp_dns_server(&mut self) { @@ -1393,8 +1384,7 @@ impl ClientState { // Resetting the client will trigger a failed `QueryResult` for each one that is in-progress. // Failed queries get translated into `SERVFAIL` responses to the client. - // This will also allocate new local ports for our outgoing TCP connections. - self.initialise_tcp_dns_client(); + self.tcp_dns_client.reset(); } pub(crate) fn poll_transmit(&mut self) -> Option> { diff --git a/rust/connlib/tunnel/src/tests/sim_client.rs b/rust/connlib/tunnel/src/tests/sim_client.rs index 1496ef674..da221d5d1 100644 --- a/rust/connlib/tunnel/src/tests/sim_client.rs +++ b/rust/connlib/tunnel/src/tests/sim_client.rs @@ -102,18 +102,8 @@ impl SimClient { } pub(crate) fn set_new_dns_servers(&mut self, mapping: BiMap) { - if self.dns_by_sentinel != mapping { - self.tcp_dns_client - .set_resolvers( - mapping - .left_values() - .map(|ip| SocketAddr::new(*ip, 53)) - .collect(), - ) - .unwrap(); - } - self.dns_by_sentinel = mapping; + self.tcp_dns_client.reset(); } pub(crate) fn dns_mapping(&self) -> &BiMap { diff --git a/rust/dns-over-tcp/src/client.rs b/rust/dns-over-tcp/src/client.rs index b009d6605..4c10a4f9c 100644 --- a/rust/dns-over-tcp/src/client.rs +++ b/rust/dns-over-tcp/src/client.rs @@ -85,46 +85,38 @@ impl Client { self.source_ips = Some((v4, v6)); } - /// Connect to the specified DNS resolvers. - /// - /// All currently pending queries will be reported as failed. - pub fn set_resolvers(&mut self, resolvers: BTreeSet) -> Result<()> { - let (ipv4_source, ipv6_source) = self.source_ips.context("Missing source IPs")?; - - // First, clear all local state. - self.sockets = SocketSet::new(vec![]); - self.sockets_by_remote.clear(); - self.local_ports_by_socket.clear(); - self.abort_all_pending_and_sent_queries(); - - // Second, try to allocate a unique port per resolver. - let unique_ports = self.sample_unique_ports(resolvers.len())?; - - // Third, initialise the sockets. - self.init_sockets( - std::iter::zip(unique_ports, resolvers), - ipv4_source, - ipv6_source, - ); - - Ok(()) - } - /// Send the given DNS query to the target server. /// /// This only queues the message. You need to call [`Client::handle_timeout`] to actually send them. pub fn send_query(&mut self, server: SocketAddr, message: Message>) -> Result<()> { anyhow::ensure!(!message.header().qr(), "Message is a DNS response!"); - anyhow::ensure!( - self.sockets_by_remote.contains_key(&server), - "Unknown DNS resolver" - ); self.pending_queries_by_remote .entry(server) .or_default() .push_back(message); + if self.sockets_by_remote.contains_key(&server) { + return Ok(()); + }; + + let local_port = self.sample_new_unique_port()?; + + let (ipv4_source, ipv6_source) = self + .source_ips + .ok_or_else(|| anyhow!("No source interface set"))?; + + let local_endpoint = match server { + SocketAddr::V4(_) => SocketAddr::new(ipv4_source.into(), local_port), + SocketAddr::V6(_) => SocketAddr::new(ipv6_source.into(), local_port), + }; + + let handle = self.sockets.add(create_tcp_socket()); + + self.sockets_by_remote.insert(server, handle); + self.local_ports_by_socket + .insert(handle, local_endpoint.port()); + Ok(()) } @@ -264,7 +256,7 @@ impl Client { Some(self.last_now + Duration::from(poll_in)) } - fn abort_all_pending_and_sent_queries(&mut self) { + pub fn reset(&mut self) { let aborted_pending_queries = self.pending_queries_by_remote .drain() @@ -280,52 +272,32 @@ impl Client { self.query_results .extend(aborted_pending_queries.chain(aborted_sent_queries)); + + self.sockets = SocketSet::new(vec![]); + self.sockets_by_remote.clear(); + self.local_ports_by_socket.clear(); } - fn init_sockets( - &mut self, - ports_and_resolvers: impl IntoIterator, - ipv4_source: Ipv4Addr, - ipv6_source: Ipv6Addr, - ) { - let new_sockets = ports_and_resolvers - .into_iter() - .map(|(port, server)| { - let local_endpoint = match server { - SocketAddr::V4(_) => SocketAddr::new(ipv4_source.into(), port), - SocketAddr::V6(_) => SocketAddr::new(ipv6_source.into(), port), - }; - let socket = create_tcp_socket(); + fn sample_new_unique_port(&mut self) -> Result { + let used_ports = self + .local_ports_by_socket + .values() + .copied() + .collect::>(); - (server, local_endpoint, socket) - }) - .collect::>(); - - for (server, local_endpoint, socket) in new_sockets { - let handle = self.sockets.add(socket); - - self.sockets_by_remote.insert(server, handle); - self.local_ports_by_socket - .insert(handle, local_endpoint.port()); - } - } - - fn sample_unique_ports(&mut self, num_ports: usize) -> Result> { - let mut ports = HashSet::with_capacity(num_ports); let range = MIN_PORT..=MAX_PORT; - if num_ports > range.len() { - bail!( - "Port range only provides {} values but we need {num_ports}", - range.len() - ) + if used_ports.len() == range.len() { + bail!("All ports exhausted") } - while ports.len() < num_ports { - ports.insert(self.rng.gen_range(range.clone())); - } + loop { + let port = self.rng.gen_range(range.clone()); - Ok(ports.into_iter()) + if !used_ports.contains(&port) { + return Ok(port); + } + } } } diff --git a/rust/dns-over-tcp/tests/client_and_server.rs b/rust/dns-over-tcp/tests/client_and_server.rs index ed5b294ba..e4c5fe2ef 100644 --- a/rust/dns-over-tcp/tests/client_and_server.rs +++ b/rust/dns-over-tcp/tests/client_and_server.rs @@ -20,9 +20,6 @@ fn smoke() { let mut dns_client = dns_over_tcp::Client::new(Instant::now(), [0u8; 32]); dns_client.set_source_interface(ipv4, ipv6); - dns_client - .set_resolvers(BTreeSet::from_iter([resolver_addr])) - .unwrap(); let mut dns_server = dns_over_tcp::Server::new(Instant::now()); dns_server.set_listen_addresses::<1>(BTreeSet::from([resolver_addr]));