diff --git a/rust/connlib/tunnel/src/dns.rs b/rust/connlib/tunnel/src/dns.rs index 04e407be5..d8cdc8055 100644 --- a/rust/connlib/tunnel/src/dns.rs +++ b/rust/connlib/tunnel/src/dns.rs @@ -12,14 +12,13 @@ use hickory_resolver::proto::error::{ProtoError, ProtoErrorKind}; use hickory_resolver::proto::op::MessageType; use hickory_resolver::proto::rr::RecordType; use ip_packet::udp::UdpPacket; +use ip_packet::IpPacket; use ip_packet::Packet as _; -use ip_packet::{udp::MutableUdpPacket, IpPacket, MutableIpPacket, MutablePacket, PacketSize}; use itertools::Itertools; use std::collections::HashMap; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; const DNS_TTL: u32 = 1; -const UDP_HEADER_SIZE: usize = 8; const REVERSE_DNS_ADDRESS_END: &str = "arpa"; const REVERSE_DNS_ADDRESS_V4: &str = "in-addr"; const REVERSE_DNS_ADDRESS_V6: &str = "ip6"; @@ -225,9 +224,16 @@ impl StubResolver { if let Some(records) = self.known_hosts.get_records(qtype, &domain) { let response = build_dns_with_answer(message, domain, records)?; - return Some(ResolveStrategy::LocalResponse(build_response( - packet, response, - ))); + let packet = ip_packet::make::udp_packet( + packet.destination(), + packet.source(), + datagram.get_destination(), + datagram.get_source(), + response, + ) + .into_immutable(); + + return Some(ResolveStrategy::LocalResponse(packet)); } let maybe_resource = self.match_resource(&domain); @@ -259,10 +265,16 @@ impl StubResolver { }; let response = build_dns_with_answer(message, domain, resource_records)?; + let packet = ip_packet::make::udp_packet( + packet.destination(), + packet.source(), + datagram.get_destination(), + datagram.get_source(), + response, + ) + .into_immutable(); - Some(ResolveStrategy::LocalResponse(build_response( - packet, response, - ))) + Some(ResolveStrategy::LocalResponse(packet)) } } @@ -313,6 +325,7 @@ pub(crate) fn build_response_from_resolve_result( original_pkt: IpPacket<'_>, response: hickory_resolver::error::ResolveResult, ) -> Result { + let datagram = original_pkt.unwrap_as_udp(); let mut message = original_pkt.unwrap_as_dns(); message.set_message_type(MessageType::Response); @@ -340,60 +353,26 @@ pub(crate) fn build_response_from_resolve_result( } }; - let packet = build_response(original_pkt, response.to_vec()?); + let packet = ip_packet::make::udp_packet( + original_pkt.destination(), + original_pkt.source(), + datagram.get_destination(), + datagram.get_source(), + response.to_vec()?, + ) + .into_immutable(); Ok(packet) } -/// Constructs an IP packet responding to an IP packet containing a DNS query -fn build_response(original_pkt: IpPacket<'_>, mut dns_answer: Vec) -> IpPacket<'static> { - let response_len = dns_answer.len(); - let original_dgm = original_pkt.unwrap_as_udp(); - let hdr_len = original_pkt.packet_size() - original_dgm.payload().len(); - let mut res_buf = Vec::with_capacity(hdr_len + response_len + 20); - - // TODO: this is some weirdness due to how MutableIpPacket is implemented - // we need an extra 20 bytes padding. - res_buf.extend_from_slice(&[0; 20]); - res_buf.extend_from_slice(&original_pkt.packet()[..hdr_len]); - res_buf.append(&mut dns_answer); - - let mut pkt = MutableIpPacket::new(&mut res_buf).unwrap(); - let dgm_len = UDP_HEADER_SIZE + response_len; - match &mut pkt { - MutableIpPacket::Ipv4(p) => p.set_total_length((hdr_len + response_len) as u16), - MutableIpPacket::Ipv6(p) => p.set_payload_length(dgm_len as u16), - } - pkt.swap_src_dst(); - - let mut dgm = MutableUdpPacket::new(pkt.payload_mut()).unwrap(); - dgm.set_length(dgm_len as u16); - dgm.set_source(original_dgm.get_destination()); - dgm.set_destination(original_dgm.get_source()); - - let mut pkt = MutableIpPacket::new(&mut res_buf).unwrap(); - let udp_checksum = pkt - .to_immutable() - .udp_checksum(&pkt.to_immutable().unwrap_as_udp()); - pkt.unwrap_as_udp().set_checksum(udp_checksum); - pkt.set_ipv4_checksum(); - - // TODO: more of this weirdness - res_buf.drain(0..20); - IpPacket::owned(res_buf).unwrap() -} - fn build_dns_with_answer( message: &Message<[u8]>, qname: DomainName, records: Vec, DomainName>>, ) -> Option> { - let msg_buf = Vec::with_capacity(message.as_slice().len() * 2); - let msg_builder = MessageBuilder::from_target(msg_buf).expect( - "Developer error: we should be always be able to create a MessageBuilder from a Vec", - ); - - let mut answer_builder = msg_builder.start_answer(message, Rcode::NOERROR).ok()?; + let mut answer_builder = MessageBuilder::new_vec() + .start_answer(message, Rcode::NOERROR) + .ok()?; answer_builder.header_mut().set_ra(true); for record in records {