mirror of
https://github.com/outbackdingo/firezone.git
synced 2026-01-27 18:18:55 +00:00
fix(connlib): clear pending sockets on DNS server re-creation (#9093)
Our DNS over TCP implementation uses `smoltcp` which requires us to manage sockets individually, i.e. there is no such thing as a listening socket. Instead, we have to create multiple sockets and rotate through them. Whenever we receive new DNS servers from the host app, we throw away all of those sockets and create new ones. The way we refer to these sockets internally is via `smoltcp`'s `SocketHandle`. These are just indices into a `Vec` and this access can panic when it is out of range. Normally that doesn't happen because such a `SocketHandle` is only created when the socket is created and therefore, each `SocketHandle` in existence should be valid. What we overlooked is that these sockets get destroyed and re-created when we call `set_listen_addresses` which happens when the host app tells us about new DNS servers. In that case, sockets that we had just received a query on and are waiting for a response have their handles stored in a temporary `HashMap`. Attempting to send back a response for one of those queries will then either fail with an error that the socket is not in the right state or - worse - panic with an out of bounds error if the previously had more listen addresses than we have now. To fix this, we need to clear this map of pending queries every time we call `set_listen_addresses`.
This commit is contained in:
@@ -108,6 +108,7 @@ impl Server {
|
||||
self.sockets = sockets;
|
||||
self.listen_endpoints = listen_endpoints;
|
||||
self.received_queries.clear();
|
||||
self.pending_sockets_by_local_remote_and_query_id.clear();
|
||||
}
|
||||
|
||||
/// Checks whether this server can handle the given packet.
|
||||
|
||||
@@ -44,10 +44,60 @@ fn smoke() {
|
||||
}
|
||||
}
|
||||
|
||||
fn progress(
|
||||
#[test]
|
||||
fn no_panic_after_set_listen_address() {
|
||||
let _guard = firezone_logging::test(
|
||||
"netlink_proto=off,wire::dns::res=trace,dns_over_tcp=trace,smoltcp=trace,debug",
|
||||
);
|
||||
|
||||
let ipv4 = Ipv4Addr::from([100, 90, 215, 97]);
|
||||
let ipv6 = Ipv6Addr::from([0xfd00, 0x2021, 0x1111, 0x0, 0x0, 0x0, 0x0016, 0x588f]);
|
||||
|
||||
let resolver_addr1 = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(100, 100, 111, 1), 53));
|
||||
let resolver_addr2 = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(100, 100, 111, 2), 53));
|
||||
let resolver_addr3 = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(100, 100, 111, 3), 53));
|
||||
|
||||
let mut dns_client = dns_over_tcp::Client::new(Instant::now(), [0u8; 32]);
|
||||
dns_client.set_source_interface(ipv4, ipv6);
|
||||
|
||||
let mut dns_server = dns_over_tcp::Server::new(Instant::now());
|
||||
dns_server.set_listen_addresses::<2>(BTreeSet::from([resolver_addr1, resolver_addr2]));
|
||||
|
||||
// Feed some queries.
|
||||
dns_client
|
||||
.send_query(
|
||||
resolver_addr1,
|
||||
Query::new("foo.example.com".parse().unwrap(), RecordType::A),
|
||||
)
|
||||
.unwrap();
|
||||
dns_client
|
||||
.send_query(
|
||||
resolver_addr2,
|
||||
Query::new("bar.example.com".parse().unwrap(), RecordType::A),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
// Send all packets to server.
|
||||
let queries = receive_queries::<2>(&mut dns_client, &mut dns_server);
|
||||
|
||||
// Change listen addresses
|
||||
dns_server.set_listen_addresses::<1>(BTreeSet::from([resolver_addr3]));
|
||||
|
||||
for query in queries {
|
||||
let _ = dns_server.send_message(
|
||||
query.local,
|
||||
query.remote,
|
||||
ResponseBuilder::for_query(&query.message, ResponseCode::NXDOMAIN).build(),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn receive_queries<const N: usize>(
|
||||
dns_client: &mut dns_over_tcp::Client,
|
||||
dns_server: &mut dns_over_tcp::Server,
|
||||
) -> Option<QueryResult> {
|
||||
) -> Vec<dns_over_tcp::Query> {
|
||||
let mut queries = Vec::with_capacity(N);
|
||||
|
||||
loop {
|
||||
if let Some(packet) = dns_client.poll_outbound() {
|
||||
dns_server.handle_inbound(packet);
|
||||
@@ -59,6 +109,39 @@ fn progress(
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some(query) = dns_server.poll_queries() {
|
||||
queries.push(query);
|
||||
continue;
|
||||
}
|
||||
|
||||
dns_client.handle_timeout(Instant::now());
|
||||
dns_server.handle_timeout(Instant::now());
|
||||
|
||||
if queries.len() == N {
|
||||
return queries;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn progress(
|
||||
dns_client: &mut dns_over_tcp::Client,
|
||||
dns_server: &mut dns_over_tcp::Server,
|
||||
) -> Option<QueryResult> {
|
||||
loop {
|
||||
if let Some(packet) = dns_client.poll_outbound() {
|
||||
if dns_server.accepts(&packet) {
|
||||
dns_server.handle_inbound(packet);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some(packet) = dns_server.poll_outbound() {
|
||||
if dns_client.accepts(&packet) {
|
||||
dns_client.handle_inbound(packet);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some(query) = dns_server.poll_queries() {
|
||||
dns_server
|
||||
.send_message(
|
||||
|
||||
Reference in New Issue
Block a user