diff --git a/.github/workflows/_rust.yml b/.github/workflows/_rust.yml index 2ee2ba0f5..0f91bc668 100644 --- a/.github/workflows/_rust.yml +++ b/.github/workflows/_rust.yml @@ -114,7 +114,7 @@ jobs: rg --count --no-ignore "Packet for Internet resource" $TESTCASES_DIR rg --count --no-ignore "Performed IP-NAT46" $TESTCASES_DIR rg --count --no-ignore "Performed IP-NAT64" $TESTCASES_DIR - rg --count --no-ignore "Too big DNS response, truncating" $TESTCASES_DIR + rg --count --no-ignore "Truncating DNS response" $TESTCASES_DIR rg --count --no-ignore "Destination is unreachable" $TESTCASES_DIR rg --count --no-ignore "Forwarding query for DNS resource to corresponding site" $TESTCASES_DIR rg --count --no-ignore "Expanded single-label query into FQDN using search-domain" $TESTCASES_DIR diff --git a/rust/Cargo.lock b/rust/Cargo.lock index 48efc1427..26a47d049 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -1044,6 +1044,7 @@ dependencies = [ "backoff", "connlib-client-shared", "connlib-model", + "dns-types", "firezone-logging", "firezone-telemetry", "flume", @@ -1075,6 +1076,7 @@ dependencies = [ "backoff", "connlib-client-shared", "connlib-model", + "dns-types", "firezone-logging", "firezone-telemetry", "flume", @@ -1107,6 +1109,7 @@ dependencies = [ "bimap", "chrono", "connlib-model", + "dns-types", "firezone-logging", "firezone-tunnel", "ip_network", @@ -1130,7 +1133,6 @@ name = "connlib-model" version = "0.1.0" dependencies = [ "boringtun", - "domain", "ip_network", "itertools 0.13.0", "serde", @@ -1706,13 +1708,12 @@ name = "dns-over-tcp" version = "0.1.0" dependencies = [ "anyhow", - "domain", + "dns-types", "firezone-bin-shared", "firezone-logging", "futures", "ip-packet", "ip_network", - "itertools 0.13.0", "rand 0.8.5", "smoltcp", "tokio", @@ -1720,6 +1721,15 @@ dependencies = [ "tun", ] +[[package]] +name = "dns-types" +version = "0.1.0" +dependencies = [ + "domain", + "thiserror 1.0.69", + "tracing", +] + [[package]] name = "domain" version = "0.10.3" @@ -2013,7 +2023,7 @@ dependencies = [ "clap", "connlib-model", "dns-lookup", - "domain", + "dns-types", "either", "firezone-bin-shared", "firezone-logging", @@ -2141,6 +2151,7 @@ dependencies = [ "connlib-client-shared", "connlib-model", "dirs 5.0.1", + "dns-types", "firezone-bin-shared", "firezone-logging", "firezone-telemetry", @@ -2274,7 +2285,7 @@ dependencies = [ "derive_more 1.0.0", "divan", "dns-over-tcp", - "domain", + "dns-types", "firezone-logging", "firezone-relay", "futures", @@ -3514,7 +3525,7 @@ name = "l4-tcp-dns-server" version = "0.1.0" dependencies = [ "anyhow", - "domain", + "dns-types", "futures", "tokio", "tracing", @@ -3525,7 +3536,7 @@ name = "l4-udp-dns-server" version = "0.1.0" dependencies = [ "anyhow", - "domain", + "dns-types", "futures", "tokio", "tracing", diff --git a/rust/Cargo.toml b/rust/Cargo.toml index 0adbe9eb8..6de2df915 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -10,6 +10,7 @@ members = [ "connlib/snownet", "connlib/tunnel", "dns-over-tcp", + "dns-types", "gateway", "gui-client/src-common", "gui-client/src-tauri", @@ -50,7 +51,6 @@ difference = "2.0.0" dirs = "5.0.1" divan = "0.1.17" dns-lookup = "2.0" -domain = { version = "0.10", features = ["serde"] } either = "1" env_logger = "0.11.6" etherparse = "0.16" @@ -164,6 +164,7 @@ snownet = { path = "connlib/snownet" } l4-udp-dns-server = { path = "connlib/l4-udp-dns-server" } l4-tcp-dns-server = { path = "connlib/l4-tcp-dns-server" } dns-over-tcp = { path = "dns-over-tcp" } +dns-types = { path = "dns-types" } firezone-relay = { path = "relay" } connlib-model = { path = "connlib/model" } firezone-tunnel = { path = "connlib/tunnel" } diff --git a/rust/connlib/clients/android/Cargo.toml b/rust/connlib/clients/android/Cargo.toml index d3ee7d7b8..be9b73529 100644 --- a/rust/connlib/clients/android/Cargo.toml +++ b/rust/connlib/clients/android/Cargo.toml @@ -15,6 +15,7 @@ anyhow = { workspace = true } backoff = { workspace = true } connlib-client-shared = { workspace = true } connlib-model = { workspace = true } +dns-types = { workspace = true } firezone-logging = { workspace = true } firezone-telemetry = { workspace = true } flume = { workspace = true } diff --git a/rust/connlib/clients/android/src/lib.rs b/rust/connlib/clients/android/src/lib.rs index 63a929be3..4865487cd 100644 --- a/rust/connlib/clients/android/src/lib.rs +++ b/rust/connlib/clients/android/src/lib.rs @@ -9,7 +9,8 @@ use crate::tun::Tun; use anyhow::{Context as _, Result}; use backoff::ExponentialBackoffBuilder; use connlib_client_shared::{Callbacks, DisconnectError, Session, V4RouteList, V6RouteList}; -use connlib_model::{DomainName, ResourceView}; +use connlib_model::ResourceView; +use dns_types::DomainName; use firezone_logging::{err_with_src, sentry_layer}; use firezone_telemetry::{Telemetry, ANDROID_DSN}; use ip_network::{Ipv4Network, Ipv6Network}; diff --git a/rust/connlib/clients/apple/Cargo.toml b/rust/connlib/clients/apple/Cargo.toml index 186f1d134..e1db51a5e 100644 --- a/rust/connlib/clients/apple/Cargo.toml +++ b/rust/connlib/clients/apple/Cargo.toml @@ -13,6 +13,7 @@ anyhow = { workspace = true } backoff = { workspace = true } connlib-client-shared = { workspace = true } connlib-model = { workspace = true } +dns-types = { workspace = true } firezone-logging = { workspace = true } firezone-telemetry = { workspace = true } flume = { workspace = true } diff --git a/rust/connlib/clients/apple/src/lib.rs b/rust/connlib/clients/apple/src/lib.rs index 2c4598254..0b9e8f86a 100644 --- a/rust/connlib/clients/apple/src/lib.rs +++ b/rust/connlib/clients/apple/src/lib.rs @@ -9,8 +9,8 @@ use anyhow::Context; use anyhow::Result; use backoff::ExponentialBackoffBuilder; use connlib_client_shared::{Callbacks, DisconnectError, Session, V4RouteList, V6RouteList}; -use connlib_model::DomainName; use connlib_model::ResourceView; +use dns_types::DomainName; use firezone_logging::err_with_src; use firezone_logging::sentry_layer; use firezone_telemetry::Telemetry; diff --git a/rust/connlib/clients/shared/Cargo.toml b/rust/connlib/clients/shared/Cargo.toml index 9b4220be7..4661c391a 100644 --- a/rust/connlib/clients/shared/Cargo.toml +++ b/rust/connlib/clients/shared/Cargo.toml @@ -9,6 +9,7 @@ anyhow = { workspace = true } backoff = { workspace = true } bimap = { workspace = true } connlib-model = { workspace = true } +dns-types = { workspace = true } firezone-logging = { workspace = true } firezone-tunnel = { workspace = true } ip_network = { workspace = true } diff --git a/rust/connlib/clients/shared/src/callbacks.rs b/rust/connlib/clients/shared/src/callbacks.rs index 3d4002aff..3704d8b65 100644 --- a/rust/connlib/clients/shared/src/callbacks.rs +++ b/rust/connlib/clients/shared/src/callbacks.rs @@ -1,4 +1,5 @@ -use connlib_model::{DomainName, ResourceView}; +use connlib_model::ResourceView; +use dns_types::DomainName; use ip_network::{Ipv4Network, Ipv6Network}; use std::{ net::{IpAddr, Ipv4Addr, Ipv6Addr}, diff --git a/rust/connlib/l4-tcp-dns-server/Cargo.toml b/rust/connlib/l4-tcp-dns-server/Cargo.toml index e26842e82..9e51d950f 100644 --- a/rust/connlib/l4-tcp-dns-server/Cargo.toml +++ b/rust/connlib/l4-tcp-dns-server/Cargo.toml @@ -9,7 +9,7 @@ path = "lib.rs" [dependencies] anyhow = { workspace = true } -domain = { workspace = true } +dns-types = { workspace = true } futures = { workspace = true } tokio = { workspace = true, features = ["net", "io-util"] } tracing = { workspace = true } diff --git a/rust/connlib/l4-tcp-dns-server/lib.rs b/rust/connlib/l4-tcp-dns-server/lib.rs index 2a05d41e5..d843ea3e8 100644 --- a/rust/connlib/l4-tcp-dns-server/lib.rs +++ b/rust/connlib/l4-tcp-dns-server/lib.rs @@ -2,8 +2,7 @@ #![cfg_attr(test, allow(clippy::unwrap_used))] -use anyhow::{anyhow, Context as _, Result}; -use domain::base::Message; +use anyhow::{Context as _, Result}; use futures::{ future::BoxFuture, stream::FuturesUnordered, task::AtomicWaker, FutureExt, StreamExt as _, }; @@ -31,7 +30,7 @@ pub struct Server { /// A set of futures that read DNS queries from TCP streams. #[expect(clippy::type_complexity, reason = "We don't care.")] reading_tcp_queries: FuturesUnordered< - BoxFuture<'static, Result>, TcpStream)>>>, + BoxFuture<'static, Result>>, >, /// A set of futures that send DNS responses over TCP streams. sending_tcp_responses: FuturesUnordered>>, @@ -68,7 +67,11 @@ impl Server { Ok(()) } - pub fn send_response(&mut self, to: SocketAddr, response: Message>) -> io::Result<()> { + pub fn send_response( + &mut self, + to: SocketAddr, + response: dns_types::Response, + ) -> io::Result<()> { let mut stream = self .tcp_streams_by_remote .remove(&to) @@ -76,7 +79,9 @@ impl Server { self.sending_tcp_responses.push( async move { - let len = response.as_slice().len() as u16; + let response = response.into_bytes(u16::MAX); // DNS over TCP has a 16-bit length field, we can't encode anything bigger than that. + + let len = response.len() as u16; let len = len.to_be_bytes(); stream @@ -84,7 +89,7 @@ impl Server { .await .context("Failed to write TCP DNS header")?; stream - .write_all(response.as_slice()) + .write_all(&response) .await .context("Failed to write TCP DNS message")?; @@ -161,7 +166,7 @@ fn anyhow_to_io(e: anyhow::Error) -> io::Error { async fn read_tcp_query( mut stream: TcpStream, from: SocketAddr, -) -> Result>, TcpStream)>> { +) -> Result> { let mut buf = [0; 2]; match stream.read_exact(&mut buf).await { Ok(2) => {} @@ -178,8 +183,7 @@ async fn read_tcp_query( .await .context("Failed to read TCP DNS message")?; - let message = - Message::try_from_octets(buf).map_err(|_| anyhow!("Failed to parse DNS message"))?; + let message = dns_types::Query::parse(&buf).context("Failed to parse DNS message")?; Ok(Some((from, message, stream))) } @@ -187,7 +191,7 @@ async fn read_tcp_query( pub struct Query { pub local: SocketAddr, pub remote: SocketAddr, - pub message: Message>, + pub message: dns_types::Query, } fn make_tcp_listener(socket: impl ToSocketAddrs) -> Result { @@ -205,7 +209,6 @@ fn make_tcp_listener(socket: impl ToSocketAddrs) -> Result { #[cfg(all(test, unix))] mod tests { - use domain::base::{iana::Rcode, MessageBuilder}; use std::future::poll_fn; use std::net::{Ipv4Addr, Ipv6Addr}; use std::process::ExitStatus; @@ -227,7 +230,7 @@ mod tests { let query = poll_fn(|cx| server.poll(cx)).await.unwrap(); server - .send_response(query.remote, empty_dns_response(query.message)) + .send_response(query.remote, dns_types::Response::no_error(&query.message)) .unwrap(); } }); @@ -245,13 +248,6 @@ mod tests { server_task.abort(); } - fn empty_dns_response(message: Message>) -> Message> { - MessageBuilder::new_vec() - .start_answer(&message, Rcode::NOERROR) - .unwrap() - .into_message() - } - async fn dig(server: SocketAddr) -> ExitStatus { tokio::process::Command::new("dig") .arg(format!("@{}", server.ip())) diff --git a/rust/connlib/l4-udp-dns-server/Cargo.toml b/rust/connlib/l4-udp-dns-server/Cargo.toml index b411da673..5da5731fa 100644 --- a/rust/connlib/l4-udp-dns-server/Cargo.toml +++ b/rust/connlib/l4-udp-dns-server/Cargo.toml @@ -9,7 +9,7 @@ path = "lib.rs" [dependencies] anyhow = { workspace = true } -domain = { workspace = true } +dns-types = { workspace = true } futures = { workspace = true } tokio = { workspace = true, features = ["net"] } tracing = { workspace = true } diff --git a/rust/connlib/l4-udp-dns-server/lib.rs b/rust/connlib/l4-udp-dns-server/lib.rs index 820e70655..a30d529e2 100644 --- a/rust/connlib/l4-udp-dns-server/lib.rs +++ b/rust/connlib/l4-udp-dns-server/lib.rs @@ -2,8 +2,7 @@ #![cfg_attr(test, allow(clippy::unwrap_used))] -use anyhow::{anyhow, Context as _, Result}; -use domain::base::Message; +use anyhow::{Context as _, Result}; use futures::{ future::BoxFuture, stream::{self, BoxStream, FuturesUnordered}, @@ -24,8 +23,8 @@ pub struct Server { udp_v6: Option>, // Streams that read incoming queries from the UDP sockets. - reading_udp_v4_queries: BoxStream<'static, Result<(SocketAddr, Message>)>>, - reading_udp_v6_queries: BoxStream<'static, Result<(SocketAddr, Message>)>>, + reading_udp_v4_queries: BoxStream<'static, Result<(SocketAddr, dns_types::Query)>>, + reading_udp_v6_queries: BoxStream<'static, Result<(SocketAddr, dns_types::Query)>>, // Futures that send responses on the UDP sockets. sending_udp_v4_responses: FuturesUnordered>>, @@ -69,7 +68,11 @@ impl Server { Ok(()) } - pub fn send_response(&mut self, to: SocketAddr, response: Message>) -> io::Result<()> { + pub fn send_response( + &mut self, + to: SocketAddr, + response: dns_types::Response, + ) -> io::Result<()> { let (udp_socket, workers) = match (to, self.udp_v4.clone(), self.udp_v6.clone()) { (SocketAddr::V4(_), Some(socket), _) => (socket, &mut self.sending_udp_v4_responses), (SocketAddr::V6(_), _, Some(socket)) => (socket, &mut self.sending_udp_v6_responses), @@ -79,8 +82,13 @@ impl Server { workers.push( async move { + // TODO: Make this limit configurable. + // The current 1200 are conservative and should be safe for the public Internet and our WireGuard tunnel. + // Worst-case, the client will re-query over TCP. + let payload = response.into_bytes(1200); + udp_socket - .send_to(response.as_slice(), to) + .send_to(&payload, to) .await .context("Failed to send UDP response")?; @@ -141,7 +149,7 @@ impl Server { /// Produces a stream of incoming DNS queries from a UDP socket for as long as there is at least one strong reference to the socket. fn udp_dns_query_stream( udp_socket: Weak, -) -> BoxStream<'static, Result<(SocketAddr, Message>)>> { +) -> BoxStream<'static, Result<(SocketAddr, dns_types::Query)>> { stream::repeat(udp_socket) // We start with an infinite stream of weak references to the UDP socket. .filter_map(|udp_socket| async move { udp_socket.upgrade() }) // For each item pulled from the stream, we first try to upgrade to a strong reference. .then(read_udp_query) // And then read single DNS query from the socket. @@ -152,7 +160,7 @@ fn anyhow_to_io(e: anyhow::Error) -> io::Error { io::Error::other(format!("{e:#}")) } -async fn read_udp_query(socket: Arc) -> Result<(SocketAddr, Message>)> { +async fn read_udp_query(socket: Arc) -> Result<(SocketAddr, dns_types::Query)> { let mut buffer = vec![0u8; 2000]; // On the public Internet, any MTU > 1500 is very unlikely so 2000 is a safe bet. let (len, from) = socket @@ -162,15 +170,14 @@ async fn read_udp_query(socket: Arc) -> Result<(SocketAddr, Message>, + pub message: dns_types::Query, } fn make_udp_socket(socket: impl ToSocketAddrs) -> Result { @@ -201,7 +208,6 @@ impl Default for Server { #[cfg(all(test, unix))] mod tests { - use domain::base::{iana::Rcode, MessageBuilder}; use std::future::poll_fn; use std::net::{Ipv4Addr, Ipv6Addr}; use std::process::ExitStatus; @@ -223,7 +229,7 @@ mod tests { let query = poll_fn(|cx| server.poll(cx)).await.unwrap(); server - .send_response(query.source, empty_dns_response(query.message)) + .send_response(query.source, dns_types::Response::no_error(&query.message)) .unwrap(); } }); @@ -241,13 +247,6 @@ mod tests { server_task.abort(); } - fn empty_dns_response(message: Message>) -> Message> { - MessageBuilder::new_vec() - .start_answer(&message, Rcode::NOERROR) - .unwrap() - .into_message() - } - async fn dig(server: SocketAddr) -> ExitStatus { tokio::process::Command::new("dig") .arg(format!("@{}", server.ip())) diff --git a/rust/connlib/model/Cargo.toml b/rust/connlib/model/Cargo.toml index ba14e080e..4509b50ea 100644 --- a/rust/connlib/model/Cargo.toml +++ b/rust/connlib/model/Cargo.toml @@ -6,7 +6,6 @@ license = { workspace = true } [dependencies] boringtun = { workspace = true } -domain = { workspace = true } ip_network = { workspace = true, features = ["serde"] } serde = { workspace = true, features = ["derive", "std"] } uuid = { workspace = true, features = ["std", "v4", "serde"] } diff --git a/rust/connlib/model/src/lib.rs b/rust/connlib/model/src/lib.rs index 87c7bfe62..49615651c 100644 --- a/rust/connlib/model/src/lib.rs +++ b/rust/connlib/model/src/lib.rs @@ -13,9 +13,6 @@ pub use view::{ CidrResourceView, DnsResourceView, InternetResourceView, ResourceStatus, ResourceView, }; -pub type DomainName = domain::base::Name>; -pub type DomainRecord = domain::rdata::AllRecordData, DomainName>; - use serde::{Deserialize, Serialize}; use std::fmt; use std::str::FromStr; diff --git a/rust/connlib/tunnel/Cargo.toml b/rust/connlib/tunnel/Cargo.toml index ada1f7af5..d2797a56c 100644 --- a/rust/connlib/tunnel/Cargo.toml +++ b/rust/connlib/tunnel/Cargo.toml @@ -15,7 +15,7 @@ connlib-model = { workspace = true } derive_more = { workspace = true, features = ["debug"] } divan = { workspace = true, optional = true } dns-over-tcp = { workspace = true } -domain = { workspace = true } +dns-types = { workspace = true } firezone-logging = { workspace = true } futures = { workspace = true } futures-bounded = { workspace = true } diff --git a/rust/connlib/tunnel/src/client.rs b/rust/connlib/tunnel/src/client.rs index 272f42f5d..899ea9617 100644 --- a/rust/connlib/tunnel/src/client.rs +++ b/rust/connlib/tunnel/src/client.rs @@ -1,5 +1,6 @@ mod resource; +use dns_types::DomainName; pub(crate) use resource::{CidrResource, Resource}; #[cfg(all(feature = "proptest", test))] pub(crate) use resource::{DnsResource, InternetResource}; @@ -14,9 +15,7 @@ use crate::unique_packet_buffer::UniquePacketBuffer; use crate::{dns, is_peer, p2p_control, IpConfig, TunConfig, IPV4_TUNNEL, IPV6_TUNNEL}; use anyhow::Context; use bimap::BiMap; -use connlib_model::{ - DomainName, GatewayId, PublicKey, RelayId, ResourceId, ResourceStatus, ResourceView, -}; +use connlib_model::{GatewayId, PublicKey, RelayId, ResourceId, ResourceStatus, ResourceView}; use connlib_model::{Site, SiteId}; use firezone_logging::{err_with_src, telemetry_event, unwrap_or_debug, unwrap_or_warn}; use ip_network::{IpNetwork, Ipv4Network, Ipv6Network}; @@ -26,7 +25,6 @@ use itertools::Itertools; use crate::peer::GatewayOnClient; use crate::ClientEvent; -use domain::base::Message; use lru::LruCache; use secrecy::{ExposeSecret as _, Secret}; use snownet::{ClientNode, NoTurnServers, RelaySocket, Transmit}; @@ -444,17 +442,12 @@ impl ClientState { /// We call this function every time a client issues a DNS query for a certain domain. /// Coupling this behaviour together allows a client to refresh the DNS resolution of a DNS resource on the Gateway /// through local DNS resolutions. - fn clear_dns_resource_nat_for_domain(&mut self, message: Message<&[u8]>) { - let Ok(question) = message.sole_question() else { - return; - }; - let domain = question.into_qname(); - + fn clear_dns_resource_nat_for_domain(&mut self, message: &dns_types::Response) { let mut any_deleted = false; self.dns_resource_nat_by_gateway .retain(|(_, candidate), _| { - if candidate == &domain { + if candidate == &message.domain() { any_deleted = true; return false; } @@ -463,7 +456,7 @@ impl ClientState { }); if any_deleted { - tracing::debug!(%domain, "Cleared DNS resource NAT"); + tracing::debug!(domain = %message.domain(), "Cleared DNS resource NAT"); } } @@ -551,16 +544,11 @@ impl ClientState { } pub(crate) fn handle_dns_response(&mut self, response: dns::RecursiveResponse) { - let qid = response.query.header().id(); + let qid = response.query.id(); let server = response.server; - let domain = response - .query - .sole_question() - .ok() - .map(|q| q.into_qname()) - .map(tracing::field::display); + let domain = response.query.domain(); - let _span = tracing::debug_span!("handle_dns_response", %qid, %server, domain).entered(); + let _span = tracing::debug_span!("handle_dns_response", %qid, %server, %domain).entered(); match (response.transport, response.message) { (dns::Transport::Udp { .. }, Err(e)) if e.kind() == io::ErrorKind::TimedOut => { @@ -571,14 +559,14 @@ impl ClientState { .inspect(|message| { tracing::trace!("Received recursive UDP DNS response"); - if message.header().tc() { + if message.truncated() { tracing::debug!("Upstream DNS server had to truncate response"); } }) .unwrap_or_else(|e| { telemetry_event!("Recursive UDP DNS query failed: {}", err_with_src(&e)); - dns::servfail(response.query.for_slice_ref()) + dns_types::Response::servfail(&response.query) }); unwrap_or_warn!( @@ -594,7 +582,7 @@ impl ClientState { .unwrap_or_else(|e| { telemetry_event!("Recursive TCP DNS query failed: {}", err_with_src(&e)); - dns::servfail(response.query.for_slice_ref()) + dns_types::Response::servfail(&response.query) }); unwrap_or_warn!( @@ -665,7 +653,7 @@ impl ClientState { &mut self, from: SocketAddr, dst: SocketAddr, - message: Message>, + message: dns_types::Response, ) -> anyhow::Result<()> { let saddr = *self .dns_mapping @@ -677,7 +665,7 @@ impl ClientState { dst.ip(), DNS_PORT, dst.port(), - truncate_dns_response(message), + message.into_bytes(MAX_UDP_PAYLOAD), )?; self.buffered_packets.push_back(ip_packet); @@ -1144,7 +1132,7 @@ impl ClientState { // Check if the client has assembled a response to a query. if let Some(query_result) = self.tcp_dns_client.poll_query_result() { let server = query_result.server; - let qid = query_result.query.header().id(); + let qid = query_result.query.id(); let known_sockets = &mut self.tcp_dns_streams_by_upstream_and_query_id; let Some((local, remote)) = known_sockets.remove(&(server, qid)) else { @@ -1188,7 +1176,7 @@ impl ClientState { return ControlFlow::Break(()); } - let message = match Message::from_octets(datagram.payload()) { + let message = match dns_types::Query::parse(datagram.payload()) { Ok(message) => message, Err(e) => { tracing::warn!(?packet, "Failed to parse DNS query: {e:#}"); @@ -1198,9 +1186,9 @@ impl ClientState { let source = SocketAddr::new(packet.source(), datagram.source_port()); - match self.stub_resolver.handle(message) { + match self.stub_resolver.handle(&message) { dns::ResolveStrategy::LocalResponse(response) => { - self.clear_dns_resource_nat_for_domain(response.for_slice_ref()); + self.clear_dns_resource_nat_for_domain(&response); self.update_dns_resource_nat(now, iter::empty()); unwrap_or_debug!( @@ -1215,7 +1203,7 @@ impl ClientState { return ControlFlow::Continue(packet); } - let query_id = message.header().id(); + let query_id = message.id(); tracing::trace!(server = %upstream, %query_id, "Forwarding UDP DNS query directly via host"); @@ -1261,7 +1249,7 @@ impl ClientState { return; } - let message = match Message::from_octets(datagram.payload()) { + let message = match dns_types::Query::parse(datagram.payload()) { Ok(message) => message, Err(e) => { tracing::warn!(?packet, "Failed to parse DNS query: {e:#}"); @@ -1269,9 +1257,9 @@ impl ClientState { } }; - match self.stub_resolver.handle(message) { + match self.stub_resolver.handle(&message) { dns::ResolveStrategy::LocalResponse(response) => { - self.clear_dns_resource_nat_for_domain(response.for_slice_ref()); + self.clear_dns_resource_nat_for_domain(&response); self.update_dns_resource_nat(now, iter::empty()); let maybe_packet = ip_packet::make::udp_packet( @@ -1279,7 +1267,7 @@ impl ClientState { packet.source(), datagram.destination_port(), datagram.source_port(), - truncate_dns_response(response), + response.into_bytes(MAX_UDP_PAYLOAD), ) .inspect_err(|e| { tracing::debug!("Failed to create LLMNR DNS response packet: {e:#}"); @@ -1309,9 +1297,8 @@ impl ClientState { .expect("to be a valid UDP packet at this point"); let dst_port = datagram.destination_port(); - let query_id = Message::from_octets(datagram.payload()) + let query_id = dns_types::Query::parse(datagram.payload()) .expect("to be a valid DNS query at this point") - .header() .id(); let connlib_dns_server = SocketAddr::new(dst_ip, dst_port); @@ -1336,7 +1323,7 @@ impl ClientState { } fn handle_tcp_dns_query(&mut self, query: dns_over_tcp::Query, now: Instant) { - let query_id = query.message.header().id(); + let query_id = query.message.id(); let Some(upstream) = self.dns_mapping.get_by_left(&query.local.ip()) else { // This is highly-unlikely but might be possible if our DNS mapping changes whilst the TCP DNS server is processing a request. @@ -1344,9 +1331,9 @@ impl ClientState { }; let server = upstream.address(); - match self.stub_resolver.handle(query.message.for_slice_ref()) { + match self.stub_resolver.handle(&query.message) { dns::ResolveStrategy::LocalResponse(response) => { - self.clear_dns_resource_nat_for_domain(response.for_slice_ref()); + self.clear_dns_resource_nat_for_domain(&response); self.update_dns_resource_nat(now, iter::empty()); unwrap_or_debug!( @@ -1396,7 +1383,7 @@ impl ClientState { server: SocketAddr, query: dns_over_tcp::Query, ) { - let query_id = query.message.header().id(); + let query_id = query.message.id(); match self .tcp_dns_client @@ -1412,7 +1399,7 @@ impl ClientState { self.tcp_dns_server.send_message( query.local, query.remote, - dns::servfail(query.message.for_slice_ref()) + dns_types::Response::servfail(&query.message) ), "Failed to send TCP DNS response: {}" ); @@ -2002,23 +1989,17 @@ fn maybe_mangle_dns_response_from_upstream_dns_server( let src_port = udp.source_port(); let src_socket = SocketAddr::new(src_ip, src_port); - let Ok(message) = domain::base::Message::from_slice(udp.payload()) else { + let Ok(message) = dns_types::Response::parse(udp.payload()) else { return packet; }; let Some(original_dst) = - udp_dns_sockets_by_upstream_and_query_id.remove(&(src_socket, message.header().id())) + udp_dns_sockets_by_upstream_and_query_id.remove(&(src_socket, message.id())) else { return packet; }; - let domain = message - .sole_question() - .ok() - .map(|q| q.into_qname()) - .map(tracing::field::display); - - tracing::trace!(server = %src_ip, query_id = %message.header().id(), domain, "Received UDP DNS response via tunnel"); + tracing::trace!(server = %src_ip, query_id = %message.id(), domain = %message.domain(), "Received UDP DNS response via tunnel"); packet.set_src(original_dst.ip()); packet @@ -2031,28 +2012,6 @@ fn maybe_mangle_dns_response_from_upstream_dns_server( packet } -fn truncate_dns_response(mut message: Message>) -> Vec { - let message_length = message.as_octets().len(); - if message_length <= MAX_UDP_PAYLOAD { - return message.into_octets(); - } - - tracing::debug!(?message, %message_length, "Too big DNS response, truncating"); - - message.header_mut().set_tc(true); - - let message_truncation = match message.answer() { - Ok(answer) if answer.pos() <= MAX_UDP_PAYLOAD => answer.pos(), - // This should be very unlikely or impossible. - _ => message.question().pos(), - }; - - let mut message_bytes = message.into_octets(); - message_bytes.truncate(message_truncation); - - message_bytes -} - /// What triggered us to establish a connection to a Gateway. enum ConnectionTrigger { /// A packet received on the TUN device with a destination IP that maps to one of our resources. diff --git a/rust/connlib/tunnel/src/device_channel.rs b/rust/connlib/tunnel/src/device_channel.rs index 9036e990a..134c757c1 100644 --- a/rust/connlib/tunnel/src/device_channel.rs +++ b/rust/connlib/tunnel/src/device_channel.rs @@ -1,8 +1,4 @@ -use domain::base::iana::Rcode; -use domain::base::{Message, ParsedName, Rtype}; -use domain::rdata::AllRecordData; use ip_packet::IpPacket; -use itertools::Itertools; use std::io; use std::task::{Context, Poll, Waker}; use tracing::Level; @@ -46,8 +42,8 @@ impl Device { for packet in &buf[..n] { if tracing::event_enabled!(target: "wire::dns::qry", Level::TRACE) { - if let Some((qtype, qname, qid)) = parse_dns_query(packet) { - tracing::trace!(target: "wire::dns::qry", %qid, "{:5} {qname}", qtype.to_string()); + if let Some(query) = parse_dns_query(packet) { + tracing::trace!(target: "wire::dns::qry", ?query); } } @@ -72,8 +68,8 @@ impl Device { pub fn send(&mut self, packet: IpPacket) -> io::Result<()> { if tracing::event_enabled!(target: "wire::dns::res", Level::TRACE) { - if let Some((qtype, qname, records, rcode, qid)) = parse_dns_response(&packet) { - tracing::trace!(target: "wire::dns::res", %qid, %rcode, "{:5} {qname} => [{records}]", qtype.to_string()); + if let Some(response) = parse_dns_response(&packet) { + tracing::trace!(target: "wire::dns::res", ?response); } } @@ -102,62 +98,20 @@ fn io_error_not_initialized() -> io::Error { io::Error::new(io::ErrorKind::NotConnected, "device is not initialized yet") } -fn parse_dns_query(packet: &IpPacket) -> Option<(Rtype, ParsedName<&[u8]>, u16)> { +fn parse_dns_query(packet: &IpPacket) -> Option { let udp = packet.as_udp()?; if udp.destination_port() != crate::dns::DNS_PORT { return None; } - let message = &Message::from_slice(udp.payload()).ok()?; - - if message.header().qr() { - return None; - } - - let question = message.sole_question().ok()?; - - let qtype = question.qtype(); - let qname = question.into_qname(); - let id = message.header().id(); - - Some((qtype, qname, id)) + dns_types::Query::parse(udp.payload()).ok() } -#[expect(clippy::type_complexity)] -fn parse_dns_response(packet: &IpPacket) -> Option<(Rtype, ParsedName<&[u8]>, String, Rcode, u16)> { +fn parse_dns_response(packet: &IpPacket) -> Option { let udp = packet.as_udp()?; if udp.source_port() != crate::dns::DNS_PORT { return None; } - let message = &Message::from_slice(udp.payload()).ok()?; - - if !message.header().qr() { - return None; - } - - let question = message.sole_question().ok()?; - - let qtype = question.qtype(); - let qname = question.into_qname(); - let rcode = message.header().rcode(); - - let record_section = message.answer().ok()?; - - let records = record_section - .into_iter() - .filter_map(|r| { - let data = r - .ok()? - .into_any_record::>() - .ok()? - .data() - .clone(); - - Some(data) - }) - .join(" | "); - let id = message.header().id(); - - Some((qtype, qname, records, rcode, id)) + dns_types::Response::parse(udp.payload()).ok() } diff --git a/rust/connlib/tunnel/src/dns.rs b/rust/connlib/tunnel/src/dns.rs index 61dd14f8d..6290470d9 100644 --- a/rust/connlib/tunnel/src/dns.rs +++ b/rust/connlib/tunnel/src/dns.rs @@ -1,21 +1,15 @@ use crate::client::IpProvider; -use anyhow::{Context, Result}; -use connlib_model::{DomainName, ResourceId}; -use domain::base::name::FlattenInto; -use domain::rdata::AllRecordData; -use domain::{ - base::{ - iana::{Class, Rcode, Rtype}, - Message, MessageBuilder, ToName, - }, - dep::octseq::OctetsInto, +use anyhow::{Context as _, Result}; +use connlib_model::ResourceId; +use dns_types::{ + prelude::*, DomainName, DomainNameRef, OwnedRecordData, Query, RecordType, Response, + ResponseBuilder, ResponseCode, }; use firezone_logging::{err_with_src, telemetry_span}; use itertools::Itertools; use pattern::{Candidate, Pattern}; use std::io; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; -use std::sync::LazyLock; use std::{ collections::{BTreeMap, HashMap}, net::SocketAddr, @@ -34,14 +28,14 @@ pub(crate) const DNS_PORT: u16 = 53; /// For Chrome and other Chrome-based browsers, this is not required as /// Chrome will automatically disable DoH if your server(s) don't support /// it. See . -static DOH_CANARY_DOMAIN: LazyLock = LazyLock::new(|| { - DomainName::vec_from_str("use-application-dns.net") - .expect("static domain name should always parse") -}); +/// +/// SAFETY: We have a unit-test for it. +pub const DOH_CANARY_DOMAIN: DomainNameRef = + unsafe { DomainNameRef::from_octets_unchecked(b"\x13use-application-dns\x03net\x00") }; pub struct StubResolver { - fqdn_to_ips: BTreeMap<(DomainName, ResourceId), Vec>, - ips_to_fqdn: HashMap, + fqdn_to_ips: BTreeMap<(dns_types::DomainName, ResourceId), Vec>, + ips_to_fqdn: HashMap, ip_provider: IpProvider, /// All DNS resources we know about, indexed by the glob pattern they match against. dns_resources: BTreeMap, @@ -52,7 +46,7 @@ pub struct StubResolver { #[derive(Debug)] pub(crate) struct RecursiveQuery { pub server: SocketAddr, - pub message: Message>, + pub message: dns_types::Query, pub transport: Transport, } @@ -60,16 +54,20 @@ pub(crate) struct RecursiveQuery { #[derive(Debug)] pub(crate) struct RecursiveResponse { pub server: SocketAddr, - pub query: Message>, - pub message: io::Result>>, + pub query: dns_types::Query, + pub message: io::Result, pub transport: Transport, } impl RecursiveQuery { - pub(crate) fn via_udp(source: SocketAddr, server: SocketAddr, message: Message<&[u8]>) -> Self { + pub(crate) fn via_udp( + source: SocketAddr, + server: SocketAddr, + message: dns_types::Query, + ) -> Self { Self { server, - message: message.octets_into(), + message, transport: Transport::Udp { source }, } } @@ -78,7 +76,7 @@ impl RecursiveQuery { local: SocketAddr, remote: SocketAddr, server: SocketAddr, - message: Message>, + message: dns_types::Query, ) -> Self { Self { server, @@ -104,7 +102,7 @@ pub(crate) enum Transport { #[derive(Debug)] pub(crate) enum ResolveStrategy { /// The query is for a Resource, we have an IP mapped already, and we can respond instantly - LocalResponse(Message>), + LocalResponse(Response), /// The query is for a non-Resource, forward it locally to an upstream or system resolver. RecurseLocal, /// The query is for a DNS resource but for a type that we don't intercept (i.e. SRV, TXT, ...), forward it to the site that hosts the DNS resource and resolve it there. @@ -128,13 +126,16 @@ impl StubResolver { /// /// Semantically, this is like a PTR query, i.e. we check whether we handed out this IP as part of answering a DNS query for one of our resources. /// This is in the hot-path of packet routing and must be fast! - pub(crate) fn resolve_resource_by_ip(&self, ip: &IpAddr) -> Option<&(DomainName, ResourceId)> { + pub(crate) fn resolve_resource_by_ip( + &self, + ip: &IpAddr, + ) -> Option<&(dns_types::DomainName, ResourceId)> { self.ips_to_fqdn.get(ip) } pub(crate) fn resolved_resources( &self, - ) -> impl Iterator)> + '_ { + ) -> impl Iterator)> + '_ { self.fqdn_to_ips .iter() .map(|((domain, resource), ips)| (domain, resource, ips)) @@ -160,21 +161,33 @@ impl StubResolver { fn get_or_assign_a_records( &mut self, - fqdn: DomainName, + fqdn: dns_types::DomainName, resource_id: ResourceId, - ) -> Vec, DomainName>> { - to_a_records(self.get_or_assign_ips(fqdn, resource_id).into_iter()) + ) -> Vec { + self.get_or_assign_ips(fqdn, resource_id) + .into_iter() + .filter_map(get_v4) + .map(dns_types::records::a) + .collect_vec() } fn get_or_assign_aaaa_records( &mut self, - fqdn: DomainName, + fqdn: dns_types::DomainName, resource_id: ResourceId, - ) -> Vec, DomainName>> { - to_aaaa_records(self.get_or_assign_ips(fqdn, resource_id).into_iter()) + ) -> Vec { + self.get_or_assign_ips(fqdn, resource_id) + .into_iter() + .filter_map(get_v6) + .map(dns_types::records::aaaa) + .collect_vec() } - fn get_or_assign_ips(&mut self, fqdn: DomainName, resource_id: ResourceId) -> Vec { + fn get_or_assign_ips( + &mut self, + fqdn: dns_types::DomainName, + resource_id: ResourceId, + ) -> Vec { let ips = self .fqdn_to_ips .entry((fqdn.clone(), resource_id)) @@ -197,7 +210,7 @@ impl StubResolver { /// Attempts to match the given domain against our list of possible patterns. /// /// This performs a linear search and is thus O(N) and **must not** be called in the hot-path of packet routing. - fn match_resource_linear(&self, domain: &DomainName) -> Option { + fn match_resource_linear(&self, domain: &dns_types::DomainName) -> Option { let _span = telemetry_span!("match_resource_linear").entered(); let name = Candidate::from_domain(domain); @@ -222,8 +235,8 @@ impl StubResolver { fn resource_address_name_by_reservse_dns( &self, - reverse_dns_name: &DomainName, - ) -> Option { + reverse_dns_name: &dns_types::DomainName, + ) -> Option { let address = reverse_dns_addr(&reverse_dns_name.to_string())?; let (domain, _) = self.ips_to_fqdn.get(&address)?; @@ -231,40 +244,14 @@ impl StubResolver { } /// Processes the incoming DNS query. - /// - /// Any errors will result in an immediate `SERVFAIL` response. - pub(crate) fn handle(&mut self, message: Message<&[u8]>) -> ResolveStrategy { - match self.try_handle(message) { - Ok(s) => s, - Err(e) => { - tracing::warn!("Failed to handle DNS query: {e:#}"); - - ResolveStrategy::LocalResponse(servfail(message)) - } - } - } - - fn try_handle(&mut self, message: Message<&[u8]>) -> Result { - anyhow::ensure!( - !message.header().qr(), - "Can only handle DNS queries, not responses" - ); - - // We don't need to support multiple questions/qname in a single query because - // nobody does it and since this run with each packet we want to squeeze as much optimization - // as we can therefore we won't do it. - // - // See: https://stackoverflow.com/a/55093896 - let question = message - .sole_question() - .context("Expected a single 'question'")?; - let domain = question.qname().to_vec(); - let qtype = question.qtype(); + pub(crate) fn handle(&mut self, query: &Query) -> ResolveStrategy { + let domain = query.domain(); + let qtype = query.qtype(); tracing::trace!("Parsed packet as DNS query: '{qtype} {domain}'"); - if domain == *DOH_CANARY_DOMAIN { - return Ok(ResolveStrategy::LocalResponse(nxdomain(message))); + if domain == DOH_CANARY_DOMAIN { + return ResolveStrategy::LocalResponse(Response::nxdomain(query)); } // We override the `domain` here to ensure that we use the FQDN everywhere from here on. @@ -278,29 +265,30 @@ impl StubResolver { // `match_resource` is `O(N)` which we deem fine for DNS queries. let maybe_resource = self.match_resource_linear(&domain); - let resource_records = match (qtype, maybe_resource) { - (Rtype::A, Some(resource)) => self.get_or_assign_a_records(domain.clone(), resource), - (Rtype::AAAA, Some(resource)) => { + let records = match (qtype, maybe_resource) { + (RecordType::A, Some(resource)) => { + self.get_or_assign_a_records(domain.clone(), resource) + } + (RecordType::AAAA, Some(resource)) => { self.get_or_assign_aaaa_records(domain.clone(), resource) } - (Rtype::SRV | Rtype::TXT, Some(resource)) => { + (RecordType::SRV | RecordType::TXT, Some(resource)) => { tracing::debug!(%qtype, %resource, "Forwarding query for DNS resource to corresponding site"); - return Ok(ResolveStrategy::RecurseSite(resource)); + return ResolveStrategy::RecurseSite(resource); } - (Rtype::PTR, _) => { + (RecordType::PTR, _) => { let Some(fqdn) = self.resource_address_name_by_reservse_dns(&domain) else { - return Ok(ResolveStrategy::RecurseLocal); + return ResolveStrategy::RecurseLocal; }; - vec![AllRecordData::Ptr(domain::rdata::Ptr::new(fqdn))] + vec![dns_types::records::ptr(fqdn)] } - (Rtype::HTTPS, Some(_)) => { + (RecordType::HTTPS, Some(_)) => { // We must intercept queries for the HTTPS record type to force the client to issue an A / AAAA query instead. // Otherwise, the client won't use the IPs we issue for a particular domain and the traffic cannot be tunneled. - let response = build_dns_with_answer(message, Vec::default())?; - return Ok(ResolveStrategy::LocalResponse(response)); + return ResolveStrategy::LocalResponse(Response::no_error(query)); } (_, None) if is_single_label_domain => { // Queries for single-label domains, i.e. local hostnames are never recursively resolved but are instead answered with nxdomain. @@ -310,15 +298,18 @@ impl StubResolver { "Query for single-label non-resource domain, responding with NXDOMAIN" ); - return Ok(ResolveStrategy::LocalResponse(nxdomain(message))); + return ResolveStrategy::LocalResponse(Response::nxdomain(query)); } - _ => return Ok(ResolveStrategy::RecurseLocal), + _ => return ResolveStrategy::RecurseLocal, }; - tracing::trace!(%qtype, %domain, records = ?resource_records, "Forming DNS response"); + tracing::trace!(%qtype, %domain, records = ?records, "Forming DNS response"); - let response = build_dns_with_answer(message, resource_records)?; - Ok(ResolveStrategy::LocalResponse(response)) + let response = ResponseBuilder::for_query(query, ResponseCode::NOERROR) + .with_records(records.into_iter().map(|r| (domain.clone(), DNS_TTL, r))) + .build(); + + ResolveStrategy::LocalResponse(response) } pub(crate) fn set_search_domain(&mut self, new_search_domain: Option) { @@ -354,60 +345,7 @@ impl StubResolver { } } -pub fn servfail(message: Message<&[u8]>) -> Message> { - MessageBuilder::new_vec() - .start_answer(&message, Rcode::SERVFAIL) - .expect("should always be able to create a heap-allocated SERVFAIL message") - .into_message() -} - -fn nxdomain(message: Message<&[u8]>) -> Message> { - MessageBuilder::new_vec() - .start_answer(&message, Rcode::NXDOMAIN) - .expect("should always be able to create a heap-allocated NXDOMAIN message") - .into_message() -} - -fn to_a_records(ips: impl Iterator) -> Vec, DomainName>> { - ips.filter_map(get_v4) - .map(domain::rdata::A::new) - .map(AllRecordData::A) - .collect_vec() -} - -fn to_aaaa_records(ips: impl Iterator) -> Vec, DomainName>> { - ips.filter_map(get_v6) - .map(domain::rdata::Aaaa::new) - .map(AllRecordData::Aaaa) - .collect_vec() -} - -fn build_dns_with_answer( - message: Message<&[u8]>, - records: Vec, DomainName>>, -) -> Result>> { - // Take the original qname out of the message. - // DNS queries should always respond for the exact same qname that was queried, even if we expanded a single-label domain. - let qname = message - .sole_question() - .context("Expected a single question")? - .into_qname(); - - let mut answer_builder = MessageBuilder::new_vec() - .start_answer(&message, Rcode::NOERROR) - .context("Failed to create answer from query")?; - answer_builder.header_mut().set_ra(true); - - for record in records { - answer_builder - .push((qname, Class::IN, DNS_TTL, record)) - .context("Failed to push record")?; - } - - Ok(answer_builder.into_message()) -} - -pub fn is_subdomain(name: &DomainName, resource: &str) -> bool { +pub fn is_subdomain(name: &dns_types::DomainName, resource: &str) -> bool { let pattern = match Pattern::new(resource) { Ok(p) => p, Err(e) => { @@ -586,7 +524,7 @@ mod pattern { pub struct Candidate(String); impl Candidate { - pub fn from_domain(domain: &DomainName) -> Self { + pub fn from_domain(domain: &dns_types::DomainName) -> Self { Self(domain.to_string().replace('.', "/")) } } @@ -640,7 +578,6 @@ mod pattern { #[cfg(test)] mod tests { use super::*; - use domain::base::Question; use std::str::FromStr as _; use test_case::test_case; @@ -763,22 +700,19 @@ mod tests { fn query_for_doh_canary_domain_records_nx_domain() { let mut resolver = StubResolver::default(); - let mut builder = MessageBuilder::new_vec().question(); - builder - .push(Question::new_in( - "use-application-dns.net".parse::().unwrap(), - Rtype::A, - )) - .unwrap(); - let query = builder.into_message(); + let query = Query::new( + "use-application-dns.net" + .parse::() + .unwrap(), + RecordType::A, + ); - let ResolveStrategy::LocalResponse(response) = resolver.handle(query.for_slice_ref()) - else { + let ResolveStrategy::LocalResponse(response) = resolver.handle(&query) else { panic!("Unexpected result") }; - assert_eq!(response.header().rcode(), Rcode::NXDOMAIN); - assert_eq!(response.answer().unwrap().count(), 0); + assert_eq!(response.response_code(), ResponseCode::NXDOMAIN); + assert_eq!(response.records().count(), 0); } } @@ -808,7 +742,7 @@ mod benches { .unwrap() .to_string(); - let needle = DomainName::vec_from_str(&needle).unwrap(); + let needle = dns_types::DomainName::vec_from_str(&needle).unwrap(); (resolver, needle) }) diff --git a/rust/connlib/tunnel/src/gateway.rs b/rust/connlib/tunnel/src/gateway.rs index 6291d06eb..982b1a339 100644 --- a/rust/connlib/tunnel/src/gateway.rs +++ b/rust/connlib/tunnel/src/gateway.rs @@ -6,7 +6,8 @@ use crate::{peer::ClientOnGateway, peer_store::PeerStore}; use anyhow::{Context, Result}; use boringtun::x25519::PublicKey; use chrono::{DateTime, Utc}; -use connlib_model::{ClientId, DomainName, RelayId, ResourceId}; +use connlib_model::{ClientId, RelayId, ResourceId}; +use dns_types::DomainName; use ip_packet::{FzP2pControlSlice, IpPacket}; use secrecy::{ExposeSecret as _, Secret}; use snownet::{Credentials, NoTurnServers, RelaySocket, ServerNode, Transmit}; diff --git a/rust/connlib/tunnel/src/io.rs b/rust/connlib/tunnel/src/io.rs index cdab319bf..6f44ebd74 100644 --- a/rust/connlib/tunnel/src/io.rs +++ b/rust/connlib/tunnel/src/io.rs @@ -5,7 +5,6 @@ mod udp_dns; use crate::{device_channel::Device, dns, sockets::Sockets}; use anyhow::Result; -use domain::base::Message; use firezone_logging::{telemetry_event, telemetry_span}; use futures::FutureExt as _; use futures_bounded::FuturesTupleSet; @@ -56,7 +55,7 @@ pub struct Io { tcp_socket_factory: Arc>, udp_socket_factory: Arc>, - dns_queries: FuturesTupleSet>>, DnsQueryMetaData>, + dns_queries: FuturesTupleSet, DnsQueryMetaData>, timeout: Option>>, @@ -66,7 +65,7 @@ pub struct Io { #[derive(Debug)] struct DnsQueryMetaData { - query: Message>, + query: dns_types::Query, server: SocketAddr, transport: dns::Transport, } @@ -338,7 +337,7 @@ impl Io { pub(crate) fn send_udp_dns_response( &mut self, to: SocketAddr, - message: Message>, + message: dns_types::Response, ) -> io::Result<()> { self.udp_dns_server.send_response(to, message) } @@ -346,7 +345,7 @@ impl Io { pub(crate) fn send_tcp_dns_response( &mut self, to: SocketAddr, - message: Message>, + message: dns_types::Response, ) -> io::Result<()> { self.tcp_dns_server.send_response(to, message) } diff --git a/rust/connlib/tunnel/src/io/nameserver_set.rs b/rust/connlib/tunnel/src/io/nameserver_set.rs index ceb18e3ba..14a2d57e0 100644 --- a/rust/connlib/tunnel/src/io/nameserver_set.rs +++ b/rust/connlib/tunnel/src/io/nameserver_set.rs @@ -2,13 +2,12 @@ use std::{ collections::{BTreeMap, BTreeSet}, io, net::{IpAddr, SocketAddr}, - sync::{Arc, LazyLock}, + sync::Arc, task::{ready, Context, Poll}, time::{Duration, Instant}, }; -use connlib_model::DomainName; -use domain::base::{iana::Rcode, Message, MessageBuilder, Question, Rtype}; +use dns_types::{prelude::*, DomainNameRef, Query, RecordType, ResponseCode}; use futures_bounded::FuturesTupleSet; use socket_factory::{SocketFactory, TcpSocket, UdpSocket}; @@ -19,9 +18,8 @@ use super::tcp_dns; const MAX_DNS_SERVERS: usize = 20; // We don't bother selecting from more than 10 servers over UDP and TCP. const DNS_TIMEOUT: Duration = Duration::from_secs(2); // Every sensible DNS servers should respond within 2s. -static FIREZONE_DEV: LazyLock = LazyLock::new(|| { - DomainName::vec_from_str("firezone.dev").expect("static domain should always parse") -}); +pub const FIREZONE_DEV: DomainNameRef = + unsafe { DomainNameRef::from_octets_unchecked(b"\x08firezone\x03dev\x00") }; pub struct NameserverSet { inner: BTreeSet, @@ -29,7 +27,7 @@ pub struct NameserverSet { tcp_socket_factory: Arc>, udp_socket_factory: Arc>, - queries: FuturesTupleSet>>, QueryMetaData>, + queries: FuturesTupleSet, QueryMetaData>, } struct QueryMetaData { @@ -63,7 +61,7 @@ impl NameserverSet { udp_dns::send( self.udp_socket_factory.clone(), SocketAddr::new(nameserver, crate::dns::DNS_PORT), - query_firezone_dev(), + Query::new(FIREZONE_DEV.to_vec(), RecordType::A), ), QueryMetaData { nameserver, start }, ) @@ -78,7 +76,7 @@ impl NameserverSet { tcp_dns::send( self.tcp_socket_factory.clone(), SocketAddr::new(nameserver, crate::dns::DNS_PORT), - query_firezone_dev(), + Query::new(FIREZONE_DEV.to_vec(), RecordType::A), ), QueryMetaData { nameserver, start }, ) @@ -102,7 +100,7 @@ impl NameserverSet { loop { match ready!(self.queries.poll_unpin(cx)) { - (Ok(Ok(response)), meta) if response.header().rcode() == Rcode::NOERROR => { + (Ok(Ok(response)), meta) if response.response_code() == ResponseCode::NOERROR => { let rtt = meta.start.elapsed(); tracing::debug!(nameserver = %meta.nameserver, ?rtt, ?response, "DNS query completed"); @@ -133,25 +131,22 @@ impl NameserverSet { } } -fn query_firezone_dev() -> Message> { - let mut builder = MessageBuilder::new_vec().question(); - builder.header_mut().set_random_id(); - builder.header_mut().set_rd(true); - builder.header_mut().set_qr(false); - - builder - .push(Question::new_in(FIREZONE_DEV.clone(), Rtype::A)) - .expect("static question should always be valid"); - - builder.into_message() -} - #[cfg(test)] mod tests { use std::net::Ipv4Addr; + use dns_types::DomainName; + use super::*; + #[test] + fn const_domain_is_correct() { + assert_eq!( + FIREZONE_DEV, + DomainName::vec_from_str("firezone.dev").unwrap() + ) + } + #[tokio::test] #[ignore = "Needs Internet"] async fn can_evaluate_fastest_nameserver() { diff --git a/rust/connlib/tunnel/src/io/tcp_dns.rs b/rust/connlib/tunnel/src/io/tcp_dns.rs index 120baf443..ee575db59 100644 --- a/rust/connlib/tunnel/src/io/tcp_dns.rs +++ b/rust/connlib/tunnel/src/io/tcp_dns.rs @@ -1,26 +1,19 @@ use std::{io, net::SocketAddr, sync::Arc}; -use domain::base::{Message, ToName as _}; use socket_factory::{SocketFactory, TcpSocket}; use tokio::io::{AsyncReadExt as _, AsyncWriteExt as _}; pub async fn send( factory: Arc>, server: SocketAddr, - query: Message>, -) -> io::Result>> { - let domain = query - .sole_question() - .expect("all queries should be for a single name") - .qname() - .to_vec(); - - tracing::trace!(target: "wire::dns::recursive::tcp", %server, %domain); + query: dns_types::Query, +) -> io::Result { + tracing::trace!(target: "wire::dns::recursive::tcp", %server, domain = %query.domain()); let tcp_socket = factory(&server)?; // TODO: Optimise this to reuse a TCP socket to the same resolver. let mut tcp_stream = tcp_socket.connect(server).await?; - let query = query.into_octets(); + let query = query.into_bytes(); let dns_message_length = (query.len() as u16).to_be_bytes(); tcp_stream.write_all(&dns_message_length).await?; @@ -34,8 +27,7 @@ pub async fn send( let mut response = vec![0u8; response_length]; tcp_stream.read_exact(&mut response).await?; - let message = Message::from_octets(response) - .map_err(|_| io::Error::other("Failed to parse DNS message"))?; + let message = dns_types::Response::parse(&response).map_err(io::Error::other)?; Ok(message) } diff --git a/rust/connlib/tunnel/src/io/udp_dns.rs b/rust/connlib/tunnel/src/io/udp_dns.rs index c926d76a5..3e9727d69 100644 --- a/rust/connlib/tunnel/src/io/udp_dns.rs +++ b/rust/connlib/tunnel/src/io/udp_dns.rs @@ -4,37 +4,30 @@ use std::{ sync::Arc, }; -use domain::base::{Message, ToName as _}; use socket_factory::{SocketFactory, UdpSocket}; pub async fn send( factory: Arc>, server: SocketAddr, - query: Message>, -) -> io::Result>> { - let domain = query - .sole_question() - .expect("all queries should be for a single name") - .qname() - .to_vec(); + query: dns_types::Query, +) -> io::Result { + tracing::trace!(target: "wire::dns::recursive::udp", %server, domain = %query.domain()); + let bind_addr = match server { SocketAddr::V4(_) => SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), 0), SocketAddr::V6(_) => SocketAddr::new(Ipv6Addr::UNSPECIFIED.into(), 0), }; - tracing::trace!(target: "wire::dns::recursive::udp", %server, %domain); - // To avoid fragmentation, IP and thus also UDP packets can only reliably sent with an MTU of <= 1500 on the public Internet. const BUF_SIZE: usize = 1500; let udp_socket = factory(&bind_addr)?; let response = udp_socket - .handshake::(server, query.as_slice()) + .handshake::(server, &query.into_bytes()) .await?; - let message = Message::from_octets(response) - .map_err(|_| io::Error::other("Failed to parse DNS message"))?; + let response = dns_types::Response::parse(&response).map_err(io::Error::other)?; - Ok(message) + Ok(response) } diff --git a/rust/connlib/tunnel/src/lib.rs b/rust/connlib/tunnel/src/lib.rs index f9f0cea47..4071d8e07 100644 --- a/rust/connlib/tunnel/src/lib.rs +++ b/rust/connlib/tunnel/src/lib.rs @@ -8,7 +8,8 @@ use anyhow::Result; use bimap::BiMap; use chrono::Utc; -use connlib_model::{ClientId, DomainName, GatewayId, PublicKey, ResourceId, ResourceView}; +use connlib_model::{ClientId, GatewayId, PublicKey, ResourceId, ResourceView}; +use dns_types::DomainName; use io::{Buffers, Io}; use ip_network::{Ipv4Network, Ipv6Network}; use socket_factory::{SocketFactory, TcpSocket, UdpSocket}; @@ -256,7 +257,7 @@ impl GatewayTunnel { let message = response.message.unwrap_or_else(|e| { tracing::debug!("DNS query failed: {e}"); - dns::servfail(response.query.for_slice_ref()) + dns_types::Response::servfail(&response.query) }); match response.transport { @@ -323,7 +324,7 @@ impl GatewayTunnel { self.io.send_dns_query(dns::RecursiveQuery::via_udp( query.source, SocketAddr::new(nameserver, dns::DNS_PORT), - query.message.for_slice_ref(), + query.message, )); } Poll::Ready(io::Input::TcpDnsQuery(query)) => { diff --git a/rust/connlib/tunnel/src/messages.rs b/rust/connlib/tunnel/src/messages.rs index 0654c1c27..0abe89dff 100644 --- a/rust/connlib/tunnel/src/messages.rs +++ b/rust/connlib/tunnel/src/messages.rs @@ -3,6 +3,7 @@ use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; use chrono::{serde::ts_seconds, DateTime, Utc}; use connlib_model::RelayId; +use dns_types::DomainName; use ip_network::IpNetwork; use secrecy::{ExposeSecret as _, Secret}; use serde::{Deserialize, Serialize}; @@ -14,8 +15,6 @@ mod key; pub use key::{Key, SecretKey}; -use crate::DomainName; - /// Represents a wireguard peer. #[derive(Debug, Deserialize, Serialize, Clone)] pub struct Peer { diff --git a/rust/connlib/tunnel/src/p2p_control.rs b/rust/connlib/tunnel/src/p2p_control.rs index f756a87d9..18c99c2ab 100644 --- a/rust/connlib/tunnel/src/p2p_control.rs +++ b/rust/connlib/tunnel/src/p2p_control.rs @@ -18,7 +18,8 @@ pub const DOMAIN_STATUS_EVENT: FzP2pEventType = FzP2pEventType::new(1); pub mod dns_resource_nat { use super::*; use anyhow::{Context as _, Result}; - use connlib_model::{DomainName, ResourceId}; + use connlib_model::ResourceId; + use dns_types::DomainName; use ip_packet::{FzP2pControlSlice, IpPacket}; use std::net::IpAddr; @@ -113,7 +114,6 @@ pub mod dns_resource_nat { #[cfg(test)] mod tests { - use domain::base::Name; use super::*; use std::net::{Ipv4Addr, Ipv6Addr}; @@ -192,7 +192,7 @@ pub mod dns_resource_nat { let domain = DomainName::vec_from_str(&format!("{label}.{label}.{label}.{label}.{label}.com")) .unwrap(); - assert_eq!(domain.len(), Name::MAX_LEN); + assert_eq!(domain.len(), dns_types::MAX_NAME_LEN); domain } diff --git a/rust/connlib/tunnel/src/peer.rs b/rust/connlib/tunnel/src/peer.rs index 2a68136c2..6df61ff26 100644 --- a/rust/connlib/tunnel/src/peer.rs +++ b/rust/connlib/tunnel/src/peer.rs @@ -7,7 +7,8 @@ use crate::client::{IPV4_RESOURCES, IPV6_RESOURCES}; use crate::messages::gateway::Filters; use crate::messages::gateway::ResourceDescription; use chrono::{DateTime, Utc}; -use connlib_model::{ClientId, DomainName, GatewayId, ResourceId}; +use connlib_model::{ClientId, GatewayId, ResourceId}; +use dns_types::DomainName; use filter_engine::FilterEngine; use ip_network::{IpNetwork, Ipv4Network, Ipv6Network}; use ip_network_table::IpNetworkTable; diff --git a/rust/connlib/tunnel/src/tests/assertions.rs b/rust/connlib/tunnel/src/tests/assertions.rs index 44fabe202..51f4fad5e 100644 --- a/rust/connlib/tunnel/src/tests/assertions.rs +++ b/rust/connlib/tunnel/src/tests/assertions.rs @@ -4,7 +4,7 @@ use super::{ sim_gateway::SimGateway, transition::{Destination, ReplyTo}, }; -use connlib_model::{DomainName, GatewayId}; +use connlib_model::GatewayId; use ip_packet::IpPacket; use itertools::Itertools; use std::{ @@ -377,7 +377,7 @@ fn assert_destination_is_cdir_resource(gateway_received_request: &IpPacket, expe fn assert_destination_is_dns_resource( gateway_received_request: &IpPacket, global_dns_records: &DnsRecords, - domain: &DomainName, + domain: &dns_types::DomainName, ) { let actual = gateway_received_request.destination(); let possible_resource_ips = global_dns_records diff --git a/rust/connlib/tunnel/src/tests/dns_records.rs b/rust/connlib/tunnel/src/tests/dns_records.rs index e468f51d3..564423199 100644 --- a/rust/connlib/tunnel/src/tests/dns_records.rs +++ b/rust/connlib/tunnel/src/tests/dns_records.rs @@ -3,21 +3,21 @@ use std::{ net::IpAddr, }; -use connlib_model::{DomainName, DomainRecord}; -use domain::base::{RecordData, Rtype}; +use dns_types::prelude::*; +use dns_types::{DomainName, OwnedRecordData, RecordType}; use itertools::Itertools; #[derive(Debug, Default, Clone)] pub(crate) struct DnsRecords { - inner: BTreeMap>, + inner: BTreeMap>, } impl DnsRecords { pub(crate) fn domain_ips_iter(&self, name: &DomainName) -> impl Iterator + '_ { #[expect(clippy::wildcard_enum_match_arm)] self.domain_records_iter(name).filter_map(|r| match r { - DomainRecord::A(a) => Some(a.addr().into()), - DomainRecord::Aaaa(aaaa) => Some(aaaa.addr().into()), + OwnedRecordData::A(a) => Some(a.addr().into()), + OwnedRecordData::Aaaa(aaaa) => Some(aaaa.addr().into()), _ => None, }) } @@ -25,8 +25,8 @@ impl DnsRecords { pub(crate) fn ips_iter(&self) -> impl Iterator + '_ { #[expect(clippy::wildcard_enum_match_arm)] self.inner.values().flatten().filter_map(|r| match r { - DomainRecord::A(a) => Some(a.addr().into()), - DomainRecord::Aaaa(aaaa) => Some(aaaa.addr().into()), + OwnedRecordData::A(a) => Some(a.addr().into()), + OwnedRecordData::Aaaa(aaaa) => Some(aaaa.addr().into()), _ => None, }) } @@ -34,7 +34,7 @@ impl DnsRecords { pub(crate) fn domain_records_iter( &self, name: &DomainName, - ) -> impl Iterator + '_ { + ) -> impl Iterator + '_ { self.inner.get(name).cloned().into_iter().flatten() } @@ -50,7 +50,7 @@ impl DnsRecords { self.inner.extend(other.inner); } - pub(crate) fn domain_rtypes(&self, name: &DomainName) -> Vec { + pub(crate) fn domain_rtypes(&self, name: &DomainName) -> Vec { self.domain_records_iter(name) .map(|r| r.rtype()) .dedup() @@ -64,7 +64,7 @@ impl DnsRecords { impl From for DnsRecords where - BTreeMap>: From, + BTreeMap>: From, { fn from(value: I) -> Self { Self { @@ -75,7 +75,7 @@ where impl FromIterator for DnsRecords where - BTreeMap>: FromIterator, + BTreeMap>: FromIterator, { fn from_iter>(iter: T) -> Self { Self { @@ -83,10 +83,3 @@ where } } } - -pub(crate) fn ip_to_domain_record(ip: IpAddr) -> DomainRecord { - match ip { - IpAddr::V4(ip) => DomainRecord::A(ip.into()), - IpAddr::V6(ip) => DomainRecord::Aaaa(ip.into()), - } -} diff --git a/rust/connlib/tunnel/src/tests/dns_server_resource.rs b/rust/connlib/tunnel/src/tests/dns_server_resource.rs index fc13973d2..35d84b1e1 100644 --- a/rust/connlib/tunnel/src/tests/dns_server_resource.rs +++ b/rust/connlib/tunnel/src/tests/dns_server_resource.rs @@ -4,10 +4,8 @@ use std::{ time::Instant, }; -use domain::base::{ - iana::{Class, Rcode}, - Message, MessageBuilder, Record, RecordData, ToName, Ttl, -}; +use dns_types::prelude::*; +use dns_types::ResponseCode; use ip_packet::{IpPacket, MAX_UDP_PAYLOAD}; use super::dns_records::DnsRecords; @@ -37,7 +35,7 @@ impl TcpDnsServerResource { pub fn handle_timeout(&mut self, global_dns_records: &DnsRecords, now: Instant) { self.server.handle_timeout(now); while let Some(query) = self.server.poll_queries() { - let response = handle_dns_query(query.message.for_slice(), global_dns_records); + let response = handle_dns_query(&query.message, global_dns_records); self.server .send_message(query.local, query.remote, response) @@ -58,9 +56,9 @@ impl UdpDnsServerResource { pub fn handle_timeout(&mut self, global_dns_records: &DnsRecords, _: Instant) { while let Some(packet) = self.inbound_packets.pop_front() { let udp = packet.as_udp().unwrap(); - let query = Message::from_octets(udp.payload().to_vec()).unwrap(); + let query = dns_types::Query::parse(udp.payload()).unwrap(); - let response = handle_dns_query(query.for_slice(), global_dns_records); + let response = handle_dns_query(&query, global_dns_records); self.outbound_packets.push_back( ip_packet::make::udp_packet( @@ -68,7 +66,7 @@ impl UdpDnsServerResource { packet.source(), udp.destination_port(), udp.source_port(), - truncate_dns_response(response), + response.into_bytes(MAX_UDP_PAYLOAD), ) .expect("src and dst are retrieved from the same packet"), ) @@ -80,44 +78,18 @@ impl UdpDnsServerResource { } } -fn handle_dns_query(query: &Message<[u8]>, global_dns_records: &DnsRecords) -> Message> { - let response = MessageBuilder::new_vec(); - let mut answers = response.start_answer(query, Rcode::NOERROR).unwrap(); +fn handle_dns_query( + query: &dns_types::Query, + global_dns_records: &DnsRecords, +) -> dns_types::Response { + let domain = query.domain().to_vec(); - for query in query.question() { - let query = query.unwrap(); - let name = query.qname().to_name::>(); + let records = global_dns_records + .domain_records_iter(&domain) + .filter(|r| r.rtype() == query.qtype()) + .map(|rdata| (domain.clone(), 60 * 60 * 24, rdata)); - let records = global_dns_records - .domain_records_iter(&name) - .filter(|r| r.rtype() == query.qtype()) - .map(|rdata| Record::new(name.clone(), Class::IN, Ttl::from_days(1), rdata)); - - for record in records { - answers.push(record).unwrap(); - } - } - - answers.into_message() -} - -fn truncate_dns_response(message: Message>) -> Vec { - let mut message_bytes = message.as_octets().to_vec(); - - if message_bytes.len() > MAX_UDP_PAYLOAD { - let mut new_message = message.clone(); - new_message.header_mut().set_tc(true); - - let message_truncation = match message.answer() { - Ok(answer) if answer.pos() <= MAX_UDP_PAYLOAD => answer.pos(), - // This should be very unlikely or impossible. - _ => message.question().pos(), - }; - - message_bytes = new_message.as_octets().to_vec(); - - message_bytes.truncate(message_truncation); - } - - message_bytes + dns_types::ResponseBuilder::for_query(query, ResponseCode::NOERROR) + .with_records(records) + .build() } diff --git a/rust/connlib/tunnel/src/tests/reference.rs b/rust/connlib/tunnel/src/tests/reference.rs index 93acf03ae..915d7f166 100644 --- a/rust/connlib/tunnel/src/tests/reference.rs +++ b/rust/connlib/tunnel/src/tests/reference.rs @@ -4,10 +4,10 @@ use super::{ composite_strategy::CompositeStrategy, sim_client::*, sim_gateway::*, sim_net::*, strategies::*, stub_portal::StubPortal, transition::*, }; -use crate::{client, DomainName}; +use crate::client; use crate::{dns::is_subdomain, proptest::relay_id}; use connlib_model::{GatewayId, RelayId, StaticSecret}; -use domain::base::Rtype; +use dns_types::{DomainName, RecordType}; use ip_network::{Ipv4Network, Ipv6Network}; use itertools::Itertools; use prop::sample::select; @@ -616,8 +616,9 @@ impl ReferenceState { return false; }; - let is_ptr_query = matches!(query.r_type, Rtype::PTR); + let is_ptr_query = matches!(query.r_type, RecordType::PTR); let is_known_domain = state.global_dns_records.contains_domain(&domain); + // In case we sampled a PTR query, the domain doesn't have to exist. let ptr_or_known_domain = is_ptr_query || is_known_domain; @@ -708,8 +709,8 @@ impl ReferenceState { .dns_records .get(name) .is_some_and(|r| match src { - IpAddr::V4(_) => r.contains(&Rtype::A), - IpAddr::V6(_) => r.contains(&Rtype::AAAA), + IpAddr::V4(_) => r.contains(&RecordType::A), + IpAddr::V6(_) => r.contains(&RecordType::AAAA), }) && self.gateways.contains_key(gateway) } @@ -719,10 +720,10 @@ impl ReferenceState { impl ReferenceState { // We surface what are the existing rtypes for a domain so that it's easier // for the proptests to hit an existing record. - fn all_domains(&self) -> Vec<(DomainName, Vec)> { + fn all_domains(&self) -> Vec<(DomainName, Vec)> { fn domains_and_rtypes( records: &DnsRecords, - ) -> impl Iterator)> + use<'_> { + ) -> impl Iterator)> + use<'_> { records .domains_iter() .map(|d| (d.clone(), records.domain_rtypes(&d))) @@ -741,14 +742,14 @@ impl ReferenceState { // We surface what are the existing rtypes for a domain so that it's easier // for the proptests to hit an existing record. - fn single_label_queries_for_search_domains(&self) -> Vec<(DomainName, Vec)> { + fn single_label_queries_for_search_domains(&self) -> Vec<(DomainName, Vec)> { let Some(search_domain) = self.client.inner().search_domain.clone() else { return Vec::default(); }; fn domains_and_rtypes( records: &DnsRecords, - ) -> impl Iterator)> + use<'_> { + ) -> impl Iterator)> + use<'_> { records .domains_iter() .map(|d| (d.clone(), records.domain_rtypes(&d))) diff --git a/rust/connlib/tunnel/src/tests/sim_client.rs b/rust/connlib/tunnel/src/tests/sim_client.rs index 849a442e5..de574d701 100644 --- a/rust/connlib/tunnel/src/tests/sim_client.rs +++ b/rust/connlib/tunnel/src/tests/sim_client.rs @@ -10,17 +10,13 @@ use super::{ use crate::{ client::{CidrResource, DnsResource, InternetResource, Resource}, messages::{DnsServer, Interface}, - DomainName, }; use crate::{proptest::*, ClientState}; use anyhow::Context; use anyhow::Result; use bimap::BiMap; use connlib_model::{ClientId, GatewayId, RelayId, ResourceId, ResourceStatus, SiteId}; -use domain::{ - base::{iana::Opcode, name::FlattenInto, Message, MessageBuilder, Question, Rtype, ToName}, - rdata::AllRecordData, -}; +use dns_types::{prelude::*, DomainName, Query, RecordData, RecordType}; use ip_network::{IpNetwork, Ipv4Network, Ipv6Network}; use ip_network_table::IpNetworkTable; use ip_packet::{Icmpv4Type, Icmpv6Type, IpPacket, Layer4Protocol}; @@ -115,7 +111,7 @@ impl SimClient { pub(crate) fn send_dns_query_for( &mut self, domain: DomainName, - r_type: Rtype, + r_type: RecordType, query_id: u16, upstream: SocketAddr, dns_transport: DnsTransport, @@ -133,20 +129,7 @@ impl SimClient { .tunnel_ip_for(sentinel) .expect("tunnel should be initialised"); - // Create the DNS query message - let mut msg_builder = MessageBuilder::new_vec(); - - msg_builder.header_mut().set_opcode(Opcode::QUERY); - msg_builder.header_mut().set_rd(true); - msg_builder.header_mut().set_id(query_id); - - // Create the query - let mut question_builder = msg_builder.question(); - question_builder - .push(Question::new_in(domain, r_type)) - .unwrap(); - - let message = question_builder.into_message(); + let query = Query::new(domain, r_type).with_id(query_id); match dns_transport { DnsTransport::Udp => { @@ -155,7 +138,7 @@ impl SimClient { sentinel, 9999, // An application would pick a free source port. 53, - message.as_octets().to_vec(), + query.into_bytes(), ) .unwrap(); @@ -165,7 +148,7 @@ impl SimClient { } DnsTransport::Tcp => { self.tcp_dns_client - .send_query(SocketAddr::new(sentinel, 53), message) + .send_query(SocketAddr::new(sentinel, 53), query) .unwrap(); self.sent_tcp_dns_queries.insert((upstream, query_id)); @@ -246,15 +229,15 @@ impl SimClient { match failed_packet.layer4_protocol() { Layer4Protocol::Udp { src, dst } => { self.received_udp_replies - .insert((SPort(dst), DPort(src)), packet.clone()); + .insert((SPort(dst), DPort(src)), packet); } Layer4Protocol::Tcp { src, dst } => { self.received_tcp_replies - .insert((SPort(dst), DPort(src)), packet.clone()); + .insert((SPort(dst), DPort(src)), packet); } Layer4Protocol::Icmp { seq, id } => { self.received_icmp_replies - .insert((Seq(seq), Identifier(id)), packet.clone()); + .insert((Seq(seq), Identifier(id)), packet); } } @@ -263,7 +246,7 @@ impl SimClient { if let Some(udp) = packet.as_udp() { if udp.source_port() == 53 { - let message = Message::from_slice(udp.payload()) + let response = dns_types::Response::parse(udp.payload()) .expect("ip packets on port 53 to be DNS packets"); // Map back to upstream socket so we can assert on it correctly. @@ -274,10 +257,10 @@ impl SimClient { }; self.received_udp_dns_responses - .insert((upstream, message.header().id()), packet.clone()); + .insert((upstream, response.id()), packet.clone()); - if !message.header().tc() { - self.handle_dns_response(message); + if !response.truncated() { + self.handle_dns_response(&response); } return; @@ -341,26 +324,19 @@ impl SimClient { Some(*socket) } - pub(crate) fn handle_dns_response(&mut self, message: &Message<[u8]>) { - for record in message.answer().unwrap() { - let record = record.unwrap(); - let domain = record.owner().to_name(); - + pub(crate) fn handle_dns_response(&mut self, response: &dns_types::Response) { + for record in response.records() { #[expect(clippy::wildcard_enum_match_arm)] - let ip = match record - .into_any_record::>() - .unwrap() - .data() - { - AllRecordData::A(a) => IpAddr::from(a.addr()), - AllRecordData::Aaaa(aaaa) => IpAddr::from(aaaa.addr()), - AllRecordData::Ptr(_) => { + let ip = match record.data() { + RecordData::A(a) => IpAddr::from(a.addr()), + RecordData::Aaaa(aaaa) => IpAddr::from(aaaa.addr()), + RecordData::Ptr(_) => { continue; } - AllRecordData::Txt(_) => { + RecordData::Txt(_) => { continue; } - AllRecordData::Srv(_) => { + RecordData::Srv(_) => { continue; } unhandled => { @@ -368,7 +344,10 @@ impl SimClient { } }; - self.dns_records.entry(domain).or_default().push(ip); + self.dns_records + .entry(response.domain()) + .or_default() + .push(ip); } // Ensure all IPs are always sorted. @@ -419,7 +398,7 @@ pub struct RefClient { /// The IPs assigned to a domain by connlib are an implementation detail that we don't want to model in these tests. /// Instead, we just remember what _kind_ of records we resolved to be able to sample a matching src IP. #[debug(skip)] - pub(crate) dns_records: BTreeMap>, + pub(crate) dns_records: BTreeMap>, /// Whether we are connected to the gateway serving the Internet resource. #[debug(skip)] @@ -907,7 +886,7 @@ impl RefClient { .find(|id| !self.disabled_resources.contains(id)) } - fn resolved_domains(&self) -> impl Iterator)> + '_ { + fn resolved_domains(&self) -> impl Iterator)> + '_ { self.dns_records .iter() .filter(|(domain, _)| self.dns_resource_by_domain(domain).is_some()) @@ -953,7 +932,7 @@ impl RefClient { .filter_map(|(domain, records)| { records .iter() - .any(|r| matches!(r, &Rtype::A)) + .any(|r| matches!(r, &RecordType::A)) .then_some(domain) }) .collect() @@ -964,7 +943,7 @@ impl RefClient { .filter_map(|(domain, records)| { records .iter() - .any(|r| matches!(r, &Rtype::AAAA)) + .any(|r| matches!(r, &RecordType::AAAA)) .then_some(domain) }) .collect() @@ -1065,7 +1044,10 @@ impl RefClient { // If we are querying a DNS resource, we will issue a connection intent to the DNS resource, not the CIDR resource. if self.dns_resource_by_domain(&query.domain).is_some() - && matches!(query.r_type, Rtype::A | Rtype::AAAA | Rtype::PTR) + && matches!( + query.r_type, + RecordType::A | RecordType::AAAA | RecordType::PTR + ) { return None; } @@ -1077,7 +1059,7 @@ impl RefClient { } pub(crate) fn is_site_specific_dns_query(&self, query: &DnsQuery) -> Option { - if !matches!(query.r_type, Rtype::SRV | Rtype::TXT) { + if !matches!(query.r_type, RecordType::SRV | RecordType::TXT) { return None; } diff --git a/rust/connlib/tunnel/src/tests/strategies.rs b/rust/connlib/tunnel/src/tests/strategies.rs index cb4187688..34f76060d 100644 --- a/rust/connlib/tunnel/src/tests/strategies.rs +++ b/rust/connlib/tunnel/src/tests/strategies.rs @@ -1,4 +1,4 @@ -use super::dns_records::{ip_to_domain_record, DnsRecords}; +use super::dns_records::DnsRecords; use super::{sim_net::Host, sim_relay::ref_relay_host, stub_portal::StubPortal}; use crate::client::{ CidrResource, DnsResource, InternetResource, DNS_SENTINELS_V4, DNS_SENTINELS_V6, @@ -6,7 +6,8 @@ use crate::client::{ }; use crate::messages::DnsServer; use crate::{proptest::*, IPV4_TUNNEL, IPV6_TUNNEL}; -use connlib_model::{DomainRecord, RelayId, Site}; +use connlib_model::{RelayId, Site}; +use dns_types::OwnedRecordData; use ip_network::{Ipv4Network, Ipv6Network}; use itertools::Itertools; use prop::sample; @@ -27,22 +28,20 @@ pub(crate) fn global_dns_records() -> impl Strategy { .prop_map_into() } -fn dns_record() -> impl Strategy { +fn dns_record() -> impl Strategy { prop_oneof![ - 3 => non_reserved_ip().prop_map(ip_to_domain_record), + 3 => non_reserved_ip().prop_map(dns_types::records::ip), 1 => collection::vec(txt_record(), 6..=10) .prop_map(|sections| { sections.into_iter().flatten().collect_vec() }) - .prop_map(|o| domain::rdata::Txt::from_octets(o).unwrap()) - .prop_map(DomainRecord::Txt) + .prop_map(|content| dns_types::records::txt(content).unwrap()) ] } -pub(crate) fn site_specific_dns_record() -> impl Strategy { +pub(crate) fn site_specific_dns_record() -> impl Strategy { prop_oneof![ collection::vec(txt_record(), 6..=10) .prop_map(|sections| { sections.into_iter().flatten().collect_vec() }) - .prop_map(|o| domain::rdata::Txt::from_octets(o).unwrap()) - .prop_map(DomainRecord::Txt), + .prop_map(|content| dns_types::records::txt(content).unwrap()), srv_record() ] } @@ -60,7 +59,7 @@ fn txt_record() -> impl Strategy> { }) } -fn srv_record() -> impl Strategy { +fn srv_record() -> impl Strategy { ( any::(), any::(), @@ -68,7 +67,7 @@ fn srv_record() -> impl Strategy { domain_name(2..4).prop_map(|d| d.parse().unwrap()), ) .prop_map(|(priority, weight, port, target)| { - DomainRecord::Srv(domain::rdata::Srv::new(priority, weight, port, target)) + dns_types::records::srv(priority, weight, port, target) }) } @@ -270,12 +269,12 @@ fn double_star_wildcard_dns_resource( }) } -pub(crate) fn resolved_ips() -> impl Strategy> { +pub(crate) fn resolved_ips() -> impl Strategy> { let record = prop_oneof![ dns_resource_ip4s().prop_map_into(), dns_resource_ip6s().prop_map_into() ] - .prop_map(ip_to_domain_record); + .prop_map(dns_types::records::ip); collection::btree_set(record, 1..6) } diff --git a/rust/connlib/tunnel/src/tests/stub_portal.rs b/rust/connlib/tunnel/src/tests/stub_portal.rs index de296f342..4829d7bf2 100644 --- a/rust/connlib/tunnel/src/tests/stub_portal.rs +++ b/rust/connlib/tunnel/src/tests/stub_portal.rs @@ -10,8 +10,9 @@ use crate::{ client::DnsResource, messages::{gateway, DnsServer}, }; -use connlib_model::{DomainName, GatewayId}; +use connlib_model::GatewayId; use connlib_model::{ResourceId, SiteId}; +use dns_types::DomainName; use itertools::Itertools; use proptest::{ collection, diff --git a/rust/connlib/tunnel/src/tests/sut.rs b/rust/connlib/tunnel/src/tests/sut.rs index 2e1484557..7547560ad 100644 --- a/rust/connlib/tunnel/src/tests/sut.rs +++ b/rust/connlib/tunnel/src/tests/sut.rs @@ -17,8 +17,8 @@ use crate::tests::transition::Transition; use crate::utils::earliest; use crate::{dns, messages::Interface, ClientEvent, GatewayEvent}; use connlib_model::{ClientId, GatewayId, PublicKey, RelayId}; -use domain::base::iana::{Class, Rcode}; -use domain::base::{Message, MessageBuilder, Record, RecordData, ToName as _, Ttl}; +use dns_types::prelude::*; +use dns_types::ResponseCode; use rand::distributions::DistString; use rand::SeedableRng; use sha2::Digest; @@ -416,10 +416,8 @@ impl TunnelTest { let server = query.server; let transport = query.transport; - let response = self.on_recursive_dns_query( - query.message.for_slice_ref(), - &ref_state.global_dns_records, - ); + let response = + self.on_recursive_dns_query(&query.message, &ref_state.global_dns_records); self.client.exec_mut(|c| { c.sut.handle_dns_response(dns::RecursiveResponse { server, @@ -531,8 +529,8 @@ impl TunnelTest { let upstream = c.dns_mapping().get_by_left(&result.server.ip()).unwrap(); c.received_tcp_dns_responses - .insert((*upstream, result.query.header().id())); - c.handle_dns_response(message.for_slice()) + .insert((*upstream, result.query.id())); + c.handle_dns_response(&message) } Err(e) => { tracing::error!("TCP DNS query failed: {e:#}"); @@ -779,28 +777,22 @@ impl TunnelTest { fn on_recursive_dns_query( &self, - query: Message<&[u8]>, + query: &dns_types::Query, global_dns_records: &DnsRecords, - ) -> Message> { - let response = MessageBuilder::new_vec(); - let mut answers = response.start_answer(&query, Rcode::NOERROR).unwrap(); - - let query = query.sole_question().unwrap(); - let name = query.qname().to_vec(); + ) -> dns_types::Response { let qtype = query.qtype(); + let domain = query.domain(); - let records = global_dns_records - .domain_records_iter(&name) - .filter(|record| qtype == record.rtype()) - .map(|rdata| Record::new(name.clone(), Class::IN, Ttl::from_days(1), rdata)); + let response = dns_types::ResponseBuilder::for_query(query, ResponseCode::NOERROR) + .with_records( + global_dns_records + .domain_records_iter(&domain) + .filter(|record| qtype == record.rtype()) + .map(|rdata| (domain.clone(), 60 * 60 * 24, rdata)), + ) + .build(); - for record in records { - answers.push(record).unwrap(); - } - - let response = answers.into_message(); - - tracing::debug!(%name, %qtype, "Responding to DNS query"); + tracing::debug!(%domain, %qtype, "Responding to DNS query"); response } diff --git a/rust/connlib/tunnel/src/tests/transition.rs b/rust/connlib/tunnel/src/tests/transition.rs index bdcc93bc3..f96e5d5a6 100644 --- a/rust/connlib/tunnel/src/tests/transition.rs +++ b/rust/connlib/tunnel/src/tests/transition.rs @@ -2,12 +2,11 @@ use crate::{ client::{Resource, IPV4_RESOURCES, IPV6_RESOURCES}, proptest::{host_v4, host_v6}, }; -use connlib_model::RelayId; +use connlib_model::{RelayId, ResourceId}; +use dns_types::{DomainName, RecordType}; use super::sim_net::{any_ip_stack, any_port, Host}; use crate::messages::DnsServer; -use connlib_model::{DomainName, ResourceId}; -use domain::base::Rtype; use prop::collection; use proptest::{prelude::*, sample}; use std::{ @@ -89,7 +88,7 @@ pub(crate) enum Transition { pub(crate) struct DnsQuery { pub(crate) domain: DomainName, /// The type of DNS query we should send. - pub(crate) r_type: Rtype, + pub(crate) r_type: RecordType, /// The DNS query ID. pub(crate) query_id: u16, pub(crate) dns_server: SocketAddr, @@ -283,7 +282,7 @@ fn non_dns_ports() -> impl Strategy { /// Samples up to 5 DNS queries that will be sent concurrently into connlib. pub(crate) fn dns_queries( - domain: impl Strategy)>, + domain: impl Strategy)>, dns_server: impl Strategy, ) -> impl Strategy> { // Queries can be uniquely identified by the tuple of DNS server and query ID. @@ -317,7 +316,7 @@ pub(crate) fn dns_queries( maybe_reverse_record, transport, )| { - if matches!(r_type, Rtype::PTR) { + if matches!(r_type, RecordType::PTR) { domain = DomainName::reverse_from_addr(maybe_reverse_record).unwrap(); } @@ -358,10 +357,15 @@ fn dns_transport() -> impl Strategy { /// /// Similarrly to trigger NAT64 and NAT46 we need to query for A when only AAAA is available and vice versa. pub(crate) fn maybe_available_response_rtypes( - available_rtypes: Vec, -) -> impl Strategy { - if available_rtypes.contains(&Rtype::A) || available_rtypes.contains(&Rtype::AAAA) { - sample::select(vec![Rtype::PTR, Rtype::MX, Rtype::A, Rtype::AAAA]) + available_rtypes: Vec, +) -> impl Strategy { + if available_rtypes.contains(&RecordType::A) || available_rtypes.contains(&RecordType::AAAA) { + sample::select(vec![ + RecordType::PTR, + RecordType::MX, + RecordType::A, + RecordType::AAAA, + ]) } else { sample::select(available_rtypes) } diff --git a/rust/dns-over-tcp/Cargo.toml b/rust/dns-over-tcp/Cargo.toml index 970cb1931..08e7e71f4 100644 --- a/rust/dns-over-tcp/Cargo.toml +++ b/rust/dns-over-tcp/Cargo.toml @@ -7,10 +7,9 @@ license = { workspace = true } [dependencies] anyhow = { workspace = true } -domain = { workspace = true } +dns-types = { workspace = true } firezone-logging = { workspace = true } ip-packet = { workspace = true } -itertools = { workspace = true } rand = { workspace = true } smoltcp = { workspace = true, features = ["std", "log", "medium-ip", "proto-ipv4", "proto-ipv6", "socket-tcp"] } tracing = { workspace = true } diff --git a/rust/dns-over-tcp/src/client.rs b/rust/dns-over-tcp/src/client.rs index 339255639..03410cabf 100644 --- a/rust/dns-over-tcp/src/client.rs +++ b/rust/dns-over-tcp/src/client.rs @@ -9,7 +9,6 @@ use crate::{ time::smol_now, }; use anyhow::{anyhow, bail, Context as _, Result}; -use domain::{base::Message, dep::octseq::OctetsInto}; use ip_packet::IpPacket; use rand::{rngs::StdRng, Rng, SeedableRng}; use smoltcp::{ @@ -36,9 +35,9 @@ pub struct Client { sockets_by_remote: BTreeMap, local_ports_by_socket: HashMap, /// Queries we should send to a DNS resolver. - pending_queries_by_remote: HashMap>>>, + pending_queries_by_remote: HashMap>, /// Queries we have sent to a DNS resolver and are waiting for a reply. - sent_queries_by_remote: HashMap>>>, + sent_queries_by_remote: HashMap>, query_results: VecDeque, @@ -50,9 +49,9 @@ pub struct Client { #[derive(Debug)] pub struct QueryResult { - pub query: Message>, + pub query: dns_types::Query, pub server: SocketAddr, - pub result: Result>>, + pub result: Result, } impl Client { @@ -88,9 +87,7 @@ impl Client { /// Send the given DNS query to the target server. /// /// This only queues the message. You need to call [`Client::handle_timeout`] to actually send them. - pub fn send_query(&mut self, server: SocketAddr, message: Message>) -> Result<()> { - anyhow::ensure!(!message.header().qr(), "Message is a DNS response!"); - + pub fn send_query(&mut self, server: SocketAddr, message: dns_types::Query) -> Result<()> { self.pending_queries_by_remote .entry(server) .or_default() @@ -308,8 +305,8 @@ impl Client { fn send_pending_queries( socket: &mut tcp::Socket, server: SocketAddr, - pending_queries: &mut VecDeque>>, - sent_queries: &mut HashMap>>, + pending_queries: &mut VecDeque, + sent_queries: &mut HashMap, query_results: &mut VecDeque, ) { loop { @@ -321,9 +318,9 @@ fn send_pending_queries( break; }; - match codec::try_send(socket, query.for_slice_ref()).context("Failed to send DNS query") { + match codec::try_send(socket, query.as_bytes()).context("Failed to send DNS query") { Ok(()) => { - let replaced = sent_queries.insert(query.header().id(), query).is_some(); + let replaced = sent_queries.insert(query.id(), query).is_some(); debug_assert!(!replaced, "Query ID is not unique"); } Err(e) => { @@ -344,8 +341,8 @@ fn send_pending_queries( fn recv_responses( socket: &mut tcp::Socket, server: SocketAddr, - pending_queries: &mut VecDeque>>, - sent_queries: &mut HashMap>>, + pending_queries: &mut VecDeque, + sent_queries: &mut HashMap, query_results: &mut VecDeque, ) { let Some(result) = try_recv_response(socket) @@ -358,13 +355,13 @@ fn recv_responses( let new_results = result .and_then(|response| { let query = sent_queries - .remove(&response.header().id()) + .remove(&response.id()) .context("DNS resolver sent response for unknown query")?; Ok(vec![QueryResult { query, server, - result: Ok(response.octets_into()), + result: Ok(response), }]) }) .unwrap_or_else(|e| { @@ -379,8 +376,8 @@ fn recv_responses( fn fail_all_queries<'a>( error: &'a anyhow::Error, server: SocketAddr, - pending_queries: &'a mut VecDeque>>, - sent_queries: &'a mut HashMap>>, + pending_queries: &'a mut VecDeque, + sent_queries: &'a mut HashMap, ) -> impl Iterator + 'a { let pending_queries = pending_queries.drain(..); let sent_queries = sent_queries.drain().map(|(_, query)| query); @@ -391,7 +388,7 @@ fn fail_all_queries<'a>( fn into_failed_results( server: SocketAddr, - iter: impl IntoIterator>>, + iter: impl IntoIterator, make_error: impl Fn() -> anyhow::Error, ) -> impl Iterator { iter.into_iter().map(move |query| QueryResult { @@ -401,18 +398,14 @@ fn into_failed_results( }) } -fn try_recv_response<'b>(socket: &'b mut tcp::Socket) -> Result>> { +fn try_recv_response(socket: &mut tcp::Socket) -> Result> { if !socket.can_recv() { tracing::trace!(state = %socket.state(), "Not yet ready to receive next message"); return Ok(None); } - let Some(message) = codec::try_recv(socket)? else { - return Ok(None); - }; + let maybe_response = codec::try_recv(socket)?; - anyhow::ensure!(message.header().qr(), "DNS message is a query!"); - - Ok(Some(message)) + Ok(maybe_response) } diff --git a/rust/dns-over-tcp/src/codec.rs b/rust/dns-over-tcp/src/codec.rs index a4a32e25f..da95b0e71 100644 --- a/rust/dns-over-tcp/src/codec.rs +++ b/rust/dns-over-tcp/src/codec.rs @@ -6,17 +6,10 @@ //! Source: . use anyhow::{Context as _, Result}; -use domain::{ - base::{iana::Rcode, Message, ParsedName, Rtype}, - rdata::AllRecordData, -}; -use itertools::Itertools as _; use smoltcp::socket::tcp; -pub fn try_send(socket: &mut tcp::Socket, message: Message<&[u8]>) -> Result<()> { - let response = message.as_slice(); - - let dns_message_length = (response.len() as u16).to_be_bytes(); +pub fn try_send(socket: &mut tcp::Socket, message: &[u8]) -> Result<()> { + let dns_message_length = (message.len() as u16).to_be_bytes(); let written = socket .send_slice(&dns_message_length) @@ -28,37 +21,40 @@ pub fn try_send(socket: &mut tcp::Socket, message: Message<&[u8]>) -> Result<()> ); let written = socket - .send_slice(response) + .send_slice(message) .context("Failed to write DNS message")?; anyhow::ensure!( - written == response.len(), + written == message.len(), "Not enough space in write buffer for DNS message" ); - if tracing::event_enabled!(target: "wire::dns::tcp::send", tracing::Level::TRACE) { - if let Some(ParsedMessage { - qid, - qname, - qtype, - response, - rcode, - records, - }) = parse(message) - { - if response { - let records = records.into_iter().join(" | "); - tracing::trace!(target: "wire::dns::tcp::send", %qid, %rcode, "{:5} {qname} => [{records}]", qtype.to_string()); - } else { - tracing::trace!(target: "wire::dns::tcp::send", %qid, "{:5} {qname}", qtype.to_string()); - } - } - } + // if tracing::event_enabled!(target: "wire::dns::tcp::send", tracing::Level::TRACE) { + // if let Some(ParsedMessage { + // qid, + // qname, + // qtype, + // response, + // rcode, + // records, + // }) = parse(message) + // { + // if response { + // let records = records.into_iter().join(" | "); + // tracing::trace!(target: "wire::dns::tcp::send", %qid, %rcode, "{:5} {qname} => [{records}]", qtype.to_string()); + // } else { + // tracing::trace!(target: "wire::dns::tcp::send", %qid, "{:5} {qname}", qtype.to_string()); + // } + // } + // } Ok(()) } -pub fn try_recv<'b>(socket: &'b mut tcp::Socket) -> Result>> { +pub fn try_recv<'b, M>(socket: &'b mut tcp::Socket) -> Result> +where + M: TryFrom<&'b [u8], Error: std::error::Error + Send + Sync + 'static>, +{ let maybe_message = socket .recv(|r| { // DNS over TCP has a 2-byte length prefix at the start, see . @@ -70,72 +66,30 @@ pub fn try_recv<'b>(socket: &'b mut tcp::Socket) -> Result [{records}]", qtype.to_string()); - } else { - tracing::trace!(target: "wire::dns::tcp::recv", %qid, "{:5} {qname}", qtype.to_string()); - } - } - } + // if tracing::event_enabled!(target: "wire::dns::tcp::recv", tracing::Level::TRACE) { + // if let Some(ParsedMessage { + // qid, + // qname, + // qtype, + // rcode, + // response, + // records, + // }) = maybe_message.and_then(parse) + // { + // if response { + // let records = records.into_iter().join(" | "); + // tracing::trace!(target: "wire::dns::tcp::recv", %qid, %rcode, "{:5} {qname} => [{records}]", qtype.to_string()); + // } else { + // tracing::trace!(target: "wire::dns::tcp::recv", %qid, "{:5} {qname}", qtype.to_string()); + // } + // } + // } Ok(maybe_message) } - -fn parse(message: Message<&[u8]>) -> Option> { - let question = message.sole_question().ok()?; - let answers = message.answer().ok()?; - - let qtype = question.qtype(); - let qname = question.into_qname(); - let qid = message.header().id(); - let response = message.header().qr(); - let rcode = message.header().rcode(); - let records = answers - .into_iter() - .filter_map(|r| { - let data = r - .ok()? - .into_any_record::>() - .ok()? - .data() - .clone(); - - Some(data) - }) - .collect(); - - Some(ParsedMessage { - qid, - qname, - rcode, - qtype, - response, - records, - }) -} - -struct ParsedMessage<'a> { - qid: u16, - qname: ParsedName<&'a [u8]>, - qtype: Rtype, - rcode: Rcode, - response: bool, - records: Vec>>, -} diff --git a/rust/dns-over-tcp/src/server.rs b/rust/dns-over-tcp/src/server.rs index 0ae3b13c4..63f3aa25e 100644 --- a/rust/dns-over-tcp/src/server.rs +++ b/rust/dns-over-tcp/src/server.rs @@ -9,7 +9,6 @@ use crate::{ time::smol_now, }; use anyhow::{Context as _, Result}; -use domain::{base::Message, dep::octseq::OctetsInto as _}; use ip_packet::IpPacket; use smoltcp::{ iface::{Interface, PollResult, SocketHandle, SocketSet}, @@ -38,7 +37,7 @@ pub struct Server { } pub struct Query { - pub message: Message>, + pub message: dns_types::Query, /// The local address of the socket that received the query. pub local: SocketAddr, /// The remote address of the client that sent the query. @@ -151,16 +150,16 @@ impl Server { &mut self, src: SocketAddr, dst: SocketAddr, - message: Message>, + response: dns_types::Response, ) -> Result<()> { let handle = self .pending_sockets_by_local_remote_and_query_id - .remove(&(src, dst, message.header().id())) + .remove(&(src, dst, response.id())) .context("No pending query found for message")?; let socket = self.sockets.get_mut::(handle); - write_tcp_dns_response(socket, message.for_slice_ref()) + codec::try_send(socket, &response.into_bytes(u16::MAX)) .inspect_err(|_| socket.abort()) // Abort socket on error. .context("Failed to write DNS response")?; @@ -191,7 +190,7 @@ impl Server { while let Some(result) = try_recv_query(socket, local).transpose() { match result { Ok((message, remote)) => { - let qid = message.header().id(); + let qid = message.id(); tracing::trace!(%local, %remote, %qid, "Received DNS query"); @@ -236,7 +235,7 @@ impl Server { fn try_recv_query( socket: &mut tcp::Socket, listen: SocketAddr, -) -> Result>, SocketAddr)>> { +) -> Result> { // smoltcp's sockets can only ever handle a single remote, i.e. there is no permanent listening socket. // to be able to handle a new connection, reset the socket back to `listen` once the connection is closed / closing. { @@ -271,28 +270,16 @@ fn try_recv_query( return Ok(None); } - let Some(message) = codec::try_recv(socket)? else { + let Some(query) = codec::try_recv(socket)? else { return Ok(None); }; - anyhow::ensure!(!message.header().qr(), "DNS message is a response!"); - - let message = message.octets_into(); - let remote = socket .remote_endpoint() .context("Unknown remote endpoint despite having just received a message")?; Ok(Some(( - message, + query, SocketAddr::new(remote.addr.into(), remote.port), ))) } - -fn write_tcp_dns_response(socket: &mut tcp::Socket, response: Message<&[u8]>) -> Result<()> { - anyhow::ensure!(response.header().qr(), "DNS message is a query!"); - - codec::try_send(socket, response)?; - - Ok(()) -} diff --git a/rust/dns-over-tcp/tests/client_and_server.rs b/rust/dns-over-tcp/tests/client_and_server.rs index 7ff038fe2..759e1fde2 100644 --- a/rust/dns-over-tcp/tests/client_and_server.rs +++ b/rust/dns-over-tcp/tests/client_and_server.rs @@ -5,7 +5,7 @@ use std::{ }; use dns_over_tcp::QueryResult; -use domain::base::{iana::Rcode, Message, MessageBuilder, Name, Rtype}; +use dns_types::{Query, RecordType, ResponseBuilder, ResponseCode}; #[test] fn smoke() { @@ -26,7 +26,10 @@ fn smoke() { for id in 0..5 { dns_client - .send_query(resolver_addr, a_query("example.com", id)) + .send_query( + resolver_addr, + Query::new("example.com".parse().unwrap(), RecordType::A).with_id(id), + ) .unwrap(); } @@ -41,16 +44,6 @@ fn smoke() { } } -fn a_query(domain: &str, id: u16) -> Message> { - let mut builder = MessageBuilder::new_vec().question(); - builder.header_mut().set_id(id); - builder - .push((Name::vec_from_str(domain).unwrap(), Rtype::A)) - .unwrap(); - - builder.into_message() -} - fn progress( dns_client: &mut dns_over_tcp::Client, dns_server: &mut dns_over_tcp::Server, @@ -67,13 +60,12 @@ fn progress( } if let Some(query) = dns_server.poll_queries() { - let response = MessageBuilder::new_vec() - .start_answer(&query.message, Rcode::NXDOMAIN) - .unwrap() - .into_message(); - dns_server - .send_message(query.local, query.remote, response) + .send_message( + query.local, + query.remote, + ResponseBuilder::for_query(&query.message, ResponseCode::NXDOMAIN).build(), + ) .unwrap(); continue; } diff --git a/rust/dns-over-tcp/tests/smoke_server.rs b/rust/dns-over-tcp/tests/smoke_server.rs index 0c296d5db..ec37b3ddb 100644 --- a/rust/dns-over-tcp/tests/smoke_server.rs +++ b/rust/dns-over-tcp/tests/smoke_server.rs @@ -7,7 +7,7 @@ use std::{ }; use anyhow::{Context as _, Result}; -use domain::base::{iana::Rcode, MessageBuilder}; +use dns_types::{ResponseBuilder, ResponseCode}; use firezone_bin_shared::TunDeviceManager; use ip_network::Ipv4Network; use tokio::task::JoinSet; @@ -107,13 +107,12 @@ impl Eventloop { } if let Some(query) = self.dns_server.poll_queries() { - let response = MessageBuilder::new_vec() - .start_answer(&query.message, Rcode::NXDOMAIN) - .unwrap() - .into_message(); - self.dns_server - .send_message(query.local, query.remote, response) + .send_message( + query.local, + query.remote, + ResponseBuilder::for_query(&query.message, ResponseCode::NXDOMAIN).build(), + ) .unwrap(); continue; } diff --git a/rust/dns-types/Cargo.toml b/rust/dns-types/Cargo.toml new file mode 100644 index 000000000..abae2b8e7 --- /dev/null +++ b/rust/dns-types/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "dns-types" +version = "0.1.0" +edition = { workspace = true } +license = { workspace = true } + +[lib] +path = "lib.rs" + +[dependencies] +domain = { version = "0.10", features = ["serde"] } # Not a workspace dependency because we don't want any other crates to depend on it. +thiserror = { workspace = true } +tracing = { workspace = true } + +[lints] +workspace = true diff --git a/rust/dns-types/lib.rs b/rust/dns-types/lib.rs new file mode 100644 index 000000000..df5757977 --- /dev/null +++ b/rust/dns-types/lib.rs @@ -0,0 +1,357 @@ +#![cfg_attr(test, allow(clippy::unwrap_used))] + +use domain::{ + base::{ + message_builder::AnswerBuilder, name::FlattenInto, HeaderCounts, Message, MessageBuilder, + ParsedName, Question, RecordSection, + }, + dep::octseq::OctetsInto, + rdata::AllRecordData, +}; + +pub mod prelude { + // Re-export trait names so other crates can call the functions on them. + // We don't export the name though so that it cannot conflict. + pub use domain::base::name::FlattenInto as _; + pub use domain::base::RecordData as _; + pub use domain::base::ToName as _; +} + +pub const MAX_NAME_LEN: usize = domain::base::Name::MAX_LEN; + +pub type RecordType = domain::base::iana::Rtype; + +pub type DomainNameRef<'a> = domain::base::Name<&'a [u8]>; +pub type Record<'a> = + domain::base::Record, AllRecordData<&'a [u8], ParsedName<&'a [u8]>>>; +pub type RecordData<'a> = AllRecordData<&'a [u8], ParsedName<&'a [u8]>>; + +pub type DomainName = domain::base::Name>; +pub type OwnedRecord = domain::base::Record, DomainName>>; +pub type OwnedRecordData = AllRecordData, DomainName>; + +pub type ResponseCode = domain::base::iana::Rcode; + +#[derive(Clone)] +pub struct Query { + inner: Message>, +} + +impl std::fmt::Debug for Query { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Query") + .field("qid", &self.inner.header().id()) + .field("type", &self.qtype()) + .field("domain", &self.domain()) + .finish() + } +} + +impl Query { + pub fn parse(slice: &[u8]) -> Result { + let message = Message::from_octets(slice).map_err(|_| Error::TooShort)?; + + if message.header().qr() { + return Err(Error::NotAQuery); + } + + // We don't need to support multiple questions/qname in a single query because + // nobody does it and since this run with each packet we want to squeeze as much optimization + // as we can therefore we won't do it. + // + // See: https://stackoverflow.com/a/55093896 + let _ = message.sole_question()?; // Verify that there is exactly one question. + + // Verify that we can parse the answers + all records + for record in message.answer()? { + record?.into_any_record::>()?; + } + + Ok(Self { + inner: message.octets_into(), + }) + } + + pub fn new(domain: DomainName, rtype: RecordType) -> Self { + let mut inner = MessageBuilder::new_vec().question(); + inner.header_mut().set_qr(false); + inner.header_mut().set_rd(true); // Default to recursion desired. + inner.header_mut().set_random_id(); // Default to a random id. + + inner + .push((domain, rtype)) + .expect("Vec-backed message builder never fails"); + + Self { + inner: inner.into_message(), + } + } + + pub fn with_id(mut self, id: u16) -> Self { + self.inner.header_mut().set_id(id); + + self + } + + pub fn id(&self) -> u16 { + self.inner.header().id() + } + + pub fn domain(&self) -> DomainName { + self.question().into_qname().flatten_into() + } + + pub fn qtype(&self) -> RecordType { + self.question().qtype() + } + + pub fn into_bytes(self) -> Vec { + self.inner.into_octets() + } + + pub fn as_bytes(&self) -> &[u8] { + self.inner.as_slice() + } + + fn question(&self) -> Question> { + self.inner.sole_question().expect("verified in ctor") + } +} + +impl TryFrom<&[u8]> for Query { + type Error = Error; + + fn try_from(slice: &[u8]) -> Result { + Self::parse(slice) + } +} + +impl TryFrom<&[u8]> for Response { + type Error = Error; + + fn try_from(slice: &[u8]) -> Result { + Self::parse(slice) + } +} + +pub struct Response { + inner: Message>, +} + +impl std::fmt::Debug for Response { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Response") + .field("qid", &self.inner.header().id()) + .field("domain", &self.domain()) + .field("type", &self.qtype()) + .field("response_code", &self.response_code()) + .finish_non_exhaustive() // TODO: Add records? + } +} + +impl Response { + /// Creates an empty, "NOERROR" response for the given query. + pub fn no_error(query: &Query) -> Self { + ResponseBuilder::for_query(query, ResponseCode::NOERROR).build() + } + + pub fn servfail(query: &Query) -> Self { + ResponseBuilder::for_query(query, ResponseCode::SERVFAIL).build() + } + + pub fn nxdomain(query: &Query) -> Self { + ResponseBuilder::for_query(query, ResponseCode::NXDOMAIN).build() + } + + pub fn parse(slice: &[u8]) -> Result { + let message = Message::from_octets(slice).map_err(|_| Error::TooShort)?; + + if !message.header().qr() { + return Err(Error::NotAResponse); + } + + let _ = message.sole_question()?; // Verify that there is exactly one question. + + // Verify that we can parse the answers + all records + for record in message.answer()? { + record?.into_any_record::>()?; + } + + Ok(Self { + inner: message.octets_into(), + }) + } + + pub fn id(&self) -> u16 { + self.inner.header().id() + } + + pub fn truncated(&self) -> bool { + self.inner.header().tc() + } + + pub fn domain(&self) -> DomainName { + self.question().into_qname().flatten_into() + } + + pub fn qtype(&self) -> RecordType { + self.question().qtype() + } + + pub fn response_code(&self) -> ResponseCode { + self.inner.header().rcode() + } + + pub fn records(&self) -> impl Iterator> { + self.answer().into_iter().map(|r| { + r.expect("verified in ctor") + .into_any_record::>() + .expect("verified in ctor") + }) + } + + /// Serializes this response into a byte slice. + /// + /// The `max_len` parameter specifies the maximum size of the payload. + /// In case the payload is bigger than `max_len`, it will be truncated and the TC bit in the header will be set. + pub fn into_bytes(mut self, max_len: u16) -> Vec { + let qid = self.inner.header().id(); + + let len = self.inner.as_slice().len(); + if len <= max_len as usize { + return self.inner.into_octets(); + } + + tracing::debug!(%len, %max_len, %qid, domain = %self.domain(), "Truncating DNS response"); + + self.inner.header_mut().set_tc(true); + + let start_of_answer = self.answer().pos(); + + let mut bytes = self.inner.into_octets(); + bytes.truncate(start_of_answer); + + let headercounts = HeaderCounts::for_message_slice_mut(&mut bytes); + + // We deleted everything after answers, reset all counts to 0. + headercounts.as_slice_mut().fill(0); + + // Set the question count to 1. + headercounts.set_qdcount(1); + + bytes + } + + fn question(&self) -> Question> { + self.inner.sole_question().expect("verified in ctor") + } + + fn answer(&self) -> RecordSection<'_, Vec> { + self.inner.answer().expect("verified in ctor") + } +} + +pub struct ResponseBuilder { + inner: AnswerBuilder>, +} + +impl ResponseBuilder { + pub fn for_query(query: &Query, code: ResponseCode) -> Self { + let inner = MessageBuilder::new_vec() + .start_answer(&query.inner, code) + .expect("Vec-backed message builder never fails"); + + Self { inner } + } + + pub fn with_records(mut self, records: impl IntoIterator>) -> Self { + for record in records { + self.inner + .push(record.into()) + .expect("Vec-backed message builder never fails"); + } + + self + } + + pub fn build(self) -> Response { + Response { + inner: self.inner.into_message(), + } + } +} + +#[derive(Debug, thiserror::Error)] +pub enum Error { + #[error("Bytes slice is too short to contain a message")] + TooShort, + #[error("DNS message is not a query")] + NotAQuery, + #[error("DNS message is not a response")] + NotAResponse, + #[error(transparent)] + Parse(#[from] domain::base::wire::ParseError), +} + +pub mod records { + use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; + + use domain::rdata::{rfc1035::TxtError, Aaaa, Ptr, Srv, Txt, A}; + + use super::*; + + pub fn ptr(domain: DomainName) -> OwnedRecordData { + OwnedRecordData::Ptr(Ptr::new(domain)) + } + + pub fn a(ip: Ipv4Addr) -> OwnedRecordData { + OwnedRecordData::A(A::new(ip)) + } + + pub fn aaaa(ip: Ipv6Addr) -> OwnedRecordData { + OwnedRecordData::Aaaa(Aaaa::new(ip)) + } + + pub fn ip(ip: IpAddr) -> OwnedRecordData { + match ip { + IpAddr::V4(ip) => a(ip), + IpAddr::V6(ip) => aaaa(ip), + } + } + + pub fn txt(content: Vec) -> Result { + Ok(OwnedRecordData::Txt(Txt::from_octets(content)?)) + } + + pub fn srv(priority: u16, weight: u16, port: u16, target: DomainName) -> OwnedRecordData { + OwnedRecordData::Srv(Srv::new(priority, weight, port, target)) + } +} + +#[cfg(test)] +mod tests { + use std::net::Ipv4Addr; + + use super::*; + + #[test] + fn can_truncate_response() { + let domain = DomainName::vec_from_str("example.com").unwrap(); + + let query = Query::new(domain.clone(), RecordType::A); + let response = ResponseBuilder::for_query(&query, ResponseCode::NOERROR) + .with_records(std::iter::repeat_n( + (domain.clone(), 1, records::a(Ipv4Addr::LOCALHOST)), + 1000, + )) + .build(); + + let bytes = response.into_bytes(1000); + + let parsed_response = Response::parse(&bytes).unwrap(); + + assert!(parsed_response.truncated()); + assert_eq!(parsed_response.records().count(), 0); + assert_eq!(parsed_response.domain(), domain); + } +} diff --git a/rust/gateway/Cargo.toml b/rust/gateway/Cargo.toml index bcc01a268..54e6accf5 100644 --- a/rust/gateway/Cargo.toml +++ b/rust/gateway/Cargo.toml @@ -15,7 +15,7 @@ chrono = { workspace = true } clap = { workspace = true } connlib-model = { workspace = true } dns-lookup = { workspace = true } -domain = { workspace = true } +dns-types = { workspace = true } either = { workspace = true } firezone-bin-shared = { workspace = true } firezone-logging = { workspace = true } diff --git a/rust/gateway/src/eventloop.rs b/rust/gateway/src/eventloop.rs index ceecfe868..74c98058d 100644 --- a/rust/gateway/src/eventloop.rs +++ b/rust/gateway/src/eventloop.rs @@ -1,8 +1,8 @@ use anyhow::{Context as _, Result}; use boringtun::x25519::PublicKey; -use connlib_model::DomainName; #[cfg(not(target_os = "windows"))] use dns_lookup::{AddrInfoHints, AddrInfoIter, LookupError}; +use dns_types::DomainName; use firezone_bin_shared::TunDeviceManager; use firezone_logging::{telemetry_event, telemetry_span}; use firezone_tunnel::messages::gateway::{ diff --git a/rust/headless-client/Cargo.toml b/rust/headless-client/Cargo.toml index 9752cc22e..943c1395e 100644 --- a/rust/headless-client/Cargo.toml +++ b/rust/headless-client/Cargo.toml @@ -14,6 +14,7 @@ backoff = { workspace = true } clap = { workspace = true, features = ["derive", "env", "string"] } connlib-client-shared = { workspace = true } connlib-model = { workspace = true } +dns-types = { workspace = true } firezone-bin-shared = { workspace = true } firezone-logging = { workspace = true } firezone-telemetry = { workspace = true } diff --git a/rust/headless-client/src/dns_control/linux.rs b/rust/headless-client/src/dns_control/linux.rs index 2c68fda67..390d6ab31 100644 --- a/rust/headless-client/src/dns_control/linux.rs +++ b/rust/headless-client/src/dns_control/linux.rs @@ -1,6 +1,6 @@ use super::DnsController; use anyhow::{bail, Context as _, Result}; -use connlib_model::DomainName; +use dns_types::DomainName; use firezone_bin_shared::{platform::DnsControlMethod, TunDeviceManager}; use std::{net::IpAddr, process::Command, str::FromStr}; diff --git a/rust/headless-client/src/dns_control/linux/etc_resolv_conf.rs b/rust/headless-client/src/dns_control/linux/etc_resolv_conf.rs index de840d11d..61be67a2c 100644 --- a/rust/headless-client/src/dns_control/linux/etc_resolv_conf.rs +++ b/rust/headless-client/src/dns_control/linux/etc_resolv_conf.rs @@ -1,5 +1,5 @@ use anyhow::{bail, Context, Result}; -use connlib_model::DomainName; +use dns_types::DomainName; use std::{ fs, io::{self, Write}, diff --git a/rust/headless-client/src/dns_control/windows.rs b/rust/headless-client/src/dns_control/windows.rs index 1b41e0b4c..a588c8c8f 100644 --- a/rust/headless-client/src/dns_control/windows.rs +++ b/rust/headless-client/src/dns_control/windows.rs @@ -15,7 +15,7 @@ use super::DnsController; use anyhow::{Context as _, Result}; -use connlib_model::DomainName; +use dns_types::DomainName; use firezone_bin_shared::platform::{DnsControlMethod, CREATE_NO_WINDOW, TUNNEL_UUID}; use firezone_bin_shared::windows::error::EPT_S_NOT_REGISTERED; use std::{io, net::IpAddr, os::windows::process::CommandExt, path::Path, process::Command}; diff --git a/rust/headless-client/src/lib.rs b/rust/headless-client/src/lib.rs index a417f6bd9..96738a563 100644 --- a/rust/headless-client/src/lib.rs +++ b/rust/headless-client/src/lib.rs @@ -12,7 +12,8 @@ use anyhow::{Context as _, Result}; use connlib_client_shared::Callbacks; -use connlib_model::{DomainName, ResourceView}; +use connlib_model::ResourceView; +use dns_types::DomainName; use firezone_bin_shared::platform::DnsControlMethod; use firezone_logging::FilterReloadHandle; use std::{ diff --git a/rust/ip-packet/src/lib.rs b/rust/ip-packet/src/lib.rs index 3b92f3146..419d9652f 100644 --- a/rust/ip-packet/src/lib.rs +++ b/rust/ip-packet/src/lib.rs @@ -43,9 +43,9 @@ pub const MAX_IP_SIZE: usize = 1280; /// /// IPv6 headers are always a fixed size whereas IPv4 headers can vary. /// The max length of an IPv4 header is > the fixed length of an IPv6 header. -pub const MAX_IP_PAYLOAD: usize = MAX_IP_SIZE - etherparse::Ipv4Header::MAX_LEN; +pub const MAX_IP_PAYLOAD: u16 = (MAX_IP_SIZE - etherparse::Ipv4Header::MAX_LEN) as u16; /// The maximum payload a UDP packet can have. -pub const MAX_UDP_PAYLOAD: usize = MAX_IP_PAYLOAD - etherparse::UdpHeader::LEN; +pub const MAX_UDP_PAYLOAD: u16 = MAX_IP_PAYLOAD - etherparse::UdpHeader::LEN as u16; /// The maximum size of the payload that Firezone will send between nodes. ///