mirror of
https://github.com/outbackdingo/firezone.git
synced 2026-01-27 10:18:54 +00:00
refactor(rust): introduce dns-types crate (#8380)
A sizeable chunk of Firezone's Rust components deal with parsing, manipulating and emitting DNS queries and responses. The API surface of DNS is quite large and to make handling of all corner-cases easier, we depend on the `domain` library to do the heavy-lifting for us. For better or worse, `domain` follows a lazy-parsing approach. Thus, creating a new DNS message doesn't actually verify that it is in fact valid. Within Firezone, we make several assumptions around DNS messages, such as that they will only ever contain a single question. Historically, DNS allows for multiple questions per query but in practise, nobody uses that. Due to how we handle DNS in Firezone, manipulating these messages happens in multiple places. That combined with the lazy-parsing approach from `domain` warrants having our own `dns-types` library that wraps `domain` and provides us with types that offer the interface we need in the rest of the codebase. Resolves: #7019
This commit is contained in:
2
.github/workflows/_rust.yml
vendored
2
.github/workflows/_rust.yml
vendored
@@ -114,7 +114,7 @@ jobs:
|
||||
rg --count --no-ignore "Packet for Internet resource" $TESTCASES_DIR
|
||||
rg --count --no-ignore "Performed IP-NAT46" $TESTCASES_DIR
|
||||
rg --count --no-ignore "Performed IP-NAT64" $TESTCASES_DIR
|
||||
rg --count --no-ignore "Too big DNS response, truncating" $TESTCASES_DIR
|
||||
rg --count --no-ignore "Truncating DNS response" $TESTCASES_DIR
|
||||
rg --count --no-ignore "Destination is unreachable" $TESTCASES_DIR
|
||||
rg --count --no-ignore "Forwarding query for DNS resource to corresponding site" $TESTCASES_DIR
|
||||
rg --count --no-ignore "Expanded single-label query into FQDN using search-domain" $TESTCASES_DIR
|
||||
|
||||
25
rust/Cargo.lock
generated
25
rust/Cargo.lock
generated
@@ -1044,6 +1044,7 @@ dependencies = [
|
||||
"backoff",
|
||||
"connlib-client-shared",
|
||||
"connlib-model",
|
||||
"dns-types",
|
||||
"firezone-logging",
|
||||
"firezone-telemetry",
|
||||
"flume",
|
||||
@@ -1075,6 +1076,7 @@ dependencies = [
|
||||
"backoff",
|
||||
"connlib-client-shared",
|
||||
"connlib-model",
|
||||
"dns-types",
|
||||
"firezone-logging",
|
||||
"firezone-telemetry",
|
||||
"flume",
|
||||
@@ -1107,6 +1109,7 @@ dependencies = [
|
||||
"bimap",
|
||||
"chrono",
|
||||
"connlib-model",
|
||||
"dns-types",
|
||||
"firezone-logging",
|
||||
"firezone-tunnel",
|
||||
"ip_network",
|
||||
@@ -1130,7 +1133,6 @@ name = "connlib-model"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"boringtun",
|
||||
"domain",
|
||||
"ip_network",
|
||||
"itertools 0.13.0",
|
||||
"serde",
|
||||
@@ -1706,13 +1708,12 @@ name = "dns-over-tcp"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"domain",
|
||||
"dns-types",
|
||||
"firezone-bin-shared",
|
||||
"firezone-logging",
|
||||
"futures",
|
||||
"ip-packet",
|
||||
"ip_network",
|
||||
"itertools 0.13.0",
|
||||
"rand 0.8.5",
|
||||
"smoltcp",
|
||||
"tokio",
|
||||
@@ -1720,6 +1721,15 @@ dependencies = [
|
||||
"tun",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "dns-types"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"domain",
|
||||
"thiserror 1.0.69",
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "domain"
|
||||
version = "0.10.3"
|
||||
@@ -2013,7 +2023,7 @@ dependencies = [
|
||||
"clap",
|
||||
"connlib-model",
|
||||
"dns-lookup",
|
||||
"domain",
|
||||
"dns-types",
|
||||
"either",
|
||||
"firezone-bin-shared",
|
||||
"firezone-logging",
|
||||
@@ -2141,6 +2151,7 @@ dependencies = [
|
||||
"connlib-client-shared",
|
||||
"connlib-model",
|
||||
"dirs 5.0.1",
|
||||
"dns-types",
|
||||
"firezone-bin-shared",
|
||||
"firezone-logging",
|
||||
"firezone-telemetry",
|
||||
@@ -2274,7 +2285,7 @@ dependencies = [
|
||||
"derive_more 1.0.0",
|
||||
"divan",
|
||||
"dns-over-tcp",
|
||||
"domain",
|
||||
"dns-types",
|
||||
"firezone-logging",
|
||||
"firezone-relay",
|
||||
"futures",
|
||||
@@ -3514,7 +3525,7 @@ name = "l4-tcp-dns-server"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"domain",
|
||||
"dns-types",
|
||||
"futures",
|
||||
"tokio",
|
||||
"tracing",
|
||||
@@ -3525,7 +3536,7 @@ name = "l4-udp-dns-server"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"domain",
|
||||
"dns-types",
|
||||
"futures",
|
||||
"tokio",
|
||||
"tracing",
|
||||
|
||||
@@ -10,6 +10,7 @@ members = [
|
||||
"connlib/snownet",
|
||||
"connlib/tunnel",
|
||||
"dns-over-tcp",
|
||||
"dns-types",
|
||||
"gateway",
|
||||
"gui-client/src-common",
|
||||
"gui-client/src-tauri",
|
||||
@@ -50,7 +51,6 @@ difference = "2.0.0"
|
||||
dirs = "5.0.1"
|
||||
divan = "0.1.17"
|
||||
dns-lookup = "2.0"
|
||||
domain = { version = "0.10", features = ["serde"] }
|
||||
either = "1"
|
||||
env_logger = "0.11.6"
|
||||
etherparse = "0.16"
|
||||
@@ -164,6 +164,7 @@ snownet = { path = "connlib/snownet" }
|
||||
l4-udp-dns-server = { path = "connlib/l4-udp-dns-server" }
|
||||
l4-tcp-dns-server = { path = "connlib/l4-tcp-dns-server" }
|
||||
dns-over-tcp = { path = "dns-over-tcp" }
|
||||
dns-types = { path = "dns-types" }
|
||||
firezone-relay = { path = "relay" }
|
||||
connlib-model = { path = "connlib/model" }
|
||||
firezone-tunnel = { path = "connlib/tunnel" }
|
||||
|
||||
@@ -15,6 +15,7 @@ anyhow = { workspace = true }
|
||||
backoff = { workspace = true }
|
||||
connlib-client-shared = { workspace = true }
|
||||
connlib-model = { workspace = true }
|
||||
dns-types = { workspace = true }
|
||||
firezone-logging = { workspace = true }
|
||||
firezone-telemetry = { workspace = true }
|
||||
flume = { workspace = true }
|
||||
|
||||
@@ -9,7 +9,8 @@ use crate::tun::Tun;
|
||||
use anyhow::{Context as _, Result};
|
||||
use backoff::ExponentialBackoffBuilder;
|
||||
use connlib_client_shared::{Callbacks, DisconnectError, Session, V4RouteList, V6RouteList};
|
||||
use connlib_model::{DomainName, ResourceView};
|
||||
use connlib_model::ResourceView;
|
||||
use dns_types::DomainName;
|
||||
use firezone_logging::{err_with_src, sentry_layer};
|
||||
use firezone_telemetry::{Telemetry, ANDROID_DSN};
|
||||
use ip_network::{Ipv4Network, Ipv6Network};
|
||||
|
||||
@@ -13,6 +13,7 @@ anyhow = { workspace = true }
|
||||
backoff = { workspace = true }
|
||||
connlib-client-shared = { workspace = true }
|
||||
connlib-model = { workspace = true }
|
||||
dns-types = { workspace = true }
|
||||
firezone-logging = { workspace = true }
|
||||
firezone-telemetry = { workspace = true }
|
||||
flume = { workspace = true }
|
||||
|
||||
@@ -9,8 +9,8 @@ use anyhow::Context;
|
||||
use anyhow::Result;
|
||||
use backoff::ExponentialBackoffBuilder;
|
||||
use connlib_client_shared::{Callbacks, DisconnectError, Session, V4RouteList, V6RouteList};
|
||||
use connlib_model::DomainName;
|
||||
use connlib_model::ResourceView;
|
||||
use dns_types::DomainName;
|
||||
use firezone_logging::err_with_src;
|
||||
use firezone_logging::sentry_layer;
|
||||
use firezone_telemetry::Telemetry;
|
||||
|
||||
@@ -9,6 +9,7 @@ anyhow = { workspace = true }
|
||||
backoff = { workspace = true }
|
||||
bimap = { workspace = true }
|
||||
connlib-model = { workspace = true }
|
||||
dns-types = { workspace = true }
|
||||
firezone-logging = { workspace = true }
|
||||
firezone-tunnel = { workspace = true }
|
||||
ip_network = { workspace = true }
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
use connlib_model::{DomainName, ResourceView};
|
||||
use connlib_model::ResourceView;
|
||||
use dns_types::DomainName;
|
||||
use ip_network::{Ipv4Network, Ipv6Network};
|
||||
use std::{
|
||||
net::{IpAddr, Ipv4Addr, Ipv6Addr},
|
||||
|
||||
@@ -9,7 +9,7 @@ path = "lib.rs"
|
||||
|
||||
[dependencies]
|
||||
anyhow = { workspace = true }
|
||||
domain = { workspace = true }
|
||||
dns-types = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
tokio = { workspace = true, features = ["net", "io-util"] }
|
||||
tracing = { workspace = true }
|
||||
|
||||
@@ -2,8 +2,7 @@
|
||||
|
||||
#![cfg_attr(test, allow(clippy::unwrap_used))]
|
||||
|
||||
use anyhow::{anyhow, Context as _, Result};
|
||||
use domain::base::Message;
|
||||
use anyhow::{Context as _, Result};
|
||||
use futures::{
|
||||
future::BoxFuture, stream::FuturesUnordered, task::AtomicWaker, FutureExt, StreamExt as _,
|
||||
};
|
||||
@@ -31,7 +30,7 @@ pub struct Server {
|
||||
/// A set of futures that read DNS queries from TCP streams.
|
||||
#[expect(clippy::type_complexity, reason = "We don't care.")]
|
||||
reading_tcp_queries: FuturesUnordered<
|
||||
BoxFuture<'static, Result<Option<(SocketAddr, Message<Vec<u8>>, TcpStream)>>>,
|
||||
BoxFuture<'static, Result<Option<(SocketAddr, dns_types::Query, TcpStream)>>>,
|
||||
>,
|
||||
/// A set of futures that send DNS responses over TCP streams.
|
||||
sending_tcp_responses: FuturesUnordered<BoxFuture<'static, Result<(TcpStream, SocketAddr)>>>,
|
||||
@@ -68,7 +67,11 @@ impl Server {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn send_response(&mut self, to: SocketAddr, response: Message<Vec<u8>>) -> io::Result<()> {
|
||||
pub fn send_response(
|
||||
&mut self,
|
||||
to: SocketAddr,
|
||||
response: dns_types::Response,
|
||||
) -> io::Result<()> {
|
||||
let mut stream = self
|
||||
.tcp_streams_by_remote
|
||||
.remove(&to)
|
||||
@@ -76,7 +79,9 @@ impl Server {
|
||||
|
||||
self.sending_tcp_responses.push(
|
||||
async move {
|
||||
let len = response.as_slice().len() as u16;
|
||||
let response = response.into_bytes(u16::MAX); // DNS over TCP has a 16-bit length field, we can't encode anything bigger than that.
|
||||
|
||||
let len = response.len() as u16;
|
||||
let len = len.to_be_bytes();
|
||||
|
||||
stream
|
||||
@@ -84,7 +89,7 @@ impl Server {
|
||||
.await
|
||||
.context("Failed to write TCP DNS header")?;
|
||||
stream
|
||||
.write_all(response.as_slice())
|
||||
.write_all(&response)
|
||||
.await
|
||||
.context("Failed to write TCP DNS message")?;
|
||||
|
||||
@@ -161,7 +166,7 @@ fn anyhow_to_io(e: anyhow::Error) -> io::Error {
|
||||
async fn read_tcp_query(
|
||||
mut stream: TcpStream,
|
||||
from: SocketAddr,
|
||||
) -> Result<Option<(SocketAddr, Message<Vec<u8>>, TcpStream)>> {
|
||||
) -> Result<Option<(SocketAddr, dns_types::Query, TcpStream)>> {
|
||||
let mut buf = [0; 2];
|
||||
match stream.read_exact(&mut buf).await {
|
||||
Ok(2) => {}
|
||||
@@ -178,8 +183,7 @@ async fn read_tcp_query(
|
||||
.await
|
||||
.context("Failed to read TCP DNS message")?;
|
||||
|
||||
let message =
|
||||
Message::try_from_octets(buf).map_err(|_| anyhow!("Failed to parse DNS message"))?;
|
||||
let message = dns_types::Query::parse(&buf).context("Failed to parse DNS message")?;
|
||||
|
||||
Ok(Some((from, message, stream)))
|
||||
}
|
||||
@@ -187,7 +191,7 @@ async fn read_tcp_query(
|
||||
pub struct Query {
|
||||
pub local: SocketAddr,
|
||||
pub remote: SocketAddr,
|
||||
pub message: Message<Vec<u8>>,
|
||||
pub message: dns_types::Query,
|
||||
}
|
||||
|
||||
fn make_tcp_listener(socket: impl ToSocketAddrs) -> Result<TcpListener> {
|
||||
@@ -205,7 +209,6 @@ fn make_tcp_listener(socket: impl ToSocketAddrs) -> Result<TcpListener> {
|
||||
|
||||
#[cfg(all(test, unix))]
|
||||
mod tests {
|
||||
use domain::base::{iana::Rcode, MessageBuilder};
|
||||
use std::future::poll_fn;
|
||||
use std::net::{Ipv4Addr, Ipv6Addr};
|
||||
use std::process::ExitStatus;
|
||||
@@ -227,7 +230,7 @@ mod tests {
|
||||
let query = poll_fn(|cx| server.poll(cx)).await.unwrap();
|
||||
|
||||
server
|
||||
.send_response(query.remote, empty_dns_response(query.message))
|
||||
.send_response(query.remote, dns_types::Response::no_error(&query.message))
|
||||
.unwrap();
|
||||
}
|
||||
});
|
||||
@@ -245,13 +248,6 @@ mod tests {
|
||||
server_task.abort();
|
||||
}
|
||||
|
||||
fn empty_dns_response(message: Message<Vec<u8>>) -> Message<Vec<u8>> {
|
||||
MessageBuilder::new_vec()
|
||||
.start_answer(&message, Rcode::NOERROR)
|
||||
.unwrap()
|
||||
.into_message()
|
||||
}
|
||||
|
||||
async fn dig(server: SocketAddr) -> ExitStatus {
|
||||
tokio::process::Command::new("dig")
|
||||
.arg(format!("@{}", server.ip()))
|
||||
|
||||
@@ -9,7 +9,7 @@ path = "lib.rs"
|
||||
|
||||
[dependencies]
|
||||
anyhow = { workspace = true }
|
||||
domain = { workspace = true }
|
||||
dns-types = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
tokio = { workspace = true, features = ["net"] }
|
||||
tracing = { workspace = true }
|
||||
|
||||
@@ -2,8 +2,7 @@
|
||||
|
||||
#![cfg_attr(test, allow(clippy::unwrap_used))]
|
||||
|
||||
use anyhow::{anyhow, Context as _, Result};
|
||||
use domain::base::Message;
|
||||
use anyhow::{Context as _, Result};
|
||||
use futures::{
|
||||
future::BoxFuture,
|
||||
stream::{self, BoxStream, FuturesUnordered},
|
||||
@@ -24,8 +23,8 @@ pub struct Server {
|
||||
udp_v6: Option<Arc<UdpSocket>>,
|
||||
|
||||
// Streams that read incoming queries from the UDP sockets.
|
||||
reading_udp_v4_queries: BoxStream<'static, Result<(SocketAddr, Message<Vec<u8>>)>>,
|
||||
reading_udp_v6_queries: BoxStream<'static, Result<(SocketAddr, Message<Vec<u8>>)>>,
|
||||
reading_udp_v4_queries: BoxStream<'static, Result<(SocketAddr, dns_types::Query)>>,
|
||||
reading_udp_v6_queries: BoxStream<'static, Result<(SocketAddr, dns_types::Query)>>,
|
||||
|
||||
// Futures that send responses on the UDP sockets.
|
||||
sending_udp_v4_responses: FuturesUnordered<BoxFuture<'static, Result<()>>>,
|
||||
@@ -69,7 +68,11 @@ impl Server {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn send_response(&mut self, to: SocketAddr, response: Message<Vec<u8>>) -> io::Result<()> {
|
||||
pub fn send_response(
|
||||
&mut self,
|
||||
to: SocketAddr,
|
||||
response: dns_types::Response,
|
||||
) -> io::Result<()> {
|
||||
let (udp_socket, workers) = match (to, self.udp_v4.clone(), self.udp_v6.clone()) {
|
||||
(SocketAddr::V4(_), Some(socket), _) => (socket, &mut self.sending_udp_v4_responses),
|
||||
(SocketAddr::V6(_), _, Some(socket)) => (socket, &mut self.sending_udp_v6_responses),
|
||||
@@ -79,8 +82,13 @@ impl Server {
|
||||
|
||||
workers.push(
|
||||
async move {
|
||||
// TODO: Make this limit configurable.
|
||||
// The current 1200 are conservative and should be safe for the public Internet and our WireGuard tunnel.
|
||||
// Worst-case, the client will re-query over TCP.
|
||||
let payload = response.into_bytes(1200);
|
||||
|
||||
udp_socket
|
||||
.send_to(response.as_slice(), to)
|
||||
.send_to(&payload, to)
|
||||
.await
|
||||
.context("Failed to send UDP response")?;
|
||||
|
||||
@@ -141,7 +149,7 @@ impl Server {
|
||||
/// Produces a stream of incoming DNS queries from a UDP socket for as long as there is at least one strong reference to the socket.
|
||||
fn udp_dns_query_stream(
|
||||
udp_socket: Weak<UdpSocket>,
|
||||
) -> BoxStream<'static, Result<(SocketAddr, Message<Vec<u8>>)>> {
|
||||
) -> BoxStream<'static, Result<(SocketAddr, dns_types::Query)>> {
|
||||
stream::repeat(udp_socket) // We start with an infinite stream of weak references to the UDP socket.
|
||||
.filter_map(|udp_socket| async move { udp_socket.upgrade() }) // For each item pulled from the stream, we first try to upgrade to a strong reference.
|
||||
.then(read_udp_query) // And then read single DNS query from the socket.
|
||||
@@ -152,7 +160,7 @@ fn anyhow_to_io(e: anyhow::Error) -> io::Error {
|
||||
io::Error::other(format!("{e:#}"))
|
||||
}
|
||||
|
||||
async fn read_udp_query(socket: Arc<UdpSocket>) -> Result<(SocketAddr, Message<Vec<u8>>)> {
|
||||
async fn read_udp_query(socket: Arc<UdpSocket>) -> Result<(SocketAddr, dns_types::Query)> {
|
||||
let mut buffer = vec![0u8; 2000]; // On the public Internet, any MTU > 1500 is very unlikely so 2000 is a safe bet.
|
||||
|
||||
let (len, from) = socket
|
||||
@@ -162,15 +170,14 @@ async fn read_udp_query(socket: Arc<UdpSocket>) -> Result<(SocketAddr, Message<V
|
||||
|
||||
buffer.truncate(len);
|
||||
|
||||
let message =
|
||||
Message::try_from_octets(buffer).map_err(|_| anyhow!("Failed to parse DNS message"))?;
|
||||
let message = dns_types::Query::parse(&buffer).context("Failed to parse DNS message")?;
|
||||
|
||||
Ok((from, message))
|
||||
}
|
||||
|
||||
pub struct Query {
|
||||
pub source: SocketAddr,
|
||||
pub message: Message<Vec<u8>>,
|
||||
pub message: dns_types::Query,
|
||||
}
|
||||
|
||||
fn make_udp_socket(socket: impl ToSocketAddrs) -> Result<UdpSocket> {
|
||||
@@ -201,7 +208,6 @@ impl Default for Server {
|
||||
|
||||
#[cfg(all(test, unix))]
|
||||
mod tests {
|
||||
use domain::base::{iana::Rcode, MessageBuilder};
|
||||
use std::future::poll_fn;
|
||||
use std::net::{Ipv4Addr, Ipv6Addr};
|
||||
use std::process::ExitStatus;
|
||||
@@ -223,7 +229,7 @@ mod tests {
|
||||
let query = poll_fn(|cx| server.poll(cx)).await.unwrap();
|
||||
|
||||
server
|
||||
.send_response(query.source, empty_dns_response(query.message))
|
||||
.send_response(query.source, dns_types::Response::no_error(&query.message))
|
||||
.unwrap();
|
||||
}
|
||||
});
|
||||
@@ -241,13 +247,6 @@ mod tests {
|
||||
server_task.abort();
|
||||
}
|
||||
|
||||
fn empty_dns_response(message: Message<Vec<u8>>) -> Message<Vec<u8>> {
|
||||
MessageBuilder::new_vec()
|
||||
.start_answer(&message, Rcode::NOERROR)
|
||||
.unwrap()
|
||||
.into_message()
|
||||
}
|
||||
|
||||
async fn dig(server: SocketAddr) -> ExitStatus {
|
||||
tokio::process::Command::new("dig")
|
||||
.arg(format!("@{}", server.ip()))
|
||||
|
||||
@@ -6,7 +6,6 @@ license = { workspace = true }
|
||||
|
||||
[dependencies]
|
||||
boringtun = { workspace = true }
|
||||
domain = { workspace = true }
|
||||
ip_network = { workspace = true, features = ["serde"] }
|
||||
serde = { workspace = true, features = ["derive", "std"] }
|
||||
uuid = { workspace = true, features = ["std", "v4", "serde"] }
|
||||
|
||||
@@ -13,9 +13,6 @@ pub use view::{
|
||||
CidrResourceView, DnsResourceView, InternetResourceView, ResourceStatus, ResourceView,
|
||||
};
|
||||
|
||||
pub type DomainName = domain::base::Name<Vec<u8>>;
|
||||
pub type DomainRecord = domain::rdata::AllRecordData<Vec<u8>, DomainName>;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fmt;
|
||||
use std::str::FromStr;
|
||||
|
||||
@@ -15,7 +15,7 @@ connlib-model = { workspace = true }
|
||||
derive_more = { workspace = true, features = ["debug"] }
|
||||
divan = { workspace = true, optional = true }
|
||||
dns-over-tcp = { workspace = true }
|
||||
domain = { workspace = true }
|
||||
dns-types = { workspace = true }
|
||||
firezone-logging = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
futures-bounded = { workspace = true }
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
mod resource;
|
||||
|
||||
use dns_types::DomainName;
|
||||
pub(crate) use resource::{CidrResource, Resource};
|
||||
#[cfg(all(feature = "proptest", test))]
|
||||
pub(crate) use resource::{DnsResource, InternetResource};
|
||||
@@ -14,9 +15,7 @@ use crate::unique_packet_buffer::UniquePacketBuffer;
|
||||
use crate::{dns, is_peer, p2p_control, IpConfig, TunConfig, IPV4_TUNNEL, IPV6_TUNNEL};
|
||||
use anyhow::Context;
|
||||
use bimap::BiMap;
|
||||
use connlib_model::{
|
||||
DomainName, GatewayId, PublicKey, RelayId, ResourceId, ResourceStatus, ResourceView,
|
||||
};
|
||||
use connlib_model::{GatewayId, PublicKey, RelayId, ResourceId, ResourceStatus, ResourceView};
|
||||
use connlib_model::{Site, SiteId};
|
||||
use firezone_logging::{err_with_src, telemetry_event, unwrap_or_debug, unwrap_or_warn};
|
||||
use ip_network::{IpNetwork, Ipv4Network, Ipv6Network};
|
||||
@@ -26,7 +25,6 @@ use itertools::Itertools;
|
||||
|
||||
use crate::peer::GatewayOnClient;
|
||||
use crate::ClientEvent;
|
||||
use domain::base::Message;
|
||||
use lru::LruCache;
|
||||
use secrecy::{ExposeSecret as _, Secret};
|
||||
use snownet::{ClientNode, NoTurnServers, RelaySocket, Transmit};
|
||||
@@ -444,17 +442,12 @@ impl ClientState {
|
||||
/// We call this function every time a client issues a DNS query for a certain domain.
|
||||
/// Coupling this behaviour together allows a client to refresh the DNS resolution of a DNS resource on the Gateway
|
||||
/// through local DNS resolutions.
|
||||
fn clear_dns_resource_nat_for_domain(&mut self, message: Message<&[u8]>) {
|
||||
let Ok(question) = message.sole_question() else {
|
||||
return;
|
||||
};
|
||||
let domain = question.into_qname();
|
||||
|
||||
fn clear_dns_resource_nat_for_domain(&mut self, message: &dns_types::Response) {
|
||||
let mut any_deleted = false;
|
||||
|
||||
self.dns_resource_nat_by_gateway
|
||||
.retain(|(_, candidate), _| {
|
||||
if candidate == &domain {
|
||||
if candidate == &message.domain() {
|
||||
any_deleted = true;
|
||||
return false;
|
||||
}
|
||||
@@ -463,7 +456,7 @@ impl ClientState {
|
||||
});
|
||||
|
||||
if any_deleted {
|
||||
tracing::debug!(%domain, "Cleared DNS resource NAT");
|
||||
tracing::debug!(domain = %message.domain(), "Cleared DNS resource NAT");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -551,16 +544,11 @@ impl ClientState {
|
||||
}
|
||||
|
||||
pub(crate) fn handle_dns_response(&mut self, response: dns::RecursiveResponse) {
|
||||
let qid = response.query.header().id();
|
||||
let qid = response.query.id();
|
||||
let server = response.server;
|
||||
let domain = response
|
||||
.query
|
||||
.sole_question()
|
||||
.ok()
|
||||
.map(|q| q.into_qname())
|
||||
.map(tracing::field::display);
|
||||
let domain = response.query.domain();
|
||||
|
||||
let _span = tracing::debug_span!("handle_dns_response", %qid, %server, domain).entered();
|
||||
let _span = tracing::debug_span!("handle_dns_response", %qid, %server, %domain).entered();
|
||||
|
||||
match (response.transport, response.message) {
|
||||
(dns::Transport::Udp { .. }, Err(e)) if e.kind() == io::ErrorKind::TimedOut => {
|
||||
@@ -571,14 +559,14 @@ impl ClientState {
|
||||
.inspect(|message| {
|
||||
tracing::trace!("Received recursive UDP DNS response");
|
||||
|
||||
if message.header().tc() {
|
||||
if message.truncated() {
|
||||
tracing::debug!("Upstream DNS server had to truncate response");
|
||||
}
|
||||
})
|
||||
.unwrap_or_else(|e| {
|
||||
telemetry_event!("Recursive UDP DNS query failed: {}", err_with_src(&e));
|
||||
|
||||
dns::servfail(response.query.for_slice_ref())
|
||||
dns_types::Response::servfail(&response.query)
|
||||
});
|
||||
|
||||
unwrap_or_warn!(
|
||||
@@ -594,7 +582,7 @@ impl ClientState {
|
||||
.unwrap_or_else(|e| {
|
||||
telemetry_event!("Recursive TCP DNS query failed: {}", err_with_src(&e));
|
||||
|
||||
dns::servfail(response.query.for_slice_ref())
|
||||
dns_types::Response::servfail(&response.query)
|
||||
});
|
||||
|
||||
unwrap_or_warn!(
|
||||
@@ -665,7 +653,7 @@ impl ClientState {
|
||||
&mut self,
|
||||
from: SocketAddr,
|
||||
dst: SocketAddr,
|
||||
message: Message<Vec<u8>>,
|
||||
message: dns_types::Response,
|
||||
) -> anyhow::Result<()> {
|
||||
let saddr = *self
|
||||
.dns_mapping
|
||||
@@ -677,7 +665,7 @@ impl ClientState {
|
||||
dst.ip(),
|
||||
DNS_PORT,
|
||||
dst.port(),
|
||||
truncate_dns_response(message),
|
||||
message.into_bytes(MAX_UDP_PAYLOAD),
|
||||
)?;
|
||||
|
||||
self.buffered_packets.push_back(ip_packet);
|
||||
@@ -1144,7 +1132,7 @@ impl ClientState {
|
||||
// Check if the client has assembled a response to a query.
|
||||
if let Some(query_result) = self.tcp_dns_client.poll_query_result() {
|
||||
let server = query_result.server;
|
||||
let qid = query_result.query.header().id();
|
||||
let qid = query_result.query.id();
|
||||
let known_sockets = &mut self.tcp_dns_streams_by_upstream_and_query_id;
|
||||
|
||||
let Some((local, remote)) = known_sockets.remove(&(server, qid)) else {
|
||||
@@ -1188,7 +1176,7 @@ impl ClientState {
|
||||
return ControlFlow::Break(());
|
||||
}
|
||||
|
||||
let message = match Message::from_octets(datagram.payload()) {
|
||||
let message = match dns_types::Query::parse(datagram.payload()) {
|
||||
Ok(message) => message,
|
||||
Err(e) => {
|
||||
tracing::warn!(?packet, "Failed to parse DNS query: {e:#}");
|
||||
@@ -1198,9 +1186,9 @@ impl ClientState {
|
||||
|
||||
let source = SocketAddr::new(packet.source(), datagram.source_port());
|
||||
|
||||
match self.stub_resolver.handle(message) {
|
||||
match self.stub_resolver.handle(&message) {
|
||||
dns::ResolveStrategy::LocalResponse(response) => {
|
||||
self.clear_dns_resource_nat_for_domain(response.for_slice_ref());
|
||||
self.clear_dns_resource_nat_for_domain(&response);
|
||||
self.update_dns_resource_nat(now, iter::empty());
|
||||
|
||||
unwrap_or_debug!(
|
||||
@@ -1215,7 +1203,7 @@ impl ClientState {
|
||||
|
||||
return ControlFlow::Continue(packet);
|
||||
}
|
||||
let query_id = message.header().id();
|
||||
let query_id = message.id();
|
||||
|
||||
tracing::trace!(server = %upstream, %query_id, "Forwarding UDP DNS query directly via host");
|
||||
|
||||
@@ -1261,7 +1249,7 @@ impl ClientState {
|
||||
return;
|
||||
}
|
||||
|
||||
let message = match Message::from_octets(datagram.payload()) {
|
||||
let message = match dns_types::Query::parse(datagram.payload()) {
|
||||
Ok(message) => message,
|
||||
Err(e) => {
|
||||
tracing::warn!(?packet, "Failed to parse DNS query: {e:#}");
|
||||
@@ -1269,9 +1257,9 @@ impl ClientState {
|
||||
}
|
||||
};
|
||||
|
||||
match self.stub_resolver.handle(message) {
|
||||
match self.stub_resolver.handle(&message) {
|
||||
dns::ResolveStrategy::LocalResponse(response) => {
|
||||
self.clear_dns_resource_nat_for_domain(response.for_slice_ref());
|
||||
self.clear_dns_resource_nat_for_domain(&response);
|
||||
self.update_dns_resource_nat(now, iter::empty());
|
||||
|
||||
let maybe_packet = ip_packet::make::udp_packet(
|
||||
@@ -1279,7 +1267,7 @@ impl ClientState {
|
||||
packet.source(),
|
||||
datagram.destination_port(),
|
||||
datagram.source_port(),
|
||||
truncate_dns_response(response),
|
||||
response.into_bytes(MAX_UDP_PAYLOAD),
|
||||
)
|
||||
.inspect_err(|e| {
|
||||
tracing::debug!("Failed to create LLMNR DNS response packet: {e:#}");
|
||||
@@ -1309,9 +1297,8 @@ impl ClientState {
|
||||
.expect("to be a valid UDP packet at this point");
|
||||
|
||||
let dst_port = datagram.destination_port();
|
||||
let query_id = Message::from_octets(datagram.payload())
|
||||
let query_id = dns_types::Query::parse(datagram.payload())
|
||||
.expect("to be a valid DNS query at this point")
|
||||
.header()
|
||||
.id();
|
||||
|
||||
let connlib_dns_server = SocketAddr::new(dst_ip, dst_port);
|
||||
@@ -1336,7 +1323,7 @@ impl ClientState {
|
||||
}
|
||||
|
||||
fn handle_tcp_dns_query(&mut self, query: dns_over_tcp::Query, now: Instant) {
|
||||
let query_id = query.message.header().id();
|
||||
let query_id = query.message.id();
|
||||
|
||||
let Some(upstream) = self.dns_mapping.get_by_left(&query.local.ip()) else {
|
||||
// This is highly-unlikely but might be possible if our DNS mapping changes whilst the TCP DNS server is processing a request.
|
||||
@@ -1344,9 +1331,9 @@ impl ClientState {
|
||||
};
|
||||
let server = upstream.address();
|
||||
|
||||
match self.stub_resolver.handle(query.message.for_slice_ref()) {
|
||||
match self.stub_resolver.handle(&query.message) {
|
||||
dns::ResolveStrategy::LocalResponse(response) => {
|
||||
self.clear_dns_resource_nat_for_domain(response.for_slice_ref());
|
||||
self.clear_dns_resource_nat_for_domain(&response);
|
||||
self.update_dns_resource_nat(now, iter::empty());
|
||||
|
||||
unwrap_or_debug!(
|
||||
@@ -1396,7 +1383,7 @@ impl ClientState {
|
||||
server: SocketAddr,
|
||||
query: dns_over_tcp::Query,
|
||||
) {
|
||||
let query_id = query.message.header().id();
|
||||
let query_id = query.message.id();
|
||||
|
||||
match self
|
||||
.tcp_dns_client
|
||||
@@ -1412,7 +1399,7 @@ impl ClientState {
|
||||
self.tcp_dns_server.send_message(
|
||||
query.local,
|
||||
query.remote,
|
||||
dns::servfail(query.message.for_slice_ref())
|
||||
dns_types::Response::servfail(&query.message)
|
||||
),
|
||||
"Failed to send TCP DNS response: {}"
|
||||
);
|
||||
@@ -2002,23 +1989,17 @@ fn maybe_mangle_dns_response_from_upstream_dns_server(
|
||||
let src_port = udp.source_port();
|
||||
let src_socket = SocketAddr::new(src_ip, src_port);
|
||||
|
||||
let Ok(message) = domain::base::Message::from_slice(udp.payload()) else {
|
||||
let Ok(message) = dns_types::Response::parse(udp.payload()) else {
|
||||
return packet;
|
||||
};
|
||||
|
||||
let Some(original_dst) =
|
||||
udp_dns_sockets_by_upstream_and_query_id.remove(&(src_socket, message.header().id()))
|
||||
udp_dns_sockets_by_upstream_and_query_id.remove(&(src_socket, message.id()))
|
||||
else {
|
||||
return packet;
|
||||
};
|
||||
|
||||
let domain = message
|
||||
.sole_question()
|
||||
.ok()
|
||||
.map(|q| q.into_qname())
|
||||
.map(tracing::field::display);
|
||||
|
||||
tracing::trace!(server = %src_ip, query_id = %message.header().id(), domain, "Received UDP DNS response via tunnel");
|
||||
tracing::trace!(server = %src_ip, query_id = %message.id(), domain = %message.domain(), "Received UDP DNS response via tunnel");
|
||||
|
||||
packet.set_src(original_dst.ip());
|
||||
packet
|
||||
@@ -2031,28 +2012,6 @@ fn maybe_mangle_dns_response_from_upstream_dns_server(
|
||||
packet
|
||||
}
|
||||
|
||||
fn truncate_dns_response(mut message: Message<Vec<u8>>) -> Vec<u8> {
|
||||
let message_length = message.as_octets().len();
|
||||
if message_length <= MAX_UDP_PAYLOAD {
|
||||
return message.into_octets();
|
||||
}
|
||||
|
||||
tracing::debug!(?message, %message_length, "Too big DNS response, truncating");
|
||||
|
||||
message.header_mut().set_tc(true);
|
||||
|
||||
let message_truncation = match message.answer() {
|
||||
Ok(answer) if answer.pos() <= MAX_UDP_PAYLOAD => answer.pos(),
|
||||
// This should be very unlikely or impossible.
|
||||
_ => message.question().pos(),
|
||||
};
|
||||
|
||||
let mut message_bytes = message.into_octets();
|
||||
message_bytes.truncate(message_truncation);
|
||||
|
||||
message_bytes
|
||||
}
|
||||
|
||||
/// What triggered us to establish a connection to a Gateway.
|
||||
enum ConnectionTrigger {
|
||||
/// A packet received on the TUN device with a destination IP that maps to one of our resources.
|
||||
|
||||
@@ -1,8 +1,4 @@
|
||||
use domain::base::iana::Rcode;
|
||||
use domain::base::{Message, ParsedName, Rtype};
|
||||
use domain::rdata::AllRecordData;
|
||||
use ip_packet::IpPacket;
|
||||
use itertools::Itertools;
|
||||
use std::io;
|
||||
use std::task::{Context, Poll, Waker};
|
||||
use tracing::Level;
|
||||
@@ -46,8 +42,8 @@ impl Device {
|
||||
|
||||
for packet in &buf[..n] {
|
||||
if tracing::event_enabled!(target: "wire::dns::qry", Level::TRACE) {
|
||||
if let Some((qtype, qname, qid)) = parse_dns_query(packet) {
|
||||
tracing::trace!(target: "wire::dns::qry", %qid, "{:5} {qname}", qtype.to_string());
|
||||
if let Some(query) = parse_dns_query(packet) {
|
||||
tracing::trace!(target: "wire::dns::qry", ?query);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -72,8 +68,8 @@ impl Device {
|
||||
|
||||
pub fn send(&mut self, packet: IpPacket) -> io::Result<()> {
|
||||
if tracing::event_enabled!(target: "wire::dns::res", Level::TRACE) {
|
||||
if let Some((qtype, qname, records, rcode, qid)) = parse_dns_response(&packet) {
|
||||
tracing::trace!(target: "wire::dns::res", %qid, %rcode, "{:5} {qname} => [{records}]", qtype.to_string());
|
||||
if let Some(response) = parse_dns_response(&packet) {
|
||||
tracing::trace!(target: "wire::dns::res", ?response);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -102,62 +98,20 @@ fn io_error_not_initialized() -> io::Error {
|
||||
io::Error::new(io::ErrorKind::NotConnected, "device is not initialized yet")
|
||||
}
|
||||
|
||||
fn parse_dns_query(packet: &IpPacket) -> Option<(Rtype, ParsedName<&[u8]>, u16)> {
|
||||
fn parse_dns_query(packet: &IpPacket) -> Option<dns_types::Query> {
|
||||
let udp = packet.as_udp()?;
|
||||
if udp.destination_port() != crate::dns::DNS_PORT {
|
||||
return None;
|
||||
}
|
||||
|
||||
let message = &Message::from_slice(udp.payload()).ok()?;
|
||||
|
||||
if message.header().qr() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let question = message.sole_question().ok()?;
|
||||
|
||||
let qtype = question.qtype();
|
||||
let qname = question.into_qname();
|
||||
let id = message.header().id();
|
||||
|
||||
Some((qtype, qname, id))
|
||||
dns_types::Query::parse(udp.payload()).ok()
|
||||
}
|
||||
|
||||
#[expect(clippy::type_complexity)]
|
||||
fn parse_dns_response(packet: &IpPacket) -> Option<(Rtype, ParsedName<&[u8]>, String, Rcode, u16)> {
|
||||
fn parse_dns_response(packet: &IpPacket) -> Option<dns_types::Response> {
|
||||
let udp = packet.as_udp()?;
|
||||
if udp.source_port() != crate::dns::DNS_PORT {
|
||||
return None;
|
||||
}
|
||||
|
||||
let message = &Message::from_slice(udp.payload()).ok()?;
|
||||
|
||||
if !message.header().qr() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let question = message.sole_question().ok()?;
|
||||
|
||||
let qtype = question.qtype();
|
||||
let qname = question.into_qname();
|
||||
let rcode = message.header().rcode();
|
||||
|
||||
let record_section = message.answer().ok()?;
|
||||
|
||||
let records = record_section
|
||||
.into_iter()
|
||||
.filter_map(|r| {
|
||||
let data = r
|
||||
.ok()?
|
||||
.into_any_record::<AllRecordData<_, _>>()
|
||||
.ok()?
|
||||
.data()
|
||||
.clone();
|
||||
|
||||
Some(data)
|
||||
})
|
||||
.join(" | ");
|
||||
let id = message.header().id();
|
||||
|
||||
Some((qtype, qname, records, rcode, id))
|
||||
dns_types::Response::parse(udp.payload()).ok()
|
||||
}
|
||||
|
||||
@@ -1,21 +1,15 @@
|
||||
use crate::client::IpProvider;
|
||||
use anyhow::{Context, Result};
|
||||
use connlib_model::{DomainName, ResourceId};
|
||||
use domain::base::name::FlattenInto;
|
||||
use domain::rdata::AllRecordData;
|
||||
use domain::{
|
||||
base::{
|
||||
iana::{Class, Rcode, Rtype},
|
||||
Message, MessageBuilder, ToName,
|
||||
},
|
||||
dep::octseq::OctetsInto,
|
||||
use anyhow::{Context as _, Result};
|
||||
use connlib_model::ResourceId;
|
||||
use dns_types::{
|
||||
prelude::*, DomainName, DomainNameRef, OwnedRecordData, Query, RecordType, Response,
|
||||
ResponseBuilder, ResponseCode,
|
||||
};
|
||||
use firezone_logging::{err_with_src, telemetry_span};
|
||||
use itertools::Itertools;
|
||||
use pattern::{Candidate, Pattern};
|
||||
use std::io;
|
||||
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
|
||||
use std::sync::LazyLock;
|
||||
use std::{
|
||||
collections::{BTreeMap, HashMap},
|
||||
net::SocketAddr,
|
||||
@@ -34,14 +28,14 @@ pub(crate) const DNS_PORT: u16 = 53;
|
||||
/// For Chrome and other Chrome-based browsers, this is not required as
|
||||
/// Chrome will automatically disable DoH if your server(s) don't support
|
||||
/// it. See <https://www.chromium.org/developers/dns-over-https/#faq>.
|
||||
static DOH_CANARY_DOMAIN: LazyLock<DomainName> = LazyLock::new(|| {
|
||||
DomainName::vec_from_str("use-application-dns.net")
|
||||
.expect("static domain name should always parse")
|
||||
});
|
||||
///
|
||||
/// SAFETY: We have a unit-test for it.
|
||||
pub const DOH_CANARY_DOMAIN: DomainNameRef =
|
||||
unsafe { DomainNameRef::from_octets_unchecked(b"\x13use-application-dns\x03net\x00") };
|
||||
|
||||
pub struct StubResolver {
|
||||
fqdn_to_ips: BTreeMap<(DomainName, ResourceId), Vec<IpAddr>>,
|
||||
ips_to_fqdn: HashMap<IpAddr, (DomainName, ResourceId)>,
|
||||
fqdn_to_ips: BTreeMap<(dns_types::DomainName, ResourceId), Vec<IpAddr>>,
|
||||
ips_to_fqdn: HashMap<IpAddr, (dns_types::DomainName, ResourceId)>,
|
||||
ip_provider: IpProvider,
|
||||
/// All DNS resources we know about, indexed by the glob pattern they match against.
|
||||
dns_resources: BTreeMap<Pattern, ResourceId>,
|
||||
@@ -52,7 +46,7 @@ pub struct StubResolver {
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct RecursiveQuery {
|
||||
pub server: SocketAddr,
|
||||
pub message: Message<Vec<u8>>,
|
||||
pub message: dns_types::Query,
|
||||
pub transport: Transport,
|
||||
}
|
||||
|
||||
@@ -60,16 +54,20 @@ pub(crate) struct RecursiveQuery {
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct RecursiveResponse {
|
||||
pub server: SocketAddr,
|
||||
pub query: Message<Vec<u8>>,
|
||||
pub message: io::Result<Message<Vec<u8>>>,
|
||||
pub query: dns_types::Query,
|
||||
pub message: io::Result<dns_types::Response>,
|
||||
pub transport: Transport,
|
||||
}
|
||||
|
||||
impl RecursiveQuery {
|
||||
pub(crate) fn via_udp(source: SocketAddr, server: SocketAddr, message: Message<&[u8]>) -> Self {
|
||||
pub(crate) fn via_udp(
|
||||
source: SocketAddr,
|
||||
server: SocketAddr,
|
||||
message: dns_types::Query,
|
||||
) -> Self {
|
||||
Self {
|
||||
server,
|
||||
message: message.octets_into(),
|
||||
message,
|
||||
transport: Transport::Udp { source },
|
||||
}
|
||||
}
|
||||
@@ -78,7 +76,7 @@ impl RecursiveQuery {
|
||||
local: SocketAddr,
|
||||
remote: SocketAddr,
|
||||
server: SocketAddr,
|
||||
message: Message<Vec<u8>>,
|
||||
message: dns_types::Query,
|
||||
) -> Self {
|
||||
Self {
|
||||
server,
|
||||
@@ -104,7 +102,7 @@ pub(crate) enum Transport {
|
||||
#[derive(Debug)]
|
||||
pub(crate) enum ResolveStrategy {
|
||||
/// The query is for a Resource, we have an IP mapped already, and we can respond instantly
|
||||
LocalResponse(Message<Vec<u8>>),
|
||||
LocalResponse(Response),
|
||||
/// The query is for a non-Resource, forward it locally to an upstream or system resolver.
|
||||
RecurseLocal,
|
||||
/// The query is for a DNS resource but for a type that we don't intercept (i.e. SRV, TXT, ...), forward it to the site that hosts the DNS resource and resolve it there.
|
||||
@@ -128,13 +126,16 @@ impl StubResolver {
|
||||
///
|
||||
/// Semantically, this is like a PTR query, i.e. we check whether we handed out this IP as part of answering a DNS query for one of our resources.
|
||||
/// This is in the hot-path of packet routing and must be fast!
|
||||
pub(crate) fn resolve_resource_by_ip(&self, ip: &IpAddr) -> Option<&(DomainName, ResourceId)> {
|
||||
pub(crate) fn resolve_resource_by_ip(
|
||||
&self,
|
||||
ip: &IpAddr,
|
||||
) -> Option<&(dns_types::DomainName, ResourceId)> {
|
||||
self.ips_to_fqdn.get(ip)
|
||||
}
|
||||
|
||||
pub(crate) fn resolved_resources(
|
||||
&self,
|
||||
) -> impl Iterator<Item = (&DomainName, &ResourceId, &Vec<IpAddr>)> + '_ {
|
||||
) -> impl Iterator<Item = (&dns_types::DomainName, &ResourceId, &Vec<IpAddr>)> + '_ {
|
||||
self.fqdn_to_ips
|
||||
.iter()
|
||||
.map(|((domain, resource), ips)| (domain, resource, ips))
|
||||
@@ -160,21 +161,33 @@ impl StubResolver {
|
||||
|
||||
fn get_or_assign_a_records(
|
||||
&mut self,
|
||||
fqdn: DomainName,
|
||||
fqdn: dns_types::DomainName,
|
||||
resource_id: ResourceId,
|
||||
) -> Vec<AllRecordData<Vec<u8>, DomainName>> {
|
||||
to_a_records(self.get_or_assign_ips(fqdn, resource_id).into_iter())
|
||||
) -> Vec<OwnedRecordData> {
|
||||
self.get_or_assign_ips(fqdn, resource_id)
|
||||
.into_iter()
|
||||
.filter_map(get_v4)
|
||||
.map(dns_types::records::a)
|
||||
.collect_vec()
|
||||
}
|
||||
|
||||
fn get_or_assign_aaaa_records(
|
||||
&mut self,
|
||||
fqdn: DomainName,
|
||||
fqdn: dns_types::DomainName,
|
||||
resource_id: ResourceId,
|
||||
) -> Vec<AllRecordData<Vec<u8>, DomainName>> {
|
||||
to_aaaa_records(self.get_or_assign_ips(fqdn, resource_id).into_iter())
|
||||
) -> Vec<OwnedRecordData> {
|
||||
self.get_or_assign_ips(fqdn, resource_id)
|
||||
.into_iter()
|
||||
.filter_map(get_v6)
|
||||
.map(dns_types::records::aaaa)
|
||||
.collect_vec()
|
||||
}
|
||||
|
||||
fn get_or_assign_ips(&mut self, fqdn: DomainName, resource_id: ResourceId) -> Vec<IpAddr> {
|
||||
fn get_or_assign_ips(
|
||||
&mut self,
|
||||
fqdn: dns_types::DomainName,
|
||||
resource_id: ResourceId,
|
||||
) -> Vec<IpAddr> {
|
||||
let ips = self
|
||||
.fqdn_to_ips
|
||||
.entry((fqdn.clone(), resource_id))
|
||||
@@ -197,7 +210,7 @@ impl StubResolver {
|
||||
/// Attempts to match the given domain against our list of possible patterns.
|
||||
///
|
||||
/// This performs a linear search and is thus O(N) and **must not** be called in the hot-path of packet routing.
|
||||
fn match_resource_linear(&self, domain: &DomainName) -> Option<ResourceId> {
|
||||
fn match_resource_linear(&self, domain: &dns_types::DomainName) -> Option<ResourceId> {
|
||||
let _span = telemetry_span!("match_resource_linear").entered();
|
||||
|
||||
let name = Candidate::from_domain(domain);
|
||||
@@ -222,8 +235,8 @@ impl StubResolver {
|
||||
|
||||
fn resource_address_name_by_reservse_dns(
|
||||
&self,
|
||||
reverse_dns_name: &DomainName,
|
||||
) -> Option<DomainName> {
|
||||
reverse_dns_name: &dns_types::DomainName,
|
||||
) -> Option<dns_types::DomainName> {
|
||||
let address = reverse_dns_addr(&reverse_dns_name.to_string())?;
|
||||
let (domain, _) = self.ips_to_fqdn.get(&address)?;
|
||||
|
||||
@@ -231,40 +244,14 @@ impl StubResolver {
|
||||
}
|
||||
|
||||
/// Processes the incoming DNS query.
|
||||
///
|
||||
/// Any errors will result in an immediate `SERVFAIL` response.
|
||||
pub(crate) fn handle(&mut self, message: Message<&[u8]>) -> ResolveStrategy {
|
||||
match self.try_handle(message) {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
tracing::warn!("Failed to handle DNS query: {e:#}");
|
||||
|
||||
ResolveStrategy::LocalResponse(servfail(message))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn try_handle(&mut self, message: Message<&[u8]>) -> Result<ResolveStrategy> {
|
||||
anyhow::ensure!(
|
||||
!message.header().qr(),
|
||||
"Can only handle DNS queries, not responses"
|
||||
);
|
||||
|
||||
// We don't need to support multiple questions/qname in a single query because
|
||||
// nobody does it and since this run with each packet we want to squeeze as much optimization
|
||||
// as we can therefore we won't do it.
|
||||
//
|
||||
// See: https://stackoverflow.com/a/55093896
|
||||
let question = message
|
||||
.sole_question()
|
||||
.context("Expected a single 'question'")?;
|
||||
let domain = question.qname().to_vec();
|
||||
let qtype = question.qtype();
|
||||
pub(crate) fn handle(&mut self, query: &Query) -> ResolveStrategy {
|
||||
let domain = query.domain();
|
||||
let qtype = query.qtype();
|
||||
|
||||
tracing::trace!("Parsed packet as DNS query: '{qtype} {domain}'");
|
||||
|
||||
if domain == *DOH_CANARY_DOMAIN {
|
||||
return Ok(ResolveStrategy::LocalResponse(nxdomain(message)));
|
||||
if domain == DOH_CANARY_DOMAIN {
|
||||
return ResolveStrategy::LocalResponse(Response::nxdomain(query));
|
||||
}
|
||||
|
||||
// We override the `domain` here to ensure that we use the FQDN everywhere from here on.
|
||||
@@ -278,29 +265,30 @@ impl StubResolver {
|
||||
// `match_resource` is `O(N)` which we deem fine for DNS queries.
|
||||
let maybe_resource = self.match_resource_linear(&domain);
|
||||
|
||||
let resource_records = match (qtype, maybe_resource) {
|
||||
(Rtype::A, Some(resource)) => self.get_or_assign_a_records(domain.clone(), resource),
|
||||
(Rtype::AAAA, Some(resource)) => {
|
||||
let records = match (qtype, maybe_resource) {
|
||||
(RecordType::A, Some(resource)) => {
|
||||
self.get_or_assign_a_records(domain.clone(), resource)
|
||||
}
|
||||
(RecordType::AAAA, Some(resource)) => {
|
||||
self.get_or_assign_aaaa_records(domain.clone(), resource)
|
||||
}
|
||||
(Rtype::SRV | Rtype::TXT, Some(resource)) => {
|
||||
(RecordType::SRV | RecordType::TXT, Some(resource)) => {
|
||||
tracing::debug!(%qtype, %resource, "Forwarding query for DNS resource to corresponding site");
|
||||
|
||||
return Ok(ResolveStrategy::RecurseSite(resource));
|
||||
return ResolveStrategy::RecurseSite(resource);
|
||||
}
|
||||
(Rtype::PTR, _) => {
|
||||
(RecordType::PTR, _) => {
|
||||
let Some(fqdn) = self.resource_address_name_by_reservse_dns(&domain) else {
|
||||
return Ok(ResolveStrategy::RecurseLocal);
|
||||
return ResolveStrategy::RecurseLocal;
|
||||
};
|
||||
|
||||
vec![AllRecordData::Ptr(domain::rdata::Ptr::new(fqdn))]
|
||||
vec![dns_types::records::ptr(fqdn)]
|
||||
}
|
||||
(Rtype::HTTPS, Some(_)) => {
|
||||
(RecordType::HTTPS, Some(_)) => {
|
||||
// We must intercept queries for the HTTPS record type to force the client to issue an A / AAAA query instead.
|
||||
// Otherwise, the client won't use the IPs we issue for a particular domain and the traffic cannot be tunneled.
|
||||
|
||||
let response = build_dns_with_answer(message, Vec::default())?;
|
||||
return Ok(ResolveStrategy::LocalResponse(response));
|
||||
return ResolveStrategy::LocalResponse(Response::no_error(query));
|
||||
}
|
||||
(_, None) if is_single_label_domain => {
|
||||
// Queries for single-label domains, i.e. local hostnames are never recursively resolved but are instead answered with nxdomain.
|
||||
@@ -310,15 +298,18 @@ impl StubResolver {
|
||||
"Query for single-label non-resource domain, responding with NXDOMAIN"
|
||||
);
|
||||
|
||||
return Ok(ResolveStrategy::LocalResponse(nxdomain(message)));
|
||||
return ResolveStrategy::LocalResponse(Response::nxdomain(query));
|
||||
}
|
||||
_ => return Ok(ResolveStrategy::RecurseLocal),
|
||||
_ => return ResolveStrategy::RecurseLocal,
|
||||
};
|
||||
|
||||
tracing::trace!(%qtype, %domain, records = ?resource_records, "Forming DNS response");
|
||||
tracing::trace!(%qtype, %domain, records = ?records, "Forming DNS response");
|
||||
|
||||
let response = build_dns_with_answer(message, resource_records)?;
|
||||
Ok(ResolveStrategy::LocalResponse(response))
|
||||
let response = ResponseBuilder::for_query(query, ResponseCode::NOERROR)
|
||||
.with_records(records.into_iter().map(|r| (domain.clone(), DNS_TTL, r)))
|
||||
.build();
|
||||
|
||||
ResolveStrategy::LocalResponse(response)
|
||||
}
|
||||
|
||||
pub(crate) fn set_search_domain(&mut self, new_search_domain: Option<DomainName>) {
|
||||
@@ -354,60 +345,7 @@ impl StubResolver {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn servfail(message: Message<&[u8]>) -> Message<Vec<u8>> {
|
||||
MessageBuilder::new_vec()
|
||||
.start_answer(&message, Rcode::SERVFAIL)
|
||||
.expect("should always be able to create a heap-allocated SERVFAIL message")
|
||||
.into_message()
|
||||
}
|
||||
|
||||
fn nxdomain(message: Message<&[u8]>) -> Message<Vec<u8>> {
|
||||
MessageBuilder::new_vec()
|
||||
.start_answer(&message, Rcode::NXDOMAIN)
|
||||
.expect("should always be able to create a heap-allocated NXDOMAIN message")
|
||||
.into_message()
|
||||
}
|
||||
|
||||
fn to_a_records(ips: impl Iterator<Item = IpAddr>) -> Vec<AllRecordData<Vec<u8>, DomainName>> {
|
||||
ips.filter_map(get_v4)
|
||||
.map(domain::rdata::A::new)
|
||||
.map(AllRecordData::A)
|
||||
.collect_vec()
|
||||
}
|
||||
|
||||
fn to_aaaa_records(ips: impl Iterator<Item = IpAddr>) -> Vec<AllRecordData<Vec<u8>, DomainName>> {
|
||||
ips.filter_map(get_v6)
|
||||
.map(domain::rdata::Aaaa::new)
|
||||
.map(AllRecordData::Aaaa)
|
||||
.collect_vec()
|
||||
}
|
||||
|
||||
fn build_dns_with_answer(
|
||||
message: Message<&[u8]>,
|
||||
records: Vec<AllRecordData<Vec<u8>, DomainName>>,
|
||||
) -> Result<Message<Vec<u8>>> {
|
||||
// Take the original qname out of the message.
|
||||
// DNS queries should always respond for the exact same qname that was queried, even if we expanded a single-label domain.
|
||||
let qname = message
|
||||
.sole_question()
|
||||
.context("Expected a single question")?
|
||||
.into_qname();
|
||||
|
||||
let mut answer_builder = MessageBuilder::new_vec()
|
||||
.start_answer(&message, Rcode::NOERROR)
|
||||
.context("Failed to create answer from query")?;
|
||||
answer_builder.header_mut().set_ra(true);
|
||||
|
||||
for record in records {
|
||||
answer_builder
|
||||
.push((qname, Class::IN, DNS_TTL, record))
|
||||
.context("Failed to push record")?;
|
||||
}
|
||||
|
||||
Ok(answer_builder.into_message())
|
||||
}
|
||||
|
||||
pub fn is_subdomain(name: &DomainName, resource: &str) -> bool {
|
||||
pub fn is_subdomain(name: &dns_types::DomainName, resource: &str) -> bool {
|
||||
let pattern = match Pattern::new(resource) {
|
||||
Ok(p) => p,
|
||||
Err(e) => {
|
||||
@@ -586,7 +524,7 @@ mod pattern {
|
||||
pub struct Candidate(String);
|
||||
|
||||
impl Candidate {
|
||||
pub fn from_domain(domain: &DomainName) -> Self {
|
||||
pub fn from_domain(domain: &dns_types::DomainName) -> Self {
|
||||
Self(domain.to_string().replace('.', "/"))
|
||||
}
|
||||
}
|
||||
@@ -640,7 +578,6 @@ mod pattern {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use domain::base::Question;
|
||||
use std::str::FromStr as _;
|
||||
use test_case::test_case;
|
||||
|
||||
@@ -763,22 +700,19 @@ mod tests {
|
||||
fn query_for_doh_canary_domain_records_nx_domain() {
|
||||
let mut resolver = StubResolver::default();
|
||||
|
||||
let mut builder = MessageBuilder::new_vec().question();
|
||||
builder
|
||||
.push(Question::new_in(
|
||||
"use-application-dns.net".parse::<DomainName>().unwrap(),
|
||||
Rtype::A,
|
||||
))
|
||||
.unwrap();
|
||||
let query = builder.into_message();
|
||||
let query = Query::new(
|
||||
"use-application-dns.net"
|
||||
.parse::<dns_types::DomainName>()
|
||||
.unwrap(),
|
||||
RecordType::A,
|
||||
);
|
||||
|
||||
let ResolveStrategy::LocalResponse(response) = resolver.handle(query.for_slice_ref())
|
||||
else {
|
||||
let ResolveStrategy::LocalResponse(response) = resolver.handle(&query) else {
|
||||
panic!("Unexpected result")
|
||||
};
|
||||
|
||||
assert_eq!(response.header().rcode(), Rcode::NXDOMAIN);
|
||||
assert_eq!(response.answer().unwrap().count(), 0);
|
||||
assert_eq!(response.response_code(), ResponseCode::NXDOMAIN);
|
||||
assert_eq!(response.records().count(), 0);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -808,7 +742,7 @@ mod benches {
|
||||
.unwrap()
|
||||
.to_string();
|
||||
|
||||
let needle = DomainName::vec_from_str(&needle).unwrap();
|
||||
let needle = dns_types::DomainName::vec_from_str(&needle).unwrap();
|
||||
|
||||
(resolver, needle)
|
||||
})
|
||||
|
||||
@@ -6,7 +6,8 @@ use crate::{peer::ClientOnGateway, peer_store::PeerStore};
|
||||
use anyhow::{Context, Result};
|
||||
use boringtun::x25519::PublicKey;
|
||||
use chrono::{DateTime, Utc};
|
||||
use connlib_model::{ClientId, DomainName, RelayId, ResourceId};
|
||||
use connlib_model::{ClientId, RelayId, ResourceId};
|
||||
use dns_types::DomainName;
|
||||
use ip_packet::{FzP2pControlSlice, IpPacket};
|
||||
use secrecy::{ExposeSecret as _, Secret};
|
||||
use snownet::{Credentials, NoTurnServers, RelaySocket, ServerNode, Transmit};
|
||||
|
||||
@@ -5,7 +5,6 @@ mod udp_dns;
|
||||
|
||||
use crate::{device_channel::Device, dns, sockets::Sockets};
|
||||
use anyhow::Result;
|
||||
use domain::base::Message;
|
||||
use firezone_logging::{telemetry_event, telemetry_span};
|
||||
use futures::FutureExt as _;
|
||||
use futures_bounded::FuturesTupleSet;
|
||||
@@ -56,7 +55,7 @@ pub struct Io {
|
||||
tcp_socket_factory: Arc<dyn SocketFactory<TcpSocket>>,
|
||||
udp_socket_factory: Arc<dyn SocketFactory<UdpSocket>>,
|
||||
|
||||
dns_queries: FuturesTupleSet<io::Result<Message<Vec<u8>>>, DnsQueryMetaData>,
|
||||
dns_queries: FuturesTupleSet<io::Result<dns_types::Response>, DnsQueryMetaData>,
|
||||
|
||||
timeout: Option<Pin<Box<tokio::time::Sleep>>>,
|
||||
|
||||
@@ -66,7 +65,7 @@ pub struct Io {
|
||||
|
||||
#[derive(Debug)]
|
||||
struct DnsQueryMetaData {
|
||||
query: Message<Vec<u8>>,
|
||||
query: dns_types::Query,
|
||||
server: SocketAddr,
|
||||
transport: dns::Transport,
|
||||
}
|
||||
@@ -338,7 +337,7 @@ impl Io {
|
||||
pub(crate) fn send_udp_dns_response(
|
||||
&mut self,
|
||||
to: SocketAddr,
|
||||
message: Message<Vec<u8>>,
|
||||
message: dns_types::Response,
|
||||
) -> io::Result<()> {
|
||||
self.udp_dns_server.send_response(to, message)
|
||||
}
|
||||
@@ -346,7 +345,7 @@ impl Io {
|
||||
pub(crate) fn send_tcp_dns_response(
|
||||
&mut self,
|
||||
to: SocketAddr,
|
||||
message: Message<Vec<u8>>,
|
||||
message: dns_types::Response,
|
||||
) -> io::Result<()> {
|
||||
self.tcp_dns_server.send_response(to, message)
|
||||
}
|
||||
|
||||
@@ -2,13 +2,12 @@ use std::{
|
||||
collections::{BTreeMap, BTreeSet},
|
||||
io,
|
||||
net::{IpAddr, SocketAddr},
|
||||
sync::{Arc, LazyLock},
|
||||
sync::Arc,
|
||||
task::{ready, Context, Poll},
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
|
||||
use connlib_model::DomainName;
|
||||
use domain::base::{iana::Rcode, Message, MessageBuilder, Question, Rtype};
|
||||
use dns_types::{prelude::*, DomainNameRef, Query, RecordType, ResponseCode};
|
||||
use futures_bounded::FuturesTupleSet;
|
||||
use socket_factory::{SocketFactory, TcpSocket, UdpSocket};
|
||||
|
||||
@@ -19,9 +18,8 @@ use super::tcp_dns;
|
||||
const MAX_DNS_SERVERS: usize = 20; // We don't bother selecting from more than 10 servers over UDP and TCP.
|
||||
const DNS_TIMEOUT: Duration = Duration::from_secs(2); // Every sensible DNS servers should respond within 2s.
|
||||
|
||||
static FIREZONE_DEV: LazyLock<DomainName> = LazyLock::new(|| {
|
||||
DomainName::vec_from_str("firezone.dev").expect("static domain should always parse")
|
||||
});
|
||||
pub const FIREZONE_DEV: DomainNameRef =
|
||||
unsafe { DomainNameRef::from_octets_unchecked(b"\x08firezone\x03dev\x00") };
|
||||
|
||||
pub struct NameserverSet {
|
||||
inner: BTreeSet<IpAddr>,
|
||||
@@ -29,7 +27,7 @@ pub struct NameserverSet {
|
||||
|
||||
tcp_socket_factory: Arc<dyn SocketFactory<TcpSocket>>,
|
||||
udp_socket_factory: Arc<dyn SocketFactory<UdpSocket>>,
|
||||
queries: FuturesTupleSet<io::Result<Message<Vec<u8>>>, QueryMetaData>,
|
||||
queries: FuturesTupleSet<io::Result<dns_types::Response>, QueryMetaData>,
|
||||
}
|
||||
|
||||
struct QueryMetaData {
|
||||
@@ -63,7 +61,7 @@ impl NameserverSet {
|
||||
udp_dns::send(
|
||||
self.udp_socket_factory.clone(),
|
||||
SocketAddr::new(nameserver, crate::dns::DNS_PORT),
|
||||
query_firezone_dev(),
|
||||
Query::new(FIREZONE_DEV.to_vec(), RecordType::A),
|
||||
),
|
||||
QueryMetaData { nameserver, start },
|
||||
)
|
||||
@@ -78,7 +76,7 @@ impl NameserverSet {
|
||||
tcp_dns::send(
|
||||
self.tcp_socket_factory.clone(),
|
||||
SocketAddr::new(nameserver, crate::dns::DNS_PORT),
|
||||
query_firezone_dev(),
|
||||
Query::new(FIREZONE_DEV.to_vec(), RecordType::A),
|
||||
),
|
||||
QueryMetaData { nameserver, start },
|
||||
)
|
||||
@@ -102,7 +100,7 @@ impl NameserverSet {
|
||||
|
||||
loop {
|
||||
match ready!(self.queries.poll_unpin(cx)) {
|
||||
(Ok(Ok(response)), meta) if response.header().rcode() == Rcode::NOERROR => {
|
||||
(Ok(Ok(response)), meta) if response.response_code() == ResponseCode::NOERROR => {
|
||||
let rtt = meta.start.elapsed();
|
||||
|
||||
tracing::debug!(nameserver = %meta.nameserver, ?rtt, ?response, "DNS query completed");
|
||||
@@ -133,25 +131,22 @@ impl NameserverSet {
|
||||
}
|
||||
}
|
||||
|
||||
fn query_firezone_dev() -> Message<Vec<u8>> {
|
||||
let mut builder = MessageBuilder::new_vec().question();
|
||||
builder.header_mut().set_random_id();
|
||||
builder.header_mut().set_rd(true);
|
||||
builder.header_mut().set_qr(false);
|
||||
|
||||
builder
|
||||
.push(Question::new_in(FIREZONE_DEV.clone(), Rtype::A))
|
||||
.expect("static question should always be valid");
|
||||
|
||||
builder.into_message()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::net::Ipv4Addr;
|
||||
|
||||
use dns_types::DomainName;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn const_domain_is_correct() {
|
||||
assert_eq!(
|
||||
FIREZONE_DEV,
|
||||
DomainName::vec_from_str("firezone.dev").unwrap()
|
||||
)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore = "Needs Internet"]
|
||||
async fn can_evaluate_fastest_nameserver() {
|
||||
|
||||
@@ -1,26 +1,19 @@
|
||||
use std::{io, net::SocketAddr, sync::Arc};
|
||||
|
||||
use domain::base::{Message, ToName as _};
|
||||
use socket_factory::{SocketFactory, TcpSocket};
|
||||
use tokio::io::{AsyncReadExt as _, AsyncWriteExt as _};
|
||||
|
||||
pub async fn send(
|
||||
factory: Arc<dyn SocketFactory<TcpSocket>>,
|
||||
server: SocketAddr,
|
||||
query: Message<Vec<u8>>,
|
||||
) -> io::Result<Message<Vec<u8>>> {
|
||||
let domain = query
|
||||
.sole_question()
|
||||
.expect("all queries should be for a single name")
|
||||
.qname()
|
||||
.to_vec();
|
||||
|
||||
tracing::trace!(target: "wire::dns::recursive::tcp", %server, %domain);
|
||||
query: dns_types::Query,
|
||||
) -> io::Result<dns_types::Response> {
|
||||
tracing::trace!(target: "wire::dns::recursive::tcp", %server, domain = %query.domain());
|
||||
|
||||
let tcp_socket = factory(&server)?; // TODO: Optimise this to reuse a TCP socket to the same resolver.
|
||||
let mut tcp_stream = tcp_socket.connect(server).await?;
|
||||
|
||||
let query = query.into_octets();
|
||||
let query = query.into_bytes();
|
||||
let dns_message_length = (query.len() as u16).to_be_bytes();
|
||||
|
||||
tcp_stream.write_all(&dns_message_length).await?;
|
||||
@@ -34,8 +27,7 @@ pub async fn send(
|
||||
let mut response = vec![0u8; response_length];
|
||||
tcp_stream.read_exact(&mut response).await?;
|
||||
|
||||
let message = Message::from_octets(response)
|
||||
.map_err(|_| io::Error::other("Failed to parse DNS message"))?;
|
||||
let message = dns_types::Response::parse(&response).map_err(io::Error::other)?;
|
||||
|
||||
Ok(message)
|
||||
}
|
||||
|
||||
@@ -4,37 +4,30 @@ use std::{
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
use domain::base::{Message, ToName as _};
|
||||
use socket_factory::{SocketFactory, UdpSocket};
|
||||
|
||||
pub async fn send(
|
||||
factory: Arc<dyn SocketFactory<UdpSocket>>,
|
||||
server: SocketAddr,
|
||||
query: Message<Vec<u8>>,
|
||||
) -> io::Result<Message<Vec<u8>>> {
|
||||
let domain = query
|
||||
.sole_question()
|
||||
.expect("all queries should be for a single name")
|
||||
.qname()
|
||||
.to_vec();
|
||||
query: dns_types::Query,
|
||||
) -> io::Result<dns_types::Response> {
|
||||
tracing::trace!(target: "wire::dns::recursive::udp", %server, domain = %query.domain());
|
||||
|
||||
let bind_addr = match server {
|
||||
SocketAddr::V4(_) => SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), 0),
|
||||
SocketAddr::V6(_) => SocketAddr::new(Ipv6Addr::UNSPECIFIED.into(), 0),
|
||||
};
|
||||
|
||||
tracing::trace!(target: "wire::dns::recursive::udp", %server, %domain);
|
||||
|
||||
// To avoid fragmentation, IP and thus also UDP packets can only reliably sent with an MTU of <= 1500 on the public Internet.
|
||||
const BUF_SIZE: usize = 1500;
|
||||
|
||||
let udp_socket = factory(&bind_addr)?;
|
||||
|
||||
let response = udp_socket
|
||||
.handshake::<BUF_SIZE>(server, query.as_slice())
|
||||
.handshake::<BUF_SIZE>(server, &query.into_bytes())
|
||||
.await?;
|
||||
|
||||
let message = Message::from_octets(response)
|
||||
.map_err(|_| io::Error::other("Failed to parse DNS message"))?;
|
||||
let response = dns_types::Response::parse(&response).map_err(io::Error::other)?;
|
||||
|
||||
Ok(message)
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
@@ -8,7 +8,8 @@
|
||||
use anyhow::Result;
|
||||
use bimap::BiMap;
|
||||
use chrono::Utc;
|
||||
use connlib_model::{ClientId, DomainName, GatewayId, PublicKey, ResourceId, ResourceView};
|
||||
use connlib_model::{ClientId, GatewayId, PublicKey, ResourceId, ResourceView};
|
||||
use dns_types::DomainName;
|
||||
use io::{Buffers, Io};
|
||||
use ip_network::{Ipv4Network, Ipv6Network};
|
||||
use socket_factory::{SocketFactory, TcpSocket, UdpSocket};
|
||||
@@ -256,7 +257,7 @@ impl GatewayTunnel {
|
||||
let message = response.message.unwrap_or_else(|e| {
|
||||
tracing::debug!("DNS query failed: {e}");
|
||||
|
||||
dns::servfail(response.query.for_slice_ref())
|
||||
dns_types::Response::servfail(&response.query)
|
||||
});
|
||||
|
||||
match response.transport {
|
||||
@@ -323,7 +324,7 @@ impl GatewayTunnel {
|
||||
self.io.send_dns_query(dns::RecursiveQuery::via_udp(
|
||||
query.source,
|
||||
SocketAddr::new(nameserver, dns::DNS_PORT),
|
||||
query.message.for_slice_ref(),
|
||||
query.message,
|
||||
));
|
||||
}
|
||||
Poll::Ready(io::Input::TcpDnsQuery(query)) => {
|
||||
|
||||
@@ -3,6 +3,7 @@ use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
|
||||
|
||||
use chrono::{serde::ts_seconds, DateTime, Utc};
|
||||
use connlib_model::RelayId;
|
||||
use dns_types::DomainName;
|
||||
use ip_network::IpNetwork;
|
||||
use secrecy::{ExposeSecret as _, Secret};
|
||||
use serde::{Deserialize, Serialize};
|
||||
@@ -14,8 +15,6 @@ mod key;
|
||||
|
||||
pub use key::{Key, SecretKey};
|
||||
|
||||
use crate::DomainName;
|
||||
|
||||
/// Represents a wireguard peer.
|
||||
#[derive(Debug, Deserialize, Serialize, Clone)]
|
||||
pub struct Peer {
|
||||
|
||||
@@ -18,7 +18,8 @@ pub const DOMAIN_STATUS_EVENT: FzP2pEventType = FzP2pEventType::new(1);
|
||||
pub mod dns_resource_nat {
|
||||
use super::*;
|
||||
use anyhow::{Context as _, Result};
|
||||
use connlib_model::{DomainName, ResourceId};
|
||||
use connlib_model::ResourceId;
|
||||
use dns_types::DomainName;
|
||||
use ip_packet::{FzP2pControlSlice, IpPacket};
|
||||
use std::net::IpAddr;
|
||||
|
||||
@@ -113,7 +114,6 @@ pub mod dns_resource_nat {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use domain::base::Name;
|
||||
|
||||
use super::*;
|
||||
use std::net::{Ipv4Addr, Ipv6Addr};
|
||||
@@ -192,7 +192,7 @@ pub mod dns_resource_nat {
|
||||
let domain =
|
||||
DomainName::vec_from_str(&format!("{label}.{label}.{label}.{label}.{label}.com"))
|
||||
.unwrap();
|
||||
assert_eq!(domain.len(), Name::MAX_LEN);
|
||||
assert_eq!(domain.len(), dns_types::MAX_NAME_LEN);
|
||||
|
||||
domain
|
||||
}
|
||||
|
||||
@@ -7,7 +7,8 @@ use crate::client::{IPV4_RESOURCES, IPV6_RESOURCES};
|
||||
use crate::messages::gateway::Filters;
|
||||
use crate::messages::gateway::ResourceDescription;
|
||||
use chrono::{DateTime, Utc};
|
||||
use connlib_model::{ClientId, DomainName, GatewayId, ResourceId};
|
||||
use connlib_model::{ClientId, GatewayId, ResourceId};
|
||||
use dns_types::DomainName;
|
||||
use filter_engine::FilterEngine;
|
||||
use ip_network::{IpNetwork, Ipv4Network, Ipv6Network};
|
||||
use ip_network_table::IpNetworkTable;
|
||||
|
||||
@@ -4,7 +4,7 @@ use super::{
|
||||
sim_gateway::SimGateway,
|
||||
transition::{Destination, ReplyTo},
|
||||
};
|
||||
use connlib_model::{DomainName, GatewayId};
|
||||
use connlib_model::GatewayId;
|
||||
use ip_packet::IpPacket;
|
||||
use itertools::Itertools;
|
||||
use std::{
|
||||
@@ -377,7 +377,7 @@ fn assert_destination_is_cdir_resource(gateway_received_request: &IpPacket, expe
|
||||
fn assert_destination_is_dns_resource(
|
||||
gateway_received_request: &IpPacket,
|
||||
global_dns_records: &DnsRecords,
|
||||
domain: &DomainName,
|
||||
domain: &dns_types::DomainName,
|
||||
) {
|
||||
let actual = gateway_received_request.destination();
|
||||
let possible_resource_ips = global_dns_records
|
||||
|
||||
@@ -3,21 +3,21 @@ use std::{
|
||||
net::IpAddr,
|
||||
};
|
||||
|
||||
use connlib_model::{DomainName, DomainRecord};
|
||||
use domain::base::{RecordData, Rtype};
|
||||
use dns_types::prelude::*;
|
||||
use dns_types::{DomainName, OwnedRecordData, RecordType};
|
||||
use itertools::Itertools;
|
||||
|
||||
#[derive(Debug, Default, Clone)]
|
||||
pub(crate) struct DnsRecords {
|
||||
inner: BTreeMap<DomainName, BTreeSet<DomainRecord>>,
|
||||
inner: BTreeMap<DomainName, BTreeSet<OwnedRecordData>>,
|
||||
}
|
||||
|
||||
impl DnsRecords {
|
||||
pub(crate) fn domain_ips_iter(&self, name: &DomainName) -> impl Iterator<Item = IpAddr> + '_ {
|
||||
#[expect(clippy::wildcard_enum_match_arm)]
|
||||
self.domain_records_iter(name).filter_map(|r| match r {
|
||||
DomainRecord::A(a) => Some(a.addr().into()),
|
||||
DomainRecord::Aaaa(aaaa) => Some(aaaa.addr().into()),
|
||||
OwnedRecordData::A(a) => Some(a.addr().into()),
|
||||
OwnedRecordData::Aaaa(aaaa) => Some(aaaa.addr().into()),
|
||||
_ => None,
|
||||
})
|
||||
}
|
||||
@@ -25,8 +25,8 @@ impl DnsRecords {
|
||||
pub(crate) fn ips_iter(&self) -> impl Iterator<Item = IpAddr> + '_ {
|
||||
#[expect(clippy::wildcard_enum_match_arm)]
|
||||
self.inner.values().flatten().filter_map(|r| match r {
|
||||
DomainRecord::A(a) => Some(a.addr().into()),
|
||||
DomainRecord::Aaaa(aaaa) => Some(aaaa.addr().into()),
|
||||
OwnedRecordData::A(a) => Some(a.addr().into()),
|
||||
OwnedRecordData::Aaaa(aaaa) => Some(aaaa.addr().into()),
|
||||
_ => None,
|
||||
})
|
||||
}
|
||||
@@ -34,7 +34,7 @@ impl DnsRecords {
|
||||
pub(crate) fn domain_records_iter(
|
||||
&self,
|
||||
name: &DomainName,
|
||||
) -> impl Iterator<Item = DomainRecord> + '_ {
|
||||
) -> impl Iterator<Item = OwnedRecordData> + '_ {
|
||||
self.inner.get(name).cloned().into_iter().flatten()
|
||||
}
|
||||
|
||||
@@ -50,7 +50,7 @@ impl DnsRecords {
|
||||
self.inner.extend(other.inner);
|
||||
}
|
||||
|
||||
pub(crate) fn domain_rtypes(&self, name: &DomainName) -> Vec<Rtype> {
|
||||
pub(crate) fn domain_rtypes(&self, name: &DomainName) -> Vec<RecordType> {
|
||||
self.domain_records_iter(name)
|
||||
.map(|r| r.rtype())
|
||||
.dedup()
|
||||
@@ -64,7 +64,7 @@ impl DnsRecords {
|
||||
|
||||
impl<I> From<I> for DnsRecords
|
||||
where
|
||||
BTreeMap<DomainName, BTreeSet<DomainRecord>>: From<I>,
|
||||
BTreeMap<DomainName, BTreeSet<OwnedRecordData>>: From<I>,
|
||||
{
|
||||
fn from(value: I) -> Self {
|
||||
Self {
|
||||
@@ -75,7 +75,7 @@ where
|
||||
|
||||
impl<I> FromIterator<I> for DnsRecords
|
||||
where
|
||||
BTreeMap<DomainName, BTreeSet<DomainRecord>>: FromIterator<I>,
|
||||
BTreeMap<DomainName, BTreeSet<OwnedRecordData>>: FromIterator<I>,
|
||||
{
|
||||
fn from_iter<T: IntoIterator<Item = I>>(iter: T) -> Self {
|
||||
Self {
|
||||
@@ -83,10 +83,3 @@ where
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn ip_to_domain_record(ip: IpAddr) -> DomainRecord {
|
||||
match ip {
|
||||
IpAddr::V4(ip) => DomainRecord::A(ip.into()),
|
||||
IpAddr::V6(ip) => DomainRecord::Aaaa(ip.into()),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,10 +4,8 @@ use std::{
|
||||
time::Instant,
|
||||
};
|
||||
|
||||
use domain::base::{
|
||||
iana::{Class, Rcode},
|
||||
Message, MessageBuilder, Record, RecordData, ToName, Ttl,
|
||||
};
|
||||
use dns_types::prelude::*;
|
||||
use dns_types::ResponseCode;
|
||||
use ip_packet::{IpPacket, MAX_UDP_PAYLOAD};
|
||||
|
||||
use super::dns_records::DnsRecords;
|
||||
@@ -37,7 +35,7 @@ impl TcpDnsServerResource {
|
||||
pub fn handle_timeout(&mut self, global_dns_records: &DnsRecords, now: Instant) {
|
||||
self.server.handle_timeout(now);
|
||||
while let Some(query) = self.server.poll_queries() {
|
||||
let response = handle_dns_query(query.message.for_slice(), global_dns_records);
|
||||
let response = handle_dns_query(&query.message, global_dns_records);
|
||||
|
||||
self.server
|
||||
.send_message(query.local, query.remote, response)
|
||||
@@ -58,9 +56,9 @@ impl UdpDnsServerResource {
|
||||
pub fn handle_timeout(&mut self, global_dns_records: &DnsRecords, _: Instant) {
|
||||
while let Some(packet) = self.inbound_packets.pop_front() {
|
||||
let udp = packet.as_udp().unwrap();
|
||||
let query = Message::from_octets(udp.payload().to_vec()).unwrap();
|
||||
let query = dns_types::Query::parse(udp.payload()).unwrap();
|
||||
|
||||
let response = handle_dns_query(query.for_slice(), global_dns_records);
|
||||
let response = handle_dns_query(&query, global_dns_records);
|
||||
|
||||
self.outbound_packets.push_back(
|
||||
ip_packet::make::udp_packet(
|
||||
@@ -68,7 +66,7 @@ impl UdpDnsServerResource {
|
||||
packet.source(),
|
||||
udp.destination_port(),
|
||||
udp.source_port(),
|
||||
truncate_dns_response(response),
|
||||
response.into_bytes(MAX_UDP_PAYLOAD),
|
||||
)
|
||||
.expect("src and dst are retrieved from the same packet"),
|
||||
)
|
||||
@@ -80,44 +78,18 @@ impl UdpDnsServerResource {
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_dns_query(query: &Message<[u8]>, global_dns_records: &DnsRecords) -> Message<Vec<u8>> {
|
||||
let response = MessageBuilder::new_vec();
|
||||
let mut answers = response.start_answer(query, Rcode::NOERROR).unwrap();
|
||||
fn handle_dns_query(
|
||||
query: &dns_types::Query,
|
||||
global_dns_records: &DnsRecords,
|
||||
) -> dns_types::Response {
|
||||
let domain = query.domain().to_vec();
|
||||
|
||||
for query in query.question() {
|
||||
let query = query.unwrap();
|
||||
let name = query.qname().to_name::<Vec<u8>>();
|
||||
let records = global_dns_records
|
||||
.domain_records_iter(&domain)
|
||||
.filter(|r| r.rtype() == query.qtype())
|
||||
.map(|rdata| (domain.clone(), 60 * 60 * 24, rdata));
|
||||
|
||||
let records = global_dns_records
|
||||
.domain_records_iter(&name)
|
||||
.filter(|r| r.rtype() == query.qtype())
|
||||
.map(|rdata| Record::new(name.clone(), Class::IN, Ttl::from_days(1), rdata));
|
||||
|
||||
for record in records {
|
||||
answers.push(record).unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
answers.into_message()
|
||||
}
|
||||
|
||||
fn truncate_dns_response(message: Message<Vec<u8>>) -> Vec<u8> {
|
||||
let mut message_bytes = message.as_octets().to_vec();
|
||||
|
||||
if message_bytes.len() > MAX_UDP_PAYLOAD {
|
||||
let mut new_message = message.clone();
|
||||
new_message.header_mut().set_tc(true);
|
||||
|
||||
let message_truncation = match message.answer() {
|
||||
Ok(answer) if answer.pos() <= MAX_UDP_PAYLOAD => answer.pos(),
|
||||
// This should be very unlikely or impossible.
|
||||
_ => message.question().pos(),
|
||||
};
|
||||
|
||||
message_bytes = new_message.as_octets().to_vec();
|
||||
|
||||
message_bytes.truncate(message_truncation);
|
||||
}
|
||||
|
||||
message_bytes
|
||||
dns_types::ResponseBuilder::for_query(query, ResponseCode::NOERROR)
|
||||
.with_records(records)
|
||||
.build()
|
||||
}
|
||||
|
||||
@@ -4,10 +4,10 @@ use super::{
|
||||
composite_strategy::CompositeStrategy, sim_client::*, sim_gateway::*, sim_net::*,
|
||||
strategies::*, stub_portal::StubPortal, transition::*,
|
||||
};
|
||||
use crate::{client, DomainName};
|
||||
use crate::client;
|
||||
use crate::{dns::is_subdomain, proptest::relay_id};
|
||||
use connlib_model::{GatewayId, RelayId, StaticSecret};
|
||||
use domain::base::Rtype;
|
||||
use dns_types::{DomainName, RecordType};
|
||||
use ip_network::{Ipv4Network, Ipv6Network};
|
||||
use itertools::Itertools;
|
||||
use prop::sample::select;
|
||||
@@ -616,8 +616,9 @@ impl ReferenceState {
|
||||
return false;
|
||||
};
|
||||
|
||||
let is_ptr_query = matches!(query.r_type, Rtype::PTR);
|
||||
let is_ptr_query = matches!(query.r_type, RecordType::PTR);
|
||||
let is_known_domain = state.global_dns_records.contains_domain(&domain);
|
||||
|
||||
// In case we sampled a PTR query, the domain doesn't have to exist.
|
||||
let ptr_or_known_domain = is_ptr_query || is_known_domain;
|
||||
|
||||
@@ -708,8 +709,8 @@ impl ReferenceState {
|
||||
.dns_records
|
||||
.get(name)
|
||||
.is_some_and(|r| match src {
|
||||
IpAddr::V4(_) => r.contains(&Rtype::A),
|
||||
IpAddr::V6(_) => r.contains(&Rtype::AAAA),
|
||||
IpAddr::V4(_) => r.contains(&RecordType::A),
|
||||
IpAddr::V6(_) => r.contains(&RecordType::AAAA),
|
||||
})
|
||||
&& self.gateways.contains_key(gateway)
|
||||
}
|
||||
@@ -719,10 +720,10 @@ impl ReferenceState {
|
||||
impl ReferenceState {
|
||||
// We surface what are the existing rtypes for a domain so that it's easier
|
||||
// for the proptests to hit an existing record.
|
||||
fn all_domains(&self) -> Vec<(DomainName, Vec<Rtype>)> {
|
||||
fn all_domains(&self) -> Vec<(DomainName, Vec<RecordType>)> {
|
||||
fn domains_and_rtypes(
|
||||
records: &DnsRecords,
|
||||
) -> impl Iterator<Item = (DomainName, Vec<Rtype>)> + use<'_> {
|
||||
) -> impl Iterator<Item = (DomainName, Vec<RecordType>)> + use<'_> {
|
||||
records
|
||||
.domains_iter()
|
||||
.map(|d| (d.clone(), records.domain_rtypes(&d)))
|
||||
@@ -741,14 +742,14 @@ impl ReferenceState {
|
||||
|
||||
// We surface what are the existing rtypes for a domain so that it's easier
|
||||
// for the proptests to hit an existing record.
|
||||
fn single_label_queries_for_search_domains(&self) -> Vec<(DomainName, Vec<Rtype>)> {
|
||||
fn single_label_queries_for_search_domains(&self) -> Vec<(DomainName, Vec<RecordType>)> {
|
||||
let Some(search_domain) = self.client.inner().search_domain.clone() else {
|
||||
return Vec::default();
|
||||
};
|
||||
|
||||
fn domains_and_rtypes(
|
||||
records: &DnsRecords,
|
||||
) -> impl Iterator<Item = (DomainName, Vec<Rtype>)> + use<'_> {
|
||||
) -> impl Iterator<Item = (DomainName, Vec<RecordType>)> + use<'_> {
|
||||
records
|
||||
.domains_iter()
|
||||
.map(|d| (d.clone(), records.domain_rtypes(&d)))
|
||||
|
||||
@@ -10,17 +10,13 @@ use super::{
|
||||
use crate::{
|
||||
client::{CidrResource, DnsResource, InternetResource, Resource},
|
||||
messages::{DnsServer, Interface},
|
||||
DomainName,
|
||||
};
|
||||
use crate::{proptest::*, ClientState};
|
||||
use anyhow::Context;
|
||||
use anyhow::Result;
|
||||
use bimap::BiMap;
|
||||
use connlib_model::{ClientId, GatewayId, RelayId, ResourceId, ResourceStatus, SiteId};
|
||||
use domain::{
|
||||
base::{iana::Opcode, name::FlattenInto, Message, MessageBuilder, Question, Rtype, ToName},
|
||||
rdata::AllRecordData,
|
||||
};
|
||||
use dns_types::{prelude::*, DomainName, Query, RecordData, RecordType};
|
||||
use ip_network::{IpNetwork, Ipv4Network, Ipv6Network};
|
||||
use ip_network_table::IpNetworkTable;
|
||||
use ip_packet::{Icmpv4Type, Icmpv6Type, IpPacket, Layer4Protocol};
|
||||
@@ -115,7 +111,7 @@ impl SimClient {
|
||||
pub(crate) fn send_dns_query_for(
|
||||
&mut self,
|
||||
domain: DomainName,
|
||||
r_type: Rtype,
|
||||
r_type: RecordType,
|
||||
query_id: u16,
|
||||
upstream: SocketAddr,
|
||||
dns_transport: DnsTransport,
|
||||
@@ -133,20 +129,7 @@ impl SimClient {
|
||||
.tunnel_ip_for(sentinel)
|
||||
.expect("tunnel should be initialised");
|
||||
|
||||
// Create the DNS query message
|
||||
let mut msg_builder = MessageBuilder::new_vec();
|
||||
|
||||
msg_builder.header_mut().set_opcode(Opcode::QUERY);
|
||||
msg_builder.header_mut().set_rd(true);
|
||||
msg_builder.header_mut().set_id(query_id);
|
||||
|
||||
// Create the query
|
||||
let mut question_builder = msg_builder.question();
|
||||
question_builder
|
||||
.push(Question::new_in(domain, r_type))
|
||||
.unwrap();
|
||||
|
||||
let message = question_builder.into_message();
|
||||
let query = Query::new(domain, r_type).with_id(query_id);
|
||||
|
||||
match dns_transport {
|
||||
DnsTransport::Udp => {
|
||||
@@ -155,7 +138,7 @@ impl SimClient {
|
||||
sentinel,
|
||||
9999, // An application would pick a free source port.
|
||||
53,
|
||||
message.as_octets().to_vec(),
|
||||
query.into_bytes(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
@@ -165,7 +148,7 @@ impl SimClient {
|
||||
}
|
||||
DnsTransport::Tcp => {
|
||||
self.tcp_dns_client
|
||||
.send_query(SocketAddr::new(sentinel, 53), message)
|
||||
.send_query(SocketAddr::new(sentinel, 53), query)
|
||||
.unwrap();
|
||||
self.sent_tcp_dns_queries.insert((upstream, query_id));
|
||||
|
||||
@@ -246,15 +229,15 @@ impl SimClient {
|
||||
match failed_packet.layer4_protocol() {
|
||||
Layer4Protocol::Udp { src, dst } => {
|
||||
self.received_udp_replies
|
||||
.insert((SPort(dst), DPort(src)), packet.clone());
|
||||
.insert((SPort(dst), DPort(src)), packet);
|
||||
}
|
||||
Layer4Protocol::Tcp { src, dst } => {
|
||||
self.received_tcp_replies
|
||||
.insert((SPort(dst), DPort(src)), packet.clone());
|
||||
.insert((SPort(dst), DPort(src)), packet);
|
||||
}
|
||||
Layer4Protocol::Icmp { seq, id } => {
|
||||
self.received_icmp_replies
|
||||
.insert((Seq(seq), Identifier(id)), packet.clone());
|
||||
.insert((Seq(seq), Identifier(id)), packet);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -263,7 +246,7 @@ impl SimClient {
|
||||
|
||||
if let Some(udp) = packet.as_udp() {
|
||||
if udp.source_port() == 53 {
|
||||
let message = Message::from_slice(udp.payload())
|
||||
let response = dns_types::Response::parse(udp.payload())
|
||||
.expect("ip packets on port 53 to be DNS packets");
|
||||
|
||||
// Map back to upstream socket so we can assert on it correctly.
|
||||
@@ -274,10 +257,10 @@ impl SimClient {
|
||||
};
|
||||
|
||||
self.received_udp_dns_responses
|
||||
.insert((upstream, message.header().id()), packet.clone());
|
||||
.insert((upstream, response.id()), packet.clone());
|
||||
|
||||
if !message.header().tc() {
|
||||
self.handle_dns_response(message);
|
||||
if !response.truncated() {
|
||||
self.handle_dns_response(&response);
|
||||
}
|
||||
|
||||
return;
|
||||
@@ -341,26 +324,19 @@ impl SimClient {
|
||||
Some(*socket)
|
||||
}
|
||||
|
||||
pub(crate) fn handle_dns_response(&mut self, message: &Message<[u8]>) {
|
||||
for record in message.answer().unwrap() {
|
||||
let record = record.unwrap();
|
||||
let domain = record.owner().to_name();
|
||||
|
||||
pub(crate) fn handle_dns_response(&mut self, response: &dns_types::Response) {
|
||||
for record in response.records() {
|
||||
#[expect(clippy::wildcard_enum_match_arm)]
|
||||
let ip = match record
|
||||
.into_any_record::<AllRecordData<_, _>>()
|
||||
.unwrap()
|
||||
.data()
|
||||
{
|
||||
AllRecordData::A(a) => IpAddr::from(a.addr()),
|
||||
AllRecordData::Aaaa(aaaa) => IpAddr::from(aaaa.addr()),
|
||||
AllRecordData::Ptr(_) => {
|
||||
let ip = match record.data() {
|
||||
RecordData::A(a) => IpAddr::from(a.addr()),
|
||||
RecordData::Aaaa(aaaa) => IpAddr::from(aaaa.addr()),
|
||||
RecordData::Ptr(_) => {
|
||||
continue;
|
||||
}
|
||||
AllRecordData::Txt(_) => {
|
||||
RecordData::Txt(_) => {
|
||||
continue;
|
||||
}
|
||||
AllRecordData::Srv(_) => {
|
||||
RecordData::Srv(_) => {
|
||||
continue;
|
||||
}
|
||||
unhandled => {
|
||||
@@ -368,7 +344,10 @@ impl SimClient {
|
||||
}
|
||||
};
|
||||
|
||||
self.dns_records.entry(domain).or_default().push(ip);
|
||||
self.dns_records
|
||||
.entry(response.domain())
|
||||
.or_default()
|
||||
.push(ip);
|
||||
}
|
||||
|
||||
// Ensure all IPs are always sorted.
|
||||
@@ -419,7 +398,7 @@ pub struct RefClient {
|
||||
/// The IPs assigned to a domain by connlib are an implementation detail that we don't want to model in these tests.
|
||||
/// Instead, we just remember what _kind_ of records we resolved to be able to sample a matching src IP.
|
||||
#[debug(skip)]
|
||||
pub(crate) dns_records: BTreeMap<DomainName, BTreeSet<Rtype>>,
|
||||
pub(crate) dns_records: BTreeMap<DomainName, BTreeSet<RecordType>>,
|
||||
|
||||
/// Whether we are connected to the gateway serving the Internet resource.
|
||||
#[debug(skip)]
|
||||
@@ -907,7 +886,7 @@ impl RefClient {
|
||||
.find(|id| !self.disabled_resources.contains(id))
|
||||
}
|
||||
|
||||
fn resolved_domains(&self) -> impl Iterator<Item = (DomainName, BTreeSet<Rtype>)> + '_ {
|
||||
fn resolved_domains(&self) -> impl Iterator<Item = (DomainName, BTreeSet<RecordType>)> + '_ {
|
||||
self.dns_records
|
||||
.iter()
|
||||
.filter(|(domain, _)| self.dns_resource_by_domain(domain).is_some())
|
||||
@@ -953,7 +932,7 @@ impl RefClient {
|
||||
.filter_map(|(domain, records)| {
|
||||
records
|
||||
.iter()
|
||||
.any(|r| matches!(r, &Rtype::A))
|
||||
.any(|r| matches!(r, &RecordType::A))
|
||||
.then_some(domain)
|
||||
})
|
||||
.collect()
|
||||
@@ -964,7 +943,7 @@ impl RefClient {
|
||||
.filter_map(|(domain, records)| {
|
||||
records
|
||||
.iter()
|
||||
.any(|r| matches!(r, &Rtype::AAAA))
|
||||
.any(|r| matches!(r, &RecordType::AAAA))
|
||||
.then_some(domain)
|
||||
})
|
||||
.collect()
|
||||
@@ -1065,7 +1044,10 @@ impl RefClient {
|
||||
|
||||
// If we are querying a DNS resource, we will issue a connection intent to the DNS resource, not the CIDR resource.
|
||||
if self.dns_resource_by_domain(&query.domain).is_some()
|
||||
&& matches!(query.r_type, Rtype::A | Rtype::AAAA | Rtype::PTR)
|
||||
&& matches!(
|
||||
query.r_type,
|
||||
RecordType::A | RecordType::AAAA | RecordType::PTR
|
||||
)
|
||||
{
|
||||
return None;
|
||||
}
|
||||
@@ -1077,7 +1059,7 @@ impl RefClient {
|
||||
}
|
||||
|
||||
pub(crate) fn is_site_specific_dns_query(&self, query: &DnsQuery) -> Option<ResourceId> {
|
||||
if !matches!(query.r_type, Rtype::SRV | Rtype::TXT) {
|
||||
if !matches!(query.r_type, RecordType::SRV | RecordType::TXT) {
|
||||
return None;
|
||||
}
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use super::dns_records::{ip_to_domain_record, DnsRecords};
|
||||
use super::dns_records::DnsRecords;
|
||||
use super::{sim_net::Host, sim_relay::ref_relay_host, stub_portal::StubPortal};
|
||||
use crate::client::{
|
||||
CidrResource, DnsResource, InternetResource, DNS_SENTINELS_V4, DNS_SENTINELS_V6,
|
||||
@@ -6,7 +6,8 @@ use crate::client::{
|
||||
};
|
||||
use crate::messages::DnsServer;
|
||||
use crate::{proptest::*, IPV4_TUNNEL, IPV6_TUNNEL};
|
||||
use connlib_model::{DomainRecord, RelayId, Site};
|
||||
use connlib_model::{RelayId, Site};
|
||||
use dns_types::OwnedRecordData;
|
||||
use ip_network::{Ipv4Network, Ipv6Network};
|
||||
use itertools::Itertools;
|
||||
use prop::sample;
|
||||
@@ -27,22 +28,20 @@ pub(crate) fn global_dns_records() -> impl Strategy<Value = DnsRecords> {
|
||||
.prop_map_into()
|
||||
}
|
||||
|
||||
fn dns_record() -> impl Strategy<Value = DomainRecord> {
|
||||
fn dns_record() -> impl Strategy<Value = OwnedRecordData> {
|
||||
prop_oneof![
|
||||
3 => non_reserved_ip().prop_map(ip_to_domain_record),
|
||||
3 => non_reserved_ip().prop_map(dns_types::records::ip),
|
||||
1 => collection::vec(txt_record(), 6..=10)
|
||||
.prop_map(|sections| { sections.into_iter().flatten().collect_vec() })
|
||||
.prop_map(|o| domain::rdata::Txt::from_octets(o).unwrap())
|
||||
.prop_map(DomainRecord::Txt)
|
||||
.prop_map(|content| dns_types::records::txt(content).unwrap())
|
||||
]
|
||||
}
|
||||
|
||||
pub(crate) fn site_specific_dns_record() -> impl Strategy<Value = DomainRecord> {
|
||||
pub(crate) fn site_specific_dns_record() -> impl Strategy<Value = OwnedRecordData> {
|
||||
prop_oneof![
|
||||
collection::vec(txt_record(), 6..=10)
|
||||
.prop_map(|sections| { sections.into_iter().flatten().collect_vec() })
|
||||
.prop_map(|o| domain::rdata::Txt::from_octets(o).unwrap())
|
||||
.prop_map(DomainRecord::Txt),
|
||||
.prop_map(|content| dns_types::records::txt(content).unwrap()),
|
||||
srv_record()
|
||||
]
|
||||
}
|
||||
@@ -60,7 +59,7 @@ fn txt_record() -> impl Strategy<Value = Vec<u8>> {
|
||||
})
|
||||
}
|
||||
|
||||
fn srv_record() -> impl Strategy<Value = DomainRecord> {
|
||||
fn srv_record() -> impl Strategy<Value = OwnedRecordData> {
|
||||
(
|
||||
any::<u16>(),
|
||||
any::<u16>(),
|
||||
@@ -68,7 +67,7 @@ fn srv_record() -> impl Strategy<Value = DomainRecord> {
|
||||
domain_name(2..4).prop_map(|d| d.parse().unwrap()),
|
||||
)
|
||||
.prop_map(|(priority, weight, port, target)| {
|
||||
DomainRecord::Srv(domain::rdata::Srv::new(priority, weight, port, target))
|
||||
dns_types::records::srv(priority, weight, port, target)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -270,12 +269,12 @@ fn double_star_wildcard_dns_resource(
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn resolved_ips() -> impl Strategy<Value = BTreeSet<DomainRecord>> {
|
||||
pub(crate) fn resolved_ips() -> impl Strategy<Value = BTreeSet<OwnedRecordData>> {
|
||||
let record = prop_oneof![
|
||||
dns_resource_ip4s().prop_map_into(),
|
||||
dns_resource_ip6s().prop_map_into()
|
||||
]
|
||||
.prop_map(ip_to_domain_record);
|
||||
.prop_map(dns_types::records::ip);
|
||||
|
||||
collection::btree_set(record, 1..6)
|
||||
}
|
||||
|
||||
@@ -10,8 +10,9 @@ use crate::{
|
||||
client::DnsResource,
|
||||
messages::{gateway, DnsServer},
|
||||
};
|
||||
use connlib_model::{DomainName, GatewayId};
|
||||
use connlib_model::GatewayId;
|
||||
use connlib_model::{ResourceId, SiteId};
|
||||
use dns_types::DomainName;
|
||||
use itertools::Itertools;
|
||||
use proptest::{
|
||||
collection,
|
||||
|
||||
@@ -17,8 +17,8 @@ use crate::tests::transition::Transition;
|
||||
use crate::utils::earliest;
|
||||
use crate::{dns, messages::Interface, ClientEvent, GatewayEvent};
|
||||
use connlib_model::{ClientId, GatewayId, PublicKey, RelayId};
|
||||
use domain::base::iana::{Class, Rcode};
|
||||
use domain::base::{Message, MessageBuilder, Record, RecordData, ToName as _, Ttl};
|
||||
use dns_types::prelude::*;
|
||||
use dns_types::ResponseCode;
|
||||
use rand::distributions::DistString;
|
||||
use rand::SeedableRng;
|
||||
use sha2::Digest;
|
||||
@@ -416,10 +416,8 @@ impl TunnelTest {
|
||||
let server = query.server;
|
||||
let transport = query.transport;
|
||||
|
||||
let response = self.on_recursive_dns_query(
|
||||
query.message.for_slice_ref(),
|
||||
&ref_state.global_dns_records,
|
||||
);
|
||||
let response =
|
||||
self.on_recursive_dns_query(&query.message, &ref_state.global_dns_records);
|
||||
self.client.exec_mut(|c| {
|
||||
c.sut.handle_dns_response(dns::RecursiveResponse {
|
||||
server,
|
||||
@@ -531,8 +529,8 @@ impl TunnelTest {
|
||||
let upstream = c.dns_mapping().get_by_left(&result.server.ip()).unwrap();
|
||||
|
||||
c.received_tcp_dns_responses
|
||||
.insert((*upstream, result.query.header().id()));
|
||||
c.handle_dns_response(message.for_slice())
|
||||
.insert((*upstream, result.query.id()));
|
||||
c.handle_dns_response(&message)
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("TCP DNS query failed: {e:#}");
|
||||
@@ -779,28 +777,22 @@ impl TunnelTest {
|
||||
|
||||
fn on_recursive_dns_query(
|
||||
&self,
|
||||
query: Message<&[u8]>,
|
||||
query: &dns_types::Query,
|
||||
global_dns_records: &DnsRecords,
|
||||
) -> Message<Vec<u8>> {
|
||||
let response = MessageBuilder::new_vec();
|
||||
let mut answers = response.start_answer(&query, Rcode::NOERROR).unwrap();
|
||||
|
||||
let query = query.sole_question().unwrap();
|
||||
let name = query.qname().to_vec();
|
||||
) -> dns_types::Response {
|
||||
let qtype = query.qtype();
|
||||
let domain = query.domain();
|
||||
|
||||
let records = global_dns_records
|
||||
.domain_records_iter(&name)
|
||||
.filter(|record| qtype == record.rtype())
|
||||
.map(|rdata| Record::new(name.clone(), Class::IN, Ttl::from_days(1), rdata));
|
||||
let response = dns_types::ResponseBuilder::for_query(query, ResponseCode::NOERROR)
|
||||
.with_records(
|
||||
global_dns_records
|
||||
.domain_records_iter(&domain)
|
||||
.filter(|record| qtype == record.rtype())
|
||||
.map(|rdata| (domain.clone(), 60 * 60 * 24, rdata)),
|
||||
)
|
||||
.build();
|
||||
|
||||
for record in records {
|
||||
answers.push(record).unwrap();
|
||||
}
|
||||
|
||||
let response = answers.into_message();
|
||||
|
||||
tracing::debug!(%name, %qtype, "Responding to DNS query");
|
||||
tracing::debug!(%domain, %qtype, "Responding to DNS query");
|
||||
|
||||
response
|
||||
}
|
||||
|
||||
@@ -2,12 +2,11 @@ use crate::{
|
||||
client::{Resource, IPV4_RESOURCES, IPV6_RESOURCES},
|
||||
proptest::{host_v4, host_v6},
|
||||
};
|
||||
use connlib_model::RelayId;
|
||||
use connlib_model::{RelayId, ResourceId};
|
||||
use dns_types::{DomainName, RecordType};
|
||||
|
||||
use super::sim_net::{any_ip_stack, any_port, Host};
|
||||
use crate::messages::DnsServer;
|
||||
use connlib_model::{DomainName, ResourceId};
|
||||
use domain::base::Rtype;
|
||||
use prop::collection;
|
||||
use proptest::{prelude::*, sample};
|
||||
use std::{
|
||||
@@ -89,7 +88,7 @@ pub(crate) enum Transition {
|
||||
pub(crate) struct DnsQuery {
|
||||
pub(crate) domain: DomainName,
|
||||
/// The type of DNS query we should send.
|
||||
pub(crate) r_type: Rtype,
|
||||
pub(crate) r_type: RecordType,
|
||||
/// The DNS query ID.
|
||||
pub(crate) query_id: u16,
|
||||
pub(crate) dns_server: SocketAddr,
|
||||
@@ -283,7 +282,7 @@ fn non_dns_ports() -> impl Strategy<Value = u16> {
|
||||
|
||||
/// Samples up to 5 DNS queries that will be sent concurrently into connlib.
|
||||
pub(crate) fn dns_queries(
|
||||
domain: impl Strategy<Value = (DomainName, Vec<Rtype>)>,
|
||||
domain: impl Strategy<Value = (DomainName, Vec<RecordType>)>,
|
||||
dns_server: impl Strategy<Value = SocketAddr>,
|
||||
) -> impl Strategy<Value = Vec<DnsQuery>> {
|
||||
// Queries can be uniquely identified by the tuple of DNS server and query ID.
|
||||
@@ -317,7 +316,7 @@ pub(crate) fn dns_queries(
|
||||
maybe_reverse_record,
|
||||
transport,
|
||||
)| {
|
||||
if matches!(r_type, Rtype::PTR) {
|
||||
if matches!(r_type, RecordType::PTR) {
|
||||
domain =
|
||||
DomainName::reverse_from_addr(maybe_reverse_record).unwrap();
|
||||
}
|
||||
@@ -358,10 +357,15 @@ fn dns_transport() -> impl Strategy<Value = DnsTransport> {
|
||||
///
|
||||
/// Similarrly to trigger NAT64 and NAT46 we need to query for A when only AAAA is available and vice versa.
|
||||
pub(crate) fn maybe_available_response_rtypes(
|
||||
available_rtypes: Vec<Rtype>,
|
||||
) -> impl Strategy<Value = Rtype> {
|
||||
if available_rtypes.contains(&Rtype::A) || available_rtypes.contains(&Rtype::AAAA) {
|
||||
sample::select(vec![Rtype::PTR, Rtype::MX, Rtype::A, Rtype::AAAA])
|
||||
available_rtypes: Vec<RecordType>,
|
||||
) -> impl Strategy<Value = RecordType> {
|
||||
if available_rtypes.contains(&RecordType::A) || available_rtypes.contains(&RecordType::AAAA) {
|
||||
sample::select(vec![
|
||||
RecordType::PTR,
|
||||
RecordType::MX,
|
||||
RecordType::A,
|
||||
RecordType::AAAA,
|
||||
])
|
||||
} else {
|
||||
sample::select(available_rtypes)
|
||||
}
|
||||
|
||||
@@ -7,10 +7,9 @@ license = { workspace = true }
|
||||
|
||||
[dependencies]
|
||||
anyhow = { workspace = true }
|
||||
domain = { workspace = true }
|
||||
dns-types = { workspace = true }
|
||||
firezone-logging = { workspace = true }
|
||||
ip-packet = { workspace = true }
|
||||
itertools = { workspace = true }
|
||||
rand = { workspace = true }
|
||||
smoltcp = { workspace = true, features = ["std", "log", "medium-ip", "proto-ipv4", "proto-ipv6", "socket-tcp"] }
|
||||
tracing = { workspace = true }
|
||||
|
||||
@@ -9,7 +9,6 @@ use crate::{
|
||||
time::smol_now,
|
||||
};
|
||||
use anyhow::{anyhow, bail, Context as _, Result};
|
||||
use domain::{base::Message, dep::octseq::OctetsInto};
|
||||
use ip_packet::IpPacket;
|
||||
use rand::{rngs::StdRng, Rng, SeedableRng};
|
||||
use smoltcp::{
|
||||
@@ -36,9 +35,9 @@ pub struct Client<const MIN_PORT: u16 = 49152, const MAX_PORT: u16 = 65535> {
|
||||
sockets_by_remote: BTreeMap<SocketAddr, smoltcp::iface::SocketHandle>,
|
||||
local_ports_by_socket: HashMap<smoltcp::iface::SocketHandle, u16>,
|
||||
/// Queries we should send to a DNS resolver.
|
||||
pending_queries_by_remote: HashMap<SocketAddr, VecDeque<Message<Vec<u8>>>>,
|
||||
pending_queries_by_remote: HashMap<SocketAddr, VecDeque<dns_types::Query>>,
|
||||
/// Queries we have sent to a DNS resolver and are waiting for a reply.
|
||||
sent_queries_by_remote: HashMap<SocketAddr, HashMap<u16, Message<Vec<u8>>>>,
|
||||
sent_queries_by_remote: HashMap<SocketAddr, HashMap<u16, dns_types::Query>>,
|
||||
|
||||
query_results: VecDeque<QueryResult>,
|
||||
|
||||
@@ -50,9 +49,9 @@ pub struct Client<const MIN_PORT: u16 = 49152, const MAX_PORT: u16 = 65535> {
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct QueryResult {
|
||||
pub query: Message<Vec<u8>>,
|
||||
pub query: dns_types::Query,
|
||||
pub server: SocketAddr,
|
||||
pub result: Result<Message<Vec<u8>>>,
|
||||
pub result: Result<dns_types::Response>,
|
||||
}
|
||||
|
||||
impl<const MIN_PORT: u16, const MAX_PORT: u16> Client<MIN_PORT, MAX_PORT> {
|
||||
@@ -88,9 +87,7 @@ impl<const MIN_PORT: u16, const MAX_PORT: u16> Client<MIN_PORT, MAX_PORT> {
|
||||
/// Send the given DNS query to the target server.
|
||||
///
|
||||
/// This only queues the message. You need to call [`Client::handle_timeout`] to actually send them.
|
||||
pub fn send_query(&mut self, server: SocketAddr, message: Message<Vec<u8>>) -> Result<()> {
|
||||
anyhow::ensure!(!message.header().qr(), "Message is a DNS response!");
|
||||
|
||||
pub fn send_query(&mut self, server: SocketAddr, message: dns_types::Query) -> Result<()> {
|
||||
self.pending_queries_by_remote
|
||||
.entry(server)
|
||||
.or_default()
|
||||
@@ -308,8 +305,8 @@ impl<const MIN_PORT: u16, const MAX_PORT: u16> Client<MIN_PORT, MAX_PORT> {
|
||||
fn send_pending_queries(
|
||||
socket: &mut tcp::Socket,
|
||||
server: SocketAddr,
|
||||
pending_queries: &mut VecDeque<Message<Vec<u8>>>,
|
||||
sent_queries: &mut HashMap<u16, Message<Vec<u8>>>,
|
||||
pending_queries: &mut VecDeque<dns_types::Query>,
|
||||
sent_queries: &mut HashMap<u16, dns_types::Query>,
|
||||
query_results: &mut VecDeque<QueryResult>,
|
||||
) {
|
||||
loop {
|
||||
@@ -321,9 +318,9 @@ fn send_pending_queries(
|
||||
break;
|
||||
};
|
||||
|
||||
match codec::try_send(socket, query.for_slice_ref()).context("Failed to send DNS query") {
|
||||
match codec::try_send(socket, query.as_bytes()).context("Failed to send DNS query") {
|
||||
Ok(()) => {
|
||||
let replaced = sent_queries.insert(query.header().id(), query).is_some();
|
||||
let replaced = sent_queries.insert(query.id(), query).is_some();
|
||||
debug_assert!(!replaced, "Query ID is not unique");
|
||||
}
|
||||
Err(e) => {
|
||||
@@ -344,8 +341,8 @@ fn send_pending_queries(
|
||||
fn recv_responses(
|
||||
socket: &mut tcp::Socket,
|
||||
server: SocketAddr,
|
||||
pending_queries: &mut VecDeque<Message<Vec<u8>>>,
|
||||
sent_queries: &mut HashMap<u16, Message<Vec<u8>>>,
|
||||
pending_queries: &mut VecDeque<dns_types::Query>,
|
||||
sent_queries: &mut HashMap<u16, dns_types::Query>,
|
||||
query_results: &mut VecDeque<QueryResult>,
|
||||
) {
|
||||
let Some(result) = try_recv_response(socket)
|
||||
@@ -358,13 +355,13 @@ fn recv_responses(
|
||||
let new_results = result
|
||||
.and_then(|response| {
|
||||
let query = sent_queries
|
||||
.remove(&response.header().id())
|
||||
.remove(&response.id())
|
||||
.context("DNS resolver sent response for unknown query")?;
|
||||
|
||||
Ok(vec![QueryResult {
|
||||
query,
|
||||
server,
|
||||
result: Ok(response.octets_into()),
|
||||
result: Ok(response),
|
||||
}])
|
||||
})
|
||||
.unwrap_or_else(|e| {
|
||||
@@ -379,8 +376,8 @@ fn recv_responses(
|
||||
fn fail_all_queries<'a>(
|
||||
error: &'a anyhow::Error,
|
||||
server: SocketAddr,
|
||||
pending_queries: &'a mut VecDeque<Message<Vec<u8>>>,
|
||||
sent_queries: &'a mut HashMap<u16, Message<Vec<u8>>>,
|
||||
pending_queries: &'a mut VecDeque<dns_types::Query>,
|
||||
sent_queries: &'a mut HashMap<u16, dns_types::Query>,
|
||||
) -> impl Iterator<Item = QueryResult> + 'a {
|
||||
let pending_queries = pending_queries.drain(..);
|
||||
let sent_queries = sent_queries.drain().map(|(_, query)| query);
|
||||
@@ -391,7 +388,7 @@ fn fail_all_queries<'a>(
|
||||
|
||||
fn into_failed_results(
|
||||
server: SocketAddr,
|
||||
iter: impl IntoIterator<Item = Message<Vec<u8>>>,
|
||||
iter: impl IntoIterator<Item = dns_types::Query>,
|
||||
make_error: impl Fn() -> anyhow::Error,
|
||||
) -> impl Iterator<Item = QueryResult> {
|
||||
iter.into_iter().map(move |query| QueryResult {
|
||||
@@ -401,18 +398,14 @@ fn into_failed_results(
|
||||
})
|
||||
}
|
||||
|
||||
fn try_recv_response<'b>(socket: &'b mut tcp::Socket) -> Result<Option<Message<&'b [u8]>>> {
|
||||
fn try_recv_response(socket: &mut tcp::Socket) -> Result<Option<dns_types::Response>> {
|
||||
if !socket.can_recv() {
|
||||
tracing::trace!(state = %socket.state(), "Not yet ready to receive next message");
|
||||
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let Some(message) = codec::try_recv(socket)? else {
|
||||
return Ok(None);
|
||||
};
|
||||
let maybe_response = codec::try_recv(socket)?;
|
||||
|
||||
anyhow::ensure!(message.header().qr(), "DNS message is a query!");
|
||||
|
||||
Ok(Some(message))
|
||||
Ok(maybe_response)
|
||||
}
|
||||
|
||||
@@ -6,17 +6,10 @@
|
||||
//! Source: <https://datatracker.ietf.org/doc/html/rfc1035#section-4.2.2>.
|
||||
|
||||
use anyhow::{Context as _, Result};
|
||||
use domain::{
|
||||
base::{iana::Rcode, Message, ParsedName, Rtype},
|
||||
rdata::AllRecordData,
|
||||
};
|
||||
use itertools::Itertools as _;
|
||||
use smoltcp::socket::tcp;
|
||||
|
||||
pub fn try_send(socket: &mut tcp::Socket, message: Message<&[u8]>) -> Result<()> {
|
||||
let response = message.as_slice();
|
||||
|
||||
let dns_message_length = (response.len() as u16).to_be_bytes();
|
||||
pub fn try_send(socket: &mut tcp::Socket, message: &[u8]) -> Result<()> {
|
||||
let dns_message_length = (message.len() as u16).to_be_bytes();
|
||||
|
||||
let written = socket
|
||||
.send_slice(&dns_message_length)
|
||||
@@ -28,37 +21,40 @@ pub fn try_send(socket: &mut tcp::Socket, message: Message<&[u8]>) -> Result<()>
|
||||
);
|
||||
|
||||
let written = socket
|
||||
.send_slice(response)
|
||||
.send_slice(message)
|
||||
.context("Failed to write DNS message")?;
|
||||
|
||||
anyhow::ensure!(
|
||||
written == response.len(),
|
||||
written == message.len(),
|
||||
"Not enough space in write buffer for DNS message"
|
||||
);
|
||||
|
||||
if tracing::event_enabled!(target: "wire::dns::tcp::send", tracing::Level::TRACE) {
|
||||
if let Some(ParsedMessage {
|
||||
qid,
|
||||
qname,
|
||||
qtype,
|
||||
response,
|
||||
rcode,
|
||||
records,
|
||||
}) = parse(message)
|
||||
{
|
||||
if response {
|
||||
let records = records.into_iter().join(" | ");
|
||||
tracing::trace!(target: "wire::dns::tcp::send", %qid, %rcode, "{:5} {qname} => [{records}]", qtype.to_string());
|
||||
} else {
|
||||
tracing::trace!(target: "wire::dns::tcp::send", %qid, "{:5} {qname}", qtype.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
// if tracing::event_enabled!(target: "wire::dns::tcp::send", tracing::Level::TRACE) {
|
||||
// if let Some(ParsedMessage {
|
||||
// qid,
|
||||
// qname,
|
||||
// qtype,
|
||||
// response,
|
||||
// rcode,
|
||||
// records,
|
||||
// }) = parse(message)
|
||||
// {
|
||||
// if response {
|
||||
// let records = records.into_iter().join(" | ");
|
||||
// tracing::trace!(target: "wire::dns::tcp::send", %qid, %rcode, "{:5} {qname} => [{records}]", qtype.to_string());
|
||||
// } else {
|
||||
// tracing::trace!(target: "wire::dns::tcp::send", %qid, "{:5} {qname}", qtype.to_string());
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn try_recv<'b>(socket: &'b mut tcp::Socket) -> Result<Option<Message<&'b [u8]>>> {
|
||||
pub fn try_recv<'b, M>(socket: &'b mut tcp::Socket) -> Result<Option<M>>
|
||||
where
|
||||
M: TryFrom<&'b [u8], Error: std::error::Error + Send + Sync + 'static>,
|
||||
{
|
||||
let maybe_message = socket
|
||||
.recv(|r| {
|
||||
// DNS over TCP has a 2-byte length prefix at the start, see <https://datatracker.ietf.org/doc/html/rfc1035#section-4.2.2>.
|
||||
@@ -70,72 +66,30 @@ pub fn try_recv<'b>(socket: &'b mut tcp::Socket) -> Result<Option<Message<&'b [u
|
||||
return (0, None); // Don't consume any bytes unless we can read the full message at once.
|
||||
}
|
||||
|
||||
(2 + dns_message_length, Some(Message::from_octets(message)))
|
||||
(2 + dns_message_length, Some(M::try_from(message)))
|
||||
})
|
||||
.context("Failed to recv TCP data")?
|
||||
.transpose()
|
||||
.context("Failed to parse DNS message")?;
|
||||
|
||||
if tracing::event_enabled!(target: "wire::dns::tcp::recv", tracing::Level::TRACE) {
|
||||
if let Some(ParsedMessage {
|
||||
qid,
|
||||
qname,
|
||||
qtype,
|
||||
rcode,
|
||||
response,
|
||||
records,
|
||||
}) = maybe_message.and_then(parse)
|
||||
{
|
||||
if response {
|
||||
let records = records.into_iter().join(" | ");
|
||||
tracing::trace!(target: "wire::dns::tcp::recv", %qid, %rcode, "{:5} {qname} => [{records}]", qtype.to_string());
|
||||
} else {
|
||||
tracing::trace!(target: "wire::dns::tcp::recv", %qid, "{:5} {qname}", qtype.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
// if tracing::event_enabled!(target: "wire::dns::tcp::recv", tracing::Level::TRACE) {
|
||||
// if let Some(ParsedMessage {
|
||||
// qid,
|
||||
// qname,
|
||||
// qtype,
|
||||
// rcode,
|
||||
// response,
|
||||
// records,
|
||||
// }) = maybe_message.and_then(parse)
|
||||
// {
|
||||
// if response {
|
||||
// let records = records.into_iter().join(" | ");
|
||||
// tracing::trace!(target: "wire::dns::tcp::recv", %qid, %rcode, "{:5} {qname} => [{records}]", qtype.to_string());
|
||||
// } else {
|
||||
// tracing::trace!(target: "wire::dns::tcp::recv", %qid, "{:5} {qname}", qtype.to_string());
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
Ok(maybe_message)
|
||||
}
|
||||
|
||||
fn parse(message: Message<&[u8]>) -> Option<ParsedMessage<'_>> {
|
||||
let question = message.sole_question().ok()?;
|
||||
let answers = message.answer().ok()?;
|
||||
|
||||
let qtype = question.qtype();
|
||||
let qname = question.into_qname();
|
||||
let qid = message.header().id();
|
||||
let response = message.header().qr();
|
||||
let rcode = message.header().rcode();
|
||||
let records = answers
|
||||
.into_iter()
|
||||
.filter_map(|r| {
|
||||
let data = r
|
||||
.ok()?
|
||||
.into_any_record::<AllRecordData<_, _>>()
|
||||
.ok()?
|
||||
.data()
|
||||
.clone();
|
||||
|
||||
Some(data)
|
||||
})
|
||||
.collect();
|
||||
|
||||
Some(ParsedMessage {
|
||||
qid,
|
||||
qname,
|
||||
rcode,
|
||||
qtype,
|
||||
response,
|
||||
records,
|
||||
})
|
||||
}
|
||||
|
||||
struct ParsedMessage<'a> {
|
||||
qid: u16,
|
||||
qname: ParsedName<&'a [u8]>,
|
||||
qtype: Rtype,
|
||||
rcode: Rcode,
|
||||
response: bool,
|
||||
records: Vec<AllRecordData<&'a [u8], ParsedName<&'a [u8]>>>,
|
||||
}
|
||||
|
||||
@@ -9,7 +9,6 @@ use crate::{
|
||||
time::smol_now,
|
||||
};
|
||||
use anyhow::{Context as _, Result};
|
||||
use domain::{base::Message, dep::octseq::OctetsInto as _};
|
||||
use ip_packet::IpPacket;
|
||||
use smoltcp::{
|
||||
iface::{Interface, PollResult, SocketHandle, SocketSet},
|
||||
@@ -38,7 +37,7 @@ pub struct Server {
|
||||
}
|
||||
|
||||
pub struct Query {
|
||||
pub message: Message<Vec<u8>>,
|
||||
pub message: dns_types::Query,
|
||||
/// The local address of the socket that received the query.
|
||||
pub local: SocketAddr,
|
||||
/// The remote address of the client that sent the query.
|
||||
@@ -151,16 +150,16 @@ impl Server {
|
||||
&mut self,
|
||||
src: SocketAddr,
|
||||
dst: SocketAddr,
|
||||
message: Message<Vec<u8>>,
|
||||
response: dns_types::Response,
|
||||
) -> Result<()> {
|
||||
let handle = self
|
||||
.pending_sockets_by_local_remote_and_query_id
|
||||
.remove(&(src, dst, message.header().id()))
|
||||
.remove(&(src, dst, response.id()))
|
||||
.context("No pending query found for message")?;
|
||||
|
||||
let socket = self.sockets.get_mut::<tcp::Socket>(handle);
|
||||
|
||||
write_tcp_dns_response(socket, message.for_slice_ref())
|
||||
codec::try_send(socket, &response.into_bytes(u16::MAX))
|
||||
.inspect_err(|_| socket.abort()) // Abort socket on error.
|
||||
.context("Failed to write DNS response")?;
|
||||
|
||||
@@ -191,7 +190,7 @@ impl Server {
|
||||
while let Some(result) = try_recv_query(socket, local).transpose() {
|
||||
match result {
|
||||
Ok((message, remote)) => {
|
||||
let qid = message.header().id();
|
||||
let qid = message.id();
|
||||
|
||||
tracing::trace!(%local, %remote, %qid, "Received DNS query");
|
||||
|
||||
@@ -236,7 +235,7 @@ impl Server {
|
||||
fn try_recv_query(
|
||||
socket: &mut tcp::Socket,
|
||||
listen: SocketAddr,
|
||||
) -> Result<Option<(Message<Vec<u8>>, SocketAddr)>> {
|
||||
) -> Result<Option<(dns_types::Query, SocketAddr)>> {
|
||||
// smoltcp's sockets can only ever handle a single remote, i.e. there is no permanent listening socket.
|
||||
// to be able to handle a new connection, reset the socket back to `listen` once the connection is closed / closing.
|
||||
{
|
||||
@@ -271,28 +270,16 @@ fn try_recv_query(
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let Some(message) = codec::try_recv(socket)? else {
|
||||
let Some(query) = codec::try_recv(socket)? else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
anyhow::ensure!(!message.header().qr(), "DNS message is a response!");
|
||||
|
||||
let message = message.octets_into();
|
||||
|
||||
let remote = socket
|
||||
.remote_endpoint()
|
||||
.context("Unknown remote endpoint despite having just received a message")?;
|
||||
|
||||
Ok(Some((
|
||||
message,
|
||||
query,
|
||||
SocketAddr::new(remote.addr.into(), remote.port),
|
||||
)))
|
||||
}
|
||||
|
||||
fn write_tcp_dns_response(socket: &mut tcp::Socket, response: Message<&[u8]>) -> Result<()> {
|
||||
anyhow::ensure!(response.header().qr(), "DNS message is a query!");
|
||||
|
||||
codec::try_send(socket, response)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -5,7 +5,7 @@ use std::{
|
||||
};
|
||||
|
||||
use dns_over_tcp::QueryResult;
|
||||
use domain::base::{iana::Rcode, Message, MessageBuilder, Name, Rtype};
|
||||
use dns_types::{Query, RecordType, ResponseBuilder, ResponseCode};
|
||||
|
||||
#[test]
|
||||
fn smoke() {
|
||||
@@ -26,7 +26,10 @@ fn smoke() {
|
||||
|
||||
for id in 0..5 {
|
||||
dns_client
|
||||
.send_query(resolver_addr, a_query("example.com", id))
|
||||
.send_query(
|
||||
resolver_addr,
|
||||
Query::new("example.com".parse().unwrap(), RecordType::A).with_id(id),
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
@@ -41,16 +44,6 @@ fn smoke() {
|
||||
}
|
||||
}
|
||||
|
||||
fn a_query(domain: &str, id: u16) -> Message<Vec<u8>> {
|
||||
let mut builder = MessageBuilder::new_vec().question();
|
||||
builder.header_mut().set_id(id);
|
||||
builder
|
||||
.push((Name::vec_from_str(domain).unwrap(), Rtype::A))
|
||||
.unwrap();
|
||||
|
||||
builder.into_message()
|
||||
}
|
||||
|
||||
fn progress(
|
||||
dns_client: &mut dns_over_tcp::Client,
|
||||
dns_server: &mut dns_over_tcp::Server,
|
||||
@@ -67,13 +60,12 @@ fn progress(
|
||||
}
|
||||
|
||||
if let Some(query) = dns_server.poll_queries() {
|
||||
let response = MessageBuilder::new_vec()
|
||||
.start_answer(&query.message, Rcode::NXDOMAIN)
|
||||
.unwrap()
|
||||
.into_message();
|
||||
|
||||
dns_server
|
||||
.send_message(query.local, query.remote, response)
|
||||
.send_message(
|
||||
query.local,
|
||||
query.remote,
|
||||
ResponseBuilder::for_query(&query.message, ResponseCode::NXDOMAIN).build(),
|
||||
)
|
||||
.unwrap();
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -7,7 +7,7 @@ use std::{
|
||||
};
|
||||
|
||||
use anyhow::{Context as _, Result};
|
||||
use domain::base::{iana::Rcode, MessageBuilder};
|
||||
use dns_types::{ResponseBuilder, ResponseCode};
|
||||
use firezone_bin_shared::TunDeviceManager;
|
||||
use ip_network::Ipv4Network;
|
||||
use tokio::task::JoinSet;
|
||||
@@ -107,13 +107,12 @@ impl Eventloop {
|
||||
}
|
||||
|
||||
if let Some(query) = self.dns_server.poll_queries() {
|
||||
let response = MessageBuilder::new_vec()
|
||||
.start_answer(&query.message, Rcode::NXDOMAIN)
|
||||
.unwrap()
|
||||
.into_message();
|
||||
|
||||
self.dns_server
|
||||
.send_message(query.local, query.remote, response)
|
||||
.send_message(
|
||||
query.local,
|
||||
query.remote,
|
||||
ResponseBuilder::for_query(&query.message, ResponseCode::NXDOMAIN).build(),
|
||||
)
|
||||
.unwrap();
|
||||
continue;
|
||||
}
|
||||
|
||||
16
rust/dns-types/Cargo.toml
Normal file
16
rust/dns-types/Cargo.toml
Normal file
@@ -0,0 +1,16 @@
|
||||
[package]
|
||||
name = "dns-types"
|
||||
version = "0.1.0"
|
||||
edition = { workspace = true }
|
||||
license = { workspace = true }
|
||||
|
||||
[lib]
|
||||
path = "lib.rs"
|
||||
|
||||
[dependencies]
|
||||
domain = { version = "0.10", features = ["serde"] } # Not a workspace dependency because we don't want any other crates to depend on it.
|
||||
thiserror = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
357
rust/dns-types/lib.rs
Normal file
357
rust/dns-types/lib.rs
Normal file
@@ -0,0 +1,357 @@
|
||||
#![cfg_attr(test, allow(clippy::unwrap_used))]
|
||||
|
||||
use domain::{
|
||||
base::{
|
||||
message_builder::AnswerBuilder, name::FlattenInto, HeaderCounts, Message, MessageBuilder,
|
||||
ParsedName, Question, RecordSection,
|
||||
},
|
||||
dep::octseq::OctetsInto,
|
||||
rdata::AllRecordData,
|
||||
};
|
||||
|
||||
pub mod prelude {
|
||||
// Re-export trait names so other crates can call the functions on them.
|
||||
// We don't export the name though so that it cannot conflict.
|
||||
pub use domain::base::name::FlattenInto as _;
|
||||
pub use domain::base::RecordData as _;
|
||||
pub use domain::base::ToName as _;
|
||||
}
|
||||
|
||||
pub const MAX_NAME_LEN: usize = domain::base::Name::MAX_LEN;
|
||||
|
||||
pub type RecordType = domain::base::iana::Rtype;
|
||||
|
||||
pub type DomainNameRef<'a> = domain::base::Name<&'a [u8]>;
|
||||
pub type Record<'a> =
|
||||
domain::base::Record<ParsedName<&'a [u8]>, AllRecordData<&'a [u8], ParsedName<&'a [u8]>>>;
|
||||
pub type RecordData<'a> = AllRecordData<&'a [u8], ParsedName<&'a [u8]>>;
|
||||
|
||||
pub type DomainName = domain::base::Name<Vec<u8>>;
|
||||
pub type OwnedRecord = domain::base::Record<DomainName, AllRecordData<Vec<u8>, DomainName>>;
|
||||
pub type OwnedRecordData = AllRecordData<Vec<u8>, DomainName>;
|
||||
|
||||
pub type ResponseCode = domain::base::iana::Rcode;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Query {
|
||||
inner: Message<Vec<u8>>,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for Query {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("Query")
|
||||
.field("qid", &self.inner.header().id())
|
||||
.field("type", &self.qtype())
|
||||
.field("domain", &self.domain())
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl Query {
|
||||
pub fn parse(slice: &[u8]) -> Result<Self, Error> {
|
||||
let message = Message::from_octets(slice).map_err(|_| Error::TooShort)?;
|
||||
|
||||
if message.header().qr() {
|
||||
return Err(Error::NotAQuery);
|
||||
}
|
||||
|
||||
// We don't need to support multiple questions/qname in a single query because
|
||||
// nobody does it and since this run with each packet we want to squeeze as much optimization
|
||||
// as we can therefore we won't do it.
|
||||
//
|
||||
// See: https://stackoverflow.com/a/55093896
|
||||
let _ = message.sole_question()?; // Verify that there is exactly one question.
|
||||
|
||||
// Verify that we can parse the answers + all records
|
||||
for record in message.answer()? {
|
||||
record?.into_any_record::<AllRecordData<_, _>>()?;
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
inner: message.octets_into(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn new(domain: DomainName, rtype: RecordType) -> Self {
|
||||
let mut inner = MessageBuilder::new_vec().question();
|
||||
inner.header_mut().set_qr(false);
|
||||
inner.header_mut().set_rd(true); // Default to recursion desired.
|
||||
inner.header_mut().set_random_id(); // Default to a random id.
|
||||
|
||||
inner
|
||||
.push((domain, rtype))
|
||||
.expect("Vec-backed message builder never fails");
|
||||
|
||||
Self {
|
||||
inner: inner.into_message(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_id(mut self, id: u16) -> Self {
|
||||
self.inner.header_mut().set_id(id);
|
||||
|
||||
self
|
||||
}
|
||||
|
||||
pub fn id(&self) -> u16 {
|
||||
self.inner.header().id()
|
||||
}
|
||||
|
||||
pub fn domain(&self) -> DomainName {
|
||||
self.question().into_qname().flatten_into()
|
||||
}
|
||||
|
||||
pub fn qtype(&self) -> RecordType {
|
||||
self.question().qtype()
|
||||
}
|
||||
|
||||
pub fn into_bytes(self) -> Vec<u8> {
|
||||
self.inner.into_octets()
|
||||
}
|
||||
|
||||
pub fn as_bytes(&self) -> &[u8] {
|
||||
self.inner.as_slice()
|
||||
}
|
||||
|
||||
fn question(&self) -> Question<ParsedName<&[u8]>> {
|
||||
self.inner.sole_question().expect("verified in ctor")
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&[u8]> for Query {
|
||||
type Error = Error;
|
||||
|
||||
fn try_from(slice: &[u8]) -> Result<Self, Self::Error> {
|
||||
Self::parse(slice)
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&[u8]> for Response {
|
||||
type Error = Error;
|
||||
|
||||
fn try_from(slice: &[u8]) -> Result<Self, Self::Error> {
|
||||
Self::parse(slice)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Response {
|
||||
inner: Message<Vec<u8>>,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for Response {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("Response")
|
||||
.field("qid", &self.inner.header().id())
|
||||
.field("domain", &self.domain())
|
||||
.field("type", &self.qtype())
|
||||
.field("response_code", &self.response_code())
|
||||
.finish_non_exhaustive() // TODO: Add records?
|
||||
}
|
||||
}
|
||||
|
||||
impl Response {
|
||||
/// Creates an empty, "NOERROR" response for the given query.
|
||||
pub fn no_error(query: &Query) -> Self {
|
||||
ResponseBuilder::for_query(query, ResponseCode::NOERROR).build()
|
||||
}
|
||||
|
||||
pub fn servfail(query: &Query) -> Self {
|
||||
ResponseBuilder::for_query(query, ResponseCode::SERVFAIL).build()
|
||||
}
|
||||
|
||||
pub fn nxdomain(query: &Query) -> Self {
|
||||
ResponseBuilder::for_query(query, ResponseCode::NXDOMAIN).build()
|
||||
}
|
||||
|
||||
pub fn parse(slice: &[u8]) -> Result<Self, Error> {
|
||||
let message = Message::from_octets(slice).map_err(|_| Error::TooShort)?;
|
||||
|
||||
if !message.header().qr() {
|
||||
return Err(Error::NotAResponse);
|
||||
}
|
||||
|
||||
let _ = message.sole_question()?; // Verify that there is exactly one question.
|
||||
|
||||
// Verify that we can parse the answers + all records
|
||||
for record in message.answer()? {
|
||||
record?.into_any_record::<AllRecordData<_, _>>()?;
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
inner: message.octets_into(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn id(&self) -> u16 {
|
||||
self.inner.header().id()
|
||||
}
|
||||
|
||||
pub fn truncated(&self) -> bool {
|
||||
self.inner.header().tc()
|
||||
}
|
||||
|
||||
pub fn domain(&self) -> DomainName {
|
||||
self.question().into_qname().flatten_into()
|
||||
}
|
||||
|
||||
pub fn qtype(&self) -> RecordType {
|
||||
self.question().qtype()
|
||||
}
|
||||
|
||||
pub fn response_code(&self) -> ResponseCode {
|
||||
self.inner.header().rcode()
|
||||
}
|
||||
|
||||
pub fn records(&self) -> impl Iterator<Item = Record<'_>> {
|
||||
self.answer().into_iter().map(|r| {
|
||||
r.expect("verified in ctor")
|
||||
.into_any_record::<AllRecordData<_, _>>()
|
||||
.expect("verified in ctor")
|
||||
})
|
||||
}
|
||||
|
||||
/// Serializes this response into a byte slice.
|
||||
///
|
||||
/// The `max_len` parameter specifies the maximum size of the payload.
|
||||
/// In case the payload is bigger than `max_len`, it will be truncated and the TC bit in the header will be set.
|
||||
pub fn into_bytes(mut self, max_len: u16) -> Vec<u8> {
|
||||
let qid = self.inner.header().id();
|
||||
|
||||
let len = self.inner.as_slice().len();
|
||||
if len <= max_len as usize {
|
||||
return self.inner.into_octets();
|
||||
}
|
||||
|
||||
tracing::debug!(%len, %max_len, %qid, domain = %self.domain(), "Truncating DNS response");
|
||||
|
||||
self.inner.header_mut().set_tc(true);
|
||||
|
||||
let start_of_answer = self.answer().pos();
|
||||
|
||||
let mut bytes = self.inner.into_octets();
|
||||
bytes.truncate(start_of_answer);
|
||||
|
||||
let headercounts = HeaderCounts::for_message_slice_mut(&mut bytes);
|
||||
|
||||
// We deleted everything after answers, reset all counts to 0.
|
||||
headercounts.as_slice_mut().fill(0);
|
||||
|
||||
// Set the question count to 1.
|
||||
headercounts.set_qdcount(1);
|
||||
|
||||
bytes
|
||||
}
|
||||
|
||||
fn question(&self) -> Question<ParsedName<&[u8]>> {
|
||||
self.inner.sole_question().expect("verified in ctor")
|
||||
}
|
||||
|
||||
fn answer(&self) -> RecordSection<'_, Vec<u8>> {
|
||||
self.inner.answer().expect("verified in ctor")
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ResponseBuilder {
|
||||
inner: AnswerBuilder<Vec<u8>>,
|
||||
}
|
||||
|
||||
impl ResponseBuilder {
|
||||
pub fn for_query(query: &Query, code: ResponseCode) -> Self {
|
||||
let inner = MessageBuilder::new_vec()
|
||||
.start_answer(&query.inner, code)
|
||||
.expect("Vec-backed message builder never fails");
|
||||
|
||||
Self { inner }
|
||||
}
|
||||
|
||||
pub fn with_records(mut self, records: impl IntoIterator<Item: Into<OwnedRecord>>) -> Self {
|
||||
for record in records {
|
||||
self.inner
|
||||
.push(record.into())
|
||||
.expect("Vec-backed message builder never fails");
|
||||
}
|
||||
|
||||
self
|
||||
}
|
||||
|
||||
pub fn build(self) -> Response {
|
||||
Response {
|
||||
inner: self.inner.into_message(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum Error {
|
||||
#[error("Bytes slice is too short to contain a message")]
|
||||
TooShort,
|
||||
#[error("DNS message is not a query")]
|
||||
NotAQuery,
|
||||
#[error("DNS message is not a response")]
|
||||
NotAResponse,
|
||||
#[error(transparent)]
|
||||
Parse(#[from] domain::base::wire::ParseError),
|
||||
}
|
||||
|
||||
pub mod records {
|
||||
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
|
||||
|
||||
use domain::rdata::{rfc1035::TxtError, Aaaa, Ptr, Srv, Txt, A};
|
||||
|
||||
use super::*;
|
||||
|
||||
pub fn ptr(domain: DomainName) -> OwnedRecordData {
|
||||
OwnedRecordData::Ptr(Ptr::new(domain))
|
||||
}
|
||||
|
||||
pub fn a(ip: Ipv4Addr) -> OwnedRecordData {
|
||||
OwnedRecordData::A(A::new(ip))
|
||||
}
|
||||
|
||||
pub fn aaaa(ip: Ipv6Addr) -> OwnedRecordData {
|
||||
OwnedRecordData::Aaaa(Aaaa::new(ip))
|
||||
}
|
||||
|
||||
pub fn ip(ip: IpAddr) -> OwnedRecordData {
|
||||
match ip {
|
||||
IpAddr::V4(ip) => a(ip),
|
||||
IpAddr::V6(ip) => aaaa(ip),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn txt(content: Vec<u8>) -> Result<OwnedRecordData, TxtError> {
|
||||
Ok(OwnedRecordData::Txt(Txt::from_octets(content)?))
|
||||
}
|
||||
|
||||
pub fn srv(priority: u16, weight: u16, port: u16, target: DomainName) -> OwnedRecordData {
|
||||
OwnedRecordData::Srv(Srv::new(priority, weight, port, target))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::net::Ipv4Addr;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn can_truncate_response() {
|
||||
let domain = DomainName::vec_from_str("example.com").unwrap();
|
||||
|
||||
let query = Query::new(domain.clone(), RecordType::A);
|
||||
let response = ResponseBuilder::for_query(&query, ResponseCode::NOERROR)
|
||||
.with_records(std::iter::repeat_n(
|
||||
(domain.clone(), 1, records::a(Ipv4Addr::LOCALHOST)),
|
||||
1000,
|
||||
))
|
||||
.build();
|
||||
|
||||
let bytes = response.into_bytes(1000);
|
||||
|
||||
let parsed_response = Response::parse(&bytes).unwrap();
|
||||
|
||||
assert!(parsed_response.truncated());
|
||||
assert_eq!(parsed_response.records().count(), 0);
|
||||
assert_eq!(parsed_response.domain(), domain);
|
||||
}
|
||||
}
|
||||
@@ -15,7 +15,7 @@ chrono = { workspace = true }
|
||||
clap = { workspace = true }
|
||||
connlib-model = { workspace = true }
|
||||
dns-lookup = { workspace = true }
|
||||
domain = { workspace = true }
|
||||
dns-types = { workspace = true }
|
||||
either = { workspace = true }
|
||||
firezone-bin-shared = { workspace = true }
|
||||
firezone-logging = { workspace = true }
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
use anyhow::{Context as _, Result};
|
||||
use boringtun::x25519::PublicKey;
|
||||
use connlib_model::DomainName;
|
||||
#[cfg(not(target_os = "windows"))]
|
||||
use dns_lookup::{AddrInfoHints, AddrInfoIter, LookupError};
|
||||
use dns_types::DomainName;
|
||||
use firezone_bin_shared::TunDeviceManager;
|
||||
use firezone_logging::{telemetry_event, telemetry_span};
|
||||
use firezone_tunnel::messages::gateway::{
|
||||
|
||||
@@ -14,6 +14,7 @@ backoff = { workspace = true }
|
||||
clap = { workspace = true, features = ["derive", "env", "string"] }
|
||||
connlib-client-shared = { workspace = true }
|
||||
connlib-model = { workspace = true }
|
||||
dns-types = { workspace = true }
|
||||
firezone-bin-shared = { workspace = true }
|
||||
firezone-logging = { workspace = true }
|
||||
firezone-telemetry = { workspace = true }
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use super::DnsController;
|
||||
use anyhow::{bail, Context as _, Result};
|
||||
use connlib_model::DomainName;
|
||||
use dns_types::DomainName;
|
||||
use firezone_bin_shared::{platform::DnsControlMethod, TunDeviceManager};
|
||||
use std::{net::IpAddr, process::Command, str::FromStr};
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use anyhow::{bail, Context, Result};
|
||||
use connlib_model::DomainName;
|
||||
use dns_types::DomainName;
|
||||
use std::{
|
||||
fs,
|
||||
io::{self, Write},
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
|
||||
use super::DnsController;
|
||||
use anyhow::{Context as _, Result};
|
||||
use connlib_model::DomainName;
|
||||
use dns_types::DomainName;
|
||||
use firezone_bin_shared::platform::{DnsControlMethod, CREATE_NO_WINDOW, TUNNEL_UUID};
|
||||
use firezone_bin_shared::windows::error::EPT_S_NOT_REGISTERED;
|
||||
use std::{io, net::IpAddr, os::windows::process::CommandExt, path::Path, process::Command};
|
||||
|
||||
@@ -12,7 +12,8 @@
|
||||
|
||||
use anyhow::{Context as _, Result};
|
||||
use connlib_client_shared::Callbacks;
|
||||
use connlib_model::{DomainName, ResourceView};
|
||||
use connlib_model::ResourceView;
|
||||
use dns_types::DomainName;
|
||||
use firezone_bin_shared::platform::DnsControlMethod;
|
||||
use firezone_logging::FilterReloadHandle;
|
||||
use std::{
|
||||
|
||||
@@ -43,9 +43,9 @@ pub const MAX_IP_SIZE: usize = 1280;
|
||||
///
|
||||
/// IPv6 headers are always a fixed size whereas IPv4 headers can vary.
|
||||
/// The max length of an IPv4 header is > the fixed length of an IPv6 header.
|
||||
pub const MAX_IP_PAYLOAD: usize = MAX_IP_SIZE - etherparse::Ipv4Header::MAX_LEN;
|
||||
pub const MAX_IP_PAYLOAD: u16 = (MAX_IP_SIZE - etherparse::Ipv4Header::MAX_LEN) as u16;
|
||||
/// The maximum payload a UDP packet can have.
|
||||
pub const MAX_UDP_PAYLOAD: usize = MAX_IP_PAYLOAD - etherparse::UdpHeader::LEN;
|
||||
pub const MAX_UDP_PAYLOAD: u16 = MAX_IP_PAYLOAD - etherparse::UdpHeader::LEN as u16;
|
||||
|
||||
/// The maximum size of the payload that Firezone will send between nodes.
|
||||
///
|
||||
|
||||
Reference in New Issue
Block a user