diff --git a/rust/connlib/tunnel/src/client.rs b/rust/connlib/tunnel/src/client.rs index cc890911a..bd08cf78b 100644 --- a/rust/connlib/tunnel/src/client.rs +++ b/rust/connlib/tunnel/src/client.rs @@ -133,9 +133,8 @@ pub struct ClientState { tcp_dns_client: dns_over_tcp::Client, tcp_dns_server: dns_over_tcp::Server, - /// Tracks the socket on which we received a TCP DNS query by the ID of the recursive DNS query we issued. - tcp_dns_sockets_by_upstream_and_query_id: - HashMap<(SocketAddr, u16), dns_over_tcp::SocketHandle>, + /// Tracks the TCP stream (i.e. socket-pair) on which we received a TCP DNS query by the ID of the recursive DNS query we issued. + tcp_dns_streams_by_upstream_and_query_id: HashMap<(SocketAddr, u16), (SocketAddr, SocketAddr)>, /// Stores the gateways we recently connected to. /// @@ -240,7 +239,7 @@ impl ClientState { buffered_dns_queries: Default::default(), tcp_dns_client: dns_over_tcp::Client::new(now, seed), tcp_dns_server: dns_over_tcp::Server::new(now), - tcp_dns_sockets_by_upstream_and_query_id: Default::default(), + tcp_dns_streams_by_upstream_and_query_id: Default::default(), pending_flows: Default::default(), dns_resource_nat_by_gateway: BTreeMap::new(), } @@ -583,7 +582,7 @@ impl ClientState { "Failed to queue UDP DNS response: {}" ); } - (dns::Transport::Tcp { source }, result) => { + (dns::Transport::Tcp { local, remote }, result) => { let message = result .inspect(|_| { tracing::trace!("Received recursive TCP DNS response"); @@ -595,7 +594,7 @@ impl ClientState { }); unwrap_or_warn!( - self.tcp_dns_server.send_message(source, message), + self.tcp_dns_server.send_message(local, remote, message), "Failed to send TCP DNS response: {}" ); } @@ -1134,9 +1133,9 @@ impl ClientState { if let Some(query_result) = self.tcp_dns_client.poll_query_result() { let server = query_result.server; let qid = query_result.query.header().id(); - let known_sockets = &mut self.tcp_dns_sockets_by_upstream_and_query_id; + let known_sockets = &mut self.tcp_dns_streams_by_upstream_and_query_id; - let Some(source) = known_sockets.remove(&(server, qid)) else { + let Some((local, remote)) = known_sockets.remove(&(server, qid)) else { tracing::warn!(?known_sockets, %server, %qid, "Failed to find TCP socket handle for query result"); continue; @@ -1148,7 +1147,7 @@ impl ClientState { message: query_result .result .map_err(|e| io::Error::new(io::ErrorKind::Other, format!("{e:#}"))), - transport: dns::Transport::Tcp { source }, + transport: dns::Transport::Tcp { local, remote }, }); continue; } @@ -1280,7 +1279,8 @@ impl ClientState { self.update_dns_resource_nat(now, iter::empty()); unwrap_or_debug!( - self.tcp_dns_server.send_message(query.socket, response), + self.tcp_dns_server + .send_message(query.local, query.remote, response), "Failed to send TCP DNS response: {}" ); } @@ -1295,7 +1295,8 @@ impl ClientState { self.buffered_dns_queries .push_back(dns::RecursiveQuery::via_tcp( - query.socket, + query.local, + query.remote, server, query.message, )); @@ -1337,8 +1338,11 @@ impl ClientState { ); unwrap_or_debug!( - self.tcp_dns_server - .send_message(query.socket, dns::servfail(query.message.for_slice_ref())), + self.tcp_dns_server.send_message( + query.local, + query.remote, + dns::servfail(query.message.for_slice_ref()) + ), "Failed to send TCP DNS response: {}" ); return; @@ -1346,8 +1350,8 @@ impl ClientState { }; let existing = self - .tcp_dns_sockets_by_upstream_and_query_id - .insert((server, query_id), query.socket); + .tcp_dns_streams_by_upstream_and_query_id + .insert((server, query_id), (query.local, query.remote)); debug_assert!(existing.is_none(), "Query IDs should be unique"); } diff --git a/rust/connlib/tunnel/src/dns.rs b/rust/connlib/tunnel/src/dns.rs index 13d9b4a1c..a52454619 100644 --- a/rust/connlib/tunnel/src/dns.rs +++ b/rust/connlib/tunnel/src/dns.rs @@ -1,7 +1,6 @@ use crate::client::IpProvider; use anyhow::{Context, Result}; use connlib_model::{DomainName, ResourceId}; -use dns_over_tcp::SocketHandle; use domain::rdata::AllRecordData; use domain::{ base::{ @@ -74,14 +73,15 @@ impl RecursiveQuery { } pub(crate) fn via_tcp( - source: SocketHandle, + local: SocketAddr, + remote: SocketAddr, server: SocketAddr, message: Message>, ) -> Self { Self { server, message, - transport: Transport::Tcp { source }, + transport: Transport::Tcp { local, remote }, } } } @@ -93,7 +93,8 @@ pub(crate) enum Transport { source: SocketAddr, }, Tcp { - source: SocketHandle, + local: SocketAddr, + remote: SocketAddr, }, } diff --git a/rust/connlib/tunnel/src/tests/dns_server_resource.rs b/rust/connlib/tunnel/src/tests/dns_server_resource.rs index 6308fffbf..fc13973d2 100644 --- a/rust/connlib/tunnel/src/tests/dns_server_resource.rs +++ b/rust/connlib/tunnel/src/tests/dns_server_resource.rs @@ -39,7 +39,9 @@ impl TcpDnsServerResource { while let Some(query) = self.server.poll_queries() { let response = handle_dns_query(query.message.for_slice(), global_dns_records); - self.server.send_message(query.socket, response).unwrap(); + self.server + .send_message(query.local, query.remote, response) + .unwrap(); } } diff --git a/rust/dns-over-tcp/src/client.rs b/rust/dns-over-tcp/src/client.rs index 4c10a4f9c..bdc3c9fc4 100644 --- a/rust/dns-over-tcp/src/client.rs +++ b/rust/dns-over-tcp/src/client.rs @@ -195,6 +195,8 @@ impl Client { } for (remote, handle) in self.sockets_by_remote.iter_mut() { + let _guard = tracing::trace_span!("socket", %handle).entered(); + let socket = self.sockets.get_mut::(*handle); let server = *remote; @@ -399,7 +401,7 @@ fn into_failed_results( fn try_recv_response<'b>(socket: &'b mut tcp::Socket) -> Result>> { if !socket.can_recv() { - tracing::trace!("Not yet ready to receive next message"); + tracing::trace!(state = %socket.state(), "Not yet ready to receive next message"); return Ok(None); } diff --git a/rust/dns-over-tcp/src/lib.rs b/rust/dns-over-tcp/src/lib.rs index ba275c7d8..7e4872141 100644 --- a/rust/dns-over-tcp/src/lib.rs +++ b/rust/dns-over-tcp/src/lib.rs @@ -6,7 +6,7 @@ mod stub_device; mod time; pub use client::{Client, QueryResult}; -pub use server::{Query, Server, SocketHandle}; +pub use server::{Query, Server}; fn create_tcp_socket() -> smoltcp::socket::tcp::Socket<'static> { /// The 2-byte length prefix of DNS over TCP messages limits their size to effectively u16::MAX. diff --git a/rust/dns-over-tcp/src/server.rs b/rust/dns-over-tcp/src/server.rs index 862c32e4b..0ae3b13c4 100644 --- a/rust/dns-over-tcp/src/server.rs +++ b/rust/dns-over-tcp/src/server.rs @@ -12,7 +12,7 @@ use anyhow::{Context as _, Result}; use domain::{base::Message, dep::octseq::OctetsInto as _}; use ip_packet::IpPacket; use smoltcp::{ - iface::{Interface, PollResult, SocketSet}, + iface::{Interface, PollResult, SocketHandle, SocketSet}, socket::tcp, wire::IpEndpoint, }; @@ -25,7 +25,11 @@ pub struct Server { interface: Interface, sockets: SocketSet<'static>, - listen_endpoints: HashMap, + listen_endpoints: HashMap, + + /// Tracks the [`SocketHandle`] on which we need to send a reply for a given query by the local socket address, remote socket address and query ID. + pending_sockets_by_local_remote_and_query_id: + HashMap<(SocketAddr, SocketAddr, u16), SocketHandle>, received_queries: VecDeque, @@ -33,18 +37,12 @@ pub struct Server { last_now: Instant, } -/// Opaque handle to a TCP socket. -/// -/// This purposely does not implement [`Clone`] or [`Copy`] to make them single-use. -#[derive(Debug, PartialEq, Eq, Hash)] -#[must_use = "An active `SocketHandle` means a TCP socket is waiting for a reply somewhere"] -pub struct SocketHandle(smoltcp::iface::SocketHandle); - pub struct Query { pub message: Message>, - pub socket: SocketHandle, - /// The address of the socket that received the query. + /// The local address of the socket that received the query. pub local: SocketAddr, + /// The remote address of the client that sent the query. + pub remote: SocketAddr, } impl Server { @@ -57,6 +55,7 @@ impl Server { interface, sockets: SocketSet::new(Vec::default()), listen_endpoints: Default::default(), + pending_sockets_by_local_remote_and_query_id: Default::default(), received_queries: Default::default(), created_at: now, last_now: now, @@ -123,7 +122,7 @@ impl Server { }; let dst = SocketAddr::new(packet.destination(), tcp.destination_port()); - let is_listening = self.listen_endpoints.values().any(|listen| listen == &dst); + let is_listening = self.listen_endpoints.values().any(|s| s == &dst); if !is_listening && tracing::enabled!(tracing::Level::TRACE) { let listen_endpoints = BTreeSet::from_iter(self.listen_endpoints.values().copied()); @@ -144,12 +143,22 @@ impl Server { self.device.receive(packet); } - /// Send a message on the socket associated with the handle. + /// Send a query response from the given source to the provided destination socket. /// - /// This fails if the socket is not writeable. + /// This fails if the socket is not writeable or if we don't have a pending query for this client. /// On any error, the TCP connection is automatically reset. - pub fn send_message(&mut self, socket: SocketHandle, message: Message>) -> Result<()> { - let socket = self.sockets.get_mut::(socket.0); + pub fn send_message( + &mut self, + src: SocketAddr, + dst: SocketAddr, + message: Message>, + ) -> Result<()> { + let handle = self + .pending_sockets_by_local_remote_and_query_id + .remove(&(src, dst, message.header().id())) + .context("No pending query found for message")?; + + let socket = self.sockets.get_mut::(handle); write_tcp_dns_response(socket, message.for_slice_ref()) .inspect_err(|_| socket.abort()) // Abort socket on error. @@ -175,15 +184,24 @@ impl Server { } for (handle, smoltcp::socket::Socket::Tcp(socket)) in self.sockets.iter_mut() { - let listen = self.listen_endpoints.get(&handle).copied().unwrap(); + let local = self.listen_endpoints.get(&handle).copied().unwrap(); - while let Some(result) = try_recv_query(socket, listen).transpose() { + let _guard = tracing::trace_span!("socket", %handle).entered(); + + while let Some(result) = try_recv_query(socket, local).transpose() { match result { - Ok(message) => { + Ok((message, remote)) => { + let qid = message.header().id(); + + tracing::trace!(%local, %remote, %qid, "Received DNS query"); + + self.pending_sockets_by_local_remote_and_query_id + .insert((local, remote, qid), handle); + self.received_queries.push_back(Query { - message: message.octets_into(), - socket: SocketHandle(handle), - local: listen, + message, + local, + remote, }); } Err(e) => { @@ -215,10 +233,10 @@ impl Server { } } -fn try_recv_query<'b>( - socket: &'b mut tcp::Socket, +fn try_recv_query( + socket: &mut tcp::Socket, listen: SocketAddr, -) -> Result>> { +) -> Result>, SocketAddr)>> { // smoltcp's sockets can only ever handle a single remote, i.e. there is no permanent listening socket. // to be able to handle a new connection, reset the socket back to `listen` once the connection is closed / closing. { @@ -245,8 +263,7 @@ fn try_recv_query<'b>( // Ensure we can recv, send and have space to send. if !socket.can_recv() || !socket.can_send() || socket.send_queue() > 0 { tracing::trace!( - can_recv = %socket.can_recv(), - can_send = %socket.can_send(), + state = %socket.state(), send_queue = %socket.send_queue(), "Not yet ready to receive next message" ); @@ -260,7 +277,16 @@ fn try_recv_query<'b>( anyhow::ensure!(!message.header().qr(), "DNS message is a response!"); - Ok(Some(message)) + let message = message.octets_into(); + + let remote = socket + .remote_endpoint() + .context("Unknown remote endpoint despite having just received a message")?; + + Ok(Some(( + message, + SocketAddr::new(remote.addr.into(), remote.port), + ))) } fn write_tcp_dns_response(socket: &mut tcp::Socket, response: Message<&[u8]>) -> Result<()> { diff --git a/rust/dns-over-tcp/tests/client_and_server.rs b/rust/dns-over-tcp/tests/client_and_server.rs index e4c5fe2ef..7ff038fe2 100644 --- a/rust/dns-over-tcp/tests/client_and_server.rs +++ b/rust/dns-over-tcp/tests/client_and_server.rs @@ -72,7 +72,9 @@ fn progress( .unwrap() .into_message(); - dns_server.send_message(query.socket, response).unwrap(); + dns_server + .send_message(query.local, query.remote, response) + .unwrap(); continue; } diff --git a/rust/dns-over-tcp/tests/smoke_server.rs b/rust/dns-over-tcp/tests/smoke_server.rs index 28f654d10..0c296d5db 100644 --- a/rust/dns-over-tcp/tests/smoke_server.rs +++ b/rust/dns-over-tcp/tests/smoke_server.rs @@ -113,7 +113,7 @@ impl Eventloop { .into_message(); self.dns_server - .send_message(query.socket, response) + .send_message(query.local, query.remote, response) .unwrap(); continue; }