From cd2dea78463e65cd7ff79bfbe4631bc67afff5ae Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Fri, 11 Oct 2024 08:05:12 +1100 Subject: [PATCH] chore: add sans-IO DNS-over-TCP implementation (#6997) This splits out the actual DNS server from #6944 into a separate crate. At present, it only contains a DNS server. Later, we will likely add a DNS client to it as well because the proptests and connlib itself will need a user-space DNS TCP client. The implementation uses `smoltcp` but that is entirely encapsulated. The `Server` struct exposes only a high-level interface for - feeding inbound packets as well as retrieving outbound packets - retrieving parsed DNS queries and sending DNS responses Related: #6140. --- rust/Cargo.lock | 89 +++++++ rust/Cargo.toml | 2 + rust/dns-over-tcp/Cargo.toml | 20 ++ rust/dns-over-tcp/src/lib.rs | 4 + rust/dns-over-tcp/src/server.rs | 361 +++++++++++++++++++++++++++ rust/dns-over-tcp/src/stub_device.rs | 86 +++++++ rust/dns-over-tcp/tests/smoke.rs | 133 ++++++++++ rust/ip-packet/src/lib.rs | 2 +- 8 files changed, 696 insertions(+), 1 deletion(-) create mode 100644 rust/dns-over-tcp/Cargo.toml create mode 100644 rust/dns-over-tcp/src/lib.rs create mode 100644 rust/dns-over-tcp/src/server.rs create mode 100644 rust/dns-over-tcp/src/stub_device.rs create mode 100644 rust/dns-over-tcp/tests/smoke.rs diff --git a/rust/Cargo.lock b/rust/Cargo.lock index 8a6e9c28d..c83246876 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -1685,6 +1685,38 @@ dependencies = [ "uuid", ] +[[package]] +name = "defmt" +version = "0.3.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a99dd22262668b887121d4672af5a64b238f026099f1a2a1b322066c9ecfe9e0" +dependencies = [ + "bitflags 1.3.2", + "defmt-macros", +] + +[[package]] +name = "defmt-macros" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3a9f309eff1f79b3ebdf252954d90ae440599c26c2c553fe87a2d17195f2dcb" +dependencies = [ + "defmt-parser", + "proc-macro-error", + "proc-macro2", + "quote", + "syn 2.0.77", +] + +[[package]] +name = "defmt-parser" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff4a5fefe330e8d7f31b16a318f9ce81000d8e35e69b93eae154d16d2278f70f" +dependencies = [ + "thiserror", +] + [[package]] name = "deranged" version = "0.3.11" @@ -1872,6 +1904,23 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "dns-over-tcp" +version = "0.1.0" +dependencies = [ + "anyhow", + "domain", + "firezone-bin-shared", + "firezone-logging", + "ip-packet", + "ip_network", + "itertools 0.13.0", + "smoltcp", + "tokio", + "tracing", + "tun", +] + [[package]] name = "domain" version = "0.10.1" @@ -3236,6 +3285,15 @@ dependencies = [ "crunchy", ] +[[package]] +name = "hash32" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47d60b12902ba28e2730cd37e95b8c9223af2808df9e902d4df49588d1470606" +dependencies = [ + "byteorder", +] + [[package]] name = "hashbrown" version = "0.12.3" @@ -3267,6 +3325,16 @@ dependencies = [ "winapi", ] +[[package]] +name = "heapless" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0bfb9eb618601c89945a70e254898da93b13be0388091d42117462b265bb3fad" +dependencies = [ + "hash32", + "stable_deref_trait", +] + [[package]] name = "heck" version = "0.3.3" @@ -4327,6 +4395,12 @@ dependencies = [ "libc", ] +[[package]] +name = "managed" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ca88d725a0a943b096803bd34e73a4437208b6077654cc4ecb2947a5f91618d" + [[package]] name = "markup5ever" version = "0.11.0" @@ -6947,6 +7021,21 @@ dependencies = [ "serde", ] +[[package]] +name = "smoltcp" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a1a996951e50b5971a2c8c0fa05a381480d70a933064245c4a223ddc87ccc97" +dependencies = [ + "bitflags 1.3.2", + "byteorder", + "cfg-if", + "defmt", + "heapless", + "log", + "managed", +] + [[package]] name = "snownet" version = "0.1.0" diff --git a/rust/Cargo.toml b/rust/Cargo.toml index a8d205519..5f2c48194 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -7,6 +7,7 @@ members = [ "connlib/model", "connlib/snownet", "connlib/tunnel", + "dns-over-tcp", "gateway", "gui-client/src-common", "gui-client/src-tauri", @@ -51,6 +52,7 @@ firezone-bin-shared = { path = "bin-shared" } firezone-logging = { path = "logging" } firezone-telemetry = { path = "telemetry" } snownet = { path = "connlib/snownet" } +dns-over-tcp = { path = "dns-over-tcp" } firezone-relay = { path = "relay" } connlib-model = { path = "connlib/model" } firezone-tunnel = { path = "connlib/tunnel" } diff --git a/rust/dns-over-tcp/Cargo.toml b/rust/dns-over-tcp/Cargo.toml new file mode 100644 index 000000000..a59bf0d4a --- /dev/null +++ b/rust/dns-over-tcp/Cargo.toml @@ -0,0 +1,20 @@ +[package] +name = "dns-over-tcp" +version = "0.1.0" +edition = "2021" +description = "User-space implementation of DNS over TCP." + +[dependencies] +anyhow = "1.0" +domain = { workspace = true } +ip-packet = { workspace = true } +itertools = "0.13" +smoltcp = { version = "0.11", default-features = false, features = ["std", "log", "medium-ip", "proto-ipv4", "proto-ipv6", "socket-tcp"] } +tracing = { workspace = true } + +[dev-dependencies] +firezone-bin-shared = { workspace = true } +firezone-logging = { workspace = true } +ip_network = { version = "0.4", default-features = false } +tokio = { workspace = true, features = ["process", "rt", "macros"] } +tun = { workspace = true } diff --git a/rust/dns-over-tcp/src/lib.rs b/rust/dns-over-tcp/src/lib.rs new file mode 100644 index 000000000..94a38418e --- /dev/null +++ b/rust/dns-over-tcp/src/lib.rs @@ -0,0 +1,4 @@ +mod server; +mod stub_device; + +pub use server::{Server, SocketHandle}; diff --git a/rust/dns-over-tcp/src/server.rs b/rust/dns-over-tcp/src/server.rs new file mode 100644 index 000000000..6d212ab05 --- /dev/null +++ b/rust/dns-over-tcp/src/server.rs @@ -0,0 +1,361 @@ +use std::{ + collections::{BTreeSet, HashMap, VecDeque}, + net::SocketAddr, + time::Instant, +}; + +use crate::stub_device::InMemoryDevice; +use anyhow::{Context as _, Result}; +use domain::{base::Message, dep::octseq::OctetsInto as _, rdata::AllRecordData}; +use ip_packet::IpPacket; +use itertools::Itertools as _; +use smoltcp::{ + iface::{Config, Interface, Route, SocketSet}, + socket::tcp, + storage::RingBuffer, + wire::{HardwareAddress, IpEndpoint, Ipv4Address, Ipv4Cidr, Ipv6Address, Ipv6Cidr}, +}; + +/// A sans-IO implementation of DNS-over-TCP server. +/// +/// Listens on a specified number of socket addresses, parses incoming DNS queries and allows writing back responses. +pub struct Server { + device: InMemoryDevice, + interface: Interface, + + sockets: SocketSet<'static>, + listen_endpoints: HashMap, + + received_queries: VecDeque, +} + +/// Opaque handle to a TCP socket. +/// +/// This purposely does not implement [`Clone`] or [`Copy`] to make them single-use. +#[derive(Debug, PartialEq, Eq, Hash)] +#[must_use = "An active `SocketHandle` means a TCP socket is waiting for a reply somewhere"] +pub struct SocketHandle(smoltcp::iface::SocketHandle); + +pub struct Query { + pub message: Message>, + pub socket: SocketHandle, + /// The address of the socket that received the query. + pub local: SocketAddr, +} + +const SERVER_IP4_ADDR: Ipv4Address = Ipv4Address::new(127, 0, 0, 1); +const SERVER_IP6_ADDR: Ipv6Address = Ipv6Address::new(0, 0, 0, 0, 0, 0, 0, 1); + +impl Server { + pub fn new(now: Instant) -> Self { + let mut device = InMemoryDevice::default(); + + let mut interface = + Interface::new(Config::new(HardwareAddress::Ip), &mut device, now.into()); + // Accept packets with any destination IP, not just our interface. + interface.set_any_ip(true); + + // Set our interface IPs. These are just dummies and don't show up anywhere! + interface.update_ip_addrs(|ips| { + ips.push(Ipv4Cidr::new(SERVER_IP4_ADDR, 32).into()).unwrap(); + ips.push(Ipv6Cidr::new(SERVER_IP6_ADDR, 128).into()) + .unwrap(); + }); + + // Configure catch-all routes, meaning all packets given to `smoltcp` will be routed to our interface. + interface.routes_mut().update(|routes| { + routes + .push(Route::new_ipv4_gateway(SERVER_IP4_ADDR)) + .unwrap(); + routes + .push(Route::new_ipv6_gateway(SERVER_IP6_ADDR)) + .unwrap(); + }); + + Self { + device, + interface, + sockets: SocketSet::new(Vec::default()), + listen_endpoints: Default::default(), + received_queries: Default::default(), + } + } + + /// Listen on the specified addresses. + /// + /// This resets all sockets we were previously listening on. + /// This function is generic over a `NUM_CONCURRENT_CLIENTS` constant. + /// The constant configures, how many concurrent clients you would like to be able to serve per listen address. + pub fn set_listen_addresses( + &mut self, + addresses: Vec, + ) { + assert!(NUM_CONCURRENT_CLIENTS > 0); + + let mut sockets = + SocketSet::new(Vec::with_capacity(addresses.len() * NUM_CONCURRENT_CLIENTS)); + let mut listen_endpoints = HashMap::with_capacity(addresses.len()); + + for listen_endpoint in addresses { + for _ in 0..NUM_CONCURRENT_CLIENTS { + let handle = sockets.add(create_tcp_socket(listen_endpoint)); + listen_endpoints.insert(handle, listen_endpoint); + } + + tracing::info!(%listen_endpoint, concurrency = %NUM_CONCURRENT_CLIENTS, "Created listening TCP socket"); + } + + self.sockets = sockets; + self.listen_endpoints = listen_endpoints; + self.received_queries.clear(); + } + + /// Checks whether this server can handle the given packet. + /// + /// Only TCP packets targeted at one of sockets configured with [`Server::set_listen_addresses`] are accepted. + pub fn accepts(&self, packet: &IpPacket) -> bool { + let Some(tcp) = packet.as_tcp() else { + tracing::trace!(?packet, "Not a TCP packet"); + + return false; + }; + + let dst = SocketAddr::new(packet.destination(), tcp.destination_port()); + let is_listening = self.listen_endpoints.values().any(|listen| listen == &dst); + + if !is_listening && tracing::enabled!(tracing::Level::TRACE) { + let listen_endpoints = BTreeSet::from_iter(self.listen_endpoints.values().copied()); + + tracing::trace!(%dst, ?listen_endpoints, "No listening socket for destination"); + } + + is_listening + } + + /// Handle the [`IpPacket`]. + /// + /// This function only inserts the packet into a buffer. + /// To actually process the packets in the buffer, [`Server::handle_timeout`] must be called. + pub fn handle_inbound(&mut self, packet: IpPacket) { + debug_assert!(self.accepts(&packet)); + + self.device.receive(packet); + } + + /// Send a message on the socket associated with the handle. + /// + /// This fails if the socket is not writeable. + /// On any error, the TCP connection is automatically reset. + pub fn send_message(&mut self, socket: SocketHandle, message: Message>) -> Result<()> { + let socket = self.sockets.get_mut::(socket.0); + + let result = write_tcp_dns_response(socket, message.for_slice_ref()); + + if result.is_err() { + socket.abort(); + } + + result.context("Failed to write DNS response")?; // Bail before logging in case we failed to write the response. + + if tracing::event_enabled!(target: "wire::dns::res", tracing::Level::TRACE) { + if let Ok(question) = message.sole_question() { + let qtype = question.qtype(); + let qname = question.into_qname(); + let rcode = message.header().rcode(); + + if let Ok(record_section) = message.answer() { + let records = record_section + .into_iter() + .filter_map(|r| { + let data = r + .ok()? + .into_any_record::>() + .ok()? + .data() + .clone(); + + Some(data) + }) + .join(" | "); + let qid = message.header().id(); + + tracing::trace!(target: "wire::dns::res", %qid, %rcode, "{:5} {qname} => [{records}]", qtype.to_string()); + } + } + } + + Ok(()) + } + + /// Resets the socket associated with the given handle. + /// + /// Use this if you encountered an error while processing a previously emitted DNS query. + pub fn reset(&mut self, handle: SocketHandle) { + self.sockets.get_mut::(handle.0).abort(); + } + + /// Inform the server that time advanced. + /// + /// Typical for a sans-IO design, `handle_timeout` will work through all local buffers and process them as much as possible. + pub fn handle_timeout(&mut self, now: Instant) { + let changed = self.interface.poll( + smoltcp::time::Instant::from(now), + &mut self.device, + &mut self.sockets, + ); + + if !changed { + return; + } + + for (handle, smoltcp::socket::Socket::Tcp(socket)) in self.sockets.iter_mut() { + let listen = self.listen_endpoints.get(&handle).copied().unwrap(); + + match try_recv_query(socket, listen) { + Ok(Some(message)) => { + if tracing::event_enabled!(target: "wire::dns::qry", tracing::Level::TRACE) { + if let Ok(question) = message.sole_question() { + let qtype = question.qtype(); + let qname = question.into_qname(); + let qid = message.header().id(); + + tracing::trace!(target: "wire::dns::qry", %qid, "{:5} {qname}", qtype.to_string()); + } + } + + self.received_queries.push_back(Query { + message, + socket: SocketHandle(handle), + local: listen, + }); + } + Ok(None) => {} + Err(e) => { + tracing::debug!("Error on receiving DNS query: {e}"); + socket.abort(); + } + } + } + } + + /// Returns [`IpPacket`]s that should be sent. + pub fn poll_outbound(&mut self) -> Option { + self.device.next_send() + } + + /// Returns queries received from a DNS client. + pub fn poll_queries(&mut self) -> Option { + self.received_queries.pop_front() + } +} + +fn create_tcp_socket(listen_endpoint: SocketAddr) -> tcp::Socket<'static> { + /// The 2-byte length prefix of DNS over TCP messages limits their size to effectively u16::MAX. + /// It is quite unlikely that we have to buffer _multiple_ of these max-sized messages. + /// Being able to buffer at least one of them means we can handle the extreme case. + /// In practice, this allows the OS to queue multiple queries even if we can't immediately process them. + const MAX_TCP_DNS_MSG_LENGTH: usize = u16::MAX as usize; + + let mut socket = tcp::Socket::new( + RingBuffer::new(vec![0u8; MAX_TCP_DNS_MSG_LENGTH]), + RingBuffer::new(vec![0u8; MAX_TCP_DNS_MSG_LENGTH]), + ); + socket + .listen(listen_endpoint) + .expect("A fresh socket should always be able to listen"); + + socket +} + +fn try_recv_query( + socket: &mut tcp::Socket, + listen: SocketAddr, +) -> Result>>> { + // smoltcp's sockets can only ever handle a single remote, i.e. there is no permanent listening socket. + // to be able to handle a new connection, reset the socket back to `listen` once the connection is closed / closing. + { + use smoltcp::socket::tcp::State::*; + + if matches!(socket.state(), Closed | TimeWait | CloseWait) { + tracing::debug!(state = %socket.state(), "Resetting socket to listen state"); + + socket.abort(); + socket + .listen(listen) + .expect("Can always listen after `abort()`"); + } + } + + // We configure `smoltcp` with "any-ip", meaning packets to technically any IP will be routed here to us. + if let Some(local) = socket.local_endpoint() { + anyhow::ensure!( + local == IpEndpoint::from(listen), + "Bad destination socket: {local}" + ) + } + + // Ensure we can recv, send and have space to send. + if !socket.can_recv() || !socket.can_send() || socket.send_queue() > 0 { + tracing::trace!( + can_recv = %socket.can_recv(), + can_send = %socket.can_send(), + send_queue = %socket.send_queue(), + "Not yet ready to receive next message" + ); + + return Ok(None); + } + + // Read a DNS message from the socket. + let Some(message) = socket + .recv(|r| { + // DNS over TCP has a 2-byte length prefix at the start, see . + let Some((header, message)) = r.split_first_chunk::<2>() else { + return (0, None); + }; + let dns_message_length = u16::from_be_bytes(*header) as usize; + if message.len() < dns_message_length { + return (0, None); // Don't consume any bytes unless we can read the full message at once. + } + + (2 + dns_message_length, Some(Message::from_octets(message))) + }) + .context("Failed to recv TCP data")? + .transpose() + .context("Failed to parse DNS message")? + else { + return Ok(None); + }; + + anyhow::ensure!(!message.header().qr(), "DNS message is a response!"); + + Ok(Some(message.octets_into())) +} + +fn write_tcp_dns_response(socket: &mut tcp::Socket, response: Message<&[u8]>) -> Result<()> { + anyhow::ensure!(response.header().qr(), "DNS message is a query!"); + + let response = response.as_slice(); + + let dns_message_length = (response.len() as u16).to_be_bytes(); + + let written = socket + .send_slice(&dns_message_length) + .context("Failed to write TCP DNS length header")?; + + anyhow::ensure!( + written == 2, + "Not enough space in write buffer for TCP DNS length header" + ); + + let written = socket + .send_slice(response) + .context("Failed to write DNS message")?; + + anyhow::ensure!( + written == response.len(), + "Not enough space in write buffer for DNS response" + ); + + Ok(()) +} diff --git a/rust/dns-over-tcp/src/stub_device.rs b/rust/dns-over-tcp/src/stub_device.rs new file mode 100644 index 000000000..91ea4d249 --- /dev/null +++ b/rust/dns-over-tcp/src/stub_device.rs @@ -0,0 +1,86 @@ +use std::collections::VecDeque; + +use ip_packet::{IpPacket, IpPacketBuf}; + +/// A in-memory device for [`smoltcp`] that is entirely backed by buffers. +#[derive(Debug, Default)] +pub(crate) struct InMemoryDevice { + inbound_packets: VecDeque, + outbound_packets: VecDeque, +} + +impl InMemoryDevice { + pub(crate) fn receive(&mut self, packet: IpPacket) { + self.inbound_packets.push_back(packet); + } + + pub(crate) fn next_send(&mut self) -> Option { + self.outbound_packets.pop_front() + } +} + +impl smoltcp::phy::Device for InMemoryDevice { + type RxToken<'a> = SmolRxToken; + type TxToken<'a> = SmolTxToken<'a>; + + fn receive( + &mut self, + _timestamp: smoltcp::time::Instant, + ) -> Option<(Self::RxToken<'_>, Self::TxToken<'_>)> { + let rx_token = SmolRxToken { + packet: self.inbound_packets.pop_front()?, + }; + let tx_token = SmolTxToken { + outbound_packets: &mut self.outbound_packets, + }; + + Some((rx_token, tx_token)) + } + + fn transmit(&mut self, _timestamp: smoltcp::time::Instant) -> Option> { + Some(SmolTxToken { + outbound_packets: &mut self.outbound_packets, + }) + } + + fn capabilities(&self) -> smoltcp::phy::DeviceCapabilities { + let mut caps = smoltcp::phy::DeviceCapabilities::default(); + caps.medium = smoltcp::phy::Medium::Ip; + caps.max_transmission_unit = ip_packet::PACKET_SIZE; + + caps + } +} + +pub(crate) struct SmolTxToken<'a> { + outbound_packets: &'a mut VecDeque, +} + +impl<'a> smoltcp::phy::TxToken for SmolTxToken<'a> { + fn consume(self, len: usize, f: F) -> R + where + F: FnOnce(&mut [u8]) -> R, + { + let mut ip_packet_buf = IpPacketBuf::new(); + let result = f(ip_packet_buf.buf()); + + let mut ip_packet = IpPacket::new(ip_packet_buf, len).unwrap(); + ip_packet.update_checksum(); + self.outbound_packets.push_back(ip_packet); + + result + } +} + +pub(crate) struct SmolRxToken { + packet: IpPacket, +} + +impl smoltcp::phy::RxToken for SmolRxToken { + fn consume(mut self, f: F) -> R + where + F: FnOnce(&mut [u8]) -> R, + { + f(self.packet.packet_mut()) + } +} diff --git a/rust/dns-over-tcp/tests/smoke.rs b/rust/dns-over-tcp/tests/smoke.rs new file mode 100644 index 000000000..7a745654e --- /dev/null +++ b/rust/dns-over-tcp/tests/smoke.rs @@ -0,0 +1,133 @@ +use std::{ + net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4}, + process::Stdio, + task::{ready, Context, Poll}, + time::Instant, +}; + +use anyhow::{Context as _, Result}; +use domain::base::{iana::Rcode, MessageBuilder}; +use firezone_bin_shared::TunDeviceManager; +use ip_network::Ipv4Network; +use ip_packet::{IpPacket, IpPacketBuf}; +use tokio::task::JoinSet; +use tun::Tun; + +const CLIENT_CONCURRENCY: usize = 3; + +#[tokio::test] +#[ignore = "Requires root"] +async fn smoke() { + let _guard = + firezone_logging::test("netlink_proto=off,wire::dns::res=trace,dns_over_tcp=trace,debug"); + + let ipv4 = Ipv4Addr::from([100, 90, 215, 97]); + let ipv6 = Ipv6Addr::from([0xfd00, 0x2021, 0x1111, 0x0, 0x0, 0x0, 0x0016, 0x588f]); + + let mut device_manager = TunDeviceManager::new(1280).unwrap(); + let tun = device_manager.make_tun().unwrap(); + device_manager.set_ips(ipv4, ipv6).await.unwrap(); + device_manager + .set_routes( + vec![Ipv4Network::new(Ipv4Addr::new(100, 100, 111, 0), 24).unwrap()], + vec![], + ) + .await + .unwrap(); + + let listen_addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(100, 100, 111, 1), 53)); + let mut dns_server = dns_over_tcp::Server::new(Instant::now()); + dns_server.set_listen_addresses::(vec![listen_addr]); + let mut eventloop = Eventloop::new(Box::new(tun), dns_server); + + tokio::spawn(std::future::poll_fn(move |cx| eventloop.poll(cx))); + + // Running the queries multiple times ensures we can reuse sockets. + run_queries(listen_addr.ip()).await; + run_queries(listen_addr.ip()).await; +} + +async fn run_queries(dns_server: IpAddr) { + let mut set = JoinSet::new(); + + for _ in 0..CLIENT_CONCURRENCY { + set.spawn(dig(dns_server)); + } + + let exit_codes = set + .join_all() + .await + .into_iter() + .collect::>>() + .unwrap(); + + for status in exit_codes { + assert_eq!(status, 0) + } +} + +async fn dig(dns_server: IpAddr) -> Result { + let exit_status = tokio::process::Command::new("dig") + .args([ + "+tcp", + "+tries=1", + "+keepopen", // Reuse the TCP socket + &format!("@{dns_server}"), + "example.com", + "example.com", // Querying more than one domain ensures a client can reuse a TCP connection + "example.com", + "example.com", + ]) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .status() + .await? + .code() + .context("Missing code")?; + + Ok(exit_status) +} + +struct Eventloop { + tun: Box, + dns_server: dns_over_tcp::Server, +} + +impl Eventloop { + fn new(tun: Box, dns_server: dns_over_tcp::Server) -> Self { + Self { tun, dns_server } + } + + fn poll(&mut self, cx: &mut Context) -> Poll<()> { + loop { + if let Some(packet) = self.dns_server.poll_outbound() { + match packet { + IpPacket::Ipv4(v4) => self.tun.write4(v4.packet()).unwrap(), + IpPacket::Ipv6(v6) => self.tun.write6(v6.packet()).unwrap(), + }; + continue; + } + + if let Some(query) = self.dns_server.poll_queries() { + let response = MessageBuilder::new_vec() + .start_answer(&query.message, Rcode::NXDOMAIN) + .unwrap() + .into_message(); + + self.dns_server + .send_message(query.socket, response) + .unwrap(); + continue; + } + + let mut packet_buf = IpPacketBuf::default(); + let num_read = ready!(self.tun.poll_read(packet_buf.buf(), cx)).unwrap(); + let packet = IpPacket::new(packet_buf, num_read).unwrap(); + + if self.dns_server.accepts(&packet) { + self.dns_server.handle_inbound(packet); + self.dns_server.handle_timeout(Instant::now()); + } + } + } +} diff --git a/rust/ip-packet/src/lib.rs b/rust/ip-packet/src/lib.rs index 0a08745ee..4822a6476 100644 --- a/rust/ip-packet/src/lib.rs +++ b/rust/ip-packet/src/lib.rs @@ -789,7 +789,7 @@ impl IpPacket { } } - fn packet_mut(&mut self) -> &mut [u8] { + pub fn packet_mut(&mut self) -> &mut [u8] { match self { IpPacket::Ipv4(v4) => v4.packet_mut(), IpPacket::Ipv6(v6) => v6.packet_mut(),