From de7d3bff89cdc8f6c410f458bed876456754a5cc Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Tue, 11 Nov 2025 14:24:36 +1100 Subject: [PATCH] fix(connlib): re-resolve portal host on WS hiccup (#10817) Currently, the DNS records for the portal's hostname are only resolved during startup. When the WebSocket connection fails, we try to reconnect but only with the IPs that we have previously resolved. If the local IP stack changed since then or the hostname now points to different IPs, we will run into the reconnect-timeout configured in `phoenix-channel`. To fix this, we re-resolve the portal's hostname every time the WebSocket connection fails. For the Gateway, this is easy as we can simply reuse the already existing `TokioResolver` provided by hickory. For the Client, we need to write our own DNS client on top of our socket factory abstraction to ensure we don't create a routing loop with the resulting DNS queries. To simplify things, we only send DNS queries over UDP. Those are not guaranteed to succeed but given that we do this on every "hiccup", we already have a retry mechanism. We use the currently configured upstream DNS servers for this. Resolves: #10238 --- rust/client-shared/Cargo.toml | 1 + rust/client-shared/src/eventloop.rs | 176 ++++++++++++++++++++++-- rust/client-shared/src/lib.rs | 2 + rust/connlib/dns-types/lib.rs | 12 ++ rust/connlib/phoenix-channel/src/lib.rs | 77 ++++++----- rust/gateway/src/eventloop.rs | 28 +++- 6 files changed, 252 insertions(+), 44 deletions(-) diff --git a/rust/client-shared/Cargo.toml b/rust/client-shared/Cargo.toml index d098aac00..15e4f5dbb 100644 --- a/rust/client-shared/Cargo.toml +++ b/rust/client-shared/Cargo.toml @@ -32,6 +32,7 @@ url = { workspace = true, features = ["serde"] } [dev-dependencies] chrono = { workspace = true } serde_json = { workspace = true, features = ["std"] } +tokio = { workspace = true, features = ["macros"] } [lints] workspace = true diff --git a/rust/client-shared/src/eventloop.rs b/rust/client-shared/src/eventloop.rs index 0e9c99aba..ec309da3e 100644 --- a/rust/client-shared/src/eventloop.rs +++ b/rust/client-shared/src/eventloop.rs @@ -1,6 +1,7 @@ use crate::PHOENIX_TOPIC; use anyhow::{Context as _, Result}; use connlib_model::{PublicKey, ResourceView}; +use dns_types::DomainName; use firezone_tunnel::messages::RelaysPresence; use firezone_tunnel::messages::client::{ EgressMessages, FailReason, FlowCreated, FlowCreationFailed, GatewayIceCandidates, @@ -9,21 +10,25 @@ use firezone_tunnel::messages::client::{ use firezone_tunnel::{ ClientEvent, ClientTunnel, DnsResourceRecord, IpConfig, TunConfig, TunnelError, }; +use futures::TryFutureExt; +use futures::stream::FuturesUnordered; use parking_lot::Mutex; use phoenix_channel::{ErrorReply, PhoenixChannel, PublicKeyParam}; use socket_factory::{SocketFactory, TcpSocket, UdpSocket}; +use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr}; use std::ops::ControlFlow; use std::pin::pin; use std::sync::Arc; -use std::time::Instant; +use std::time::{Duration, Instant}; use std::{ collections::BTreeSet, io, net::IpAddr, task::{Context, Poll}, }; -use std::{future, mem}; +use std::{future, iter, mem}; use tokio::sync::{mpsc, watch}; +use tokio_stream::StreamExt; use tun::Tun; /// In-memory cache for DNS resource records. @@ -71,6 +76,7 @@ pub enum Command { enum PortalCommand { Connect(PublicKeyParam), Send(EgressMessages), + UpdateDnsServers(Vec), } /// Unified error type to use across connlib. @@ -109,7 +115,7 @@ impl Eventloop { let tunnel = ClientTunnel::new( tcp_socket_factory, - udp_socket_factory, + udp_socket_factory.clone(), DNS_RESOURCE_RECORDS_CACHE.lock().clone(), is_internet_resource_active, ); @@ -120,6 +126,7 @@ impl Eventloop { portal, portal_event_tx, portal_cmd_rx, + UdpDnsClient::new(udp_socket_factory), )); Self { @@ -285,6 +292,12 @@ impl Eventloop { .context("Failed to emit event")?; } ClientEvent::TunInterfaceUpdated(config) => { + self.portal_cmd_tx + .send(PortalCommand::UpdateDnsServers( + config.dns_by_sentinel.upstream_sockets(), + )) + .await + .context("Failed to send message to portal")?; self.tun_config_sender .send(Some(config)) .context("Failed to emit event")?; @@ -494,6 +507,7 @@ async fn phoenix_channel_event_loop( mut portal: PhoenixChannel<(), IngressMessages, PublicKeyParam>, event_tx: mpsc::Sender>, mut cmd_rx: mpsc::Receiver, + mut udp_dns_client: UdpDnsClient, ) { use futures::future::Either; use futures::future::select; @@ -534,11 +548,27 @@ async fn phoenix_channel_event_loop( error, }), _, - )) => tracing::info!( - ?backoff, - ?max_elapsed_time, - "Hiccup in portal connection: {error:#}" - ), + )) => { + tracing::info!( + ?backoff, + ?max_elapsed_time, + "Hiccup in portal connection: {error:#}" + ); + + let ips = match udp_dns_client + .resolve(portal.host()) + .await + .context("Failed to lookup portal host") + { + Ok(ips) => ips.into_iter().collect(), + Err(e) => { + tracing::debug!(host = %portal.host(), "{e:#}"); + continue; + } + }; + + portal.update_ips(ips); + } Either::Left((Err(e), _)) => { let _ = event_tx.send(Err(e)).await; // We don't care about the result because we are exiting anyway. @@ -550,6 +580,9 @@ async fn phoenix_channel_event_loop( Either::Right((Some(PortalCommand::Connect(param)), _)) => { portal.connect(param); } + Either::Right((Some(PortalCommand::UpdateDnsServers(servers)), _)) => { + udp_dns_client.servers = servers; + } Either::Right((None, _)) => { tracing::debug!("Command channel closed: exiting phoenix-channel event-loop"); @@ -569,3 +602,130 @@ fn is_unreachable(e: &io::Error) -> bool { || e.kind() == io::ErrorKind::HostUnreachable || e.kind() == io::ErrorKind::AddrNotAvailable } + +struct UdpDnsClient { + socket_factory: Arc>, + servers: Vec, +} + +impl UdpDnsClient { + const TIMEOUT: Duration = Duration::from_secs(2); + + fn new(socket_factory: Arc>) -> Self { + Self { + socket_factory, + servers: Vec::default(), + } + } + + async fn resolve(&self, host: String) -> Result> { + let host = DomainName::vec_from_str(&host).context("Failed to parse domain name")?; + let servers = self.servers.clone(); + + let (a_records, aaaa_records) = self + .servers + .iter() + .map(|socket| { + futures::future::try_join( + self.send( + *socket, + dns_types::Query::new(host.clone(), dns_types::RecordType::A), + ), + self.send( + *socket, + dns_types::Query::new(host.clone(), dns_types::RecordType::AAAA), + ), + ) + .map_err(|e| { + tracing::debug!(%host, "DNS query failed: {e:#}"); + + e + }) + }) + .collect::>() + .filter_map(|result| result.ok()) + .filter(|(a, b)| { + a.response_code() == dns_types::ResponseCode::NOERROR + && b.response_code() == dns_types::ResponseCode::NOERROR + }) + .next() + .await + .with_context(|| { + format!("All DNS servers ({servers:?}) failed to resolve portal host '{host}'") + })?; + + let ips = iter::empty() + .chain( + a_records + .records() + .filter_map(dns_types::records::extract_ip), + ) + .chain( + aaaa_records + .records() + .filter_map(dns_types::records::extract_ip), + ) + .collect(); + + Ok(ips) + } + + async fn send( + &self, + server: SocketAddr, + query: dns_types::Query, + ) -> io::Result { + let bind_addr = match server { + SocketAddr::V4(_) => SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), 0), + SocketAddr::V6(_) => SocketAddr::new(Ipv6Addr::UNSPECIFIED.into(), 0), + }; + + // To avoid fragmentation, IP and thus also UDP packets can only reliably sent with an MTU of <= 1500 on the public Internet. + const BUF_SIZE: usize = 1500; + + let udp_socket = self.socket_factory.bind(bind_addr)?; + + let response = tokio::time::timeout( + Self::TIMEOUT, + udp_socket.handshake::(server, &query.into_bytes()), + ) + .await??; + + let response = dns_types::Response::parse(&response).map_err(io::Error::other)?; + + Ok(response) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + #[ignore = "Requires Internet"] + async fn udp_dns_client_can_resolve_host() { + let mut client = UdpDnsClient::new(Arc::new(socket_factory::udp)); + client.servers = vec![SocketAddr::new(IpAddr::from([1, 1, 1, 1]), 53)]; + + let ips = client.resolve("example.com".to_owned()).await.unwrap(); + + assert!(!ips.is_empty()) + } + + #[tokio::test] + #[ignore = "Requires Internet"] + async fn udp_dns_client_times_out_unreachable_host() { + let mut client = UdpDnsClient::new(Arc::new(socket_factory::udp)); + client.servers = vec![SocketAddr::new(IpAddr::from([2, 2, 2, 2]), 53)]; + + let now = Instant::now(); + + let error = client.resolve("example.com".to_owned()).await.unwrap_err(); + + assert_eq!( + error.to_string(), + "All DNS servers ([2.2.2.2:53]) failed to resolve portal host 'example.com'" + ); + assert!(now.elapsed() >= UdpDnsClient::TIMEOUT) + } +} diff --git a/rust/client-shared/src/lib.rs b/rust/client-shared/src/lib.rs index 7d88c6cbc..e711a18be 100644 --- a/rust/client-shared/src/lib.rs +++ b/rust/client-shared/src/lib.rs @@ -1,3 +1,5 @@ +#![cfg_attr(test, allow(clippy::unwrap_used))] + //! Main connlib library for clients. pub use connlib_model::StaticSecret; pub use eventloop::DisconnectError; diff --git a/rust/connlib/dns-types/lib.rs b/rust/connlib/dns-types/lib.rs index ade0f6087..3fb7a7be4 100644 --- a/rust/connlib/dns-types/lib.rs +++ b/rust/connlib/dns-types/lib.rs @@ -337,6 +337,18 @@ pub mod records { pub fn srv(priority: u16, weight: u16, port: u16, target: DomainName) -> OwnedRecordData { OwnedRecordData::Srv(Srv::new(priority, weight, port, target)) } + + #[expect( + clippy::wildcard_enum_match_arm, + reason = "We explicitly only want A and AAAA records." + )] + pub fn extract_ip(r: Record<'_>) -> Option { + match r.into_data() { + RecordData::A(a) => Some(a.addr().into()), + RecordData::Aaaa(aaaa) => Some(aaaa.addr().into()), + _ => None, + } + } } #[cfg(test)] diff --git a/rust/connlib/phoenix-channel/src/lib.rs b/rust/connlib/phoenix-channel/src/lib.rs index 6fa34f399..9fdfdab1b 100644 --- a/rust/connlib/phoenix-channel/src/lib.rs +++ b/rust/connlib/phoenix-channel/src/lib.rs @@ -64,6 +64,9 @@ pub struct PhoenixChannel { } enum State { + Reconnect { + backoff: Duration, + }, Connected(WebSocketStream>), Connecting( BoxFuture<'static, Result>, InternalError>>, @@ -357,6 +360,20 @@ where self.url_prototype.expose_secret().base_url() } + pub fn host(&self) -> String { + self.url_prototype + .expose_secret() + .host_and_port() + .0 + .to_owned() + } + + pub fn update_ips(&mut self, ips: Vec) { + tracing::debug!(host = %self.host(), current = ?self.resolved_addresses, new = ?ips, "Updating resolved IPs"); + + self.resolved_addresses = ips; + } + /// Initiate a graceful close of the connection. pub fn close(&mut self) -> Result<(), Connecting> { tracing::info!("Closing connection to portal"); @@ -366,7 +383,7 @@ where State::Closing(stream) | State::Connected(stream) => { self.state = State::Closing(stream); } - State::Closed => {} + State::Closed | State::Reconnect { .. } => {} } Ok(()) @@ -393,6 +410,33 @@ where Poll::Pending => return Poll::Pending, }, State::Connected(stream) => stream, + State::Reconnect { backoff } => { + let backoff = *backoff; + let socket_addresses = self.socket_addresses(); + let host = self.host(); + + let secret_url = self + .last_url + .as_ref() + .expect("should have last URL if we receive connection error") + .clone(); + let user_agent = self.user_agent.clone(); + let socket_factory = self.socket_factory.clone(); + + self.state = State::Connecting(Box::pin(async move { + tokio::time::sleep(backoff).await; + create_and_connect_websocket( + secret_url, + socket_addresses, + host, + user_agent, + socket_factory, + ) + .await + })); + + continue; + } State::Connecting(future) => match future.poll_unpin(cx) { Poll::Ready(Ok(stream)) => { self.reconnect_backoff = None; @@ -423,9 +467,6 @@ where return Poll::Ready(Err(Error::FatalIo(io))); } Poll::Ready(Err(e)) => { - let socket_addresses = self.socket_addresses(); - let host = self.host(); - let backoff = match self.reconnect_backoff.as_mut() { Some(reconnect_backoff) => reconnect_backoff .next_backoff() @@ -439,25 +480,7 @@ where } }; - let secret_url = self - .last_url - .as_ref() - .expect("should have last URL if we receive connection error") - .clone(); - let user_agent = self.user_agent.clone(); - let socket_factory = self.socket_factory.clone(); - - self.state = State::Connecting(Box::pin(async move { - tokio::time::sleep(backoff).await; - create_and_connect_websocket( - secret_url, - socket_addresses, - host, - user_agent, - socket_factory, - ) - .await - })); + self.state = State::Reconnect { backoff }; return Poll::Ready(Ok(Event::Hiccup { backoff, @@ -694,14 +717,6 @@ where .map(|ip| SocketAddr::new(*ip, port)) .collect() } - - fn host(&self) -> String { - self.url_prototype - .expose_secret() - .host_and_port() - .0 - .to_owned() - } } #[derive(Debug)] diff --git a/rust/gateway/src/eventloop.rs b/rust/gateway/src/eventloop.rs index cf5f47bf1..187ad9d4d 100644 --- a/rust/gateway/src/eventloop.rs +++ b/rust/gateway/src/eventloop.rs @@ -93,6 +93,7 @@ impl Eventloop { portal, portal_event_tx, portal_cmd_rx, + resolver.clone(), )); Ok(Self { @@ -696,6 +697,7 @@ async fn phoenix_channel_event_loop( mut portal: PhoenixChannel<(), IngressMessages, PublicKeyParam>, event_tx: mpsc::Sender>, mut cmd_rx: mpsc::Receiver, + resolver: TokioResolver, ) { use futures::future::Either; use futures::future::select; @@ -740,11 +742,27 @@ async fn phoenix_channel_event_loop( error, }), _, - )) => tracing::info!( - ?backoff, - ?max_elapsed_time, - "Hiccup in portal connection: {error:#}" - ), + )) => { + tracing::info!( + ?backoff, + ?max_elapsed_time, + "Hiccup in portal connection: {error:#}" + ); + + let ips = match resolver + .lookup_ip(portal.host()) + .await + .context("Failed to lookup portal host") + { + Ok(ips) => ips.into_iter().collect(), + Err(e) => { + tracing::debug!(host = %portal.host(), "{e:#}"); + continue; + } + }; + + portal.update_ips(ips); + } Either::Left((Err(e), _)) => { let _ = event_tx.send(Err(e)).await; // We don't care about the result because we are exiting anyway.