diff --git a/rust/connlib/clients/shared/src/lib.rs b/rust/connlib/clients/shared/src/lib.rs index c9a524799..f4068220d 100644 --- a/rust/connlib/clients/shared/src/lib.rs +++ b/rust/connlib/clients/shared/src/lib.rs @@ -12,7 +12,7 @@ use eventloop::Command; use firezone_tunnel::ClientTunnel; use phoenix_channel::{PhoenixChannel, PublicKeyParam}; use socket_factory::{SocketFactory, TcpSocket, UdpSocket}; -use std::collections::{BTreeMap, BTreeSet}; +use std::collections::BTreeSet; use std::net::IpAddr; use std::sync::Arc; use tokio::sync::mpsc::UnboundedReceiver; @@ -124,12 +124,7 @@ async fn connect( where CB: Callbacks + 'static, { - let tunnel = ClientTunnel::new( - tcp_socket_factory, - udp_socket_factory, - BTreeMap::from([(portal.server_host().to_owned(), portal.resolved_addresses())]), - ); - + let tunnel = ClientTunnel::new(tcp_socket_factory, udp_socket_factory); let mut eventloop = Eventloop::new(tunnel, callbacks, portal, rx); std::future::poll_fn(|cx| eventloop.poll(cx)).await?; diff --git a/rust/connlib/tunnel/src/lib.rs b/rust/connlib/tunnel/src/lib.rs index 7c7069c7f..2524c34b6 100644 --- a/rust/connlib/tunnel/src/lib.rs +++ b/rust/connlib/tunnel/src/lib.rs @@ -89,11 +89,10 @@ impl ClientTunnel { pub fn new( tcp_socket_factory: Arc>, udp_socket_factory: Arc>, - known_hosts: BTreeMap>, ) -> Self { Self { io: Io::new(tcp_socket_factory, udp_socket_factory), - role_state: ClientState::new(known_hosts, rand::random(), Instant::now()), + role_state: ClientState::new(BTreeMap::default(), rand::random(), Instant::now()), ip4_read_buf: Box::new([0u8; MAX_UDP_SIZE]), ip6_read_buf: Box::new([0u8; MAX_UDP_SIZE]), encrypt_buf: Default::default(), diff --git a/rust/phoenix-channel/src/lib.rs b/rust/phoenix-channel/src/lib.rs index ea670ed20..b7e6b5298 100644 --- a/rust/phoenix-channel/src/lib.rs +++ b/rust/phoenix-channel/src/lib.rs @@ -28,7 +28,7 @@ use tokio_tungstenite::{ tungstenite::{handshake::client::Request, Message}, MaybeTlsStream, WebSocketStream, }; -use url::{Host, Url}; +use url::Url; pub use get_user_agent::get_user_agent; pub use login_url::{DeviceInfo, LoginUrl, LoginUrlError, NoParams, PublicKeyParam}; @@ -72,22 +72,26 @@ enum State { impl State { fn connect( url: Url, + addresses: Vec, user_agent: String, socket_factory: Arc>, ) -> Self { - Self::Connecting(create_and_connect_websocket(url, user_agent, socket_factory).boxed()) + Self::Connecting( + create_and_connect_websocket(url, addresses, user_agent, socket_factory).boxed(), + ) } } async fn create_and_connect_websocket( url: Url, + addresses: Vec, user_agent: String, socket_factory: Arc>, ) -> Result>, InternalError> { tracing::debug!(host = %url.host().unwrap(), %user_agent, "Connecting to portal"); let duration = Duration::from_secs(5); - let socket = tokio::time::timeout(duration, make_socket(&url, &*socket_factory)) + let socket = tokio::time::timeout(duration, connect(addresses, &*socket_factory)) .await .map_err(|_| InternalError::Timeout { duration })??; @@ -98,28 +102,12 @@ async fn create_and_connect_websocket( Ok(stream) } -async fn make_socket( - url: &Url, +async fn connect( + addresses: Vec, socket_factory: &dyn SocketFactory, ) -> Result { - let port = url - .port_or_known_default() - .expect("scheme to be http, https, ws or wss"); - let addrs: Vec = match url.host().ok_or(InternalError::InvalidUrl)? { - Host::Domain(n) => tokio::net::lookup_host((n, port)) - .await - .map_err(|_| InternalError::InvalidUrl)? - .collect(), - Host::Ipv6(ip) => { - vec![(ip, port).into()] - } - Host::Ipv4(ip) => { - vec![(ip, port).into()] - } - }; - let mut last_error = None; - for addr in addrs { + for addr in addresses { let Ok(socket) = socket_factory(&addr) else { continue; }; @@ -274,8 +262,7 @@ where let _span = telemetry_span!("resolve_portal_url", host = %host_and_port.0).entered(); // Statically resolve the host in the URL to a set of addresses. - // We don't use these directly because we need to connect to the domain via TLS which requires a hostname. - // We expose them to other components that deal with DNS stuff to ensure our domain always resolves to these IPs. + // We use these when connecting the socket to avoid a dependency on DNS resolution later on. let resolved_addresses = host_and_port .to_socket_addrs()? .map(|addr| addr.ip()) @@ -304,16 +291,6 @@ where }) } - /// Returns the addresses that have been resolved for our server host. - pub fn resolved_addresses(&self) -> Vec { - self.resolved_addresses.clone() - } - - /// The host we are connecting / connected to. - pub fn server_host(&self) -> &str { - self.url_prototype.expose_secret().host_and_port().0 - } - /// Join the provided room. /// /// If successful, a [`Event::JoinedRoom`] event will be emitted. @@ -352,7 +329,12 @@ where // 2. Set state to `Connecting` without a timer. let user_agent = self.user_agent.clone(); - self.state = State::connect(url.clone(), user_agent, self.socket_factory.clone()); + self.state = State::connect( + url.clone(), + self.socket_addresses(), + user_agent, + self.socket_factory.clone(), + ); self.last_url = Some(url); // 3. In case we were already re-connecting, we need to wake the suspended task. @@ -431,13 +413,19 @@ where .clone(); let user_agent = self.user_agent.clone(); let socket_factory = self.socket_factory.clone(); + let socket_addresses = self.socket_addresses(); tracing::debug!(error = std_dyn_err(&e), ?backoff, max_elapsed_time = ?self.reconnect_backoff.max_elapsed_time, "Reconnecting to portal on transient client error"); self.state = State::Connecting(Box::pin(async move { tokio::time::sleep(backoff).await; - create_and_connect_websocket(secret_url, user_agent, socket_factory) - .await + create_and_connect_websocket( + secret_url, + socket_addresses, + user_agent, + socket_factory, + ) + .await })); continue; @@ -645,6 +633,15 @@ where OutboundRequestId(next_id) } + + fn socket_addresses(&self) -> Vec { + let port = self.url_prototype.expose_secret().host_and_port().1; + + self.resolved_addresses + .iter() + .map(|ip| SocketAddr::new(*ip, port)) + .collect() + } } #[derive(Debug)]