From 274cc8655739999ea2ff0778dba7d29ca9acd577 Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Sat, 12 Oct 2024 09:04:45 +1100 Subject: [PATCH] chore(connlib): add sans-IO DNS-over-TCP client (#7007) This brings us one step closer to completing #6140. In Firezone, users can define custom upstream DNS servers that take priority over system-defined DNS servers. The IPs of these servers could also be resources, meaning the DNS queries must be sent through the WireGuard tunnel to the gateway. For UDP DNS queries, that is easy because each query is only a single packet. For TCP DNS queries, we need to have a dedicated TCP-capable DNS server that parses all incoming queries. If they are required to be forwarded to the gateway, we then need a TCP-capable DNS client that can send them to the actual upstream DNS server. This PR implements such a DNS client. The design is tailored for what we need in `connlib`: We maintain a permanent TCP connection to each upstream DNS server and send queries to them. Most likely, users will only have a handful of DNS servers defined. TCP requires a three-way handshake before any application data can be sent, maintaining a connection should therefore greatly improve DNS resolution latency. DNS resolvers are encouraged to keep TCP connections open but may close them if they run out of resources. We only re-connect once we have more queries to send in order to not spam the resolver with connections. Resolves: #7000. --------- Signed-off-by: Thomas Eizinger --- rust/Cargo.lock | 2 + rust/dns-over-tcp/Cargo.toml | 2 + rust/dns-over-tcp/src/client.rs | 417 ++++++++++++++++++ rust/dns-over-tcp/src/codec.rs | 141 ++++++ rust/dns-over-tcp/src/interface.rs | 44 ++ rust/dns-over-tcp/src/lib.rs | 17 + rust/dns-over-tcp/src/server.rs | 181 ++------ rust/dns-over-tcp/tests/client_and_server.rs | 89 ++++ .../tests/{smoke.rs => smoke_server.rs} | 5 +- 9 files changed, 747 insertions(+), 151 deletions(-) create mode 100644 rust/dns-over-tcp/src/client.rs create mode 100644 rust/dns-over-tcp/src/codec.rs create mode 100644 rust/dns-over-tcp/src/interface.rs create mode 100644 rust/dns-over-tcp/tests/client_and_server.rs rename rust/dns-over-tcp/tests/{smoke.rs => smoke_server.rs} (96%) diff --git a/rust/Cargo.lock b/rust/Cargo.lock index c83246876..03667464d 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -1912,9 +1912,11 @@ dependencies = [ "domain", "firezone-bin-shared", "firezone-logging", + "futures", "ip-packet", "ip_network", "itertools 0.13.0", + "rand 0.8.5", "smoltcp", "tokio", "tracing", diff --git a/rust/dns-over-tcp/Cargo.toml b/rust/dns-over-tcp/Cargo.toml index a59bf0d4a..6052dea2e 100644 --- a/rust/dns-over-tcp/Cargo.toml +++ b/rust/dns-over-tcp/Cargo.toml @@ -9,12 +9,14 @@ anyhow = "1.0" domain = { workspace = true } ip-packet = { workspace = true } itertools = "0.13" +rand = "0.8" smoltcp = { version = "0.11", default-features = false, features = ["std", "log", "medium-ip", "proto-ipv4", "proto-ipv6", "socket-tcp"] } tracing = { workspace = true } [dev-dependencies] firezone-bin-shared = { workspace = true } firezone-logging = { workspace = true } +futures = "0.3" ip_network = { version = "0.4", default-features = false } tokio = { workspace = true, features = ["process", "rt", "macros"] } tun = { workspace = true } diff --git a/rust/dns-over-tcp/src/client.rs b/rust/dns-over-tcp/src/client.rs new file mode 100644 index 000000000..589f28ffa --- /dev/null +++ b/rust/dns-over-tcp/src/client.rs @@ -0,0 +1,417 @@ +use std::{ + collections::{BTreeSet, HashMap, HashSet, VecDeque}, + net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, + ops::RangeInclusive, + time::Instant, +}; + +use crate::{codec, create_tcp_socket, interface::create_interface, stub_device::InMemoryDevice}; +use anyhow::{anyhow, bail, Context as _, Result}; +use domain::{base::Message, dep::octseq::OctetsInto}; +use ip_packet::IpPacket; +use rand::{rngs::StdRng, Rng, SeedableRng}; +use smoltcp::{ + iface::{Interface, SocketSet}, + socket::tcp::{self, Socket}, +}; + +/// A sans-io DNS-over-TCP client. +/// +/// The client maintains a single TCP connection for each configured resolver. +/// If the TCP connection fails for some reason, we try to establish a new one. +/// +/// One of the design goals of this client is to always provide a result for a query. +/// If the TCP connection fails, we report all currently pending queries to that resolver as failed. +/// +/// There are however currently no timeouts. +/// If the upstream resolver refuses to answer, we don't fail the query. +pub struct Client { + device: InMemoryDevice, + interface: Interface, + source_ips: Option<(Ipv4Addr, Ipv6Addr)>, + /// The port range we are allowed to use on our local interface for outgoing connections. + port_range: RangeInclusive, + + sockets: SocketSet<'static>, + sockets_by_remote: HashMap, + local_ports_by_socket: HashMap, + /// Queries we should send to a DNS resolver. + pending_queries_by_remote: HashMap>>>, + /// Queries we have sent to a DNS resolver and are waiting for a reply. + sent_queries_by_remote: HashMap>>>, + + query_results: VecDeque, + + rng: StdRng, +} + +#[derive(Debug)] +pub struct QueryResult { + pub query: Message>, + pub server: SocketAddr, + pub result: Result>>, +} + +impl Client { + pub fn new(now: Instant, port_range: RangeInclusive, seed: [u8; 32]) -> Self { + assert!(!port_range.contains(&0), "0 port must not be possible"); + + let mut device = InMemoryDevice::default(); + let interface = create_interface(&mut device, now); + + Self { + device, + interface, + sockets: SocketSet::new(Vec::default()), + source_ips: None, + port_range, + sent_queries_by_remote: Default::default(), + query_results: Default::default(), + rng: StdRng::from_seed(seed), + sockets_by_remote: Default::default(), + local_ports_by_socket: Default::default(), + pending_queries_by_remote: Default::default(), + } + } + + /// Sets the IPv4 and IPv6 source ips to use for outgoing packets. + pub fn set_source_interface(&mut self, v4: Ipv4Addr, v6: Ipv6Addr) { + self.source_ips = Some((v4, v6)); + } + + /// Connect to the specified DNS resolvers. + /// + /// All currently pending queries will be reported as failed. + pub fn connect_to_resolvers(&mut self, resolvers: BTreeSet) -> Result<()> { + let (ipv4_source, ipv6_source) = self.source_ips.context("Missing source IPs")?; + + // First, clear all local state. + self.sockets = SocketSet::new(vec![]); + self.sockets_by_remote.clear(); + self.local_ports_by_socket.clear(); + + self.query_results + .extend( + self.pending_queries_by_remote + .drain() + .flat_map(|(server, queries)| { + into_failed_results(server, queries, || anyhow!("Aborted")) + }), + ); + self.query_results + .extend( + self.sent_queries_by_remote + .drain() + .flat_map(|(server, queries)| { + into_failed_results(server, queries.into_values(), || anyhow!("Aborted")) + }), + ); + + // Second, try to create all new sockets. + let new_sockets = std::iter::zip(self.sample_unique_ports(resolvers.len())?, resolvers).map(|(port, server)| { + let local_endpoint = match server { + SocketAddr::V4(_) => SocketAddr::new(ipv4_source.into(), port), + SocketAddr::V6(_) => SocketAddr::new(ipv6_source.into(), port), + }; + + let mut socket = create_tcp_socket(); + + socket + .connect(self.interface.context(), server, local_endpoint) + .context("Failed to connect socket")?; + + tracing::info!(local = %local_endpoint, remote = %server, "Connecting to DNS resolver"); + + Ok((server, local_endpoint, socket)) + }) + .collect::>>()?; + + // Third, if everything was successful, change the local state. + for (server, local_endpoint, socket) in new_sockets { + let handle = self.sockets.add(socket); + + self.sockets_by_remote.insert(server, handle); + self.local_ports_by_socket + .insert(handle, local_endpoint.port()); + } + + Ok(()) + } + + /// Send the given DNS query to the target server. + /// + /// This only queues the message. You need to call [`Client::handle_timeout`] to actually send them. + pub fn send_query(&mut self, server: SocketAddr, message: Message>) -> Result<()> { + anyhow::ensure!(!message.header().qr(), "Message is a DNS response!"); + anyhow::ensure!( + self.sockets_by_remote.contains_key(&server), + "Unknown DNS resolver" + ); + + self.pending_queries_by_remote + .entry(server) + .or_default() + .push_back(message); + + Ok(()) + } + + /// Checks whether this client can handle the given packet. + /// + /// Only TCP packets originating from one of the connected DNS resolvers are accepted. + pub fn accepts(&self, packet: &IpPacket) -> bool { + let Some(tcp) = packet.as_tcp() else { + tracing::trace!(?packet, "Not a TCP packet"); + + return false; + }; + + let Some((ipv4_source, ipv6_source)) = self.source_ips else { + tracing::trace!("No source interface"); + + return false; + }; + + // If the packet doesn't match our source interface, we don't want it. + match packet.destination() { + IpAddr::V4(v4) if v4 != ipv4_source => return false, + IpAddr::V6(v6) if v6 != ipv6_source => return false, + _ => {} + } + + let remote = SocketAddr::new(packet.source(), tcp.source_port()); + let has_socket = self.sockets_by_remote.contains_key(&remote); + + if !has_socket && tracing::enabled!(tracing::Level::TRACE) { + let open_sockets = BTreeSet::from_iter(self.sockets_by_remote.keys().copied()); + + tracing::trace!(%remote, ?open_sockets, "No open socket for remote"); + } + + has_socket + } + + /// Handle the [`IpPacket`]. + /// + /// This function only inserts the packet into a buffer. + /// To actually process the packets in the buffer, [`Client::handle_timeout`] must be called. + pub fn handle_inbound(&mut self, packet: IpPacket) { + debug_assert!(self.accepts(&packet)); + + self.device.receive(packet); + } + + /// Returns [`IpPacket`]s that should be sent. + pub fn poll_outbound(&mut self) -> Option { + self.device.next_send() + } + + /// Returns the next [`QueryResult`]. + pub fn poll_query_result(&mut self) -> Option { + self.query_results.pop_front() + } + + /// Inform the client that time advanced. + /// + /// Typical for a sans-IO design, `handle_timeout` will work through all local buffers and process them as much as possible. + pub fn handle_timeout(&mut self, now: Instant) { + let Some((ipv4_source, ipv6_source)) = self.source_ips else { + return; + }; + + let changed = self.interface.poll( + smoltcp::time::Instant::from(now), + &mut self.device, + &mut self.sockets, + ); + + if !changed && self.pending_queries_by_remote.is_empty() { + return; + } + + for (remote, handle) in self.sockets_by_remote.iter_mut() { + let socket = self.sockets.get_mut::(*handle); + let server = *remote; + + // First, attempt to send all pending queries on this socket. + send_pending_queries( + socket, + server, + &mut self.pending_queries_by_remote, + &mut self.sent_queries_by_remote, + &mut self.query_results, + ); + + // Second, attempt to receive responses. + recv_responses( + socket, + server, + &mut self.pending_queries_by_remote, + &mut self.sent_queries_by_remote, + &mut self.query_results, + ); + + let has_pending_dns_queries = !self + .pending_queries_by_remote + .entry(server) + .or_default() + .is_empty(); + + // Third, if the socket got closed, reconnect it. + if matches!(socket.state(), tcp::State::Closed) && has_pending_dns_queries { + let local_port = self + .local_ports_by_socket + .get(handle) + .expect("must always have a port for each socket"); + let local_endpoint = match server { + SocketAddr::V4(_) => SocketAddr::new(ipv4_source.into(), *local_port), + SocketAddr::V6(_) => SocketAddr::new(ipv6_source.into(), *local_port), + }; + + tracing::info!(local = %local_endpoint, remote = %server, "Re-connecting to DNS resolver"); + + socket + .connect(self.interface.context(), server, local_endpoint) + .expect( + "re-connecting a closed socket with the same parameters should always work", + ); + } + } + } + + fn sample_unique_ports(&mut self, num_ports: usize) -> Result> { + let mut ports = HashSet::with_capacity(num_ports); + + if num_ports > self.port_range.len() { + bail!( + "Port range only provides {} values but we need {num_ports}", + self.port_range.len() + ) + } + + while ports.len() < num_ports { + ports.insert(self.rng.gen_range(self.port_range.clone())); + } + + Ok(ports.into_iter()) + } +} + +fn send_pending_queries( + socket: &mut Socket, + server: SocketAddr, + pending_queries_by_remote: &mut HashMap>>>, + sent_queries_by_remote: &mut HashMap>>>, + query_results: &mut VecDeque, +) { + let pending_queries = pending_queries_by_remote.entry(server).or_default(); + let sent_queries = sent_queries_by_remote.entry(server).or_default(); + + loop { + if !socket.can_send() { + break; + } + + let Some(query) = pending_queries.pop_front() else { + break; + }; + + match codec::try_send(socket, query.for_slice_ref()).context("Failed to send DNS query") { + Ok(()) => { + let replaced = sent_queries.insert(query.header().id(), query).is_some(); + debug_assert!(!replaced, "Query ID is not unique"); + } + Err(e) => { + // We failed to send the query, declare the socket as failed. + socket.abort(); + + query_results.extend(into_failed_results( + server, + pending_queries + .drain(..) + .chain(sent_queries.drain().map(|(_, query)| query)), + || anyhow!("{e:#}"), + )); + query_results.push_back(QueryResult { + query, + server, + result: Err(e), + }); + } + } + } +} + +fn recv_responses( + socket: &mut Socket, + server: SocketAddr, + pending_queries_by_remote: &mut HashMap>>>, + sent_queries_by_remote: &mut HashMap>>>, + query_results: &mut VecDeque, +) { + let Some(result) = try_recv_response(socket) + .context("Failed to receive DNS response") + .transpose() + else { + return; // No messages on this socket, continue. + }; + let pending_queries = pending_queries_by_remote.entry(server).or_default(); + let sent_queries = sent_queries_by_remote.entry(server).or_default(); + + let new_results = result + .and_then(|response| { + let query = sent_queries + .remove(&response.header().id()) + .context("DNS resolver sent response for unknown query")?; + + Ok(vec![QueryResult { + query, + server, + result: Ok(response.octets_into()), + }]) + }) + .unwrap_or_else(|e| { + socket.abort(); + + into_failed_results( + server, + pending_queries + .drain(..) + .chain(sent_queries.drain().map(|(_, query)| query)), + || anyhow!("{e:#}"), + ) + .collect() + }); + + query_results.extend(new_results); +} + +fn into_failed_results( + server: SocketAddr, + iter: impl IntoIterator>>, + make_error: impl Fn() -> anyhow::Error, +) -> impl Iterator { + iter.into_iter().map(move |query| QueryResult { + query, + server, + result: Err(make_error()), + }) +} + +fn try_recv_response<'b>(socket: &'b mut tcp::Socket) -> Result>> { + anyhow::ensure!(socket.is_active(), "Socket is not active"); + + if !socket.can_recv() { + tracing::trace!("Not yet ready to receive next message"); + + return Ok(None); + } + + let Some(message) = codec::try_recv(socket)? else { + return Ok(None); + }; + + anyhow::ensure!(message.header().qr(), "DNS message is a query!"); + + Ok(Some(message)) +} diff --git a/rust/dns-over-tcp/src/codec.rs b/rust/dns-over-tcp/src/codec.rs new file mode 100644 index 000000000..a4a32e25f --- /dev/null +++ b/rust/dns-over-tcp/src/codec.rs @@ -0,0 +1,141 @@ +//! Implements sending and receiving of DNS messages over TCP. +//! +//! TCP's stream-oriented nature requires us to know how long the encoded DNS message is before we can read it. +//! For this purpose, DNS messages over TCP are prefixed using a big-endian encoded u16. +//! +//! Source: . + +use anyhow::{Context as _, Result}; +use domain::{ + base::{iana::Rcode, Message, ParsedName, Rtype}, + rdata::AllRecordData, +}; +use itertools::Itertools as _; +use smoltcp::socket::tcp; + +pub fn try_send(socket: &mut tcp::Socket, message: Message<&[u8]>) -> Result<()> { + let response = message.as_slice(); + + let dns_message_length = (response.len() as u16).to_be_bytes(); + + let written = socket + .send_slice(&dns_message_length) + .context("Failed to write TCP DNS length header")?; + + anyhow::ensure!( + written == 2, + "Not enough space in write buffer for TCP DNS length header" + ); + + let written = socket + .send_slice(response) + .context("Failed to write DNS message")?; + + anyhow::ensure!( + written == response.len(), + "Not enough space in write buffer for DNS message" + ); + + if tracing::event_enabled!(target: "wire::dns::tcp::send", tracing::Level::TRACE) { + if let Some(ParsedMessage { + qid, + qname, + qtype, + response, + rcode, + records, + }) = parse(message) + { + if response { + let records = records.into_iter().join(" | "); + tracing::trace!(target: "wire::dns::tcp::send", %qid, %rcode, "{:5} {qname} => [{records}]", qtype.to_string()); + } else { + tracing::trace!(target: "wire::dns::tcp::send", %qid, "{:5} {qname}", qtype.to_string()); + } + } + } + + Ok(()) +} + +pub fn try_recv<'b>(socket: &'b mut tcp::Socket) -> Result>> { + let maybe_message = socket + .recv(|r| { + // DNS over TCP has a 2-byte length prefix at the start, see . + let Some((header, message)) = r.split_first_chunk::<2>() else { + return (0, None); + }; + let dns_message_length = u16::from_be_bytes(*header) as usize; + if message.len() < dns_message_length { + return (0, None); // Don't consume any bytes unless we can read the full message at once. + } + + (2 + dns_message_length, Some(Message::from_octets(message))) + }) + .context("Failed to recv TCP data")? + .transpose() + .context("Failed to parse DNS message")?; + + if tracing::event_enabled!(target: "wire::dns::tcp::recv", tracing::Level::TRACE) { + if let Some(ParsedMessage { + qid, + qname, + qtype, + rcode, + response, + records, + }) = maybe_message.and_then(parse) + { + if response { + let records = records.into_iter().join(" | "); + tracing::trace!(target: "wire::dns::tcp::recv", %qid, %rcode, "{:5} {qname} => [{records}]", qtype.to_string()); + } else { + tracing::trace!(target: "wire::dns::tcp::recv", %qid, "{:5} {qname}", qtype.to_string()); + } + } + } + + Ok(maybe_message) +} + +fn parse(message: Message<&[u8]>) -> Option> { + let question = message.sole_question().ok()?; + let answers = message.answer().ok()?; + + let qtype = question.qtype(); + let qname = question.into_qname(); + let qid = message.header().id(); + let response = message.header().qr(); + let rcode = message.header().rcode(); + let records = answers + .into_iter() + .filter_map(|r| { + let data = r + .ok()? + .into_any_record::>() + .ok()? + .data() + .clone(); + + Some(data) + }) + .collect(); + + Some(ParsedMessage { + qid, + qname, + rcode, + qtype, + response, + records, + }) +} + +struct ParsedMessage<'a> { + qid: u16, + qname: ParsedName<&'a [u8]>, + qtype: Rtype, + rcode: Rcode, + response: bool, + records: Vec>>, +} diff --git a/rust/dns-over-tcp/src/interface.rs b/rust/dns-over-tcp/src/interface.rs new file mode 100644 index 000000000..f7dc63239 --- /dev/null +++ b/rust/dns-over-tcp/src/interface.rs @@ -0,0 +1,44 @@ +use std::time::Instant; + +use smoltcp::{ + iface::{Config, Interface, Route}, + wire::{HardwareAddress, Ipv4Address, Ipv4Cidr, Ipv6Address, Ipv6Cidr}, +}; + +use crate::stub_device::InMemoryDevice; + +const IP4_ADDR: Ipv4Address = Ipv4Address::new(127, 0, 0, 1); +const IP6_ADDR: Ipv6Address = Ipv6Address::new(0, 0, 0, 0, 0, 0, 0, 1); + +/// Creates a smoltcp [`Interface`]. +/// +/// smoltcp's abstractions allow to directly plug it in a TUN device. +/// As a result, it has all the features you'd expect from a network interface: +/// - Setting IP addresses +/// - Defining routes +/// +/// In our implementation, we don't want to use any of that. +/// Our device is entirely backed by in-memory buffers and we and selectively feed IP packets to it. +/// Therefore, we configure it to: +/// - Accept any packet +/// - Define dummy IPs (localhost for IPv4 and IPv6) +/// - Define catch-all routes (0.0.0.0/0) that routes all traffic to the interface +pub fn create_interface(device: &mut InMemoryDevice, now: Instant) -> Interface { + let mut interface = Interface::new(Config::new(HardwareAddress::Ip), device, now.into()); + // Accept packets with any destination IP, not just our interface. + interface.set_any_ip(true); + + // Set our interface IPs. These are just dummies and don't show up anywhere! + interface.update_ip_addrs(|ips| { + ips.push(Ipv4Cidr::new(IP4_ADDR, 32).into()).unwrap(); + ips.push(Ipv6Cidr::new(IP6_ADDR, 128).into()).unwrap(); + }); + + // Configure catch-all routes, meaning all packets given to `smoltcp` will be routed to our interface. + interface.routes_mut().update(|routes| { + routes.push(Route::new_ipv4_gateway(IP4_ADDR)).unwrap(); + routes.push(Route::new_ipv6_gateway(IP6_ADDR)).unwrap(); + }); + + interface +} diff --git a/rust/dns-over-tcp/src/lib.rs b/rust/dns-over-tcp/src/lib.rs index 94a38418e..af49bc6b8 100644 --- a/rust/dns-over-tcp/src/lib.rs +++ b/rust/dns-over-tcp/src/lib.rs @@ -1,4 +1,21 @@ +mod client; +mod codec; +mod interface; mod server; mod stub_device; +pub use client::{Client, QueryResult}; pub use server::{Server, SocketHandle}; + +fn create_tcp_socket() -> smoltcp::socket::tcp::Socket<'static> { + /// The 2-byte length prefix of DNS over TCP messages limits their size to effectively u16::MAX. + /// It is quite unlikely that we have to buffer _multiple_ of these max-sized messages. + /// Being able to buffer at least one of them means we can handle the extreme case. + /// In practice, this allows the OS to queue multiple queries even if we can't immediately process them. + const MAX_TCP_DNS_MSG_LENGTH: usize = u16::MAX as usize; + + smoltcp::socket::tcp::Socket::new( + smoltcp::storage::RingBuffer::new(vec![0u8; MAX_TCP_DNS_MSG_LENGTH]), + smoltcp::storage::RingBuffer::new(vec![0u8; MAX_TCP_DNS_MSG_LENGTH]), + ) +} diff --git a/rust/dns-over-tcp/src/server.rs b/rust/dns-over-tcp/src/server.rs index 6d212ab05..dceb2ddb9 100644 --- a/rust/dns-over-tcp/src/server.rs +++ b/rust/dns-over-tcp/src/server.rs @@ -4,16 +4,14 @@ use std::{ time::Instant, }; -use crate::stub_device::InMemoryDevice; +use crate::{codec, create_tcp_socket, interface::create_interface, stub_device::InMemoryDevice}; use anyhow::{Context as _, Result}; -use domain::{base::Message, dep::octseq::OctetsInto as _, rdata::AllRecordData}; +use domain::{base::Message, dep::octseq::OctetsInto as _}; use ip_packet::IpPacket; -use itertools::Itertools as _; use smoltcp::{ - iface::{Config, Interface, Route, SocketSet}, + iface::{Interface, SocketSet}, socket::tcp, - storage::RingBuffer, - wire::{HardwareAddress, IpEndpoint, Ipv4Address, Ipv4Cidr, Ipv6Address, Ipv6Cidr}, + wire::IpEndpoint, }; /// A sans-IO implementation of DNS-over-TCP server. @@ -43,34 +41,10 @@ pub struct Query { pub local: SocketAddr, } -const SERVER_IP4_ADDR: Ipv4Address = Ipv4Address::new(127, 0, 0, 1); -const SERVER_IP6_ADDR: Ipv6Address = Ipv6Address::new(0, 0, 0, 0, 0, 0, 0, 1); - impl Server { pub fn new(now: Instant) -> Self { let mut device = InMemoryDevice::default(); - - let mut interface = - Interface::new(Config::new(HardwareAddress::Ip), &mut device, now.into()); - // Accept packets with any destination IP, not just our interface. - interface.set_any_ip(true); - - // Set our interface IPs. These are just dummies and don't show up anywhere! - interface.update_ip_addrs(|ips| { - ips.push(Ipv4Cidr::new(SERVER_IP4_ADDR, 32).into()).unwrap(); - ips.push(Ipv6Cidr::new(SERVER_IP6_ADDR, 128).into()) - .unwrap(); - }); - - // Configure catch-all routes, meaning all packets given to `smoltcp` will be routed to our interface. - interface.routes_mut().update(|routes| { - routes - .push(Route::new_ipv4_gateway(SERVER_IP4_ADDR)) - .unwrap(); - routes - .push(Route::new_ipv6_gateway(SERVER_IP6_ADDR)) - .unwrap(); - }); + let interface = create_interface(&mut device, now); Self { device, @@ -98,7 +72,12 @@ impl Server { for listen_endpoint in addresses { for _ in 0..NUM_CONCURRENT_CLIENTS { - let handle = sockets.add(create_tcp_socket(listen_endpoint)); + let mut socket = create_tcp_socket(); + socket + .listen(listen_endpoint) + .expect("A fresh socket should always be able to listen"); + + let handle = sockets.add(socket); listen_endpoints.insert(handle, listen_endpoint); } @@ -149,40 +128,9 @@ impl Server { pub fn send_message(&mut self, socket: SocketHandle, message: Message>) -> Result<()> { let socket = self.sockets.get_mut::(socket.0); - let result = write_tcp_dns_response(socket, message.for_slice_ref()); - - if result.is_err() { - socket.abort(); - } - - result.context("Failed to write DNS response")?; // Bail before logging in case we failed to write the response. - - if tracing::event_enabled!(target: "wire::dns::res", tracing::Level::TRACE) { - if let Ok(question) = message.sole_question() { - let qtype = question.qtype(); - let qname = question.into_qname(); - let rcode = message.header().rcode(); - - if let Ok(record_section) = message.answer() { - let records = record_section - .into_iter() - .filter_map(|r| { - let data = r - .ok()? - .into_any_record::>() - .ok()? - .data() - .clone(); - - Some(data) - }) - .join(" | "); - let qid = message.header().id(); - - tracing::trace!(target: "wire::dns::res", %qid, %rcode, "{:5} {qname} => [{records}]", qtype.to_string()); - } - } - } + write_tcp_dns_response(socket, message.for_slice_ref()) + .inspect_err(|_| socket.abort()) // Abort socket on error. + .context("Failed to write DNS response")?; Ok(()) } @@ -211,28 +159,20 @@ impl Server { for (handle, smoltcp::socket::Socket::Tcp(socket)) in self.sockets.iter_mut() { let listen = self.listen_endpoints.get(&handle).copied().unwrap(); - match try_recv_query(socket, listen) { - Ok(Some(message)) => { - if tracing::event_enabled!(target: "wire::dns::qry", tracing::Level::TRACE) { - if let Ok(question) = message.sole_question() { - let qtype = question.qtype(); - let qname = question.into_qname(); - let qid = message.header().id(); - - tracing::trace!(target: "wire::dns::qry", %qid, "{:5} {qname}", qtype.to_string()); - } + while let Some(result) = try_recv_query(socket, listen).transpose() { + match result { + Ok(message) => { + self.received_queries.push_back(Query { + message: message.octets_into(), + socket: SocketHandle(handle), + local: listen, + }); + } + Err(e) => { + tracing::debug!("Error on receiving DNS query: {e}"); + socket.abort(); + break; } - - self.received_queries.push_back(Query { - message, - socket: SocketHandle(handle), - local: listen, - }); - } - Ok(None) => {} - Err(e) => { - tracing::debug!("Error on receiving DNS query: {e}"); - socket.abort(); } } } @@ -249,28 +189,10 @@ impl Server { } } -fn create_tcp_socket(listen_endpoint: SocketAddr) -> tcp::Socket<'static> { - /// The 2-byte length prefix of DNS over TCP messages limits their size to effectively u16::MAX. - /// It is quite unlikely that we have to buffer _multiple_ of these max-sized messages. - /// Being able to buffer at least one of them means we can handle the extreme case. - /// In practice, this allows the OS to queue multiple queries even if we can't immediately process them. - const MAX_TCP_DNS_MSG_LENGTH: usize = u16::MAX as usize; - - let mut socket = tcp::Socket::new( - RingBuffer::new(vec![0u8; MAX_TCP_DNS_MSG_LENGTH]), - RingBuffer::new(vec![0u8; MAX_TCP_DNS_MSG_LENGTH]), - ); - socket - .listen(listen_endpoint) - .expect("A fresh socket should always be able to listen"); - - socket -} - -fn try_recv_query( - socket: &mut tcp::Socket, +fn try_recv_query<'b>( + socket: &'b mut tcp::Socket, listen: SocketAddr, -) -> Result>>> { +) -> Result>> { // smoltcp's sockets can only ever handle a single remote, i.e. there is no permanent listening socket. // to be able to handle a new connection, reset the socket back to `listen` once the connection is closed / closing. { @@ -306,56 +228,19 @@ fn try_recv_query( return Ok(None); } - // Read a DNS message from the socket. - let Some(message) = socket - .recv(|r| { - // DNS over TCP has a 2-byte length prefix at the start, see . - let Some((header, message)) = r.split_first_chunk::<2>() else { - return (0, None); - }; - let dns_message_length = u16::from_be_bytes(*header) as usize; - if message.len() < dns_message_length { - return (0, None); // Don't consume any bytes unless we can read the full message at once. - } - - (2 + dns_message_length, Some(Message::from_octets(message))) - }) - .context("Failed to recv TCP data")? - .transpose() - .context("Failed to parse DNS message")? - else { + let Some(message) = codec::try_recv(socket)? else { return Ok(None); }; anyhow::ensure!(!message.header().qr(), "DNS message is a response!"); - Ok(Some(message.octets_into())) + Ok(Some(message)) } fn write_tcp_dns_response(socket: &mut tcp::Socket, response: Message<&[u8]>) -> Result<()> { anyhow::ensure!(response.header().qr(), "DNS message is a query!"); - let response = response.as_slice(); - - let dns_message_length = (response.len() as u16).to_be_bytes(); - - let written = socket - .send_slice(&dns_message_length) - .context("Failed to write TCP DNS length header")?; - - anyhow::ensure!( - written == 2, - "Not enough space in write buffer for TCP DNS length header" - ); - - let written = socket - .send_slice(response) - .context("Failed to write DNS message")?; - - anyhow::ensure!( - written == response.len(), - "Not enough space in write buffer for DNS response" - ); + codec::try_send(socket, response)?; Ok(()) } diff --git a/rust/dns-over-tcp/tests/client_and_server.rs b/rust/dns-over-tcp/tests/client_and_server.rs new file mode 100644 index 000000000..17a42e5de --- /dev/null +++ b/rust/dns-over-tcp/tests/client_and_server.rs @@ -0,0 +1,89 @@ +use std::{ + collections::BTreeSet, + net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4}, + time::Instant, +}; + +use dns_over_tcp::QueryResult; +use domain::base::{iana::Rcode, Message, MessageBuilder, Name, Rtype}; + +#[test] +fn smoke() { + let _guard = firezone_logging::test( + "netlink_proto=off,wire::dns::res=trace,dns_over_tcp=trace,smoltcp=trace,debug", + ); + + let ipv4 = Ipv4Addr::from([100, 90, 215, 97]); + let ipv6 = Ipv6Addr::from([0xfd00, 0x2021, 0x1111, 0x0, 0x0, 0x0, 0x0016, 0x588f]); + + let resolver_addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(100, 100, 111, 1), 53)); + + let mut dns_client = dns_over_tcp::Client::new(Instant::now(), 49152..=65535, [0u8; 32]); + dns_client.set_source_interface(ipv4, ipv6); + dns_client + .connect_to_resolvers(BTreeSet::from_iter([resolver_addr])) + .unwrap(); + + let mut dns_server = dns_over_tcp::Server::new(Instant::now()); + dns_server.set_listen_addresses::<1>(vec![resolver_addr]); + + for id in 0..5 { + dns_client + .send_query(resolver_addr, a_query("example.com", id)) + .unwrap(); + } + + let results = std::iter::from_fn(|| progress(&mut dns_client, &mut dns_server)) + .take(5) + .collect::>(); + + for query_result in results { + let result = query_result.result.unwrap(); + + println!("{result:?}") + } +} + +fn a_query(domain: &str, id: u16) -> Message> { + let mut builder = MessageBuilder::new_vec().question(); + builder.header_mut().set_id(id); + builder + .push((Name::vec_from_str(domain).unwrap(), Rtype::A)) + .unwrap(); + + builder.into_message() +} + +fn progress( + dns_client: &mut dns_over_tcp::Client, + dns_server: &mut dns_over_tcp::Server, +) -> Option { + loop { + if let Some(packet) = dns_client.poll_outbound() { + dns_server.handle_inbound(packet); + continue; + } + + if let Some(packet) = dns_server.poll_outbound() { + dns_client.handle_inbound(packet); + continue; + } + + if let Some(query) = dns_server.poll_queries() { + let response = MessageBuilder::new_vec() + .start_answer(&query.message, Rcode::NXDOMAIN) + .unwrap() + .into_message(); + + dns_server.send_message(query.socket, response).unwrap(); + continue; + } + + if let Some(query) = dns_client.poll_query_result() { + return Some(query); + } + + dns_client.handle_timeout(Instant::now()); + dns_server.handle_timeout(Instant::now()); + } +} diff --git a/rust/dns-over-tcp/tests/smoke.rs b/rust/dns-over-tcp/tests/smoke_server.rs similarity index 96% rename from rust/dns-over-tcp/tests/smoke.rs rename to rust/dns-over-tcp/tests/smoke_server.rs index 7a745654e..5003653ad 100644 --- a/rust/dns-over-tcp/tests/smoke.rs +++ b/rust/dns-over-tcp/tests/smoke_server.rs @@ -16,10 +16,9 @@ use tun::Tun; const CLIENT_CONCURRENCY: usize = 3; #[tokio::test] -#[ignore = "Requires root"] +#[ignore = "Requires root & IP forwarding"] async fn smoke() { - let _guard = - firezone_logging::test("netlink_proto=off,wire::dns::res=trace,dns_over_tcp=trace,debug"); + let _guard = firezone_logging::test("netlink_proto=off,wire::dns=trace,debug"); let ipv4 = Ipv4Addr::from([100, 90, 215, 97]); let ipv6 = Ipv6Addr::from([0xfd00, 0x2021, 0x1111, 0x0, 0x0, 0x0, 0x0016, 0x588f]);