diff --git a/rust/client-shared/Cargo.toml b/rust/client-shared/Cargo.toml index d098aac00..15e4f5dbb 100644 --- a/rust/client-shared/Cargo.toml +++ b/rust/client-shared/Cargo.toml @@ -32,6 +32,7 @@ url = { workspace = true, features = ["serde"] } [dev-dependencies] chrono = { workspace = true } serde_json = { workspace = true, features = ["std"] } +tokio = { workspace = true, features = ["macros"] } [lints] workspace = true diff --git a/rust/client-shared/src/eventloop.rs b/rust/client-shared/src/eventloop.rs index 0e9c99aba..ec309da3e 100644 --- a/rust/client-shared/src/eventloop.rs +++ b/rust/client-shared/src/eventloop.rs @@ -1,6 +1,7 @@ use crate::PHOENIX_TOPIC; use anyhow::{Context as _, Result}; use connlib_model::{PublicKey, ResourceView}; +use dns_types::DomainName; use firezone_tunnel::messages::RelaysPresence; use firezone_tunnel::messages::client::{ EgressMessages, FailReason, FlowCreated, FlowCreationFailed, GatewayIceCandidates, @@ -9,21 +10,25 @@ use firezone_tunnel::messages::client::{ use firezone_tunnel::{ ClientEvent, ClientTunnel, DnsResourceRecord, IpConfig, TunConfig, TunnelError, }; +use futures::TryFutureExt; +use futures::stream::FuturesUnordered; use parking_lot::Mutex; use phoenix_channel::{ErrorReply, PhoenixChannel, PublicKeyParam}; use socket_factory::{SocketFactory, TcpSocket, UdpSocket}; +use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr}; use std::ops::ControlFlow; use std::pin::pin; use std::sync::Arc; -use std::time::Instant; +use std::time::{Duration, Instant}; use std::{ collections::BTreeSet, io, net::IpAddr, task::{Context, Poll}, }; -use std::{future, mem}; +use std::{future, iter, mem}; use tokio::sync::{mpsc, watch}; +use tokio_stream::StreamExt; use tun::Tun; /// In-memory cache for DNS resource records. @@ -71,6 +76,7 @@ pub enum Command { enum PortalCommand { Connect(PublicKeyParam), Send(EgressMessages), + UpdateDnsServers(Vec), } /// Unified error type to use across connlib. @@ -109,7 +115,7 @@ impl Eventloop { let tunnel = ClientTunnel::new( tcp_socket_factory, - udp_socket_factory, + udp_socket_factory.clone(), DNS_RESOURCE_RECORDS_CACHE.lock().clone(), is_internet_resource_active, ); @@ -120,6 +126,7 @@ impl Eventloop { portal, portal_event_tx, portal_cmd_rx, + UdpDnsClient::new(udp_socket_factory), )); Self { @@ -285,6 +292,12 @@ impl Eventloop { .context("Failed to emit event")?; } ClientEvent::TunInterfaceUpdated(config) => { + self.portal_cmd_tx + .send(PortalCommand::UpdateDnsServers( + config.dns_by_sentinel.upstream_sockets(), + )) + .await + .context("Failed to send message to portal")?; self.tun_config_sender .send(Some(config)) .context("Failed to emit event")?; @@ -494,6 +507,7 @@ async fn phoenix_channel_event_loop( mut portal: PhoenixChannel<(), IngressMessages, PublicKeyParam>, event_tx: mpsc::Sender>, mut cmd_rx: mpsc::Receiver, + mut udp_dns_client: UdpDnsClient, ) { use futures::future::Either; use futures::future::select; @@ -534,11 +548,27 @@ async fn phoenix_channel_event_loop( error, }), _, - )) => tracing::info!( - ?backoff, - ?max_elapsed_time, - "Hiccup in portal connection: {error:#}" - ), + )) => { + tracing::info!( + ?backoff, + ?max_elapsed_time, + "Hiccup in portal connection: {error:#}" + ); + + let ips = match udp_dns_client + .resolve(portal.host()) + .await + .context("Failed to lookup portal host") + { + Ok(ips) => ips.into_iter().collect(), + Err(e) => { + tracing::debug!(host = %portal.host(), "{e:#}"); + continue; + } + }; + + portal.update_ips(ips); + } Either::Left((Err(e), _)) => { let _ = event_tx.send(Err(e)).await; // We don't care about the result because we are exiting anyway. @@ -550,6 +580,9 @@ async fn phoenix_channel_event_loop( Either::Right((Some(PortalCommand::Connect(param)), _)) => { portal.connect(param); } + Either::Right((Some(PortalCommand::UpdateDnsServers(servers)), _)) => { + udp_dns_client.servers = servers; + } Either::Right((None, _)) => { tracing::debug!("Command channel closed: exiting phoenix-channel event-loop"); @@ -569,3 +602,130 @@ fn is_unreachable(e: &io::Error) -> bool { || e.kind() == io::ErrorKind::HostUnreachable || e.kind() == io::ErrorKind::AddrNotAvailable } + +struct UdpDnsClient { + socket_factory: Arc>, + servers: Vec, +} + +impl UdpDnsClient { + const TIMEOUT: Duration = Duration::from_secs(2); + + fn new(socket_factory: Arc>) -> Self { + Self { + socket_factory, + servers: Vec::default(), + } + } + + async fn resolve(&self, host: String) -> Result> { + let host = DomainName::vec_from_str(&host).context("Failed to parse domain name")?; + let servers = self.servers.clone(); + + let (a_records, aaaa_records) = self + .servers + .iter() + .map(|socket| { + futures::future::try_join( + self.send( + *socket, + dns_types::Query::new(host.clone(), dns_types::RecordType::A), + ), + self.send( + *socket, + dns_types::Query::new(host.clone(), dns_types::RecordType::AAAA), + ), + ) + .map_err(|e| { + tracing::debug!(%host, "DNS query failed: {e:#}"); + + e + }) + }) + .collect::>() + .filter_map(|result| result.ok()) + .filter(|(a, b)| { + a.response_code() == dns_types::ResponseCode::NOERROR + && b.response_code() == dns_types::ResponseCode::NOERROR + }) + .next() + .await + .with_context(|| { + format!("All DNS servers ({servers:?}) failed to resolve portal host '{host}'") + })?; + + let ips = iter::empty() + .chain( + a_records + .records() + .filter_map(dns_types::records::extract_ip), + ) + .chain( + aaaa_records + .records() + .filter_map(dns_types::records::extract_ip), + ) + .collect(); + + Ok(ips) + } + + async fn send( + &self, + server: SocketAddr, + query: dns_types::Query, + ) -> io::Result { + let bind_addr = match server { + SocketAddr::V4(_) => SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), 0), + SocketAddr::V6(_) => SocketAddr::new(Ipv6Addr::UNSPECIFIED.into(), 0), + }; + + // To avoid fragmentation, IP and thus also UDP packets can only reliably sent with an MTU of <= 1500 on the public Internet. + const BUF_SIZE: usize = 1500; + + let udp_socket = self.socket_factory.bind(bind_addr)?; + + let response = tokio::time::timeout( + Self::TIMEOUT, + udp_socket.handshake::(server, &query.into_bytes()), + ) + .await??; + + let response = dns_types::Response::parse(&response).map_err(io::Error::other)?; + + Ok(response) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + #[ignore = "Requires Internet"] + async fn udp_dns_client_can_resolve_host() { + let mut client = UdpDnsClient::new(Arc::new(socket_factory::udp)); + client.servers = vec![SocketAddr::new(IpAddr::from([1, 1, 1, 1]), 53)]; + + let ips = client.resolve("example.com".to_owned()).await.unwrap(); + + assert!(!ips.is_empty()) + } + + #[tokio::test] + #[ignore = "Requires Internet"] + async fn udp_dns_client_times_out_unreachable_host() { + let mut client = UdpDnsClient::new(Arc::new(socket_factory::udp)); + client.servers = vec![SocketAddr::new(IpAddr::from([2, 2, 2, 2]), 53)]; + + let now = Instant::now(); + + let error = client.resolve("example.com".to_owned()).await.unwrap_err(); + + assert_eq!( + error.to_string(), + "All DNS servers ([2.2.2.2:53]) failed to resolve portal host 'example.com'" + ); + assert!(now.elapsed() >= UdpDnsClient::TIMEOUT) + } +} diff --git a/rust/client-shared/src/lib.rs b/rust/client-shared/src/lib.rs index 7d88c6cbc..e711a18be 100644 --- a/rust/client-shared/src/lib.rs +++ b/rust/client-shared/src/lib.rs @@ -1,3 +1,5 @@ +#![cfg_attr(test, allow(clippy::unwrap_used))] + //! Main connlib library for clients. pub use connlib_model::StaticSecret; pub use eventloop::DisconnectError; diff --git a/rust/connlib/dns-types/lib.rs b/rust/connlib/dns-types/lib.rs index ade0f6087..3fb7a7be4 100644 --- a/rust/connlib/dns-types/lib.rs +++ b/rust/connlib/dns-types/lib.rs @@ -337,6 +337,18 @@ pub mod records { pub fn srv(priority: u16, weight: u16, port: u16, target: DomainName) -> OwnedRecordData { OwnedRecordData::Srv(Srv::new(priority, weight, port, target)) } + + #[expect( + clippy::wildcard_enum_match_arm, + reason = "We explicitly only want A and AAAA records." + )] + pub fn extract_ip(r: Record<'_>) -> Option { + match r.into_data() { + RecordData::A(a) => Some(a.addr().into()), + RecordData::Aaaa(aaaa) => Some(aaaa.addr().into()), + _ => None, + } + } } #[cfg(test)] diff --git a/rust/connlib/phoenix-channel/src/lib.rs b/rust/connlib/phoenix-channel/src/lib.rs index 6fa34f399..9fdfdab1b 100644 --- a/rust/connlib/phoenix-channel/src/lib.rs +++ b/rust/connlib/phoenix-channel/src/lib.rs @@ -64,6 +64,9 @@ pub struct PhoenixChannel { } enum State { + Reconnect { + backoff: Duration, + }, Connected(WebSocketStream>), Connecting( BoxFuture<'static, Result>, InternalError>>, @@ -357,6 +360,20 @@ where self.url_prototype.expose_secret().base_url() } + pub fn host(&self) -> String { + self.url_prototype + .expose_secret() + .host_and_port() + .0 + .to_owned() + } + + pub fn update_ips(&mut self, ips: Vec) { + tracing::debug!(host = %self.host(), current = ?self.resolved_addresses, new = ?ips, "Updating resolved IPs"); + + self.resolved_addresses = ips; + } + /// Initiate a graceful close of the connection. pub fn close(&mut self) -> Result<(), Connecting> { tracing::info!("Closing connection to portal"); @@ -366,7 +383,7 @@ where State::Closing(stream) | State::Connected(stream) => { self.state = State::Closing(stream); } - State::Closed => {} + State::Closed | State::Reconnect { .. } => {} } Ok(()) @@ -393,6 +410,33 @@ where Poll::Pending => return Poll::Pending, }, State::Connected(stream) => stream, + State::Reconnect { backoff } => { + let backoff = *backoff; + let socket_addresses = self.socket_addresses(); + let host = self.host(); + + let secret_url = self + .last_url + .as_ref() + .expect("should have last URL if we receive connection error") + .clone(); + let user_agent = self.user_agent.clone(); + let socket_factory = self.socket_factory.clone(); + + self.state = State::Connecting(Box::pin(async move { + tokio::time::sleep(backoff).await; + create_and_connect_websocket( + secret_url, + socket_addresses, + host, + user_agent, + socket_factory, + ) + .await + })); + + continue; + } State::Connecting(future) => match future.poll_unpin(cx) { Poll::Ready(Ok(stream)) => { self.reconnect_backoff = None; @@ -423,9 +467,6 @@ where return Poll::Ready(Err(Error::FatalIo(io))); } Poll::Ready(Err(e)) => { - let socket_addresses = self.socket_addresses(); - let host = self.host(); - let backoff = match self.reconnect_backoff.as_mut() { Some(reconnect_backoff) => reconnect_backoff .next_backoff() @@ -439,25 +480,7 @@ where } }; - let secret_url = self - .last_url - .as_ref() - .expect("should have last URL if we receive connection error") - .clone(); - let user_agent = self.user_agent.clone(); - let socket_factory = self.socket_factory.clone(); - - self.state = State::Connecting(Box::pin(async move { - tokio::time::sleep(backoff).await; - create_and_connect_websocket( - secret_url, - socket_addresses, - host, - user_agent, - socket_factory, - ) - .await - })); + self.state = State::Reconnect { backoff }; return Poll::Ready(Ok(Event::Hiccup { backoff, @@ -694,14 +717,6 @@ where .map(|ip| SocketAddr::new(*ip, port)) .collect() } - - fn host(&self) -> String { - self.url_prototype - .expose_secret() - .host_and_port() - .0 - .to_owned() - } } #[derive(Debug)] diff --git a/rust/gateway/src/eventloop.rs b/rust/gateway/src/eventloop.rs index cf5f47bf1..187ad9d4d 100644 --- a/rust/gateway/src/eventloop.rs +++ b/rust/gateway/src/eventloop.rs @@ -93,6 +93,7 @@ impl Eventloop { portal, portal_event_tx, portal_cmd_rx, + resolver.clone(), )); Ok(Self { @@ -696,6 +697,7 @@ async fn phoenix_channel_event_loop( mut portal: PhoenixChannel<(), IngressMessages, PublicKeyParam>, event_tx: mpsc::Sender>, mut cmd_rx: mpsc::Receiver, + resolver: TokioResolver, ) { use futures::future::Either; use futures::future::select; @@ -740,11 +742,27 @@ async fn phoenix_channel_event_loop( error, }), _, - )) => tracing::info!( - ?backoff, - ?max_elapsed_time, - "Hiccup in portal connection: {error:#}" - ), + )) => { + tracing::info!( + ?backoff, + ?max_elapsed_time, + "Hiccup in portal connection: {error:#}" + ); + + let ips = match resolver + .lookup_ip(portal.host()) + .await + .context("Failed to lookup portal host") + { + Ok(ips) => ips.into_iter().collect(), + Err(e) => { + tracing::debug!(host = %portal.host(), "{e:#}"); + continue; + } + }; + + portal.update_ips(ips); + } Either::Left((Err(e), _)) => { let _ = event_tx.send(Err(e)).await; // We don't care about the result because we are exiting anyway.