refactor(connlib): don't rely on DNS when reconnecting to portal (#7289)

Currently, `connlib` uses the feature of "known hosts" to provide DNS
functionality for some domains even without any network connectivity.
This is primarily used to ensure that when we reconnect to the portal,
we can resolve the domain name which allows us to then create network
connections.

With recent changes to how our phoenix-channel implementation works,
this is actually no longer necessary. Currently, we re-resolve the
domain every time we connect, even though we already resolved them when
connecting to it for the first time. This step is unnecessary and we can
simply directly use the previously resolved IP addresses for the portal
domain.
This commit is contained in:
Thomas Eizinger
2024-11-08 05:06:42 +00:00
committed by GitHub
parent cdd3e4d25c
commit 4d2dc3dfcb
3 changed files with 37 additions and 46 deletions

View File

@@ -12,7 +12,7 @@ use eventloop::Command;
use firezone_tunnel::ClientTunnel;
use phoenix_channel::{PhoenixChannel, PublicKeyParam};
use socket_factory::{SocketFactory, TcpSocket, UdpSocket};
use std::collections::{BTreeMap, BTreeSet};
use std::collections::BTreeSet;
use std::net::IpAddr;
use std::sync::Arc;
use tokio::sync::mpsc::UnboundedReceiver;
@@ -124,12 +124,7 @@ async fn connect<CB>(
where
CB: Callbacks + 'static,
{
let tunnel = ClientTunnel::new(
tcp_socket_factory,
udp_socket_factory,
BTreeMap::from([(portal.server_host().to_owned(), portal.resolved_addresses())]),
);
let tunnel = ClientTunnel::new(tcp_socket_factory, udp_socket_factory);
let mut eventloop = Eventloop::new(tunnel, callbacks, portal, rx);
std::future::poll_fn(|cx| eventloop.poll(cx)).await?;

View File

@@ -89,11 +89,10 @@ impl ClientTunnel {
pub fn new(
tcp_socket_factory: Arc<dyn SocketFactory<TcpSocket>>,
udp_socket_factory: Arc<dyn SocketFactory<UdpSocket>>,
known_hosts: BTreeMap<String, Vec<IpAddr>>,
) -> Self {
Self {
io: Io::new(tcp_socket_factory, udp_socket_factory),
role_state: ClientState::new(known_hosts, rand::random(), Instant::now()),
role_state: ClientState::new(BTreeMap::default(), rand::random(), Instant::now()),
ip4_read_buf: Box::new([0u8; MAX_UDP_SIZE]),
ip6_read_buf: Box::new([0u8; MAX_UDP_SIZE]),
encrypt_buf: Default::default(),

View File

@@ -28,7 +28,7 @@ use tokio_tungstenite::{
tungstenite::{handshake::client::Request, Message},
MaybeTlsStream, WebSocketStream,
};
use url::{Host, Url};
use url::Url;
pub use get_user_agent::get_user_agent;
pub use login_url::{DeviceInfo, LoginUrl, LoginUrlError, NoParams, PublicKeyParam};
@@ -72,22 +72,26 @@ enum State {
impl State {
fn connect(
url: Url,
addresses: Vec<SocketAddr>,
user_agent: String,
socket_factory: Arc<dyn SocketFactory<TcpSocket>>,
) -> Self {
Self::Connecting(create_and_connect_websocket(url, user_agent, socket_factory).boxed())
Self::Connecting(
create_and_connect_websocket(url, addresses, user_agent, socket_factory).boxed(),
)
}
}
async fn create_and_connect_websocket(
url: Url,
addresses: Vec<SocketAddr>,
user_agent: String,
socket_factory: Arc<dyn SocketFactory<TcpSocket>>,
) -> Result<WebSocketStream<MaybeTlsStream<TcpStream>>, InternalError> {
tracing::debug!(host = %url.host().unwrap(), %user_agent, "Connecting to portal");
let duration = Duration::from_secs(5);
let socket = tokio::time::timeout(duration, make_socket(&url, &*socket_factory))
let socket = tokio::time::timeout(duration, connect(addresses, &*socket_factory))
.await
.map_err(|_| InternalError::Timeout { duration })??;
@@ -98,28 +102,12 @@ async fn create_and_connect_websocket(
Ok(stream)
}
async fn make_socket(
url: &Url,
async fn connect(
addresses: Vec<SocketAddr>,
socket_factory: &dyn SocketFactory<TcpSocket>,
) -> Result<TcpStream, InternalError> {
let port = url
.port_or_known_default()
.expect("scheme to be http, https, ws or wss");
let addrs: Vec<SocketAddr> = match url.host().ok_or(InternalError::InvalidUrl)? {
Host::Domain(n) => tokio::net::lookup_host((n, port))
.await
.map_err(|_| InternalError::InvalidUrl)?
.collect(),
Host::Ipv6(ip) => {
vec![(ip, port).into()]
}
Host::Ipv4(ip) => {
vec![(ip, port).into()]
}
};
let mut last_error = None;
for addr in addrs {
for addr in addresses {
let Ok(socket) = socket_factory(&addr) else {
continue;
};
@@ -274,8 +262,7 @@ where
let _span = telemetry_span!("resolve_portal_url", host = %host_and_port.0).entered();
// Statically resolve the host in the URL to a set of addresses.
// We don't use these directly because we need to connect to the domain via TLS which requires a hostname.
// We expose them to other components that deal with DNS stuff to ensure our domain always resolves to these IPs.
// We use these when connecting the socket to avoid a dependency on DNS resolution later on.
let resolved_addresses = host_and_port
.to_socket_addrs()?
.map(|addr| addr.ip())
@@ -304,16 +291,6 @@ where
})
}
/// Returns the addresses that have been resolved for our server host.
pub fn resolved_addresses(&self) -> Vec<IpAddr> {
self.resolved_addresses.clone()
}
/// The host we are connecting / connected to.
pub fn server_host(&self) -> &str {
self.url_prototype.expose_secret().host_and_port().0
}
/// Join the provided room.
///
/// If successful, a [`Event::JoinedRoom`] event will be emitted.
@@ -352,7 +329,12 @@ where
// 2. Set state to `Connecting` without a timer.
let user_agent = self.user_agent.clone();
self.state = State::connect(url.clone(), user_agent, self.socket_factory.clone());
self.state = State::connect(
url.clone(),
self.socket_addresses(),
user_agent,
self.socket_factory.clone(),
);
self.last_url = Some(url);
// 3. In case we were already re-connecting, we need to wake the suspended task.
@@ -431,13 +413,19 @@ where
.clone();
let user_agent = self.user_agent.clone();
let socket_factory = self.socket_factory.clone();
let socket_addresses = self.socket_addresses();
tracing::debug!(error = std_dyn_err(&e), ?backoff, max_elapsed_time = ?self.reconnect_backoff.max_elapsed_time, "Reconnecting to portal on transient client error");
self.state = State::Connecting(Box::pin(async move {
tokio::time::sleep(backoff).await;
create_and_connect_websocket(secret_url, user_agent, socket_factory)
.await
create_and_connect_websocket(
secret_url,
socket_addresses,
user_agent,
socket_factory,
)
.await
}));
continue;
@@ -645,6 +633,15 @@ where
OutboundRequestId(next_id)
}
fn socket_addresses(&self) -> Vec<SocketAddr> {
let port = self.url_prototype.expose_secret().host_and_port().1;
self.resolved_addresses
.iter()
.map(|ip| SocketAddr::new(*ip, port))
.collect()
}
}
#[derive(Debug)]