mirror of
https://github.com/outbackdingo/firezone.git
synced 2026-01-27 10:18:54 +00:00
refactor(connlib): move DNS resolution into tunnel (#3652)
Previously, this mapping was not stored within the tunnel so we had to perform the resolution further up. This has changed and the tunnel itself now knows about this mapping. Thus, we can easily move the actual DNS resolution also into the tunnel, thereby reducing the API surface of `Tunnel` because we don't need the `write_dns_lookup_response` function. This is crucial because it is the last place where `Tunnel` is being cloned in #3391. With this sorted out the way, we can remove all `Arc`s and locks from `Tunnel` as part of #3391.
This commit is contained in:
1
rust/Cargo.lock
generated
1
rust/Cargo.lock
generated
@@ -1218,7 +1218,6 @@ dependencies = [
|
||||
"chrono",
|
||||
"connlib-shared",
|
||||
"firezone-tunnel",
|
||||
"hickory-resolver",
|
||||
"ip_network",
|
||||
"parking_lot",
|
||||
"reqwest",
|
||||
|
||||
@@ -27,7 +27,6 @@ time = { version = "0.3.34", features = ["formatting"] }
|
||||
reqwest = { version = "0.11.22", default-features = false, features = ["stream", "rustls-tls"] }
|
||||
tokio-tungstenite = { version = "0.21", default-features = false, features = ["connect", "handshake", "rustls-tls-webpki-roots"] }
|
||||
async-compression = { version = "0.4.6", features = ["tokio", "gzip"] }
|
||||
hickory-resolver = { workspace = true, features = ["tokio-runtime"] }
|
||||
parking_lot = "0.12"
|
||||
bimap = "0.6"
|
||||
ip_network = { version = "0.4", default-features = false }
|
||||
|
||||
@@ -6,7 +6,6 @@ use connlib_shared::control::Reason;
|
||||
use connlib_shared::messages::{DnsServer, GatewayResponse, IpDnsServer};
|
||||
use connlib_shared::IpProvider;
|
||||
use ip_network::IpNetwork;
|
||||
use std::collections::HashMap;
|
||||
use std::net::IpAddr;
|
||||
use std::path::PathBuf;
|
||||
use std::str::FromStr;
|
||||
@@ -25,8 +24,6 @@ use connlib_shared::{
|
||||
};
|
||||
|
||||
use firezone_tunnel::{ClientState, Request, Tunnel};
|
||||
use hickory_resolver::config::{NameServerConfig, Protocol, ResolverConfig};
|
||||
use hickory_resolver::TokioAsyncResolver;
|
||||
use reqwest::header::{CONTENT_ENCODING, CONTENT_TYPE};
|
||||
use tokio::io::BufReader;
|
||||
use tokio::sync::Mutex;
|
||||
@@ -41,11 +38,6 @@ pub struct ControlPlane<CB: Callbacks> {
|
||||
pub tunnel: Arc<Tunnel<CB, ClientState>>,
|
||||
pub phoenix_channel: PhoenixSenderWithTopic,
|
||||
pub tunnel_init: Mutex<bool>,
|
||||
// It's a Mutex<Option<_>> because we need the init message to initialize the resolver
|
||||
// also, in platforms with split DNS and no configured upstream dns this will be None.
|
||||
//
|
||||
// We could still initialize the resolver with no nameservers in those platforms...
|
||||
pub fallback_resolver: parking_lot::Mutex<HashMap<IpAddr, TokioAsyncResolver>>,
|
||||
}
|
||||
|
||||
fn effective_dns_servers(
|
||||
@@ -76,22 +68,6 @@ fn effective_dns_servers(
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn create_resolvers(
|
||||
sentinel_mapping: BiMap<IpAddr, DnsServer>,
|
||||
) -> HashMap<IpAddr, TokioAsyncResolver> {
|
||||
sentinel_mapping
|
||||
.iter()
|
||||
.map(|(sentinel, srv)| {
|
||||
let mut resolver_config = ResolverConfig::new();
|
||||
resolver_config.add_name_server(NameServerConfig::new(srv.address(), Protocol::Udp));
|
||||
(
|
||||
*sentinel,
|
||||
TokioAsyncResolver::tokio(resolver_config, Default::default()),
|
||||
)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn sentinel_dns_mapping(dns: &[DnsServer]) -> BiMap<IpAddr, DnsServer> {
|
||||
let mut ip_provider = IpProvider::new(
|
||||
DNS_SENTINELS_V4.parse().unwrap(),
|
||||
@@ -148,9 +124,6 @@ impl<CB: Callbacks + 'static> ControlPlane<CB> {
|
||||
for resource_description in resources {
|
||||
self.add_resource(resource_description);
|
||||
}
|
||||
|
||||
// Note: watch out here we're holding 2 mutexes
|
||||
*self.fallback_resolver.lock() = create_resolvers(sentinel_mapping);
|
||||
} else {
|
||||
tracing::info!("Firezone reinitializated");
|
||||
}
|
||||
@@ -395,25 +368,6 @@ impl<CB: Callbacks + 'static> ControlPlane<CB> {
|
||||
// TODO: Clean up connection in `ClientState` here?
|
||||
}
|
||||
}
|
||||
Ok(firezone_tunnel::Event::DnsQuery(query)) => {
|
||||
// Until we handle it better on a gateway-like eventloop, making sure not to block the loop
|
||||
let Some(resolver) = self
|
||||
.fallback_resolver
|
||||
.lock()
|
||||
.get(&query.query.destination())
|
||||
.cloned()
|
||||
else {
|
||||
return;
|
||||
};
|
||||
|
||||
let tunnel = self.tunnel.clone();
|
||||
tokio::spawn(async move {
|
||||
let response = resolver.lookup(&query.name, query.record_type).await;
|
||||
if let Err(err) = tunnel.write_dns_lookup_response(response, query.query) {
|
||||
tracing::debug!(err = ?err, name = query.name, record_type = ?query.record_type, "DNS lookup failed: {err:#}");
|
||||
}
|
||||
});
|
||||
}
|
||||
Ok(firezone_tunnel::Event::RefreshResources { connections }) => {
|
||||
let mut control_signaler = self.phoenix_channel.clone();
|
||||
tokio::spawn(async move {
|
||||
@@ -428,6 +382,9 @@ impl<CB: Callbacks + 'static> ControlPlane<CB> {
|
||||
}
|
||||
});
|
||||
}
|
||||
Ok(firezone_tunnel::Event::SendPacket(_)) => {
|
||||
unimplemented!("Handled internally");
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("Tunnel failed: {e}");
|
||||
}
|
||||
|
||||
@@ -12,7 +12,6 @@ use messages::IngressMessages;
|
||||
use messages::Messages;
|
||||
use messages::ReplyMessages;
|
||||
use secrecy::{Secret, SecretString};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::time::{Interval, MissedTickBehavior};
|
||||
@@ -178,7 +177,6 @@ where
|
||||
tunnel: Arc::new(tunnel),
|
||||
phoenix_channel: connection.sender_with_topic("client".to_owned()),
|
||||
tunnel_init: Mutex::new(false),
|
||||
fallback_resolver: parking_lot::Mutex::new(HashMap::new()),
|
||||
};
|
||||
|
||||
tokio::spawn({
|
||||
|
||||
@@ -26,7 +26,7 @@ boringtun = { workspace = true }
|
||||
chrono = { workspace = true }
|
||||
pnet_packet = { version = "0.34" }
|
||||
futures-bounded = { workspace = true }
|
||||
hickory-resolver = { workspace = true }
|
||||
hickory-resolver = { workspace = true, features = ["tokio-runtime"] }
|
||||
arc-swap = "1.6.0"
|
||||
bimap = "0.6"
|
||||
resolv-conf = "0.7.0"
|
||||
|
||||
@@ -1,59 +0,0 @@
|
||||
use core::fmt;
|
||||
use std::{
|
||||
collections::VecDeque,
|
||||
task::{Context, Poll, Waker},
|
||||
};
|
||||
|
||||
// Simple bounded queue for one-time events
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct BoundedQueue<T> {
|
||||
queue: VecDeque<T>,
|
||||
limit: usize,
|
||||
waker: Option<Waker>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub(crate) struct Full;
|
||||
|
||||
impl fmt::Display for Full {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "Queue is full")
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> BoundedQueue<T> {
|
||||
pub(crate) fn with_capacity(cap: usize) -> BoundedQueue<T> {
|
||||
BoundedQueue {
|
||||
queue: VecDeque::with_capacity(cap),
|
||||
limit: cap,
|
||||
waker: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn poll(&mut self, cx: &Context) -> Poll<T> {
|
||||
if let Some(front) = self.queue.pop_front() {
|
||||
return Poll::Ready(front);
|
||||
}
|
||||
|
||||
self.waker = Some(cx.waker().clone());
|
||||
Poll::Pending
|
||||
}
|
||||
|
||||
fn at_capacity(&self) -> bool {
|
||||
self.queue.len() == self.limit
|
||||
}
|
||||
|
||||
pub(crate) fn push_back(&mut self, x: T) -> Result<(), Full> {
|
||||
if self.at_capacity() {
|
||||
return Err(Full);
|
||||
}
|
||||
|
||||
self.queue.push_back(x);
|
||||
|
||||
if let Some(ref waker) = self.waker {
|
||||
waker.wake_by_ref();
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,3 @@
|
||||
use crate::bounded_queue::BoundedQueue;
|
||||
use crate::device_channel::{Device, Packet};
|
||||
use crate::ip_packet::{IpPacket, MutableIpPacket};
|
||||
use crate::peer::PacketTransformClient;
|
||||
@@ -17,12 +16,13 @@ use connlib_shared::{Callbacks, Dname, IpProvider};
|
||||
use domain::base::Rtype;
|
||||
use futures::channel::mpsc::Receiver;
|
||||
use futures::stream;
|
||||
use futures_bounded::{FuturesMap, PushError, StreamMap};
|
||||
use hickory_resolver::lookup::Lookup;
|
||||
use futures_bounded::{FuturesMap, FuturesTupleSet, PushError, StreamMap};
|
||||
use ip_network::IpNetwork;
|
||||
use ip_network_table::IpNetworkTable;
|
||||
use itertools::Itertools;
|
||||
|
||||
use hickory_resolver::config::{NameServerConfig, Protocol, ResolverConfig};
|
||||
use hickory_resolver::TokioAsyncResolver;
|
||||
use rand_core::OsRng;
|
||||
use std::collections::hash_map::Entry;
|
||||
use std::collections::{HashMap, HashSet, VecDeque};
|
||||
@@ -110,24 +110,6 @@ where
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Writes the response to a DNS lookup
|
||||
#[tracing::instrument(level = "trace", skip(self))]
|
||||
pub fn write_dns_lookup_response(
|
||||
&self,
|
||||
response: hickory_resolver::error::ResolveResult<Lookup>,
|
||||
query: IpPacket<'static>,
|
||||
) -> connlib_shared::Result<()> {
|
||||
if let Some(pkt) = dns::build_response_from_resolve_result(query, response)? {
|
||||
let Some(ref device) = *self.device.load() else {
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
device.write(pkt)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Sets the interface configuration and starts background tasks.
|
||||
#[tracing::instrument(level = "trace", skip(self))]
|
||||
pub fn set_interface(
|
||||
@@ -157,7 +139,7 @@ where
|
||||
return Err(errs.pop().unwrap());
|
||||
}
|
||||
|
||||
self.role_state.lock().dns_mapping = dns_mapping;
|
||||
self.role_state.lock().set_dns_mapping(dns_mapping);
|
||||
|
||||
let res_v4 = self.add_route(IPV4_RESOURCES.parse().unwrap());
|
||||
let res_v6 = self.add_route(IPV6_RESOURCES.parse().unwrap());
|
||||
@@ -228,13 +210,17 @@ pub struct ClientState {
|
||||
#[allow(clippy::type_complexity)]
|
||||
pub peers_by_ip: IpNetworkTable<ConnectedPeer<GatewayId, PacketTransformClient>>,
|
||||
|
||||
forwarded_dns_queries: BoundedQueue<DnsQuery<'static>>,
|
||||
forwarded_dns_queries: FuturesTupleSet<
|
||||
Result<hickory_resolver::lookup::Lookup, hickory_resolver::error::ResolveError>,
|
||||
DnsQuery<'static>,
|
||||
>,
|
||||
|
||||
pub ip_provider: IpProvider,
|
||||
|
||||
refresh_dns_timer: Interval,
|
||||
|
||||
pub dns_mapping: BiMap<IpAddr, DnsServer>,
|
||||
dns_mapping: BiMap<IpAddr, DnsServer>,
|
||||
dns_resolvers: HashMap<IpAddr, TokioAsyncResolver>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
@@ -579,6 +565,15 @@ impl ClientState {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_dns_mapping(&mut self, mapping: BiMap<IpAddr, DnsServer>) {
|
||||
self.dns_mapping = mapping.clone();
|
||||
self.dns_resolvers = create_resolvers(mapping);
|
||||
}
|
||||
|
||||
pub fn dns_mapping(&self) -> BiMap<IpAddr, DnsServer> {
|
||||
self.dns_mapping.clone()
|
||||
}
|
||||
|
||||
fn is_awaiting_connection_to_cidr(&self, destination: IpAddr) -> bool {
|
||||
let Some(resource) = self.get_cidr_resource_by_destination(destination) else {
|
||||
return false;
|
||||
@@ -633,16 +628,48 @@ impl ClientState {
|
||||
}
|
||||
|
||||
fn add_pending_dns_query(&mut self, query: DnsQuery) {
|
||||
let upstream = query.query.destination();
|
||||
let Some(resolver) = self.dns_resolvers.get(&upstream).cloned() else {
|
||||
tracing::warn!(%upstream, "Dropping DNS query because of unknown upstream DNS server");
|
||||
return;
|
||||
};
|
||||
|
||||
let query = query.into_owned();
|
||||
|
||||
if self
|
||||
.forwarded_dns_queries
|
||||
.push_back(query.into_owned())
|
||||
.try_push(
|
||||
{
|
||||
let name = query.name.clone();
|
||||
let record_type = query.record_type;
|
||||
|
||||
async move { resolver.lookup(&name, record_type).await }
|
||||
},
|
||||
query,
|
||||
)
|
||||
.is_err()
|
||||
{
|
||||
tracing::warn!("Too many DNS queries, dropping new ones");
|
||||
tracing::warn!("Too many DNS queries, dropping existing one");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn create_resolvers(
|
||||
sentinel_mapping: BiMap<IpAddr, DnsServer>,
|
||||
) -> HashMap<IpAddr, TokioAsyncResolver> {
|
||||
sentinel_mapping
|
||||
.into_iter()
|
||||
.map(|(sentinel, srv)| {
|
||||
let mut resolver_config = ResolverConfig::new();
|
||||
resolver_config.add_name_server(NameServerConfig::new(srv.address(), Protocol::Udp));
|
||||
(
|
||||
sentinel,
|
||||
TokioAsyncResolver::tokio(resolver_config, Default::default()),
|
||||
)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
impl Default for ClientState {
|
||||
fn default() -> Self {
|
||||
// With this single timer this might mean that some DNS are refreshed too often
|
||||
@@ -665,7 +692,10 @@ impl Default for ClientState {
|
||||
|
||||
gateway_public_keys: Default::default(),
|
||||
resources_gateways: Default::default(),
|
||||
forwarded_dns_queries: BoundedQueue::with_capacity(DNS_QUERIES_QUEUE_SIZE),
|
||||
forwarded_dns_queries: FuturesTupleSet::new(
|
||||
Duration::from_secs(60),
|
||||
DNS_QUERIES_QUEUE_SIZE,
|
||||
),
|
||||
gateway_preshared_keys: Default::default(),
|
||||
// TODO: decide ip ranges
|
||||
ip_provider: IpProvider::new(
|
||||
@@ -680,6 +710,7 @@ impl Default for ClientState {
|
||||
deferred_dns_queries: Default::default(),
|
||||
refresh_dns_timer: interval,
|
||||
dns_mapping: Default::default(),
|
||||
dns_resolvers: Default::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -780,7 +811,25 @@ impl RoleState for ClientState {
|
||||
return Poll::Ready(Event::RefreshResources { connections });
|
||||
}
|
||||
|
||||
return self.forwarded_dns_queries.poll(cx).map(Event::DnsQuery);
|
||||
match self.forwarded_dns_queries.poll_unpin(cx) {
|
||||
Poll::Ready((Ok(response), query)) => {
|
||||
match dns::build_response_from_resolve_result(query.query, response) {
|
||||
Ok(Some(packet)) => return Poll::Ready(Event::SendPacket(packet)),
|
||||
Ok(None) => continue,
|
||||
Err(e) => {
|
||||
tracing::warn!("Failed to build DNS response from lookup result: {e}");
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
Poll::Ready((Err(resolve_timeout), query)) => {
|
||||
tracing::warn!(name = %query.name, server = %query.query.destination(), "DNS query timed out: {resolve_timeout}");
|
||||
continue;
|
||||
}
|
||||
Poll::Pending => {}
|
||||
}
|
||||
|
||||
return Poll::Pending;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -164,8 +164,7 @@ where
|
||||
|
||||
let (peer_sender, peer_receiver) = tokio::sync::mpsc::channel(PEER_QUEUE_SIZE);
|
||||
|
||||
peer.transform
|
||||
.set_dns(self.role_state.lock().dns_mapping.clone());
|
||||
peer.transform.set_dns(self.role_state.lock().dns_mapping());
|
||||
|
||||
start_handlers(
|
||||
Arc::clone(self),
|
||||
|
||||
@@ -30,7 +30,7 @@ use webrtc::{
|
||||
interceptor::registry::Registry,
|
||||
};
|
||||
|
||||
use arc_swap::ArcSwapOption;
|
||||
use arc_swap::{access::Access, ArcSwapOption};
|
||||
use futures_util::task::AtomicWaker;
|
||||
use std::task::{ready, Context, Poll};
|
||||
use std::{collections::HashMap, fmt, net::IpAddr, sync::Arc, time::Duration};
|
||||
@@ -56,7 +56,6 @@ use connlib_shared::messages::{ClientId, SecretKey};
|
||||
use device_channel::Device;
|
||||
use index::IndexLfsr;
|
||||
|
||||
mod bounded_queue;
|
||||
mod client;
|
||||
mod control_protocol;
|
||||
mod device_channel;
|
||||
@@ -365,8 +364,18 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
if let Poll::Ready(event) = self.role_state.lock().poll_next_event(cx) {
|
||||
return Poll::Ready(event);
|
||||
match self.role_state.lock().poll_next_event(cx) {
|
||||
Poll::Ready(Event::SendPacket(packet)) => {
|
||||
let Some(device) = self.device.load().clone() else {
|
||||
continue;
|
||||
};
|
||||
|
||||
let _ = device.write(packet);
|
||||
|
||||
continue;
|
||||
}
|
||||
Poll::Ready(other) => return Poll::Ready(other),
|
||||
_ => (),
|
||||
}
|
||||
|
||||
return Poll::Pending;
|
||||
@@ -448,7 +457,7 @@ pub enum Event<TId> {
|
||||
RefreshResources {
|
||||
connections: Vec<ReuseConnection>,
|
||||
},
|
||||
DnsQuery(DnsQuery<'static>),
|
||||
SendPacket(device_channel::Packet<'static>),
|
||||
}
|
||||
|
||||
impl<CB, TRoleState> Tunnel<CB, TRoleState>
|
||||
|
||||
Reference in New Issue
Block a user