diff --git a/rust/connlib/tunnel/src/io.rs b/rust/connlib/tunnel/src/io.rs index 15bda0115..cdab319bf 100644 --- a/rust/connlib/tunnel/src/io.rs +++ b/rust/connlib/tunnel/src/io.rs @@ -110,7 +110,11 @@ impl Io { let mut sockets = Sockets::default(); sockets.rebind(udp_socket_factory.as_ref()); // Bind sockets on startup. Must happen within a tokio runtime context. - let mut nameservers = NameserverSet::new(nameservers, udp_socket_factory.clone()); + let mut nameservers = NameserverSet::new( + nameservers, + tcp_socket_factory.clone(), + udp_socket_factory.clone(), + ); nameservers.evaluate(); Self { diff --git a/rust/connlib/tunnel/src/io/nameserver_set.rs b/rust/connlib/tunnel/src/io/nameserver_set.rs index bc9ce28f9..ceb18e3ba 100644 --- a/rust/connlib/tunnel/src/io/nameserver_set.rs +++ b/rust/connlib/tunnel/src/io/nameserver_set.rs @@ -10,11 +10,13 @@ use std::{ use connlib_model::DomainName; use domain::base::{iana::Rcode, Message, MessageBuilder, Question, Rtype}; use futures_bounded::FuturesTupleSet; -use socket_factory::{SocketFactory, UdpSocket}; +use socket_factory::{SocketFactory, TcpSocket, UdpSocket}; use crate::io::udp_dns; -const MAX_DNS_SERVERS: usize = 10; // We don't bother selecting from more than 10 servers. +use super::tcp_dns; + +const MAX_DNS_SERVERS: usize = 20; // We don't bother selecting from more than 10 servers over UDP and TCP. const DNS_TIMEOUT: Duration = Duration::from_secs(2); // Every sensible DNS servers should respond within 2s. static FIREZONE_DEV: LazyLock = LazyLock::new(|| { @@ -25,7 +27,8 @@ pub struct NameserverSet { inner: BTreeSet, nameserver_by_rtt: BTreeMap, - socket_factory: Arc>, + tcp_socket_factory: Arc>, + udp_socket_factory: Arc>, queries: FuturesTupleSet>>, QueryMetaData>, } @@ -37,12 +40,14 @@ struct QueryMetaData { impl NameserverSet { pub fn new( inner: BTreeSet, + tcp_socket_factory: Arc>, udp_socket_factory: Arc>, ) -> Self { Self { queries: FuturesTupleSet::new(DNS_TIMEOUT, MAX_DNS_SERVERS), inner, - socket_factory: udp_socket_factory, + tcp_socket_factory, + udp_socket_factory, nameserver_by_rtt: Default::default(), } } @@ -56,7 +61,7 @@ impl NameserverSet { .queries .try_push( udp_dns::send( - self.socket_factory.clone(), + self.udp_socket_factory.clone(), SocketAddr::new(nameserver, crate::dns::DNS_PORT), query_firezone_dev(), ), @@ -64,7 +69,22 @@ impl NameserverSet { ) .is_err() { - tracing::debug!(%nameserver, "Failed to queue another DNS query"); + tracing::debug!(%nameserver, "Failed to queue another UDP DNS query"); + } + + if self + .queries + .try_push( + tcp_dns::send( + self.tcp_socket_factory.clone(), + SocketAddr::new(nameserver, crate::dns::DNS_PORT), + query_firezone_dev(), + ), + QueryMetaData { nameserver, start }, + ) + .is_err() + { + tracing::debug!(%nameserver, "Failed to queue another TCP DNS query"); } } } @@ -145,6 +165,7 @@ mod tests { Ipv4Addr::new(9, 9, 9, 9).into(), Ipv4Addr::new(100, 100, 100, 100).into(), // Also include an unreachable server. ]), + Arc::new(socket_factory::tcp), Arc::new(socket_factory::udp), ); set.evaluate(); @@ -158,7 +179,11 @@ mod tests { async fn can_handle_no_servers() { let _guard = firezone_logging::test("debug"); - let mut set = NameserverSet::new(BTreeSet::default(), Arc::new(socket_factory::udp)); + let mut set = NameserverSet::new( + BTreeSet::default(), + Arc::new(socket_factory::tcp), + Arc::new(socket_factory::udp), + ); std::future::poll_fn(|cx| set.poll(cx)).await;