mirror of
https://github.com/outbackdingo/firezone.git
synced 2026-01-27 10:18:54 +00:00
fix(connlib): re-resolve portal host on WS hiccup (#10817)
Currently, the DNS records for the portal's hostname are only resolved during startup. When the WebSocket connection fails, we try to reconnect but only with the IPs that we have previously resolved. If the local IP stack changed since then or the hostname now points to different IPs, we will run into the reconnect-timeout configured in `phoenix-channel`. To fix this, we re-resolve the portal's hostname every time the WebSocket connection fails. For the Gateway, this is easy as we can simply reuse the already existing `TokioResolver` provided by hickory. For the Client, we need to write our own DNS client on top of our socket factory abstraction to ensure we don't create a routing loop with the resulting DNS queries. To simplify things, we only send DNS queries over UDP. Those are not guaranteed to succeed but given that we do this on every "hiccup", we already have a retry mechanism. We use the currently configured upstream DNS servers for this. Resolves: #10238
This commit is contained in:
@@ -32,6 +32,7 @@ url = { workspace = true, features = ["serde"] }
|
||||
[dev-dependencies]
|
||||
chrono = { workspace = true }
|
||||
serde_json = { workspace = true, features = ["std"] }
|
||||
tokio = { workspace = true, features = ["macros"] }
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
use crate::PHOENIX_TOPIC;
|
||||
use anyhow::{Context as _, Result};
|
||||
use connlib_model::{PublicKey, ResourceView};
|
||||
use dns_types::DomainName;
|
||||
use firezone_tunnel::messages::RelaysPresence;
|
||||
use firezone_tunnel::messages::client::{
|
||||
EgressMessages, FailReason, FlowCreated, FlowCreationFailed, GatewayIceCandidates,
|
||||
@@ -9,21 +10,25 @@ use firezone_tunnel::messages::client::{
|
||||
use firezone_tunnel::{
|
||||
ClientEvent, ClientTunnel, DnsResourceRecord, IpConfig, TunConfig, TunnelError,
|
||||
};
|
||||
use futures::TryFutureExt;
|
||||
use futures::stream::FuturesUnordered;
|
||||
use parking_lot::Mutex;
|
||||
use phoenix_channel::{ErrorReply, PhoenixChannel, PublicKeyParam};
|
||||
use socket_factory::{SocketFactory, TcpSocket, UdpSocket};
|
||||
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr};
|
||||
use std::ops::ControlFlow;
|
||||
use std::pin::pin;
|
||||
use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
use std::time::{Duration, Instant};
|
||||
use std::{
|
||||
collections::BTreeSet,
|
||||
io,
|
||||
net::IpAddr,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
use std::{future, mem};
|
||||
use std::{future, iter, mem};
|
||||
use tokio::sync::{mpsc, watch};
|
||||
use tokio_stream::StreamExt;
|
||||
use tun::Tun;
|
||||
|
||||
/// In-memory cache for DNS resource records.
|
||||
@@ -71,6 +76,7 @@ pub enum Command {
|
||||
enum PortalCommand {
|
||||
Connect(PublicKeyParam),
|
||||
Send(EgressMessages),
|
||||
UpdateDnsServers(Vec<SocketAddr>),
|
||||
}
|
||||
|
||||
/// Unified error type to use across connlib.
|
||||
@@ -109,7 +115,7 @@ impl Eventloop {
|
||||
|
||||
let tunnel = ClientTunnel::new(
|
||||
tcp_socket_factory,
|
||||
udp_socket_factory,
|
||||
udp_socket_factory.clone(),
|
||||
DNS_RESOURCE_RECORDS_CACHE.lock().clone(),
|
||||
is_internet_resource_active,
|
||||
);
|
||||
@@ -120,6 +126,7 @@ impl Eventloop {
|
||||
portal,
|
||||
portal_event_tx,
|
||||
portal_cmd_rx,
|
||||
UdpDnsClient::new(udp_socket_factory),
|
||||
));
|
||||
|
||||
Self {
|
||||
@@ -285,6 +292,12 @@ impl Eventloop {
|
||||
.context("Failed to emit event")?;
|
||||
}
|
||||
ClientEvent::TunInterfaceUpdated(config) => {
|
||||
self.portal_cmd_tx
|
||||
.send(PortalCommand::UpdateDnsServers(
|
||||
config.dns_by_sentinel.upstream_sockets(),
|
||||
))
|
||||
.await
|
||||
.context("Failed to send message to portal")?;
|
||||
self.tun_config_sender
|
||||
.send(Some(config))
|
||||
.context("Failed to emit event")?;
|
||||
@@ -494,6 +507,7 @@ async fn phoenix_channel_event_loop(
|
||||
mut portal: PhoenixChannel<(), IngressMessages, PublicKeyParam>,
|
||||
event_tx: mpsc::Sender<Result<IngressMessages, phoenix_channel::Error>>,
|
||||
mut cmd_rx: mpsc::Receiver<PortalCommand>,
|
||||
mut udp_dns_client: UdpDnsClient,
|
||||
) {
|
||||
use futures::future::Either;
|
||||
use futures::future::select;
|
||||
@@ -534,11 +548,27 @@ async fn phoenix_channel_event_loop(
|
||||
error,
|
||||
}),
|
||||
_,
|
||||
)) => tracing::info!(
|
||||
?backoff,
|
||||
?max_elapsed_time,
|
||||
"Hiccup in portal connection: {error:#}"
|
||||
),
|
||||
)) => {
|
||||
tracing::info!(
|
||||
?backoff,
|
||||
?max_elapsed_time,
|
||||
"Hiccup in portal connection: {error:#}"
|
||||
);
|
||||
|
||||
let ips = match udp_dns_client
|
||||
.resolve(portal.host())
|
||||
.await
|
||||
.context("Failed to lookup portal host")
|
||||
{
|
||||
Ok(ips) => ips.into_iter().collect(),
|
||||
Err(e) => {
|
||||
tracing::debug!(host = %portal.host(), "{e:#}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
portal.update_ips(ips);
|
||||
}
|
||||
Either::Left((Err(e), _)) => {
|
||||
let _ = event_tx.send(Err(e)).await; // We don't care about the result because we are exiting anyway.
|
||||
|
||||
@@ -550,6 +580,9 @@ async fn phoenix_channel_event_loop(
|
||||
Either::Right((Some(PortalCommand::Connect(param)), _)) => {
|
||||
portal.connect(param);
|
||||
}
|
||||
Either::Right((Some(PortalCommand::UpdateDnsServers(servers)), _)) => {
|
||||
udp_dns_client.servers = servers;
|
||||
}
|
||||
Either::Right((None, _)) => {
|
||||
tracing::debug!("Command channel closed: exiting phoenix-channel event-loop");
|
||||
|
||||
@@ -569,3 +602,130 @@ fn is_unreachable(e: &io::Error) -> bool {
|
||||
|| e.kind() == io::ErrorKind::HostUnreachable
|
||||
|| e.kind() == io::ErrorKind::AddrNotAvailable
|
||||
}
|
||||
|
||||
struct UdpDnsClient {
|
||||
socket_factory: Arc<dyn SocketFactory<UdpSocket>>,
|
||||
servers: Vec<SocketAddr>,
|
||||
}
|
||||
|
||||
impl UdpDnsClient {
|
||||
const TIMEOUT: Duration = Duration::from_secs(2);
|
||||
|
||||
fn new(socket_factory: Arc<dyn SocketFactory<UdpSocket>>) -> Self {
|
||||
Self {
|
||||
socket_factory,
|
||||
servers: Vec::default(),
|
||||
}
|
||||
}
|
||||
|
||||
async fn resolve(&self, host: String) -> Result<Vec<IpAddr>> {
|
||||
let host = DomainName::vec_from_str(&host).context("Failed to parse domain name")?;
|
||||
let servers = self.servers.clone();
|
||||
|
||||
let (a_records, aaaa_records) = self
|
||||
.servers
|
||||
.iter()
|
||||
.map(|socket| {
|
||||
futures::future::try_join(
|
||||
self.send(
|
||||
*socket,
|
||||
dns_types::Query::new(host.clone(), dns_types::RecordType::A),
|
||||
),
|
||||
self.send(
|
||||
*socket,
|
||||
dns_types::Query::new(host.clone(), dns_types::RecordType::AAAA),
|
||||
),
|
||||
)
|
||||
.map_err(|e| {
|
||||
tracing::debug!(%host, "DNS query failed: {e:#}");
|
||||
|
||||
e
|
||||
})
|
||||
})
|
||||
.collect::<FuturesUnordered<_>>()
|
||||
.filter_map(|result| result.ok())
|
||||
.filter(|(a, b)| {
|
||||
a.response_code() == dns_types::ResponseCode::NOERROR
|
||||
&& b.response_code() == dns_types::ResponseCode::NOERROR
|
||||
})
|
||||
.next()
|
||||
.await
|
||||
.with_context(|| {
|
||||
format!("All DNS servers ({servers:?}) failed to resolve portal host '{host}'")
|
||||
})?;
|
||||
|
||||
let ips = iter::empty()
|
||||
.chain(
|
||||
a_records
|
||||
.records()
|
||||
.filter_map(dns_types::records::extract_ip),
|
||||
)
|
||||
.chain(
|
||||
aaaa_records
|
||||
.records()
|
||||
.filter_map(dns_types::records::extract_ip),
|
||||
)
|
||||
.collect();
|
||||
|
||||
Ok(ips)
|
||||
}
|
||||
|
||||
async fn send(
|
||||
&self,
|
||||
server: SocketAddr,
|
||||
query: dns_types::Query,
|
||||
) -> io::Result<dns_types::Response> {
|
||||
let bind_addr = match server {
|
||||
SocketAddr::V4(_) => SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), 0),
|
||||
SocketAddr::V6(_) => SocketAddr::new(Ipv6Addr::UNSPECIFIED.into(), 0),
|
||||
};
|
||||
|
||||
// To avoid fragmentation, IP and thus also UDP packets can only reliably sent with an MTU of <= 1500 on the public Internet.
|
||||
const BUF_SIZE: usize = 1500;
|
||||
|
||||
let udp_socket = self.socket_factory.bind(bind_addr)?;
|
||||
|
||||
let response = tokio::time::timeout(
|
||||
Self::TIMEOUT,
|
||||
udp_socket.handshake::<BUF_SIZE>(server, &query.into_bytes()),
|
||||
)
|
||||
.await??;
|
||||
|
||||
let response = dns_types::Response::parse(&response).map_err(io::Error::other)?;
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore = "Requires Internet"]
|
||||
async fn udp_dns_client_can_resolve_host() {
|
||||
let mut client = UdpDnsClient::new(Arc::new(socket_factory::udp));
|
||||
client.servers = vec![SocketAddr::new(IpAddr::from([1, 1, 1, 1]), 53)];
|
||||
|
||||
let ips = client.resolve("example.com".to_owned()).await.unwrap();
|
||||
|
||||
assert!(!ips.is_empty())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore = "Requires Internet"]
|
||||
async fn udp_dns_client_times_out_unreachable_host() {
|
||||
let mut client = UdpDnsClient::new(Arc::new(socket_factory::udp));
|
||||
client.servers = vec![SocketAddr::new(IpAddr::from([2, 2, 2, 2]), 53)];
|
||||
|
||||
let now = Instant::now();
|
||||
|
||||
let error = client.resolve("example.com".to_owned()).await.unwrap_err();
|
||||
|
||||
assert_eq!(
|
||||
error.to_string(),
|
||||
"All DNS servers ([2.2.2.2:53]) failed to resolve portal host 'example.com'"
|
||||
);
|
||||
assert!(now.elapsed() >= UdpDnsClient::TIMEOUT)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
#![cfg_attr(test, allow(clippy::unwrap_used))]
|
||||
|
||||
//! Main connlib library for clients.
|
||||
pub use connlib_model::StaticSecret;
|
||||
pub use eventloop::DisconnectError;
|
||||
|
||||
@@ -337,6 +337,18 @@ pub mod records {
|
||||
pub fn srv(priority: u16, weight: u16, port: u16, target: DomainName) -> OwnedRecordData {
|
||||
OwnedRecordData::Srv(Srv::new(priority, weight, port, target))
|
||||
}
|
||||
|
||||
#[expect(
|
||||
clippy::wildcard_enum_match_arm,
|
||||
reason = "We explicitly only want A and AAAA records."
|
||||
)]
|
||||
pub fn extract_ip(r: Record<'_>) -> Option<IpAddr> {
|
||||
match r.into_data() {
|
||||
RecordData::A(a) => Some(a.addr().into()),
|
||||
RecordData::Aaaa(aaaa) => Some(aaaa.addr().into()),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
@@ -64,6 +64,9 @@ pub struct PhoenixChannel<TInitReq, TInboundMsg, TFinish> {
|
||||
}
|
||||
|
||||
enum State {
|
||||
Reconnect {
|
||||
backoff: Duration,
|
||||
},
|
||||
Connected(WebSocketStream<MaybeTlsStream<TcpStream>>),
|
||||
Connecting(
|
||||
BoxFuture<'static, Result<WebSocketStream<MaybeTlsStream<TcpStream>>, InternalError>>,
|
||||
@@ -357,6 +360,20 @@ where
|
||||
self.url_prototype.expose_secret().base_url()
|
||||
}
|
||||
|
||||
pub fn host(&self) -> String {
|
||||
self.url_prototype
|
||||
.expose_secret()
|
||||
.host_and_port()
|
||||
.0
|
||||
.to_owned()
|
||||
}
|
||||
|
||||
pub fn update_ips(&mut self, ips: Vec<IpAddr>) {
|
||||
tracing::debug!(host = %self.host(), current = ?self.resolved_addresses, new = ?ips, "Updating resolved IPs");
|
||||
|
||||
self.resolved_addresses = ips;
|
||||
}
|
||||
|
||||
/// Initiate a graceful close of the connection.
|
||||
pub fn close(&mut self) -> Result<(), Connecting> {
|
||||
tracing::info!("Closing connection to portal");
|
||||
@@ -366,7 +383,7 @@ where
|
||||
State::Closing(stream) | State::Connected(stream) => {
|
||||
self.state = State::Closing(stream);
|
||||
}
|
||||
State::Closed => {}
|
||||
State::Closed | State::Reconnect { .. } => {}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
@@ -393,6 +410,33 @@ where
|
||||
Poll::Pending => return Poll::Pending,
|
||||
},
|
||||
State::Connected(stream) => stream,
|
||||
State::Reconnect { backoff } => {
|
||||
let backoff = *backoff;
|
||||
let socket_addresses = self.socket_addresses();
|
||||
let host = self.host();
|
||||
|
||||
let secret_url = self
|
||||
.last_url
|
||||
.as_ref()
|
||||
.expect("should have last URL if we receive connection error")
|
||||
.clone();
|
||||
let user_agent = self.user_agent.clone();
|
||||
let socket_factory = self.socket_factory.clone();
|
||||
|
||||
self.state = State::Connecting(Box::pin(async move {
|
||||
tokio::time::sleep(backoff).await;
|
||||
create_and_connect_websocket(
|
||||
secret_url,
|
||||
socket_addresses,
|
||||
host,
|
||||
user_agent,
|
||||
socket_factory,
|
||||
)
|
||||
.await
|
||||
}));
|
||||
|
||||
continue;
|
||||
}
|
||||
State::Connecting(future) => match future.poll_unpin(cx) {
|
||||
Poll::Ready(Ok(stream)) => {
|
||||
self.reconnect_backoff = None;
|
||||
@@ -423,9 +467,6 @@ where
|
||||
return Poll::Ready(Err(Error::FatalIo(io)));
|
||||
}
|
||||
Poll::Ready(Err(e)) => {
|
||||
let socket_addresses = self.socket_addresses();
|
||||
let host = self.host();
|
||||
|
||||
let backoff = match self.reconnect_backoff.as_mut() {
|
||||
Some(reconnect_backoff) => reconnect_backoff
|
||||
.next_backoff()
|
||||
@@ -439,25 +480,7 @@ where
|
||||
}
|
||||
};
|
||||
|
||||
let secret_url = self
|
||||
.last_url
|
||||
.as_ref()
|
||||
.expect("should have last URL if we receive connection error")
|
||||
.clone();
|
||||
let user_agent = self.user_agent.clone();
|
||||
let socket_factory = self.socket_factory.clone();
|
||||
|
||||
self.state = State::Connecting(Box::pin(async move {
|
||||
tokio::time::sleep(backoff).await;
|
||||
create_and_connect_websocket(
|
||||
secret_url,
|
||||
socket_addresses,
|
||||
host,
|
||||
user_agent,
|
||||
socket_factory,
|
||||
)
|
||||
.await
|
||||
}));
|
||||
self.state = State::Reconnect { backoff };
|
||||
|
||||
return Poll::Ready(Ok(Event::Hiccup {
|
||||
backoff,
|
||||
@@ -694,14 +717,6 @@ where
|
||||
.map(|ip| SocketAddr::new(*ip, port))
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn host(&self) -> String {
|
||||
self.url_prototype
|
||||
.expose_secret()
|
||||
.host_and_port()
|
||||
.0
|
||||
.to_owned()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
|
||||
@@ -93,6 +93,7 @@ impl Eventloop {
|
||||
portal,
|
||||
portal_event_tx,
|
||||
portal_cmd_rx,
|
||||
resolver.clone(),
|
||||
));
|
||||
|
||||
Ok(Self {
|
||||
@@ -696,6 +697,7 @@ async fn phoenix_channel_event_loop(
|
||||
mut portal: PhoenixChannel<(), IngressMessages, PublicKeyParam>,
|
||||
event_tx: mpsc::Sender<Result<IngressMessages, phoenix_channel::Error>>,
|
||||
mut cmd_rx: mpsc::Receiver<PortalCommand>,
|
||||
resolver: TokioResolver,
|
||||
) {
|
||||
use futures::future::Either;
|
||||
use futures::future::select;
|
||||
@@ -740,11 +742,27 @@ async fn phoenix_channel_event_loop(
|
||||
error,
|
||||
}),
|
||||
_,
|
||||
)) => tracing::info!(
|
||||
?backoff,
|
||||
?max_elapsed_time,
|
||||
"Hiccup in portal connection: {error:#}"
|
||||
),
|
||||
)) => {
|
||||
tracing::info!(
|
||||
?backoff,
|
||||
?max_elapsed_time,
|
||||
"Hiccup in portal connection: {error:#}"
|
||||
);
|
||||
|
||||
let ips = match resolver
|
||||
.lookup_ip(portal.host())
|
||||
.await
|
||||
.context("Failed to lookup portal host")
|
||||
{
|
||||
Ok(ips) => ips.into_iter().collect(),
|
||||
Err(e) => {
|
||||
tracing::debug!(host = %portal.host(), "{e:#}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
portal.update_ips(ips);
|
||||
}
|
||||
Either::Left((Err(e), _)) => {
|
||||
let _ = event_tx.send(Err(e)).await; // We don't care about the result because we are exiting anyway.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user