feat(connlib): support DoH (#10876)

Building on top of a series of refactors and smaller features, this PR
enables connlib to send DNS queries over HTTPS to one or more configured
DoH providers.

A DoH server itself is addressed via a domain which first needs to be
resolved before it can be contacted. The RFC recommends to perform this
bootstrapping using the system DNS resolvers. For connlib, this is a bit
tricky because the system resolvers may already be set to connlib's
sentinel servers by the time we need to bootstrap the DoH clients.
Therefore, we maintain a dedicated UDP DNS client inside connlib's `Io`
component which is always configured with the latest system DNS
resolvers known to connlib.

The actual bootstrapping of a DoH client happens in the following cases:

1. Our TUN device configuration changes and the configured DNS servers
mapping contains DoH upstreams.
2. We need to make a DNS query to a DoH server but don't have a client
yet.

The first case ensures we bootstrap the DoH clients as early as
possible. The latter case ensures we have a self-healing behaviour in
case the TCP connection to the DoH server breaks (in which case the DoH
client will be de-allocated).

Once the DoH client is initialized, making queries with it is a trivial
act of sending an HTTP request and parsing the HTTP response. Within
connlib, this now requires almost no special handling apart from a new
`dns::Upstream` type that differentiates between Do53 servers (addressed
by a `SocketAddr`) and DoH servers (addressed by a `Url`).

Related: #10764
Related: #10788
Related: #10850
Related: #10851
Related: #10856
Related: #10857
Related: #10871
Related: #10872
Related: #10875
Related: #10881
Resolves: #10790
This commit is contained in:
Thomas Eizinger
2025-11-19 16:10:52 +11:00
committed by GitHub
parent 9b0ae92b29
commit 01e16e87d6
14 changed files with 374 additions and 116 deletions

2
rust/Cargo.lock generated
View File

@@ -2690,6 +2690,7 @@ dependencies = [
"gat-lending-iterator",
"glob",
"hex",
"http-client",
"ip-packet",
"ip_network",
"ip_network_table",
@@ -2697,6 +2698,7 @@ dependencies = [
"l3-tcp",
"l3-udp-dns-client",
"l4-tcp-dns-server",
"l4-udp-dns-client",
"l4-udp-dns-server",
"lru",
"opentelemetry",

View File

@@ -200,7 +200,7 @@ impl Eventloop {
return Ok(ControlFlow::Continue(()));
};
let dns = tunnel.state_mut().update_system_resolvers(dns);
let dns = tunnel.update_system_resolvers(dns);
self.portal_cmd_tx
.send(PortalCommand::UpdateDnsServers(dns))

View File

@@ -29,12 +29,14 @@ futures-bounded = { workspace = true, features = ["tokio"] }
gat-lending-iterator = { workspace = true }
glob = { workspace = true }
hex = { workspace = true }
http-client = { workspace = true }
ip-packet = { workspace = true }
ip_network = { workspace = true }
ip_network_table = { workspace = true }
itertools = { workspace = true, features = ["use_std"] }
l3-udp-dns-client = { workspace = true }
l4-tcp-dns-server = { workspace = true }
l4-udp-dns-client = { workspace = true }
l4-udp-dns-server = { workspace = true }
lru = { workspace = true }
opentelemetry = { workspace = true, features = ["metrics"] }

View File

@@ -778,17 +778,27 @@ impl ClientState {
/// For DNS queries to IPs that are a CIDR resources we want to mangle and forward to the gateway that handles that resource.
///
/// We only want to do this if the upstream DNS server is set by the portal, otherwise, the server might be a local IP.
fn should_forward_dns_query_to_gateway(&self, dns_server: IpAddr) -> bool {
fn should_forward_dns_query_to_gateway(
&self,
dns_server: &dns::Upstream,
) -> Option<SocketAddr> {
if !self.dns_config.has_custom_upstream() {
return false;
return None;
}
let server = match dns_server {
dns::Upstream::Do53 { server } => server,
dns::Upstream::DoH { .. } => return None, // If DoH upstreams are in effect, we never forward queries to upstreams.
};
if self.active_internet_resource().is_some() {
return true;
return Some(*server);
}
self.active_cidr_resources
.longest_match(dns_server)
.longest_match(server.ip())
.is_some()
.then_some(*server)
}
/// Handles UDP & TCP packets targeted at our stub resolver.
@@ -1014,7 +1024,7 @@ impl ClientState {
///
/// Note: The returned list is not necessarily the list of DNS resolvers that is active.
/// If DNS servers are defined in the portal, those will be preferred over the system defined ones.
pub fn update_system_resolvers(&mut self, new_dns: Vec<IpAddr>) -> Vec<IpAddr> {
pub(crate) fn update_system_resolvers(&mut self, new_dns: Vec<IpAddr>) -> Vec<IpAddr> {
let changed = self.dns_config.update_system_resolvers(new_dns);
if !changed {
@@ -1199,7 +1209,7 @@ impl ClientState {
self.handle_dns_response(
dns::RecursiveResponse {
server,
server: dns::Upstream::Do53 { server },
local,
remote,
query: query_result.query,
@@ -1227,7 +1237,7 @@ impl ClientState {
self.handle_dns_response(
dns::RecursiveResponse {
server,
server: dns::Upstream::Do53 { server },
local,
remote,
query: query_result.query,
@@ -1257,7 +1267,7 @@ impl ClientState {
}
}
fn handle_udp_dns_query(&mut self, upstream: SocketAddr, packet: IpPacket, now: Instant) {
fn handle_udp_dns_query(&mut self, upstream: dns::Upstream, packet: IpPacket, now: Instant) {
let Some(datagram) = packet.as_udp() else {
tracing::debug!(?packet, "Not a UDP packet");
@@ -1355,7 +1365,7 @@ impl ClientState {
message: dns_types::Query,
local: SocketAddr,
remote: SocketAddr,
upstream: SocketAddr,
upstream: dns::Upstream,
transport: dns::Transport,
now: Instant,
) -> Option<dns_types::Response> {
@@ -1374,7 +1384,7 @@ impl ClientState {
return Some(response);
}
dns::ResolveStrategy::RecurseLocal => {
if self.should_forward_dns_query_to_gateway(upstream.ip()) {
if let Some(upstream) = self.should_forward_dns_query_to_gateway(&upstream) {
self.forward_dns_query_to_new_upstream_via_tunnel(
local, remote, upstream, message, transport, now,
);

View File

@@ -1,6 +1,6 @@
use std::{
collections::HashSet,
net::{IpAddr, SocketAddr},
net::{IpAddr, Ipv4Addr, SocketAddr},
};
use dns_types::DoHUrl;
@@ -8,7 +8,7 @@ use ip_network::IpNetwork;
use crate::{
client::{DNS_SENTINELS_V4, DNS_SENTINELS_V6, IpProvider},
dns::DNS_PORT,
dns::{self, DNS_PORT},
};
#[derive(Debug, Default)]
@@ -18,6 +18,7 @@ pub(crate) struct DnsConfig {
/// The Do53 resolvers configured in the portal.
///
/// Has priority over system-configured DNS servers.
/// Has priority over DoH resolvers.
upstream_do53: Vec<IpAddr>,
/// The DoH resolvers configured in the portal.
///
@@ -30,7 +31,7 @@ pub(crate) struct DnsConfig {
#[derive(Debug, Clone, PartialEq, Eq, Default)]
pub struct DnsMapping {
inner: Vec<(IpAddr, SocketAddr)>,
inner: Vec<(IpAddr, dns::Upstream)>,
}
impl DnsMapping {
@@ -38,11 +39,11 @@ impl DnsMapping {
self.inner.iter().map(|(ip, _)| ip).copied().collect()
}
pub fn upstream_sockets(&self) -> Vec<SocketAddr> {
pub fn upstream_servers(&self) -> Vec<dns::Upstream> {
self.inner
.iter()
.map(|(_, socket)| socket)
.copied()
.map(|(_, upstream)| upstream)
.cloned()
.collect()
}
@@ -54,16 +55,16 @@ impl DnsMapping {
// Most importantly, it is much easier for us to retain the ordering of the DNS servers if we don't use a map.
#[cfg(test)]
pub(crate) fn sentinel_by_upstream(&self, upstream: SocketAddr) -> Option<IpAddr> {
pub(crate) fn sentinel_by_upstream(&self, upstream: &dns::Upstream) -> Option<IpAddr> {
self.inner
.iter()
.find_map(|(sentinel, candidate)| (candidate == &upstream).then_some(*sentinel))
.find_map(|(sentinel, candidate)| (candidate == upstream).then_some(*sentinel))
}
pub(crate) fn upstream_by_sentinel(&self, sentinel: IpAddr) -> Option<SocketAddr> {
pub(crate) fn upstream_by_sentinel(&self, sentinel: IpAddr) -> Option<dns::Upstream> {
self.inner
.iter()
.find_map(|(candidate, upstream)| (candidate == &sentinel).then_some(*upstream))
.find_map(|(candidate, upstream)| (candidate == &sentinel).then_some(upstream.clone()))
}
}
@@ -104,7 +105,7 @@ impl DnsConfig {
}
pub(crate) fn has_custom_upstream(&self) -> bool {
!self.upstream_do53.is_empty()
!self.upstream_do53.is_empty() || !self.upstream_doh.is_empty()
}
pub(crate) fn mapping(&mut self) -> DnsMapping {
@@ -116,11 +117,14 @@ impl DnsConfig {
}
fn update_dns_mapping(&mut self) -> bool {
let effective_dns_servers =
effective_dns_servers(self.upstream_do53.clone(), self.system_resolvers.clone());
let effective_dns_servers = effective_dns_servers(
self.upstream_do53.clone(),
self.upstream_doh.clone(),
self.system_resolvers.clone(),
);
if HashSet::<SocketAddr>::from_iter(effective_dns_servers.clone())
== HashSet::from_iter(self.mapping.upstream_sockets())
if HashSet::<dns::Upstream>::from_iter(effective_dns_servers.clone())
== HashSet::from_iter(self.mapping.upstream_servers())
{
tracing::debug!(servers = ?effective_dns_servers, "Effective DNS servers are unchanged");
@@ -135,12 +139,22 @@ impl DnsConfig {
fn effective_dns_servers(
upstream_do53: Vec<IpAddr>,
upstream_doh: Vec<DoHUrl>,
default_resolvers: Vec<IpAddr>,
) -> Vec<SocketAddr> {
) -> Vec<dns::Upstream> {
if !upstream_do53.is_empty() {
return upstream_do53
.into_iter()
.map(|ip| SocketAddr::new(ip, DNS_PORT))
.map(|ip| dns::Upstream::Do53 {
server: SocketAddr::new(ip, DNS_PORT),
})
.collect();
}
if !upstream_doh.is_empty() {
return upstream_doh
.into_iter()
.map(|server| dns::Upstream::DoH { server })
.collect();
}
@@ -153,22 +167,28 @@ fn effective_dns_servers(
default_resolvers
.into_iter()
.map(|ip| SocketAddr::new(ip, DNS_PORT))
.map(|ip| dns::Upstream::Do53 {
server: SocketAddr::new(ip, DNS_PORT),
})
.collect()
}
fn sentinel_dns_mapping(dns: &[SocketAddr], old_sentinels: Vec<IpAddr>) -> DnsMapping {
fn sentinel_dns_mapping(dns: &[dns::Upstream], old_sentinels: Vec<IpAddr>) -> DnsMapping {
let mut ip_provider = IpProvider::for_stub_dns_servers(old_sentinels);
let mapping = dns
.iter()
.copied()
.map(|i| {
.map(|u| {
let ip_addr = match u {
dns::Upstream::Do53 { server } => server.ip(),
dns::Upstream::DoH { .. } => IpAddr::V4(Ipv4Addr::UNSPECIFIED), // DoH servers are always mapped to IPv4 servers.
};
(
ip_provider
.get_proxy_ip_for(&i.ip())
.get_proxy_ip_for(&ip_addr)
.expect("We only support up to 256 IPv4 DNS servers and 256 IPv6 DNS servers"),
i,
u.clone(),
)
})
.collect();
@@ -204,11 +224,11 @@ mod tests {
assert_eq!(config.mapping().sentinel_ips().len(), 3);
assert_eq!(
config.mapping().upstream_sockets(),
config.mapping().upstream_servers(),
vec![
socket("1.1.1.1:53"),
socket("1.0.0.1:53"),
socket("[2606:4700:4700::1111]:53"),
do53("1.1.1.1:53"),
do53("1.0.0.1:53"),
do53("[2606:4700:4700::1111]:53"),
]
);
}
@@ -224,8 +244,8 @@ mod tests {
assert_eq!(config.mapping().sentinel_ips().len(), 1);
assert_eq!(
config.mapping().upstream_sockets(),
vec![socket("1.0.0.1:53"),]
config.mapping().upstream_servers(),
vec![do53("1.0.0.1:53"),]
);
}
@@ -238,8 +258,8 @@ mod tests {
assert_eq!(config.mapping().sentinel_ips().len(), 1);
assert_eq!(
config.mapping().upstream_sockets(),
vec![socket("1.1.1.1:53"),]
config.mapping().upstream_servers(),
vec![do53("1.1.1.1:53"),]
);
}
@@ -253,8 +273,8 @@ mod tests {
assert_eq!(config.mapping().sentinel_ips().len(), 1);
assert_eq!(
config.mapping().upstream_sockets(),
vec![socket("1.1.1.1:53"),]
config.mapping().upstream_servers(),
vec![do53("1.1.1.1:53"),]
);
}
@@ -262,7 +282,9 @@ mod tests {
address.parse().unwrap()
}
fn socket(socket: &str) -> SocketAddr {
socket.parse().unwrap()
fn do53(socket: &str) -> dns::Upstream {
dns::Upstream::Do53 {
server: socket.parse().unwrap(),
}
}
}

View File

@@ -2,8 +2,8 @@ use crate::client::IpProvider;
use anyhow::Result;
use connlib_model::{IpStack, ResourceId};
use dns_types::{
DomainName, DomainNameRef, OwnedRecordData, Query, RecordType, Response, ResponseBuilder,
ResponseCode,
DoHUrl, DomainName, DomainNameRef, OwnedRecordData, Query, RecordType, Response,
ResponseBuilder, ResponseCode,
};
use firezone_logging::err_with_src;
use itertools::Itertools;
@@ -54,7 +54,7 @@ struct Resource {
#[derive(Debug)]
pub(crate) struct RecursiveQuery {
/// The server we want to send the query to.
pub server: SocketAddr,
pub server: Upstream,
/// The local address we received the query on.
pub local: SocketAddr,
@@ -73,7 +73,7 @@ pub(crate) struct RecursiveQuery {
#[derive(Debug)]
pub(crate) struct RecursiveResponse {
/// The server we sent the query to.
pub server: SocketAddr,
pub server: Upstream,
/// The local address we received the original query on.
pub local: SocketAddr,
@@ -99,6 +99,14 @@ pub(crate) enum Transport {
Tcp,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, derive_more::Display)]
pub enum Upstream {
#[display("Do53({server})")]
Do53 { server: SocketAddr },
#[display("DoH({server})")]
DoH { server: DoHUrl },
}
/// Tells the Client how to reply to a single DNS query
#[derive(Debug)]
pub(crate) enum ResolveStrategy {

View File

@@ -1,3 +1,4 @@
mod doh;
mod gso_queue;
mod nameserver_set;
mod tcp_dns;
@@ -6,10 +7,12 @@ mod udp_dns;
use crate::{TunnelError, device_channel::Device, dns, otel, sockets::Sockets};
use anyhow::{Context as _, Result};
use chrono::{DateTime, Utc};
use dns_types::DoHUrl;
use futures::FutureExt as _;
use futures_bounded::FuturesTupleSet;
use futures_bounded::{FuturesMap, FuturesTupleSet};
use gat_lending_iterator::LendingIterator;
use gso_queue::GsoQueue;
use http_client::HttpClient;
use ip_packet::{Ecn, IpPacket, MAX_FZ_PAYLOAD};
use nameserver_set::NameserverSet;
use socket_factory::{DatagramIn, SocketFactory, TcpSocket, UdpSocket};
@@ -57,6 +60,10 @@ pub struct Io {
dns_queries: FuturesTupleSet<Result<dns_types::Response>, DnsQueryMetaData>,
udp_dns_client: l4_udp_dns_client::UdpDnsClient,
doh_clients: BTreeMap<DoHUrl, HttpClient>,
doh_clients_bootstrap: FuturesMap<DoHUrl, Result<HttpClient>>,
timeout: Option<Pin<Box<tokio::time::Sleep>>>,
tun: Device,
@@ -65,10 +72,10 @@ pub struct Io {
dropped_packets: opentelemetry::metrics::Counter<u64>,
}
#[derive(Debug)]
#[derive(Debug, Clone)]
struct DnsQueryMetaData {
query: dns_types::Query,
server: SocketAddr,
server: dns::Upstream,
local: SocketAddr,
remote: SocketAddr,
transport: dns::Transport,
@@ -164,6 +171,10 @@ impl Io {
tcp_socket_factory.clone(),
udp_socket_factory.clone(),
),
udp_dns_client: l4_udp_dns_client::UdpDnsClient::new(
udp_socket_factory.clone(),
Vec::default(),
),
reval_nameserver_interval: tokio::time::interval(RE_EVALUATE_NAMESERVER_INTERVAL),
tcp_socket_factory,
udp_socket_factory,
@@ -171,6 +182,11 @@ impl Io {
|| futures_bounded::Delay::tokio(DNS_QUERY_TIMEOUT),
1000,
),
doh_clients: Default::default(),
doh_clients_bootstrap: FuturesMap::new(
|| futures_bounded::Delay::tokio(DNS_QUERY_TIMEOUT),
10,
),
gso_queue: GsoQueue::new(),
tun: Device::new(),
udp_dns_server: Default::default(),
@@ -203,6 +219,13 @@ impl Io {
Ok(())
}
pub fn update_system_resolvers(&mut self, resolvers: Vec<IpAddr>) {
tracing::debug!(servers = ?resolvers, "Re-configuring UDP DNS client with new upstreams");
self.udp_dns_client =
l4_udp_dns_client::UdpDnsClient::new(self.udp_socket_factory.clone(), resolvers)
}
pub fn poll_has_sockets(&mut self, cx: &mut Context<'_>) -> Poll<()> {
self.sockets.poll_has_sockets(cx)
}
@@ -234,6 +257,16 @@ impl Io {
// We purposely don't want to block the event loop here because we can do plenty of other work while this is running.
let _ = self.nameservers.poll(cx);
while let Poll::Ready((url, result)) = self.doh_clients_bootstrap.poll_unpin(cx) {
match result {
Ok(Ok(client)) => {
self.doh_clients.insert(url.clone(), client);
}
Ok(Err(e)) => tracing::debug!(%url, "Failed to bootstrap DoH client: {e:#}"),
Err(e) => tracing::debug!(%url, "Failed to bootstrap DoH client: {e:#}"),
}
}
let network = self.sockets.poll_recv_from(cx).map(|network| {
anyhow::Ok(
network
@@ -323,6 +356,18 @@ impl Io {
},
});
// We need to discard DoH clients if their queries fail because the connection got closed.
// They will get re-bootstrapped on the next requested DoH query.
if let Poll::Ready(response) = &dns_response
&& let dns::Upstream::DoH { server } = &response.server
&& let Err(e) = &response.message
&& e.is::<http_client::Closed>()
{
tracing::debug!(%server, "Connection of DoH client failed");
self.doh_clients.remove(server);
}
let timeout = self
.timeout
.as_mut()
@@ -423,6 +468,10 @@ impl Io {
self.dns_queries =
FuturesTupleSet::new(|| futures_bounded::Delay::tokio(DNS_QUERY_TIMEOUT), 1000);
self.nameservers.evaluate();
for (server, _) in std::mem::take(&mut self.doh_clients) {
self.bootstrap_doh_client(server);
}
}
pub fn reset_timeout(&mut self, timeout: Instant, reason: &'static str) {
@@ -470,40 +519,69 @@ impl Io {
pub fn send_dns_query(&mut self, query: dns::RecursiveQuery) {
let meta = DnsQueryMetaData {
query: query.message.clone(),
server: query.server,
server: query.server.clone(),
transport: query.transport,
local: query.local,
remote: query.remote,
};
match query.transport {
dns::Transport::Udp => {
if self
.dns_queries
.try_push(
udp_dns::send(self.udp_socket_factory.clone(), query.server, query.message),
meta,
)
.is_err()
{
tracing::debug!("Failed to queue UDP DNS query")
}
match (query.transport, query.server) {
(dns::Transport::Udp, dns::Upstream::Do53 { server }) => {
self.queue_dns_query(
udp_dns::send(self.udp_socket_factory.clone(), server, query.message),
meta,
);
}
dns::Transport::Tcp => {
if self
.dns_queries
.try_push(
tcp_dns::send(self.tcp_socket_factory.clone(), query.server, query.message),
meta,
)
.is_err()
{
tracing::debug!("Failed to queue TCP DNS query")
}
(dns::Transport::Tcp, dns::Upstream::Do53 { server }) => {
self.queue_dns_query(
tcp_dns::send(self.tcp_socket_factory.clone(), server, query.message),
meta,
);
}
(_, dns::Upstream::DoH { server }) => {
let Some(http_client) = self.doh_clients.get(&server).cloned() else {
self.bootstrap_doh_client(server);
// Queue a dummy "query" that instantly fails to ensure we don't let the application run into a timeout.
// This will trigger a SERVFAIL response.
self.queue_dns_query(async { anyhow::bail!("Bootstrapping DoH client") }, meta);
return;
};
self.queue_dns_query(doh::send(http_client, server, query.message), meta);
}
}
}
pub(crate) fn bootstrap_doh_client(&mut self, server: DoHUrl) {
if self.doh_clients.contains_key(&server) {
return;
}
if self.doh_clients_bootstrap.contains(server.clone()) {
return; // Already bootstrapping.
}
let socket_factory = self.tcp_socket_factory.clone();
let addresses = self.udp_dns_client.resolve(server.host());
let _ = self
.doh_clients_bootstrap
.try_push(server.clone(), async move {
tracing::debug!(%server, "Bootstrapping DoH client");
let addresses = addresses.await?;
let http_client =
HttpClient::new(server.host().to_string(), addresses.clone(), socket_factory)
.await?;
tracing::debug!(%server, "Bootstrapped DoH client");
Ok(http_client)
});
}
pub(crate) fn send_udp_dns_response(
&mut self,
to: SocketAddr,
@@ -531,6 +609,16 @@ impl Io {
pub(crate) fn inc_dropped_packet(&self, attrs: &[opentelemetry::KeyValue]) {
self.dropped_packets.add(1, attrs);
}
fn queue_dns_query(
&mut self,
future: impl Future<Output = Result<dns_types::Response>> + Send + 'static,
meta: DnsQueryMetaData,
) {
if self.dns_queries.try_push(future, meta.clone()).is_err() {
tracing::debug!(?meta, "Failed to queue DNS query")
}
}
}
fn is_max_wg_packet_size(d: &DatagramIn) -> bool {
@@ -545,7 +633,7 @@ fn is_max_wg_packet_size(d: &DatagramIn) -> bool {
#[cfg(test)]
mod tests {
use futures::task::noop_waker_ref;
use std::{future::poll_fn, ptr::addr_of_mut};
use std::{future::poll_fn, net::Ipv4Addr, ptr::addr_of_mut};
use super::*;
@@ -581,14 +669,62 @@ mod tests {
assert!(timeout >= now, "timeout = {timeout:?}, now = {now:?}");
}
#[tokio::test]
async fn bootstrap_doh() {
let _guard = firezone_logging::test("debug");
let mut io = Io::for_test();
io.update_system_resolvers(vec![IpAddr::from([1, 1, 1, 1])]);
{
io.send_dns_query(example_com_recursive_query());
let input = io.next().await;
assert_eq!(
input.dns_response.unwrap().message.unwrap_err().to_string(),
"Bootstrapping DoH client"
);
}
// Hack: Advance for a bit but timeout after 2s. We don't emit an event when the client is bootstrapped so this will always be `Pending`.
let _ = tokio::time::timeout(Duration::from_secs(2), io.next()).await;
{
io.send_dns_query(example_com_recursive_query());
let input = io.next().await;
assert_eq!(
input.dns_response.unwrap().message.unwrap().response_code(),
dns_types::ResponseCode::NOERROR
);
}
}
fn example_com_recursive_query() -> dns::RecursiveQuery {
dns::RecursiveQuery {
server: dns::Upstream::DoH {
server: DoHUrl::cloudflare(),
},
local: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 11111),
remote: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 22222),
message: dns_types::Query::new(
"example.com".parse().unwrap(),
dns_types::RecordType::A,
),
transport: dns::Transport::Udp,
}
}
static mut DUMMY_BUF: Buffers = Buffers { ip: Vec::new() };
/// Helper functions to make the test more concise.
impl Io {
fn for_test() -> Io {
let mut io = Io::new(
Arc::new(|_| Err(io::Error::other("not implemented"))),
Arc::new(|_| Err(io::Error::other("not implemented"))),
Arc::new(socket_factory::tcp),
Arc::new(socket_factory::udp),
BTreeSet::new(),
);
io.set_tun(Box::new(DummyTun));

View File

@@ -0,0 +1,17 @@
use anyhow::Result;
use dns_types::DoHUrl;
use http_client::HttpClient;
pub async fn send(
client: HttpClient,
server: DoHUrl,
query: dns_types::Query,
) -> Result<dns_types::Response> {
tracing::trace!(target: "wire::dns::recursive::https", %server, domain = %query.domain());
let request = query.try_into_http_request(&server)?;
let response = client.send_request(request)?.await?;
let response = dns_types::Response::try_from_http_response(response)?;
Ok(response)
}

View File

@@ -145,6 +145,13 @@ impl ClientTunnel {
self.io.reset();
}
pub fn update_system_resolvers(&mut self, resolvers: Vec<IpAddr>) -> Vec<IpAddr> {
let resolvers = self.role_state.update_system_resolvers(resolvers);
self.io.update_system_resolvers(resolvers.clone()); // IO needs the system resolvers to bootstrap DoH upstream.
resolvers
}
/// Shut down the Client tunnel.
pub fn shut_down(mut self) -> BoxFuture<'static, Result<()>> {
// Initiate shutdown.
@@ -178,6 +185,16 @@ impl ClientTunnel {
// Pass up existing events.
if let Some(event) = self.role_state.poll_event() {
if let ClientEvent::TunInterfaceUpdated(config) = &event {
for url in &config.dns_by_sentinel.upstream_servers() {
let dns::Upstream::DoH { server } = url else {
continue;
};
self.io.bootstrap_doh_client(server.clone());
}
}
return Poll::Ready(event);
}
@@ -480,7 +497,9 @@ impl GatewayTunnel {
for query in udp_dns_queries {
if let Some(nameserver) = self.io.fastest_nameserver() {
self.io.send_dns_query(dns::RecursiveQuery {
server: SocketAddr::new(nameserver, dns::DNS_PORT),
server: dns::Upstream::Do53 {
server: SocketAddr::new(nameserver, dns::DNS_PORT),
},
local: query.local,
remote: query.remote,
message: query.message,
@@ -504,7 +523,9 @@ impl GatewayTunnel {
for query in tcp_dns_queries {
if let Some(nameserver) = self.io.fastest_nameserver() {
self.io.send_dns_query(dns::RecursiveQuery {
server: SocketAddr::new(nameserver, dns::DNS_PORT),
server: dns::Upstream::Do53 {
server: SocketAddr::new(nameserver, dns::DNS_PORT),
},
local: query.local,
remote: query.remote,
message: query.message,

View File

@@ -351,7 +351,7 @@ pub(crate) fn assert_udp_dns_packets_properties(ref_client: &RefClient, sim_clie
for (dns_server, query_id) in ref_client.expected_udp_dns_handshakes.iter() {
let _guard =
tracing::info_span!(target: "assertions", "udp_dns", %query_id, %dns_server).entered();
let key = &(*dns_server, *query_id);
let key = &(dns_server.clone(), *query_id);
let queries = &sim_client.sent_udp_dns_queries;
let responses = &sim_client.received_udp_dns_responses;
@@ -374,7 +374,7 @@ pub(crate) fn assert_tcp_dns(ref_client: &RefClient, sim_client: &SimClient) {
for (dns_server, query_id) in ref_client.expected_tcp_dns_handshakes.iter() {
let _guard =
tracing::info_span!(target: "assertions", "tcp_dns", %query_id, %dns_server).entered();
let key = &(*dns_server, *query_id);
let key = &(dns_server.clone(), *query_id);
let queries = &sim_client.sent_tcp_dns_queries;
let responses = &sim_client.received_tcp_dns_responses;

View File

@@ -4,8 +4,8 @@ use super::{
composite_strategy::CompositeStrategy, sim_client::*, sim_gateway::*, sim_net::*,
strategies::*, stub_portal::StubPortal, transition::*,
};
use crate::client;
use crate::proptest::domain_label;
use crate::{client, dns};
use crate::{dns::is_subdomain, proptest::relay_id};
use connlib_model::{GatewayId, RelayId, Site, StaticSecret};
use dns_types::{DomainName, RecordType};
@@ -756,10 +756,12 @@ impl ReferenceState {
Transition::UpdateUpstreamDoHServers(_) => true,
Transition::UpdateUpstreamSearchDomain(_) => true,
Transition::SendDnsQueries(queries) => queries.iter().all(|query| {
let has_socket_for_server = state
.client
.sending_socket_for(query.dns_server.ip())
.is_some();
let has_socket_for_server = match query.dns_server {
crate::dns::Upstream::Do53 { server } => {
state.client.sending_socket_for(server.ip()).is_some()
}
crate::dns::Upstream::DoH { .. } => true,
};
let has_dns_server = state
.client
@@ -919,14 +921,19 @@ impl ReferenceState {
Vec::from_iter(unique_domains)
}
fn reachable_dns_servers(&self) -> Vec<SocketAddr> {
fn reachable_dns_servers(&self) -> Vec<dns::Upstream> {
self.client
.inner()
.expected_dns_servers()
.into_iter()
.filter(|s| match s {
SocketAddr::V4(_) => self.client.ip4.is_some(),
SocketAddr::V6(_) => self.client.ip6.is_some(),
crate::dns::Upstream::Do53 {
server: SocketAddr::V4(_),
} => self.client.ip4.is_some(),
crate::dns::Upstream::Do53 {
server: SocketAddr::V6(_),
} => self.client.ip6.is_some(),
crate::dns::Upstream::DoH { .. } => true,
})
.collect()
}

View File

@@ -8,7 +8,7 @@ use super::{
transition::{DPort, Destination, DnsQuery, DnsTransport, Identifier, SPort, Seq},
};
use crate::{
ClientState, DnsMapping, DnsResourceRecord,
ClientState, DnsMapping, DnsResourceRecord, dns,
messages::{UpstreamDo53, UpstreamDoH},
proptest::*,
};
@@ -61,11 +61,11 @@ pub(crate) struct SimClient {
pub(crate) resource_status: BTreeMap<ResourceId, ResourceStatus>,
pub(crate) sent_udp_dns_queries: HashMap<(SocketAddr, QueryId), IpPacket>,
pub(crate) received_udp_dns_responses: BTreeMap<(SocketAddr, QueryId), IpPacket>,
pub(crate) sent_udp_dns_queries: HashMap<(dns::Upstream, QueryId), IpPacket>,
pub(crate) received_udp_dns_responses: BTreeMap<(dns::Upstream, QueryId), IpPacket>,
pub(crate) sent_tcp_dns_queries: HashSet<(SocketAddr, QueryId)>,
pub(crate) received_tcp_dns_responses: BTreeSet<(SocketAddr, QueryId)>,
pub(crate) sent_tcp_dns_queries: HashSet<(dns::Upstream, QueryId)>,
pub(crate) received_tcp_dns_responses: BTreeSet<(dns::Upstream, QueryId)>,
pub(crate) sent_icmp_requests: HashMap<(Seq, Identifier), IpPacket>,
pub(crate) received_icmp_replies: BTreeMap<(Seq, Identifier), IpPacket>,
@@ -138,8 +138,8 @@ impl SimClient {
}
/// Returns the _effective_ DNS servers that connlib is using.
pub(crate) fn effective_dns_servers(&self) -> Vec<SocketAddr> {
self.dns_by_sentinel.upstream_sockets()
pub(crate) fn effective_dns_servers(&self) -> Vec<dns::Upstream> {
self.dns_by_sentinel.upstream_servers()
}
pub(crate) fn effective_search_domain(&self) -> Option<DomainName> {
@@ -160,11 +160,11 @@ impl SimClient {
domain: DomainName,
r_type: RecordType,
query_id: u16,
upstream: SocketAddr,
upstream: dns::Upstream,
dns_transport: DnsTransport,
now: Instant,
) -> Option<Transmit> {
let Some(sentinel) = self.dns_by_sentinel.sentinel_by_upstream(upstream) else {
let Some(sentinel) = self.dns_by_sentinel.sentinel_by_upstream(&upstream) else {
tracing::error!(%upstream, "Unknown DNS server");
return None;
};
@@ -493,10 +493,10 @@ pub struct RefClient {
/// The expected UDP DNS handshakes.
#[debug(skip)]
pub(crate) expected_udp_dns_handshakes: VecDeque<(SocketAddr, QueryId)>,
pub(crate) expected_udp_dns_handshakes: VecDeque<(dns::Upstream, QueryId)>,
/// The expected TCP DNS handshakes.
#[debug(skip)]
pub(crate) expected_tcp_dns_handshakes: VecDeque<(SocketAddr, QueryId)>,
pub(crate) expected_tcp_dns_handshakes: VecDeque<(dns::Upstream, QueryId)>,
}
impl RefClient {
@@ -926,11 +926,11 @@ impl RefClient {
match query.transport {
DnsTransport::Udp => {
self.expected_udp_dns_handshakes
.push_back((query.dns_server, query.query_id));
.push_back((query.dns_server.clone(), query.query_id));
}
DnsTransport::Tcp => {
self.expected_tcp_dns_handshakes
.push_back((query.dns_server, query.query_id));
.push_back((query.dns_server.clone(), query.query_id));
}
}
@@ -1076,22 +1076,37 @@ impl RefClient {
/// Returns the DNS servers that we expect connlib to use.
///
/// If there are upstream DNS servers configured in the portal, it should use those.
/// If there are upstream Do53 servers configured in the portal, it should use those.
/// If there are no custom servers defined, it should use the DoH servers specified in the portal.
/// Otherwise it should use whatever was configured on the system prior to connlib starting.
///
/// This purposely returns a `Vec` so we also assert the order!
pub(crate) fn expected_dns_servers(&self) -> Vec<SocketAddr> {
pub(crate) fn expected_dns_servers(&self) -> Vec<dns::Upstream> {
if !self.upstream_do53_resolvers.is_empty() {
return self
.upstream_do53_resolvers
.iter()
.map(|u| SocketAddr::new(u.ip, 53))
.map(|u| dns::Upstream::Do53 {
server: SocketAddr::new(u.ip, 53),
})
.collect();
}
if !self.upstream_doh_resolvers.is_empty() {
return self
.upstream_doh_resolvers
.iter()
.map(|u| dns::Upstream::DoH {
server: u.url.clone(),
})
.collect();
}
self.system_dns_resolvers
.iter()
.map(|ip| SocketAddr::new(*ip, 53))
.map(|ip| dns::Upstream::Do53 {
server: SocketAddr::new(*ip, 53),
})
.collect()
}
@@ -1185,7 +1200,12 @@ impl RefClient {
return None;
}
let maybe_active_cidr_resource = self.cidr_resource_by_ip(query.dns_server.ip());
let server = match query.dns_server {
dns::Upstream::Do53 { server } => server,
dns::Upstream::DoH { .. } => return None,
};
let maybe_active_cidr_resource = self.cidr_resource_by_ip(server.ip());
let maybe_active_internet_resource = self.active_internet_resource();
maybe_active_cidr_resource.or(maybe_active_internet_resource)

View File

@@ -279,7 +279,7 @@ impl TunnelTest {
upstream_dns: vec![],
upstream_do53,
search_domain: ref_state.client.inner().search_domain.clone(),
upstream_doh: vec![],
upstream_doh: ref_state.client.inner().upstream_doh_resolvers(),
})
});
}
@@ -424,6 +424,7 @@ impl TunnelTest {
let ipv6 = state.client.inner().sut.tunnel_ip_config().unwrap().v6;
let system_dns = ref_state.client.inner().system_dns_resolvers();
let upstream_do53 = ref_state.client.inner().upstream_do53_resolvers();
let upstream_doh = ref_state.client.inner().upstream_doh_resolvers();
let all_resources = ref_state.client.inner().all_resources();
let internet_resource_state = ref_state.client.inner().internet_resource_active;
@@ -436,8 +437,8 @@ impl TunnelTest {
ipv6,
upstream_dns: Vec::new(),
upstream_do53,
upstream_doh,
search_domain: ref_state.client.inner().search_domain.clone(),
upstream_doh: Vec::new(),
});
c.sut.update_system_resolvers(system_dns);
c.sut.set_resources(all_resources, now);
@@ -927,7 +928,18 @@ impl TunnelTest {
for gateway in self.gateways.values_mut() {
gateway.exec_mut(|g| {
g.deploy_new_dns_servers(config.dns_by_sentinel.upstream_sockets(), now)
// If DoH servers are configured, we never route them through the tunnel.
// Therefore, we also don't need to "deploy" any DNS servers here.
let upstream_do53_servers = config
.dns_by_sentinel
.upstream_servers()
.into_iter()
.filter_map(|u| match u {
dns::Upstream::Do53 { server } => Some(server),
dns::Upstream::DoH { .. } => None,
});
g.deploy_new_dns_servers(upstream_do53_servers, now)
})
}

View File

@@ -1,5 +1,6 @@
use crate::{
client::{CidrResource, IPV4_RESOURCES, IPV6_RESOURCES, Resource},
dns,
messages::{UpstreamDo53, UpstreamDoH},
proptest::{host_v4, host_v6},
};
@@ -15,7 +16,7 @@ use prop::collection;
use proptest::{prelude::*, sample};
use std::{
collections::{BTreeMap, BTreeSet},
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
net::{IpAddr, Ipv4Addr, Ipv6Addr},
num::NonZeroU16,
};
@@ -119,7 +120,7 @@ pub(crate) struct DnsQuery {
pub(crate) r_type: RecordType,
/// The DNS query ID.
pub(crate) query_id: u16,
pub(crate) dns_server: SocketAddr,
pub(crate) dns_server: dns::Upstream,
pub(crate) transport: DnsTransport,
}
@@ -352,7 +353,7 @@ fn non_dns_ports() -> impl Strategy<Value = u16> {
/// Samples up to 5 DNS queries that will be sent concurrently into connlib.
pub(crate) fn dns_queries(
domain: impl Strategy<Value = (DomainName, Vec<RecordType>)>,
dns_server: impl Strategy<Value = SocketAddr>,
dns_server: impl Strategy<Value = dns::Upstream>,
) -> impl Strategy<Value = Vec<DnsQuery>> {
// Queries can be uniquely identified by the tuple of DNS server and query ID.
let unique_queries = collection::btree_set((dns_server, any::<u16>()), 1..5);