mirror of
https://github.com/outbackdingo/firezone.git
synced 2026-01-27 10:18:54 +00:00
chore(connlib): add time-related tests to l3-udp-dns-client (#10913)
This module didn't have any tests yet so I generated some with Claude and trimmed them down to a meaningful set.
This commit is contained in:
@@ -1,3 +1,5 @@
|
||||
#![cfg_attr(test, allow(clippy::unwrap_used))]
|
||||
|
||||
use std::{
|
||||
collections::{HashMap, VecDeque},
|
||||
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
|
||||
@@ -20,9 +22,6 @@ pub struct Client<const MIN_PORT: u16 = 49152, const MAX_PORT: u16 = 65535> {
|
||||
query_results: VecDeque<QueryResult>,
|
||||
|
||||
rng: StdRng,
|
||||
|
||||
_created_at: Instant,
|
||||
last_now: Instant,
|
||||
}
|
||||
|
||||
struct PendingQuery {
|
||||
@@ -39,7 +38,7 @@ pub struct QueryResult {
|
||||
}
|
||||
|
||||
impl<const MIN_PORT: u16, const MAX_PORT: u16> Client<MIN_PORT, MAX_PORT> {
|
||||
pub fn new(now: Instant, seed: [u8; 32]) -> Self {
|
||||
pub fn new(seed: [u8; 32]) -> Self {
|
||||
// Sadly, these can't be compile-time assertions :(
|
||||
assert!(MIN_PORT >= 49152, "Must use ephemeral port range");
|
||||
assert!(MIN_PORT < MAX_PORT, "Port range must not have length 0");
|
||||
@@ -47,8 +46,6 @@ impl<const MIN_PORT: u16, const MAX_PORT: u16> Client<MIN_PORT, MAX_PORT> {
|
||||
Self {
|
||||
source_ips: None,
|
||||
rng: StdRng::from_seed(seed),
|
||||
_created_at: now,
|
||||
last_now: now,
|
||||
pending_queries_by_local_port: Default::default(),
|
||||
scheduled_queries: Default::default(),
|
||||
query_results: Default::default(),
|
||||
@@ -62,7 +59,8 @@ impl<const MIN_PORT: u16, const MAX_PORT: u16> Client<MIN_PORT, MAX_PORT> {
|
||||
|
||||
/// Send the given DNS query to the target server.
|
||||
///
|
||||
/// This only queues the message. You need to call [`Client::handle_timeout`] to actually send them.
|
||||
/// This only queues the message. You need to call [`Client::poll_outbound`] to retrieve
|
||||
/// the resulting IP packet and send it to the server.
|
||||
pub fn send_query(
|
||||
&mut self,
|
||||
server: SocketAddr,
|
||||
@@ -102,7 +100,7 @@ impl<const MIN_PORT: u16, const MAX_PORT: u16> Client<MIN_PORT, MAX_PORT> {
|
||||
|
||||
/// Checks whether this client can handle the given packet.
|
||||
///
|
||||
/// Only TCP packets originating from one of the connected DNS resolvers are accepted.
|
||||
/// Only UDP packets for pending DNS queries are accepted.
|
||||
pub fn accepts(&self, packet: &IpPacket) -> bool {
|
||||
let Some(udp) = packet.as_udp() else {
|
||||
#[cfg(debug_assertions)]
|
||||
@@ -129,10 +127,6 @@ impl<const MIN_PORT: u16, const MAX_PORT: u16> Client<MIN_PORT, MAX_PORT> {
|
||||
.contains_key(&udp.destination_port())
|
||||
}
|
||||
|
||||
/// Handle the [`IpPacket`].
|
||||
///
|
||||
/// This function only inserts the packet into a buffer.
|
||||
/// To actually process the packets in the buffer, [`Client::handle_timeout`] must be called.
|
||||
pub fn handle_inbound(&mut self, packet: IpPacket) {
|
||||
debug_assert!(self.accepts(&packet));
|
||||
|
||||
@@ -182,12 +176,7 @@ impl<const MIN_PORT: u16, const MAX_PORT: u16> Client<MIN_PORT, MAX_PORT> {
|
||||
self.query_results.pop_front()
|
||||
}
|
||||
|
||||
/// Inform the client that time advanced.
|
||||
///
|
||||
/// Typical for a sans-IO design, `handle_timeout` will work through all local buffers and process them as much as possible.
|
||||
pub fn handle_timeout(&mut self, now: Instant) {
|
||||
self.last_now = now;
|
||||
|
||||
for (
|
||||
_,
|
||||
PendingQuery {
|
||||
@@ -247,3 +236,93 @@ impl<const MIN_PORT: u16, const MAX_PORT: u16> Client<MIN_PORT, MAX_PORT> {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_timeout_multiple_queries() {
|
||||
let mut client = create_test_client();
|
||||
let now = Instant::now();
|
||||
let server1 = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)), 53);
|
||||
let server2 = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(8, 8, 4, 4)), 53);
|
||||
|
||||
// Send two queries at the same time
|
||||
client
|
||||
.send_query(server1, create_test_query(), now)
|
||||
.unwrap();
|
||||
client
|
||||
.send_query(server2, create_test_query(), now)
|
||||
.unwrap();
|
||||
assert_eq!(client.poll_timeout(), Some(now + TIMEOUT));
|
||||
|
||||
// Send third query 10 seconds later
|
||||
let later = now + Duration::from_secs(10);
|
||||
client
|
||||
.send_query(server1, create_test_query(), later)
|
||||
.unwrap();
|
||||
|
||||
// poll_timeout should return the earliest timeout
|
||||
assert_eq!(client.poll_timeout(), Some(now + TIMEOUT));
|
||||
|
||||
// Advance to after first two timeouts but before third
|
||||
client.handle_timeout(now + TIMEOUT + Duration::from_secs(1));
|
||||
|
||||
// First two queries should have timed out
|
||||
assert!(client.poll_query_result().unwrap().result.is_err());
|
||||
assert!(client.poll_query_result().unwrap().result.is_err());
|
||||
assert!(client.poll_query_result().is_none());
|
||||
|
||||
// Third query should still be pending
|
||||
assert_eq!(client.poll_timeout(), Some(later + TIMEOUT));
|
||||
|
||||
// Advance past third timeout
|
||||
client.handle_timeout(later + TIMEOUT + Duration::from_secs(1));
|
||||
assert!(client.poll_query_result().unwrap().result.is_err());
|
||||
assert!(client.poll_timeout().is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_reset_times_out_all_pending_queries() {
|
||||
let mut client = create_test_client();
|
||||
let now = Instant::now();
|
||||
let server1 = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)), 53);
|
||||
let server2 = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(8, 8, 4, 4)), 53);
|
||||
let query1 = create_test_query();
|
||||
let query2 = create_test_query();
|
||||
|
||||
// Send multiple queries
|
||||
client.send_query(server1, query1, now).unwrap();
|
||||
client.send_query(server2, query2, now).unwrap();
|
||||
|
||||
// Reset should abort all pending queries
|
||||
client.reset();
|
||||
|
||||
// Both queries should have error results
|
||||
assert!(client.poll_query_result().unwrap().result.is_err());
|
||||
assert!(client.poll_query_result().unwrap().result.is_err());
|
||||
assert!(client.poll_query_result().is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_poll_timeout_returns_none_when_no_pending_queries() {
|
||||
let mut client = create_test_client();
|
||||
|
||||
// No pending queries, should return None
|
||||
assert!(client.poll_timeout().is_none());
|
||||
}
|
||||
|
||||
fn create_test_client() -> Client {
|
||||
let seed = [0u8; 32];
|
||||
let mut client = Client::new(seed);
|
||||
client.set_source_interface(Ipv4Addr::new(10, 0, 0, 1), Ipv6Addr::LOCALHOST);
|
||||
client
|
||||
}
|
||||
|
||||
fn create_test_query() -> dns_types::Query {
|
||||
use std::str::FromStr;
|
||||
let domain = dns_types::DomainName::from_str("example.com").unwrap();
|
||||
dns_types::Query::new(domain, dns_types::RecordType::A)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -208,7 +208,7 @@ impl ClientState {
|
||||
is_internet_resource_active,
|
||||
recently_connected_gateways: LruCache::new(MAX_REMEMBERED_GATEWAYS),
|
||||
buffered_dns_queries: Default::default(),
|
||||
udp_dns_client: l3_udp_dns_client::Client::new(now, seed),
|
||||
udp_dns_client: l3_udp_dns_client::Client::new(seed),
|
||||
tcp_dns_client: dns_over_tcp::Client::new(now, seed),
|
||||
tcp_dns_server: dns_over_tcp::Server::new(now),
|
||||
dns_streams_by_upstream_and_query_id: Default::default(),
|
||||
|
||||
Reference in New Issue
Block a user