diff --git a/rust/connlib/l3-udp-dns-client/lib.rs b/rust/connlib/l3-udp-dns-client/lib.rs index 21beb97f7..ea13ab080 100644 --- a/rust/connlib/l3-udp-dns-client/lib.rs +++ b/rust/connlib/l3-udp-dns-client/lib.rs @@ -1,3 +1,5 @@ +#![cfg_attr(test, allow(clippy::unwrap_used))] + use std::{ collections::{HashMap, VecDeque}, net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, @@ -20,9 +22,6 @@ pub struct Client { query_results: VecDeque, rng: StdRng, - - _created_at: Instant, - last_now: Instant, } struct PendingQuery { @@ -39,7 +38,7 @@ pub struct QueryResult { } impl Client { - pub fn new(now: Instant, seed: [u8; 32]) -> Self { + pub fn new(seed: [u8; 32]) -> Self { // Sadly, these can't be compile-time assertions :( assert!(MIN_PORT >= 49152, "Must use ephemeral port range"); assert!(MIN_PORT < MAX_PORT, "Port range must not have length 0"); @@ -47,8 +46,6 @@ impl Client { Self { source_ips: None, rng: StdRng::from_seed(seed), - _created_at: now, - last_now: now, pending_queries_by_local_port: Default::default(), scheduled_queries: Default::default(), query_results: Default::default(), @@ -62,7 +59,8 @@ impl Client { /// Send the given DNS query to the target server. /// - /// This only queues the message. You need to call [`Client::handle_timeout`] to actually send them. + /// This only queues the message. You need to call [`Client::poll_outbound`] to retrieve + /// the resulting IP packet and send it to the server. pub fn send_query( &mut self, server: SocketAddr, @@ -102,7 +100,7 @@ impl Client { /// Checks whether this client can handle the given packet. /// - /// Only TCP packets originating from one of the connected DNS resolvers are accepted. + /// Only UDP packets for pending DNS queries are accepted. pub fn accepts(&self, packet: &IpPacket) -> bool { let Some(udp) = packet.as_udp() else { #[cfg(debug_assertions)] @@ -129,10 +127,6 @@ impl Client { .contains_key(&udp.destination_port()) } - /// Handle the [`IpPacket`]. - /// - /// This function only inserts the packet into a buffer. - /// To actually process the packets in the buffer, [`Client::handle_timeout`] must be called. pub fn handle_inbound(&mut self, packet: IpPacket) { debug_assert!(self.accepts(&packet)); @@ -182,12 +176,7 @@ impl Client { self.query_results.pop_front() } - /// Inform the client that time advanced. - /// - /// Typical for a sans-IO design, `handle_timeout` will work through all local buffers and process them as much as possible. pub fn handle_timeout(&mut self, now: Instant) { - self.last_now = now; - for ( _, PendingQuery { @@ -247,3 +236,93 @@ impl Client { } } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_timeout_multiple_queries() { + let mut client = create_test_client(); + let now = Instant::now(); + let server1 = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)), 53); + let server2 = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(8, 8, 4, 4)), 53); + + // Send two queries at the same time + client + .send_query(server1, create_test_query(), now) + .unwrap(); + client + .send_query(server2, create_test_query(), now) + .unwrap(); + assert_eq!(client.poll_timeout(), Some(now + TIMEOUT)); + + // Send third query 10 seconds later + let later = now + Duration::from_secs(10); + client + .send_query(server1, create_test_query(), later) + .unwrap(); + + // poll_timeout should return the earliest timeout + assert_eq!(client.poll_timeout(), Some(now + TIMEOUT)); + + // Advance to after first two timeouts but before third + client.handle_timeout(now + TIMEOUT + Duration::from_secs(1)); + + // First two queries should have timed out + assert!(client.poll_query_result().unwrap().result.is_err()); + assert!(client.poll_query_result().unwrap().result.is_err()); + assert!(client.poll_query_result().is_none()); + + // Third query should still be pending + assert_eq!(client.poll_timeout(), Some(later + TIMEOUT)); + + // Advance past third timeout + client.handle_timeout(later + TIMEOUT + Duration::from_secs(1)); + assert!(client.poll_query_result().unwrap().result.is_err()); + assert!(client.poll_timeout().is_none()); + } + + #[test] + fn test_reset_times_out_all_pending_queries() { + let mut client = create_test_client(); + let now = Instant::now(); + let server1 = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)), 53); + let server2 = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(8, 8, 4, 4)), 53); + let query1 = create_test_query(); + let query2 = create_test_query(); + + // Send multiple queries + client.send_query(server1, query1, now).unwrap(); + client.send_query(server2, query2, now).unwrap(); + + // Reset should abort all pending queries + client.reset(); + + // Both queries should have error results + assert!(client.poll_query_result().unwrap().result.is_err()); + assert!(client.poll_query_result().unwrap().result.is_err()); + assert!(client.poll_query_result().is_none()); + } + + #[test] + fn test_poll_timeout_returns_none_when_no_pending_queries() { + let mut client = create_test_client(); + + // No pending queries, should return None + assert!(client.poll_timeout().is_none()); + } + + fn create_test_client() -> Client { + let seed = [0u8; 32]; + let mut client = Client::new(seed); + client.set_source_interface(Ipv4Addr::new(10, 0, 0, 1), Ipv6Addr::LOCALHOST); + client + } + + fn create_test_query() -> dns_types::Query { + use std::str::FromStr; + let domain = dns_types::DomainName::from_str("example.com").unwrap(); + dns_types::Query::new(domain, dns_types::RecordType::A) + } +} diff --git a/rust/connlib/tunnel/src/client.rs b/rust/connlib/tunnel/src/client.rs index 03c644ea5..16bf03b75 100644 --- a/rust/connlib/tunnel/src/client.rs +++ b/rust/connlib/tunnel/src/client.rs @@ -208,7 +208,7 @@ impl ClientState { is_internet_resource_active, recently_connected_gateways: LruCache::new(MAX_REMEMBERED_GATEWAYS), buffered_dns_queries: Default::default(), - udp_dns_client: l3_udp_dns_client::Client::new(now, seed), + udp_dns_client: l3_udp_dns_client::Client::new(seed), tcp_dns_client: dns_over_tcp::Client::new(now, seed), tcp_dns_server: dns_over_tcp::Server::new(now), dns_streams_by_upstream_and_query_id: Default::default(),