refactor(connlib): move DNS resolution into tunnel (#3652)

Previously, this mapping was not stored within the tunnel so we had to
perform the resolution further up. This has changed and the tunnel
itself now knows about this mapping. Thus, we can easily move the actual
DNS resolution also into the tunnel, thereby reducing the API surface of
`Tunnel` because we don't need the `write_dns_lookup_response` function.

This is crucial because it is the last place where `Tunnel` is being
cloned in #3391. With this sorted out the way, we can remove all `Arc`s
and locks from `Tunnel` as part of #3391.
This commit is contained in:
Thomas Eizinger
2024-02-16 11:29:31 +11:00
committed by GitHub
parent 75e447f9d4
commit 19bcaa9539
9 changed files with 96 additions and 145 deletions

1
rust/Cargo.lock generated
View File

@@ -1218,7 +1218,6 @@ dependencies = [
"chrono",
"connlib-shared",
"firezone-tunnel",
"hickory-resolver",
"ip_network",
"parking_lot",
"reqwest",

View File

@@ -27,7 +27,6 @@ time = { version = "0.3.34", features = ["formatting"] }
reqwest = { version = "0.11.22", default-features = false, features = ["stream", "rustls-tls"] }
tokio-tungstenite = { version = "0.21", default-features = false, features = ["connect", "handshake", "rustls-tls-webpki-roots"] }
async-compression = { version = "0.4.6", features = ["tokio", "gzip"] }
hickory-resolver = { workspace = true, features = ["tokio-runtime"] }
parking_lot = "0.12"
bimap = "0.6"
ip_network = { version = "0.4", default-features = false }

View File

@@ -6,7 +6,6 @@ use connlib_shared::control::Reason;
use connlib_shared::messages::{DnsServer, GatewayResponse, IpDnsServer};
use connlib_shared::IpProvider;
use ip_network::IpNetwork;
use std::collections::HashMap;
use std::net::IpAddr;
use std::path::PathBuf;
use std::str::FromStr;
@@ -25,8 +24,6 @@ use connlib_shared::{
};
use firezone_tunnel::{ClientState, Request, Tunnel};
use hickory_resolver::config::{NameServerConfig, Protocol, ResolverConfig};
use hickory_resolver::TokioAsyncResolver;
use reqwest::header::{CONTENT_ENCODING, CONTENT_TYPE};
use tokio::io::BufReader;
use tokio::sync::Mutex;
@@ -41,11 +38,6 @@ pub struct ControlPlane<CB: Callbacks> {
pub tunnel: Arc<Tunnel<CB, ClientState>>,
pub phoenix_channel: PhoenixSenderWithTopic,
pub tunnel_init: Mutex<bool>,
// It's a Mutex<Option<_>> because we need the init message to initialize the resolver
// also, in platforms with split DNS and no configured upstream dns this will be None.
//
// We could still initialize the resolver with no nameservers in those platforms...
pub fallback_resolver: parking_lot::Mutex<HashMap<IpAddr, TokioAsyncResolver>>,
}
fn effective_dns_servers(
@@ -76,22 +68,6 @@ fn effective_dns_servers(
.collect()
}
fn create_resolvers(
sentinel_mapping: BiMap<IpAddr, DnsServer>,
) -> HashMap<IpAddr, TokioAsyncResolver> {
sentinel_mapping
.iter()
.map(|(sentinel, srv)| {
let mut resolver_config = ResolverConfig::new();
resolver_config.add_name_server(NameServerConfig::new(srv.address(), Protocol::Udp));
(
*sentinel,
TokioAsyncResolver::tokio(resolver_config, Default::default()),
)
})
.collect()
}
fn sentinel_dns_mapping(dns: &[DnsServer]) -> BiMap<IpAddr, DnsServer> {
let mut ip_provider = IpProvider::new(
DNS_SENTINELS_V4.parse().unwrap(),
@@ -148,9 +124,6 @@ impl<CB: Callbacks + 'static> ControlPlane<CB> {
for resource_description in resources {
self.add_resource(resource_description);
}
// Note: watch out here we're holding 2 mutexes
*self.fallback_resolver.lock() = create_resolvers(sentinel_mapping);
} else {
tracing::info!("Firezone reinitializated");
}
@@ -395,25 +368,6 @@ impl<CB: Callbacks + 'static> ControlPlane<CB> {
// TODO: Clean up connection in `ClientState` here?
}
}
Ok(firezone_tunnel::Event::DnsQuery(query)) => {
// Until we handle it better on a gateway-like eventloop, making sure not to block the loop
let Some(resolver) = self
.fallback_resolver
.lock()
.get(&query.query.destination())
.cloned()
else {
return;
};
let tunnel = self.tunnel.clone();
tokio::spawn(async move {
let response = resolver.lookup(&query.name, query.record_type).await;
if let Err(err) = tunnel.write_dns_lookup_response(response, query.query) {
tracing::debug!(err = ?err, name = query.name, record_type = ?query.record_type, "DNS lookup failed: {err:#}");
}
});
}
Ok(firezone_tunnel::Event::RefreshResources { connections }) => {
let mut control_signaler = self.phoenix_channel.clone();
tokio::spawn(async move {
@@ -428,6 +382,9 @@ impl<CB: Callbacks + 'static> ControlPlane<CB> {
}
});
}
Ok(firezone_tunnel::Event::SendPacket(_)) => {
unimplemented!("Handled internally");
}
Err(e) => {
tracing::error!("Tunnel failed: {e}");
}

View File

@@ -12,7 +12,6 @@ use messages::IngressMessages;
use messages::Messages;
use messages::ReplyMessages;
use secrecy::{Secret, SecretString};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use tokio::time::{Interval, MissedTickBehavior};
@@ -178,7 +177,6 @@ where
tunnel: Arc::new(tunnel),
phoenix_channel: connection.sender_with_topic("client".to_owned()),
tunnel_init: Mutex::new(false),
fallback_resolver: parking_lot::Mutex::new(HashMap::new()),
};
tokio::spawn({

View File

@@ -26,7 +26,7 @@ boringtun = { workspace = true }
chrono = { workspace = true }
pnet_packet = { version = "0.34" }
futures-bounded = { workspace = true }
hickory-resolver = { workspace = true }
hickory-resolver = { workspace = true, features = ["tokio-runtime"] }
arc-swap = "1.6.0"
bimap = "0.6"
resolv-conf = "0.7.0"

View File

@@ -1,59 +0,0 @@
use core::fmt;
use std::{
collections::VecDeque,
task::{Context, Poll, Waker},
};
// Simple bounded queue for one-time events
#[derive(Debug, Clone)]
pub(crate) struct BoundedQueue<T> {
queue: VecDeque<T>,
limit: usize,
waker: Option<Waker>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) struct Full;
impl fmt::Display for Full {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Queue is full")
}
}
impl<T> BoundedQueue<T> {
pub(crate) fn with_capacity(cap: usize) -> BoundedQueue<T> {
BoundedQueue {
queue: VecDeque::with_capacity(cap),
limit: cap,
waker: None,
}
}
pub(crate) fn poll(&mut self, cx: &Context) -> Poll<T> {
if let Some(front) = self.queue.pop_front() {
return Poll::Ready(front);
}
self.waker = Some(cx.waker().clone());
Poll::Pending
}
fn at_capacity(&self) -> bool {
self.queue.len() == self.limit
}
pub(crate) fn push_back(&mut self, x: T) -> Result<(), Full> {
if self.at_capacity() {
return Err(Full);
}
self.queue.push_back(x);
if let Some(ref waker) = self.waker {
waker.wake_by_ref();
}
Ok(())
}
}

View File

@@ -1,4 +1,3 @@
use crate::bounded_queue::BoundedQueue;
use crate::device_channel::{Device, Packet};
use crate::ip_packet::{IpPacket, MutableIpPacket};
use crate::peer::PacketTransformClient;
@@ -17,12 +16,13 @@ use connlib_shared::{Callbacks, Dname, IpProvider};
use domain::base::Rtype;
use futures::channel::mpsc::Receiver;
use futures::stream;
use futures_bounded::{FuturesMap, PushError, StreamMap};
use hickory_resolver::lookup::Lookup;
use futures_bounded::{FuturesMap, FuturesTupleSet, PushError, StreamMap};
use ip_network::IpNetwork;
use ip_network_table::IpNetworkTable;
use itertools::Itertools;
use hickory_resolver::config::{NameServerConfig, Protocol, ResolverConfig};
use hickory_resolver::TokioAsyncResolver;
use rand_core::OsRng;
use std::collections::hash_map::Entry;
use std::collections::{HashMap, HashSet, VecDeque};
@@ -110,24 +110,6 @@ where
Ok(())
}
/// Writes the response to a DNS lookup
#[tracing::instrument(level = "trace", skip(self))]
pub fn write_dns_lookup_response(
&self,
response: hickory_resolver::error::ResolveResult<Lookup>,
query: IpPacket<'static>,
) -> connlib_shared::Result<()> {
if let Some(pkt) = dns::build_response_from_resolve_result(query, response)? {
let Some(ref device) = *self.device.load() else {
return Ok(());
};
device.write(pkt)?;
}
Ok(())
}
/// Sets the interface configuration and starts background tasks.
#[tracing::instrument(level = "trace", skip(self))]
pub fn set_interface(
@@ -157,7 +139,7 @@ where
return Err(errs.pop().unwrap());
}
self.role_state.lock().dns_mapping = dns_mapping;
self.role_state.lock().set_dns_mapping(dns_mapping);
let res_v4 = self.add_route(IPV4_RESOURCES.parse().unwrap());
let res_v6 = self.add_route(IPV6_RESOURCES.parse().unwrap());
@@ -228,13 +210,17 @@ pub struct ClientState {
#[allow(clippy::type_complexity)]
pub peers_by_ip: IpNetworkTable<ConnectedPeer<GatewayId, PacketTransformClient>>,
forwarded_dns_queries: BoundedQueue<DnsQuery<'static>>,
forwarded_dns_queries: FuturesTupleSet<
Result<hickory_resolver::lookup::Lookup, hickory_resolver::error::ResolveError>,
DnsQuery<'static>,
>,
pub ip_provider: IpProvider,
refresh_dns_timer: Interval,
pub dns_mapping: BiMap<IpAddr, DnsServer>,
dns_mapping: BiMap<IpAddr, DnsServer>,
dns_resolvers: HashMap<IpAddr, TokioAsyncResolver>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
@@ -579,6 +565,15 @@ impl ClientState {
}
}
pub fn set_dns_mapping(&mut self, mapping: BiMap<IpAddr, DnsServer>) {
self.dns_mapping = mapping.clone();
self.dns_resolvers = create_resolvers(mapping);
}
pub fn dns_mapping(&self) -> BiMap<IpAddr, DnsServer> {
self.dns_mapping.clone()
}
fn is_awaiting_connection_to_cidr(&self, destination: IpAddr) -> bool {
let Some(resource) = self.get_cidr_resource_by_destination(destination) else {
return false;
@@ -633,16 +628,48 @@ impl ClientState {
}
fn add_pending_dns_query(&mut self, query: DnsQuery) {
let upstream = query.query.destination();
let Some(resolver) = self.dns_resolvers.get(&upstream).cloned() else {
tracing::warn!(%upstream, "Dropping DNS query because of unknown upstream DNS server");
return;
};
let query = query.into_owned();
if self
.forwarded_dns_queries
.push_back(query.into_owned())
.try_push(
{
let name = query.name.clone();
let record_type = query.record_type;
async move { resolver.lookup(&name, record_type).await }
},
query,
)
.is_err()
{
tracing::warn!("Too many DNS queries, dropping new ones");
tracing::warn!("Too many DNS queries, dropping existing one");
}
}
}
fn create_resolvers(
sentinel_mapping: BiMap<IpAddr, DnsServer>,
) -> HashMap<IpAddr, TokioAsyncResolver> {
sentinel_mapping
.into_iter()
.map(|(sentinel, srv)| {
let mut resolver_config = ResolverConfig::new();
resolver_config.add_name_server(NameServerConfig::new(srv.address(), Protocol::Udp));
(
sentinel,
TokioAsyncResolver::tokio(resolver_config, Default::default()),
)
})
.collect()
}
impl Default for ClientState {
fn default() -> Self {
// With this single timer this might mean that some DNS are refreshed too often
@@ -665,7 +692,10 @@ impl Default for ClientState {
gateway_public_keys: Default::default(),
resources_gateways: Default::default(),
forwarded_dns_queries: BoundedQueue::with_capacity(DNS_QUERIES_QUEUE_SIZE),
forwarded_dns_queries: FuturesTupleSet::new(
Duration::from_secs(60),
DNS_QUERIES_QUEUE_SIZE,
),
gateway_preshared_keys: Default::default(),
// TODO: decide ip ranges
ip_provider: IpProvider::new(
@@ -680,6 +710,7 @@ impl Default for ClientState {
deferred_dns_queries: Default::default(),
refresh_dns_timer: interval,
dns_mapping: Default::default(),
dns_resolvers: Default::default(),
}
}
}
@@ -780,7 +811,25 @@ impl RoleState for ClientState {
return Poll::Ready(Event::RefreshResources { connections });
}
return self.forwarded_dns_queries.poll(cx).map(Event::DnsQuery);
match self.forwarded_dns_queries.poll_unpin(cx) {
Poll::Ready((Ok(response), query)) => {
match dns::build_response_from_resolve_result(query.query, response) {
Ok(Some(packet)) => return Poll::Ready(Event::SendPacket(packet)),
Ok(None) => continue,
Err(e) => {
tracing::warn!("Failed to build DNS response from lookup result: {e}");
continue;
}
}
}
Poll::Ready((Err(resolve_timeout), query)) => {
tracing::warn!(name = %query.name, server = %query.query.destination(), "DNS query timed out: {resolve_timeout}");
continue;
}
Poll::Pending => {}
}
return Poll::Pending;
}
}

View File

@@ -164,8 +164,7 @@ where
let (peer_sender, peer_receiver) = tokio::sync::mpsc::channel(PEER_QUEUE_SIZE);
peer.transform
.set_dns(self.role_state.lock().dns_mapping.clone());
peer.transform.set_dns(self.role_state.lock().dns_mapping());
start_handlers(
Arc::clone(self),

View File

@@ -30,7 +30,7 @@ use webrtc::{
interceptor::registry::Registry,
};
use arc_swap::ArcSwapOption;
use arc_swap::{access::Access, ArcSwapOption};
use futures_util::task::AtomicWaker;
use std::task::{ready, Context, Poll};
use std::{collections::HashMap, fmt, net::IpAddr, sync::Arc, time::Duration};
@@ -56,7 +56,6 @@ use connlib_shared::messages::{ClientId, SecretKey};
use device_channel::Device;
use index::IndexLfsr;
mod bounded_queue;
mod client;
mod control_protocol;
mod device_channel;
@@ -365,8 +364,18 @@ where
}
}
if let Poll::Ready(event) = self.role_state.lock().poll_next_event(cx) {
return Poll::Ready(event);
match self.role_state.lock().poll_next_event(cx) {
Poll::Ready(Event::SendPacket(packet)) => {
let Some(device) = self.device.load().clone() else {
continue;
};
let _ = device.write(packet);
continue;
}
Poll::Ready(other) => return Poll::Ready(other),
_ => (),
}
return Poll::Pending;
@@ -448,7 +457,7 @@ pub enum Event<TId> {
RefreshResources {
connections: Vec<ReuseConnection>,
},
DnsQuery(DnsQuery<'static>),
SendPacket(device_channel::Packet<'static>),
}
impl<CB, TRoleState> Tunnel<CB, TRoleState>