mirror of
https://github.com/outbackdingo/firezone.git
synced 2026-01-27 10:18:54 +00:00
refactor(connlib): don't rely on DNS when reconnecting to portal (#7289)
Currently, `connlib` uses the feature of "known hosts" to provide DNS functionality for some domains even without any network connectivity. This is primarily used to ensure that when we reconnect to the portal, we can resolve the domain name which allows us to then create network connections. With recent changes to how our phoenix-channel implementation works, this is actually no longer necessary. Currently, we re-resolve the domain every time we connect, even though we already resolved them when connecting to it for the first time. This step is unnecessary and we can simply directly use the previously resolved IP addresses for the portal domain.
This commit is contained in:
@@ -12,7 +12,7 @@ use eventloop::Command;
|
||||
use firezone_tunnel::ClientTunnel;
|
||||
use phoenix_channel::{PhoenixChannel, PublicKeyParam};
|
||||
use socket_factory::{SocketFactory, TcpSocket, UdpSocket};
|
||||
use std::collections::{BTreeMap, BTreeSet};
|
||||
use std::collections::BTreeSet;
|
||||
use std::net::IpAddr;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::mpsc::UnboundedReceiver;
|
||||
@@ -124,12 +124,7 @@ async fn connect<CB>(
|
||||
where
|
||||
CB: Callbacks + 'static,
|
||||
{
|
||||
let tunnel = ClientTunnel::new(
|
||||
tcp_socket_factory,
|
||||
udp_socket_factory,
|
||||
BTreeMap::from([(portal.server_host().to_owned(), portal.resolved_addresses())]),
|
||||
);
|
||||
|
||||
let tunnel = ClientTunnel::new(tcp_socket_factory, udp_socket_factory);
|
||||
let mut eventloop = Eventloop::new(tunnel, callbacks, portal, rx);
|
||||
|
||||
std::future::poll_fn(|cx| eventloop.poll(cx)).await?;
|
||||
|
||||
@@ -89,11 +89,10 @@ impl ClientTunnel {
|
||||
pub fn new(
|
||||
tcp_socket_factory: Arc<dyn SocketFactory<TcpSocket>>,
|
||||
udp_socket_factory: Arc<dyn SocketFactory<UdpSocket>>,
|
||||
known_hosts: BTreeMap<String, Vec<IpAddr>>,
|
||||
) -> Self {
|
||||
Self {
|
||||
io: Io::new(tcp_socket_factory, udp_socket_factory),
|
||||
role_state: ClientState::new(known_hosts, rand::random(), Instant::now()),
|
||||
role_state: ClientState::new(BTreeMap::default(), rand::random(), Instant::now()),
|
||||
ip4_read_buf: Box::new([0u8; MAX_UDP_SIZE]),
|
||||
ip6_read_buf: Box::new([0u8; MAX_UDP_SIZE]),
|
||||
encrypt_buf: Default::default(),
|
||||
|
||||
@@ -28,7 +28,7 @@ use tokio_tungstenite::{
|
||||
tungstenite::{handshake::client::Request, Message},
|
||||
MaybeTlsStream, WebSocketStream,
|
||||
};
|
||||
use url::{Host, Url};
|
||||
use url::Url;
|
||||
|
||||
pub use get_user_agent::get_user_agent;
|
||||
pub use login_url::{DeviceInfo, LoginUrl, LoginUrlError, NoParams, PublicKeyParam};
|
||||
@@ -72,22 +72,26 @@ enum State {
|
||||
impl State {
|
||||
fn connect(
|
||||
url: Url,
|
||||
addresses: Vec<SocketAddr>,
|
||||
user_agent: String,
|
||||
socket_factory: Arc<dyn SocketFactory<TcpSocket>>,
|
||||
) -> Self {
|
||||
Self::Connecting(create_and_connect_websocket(url, user_agent, socket_factory).boxed())
|
||||
Self::Connecting(
|
||||
create_and_connect_websocket(url, addresses, user_agent, socket_factory).boxed(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
async fn create_and_connect_websocket(
|
||||
url: Url,
|
||||
addresses: Vec<SocketAddr>,
|
||||
user_agent: String,
|
||||
socket_factory: Arc<dyn SocketFactory<TcpSocket>>,
|
||||
) -> Result<WebSocketStream<MaybeTlsStream<TcpStream>>, InternalError> {
|
||||
tracing::debug!(host = %url.host().unwrap(), %user_agent, "Connecting to portal");
|
||||
|
||||
let duration = Duration::from_secs(5);
|
||||
let socket = tokio::time::timeout(duration, make_socket(&url, &*socket_factory))
|
||||
let socket = tokio::time::timeout(duration, connect(addresses, &*socket_factory))
|
||||
.await
|
||||
.map_err(|_| InternalError::Timeout { duration })??;
|
||||
|
||||
@@ -98,28 +102,12 @@ async fn create_and_connect_websocket(
|
||||
Ok(stream)
|
||||
}
|
||||
|
||||
async fn make_socket(
|
||||
url: &Url,
|
||||
async fn connect(
|
||||
addresses: Vec<SocketAddr>,
|
||||
socket_factory: &dyn SocketFactory<TcpSocket>,
|
||||
) -> Result<TcpStream, InternalError> {
|
||||
let port = url
|
||||
.port_or_known_default()
|
||||
.expect("scheme to be http, https, ws or wss");
|
||||
let addrs: Vec<SocketAddr> = match url.host().ok_or(InternalError::InvalidUrl)? {
|
||||
Host::Domain(n) => tokio::net::lookup_host((n, port))
|
||||
.await
|
||||
.map_err(|_| InternalError::InvalidUrl)?
|
||||
.collect(),
|
||||
Host::Ipv6(ip) => {
|
||||
vec![(ip, port).into()]
|
||||
}
|
||||
Host::Ipv4(ip) => {
|
||||
vec![(ip, port).into()]
|
||||
}
|
||||
};
|
||||
|
||||
let mut last_error = None;
|
||||
for addr in addrs {
|
||||
for addr in addresses {
|
||||
let Ok(socket) = socket_factory(&addr) else {
|
||||
continue;
|
||||
};
|
||||
@@ -274,8 +262,7 @@ where
|
||||
let _span = telemetry_span!("resolve_portal_url", host = %host_and_port.0).entered();
|
||||
|
||||
// Statically resolve the host in the URL to a set of addresses.
|
||||
// We don't use these directly because we need to connect to the domain via TLS which requires a hostname.
|
||||
// We expose them to other components that deal with DNS stuff to ensure our domain always resolves to these IPs.
|
||||
// We use these when connecting the socket to avoid a dependency on DNS resolution later on.
|
||||
let resolved_addresses = host_and_port
|
||||
.to_socket_addrs()?
|
||||
.map(|addr| addr.ip())
|
||||
@@ -304,16 +291,6 @@ where
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns the addresses that have been resolved for our server host.
|
||||
pub fn resolved_addresses(&self) -> Vec<IpAddr> {
|
||||
self.resolved_addresses.clone()
|
||||
}
|
||||
|
||||
/// The host we are connecting / connected to.
|
||||
pub fn server_host(&self) -> &str {
|
||||
self.url_prototype.expose_secret().host_and_port().0
|
||||
}
|
||||
|
||||
/// Join the provided room.
|
||||
///
|
||||
/// If successful, a [`Event::JoinedRoom`] event will be emitted.
|
||||
@@ -352,7 +329,12 @@ where
|
||||
|
||||
// 2. Set state to `Connecting` without a timer.
|
||||
let user_agent = self.user_agent.clone();
|
||||
self.state = State::connect(url.clone(), user_agent, self.socket_factory.clone());
|
||||
self.state = State::connect(
|
||||
url.clone(),
|
||||
self.socket_addresses(),
|
||||
user_agent,
|
||||
self.socket_factory.clone(),
|
||||
);
|
||||
self.last_url = Some(url);
|
||||
|
||||
// 3. In case we were already re-connecting, we need to wake the suspended task.
|
||||
@@ -431,13 +413,19 @@ where
|
||||
.clone();
|
||||
let user_agent = self.user_agent.clone();
|
||||
let socket_factory = self.socket_factory.clone();
|
||||
let socket_addresses = self.socket_addresses();
|
||||
|
||||
tracing::debug!(error = std_dyn_err(&e), ?backoff, max_elapsed_time = ?self.reconnect_backoff.max_elapsed_time, "Reconnecting to portal on transient client error");
|
||||
|
||||
self.state = State::Connecting(Box::pin(async move {
|
||||
tokio::time::sleep(backoff).await;
|
||||
create_and_connect_websocket(secret_url, user_agent, socket_factory)
|
||||
.await
|
||||
create_and_connect_websocket(
|
||||
secret_url,
|
||||
socket_addresses,
|
||||
user_agent,
|
||||
socket_factory,
|
||||
)
|
||||
.await
|
||||
}));
|
||||
|
||||
continue;
|
||||
@@ -645,6 +633,15 @@ where
|
||||
|
||||
OutboundRequestId(next_id)
|
||||
}
|
||||
|
||||
fn socket_addresses(&self) -> Vec<SocketAddr> {
|
||||
let port = self.url_prototype.expose_secret().host_and_port().1;
|
||||
|
||||
self.resolved_addresses
|
||||
.iter()
|
||||
.map(|ip| SocketAddr::new(*ip, port))
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
|
||||
Reference in New Issue
Block a user