fix(connlib): clear pending sockets on DNS server re-creation (#9093)

Our DNS over TCP implementation uses `smoltcp` which requires us to
manage sockets individually, i.e. there is no such thing as a listening
socket. Instead, we have to create multiple sockets and rotate through
them.

Whenever we receive new DNS servers from the host app, we throw away all
of those sockets and create new ones.

The way we refer to these sockets internally is via `smoltcp`'s
`SocketHandle`. These are just indices into a `Vec` and this access can
panic when it is out of range. Normally that doesn't happen because such
a `SocketHandle` is only created when the socket is created and
therefore, each `SocketHandle` in existence should be valid.

What we overlooked is that these sockets get destroyed and re-created
when we call `set_listen_addresses` which happens when the host app
tells us about new DNS servers. In that case, sockets that we had just
received a query on and are waiting for a response have their handles
stored in a temporary `HashMap`. Attempting to send back a response for
one of those queries will then either fail with an error that the socket
is not in the right state or - worse - panic with an out of bounds error
if the previously had more listen addresses than we have now.

To fix this, we need to clear this map of pending queries every time we
call `set_listen_addresses`.
This commit is contained in:
Thomas Eizinger
2025-05-12 21:39:59 +10:00
committed by GitHub
parent 7e4fe68485
commit f01fd4ddf6
2 changed files with 86 additions and 2 deletions

View File

@@ -108,6 +108,7 @@ impl Server {
self.sockets = sockets;
self.listen_endpoints = listen_endpoints;
self.received_queries.clear();
self.pending_sockets_by_local_remote_and_query_id.clear();
}
/// Checks whether this server can handle the given packet.

View File

@@ -44,10 +44,60 @@ fn smoke() {
}
}
fn progress(
#[test]
fn no_panic_after_set_listen_address() {
let _guard = firezone_logging::test(
"netlink_proto=off,wire::dns::res=trace,dns_over_tcp=trace,smoltcp=trace,debug",
);
let ipv4 = Ipv4Addr::from([100, 90, 215, 97]);
let ipv6 = Ipv6Addr::from([0xfd00, 0x2021, 0x1111, 0x0, 0x0, 0x0, 0x0016, 0x588f]);
let resolver_addr1 = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(100, 100, 111, 1), 53));
let resolver_addr2 = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(100, 100, 111, 2), 53));
let resolver_addr3 = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(100, 100, 111, 3), 53));
let mut dns_client = dns_over_tcp::Client::new(Instant::now(), [0u8; 32]);
dns_client.set_source_interface(ipv4, ipv6);
let mut dns_server = dns_over_tcp::Server::new(Instant::now());
dns_server.set_listen_addresses::<2>(BTreeSet::from([resolver_addr1, resolver_addr2]));
// Feed some queries.
dns_client
.send_query(
resolver_addr1,
Query::new("foo.example.com".parse().unwrap(), RecordType::A),
)
.unwrap();
dns_client
.send_query(
resolver_addr2,
Query::new("bar.example.com".parse().unwrap(), RecordType::A),
)
.unwrap();
// Send all packets to server.
let queries = receive_queries::<2>(&mut dns_client, &mut dns_server);
// Change listen addresses
dns_server.set_listen_addresses::<1>(BTreeSet::from([resolver_addr3]));
for query in queries {
let _ = dns_server.send_message(
query.local,
query.remote,
ResponseBuilder::for_query(&query.message, ResponseCode::NXDOMAIN).build(),
);
}
}
fn receive_queries<const N: usize>(
dns_client: &mut dns_over_tcp::Client,
dns_server: &mut dns_over_tcp::Server,
) -> Option<QueryResult> {
) -> Vec<dns_over_tcp::Query> {
let mut queries = Vec::with_capacity(N);
loop {
if let Some(packet) = dns_client.poll_outbound() {
dns_server.handle_inbound(packet);
@@ -59,6 +109,39 @@ fn progress(
continue;
}
if let Some(query) = dns_server.poll_queries() {
queries.push(query);
continue;
}
dns_client.handle_timeout(Instant::now());
dns_server.handle_timeout(Instant::now());
if queries.len() == N {
return queries;
}
}
}
fn progress(
dns_client: &mut dns_over_tcp::Client,
dns_server: &mut dns_over_tcp::Server,
) -> Option<QueryResult> {
loop {
if let Some(packet) = dns_client.poll_outbound() {
if dns_server.accepts(&packet) {
dns_server.handle_inbound(packet);
}
continue;
}
if let Some(packet) = dns_server.poll_outbound() {
if dns_client.accepts(&packet) {
dns_client.handle_inbound(packet);
}
continue;
}
if let Some(query) = dns_server.poll_queries() {
dns_server
.send_message(