refactor(connlib): don't manually build DNS responses (#6193)

Resolves: #5540.
This commit is contained in:
Thomas Eizinger
2024-08-07 05:27:27 +01:00
committed by GitHub
parent 622fa63535
commit a81f5128e5

View File

@@ -12,14 +12,13 @@ use hickory_resolver::proto::error::{ProtoError, ProtoErrorKind};
use hickory_resolver::proto::op::MessageType;
use hickory_resolver::proto::rr::RecordType;
use ip_packet::udp::UdpPacket;
use ip_packet::IpPacket;
use ip_packet::Packet as _;
use ip_packet::{udp::MutableUdpPacket, IpPacket, MutableIpPacket, MutablePacket, PacketSize};
use itertools::Itertools;
use std::collections::HashMap;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
const DNS_TTL: u32 = 1;
const UDP_HEADER_SIZE: usize = 8;
const REVERSE_DNS_ADDRESS_END: &str = "arpa";
const REVERSE_DNS_ADDRESS_V4: &str = "in-addr";
const REVERSE_DNS_ADDRESS_V6: &str = "ip6";
@@ -225,9 +224,16 @@ impl StubResolver {
if let Some(records) = self.known_hosts.get_records(qtype, &domain) {
let response = build_dns_with_answer(message, domain, records)?;
return Some(ResolveStrategy::LocalResponse(build_response(
packet, response,
)));
let packet = ip_packet::make::udp_packet(
packet.destination(),
packet.source(),
datagram.get_destination(),
datagram.get_source(),
response,
)
.into_immutable();
return Some(ResolveStrategy::LocalResponse(packet));
}
let maybe_resource = self.match_resource(&domain);
@@ -259,10 +265,16 @@ impl StubResolver {
};
let response = build_dns_with_answer(message, domain, resource_records)?;
let packet = ip_packet::make::udp_packet(
packet.destination(),
packet.source(),
datagram.get_destination(),
datagram.get_source(),
response,
)
.into_immutable();
Some(ResolveStrategy::LocalResponse(build_response(
packet, response,
)))
Some(ResolveStrategy::LocalResponse(packet))
}
}
@@ -313,6 +325,7 @@ pub(crate) fn build_response_from_resolve_result(
original_pkt: IpPacket<'_>,
response: hickory_resolver::error::ResolveResult<Lookup>,
) -> Result<IpPacket, hickory_resolver::error::ResolveError> {
let datagram = original_pkt.unwrap_as_udp();
let mut message = original_pkt.unwrap_as_dns();
message.set_message_type(MessageType::Response);
@@ -340,60 +353,26 @@ pub(crate) fn build_response_from_resolve_result(
}
};
let packet = build_response(original_pkt, response.to_vec()?);
let packet = ip_packet::make::udp_packet(
original_pkt.destination(),
original_pkt.source(),
datagram.get_destination(),
datagram.get_source(),
response.to_vec()?,
)
.into_immutable();
Ok(packet)
}
/// Constructs an IP packet responding to an IP packet containing a DNS query
fn build_response(original_pkt: IpPacket<'_>, mut dns_answer: Vec<u8>) -> IpPacket<'static> {
let response_len = dns_answer.len();
let original_dgm = original_pkt.unwrap_as_udp();
let hdr_len = original_pkt.packet_size() - original_dgm.payload().len();
let mut res_buf = Vec::with_capacity(hdr_len + response_len + 20);
// TODO: this is some weirdness due to how MutableIpPacket is implemented
// we need an extra 20 bytes padding.
res_buf.extend_from_slice(&[0; 20]);
res_buf.extend_from_slice(&original_pkt.packet()[..hdr_len]);
res_buf.append(&mut dns_answer);
let mut pkt = MutableIpPacket::new(&mut res_buf).unwrap();
let dgm_len = UDP_HEADER_SIZE + response_len;
match &mut pkt {
MutableIpPacket::Ipv4(p) => p.set_total_length((hdr_len + response_len) as u16),
MutableIpPacket::Ipv6(p) => p.set_payload_length(dgm_len as u16),
}
pkt.swap_src_dst();
let mut dgm = MutableUdpPacket::new(pkt.payload_mut()).unwrap();
dgm.set_length(dgm_len as u16);
dgm.set_source(original_dgm.get_destination());
dgm.set_destination(original_dgm.get_source());
let mut pkt = MutableIpPacket::new(&mut res_buf).unwrap();
let udp_checksum = pkt
.to_immutable()
.udp_checksum(&pkt.to_immutable().unwrap_as_udp());
pkt.unwrap_as_udp().set_checksum(udp_checksum);
pkt.set_ipv4_checksum();
// TODO: more of this weirdness
res_buf.drain(0..20);
IpPacket::owned(res_buf).unwrap()
}
fn build_dns_with_answer(
message: &Message<[u8]>,
qname: DomainName,
records: Vec<AllRecordData<Vec<u8>, DomainName>>,
) -> Option<Vec<u8>> {
let msg_buf = Vec::with_capacity(message.as_slice().len() * 2);
let msg_builder = MessageBuilder::from_target(msg_buf).expect(
"Developer error: we should be always be able to create a MessageBuilder from a Vec",
);
let mut answer_builder = msg_builder.start_answer(message, Rcode::NOERROR).ok()?;
let mut answer_builder = MessageBuilder::new_vec()
.start_answer(message, Rcode::NOERROR)
.ok()?;
answer_builder.header_mut().set_ra(true);
for record in records {