diff --git a/rust/connlib/dns-over-tcp/src/server.rs b/rust/connlib/dns-over-tcp/src/server.rs index 63f3aa25e..8ea8097d9 100644 --- a/rust/connlib/dns-over-tcp/src/server.rs +++ b/rust/connlib/dns-over-tcp/src/server.rs @@ -108,6 +108,7 @@ impl Server { self.sockets = sockets; self.listen_endpoints = listen_endpoints; self.received_queries.clear(); + self.pending_sockets_by_local_remote_and_query_id.clear(); } /// Checks whether this server can handle the given packet. diff --git a/rust/connlib/dns-over-tcp/tests/client_and_server.rs b/rust/connlib/dns-over-tcp/tests/client_and_server.rs index 759e1fde2..fbd012f1c 100644 --- a/rust/connlib/dns-over-tcp/tests/client_and_server.rs +++ b/rust/connlib/dns-over-tcp/tests/client_and_server.rs @@ -44,10 +44,60 @@ fn smoke() { } } -fn progress( +#[test] +fn no_panic_after_set_listen_address() { + let _guard = firezone_logging::test( + "netlink_proto=off,wire::dns::res=trace,dns_over_tcp=trace,smoltcp=trace,debug", + ); + + let ipv4 = Ipv4Addr::from([100, 90, 215, 97]); + let ipv6 = Ipv6Addr::from([0xfd00, 0x2021, 0x1111, 0x0, 0x0, 0x0, 0x0016, 0x588f]); + + let resolver_addr1 = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(100, 100, 111, 1), 53)); + let resolver_addr2 = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(100, 100, 111, 2), 53)); + let resolver_addr3 = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(100, 100, 111, 3), 53)); + + let mut dns_client = dns_over_tcp::Client::new(Instant::now(), [0u8; 32]); + dns_client.set_source_interface(ipv4, ipv6); + + let mut dns_server = dns_over_tcp::Server::new(Instant::now()); + dns_server.set_listen_addresses::<2>(BTreeSet::from([resolver_addr1, resolver_addr2])); + + // Feed some queries. + dns_client + .send_query( + resolver_addr1, + Query::new("foo.example.com".parse().unwrap(), RecordType::A), + ) + .unwrap(); + dns_client + .send_query( + resolver_addr2, + Query::new("bar.example.com".parse().unwrap(), RecordType::A), + ) + .unwrap(); + + // Send all packets to server. + let queries = receive_queries::<2>(&mut dns_client, &mut dns_server); + + // Change listen addresses + dns_server.set_listen_addresses::<1>(BTreeSet::from([resolver_addr3])); + + for query in queries { + let _ = dns_server.send_message( + query.local, + query.remote, + ResponseBuilder::for_query(&query.message, ResponseCode::NXDOMAIN).build(), + ); + } +} + +fn receive_queries( dns_client: &mut dns_over_tcp::Client, dns_server: &mut dns_over_tcp::Server, -) -> Option { +) -> Vec { + let mut queries = Vec::with_capacity(N); + loop { if let Some(packet) = dns_client.poll_outbound() { dns_server.handle_inbound(packet); @@ -59,6 +109,39 @@ fn progress( continue; } + if let Some(query) = dns_server.poll_queries() { + queries.push(query); + continue; + } + + dns_client.handle_timeout(Instant::now()); + dns_server.handle_timeout(Instant::now()); + + if queries.len() == N { + return queries; + } + } +} + +fn progress( + dns_client: &mut dns_over_tcp::Client, + dns_server: &mut dns_over_tcp::Server, +) -> Option { + loop { + if let Some(packet) = dns_client.poll_outbound() { + if dns_server.accepts(&packet) { + dns_server.handle_inbound(packet); + } + continue; + } + + if let Some(packet) = dns_server.poll_outbound() { + if dns_client.accepts(&packet) { + dns_client.handle_inbound(packet); + } + continue; + } + if let Some(query) = dns_server.poll_queries() { dns_server .send_message(