fix(connlib): re-resolve portal host on WS hiccup (#10817)

Currently, the DNS records for the portal's hostname are only resolved
during startup. When the WebSocket connection fails, we try to reconnect
but only with the IPs that we have previously resolved. If the local IP
stack changed since then or the hostname now points to different IPs, we
will run into the reconnect-timeout configured in `phoenix-channel`.

To fix this, we re-resolve the portal's hostname every time the
WebSocket connection fails. For the Gateway, this is easy as we can
simply reuse the already existing `TokioResolver` provided by hickory.
For the Client, we need to write our own DNS client on top of our socket
factory abstraction to ensure we don't create a routing loop with the
resulting DNS queries. To simplify things, we only send DNS queries over
UDP. Those are not guaranteed to succeed but given that we do this on
every "hiccup", we already have a retry mechanism. We use the currently
configured upstream DNS servers for this.

Resolves: #10238
This commit is contained in:
Thomas Eizinger
2025-11-11 14:24:36 +11:00
committed by GitHub
parent 189c358975
commit de7d3bff89
6 changed files with 252 additions and 44 deletions

View File

@@ -32,6 +32,7 @@ url = { workspace = true, features = ["serde"] }
[dev-dependencies]
chrono = { workspace = true }
serde_json = { workspace = true, features = ["std"] }
tokio = { workspace = true, features = ["macros"] }
[lints]
workspace = true

View File

@@ -1,6 +1,7 @@
use crate::PHOENIX_TOPIC;
use anyhow::{Context as _, Result};
use connlib_model::{PublicKey, ResourceView};
use dns_types::DomainName;
use firezone_tunnel::messages::RelaysPresence;
use firezone_tunnel::messages::client::{
EgressMessages, FailReason, FlowCreated, FlowCreationFailed, GatewayIceCandidates,
@@ -9,21 +10,25 @@ use firezone_tunnel::messages::client::{
use firezone_tunnel::{
ClientEvent, ClientTunnel, DnsResourceRecord, IpConfig, TunConfig, TunnelError,
};
use futures::TryFutureExt;
use futures::stream::FuturesUnordered;
use parking_lot::Mutex;
use phoenix_channel::{ErrorReply, PhoenixChannel, PublicKeyParam};
use socket_factory::{SocketFactory, TcpSocket, UdpSocket};
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr};
use std::ops::ControlFlow;
use std::pin::pin;
use std::sync::Arc;
use std::time::Instant;
use std::time::{Duration, Instant};
use std::{
collections::BTreeSet,
io,
net::IpAddr,
task::{Context, Poll},
};
use std::{future, mem};
use std::{future, iter, mem};
use tokio::sync::{mpsc, watch};
use tokio_stream::StreamExt;
use tun::Tun;
/// In-memory cache for DNS resource records.
@@ -71,6 +76,7 @@ pub enum Command {
enum PortalCommand {
Connect(PublicKeyParam),
Send(EgressMessages),
UpdateDnsServers(Vec<SocketAddr>),
}
/// Unified error type to use across connlib.
@@ -109,7 +115,7 @@ impl Eventloop {
let tunnel = ClientTunnel::new(
tcp_socket_factory,
udp_socket_factory,
udp_socket_factory.clone(),
DNS_RESOURCE_RECORDS_CACHE.lock().clone(),
is_internet_resource_active,
);
@@ -120,6 +126,7 @@ impl Eventloop {
portal,
portal_event_tx,
portal_cmd_rx,
UdpDnsClient::new(udp_socket_factory),
));
Self {
@@ -285,6 +292,12 @@ impl Eventloop {
.context("Failed to emit event")?;
}
ClientEvent::TunInterfaceUpdated(config) => {
self.portal_cmd_tx
.send(PortalCommand::UpdateDnsServers(
config.dns_by_sentinel.upstream_sockets(),
))
.await
.context("Failed to send message to portal")?;
self.tun_config_sender
.send(Some(config))
.context("Failed to emit event")?;
@@ -494,6 +507,7 @@ async fn phoenix_channel_event_loop(
mut portal: PhoenixChannel<(), IngressMessages, PublicKeyParam>,
event_tx: mpsc::Sender<Result<IngressMessages, phoenix_channel::Error>>,
mut cmd_rx: mpsc::Receiver<PortalCommand>,
mut udp_dns_client: UdpDnsClient,
) {
use futures::future::Either;
use futures::future::select;
@@ -534,11 +548,27 @@ async fn phoenix_channel_event_loop(
error,
}),
_,
)) => tracing::info!(
?backoff,
?max_elapsed_time,
"Hiccup in portal connection: {error:#}"
),
)) => {
tracing::info!(
?backoff,
?max_elapsed_time,
"Hiccup in portal connection: {error:#}"
);
let ips = match udp_dns_client
.resolve(portal.host())
.await
.context("Failed to lookup portal host")
{
Ok(ips) => ips.into_iter().collect(),
Err(e) => {
tracing::debug!(host = %portal.host(), "{e:#}");
continue;
}
};
portal.update_ips(ips);
}
Either::Left((Err(e), _)) => {
let _ = event_tx.send(Err(e)).await; // We don't care about the result because we are exiting anyway.
@@ -550,6 +580,9 @@ async fn phoenix_channel_event_loop(
Either::Right((Some(PortalCommand::Connect(param)), _)) => {
portal.connect(param);
}
Either::Right((Some(PortalCommand::UpdateDnsServers(servers)), _)) => {
udp_dns_client.servers = servers;
}
Either::Right((None, _)) => {
tracing::debug!("Command channel closed: exiting phoenix-channel event-loop");
@@ -569,3 +602,130 @@ fn is_unreachable(e: &io::Error) -> bool {
|| e.kind() == io::ErrorKind::HostUnreachable
|| e.kind() == io::ErrorKind::AddrNotAvailable
}
struct UdpDnsClient {
socket_factory: Arc<dyn SocketFactory<UdpSocket>>,
servers: Vec<SocketAddr>,
}
impl UdpDnsClient {
const TIMEOUT: Duration = Duration::from_secs(2);
fn new(socket_factory: Arc<dyn SocketFactory<UdpSocket>>) -> Self {
Self {
socket_factory,
servers: Vec::default(),
}
}
async fn resolve(&self, host: String) -> Result<Vec<IpAddr>> {
let host = DomainName::vec_from_str(&host).context("Failed to parse domain name")?;
let servers = self.servers.clone();
let (a_records, aaaa_records) = self
.servers
.iter()
.map(|socket| {
futures::future::try_join(
self.send(
*socket,
dns_types::Query::new(host.clone(), dns_types::RecordType::A),
),
self.send(
*socket,
dns_types::Query::new(host.clone(), dns_types::RecordType::AAAA),
),
)
.map_err(|e| {
tracing::debug!(%host, "DNS query failed: {e:#}");
e
})
})
.collect::<FuturesUnordered<_>>()
.filter_map(|result| result.ok())
.filter(|(a, b)| {
a.response_code() == dns_types::ResponseCode::NOERROR
&& b.response_code() == dns_types::ResponseCode::NOERROR
})
.next()
.await
.with_context(|| {
format!("All DNS servers ({servers:?}) failed to resolve portal host '{host}'")
})?;
let ips = iter::empty()
.chain(
a_records
.records()
.filter_map(dns_types::records::extract_ip),
)
.chain(
aaaa_records
.records()
.filter_map(dns_types::records::extract_ip),
)
.collect();
Ok(ips)
}
async fn send(
&self,
server: SocketAddr,
query: dns_types::Query,
) -> io::Result<dns_types::Response> {
let bind_addr = match server {
SocketAddr::V4(_) => SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), 0),
SocketAddr::V6(_) => SocketAddr::new(Ipv6Addr::UNSPECIFIED.into(), 0),
};
// To avoid fragmentation, IP and thus also UDP packets can only reliably sent with an MTU of <= 1500 on the public Internet.
const BUF_SIZE: usize = 1500;
let udp_socket = self.socket_factory.bind(bind_addr)?;
let response = tokio::time::timeout(
Self::TIMEOUT,
udp_socket.handshake::<BUF_SIZE>(server, &query.into_bytes()),
)
.await??;
let response = dns_types::Response::parse(&response).map_err(io::Error::other)?;
Ok(response)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
#[ignore = "Requires Internet"]
async fn udp_dns_client_can_resolve_host() {
let mut client = UdpDnsClient::new(Arc::new(socket_factory::udp));
client.servers = vec![SocketAddr::new(IpAddr::from([1, 1, 1, 1]), 53)];
let ips = client.resolve("example.com".to_owned()).await.unwrap();
assert!(!ips.is_empty())
}
#[tokio::test]
#[ignore = "Requires Internet"]
async fn udp_dns_client_times_out_unreachable_host() {
let mut client = UdpDnsClient::new(Arc::new(socket_factory::udp));
client.servers = vec![SocketAddr::new(IpAddr::from([2, 2, 2, 2]), 53)];
let now = Instant::now();
let error = client.resolve("example.com".to_owned()).await.unwrap_err();
assert_eq!(
error.to_string(),
"All DNS servers ([2.2.2.2:53]) failed to resolve portal host 'example.com'"
);
assert!(now.elapsed() >= UdpDnsClient::TIMEOUT)
}
}

View File

@@ -1,3 +1,5 @@
#![cfg_attr(test, allow(clippy::unwrap_used))]
//! Main connlib library for clients.
pub use connlib_model::StaticSecret;
pub use eventloop::DisconnectError;

View File

@@ -337,6 +337,18 @@ pub mod records {
pub fn srv(priority: u16, weight: u16, port: u16, target: DomainName) -> OwnedRecordData {
OwnedRecordData::Srv(Srv::new(priority, weight, port, target))
}
#[expect(
clippy::wildcard_enum_match_arm,
reason = "We explicitly only want A and AAAA records."
)]
pub fn extract_ip(r: Record<'_>) -> Option<IpAddr> {
match r.into_data() {
RecordData::A(a) => Some(a.addr().into()),
RecordData::Aaaa(aaaa) => Some(aaaa.addr().into()),
_ => None,
}
}
}
#[cfg(test)]

View File

@@ -64,6 +64,9 @@ pub struct PhoenixChannel<TInitReq, TInboundMsg, TFinish> {
}
enum State {
Reconnect {
backoff: Duration,
},
Connected(WebSocketStream<MaybeTlsStream<TcpStream>>),
Connecting(
BoxFuture<'static, Result<WebSocketStream<MaybeTlsStream<TcpStream>>, InternalError>>,
@@ -357,6 +360,20 @@ where
self.url_prototype.expose_secret().base_url()
}
pub fn host(&self) -> String {
self.url_prototype
.expose_secret()
.host_and_port()
.0
.to_owned()
}
pub fn update_ips(&mut self, ips: Vec<IpAddr>) {
tracing::debug!(host = %self.host(), current = ?self.resolved_addresses, new = ?ips, "Updating resolved IPs");
self.resolved_addresses = ips;
}
/// Initiate a graceful close of the connection.
pub fn close(&mut self) -> Result<(), Connecting> {
tracing::info!("Closing connection to portal");
@@ -366,7 +383,7 @@ where
State::Closing(stream) | State::Connected(stream) => {
self.state = State::Closing(stream);
}
State::Closed => {}
State::Closed | State::Reconnect { .. } => {}
}
Ok(())
@@ -393,6 +410,33 @@ where
Poll::Pending => return Poll::Pending,
},
State::Connected(stream) => stream,
State::Reconnect { backoff } => {
let backoff = *backoff;
let socket_addresses = self.socket_addresses();
let host = self.host();
let secret_url = self
.last_url
.as_ref()
.expect("should have last URL if we receive connection error")
.clone();
let user_agent = self.user_agent.clone();
let socket_factory = self.socket_factory.clone();
self.state = State::Connecting(Box::pin(async move {
tokio::time::sleep(backoff).await;
create_and_connect_websocket(
secret_url,
socket_addresses,
host,
user_agent,
socket_factory,
)
.await
}));
continue;
}
State::Connecting(future) => match future.poll_unpin(cx) {
Poll::Ready(Ok(stream)) => {
self.reconnect_backoff = None;
@@ -423,9 +467,6 @@ where
return Poll::Ready(Err(Error::FatalIo(io)));
}
Poll::Ready(Err(e)) => {
let socket_addresses = self.socket_addresses();
let host = self.host();
let backoff = match self.reconnect_backoff.as_mut() {
Some(reconnect_backoff) => reconnect_backoff
.next_backoff()
@@ -439,25 +480,7 @@ where
}
};
let secret_url = self
.last_url
.as_ref()
.expect("should have last URL if we receive connection error")
.clone();
let user_agent = self.user_agent.clone();
let socket_factory = self.socket_factory.clone();
self.state = State::Connecting(Box::pin(async move {
tokio::time::sleep(backoff).await;
create_and_connect_websocket(
secret_url,
socket_addresses,
host,
user_agent,
socket_factory,
)
.await
}));
self.state = State::Reconnect { backoff };
return Poll::Ready(Ok(Event::Hiccup {
backoff,
@@ -694,14 +717,6 @@ where
.map(|ip| SocketAddr::new(*ip, port))
.collect()
}
fn host(&self) -> String {
self.url_prototype
.expose_secret()
.host_and_port()
.0
.to_owned()
}
}
#[derive(Debug)]

View File

@@ -93,6 +93,7 @@ impl Eventloop {
portal,
portal_event_tx,
portal_cmd_rx,
resolver.clone(),
));
Ok(Self {
@@ -696,6 +697,7 @@ async fn phoenix_channel_event_loop(
mut portal: PhoenixChannel<(), IngressMessages, PublicKeyParam>,
event_tx: mpsc::Sender<Result<IngressMessages, phoenix_channel::Error>>,
mut cmd_rx: mpsc::Receiver<PortalCommand>,
resolver: TokioResolver,
) {
use futures::future::Either;
use futures::future::select;
@@ -740,11 +742,27 @@ async fn phoenix_channel_event_loop(
error,
}),
_,
)) => tracing::info!(
?backoff,
?max_elapsed_time,
"Hiccup in portal connection: {error:#}"
),
)) => {
tracing::info!(
?backoff,
?max_elapsed_time,
"Hiccup in portal connection: {error:#}"
);
let ips = match resolver
.lookup_ip(portal.host())
.await
.context("Failed to lookup portal host")
{
Ok(ips) => ips.into_iter().collect(),
Err(e) => {
tracing::debug!(host = %portal.host(), "{e:#}");
continue;
}
};
portal.update_ips(ips);
}
Either::Left((Err(e), _)) => {
let _ = event_tx.send(Err(e)).await; // We don't care about the result because we are exiting anyway.