feat(connlib): transparently forward non-resources DNS queries (#6181)

Currently, `connlib` depends on `hickory-resolver` to perform DNS
queries for non-resources. This is unnecessary. Instead of buffering the
original UDP DNS query, consulting hickory to resolve the name and
mapping the response back, we can simply take the UDP payload and send
it via our protected socket directly to the original upstream DNS
server.

This ensures `connlib` is as transparent as possible for DNS queries for
non-resources. Additionally, it removes a lot of error handling and
other cruft that we currently have to perform because we are using
hickory. For example, hickory will automatically retry a DNS query after
a certain timeout. However, the OS / client talking to `connlib` will
also retry after a certain timeout because it is making DNS queries over
an unreliable transport (UDP). It is thus unnecessary for us to do that
internally.

To correctly test this change, our test-suite needed some refactoring.
Specifically, DNS servers are now modelled as dedicated `Host`s that can
receive (UDP) traffic.

Lastly, we can remove our dependency on `hickory-proto` and
`hickory-resolver` everywhere and only use `domain` for parsing DNS
messages.

Resolves: #6141.
Related: #6033.
Related: #4800. (Impossible to happen with this design)
This commit is contained in:
Thomas Eizinger
2024-08-07 09:54:49 +01:00
committed by GitHub
parent 376900ca4e
commit 128d0eb407
25 changed files with 498 additions and 1000 deletions

102
rust/Cargo.lock generated
View File

@@ -1608,18 +1608,6 @@ version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a3d8a32ae18130a3c84dd492d4215c3d913c3b07c6b63c2eb3eb7ff1101ab7bf"
[[package]]
name = "enum-as-inner"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5ffccbb6966c05b32ef8fbac435df276c4ae4d3dc55a8cd0eb9745e6c12f546a"
dependencies = [
"heck 0.4.1",
"proc-macro2",
"quote",
"syn 2.0.72",
]
[[package]]
name = "enumflags2"
version = "0.7.9"
@@ -1999,11 +1987,8 @@ dependencies = [
"domain",
"firezone-relay",
"futures",
"futures-bounded",
"futures-util",
"hex",
"hickory-proto",
"hickory-resolver",
"ip-packet",
"ip_network",
"ip_network_table",
@@ -2025,6 +2010,7 @@ dependencies = [
"tracing-appender",
"tracing-subscriber",
"tun",
"uuid",
]
[[package]]
@@ -2121,9 +2107,9 @@ dependencies = [
[[package]]
name = "futures-bounded"
version = "0.2.3"
version = "0.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e1e2774cc104e198ef3d3e1ff4ab40f86fa3245d6cb6a3a46174f21463cee173"
checksum = "91f328e7fb845fc832912fb6a34f40cf6d1888c92f974d1893a54e97b5ff542e"
dependencies = [
"futures-timer",
"futures-util",
@@ -2655,49 +2641,6 @@ version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6fe2267d4ed49bc07b63801559be28c718ea06c4738b7a03c94df7386d2cde46"
[[package]]
name = "hickory-proto"
version = "0.24.0"
source = "git+https://github.com/hickory-dns/hickory-dns?rev=a3669bd80f3f7b97f0c301c15f1cba6368d97b63#a3669bd80f3f7b97f0c301c15f1cba6368d97b63"
dependencies = [
"async-trait",
"cfg-if",
"data-encoding",
"enum-as-inner",
"futures-channel",
"futures-io",
"futures-util",
"idna",
"ipnet",
"once_cell",
"rand 0.8.5",
"thiserror",
"tinyvec",
"tokio",
"tracing",
"url",
]
[[package]]
name = "hickory-resolver"
version = "0.24.0"
source = "git+https://github.com/hickory-dns/hickory-dns?rev=a3669bd80f3f7b97f0c301c15f1cba6368d97b63#a3669bd80f3f7b97f0c301c15f1cba6368d97b63"
dependencies = [
"cfg-if",
"futures-util",
"hickory-proto",
"ipconfig",
"lru-cache",
"once_cell",
"parking_lot",
"rand 0.8.5",
"resolv-conf",
"smallvec 1.13.2",
"thiserror",
"tokio",
"tracing",
]
[[package]]
name = "hkdf"
version = "0.12.4"
@@ -2725,17 +2668,6 @@ dependencies = [
"windows-sys 0.52.0",
]
[[package]]
name = "hostname"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3c731c3e10504cc8ed35cfe2f1db4c9274c3d35fa486e3b31df46f068ef3e867"
dependencies = [
"libc",
"match_cfg",
"winapi",
]
[[package]]
name = "hostname"
version = "0.4.0"
@@ -3061,7 +2993,7 @@ dependencies = [
name = "ip-packet"
version = "0.1.0"
dependencies = [
"hickory-proto",
"domain",
"pnet_packet",
"proptest",
"test-strategy",
@@ -3355,12 +3287,6 @@ version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dd1bc4d24ad230d21fb898d1116b1801d7adfc449d42026475862ab48b11e70e"
[[package]]
name = "linked-hash-map"
version = "0.5.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0717cef1bc8b636c6e1c1bbdefc09e6322da8a9321966e8928ef80d20f7f770f"
[[package]]
name = "linux-raw-sys"
version = "0.4.13"
@@ -3404,15 +3330,6 @@ dependencies = [
"tracing-subscriber",
]
[[package]]
name = "lru-cache"
version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "31e24f1ad8321ca0e8a1e0ac13f23cb668e6f5466c2c57319f6a5cf1cc8e3b1c"
dependencies = [
"linked-hash-map",
]
[[package]]
name = "mac"
version = "0.1.1"
@@ -3464,12 +3381,6 @@ dependencies = [
"tendril",
]
[[package]]
name = "match_cfg"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ffbee8634e0d45d258acb448e7eaab3fce7a0a467395d4d9f228e3c1f01fb2e4"
[[package]]
name = "matchers"
version = "0.1.0"
@@ -4438,7 +4349,7 @@ dependencies = [
"base64 0.22.1",
"futures",
"hex",
"hostname 0.4.0",
"hostname",
"libc",
"rand_core 0.6.4",
"secrecy",
@@ -5044,7 +4955,6 @@ version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "52e44394d2086d010551b14b53b1f24e31647570cd1deb0379e2c21b329aba00"
dependencies = [
"hostname 0.3.1",
"quick-error",
]
@@ -5636,8 +5546,6 @@ dependencies = [
name = "socket-factory"
version = "0.1.0"
dependencies = [
"async-trait",
"hickory-proto",
"quinn-udp",
"socket2",
"tokio",

View File

@@ -20,7 +20,7 @@ tun = { workspace = true }
url = { version = "2.5.2", default-features = false }
[dev-dependencies]
tokio = { workspace = true, features = ["macros"] }
tokio = { workspace = true, features = ["macros", "rt-multi-thread"] }
[target.'cfg(target_os = "linux")'.dependencies]
libc = "0.2"

View File

@@ -25,7 +25,7 @@ secrecy = { workspace = true }
serde_json = "1"
socket-factory = { workspace = true }
thiserror = "1"
tokio = { workspace = true, features = ["rt"] }
tokio = { workspace = true, features = ["rt-multi-thread"] }
tracing = { workspace = true, features = ["std", "attributes"] }
tracing-appender = "0.2"
tracing-subscriber = { workspace = true }

View File

@@ -22,7 +22,7 @@ secrecy = { workspace = true }
serde_json = "1"
socket-factory = { workspace = true }
swift-bridge = { workspace = true }
tokio = { workspace = true, features = ["rt"] }
tokio = { workspace = true, features = ["rt-multi-thread"] }
tracing = { workspace = true }
tracing-appender = "0.2"
tracing-subscriber = "0.3"

View File

@@ -205,12 +205,7 @@ where
fn handle_portal_inbound_message(&mut self, msg: IngressMessages) {
match msg {
IngressMessages::ConfigChanged(config) => {
if let Err(e) = self
.tunnel
.set_new_interface_config(config.interface.clone())
{
tracing::warn!(?config, "Failed to update configuration: {e:?}");
}
self.tunnel.set_new_interface_config(config.interface)
}
IngressMessages::IceCandidates(GatewayIceCandidates {
gateway_id,
@@ -225,14 +220,11 @@ where
resources,
relays,
}) => {
if let Err(e) = self.tunnel.set_new_interface_config(interface) {
tracing::warn!("Failed to set interface on tunnel: {e}");
return;
}
self.tunnel.set_new_interface_config(interface);
self.tunnel.set_resources(resources);
self.tunnel.update_relays(BTreeSet::default(), relays);
tracing::info!("Firezone Started!");
self.tunnel.set_resources(resources);
self.tunnel.update_relays(BTreeSet::default(), relays)
}
IngressMessages::ResourceCreatedOrUpdated(resource) => {
self.tunnel.add_resource(resource);

View File

@@ -12,11 +12,8 @@ chrono = { workspace = true }
connlib-shared = { workspace = true }
domain = { workspace = true }
futures = { version = "0.3", default-features = false, features = ["std", "async-await", "executor"] }
futures-bounded = { workspace = true }
futures-util = { version = "0.3", default-features = false, features = ["std", "async-await", "async-await-macro"] }
hex = "0.4.3"
hickory-proto = { workspace = true }
hickory-resolver = { workspace = true, features = ["tokio-runtime"] }
ip-packet = { workspace = true }
ip_network = { version = "0.4", default-features = false }
ip_network_table = { version = "0.2", default-features = false }
@@ -27,17 +24,17 @@ rangemap = "1.5.1"
secrecy = { workspace = true }
serde = { version = "1.0", default-features = false, features = ["derive", "std"] }
snownet = { workspace = true }
socket-factory = { workspace = true, features = ["hickory"] }
socket-factory = { workspace = true }
socket2 = { workspace = true }
thiserror = { version = "1.0", default-features = false }
tokio = { workspace = true }
tracing = { workspace = true, features = ["attributes"] }
tun = { workspace = true }
uuid = { version = "1.10", default-features = false, features = ["std", "v4"] }
[dev-dependencies]
derivative = "2.2.0"
firezone-relay = { workspace = true, features = ["proptest"] }
hickory-proto = { workspace = true }
ip-packet = { workspace = true, features = ["proptest"] }
proptest-state-machine = "0.3"
rand = "0.8"

View File

@@ -1,7 +1,6 @@
use crate::dns;
use crate::dns::StubResolver;
use crate::io::DnsQueryError;
use crate::peer_store::PeerStore;
use crate::{dns, dns::DnsQuery};
use anyhow::Context;
use bimap::BiMap;
use connlib_shared::callbacks::Status;
@@ -18,14 +17,14 @@ use ip_network::{IpNetwork, Ipv4Network, Ipv6Network};
use ip_network_table::IpNetworkTable;
use ip_packet::{IpPacket, MutableIpPacket, Packet as _};
use itertools::Itertools;
use tracing::Level;
use crate::peer::GatewayOnClient;
use crate::utils::{self, earliest, turn};
use crate::{ClientEvent, ClientTunnel, Tun};
use core::fmt;
use domain::base::Message;
use secrecy::{ExposeSecret as _, Secret};
use snownet::{ClientNode, RelaySocket};
use snownet::{ClientNode, RelaySocket, Transmit};
use std::borrow::Cow;
use std::collections::hash_map::Entry;
use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet, VecDeque};
use std::iter;
@@ -140,29 +139,12 @@ impl ClientTunnel {
pub fn set_new_dns(&mut self, new_dns: Vec<IpAddr>) {
// We store the sentinel dns both in the config and in the system's resolvers
// but when we calculate the dns mapping, those are ignored.
let dns_changed = self.role_state.update_system_resolvers(new_dns);
if !dns_changed {
return;
}
self.io
.set_upstream_dns_servers(self.role_state.dns_mapping());
self.role_state.update_system_resolvers(new_dns);
}
#[tracing::instrument(level = "trace", skip(self))]
pub fn set_new_interface_config(
&mut self,
config: InterfaceConfig,
) -> connlib_shared::Result<()> {
let dns_changed = self.role_state.update_interface_config(config);
if dns_changed {
self.io
.set_upstream_dns_servers(self.role_state.dns_mapping());
}
Ok(())
pub fn set_new_interface_config(&mut self, config: InterfaceConfig) {
self.role_state.update_interface_config(config);
}
pub fn cleanup_connection(&mut self, id: ResourceId) {
@@ -273,15 +255,19 @@ pub struct ClientState {
/// The DNS resolvers configured on the system outside of connlib.
system_resolvers: Vec<IpAddr>,
/// DNS queries that we need to forward to the system resolver.
buffered_dns_queries: VecDeque<DnsQuery<'static>>,
/// Maps from connlib-assigned IP of a DNS server back to the originally configured system DNS resolver.
dns_mapping: BiMap<IpAddr, DnsServer>,
/// DNS queries that had their destination IP mangled because the servers is a CIDR resource.
///
/// The [`Instant`] tracks when the DNS query expires.
mangled_dns_queries: HashMap<u16, Instant>,
/// DNS queries that were forwarded to an upstream server.
///
/// - The [`SocketAddr`] is the original source IP.
/// - The [`Instant`] tracks when the DNS query expires.
///
/// We store an explicit expiry to avoid a memory leak in case of a non-responding DNS server.
forwarded_dns_queries: HashMap<u16, (SocketAddr, Instant)>,
/// Manages internal dns records and emits forwarding event when not internally handled
stub_resolver: StubResolver,
@@ -293,6 +279,7 @@ pub struct ClientState {
buffered_events: VecDeque<ClientEvent>,
buffered_packets: VecDeque<IpPacket<'static>>,
buffered_transmits: VecDeque<Transmit<'static>>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
@@ -318,14 +305,15 @@ impl ClientState {
buffered_events: Default::default(),
interface_config: Default::default(),
buffered_packets: Default::default(),
buffered_dns_queries: Default::default(),
node: ClientNode::new(private_key.into(), seed),
system_resolvers: Default::default(),
sites_status: Default::default(),
gateways_site: Default::default(),
mangled_dns_queries: Default::default(),
forwarded_dns_queries: Default::default(),
stub_resolver: StubResolver::new(known_hosts),
disabled_resources: Default::default(),
buffered_transmits: Default::default(),
}
}
@@ -432,7 +420,7 @@ impl ClientState {
packet: MutableIpPacket<'_>,
now: Instant,
) -> Option<snownet::Transmit<'s>> {
let (packet, dst) = match self.handle_dns(packet) {
let (packet, dst) = match self.try_handle_dns_query(packet, now) {
Ok(response) => {
self.buffered_packets.push_back(response?.to_owned());
return None;
@@ -492,6 +480,10 @@ impl ClientState {
now: Instant,
buffer: &'b mut [u8],
) -> Option<IpPacket<'b>> {
if let Some(response) = self.try_handle_forwarded_dns_response(from, packet) {
return Some(response);
};
let (gid, packet) = self.node.decapsulate(
local,
from,
@@ -627,35 +619,42 @@ impl ClientState {
!interface.upstream_dns.is_empty()
}
/// Attempt to handle the given packet as a DNS packet.
/// Attempt to handle the given packet as a DNS query packet.
///
/// Returns `Ok` if the packet is in fact a DNS query with an optional response to send back.
/// Returns `Err` if the packet is not a DNS query.
fn handle_dns<'a>(
fn try_handle_dns_query<'a>(
&mut self,
packet: MutableIpPacket<'a>,
now: Instant,
) -> Result<Option<IpPacket<'a>>, (MutableIpPacket<'a>, IpAddr)> {
match self
.stub_resolver
.handle(&self.dns_mapping, packet.as_immutable())
{
Some(dns::ResolveStrategy::LocalResponse(query)) => Ok(Some(query)),
Some(dns::ResolveStrategy::ForwardQuery(query)) => {
// There's an edge case here, where the resolver's ip has been resolved before as
// a dns resource... we will ignore that weird case for now.
if let Some(upstream_dns) = self.dns_mapping.get_by_left(&query.query.destination())
{
let ip = upstream_dns.ip();
Some(dns::ResolveStrategy::ForwardQuery {
upstream: server,
query_id: id,
payload,
original_src,
}) => {
let ip = server.ip();
// In case the DNS server is a CIDR resource, it needs to go through the tunnel.
if self.is_upstream_set_by_the_portal()
&& self.active_cidr_resources.longest_match(ip).is_some()
{
return Err((packet, ip));
}
// In case the DNS server is a CIDR resource, it needs to go through the tunnel.
if self.is_upstream_set_by_the_portal()
&& self.active_cidr_resources.longest_match(ip).is_some()
{
return Err((packet, ip));
}
self.buffered_dns_queries.push_back(query.into_owned());
self.forwarded_dns_queries
.insert(id, (original_src, now + IDS_EXPIRE));
self.buffered_transmits.push_back(Transmit {
src: None,
dst: server,
payload: Cow::Owned(payload),
});
Ok(None)
}
@@ -666,46 +665,26 @@ impl ClientState {
}
}
#[tracing::instrument(level = "debug", skip_all, fields(name = %query.name, server = %query.query.destination()))] // On debug level, we can log potentially sensitive information such as domain names.
pub(crate) fn on_dns_result(
fn try_handle_forwarded_dns_response<'a>(
&mut self,
query: DnsQuery<'static>,
response: Result<
Result<
Result<hickory_resolver::lookup::Lookup, hickory_resolver::error::ResolveError>,
futures_bounded::Timeout,
>,
DnsQueryError,
>,
) {
let query = query.query;
let make_error_reply = {
let query = query.clone();
from: SocketAddr,
packet: &[u8],
) -> Option<IpPacket<'a>> {
// The sentinel DNS server shall be the source. If we don't have a sentinel DNS for this socket, it cannot be a DNS response.
let saddr = *self.dns_mapping.get_by_right(&DnsServer::from(from))?;
let sport = DNS_PORT;
|e: &dyn fmt::Display| {
// To avoid sensitive data getting into the logs, only log the error if debug logging is enabled.
// We always want to see a warning.
if tracing::enabled!(Level::DEBUG) {
tracing::warn!("DNS query failed: {e}");
} else {
tracing::warn!("DNS query failed");
};
let message = Message::from_slice(packet).ok()?;
let query_id = message.header().id();
ip_packet::make::dns_err_response(query, hickory_proto::op::ResponseCode::ServFail)
.into_immutable()
}
};
let (destination, _) = self.forwarded_dns_queries.remove(&query_id)?;
let daddr = destination.ip();
let dport = destination.port();
let dns_reply = match response {
Ok(Ok(response)) => match dns::build_response_from_resolve_result(query, response) {
Ok(dns_reply) => dns_reply,
Err(e) => make_error_reply(&e),
},
Ok(Err(timeout)) => make_error_reply(&timeout),
Err(e) => make_error_reply(&e),
};
self.buffered_packets.push_back(dns_reply);
Some(
ip_packet::make::udp_packet(saddr, daddr, sport, dport, packet.to_vec())
.into_immutable(),
)
}
pub fn on_connection_failed(&mut self, resource: ResourceId) {
@@ -844,15 +823,13 @@ impl ClientState {
.filter(|resource| self.is_resource_enabled(resource))
}
#[must_use]
pub(crate) fn update_system_resolvers(&mut self, new_dns: Vec<IpAddr>) -> bool {
pub(crate) fn update_system_resolvers(&mut self, new_dns: Vec<IpAddr>) {
self.system_resolvers = new_dns;
self.update_dns_mapping()
}
#[must_use]
pub(crate) fn update_interface_config(&mut self, config: InterfaceConfig) -> bool {
pub(crate) fn update_interface_config(&mut self, config: InterfaceConfig) {
self.interface_config = Some(config);
self.update_dns_mapping()
@@ -862,10 +839,6 @@ impl ClientState {
self.buffered_packets.pop_front()
}
pub fn poll_dns_queries(&mut self) -> Option<DnsQuery<'static>> {
self.buffered_dns_queries.pop_front()
}
pub fn poll_timeout(&mut self) -> Option<Instant> {
// The number of mangled DNS queries is expected to be fairly small because we only track them whilst connecting to a CIDR resource that is a DNS server.
// Thus, sorting these values on-demand even within `poll_timeout` is expected to be performant enough.
@@ -878,6 +851,7 @@ impl ClientState {
pub fn handle_timeout(&mut self, now: Instant) {
self.node.handle_timeout(now);
self.mangled_dns_queries.retain(|_, exp| now < *exp);
self.forwarded_dns_queries.retain(|_, (_, exp)| now < *exp);
self.drain_node_events();
}
@@ -965,7 +939,9 @@ impl ClientState {
}
pub(crate) fn poll_transmit(&mut self) -> Option<snownet::Transmit<'static>> {
self.node.poll_transmit()
self.buffered_transmits
.pop_front()
.or_else(|| self.node.poll_transmit())
}
/// Sets a new set of resources.
@@ -1078,9 +1054,11 @@ impl ClientState {
self.resources_gateways.remove(&id);
}
fn update_dns_mapping(&mut self) -> bool {
fn update_dns_mapping(&mut self) {
let Some(config) = &self.interface_config else {
return false;
tracing::debug!("Unable to update DNS servesr without interface configuration");
return;
};
let effective_dns_servers =
@@ -1089,7 +1067,9 @@ impl ClientState {
if HashSet::<&DnsServer>::from_iter(effective_dns_servers.iter())
== HashSet::from_iter(self.dns_mapping.right_values())
{
return false;
tracing::debug!("Effective DNS servers are unchanged");
return;
}
let dns_mapping = sentinel_dns_mapping(
@@ -1121,8 +1101,6 @@ impl ClientState {
ip4: self.routes().filter_map(utils::ipv4).collect(),
ip6: self.routes().filter_map(utils::ipv6).collect(),
});
true
}
pub fn update_relays(
@@ -1390,134 +1368,6 @@ mod tests {
assert!(is_definitely_not_a_resource(ip("ff02::2")))
}
#[test]
fn update_system_dns_works() {
let mut client_state = ClientState::for_test();
client_state.interface_config = Some(interface_config_without_dns());
let dns_changed = client_state.update_system_resolvers(vec![ip("1.1.1.1")]);
assert!(dns_changed);
dns_mapping_is_exactly(client_state.dns_mapping(), vec![dns("1.1.1.1:53")]);
}
#[test]
fn update_system_dns_without_change_is_a_no_op() {
let mut client_state = ClientState::for_test();
client_state.interface_config = Some(interface_config_without_dns());
let _ = client_state.update_system_resolvers(vec![ip("1.1.1.1")]);
let dns_changed = client_state.update_system_resolvers(vec![ip("1.1.1.1")]);
assert!(!dns_changed);
dns_mapping_is_exactly(client_state.dns_mapping(), vec![dns("1.1.1.1:53")]);
}
#[test]
fn update_system_dns_with_change_works() {
let mut client_state = ClientState::for_test();
client_state.interface_config = Some(interface_config_without_dns());
let _ = client_state.update_system_resolvers(vec![ip("1.1.1.1")]);
let dns_changed = client_state.update_system_resolvers(vec![ip("1.0.0.1")]);
assert!(dns_changed);
dns_mapping_is_exactly(client_state.dns_mapping(), vec![dns("1.0.0.1:53")]);
}
#[test]
fn update_to_system_with_sentinels_are_ignored() {
let mut client_state = ClientState::for_test();
client_state.interface_config = Some(interface_config_without_dns());
let _ = client_state.update_system_resolvers(vec![ip("1.1.1.1")]);
let dns_changed = client_state.update_system_resolvers(vec![
ip("1.1.1.1"),
ip("100.100.111.1"),
ip("fd00:2021:1111:8000:100:100:111:0"),
]);
assert!(!dns_changed);
dns_mapping_is_exactly(client_state.dns_mapping(), vec![dns("1.1.1.1:53")]);
}
#[test]
fn upstream_dns_wins_over_system() {
let mut client_state = ClientState::for_test();
client_state.interface_config = Some(interface_config_with_dns());
let dns_changed = client_state.update_dns_mapping();
assert!(dns_changed);
let dns_changed = client_state.update_system_resolvers(vec![ip("1.0.0.1")]);
assert!(!dns_changed);
dns_mapping_is_exactly(client_state.dns_mapping(), dns_list());
}
#[test]
fn upstream_dns_change_updates() {
let mut client_state = ClientState::for_test();
let dns_changed = client_state.update_interface_config(interface_config_with_dns());
assert!(dns_changed);
let dns_changed = client_state.update_interface_config(InterfaceConfig {
upstream_dns: vec![dns("8.8.8.8:53")],
..interface_config_without_dns()
});
assert!(dns_changed);
dns_mapping_is_exactly(client_state.dns_mapping(), vec![dns("8.8.8.8:53")]);
}
#[test]
fn upstream_dns_no_change_is_a_no_op() {
let mut client_state = ClientState::for_test();
client_state.interface_config = Some(interface_config_with_dns());
let dns_changed = client_state.update_system_resolvers(vec![ip("1.0.0.1")]);
assert!(dns_changed);
let dns_changed = client_state.update_interface_config(interface_config_with_dns());
assert!(!dns_changed);
dns_mapping_is_exactly(client_state.dns_mapping(), dns_list());
}
#[test]
fn upstream_dns_sentinels_are_ignored() {
let mut client_state = ClientState::for_test();
let mut config = interface_config_with_dns();
let _ = client_state.update_interface_config(config.clone());
config.upstream_dns.push(dns("100.100.111.1:53"));
config
.upstream_dns
.push(dns("[fd00:2021:1111:8000:100:100:111:0]:53"));
let dns_changed = client_state.update_interface_config(config);
assert!(!dns_changed);
dns_mapping_is_exactly(client_state.dns_mapping(), dns_list())
}
#[test]
fn system_dns_takes_over_when_upstream_are_unset() {
let mut client_state = ClientState::for_test();
let _ = client_state.update_interface_config(interface_config_with_dns());
let _ = client_state.update_system_resolvers(vec![ip("1.0.0.1")]);
let dns_changed = client_state.update_interface_config(interface_config_without_dns());
assert!(dns_changed);
dns_mapping_is_exactly(client_state.dns_mapping(), vec![dns("1.0.0.1:53")]);
}
#[test]
fn sentinel_dns_works() {
let servers = dns_list();
@@ -1559,29 +1409,6 @@ mod tests {
}
}
fn dns_mapping_is_exactly(mapping: BiMap<IpAddr, DnsServer>, servers: Vec<DnsServer>) {
assert_eq!(
HashSet::<&DnsServer>::from_iter(mapping.right_values()),
HashSet::from_iter(servers.iter())
)
}
fn interface_config_without_dns() -> InterfaceConfig {
InterfaceConfig {
ipv4: "10.0.0.1".parse().unwrap(),
ipv6: "fe80::".parse().unwrap(),
upstream_dns: Vec::new(),
}
}
fn interface_config_with_dns() -> InterfaceConfig {
InterfaceConfig {
ipv4: "10.0.0.1".parse().unwrap(),
ipv6: "fe80::".parse().unwrap(),
upstream_dns: dns_list(),
}
}
fn sentinel_ranges() -> Vec<IpNetwork> {
vec![
IpNetwork::V4(DNS_SENTINELS_V4),

View File

@@ -7,16 +7,11 @@ use domain::base::{
Message, MessageBuilder, ToName,
};
use domain::rdata::AllRecordData;
use hickory_resolver::lookup::Lookup;
use hickory_resolver::proto::error::{ProtoError, ProtoErrorKind};
use hickory_resolver::proto::op::MessageType;
use hickory_resolver::proto::rr::RecordType;
use ip_packet::udp::UdpPacket;
use ip_packet::IpPacket;
use ip_packet::Packet as _;
use itertools::Itertools;
use std::collections::HashMap;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
const DNS_TTL: u32 = 1;
const REVERSE_DNS_ADDRESS_END: &str = "arpa";
@@ -34,22 +29,18 @@ pub struct StubResolver {
known_hosts: KnownHosts,
}
#[derive(Debug)]
pub struct DnsQuery<'a> {
pub name: DomainName,
pub record_type: RecordType,
// We could be much more efficient with this field,
// we only need the header to create the response.
pub query: ip_packet::IpPacket<'a>,
}
/// Tells the Client how to reply to a single DNS query
#[derive(Debug)]
pub(crate) enum ResolveStrategy<'a> {
pub(crate) enum ResolveStrategy {
/// The query is for a Resource, we have an IP mapped already, and we can respond instantly
LocalResponse(IpPacket<'static>),
/// The query is for a non-Resource, forward it to an upstream or system resolver
ForwardQuery(DnsQuery<'a>),
/// The query is for a non-Resource, forward it to an upstream or system resolver.
ForwardQuery {
upstream: SocketAddr,
original_src: SocketAddr,
query_id: u16,
payload: Vec<u8>,
},
}
struct KnownHosts {
@@ -193,29 +184,35 @@ impl StubResolver {
self.dns_resources.values().contains(resource)
}
// TODO: we can save a few allocations here still
// We don't need to support multiple questions/qname in a single query because
// nobody does it and since this run with each packet we want to squeeze as much optimization
// as we can therefore we won't do it.
//
// See: https://stackoverflow.com/a/55093896
/// Parses an incoming packet as a DNS query and decides how to respond to it
///
/// Returns:
/// - `None` if the packet is not a valid DNS query destined for one of our sentinel resolvers
/// - Otherwise, a strategy for responding to the query
pub(crate) fn handle<'a>(
pub(crate) fn handle(
&mut self,
dns_mapping: &bimap::BiMap<IpAddr, DnsServer>,
packet: IpPacket<'a>,
) -> Option<ResolveStrategy<'a>> {
dns_mapping.get_by_left(&packet.destination())?;
packet: IpPacket,
) -> Option<ResolveStrategy> {
let upstream = dns_mapping.get_by_left(&packet.destination())?.address();
let datagram = packet.as_udp()?;
let message = as_dns(&datagram)?;
// We only support DNS on port 53.
if datagram.get_destination() != DNS_PORT {
return None;
}
let message = Message::from_octets(datagram.payload()).ok()?;
if message.header().qr() {
return None;
}
// We don't need to support multiple questions/qname in a single query because
// nobody does it and since this run with each packet we want to squeeze as much optimization
// as we can therefore we won't do it.
//
// See: https://stackoverflow.com/a/55093896
let question = message.first_question()?;
let domain = question.qname().to_vec();
let qtype = question.qtype();
@@ -240,11 +237,12 @@ impl StubResolver {
let resource_records = match (qtype, maybe_resource) {
(_, Some(resource)) if !self.knows_resource(&resource) => {
return Some(ResolveStrategy::ForwardQuery(DnsQuery {
name: domain,
record_type: u16::from(qtype).into(),
query: packet,
}))
return Some(ResolveStrategy::ForwardQuery {
upstream,
query_id: message.header().id(),
payload: message.into_octets().to_vec(),
original_src: SocketAddr::new(packet.source(), datagram.get_source()),
})
}
(Rtype::A, Some(resource)) => self.get_or_assign_a_records(domain.clone(), resource),
(Rtype::AAAA, Some(resource)) => {
@@ -256,11 +254,12 @@ impl StubResolver {
vec![AllRecordData::Ptr(domain::rdata::Ptr::new(fqdn))]
}
_ => {
return Some(ResolveStrategy::ForwardQuery(DnsQuery {
name: domain,
record_type: u16::from(qtype).into(),
query: packet,
}))
return Some(ResolveStrategy::ForwardQuery {
upstream,
query_id: message.header().id(),
payload: message.into_octets().to_vec(),
original_src: SocketAddr::new(packet.source(), datagram.get_source()),
})
}
};
@@ -278,35 +277,6 @@ impl StubResolver {
}
}
impl<'a> DnsQuery<'a> {
pub(crate) fn into_owned(self) -> DnsQuery<'static> {
let Self {
name,
record_type,
query,
} = self;
let buf = query.packet().to_vec();
let query = ip_packet::IpPacket::owned(buf)
.expect("We are constructing the ip packet from an ip packet");
DnsQuery {
name,
record_type,
query,
}
}
}
impl Clone for DnsQuery<'static> {
fn clone(&self) -> Self {
Self {
name: self.name.clone(),
record_type: self.record_type,
query: self.query.clone(),
}
}
}
fn to_a_records(ips: impl Iterator<Item = IpAddr>) -> Vec<AllRecordData<Vec<u8>, DomainName>> {
ips.filter_map(get_v4)
.map(domain::rdata::A::new)
@@ -321,57 +291,13 @@ fn to_aaaa_records(ips: impl Iterator<Item = IpAddr>) -> Vec<AllRecordData<Vec<u
.collect_vec()
}
pub(crate) fn build_response_from_resolve_result(
original_pkt: IpPacket<'_>,
response: hickory_resolver::error::ResolveResult<Lookup>,
) -> Result<IpPacket, hickory_resolver::error::ResolveError> {
let datagram = original_pkt.unwrap_as_udp();
let mut message = original_pkt.unwrap_as_dns();
message.set_message_type(MessageType::Response);
message.set_recursion_available(true);
let response = match response.map_err(|err| err.kind().clone()) {
Ok(response) => message.add_answers(response.records().to_vec()),
Err(hickory_resolver::error::ResolveErrorKind::Proto(ProtoError { kind, .. }))
if matches!(*kind, ProtoErrorKind::NoRecordsFound { .. }) =>
{
let ProtoErrorKind::NoRecordsFound {
soa, response_code, ..
} = *kind
else {
panic!("Impossible - We matched on `ProtoErrorKind::NoRecordsFound` but then could not destructure that same variant");
};
if let Some(soa) = soa {
message.add_name_server(soa.into_record_of_rdata());
}
message.set_response_code(response_code)
}
Err(e) => {
return Err(e.into());
}
};
let packet = ip_packet::make::udp_packet(
original_pkt.destination(),
original_pkt.source(),
datagram.get_destination(),
datagram.get_source(),
response.to_vec()?,
)
.into_immutable();
Ok(packet)
}
fn build_dns_with_answer(
message: &Message<[u8]>,
message: Message<&[u8]>,
qname: DomainName,
records: Vec<AllRecordData<Vec<u8>, DomainName>>,
) -> Option<Vec<u8>> {
let mut answer_builder = MessageBuilder::new_vec()
.start_answer(message, Rcode::NOERROR)
.start_answer(&message, Rcode::NOERROR)
.ok()?;
answer_builder.header_mut().set_ra(true);
@@ -384,12 +310,6 @@ fn build_dns_with_answer(
Some(answer_builder.finish())
}
pub fn as_dns<'a>(pkt: &'a UdpPacket<'a>) -> Option<&'a Message<[u8]>> {
(pkt.get_destination() == DNS_PORT)
.then(|| Message::from_slice(pkt.payload()).ok())
.flatten()
}
pub fn is_subdomain(name: &DomainName, resource: &str) -> bool {
let question_mark = RelativeName::<Vec<_>>::from_octets(b"\x01?".as_ref().into()).unwrap();
let Ok(resource) = DomainName::vec_from_str(resource) else {

View File

@@ -1,29 +1,15 @@
use crate::{device_channel::Device, dns::DnsQuery, sockets::Sockets};
use connlib_shared::messages::DnsServer;
use futures::Future;
use futures_bounded::FuturesTupleSet;
use crate::{device_channel::Device, sockets::Sockets};
use futures_util::FutureExt as _;
use hickory_proto::iocompat::AsyncIoTokioAsStd;
use hickory_proto::TokioTime;
use hickory_resolver::{
config::{NameServerConfig, Protocol, ResolverConfig, ResolverOpts},
name_server::{GenericConnector, RuntimeProvider},
AsyncResolver, TokioHandle,
};
use ip_packet::{IpPacket, MutableIpPacket};
use socket_factory::{DatagramIn, DatagramOut, SocketFactory, TcpSocket, UdpSocket};
use std::{
collections::HashMap,
io,
net::{IpAddr, SocketAddr},
pin::Pin,
sync::Arc,
task::{ready, Context, Poll},
time::{Duration, Instant},
time::Instant,
};
const DNS_QUERIES_QUEUE_SIZE: usize = 100;
/// Bundles together all side-effects that connlib needs to have access to.
pub struct Io {
/// The TUN device offered to the user.
@@ -33,29 +19,16 @@ pub struct Io {
/// The UDP sockets used to send & receive packets from the network.
sockets: Sockets,
tcp_socket_factory: Arc<dyn SocketFactory<TcpSocket>>,
_tcp_socket_factory: Arc<dyn SocketFactory<TcpSocket>>,
udp_socket_factory: Arc<dyn SocketFactory<UdpSocket>>,
timeout: Option<Pin<Box<tokio::time::Sleep>>>,
upstream_dns_servers: HashMap<IpAddr, AsyncResolver<GenericConnector<TokioRuntimeProvider>>>,
forwarded_dns_queries: FuturesTupleSet<
Result<hickory_resolver::lookup::Lookup, hickory_resolver::error::ResolveError>,
DnsQuery<'static>,
>,
}
pub enum Input<'a, I> {
Timeout(Instant),
Device(MutableIpPacket<'a>),
Network(I),
DnsResponse(
DnsQuery<'static>,
Result<
Result<hickory_resolver::lookup::Lookup, hickory_resolver::error::ResolveError>,
futures_bounded::Timeout,
>,
),
}
impl Io {
@@ -73,13 +46,8 @@ impl Io {
device: Device::new(),
timeout: None,
sockets,
tcp_socket_factory,
_tcp_socket_factory: tcp_socket_factory,
udp_socket_factory,
upstream_dns_servers: HashMap::default(),
forwarded_dns_queries: FuturesTupleSet::new(
Duration::from_secs(60),
DNS_QUERIES_QUEUE_SIZE,
),
})
}
@@ -90,10 +58,6 @@ impl Io {
ip6_bffer: &'b mut [u8],
device_buffer: &'b mut [u8],
) -> Poll<io::Result<Input<'b, impl Iterator<Item = DatagramIn<'b>>>>> {
if let Poll::Ready((response, query)) = self.forwarded_dns_queries.poll_unpin(cx) {
return Poll::Ready(Ok(Input::DnsResponse(query, response)));
}
if let Poll::Ready(network) = self.sockets.poll_recv_from(ip4_buffer, ip6_bffer, cx)? {
return Poll::Ready(Ok(Input::Network(network)));
}
@@ -126,50 +90,6 @@ impl Io {
Ok(())
}
pub fn set_upstream_dns_servers(
&mut self,
dns_servers: impl IntoIterator<Item = (IpAddr, DnsServer)>,
) {
tracing::info!("Setting new DNS resolvers");
self.forwarded_dns_queries =
FuturesTupleSet::new(Duration::from_secs(60), DNS_QUERIES_QUEUE_SIZE);
self.upstream_dns_servers = create_resolvers(
dns_servers,
TokioRuntimeProvider::new(
self.tcp_socket_factory.clone(),
self.udp_socket_factory.clone(),
),
);
}
pub fn perform_dns_query(&mut self, query: DnsQuery<'static>) -> Result<(), DnsQueryError> {
let upstream = query.query.destination();
let resolver = self
.upstream_dns_servers
.get(&upstream)
.cloned()
.expect("Only DNS queries to known upstream servers should be forwarded to `Io`");
if self
.forwarded_dns_queries
.try_push(
{
let name = query.name.clone().to_string();
let record_type = query.record_type;
async move { resolver.lookup(&name, record_type).await }
},
query,
)
.is_err()
{
return Err(DnsQueryError::TooManyQueries);
}
Ok(())
}
pub fn reset_timeout(&mut self, timeout: Instant) {
let timeout = tokio::time::Instant::from_std(timeout);
@@ -198,92 +118,3 @@ impl Io {
Ok(())
}
}
#[derive(Debug, thiserror::Error)]
pub enum DnsQueryError {
#[error("Too many ongoing DNS queries")]
TooManyQueries,
}
/// Identical to [`TokioRuntimeProvider`](hickory_resolver::name_server::TokioRuntimeProvider) but using our own [`SocketFactory`].
#[derive(Clone)]
struct TokioRuntimeProvider {
handle: TokioHandle,
tcp_socket_factory: Arc<dyn SocketFactory<TcpSocket>>,
udp_socket_factory: Arc<dyn SocketFactory<UdpSocket>>,
}
impl TokioRuntimeProvider {
fn new(
tcp_socket_factory: Arc<dyn SocketFactory<TcpSocket>>,
udp_socket_factory: Arc<dyn SocketFactory<UdpSocket>>,
) -> TokioRuntimeProvider {
Self {
handle: Default::default(),
tcp_socket_factory,
udp_socket_factory,
}
}
}
impl RuntimeProvider for TokioRuntimeProvider {
type Handle = TokioHandle;
type Timer = TokioTime;
type Udp = UdpSocket;
type Tcp = AsyncIoTokioAsStd<tokio::net::TcpStream>;
fn create_handle(&self) -> Self::Handle {
self.handle.clone()
}
fn connect_tcp(
&self,
server_addr: SocketAddr,
) -> Pin<Box<dyn Send + Future<Output = io::Result<Self::Tcp>>>> {
let socket = (self.tcp_socket_factory)(&server_addr);
Box::pin(async move {
let socket = socket?;
let stream = socket.connect(server_addr).await?;
Ok(AsyncIoTokioAsStd(stream))
})
}
fn bind_udp(
&self,
local_addr: SocketAddr,
_server_addr: SocketAddr,
) -> Pin<Box<dyn Send + Future<Output = io::Result<Self::Udp>>>> {
let socket = (self.udp_socket_factory)(&local_addr);
Box::pin(async move { socket })
}
}
fn create_resolvers(
dns_servers: impl IntoIterator<Item = (IpAddr, DnsServer)>,
runtime_provider: TokioRuntimeProvider,
) -> HashMap<IpAddr, AsyncResolver<GenericConnector<TokioRuntimeProvider>>> {
dns_servers
.into_iter()
.map(|(sentinel, srv)| {
let mut resolver_config = ResolverConfig::new();
resolver_config.add_name_server(NameServerConfig::new(srv.address(), Protocol::Udp));
resolver_config.add_name_server(NameServerConfig::new(srv.address(), Protocol::Tcp));
let mut resolver_opts = ResolverOpts::default();
resolver_opts.edns0 = true;
resolver_opts.cache_size = 0;
resolver_opts.attempts = 1;
(
sentinel,
AsyncResolver::new_with_conn(
resolver_config,
resolver_opts,
GenericConnector::new(runtime_provider.clone()),
),
)
})
.collect()
}

View File

@@ -116,13 +116,6 @@ impl ClientTunnel {
continue;
}
if let Some(dns_query) = self.role_state.poll_dns_queries() {
if let Err(e) = self.io.perform_dns_query(dns_query.clone()) {
self.role_state.on_dns_result(dns_query, Err(e))
}
continue;
}
if let Some(timeout) = self.role_state.poll_timeout() {
self.io.reset_timeout(timeout);
}
@@ -163,10 +156,6 @@ impl ClientTunnel {
continue;
}
Poll::Ready(io::Input::DnsResponse(query, response)) => {
self.role_state.on_dns_result(query, Ok(response));
continue;
}
Poll::Pending => {}
}
@@ -254,9 +243,6 @@ impl GatewayTunnel {
continue;
}
Poll::Ready(io::Input::DnsResponse(_, _)) => {
unreachable!("Gateway does not (yet) resolve DNS queries via `Io`")
}
Poll::Pending => {}
}

View File

@@ -8,6 +8,7 @@ mod flux_capacitor;
mod reference;
mod run_count_appender;
mod sim_client;
mod sim_dns;
mod sim_gateway;
mod sim_net;
mod sim_relay;

View File

@@ -1,5 +1,5 @@
use super::{
composite_strategy::CompositeStrategy, sim_client::*, sim_gateway::*, sim_net::*,
composite_strategy::CompositeStrategy, sim_client::*, sim_dns::*, sim_gateway::*, sim_net::*,
strategies::*, stub_portal::StubPortal, transition::*,
};
use crate::dns::is_subdomain;
@@ -10,7 +10,7 @@ use connlib_shared::{
},
DomainName, StaticSecret,
};
use hickory_proto::rr::RecordType;
use domain::base::Rtype;
use proptest::{prelude::*, sample};
use proptest_state_machine::ReferenceStateMachine;
use std::{
@@ -27,6 +27,7 @@ pub(crate) struct ReferenceState {
pub(crate) client: Host<RefClient>,
pub(crate) gateways: BTreeMap<GatewayId, Host<RefGateway>>,
pub(crate) relays: BTreeMap<RelayId, Host<u64>>,
pub(crate) dns_servers: BTreeMap<DnsServerId, Host<RefDns>>,
pub(crate) portal: StubPortal,
@@ -58,13 +59,14 @@ impl ReferenceStateMachine for ReferenceState {
fn init_state() -> BoxedStrategy<Self::State> {
stub_portal()
.prop_flat_map(move |portal| {
.prop_flat_map(|portal| {
let gateways = portal.gateways();
let dns_resource_records = portal.dns_resource_records();
let client = portal.client();
let relays = relays();
let global_dns_records = global_dns_records(); // Start out with a set of global DNS records so we have something to resolve outside of DNS resources.
let drop_direct_client_traffic = any::<bool>();
let dns_servers = dns_servers();
(
client,
@@ -74,6 +76,7 @@ impl ReferenceStateMachine for ReferenceState {
relays,
global_dns_records,
drop_direct_client_traffic,
dns_servers,
)
})
.prop_filter_map(
@@ -86,6 +89,7 @@ impl ReferenceStateMachine for ReferenceState {
relays,
mut global_dns,
drop_direct_client_traffic,
dns_servers,
)| {
let mut routing_table = RoutingTable::default();
@@ -104,6 +108,12 @@ impl ReferenceStateMachine for ReferenceState {
};
}
for (id, dns_server) in &dns_servers {
if !routing_table.add_host(*id, dns_server) {
return None;
};
}
// Merge all DNS records into `global_dns`.
global_dns.extend(records);
@@ -111,6 +121,7 @@ impl ReferenceStateMachine for ReferenceState {
c,
gateways,
relays,
dns_servers,
portal,
global_dns,
drop_direct_client_traffic,
@@ -120,7 +131,7 @@ impl ReferenceStateMachine for ReferenceState {
)
.prop_filter(
"private keys must be unique",
|(c, gateways, _, _, _, _, _)| {
|(c, gateways, _, _, _, _, _, _)| {
let different_keys = gateways
.iter()
.map(|(_, g)| g.inner().key)
@@ -135,6 +146,7 @@ impl ReferenceStateMachine for ReferenceState {
client,
gateways,
relays,
dns_servers,
portal,
global_dns_records,
drop_direct_client_traffic,
@@ -144,6 +156,7 @@ impl ReferenceStateMachine for ReferenceState {
client,
gateways,
relays,
dns_servers,
portal,
global_dns_records,
network,
@@ -162,13 +175,11 @@ impl ReferenceStateMachine for ReferenceState {
CompositeStrategy::default()
.with(
1,
system_dns_servers()
.prop_map(|servers| Transition::UpdateSystemDnsServers { servers }),
update_system_dns_servers(state.dns_servers.values().cloned().collect()),
)
.with(
1,
upstream_dns_servers()
.prop_map(|servers| Transition::UpdateUpstreamDnsServers { servers }),
update_upstream_dns_servers(state.dns_servers.values().cloned().collect()),
)
.with_if_not_empty(
5,
@@ -396,12 +407,12 @@ impl ReferenceStateMachine for ReferenceState {
state.portal.gateway_for_resource(r).copied()
})
}),
Transition::UpdateSystemDnsServers { servers } => {
Transition::UpdateSystemDnsServers(servers) => {
state
.client
.exec_mut(|client| client.system_dns_resolvers.clone_from(servers));
}
Transition::UpdateUpstreamDnsServers { servers } => {
Transition::UpdateUpstreamDnsServers(servers) => {
state
.client
.exec_mut(|client| client.upstream_dns_resolvers.clone_from(servers));
@@ -513,34 +524,28 @@ impl ReferenceStateMachine for ReferenceState {
ref_client.is_valid_icmp_packet(seq, identifier)
&& ref_client.dns_records.get(dst).is_some_and(|r| match src {
IpAddr::V4(_) => r.contains(&RecordType::A),
IpAddr::V6(_) => r.contains(&RecordType::AAAA),
IpAddr::V4(_) => r.contains(&Rtype::A),
IpAddr::V6(_) => r.contains(&Rtype::AAAA),
})
&& state.gateways.contains_key(gateway)
}
Transition::UpdateSystemDnsServers { servers } => {
// TODO: PRODUCTION CODE DOES NOT HANDLE THIS!
if state.client.ip4.is_none() && servers.iter().all(|s| s.is_ipv4()) {
return false;
}
if state.client.ip6.is_none() && servers.iter().all(|s| s.is_ipv6()) {
return false;
Transition::UpdateSystemDnsServers(servers) => {
if servers.is_empty() {
return true; // Clearing is allowed.
}
true
servers
.iter()
.any(|dns_server| state.client.sending_socket_for(*dns_server).is_some())
}
Transition::UpdateUpstreamDnsServers { servers } => {
// TODO: PRODUCTION CODE DOES NOT HANDLE THIS!
if state.client.ip4.is_none() && servers.iter().all(|s| s.ip().is_ipv4()) {
return false;
}
if state.client.ip6.is_none() && servers.iter().all(|s| s.ip().is_ipv6()) {
return false;
Transition::UpdateUpstreamDnsServers(servers) => {
if servers.is_empty() {
return true; // Clearing is allowed.
}
true
servers
.iter()
.any(|dns_server| state.client.sending_socket_for(dns_server.ip()).is_some())
}
Transition::SendDnsQuery {
domain, dns_server, ..

View File

@@ -2,11 +2,10 @@ use super::{
reference::{private_key, PrivateKey, ResourceDst},
sim_net::{any_ip_stack, any_port, host, Host},
sim_relay::{map_explode, SimRelay},
strategies::{latency, system_dns_servers, upstream_dns_servers},
sut::domain_to_hickory_name,
strategies::latency,
IcmpIdentifier, IcmpSeq, QueryId,
};
use crate::{tests::sut::hickory_name_to_domain, ClientState};
use crate::ClientState;
use bimap::BiMap;
use connlib_shared::{
messages::{
@@ -16,10 +15,9 @@ use connlib_shared::{
proptest::{client_id, domain_name},
DomainName,
};
use hickory_proto::{
op::MessageType,
rr::{rdata, RData, RecordType},
serialize::binary::BinDecodable as _,
use domain::{
base::{Message, Rtype, ToName},
rdata::AllRecordData,
};
use ip_network::{Ipv4Network, Ipv6Network};
use ip_network_table::IpNetworkTable;
@@ -80,7 +78,7 @@ impl SimClient {
pub(crate) fn send_dns_query_for(
&mut self,
domain: DomainName,
r_type: RecordType,
r_type: Rtype,
query_id: u16,
dns_server: SocketAddr,
now: Instant,
@@ -92,15 +90,13 @@ impl SimClient {
tracing::debug!(%dns_server, %domain, "Sending DNS query");
let name = domain_to_hickory_name(domain);
let src = self
.sut
.tunnel_ip_for(dns_server)
.expect("tunnel should be initialised");
let packet = ip_packet::make::dns_query(
name,
domain,
r_type,
SocketAddr::new(src, 9999), // An application would pick a random source port that is free.
SocketAddr::new(dns_server, 53),
@@ -130,14 +126,13 @@ impl SimClient {
let packet = packet.to_owned().into_immutable();
if let Some(udp) = packet.as_udp() {
if let Ok(message) = hickory_proto::op::Message::from_bytes(udp.payload()) {
debug_assert_eq!(
message.message_type(),
MessageType::Query,
if let Ok(message) = Message::from_slice(udp.payload()) {
debug_assert!(
!message.header().qr(),
"every DNS message sent from the client should be a DNS query"
);
self.sent_dns_queries.insert(message.id(), packet);
self.sent_dns_queries.insert(message.header().id(), packet);
}
}
}
@@ -175,18 +170,24 @@ impl SimClient {
if let Some(udp) = packet.as_udp() {
if udp.get_source() == 53 {
let mut message = hickory_proto::op::Message::from_bytes(udp.payload())
let message = Message::from_slice(udp.payload())
.expect("ip packets on port 53 to be DNS packets");
self.received_dns_responses
.insert(message.id(), packet.to_owned());
.insert(message.header().id(), packet.to_owned());
for record in message.take_answers().into_iter() {
let domain = hickory_name_to_domain(record.name().clone());
for record in message.answer().unwrap() {
let record = record.unwrap();
let domain = record.owner().to_name();
let ip = match record.data() {
Some(RData::A(rdata::A(ip4))) => IpAddr::from(*ip4),
Some(RData::AAAA(rdata::AAAA(ip6))) => IpAddr::from(*ip6),
#[allow(clippy::wildcard_enum_match_arm)]
let ip = match record
.into_any_record::<AllRecordData<_, _>>()
.unwrap()
.data()
{
AllRecordData::A(a) => IpAddr::from(a.addr()),
AllRecordData::Aaaa(aaaa) => IpAddr::from(aaaa.addr()),
unhandled => {
panic!("Unexpected record data: {unhandled:?}")
}
@@ -253,7 +254,7 @@ pub struct RefClient {
/// The IPs assigned to a domain by connlib are an implementation detail that we don't want to model in these tests.
/// Instead, we just remember what _kind_ of records we resolved to be able to sample a matching src IP.
#[derivative(Debug = "ignore")]
pub(crate) dns_records: BTreeMap<DomainName, HashSet<RecordType>>,
pub(crate) dns_records: BTreeMap<DomainName, HashSet<Rtype>>,
/// Whether we are connected to the gateway serving the Internet resource.
pub(crate) connected_internet_resources: bool,
@@ -287,12 +288,12 @@ impl RefClient {
/// This simulates receiving the `init` message from the portal.
pub(crate) fn init(self) -> SimClient {
let mut client_state = ClientState::new(self.key, self.known_hosts, self.key.0); // Cheating a bit here by reusing the key as seed.
let _ = client_state.update_interface_config(Interface {
client_state.update_interface_config(Interface {
ipv4: self.tunnel_ip4,
ipv6: self.tunnel_ip6,
upstream_dns: self.upstream_dns_resolvers,
upstream_dns: self.upstream_dns_resolvers.clone(),
});
let _ = client_state.update_system_resolvers(self.system_dns_resolvers);
client_state.update_system_resolvers(self.system_dns_resolvers.clone());
SimClient::new(self.id, client_state)
}
@@ -478,7 +479,7 @@ impl RefClient {
.find(|id| !self.disabled_resources.contains(id))
}
fn resolved_domains(&self) -> impl Iterator<Item = (DomainName, HashSet<RecordType>)> + '_ {
fn resolved_domains(&self) -> impl Iterator<Item = (DomainName, HashSet<Rtype>)> + '_ {
self.dns_records
.iter()
.filter(|(domain, _)| self.dns_resource_by_domain(domain).is_some())
@@ -499,7 +500,7 @@ impl RefClient {
.filter_map(|(domain, records)| {
records
.iter()
.any(|r| matches!(r, RecordType::A))
.any(|r| matches!(r, &Rtype::A))
.then_some(domain)
})
.collect()
@@ -510,7 +511,7 @@ impl RefClient {
.filter_map(|(domain, records)| {
records
.iter()
.any(|r| matches!(r, RecordType::AAAA))
.any(|r| matches!(r, &Rtype::AAAA))
.then_some(domain)
})
.collect()
@@ -525,7 +526,7 @@ impl RefClient {
return self
.upstream_dns_resolvers
.iter()
.map(|s| s.address())
.map(DnsServer::address)
.collect();
}
@@ -687,32 +688,6 @@ pub(crate) fn ref_client_host(
ref_client(tunnel_ip4s, tunnel_ip6s),
latency(300), // TODO: Increase with #6062.
)
.prop_filter("at least one DNS server needs to be reachable", |host| {
// TODO: PRODUCTION CODE DOES NOT HANDLE THIS!
let upstream_dns_resolvers = &host.inner().upstream_dns_resolvers;
let system_dns = &host.inner().system_dns_resolvers;
if !upstream_dns_resolvers.is_empty() {
if host.ip4.is_none() && upstream_dns_resolvers.iter().all(|s| s.ip().is_ipv4()) {
return false;
}
if host.ip6.is_none() && upstream_dns_resolvers.iter().all(|s| s.ip().is_ipv6()) {
return false;
}
return true;
}
if host.ip4.is_none() && system_dns.iter().all(|s| s.is_ipv4()) {
return false;
}
if host.ip6.is_none() && system_dns.iter().all(|s| s.is_ipv6()) {
return false;
}
true
})
}
fn ref_client(
@@ -725,26 +700,16 @@ fn ref_client(
client_id(),
private_key(),
known_hosts(),
system_dns_servers(),
upstream_dns_servers(),
)
.prop_map(
move |(
tunnel_ip4,
tunnel_ip6,
id,
key,
known_hosts,
system_dns_resolvers,
upstream_dns_resolvers,
)| RefClient {
move |(tunnel_ip4, tunnel_ip6, id, key, known_hosts)| RefClient {
id,
key,
known_hosts,
tunnel_ip4,
tunnel_ip6,
system_dns_resolvers,
upstream_dns_resolvers,
system_dns_resolvers: Default::default(),
upstream_dns_resolvers: Default::default(),
internet_resource: Default::default(),
cidr_resources: IpNetworkTable::new(),
dns_resources: Default::default(),

View File

@@ -0,0 +1,124 @@
use super::{
sim_net::{host, Host},
strategies::latency,
};
use connlib_shared::DomainName;
use domain::{
base::{
iana::{Class, Rcode},
Message, MessageBuilder, Record, Rtype, ToName as _, Ttl,
},
rdata::AllRecordData,
};
use firezone_relay::IpStack;
use proptest::{
arbitrary::any,
strategy::{Just, Strategy},
};
use snownet::Transmit;
use std::{
borrow::Cow,
collections::{BTreeMap, HashSet},
fmt,
net::{IpAddr, SocketAddr},
time::Instant,
};
use uuid::Uuid;
pub(crate) fn dns_server_id() -> impl Strategy<Value = DnsServerId> {
any::<u128>().prop_map(DnsServerId::from_u128)
}
pub(crate) fn ref_dns_host(addr: SocketAddr) -> impl Strategy<Value = Host<RefDns>> {
let ip = addr.ip();
let port = addr.port();
host(
Just(IpStack::from(ip)),
Just(port),
Just(RefDns {}),
latency(50),
)
}
#[derive(Debug, Clone)]
pub(crate) struct RefDns {}
#[derive(Debug)]
pub(crate) struct SimDns {}
impl SimDns {
pub(crate) fn receive(
&mut self,
global_dns_records: &BTreeMap<DomainName, HashSet<IpAddr>>,
transmit: Transmit,
_now: Instant,
) -> Option<Transmit<'static>> {
let query = Message::from_octets(&transmit.payload).ok()?;
let response = MessageBuilder::new_vec();
let mut answers = response.start_answer(&query, Rcode::NOERROR).unwrap();
let query = query.sole_question().unwrap();
let name = query.qname().to_vec();
let records = global_dns_records
.get(&name)
.into_iter()
.flatten()
.filter(|ip| {
#[allow(clippy::wildcard_enum_match_arm)]
match query.qtype() {
Rtype::A => ip.is_ipv4(),
Rtype::AAAA => ip.is_ipv6(),
_ => todo!(),
}
})
.copied()
.map(|ip| match ip {
IpAddr::V4(v4) => AllRecordData::<Vec<_>, DomainName>::A(v4.into()),
IpAddr::V6(v6) => AllRecordData::<Vec<_>, DomainName>::Aaaa(v6.into()),
})
.map(|rdata| Record::new(name.clone(), Class::IN, Ttl::from_days(1), rdata));
for record in records {
answers.push(record).unwrap();
}
let payload = answers.finish();
tracing::debug!(%name, "Responding to DNS query");
Some(Transmit {
src: Some(transmit.dst),
dst: transmit.src.unwrap(),
payload: Cow::Owned(payload),
})
}
}
#[derive(Hash, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub(crate) struct DnsServerId(Uuid);
impl DnsServerId {
#[cfg(feature = "proptest")]
pub fn from_u128(v: u128) -> Self {
Self(Uuid::from_u128(v))
}
}
impl fmt::Display for DnsServerId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if cfg!(feature = "proptest") {
write!(f, "{:X}", self.0.as_u128())
} else {
write!(f, "{}", self.0)
}
}
}
impl fmt::Debug for DnsServerId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Display::fmt(&self, f)
}
}

View File

@@ -4,7 +4,7 @@ use super::{
sim_relay::{map_explode, SimRelay},
strategies::latency,
};
use crate::{tests::sut::hickory_name_to_domain, GatewayState};
use crate::GatewayState;
use connlib_shared::{
messages::{GatewayId, RelayId},
DomainName,
@@ -67,6 +67,8 @@ impl SimGateway {
) -> Option<Transmit<'static>> {
let packet = packet.to_owned();
// TODO: Instead of handling these things inline, here, should we dispatch them via `RoutingTable`?
if packet.as_icmp().is_some() {
self.received_icmp_requests.push_back(packet.clone());
@@ -78,11 +80,7 @@ impl SimGateway {
if packet.as_udp().is_some() {
let response = ip_packet::make::dns_ok_response(packet, |name| {
global_dns_records
.get(&hickory_name_to_domain(name.clone()))
.cloned()
.into_iter()
.flatten()
global_dns_records.get(name).cloned().into_iter().flatten()
});
let transmit = self.sut.encapsulate(response, now)?.into_owned();

View File

@@ -1,3 +1,4 @@
use super::sim_dns::DnsServerId;
use crate::tests::buffered_transmits::BufferedTransmits;
use crate::tests::strategies::documentation_ip6s;
use connlib_shared::messages::{ClientId, GatewayId, RelayId};
@@ -122,6 +123,15 @@ impl<T> Host<T> {
}
}
pub(crate) fn single_socket(&self) -> SocketAddr {
match (self.ip4, self.ip6) {
(None, Some(ip6)) => SocketAddr::new(ip6.into(), self.default_port),
(Some(ip4), None) => SocketAddr::new(ip4.into(), self.default_port),
(Some(_), Some(_)) => panic!("Dual-stack host"),
(None, None) => panic!("No socket available"),
}
}
pub(crate) fn latency(&self) -> Duration {
self.latency
}
@@ -253,6 +263,7 @@ pub(crate) enum HostId {
Client(ClientId),
Gateway(GatewayId),
Relay(RelayId),
DnsServer(DnsServerId),
Stale,
}
@@ -274,6 +285,12 @@ impl From<ClientId> for HostId {
}
}
impl From<DnsServerId> for HostId {
fn from(v: DnsServerId) -> Self {
Self::DnsServer(v)
}
}
pub(crate) fn host<T>(
socket_ips: impl Strategy<Value = IpStack>,
default_port: impl Strategy<Value = u16>,

View File

@@ -1,4 +1,9 @@
use super::{sim_net::Host, sim_relay::ref_relay_host, stub_portal::StubPortal};
use super::{
sim_dns::{dns_server_id, ref_dns_host, DnsServerId, RefDns},
sim_net::Host,
sim_relay::ref_relay_host,
stub_portal::StubPortal,
};
use crate::client::{IPV4_RESOURCES, IPV6_RESOURCES};
use connlib_shared::{
messages::{
@@ -6,7 +11,7 @@ use connlib_shared::{
ResourceDescriptionCidr, ResourceDescriptionDns, ResourceDescriptionInternet, Site,
SiteId,
},
DnsServer, GatewayId, RelayId,
GatewayId, RelayId,
},
proptest::{
any_ip_network, cidr_resource, dns_resource, domain_name, gateway_id, relay_id, site,
@@ -14,42 +19,15 @@ use connlib_shared::{
DomainName,
};
use ip_network::{IpNetwork, Ipv4Network, Ipv6Network};
use itertools::Itertools as _;
use itertools::Itertools;
use prop::sample;
use proptest::{collection, prelude::*};
use std::{
collections::{BTreeMap, HashMap, HashSet},
net::{IpAddr, Ipv4Addr, Ipv6Addr},
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
time::Duration,
};
pub(crate) fn upstream_dns_servers() -> impl Strategy<Value = Vec<DnsServer>> {
let ip4_dns_servers = collection::vec(
any::<Ipv4Addr>().prop_map(|ip| DnsServer::from((ip, 53))),
1..4,
);
let ip6_dns_servers = collection::vec(
any::<Ipv6Addr>().prop_map(|ip| DnsServer::from((ip, 53))),
1..4,
);
// TODO: PRODUCTION CODE DOES NOT HAVE A SAFEGUARD FOR THIS YET.
// AN ADMIN COULD CONFIGURE ONLY IPv4 SERVERS IN WHICH CASE WE ARE SCREWED IF THE CLIENT ONLY HAS IPv6 CONNECTIVITY.
prop_oneof![
Just(Vec::new()),
(ip4_dns_servers, ip6_dns_servers).prop_map(|(mut ip4_servers, ip6_servers)| {
ip4_servers.extend(ip6_servers);
ip4_servers
})
]
}
pub(crate) fn system_dns_servers() -> impl Strategy<Value = Vec<IpAddr>> {
collection::vec(any::<IpAddr>(), 1..4) // Always need at least 1 system DNS server. TODO: Should we test what happens if we don't?
}
pub(crate) fn global_dns_records() -> impl Strategy<Value = BTreeMap<DomainName, HashSet<IpAddr>>> {
collection::btree_map(
domain_name(2..4).prop_map(|d| d.parse().unwrap()),
@@ -144,6 +122,38 @@ pub(crate) fn relays() -> impl Strategy<Value = BTreeMap<RelayId, Host<u64>>> {
collection::btree_map(relay_id(), ref_relay_host(), 1..=2)
}
/// Sample a list of DNS servers.
///
/// We make sure to always have at least 1 IPv4 and 1 IPv6 DNS server.
pub(crate) fn dns_servers() -> impl Strategy<Value = BTreeMap<DnsServerId, Host<RefDns>>> {
let ip4_dns_servers = collection::hash_set(
any::<Ipv4Addr>().prop_map(|ip| SocketAddr::from((ip, 53))),
1..4,
);
let ip6_dns_servers = collection::hash_set(
any::<Ipv6Addr>().prop_map(|ip| SocketAddr::from((ip, 53))),
1..4,
);
(ip4_dns_servers, ip6_dns_servers).prop_flat_map(|(ip4_dns_servers, ip6_dns_servers)| {
let servers = Vec::from_iter(ip4_dns_servers.into_iter().chain(ip6_dns_servers));
// First, generate a unique number of IDs, one for each DNS server.
let ids = collection::hash_set(dns_server_id(), servers.len());
(ids, Just(servers))
.prop_flat_map(move |(ids, servers)| {
let ids = ids.into_iter();
// Second, zip the IDs and addresses together.
ids.zip(servers)
.map(|(id, addr)| (Just(id), ref_dns_host(addr)))
.collect::<Vec<_>>()
})
.prop_map(BTreeMap::from_iter) // Third, turn the `Vec` of tuples into a `BTreeMap`.
})
}
fn any_site(sites: HashSet<Site>) -> impl Strategy<Value = Site> {
sample::select(Vec::from_iter(sites))
}

View File

@@ -1,6 +1,7 @@
use super::buffered_transmits::BufferedTransmits;
use super::reference::ReferenceState;
use super::sim_client::SimClient;
use super::sim_dns::{DnsServerId, SimDns};
use super::sim_gateway::SimGateway;
use super::sim_net::{Host, HostId, RoutingTable};
use super::sim_relay::SimRelay;
@@ -10,17 +11,12 @@ use crate::tests::assertions::*;
use crate::tests::flux_capacitor::FluxCapacitor;
use crate::tests::transition::Transition;
use crate::utils::earliest;
use crate::{dns::DnsQuery, ClientEvent, GatewayEvent, Request};
use crate::{ClientEvent, GatewayEvent, Request};
use connlib_shared::messages::client::ResourceDescription;
use connlib_shared::{
messages::{ClientId, GatewayId, Interface, RelayId},
DomainName,
};
use hickory_proto::{
op::Query,
rr::{RData, Record, RecordType},
};
use hickory_resolver::lookup::Lookup;
use proptest_state_machine::{ReferenceStateMachine, StateMachineTest};
use secrecy::ExposeSecret as _;
use snownet::Transmit;
@@ -28,8 +24,6 @@ use std::iter;
use std::{
collections::{BTreeMap, HashSet},
net::IpAddr,
str::FromStr as _,
sync::Arc,
time::{Duration, Instant},
};
use tracing::debug_span;
@@ -47,6 +41,7 @@ pub(crate) struct TunnelTest {
client: Host<SimClient>,
gateways: BTreeMap<GatewayId, Host<SimGateway>>,
relays: BTreeMap<RelayId, Host<SimRelay>>,
dns_servers: BTreeMap<DnsServerId, Host<SimDns>>,
drop_direct_client_traffic: bool,
network: RoutingTable,
@@ -101,6 +96,16 @@ impl StateMachineTest for TunnelTest {
})
.collect::<BTreeMap<_, _>>();
let dns_servers = ref_state
.dns_servers
.iter()
.map(|(did, dns_server)| {
let dns_server = dns_server.map(|_, _, _| SimDns {}, debug_span!("dns", %did));
(*did, dns_server)
})
.collect::<BTreeMap<_, _>>();
// Configure client and gateway with the relays.
client.exec_mut(|c| c.update_relays(iter::empty(), relays.iter(), flux_capacitor.now()));
for gateway in gateways.values_mut() {
@@ -116,6 +121,7 @@ impl StateMachineTest for TunnelTest {
gateways,
logger,
relays,
dns_servers,
};
let mut buffered_transmits = BufferedTransmits::default();
@@ -221,12 +227,12 @@ impl StateMachineTest for TunnelTest {
buffered_transmits.push_from(transmit, &state.client, now);
}
Transition::UpdateSystemDnsServers { servers } => {
Transition::UpdateSystemDnsServers(servers) => {
state
.client
.exec_mut(|c| c.sut.update_system_resolvers(servers));
}
Transition::UpdateUpstreamDnsServers { servers } => {
Transition::UpdateUpstreamDnsServers(servers) => {
state.client.exec_mut(|c| {
c.sut.update_interface_config(Interface {
ipv4: c.sut.tunnel_ip4().unwrap(),
@@ -259,7 +265,7 @@ impl StateMachineTest for TunnelTest {
// Simulate receiving `init`.
state.client.exec_mut(|c| {
let _ = c.sut.update_interface_config(Interface {
c.sut.update_interface_config(Interface {
ipv4,
ipv6,
upstream_dns,
@@ -454,10 +460,6 @@ impl TunnelTest {
);
continue;
}
if let Some(query) = self.client.exec_mut(|client| client.sut.poll_dns_queries()) {
self.on_forwarded_dns_query(query, ref_state);
continue;
}
self.client.exec_mut(|sim| {
while let Some(packet) = sim.sut.poll_packets() {
sim.on_received_packet(packet)
@@ -514,7 +516,7 @@ impl TunnelTest {
for (_, relay) in self.relays.iter_mut() {
while let Some(transmit) = relay.poll_transmit(now) {
let Some(reply) = relay.exec_mut(|g| g.receive(transmit, now)) else {
let Some(reply) = relay.exec_mut(|r| r.receive(transmit, now)) else {
continue;
};
@@ -523,6 +525,18 @@ impl TunnelTest {
relay.exec_mut(|r| r.sut.handle_timeout(now))
}
for (_, dns_server) in self.dns_servers.iter_mut() {
while let Some(transmit) = dns_server.poll_transmit(now) {
let Some(reply) =
dns_server.exec_mut(|d| d.receive(global_dns_records, transmit, now))
else {
continue;
};
buffered_transmits.push_from(reply, dns_server, now);
}
}
}
fn poll_timeout(&mut self) -> Option<Instant> {
@@ -591,6 +605,12 @@ impl TunnelTest {
HostId::Stale => {
tracing::debug!(%dst, "Dropping packet because host roamed away or is offline");
}
HostId::DnsServer(id) => {
self.dns_servers
.get_mut(&id)
.expect("unknown DNS server")
.receive(transmit, now);
}
}
}
@@ -780,41 +800,6 @@ impl TunnelTest {
ClientEvent::TunRoutesUpdated { .. } => {}
}
}
// TODO: Should we vary the following things via proptests?
// - Forwarded DNS query timing out?
// - hickory error?
// - TTL?
fn on_forwarded_dns_query(&mut self, query: DnsQuery<'static>, ref_state: &ReferenceState) {
let all_ips = &ref_state
.global_dns_records
.get(&query.name)
.expect("Forwarded DNS query to be for known domain");
let name = domain_to_hickory_name(query.name.clone());
let requested_type = query.record_type;
let record_data = all_ips
.iter()
.filter_map(|ip| match (requested_type, ip) {
(RecordType::A, IpAddr::V4(v4)) => Some(RData::A((*v4).into())),
(RecordType::AAAA, IpAddr::V6(v6)) => Some(RData::AAAA((*v6).into())),
(RecordType::A, IpAddr::V6(_)) | (RecordType::AAAA, IpAddr::V4(_)) => None,
_ => unreachable!(),
})
.map(|rdata| Record::from_rdata(name.clone(), 86400_u32, rdata))
.collect::<Arc<_>>();
self.client.exec_mut(|c| {
c.sut.on_dns_result(
query,
Ok(Ok(Ok(Lookup::new_with_max_ttl(
Query::query(name, requested_type),
record_data,
)))),
)
})
}
}
fn on_gateway_event(
@@ -837,22 +822,3 @@ fn on_gateway_event(
GatewayEvent::RefreshDns { .. } => todo!(),
}
}
pub(crate) fn hickory_name_to_domain(mut name: hickory_proto::rr::Name) -> DomainName {
name.set_fqdn(false); // Hack to work around hickory always parsing as FQ
let name = name.to_string();
let domain = DomainName::from_chars(name.chars()).unwrap();
debug_assert_eq!(name, domain.to_string());
domain
}
pub(crate) fn domain_to_hickory_name(domain: DomainName) -> hickory_proto::rr::Name {
let domain = domain.to_string();
let name = hickory_proto::rr::Name::from_str(&domain).unwrap();
debug_assert_eq!(name.to_string(), domain);
name
}

View File

@@ -1,9 +1,12 @@
use super::sim_net::{any_ip_stack, any_port, Host};
use super::{
sim_dns::RefDns,
sim_net::{any_ip_stack, any_port, Host},
};
use connlib_shared::{
messages::{client::ResourceDescription, DnsServer, RelayId, ResourceId},
DomainName,
};
use hickory_proto::rr::RecordType;
use domain::base::Rtype;
use proptest::{prelude::*, sample};
use std::{
collections::{BTreeMap, HashSet},
@@ -50,16 +53,16 @@ pub(crate) enum Transition {
SendDnsQuery {
domain: DomainName,
/// The type of DNS query we should send.
r_type: RecordType,
r_type: Rtype,
/// The DNS query ID.
query_id: u16,
dns_server: SocketAddr,
},
/// The system's DNS servers changed.
UpdateSystemDnsServers { servers: Vec<IpAddr> },
UpdateSystemDnsServers(Vec<IpAddr>),
/// The upstream DNS servers changed.
UpdateUpstreamDnsServers { servers: Vec<DnsServer> },
UpdateUpstreamDnsServers(Vec<DnsServer>),
/// Roam the client to a new pair of sockets.
RoamClient {
@@ -165,7 +168,7 @@ where
(
domain,
dns_server.prop_map_into(),
prop_oneof![Just(RecordType::A), Just(RecordType::AAAA)],
prop_oneof![Just(Rtype::A), Just(Rtype::AAAA)],
any::<u16>(),
)
.prop_map(
@@ -185,3 +188,27 @@ pub(crate) fn roam_client() -> impl Strategy<Value = Transition> {
port,
})
}
pub(crate) fn update_system_dns_servers(
dns_servers: Vec<Host<RefDns>>,
) -> impl Strategy<Value = Transition> {
let max = dns_servers.len();
sample::subsequence(dns_servers, ..=max).prop_map(|seq| {
Transition::UpdateSystemDnsServers(
seq.into_iter().map(|h| h.single_socket().ip()).collect(),
)
})
}
pub(crate) fn update_upstream_dns_servers(
dns_servers: Vec<Host<RefDns>>,
) -> impl Strategy<Value = Transition> {
let max = dns_servers.len();
sample::subsequence(dns_servers, ..=max).prop_map(|seq| {
Transition::UpdateUpstreamDnsServers(
seq.into_iter().map(|h| h.single_socket().into()).collect(),
)
})
}

View File

@@ -26,7 +26,7 @@ socket-factory = { workspace = true }
thiserror = { version = "1.0", default-features = false }
# This actually relies on many other features in Tokio, so this will probably
# fail to build outside the workspace. <https://github.com/firezone/firezone/pull/4328#discussion_r1540342142>
tokio = { workspace = true, features = ["macros", "signal", "process", "time"] }
tokio = { workspace = true, features = ["macros", "signal", "process", "time", "rt-multi-thread"] }
tokio-stream = "0.1.15"
tokio-util = { version = "0.7.11", features = ["codec"] }
tracing = { workspace = true }

View File

@@ -10,7 +10,7 @@ publish = false
proptest = ["dep:proptest"]
[dependencies]
hickory-proto = { workspace = true }
domain = "0.10.1"
pnet_packet = { version = "0.35" }
proptest = { version = "1", optional = true }
thiserror = "1"

View File

@@ -8,6 +8,7 @@ pub use pnet_packet::*;
#[cfg(all(test, feature = "proptest"))]
mod proptests;
use domain::base::Message;
use pnet_packet::{
icmp::{
destination_unreachable::IcmpCodes, echo_reply::MutableEchoReplyPacket,
@@ -1409,9 +1410,9 @@ impl<'a> IpPacket<'a> {
}
/// Unwrap this [`IpPacket`] as a DNS message, panicking in case it is not.
pub fn unwrap_as_dns(&self) -> hickory_proto::op::Message {
pub fn unwrap_as_dns(&self) -> Message<Vec<u8>> {
let udp = self.unwrap_as_udp();
let message = match hickory_proto::op::Message::from_vec(udp.payload()) {
let message = match Message::from_octets(udp.payload().to_vec()) {
Ok(message) => message,
Err(e) => {
panic!("Failed to parse UDP payload as DNS message: {e}");

View File

@@ -1,9 +1,12 @@
//! Factory module for making all kinds of packets.
use crate::{IpPacket, MutableIpPacket};
use hickory_proto::{
op::{Message, Query, ResponseCode},
rr::{Name, RData, Record, RecordType},
use domain::{
base::{
iana::{Class, Opcode, Rcode},
MessageBuilder, Name, Question, Record, Rtype, ToName, Ttl,
},
rdata::AllRecordData,
};
use pnet_packet::{
ip::IpNextHeaderProtocol,
@@ -244,24 +247,26 @@ where
}
pub fn dns_query(
domain: Name,
kind: RecordType,
domain: Name<Vec<u8>>,
kind: Rtype,
src: SocketAddr,
dst: SocketAddr,
id: u16,
) -> MutableIpPacket<'static> {
// Create the DNS query message
let mut msg = Message::new();
msg.set_message_type(hickory_proto::op::MessageType::Query);
msg.set_op_code(hickory_proto::op::OpCode::Query);
msg.set_recursion_desired(true);
msg.set_id(id);
let mut msg_builder = MessageBuilder::new_vec();
msg_builder.header_mut().set_opcode(Opcode::QUERY);
msg_builder.header_mut().set_rd(true);
msg_builder.header_mut().set_id(id);
// Create the query
let query = Query::query(domain, kind);
msg.add_query(query);
let mut question_builder = msg_builder.question();
question_builder
.push(Question::new_in(domain, kind))
.unwrap();
let payload = msg.to_vec().unwrap();
let payload = question_builder.finish();
udp_packet(src.ip(), dst.ip(), src.port(), dst.port(), payload)
}
@@ -269,60 +274,42 @@ pub fn dns_query(
/// Makes a DNS response to the given DNS query packet, using a resolver callback.
pub fn dns_ok_response<I>(
packet: IpPacket<'static>,
resolve: impl Fn(&Name) -> I,
resolve: impl Fn(&Name<Vec<u8>>) -> I,
) -> MutableIpPacket<'static>
where
I: Iterator<Item = IpAddr>,
{
let udp = packet.unwrap_as_udp();
let mut query = packet.unwrap_as_dns();
let query = packet.unwrap_as_dns();
let mut response = Message::new();
response.set_id(query.id());
response.set_message_type(hickory_proto::op::MessageType::Response);
let response = MessageBuilder::new_vec();
let mut answers = response.start_answer(&query, Rcode::NOERROR).unwrap();
for query in query.take_queries() {
response.add_query(query.clone());
for query in query.question() {
let query = query.unwrap();
let name = query.qname().to_name();
let records = resolve(query.name())
let records = resolve(&name)
.filter(|ip| {
#[allow(clippy::wildcard_enum_match_arm)]
match query.query_type() {
RecordType::A => ip.is_ipv4(),
RecordType::AAAA => ip.is_ipv6(),
match query.qtype() {
Rtype::A => ip.is_ipv4(),
Rtype::AAAA => ip.is_ipv6(),
_ => todo!(),
}
})
.map(|ip| match ip {
IpAddr::V4(v4) => RData::A(v4.into()),
IpAddr::V6(v6) => RData::AAAA(v6.into()),
IpAddr::V4(v4) => AllRecordData::<Vec<_>, Name<Vec<_>>>::A(v4.into()),
IpAddr::V6(v6) => AllRecordData::<Vec<_>, Name<Vec<_>>>::Aaaa(v6.into()),
})
.map(|rdata| Record::from_rdata(query.name().clone(), 86400_u32, rdata));
.map(|rdata| Record::new(name.clone(), Class::IN, Ttl::from_days(1), rdata));
response.add_answers(records);
for record in records {
answers.push(record).unwrap();
}
}
let payload = response.to_vec().unwrap();
udp_packet(
packet.destination(),
packet.source(),
udp.get_destination(),
udp.get_source(),
payload,
)
}
/// Makes a DNS response to the given DNS query packet, using the given error code.
pub fn dns_err_response(packet: IpPacket<'static>, code: ResponseCode) -> MutableIpPacket<'static> {
let udp = packet.unwrap_as_udp();
let query = packet.unwrap_as_dns();
debug_assert_ne!(code, ResponseCode::NoError);
let response = Message::error_msg(query.id(), query.op_code(), code);
let payload = response.to_vec().unwrap();
let payload = answers.finish();
udp_packet(
packet.destination(),

View File

@@ -4,12 +4,7 @@ version = "0.1.0"
edition = "2021"
[dependencies]
async-trait = { version = "0.1", optional = true }
hickory-proto = { workspace = true, optional = true }
quinn-udp = { git = "https://github.com/quinn-rs/quinn", branch = "main" }
socket2 = { workspace = true }
tokio = { version = "1.39", features = ["net"] }
tracing = "0.1"
[features]
hickory = ["dep:hickory-proto", "dep:async-trait"]

View File

@@ -301,62 +301,3 @@ impl UdpSocket {
Ok(Some(src))
}
}
#[cfg(feature = "hickory")]
mod hickory {
use super::*;
use hickory_proto::{
udp::DnsUdpSocket as DnsUdpSocketTrait, udp::UdpSocket as UdpSocketTrait, TokioTime,
};
use tokio::net::UdpSocket as TokioUdpSocket;
#[async_trait::async_trait]
impl UdpSocketTrait for crate::UdpSocket {
/// setups up a "client" udp connection that will only receive packets from the associated address
async fn connect(addr: SocketAddr) -> io::Result<Self> {
let inner = <TokioUdpSocket as UdpSocketTrait>::connect(addr).await?;
let socket = Self::new(inner)?;
Ok(socket)
}
/// same as connect, but binds to the specified local address for sending address
async fn connect_with_bind(addr: SocketAddr, bind_addr: SocketAddr) -> io::Result<Self> {
let inner =
<TokioUdpSocket as UdpSocketTrait>::connect_with_bind(addr, bind_addr).await?;
let socket = Self::new(inner)?;
Ok(socket)
}
/// a "server" UDP socket, that bind to the local listening address, and unbound remote address (can receive from anything)
async fn bind(addr: SocketAddr) -> io::Result<Self> {
let inner = <TokioUdpSocket as UdpSocketTrait>::bind(addr).await?;
let socket = Self::new(inner)?;
Ok(socket)
}
}
#[cfg(feature = "hickory")]
impl DnsUdpSocketTrait for crate::UdpSocket {
type Time = TokioTime;
fn poll_recv_from(
&self,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<(usize, SocketAddr)>> {
<TokioUdpSocket as DnsUdpSocketTrait>::poll_recv_from(&self.inner, cx, buf)
}
fn poll_send_to(
&self,
cx: &mut Context<'_>,
buf: &[u8],
target: SocketAddr,
) -> Poll<io::Result<usize>> {
<TokioUdpSocket as DnsUdpSocketTrait>::poll_send_to(&self.inner, cx, buf, target)
}
}
}