feat(gateway): forward queries to local nameserver (#8350)

The DNS server added in #8285 was only a dummy DNS server that added
infrastructure to actually receive DNS queries on the IP of the TUN
device at port 53535 and it returns SERVFAIL for all queries. For this
DNS server to be useful, we need to take those queries and replay them
towards a DNS server that is configured locally on the Gateway.

To achieve this, we parse `/etc/resolv.conf` during startup of the
Gateway and pass the contained nameservers into the tunnel. From there,
the Gateway's event-loop can receive the queries, feed them into the
already existing machinery for performing recursive DNS queries that we
use on the Client and resolve the records.

In its current implementation, we only use the first nameserver defined
in `/etc/resolv.conf`. If the lookup fails, we send back a SERVFAIL
error and log a message.

Resolves: #8221
This commit is contained in:
Thomas Eizinger
2025-03-06 07:23:01 +11:00
committed by GitHub
parent e4ab0f1cb4
commit eacf67f2bc
12 changed files with 370 additions and 82 deletions

1
rust/Cargo.lock generated
View File

@@ -2027,6 +2027,7 @@ dependencies = [
"moka",
"nix 0.29.0",
"phoenix-channel",
"resolv-conf",
"rustls",
"secrecy",
"serde",

View File

@@ -93,6 +93,7 @@ rand_core = "0.6.4"
rangemap = "1.5.1"
rayon = "1.10.0"
reqwest = { version = "0.12.9", default-features = false }
resolv-conf = "0.7.0"
rtnetlink = { version = "0.14.1", default-features = false, features = ["tokio_socket"] }
rustls = { version = "0.23.21", default-features = false, features = ["ring"] }
sadness-generator = "0.6.0"

View File

@@ -118,9 +118,15 @@ impl Server {
continue;
};
self.tcp_streams_by_remote.insert(from, stream); // Store the stream so we can send a response back later.
let local = stream.local_addr()?;
// Store the stream so we can send a response back later.
// We don't need to index by the local address because we only ever listen on a single socket.
self.tcp_streams_by_remote.insert(from, stream);
return Poll::Ready(Ok(Query {
source: from,
local,
remote: from,
message,
}));
}
@@ -179,7 +185,8 @@ async fn read_tcp_query(
}
pub struct Query {
pub source: SocketAddr,
pub local: SocketAddr,
pub remote: SocketAddr,
pub message: Message<Vec<u8>>,
}
@@ -220,7 +227,7 @@ mod tests {
let query = poll_fn(|cx| server.poll(cx)).await.unwrap();
server
.send_response(query.source, empty_dns_response(query.message))
.send_response(query.remote, empty_dns_response(query.message))
.unwrap();
}
});

View File

@@ -86,7 +86,7 @@ impl RecursiveQuery {
}
}
#[derive(Debug)]
#[derive(Debug, Clone, Copy)]
pub(crate) enum Transport {
Udp {
/// The original source we received the DNS query on.

View File

@@ -1,4 +1,7 @@
mod gso_queue;
mod nameserver_set;
mod tcp_dns;
mod udp_dns;
use crate::{device_channel::Device, dns, sockets::Sockets};
use anyhow::Result;
@@ -8,17 +11,17 @@ use futures::FutureExt as _;
use futures_bounded::FuturesTupleSet;
use gso_queue::GsoQueue;
use ip_packet::{IpPacket, MAX_FZ_PAYLOAD};
use nameserver_set::NameserverSet;
use socket_factory::{DatagramIn, SocketFactory, TcpSocket, UdpSocket};
use std::{
collections::VecDeque,
collections::{BTreeSet, VecDeque},
io,
net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6},
net::{IpAddr, SocketAddr, SocketAddrV4, SocketAddrV6},
pin::Pin,
sync::Arc,
task::{ready, Context, Poll},
time::{Duration, Instant},
};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tracing::Instrument;
use tun::Tun;
@@ -45,6 +48,8 @@ pub struct Io {
sockets: Sockets,
gso_queue: GsoQueue,
nameservers: NameserverSet,
udp_dns_server: l4_udp_dns_server::Server,
tcp_dns_server: l4_tcp_dns_server::Server,
@@ -100,14 +105,19 @@ impl Io {
pub fn new(
tcp_socket_factory: Arc<dyn SocketFactory<TcpSocket>>,
udp_socket_factory: Arc<dyn SocketFactory<UdpSocket>>,
nameservers: BTreeSet<IpAddr>,
) -> Self {
let mut sockets = Sockets::default();
sockets.rebind(udp_socket_factory.as_ref()); // Bind sockets on startup. Must happen within a tokio runtime context.
let mut nameservers = NameserverSet::new(nameservers, udp_socket_factory.clone());
nameservers.evaluate();
Self {
outbound_packet_buffer: VecDeque::with_capacity(10), // It is unlikely that we process more than 10 packets after 1 GRO call.
timeout: None,
sockets,
nameservers,
tcp_socket_factory,
udp_socket_factory,
dns_queries: FuturesTupleSet::new(DNS_QUERY_TIMEOUT, 1000),
@@ -136,6 +146,15 @@ impl Io {
self.sockets.poll_has_sockets(cx)
}
pub fn fastest_nameserver(&self) -> io::Result<IpAddr> {
let ns = self
.nameservers
.fastest()
.ok_or(io::Error::other(NoNameserverAvailable))?;
Ok(ns)
}
pub fn poll<'b>(
&mut self,
cx: &mut Context<'_>,
@@ -146,6 +165,7 @@ impl Io {
>,
> {
ready!(self.flush_send_queue(cx)?);
ready!(self.nameservers.poll(cx));
if let Poll::Ready(network) =
self.sockets
@@ -255,6 +275,7 @@ impl Io {
self.sockets.rebind(self.udp_socket_factory.as_ref());
self.gso_queue.clear();
self.dns_queries = FuturesTupleSet::new(DNS_QUERY_TIMEOUT, 1000);
self.nameservers.evaluate();
}
pub fn reset_timeout(&mut self, timeout: Instant) {
@@ -274,39 +295,19 @@ impl Io {
}
pub fn send_dns_query(&mut self, query: dns::RecursiveQuery) {
let meta = DnsQueryMetaData {
query: query.message.clone(),
server: query.server,
transport: query.transport,
};
match query.transport {
dns::Transport::Udp { .. } => {
let factory = self.udp_socket_factory.clone();
let server = query.server;
let bind_addr = match query.server {
SocketAddr::V4(_) => SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), 0),
SocketAddr::V6(_) => SocketAddr::new(Ipv6Addr::UNSPECIFIED.into(), 0),
};
let meta = DnsQueryMetaData {
query: query.message.clone(),
server,
transport: query.transport,
};
if self
.dns_queries
.try_push(
async move {
// To avoid fragmentation, IP and thus also UDP packets can only reliably sent with an MTU of <= 1500 on the public Internet.
const BUF_SIZE: usize = 1500;
let udp_socket = factory(&bind_addr)?;
let response = udp_socket
.handshake::<BUF_SIZE>(server, query.message.as_slice())
.await?;
let message = Message::from_octets(response)
.map_err(|_| io::Error::other("Failed to parse DNS message"))?;
Ok(message)
}
.instrument(telemetry_span!("recursive_udp_dns_query")),
udp_dns::send(self.udp_socket_factory.clone(), query.server, query.message)
.instrument(telemetry_span!("recursive_udp_dns_query")),
meta,
)
.is_err()
@@ -315,41 +316,11 @@ impl Io {
}
}
dns::Transport::Tcp { .. } => {
let factory = self.tcp_socket_factory.clone();
let server = query.server;
let meta = DnsQueryMetaData {
query: query.message.clone(),
server,
transport: query.transport,
};
if self
.dns_queries
.try_push(
async move {
let tcp_socket = factory(&server)?; // TODO: Optimise this to reuse a TCP socket to the same resolver.
let mut tcp_stream = tcp_socket.connect(server).await?;
let query = query.message.into_octets();
let dns_message_length = (query.len() as u16).to_be_bytes();
tcp_stream.write_all(&dns_message_length).await?;
tcp_stream.write_all(&query).await?;
let mut response_length = [0u8; 2];
tcp_stream.read_exact(&mut response_length).await?;
let response_length = u16::from_be_bytes(response_length) as usize;
// A u16 is at most 65k, meaning we are okay to allocate here based on what the remote is sending.
let mut response = vec![0u8; response_length];
tcp_stream.read_exact(&mut response).await?;
let message = Message::from_octets(response)
.map_err(|_| io::Error::other("Failed to parse DNS message"))?;
Ok(message)
}
.instrument(telemetry_span!("recursive_tcp_dns_query")),
tcp_dns::send(self.tcp_socket_factory.clone(), query.server, query.message)
.instrument(telemetry_span!("recursive_tcp_dns_query")),
meta,
)
.is_err()
@@ -377,6 +348,10 @@ impl Io {
}
}
#[derive(Debug, thiserror::Error)]
#[error("No nameserver available to handle DNS query")]
pub struct NoNameserverAvailable;
fn is_max_wg_packet_size(d: &DatagramIn) -> bool {
let len = d.packet.len();
if len > MAX_FZ_PAYLOAD {
@@ -446,6 +421,7 @@ mod tests {
let mut io = Io::new(
Arc::new(|_| Err(io::Error::other("not implemented"))),
Arc::new(|_| Err(io::Error::other("not implemented"))),
BTreeSet::new(),
);
io.set_tun(Box::new(DummyTun));

View File

@@ -0,0 +1,167 @@
use std::{
collections::{BTreeMap, BTreeSet},
io,
net::{IpAddr, SocketAddr},
sync::{Arc, LazyLock},
task::{ready, Context, Poll},
time::{Duration, Instant},
};
use connlib_model::DomainName;
use domain::base::{iana::Rcode, Message, MessageBuilder, Question, Rtype};
use futures_bounded::FuturesTupleSet;
use socket_factory::{SocketFactory, UdpSocket};
use crate::io::udp_dns;
const MAX_DNS_SERVERS: usize = 10; // We don't bother selecting from more than 10 servers.
const DNS_TIMEOUT: Duration = Duration::from_secs(2); // Every sensible DNS servers should respond within 2s.
static FIREZONE_DEV: LazyLock<DomainName> = LazyLock::new(|| {
DomainName::vec_from_str("firezone.dev").expect("static domain should always parse")
});
pub struct NameserverSet {
inner: BTreeSet<IpAddr>,
nameserver_by_rtt: BTreeMap<Duration, IpAddr>,
socket_factory: Arc<dyn SocketFactory<UdpSocket>>,
queries: FuturesTupleSet<io::Result<Message<Vec<u8>>>, QueryMetaData>,
}
struct QueryMetaData {
nameserver: IpAddr,
start: Instant,
}
impl NameserverSet {
pub fn new(
inner: BTreeSet<IpAddr>,
udp_socket_factory: Arc<dyn SocketFactory<UdpSocket>>,
) -> Self {
Self {
queries: FuturesTupleSet::new(DNS_TIMEOUT, MAX_DNS_SERVERS),
inner,
socket_factory: udp_socket_factory,
nameserver_by_rtt: Default::default(),
}
}
pub fn evaluate(&mut self) {
self.nameserver_by_rtt.clear();
let start = Instant::now();
for nameserver in self.inner.iter().copied() {
if self
.queries
.try_push(
udp_dns::send(
self.socket_factory.clone(),
SocketAddr::new(nameserver, crate::dns::DNS_PORT),
query_firezone_dev(),
),
QueryMetaData { nameserver, start },
)
.is_err()
{
tracing::debug!(%nameserver, "Failed to queue another DNS query");
}
}
}
pub fn fastest(&self) -> Option<IpAddr> {
let (_, ns) = self.nameserver_by_rtt.first_key_value()?;
Some(*ns)
}
pub fn poll(&mut self, cx: &mut Context<'_>) -> Poll<()> {
if self.queries.is_empty() {
return Poll::Ready(());
}
loop {
match ready!(self.queries.poll_unpin(cx)) {
(Ok(Ok(response)), meta) if response.header().rcode() == Rcode::NOERROR => {
let rtt = meta.start.elapsed();
tracing::debug!(nameserver = %meta.nameserver, ?rtt, ?response, "DNS query completed");
self.nameserver_by_rtt.insert(rtt, meta.nameserver);
}
(Ok(Ok(response)), meta) => {
tracing::debug!(nameserver = %meta.nameserver, ?response, "DNS query failed");
}
(Ok(Err(e)), meta) => {
tracing::debug!(nameserver = %meta.nameserver, "DNS query failed: {e}");
}
(Err(_), meta) => {
tracing::debug!(nameserver = %meta.nameserver, "DNS query timed out after {DNS_TIMEOUT:?}");
}
}
let Some(fastest) = self.fastest() else {
continue;
};
if self.queries.is_empty() {
tracing::info!(%fastest, "Evaluated fastest nameserver");
return Poll::Ready(());
}
}
}
}
fn query_firezone_dev() -> Message<Vec<u8>> {
let mut builder = MessageBuilder::new_vec().question();
builder.header_mut().set_random_id();
builder.header_mut().set_rd(true);
builder.header_mut().set_qr(false);
builder
.push(Question::new_in(FIREZONE_DEV.clone(), Rtype::A))
.expect("static question should always be valid");
builder.into_message()
}
#[cfg(test)]
mod tests {
use std::net::Ipv4Addr;
use super::*;
#[tokio::test]
#[ignore = "Needs Internet"]
async fn can_evaluate_fastest_nameserver() {
let _guard = firezone_logging::test("debug");
let mut set = NameserverSet::new(
BTreeSet::from([
Ipv4Addr::new(1, 1, 1, 1).into(),
Ipv4Addr::new(8, 8, 8, 8).into(),
Ipv4Addr::new(8, 8, 4, 4).into(),
Ipv4Addr::new(9, 9, 9, 9).into(),
Ipv4Addr::new(100, 100, 100, 100).into(), // Also include an unreachable server.
]),
Arc::new(socket_factory::udp),
);
set.evaluate();
std::future::poll_fn(|cx| set.poll(cx)).await;
assert!(set.fastest().is_some());
}
#[tokio::test]
async fn can_handle_no_servers() {
let _guard = firezone_logging::test("debug");
let mut set = NameserverSet::new(BTreeSet::default(), Arc::new(socket_factory::udp));
std::future::poll_fn(|cx| set.poll(cx)).await;
assert!(set.fastest().is_none());
}
}

View File

@@ -0,0 +1,41 @@
use std::{io, net::SocketAddr, sync::Arc};
use domain::base::{Message, ToName as _};
use socket_factory::{SocketFactory, TcpSocket};
use tokio::io::{AsyncReadExt as _, AsyncWriteExt as _};
pub async fn send(
factory: Arc<dyn SocketFactory<TcpSocket>>,
server: SocketAddr,
query: Message<Vec<u8>>,
) -> io::Result<Message<Vec<u8>>> {
let domain = query
.sole_question()
.expect("all queries should be for a single name")
.qname()
.to_vec();
tracing::trace!(target: "wire::dns::recursive::tcp", %server, %domain);
let tcp_socket = factory(&server)?; // TODO: Optimise this to reuse a TCP socket to the same resolver.
let mut tcp_stream = tcp_socket.connect(server).await?;
let query = query.into_octets();
let dns_message_length = (query.len() as u16).to_be_bytes();
tcp_stream.write_all(&dns_message_length).await?;
tcp_stream.write_all(&query).await?;
let mut response_length = [0u8; 2];
tcp_stream.read_exact(&mut response_length).await?;
let response_length = u16::from_be_bytes(response_length) as usize;
// A u16 is at most 65k, meaning we are okay to allocate here based on what the remote is sending.
let mut response = vec![0u8; response_length];
tcp_stream.read_exact(&mut response).await?;
let message = Message::from_octets(response)
.map_err(|_| io::Error::other("Failed to parse DNS message"))?;
Ok(message)
}

View File

@@ -0,0 +1,40 @@
use std::{
io,
net::{Ipv4Addr, Ipv6Addr, SocketAddr},
sync::Arc,
};
use domain::base::{Message, ToName as _};
use socket_factory::{SocketFactory, UdpSocket};
pub async fn send(
factory: Arc<dyn SocketFactory<UdpSocket>>,
server: SocketAddr,
query: Message<Vec<u8>>,
) -> io::Result<Message<Vec<u8>>> {
let domain = query
.sole_question()
.expect("all queries should be for a single name")
.qname()
.to_vec();
let bind_addr = match server {
SocketAddr::V4(_) => SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), 0),
SocketAddr::V6(_) => SocketAddr::new(Ipv6Addr::UNSPECIFIED.into(), 0),
};
tracing::trace!(target: "wire::dns::recursive::udp", %server, %domain);
// To avoid fragmentation, IP and thus also UDP packets can only reliably sent with an MTU of <= 1500 on the public Internet.
const BUF_SIZE: usize = 1500;
let udp_socket = factory(&bind_addr)?;
let response = udp_socket
.handshake::<BUF_SIZE>(server, query.as_slice())
.await?;
let message = Message::from_octets(response)
.map_err(|_| io::Error::other("Failed to parse DNS message"))?;
Ok(message)
}

View File

@@ -66,6 +66,7 @@ pub type ClientTunnel = Tunnel<ClientState>;
pub use client::ClientState;
pub use gateway::{DnsResourceNatEntry, GatewayState, ResolveDnsRequest};
pub use io::NoNameserverAvailable;
pub use utils::turn;
/// [`Tunnel`] glues together connlib's [`Io`] component and the respective (pure) state of a client or gateway.
@@ -107,7 +108,11 @@ impl ClientTunnel {
udp_socket_factory: Arc<dyn SocketFactory<UdpSocket>>,
) -> Self {
Self {
io: Io::new(tcp_socket_factory, udp_socket_factory),
io: Io::new(
tcp_socket_factory,
udp_socket_factory.clone(),
BTreeSet::default(),
),
role_state: ClientState::new(rand::random(), Instant::now()),
buffers: Buffers::default(),
}
@@ -216,9 +221,10 @@ impl GatewayTunnel {
pub fn new(
tcp_socket_factory: Arc<dyn SocketFactory<TcpSocket>>,
udp_socket_factory: Arc<dyn SocketFactory<UdpSocket>>,
nameservers: BTreeSet<IpAddr>,
) -> Self {
Self {
io: Io::new(tcp_socket_factory, udp_socket_factory),
io: Io::new(tcp_socket_factory, udp_socket_factory.clone(), nameservers),
role_state: GatewayState::new(rand::random(), Instant::now()),
buffers: Buffers::default(),
}
@@ -246,8 +252,23 @@ impl GatewayTunnel {
}
match self.io.poll(cx, &mut self.buffers)? {
Poll::Ready(io::Input::DnsResponse(_)) => {
unreachable!("Gateway doesn't use user-space DNS resolution")
Poll::Ready(io::Input::DnsResponse(response)) => {
let message = response.message.unwrap_or_else(|e| {
tracing::debug!("DNS query failed: {e}");
dns::servfail(response.query.for_slice_ref())
});
match response.transport {
dns::Transport::Udp { source } => {
self.io.send_udp_dns_response(source, message)?;
}
dns::Transport::Tcp { remote, .. } => {
self.io.send_tcp_dns_response(remote, message)?;
}
}
continue;
}
Poll::Ready(io::Input::Timeout(timeout)) => {
self.role_state.handle_timeout(timeout, Utc::now());
@@ -296,14 +317,25 @@ impl GatewayTunnel {
continue;
}
Poll::Ready(io::Input::UdpDnsQuery(query)) => self.io.send_udp_dns_response(
query.source,
dns::servfail(query.message.for_slice_ref()),
)?,
Poll::Ready(io::Input::TcpDnsQuery(query)) => self.io.send_tcp_dns_response(
query.source,
dns::servfail(query.message.for_slice_ref()),
)?,
Poll::Ready(io::Input::UdpDnsQuery(query)) => {
let nameserver = self.io.fastest_nameserver()?;
self.io.send_dns_query(dns::RecursiveQuery::via_udp(
query.source,
SocketAddr::new(nameserver, dns::DNS_PORT),
query.message.for_slice_ref(),
));
}
Poll::Ready(io::Input::TcpDnsQuery(query)) => {
let nameserver = self.io.fastest_nameserver()?;
self.io.send_dns_query(dns::RecursiveQuery::via_tcp(
query.local,
query.remote,
SocketAddr::new(nameserver, dns::DNS_PORT),
query.message,
));
}
Poll::Pending => {}
}

View File

@@ -29,6 +29,7 @@ libc = { workspace = true, features = ["std", "const-extern-fn", "extra_traits"]
moka = { workspace = true, features = ["future"] }
nix = { workspace = true }
phoenix-channel = { workspace = true }
resolv-conf = { workspace = true }
rustls = { workspace = true }
secrecy = { workspace = true }
serde = { workspace = true, features = ["std", "derive"] }

View File

@@ -123,6 +123,12 @@ impl Eventloop {
continue;
}
if e.root_cause()
.is::<firezone_tunnel::NoNameserverAvailable>()
{
return Poll::Ready(Err(Error::NoNameserversAvailable(e)));
}
telemetry_event!("Tunnel error: {e:#}");
continue;
}
@@ -544,6 +550,8 @@ pub enum Error {
UpdateTun(#[source] anyhow::Error),
#[error("{0:#}")]
BindDnsSockets(#[source] anyhow::Error),
#[error("{0:#}")]
NoNameserversAvailable(#[source] anyhow::Error),
}
async fn resolve(domain: DomainName) -> Result<Vec<IpAddr>> {

View File

@@ -16,10 +16,10 @@ use phoenix_channel::LoginUrl;
use futures::{future, TryFutureExt};
use phoenix_channel::PhoenixChannel;
use secrecy::{Secret, SecretString};
use std::path::Path;
use std::pin::pin;
use std::process::ExitCode;
use std::sync::Arc;
use std::{collections::BTreeSet, path::Path};
use tokio::io::AsyncWriteExt;
use tokio::signal::ctrl_c;
use tracing_subscriber::layer;
@@ -113,7 +113,21 @@ async fn try_main(cli: Cli) -> Result<ExitCode> {
)
.context("Failed to construct URL for logging into portal")?;
let mut tunnel = GatewayTunnel::new(Arc::new(tcp_socket_factory), Arc::new(udp_socket_factory));
let resolv_conf = resolv_conf::Config::parse(
std::fs::read_to_string("/etc/resolv.conf").context("Failed to read /etc/resolv.conf")?,
)
.context("Failed to parse /etc/resolv.conf")?;
let nameservers = resolv_conf
.nameservers
.into_iter()
.map(|ip| ip.into())
.collect::<BTreeSet<_>>();
let mut tunnel = GatewayTunnel::new(
Arc::new(tcp_socket_factory),
Arc::new(udp_socket_factory),
nameservers,
);
let portal = PhoenixChannel::disconnected(
Secret::new(login),
get_user_agent(None, env!("CARGO_PKG_VERSION")),