diff --git a/rust/Cargo.lock b/rust/Cargo.lock index c83246876..03667464d 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -1912,9 +1912,11 @@ dependencies = [ "domain", "firezone-bin-shared", "firezone-logging", + "futures", "ip-packet", "ip_network", "itertools 0.13.0", + "rand 0.8.5", "smoltcp", "tokio", "tracing", diff --git a/rust/dns-over-tcp/Cargo.toml b/rust/dns-over-tcp/Cargo.toml index a59bf0d4a..6052dea2e 100644 --- a/rust/dns-over-tcp/Cargo.toml +++ b/rust/dns-over-tcp/Cargo.toml @@ -9,12 +9,14 @@ anyhow = "1.0" domain = { workspace = true } ip-packet = { workspace = true } itertools = "0.13" +rand = "0.8" smoltcp = { version = "0.11", default-features = false, features = ["std", "log", "medium-ip", "proto-ipv4", "proto-ipv6", "socket-tcp"] } tracing = { workspace = true } [dev-dependencies] firezone-bin-shared = { workspace = true } firezone-logging = { workspace = true } +futures = "0.3" ip_network = { version = "0.4", default-features = false } tokio = { workspace = true, features = ["process", "rt", "macros"] } tun = { workspace = true } diff --git a/rust/dns-over-tcp/src/client.rs b/rust/dns-over-tcp/src/client.rs new file mode 100644 index 000000000..589f28ffa --- /dev/null +++ b/rust/dns-over-tcp/src/client.rs @@ -0,0 +1,417 @@ +use std::{ + collections::{BTreeSet, HashMap, HashSet, VecDeque}, + net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, + ops::RangeInclusive, + time::Instant, +}; + +use crate::{codec, create_tcp_socket, interface::create_interface, stub_device::InMemoryDevice}; +use anyhow::{anyhow, bail, Context as _, Result}; +use domain::{base::Message, dep::octseq::OctetsInto}; +use ip_packet::IpPacket; +use rand::{rngs::StdRng, Rng, SeedableRng}; +use smoltcp::{ + iface::{Interface, SocketSet}, + socket::tcp::{self, Socket}, +}; + +/// A sans-io DNS-over-TCP client. +/// +/// The client maintains a single TCP connection for each configured resolver. +/// If the TCP connection fails for some reason, we try to establish a new one. +/// +/// One of the design goals of this client is to always provide a result for a query. +/// If the TCP connection fails, we report all currently pending queries to that resolver as failed. +/// +/// There are however currently no timeouts. +/// If the upstream resolver refuses to answer, we don't fail the query. +pub struct Client { + device: InMemoryDevice, + interface: Interface, + source_ips: Option<(Ipv4Addr, Ipv6Addr)>, + /// The port range we are allowed to use on our local interface for outgoing connections. + port_range: RangeInclusive, + + sockets: SocketSet<'static>, + sockets_by_remote: HashMap, + local_ports_by_socket: HashMap, + /// Queries we should send to a DNS resolver. + pending_queries_by_remote: HashMap>>>, + /// Queries we have sent to a DNS resolver and are waiting for a reply. + sent_queries_by_remote: HashMap>>>, + + query_results: VecDeque, + + rng: StdRng, +} + +#[derive(Debug)] +pub struct QueryResult { + pub query: Message>, + pub server: SocketAddr, + pub result: Result>>, +} + +impl Client { + pub fn new(now: Instant, port_range: RangeInclusive, seed: [u8; 32]) -> Self { + assert!(!port_range.contains(&0), "0 port must not be possible"); + + let mut device = InMemoryDevice::default(); + let interface = create_interface(&mut device, now); + + Self { + device, + interface, + sockets: SocketSet::new(Vec::default()), + source_ips: None, + port_range, + sent_queries_by_remote: Default::default(), + query_results: Default::default(), + rng: StdRng::from_seed(seed), + sockets_by_remote: Default::default(), + local_ports_by_socket: Default::default(), + pending_queries_by_remote: Default::default(), + } + } + + /// Sets the IPv4 and IPv6 source ips to use for outgoing packets. + pub fn set_source_interface(&mut self, v4: Ipv4Addr, v6: Ipv6Addr) { + self.source_ips = Some((v4, v6)); + } + + /// Connect to the specified DNS resolvers. + /// + /// All currently pending queries will be reported as failed. + pub fn connect_to_resolvers(&mut self, resolvers: BTreeSet) -> Result<()> { + let (ipv4_source, ipv6_source) = self.source_ips.context("Missing source IPs")?; + + // First, clear all local state. + self.sockets = SocketSet::new(vec![]); + self.sockets_by_remote.clear(); + self.local_ports_by_socket.clear(); + + self.query_results + .extend( + self.pending_queries_by_remote + .drain() + .flat_map(|(server, queries)| { + into_failed_results(server, queries, || anyhow!("Aborted")) + }), + ); + self.query_results + .extend( + self.sent_queries_by_remote + .drain() + .flat_map(|(server, queries)| { + into_failed_results(server, queries.into_values(), || anyhow!("Aborted")) + }), + ); + + // Second, try to create all new sockets. + let new_sockets = std::iter::zip(self.sample_unique_ports(resolvers.len())?, resolvers).map(|(port, server)| { + let local_endpoint = match server { + SocketAddr::V4(_) => SocketAddr::new(ipv4_source.into(), port), + SocketAddr::V6(_) => SocketAddr::new(ipv6_source.into(), port), + }; + + let mut socket = create_tcp_socket(); + + socket + .connect(self.interface.context(), server, local_endpoint) + .context("Failed to connect socket")?; + + tracing::info!(local = %local_endpoint, remote = %server, "Connecting to DNS resolver"); + + Ok((server, local_endpoint, socket)) + }) + .collect::>>()?; + + // Third, if everything was successful, change the local state. + for (server, local_endpoint, socket) in new_sockets { + let handle = self.sockets.add(socket); + + self.sockets_by_remote.insert(server, handle); + self.local_ports_by_socket + .insert(handle, local_endpoint.port()); + } + + Ok(()) + } + + /// Send the given DNS query to the target server. + /// + /// This only queues the message. You need to call [`Client::handle_timeout`] to actually send them. + pub fn send_query(&mut self, server: SocketAddr, message: Message>) -> Result<()> { + anyhow::ensure!(!message.header().qr(), "Message is a DNS response!"); + anyhow::ensure!( + self.sockets_by_remote.contains_key(&server), + "Unknown DNS resolver" + ); + + self.pending_queries_by_remote + .entry(server) + .or_default() + .push_back(message); + + Ok(()) + } + + /// Checks whether this client can handle the given packet. + /// + /// Only TCP packets originating from one of the connected DNS resolvers are accepted. + pub fn accepts(&self, packet: &IpPacket) -> bool { + let Some(tcp) = packet.as_tcp() else { + tracing::trace!(?packet, "Not a TCP packet"); + + return false; + }; + + let Some((ipv4_source, ipv6_source)) = self.source_ips else { + tracing::trace!("No source interface"); + + return false; + }; + + // If the packet doesn't match our source interface, we don't want it. + match packet.destination() { + IpAddr::V4(v4) if v4 != ipv4_source => return false, + IpAddr::V6(v6) if v6 != ipv6_source => return false, + _ => {} + } + + let remote = SocketAddr::new(packet.source(), tcp.source_port()); + let has_socket = self.sockets_by_remote.contains_key(&remote); + + if !has_socket && tracing::enabled!(tracing::Level::TRACE) { + let open_sockets = BTreeSet::from_iter(self.sockets_by_remote.keys().copied()); + + tracing::trace!(%remote, ?open_sockets, "No open socket for remote"); + } + + has_socket + } + + /// Handle the [`IpPacket`]. + /// + /// This function only inserts the packet into a buffer. + /// To actually process the packets in the buffer, [`Client::handle_timeout`] must be called. + pub fn handle_inbound(&mut self, packet: IpPacket) { + debug_assert!(self.accepts(&packet)); + + self.device.receive(packet); + } + + /// Returns [`IpPacket`]s that should be sent. + pub fn poll_outbound(&mut self) -> Option { + self.device.next_send() + } + + /// Returns the next [`QueryResult`]. + pub fn poll_query_result(&mut self) -> Option { + self.query_results.pop_front() + } + + /// Inform the client that time advanced. + /// + /// Typical for a sans-IO design, `handle_timeout` will work through all local buffers and process them as much as possible. + pub fn handle_timeout(&mut self, now: Instant) { + let Some((ipv4_source, ipv6_source)) = self.source_ips else { + return; + }; + + let changed = self.interface.poll( + smoltcp::time::Instant::from(now), + &mut self.device, + &mut self.sockets, + ); + + if !changed && self.pending_queries_by_remote.is_empty() { + return; + } + + for (remote, handle) in self.sockets_by_remote.iter_mut() { + let socket = self.sockets.get_mut::(*handle); + let server = *remote; + + // First, attempt to send all pending queries on this socket. + send_pending_queries( + socket, + server, + &mut self.pending_queries_by_remote, + &mut self.sent_queries_by_remote, + &mut self.query_results, + ); + + // Second, attempt to receive responses. + recv_responses( + socket, + server, + &mut self.pending_queries_by_remote, + &mut self.sent_queries_by_remote, + &mut self.query_results, + ); + + let has_pending_dns_queries = !self + .pending_queries_by_remote + .entry(server) + .or_default() + .is_empty(); + + // Third, if the socket got closed, reconnect it. + if matches!(socket.state(), tcp::State::Closed) && has_pending_dns_queries { + let local_port = self + .local_ports_by_socket + .get(handle) + .expect("must always have a port for each socket"); + let local_endpoint = match server { + SocketAddr::V4(_) => SocketAddr::new(ipv4_source.into(), *local_port), + SocketAddr::V6(_) => SocketAddr::new(ipv6_source.into(), *local_port), + }; + + tracing::info!(local = %local_endpoint, remote = %server, "Re-connecting to DNS resolver"); + + socket + .connect(self.interface.context(), server, local_endpoint) + .expect( + "re-connecting a closed socket with the same parameters should always work", + ); + } + } + } + + fn sample_unique_ports(&mut self, num_ports: usize) -> Result> { + let mut ports = HashSet::with_capacity(num_ports); + + if num_ports > self.port_range.len() { + bail!( + "Port range only provides {} values but we need {num_ports}", + self.port_range.len() + ) + } + + while ports.len() < num_ports { + ports.insert(self.rng.gen_range(self.port_range.clone())); + } + + Ok(ports.into_iter()) + } +} + +fn send_pending_queries( + socket: &mut Socket, + server: SocketAddr, + pending_queries_by_remote: &mut HashMap>>>, + sent_queries_by_remote: &mut HashMap>>>, + query_results: &mut VecDeque, +) { + let pending_queries = pending_queries_by_remote.entry(server).or_default(); + let sent_queries = sent_queries_by_remote.entry(server).or_default(); + + loop { + if !socket.can_send() { + break; + } + + let Some(query) = pending_queries.pop_front() else { + break; + }; + + match codec::try_send(socket, query.for_slice_ref()).context("Failed to send DNS query") { + Ok(()) => { + let replaced = sent_queries.insert(query.header().id(), query).is_some(); + debug_assert!(!replaced, "Query ID is not unique"); + } + Err(e) => { + // We failed to send the query, declare the socket as failed. + socket.abort(); + + query_results.extend(into_failed_results( + server, + pending_queries + .drain(..) + .chain(sent_queries.drain().map(|(_, query)| query)), + || anyhow!("{e:#}"), + )); + query_results.push_back(QueryResult { + query, + server, + result: Err(e), + }); + } + } + } +} + +fn recv_responses( + socket: &mut Socket, + server: SocketAddr, + pending_queries_by_remote: &mut HashMap>>>, + sent_queries_by_remote: &mut HashMap>>>, + query_results: &mut VecDeque, +) { + let Some(result) = try_recv_response(socket) + .context("Failed to receive DNS response") + .transpose() + else { + return; // No messages on this socket, continue. + }; + let pending_queries = pending_queries_by_remote.entry(server).or_default(); + let sent_queries = sent_queries_by_remote.entry(server).or_default(); + + let new_results = result + .and_then(|response| { + let query = sent_queries + .remove(&response.header().id()) + .context("DNS resolver sent response for unknown query")?; + + Ok(vec![QueryResult { + query, + server, + result: Ok(response.octets_into()), + }]) + }) + .unwrap_or_else(|e| { + socket.abort(); + + into_failed_results( + server, + pending_queries + .drain(..) + .chain(sent_queries.drain().map(|(_, query)| query)), + || anyhow!("{e:#}"), + ) + .collect() + }); + + query_results.extend(new_results); +} + +fn into_failed_results( + server: SocketAddr, + iter: impl IntoIterator>>, + make_error: impl Fn() -> anyhow::Error, +) -> impl Iterator { + iter.into_iter().map(move |query| QueryResult { + query, + server, + result: Err(make_error()), + }) +} + +fn try_recv_response<'b>(socket: &'b mut tcp::Socket) -> Result>> { + anyhow::ensure!(socket.is_active(), "Socket is not active"); + + if !socket.can_recv() { + tracing::trace!("Not yet ready to receive next message"); + + return Ok(None); + } + + let Some(message) = codec::try_recv(socket)? else { + return Ok(None); + }; + + anyhow::ensure!(message.header().qr(), "DNS message is a query!"); + + Ok(Some(message)) +} diff --git a/rust/dns-over-tcp/src/codec.rs b/rust/dns-over-tcp/src/codec.rs new file mode 100644 index 000000000..a4a32e25f --- /dev/null +++ b/rust/dns-over-tcp/src/codec.rs @@ -0,0 +1,141 @@ +//! Implements sending and receiving of DNS messages over TCP. +//! +//! TCP's stream-oriented nature requires us to know how long the encoded DNS message is before we can read it. +//! For this purpose, DNS messages over TCP are prefixed using a big-endian encoded u16. +//! +//! Source: . + +use anyhow::{Context as _, Result}; +use domain::{ + base::{iana::Rcode, Message, ParsedName, Rtype}, + rdata::AllRecordData, +}; +use itertools::Itertools as _; +use smoltcp::socket::tcp; + +pub fn try_send(socket: &mut tcp::Socket, message: Message<&[u8]>) -> Result<()> { + let response = message.as_slice(); + + let dns_message_length = (response.len() as u16).to_be_bytes(); + + let written = socket + .send_slice(&dns_message_length) + .context("Failed to write TCP DNS length header")?; + + anyhow::ensure!( + written == 2, + "Not enough space in write buffer for TCP DNS length header" + ); + + let written = socket + .send_slice(response) + .context("Failed to write DNS message")?; + + anyhow::ensure!( + written == response.len(), + "Not enough space in write buffer for DNS message" + ); + + if tracing::event_enabled!(target: "wire::dns::tcp::send", tracing::Level::TRACE) { + if let Some(ParsedMessage { + qid, + qname, + qtype, + response, + rcode, + records, + }) = parse(message) + { + if response { + let records = records.into_iter().join(" | "); + tracing::trace!(target: "wire::dns::tcp::send", %qid, %rcode, "{:5} {qname} => [{records}]", qtype.to_string()); + } else { + tracing::trace!(target: "wire::dns::tcp::send", %qid, "{:5} {qname}", qtype.to_string()); + } + } + } + + Ok(()) +} + +pub fn try_recv<'b>(socket: &'b mut tcp::Socket) -> Result>> { + let maybe_message = socket + .recv(|r| { + // DNS over TCP has a 2-byte length prefix at the start, see . + let Some((header, message)) = r.split_first_chunk::<2>() else { + return (0, None); + }; + let dns_message_length = u16::from_be_bytes(*header) as usize; + if message.len() < dns_message_length { + return (0, None); // Don't consume any bytes unless we can read the full message at once. + } + + (2 + dns_message_length, Some(Message::from_octets(message))) + }) + .context("Failed to recv TCP data")? + .transpose() + .context("Failed to parse DNS message")?; + + if tracing::event_enabled!(target: "wire::dns::tcp::recv", tracing::Level::TRACE) { + if let Some(ParsedMessage { + qid, + qname, + qtype, + rcode, + response, + records, + }) = maybe_message.and_then(parse) + { + if response { + let records = records.into_iter().join(" | "); + tracing::trace!(target: "wire::dns::tcp::recv", %qid, %rcode, "{:5} {qname} => [{records}]", qtype.to_string()); + } else { + tracing::trace!(target: "wire::dns::tcp::recv", %qid, "{:5} {qname}", qtype.to_string()); + } + } + } + + Ok(maybe_message) +} + +fn parse(message: Message<&[u8]>) -> Option> { + let question = message.sole_question().ok()?; + let answers = message.answer().ok()?; + + let qtype = question.qtype(); + let qname = question.into_qname(); + let qid = message.header().id(); + let response = message.header().qr(); + let rcode = message.header().rcode(); + let records = answers + .into_iter() + .filter_map(|r| { + let data = r + .ok()? + .into_any_record::>() + .ok()? + .data() + .clone(); + + Some(data) + }) + .collect(); + + Some(ParsedMessage { + qid, + qname, + rcode, + qtype, + response, + records, + }) +} + +struct ParsedMessage<'a> { + qid: u16, + qname: ParsedName<&'a [u8]>, + qtype: Rtype, + rcode: Rcode, + response: bool, + records: Vec>>, +} diff --git a/rust/dns-over-tcp/src/interface.rs b/rust/dns-over-tcp/src/interface.rs new file mode 100644 index 000000000..f7dc63239 --- /dev/null +++ b/rust/dns-over-tcp/src/interface.rs @@ -0,0 +1,44 @@ +use std::time::Instant; + +use smoltcp::{ + iface::{Config, Interface, Route}, + wire::{HardwareAddress, Ipv4Address, Ipv4Cidr, Ipv6Address, Ipv6Cidr}, +}; + +use crate::stub_device::InMemoryDevice; + +const IP4_ADDR: Ipv4Address = Ipv4Address::new(127, 0, 0, 1); +const IP6_ADDR: Ipv6Address = Ipv6Address::new(0, 0, 0, 0, 0, 0, 0, 1); + +/// Creates a smoltcp [`Interface`]. +/// +/// smoltcp's abstractions allow to directly plug it in a TUN device. +/// As a result, it has all the features you'd expect from a network interface: +/// - Setting IP addresses +/// - Defining routes +/// +/// In our implementation, we don't want to use any of that. +/// Our device is entirely backed by in-memory buffers and we and selectively feed IP packets to it. +/// Therefore, we configure it to: +/// - Accept any packet +/// - Define dummy IPs (localhost for IPv4 and IPv6) +/// - Define catch-all routes (0.0.0.0/0) that routes all traffic to the interface +pub fn create_interface(device: &mut InMemoryDevice, now: Instant) -> Interface { + let mut interface = Interface::new(Config::new(HardwareAddress::Ip), device, now.into()); + // Accept packets with any destination IP, not just our interface. + interface.set_any_ip(true); + + // Set our interface IPs. These are just dummies and don't show up anywhere! + interface.update_ip_addrs(|ips| { + ips.push(Ipv4Cidr::new(IP4_ADDR, 32).into()).unwrap(); + ips.push(Ipv6Cidr::new(IP6_ADDR, 128).into()).unwrap(); + }); + + // Configure catch-all routes, meaning all packets given to `smoltcp` will be routed to our interface. + interface.routes_mut().update(|routes| { + routes.push(Route::new_ipv4_gateway(IP4_ADDR)).unwrap(); + routes.push(Route::new_ipv6_gateway(IP6_ADDR)).unwrap(); + }); + + interface +} diff --git a/rust/dns-over-tcp/src/lib.rs b/rust/dns-over-tcp/src/lib.rs index 94a38418e..af49bc6b8 100644 --- a/rust/dns-over-tcp/src/lib.rs +++ b/rust/dns-over-tcp/src/lib.rs @@ -1,4 +1,21 @@ +mod client; +mod codec; +mod interface; mod server; mod stub_device; +pub use client::{Client, QueryResult}; pub use server::{Server, SocketHandle}; + +fn create_tcp_socket() -> smoltcp::socket::tcp::Socket<'static> { + /// The 2-byte length prefix of DNS over TCP messages limits their size to effectively u16::MAX. + /// It is quite unlikely that we have to buffer _multiple_ of these max-sized messages. + /// Being able to buffer at least one of them means we can handle the extreme case. + /// In practice, this allows the OS to queue multiple queries even if we can't immediately process them. + const MAX_TCP_DNS_MSG_LENGTH: usize = u16::MAX as usize; + + smoltcp::socket::tcp::Socket::new( + smoltcp::storage::RingBuffer::new(vec![0u8; MAX_TCP_DNS_MSG_LENGTH]), + smoltcp::storage::RingBuffer::new(vec![0u8; MAX_TCP_DNS_MSG_LENGTH]), + ) +} diff --git a/rust/dns-over-tcp/src/server.rs b/rust/dns-over-tcp/src/server.rs index 6d212ab05..dceb2ddb9 100644 --- a/rust/dns-over-tcp/src/server.rs +++ b/rust/dns-over-tcp/src/server.rs @@ -4,16 +4,14 @@ use std::{ time::Instant, }; -use crate::stub_device::InMemoryDevice; +use crate::{codec, create_tcp_socket, interface::create_interface, stub_device::InMemoryDevice}; use anyhow::{Context as _, Result}; -use domain::{base::Message, dep::octseq::OctetsInto as _, rdata::AllRecordData}; +use domain::{base::Message, dep::octseq::OctetsInto as _}; use ip_packet::IpPacket; -use itertools::Itertools as _; use smoltcp::{ - iface::{Config, Interface, Route, SocketSet}, + iface::{Interface, SocketSet}, socket::tcp, - storage::RingBuffer, - wire::{HardwareAddress, IpEndpoint, Ipv4Address, Ipv4Cidr, Ipv6Address, Ipv6Cidr}, + wire::IpEndpoint, }; /// A sans-IO implementation of DNS-over-TCP server. @@ -43,34 +41,10 @@ pub struct Query { pub local: SocketAddr, } -const SERVER_IP4_ADDR: Ipv4Address = Ipv4Address::new(127, 0, 0, 1); -const SERVER_IP6_ADDR: Ipv6Address = Ipv6Address::new(0, 0, 0, 0, 0, 0, 0, 1); - impl Server { pub fn new(now: Instant) -> Self { let mut device = InMemoryDevice::default(); - - let mut interface = - Interface::new(Config::new(HardwareAddress::Ip), &mut device, now.into()); - // Accept packets with any destination IP, not just our interface. - interface.set_any_ip(true); - - // Set our interface IPs. These are just dummies and don't show up anywhere! - interface.update_ip_addrs(|ips| { - ips.push(Ipv4Cidr::new(SERVER_IP4_ADDR, 32).into()).unwrap(); - ips.push(Ipv6Cidr::new(SERVER_IP6_ADDR, 128).into()) - .unwrap(); - }); - - // Configure catch-all routes, meaning all packets given to `smoltcp` will be routed to our interface. - interface.routes_mut().update(|routes| { - routes - .push(Route::new_ipv4_gateway(SERVER_IP4_ADDR)) - .unwrap(); - routes - .push(Route::new_ipv6_gateway(SERVER_IP6_ADDR)) - .unwrap(); - }); + let interface = create_interface(&mut device, now); Self { device, @@ -98,7 +72,12 @@ impl Server { for listen_endpoint in addresses { for _ in 0..NUM_CONCURRENT_CLIENTS { - let handle = sockets.add(create_tcp_socket(listen_endpoint)); + let mut socket = create_tcp_socket(); + socket + .listen(listen_endpoint) + .expect("A fresh socket should always be able to listen"); + + let handle = sockets.add(socket); listen_endpoints.insert(handle, listen_endpoint); } @@ -149,40 +128,9 @@ impl Server { pub fn send_message(&mut self, socket: SocketHandle, message: Message>) -> Result<()> { let socket = self.sockets.get_mut::(socket.0); - let result = write_tcp_dns_response(socket, message.for_slice_ref()); - - if result.is_err() { - socket.abort(); - } - - result.context("Failed to write DNS response")?; // Bail before logging in case we failed to write the response. - - if tracing::event_enabled!(target: "wire::dns::res", tracing::Level::TRACE) { - if let Ok(question) = message.sole_question() { - let qtype = question.qtype(); - let qname = question.into_qname(); - let rcode = message.header().rcode(); - - if let Ok(record_section) = message.answer() { - let records = record_section - .into_iter() - .filter_map(|r| { - let data = r - .ok()? - .into_any_record::>() - .ok()? - .data() - .clone(); - - Some(data) - }) - .join(" | "); - let qid = message.header().id(); - - tracing::trace!(target: "wire::dns::res", %qid, %rcode, "{:5} {qname} => [{records}]", qtype.to_string()); - } - } - } + write_tcp_dns_response(socket, message.for_slice_ref()) + .inspect_err(|_| socket.abort()) // Abort socket on error. + .context("Failed to write DNS response")?; Ok(()) } @@ -211,28 +159,20 @@ impl Server { for (handle, smoltcp::socket::Socket::Tcp(socket)) in self.sockets.iter_mut() { let listen = self.listen_endpoints.get(&handle).copied().unwrap(); - match try_recv_query(socket, listen) { - Ok(Some(message)) => { - if tracing::event_enabled!(target: "wire::dns::qry", tracing::Level::TRACE) { - if let Ok(question) = message.sole_question() { - let qtype = question.qtype(); - let qname = question.into_qname(); - let qid = message.header().id(); - - tracing::trace!(target: "wire::dns::qry", %qid, "{:5} {qname}", qtype.to_string()); - } + while let Some(result) = try_recv_query(socket, listen).transpose() { + match result { + Ok(message) => { + self.received_queries.push_back(Query { + message: message.octets_into(), + socket: SocketHandle(handle), + local: listen, + }); + } + Err(e) => { + tracing::debug!("Error on receiving DNS query: {e}"); + socket.abort(); + break; } - - self.received_queries.push_back(Query { - message, - socket: SocketHandle(handle), - local: listen, - }); - } - Ok(None) => {} - Err(e) => { - tracing::debug!("Error on receiving DNS query: {e}"); - socket.abort(); } } } @@ -249,28 +189,10 @@ impl Server { } } -fn create_tcp_socket(listen_endpoint: SocketAddr) -> tcp::Socket<'static> { - /// The 2-byte length prefix of DNS over TCP messages limits their size to effectively u16::MAX. - /// It is quite unlikely that we have to buffer _multiple_ of these max-sized messages. - /// Being able to buffer at least one of them means we can handle the extreme case. - /// In practice, this allows the OS to queue multiple queries even if we can't immediately process them. - const MAX_TCP_DNS_MSG_LENGTH: usize = u16::MAX as usize; - - let mut socket = tcp::Socket::new( - RingBuffer::new(vec![0u8; MAX_TCP_DNS_MSG_LENGTH]), - RingBuffer::new(vec![0u8; MAX_TCP_DNS_MSG_LENGTH]), - ); - socket - .listen(listen_endpoint) - .expect("A fresh socket should always be able to listen"); - - socket -} - -fn try_recv_query( - socket: &mut tcp::Socket, +fn try_recv_query<'b>( + socket: &'b mut tcp::Socket, listen: SocketAddr, -) -> Result>>> { +) -> Result>> { // smoltcp's sockets can only ever handle a single remote, i.e. there is no permanent listening socket. // to be able to handle a new connection, reset the socket back to `listen` once the connection is closed / closing. { @@ -306,56 +228,19 @@ fn try_recv_query( return Ok(None); } - // Read a DNS message from the socket. - let Some(message) = socket - .recv(|r| { - // DNS over TCP has a 2-byte length prefix at the start, see . - let Some((header, message)) = r.split_first_chunk::<2>() else { - return (0, None); - }; - let dns_message_length = u16::from_be_bytes(*header) as usize; - if message.len() < dns_message_length { - return (0, None); // Don't consume any bytes unless we can read the full message at once. - } - - (2 + dns_message_length, Some(Message::from_octets(message))) - }) - .context("Failed to recv TCP data")? - .transpose() - .context("Failed to parse DNS message")? - else { + let Some(message) = codec::try_recv(socket)? else { return Ok(None); }; anyhow::ensure!(!message.header().qr(), "DNS message is a response!"); - Ok(Some(message.octets_into())) + Ok(Some(message)) } fn write_tcp_dns_response(socket: &mut tcp::Socket, response: Message<&[u8]>) -> Result<()> { anyhow::ensure!(response.header().qr(), "DNS message is a query!"); - let response = response.as_slice(); - - let dns_message_length = (response.len() as u16).to_be_bytes(); - - let written = socket - .send_slice(&dns_message_length) - .context("Failed to write TCP DNS length header")?; - - anyhow::ensure!( - written == 2, - "Not enough space in write buffer for TCP DNS length header" - ); - - let written = socket - .send_slice(response) - .context("Failed to write DNS message")?; - - anyhow::ensure!( - written == response.len(), - "Not enough space in write buffer for DNS response" - ); + codec::try_send(socket, response)?; Ok(()) } diff --git a/rust/dns-over-tcp/tests/client_and_server.rs b/rust/dns-over-tcp/tests/client_and_server.rs new file mode 100644 index 000000000..17a42e5de --- /dev/null +++ b/rust/dns-over-tcp/tests/client_and_server.rs @@ -0,0 +1,89 @@ +use std::{ + collections::BTreeSet, + net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4}, + time::Instant, +}; + +use dns_over_tcp::QueryResult; +use domain::base::{iana::Rcode, Message, MessageBuilder, Name, Rtype}; + +#[test] +fn smoke() { + let _guard = firezone_logging::test( + "netlink_proto=off,wire::dns::res=trace,dns_over_tcp=trace,smoltcp=trace,debug", + ); + + let ipv4 = Ipv4Addr::from([100, 90, 215, 97]); + let ipv6 = Ipv6Addr::from([0xfd00, 0x2021, 0x1111, 0x0, 0x0, 0x0, 0x0016, 0x588f]); + + let resolver_addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(100, 100, 111, 1), 53)); + + let mut dns_client = dns_over_tcp::Client::new(Instant::now(), 49152..=65535, [0u8; 32]); + dns_client.set_source_interface(ipv4, ipv6); + dns_client + .connect_to_resolvers(BTreeSet::from_iter([resolver_addr])) + .unwrap(); + + let mut dns_server = dns_over_tcp::Server::new(Instant::now()); + dns_server.set_listen_addresses::<1>(vec![resolver_addr]); + + for id in 0..5 { + dns_client + .send_query(resolver_addr, a_query("example.com", id)) + .unwrap(); + } + + let results = std::iter::from_fn(|| progress(&mut dns_client, &mut dns_server)) + .take(5) + .collect::>(); + + for query_result in results { + let result = query_result.result.unwrap(); + + println!("{result:?}") + } +} + +fn a_query(domain: &str, id: u16) -> Message> { + let mut builder = MessageBuilder::new_vec().question(); + builder.header_mut().set_id(id); + builder + .push((Name::vec_from_str(domain).unwrap(), Rtype::A)) + .unwrap(); + + builder.into_message() +} + +fn progress( + dns_client: &mut dns_over_tcp::Client, + dns_server: &mut dns_over_tcp::Server, +) -> Option { + loop { + if let Some(packet) = dns_client.poll_outbound() { + dns_server.handle_inbound(packet); + continue; + } + + if let Some(packet) = dns_server.poll_outbound() { + dns_client.handle_inbound(packet); + continue; + } + + if let Some(query) = dns_server.poll_queries() { + let response = MessageBuilder::new_vec() + .start_answer(&query.message, Rcode::NXDOMAIN) + .unwrap() + .into_message(); + + dns_server.send_message(query.socket, response).unwrap(); + continue; + } + + if let Some(query) = dns_client.poll_query_result() { + return Some(query); + } + + dns_client.handle_timeout(Instant::now()); + dns_server.handle_timeout(Instant::now()); + } +} diff --git a/rust/dns-over-tcp/tests/smoke.rs b/rust/dns-over-tcp/tests/smoke_server.rs similarity index 96% rename from rust/dns-over-tcp/tests/smoke.rs rename to rust/dns-over-tcp/tests/smoke_server.rs index 7a745654e..5003653ad 100644 --- a/rust/dns-over-tcp/tests/smoke.rs +++ b/rust/dns-over-tcp/tests/smoke_server.rs @@ -16,10 +16,9 @@ use tun::Tun; const CLIENT_CONCURRENCY: usize = 3; #[tokio::test] -#[ignore = "Requires root"] +#[ignore = "Requires root & IP forwarding"] async fn smoke() { - let _guard = - firezone_logging::test("netlink_proto=off,wire::dns::res=trace,dns_over_tcp=trace,debug"); + let _guard = firezone_logging::test("netlink_proto=off,wire::dns=trace,debug"); let ipv4 = Ipv4Addr::from([100, 90, 215, 97]); let ipv6 = Ipv6Addr::from([0xfd00, 0x2021, 0x1111, 0x0, 0x0, 0x0, 0x0016, 0x588f]);