From 4d2dc3dfcb47fcca1c75e22350958eea42435295 Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Fri, 8 Nov 2024 05:06:42 +0000 Subject: [PATCH] refactor(connlib): don't rely on DNS when reconnecting to portal (#7289) Currently, `connlib` uses the feature of "known hosts" to provide DNS functionality for some domains even without any network connectivity. This is primarily used to ensure that when we reconnect to the portal, we can resolve the domain name which allows us to then create network connections. With recent changes to how our phoenix-channel implementation works, this is actually no longer necessary. Currently, we re-resolve the domain every time we connect, even though we already resolved them when connecting to it for the first time. This step is unnecessary and we can simply directly use the previously resolved IP addresses for the portal domain. --- rust/connlib/clients/shared/src/lib.rs | 9 +--- rust/connlib/tunnel/src/lib.rs | 3 +- rust/phoenix-channel/src/lib.rs | 71 ++++++++++++-------------- 3 files changed, 37 insertions(+), 46 deletions(-) diff --git a/rust/connlib/clients/shared/src/lib.rs b/rust/connlib/clients/shared/src/lib.rs index c9a524799..f4068220d 100644 --- a/rust/connlib/clients/shared/src/lib.rs +++ b/rust/connlib/clients/shared/src/lib.rs @@ -12,7 +12,7 @@ use eventloop::Command; use firezone_tunnel::ClientTunnel; use phoenix_channel::{PhoenixChannel, PublicKeyParam}; use socket_factory::{SocketFactory, TcpSocket, UdpSocket}; -use std::collections::{BTreeMap, BTreeSet}; +use std::collections::BTreeSet; use std::net::IpAddr; use std::sync::Arc; use tokio::sync::mpsc::UnboundedReceiver; @@ -124,12 +124,7 @@ async fn connect( where CB: Callbacks + 'static, { - let tunnel = ClientTunnel::new( - tcp_socket_factory, - udp_socket_factory, - BTreeMap::from([(portal.server_host().to_owned(), portal.resolved_addresses())]), - ); - + let tunnel = ClientTunnel::new(tcp_socket_factory, udp_socket_factory); let mut eventloop = Eventloop::new(tunnel, callbacks, portal, rx); std::future::poll_fn(|cx| eventloop.poll(cx)).await?; diff --git a/rust/connlib/tunnel/src/lib.rs b/rust/connlib/tunnel/src/lib.rs index 7c7069c7f..2524c34b6 100644 --- a/rust/connlib/tunnel/src/lib.rs +++ b/rust/connlib/tunnel/src/lib.rs @@ -89,11 +89,10 @@ impl ClientTunnel { pub fn new( tcp_socket_factory: Arc>, udp_socket_factory: Arc>, - known_hosts: BTreeMap>, ) -> Self { Self { io: Io::new(tcp_socket_factory, udp_socket_factory), - role_state: ClientState::new(known_hosts, rand::random(), Instant::now()), + role_state: ClientState::new(BTreeMap::default(), rand::random(), Instant::now()), ip4_read_buf: Box::new([0u8; MAX_UDP_SIZE]), ip6_read_buf: Box::new([0u8; MAX_UDP_SIZE]), encrypt_buf: Default::default(), diff --git a/rust/phoenix-channel/src/lib.rs b/rust/phoenix-channel/src/lib.rs index ea670ed20..b7e6b5298 100644 --- a/rust/phoenix-channel/src/lib.rs +++ b/rust/phoenix-channel/src/lib.rs @@ -28,7 +28,7 @@ use tokio_tungstenite::{ tungstenite::{handshake::client::Request, Message}, MaybeTlsStream, WebSocketStream, }; -use url::{Host, Url}; +use url::Url; pub use get_user_agent::get_user_agent; pub use login_url::{DeviceInfo, LoginUrl, LoginUrlError, NoParams, PublicKeyParam}; @@ -72,22 +72,26 @@ enum State { impl State { fn connect( url: Url, + addresses: Vec, user_agent: String, socket_factory: Arc>, ) -> Self { - Self::Connecting(create_and_connect_websocket(url, user_agent, socket_factory).boxed()) + Self::Connecting( + create_and_connect_websocket(url, addresses, user_agent, socket_factory).boxed(), + ) } } async fn create_and_connect_websocket( url: Url, + addresses: Vec, user_agent: String, socket_factory: Arc>, ) -> Result>, InternalError> { tracing::debug!(host = %url.host().unwrap(), %user_agent, "Connecting to portal"); let duration = Duration::from_secs(5); - let socket = tokio::time::timeout(duration, make_socket(&url, &*socket_factory)) + let socket = tokio::time::timeout(duration, connect(addresses, &*socket_factory)) .await .map_err(|_| InternalError::Timeout { duration })??; @@ -98,28 +102,12 @@ async fn create_and_connect_websocket( Ok(stream) } -async fn make_socket( - url: &Url, +async fn connect( + addresses: Vec, socket_factory: &dyn SocketFactory, ) -> Result { - let port = url - .port_or_known_default() - .expect("scheme to be http, https, ws or wss"); - let addrs: Vec = match url.host().ok_or(InternalError::InvalidUrl)? { - Host::Domain(n) => tokio::net::lookup_host((n, port)) - .await - .map_err(|_| InternalError::InvalidUrl)? - .collect(), - Host::Ipv6(ip) => { - vec![(ip, port).into()] - } - Host::Ipv4(ip) => { - vec![(ip, port).into()] - } - }; - let mut last_error = None; - for addr in addrs { + for addr in addresses { let Ok(socket) = socket_factory(&addr) else { continue; }; @@ -274,8 +262,7 @@ where let _span = telemetry_span!("resolve_portal_url", host = %host_and_port.0).entered(); // Statically resolve the host in the URL to a set of addresses. - // We don't use these directly because we need to connect to the domain via TLS which requires a hostname. - // We expose them to other components that deal with DNS stuff to ensure our domain always resolves to these IPs. + // We use these when connecting the socket to avoid a dependency on DNS resolution later on. let resolved_addresses = host_and_port .to_socket_addrs()? .map(|addr| addr.ip()) @@ -304,16 +291,6 @@ where }) } - /// Returns the addresses that have been resolved for our server host. - pub fn resolved_addresses(&self) -> Vec { - self.resolved_addresses.clone() - } - - /// The host we are connecting / connected to. - pub fn server_host(&self) -> &str { - self.url_prototype.expose_secret().host_and_port().0 - } - /// Join the provided room. /// /// If successful, a [`Event::JoinedRoom`] event will be emitted. @@ -352,7 +329,12 @@ where // 2. Set state to `Connecting` without a timer. let user_agent = self.user_agent.clone(); - self.state = State::connect(url.clone(), user_agent, self.socket_factory.clone()); + self.state = State::connect( + url.clone(), + self.socket_addresses(), + user_agent, + self.socket_factory.clone(), + ); self.last_url = Some(url); // 3. In case we were already re-connecting, we need to wake the suspended task. @@ -431,13 +413,19 @@ where .clone(); let user_agent = self.user_agent.clone(); let socket_factory = self.socket_factory.clone(); + let socket_addresses = self.socket_addresses(); tracing::debug!(error = std_dyn_err(&e), ?backoff, max_elapsed_time = ?self.reconnect_backoff.max_elapsed_time, "Reconnecting to portal on transient client error"); self.state = State::Connecting(Box::pin(async move { tokio::time::sleep(backoff).await; - create_and_connect_websocket(secret_url, user_agent, socket_factory) - .await + create_and_connect_websocket( + secret_url, + socket_addresses, + user_agent, + socket_factory, + ) + .await })); continue; @@ -645,6 +633,15 @@ where OutboundRequestId(next_id) } + + fn socket_addresses(&self) -> Vec { + let port = self.url_prototype.expose_secret().host_and_port().1; + + self.resolved_addresses + .iter() + .map(|ip| SocketAddr::new(*ip, port)) + .collect() + } } #[derive(Debug)]