diff --git a/rust/connlib/clients/shared/src/callbacks.rs b/rust/connlib/clients/shared/src/callbacks.rs index 0cb52b36e..8d2d1f1b6 100644 --- a/rust/connlib/clients/shared/src/callbacks.rs +++ b/rust/connlib/clients/shared/src/callbacks.rs @@ -1,5 +1,4 @@ use connlib_shared::callbacks::ResourceDescription; -use firezone_tunnel::NoInterfaces; use ip_network::{Ipv4Network, Ipv6Network}; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; @@ -36,9 +35,6 @@ pub trait Callbacks: Clone + Send + Sync { /// Unified error type to use across connlib. #[derive(thiserror::Error, Debug)] pub enum DisconnectError { - /// Failed to bind to interfaces. - #[error(transparent)] - NoInterfaces(#[from] NoInterfaces), /// A panic occurred. #[error("Connlib panicked: {0}")] Panic(String), diff --git a/rust/connlib/clients/shared/src/eventloop.rs b/rust/connlib/clients/shared/src/eventloop.rs index f3e9613c5..9b26ce723 100644 --- a/rust/connlib/clients/shared/src/eventloop.rs +++ b/rust/connlib/clients/shared/src/eventloop.rs @@ -78,9 +78,7 @@ where } Poll::Ready(Some(Command::Reset)) => { self.portal.reconnect(); - if let Err(e) = self.tunnel.reset() { - tracing::warn!("Failed to reconnect tunnel: {e}"); - } + self.tunnel.reset(); continue; } diff --git a/rust/connlib/clients/shared/src/lib.rs b/rust/connlib/clients/shared/src/lib.rs index bf5ca477f..57aae5e86 100644 --- a/rust/connlib/clients/shared/src/lib.rs +++ b/rust/connlib/clients/shared/src/lib.rs @@ -136,7 +136,7 @@ where tcp_socket_factory, udp_socket_factory, BTreeMap::from([(portal.server_host().to_owned(), portal.resolved_addresses())]), - )?; + ); let mut eventloop = Eventloop::new(tunnel, callbacks, portal, rx); diff --git a/rust/connlib/tunnel/src/io.rs b/rust/connlib/tunnel/src/io.rs index a9221bebe..30f63c6c1 100644 --- a/rust/connlib/tunnel/src/io.rs +++ b/rust/connlib/tunnel/src/io.rs @@ -1,8 +1,4 @@ -use crate::{ - device_channel::Device, - sockets::{NoInterfaces, Sockets}, - BUF_SIZE, -}; +use crate::{device_channel::Device, sockets::Sockets, BUF_SIZE}; use futures_util::FutureExt as _; use ip_packet::{IpPacket, MutableIpPacket}; use socket_factory::{DatagramIn, DatagramOut, SocketFactory, TcpSocket, UdpSocket}; @@ -42,17 +38,21 @@ impl Io { pub fn new( tcp_socket_factory: Arc>, udp_socket_factory: Arc>, - ) -> Result { + ) -> Self { let mut sockets = Sockets::default(); - sockets.rebind(udp_socket_factory.as_ref())?; // Bind sockets on startup. Must happen within a tokio runtime context. + sockets.rebind(udp_socket_factory.as_ref()); // Bind sockets on startup. Must happen within a tokio runtime context. - Ok(Self { + Self { device: Device::new(), timeout: None, sockets, _tcp_socket_factory: tcp_socket_factory, udp_socket_factory, - }) + } + } + + pub fn poll_has_sockets(&mut self, cx: &mut Context<'_>) -> Poll<()> { + self.sockets.poll_has_sockets(cx) } pub fn poll<'b1, 'b2>( @@ -88,10 +88,8 @@ impl Io { &mut self.device } - pub fn rebind_sockets(&mut self) -> Result<(), NoInterfaces> { - self.sockets.rebind(self.udp_socket_factory.as_ref())?; - - Ok(()) + pub fn rebind_sockets(&mut self) { + self.sockets.rebind(self.udp_socket_factory.as_ref()); } pub fn reset_timeout(&mut self, timeout: Instant) { diff --git a/rust/connlib/tunnel/src/lib.rs b/rust/connlib/tunnel/src/lib.rs index 160abb80a..e5c5af62c 100644 --- a/rust/connlib/tunnel/src/lib.rs +++ b/rust/connlib/tunnel/src/lib.rs @@ -19,7 +19,7 @@ use std::{ collections::{BTreeMap, BTreeSet}, net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, sync::Arc, - task::{Context, Poll}, + task::{ready, Context, Poll}, time::Instant, }; use tun::Tun; @@ -64,7 +64,6 @@ pub type ClientTunnel = Tunnel; pub use client::{ClientState, Request}; pub use gateway::{GatewayState, IPV4_PEERS, IPV6_PEERS}; -pub use sockets::NoInterfaces; /// [`Tunnel`] glues together connlib's [`Io`] component and the respective (pure) state of a client or gateway. /// @@ -92,25 +91,25 @@ impl ClientTunnel { tcp_socket_factory: Arc>, udp_socket_factory: Arc>, known_hosts: BTreeMap>, - ) -> Result { - Ok(Self { - io: Io::new(tcp_socket_factory, udp_socket_factory)?, + ) -> Self { + Self { + io: Io::new(tcp_socket_factory, udp_socket_factory), role_state: ClientState::new(private_key, known_hosts, rand::random()), packet_buffer: Box::new([0u8; BUF_SIZE]), ip4_read_buf: Box::new([0u8; MAX_UDP_SIZE]), ip6_read_buf: Box::new([0u8; MAX_UDP_SIZE]), - }) + } } - pub fn reset(&mut self) -> Result<(), NoInterfaces> { + pub fn reset(&mut self) { self.role_state.reset(); - self.io.rebind_sockets()?; - - Ok(()) + self.io.rebind_sockets(); } pub fn poll_next_event(&mut self, cx: &mut Context<'_>) -> Poll> { for _ in 0..MAX_EVENTLOOP_ITERS { + ready!(self.io.poll_has_sockets(cx)); // Suspend everything if we don't have any sockets. + if let Some(e) = self.role_state.poll_event() { return Poll::Ready(Ok(e)); } @@ -182,14 +181,14 @@ impl GatewayTunnel { private_key: StaticSecret, tcp_socket_factory: Arc>, udp_socket_factory: Arc>, - ) -> Result { - Ok(Self { - io: Io::new(tcp_socket_factory, udp_socket_factory)?, + ) -> Self { + Self { + io: Io::new(tcp_socket_factory, udp_socket_factory), role_state: GatewayState::new(private_key, rand::random()), packet_buffer: Box::new([0u8; BUF_SIZE]), ip4_read_buf: Box::new([0u8; MAX_UDP_SIZE]), ip6_read_buf: Box::new([0u8; MAX_UDP_SIZE]), - }) + } } pub fn update_relays(&mut self, to_remove: BTreeSet, to_add: Vec) { @@ -199,6 +198,8 @@ impl GatewayTunnel { pub fn poll_next_event(&mut self, cx: &mut Context<'_>) -> Poll> { for _ in 0..MAX_EVENTLOOP_ITERS { + ready!(self.io.poll_has_sockets(cx)); // Suspend everything if we don't have any sockets. + if let Some(other) = self.role_state.poll_event() { return Poll::Ready(Ok(other)); } diff --git a/rust/connlib/tunnel/src/sockets.rs b/rust/connlib/tunnel/src/sockets.rs index 876086375..7d1326571 100644 --- a/rust/connlib/tunnel/src/sockets.rs +++ b/rust/connlib/tunnel/src/sockets.rs @@ -2,7 +2,7 @@ use socket_factory::{DatagramIn, DatagramOut, SocketFactory, UdpSocket}; use std::{ io, net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}, - task::{ready, Context, Poll}, + task::{ready, Context, Poll, Waker}, }; const UNSPECIFIED_V4_SOCKET: SocketAddrV4 = SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0); @@ -10,39 +10,39 @@ const UNSPECIFIED_V6_SOCKET: SocketAddrV6 = SocketAddrV6::new(Ipv6Addr::UNSPECIF #[derive(Default)] pub(crate) struct Sockets { + waker: Option, + socket_v4: Option, socket_v6: Option, } impl Sockets { - pub fn rebind( - &mut self, - socket_factory: &dyn SocketFactory, - ) -> Result<(), NoInterfaces> { - let socket_v4 = socket_factory(&SocketAddr::V4(UNSPECIFIED_V4_SOCKET)); - let socket_v6 = socket_factory(&SocketAddr::V6(UNSPECIFIED_V6_SOCKET)); + pub fn rebind(&mut self, socket_factory: &dyn SocketFactory) { + self.socket_v4 = socket_factory(&SocketAddr::V4(UNSPECIFIED_V4_SOCKET)) + .inspect_err(|e| tracing::warn!("Failed to bind IPv4 socket: {e}")) + .ok(); + self.socket_v6 = socket_factory(&SocketAddr::V6(UNSPECIFIED_V6_SOCKET)) + .inspect_err(|e| tracing::warn!("Failed to bind IPv6 socket: {e}")) + .ok(); - let (socket_v4, socket_v6) = match (socket_v4, socket_v6) { - (Err(e), Ok(socket)) => { - tracing::warn!("Failed to bind IPv4 socket: {e}"); + if let Some(waker) = self.waker.take() { + waker.wake(); + } + } - (None, Some(socket)) + pub fn poll_has_sockets(&mut self, cx: &mut Context<'_>) -> Poll<()> { + if self.socket_v4.is_none() && self.socket_v6.is_none() { + let previous = self.waker.replace(cx.waker().clone()); + + if previous.is_none() { + // If we didn't have a waker yet, it means we just lost our sockets. Let the user know everything will be suspended. + tracing::error!("No available UDP sockets") } - (Ok(socket), Err(e)) => { - tracing::warn!("Failed to bind IPv6 socket: {e}"); - (Some(socket), None) - } - (Err(e4), Err(e6)) => { - return Err(NoInterfaces { e4, e6 }); - } - (Ok(v4), Ok(v6)) => (Some(v4), Some(v6)), - }; + return Poll::Pending; + } - self.socket_v4 = socket_v4; - self.socket_v6 = socket_v6; - - Ok(()) + Poll::Ready(()) } /// Flushes all buffered data on the sockets. @@ -77,7 +77,7 @@ impl Sockets { } pub fn poll_recv_from<'b>( - &self, + &mut self, ip4_buffer: &'b mut [u8], ip6_buffer: &'b mut [u8], cx: &mut Context<'_>, @@ -108,13 +108,6 @@ impl Sockets { } } -#[derive(thiserror::Error, Debug)] -#[error("Failed to bind to interfaces: {e4} | {e6}")] -pub struct NoInterfaces { - e4: io::Error, - e6: io::Error, -} - struct PacketIter { ip4: Option, ip6: Option, diff --git a/rust/gateway/src/main.rs b/rust/gateway/src/main.rs index e9a93a182..708ab7d3c 100644 --- a/rust/gateway/src/main.rs +++ b/rust/gateway/src/main.rs @@ -110,7 +110,7 @@ async fn run(login: LoginUrl, private_key: StaticSecret) -> Result { private_key, Arc::new(tcp_socket_factory), Arc::new(udp_socket_factory), - )?; + ); let portal = PhoenixChannel::connect( Secret::new(login), get_user_agent(None, env!("CARGO_PKG_VERSION")),