From e534207bbdfb2ba6a9f09f8d3c17e95ccc0ead57 Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Wed, 5 Mar 2025 14:10:59 +1100 Subject: [PATCH] refactor(connlib): remove `SocketHandle` from TCP DNS server API (#8360) At present, the TCP DNS server we use in `connlib` exposes an opaque `SocketHandle` with each received query. This handle refers to the socket that the query was received on. The response needs to be sent back on the same socket because it effectively refers to the TCP stream that was established. We need to track this `SocketHandle` all the way through to our user-space DNS client in `connlib` which actually resolves queries with a DNS server. In order to be able to reuse this DNS client on the Gateway where we receive DNS queries using a user-space socket (and thus don't have such a `SocketHandle`), we need to remove this abstraction from the public API of the TCP DNS server. A TCP stream is effectively identified by the source and destination socket address: A given 4-tuple (source IP, source port, destination IP, destination port) can only ever hold a single TCP connection. As such, returning the local and remote `SocketAddr` with the query is sufficient to uniquely identify the socket. --- rust/connlib/tunnel/src/client.rs | 34 ++++---- rust/connlib/tunnel/src/dns.rs | 9 +- .../tunnel/src/tests/dns_server_resource.rs | 4 +- rust/dns-over-tcp/src/client.rs | 4 +- rust/dns-over-tcp/src/lib.rs | 2 +- rust/dns-over-tcp/src/server.rs | 82 ++++++++++++------- rust/dns-over-tcp/tests/client_and_server.rs | 4 +- rust/dns-over-tcp/tests/smoke_server.rs | 2 +- 8 files changed, 89 insertions(+), 52 deletions(-) diff --git a/rust/connlib/tunnel/src/client.rs b/rust/connlib/tunnel/src/client.rs index cc890911a..bd08cf78b 100644 --- a/rust/connlib/tunnel/src/client.rs +++ b/rust/connlib/tunnel/src/client.rs @@ -133,9 +133,8 @@ pub struct ClientState { tcp_dns_client: dns_over_tcp::Client, tcp_dns_server: dns_over_tcp::Server, - /// Tracks the socket on which we received a TCP DNS query by the ID of the recursive DNS query we issued. - tcp_dns_sockets_by_upstream_and_query_id: - HashMap<(SocketAddr, u16), dns_over_tcp::SocketHandle>, + /// Tracks the TCP stream (i.e. socket-pair) on which we received a TCP DNS query by the ID of the recursive DNS query we issued. + tcp_dns_streams_by_upstream_and_query_id: HashMap<(SocketAddr, u16), (SocketAddr, SocketAddr)>, /// Stores the gateways we recently connected to. /// @@ -240,7 +239,7 @@ impl ClientState { buffered_dns_queries: Default::default(), tcp_dns_client: dns_over_tcp::Client::new(now, seed), tcp_dns_server: dns_over_tcp::Server::new(now), - tcp_dns_sockets_by_upstream_and_query_id: Default::default(), + tcp_dns_streams_by_upstream_and_query_id: Default::default(), pending_flows: Default::default(), dns_resource_nat_by_gateway: BTreeMap::new(), } @@ -583,7 +582,7 @@ impl ClientState { "Failed to queue UDP DNS response: {}" ); } - (dns::Transport::Tcp { source }, result) => { + (dns::Transport::Tcp { local, remote }, result) => { let message = result .inspect(|_| { tracing::trace!("Received recursive TCP DNS response"); @@ -595,7 +594,7 @@ impl ClientState { }); unwrap_or_warn!( - self.tcp_dns_server.send_message(source, message), + self.tcp_dns_server.send_message(local, remote, message), "Failed to send TCP DNS response: {}" ); } @@ -1134,9 +1133,9 @@ impl ClientState { if let Some(query_result) = self.tcp_dns_client.poll_query_result() { let server = query_result.server; let qid = query_result.query.header().id(); - let known_sockets = &mut self.tcp_dns_sockets_by_upstream_and_query_id; + let known_sockets = &mut self.tcp_dns_streams_by_upstream_and_query_id; - let Some(source) = known_sockets.remove(&(server, qid)) else { + let Some((local, remote)) = known_sockets.remove(&(server, qid)) else { tracing::warn!(?known_sockets, %server, %qid, "Failed to find TCP socket handle for query result"); continue; @@ -1148,7 +1147,7 @@ impl ClientState { message: query_result .result .map_err(|e| io::Error::new(io::ErrorKind::Other, format!("{e:#}"))), - transport: dns::Transport::Tcp { source }, + transport: dns::Transport::Tcp { local, remote }, }); continue; } @@ -1280,7 +1279,8 @@ impl ClientState { self.update_dns_resource_nat(now, iter::empty()); unwrap_or_debug!( - self.tcp_dns_server.send_message(query.socket, response), + self.tcp_dns_server + .send_message(query.local, query.remote, response), "Failed to send TCP DNS response: {}" ); } @@ -1295,7 +1295,8 @@ impl ClientState { self.buffered_dns_queries .push_back(dns::RecursiveQuery::via_tcp( - query.socket, + query.local, + query.remote, server, query.message, )); @@ -1337,8 +1338,11 @@ impl ClientState { ); unwrap_or_debug!( - self.tcp_dns_server - .send_message(query.socket, dns::servfail(query.message.for_slice_ref())), + self.tcp_dns_server.send_message( + query.local, + query.remote, + dns::servfail(query.message.for_slice_ref()) + ), "Failed to send TCP DNS response: {}" ); return; @@ -1346,8 +1350,8 @@ impl ClientState { }; let existing = self - .tcp_dns_sockets_by_upstream_and_query_id - .insert((server, query_id), query.socket); + .tcp_dns_streams_by_upstream_and_query_id + .insert((server, query_id), (query.local, query.remote)); debug_assert!(existing.is_none(), "Query IDs should be unique"); } diff --git a/rust/connlib/tunnel/src/dns.rs b/rust/connlib/tunnel/src/dns.rs index 13d9b4a1c..a52454619 100644 --- a/rust/connlib/tunnel/src/dns.rs +++ b/rust/connlib/tunnel/src/dns.rs @@ -1,7 +1,6 @@ use crate::client::IpProvider; use anyhow::{Context, Result}; use connlib_model::{DomainName, ResourceId}; -use dns_over_tcp::SocketHandle; use domain::rdata::AllRecordData; use domain::{ base::{ @@ -74,14 +73,15 @@ impl RecursiveQuery { } pub(crate) fn via_tcp( - source: SocketHandle, + local: SocketAddr, + remote: SocketAddr, server: SocketAddr, message: Message>, ) -> Self { Self { server, message, - transport: Transport::Tcp { source }, + transport: Transport::Tcp { local, remote }, } } } @@ -93,7 +93,8 @@ pub(crate) enum Transport { source: SocketAddr, }, Tcp { - source: SocketHandle, + local: SocketAddr, + remote: SocketAddr, }, } diff --git a/rust/connlib/tunnel/src/tests/dns_server_resource.rs b/rust/connlib/tunnel/src/tests/dns_server_resource.rs index 6308fffbf..fc13973d2 100644 --- a/rust/connlib/tunnel/src/tests/dns_server_resource.rs +++ b/rust/connlib/tunnel/src/tests/dns_server_resource.rs @@ -39,7 +39,9 @@ impl TcpDnsServerResource { while let Some(query) = self.server.poll_queries() { let response = handle_dns_query(query.message.for_slice(), global_dns_records); - self.server.send_message(query.socket, response).unwrap(); + self.server + .send_message(query.local, query.remote, response) + .unwrap(); } } diff --git a/rust/dns-over-tcp/src/client.rs b/rust/dns-over-tcp/src/client.rs index 4c10a4f9c..bdc3c9fc4 100644 --- a/rust/dns-over-tcp/src/client.rs +++ b/rust/dns-over-tcp/src/client.rs @@ -195,6 +195,8 @@ impl Client { } for (remote, handle) in self.sockets_by_remote.iter_mut() { + let _guard = tracing::trace_span!("socket", %handle).entered(); + let socket = self.sockets.get_mut::(*handle); let server = *remote; @@ -399,7 +401,7 @@ fn into_failed_results( fn try_recv_response<'b>(socket: &'b mut tcp::Socket) -> Result>> { if !socket.can_recv() { - tracing::trace!("Not yet ready to receive next message"); + tracing::trace!(state = %socket.state(), "Not yet ready to receive next message"); return Ok(None); } diff --git a/rust/dns-over-tcp/src/lib.rs b/rust/dns-over-tcp/src/lib.rs index ba275c7d8..7e4872141 100644 --- a/rust/dns-over-tcp/src/lib.rs +++ b/rust/dns-over-tcp/src/lib.rs @@ -6,7 +6,7 @@ mod stub_device; mod time; pub use client::{Client, QueryResult}; -pub use server::{Query, Server, SocketHandle}; +pub use server::{Query, Server}; fn create_tcp_socket() -> smoltcp::socket::tcp::Socket<'static> { /// The 2-byte length prefix of DNS over TCP messages limits their size to effectively u16::MAX. diff --git a/rust/dns-over-tcp/src/server.rs b/rust/dns-over-tcp/src/server.rs index 862c32e4b..0ae3b13c4 100644 --- a/rust/dns-over-tcp/src/server.rs +++ b/rust/dns-over-tcp/src/server.rs @@ -12,7 +12,7 @@ use anyhow::{Context as _, Result}; use domain::{base::Message, dep::octseq::OctetsInto as _}; use ip_packet::IpPacket; use smoltcp::{ - iface::{Interface, PollResult, SocketSet}, + iface::{Interface, PollResult, SocketHandle, SocketSet}, socket::tcp, wire::IpEndpoint, }; @@ -25,7 +25,11 @@ pub struct Server { interface: Interface, sockets: SocketSet<'static>, - listen_endpoints: HashMap, + listen_endpoints: HashMap, + + /// Tracks the [`SocketHandle`] on which we need to send a reply for a given query by the local socket address, remote socket address and query ID. + pending_sockets_by_local_remote_and_query_id: + HashMap<(SocketAddr, SocketAddr, u16), SocketHandle>, received_queries: VecDeque, @@ -33,18 +37,12 @@ pub struct Server { last_now: Instant, } -/// Opaque handle to a TCP socket. -/// -/// This purposely does not implement [`Clone`] or [`Copy`] to make them single-use. -#[derive(Debug, PartialEq, Eq, Hash)] -#[must_use = "An active `SocketHandle` means a TCP socket is waiting for a reply somewhere"] -pub struct SocketHandle(smoltcp::iface::SocketHandle); - pub struct Query { pub message: Message>, - pub socket: SocketHandle, - /// The address of the socket that received the query. + /// The local address of the socket that received the query. pub local: SocketAddr, + /// The remote address of the client that sent the query. + pub remote: SocketAddr, } impl Server { @@ -57,6 +55,7 @@ impl Server { interface, sockets: SocketSet::new(Vec::default()), listen_endpoints: Default::default(), + pending_sockets_by_local_remote_and_query_id: Default::default(), received_queries: Default::default(), created_at: now, last_now: now, @@ -123,7 +122,7 @@ impl Server { }; let dst = SocketAddr::new(packet.destination(), tcp.destination_port()); - let is_listening = self.listen_endpoints.values().any(|listen| listen == &dst); + let is_listening = self.listen_endpoints.values().any(|s| s == &dst); if !is_listening && tracing::enabled!(tracing::Level::TRACE) { let listen_endpoints = BTreeSet::from_iter(self.listen_endpoints.values().copied()); @@ -144,12 +143,22 @@ impl Server { self.device.receive(packet); } - /// Send a message on the socket associated with the handle. + /// Send a query response from the given source to the provided destination socket. /// - /// This fails if the socket is not writeable. + /// This fails if the socket is not writeable or if we don't have a pending query for this client. /// On any error, the TCP connection is automatically reset. - pub fn send_message(&mut self, socket: SocketHandle, message: Message>) -> Result<()> { - let socket = self.sockets.get_mut::(socket.0); + pub fn send_message( + &mut self, + src: SocketAddr, + dst: SocketAddr, + message: Message>, + ) -> Result<()> { + let handle = self + .pending_sockets_by_local_remote_and_query_id + .remove(&(src, dst, message.header().id())) + .context("No pending query found for message")?; + + let socket = self.sockets.get_mut::(handle); write_tcp_dns_response(socket, message.for_slice_ref()) .inspect_err(|_| socket.abort()) // Abort socket on error. @@ -175,15 +184,24 @@ impl Server { } for (handle, smoltcp::socket::Socket::Tcp(socket)) in self.sockets.iter_mut() { - let listen = self.listen_endpoints.get(&handle).copied().unwrap(); + let local = self.listen_endpoints.get(&handle).copied().unwrap(); - while let Some(result) = try_recv_query(socket, listen).transpose() { + let _guard = tracing::trace_span!("socket", %handle).entered(); + + while let Some(result) = try_recv_query(socket, local).transpose() { match result { - Ok(message) => { + Ok((message, remote)) => { + let qid = message.header().id(); + + tracing::trace!(%local, %remote, %qid, "Received DNS query"); + + self.pending_sockets_by_local_remote_and_query_id + .insert((local, remote, qid), handle); + self.received_queries.push_back(Query { - message: message.octets_into(), - socket: SocketHandle(handle), - local: listen, + message, + local, + remote, }); } Err(e) => { @@ -215,10 +233,10 @@ impl Server { } } -fn try_recv_query<'b>( - socket: &'b mut tcp::Socket, +fn try_recv_query( + socket: &mut tcp::Socket, listen: SocketAddr, -) -> Result>> { +) -> Result>, SocketAddr)>> { // smoltcp's sockets can only ever handle a single remote, i.e. there is no permanent listening socket. // to be able to handle a new connection, reset the socket back to `listen` once the connection is closed / closing. { @@ -245,8 +263,7 @@ fn try_recv_query<'b>( // Ensure we can recv, send and have space to send. if !socket.can_recv() || !socket.can_send() || socket.send_queue() > 0 { tracing::trace!( - can_recv = %socket.can_recv(), - can_send = %socket.can_send(), + state = %socket.state(), send_queue = %socket.send_queue(), "Not yet ready to receive next message" ); @@ -260,7 +277,16 @@ fn try_recv_query<'b>( anyhow::ensure!(!message.header().qr(), "DNS message is a response!"); - Ok(Some(message)) + let message = message.octets_into(); + + let remote = socket + .remote_endpoint() + .context("Unknown remote endpoint despite having just received a message")?; + + Ok(Some(( + message, + SocketAddr::new(remote.addr.into(), remote.port), + ))) } fn write_tcp_dns_response(socket: &mut tcp::Socket, response: Message<&[u8]>) -> Result<()> { diff --git a/rust/dns-over-tcp/tests/client_and_server.rs b/rust/dns-over-tcp/tests/client_and_server.rs index e4c5fe2ef..7ff038fe2 100644 --- a/rust/dns-over-tcp/tests/client_and_server.rs +++ b/rust/dns-over-tcp/tests/client_and_server.rs @@ -72,7 +72,9 @@ fn progress( .unwrap() .into_message(); - dns_server.send_message(query.socket, response).unwrap(); + dns_server + .send_message(query.local, query.remote, response) + .unwrap(); continue; } diff --git a/rust/dns-over-tcp/tests/smoke_server.rs b/rust/dns-over-tcp/tests/smoke_server.rs index 28f654d10..0c296d5db 100644 --- a/rust/dns-over-tcp/tests/smoke_server.rs +++ b/rust/dns-over-tcp/tests/smoke_server.rs @@ -113,7 +113,7 @@ impl Eventloop { .into_message(); self.dns_server - .send_message(query.socket, response) + .send_message(query.local, query.remote, response) .unwrap(); continue; }