fix(connlib): mangle packet for upstream dns as resource (#3134)

Fixes #3027 

Left a few TODO, will solve it when doing #3123 

Draft because we're still testing but it's almost ready
This commit is contained in:
Gabi
2024-01-09 18:08:07 -03:00
committed by GitHub
parent 33133d7448
commit 2af8d6096c
5 changed files with 120 additions and 14 deletions

View File

@@ -100,8 +100,6 @@ impl<CB: Callbacks + 'static> ControlPlane<CB> {
return Err(e);
} else {
*init = true;
*self.fallback_resolver.lock() =
create_resolver(interface.upstream_dns, self.tunnel.callbacks());
tracing::info!("Firezone Started!");
}
} else {
@@ -109,6 +107,9 @@ impl<CB: Callbacks + 'static> ControlPlane<CB> {
}
}
self.tunnel.set_upstream_dns(&interface.upstream_dns);
*self.fallback_resolver.lock() =
create_resolver(interface.upstream_dns, self.tunnel.callbacks());
for resource_description in resources {
self.add_resource(resource_description);
}

View File

@@ -9,8 +9,8 @@ use crate::{
use boringtun::x25519::{PublicKey, StaticSecret};
use connlib_shared::error::{ConnlibError as Error, ConnlibError};
use connlib_shared::messages::{
GatewayId, Interface as InterfaceConfig, Key, ResourceDescription, ResourceDescriptionCidr,
ResourceDescriptionDns, ResourceId, ReuseConnection, SecretKey,
DnsServer, GatewayId, Interface as InterfaceConfig, Key, ResourceDescription,
ResourceDescriptionCidr, ResourceDescriptionDns, ResourceId, ReuseConnection, SecretKey,
};
use connlib_shared::{Callbacks, Dname, DNS_SENTINEL};
use domain::base::Rtype;
@@ -25,7 +25,7 @@ use itertools::Itertools;
use rand_core::OsRng;
use std::collections::hash_map::Entry;
use std::collections::{HashMap, HashSet, VecDeque};
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Duration;
@@ -109,6 +109,30 @@ where
Ok(())
}
/// Sets the dns server.
pub fn set_upstream_dns(&self, upstream_dns: &[DnsServer]) {
let upstream_dns: HashSet<_> = upstream_dns
.iter()
.map(|dns| {
let DnsServer::IpPort(dns) = dns;
dns.address
})
.collect();
// TODO: assuming single dns
let Some(upstream_dns_srv) = upstream_dns.iter().next().cloned() else {
return;
};
self.role_state.lock().upstream_dns = upstream_dns;
self.role_state
.lock()
.peers_by_ip
.iter()
.for_each(|p| p.1.inner.transform.set_dns(upstream_dns_srv.ip()));
}
/// Sets the interface configuration and starts background tasks.
#[tracing::instrument(level = "trace", skip(self))]
pub fn set_interface(&self, config: &InterfaceConfig) -> connlib_shared::Result<()> {
@@ -199,6 +223,8 @@ pub struct ClientState {
pub ip_provider: IpProvider,
refresh_dns_timer: Interval,
pub upstream_dns: HashSet<SocketAddr>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
@@ -217,7 +243,7 @@ impl ClientState {
pub(crate) fn handle_dns<'a>(
&mut self,
packet: MutableIpPacket<'a>,
) -> Result<Option<Packet<'a>>, MutableIpPacket<'a>> {
) -> Result<Option<Packet<'a>>, (MutableIpPacket<'a>, IpAddr)> {
match dns::parse(
&self.dns_resources,
&self.dns_resources_internal_ips,
@@ -225,6 +251,19 @@ impl ClientState {
) {
Some(dns::ResolveStrategy::LocalResponse(query)) => Ok(Some(query)),
Some(dns::ResolveStrategy::ForwardQuery(query)) => {
// There's an edge case here, where the resolver's ip has been resolved before as
// a dns resource... we will ignore that weird case for now.
// Assuming a single upstream dns until #3123 lands
if let Some(upstream_dns) = self.upstream_dns.iter().next() {
if self
.cidr_resources
.longest_match(upstream_dns.ip())
.is_some()
{
return Err((packet, upstream_dns.ip()));
}
}
self.add_pending_dns_query(query);
Ok(None)
@@ -236,7 +275,10 @@ impl ClientState {
Ok(None)
}
None => Err(packet),
None => {
let dest = packet.destination();
Err((packet, dest))
}
}
}
@@ -694,6 +736,7 @@ impl Default for ClientState {
peers_by_ip: IpNetworkTable::new(),
deferred_dns_queries: Default::default(),
refresh_dns_timer: interval,
upstream_dns: Default::default(),
}
}
}
@@ -767,6 +810,11 @@ impl RoleState for ClientState {
if self.refresh_dns_timer.poll_tick(cx).is_ready() {
let mut connections = Vec::new();
self.peers_by_ip
.iter()
.for_each(|p| p.1.inner.transform.expire_dns_track());
for resource in self.dns_resources_internal_ips.keys() {
let Some(gateway_id) = self.resources_gateways.get(&resource.id) else {
continue;

View File

@@ -164,6 +164,10 @@ where
let (peer_sender, peer_receiver) = tokio::sync::mpsc::channel(PEER_QUEUE_SIZE);
if let Some(upstream_dns) = self.role_state.lock().upstream_dns.iter().next() {
peer.transform.set_dns(upstream_dns.ip());
}
start_handlers(
Arc::clone(self),
Arc::clone(&self.device),

View File

@@ -213,7 +213,7 @@ where
tracing::trace!(target: "wire", action = "read", from = "device", dest = %packet.destination());
let packet = match role_state.handle_dns(packet) {
let (packet, dest) = match role_state.handle_dns(packet) {
Ok(Some(response)) => {
device.write(response)?;
continue;
@@ -222,8 +222,6 @@ where
Err(non_dns_packet) => non_dns_packet,
};
let dest = packet.destination();
let Some(peer) = peer_by_ip(&role_state.peers_by_ip, dest) else {
role_state.on_connection_intent_ip(dest);
continue;

View File

@@ -1,14 +1,17 @@
use std::borrow::Cow;
use std::collections::VecDeque;
use std::collections::{HashMap, VecDeque};
use std::net::IpAddr;
use std::sync::Arc;
use std::time::Instant;
use arc_swap::ArcSwapOption;
use bimap::BiMap;
use boringtun::noise::rate_limiter::RateLimiter;
use boringtun::noise::{Tunn, TunnResult};
use boringtun::x25519::StaticSecret;
use bytes::Bytes;
use chrono::{DateTime, Utc};
use connlib_shared::DNS_SENTINEL;
use connlib_shared::{messages::ResourceDescription, Error, Result};
use ip_network::IpNetwork;
use ip_network_table::IpNetworkTable;
@@ -22,6 +25,10 @@ use crate::{device_channel, ip_packet::MutableIpPacket, PeerConfig};
type ExpiryingResource = (ResourceDescription, DateTime<Utc>);
// The max time a dns request can be configured to live in resolvconf
// is 30 seconds. See resolvconf(5) timeout.
const IDS_EXPIRE: std::time::Duration = std::time::Duration::from_secs(60);
pub(crate) struct Peer<TId, TTransform> {
tunnel: Mutex<Tunn>,
allowed_ips: RwLock<IpNetworkTable<()>>,
@@ -194,6 +201,9 @@ impl Default for PacketTransformGateway {
#[derive(Default)]
pub struct PacketTransformClient {
translations: RwLock<BiMap<IpAddr, IpAddr>>,
// TODO: Upstream dns could be something that's not an ip
upstream_dns: ArcSwapOption<IpAddr>,
mangled_dns_ids: Mutex<HashMap<u16, std::time::Instant>>,
}
impl PacketTransformClient {
@@ -212,6 +222,16 @@ impl PacketTransformClient {
translations.insert(proxy_ip, *ip);
Some(proxy_ip)
}
pub fn expire_dns_track(&self) {
self.mangled_dns_ids
.lock()
.retain(|_, exp| exp.elapsed() < IDS_EXPIRE);
}
pub fn set_dns(&self, ip_addr: IpAddr) {
self.upstream_dns.store(Some(Arc::new(ip_addr)));
}
}
impl PacketTransformGateway {
@@ -274,16 +294,37 @@ impl PacketTransform for PacketTransformClient {
packet: &'a mut [u8],
) -> Result<(device_channel::Packet<'a>, IpAddr)> {
let translations = self.translations.read();
let src = translations.get_by_right(addr).unwrap_or(addr);
let mut src = *translations.get_by_right(addr).unwrap_or(addr);
let Some(mut pkt) = MutableIpPacket::new(packet) else {
return Err(Error::BadPacket);
};
pkt.set_src(*src);
let original_src = src;
if self
.upstream_dns
.load()
.as_ref()
.is_some_and(|upstream| **upstream == src)
{
if let Some(dgm) = pkt.as_udp() {
if let Ok(message) = domain::base::Message::from_slice(dgm.payload()) {
if self
.mangled_dns_ids
.lock()
.remove(&message.header().id())
.is_some_and(|exp| exp.elapsed() < IDS_EXPIRE)
{
src = DNS_SENTINEL.into();
}
}
}
}
pkt.set_src(src);
pkt.update_checksum();
let packet = make_packet(packet, addr);
Ok((packet, *src))
Ok((packet, original_src))
}
fn packet_transform<'a>(&self, mut packet: MutableIpPacket<'a>) -> Option<MutableIpPacket<'a>> {
@@ -292,6 +333,20 @@ impl PacketTransform for PacketTransformClient {
packet.update_checksum();
}
if packet.destination() == DNS_SENTINEL {
if let Some(ip) = self.upstream_dns.load().as_ref() {
if let Some(dgm) = packet.as_udp() {
if let Ok(message) = domain::base::Message::from_slice(dgm.payload()) {
self.mangled_dns_ids
.lock()
.insert(message.header().id(), Instant::now());
packet.set_dst(**ip);
packet.update_checksum();
}
}
}
}
Some(packet)
}
}